diff --git a/.dockerignore b/.dockerignore deleted file mode 100644 index abd139420db9..000000000000 --- a/.dockerignore +++ /dev/null @@ -1,3 +0,0 @@ -target/ -.git/ -py-polars/wheels/ \ No newline at end of file diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 000000000000..a99a2770491c --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,10 @@ +* @ritchie46 @c-peters + +/crates/ @ritchie46 @orlp @c-peters +/crates/polars-sql/ @ritchie46 @orlp @c-peters @alexander-beedie +/crates/polars-parquet/ @ritchie46 @orlp @c-peters @coastalwhite +/crates/polars-time/ @ritchie46 @orlp @c-peters @MarcoGorelli +/crates/polars-python/ @ritchie46 @c-peters @alexander-beedie @MarcoGorelli @reswqa +/crates/polars-python/src/lazyframe/visit.rs @ritchie46 @c-peters @alexander-beedie @MarcoGorelli @reswqa @wence- +/crates/polars-python/src/lazyframe/visitor/ @ritchie46 @c-peters @alexander-beedie @MarcoGorelli @reswqa @wence- +/py-polars/ @ritchie46 @c-peters @alexander-beedie @MarcoGorelli @reswqa diff --git a/.github/CODE_OF_CONDUCT.md b/.github/CODE_OF_CONDUCT.md new file mode 100644 index 000000000000..12a6354d20af --- /dev/null +++ b/.github/CODE_OF_CONDUCT.md @@ -0,0 +1,68 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as contributors and maintainers +pledge to make participation in our project and our community a harassment-free experience for +everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity +and expression, level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment include: + +- Using welcoming and inclusive language +- Being respectful of differing viewpoints and experiences +- Gracefully accepting constructive criticism +- Focusing on what is best for the community +- Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +- The use of sexualized language or imagery and unwelcome sexual attention or advances +- Trolling, insulting/derogatory comments, and personal or political attacks +- Public or private harassment +- Publishing others' private information, such as a physical or electronic address, without explicit + permission +- Other conduct which could reasonably be considered inappropriate in a professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable behavior and are +expected to take appropriate and fair corrective action in response to any instances of unacceptable +behavior. + +Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, +code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or +to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when an individual is +representing the project or its community in public spaces. Examples of representing a project or +community include using an official project e-mail address, posting via an official social media +account, or acting as an appointed representative at an online or offline event. Representation of a +project may be further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting +the project team at ritchie46@gmail.com. All complaints will be reviewed and investigated and will +result in a response that is deemed necessary and appropriate to the circumstances. The project team +is obligated to maintain confidentiality with regard to the reporter of an incident. Further details +of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good faith may face +temporary or permanent repercussions as determined by other members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at +https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 000000000000..4eaa95ad65ec --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1 @@ +github: ritchie46 diff --git a/.github/ISSUE_TEMPLATE/bug_report_python.yml b/.github/ISSUE_TEMPLATE/bug_report_python.yml new file mode 100644 index 000000000000..ffd99fca3467 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report_python.yml @@ -0,0 +1,73 @@ +name: '🐞 Bug report - Python' +description: Report an issue with Python Polars. +labels: [bug, needs triage, python] + +body: + - type: checkboxes + id: checks + attributes: + label: Checks + options: + - label: I have checked that this issue has not already been reported. + required: true + - label: I have confirmed this bug exists on the [latest version](https://pypi.org/project/polars/) of Polars. + required: true + + - type: textarea + id: example + attributes: + label: Reproducible example + description: > + Please follow [this guide](https://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports) on how to + provide a minimal, copy-pasteable example. Include the (wrong) output if applicable. + value: | + ```python + + ``` + validations: + required: true + + - type: textarea + id: logs + attributes: + label: Log output + description: > + Set the environment variable ``POLARS_VERBOSE=1`` before running the query. + Paste the output of ``stderr`` here. + render: shell + + - type: textarea + id: problem + attributes: + label: Issue description + description: > + Provide any additional information you think might be relevant. + validations: + required: true + + - type: textarea + id: expected-behavior + attributes: + label: Expected behavior + description: > + Describe or show a code example of the expected behavior. + validations: + required: true + + - type: textarea + id: version + attributes: + label: Installed versions + description: > + Paste the output of ``pl.show_versions()`` + value: | +
+ + ``` + Replace this line with the output of pl.show_versions(). Leave the backticks in place. + ``` + +
+ validations: + required: true + diff --git a/.github/ISSUE_TEMPLATE/bug_report_rust.yml b/.github/ISSUE_TEMPLATE/bug_report_rust.yml new file mode 100644 index 000000000000..2013020dd4a2 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report_rust.yml @@ -0,0 +1,70 @@ +name: '🐞 Bug report - Rust' +description: Report an issue with Rust Polars. +labels: [bug, needs triage, rust] + +body: + - type: checkboxes + id: checks + attributes: + label: Checks + options: + - label: I have checked that this issue has not already been reported. + required: true + - label: I have confirmed this bug exists on the [latest version](https://crates.io/crates/polars) of Polars. + required: true + + - type: textarea + id: example + attributes: + label: Reproducible example + description: > + Please follow [this guide](https://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports) on how to + provide a minimal, copy-pasteable example. Include the (wrong) output if applicable. + value: | + ```rust + + ``` + validations: + required: true + + - type: textarea + id: logs + attributes: + label: Log output + description: > + Set the environment variable ``POLARS_VERBOSE=1`` before running the query. + Paste the output of ``stderr`` here. + render: shell + + - type: textarea + id: problem + attributes: + label: Issue description + description: > + Provide any additional information you think might be relevant. + validations: + required: true + + - type: textarea + id: expected-behavior + attributes: + label: Expected behavior + description: > + Describe or show a code example of the expected behavior. + validations: + required: true + + - type: textarea + id: version + attributes: + label: Installed versions + description: > + List the feature gates you used. + value: | +
+ + Replace this line with a list of feature gates + +
+ validations: + required: true diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 000000000000..d7ae24c827ff --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,15 @@ +# Ref: https://help.github.com/en/github/building-a-strong-community/configuring-issue-templates-for-your-repository#configuring-the-template-chooser +blank_issues_enabled: true +contact_links: +- name: '❓ Question - Python' + url: https://stackoverflow.com/questions/ask?tags=python-polars%20python + about: | + Ask a question about Python Polars on Stack Overflow. +- name: '❓ Question - Rust' + url: https://stackoverflow.com/questions/ask?tags=rust-polars%20rust + about: | + Ask a question about Rust Polars on Stack Overflow. +- name: '💬 Discord server' + url: https://discord.gg/4UfP5cfBE7 + about: | + Chat with the community and Polars maintainers about the usage and development of the project. diff --git a/.github/ISSUE_TEMPLATE/documentation.yml b/.github/ISSUE_TEMPLATE/documentation.yml new file mode 100644 index 000000000000..64cc3c1d32c4 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/documentation.yml @@ -0,0 +1,23 @@ +name: '📖 Documentation improvement' +description: Report an issue with the documentation. +labels: [documentation] + +body: + - type: textarea + id: description + attributes: + label: Description + description: > + Describe the issue with the documentation and how it can be fixed or improved. + validations: + required: true + + - type: input + id: link + attributes: + label: Link + description: > + Provide a link to the existing documentation, if applicable. + placeholder: ex. https://docs.pola.rs/api/python/dev/... + validations: + required: false diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml new file mode 100644 index 000000000000..eed3105bf95f --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -0,0 +1,14 @@ +name: '✨ Feature request' +description: Suggest a new feature or enhancement for Polars. +labels: [enhancement] + +body: + - type: textarea + id: description + attributes: + label: Description + description: > + Describe the feature or enhancement and explain why it should be implemented. + Include a code example if applicable. + validations: + required: true diff --git a/.github/codecov.yml b/.github/codecov.yml new file mode 100644 index 000000000000..6b0e3b4c274c --- /dev/null +++ b/.github/codecov.yml @@ -0,0 +1,20 @@ +coverage: + status: + project: off + patch: off +comment: + require_changes: true +ignore: + - crates/polars-arrow/src/io/flight/*.rs + - crates/polars-arrow/src/io/ipc/append/*.rs + - crates/polars-arrow/src/io/ipc/read/array/union.rs + - crates/polars-arrow/src/io/ipc/read/array/map.rs + - crates/polars-arrow/src/io/ipc/read/array/binary.rs + - crates/polars-arrow/src/io/ipc/read/array/fixed_size_binary.rs + - crates/polars-arrow/src/io/ipc/read/array/null.rs + - crates/polars-arrow/src/io/ipc/write/serialize/fixed_size_binary.rs + - crates/polars-arrow/src/io/ipc/write/serialize/union.rs + - crates/polars-arrow/src/io/ipc/write/serialize/map.rs + - crates/polars-arrow/src/array/union/*.rs + - crates/polars-arrow/src/array/map/*.rs + - crates/polars-arrow/src/array/fixed_size_binary/*.rs diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 000000000000..b95366972728 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,66 @@ +version: 2 +updates: + # GitHub Actions + - package-ecosystem: github-actions + directory: '/' + schedule: + interval: monthly + ignore: + - dependency-name: '*' + update-types: ['version-update:semver-patch'] + commit-message: + prefix: ci + labels: ['skip changelog'] + groups: + ci: + patterns: + - '*' + + # Rust Polars + - package-ecosystem: cargo + directory: '/' + schedule: + interval: monthly + ignore: + - dependency-name: '*' + update-types: ['version-update:semver-patch'] + commit-message: + prefix: build + prefix-development: chore(rust) + labels: ['skip changelog'] + groups: + rust: + patterns: + - '*' + + # Python Polars + - package-ecosystem: pip + directory: py-polars + schedule: + interval: monthly + ignore: + - dependency-name: '*' + update-types: ['version-update:semver-patch'] + commit-message: + prefix: chore(python) + labels: ['skip changelog'] + groups: + python: + patterns: + - '*' + + # Documentation + - package-ecosystem: pip + directory: docs + schedule: + interval: monthly + ignore: + - dependency-name: '*' + update-types: ['version-update:semver-patch'] + commit-message: + prefix: chore(python) + labels: ['skip changelog'] + groups: + documentation: + patterns: + - '*' diff --git a/.github/pr-title-checker-config.json b/.github/pr-title-checker-config.json new file mode 100644 index 000000000000..65f09d5aab9d --- /dev/null +++ b/.github/pr-title-checker-config.json @@ -0,0 +1,15 @@ +{ + "LABEL": { + "name": "title needs formatting", + "color": "FF0000" + }, + "CHECKS": { + "regexp": "^(build|chore|ci|depr|docs|feat|fix|perf|refactor|release|test)(\\((python|rust)\\!?(,(python|rust)\\!?)?\\))?\\!?\\: [A-Z].*[^\\.\\!\\?,… ]$", + "ignoreLabels": ["skip changelog"] + }, + "MESSAGES": { + "success": "PR title OK!", + "failure": "Invalid PR title! Please update according to the contributing guidelines: https://docs.pola.rs/development/contributing/#pull-requests", + "notice": "" + } +} diff --git a/.github/release-drafter-python.yml b/.github/release-drafter-python.yml new file mode 100644 index 000000000000..284dfbb7b428 --- /dev/null +++ b/.github/release-drafter-python.yml @@ -0,0 +1,41 @@ +_extends: polars:.github/release-drafter.yml + +name-template: Python Polars $RESOLVED_VERSION +tag-template: py-$RESOLVED_VERSION +tag-prefix: py- + +include-labels: + - python + +version-resolver: + major: + labels: + - breaking + - breaking python + minor: + labels: + - performance + - enhancement + default: patch + +categories: + - title: 🏆 Highlights + labels: highlight + - title: 💥 Breaking changes + labels: + - breaking + - breaking python + - title: ⚠️ Deprecations + labels: deprecation + - title: 🚀 Performance improvements + labels: performance + - title: ✨ Enhancements + labels: enhancement + - title: 🐞 Bug fixes + labels: fix + - title: 📖 Documentation + labels: documentation + - title: 📦 Build system + labels: build + - title: 🛠️ Other improvements + labels: internal diff --git a/.github/release-drafter-rust.yml b/.github/release-drafter-rust.yml new file mode 100644 index 000000000000..43f0a8ecc7e8 --- /dev/null +++ b/.github/release-drafter-rust.yml @@ -0,0 +1,37 @@ +_extends: polars:.github/release-drafter.yml + +name-template: Rust Polars $RESOLVED_VERSION +tag-template: rs-$RESOLVED_VERSION +tag-prefix: rs- + +include-labels: + - rust + +version-resolver: + minor: + labels: + - breaking + - breaking rust + default: patch + +categories: + - title: 🏆 Highlights + labels: highlight + - title: 💥 Breaking changes + labels: + - breaking + - breaking rust + - title: 🚀 Performance improvements + labels: performance + - title: ✨ Enhancements + labels: enhancement + - title: 🐞 Bug fixes + labels: fix + - title: 📖 Documentation + labels: documentation + - title: 📦 Build system + labels: build + - title: 🛠️ Other improvements + labels: + - deprecation + - internal diff --git a/.github/release-drafter.yml b/.github/release-drafter.yml new file mode 100644 index 000000000000..8216254ab6bc --- /dev/null +++ b/.github/release-drafter.yml @@ -0,0 +1,62 @@ +exclude-labels: + - skip changelog + - release + +change-template: '- $TITLE (#$NUMBER)' +change-title-escapes: '\<*_&' +replacers: + # Remove conventional commits from titles + - search: '/- (build|chore|ci|depr|docs|feat|fix|perf|refactor|release|test)(\(.*\))?(\!)?\: /g' + replace: '- ' + +autolabeler: + - label: rust + title: + # Example: feat(rust): ... + - '/^(build|chore|ci|depr|docs|feat|fix|perf|refactor|release|test)(\(.*rust.*\))?\!?\: /' + - label: python + title: + # Example: feat(python): ... + - '/^(build|chore|ci|depr|docs|feat|fix|perf|refactor|release|test)(\(.*python.*\))?\!?\: /' + - label: breaking + title: + # Example: feat!: ... + - '/^(build|chore|ci|depr|docs|feat|fix|perf|refactor|release|test)(\(.*\))?\!\: /' + - label: breaking rust + title: + # Example: feat(rust!, python): ... + - '/^(build|chore|ci|depr|docs|feat|fix|perf|refactor|release|test)\(.*rust\!.*\)\: /' + - label: breaking python + title: + # Example: feat(python!): ... + - '/^(build|chore|ci|depr|docs|feat|fix|perf|refactor|release|test)\(.*python\!.*\)\: /' + - label: build + title: + - '/^build/' + - label: internal + title: + - '/^(chore|ci|refactor|test)/' + - label: deprecation + title: + - '/^depr/' + - label: documentation + title: + - '/^docs/' + - label: enhancement + title: + - '/^feat/' + - label: fix + title: + - '/^fix/' + - label: performance + title: + - '/^perf/' + - label: release + title: + - '/^release/' + +template: | + $CHANGES + + Thank you to all our contributors for making this release possible! + $CONTRIBUTORS diff --git a/.github/scripts/test_bytecode_parser.py b/.github/scripts/test_bytecode_parser.py new file mode 100644 index 000000000000..2abe9bf5de13 --- /dev/null +++ b/.github/scripts/test_bytecode_parser.py @@ -0,0 +1,101 @@ +""" +Minimal testing script of the BytecodeParser class. + +This can be run without polars installed, and so can be easily run in CI +over all supported Python versions. + +All that needs to be installed is pytest, numpy, and ipython. + +Usage: + + $ PYTHONPATH=py-polars pytest .github/scripts/test_bytecode_parser.py + +Running it without `PYTHONPATH` set will result in the test failing. +""" + +import datetime as dt # noqa: F401 +import subprocess +from datetime import datetime # noqa: F401 +from typing import Any, Callable + +import pytest +from polars._utils.udfs import BytecodeParser +from tests.unit.operations.map.test_inefficient_map_warning import ( + MY_DICT, + NOOP_TEST_CASES, + TEST_CASES, +) + + +@pytest.mark.parametrize( + ("col", "func", "expected"), + TEST_CASES, +) +def test_bytecode_parser_expression(col: str, func: str, expected: str) -> None: + bytecode_parser = BytecodeParser(eval(func), map_target="expr") + result = bytecode_parser.to_expression(col) + assert result == expected + + +@pytest.mark.parametrize( + ("col", "func", "expected"), + TEST_CASES, +) +def test_bytecode_parser_expression_in_ipython( + col: str, func: Callable[[Any], Any], expected: str +) -> None: + script = ( + "from polars._utils.udfs import BytecodeParser; " + "import datetime as dt; " + "from datetime import datetime; " + "import numpy as np; " + "import json; " + f"MY_DICT = {MY_DICT};" + f'bytecode_parser = BytecodeParser({func}, map_target="expr");' + f'print(bytecode_parser.to_expression("{col}"));' + ) + + output = subprocess.run(["ipython", "-c", script], text=True, capture_output=True) + assert expected == output.stdout.rstrip("\n") + + +@pytest.mark.parametrize( + "func", + NOOP_TEST_CASES, +) +def test_bytecode_parser_expression_noop(func: str) -> None: + parser = BytecodeParser(eval(func), map_target="expr") + assert not parser.can_attempt_rewrite() or not parser.to_expression("x") + + +@pytest.mark.parametrize( + "func", + NOOP_TEST_CASES, +) +def test_bytecode_parser_expression_noop_in_ipython(func: str) -> None: + script = ( + "from polars._utils.udfs import BytecodeParser; " + f"MY_DICT = {MY_DICT};" + f'parser = BytecodeParser({func}, map_target="expr");' + f'print(not parser.can_attempt_rewrite() or not parser.to_expression("x"));' + ) + + output = subprocess.run(["ipython", "-c", script], text=True, capture_output=True) + assert output.stdout == "True\n" + + +def test_local_imports() -> None: + import datetime as dt # noqa: F811 + import json + + bytecode_parser = BytecodeParser(lambda x: json.loads(x), map_target="expr") + result = bytecode_parser.to_expression("x") + expected = 'pl.col("x").str.json_decode()' + assert result == expected + + bytecode_parser = BytecodeParser( + lambda x: dt.datetime.strptime(x, "%Y-%m-%d"), map_target="expr" + ) + result = bytecode_parser.to_expression("x") + expected = 'pl.col("x").str.to_datetime(format="%Y-%m-%d")' + assert result == expected diff --git a/.github/workflows/benchmark-remote.yml b/.github/workflows/benchmark-remote.yml new file mode 100644 index 000000000000..24b3e0959d7a --- /dev/null +++ b/.github/workflows/benchmark-remote.yml @@ -0,0 +1,80 @@ +name: Remote Benchmark + +on: + workflow_dispatch: + push: + branches: + - 'main' + paths: + - crates/** + pull_request: + types: [ labeled ] + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.event.label.name == 'needs-bench' }} + +env: + SCALE_FACTOR: '10.0' + +jobs: + main: + if: ${{ github.ref == 'refs/heads/main' || github.event.label.name == 'needs-bench' }} + runs-on: self-hosted + steps: + - uses: actions/checkout@v4 + + - name: Clone Polars-benchmark + run: | + git clone --depth=1 https://github.com/pola-rs/polars-benchmark.git + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.13' + + - name: Create virtual environment + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + uv venv + echo "$GITHUB_WORKSPACE/.venv/bin" >> $GITHUB_PATH + echo "VIRTUAL_ENV=$GITHUB_WORKSPACE/.venv" >> $GITHUB_ENV + + - name: Install Python dependencies + working-directory: py-polars + run: | + # Install typing-extensions separately whilst the `--extra-index-url` in `requirements-ci.txt` + # doesn't have an up-to-date typing-extensions, see + # https://github.com/astral-sh/uv/issues/6028#issuecomment-2287232150 + uv pip install -U typing-extensions + uv pip install --compile-bytecode -r requirements-dev.txt -r requirements-ci.txt --verbose --index-strategy=unsafe-best-match + + - name: Install Polars-Benchmark dependencies + working-directory: polars-benchmark + run: | + uv pip install --compile-bytecode -r requirements-polars-only.txt + + - name: Set up Rust + run: rustup show + + - name: Install Polars release build + env: + RUSTFLAGS: -C embed-bitcode -D warnings + working-directory: py-polars + run: | + maturin develop --release -- -C codegen-units=8 -C lto=thin -C target-cpu=native + + - name: Run benchmark + working-directory: polars-benchmark + run: | + "$HOME/py-polars-cache/run-benchmarks.sh" | tee ../py-polars/benchmark-results + + - name: Cache the Polars build + if: ${{ github.ref == 'refs/heads/main' }} + working-directory: py-polars + run: | + "$HOME/py-polars-cache/add-data.py" "$PWD/polars" < ./benchmark-results + pip install seaborn + "$HOME/py-polars-cache/create-plots.py" + touch "$HOME/py-polars-cache/upload-probe" + "$HOME/py-polars-cache/cache-build.sh" "$PWD/polars" diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml new file mode 100644 index 000000000000..aa759712ff9b --- /dev/null +++ b/.github/workflows/benchmark.yml @@ -0,0 +1,150 @@ +name: Benchmark + +on: + pull_request: + paths: + - crates/** + - Cargo.toml + - py-polars/tests/benchmark/** + - .github/workflows/benchmark.yml + push: + branches: + - main + paths: + - crates/** + - Cargo.toml + - py-polars/tests/benchmark/** + - .github/workflows/benchmark.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +env: + SCCACHE_GHA_ENABLED: 'true' + RUSTC_WRAPPER: sccache + RUST_BACKTRACE: 1 + +jobs: + main: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.13' + + - name: Set up Graphviz + uses: ts-graphviz/setup-graphviz@v2 + + - name: Create virtual environment + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + uv venv + echo "$GITHUB_WORKSPACE/.venv/bin" >> $GITHUB_PATH + echo "VIRTUAL_ENV=$GITHUB_WORKSPACE/.venv" >> $GITHUB_ENV + + - name: Install Python dependencies + working-directory: py-polars + run: | + # Install typing-extensions separately whilst the `--extra-index-url` in `requirements-ci.txt` + # doesn't have an up-to-date typing-extensions, see + # https://github.com/astral-sh/uv/issues/6028#issuecomment-2287232150 + uv pip install -U typing-extensions + uv pip install --compile-bytecode -r requirements-dev.txt -r requirements-ci.txt --verbose --index-strategy=unsafe-best-match + + - name: Set up Rust + run: rustup show + + - name: Cache Rust + uses: Swatinem/rust-cache@v2 + with: + workspaces: py-polars + save-if: ${{ github.ref_name == 'main' }} + + - name: Run sccache-cache + uses: mozilla-actions/sccache-action@v0.0.9 + + - name: Build Polars release build + working-directory: py-polars + run: maturin build --profile nodebug-release -- -C codegen-units=8 -C lto=thin -C target-cpu=native + + - name: Install Polars release build + run: | + pip install --force-reinstall target/wheels/polars*.whl + + # This workflow builds and installs a wheel, meaning there is no polars.abi3.so under + # py-polars/. This causes a binary not found error if a test tries to import polars in + # a Python executed using `subprocess.check_output([sys.executable])`. Fix this by + # symlinking the binary. + ln -sv \ + $(python -c "import importlib; print(importlib.util.find_spec('polars').submodule_search_locations[0] + '/polars.abi3.so')") \ + py-polars/polars/polars.abi3.so + + - name: Set wheel size + run: | + LIB_SIZE=$(ls -l target/wheels/polars*.whl | awk '{ print $5 }') + echo "LIB_SIZE=$LIB_SIZE" >> $GITHUB_ENV + + - name: Comment wheel size + uses: actions/github-script@v7 + if: github.ref_name != 'main' + with: + script: | + const currentSize = process.env.LIB_SIZE || 'Unknown'; + + // Convert to MB + const currentSizeMB = currentSize !== 'Unknown' ? (currentSize / 1024 / 1024).toFixed(4) : 'Unknown'; + + let commentBody = `The uncompressed lib size after this PR is **${currentSizeMB} MB**.`; + + const { data: comments } = await github.rest.issues.listComments({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + }); + + // Look for an existing comment + const existingComment = comments.find(comment => + comment.body.includes('The previous uncompressed lib size was') + ); + + if (existingComment) { + // Update the existing comment + await github.rest.issues.updateComment({ + owner: context.repo.owner, + repo: context.repo.repo, + comment_id: existingComment.id, + body: commentBody, + }); + } else { + // Create a new comment + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + body: commentBody, + }); + } + continue-on-error: true + + # - name: Run benchmark tests + # uses: CodSpeedHQ/action@v3 + # with: + # working-directory: py-polars + # run: pytest -m benchmark --codspeed -v + + - name: Run non-benchmark tests + working-directory: py-polars + run: pytest -m 'not benchmark and not debug' -n auto + env: + POLARS_TIMEOUT_MS: 60000 + + - name: Run non-benchmark tests on new streaming engine + working-directory: py-polars + env: + POLARS_AUTO_NEW_STREAMING: 1 + POLARS_TIMEOUT_MS: 60000 + run: pytest -n auto -m "not may_fail_auto_streaming and not slow and not write_disk and not release and not docs and not hypothesis and not benchmark and not ci_only" diff --git a/.github/workflows/clear-caches.yml b/.github/workflows/clear-caches.yml new file mode 100644 index 000000000000..fc75374b21fb --- /dev/null +++ b/.github/workflows/clear-caches.yml @@ -0,0 +1,19 @@ +# Clearing caches regularly takes care of Rust caches growing to problematic size over time + +name: Clear caches + +on: + schedule: + - cron: '0 4 * * MON' + workflow_dispatch: + +jobs: + clear-caches: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Clear all caches + run: gh cache delete --all + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/docs-global.yml b/.github/workflows/docs-global.yml new file mode 100644 index 000000000000..1bdfcc57ce7f --- /dev/null +++ b/.github/workflows/docs-global.yml @@ -0,0 +1,111 @@ +name: Build documentation + +on: + pull_request: + paths: + - docs/** + - mkdocs.yml + - .github/workflows/docs-global.yml + repository_dispatch: + types: + - python-release-docs + workflow_dispatch: + +jobs: + markdown-link-check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + ref: ${{ github.event.client_payload.sha }} + - uses: gaurav-nelson/github-action-markdown-link-check@v1 + with: + config-file: docs/mlc-config.json + folder-path: docs + + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + ref: ${{ github.event.client_payload.sha }} + + - name: Get ruff version from requirements file + id: version + run: | + VERSION=$(grep -m 1 -oP 'ruff==\K(.*)' py-polars/requirements-lint.txt) + echo "version=$VERSION" >> $GITHUB_OUTPUT + + - uses: chartboost/ruff-action@v1 + with: + src: docs/source/ + version: ${{ steps.version.outputs.version }} + args: check --no-fix + + - uses: chartboost/ruff-action@v1 + with: + src: docs/source/ + version: ${{ steps.version.outputs.version }} + args: format --diff + + deploy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + ref: ${{ github.event.client_payload.sha }} + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + + - name: Create virtual environment + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + uv venv + echo "$GITHUB_WORKSPACE/.venv/bin" >> $GITHUB_PATH + echo "VIRTUAL_ENV=$GITHUB_WORKSPACE/.venv" >> $GITHUB_ENV + + - name: Install Python dependencies + run: uv pip install -r py-polars/requirements-dev.txt -r docs/source/requirements.txt + + - name: Set up Rust + run: rustup show + + - name: Cache Rust + uses: Swatinem/rust-cache@v2 + with: + workspaces: py-polars + save-if: ${{ github.ref_name == 'main' }} + + - name: Install Polars + working-directory: py-polars + run: maturin develop + + - name: Set up Graphviz + uses: ts-graphviz/setup-graphviz@v2 + + - name: Build documentation + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: mkdocs build + + - name: Add .nojekyll + if: github.event_name == 'repository_dispatch' || github.event_name == 'workflow_dispatch' + working-directory: site + run: touch .nojekyll + + - name: Deploy docs + if: github.event_name == 'repository_dispatch' || github.event_name == 'workflow_dispatch' + uses: JamesIves/github-pages-deploy-action@v4 + with: + folder: site + # docs/ and py-polars/ included for backwards compatibility + clean-exclude: | + api/python/ + api/rust/ + docs/python + docs/rust + py-polars/html + single-commit: true diff --git a/.github/workflows/docs-python.yml b/.github/workflows/docs-python.yml new file mode 100644 index 000000000000..2e90c6c611f5 --- /dev/null +++ b/.github/workflows/docs-python.yml @@ -0,0 +1,119 @@ +name: Build Python documentation + +on: + pull_request: + paths: + - py-polars/docs/** + - py-polars/polars/** + - .github/workflows/docs-python.yml + push: + branches: + - main + paths: + - py-polars/docs/** + - py-polars/polars/** + - .github/workflows/docs-python.yml + repository_dispatch: + types: + - python-release + # To build older versions, specify the commit and the version + workflow_dispatch: + inputs: + polars_version: + description: 'Specify Polars version (e.g., py-0.20.1)' + required: true + git_commit: + description: 'Which commit to build' + required: true + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + +jobs: + build-python-docs: + runs-on: ubuntu-latest + steps: + - name: Parse the tag to find the major version of Polars + id: version + if: github.event_name == 'repository_dispatch' || github.event_name == 'workflow_dispatch' + shell: bash + run: | + tag="${{ github.event.inputs.polars_version || github.event.client_payload.tag }}" + regex="py-([0-9]+)\.[0-9]+\.[0-9]+.*" + if [[ $tag =~ $regex ]]; then + major=${BASH_REMATCH[1]} + minor=${BASH_REMATCH[2]} + + if [[ "$major" == "0" ]]; then + version="$major.$minor" # Keep "0.X" for 0.x.y versions + else + version="$major" # Use only major for 1.x.y+ + fi + + echo "version=$version" >> "$GITHUB_OUTPUT" + else + echo "Error: Invalid version format. Cancelling workflow." + exit 1 + fi + + # Chooses the manually given commit if manually triggered. + - uses: actions/checkout@v4 + with: + ref: ${{ github.event.inputs.git_commit || github.event.client_payload.sha }} + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.13' + + - name: Create virtual environment + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + uv venv + echo "$GITHUB_WORKSPACE/.venv/bin" >> $GITHUB_PATH + echo "VIRTUAL_ENV=$GITHUB_WORKSPACE/.venv" >> $GITHUB_ENV + + - name: Install Python dependencies + working-directory: py-polars/docs + run: uv pip install -r requirements-docs.txt + + - name: Build Python documentation + working-directory: py-polars/docs + env: + POLARS_VERSION: ${{ github.event.inputs.polars_version || github.event.client_payload.tag || 'main' }} + run: make html + + - name: Deploy Python docs for latest release version - versioned + if: github.event_name == 'repository_dispatch' || github.event_name == 'workflow_dispatch' + uses: JamesIves/github-pages-deploy-action@v4 + with: + folder: py-polars/docs/build/html + target-folder: api/python/version/${{ steps.version.outputs.version }} + single-commit: true + + - name: Deploy Python docs for latest development version + if: github.event_name == 'push' && github.ref_name == 'main' + uses: JamesIves/github-pages-deploy-action@v4 + with: + folder: py-polars/docs/build/html + target-folder: api/python/dev + single-commit: true + + - name: Deploy Python docs for latest release version - stable + if: github.event_name == 'repository_dispatch' && github.event.client_payload.is_prerelease == 'false' || (github.event_name == 'workflow_dispatch' && steps.version.outputs.version == '1') + uses: JamesIves/github-pages-deploy-action@v4 + with: + folder: py-polars/docs/build/html + target-folder: api/python/stable + single-commit: true + + # Build global docs _after_ this workflow to avoid contention on the gh-pages branch + - name: Trigger global docs workflow + if: github.event_name == 'repository_dispatch' + uses: peter-evans/repository-dispatch@v3 + with: + event-type: python-release-docs + client-payload: > + { + "sha": "${{ github.event.client_payload.sha }}" + } diff --git a/.github/workflows/docs-rust.yml b/.github/workflows/docs-rust.yml new file mode 100644 index 000000000000..262125b82322 --- /dev/null +++ b/.github/workflows/docs-rust.yml @@ -0,0 +1,55 @@ +name: Build Rust documentation + +on: + pull_request: + paths: + - crates/** + - .github/workflows/docs-rust.yml + push: + branches: + - main + paths: + - crates/** + - .github/workflows/docs-rust.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + +jobs: + build-rust-docs: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Rust + run: rustup component add rust-docs + + - name: Cache Rust + uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref_name == 'main' }} + + - name: Build Rust documentation + env: + RUSTDOCFLAGS: --cfg docsrs -D warnings --allow=rustdoc::redundant-explicit-links + working-directory: crates + run: make doctest + + - name: Create redirect to Polars crate and set no-jekyll + if: ${{ github.ref_name == 'main' }} + run: | + echo '' > target/doc/index.html + touch target/doc/.nojekyll + + - name: Deploy Rust docs + if: ${{ github.ref_name == 'main' }} + uses: JamesIves/github-pages-deploy-action@v4 + with: + folder: target/doc + target-folder: api/rust/dev + single-commit: true + + # Make sure documentation artifacts are not cached + - name: Clean up documentation artifacts + if: ${{ github.ref_name == 'main' }} + run: rm -rf target/doc diff --git a/.github/workflows/lint-global.yml b/.github/workflows/lint-global.yml new file mode 100644 index 000000000000..a178a4ee9ae1 --- /dev/null +++ b/.github/workflows/lint-global.yml @@ -0,0 +1,18 @@ +name: Lint global + +on: + pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + main: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Lint Markdown and TOML + uses: dprint/check@v2.2 + - name: Spell Check with Typos + uses: crate-ci/typos@v1.30.0 diff --git a/.github/workflows/lint-python.yml b/.github/workflows/lint-python.yml new file mode 100644 index 000000000000..895a6a85b576 --- /dev/null +++ b/.github/workflows/lint-python.yml @@ -0,0 +1,66 @@ +name: Lint Python + +on: + pull_request: + paths: + - py-polars/** + - .github/workflows/lint-python.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + ruff: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Get ruff version from requirements file + id: version + run: | + VERSION=$(grep -m 1 -oP 'ruff==\K(.*)' py-polars/requirements-lint.txt) + echo "version=$VERSION" >> $GITHUB_OUTPUT + + - uses: chartboost/ruff-action@v1 + with: + src: py-polars/ + version: ${{ steps.version.outputs.version }} + args: check --no-fix + + - uses: chartboost/ruff-action@v1 + with: + src: py-polars/ + version: ${{ steps.version.outputs.version }} + args: format --diff + + mypy: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ['3.9', '3.13'] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Create virtual environment + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + uv venv + echo "$GITHUB_WORKSPACE/.venv/bin" >> $GITHUB_PATH + echo "VIRTUAL_ENV=$GITHUB_WORKSPACE/.venv" >> $GITHUB_ENV + + - name: Install Python dependencies + working-directory: py-polars + run: uv pip install -r requirements-dev.txt -r requirements-lint.txt + + # Allow untyped calls for older Python versions + - name: Run mypy + working-directory: py-polars + run: mypy ${{ (matrix.python-version == '3.9') && '--allow-untyped-calls' || '' }} diff --git a/.github/workflows/lint-rust.yml b/.github/workflows/lint-rust.yml new file mode 100644 index 000000000000..246fc3cd2e7f --- /dev/null +++ b/.github/workflows/lint-rust.yml @@ -0,0 +1,101 @@ +name: Lint Rust + +on: + pull_request: + paths: + - crates/** + - docs/source/src/rust/** + - examples/** + - py-polars/src/** + - py-polars/Cargo.toml + - Cargo.toml + - .github/workflows/lint-rust.yml + push: + branches: + - main + paths: + - crates/** + - docs/source/src/rust/** + - examples/** + - py-polars/src/** + - py-polars/Cargo.toml + - Cargo.toml + - .github/workflows/lint-rust.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +env: + RUSTFLAGS: -C debuginfo=0 # Do not produce debug symbols to keep memory usage down + +jobs: + clippy-nightly: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Rust + run: rustup component add clippy + + - name: Cache Rust + uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref_name == 'main' }} + + - name: Run cargo clippy with all features enabled + run: cargo clippy --workspace --all-targets --all-features --locked -- -D warnings -D clippy::dbg_macro + + # Default feature set should compile on the stable toolchain + clippy-stable: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Rust + run: rustup override set stable && rustup update + + - name: Install clippy + run: rustup component add clippy + + - name: Cache Rust + uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref_name == 'main' }} + + - name: Run cargo clippy + run: cargo clippy --all-targets --locked -- -D warnings -D clippy::dbg_macro + + rustfmt: + if: github.ref_name != 'main' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Rust + run: rustup component add rustfmt + + - name: Run cargo fmt + run: cargo fmt --all --check + + miri: + if: github.ref_name != 'main' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Rust + run: rustup component add miri + + - name: Set up miri + run: cargo miri setup + + - name: Run miri + env: + MIRIFLAGS: -Zmiri-disable-isolation -Zmiri-ignore-leaks -Zmiri-disable-stacked-borrows + POLARS_ALLOW_EXTENSION: '1' + run: > + cargo miri test + --features object + -p polars-core +# -p polars-arrow diff --git a/.github/workflows/pr-labeler.yml b/.github/workflows/pr-labeler.yml new file mode 100644 index 000000000000..a0c7e83bfdb3 --- /dev/null +++ b/.github/workflows/pr-labeler.yml @@ -0,0 +1,25 @@ +name: Pull request labeler + +on: + pull_request_target: + types: [opened, edited] + +permissions: + contents: read + pull-requests: write + +jobs: + labeler: + runs-on: ubuntu-latest + steps: + - name: Check pull request title + uses: thehanimo/pr-title-checker@v1.4.2 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Label pull request + uses: release-drafter/release-drafter@v6 + with: + disable-releaser: true + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/release-drafter.yml b/.github/workflows/release-drafter.yml new file mode 100644 index 000000000000..84229ef07920 --- /dev/null +++ b/.github/workflows/release-drafter.yml @@ -0,0 +1,38 @@ +name: Update draft releases + +on: + push: + branches: + - main + workflow_dispatch: + inputs: + # Latest commit to include with the release. If omitted, use the latest commit on the main branch. + sha: + description: Commit SHA + type: string + +permissions: + contents: write + pull-requests: read + +jobs: + main: + runs-on: ubuntu-latest + steps: + - name: Draft Rust release + uses: release-drafter/release-drafter@v6 + with: + config-name: release-drafter-rust.yml + commitish: ${{ inputs.sha || github.sha }} + disable-autolabeler: true + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Draft Python release + uses: release-drafter/release-drafter@v6 + with: + config-name: release-drafter-python.yml + commitish: ${{ inputs.sha || github.sha }} + disable-autolabeler: true + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/release-python.yml b/.github/workflows/release-python.yml new file mode 100644 index 000000000000..650e8bbe7732 --- /dev/null +++ b/.github/workflows/release-python.yml @@ -0,0 +1,464 @@ +name: Release Python + +on: + workflow_dispatch: + inputs: + # Latest commit to include with the release. If omitted, use the latest commit on the main branch. + sha: + description: Commit SHA + type: string + # Create the sdist and build the wheels, but do not publish to PyPI / GitHub. + dry-run: + description: Dry run + type: boolean + default: false + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +env: + PYTHON_VERSION: '3.9' + PYTHON_VERSION_WIN_ARM64: '3.11' # ARM64 Windows doesn't have older versions + CARGO_INCREMENTAL: 0 + CARGO_NET_RETRY: 10 + RUSTUP_MAX_RETRIES: 10 + +defaults: + run: + shell: bash + +jobs: + create-sdist: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + package: [polars, polars-lts-cpu, polars-u64-idx] + + steps: + - uses: actions/checkout@v4 + with: + ref: ${{ inputs.sha }} + + # Avoid potential out-of-memory errors + - name: Set swap space for Linux + uses: pierotofy/set-swap-space@master + with: + swap-size-gb: 10 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ env.PYTHON_VERSION }} + + - name: Install yq + if: matrix.package != 'polars' + run: pip install yq + - name: Update package name + if: matrix.package != 'polars' + run: tomlq -i -t ".project.name = \"${{ matrix.package }}\"" py-polars/pyproject.toml + - name: Add bigidx feature + if: matrix.package == 'polars-u64-idx' + run: tomlq -i -t '.dependencies.polars.features += ["bigidx"]' py-polars/Cargo.toml + - name: Update optional dependencies + if: matrix.package != 'polars' + run: sed -i 's/polars\[/${{ matrix.package }}\[/' py-polars/pyproject.toml + + - name: Create source distribution + uses: PyO3/maturin-action@v1 + with: + command: sdist + args: > + --manifest-path py-polars/Cargo.toml + --out dist + maturin-version: 1.8.3 + + - name: Test sdist + run: | + pip install --force-reinstall --verbose dist/*.tar.gz + python -c 'import polars' + + - name: Upload sdist + uses: actions/upload-artifact@v4 + with: + name: sdist-${{ matrix.package }} + path: dist/*.tar.gz + + build-wheels: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + package: [polars, polars-lts-cpu, polars-u64-idx] + # macos-13 is x86-64 + # macos-15 is aarch64 + os: [ubuntu-latest, macos-13, macos-15, windows-latest, windows-11-arm] + architecture: [x86-64, aarch64] + exclude: + - os: windows-latest + architecture: aarch64 + - os: windows-11-arm + architecture: x86-64 + - os: macos-15 + architecture: x86-64 + - os: macos-13 + architecture: aarch64 + + env: + SED_INPLACE: ${{ startsWith(matrix.os, 'macos') && '-i ''''' || '-i' }} + CPU_CHECK_MODULE: py-polars/polars/_cpu_check.py + + steps: + - name: Setup build environment (ARM64 Windows) + if: matrix.os == 'windows-11-arm' + shell: + powershell + # Notes + # * We update `Expand-Archive` to avoid "" is not a supported archive file format when extracting + # files that don't end in `.zip` + run: | + Write-Output "> Update Expand-Archive (Microsoft.PowerShell.Archive)" + Install-PackageProvider -Name NuGet -Force + Install-Module -Name Microsoft.PowerShell.Archive -Force + + Write-Output "> Setup bash.exe (git-for-windows/PortableGit)" + Invoke-WebRequest "https://github.com/git-for-windows/git/releases/download/v2.47.1.windows.1/PortableGit-2.47.1-arm64.7z.exe" -OutFile /git.7z.exe + /git.7z.exe -o/git -y + + Write-Output "> Setup Rust" + Invoke-WebRequest "https://static.rust-lang.org/rustup/dist/aarch64-pc-windows-msvc/rustup-init.exe" -OutFile /rustup-init.exe + /rustup-init.exe --default-host aarch64-pc-windows-msvc -y + + Write-Output "> Setup VS Build Tools" + Invoke-WebRequest "https://aka.ms/vs/17/release/vs_BuildTools.exe" -OutFile /vs_BuildTools.exe + Start-Process C:/vs_BuildTools.exe -ArgumentList " ` + --add Microsoft.VisualStudio.Workload.NativeDesktop ` + --add Microsoft.VisualStudio.Workload.VCTools ` + --add Microsoft.VisualStudio.Component.VC.Tools.arm64 ` + --add Microsoft.VisualStudio.Component.VC.Llvm.Clang ` + --add Microsoft.VisualStudio.Component.VC.Llvm.ClangToolset ` + --includeRecommended --quiet --norestart --wait" -Wait + + Write-Output "> Setup CMake" + Invoke-WebRequest "https://github.com/Kitware/CMake/releases/download/v3.31.2/cmake-3.31.2-windows-arm64.zip" -OutFile /cmake.zip + Expand-Archive /cmake.zip -DestinationPath / + + Write-Output "> Download jq.exe (github.com/jqlang) (needed for tomlq / yq)" + Invoke-WebRequest https://github.com/jqlang/jq/releases/download/jq-1.7.1/jq-windows-i386.exe -OutFile /jq.exe + + Write-Output "> Update GITHUB_PATH" + [System.IO.File]::AppendAllText($Env:GITHUB_PATH, "`n" + "C:/git/bin/") + [System.IO.File]::AppendAllText($Env:GITHUB_PATH, "`n" + $Env:USERPROFILE + "/.cargo/bin/") + [System.IO.File]::AppendAllText($Env:GITHUB_PATH, "`n" + "C:/Program Files (x86)/Microsoft Visual Studio/2022/BuildTools/VC/Tools/Llvm/bin/") + [System.IO.File]::AppendAllText($Env:GITHUB_PATH, "`n" + "C:/cmake-3.31.2-windows-arm64/bin") + [System.IO.File]::AppendAllText($Env:GITHUB_PATH, "`n" + "C:/") + [System.IO.File]::AppendAllText($Env:GITHUB_PATH, "`n") + + Get-Content $Env:GITHUB_PATH | Out-Host + + - name: Check build environment (ARM64 Windows) + if: matrix.os == 'windows-11-arm' + run: | + set -x + bash --version + rustup show + clang -v + cmake --version + + - uses: actions/checkout@v4 + with: + ref: ${{ inputs.sha }} + + # Avoid potential out-of-memory errors + - name: Set swap space for Linux + if: matrix.os == 'ubuntu-latest' + uses: pierotofy/set-swap-space@master + with: + swap-size-gb: 10 + + - name: Set up Python + if: matrix.os != 'windows-11-arm' + uses: actions/setup-python@v5 + with: + python-version: ${{ env.PYTHON_VERSION }} + + - name: Set up Python (ARM64 Windows) + if: matrix.os == 'windows-11-arm' + uses: actions/setup-python@v5 + with: + python-version: ${{ env.PYTHON_VERSION_WIN_ARM64 }} + + # Otherwise can't find `tomlq` after `pip install yq` + - name: Add Python scripts folder to GITHUB_PATH (ARM64 Windows) + if: matrix.os == 'windows-11-arm' + run: | + python -c "import sysconfig; print(sysconfig.get_path('scripts'))" >> $GITHUB_PATH + + - name: Install yq + if: matrix.package != 'polars' + run: pip install yq + + - name: Update package name + if: matrix.package != 'polars' + run: tomlq -i -t ".project.name = \"${{ matrix.package }}\"" py-polars/pyproject.toml + + - name: Add bigidx feature + if: matrix.package == 'polars-u64-idx' + run: tomlq -i -t '.dependencies.polars.features += ["bigidx"]' py-polars/Cargo.toml + + - name: Update optional dependencies + if: matrix.package != 'polars' + run: sed $SED_INPLACE 's/polars\[/${{ matrix.package }}\[/' py-polars/pyproject.toml + + - name: Determine CPU features for x86-64 + id: features + if: matrix.architecture == 'x86-64' + env: + IS_LTS_CPU: ${{ matrix.package == 'polars-lts-cpu' }} + # IMPORTANT: All features enabled here should also be included in py-polars/polars/_cpu_check.py + run: | + if [[ "$IS_LTS_CPU" = true ]]; then + FEATURES=+sse3,+ssse3,+sse4.1,+sse4.2,+popcnt,+cmpxchg16b + CC_FEATURES="-msse3 -mssse3 -msse4.1 -msse4.2 -mpopcnt -mcx16" + else + TUNE_CPU=skylake + FEATURES=+sse3,+ssse3,+sse4.1,+sse4.2,+popcnt,+cmpxchg16b,+avx,+avx2,+fma,+bmi1,+bmi2,+lzcnt,+pclmulqdq,+movbe + CC_FEATURES="-msse3 -mssse3 -msse4.1 -msse4.2 -mpopcnt -mcx16 -mavx -mavx2 -mfma -mbmi -mbmi2 -mlzcnt -mpclmul -mmovbe" + fi + echo "features=$FEATURES" >> $GITHUB_OUTPUT + echo "tune_cpu=$TUNE_CPU" >> $GITHUB_OUTPUT + echo "cc_features=$CC_FEATURES" >> $GITHUB_OUTPUT + + - name: Set RUSTFLAGS for x86-64 + if: matrix.architecture == 'x86-64' + env: + FEATURES: ${{ steps.features.outputs.features }} + TUNE_CPU: ${{ steps.features.outputs.tune_cpu }} + CC_FEATURES: ${{ steps.features.outputs.cc_features }} + CFG: ${{ matrix.package == 'polars-lts-cpu' && '--cfg allocator="default"' || '' }} + run: | + if [[ -z "$TUNE_CPU" ]]; then + echo "RUSTFLAGS=-C target-feature=$FEATURES $CFG" >> $GITHUB_ENV + echo "CFLAGS=$CC_FEATURES" >> $GITHUB_ENV + else + echo "RUSTFLAGS=-C target-feature=$FEATURES -Z tune-cpu=$TUNE_CPU $CFG" >> $GITHUB_ENV + echo "CFLAGS=$CC_FEATURES -mtune=$TUNE_CPU" >> $GITHUB_ENV + fi + + - name: Set variables in CPU check module + run: | + sed $SED_INPLACE 's/^_POLARS_ARCH = \"unknown\"$/_POLARS_ARCH = \"${{ matrix.architecture }}\"/g' $CPU_CHECK_MODULE + sed $SED_INPLACE 's/^_POLARS_FEATURE_FLAGS = \"\"$/_POLARS_FEATURE_FLAGS = \"${{ steps.features.outputs.features }}\"/g' $CPU_CHECK_MODULE + - name: Set variables in CPU check module - LTS_CPU + if: matrix.package == 'polars-lts-cpu' + run: | + sed $SED_INPLACE 's/^_POLARS_LTS_CPU = False$/_POLARS_LTS_CPU = True/g' $CPU_CHECK_MODULE + + - name: Set Rust target for aarch64 + if: matrix.architecture == 'aarch64' + id: target + run: | + TARGET=$( + if [[ "${{ matrix.os }}" == "macos-15" ]]; then + echo "aarch64-apple-darwin"; + elif [[ "${{ matrix.os }}" == "windows-11-arm" ]]; then + echo "aarch64-pc-windows-msvc"; + else + echo "aarch64-unknown-linux-gnu"; + fi + ) + echo "target=$TARGET" >> $GITHUB_OUTPUT + + - name: Set jemalloc for aarch64 Linux + if: matrix.architecture == 'aarch64' && matrix.os == 'ubuntu-latest' + run: | + echo "JEMALLOC_SYS_WITH_LG_PAGE=16" >> $GITHUB_ENV + + - name: Copy toolchain to py-polars/ (ARM64 Windows) + # Manual fix for: + # TomlError: Unknown character "46" at row 1, col 2, pos 1: + # 1> ../rust-toolchain.toml + if: matrix.os == 'windows-11-arm' + run: cp rust-toolchain.toml py-polars/ + + - name: Build wheel + uses: PyO3/maturin-action@v1 + with: + command: build + target: ${{ steps.target.outputs.target }} + args: > + --profile dist-release + --manifest-path py-polars/Cargo.toml + --out dist + manylinux: ${{ matrix.architecture == 'aarch64' && '2_24' || 'auto' }} + maturin-version: 1.8.3 + + - name: Test wheel + # For linux; only test on x86-64 for now as this matches the runner architecture + if: matrix.architecture == 'x86-64' || startsWith(matrix.os, 'macos') || startsWith(matrix.os, 'windows') + run: | + pip install --force-reinstall --verbose dist/*.whl + python -c 'import polars; polars.show_versions()' + + - name: Upload wheel + uses: actions/upload-artifact@v4 + with: + name: wheel-${{ matrix.package }}-${{ matrix.os }}-${{ matrix.architecture }} + path: dist/*.whl + + build-wheel-pyodide: + name: build-wheels (polars, pyodide, wasm32) + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + with: + ref: ${{ inputs.sha }} + + # Avoid potential out-of-memory errors + - name: Set swap space for Linux + uses: pierotofy/set-swap-space@master + with: + swap-size-gb: 10 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ env.PYTHON_VERSION }} + + - name: Disable incompatible features + env: + FEATURES: parquet|async|json|extract_jsonpath|catalog|cloud|polars_cloud|tokio|clipboard|decompress|new_streaming + run: | + sed -i 's/^ "json",$/ "serde_json",/' crates/polars-python/Cargo.toml + sed -E -i "/^ \"(${FEATURES})\",$/d" crates/polars-python/Cargo.toml py-polars/Cargo.toml + + - name: Setup emsdk + uses: mymindstorm/setup-emsdk@v14 + with: + # This should match the exact version of Emscripten used by Pyodide + # UPDATE; set to latest due to breaking build + version: latest + + - name: Set CFLAGS and RUSTFLAGS for wasm32 + run: | + echo "CFLAGS=-fPIC" >> $GITHUB_ENV + echo "RUSTFLAGS=-C link-self-contained=no" >> $GITHUB_ENV + + - name: Build wheel + uses: PyO3/maturin-action@v1 + with: + command: build + target: wasm32-unknown-emscripten + args: > + --profile dist-release + --manifest-path py-polars/Cargo.toml + --interpreter python3.10 + --out wasm-dist + maturin-version: 1.8.3 + + - name: Upload wheel + uses: actions/upload-artifact@v4 + with: + name: wheel-polars-emscripten-wasm32 + path: wasm-dist/*.whl + + publish-to-pypi: + needs: [create-sdist, build-wheels, build-wheel-pyodide] + environment: + name: release-python + url: https://pypi.org/project/polars + runs-on: ubuntu-latest + permissions: + id-token: write + + steps: + - name: Download sdists and wheels + uses: actions/download-artifact@v4 + with: + path: dist + merge-multiple: true + + - name: Remove Emscripten wheel + run: rm -f dist/*emscripten*.whl + + - name: Publish to PyPI + if: inputs.dry-run == false + uses: pypa/gh-action-pypi-publish@release/v1 + with: + verbose: true + + publish-to-github: + needs: [publish-to-pypi, build-wheel-pyodide] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + ref: ${{ inputs.sha }} + + - name: Download sdist + uses: actions/download-artifact@v4 + with: + name: sdist-polars + path: dist + + - name: Download Pyodide wheel + uses: actions/download-artifact@v4 + with: + name: wheel-polars-emscripten-wasm32 + path: wasm-dist + + - name: Get version from Cargo.toml + id: version + working-directory: py-polars + run: | + VERSION=$(grep -m 1 -oP 'version = "\K[^"]+' Cargo.toml) + if [[ "$VERSION" == *"-"* ]]; then + IS_PRERELEASE=true + else + IS_PRERELEASE=false + fi + echo "version=$VERSION" >> $GITHUB_OUTPUT + echo "is_prerelease=$IS_PRERELEASE" >> $GITHUB_OUTPUT + + - name: Create GitHub release + id: github-release + uses: release-drafter/release-drafter@v6 + with: + config-name: release-drafter-python.yml + name: Python Polars ${{ steps.version.outputs.version }} + tag: py-${{ steps.version.outputs.version }} + version: ${{ steps.version.outputs.version }} + prerelease: ${{ steps.version.outputs.is_prerelease }} + commitish: ${{ inputs.sha || github.sha }} + disable-autolabeler: true + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Upload sdist to GitHub release + run: gh release upload $TAG $FILES --clobber + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + TAG: ${{ steps.github-release.outputs.tag_name }} + FILES: dist/polars-*.tar.gz wasm-dist/polars-*.whl + + - name: Publish GitHub release + if: inputs.dry-run == false + run: gh release edit $TAG --draft=false + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + TAG: ${{ steps.github-release.outputs.tag_name }} + + - name: Trigger other workflows related to the release + if: inputs.dry-run == false + uses: peter-evans/repository-dispatch@v3 + with: + event-type: python-release + client-payload: > + { + "version": "${{ steps.version.outputs.version }}", + "is_prerelease": "${{ steps.version.outputs.is_prerelease }}", + "tag": "${{ steps.github-release.outputs.tag_name }}", + "sha": "${{ inputs.sha || github.sha }}" + } diff --git a/.github/workflows/release-rust.yml b/.github/workflows/release-rust.yml new file mode 100644 index 000000000000..ad7be2155053 --- /dev/null +++ b/.github/workflows/release-rust.yml @@ -0,0 +1,14 @@ +name: Release Rust + +on: + push: + tags: + - rs-* + +# TODO: Implement +jobs: + release-rust: + if: false + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 diff --git a/.github/workflows/test-bytecode-parser.yml b/.github/workflows/test-bytecode-parser.yml new file mode 100644 index 000000000000..d42e598124a6 --- /dev/null +++ b/.github/workflows/test-bytecode-parser.yml @@ -0,0 +1,36 @@ +name: Test Bytecode Parser + +on: + pull_request: + paths: + - py-polars/** + - .github/workflows/test-bytecode-parser.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + ubuntu: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + # Only the versions that are not already run as part of the regular test suite + python-version: ['3.10', '3.11'] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: pip install ipython numpy pytest + + - name: Run tests + env: + PYTHONPATH: py-polars + run: pytest .github/scripts/test_bytecode_parser.py diff --git a/.github/workflows/test-coverage.yml b/.github/workflows/test-coverage.yml new file mode 100644 index 000000000000..06f3d025ffb6 --- /dev/null +++ b/.github/workflows/test-coverage.yml @@ -0,0 +1,187 @@ +name: Code coverage + +on: + pull_request: + paths: + - '**.rs' + - '**.py' + - .github/workflows/test-coverage.yml + push: + branches: + - main + paths: + - '**.rs' + - '**.py' + - .github/workflows/test-coverage.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +defaults: + run: + shell: bash + +env: + RUSTFLAGS: '-C instrument-coverage --cfg=coverage --cfg=coverage_nightly --cfg=trybuild_no_target' + RUST_BACKTRACE: 1 + LLVM_PROFILE_FILE: ${{ github.workspace }}/target/polars-%p-%3m.profraw + CARGO_LLVM_COV: 1 + CARGO_LLVM_COV_SHOW_ENV: 1 + CARGO_LLVM_COV_TARGET_DIR: ${{ github.workspace }}/target + # We use the stable ABI, silences error from PyO3 that the system Python is too new. + PYO3_USE_ABI3_FORWARD_COMPATIBILITY: 1 + +jobs: + coverage-rust: + # Running under ubuntu doesn't seem to work: + # https://github.com/pola-rs/polars/issues/14255 + runs-on: macos-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Rust + run: rustup component add llvm-tools-preview + + - name: Install cargo-llvm-cov + uses: taiki-e/install-action@cargo-llvm-cov + + - name: Cache Rust + uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref_name == 'main' }} + + - name: Prepare coverage + run: cargo llvm-cov clean --workspace + + - name: Run tests + run: > + cargo test --all-features + -p polars-arrow + -p polars-compute + -p polars-core + -p polars-io + -p polars-lazy + -p polars-ops + -p polars-parquet + -p polars-plan + -p polars-row + -p polars-sql + -p polars-time + -p polars-utils + + - name: Run integration tests + run: cargo test --all-features -p polars --test it + + - name: Report coverage + run: cargo llvm-cov report --lcov --output-path coverage-rust.lcov + + - name: Upload coverage report + uses: actions/upload-artifact@v4 + with: + name: coverage-rust + path: coverage-rust.lcov + + coverage-python: + # Running under ubuntu doesn't seem to work: + # https://github.com/pola-rs/polars/issues/14255 + runs-on: macos-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + check-latest: true + + - name: Install uv + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + echo "$HOME/.local/bin" >> "$GITHUB_PATH" + + - name: Create virtual environment + run: | + uv venv + echo "$GITHUB_WORKSPACE/.venv/bin" >> $GITHUB_PATH + echo "VIRTUAL_ENV=$GITHUB_WORKSPACE/.venv" >> $GITHUB_ENV + + - name: Install Python dependencies + working-directory: py-polars + run: | + # Install typing-extensions separately whilst the `--extra-index-url` in `requirements-ci.txt` + # doesn't have an up-to-date typing-extensions, see + # https://github.com/astral-sh/uv/issues/6028#issuecomment-2287232150 + uv pip install -U typing-extensions + uv pip install --compile-bytecode -r requirements-dev.txt -r requirements-ci.txt --verbose + + - name: Set up Rust + run: rustup component add llvm-tools-preview + + - name: Install cargo-llvm-cov + uses: taiki-e/install-action@cargo-llvm-cov + + - name: Cache Rust + uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref_name == 'main' }} + + - name: Prepare coverage + run: cargo llvm-cov clean --workspace + + - name: Install Polars + run: maturin develop -m py-polars/Cargo.toml + + - name: Run Python tests + working-directory: py-polars + env: + POLARS_TIMEOUT_MS: 60000 + run: > + pytest + -n auto + -m "not release and not benchmark and not docs" + -k 'not test_polars_import' + --cov --cov-report xml:main.xml + + - name: Run Python tests - async reader + working-directory: py-polars + env: + POLARS_FORCE_ASYNC: 1 + POLARS_TIMEOUT_MS: 60000 + run: > + pytest tests/unit/io/ + -n auto + -m "not release and not benchmark and not docs" + --cov --cov-report xml:async.xml --cov-fail-under=0 + + - name: Report Rust coverage + run: cargo llvm-cov report --lcov --output-path coverage-python.lcov + + - name: Upload coverage reports + uses: actions/upload-artifact@v4 + with: + name: coverage-python + path: | + coverage-python.lcov + py-polars/main.xml + py-polars/async.xml + + upload-coverage: + needs: [coverage-rust, coverage-python] + runs-on: ubuntu-latest + + steps: + # Needed to fetch the Codecov config file + - uses: actions/checkout@v4 + + - name: Download coverage reports + uses: actions/download-artifact@v4 + with: + merge-multiple: true + + - name: Upload coverage reports + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} + files: coverage-rust.lcov,coverage-python.lcov,py-polars/main.xml,py-polars/async.xml + root_dir: ${{ github.workspace }} diff --git a/.github/workflows/test-pyodide.yml b/.github/workflows/test-pyodide.yml new file mode 100644 index 000000000000..cecf79153ec7 --- /dev/null +++ b/.github/workflows/test-pyodide.yml @@ -0,0 +1,71 @@ +name: Test Pyodide + +on: + pull_request: + paths: + - Cargo.lock + - py-polars/** + - docs/source/src/python/** + - crates/** + - .github/workflows/test-pyodide.yml + push: + branches: + - main + paths: + - Cargo.lock + - crates/** + - docs/source/src/python/** + - py-polars/** + - .github/workflows/test-pyodide.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +env: + RUSTFLAGS: -C link-self-contained=no -C debuginfo=0 + CFLAGS: -fPIC + +defaults: + run: + shell: bash + +jobs: + test-pyodide: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Disable incompatible features + env: + # Note: If you change this line, you must also copy the change to `release-python.yml`. + FEATURES: parquet|async|json|extract_jsonpath|catalog|cloud|polars_cloud|tokio|clipboard|decompress|new_streaming + run: | + sed -i 's/^ "json",$/ "serde_json",/' crates/polars-python/Cargo.toml + sed -E -i "/^ \"(${FEATURES})\",$/d" crates/polars-python/Cargo.toml py-polars/Cargo.toml + + - name: Setup emsdk + uses: mymindstorm/setup-emsdk@v14 + with: + # This should match the exact version of Emscripten used by Pyodide + version: 3.1.58 + + - name: Install LLVM + # This should match the major version of LLVM expected by Emscripten + run: | + wget https://apt.llvm.org/llvm.sh + chmod +x llvm.sh + sudo ./llvm.sh 19 + echo "EM_LLVM_ROOT=/usr/lib/llvm-19/bin" >> $GITHUB_ENV + + - name: Build wheel + uses: PyO3/maturin-action@v1 + with: + command: build + target: wasm32-unknown-emscripten + args: > + --profile dev + --manifest-path py-polars/Cargo.toml + --interpreter python3.10 + maturin-version: 1.7.4 diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml new file mode 100644 index 000000000000..fbb9508802c6 --- /dev/null +++ b/.github/workflows/test-python.yml @@ -0,0 +1,130 @@ +name: Test Python + +on: + pull_request: + paths: + - Cargo.lock + - py-polars/** + - docs/source/src/python/** + - crates/** + - .github/workflows/test-python.yml + push: + branches: + - main + paths: + - Cargo.lock + - crates/** + - docs/source/src/python/** + - py-polars/** + - .github/workflows/test-python.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +env: + RUSTFLAGS: -C debuginfo=0 # Do not produce debug symbols to keep memory usage down + RUST_BACKTRACE: 1 + PYTHONUTF8: 1 + +defaults: + run: + working-directory: py-polars + shell: bash + +jobs: + test-python: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python-version: ['3.9', '3.12', '3.13'] + include: + - os: windows-latest + python-version: '3.13' + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Set up Graphviz + uses: ts-graphviz/setup-graphviz@v2 + + - name: Create virtual environment + env: + BIN: ${{ matrix.os == 'windows-latest' && 'Scripts' || 'bin' }} + run: | + python -m venv .venv + echo "$GITHUB_WORKSPACE/py-polars/.venv/$BIN" >> $GITHUB_PATH + echo "VIRTUAL_ENV=$GITHUB_WORKSPACE/py-polars/.venv" >> $GITHUB_ENV + + - name: Install Python dependencies + run: | + pip install uv + # Install typing-extensions separately whilst the `--extra-index-url` in `requirements-ci.txt` + # doesn't have an up-to-date typing-extensions, see + # https://github.com/astral-sh/uv/issues/6028#issuecomment-2287232150 + uv pip install -U typing-extensions + uv pip install --compile-bytecode -r requirements-dev.txt -r requirements-ci.txt --verbose --index-strategy=unsafe-best-match + - name: Set up Rust + run: rustup show + + - name: Cache Rust + uses: Swatinem/rust-cache@v2 + with: + workspaces: py-polars + save-if: ${{ github.ref_name == 'main' }} + + - name: Install Polars + run: maturin develop + + - name: Run doctests + if: github.ref_name != 'main' && matrix.python-version == '3.13' && matrix.os == 'ubuntu-latest' + run: | + python tests/docs/run_doctest.py + pytest tests/docs/test_user_guide.py -m docs + + - name: Run tests + if: github.ref_name != 'main' + env: + POLARS_TIMEOUT_MS: 60000 + run: pytest -n auto -m "not release and not benchmark and not docs" + + - name: Run tests with new streaming engine + if: github.ref_name != 'main' + env: + POLARS_AUTO_NEW_STREAMING: 1 + POLARS_TIMEOUT_MS: 60000 + run: pytest -n auto -m "not may_fail_auto_streaming and not slow and not write_disk and not release and not docs and not hypothesis and not benchmark and not ci_only" + + - name: Run tests async reader tests + if: github.ref_name != 'main' && matrix.os != 'windows-latest' + env: + POLARS_FORCE_ASYNC: 1 + POLARS_TIMEOUT_MS: 60000 + run: pytest -n auto -m "not release and not benchmark and not docs" tests/unit/io/ + + - name: Check import without optional dependencies + if: github.ref_name != 'main' && matrix.python-version == '3.13' && matrix.os == 'ubuntu-latest' + run: | + declare -a deps=("pandas" + "pyarrow" + "fsspec" + "matplotlib" + "backports.zoneinfo" + "connectorx" + "pyiceberg" + "deltalake" + "xlsx2csv" + ) + for d in "${deps[@]}" + do + echo "uninstall $i and check imports..." + pip uninstall "$d" -y + python -c 'import polars' + done diff --git a/.github/workflows/test-rust.yml b/.github/workflows/test-rust.yml new file mode 100644 index 000000000000..4e54ca0cf8e9 --- /dev/null +++ b/.github/workflows/test-rust.yml @@ -0,0 +1,147 @@ +name: Test Rust + +on: + pull_request: + paths: + - crates/** + - examples/** + - Cargo.toml + - .github/workflows/test-rust.yml + push: + branches: + - main + paths: + - crates/** + - examples/** + - Cargo.toml + - .github/workflows/test-rust.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +env: + RUSTFLAGS: -C debuginfo=0 # Do not produce debug symbols to keep memory usage down + RUST_BACKTRACE: 1 + +jobs: + test: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-latest] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Rust + run: rustup show + + - name: Cache Rust + uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref_name == 'main' }} + + - name: Compile tests + run: > + cargo test --all-features --no-run + -p polars-arrow + -p polars-compute + -p polars-core + -p polars-io + -p polars-lazy + -p polars-ops + -p polars-parquet + -p polars-plan + -p polars-row + -p polars-sql + -p polars-time + -p polars-utils + + - name: Run tests + if: github.ref_name != 'main' + run: > + cargo test --all-features + -p polars-arrow + -p polars-compute + -p polars-core + -p polars-io + -p polars-lazy + -p polars-ops + -p polars-parquet + -p polars-plan + -p polars-row + -p polars-sql + -p polars-time + -p polars-utils + + integration-test: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: true + matrix: + os: [ubuntu-latest, windows-latest] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Rust + run: rustup show + + - name: Cache Rust + uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref_name == 'main' }} + + - name: Compile integration tests + run: cargo test --all-features -p polars --test it --no-run + + - name: Run integration tests + if: github.ref_name != 'main' + run: cargo test --all-features -p polars --test it + + check-features: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Rust + run: rustup show + + - name: Cache Rust + uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref_name == 'main' }} + + - name: Install cargo hack + uses: taiki-e/install-action@v2 + with: + tool: cargo-hack + + - name: Run cargo hack + run: cargo hack check -p polars --each-feature --no-dev-deps + + check-wasm: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Rust + run: | + rustup target add wasm32-unknown-unknown + rustup show + + - name: Cache Rust + uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref_name == 'main' }} + + - name: Install cargo hack + uses: taiki-e/install-action@v2 + with: + tool: cargo-hack + + - name: Check wasm + working-directory: crates + run: make check-wasm diff --git a/.gitignore b/.gitignore index 8f2926f377f5..a055cde851a6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,44 @@ +*.iml +*.so +*.pyd +*.pdb +*.ipynb +.ENV +.env +.ipynb_checkpoints/ +.python-version +.yarn/ +coverage.lcov +coverage.xml +profile.json +polars/vendor + +# OS +.DS_Store + +# IDE .idea/ -target/ -Cargo.lock -data/ +.vscode/ +.vim + +# Python +.hypothesis/ +.mypy_cache/ +.pytest_cache/ +.ruff_cache/ +.venv*/ __pycache__/ -.ipynb_checkpoints/ +.coverage + +# Rust +target/ + +# Data +*.csv +*.parquet +*.feather +*.tbl + +# Project +/docs/assets/data/ +/docs/assets/people.md diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 6f128601469f..000000000000 --- a/.travis.yml +++ /dev/null @@ -1,15 +0,0 @@ -language: rust -rust: - - nightly - -script: - - cd polars && cargo test - -after_success: | - cd .. && \ - cargo doc --no-deps --all-features --package polars \ - && echo '' > target/doc/index.html && \ - sudo pip install ghp-import && \ - ghp-import -n target/doc && \ - git push -qf https://${GITHUB_TOKEN}@github.com/${TRAVIS_REPO_SLUG}.git gh-pages - diff --git a/.typos.toml b/.typos.toml new file mode 100644 index 000000000000..fd3a53fbe6b1 --- /dev/null +++ b/.typos.toml @@ -0,0 +1,37 @@ +[files] +extend-exclude = [ + ".git/", + "*.csv", + "*.gz", + "dists.dss", + "**/images/*", + "**/py-polars/polars/_utils/nest_asyncio.py", +] +ignore-hidden = false + +[default] +extend-ignore-re = [ + '"Theatre": \[.+\],', +] + +[default.extend-words] +arange = "arange" +ser = "ser" +splitted = "splitted" +strat = "strat" + +[type.py.extend-identifiers] +ba = "ba" +ody = "ody" + +[type.rust.extend-identifiers] +ANDed = "ANDed" +bck = "bck" +Fo = "Fo" +ND = "ND" +nd = "nd" +NDJson = "NDJson" +NDJsonSinkNode = "NDJsonSinkNode" +NDJsonReadOptions = "NDJsonReadOptions" +NDJsonFileReader = "NDJsonFileReader" +opt_nd = "opt_nd" diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000000..175df57e8064 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,9 @@ +# Contributing to Polars + +Thanks for taking the time to contribute! We appreciate all contributions, from reporting bugs to +implementing new features. + +Please refer to the [contributing section](https://docs.pola.rs/development/contributing/) of our +documentation to get started. + +We look forward to your contributions! diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 000000000000..513bcbf02086 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,5843 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "addr2line" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" +dependencies = [ + "gimli", +] + +[[package]] +name = "adler2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" + +[[package]] +name = "adler32" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234" + +[[package]] +name = "ahash" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +dependencies = [ + "cfg-if", + "getrandom 0.2.15", + "once_cell", + "version_check", + "zerocopy", +] + +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + +[[package]] +name = "alloc-no-stdlib" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" + +[[package]] +name = "alloc-stdlib" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" +dependencies = [ + "alloc-no-stdlib", +] + +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + +[[package]] +name = "anstyle" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" + +[[package]] +name = "anyhow" +version = "1.0.97" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcfed56ad506cb2c684a14971b8861fdc3baaaae314b9e5f9bb532cbe3ba7a4f" + +[[package]] +name = "apache-avro" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1aef82843a0ec9f8b19567445ad2421ceeb1d711514384bdd3d49fe37102ee13" +dependencies = [ + "bigdecimal", + "crc32fast", + "digest", + "libflate 2.1.0", + "log", + "num-bigint", + "quad-rand", + "rand", + "regex-lite", + "serde", + "serde_bytes", + "serde_json", + "snap", + "strum", + "strum_macros", + "thiserror 1.0.69", + "typed-builder", + "uuid", +] + +[[package]] +name = "arboard" +version = "3.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df099ccb16cd014ff054ac1bf392c67feeef57164b05c42f037cd40f5d4357f4" +dependencies = [ + "clipboard-win", + "log", + "objc2", + "objc2-app-kit", + "objc2-foundation", + "parking_lot", + "x11rb", +] + +[[package]] +name = "argminmax" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52424b59d69d69d5056d508b260553afd91c57e21849579cd1f50ee8b8b88eaa" +dependencies = [ + "num-traits", +] + +[[package]] +name = "array-init-cursor" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf7d0a018de4f6aa429b9d33d69edf69072b1c5b1cb8d3e4a5f7ef898fc3eb76" + +[[package]] +name = "arrayref" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" + +[[package]] +name = "arrayvec" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" + +[[package]] +name = "arrow2" +version = "0.17.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59c468daea140b747d781a1da9f7db5f0a8e6636d4af20cc539e43d05b0604fa" +dependencies = [ + "ahash", + "bytemuck", + "chrono", + "dyn-clone", + "either", + "ethnum", + "foreign_vec", + "getrandom 0.2.15", + "hash_hasher", + "num-traits", + "rustc_version", + "simdutf8", +] + +[[package]] +name = "async-channel" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89b47800b0be77592da0afd425cc03468052844aff33b84e33cc696f64e77b6a" +dependencies = [ + "concurrent-queue", + "event-listener-strategy", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "async-trait" +version = "0.1.87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d556ec1359574147ec0c4fc5eb525f3f23263a592b1a9c07e0a75b427de55c97" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "atoi_simd" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4790f9e8961209112beb783d85449b508673cf4a6a419c8449b210743ac4dbe9" + +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + +[[package]] +name = "autocfg" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" + +[[package]] +name = "avro-schema" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5281855b39aba9684d2f47bf96983fbfd8f1725f12fabb0513a8ab879647bbd" +dependencies = [ + "async-stream", + "crc 2.1.0", + "fallible-streaming-iterator", + "futures", + "libflate 1.4.0", + "serde", + "serde_json", + "snap", +] + +[[package]] +name = "aws-config" +version = "1.5.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490aa7465ee685b2ced076bb87ef654a47724a7844e2c7d3af4e749ce5b875dd" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-sdk-sso", + "aws-sdk-ssooidc", + "aws-sdk-sts", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "hex", + "http 0.2.12", + "ring", + "time", + "tokio", + "tracing", + "url", + "zeroize", +] + +[[package]] +name = "aws-credential-types" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60e8f6b615cb5fc60a98132268508ad104310f0cfb25a1c22eee76efdf9154da" +dependencies = [ + "aws-smithy-async", + "aws-smithy-runtime-api", + "aws-smithy-types", + "zeroize", +] + +[[package]] +name = "aws-runtime" +version = "1.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76dd04d39cc12844c0994f2c9c5a6f5184c22e9188ec1ff723de41910a21dcad" +dependencies = [ + "aws-credential-types", + "aws-sigv4", + "aws-smithy-async", + "aws-smithy-eventstream", + "aws-smithy-http", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "http 0.2.12", + "http-body 0.4.6", + "once_cell", + "percent-encoding", + "pin-project-lite", + "tracing", + "uuid", +] + +[[package]] +name = "aws-sdk-s3" +version = "1.77.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34e87342432a3de0e94e82c99a7cbd9042f99de029ae1f4e368160f9e9929264" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-sigv4", + "aws-smithy-async", + "aws-smithy-checksums 0.63.0", + "aws-smithy-eventstream", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-smithy-xml", + "aws-types", + "bytes", + "fastrand", + "hex", + "hmac", + "http 0.2.12", + "http-body 0.4.6", + "lru", + "once_cell", + "percent-encoding", + "regex-lite", + "sha2", + "tracing", + "url", +] + +[[package]] +name = "aws-sdk-sso" +version = "1.60.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60186fab60b24376d3e33b9ff0a43485f99efd470e3b75a9160c849741d63d56" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "http 0.2.12", + "once_cell", + "regex-lite", + "tracing", +] + +[[package]] +name = "aws-sdk-ssooidc" +version = "1.61.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7033130ce1ee13e6018905b7b976c915963755aef299c1521897679d6cd4f8ef" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "http 0.2.12", + "once_cell", + "regex-lite", + "tracing", +] + +[[package]] +name = "aws-sdk-sts" +version = "1.61.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5c1cac7677179d622b4448b0d31bcb359185295dc6fca891920cfb17e2b5156" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-query", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-smithy-xml", + "aws-types", + "http 0.2.12", + "once_cell", + "regex-lite", + "tracing", +] + +[[package]] +name = "aws-sigv4" +version = "1.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9bfe75fad52793ce6dec0dc3d4b1f388f038b5eb866c8d4d7f3a8e21b5ea5051" +dependencies = [ + "aws-credential-types", + "aws-smithy-eventstream", + "aws-smithy-http", + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", + "crypto-bigint 0.5.5", + "form_urlencoded", + "hex", + "hmac", + "http 0.2.12", + "http 1.2.0", + "once_cell", + "p256", + "percent-encoding", + "ring", + "sha2", + "subtle", + "time", + "tracing", + "zeroize", +] + +[[package]] +name = "aws-smithy-async" +version = "1.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa59d1327d8b5053c54bf2eaae63bf629ba9e904434d0835a28ed3c0ed0a614e" +dependencies = [ + "futures-util", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "aws-smithy-checksums" +version = "0.60.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba1a71073fca26775c8b5189175ea8863afb1c9ea2cceb02a5de5ad9dfbaa795" +dependencies = [ + "aws-smithy-http", + "aws-smithy-types", + "bytes", + "crc32c", + "crc32fast", + "hex", + "http 0.2.12", + "http-body 0.4.6", + "md-5", + "pin-project-lite", + "sha1", + "sha2", + "tracing", +] + +[[package]] +name = "aws-smithy-checksums" +version = "0.63.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db2dc8d842d872529355c72632de49ef8c5a2949a4472f10e802f28cf925770c" +dependencies = [ + "aws-smithy-http", + "aws-smithy-types", + "bytes", + "crc32c", + "crc32fast", + "crc64fast-nvme", + "hex", + "http 0.2.12", + "http-body 0.4.6", + "md-5", + "pin-project-lite", + "sha1", + "sha2", + "tracing", +] + +[[package]] +name = "aws-smithy-eventstream" +version = "0.60.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "461e5e02f9864cba17cff30f007c2e37ade94d01e87cdb5204e44a84e6d38c17" +dependencies = [ + "aws-smithy-types", + "bytes", + "crc32fast", +] + +[[package]] +name = "aws-smithy-http" +version = "0.60.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7809c27ad8da6a6a68c454e651d4962479e81472aa19ae99e59f9aba1f9713cc" +dependencies = [ + "aws-smithy-eventstream", + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", + "bytes-utils", + "futures-core", + "http 0.2.12", + "http-body 0.4.6", + "once_cell", + "percent-encoding", + "pin-project-lite", + "pin-utils", + "tracing", +] + +[[package]] +name = "aws-smithy-json" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "623a51127f24c30776c8b374295f2df78d92517386f77ba30773f15a30ce1422" +dependencies = [ + "aws-smithy-types", +] + +[[package]] +name = "aws-smithy-query" +version = "0.60.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2fbd61ceb3fe8a1cb7352e42689cec5335833cd9f94103a61e98f9bb61c64bb" +dependencies = [ + "aws-smithy-types", + "urlencoding", +] + +[[package]] +name = "aws-smithy-runtime" +version = "1.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d526a12d9ed61fadefda24abe2e682892ba288c2018bcb38b1b4c111d13f6d92" +dependencies = [ + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", + "fastrand", + "h2 0.3.26", + "http 0.2.12", + "http-body 0.4.6", + "http-body 1.0.1", + "httparse", + "hyper 0.14.32", + "hyper-rustls 0.24.2", + "once_cell", + "pin-project-lite", + "pin-utils", + "rustls 0.21.12", + "tokio", + "tracing", +] + +[[package]] +name = "aws-smithy-runtime-api" +version = "1.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92165296a47a812b267b4f41032ff8069ab7ff783696d217f0994a0d7ab585cd" +dependencies = [ + "aws-smithy-async", + "aws-smithy-types", + "bytes", + "http 0.2.12", + "http 1.2.0", + "pin-project-lite", + "tokio", + "tracing", + "zeroize", +] + +[[package]] +name = "aws-smithy-types" +version = "1.2.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7b8a53819e42f10d0821f56da995e1470b199686a1809168db6ca485665f042" +dependencies = [ + "base64-simd", + "bytes", + "bytes-utils", + "futures-core", + "http 0.2.12", + "http 1.2.0", + "http-body 0.4.6", + "http-body 1.0.1", + "http-body-util", + "itoa", + "num-integer", + "pin-project-lite", + "pin-utils", + "ryu", + "serde", + "time", + "tokio", + "tokio-util", +] + +[[package]] +name = "aws-smithy-xml" +version = "0.60.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab0b0166827aa700d3dc519f72f8b3a91c35d0b8d042dc5d643a91e6f80648fc" +dependencies = [ + "xmlparser", +] + +[[package]] +name = "aws-types" +version = "1.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfbd0a668309ec1f66c0f6bda4840dd6d4796ae26d699ebc266d7cc95c6d040f" +dependencies = [ + "aws-credential-types", + "aws-smithy-async", + "aws-smithy-runtime-api", + "aws-smithy-types", + "rustc_version", + "tracing", +] + +[[package]] +name = "backtrace" +version = "0.3.74" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" +dependencies = [ + "addr2line", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", + "windows-targets 0.52.6", +] + +[[package]] +name = "base16ct" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349a06037c7bf932dd7e7d1f653678b2038b9ad46a74102f1fc7bd7872678cce" + +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "base64-simd" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "339abbe78e73178762e23bea9dfd08e697eb3f3301cd4be981c0f78ba5859195" +dependencies = [ + "outref", + "vsimd", +] + +[[package]] +name = "base64ct" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" + +[[package]] +name = "bigdecimal" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f31f3af01c5c65a07985c804d3366560e6fa7883d640a122819b14ec327482c" +dependencies = [ + "autocfg", + "libm", + "num-bigint", + "num-integer", + "num-traits", + "serde", +] + +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + +[[package]] +name = "bitflags" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c8214115b7bf84099f1309324e63141d4c5d7cc26862f97a0a857dbefe165bd" +dependencies = [ + "serde", +] + +[[package]] +name = "blake3" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "675f87afced0413c9bb02843499dbbd3882a237645883f71a2b59644a6d2f753" +dependencies = [ + "arrayref", + "arrayvec", + "cc", + "cfg-if", + "constant_time_eq", +] + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + +[[package]] +name = "block2" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c132eebf10f5cad5289222520a4a058514204aed6d791f1cf4fe8088b82d15f" +dependencies = [ + "objc2", +] + +[[package]] +name = "brotli" +version = "7.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc97b8f16f944bba54f0433f07e30be199b6dc2bd25937444bbad560bcea29bd" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor", +] + +[[package]] +name = "brotli-decompressor" +version = "4.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74fa05ad7d803d413eb8380983b092cbbaf9a85f151b871360e7b00cd7060b37" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", +] + +[[package]] +name = "bumpalo" +version = "3.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" + +[[package]] +name = "bytemuck" +version = "1.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6b1fc10dbac614ebc03540c9dbd60e83887fda27794998c6528f1782047d540" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fa76293b4f7bb636ab88fd78228235b5248b4d05cc589aed610f954af5d7c7a" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "bytes" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" +dependencies = [ + "serde", +] + +[[package]] +name = "bytes-utils" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dafe3a8757b027e2be6e4e5601ed563c55989fcf1546e933c66c8eb3a058d35" +dependencies = [ + "bytes", + "either", +] + +[[package]] +name = "casey" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e779867f62d81627d1438e0d3fb6ed7d7c9d64293ca6d87a1e88781b94ece1c" +dependencies = [ + "syn 2.0.99", +] + +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + +[[package]] +name = "castaway" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0abae9be0aaf9ea96a3b1b8b1b55c602ca751eba1b1500220cea4ecbafe7c0d5" +dependencies = [ + "rustversion", +] + +[[package]] +name = "cc" +version = "1.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be714c154be609ec7f5dad223a33bf1482fff90472de28f7362806e6d4832b8c" +dependencies = [ + "jobserver", + "libc", + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + +[[package]] +name = "chrono" +version = "0.4.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a7964611d71df112cb1730f2ee67324fcf4d0fc6606acbbe9bfe06df124637c" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "num-traits", + "serde", + "windows-link", +] + +[[package]] +name = "chrono-tz" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c6ac4f2c0bf0f44e9161aec9675e1050aa4a530663c4a9e37e108fa948bca9f" +dependencies = [ + "chrono", + "chrono-tz-build", + "phf", +] + +[[package]] +name = "chrono-tz-build" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e94fea34d77a245229e7746bd2beb786cd2a896f306ff491fb8cecb3074b10a7" +dependencies = [ + "parse-zoneinfo", + "phf_codegen", +] + +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + +[[package]] +name = "clap" +version = "4.5.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "027bb0d98429ae334a8698531da7077bdf906419543a35a55c2cb1b66437d767" +dependencies = [ + "clap_builder", +] + +[[package]] +name = "clap_builder" +version = "4.5.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5589e0cba072e0f3d23791efac0fd8627b49c829c196a492e88168e6a669d863" +dependencies = [ + "anstyle", + "clap_lex", +] + +[[package]] +name = "clap_lex" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" + +[[package]] +name = "clipboard-win" +version = "5.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15efe7a882b08f34e38556b14f2fb3daa98769d06c7f0c1b076dfd0d983bc892" +dependencies = [ + "error-code", +] + +[[package]] +name = "comfy-table" +version = "7.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a65ebfec4fb190b6f90e944a817d60499ee0744e582530e2c9900a22e591d9a" +dependencies = [ + "crossterm", + "unicode-segmentation", + "unicode-width", +] + +[[package]] +name = "compact_str" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b79c4069c6cad78e2e0cdfcbd26275770669fb39fd308a752dc110e83b9af32" +dependencies = [ + "castaway", + "cfg-if", + "itoa", + "rustversion", + "ryu", + "serde", + "static_assertions", +] + +[[package]] +name = "concurrent-queue" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "const-oid" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" + +[[package]] +name = "constant_time_eq" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" + +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b55271e5c8c478ad3f38ad24ef34923091e0548492a266d19b3c0b4d82574c63" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "core2" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b49ba7ef1ad6107f8824dbe97de947cbaac53c44e7f9756a1fba0d37c1eec505" +dependencies = [ + "memchr", +] + +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + +[[package]] +name = "crc" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49fc9a695bca7f35f5f4c15cddc84415f66a74ea78eef08e90c5024f2b540e23" +dependencies = [ + "crc-catalog 1.1.1", +] + +[[package]] +name = "crc" +version = "3.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69e6e4d7b33a94f0991c26729976b10ebde1d34c3ee82408fb536164fa10d636" +dependencies = [ + "crc-catalog 2.4.0", +] + +[[package]] +name = "crc-catalog" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccaeedb56da03b09f598226e25e80088cb4cd25f316e6e4df7d695f0feeb1403" + +[[package]] +name = "crc-catalog" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" + +[[package]] +name = "crc32c" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a47af21622d091a8f0fb295b88bc886ac74efcc613efc19f5d0b21de5c89e47" +dependencies = [ + "rustc_version", +] + +[[package]] +name = "crc32fast" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crc64fast-nvme" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4955638f00a809894c947f85a024020a20815b65a5eea633798ea7924edab2b3" +dependencies = [ + "crc 3.2.1", +] + +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools 0.10.5", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools 0.10.5", +] + +[[package]] +name = "crossbeam-channel" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-queue" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + +[[package]] +name = "crossterm" +version = "0.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "829d955a0bb380ef178a640b91779e3987da38c9aea133b20614cfed8cdea9c6" +dependencies = [ + "bitflags", + "crossterm_winapi", + "parking_lot", + "rustix", + "winapi", +] + +[[package]] +name = "crossterm_winapi" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b" +dependencies = [ + "winapi", +] + +[[package]] +name = "crunchy" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" + +[[package]] +name = "crypto-bigint" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef2b4b23cddf68b89b8f8069890e8c270d54e2d5fe1b143820234805e4cb17ef" +dependencies = [ + "generic-array", + "rand_core", + "subtle", + "zeroize", +] + +[[package]] +name = "crypto-bigint" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dc92fb57ca44df6db8059111ab3af99a63d5d0f8375d9972e319a379c6bab76" +dependencies = [ + "rand_core", + "subtle", +] + +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "dary_heap" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04d2cd9c18b9f454ed67da600630b021a8a80bf33f8c95896ab33aaf1c26b728" + +[[package]] +name = "der" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1a467a65c5e759bce6e65eaf91cc29f466cdc57cb65777bd646872a8a1fd4de" +dependencies = [ + "const-oid", + "zeroize", +] + +[[package]] +name = "deranged" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +dependencies = [ + "powerfmt", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", + "subtle", +] + +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "doc-comment" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" + +[[package]] +name = "dyn-clone" +version = "1.0.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c7a8fb8a9fbf66c1f703fe16184d10ca0ee9d23be5b4436400408ba54a95005" + +[[package]] +name = "ecdsa" +version = "0.14.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413301934810f597c1d19ca71c8710e99a3f1ba28a0d2ebc01551a2daeea3c5c" +dependencies = [ + "der", + "elliptic-curve", + "rfc6979", + "signature", +] + +[[package]] +name = "either" +version = "1.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7914353092ddf589ad78f25c5c1c21b7f80b0ff8621e7c814c3485b5306da9d" +dependencies = [ + "serde", +] + +[[package]] +name = "elliptic-curve" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7bb888ab5300a19b8e5bceef25ac745ad065f3c9f7efc6de1b91958110891d3" +dependencies = [ + "base16ct", + "crypto-bigint 0.4.9", + "der", + "digest", + "ff", + "generic-array", + "group", + "pkcs8", + "rand_core", + "sec1", + "subtle", + "zeroize", +] + +[[package]] +name = "enum_dispatch" +version = "0.3.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa18ce2bc66555b3218614519ac839ddb759a7d6720732f979ef8d13be147ecd" +dependencies = [ + "once_cell", + "proc-macro2", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "env_logger" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a19187fea3ac7e84da7dacf48de0c45d63c6a76f9490dae389aead16c243fce3" +dependencies = [ + "log", + "regex", +] + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "errno" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" +dependencies = [ + "libc", + "windows-sys 0.59.0", +] + +[[package]] +name = "error-code" +version = "3.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5d9305ccc6942a704f4335694ecd3de2ea531b114ac2d51f5f843750787a92f" + +[[package]] +name = "ethnum" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b90ca2580b73ab6a1f724b76ca11ab632df820fd6040c336200d2c1df7b3c82c" + +[[package]] +name = "event-listener" +version = "5.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3492acde4c3fc54c845eaab3eed8bd00c7a7d881f78bfc801e43a93dec1331ae" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + +[[package]] +name = "event-listener-strategy" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93" +dependencies = [ + "event-listener", + "pin-project-lite", +] + +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + +[[package]] +name = "fast-float2" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8eb564c5c7423d25c886fb561d1e4ee69f72354d16918afa32c08811f6b6a55" + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + +[[package]] +name = "ff" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d013fc25338cc558c5c2cfbad646908fb23591e2404481826742b651c9af7160" +dependencies = [ + "rand_core", + "subtle", +] + +[[package]] +name = "flate2" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11faaf5a5236997af9848be0bef4db95824b1d534ebc64d0f0c6cf3e67bd38dc" +dependencies = [ + "crc32fast", + "libz-rs-sys", + "miniz_oxide", +] + +[[package]] +name = "float-cmp" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b09cf3155332e944990140d967ff5eceb70df778b34f77d8075db46e4704e6d8" +dependencies = [ + "num-traits", +] + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + +[[package]] +name = "foreign_vec" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee1b05cbd864bcaecbd3455d6d967862d446e4ebfc3c2e5e5b9841e53cba6673" + +[[package]] +name = "form_urlencoded" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "fs4" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be058769cf1633370c3d0dac6bb9b223b8f18900cf808abadf7843192e706238" +dependencies = [ + "rustix", + "windows-sys 0.59.0", +] + +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" + +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" + +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "futures-sink" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" + +[[package]] +name = "futures-task" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" + +[[package]] +name = "futures-util" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + +[[package]] +name = "gethostname" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0176e0459c2e4a1fe232f984bca6890e681076abb9934f6cea7c326f3fc47818" +dependencies = [ + "libc", + "windows-targets 0.48.5", +] + +[[package]] +name = "getrandom" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +dependencies = [ + "cfg-if", + "js-sys", + "libc", + "wasi 0.11.0+wasi-snapshot-preview1", + "wasm-bindgen", +] + +[[package]] +name = "getrandom" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.13.3+wasi-0.2.2", + "windows-targets 0.52.6", +] + +[[package]] +name = "gimli" +version = "0.31.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" + +[[package]] +name = "glob" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" + +[[package]] +name = "group" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5dfbfb3a6cfbd390d5c9564ab283a0349b9b9fcd46a706c1eb10e0db70bfbac7" +dependencies = [ + "ff", + "rand_core", + "subtle", +] + +[[package]] +name = "h2" +version = "0.3.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81fe527a889e1532da5c525686d96d4c2e74cdd345badf8dfef9f6b39dd5f5e8" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http 0.2.12", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "h2" +version = "0.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5017294ff4bb30944501348f6f8e42e6ad28f42c8bbef7a74029aff064a4e3c2" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http 1.2.0", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "half" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +dependencies = [ + "cfg-if", + "crunchy", +] + +[[package]] +name = "halfbrown" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8588661a8607108a5ca69cab034063441a0413a0b041c13618a7dd348021ef6f" +dependencies = [ + "hashbrown 0.14.5", + "serde", +] + +[[package]] +name = "hash_hasher" +version = "2.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74721d007512d0cb3338cd20f0654ac913920061a4c4d0d8708edb3f2a698c0c" + +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", + "allocator-api2", + "rayon", + "serde", +] + +[[package]] +name = "hashbrown" +version = "0.15.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash", + "rayon", + "serde", +] + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "hermit-abi" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" + +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + +[[package]] +name = "home" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf" +dependencies = [ + "windows-sys 0.59.0", +] + +[[package]] +name = "http" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + +[[package]] +name = "http" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f16ca2af56261c99fba8bac40a10251ce8188205a4c448fbb745a2e4daa76fea" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + +[[package]] +name = "http-body" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" +dependencies = [ + "bytes", + "http 0.2.12", + "pin-project-lite", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http 1.2.0", +] + +[[package]] +name = "http-body-util" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" +dependencies = [ + "bytes", + "futures-util", + "http 1.2.0", + "http-body 1.0.1", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + +[[package]] +name = "hyper" +version = "0.14.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41dfc780fdec9373c01bae43289ea34c972e40ee3c9f6b3c8801a35f35586ce7" +dependencies = [ + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "h2 0.3.26", + "http 0.2.12", + "http-body 0.4.6", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "socket2", + "tokio", + "tower-service", + "tracing", + "want", +] + +[[package]] +name = "hyper" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "h2 0.4.8", + "http 1.2.0", + "http-body 1.0.1", + "httparse", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", + "want", +] + +[[package]] +name = "hyper-rustls" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" +dependencies = [ + "futures-util", + "http 0.2.12", + "hyper 0.14.32", + "log", + "rustls 0.21.12", + "rustls-native-certs 0.6.3", + "tokio", + "tokio-rustls 0.24.1", +] + +[[package]] +name = "hyper-rustls" +version = "0.27.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2" +dependencies = [ + "futures-util", + "http 1.2.0", + "hyper 1.6.0", + "hyper-util", + "rustls 0.23.23", + "rustls-native-certs 0.8.1", + "rustls-pki-types", + "tokio", + "tokio-rustls 0.26.2", + "tower-service", +] + +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper 1.6.0", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + +[[package]] +name = "hyper-util" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "http 1.2.0", + "http-body 1.0.1", + "hyper 1.6.0", + "pin-project-lite", + "socket2", + "tokio", + "tower-service", + "tracing", +] + +[[package]] +name = "iana-time-zone" +version = "0.1.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core 0.52.0", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "icu_collections" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locid" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_locid_transform" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_locid_transform_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_locid_transform_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e" + +[[package]] +name = "icu_normalizer" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "utf16_iter", + "utf8_iter", + "write16", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516" + +[[package]] +name = "icu_properties" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_locid_transform", + "icu_properties_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569" + +[[package]] +name = "icu_provider" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_provider_macros", + "stable_deref_trait", + "tinystr", + "writeable", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_provider_macros" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "idna" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" +dependencies = [ + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71" +dependencies = [ + "icu_normalizer", + "icu_properties", +] + +[[package]] +name = "indexmap" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652" +dependencies = [ + "equivalent", + "hashbrown 0.15.2", + "serde", +] + +[[package]] +name = "indoc" +version = "2.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" + +[[package]] +name = "inventory" +version = "0.3.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab08d7cd2c5897f2c949e5383ea7c7db03fb19130ffcfbf7eda795137ae3cb83" +dependencies = [ + "rustversion", +] + +[[package]] +name = "ipnet" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" + +[[package]] +name = "is-terminal" +version = "0.4.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e19b23d53f35ce9f56aebc7d1bb4e6ac1e9c0db7ac85c8d1760c04379edced37" +dependencies = [ + "hermit-abi", + "libc", + "windows-sys 0.59.0", +] + +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" + +[[package]] +name = "jobserver" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +dependencies = [ + "libc", +] + +[[package]] +name = "js-sys" +version = "0.3.77" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f" +dependencies = [ + "once_cell", + "wasm-bindgen", +] + +[[package]] +name = "jsonpath_lib_polars_vendor" +version = "0.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4bd9354947622f7471ff713eacaabdb683ccb13bba4edccaab9860abf480b7d" +dependencies = [ + "log", + "serde", + "serde_json", +] + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + +[[package]] +name = "libc" +version = "0.2.170" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "875b3680cb2f8f71bdcf9a30f38d48282f5d3c95cbf9b3fa57269bb5d5c06828" + +[[package]] +name = "libflate" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ff4ae71b685bbad2f2f391fe74f6b7659a34871c08b210fdc039e43bee07d18" +dependencies = [ + "adler32", + "crc32fast", + "libflate_lz77 1.2.0", +] + +[[package]] +name = "libflate" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45d9dfdc14ea4ef0900c1cddbc8dcd553fbaacd8a4a282cf4018ae9dd04fb21e" +dependencies = [ + "adler32", + "core2", + "crc32fast", + "dary_heap", + "libflate_lz77 2.1.0", +] + +[[package]] +name = "libflate_lz77" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a52d3a8bfc85f250440e4424db7d857e241a3aebbbe301f3eb606ab15c39acbf" +dependencies = [ + "rle-decode-fast", +] + +[[package]] +name = "libflate_lz77" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e0d73b369f386f1c44abd9c570d5318f55ccde816ff4b562fa452e5182863d" +dependencies = [ + "core2", + "hashbrown 0.14.5", + "rle-decode-fast", +] + +[[package]] +name = "libloading" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" +dependencies = [ + "cfg-if", + "windows-targets 0.52.6", +] + +[[package]] +name = "libm" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" + +[[package]] +name = "libmimalloc-sys" +version = "0.1.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23aa6811d3bd4deb8a84dde645f943476d13b248d818edcf8ce0b2f37f036b44" +dependencies = [ + "cc", + "libc", +] + +[[package]] +name = "libz-rs-sys" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "902bc563b5d65ad9bba616b490842ef0651066a1a1dc3ce1087113ffcb873c8d" +dependencies = [ + "zlib-rs", +] + +[[package]] +name = "linux-raw-sys" +version = "0.4.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" + +[[package]] +name = "litemap" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23fb14cb19457329c82206317a5663005a4d404783dc74f4252769b0d5f42856" + +[[package]] +name = "lock_api" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +dependencies = [ + "autocfg", + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e" + +[[package]] +name = "lru" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38" +dependencies = [ + "hashbrown 0.15.2", +] + +[[package]] +name = "lz4" +version = "1.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a20b523e860d03443e98350ceaac5e71c6ba89aea7d960769ec3ce37f4de5af4" +dependencies = [ + "lz4-sys", +] + +[[package]] +name = "lz4-sys" +version = "1.11.1+lz4-1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bd8c0d6c6ed0cd30b3652886bb8711dc4bb01d637a68105a3d5158039b418e6" +dependencies = [ + "cc", + "libc", +] + +[[package]] +name = "lz4_flex" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75761162ae2b0e580d7e7c390558127e5f01b4194debd6221fd8c207fc80e3f5" +dependencies = [ + "twox-hash", +] + +[[package]] +name = "matrixmultiply" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] +name = "md-5" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" +dependencies = [ + "cfg-if", + "digest", +] + +[[package]] +name = "memchr" +version = "2.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" + +[[package]] +name = "memmap2" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd3f7eed9d3848f8b98834af67102b720745c4ec028fcd0aa0239277e7de374f" +dependencies = [ + "libc", +] + +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + +[[package]] +name = "mimalloc" +version = "0.1.43" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68914350ae34959d83f732418d51e2427a794055d0b9529f48259ac07af65633" +dependencies = [ + "libmimalloc-sys", +] + +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + +[[package]] +name = "miniz_oxide" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e3e04debbb59698c15bacbb6d93584a8c0ca9cc3213cb423d31f760d8843ce5" +dependencies = [ + "adler2", +] + +[[package]] +name = "mio" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" +dependencies = [ + "libc", + "wasi 0.11.0+wasi-snapshot-preview1", + "windows-sys 0.52.0", +] + +[[package]] +name = "native-tls" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework 2.11.1", + "security-framework-sys", + "tempfile", +] + +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + +[[package]] +name = "now" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d89e9874397a1f0a52fc1f197a8effd9735223cb2390e9dcc83ac6cd02923d0" +dependencies = [ + "chrono", +] + +[[package]] +name = "ntapi" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8a3895c6391c39d7fe7ebc444a87eb2991b2a0bc718fdabd071eec617fc68e4" +dependencies = [ + "winapi", +] + +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", + "serde", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "numpy" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7cfbf3f0feededcaa4d289fe3079b03659e85c5b5a177f4ba6fb01ab4fb3e39" +dependencies = [ + "libc", + "ndarray", + "num-complex", + "num-integer", + "num-traits", + "pyo3", + "pyo3-build-config", + "rustc-hash", +] + +[[package]] +name = "objc-sys" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdb91bdd390c7ce1a8607f35f3ca7151b65afc0ff5ff3b34fa350f7d7c7e4310" + +[[package]] +name = "objc2" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46a785d4eeff09c14c487497c162e92766fbb3e4059a71840cecc03d9a50b804" +dependencies = [ + "objc-sys", + "objc2-encode", +] + +[[package]] +name = "objc2-app-kit" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4e89ad9e3d7d297152b17d39ed92cd50ca8063a89a9fa569046d41568891eff" +dependencies = [ + "bitflags", + "block2", + "libc", + "objc2", + "objc2-core-data", + "objc2-core-image", + "objc2-foundation", + "objc2-quartz-core", +] + +[[package]] +name = "objc2-core-data" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "617fbf49e071c178c0b24c080767db52958f716d9eabdf0890523aeae54773ef" +dependencies = [ + "bitflags", + "block2", + "objc2", + "objc2-foundation", +] + +[[package]] +name = "objc2-core-image" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55260963a527c99f1819c4f8e3b47fe04f9650694ef348ffd2227e8196d34c80" +dependencies = [ + "block2", + "objc2", + "objc2-foundation", + "objc2-metal", +] + +[[package]] +name = "objc2-encode" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33" + +[[package]] +name = "objc2-foundation" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ee638a5da3799329310ad4cfa62fbf045d5f56e3ef5ba4149e7452dcf89d5a8" +dependencies = [ + "bitflags", + "block2", + "libc", + "objc2", +] + +[[package]] +name = "objc2-metal" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd0cba1276f6023976a406a14ffa85e1fdd19df6b0f737b063b95f6c8c7aadd6" +dependencies = [ + "bitflags", + "block2", + "objc2", + "objc2-foundation", +] + +[[package]] +name = "objc2-quartz-core" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e42bee7bff906b14b167da2bac5efe6b6a07e6f7c0a21a7308d40c960242dc7a" +dependencies = [ + "bitflags", + "block2", + "objc2", + "objc2-foundation", + "objc2-metal", +] + +[[package]] +name = "object" +version = "0.36.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87" +dependencies = [ + "memchr", +] + +[[package]] +name = "object_store" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3cfccb68961a56facde1163f9319e0d15743352344e7808a11795fb99698dcaf" +dependencies = [ + "async-trait", + "base64 0.22.1", + "bytes", + "chrono", + "futures", + "httparse", + "humantime", + "hyper 1.6.0", + "itertools 0.13.0", + "md-5", + "parking_lot", + "percent-encoding", + "quick-xml", + "rand", + "reqwest", + "ring", + "rustls-pemfile 2.2.0", + "serde", + "serde_json", + "snafu", + "tokio", + "tracing", + "url", + "walkdir", +] + +[[package]] +name = "once_cell" +version = "1.20.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e" + +[[package]] +name = "oorandom" +version = "11.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" + +[[package]] +name = "openssl" +version = "0.10.72" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fedfea7d58a1f73118430a55da6a286e7b044961736ce96a16a17068ea25e5da" +dependencies = [ + "bitflags", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "openssl-probe" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" + +[[package]] +name = "openssl-sys" +version = "0.9.107" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8288979acd84749c744a9014b4382d42b8f7b2592847b5afb2ed29e5d16ede07" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "outref" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" + +[[package]] +name = "p256" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51f44edd08f51e2ade572f141051021c5af22677e42b7dd28a88155151c33594" +dependencies = [ + "ecdsa", + "elliptic-curve", + "sha2", +] + +[[package]] +name = "parking" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" + +[[package]] +name = "parking_lot" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets 0.52.6", +] + +[[package]] +name = "parse-zoneinfo" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f2a05b18d44e2957b88f96ba460715e295bc1d7510468a2f3d3b44535d26c24" +dependencies = [ + "regex", +] + +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + +[[package]] +name = "percent-encoding" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" + +[[package]] +name = "phf" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" +dependencies = [ + "phf_shared", +] + +[[package]] +name = "phf_codegen" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aef8048c789fa5e851558d709946d6d79a8ff88c0440c587967f8e94bfb1216a" +dependencies = [ + "phf_generator", + "phf_shared", +] + +[[package]] +name = "phf_generator" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" +dependencies = [ + "phf_shared", + "rand", +] + +[[package]] +name = "phf_shared" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" +dependencies = [ + "siphasher", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "pkcs8" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9eca2c590a5f85da82668fa685c09ce2888b9430e83299debf1f34b65fd4a4ba" +dependencies = [ + "der", + "spki", +] + +[[package]] +name = "pkg-config" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" + +[[package]] +name = "planus" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc1691dd09e82f428ce8d6310bd6d5da2557c82ff17694d2a32cad7242aea89f" +dependencies = [ + "array-init-cursor", +] + +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + +[[package]] +name = "polars" +version = "0.46.0" +dependencies = [ + "apache-avro", + "avro-schema", + "chrono", + "either", + "ethnum", + "futures", + "getrandom 0.2.15", + "polars-arrow", + "polars-core", + "polars-error", + "polars-io", + "polars-lazy", + "polars-ops", + "polars-parquet", + "polars-plan", + "polars-sql", + "polars-time", + "polars-utils", + "proptest", + "rand", + "tokio", + "tokio-util", + "version_check", +] + +[[package]] +name = "polars-arrow" +version = "0.46.0" +dependencies = [ + "async-stream", + "atoi_simd", + "avro-schema", + "bytemuck", + "chrono", + "chrono-tz", + "criterion", + "crossbeam-channel", + "doc-comment", + "dyn-clone", + "either", + "ethnum", + "fast-float2", + "flate2", + "futures", + "getrandom 0.2.15", + "hashbrown 0.15.2", + "hex", + "indexmap", + "itoa", + "lz4", + "num-traits", + "parking_lot", + "polars-arrow-format", + "polars-error", + "polars-schema", + "polars-utils", + "proptest", + "rand", + "regex", + "regex-syntax 0.8.5", + "ryu", + "sample-arrow2", + "sample-std", + "sample-test", + "serde", + "simdutf8", + "streaming-iterator", + "strength_reduce", + "strum_macros", + "tokio", + "tokio-util", + "version_check", + "zstd", +] + +[[package]] +name = "polars-arrow-format" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19b0ef2474af9396b19025b189d96e992311e6a47f90c53cd998b36c4c64b84c" +dependencies = [ + "planus", + "prost", + "prost-derive", + "serde", +] + +[[package]] +name = "polars-compute" +version = "0.46.0" +dependencies = [ + "atoi_simd", + "bytemuck", + "chrono", + "either", + "fast-float2", + "hashbrown 0.15.2", + "itoa", + "num-traits", + "polars-arrow", + "polars-error", + "polars-utils", + "rand", + "ryu", + "serde", + "skiplist", + "strength_reduce", + "strum_macros", + "version_check", +] + +[[package]] +name = "polars-core" +version = "0.46.0" +dependencies = [ + "bincode", + "bitflags", + "bytemuck", + "chrono", + "chrono-tz", + "comfy-table", + "either", + "hashbrown 0.14.5", + "hashbrown 0.15.2", + "indexmap", + "itoa", + "ndarray", + "num-traits", + "polars-arrow", + "polars-compute", + "polars-error", + "polars-row", + "polars-schema", + "polars-utils", + "rand", + "rand_distr", + "rayon", + "regex", + "serde", + "serde_json", + "strum_macros", + "version_check", + "xxhash-rust", +] + +[[package]] +name = "polars-doc-examples" +version = "0.46.0" +dependencies = [ + "aws-config", + "aws-sdk-s3", + "aws-smithy-checksums 0.60.13", + "chrono", + "polars", + "rand", + "reqwest", + "tokio", +] + +[[package]] +name = "polars-dylib" +version = "0.46.0" +dependencies = [ + "polars", + "polars-arrow", + "polars-core", + "polars-expr", + "polars-lazy", + "polars-mem-engine", + "polars-plan", + "polars-python", +] + +[[package]] +name = "polars-error" +version = "0.46.0" +dependencies = [ + "avro-schema", + "object_store", + "parking_lot", + "polars-arrow-format", + "regex", + "signal-hook", + "simdutf8", +] + +[[package]] +name = "polars-expr" +version = "0.46.0" +dependencies = [ + "bitflags", + "hashbrown 0.15.2", + "num-traits", + "polars-arrow", + "polars-compute", + "polars-core", + "polars-io", + "polars-json", + "polars-ops", + "polars-plan", + "polars-row", + "polars-time", + "polars-utils", + "rand", + "rayon", + "recursive", +] + +[[package]] +name = "polars-ffi" +version = "0.46.0" +dependencies = [ + "polars-arrow", + "polars-core", +] + +[[package]] +name = "polars-io" +version = "0.46.0" +dependencies = [ + "async-trait", + "atoi_simd", + "blake3", + "bytes", + "chrono", + "chrono-tz", + "fast-float2", + "flate2", + "fs4", + "futures", + "glob", + "hashbrown 0.15.2", + "home", + "itoa", + "memchr", + "memmap2", + "num-traits", + "object_store", + "percent-encoding", + "polars-arrow", + "polars-core", + "polars-error", + "polars-json", + "polars-parquet", + "polars-schema", + "polars-time", + "polars-utils", + "pyo3", + "rayon", + "regex", + "reqwest", + "ryu", + "serde", + "serde_json", + "simd-json", + "simdutf8", + "strum", + "strum_macros", + "tempfile", + "tokio", + "tokio-util", + "url", + "zstd", +] + +[[package]] +name = "polars-json" +version = "0.46.0" +dependencies = [ + "chrono", + "chrono-tz", + "fallible-streaming-iterator", + "hashbrown 0.15.2", + "indexmap", + "itoa", + "num-traits", + "polars-arrow", + "polars-compute", + "polars-error", + "polars-utils", + "ryu", + "simd-json", + "streaming-iterator", +] + +[[package]] +name = "polars-lazy" +version = "0.46.0" +dependencies = [ + "bitflags", + "chrono", + "either", + "futures", + "memchr", + "polars-arrow", + "polars-compute", + "polars-core", + "polars-expr", + "polars-io", + "polars-json", + "polars-mem-engine", + "polars-ops", + "polars-pipe", + "polars-plan", + "polars-stream", + "polars-time", + "polars-utils", + "pyo3", + "rayon", + "serde_json", + "tokio", + "version_check", +] + +[[package]] +name = "polars-mem-engine" +version = "0.46.0" +dependencies = [ + "futures", + "memmap2", + "polars-arrow", + "polars-core", + "polars-error", + "polars-expr", + "polars-io", + "polars-json", + "polars-ops", + "polars-plan", + "polars-time", + "polars-utils", + "pyo3", + "rayon", + "recursive", + "tokio", +] + +[[package]] +name = "polars-ops" +version = "0.46.0" +dependencies = [ + "aho-corasick", + "argminmax", + "base64 0.22.1", + "bytemuck", + "chrono", + "chrono-tz", + "either", + "hashbrown 0.15.2", + "hex", + "indexmap", + "jsonpath_lib_polars_vendor", + "libm", + "memchr", + "num-traits", + "polars-arrow", + "polars-compute", + "polars-core", + "polars-error", + "polars-json", + "polars-schema", + "polars-utils", + "rand", + "rand_distr", + "rayon", + "regex", + "regex-syntax 0.8.5", + "serde", + "serde_json", + "strum_macros", + "unicode-normalization", + "unicode-reverse", + "version_check", +] + +[[package]] +name = "polars-parquet" +version = "0.46.0" +dependencies = [ + "async-stream", + "base64 0.22.1", + "brotli", + "bytemuck", + "ethnum", + "fallible-streaming-iterator", + "flate2", + "futures", + "hashbrown 0.15.2", + "lz4", + "lz4_flex", + "num-traits", + "polars-arrow", + "polars-compute", + "polars-error", + "polars-parquet-format", + "polars-utils", + "rand", + "serde", + "simdutf8", + "snap", + "streaming-decompression", + "xxhash-rust", + "zstd", +] + +[[package]] +name = "polars-parquet-format" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c025243dcfe8dbc57e94d9f82eb3bef10b565ab180d5b99bed87fd8aea319ce1" +dependencies = [ + "async-trait", + "futures", +] + +[[package]] +name = "polars-pipe" +version = "0.46.0" +dependencies = [ + "crossbeam-channel", + "crossbeam-queue", + "enum_dispatch", + "futures", + "hashbrown 0.15.2", + "num-traits", + "polars-arrow", + "polars-compute", + "polars-core", + "polars-expr", + "polars-io", + "polars-ops", + "polars-plan", + "polars-row", + "polars-utils", + "rayon", + "tokio", + "uuid", + "version_check", +] + +[[package]] +name = "polars-plan" +version = "0.46.0" +dependencies = [ + "bitflags", + "bytemuck", + "bytes", + "chrono", + "chrono-tz", + "either", + "futures", + "hashbrown 0.15.2", + "libloading", + "memmap2", + "num-traits", + "percent-encoding", + "polars-arrow", + "polars-compute", + "polars-core", + "polars-ffi", + "polars-io", + "polars-json", + "polars-ops", + "polars-parquet", + "polars-time", + "polars-utils", + "pyo3", + "rayon", + "recursive", + "regex", + "serde", + "serde_json", + "strum_macros", + "version_check", +] + +[[package]] +name = "polars-python" +version = "0.46.0" +dependencies = [ + "arboard", + "bincode", + "bytemuck", + "bytes", + "chrono", + "chrono-tz", + "either", + "flate2", + "hashbrown 0.15.2", + "itoa", + "libc", + "ndarray", + "num-traits", + "numpy", + "polars", + "polars-arrow", + "polars-compute", + "polars-core", + "polars-error", + "polars-expr", + "polars-ffi", + "polars-io", + "polars-lazy", + "polars-mem-engine", + "polars-ops", + "polars-parquet", + "polars-plan", + "polars-row", + "polars-time", + "polars-utils", + "pyo3", + "rayon", + "recursive", + "serde_json", + "version_check", +] + +[[package]] +name = "polars-row" +version = "0.46.0" +dependencies = [ + "bitflags", + "bytemuck", + "polars-arrow", + "polars-compute", + "polars-error", + "polars-utils", +] + +[[package]] +name = "polars-schema" +version = "0.46.0" +dependencies = [ + "indexmap", + "polars-error", + "polars-utils", + "serde", + "version_check", +] + +[[package]] +name = "polars-sql" +version = "0.46.0" +dependencies = [ + "bitflags", + "hex", + "polars-core", + "polars-error", + "polars-lazy", + "polars-ops", + "polars-plan", + "polars-time", + "polars-utils", + "rand", + "regex", + "serde", + "sqlparser", +] + +[[package]] +name = "polars-stream" +version = "0.46.0" +dependencies = [ + "async-channel", + "async-trait", + "atomic-waker", + "bitflags", + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-queue", + "crossbeam-utils", + "futures", + "memmap2", + "parking_lot", + "percent-encoding", + "pin-project-lite", + "polars-arrow", + "polars-core", + "polars-error", + "polars-expr", + "polars-io", + "polars-mem-engine", + "polars-ops", + "polars-parquet", + "polars-plan", + "polars-utils", + "pyo3", + "rand", + "rayon", + "recursive", + "slotmap", + "tokio", + "version_check", +] + +[[package]] +name = "polars-testing" +version = "0.46.0" +dependencies = [ + "polars-core", + "polars-ops", +] + +[[package]] +name = "polars-time" +version = "0.46.0" +dependencies = [ + "atoi_simd", + "bytemuck", + "chrono", + "chrono-tz", + "now", + "num-traits", + "polars-arrow", + "polars-compute", + "polars-core", + "polars-error", + "polars-ops", + "polars-utils", + "rayon", + "regex", + "serde", + "strum_macros", +] + +[[package]] +name = "polars-utils" +version = "0.46.0" +dependencies = [ + "bincode", + "bytemuck", + "bytes", + "compact_str", + "flate2", + "foldhash", + "hashbrown 0.15.2", + "indexmap", + "libc", + "memmap2", + "num-traits", + "polars-error", + "pyo3", + "rand", + "raw-cpuid", + "rayon", + "regex", + "rmp-serde", + "serde", + "serde_json", + "slotmap", + "stacker", + "sysinfo", + "version_check", +] + +[[package]] +name = "portable-atomic" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" + +[[package]] +name = "portable-atomic-util" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] + +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + +[[package]] +name = "ppv-lite86" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "proc-macro2" +version = "1.0.94" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a31971752e70b8b2686d7e46ec17fb38dad4051d94024c88df49b667caea9c84" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "proptest" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14cae93065090804185d3b75f0bf93b8eeda30c7a9b4a33d3bdb3988d6229e50" +dependencies = [ + "bitflags", + "lazy_static", + "num-traits", + "rand", + "rand_chacha", + "rand_xorshift", + "regex-syntax 0.8.5", + "unarray", +] + +[[package]] +name = "prost" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b82eaa1d779e9a4bc1c3217db8ffbeabaae1dca241bf70183242128d48681cd" +dependencies = [ + "bytes", + "prost-derive", +] + +[[package]] +name = "prost-derive" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5d2d8d10f3c6ded6da8b05b5fb3b8a5082514344d56c9f871412d29b4e075b4" +dependencies = [ + "anyhow", + "itertools 0.10.5", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "psm" +version = "0.1.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f58e5423e24c18cc840e1c98370b3993c6649cd1678b4d24318bcf0a083cbe88" +dependencies = [ + "cc", +] + +[[package]] +name = "py-polars" +version = "1.28.1" +dependencies = [ + "either", + "libc", + "mimalloc", + "polars", + "polars-python", + "pyo3", + "tikv-jemallocator", +] + +[[package]] +name = "pyo3" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5203598f366b11a02b13aa20cab591229ff0a89fd121a308a5df751d5fc9219" +dependencies = [ + "cfg-if", + "chrono", + "chrono-tz", + "indoc", + "inventory", + "libc", + "memoffset", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99636d423fa2ca130fa5acde3059308006d46f98caac629418e53f7ebb1e9999" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78f9cf92ba9c409279bc3305b5409d90db2d2c22392d443a87df3a1adad59e33" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b999cb1a6ce21f9a6b147dcf1be9ffedf02e0043aec74dc390f3007047cecd9" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "822ece1c7e1012745607d5cf0bcb2874769f0f7cb34c4cde03b9358eb9ef911a" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "quad-rand" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a651516ddc9168ebd67b24afd085a718be02f8858fe406591b013d101ce2f40" + +[[package]] +name = "quick-xml" +version = "0.37.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "165859e9e55f79d67b96c5d96f4e88b6f2695a1972849c15a6a3f5c59fc2c003" +dependencies = [ + "memchr", + "serde", +] + +[[package]] +name = "quickcheck" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "588f6378e4dd99458b60ec275b4477add41ce4fa9f64dcba6f15adccb19b50d6" +dependencies = [ + "env_logger", + "log", + "rand", +] + +[[package]] +name = "quinn" +version = "0.11.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef" +dependencies = [ + "bytes", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash", + "rustls 0.23.23", + "socket2", + "thiserror 2.0.12", + "tokio", + "tracing", +] + +[[package]] +name = "quinn-proto" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d" +dependencies = [ + "bytes", + "getrandom 0.2.15", + "rand", + "ring", + "rustc-hash", + "rustls 0.23.23", + "rustls-pki-types", + "slab", + "thiserror 2.0.12", + "tinyvec", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-udp" +version = "0.5.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e46f3055866785f6b92bc6164b76be02ca8f2eb4b002c0354b28cf4c119e5944" +dependencies = [ + "cfg_aliases", + "libc", + "once_cell", + "socket2", + "tracing", + "windows-sys 0.59.0", +] + +[[package]] +name = "quote" +version = "1.0.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1f1914ce909e1658d9907913b4b91947430c7d9be598b15a1912935b8c04801" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom 0.2.15", +] + +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + +[[package]] +name = "rand_regex" +version = "0.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b2a9fe2d7d9eeaf3279d1780452a5bbd26b31b27938787ef1c3e930d1e9cfbd" +dependencies = [ + "rand", + "regex-syntax 0.6.29", +] + +[[package]] +name = "rand_xorshift" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d25bf25ec5ae4a3f1b92f929810509a2f53d7dca2f50b794ff57e3face536c8f" +dependencies = [ + "rand_core", +] + +[[package]] +name = "raw-cpuid" +version = "11.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6df7ab838ed27997ba19a4664507e6f82b41fe6e20be42929332156e5e85146" +dependencies = [ + "bitflags", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "recursive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0786a43debb760f491b1bc0269fe5e84155353c67482b9e60d0cfb596054b43e" +dependencies = [ + "recursive-proc-macro-impl", + "stacker", +] + +[[package]] +name = "recursive-proc-macro-impl" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" +dependencies = [ + "quote", + "syn 2.0.99", +] + +[[package]] +name = "redox_syscall" +version = "0.5.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b8c0c260b63a8219631167be35e6a988e9554dbd323f8bd08439c8ed1302bd1" +dependencies = [ + "bitflags", +] + +[[package]] +name = "ref-cast" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a0ae411dbe946a674d89546582cea4ba2bb8defac896622d6496f14c23ba5cf" +dependencies = [ + "ref-cast-impl", +] + +[[package]] +name = "ref-cast-impl" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1165225c21bff1f3bbce98f5a1f889949bc902d3575308cc7b0de30b4f6d27c7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "regex" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax 0.8.5", +] + +[[package]] +name = "regex-automata" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax 0.8.5", +] + +[[package]] +name = "regex-lite" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" + +[[package]] +name = "regex-syntax" +version = "0.6.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" + +[[package]] +name = "regex-syntax" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" + +[[package]] +name = "reqwest" +version = "0.12.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da" +dependencies = [ + "base64 0.22.1", + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "h2 0.4.8", + "http 1.2.0", + "http-body 1.0.1", + "http-body-util", + "hyper 1.6.0", + "hyper-rustls 0.27.5", + "hyper-tls", + "hyper-util", + "ipnet", + "js-sys", + "log", + "mime", + "native-tls", + "once_cell", + "percent-encoding", + "pin-project-lite", + "quinn", + "rustls 0.23.23", + "rustls-native-certs 0.8.1", + "rustls-pemfile 2.2.0", + "rustls-pki-types", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tokio-native-tls", + "tokio-rustls 0.26.2", + "tokio-util", + "tower", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "wasm-streams", + "web-sys", + "windows-registry", +] + +[[package]] +name = "rfc6979" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7743f17af12fa0b03b803ba12cd6a8d9483a587e89c69445e3909655c0b9fabb" +dependencies = [ + "crypto-bigint 0.4.9", + "hmac", + "zeroize", +] + +[[package]] +name = "ring" +version = "0.17.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70ac5d832aa16abd7d1def883a8545280c20a60f523a370aa3a9617c2b8550ee" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.15", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + +[[package]] +name = "rle-decode-fast" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3582f63211428f83597b51b2ddb88e2a91a9d52d12831f9d08f5e624e8977422" + +[[package]] +name = "rmp" +version = "0.8.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "228ed7c16fa39782c3b3468e974aec2795e9089153cd08ee2e9aefb3613334c4" +dependencies = [ + "byteorder", + "num-traits", + "paste", +] + +[[package]] +name = "rmp-serde" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52e599a477cf9840e92f2cde9a7189e67b42c57532749bf90aea6ec10facd4db" +dependencies = [ + "byteorder", + "rmp", + "serde", +] + +[[package]] +name = "rustc-demangle" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" + +[[package]] +name = "rustc-hash" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" + +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + +[[package]] +name = "rustix" +version = "0.38.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.59.0", +] + +[[package]] +name = "rustls" +version = "0.21.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e" +dependencies = [ + "log", + "ring", + "rustls-webpki 0.101.7", + "sct", +] + +[[package]] +name = "rustls" +version = "0.23.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47796c98c480fce5406ef69d1c76378375492c3b0a0de587be0c1d9feb12f395" +dependencies = [ + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki 0.102.8", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-native-certs" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00" +dependencies = [ + "openssl-probe", + "rustls-pemfile 1.0.4", + "schannel", + "security-framework 2.11.1", +] + +[[package]] +name = "rustls-native-certs" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3" +dependencies = [ + "openssl-probe", + "rustls-pki-types", + "schannel", + "security-framework 3.2.0", +] + +[[package]] +name = "rustls-pemfile" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" +dependencies = [ + "base64 0.21.7", +] + +[[package]] +name = "rustls-pemfile" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "rustls-pki-types" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c" +dependencies = [ + "web-time", +] + +[[package]] +name = "rustls-webpki" +version = "0.101.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" +dependencies = [ + "ring", + "untrusted", +] + +[[package]] +name = "rustls-webpki" +version = "0.102.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + +[[package]] +name = "rustversion" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2" + +[[package]] +name = "ryu" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" + +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "sample-arrow2" +version = "0.17.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "502b30097ae5cc57ee8359bb59d8af349db022492de04596119d83f561ab8977" +dependencies = [ + "arrow2", + "sample-std", +] + +[[package]] +name = "sample-std" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "948bd219c6eb2b2ca1e004d8aefa8bbcf12614f60e0139b1758b49f9a94358c8" +dependencies = [ + "casey", + "quickcheck", + "rand", + "rand_regex", + "regex", +] + +[[package]] +name = "sample-test" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8b253ca516416756b09b582e2b7275de8f51f35e5d5711e20712b9377c7d5bf" +dependencies = [ + "quickcheck", + "sample-std", + "sample-test-macros", +] + +[[package]] +name = "sample-test-macros" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cc6439a7589bb4581fdadb6391700ce4d26f8bffd34e2a75acb320822e9b5ef" +dependencies = [ + "proc-macro2", + "quote", + "sample-std", + "syn 1.0.109", +] + +[[package]] +name = "schannel" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" +dependencies = [ + "windows-sys 0.59.0", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "sct" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" +dependencies = [ + "ring", + "untrusted", +] + +[[package]] +name = "sec1" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3be24c1842290c45df0a7bf069e0c268a747ad05a192f2fd7dcfdbc1cba40928" +dependencies = [ + "base16ct", + "der", + "generic-array", + "pkcs8", + "subtle", + "zeroize", +] + +[[package]] +name = "security-framework" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags", + "core-foundation 0.9.4", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271720403f46ca04f7ba6f55d438f8bd878d6b8ca0a1046e8228c4145bcbb316" +dependencies = [ + "bitflags", + "core-foundation 0.10.0", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "semver" +version = "1.0.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0" + +[[package]] +name = "serde" +version = "1.0.218" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8dfc9d19bdbf6d17e22319da49161d5d0108e4188e8b680aef6299eed22df60" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_bytes" +version = "0.11.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "364fec0df39c49a083c9a8a18a23a6bcfd9af130fe9fe321d18520a0d113e09e" +dependencies = [ + "serde", +] + +[[package]] +name = "serde_derive" +version = "1.0.218" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f09503e191f4e797cb8aac08e9a4a4695c5edf6a2e70e376d961ddd5c969f82b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "serde_json" +version = "1.0.140" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" +dependencies = [ + "indexmap", + "itoa", + "memchr", + "ryu", + "serde", +] + +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + +[[package]] +name = "sha2" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "signal-hook" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801" +dependencies = [ + "libc", + "signal-hook-registry", +] + +[[package]] +name = "signal-hook-registry" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +dependencies = [ + "libc", +] + +[[package]] +name = "signature" +version = "1.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74233d3b3b2f6d4b006dc19dee745e73e2a6bfb6f93607cd3b02bd5b00797d7c" +dependencies = [ + "digest", + "rand_core", +] + +[[package]] +name = "simd-json" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa2bcf6c6e164e81bc7a5d49fc6988b3d515d9e8c07457d7b74ffb9324b9cd40" +dependencies = [ + "ahash", + "getrandom 0.2.15", + "halfbrown", + "once_cell", + "ref-cast", + "serde", + "serde_json", + "simdutf8", + "value-trait", +] + +[[package]] +name = "simdutf8" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" + +[[package]] +name = "siphasher" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" + +[[package]] +name = "skiplist" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eec25f46463fcdc5e02f388c2780b1b58e01be81a8378e62ec60931beccc3f6" +dependencies = [ + "rand", +] + +[[package]] +name = "slab" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" +dependencies = [ + "autocfg", +] + +[[package]] +name = "slotmap" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbff4acf519f630b3a3ddcfaea6c06b42174d9a44bc70c620e9ed1649d58b82a" +dependencies = [ + "version_check", +] + +[[package]] +name = "smallvec" +version = "1.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd" + +[[package]] +name = "snafu" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "223891c85e2a29c3fe8fb900c1fae5e69c2e42415e3177752e8718475efa5019" +dependencies = [ + "snafu-derive", +] + +[[package]] +name = "snafu-derive" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03c3c6b7927ffe7ecaa769ee0e3994da3b8cafc8f444578982c83ecb161af917" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "snap" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" + +[[package]] +name = "socket2" +version = "0.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + +[[package]] +name = "spki" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67cf02bbac7a337dc36e4f5a693db6c21e7863f45070f7064577eb4367a3212b" +dependencies = [ + "base64ct", + "der", +] + +[[package]] +name = "sqlparser" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05a528114c392209b3264855ad491fcce534b94a38771b0a0b97a79379275ce8" +dependencies = [ + "log", +] + +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + +[[package]] +name = "stacker" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9156ebd5870ef293bfb43f91c7a74528d363ec0d424afe24160ed5a4343d08a" +dependencies = [ + "cc", + "cfg-if", + "libc", + "psm", + "windows-sys 0.59.0", +] + +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + +[[package]] +name = "streaming-decompression" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf6cc3b19bfb128a8ad11026086e31d3ce9ad23f8ea37354b31383a187c44cf3" +dependencies = [ + "fallible-streaming-iterator", +] + +[[package]] +name = "streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b2231b7c3057d5e4ad0156fb3dc807d900806020c5ffa3ee6ff2c8c76fb8520" + +[[package]] +name = "strength_reduce" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" + +[[package]] +name = "strum" +version = "0.26.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" + +[[package]] +name = "strum_macros" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.99", +] + +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.99" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e02e925281e18ffd9d640e234264753c43edc62d64b2d4cf898f1bc5e75f3fc2" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" +dependencies = [ + "futures-core", +] + +[[package]] +name = "synstructure" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "sysinfo" +version = "0.33.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fc858248ea01b66f19d8e8a6d55f41deaf91e9d495246fd01368d99935c6c01" +dependencies = [ + "core-foundation-sys", + "libc", + "memchr", + "ntapi", + "windows", +] + +[[package]] +name = "target-lexicon" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e502f78cdbb8ba4718f566c418c52bc729126ffd16baee5baa718cf25dd5a69a" + +[[package]] +name = "tempfile" +version = "3.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e5a0acb1f3f55f65cc4a866c361b2fb2a0ff6366785ae6fbb5f85df07ba230" +dependencies = [ + "cfg-if", + "fastrand", + "getrandom 0.3.1", + "once_cell", + "rustix", + "windows-sys 0.59.0", +] + +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" +dependencies = [ + "thiserror-impl 2.0.12", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "tikv-jemalloc-sys" +version = "0.6.0+5.3.0-1-ge13ca993e8ccb9ba9847cc330696e02839f328f7" +source = "git+https://github.com/pola-rs/jemallocator?rev=c7991e5bb6b3e9f79db6b0f48dcda67c5c3d2936#c7991e5bb6b3e9f79db6b0f48dcda67c5c3d2936" +dependencies = [ + "cc", + "libc", +] + +[[package]] +name = "tikv-jemallocator" +version = "0.6.0" +source = "git+https://github.com/pola-rs/jemallocator?rev=c7991e5bb6b3e9f79db6b0f48dcda67c5c3d2936#c7991e5bb6b3e9f79db6b0f48dcda67c5c3d2936" +dependencies = [ + "libc", + "tikv-jemalloc-sys", +] + +[[package]] +name = "time" +version = "0.3.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb041120f25f8fbe8fd2dbe4671c7c2ed74d83be2e7a77529bf7e0790ae3f472" +dependencies = [ + "deranged", + "num-conv", + "powerfmt", + "serde", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "765c97a5b985b7c11d7bc27fa927dc4fe6af3a6dfb021d28deb60d3bf51e76ef" + +[[package]] +name = "time-macros" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8093bc3e81c3bc5f7879de09619d06c9a5a5e45ca44dfeeb7225bae38005c5c" +dependencies = [ + "num-conv", + "time-core", +] + +[[package]] +name = "tinystr" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" +dependencies = [ + "displaydoc", + "zerovec", +] + +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + +[[package]] +name = "tinyvec" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09b3661f17e86524eccd4371ab0429194e0d7c008abb45f7a7495b1719463c71" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + +[[package]] +name = "tokio" +version = "1.44.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6b88822cbe49de4185e3a4cbf8321dd487cf5fe0c5c65695fef6346371e9c48" +dependencies = [ + "backtrace", + "bytes", + "libc", + "mio", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "windows-sys 0.52.0", +] + +[[package]] +name = "tokio-macros" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + +[[package]] +name = "tokio-rustls" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" +dependencies = [ + "rustls 0.21.12", + "tokio", +] + +[[package]] +name = "tokio-rustls" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e727b36a1a0e8b74c376ac2211e40c2c8af09fb4013c60d910495810f008e9b" +dependencies = [ + "rustls 0.23.23", + "tokio", +] + +[[package]] +name = "tokio-util" +version = "0.7.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7fcaa8d55a2bdd6b83ace262b016eca0d79ee02818c5c1bcdf0305114081078" +dependencies = [ + "bytes", + "futures-core", + "futures-io", + "futures-sink", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tower" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper", + "tokio", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + +[[package]] +name = "tracing" +version = "0.1.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" +dependencies = [ + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "tracing-core" +version = "0.1.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" +dependencies = [ + "once_cell", +] + +[[package]] +name = "try-lock" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" + +[[package]] +name = "twox-hash" +version = "1.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675" +dependencies = [ + "cfg-if", + "static_assertions", +] + +[[package]] +name = "typed-builder" +version = "0.19.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06fbd5b8de54c5f7c91f6fe4cebb949be2125d7758e630bb58b1d831dbce600" +dependencies = [ + "typed-builder-macro", +] + +[[package]] +name = "typed-builder-macro" +version = "0.19.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9534daa9fd3ed0bd911d462a37f172228077e7abf18c18a5f67199d959205f8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "typenum" +version = "1.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" + +[[package]] +name = "unarray" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eaea85b334db583fe3274d12b4cd1880032beab409c0d774be044d4480ab9a94" + +[[package]] +name = "unicode-ident" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" + +[[package]] +name = "unicode-normalization" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5033c97c4262335cded6d6fc3e5c18ab755e1a3dc96376350f3d8e9f009ad956" +dependencies = [ + "tinyvec", +] + +[[package]] +name = "unicode-reverse" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b6f4888ebc23094adfb574fdca9fdc891826287a6397d2cd28802ffd6f20c76" +dependencies = [ + "unicode-segmentation", +] + +[[package]] +name = "unicode-segmentation" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" + +[[package]] +name = "unicode-width" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" + +[[package]] +name = "unindent" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" + +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] +name = "url" +version = "2.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", +] + +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + +[[package]] +name = "utf16_iter" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + +[[package]] +name = "uuid" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0f540e3240398cce6128b64ba83fdbdd86129c16a3aa1a3a252efd66eb3d587" +dependencies = [ + "getrandom 0.3.1", + "serde", +] + +[[package]] +name = "value-trait" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9170e001f458781e92711d2ad666110f153e4e50bfd5cbd02db6547625714187" +dependencies = [ + "float-cmp", + "halfbrown", + "itoa", + "ryu", +] + +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "vsimd" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64" + +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wasi" +version = "0.13.3+wasi-0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +dependencies = [ + "wit-bindgen-rt", +] + +[[package]] +name = "wasm-bindgen" +version = "0.2.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6" +dependencies = [ + "bumpalo", + "log", + "proc-macro2", + "quote", + "syn 2.0.99", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.50" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "555d470ec0bc3bb57890405e5d4322cc9ea83cebb085523ced7be4144dac1e61" +dependencies = [ + "cfg-if", + "js-sys", + "once_cell", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + +[[package]] +name = "web-sys" +version = "0.3.77" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-util" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" +dependencies = [ + "windows-sys 0.59.0", +] + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "windows" +version = "0.57.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12342cb4d8e3b046f3d80effd474a7a02447231330ef77d71daa6fbc40681143" +dependencies = [ + "windows-core 0.57.0", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-core" +version = "0.57.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2ed2439a290666cd67ecce2b0ffaad89c2a56b976b736e6ece670297897832d" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-result 0.1.2", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-implement" +version = "0.57.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9107ddc059d5b6fbfbffdfa7a7fe3e22a226def0b2608f72e9d552763d3e1ad7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "windows-interface" +version = "0.57.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29bee4b38ea3cde66011baa44dba677c432a78593e202392d1e9070cf2a7fca7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "windows-link" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dccfd733ce2b1753b03b6d3c65edf020262ea35e20ccdf3e288043e6dd620e3" + +[[package]] +name = "windows-registry" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0" +dependencies = [ + "windows-result 0.2.0", + "windows-strings", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-result" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e383302e8ec8515204254685643de10811af0ed97ea37210dc26fb0032647f8" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-result" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-strings" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" +dependencies = [ + "windows-result 0.2.0", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "wit-bindgen-rt" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" +dependencies = [ + "bitflags", +] + +[[package]] +name = "write16" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" + +[[package]] +name = "writeable" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" + +[[package]] +name = "x11rb" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d91ffca73ee7f68ce055750bf9f6eca0780b8c85eff9bc046a3b0da41755e12" +dependencies = [ + "gethostname", + "rustix", + "x11rb-protocol", +] + +[[package]] +name = "x11rb-protocol" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec107c4503ea0b4a98ef47356329af139c0a4f7750e621cf2973cd3385ebcb3d" + +[[package]] +name = "xmlparser" +version = "0.13.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66fee0b777b0f5ac1c69bb06d361268faafa61cd4682ae064a171c16c433e9e4" + +[[package]] +name = "xxhash-rust" +version = "0.8.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdd20c5420375476fbd4394763288da7eb0cc0b8c11deed431a91562af7335d3" + +[[package]] +name = "yoke" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40" +dependencies = [ + "serde", + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", + "synstructure", +] + +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "byteorder", + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "zerofrom" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", + "synstructure", +] + +[[package]] +name = "zeroize" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" + +[[package]] +name = "zerovec" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "zlib-rs" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b20717f0917c908dc63de2e44e97f1e6b126ca58d0e391cee86d504eb8fbd05" + +[[package]] +name = "zstd" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3051792fbdc2e1e143244dc28c60f73d8470e93f3f9cbd0ead44da5ed802722" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.14+zstd.1.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fb060d4926e4ac3a3ad15d864e99ceb5f343c6b34f5bd6d81ae6ed417311be5" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/Cargo.toml b/Cargo.toml index 19827db99ed5..0e03c294fc0f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,15 +1,163 @@ [workspace] - +resolver = "2" members = [ - "bench", - "polars", - "pandas_cmp", - "py-polars", - "examples/iris_classifier" + "crates/*", + "docs/source/src/rust", + "py-polars", +] +default-members = [ + "crates/*", ] +[workspace.package] +version = "0.46.0" +authors = ["Ritchie Vink "] +edition = "2024" +homepage = "https://www.pola.rs/" +license = "MIT" +repository = "https://github.com/pola-rs/polars" + +[workspace.dependencies] +aho-corasick = "1.1" +arboard = { version = "3.4.0", default-features = false } +async-channel = { version = "2.3.1" } +async-trait = { version = "0.1.59" } +atoi_simd = "0.16" +atomic-waker = "1" +avro-schema = { version = "0.3" } +base64 = "0.22.0" +bincode = "1.3.3" +bitflags = "2" +bytemuck = { version = "1.22", features = ["derive", "extern_crate_alloc"] } +bytes = { version = "1.10" } +chrono = { version = "0.4.31", default-features = false, features = ["std"] } +chrono-tz = "0.10" +compact_str = { version = "0.8.0", features = ["serde"] } +crossbeam-channel = "0.5.15" +crossbeam-deque = "0.8.5" +crossbeam-queue = "0.3" +crossbeam-utils = "0.8.20" +either = "1.14" +ethnum = "1.3.2" +fallible-streaming-iterator = "0.1.9" +fast-float2 = { version = "^0.2.2" } +flate2 = { version = "1", default-features = false } +foldhash = "0.1.5" +futures = "0.3.25" +hashbrown = { version = "0.15.0", features = ["rayon", "serde"] } +# https://github.com/rust-lang/hashbrown/issues/564 +hashbrown_old_nightly_hack = { package = "hashbrown", version = "0.14.5", features = ["rayon", "serde"] } +hex = "0.4.3" +indexmap = { version = "2", features = ["std", "serde"] } +itoa = "1.0.6" +libc = "0.2" +libm = "0.2" +memchr = "2.6" +memmap = { package = "memmap2", version = "0.9" } +ndarray = { version = "0.16", default-features = false } +num-traits = "0.2" +numpy = "0.24" +object_store = { version = "0.11", default-features = false } +parking_lot = "0.12" +percent-encoding = "2.3" +pin-project-lite = "0.2" +pyo3 = "0.24.2" +rand = "0.8" +rand_distr = "0.4" +raw-cpuid = "11" +rayon = "1.9" +recursive = "0.1" +regex = "1.9" +regex-syntax = "0.8.5" +reqwest = { version = "0.12", default-features = false } +rmp-serde = "1.3" +ryu = "1.0.13" +serde = { version = "1.0.188", features = ["derive", "rc"] } +serde_json = "1" +simd-json = { version = "0.14", features = ["known-key"] } +simdutf8 = "0.1.4" +skiplist = "0.5.1" +slotmap = "1" +sqlparser = "0.53" +stacker = "0.1" +streaming-iterator = "0.1.9" +strength_reduce = "0.2" +strum = "0.26" +strum_macros = "0.26" +tokio = { version = "1.44", default-features = false } +tokio-util = "0.7.8" +unicode-normalization = "0.1.24" +unicode-reverse = "1.0.8" +url = "2.4" +uuid = { version = "1.15.1", features = ["v4"] } +version_check = "0.9.4" +xxhash-rust = { version = "0.8.6", features = ["xxh3"] } +zstd = "0.13" + +polars = { version = "0.46.0", path = "crates/polars", default-features = false } +polars-compute = { version = "0.46.0", path = "crates/polars-compute", default-features = false } +polars-core = { version = "0.46.0", path = "crates/polars-core", default-features = false } +polars-dylib = { version = "0.46.0", path = "crates/polars-dylib", default-features = false } +polars-error = { version = "0.46.0", path = "crates/polars-error", default-features = false } +polars-expr = { version = "0.46.0", path = "crates/polars-expr", default-features = false } +polars-ffi = { version = "0.46.0", path = "crates/polars-ffi", default-features = false } +polars-io = { version = "0.46.0", path = "crates/polars-io", default-features = false } +polars-json = { version = "0.46.0", path = "crates/polars-json", default-features = false } +polars-lazy = { version = "0.46.0", path = "crates/polars-lazy", default-features = false } +polars-mem-engine = { version = "0.46.0", path = "crates/polars-mem-engine", default-features = false } +polars-ops = { version = "0.46.0", path = "crates/polars-ops", default-features = false } +polars-parquet = { version = "0.46.0", path = "crates/polars-parquet", default-features = false } +polars-pipe = { version = "0.46.0", path = "crates/polars-pipe", default-features = false } +polars-plan = { version = "0.46.0", path = "crates/polars-plan", default-features = false } +polars-python = { version = "0.46.0", path = "crates/polars-python", default-features = false } +polars-row = { version = "0.46.0", path = "crates/polars-row", default-features = false } +polars-schema = { version = "0.46.0", path = "crates/polars-schema", default-features = false } +polars-sql = { version = "0.46.0", path = "crates/polars-sql", default-features = false } +polars-stream = { version = "0.46.0", path = "crates/polars-stream", default-features = false } +polars-time = { version = "0.46.0", path = "crates/polars-time", default-features = false } +polars-utils = { version = "0.46.0", path = "crates/polars-utils", default-features = false } + +[workspace.dependencies.arrow-format] +package = "polars-arrow-format" +version = "0.1.0" + +[workspace.dependencies.arrow] +package = "polars-arrow" +version = "0.46.0" +path = "crates/polars-arrow" +default-features = false +features = [ + "compute_aggregate", + "compute_arithmetics", + "compute_bitwise", + "compute_boolean", + "compute_boolean_kleene", + "compute_comparison", +] + +[patch.crates-io] +# packed_simd_2 = { git = "https://github.com/rust-lang/packed_simd", rev = "e57c7ba11386147e6d2cbad7c88f376aab4bdc86" } +# simd-json = { git = "https://github.com/ritchie46/simd-json", branch = "alignment" } +tikv-jemallocator = { git = "https://github.com/pola-rs/jemallocator", rev = "c7991e5bb6b3e9f79db6b0f48dcda67c5c3d2936" } + +[profile.mindebug-dev] +inherits = "dev" +debug = "line-tables-only" + [profile.release] +lto = "thin" +debug = "line-tables-only" + +[profile.nodebug-release] +inherits = "release" +debug = false + +[profile.debug-release] +inherits = "release" +debug = true + +[profile.dist-release] +inherits = "release" codegen-units = 1 -rustflags = ["-C", "target-cpu=native"] -target-cpu = "native" -lto = "fat" \ No newline at end of file +debug = false +lto = "fat" diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index ed3e909fc85b..000000000000 --- a/Dockerfile +++ /dev/null @@ -1,29 +0,0 @@ -FROM rustlang/rust:nightly-slim - -RUN apt-get update \ -&& apt-get install \ - libssl-dev \ - lld \ - cmake \ - jupyter-notebook \ - pkg-config \ - git \ - -y \ -&& rm -rf /var/lib/apt/lists/* - -RUN useradd -m -d /home/polars -s /bin/bash -U -u 1000 polars \ -&& chown polars /usr/local/ -USER 1000 -RUN mkdir --parents home/polars/.config/evcxr \ -&& cargo install evcxr_jupyter \ -&& cargo install sccache \ -&& evcxr_jupyter --install \ -&& echo ':dep polars = { path = "/polars" }' > home/polars/.config/evcxr/init.evcxr - -RUN mkdir -p $(jupyter --data-dir)/nbextensions \ -&& cd $(jupyter --data-dir)/nbextensions \ -&& git clone --depth=1 https://github.com/lambdalisue/jupyter-vim-binding vim_binding \ -&& jupyter nbextension enable vim_binding/vim_binding - - -CMD [ "bash", "-c", "jupyter-notebook --no-browser --ip=0.0.0.0 --NotebookApp.token=''" ] diff --git a/LICENSE b/LICENSE index 9dbfcb834b96..cc6d55aa7571 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,5 @@ -Copyright (c) 2020 Ritchie Vink +Copyright (c) 2025 Ritchie Vink +Some portions Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal @@ -16,4 +17,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. \ No newline at end of file +SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 000000000000..84f55912ad92 --- /dev/null +++ b/Makefile @@ -0,0 +1,165 @@ +.DEFAULT_GOAL := help + +PYTHONPATH= +SHELL=bash +VENV=.venv + +ifeq ($(OS),Windows_NT) + VENV_BIN=$(VENV)/Scripts +else + VENV_BIN=$(VENV)/bin +endif + +# Detect CPU architecture. +ifeq ($(OS),Windows_NT) + ifeq ($(PROCESSOR_ARCHITECTURE),AMD64) + ARCH := amd64 + else ifeq ($(PROCESSOR_ARCHITECTURE),x86) + ARCH := x86 + else ifeq ($(PROCESSOR_ARCHITECTURE),ARM64) + ARCH := arm64 + else + ARCH := unknown + endif +else + UNAME_P := $(shell uname -p) + ifeq ($(UNAME_P),x86_64) + ARCH := amd64 + else ifneq ($(filter %86,$(UNAME_P)),) + ARCH := x86 + else ifneq ($(filter arm%,$(UNAME_P)),) + ARCH := arm64 + else + ARCH := unknown + endif +endif + +# Ensure boolean arguments are normalized to 1/0 to prevent surprises. +ifdef LTS_CPU + ifeq ($(LTS_CPU),0) + else ifeq ($(LTS_CPU),1) + else +$(error LTS_CPU must be 0 or 1 (or undefined, default to 0)) + endif +endif + +# Define RUSTFLAGS and CFLAGS appropriate for the architecture. +# Keep synchronized with .github/workflows/release-python.yml. +ifeq ($(ARCH),amd64) + ifeq ($(LTS_CPU),1) + FEAT_RUSTFLAGS=-C target-feature=+sse3,+ssse3,+sse4.1,+sse4.2,+popcnt,+cmpxchg16b + FEAT_CFLAGS=-msse3 -mssse3 -msse4.1 -msse4.2 -mpopcnt -mcx16 + else + FEAT_RUSTFLAGS=-C target-feature=+sse3,+ssse3,+sse4.1,+sse4.2,+popcnt,+cmpxchg16b,+avx,+avx2,+fma,+bmi1,+bmi2,+lzcnt,+pclmulqdq,+movbe -Z tune-cpu=skylake + FEAT_CFLAGS=-msse3 -mssse3 -msse4.1 -msse4.2 -mpopcnt -mcx16 -mavx -mavx2 -mfma -mbmi -mbmi2 -mlzcnt -mpclmul -mmovbe -mtune=skylake + endif +endif + +override RUSTFLAGS+=$(FEAT_RUSTFLAGS) +override CFLAGS+=$(FEAT_CFLAGS) +export RUSTFLAGS +export CFLAGS + +# Define command to filter pip warnings when running maturin +FILTER_PIP_WARNINGS=| grep -v "don't match your environment"; test $${PIPESTATUS[0]} -eq 0 + +.venv: ## Set up Python virtual environment and install requirements + python3 -m venv $(VENV) + $(MAKE) requirements + +.PHONY: requirements +requirements: .venv ## Install/refresh Python project requirements + @unset CONDA_PREFIX \ + && $(VENV_BIN)/python -m pip install --upgrade uv \ + && $(VENV_BIN)/uv pip install --upgrade --compile-bytecode --no-build \ + -r py-polars/requirements-dev.txt \ + -r py-polars/requirements-lint.txt \ + -r py-polars/docs/requirements-docs.txt \ + -r docs/source/requirements.txt + +.PHONY: requirements-all +requirements-all: .venv ## Install/refresh all Python requirements (including those needed for CI tests) + $(MAKE) requirements + $(VENV_BIN)/uv pip install --upgrade --compile-bytecode -r py-polars/requirements-ci.txt + +.PHONY: build +build: .venv ## Compile and install Python Polars for development + @unset CONDA_PREFIX \ + && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml $(ARGS) \ + $(FILTER_PIP_WARNINGS) + +.PHONY: build-mindebug +build-mindebug: .venv ## Same as build, but don't include full debug information + @unset CONDA_PREFIX \ + && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --profile mindebug-dev $(ARGS) \ + $(FILTER_PIP_WARNINGS) + +.PHONY: build-release +build-release: .venv ## Compile and install Python Polars binary with optimizations, with minimal debug symbols + @unset CONDA_PREFIX \ + && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --release $(ARGS) \ + $(FILTER_PIP_WARNINGS) + +.PHONY: build-nodebug-release +build-nodebug-release: .venv ## Same as build-release, but without any debug symbols at all (a bit faster to build) + @unset CONDA_PREFIX \ + && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --profile nodebug-release $(ARGS) \ + $(FILTER_PIP_WARNINGS) + +.PHONY: build-debug-release +build-debug-release: .venv ## Same as build-release, but with full debug symbols turned on (a bit slower to build) + @unset CONDA_PREFIX \ + && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --profile debug-release $(ARGS) \ + $(FILTER_PIP_WARNINGS) + +.PHONY: build-dist-release +build-dist-release: .venv ## Compile and install Python Polars binary with super slow extra optimization turned on, for distribution + @unset CONDA_PREFIX \ + && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --profile dist-release $(ARGS) \ + $(FILTER_PIP_WARNINGS) + +.PHONY: check +check: ## Run cargo check with all features + cargo check --workspace --all-targets --all-features + +.PHONY: clippy +clippy: ## Run clippy with all features + cargo clippy --workspace --all-targets --all-features --locked -- -D warnings -D clippy::dbg_macro + +.PHONY: clippy-default +clippy-default: ## Run clippy with default features + cargo clippy --all-targets --locked -- -D warnings -D clippy::dbg_macro + +.PHONY: fmt +fmt: ## Run autoformatting and linting + $(VENV_BIN)/ruff check + $(VENV_BIN)/ruff format + cargo fmt --all + dprint fmt + $(VENV_BIN)/typos + +.PHONY: fix +fix: + cargo clippy --workspace --all-targets --all-features --fix + @# Good chance the fixing introduced formatting issues, best to just do a quick format. + cargo fmt --all + + +.PHONY: pre-commit +pre-commit: fmt clippy clippy-default ## Run all code quality checks + +.PHONY: clean +clean: ## Clean up caches, build artifacts, and the venv + @$(MAKE) -s -C py-polars/ $@ + @rm -rf .ruff_cache/ + @rm -rf .hypothesis/ + @rm -rf .venv/ + @cargo clean + +.PHONY: help +help: ## Display this help screen + @echo -e "\033[1mAvailable commands:\033[0m" + @grep -E '^[a-z.A-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-22s\033[0m %s\n", $$1, $$2}' | sort + @echo + @echo The build commands support LTS_CPU=1 for building for older CPUs, and ARGS which is passed through to maturin. + @echo 'For example to build without default features use: make build ARGS="--no-default-features".' diff --git a/README.md b/README.md index 365e0eda70d8..14b7def5fa01 100644 --- a/README.md +++ b/README.md @@ -1,200 +1,280 @@ -# Polars -[![rust docs](https://docs.rs/polars/badge.svg)](https://docs.rs/polars/latest/polars/) -[![Build Status](https://travis-ci.com/ritchie46/polars.svg?branch=master)](https://travis-ci.com/ritchie46/polars) -[![](http://meritbadge.herokuapp.com/polars)](https://crates.io/crates/polars) +

+ + Polars logo + +

-## Blazingly fast in memory DataFrames in Rust +
+ + crates.io Latest Release + + + PyPi Latest Release + + + NPM Latest Release + + + R-universe Latest Release + + + DOI Latest Release + +
-Polars is a DataFrames library implemented in Rust, using Apache Arrow as backend. -Its focus is being a fast in memory DataFrame library. +

+ Documentation: + Python + - + Rust + - + Node.js + - + R + | + StackOverflow: + Python + - + Rust + - + Node.js + - + R + | + User guide + | + Discord +

-Polars is in rapid development, but it already supports most features needed for a useful DataFrame library. Do you -miss something, please make an issue and/or sent a PR. +## Polars: Blazingly fast DataFrames in Rust, Python, Node.js, R, and SQL -## First run -Take a look at the [10 minutes to Polars notebook](examples/10_minutes_to_polars.ipynb) to get you started. -Want to run the notebook yourself? Clone the repo and run `$ cargo c && docker-compose up`. This will spin up a jupyter -notebook on `http://localhost:8891`. The notebooks are in the `/examples` directory. - -Oh yeah.. and get a cup of coffee because compilation will take while during the first run. +Polars is a DataFrame interface on top of an OLAP Query Engine implemented in Rust using +[Apache Arrow Columnar Format](https://arrow.apache.org/docs/format/Columnar.html) as the memory +model. +- Lazy | eager execution +- Multi-threaded +- SIMD +- Query optimization +- Powerful expression API +- Hybrid Streaming (larger-than-RAM datasets) +- Rust | Python | NodeJS | R | ... -## Documentation -Want to know what features Polars support? [Check the current master docs](https://ritchie46.github.io/polars). +To learn more, read the [user guide](https://docs.pola.rs/). -Most features are described on the [DataFrame](https://ritchie46.github.io/polars/polars/frame/struct.DataFrame.html), -[Series](https://ritchie46.github.io/polars/polars/series/enum.Series.html), and [ChunkedArray](https://ritchie46.github.io/polars/polars/chunked_array/struct.ChunkedArray.html) -structs in that order. For `ChunkedArray` a lot of functionality is also defined by `Traits` in the -[ops module](https://ritchie46.github.io/polars/polars/chunked_array/ops/index.html). +## Python -## Performance -Polars is written to be performant. Below are some comparisons with the (also very fast) Pandas DataFrame library. +```python +>>> import polars as pl +>>> df = pl.DataFrame( +... { +... "A": [1, 2, 3, 4, 5], +... "fruits": ["banana", "banana", "apple", "apple", "banana"], +... "B": [5, 4, 3, 2, 1], +... "cars": ["beetle", "audi", "beetle", "beetle", "beetle"], +... } +... ) -#### GroupBy -![](pandas_cmp/img/groupby10_.png) +# embarrassingly parallel execution & very expressive query language +>>> df.sort("fruits").select( +... "fruits", +... "cars", +... pl.lit("fruits").alias("literal_string_fruits"), +... pl.col("B").filter(pl.col("cars") == "beetle").sum(), +... pl.col("A").filter(pl.col("B") > 2).sum().over("cars").alias("sum_A_by_cars"), +... pl.col("A").sum().over("fruits").alias("sum_A_by_fruits"), +... pl.col("A").reverse().over("fruits").alias("rev_A_by_fruits"), +... pl.col("A").sort_by("B").over("fruits").alias("sort_A_by_B_by_fruits"), +... ) +shape: (5, 8) +┌──────────┬──────────┬──────────────┬─────┬─────────────┬─────────────┬─────────────┬─────────────┐ +│ fruits ┆ cars ┆ literal_stri ┆ B ┆ sum_A_by_ca ┆ sum_A_by_fr ┆ rev_A_by_fr ┆ sort_A_by_B │ +│ --- ┆ --- ┆ ng_fruits ┆ --- ┆ rs ┆ uits ┆ uits ┆ _by_fruits │ +│ str ┆ str ┆ --- ┆ i64 ┆ --- ┆ --- ┆ --- ┆ --- │ +│ ┆ ┆ str ┆ ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ +╞══════════╪══════════╪══════════════╪═════╪═════════════╪═════════════╪═════════════╪═════════════╡ +│ "apple" ┆ "beetle" ┆ "fruits" ┆ 11 ┆ 4 ┆ 7 ┆ 4 ┆ 4 │ +│ "apple" ┆ "beetle" ┆ "fruits" ┆ 11 ┆ 4 ┆ 7 ┆ 3 ┆ 3 │ +│ "banana" ┆ "beetle" ┆ "fruits" ┆ 11 ┆ 4 ┆ 8 ┆ 5 ┆ 5 │ +│ "banana" ┆ "audi" ┆ "fruits" ┆ 11 ┆ 2 ┆ 8 ┆ 2 ┆ 2 │ +│ "banana" ┆ "beetle" ┆ "fruits" ┆ 11 ┆ 4 ┆ 8 ┆ 1 ┆ 1 │ +└──────────┴──────────┴──────────────┴─────┴─────────────┴─────────────┴─────────────┴─────────────┘ +``` -#### Joins -![](pandas_cmp/img/join_80_000.png) +## SQL -## Functionality +```python +>>> df = pl.scan_csv("docs/assets/data/iris.csv") +>>> ## OPTION 1 +>>> # run SQL queries on frame-level +>>> df.sql(""" +... SELECT species, +... AVG(sepal_length) AS avg_sepal_length +... FROM self +... GROUP BY species +... """).collect() +shape: (3, 2) +┌────────────┬──────────────────┐ +│ species ┆ avg_sepal_length │ +│ --- ┆ --- │ +│ str ┆ f64 │ +╞════════════╪══════════════════╡ +│ Virginica ┆ 6.588 │ +│ Versicolor ┆ 5.936 │ +│ Setosa ┆ 5.006 │ +└────────────┴──────────────────┘ +>>> ## OPTION 2 +>>> # use pl.sql() to operate on the global context +>>> df2 = pl.LazyFrame({ +... "species": ["Setosa", "Versicolor", "Virginica"], +... "blooming_season": ["Spring", "Summer", "Fall"] +...}) +>>> pl.sql(""" +... SELECT df.species, +... AVG(df.sepal_length) AS avg_sepal_length, +... df2.blooming_season +... FROM df +... LEFT JOIN df2 ON df.species = df2.species +... GROUP BY df.species, df2.blooming_season +... """).collect() +``` -### Read and write CSV | JSON | IPC | Parquet +SQL commands can also be run directly from your terminal using the Polars CLI: -```rust - use polars::prelude::*; - use std::fs::File; - - fn example() -> Result { - let file = File::open("iris.csv") - .expect("could not open file"); - - CsvReader::new(file) - .infer_schema(None) - .has_header(true) - .finish() - } -``` +```bash +# run an inline SQL query +> polars -c "SELECT species, AVG(sepal_length) AS avg_sepal_length, AVG(sepal_width) AS avg_sepal_width FROM read_csv('docs/assets/data/iris.csv') GROUP BY species;" -### Joins - -```rust - use polars::prelude::*; - - fn join() -> Result { - // Create first df. - let s0 = Series::new("days", &[0, 1, 2, 3, 4]); - let s1 = Series::new("temp", &[22.1, 19.9, 7., 2., 3.]); - let temp = DataFrame::new(vec![s0, s1])?; - - // Create second df. - let s0 = Series::new("days", &[1, 2]); - let s1 = Series::new("rain", &[0.1, 0.2]); - let rain = DataFrame::new(vec![s0, s1])?; - - // Left join on days column. - temp.left_join(&rain, "days", "days") - } - println!("{}", join().unwrap()); -``` +# run interactively +> polars +Polars CLI v0.3.0 +Type .help for help. -```text - +------+------+------+ - | days | temp | rain | - | --- | --- | --- | - | i32 | f64 | f64 | - +======+======+======+ - | 0 | 22.1 | null | - +------+------+------+ - | 1 | 19.9 | 0.1 | - +------+------+------+ - | 2 | 7 | 0.2 | - +------+------+------+ - | 3 | 2 | null | - +------+------+------+ - | 4 | 3 | null | - +------+------+------+ +> SELECT species, AVG(sepal_length) AS avg_sepal_length, AVG(sepal_width) AS avg_sepal_width FROM read_csv('docs/assets/data/iris.csv') GROUP BY species; ``` -### Groupby's | aggregations | pivots +Refer to the [Polars CLI repository](https://github.com/pola-rs/polars-cli) for more information. -```rust - use polars::prelude::*; - fn groupby_sum(df: &DataFrame) -> Result { - df.groupby("column_name")? - .select("agg_column_name") - .sum() - } -``` +## Performance 🚀🚀 + +### Blazingly fast + +Polars is very fast. In fact, it is one of the best performing solutions available. See the +[PDS-H benchmarks](https://www.pola.rs/benchmarks.html) results. + +### Lightweight + +Polars is also very lightweight. It comes with zero required dependencies, and this shows in the +import times: + +- polars: 70ms +- numpy: 104ms +- pandas: 520ms -### Arithmetic -```rust - use polars::prelude::*; - let s: Series = [1, 2, 3].iter().collect(); - let s_squared = &s * &s; +### Handles larger-than-RAM data + +If you have data that does not fit into memory, Polars' query engine is able to process your query +(or parts of your query) in a streaming fashion. This drastically reduces memory requirements, so +you might be able to process your 250GB dataset on your laptop. Collect with +`collect(engine='streaming')` to run the query streaming. (This might be a little slower, but it is +still very fast!) + +## Setup + +### Python + +Install the latest Polars version with: + +```sh +pip install polars ``` -### Rust iterators - -```rust - use polars::prelude::*; - - let s: Series = [1, 2, 3].iter().collect(); - let s_squared: Series = s.i32() - .expect("datatype mismatch") - .into_iter() - .map(|optional_v| { - match optional_v { - Some(v) => Some(v * v), - None => None, // null value - } - }).collect(); +We also have a conda package (`conda install -c conda-forge polars`), however pip is the preferred +way to install Polars. + +Install Polars with all optional dependencies. + +```sh +pip install 'polars[all]' ``` -### Apply custom closures -```rust - use polars::prelude::*; - - let s: Series = Series::new("values", [Some(1.0), None, Some(3.0)]); - // null values are ignored automatically - let squared = s.f64() - .unwrap() - .apply(|value| value.powf(2.0)) - .into_series(); - - assert_eq!(Vec::from(squared.f64().unwrap()), &[Some(1.0), None, Some(9.0)]) +You can also install a subset of all optional dependencies. + +```sh +pip install 'polars[numpy,pandas,pyarrow]' ``` -### Comparisons +See the [User Guide](https://docs.pola.rs/user-guide/installation/#feature-flags) for more details +on optional dependencies -```rust - use polars::prelude::*; - use itertools::Itertools; - let s = Series::new("dollars", &[1, 2, 3]); - let mask = s.eq(1); - let valid = [true, false, false].iter(); - - assert_eq!(Vec::from(mask), &[Some(true), Some(false), Some(false)]); +To see the current Polars version and a full list of its optional dependencies, run: + +```python +pl.show_versions() ``` -## Temporal data types - -```rust - let dates = &[ - "2020-08-21", - "2020-08-21", - "2020-08-22", - "2020-08-23", - "2020-08-22", - ]; - // date format - let fmt = "%Y-%m-%d"; - // create date series - let s0 = Date32Chunked::parse_from_str_slice("date", dates, fmt) - .into_series(); +Releases happen quite often (weekly / every few days) at the moment, so updating Polars regularly to +get the latest bugfixes / features might not be a bad idea. + +### Rust + +You can take latest release from `crates.io`, or if you want to use the latest features / +performance improvements point to the `main` branch of this repo. + +```toml +polars = { git = "https://github.com/pola-rs/polars", rev = "" } ``` -## And more... - -* [DataFrame](https://ritchie46.github.io/polars/polars/frame/struct.DataFrame.html) -* [Series](https://ritchie46.github.io/polars/polars/series/enum.Series.html) -* [ChunkedArray](https://ritchie46.github.io/polars/polars/chunked_array/struct.ChunkedArray.html) - - [Operations implemented by Traits](https://ritchie46.github.io/polars/polars/chunked_array/ops/index.html) -* [Time/ DateTime utilities](https://ritchie46.github.io/polars/polars/doc/time/index.html) -* [Groupby, aggregations and pivots](https://ritchie46.github.io/polars/polars/frame/group_by/struct.GroupBy.html) - -## Features - -Additional cargo features: - -* `pretty` (default) - - pretty printing of DataFrames -* `temporal (default)` - - Conversions between Chrono and Polars for temporal data -* `simd` - - SIMD operations -* `paquet` - - Read Apache Parquet format -* `random` - - Generate array's with randomly sampled values -* `ndarray` - - Convert from `DataFrame` to `ndarray` +Requires Rust version `>=1.80`. + +## Contributing + +Want to contribute? Read our [contributing guide](https://docs.pola.rs/development/contributing/). + +## Python: compile Polars from source + +If you want a bleeding edge release or maximal performance you should compile Polars from source. + +This can be done by going through the following steps in sequence: + +1. Install the latest [Rust compiler](https://www.rust-lang.org/tools/install) +2. Install [maturin](https://maturin.rs/): `pip install maturin` +3. `cd py-polars` and choose one of the following: + - `make build`, slow binary with debug assertions and symbols, fast compile times + - `make build-release`, fast binary without debug assertions, minimal debug symbols, long compile + times + - `make build-nodebug-release`, same as build-release but without any debug symbols, slightly + faster to compile + - `make build-debug-release`, same as build-release but with full debug symbols, slightly slower + to compile + - `make build-dist-release`, fastest binary, extreme compile times + +By default the binary is compiled with optimizations turned on for a modern CPU. Specify `LTS_CPU=1` +with the command if your CPU is older and does not support e.g. AVX2. + +Note that the Rust crate implementing the Python bindings is called `py-polars` to distinguish from +the wrapped Rust crate `polars` itself. However, both the Python package and the Python module are +named `polars`, so you can `pip install polars` and `import polars`. + +## Using custom Rust functions in Python + +Extending Polars with UDFs compiled in Rust is easy. We expose PyO3 extensions for `DataFrame` and +`Series` data structures. See more in https://github.com/pola-rs/pyo3-polars. + +## Going big... + +Do you expect more than 2^32 (~4.2 billion) rows? Compile Polars with the `bigidx` feature flag or, +for Python users, install `pip install polars-u64-idx`. + +Don't use this unless you hit the row boundary as the default build of Polars is faster and consumes +less memory. + +## Legacy + +Do you want Polars to run on an old CPU (e.g. dating from before 2011), or on an `x86-64` build of +Python on Apple Silicon under Rosetta? Install `pip install polars-lts-cpu`. This version of Polars +is compiled without [AVX](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions) target features. + +## Sponsors + +[JetBrains logo](https://www.jetbrains.com) diff --git a/bench/Cargo.toml b/bench/Cargo.toml deleted file mode 100644 index b1018022c519..000000000000 --- a/bench/Cargo.toml +++ /dev/null @@ -1,11 +0,0 @@ -[package] -name = "bench" -version = "0.1.0" -authors = ["ritchie46 "] -edition = "2018" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -polars = {path = "../polars"} -arrow = {version = "1.0.0", default_features = false} diff --git a/bench/src/main.rs b/bench/src/main.rs deleted file mode 100644 index 69597ffce5b8..000000000000 --- a/bench/src/main.rs +++ /dev/null @@ -1,63 +0,0 @@ -#![feature(test)] -extern crate test; -use arrow::array::{Array, ArrayRef}; -use arrow::{ - array::{PrimitiveArray, PrimitiveBuilder}, - datatypes::Int32Type, -}; -use polars::prelude::*; -use std::sync::Arc; -use test::Bencher; - -const SIZE: usize = 1000; -const N: usize = 10; - -fn create_arrow_array(size: usize) -> Arc { - let mut builder = PrimitiveBuilder::::new(size); - for i in 0..size { - builder.append_value(i as i32).expect("append"); - } - Arc::new(builder.finish()) -} - -fn n_arrow_arrays(n: usize, size: usize) -> Vec> { - let mut arrays = Vec::with_capacity(n); - for _ in 0..n { - arrays.push(create_arrow_array(size)) - } - arrays -} - -#[bench] -fn bench_vec_clone(b: &mut Bencher) { - let arrays = n_arrow_arrays(N, SIZE); - b.iter(|| arrays.clone()) -} - -#[bench] -fn bench_arc_clone(b: &mut Bencher) { - let arrays = n_arrow_arrays(N, SIZE); - let arrays = Arc::new(arrays); - b.iter(|| arrays.clone()) -} - -#[bench] -fn bench_series_clone(b: &mut Bencher) { - let arrays = n_arrow_arrays(N, SIZE); - - let ca = ChunkedArray::new_from_chunks("a", arrays); - let s = Series::Int32(ca); - b.iter(|| s.clone()) -} - -#[bench] -fn bench_data_clone(b: &mut Bencher) { - // Is not really comparable, as arrow arrays also have null bits - let mut arrays = Vec::with_capacity(N); - for _ in 0..N { - arrays.push((0..SIZE).collect::>()); - } - - b.iter(|| arrays.clone()) -} -fn main() {} diff --git a/crates/Makefile b/crates/Makefile new file mode 100644 index 000000000000..db7a6d3859dd --- /dev/null +++ b/crates/Makefile @@ -0,0 +1,164 @@ +.DEFAULT_GOAL := help + +SHELL=bash +BASE ?= main + +.PHONY: fix +fix: + @$(MAKE) -s -C .. $@ + +.PHONY: fmt +fmt: ## Run rustfmt and dprint + cargo fmt --all + dprint fmt + +.PHONY: check +check: ## Run cargo check with all features + cargo check -p polars --all-features + +.PHONY: clippy +clippy: ## Run clippy with all features + cargo clippy --all-targets --all-features -- -W clippy::dbg_macro + +.PHONY: clippy-default +clippy-default: ## Run clippy with default features + cargo clippy --all-targets -- -W clippy::dbg_macro + +.PHONY: pre-commit +pre-commit: fmt clippy clippy-default ## Run autoformatting and linting + +.PHONY: check-features +check-features: ## Run cargo check for feature flag combinations (warning: slow) + cargo hack check -p polars --each-feature --no-dev-deps + +.PHONY: miri +miri: ## Run miri + # not tested on all features because miri does not support SIMD + # some tests are also filtered, because miri cannot deal with the rayon threadpool + # we ignore leaks because the thread pool of rayon is never killed. + MIRIFLAGS="-Zmiri-disable-isolation -Zmiri-ignore-leaks -Zmiri-disable-stacked-borrows" \ + POLARS_ALLOW_EXTENSION=1 \ + cargo miri test \ + --features object \ + -p polars-core \ +# -p polars-arrow + +.PHONY: test +test: ## Run tests + cargo test --all-features \ + -p polars-compute \ + -p polars-core \ + -p polars-io \ + -p polars-lazy \ + -p polars-ops \ + -p polars-plan \ + -p polars-row \ + -p polars-sql \ + -p polars-testing \ + -p polars-time \ + -p polars-utils \ + -- \ + --test-threads=2 + +.PHONY: nextest +nextest: ## Run tests with nextest + cargo nextest run --all-features \ + -p polars-compute \ + -p polars-core \ + -p polars-io \ + -p polars-lazy \ + -p polars-ops \ + -p polars-plan \ + -p polars-row \ + -p polars-sql \ + -p polars-testing \ + -p polars-time \ + -p polars-utils \ + +.PHONY: integration-tests +integration-tests: ## Run integration tests + cargo test --all-features --test it -p polars + +.PHONY: test-doc +test-doc: ## Run doc examples + cargo test --doc \ + -p polars-lazy \ + -p polars-io \ + -p polars-core \ + -p polars-testing \ + -p polars-sql + +.PHONY: bench-save +bench-save: ## Run benchmark and save + cargo bench --features=random --bench $(BENCH) -- --save-baseline $(SAVE) + +.PHONY: bench-cmp +bench-cmp: ## Run benchmark and compare + cargo bench --features=random --bench $(BENCH) -- --load-baseline $(FEAT) --baseline $(BASE) + +.PHONY: doctest +doctest: ## Check that documentation builds + cargo doc --no-deps --all-features -p polars-utils + cargo doc --no-deps --features=docs-selection -p polars-core + cargo doc --no-deps -p polars-time + cargo doc --no-deps -p polars-ops + cargo doc --no-deps --all-features -p polars-io + cargo doc --no-deps --all-features -p polars-lazy + cargo doc --no-deps --features=docs-selection -p polars + cargo doc --no-deps --all-features -p polars-sql + +.PHONY: publish +publish: ## Publish Polars crates + cargo publish --allow-dirty -p polars-error + cargo publish --allow-dirty -p polars-utils + cargo publish --allow-dirty -p polars-schema + cargo publish --allow-dirty -p polars-arrow + cargo publish --allow-dirty -p polars-compute + cargo publish --allow-dirty -p polars-row + cargo publish --allow-dirty -p polars-json + cargo publish --allow-dirty -p polars-core + cargo publish --allow-dirty -p polars-ffi + cargo publish --allow-dirty -p polars-ops + cargo publish --allow-dirty -p polars-time + cargo publish --allow-dirty -p polars-parquet + cargo publish --allow-dirty -p polars-io + cargo publish --allow-dirty -p polars-plan + cargo publish --allow-dirty -p polars-expr + cargo publish --allow-dirty -p polars-mem-engine + cargo publish --allow-dirty -p polars-stream + cargo publish --allow-dirty -p polars-pipe + cargo publish --allow-dirty -p polars-lazy + cargo publish --allow-dirty -p polars-sql + cargo publish --allow-dirty -p polars + # This is independent + cargo publish --allow-dirty -p polars-python + +.PHONY: help +help: ## Display this help screen + @echo -e "\033[1mAvailable commands:\033[0m" + @grep -E '^[a-z.A-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-18s\033[0m %s\n", $$1, $$2}' | sort + +.PHONY: check-wasm +check-wasm: ## Check wasm build without supported features + cargo hack check --target wasm32-unknown-unknown -p polars --no-dev-deps \ + --each-feature \ + --exclude-features async \ + --exclude-features aws \ + --exclude-features azure \ + --exclude-features cloud \ + --exclude-features decompress \ + --exclude-features default \ + --exclude-features docs-selection \ + --exclude-features extract_jsonpath \ + --exclude-features fmt \ + --exclude-features gcp \ + --exclude-features ipc \ + --exclude-features ipc_streaming \ + --exclude-features json \ + --exclude-features nightly \ + --exclude-features parquet \ + --exclude-features performant \ + --exclude-features streaming \ + --exclude-features http \ + --exclude-features full \ + --exclude-features test diff --git a/crates/clippy.toml b/crates/clippy.toml new file mode 100644 index 000000000000..7e715b3edcd1 --- /dev/null +++ b/crates/clippy.toml @@ -0,0 +1,5 @@ +disallowed-types = ["std::collections::HashMap", "std::collections::HashSet"] + +disallowed-methods = [ + { path = "regex::Regex::new", reason = "use polars_utils::regex_cache" }, +] diff --git a/crates/polars-arrow/Cargo.toml b/crates/polars-arrow/Cargo.toml new file mode 100644 index 000000000000..f1960a359df0 --- /dev/null +++ b/crates/polars-arrow/Cargo.toml @@ -0,0 +1,154 @@ +[package] +name = "polars-arrow" +version = { workspace = true } +authors = [ + "Jorge C. Leitao ", + "Apache Arrow ", + "Ritchie Vink ", +] +edition = { workspace = true } +homepage = { workspace = true } +license = "MIT AND Apache-2.0" +repository = { workspace = true } +description = "Minimal implementation of the Arrow specification forked from arrow2" + +[dependencies] +bytemuck = { workspace = true, features = ["must_cast"] } +chrono = { workspace = true } +# for timezone support +chrono-tz = { workspace = true, optional = true } +dyn-clone = { version = "1" } +either = { workspace = true } +hashbrown = { workspace = true } +num-traits = { workspace = true } +parking_lot = { workspace = true } +polars-error = { workspace = true } +polars-schema = { workspace = true } +polars-utils = { workspace = true } +serde = { workspace = true, optional = true } +simdutf8 = { workspace = true } + +ethnum = { workspace = true } + +# To efficiently cast numbers to strings +atoi_simd = { workspace = true, optional = true } +fast-float2 = { workspace = true, optional = true } +itoa = { workspace = true, optional = true } +ryu = { workspace = true, optional = true } + +regex = { workspace = true, optional = true } +regex-syntax = { version = "0.8", optional = true } +streaming-iterator = { workspace = true } + +indexmap = { workspace = true, optional = true } + +arrow-format = { workspace = true, optional = true, features = ["ipc"] } + +hex = { workspace = true, optional = true } + +# for IPC compression +lz4 = { version = "1.24", optional = true } +zstd = { workspace = true, optional = true } + +# to write to parquet as a stream +futures = { workspace = true, optional = true } + +# avro support +avro-schema = { workspace = true, optional = true } + +# for division/remainder optimization at runtime +strength_reduce = { workspace = true, optional = true } + +# For async arrow flight conversion +async-stream = { version = "0.3", optional = true } +tokio = { workspace = true, optional = true, features = ["io-util"] } + +strum_macros = { workspace = true } + +[dev-dependencies] +criterion = "0.5" +crossbeam-channel = { workspace = true } +doc-comment = "0.3" +flate2 = { workspace = true, default-features = true } +# used to run formal property testing +proptest = { version = "1", default-features = false, features = ["std"] } +# use for flaky testing +rand = { workspace = true } +# use for generating and testing random data samples +sample-arrow2 = "0.17" +sample-std = "0.2" +sample-test = "0.2" +# used to test async readers +tokio = { workspace = true, features = ["macros", "rt", "fs", "io-util"] } +tokio-util = { workspace = true, features = ["compat"] } + +[build-dependencies] +version_check = { workspace = true } + +[target.wasm32-unknown-unknown.dependencies] +getrandom = { version = "0.2", features = ["js"] } + +[features] +default = [] +full = [ + "io_ipc", + "io_flight", + "io_ipc_compression", + "io_avro", + "io_avro_compression", + "io_avro_async", + "regex-syntax", + "compute", + "serde", + # parses timezones used in timestamp conversions + "chrono-tz", +] +io_ipc = ["arrow-format", "polars-error/arrow-format"] +io_ipc_compression = ["lz4", "zstd", "io_ipc"] +io_flight = ["io_ipc", "arrow-format/flight-data", "async-stream", "futures", "tokio"] + +io_avro = ["avro-schema", "polars-error/avro-schema"] +io_avro_compression = [ + "avro-schema/compression", +] +io_avro_async = ["avro-schema/async"] + +# the compute kernels. Disabling this significantly reduces compile time. +compute_aggregate = [] +compute_arithmetics_decimal = ["strength_reduce"] +compute_arithmetics = ["strength_reduce", "compute_arithmetics_decimal"] +compute_bitwise = [] +compute_boolean = [] +compute_boolean_kleene = [] +compute_comparison = ["compute_boolean"] +compute_temporal = [] +compute = [ + "compute_aggregate", + "compute_arithmetics", + "compute_bitwise", + "compute_boolean", + "compute_boolean_kleene", + "compute_comparison", + "compute_temporal", +] +serde = ["dep:serde", "polars-schema/serde", "polars-utils/serde"] +simd = [] + +# polars-arrow +timezones = [ + "chrono-tz", +] +dtype-array = [] +dtype-decimal = ["atoi_simd", "itoa"] +bigidx = ["polars-utils/bigidx"] +nightly = [] +performant = [] +strings = [] +temporal = [] + +[package.metadata.docs.rs] +features = ["full"] +rustdoc-args = ["--cfg", "docsrs"] + +[package.metadata.cargo-all-features] +allowlist = ["compute", "compute_sort", "compute_nullif"] diff --git a/crates/polars-arrow/LICENSE b/crates/polars-arrow/LICENSE new file mode 100644 index 000000000000..a4b4b70523c3 --- /dev/null +++ b/crates/polars-arrow/LICENSE @@ -0,0 +1,196 @@ +Some of the code in this crate is subject to the Apache 2 license below, as it +was taken from the arrow2 Rust crate in October 2023. Later changes are subject +to the MIT license in ../../LICENSE. + + + + Apache License + Version 2.0, January 2004 + https://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + Copyright 2020-2022 Jorge C. Leitão + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/crates/polars-arrow/build.rs b/crates/polars-arrow/build.rs new file mode 100644 index 000000000000..3e4ab64620ac --- /dev/null +++ b/crates/polars-arrow/build.rs @@ -0,0 +1,7 @@ +fn main() { + println!("cargo:rerun-if-changed=build.rs"); + let channel = version_check::Channel::read().unwrap(); + if channel.is_nightly() { + println!("cargo:rustc-cfg=feature=\"nightly\""); + } +} diff --git a/crates/polars-arrow/src/README.md b/crates/polars-arrow/src/README.md new file mode 100644 index 000000000000..cb24938c5f78 --- /dev/null +++ b/crates/polars-arrow/src/README.md @@ -0,0 +1,37 @@ +# Crate's design + +This document describes the design of this module, and thus the overall crate. Each module MAY have +its own design document, that concerns specifics of that module, and if yes, it MUST be on each +module's `README.md`. + +## Equality + +Array equality is not defined in the Arrow specification. This crate follows the intent of the +specification, but there is no guarantee that this no verification that this equals e.g. C++'s +definition. + +There is a single source of truth about whether two arrays are equal, and that is via their equality +operators, defined on the module [`array/equal`](array/equal/mod.rs). + +Implementation MUST use these operators for asserting equality, so that all testing follows the same +definition of array equality. + +## Error handling + +- Errors from an external dependency MUST be encapsulated on `External`. +- Errors from IO MUST be encapsulated on `Io`. +- This crate MAY return `NotYetImplemented` when the functionality does not exist, or it MAY panic + with `unimplemented!`. + +## Logical and physical types + +There is a strict separation between physical and logical types: + +- physical types MUST be implemented via generics +- logical types MUST be implemented via variables (whose value is e.g. an `enum`) +- logical types MUST be declared and implemented on the `datatypes` module + +## Source of undefined behavior + +There is one, and only one, acceptable source of undefined behavior: FFI. It is impossible to prove +that data passed via pointers are safe for consumption (only a promise from the specification). diff --git a/crates/polars-arrow/src/array/README.md b/crates/polars-arrow/src/array/README.md new file mode 100644 index 000000000000..b3c8424c4c45 --- /dev/null +++ b/crates/polars-arrow/src/array/README.md @@ -0,0 +1,84 @@ +# Array module + +This document describes the overall design of this module. + +## Notation: + +- "array" in this module denotes any struct that implements the trait `Array`. +- "mutable array" in this module denotes any struct that implements the trait `MutableArray`. +- words in `code` denote existing terms on this implementation. + +## Arrays: + +- Every arrow array with a different physical representation MUST be implemented as a struct or + generic struct. + +- An array MAY have its own module. E.g. `primitive/mod.rs` + +- An array with a null bitmap MUST implement it as `Option` + +- An array MUST be `#[derive(Clone)]` + +- The trait `Array` MUST only be implemented by structs in this module. + +- Every child array on the struct MUST be `Box`. + +- An array MUST implement `try_new(...) -> Self`. This method MUST error iff the data does not + follow the arrow specification, including any sentinel types such as utf8. + +- An array MAY implement `unsafe try_new_unchecked` that skips validation steps that are `O(N)`. + +- An array MUST implement either `new_empty()` or `new_empty(DataType)` that returns a zero-len of + `Self`. + +- An array MUST implement either `new_null(length: usize)` or `new_null(DataType, length: usize)` + that returns a valid array of length `length` whose all elements are null. + +- An array MAY implement `value(i: usize)` that returns the value at slot `i` ignoring the validity + bitmap. + +- functions to create new arrays from native Rust SHOULD be named as follows: + - `from`: from a slice of optional values (e.g. `AsRef<[Option]` for `BooleanArray`) + - `from_slice`: from a slice of values (e.g. `AsRef<[bool]>` for `BooleanArray`) + - `from_trusted_len_iter` from an iterator of trusted len of optional values + - `from_trusted_len_values_iter` from an iterator of trusted len of values + - `try_from_trusted_len_iter` from an fallible iterator of trusted len of optional values + +### Slot offsets + +- An array MUST have a `offset: usize` measuring the number of slots that the array is currently + offsetted by if the specification requires. + +- An array MUST implement `fn slice(&self, offset: usize, length: usize) -> Self` that returns an + offsetted and/or truncated clone of the array. This function MUST increase the array's offset if + it exists. + +- Conversely, `offset` MUST only be changed by `slice`. + +The rational of the above is that it enable us to be fully interoperable with the offset logic +supported by the C data interface, while at the same time easily perform array slices within Rust's +type safety mechanism. + +### Mutable Arrays + +- An array MAY have a mutable counterpart. E.g. `MutablePrimitiveArray` is the mutable + counterpart of `PrimitiveArray`. + +- Arrays with mutable counterparts MUST have its own module, and have the mutable counterpart + declared in `{module}/mutable.rs`. + +- The trait `MutableArray` MUST only be implemented by mutable arrays in this module. + +- A mutable array MUST be `#[derive(Debug)]` + +- A mutable array with a null bitmap MUST implement it as `Option` + +- Converting a `MutableArray` to its immutable counterpart MUST be `O(1)`. Specifically: + - it must not allocate + - it must not cause `O(N)` data transformations + + This is achieved by converting mutable versions to immutable counterparts (e.g. + `MutableBitmap -> Bitmap`). + + The rational is that `MutableArray`s can be used to perform in-place operations under the arrow + spec. diff --git a/crates/polars-arrow/src/array/binary/builder.rs b/crates/polars-arrow/src/array/binary/builder.rs new file mode 100644 index 000000000000..0f834bdce4ca --- /dev/null +++ b/crates/polars-arrow/src/array/binary/builder.rs @@ -0,0 +1,149 @@ +use polars_utils::IdxSize; + +use crate::array::BinaryArray; +use crate::array::builder::{ShareStrategy, StaticArrayBuilder}; +use crate::bitmap::OptBitmapBuilder; +use crate::buffer::Buffer; +use crate::datatypes::ArrowDataType; +use crate::offset::{Offset, Offsets, OffsetsBuffer}; + +pub struct BinaryArrayBuilder { + dtype: ArrowDataType, + offsets: Offsets, + values: Vec, + validity: OptBitmapBuilder, +} + +impl BinaryArrayBuilder { + pub fn new(dtype: ArrowDataType) -> Self { + Self { + dtype, + offsets: Offsets::new(), + values: Vec::new(), + validity: OptBitmapBuilder::default(), + } + } +} + +impl StaticArrayBuilder for BinaryArrayBuilder { + type Array = BinaryArray; + + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } + + fn reserve(&mut self, additional: usize) { + self.offsets.reserve(additional); + self.validity.reserve(additional); + // No values reserve, we have no idea how large it needs to be. + } + + fn freeze(self) -> BinaryArray { + let offsets = OffsetsBuffer::from(self.offsets); + let values = Buffer::from(self.values); + let validity = self.validity.into_opt_validity(); + BinaryArray::new(self.dtype, offsets, values, validity) + } + + fn freeze_reset(&mut self) -> Self::Array { + let offsets = OffsetsBuffer::from(core::mem::take(&mut self.offsets)); + let values = Buffer::from(core::mem::take(&mut self.values)); + let validity = core::mem::take(&mut self.validity).into_opt_validity(); + BinaryArray::new(self.dtype.clone(), offsets, values, validity) + } + + fn len(&self) -> usize { + self.offsets.len_proxy() + } + + fn extend_nulls(&mut self, length: usize) { + self.offsets.extend_constant(length); + self.validity.extend_constant(length, false); + } + + fn subslice_extend( + &mut self, + other: &BinaryArray, + start: usize, + length: usize, + _share: ShareStrategy, + ) { + let start_offset = other.offsets()[start].to_usize(); + let stop_offset = other.offsets()[start + length].to_usize(); + self.offsets + .try_extend_from_slice(other.offsets(), start, length) + .unwrap(); + self.values + .extend_from_slice(&other.values()[start_offset..stop_offset]); + self.validity + .subslice_extend_from_opt_validity(other.validity(), start, length); + } + + unsafe fn gather_extend( + &mut self, + other: &BinaryArray, + idxs: &[IdxSize], + _share: ShareStrategy, + ) { + let other_values = &**other.values(); + let other_offsets = other.offsets(); + + // Pre-compute proper length for reserve. + let total_len: usize = idxs + .iter() + .map(|i| { + let start_offset = other_offsets.get_unchecked(*i as usize).to_usize(); + let stop_offset = other_offsets.get_unchecked(*i as usize + 1).to_usize(); + stop_offset - start_offset + }) + .sum(); + self.values.reserve(total_len); + + for idx in idxs { + let start_offset = other_offsets.get_unchecked(*idx as usize).to_usize(); + let stop_offset = other_offsets.get_unchecked(*idx as usize + 1).to_usize(); + self.values + .extend_from_slice(other_values.get_unchecked(start_offset..stop_offset)); + } + + self.validity + .gather_extend_from_opt_validity(other.validity(), idxs); + } + + fn opt_gather_extend( + &mut self, + other: &BinaryArray, + idxs: &[IdxSize], + _share: ShareStrategy, + ) { + let other_values = &**other.values(); + let other_offsets = other.offsets(); + + unsafe { + // Pre-compute proper length for reserve. + let total_len: usize = idxs + .iter() + .map(|idx| { + if (*idx as usize) < other.len() { + let start_offset = other_offsets.get_unchecked(*idx as usize).to_usize(); + let stop_offset = other_offsets.get_unchecked(*idx as usize + 1).to_usize(); + stop_offset - start_offset + } else { + 0 + } + }) + .sum(); + self.values.reserve(total_len); + + for idx in idxs { + let start_offset = other_offsets.get_unchecked(*idx as usize).to_usize(); + let stop_offset = other_offsets.get_unchecked(*idx as usize + 1).to_usize(); + self.values + .extend_from_slice(other_values.get_unchecked(start_offset..stop_offset)); + } + + self.validity + .opt_gather_extend_from_opt_validity(other.validity(), idxs, other.len()); + } + } +} diff --git a/crates/polars-arrow/src/array/binary/ffi.rs b/crates/polars-arrow/src/array/binary/ffi.rs new file mode 100644 index 000000000000..107cf0fcb421 --- /dev/null +++ b/crates/polars-arrow/src/array/binary/ffi.rs @@ -0,0 +1,64 @@ +use polars_error::PolarsResult; + +use super::BinaryArray; +use crate::array::{FromFfi, ToFfi}; +use crate::bitmap::align; +use crate::ffi; +use crate::offset::{Offset, OffsetsBuffer}; + +unsafe impl ToFfi for BinaryArray { + fn buffers(&self) -> Vec> { + vec![ + self.validity.as_ref().map(|x| x.as_ptr()), + Some(self.offsets.buffer().storage_ptr().cast::()), + Some(self.values.storage_ptr().cast::()), + ] + } + + fn offset(&self) -> Option { + let offset = self.offsets.buffer().offset(); + if let Some(bitmap) = self.validity.as_ref() { + if bitmap.offset() == offset { + Some(offset) + } else { + None + } + } else { + Some(offset) + } + } + + fn to_ffi_aligned(&self) -> Self { + let offset = self.offsets.buffer().offset(); + + let validity = self.validity.as_ref().map(|bitmap| { + if bitmap.offset() == offset { + bitmap.clone() + } else { + align(bitmap, offset) + } + }); + + Self { + dtype: self.dtype.clone(), + validity, + offsets: self.offsets.clone(), + values: self.values.clone(), + } + } +} + +impl FromFfi for BinaryArray { + unsafe fn try_from_ffi(array: A) -> PolarsResult { + let dtype = array.dtype().clone(); + + let validity = unsafe { array.validity() }?; + let offsets = unsafe { array.buffer::(1) }?; + let values = unsafe { array.buffer::(2) }?; + + // assumption that data from FFI is well constructed + let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets) }; + + Self::try_new(dtype, offsets, values, validity) + } +} diff --git a/crates/polars-arrow/src/array/binary/fmt.rs b/crates/polars-arrow/src/array/binary/fmt.rs new file mode 100644 index 000000000000..d2a6788ce4d8 --- /dev/null +++ b/crates/polars-arrow/src/array/binary/fmt.rs @@ -0,0 +1,26 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::write_vec; +use super::BinaryArray; +use crate::offset::Offset; + +pub fn write_value(array: &BinaryArray, index: usize, f: &mut W) -> Result { + let bytes = array.value(index); + let writer = |f: &mut W, index| write!(f, "{}", bytes[index]); + + write_vec(f, writer, None, bytes.len(), "None", false) +} + +impl Debug for BinaryArray { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, f); + + let head = if O::IS_LARGE { + "LargeBinaryArray" + } else { + "BinaryArray" + }; + write!(f, "{head}")?; + write_vec(f, writer, self.validity(), self.len(), "None", false) + } +} diff --git a/crates/polars-arrow/src/array/binary/from.rs b/crates/polars-arrow/src/array/binary/from.rs new file mode 100644 index 000000000000..9ffac9827bb8 --- /dev/null +++ b/crates/polars-arrow/src/array/binary/from.rs @@ -0,0 +1,9 @@ +use super::{BinaryArray, MutableBinaryArray}; +use crate::offset::Offset; + +impl> FromIterator> for BinaryArray { + #[inline] + fn from_iter>>(iter: I) -> Self { + MutableBinaryArray::::from_iter(iter).into() + } +} diff --git a/crates/polars-arrow/src/array/binary/iterator.rs b/crates/polars-arrow/src/array/binary/iterator.rs new file mode 100644 index 000000000000..3fccec58eb50 --- /dev/null +++ b/crates/polars-arrow/src/array/binary/iterator.rs @@ -0,0 +1,42 @@ +use super::{BinaryArray, MutableBinaryValuesArray}; +use crate::array::{ArrayAccessor, ArrayValuesIter}; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::offset::Offset; + +unsafe impl<'a, O: Offset> ArrayAccessor<'a> for BinaryArray { + type Item = &'a [u8]; + + #[inline] + unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item { + self.value_unchecked(index) + } + + #[inline] + fn len(&self) -> usize { + self.len() + } +} + +/// Iterator of values of an [`BinaryArray`]. +pub type BinaryValueIter<'a, O> = ArrayValuesIter<'a, BinaryArray>; + +impl<'a, O: Offset> IntoIterator for &'a BinaryArray { + type Item = Option<&'a [u8]>; + type IntoIter = ZipValidity<&'a [u8], BinaryValueIter<'a, O>, BitmapIter<'a>>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +/// Iterator of values of an [`MutableBinaryValuesArray`]. +pub type MutableBinaryValuesIter<'a, O> = ArrayValuesIter<'a, MutableBinaryValuesArray>; + +impl<'a, O: Offset> IntoIterator for &'a MutableBinaryValuesArray { + type Item = &'a [u8]; + type IntoIter = MutableBinaryValuesIter<'a, O>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} diff --git a/crates/polars-arrow/src/array/binary/mod.rs b/crates/polars-arrow/src/array/binary/mod.rs new file mode 100644 index 000000000000..9be1b11c2e0a --- /dev/null +++ b/crates/polars-arrow/src/array/binary/mod.rs @@ -0,0 +1,472 @@ +use either::Either; + +use super::specification::try_check_offsets_bounds; +use super::{Array, GenericBinaryArray, Splitable}; +use crate::array::iterator::NonNullValuesIter; +use crate::bitmap::Bitmap; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::buffer::Buffer; +use crate::datatypes::ArrowDataType; +use crate::offset::{Offset, Offsets, OffsetsBuffer}; +use crate::trusted_len::TrustedLen; + +mod builder; +pub use builder::*; +mod ffi; +pub(super) mod fmt; +mod iterator; +pub use iterator::*; +mod from; +mod mutable_values; +pub use mutable_values::*; +mod mutable; +pub use mutable::*; +use polars_error::{PolarsResult, polars_bail}; + +/// A [`BinaryArray`] is Arrow's semantically equivalent of an immutable `Vec>>`. +/// It implements [`Array`]. +/// +/// The size of this struct is `O(1)`, as all data is stored behind an [`std::sync::Arc`]. +/// # Example +/// ``` +/// use polars_arrow::array::BinaryArray; +/// use polars_arrow::bitmap::Bitmap; +/// use polars_arrow::buffer::Buffer; +/// +/// let array = BinaryArray::::from([Some([1, 2].as_ref()), None, Some([3].as_ref())]); +/// assert_eq!(array.value(0), &[1, 2]); +/// assert_eq!(array.iter().collect::>(), vec![Some([1, 2].as_ref()), None, Some([3].as_ref())]); +/// assert_eq!(array.values_iter().collect::>(), vec![[1, 2].as_ref(), &[], &[3]]); +/// // the underlying representation: +/// assert_eq!(array.values(), &Buffer::from(vec![1, 2, 3])); +/// assert_eq!(array.offsets().buffer(), &Buffer::from(vec![0, 2, 2, 3])); +/// assert_eq!(array.validity(), Some(&Bitmap::from([true, false, true]))); +/// ``` +/// +/// # Generic parameter +/// The generic parameter [`Offset`] can only be `i32` or `i64` and tradeoffs maximum array length with +/// memory usage: +/// * the sum of lengths of all elements cannot exceed `Offset::MAX` +/// * the total size of the underlying data is `array.len() * size_of::() + sum of lengths of all elements` +/// +/// # Safety +/// The following invariants hold: +/// * Two consecutive `offsets` cast (`as`) to `usize` are valid slices of `values`. +/// * `len` is equal to `validity.len()`, when defined. +#[derive(Clone)] +pub struct BinaryArray { + dtype: ArrowDataType, + offsets: OffsetsBuffer, + values: Buffer, + validity: Option, +} + +impl BinaryArray { + /// Returns a [`BinaryArray`] created from its internal representation. + /// + /// # Errors + /// This function returns an error iff: + /// * The last offset is not equal to the values' length. + /// * the validity's length is not equal to `offsets.len()`. + /// * The `dtype`'s [`crate::datatypes::PhysicalType`] is not equal to either `Binary` or `LargeBinary`. + /// # Implementation + /// This function is `O(1)` + pub fn try_new( + dtype: ArrowDataType, + offsets: OffsetsBuffer, + values: Buffer, + validity: Option, + ) -> PolarsResult { + try_check_offsets_bounds(&offsets, values.len())?; + + if validity + .as_ref() + .is_some_and(|validity| validity.len() != offsets.len_proxy()) + { + polars_bail!(ComputeError: "validity mask length must match the number of values") + } + + if dtype.to_physical_type() != Self::default_dtype().to_physical_type() { + polars_bail!(ComputeError: "BinaryArray can only be initialized with DataType::Binary or DataType::LargeBinary") + } + + Ok(Self { + dtype, + offsets, + values, + validity, + }) + } + + /// Creates a new [`BinaryArray`] without checking invariants. + /// + /// # Safety + /// + /// The invariants must be valid (see try_new). + pub unsafe fn new_unchecked( + dtype: ArrowDataType, + offsets: OffsetsBuffer, + values: Buffer, + validity: Option, + ) -> Self { + Self { + dtype, + offsets, + values, + validity, + } + } + + /// Creates a new [`BinaryArray`] from slices of `&[u8]`. + pub fn from_slice, P: AsRef<[T]>>(slice: P) -> Self { + Self::from_trusted_len_values_iter(slice.as_ref().iter()) + } + + /// Creates a new [`BinaryArray`] from a slice of optional `&[u8]`. + // Note: this can't be `impl From` because Rust does not allow double `AsRef` on it. + pub fn from, P: AsRef<[Option]>>(slice: P) -> Self { + MutableBinaryArray::::from(slice).into() + } + + /// Returns an iterator of `Option<&[u8]>` over every element of this array. + pub fn iter(&self) -> ZipValidity<&[u8], BinaryValueIter, BitmapIter> { + ZipValidity::new_with_validity(self.values_iter(), self.validity.as_ref()) + } + + /// Returns an iterator of `&[u8]` over every element of this array, ignoring the validity + pub fn values_iter(&self) -> BinaryValueIter { + BinaryValueIter::new(self) + } + + /// Returns an iterator of the non-null values. + #[inline] + pub fn non_null_values_iter(&self) -> NonNullValuesIter<'_, BinaryArray> { + NonNullValuesIter::new(self, self.validity()) + } + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.offsets.len_proxy() + } + + /// Returns the element at index `i` + /// # Panics + /// iff `i >= self.len()` + #[inline] + pub fn value(&self, i: usize) -> &[u8] { + assert!(i < self.len()); + unsafe { self.value_unchecked(i) } + } + + /// Returns the element at index `i` + /// + /// # Safety + /// Assumes that the `i < self.len`. + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> &[u8] { + // soundness: the invariant of the function + let (start, end) = self.offsets.start_end_unchecked(i); + + // soundness: the invariant of the struct + self.values.get_unchecked(start..end) + } + + /// Returns the element at index `i` or `None` if it is null + /// # Panics + /// iff `i >= self.len()` + #[inline] + pub fn get(&self, i: usize) -> Option<&[u8]> { + if !self.is_null(i) { + // soundness: Array::is_null panics if i >= self.len + unsafe { Some(self.value_unchecked(i)) } + } else { + None + } + } + + /// Returns the [`ArrowDataType`] of this array. + #[inline] + pub fn dtype(&self) -> &ArrowDataType { + &self.dtype + } + + /// Returns the values of this [`BinaryArray`]. + #[inline] + pub fn values(&self) -> &Buffer { + &self.values + } + + /// Returns the offsets of this [`BinaryArray`]. + #[inline] + pub fn offsets(&self) -> &OffsetsBuffer { + &self.offsets + } + + /// The optional validity. + #[inline] + pub fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + /// Slices this [`BinaryArray`]. + /// # Implementation + /// This function is `O(1)`. + /// # Panics + /// iff `offset + length > self.len()`. + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Slices this [`BinaryArray`]. + /// # Implementation + /// This function is `O(1)`. + /// + /// # Safety + /// The caller must ensure that `offset + length <= self.len()`. + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.validity = self + .validity + .take() + .map(|bitmap| bitmap.sliced_unchecked(offset, length)) + .filter(|bitmap| bitmap.unset_bits() > 0); + self.offsets.slice_unchecked(offset, length + 1); + } + + impl_sliced!(); + impl_mut_validity!(); + impl_into_array!(); + + /// Returns its internal representation + #[must_use] + pub fn into_inner(self) -> (ArrowDataType, OffsetsBuffer, Buffer, Option) { + let Self { + dtype, + offsets, + values, + validity, + } = self; + (dtype, offsets, values, validity) + } + + /// Try to convert this `BinaryArray` to a `MutableBinaryArray` + #[must_use] + pub fn into_mut(self) -> Either> { + use Either::*; + if let Some(bitmap) = self.validity { + match bitmap.into_mut() { + // SAFETY: invariants are preserved + Left(bitmap) => Left(BinaryArray::new( + self.dtype, + self.offsets, + self.values, + Some(bitmap), + )), + Right(mutable_bitmap) => match (self.values.into_mut(), self.offsets.into_mut()) { + (Left(values), Left(offsets)) => Left(BinaryArray::new( + self.dtype, + offsets, + values, + Some(mutable_bitmap.into()), + )), + (Left(values), Right(offsets)) => Left(BinaryArray::new( + self.dtype, + offsets.into(), + values, + Some(mutable_bitmap.into()), + )), + (Right(values), Left(offsets)) => Left(BinaryArray::new( + self.dtype, + offsets, + values.into(), + Some(mutable_bitmap.into()), + )), + (Right(values), Right(offsets)) => Right( + MutableBinaryArray::try_new( + self.dtype, + offsets, + values, + Some(mutable_bitmap), + ) + .unwrap(), + ), + }, + } + } else { + match (self.values.into_mut(), self.offsets.into_mut()) { + (Left(values), Left(offsets)) => { + Left(BinaryArray::new(self.dtype, offsets, values, None)) + }, + (Left(values), Right(offsets)) => { + Left(BinaryArray::new(self.dtype, offsets.into(), values, None)) + }, + (Right(values), Left(offsets)) => { + Left(BinaryArray::new(self.dtype, offsets, values.into(), None)) + }, + (Right(values), Right(offsets)) => { + Right(MutableBinaryArray::try_new(self.dtype, offsets, values, None).unwrap()) + }, + } + } + } + + /// Creates an empty [`BinaryArray`], i.e. whose `.len` is zero. + pub fn new_empty(dtype: ArrowDataType) -> Self { + Self::new(dtype, OffsetsBuffer::new(), Buffer::new(), None) + } + + /// Creates an null [`BinaryArray`], i.e. whose `.null_count() == .len()`. + #[inline] + pub fn new_null(dtype: ArrowDataType, length: usize) -> Self { + unsafe { + Self::new_unchecked( + dtype, + Offsets::new_zeroed(length).into(), + Buffer::new(), + Some(Bitmap::new_zeroed(length)), + ) + } + } + + /// Returns the default [`ArrowDataType`], `DataType::Binary` or `DataType::LargeBinary` + pub fn default_dtype() -> ArrowDataType { + if O::IS_LARGE { + ArrowDataType::LargeBinary + } else { + ArrowDataType::Binary + } + } + + /// Alias for unwrapping [`Self::try_new`] + pub fn new( + dtype: ArrowDataType, + offsets: OffsetsBuffer, + values: Buffer, + validity: Option, + ) -> Self { + Self::try_new(dtype, offsets, values, validity).unwrap() + } + + /// Returns a [`BinaryArray`] from an iterator of trusted length. + /// + /// The [`BinaryArray`] is guaranteed to not have a validity + #[inline] + pub fn from_trusted_len_values_iter, I: TrustedLen>( + iterator: I, + ) -> Self { + MutableBinaryArray::::from_trusted_len_values_iter(iterator).into() + } + + /// Returns a new [`BinaryArray`] from a [`Iterator`] of `&[u8]`. + /// + /// The [`BinaryArray`] is guaranteed to not have a validity + pub fn from_iter_values, I: Iterator>(iterator: I) -> Self { + MutableBinaryArray::::from_iter_values(iterator).into() + } + + /// Creates a [`BinaryArray`] from an iterator of trusted length. + /// + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + P: AsRef<[u8]>, + I: Iterator>, + { + MutableBinaryArray::::from_trusted_len_iter_unchecked(iterator).into() + } + + /// Creates a [`BinaryArray`] from a [`TrustedLen`] + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + P: AsRef<[u8]>, + I: TrustedLen>, + { + // soundness: I is `TrustedLen` + unsafe { Self::from_trusted_len_iter_unchecked(iterator) } + } + + /// Creates a [`BinaryArray`] from an falible iterator of trusted length. + /// + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn try_from_trusted_len_iter_unchecked(iterator: I) -> Result + where + P: AsRef<[u8]>, + I: IntoIterator, E>>, + { + MutableBinaryArray::::try_from_trusted_len_iter_unchecked(iterator).map(|x| x.into()) + } + + /// Creates a [`BinaryArray`] from an fallible iterator of trusted length. + #[inline] + pub fn try_from_trusted_len_iter(iter: I) -> Result + where + P: AsRef<[u8]>, + I: TrustedLen, E>>, + { + // soundness: I: TrustedLen + unsafe { Self::try_from_trusted_len_iter_unchecked(iter) } + } +} + +impl Array for BinaryArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + #[inline] + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } +} + +unsafe impl GenericBinaryArray for BinaryArray { + #[inline] + fn values(&self) -> &[u8] { + self.values() + } + + #[inline] + fn offsets(&self) -> &[O] { + self.offsets().buffer() + } +} + +impl Splitable for BinaryArray { + #[inline(always)] + fn check_bound(&self, offset: usize) -> bool { + offset <= self.len() + } + + unsafe fn _split_at_unchecked(&self, offset: usize) -> (Self, Self) { + let (lhs_offsets, rhs_offsets) = unsafe { self.offsets.split_at_unchecked(offset) }; + let (lhs_validity, rhs_validity) = unsafe { self.validity.split_at_unchecked(offset) }; + + ( + Self { + dtype: self.dtype.clone(), + offsets: lhs_offsets, + values: self.values.clone(), + validity: lhs_validity, + }, + Self { + dtype: self.dtype.clone(), + offsets: rhs_offsets, + values: self.values.clone(), + validity: rhs_validity, + }, + ) + } +} diff --git a/crates/polars-arrow/src/array/binary/mutable.rs b/crates/polars-arrow/src/array/binary/mutable.rs new file mode 100644 index 000000000000..84de3a9893d5 --- /dev/null +++ b/crates/polars-arrow/src/array/binary/mutable.rs @@ -0,0 +1,468 @@ +use std::sync::Arc; + +use polars_error::{PolarsResult, polars_bail}; + +use super::{BinaryArray, MutableBinaryValuesArray, MutableBinaryValuesIter}; +use crate::array::physical_binary::*; +use crate::array::{Array, MutableArray, TryExtend, TryExtendFromSelf, TryPush}; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::datatypes::ArrowDataType; +use crate::offset::{Offset, Offsets}; +use crate::trusted_len::TrustedLen; + +/// The Arrow's equivalent to `Vec>>`. +/// Converting a [`MutableBinaryArray`] into a [`BinaryArray`] is `O(1)`. +/// # Implementation +/// This struct does not allocate a validity until one is required (i.e. push a null to it). +#[derive(Debug, Clone)] +pub struct MutableBinaryArray { + values: MutableBinaryValuesArray, + validity: Option, +} + +impl From> for BinaryArray { + fn from(other: MutableBinaryArray) -> Self { + let validity = other.validity.and_then(|x| { + let validity: Option = x.into(); + validity + }); + let array: BinaryArray = other.values.into(); + array.with_validity(validity) + } +} + +impl Default for MutableBinaryArray { + fn default() -> Self { + Self::new() + } +} + +impl MutableBinaryArray { + /// Creates a new empty [`MutableBinaryArray`]. + /// # Implementation + /// This allocates a [`Vec`] of one element + pub fn new() -> Self { + Self::with_capacity(0) + } + + /// Returns a [`MutableBinaryArray`] created from its internal representation. + /// + /// # Errors + /// This function returns an error iff: + /// * The last offset is not equal to the values' length. + /// * the validity's length is not equal to `offsets.len()`. + /// * The `dtype`'s [`crate::datatypes::PhysicalType`] is not equal to either `Binary` or `LargeBinary`. + /// # Implementation + /// This function is `O(1)` + pub fn try_new( + dtype: ArrowDataType, + offsets: Offsets, + values: Vec, + validity: Option, + ) -> PolarsResult { + let values = MutableBinaryValuesArray::try_new(dtype, offsets, values)?; + + if validity + .as_ref() + .is_some_and(|validity| validity.len() != values.len()) + { + polars_bail!(ComputeError: "validity's length must be equal to the number of values") + } + + Ok(Self { values, validity }) + } + + /// Creates a new [`MutableBinaryArray`] from a slice of optional `&[u8]`. + // Note: this can't be `impl From` because Rust does not allow double `AsRef` on it. + pub fn from, P: AsRef<[Option]>>(slice: P) -> Self { + Self::from_trusted_len_iter(slice.as_ref().iter().map(|x| x.as_ref())) + } + + fn default_dtype() -> ArrowDataType { + BinaryArray::::default_dtype() + } + + /// Initializes a new [`MutableBinaryArray`] with a pre-allocated capacity of slots. + pub fn with_capacity(capacity: usize) -> Self { + Self::with_capacities(capacity, 0) + } + + /// Initializes a new [`MutableBinaryArray`] with a pre-allocated capacity of slots and values. + /// # Implementation + /// This does not allocate the validity. + pub fn with_capacities(capacity: usize, values: usize) -> Self { + Self { + values: MutableBinaryValuesArray::with_capacities(capacity, values), + validity: None, + } + } + + /// Reserves `additional` elements and `additional_values` on the values buffer. + pub fn reserve(&mut self, additional: usize, additional_values: usize) { + self.values.reserve(additional, additional_values); + if let Some(x) = self.validity.as_mut() { + x.reserve(additional) + } + } + + /// Pushes a new element to the array. + /// # Panic + /// This operation panics iff the length of all values (in bytes) exceeds `O` maximum value. + pub fn push>(&mut self, value: Option) { + self.try_push(value).unwrap() + } + + /// Pop the last entry from [`MutableBinaryArray`]. + /// This function returns `None` iff this array is empty + pub fn pop(&mut self) -> Option> { + let value = self.values.pop()?; + self.validity + .as_mut() + .map(|x| x.pop()?.then(|| ())) + .unwrap_or_else(|| Some(())) + .map(|_| value) + } + + fn try_from_iter, I: IntoIterator>>( + iter: I, + ) -> PolarsResult { + let iterator = iter.into_iter(); + let (lower, _) = iterator.size_hint(); + let mut primitive = Self::with_capacity(lower); + for item in iterator { + primitive.try_push(item.as_ref())? + } + Ok(primitive) + } + + fn init_validity(&mut self) { + let mut validity = MutableBitmap::with_capacity(self.values.capacity()); + validity.extend_constant(self.len(), true); + validity.set(self.len() - 1, false); + self.validity = Some(validity); + } + + /// Converts itself into an [`Array`]. + pub fn into_arc(self) -> Arc { + let a: BinaryArray = self.into(); + Arc::new(a) + } + + /// Shrinks the capacity of the [`MutableBinaryArray`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + if let Some(validity) = &mut self.validity { + validity.shrink_to_fit() + } + } + + impl_mutable_array_mut_validity!(); +} + +impl MutableBinaryArray { + /// returns its values. + pub fn values(&self) -> &Vec { + self.values.values() + } + + /// returns its offsets. + pub fn offsets(&self) -> &Offsets { + self.values.offsets() + } + + /// Returns an iterator of `Option<&[u8]>` + pub fn iter(&self) -> ZipValidity<&[u8], MutableBinaryValuesIter, BitmapIter> { + ZipValidity::new(self.values_iter(), self.validity.as_ref().map(|x| x.iter())) + } + + /// Returns an iterator over the values of this array + pub fn values_iter(&self) -> MutableBinaryValuesIter { + self.values.iter() + } +} + +impl MutableArray for MutableBinaryArray { + fn len(&self) -> usize { + self.values.len() + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + let array: BinaryArray = std::mem::take(self).into(); + array.boxed() + } + + fn as_arc(&mut self) -> Arc { + let array: BinaryArray = std::mem::take(self).into(); + array.arced() + } + + fn dtype(&self) -> &ArrowDataType { + self.values.dtype() + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + self.push::<&[u8]>(None) + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional, 0) + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit() + } +} + +impl> FromIterator> for MutableBinaryArray { + fn from_iter>>(iter: I) -> Self { + Self::try_from_iter(iter).unwrap() + } +} + +impl MutableBinaryArray { + /// Creates a [`MutableBinaryArray`] from an iterator of trusted length. + /// + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + P: AsRef<[u8]>, + I: Iterator>, + { + let (validity, offsets, values) = trusted_len_unzip(iterator); + + Self::try_new(Self::default_dtype(), offsets, values, validity).unwrap() + } + + /// Creates a [`MutableBinaryArray`] from an iterator of trusted length. + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + P: AsRef<[u8]>, + I: TrustedLen>, + { + // soundness: I is `TrustedLen` + unsafe { Self::from_trusted_len_iter_unchecked(iterator) } + } + + /// Creates a new [`BinaryArray`] from a [`TrustedLen`] of `&[u8]`. + /// + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_values_iter_unchecked, I: Iterator>( + iterator: I, + ) -> Self { + let (offsets, values) = trusted_len_values_iter(iterator); + Self::try_new(Self::default_dtype(), offsets, values, None).unwrap() + } + + /// Creates a new [`BinaryArray`] from a [`TrustedLen`] of `&[u8]`. + #[inline] + pub fn from_trusted_len_values_iter, I: TrustedLen>( + iterator: I, + ) -> Self { + // soundness: I is `TrustedLen` + unsafe { Self::from_trusted_len_values_iter_unchecked(iterator) } + } + + /// Creates a [`MutableBinaryArray`] from an falible iterator of trusted length. + /// + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn try_from_trusted_len_iter_unchecked( + iterator: I, + ) -> std::result::Result + where + P: AsRef<[u8]>, + I: IntoIterator, E>>, + { + let iterator = iterator.into_iter(); + + // soundness: assumed trusted len + let (validity, offsets, values) = try_trusted_len_unzip(iterator)?; + Ok(Self::try_new(Self::default_dtype(), offsets, values, validity).unwrap()) + } + + /// Creates a [`MutableBinaryArray`] from an falible iterator of trusted length. + #[inline] + pub fn try_from_trusted_len_iter(iterator: I) -> std::result::Result + where + P: AsRef<[u8]>, + I: TrustedLen, E>>, + { + // soundness: I: TrustedLen + unsafe { Self::try_from_trusted_len_iter_unchecked(iterator) } + } + + /// Extends the [`MutableBinaryArray`] from an iterator of trusted length. + /// This differs from `extend_trusted_len` which accepts iterator of optional values. + #[inline] + pub fn extend_trusted_len_values(&mut self, iterator: I) + where + P: AsRef<[u8]>, + I: TrustedLen, + { + // SAFETY: The iterator is `TrustedLen` + unsafe { self.extend_trusted_len_values_unchecked(iterator) } + } + + /// Extends the [`MutableBinaryArray`] from an iterator of values. + /// This differs from `extended_trusted_len` which accepts iterator of optional values. + #[inline] + pub fn extend_values(&mut self, iterator: I) + where + P: AsRef<[u8]>, + I: Iterator, + { + let length = self.values.len(); + self.values.extend(iterator); + let additional = self.values.len() - length; + + if let Some(validity) = self.validity.as_mut() { + validity.extend_constant(additional, true); + } + } + + /// Extends the [`MutableBinaryArray`] from an `iterator` of values of trusted length. + /// This differs from `extend_trusted_len_unchecked` which accepts iterator of optional + /// values. + /// + /// # Safety + /// The `iterator` must be [`TrustedLen`] + #[inline] + pub unsafe fn extend_trusted_len_values_unchecked(&mut self, iterator: I) + where + P: AsRef<[u8]>, + I: Iterator, + { + let length = self.values.len(); + self.values.extend_trusted_len_unchecked(iterator); + let additional = self.values.len() - length; + + if let Some(validity) = self.validity.as_mut() { + validity.extend_constant(additional, true); + } + } + + /// Extends the [`MutableBinaryArray`] from an iterator of [`TrustedLen`] + #[inline] + pub fn extend_trusted_len(&mut self, iterator: I) + where + P: AsRef<[u8]>, + I: TrustedLen>, + { + // SAFETY: The iterator is `TrustedLen` + unsafe { self.extend_trusted_len_unchecked(iterator) } + } + + /// Extends the [`MutableBinaryArray`] from an iterator of [`TrustedLen`] + /// + /// # Safety + /// The `iterator` must be [`TrustedLen`] + #[inline] + pub unsafe fn extend_trusted_len_unchecked(&mut self, iterator: I) + where + P: AsRef<[u8]>, + I: Iterator>, + { + if self.validity.is_none() { + let mut validity = MutableBitmap::new(); + validity.extend_constant(self.len(), true); + self.validity = Some(validity); + } + + self.values + .extend_from_trusted_len_iter(self.validity.as_mut().unwrap(), iterator); + } + + /// Creates a new [`MutableBinaryArray`] from a [`Iterator`] of `&[u8]`. + pub fn from_iter_values, I: Iterator>(iterator: I) -> Self { + let (offsets, values) = values_iter(iterator); + Self::try_new(Self::default_dtype(), offsets, values, None).unwrap() + } + + /// Extend with a fallible iterator + pub fn extend_fallible(&mut self, iter: I) -> std::result::Result<(), E> + where + E: std::error::Error, + I: IntoIterator, E>>, + T: AsRef<[u8]>, + { + let mut iter = iter.into_iter(); + self.reserve(iter.size_hint().0, 0); + iter.try_for_each(|x| { + self.push(x?); + Ok(()) + }) + } +} + +impl> Extend> for MutableBinaryArray { + fn extend>>(&mut self, iter: I) { + self.try_extend(iter).unwrap(); + } +} + +impl> TryExtend> for MutableBinaryArray { + fn try_extend>>(&mut self, iter: I) -> PolarsResult<()> { + let mut iter = iter.into_iter(); + self.reserve(iter.size_hint().0, 0); + iter.try_for_each(|x| self.try_push(x)) + } +} + +impl> TryPush> for MutableBinaryArray { + fn try_push(&mut self, value: Option) -> PolarsResult<()> { + match value { + Some(value) => { + self.values.try_push(value.as_ref())?; + + if let Some(validity) = &mut self.validity { + validity.push(true) + } + }, + None => { + self.values.push(""); + match &mut self.validity { + Some(validity) => validity.push(false), + None => self.init_validity(), + } + }, + } + Ok(()) + } +} + +impl PartialEq for MutableBinaryArray { + fn eq(&self, other: &Self) -> bool { + self.iter().eq(other.iter()) + } +} + +impl TryExtendFromSelf for MutableBinaryArray { + fn try_extend_from_self(&mut self, other: &Self) -> PolarsResult<()> { + extend_validity(self.len(), &mut self.validity, &other.validity); + + self.values.try_extend_from_self(&other.values) + } +} diff --git a/crates/polars-arrow/src/array/binary/mutable_values.rs b/crates/polars-arrow/src/array/binary/mutable_values.rs new file mode 100644 index 000000000000..c02d8dfbbb40 --- /dev/null +++ b/crates/polars-arrow/src/array/binary/mutable_values.rs @@ -0,0 +1,374 @@ +use std::sync::Arc; + +use polars_error::{PolarsResult, polars_bail}; + +use super::{BinaryArray, MutableBinaryArray}; +use crate::array::physical_binary::*; +use crate::array::specification::try_check_offsets_bounds; +use crate::array::{ + Array, ArrayAccessor, ArrayValuesIter, MutableArray, TryExtend, TryExtendFromSelf, TryPush, +}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::ArrowDataType; +use crate::offset::{Offset, Offsets}; +use crate::trusted_len::TrustedLen; + +/// A [`MutableArray`] that builds a [`BinaryArray`]. It differs +/// from [`MutableBinaryArray`] in that it builds non-null [`BinaryArray`]. +#[derive(Debug, Clone)] +pub struct MutableBinaryValuesArray { + dtype: ArrowDataType, + offsets: Offsets, + values: Vec, +} + +impl From> for BinaryArray { + fn from(other: MutableBinaryValuesArray) -> Self { + BinaryArray::::new(other.dtype, other.offsets.into(), other.values.into(), None) + } +} + +impl From> for MutableBinaryArray { + fn from(other: MutableBinaryValuesArray) -> Self { + MutableBinaryArray::::try_new(other.dtype, other.offsets, other.values, None) + .expect("MutableBinaryValuesArray is consistent with MutableBinaryArray") + } +} + +impl Default for MutableBinaryValuesArray { + fn default() -> Self { + Self::new() + } +} + +impl MutableBinaryValuesArray { + /// Returns an empty [`MutableBinaryValuesArray`]. + pub fn new() -> Self { + Self { + dtype: Self::default_dtype(), + offsets: Offsets::new(), + values: Vec::::new(), + } + } + + /// Returns a [`MutableBinaryValuesArray`] created from its internal representation. + /// + /// # Errors + /// This function returns an error iff: + /// * The last offset is not equal to the values' length. + /// * The `dtype`'s [`crate::datatypes::PhysicalType`] is not equal to either `Binary` or `LargeBinary`. + /// # Implementation + /// This function is `O(1)` + pub fn try_new( + dtype: ArrowDataType, + offsets: Offsets, + values: Vec, + ) -> PolarsResult { + try_check_offsets_bounds(&offsets, values.len())?; + + if dtype.to_physical_type() != Self::default_dtype().to_physical_type() { + polars_bail!(ComputeError: "MutableBinaryValuesArray can only be initialized with DataType::Binary or DataType::LargeBinary",) + } + + Ok(Self { + dtype, + offsets, + values, + }) + } + + /// Returns the default [`ArrowDataType`] of this container: [`ArrowDataType::Utf8`] or [`ArrowDataType::LargeUtf8`] + /// depending on the generic [`Offset`]. + pub fn default_dtype() -> ArrowDataType { + BinaryArray::::default_dtype() + } + + /// Initializes a new [`MutableBinaryValuesArray`] with a pre-allocated capacity of items. + pub fn with_capacity(capacity: usize) -> Self { + Self::with_capacities(capacity, 0) + } + + /// Initializes a new [`MutableBinaryValuesArray`] with a pre-allocated capacity of items and values. + pub fn with_capacities(capacity: usize, values: usize) -> Self { + Self { + dtype: Self::default_dtype(), + offsets: Offsets::::with_capacity(capacity), + values: Vec::::with_capacity(values), + } + } + + /// returns its values. + #[inline] + pub fn values(&self) -> &Vec { + &self.values + } + + /// returns its offsets. + #[inline] + pub fn offsets(&self) -> &Offsets { + &self.offsets + } + + /// Reserves `additional` elements and `additional_values` on the values. + #[inline] + pub fn reserve(&mut self, additional: usize, additional_values: usize) { + self.offsets.reserve(additional); + self.values.reserve(additional_values); + } + + /// Returns the capacity in number of items + pub fn capacity(&self) -> usize { + self.offsets.capacity() + } + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.offsets.len_proxy() + } + + /// Pushes a new item to the array. + /// # Panic + /// This operation panics iff the length of all values (in bytes) exceeds `O` maximum value. + #[inline] + pub fn push>(&mut self, value: T) { + self.try_push(value).unwrap() + } + + /// Pop the last entry from [`MutableBinaryValuesArray`]. + /// This function returns `None` iff this array is empty. + pub fn pop(&mut self) -> Option> { + if self.len() == 0 { + return None; + } + self.offsets.pop()?; + let start = self.offsets.last().to_usize(); + let value = self.values.split_off(start); + Some(value.to_vec()) + } + + /// Returns the value of the element at index `i`. + /// # Panic + /// This function panics iff `i >= self.len`. + #[inline] + pub fn value(&self, i: usize) -> &[u8] { + assert!(i < self.len()); + unsafe { self.value_unchecked(i) } + } + + /// Returns the value of the element at index `i`. + /// + /// # Safety + /// This function is safe iff `i < self.len`. + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> &[u8] { + // soundness: the invariant of the function + let (start, end) = self.offsets.start_end(i); + + // soundness: the invariant of the struct + self.values.get_unchecked(start..end) + } + + /// Returns an iterator of `&[u8]` + pub fn iter(&self) -> ArrayValuesIter { + ArrayValuesIter::new(self) + } + + /// Shrinks the capacity of the [`MutableBinaryValuesArray`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + self.offsets.shrink_to_fit(); + } + + /// Extract the low-end APIs from the [`MutableBinaryValuesArray`]. + pub fn into_inner(self) -> (ArrowDataType, Offsets, Vec) { + (self.dtype, self.offsets, self.values) + } +} + +impl MutableArray for MutableBinaryValuesArray { + fn len(&self) -> usize { + self.len() + } + + fn validity(&self) -> Option<&MutableBitmap> { + None + } + + fn as_box(&mut self) -> Box { + let (dtype, offsets, values) = std::mem::take(self).into_inner(); + BinaryArray::new(dtype, offsets.into(), values.into(), None).boxed() + } + + fn as_arc(&mut self) -> Arc { + let (dtype, offsets, values) = std::mem::take(self).into_inner(); + BinaryArray::new(dtype, offsets.into(), values.into(), None).arced() + } + + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + self.push::<&[u8]>(b"") + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional, 0) + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit() + } +} + +impl> FromIterator

for MutableBinaryValuesArray { + fn from_iter>(iter: I) -> Self { + let (offsets, values) = values_iter(iter.into_iter()); + Self::try_new(Self::default_dtype(), offsets, values).unwrap() + } +} + +impl MutableBinaryValuesArray { + pub(crate) unsafe fn extend_from_trusted_len_iter( + &mut self, + validity: &mut MutableBitmap, + iterator: I, + ) where + P: AsRef<[u8]>, + I: Iterator>, + { + extend_from_trusted_len_iter(&mut self.offsets, &mut self.values, validity, iterator); + } + + /// Extends the [`MutableBinaryValuesArray`] from a [`TrustedLen`] + #[inline] + pub fn extend_trusted_len(&mut self, iterator: I) + where + P: AsRef<[u8]>, + I: TrustedLen, + { + unsafe { self.extend_trusted_len_unchecked(iterator) } + } + + /// Extends [`MutableBinaryValuesArray`] from an iterator of trusted len. + /// + /// # Safety + /// The iterator must be trusted len. + #[inline] + pub unsafe fn extend_trusted_len_unchecked(&mut self, iterator: I) + where + P: AsRef<[u8]>, + I: Iterator, + { + extend_from_trusted_len_values_iter(&mut self.offsets, &mut self.values, iterator); + } + + /// Creates a [`MutableBinaryValuesArray`] from a [`TrustedLen`] + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + P: AsRef<[u8]>, + I: TrustedLen, + { + // soundness: I is `TrustedLen` + unsafe { Self::from_trusted_len_iter_unchecked(iterator) } + } + + /// Returns a new [`MutableBinaryValuesArray`] from an iterator of trusted length. + /// + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + P: AsRef<[u8]>, + I: Iterator, + { + let (offsets, values) = trusted_len_values_iter(iterator); + Self::try_new(Self::default_dtype(), offsets, values).unwrap() + } + + /// Returns a new [`MutableBinaryValuesArray`] from an iterator. + /// # Error + /// This operation errors iff the total length in bytes on the iterator exceeds `O`'s maximum value. + /// (`i32::MAX` or `i64::MAX` respectively). + pub fn try_from_iter, I: IntoIterator>(iter: I) -> PolarsResult { + let iterator = iter.into_iter(); + let (lower, _) = iterator.size_hint(); + let mut array = Self::with_capacity(lower); + for item in iterator { + array.try_push(item)?; + } + Ok(array) + } + + /// Extend with a fallible iterator + pub fn extend_fallible(&mut self, iter: I) -> std::result::Result<(), E> + where + E: std::error::Error, + I: IntoIterator>, + T: AsRef<[u8]>, + { + let mut iter = iter.into_iter(); + self.reserve(iter.size_hint().0, 0); + iter.try_for_each(|x| { + self.push(x?); + Ok(()) + }) + } +} + +impl> Extend for MutableBinaryValuesArray { + fn extend>(&mut self, iter: I) { + extend_from_values_iter(&mut self.offsets, &mut self.values, iter.into_iter()); + } +} + +impl> TryExtend for MutableBinaryValuesArray { + fn try_extend>(&mut self, iter: I) -> PolarsResult<()> { + let mut iter = iter.into_iter(); + self.reserve(iter.size_hint().0, 0); + iter.try_for_each(|x| self.try_push(x)) + } +} + +impl> TryPush for MutableBinaryValuesArray { + #[inline] + fn try_push(&mut self, value: T) -> PolarsResult<()> { + let bytes = value.as_ref(); + self.values.extend_from_slice(bytes); + self.offsets.try_push(bytes.len()) + } +} + +unsafe impl<'a, O: Offset> ArrayAccessor<'a> for MutableBinaryValuesArray { + type Item = &'a [u8]; + + #[inline] + unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item { + self.value_unchecked(index) + } + + #[inline] + fn len(&self) -> usize { + self.len() + } +} + +impl TryExtendFromSelf for MutableBinaryValuesArray { + fn try_extend_from_self(&mut self, other: &Self) -> PolarsResult<()> { + self.values.extend_from_slice(&other.values); + self.offsets.try_extend_from_self(&other.offsets) + } +} diff --git a/crates/polars-arrow/src/array/binview/builder.rs b/crates/polars-arrow/src/array/binview/builder.rs new file mode 100644 index 000000000000..0a4bb22e4ab6 --- /dev/null +++ b/crates/polars-arrow/src/array/binview/builder.rs @@ -0,0 +1,381 @@ +use std::marker::PhantomData; +use std::sync::{Arc, LazyLock}; + +use hashbrown::hash_map::Entry; +use polars_utils::IdxSize; +use polars_utils::aliases::{InitHashMaps, PlHashMap}; + +use crate::array::binview::{DEFAULT_BLOCK_SIZE, MAX_EXP_BLOCK_SIZE}; +use crate::array::builder::{ShareStrategy, StaticArrayBuilder}; +use crate::array::{Array, BinaryViewArrayGeneric, View, ViewType}; +use crate::bitmap::OptBitmapBuilder; +use crate::buffer::Buffer; +use crate::datatypes::ArrowDataType; +use crate::pushable::Pushable; + +static PLACEHOLDER_BUFFER: LazyLock> = LazyLock::new(|| Buffer::from_static(&[])); + +pub struct BinaryViewArrayGenericBuilder { + dtype: ArrowDataType, + views: Vec, + active_buffer: Vec, + active_buffer_idx: u32, + buffer_set: Vec>, + stolen_buffers: PlHashMap, + + // With these we can amortize buffer set translation costs if repeatedly + // stealing from the same set of buffers. + last_buffer_set_stolen_from: Option]>>, + buffer_set_translation_idxs: Vec<(u32, u32)>, // (idx, generation) + buffer_set_translation_generation: u32, + + validity: OptBitmapBuilder, + /// Total bytes length if we would concatenate them all. + total_bytes_len: usize, + /// Total bytes in the buffer set (excluding remaining capacity). + total_buffer_len: usize, + view_type: PhantomData, +} + +impl BinaryViewArrayGenericBuilder { + pub fn new(dtype: ArrowDataType) -> Self { + Self { + dtype, + views: Vec::new(), + active_buffer: Vec::new(), + active_buffer_idx: 0, + buffer_set: Vec::new(), + stolen_buffers: PlHashMap::new(), + last_buffer_set_stolen_from: None, + buffer_set_translation_idxs: Vec::new(), + buffer_set_translation_generation: 0, + validity: OptBitmapBuilder::default(), + total_bytes_len: 0, + total_buffer_len: 0, + view_type: PhantomData, + } + } + + #[inline] + fn reserve_active_buffer(&mut self, additional: usize) { + let len = self.active_buffer.len(); + let cap = self.active_buffer.capacity(); + if additional > cap - len || len + additional >= (u32::MAX - 1) as usize { + self.reserve_active_buffer_slow(additional); + } + } + + #[cold] + fn reserve_active_buffer_slow(&mut self, additional: usize) { + assert!( + additional <= (u32::MAX - 1) as usize, + "strings longer than 2^32 - 2 are not supported" + ); + + // Allocate a new buffer and flush the old buffer. + let new_capacity = (self.active_buffer.capacity() * 2) + .clamp(DEFAULT_BLOCK_SIZE, MAX_EXP_BLOCK_SIZE) + .max(additional); + + let old_buffer = + core::mem::replace(&mut self.active_buffer, Vec::with_capacity(new_capacity)); + if !old_buffer.is_empty() { + // Replace dummy with real buffer. + self.buffer_set[self.active_buffer_idx as usize] = Buffer::from(old_buffer); + } + self.active_buffer_idx = self.buffer_set.len().try_into().unwrap(); + self.buffer_set.push(PLACEHOLDER_BUFFER.clone()) // Push placeholder so active_buffer_idx stays valid. + } + + pub fn push_value_ignore_validity(&mut self, bytes: &V) { + let bytes = bytes.to_bytes(); + self.total_bytes_len += bytes.len(); + unsafe { + let view = if bytes.len() > View::MAX_INLINE_SIZE as usize { + self.reserve_active_buffer(bytes.len()); + + let offset = self.active_buffer.len() as u32; // Ensured no overflow by reserve_active_buffer. + self.active_buffer.extend_from_slice(bytes); + self.total_buffer_len += bytes.len(); + View::new_noninline_unchecked(bytes, self.active_buffer_idx, offset) + } else { + View::new_inline_unchecked(bytes) + }; + self.views.push(view); + } + } + + /// # Safety + /// The view must be inline. + pub unsafe fn push_inline_view_ignore_validity(&mut self, view: View) { + debug_assert!(view.is_inline()); + self.total_bytes_len += view.length as usize; + self.views.push(view); + } + + fn switch_active_stealing_bufferset_to(&mut self, buffer_set: &Arc<[Buffer]>) { + // Fat pointer equality, checks both start and length. + if self + .last_buffer_set_stolen_from + .as_ref() + .is_some_and(|stolen_bs| std::ptr::eq(Arc::as_ptr(stolen_bs), Arc::as_ptr(buffer_set))) + { + return; // Already active. + } + + // Switch to new generation (invalidating all old translation indices), + // and resizing the buffer with invalid indices if necessary. + let old_gen = self.buffer_set_translation_generation; + self.buffer_set_translation_generation = old_gen.wrapping_add(1); + if self.buffer_set_translation_idxs.len() < buffer_set.len() { + self.buffer_set_translation_idxs + .resize(buffer_set.len(), (0, old_gen)); + } + } + + unsafe fn extend_views_dedup_ignore_validity( + &mut self, + views: impl IntoIterator, + other_bufferset: &Arc<[Buffer]>, + ) { + // TODO: if there are way more buffers than length translate per-view + // rather than all at once. + self.switch_active_stealing_bufferset_to(other_bufferset); + + for mut view in views { + if view.length > View::MAX_INLINE_SIZE { + // Translate from old array-local buffer idx to global stolen buffer idx. + let (mut new_buffer_idx, gen_) = *self + .buffer_set_translation_idxs + .get_unchecked(view.buffer_idx as usize); + if gen_ != self.buffer_set_translation_generation { + // This buffer index wasn't seen before for this array, do a dedup lookup. + // Since we map by starting pointer and different subslices may have different lengths, we expand + // the buffer to the maximum it could be. + let buffer = other_bufferset + .get_unchecked(view.buffer_idx as usize) + .clone() + .expand_end_to_storage(); + let buf_id = buffer.as_slice().as_ptr().addr(); + let idx = match self.stolen_buffers.entry(buf_id) { + Entry::Occupied(o) => *o.get(), + Entry::Vacant(v) => { + let idx = self.buffer_set.len() as u32; + self.total_buffer_len += buffer.len(); + self.buffer_set.push(buffer); + v.insert(idx); + idx + }, + }; + + // Cache result for future lookups. + *self + .buffer_set_translation_idxs + .get_unchecked_mut(view.buffer_idx as usize) = + (idx, self.buffer_set_translation_generation); + new_buffer_idx = idx; + } + view.buffer_idx = new_buffer_idx; + } + + self.total_bytes_len += view.length as usize; + self.views.push(view); + } + } +} + +impl StaticArrayBuilder for BinaryViewArrayGenericBuilder { + type Array = BinaryViewArrayGeneric; + + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } + + fn reserve(&mut self, additional: usize) { + self.views.reserve(additional); + self.validity.reserve(additional); + } + + fn freeze(mut self) -> Self::Array { + // Flush active buffer and/or remove extra placeholder buffer. + if !self.active_buffer.is_empty() { + self.buffer_set[self.active_buffer_idx as usize] = Buffer::from(self.active_buffer); + } else if self.buffer_set.last().is_some_and(|b| b.is_empty()) { + self.buffer_set.pop(); + } + + unsafe { + BinaryViewArrayGeneric::new_unchecked( + self.dtype, + Buffer::from(self.views), + Arc::from(self.buffer_set), + self.validity.into_opt_validity(), + self.total_bytes_len, + self.total_buffer_len, + ) + } + } + + fn freeze_reset(&mut self) -> Self::Array { + // Flush active buffer and/or remove extra placeholder buffer. + if !self.active_buffer.is_empty() { + self.buffer_set[self.active_buffer_idx as usize] = + Buffer::from(core::mem::take(&mut self.active_buffer)); + } else if self.buffer_set.last().is_some_and(|b| b.is_empty()) { + self.buffer_set.pop(); + } + + let out = unsafe { + BinaryViewArrayGeneric::new_unchecked( + self.dtype.clone(), + Buffer::from(core::mem::take(&mut self.views)), + Arc::from(core::mem::take(&mut self.buffer_set)), + core::mem::take(&mut self.validity).into_opt_validity(), + self.total_bytes_len, + self.total_buffer_len, + ) + }; + + self.total_buffer_len = 0; + self.total_bytes_len = 0; + self.active_buffer_idx = 0; + self.stolen_buffers.clear(); + self.last_buffer_set_stolen_from = None; + out + } + + fn len(&self) -> usize { + self.views.len() + } + + fn extend_nulls(&mut self, length: usize) { + self.views.extend_constant(length, View::default()); + self.validity.extend_constant(length, false); + } + + fn subslice_extend( + &mut self, + other: &Self::Array, + start: usize, + length: usize, + share: ShareStrategy, + ) { + self.views.reserve(length); + + unsafe { + match share { + ShareStrategy::Never => { + if let Some(v) = other.validity() { + for i in start..start + length { + if v.get_bit_unchecked(i) { + self.push_value_ignore_validity(other.value_unchecked(i)); + } else { + self.views.push(View::default()) + } + } + } else { + for i in start..start + length { + self.push_value_ignore_validity(other.value_unchecked(i)); + } + } + }, + ShareStrategy::Always => { + let other_views = &other.views()[start..start + length]; + self.extend_views_dedup_ignore_validity( + other_views.iter().copied(), + other.data_buffers(), + ); + }, + } + } + + self.validity + .subslice_extend_from_opt_validity(other.validity(), start, length); + } + + unsafe fn gather_extend( + &mut self, + other: &Self::Array, + idxs: &[IdxSize], + share: ShareStrategy, + ) { + self.views.reserve(idxs.len()); + + unsafe { + match share { + ShareStrategy::Never => { + if let Some(v) = other.validity() { + for idx in idxs { + if v.get_bit_unchecked(*idx as usize) { + self.push_value_ignore_validity( + other.value_unchecked(*idx as usize), + ); + } else { + self.views.push(View::default()) + } + } + } else { + for idx in idxs { + self.push_value_ignore_validity(other.value_unchecked(*idx as usize)); + } + } + }, + ShareStrategy::Always => { + let other_view_slice = other.views().as_slice(); + let other_views = idxs + .iter() + .map(|idx| *other_view_slice.get_unchecked(*idx as usize)); + self.extend_views_dedup_ignore_validity(other_views, other.data_buffers()); + }, + } + } + + self.validity + .gather_extend_from_opt_validity(other.validity(), idxs); + } + + fn opt_gather_extend(&mut self, other: &Self::Array, idxs: &[IdxSize], share: ShareStrategy) { + self.views.reserve(idxs.len()); + + unsafe { + match share { + ShareStrategy::Never => { + if let Some(v) = other.validity() { + for idx in idxs { + if (*idx as usize) < v.len() && v.get_bit_unchecked(*idx as usize) { + self.push_value_ignore_validity( + other.value_unchecked(*idx as usize), + ); + } else { + self.views.push(View::default()) + } + } + } else { + for idx in idxs { + if (*idx as usize) < other.len() { + self.push_value_ignore_validity( + other.value_unchecked(*idx as usize), + ); + } else { + self.views.push(View::default()) + } + } + } + }, + ShareStrategy::Always => { + let other_view_slice = other.views().as_slice(); + let other_views = idxs.iter().map(|idx| { + other_view_slice + .get(*idx as usize) + .copied() + .unwrap_or_default() + }); + self.extend_views_dedup_ignore_validity(other_views, other.data_buffers()); + }, + } + } + + self.validity + .opt_gather_extend_from_opt_validity(other.validity(), idxs, other.len()); + } +} diff --git a/crates/polars-arrow/src/array/binview/ffi.rs b/crates/polars-arrow/src/array/binview/ffi.rs new file mode 100644 index 000000000000..b2a983363691 --- /dev/null +++ b/crates/polars-arrow/src/array/binview/ffi.rs @@ -0,0 +1,100 @@ +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; + +use polars_error::PolarsResult; + +use super::BinaryViewArrayGeneric; +use crate::array::binview::{View, ViewType}; +use crate::array::{FromFfi, ToFfi}; +use crate::bitmap::align; +use crate::ffi; + +unsafe impl ToFfi for BinaryViewArrayGeneric { + fn buffers(&self) -> Vec> { + let mut buffers = Vec::with_capacity(self.buffers.len() + 2); + buffers.push(self.validity.as_ref().map(|x| x.as_ptr())); + buffers.push(Some(self.views.storage_ptr().cast::())); + buffers.extend(self.buffers.iter().map(|b| Some(b.storage_ptr()))); + buffers + } + + fn offset(&self) -> Option { + let offset = self.views.offset(); + if let Some(bitmap) = self.validity.as_ref() { + if bitmap.offset() == offset { + Some(offset) + } else { + None + } + } else { + Some(offset) + } + } + + fn to_ffi_aligned(&self) -> Self { + let offset = self.views.offset(); + + let validity = self.validity.as_ref().map(|bitmap| { + if bitmap.offset() == offset { + bitmap.clone() + } else { + align(bitmap, offset) + } + }); + + Self { + dtype: self.dtype.clone(), + validity, + views: self.views.clone(), + buffers: self.buffers.clone(), + phantom: Default::default(), + total_bytes_len: AtomicU64::new(self.total_bytes_len.load(Ordering::Relaxed)), + total_buffer_len: self.total_buffer_len, + } + } +} + +impl FromFfi for BinaryViewArrayGeneric { + unsafe fn try_from_ffi(array: A) -> PolarsResult { + let dtype = array.dtype().clone(); + + let validity = unsafe { array.validity() }?; + let views = unsafe { array.buffer::(1) }?; + + // 2 - validity + views + let n_buffers = array.n_buffers(); + let mut remaining_buffers = n_buffers - 2; + if remaining_buffers <= 1 { + return Ok(Self::new_unchecked_unknown_md( + dtype, + views, + Arc::from([]), + validity, + None, + )); + } + + let n_variadic_buffers = remaining_buffers - 1; + let variadic_buffer_offset = n_buffers - 1; + + let variadic_buffer_sizes = + array.buffer_known_len::(variadic_buffer_offset, n_variadic_buffers)?; + remaining_buffers -= 1; + + let mut variadic_buffers = Vec::with_capacity(remaining_buffers); + + let offset = 2; + for (i, &size) in (offset..remaining_buffers + offset).zip(variadic_buffer_sizes.iter()) { + let values = unsafe { array.buffer_known_len::(i, size as usize) }?; + variadic_buffers.push(values); + } + + Ok(Self::new_unchecked_unknown_md( + dtype, + views, + Arc::from(variadic_buffers), + validity, + None, + )) + } +} diff --git a/crates/polars-arrow/src/array/binview/fmt.rs b/crates/polars-arrow/src/array/binview/fmt.rs new file mode 100644 index 000000000000..53a0f71dd4b6 --- /dev/null +++ b/crates/polars-arrow/src/array/binview/fmt.rs @@ -0,0 +1,36 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::write_vec; +use super::BinaryViewArrayGeneric; +use crate::array::binview::ViewType; +use crate::array::{Array, BinaryViewArray, Utf8ViewArray}; + +pub fn write_value<'a, T: ViewType + ?Sized, W: Write>( + array: &'a BinaryViewArrayGeneric, + index: usize, + f: &mut W, +) -> Result +where + &'a T: Debug, +{ + let bytes = array.value(index).to_bytes(); + let writer = |f: &mut W, index| write!(f, "{}", bytes[index]); + + write_vec(f, writer, None, bytes.len(), "None", false) +} + +impl Debug for BinaryViewArray { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, f); + write!(f, "BinaryViewArray")?; + write_vec(f, writer, self.validity(), self.len(), "None", false) + } +} + +impl Debug for Utf8ViewArray { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = |f: &mut Formatter, index| write!(f, "{}", self.value(index)); + write!(f, "Utf8ViewArray")?; + write_vec(f, writer, self.validity(), self.len(), "None", false) + } +} diff --git a/crates/polars-arrow/src/array/binview/iterator.rs b/crates/polars-arrow/src/array/binview/iterator.rs new file mode 100644 index 000000000000..26587d5c1b72 --- /dev/null +++ b/crates/polars-arrow/src/array/binview/iterator.rs @@ -0,0 +1,47 @@ +use super::BinaryViewArrayGeneric; +use crate::array::binview::ViewType; +use crate::array::{ArrayAccessor, ArrayValuesIter, MutableBinaryViewArray}; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; + +unsafe impl<'a, T: ViewType + ?Sized> ArrayAccessor<'a> for BinaryViewArrayGeneric { + type Item = &'a T; + + #[inline] + unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item { + self.value_unchecked(index) + } + + #[inline] + fn len(&self) -> usize { + self.views.len() + } +} + +/// Iterator of values of an [`BinaryArray`]. +pub type BinaryViewValueIter<'a, T> = ArrayValuesIter<'a, BinaryViewArrayGeneric>; + +impl<'a, T: ViewType + ?Sized> IntoIterator for &'a BinaryViewArrayGeneric { + type Item = Option<&'a T>; + type IntoIter = ZipValidity<&'a T, BinaryViewValueIter<'a, T>, BitmapIter<'a>>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +unsafe impl<'a, T: ViewType + ?Sized> ArrayAccessor<'a> for MutableBinaryViewArray { + type Item = &'a T; + + #[inline] + unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item { + self.value_unchecked(index) + } + + #[inline] + fn len(&self) -> usize { + self.views().len() + } +} + +/// Iterator of values of an [`MutableBinaryViewArray`]. +pub type MutableBinaryViewValueIter<'a, T> = ArrayValuesIter<'a, MutableBinaryViewArray>; diff --git a/crates/polars-arrow/src/array/binview/mod.rs b/crates/polars-arrow/src/array/binview/mod.rs new file mode 100644 index 000000000000..bd5e025c2802 --- /dev/null +++ b/crates/polars-arrow/src/array/binview/mod.rs @@ -0,0 +1,708 @@ +#![allow(unsafe_op_in_unsafe_fn)] +//! See thread: https://lists.apache.org/thread/w88tpz76ox8h3rxkjl4so6rg3f1rv7wt + +mod builder; +pub use builder::*; +mod ffi; +pub(super) mod fmt; +mod iterator; +mod mutable; +mod view; + +use std::any::Any; +use std::fmt::Debug; +use std::marker::PhantomData; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; + +use polars_error::*; + +use crate::array::Array; +use crate::bitmap::Bitmap; +use crate::buffer::Buffer; +use crate::datatypes::ArrowDataType; + +mod private { + pub trait Sealed: Send + Sync {} + + impl Sealed for str {} + impl Sealed for [u8] {} +} +pub use iterator::BinaryViewValueIter; +pub use mutable::MutableBinaryViewArray; +use polars_utils::aliases::{InitHashMaps, PlHashMap}; +use private::Sealed; + +use crate::array::binview::view::{validate_binary_views, validate_views_utf8_only}; +use crate::array::iterator::NonNullValuesIter; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +pub type BinaryViewArray = BinaryViewArrayGeneric<[u8]>; +pub type Utf8ViewArray = BinaryViewArrayGeneric; +pub use view::{View, validate_utf8_views}; + +use super::Splitable; + +pub type MutablePlString = MutableBinaryViewArray; +pub type MutablePlBinary = MutableBinaryViewArray<[u8]>; + +static BIN_VIEW_TYPE: ArrowDataType = ArrowDataType::BinaryView; +static UTF8_VIEW_TYPE: ArrowDataType = ArrowDataType::Utf8View; + +// Growth parameters of view array buffers. +const DEFAULT_BLOCK_SIZE: usize = 8 * 1024; +const MAX_EXP_BLOCK_SIZE: usize = 16 * 1024 * 1024; + +pub trait ViewType: Sealed + 'static + PartialEq + AsRef { + const IS_UTF8: bool; + const DATA_TYPE: ArrowDataType; + type Owned: Debug + Clone + Sync + Send + AsRef; + + /// # Safety + /// The caller must ensure that `slice` is a valid view. + unsafe fn from_bytes_unchecked(slice: &[u8]) -> &Self; + fn from_bytes(slice: &[u8]) -> Option<&Self>; + + fn to_bytes(&self) -> &[u8]; + + #[allow(clippy::wrong_self_convention)] + fn into_owned(&self) -> Self::Owned; + + fn dtype() -> &'static ArrowDataType; +} + +impl ViewType for str { + const IS_UTF8: bool = true; + const DATA_TYPE: ArrowDataType = ArrowDataType::Utf8View; + type Owned = String; + + #[inline(always)] + unsafe fn from_bytes_unchecked(slice: &[u8]) -> &Self { + std::str::from_utf8_unchecked(slice) + } + #[inline(always)] + fn from_bytes(slice: &[u8]) -> Option<&Self> { + std::str::from_utf8(slice).ok() + } + + #[inline(always)] + fn to_bytes(&self) -> &[u8] { + self.as_bytes() + } + + fn into_owned(&self) -> Self::Owned { + self.to_string() + } + fn dtype() -> &'static ArrowDataType { + &UTF8_VIEW_TYPE + } +} + +impl ViewType for [u8] { + const IS_UTF8: bool = false; + const DATA_TYPE: ArrowDataType = ArrowDataType::BinaryView; + type Owned = Vec; + + #[inline(always)] + unsafe fn from_bytes_unchecked(slice: &[u8]) -> &Self { + slice + } + #[inline(always)] + fn from_bytes(slice: &[u8]) -> Option<&Self> { + Some(slice) + } + + #[inline(always)] + fn to_bytes(&self) -> &[u8] { + self + } + + fn into_owned(&self) -> Self::Owned { + self.to_vec() + } + + fn dtype() -> &'static ArrowDataType { + &BIN_VIEW_TYPE + } +} + +pub struct BinaryViewArrayGeneric { + dtype: ArrowDataType, + views: Buffer, + buffers: Arc<[Buffer]>, + validity: Option, + phantom: PhantomData, + /// Total bytes length if we would concatenate them all. + total_bytes_len: AtomicU64, + /// Total bytes in the buffer (excluding remaining capacity) + total_buffer_len: usize, +} + +impl PartialEq for BinaryViewArrayGeneric { + fn eq(&self, other: &Self) -> bool { + self.len() == other.len() && self.into_iter().zip(other).all(|(l, r)| l == r) + } +} + +impl Clone for BinaryViewArrayGeneric { + fn clone(&self) -> Self { + Self { + dtype: self.dtype.clone(), + views: self.views.clone(), + buffers: self.buffers.clone(), + validity: self.validity.clone(), + phantom: Default::default(), + total_bytes_len: AtomicU64::new(self.total_bytes_len.load(Ordering::Relaxed)), + total_buffer_len: self.total_buffer_len, + } + } +} + +unsafe impl Send for BinaryViewArrayGeneric {} +unsafe impl Sync for BinaryViewArrayGeneric {} + +const UNKNOWN_LEN: u64 = u64::MAX; + +impl BinaryViewArrayGeneric { + /// # Safety + /// The caller must ensure + /// - the data is valid utf8 (if required) + /// - The offsets match the buffers. + pub unsafe fn new_unchecked( + dtype: ArrowDataType, + views: Buffer, + buffers: Arc<[Buffer]>, + validity: Option, + total_bytes_len: usize, + total_buffer_len: usize, + ) -> Self { + // Verify the invariants + #[cfg(debug_assertions)] + { + if let Some(validity) = validity.as_ref() { + assert_eq!(validity.len(), views.len()); + } + + // @TODO: Enable this. This is currently bugged with concatenate. + // let mut actual_total_buffer_len = 0; + // let mut actual_total_bytes_len = 0; + // + // for buffer in buffers.iter() { + // actual_total_buffer_len += buffer.len(); + // } + + for (i, view) in views.iter().enumerate() { + let is_valid = validity.as_ref().is_none_or(|v| v.get_bit(i)); + + if !is_valid { + continue; + } + + // actual_total_bytes_len += view.length as usize; + if view.length > View::MAX_INLINE_SIZE { + assert!((view.buffer_idx as usize) < (buffers.len())); + assert!( + view.offset as usize + view.length as usize + <= buffers[view.buffer_idx as usize].len() + ); + } + } + + // assert_eq!(actual_total_buffer_len, total_buffer_len); + // if (total_bytes_len as u64) != UNKNOWN_LEN { + // assert_eq!(actual_total_bytes_len, total_bytes_len); + // } + } + + Self { + dtype, + views, + buffers, + validity, + phantom: Default::default(), + total_bytes_len: AtomicU64::new(total_bytes_len as u64), + total_buffer_len, + } + } + + /// Create a new BinaryViewArray but initialize a statistics compute. + /// + /// # Safety + /// The caller must ensure the invariants + pub unsafe fn new_unchecked_unknown_md( + dtype: ArrowDataType, + views: Buffer, + buffers: Arc<[Buffer]>, + validity: Option, + total_buffer_len: Option, + ) -> Self { + let total_bytes_len = UNKNOWN_LEN as usize; + let total_buffer_len = + total_buffer_len.unwrap_or_else(|| buffers.iter().map(|b| b.len()).sum()); + Self::new_unchecked( + dtype, + views, + buffers, + validity, + total_bytes_len, + total_buffer_len, + ) + } + + pub fn data_buffers(&self) -> &Arc<[Buffer]> { + &self.buffers + } + + pub fn variadic_buffer_lengths(&self) -> Vec { + self.buffers.iter().map(|buf| buf.len() as i64).collect() + } + + pub fn views(&self) -> &Buffer { + &self.views + } + + pub fn into_views(self) -> Vec { + self.views.make_mut() + } + + pub fn into_inner( + self, + ) -> ( + Buffer, + Arc<[Buffer]>, + Option, + usize, + usize, + ) { + let views = self.views; + let buffers = self.buffers; + let validity = self.validity; + + ( + views, + buffers, + validity, + self.total_bytes_len.load(Ordering::Relaxed) as usize, + self.total_buffer_len, + ) + } + + /// Apply a function over the views. This can be used to update views in operations like slicing. + /// + /// # Safety + /// Update the views. All invariants of the views apply. + pub unsafe fn apply_views View>(&self, mut update_view: F) -> Self { + let arr = self.clone(); + let (views, buffers, validity, total_bytes_len, total_buffer_len) = arr.into_inner(); + + let mut views = views.make_mut(); + for v in views.iter_mut() { + let str_slice = T::from_bytes_unchecked(v.get_slice_unchecked(&buffers)); + *v = update_view(*v, str_slice); + } + Self::new_unchecked( + self.dtype.clone(), + views.into(), + buffers, + validity, + total_bytes_len, + total_buffer_len, + ) + } + + pub fn try_new( + dtype: ArrowDataType, + views: Buffer, + buffers: Arc<[Buffer]>, + validity: Option, + ) -> PolarsResult { + if T::IS_UTF8 { + validate_utf8_views(views.as_ref(), buffers.as_ref())?; + } else { + validate_binary_views(views.as_ref(), buffers.as_ref())?; + } + + if let Some(validity) = &validity { + polars_ensure!(validity.len()== views.len(), ComputeError: "validity mask length must match the number of values" ) + } + + unsafe { + Ok(Self::new_unchecked_unknown_md( + dtype, views, buffers, validity, None, + )) + } + } + + /// Creates an empty [`BinaryViewArrayGeneric`], i.e. whose `.len` is zero. + #[inline] + pub fn new_empty(dtype: ArrowDataType) -> Self { + unsafe { Self::new_unchecked(dtype, Buffer::new(), Arc::from([]), None, 0, 0) } + } + + /// Returns a new null [`BinaryViewArrayGeneric`] of `length`. + #[inline] + pub fn new_null(dtype: ArrowDataType, length: usize) -> Self { + let validity = Some(Bitmap::new_zeroed(length)); + unsafe { Self::new_unchecked(dtype, Buffer::zeroed(length), Arc::from([]), validity, 0, 0) } + } + + /// Returns the element at index `i` + /// # Panics + /// iff `i >= self.len()` + #[inline] + pub fn value(&self, i: usize) -> &T { + assert!(i < self.len()); + unsafe { self.value_unchecked(i) } + } + + /// Returns the element at index `i` + /// + /// # Safety + /// Assumes that the `i < self.len`. + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> &T { + let v = self.views.get_unchecked(i); + T::from_bytes_unchecked(v.get_slice_unchecked(&self.buffers)) + } + + /// Returns the element at index `i`, or None if it is null. + /// # Panics + /// iff `i >= self.len()` + #[inline] + pub fn get(&self, i: usize) -> Option<&T> { + assert!(i < self.len()); + unsafe { self.get_unchecked(i) } + } + + /// Returns the element at index `i`, or None if it is null. + /// + /// # Safety + /// Assumes that the `i < self.len`. + #[inline] + pub unsafe fn get_unchecked(&self, i: usize) -> Option<&T> { + if self + .validity + .as_ref() + .is_none_or(|v| v.get_bit_unchecked(i)) + { + let v = self.views.get_unchecked(i); + Some(T::from_bytes_unchecked( + v.get_slice_unchecked(&self.buffers), + )) + } else { + None + } + } + + /// Returns an iterator of `Option<&T>` over every element of this array. + pub fn iter(&self) -> ZipValidity<&T, BinaryViewValueIter, BitmapIter> { + ZipValidity::new_with_validity(self.values_iter(), self.validity.as_ref()) + } + + /// Returns an iterator of `&[u8]` over every element of this array, ignoring the validity + pub fn values_iter(&self) -> BinaryViewValueIter { + BinaryViewValueIter::new(self) + } + + pub fn len_iter(&self) -> impl Iterator + '_ { + self.views.iter().map(|v| v.length) + } + + /// Returns an iterator of the non-null values. + pub fn non_null_values_iter(&self) -> NonNullValuesIter<'_, BinaryViewArrayGeneric> { + NonNullValuesIter::new(self, self.validity()) + } + + /// Returns an iterator of the non-null values. + pub fn non_null_views_iter(&self) -> NonNullValuesIter<'_, Buffer> { + NonNullValuesIter::new(self.views(), self.validity()) + } + + impl_sliced!(); + impl_mut_validity!(); + impl_into_array!(); + + pub fn from_slice, P: AsRef<[Option]>>(slice: P) -> Self { + let mutable = MutableBinaryViewArray::from_iterator( + slice.as_ref().iter().map(|opt_v| opt_v.as_ref()), + ); + mutable.into() + } + + pub fn from_slice_values, P: AsRef<[S]>>(slice: P) -> Self { + let mutable = + MutableBinaryViewArray::from_values_iter(slice.as_ref().iter().map(|v| v.as_ref())); + mutable.into() + } + + /// Get the total length of bytes that it would take to concatenate all binary/str values in this array. + pub fn total_bytes_len(&self) -> usize { + let total = self.total_bytes_len.load(Ordering::Relaxed); + if total == UNKNOWN_LEN { + let total = self.len_iter().map(|v| v as usize).sum::(); + self.total_bytes_len.store(total as u64, Ordering::Relaxed); + total + } else { + total as usize + } + } + + /// Get the length of bytes that are stored in the variadic buffers. + pub fn total_buffer_len(&self) -> usize { + self.total_buffer_len + } + + fn total_unshared_buffer_len(&self) -> usize { + // XXX: it is O(n), not O(1). + // Given this function is only called in `maybe_gc()`, + // it may not be worthy to add an extra field for this. + self.buffers + .iter() + .map(|buf| { + if buf.storage_refcount() > 1 { + 0 + } else { + buf.len() + } + }) + .sum() + } + + #[inline(always)] + pub fn len(&self) -> usize { + self.views.len() + } + + /// Garbage collect + pub fn gc(self) -> Self { + if self.buffers.is_empty() { + return self; + } + let mut mutable = MutableBinaryViewArray::with_capacity(self.len()); + let buffers = self.buffers.as_ref(); + + for view in self.views.as_ref() { + unsafe { mutable.push_view_unchecked(*view, buffers) } + } + mutable.freeze().with_validity(self.validity) + } + + pub fn deshare(&self) -> Self { + if Arc::strong_count(&self.buffers) == 1 + && self.buffers.iter().all(|b| b.storage_refcount() == 1) + { + return self.clone(); + } + self.clone().gc() + } + + pub fn is_sliced(&self) -> bool { + !std::ptr::eq(self.views.as_ptr(), self.views.storage_ptr()) + } + + pub fn maybe_gc(self) -> Self { + const GC_MINIMUM_SAVINGS: usize = 16 * 1024; // At least 16 KiB. + + if self.total_buffer_len <= GC_MINIMUM_SAVINGS { + return self; + } + + if Arc::strong_count(&self.buffers) != 1 { + // There are multiple holders of this `buffers`. + // If we allow gc in this case, + // it may end up copying the same content multiple times. + return self; + } + + // Subtract the maximum amount of inlined strings to get a lower bound + // on the number of buffer bytes needed (assuming no dedup). + let total_bytes_len = self.total_bytes_len(); + let buffer_req_lower_bound = total_bytes_len.saturating_sub(self.len() * 12); + + let lower_bound_mem_usage_post_gc = self.len() * 16 + buffer_req_lower_bound; + // Use unshared buffer len. Shared buffer won't be freed; no savings. + let cur_mem_usage = self.len() * 16 + self.total_unshared_buffer_len(); + let savings_upper_bound = cur_mem_usage.saturating_sub(lower_bound_mem_usage_post_gc); + + if savings_upper_bound >= GC_MINIMUM_SAVINGS + && cur_mem_usage >= 4 * lower_bound_mem_usage_post_gc + { + self.gc() + } else { + self + } + } + + pub fn make_mut(self) -> MutableBinaryViewArray { + let views = self.views.make_mut(); + let completed_buffers = self.buffers.to_vec(); + let validity = self.validity.map(|bitmap| bitmap.make_mut()); + + // We need to know the total_bytes_len if we are going to mutate it. + let mut total_bytes_len = self.total_bytes_len.load(Ordering::Relaxed); + if total_bytes_len == UNKNOWN_LEN { + total_bytes_len = views.iter().map(|view| view.length as u64).sum(); + } + let total_bytes_len = total_bytes_len as usize; + + MutableBinaryViewArray { + views, + completed_buffers, + in_progress_buffer: vec![], + validity, + phantom: Default::default(), + total_bytes_len, + total_buffer_len: self.total_buffer_len, + stolen_buffers: PlHashMap::new(), + } + } +} + +impl BinaryViewArray { + /// Validate the underlying bytes on UTF-8. + pub fn validate_utf8(&self) -> PolarsResult<()> { + // SAFETY: views are correct + unsafe { validate_views_utf8_only(&self.views, &self.buffers, 0) } + } + + /// Convert [`BinaryViewArray`] to [`Utf8ViewArray`]. + pub fn to_utf8view(&self) -> PolarsResult { + self.validate_utf8()?; + unsafe { Ok(self.to_utf8view_unchecked()) } + } + + /// Convert [`BinaryViewArray`] to [`Utf8ViewArray`] without checking UTF-8. + /// + /// # Safety + /// The caller must ensure the underlying data is valid UTF-8. + pub unsafe fn to_utf8view_unchecked(&self) -> Utf8ViewArray { + Utf8ViewArray::new_unchecked( + ArrowDataType::Utf8View, + self.views.clone(), + self.buffers.clone(), + self.validity.clone(), + self.total_bytes_len.load(Ordering::Relaxed) as usize, + self.total_buffer_len, + ) + } +} + +impl Utf8ViewArray { + pub fn to_binview(&self) -> BinaryViewArray { + // SAFETY: same invariants. + unsafe { + BinaryViewArray::new_unchecked( + ArrowDataType::BinaryView, + self.views.clone(), + self.buffers.clone(), + self.validity.clone(), + self.total_bytes_len.load(Ordering::Relaxed) as usize, + self.total_buffer_len, + ) + } + } +} + +impl Array for BinaryViewArrayGeneric { + fn as_any(&self) -> &dyn Any { + self + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + #[inline(always)] + fn len(&self) -> usize { + BinaryViewArrayGeneric::len(self) + } + + fn dtype(&self) -> &ArrowDataType { + T::dtype() + } + + fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + fn split_at_boxed(&self, offset: usize) -> (Box, Box) { + let (lhs, rhs) = Splitable::split_at(self, offset); + (Box::new(lhs), Box::new(rhs)) + } + + unsafe fn split_at_boxed_unchecked(&self, offset: usize) -> (Box, Box) { + let (lhs, rhs) = unsafe { Splitable::split_at_unchecked(self, offset) }; + (Box::new(lhs), Box::new(rhs)) + } + + fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + debug_assert!(offset + length <= self.len()); + self.validity = self + .validity + .take() + .map(|bitmap| bitmap.sliced_unchecked(offset, length)) + .filter(|bitmap| bitmap.unset_bits() > 0); + self.views.slice_unchecked(offset, length); + self.total_bytes_len.store(UNKNOWN_LEN, Ordering::Relaxed) + } + + fn with_validity(&self, validity: Option) -> Box { + debug_assert!( + validity.as_ref().is_none_or(|v| v.len() == self.len()), + "{} != {}", + validity.as_ref().unwrap().len(), + self.len() + ); + + let mut new = self.clone(); + new.validity = validity; + Box::new(new) + } + + fn to_boxed(&self) -> Box { + Box::new(self.clone()) + } +} + +impl Splitable for BinaryViewArrayGeneric { + fn check_bound(&self, offset: usize) -> bool { + offset <= self.len() + } + + unsafe fn _split_at_unchecked(&self, offset: usize) -> (Self, Self) { + let (lhs_views, rhs_views) = unsafe { self.views.split_at_unchecked(offset) }; + let (lhs_validity, rhs_validity) = unsafe { self.validity.split_at_unchecked(offset) }; + + unsafe { + ( + Self::new_unchecked( + self.dtype.clone(), + lhs_views, + self.buffers.clone(), + lhs_validity, + if offset == 0 { 0 } else { UNKNOWN_LEN as _ }, + self.total_buffer_len(), + ), + Self::new_unchecked( + self.dtype.clone(), + rhs_views, + self.buffers.clone(), + rhs_validity, + if offset == self.len() { + 0 + } else { + UNKNOWN_LEN as _ + }, + self.total_buffer_len(), + ), + ) + } + } +} diff --git a/crates/polars-arrow/src/array/binview/mutable.rs b/crates/polars-arrow/src/array/binview/mutable.rs new file mode 100644 index 000000000000..85a6c5925b61 --- /dev/null +++ b/crates/polars-arrow/src/array/binview/mutable.rs @@ -0,0 +1,901 @@ +use std::any::Any; +use std::fmt::{Debug, Formatter}; +use std::ops::Deref; +use std::sync::Arc; + +use hashbrown::hash_map::Entry; +use polars_error::PolarsResult; +use polars_utils::aliases::{InitHashMaps, PlHashMap}; + +use crate::array::binview::iterator::MutableBinaryViewValueIter; +use crate::array::binview::view::validate_views_utf8_only; +use crate::array::binview::{ + BinaryViewArrayGeneric, DEFAULT_BLOCK_SIZE, MAX_EXP_BLOCK_SIZE, ViewType, +}; +use crate::array::{Array, MutableArray, TryExtend, TryPush, View}; +use crate::bitmap::MutableBitmap; +use crate::buffer::Buffer; +use crate::datatypes::ArrowDataType; +use crate::legacy::trusted_len::TrustedLenPush; +use crate::trusted_len::TrustedLen; + +// Invariants: +// +// - Each view must point to a valid slice of a buffer +// - `total_buffer_len` must be equal to `completed_buffers.iter().map(Vec::len).sum()` +// - `total_bytes_len` must be equal to `views.iter().map(View::len).sum()` +pub struct MutableBinaryViewArray { + pub(crate) views: Vec, + pub(crate) completed_buffers: Vec>, + pub(crate) in_progress_buffer: Vec, + pub(crate) validity: Option, + pub(crate) phantom: std::marker::PhantomData, + /// Total bytes length if we would concatenate them all. + pub(crate) total_bytes_len: usize, + /// Total bytes in the buffer (excluding remaining capacity) + pub(crate) total_buffer_len: usize, + /// Mapping from `Buffer::deref()` to index in `completed_buffers`. + /// Used in `push_view()`. + pub(crate) stolen_buffers: PlHashMap, +} + +impl Clone for MutableBinaryViewArray { + fn clone(&self) -> Self { + Self { + views: self.views.clone(), + completed_buffers: self.completed_buffers.clone(), + in_progress_buffer: self.in_progress_buffer.clone(), + validity: self.validity.clone(), + phantom: Default::default(), + total_bytes_len: self.total_bytes_len, + total_buffer_len: self.total_buffer_len, + stolen_buffers: PlHashMap::new(), + } + } +} + +impl Debug for MutableBinaryViewArray { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "mutable-binview{:?}", T::DATA_TYPE) + } +} + +impl Default for MutableBinaryViewArray { + fn default() -> Self { + Self::with_capacity(0) + } +} + +impl From> for BinaryViewArrayGeneric { + fn from(mut value: MutableBinaryViewArray) -> Self { + value.finish_in_progress(); + unsafe { + Self::new_unchecked( + T::DATA_TYPE, + value.views.into(), + Arc::from(value.completed_buffers), + value.validity.map(|b| b.into()), + value.total_bytes_len, + value.total_buffer_len, + ) + } + } +} + +impl MutableBinaryViewArray { + pub fn new() -> Self { + Self::default() + } + + pub fn with_capacity(capacity: usize) -> Self { + Self { + views: Vec::with_capacity(capacity), + completed_buffers: vec![], + in_progress_buffer: vec![], + validity: None, + phantom: Default::default(), + total_buffer_len: 0, + total_bytes_len: 0, + stolen_buffers: PlHashMap::new(), + } + } + + /// Get a mutable reference to the [`Vec`] of [`View`]s in this [`MutableBinaryViewArray`]. + /// + /// # Safety + /// + /// This is safe as long as any mutation of the [`Vec`] does not break any invariants of the + /// [`MutableBinaryViewArray`] before it is read again. + #[inline] + pub unsafe fn views_mut(&mut self) -> &mut Vec { + &mut self.views + } + + /// Set the `total_bytes_len` of the [`MutableBinaryViewArray`] + /// + /// # Safety + /// + /// This should not break invariants of the [`MutableBinaryViewArray`] + #[inline] + pub unsafe fn set_total_bytes_len(&mut self, value: usize) { + #[cfg(debug_assertions)] + { + let actual_length: usize = self.views().iter().map(|v| v.length as usize).sum(); + assert_eq!(value, actual_length); + } + + self.total_bytes_len = value; + } + + pub fn total_bytes_len(&self) -> usize { + self.total_bytes_len + } + + pub fn total_buffer_len(&self) -> usize { + self.total_buffer_len + } + + #[inline] + pub fn views(&self) -> &[View] { + &self.views + } + + #[inline] + pub fn completed_buffers(&self) -> &[Buffer] { + &self.completed_buffers + } + + pub fn validity(&mut self) -> Option<&mut MutableBitmap> { + self.validity.as_mut() + } + + /// Reserves `additional` elements and `additional_buffer` on the buffer. + pub fn reserve(&mut self, additional: usize) { + self.views.reserve(additional); + } + + #[inline] + pub fn len(&self) -> usize { + self.views.len() + } + + #[inline] + pub fn capacity(&self) -> usize { + self.views.capacity() + } + + fn init_validity(&mut self, unset_last: bool) { + let mut validity = MutableBitmap::with_capacity(self.views.capacity()); + validity.extend_constant(self.len(), true); + if unset_last { + validity.set(self.len() - 1, false); + } + self.validity = Some(validity); + } + + /// # Safety + /// - caller must allocate enough capacity + /// - caller must ensure the view and buffers match. + /// - The array must not have validity. + pub(crate) unsafe fn push_view_unchecked(&mut self, v: View, buffers: &[Buffer]) { + let len = v.length; + self.total_bytes_len += len as usize; + if len <= 12 { + debug_assert!(self.views.capacity() > self.views.len()); + self.views.push_unchecked(v) + } else { + self.total_buffer_len += len as usize; + let data = buffers.get_unchecked(v.buffer_idx as usize); + let offset = v.offset as usize; + let bytes = data.get_unchecked(offset..offset + len as usize); + let t = T::from_bytes_unchecked(bytes); + self.push_value_ignore_validity(t) + } + } + + /// # Safety + /// - caller must allocate enough capacity + /// - caller must ensure the view and buffers match. + /// - The array must not have validity. + /// - caller must not mix use this function with other push functions. + pub unsafe fn push_view_unchecked_dedupe(&mut self, mut v: View, buffers: &[Buffer]) { + let len = v.length; + self.total_bytes_len += len as usize; + if len <= 12 { + self.views.push_unchecked(v); + } else { + let buffer = buffers.get_unchecked(v.buffer_idx as usize); + let idx = match self.stolen_buffers.entry(buffer.deref().as_ptr() as usize) { + Entry::Occupied(entry) => *entry.get(), + Entry::Vacant(entry) => { + let idx = self.completed_buffers.len() as u32; + entry.insert(idx); + self.completed_buffers.push(buffer.clone()); + self.total_buffer_len += buffer.len(); + idx + }, + }; + v.buffer_idx = idx; + self.views.push_unchecked(v); + } + } + + pub fn push_view(&mut self, mut v: View, buffers: &[Buffer]) { + let len = v.length; + self.total_bytes_len += len as usize; + if len <= 12 { + self.views.push(v); + } else { + // Do no mix use of push_view and push_value_ignore_validity - + // it causes fragmentation. + self.finish_in_progress(); + + let buffer = &buffers[v.buffer_idx as usize]; + let idx = match self.stolen_buffers.entry(buffer.deref().as_ptr() as usize) { + Entry::Occupied(entry) => { + let idx = *entry.get(); + let target_buffer = &self.completed_buffers[idx as usize]; + debug_assert_eq!(buffer, target_buffer); + idx + }, + Entry::Vacant(entry) => { + let idx = self.completed_buffers.len() as u32; + entry.insert(idx); + self.completed_buffers.push(buffer.clone()); + self.total_buffer_len += buffer.len(); + idx + }, + }; + v.buffer_idx = idx; + self.views.push(v); + } + if let Some(validity) = &mut self.validity { + validity.push(true) + } + } + + #[inline] + pub fn push_value_ignore_validity>(&mut self, value: V) { + let bytes = value.as_ref().to_bytes(); + self.total_bytes_len += bytes.len(); + + // A string can only be maximum of 4GB in size. + let len = u32::try_from(bytes.len()).unwrap(); + + let view = if len <= View::MAX_INLINE_SIZE { + View::new_inline(bytes) + } else { + self.total_buffer_len += bytes.len(); + + // We want to make sure that we never have to memcopy between buffers. So if the + // current buffer is not large enough, create a new buffer that is large enough and try + // to anticipate the larger size. + let required_capacity = self.in_progress_buffer.len() + bytes.len(); + let does_not_fit_in_buffer = self.in_progress_buffer.capacity() < required_capacity; + + // We can only save offsets that are below u32::MAX + let offset_will_not_fit = self.in_progress_buffer.len() > u32::MAX as usize; + + if does_not_fit_in_buffer || offset_will_not_fit { + // Allocate a new buffer and flush the old buffer + let new_capacity = (self.in_progress_buffer.capacity() * 2) + .clamp(DEFAULT_BLOCK_SIZE, MAX_EXP_BLOCK_SIZE) + .max(bytes.len()); + let in_progress = Vec::with_capacity(new_capacity); + let flushed = std::mem::replace(&mut self.in_progress_buffer, in_progress); + if !flushed.is_empty() { + self.completed_buffers.push(flushed.into()) + } + } + + let offset = self.in_progress_buffer.len() as u32; + self.in_progress_buffer.extend_from_slice(bytes); + + let buffer_idx = u32::try_from(self.completed_buffers.len()).unwrap(); + + View::new_from_bytes(bytes, buffer_idx, offset) + }; + + self.views.push(view); + } + + #[inline] + pub fn push_buffer(&mut self, buffer: Buffer) -> u32 { + self.finish_in_progress(); + + let buffer_idx = self.completed_buffers.len(); + self.total_buffer_len += buffer.len(); + self.completed_buffers.push(buffer); + buffer_idx as u32 + } + + #[inline] + pub fn push_value>(&mut self, value: V) { + if let Some(validity) = &mut self.validity { + validity.push(true) + } + self.push_value_ignore_validity(value) + } + + #[inline] + pub fn push>(&mut self, value: Option) { + if let Some(value) = value { + self.push_value(value) + } else { + self.push_null() + } + } + + #[inline] + pub fn push_null(&mut self) { + self.views.push(View::default()); + match &mut self.validity { + Some(validity) => validity.push(false), + None => self.init_validity(true), + } + } + + pub fn extend_null(&mut self, additional: usize) { + if self.validity.is_none() && additional > 0 { + self.init_validity(false); + } + self.views + .extend(std::iter::repeat_n(View::default(), additional)); + if let Some(validity) = &mut self.validity { + validity.extend_constant(additional, false); + } + } + + pub fn extend_constant>(&mut self, additional: usize, value: Option) { + if value.is_none() && self.validity.is_none() { + self.init_validity(false); + } + + if let Some(validity) = &mut self.validity { + validity.extend_constant(additional, value.is_some()) + } + + // Push and pop to get the properly encoded value. + // For long string this leads to a dictionary encoding, + // as we push the string only once in the buffers + let view_value = value + .map(|v| { + self.push_value_ignore_validity(v); + self.views.pop().unwrap() + }) + .unwrap_or_default(); + self.views + .extend(std::iter::repeat_n(view_value, additional)); + } + + impl_mutable_array_mut_validity!(); + + #[inline] + pub fn extend_values(&mut self, iterator: I) + where + I: Iterator, + P: AsRef, + { + self.reserve(iterator.size_hint().0); + for v in iterator { + self.push_value(v) + } + } + + #[inline] + pub fn extend_trusted_len_values(&mut self, iterator: I) + where + I: TrustedLen, + P: AsRef, + { + self.extend_values(iterator) + } + + #[inline] + pub fn extend(&mut self, iterator: I) + where + I: Iterator>, + P: AsRef, + { + self.reserve(iterator.size_hint().0); + for p in iterator { + self.push(p) + } + } + + #[inline] + pub fn extend_trusted_len(&mut self, iterator: I) + where + I: TrustedLen>, + P: AsRef, + { + self.extend(iterator) + } + + #[inline] + pub fn extend_views(&mut self, iterator: I, buffers: &[Buffer]) + where + I: Iterator>, + { + self.reserve(iterator.size_hint().0); + for p in iterator { + match p { + Some(v) => self.push_view(v, buffers), + None => self.push_null(), + } + } + } + + #[inline] + pub fn extend_views_trusted_len(&mut self, iterator: I, buffers: &[Buffer]) + where + I: TrustedLen>, + { + self.extend_views(iterator, buffers); + } + + #[inline] + pub fn extend_non_null_views(&mut self, iterator: I, buffers: &[Buffer]) + where + I: Iterator, + { + self.reserve(iterator.size_hint().0); + for v in iterator { + self.push_view(v, buffers); + } + } + + #[inline] + pub fn extend_non_null_views_trusted_len(&mut self, iterator: I, buffers: &[Buffer]) + where + I: TrustedLen, + { + self.extend_non_null_views(iterator, buffers); + } + + /// # Safety + /// Same as `push_view_unchecked()`. + #[inline] + pub unsafe fn extend_non_null_views_unchecked(&mut self, iterator: I, buffers: &[Buffer]) + where + I: Iterator, + { + self.reserve(iterator.size_hint().0); + for v in iterator { + self.push_view_unchecked(v, buffers); + } + } + + /// # Safety + /// Same as `push_view_unchecked()`. + #[inline] + pub unsafe fn extend_non_null_views_unchecked_dedupe( + &mut self, + iterator: I, + buffers: &[Buffer], + ) where + I: Iterator, + { + self.reserve(iterator.size_hint().0); + for v in iterator { + self.push_view_unchecked_dedupe(v, buffers); + } + } + + #[inline] + pub fn from_iterator(iterator: I) -> Self + where + I: Iterator>, + P: AsRef, + { + let mut mutable = Self::with_capacity(iterator.size_hint().0); + mutable.extend(iterator); + mutable + } + + pub fn from_values_iter(iterator: I) -> Self + where + I: Iterator, + P: AsRef, + { + let mut mutable = Self::with_capacity(iterator.size_hint().0); + mutable.extend_values(iterator); + mutable + } + + pub fn from, P: AsRef<[Option]>>(slice: P) -> Self { + Self::from_iterator(slice.as_ref().iter().map(|opt_v| opt_v.as_ref())) + } + + pub fn finish_in_progress(&mut self) -> bool { + if !self.in_progress_buffer.is_empty() { + self.completed_buffers + .push(std::mem::take(&mut self.in_progress_buffer).into()); + true + } else { + false + } + } + + #[inline] + pub fn freeze(self) -> BinaryViewArrayGeneric { + self.into() + } + + #[inline] + pub fn freeze_with_dtype(self, dtype: ArrowDataType) -> BinaryViewArrayGeneric { + let mut arr: BinaryViewArrayGeneric = self.into(); + arr.dtype = dtype; + arr + } + + pub fn take(self) -> (Vec, Vec>) { + (self.views, self.completed_buffers) + } + + #[inline] + pub fn value(&self, i: usize) -> &T { + assert!(i < self.len()); + unsafe { self.value_unchecked(i) } + } + + /// Returns the element at index `i` + /// + /// # Safety + /// Assumes that the `i < self.len`. + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> &T { + self.value_from_view_unchecked(self.views.get_unchecked(i)) + } + + /// Returns the element indicated by the given view. + /// + /// # Safety + /// Assumes the View belongs to this MutableBinaryViewArray. + pub unsafe fn value_from_view_unchecked<'a>(&'a self, view: &'a View) -> &'a T { + // View layout: + // length: 4 bytes + // prefix: 4 bytes + // buffer_index: 4 bytes + // offset: 4 bytes + + // Inlined layout: + // length: 4 bytes + // data: 12 bytes + let len = view.length; + let bytes = if len <= 12 { + let ptr = view as *const View as *const u8; + std::slice::from_raw_parts(ptr.add(4), len as usize) + } else { + let buffer_idx = view.buffer_idx as usize; + let offset = view.offset; + + let data = if buffer_idx == self.completed_buffers.len() { + self.in_progress_buffer.as_slice() + } else { + self.completed_buffers.get_unchecked(buffer_idx) + }; + + let offset = offset as usize; + data.get_unchecked(offset..offset + len as usize) + }; + T::from_bytes_unchecked(bytes) + } + + /// Returns an iterator of `&[u8]` over every element of this array, ignoring the validity + pub fn values_iter(&self) -> MutableBinaryViewValueIter { + MutableBinaryViewValueIter::new(self) + } + + pub fn extend_from_array(&mut self, other: &BinaryViewArrayGeneric) { + let slf_len = self.len(); + match (&mut self.validity, other.validity()) { + (None, None) => {}, + (Some(v), None) => v.extend_constant(other.len(), true), + (v @ None, Some(other)) => { + let mut bm = MutableBitmap::with_capacity(slf_len + other.len()); + bm.extend_constant(slf_len, true); + bm.extend_from_bitmap(other); + *v = Some(bm); + }, + (Some(slf), Some(other)) => slf.extend_from_bitmap(other), + } + + if other.total_buffer_len() == 0 { + self.views.extend(other.views().iter().copied()); + } else { + self.finish_in_progress(); + + let buffer_offset = self.completed_buffers().len() as u32; + self.completed_buffers + .extend(other.data_buffers().iter().cloned()); + + self.views.extend(other.views().iter().map(|view| { + let mut view = *view; + if view.length > View::MAX_INLINE_SIZE { + view.buffer_idx += buffer_offset; + } + view + })); + + let new_total_buffer_len = self.total_buffer_len() + other.total_buffer_len(); + self.total_buffer_len = new_total_buffer_len; + } + + self.total_bytes_len = self.total_bytes_len() + other.total_bytes_len(); + } +} + +impl MutableBinaryViewArray<[u8]> { + pub fn validate_utf8(&mut self, buffer_offset: usize, views_offset: usize) -> PolarsResult<()> { + // Finish the in progress as it might be required for validation. + let pushed = self.finish_in_progress(); + // views are correct + unsafe { + validate_views_utf8_only( + &self.views[views_offset..], + &self.completed_buffers, + buffer_offset, + )? + } + // Restore in-progress buffer as we don't want to get too small buffers + if pushed { + if let Some(last) = self.completed_buffers.pop() { + self.in_progress_buffer = last.into_mut().right().unwrap(); + } + } + Ok(()) + } + + /// Extend from a `buffer` and `length` of items given some statistics about the lengths. + /// + /// This will attempt to dispatch to several optimized implementations. + /// + /// # Safety + /// + /// This is safe if the statistics are correct. + pub unsafe fn extend_from_lengths_with_stats( + &mut self, + buffer: &[u8], + lengths_iterator: impl Clone + ExactSizeIterator, + min_length: usize, + max_length: usize, + sum_length: usize, + ) { + let num_items = lengths_iterator.len(); + + if num_items == 0 { + return; + } + + #[cfg(debug_assertions)] + { + let (min, max, sum) = lengths_iterator.clone().map(|v| (v, v, v)).fold( + (usize::MAX, usize::MIN, 0usize), + |(cmin, cmax, csum), (emin, emax, esum)| { + (cmin.min(emin), cmax.max(emax), csum + esum) + }, + ); + + assert_eq!(min, min_length); + assert_eq!(max, max_length); + assert_eq!(sum, sum_length); + } + + assert!(sum_length <= buffer.len()); + + let mut buffer_offset = 0; + if min_length > View::MAX_INLINE_SIZE as usize + && (num_items == 1 || sum_length + self.in_progress_buffer.len() <= u32::MAX as usize) + { + let buffer_idx = self.completed_buffers().len() as u32; + let in_progress_buffer_offset = self.in_progress_buffer.len(); + + self.total_bytes_len += sum_length; + self.total_buffer_len += sum_length; + + self.in_progress_buffer + .extend_from_slice(&buffer[..sum_length]); + self.views.extend(lengths_iterator.map(|length| { + // SAFETY: We asserted before that the sum of all lengths is smaller or equal to + // the buffer length. + let view_buffer = + unsafe { buffer.get_unchecked(buffer_offset..buffer_offset + length) }; + + // SAFETY: We know that the minimum length > View::MAX_INLINE_SIZE. Therefore, this + // length is > View::MAX_INLINE_SIZE. + let view = unsafe { + View::new_noninline_unchecked( + view_buffer, + buffer_idx, + (buffer_offset + in_progress_buffer_offset) as u32, + ) + }; + buffer_offset += length; + view + })); + } else if max_length <= View::MAX_INLINE_SIZE as usize { + self.total_bytes_len += sum_length; + + // If the min and max are the same, we can dispatch to the optimized SIMD + // implementation. + if min_length == max_length { + let length = min_length; + if length == 0 { + self.views + .resize(self.views.len() + num_items, View::new_inline(&[])); + } else { + View::extend_with_inlinable_strided( + &mut self.views, + &buffer[..length * num_items], + length as u8, + ); + } + } else { + self.views.extend(lengths_iterator.map(|length| { + // SAFETY: We asserted before that the sum of all lengths is smaller or equal + // to the buffer length. + let view_buffer = + unsafe { buffer.get_unchecked(buffer_offset..buffer_offset + length) }; + + // SAFETY: We know that each view has a length <= View::MAX_INLINE_SIZE because + // the maximum length is <= View::MAX_INLINE_SIZE + let view = unsafe { View::new_inline_unchecked(view_buffer) }; + + buffer_offset += length; + + view + })); + } + } else { + // If all fails, just fall back to a base implementation. + self.reserve(num_items); + for length in lengths_iterator { + let value = &buffer[buffer_offset..buffer_offset + length]; + buffer_offset += length; + self.push_value(value); + } + } + } + + /// Extend from a `buffer` and `length` of items. + /// + /// This will attempt to dispatch to several optimized implementations. + #[inline] + pub fn extend_from_lengths( + &mut self, + buffer: &[u8], + lengths_iterator: impl Clone + ExactSizeIterator, + ) { + let (min, max, sum) = lengths_iterator.clone().map(|v| (v, v, v)).fold( + (usize::MAX, usize::MIN, 0usize), + |(cmin, cmax, csum), (emin, emax, esum)| (cmin.min(emin), cmax.max(emax), csum + esum), + ); + + // SAFETY: We just collected the right stats. + unsafe { self.extend_from_lengths_with_stats(buffer, lengths_iterator, min, max, sum) } + } +} + +impl> Extend> for MutableBinaryViewArray { + #[inline] + fn extend>>(&mut self, iter: I) { + Self::extend(self, iter.into_iter()) + } +} + +impl> FromIterator> for MutableBinaryViewArray { + #[inline] + fn from_iter>>(iter: I) -> Self { + Self::from_iterator(iter.into_iter()) + } +} + +impl MutableArray for MutableBinaryViewArray { + fn dtype(&self) -> &ArrowDataType { + T::dtype() + } + + fn len(&self) -> usize { + MutableBinaryViewArray::len(self) + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + let mutable = std::mem::take(self); + let arr: BinaryViewArrayGeneric = mutable.into(); + arr.boxed() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn Any { + self + } + + fn push_null(&mut self) { + MutableBinaryViewArray::push_null(self) + } + + fn reserve(&mut self, additional: usize) { + MutableBinaryViewArray::reserve(self, additional) + } + + fn shrink_to_fit(&mut self) { + self.views.shrink_to_fit() + } +} + +impl> TryExtend> for MutableBinaryViewArray { + /// This is infallible and is implemented for consistency with all other types + #[inline] + fn try_extend>>(&mut self, iter: I) -> PolarsResult<()> { + self.extend(iter.into_iter()); + Ok(()) + } +} + +impl> TryPush> for MutableBinaryViewArray { + /// This is infallible and is implemented for consistency with all other types + #[inline(always)] + fn try_push(&mut self, item: Option

) -> PolarsResult<()> { + self.push(item.as_ref().map(|p| p.as_ref())); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn roundtrip(values: &[&[u8]]) -> bool { + let buffer = values + .iter() + .flat_map(|v| v.iter().copied()) + .collect::>(); + let lengths = values.iter().map(|v| v.len()).collect::>(); + let mut bv = MutableBinaryViewArray::<[u8]>::with_capacity(values.len()); + + bv.extend_from_lengths(&buffer[..], lengths.into_iter()); + + &bv.values_iter().collect::>()[..] == values + } + + #[test] + fn extend_with_lengths_basic() { + assert!(roundtrip(&[])); + assert!(roundtrip(&[b"abc"])); + assert!(roundtrip(&[ + b"a_very_very_long_string_that_is_not_inlinable" + ])); + assert!(roundtrip(&[ + b"abc", + b"a_very_very_long_string_that_is_not_inlinable" + ])); + } + + #[test] + fn extend_with_inlinable_fastpath() { + assert!(roundtrip(&[b"abc", b"defg", b"hix"])); + assert!(roundtrip(&[b"abc", b"defg", b"hix", b"xyza1234abcd"])); + } + + #[test] + fn extend_with_inlinable_eq_len_fastpath() { + assert!(roundtrip(&[b"abc", b"def", b"hix"])); + assert!(roundtrip(&[b"abc", b"def", b"hix", b"xyz"])); + } + + #[test] + fn extend_with_not_inlinable_fastpath() { + assert!(roundtrip(&[ + b"a_very_long_string123", + b"a_longer_string_than_the_previous" + ])); + } +} diff --git a/crates/polars-arrow/src/array/binview/view.rs b/crates/polars-arrow/src/array/binview/view.rs new file mode 100644 index 000000000000..1e9179fcc93e --- /dev/null +++ b/crates/polars-arrow/src/array/binview/view.rs @@ -0,0 +1,495 @@ +use std::cmp::Ordering; +use std::fmt::{self, Display, Formatter}; + +use bytemuck::{Pod, Zeroable}; +use polars_error::*; +use polars_utils::min_max::MinMax; +use polars_utils::nulls::IsNull; +use polars_utils::total_ord::{TotalEq, TotalOrd}; + +use crate::datatypes::PrimitiveType; +use crate::types::{Bytes16Alignment4, NativeType}; + +// We use this instead of u128 because we want alignment of <= 8 bytes. +/// A reference to a set of bytes. +/// +/// If `length <= 12`, these bytes are inlined over the `prefix`, `buffer_idx` and `offset` fields. +/// If `length > 12`, these fields specify a slice of a buffer. +#[derive(Copy, Clone, Default)] +#[repr(C)] +pub struct View { + /// The length of the string/bytes. + pub length: u32, + /// First 4 bytes of string/bytes data. + pub prefix: u32, + /// The buffer index. + pub buffer_idx: u32, + /// The offset into the buffer. + pub offset: u32, +} + +impl fmt::Debug for View { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.length <= Self::MAX_INLINE_SIZE { + fmt.debug_struct("View") + .field("length", &self.length) + .field("content", &unsafe { + std::slice::from_raw_parts( + (self as *const _ as *const u8).add(4), + self.length as usize, + ) + }) + .finish() + } else { + fmt.debug_struct("View") + .field("length", &self.length) + .field("prefix", &self.prefix.to_be_bytes()) + .field("buffer_idx", &self.buffer_idx) + .field("offset", &self.offset) + .finish() + } + } +} + +impl View { + pub const MAX_INLINE_SIZE: u32 = 12; + + #[inline(always)] + pub fn is_inline(&self) -> bool { + self.length <= Self::MAX_INLINE_SIZE + } + + #[inline(always)] + pub fn as_u128(self) -> u128 { + unsafe { std::mem::transmute(self) } + } + + /// Create a new inline view without verifying the length + /// + /// # Safety + /// + /// It needs to hold that `bytes.len() <= View::MAX_INLINE_SIZE`. + #[inline] + pub unsafe fn new_inline_unchecked(bytes: &[u8]) -> Self { + debug_assert!(bytes.len() <= u32::MAX as usize); + debug_assert!(bytes.len() as u32 <= Self::MAX_INLINE_SIZE); + + let mut view = Self { + length: bytes.len() as u32, + ..Default::default() + }; + + let view_ptr = &mut view as *mut _ as *mut u8; + + // SAFETY: + // - bytes length <= 12, + // - size_of:: == 16 + // - View is laid out as [length, prefix, buffer_idx, offset] (using repr(C)) + // - By grabbing the view_ptr and adding 4, we have provenance over prefix, buffer_idx and + // offset. (i.e. the same could not be achieved with &mut self.prefix as *mut _ as *mut u8) + unsafe { + let inline_data_ptr = view_ptr.add(4); + core::ptr::copy_nonoverlapping(bytes.as_ptr(), inline_data_ptr, bytes.len()); + } + view + } + + /// Create a new inline view + /// + /// # Panics + /// + /// Panics if the `bytes.len() > View::MAX_INLINE_SIZE`. + #[inline] + pub fn new_inline(bytes: &[u8]) -> Self { + assert!(bytes.len() as u32 <= Self::MAX_INLINE_SIZE); + unsafe { Self::new_inline_unchecked(bytes) } + } + + /// Create a new inline view + /// + /// # Safety + /// + /// It needs to hold that `bytes.len() > View::MAX_INLINE_SIZE`. + #[inline] + pub unsafe fn new_noninline_unchecked(bytes: &[u8], buffer_idx: u32, offset: u32) -> Self { + debug_assert!(bytes.len() <= u32::MAX as usize); + debug_assert!(bytes.len() as u32 > View::MAX_INLINE_SIZE); + + // SAFETY: The invariant of this function guarantees that this is safe. + let prefix = unsafe { u32::from_le_bytes(bytes[0..4].try_into().unwrap_unchecked()) }; + Self { + length: bytes.len() as u32, + prefix, + buffer_idx, + offset, + } + } + + #[inline] + pub fn new_from_bytes(bytes: &[u8], buffer_idx: u32, offset: u32) -> Self { + debug_assert!(bytes.len() <= u32::MAX as usize); + + // SAFETY: We verify the invariant with the outer if statement + unsafe { + if bytes.len() as u32 <= Self::MAX_INLINE_SIZE { + Self::new_inline_unchecked(bytes) + } else { + Self::new_noninline_unchecked(bytes, buffer_idx, offset) + } + } + } + + /// Constructs a byteslice from this view. + /// + /// # Safety + /// Assumes that this view is valid for the given buffers. + #[inline] + pub unsafe fn get_slice_unchecked<'a, B: AsRef<[u8]>>(&'a self, buffers: &'a [B]) -> &'a [u8] { + unsafe { + if self.length <= Self::MAX_INLINE_SIZE { + self.get_inlined_slice_unchecked() + } else { + self.get_external_slice_unchecked(buffers) + } + } + } + + /// Construct a byte slice from an inline view, if it is inline. + #[inline] + pub fn get_inlined_slice(&self) -> Option<&[u8]> { + if self.length <= Self::MAX_INLINE_SIZE { + unsafe { Some(self.get_inlined_slice_unchecked()) } + } else { + None + } + } + + /// Construct a byte slice from an inline view. + /// + /// # Safety + /// Assumes that this view is inlinable. + #[inline] + pub unsafe fn get_inlined_slice_unchecked(&self) -> &[u8] { + debug_assert!(self.length <= View::MAX_INLINE_SIZE); + let ptr = self as *const View as *const u8; + unsafe { std::slice::from_raw_parts(ptr.add(4), self.length as usize) } + } + + /// Construct a byte slice from an external view. + /// + /// # Safety + /// Assumes that this view is in the external buffers. + #[inline] + pub unsafe fn get_external_slice_unchecked<'a, B: AsRef<[u8]>>( + &self, + buffers: &'a [B], + ) -> &'a [u8] { + debug_assert!(self.length > View::MAX_INLINE_SIZE); + let data = buffers.get_unchecked(self.buffer_idx as usize); + let offset = self.offset as usize; + data.as_ref() + .get_unchecked(offset..offset + self.length as usize) + } + + /// Extend a `Vec` with inline views slices of `src` with `width`. + /// + /// This tries to use SIMD to optimize the copying and can be massively faster than doing a + /// `views.extend(src.chunks_exact(width).map(View::new_inline))`. + /// + /// # Panics + /// + /// This function panics if `src.len()` is not divisible by `width`, `width > + /// View::MAX_INLINE_SIZE` or `width == 0`. + pub fn extend_with_inlinable_strided(views: &mut Vec, src: &[u8], width: u8) { + macro_rules! dispatch { + ($n:ident = $match:ident in [$($v:literal),+ $(,)?] => $block:block, otherwise = $otherwise:expr) => { + match $match { + $( + $v => { + const $n: usize = $v; + + $block + } + )+ + _ => $otherwise, + } + } + } + + let width = width as usize; + + assert!(width > 0); + assert!(width <= View::MAX_INLINE_SIZE as usize); + + assert_eq!(src.len() % width, 0); + + let num_values = src.len() / width; + + views.reserve(num_values); + + #[allow(unused_mut)] + let mut src = src; + + dispatch! { + N = width in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] => { + #[cfg(feature = "simd")] + { + macro_rules! repeat_with { + ($i:ident = [$($v:literal),+ $(,)?] => $block:block) => { + $({ + const $i: usize = $v; + + $block + })+ + } + } + + use std::simd::*; + + // SAFETY: This is always allowed, since views.len() is always in the Vec + // buffer. + let mut dst = unsafe { views.as_mut_ptr().add(views.len()).cast::() }; + + let length_mask = u8x16::from_array([N as u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); + + const BLOCKS_PER_LOAD: usize = 16 / N; + const BYTES_PER_LOOP: usize = N * BLOCKS_PER_LOAD; + + let num_loops = (src.len() / BYTES_PER_LOOP).saturating_sub(1); + + for _ in 0..num_loops { + // SAFETY: The num_loops calculates how many times we can do this. + let loaded = u8x16::from_array(unsafe { + src.get_unchecked(..16).try_into().unwrap() + }); + src = unsafe { src.get_unchecked(BYTES_PER_LOOP..) }; + + // This way we can reuse the same load for multiple views. + repeat_with!( + I = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] => { + if I < BLOCKS_PER_LOAD { + let zero = u8x16::default(); + const SWIZZLE: [usize; 16] = const { + let mut swizzle = [16usize; 16]; + + let mut i = 0; + while i < N { + let idx = i + I * N; + if idx < 16 { + swizzle[4+i] = idx; + } + i += 1; + } + + swizzle + }; + + let scattered = simd_swizzle!(loaded, zero, SWIZZLE); + let view_bytes = (scattered | length_mask).to_array(); + + // SAFETY: dst has the capacity reserved and view_bytes is 16 + // bytes long. + unsafe { + core::ptr::copy_nonoverlapping(view_bytes.as_ptr(), dst, 16); + dst = dst.add(16); + } + } + } + ); + } + + unsafe { + views.set_len(views.len() + num_loops * BLOCKS_PER_LOAD); + } + } + + views.extend(src.chunks_exact(N).map(|slice| unsafe { + View::new_inline_unchecked(slice) + })); + }, + otherwise = unreachable!() + } + } +} + +impl IsNull for View { + const HAS_NULLS: bool = false; + type Inner = Self; + + fn is_null(&self) -> bool { + false + } + + fn unwrap_inner(self) -> Self::Inner { + self + } +} + +impl Display for View { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + +unsafe impl Zeroable for View {} + +unsafe impl Pod for View {} + +impl PartialEq for View { + fn eq(&self, other: &Self) -> bool { + self.as_u128() == other.as_u128() + } +} + +// These are 'implemented' because we want to implement NativeType +// for View, that should probably not be done. +impl TotalOrd for View { + fn tot_cmp(&self, _other: &Self) -> Ordering { + unimplemented!() + } +} + +impl TotalEq for View { + fn tot_eq(&self, other: &Self) -> bool { + self.eq(other) + } +} + +impl MinMax for View { + fn nan_min_lt(&self, _other: &Self) -> bool { + unimplemented!() + } + + fn nan_max_lt(&self, _other: &Self) -> bool { + unimplemented!() + } +} + +impl NativeType for View { + const PRIMITIVE: PrimitiveType = PrimitiveType::UInt128; + + type Bytes = [u8; 16]; + type AlignedBytes = Bytes16Alignment4; + + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + self.as_u128().to_le_bytes() + } + + #[inline] + fn to_be_bytes(&self) -> Self::Bytes { + self.as_u128().to_be_bytes() + } + + #[inline] + fn from_le_bytes(bytes: Self::Bytes) -> Self { + Self::from(u128::from_le_bytes(bytes)) + } + + #[inline] + fn from_be_bytes(bytes: Self::Bytes) -> Self { + Self::from(u128::from_be_bytes(bytes)) + } +} + +impl From for View { + #[inline] + fn from(value: u128) -> Self { + unsafe { std::mem::transmute(value) } + } +} + +impl From for u128 { + #[inline] + fn from(value: View) -> Self { + value.as_u128() + } +} + +pub fn validate_views, F>( + views: &[View], + buffers: &[B], + validate_bytes: F, +) -> PolarsResult<()> +where + F: Fn(&[u8]) -> PolarsResult<()>, +{ + for view in views { + if let Some(inline_slice) = view.get_inlined_slice() { + if view.length < View::MAX_INLINE_SIZE && view.as_u128() >> (32 + view.length * 8) != 0 + { + polars_bail!(ComputeError: "view contained non-zero padding in prefix"); + } + + validate_bytes(inline_slice)?; + } else { + let data = buffers.get(view.buffer_idx as usize).ok_or_else(|| { + polars_err!(OutOfBounds: "view index out of bounds\n\nGot: {} buffers and index: {}", buffers.len(), view.buffer_idx) + })?; + + let start = view.offset as usize; + let end = start + view.length as usize; + let b = data + .as_ref() + .get(start..end) + .ok_or_else(|| polars_err!(OutOfBounds: "buffer slice out of bounds"))?; + + polars_ensure!(b.starts_with(&view.prefix.to_le_bytes()), ComputeError: "prefix does not match string data"); + validate_bytes(b)?; + } + } + + Ok(()) +} + +pub fn validate_binary_views>(views: &[View], buffers: &[B]) -> PolarsResult<()> { + validate_views(views, buffers, |_| Ok(())) +} + +fn validate_utf8(b: &[u8]) -> PolarsResult<()> { + match simdutf8::basic::from_utf8(b) { + Ok(_) => Ok(()), + Err(_) => Err(polars_err!(ComputeError: "invalid utf8")), + } +} + +pub fn validate_utf8_views>(views: &[View], buffers: &[B]) -> PolarsResult<()> { + validate_views(views, buffers, validate_utf8) +} + +/// Checks the views for valid UTF-8. Assumes the first num_trusted_buffers are +/// valid UTF-8 without checking. +/// # Safety +/// The views and buffers must uphold the invariants of BinaryView otherwise we will go OOB. +pub unsafe fn validate_views_utf8_only>( + views: &[View], + buffers: &[B], + mut num_trusted_buffers: usize, +) -> PolarsResult<()> { + unsafe { + while num_trusted_buffers < buffers.len() + && buffers[num_trusted_buffers].as_ref().is_ascii() + { + num_trusted_buffers += 1; + } + + // Fast path if all buffers are ASCII (or there are no buffers). + if num_trusted_buffers >= buffers.len() { + for view in views { + if let Some(inlined_slice) = view.get_inlined_slice() { + validate_utf8(inlined_slice)?; + } + } + } else { + for view in views { + if view.length <= View::MAX_INLINE_SIZE + || view.buffer_idx as usize >= num_trusted_buffers + { + validate_utf8(view.get_slice_unchecked(buffers))?; + } + } + } + + Ok(()) + } +} diff --git a/crates/polars-arrow/src/array/boolean/builder.rs b/crates/polars-arrow/src/array/boolean/builder.rs new file mode 100644 index 000000000000..38919020b355 --- /dev/null +++ b/crates/polars-arrow/src/array/boolean/builder.rs @@ -0,0 +1,102 @@ +use polars_utils::IdxSize; + +use super::BooleanArray; +use crate::array::builder::{ShareStrategy, StaticArrayBuilder}; +use crate::bitmap::{BitmapBuilder, OptBitmapBuilder}; +use crate::datatypes::ArrowDataType; + +pub struct BooleanArrayBuilder { + dtype: ArrowDataType, + values: BitmapBuilder, + validity: OptBitmapBuilder, +} + +impl BooleanArrayBuilder { + pub fn new(dtype: ArrowDataType) -> Self { + Self { + dtype, + values: BitmapBuilder::new(), + validity: OptBitmapBuilder::default(), + } + } +} + +impl StaticArrayBuilder for BooleanArrayBuilder { + type Array = BooleanArray; + + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } + + fn reserve(&mut self, additional: usize) { + self.values.reserve(additional); + self.validity.reserve(additional); + } + + fn freeze(self) -> BooleanArray { + let values = self.values.freeze(); + let validity = self.validity.into_opt_validity(); + BooleanArray::try_new(self.dtype, values, validity).unwrap() + } + + fn freeze_reset(&mut self) -> Self::Array { + let values = core::mem::take(&mut self.values).freeze(); + let validity = core::mem::take(&mut self.validity).into_opt_validity(); + BooleanArray::try_new(self.dtype.clone(), values, validity).unwrap() + } + + fn len(&self) -> usize { + self.values.len() + } + + fn extend_nulls(&mut self, length: usize) { + self.values.extend_constant(length, false); + self.validity.extend_constant(length, false); + } + + fn subslice_extend( + &mut self, + other: &BooleanArray, + start: usize, + length: usize, + _share: ShareStrategy, + ) { + self.values + .subslice_extend_from_bitmap(other.values(), start, length); + self.validity + .subslice_extend_from_opt_validity(other.validity(), start, length); + } + + unsafe fn gather_extend( + &mut self, + other: &BooleanArray, + idxs: &[IdxSize], + _share: ShareStrategy, + ) { + self.values.reserve(idxs.len()); + for idx in idxs { + self.values + .push_unchecked(other.value_unchecked(*idx as usize)); + } + self.validity + .gather_extend_from_opt_validity(other.validity(), idxs); + } + + fn opt_gather_extend(&mut self, other: &BooleanArray, idxs: &[IdxSize], _share: ShareStrategy) { + self.values.reserve(idxs.len()); + unsafe { + for idx in idxs { + let val = if (*idx as usize) < other.len() { + // We don't use get here as that double-checks the validity + // which we don't care about here. + other.value_unchecked(*idx as usize) + } else { + false + }; + self.values.push_unchecked(val); + } + } + self.validity + .opt_gather_extend_from_opt_validity(other.validity(), idxs, other.len()); + } +} diff --git a/crates/polars-arrow/src/array/boolean/ffi.rs b/crates/polars-arrow/src/array/boolean/ffi.rs new file mode 100644 index 000000000000..dfaf3ac90571 --- /dev/null +++ b/crates/polars-arrow/src/array/boolean/ffi.rs @@ -0,0 +1,55 @@ +use polars_error::PolarsResult; + +use super::BooleanArray; +use crate::array::{FromFfi, ToFfi}; +use crate::bitmap::align; +use crate::ffi; + +unsafe impl ToFfi for BooleanArray { + fn buffers(&self) -> Vec> { + vec![ + self.validity.as_ref().map(|x| x.as_ptr()), + Some(self.values.as_ptr()), + ] + } + + fn offset(&self) -> Option { + let offset = self.values.offset(); + if let Some(bitmap) = self.validity.as_ref() { + if bitmap.offset() == offset { + Some(offset) + } else { + None + } + } else { + Some(offset) + } + } + + fn to_ffi_aligned(&self) -> Self { + let offset = self.values.offset(); + + let validity = self.validity.as_ref().map(|bitmap| { + if bitmap.offset() == offset { + bitmap.clone() + } else { + align(bitmap, offset) + } + }); + + Self { + dtype: self.dtype.clone(), + validity, + values: self.values.clone(), + } + } +} + +impl FromFfi for BooleanArray { + unsafe fn try_from_ffi(array: A) -> PolarsResult { + let dtype = array.dtype().clone(); + let validity = unsafe { array.validity() }?; + let values = unsafe { array.bitmap(1) }?; + Self::try_new(dtype, values, validity) + } +} diff --git a/crates/polars-arrow/src/array/boolean/fmt.rs b/crates/polars-arrow/src/array/boolean/fmt.rs new file mode 100644 index 000000000000..229a01cd3e03 --- /dev/null +++ b/crates/polars-arrow/src/array/boolean/fmt.rs @@ -0,0 +1,17 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::write_vec; +use super::BooleanArray; + +pub fn write_value(array: &BooleanArray, index: usize, f: &mut W) -> Result { + write!(f, "{}", array.value(index)) +} + +impl Debug for BooleanArray { + fn fmt(&self, f: &mut Formatter) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, f); + + write!(f, "BooleanArray")?; + write_vec(f, writer, self.validity(), self.len(), "None", false) + } +} diff --git a/crates/polars-arrow/src/array/boolean/from.rs b/crates/polars-arrow/src/array/boolean/from.rs new file mode 100644 index 000000000000..07553d78b737 --- /dev/null +++ b/crates/polars-arrow/src/array/boolean/from.rs @@ -0,0 +1,13 @@ +use super::{BooleanArray, MutableBooleanArray}; + +impl]>> From

for BooleanArray { + fn from(slice: P) -> Self { + MutableBooleanArray::from(slice).into() + } +} + +impl>> FromIterator for BooleanArray { + fn from_iter>(iter: I) -> Self { + MutableBooleanArray::from_iter(iter).into() + } +} diff --git a/crates/polars-arrow/src/array/boolean/iterator.rs b/crates/polars-arrow/src/array/boolean/iterator.rs new file mode 100644 index 000000000000..5d4c39e1ec64 --- /dev/null +++ b/crates/polars-arrow/src/array/boolean/iterator.rs @@ -0,0 +1,70 @@ +use super::super::MutableArray; +use super::{BooleanArray, MutableBooleanArray}; +use crate::array::ArrayAccessor; +use crate::bitmap::IntoIter; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; + +impl<'a> IntoIterator for &'a BooleanArray { + type Item = Option; + type IntoIter = ZipValidity, BitmapIter<'a>>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl IntoIterator for BooleanArray { + type Item = Option; + type IntoIter = ZipValidity; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + let (_, values, validity) = self.into_inner(); + let values = values.into_iter(); + let validity = + validity.and_then(|validity| (validity.unset_bits() > 0).then(|| validity.into_iter())); + ZipValidity::new(values, validity) + } +} + +impl<'a> IntoIterator for &'a MutableBooleanArray { + type Item = Option; + type IntoIter = ZipValidity, BitmapIter<'a>>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a> MutableBooleanArray { + /// Returns an iterator over the optional values of this [`MutableBooleanArray`]. + #[inline] + pub fn iter(&'a self) -> ZipValidity, BitmapIter<'a>> { + ZipValidity::new( + self.values().iter(), + self.validity().as_ref().map(|x| x.iter()), + ) + } + + /// Returns an iterator over the values of this [`MutableBooleanArray`] + #[inline] + pub fn values_iter(&'a self) -> BitmapIter<'a> { + self.values().iter() + } +} + +unsafe impl<'a> ArrayAccessor<'a> for BooleanArray { + type Item = bool; + + #[inline] + unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item { + (*self).value_unchecked(index) + } + + #[inline] + fn len(&self) -> usize { + (*self).len() + } +} diff --git a/crates/polars-arrow/src/array/boolean/mod.rs b/crates/polars-arrow/src/array/boolean/mod.rs new file mode 100644 index 000000000000..47a651bbd9dd --- /dev/null +++ b/crates/polars-arrow/src/array/boolean/mod.rs @@ -0,0 +1,439 @@ +use either::Either; +use polars_error::{PolarsResult, polars_bail}; + +use super::{Array, Splitable}; +use crate::array::iterator::NonNullValuesIter; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::compute::utils::{combine_validities_and, combine_validities_or}; +use crate::datatypes::{ArrowDataType, PhysicalType}; +use crate::trusted_len::TrustedLen; + +mod ffi; +pub(super) mod fmt; +mod from; +mod iterator; +mod mutable; +pub use mutable::*; +mod builder; +pub use builder::*; + +/// A [`BooleanArray`] is Arrow's semantically equivalent of an immutable `Vec>`. +/// It implements [`Array`]. +/// +/// One way to think about a [`BooleanArray`] is `(DataType, Arc>, Option>>)` +/// where: +/// * the first item is the array's logical type +/// * the second is the immutable values +/// * the third is the immutable validity (whether a value is null or not as a bitmap). +/// +/// The size of this struct is `O(1)`, as all data is stored behind an [`std::sync::Arc`]. +/// # Example +/// ``` +/// use polars_arrow::array::BooleanArray; +/// use polars_arrow::bitmap::Bitmap; +/// use polars_arrow::buffer::Buffer; +/// +/// let array = BooleanArray::from([Some(true), None, Some(false)]); +/// assert_eq!(array.value(0), true); +/// assert_eq!(array.iter().collect::>(), vec![Some(true), None, Some(false)]); +/// assert_eq!(array.values_iter().collect::>(), vec![true, false, false]); +/// // the underlying representation +/// assert_eq!(array.values(), &Bitmap::from([true, false, false])); +/// assert_eq!(array.validity(), Some(&Bitmap::from([true, false, true]))); +/// +/// ``` +#[derive(Clone)] +pub struct BooleanArray { + dtype: ArrowDataType, + values: Bitmap, + validity: Option, +} + +impl BooleanArray { + /// The canonical method to create a [`BooleanArray`] out of low-end APIs. + /// # Errors + /// This function errors iff: + /// * The validity is not `None` and its length is different from `values`'s length + /// * The `dtype`'s [`PhysicalType`] is not equal to [`PhysicalType::Boolean`]. + pub fn try_new( + dtype: ArrowDataType, + values: Bitmap, + validity: Option, + ) -> PolarsResult { + if validity + .as_ref() + .is_some_and(|validity| validity.len() != values.len()) + { + polars_bail!(ComputeError: "validity mask length must match the number of values") + } + + if dtype.to_physical_type() != PhysicalType::Boolean { + polars_bail!(ComputeError: "BooleanArray can only be initialized with a DataType whose physical type is Boolean") + } + + Ok(Self { + dtype, + values, + validity, + }) + } + + /// Alias to `Self::try_new().unwrap()` + pub fn new(dtype: ArrowDataType, values: Bitmap, validity: Option) -> Self { + Self::try_new(dtype, values, validity).unwrap() + } + + /// Returns an iterator over the optional values of this [`BooleanArray`]. + #[inline] + pub fn iter(&self) -> ZipValidity { + ZipValidity::new_with_validity(self.values().iter(), self.validity()) + } + + /// Returns an iterator over the values of this [`BooleanArray`]. + #[inline] + pub fn values_iter(&self) -> BitmapIter { + self.values().iter() + } + + /// Returns an iterator of the non-null values. + #[inline] + pub fn non_null_values_iter(&self) -> NonNullValuesIter<'_, BooleanArray> { + NonNullValuesIter::new(self, self.validity()) + } + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.values.len() + } + + /// The values [`Bitmap`]. + /// Values on null slots are undetermined (they can be anything). + #[inline] + pub fn values(&self) -> &Bitmap { + &self.values + } + + /// Returns the optional validity. + #[inline] + pub fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + /// Returns the arrays' [`ArrowDataType`]. + #[inline] + pub fn dtype(&self) -> &ArrowDataType { + &self.dtype + } + + /// Returns the value at index `i` + /// # Panic + /// This function panics iff `i >= self.len()`. + #[inline] + pub fn value(&self, i: usize) -> bool { + self.values.get_bit(i) + } + + /// Returns the element at index `i` as bool + /// + /// # Safety + /// Caller must be sure that `i < self.len()` + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> bool { + self.values.get_bit_unchecked(i) + } + + /// Returns the element at index `i` or `None` if it is null + /// # Panics + /// iff `i >= self.len()` + #[inline] + pub fn get(&self, i: usize) -> Option { + if !self.is_null(i) { + // soundness: Array::is_null panics if i >= self.len + unsafe { Some(self.value_unchecked(i)) } + } else { + None + } + } + + /// Slices this [`BooleanArray`]. + /// # Implementation + /// This operation is `O(1)` as it amounts to increase up to two ref counts. + /// # Panic + /// This function panics iff `offset + length > self.len()`. + #[inline] + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Slices this [`BooleanArray`]. + /// # Implementation + /// This operation is `O(1)` as it amounts to increase two ref counts. + /// + /// # Safety + /// The caller must ensure that `offset + length <= self.len()`. + #[inline] + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.validity = self + .validity + .take() + .map(|bitmap| bitmap.sliced_unchecked(offset, length)) + .filter(|bitmap| bitmap.unset_bits() > 0); + self.values.slice_unchecked(offset, length); + } + + impl_sliced!(); + impl_mut_validity!(); + impl_into_array!(); + + /// Returns a clone of this [`BooleanArray`] with new values. + /// # Panics + /// This function panics iff `values.len() != self.len()`. + #[must_use] + pub fn with_values(&self, values: Bitmap) -> Self { + let mut out = self.clone(); + out.set_values(values); + out + } + + /// Sets the values of this [`BooleanArray`]. + /// # Panics + /// This function panics iff `values.len() != self.len()`. + pub fn set_values(&mut self, values: Bitmap) { + assert_eq!( + values.len(), + self.len(), + "values length must be equal to this arrays length" + ); + self.values = values; + } + + /// Applies a function `f` to the values of this array, cloning the values + /// iff they are being shared with others + /// + /// This is an API to use clone-on-write + /// # Implementation + /// This function is `O(f)` if the data is not being shared, and `O(N) + O(f)` + /// if it is being shared (since it results in a `O(N)` memcopy). + /// # Panics + /// This function panics if the function modifies the length of the [`MutableBitmap`]. + pub fn apply_values_mut(&mut self, f: F) { + let values = std::mem::take(&mut self.values); + let mut values = values.make_mut(); + f(&mut values); + if let Some(validity) = &self.validity { + assert_eq!(validity.len(), values.len()); + } + self.values = values.into(); + } + + /// Try to convert this [`BooleanArray`] to a [`MutableBooleanArray`] + pub fn into_mut(self) -> Either { + use Either::*; + + if let Some(bitmap) = self.validity { + match bitmap.into_mut() { + Left(bitmap) => Left(BooleanArray::new(self.dtype, self.values, Some(bitmap))), + Right(mutable_bitmap) => match self.values.into_mut() { + Left(immutable) => Left(BooleanArray::new( + self.dtype, + immutable, + Some(mutable_bitmap.into()), + )), + Right(mutable) => Right( + MutableBooleanArray::try_new(self.dtype, mutable, Some(mutable_bitmap)) + .unwrap(), + ), + }, + } + } else { + match self.values.into_mut() { + Left(immutable) => Left(BooleanArray::new(self.dtype, immutable, None)), + Right(mutable) => { + Right(MutableBooleanArray::try_new(self.dtype, mutable, None).unwrap()) + }, + } + } + } + + /// Returns a new empty [`BooleanArray`]. + pub fn new_empty(dtype: ArrowDataType) -> Self { + Self::new(dtype, Bitmap::new(), None) + } + + /// Returns a new [`BooleanArray`] whose all slots are null / `None`. + pub fn new_null(dtype: ArrowDataType, length: usize) -> Self { + let bitmap = Bitmap::new_zeroed(length); + Self::new(dtype, bitmap.clone(), Some(bitmap)) + } + + /// Creates a new [`BooleanArray`] from an [`TrustedLen`] of `bool`. + #[inline] + pub fn from_trusted_len_values_iter>(iterator: I) -> Self { + MutableBooleanArray::from_trusted_len_values_iter(iterator).into() + } + + /// Creates a new [`BooleanArray`] from an [`TrustedLen`] of `bool`. + /// Use this over [`BooleanArray::from_trusted_len_iter`] when the iterator is trusted len + /// but this crate does not mark it as such. + /// + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_values_iter_unchecked>( + iterator: I, + ) -> Self { + MutableBooleanArray::from_trusted_len_values_iter_unchecked(iterator).into() + } + + /// Creates a new [`BooleanArray`] from a slice of `bool`. + #[inline] + pub fn from_slice>(slice: P) -> Self { + MutableBooleanArray::from_slice(slice).into() + } + + /// Creates a [`BooleanArray`] from an iterator of trusted length. + /// Use this over [`BooleanArray::from_trusted_len_iter`] when the iterator is trusted len + /// but this crate does not mark it as such. + /// + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + P: std::borrow::Borrow, + I: Iterator>, + { + MutableBooleanArray::from_trusted_len_iter_unchecked(iterator).into() + } + + /// Creates a [`BooleanArray`] from a [`TrustedLen`]. + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + P: std::borrow::Borrow, + I: TrustedLen>, + { + MutableBooleanArray::from_trusted_len_iter(iterator).into() + } + + /// Creates a [`BooleanArray`] from an falible iterator of trusted length. + /// + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn try_from_trusted_len_iter_unchecked(iterator: I) -> Result + where + P: std::borrow::Borrow, + I: Iterator, E>>, + { + Ok(MutableBooleanArray::try_from_trusted_len_iter_unchecked(iterator)?.into()) + } + + /// Creates a [`BooleanArray`] from a [`TrustedLen`]. + #[inline] + pub fn try_from_trusted_len_iter(iterator: I) -> Result + where + P: std::borrow::Borrow, + I: TrustedLen, E>>, + { + Ok(MutableBooleanArray::try_from_trusted_len_iter(iterator)?.into()) + } + + pub fn true_and_valid(&self) -> Bitmap { + match &self.validity { + None => self.values.clone(), + Some(validity) => combine_validities_and(Some(&self.values), Some(validity)).unwrap(), + } + } + + pub fn true_or_valid(&self) -> Bitmap { + match &self.validity { + None => self.values.clone(), + Some(validity) => combine_validities_or(Some(&self.values), Some(validity)).unwrap(), + } + } + + /// Returns its internal representation + #[must_use] + pub fn into_inner(self) -> (ArrowDataType, Bitmap, Option) { + let Self { + dtype, + values, + validity, + } = self; + (dtype, values, validity) + } + + /// Creates a [`BooleanArray`] from its internal representation. + /// This is the inverted from [`BooleanArray::into_inner`] + /// + /// # Safety + /// Callers must ensure all invariants of this struct are upheld. + pub unsafe fn from_inner_unchecked( + dtype: ArrowDataType, + values: Bitmap, + validity: Option, + ) -> Self { + Self { + dtype, + values, + validity, + } + } +} + +impl Array for BooleanArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + #[inline] + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } +} + +impl Splitable for BooleanArray { + fn check_bound(&self, offset: usize) -> bool { + offset <= self.len() + } + + unsafe fn _split_at_unchecked(&self, offset: usize) -> (Self, Self) { + let (lhs_values, rhs_values) = unsafe { self.values.split_at_unchecked(offset) }; + let (lhs_validity, rhs_validity) = unsafe { self.validity.split_at_unchecked(offset) }; + + ( + Self { + dtype: self.dtype.clone(), + values: lhs_values, + validity: lhs_validity, + }, + Self { + dtype: self.dtype.clone(), + values: rhs_values, + validity: rhs_validity, + }, + ) + } +} + +impl From for BooleanArray { + fn from(values: Bitmap) -> Self { + Self { + dtype: ArrowDataType::Boolean, + values, + validity: None, + } + } +} diff --git a/crates/polars-arrow/src/array/boolean/mutable.rs b/crates/polars-arrow/src/array/boolean/mutable.rs new file mode 100644 index 000000000000..616eabe159a5 --- /dev/null +++ b/crates/polars-arrow/src/array/boolean/mutable.rs @@ -0,0 +1,605 @@ +use std::sync::Arc; + +use polars_error::{PolarsResult, polars_bail}; + +use super::BooleanArray; +use crate::array::physical_binary::extend_validity; +use crate::array::{Array, MutableArray, TryExtend, TryExtendFromSelf, TryPush}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::{ArrowDataType, PhysicalType}; +use crate::trusted_len::TrustedLen; + +/// The Arrow's equivalent to `Vec>`, but with `1/16` of its size. +/// Converting a [`MutableBooleanArray`] into a [`BooleanArray`] is `O(1)`. +/// # Implementation +/// This struct does not allocate a validity until one is required (i.e. push a null to it). +#[derive(Debug, Clone)] +pub struct MutableBooleanArray { + dtype: ArrowDataType, + values: MutableBitmap, + validity: Option, +} + +impl From for BooleanArray { + fn from(other: MutableBooleanArray) -> Self { + BooleanArray::new( + other.dtype, + other.values.into(), + other.validity.map(|x| x.into()), + ) + } +} + +impl]>> From

for MutableBooleanArray { + /// Creates a new [`MutableBooleanArray`] out of a slice of Optional `bool`. + fn from(slice: P) -> Self { + Self::from_trusted_len_iter(slice.as_ref().iter().map(|x| x.as_ref())) + } +} + +impl Default for MutableBooleanArray { + fn default() -> Self { + Self::new() + } +} + +impl MutableBooleanArray { + /// Creates an new empty [`MutableBooleanArray`]. + pub fn new() -> Self { + Self::with_capacity(0) + } + + /// The canonical method to create a [`MutableBooleanArray`] out of low-end APIs. + /// # Errors + /// This function errors iff: + /// * The validity is not `None` and its length is different from `values`'s length + /// * The `dtype`'s [`PhysicalType`] is not equal to [`PhysicalType::Boolean`]. + pub fn try_new( + dtype: ArrowDataType, + values: MutableBitmap, + validity: Option, + ) -> PolarsResult { + if validity + .as_ref() + .is_some_and(|validity| validity.len() != values.len()) + { + polars_bail!(ComputeError: + "validity mask length must match the number of values", + ) + } + + if dtype.to_physical_type() != PhysicalType::Boolean { + polars_bail!( + oos = "MutableBooleanArray can only be initialized with a DataType whose physical type is Boolean", + ) + } + + Ok(Self { + dtype, + values, + validity, + }) + } + + /// Creates an new [`MutableBooleanArray`] with a capacity of values. + pub fn with_capacity(capacity: usize) -> Self { + Self { + dtype: ArrowDataType::Boolean, + values: MutableBitmap::with_capacity(capacity), + validity: None, + } + } + + /// Reserves `additional` slots. + pub fn reserve(&mut self, additional: usize) { + self.values.reserve(additional); + if let Some(x) = self.validity.as_mut() { + x.reserve(additional) + } + } + + #[inline] + pub fn push_value(&mut self, value: bool) { + self.values.push(value); + if let Some(validity) = &mut self.validity { + validity.push(true) + } + } + + #[inline] + pub fn push_null(&mut self) { + self.values.push(false); + match &mut self.validity { + Some(validity) => validity.push(false), + None => self.init_validity(), + } + } + + /// Pushes a new entry to [`MutableBooleanArray`]. + #[inline] + pub fn push(&mut self, value: Option) { + match value { + Some(value) => self.push_value(value), + None => self.push_null(), + } + } + + /// Pop an entry from [`MutableBooleanArray`]. + /// Note If the values is empty, this method will return None. + pub fn pop(&mut self) -> Option { + let value = self.values.pop()?; + self.validity + .as_mut() + .map(|x| x.pop()?.then(|| value)) + .unwrap_or_else(|| Some(value)) + } + + /// Extends the [`MutableBooleanArray`] from an iterator of values of trusted len. + /// This differs from `extend_trusted_len` which accepts in iterator of optional values. + #[inline] + pub fn extend_trusted_len_values(&mut self, iterator: I) + where + I: TrustedLen, + { + // SAFETY: `I` is `TrustedLen` + unsafe { self.extend_trusted_len_values_unchecked(iterator) } + } + + /// Extends the [`MutableBooleanArray`] from an iterator of values of trusted len. + /// This differs from `extend_trusted_len_unchecked`, which accepts in iterator of optional values. + /// + /// # Safety + /// The iterator must be trusted len. + #[inline] + pub unsafe fn extend_trusted_len_values_unchecked(&mut self, iterator: I) + where + I: Iterator, + { + let (_, upper) = iterator.size_hint(); + let additional = + upper.expect("extend_trusted_len_values_unchecked requires an upper limit"); + + if let Some(validity) = self.validity.as_mut() { + validity.extend_constant(additional, true); + } + + self.values.extend_from_trusted_len_iter_unchecked(iterator) + } + + /// Extends the [`MutableBooleanArray`] from an iterator of trusted len. + #[inline] + pub fn extend_trusted_len(&mut self, iterator: I) + where + P: std::borrow::Borrow, + I: TrustedLen>, + { + // SAFETY: `I` is `TrustedLen` + unsafe { self.extend_trusted_len_unchecked(iterator) } + } + + /// Extends the [`MutableBooleanArray`] from an iterator of trusted len. + /// + /// # Safety + /// The iterator must be trusted len. + #[inline] + pub unsafe fn extend_trusted_len_unchecked(&mut self, iterator: I) + where + P: std::borrow::Borrow, + I: Iterator>, + { + if let Some(validity) = self.validity.as_mut() { + extend_trusted_len_unzip(iterator, validity, &mut self.values); + } else { + let mut validity = MutableBitmap::new(); + validity.extend_constant(self.len(), true); + + extend_trusted_len_unzip(iterator, &mut validity, &mut self.values); + + if validity.unset_bits() > 0 { + self.validity = Some(validity); + } + } + } + + /// Extends `MutableBooleanArray` by additional values of constant value. + #[inline] + pub fn extend_constant(&mut self, additional: usize, value: Option) { + match value { + Some(value) => { + self.values.extend_constant(additional, value); + if let Some(validity) = self.validity.as_mut() { + validity.extend_constant(additional, true); + } + }, + None => { + self.values.extend_constant(additional, false); + if let Some(validity) = self.validity.as_mut() { + validity.extend_constant(additional, false) + } else { + self.init_validity(); + self.validity + .as_mut() + .unwrap() + .extend_constant(additional, false) + }; + }, + }; + } + + fn init_validity(&mut self) { + let mut validity = MutableBitmap::with_capacity(self.values.capacity()); + validity.extend_constant(self.len(), true); + validity.set(self.len() - 1, false); + self.validity = Some(validity) + } + + /// Converts itself into an [`Array`]. + pub fn into_arc(self) -> Arc { + let a: BooleanArray = self.into(); + Arc::new(a) + } + + pub fn freeze(self) -> BooleanArray { + self.into() + } +} + +/// Getters +impl MutableBooleanArray { + /// Returns its values. + pub fn values(&self) -> &MutableBitmap { + &self.values + } +} + +/// Setters +impl MutableBooleanArray { + /// Sets position `index` to `value`. + /// Note that if it is the first time a null appears in this array, + /// this initializes the validity bitmap (`O(N)`). + /// # Panic + /// Panics iff index is larger than `self.len()`. + pub fn set(&mut self, index: usize, value: Option) { + self.values.set(index, value.unwrap_or_default()); + + if value.is_none() && self.validity.is_none() { + // When the validity is None, all elements so far are valid. When one of the elements is set of null, + // the validity must be initialized. + self.validity = Some(MutableBitmap::from_trusted_len_iter(std::iter::repeat_n( + true, + self.len(), + ))); + } + if let Some(x) = self.validity.as_mut() { + x.set(index, value.is_some()) + } + } +} + +/// From implementations +impl MutableBooleanArray { + /// Creates a new [`MutableBooleanArray`] from an [`TrustedLen`] of `bool`. + #[inline] + pub fn from_trusted_len_values_iter>(iterator: I) -> Self { + Self::try_new( + ArrowDataType::Boolean, + MutableBitmap::from_trusted_len_iter(iterator), + None, + ) + .unwrap() + } + + /// Creates a new [`MutableBooleanArray`] from an [`TrustedLen`] of `bool`. + /// Use this over [`BooleanArray::from_trusted_len_iter`] when the iterator is trusted len + /// but this crate does not mark it as such. + /// + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_values_iter_unchecked>( + iterator: I, + ) -> Self { + let mut mutable = MutableBitmap::new(); + mutable.extend_from_trusted_len_iter_unchecked(iterator); + MutableBooleanArray::try_new(ArrowDataType::Boolean, mutable, None).unwrap() + } + + /// Creates a new [`MutableBooleanArray`] from a slice of `bool`. + #[inline] + pub fn from_slice>(slice: P) -> Self { + Self::from_trusted_len_values_iter(slice.as_ref().iter().copied()) + } + + /// Creates a [`BooleanArray`] from an iterator of trusted length. + /// Use this over [`BooleanArray::from_trusted_len_iter`] when the iterator is trusted len + /// but this crate does not mark it as such. + /// + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + P: std::borrow::Borrow, + I: Iterator>, + { + let (validity, values) = trusted_len_unzip(iterator); + + Self::try_new(ArrowDataType::Boolean, values, validity).unwrap() + } + + /// Creates a [`BooleanArray`] from a [`TrustedLen`]. + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + P: std::borrow::Borrow, + I: TrustedLen>, + { + // SAFETY: `I` is `TrustedLen` + unsafe { Self::from_trusted_len_iter_unchecked(iterator) } + } + + /// Creates a [`BooleanArray`] from an falible iterator of trusted length. + /// + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn try_from_trusted_len_iter_unchecked( + iterator: I, + ) -> std::result::Result + where + P: std::borrow::Borrow, + I: Iterator, E>>, + { + let (validity, values) = try_trusted_len_unzip(iterator)?; + + let validity = if validity.unset_bits() > 0 { + Some(validity) + } else { + None + }; + + Ok(Self::try_new(ArrowDataType::Boolean, values, validity).unwrap()) + } + + /// Creates a [`BooleanArray`] from a [`TrustedLen`]. + #[inline] + pub fn try_from_trusted_len_iter(iterator: I) -> std::result::Result + where + P: std::borrow::Borrow, + I: TrustedLen, E>>, + { + // SAFETY: `I` is `TrustedLen` + unsafe { Self::try_from_trusted_len_iter_unchecked(iterator) } + } + + /// Shrinks the capacity of the [`MutableBooleanArray`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + if let Some(validity) = &mut self.validity { + validity.shrink_to_fit() + } + } +} + +/// Creates a Bitmap and an optional [`MutableBitmap`] from an iterator of `Option`. +/// The first buffer corresponds to a bitmap buffer, the second one +/// corresponds to a values buffer. +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +pub(crate) unsafe fn trusted_len_unzip(iterator: I) -> (Option, MutableBitmap) +where + P: std::borrow::Borrow, + I: Iterator>, +{ + let mut validity = MutableBitmap::new(); + let mut values = MutableBitmap::new(); + + extend_trusted_len_unzip(iterator, &mut validity, &mut values); + + let validity = if validity.unset_bits() > 0 { + Some(validity) + } else { + None + }; + + (validity, values) +} + +/// Extends validity [`MutableBitmap`] and values [`MutableBitmap`] from an iterator of `Option`. +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +pub(crate) unsafe fn extend_trusted_len_unzip( + iterator: I, + validity: &mut MutableBitmap, + values: &mut MutableBitmap, +) where + P: std::borrow::Borrow, + I: Iterator>, +{ + let (_, upper) = iterator.size_hint(); + let additional = upper.expect("extend_trusted_len_unzip requires an upper limit"); + + // Length of the array before new values are pushed, + // variable created for assertion post operation + let pre_length = values.len(); + + validity.reserve(additional); + values.reserve(additional); + + for item in iterator { + let item = if let Some(item) = item { + validity.push_unchecked(true); + *item.borrow() + } else { + validity.push_unchecked(false); + bool::default() + }; + values.push_unchecked(item); + } + + debug_assert_eq!( + values.len(), + pre_length + additional, + "Trusted iterator length was not accurately reported" + ); +} + +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +pub(crate) unsafe fn try_trusted_len_unzip( + iterator: I, +) -> std::result::Result<(MutableBitmap, MutableBitmap), E> +where + P: std::borrow::Borrow, + I: Iterator, E>>, +{ + let (_, upper) = iterator.size_hint(); + let len = upper.expect("trusted_len_unzip requires an upper limit"); + + let mut null = MutableBitmap::with_capacity(len); + let mut values = MutableBitmap::with_capacity(len); + + for item in iterator { + let item = if let Some(item) = item? { + null.push(true); + *item.borrow() + } else { + null.push(false); + false + }; + values.push(item); + } + assert_eq!( + values.len(), + len, + "Trusted iterator length was not accurately reported" + ); + values.set_len(len); + null.set_len(len); + + Ok((null, values)) +} + +impl>> FromIterator for MutableBooleanArray { + fn from_iter>(iter: I) -> Self { + let iter = iter.into_iter(); + let (lower, _) = iter.size_hint(); + + let mut validity = MutableBitmap::with_capacity(lower); + + let values: MutableBitmap = iter + .map(|item| { + if let Some(a) = item.borrow() { + validity.push(true); + *a + } else { + validity.push(false); + false + } + }) + .collect(); + + let validity = if validity.unset_bits() > 0 { + Some(validity) + } else { + None + }; + + MutableBooleanArray::try_new(ArrowDataType::Boolean, values, validity).unwrap() + } +} + +impl MutableArray for MutableBooleanArray { + fn len(&self) -> usize { + self.values.len() + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + let array: BooleanArray = std::mem::take(self).into(); + array.boxed() + } + + fn as_arc(&mut self) -> Arc { + let array: BooleanArray = std::mem::take(self).into(); + array.arced() + } + + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + self.push(None) + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional) + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit() + } +} + +impl Extend> for MutableBooleanArray { + fn extend>>(&mut self, iter: I) { + let iter = iter.into_iter(); + self.reserve(iter.size_hint().0); + iter.for_each(|x| self.push(x)) + } +} + +impl TryExtend> for MutableBooleanArray { + /// This is infalible and is implemented for consistency with all other types + fn try_extend>>(&mut self, iter: I) -> PolarsResult<()> { + self.extend(iter); + Ok(()) + } +} + +impl TryPush> for MutableBooleanArray { + /// This is infalible and is implemented for consistency with all other types + fn try_push(&mut self, item: Option) -> PolarsResult<()> { + self.push(item); + Ok(()) + } +} + +impl PartialEq for MutableBooleanArray { + fn eq(&self, other: &Self) -> bool { + self.iter().eq(other.iter()) + } +} + +impl TryExtendFromSelf for MutableBooleanArray { + fn try_extend_from_self(&mut self, other: &Self) -> PolarsResult<()> { + extend_validity(self.len(), &mut self.validity, &other.validity); + + let slice = other.values.as_slice(); + // SAFETY: invariant offset + length <= slice.len() + unsafe { + self.values + .extend_from_slice_unchecked(slice, 0, other.values.len()); + } + Ok(()) + } +} diff --git a/crates/polars-arrow/src/array/builder.rs b/crates/polars-arrow/src/array/builder.rs new file mode 100644 index 000000000000..cefb1bab948e --- /dev/null +++ b/crates/polars-arrow/src/array/builder.rs @@ -0,0 +1,334 @@ +use polars_utils::IdxSize; + +use crate::array::binary::BinaryArrayBuilder; +use crate::array::binview::BinaryViewArrayGenericBuilder; +use crate::array::boolean::BooleanArrayBuilder; +use crate::array::fixed_size_binary::FixedSizeBinaryArrayBuilder; +use crate::array::fixed_size_list::FixedSizeListArrayBuilder; +use crate::array::list::ListArrayBuilder; +use crate::array::null::NullArrayBuilder; +use crate::array::struct_::StructArrayBuilder; +use crate::array::{Array, PrimitiveArrayBuilder}; +use crate::datatypes::{ArrowDataType, PhysicalType}; +use crate::with_match_primitive_type_full; + +/// Used for arrays which can share buffers with input arrays to appends, +/// gathers, etc. +#[derive(Copy, Clone, Debug)] +pub enum ShareStrategy { + Never, + Always, +} + +pub trait StaticArrayBuilder: Send { + type Array: Array; + + fn dtype(&self) -> &ArrowDataType; + fn reserve(&mut self, additional: usize); + + /// Consume this builder returning the built array. + fn freeze(self) -> Self::Array; + + /// Return the built array and reset to an empty state. + fn freeze_reset(&mut self) -> Self::Array; + + /// Returns the length of this builder (so far). + fn len(&self) -> usize; + + /// Extend this builder with the given number of null elements. + fn extend_nulls(&mut self, length: usize); + + /// Extends this builder with the contents of the given array. May panic if + /// other does not match the dtype of this array. + fn extend(&mut self, other: &Self::Array, share: ShareStrategy) { + self.subslice_extend(other, 0, other.len(), share); + } + + /// Extends this builder with the contents of the given array subslice. May + /// panic if other does not match the dtype of this array. + fn subslice_extend( + &mut self, + other: &Self::Array, + start: usize, + length: usize, + share: ShareStrategy, + ); + + /// The same as subslice_extend, but repeats the extension `repeats` times. + fn subslice_extend_repeated( + &mut self, + other: &Self::Array, + start: usize, + length: usize, + repeats: usize, + share: ShareStrategy, + ) { + for _ in 0..repeats { + self.subslice_extend(other, start, length, share) + } + } + + /// Extends this builder with the contents of the given array at the given + /// indices. That is, `other[idxs[i]]` is appended to this array in order, + /// for each i=0..idxs.len(). May panic if other does not match the + /// dtype of this array. + /// + /// # Safety + /// The indices must be in-bounds. + unsafe fn gather_extend(&mut self, other: &Self::Array, idxs: &[IdxSize], share: ShareStrategy); + + /// Extends this builder with the contents of the given array at the given + /// indices. That is, `other[idxs[i]]` is appended to this array in order, + /// for each i=0..idxs.len(). May panic if other does not match the + /// dtype of this array. Out-of-bounds indices are mapped to nulls. + fn opt_gather_extend(&mut self, other: &Self::Array, idxs: &[IdxSize], share: ShareStrategy); +} + +impl ArrayBuilder for T { + #[inline(always)] + fn dtype(&self) -> &ArrowDataType { + StaticArrayBuilder::dtype(self) + } + + #[inline(always)] + fn reserve(&mut self, additional: usize) { + StaticArrayBuilder::reserve(self, additional) + } + + #[inline(always)] + fn freeze(self) -> Box { + Box::new(StaticArrayBuilder::freeze(self)) + } + + #[inline(always)] + fn freeze_reset(&mut self) -> Box { + Box::new(StaticArrayBuilder::freeze_reset(self)) + } + + #[inline(always)] + fn len(&self) -> usize { + StaticArrayBuilder::len(self) + } + + #[inline(always)] + fn extend_nulls(&mut self, length: usize) { + StaticArrayBuilder::extend_nulls(self, length); + } + + #[inline(always)] + fn subslice_extend( + &mut self, + other: &dyn Array, + start: usize, + length: usize, + share: ShareStrategy, + ) { + let other: &T::Array = other.as_any().downcast_ref().unwrap(); + StaticArrayBuilder::subslice_extend(self, other, start, length, share); + } + + #[inline(always)] + fn subslice_extend_repeated( + &mut self, + other: &dyn Array, + start: usize, + length: usize, + repeats: usize, + share: ShareStrategy, + ) { + let other: &T::Array = other.as_any().downcast_ref().unwrap(); + StaticArrayBuilder::subslice_extend_repeated(self, other, start, length, repeats, share); + } + + #[inline(always)] + unsafe fn gather_extend(&mut self, other: &dyn Array, idxs: &[IdxSize], share: ShareStrategy) { + let other: &T::Array = other.as_any().downcast_ref().unwrap(); + StaticArrayBuilder::gather_extend(self, other, idxs, share); + } + + #[inline(always)] + fn opt_gather_extend(&mut self, other: &dyn Array, idxs: &[IdxSize], share: ShareStrategy) { + let other: &T::Array = other.as_any().downcast_ref().unwrap(); + StaticArrayBuilder::opt_gather_extend(self, other, idxs, share); + } +} + +#[allow(private_bounds)] +pub trait ArrayBuilder: ArrayBuilderBoxedHelper + Send { + fn dtype(&self) -> &ArrowDataType; + fn reserve(&mut self, additional: usize); + + /// Consume this builder returning the built array. + fn freeze(self) -> Box; + + /// Return the built array and reset to an empty state. + fn freeze_reset(&mut self) -> Box; + + /// Returns the length of this builder (so far). + fn len(&self) -> usize; + + /// Extend this builder with the given number of null elements. + fn extend_nulls(&mut self, length: usize); + + /// Extends this builder with the contents of the given array. May panic if + /// other does not match the dtype of this array. + fn extend(&mut self, other: &dyn Array, share: ShareStrategy) { + self.subslice_extend(other, 0, other.len(), share); + } + + /// Extends this builder with the contents of the given array subslice. May + /// panic if other does not match the dtype of this array. + fn subslice_extend( + &mut self, + other: &dyn Array, + start: usize, + length: usize, + share: ShareStrategy, + ); + + /// The same as subslice_extend, but repeats the extension `repeats` times. + fn subslice_extend_repeated( + &mut self, + other: &dyn Array, + start: usize, + length: usize, + repeats: usize, + share: ShareStrategy, + ); + + /// Extends this builder with the contents of the given array at the given + /// indices. That is, `other[idxs[i]]` is appended to this array in order, + /// for each i=0..idxs.len(). May panic if other does not match the + /// dtype of this array. + /// + /// # Safety + /// The indices must be in-bounds. + unsafe fn gather_extend(&mut self, other: &dyn Array, idxs: &[IdxSize], share: ShareStrategy); + + /// Extends this builder with the contents of the given array at the given + /// indices. That is, `other[idxs[i]]` is appended to this array in order, + /// for each i=0..idxs.len(). May panic if other does not match the + /// dtype of this array. Out-of-bounds indices are mapped to nulls. + fn opt_gather_extend(&mut self, other: &dyn Array, idxs: &[IdxSize], share: ShareStrategy); +} + +/// A hack that lets us call the consuming `freeze` method on Box. +trait ArrayBuilderBoxedHelper { + fn freeze_boxed(self: Box) -> Box; +} + +impl ArrayBuilderBoxedHelper for T { + fn freeze_boxed(self: Box) -> Box { + self.freeze() + } +} + +impl ArrayBuilder for Box { + #[inline(always)] + fn dtype(&self) -> &ArrowDataType { + (**self).dtype() + } + + #[inline(always)] + fn reserve(&mut self, additional: usize) { + (**self).reserve(additional) + } + + #[inline(always)] + fn freeze(self) -> Box { + self.freeze_boxed() + } + + #[inline(always)] + fn freeze_reset(&mut self) -> Box { + (**self).freeze_reset() + } + + #[inline(always)] + fn len(&self) -> usize { + (**self).len() + } + + #[inline(always)] + fn extend_nulls(&mut self, length: usize) { + (**self).extend_nulls(length); + } + + #[inline(always)] + fn subslice_extend( + &mut self, + other: &dyn Array, + start: usize, + length: usize, + share: ShareStrategy, + ) { + (**self).subslice_extend(other, start, length, share); + } + + #[inline(always)] + fn subslice_extend_repeated( + &mut self, + other: &dyn Array, + start: usize, + length: usize, + repeats: usize, + share: ShareStrategy, + ) { + (**self).subslice_extend_repeated(other, start, length, repeats, share); + } + + #[inline(always)] + unsafe fn gather_extend(&mut self, other: &dyn Array, idxs: &[IdxSize], share: ShareStrategy) { + (**self).gather_extend(other, idxs, share); + } + + #[inline(always)] + fn opt_gather_extend(&mut self, other: &dyn Array, idxs: &[IdxSize], share: ShareStrategy) { + (**self).opt_gather_extend(other, idxs, share); + } +} + +/// Construct an ArrayBuilder for the given type. +pub fn make_builder(dtype: &ArrowDataType) -> Box { + use PhysicalType::*; + match dtype.to_physical_type() { + Null => Box::new(NullArrayBuilder::new(dtype.clone())), + Boolean => Box::new(BooleanArrayBuilder::new(dtype.clone())), + Primitive(prim_t) => with_match_primitive_type_full!(prim_t, |$T| { + Box::new(PrimitiveArrayBuilder::<$T>::new(dtype.clone())) + }), + LargeBinary => Box::new(BinaryArrayBuilder::::new(dtype.clone())), + FixedSizeBinary => Box::new(FixedSizeBinaryArrayBuilder::new(dtype.clone())), + LargeList => { + let ArrowDataType::LargeList(inner_dt) = dtype else { + unreachable!() + }; + Box::new(ListArrayBuilder::::new( + dtype.clone(), + make_builder(inner_dt.dtype()), + )) + }, + FixedSizeList => { + let ArrowDataType::FixedSizeList(inner_dt, _) = dtype else { + unreachable!() + }; + Box::new(FixedSizeListArrayBuilder::new( + dtype.clone(), + make_builder(inner_dt.dtype()), + )) + }, + Struct => { + let ArrowDataType::Struct(fields) = dtype else { + unreachable!() + }; + let builders = fields.iter().map(|f| make_builder(f.dtype())).collect(); + Box::new(StructArrayBuilder::new(dtype.clone(), builders)) + }, + BinaryView => Box::new(BinaryViewArrayGenericBuilder::<[u8]>::new(dtype.clone())), + Utf8View => Box::new(BinaryViewArrayGenericBuilder::::new(dtype.clone())), + + List | Binary | Utf8 | LargeUtf8 | Map | Union | Dictionary(_) => { + unimplemented!() + }, + } +} diff --git a/crates/polars-arrow/src/array/dictionary/ffi.rs b/crates/polars-arrow/src/array/dictionary/ffi.rs new file mode 100644 index 000000000000..241c559a9a74 --- /dev/null +++ b/crates/polars-arrow/src/array/dictionary/ffi.rs @@ -0,0 +1,42 @@ +use polars_error::{PolarsResult, polars_err}; + +use super::{DictionaryArray, DictionaryKey}; +use crate::array::{FromFfi, PrimitiveArray, ToFfi}; +use crate::ffi; + +unsafe impl ToFfi for DictionaryArray { + fn buffers(&self) -> Vec> { + self.keys.buffers() + } + + fn offset(&self) -> Option { + self.keys.offset() + } + + fn to_ffi_aligned(&self) -> Self { + Self { + dtype: self.dtype.clone(), + keys: self.keys.to_ffi_aligned(), + values: self.values.clone(), + } + } +} + +impl FromFfi for DictionaryArray { + unsafe fn try_from_ffi(array: A) -> PolarsResult { + // keys: similar to PrimitiveArray, but the datatype is the inner one + let validity = unsafe { array.validity() }?; + let values = unsafe { array.buffer::(1) }?; + + let dtype = array.dtype().clone(); + + let keys = PrimitiveArray::::try_new(K::PRIMITIVE.into(), values, validity)?; + let values = array.dictionary()?.ok_or_else( + || polars_err!(ComputeError: "Dictionary Array must contain a dictionary in ffi"), + )?; + let values = ffi::try_from(values)?; + + // the assumption of this trait + DictionaryArray::::try_new_unchecked(dtype, keys, values) + } +} diff --git a/crates/polars-arrow/src/array/dictionary/fmt.rs b/crates/polars-arrow/src/array/dictionary/fmt.rs new file mode 100644 index 000000000000..b3ce55515902 --- /dev/null +++ b/crates/polars-arrow/src/array/dictionary/fmt.rs @@ -0,0 +1,31 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::{get_display, write_vec}; +use super::{DictionaryArray, DictionaryKey}; +use crate::array::Array; + +pub fn write_value( + array: &DictionaryArray, + index: usize, + null: &'static str, + f: &mut W, +) -> Result { + let keys = array.keys(); + let values = array.values(); + + if keys.is_valid(index) { + let key = array.key_value(index); + get_display(values.as_ref(), null)(f, key) + } else { + write!(f, "{null}") + } +} + +impl Debug for DictionaryArray { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, "None", f); + + write!(f, "DictionaryArray")?; + write_vec(f, writer, self.validity(), self.len(), "None", false) + } +} diff --git a/crates/polars-arrow/src/array/dictionary/iterator.rs b/crates/polars-arrow/src/array/dictionary/iterator.rs new file mode 100644 index 000000000000..af6ef539572d --- /dev/null +++ b/crates/polars-arrow/src/array/dictionary/iterator.rs @@ -0,0 +1,67 @@ +use super::{DictionaryArray, DictionaryKey}; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::scalar::Scalar; +use crate::trusted_len::TrustedLen; + +/// Iterator of values of an `ListArray`. +pub struct DictionaryValuesIter<'a, K: DictionaryKey> { + array: &'a DictionaryArray, + index: usize, + end: usize, +} + +impl<'a, K: DictionaryKey> DictionaryValuesIter<'a, K> { + #[inline] + pub fn new(array: &'a DictionaryArray) -> Self { + Self { + array, + index: 0, + end: array.len(), + } + } +} + +impl Iterator for DictionaryValuesIter<'_, K> { + type Item = Box; + + #[inline] + fn next(&mut self) -> Option { + if self.index == self.end { + return None; + } + let old = self.index; + self.index += 1; + Some(self.array.value(old)) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.end - self.index, Some(self.end - self.index)) + } +} + +unsafe impl TrustedLen for DictionaryValuesIter<'_, K> {} + +impl DoubleEndedIterator for DictionaryValuesIter<'_, K> { + #[inline] + fn next_back(&mut self) -> Option { + if self.index == self.end { + None + } else { + self.end -= 1; + Some(self.array.value(self.end)) + } + } +} + +type ValuesIter<'a, K> = DictionaryValuesIter<'a, K>; +type ZipIter<'a, K> = ZipValidity, ValuesIter<'a, K>, BitmapIter<'a>>; + +impl<'a, K: DictionaryKey> IntoIterator for &'a DictionaryArray { + type Item = Option>; + type IntoIter = ZipIter<'a, K>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} diff --git a/crates/polars-arrow/src/array/dictionary/mod.rs b/crates/polars-arrow/src/array/dictionary/mod.rs new file mode 100644 index 000000000000..5b39a3b6aeb2 --- /dev/null +++ b/crates/polars-arrow/src/array/dictionary/mod.rs @@ -0,0 +1,448 @@ +use std::hash::Hash; +use std::hint::unreachable_unchecked; + +use crate::bitmap::Bitmap; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::datatypes::{ArrowDataType, IntegerType}; +use crate::scalar::{Scalar, new_scalar}; +use crate::trusted_len::TrustedLen; +use crate::types::NativeType; + +mod ffi; +pub(super) mod fmt; +mod iterator; +mod mutable; +use crate::array::specification::check_indexes_unchecked; +mod typed_iterator; +mod value_map; + +pub use iterator::*; +pub use mutable::*; +use polars_error::{PolarsResult, polars_bail}; + +use super::primitive::PrimitiveArray; +use super::specification::check_indexes; +use super::{Array, Splitable, new_empty_array, new_null_array}; +use crate::array::dictionary::typed_iterator::{ + DictValue, DictionaryIterTyped, DictionaryValuesIterTyped, +}; + +/// Trait denoting [`NativeType`]s that can be used as keys of a dictionary. +/// # Safety +/// +/// Any implementation of this trait must ensure that `always_fits_usize` only +/// returns `true` if all values succeeds on `value::try_into::().unwrap()`. +pub unsafe trait DictionaryKey: NativeType + TryInto + TryFrom + Hash { + /// The corresponding [`IntegerType`] of this key + const KEY_TYPE: IntegerType; + const MAX_USIZE_VALUE: usize; + + /// Represents this key as a `usize`. + /// + /// # Safety + /// The caller _must_ have checked that the value can be cast to `usize`. + #[inline] + unsafe fn as_usize(self) -> usize { + match self.try_into() { + Ok(v) => v, + Err(_) => unreachable_unchecked(), + } + } + + /// Create a key from a `usize` without checking bounds. + /// + /// # Safety + /// The caller _must_ have checked that the value can be created from a `usize`. + #[inline] + unsafe fn from_usize_unchecked(x: usize) -> Self { + debug_assert!(Self::try_from(x).is_ok()); + unsafe { Self::try_from(x).unwrap_unchecked() } + } + + /// If the key type always can be converted to `usize`. + fn always_fits_usize() -> bool { + false + } +} + +unsafe impl DictionaryKey for i8 { + const KEY_TYPE: IntegerType = IntegerType::Int8; + const MAX_USIZE_VALUE: usize = i8::MAX as usize; +} +unsafe impl DictionaryKey for i16 { + const KEY_TYPE: IntegerType = IntegerType::Int16; + const MAX_USIZE_VALUE: usize = i16::MAX as usize; +} +unsafe impl DictionaryKey for i32 { + const KEY_TYPE: IntegerType = IntegerType::Int32; + const MAX_USIZE_VALUE: usize = i32::MAX as usize; +} +unsafe impl DictionaryKey for i64 { + const KEY_TYPE: IntegerType = IntegerType::Int64; + const MAX_USIZE_VALUE: usize = i64::MAX as usize; +} +unsafe impl DictionaryKey for i128 { + const KEY_TYPE: IntegerType = IntegerType::Int128; + const MAX_USIZE_VALUE: usize = i128::MAX as usize; +} +unsafe impl DictionaryKey for u8 { + const KEY_TYPE: IntegerType = IntegerType::UInt8; + const MAX_USIZE_VALUE: usize = u8::MAX as usize; + + fn always_fits_usize() -> bool { + true + } +} +unsafe impl DictionaryKey for u16 { + const KEY_TYPE: IntegerType = IntegerType::UInt16; + const MAX_USIZE_VALUE: usize = u16::MAX as usize; + + fn always_fits_usize() -> bool { + true + } +} +unsafe impl DictionaryKey for u32 { + const KEY_TYPE: IntegerType = IntegerType::UInt32; + const MAX_USIZE_VALUE: usize = u32::MAX as usize; + + fn always_fits_usize() -> bool { + true + } +} +unsafe impl DictionaryKey for u64 { + const KEY_TYPE: IntegerType = IntegerType::UInt64; + const MAX_USIZE_VALUE: usize = u64::MAX as usize; + + #[cfg(target_pointer_width = "64")] + fn always_fits_usize() -> bool { + true + } +} + +/// An [`Array`] whose values are stored as indices. This [`Array`] is useful when the cardinality of +/// values is low compared to the length of the [`Array`]. +/// +/// # Safety +/// This struct guarantees that each item of [`DictionaryArray::keys`] is castable to `usize` and +/// its value is smaller than [`DictionaryArray::values`]`.len()`. In other words, you can safely +/// use `unchecked` calls to retrieve the values +#[derive(Clone)] +pub struct DictionaryArray { + dtype: ArrowDataType, + keys: PrimitiveArray, + values: Box, +} + +fn check_dtype( + key_type: IntegerType, + dtype: &ArrowDataType, + values_dtype: &ArrowDataType, +) -> PolarsResult<()> { + if let ArrowDataType::Dictionary(key, value, _) = dtype.to_logical_type() { + if *key != key_type { + polars_bail!(ComputeError: "DictionaryArray must be initialized with a DataType::Dictionary whose integer is compatible to its keys") + } + if value.as_ref().to_logical_type() != values_dtype.to_logical_type() { + polars_bail!(ComputeError: "DictionaryArray must be initialized with a DataType::Dictionary whose value is equal to its values") + } + } else { + polars_bail!(ComputeError: "DictionaryArray must be initialized with logical DataType::Dictionary") + } + Ok(()) +} + +impl DictionaryArray { + /// Returns a new [`DictionaryArray`]. + /// # Implementation + /// This function is `O(N)` where `N` is the length of keys + /// # Errors + /// This function errors iff + /// * the `dtype`'s logical type is not a `DictionaryArray` + /// * the `dtype`'s keys is not compatible with `keys` + /// * the `dtype`'s values's dtype is not equal with `values.dtype()` + /// * any of the keys's values is not represented in `usize` or is `>= values.len()` + pub fn try_new( + dtype: ArrowDataType, + keys: PrimitiveArray, + values: Box, + ) -> PolarsResult { + check_dtype(K::KEY_TYPE, &dtype, values.dtype())?; + + if keys.null_count() != keys.len() { + if K::always_fits_usize() { + // SAFETY: we just checked that conversion to `usize` always + // succeeds + unsafe { check_indexes_unchecked(keys.values(), values.len()) }?; + } else { + check_indexes(keys.values(), values.len())?; + } + } + + Ok(Self { + dtype, + keys, + values, + }) + } + + /// Returns a new [`DictionaryArray`]. + /// # Implementation + /// This function is `O(N)` where `N` is the length of keys + /// # Errors + /// This function errors iff + /// * any of the keys's values is not represented in `usize` or is `>= values.len()` + pub fn try_from_keys(keys: PrimitiveArray, values: Box) -> PolarsResult { + let dtype = Self::default_dtype(values.dtype().clone()); + Self::try_new(dtype, keys, values) + } + + /// Returns a new [`DictionaryArray`]. + /// # Errors + /// This function errors iff + /// * the `dtype`'s logical type is not a `DictionaryArray` + /// * the `dtype`'s keys is not compatible with `keys` + /// * the `dtype`'s values's dtype is not equal with `values.dtype()` + /// + /// # Safety + /// The caller must ensure that every keys's values is represented in `usize` and is `< values.len()` + pub unsafe fn try_new_unchecked( + dtype: ArrowDataType, + keys: PrimitiveArray, + values: Box, + ) -> PolarsResult { + check_dtype(K::KEY_TYPE, &dtype, values.dtype())?; + + Ok(Self { + dtype, + keys, + values, + }) + } + + /// Returns a new empty [`DictionaryArray`]. + pub fn new_empty(dtype: ArrowDataType) -> Self { + let values = Self::try_get_child(&dtype).unwrap(); + let values = new_empty_array(values.clone()); + Self::try_new( + dtype, + PrimitiveArray::::new_empty(K::PRIMITIVE.into()), + values, + ) + .unwrap() + } + + /// Returns an [`DictionaryArray`] whose all elements are null + #[inline] + pub fn new_null(dtype: ArrowDataType, length: usize) -> Self { + let values = Self::try_get_child(&dtype).unwrap(); + let values = new_null_array(values.clone(), 1); + Self::try_new( + dtype, + PrimitiveArray::::new_null(K::PRIMITIVE.into(), length), + values, + ) + .unwrap() + } + + /// Returns an iterator of [`Option>`]. + /// # Implementation + /// This function will allocate a new [`Scalar`] per item and is usually not performant. + /// Consider calling `keys_iter` and `values`, downcasting `values`, and iterating over that. + pub fn iter(&self) -> ZipValidity, DictionaryValuesIter, BitmapIter> { + ZipValidity::new_with_validity(DictionaryValuesIter::new(self), self.keys.validity()) + } + + /// Returns an iterator of [`Box`] + /// # Implementation + /// This function will allocate a new [`Scalar`] per item and is usually not performant. + /// Consider calling `keys_iter` and `values`, downcasting `values`, and iterating over that. + pub fn values_iter(&self) -> DictionaryValuesIter { + DictionaryValuesIter::new(self) + } + + /// Returns an iterator over the values [`V::IterValue`]. + /// + /// # Panics + /// + /// Panics if the keys of this [`DictionaryArray`] has any nulls. + /// If they do [`DictionaryArray::iter_typed`] should be used. + pub fn values_iter_typed(&self) -> PolarsResult> { + let keys = &self.keys; + assert_eq!(keys.null_count(), 0); + let values = self.values.as_ref(); + let values = V::downcast_values(values)?; + Ok(DictionaryValuesIterTyped::new(keys, values)) + } + + /// Returns an iterator over the optional values of [`Option`]. + pub fn iter_typed(&self) -> PolarsResult> { + let keys = &self.keys; + let values = self.values.as_ref(); + let values = V::downcast_values(values)?; + Ok(DictionaryIterTyped::new(keys, values)) + } + + /// Returns the [`ArrowDataType`] of this [`DictionaryArray`] + #[inline] + pub fn dtype(&self) -> &ArrowDataType { + &self.dtype + } + + /// Returns whether the values of this [`DictionaryArray`] are ordered + #[inline] + pub fn is_ordered(&self) -> bool { + match self.dtype.to_logical_type() { + ArrowDataType::Dictionary(_, _, is_ordered) => *is_ordered, + _ => unreachable!(), + } + } + + pub(crate) fn default_dtype(values_datatype: ArrowDataType) -> ArrowDataType { + ArrowDataType::Dictionary(K::KEY_TYPE, Box::new(values_datatype), false) + } + + /// Slices this [`DictionaryArray`]. + /// # Panics + /// iff `offset + length > self.len()`. + pub fn slice(&mut self, offset: usize, length: usize) { + self.keys.slice(offset, length); + } + + /// Slices this [`DictionaryArray`]. + /// + /// # Safety + /// Safe iff `offset + length <= self.len()`. + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.keys.slice_unchecked(offset, length); + } + + impl_sliced!(); + + /// Returns this [`DictionaryArray`] with a new validity. + /// # Panic + /// This function panics iff `validity.len() != self.len()`. + #[must_use] + pub fn with_validity(mut self, validity: Option) -> Self { + self.set_validity(validity); + self + } + + /// Sets the validity of the keys of this [`DictionaryArray`]. + /// # Panics + /// This function panics iff `validity.len() != self.len()`. + pub fn set_validity(&mut self, validity: Option) { + self.keys.set_validity(validity); + } + + impl_into_array!(); + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.keys.len() + } + + /// The optional validity. Equivalent to `self.keys().validity()`. + #[inline] + pub fn validity(&self) -> Option<&Bitmap> { + self.keys.validity() + } + + /// Returns the keys of the [`DictionaryArray`]. These keys can be used to fetch values + /// from `values`. + #[inline] + pub fn keys(&self) -> &PrimitiveArray { + &self.keys + } + + /// Returns an iterator of the keys' values of the [`DictionaryArray`] as `usize` + #[inline] + pub fn keys_values_iter(&self) -> impl TrustedLen + Clone + '_ { + // SAFETY: invariant of the struct + self.keys.values_iter().map(|x| unsafe { x.as_usize() }) + } + + /// Returns an iterator of the keys' of the [`DictionaryArray`] as `usize` + #[inline] + pub fn keys_iter(&self) -> impl TrustedLen> + Clone + '_ { + // SAFETY: invariant of the struct + self.keys.iter().map(|x| x.map(|x| unsafe { x.as_usize() })) + } + + /// Returns the keys' value of the [`DictionaryArray`] as `usize` + /// # Panics + /// This function panics iff `index >= self.len()` + #[inline] + pub fn key_value(&self, index: usize) -> usize { + // SAFETY: invariant of the struct + unsafe { self.keys.values()[index].as_usize() } + } + + /// Returns the values of the [`DictionaryArray`]. + #[inline] + pub fn values(&self) -> &Box { + &self.values + } + + /// Returns the value of the [`DictionaryArray`] at position `i`. + /// # Implementation + /// This function will allocate a new [`Scalar`] and is usually not performant. + /// Consider calling `keys` and `values`, downcasting `values`, and iterating over that. + /// # Panic + /// This function panics iff `index >= self.len()` + #[inline] + pub fn value(&self, index: usize) -> Box { + // SAFETY: invariant of this struct + let index = unsafe { self.keys.value(index).as_usize() }; + new_scalar(self.values.as_ref(), index) + } + + pub(crate) fn try_get_child(dtype: &ArrowDataType) -> PolarsResult<&ArrowDataType> { + Ok(match dtype.to_logical_type() { + ArrowDataType::Dictionary(_, values, _) => values.as_ref(), + _ => { + polars_bail!(ComputeError: "Dictionaries must be initialized with DataType::Dictionary") + }, + }) + } + + pub fn take(self) -> (ArrowDataType, PrimitiveArray, Box) { + (self.dtype, self.keys, self.values) + } +} + +impl Array for DictionaryArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + self.keys.validity() + } + + #[inline] + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } +} + +impl Splitable for DictionaryArray { + fn check_bound(&self, offset: usize) -> bool { + offset < self.len() + } + + unsafe fn _split_at_unchecked(&self, offset: usize) -> (Self, Self) { + let (lhs_keys, rhs_keys) = unsafe { Splitable::split_at_unchecked(&self.keys, offset) }; + + ( + Self { + dtype: self.dtype.clone(), + keys: lhs_keys, + values: self.values.clone(), + }, + Self { + dtype: self.dtype.clone(), + keys: rhs_keys, + values: self.values.clone(), + }, + ) + } +} diff --git a/crates/polars-arrow/src/array/dictionary/mutable.rs b/crates/polars-arrow/src/array/dictionary/mutable.rs new file mode 100644 index 000000000000..1d01b1e719f9 --- /dev/null +++ b/crates/polars-arrow/src/array/dictionary/mutable.rs @@ -0,0 +1,227 @@ +use std::hash::Hash; +use std::sync::Arc; + +use polars_error::PolarsResult; + +use super::value_map::ValueMap; +use super::{DictionaryArray, DictionaryKey}; +use crate::array::indexable::{AsIndexed, Indexable}; +use crate::array::primitive::MutablePrimitiveArray; +use crate::array::{Array, MutableArray, TryExtend, TryPush}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::ArrowDataType; + +#[derive(Debug)] +pub struct MutableDictionaryArray { + dtype: ArrowDataType, + map: ValueMap, + // invariant: `max(keys) < map.values().len()` + keys: MutablePrimitiveArray, +} + +impl From> for DictionaryArray { + fn from(other: MutableDictionaryArray) -> Self { + // SAFETY: the invariant of this struct ensures that this is up-held + unsafe { + DictionaryArray::::try_new_unchecked( + other.dtype, + other.keys.into(), + other.map.into_values().as_box(), + ) + .unwrap() + } + } +} + +impl MutableDictionaryArray { + /// Creates an empty [`MutableDictionaryArray`]. + pub fn new() -> Self { + Self::try_empty(M::default()).unwrap() + } +} + +impl Default for MutableDictionaryArray { + fn default() -> Self { + Self::new() + } +} + +impl MutableDictionaryArray { + /// Creates an empty [`MutableDictionaryArray`] from a given empty values array. + /// # Errors + /// Errors if the array is non-empty. + pub fn try_empty(values: M) -> PolarsResult { + Ok(Self::from_value_map(ValueMap::::try_empty(values)?)) + } + + /// Creates an empty [`MutableDictionaryArray`] preloaded with a given dictionary of values. + /// Indices associated with those values are automatically assigned based on the order of + /// the values. + /// # Errors + /// Errors if there's more values than the maximum value of `K` or if values are not unique. + pub fn from_values(values: M) -> PolarsResult + where + M: Indexable, + M::Type: Eq + Hash, + { + Ok(Self::from_value_map(ValueMap::::from_values(values)?)) + } + + fn from_value_map(value_map: ValueMap) -> Self { + let keys = MutablePrimitiveArray::::new(); + let dtype = + ArrowDataType::Dictionary(K::KEY_TYPE, Box::new(value_map.dtype().clone()), false); + Self { + dtype, + map: value_map, + keys, + } + } + + /// Creates an empty [`MutableDictionaryArray`] retaining the same dictionary as the current + /// mutable dictionary array, but with no data. This may come useful when serializing the + /// array into multiple chunks, where there's a requirement that the dictionary is the same. + /// No copying is performed, the value map is moved over to the new array. + pub fn into_empty(self) -> Self { + Self::from_value_map(self.map) + } + + /// Same as `into_empty` but clones the inner value map instead of taking full ownership. + pub fn to_empty(&self) -> Self + where + M: Clone, + { + Self::from_value_map(self.map.clone()) + } + + /// pushes a null value + pub fn push_null(&mut self) { + self.keys.push(None) + } + + /// returns a reference to the inner values. + pub fn values(&self) -> &M { + self.map.values() + } + + /// converts itself into [`Arc`] + pub fn into_arc(self) -> Arc { + let a: DictionaryArray = self.into(); + Arc::new(a) + } + + /// converts itself into [`Box`] + pub fn into_box(self) -> Box { + let a: DictionaryArray = self.into(); + Box::new(a) + } + + /// Reserves `additional` slots. + pub fn reserve(&mut self, additional: usize) { + self.keys.reserve(additional); + } + + /// Shrinks the capacity of the [`MutableDictionaryArray`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.map.shrink_to_fit(); + self.keys.shrink_to_fit(); + } + + /// Returns the dictionary keys + pub fn keys(&self) -> &MutablePrimitiveArray { + &self.keys + } + + fn take_into(&mut self) -> DictionaryArray { + DictionaryArray::::try_new( + self.dtype.clone(), + std::mem::take(&mut self.keys).into(), + self.map.take_into(), + ) + .unwrap() + } +} + +impl MutableArray for MutableDictionaryArray { + fn len(&self) -> usize { + self.keys.len() + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.keys.validity() + } + + fn as_box(&mut self) -> Box { + Box::new(self.take_into()) + } + + fn as_arc(&mut self) -> Arc { + Arc::new(self.take_into()) + } + + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + fn push_null(&mut self) { + self.keys.push(None) + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional) + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit() + } +} + +impl TryExtend> for MutableDictionaryArray +where + K: DictionaryKey, + M: MutableArray + Indexable + TryExtend> + TryPush>, + T: AsIndexed, + M::Type: Eq + Hash, +{ + fn try_extend>>(&mut self, iter: II) -> PolarsResult<()> { + for value in iter { + if let Some(value) = value { + let key = self + .map + .try_push_valid(value, |arr, v| arr.try_push(Some(v)))?; + self.keys.try_push(Some(key))?; + } else { + self.push_null(); + } + } + Ok(()) + } +} + +impl TryPush> for MutableDictionaryArray +where + K: DictionaryKey, + M: MutableArray + Indexable + TryPush>, + T: AsIndexed, + M::Type: Eq + Hash, +{ + fn try_push(&mut self, item: Option) -> PolarsResult<()> { + if let Some(value) = item { + let key = self + .map + .try_push_valid(value, |arr, v| arr.try_push(Some(v)))?; + self.keys.try_push(Some(key))?; + } else { + self.push_null(); + } + Ok(()) + } +} diff --git a/crates/polars-arrow/src/array/dictionary/typed_iterator.rs b/crates/polars-arrow/src/array/dictionary/typed_iterator.rs new file mode 100644 index 000000000000..fe8ea38372f8 --- /dev/null +++ b/crates/polars-arrow/src/array/dictionary/typed_iterator.rs @@ -0,0 +1,202 @@ +use polars_error::{PolarsResult, polars_err}; + +use super::DictionaryKey; +use crate::array::{Array, PrimitiveArray, StaticArray, Utf8Array, Utf8ViewArray}; +use crate::trusted_len::TrustedLen; +use crate::types::Offset; + +pub trait DictValue { + type IterValue<'this> + where + Self: 'this; + + /// # Safety + /// Will not do any bound checks but must check validity. + unsafe fn get_unchecked(&self, item: usize) -> Self::IterValue<'_>; + + /// Take a [`dyn Array`] an try to downcast it to the type of `DictValue`. + fn downcast_values(array: &dyn Array) -> PolarsResult<&Self> + where + Self: Sized; +} + +impl DictValue for Utf8Array { + type IterValue<'a> = &'a str; + + unsafe fn get_unchecked(&self, item: usize) -> Self::IterValue<'_> { + self.value_unchecked(item) + } + + fn downcast_values(array: &dyn Array) -> PolarsResult<&Self> + where + Self: Sized, + { + array + .as_any() + .downcast_ref::() + .ok_or_else( + || polars_err!(InvalidOperation: "could not convert array to dictionary value"), + ) + .inspect(|arr| { + assert_eq!( + arr.null_count(), + 0, + "null values in values not supported in iteration" + ); + }) + } +} + +impl DictValue for Utf8ViewArray { + type IterValue<'a> = &'a str; + + unsafe fn get_unchecked(&self, item: usize) -> Self::IterValue<'_> { + self.value_unchecked(item) + } + + fn downcast_values(array: &dyn Array) -> PolarsResult<&Self> + where + Self: Sized, + { + array + .as_any() + .downcast_ref::() + .ok_or_else( + || polars_err!(InvalidOperation: "could not convert array to dictionary value"), + ) + .inspect(|arr| { + assert_eq!( + arr.null_count(), + 0, + "null values in values not supported in iteration" + ); + }) + } +} + +/// Iterator of values of an `ListArray`. +pub struct DictionaryValuesIterTyped<'a, K: DictionaryKey, V: DictValue> { + keys: &'a PrimitiveArray, + values: &'a V, + index: usize, + end: usize, +} + +impl<'a, K: DictionaryKey, V: DictValue> DictionaryValuesIterTyped<'a, K, V> { + pub(super) fn new(keys: &'a PrimitiveArray, values: &'a V) -> Self { + assert_eq!(keys.null_count(), 0); + Self { + keys, + values, + index: 0, + end: keys.len(), + } + } +} + +impl<'a, K: DictionaryKey, V: DictValue> Iterator for DictionaryValuesIterTyped<'a, K, V> { + type Item = V::IterValue<'a>; + + #[inline] + fn next(&mut self) -> Option { + if self.index == self.end { + return None; + } + let old = self.index; + self.index += 1; + unsafe { + let key = self.keys.value_unchecked(old); + let idx = key.as_usize(); + Some(self.values.get_unchecked(idx)) + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.end - self.index, Some(self.end - self.index)) + } +} + +unsafe impl TrustedLen for DictionaryValuesIterTyped<'_, K, V> {} + +impl DoubleEndedIterator for DictionaryValuesIterTyped<'_, K, V> { + #[inline] + fn next_back(&mut self) -> Option { + if self.index == self.end { + None + } else { + self.end -= 1; + unsafe { + let key = self.keys.value_unchecked(self.end); + let idx = key.as_usize(); + Some(self.values.get_unchecked(idx)) + } + } + } +} + +pub struct DictionaryIterTyped<'a, K: DictionaryKey, V: DictValue> { + keys: &'a PrimitiveArray, + values: &'a V, + index: usize, + end: usize, +} + +impl<'a, K: DictionaryKey, V: DictValue> DictionaryIterTyped<'a, K, V> { + pub(super) fn new(keys: &'a PrimitiveArray, values: &'a V) -> Self { + Self { + keys, + values, + index: 0, + end: keys.len(), + } + } +} + +impl<'a, K: DictionaryKey, V: DictValue> Iterator for DictionaryIterTyped<'a, K, V> { + type Item = Option>; + + #[inline] + fn next(&mut self) -> Option { + if self.index == self.end { + return None; + } + let old = self.index; + self.index += 1; + unsafe { + if let Some(key) = self.keys.get_unchecked(old) { + let idx = key.as_usize(); + Some(Some(self.values.get_unchecked(idx))) + } else { + Some(None) + } + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.end - self.index, Some(self.end - self.index)) + } +} + +unsafe impl TrustedLen for DictionaryIterTyped<'_, K, V> {} + +impl ExactSizeIterator for DictionaryIterTyped<'_, K, V> {} +impl DoubleEndedIterator for DictionaryIterTyped<'_, K, V> { + #[inline] + fn next_back(&mut self) -> Option { + if self.index == self.end { + None + } else { + self.end -= 1; + unsafe { + if let Some(key) = self.keys.get_unchecked(self.end) { + let idx = key.as_usize(); + Some(Some(self.values.get_unchecked(idx))) + } else { + Some(None) + } + } + } + } +} diff --git a/crates/polars-arrow/src/array/dictionary/value_map.rs b/crates/polars-arrow/src/array/dictionary/value_map.rs new file mode 100644 index 000000000000..b139268d7f5b --- /dev/null +++ b/crates/polars-arrow/src/array/dictionary/value_map.rs @@ -0,0 +1,135 @@ +use std::borrow::Borrow; +use std::fmt::{self, Debug}; +use std::hash::{BuildHasher, Hash}; + +use hashbrown::HashTable; +use hashbrown::hash_table::Entry; +use polars_error::{PolarsResult, polars_bail, polars_err}; +use polars_utils::aliases::PlRandomState; + +use super::DictionaryKey; +use crate::array::indexable::{AsIndexed, Indexable}; +use crate::array::{Array, MutableArray}; +use crate::datatypes::ArrowDataType; + +#[derive(Clone)] +pub struct ValueMap { + pub values: M, + pub map: HashTable<(u64, K)>, + random_state: PlRandomState, +} + +impl ValueMap { + pub fn try_empty(values: M) -> PolarsResult { + if !values.is_empty() { + polars_bail!(ComputeError: "initializing value map with non-empty values array") + } + Ok(Self { + values, + map: HashTable::default(), + random_state: PlRandomState::default(), + }) + } + + pub fn from_values(values: M) -> PolarsResult + where + M: Indexable, + M::Type: Eq + Hash, + { + let mut map: HashTable<(u64, K)> = HashTable::with_capacity(values.len()); + let random_state = PlRandomState::default(); + for index in 0..values.len() { + let key = K::try_from(index).map_err(|_| polars_err!(ComputeError: "overflow"))?; + // SAFETY: we only iterate within bounds + let value = unsafe { values.value_unchecked_at(index) }; + let hash = random_state.hash_one(value.borrow()); + + let entry = map.entry( + hash, + |(_h, key)| { + // SAFETY: invariant of the struct, it's always in bounds. + let stored_value = unsafe { values.value_unchecked_at(key.as_usize()) }; + stored_value.borrow() == value.borrow() + }, + |(h, _key)| *h, + ); + match entry { + Entry::Occupied(_) => { + polars_bail!(InvalidOperation: "duplicate value in dictionary values array") + }, + Entry::Vacant(entry) => { + entry.insert((hash, key)); + }, + } + } + Ok(Self { + values, + map, + random_state, + }) + } + + pub fn dtype(&self) -> &ArrowDataType { + self.values.dtype() + } + + pub fn into_values(self) -> M { + self.values + } + + pub fn take_into(&mut self) -> Box { + let arr = self.values.as_box(); + self.map.clear(); + arr + } + + #[inline] + pub fn values(&self) -> &M { + &self.values + } + + /// Try to insert a value and return its index (it may or may not get inserted). + pub fn try_push_valid( + &mut self, + value: V, + mut push: impl FnMut(&mut M, V) -> PolarsResult<()>, + ) -> PolarsResult + where + M: Indexable, + V: AsIndexed, + M::Type: Eq + Hash, + { + let hash = self.random_state.hash_one(value.as_indexed()); + let entry = self.map.entry( + hash, + |(_h, key)| { + // SAFETY: invariant of the struct, it's always in bounds. + let stored_value = unsafe { self.values.value_unchecked_at(key.as_usize()) }; + stored_value.borrow() == value.as_indexed() + }, + |(h, _key)| *h, + ); + let out = match entry { + Entry::Occupied(entry) => entry.get().1, + Entry::Vacant(entry) => { + let index = self.values.len(); + let key = K::try_from(index).map_err(|_| polars_err!(ComputeError: "overflow"))?; + entry.insert((hash, key)); + push(&mut self.values, value)?; + debug_assert_eq!(self.values.len(), index + 1); + key + }, + }; + Ok(out) + } + + pub fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + } +} + +impl Debug for ValueMap { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.values.fmt(f) + } +} diff --git a/crates/polars-arrow/src/array/equal/binary.rs b/crates/polars-arrow/src/array/equal/binary.rs new file mode 100644 index 000000000000..93145aa461e2 --- /dev/null +++ b/crates/polars-arrow/src/array/equal/binary.rs @@ -0,0 +1,6 @@ +use crate::array::BinaryArray; +use crate::offset::Offset; + +pub(super) fn equal(lhs: &BinaryArray, rhs: &BinaryArray) -> bool { + lhs.dtype() == rhs.dtype() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) +} diff --git a/crates/polars-arrow/src/array/equal/binary_view.rs b/crates/polars-arrow/src/array/equal/binary_view.rs new file mode 100644 index 000000000000..8200155b519b --- /dev/null +++ b/crates/polars-arrow/src/array/equal/binary_view.rs @@ -0,0 +1,9 @@ +use crate::array::Array; +use crate::array::binview::{BinaryViewArrayGeneric, ViewType}; + +pub(super) fn equal( + lhs: &BinaryViewArrayGeneric, + rhs: &BinaryViewArrayGeneric, +) -> bool { + lhs.dtype() == rhs.dtype() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) +} diff --git a/crates/polars-arrow/src/array/equal/boolean.rs b/crates/polars-arrow/src/array/equal/boolean.rs new file mode 100644 index 000000000000..d9c6af9b0276 --- /dev/null +++ b/crates/polars-arrow/src/array/equal/boolean.rs @@ -0,0 +1,5 @@ +use crate::array::BooleanArray; + +pub(super) fn equal(lhs: &BooleanArray, rhs: &BooleanArray) -> bool { + lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) +} diff --git a/crates/polars-arrow/src/array/equal/dictionary.rs b/crates/polars-arrow/src/array/equal/dictionary.rs new file mode 100644 index 000000000000..88213cbc059a --- /dev/null +++ b/crates/polars-arrow/src/array/equal/dictionary.rs @@ -0,0 +1,14 @@ +use crate::array::{DictionaryArray, DictionaryKey}; + +pub(super) fn equal(lhs: &DictionaryArray, rhs: &DictionaryArray) -> bool { + if !(lhs.dtype() == rhs.dtype() && lhs.len() == rhs.len()) { + return false; + }; + + // if x is not valid and y is but its child is not, the slots are equal. + lhs.iter().zip(rhs.iter()).all(|(x, y)| match (&x, &y) { + (None, Some(y)) => !y.is_valid(), + (Some(x), None) => !x.is_valid(), + _ => x == y, + }) +} diff --git a/crates/polars-arrow/src/array/equal/fixed_size_binary.rs b/crates/polars-arrow/src/array/equal/fixed_size_binary.rs new file mode 100644 index 000000000000..0e956e872090 --- /dev/null +++ b/crates/polars-arrow/src/array/equal/fixed_size_binary.rs @@ -0,0 +1,5 @@ +use crate::array::{Array, FixedSizeBinaryArray}; + +pub(super) fn equal(lhs: &FixedSizeBinaryArray, rhs: &FixedSizeBinaryArray) -> bool { + lhs.dtype() == rhs.dtype() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) +} diff --git a/crates/polars-arrow/src/array/equal/fixed_size_list.rs b/crates/polars-arrow/src/array/equal/fixed_size_list.rs new file mode 100644 index 000000000000..26582aa05379 --- /dev/null +++ b/crates/polars-arrow/src/array/equal/fixed_size_list.rs @@ -0,0 +1,5 @@ +use crate::array::{Array, FixedSizeListArray}; + +pub(super) fn equal(lhs: &FixedSizeListArray, rhs: &FixedSizeListArray) -> bool { + lhs.dtype() == rhs.dtype() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) +} diff --git a/crates/polars-arrow/src/array/equal/list.rs b/crates/polars-arrow/src/array/equal/list.rs new file mode 100644 index 000000000000..5c08e2103dcb --- /dev/null +++ b/crates/polars-arrow/src/array/equal/list.rs @@ -0,0 +1,6 @@ +use crate::array::{Array, ListArray}; +use crate::offset::Offset; + +pub(super) fn equal(lhs: &ListArray, rhs: &ListArray) -> bool { + lhs.dtype() == rhs.dtype() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) +} diff --git a/crates/polars-arrow/src/array/equal/map.rs b/crates/polars-arrow/src/array/equal/map.rs new file mode 100644 index 000000000000..b98d65cea03a --- /dev/null +++ b/crates/polars-arrow/src/array/equal/map.rs @@ -0,0 +1,5 @@ +use crate::array::{Array, MapArray}; + +pub(super) fn equal(lhs: &MapArray, rhs: &MapArray) -> bool { + lhs.dtype() == rhs.dtype() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) +} diff --git a/crates/polars-arrow/src/array/equal/mod.rs b/crates/polars-arrow/src/array/equal/mod.rs new file mode 100644 index 000000000000..971e4cbca4e8 --- /dev/null +++ b/crates/polars-arrow/src/array/equal/mod.rs @@ -0,0 +1,298 @@ +use super::*; +use crate::offset::Offset; +use crate::types::NativeType; + +mod binary; +mod binary_view; +mod boolean; +mod dictionary; +mod fixed_size_binary; +mod fixed_size_list; +mod list; +mod map; +mod null; +mod primitive; +mod struct_; +mod union; +mod utf8; + +impl PartialEq for dyn Array + '_ { + fn eq(&self, that: &dyn Array) -> bool { + equal(self, that) + } +} + +impl PartialEq for std::sync::Arc { + fn eq(&self, that: &dyn Array) -> bool { + equal(&**self, that) + } +} + +impl PartialEq for Box { + fn eq(&self, that: &dyn Array) -> bool { + equal(&**self, that) + } +} + +impl PartialEq for NullArray { + fn eq(&self, other: &Self) -> bool { + null::equal(self, other) + } +} + +impl PartialEq<&dyn Array> for NullArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq<&dyn Array> for PrimitiveArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq> for &dyn Array { + fn eq(&self, other: &PrimitiveArray) -> bool { + equal(*self, other) + } +} + +impl PartialEq> for PrimitiveArray { + fn eq(&self, other: &Self) -> bool { + primitive::equal::(self, other) + } +} + +impl PartialEq for BooleanArray { + fn eq(&self, other: &Self) -> bool { + equal(self, other) + } +} + +impl PartialEq<&dyn Array> for BooleanArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq> for Utf8Array { + fn eq(&self, other: &Self) -> bool { + utf8::equal(self, other) + } +} + +impl PartialEq<&dyn Array> for Utf8Array { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq> for &dyn Array { + fn eq(&self, other: &Utf8Array) -> bool { + equal(*self, other) + } +} + +impl PartialEq> for BinaryArray { + fn eq(&self, other: &Self) -> bool { + binary::equal(self, other) + } +} + +impl PartialEq<&dyn Array> for BinaryArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq> for &dyn Array { + fn eq(&self, other: &BinaryArray) -> bool { + equal(*self, other) + } +} + +impl PartialEq for FixedSizeBinaryArray { + fn eq(&self, other: &Self) -> bool { + fixed_size_binary::equal(self, other) + } +} + +impl PartialEq<&dyn Array> for FixedSizeBinaryArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq> for ListArray { + fn eq(&self, other: &Self) -> bool { + list::equal(self, other) + } +} + +impl PartialEq<&dyn Array> for ListArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq for FixedSizeListArray { + fn eq(&self, other: &Self) -> bool { + fixed_size_list::equal(self, other) + } +} + +impl PartialEq<&dyn Array> for FixedSizeListArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq for StructArray { + fn eq(&self, other: &Self) -> bool { + struct_::equal(self, other) + } +} + +impl PartialEq<&dyn Array> for StructArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq> for DictionaryArray { + fn eq(&self, other: &Self) -> bool { + dictionary::equal(self, other) + } +} + +impl PartialEq<&dyn Array> for DictionaryArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq for UnionArray { + fn eq(&self, other: &Self) -> bool { + union::equal(self, other) + } +} + +impl PartialEq<&dyn Array> for UnionArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq for MapArray { + fn eq(&self, other: &Self) -> bool { + map::equal(self, other) + } +} + +impl PartialEq<&dyn Array> for MapArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +/// Logically compares two [`Array`]s. +/// Two arrays are logically equal if and only if: +/// * their data types are equal +/// * each of their items are equal +pub fn equal(lhs: &dyn Array, rhs: &dyn Array) -> bool { + if lhs.dtype() != rhs.dtype() { + return false; + } + + use crate::datatypes::PhysicalType::*; + match lhs.dtype().to_physical_type() { + Null => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + null::equal(lhs, rhs) + }, + Boolean => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + boolean::equal(lhs, rhs) + }, + Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + primitive::equal::<$T>(lhs, rhs) + }), + Utf8 => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + utf8::equal::(lhs, rhs) + }, + LargeUtf8 => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + utf8::equal::(lhs, rhs) + }, + Binary => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + binary::equal::(lhs, rhs) + }, + LargeBinary => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + binary::equal::(lhs, rhs) + }, + List => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + list::equal::(lhs, rhs) + }, + LargeList => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + list::equal::(lhs, rhs) + }, + Struct => { + let lhs = lhs.as_any().downcast_ref::().unwrap(); + let rhs = rhs.as_any().downcast_ref::().unwrap(); + struct_::equal(lhs, rhs) + }, + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + dictionary::equal::<$T>(lhs, rhs) + }) + }, + FixedSizeBinary => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + fixed_size_binary::equal(lhs, rhs) + }, + FixedSizeList => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + fixed_size_list::equal(lhs, rhs) + }, + Union => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + union::equal(lhs, rhs) + }, + Map => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + map::equal(lhs, rhs) + }, + BinaryView => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + binary_view::equal::<[u8]>(lhs, rhs) + }, + Utf8View => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + binary_view::equal::(lhs, rhs) + }, + } +} diff --git a/crates/polars-arrow/src/array/equal/null.rs b/crates/polars-arrow/src/array/equal/null.rs new file mode 100644 index 000000000000..599d826f0772 --- /dev/null +++ b/crates/polars-arrow/src/array/equal/null.rs @@ -0,0 +1,6 @@ +use crate::array::NullArray; + +#[inline] +pub(super) fn equal(lhs: &NullArray, rhs: &NullArray) -> bool { + lhs.len() == rhs.len() +} diff --git a/crates/polars-arrow/src/array/equal/primitive.rs b/crates/polars-arrow/src/array/equal/primitive.rs new file mode 100644 index 000000000000..375335155dc8 --- /dev/null +++ b/crates/polars-arrow/src/array/equal/primitive.rs @@ -0,0 +1,6 @@ +use crate::array::PrimitiveArray; +use crate::types::NativeType; + +pub(super) fn equal(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> bool { + lhs.dtype() == rhs.dtype() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) +} diff --git a/crates/polars-arrow/src/array/equal/struct_.rs b/crates/polars-arrow/src/array/equal/struct_.rs new file mode 100644 index 000000000000..3e50626fe7d1 --- /dev/null +++ b/crates/polars-arrow/src/array/equal/struct_.rs @@ -0,0 +1,54 @@ +use crate::array::{Array, StructArray}; + +pub(super) fn equal(lhs: &StructArray, rhs: &StructArray) -> bool { + lhs.dtype() == rhs.dtype() + && lhs.len() == rhs.len() + && match (lhs.validity(), rhs.validity()) { + (None, None) => lhs.values().iter().eq(rhs.values().iter()), + (Some(l_validity), Some(r_validity)) => lhs + .values() + .iter() + .zip(rhs.values().iter()) + .all(|(lhs, rhs)| { + l_validity.iter().zip(r_validity.iter()).enumerate().all( + |(i, (lhs_is_valid, rhs_is_valid))| { + if lhs_is_valid && rhs_is_valid { + lhs.sliced(i, 1) == rhs.sliced(i, 1) + } else { + lhs_is_valid == rhs_is_valid + } + }, + ) + }), + (Some(l_validity), None) => { + lhs.values() + .iter() + .zip(rhs.values().iter()) + .all(|(lhs, rhs)| { + l_validity.iter().enumerate().all(|(i, lhs_is_valid)| { + if lhs_is_valid { + lhs.sliced(i, 1) == rhs.sliced(i, 1) + } else { + // rhs is always valid => different + false + } + }) + }) + }, + (None, Some(r_validity)) => { + lhs.values() + .iter() + .zip(rhs.values().iter()) + .all(|(lhs, rhs)| { + r_validity.iter().enumerate().all(|(i, rhs_is_valid)| { + if rhs_is_valid { + lhs.sliced(i, 1) == rhs.sliced(i, 1) + } else { + // lhs is always valid => different + false + } + }) + }) + }, + } +} diff --git a/crates/polars-arrow/src/array/equal/union.rs b/crates/polars-arrow/src/array/equal/union.rs new file mode 100644 index 000000000000..94881c187fe9 --- /dev/null +++ b/crates/polars-arrow/src/array/equal/union.rs @@ -0,0 +1,5 @@ +use crate::array::{Array, UnionArray}; + +pub(super) fn equal(lhs: &UnionArray, rhs: &UnionArray) -> bool { + lhs.dtype() == rhs.dtype() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) +} diff --git a/crates/polars-arrow/src/array/equal/utf8.rs b/crates/polars-arrow/src/array/equal/utf8.rs new file mode 100644 index 000000000000..f76d30a87368 --- /dev/null +++ b/crates/polars-arrow/src/array/equal/utf8.rs @@ -0,0 +1,6 @@ +use crate::array::Utf8Array; +use crate::offset::Offset; + +pub(super) fn equal(lhs: &Utf8Array, rhs: &Utf8Array) -> bool { + lhs.dtype() == rhs.dtype() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) +} diff --git a/crates/polars-arrow/src/array/ffi.rs b/crates/polars-arrow/src/array/ffi.rs new file mode 100644 index 000000000000..977568406900 --- /dev/null +++ b/crates/polars-arrow/src/array/ffi.rs @@ -0,0 +1,90 @@ +use crate::array::*; +use crate::datatypes::PhysicalType; +use crate::ffi; + +/// Trait describing how a struct presents itself to the +/// [C data interface](https://arrow.apache.org/docs/format/CDataInterface.html) (FFI). +/// # Safety +/// Implementing this trait incorrect will lead to UB +pub(crate) unsafe trait ToFfi { + /// The pointers to the buffers. + fn buffers(&self) -> Vec>; + + /// The children + fn children(&self) -> Vec> { + vec![] + } + + /// The offset + fn offset(&self) -> Option; + + /// return a partial clone of self with an offset. + fn to_ffi_aligned(&self) -> Self; +} + +/// Trait describing how a struct imports into itself from the +/// [C data interface](https://arrow.apache.org/docs/format/CDataInterface.html) (FFI). +pub(crate) trait FromFfi: Sized { + /// Convert itself from FFI. + /// + /// # Safety + /// This function is intrinsically `unsafe` as it requires the FFI to be made according + /// to the [C data interface](https://arrow.apache.org/docs/format/CDataInterface.html) + unsafe fn try_from_ffi(array: T) -> PolarsResult; +} + +macro_rules! ffi_dyn { + ($array:expr, $ty:ty) => {{ + let array = $array.as_any().downcast_ref::<$ty>().unwrap(); + + ( + array.offset().unwrap(), + array.buffers(), + array.children(), + None, + ) + }}; +} + +type BuffersChildren = ( + usize, + Vec>, + Vec>, + Option>, +); + +pub fn offset_buffers_children_dictionary(array: &dyn Array) -> BuffersChildren { + use PhysicalType::*; + + match array.dtype().to_physical_type() { + Null => ffi_dyn!(array, NullArray), + Boolean => ffi_dyn!(array, BooleanArray), + Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { + ffi_dyn!(array, PrimitiveArray<$T>) + }), + Binary => ffi_dyn!(array, BinaryArray), + LargeBinary => ffi_dyn!(array, BinaryArray), + FixedSizeBinary => ffi_dyn!(array, FixedSizeBinaryArray), + Utf8 => ffi_dyn!(array, Utf8Array::), + LargeUtf8 => ffi_dyn!(array, Utf8Array::), + List => ffi_dyn!(array, ListArray::), + LargeList => ffi_dyn!(array, ListArray::), + FixedSizeList => ffi_dyn!(array, FixedSizeListArray), + Struct => ffi_dyn!(array, StructArray), + Union => ffi_dyn!(array, UnionArray), + Map => ffi_dyn!(array, MapArray), + BinaryView => ffi_dyn!(array, BinaryViewArray), + Utf8View => ffi_dyn!(array, Utf8ViewArray), + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + let array = array.as_any().downcast_ref::>().unwrap(); + ( + array.offset().unwrap(), + array.buffers(), + array.children(), + Some(array.values().clone()), + ) + }) + }, + } +} diff --git a/crates/polars-arrow/src/array/fixed_size_binary/builder.rs b/crates/polars-arrow/src/array/fixed_size_binary/builder.rs new file mode 100644 index 000000000000..edb1b6246c7f --- /dev/null +++ b/crates/polars-arrow/src/array/fixed_size_binary/builder.rs @@ -0,0 +1,124 @@ +use polars_utils::IdxSize; + +use super::FixedSizeBinaryArray; +use crate::array::builder::{ShareStrategy, StaticArrayBuilder}; +use crate::bitmap::OptBitmapBuilder; +use crate::buffer::Buffer; +use crate::datatypes::ArrowDataType; +use crate::pushable::Pushable; + +pub struct FixedSizeBinaryArrayBuilder { + dtype: ArrowDataType, + size: usize, + length: usize, + values: Vec, + validity: OptBitmapBuilder, +} + +impl FixedSizeBinaryArrayBuilder { + pub fn new(dtype: ArrowDataType) -> Self { + Self { + size: FixedSizeBinaryArray::get_size(&dtype), + length: 0, + dtype, + values: Vec::new(), + validity: OptBitmapBuilder::default(), + } + } +} + +impl StaticArrayBuilder for FixedSizeBinaryArrayBuilder { + type Array = FixedSizeBinaryArray; + + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } + + fn reserve(&mut self, additional: usize) { + let bytes = additional * self.size; + self.values.reserve(bytes); + self.validity.reserve(additional); + } + + fn freeze(self) -> FixedSizeBinaryArray { + // TODO: FixedSizeBinaryArray should track its own length to be correct + // for size-0 inner. + let values = Buffer::from(self.values); + let validity = self.validity.into_opt_validity(); + FixedSizeBinaryArray::new(self.dtype, values, validity) + } + + fn freeze_reset(&mut self) -> Self::Array { + // TODO: FixedSizeBinaryArray should track its own length to be correct + // for size-0 inner. + let values = Buffer::from(core::mem::take(&mut self.values)); + let validity = core::mem::take(&mut self.validity).into_opt_validity(); + let out = FixedSizeBinaryArray::new(self.dtype.clone(), values, validity); + self.length = 0; + out + } + + fn len(&self) -> usize { + self.length + } + + fn extend_nulls(&mut self, length: usize) { + self.values.extend_constant(length * self.size, 0); + self.validity.extend_constant(length, false); + self.length += length; + } + + fn subslice_extend( + &mut self, + other: &FixedSizeBinaryArray, + start: usize, + length: usize, + _share: ShareStrategy, + ) { + let other_slice = other.values().as_slice(); + self.values + .extend_from_slice(&other_slice[start * self.size..(start + length) * self.size]); + self.validity + .subslice_extend_from_opt_validity(other.validity(), start, length); + self.length += length.min(other.len().saturating_sub(start)); + } + + unsafe fn gather_extend( + &mut self, + other: &FixedSizeBinaryArray, + idxs: &[IdxSize], + _share: ShareStrategy, + ) { + let other_slice = other.values().as_slice(); + self.values.reserve(idxs.len() * self.size); + for idx in idxs { + let idx = *idx as usize; + let subslice = other_slice.get_unchecked(idx * self.size..(idx + 1) * self.size); + self.values.extend_from_slice(subslice); + } + self.validity + .gather_extend_from_opt_validity(other.validity(), idxs); + self.length += idxs.len(); + } + + fn opt_gather_extend( + &mut self, + other: &FixedSizeBinaryArray, + idxs: &[IdxSize], + _share: ShareStrategy, + ) { + let other_slice = other.values().as_slice(); + self.values.reserve(idxs.len() * self.size); + for idx in idxs { + let idx = *idx as usize; + if let Some(subslice) = other_slice.get(idx * self.size..(idx + 1) * self.size) { + self.values.extend_from_slice(subslice); + } else { + self.values.extend_constant(self.size, 0); + } + } + self.validity + .opt_gather_extend_from_opt_validity(other.validity(), idxs, other.len()); + self.length += idxs.len(); + } +} diff --git a/crates/polars-arrow/src/array/fixed_size_binary/ffi.rs b/crates/polars-arrow/src/array/fixed_size_binary/ffi.rs new file mode 100644 index 000000000000..d3d0c777dd66 --- /dev/null +++ b/crates/polars-arrow/src/array/fixed_size_binary/ffi.rs @@ -0,0 +1,57 @@ +use polars_error::PolarsResult; + +use super::FixedSizeBinaryArray; +use crate::array::{FromFfi, ToFfi}; +use crate::bitmap::align; +use crate::ffi; + +unsafe impl ToFfi for FixedSizeBinaryArray { + fn buffers(&self) -> Vec> { + vec![ + self.validity.as_ref().map(|x| x.as_ptr()), + Some(self.values.storage_ptr().cast::()), + ] + } + + fn offset(&self) -> Option { + let offset = self.values.offset() / self.size; + if let Some(bitmap) = self.validity.as_ref() { + if bitmap.offset() == offset { + Some(offset) + } else { + None + } + } else { + Some(offset) + } + } + + fn to_ffi_aligned(&self) -> Self { + let offset = self.values.offset() / self.size; + + let validity = self.validity.as_ref().map(|bitmap| { + if bitmap.offset() == offset { + bitmap.clone() + } else { + align(bitmap, offset) + } + }); + + Self { + size: self.size, + dtype: self.dtype.clone(), + validity, + values: self.values.clone(), + } + } +} + +impl FromFfi for FixedSizeBinaryArray { + unsafe fn try_from_ffi(array: A) -> PolarsResult { + let dtype = array.dtype().clone(); + let validity = unsafe { array.validity() }?; + let values = unsafe { array.buffer::(1) }?; + + Self::try_new(dtype, values, validity) + } +} diff --git a/crates/polars-arrow/src/array/fixed_size_binary/fmt.rs b/crates/polars-arrow/src/array/fixed_size_binary/fmt.rs new file mode 100644 index 000000000000..6aa47acf3fd8 --- /dev/null +++ b/crates/polars-arrow/src/array/fixed_size_binary/fmt.rs @@ -0,0 +1,20 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::write_vec; +use super::FixedSizeBinaryArray; + +pub fn write_value(array: &FixedSizeBinaryArray, index: usize, f: &mut W) -> Result { + let values = array.value(index); + let writer = |f: &mut W, index| write!(f, "{}", values[index]); + + write_vec(f, writer, None, values.len(), "None", false) +} + +impl Debug for FixedSizeBinaryArray { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, f); + + write!(f, "{:?}", self.dtype)?; + write_vec(f, writer, self.validity(), self.len(), "None", false) + } +} diff --git a/crates/polars-arrow/src/array/fixed_size_binary/iterator.rs b/crates/polars-arrow/src/array/fixed_size_binary/iterator.rs new file mode 100644 index 000000000000..4c885c591943 --- /dev/null +++ b/crates/polars-arrow/src/array/fixed_size_binary/iterator.rs @@ -0,0 +1,49 @@ +use super::{FixedSizeBinaryArray, MutableFixedSizeBinaryArray}; +use crate::array::MutableArray; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; + +impl<'a> IntoIterator for &'a FixedSizeBinaryArray { + type Item = Option<&'a [u8]>; + type IntoIter = ZipValidity<&'a [u8], std::slice::ChunksExact<'a, u8>, BitmapIter<'a>>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a> FixedSizeBinaryArray { + /// constructs a new iterator + pub fn iter( + &'a self, + ) -> ZipValidity<&'a [u8], std::slice::ChunksExact<'a, u8>, BitmapIter<'a>> { + ZipValidity::new_with_validity(self.values_iter(), self.validity()) + } + + /// Returns iterator over the values of [`FixedSizeBinaryArray`] + pub fn values_iter(&'a self) -> std::slice::ChunksExact<'a, u8> { + self.values().chunks_exact(self.size) + } +} + +impl<'a> IntoIterator for &'a MutableFixedSizeBinaryArray { + type Item = Option<&'a [u8]>; + type IntoIter = ZipValidity<&'a [u8], std::slice::ChunksExact<'a, u8>, BitmapIter<'a>>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a> MutableFixedSizeBinaryArray { + /// constructs a new iterator + pub fn iter( + &'a self, + ) -> ZipValidity<&'a [u8], std::slice::ChunksExact<'a, u8>, BitmapIter<'a>> { + ZipValidity::new(self.iter_values(), self.validity().map(|x| x.iter())) + } + + /// Returns iterator over the values of [`MutableFixedSizeBinaryArray`] + pub fn iter_values(&'a self) -> std::slice::ChunksExact<'a, u8> { + self.values().chunks_exact(self.size()) + } +} diff --git a/crates/polars-arrow/src/array/fixed_size_binary/mod.rs b/crates/polars-arrow/src/array/fixed_size_binary/mod.rs new file mode 100644 index 000000000000..3591ed88ac92 --- /dev/null +++ b/crates/polars-arrow/src/array/fixed_size_binary/mod.rs @@ -0,0 +1,293 @@ +use super::{Array, Splitable}; +use crate::bitmap::Bitmap; +use crate::buffer::Buffer; +use crate::datatypes::ArrowDataType; + +mod builder; +mod ffi; +pub(super) mod fmt; +mod iterator; +pub use builder::*; +mod mutable; +pub use mutable::*; +use polars_error::{PolarsResult, polars_bail, polars_ensure}; + +/// The Arrow's equivalent to an immutable `Vec>`. +/// Cloning and slicing this struct is `O(1)`. +#[derive(Clone)] +pub struct FixedSizeBinaryArray { + size: usize, // this is redundant with `dtype`, but useful to not have to deconstruct the dtype. + dtype: ArrowDataType, + values: Buffer, + validity: Option, +} + +impl FixedSizeBinaryArray { + /// Creates a new [`FixedSizeBinaryArray`]. + /// + /// # Errors + /// This function returns an error iff: + /// * The `dtype`'s physical type is not [`crate::datatypes::PhysicalType::FixedSizeBinary`] + /// * The length of `values` is not a multiple of `size` in `dtype` + /// * the validity's length is not equal to `values.len() / size`. + pub fn try_new( + dtype: ArrowDataType, + values: Buffer, + validity: Option, + ) -> PolarsResult { + let size = Self::maybe_get_size(&dtype)?; + + if values.len() % size != 0 { + polars_bail!(ComputeError: + "values (of len {}) must be a multiple of size ({}) in FixedSizeBinaryArray.", + values.len(), + size + ) + } + let len = values.len() / size; + + if validity + .as_ref() + .is_some_and(|validity| validity.len() != len) + { + polars_bail!(ComputeError: "validity mask length must be equal to the number of values divided by size") + } + + Ok(Self { + size, + dtype, + values, + validity, + }) + } + + /// Creates a new [`FixedSizeBinaryArray`]. + /// # Panics + /// This function panics iff: + /// * The `dtype`'s physical type is not [`crate::datatypes::PhysicalType::FixedSizeBinary`] + /// * The length of `values` is not a multiple of `size` in `dtype` + /// * the validity's length is not equal to `values.len() / size`. + pub fn new(dtype: ArrowDataType, values: Buffer, validity: Option) -> Self { + Self::try_new(dtype, values, validity).unwrap() + } + + /// Returns a new empty [`FixedSizeBinaryArray`]. + pub fn new_empty(dtype: ArrowDataType) -> Self { + Self::new(dtype, Buffer::new(), None) + } + + /// Returns a new null [`FixedSizeBinaryArray`]. + pub fn new_null(dtype: ArrowDataType, length: usize) -> Self { + let size = Self::maybe_get_size(&dtype).unwrap(); + Self::new( + dtype, + vec![0u8; length * size].into(), + Some(Bitmap::new_zeroed(length)), + ) + } +} + +// must use +impl FixedSizeBinaryArray { + /// Slices this [`FixedSizeBinaryArray`]. + /// # Implementation + /// This operation is `O(1)`. + /// # Panics + /// panics iff `offset + length > self.len()` + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Slices this [`FixedSizeBinaryArray`]. + /// # Implementation + /// This operation is `O(1)`. + /// + /// # Safety + /// The caller must ensure that `offset + length <= self.len()`. + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.validity = self + .validity + .take() + .map(|bitmap| bitmap.sliced_unchecked(offset, length)) + .filter(|bitmap| bitmap.unset_bits() > 0); + self.values + .slice_unchecked(offset * self.size, length * self.size); + } + + impl_sliced!(); + impl_mut_validity!(); + impl_into_array!(); +} + +// accessors +impl FixedSizeBinaryArray { + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.values.len() / self.size + } + + /// The optional validity. + #[inline] + pub fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + /// Returns the values allocated on this [`FixedSizeBinaryArray`]. + pub fn values(&self) -> &Buffer { + &self.values + } + + /// Returns value at position `i`. + /// # Panic + /// Panics iff `i >= self.len()`. + #[inline] + pub fn value(&self, i: usize) -> &[u8] { + assert!(i < self.len()); + unsafe { self.value_unchecked(i) } + } + + /// Returns the element at index `i` as &str + /// + /// # Safety + /// Assumes that the `i < self.len`. + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> &[u8] { + // soundness: invariant of the function. + self.values + .get_unchecked(i * self.size..(i + 1) * self.size) + } + + /// Returns the element at index `i` or `None` if it is null + /// # Panics + /// iff `i >= self.len()` + #[inline] + pub fn get(&self, i: usize) -> Option<&[u8]> { + if !self.is_null(i) { + // soundness: Array::is_null panics if i >= self.len + unsafe { Some(self.value_unchecked(i)) } + } else { + None + } + } + + /// Returns a new [`FixedSizeBinaryArray`] with a different logical type. + /// This is `O(1)`. + /// # Panics + /// Panics iff the dtype is not supported for the physical type. + #[inline] + pub fn to(self, dtype: ArrowDataType) -> Self { + match (dtype.to_logical_type(), self.dtype().to_logical_type()) { + (ArrowDataType::FixedSizeBinary(size_a), ArrowDataType::FixedSizeBinary(size_b)) + if size_a == size_b => {}, + _ => panic!("Wrong DataType"), + } + + Self { + size: self.size, + dtype, + values: self.values, + validity: self.validity, + } + } + + /// Returns the size + pub fn size(&self) -> usize { + self.size + } +} + +impl FixedSizeBinaryArray { + pub(crate) fn maybe_get_size(dtype: &ArrowDataType) -> PolarsResult { + match dtype.to_logical_type() { + ArrowDataType::FixedSizeBinary(size) => { + polars_ensure!(*size != 0, ComputeError: "FixedSizeBinaryArray expects a positive size"); + Ok(*size) + }, + other => { + polars_bail!(ComputeError: "FixedSizeBinaryArray expects DataType::FixedSizeBinary. found {other:?}") + }, + } + } + + pub fn get_size(dtype: &ArrowDataType) -> usize { + Self::maybe_get_size(dtype).unwrap() + } +} + +impl Array for FixedSizeBinaryArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + #[inline] + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } +} + +impl Splitable for FixedSizeBinaryArray { + fn check_bound(&self, offset: usize) -> bool { + offset < self.len() + } + + unsafe fn _split_at_unchecked(&self, offset: usize) -> (Self, Self) { + let (lhs_values, rhs_values) = unsafe { self.values.split_at_unchecked(offset) }; + let (lhs_validity, rhs_validity) = unsafe { self.validity.split_at_unchecked(offset) }; + + let size = self.size; + + ( + Self { + dtype: self.dtype.clone(), + values: lhs_values, + validity: lhs_validity, + size, + }, + Self { + dtype: self.dtype.clone(), + values: rhs_values, + validity: rhs_validity, + size, + }, + ) + } +} + +impl FixedSizeBinaryArray { + /// Creates a [`FixedSizeBinaryArray`] from an fallible iterator of optional `[u8]`. + pub fn try_from_iter, I: IntoIterator>>( + iter: I, + size: usize, + ) -> PolarsResult { + MutableFixedSizeBinaryArray::try_from_iter(iter, size).map(|x| x.into()) + } + + /// Creates a [`FixedSizeBinaryArray`] from an iterator of optional `[u8]`. + pub fn from_iter, I: IntoIterator>>( + iter: I, + size: usize, + ) -> Self { + MutableFixedSizeBinaryArray::try_from_iter(iter, size) + .unwrap() + .into() + } + + /// Creates a [`FixedSizeBinaryArray`] from a slice of arrays of bytes + pub fn from_slice>(a: P) -> Self { + let values = a.as_ref().iter().flatten().copied().collect::>(); + Self::new(ArrowDataType::FixedSizeBinary(N), values.into(), None) + } + + /// Creates a new [`FixedSizeBinaryArray`] from a slice of optional `[u8]`. + // Note: this can't be `impl From` because Rust does not allow double `AsRef` on it. + pub fn from]>>(slice: P) -> Self { + MutableFixedSizeBinaryArray::from(slice).into() + } +} diff --git a/crates/polars-arrow/src/array/fixed_size_binary/mutable.rs b/crates/polars-arrow/src/array/fixed_size_binary/mutable.rs new file mode 100644 index 000000000000..f49066c7e410 --- /dev/null +++ b/crates/polars-arrow/src/array/fixed_size_binary/mutable.rs @@ -0,0 +1,314 @@ +use std::sync::Arc; + +use polars_error::{PolarsResult, polars_bail}; + +use super::FixedSizeBinaryArray; +use crate::array::physical_binary::extend_validity; +use crate::array::{Array, MutableArray, TryExtendFromSelf}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::ArrowDataType; + +/// The Arrow's equivalent to a mutable `Vec>`. +/// Converting a [`MutableFixedSizeBinaryArray`] into a [`FixedSizeBinaryArray`] is `O(1)`. +/// # Implementation +/// This struct does not allocate a validity until one is required (i.e. push a null to it). +#[derive(Debug, Clone)] +pub struct MutableFixedSizeBinaryArray { + dtype: ArrowDataType, + size: usize, + values: Vec, + validity: Option, +} + +impl From for FixedSizeBinaryArray { + fn from(other: MutableFixedSizeBinaryArray) -> Self { + FixedSizeBinaryArray::new( + other.dtype, + other.values.into(), + other.validity.map(|x| x.into()), + ) + } +} + +impl MutableFixedSizeBinaryArray { + /// Creates a new [`MutableFixedSizeBinaryArray`]. + /// + /// # Errors + /// This function returns an error iff: + /// * The `dtype`'s physical type is not [`crate::datatypes::PhysicalType::FixedSizeBinary`] + /// * The length of `values` is not a multiple of `size` in `dtype` + /// * the validity's length is not equal to `values.len() / size`. + pub fn try_new( + dtype: ArrowDataType, + values: Vec, + validity: Option, + ) -> PolarsResult { + let size = FixedSizeBinaryArray::maybe_get_size(&dtype)?; + + if values.len() % size != 0 { + polars_bail!(ComputeError: + "values (of len {}) must be a multiple of size ({}) in FixedSizeBinaryArray.", + values.len(), + size + ) + } + let len = values.len() / size; + + if validity + .as_ref() + .is_some_and(|validity| validity.len() != len) + { + polars_bail!(ComputeError: "validity mask length must be equal to the number of values divided by size") + } + + Ok(Self { + size, + dtype, + values, + validity, + }) + } + + /// Creates a new empty [`MutableFixedSizeBinaryArray`]. + pub fn new(size: usize) -> Self { + Self::with_capacity(size, 0) + } + + /// Creates a new [`MutableFixedSizeBinaryArray`] with capacity for `capacity` entries. + pub fn with_capacity(size: usize, capacity: usize) -> Self { + Self::try_new( + ArrowDataType::FixedSizeBinary(size), + Vec::::with_capacity(capacity * size), + None, + ) + .unwrap() + } + + /// Creates a new [`MutableFixedSizeBinaryArray`] from a slice of optional `[u8]`. + // Note: this can't be `impl From` because Rust does not allow double `AsRef` on it. + pub fn from]>>(slice: P) -> Self { + let values = slice + .as_ref() + .iter() + .copied() + .flat_map(|x| x.unwrap_or([0; N])) + .collect::>(); + let validity = slice + .as_ref() + .iter() + .map(|x| x.is_some()) + .collect::(); + Self::try_new(ArrowDataType::FixedSizeBinary(N), values, validity.into()).unwrap() + } + + /// tries to push a new entry to [`MutableFixedSizeBinaryArray`]. + /// # Error + /// Errors iff the size of `value` is not equal to its own size. + #[inline] + pub fn try_push>(&mut self, value: Option

) -> PolarsResult<()> { + match value { + Some(bytes) => { + let bytes = bytes.as_ref(); + if self.size != bytes.len() { + polars_bail!(ComputeError: "FixedSizeBinaryArray requires every item to be of its length") + } + self.values.extend_from_slice(bytes); + + if let Some(validity) = &mut self.validity { + validity.push(true) + } + }, + None => { + self.values.resize(self.values.len() + self.size, 0); + match &mut self.validity { + Some(validity) => validity.push(false), + None => self.init_validity(), + } + }, + } + Ok(()) + } + + /// pushes a new entry to [`MutableFixedSizeBinaryArray`]. + /// # Panics + /// Panics iff the size of `value` is not equal to its own size. + #[inline] + pub fn push>(&mut self, value: Option

) { + self.try_push(value).unwrap() + } + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.values.len() / self.size + } + + /// Pop the last entry from [`MutableFixedSizeBinaryArray`]. + /// This function returns `None` iff this array is empty + pub fn pop(&mut self) -> Option> { + if self.values.len() < self.size { + return None; + } + let value_start = self.values.len() - self.size; + let value = self.values.split_off(value_start); + self.validity + .as_mut() + .map(|x| x.pop()?.then(|| ())) + .unwrap_or_else(|| Some(())) + .map(|_| value) + } + + /// Creates a new [`MutableFixedSizeBinaryArray`] from an iterator of values. + /// # Errors + /// Errors iff the size of any of the `value` is not equal to its own size. + pub fn try_from_iter, I: IntoIterator>>( + iter: I, + size: usize, + ) -> PolarsResult { + let iterator = iter.into_iter(); + let (lower, _) = iterator.size_hint(); + let mut primitive = Self::with_capacity(size, lower); + for item in iterator { + primitive.try_push(item)? + } + Ok(primitive) + } + + /// returns the (fixed) size of the [`MutableFixedSizeBinaryArray`]. + #[inline] + pub fn size(&self) -> usize { + self.size + } + + /// Returns the capacity of this array + pub fn capacity(&self) -> usize { + self.values.capacity() / self.size + } + + fn init_validity(&mut self) { + let mut validity = MutableBitmap::new(); + validity.extend_constant(self.len(), true); + validity.set(self.len() - 1, false); + self.validity = Some(validity) + } + + /// Returns the element at index `i` as `&[u8]` + #[inline] + pub fn value(&self, i: usize) -> &[u8] { + &self.values[i * self.size..(i + 1) * self.size] + } + + /// Returns the element at index `i` as `&[u8]` + /// + /// # Safety + /// Assumes that the `i < self.len`. + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> &[u8] { + std::slice::from_raw_parts(self.values.as_ptr().add(i * self.size), self.size) + } + + /// Reserves `additional` slots. + pub fn reserve(&mut self, additional: usize) { + self.values.reserve(additional * self.size); + if let Some(x) = self.validity.as_mut() { + x.reserve(additional) + } + } + + /// Shrinks the capacity of the [`MutableFixedSizeBinaryArray`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + if let Some(validity) = &mut self.validity { + validity.shrink_to_fit() + } + } + + pub fn freeze(self) -> FixedSizeBinaryArray { + FixedSizeBinaryArray::new( + ArrowDataType::FixedSizeBinary(self.size), + self.values.into(), + self.validity.map(|x| x.into()), + ) + } +} + +/// Accessors +impl MutableFixedSizeBinaryArray { + /// Returns its values. + pub fn values(&self) -> &Vec { + &self.values + } + + /// Returns a mutable slice of values. + pub fn values_mut_slice(&mut self) -> &mut [u8] { + self.values.as_mut_slice() + } +} + +impl MutableArray for MutableFixedSizeBinaryArray { + fn len(&self) -> usize { + self.values.len() / self.size + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + FixedSizeBinaryArray::new( + ArrowDataType::FixedSizeBinary(self.size), + std::mem::take(&mut self.values).into(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .boxed() + } + + fn as_arc(&mut self) -> Arc { + FixedSizeBinaryArray::new( + ArrowDataType::FixedSizeBinary(self.size), + std::mem::take(&mut self.values).into(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .arced() + } + + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + fn push_null(&mut self) { + self.push::<&[u8]>(None); + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional) + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit() + } +} + +impl PartialEq for MutableFixedSizeBinaryArray { + fn eq(&self, other: &Self) -> bool { + self.iter().eq(other.iter()) + } +} + +impl TryExtendFromSelf for MutableFixedSizeBinaryArray { + fn try_extend_from_self(&mut self, other: &Self) -> PolarsResult<()> { + extend_validity(self.len(), &mut self.validity, &other.validity); + + let slice = other.values.as_slice(); + self.values.extend_from_slice(slice); + Ok(()) + } +} diff --git a/crates/polars-arrow/src/array/fixed_size_list/builder.rs b/crates/polars-arrow/src/array/fixed_size_list/builder.rs new file mode 100644 index 000000000000..3674a9302581 --- /dev/null +++ b/crates/polars-arrow/src/array/fixed_size_list/builder.rs @@ -0,0 +1,160 @@ +use polars_utils::IdxSize; + +use super::FixedSizeListArray; +use crate::array::builder::{ArrayBuilder, ShareStrategy, StaticArrayBuilder}; +use crate::bitmap::OptBitmapBuilder; +use crate::datatypes::ArrowDataType; + +pub struct FixedSizeListArrayBuilder { + dtype: ArrowDataType, + size: usize, + length: usize, + inner_builder: B, + validity: OptBitmapBuilder, +} +impl FixedSizeListArrayBuilder { + pub fn new(dtype: ArrowDataType, inner_builder: B) -> Self { + Self { + size: FixedSizeListArray::get_child_and_size(&dtype).1, + dtype, + length: 0, + inner_builder, + validity: OptBitmapBuilder::default(), + } + } +} + +impl StaticArrayBuilder for FixedSizeListArrayBuilder { + type Array = FixedSizeListArray; + + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } + + fn reserve(&mut self, additional: usize) { + self.inner_builder.reserve(additional); + self.validity.reserve(additional); + } + + fn freeze(self) -> FixedSizeListArray { + let values = self.inner_builder.freeze(); + let validity = self.validity.into_opt_validity(); + FixedSizeListArray::new(self.dtype, self.length, values, validity) + } + + fn freeze_reset(&mut self) -> Self::Array { + let values = self.inner_builder.freeze_reset(); + let validity = core::mem::take(&mut self.validity).into_opt_validity(); + let out = FixedSizeListArray::new(self.dtype.clone(), self.length, values, validity); + self.length = 0; + out + } + + fn len(&self) -> usize { + self.length + } + + fn extend_nulls(&mut self, length: usize) { + self.inner_builder.extend_nulls(length * self.size); + self.validity.extend_constant(length, false); + self.length += length; + } + + fn subslice_extend( + &mut self, + other: &FixedSizeListArray, + start: usize, + length: usize, + share: ShareStrategy, + ) { + self.inner_builder.subslice_extend( + &**other.values(), + start * self.size, + length * self.size, + share, + ); + self.validity + .subslice_extend_from_opt_validity(other.validity(), start, length); + self.length += length.min(other.len().saturating_sub(start)); + } + + unsafe fn gather_extend( + &mut self, + other: &FixedSizeListArray, + idxs: &[IdxSize], + share: ShareStrategy, + ) { + let other_values = &**other.values(); + self.inner_builder.reserve(idxs.len() * self.size); + + // Group consecutive indices into larger copies. + let mut group_start = 0; + while group_start < idxs.len() { + let start_idx = idxs[group_start] as usize; + let mut group_len = 1; + while group_start + group_len < idxs.len() + && idxs[group_start + group_len] as usize == start_idx + group_len + { + group_len += 1; + } + self.inner_builder.subslice_extend( + other_values, + start_idx * self.size, + group_len * self.size, + share, + ); + group_start += group_len; + } + + self.validity + .gather_extend_from_opt_validity(other.validity(), idxs); + self.length += idxs.len(); + } + + fn opt_gather_extend( + &mut self, + other: &FixedSizeListArray, + idxs: &[IdxSize], + share: ShareStrategy, + ) { + let other_values = &**other.values(); + self.inner_builder.reserve(idxs.len() * self.size); + + // Group consecutive indices into larger copies. + let mut group_start = 0; + while group_start < idxs.len() { + let start_idx = idxs[group_start] as usize; + let mut group_len = 1; + let in_bounds = start_idx < other.len(); + + if in_bounds { + while group_start + group_len < idxs.len() + && idxs[group_start + group_len] as usize == start_idx + group_len + && start_idx + group_len < other.len() + { + group_len += 1; + } + + self.inner_builder.subslice_extend( + other_values, + start_idx * self.size, + group_len * self.size, + share, + ); + } else { + while group_start + group_len < idxs.len() + && idxs[group_start + group_len] as usize >= other.len() + { + group_len += 1; + } + + self.inner_builder.extend_nulls(group_len * self.size); + } + group_start += group_len; + } + + self.validity + .opt_gather_extend_from_opt_validity(other.validity(), idxs, other.len()); + self.length += idxs.len(); + } +} diff --git a/crates/polars-arrow/src/array/fixed_size_list/ffi.rs b/crates/polars-arrow/src/array/fixed_size_list/ffi.rs new file mode 100644 index 000000000000..6541eaab56a9 --- /dev/null +++ b/crates/polars-arrow/src/array/fixed_size_list/ffi.rs @@ -0,0 +1,50 @@ +use polars_error::{PolarsResult, polars_ensure}; + +use super::FixedSizeListArray; +use crate::array::Array; +use crate::array::ffi::{FromFfi, ToFfi}; +use crate::ffi; + +unsafe impl ToFfi for FixedSizeListArray { + fn buffers(&self) -> Vec> { + vec![self.validity.as_ref().map(|x| x.as_ptr())] + } + + fn children(&self) -> Vec> { + vec![self.values.clone()] + } + + fn offset(&self) -> Option { + Some( + self.validity + .as_ref() + .map(|bitmap| bitmap.offset()) + .unwrap_or_default(), + ) + } + + fn to_ffi_aligned(&self) -> Self { + self.clone() + } +} + +impl FromFfi for FixedSizeListArray { + unsafe fn try_from_ffi(array: A) -> PolarsResult { + let dtype = array.dtype().clone(); + let (_, width) = FixedSizeListArray::try_child_and_size(&dtype)?; + let validity = unsafe { array.validity() }?; + let child = unsafe { array.child(0) }?; + let values = ffi::try_from(child)?; + + let length = if values.is_empty() { + 0 + } else { + polars_ensure!(width > 0, InvalidOperation: "Zero-width array with values"); + values.len() / width + }; + + let mut fsl = Self::try_new(dtype, length, values, validity)?; + fsl.slice(array.offset(), array.length()); + Ok(fsl) + } +} diff --git a/crates/polars-arrow/src/array/fixed_size_list/fmt.rs b/crates/polars-arrow/src/array/fixed_size_list/fmt.rs new file mode 100644 index 000000000000..ee7d86115a14 --- /dev/null +++ b/crates/polars-arrow/src/array/fixed_size_list/fmt.rs @@ -0,0 +1,24 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::{get_display, write_vec}; +use super::FixedSizeListArray; + +pub fn write_value( + array: &FixedSizeListArray, + index: usize, + null: &'static str, + f: &mut W, +) -> Result { + let values = array.value(index); + let writer = |f: &mut W, index| get_display(values.as_ref(), null)(f, index); + write_vec(f, writer, None, values.len(), null, false) +} + +impl Debug for FixedSizeListArray { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, "None", f); + + write!(f, "FixedSizeListArray")?; + write_vec(f, writer, self.validity(), self.len(), "None", false) + } +} diff --git a/crates/polars-arrow/src/array/fixed_size_list/iterator.rs b/crates/polars-arrow/src/array/fixed_size_list/iterator.rs new file mode 100644 index 000000000000..123658005adc --- /dev/null +++ b/crates/polars-arrow/src/array/fixed_size_list/iterator.rs @@ -0,0 +1,43 @@ +use super::FixedSizeListArray; +use crate::array::{Array, ArrayAccessor, ArrayValuesIter}; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; + +unsafe impl<'a> ArrayAccessor<'a> for FixedSizeListArray { + type Item = Box; + + #[inline] + unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item { + self.value_unchecked(index) + } + + #[inline] + fn len(&self) -> usize { + self.len() + } +} + +/// Iterator of values of a [`FixedSizeListArray`]. +pub type FixedSizeListValuesIter<'a> = ArrayValuesIter<'a, FixedSizeListArray>; + +type ZipIter<'a> = ZipValidity, FixedSizeListValuesIter<'a>, BitmapIter<'a>>; + +impl<'a> IntoIterator for &'a FixedSizeListArray { + type Item = Option>; + type IntoIter = ZipIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a> FixedSizeListArray { + /// Returns an iterator of `Option>` + pub fn iter(&'a self) -> ZipIter<'a> { + ZipValidity::new_with_validity(FixedSizeListValuesIter::new(self), self.validity()) + } + + /// Returns an iterator of `Box` + pub fn values_iter(&'a self) -> FixedSizeListValuesIter<'a> { + FixedSizeListValuesIter::new(self) + } +} diff --git a/crates/polars-arrow/src/array/fixed_size_list/mod.rs b/crates/polars-arrow/src/array/fixed_size_list/mod.rs new file mode 100644 index 000000000000..8cff9379820c --- /dev/null +++ b/crates/polars-arrow/src/array/fixed_size_list/mod.rs @@ -0,0 +1,408 @@ +use super::{Array, ArrayRef, Splitable, new_empty_array, new_null_array}; +use crate::bitmap::Bitmap; +use crate::datatypes::{ArrowDataType, Field}; + +mod ffi; +pub(super) mod fmt; +mod iterator; + +mod builder; +pub use builder::*; +mod mutable; +pub use mutable::*; +use polars_error::{PolarsResult, polars_bail, polars_ensure}; +use polars_utils::format_tuple; +use polars_utils::pl_str::PlSmallStr; + +use crate::datatypes::reshape::{Dimension, ReshapeDimension}; + +/// The Arrow's equivalent to an immutable `Vec>` where `T` is an Arrow type. +/// Cloning and slicing this struct is `O(1)`. +#[derive(Clone)] +pub struct FixedSizeListArray { + size: usize, // this is redundant with `dtype`, but useful to not have to deconstruct the dtype. + length: usize, // invariant: this is values.len() / size if size > 0 + dtype: ArrowDataType, + values: Box, + validity: Option, +} + +impl FixedSizeListArray { + /// Creates a new [`FixedSizeListArray`]. + /// + /// # Errors + /// This function returns an error iff: + /// * The `dtype`'s physical type is not [`crate::datatypes::PhysicalType::FixedSizeList`] + /// * The `dtype`'s inner field's data type is not equal to `values.dtype`. + /// * The length of `values` is not a multiple of `size` in `dtype` + /// * the validity's length is not equal to `values.len() / size`. + pub fn try_new( + dtype: ArrowDataType, + length: usize, + values: Box, + validity: Option, + ) -> PolarsResult { + let (child, size) = Self::try_child_and_size(&dtype)?; + + let child_dtype = &child.dtype; + let values_dtype = values.dtype(); + if child_dtype != values_dtype { + polars_bail!(ComputeError: "FixedSizeListArray's child's DataType must match. However, the expected DataType is {child_dtype:?} while it got {values_dtype:?}.") + } + + polars_ensure!(size == 0 || values.len() % size == 0, ComputeError: + "values (of len {}) must be a multiple of size ({}) in FixedSizeListArray.", + values.len(), + size + ); + + polars_ensure!(size == 0 || values.len() / size == length, ComputeError: + "length of values ({}) is not equal to given length ({}) in FixedSizeListArray({size}).", + values.len() / size, + length, + ); + polars_ensure!(size != 0 || values.is_empty(), ComputeError: + "zero width FixedSizeListArray has values (length = {}).", + values.len(), + ); + + if validity + .as_ref() + .is_some_and(|validity| validity.len() != length) + { + polars_bail!(ComputeError: "validity mask length must be equal to the number of values divided by size") + } + + Ok(Self { + size, + length, + dtype, + values, + validity, + }) + } + + #[inline] + fn has_invariants(&self) -> bool { + let has_valid_length = (self.size == 0 && self.values().is_empty()) + || (self.size > 0 + && self.values().len() % self.size() == 0 + && self.values().len() / self.size() == self.length); + let has_valid_validity = self + .validity + .as_ref() + .is_none_or(|v| v.len() == self.length); + + has_valid_length && has_valid_validity + } + + /// Alias to `Self::try_new(...).unwrap()` + #[track_caller] + pub fn new( + dtype: ArrowDataType, + length: usize, + values: Box, + validity: Option, + ) -> Self { + Self::try_new(dtype, length, values, validity).unwrap() + } + + /// Returns the size (number of elements per slot) of this [`FixedSizeListArray`]. + pub const fn size(&self) -> usize { + self.size + } + + /// Returns a new empty [`FixedSizeListArray`]. + pub fn new_empty(dtype: ArrowDataType) -> Self { + let values = new_empty_array(Self::get_child_and_size(&dtype).0.dtype().clone()); + Self::new(dtype, 0, values, None) + } + + /// Returns a new null [`FixedSizeListArray`]. + pub fn new_null(dtype: ArrowDataType, length: usize) -> Self { + let (field, size) = Self::get_child_and_size(&dtype); + + let values = new_null_array(field.dtype().clone(), length * size); + Self::new(dtype, length, values, Some(Bitmap::new_zeroed(length))) + } + + pub fn from_shape( + leaf_array: ArrayRef, + dimensions: &[ReshapeDimension], + ) -> PolarsResult { + polars_ensure!( + !dimensions.is_empty(), + InvalidOperation: "at least one dimension must be specified" + ); + let size = leaf_array.len(); + + let mut total_dim_size = 1; + let mut num_infers = 0; + for &dim in dimensions { + match dim { + ReshapeDimension::Infer => num_infers += 1, + ReshapeDimension::Specified(dim) => total_dim_size *= dim.get() as usize, + } + } + + polars_ensure!(num_infers <= 1, InvalidOperation: "can only specify one inferred dimension"); + + if size == 0 { + polars_ensure!( + num_infers > 0 || total_dim_size == 0, + InvalidOperation: "cannot reshape empty array into shape without zero dimension: {}", + format_tuple!(dimensions), + ); + + let mut prev_arrow_dtype = leaf_array.dtype().clone(); + let mut prev_array = leaf_array; + + // @NOTE: We need to collect the iterator here because it is lazily processed. + let mut current_length = dimensions[0].get_or_infer(0); + let len_iter = dimensions[1..] + .iter() + .map(|d| { + let length = current_length as usize; + current_length *= d.get_or_infer(0); + length + }) + .collect::>(); + + // We pop the outer dimension as that is the height of the series. + for (dim, length) in dimensions[1..].iter().zip(len_iter).rev() { + // Infer dimension if needed + let dim = dim.get_or_infer(0); + prev_arrow_dtype = prev_arrow_dtype.to_fixed_size_list(dim as usize, true); + + prev_array = + FixedSizeListArray::new(prev_arrow_dtype.clone(), length, prev_array, None) + .boxed(); + } + + return Ok(prev_array); + } + + polars_ensure!( + total_dim_size > 0, + InvalidOperation: "cannot reshape non-empty array into shape containing a zero dimension: {}", + format_tuple!(dimensions) + ); + + polars_ensure!( + size % total_dim_size == 0, + InvalidOperation: "cannot reshape array of size {} into shape {}", size, format_tuple!(dimensions) + ); + + let mut prev_arrow_dtype = leaf_array.dtype().clone(); + let mut prev_array = leaf_array; + + // We pop the outer dimension as that is the height of the series. + for dim in dimensions[1..].iter().rev() { + // Infer dimension if needed + let dim = dim.get_or_infer((size / total_dim_size) as u64); + prev_arrow_dtype = prev_arrow_dtype.to_fixed_size_list(dim as usize, true); + + prev_array = FixedSizeListArray::new( + prev_arrow_dtype.clone(), + prev_array.len() / dim as usize, + prev_array, + None, + ) + .boxed(); + } + Ok(prev_array) + } + + pub fn get_dims(&self) -> Vec { + let mut dims = vec![ + Dimension::new(self.length as _), + Dimension::new(self.size as _), + ]; + + let mut prev_array = &self.values; + + while let Some(a) = prev_array.as_any().downcast_ref::() { + dims.push(Dimension::new(a.size as _)); + prev_array = &a.values; + } + dims + } + + pub fn propagate_nulls(&self) -> Self { + let Some(validity) = self.validity() else { + return self.clone(); + }; + + let propagated_validity = if self.size == 1 { + validity.clone() + } else { + Bitmap::from_trusted_len_iter( + (0..self.size * validity.len()) + .map(|i| unsafe { validity.get_bit_unchecked(i / self.size) }), + ) + }; + + let propagated_validity = match self.values.validity() { + None => propagated_validity, + Some(val) => val & &propagated_validity, + }; + Self::new( + self.dtype().clone(), + self.length, + self.values.with_validity(Some(propagated_validity)), + self.validity.clone(), + ) + } +} + +// must use +impl FixedSizeListArray { + /// Slices this [`FixedSizeListArray`]. + /// # Implementation + /// This operation is `O(1)`. + /// # Panics + /// panics iff `offset + length > self.len()` + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Slices this [`FixedSizeListArray`]. + /// # Implementation + /// This operation is `O(1)`. + /// + /// # Safety + /// The caller must ensure that `offset + length <= self.len()`. + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + debug_assert!(offset + length <= self.len()); + self.validity = self + .validity + .take() + .map(|bitmap| bitmap.sliced_unchecked(offset, length)) + .filter(|bitmap| bitmap.unset_bits() > 0); + self.values + .slice_unchecked(offset * self.size, length * self.size); + self.length = length; + } + + impl_sliced!(); + impl_mut_validity!(); + impl_into_array!(); +} + +// accessors +impl FixedSizeListArray { + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + debug_assert!(self.has_invariants()); + self.length + } + + /// The optional validity. + #[inline] + pub fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + /// Returns the inner array. + pub fn values(&self) -> &Box { + &self.values + } + + /// Returns the `Vec` at position `i`. + /// # Panic: + /// panics iff `i >= self.len()` + #[inline] + pub fn value(&self, i: usize) -> Box { + self.values.sliced(i * self.size, self.size) + } + + /// Returns the `Vec` at position `i`. + /// + /// # Safety + /// Caller must ensure that `i < self.len()` + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> Box { + self.values.sliced_unchecked(i * self.size, self.size) + } + + /// Returns the element at index `i` or `None` if it is null + /// # Panics + /// iff `i >= self.len()` + #[inline] + pub fn get(&self, i: usize) -> Option> { + if !self.is_null(i) { + // soundness: Array::is_null panics if i >= self.len + unsafe { Some(self.value_unchecked(i)) } + } else { + None + } + } +} + +impl FixedSizeListArray { + pub(crate) fn try_child_and_size(dtype: &ArrowDataType) -> PolarsResult<(&Field, usize)> { + match dtype.to_logical_type() { + ArrowDataType::FixedSizeList(child, size) => Ok((child.as_ref(), *size)), + _ => polars_bail!(ComputeError: "FixedSizeListArray expects DataType::FixedSizeList"), + } + } + + pub(crate) fn get_child_and_size(dtype: &ArrowDataType) -> (&Field, usize) { + Self::try_child_and_size(dtype).unwrap() + } + + /// Returns a [`ArrowDataType`] consistent with [`FixedSizeListArray`]. + pub fn default_datatype(dtype: ArrowDataType, size: usize) -> ArrowDataType { + let field = Box::new(Field::new(PlSmallStr::from_static("item"), dtype, true)); + ArrowDataType::FixedSizeList(field, size) + } +} + +impl Array for FixedSizeListArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + #[inline] + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } +} + +impl Splitable for FixedSizeListArray { + fn check_bound(&self, offset: usize) -> bool { + offset <= self.len() + } + + unsafe fn _split_at_unchecked(&self, offset: usize) -> (Self, Self) { + let (lhs_values, rhs_values) = + unsafe { self.values.split_at_boxed_unchecked(offset * self.size) }; + let (lhs_validity, rhs_validity) = unsafe { self.validity.split_at_unchecked(offset) }; + + let size = self.size; + + ( + Self { + dtype: self.dtype.clone(), + length: offset, + values: lhs_values, + validity: lhs_validity, + size, + }, + Self { + dtype: self.dtype.clone(), + length: self.length - offset, + values: rhs_values, + validity: rhs_validity, + size, + }, + ) + } +} diff --git a/crates/polars-arrow/src/array/fixed_size_list/mutable.rs b/crates/polars-arrow/src/array/fixed_size_list/mutable.rs new file mode 100644 index 000000000000..79014192394a --- /dev/null +++ b/crates/polars-arrow/src/array/fixed_size_list/mutable.rs @@ -0,0 +1,292 @@ +use std::sync::Arc; + +use polars_error::{PolarsResult, polars_bail}; +use polars_utils::pl_str::PlSmallStr; + +use super::FixedSizeListArray; +use crate::array::physical_binary::extend_validity; +use crate::array::{Array, MutableArray, PushUnchecked, TryExtend, TryExtendFromSelf, TryPush}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::{ArrowDataType, Field}; + +/// The mutable version of [`FixedSizeListArray`]. +#[derive(Debug, Clone)] +pub struct MutableFixedSizeListArray { + dtype: ArrowDataType, + size: usize, + length: usize, + values: M, + validity: Option, +} + +impl From> for FixedSizeListArray { + fn from(mut other: MutableFixedSizeListArray) -> Self { + FixedSizeListArray::new( + other.dtype, + other.length, + other.values.as_box(), + other.validity.map(|x| x.into()), + ) + } +} + +impl MutableFixedSizeListArray { + /// Creates a new [`MutableFixedSizeListArray`] from a [`MutableArray`] and size. + pub fn new(values: M, size: usize) -> Self { + let dtype = FixedSizeListArray::default_datatype(values.dtype().clone(), size); + Self::new_from(values, dtype, size) + } + + /// Creates a new [`MutableFixedSizeListArray`] from a [`MutableArray`] and size. + pub fn new_with_field(values: M, name: PlSmallStr, nullable: bool, size: usize) -> Self { + let dtype = ArrowDataType::FixedSizeList( + Box::new(Field::new(name, values.dtype().clone(), nullable)), + size, + ); + Self::new_from(values, dtype, size) + } + + /// Creates a new [`MutableFixedSizeListArray`] from a [`MutableArray`], [`ArrowDataType`] and size. + pub fn new_from(values: M, dtype: ArrowDataType, size: usize) -> Self { + assert_eq!(values.len(), 0); + match dtype { + ArrowDataType::FixedSizeList(..) => (), + _ => panic!("data type must be FixedSizeList (got {dtype:?})"), + }; + Self { + size, + length: 0, + dtype, + values, + validity: None, + } + } + + #[inline] + fn has_valid_invariants(&self) -> bool { + (self.size == 0 && self.values().len() == 0) + || (self.size > 0 && self.values.len() / self.size == self.length) + } + + /// Returns the size (number of elements per slot) of this [`FixedSizeListArray`]. + pub const fn size(&self) -> usize { + self.size + } + + /// The length of this array + pub fn len(&self) -> usize { + debug_assert!(self.has_valid_invariants()); + self.length + } + + /// The inner values + pub fn values(&self) -> &M { + &self.values + } + + fn init_validity(&mut self) { + let len = self.values.len() / self.size; + + let mut validity = MutableBitmap::new(); + validity.extend_constant(len, true); + validity.set(len - 1, false); + self.validity = Some(validity) + } + + #[inline] + /// Needs to be called when a valid value was extended to this array. + /// This is a relatively low level function, prefer `try_push` when you can. + pub fn try_push_valid(&mut self) -> PolarsResult<()> { + if self.values.len() % self.size != 0 { + polars_bail!(ComputeError: "overflow") + }; + if let Some(validity) = &mut self.validity { + validity.push(true) + } + self.length += 1; + + debug_assert!(self.has_valid_invariants()); + + Ok(()) + } + + #[inline] + /// Needs to be called when a valid value was extended to this array. + /// This is a relatively low level function, prefer `try_push` when you can. + pub fn push_valid(&mut self) { + if let Some(validity) = &mut self.validity { + validity.push(true) + } + self.length += 1; + + debug_assert!(self.has_valid_invariants()); + } + + #[inline] + fn push_null(&mut self) { + (0..self.size).for_each(|_| self.values.push_null()); + match &mut self.validity { + Some(validity) => validity.push(false), + None => self.init_validity(), + } + self.length += 1; + + debug_assert!(self.has_valid_invariants()); + } + + /// Reserves `additional` slots. + pub fn reserve(&mut self, additional: usize) { + self.values.reserve(additional); + if let Some(x) = self.validity.as_mut() { + x.reserve(additional) + } + } + + /// Shrinks the capacity of the [`MutableFixedSizeListArray`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + if let Some(validity) = &mut self.validity { + validity.shrink_to_fit() + } + } +} + +impl MutableArray for MutableFixedSizeListArray { + fn len(&self) -> usize { + debug_assert!(self.has_valid_invariants()); + self.length + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + FixedSizeListArray::new( + self.dtype.clone(), + self.length, + self.values.as_box(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .boxed() + } + + fn as_arc(&mut self) -> Arc { + FixedSizeListArray::new( + self.dtype.clone(), + self.length, + self.values.as_box(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .arced() + } + + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + (0..self.size).for_each(|_| { + self.values.push_null(); + }); + if let Some(validity) = &mut self.validity { + validity.push(false) + } else { + self.init_validity() + } + self.length += 1; + + debug_assert!(self.has_valid_invariants()); + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional) + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit() + } +} + +impl TryExtend> for MutableFixedSizeListArray +where + M: MutableArray + TryExtend>, + I: IntoIterator>, +{ + #[inline] + fn try_extend>>(&mut self, iter: II) -> PolarsResult<()> { + for items in iter { + self.try_push(items)?; + } + + debug_assert!(self.has_valid_invariants()); + + Ok(()) + } +} + +impl TryPush> for MutableFixedSizeListArray +where + M: MutableArray + TryExtend>, + I: IntoIterator>, +{ + #[inline] + fn try_push(&mut self, item: Option) -> PolarsResult<()> { + if let Some(items) = item { + self.values.try_extend(items)?; + self.try_push_valid()?; + } else { + self.push_null(); + } + + debug_assert!(self.has_valid_invariants()); + + Ok(()) + } +} + +impl PushUnchecked> for MutableFixedSizeListArray +where + M: MutableArray + Extend>, + I: IntoIterator>, +{ + /// # Safety + /// The caller must ensure that the `I` iterates exactly over `size` + /// items, where `size` is the fixed size width. + #[inline] + unsafe fn push_unchecked(&mut self, item: Option) { + if let Some(items) = item { + self.values.extend(items); + self.push_valid(); + } else { + self.push_null(); + } + + debug_assert!(self.has_valid_invariants()); + } +} + +impl TryExtendFromSelf for MutableFixedSizeListArray +where + M: MutableArray + TryExtendFromSelf, +{ + fn try_extend_from_self(&mut self, other: &Self) -> PolarsResult<()> { + extend_validity(self.len(), &mut self.validity, &other.validity); + + self.values.try_extend_from_self(&other.values)?; + self.length += other.len(); + + debug_assert!(self.has_valid_invariants()); + + Ok(()) + } +} diff --git a/crates/polars-arrow/src/array/fmt.rs b/crates/polars-arrow/src/array/fmt.rs new file mode 100644 index 000000000000..6b3fc21752b1 --- /dev/null +++ b/crates/polars-arrow/src/array/fmt.rs @@ -0,0 +1,196 @@ +use std::fmt::{Result, Write}; + +use super::Array; +use crate::bitmap::Bitmap; +use crate::{match_integer_type, with_match_primitive_type_full}; + +/// Returns a function that writes the value of the element of `array` +/// at position `index` to a [`Write`], +/// writing `null` in the null slots. +pub fn get_value_display<'a, F: Write + 'a>( + array: &'a dyn Array, + null: &'static str, +) -> Box Result + 'a> { + use crate::datatypes::PhysicalType::*; + match array.dtype().to_physical_type() { + Null => Box::new(move |f, _| write!(f, "{null}")), + Boolean => Box::new(|f, index| { + super::boolean::fmt::write_value(array.as_any().downcast_ref().unwrap(), index, f) + }), + Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { + let writer = super::primitive::fmt::get_write_value::<$T, _>( + array.as_any().downcast_ref().unwrap(), + ); + Box::new(move |f, index| writer(f, index)) + }), + Binary => Box::new(|f, index| { + super::binary::fmt::write_value::( + array.as_any().downcast_ref().unwrap(), + index, + f, + ) + }), + FixedSizeBinary => Box::new(|f, index| { + super::fixed_size_binary::fmt::write_value( + array.as_any().downcast_ref().unwrap(), + index, + f, + ) + }), + LargeBinary => Box::new(|f, index| { + super::binary::fmt::write_value::( + array.as_any().downcast_ref().unwrap(), + index, + f, + ) + }), + Utf8 => Box::new(|f, index| { + super::utf8::fmt::write_value::( + array.as_any().downcast_ref().unwrap(), + index, + f, + ) + }), + LargeUtf8 => Box::new(|f, index| { + super::utf8::fmt::write_value::( + array.as_any().downcast_ref().unwrap(), + index, + f, + ) + }), + List => Box::new(move |f, index| { + super::list::fmt::write_value::( + array.as_any().downcast_ref().unwrap(), + index, + null, + f, + ) + }), + FixedSizeList => Box::new(move |f, index| { + super::fixed_size_list::fmt::write_value( + array.as_any().downcast_ref().unwrap(), + index, + null, + f, + ) + }), + LargeList => Box::new(move |f, index| { + super::list::fmt::write_value::( + array.as_any().downcast_ref().unwrap(), + index, + null, + f, + ) + }), + Struct => Box::new(move |f, index| { + super::struct_::fmt::write_value(array.as_any().downcast_ref().unwrap(), index, null, f) + }), + Union => Box::new(move |f, index| { + super::union::fmt::write_value(array.as_any().downcast_ref().unwrap(), index, null, f) + }), + Map => Box::new(move |f, index| { + super::map::fmt::write_value(array.as_any().downcast_ref().unwrap(), index, null, f) + }), + BinaryView => Box::new(move |f, index| { + super::binview::fmt::write_value::<[u8], _>( + array.as_any().downcast_ref().unwrap(), + index, + f, + ) + }), + Utf8View => Box::new(move |f, index| { + super::binview::fmt::write_value::( + array.as_any().downcast_ref().unwrap(), + index, + f, + ) + }), + Dictionary(key_type) => match_integer_type!(key_type, |$T| { + Box::new(move |f, index| { + super::dictionary::fmt::write_value::<$T,_>(array.as_any().downcast_ref().unwrap(), index, null, f) + }) + }), + } +} + +/// Returns a function that writes the element of `array` +/// at position `index` to a [`Write`], writing `null` to the null slots. +pub fn get_display<'a, F: Write + 'a>( + array: &'a dyn Array, + null: &'static str, +) -> Box Result + 'a> { + let value_display = get_value_display(array, null); + Box::new(move |f, row| { + if array.is_null(row) { + f.write_str(null) + } else { + value_display(f, row) + } + }) +} + +pub fn write_vec( + f: &mut F, + d: D, + validity: Option<&Bitmap>, + len: usize, + null: &'static str, + new_lines: bool, +) -> Result +where + D: Fn(&mut F, usize) -> Result, + F: Write, +{ + f.write_char('[')?; + write_list(f, d, validity, len, null, new_lines)?; + f.write_char(']')?; + Ok(()) +} + +fn write_list( + f: &mut F, + d: D, + validity: Option<&Bitmap>, + len: usize, + null: &'static str, + new_lines: bool, +) -> Result +where + D: Fn(&mut F, usize) -> Result, + F: Write, +{ + for index in 0..len { + if index != 0 { + f.write_char(',')?; + f.write_char(if new_lines { '\n' } else { ' ' })?; + } + if let Some(val) = validity { + if val.get_bit(index) { + d(f, index) + } else { + write!(f, "{null}") + } + } else { + d(f, index) + }?; + } + Ok(()) +} + +pub fn write_map( + f: &mut F, + d: D, + validity: Option<&Bitmap>, + len: usize, + null: &'static str, + new_lines: bool, +) -> Result +where + D: Fn(&mut F, usize) -> Result, + F: Write, +{ + f.write_char('{')?; + write_list(f, d, validity, len, null, new_lines)?; + f.write_char('}')?; + Ok(()) +} diff --git a/crates/polars-arrow/src/array/indexable.rs b/crates/polars-arrow/src/array/indexable.rs new file mode 100644 index 000000000000..f071a731663c --- /dev/null +++ b/crates/polars-arrow/src/array/indexable.rs @@ -0,0 +1,217 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use std::borrow::Borrow; + +use crate::array::{ + MutableArray, MutableBinaryArray, MutableBinaryValuesArray, MutableBinaryViewArray, + MutableBooleanArray, MutableFixedSizeBinaryArray, MutablePrimitiveArray, MutableUtf8Array, + MutableUtf8ValuesArray, ViewType, +}; +use crate::offset::Offset; +use crate::types::NativeType; + +/// Trait for arrays that can be indexed directly to extract a value. +pub trait Indexable { + /// The type of the element at index `i`; may be a reference type or a value type. + type Value<'a>: Borrow + where + Self: 'a; + + type Type: ?Sized; + + /// Returns the element at index `i`. + /// # Panic + /// May panic if `i >= self.len()`. + fn value_at(&self, index: usize) -> Self::Value<'_>; + + /// Returns the element at index `i`. + /// + /// # Safety + /// Assumes that the `i < self.len`. + #[inline] + unsafe fn value_unchecked_at(&self, index: usize) -> Self::Value<'_> { + self.value_at(index) + } +} + +pub trait AsIndexed { + fn as_indexed(&self) -> &M::Type; +} + +impl Indexable for MutableBooleanArray { + type Value<'a> = bool; + type Type = bool; + + #[inline] + fn value_at(&self, i: usize) -> Self::Value<'_> { + self.values().get(i) + } +} + +impl AsIndexed for bool { + #[inline] + fn as_indexed(&self) -> &bool { + self + } +} + +impl Indexable for MutableBinaryArray { + type Value<'a> = &'a [u8]; + type Type = [u8]; + + #[inline] + fn value_at(&self, i: usize) -> Self::Value<'_> { + // TODO: add .value() / .value_unchecked() to MutableBinaryArray? + assert!(i < self.len()); + unsafe { self.value_unchecked_at(i) } + } + + #[inline] + unsafe fn value_unchecked_at(&self, i: usize) -> Self::Value<'_> { + // TODO: add .value() / .value_unchecked() to MutableBinaryArray? + // soundness: the invariant of the function + let (start, end) = self.offsets().start_end_unchecked(i); + // soundness: the invariant of the struct + self.values().get_unchecked(start..end) + } +} + +impl AsIndexed> for &[u8] { + #[inline] + fn as_indexed(&self) -> &[u8] { + self + } +} + +impl Indexable for MutableBinaryValuesArray { + type Value<'a> = &'a [u8]; + type Type = [u8]; + + #[inline] + fn value_at(&self, i: usize) -> Self::Value<'_> { + self.value(i) + } + + #[inline] + unsafe fn value_unchecked_at(&self, i: usize) -> Self::Value<'_> { + self.value_unchecked(i) + } +} + +impl AsIndexed> for &[u8] { + #[inline] + fn as_indexed(&self) -> &[u8] { + self + } +} + +impl Indexable for MutableFixedSizeBinaryArray { + type Value<'a> = &'a [u8]; + type Type = [u8]; + + #[inline] + fn value_at(&self, i: usize) -> Self::Value<'_> { + self.value(i) + } + + #[inline] + unsafe fn value_unchecked_at(&self, i: usize) -> Self::Value<'_> { + // soundness: the invariant of the struct + self.value_unchecked(i) + } +} + +impl AsIndexed for &[u8] { + #[inline] + fn as_indexed(&self) -> &[u8] { + self + } +} + +impl Indexable for MutableBinaryViewArray { + type Value<'a> = &'a T; + type Type = T; + + fn value_at(&self, index: usize) -> Self::Value<'_> { + self.value(index) + } + + unsafe fn value_unchecked_at(&self, index: usize) -> Self::Value<'_> { + self.value_unchecked(index) + } +} + +impl AsIndexed> for &T { + #[inline] + fn as_indexed(&self) -> &T { + self + } +} + +// TODO: should NativeType derive from Hash? +impl Indexable for MutablePrimitiveArray { + type Value<'a> = T; + type Type = T; + + #[inline] + fn value_at(&self, i: usize) -> Self::Value<'_> { + assert!(i < self.len()); + // TODO: add Length trait? (for both Array and MutableArray) + unsafe { self.value_unchecked_at(i) } + } + + #[inline] + unsafe fn value_unchecked_at(&self, i: usize) -> Self::Value<'_> { + *self.values().get_unchecked(i) + } +} + +impl AsIndexed> for T { + #[inline] + fn as_indexed(&self) -> &T { + self + } +} + +impl Indexable for MutableUtf8Array { + type Value<'a> = &'a str; + type Type = str; + + #[inline] + fn value_at(&self, i: usize) -> Self::Value<'_> { + self.value(i) + } + + #[inline] + unsafe fn value_unchecked_at(&self, i: usize) -> Self::Value<'_> { + self.value_unchecked(i) + } +} + +impl> AsIndexed> for V { + #[inline] + fn as_indexed(&self) -> &str { + self.as_ref() + } +} + +impl Indexable for MutableUtf8ValuesArray { + type Value<'a> = &'a str; + type Type = str; + + #[inline] + fn value_at(&self, i: usize) -> Self::Value<'_> { + self.value(i) + } + + #[inline] + unsafe fn value_unchecked_at(&self, i: usize) -> Self::Value<'_> { + self.value_unchecked(i) + } +} + +impl> AsIndexed> for V { + #[inline] + fn as_indexed(&self) -> &str { + self.as_ref() + } +} diff --git a/crates/polars-arrow/src/array/iterator.rs b/crates/polars-arrow/src/array/iterator.rs new file mode 100644 index 000000000000..0b8bd1de8ee9 --- /dev/null +++ b/crates/polars-arrow/src/array/iterator.rs @@ -0,0 +1,128 @@ +use crate::bitmap::Bitmap; +use crate::bitmap::iterator::TrueIdxIter; +use crate::trusted_len::TrustedLen; + +mod private { + pub trait Sealed {} + + impl<'a, T: super::ArrayAccessor<'a> + ?Sized> Sealed for T {} +} + +/// Sealed trait representing access to a value of an array. +/// # Safety +/// Implementers of this trait guarantee that +/// `value_unchecked` is safe when called up to `len` +pub unsafe trait ArrayAccessor<'a>: private::Sealed { + type Item: 'a; + /// # Safety + /// The index must be in-bounds in the array. + unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item; + fn len(&self) -> usize; +} + +/// Iterator of values of an [`ArrayAccessor`]. +#[derive(Debug, Clone)] +pub struct ArrayValuesIter<'a, A: ArrayAccessor<'a>> { + array: &'a A, + index: usize, + end: usize, +} + +impl<'a, A: ArrayAccessor<'a>> ArrayValuesIter<'a, A> { + /// Creates a new [`ArrayValuesIter`] + #[inline] + pub fn new(array: &'a A) -> Self { + Self { + array, + index: 0, + end: array.len(), + } + } +} + +impl<'a, A: ArrayAccessor<'a>> Iterator for ArrayValuesIter<'a, A> { + type Item = A::Item; + + #[inline] + fn next(&mut self) -> Option { + if self.index == self.end { + return None; + } + let old = self.index; + self.index += 1; + Some(unsafe { self.array.value_unchecked(old) }) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.end - self.index, Some(self.end - self.index)) + } + + #[inline] + fn nth(&mut self, n: usize) -> Option { + let new_index = self.index + n; + if new_index > self.end { + self.index = self.end; + None + } else { + self.index = new_index; + self.next() + } + } +} + +impl<'a, A: ArrayAccessor<'a>> DoubleEndedIterator for ArrayValuesIter<'a, A> { + #[inline] + fn next_back(&mut self) -> Option { + if self.index == self.end { + None + } else { + self.end -= 1; + Some(unsafe { self.array.value_unchecked(self.end) }) + } + } +} + +unsafe impl<'a, A: ArrayAccessor<'a>> TrustedLen for ArrayValuesIter<'a, A> {} +impl<'a, A: ArrayAccessor<'a>> ExactSizeIterator for ArrayValuesIter<'a, A> {} + +pub struct NonNullValuesIter<'a, A: ?Sized> { + accessor: &'a A, + idxs: TrueIdxIter<'a>, +} + +impl<'a, A: ArrayAccessor<'a> + ?Sized> NonNullValuesIter<'a, A> { + pub fn new(accessor: &'a A, validity: Option<&'a Bitmap>) -> Self { + Self { + idxs: TrueIdxIter::new(accessor.len(), validity), + accessor, + } + } +} + +impl<'a, A: ArrayAccessor<'a> + ?Sized> Iterator for NonNullValuesIter<'a, A> { + type Item = A::Item; + + #[inline] + fn next(&mut self) -> Option { + if let Some(i) = self.idxs.next() { + return Some(unsafe { self.accessor.value_unchecked(i) }); + } + None + } + + fn size_hint(&self) -> (usize, Option) { + self.idxs.size_hint() + } +} + +unsafe impl<'a, A: ArrayAccessor<'a> + ?Sized> TrustedLen for NonNullValuesIter<'a, A> {} + +impl Clone for NonNullValuesIter<'_, A> { + fn clone(&self) -> Self { + Self { + accessor: self.accessor, + idxs: self.idxs.clone(), + } + } +} diff --git a/crates/polars-arrow/src/array/list/builder.rs b/crates/polars-arrow/src/array/list/builder.rs new file mode 100644 index 000000000000..909e3caaefc4 --- /dev/null +++ b/crates/polars-arrow/src/array/list/builder.rs @@ -0,0 +1,199 @@ +use polars_utils::IdxSize; + +use super::ListArray; +use crate::array::builder::{ArrayBuilder, ShareStrategy, StaticArrayBuilder}; +use crate::bitmap::OptBitmapBuilder; +use crate::datatypes::ArrowDataType; +use crate::offset::{Offset, Offsets, OffsetsBuffer}; + +pub struct ListArrayBuilder { + dtype: ArrowDataType, + offsets: Offsets, + inner_builder: B, + validity: OptBitmapBuilder, +} + +impl ListArrayBuilder { + pub fn new(dtype: ArrowDataType, inner_builder: B) -> Self { + Self { + dtype, + inner_builder, + offsets: Offsets::new(), + validity: OptBitmapBuilder::default(), + } + } +} + +impl StaticArrayBuilder for ListArrayBuilder { + type Array = ListArray; + + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } + + fn reserve(&mut self, additional: usize) { + self.offsets.reserve(additional); + self.validity.reserve(additional); + // No inner reserve, we have no idea how large it needs to be. + } + + fn freeze(self) -> ListArray { + let offsets = OffsetsBuffer::from(self.offsets); + let values = self.inner_builder.freeze(); + let validity = self.validity.into_opt_validity(); + ListArray::new(self.dtype, offsets, values, validity) + } + + fn freeze_reset(&mut self) -> Self::Array { + let offsets = OffsetsBuffer::from(core::mem::take(&mut self.offsets)); + let values = self.inner_builder.freeze_reset(); + let validity = core::mem::take(&mut self.validity).into_opt_validity(); + ListArray::new(self.dtype.clone(), offsets, values, validity) + } + + fn len(&self) -> usize { + self.offsets.len_proxy() + } + + fn extend_nulls(&mut self, length: usize) { + self.offsets.extend_constant(length); + self.validity.extend_constant(length, false); + } + + fn subslice_extend( + &mut self, + other: &ListArray, + start: usize, + length: usize, + share: ShareStrategy, + ) { + let start_offset = other.offsets()[start].to_usize(); + let stop_offset = other.offsets()[start + length].to_usize(); + self.offsets + .try_extend_from_slice(other.offsets(), start, length) + .unwrap(); + self.inner_builder.subslice_extend( + &**other.values(), + start_offset, + stop_offset - start_offset, + share, + ); + self.validity + .subslice_extend_from_opt_validity(other.validity(), start, length); + } + + unsafe fn gather_extend( + &mut self, + other: &ListArray, + idxs: &[IdxSize], + share: ShareStrategy, + ) { + let other_values = &**other.values(); + let other_offsets = other.offsets(); + + // Pre-compute proper length for reserve. + let total_len: usize = idxs + .iter() + .map(|i| { + let start = other_offsets.get_unchecked(*i as usize).to_usize(); + let stop = other_offsets.get_unchecked(*i as usize + 1).to_usize(); + stop - start + }) + .sum(); + self.inner_builder.reserve(total_len); + + // Group consecutive indices into larger copies. + let mut group_start = 0; + while group_start < idxs.len() { + let start_idx = idxs[group_start] as usize; + let mut group_len = 1; + while group_start + group_len < idxs.len() + && idxs[group_start + group_len] as usize == start_idx + group_len + { + group_len += 1; + } + + let start_offset = other_offsets.get_unchecked(start_idx).to_usize(); + let stop_offset = other_offsets + .get_unchecked(start_idx + group_len) + .to_usize(); + self.offsets + .try_extend_from_slice(other_offsets, start_idx, group_len) + .unwrap(); + self.inner_builder.subslice_extend( + other_values, + start_offset, + stop_offset - start_offset, + share, + ); + group_start += group_len; + } + + self.validity + .gather_extend_from_opt_validity(other.validity(), idxs); + } + + fn opt_gather_extend(&mut self, other: &ListArray, idxs: &[IdxSize], share: ShareStrategy) { + let other_values = &**other.values(); + let other_offsets = other.offsets(); + + unsafe { + // Pre-compute proper length for reserve. + let total_len: usize = idxs + .iter() + .map(|idx| { + if (*idx as usize) < other.len() { + let start = other_offsets.get_unchecked(*idx as usize).to_usize(); + let stop = other_offsets.get_unchecked(*idx as usize + 1).to_usize(); + stop - start + } else { + 0 + } + }) + .sum(); + self.inner_builder.reserve(total_len); + + // Group consecutive indices into larger copies. + let mut group_start = 0; + while group_start < idxs.len() { + let start_idx = idxs[group_start] as usize; + let mut group_len = 1; + let in_bounds = start_idx < other.len(); + + if in_bounds { + while group_start + group_len < idxs.len() + && idxs[group_start + group_len] as usize == start_idx + group_len + && start_idx + group_len < other.len() + { + group_len += 1; + } + + let start_offset = other_offsets.get_unchecked(start_idx).to_usize(); + let stop_offset = other_offsets + .get_unchecked(start_idx + group_len) + .to_usize(); + self.offsets + .try_extend_from_slice(other_offsets, start_idx, group_len) + .unwrap(); + self.inner_builder.subslice_extend( + other_values, + start_offset, + stop_offset - start_offset, + share, + ); + } else { + while group_start + group_len < idxs.len() + && idxs[group_start + group_len] as usize >= other.len() + { + group_len += 1; + } + self.offsets.extend_constant(group_len); + } + group_start += group_len; + } + + self.validity + .opt_gather_extend_from_opt_validity(other.validity(), idxs, other.len()); + } + } +} diff --git a/crates/polars-arrow/src/array/list/ffi.rs b/crates/polars-arrow/src/array/list/ffi.rs new file mode 100644 index 000000000000..a39de3341e0c --- /dev/null +++ b/crates/polars-arrow/src/array/list/ffi.rs @@ -0,0 +1,69 @@ +use polars_error::PolarsResult; + +use super::super::Array; +use super::super::ffi::ToFfi; +use super::ListArray; +use crate::array::FromFfi; +use crate::bitmap::align; +use crate::ffi; +use crate::offset::{Offset, OffsetsBuffer}; + +unsafe impl ToFfi for ListArray { + fn buffers(&self) -> Vec> { + vec![ + self.validity.as_ref().map(|x| x.as_ptr()), + Some(self.offsets.buffer().storage_ptr().cast::()), + ] + } + + fn children(&self) -> Vec> { + vec![self.values.clone()] + } + + fn offset(&self) -> Option { + let offset = self.offsets.buffer().offset(); + if let Some(bitmap) = self.validity.as_ref() { + if bitmap.offset() == offset { + Some(offset) + } else { + None + } + } else { + Some(offset) + } + } + + fn to_ffi_aligned(&self) -> Self { + let offset = self.offsets.buffer().offset(); + + let validity = self.validity.as_ref().map(|bitmap| { + if bitmap.offset() == offset { + bitmap.clone() + } else { + align(bitmap, offset) + } + }); + + Self { + dtype: self.dtype.clone(), + validity, + offsets: self.offsets.clone(), + values: self.values.clone(), + } + } +} + +impl FromFfi for ListArray { + unsafe fn try_from_ffi(array: A) -> PolarsResult { + let dtype = array.dtype().clone(); + let validity = unsafe { array.validity() }?; + let offsets = unsafe { array.buffer::(1) }?; + let child = unsafe { array.child(0)? }; + let values = ffi::try_from(child)?; + + // assumption that data from FFI is well constructed + let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets) }; + + Self::try_new(dtype, offsets, values, validity) + } +} diff --git a/crates/polars-arrow/src/array/list/fmt.rs b/crates/polars-arrow/src/array/list/fmt.rs new file mode 100644 index 000000000000..67dcd6b78786 --- /dev/null +++ b/crates/polars-arrow/src/array/list/fmt.rs @@ -0,0 +1,30 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::{get_display, write_vec}; +use super::ListArray; +use crate::offset::Offset; + +pub fn write_value( + array: &ListArray, + index: usize, + null: &'static str, + f: &mut W, +) -> Result { + let values = array.value(index); + let writer = |f: &mut W, index| get_display(values.as_ref(), null)(f, index); + write_vec(f, writer, None, values.len(), null, false) +} + +impl Debug for ListArray { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, "None", f); + + let head = if O::IS_LARGE { + "LargeListArray" + } else { + "ListArray" + }; + write!(f, "{head}")?; + write_vec(f, writer, self.validity(), self.len(), "None", false) + } +} diff --git a/crates/polars-arrow/src/array/list/iterator.rs b/crates/polars-arrow/src/array/list/iterator.rs new file mode 100644 index 000000000000..5ae7f96a5c4c --- /dev/null +++ b/crates/polars-arrow/src/array/list/iterator.rs @@ -0,0 +1,51 @@ +use super::ListArray; +use crate::array::iterator::NonNullValuesIter; +use crate::array::{Array, ArrayAccessor, ArrayValuesIter}; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::offset::Offset; + +unsafe impl<'a, O: Offset> ArrayAccessor<'a> for ListArray { + type Item = Box; + + #[inline] + unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item { + self.value_unchecked(index) + } + + #[inline] + fn len(&self) -> usize { + self.len() + } +} + +/// Iterator of values of a [`ListArray`]. +pub type ListValuesIter<'a, O> = ArrayValuesIter<'a, ListArray>; + +type ZipIter<'a, O> = ZipValidity, ListValuesIter<'a, O>, BitmapIter<'a>>; + +impl<'a, O: Offset> IntoIterator for &'a ListArray { + type Item = Option>; + type IntoIter = ZipIter<'a, O>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a, O: Offset> ListArray { + /// Returns an iterator of `Option>` + pub fn iter(&'a self) -> ZipIter<'a, O> { + ZipValidity::new_with_validity(ListValuesIter::new(self), self.validity.as_ref()) + } + + /// Returns an iterator of `Box` + pub fn values_iter(&'a self) -> ListValuesIter<'a, O> { + ListValuesIter::new(self) + } + + /// Returns an iterator of the non-null values `Box`. + #[inline] + pub fn non_null_values_iter(&'a self) -> NonNullValuesIter<'a, ListArray> { + NonNullValuesIter::new(self, self.validity()) + } +} diff --git a/crates/polars-arrow/src/array/list/mod.rs b/crates/polars-arrow/src/array/list/mod.rs new file mode 100644 index 000000000000..f68ed475d13c --- /dev/null +++ b/crates/polars-arrow/src/array/list/mod.rs @@ -0,0 +1,309 @@ +use super::specification::try_check_offsets_bounds; +use super::{Array, Splitable, new_empty_array}; +use crate::bitmap::Bitmap; +use crate::datatypes::{ArrowDataType, Field}; +use crate::offset::{Offset, Offsets, OffsetsBuffer}; + +mod builder; +pub use builder::*; +mod ffi; +pub(super) mod fmt; +mod iterator; +pub use iterator::*; +mod mutable; +pub use mutable::*; +use polars_error::{PolarsResult, polars_bail}; +use polars_utils::pl_str::PlSmallStr; + +/// An [`Array`] semantically equivalent to `Vec>>>` with Arrow's in-memory. +#[derive(Clone)] +pub struct ListArray { + dtype: ArrowDataType, + offsets: OffsetsBuffer, + values: Box, + validity: Option, +} + +impl ListArray { + /// Creates a new [`ListArray`]. + /// + /// # Errors + /// This function returns an error iff: + /// * `offsets.last()` is greater than `values.len()`. + /// * the validity's length is not equal to `offsets.len_proxy()`. + /// * The `dtype`'s [`crate::datatypes::PhysicalType`] is not equal to either [`crate::datatypes::PhysicalType::List`] or [`crate::datatypes::PhysicalType::LargeList`]. + /// * The `dtype`'s inner field's data type is not equal to `values.dtype`. + /// # Implementation + /// This function is `O(1)` + pub fn try_new( + dtype: ArrowDataType, + offsets: OffsetsBuffer, + values: Box, + validity: Option, + ) -> PolarsResult { + try_check_offsets_bounds(&offsets, values.len())?; + + if validity + .as_ref() + .is_some_and(|validity| validity.len() != offsets.len_proxy()) + { + polars_bail!(ComputeError: "validity mask length must match the number of values") + } + + let child_dtype = Self::try_get_child(&dtype)?.dtype(); + let values_dtype = values.dtype(); + if child_dtype != values_dtype { + polars_bail!(ComputeError: "ListArray's child's DataType must match. However, the expected DataType is {child_dtype:?} while it got {values_dtype:?}."); + } + + Ok(Self { + dtype, + offsets, + values, + validity, + }) + } + + /// Creates a new [`ListArray`]. + /// + /// # Panics + /// This function panics iff: + /// * `offsets.last()` is greater than `values.len()`. + /// * the validity's length is not equal to `offsets.len_proxy()`. + /// * The `dtype`'s [`crate::datatypes::PhysicalType`] is not equal to either [`crate::datatypes::PhysicalType::List`] or [`crate::datatypes::PhysicalType::LargeList`]. + /// * The `dtype`'s inner field's data type is not equal to `values.dtype`. + /// # Implementation + /// This function is `O(1)` + pub fn new( + dtype: ArrowDataType, + offsets: OffsetsBuffer, + values: Box, + validity: Option, + ) -> Self { + Self::try_new(dtype, offsets, values, validity).unwrap() + } + + /// Returns a new empty [`ListArray`]. + pub fn new_empty(dtype: ArrowDataType) -> Self { + let values = new_empty_array(Self::get_child_type(&dtype).clone()); + Self::new(dtype, OffsetsBuffer::default(), values, None) + } + + /// Returns a new null [`ListArray`]. + #[inline] + pub fn new_null(dtype: ArrowDataType, length: usize) -> Self { + let child = Self::get_child_type(&dtype).clone(); + Self::new( + dtype, + Offsets::new_zeroed(length).into(), + new_empty_array(child), + Some(Bitmap::new_zeroed(length)), + ) + } +} + +impl ListArray { + /// Slices this [`ListArray`]. + /// # Panics + /// panics iff `offset + length > self.len()` + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Slices this [`ListArray`]. + /// + /// # Safety + /// The caller must ensure that `offset + length < self.len()`. + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.validity = self + .validity + .take() + .map(|bitmap| bitmap.sliced_unchecked(offset, length)) + .filter(|bitmap| bitmap.unset_bits() > 0); + self.offsets.slice_unchecked(offset, length + 1); + } + + impl_sliced!(); + impl_mut_validity!(); + impl_into_array!(); + + pub fn trim_to_normalized_offsets_recursive(&self) -> Self { + let offsets = self.offsets(); + let values = self.values(); + + let first_idx = *offsets.first(); + let len = offsets.range().to_usize(); + + let values = if values.len() == len { + values.clone() + } else { + values.sliced(first_idx.to_usize(), len) + }; + + let offsets = if first_idx.to_usize() == 0 { + offsets.clone() + } else { + let v = offsets.iter().map(|x| *x - first_idx).collect::>(); + unsafe { OffsetsBuffer::::new_unchecked(v.into()) } + }; + + let values = match values.dtype() { + ArrowDataType::List(_) => { + let inner: &ListArray = values.as_ref().as_any().downcast_ref().unwrap(); + Box::new(inner.trim_to_normalized_offsets_recursive()) as Box + }, + ArrowDataType::LargeList(_) => { + let inner: &ListArray = values.as_ref().as_any().downcast_ref().unwrap(); + Box::new(inner.trim_to_normalized_offsets_recursive()) as Box + }, + _ => values, + }; + + assert_eq!(offsets.first().to_usize(), 0); + assert_eq!(values.len(), offsets.range().to_usize()); + + Self::new( + self.dtype().clone(), + offsets, + values, + self.validity().cloned(), + ) + } +} + +// Accessors +impl ListArray { + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.offsets.len_proxy() + } + + /// Returns the element at index `i` + /// # Panic + /// Panics iff `i >= self.len()` + #[inline] + pub fn value(&self, i: usize) -> Box { + assert!(i < self.len()); + // SAFETY: invariant of this function + unsafe { self.value_unchecked(i) } + } + + /// Returns the element at index `i` as &str + /// + /// # Safety + /// Assumes that the `i < self.len`. + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> Box { + // SAFETY: the invariant of the function + let (start, end) = self.offsets.start_end_unchecked(i); + let length = end - start; + + // SAFETY: the invariant of the struct + self.values.sliced_unchecked(start, length) + } + + /// The optional validity. + #[inline] + pub fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + /// The offsets [`Buffer`]. + #[inline] + pub fn offsets(&self) -> &OffsetsBuffer { + &self.offsets + } + + /// The values. + #[inline] + pub fn values(&self) -> &Box { + &self.values + } +} + +impl ListArray { + /// Returns a default [`ArrowDataType`]: inner field is named "item" and is nullable + pub fn default_datatype(dtype: ArrowDataType) -> ArrowDataType { + let field = Box::new(Field::new(PlSmallStr::from_static("item"), dtype, true)); + if O::IS_LARGE { + ArrowDataType::LargeList(field) + } else { + ArrowDataType::List(field) + } + } + + /// Returns a the inner [`Field`] + /// # Panics + /// Panics iff the logical type is not consistent with this struct. + pub fn get_child_field(dtype: &ArrowDataType) -> &Field { + Self::try_get_child(dtype).unwrap() + } + + /// Returns a the inner [`Field`] + /// # Errors + /// Panics iff the logical type is not consistent with this struct. + pub fn try_get_child(dtype: &ArrowDataType) -> PolarsResult<&Field> { + if O::IS_LARGE { + match dtype.to_logical_type() { + ArrowDataType::LargeList(child) => Ok(child.as_ref()), + _ => polars_bail!(ComputeError: "ListArray expects DataType::LargeList"), + } + } else { + match dtype.to_logical_type() { + ArrowDataType::List(child) => Ok(child.as_ref()), + _ => polars_bail!(ComputeError: "ListArray expects DataType::List"), + } + } + } + + /// Returns a the inner [`ArrowDataType`] + /// # Panics + /// Panics iff the logical type is not consistent with this struct. + pub fn get_child_type(dtype: &ArrowDataType) -> &ArrowDataType { + Self::get_child_field(dtype).dtype() + } +} + +impl Array for ListArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + #[inline] + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } +} + +impl Splitable for ListArray { + fn check_bound(&self, offset: usize) -> bool { + offset <= self.len() + } + + unsafe fn _split_at_unchecked(&self, offset: usize) -> (Self, Self) { + let (lhs_offsets, rhs_offsets) = unsafe { self.offsets.split_at_unchecked(offset) }; + let (lhs_validity, rhs_validity) = unsafe { self.validity.split_at_unchecked(offset) }; + + ( + Self { + dtype: self.dtype.clone(), + offsets: lhs_offsets, + validity: lhs_validity, + values: self.values.clone(), + }, + Self { + dtype: self.dtype.clone(), + offsets: rhs_offsets, + validity: rhs_validity, + values: self.values.clone(), + }, + ) + } +} diff --git a/crates/polars-arrow/src/array/list/mutable.rs b/crates/polars-arrow/src/array/list/mutable.rs new file mode 100644 index 000000000000..21310dc62524 --- /dev/null +++ b/crates/polars-arrow/src/array/list/mutable.rs @@ -0,0 +1,319 @@ +use std::sync::Arc; + +use polars_error::{PolarsResult, polars_err}; +use polars_utils::pl_str::PlSmallStr; + +use super::ListArray; +use crate::array::physical_binary::extend_validity; +use crate::array::{Array, MutableArray, TryExtend, TryExtendFromSelf, TryPush}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::{ArrowDataType, Field}; +use crate::offset::{Offset, Offsets}; +use crate::trusted_len::TrustedLen; + +/// The mutable version of [`ListArray`]. +#[derive(Debug, Clone)] +pub struct MutableListArray { + dtype: ArrowDataType, + offsets: Offsets, + values: M, + validity: Option, +} + +impl MutableListArray { + /// Creates a new empty [`MutableListArray`]. + pub fn new() -> Self { + let values = M::default(); + let dtype = ListArray::::default_datatype(values.dtype().clone()); + Self::new_from(values, dtype, 0) + } + + /// Creates a new [`MutableListArray`] with a capacity. + pub fn with_capacity(capacity: usize) -> Self { + let values = M::default(); + let dtype = ListArray::::default_datatype(values.dtype().clone()); + + let offsets = Offsets::::with_capacity(capacity); + Self { + dtype, + offsets, + values, + validity: None, + } + } +} + +impl Default for MutableListArray { + fn default() -> Self { + Self::new() + } +} + +impl From> for ListArray { + fn from(mut other: MutableListArray) -> Self { + ListArray::new( + other.dtype, + other.offsets.into(), + other.values.as_box(), + other.validity.map(|x| x.into()), + ) + } +} + +impl TryExtend> for MutableListArray +where + O: Offset, + M: MutableArray + TryExtend>, + I: IntoIterator>, +{ + fn try_extend>>(&mut self, iter: II) -> PolarsResult<()> { + let iter = iter.into_iter(); + self.reserve(iter.size_hint().0); + for items in iter { + self.try_push(items)?; + } + Ok(()) + } +} + +impl TryPush> for MutableListArray +where + O: Offset, + M: MutableArray + TryExtend>, + I: IntoIterator>, +{ + #[inline] + fn try_push(&mut self, item: Option) -> PolarsResult<()> { + if let Some(items) = item { + let values = self.mut_values(); + values.try_extend(items)?; + self.try_push_valid()?; + } else { + self.push_null(); + } + Ok(()) + } +} + +impl TryExtendFromSelf for MutableListArray +where + O: Offset, + M: MutableArray + TryExtendFromSelf, +{ + fn try_extend_from_self(&mut self, other: &Self) -> PolarsResult<()> { + extend_validity(self.len(), &mut self.validity, &other.validity); + + self.values.try_extend_from_self(&other.values)?; + self.offsets.try_extend_from_self(&other.offsets) + } +} + +impl MutableListArray { + /// Creates a new [`MutableListArray`] from a [`MutableArray`] and capacity. + pub fn new_from(values: M, dtype: ArrowDataType, capacity: usize) -> Self { + let offsets = Offsets::::with_capacity(capacity); + assert_eq!(values.len(), 0); + ListArray::::get_child_field(&dtype); + Self { + dtype, + offsets, + values, + validity: None, + } + } + + /// Creates a new [`MutableListArray`] from a [`MutableArray`]. + pub fn new_with_field(values: M, name: PlSmallStr, nullable: bool) -> Self { + let field = Box::new(Field::new(name, values.dtype().clone(), nullable)); + let dtype = if O::IS_LARGE { + ArrowDataType::LargeList(field) + } else { + ArrowDataType::List(field) + }; + Self::new_from(values, dtype, 0) + } + + /// Creates a new [`MutableListArray`] from a [`MutableArray`] and capacity. + pub fn new_with_capacity(values: M, capacity: usize) -> Self { + let dtype = ListArray::::default_datatype(values.dtype().clone()); + Self::new_from(values, dtype, capacity) + } + + /// Creates a new [`MutableListArray`] from a [`MutableArray`], [`Offsets`] and + /// [`MutableBitmap`]. + pub fn new_from_mutable( + values: M, + offsets: Offsets, + validity: Option, + ) -> Self { + assert_eq!(values.len(), offsets.last().to_usize()); + let dtype = ListArray::::default_datatype(values.dtype().clone()); + Self { + dtype, + offsets, + values, + validity, + } + } + + #[inline] + /// Needs to be called when a valid value was extended to this array. + /// This is a relatively low level function, prefer `try_push` when you can. + pub fn try_push_valid(&mut self) -> PolarsResult<()> { + let total_length = self.values.len(); + let offset = self.offsets.last().to_usize(); + let length = total_length + .checked_sub(offset) + .ok_or_else(|| polars_err!(ComputeError: "overflow"))?; + + self.offsets.try_push(length)?; + if let Some(validity) = &mut self.validity { + validity.push(true) + } + Ok(()) + } + + #[inline] + fn push_null(&mut self) { + self.offsets.extend_constant(1); + match &mut self.validity { + Some(validity) => validity.push(false), + None => self.init_validity(), + } + } + + /// Expand this array, using elements from the underlying backing array. + /// Assumes the expansion begins at the highest previous offset, or zero if + /// this [`MutableListArray`] is currently empty. + /// + /// Panics if: + /// - the new offsets are not in monotonic increasing order. + /// - any new offset is not in bounds of the backing array. + /// - the passed iterator has no upper bound. + pub fn try_extend_from_lengths(&mut self, iterator: II) -> PolarsResult<()> + where + II: TrustedLen> + Clone, + { + self.offsets + .try_extend_from_lengths(iterator.clone().map(|x| x.unwrap_or_default()))?; + if let Some(validity) = &mut self.validity { + validity.extend_from_trusted_len_iter(iterator.map(|x| x.is_some())) + } + assert_eq!(self.offsets.last().to_usize(), self.values.len()); + Ok(()) + } + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.offsets.len_proxy() + } + + /// The values + pub fn mut_values(&mut self) -> &mut M { + &mut self.values + } + + /// The offsets + pub fn offsets(&self) -> &Offsets { + &self.offsets + } + + /// The values + pub fn values(&self) -> &M { + &self.values + } + + fn init_validity(&mut self) { + let len = self.offsets.len_proxy(); + + let mut validity = MutableBitmap::with_capacity(self.offsets.capacity()); + validity.extend_constant(len, true); + validity.set(len - 1, false); + self.validity = Some(validity) + } + + /// Converts itself into an [`Array`]. + pub fn into_arc(self) -> Arc { + let a: ListArray = self.into(); + Arc::new(a) + } + + /// converts itself into [`Box`] + pub fn into_box(self) -> Box { + let a: ListArray = self.into(); + Box::new(a) + } + + /// Reserves `additional` slots. + pub fn reserve(&mut self, additional: usize) { + self.offsets.reserve(additional); + if let Some(x) = self.validity.as_mut() { + x.reserve(additional) + } + } + + /// Shrinks the capacity of the [`MutableListArray`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + self.offsets.shrink_to_fit(); + if let Some(validity) = &mut self.validity { + validity.shrink_to_fit() + } + } +} + +impl MutableArray for MutableListArray { + fn len(&self) -> usize { + MutableListArray::len(self) + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + ListArray::new( + self.dtype.clone(), + std::mem::take(&mut self.offsets).into(), + self.values.as_box(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .boxed() + } + + fn as_arc(&mut self) -> Arc { + ListArray::new( + self.dtype.clone(), + std::mem::take(&mut self.offsets).into(), + self.values.as_box(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .arced() + } + + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + self.push_null() + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional) + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit(); + } +} diff --git a/crates/polars-arrow/src/array/map/ffi.rs b/crates/polars-arrow/src/array/map/ffi.rs new file mode 100644 index 000000000000..a80db436e838 --- /dev/null +++ b/crates/polars-arrow/src/array/map/ffi.rs @@ -0,0 +1,69 @@ +use polars_error::PolarsResult; + +use super::super::Array; +use super::super::ffi::ToFfi; +use super::MapArray; +use crate::array::FromFfi; +use crate::bitmap::align; +use crate::ffi; +use crate::offset::OffsetsBuffer; + +unsafe impl ToFfi for MapArray { + fn buffers(&self) -> Vec> { + vec![ + self.validity.as_ref().map(|x| x.as_ptr()), + Some(self.offsets.buffer().storage_ptr().cast::()), + ] + } + + fn children(&self) -> Vec> { + vec![self.field.clone()] + } + + fn offset(&self) -> Option { + let offset = self.offsets.buffer().offset(); + if let Some(bitmap) = self.validity.as_ref() { + if bitmap.offset() == offset { + Some(offset) + } else { + None + } + } else { + Some(offset) + } + } + + fn to_ffi_aligned(&self) -> Self { + let offset = self.offsets.buffer().offset(); + + let validity = self.validity.as_ref().map(|bitmap| { + if bitmap.offset() == offset { + bitmap.clone() + } else { + align(bitmap, offset) + } + }); + + Self { + dtype: self.dtype.clone(), + validity, + offsets: self.offsets.clone(), + field: self.field.clone(), + } + } +} + +impl FromFfi for MapArray { + unsafe fn try_from_ffi(array: A) -> PolarsResult { + let dtype = array.dtype().clone(); + let validity = unsafe { array.validity() }?; + let offsets = unsafe { array.buffer::(1) }?; + let child = array.child(0)?; + let values = ffi::try_from(child)?; + + // assumption that data from FFI is well constructed + let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets) }; + + Self::try_new(dtype, offsets, values, validity) + } +} diff --git a/crates/polars-arrow/src/array/map/fmt.rs b/crates/polars-arrow/src/array/map/fmt.rs new file mode 100644 index 000000000000..60abf56e18c5 --- /dev/null +++ b/crates/polars-arrow/src/array/map/fmt.rs @@ -0,0 +1,24 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::{get_display, write_vec}; +use super::MapArray; + +pub fn write_value( + array: &MapArray, + index: usize, + null: &'static str, + f: &mut W, +) -> Result { + let values = array.value(index); + let writer = |f: &mut W, index| get_display(values.as_ref(), null)(f, index); + write_vec(f, writer, None, values.len(), null, false) +} + +impl Debug for MapArray { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, "None", f); + + write!(f, "MapArray")?; + write_vec(f, writer, self.validity.as_ref(), self.len(), "None", false) + } +} diff --git a/crates/polars-arrow/src/array/map/iterator.rs b/crates/polars-arrow/src/array/map/iterator.rs new file mode 100644 index 000000000000..79fc630cc520 --- /dev/null +++ b/crates/polars-arrow/src/array/map/iterator.rs @@ -0,0 +1,81 @@ +use super::MapArray; +use crate::array::Array; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::trusted_len::TrustedLen; + +/// Iterator of values of an [`ListArray`]. +#[derive(Clone, Debug)] +pub struct MapValuesIter<'a> { + array: &'a MapArray, + index: usize, + end: usize, +} + +impl<'a> MapValuesIter<'a> { + #[inline] + pub fn new(array: &'a MapArray) -> Self { + Self { + array, + index: 0, + end: array.len(), + } + } +} + +impl Iterator for MapValuesIter<'_> { + type Item = Box; + + #[inline] + fn next(&mut self) -> Option { + if self.index == self.end { + return None; + } + let old = self.index; + self.index += 1; + // SAFETY: + // self.end is maximized by the length of the array + Some(unsafe { self.array.value_unchecked(old) }) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.end - self.index, Some(self.end - self.index)) + } +} + +unsafe impl TrustedLen for MapValuesIter<'_> {} + +impl DoubleEndedIterator for MapValuesIter<'_> { + #[inline] + fn next_back(&mut self) -> Option { + if self.index == self.end { + None + } else { + self.end -= 1; + // SAFETY: + // self.end is maximized by the length of the array + Some(unsafe { self.array.value_unchecked(self.end) }) + } + } +} + +impl<'a> IntoIterator for &'a MapArray { + type Item = Option>; + type IntoIter = ZipValidity, MapValuesIter<'a>, BitmapIter<'a>>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a> MapArray { + /// Returns an iterator of `Option>` + pub fn iter(&'a self) -> ZipValidity, MapValuesIter<'a>, BitmapIter<'a>> { + ZipValidity::new_with_validity(MapValuesIter::new(self), self.validity()) + } + + /// Returns an iterator of `Box` + pub fn values_iter(&'a self) -> MapValuesIter<'a> { + MapValuesIter::new(self) + } +} diff --git a/crates/polars-arrow/src/array/map/mod.rs b/crates/polars-arrow/src/array/map/mod.rs new file mode 100644 index 000000000000..9747ff23259e --- /dev/null +++ b/crates/polars-arrow/src/array/map/mod.rs @@ -0,0 +1,221 @@ +use super::specification::try_check_offsets_bounds; +use super::{Array, Splitable, new_empty_array}; +use crate::bitmap::Bitmap; +use crate::datatypes::{ArrowDataType, Field}; +use crate::offset::OffsetsBuffer; + +mod ffi; +pub(super) mod fmt; +mod iterator; + +use polars_error::{PolarsResult, polars_bail}; + +/// An array representing a (key, value), both of arbitrary logical types. +#[derive(Clone)] +pub struct MapArray { + dtype: ArrowDataType, + // invariant: field.len() == offsets.len() + offsets: OffsetsBuffer, + field: Box, + // invariant: offsets.len() - 1 == Bitmap::len() + validity: Option, +} + +impl MapArray { + /// Returns a new [`MapArray`]. + /// # Errors + /// This function errors iff: + /// * `offsets.last()` is greater than `field.len()` + /// * The `dtype`'s physical type is not [`crate::datatypes::PhysicalType::Map`] + /// * The fields' `dtype` is not equal to the inner field of `dtype` + /// * The validity is not `None` and its length is different from `offsets.len() - 1`. + pub fn try_new( + dtype: ArrowDataType, + offsets: OffsetsBuffer, + field: Box, + validity: Option, + ) -> PolarsResult { + try_check_offsets_bounds(&offsets, field.len())?; + + let inner_field = Self::try_get_field(&dtype)?; + if let ArrowDataType::Struct(inner) = inner_field.dtype() { + if inner.len() != 2 { + polars_bail!(ComputeError: "MapArray's inner `Struct` must have 2 fields (keys and maps)") + } + } else { + polars_bail!(ComputeError: "MapArray expects `DataType::Struct` as its inner logical type") + } + if field.dtype() != inner_field.dtype() { + polars_bail!(ComputeError: "MapArray expects `field.dtype` to match its inner DataType") + } + + if validity + .as_ref() + .is_some_and(|validity| validity.len() != offsets.len_proxy()) + { + polars_bail!(ComputeError: "validity mask length must match the number of values") + } + + Ok(Self { + dtype, + field, + offsets, + validity, + }) + } + + /// Creates a new [`MapArray`]. + /// # Panics + /// * `offsets.last()` is greater than `field.len()`. + /// * The `dtype`'s physical type is not [`crate::datatypes::PhysicalType::Map`], + /// * The validity is not `None` and its length is different from `offsets.len() - 1`. + pub fn new( + dtype: ArrowDataType, + offsets: OffsetsBuffer, + field: Box, + validity: Option, + ) -> Self { + Self::try_new(dtype, offsets, field, validity).unwrap() + } + + /// Returns a new null [`MapArray`] of `length`. + pub fn new_null(dtype: ArrowDataType, length: usize) -> Self { + let field = new_empty_array(Self::get_field(&dtype).dtype().clone()); + Self::new( + dtype, + vec![0i32; 1 + length].try_into().unwrap(), + field, + Some(Bitmap::new_zeroed(length)), + ) + } + + /// Returns a new empty [`MapArray`]. + pub fn new_empty(dtype: ArrowDataType) -> Self { + let field = new_empty_array(Self::get_field(&dtype).dtype().clone()); + Self::new(dtype, OffsetsBuffer::default(), field, None) + } +} + +impl MapArray { + /// Returns a slice of this [`MapArray`]. + /// # Panics + /// panics iff `offset + length > self.len()` + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Returns a slice of this [`MapArray`]. + /// + /// # Safety + /// The caller must ensure that `offset + length < self.len()`. + #[inline] + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.validity = self + .validity + .take() + .map(|bitmap| bitmap.sliced_unchecked(offset, length)) + .filter(|bitmap| bitmap.unset_bits() > 0); + self.offsets.slice_unchecked(offset, length + 1); + } + + impl_sliced!(); + impl_mut_validity!(); + impl_into_array!(); + + pub(crate) fn try_get_field(dtype: &ArrowDataType) -> PolarsResult<&Field> { + if let ArrowDataType::Map(field, _) = dtype.to_logical_type() { + Ok(field.as_ref()) + } else { + polars_bail!(ComputeError: "The dtype's logical type must be DataType::Map") + } + } + + pub(crate) fn get_field(dtype: &ArrowDataType) -> &Field { + Self::try_get_field(dtype).unwrap() + } +} + +// Accessors +impl MapArray { + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.offsets.len_proxy() + } + + /// returns the offsets + #[inline] + pub fn offsets(&self) -> &OffsetsBuffer { + &self.offsets + } + + /// Returns the field (guaranteed to be a `Struct`) + #[inline] + pub fn field(&self) -> &Box { + &self.field + } + + /// Returns the element at index `i`. + #[inline] + pub fn value(&self, i: usize) -> Box { + assert!(i < self.len()); + unsafe { self.value_unchecked(i) } + } + + /// Returns the element at index `i`. + /// + /// # Safety + /// Assumes that the `i < self.len`. + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> Box { + // soundness: the invariant of the function + let (start, end) = self.offsets.start_end_unchecked(i); + let length = end - start; + + // soundness: the invariant of the struct + self.field.sliced_unchecked(start, length) + } +} + +impl Array for MapArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + #[inline] + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } +} + +impl Splitable for MapArray { + fn check_bound(&self, offset: usize) -> bool { + offset <= self.len() + } + + unsafe fn _split_at_unchecked(&self, offset: usize) -> (Self, Self) { + let (lhs_offsets, rhs_offsets) = unsafe { self.offsets.split_at_unchecked(offset) }; + let (lhs_validity, rhs_validity) = unsafe { self.validity.split_at_unchecked(offset) }; + + ( + Self { + dtype: self.dtype.clone(), + offsets: lhs_offsets, + field: self.field.clone(), + validity: lhs_validity, + }, + Self { + dtype: self.dtype.clone(), + offsets: rhs_offsets, + field: self.field.clone(), + validity: rhs_validity, + }, + ) + } +} diff --git a/crates/polars-arrow/src/array/mod.rs b/crates/polars-arrow/src/array/mod.rs new file mode 100644 index 000000000000..9a55349b56b7 --- /dev/null +++ b/crates/polars-arrow/src/array/mod.rs @@ -0,0 +1,777 @@ +//! Contains the [`Array`] and [`MutableArray`] trait objects declaring arrays, +//! as well as concrete arrays (such as [`Utf8Array`] and [`MutableUtf8Array`]). +//! +//! Fixed-length containers with optional values +//! that are laid in memory according to the Arrow specification. +//! Each array type has its own `struct`. The following are the main array types: +//! * [`PrimitiveArray`] and [`MutablePrimitiveArray`], an array of values with a fixed length such as integers, floats, etc. +//! * [`BooleanArray`] and [`MutableBooleanArray`], an array of boolean values (stored as a bitmap) +//! * [`Utf8Array`] and [`MutableUtf8Array`], an array of variable length utf8 values +//! * [`BinaryArray`] and [`MutableBinaryArray`], an array of opaque variable length values +//! * [`ListArray`] and [`MutableListArray`], an array of arrays (e.g. `[[1, 2], None, [], [None]]`) +//! * [`StructArray`] and [`MutableStructArray`], an array of arrays identified by a string (e.g. `{"a": [1, 2], "b": [true, false]}`) +//! +//! All immutable arrays implement the trait object [`Array`] and that can be downcast +//! to a concrete struct based on [`PhysicalType`](crate::datatypes::PhysicalType) available from [`Array::dtype`]. +//! All immutable arrays are backed by [`Buffer`](crate::buffer::Buffer) and thus cloning and slicing them is `O(1)`. +//! +//! Most arrays contain a [`MutableArray`] counterpart that is neither cloneable nor sliceable, but +//! can be operated in-place. +#![allow(unsafe_op_in_unsafe_fn)] +use std::any::Any; +use std::sync::Arc; + +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::datatypes::ArrowDataType; + +pub mod physical_binary; + +pub trait Splitable: Sized { + fn check_bound(&self, offset: usize) -> bool; + + /// Split [`Self`] at `offset` where `offset <= self.len()`. + #[inline] + #[must_use] + fn split_at(&self, offset: usize) -> (Self, Self) { + assert!(self.check_bound(offset)); + unsafe { self._split_at_unchecked(offset) } + } + + /// Split [`Self`] at `offset` without checking `offset <= self.len()`. + /// + /// # Safety + /// + /// Safe if `offset <= self.len()`. + #[inline] + #[must_use] + unsafe fn split_at_unchecked(&self, offset: usize) -> (Self, Self) { + debug_assert!(self.check_bound(offset)); + unsafe { self._split_at_unchecked(offset) } + } + + /// Internal implementation of `split_at_unchecked`. For any usage, prefer the using + /// `split_at` or `split_at_unchecked`. + /// + /// # Safety + /// + /// Safe if `offset <= self.len()`. + unsafe fn _split_at_unchecked(&self, offset: usize) -> (Self, Self); +} + +/// A trait representing an immutable Arrow array. Arrow arrays are trait objects +/// that are infallibly downcast to concrete types according to the [`Array::dtype`]. +pub trait Array: Send + Sync + dyn_clone::DynClone + 'static { + /// Converts itself to a reference of [`Any`], which enables downcasting to concrete types. + fn as_any(&self) -> &dyn Any; + + /// Converts itself to a mutable reference of [`Any`], which enables mutable downcasting to concrete types. + fn as_any_mut(&mut self) -> &mut dyn Any; + + /// The length of the [`Array`]. Every array has a length corresponding to the number of + /// elements (slots). + fn len(&self) -> usize; + + /// whether the array is empty + fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// The [`ArrowDataType`] of the [`Array`]. In combination with [`Array::as_any`], this can be + /// used to downcast trait objects (`dyn Array`) to concrete arrays. + fn dtype(&self) -> &ArrowDataType; + + /// The validity of the [`Array`]: every array has an optional [`Bitmap`] that, when available + /// specifies whether the array slot is valid or not (null). + /// When the validity is [`None`], all slots are valid. + fn validity(&self) -> Option<&Bitmap>; + + /// The number of null slots on this [`Array`]. + /// # Implementation + /// This is `O(1)` since the number of null elements is pre-computed. + #[inline] + fn null_count(&self) -> usize { + if self.dtype() == &ArrowDataType::Null { + return self.len(); + }; + self.validity() + .as_ref() + .map(|x| x.unset_bits()) + .unwrap_or(0) + } + + #[inline] + fn has_nulls(&self) -> bool { + self.null_count() > 0 + } + + /// Returns whether slot `i` is null. + /// # Panic + /// Panics iff `i >= self.len()`. + #[inline] + fn is_null(&self, i: usize) -> bool { + assert!(i < self.len()); + unsafe { self.is_null_unchecked(i) } + } + + /// Returns whether slot `i` is null. + /// + /// # Safety + /// The caller must ensure `i < self.len()` + #[inline] + unsafe fn is_null_unchecked(&self, i: usize) -> bool { + self.validity() + .as_ref() + .map(|x| !x.get_bit_unchecked(i)) + .unwrap_or(false) + } + + /// Returns whether slot `i` is valid. + /// # Panic + /// Panics iff `i >= self.len()`. + #[inline] + fn is_valid(&self, i: usize) -> bool { + !self.is_null(i) + } + + /// Split [`Self`] at `offset` into two boxed [`Array`]s where `offset <= self.len()`. + #[must_use] + fn split_at_boxed(&self, offset: usize) -> (Box, Box); + + /// Split [`Self`] at `offset` into two boxed [`Array`]s without checking `offset <= self.len()`. + /// + /// # Safety + /// + /// Safe if `offset <= self.len()`. + #[must_use] + unsafe fn split_at_boxed_unchecked(&self, offset: usize) -> (Box, Box); + + /// Slices this [`Array`]. + /// # Implementation + /// This operation is `O(1)` over `len`. + /// # Panic + /// This function panics iff `offset + length > self.len()`. + fn slice(&mut self, offset: usize, length: usize); + + /// Slices the [`Array`]. + /// # Implementation + /// This operation is `O(1)`. + /// + /// # Safety + /// The caller must ensure that `offset + length <= self.len()` + unsafe fn slice_unchecked(&mut self, offset: usize, length: usize); + + /// Returns a slice of this [`Array`]. + /// # Implementation + /// This operation is `O(1)` over `len`. + /// # Panic + /// This function panics iff `offset + length > self.len()`. + #[must_use] + fn sliced(&self, offset: usize, length: usize) -> Box { + if length == 0 { + return new_empty_array(self.dtype().clone()); + } + let mut new = self.to_boxed(); + new.slice(offset, length); + new + } + + /// Returns a slice of this [`Array`]. + /// # Implementation + /// This operation is `O(1)` over `len`, as it amounts to increase two ref counts + /// and moving the struct to the heap. + /// + /// # Safety + /// The caller must ensure that `offset + length <= self.len()` + #[must_use] + unsafe fn sliced_unchecked(&self, offset: usize, length: usize) -> Box { + debug_assert!(offset + length <= self.len()); + let mut new = self.to_boxed(); + new.slice_unchecked(offset, length); + new + } + + /// Clones this [`Array`] with a new assigned bitmap. + /// # Panic + /// This function panics iff `validity.len() != self.len()`. + fn with_validity(&self, validity: Option) -> Box; + + /// Clone a `&dyn Array` to an owned `Box`. + fn to_boxed(&self) -> Box; +} + +dyn_clone::clone_trait_object!(Array); + +pub trait IntoBoxedArray { + fn into_boxed(self) -> Box; +} + +impl IntoBoxedArray for A { + #[inline(always)] + fn into_boxed(self) -> Box { + Box::new(self) as _ + } +} +impl IntoBoxedArray for Box { + #[inline(always)] + fn into_boxed(self) -> Box { + self + } +} + +/// A trait describing a mutable array; i.e. an array whose values can be changed. +/// +/// Mutable arrays cannot be cloned but can be mutated in place, +/// thereby making them useful to perform numeric operations without allocations. +/// As in [`Array`], concrete arrays (such as [`MutablePrimitiveArray`]) implement how they are mutated. +pub trait MutableArray: std::fmt::Debug + Send + Sync { + /// The [`ArrowDataType`] of the array. + fn dtype(&self) -> &ArrowDataType; + + /// The length of the array. + fn len(&self) -> usize; + + /// Whether the array is empty. + fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// The optional validity of the array. + fn validity(&self) -> Option<&MutableBitmap>; + + /// Convert itself to an (immutable) [`Array`]. + fn as_box(&mut self) -> Box; + + /// Convert itself to an (immutable) atomically reference counted [`Array`]. + // This provided implementation has an extra allocation as it first + // boxes `self`, then converts the box into an `Arc`. Implementors may wish + // to avoid an allocation by skipping the box completely. + fn as_arc(&mut self) -> std::sync::Arc { + self.as_box().into() + } + + /// Convert to `Any`, to enable dynamic casting. + fn as_any(&self) -> &dyn Any; + + /// Convert to mutable `Any`, to enable dynamic casting. + fn as_mut_any(&mut self) -> &mut dyn Any; + + /// Adds a new null element to the array. + fn push_null(&mut self); + + /// Whether `index` is valid / set. + /// # Panic + /// Panics if `index >= self.len()`. + #[inline] + fn is_valid(&self, index: usize) -> bool { + self.validity() + .as_ref() + .map(|x| x.get(index)) + .unwrap_or(true) + } + + /// Reserves additional slots to its capacity. + fn reserve(&mut self, additional: usize); + + /// Shrink the array to fit its length. + fn shrink_to_fit(&mut self); +} + +impl MutableArray for Box { + fn len(&self) -> usize { + self.as_ref().len() + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.as_ref().validity() + } + + fn as_box(&mut self) -> Box { + self.as_mut().as_box() + } + + fn as_arc(&mut self) -> Arc { + self.as_mut().as_arc() + } + + fn dtype(&self) -> &ArrowDataType { + self.as_ref().dtype() + } + + fn as_any(&self) -> &dyn std::any::Any { + self.as_ref().as_any() + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self.as_mut().as_mut_any() + } + + #[inline] + fn push_null(&mut self) { + self.as_mut().push_null() + } + + fn shrink_to_fit(&mut self) { + self.as_mut().shrink_to_fit(); + } + + fn reserve(&mut self, additional: usize) { + self.as_mut().reserve(additional); + } +} + +macro_rules! general_dyn { + ($array:expr, $ty:ty, $f:expr) => {{ + let array = $array.as_any().downcast_ref::<$ty>().unwrap(); + ($f)(array) + }}; +} + +macro_rules! fmt_dyn { + ($array:expr, $ty:ty, $f:expr) => {{ + let mut f = |x: &$ty| x.fmt($f); + general_dyn!($array, $ty, f) + }}; +} + +impl std::fmt::Debug for dyn Array + '_ { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use crate::datatypes::PhysicalType::*; + match self.dtype().to_physical_type() { + Null => fmt_dyn!(self, NullArray, f), + Boolean => fmt_dyn!(self, BooleanArray, f), + Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { + fmt_dyn!(self, PrimitiveArray<$T>, f) + }), + BinaryView => fmt_dyn!(self, BinaryViewArray, f), + Utf8View => fmt_dyn!(self, Utf8ViewArray, f), + Binary => fmt_dyn!(self, BinaryArray, f), + LargeBinary => fmt_dyn!(self, BinaryArray, f), + FixedSizeBinary => fmt_dyn!(self, FixedSizeBinaryArray, f), + Utf8 => fmt_dyn!(self, Utf8Array::, f), + LargeUtf8 => fmt_dyn!(self, Utf8Array::, f), + List => fmt_dyn!(self, ListArray::, f), + LargeList => fmt_dyn!(self, ListArray::, f), + FixedSizeList => fmt_dyn!(self, FixedSizeListArray, f), + Struct => fmt_dyn!(self, StructArray, f), + Union => fmt_dyn!(self, UnionArray, f), + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + fmt_dyn!(self, DictionaryArray::<$T>, f) + }) + }, + Map => fmt_dyn!(self, MapArray, f), + } + } +} + +/// Creates a new [`Array`] with a [`Array::len`] of 0. +pub fn new_empty_array(dtype: ArrowDataType) -> Box { + use crate::datatypes::PhysicalType::*; + match dtype.to_physical_type() { + Null => Box::new(NullArray::new_empty(dtype)), + Boolean => Box::new(BooleanArray::new_empty(dtype)), + Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { + Box::new(PrimitiveArray::<$T>::new_empty(dtype)) + }), + Binary => Box::new(BinaryArray::::new_empty(dtype)), + LargeBinary => Box::new(BinaryArray::::new_empty(dtype)), + FixedSizeBinary => Box::new(FixedSizeBinaryArray::new_empty(dtype)), + Utf8 => Box::new(Utf8Array::::new_empty(dtype)), + LargeUtf8 => Box::new(Utf8Array::::new_empty(dtype)), + List => Box::new(ListArray::::new_empty(dtype)), + LargeList => Box::new(ListArray::::new_empty(dtype)), + FixedSizeList => Box::new(FixedSizeListArray::new_empty(dtype)), + Struct => Box::new(StructArray::new_empty(dtype)), + Union => Box::new(UnionArray::new_empty(dtype)), + Map => Box::new(MapArray::new_empty(dtype)), + Utf8View => Box::new(Utf8ViewArray::new_empty(dtype)), + BinaryView => Box::new(BinaryViewArray::new_empty(dtype)), + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + Box::new(DictionaryArray::<$T>::new_empty(dtype)) + }) + }, + } +} + +/// Creates a new [`Array`] of [`ArrowDataType`] `dtype` and `length`. +/// +/// The array is guaranteed to have [`Array::null_count`] equal to [`Array::len`] +/// for all types except Union, which does not have a validity. +pub fn new_null_array(dtype: ArrowDataType, length: usize) -> Box { + use crate::datatypes::PhysicalType::*; + match dtype.to_physical_type() { + Null => Box::new(NullArray::new_null(dtype, length)), + Boolean => Box::new(BooleanArray::new_null(dtype, length)), + Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { + Box::new(PrimitiveArray::<$T>::new_null(dtype, length)) + }), + Binary => Box::new(BinaryArray::::new_null(dtype, length)), + LargeBinary => Box::new(BinaryArray::::new_null(dtype, length)), + FixedSizeBinary => Box::new(FixedSizeBinaryArray::new_null(dtype, length)), + Utf8 => Box::new(Utf8Array::::new_null(dtype, length)), + LargeUtf8 => Box::new(Utf8Array::::new_null(dtype, length)), + List => Box::new(ListArray::::new_null(dtype, length)), + LargeList => Box::new(ListArray::::new_null(dtype, length)), + FixedSizeList => Box::new(FixedSizeListArray::new_null(dtype, length)), + Struct => Box::new(StructArray::new_null(dtype, length)), + Union => Box::new(UnionArray::new_null(dtype, length)), + Map => Box::new(MapArray::new_null(dtype, length)), + BinaryView => Box::new(BinaryViewArray::new_null(dtype, length)), + Utf8View => Box::new(Utf8ViewArray::new_null(dtype, length)), + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + Box::new(DictionaryArray::<$T>::new_null(dtype, length)) + }) + }, + } +} + +macro_rules! clone_dyn { + ($array:expr, $ty:ty) => {{ + let f = |x: &$ty| Box::new(x.clone()); + general_dyn!($array, $ty, f) + }}; +} + +// macro implementing `sliced` and `sliced_unchecked` +macro_rules! impl_sliced { + () => { + /// Returns this array sliced. + /// # Implementation + /// This function is `O(1)`. + /// # Panics + /// iff `offset + length > self.len()`. + #[inline] + #[must_use] + pub fn sliced(self, offset: usize, length: usize) -> Self { + let total = offset + .checked_add(length) + .expect("offset + length overflowed"); + assert!( + total <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + unsafe { Self::sliced_unchecked(self, offset, length) } + } + + /// Returns this array sliced. + /// # Implementation + /// This function is `O(1)`. + /// + /// # Safety + /// The caller must ensure that `offset + length <= self.len()`. + #[inline] + #[must_use] + pub unsafe fn sliced_unchecked(mut self, offset: usize, length: usize) -> Self { + Self::slice_unchecked(&mut self, offset, length); + self + } + }; +} + +// macro implementing `with_validity` and `set_validity` +macro_rules! impl_mut_validity { + () => { + /// Returns this array with a new validity. + /// # Panic + /// Panics iff `validity.len() != self.len()`. + #[must_use] + #[inline] + pub fn with_validity(mut self, validity: Option) -> Self { + self.set_validity(validity); + self + } + + /// Sets the validity of this array. + /// # Panics + /// This function panics iff `values.len() != self.len()`. + #[inline] + pub fn set_validity(&mut self, validity: Option) { + if matches!(&validity, Some(bitmap) if bitmap.len() != self.len()) { + panic!("validity must be equal to the array's length") + } + self.validity = validity; + } + + /// Takes the validity of this array, leaving it without a validity mask. + #[inline] + pub fn take_validity(&mut self) -> Option { + self.validity.take() + } + } +} + +// macro implementing `with_validity`, `set_validity` and `apply_validity` for mutable arrays +macro_rules! impl_mutable_array_mut_validity { + () => { + /// Returns this array with a new validity. + /// # Panic + /// Panics iff `validity.len() != self.len()`. + #[must_use] + #[inline] + pub fn with_validity(mut self, validity: Option) -> Self { + self.set_validity(validity); + self + } + + /// Sets the validity of this array. + /// # Panics + /// This function panics iff `values.len() != self.len()`. + #[inline] + pub fn set_validity(&mut self, validity: Option) { + if matches!(&validity, Some(bitmap) if bitmap.len() != self.len()) { + panic!("validity must be equal to the array's length") + } + self.validity = validity; + } + + /// Applies a function `f` to the validity of this array. + /// + /// This is an API to leverage clone-on-write + /// # Panics + /// This function panics if the function `f` modifies the length of the [`Bitmap`]. + #[inline] + pub fn apply_validity MutableBitmap>(&mut self, f: F) { + if let Some(validity) = std::mem::take(&mut self.validity) { + self.set_validity(Some(f(validity))) + } + } + + } +} + +// macro implementing `boxed` and `arced` +macro_rules! impl_into_array { + () => { + /// Boxes this array into a [`Box`]. + pub fn boxed(self) -> Box { + Box::new(self) + } + + /// Arcs this array into a [`std::sync::Arc`]. + pub fn arced(self) -> std::sync::Arc { + std::sync::Arc::new(self) + } + }; +} + +// macro implementing common methods of trait `Array` +macro_rules! impl_common_array { + () => { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn len(&self) -> usize { + self.len() + } + + #[inline] + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } + + #[inline] + fn split_at_boxed(&self, offset: usize) -> (Box, Box) { + let (lhs, rhs) = $crate::array::Splitable::split_at(self, offset); + (Box::new(lhs), Box::new(rhs)) + } + + #[inline] + unsafe fn split_at_boxed_unchecked( + &self, + offset: usize, + ) -> (Box, Box) { + let (lhs, rhs) = unsafe { $crate::array::Splitable::split_at_unchecked(self, offset) }; + (Box::new(lhs), Box::new(rhs)) + } + + #[inline] + fn slice(&mut self, offset: usize, length: usize) { + self.slice(offset, length); + } + + #[inline] + unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.slice_unchecked(offset, length); + } + + #[inline] + fn to_boxed(&self) -> Box { + Box::new(self.clone()) + } + }; +} + +/// Clones a dynamic [`Array`]. +/// # Implementation +/// This operation is `O(1)` over `len`, as it amounts to increase two ref counts +/// and moving the concrete struct under a `Box`. +pub fn clone(array: &dyn Array) -> Box { + use crate::datatypes::PhysicalType::*; + match array.dtype().to_physical_type() { + Null => clone_dyn!(array, NullArray), + Boolean => clone_dyn!(array, BooleanArray), + Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { + clone_dyn!(array, PrimitiveArray<$T>) + }), + Binary => clone_dyn!(array, BinaryArray), + LargeBinary => clone_dyn!(array, BinaryArray), + FixedSizeBinary => clone_dyn!(array, FixedSizeBinaryArray), + Utf8 => clone_dyn!(array, Utf8Array::), + LargeUtf8 => clone_dyn!(array, Utf8Array::), + List => clone_dyn!(array, ListArray::), + LargeList => clone_dyn!(array, ListArray::), + FixedSizeList => clone_dyn!(array, FixedSizeListArray), + Struct => clone_dyn!(array, StructArray), + Union => clone_dyn!(array, UnionArray), + Map => clone_dyn!(array, MapArray), + BinaryView => clone_dyn!(array, BinaryViewArray), + Utf8View => clone_dyn!(array, Utf8ViewArray), + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + clone_dyn!(array, DictionaryArray::<$T>) + }) + }, + } +} + +// see https://users.rust-lang.org/t/generic-for-dyn-a-or-box-dyn-a-or-arc-dyn-a/69430/3 +// for details +impl<'a> AsRef<(dyn Array + 'a)> for dyn Array { + fn as_ref(&self) -> &(dyn Array + 'a) { + self + } +} + +mod binary; +mod boolean; +pub mod builder; +mod dictionary; +mod fixed_size_binary; +mod fixed_size_list; +mod list; +mod map; +mod null; +mod primitive; +pub mod specification; +mod static_array; +mod static_array_collect; +mod struct_; +mod total_ord; +mod union; +mod utf8; + +mod equal; +mod ffi; +mod fmt; +#[doc(hidden)] +pub mod indexable; +pub mod iterator; + +mod binview; +mod values; + +pub use binary::{ + BinaryArray, BinaryArrayBuilder, BinaryValueIter, MutableBinaryArray, MutableBinaryValuesArray, +}; +pub use binview::{ + BinaryViewArray, BinaryViewArrayGeneric, BinaryViewArrayGenericBuilder, MutableBinaryViewArray, + MutablePlBinary, MutablePlString, Utf8ViewArray, View, ViewType, +}; +pub use boolean::{BooleanArray, BooleanArrayBuilder, MutableBooleanArray}; +pub use dictionary::{DictionaryArray, DictionaryKey, MutableDictionaryArray}; +pub use equal::equal; +pub use fixed_size_binary::{ + FixedSizeBinaryArray, FixedSizeBinaryArrayBuilder, MutableFixedSizeBinaryArray, +}; +pub use fixed_size_list::{ + FixedSizeListArray, FixedSizeListArrayBuilder, MutableFixedSizeListArray, +}; +pub use fmt::{get_display, get_value_display}; +pub(crate) use iterator::ArrayAccessor; +pub use iterator::ArrayValuesIter; +pub use list::{ListArray, ListArrayBuilder, ListValuesIter, MutableListArray}; +pub use map::MapArray; +pub use null::{MutableNullArray, NullArray, NullArrayBuilder}; +use polars_error::PolarsResult; +pub use primitive::*; +pub use static_array::{ParameterFreeDtypeStaticArray, StaticArray}; +pub use static_array_collect::{ArrayCollectIterExt, ArrayFromIter, ArrayFromIterDtype}; +pub use struct_::{StructArray, StructArrayBuilder}; +pub use union::UnionArray; +pub use utf8::{MutableUtf8Array, MutableUtf8ValuesArray, Utf8Array, Utf8ValuesIter}; +pub use values::ValueSize; + +pub(crate) use self::ffi::{FromFfi, ToFfi, offset_buffers_children_dictionary}; +use crate::{match_integer_type, with_match_primitive_type_full}; + +/// A trait describing the ability of a struct to create itself from a iterator. +/// This is similar to [`Extend`], but accepted the creation to error. +pub trait TryExtend { + /// Fallible version of [`Extend::extend`]. + fn try_extend>(&mut self, iter: I) -> PolarsResult<()>; +} + +/// A trait describing the ability of a struct to receive new items. +pub trait TryPush { + /// Tries to push a new element. + fn try_push(&mut self, item: A) -> PolarsResult<()>; +} + +/// A trait describing the ability of a struct to receive new items. +pub trait PushUnchecked { + /// Push a new element that holds the invariants of the struct. + /// + /// # Safety + /// The items must uphold the invariants of the struct + /// Read the specific implementation of the trait to understand what these are. + unsafe fn push_unchecked(&mut self, item: A); +} + +/// A trait describing the ability of a struct to extend from a reference of itself. +/// Specialization of [`TryExtend`]. +pub trait TryExtendFromSelf { + /// Tries to extend itself with elements from `other`, failing only on overflow. + fn try_extend_from_self(&mut self, other: &Self) -> PolarsResult<()>; +} + +/// Trait that [`BinaryArray`] and [`Utf8Array`] implement for the purposes of DRY. +/// # Safety +/// The implementer must ensure that +/// 1. `offsets.len() > 0` +/// 2. `offsets[i] >= offsets[i-1] for all i` +/// 3. `offsets[i] < values.len() for all i` +pub unsafe trait GenericBinaryArray: Array { + /// The values of the array + fn values(&self) -> &[u8]; + /// The offsets of the array + fn offsets(&self) -> &[O]; +} + +pub type ArrayRef = Box; + +impl Splitable for Option { + #[inline(always)] + fn check_bound(&self, offset: usize) -> bool { + self.as_ref().is_none_or(|v| offset <= v.len()) + } + + unsafe fn _split_at_unchecked(&self, offset: usize) -> (Self, Self) { + self.as_ref().map_or((None, None), |bm| { + let (lhs, rhs) = unsafe { bm.split_at_unchecked(offset) }; + ( + (lhs.unset_bits() > 0).then_some(lhs), + (rhs.unset_bits() > 0).then_some(rhs), + ) + }) + } +} diff --git a/crates/polars-arrow/src/array/null.rs b/crates/polars-arrow/src/array/null.rs new file mode 100644 index 000000000000..d35bf761c090 --- /dev/null +++ b/crates/polars-arrow/src/array/null.rs @@ -0,0 +1,290 @@ +use std::any::Any; + +use polars_error::{PolarsResult, polars_bail}; +use polars_utils::IdxSize; + +use super::Splitable; +use crate::array::builder::{ShareStrategy, StaticArrayBuilder}; +use crate::array::{Array, FromFfi, MutableArray, ToFfi}; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::datatypes::{ArrowDataType, PhysicalType}; +use crate::ffi; + +/// The concrete [`Array`] of [`ArrowDataType::Null`]. +#[derive(Clone)] +pub struct NullArray { + dtype: ArrowDataType, + + /// Validity mask. This is always all-zeroes. + validity: Bitmap, + + length: usize, +} + +impl NullArray { + /// Returns a new [`NullArray`]. + /// # Errors + /// This function errors iff: + /// * The `dtype`'s [`crate::datatypes::PhysicalType`] is not equal to [`crate::datatypes::PhysicalType::Null`]. + pub fn try_new(dtype: ArrowDataType, length: usize) -> PolarsResult { + if dtype.to_physical_type() != PhysicalType::Null { + polars_bail!(ComputeError: "NullArray can only be initialized with a DataType whose physical type is Null"); + } + + let validity = Bitmap::new_zeroed(length); + + Ok(Self { + dtype, + validity, + length, + }) + } + + /// Returns a new [`NullArray`]. + /// # Panics + /// This function errors iff: + /// * The `dtype`'s [`crate::datatypes::PhysicalType`] is not equal to [`crate::datatypes::PhysicalType::Null`]. + pub fn new(dtype: ArrowDataType, length: usize) -> Self { + Self::try_new(dtype, length).unwrap() + } + + /// Returns a new empty [`NullArray`]. + pub fn new_empty(dtype: ArrowDataType) -> Self { + Self::new(dtype, 0) + } + + /// Returns a new [`NullArray`]. + pub fn new_null(dtype: ArrowDataType, length: usize) -> Self { + Self::new(dtype, length) + } + + impl_sliced!(); + impl_into_array!(); +} + +impl NullArray { + /// Returns a slice of the [`NullArray`]. + /// # Panic + /// This function panics iff `offset + length > self.len()`. + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new array cannot exceed the arrays' length" + ); + unsafe { self.slice_unchecked(offset, length) }; + } + + /// Returns a slice of the [`NullArray`]. + /// + /// # Safety + /// The caller must ensure that `offset + length < self.len()`. + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.length = length; + self.validity.slice_unchecked(offset, length); + } + + #[inline] + pub fn len(&self) -> usize { + self.length + } +} + +impl Array for NullArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + Some(&self.validity) + } + + fn with_validity(&self, _: Option) -> Box { + // Nulls with invalid nulls are also nulls. + self.clone().boxed() + } +} + +#[derive(Debug)] +/// A distinct type to disambiguate +/// clashing methods +pub struct MutableNullArray { + inner: NullArray, +} + +impl MutableNullArray { + /// Returns a new [`MutableNullArray`]. + /// # Panics + /// This function errors iff: + /// * The `dtype`'s [`crate::datatypes::PhysicalType`] is not equal to [`crate::datatypes::PhysicalType::Null`]. + pub fn new(dtype: ArrowDataType, length: usize) -> Self { + let inner = NullArray::try_new(dtype, length).unwrap(); + Self { inner } + } +} + +impl From for NullArray { + fn from(value: MutableNullArray) -> Self { + value.inner + } +} + +impl MutableArray for MutableNullArray { + fn dtype(&self) -> &ArrowDataType { + &ArrowDataType::Null + } + + fn len(&self) -> usize { + self.inner.length + } + + fn validity(&self) -> Option<&MutableBitmap> { + None + } + + fn as_box(&mut self) -> Box { + self.inner.clone().boxed() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn Any { + self + } + + fn push_null(&mut self) { + self.inner.length += 1; + } + + fn reserve(&mut self, _additional: usize) { + // no-op + } + + fn shrink_to_fit(&mut self) { + // no-op + } +} + +impl std::fmt::Debug for NullArray { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "NullArray({})", self.len()) + } +} + +unsafe impl ToFfi for NullArray { + fn buffers(&self) -> Vec> { + // `None` is technically not required by the specification, but older C++ implementations require it, so leaving + // it here for backward compatibility + vec![None] + } + + fn offset(&self) -> Option { + Some(0) + } + + fn to_ffi_aligned(&self) -> Self { + self.clone() + } +} + +impl Splitable for NullArray { + fn check_bound(&self, offset: usize) -> bool { + offset <= self.len() + } + + unsafe fn _split_at_unchecked(&self, offset: usize) -> (Self, Self) { + let (lhs, rhs) = self.validity.split_at(offset); + + ( + Self { + dtype: self.dtype.clone(), + validity: lhs, + length: offset, + }, + Self { + dtype: self.dtype.clone(), + validity: rhs, + length: self.len() - offset, + }, + ) + } +} + +impl FromFfi for NullArray { + unsafe fn try_from_ffi(array: A) -> PolarsResult { + let dtype = array.dtype().clone(); + Self::try_new(dtype, array.array().len()) + } +} + +pub struct NullArrayBuilder { + dtype: ArrowDataType, + length: usize, +} + +impl NullArrayBuilder { + pub fn new(dtype: ArrowDataType) -> Self { + Self { dtype, length: 0 } + } +} + +impl StaticArrayBuilder for NullArrayBuilder { + type Array = NullArray; + + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } + + fn reserve(&mut self, _additional: usize) {} + + fn freeze(self) -> NullArray { + NullArray::new(self.dtype, self.length) + } + + fn freeze_reset(&mut self) -> Self::Array { + let out = NullArray::new(self.dtype.clone(), self.length); + self.length = 0; + out + } + + fn len(&self) -> usize { + self.length + } + + fn extend_nulls(&mut self, length: usize) { + self.length += length; + } + + fn subslice_extend( + &mut self, + _other: &NullArray, + _start: usize, + length: usize, + _share: ShareStrategy, + ) { + self.length += length; + } + + fn subslice_extend_repeated( + &mut self, + _other: &NullArray, + _start: usize, + length: usize, + repeats: usize, + _share: ShareStrategy, + ) { + self.length += length * repeats; + } + + unsafe fn gather_extend( + &mut self, + _other: &NullArray, + idxs: &[IdxSize], + _share: ShareStrategy, + ) { + self.length += idxs.len(); + } + + fn opt_gather_extend(&mut self, _other: &NullArray, idxs: &[IdxSize], _share: ShareStrategy) { + self.length += idxs.len(); + } +} diff --git a/crates/polars-arrow/src/array/physical_binary.rs b/crates/polars-arrow/src/array/physical_binary.rs new file mode 100644 index 000000000000..ba23f3412fa2 --- /dev/null +++ b/crates/polars-arrow/src/array/physical_binary.rs @@ -0,0 +1,234 @@ +use crate::bitmap::{BitmapBuilder, MutableBitmap}; +use crate::offset::{Offset, Offsets}; + +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +#[allow(clippy::type_complexity)] +pub(crate) unsafe fn try_trusted_len_unzip( + iterator: I, +) -> std::result::Result<(Option, Offsets, Vec), E> +where + O: Offset, + P: AsRef<[u8]>, + I: Iterator, E>>, +{ + let (_, upper) = iterator.size_hint(); + let len = upper.expect("trusted_len_unzip requires an upper limit"); + + let mut null = BitmapBuilder::with_capacity(len); + let mut offsets = Vec::::with_capacity(len + 1); + let mut values = Vec::::new(); + + let mut length = O::default(); + let mut dst = offsets.as_mut_ptr(); + std::ptr::write(dst, length); + dst = dst.add(1); + for item in iterator { + if let Some(item) = item? { + null.push_unchecked(true); + let s = item.as_ref(); + length += O::from_as_usize(s.len()); + values.extend_from_slice(s); + } else { + null.push_unchecked(false); + }; + + std::ptr::write(dst, length); + dst = dst.add(1); + } + assert_eq!( + dst.offset_from(offsets.as_ptr()) as usize, + len + 1, + "Trusted iterator length was not accurately reported" + ); + offsets.set_len(len + 1); + + Ok(( + null.into_opt_mut_validity(), + Offsets::new_unchecked(offsets), + values, + )) +} + +/// Creates [`MutableBitmap`] and two [`Vec`]s from an iterator of `Option`. +/// The first buffer corresponds to a offset buffer, the second one +/// corresponds to a values buffer. +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +pub(crate) unsafe fn trusted_len_unzip( + iterator: I, +) -> (Option, Offsets, Vec) +where + O: Offset, + P: AsRef<[u8]>, + I: Iterator>, +{ + let (_, upper) = iterator.size_hint(); + let len = upper.expect("trusted_len_unzip requires an upper limit"); + + let mut offsets = Offsets::::with_capacity(len); + let mut values = Vec::::new(); + let mut validity = MutableBitmap::new(); + + extend_from_trusted_len_iter(&mut offsets, &mut values, &mut validity, iterator); + + let validity = if validity.unset_bits() > 0 { + Some(validity) + } else { + None + }; + + (validity, offsets, values) +} + +/// Creates two [`Buffer`]s from an iterator of `&[u8]`. +/// The first buffer corresponds to a offset buffer, the second to a values buffer. +/// # Safety +/// The caller must ensure that `iterator` is [`TrustedLen`]. +#[inline] +pub(crate) unsafe fn trusted_len_values_iter(iterator: I) -> (Offsets, Vec) +where + O: Offset, + P: AsRef<[u8]>, + I: Iterator, +{ + let (_, upper) = iterator.size_hint(); + let len = upper.expect("trusted_len_unzip requires an upper limit"); + + let mut offsets = Offsets::::with_capacity(len); + let mut values = Vec::::new(); + + extend_from_trusted_len_values_iter(&mut offsets, &mut values, iterator); + + (offsets, values) +} + +// Populates `offsets` and `values` [`Vec`]s with information extracted +// from the incoming `iterator`. +// # Safety +// The caller must ensure the `iterator` is [`TrustedLen`] +#[inline] +pub(crate) unsafe fn extend_from_trusted_len_values_iter( + offsets: &mut Offsets, + values: &mut Vec, + iterator: I, +) where + O: Offset, + P: AsRef<[u8]>, + I: Iterator, +{ + let lengths = iterator.map(|item| { + let s = item.as_ref(); + // Push new entries for both `values` and `offsets` buffer + values.extend_from_slice(s); + s.len() + }); + offsets.try_extend_from_lengths(lengths).unwrap(); +} + +// Populates `offsets` and `values` [`Vec`]s with information extracted +// from the incoming `iterator`. +// the return value indicates how many items were added. +#[inline] +pub(crate) fn extend_from_values_iter( + offsets: &mut Offsets, + values: &mut Vec, + iterator: I, +) -> usize +where + O: Offset, + P: AsRef<[u8]>, + I: Iterator, +{ + let (size_hint, _) = iterator.size_hint(); + + offsets.reserve(size_hint); + + let start_index = offsets.len_proxy(); + + for item in iterator { + let bytes = item.as_ref(); + values.extend_from_slice(bytes); + offsets.try_push(bytes.len()).unwrap(); + } + offsets.len_proxy() - start_index +} + +// Populates `offsets`, `values`, and `validity` [`Vec`]s with +// information extracted from the incoming `iterator`. +// +// # Safety +// The caller must ensure that `iterator` is [`TrustedLen`] +#[inline] +pub(crate) unsafe fn extend_from_trusted_len_iter( + offsets: &mut Offsets, + values: &mut Vec, + validity: &mut MutableBitmap, + iterator: I, +) where + O: Offset, + P: AsRef<[u8]>, + I: Iterator>, +{ + let (_, upper) = iterator.size_hint(); + let additional = upper.expect("extend_from_trusted_len_iter requires an upper limit"); + + offsets.reserve(additional); + validity.reserve(additional); + + let lengths = iterator.map(|item| { + if let Some(item) = item { + let bytes = item.as_ref(); + values.extend_from_slice(bytes); + validity.push_unchecked(true); + bytes.len() + } else { + validity.push_unchecked(false); + 0 + } + }); + offsets.try_extend_from_lengths(lengths).unwrap(); +} + +/// Creates two [`Vec`]s from an iterator of `&[u8]`. +/// The first buffer corresponds to a offset buffer, the second to a values buffer. +#[inline] +pub(crate) fn values_iter(iterator: I) -> (Offsets, Vec) +where + O: Offset, + P: AsRef<[u8]>, + I: Iterator, +{ + let (lower, _) = iterator.size_hint(); + + let mut offsets = Offsets::::with_capacity(lower); + let mut values = Vec::::new(); + + for item in iterator { + let s = item.as_ref(); + values.extend_from_slice(s); + offsets.try_push(s.len()).unwrap(); + } + (offsets, values) +} + +/// Extends `validity` with all items from `other` +pub(crate) fn extend_validity( + length: usize, + validity: &mut Option, + other: &Option, +) { + if let Some(other) = other { + if let Some(validity) = validity { + let slice = other.as_slice(); + // SAFETY: invariant offset + length <= slice.len() + unsafe { validity.extend_from_slice_unchecked(slice, 0, other.len()) } + } else { + let mut new_validity = MutableBitmap::from_len_set(length); + new_validity.extend_from_slice(other.as_slice(), 0, other.len()); + *validity = Some(new_validity); + } + } +} diff --git a/crates/polars-arrow/src/array/primitive/builder.rs b/crates/polars-arrow/src/array/primitive/builder.rs new file mode 100644 index 000000000000..405bead755f2 --- /dev/null +++ b/crates/polars-arrow/src/array/primitive/builder.rs @@ -0,0 +1,109 @@ +use polars_utils::IdxSize; +use polars_utils::vec::PushUnchecked; + +use super::PrimitiveArray; +use crate::array::builder::{ShareStrategy, StaticArrayBuilder}; +use crate::bitmap::OptBitmapBuilder; +use crate::buffer::Buffer; +use crate::datatypes::ArrowDataType; +use crate::types::NativeType; + +pub struct PrimitiveArrayBuilder { + dtype: ArrowDataType, + values: Vec, + validity: OptBitmapBuilder, +} + +impl PrimitiveArrayBuilder { + pub fn new(dtype: ArrowDataType) -> Self { + Self { + dtype, + values: Vec::new(), + validity: OptBitmapBuilder::default(), + } + } +} + +impl StaticArrayBuilder for PrimitiveArrayBuilder { + type Array = PrimitiveArray; + + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } + + fn reserve(&mut self, additional: usize) { + self.values.reserve(additional); + self.validity.reserve(additional); + } + + fn freeze(self) -> PrimitiveArray { + let values = Buffer::from(self.values); + let validity = self.validity.into_opt_validity(); + PrimitiveArray::new(self.dtype, values, validity) + } + + fn freeze_reset(&mut self) -> Self::Array { + let values = Buffer::from(core::mem::take(&mut self.values)); + let validity = core::mem::take(&mut self.validity).into_opt_validity(); + PrimitiveArray::new(self.dtype.clone(), values, validity) + } + + fn len(&self) -> usize { + self.values.len() + } + + fn extend_nulls(&mut self, length: usize) { + self.values.resize(self.values.len() + length, T::zeroed()); + self.validity.extend_constant(length, false); + } + + fn subslice_extend( + &mut self, + other: &PrimitiveArray, + start: usize, + length: usize, + _share: ShareStrategy, + ) { + self.values + .extend_from_slice(&other.values()[start..start + length]); + self.validity + .subslice_extend_from_opt_validity(other.validity(), start, length); + } + + unsafe fn gather_extend( + &mut self, + other: &PrimitiveArray, + idxs: &[IdxSize], + _share: ShareStrategy, + ) { + // TODO: SIMD gather kernels? + let other_values_slice = other.values().as_slice(); + self.values.extend( + idxs.iter() + .map(|idx| *other_values_slice.get_unchecked(*idx as usize)), + ); + self.validity + .gather_extend_from_opt_validity(other.validity(), idxs); + } + + fn opt_gather_extend( + &mut self, + other: &PrimitiveArray, + idxs: &[IdxSize], + _share: ShareStrategy, + ) { + self.values.reserve(idxs.len()); + unsafe { + for idx in idxs { + let val = if (*idx as usize) < other.len() { + other.value_unchecked(*idx as usize) + } else { + T::zeroed() + }; + self.values.push_unchecked(val); + } + } + self.validity + .opt_gather_extend_from_opt_validity(other.validity(), idxs, other.len()); + } +} diff --git a/crates/polars-arrow/src/array/primitive/ffi.rs b/crates/polars-arrow/src/array/primitive/ffi.rs new file mode 100644 index 000000000000..6dae1963dd74 --- /dev/null +++ b/crates/polars-arrow/src/array/primitive/ffi.rs @@ -0,0 +1,57 @@ +use polars_error::PolarsResult; + +use super::PrimitiveArray; +use crate::array::{FromFfi, ToFfi}; +use crate::bitmap::align; +use crate::ffi; +use crate::types::NativeType; + +unsafe impl ToFfi for PrimitiveArray { + fn buffers(&self) -> Vec> { + vec![ + self.validity.as_ref().map(|x| x.as_ptr()), + Some(self.values.storage_ptr().cast::()), + ] + } + + fn offset(&self) -> Option { + let offset = self.values.offset(); + if let Some(bitmap) = self.validity.as_ref() { + if bitmap.offset() == offset { + Some(offset) + } else { + None + } + } else { + Some(offset) + } + } + + fn to_ffi_aligned(&self) -> Self { + let offset = self.values.offset(); + + let validity = self.validity.as_ref().map(|bitmap| { + if bitmap.offset() == offset { + bitmap.clone() + } else { + align(bitmap, offset) + } + }); + + Self { + dtype: self.dtype.clone(), + validity, + values: self.values.clone(), + } + } +} + +impl FromFfi for PrimitiveArray { + unsafe fn try_from_ffi(array: A) -> PolarsResult { + let dtype = array.dtype().clone(); + let validity = unsafe { array.validity() }?; + let values = unsafe { array.buffer::(1) }?; + + Self::try_new(dtype, values, validity) + } +} diff --git a/crates/polars-arrow/src/array/primitive/fmt.rs b/crates/polars-arrow/src/array/primitive/fmt.rs new file mode 100644 index 000000000000..e09d62a1afb8 --- /dev/null +++ b/crates/polars-arrow/src/array/primitive/fmt.rs @@ -0,0 +1,150 @@ +#![allow(clippy::redundant_closure_call)] +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::PrimitiveArray; +use crate::array::Array; +use crate::array::fmt::write_vec; +use crate::datatypes::{IntervalUnit, TimeUnit}; +use crate::temporal_conversions; +use crate::types::{NativeType, days_ms, i256, months_days_ns}; + +macro_rules! dyn_primitive { + ($array:expr, $ty:ty, $expr:expr) => {{ + let array = ($array as &dyn Array) + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(move |f, index| write!(f, "{}", $expr(array.value(index)))) + }}; +} + +pub fn get_write_value<'a, T: NativeType, F: Write>( + array: &'a PrimitiveArray, +) -> Box Result + 'a> { + use crate::datatypes::ArrowDataType::*; + match array.dtype().to_logical_type() { + Int8 => Box::new(|f, index| write!(f, "{}", array.value(index))), + Int16 => Box::new(|f, index| write!(f, "{}", array.value(index))), + Int32 => Box::new(|f, index| write!(f, "{}", array.value(index))), + Int64 => Box::new(|f, index| write!(f, "{}", array.value(index))), + Int128 => Box::new(|f, index| write!(f, "{}", array.value(index))), + UInt8 => Box::new(|f, index| write!(f, "{}", array.value(index))), + UInt16 => Box::new(|f, index| write!(f, "{}", array.value(index))), + UInt32 => Box::new(|f, index| write!(f, "{}", array.value(index))), + UInt64 => Box::new(|f, index| write!(f, "{}", array.value(index))), + Float16 => unreachable!(), + Float32 => Box::new(|f, index| write!(f, "{}", array.value(index))), + Float64 => Box::new(|f, index| write!(f, "{}", array.value(index))), + Date32 => { + dyn_primitive!(array, i32, temporal_conversions::date32_to_date) + }, + Date64 => { + dyn_primitive!(array, i64, temporal_conversions::date64_to_date) + }, + Time32(TimeUnit::Second) => { + dyn_primitive!(array, i32, temporal_conversions::time32s_to_time) + }, + Time32(TimeUnit::Millisecond) => { + dyn_primitive!(array, i32, temporal_conversions::time32ms_to_time) + }, + Time32(_) => unreachable!(), // remaining are not valid + Time64(TimeUnit::Microsecond) => { + dyn_primitive!(array, i64, temporal_conversions::time64us_to_time) + }, + Time64(TimeUnit::Nanosecond) => { + dyn_primitive!(array, i64, temporal_conversions::time64ns_to_time) + }, + Time64(_) => unreachable!(), // remaining are not valid + Timestamp(time_unit, tz) => { + if let Some(tz) = tz { + let timezone = temporal_conversions::parse_offset(tz.as_str()); + match timezone { + Ok(timezone) => { + dyn_primitive!(array, i64, |time| { + temporal_conversions::timestamp_to_datetime(time, *time_unit, &timezone) + }) + }, + #[cfg(feature = "chrono-tz")] + Err(_) => { + let timezone = temporal_conversions::parse_offset_tz(tz.as_str()); + match timezone { + Ok(timezone) => dyn_primitive!(array, i64, |time| { + temporal_conversions::timestamp_to_datetime( + time, *time_unit, &timezone, + ) + }), + Err(_) => { + let tz = tz.clone(); + Box::new(move |f, index| { + write!(f, "{} ({})", array.value(index), tz) + }) + }, + } + }, + #[cfg(not(feature = "chrono-tz"))] + _ => { + let tz = tz.clone(); + Box::new(move |f, index| write!(f, "{} ({})", array.value(index), tz)) + }, + } + } else { + dyn_primitive!(array, i64, |time| { + temporal_conversions::timestamp_to_naive_datetime(time, *time_unit) + }) + } + }, + Interval(IntervalUnit::YearMonth) => { + dyn_primitive!(array, i32, |x| format!("{x}m")) + }, + Interval(IntervalUnit::DayTime) => { + dyn_primitive!(array, days_ms, |x: days_ms| format!( + "{}d{}ms", + x.days(), + x.milliseconds() + )) + }, + Interval(IntervalUnit::MonthDayNano) => { + dyn_primitive!(array, months_days_ns, |x: months_days_ns| format!( + "{}m{}d{}ns", + x.months(), + x.days(), + x.ns() + )) + }, + Duration(TimeUnit::Second) => dyn_primitive!(array, i64, |x| format!("{x}s")), + Duration(TimeUnit::Millisecond) => dyn_primitive!(array, i64, |x| format!("{x}ms")), + Duration(TimeUnit::Microsecond) => dyn_primitive!(array, i64, |x| format!("{x}us")), + Duration(TimeUnit::Nanosecond) => dyn_primitive!(array, i64, |x| format!("{x}ns")), + Decimal(_, scale) => { + // The number 999.99 has a precision of 5 and scale of 2 + let scale = *scale as u32; + let factor = 10i128.pow(scale); + let display = move |x: i128| { + let base = x / factor; + let decimals = (x - base * factor).abs(); + format!("{base}.{decimals}") + }; + dyn_primitive!(array, i128, display) + }, + Decimal256(_, scale) => { + let scale = *scale as u32; + let factor = (ethnum::I256::ONE * 10).pow(scale); + let display = move |x: i256| { + let base = x.0 / factor; + let decimals = (x.0 - base * factor).abs(); + format!("{base}.{decimals}") + }; + dyn_primitive!(array, i256, display) + }, + _ => unreachable!(), + } +} + +impl Debug for PrimitiveArray { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = get_write_value(self); + + write!(f, "{:?}", self.dtype())?; + write_vec(f, &*writer, self.validity(), self.len(), "None", false) + } +} diff --git a/crates/polars-arrow/src/array/primitive/from_natural.rs b/crates/polars-arrow/src/array/primitive/from_natural.rs new file mode 100644 index 000000000000..a70259a8eeff --- /dev/null +++ b/crates/polars-arrow/src/array/primitive/from_natural.rs @@ -0,0 +1,14 @@ +use super::{MutablePrimitiveArray, PrimitiveArray}; +use crate::types::NativeType; + +impl]>> From

for PrimitiveArray { + fn from(slice: P) -> Self { + MutablePrimitiveArray::::from(slice).into() + } +} + +impl>> FromIterator for PrimitiveArray { + fn from_iter>(iter: I) -> Self { + MutablePrimitiveArray::::from_iter(iter).into() + } +} diff --git a/crates/polars-arrow/src/array/primitive/iterator.rs b/crates/polars-arrow/src/array/primitive/iterator.rs new file mode 100644 index 000000000000..5b811b0b0c70 --- /dev/null +++ b/crates/polars-arrow/src/array/primitive/iterator.rs @@ -0,0 +1,61 @@ +use super::{MutablePrimitiveArray, PrimitiveArray}; +use crate::array::{ArrayAccessor, MutableArray}; +use crate::bitmap::IntoIter as BitmapIntoIter; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::buffer::IntoIter; +use crate::types::NativeType; + +unsafe impl<'a, T: NativeType> ArrayAccessor<'a> for [T] { + type Item = T; + + #[inline] + unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item { + *self.get_unchecked(index) + } + + #[inline] + fn len(&self) -> usize { + (*self).len() + } +} + +impl IntoIterator for PrimitiveArray { + type Item = Option; + type IntoIter = ZipValidity, BitmapIntoIter>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + let (_, values, validity) = self.into_inner(); + let values = values.into_iter(); + let validity = + validity.and_then(|validity| (validity.unset_bits() > 0).then(|| validity.into_iter())); + ZipValidity::new(values, validity) + } +} + +impl<'a, T: NativeType> IntoIterator for &'a PrimitiveArray { + type Item = Option<&'a T>; + type IntoIter = ZipValidity<&'a T, std::slice::Iter<'a, T>, BitmapIter<'a>>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a, T: NativeType> MutablePrimitiveArray { + /// Returns an iterator over `Option` + #[inline] + pub fn iter(&'a self) -> ZipValidity<&'a T, std::slice::Iter<'a, T>, BitmapIter<'a>> { + ZipValidity::new( + self.values().iter(), + self.validity().as_ref().map(|x| x.iter()), + ) + } + + /// Returns an iterator of `T` + #[inline] + pub fn values_iter(&'a self) -> std::slice::Iter<'a, T> { + self.values().iter() + } +} diff --git a/crates/polars-arrow/src/array/primitive/mod.rs b/crates/polars-arrow/src/array/primitive/mod.rs new file mode 100644 index 000000000000..7ceb2b3bb473 --- /dev/null +++ b/crates/polars-arrow/src/array/primitive/mod.rs @@ -0,0 +1,631 @@ +use std::ops::Range; + +use either::Either; + +use super::{Array, Splitable}; +use crate::array::iterator::NonNullValuesIter; +use crate::bitmap::Bitmap; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::buffer::Buffer; +use crate::datatypes::*; +use crate::trusted_len::TrustedLen; +use crate::types::{NativeType, days_ms, f16, i256, months_days_ns}; + +mod ffi; +pub(super) mod fmt; +mod from_natural; +pub mod iterator; + +mod mutable; +pub use mutable::*; +mod builder; +pub use builder::*; +use polars_error::{PolarsResult, polars_bail}; +use polars_utils::index::{Bounded, Indexable, NullCount}; +use polars_utils::slice::SliceAble; + +/// A [`PrimitiveArray`] is Arrow's semantically equivalent of an immutable `Vec>` where +/// T is [`NativeType`] (e.g. [`i32`]). It implements [`Array`]. +/// +/// One way to think about a [`PrimitiveArray`] is `(DataType, Arc>, Option>>)` +/// where: +/// * the first item is the array's logical type +/// * the second is the immutable values +/// * the third is the immutable validity (whether a value is null or not as a bitmap). +/// +/// The size of this struct is `O(1)`, as all data is stored behind an [`std::sync::Arc`]. +/// # Example +/// ``` +/// use polars_arrow::array::PrimitiveArray; +/// use polars_arrow::bitmap::Bitmap; +/// use polars_arrow::buffer::Buffer; +/// +/// let array = PrimitiveArray::from([Some(1i32), None, Some(10)]); +/// assert_eq!(array.value(0), 1); +/// assert_eq!(array.iter().collect::>(), vec![Some(&1i32), None, Some(&10)]); +/// assert_eq!(array.values_iter().copied().collect::>(), vec![1, 0, 10]); +/// // the underlying representation +/// assert_eq!(array.values(), &Buffer::from(vec![1i32, 0, 10])); +/// assert_eq!(array.validity(), Some(&Bitmap::from([true, false, true]))); +/// +/// ``` +#[derive(Clone)] +pub struct PrimitiveArray { + dtype: ArrowDataType, + values: Buffer, + validity: Option, +} + +pub(super) fn check( + dtype: &ArrowDataType, + values: &[T], + validity_len: Option, +) -> PolarsResult<()> { + if validity_len.is_some_and(|len| len != values.len()) { + polars_bail!(ComputeError: "validity mask length must match the number of values") + } + + if dtype.to_physical_type() != PhysicalType::Primitive(T::PRIMITIVE) { + polars_bail!(ComputeError: "PrimitiveArray can only be initialized with a DataType whose physical type is Primitive") + } + Ok(()) +} + +impl PrimitiveArray { + /// The canonical method to create a [`PrimitiveArray`] out of its internal components. + /// # Implementation + /// This function is `O(1)`. + /// + /// # Errors + /// This function errors iff: + /// * The validity is not `None` and its length is different from `values`'s length + /// * The `dtype`'s [`PhysicalType`] is not equal to [`PhysicalType::Primitive(T::PRIMITIVE)`] + pub fn try_new( + dtype: ArrowDataType, + values: Buffer, + validity: Option, + ) -> PolarsResult { + check(&dtype, &values, validity.as_ref().map(|v| v.len()))?; + Ok(Self { + dtype, + values, + validity, + }) + } + + /// # Safety + /// Doesn't check invariants + pub unsafe fn new_unchecked( + dtype: ArrowDataType, + values: Buffer, + validity: Option, + ) -> Self { + if cfg!(debug_assertions) { + check(&dtype, &values, validity.as_ref().map(|v| v.len())).unwrap(); + } + + Self { + dtype, + values, + validity, + } + } + + /// Returns a new [`PrimitiveArray`] with a different logical type. + /// + /// This function is useful to assign a different [`ArrowDataType`] to the array. + /// Used to change the arrays' logical type (see example). + /// # Example + /// ``` + /// use polars_arrow::array::Int32Array; + /// use polars_arrow::datatypes::ArrowDataType; + /// + /// let array = Int32Array::from(&[Some(1), None, Some(2)]).to(ArrowDataType::Date32); + /// assert_eq!( + /// format!("{:?}", array), + /// "Date32[1970-01-02, None, 1970-01-03]" + /// ); + /// ``` + /// # Panics + /// Panics iff the `dtype`'s [`PhysicalType`] is not equal to [`PhysicalType::Primitive(T::PRIMITIVE)`] + #[inline] + #[must_use] + pub fn to(self, dtype: ArrowDataType) -> Self { + check( + &dtype, + &self.values, + self.validity.as_ref().map(|v| v.len()), + ) + .unwrap(); + Self { + dtype, + values: self.values, + validity: self.validity, + } + } + + /// Creates a (non-null) [`PrimitiveArray`] from a vector of values. + /// This function is `O(1)`. + /// # Examples + /// ``` + /// use polars_arrow::array::PrimitiveArray; + /// + /// let array = PrimitiveArray::from_vec(vec![1, 2, 3]); + /// assert_eq!(format!("{:?}", array), "Int32[1, 2, 3]"); + /// ``` + pub fn from_vec(values: Vec) -> Self { + Self::new(T::PRIMITIVE.into(), values.into(), None) + } + + /// Returns an iterator over the values and validity, `Option<&T>`. + #[inline] + pub fn iter(&self) -> ZipValidity<&T, std::slice::Iter, BitmapIter> { + ZipValidity::new_with_validity(self.values().iter(), self.validity()) + } + + /// Returns an iterator of the values, `&T`, ignoring the arrays' validity. + #[inline] + pub fn values_iter(&self) -> std::slice::Iter { + self.values().iter() + } + + /// Returns an iterator of the non-null values `T`. + #[inline] + pub fn non_null_values_iter(&self) -> NonNullValuesIter<'_, [T]> { + NonNullValuesIter::new(self.values(), self.validity()) + } + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.values.len() + } + + /// The values [`Buffer`]. + /// Values on null slots are undetermined (they can be anything). + #[inline] + pub fn values(&self) -> &Buffer { + &self.values + } + + /// Returns the optional validity. + #[inline] + pub fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + /// Returns the arrays' [`ArrowDataType`]. + #[inline] + pub fn dtype(&self) -> &ArrowDataType { + &self.dtype + } + + /// Returns the value at slot `i`. + /// + /// Equivalent to `self.values()[i]`. The value of a null slot is undetermined (it can be anything). + /// # Panic + /// This function panics iff `i >= self.len`. + #[inline] + pub fn value(&self, i: usize) -> T { + self.values[i] + } + + /// Returns the value at index `i`. + /// The value on null slots is undetermined (it can be anything). + /// + /// # Safety + /// Caller must be sure that `i < self.len()` + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> T { + *self.values.get_unchecked(i) + } + + // /// Returns the element at index `i` or `None` if it is null + // /// # Panics + // /// iff `i >= self.len()` + // #[inline] + // pub fn get(&self, i: usize) -> Option { + // if !self.is_null(i) { + // // soundness: Array::is_null panics if i >= self.len + // unsafe { Some(self.value_unchecked(i)) } + // } else { + // None + // } + // } + + /// Slices this [`PrimitiveArray`] by an offset and length. + /// # Implementation + /// This operation is `O(1)`. + #[inline] + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "offset + length may not exceed length of array" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Slices this [`PrimitiveArray`] by an offset and length. + /// # Implementation + /// This operation is `O(1)`. + /// + /// # Safety + /// The caller must ensure that `offset + length <= self.len()`. + #[inline] + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.validity = self + .validity + .take() + .map(|bitmap| bitmap.sliced_unchecked(offset, length)) + .filter(|bitmap| bitmap.unset_bits() > 0); + self.values.slice_unchecked(offset, length); + } + + impl_sliced!(); + impl_mut_validity!(); + impl_into_array!(); + + /// Returns this [`PrimitiveArray`] with new values. + /// # Panics + /// This function panics iff `values.len() != self.len()`. + #[must_use] + pub fn with_values(mut self, values: Buffer) -> Self { + self.set_values(values); + self + } + + /// Update the values of this [`PrimitiveArray`]. + /// # Panics + /// This function panics iff `values.len() != self.len()`. + pub fn set_values(&mut self, values: Buffer) { + assert_eq!( + values.len(), + self.len(), + "values' length must be equal to this arrays' length" + ); + self.values = values; + } + + /// Applies a function `f` to the validity of this array. + /// + /// This is an API to leverage clone-on-write + /// # Panics + /// This function panics if the function `f` modifies the length of the [`Bitmap`]. + pub fn apply_validity Bitmap>(&mut self, f: F) { + if let Some(validity) = std::mem::take(&mut self.validity) { + self.set_validity(Some(f(validity))) + } + } + + /// Returns an option of a mutable reference to the values of this [`PrimitiveArray`]. + pub fn get_mut_values(&mut self) -> Option<&mut [T]> { + self.values.get_mut_slice() + } + + /// Returns its internal representation + #[must_use] + pub fn into_inner(self) -> (ArrowDataType, Buffer, Option) { + let Self { + dtype, + values, + validity, + } = self; + (dtype, values, validity) + } + + /// Creates a [`PrimitiveArray`] from its internal representation. + /// This is the inverted from [`PrimitiveArray::into_inner`] + pub fn from_inner( + dtype: ArrowDataType, + values: Buffer, + validity: Option, + ) -> PolarsResult { + check(&dtype, &values, validity.as_ref().map(|v| v.len()))?; + Ok(unsafe { Self::from_inner_unchecked(dtype, values, validity) }) + } + + /// Creates a [`PrimitiveArray`] from its internal representation. + /// This is the inverted from [`PrimitiveArray::into_inner`] + /// + /// # Safety + /// Callers must ensure all invariants of this struct are upheld. + pub unsafe fn from_inner_unchecked( + dtype: ArrowDataType, + values: Buffer, + validity: Option, + ) -> Self { + Self { + dtype, + values, + validity, + } + } + + /// Try to convert this [`PrimitiveArray`] to a [`MutablePrimitiveArray`] via copy-on-write semantics. + /// + /// A [`PrimitiveArray`] is backed by a [`Buffer`] and [`Bitmap`] which are essentially `Arc>`. + /// This function returns a [`MutablePrimitiveArray`] (via [`std::sync::Arc::get_mut`]) iff both values + /// and validity have not been cloned / are unique references to their underlying vectors. + /// + /// This function is primarily used to reuse memory regions. + #[must_use] + pub fn into_mut(self) -> Either> { + use Either::*; + + if let Some(bitmap) = self.validity { + match bitmap.into_mut() { + Left(bitmap) => Left(PrimitiveArray::new(self.dtype, self.values, Some(bitmap))), + Right(mutable_bitmap) => match self.values.into_mut() { + Right(values) => Right( + MutablePrimitiveArray::try_new(self.dtype, values, Some(mutable_bitmap)) + .unwrap(), + ), + Left(values) => Left(PrimitiveArray::new( + self.dtype, + values, + Some(mutable_bitmap.into()), + )), + }, + } + } else { + match self.values.into_mut() { + Right(values) => { + Right(MutablePrimitiveArray::try_new(self.dtype, values, None).unwrap()) + }, + Left(values) => Left(PrimitiveArray::new(self.dtype, values, None)), + } + } + } + + /// Returns a new empty (zero-length) [`PrimitiveArray`]. + pub fn new_empty(dtype: ArrowDataType) -> Self { + Self::new(dtype, Buffer::new(), None) + } + + /// Returns a new [`PrimitiveArray`] where all slots are null / `None`. + #[inline] + pub fn new_null(dtype: ArrowDataType, length: usize) -> Self { + Self::new( + dtype, + vec![T::default(); length].into(), + Some(Bitmap::new_zeroed(length)), + ) + } + + /// Creates a (non-null) [`PrimitiveArray`] from an iterator of values. + /// # Implementation + /// This does not assume that the iterator has a known length. + pub fn from_values>(iter: I) -> Self { + Self::new(T::PRIMITIVE.into(), Vec::::from_iter(iter).into(), None) + } + + /// Creates a (non-null) [`PrimitiveArray`] from a slice of values. + /// # Implementation + /// This is essentially a memcopy and is thus `O(N)` + pub fn from_slice>(slice: P) -> Self { + Self::new( + T::PRIMITIVE.into(), + Vec::::from(slice.as_ref()).into(), + None, + ) + } + + /// Creates a (non-null) [`PrimitiveArray`] from a [`TrustedLen`] of values. + /// # Implementation + /// This does not assume that the iterator has a known length. + pub fn from_trusted_len_values_iter>(iter: I) -> Self { + MutablePrimitiveArray::::from_trusted_len_values_iter(iter).into() + } + + /// Creates a new [`PrimitiveArray`] from an iterator over values + /// + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + pub unsafe fn from_trusted_len_values_iter_unchecked>(iter: I) -> Self { + MutablePrimitiveArray::::from_trusted_len_values_iter_unchecked(iter).into() + } + + /// Creates a [`PrimitiveArray`] from a [`TrustedLen`] of optional values. + pub fn from_trusted_len_iter>>(iter: I) -> Self { + MutablePrimitiveArray::::from_trusted_len_iter(iter).into() + } + + /// Creates a [`PrimitiveArray`] from an iterator of optional values. + /// + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + pub unsafe fn from_trusted_len_iter_unchecked>>(iter: I) -> Self { + MutablePrimitiveArray::::from_trusted_len_iter_unchecked(iter).into() + } + + /// Alias for `Self::try_new(..).unwrap()`. + /// # Panics + /// This function errors iff: + /// * The validity is not `None` and its length is different from `values`'s length + /// * The `dtype`'s [`PhysicalType`] is not equal to [`PhysicalType::Primitive`]. + pub fn new(dtype: ArrowDataType, values: Buffer, validity: Option) -> Self { + Self::try_new(dtype, values, validity).unwrap() + } + + /// Transmute this PrimitiveArray into another PrimitiveArray. + /// + /// T and U must have the same size and alignment. + pub fn transmute(self) -> PrimitiveArray { + let PrimitiveArray { + values, validity, .. + } = self; + + // SAFETY: this is fine, we checked size and alignment, and NativeType + // is always Pod. + assert_eq!(size_of::(), size_of::()); + assert_eq!(align_of::(), align_of::()); + let new_values = unsafe { std::mem::transmute::, Buffer>(values) }; + PrimitiveArray::new(U::PRIMITIVE.into(), new_values, validity) + } + + /// Fills this entire array with the given value, leaving the validity mask intact. + /// + /// Reuses the memory of the PrimitiveArray if possible. + pub fn fill_with(mut self, value: T) -> Self { + if let Some(values) = self.get_mut_values() { + for x in values.iter_mut() { + *x = value; + } + self + } else { + let values = vec![value; self.len()]; + Self::new(T::PRIMITIVE.into(), values.into(), self.validity) + } + } +} + +impl Array for PrimitiveArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + #[inline] + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } +} + +impl Splitable for PrimitiveArray { + #[inline(always)] + fn check_bound(&self, offset: usize) -> bool { + offset <= self.len() + } + + unsafe fn _split_at_unchecked(&self, offset: usize) -> (Self, Self) { + let (lhs_values, rhs_values) = unsafe { self.values.split_at_unchecked(offset) }; + let (lhs_validity, rhs_validity) = unsafe { self.validity.split_at_unchecked(offset) }; + + ( + Self { + dtype: self.dtype.clone(), + values: lhs_values, + validity: lhs_validity, + }, + Self { + dtype: self.dtype.clone(), + values: rhs_values, + validity: rhs_validity, + }, + ) + } +} + +impl SliceAble for PrimitiveArray { + unsafe fn slice_unchecked(&self, range: Range) -> Self { + self.clone().sliced_unchecked(range.start, range.len()) + } + + fn slice(&self, range: Range) -> Self { + self.clone().sliced(range.start, range.len()) + } +} + +impl Indexable for PrimitiveArray { + type Item = Option; + + fn get(&self, i: usize) -> Self::Item { + if !self.is_null(i) { + // soundness: Array::is_null panics if i >= self.len + unsafe { Some(self.value_unchecked(i)) } + } else { + None + } + } + + unsafe fn get_unchecked(&self, i: usize) -> Self::Item { + if !self.is_null_unchecked(i) { + Some(self.value_unchecked(i)) + } else { + None + } + } +} + +/// A type definition [`PrimitiveArray`] for `i8` +pub type Int8Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `i16` +pub type Int16Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `i32` +pub type Int32Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `i64` +pub type Int64Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `i128` +pub type Int128Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `i256` +pub type Int256Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for [`days_ms`] +pub type DaysMsArray = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for [`months_days_ns`] +pub type MonthsDaysNsArray = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `f16` +pub type Float16Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `f32` +pub type Float32Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `f64` +pub type Float64Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `u8` +pub type UInt8Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `u16` +pub type UInt16Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `u32` +pub type UInt32Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `u64` +pub type UInt64Array = PrimitiveArray; + +/// A type definition [`MutablePrimitiveArray`] for `i8` +pub type Int8Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `i16` +pub type Int16Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `i32` +pub type Int32Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `i64` +pub type Int64Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `i128` +pub type Int128Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `i256` +pub type Int256Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for [`days_ms`] +pub type DaysMsVec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for [`months_days_ns`] +pub type MonthsDaysNsVec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `f16` +pub type Float16Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `f32` +pub type Float32Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `f64` +pub type Float64Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `u8` +pub type UInt8Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `u16` +pub type UInt16Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `u32` +pub type UInt32Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `u64` +pub type UInt64Vec = MutablePrimitiveArray; + +impl Default for PrimitiveArray { + fn default() -> Self { + PrimitiveArray::new(T::PRIMITIVE.into(), Default::default(), None) + } +} + +impl Bounded for PrimitiveArray { + fn len(&self) -> usize { + self.values.len() + } +} + +impl NullCount for PrimitiveArray { + fn null_count(&self) -> usize { + ::null_count(self) + } +} diff --git a/crates/polars-arrow/src/array/primitive/mutable.rs b/crates/polars-arrow/src/array/primitive/mutable.rs new file mode 100644 index 000000000000..24486235ef53 --- /dev/null +++ b/crates/polars-arrow/src/array/primitive/mutable.rs @@ -0,0 +1,696 @@ +use std::sync::Arc; + +use polars_error::PolarsResult; + +use super::{PrimitiveArray, check}; +use crate::array::physical_binary::extend_validity; +use crate::array::{Array, MutableArray, TryExtend, TryExtendFromSelf, TryPush}; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::datatypes::ArrowDataType; +use crate::trusted_len::TrustedLen; +use crate::types::NativeType; + +/// The Arrow's equivalent to `Vec>` where `T` is byte-size (e.g. `i32`). +/// Converting a [`MutablePrimitiveArray`] into a [`PrimitiveArray`] is `O(1)`. +#[derive(Debug, Clone)] +pub struct MutablePrimitiveArray { + dtype: ArrowDataType, + values: Vec, + validity: Option, +} + +impl From> for PrimitiveArray { + fn from(other: MutablePrimitiveArray) -> Self { + let validity = other.validity.and_then(|x| { + let bitmap: Bitmap = x.into(); + if bitmap.unset_bits() == 0 { + None + } else { + Some(bitmap) + } + }); + + PrimitiveArray::::new(other.dtype, other.values.into(), validity) + } +} + +impl]>> From

for MutablePrimitiveArray { + fn from(slice: P) -> Self { + Self::from_trusted_len_iter(slice.as_ref().iter().map(|x| x.as_ref())) + } +} + +impl MutablePrimitiveArray { + /// Creates a new empty [`MutablePrimitiveArray`]. + pub fn new() -> Self { + Self::with_capacity(0) + } + + /// Creates a new [`MutablePrimitiveArray`] with a capacity. + pub fn with_capacity(capacity: usize) -> Self { + Self::with_capacity_from(capacity, T::PRIMITIVE.into()) + } + + /// The canonical method to create a [`MutablePrimitiveArray`] out of its internal components. + /// # Implementation + /// This function is `O(1)`. + /// + /// # Errors + /// This function errors iff: + /// * The validity is not `None` and its length is different from `values`'s length + /// * The `dtype`'s [`crate::datatypes::PhysicalType`] is not equal to [`crate::datatypes::PhysicalType::Primitive(T::PRIMITIVE)`] + pub fn try_new( + dtype: ArrowDataType, + values: Vec, + validity: Option, + ) -> PolarsResult { + check(&dtype, &values, validity.as_ref().map(|x| x.len()))?; + Ok(Self { + dtype, + values, + validity, + }) + } + + /// Extract the low-end APIs from the [`MutablePrimitiveArray`]. + pub fn into_inner(self) -> (ArrowDataType, Vec, Option) { + (self.dtype, self.values, self.validity) + } + + /// Applies a function `f` to the values of this array, cloning the values + /// iff they are being shared with others + /// + /// This is an API to use clone-on-write + /// # Implementation + /// This function is `O(f)` if the data is not being shared, and `O(N) + O(f)` + /// if it is being shared (since it results in a `O(N)` memcopy). + /// # Panics + /// This function panics iff `f` panics + pub fn apply_values(&mut self, f: F) { + f(&mut self.values); + } +} + +impl Default for MutablePrimitiveArray { + fn default() -> Self { + Self::new() + } +} + +impl From for MutablePrimitiveArray { + fn from(dtype: ArrowDataType) -> Self { + assert!(dtype.to_physical_type().eq_primitive(T::PRIMITIVE)); + Self { + dtype, + values: Vec::::new(), + validity: None, + } + } +} + +impl MutablePrimitiveArray { + /// Creates a new [`MutablePrimitiveArray`] from a capacity and [`ArrowDataType`]. + pub fn with_capacity_from(capacity: usize, dtype: ArrowDataType) -> Self { + assert!(dtype.to_physical_type().eq_primitive(T::PRIMITIVE)); + Self { + dtype, + values: Vec::::with_capacity(capacity), + validity: None, + } + } + + /// Reserves `additional` entries. + pub fn reserve(&mut self, additional: usize) { + self.values.reserve(additional); + if let Some(x) = self.validity.as_mut() { + x.reserve(additional) + } + } + + #[inline] + pub fn push_value(&mut self, value: T) { + self.values.push(value); + if let Some(validity) = &mut self.validity { + validity.push(true) + } + } + + /// Adds a new value to the array. + #[inline] + pub fn push(&mut self, value: Option) { + match value { + Some(value) => self.push_value(value), + None => { + self.values.push(T::default()); + match &mut self.validity { + Some(validity) => validity.push(false), + None => { + self.init_validity(); + }, + } + }, + } + } + + /// Pop a value from the array. + /// Note if the values is empty, this method will return None. + pub fn pop(&mut self) -> Option { + let value = self.values.pop()?; + self.validity + .as_mut() + .map(|x| x.pop()?.then(|| value)) + .unwrap_or_else(|| Some(value)) + } + + /// Extends the [`MutablePrimitiveArray`] with a constant + #[inline] + pub fn extend_constant(&mut self, additional: usize, value: Option) { + if let Some(value) = value { + self.values.resize(self.values.len() + additional, value); + if let Some(validity) = &mut self.validity { + validity.extend_constant(additional, true) + } + } else { + if let Some(validity) = &mut self.validity { + validity.extend_constant(additional, false) + } else { + let mut validity = MutableBitmap::with_capacity(self.values.capacity()); + validity.extend_constant(self.len(), true); + validity.extend_constant(additional, false); + self.validity = Some(validity) + } + self.values + .resize(self.values.len() + additional, T::default()); + } + } + + /// Extends the [`MutablePrimitiveArray`] from an iterator of trusted len. + #[inline] + pub fn extend_trusted_len(&mut self, iterator: I) + where + P: std::borrow::Borrow, + I: TrustedLen>, + { + unsafe { self.extend_trusted_len_unchecked(iterator) } + } + + /// Extends the [`MutablePrimitiveArray`] from an iterator of trusted len. + /// + /// # Safety + /// The iterator must be trusted len. + #[inline] + pub unsafe fn extend_trusted_len_unchecked(&mut self, iterator: I) + where + P: std::borrow::Borrow, + I: Iterator>, + { + if let Some(validity) = self.validity.as_mut() { + extend_trusted_len_unzip(iterator, validity, &mut self.values) + } else { + let mut validity = MutableBitmap::new(); + validity.extend_constant(self.len(), true); + extend_trusted_len_unzip(iterator, &mut validity, &mut self.values); + self.validity = Some(validity); + } + } + /// Extends the [`MutablePrimitiveArray`] from an iterator of values of trusted len. + /// This differs from `extend_trusted_len` which accepts in iterator of optional values. + #[inline] + pub fn extend_trusted_len_values(&mut self, iterator: I) + where + I: TrustedLen, + { + unsafe { self.extend_values(iterator) } + } + + /// Extends the [`MutablePrimitiveArray`] from an iterator of values of trusted len. + /// This differs from `extend_trusted_len_unchecked` which accepts in iterator of optional values. + /// + /// # Safety + /// The iterator must be trusted len. + #[inline] + pub fn extend_values(&mut self, iterator: I) + where + I: Iterator, + { + self.values.extend(iterator); + self.update_all_valid(); + } + + #[inline] + /// Extends the [`MutablePrimitiveArray`] from a slice + pub fn extend_from_slice(&mut self, items: &[T]) { + self.values.extend_from_slice(items); + self.update_all_valid(); + } + + fn update_all_valid(&mut self) { + // get len before mutable borrow + let len = self.len(); + if let Some(validity) = self.validity.as_mut() { + validity.extend_constant(len - validity.len(), true); + } + } + + fn init_validity(&mut self) { + let mut validity = MutableBitmap::with_capacity(self.values.capacity()); + validity.extend_constant(self.len(), true); + validity.set(self.len() - 1, false); + self.validity = Some(validity) + } + + /// Changes the arrays' [`ArrowDataType`], returning a new [`MutablePrimitiveArray`]. + /// Use to change the logical type without changing the corresponding physical Type. + /// # Implementation + /// This operation is `O(1)`. + #[inline] + pub fn to(self, dtype: ArrowDataType) -> Self { + Self::try_new(dtype, self.values, self.validity).unwrap() + } + + /// Converts itself into an [`Array`]. + pub fn into_arc(self) -> Arc { + let a: PrimitiveArray = self.into(); + Arc::new(a) + } + + /// Shrinks the capacity of the [`MutablePrimitiveArray`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + if let Some(validity) = &mut self.validity { + validity.shrink_to_fit() + } + } + + /// Returns the capacity of this [`MutablePrimitiveArray`]. + pub fn capacity(&self) -> usize { + self.values.capacity() + } + + pub fn freeze(self) -> PrimitiveArray { + self.into() + } + + /// Clears the array, removing all values. + /// + /// Note that this method has no effect on the allocated capacity + /// of the array. + pub fn clear(&mut self) { + self.values.clear(); + self.validity = None; + } + + /// Apply a function that temporarily freezes this `MutableArray` into a `PrimitiveArray`. + pub fn with_freeze) -> K>(&mut self, f: F) -> K { + let mutable = std::mem::take(self); + let arr = mutable.freeze(); + let out = f(&arr); + *self = arr.into_mut().right().unwrap(); + out + } +} + +/// Accessors +impl MutablePrimitiveArray { + /// Returns its values. + pub fn values(&self) -> &Vec { + &self.values + } + + /// Returns a mutable slice of values. + pub fn values_mut_slice(&mut self) -> &mut [T] { + self.values.as_mut_slice() + } +} + +/// Setters +impl MutablePrimitiveArray { + /// Sets position `index` to `value`. + /// Note that if it is the first time a null appears in this array, + /// this initializes the validity bitmap (`O(N)`). + /// # Panic + /// Panics iff `index >= self.len()`. + pub fn set(&mut self, index: usize, value: Option) { + assert!(index < self.len()); + // SAFETY: + // we just checked bounds + unsafe { self.set_unchecked(index, value) } + } + + /// Sets position `index` to `value`. + /// Note that if it is the first time a null appears in this array, + /// this initializes the validity bitmap (`O(N)`). + /// + /// # Safety + /// Caller must ensure `index < self.len()` + pub unsafe fn set_unchecked(&mut self, index: usize, value: Option) { + *self.values.get_unchecked_mut(index) = value.unwrap_or_default(); + + if value.is_none() && self.validity.is_none() { + // When the validity is None, all elements so far are valid. When one of the elements is set of null, + // the validity must be initialized. + let mut validity = MutableBitmap::new(); + validity.extend_constant(self.len(), true); + self.validity = Some(validity); + } + if let Some(x) = self.validity.as_mut() { + x.set_unchecked(index, value.is_some()) + } + } + + /// Sets the validity. + /// # Panic + /// Panics iff the validity's len is not equal to the existing values' length. + pub fn set_validity(&mut self, validity: Option) { + if let Some(validity) = &validity { + assert_eq!(self.values.len(), validity.len()) + } + self.validity = validity; + } + + /// Sets values. + /// # Panic + /// Panics iff the values' length is not equal to the existing values' len. + pub fn set_values(&mut self, values: Vec) { + assert_eq!(values.len(), self.values.len()); + self.values = values; + } +} + +impl Extend> for MutablePrimitiveArray { + fn extend>>(&mut self, iter: I) { + let iter = iter.into_iter(); + self.reserve(iter.size_hint().0); + iter.for_each(|x| self.push(x)) + } +} + +impl TryExtend> for MutablePrimitiveArray { + /// This is infallible and is implemented for consistency with all other types + fn try_extend>>(&mut self, iter: I) -> PolarsResult<()> { + self.extend(iter); + Ok(()) + } +} + +impl TryPush> for MutablePrimitiveArray { + /// This is infalible and is implemented for consistency with all other types + #[inline] + fn try_push(&mut self, item: Option) -> PolarsResult<()> { + self.push(item); + Ok(()) + } +} + +impl MutableArray for MutablePrimitiveArray { + fn len(&self) -> usize { + self.values.len() + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + PrimitiveArray::new( + self.dtype.clone(), + std::mem::take(&mut self.values).into(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .boxed() + } + + fn as_arc(&mut self) -> Arc { + PrimitiveArray::new( + self.dtype.clone(), + std::mem::take(&mut self.values).into(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .arced() + } + + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + fn push_null(&mut self) { + self.push(None) + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional) + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit() + } +} + +impl MutablePrimitiveArray { + /// Creates a [`MutablePrimitiveArray`] from a slice of values. + pub fn from_slice>(slice: P) -> Self { + Self::from_trusted_len_values_iter(slice.as_ref().iter().copied()) + } + + /// Creates a [`MutablePrimitiveArray`] from an iterator of trusted length. + /// + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + P: std::borrow::Borrow, + I: Iterator>, + { + let (validity, values) = trusted_len_unzip(iterator); + + Self { + dtype: T::PRIMITIVE.into(), + values, + validity, + } + } + + /// Creates a [`MutablePrimitiveArray`] from a [`TrustedLen`]. + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + P: std::borrow::Borrow, + I: TrustedLen>, + { + unsafe { Self::from_trusted_len_iter_unchecked(iterator) } + } + + /// Creates a [`MutablePrimitiveArray`] from an fallible iterator of trusted length. + /// + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn try_from_trusted_len_iter_unchecked( + iter: I, + ) -> std::result::Result + where + P: std::borrow::Borrow, + I: IntoIterator, E>>, + { + let iterator = iter.into_iter(); + + let (validity, values) = try_trusted_len_unzip(iterator)?; + + Ok(Self { + dtype: T::PRIMITIVE.into(), + values, + validity, + }) + } + + /// Creates a [`MutablePrimitiveArray`] from an fallible iterator of trusted length. + #[inline] + pub fn try_from_trusted_len_iter(iterator: I) -> std::result::Result + where + P: std::borrow::Borrow, + I: TrustedLen, E>>, + { + unsafe { Self::try_from_trusted_len_iter_unchecked(iterator) } + } + + /// Creates a new [`MutablePrimitiveArray`] out an iterator over values + pub fn from_trusted_len_values_iter>(iter: I) -> Self { + Self { + dtype: T::PRIMITIVE.into(), + values: iter.collect(), + validity: None, + } + } + + /// Creates a (non-null) [`MutablePrimitiveArray`] from a vector of values. + /// This does not have memcopy and is the fastest way to create a [`PrimitiveArray`]. + pub fn from_vec(values: Vec) -> Self { + Self::try_new(T::PRIMITIVE.into(), values, None).unwrap() + } + + /// Creates a new [`MutablePrimitiveArray`] from an iterator over values + /// + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + pub unsafe fn from_trusted_len_values_iter_unchecked>(iter: I) -> Self { + Self { + dtype: T::PRIMITIVE.into(), + values: iter.collect(), + validity: None, + } + } +} + +impl>> FromIterator + for MutablePrimitiveArray +{ + fn from_iter>(iter: I) -> Self { + let iter = iter.into_iter(); + let (lower, _) = iter.size_hint(); + + let mut validity = MutableBitmap::with_capacity(lower); + + let values: Vec = iter + .map(|item| { + if let Some(a) = item.borrow() { + validity.push(true); + *a + } else { + validity.push(false); + T::default() + } + }) + .collect(); + + let validity = Some(validity); + + Self { + dtype: T::PRIMITIVE.into(), + values, + validity, + } + } +} + +/// Extends a [`MutableBitmap`] and a [`Vec`] from an iterator of `Option`. +/// The first buffer corresponds to a bitmap buffer, the second one +/// corresponds to a values buffer. +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +pub(crate) unsafe fn extend_trusted_len_unzip( + iterator: I, + validity: &mut MutableBitmap, + buffer: &mut Vec, +) where + T: NativeType, + P: std::borrow::Borrow, + I: Iterator>, +{ + let (_, upper) = iterator.size_hint(); + let additional = upper.expect("trusted_len_unzip requires an upper limit"); + + validity.reserve(additional); + let values = iterator.map(|item| { + if let Some(item) = item { + validity.push_unchecked(true); + *item.borrow() + } else { + validity.push_unchecked(false); + T::default() + } + }); + buffer.extend(values); +} + +/// Creates a [`MutableBitmap`] and a [`Vec`] from an iterator of `Option`. +/// The first buffer corresponds to a bitmap buffer, the second one +/// corresponds to a values buffer. +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +pub(crate) unsafe fn trusted_len_unzip(iterator: I) -> (Option, Vec) +where + T: NativeType, + P: std::borrow::Borrow, + I: Iterator>, +{ + let mut validity = MutableBitmap::new(); + let mut buffer = Vec::::new(); + + extend_trusted_len_unzip(iterator, &mut validity, &mut buffer); + + let validity = Some(validity); + + (validity, buffer) +} + +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +pub(crate) unsafe fn try_trusted_len_unzip( + iterator: I, +) -> std::result::Result<(Option, Vec), E> +where + T: NativeType, + P: std::borrow::Borrow, + I: Iterator, E>>, +{ + let (_, upper) = iterator.size_hint(); + let len = upper.expect("trusted_len_unzip requires an upper limit"); + + let mut null = MutableBitmap::with_capacity(len); + let mut buffer = Vec::::with_capacity(len); + + let mut dst = buffer.as_mut_ptr(); + for item in iterator { + let item = if let Some(item) = item? { + null.push(true); + *item.borrow() + } else { + null.push(false); + T::default() + }; + std::ptr::write(dst, item); + dst = dst.add(1); + } + assert_eq!( + dst.offset_from(buffer.as_ptr()) as usize, + len, + "Trusted iterator length was not accurately reported" + ); + buffer.set_len(len); + null.set_len(len); + + let validity = Some(null); + + Ok((validity, buffer)) +} + +impl PartialEq for MutablePrimitiveArray { + fn eq(&self, other: &Self) -> bool { + self.iter().eq(other.iter()) + } +} + +impl TryExtendFromSelf for MutablePrimitiveArray { + fn try_extend_from_self(&mut self, other: &Self) -> PolarsResult<()> { + extend_validity(self.len(), &mut self.validity, &other.validity); + + let slice = other.values.as_slice(); + self.values.extend_from_slice(slice); + Ok(()) + } +} diff --git a/crates/polars-arrow/src/array/specification.rs b/crates/polars-arrow/src/array/specification.rs new file mode 100644 index 000000000000..edbc494e1efa --- /dev/null +++ b/crates/polars-arrow/src/array/specification.rs @@ -0,0 +1,174 @@ +use polars_error::{PolarsResult, polars_bail, polars_err, to_compute_err}; + +use crate::array::DictionaryKey; +use crate::offset::{Offset, Offsets, OffsetsBuffer}; + +/// Helper trait to support `Offset` and `OffsetBuffer` +pub trait OffsetsContainer { + fn last(&self) -> usize; + fn as_slice(&self) -> &[O]; +} + +impl OffsetsContainer for OffsetsBuffer { + #[inline] + fn last(&self) -> usize { + self.last().to_usize() + } + + #[inline] + fn as_slice(&self) -> &[O] { + self.buffer() + } +} + +impl OffsetsContainer for Offsets { + #[inline] + fn last(&self) -> usize { + self.last().to_usize() + } + + #[inline] + fn as_slice(&self) -> &[O] { + self.as_slice() + } +} + +pub(crate) fn try_check_offsets_bounds( + offsets: &[O], + values_len: usize, +) -> PolarsResult<()> { + if offsets.last().unwrap().to_usize() > values_len { + polars_bail!(ComputeError: "offsets must not exceed the values length") + } else { + Ok(()) + } +} + +/// # Error +/// * any offset is larger or equal to `values_len`. +/// * any slice of `values` between two consecutive pairs from `offsets` is invalid `utf8`, or +pub fn try_check_utf8(offsets: &[O], values: &[u8]) -> PolarsResult<()> { + if offsets.len() == 1 { + return Ok(()); + } + assert!(offsets.len() > 1); + let end = offsets.last().unwrap().to_usize(); + let start = offsets.first().unwrap().to_usize(); + + try_check_offsets_bounds(offsets, values.len())?; + let values_range = &values[start..end]; + + if values_range.is_ascii() { + Ok(()) + } else { + simdutf8::basic::from_utf8(values_range).map_err(to_compute_err)?; + + // offsets can be == values.len() + // find first offset from the end that is smaller + // Example: + // values.len() = 10 + // offsets = [0, 5, 10, 10] + let last = offsets + .iter() + .enumerate() + .skip(1) + .rev() + .find_map(|(i, offset)| (offset.to_usize() < values.len()).then(|| i)); + + let last = if let Some(last) = last { + // following the example: last = 1 (offset = 5) + last + } else { + // given `l = values.len()`, this branch is hit iff either: + // * `offsets = [0, l, l, ...]`, which was covered by `from_utf8(values)` above + // * `offsets = [0]`, which never happens because offsets.as_slice().len() == 1 is short-circuited above + return Ok(()); + }; + + // truncate to relevant offsets. Note: `=last` because last was computed skipping the first item + // following the example: starts = [0, 5] + let starts = unsafe { offsets.get_unchecked(..=last) }; + + let mut any_invalid = false; + for start in starts { + let start = start.to_usize(); + + // SAFETY: `try_check_offsets_bounds` just checked for bounds + let b = *unsafe { values.get_unchecked(start) }; + + // A valid code-point iff it does not start with 0b10xxxxxx + // Bit-magic taken from `std::str::is_char_boundary` + any_invalid |= (b as i8) < -0x40; + } + if any_invalid { + polars_bail!(ComputeError: "non-valid char boundary detected") + } + Ok(()) + } +} + +/// Check dictionary indexes without checking usize conversion. +/// # Safety +/// The caller must ensure that `K::as_usize` always succeeds. +pub(crate) unsafe fn check_indexes_unchecked( + keys: &[K], + len: usize, +) -> PolarsResult<()> { + let mut invalid = false; + + // this loop is auto-vectorized + keys.iter().for_each(|k| invalid |= k.as_usize() > len); + + if invalid { + let key = keys.iter().map(|k| k.as_usize()).max().unwrap(); + polars_bail!(ComputeError: "one of the dictionary keys is {key} but it must be < than the length of the dictionary values, which is {len}") + } else { + Ok(()) + } +} + +pub fn check_indexes(keys: &[K], len: usize) -> PolarsResult<()> +where + K: std::fmt::Debug + Copy + TryInto, +{ + keys.iter().try_for_each(|key| { + let key: usize = (*key) + .try_into() + .map_err(|_| polars_err!(ComputeError: "The dictionary key must fit in a `usize`, but {key:?} does not") + )?; + if key >= len { + polars_bail!(ComputeError: "one of the dictionary keys is {key} but it must be < than the length of the dictionary values, which is {len}") + } else { + Ok(()) + } + }) +} + +#[cfg(test)] +mod tests { + use proptest::prelude::*; + + use super::*; + + pub(crate) fn binary_strategy() -> impl Strategy> { + prop::collection::vec(any::(), 1..100) + } + + proptest! { + // a bit expensive, feel free to run it when changing the code above + // #![proptest_config(ProptestConfig::with_cases(100000))] + #[test] + #[cfg_attr(miri, ignore)] // miri and proptest do not work well + fn check_utf8_validation(values in binary_strategy()) { + + for offset in 0..values.len() - 1 { + let offsets: OffsetsBuffer = vec![0, offset as i32, values.len() as i32].try_into().unwrap(); + + let mut is_valid = std::str::from_utf8(&values[..offset]).is_ok(); + is_valid &= std::str::from_utf8(&values[offset..]).is_ok(); + + assert_eq!(try_check_utf8::(&offsets, &values).is_ok(), is_valid) + } + } + } +} diff --git a/crates/polars-arrow/src/array/static_array.rs b/crates/polars-arrow/src/array/static_array.rs new file mode 100644 index 000000000000..262224788395 --- /dev/null +++ b/crates/polars-arrow/src/array/static_array.rs @@ -0,0 +1,418 @@ +use bytemuck::Zeroable; +use polars_utils::no_call_const; + +use crate::array::binview::BinaryViewValueIter; +use crate::array::builder::{ShareStrategy, StaticArrayBuilder, make_builder}; +use crate::array::fixed_size_list::FixedSizeListArrayBuilder; +use crate::array::static_array_collect::ArrayFromIterDtype; +use crate::array::{ + Array, ArrayValuesIter, BinaryArray, BinaryValueIter, BinaryViewArray, BooleanArray, + FixedSizeListArray, ListArray, ListValuesIter, MutableBinaryViewArray, PrimitiveArray, + StructArray, Utf8Array, Utf8ValuesIter, Utf8ViewArray, +}; +use crate::bitmap::Bitmap; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::datatypes::ArrowDataType; +use crate::trusted_len::TrustedLen; +use crate::types::NativeType; + +pub trait StaticArray: + Array + + for<'a> ArrayFromIterDtype> + + for<'a> ArrayFromIterDtype> + + for<'a> ArrayFromIterDtype>> + + Clone +{ + type ValueT<'a>: Clone + where + Self: 'a; + type ZeroableValueT<'a>: Zeroable + From> + where + Self: 'a; + type ValueIterT<'a>: DoubleEndedIterator> + TrustedLen + Send + Sync + where + Self: 'a; + + #[inline] + fn get(&self, idx: usize) -> Option> { + if idx >= self.len() { + None + } else { + unsafe { self.get_unchecked(idx) } + } + } + + /// # Safety + /// It is the callers responsibility that the `idx < self.len()`. + #[inline] + unsafe fn get_unchecked(&self, idx: usize) -> Option> { + if self.is_null_unchecked(idx) { + None + } else { + Some(self.value_unchecked(idx)) + } + } + + #[inline] + fn last(&self) -> Option> { + unsafe { self.get_unchecked(self.len().checked_sub(1)?) } + } + + #[inline] + fn value(&self, idx: usize) -> Self::ValueT<'_> { + assert!(idx < self.len()); + unsafe { self.value_unchecked(idx) } + } + + /// # Safety + /// It is the callers responsibility that the `idx < self.len()`. + #[allow(unused_variables)] + unsafe fn value_unchecked(&self, idx: usize) -> Self::ValueT<'_> { + no_call_const!() + } + + #[inline(always)] + fn as_slice(&self) -> Option<&[Self::ValueT<'_>]> { + None + } + + fn iter(&self) -> ZipValidity, Self::ValueIterT<'_>, BitmapIter> { + no_call_const!() + } + fn values_iter(&self) -> Self::ValueIterT<'_> { + no_call_const!() + } + fn with_validity_typed(self, validity: Option) -> Self; + + fn from_vec(v: Vec>, dtype: ArrowDataType) -> Self { + Self::arr_from_iter_with_dtype(dtype, v) + } + + fn from_zeroable_vec(v: Vec>, dtype: ArrowDataType) -> Self { + Self::arr_from_iter_with_dtype(dtype, v) + } + + fn full_null(length: usize, dtype: ArrowDataType) -> Self; + + fn full(length: usize, value: Self::ValueT<'_>, dtype: ArrowDataType) -> Self { + Self::arr_from_iter_with_dtype(dtype, std::iter::repeat_n(value, length)) + } +} + +pub trait ParameterFreeDtypeStaticArray: StaticArray { + fn get_dtype() -> ArrowDataType; +} + +impl StaticArray for PrimitiveArray { + type ValueT<'a> = T; + type ZeroableValueT<'a> = T; + type ValueIterT<'a> = std::iter::Copied>; + + #[inline] + unsafe fn value_unchecked(&self, idx: usize) -> Self::ValueT<'_> { + self.value_unchecked(idx) + } + + fn values_iter(&self) -> Self::ValueIterT<'_> { + self.values_iter().copied() + } + + #[inline(always)] + fn as_slice(&self) -> Option<&[Self::ValueT<'_>]> { + Some(self.values().as_slice()) + } + + fn iter(&self) -> ZipValidity, Self::ValueIterT<'_>, BitmapIter> { + ZipValidity::new_with_validity(self.values().iter().copied(), self.validity()) + } + + fn with_validity_typed(self, validity: Option) -> Self { + self.with_validity(validity) + } + + fn from_vec(v: Vec>, _dtype: ArrowDataType) -> Self { + PrimitiveArray::from_vec(v) + } + + fn from_zeroable_vec(v: Vec>, _dtype: ArrowDataType) -> Self { + PrimitiveArray::from_vec(v) + } + + fn full_null(length: usize, dtype: ArrowDataType) -> Self { + Self::new_null(dtype, length) + } + + fn full(length: usize, value: Self::ValueT<'_>, _dtype: ArrowDataType) -> Self { + PrimitiveArray::from_vec(vec![value; length]) + } +} + +impl ParameterFreeDtypeStaticArray for PrimitiveArray { + fn get_dtype() -> ArrowDataType { + T::PRIMITIVE.into() + } +} + +impl StaticArray for BooleanArray { + type ValueT<'a> = bool; + type ZeroableValueT<'a> = bool; + type ValueIterT<'a> = BitmapIter<'a>; + + #[inline] + unsafe fn value_unchecked(&self, idx: usize) -> Self::ValueT<'_> { + self.value_unchecked(idx) + } + + fn values_iter(&self) -> Self::ValueIterT<'_> { + self.values_iter() + } + + fn iter(&self) -> ZipValidity, Self::ValueIterT<'_>, BitmapIter> { + self.iter() + } + + fn with_validity_typed(self, validity: Option) -> Self { + self.with_validity(validity) + } + + fn from_vec(v: Vec>, _dtype: ArrowDataType) -> Self { + BooleanArray::from_slice(v) + } + + fn from_zeroable_vec(v: Vec>, _dtype: ArrowDataType) -> Self { + BooleanArray::from_slice(v) + } + + fn full_null(length: usize, dtype: ArrowDataType) -> Self { + Self::new_null(dtype, length) + } + + fn full(length: usize, value: Self::ValueT<'_>, _dtype: ArrowDataType) -> Self { + Bitmap::new_with_value(value, length).into() + } +} + +impl ParameterFreeDtypeStaticArray for BooleanArray { + fn get_dtype() -> ArrowDataType { + ArrowDataType::Boolean + } +} + +impl StaticArray for Utf8Array { + type ValueT<'a> = &'a str; + type ZeroableValueT<'a> = Option<&'a str>; + type ValueIterT<'a> = Utf8ValuesIter<'a, i64>; + + #[inline] + unsafe fn value_unchecked(&self, idx: usize) -> Self::ValueT<'_> { + self.value_unchecked(idx) + } + + fn values_iter(&self) -> Self::ValueIterT<'_> { + self.values_iter() + } + + fn iter(&self) -> ZipValidity, Self::ValueIterT<'_>, BitmapIter> { + self.iter() + } + + fn with_validity_typed(self, validity: Option) -> Self { + self.with_validity(validity) + } + + fn full_null(length: usize, dtype: ArrowDataType) -> Self { + Self::new_null(dtype, length) + } +} + +impl ParameterFreeDtypeStaticArray for Utf8Array { + fn get_dtype() -> ArrowDataType { + ArrowDataType::LargeUtf8 + } +} + +impl StaticArray for BinaryArray { + type ValueT<'a> = &'a [u8]; + type ZeroableValueT<'a> = Option<&'a [u8]>; + type ValueIterT<'a> = BinaryValueIter<'a, i64>; + + #[inline] + unsafe fn value_unchecked(&self, idx: usize) -> Self::ValueT<'_> { + self.value_unchecked(idx) + } + + fn values_iter(&self) -> Self::ValueIterT<'_> { + self.values_iter() + } + + fn iter(&self) -> ZipValidity, Self::ValueIterT<'_>, BitmapIter> { + self.iter() + } + + fn with_validity_typed(self, validity: Option) -> Self { + self.with_validity(validity) + } + + fn full_null(length: usize, dtype: ArrowDataType) -> Self { + Self::new_null(dtype, length) + } +} + +impl ParameterFreeDtypeStaticArray for BinaryArray { + fn get_dtype() -> ArrowDataType { + ArrowDataType::LargeBinary + } +} + +impl StaticArray for BinaryViewArray { + type ValueT<'a> = &'a [u8]; + type ZeroableValueT<'a> = Option<&'a [u8]>; + type ValueIterT<'a> = BinaryViewValueIter<'a, [u8]>; + + unsafe fn value_unchecked(&self, idx: usize) -> Self::ValueT<'_> { + self.value_unchecked(idx) + } + + fn iter(&self) -> ZipValidity, Self::ValueIterT<'_>, BitmapIter> { + self.iter() + } + + fn values_iter(&self) -> Self::ValueIterT<'_> { + self.values_iter() + } + + fn with_validity_typed(self, validity: Option) -> Self { + self.with_validity(validity) + } + + fn full_null(length: usize, dtype: ArrowDataType) -> Self { + Self::new_null(dtype, length) + } + + fn full(length: usize, value: Self::ValueT<'_>, _dtype: ArrowDataType) -> Self { + let mut builder = MutableBinaryViewArray::with_capacity(length); + builder.extend_constant(length, Some(value)); + builder.into() + } +} + +impl ParameterFreeDtypeStaticArray for BinaryViewArray { + fn get_dtype() -> ArrowDataType { + ArrowDataType::BinaryView + } +} + +impl StaticArray for Utf8ViewArray { + type ValueT<'a> = &'a str; + type ZeroableValueT<'a> = Option<&'a str>; + type ValueIterT<'a> = BinaryViewValueIter<'a, str>; + + unsafe fn value_unchecked(&self, idx: usize) -> Self::ValueT<'_> { + self.value_unchecked(idx) + } + + fn iter(&self) -> ZipValidity, Self::ValueIterT<'_>, BitmapIter> { + self.iter() + } + + fn values_iter(&self) -> Self::ValueIterT<'_> { + self.values_iter() + } + + fn with_validity_typed(self, validity: Option) -> Self { + self.with_validity(validity) + } + + fn full_null(length: usize, dtype: ArrowDataType) -> Self { + Self::new_null(dtype, length) + } + + fn full(length: usize, value: Self::ValueT<'_>, _dtype: ArrowDataType) -> Self { + unsafe { + BinaryViewArray::full(length, value.as_bytes(), ArrowDataType::BinaryView) + .to_utf8view_unchecked() + } + } +} + +impl ParameterFreeDtypeStaticArray for Utf8ViewArray { + fn get_dtype() -> ArrowDataType { + ArrowDataType::Utf8View + } +} + +impl StaticArray for ListArray { + type ValueT<'a> = Box; + type ZeroableValueT<'a> = Option>; + type ValueIterT<'a> = ListValuesIter<'a, i64>; + + #[inline] + unsafe fn value_unchecked(&self, idx: usize) -> Self::ValueT<'_> { + self.value_unchecked(idx) + } + + fn values_iter(&self) -> Self::ValueIterT<'_> { + self.values_iter() + } + + fn iter(&self) -> ZipValidity, Self::ValueIterT<'_>, BitmapIter> { + self.iter() + } + + fn with_validity_typed(self, validity: Option) -> Self { + self.with_validity(validity) + } + + fn full_null(length: usize, dtype: ArrowDataType) -> Self { + Self::new_null(dtype, length) + } +} + +impl StaticArray for FixedSizeListArray { + type ValueT<'a> = Box; + type ZeroableValueT<'a> = Option>; + type ValueIterT<'a> = ArrayValuesIter<'a, FixedSizeListArray>; + + #[inline] + unsafe fn value_unchecked(&self, idx: usize) -> Self::ValueT<'_> { + self.value_unchecked(idx) + } + + fn values_iter(&self) -> Self::ValueIterT<'_> { + self.values_iter() + } + + fn iter(&self) -> ZipValidity, Self::ValueIterT<'_>, BitmapIter> { + self.iter() + } + + fn with_validity_typed(self, validity: Option) -> Self { + self.with_validity(validity) + } + + fn full_null(length: usize, dtype: ArrowDataType) -> Self { + Self::new_null(dtype, length) + } + + fn full(length: usize, value: Self::ValueT<'_>, dtype: ArrowDataType) -> Self { + let singular_arr = FixedSizeListArray::new(dtype.clone(), 1, value, None); + let inner_dt = dtype.inner_dtype().unwrap(); + let mut builder = FixedSizeListArrayBuilder::new(dtype.clone(), make_builder(inner_dt)); + builder.subslice_extend_repeated(&singular_arr, 0, 1, length, ShareStrategy::Always); + builder.freeze() + } +} + +impl StaticArray for StructArray { + type ValueT<'a> = (); + type ZeroableValueT<'a> = (); + type ValueIterT<'a> = std::iter::Repeat<()>; + + fn with_validity_typed(self, validity: Option) -> Self { + self.with_validity(validity) + } + + fn full_null(length: usize, dtype: ArrowDataType) -> Self { + Self::new_null(dtype, length) + } +} diff --git a/crates/polars-arrow/src/array/static_array_collect.rs b/crates/polars-arrow/src/array/static_array_collect.rs new file mode 100644 index 000000000000..dceceee40325 --- /dev/null +++ b/crates/polars-arrow/src/array/static_array_collect.rs @@ -0,0 +1,952 @@ +use std::borrow::Cow; + +use polars_utils::no_call_const; + +use crate::array::static_array::{ParameterFreeDtypeStaticArray, StaticArray}; +use crate::array::{ + Array, BinaryArray, BinaryViewArray, BooleanArray, FixedSizeListArray, ListArray, + MutableBinaryArray, MutableBinaryValuesArray, MutableBinaryViewArray, PrimitiveArray, + StructArray, Utf8Array, Utf8ViewArray, +}; +use crate::bitmap::BitmapBuilder; +use crate::datatypes::ArrowDataType; +#[cfg(feature = "dtype-array")] +use crate::legacy::prelude::fixed_size_list::AnonymousBuilder as AnonymousFixedSizeListArrayBuilder; +use crate::legacy::prelude::list::AnonymousBuilder as AnonymousListArrayBuilder; +use crate::legacy::trusted_len::TrustedLenPush; +use crate::trusted_len::TrustedLen; +use crate::types::NativeType; + +pub trait ArrayFromIterDtype: Sized { + fn arr_from_iter_with_dtype>(dtype: ArrowDataType, iter: I) -> Self; + + #[inline(always)] + fn arr_from_iter_trusted_with_dtype(dtype: ArrowDataType, iter: I) -> Self + where + I: IntoIterator, + I::IntoIter: TrustedLen, + { + Self::arr_from_iter_with_dtype(dtype, iter) + } + + fn try_arr_from_iter_with_dtype>>( + dtype: ArrowDataType, + iter: I, + ) -> Result; + + #[inline(always)] + fn try_arr_from_iter_trusted_with_dtype(dtype: ArrowDataType, iter: I) -> Result + where + I: IntoIterator>, + I::IntoIter: TrustedLen, + { + Self::try_arr_from_iter_with_dtype(dtype, iter) + } +} + +pub trait ArrayFromIter: Sized { + fn arr_from_iter>(iter: I) -> Self; + + #[inline(always)] + fn arr_from_iter_trusted(iter: I) -> Self + where + I: IntoIterator, + I::IntoIter: TrustedLen, + { + Self::arr_from_iter(iter) + } + + fn try_arr_from_iter>>(iter: I) -> Result; + + #[inline(always)] + fn try_arr_from_iter_trusted(iter: I) -> Result + where + I: IntoIterator>, + I::IntoIter: TrustedLen, + { + Self::try_arr_from_iter(iter) + } +} + +impl> ArrayFromIterDtype for A { + #[inline(always)] + fn arr_from_iter_with_dtype>(dtype: ArrowDataType, iter: I) -> Self { + // FIXME: currently some Object arrays have Unknown dtype, when this is fixed remove this bypass. + if dtype != ArrowDataType::Unknown { + debug_assert_eq!( + std::mem::discriminant(&dtype), + std::mem::discriminant(&A::get_dtype()) + ); + } + Self::arr_from_iter(iter) + } + + #[inline(always)] + fn arr_from_iter_trusted_with_dtype(dtype: ArrowDataType, iter: I) -> Self + where + I: IntoIterator, + I::IntoIter: TrustedLen, + { + // FIXME: currently some Object arrays have Unknown dtype, when this is fixed remove this bypass. + if dtype != ArrowDataType::Unknown { + debug_assert_eq!( + std::mem::discriminant(&dtype), + std::mem::discriminant(&A::get_dtype()) + ); + } + Self::arr_from_iter_trusted(iter) + } + + #[inline(always)] + fn try_arr_from_iter_with_dtype>>( + dtype: ArrowDataType, + iter: I, + ) -> Result { + // FIXME: currently some Object arrays have Unknown dtype, when this is fixed remove this bypass. + if dtype != ArrowDataType::Unknown { + debug_assert_eq!( + std::mem::discriminant(&dtype), + std::mem::discriminant(&A::get_dtype()) + ); + } + Self::try_arr_from_iter(iter) + } + + #[inline(always)] + fn try_arr_from_iter_trusted_with_dtype(dtype: ArrowDataType, iter: I) -> Result + where + I: IntoIterator>, + I::IntoIter: TrustedLen, + { + // FIXME: currently some Object arrays have Unknown dtype, when this is fixed remove this bypass. + if dtype != ArrowDataType::Unknown { + debug_assert_eq!( + std::mem::discriminant(&dtype), + std::mem::discriminant(&A::get_dtype()) + ); + } + Self::try_arr_from_iter_trusted(iter) + } +} + +pub trait ArrayCollectIterExt: Iterator + Sized { + #[inline(always)] + fn collect_arr(self) -> A + where + A: ArrayFromIter, + { + A::arr_from_iter(self) + } + + #[inline(always)] + fn collect_arr_trusted(self) -> A + where + A: ArrayFromIter, + Self: TrustedLen, + { + A::arr_from_iter_trusted(self) + } + + #[inline(always)] + fn try_collect_arr(self) -> Result + where + A: ArrayFromIter, + Self: Iterator>, + { + A::try_arr_from_iter(self) + } + + #[inline(always)] + fn try_collect_arr_trusted(self) -> Result + where + A: ArrayFromIter, + Self: Iterator> + TrustedLen, + { + A::try_arr_from_iter_trusted(self) + } + + #[inline(always)] + fn collect_arr_with_dtype(self, dtype: ArrowDataType) -> A + where + A: ArrayFromIterDtype, + { + A::arr_from_iter_with_dtype(dtype, self) + } + + #[inline(always)] + fn collect_arr_trusted_with_dtype(self, dtype: ArrowDataType) -> A + where + A: ArrayFromIterDtype, + Self: TrustedLen, + { + A::arr_from_iter_trusted_with_dtype(dtype, self) + } + + #[inline(always)] + fn try_collect_arr_with_dtype(self, dtype: ArrowDataType) -> Result + where + A: ArrayFromIterDtype, + Self: Iterator>, + { + A::try_arr_from_iter_with_dtype(dtype, self) + } + + #[inline(always)] + fn try_collect_arr_trusted_with_dtype(self, dtype: ArrowDataType) -> Result + where + A: ArrayFromIterDtype, + Self: Iterator> + TrustedLen, + { + A::try_arr_from_iter_trusted_with_dtype(dtype, self) + } +} + +impl ArrayCollectIterExt for I {} + +// --------------- +// Implementations +// --------------- + +impl ArrayFromIter for PrimitiveArray { + #[inline] + fn arr_from_iter>(iter: I) -> Self { + PrimitiveArray::from_vec(iter.into_iter().collect()) + } + + #[inline] + fn arr_from_iter_trusted(iter: I) -> Self + where + I: IntoIterator, + I::IntoIter: TrustedLen, + { + PrimitiveArray::from_vec(Vec::from_trusted_len_iter(iter)) + } + + #[inline] + fn try_arr_from_iter>>(iter: I) -> Result { + let v: Result, E> = iter.into_iter().collect(); + Ok(PrimitiveArray::from_vec(v?)) + } + + #[inline] + fn try_arr_from_iter_trusted(iter: I) -> Result + where + I: IntoIterator>, + I::IntoIter: TrustedLen, + { + let v = Vec::try_from_trusted_len_iter(iter); + Ok(PrimitiveArray::from_vec(v?)) + } +} + +impl ArrayFromIter> for PrimitiveArray { + fn arr_from_iter>>(iter: I) -> Self { + let iter = iter.into_iter(); + let n = iter.size_hint().0; + let mut buf = Vec::with_capacity(n); + let mut validity = BitmapBuilder::with_capacity(n); + unsafe { + for val in iter { + // Use one check for both capacities. + if buf.len() == buf.capacity() { + buf.reserve(1); + validity.reserve(buf.capacity() - buf.len()); + } + buf.push_unchecked(val.unwrap_or_default()); + validity.push_unchecked(val.is_some()); + } + } + PrimitiveArray::new( + T::PRIMITIVE.into(), + buf.into(), + validity.into_opt_validity(), + ) + } + + fn arr_from_iter_trusted(iter: I) -> Self + where + I: IntoIterator>, + I::IntoIter: TrustedLen, + { + let iter = iter.into_iter(); + let n = iter.size_hint().1.expect("must have an upper bound"); + let mut buf = Vec::with_capacity(n); + let mut validity = BitmapBuilder::with_capacity(n); + unsafe { + for val in iter { + buf.push_unchecked(val.unwrap_or_default()); + validity.push_unchecked(val.is_some()); + } + } + PrimitiveArray::new( + T::PRIMITIVE.into(), + buf.into(), + validity.into_opt_validity(), + ) + } + + fn try_arr_from_iter, E>>>( + iter: I, + ) -> Result { + let iter = iter.into_iter(); + let n = iter.size_hint().0; + let mut buf = Vec::with_capacity(n); + let mut validity = BitmapBuilder::with_capacity(n); + unsafe { + for val in iter { + let val = val?; + // Use one check for both capacities. + if buf.len() == buf.capacity() { + buf.reserve(1); + validity.reserve(buf.capacity() - buf.len()); + } + buf.push_unchecked(val.unwrap_or_default()); + validity.push_unchecked(val.is_some()); + } + } + Ok(PrimitiveArray::new( + T::PRIMITIVE.into(), + buf.into(), + validity.into_opt_validity(), + )) + } + + fn try_arr_from_iter_trusted(iter: I) -> Result + where + I: IntoIterator, E>>, + I::IntoIter: TrustedLen, + { + let iter = iter.into_iter(); + let n = iter.size_hint().1.expect("must have an upper bound"); + let mut buf = Vec::with_capacity(n); + let mut validity = BitmapBuilder::with_capacity(n); + unsafe { + for val in iter { + let val = val?; + buf.push_unchecked(val.unwrap_or_default()); + validity.push_unchecked(val.is_some()); + } + } + Ok(PrimitiveArray::new( + T::PRIMITIVE.into(), + buf.into(), + validity.into_opt_validity(), + )) + } +} + +// We don't use AsRef here because it leads to problems with conflicting implementations, +// as Rust considers that AsRef<[u8]> for Option<&[u8]> could be implemented. +trait IntoBytes { + type AsRefT: AsRef<[u8]>; + fn into_bytes(self) -> Self::AsRefT; +} +trait TrivialIntoBytes: AsRef<[u8]> {} +impl IntoBytes for T { + type AsRefT = Self; + fn into_bytes(self) -> Self { + self + } +} +impl TrivialIntoBytes for Vec {} +impl TrivialIntoBytes for Cow<'_, [u8]> {} +impl TrivialIntoBytes for &[u8] {} +impl TrivialIntoBytes for String {} +impl TrivialIntoBytes for &str {} +impl<'a> IntoBytes for Cow<'a, str> { + type AsRefT = Cow<'a, [u8]>; + fn into_bytes(self) -> Cow<'a, [u8]> { + match self { + Cow::Borrowed(a) => Cow::Borrowed(a.as_bytes()), + Cow::Owned(s) => Cow::Owned(s.into_bytes()), + } + } +} + +impl ArrayFromIter for BinaryArray { + fn arr_from_iter>(iter: I) -> Self { + BinaryArray::from_iter_values(iter.into_iter().map(|s| s.into_bytes())) + } + + fn arr_from_iter_trusted(iter: I) -> Self + where + I: IntoIterator, + I::IntoIter: TrustedLen, + { + unsafe { + // SAFETY: our iterator is TrustedLen. + MutableBinaryArray::from_trusted_len_values_iter_unchecked( + iter.into_iter().map(|s| s.into_bytes()), + ) + .into() + } + } + + fn try_arr_from_iter>>(iter: I) -> Result { + // No built-in for this? + let mut arr = MutableBinaryValuesArray::new(); + let mut iter = iter.into_iter(); + arr.reserve(iter.size_hint().0, 0); + iter.try_for_each(|x| -> Result<(), E> { + arr.push(x?.into_bytes()); + Ok(()) + })?; + Ok(arr.into()) + } + + // No faster implementation than this available, fall back to default. + // fn try_arr_from_iter_trusted(iter: I) -> Result +} + +impl ArrayFromIter> for BinaryArray { + #[inline] + fn arr_from_iter>>(iter: I) -> Self { + BinaryArray::from_iter(iter.into_iter().map(|s| Some(s?.into_bytes()))) + } + + #[inline] + fn arr_from_iter_trusted(iter: I) -> Self + where + I: IntoIterator>, + I::IntoIter: TrustedLen, + { + unsafe { + // SAFETY: the iterator is TrustedLen. + BinaryArray::from_trusted_len_iter_unchecked( + iter.into_iter().map(|s| Some(s?.into_bytes())), + ) + } + } + + fn try_arr_from_iter, E>>>( + iter: I, + ) -> Result { + // No built-in for this? + let mut arr = MutableBinaryArray::new(); + let mut iter = iter.into_iter(); + arr.reserve(iter.size_hint().0, 0); + iter.try_for_each(|x| -> Result<(), E> { + arr.push(x?.map(|s| s.into_bytes())); + Ok(()) + })?; + Ok(arr.into()) + } + + fn try_arr_from_iter_trusted(iter: I) -> Result + where + I: IntoIterator, E>>, + I::IntoIter: TrustedLen, + { + unsafe { + // SAFETY: the iterator is TrustedLen. + BinaryArray::try_from_trusted_len_iter_unchecked( + iter.into_iter().map(|s| s.map(|s| Some(s?.into_bytes()))), + ) + } + } +} + +impl ArrayFromIter for BinaryViewArray { + #[inline] + fn arr_from_iter>(iter: I) -> Self { + MutableBinaryViewArray::from_values_iter(iter.into_iter().map(|a| a.into_bytes())).into() + } + + #[inline] + fn arr_from_iter_trusted(iter: I) -> Self + where + I: IntoIterator, + I::IntoIter: TrustedLen, + { + Self::arr_from_iter(iter) + } + + fn try_arr_from_iter>>(iter: I) -> Result { + let mut iter = iter.into_iter(); + let mut arr = MutableBinaryViewArray::with_capacity(iter.size_hint().0); + iter.try_for_each(|x| -> Result<(), E> { + arr.push_value_ignore_validity(x?.into_bytes()); + Ok(()) + })?; + Ok(arr.into()) + } + + // No faster implementation than this available, fall back to default. + // fn try_arr_from_iter_trusted(iter: I) -> Result +} + +impl ArrayFromIter> for BinaryViewArray { + #[inline] + fn arr_from_iter>>(iter: I) -> Self { + MutableBinaryViewArray::from_iter( + iter.into_iter().map(|opt_a| opt_a.map(|a| a.into_bytes())), + ) + .into() + } + + #[inline] + fn arr_from_iter_trusted(iter: I) -> Self + where + I: IntoIterator>, + I::IntoIter: TrustedLen, + { + Self::arr_from_iter(iter) + } + + fn try_arr_from_iter, E>>>( + iter: I, + ) -> Result { + let mut iter = iter.into_iter(); + let mut arr = MutableBinaryViewArray::with_capacity(iter.size_hint().0); + iter.try_for_each(|x| -> Result<(), E> { + let x = x?; + arr.push(x.map(|x| x.into_bytes())); + Ok(()) + })?; + Ok(arr.into()) + } + + // No faster implementation than this available, fall back to default. + // fn try_arr_from_iter_trusted(iter: I) -> Result +} + +/// We use this to reuse the binary collect implementation for strings. +/// # Safety +/// The array must be valid UTF-8. +unsafe fn into_utf8array(arr: BinaryArray) -> Utf8Array { + unsafe { + let (_dt, offsets, values, validity) = arr.into_inner(); + Utf8Array::new_unchecked(ArrowDataType::LargeUtf8, offsets, values, validity) + } +} + +trait StrIntoBytes: IntoBytes {} +impl StrIntoBytes for String {} +impl StrIntoBytes for &str {} +impl StrIntoBytes for Cow<'_, str> {} + +impl ArrayFromIter for Utf8ViewArray { + #[inline] + fn arr_from_iter>(iter: I) -> Self { + unsafe { BinaryViewArray::arr_from_iter(iter).to_utf8view_unchecked() } + } + + #[inline] + fn arr_from_iter_trusted(iter: I) -> Self + where + I: IntoIterator, + I::IntoIter: TrustedLen, + { + Self::arr_from_iter(iter) + } + + fn try_arr_from_iter>>(iter: I) -> Result { + unsafe { BinaryViewArray::try_arr_from_iter(iter).map(|arr| arr.to_utf8view_unchecked()) } + } + + // No faster implementation than this available, fall back to default. + // fn try_arr_from_iter_trusted(iter: I) -> Result +} + +impl ArrayFromIter> for Utf8ViewArray { + #[inline] + fn arr_from_iter>>(iter: I) -> Self { + unsafe { BinaryViewArray::arr_from_iter(iter).to_utf8view_unchecked() } + } + + #[inline] + fn arr_from_iter_trusted(iter: I) -> Self + where + I: IntoIterator>, + I::IntoIter: TrustedLen, + { + Self::arr_from_iter(iter) + } + + fn try_arr_from_iter, E>>>( + iter: I, + ) -> Result { + unsafe { BinaryViewArray::try_arr_from_iter(iter).map(|arr| arr.to_utf8view_unchecked()) } + } + + // No faster implementation than this available, fall back to default. + // fn try_arr_from_iter_trusted(iter: I) -> Result +} + +impl ArrayFromIter for Utf8Array { + #[inline(always)] + fn arr_from_iter>(iter: I) -> Self { + unsafe { into_utf8array(iter.into_iter().collect_arr()) } + } + + #[inline(always)] + fn arr_from_iter_trusted(iter: I) -> Self + where + I: IntoIterator, + I::IntoIter: TrustedLen, + { + unsafe { into_utf8array(iter.into_iter().collect_arr()) } + } + + #[inline(always)] + fn try_arr_from_iter>>(iter: I) -> Result { + let arr = iter.into_iter().try_collect_arr()?; + unsafe { Ok(into_utf8array(arr)) } + } + + #[inline(always)] + fn try_arr_from_iter_trusted>>( + iter: I, + ) -> Result { + let arr = iter.into_iter().try_collect_arr()?; + unsafe { Ok(into_utf8array(arr)) } + } +} + +impl ArrayFromIter> for Utf8Array { + #[inline(always)] + fn arr_from_iter>>(iter: I) -> Self { + unsafe { into_utf8array(iter.into_iter().collect_arr()) } + } + + #[inline(always)] + fn arr_from_iter_trusted(iter: I) -> Self + where + I: IntoIterator>, + I::IntoIter: TrustedLen, + { + unsafe { into_utf8array(iter.into_iter().collect_arr()) } + } + + #[inline(always)] + fn try_arr_from_iter, E>>>( + iter: I, + ) -> Result { + let arr = iter.into_iter().try_collect_arr()?; + unsafe { Ok(into_utf8array(arr)) } + } + + #[inline(always)] + fn try_arr_from_iter_trusted, E>>>( + iter: I, + ) -> Result { + let arr = iter.into_iter().try_collect_arr()?; + unsafe { Ok(into_utf8array(arr)) } + } +} + +impl ArrayFromIter for BooleanArray { + fn arr_from_iter>(iter: I) -> Self { + let iter = iter.into_iter(); + let n = iter.size_hint().0; + let mut values = BitmapBuilder::with_capacity(n); + for val in iter { + values.push(val); + } + BooleanArray::new(ArrowDataType::Boolean, values.freeze(), None) + } + + // TODO: are efficient trusted collects for booleans worth it? + // fn arr_from_iter_trusted(iter: I) -> Self + + fn try_arr_from_iter>>(iter: I) -> Result { + let iter = iter.into_iter(); + let n = iter.size_hint().0; + let mut values = BitmapBuilder::with_capacity(n); + for val in iter { + values.push(val?); + } + Ok(BooleanArray::new( + ArrowDataType::Boolean, + values.freeze(), + None, + )) + } + + // fn try_arr_from_iter_trusted>>( +} + +impl ArrayFromIter> for BooleanArray { + fn arr_from_iter>>(iter: I) -> Self { + let iter = iter.into_iter(); + let n = iter.size_hint().0; + let mut values = BitmapBuilder::with_capacity(n); + let mut validity = BitmapBuilder::with_capacity(n); + for val in iter { + values.push(val.unwrap_or(false)); + validity.push(val.is_some()); + } + BooleanArray::new( + ArrowDataType::Boolean, + values.freeze(), + validity.into_opt_validity(), + ) + } + + // fn arr_from_iter_trusted(iter: I) -> Self + + fn try_arr_from_iter, E>>>( + iter: I, + ) -> Result { + let iter = iter.into_iter(); + let n = iter.size_hint().0; + let mut values = BitmapBuilder::with_capacity(n); + let mut validity = BitmapBuilder::with_capacity(n); + for val in iter { + let val = val?; + values.push(val.unwrap_or(false)); + validity.push(val.is_some()); + } + Ok(BooleanArray::new( + ArrowDataType::Boolean, + values.freeze(), + validity.into_opt_validity(), + )) + } + + // fn try_arr_from_iter_trusted, E>>>( +} + +// We don't use AsRef here because it leads to problems with conflicting implementations, +// as Rust considers that AsRef for Option<&dyn Array> could be implemented. +trait AsArray { + fn as_array(&self) -> &dyn Array; + #[cfg(feature = "dtype-array")] + fn into_boxed_array(self) -> Box; // Prevents unnecessary re-boxing. +} +impl AsArray for Box { + fn as_array(&self) -> &dyn Array { + self.as_ref() + } + #[cfg(feature = "dtype-array")] + fn into_boxed_array(self) -> Box { + self + } +} +impl<'a> AsArray for &'a dyn Array { + fn as_array(&self) -> &'a dyn Array { + *self + } + #[cfg(feature = "dtype-array")] + fn into_boxed_array(self) -> Box { + self.to_boxed() + } +} + +// TODO: more efficient (fixed size) list collect routines. +impl ArrayFromIterDtype for ListArray { + fn arr_from_iter_with_dtype>(dtype: ArrowDataType, iter: I) -> Self { + let iter_values: Vec = iter.into_iter().collect(); + let mut builder = AnonymousListArrayBuilder::new(iter_values.len()); + for arr in &iter_values { + builder.push(arr.as_array()); + } + let inner = dtype + .inner_dtype() + .expect("expected nested type in ListArray collect"); + builder + .finish(Some(&inner.underlying_physical_type())) + .unwrap() + } + + fn try_arr_from_iter_with_dtype>>( + dtype: ArrowDataType, + iter: I, + ) -> Result { + let iter_values = iter.into_iter().collect::, E>>()?; + Ok(Self::arr_from_iter_with_dtype(dtype, iter_values)) + } +} + +impl ArrayFromIterDtype> for ListArray { + fn arr_from_iter_with_dtype>>( + dtype: ArrowDataType, + iter: I, + ) -> Self { + let iter_values: Vec> = iter.into_iter().collect(); + let mut builder = AnonymousListArrayBuilder::new(iter_values.len()); + for arr in &iter_values { + builder.push_opt(arr.as_ref().map(|a| a.as_array())); + } + let inner = dtype + .inner_dtype() + .expect("expected nested type in ListArray collect"); + builder + .finish(Some(&inner.underlying_physical_type())) + .unwrap() + } + + fn try_arr_from_iter_with_dtype, E>>>( + dtype: ArrowDataType, + iter: I, + ) -> Result { + let iter_values = iter.into_iter().collect::, E>>()?; + let mut builder = AnonymousListArrayBuilder::new(iter_values.len()); + for arr in &iter_values { + builder.push_opt(arr.as_ref().map(|a| a.as_array())); + } + let inner = dtype + .inner_dtype() + .expect("expected nested type in ListArray collect"); + Ok(builder + .finish(Some(&inner.underlying_physical_type())) + .unwrap()) + } +} + +impl ArrayFromIter> for ListArray { + fn arr_from_iter>>(iter: I) -> Self { + let iter = iter.into_iter(); + let iter_values: Vec> = iter.into_iter().collect(); + let mut builder = AnonymousListArrayBuilder::new(iter_values.len()); + for arr in &iter_values { + builder.push_opt(arr.as_ref().map(|a| a.as_array())); + } + builder.finish(None).unwrap() + } + + fn try_arr_from_iter, E>>>( + iter: I, + ) -> Result { + let iter_values = iter.into_iter().collect::, E>>()?; + let mut builder = AnonymousListArrayBuilder::new(iter_values.len()); + for arr in &iter_values { + builder.push_opt(arr.as_ref().map(|a| a.as_array())); + } + Ok(builder.finish(None).unwrap()) + } +} + +impl ArrayFromIterDtype> for FixedSizeListArray { + #[allow(unused_variables)] + fn arr_from_iter_with_dtype>>( + dtype: ArrowDataType, + iter: I, + ) -> Self { + #[cfg(feature = "dtype-array")] + { + let ArrowDataType::FixedSizeList(_, width) = &dtype else { + panic!("FixedSizeListArray::arr_from_iter_with_dtype called with non-Array dtype"); + }; + let iter_values: Vec<_> = iter.into_iter().collect(); + let mut builder = AnonymousFixedSizeListArrayBuilder::new(iter_values.len(), *width); + for arr in iter_values { + builder.push(arr.into_boxed_array()); + } + let inner = dtype + .inner_dtype() + .expect("expected nested type in ListArray collect"); + builder + .finish(Some(&inner.underlying_physical_type())) + .unwrap() + } + #[cfg(not(feature = "dtype-array"))] + panic!("activate 'dtype-array'") + } + + fn try_arr_from_iter_with_dtype, E>>>( + dtype: ArrowDataType, + iter: I, + ) -> Result { + let iter_values = iter.into_iter().collect::, E>>()?; + Ok(Self::arr_from_iter_with_dtype(dtype, iter_values)) + } +} + +impl ArrayFromIterDtype>> for FixedSizeListArray { + #[allow(unused_variables)] + fn arr_from_iter_with_dtype>>>( + dtype: ArrowDataType, + iter: I, + ) -> Self { + #[cfg(feature = "dtype-array")] + { + let ArrowDataType::FixedSizeList(_, width) = &dtype else { + panic!( + "FixedSizeListArray::arr_from_iter_with_dtype called with non-FixedSizeList dtype" + ); + }; + let iter_values: Vec<_> = iter.into_iter().collect(); + let mut builder = AnonymousFixedSizeListArrayBuilder::new(iter_values.len(), *width); + for arr in iter_values { + match arr { + Some(a) => builder.push(a.into_boxed_array()), + None => builder.push_null(), + } + } + let inner = dtype + .inner_dtype() + .expect("expected nested type in ListArray collect"); + builder + .finish(Some(&inner.underlying_physical_type())) + .unwrap() + } + #[cfg(not(feature = "dtype-array"))] + panic!("activate 'dtype-array'") + } + + fn try_arr_from_iter_with_dtype< + E, + I: IntoIterator>, E>>, + >( + dtype: ArrowDataType, + iter: I, + ) -> Result { + let iter_values = iter.into_iter().collect::, E>>()?; + Ok(Self::arr_from_iter_with_dtype(dtype, iter_values)) + } +} + +impl ArrayFromIter> for StructArray { + fn arr_from_iter>>(_iter: I) -> Self { + no_call_const!() + } + + fn try_arr_from_iter, E>>>( + _iter: I, + ) -> Result { + no_call_const!() + } +} + +impl ArrayFromIter<()> for StructArray { + fn arr_from_iter>(_iter: I) -> Self { + no_call_const!() + } + + fn try_arr_from_iter>>(_iter: I) -> Result { + no_call_const!() + } +} + +impl ArrayFromIterDtype<()> for StructArray { + fn arr_from_iter_with_dtype>( + _dtype: ArrowDataType, + _iter: I, + ) -> Self { + no_call_const!() + } + + fn try_arr_from_iter_with_dtype>>( + _dtype: ArrowDataType, + _iter: I, + ) -> Result { + no_call_const!() + } +} + +impl ArrayFromIterDtype> for StructArray { + fn arr_from_iter_with_dtype>>( + _dtype: ArrowDataType, + _iter: I, + ) -> Self { + no_call_const!() + } + + fn try_arr_from_iter_with_dtype, E>>>( + _dtype: ArrowDataType, + _iter: I, + ) -> Result { + no_call_const!() + } +} diff --git a/crates/polars-arrow/src/array/struct_/builder.rs b/crates/polars-arrow/src/array/struct_/builder.rs new file mode 100644 index 000000000000..7acf40caef04 --- /dev/null +++ b/crates/polars-arrow/src/array/struct_/builder.rs @@ -0,0 +1,111 @@ +use polars_utils::IdxSize; + +use super::StructArray; +use crate::array::builder::{ArrayBuilder, ShareStrategy, StaticArrayBuilder}; +use crate::bitmap::OptBitmapBuilder; +use crate::datatypes::ArrowDataType; + +pub struct StructArrayBuilder { + dtype: ArrowDataType, + length: usize, + inner_builders: Vec>, + validity: OptBitmapBuilder, +} + +impl StructArrayBuilder { + pub fn new(dtype: ArrowDataType, inner_builders: Vec>) -> Self { + Self { + dtype, + length: 0, + inner_builders, + validity: OptBitmapBuilder::default(), + } + } +} + +impl StaticArrayBuilder for StructArrayBuilder { + type Array = StructArray; + + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } + + fn reserve(&mut self, additional: usize) { + for builder in &mut self.inner_builders { + builder.reserve(additional); + } + self.validity.reserve(additional); + } + + fn freeze(self) -> StructArray { + let values = self + .inner_builders + .into_iter() + .map(|b| b.freeze()) + .collect(); + let validity = self.validity.into_opt_validity(); + StructArray::new(self.dtype, self.length, values, validity) + } + + fn freeze_reset(&mut self) -> Self::Array { + let values = self + .inner_builders + .iter_mut() + .map(|b| b.freeze_reset()) + .collect(); + let validity = core::mem::take(&mut self.validity).into_opt_validity(); + let out = StructArray::new(self.dtype.clone(), self.length, values, validity); + self.length = 0; + out + } + + fn len(&self) -> usize { + self.length + } + + fn extend_nulls(&mut self, length: usize) { + for builder in &mut self.inner_builders { + builder.extend_nulls(length); + } + self.validity.extend_constant(length, false); + self.length += length; + } + + fn subslice_extend( + &mut self, + other: &StructArray, + start: usize, + length: usize, + share: ShareStrategy, + ) { + for (builder, other_values) in self.inner_builders.iter_mut().zip(other.values()) { + builder.subslice_extend(&**other_values, start, length, share); + } + self.validity + .subslice_extend_from_opt_validity(other.validity(), start, length); + self.length += length.min(other.len().saturating_sub(start)); + } + + unsafe fn gather_extend( + &mut self, + other: &StructArray, + idxs: &[IdxSize], + share: ShareStrategy, + ) { + for (builder, other_values) in self.inner_builders.iter_mut().zip(other.values()) { + builder.gather_extend(&**other_values, idxs, share); + } + self.validity + .gather_extend_from_opt_validity(other.validity(), idxs); + self.length += idxs.len(); + } + + fn opt_gather_extend(&mut self, other: &StructArray, idxs: &[IdxSize], share: ShareStrategy) { + for (builder, other_values) in self.inner_builders.iter_mut().zip(other.values()) { + builder.opt_gather_extend(&**other_values, idxs, share); + } + self.validity + .opt_gather_extend_from_opt_validity(other.validity(), idxs, other.len()); + self.length += idxs.len(); + } +} diff --git a/crates/polars-arrow/src/array/struct_/ffi.rs b/crates/polars-arrow/src/array/struct_/ffi.rs new file mode 100644 index 000000000000..cc56f0f12cf3 --- /dev/null +++ b/crates/polars-arrow/src/array/struct_/ffi.rs @@ -0,0 +1,73 @@ +use polars_error::PolarsResult; + +use super::super::ffi::ToFfi; +use super::super::{Array, FromFfi}; +use super::StructArray; +use crate::ffi; + +unsafe impl ToFfi for StructArray { + fn buffers(&self) -> Vec> { + vec![self.validity.as_ref().map(|x| x.as_ptr())] + } + + fn children(&self) -> Vec> { + self.values.clone() + } + + fn offset(&self) -> Option { + Some( + self.validity + .as_ref() + .map(|bitmap| bitmap.offset()) + .unwrap_or_default(), + ) + } + + fn to_ffi_aligned(&self) -> Self { + self.clone() + } +} + +impl FromFfi for StructArray { + unsafe fn try_from_ffi(array: A) -> PolarsResult { + let dtype = array.dtype().clone(); + let fields = Self::get_fields(&dtype); + + let arrow_array = array.array(); + let validity = unsafe { array.validity() }?; + let len = arrow_array.len(); + let offset = arrow_array.offset(); + let values = (0..fields.len()) + .map(|index| { + let child = array.child(index)?; + ffi::try_from(child).map(|arr| { + // there is a discrepancy with how polars_arrow exports sliced + // struct array and how pyarrow does it. + // # Pyarrow + // ## struct array len 3 + // * slice 1 by with len 2 + // offset on struct array: 1 + // length on struct array: 2 + // offset on value array: 0 + // length on value array: 3 + // # Arrow2 + // ## struct array len 3 + // * slice 1 by with len 2 + // offset on struct array: 0 + // length on struct array: 3 + // offset on value array: 1 + // length on value array: 2 + // + // this branch will ensure both can round trip + if arr.len() >= (len + offset) { + arr.sliced(offset, len) + } else { + arr + } + }) + }) + .collect::>>>()?; + + Self::try_new(dtype, len, values, validity) + } +} diff --git a/crates/polars-arrow/src/array/struct_/fmt.rs b/crates/polars-arrow/src/array/struct_/fmt.rs new file mode 100644 index 000000000000..999cd8b67e08 --- /dev/null +++ b/crates/polars-arrow/src/array/struct_/fmt.rs @@ -0,0 +1,34 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::{get_display, write_map, write_vec}; +use super::StructArray; + +pub fn write_value( + array: &StructArray, + index: usize, + null: &'static str, + f: &mut W, +) -> Result { + let writer = |f: &mut W, _index| { + for (i, (field, column)) in array.fields().iter().zip(array.values()).enumerate() { + if i != 0 { + write!(f, ", ")?; + } + let writer = get_display(column.as_ref(), null); + write!(f, "{}: ", field.name)?; + writer(f, index)?; + } + Ok(()) + }; + + write_map(f, writer, None, 1, null, false) +} + +impl Debug for StructArray { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, "None", f); + + write!(f, "StructArray")?; + write_vec(f, writer, self.validity(), self.len(), "None", false) + } +} diff --git a/crates/polars-arrow/src/array/struct_/iterator.rs b/crates/polars-arrow/src/array/struct_/iterator.rs new file mode 100644 index 000000000000..29c428e858a0 --- /dev/null +++ b/crates/polars-arrow/src/array/struct_/iterator.rs @@ -0,0 +1,96 @@ +use super::StructArray; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::scalar::{Scalar, new_scalar}; +use crate::trusted_len::TrustedLen; + +pub struct StructValueIter<'a> { + array: &'a StructArray, + index: usize, + end: usize, +} + +impl<'a> StructValueIter<'a> { + #[inline] + pub fn new(array: &'a StructArray) -> Self { + Self { + array, + index: 0, + end: array.len(), + } + } +} + +impl Iterator for StructValueIter<'_> { + type Item = Vec>; + + #[inline] + fn next(&mut self) -> Option { + if self.index == self.end { + return None; + } + let old = self.index; + self.index += 1; + + // SAFETY: + // self.end is maximized by the length of the array + Some( + self.array + .values() + .iter() + .map(|v| new_scalar(v.as_ref(), old)) + .collect(), + ) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.end - self.index, Some(self.end - self.index)) + } +} + +unsafe impl TrustedLen for StructValueIter<'_> {} + +impl DoubleEndedIterator for StructValueIter<'_> { + #[inline] + fn next_back(&mut self) -> Option { + if self.index == self.end { + None + } else { + self.end -= 1; + + // SAFETY: + // self.end is maximized by the length of the array + Some( + self.array + .values() + .iter() + .map(|v| new_scalar(v.as_ref(), self.end)) + .collect(), + ) + } + } +} + +type ValuesIter<'a> = StructValueIter<'a>; +type ZipIter<'a> = ZipValidity>, ValuesIter<'a>, BitmapIter<'a>>; + +impl<'a> IntoIterator for &'a StructArray { + type Item = Option>>; + type IntoIter = ZipIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a> StructArray { + /// Returns an iterator of `Option>` + pub fn iter(&'a self) -> ZipIter<'a> { + ZipValidity::new_with_validity(StructValueIter::new(self), self.validity()) + } + + /// Returns an iterator of `Box` + pub fn values_iter(&'a self) -> ValuesIter<'a> { + StructValueIter::new(self) + } +} diff --git a/crates/polars-arrow/src/array/struct_/mod.rs b/crates/polars-arrow/src/array/struct_/mod.rs new file mode 100644 index 000000000000..016d0b988112 --- /dev/null +++ b/crates/polars-arrow/src/array/struct_/mod.rs @@ -0,0 +1,322 @@ +use super::{Array, Splitable, new_empty_array, new_null_array}; +use crate::bitmap::Bitmap; +use crate::datatypes::{ArrowDataType, Field}; + +mod builder; +pub use builder::*; +mod ffi; +pub(super) mod fmt; +mod iterator; +use polars_error::{PolarsResult, polars_bail, polars_ensure}; + +use crate::compute::utils::combine_validities_and; + +/// A [`StructArray`] is a nested [`Array`] with an optional validity representing +/// multiple [`Array`] with the same number of rows. +/// # Example +/// ``` +/// use polars_arrow::array::*; +/// use polars_arrow::datatypes::*; +/// let boolean = BooleanArray::from_slice(&[false, false, true, true]).boxed(); +/// let int = Int32Array::from_slice(&[42, 28, 19, 31]).boxed(); +/// +/// let fields = vec![ +/// Field::new("b".into(), ArrowDataType::Boolean, false), +/// Field::new("c".into(), ArrowDataType::Int32, false), +/// ]; +/// +/// let array = StructArray::new(ArrowDataType::Struct(fields), 4, vec![boolean, int], None); +/// ``` +#[derive(Clone)] +pub struct StructArray { + dtype: ArrowDataType, + // invariant: each array has the same length + values: Vec>, + // invariant: for each v in values: length == v.len() + length: usize, + validity: Option, +} + +impl StructArray { + /// Returns a new [`StructArray`]. + /// # Errors + /// This function errors iff: + /// * `dtype`'s physical type is not [`crate::datatypes::PhysicalType::Struct`]. + /// * the children of `dtype` are empty + /// * the values's len is different from children's length + /// * any of the values's data type is different from its corresponding children' data type + /// * any element of values has a different length than the first element + /// * the validity's length is not equal to the length of the first element + pub fn try_new( + dtype: ArrowDataType, + length: usize, + values: Vec>, + validity: Option, + ) -> PolarsResult { + let fields = Self::try_get_fields(&dtype)?; + + polars_ensure!( + fields.len() == values.len(), + ComputeError: + "a StructArray must have a number of fields in its DataType equal to the number of child values" + ); + + fields + .iter().map(|a| &a.dtype) + .zip(values.iter().map(|a| a.dtype())) + .enumerate() + .try_for_each(|(index, (dtype, child))| { + if dtype != child { + polars_bail!(ComputeError: + "The children DataTypes of a StructArray must equal the children data types. + However, the field {index} has data type {dtype:?} but the value has data type {child:?}" + ) + } else { + Ok(()) + } + })?; + + values + .iter() + .map(|f| f.len()) + .enumerate() + .try_for_each(|(index, f_length)| { + if f_length != length { + polars_bail!(ComputeError: "The children must have the given number of values. + However, the values at index {index} have a length of {f_length}, which is different from given length {length}.") + } else { + Ok(()) + } + })?; + + if validity + .as_ref() + .is_some_and(|validity| validity.len() != length) + { + polars_bail!(ComputeError:"The validity length of a StructArray must match its number of elements") + } + + Ok(Self { + dtype, + length, + values, + validity, + }) + } + + /// Returns a new [`StructArray`] + /// # Panics + /// This function panics iff: + /// * `dtype`'s physical type is not [`crate::datatypes::PhysicalType::Struct`]. + /// * the children of `dtype` are empty + /// * the values's len is different from children's length + /// * any of the values's data type is different from its corresponding children' data type + /// * any element of values has a different length than the first element + /// * the validity's length is not equal to the length of the first element + pub fn new( + dtype: ArrowDataType, + length: usize, + values: Vec>, + validity: Option, + ) -> Self { + Self::try_new(dtype, length, values, validity).unwrap() + } + + /// Creates an empty [`StructArray`]. + pub fn new_empty(dtype: ArrowDataType) -> Self { + if let ArrowDataType::Struct(fields) = &dtype.to_logical_type() { + let values = fields + .iter() + .map(|field| new_empty_array(field.dtype().clone())) + .collect(); + Self::new(dtype, 0, values, None) + } else { + panic!("StructArray must be initialized with DataType::Struct"); + } + } + + /// Creates a null [`StructArray`] of length `length`. + pub fn new_null(dtype: ArrowDataType, length: usize) -> Self { + if let ArrowDataType::Struct(fields) = &dtype { + let values = fields + .iter() + .map(|field| new_null_array(field.dtype().clone(), length)) + .collect(); + Self::new(dtype, length, values, Some(Bitmap::new_zeroed(length))) + } else { + panic!("StructArray must be initialized with DataType::Struct"); + } + } +} + +// must use +impl StructArray { + /// Deconstructs the [`StructArray`] into its individual components. + #[must_use] + pub fn into_data(self) -> (Vec, usize, Vec>, Option) { + let Self { + dtype, + length, + values, + validity, + } = self; + let fields = if let ArrowDataType::Struct(fields) = dtype { + fields + } else { + unreachable!() + }; + (fields, length, values, validity) + } + + /// Slices this [`StructArray`]. + /// # Panics + /// panics iff `offset + length > self.len()` + /// # Implementation + /// This operation is `O(F)` where `F` is the number of fields. + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "offset + length may not exceed length of array" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Slices this [`StructArray`]. + /// # Implementation + /// This operation is `O(F)` where `F` is the number of fields. + /// + /// # Safety + /// The caller must ensure that `offset + length <= self.len()`. + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.validity = self + .validity + .take() + .map(|bitmap| bitmap.sliced_unchecked(offset, length)) + .filter(|bitmap| bitmap.unset_bits() > 0); + self.values + .iter_mut() + .for_each(|x| x.slice_unchecked(offset, length)); + self.length = length; + } + + /// Set the outer nulls into the inner arrays. + pub fn propagate_nulls(&self) -> StructArray { + let has_nulls = self.null_count() > 0; + let mut out = self.clone(); + if !has_nulls { + return out; + }; + + for value_arr in &mut out.values { + let new_validity = combine_validities_and(self.validity(), value_arr.validity()); + *value_arr = value_arr.with_validity(new_validity); + } + out + } + + impl_sliced!(); + + impl_mut_validity!(); + + impl_into_array!(); +} + +// Accessors +impl StructArray { + #[inline] + fn len(&self) -> usize { + if cfg!(debug_assertions) { + for arr in self.values.iter() { + assert_eq!( + arr.len(), + self.length, + "StructArray invariant: each array has same length" + ); + } + } + + self.length + } + + /// The optional validity. + #[inline] + pub fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + /// Returns the values of this [`StructArray`]. + pub fn values(&self) -> &[Box] { + &self.values + } + + /// Returns the fields of this [`StructArray`]. + pub fn fields(&self) -> &[Field] { + let fields = Self::get_fields(&self.dtype); + debug_assert_eq!(self.values().len(), fields.len()); + fields + } +} + +impl StructArray { + /// Returns the fields the `DataType::Struct`. + pub(crate) fn try_get_fields(dtype: &ArrowDataType) -> PolarsResult<&[Field]> { + match dtype.to_logical_type() { + ArrowDataType::Struct(fields) => Ok(fields), + _ => { + polars_bail!(ComputeError: "Struct array must be created with a DataType whose physical type is Struct") + }, + } + } + + /// Returns the fields the `DataType::Struct`. + pub fn get_fields(dtype: &ArrowDataType) -> &[Field] { + Self::try_get_fields(dtype).unwrap() + } +} + +impl Array for StructArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + #[inline] + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } +} + +impl Splitable for StructArray { + fn check_bound(&self, offset: usize) -> bool { + offset <= self.len() + } + + unsafe fn _split_at_unchecked(&self, offset: usize) -> (Self, Self) { + let (lhs_validity, rhs_validity) = unsafe { self.validity.split_at_unchecked(offset) }; + + let mut lhs_values = Vec::with_capacity(self.values.len()); + let mut rhs_values = Vec::with_capacity(self.values.len()); + + for v in self.values.iter() { + let (lhs, rhs) = unsafe { v.split_at_boxed_unchecked(offset) }; + lhs_values.push(lhs); + rhs_values.push(rhs); + } + + ( + Self { + dtype: self.dtype.clone(), + length: offset, + values: lhs_values, + validity: lhs_validity, + }, + Self { + dtype: self.dtype.clone(), + length: self.length - offset, + values: rhs_values, + validity: rhs_validity, + }, + ) + } +} diff --git a/crates/polars-arrow/src/array/total_ord.rs b/crates/polars-arrow/src/array/total_ord.rs new file mode 100644 index 000000000000..778c488e0a2b --- /dev/null +++ b/crates/polars-arrow/src/array/total_ord.rs @@ -0,0 +1,9 @@ +use polars_utils::total_ord::TotalEq; + +use crate::array::Array; + +impl TotalEq for Box { + fn tot_eq(&self, other: &Self) -> bool { + self == other + } +} diff --git a/crates/polars-arrow/src/array/union/ffi.rs b/crates/polars-arrow/src/array/union/ffi.rs new file mode 100644 index 000000000000..bf023f306dda --- /dev/null +++ b/crates/polars-arrow/src/array/union/ffi.rs @@ -0,0 +1,61 @@ +use polars_error::PolarsResult; + +use super::super::Array; +use super::super::ffi::ToFfi; +use super::UnionArray; +use crate::array::FromFfi; +use crate::ffi; + +unsafe impl ToFfi for UnionArray { + fn buffers(&self) -> Vec> { + if let Some(offsets) = &self.offsets { + vec![ + Some(self.types.storage_ptr().cast::()), + Some(offsets.storage_ptr().cast::()), + ] + } else { + vec![Some(self.types.storage_ptr().cast::())] + } + } + + fn children(&self) -> Vec> { + self.fields.clone() + } + + fn offset(&self) -> Option { + Some(self.types.offset()) + } + + fn to_ffi_aligned(&self) -> Self { + self.clone() + } +} + +impl FromFfi for UnionArray { + unsafe fn try_from_ffi(array: A) -> PolarsResult { + let dtype = array.dtype().clone(); + let fields = Self::get_fields(&dtype); + + let mut types = unsafe { array.buffer::(0) }?; + let offsets = if Self::is_sparse(&dtype) { + None + } else { + Some(unsafe { array.buffer::(1) }?) + }; + + let length = array.array().len(); + let offset = array.array().offset(); + let fields = (0..fields.len()) + .map(|index| { + let child = array.child(index)?; + ffi::try_from(child) + }) + .collect::>>>()?; + + if offset > 0 { + types.slice(offset, length); + }; + + Self::try_new(dtype, types, fields, offsets) + } +} diff --git a/crates/polars-arrow/src/array/union/fmt.rs b/crates/polars-arrow/src/array/union/fmt.rs new file mode 100644 index 000000000000..521201fffd6d --- /dev/null +++ b/crates/polars-arrow/src/array/union/fmt.rs @@ -0,0 +1,24 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::{get_display, write_vec}; +use super::UnionArray; + +pub fn write_value( + array: &UnionArray, + index: usize, + null: &'static str, + f: &mut W, +) -> Result { + let (field, index) = array.index(index); + + get_display(array.fields()[field].as_ref(), null)(f, index) +} + +impl Debug for UnionArray { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, "None", f); + + write!(f, "UnionArray")?; + write_vec(f, writer, None, self.len(), "None", false) + } +} diff --git a/crates/polars-arrow/src/array/union/iterator.rs b/crates/polars-arrow/src/array/union/iterator.rs new file mode 100644 index 000000000000..e93223e46c43 --- /dev/null +++ b/crates/polars-arrow/src/array/union/iterator.rs @@ -0,0 +1,59 @@ +use super::UnionArray; +use crate::scalar::Scalar; +use crate::trusted_len::TrustedLen; + +#[derive(Debug, Clone)] +pub struct UnionIter<'a> { + array: &'a UnionArray, + current: usize, +} + +impl<'a> UnionIter<'a> { + #[inline] + pub fn new(array: &'a UnionArray) -> Self { + Self { array, current: 0 } + } +} + +impl Iterator for UnionIter<'_> { + type Item = Box; + + #[inline] + fn next(&mut self) -> Option { + if self.current == self.array.len() { + None + } else { + let old = self.current; + self.current += 1; + Some(unsafe { self.array.value_unchecked(old) }) + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.array.len() - self.current; + (len, Some(len)) + } +} + +impl<'a> IntoIterator for &'a UnionArray { + type Item = Box; + type IntoIter = UnionIter<'a>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a> UnionArray { + /// constructs a new iterator + #[inline] + pub fn iter(&'a self) -> UnionIter<'a> { + UnionIter::new(self) + } +} + +impl std::iter::ExactSizeIterator for UnionIter<'_> {} + +unsafe impl TrustedLen for UnionIter<'_> {} diff --git a/crates/polars-arrow/src/array/union/mod.rs b/crates/polars-arrow/src/array/union/mod.rs new file mode 100644 index 000000000000..82d03c44108d --- /dev/null +++ b/crates/polars-arrow/src/array/union/mod.rs @@ -0,0 +1,413 @@ +use polars_error::{PolarsResult, polars_bail, polars_err}; + +use super::{Array, Splitable, new_empty_array, new_null_array}; +use crate::bitmap::Bitmap; +use crate::buffer::Buffer; +use crate::datatypes::{ArrowDataType, Field, UnionMode}; +use crate::scalar::{Scalar, new_scalar}; + +mod ffi; +pub(super) mod fmt; +mod iterator; + +type UnionComponents<'a> = (&'a [Field], Option<&'a [i32]>, UnionMode); + +/// [`UnionArray`] represents an array whose each slot can contain different values. +/// +// How to read a value at slot i: +// ``` +// let index = self.types()[i] as usize; +// let field = self.fields()[index]; +// let offset = self.offsets().map(|x| x[index]).unwrap_or(i); +// let field = field.as_any().downcast to correct type; +// let value = field.value(offset); +// ``` +#[derive(Clone)] +pub struct UnionArray { + // Invariant: every item in `types` is `> 0 && < fields.len()` + types: Buffer, + // Invariant: `map.len() == fields.len()` + // Invariant: every item in `map` is `> 0 && < fields.len()` + map: Option<[usize; 127]>, + fields: Vec>, + // Invariant: when set, `offsets.len() == types.len()` + offsets: Option>, + dtype: ArrowDataType, + offset: usize, +} + +impl UnionArray { + /// Returns a new [`UnionArray`]. + /// # Errors + /// This function errors iff: + /// * `dtype`'s physical type is not [`crate::datatypes::PhysicalType::Union`]. + /// * the fields's len is different from the `dtype`'s children's length + /// * The number of `fields` is larger than `i8::MAX` + /// * any of the values's data type is different from its corresponding children' data type + pub fn try_new( + dtype: ArrowDataType, + types: Buffer, + fields: Vec>, + offsets: Option>, + ) -> PolarsResult { + let (f, ids, mode) = Self::try_get_all(&dtype)?; + + if f.len() != fields.len() { + polars_bail!(ComputeError: "the number of `fields` must equal the number of children fields in DataType::Union") + }; + let number_of_fields: i8 = fields.len().try_into().map_err( + |_| polars_err!(ComputeError: "the number of `fields` cannot be larger than i8::MAX"), + )?; + + f + .iter().map(|a| a.dtype()) + .zip(fields.iter().map(|a| a.dtype())) + .enumerate() + .try_for_each(|(index, (dtype, child))| { + if dtype != child { + polars_bail!(ComputeError: + "the children DataTypes of a UnionArray must equal the children data types. + However, the field {index} has data type {dtype:?} but the value has data type {child:?}" + ) + } else { + Ok(()) + } + })?; + + if let Some(offsets) = &offsets { + if offsets.len() != types.len() { + polars_bail!(ComputeError: + "in a UnionArray, the offsets' length must be equal to the number of types" + ) + } + } + if offsets.is_none() != mode.is_sparse() { + polars_bail!(ComputeError: + "in a sparse UnionArray, the offsets must be set (and vice-versa)", + ) + } + + // build hash + let map = if let Some(&ids) = ids.as_ref() { + if ids.len() != fields.len() { + polars_bail!(ComputeError: + "in a union, when the ids are set, their length must be equal to the number of fields", + ) + } + + // example: + // * types = [5, 7, 5, 7, 7, 7, 5, 7, 7, 5, 5] + // * ids = [5, 7] + // => hash = [0, 0, 0, 0, 0, 0, 1, 0, ...] + let mut hash = [0; 127]; + + for (pos, &id) in ids.iter().enumerate() { + if !(0..=127).contains(&id) { + polars_bail!(ComputeError: + "in a union, when the ids are set, every id must belong to [0, 128[", + ) + } + hash[id as usize] = pos; + } + + types.iter().try_for_each(|&type_| { + if type_ < 0 { + polars_bail!(ComputeError: + "in a union, when the ids are set, every type must be >= 0" + ) + } + let id = hash[type_ as usize]; + if id >= fields.len() { + polars_bail!(ComputeError: + "in a union, when the ids are set, each id must be smaller than the number of fields." + ) + } else { + Ok(()) + } + })?; + + Some(hash) + } else { + // SAFETY: every type in types is smaller than number of fields + let mut is_valid = true; + for &type_ in types.iter() { + if type_ < 0 || type_ >= number_of_fields { + is_valid = false + } + } + if !is_valid { + polars_bail!(ComputeError: + "every type in `types` must be larger than 0 and smaller than the number of fields.", + ) + } + + None + }; + + Ok(Self { + dtype, + map, + fields, + offsets, + types, + offset: 0, + }) + } + + /// Returns a new [`UnionArray`]. + /// # Panics + /// This function panics iff: + /// * `dtype`'s physical type is not [`crate::datatypes::PhysicalType::Union`]. + /// * the fields's len is different from the `dtype`'s children's length + /// * any of the values's data type is different from its corresponding children' data type + pub fn new( + dtype: ArrowDataType, + types: Buffer, + fields: Vec>, + offsets: Option>, + ) -> Self { + Self::try_new(dtype, types, fields, offsets).unwrap() + } + + /// Creates a new null [`UnionArray`]. + pub fn new_null(dtype: ArrowDataType, length: usize) -> Self { + if let ArrowDataType::Union(u) = &dtype { + let fields = u + .fields + .iter() + .map(|x| new_null_array(x.dtype().clone(), length)) + .collect(); + + let offsets = if u.mode.is_sparse() { + None + } else { + Some((0..length as i32).collect::>().into()) + }; + + // all from the same field + let types = vec![0i8; length].into(); + + Self::new(dtype, types, fields, offsets) + } else { + panic!("Union struct must be created with the corresponding Union DataType") + } + } + + /// Creates a new empty [`UnionArray`]. + pub fn new_empty(dtype: ArrowDataType) -> Self { + if let ArrowDataType::Union(u) = dtype.to_logical_type() { + let fields = u + .fields + .iter() + .map(|x| new_empty_array(x.dtype().clone())) + .collect(); + + let offsets = if u.mode.is_sparse() { + None + } else { + Some(Buffer::default()) + }; + + Self { + dtype, + map: None, + fields, + offsets, + types: Buffer::new(), + offset: 0, + } + } else { + panic!("Union struct must be created with the corresponding Union DataType") + } + } +} + +impl UnionArray { + /// Returns a slice of this [`UnionArray`]. + /// # Implementation + /// This operation is `O(F)` where `F` is the number of fields. + /// # Panic + /// This function panics iff `offset + length > self.len()`. + #[inline] + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new array cannot exceed the existing length" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Returns a slice of this [`UnionArray`]. + /// # Implementation + /// This operation is `O(F)` where `F` is the number of fields. + /// + /// # Safety + /// The caller must ensure that `offset + length <= self.len()`. + #[inline] + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + debug_assert!(offset + length <= self.len()); + + self.types.slice_unchecked(offset, length); + if let Some(offsets) = self.offsets.as_mut() { + offsets.slice_unchecked(offset, length) + } + self.offset += offset; + } + + impl_sliced!(); + impl_into_array!(); +} + +impl UnionArray { + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.types.len() + } + + /// The optional offsets. + pub fn offsets(&self) -> Option<&Buffer> { + self.offsets.as_ref() + } + + /// The fields. + pub fn fields(&self) -> &Vec> { + &self.fields + } + + /// The types. + pub fn types(&self) -> &Buffer { + &self.types + } + + #[inline] + unsafe fn field_slot_unchecked(&self, index: usize) -> usize { + self.offsets() + .as_ref() + .map(|x| *x.get_unchecked(index) as usize) + .unwrap_or(index + self.offset) + } + + /// Returns the index and slot of the field to select from `self.fields`. + #[inline] + pub fn index(&self, index: usize) -> (usize, usize) { + assert!(index < self.len()); + unsafe { self.index_unchecked(index) } + } + + /// Returns the index and slot of the field to select from `self.fields`. + /// The first value is guaranteed to be `< self.fields().len()` + /// + /// # Safety + /// This function is safe iff `index < self.len`. + #[inline] + pub unsafe fn index_unchecked(&self, index: usize) -> (usize, usize) { + debug_assert!(index < self.len()); + // SAFETY: assumption of the function + let type_ = unsafe { *self.types.get_unchecked(index) }; + // SAFETY: assumption of the struct + let type_ = self + .map + .as_ref() + .map(|map| unsafe { *map.get_unchecked(type_ as usize) }) + .unwrap_or(type_ as usize); + // SAFETY: assumption of the function + let index = self.field_slot_unchecked(index); + (type_, index) + } + + /// Returns the slot `index` as a [`Scalar`]. + /// # Panics + /// iff `index >= self.len()` + pub fn value(&self, index: usize) -> Box { + assert!(index < self.len()); + unsafe { self.value_unchecked(index) } + } + + /// Returns the slot `index` as a [`Scalar`]. + /// + /// # Safety + /// This function is safe iff `i < self.len`. + pub unsafe fn value_unchecked(&self, index: usize) -> Box { + debug_assert!(index < self.len()); + let (type_, index) = self.index_unchecked(index); + // SAFETY: assumption of the struct + debug_assert!(type_ < self.fields.len()); + let field = self.fields.get_unchecked(type_).as_ref(); + new_scalar(field, index) + } +} + +impl Array for UnionArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + None + } + + fn with_validity(&self, _: Option) -> Box { + panic!("cannot set validity of a union array") + } +} + +impl UnionArray { + fn try_get_all(dtype: &ArrowDataType) -> PolarsResult { + match dtype.to_logical_type() { + ArrowDataType::Union(u) => Ok((&u.fields, u.ids.as_ref().map(|x| x.as_ref()), u.mode)), + _ => polars_bail!(ComputeError: + "The UnionArray requires a logical type of DataType::Union", + ), + } + } + + fn get_all(dtype: &ArrowDataType) -> (&[Field], Option<&[i32]>, UnionMode) { + Self::try_get_all(dtype).unwrap() + } + + /// Returns all fields from [`ArrowDataType::Union`]. + /// # Panic + /// Panics iff `dtype`'s logical type is not [`ArrowDataType::Union`]. + pub fn get_fields(dtype: &ArrowDataType) -> &[Field] { + Self::get_all(dtype).0 + } + + /// Returns whether the [`ArrowDataType::Union`] is sparse or not. + /// # Panic + /// Panics iff `dtype`'s logical type is not [`ArrowDataType::Union`]. + pub fn is_sparse(dtype: &ArrowDataType) -> bool { + Self::get_all(dtype).2.is_sparse() + } +} + +impl Splitable for UnionArray { + fn check_bound(&self, offset: usize) -> bool { + offset <= self.len() + } + + unsafe fn _split_at_unchecked(&self, offset: usize) -> (Self, Self) { + let (lhs_types, rhs_types) = unsafe { self.types.split_at_unchecked(offset) }; + let (lhs_offsets, rhs_offsets) = self.offsets.as_ref().map_or((None, None), |v| { + let (lhs, rhs) = unsafe { v.split_at_unchecked(offset) }; + (Some(lhs), Some(rhs)) + }); + + ( + Self { + types: lhs_types, + map: self.map, + fields: self.fields.clone(), + offsets: lhs_offsets, + dtype: self.dtype.clone(), + offset: self.offset, + }, + Self { + types: rhs_types, + map: self.map, + fields: self.fields.clone(), + offsets: rhs_offsets, + dtype: self.dtype.clone(), + offset: self.offset + offset, + }, + ) + } +} diff --git a/crates/polars-arrow/src/array/utf8/ffi.rs b/crates/polars-arrow/src/array/utf8/ffi.rs new file mode 100644 index 000000000000..7181eba91286 --- /dev/null +++ b/crates/polars-arrow/src/array/utf8/ffi.rs @@ -0,0 +1,63 @@ +use polars_error::PolarsResult; + +use super::Utf8Array; +use crate::array::{FromFfi, ToFfi}; +use crate::bitmap::align; +use crate::ffi; +use crate::offset::{Offset, OffsetsBuffer}; + +unsafe impl ToFfi for Utf8Array { + fn buffers(&self) -> Vec> { + vec![ + self.validity.as_ref().map(|x| x.as_ptr()), + Some(self.offsets.buffer().storage_ptr().cast::()), + Some(self.values.storage_ptr().cast::()), + ] + } + + fn offset(&self) -> Option { + let offset = self.offsets.buffer().offset(); + if let Some(bitmap) = self.validity.as_ref() { + if bitmap.offset() == offset { + Some(offset) + } else { + None + } + } else { + Some(offset) + } + } + + fn to_ffi_aligned(&self) -> Self { + let offset = self.offsets.buffer().offset(); + + let validity = self.validity.as_ref().map(|bitmap| { + if bitmap.offset() == offset { + bitmap.clone() + } else { + align(bitmap, offset) + } + }); + + Self { + dtype: self.dtype.clone(), + validity, + offsets: self.offsets.clone(), + values: self.values.clone(), + } + } +} + +impl FromFfi for Utf8Array { + unsafe fn try_from_ffi(array: A) -> PolarsResult { + let dtype = array.dtype().clone(); + let validity = unsafe { array.validity() }?; + let offsets = unsafe { array.buffer::(1) }?; + let values = unsafe { array.buffer::(2)? }; + + // assumption that data from FFI is well constructed + let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets) }; + + Ok(Self::new_unchecked(dtype, offsets, values, validity)) + } +} diff --git a/crates/polars-arrow/src/array/utf8/fmt.rs b/crates/polars-arrow/src/array/utf8/fmt.rs new file mode 100644 index 000000000000..4466444ffe3b --- /dev/null +++ b/crates/polars-arrow/src/array/utf8/fmt.rs @@ -0,0 +1,23 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::write_vec; +use super::Utf8Array; +use crate::offset::Offset; + +pub fn write_value(array: &Utf8Array, index: usize, f: &mut W) -> Result { + write!(f, "{}", array.value(index)) +} + +impl Debug for Utf8Array { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, f); + + let head = if O::IS_LARGE { + "LargeUtf8Array" + } else { + "Utf8Array" + }; + write!(f, "{head}")?; + write_vec(f, writer, self.validity(), self.len(), "None", false) + } +} diff --git a/crates/polars-arrow/src/array/utf8/from.rs b/crates/polars-arrow/src/array/utf8/from.rs new file mode 100644 index 000000000000..6f90bac99495 --- /dev/null +++ b/crates/polars-arrow/src/array/utf8/from.rs @@ -0,0 +1,9 @@ +use super::{MutableUtf8Array, Utf8Array}; +use crate::offset::Offset; + +impl> FromIterator> for Utf8Array { + #[inline] + fn from_iter>>(iter: I) -> Self { + MutableUtf8Array::::from_iter(iter).into() + } +} diff --git a/crates/polars-arrow/src/array/utf8/iterator.rs b/crates/polars-arrow/src/array/utf8/iterator.rs new file mode 100644 index 000000000000..262b98c10d79 --- /dev/null +++ b/crates/polars-arrow/src/array/utf8/iterator.rs @@ -0,0 +1,79 @@ +use super::{MutableUtf8Array, MutableUtf8ValuesArray, Utf8Array}; +use crate::array::{ArrayAccessor, ArrayValuesIter}; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::offset::Offset; + +unsafe impl<'a, O: Offset> ArrayAccessor<'a> for Utf8Array { + type Item = &'a str; + + #[inline] + unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item { + self.value_unchecked(index) + } + + #[inline] + fn len(&self) -> usize { + self.len() + } +} + +/// Iterator of values of an [`Utf8Array`]. +pub type Utf8ValuesIter<'a, O> = ArrayValuesIter<'a, Utf8Array>; + +impl<'a, O: Offset> IntoIterator for &'a Utf8Array { + type Item = Option<&'a str>; + type IntoIter = ZipValidity<&'a str, Utf8ValuesIter<'a, O>, BitmapIter<'a>>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +unsafe impl<'a, O: Offset> ArrayAccessor<'a> for MutableUtf8Array { + type Item = &'a str; + + #[inline] + unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item { + self.value_unchecked(index) + } + + #[inline] + fn len(&self) -> usize { + self.len() + } +} + +/// Iterator of values of an [`MutableUtf8ValuesArray`]. +pub type MutableUtf8ValuesIter<'a, O> = ArrayValuesIter<'a, MutableUtf8ValuesArray>; + +impl<'a, O: Offset> IntoIterator for &'a MutableUtf8Array { + type Item = Option<&'a str>; + type IntoIter = ZipValidity<&'a str, MutableUtf8ValuesIter<'a, O>, BitmapIter<'a>>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +unsafe impl<'a, O: Offset> ArrayAccessor<'a> for MutableUtf8ValuesArray { + type Item = &'a str; + + #[inline] + unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item { + self.value_unchecked(index) + } + + #[inline] + fn len(&self) -> usize { + self.len() + } +} + +impl<'a, O: Offset> IntoIterator for &'a MutableUtf8ValuesArray { + type Item = &'a str; + type IntoIter = ArrayValuesIter<'a, MutableUtf8ValuesArray>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} diff --git a/crates/polars-arrow/src/array/utf8/mod.rs b/crates/polars-arrow/src/array/utf8/mod.rs new file mode 100644 index 000000000000..886bc357faf6 --- /dev/null +++ b/crates/polars-arrow/src/array/utf8/mod.rs @@ -0,0 +1,563 @@ +use either::Either; + +use super::specification::try_check_utf8; +use super::{Array, GenericBinaryArray, Splitable}; +use crate::array::BinaryArray; +use crate::array::iterator::NonNullValuesIter; +use crate::bitmap::Bitmap; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::buffer::Buffer; +use crate::datatypes::ArrowDataType; +use crate::offset::{Offset, Offsets, OffsetsBuffer}; +use crate::trusted_len::TrustedLen; + +mod ffi; +pub(super) mod fmt; +mod from; +mod iterator; +mod mutable; +mod mutable_values; +pub use iterator::*; +pub use mutable::*; +pub use mutable_values::MutableUtf8ValuesArray; +use polars_error::*; + +// Auxiliary struct to allow presenting &str as [u8] to a generic function +pub(super) struct StrAsBytes

(P); +impl> AsRef<[u8]> for StrAsBytes { + #[inline(always)] + fn as_ref(&self) -> &[u8] { + self.0.as_ref().as_bytes() + } +} + +/// A [`Utf8Array`] is arrow's semantic equivalent of an immutable `Vec>`. +/// Cloning and slicing this struct is `O(1)`. +/// # Example +/// ``` +/// use polars_arrow::bitmap::Bitmap; +/// use polars_arrow::buffer::Buffer; +/// use polars_arrow::array::Utf8Array; +/// # fn main() { +/// let array = Utf8Array::::from([Some("hi"), None, Some("there")]); +/// assert_eq!(array.value(0), "hi"); +/// assert_eq!(array.iter().collect::>(), vec![Some("hi"), None, Some("there")]); +/// assert_eq!(array.values_iter().collect::>(), vec!["hi", "", "there"]); +/// // the underlying representation +/// assert_eq!(array.validity(), Some(&Bitmap::from([true, false, true]))); +/// assert_eq!(array.values(), &Buffer::from(b"hithere".to_vec())); +/// assert_eq!(array.offsets().buffer(), &Buffer::from(vec![0, 2, 2, 2 + 5])); +/// # } +/// ``` +/// +/// # Generic parameter +/// The generic parameter [`Offset`] can only be `i32` or `i64` and tradeoffs maximum array length with +/// memory usage: +/// * the sum of lengths of all elements cannot exceed `Offset::MAX` +/// * the total size of the underlying data is `array.len() * size_of::() + sum of lengths of all elements` +/// +/// # Safety +/// The following invariants hold: +/// * Two consecutive `offsets` cast (`as`) to `usize` are valid slices of `values`. +/// * A slice of `values` taken from two consecutive `offsets` is valid `utf8`. +/// * `len` is equal to `validity.len()`, when defined. +#[derive(Clone)] +pub struct Utf8Array { + dtype: ArrowDataType, + offsets: OffsetsBuffer, + values: Buffer, + validity: Option, +} + +// constructors +impl Utf8Array { + /// Returns a [`Utf8Array`] created from its internal representation. + /// + /// # Errors + /// This function returns an error iff: + /// * The last offset is greater than the values' length. + /// * the validity's length is not equal to `offsets.len_proxy()`. + /// * The `dtype`'s [`crate::datatypes::PhysicalType`] is not equal to either `Utf8` or `LargeUtf8`. + /// * The `values` between two consecutive `offsets` are not valid utf8 + /// # Implementation + /// This function is `O(N)` - checking utf8 is `O(N)` + pub fn try_new( + dtype: ArrowDataType, + offsets: OffsetsBuffer, + values: Buffer, + validity: Option, + ) -> PolarsResult { + try_check_utf8(&offsets, &values)?; + if validity + .as_ref() + .is_some_and(|validity| validity.len() != offsets.len_proxy()) + { + polars_bail!(ComputeError: "validity mask length must match the number of values"); + } + + if dtype.to_physical_type() != Self::default_dtype().to_physical_type() { + polars_bail!(ComputeError: "Utf8Array can only be initialized with DataType::Utf8 or DataType::LargeUtf8") + } + + Ok(Self { + dtype, + offsets, + values, + validity, + }) + } + + /// Returns a [`Utf8Array`] from a slice of `&str`. + /// + /// A convenience method that uses [`Self::from_trusted_len_values_iter`]. + pub fn from_slice, P: AsRef<[T]>>(slice: P) -> Self { + Self::from_trusted_len_values_iter(slice.as_ref().iter()) + } + + /// Returns a new [`Utf8Array`] from a slice of `&str`. + /// + /// A convenience method that uses [`Self::from_trusted_len_iter`]. + // Note: this can't be `impl From` because Rust does not allow double `AsRef` on it. + pub fn from, P: AsRef<[Option]>>(slice: P) -> Self { + MutableUtf8Array::::from(slice).into() + } + + /// Returns an iterator of `Option<&str>` + pub fn iter(&self) -> ZipValidity<&str, Utf8ValuesIter, BitmapIter> { + ZipValidity::new_with_validity(self.values_iter(), self.validity()) + } + + /// Returns an iterator of `&str` + pub fn values_iter(&self) -> Utf8ValuesIter { + Utf8ValuesIter::new(self) + } + + /// Returns an iterator of the non-null values `&str. + #[inline] + pub fn non_null_values_iter(&self) -> NonNullValuesIter<'_, Utf8Array> { + NonNullValuesIter::new(self, self.validity()) + } + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.offsets.len_proxy() + } + + /// Returns the value of the element at index `i`, ignoring the array's validity. + /// # Panic + /// This function panics iff `i >= self.len`. + #[inline] + pub fn value(&self, i: usize) -> &str { + assert!(i < self.len()); + unsafe { self.value_unchecked(i) } + } + + /// Returns the value of the element at index `i`, ignoring the array's validity. + /// + /// # Safety + /// This function is safe iff `i < self.len`. + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> &str { + // soundness: the invariant of the function + let (start, end) = self.offsets.start_end_unchecked(i); + + // soundness: the invariant of the struct + let slice = self.values.get_unchecked(start..end); + + // soundness: the invariant of the struct + std::str::from_utf8_unchecked(slice) + } + + /// Returns the element at index `i` or `None` if it is null + /// # Panics + /// iff `i >= self.len()` + #[inline] + pub fn get(&self, i: usize) -> Option<&str> { + if !self.is_null(i) { + // soundness: Array::is_null panics if i >= self.len + unsafe { Some(self.value_unchecked(i)) } + } else { + None + } + } + + /// Returns the [`ArrowDataType`] of this array. + #[inline] + pub fn dtype(&self) -> &ArrowDataType { + &self.dtype + } + + /// Returns the values of this [`Utf8Array`]. + #[inline] + pub fn values(&self) -> &Buffer { + &self.values + } + + /// Returns the offsets of this [`Utf8Array`]. + #[inline] + pub fn offsets(&self) -> &OffsetsBuffer { + &self.offsets + } + + /// The optional validity. + #[inline] + pub fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + /// Slices this [`Utf8Array`]. + /// # Implementation + /// This function is `O(1)`. + /// # Panics + /// iff `offset + length > self.len()`. + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new array cannot exceed the arrays' length" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Slices this [`Utf8Array`]. + /// # Implementation + /// This function is `O(1)` + /// + /// # Safety + /// The caller must ensure that `offset + length <= self.len()`. + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.validity = self + .validity + .take() + .map(|bitmap| bitmap.sliced_unchecked(offset, length)) + .filter(|bitmap| bitmap.unset_bits() > 0); + self.offsets.slice_unchecked(offset, length + 1); + } + + impl_sliced!(); + impl_mut_validity!(); + impl_into_array!(); + + /// Returns its internal representation + #[must_use] + pub fn into_inner(self) -> (ArrowDataType, OffsetsBuffer, Buffer, Option) { + let Self { + dtype, + offsets, + values, + validity, + } = self; + (dtype, offsets, values, validity) + } + + /// Try to convert this `Utf8Array` to a `MutableUtf8Array` + #[must_use] + pub fn into_mut(self) -> Either> { + use Either::*; + if let Some(bitmap) = self.validity { + match bitmap.into_mut() { + // SAFETY: invariants are preserved + Left(bitmap) => Left(unsafe { + Utf8Array::new_unchecked(self.dtype, self.offsets, self.values, Some(bitmap)) + }), + Right(mutable_bitmap) => match (self.values.into_mut(), self.offsets.into_mut()) { + (Left(values), Left(offsets)) => { + // SAFETY: invariants are preserved + Left(unsafe { + Utf8Array::new_unchecked( + self.dtype, + offsets, + values, + Some(mutable_bitmap.into()), + ) + }) + }, + (Left(values), Right(offsets)) => { + // SAFETY: invariants are preserved + Left(unsafe { + Utf8Array::new_unchecked( + self.dtype, + offsets.into(), + values, + Some(mutable_bitmap.into()), + ) + }) + }, + (Right(values), Left(offsets)) => { + // SAFETY: invariants are preserved + Left(unsafe { + Utf8Array::new_unchecked( + self.dtype, + offsets, + values.into(), + Some(mutable_bitmap.into()), + ) + }) + }, + (Right(values), Right(offsets)) => Right(unsafe { + MutableUtf8Array::new_unchecked( + self.dtype, + offsets, + values, + Some(mutable_bitmap), + ) + }), + }, + } + } else { + match (self.values.into_mut(), self.offsets.into_mut()) { + (Left(values), Left(offsets)) => { + Left(unsafe { Utf8Array::new_unchecked(self.dtype, offsets, values, None) }) + }, + (Left(values), Right(offsets)) => Left(unsafe { + Utf8Array::new_unchecked(self.dtype, offsets.into(), values, None) + }), + (Right(values), Left(offsets)) => Left(unsafe { + Utf8Array::new_unchecked(self.dtype, offsets, values.into(), None) + }), + (Right(values), Right(offsets)) => Right(unsafe { + MutableUtf8Array::new_unchecked(self.dtype, offsets, values, None) + }), + } + } + } + + /// Returns a new empty [`Utf8Array`]. + /// + /// The array is guaranteed to have no elements nor validity. + #[inline] + pub fn new_empty(dtype: ArrowDataType) -> Self { + unsafe { Self::new_unchecked(dtype, OffsetsBuffer::new(), Buffer::new(), None) } + } + + /// Returns a new [`Utf8Array`] whose all slots are null / `None`. + #[inline] + pub fn new_null(dtype: ArrowDataType, length: usize) -> Self { + Self::new( + dtype, + Offsets::new_zeroed(length).into(), + Buffer::new(), + Some(Bitmap::new_zeroed(length)), + ) + } + + /// Returns a default [`ArrowDataType`] of this array, which depends on the generic parameter `O`: `DataType::Utf8` or `DataType::LargeUtf8` + pub fn default_dtype() -> ArrowDataType { + if O::IS_LARGE { + ArrowDataType::LargeUtf8 + } else { + ArrowDataType::Utf8 + } + } + + /// Creates a new [`Utf8Array`] without checking for offsets monotinicity nor utf8-validity + /// + /// # Panic + /// This function panics (in debug mode only) iff: + /// * The last offset is greater than the values' length. + /// * the validity's length is not equal to `offsets.len_proxy()`. + /// * The `dtype`'s [`crate::datatypes::PhysicalType`] is not equal to either `Utf8` or `LargeUtf8`. + /// + /// # Safety + /// This function is unsound iff: + /// * The `values` between two consecutive `offsets` are not valid utf8 + /// # Implementation + /// This function is `O(1)` + pub unsafe fn new_unchecked( + dtype: ArrowDataType, + offsets: OffsetsBuffer, + values: Buffer, + validity: Option, + ) -> Self { + debug_assert!( + offsets.last().to_usize() <= values.len(), + "offsets must not exceed the values length" + ); + debug_assert!( + validity + .as_ref() + .is_none_or(|validity| validity.len() == offsets.len_proxy()), + "validity mask length must match the number of values" + ); + debug_assert!( + dtype.to_physical_type() == Self::default_dtype().to_physical_type(), + "Utf8Array can only be initialized with DataType::Utf8 or DataType::LargeUtf8" + ); + + Self { + dtype, + offsets, + values, + validity, + } + } + + /// Creates a new [`Utf8Array`]. + /// # Panics + /// This function panics iff: + /// * `offsets.last()` is greater than `values.len()`. + /// * the validity's length is not equal to `offsets.len_proxy()`. + /// * The `dtype`'s [`crate::datatypes::PhysicalType`] is not equal to either `Utf8` or `LargeUtf8`. + /// * The `values` between two consecutive `offsets` are not valid utf8 + /// # Implementation + /// This function is `O(N)` - checking utf8 is `O(N)` + pub fn new( + dtype: ArrowDataType, + offsets: OffsetsBuffer, + values: Buffer, + validity: Option, + ) -> Self { + Self::try_new(dtype, offsets, values, validity).unwrap() + } + + /// Returns a (non-null) [`Utf8Array`] created from a [`TrustedLen`] of `&str`. + /// # Implementation + /// This function is `O(N)` + #[inline] + pub fn from_trusted_len_values_iter, I: TrustedLen>( + iterator: I, + ) -> Self { + MutableUtf8Array::::from_trusted_len_values_iter(iterator).into() + } + + /// Creates a new [`Utf8Array`] from a [`Iterator`] of `&str`. + pub fn from_iter_values, I: Iterator>(iterator: I) -> Self { + MutableUtf8Array::::from_iter_values(iterator).into() + } + + /// Creates a [`Utf8Array`] from an iterator of trusted length. + /// + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + P: AsRef, + I: Iterator>, + { + MutableUtf8Array::::from_trusted_len_iter_unchecked(iterator).into() + } + + /// Creates a [`Utf8Array`] from an iterator of trusted length. + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + P: AsRef, + I: TrustedLen>, + { + MutableUtf8Array::::from_trusted_len_iter(iterator).into() + } + + /// Creates a [`Utf8Array`] from an falible iterator of trusted length. + /// + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn try_from_trusted_len_iter_unchecked( + iterator: I, + ) -> std::result::Result + where + P: AsRef, + I: IntoIterator, E>>, + { + MutableUtf8Array::::try_from_trusted_len_iter_unchecked(iterator).map(|x| x.into()) + } + + /// Creates a [`Utf8Array`] from an fallible iterator of trusted length. + #[inline] + pub fn try_from_trusted_len_iter(iter: I) -> std::result::Result + where + P: AsRef, + I: TrustedLen, E>>, + { + MutableUtf8Array::::try_from_trusted_len_iter(iter).map(|x| x.into()) + } + + /// Applies a function `f` to the validity of this array. + /// + /// This is an API to leverage clone-on-write + /// # Panics + /// This function panics if the function `f` modifies the length of the [`Bitmap`]. + pub fn apply_validity Bitmap>(&mut self, f: F) { + if let Some(validity) = std::mem::take(&mut self.validity) { + self.set_validity(Some(f(validity))) + } + } + + // Convert this [`Utf8Array`] to a [`BinaryArray`]. + pub fn to_binary(&self) -> BinaryArray { + unsafe { + BinaryArray::new_unchecked( + BinaryArray::::default_dtype(), + self.offsets.clone(), + self.values.clone(), + self.validity.clone(), + ) + } + } +} + +impl Splitable for Utf8Array { + #[inline(always)] + fn check_bound(&self, offset: usize) -> bool { + offset <= self.len() + } + + unsafe fn _split_at_unchecked(&self, offset: usize) -> (Self, Self) { + let (lhs_validity, rhs_validity) = unsafe { self.validity.split_at_unchecked(offset) }; + let (lhs_offsets, rhs_offsets) = unsafe { self.offsets.split_at_unchecked(offset) }; + + ( + Self { + dtype: self.dtype.clone(), + offsets: lhs_offsets, + values: self.values.clone(), + validity: lhs_validity, + }, + Self { + dtype: self.dtype.clone(), + offsets: rhs_offsets, + values: self.values.clone(), + validity: rhs_validity, + }, + ) + } +} + +impl Array for Utf8Array { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + #[inline] + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } +} + +unsafe impl GenericBinaryArray for Utf8Array { + #[inline] + fn values(&self) -> &[u8] { + self.values() + } + + #[inline] + fn offsets(&self) -> &[O] { + self.offsets().buffer() + } +} + +impl Default for Utf8Array { + fn default() -> Self { + let dtype = if O::IS_LARGE { + ArrowDataType::LargeUtf8 + } else { + ArrowDataType::Utf8 + }; + Utf8Array::new(dtype, Default::default(), Default::default(), None) + } +} diff --git a/crates/polars-arrow/src/array/utf8/mutable.rs b/crates/polars-arrow/src/array/utf8/mutable.rs new file mode 100644 index 000000000000..58332f408ff1 --- /dev/null +++ b/crates/polars-arrow/src/array/utf8/mutable.rs @@ -0,0 +1,553 @@ +use std::sync::Arc; + +use polars_error::{PolarsResult, polars_bail}; + +use super::{MutableUtf8ValuesArray, MutableUtf8ValuesIter, StrAsBytes, Utf8Array}; +use crate::array::physical_binary::*; +use crate::array::{Array, MutableArray, TryExtend, TryExtendFromSelf, TryPush}; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::datatypes::ArrowDataType; +use crate::offset::{Offset, Offsets}; +use crate::trusted_len::TrustedLen; + +/// A [`MutableArray`] that builds a [`Utf8Array`]. It differs +/// from [`MutableUtf8ValuesArray`] in that it can build nullable [`Utf8Array`]s. +#[derive(Debug, Clone)] +pub struct MutableUtf8Array { + values: MutableUtf8ValuesArray, + validity: Option, +} + +impl From> for Utf8Array { + fn from(other: MutableUtf8Array) -> Self { + let validity = other.validity.and_then(|x| { + let validity: Option = x.into(); + validity + }); + let array: Utf8Array = other.values.into(); + array.with_validity(validity) + } +} + +impl Default for MutableUtf8Array { + fn default() -> Self { + Self::new() + } +} + +impl MutableUtf8Array { + /// Initializes a new empty [`MutableUtf8Array`]. + pub fn new() -> Self { + Self { + values: Default::default(), + validity: None, + } + } + + /// Returns a [`MutableUtf8Array`] created from its internal representation. + /// + /// # Errors + /// This function returns an error iff: + /// * The last offset is not equal to the values' length. + /// * the validity's length is not equal to `offsets.len()`. + /// * The `dtype`'s [`crate::datatypes::PhysicalType`] is not equal to either `Utf8` or `LargeUtf8`. + /// * The `values` between two consecutive `offsets` are not valid utf8 + /// # Implementation + /// This function is `O(N)` - checking utf8 is `O(N)` + pub fn try_new( + dtype: ArrowDataType, + offsets: Offsets, + values: Vec, + validity: Option, + ) -> PolarsResult { + let values = MutableUtf8ValuesArray::try_new(dtype, offsets, values)?; + + if validity + .as_ref() + .is_some_and(|validity| validity.len() != values.len()) + { + polars_bail!(ComputeError: "validity's length must be equal to the number of values") + } + + Ok(Self { values, validity }) + } + + /// Create a [`MutableUtf8Array`] out of low-end APIs. + /// + /// # Safety + /// The caller must ensure that every value between offsets is a valid utf8. + /// # Panics + /// This function panics iff: + /// * The `offsets` and `values` are inconsistent + /// * The validity is not `None` and its length is different from `offsets`'s length minus one. + pub unsafe fn new_unchecked( + dtype: ArrowDataType, + offsets: Offsets, + values: Vec, + validity: Option, + ) -> Self { + let values = MutableUtf8ValuesArray::new_unchecked(dtype, offsets, values); + if let Some(ref validity) = validity { + assert_eq!(values.len(), validity.len()); + } + Self { values, validity } + } + + /// Creates a new [`MutableUtf8Array`] from a slice of optional `&[u8]`. + // Note: this can't be `impl From` because Rust does not allow double `AsRef` on it. + pub fn from, P: AsRef<[Option]>>(slice: P) -> Self { + Self::from_trusted_len_iter(slice.as_ref().iter().map(|x| x.as_ref())) + } + + fn default_dtype() -> ArrowDataType { + Utf8Array::::default_dtype() + } + + /// Initializes a new [`MutableUtf8Array`] with a pre-allocated capacity of slots. + pub fn with_capacity(capacity: usize) -> Self { + Self::with_capacities(capacity, 0) + } + + /// Initializes a new [`MutableUtf8Array`] with a pre-allocated capacity of slots and values. + pub fn with_capacities(capacity: usize, values: usize) -> Self { + Self { + values: MutableUtf8ValuesArray::with_capacities(capacity, values), + validity: None, + } + } + + /// Reserves `additional` elements and `additional_values` on the values buffer. + pub fn reserve(&mut self, additional: usize, additional_values: usize) { + self.values.reserve(additional, additional_values); + if let Some(x) = self.validity.as_mut() { + x.reserve(additional) + } + } + + /// Reserves `additional` elements and `additional_values` on the values buffer. + pub fn capacity(&self) -> usize { + self.values.capacity() + } + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.values.len() + } + + /// Pushes a new element to the array. + /// # Panic + /// This operation panics iff the length of all values (in bytes) exceeds `O` maximum value. + #[inline] + pub fn push>(&mut self, value: Option) { + self.try_push(value).unwrap() + } + + /// Returns the value of the element at index `i`, ignoring the array's validity. + #[inline] + pub fn value(&self, i: usize) -> &str { + self.values.value(i) + } + + /// Returns the value of the element at index `i`, ignoring the array's validity. + /// + /// # Safety + /// This function is safe iff `i < self.len`. + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> &str { + self.values.value_unchecked(i) + } + + /// Pop the last entry from [`MutableUtf8Array`]. + /// This function returns `None` iff this array is empty. + pub fn pop(&mut self) -> Option { + let value = self.values.pop()?; + self.validity + .as_mut() + .map(|x| x.pop()?.then(|| ())) + .unwrap_or_else(|| Some(())) + .map(|_| value) + } + + fn init_validity(&mut self) { + let mut validity = MutableBitmap::with_capacity(self.values.capacity()); + validity.extend_constant(self.len(), true); + validity.set(self.len() - 1, false); + self.validity = Some(validity); + } + + /// Returns an iterator of `Option<&str>` + pub fn iter(&self) -> ZipValidity<&str, MutableUtf8ValuesIter, BitmapIter> { + ZipValidity::new(self.values_iter(), self.validity.as_ref().map(|x| x.iter())) + } + + /// Converts itself into an [`Array`]. + pub fn into_arc(self) -> Arc { + let a: Utf8Array = self.into(); + Arc::new(a) + } + + /// Shrinks the capacity of the [`MutableUtf8Array`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + if let Some(validity) = &mut self.validity { + validity.shrink_to_fit() + } + } + + /// Extract the low-end APIs from the [`MutableUtf8Array`]. + pub fn into_data(self) -> (ArrowDataType, Offsets, Vec, Option) { + let (dtype, offsets, values) = self.values.into_inner(); + (dtype, offsets, values, self.validity) + } + + /// Returns an iterator of `&str` + pub fn values_iter(&self) -> MutableUtf8ValuesIter { + self.values.iter() + } + + /// Sets the validity. + /// # Panic + /// Panics iff the validity's len is not equal to the existing values' length. + pub fn set_validity(&mut self, validity: Option) { + if let Some(validity) = &validity { + assert_eq!(self.values.len(), validity.len()) + } + self.validity = validity; + } + + /// Applies a function `f` to the validity of this array. + /// + /// This is an API to leverage clone-on-write + /// # Panics + /// This function panics if the function `f` modifies the length of the [`Bitmap`]. + pub fn apply_validity MutableBitmap>(&mut self, f: F) { + if let Some(validity) = std::mem::take(&mut self.validity) { + self.set_validity(Some(f(validity))) + } + } +} + +impl MutableUtf8Array { + /// returns its values. + pub fn values(&self) -> &Vec { + self.values.values() + } + + /// returns its offsets. + pub fn offsets(&self) -> &Offsets { + self.values.offsets() + } +} + +impl MutableArray for MutableUtf8Array { + fn len(&self) -> usize { + self.len() + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + let array: Utf8Array = std::mem::take(self).into(); + array.boxed() + } + + fn as_arc(&mut self) -> Arc { + let array: Utf8Array = std::mem::take(self).into(); + array.arced() + } + + fn dtype(&self) -> &ArrowDataType { + if O::IS_LARGE { + &ArrowDataType::LargeUtf8 + } else { + &ArrowDataType::Utf8 + } + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + self.push::<&str>(None) + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional, 0) + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit() + } +} + +impl> FromIterator> for MutableUtf8Array { + fn from_iter>>(iter: I) -> Self { + Self::try_from_iter(iter).unwrap() + } +} + +impl MutableUtf8Array { + /// Extends the [`MutableUtf8Array`] from an iterator of values of trusted len. + /// This differs from `extended_trusted_len` which accepts iterator of optional values. + #[inline] + pub fn extend_trusted_len_values(&mut self, iterator: I) + where + P: AsRef, + I: TrustedLen, + { + unsafe { self.extend_trusted_len_values_unchecked(iterator) } + } + + /// Extends the [`MutableUtf8Array`] from an iterator of values. + /// This differs from `extended_trusted_len` which accepts iterator of optional values. + #[inline] + pub fn extend_values(&mut self, iterator: I) + where + P: AsRef, + I: Iterator, + { + let length = self.values.len(); + self.values.extend(iterator); + let additional = self.values.len() - length; + + if let Some(validity) = self.validity.as_mut() { + validity.extend_constant(additional, true); + } + } + + /// Extends the [`MutableUtf8Array`] from an iterator of values of trusted len. + /// This differs from `extended_trusted_len_unchecked` which accepts iterator of optional + /// values. + /// + /// # Safety + /// The iterator must be trusted len. + #[inline] + pub unsafe fn extend_trusted_len_values_unchecked(&mut self, iterator: I) + where + P: AsRef, + I: Iterator, + { + let length = self.values.len(); + self.values.extend_trusted_len_unchecked(iterator); + let additional = self.values.len() - length; + + if let Some(validity) = self.validity.as_mut() { + validity.extend_constant(additional, true); + } + } + + /// Extends the [`MutableUtf8Array`] from an iterator of trusted len. + #[inline] + pub fn extend_trusted_len(&mut self, iterator: I) + where + P: AsRef, + I: TrustedLen>, + { + unsafe { self.extend_trusted_len_unchecked(iterator) } + } + + /// Extends [`MutableUtf8Array`] from an iterator of trusted len. + /// + /// # Safety + /// The iterator must be trusted len. + #[inline] + pub unsafe fn extend_trusted_len_unchecked(&mut self, iterator: I) + where + P: AsRef, + I: Iterator>, + { + if self.validity.is_none() { + let mut validity = MutableBitmap::new(); + validity.extend_constant(self.len(), true); + self.validity = Some(validity); + } + + self.values + .extend_from_trusted_len_iter(self.validity.as_mut().unwrap(), iterator); + } + + /// Creates a [`MutableUtf8Array`] from an iterator of trusted length. + /// + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + P: AsRef, + I: Iterator>, + { + let iterator = iterator.map(|x| x.map(StrAsBytes)); + let (validity, offsets, values) = trusted_len_unzip(iterator); + + // soundness: P is `str` + Self::new_unchecked(Self::default_dtype(), offsets, values, validity) + } + + /// Creates a [`MutableUtf8Array`] from an iterator of trusted length. + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + P: AsRef, + I: TrustedLen>, + { + // soundness: I is `TrustedLen` + unsafe { Self::from_trusted_len_iter_unchecked(iterator) } + } + + /// Creates a [`MutableUtf8Array`] from an iterator of trusted length of `&str`. + /// + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_values_iter_unchecked, I: Iterator>( + iterator: I, + ) -> Self { + MutableUtf8ValuesArray::from_trusted_len_iter_unchecked(iterator).into() + } + + /// Creates a new [`MutableUtf8Array`] from a [`TrustedLen`] of `&str`. + #[inline] + pub fn from_trusted_len_values_iter, I: TrustedLen>( + iterator: I, + ) -> Self { + // soundness: I is `TrustedLen` + unsafe { Self::from_trusted_len_values_iter_unchecked(iterator) } + } + + /// Creates a new [`MutableUtf8Array`] from an iterator. + /// # Error + /// This operation errors iff the total length in bytes on the iterator exceeds `O`'s maximum value. + /// (`i32::MAX` or `i64::MAX` respectively). + fn try_from_iter, I: IntoIterator>>( + iter: I, + ) -> PolarsResult { + let iterator = iter.into_iter(); + let (lower, _) = iterator.size_hint(); + let mut array = Self::with_capacity(lower); + for item in iterator { + array.try_push(item)?; + } + Ok(array) + } + + /// Creates a [`MutableUtf8Array`] from an falible iterator of trusted length. + /// + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn try_from_trusted_len_iter_unchecked( + iterator: I, + ) -> std::result::Result + where + P: AsRef, + I: IntoIterator, E>>, + { + let iterator = iterator.into_iter(); + + let iterator = iterator.map(|x| x.map(|x| x.map(StrAsBytes))); + let (validity, offsets, values) = try_trusted_len_unzip(iterator)?; + + // soundness: P is `str` + Ok(Self::new_unchecked( + Self::default_dtype(), + offsets, + values, + validity, + )) + } + + /// Creates a [`MutableUtf8Array`] from an falible iterator of trusted length. + #[inline] + pub fn try_from_trusted_len_iter(iterator: I) -> std::result::Result + where + P: AsRef, + I: TrustedLen, E>>, + { + // soundness: I: TrustedLen + unsafe { Self::try_from_trusted_len_iter_unchecked(iterator) } + } + + /// Creates a new [`MutableUtf8Array`] from a [`Iterator`] of `&str`. + pub fn from_iter_values, I: Iterator>(iterator: I) -> Self { + MutableUtf8ValuesArray::from_iter(iterator).into() + } + + /// Extend with a fallible iterator + pub fn extend_fallible(&mut self, iter: I) -> std::result::Result<(), E> + where + E: std::error::Error, + I: IntoIterator, E>>, + T: AsRef, + { + let mut iter = iter.into_iter(); + self.reserve(iter.size_hint().0, 0); + iter.try_for_each(|x| { + self.push(x?); + Ok(()) + }) + } +} + +impl> Extend> for MutableUtf8Array { + fn extend>>(&mut self, iter: I) { + self.try_extend(iter).unwrap(); + } +} + +impl> TryExtend> for MutableUtf8Array { + fn try_extend>>(&mut self, iter: I) -> PolarsResult<()> { + let mut iter = iter.into_iter(); + self.reserve(iter.size_hint().0, 0); + iter.try_for_each(|x| self.try_push(x)) + } +} + +impl> TryPush> for MutableUtf8Array { + #[inline] + fn try_push(&mut self, value: Option) -> PolarsResult<()> { + match value { + Some(value) => { + self.values.try_push(value.as_ref())?; + + if let Some(validity) = &mut self.validity { + validity.push(true) + } + }, + None => { + self.values.push(""); + match &mut self.validity { + Some(validity) => validity.push(false), + None => self.init_validity(), + } + }, + } + Ok(()) + } +} + +impl PartialEq for MutableUtf8Array { + fn eq(&self, other: &Self) -> bool { + self.iter().eq(other.iter()) + } +} + +impl TryExtendFromSelf for MutableUtf8Array { + fn try_extend_from_self(&mut self, other: &Self) -> PolarsResult<()> { + extend_validity(self.len(), &mut self.validity, &other.validity); + + self.values.try_extend_from_self(&other.values) + } +} diff --git a/crates/polars-arrow/src/array/utf8/mutable_values.rs b/crates/polars-arrow/src/array/utf8/mutable_values.rs new file mode 100644 index 000000000000..a36af300efca --- /dev/null +++ b/crates/polars-arrow/src/array/utf8/mutable_values.rs @@ -0,0 +1,419 @@ +use std::sync::Arc; + +use polars_error::{PolarsResult, polars_bail}; + +use super::{MutableUtf8Array, StrAsBytes, Utf8Array}; +use crate::array::physical_binary::*; +use crate::array::specification::{try_check_offsets_bounds, try_check_utf8}; +use crate::array::{Array, ArrayValuesIter, MutableArray, TryExtend, TryExtendFromSelf, TryPush}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::ArrowDataType; +use crate::offset::{Offset, Offsets}; +use crate::trusted_len::TrustedLen; + +/// A [`MutableArray`] that builds a [`Utf8Array`]. It differs +/// from [`MutableUtf8Array`] in that it builds non-null [`Utf8Array`]. +#[derive(Debug, Clone)] +pub struct MutableUtf8ValuesArray { + dtype: ArrowDataType, + offsets: Offsets, + values: Vec, +} + +impl From> for Utf8Array { + fn from(other: MutableUtf8ValuesArray) -> Self { + // SAFETY: + // `MutableUtf8ValuesArray` has the same invariants as `Utf8Array` and thus + // `Utf8Array` can be safely created from `MutableUtf8ValuesArray` without checks. + unsafe { + Utf8Array::::new_unchecked( + other.dtype, + other.offsets.into(), + other.values.into(), + None, + ) + } + } +} + +impl From> for MutableUtf8Array { + fn from(other: MutableUtf8ValuesArray) -> Self { + // SAFETY: + // `MutableUtf8ValuesArray` has the same invariants as `MutableUtf8Array` + unsafe { + MutableUtf8Array::::new_unchecked(other.dtype, other.offsets, other.values, None) + } + } +} + +impl Default for MutableUtf8ValuesArray { + fn default() -> Self { + Self::new() + } +} + +impl MutableUtf8ValuesArray { + /// Returns an empty [`MutableUtf8ValuesArray`]. + pub fn new() -> Self { + Self { + dtype: Self::default_dtype(), + offsets: Offsets::new(), + values: Vec::::new(), + } + } + + /// Returns a [`MutableUtf8ValuesArray`] created from its internal representation. + /// + /// # Errors + /// This function returns an error iff: + /// * `offsets.last()` is greater than `values.len()`. + /// * The `dtype`'s [`crate::datatypes::PhysicalType`] is not equal to either `Utf8` or `LargeUtf8`. + /// * The `values` between two consecutive `offsets` are not valid utf8 + /// # Implementation + /// This function is `O(N)` - checking utf8 is `O(N)` + pub fn try_new( + dtype: ArrowDataType, + offsets: Offsets, + values: Vec, + ) -> PolarsResult { + try_check_utf8(&offsets, &values)?; + if dtype.to_physical_type() != Self::default_dtype().to_physical_type() { + polars_bail!(ComputeError: "MutableUtf8ValuesArray can only be initialized with DataType::Utf8 or DataType::LargeUtf8") + } + + Ok(Self { + dtype, + offsets, + values, + }) + } + + /// Returns a [`MutableUtf8ValuesArray`] created from its internal representation. + /// + /// # Panic + /// This function does not panic iff: + /// * `offsets.last()` is greater than `values.len()` + /// * The `dtype`'s [`crate::datatypes::PhysicalType`] is equal to either `Utf8` or `LargeUtf8`. + /// + /// # Safety + /// This function is safe iff: + /// * the offsets are monotonically increasing + /// * The `values` between two consecutive `offsets` are not valid utf8 + /// # Implementation + /// This function is `O(1)` + pub unsafe fn new_unchecked( + dtype: ArrowDataType, + offsets: Offsets, + values: Vec, + ) -> Self { + try_check_offsets_bounds(&offsets, values.len()) + .expect("The length of the values must be equal to the last offset value"); + + if dtype.to_physical_type() != Self::default_dtype().to_physical_type() { + panic!( + "MutableUtf8ValuesArray can only be initialized with DataType::Utf8 or DataType::LargeUtf8" + ) + } + + Self { + dtype, + offsets, + values, + } + } + + /// Returns the default [`ArrowDataType`] of this container: [`ArrowDataType::Utf8`] or [`ArrowDataType::LargeUtf8`] + /// depending on the generic [`Offset`]. + pub fn default_dtype() -> ArrowDataType { + Utf8Array::::default_dtype() + } + + /// Initializes a new [`MutableUtf8ValuesArray`] with a pre-allocated capacity of items. + pub fn with_capacity(capacity: usize) -> Self { + Self::with_capacities(capacity, 0) + } + + /// Initializes a new [`MutableUtf8ValuesArray`] with a pre-allocated capacity of items and values. + pub fn with_capacities(capacity: usize, values: usize) -> Self { + Self { + dtype: Self::default_dtype(), + offsets: Offsets::::with_capacity(capacity), + values: Vec::::with_capacity(values), + } + } + + /// returns its values. + #[inline] + pub fn values(&self) -> &Vec { + &self.values + } + + /// returns its offsets. + #[inline] + pub fn offsets(&self) -> &Offsets { + &self.offsets + } + + /// Reserves `additional` elements and `additional_values` on the values. + #[inline] + pub fn reserve(&mut self, additional: usize, additional_values: usize) { + self.offsets.reserve(additional + 1); + self.values.reserve(additional_values); + } + + /// Returns the capacity in number of items + pub fn capacity(&self) -> usize { + self.offsets.capacity() + } + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.offsets.len_proxy() + } + + /// Pushes a new item to the array. + /// # Panic + /// This operation panics iff the length of all values (in bytes) exceeds `O` maximum value. + #[inline] + pub fn push>(&mut self, value: T) { + self.try_push(value).unwrap() + } + + /// Pop the last entry from [`MutableUtf8ValuesArray`]. + /// This function returns `None` iff this array is empty. + pub fn pop(&mut self) -> Option { + if self.len() == 0 { + return None; + } + self.offsets.pop()?; + let start = self.offsets.last().to_usize(); + let value = self.values.split_off(start); + // SAFETY: utf8 is validated on initialization + Some(unsafe { String::from_utf8_unchecked(value) }) + } + + /// Returns the value of the element at index `i`. + /// # Panic + /// This function panics iff `i >= self.len`. + #[inline] + pub fn value(&self, i: usize) -> &str { + assert!(i < self.len()); + unsafe { self.value_unchecked(i) } + } + + /// Returns the value of the element at index `i`. + /// + /// # Safety + /// This function is safe iff `i < self.len`. + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> &str { + // soundness: the invariant of the function + let (start, end) = self.offsets.start_end(i); + + // soundness: the invariant of the struct + let slice = self.values.get_unchecked(start..end); + + // soundness: the invariant of the struct + std::str::from_utf8_unchecked(slice) + } + + /// Returns an iterator of `&str` + pub fn iter(&self) -> ArrayValuesIter { + ArrayValuesIter::new(self) + } + + /// Shrinks the capacity of the [`MutableUtf8ValuesArray`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + self.offsets.shrink_to_fit(); + } + + /// Extract the low-end APIs from the [`MutableUtf8ValuesArray`]. + pub fn into_inner(self) -> (ArrowDataType, Offsets, Vec) { + (self.dtype, self.offsets, self.values) + } +} + +impl MutableArray for MutableUtf8ValuesArray { + fn len(&self) -> usize { + self.len() + } + + fn validity(&self) -> Option<&MutableBitmap> { + None + } + + fn as_box(&mut self) -> Box { + let array: Utf8Array = std::mem::take(self).into(); + array.boxed() + } + + fn as_arc(&mut self) -> Arc { + let array: Utf8Array = std::mem::take(self).into(); + array.arced() + } + + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + self.push::<&str>("") + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional, 0) + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit() + } +} + +impl> FromIterator

for MutableUtf8ValuesArray { + fn from_iter>(iter: I) -> Self { + let (offsets, values) = values_iter(iter.into_iter().map(StrAsBytes)); + // soundness: T: AsRef and offsets are monotonically increasing + unsafe { Self::new_unchecked(Self::default_dtype(), offsets, values) } + } +} + +impl MutableUtf8ValuesArray { + pub(crate) unsafe fn extend_from_trusted_len_iter( + &mut self, + validity: &mut MutableBitmap, + iterator: I, + ) where + P: AsRef, + I: Iterator>, + { + let iterator = iterator.map(|x| x.map(StrAsBytes)); + extend_from_trusted_len_iter(&mut self.offsets, &mut self.values, validity, iterator); + } + + /// Extends the [`MutableUtf8ValuesArray`] from a [`TrustedLen`] + #[inline] + pub fn extend_trusted_len(&mut self, iterator: I) + where + P: AsRef, + I: TrustedLen, + { + unsafe { self.extend_trusted_len_unchecked(iterator) } + } + + /// Extends [`MutableUtf8ValuesArray`] from an iterator of trusted len. + /// + /// # Safety + /// The iterator must be trusted len. + #[inline] + pub unsafe fn extend_trusted_len_unchecked(&mut self, iterator: I) + where + P: AsRef, + I: Iterator, + { + let iterator = iterator.map(StrAsBytes); + extend_from_trusted_len_values_iter(&mut self.offsets, &mut self.values, iterator); + } + + /// Creates a [`MutableUtf8ValuesArray`] from a [`TrustedLen`] + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + P: AsRef, + I: TrustedLen, + { + // soundness: I is `TrustedLen` + unsafe { Self::from_trusted_len_iter_unchecked(iterator) } + } + + /// Returns a new [`MutableUtf8ValuesArray`] from an iterator of trusted length. + /// + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + P: AsRef, + I: Iterator, + { + let iterator = iterator.map(StrAsBytes); + let (offsets, values) = trusted_len_values_iter(iterator); + + // soundness: P is `str` and offsets are monotonically increasing + Self::new_unchecked(Self::default_dtype(), offsets, values) + } + + /// Returns a new [`MutableUtf8ValuesArray`] from an iterator. + /// # Error + /// This operation errors iff the total length in bytes on the iterator exceeds `O`'s maximum value. + /// (`i32::MAX` or `i64::MAX` respectively). + pub fn try_from_iter, I: IntoIterator>(iter: I) -> PolarsResult { + let iterator = iter.into_iter(); + let (lower, _) = iterator.size_hint(); + let mut array = Self::with_capacity(lower); + for item in iterator { + array.try_push(item)?; + } + Ok(array) + } + + /// Extend with a fallible iterator + pub fn extend_fallible(&mut self, iter: I) -> std::result::Result<(), E> + where + E: std::error::Error, + I: IntoIterator>, + T: AsRef, + { + let mut iter = iter.into_iter(); + self.reserve(iter.size_hint().0, 0); + iter.try_for_each(|x| { + self.push(x?); + Ok(()) + }) + } +} + +impl> Extend for MutableUtf8ValuesArray { + fn extend>(&mut self, iter: I) { + extend_from_values_iter( + &mut self.offsets, + &mut self.values, + iter.into_iter().map(StrAsBytes), + ); + } +} + +impl> TryExtend for MutableUtf8ValuesArray { + fn try_extend>(&mut self, iter: I) -> PolarsResult<()> { + let mut iter = iter.into_iter(); + self.reserve(iter.size_hint().0, 0); + iter.try_for_each(|x| self.try_push(x)) + } +} + +impl> TryPush for MutableUtf8ValuesArray { + #[inline] + fn try_push(&mut self, value: T) -> PolarsResult<()> { + let bytes = value.as_ref().as_bytes(); + self.values.extend_from_slice(bytes); + self.offsets.try_push(bytes.len()) + } +} + +impl TryExtendFromSelf for MutableUtf8ValuesArray { + fn try_extend_from_self(&mut self, other: &Self) -> PolarsResult<()> { + self.values.extend_from_slice(&other.values); + self.offsets.try_extend_from_self(&other.offsets) + } +} diff --git a/crates/polars-arrow/src/array/values.rs b/crates/polars-arrow/src/array/values.rs new file mode 100644 index 000000000000..1a1f2582771b --- /dev/null +++ b/crates/polars-arrow/src/array/values.rs @@ -0,0 +1,91 @@ +use crate::array::{ + ArrayRef, BinaryArray, BinaryViewArray, FixedSizeListArray, ListArray, Utf8Array, Utf8ViewArray, +}; +use crate::datatypes::ArrowDataType; +use crate::offset::Offset; + +pub trait ValueSize { + /// Get the values size that is still "visible" to the underlying array. + /// E.g. take the offsets into account. + fn get_values_size(&self) -> usize; +} + +impl ValueSize for ListArray { + fn get_values_size(&self) -> usize { + unsafe { + // SAFETY: + // invariant of the struct that offsets always has at least 2 members. + let start = *self.offsets().get_unchecked(0) as usize; + let end = *self.offsets().last() as usize; + end - start + } + } +} + +impl ValueSize for FixedSizeListArray { + fn get_values_size(&self) -> usize { + self.values().len() + } +} + +impl ValueSize for Utf8Array { + fn get_values_size(&self) -> usize { + unsafe { + // SAFETY: + // invariant of the struct that offsets always has at least 2 members. + let start = self.offsets().get_unchecked(0).to_usize(); + let end = self.offsets().last().to_usize(); + end - start + } + } +} + +impl ValueSize for BinaryArray { + fn get_values_size(&self) -> usize { + unsafe { + // SAFETY: + // invariant of the struct that offsets always has at least 2 members. + let start = self.offsets().get_unchecked(0).to_usize(); + let end = self.offsets().last().to_usize(); + end - start + } + } +} + +impl ValueSize for ArrayRef { + fn get_values_size(&self) -> usize { + match self.dtype() { + ArrowDataType::LargeUtf8 => self + .as_any() + .downcast_ref::>() + .unwrap() + .get_values_size(), + ArrowDataType::FixedSizeList(_, _) => self + .as_any() + .downcast_ref::() + .unwrap() + .get_values_size(), + ArrowDataType::LargeList(_) => self + .as_any() + .downcast_ref::>() + .unwrap() + .get_values_size(), + ArrowDataType::LargeBinary => self + .as_any() + .downcast_ref::>() + .unwrap() + .get_values_size(), + ArrowDataType::Utf8View => self + .as_any() + .downcast_ref::() + .unwrap() + .total_bytes_len(), + ArrowDataType::BinaryView => self + .as_any() + .downcast_ref::() + .unwrap() + .total_bytes_len(), + _ => unimplemented!(), + } + } +} diff --git a/crates/polars-arrow/src/bitmap/aligned.rs b/crates/polars-arrow/src/bitmap/aligned.rs new file mode 100644 index 000000000000..ad1baf06631e --- /dev/null +++ b/crates/polars-arrow/src/bitmap/aligned.rs @@ -0,0 +1,129 @@ +use std::iter::Copied; +use std::slice::Iter; + +use crate::bitmap::utils::BitChunk; + +fn load_chunk_le(src: &[u8]) -> T { + if let Ok(chunk) = src.try_into() { + return T::from_le_bytes(chunk); + } + + let mut chunk = T::Bytes::default(); + let len = src.len().min(chunk.as_ref().len()); + chunk.as_mut()[..len].copy_from_slice(&src[..len]); + T::from_le_bytes(chunk) +} + +/// Represents a bitmap split in three portions, a prefix, a suffix and an +/// aligned bulk section in the middle. +#[derive(Default, Clone, Debug)] +pub struct AlignedBitmapSlice<'a, T: BitChunk> { + prefix: T, + prefix_len: u32, + bulk: &'a [T], + suffix: T, + suffix_len: u32, +} + +impl<'a, T: BitChunk> AlignedBitmapSlice<'a, T> { + #[inline(always)] + pub fn prefix(&self) -> T { + self.prefix + } + + #[inline(always)] + pub fn bulk_iter(&self) -> Copied> { + self.bulk.iter().copied() + } + + #[inline(always)] + pub fn bulk(&self) -> &'a [T] { + self.bulk + } + + #[inline(always)] + pub fn suffix(&self) -> T { + self.suffix + } + + /// The length (in bits) of the portion of the bitmap found in prefix. + #[inline(always)] + pub fn prefix_bitlen(&self) -> usize { + self.prefix_len as usize + } + + /// The length (in bits) of the portion of the bitmap found in bulk. + #[inline(always)] + pub fn bulk_bitlen(&self) -> usize { + 8 * size_of::() * self.bulk.len() + } + + /// The length (in bits) of the portion of the bitmap found in suffix. + #[inline(always)] + pub fn suffix_bitlen(&self) -> usize { + self.suffix_len as usize + } + + pub fn new(mut bytes: &'a [u8], mut offset: usize, len: usize) -> Self { + if len == 0 { + return Self::default(); + } + + assert!(bytes.len() * 8 >= offset + len); + + // Strip off useless bytes from start. + let start_byte_idx = offset / 8; + bytes = &bytes[start_byte_idx..]; + offset %= 8; + + // Fast-path: fits entirely in one chunk. + let chunk_len = size_of::(); + let chunk_len_bits = 8 * chunk_len; + if offset + len <= chunk_len_bits { + let mut prefix = load_chunk_le::(bytes) >> offset; + if len < chunk_len_bits { + prefix &= (T::one() << len) - T::one(); + } + return Self { + prefix, + prefix_len: len as u32, + ..Self::default() + }; + } + + // Find how many bytes from the start our aligned section would start. + let mut align_offset = bytes.as_ptr().align_offset(chunk_len); + let mut align_offset_bits = 8 * align_offset; + + // Oops, the original pointer was already aligned, but our offset means + // we can't start there, start one chunk later. + if offset > align_offset_bits { + align_offset_bits += chunk_len_bits; + align_offset += chunk_len; + } + + // Calculate based on this the lengths of our sections (in bits). + let prefix_len = (align_offset_bits - offset).min(len); + let rest_len = len - prefix_len; + let suffix_len = rest_len % chunk_len_bits; + let bulk_len = rest_len - suffix_len; + debug_assert!(prefix_len < chunk_len_bits); + debug_assert!(bulk_len % chunk_len_bits == 0); + debug_assert!(suffix_len < chunk_len_bits); + + // Now we just have to load. + let (prefix_bytes, rest_bytes) = bytes.split_at(align_offset); + let (bulk_bytes, suffix_bytes) = rest_bytes.split_at(bulk_len / 8); + let mut prefix = load_chunk_le::(prefix_bytes) >> offset; + let mut suffix = load_chunk_le::(suffix_bytes); + prefix &= (T::one() << prefix_len) - T::one(); + suffix &= (T::one() << suffix_len) - T::one(); + Self { + prefix, + bulk: bytemuck::cast_slice(bulk_bytes), + suffix, + prefix_len: prefix_len as u32, + suffix_len: suffix_len as u32, + } + } +} diff --git a/crates/polars-arrow/src/bitmap/assign_ops.rs b/crates/polars-arrow/src/bitmap/assign_ops.rs new file mode 100644 index 000000000000..2afe5746c6b5 --- /dev/null +++ b/crates/polars-arrow/src/bitmap/assign_ops.rs @@ -0,0 +1,227 @@ +use super::utils::{BitChunk, BitChunkIterExact, BitChunksExact}; +use crate::bitmap::{Bitmap, MutableBitmap}; + +/// Applies a function to every bit of this [`MutableBitmap`] in chunks +/// +/// This function can be for operations like `!` to a [`MutableBitmap`]. +pub fn unary_assign T>(bitmap: &mut MutableBitmap, op: F) { + let mut chunks = bitmap.bitchunks_exact_mut::(); + + chunks.by_ref().for_each(|chunk| { + let new_chunk: T = match (chunk as &[u8]).try_into() { + Ok(a) => T::from_ne_bytes(a), + Err(_) => unreachable!(), + }; + let new_chunk = op(new_chunk); + chunk.copy_from_slice(new_chunk.to_ne_bytes().as_ref()); + }); + + if chunks.remainder().is_empty() { + return; + } + let mut new_remainder = T::zero().to_ne_bytes(); + chunks + .remainder() + .iter() + .enumerate() + .for_each(|(index, b)| new_remainder[index] = *b); + new_remainder = op(T::from_ne_bytes(new_remainder)).to_ne_bytes(); + + let len = chunks.remainder().len(); + chunks + .remainder() + .copy_from_slice(&new_remainder.as_ref()[..len]); +} + +impl std::ops::Not for MutableBitmap { + type Output = Self; + + #[inline] + fn not(mut self) -> Self { + unary_assign(&mut self, |a: u64| !a); + self + } +} + +fn binary_assign_impl(lhs: &mut MutableBitmap, mut rhs: I, op: F) +where + I: BitChunkIterExact, + T: BitChunk, + F: Fn(T, T) -> T, +{ + let mut lhs_chunks = lhs.bitchunks_exact_mut::(); + + lhs_chunks + .by_ref() + .zip(rhs.by_ref()) + .for_each(|(lhs, rhs)| { + let new_chunk: T = match (lhs as &[u8]).try_into() { + Ok(a) => T::from_ne_bytes(a), + Err(_) => unreachable!(), + }; + let new_chunk = op(new_chunk, rhs); + lhs.copy_from_slice(new_chunk.to_ne_bytes().as_ref()); + }); + + let rem_lhs = lhs_chunks.remainder(); + let rem_rhs = rhs.remainder(); + if rem_lhs.is_empty() { + return; + } + let mut new_remainder = T::zero().to_ne_bytes(); + lhs_chunks + .remainder() + .iter() + .enumerate() + .for_each(|(index, b)| new_remainder[index] = *b); + new_remainder = op(T::from_ne_bytes(new_remainder), rem_rhs).to_ne_bytes(); + + let len = lhs_chunks.remainder().len(); + lhs_chunks + .remainder() + .copy_from_slice(&new_remainder.as_ref()[..len]); +} + +/// Apply a bitwise binary operation to a [`MutableBitmap`]. +/// +/// This function can be used for operations like `&=` to a [`MutableBitmap`]. +/// # Panics +/// This function panics iff `lhs.len() != `rhs.len()` +pub fn binary_assign(lhs: &mut MutableBitmap, rhs: &Bitmap, op: F) +where + F: Fn(T, T) -> T, +{ + assert_eq!(lhs.len(), rhs.len()); + + let (slice, offset, length) = rhs.as_slice(); + if offset == 0 { + let iter = BitChunksExact::::new(slice, length); + binary_assign_impl(lhs, iter, op) + } else { + let rhs_chunks = rhs.chunks::(); + binary_assign_impl(lhs, rhs_chunks, op) + } +} + +/// Apply a bitwise binary operation to a [`MutableBitmap`]. +/// +/// This function can be used for operations like `&=` to a [`MutableBitmap`]. +/// # Panics +/// This function panics iff `lhs.len() != `rhs.len()` +pub fn binary_assign_mut(lhs: &mut MutableBitmap, rhs: &MutableBitmap, op: F) +where + F: Fn(T, T) -> T, +{ + assert_eq!(lhs.len(), rhs.len()); + + let slice = rhs.as_slice(); + let iter = BitChunksExact::::new(slice, rhs.len()); + binary_assign_impl(lhs, iter, op) +} + +#[inline] +/// Compute bitwise OR operation in-place +fn or_assign(lhs: &mut MutableBitmap, rhs: &Bitmap) { + if rhs.unset_bits() == 0 { + assert_eq!(lhs.len(), rhs.len()); + lhs.clear(); + lhs.extend_constant(rhs.len(), true); + } else if rhs.unset_bits() == rhs.len() { + // bitmap remains + } else { + binary_assign(lhs, rhs, |x: T, y| x | y) + } +} + +#[inline] +/// Compute bitwise OR operation in-place +fn or_assign_mut(lhs: &mut MutableBitmap, rhs: &MutableBitmap) { + if rhs.unset_bits() == 0 { + assert_eq!(lhs.len(), rhs.len()); + lhs.clear(); + lhs.extend_constant(rhs.len(), true); + } else if rhs.unset_bits() == rhs.len() { + // bitmap remains + } else { + binary_assign_mut(lhs, rhs, |x: T, y| x | y) + } +} + +impl<'a> std::ops::BitOrAssign<&'a MutableBitmap> for &mut MutableBitmap { + #[inline] + fn bitor_assign(&mut self, rhs: &'a MutableBitmap) { + or_assign_mut::(self, rhs) + } +} + +impl<'a> std::ops::BitOrAssign<&'a Bitmap> for &mut MutableBitmap { + #[inline] + fn bitor_assign(&mut self, rhs: &'a Bitmap) { + or_assign::(self, rhs) + } +} + +impl<'a> std::ops::BitOr<&'a Bitmap> for MutableBitmap { + type Output = Self; + + #[inline] + fn bitor(mut self, rhs: &'a Bitmap) -> Self { + or_assign::(&mut self, rhs); + self + } +} + +#[inline] +/// Compute bitwise `&` between `lhs` and `rhs`, assigning it to `lhs` +fn and_assign(lhs: &mut MutableBitmap, rhs: &Bitmap) { + if rhs.unset_bits() == 0 { + // bitmap remains + } + if rhs.unset_bits() == rhs.len() { + assert_eq!(lhs.len(), rhs.len()); + lhs.clear(); + lhs.extend_constant(rhs.len(), false); + } else { + binary_assign(lhs, rhs, |x: T, y| x & y) + } +} + +impl<'a> std::ops::BitAndAssign<&'a Bitmap> for &mut MutableBitmap { + #[inline] + fn bitand_assign(&mut self, rhs: &'a Bitmap) { + and_assign::(self, rhs) + } +} + +impl<'a> std::ops::BitAnd<&'a Bitmap> for MutableBitmap { + type Output = Self; + + #[inline] + fn bitand(mut self, rhs: &'a Bitmap) -> Self { + and_assign::(&mut self, rhs); + self + } +} + +#[inline] +/// Compute bitwise XOR operation +fn xor_assign(lhs: &mut MutableBitmap, rhs: &Bitmap) { + binary_assign(lhs, rhs, |x: T, y| x ^ y) +} + +impl<'a> std::ops::BitXorAssign<&'a Bitmap> for &mut MutableBitmap { + #[inline] + fn bitxor_assign(&mut self, rhs: &'a Bitmap) { + xor_assign::(self, rhs) + } +} + +impl<'a> std::ops::BitXor<&'a Bitmap> for MutableBitmap { + type Output = Self; + + #[inline] + fn bitxor(mut self, rhs: &'a Bitmap) -> Self { + xor_assign::(&mut self, rhs); + self + } +} diff --git a/crates/polars-arrow/src/bitmap/bitmap_ops.rs b/crates/polars-arrow/src/bitmap/bitmap_ops.rs new file mode 100644 index 000000000000..8c0aac614cf9 --- /dev/null +++ b/crates/polars-arrow/src/bitmap/bitmap_ops.rs @@ -0,0 +1,370 @@ +use std::ops::{BitAnd, BitOr, BitXor, Not}; + +use super::Bitmap; +use super::utils::{BitChunk, BitChunkIterExact, BitChunksExact}; +use crate::bitmap::MutableBitmap; +use crate::trusted_len::TrustedLen; + +#[inline(always)] +pub(crate) fn push_bitchunk(buffer: &mut Vec, value: T) { + buffer.extend(value.to_ne_bytes()) +} + +/// Creates a [`Vec`] from a [`TrustedLen`] of [`BitChunk`]. +pub fn chunk_iter_to_vec>(iter: I) -> Vec { + let cap = iter.size_hint().0 * size_of::(); + let mut buffer = Vec::with_capacity(cap); + for v in iter { + push_bitchunk(&mut buffer, v) + } + buffer +} + +fn chunk_iter_to_vec_and_remainder>( + iter: I, + remainder: T, +) -> Vec { + let cap = (iter.size_hint().0 + 1) * size_of::(); + let mut buffer = Vec::with_capacity(cap); + for v in iter { + push_bitchunk(&mut buffer, v) + } + push_bitchunk(&mut buffer, remainder); + debug_assert_eq!(buffer.len(), cap); + buffer +} + +/// Apply a bitwise operation `op` to four inputs and return the result as a [`Bitmap`]. +pub fn quaternary(a1: &Bitmap, a2: &Bitmap, a3: &Bitmap, a4: &Bitmap, op: F) -> Bitmap +where + F: Fn(u64, u64, u64, u64) -> u64, +{ + assert_eq!(a1.len(), a2.len()); + assert_eq!(a1.len(), a3.len()); + assert_eq!(a1.len(), a4.len()); + let a1_chunks = a1.chunks(); + let a2_chunks = a2.chunks(); + let a3_chunks = a3.chunks(); + let a4_chunks = a4.chunks(); + + let rem_a1 = a1_chunks.remainder(); + let rem_a2 = a2_chunks.remainder(); + let rem_a3 = a3_chunks.remainder(); + let rem_a4 = a4_chunks.remainder(); + + let chunks = a1_chunks + .zip(a2_chunks) + .zip(a3_chunks) + .zip(a4_chunks) + .map(|(((a1, a2), a3), a4)| op(a1, a2, a3, a4)); + + let buffer = chunk_iter_to_vec_and_remainder(chunks, op(rem_a1, rem_a2, rem_a3, rem_a4)); + let length = a1.len(); + + Bitmap::from_u8_vec(buffer, length) +} + +/// Apply a bitwise operation `op` to three inputs and return the result as a [`Bitmap`]. +pub fn ternary(a1: &Bitmap, a2: &Bitmap, a3: &Bitmap, op: F) -> Bitmap +where + F: Fn(u64, u64, u64) -> u64, +{ + assert_eq!(a1.len(), a2.len()); + assert_eq!(a1.len(), a3.len()); + let a1_chunks = a1.chunks(); + let a2_chunks = a2.chunks(); + let a3_chunks = a3.chunks(); + + let rem_a1 = a1_chunks.remainder(); + let rem_a2 = a2_chunks.remainder(); + let rem_a3 = a3_chunks.remainder(); + + let chunks = a1_chunks + .zip(a2_chunks) + .zip(a3_chunks) + .map(|((a1, a2), a3)| op(a1, a2, a3)); + + let buffer = chunk_iter_to_vec_and_remainder(chunks, op(rem_a1, rem_a2, rem_a3)); + let length = a1.len(); + + Bitmap::from_u8_vec(buffer, length) +} + +/// Apply a bitwise operation `op` to two inputs and return the result as a [`Bitmap`]. +pub fn binary(lhs: &Bitmap, rhs: &Bitmap, op: F) -> Bitmap +where + F: Fn(u64, u64) -> u64, +{ + assert_eq!(lhs.len(), rhs.len()); + let lhs_chunks = lhs.chunks(); + let rhs_chunks = rhs.chunks(); + let rem_lhs = lhs_chunks.remainder(); + let rem_rhs = rhs_chunks.remainder(); + + let chunks = lhs_chunks + .zip(rhs_chunks) + .map(|(left, right)| op(left, right)); + + let buffer = chunk_iter_to_vec_and_remainder(chunks, op(rem_lhs, rem_rhs)); + let length = lhs.len(); + + Bitmap::from_u8_vec(buffer, length) +} + +/// Apply a bitwise operation `op` to two inputs and fold the result. +pub fn binary_fold(lhs: &Bitmap, rhs: &Bitmap, op: F, init: B, fold: R) -> B +where + F: Fn(u64, u64) -> B, + R: Fn(B, B) -> B, +{ + assert_eq!(lhs.len(), rhs.len()); + let lhs_chunks = lhs.chunks(); + let rhs_chunks = rhs.chunks(); + let rem_lhs = lhs_chunks.remainder(); + let rem_rhs = rhs_chunks.remainder(); + + let result = lhs_chunks + .zip(rhs_chunks) + .fold(init, |prev, (left, right)| fold(prev, op(left, right))); + + fold(result, op(rem_lhs, rem_rhs)) +} + +/// Apply a bitwise operation `op` to two inputs and fold the result. +pub fn binary_fold_mut( + lhs: &MutableBitmap, + rhs: &MutableBitmap, + op: F, + init: B, + fold: R, +) -> B +where + F: Fn(u64, u64) -> B, + R: Fn(B, B) -> B, +{ + assert_eq!(lhs.len(), rhs.len()); + let lhs_chunks = lhs.chunks(); + let rhs_chunks = rhs.chunks(); + let rem_lhs = lhs_chunks.remainder(); + let rem_rhs = rhs_chunks.remainder(); + + let result = lhs_chunks + .zip(rhs_chunks) + .fold(init, |prev, (left, right)| fold(prev, op(left, right))); + + fold(result, op(rem_lhs, rem_rhs)) +} + +fn unary_impl(iter: I, op: F, length: usize) -> Bitmap +where + I: BitChunkIterExact, + F: Fn(u64) -> u64, +{ + let rem = op(iter.remainder()); + let buffer = chunk_iter_to_vec_and_remainder(iter.map(op), rem); + + Bitmap::from_u8_vec(buffer, length) +} + +/// Apply a bitwise operation `op` to one input and return the result as a [`Bitmap`]. +pub fn unary(lhs: &Bitmap, op: F) -> Bitmap +where + F: Fn(u64) -> u64, +{ + let (slice, offset, length) = lhs.as_slice(); + if offset == 0 { + let iter = BitChunksExact::::new(slice, length); + unary_impl(iter, op, lhs.len()) + } else { + let iter = lhs.chunks::(); + unary_impl(iter, op, lhs.len()) + } +} + +// create a new [`Bitmap`] semantically equal to ``bitmap`` but with an offset equal to ``offset`` +pub(crate) fn align(bitmap: &Bitmap, new_offset: usize) -> Bitmap { + let length = bitmap.len(); + + let bitmap: Bitmap = std::iter::repeat_n(false, new_offset) + .chain(bitmap.iter()) + .collect(); + + bitmap.sliced(new_offset, length) +} + +/// Compute bitwise A AND B operation. +pub fn and(lhs: &Bitmap, rhs: &Bitmap) -> Bitmap { + if lhs.unset_bits() == lhs.len() || rhs.unset_bits() == rhs.len() { + assert_eq!(lhs.len(), rhs.len()); + Bitmap::new_zeroed(lhs.len()) + } else { + binary(lhs, rhs, |x, y| x & y) + } +} + +/// Compute bitwise A AND NOT B operation. +pub fn and_not(lhs: &Bitmap, rhs: &Bitmap) -> Bitmap { + binary(lhs, rhs, |x, y| x & !y) +} + +/// Compute bitwise A OR B operation. +pub fn or(lhs: &Bitmap, rhs: &Bitmap) -> Bitmap { + if lhs.unset_bits() == 0 || rhs.unset_bits() == 0 { + assert_eq!(lhs.len(), rhs.len()); + let mut mutable = MutableBitmap::with_capacity(lhs.len()); + mutable.extend_constant(lhs.len(), true); + mutable.into() + } else { + binary(lhs, rhs, |x, y| x | y) + } +} + +/// Compute bitwise A OR NOT B operation. +pub fn or_not(lhs: &Bitmap, rhs: &Bitmap) -> Bitmap { + binary(lhs, rhs, |x, y| x | !y) +} + +/// Compute bitwise XOR operation. +pub fn xor(lhs: &Bitmap, rhs: &Bitmap) -> Bitmap { + let lhs_nulls = lhs.unset_bits(); + let rhs_nulls = rhs.unset_bits(); + + // all false or all true + if lhs_nulls == rhs_nulls && rhs_nulls == rhs.len() || lhs_nulls == 0 && rhs_nulls == 0 { + assert_eq!(lhs.len(), rhs.len()); + Bitmap::new_zeroed(rhs.len()) + } + // all false and all true or vice versa + else if (lhs_nulls == 0 && rhs_nulls == rhs.len()) + || (lhs_nulls == lhs.len() && rhs_nulls == 0) + { + assert_eq!(lhs.len(), rhs.len()); + let mut mutable = MutableBitmap::with_capacity(lhs.len()); + mutable.extend_constant(lhs.len(), true); + mutable.into() + } else { + binary(lhs, rhs, |x, y| x ^ y) + } +} + +/// Compute bitwise equality (not XOR) operation. +fn eq(lhs: &Bitmap, rhs: &Bitmap) -> bool { + if lhs.len() != rhs.len() { + return false; + } + + let mut lhs_chunks = lhs.chunks::(); + let mut rhs_chunks = rhs.chunks::(); + + let equal_chunks = lhs_chunks + .by_ref() + .zip(rhs_chunks.by_ref()) + .all(|(left, right)| left == right); + + if !equal_chunks { + return false; + } + let lhs_remainder = lhs_chunks.remainder_iter(); + let rhs_remainder = rhs_chunks.remainder_iter(); + lhs_remainder.zip(rhs_remainder).all(|(x, y)| x == y) +} + +pub fn num_intersections_with(lhs: &Bitmap, rhs: &Bitmap) -> usize { + binary_fold( + lhs, + rhs, + |lhs, rhs| (lhs & rhs).count_ones() as usize, + 0, + |lhs, rhs| lhs + rhs, + ) +} + +pub fn intersects_with(lhs: &Bitmap, rhs: &Bitmap) -> bool { + binary_fold( + lhs, + rhs, + |lhs, rhs| lhs & rhs != 0, + false, + |lhs, rhs| lhs || rhs, + ) +} + +pub fn intersects_with_mut(lhs: &MutableBitmap, rhs: &MutableBitmap) -> bool { + binary_fold_mut( + lhs, + rhs, + |lhs, rhs| lhs & rhs != 0, + false, + |lhs, rhs| lhs || rhs, + ) +} + +pub fn num_edges(lhs: &Bitmap) -> usize { + if lhs.is_empty() { + return 0; + } + + // @TODO: If is probably quite inefficient to do it like this because now either one is not + // aligned. Maybe, we can implement a smarter way to do this. + binary_fold( + &unsafe { lhs.clone().sliced_unchecked(0, lhs.len() - 1) }, + &unsafe { lhs.clone().sliced_unchecked(1, lhs.len() - 1) }, + |l, r| (l ^ r).count_ones() as usize, + 0, + |acc, v| acc + v, + ) +} + +/// Compute `out[i] = if selector[i] { truthy[i] } else { falsy }`. +pub fn select_constant(selector: &Bitmap, truthy: &Bitmap, falsy: bool) -> Bitmap { + let falsy_mask: u64 = if falsy { + 0xFFFF_FFFF_FFFF_FFFF + } else { + 0x0000_0000_0000_0000 + }; + + binary(selector, truthy, |s, t| (s & t) | (!s & falsy_mask)) +} + +/// Compute `out[i] = if selector[i] { truthy[i] } else { falsy[i] }`. +pub fn select(selector: &Bitmap, truthy: &Bitmap, falsy: &Bitmap) -> Bitmap { + ternary(selector, truthy, falsy, |s, t, f| (s & t) | (!s & f)) +} + +impl PartialEq for Bitmap { + fn eq(&self, other: &Self) -> bool { + eq(self, other) + } +} + +impl<'b> BitOr<&'b Bitmap> for &Bitmap { + type Output = Bitmap; + + fn bitor(self, rhs: &'b Bitmap) -> Bitmap { + or(self, rhs) + } +} + +impl<'b> BitAnd<&'b Bitmap> for &Bitmap { + type Output = Bitmap; + + fn bitand(self, rhs: &'b Bitmap) -> Bitmap { + and(self, rhs) + } +} + +impl<'b> BitXor<&'b Bitmap> for &Bitmap { + type Output = Bitmap; + + fn bitxor(self, rhs: &'b Bitmap) -> Bitmap { + xor(self, rhs) + } +} + +impl Not for &Bitmap { + type Output = Bitmap; + + fn not(self) -> Bitmap { + unary(self, |a| !a) + } +} diff --git a/crates/polars-arrow/src/bitmap/bitmask.rs b/crates/polars-arrow/src/bitmap/bitmask.rs new file mode 100644 index 000000000000..8f2085a03ccf --- /dev/null +++ b/crates/polars-arrow/src/bitmap/bitmask.rs @@ -0,0 +1,348 @@ +#[cfg(feature = "simd")] +use std::simd::{LaneCount, Mask, MaskElement, SupportedLaneCount}; + +use polars_utils::slice::load_padded_le_u64; + +use super::iterator::FastU56BitmapIter; +use super::utils::{BitmapIter, count_zeros, fmt}; +use crate::bitmap::Bitmap; + +/// Returns the nth set bit in w, if n+1 bits are set. The indexing is +/// zero-based, nth_set_bit_u32(w, 0) returns the least significant set bit in w. +fn nth_set_bit_u32(w: u32, n: u32) -> Option { + // If we have BMI2's PDEP available, we use it. It takes the lower order + // bits of the first argument and spreads it along its second argument + // where those bits are 1. So PDEP(abcdefgh, 11001001) becomes ef00g00h. + // We use this by setting the first argument to 1 << n, which means the + // first n-1 zero bits of it will spread to the first n-1 one bits of w, + // after which the one bit will exactly get copied to the nth one bit of w. + #[cfg(all(not(miri), target_feature = "bmi2"))] + { + if n >= 32 { + return None; + } + + let nth_set_bit = unsafe { core::arch::x86_64::_pdep_u32(1 << n, w) }; + if nth_set_bit == 0 { + return None; + } + + Some(nth_set_bit.trailing_zeros()) + } + + #[cfg(any(miri, not(target_feature = "bmi2")))] + { + // Each block of 2/4/8/16 bits contains how many set bits there are in that block. + let set_per_2 = w - ((w >> 1) & 0x55555555); + let set_per_4 = (set_per_2 & 0x33333333) + ((set_per_2 >> 2) & 0x33333333); + let set_per_8 = (set_per_4 + (set_per_4 >> 4)) & 0x0f0f0f0f; + let set_per_16 = (set_per_8 + (set_per_8 >> 8)) & 0x00ff00ff; + let set_per_32 = (set_per_16 + (set_per_16 >> 16)) & 0xff; + if n >= set_per_32 { + return None; + } + + let mut idx = 0; + let mut n = n; + let next16 = set_per_16 & 0xff; + if n >= next16 { + n -= next16; + idx += 16; + } + let next8 = (set_per_8 >> idx) & 0xff; + if n >= next8 { + n -= next8; + idx += 8; + } + let next4 = (set_per_4 >> idx) & 0b1111; + if n >= next4 { + n -= next4; + idx += 4; + } + let next2 = (set_per_2 >> idx) & 0b11; + if n >= next2 { + n -= next2; + idx += 2; + } + let next1 = (w >> idx) & 0b1; + if n >= next1 { + idx += 1; + } + Some(idx) + } +} + +#[derive(Default, Clone)] +pub struct BitMask<'a> { + bytes: &'a [u8], + offset: usize, + len: usize, +} + +impl std::fmt::Debug for BitMask<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let Self { bytes, offset, len } = self; + let offset_num_bytes = offset / 8; + let offset_in_byte = offset % 8; + fmt(&bytes[offset_num_bytes..], offset_in_byte, *len, f) + } +} + +impl<'a> BitMask<'a> { + pub fn from_bitmap(bitmap: &'a Bitmap) -> Self { + let (bytes, offset, len) = bitmap.as_slice(); + Self::new(bytes, offset, len) + } + + pub fn inner(&self) -> (&[u8], usize, usize) { + (self.bytes, self.offset, self.len) + } + + pub fn new(bytes: &'a [u8], offset: usize, len: usize) -> Self { + // Check length so we can use unsafe access in our get. + assert!(bytes.len() * 8 >= len + offset); + Self { bytes, offset, len } + } + + #[inline(always)] + pub fn len(&self) -> usize { + self.len + } + + #[inline] + pub fn advance_by(&mut self, idx: usize) { + assert!(idx <= self.len); + self.offset += idx; + self.len -= idx; + } + + #[inline] + pub fn split_at(&self, idx: usize) -> (Self, Self) { + assert!(idx <= self.len); + unsafe { self.split_at_unchecked(idx) } + } + + /// # Safety + /// The index must be in-bounds. + #[inline] + pub unsafe fn split_at_unchecked(&self, idx: usize) -> (Self, Self) { + debug_assert!(idx <= self.len); + let left = Self { len: idx, ..*self }; + let right = Self { + len: self.len - idx, + offset: self.offset + idx, + ..*self + }; + (left, right) + } + + #[inline] + pub fn sliced(&self, offset: usize, length: usize) -> Self { + assert!(offset.checked_add(length).unwrap() <= self.len); + unsafe { self.sliced_unchecked(offset, length) } + } + + /// # Safety + /// The index must be in-bounds. + #[inline] + pub unsafe fn sliced_unchecked(&self, offset: usize, length: usize) -> Self { + if cfg!(debug_assertions) { + assert!(offset.checked_add(length).unwrap() <= self.len); + } + + Self { + bytes: self.bytes, + offset: self.offset + offset, + len: length, + } + } + + pub fn unset_bits(&self) -> usize { + count_zeros(self.bytes, self.offset, self.len) + } + + pub fn set_bits(&self) -> usize { + self.len - self.unset_bits() + } + + pub fn fast_iter_u56(&self) -> FastU56BitmapIter { + FastU56BitmapIter::new(self.bytes, self.offset, self.len) + } + + #[cfg(feature = "simd")] + #[inline] + pub fn get_simd(&self, idx: usize) -> Mask + where + T: MaskElement, + LaneCount: SupportedLaneCount, + { + // We don't support 64-lane masks because then we couldn't load our + // bitwise mask as a u64 and then do the byteshift on it. + + let lanes = LaneCount::::BITMASK_LEN; + assert!(lanes < 64); + + let start_byte_idx = (self.offset + idx) / 8; + let byte_shift = (self.offset + idx) % 8; + if idx + lanes <= self.len { + // SAFETY: fast path, we know this is completely in-bounds. + let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) }); + Mask::from_bitmask(mask >> byte_shift) + } else if idx < self.len { + // SAFETY: we know that at least the first byte is in-bounds. + // This is partially out of bounds, we have to do extra masking. + let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) }); + let num_out_of_bounds = idx + lanes - self.len; + let shifted = (mask << num_out_of_bounds) >> (num_out_of_bounds + byte_shift); + Mask::from_bitmask(shifted) + } else { + Mask::from_bitmask(0u64) + } + } + + #[inline] + pub fn get_u32(&self, idx: usize) -> u32 { + let start_byte_idx = (self.offset + idx) / 8; + let byte_shift = (self.offset + idx) % 8; + if idx + 32 <= self.len { + // SAFETY: fast path, we know this is completely in-bounds. + let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) }); + (mask >> byte_shift) as u32 + } else if idx < self.len { + // SAFETY: we know that at least the first byte is in-bounds. + // This is partially out of bounds, we have to do extra masking. + let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) }); + let out_of_bounds_mask = (1u32 << (self.len - idx)) - 1; + ((mask >> byte_shift) as u32) & out_of_bounds_mask + } else { + 0 + } + } + + /// Computes the index of the nth set bit after start. + /// + /// Both are zero-indexed, so `nth_set_bit_idx(0, 0)` finds the index of the + /// first bit set (which can be 0 as well). The returned index is absolute, + /// not relative to start. + pub fn nth_set_bit_idx(&self, mut n: usize, mut start: usize) -> Option { + while start < self.len { + let next_u32_mask = self.get_u32(start); + if next_u32_mask == u32::MAX { + // Happy fast path for dense non-null section. + if n < 32 { + return Some(start + n); + } + n -= 32; + } else { + let ones = next_u32_mask.count_ones() as usize; + if n < ones { + let idx = unsafe { + // SAFETY: we know the nth bit is in the mask. + nth_set_bit_u32(next_u32_mask, n as u32).unwrap_unchecked() as usize + }; + return Some(start + idx); + } + n -= ones; + } + + start += 32; + } + + None + } + + /// Computes the index of the nth set bit before end, counting backwards. + /// + /// Both are zero-indexed, so nth_set_bit_idx_rev(0, len) finds the index of + /// the last bit set (which can be 0 as well). The returned index is + /// absolute (and starts at the beginning), not relative to end. + pub fn nth_set_bit_idx_rev(&self, mut n: usize, mut end: usize) -> Option { + while end > 0 { + // We want to find bits *before* end, so if end < 32 we must mask + // out the bits after the endth. + let (u32_mask_start, u32_mask_mask) = if end >= 32 { + (end - 32, u32::MAX) + } else { + (0, (1 << end) - 1) + }; + let next_u32_mask = self.get_u32(u32_mask_start) & u32_mask_mask; + if next_u32_mask == u32::MAX { + // Happy fast path for dense non-null section. + if n < 32 { + return Some(end - 1 - n); + } + n -= 32; + } else { + let ones = next_u32_mask.count_ones() as usize; + if n < ones { + let rev_n = ones - 1 - n; + let idx = unsafe { + // SAFETY: we know the rev_nth bit is in the mask. + nth_set_bit_u32(next_u32_mask, rev_n as u32).unwrap_unchecked() as usize + }; + return Some(u32_mask_start + idx); + } + n -= ones; + } + + end = u32_mask_start; + } + + None + } + + #[inline] + pub fn get(&self, idx: usize) -> bool { + let byte_idx = (self.offset + idx) / 8; + let byte_shift = (self.offset + idx) % 8; + + if idx < self.len { + // SAFETY: we know this is in-bounds. + let byte = unsafe { *self.bytes.get_unchecked(byte_idx) }; + (byte >> byte_shift) & 1 == 1 + } else { + false + } + } + + pub fn iter(&self) -> BitmapIter { + BitmapIter::new(self.bytes, self.offset, self.len) + } +} + +#[cfg(test)] +mod test { + use super::*; + + fn naive_nth_bit_set(mut w: u32, mut n: u32) -> Option { + for i in 0..32 { + if w & (1 << i) != 0 { + if n == 0 { + return Some(i); + } + n -= 1; + w ^= 1 << i; + } + } + None + } + + #[test] + fn test_nth_set_bit_u32() { + for n in 0..256 { + assert_eq!(nth_set_bit_u32(0, n), None); + } + + for i in 0..32 { + assert_eq!(nth_set_bit_u32(1 << i, 0), Some(i)); + assert_eq!(nth_set_bit_u32(1 << i, 1), None); + } + + for i in 0..10000 { + let rnd = (0xbdbc9d8ec9d5c461u64.wrapping_mul(i as u64) >> 32) as u32; + for i in 0..=32 { + assert_eq!(nth_set_bit_u32(rnd, i), naive_nth_bit_set(rnd, i)); + } + } + } +} diff --git a/crates/polars-arrow/src/bitmap/builder.rs b/crates/polars-arrow/src/bitmap/builder.rs new file mode 100644 index 000000000000..2442c45b4a90 --- /dev/null +++ b/crates/polars-arrow/src/bitmap/builder.rs @@ -0,0 +1,566 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use polars_utils::IdxSize; +use polars_utils::slice::load_padded_le_u64; + +use super::bitmask::BitMask; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::storage::SharedStorage; +use crate::trusted_len::TrustedLen; + +/// Used to build bitmaps bool-by-bool in sequential order. +#[derive(Default, Clone)] +pub struct BitmapBuilder { + buf: u64, // A buffer containing the last self.bit_len % 64 bits. + bit_len: usize, // Length in bits. + bit_cap: usize, // Capacity in bits (always multiple of 64). + set_bits_in_bytes: usize, // The number of bits set in self.bytes, not including self.buf. + bytes: Vec, +} + +impl BitmapBuilder { + pub fn new() -> Self { + Self::default() + } + + #[inline(always)] + pub fn len(&self) -> usize { + self.bit_len + } + + #[inline(always)] + pub fn is_empty(&self) -> bool { + self.bit_len == 0 + } + + #[inline(always)] + pub fn capacity(&self) -> usize { + self.bit_cap + } + + #[inline(always)] + pub fn set_bits(&self) -> usize { + self.set_bits_in_bytes + self.buf.count_ones() as usize + } + + #[inline(always)] + pub fn unset_bits(&self) -> usize { + self.bit_len - self.set_bits() + } + + pub fn with_capacity(bits: usize) -> Self { + let bytes = Vec::with_capacity(bits.div_ceil(64) * 8); + let words_available = bytes.capacity() / 8; + Self { + buf: 0, + bit_len: 0, + bit_cap: words_available * 64, + set_bits_in_bytes: 0, + bytes, + } + } + + #[inline(always)] + pub fn reserve(&mut self, additional: usize) { + if self.bit_len + additional > self.bit_cap { + self.reserve_slow(additional) + } + } + + #[cold] + #[inline(never)] + fn reserve_slow(&mut self, additional: usize) { + let bytes_needed = (self.bit_len + additional).div_ceil(64) * 8; + self.bytes.reserve(bytes_needed - self.bytes.len()); + let words_available = self.bytes.capacity() / 8; + self.bit_cap = words_available * 64; + } + + #[inline(always)] + pub fn push(&mut self, x: bool) { + self.reserve(1); + unsafe { self.push_unchecked(x) } + } + + /// Does not update len/set_bits, simply writes to the output buffer. + /// # Safety + /// self.bytes.len() + 8 <= self.bytes.capacity() must hold. + #[inline(always)] + unsafe fn flush_word_unchecked(&mut self, w: u64) { + let cur_len = self.bytes.len(); + let p = self.bytes.as_mut_ptr().add(cur_len).cast::(); + p.write_unaligned(w.to_le()); + self.bytes.set_len(cur_len + 8); + } + + /// # Safety + /// self.len() < self.capacity() must hold. + #[inline(always)] + pub unsafe fn push_unchecked(&mut self, x: bool) { + debug_assert!(self.bit_len < self.bit_cap); + self.buf |= (x as u64) << (self.bit_len % 64); + self.bit_len += 1; + if self.bit_len % 64 == 0 { + self.flush_word_unchecked(self.buf); + self.set_bits_in_bytes += self.buf.count_ones() as usize; + self.buf = 0; + } + } + + #[inline(always)] + pub fn extend_constant(&mut self, length: usize, value: bool) { + // Fast path if the extension still fits in buf with room left to spare. + let bits_in_buf = self.bit_len % 64; + if bits_in_buf + length < 64 { + let bit_block = ((value as u64) << length) - (value as u64); + self.buf |= bit_block << bits_in_buf; + self.bit_len += length; + } else { + self.extend_constant_slow(length, value); + } + } + + #[cold] + fn extend_constant_slow(&mut self, length: usize, value: bool) { + unsafe { + let value_spread = if value { u64::MAX } else { 0 }; // Branchless neg. + + // Extend and flush current buf. + self.reserve(length); + let bits_in_buf = self.bit_len % 64; + let ext_buf = self.buf | (value_spread << bits_in_buf); + self.flush_word_unchecked(ext_buf); + self.set_bits_in_bytes += ext_buf.count_ones() as usize; + + // Write complete words. + let remaining_bits = length - (64 - bits_in_buf); + let remaining_words = remaining_bits / 64; + for _ in 0..remaining_words { + self.flush_word_unchecked(value_spread); + } + self.set_bits_in_bytes += (remaining_words * 64) & value_spread as usize; + + // Put remainder in buf and update length. + self.buf = ((value as u64) << (remaining_bits % 64)) - (value as u64); + self.bit_len += length; + } + } + + /// Pushes the first length bits from the given word, assuming the rest of + /// the bits are zero. + /// # Safety + /// self.len + length <= self.cap and length <= 64 must hold. + pub unsafe fn push_word_with_len_unchecked(&mut self, word: u64, length: usize) { + debug_assert!(self.bit_len + length <= self.bit_cap); + debug_assert!(length <= 64); + debug_assert!(length == 64 || (word >> length) == 0); + let bits_in_buf = self.bit_len % 64; + self.buf |= word << bits_in_buf; + if bits_in_buf + length >= 64 { + self.flush_word_unchecked(self.buf); + self.set_bits_in_bytes += self.buf.count_ones() as usize; + self.buf = if bits_in_buf > 0 { + word >> (64 - bits_in_buf) + } else { + 0 + }; + } + self.bit_len += length; + } + + /// # Safety + /// self.len() + length <= self.capacity() must hold, as well as + /// offset + length <= 8 * slice.len(). + unsafe fn extend_from_slice_unchecked( + &mut self, + mut slice: &[u8], + mut offset: usize, + mut length: usize, + ) { + if length == 0 { + return; + } + + // Deal with slice offset so it's aligned to bytes. + let slice_bit_offset = offset % 8; + if slice_bit_offset > 0 { + let bits_in_first_byte = (8 - slice_bit_offset).min(length); + let first_byte = *slice.get_unchecked(offset / 8) >> slice_bit_offset; + self.push_word_with_len_unchecked( + first_byte as u64 & ((1 << bits_in_first_byte) - 1), + bits_in_first_byte, + ); + length -= bits_in_first_byte; + offset += bits_in_first_byte; + } + slice = slice.get_unchecked(offset / 8..); + + // Write word-by-word. + let bits_in_buf = self.bit_len % 64; + if bits_in_buf > 0 { + while length >= 64 { + let word = u64::from_le_bytes(slice.get_unchecked(0..8).try_into().unwrap()); + self.buf |= word << bits_in_buf; + self.flush_word_unchecked(self.buf); + self.set_bits_in_bytes += self.buf.count_ones() as usize; + self.buf = word >> (64 - bits_in_buf); + self.bit_len += 64; + length -= 64; + slice = slice.get_unchecked(8..); + } + } else { + while length >= 64 { + let word = u64::from_le_bytes(slice.get_unchecked(0..8).try_into().unwrap()); + self.flush_word_unchecked(word); + self.set_bits_in_bytes += word.count_ones() as usize; + self.bit_len += 64; + length -= 64; + slice = slice.get_unchecked(8..); + } + } + + // Just the last word left. + if length > 0 { + let word = load_padded_le_u64(slice); + self.push_word_with_len_unchecked(word & ((1 << length) - 1), length); + } + } + + pub fn extend_from_slice(&mut self, slice: &[u8], offset: usize, length: usize) { + assert!(8 * slice.len() >= offset + length); + self.reserve(length); + unsafe { + self.extend_from_slice_unchecked(slice, offset, length); + } + } + + pub fn extend_from_bitmap(&mut self, bitmap: &Bitmap) { + // TODO: we can perhaps use the bitmaps bitcount here instead of + // recomputing it if it has a known bitcount. + let (slice, offset, length) = bitmap.as_slice(); + self.extend_from_slice(slice, offset, length); + } + + pub fn extend_from_bitmask(&mut self, bitmap: BitMask<'_>) { + let (slice, offset, length) = bitmap.inner(); + self.extend_from_slice(slice, offset, length); + } + + /// Extends this BitmapBuilder with a subslice of a bitmap. + pub fn subslice_extend_from_bitmap(&mut self, bitmap: &Bitmap, start: usize, length: usize) { + let (slice, bm_offset, bm_length) = bitmap.as_slice(); + assert!(start + length <= bm_length); + self.extend_from_slice(slice, bm_offset + start, length); + } + + pub fn subslice_extend_from_opt_validity( + &mut self, + bitmap: Option<&Bitmap>, + start: usize, + length: usize, + ) { + match bitmap { + Some(bm) => self.subslice_extend_from_bitmap(bm, start, length), + None => self.extend_constant(length, true), + } + } + + /// # Safety + /// The indices must be in-bounds. + pub unsafe fn gather_extend_from_slice( + &mut self, + slice: &[u8], + offset: usize, + length: usize, + idxs: &[IdxSize], + ) { + assert!(8 * slice.len() >= offset + length); + + self.reserve(idxs.len()); + unsafe { + for idx in idxs { + debug_assert!((*idx as usize) < length); + let idx_in_slice = offset + *idx as usize; + let bit = (*slice.get_unchecked(idx_in_slice / 8) >> (idx_in_slice % 8)) & 1; + self.push_unchecked(bit != 0); + } + } + } + + pub fn opt_gather_extend_from_slice( + &mut self, + slice: &[u8], + offset: usize, + length: usize, + idxs: &[IdxSize], + ) { + assert!(8 * slice.len() >= offset + length); + + self.reserve(idxs.len()); + unsafe { + for idx in idxs { + if (*idx as usize) < length { + let idx_in_slice = offset + *idx as usize; + let bit = (*slice.get_unchecked(idx_in_slice / 8) >> (idx_in_slice % 8)) & 1; + self.push_unchecked(bit != 0); + } else { + self.push_unchecked(false); + } + } + } + } + + /// # Safety + /// The indices must be in-bounds. + pub unsafe fn gather_extend_from_bitmap(&mut self, bitmap: &Bitmap, idxs: &[IdxSize]) { + let (slice, offset, length) = bitmap.as_slice(); + self.gather_extend_from_slice(slice, offset, length, idxs); + } + + pub fn opt_gather_extend_from_bitmap(&mut self, bitmap: &Bitmap, idxs: &[IdxSize]) { + let (slice, offset, length) = bitmap.as_slice(); + self.opt_gather_extend_from_slice(slice, offset, length, idxs); + } + + /// # Safety + /// The indices must be in-bounds. + pub unsafe fn gather_extend_from_opt_validity( + &mut self, + bitmap: Option<&Bitmap>, + idxs: &[IdxSize], + length: usize, + ) { + if let Some(bm) = bitmap { + let (slice, offset, sl_length) = bm.as_slice(); + debug_assert_eq!(sl_length, length); + self.gather_extend_from_slice(slice, offset, length, idxs); + } else { + self.extend_constant(length, true); + } + } + + pub fn opt_gather_extend_from_opt_validity( + &mut self, + bitmap: Option<&Bitmap>, + idxs: &[IdxSize], + length: usize, + ) { + if let Some(bm) = bitmap { + let (slice, offset, sl_length) = bm.as_slice(); + debug_assert_eq!(sl_length, length); + self.opt_gather_extend_from_slice(slice, offset, sl_length, idxs); + } else { + unsafe { + self.reserve(idxs.len()); + for idx in idxs { + self.push_unchecked((*idx as usize) < length); + } + } + } + } + + /// # Safety + /// May only be called once at the end. + unsafe fn finish(&mut self) { + if self.bit_len % 64 != 0 { + self.bytes.extend_from_slice(&self.buf.to_le_bytes()); + self.set_bits_in_bytes += self.buf.count_ones() as usize; + self.buf = 0; + } + } + + /// Converts this BitmapBuilder into a mutable bitmap. + pub fn into_mut(mut self) -> MutableBitmap { + unsafe { + self.finish(); + MutableBitmap::from_vec(self.bytes, self.bit_len) + } + } + + /// The same as into_mut, but returns None if the bitmap is all-ones. + pub fn into_opt_mut_validity(mut self) -> Option { + unsafe { + self.finish(); + if self.set_bits_in_bytes == self.bit_len { + return None; + } + Some(MutableBitmap::from_vec(self.bytes, self.bit_len)) + } + } + + /// Freezes this BitmapBuilder into an immutable Bitmap. + pub fn freeze(mut self) -> Bitmap { + unsafe { + self.finish(); + let storage = SharedStorage::from_vec(self.bytes); + Bitmap::from_inner_unchecked( + storage, + 0, + self.bit_len, + Some(self.bit_len - self.set_bits_in_bytes), + ) + } + } + + /// The same as freeze, but returns None if the bitmap is all-ones. + pub fn into_opt_validity(mut self) -> Option { + unsafe { + self.finish(); + if self.set_bits_in_bytes == self.bit_len { + return None; + } + let storage = SharedStorage::from_vec(self.bytes); + let bitmap = Bitmap::from_inner_unchecked( + storage, + 0, + self.bit_len, + Some(self.bit_len - self.set_bits_in_bytes), + ); + Some(bitmap) + } + } + + pub fn extend_trusted_len_iter(&mut self, iterator: I) + where + I: Iterator + TrustedLen, + { + self.reserve(iterator.size_hint().1.unwrap()); + for b in iterator { + // SAFETY: we reserved and the iterator's length is trusted. + unsafe { + self.push_unchecked(b); + } + } + } + + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + I: Iterator + TrustedLen, + { + let mut builder = Self::new(); + builder.extend_trusted_len_iter(iterator); + builder + } +} + +/// A wrapper for BitmapBuilder that does not allocate until the first false is +/// pushed. Less efficient if you know there are false values because it must +/// check if it has allocated for each push. +pub enum OptBitmapBuilder { + AllTrue { bit_len: usize, bit_cap: usize }, + MayHaveFalse(BitmapBuilder), +} + +impl Default for OptBitmapBuilder { + fn default() -> Self { + Self::AllTrue { + bit_len: 0, + bit_cap: 0, + } + } +} + +impl OptBitmapBuilder { + pub fn reserve(&mut self, additional: usize) { + match self { + Self::AllTrue { bit_len, bit_cap } => { + *bit_cap = usize::max(*bit_cap, *bit_len + additional); + }, + Self::MayHaveFalse(inner) => inner.reserve(additional), + } + } + + pub fn extend_constant(&mut self, length: usize, value: bool) { + match self { + Self::AllTrue { bit_len, bit_cap } => { + if value { + *bit_cap = usize::max(*bit_cap, *bit_len + length); + *bit_len += length; + } else { + self.get_builder().extend_constant(length, value); + } + }, + Self::MayHaveFalse(inner) => inner.extend_constant(length, value), + } + } + + pub fn into_opt_validity(self) -> Option { + match self { + Self::AllTrue { .. } => None, + Self::MayHaveFalse(inner) => inner.into_opt_validity(), + } + } + + pub fn subslice_extend_from_opt_validity( + &mut self, + bitmap: Option<&Bitmap>, + start: usize, + length: usize, + ) { + match bitmap { + Some(bm) => { + self.get_builder() + .subslice_extend_from_bitmap(bm, start, length); + }, + None => { + self.extend_constant(length, true); + }, + } + } + + /// # Safety + /// The indices must be in-bounds. + pub unsafe fn gather_extend_from_opt_validity( + &mut self, + bitmap: Option<&Bitmap>, + idxs: &[IdxSize], + ) { + match bitmap { + Some(bm) => { + self.get_builder().gather_extend_from_bitmap(bm, idxs); + }, + None => { + self.extend_constant(idxs.len(), true); + }, + } + } + + pub fn opt_gather_extend_from_opt_validity( + &mut self, + bitmap: Option<&Bitmap>, + idxs: &[IdxSize], + length: usize, + ) { + match bitmap { + Some(bm) => { + self.get_builder().opt_gather_extend_from_bitmap(bm, idxs); + }, + None => { + if let Some(first_oob) = idxs.iter().position(|idx| *idx as usize >= length) { + let builder = self.get_builder(); + builder.extend_constant(first_oob, true); + for idx in idxs.iter().skip(first_oob) { + builder.push((*idx as usize) < length); + } + } else { + self.extend_constant(idxs.len(), true); + } + }, + } + } + + fn get_builder(&mut self) -> &mut BitmapBuilder { + match self { + Self::AllTrue { bit_len, bit_cap } => { + let mut builder = BitmapBuilder::with_capacity(*bit_cap); + builder.extend_constant(*bit_len, true); + *self = Self::MayHaveFalse(builder); + let Self::MayHaveFalse(inner) = self else { + unreachable!() + }; + inner + }, + Self::MayHaveFalse(inner) => inner, + } + } +} diff --git a/crates/polars-arrow/src/bitmap/immutable.rs b/crates/polars-arrow/src/bitmap/immutable.rs new file mode 100644 index 000000000000..a8fd485944ec --- /dev/null +++ b/crates/polars-arrow/src/bitmap/immutable.rs @@ -0,0 +1,779 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use std::ops::Deref; +use std::sync::LazyLock; +use std::sync::atomic::{AtomicU64, Ordering}; + +use either::Either; +use polars_error::{PolarsResult, polars_bail}; + +use super::utils::{self, BitChunk, BitChunks, BitmapIter, count_zeros, fmt, get_bit_unchecked}; +use super::{IntoIter, MutableBitmap, chunk_iter_to_vec, intersects_with, num_intersections_with}; +use crate::array::Splitable; +use crate::bitmap::aligned::AlignedBitmapSlice; +use crate::bitmap::iterator::{ + FastU32BitmapIter, FastU56BitmapIter, FastU64BitmapIter, TrueIdxIter, +}; +use crate::legacy::utils::FromTrustedLenIterator; +use crate::storage::SharedStorage; +use crate::trusted_len::TrustedLen; + +const UNKNOWN_BIT_COUNT: u64 = u64::MAX; + +/// An immutable container semantically equivalent to `Arc>` but represented as `Arc>` where +/// each boolean is represented as a single bit. +/// +/// # Examples +/// ``` +/// use polars_arrow::bitmap::{Bitmap, MutableBitmap}; +/// +/// let bitmap = Bitmap::from([true, false, true]); +/// assert_eq!(bitmap.iter().collect::>(), vec![true, false, true]); +/// +/// // creation directly from bytes +/// let bitmap = Bitmap::try_new(vec![0b00001101], 5).unwrap(); +/// // note: the first bit is the left-most of the first byte +/// assert_eq!(bitmap.iter().collect::>(), vec![true, false, true, true, false]); +/// // we can also get the slice: +/// assert_eq!(bitmap.as_slice(), ([0b00001101u8].as_ref(), 0, 5)); +/// // debug helps :) +/// assert_eq!(format!("{:?}", bitmap), "Bitmap { len: 5, offset: 0, bytes: [0b___01101] }"); +/// +/// // it supports copy-on-write semantics (to a `MutableBitmap`) +/// let bitmap: MutableBitmap = bitmap.into_mut().right().unwrap(); +/// assert_eq!(bitmap, MutableBitmap::from([true, false, true, true, false])); +/// +/// // slicing is 'O(1)' (data is shared) +/// let bitmap = Bitmap::try_new(vec![0b00001101], 5).unwrap(); +/// let mut sliced = bitmap.clone(); +/// sliced.slice(1, 4); +/// assert_eq!(sliced.as_slice(), ([0b00001101u8].as_ref(), 1, 4)); // 1 here is the offset: +/// assert_eq!(format!("{:?}", sliced), "Bitmap { len: 4, offset: 1, bytes: [0b___0110_] }"); +/// // when sliced (or cloned), it is no longer possible to `into_mut`. +/// let same: Bitmap = sliced.into_mut().left().unwrap(); +/// ``` +#[derive(Default)] +pub struct Bitmap { + storage: SharedStorage, + // Both offset and length are measured in bits. They are used to bound the + // bitmap to a region of Bytes. + offset: usize, + length: usize, + + // A bit field that contains our cache for the number of unset bits. + // If it is u64::MAX, we have no known value at all. + // Other bit patterns where the top bit is set is reserved for future use. + // If the top bit is not set we have an exact count. + unset_bit_count_cache: AtomicU64, +} + +#[inline(always)] +fn has_cached_unset_bit_count(ubcc: u64) -> bool { + ubcc >> 63 == 0 +} + +impl Clone for Bitmap { + fn clone(&self) -> Self { + Self { + storage: self.storage.clone(), + offset: self.offset, + length: self.length, + unset_bit_count_cache: AtomicU64::new( + self.unset_bit_count_cache.load(Ordering::Relaxed), + ), + } + } +} + +impl std::fmt::Debug for Bitmap { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let (bytes, offset, len) = self.as_slice(); + fmt(bytes, offset, len, f) + } +} + +pub(super) fn check(bytes: &[u8], offset: usize, length: usize) -> PolarsResult<()> { + if offset + length > bytes.len().saturating_mul(8) { + polars_bail!(InvalidOperation: + "The offset + length of the bitmap ({}) must be `<=` to the number of bytes times 8 ({})", + offset + length, + bytes.len().saturating_mul(8) + ); + } + Ok(()) +} + +impl Bitmap { + /// Initializes an empty [`Bitmap`]. + #[inline] + pub fn new() -> Self { + Self::default() + } + + /// Initializes a new [`Bitmap`] from vector of bytes and a length. + /// # Errors + /// This function errors iff `length > bytes.len() * 8` + #[inline] + pub fn try_new(bytes: Vec, length: usize) -> PolarsResult { + check(&bytes, 0, length)?; + Ok(Self { + storage: SharedStorage::from_vec(bytes), + length, + offset: 0, + unset_bit_count_cache: AtomicU64::new(if length == 0 { 0 } else { UNKNOWN_BIT_COUNT }), + }) + } + + /// Returns the length of the [`Bitmap`]. + #[inline] + pub fn len(&self) -> usize { + self.length + } + + /// Returns whether [`Bitmap`] is empty + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns a new iterator of `bool` over this bitmap + pub fn iter(&self) -> BitmapIter { + BitmapIter::new(&self.storage, self.offset, self.length) + } + + /// Returns an iterator over bits in bit chunks [`BitChunk`]. + /// + /// This iterator is useful to operate over multiple bits via e.g. bitwise. + pub fn chunks(&self) -> BitChunks { + BitChunks::new(&self.storage, self.offset, self.length) + } + + /// Returns a fast iterator that gives 32 bits at a time. + /// Has a remainder that must be handled separately. + pub fn fast_iter_u32(&self) -> FastU32BitmapIter<'_> { + FastU32BitmapIter::new(&self.storage, self.offset, self.length) + } + + /// Returns a fast iterator that gives 56 bits at a time. + /// Has a remainder that must be handled separately. + pub fn fast_iter_u56(&self) -> FastU56BitmapIter<'_> { + FastU56BitmapIter::new(&self.storage, self.offset, self.length) + } + + /// Returns a fast iterator that gives 64 bits at a time. + /// Has a remainder that must be handled separately. + pub fn fast_iter_u64(&self) -> FastU64BitmapIter<'_> { + FastU64BitmapIter::new(&self.storage, self.offset, self.length) + } + + /// Returns an iterator that only iterates over the set bits. + pub fn true_idx_iter(&self) -> TrueIdxIter<'_> { + TrueIdxIter::new(self.len(), Some(self)) + } + + /// Returns the bits of this [`Bitmap`] as a [`AlignedBitmapSlice`]. + pub fn aligned(&self) -> AlignedBitmapSlice<'_, T> { + AlignedBitmapSlice::new(&self.storage, self.offset, self.length) + } + + /// Returns the byte slice of this [`Bitmap`]. + /// + /// The returned tuple contains: + /// * `.1`: The byte slice, truncated to the start of the first bit. So the start of the slice + /// is within the first 8 bits. + /// * `.2`: The start offset in bits on a range `0 <= offsets < 8`. + /// * `.3`: The length in number of bits. + #[inline] + pub fn as_slice(&self) -> (&[u8], usize, usize) { + let start = self.offset / 8; + let len = (self.offset % 8 + self.length).saturating_add(7) / 8; + ( + &self.storage[start..start + len], + self.offset % 8, + self.length, + ) + } + + /// Returns the number of set bits on this [`Bitmap`]. + /// + /// See `unset_bits` for details. + #[inline] + pub fn set_bits(&self) -> usize { + self.length - self.unset_bits() + } + + /// Returns the number of set bits on this [`Bitmap`] if it is known. + /// + /// See `lazy_unset_bits` for details. + #[inline] + pub fn lazy_set_bits(&self) -> Option { + Some(self.length - self.lazy_unset_bits()?) + } + + /// Returns the number of unset bits on this [`Bitmap`]. + /// + /// Guaranteed to be `<= self.len()`. + /// + /// # Implementation + /// + /// This function counts the number of unset bits if it is not already + /// computed. Repeated calls use the cached bitcount. + pub fn unset_bits(&self) -> usize { + self.lazy_unset_bits().unwrap_or_else(|| { + let zeros = count_zeros(&self.storage, self.offset, self.length); + self.unset_bit_count_cache + .store(zeros as u64, Ordering::Relaxed); + zeros + }) + } + + /// Returns the number of unset bits on this [`Bitmap`] if it is known. + /// + /// Guaranteed to be `<= self.len()`. + pub fn lazy_unset_bits(&self) -> Option { + let cache = self.unset_bit_count_cache.load(Ordering::Relaxed); + has_cached_unset_bit_count(cache).then_some(cache as usize) + } + + /// Updates the count of the number of set bits on this [`Bitmap`]. + /// + /// # Safety + /// + /// The number of set bits must be correct. + pub unsafe fn update_bit_count(&mut self, bits_set: usize) { + assert!(bits_set <= self.length); + let zeros = self.length - bits_set; + self.unset_bit_count_cache + .store(zeros as u64, Ordering::Relaxed); + } + + /// Slices `self`, offsetting by `offset` and truncating up to `length` bits. + /// # Panic + /// Panics iff `offset + length > self.length`, i.e. if the offset and `length` + /// exceeds the allocated capacity of `self`. + #[inline] + pub fn slice(&mut self, offset: usize, length: usize) { + assert!(offset + length <= self.length); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Slices `self`, offsetting by `offset` and truncating up to `length` bits. + /// + /// # Safety + /// The caller must ensure that `self.offset + offset + length <= self.len()` + #[inline] + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + // Fast path: no-op slice. + if offset == 0 && length == self.length { + return; + } + + // Fast path: we have no nulls or are full-null. + let unset_bit_count_cache = self.unset_bit_count_cache.get_mut(); + if *unset_bit_count_cache == 0 || *unset_bit_count_cache == self.length as u64 { + let new_count = if *unset_bit_count_cache > 0 { + length as u64 + } else { + 0 + }; + *unset_bit_count_cache = new_count; + self.offset += offset; + self.length = length; + return; + } + + if has_cached_unset_bit_count(*unset_bit_count_cache) { + // If we keep all but a small portion of the array it is worth + // doing an eager re-count since we can reuse the old count via the + // inclusion-exclusion principle. + let small_portion = (self.length / 5).max(32); + if length + small_portion >= self.length { + // Subtract the null count of the chunks we slice off. + let slice_end = self.offset + offset + length; + let head_count = count_zeros(&self.storage, self.offset, offset); + let tail_count = + count_zeros(&self.storage, slice_end, self.length - length - offset); + let new_count = *unset_bit_count_cache - head_count as u64 - tail_count as u64; + *unset_bit_count_cache = new_count; + } else { + *unset_bit_count_cache = UNKNOWN_BIT_COUNT; + } + } + + self.offset += offset; + self.length = length; + } + + /// Slices `self`, offsetting by `offset` and truncating up to `length` bits. + /// # Panic + /// Panics iff `offset + length > self.length`, i.e. if the offset and `length` + /// exceeds the allocated capacity of `self`. + #[inline] + #[must_use] + pub fn sliced(self, offset: usize, length: usize) -> Self { + assert!(offset + length <= self.length); + unsafe { self.sliced_unchecked(offset, length) } + } + + /// Slices `self`, offsetting by `offset` and truncating up to `length` bits. + /// + /// # Safety + /// The caller must ensure that `self.offset + offset + length <= self.len()` + #[inline] + #[must_use] + pub unsafe fn sliced_unchecked(mut self, offset: usize, length: usize) -> Self { + self.slice_unchecked(offset, length); + self + } + + /// Returns whether the bit at position `i` is set. + /// # Panics + /// Panics iff `i >= self.len()`. + #[inline] + pub fn get_bit(&self, i: usize) -> bool { + assert!(i < self.len()); + unsafe { self.get_bit_unchecked(i) } + } + + /// Unsafely returns whether the bit at position `i` is set. + /// + /// # Safety + /// Unsound iff `i >= self.len()`. + #[inline] + pub unsafe fn get_bit_unchecked(&self, i: usize) -> bool { + debug_assert!(i < self.len()); + get_bit_unchecked(&self.storage, self.offset + i) + } + + /// Returns a pointer to the start of this [`Bitmap`] (ignores `offsets`) + /// This pointer is allocated iff `self.len() > 0`. + pub(crate) fn as_ptr(&self) -> *const u8 { + self.storage.deref().as_ptr() + } + + /// Returns a pointer to the start of this [`Bitmap`] (ignores `offsets`) + /// This pointer is allocated iff `self.len() > 0`. + pub(crate) fn offset(&self) -> usize { + self.offset + } + + /// Converts this [`Bitmap`] to [`MutableBitmap`], returning itself if the conversion + /// is not possible + /// + /// This operation returns a [`MutableBitmap`] iff: + /// * this [`Bitmap`] is not an offsetted slice of another [`Bitmap`] + /// * this [`Bitmap`] has not been cloned (i.e. [`Arc`]`::get_mut` yields [`Some`]) + /// * this [`Bitmap`] was not imported from the c data interface (FFI) + pub fn into_mut(mut self) -> Either { + match self.storage.try_into_vec() { + Ok(v) => Either::Right(MutableBitmap::from_vec(v, self.length)), + Err(storage) => { + self.storage = storage; + Either::Left(self) + }, + } + } + + /// Converts this [`Bitmap`] into a [`MutableBitmap`], cloning its internal + /// buffer if required (clone-on-write). + pub fn make_mut(self) -> MutableBitmap { + match self.into_mut() { + Either::Left(data) => { + if data.offset > 0 { + // re-align the bits (remove the offset) + let chunks = data.chunks::(); + let remainder = chunks.remainder(); + let vec = chunk_iter_to_vec(chunks.chain(std::iter::once(remainder))); + MutableBitmap::from_vec(vec, data.length) + } else { + MutableBitmap::from_vec(data.storage.as_ref().to_vec(), data.length) + } + }, + Either::Right(data) => data, + } + } + + /// Initializes an new [`Bitmap`] filled with unset values. + #[inline] + pub fn new_zeroed(length: usize) -> Self { + // We intentionally leak 1MiB of zeroed memory once so we don't have to + // refcount it. + const GLOBAL_ZERO_SIZE: usize = 1024 * 1024; + static GLOBAL_ZEROES: LazyLock> = LazyLock::new(|| { + let mut ss = SharedStorage::from_vec(vec![0; GLOBAL_ZERO_SIZE]); + ss.leak(); + ss + }); + + let bytes_needed = length.div_ceil(8); + let storage = if bytes_needed <= GLOBAL_ZERO_SIZE { + GLOBAL_ZEROES.clone() + } else { + SharedStorage::from_vec(vec![0; bytes_needed]) + }; + Self { + storage, + offset: 0, + length, + unset_bit_count_cache: AtomicU64::new(length as u64), + } + } + + /// Initializes an new [`Bitmap`] filled with the given value. + #[inline] + pub fn new_with_value(value: bool, length: usize) -> Self { + if !value { + return Self::new_zeroed(length); + } + + unsafe { + Bitmap::from_inner_unchecked( + SharedStorage::from_vec(vec![u8::MAX; length.saturating_add(7) / 8]), + 0, + length, + Some(0), + ) + } + } + + /// Counts the nulls (unset bits) starting from `offset` bits and for `length` bits. + #[inline] + pub fn null_count_range(&self, offset: usize, length: usize) -> usize { + count_zeros(&self.storage, self.offset + offset, length) + } + + /// Creates a new [`Bitmap`] from a slice and length. + /// # Panic + /// Panics iff `length > bytes.len() * 8` + #[inline] + pub fn from_u8_slice>(slice: T, length: usize) -> Self { + Bitmap::try_new(slice.as_ref().to_vec(), length).unwrap() + } + + /// Alias for `Bitmap::try_new().unwrap()` + /// This function is `O(1)` + /// # Panic + /// This function panics iff `length > bytes.len() * 8` + #[inline] + pub fn from_u8_vec(vec: Vec, length: usize) -> Self { + Bitmap::try_new(vec, length).unwrap() + } + + /// Returns whether the bit at position `i` is set. + #[inline] + pub fn get(&self, i: usize) -> Option { + if i < self.len() { + Some(unsafe { self.get_bit_unchecked(i) }) + } else { + None + } + } + + /// Creates a [`Bitmap`] from its internal representation. + /// This is the inverted from [`Bitmap::into_inner`] + /// + /// # Safety + /// Callers must ensure all invariants of this struct are upheld. + pub unsafe fn from_inner_unchecked( + storage: SharedStorage, + offset: usize, + length: usize, + unset_bits: Option, + ) -> Self { + debug_assert!(check(&storage[..], offset, length).is_ok()); + + let unset_bit_count_cache = if let Some(n) = unset_bits { + AtomicU64::new(n as u64) + } else { + AtomicU64::new(UNKNOWN_BIT_COUNT) + }; + Self { + storage, + offset, + length, + unset_bit_count_cache, + } + } + + /// Checks whether two [`Bitmap`]s have shared set bits. + /// + /// This is an optimized version of `(self & other) != 0000..`. + pub fn intersects_with(&self, other: &Self) -> bool { + intersects_with(self, other) + } + + /// Calculates the number of shared set bits between two [`Bitmap`]s. + pub fn num_intersections_with(&self, other: &Self) -> usize { + num_intersections_with(self, other) + } + + /// Select between `truthy` and `falsy` based on `self`. + /// + /// This essentially performs: + /// + /// `out[i] = if self[i] { truthy[i] } else { falsy[i] }` + pub fn select(&self, truthy: &Self, falsy: &Self) -> Self { + super::bitmap_ops::select(self, truthy, falsy) + } + + /// Select between `truthy` and constant `falsy` based on `self`. + /// + /// This essentially performs: + /// + /// `out[i] = if self[i] { truthy[i] } else { falsy }` + pub fn select_constant(&self, truthy: &Self, falsy: bool) -> Self { + super::bitmap_ops::select_constant(self, truthy, falsy) + } + + /// Calculates the number of edges from `0 -> 1` and `1 -> 0`. + pub fn num_edges(&self) -> usize { + super::bitmap_ops::num_edges(self) + } + + /// Returns the number of zero bits from the start before a one bit is seen + pub fn leading_zeros(&self) -> usize { + utils::leading_zeros(&self.storage, self.offset, self.length) + } + /// Returns the number of one bits from the start before a zero bit is seen + pub fn leading_ones(&self) -> usize { + utils::leading_ones(&self.storage, self.offset, self.length) + } + /// Returns the number of zero bits from the back before a one bit is seen + pub fn trailing_zeros(&self) -> usize { + utils::trailing_zeros(&self.storage, self.offset, self.length) + } + /// Returns the number of one bits from the back before a zero bit is seen + pub fn trailing_ones(&mut self) -> usize { + utils::trailing_ones(&self.storage, self.offset, self.length) + } + + /// Take all `0` bits at the start of the [`Bitmap`] before a `1` is seen, returning how many + /// bits were taken + pub fn take_leading_zeros(&mut self) -> usize { + if self + .lazy_unset_bits() + .is_some_and(|unset_bits| unset_bits == self.length) + { + let leading_zeros = self.length; + self.offset += self.length; + self.length = 0; + *self.unset_bit_count_cache.get_mut() = 0; + return leading_zeros; + } + + let leading_zeros = self.leading_zeros(); + self.offset += leading_zeros; + self.length -= leading_zeros; + if has_cached_unset_bit_count(*self.unset_bit_count_cache.get_mut()) { + *self.unset_bit_count_cache.get_mut() -= leading_zeros as u64; + } + leading_zeros + } + /// Take all `1` bits at the start of the [`Bitmap`] before a `0` is seen, returning how many + /// bits were taken + pub fn take_leading_ones(&mut self) -> usize { + if self + .lazy_unset_bits() + .is_some_and(|unset_bits| unset_bits == 0) + { + let leading_ones = self.length; + self.offset += self.length; + self.length = 0; + *self.unset_bit_count_cache.get_mut() = 0; + return leading_ones; + } + + let leading_ones = self.leading_ones(); + self.offset += leading_ones; + self.length -= leading_ones; + // @NOTE: the unset_bit_count_cache remains unchanged + leading_ones + } + /// Take all `0` bits at the back of the [`Bitmap`] before a `1` is seen, returning how many + /// bits were taken + pub fn take_trailing_zeros(&mut self) -> usize { + if self + .lazy_unset_bits() + .is_some_and(|unset_bits| unset_bits == self.length) + { + let trailing_zeros = self.length; + self.length = 0; + *self.unset_bit_count_cache.get_mut() = 0; + return trailing_zeros; + } + + let trailing_zeros = self.trailing_zeros(); + self.length -= trailing_zeros; + if has_cached_unset_bit_count(*self.unset_bit_count_cache.get_mut()) { + *self.unset_bit_count_cache.get_mut() -= trailing_zeros as u64; + } + trailing_zeros + } + /// Take all `1` bits at the back of the [`Bitmap`] before a `0` is seen, returning how many + /// bits were taken + pub fn take_trailing_ones(&mut self) -> usize { + if self + .lazy_unset_bits() + .is_some_and(|unset_bits| unset_bits == 0) + { + let trailing_ones = self.length; + self.length = 0; + *self.unset_bit_count_cache.get_mut() = 0; + return trailing_ones; + } + + let trailing_ones = self.trailing_ones(); + self.length -= trailing_ones; + // @NOTE: the unset_bit_count_cache remains unchanged + trailing_ones + } +} + +impl> From

for Bitmap { + fn from(slice: P) -> Self { + Self::from_trusted_len_iter(slice.as_ref().iter().copied()) + } +} + +impl FromIterator for Bitmap { + fn from_iter(iter: I) -> Self + where + I: IntoIterator, + { + MutableBitmap::from_iter(iter).into() + } +} + +impl FromTrustedLenIterator for Bitmap { + fn from_iter_trusted_length>(iter: T) -> Self + where + T::IntoIter: TrustedLen, + { + MutableBitmap::from_trusted_len_iter(iter.into_iter()).into() + } +} + +impl Bitmap { + /// Creates a new [`Bitmap`] from an iterator of booleans. + /// + /// # Safety + /// The iterator must report an accurate length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked>(iterator: I) -> Self { + MutableBitmap::from_trusted_len_iter_unchecked(iterator).into() + } + + /// Creates a new [`Bitmap`] from an iterator of booleans. + #[inline] + pub fn from_trusted_len_iter>(iterator: I) -> Self { + MutableBitmap::from_trusted_len_iter(iterator).into() + } + + /// Creates a new [`Bitmap`] from a fallible iterator of booleans. + #[inline] + pub fn try_from_trusted_len_iter>>( + iterator: I, + ) -> std::result::Result { + Ok(MutableBitmap::try_from_trusted_len_iter(iterator)?.into()) + } + + /// Creates a new [`Bitmap`] from a fallible iterator of booleans. + /// + /// # Safety + /// The iterator must report an accurate length. + #[inline] + pub unsafe fn try_from_trusted_len_iter_unchecked< + E, + I: Iterator>, + >( + iterator: I, + ) -> std::result::Result { + Ok(MutableBitmap::try_from_trusted_len_iter_unchecked(iterator)?.into()) + } +} + +impl<'a> IntoIterator for &'a Bitmap { + type Item = bool; + type IntoIter = BitmapIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + BitmapIter::<'a>::new(&self.storage, self.offset, self.length) + } +} + +impl IntoIterator for Bitmap { + type Item = bool; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + IntoIter::new(self) + } +} + +impl Splitable for Bitmap { + #[inline(always)] + fn check_bound(&self, offset: usize) -> bool { + offset <= self.len() + } + + unsafe fn _split_at_unchecked(&self, offset: usize) -> (Self, Self) { + if offset == 0 { + return (Self::new(), self.clone()); + } + if offset == self.len() { + return (self.clone(), Self::new()); + } + + let ubcc = self.unset_bit_count_cache.load(Ordering::Relaxed); + + let lhs_length = offset; + let rhs_length = self.length - offset; + + let mut lhs_ubcc = UNKNOWN_BIT_COUNT; + let mut rhs_ubcc = UNKNOWN_BIT_COUNT; + + if has_cached_unset_bit_count(ubcc) { + if ubcc == 0 { + lhs_ubcc = 0; + rhs_ubcc = 0; + } else if ubcc == self.length as u64 { + lhs_ubcc = offset as u64; + rhs_ubcc = (self.length - offset) as u64; + } else { + // If we keep all but a small portion of the array it is worth + // doing an eager re-count since we can reuse the old count via the + // inclusion-exclusion principle. + let small_portion = (self.length / 4).max(32); + + if lhs_length <= rhs_length { + if rhs_length + small_portion >= self.length { + let count = count_zeros(&self.storage, self.offset, lhs_length) as u64; + lhs_ubcc = count; + rhs_ubcc = ubcc - count; + } + } else if lhs_length + small_portion >= self.length { + let count = count_zeros(&self.storage, self.offset + offset, rhs_length) as u64; + lhs_ubcc = ubcc - count; + rhs_ubcc = count; + } + } + } + + debug_assert!(lhs_ubcc == UNKNOWN_BIT_COUNT || lhs_ubcc <= ubcc); + debug_assert!(rhs_ubcc == UNKNOWN_BIT_COUNT || rhs_ubcc <= ubcc); + + ( + Self { + storage: self.storage.clone(), + offset: self.offset, + length: lhs_length, + unset_bit_count_cache: AtomicU64::new(lhs_ubcc), + }, + Self { + storage: self.storage.clone(), + offset: self.offset + offset, + length: rhs_length, + unset_bit_count_cache: AtomicU64::new(rhs_ubcc), + }, + ) + } +} diff --git a/crates/polars-arrow/src/bitmap/iterator.rs b/crates/polars-arrow/src/bitmap/iterator.rs new file mode 100644 index 000000000000..e97ec9108012 --- /dev/null +++ b/crates/polars-arrow/src/bitmap/iterator.rs @@ -0,0 +1,419 @@ +use polars_utils::slice::load_padded_le_u64; + +use super::Bitmap; +use super::bitmask::BitMask; +use crate::trusted_len::TrustedLen; + +/// Calculates how many iterations are remaining, assuming: +/// - We have length elements left. +/// - We need max(consume, min_length_for_iter) elements to start a new iteration. +/// - On each iteration we consume the given amount of elements. +fn calc_iters_remaining(length: usize, min_length_for_iter: usize, consume: usize) -> usize { + let min_length_for_iter = min_length_for_iter.max(consume); + if length < min_length_for_iter { + return 0; + } + + let obvious_part = length - min_length_for_iter; + let obvious_iters = obvious_part / consume; + // let obvious_part_remaining = obvious_part % consume; + // let total_remaining = min_length_for_iter + obvious_part_remaining; + // assert!(total_remaining >= min_length_for_iter); // We have at least 1 more iter. + // assert!(obvious_part_remaining < consume); // Basic modulo property. + // assert!(total_remaining < min_length_for_iter + consume); // Add min_length_for_iter to both sides. + // assert!(total_remaining - consume < min_length_for_iter); // Not enough remaining after 1 iter. + 1 + obvious_iters // Thus always exactly 1 more iter. +} + +#[derive(Clone)] +pub struct TrueIdxIter<'a> { + mask: BitMask<'a>, + first_unknown: usize, + i: usize, + len: usize, + remaining: usize, +} + +impl<'a> TrueIdxIter<'a> { + #[inline] + pub fn new(len: usize, validity: Option<&'a Bitmap>) -> Self { + if let Some(bitmap) = validity { + assert!(len == bitmap.len()); + Self { + mask: BitMask::from_bitmap(bitmap), + first_unknown: 0, + i: 0, + remaining: bitmap.len() - bitmap.unset_bits(), + len, + } + } else { + Self { + mask: BitMask::default(), + first_unknown: len, + i: 0, + remaining: len, + len, + } + } + } +} + +impl Iterator for TrueIdxIter<'_> { + type Item = usize; + + #[inline] + fn next(&mut self) -> Option { + // Fast path for many non-nulls in a row. + if self.i < self.first_unknown { + let ret = self.i; + self.i += 1; + self.remaining -= 1; + return Some(ret); + } + + while self.i < self.len { + let mask = self.mask.get_u32(self.i); + let num_null = mask.trailing_zeros(); + self.i += num_null as usize; + if num_null < 32 { + self.first_unknown = self.i + (mask >> num_null).trailing_ones() as usize; + let ret = self.i; + self.i += 1; + self.remaining -= 1; + return Some(ret); + } + } + + None + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.remaining, Some(self.remaining)) + } +} + +unsafe impl TrustedLen for TrueIdxIter<'_> {} + +pub struct FastU32BitmapIter<'a> { + bytes: &'a [u8], + shift: u32, + bits_left: usize, +} + +impl<'a> FastU32BitmapIter<'a> { + pub fn new(bytes: &'a [u8], offset: usize, len: usize) -> Self { + assert!(bytes.len() * 8 >= offset + len); + let shift = (offset % 8) as u32; + let bytes = &bytes[offset / 8..]; + Self { + bytes, + shift, + bits_left: len, + } + } + + // The iteration logic that would normally follow the fast-path. + fn next_remainder(&mut self) -> Option { + if self.bits_left > 0 { + let word = load_padded_le_u64(self.bytes); + let mask; + if self.bits_left >= 32 { + mask = u32::MAX; + self.bits_left -= 32; + self.bytes = unsafe { self.bytes.get_unchecked(4..) }; + } else { + mask = (1 << self.bits_left) - 1; + self.bits_left = 0; + } + + return Some((word >> self.shift) as u32 & mask); + } + + None + } + + /// Returns the remainder bits and how many there are, + /// assuming the iterator was fully consumed. + pub fn remainder(mut self) -> (u64, usize) { + let bits_left = self.bits_left; + let lo = self.next_remainder().unwrap_or(0); + let hi = self.next_remainder().unwrap_or(0); + (((hi as u64) << 32) | (lo as u64), bits_left) + } +} + +impl Iterator for FastU32BitmapIter<'_> { + type Item = u32; + + #[inline] + fn next(&mut self) -> Option { + // Fast path, can load a whole u64. + if self.bits_left >= 64 { + let chunk; + unsafe { + // SAFETY: bits_left ensures this is in-bounds. + chunk = self.bytes.get_unchecked(0..8); + self.bytes = self.bytes.get_unchecked(4..); + } + self.bits_left -= 32; + let word = u64::from_le_bytes(chunk.try_into().unwrap()); + return Some((word >> self.shift) as u32); + } + + None + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let hint = calc_iters_remaining(self.bits_left, 64, 32); + (hint, Some(hint)) + } +} + +unsafe impl TrustedLen for FastU32BitmapIter<'_> {} + +#[derive(Clone)] +pub struct FastU56BitmapIter<'a> { + bytes: &'a [u8], + shift: u32, + bits_left: usize, +} + +impl<'a> FastU56BitmapIter<'a> { + pub fn new(bytes: &'a [u8], offset: usize, len: usize) -> Self { + assert!(bytes.len() * 8 >= offset + len); + let shift = (offset % 8) as u32; + let bytes = &bytes[offset / 8..]; + Self { + bytes, + shift, + bits_left: len, + } + } + + // The iteration logic that would normally follow the fast-path. + fn next_remainder(&mut self) -> Option { + if self.bits_left > 0 { + let word = load_padded_le_u64(self.bytes); + let mask; + if self.bits_left >= 56 { + mask = (1 << 56) - 1; + self.bits_left -= 56; + self.bytes = unsafe { self.bytes.get_unchecked(7..) }; + } else { + mask = (1 << self.bits_left) - 1; + self.bits_left = 0; + }; + + return Some((word >> self.shift) & mask); + } + + None + } + + /// Returns the remainder bits and how many there are, + /// assuming the iterator was fully consumed. Output is safe but + /// not specified if the iterator wasn't fully consumed. + pub fn remainder(mut self) -> (u64, usize) { + let bits_left = self.bits_left; + let lo = self.next_remainder().unwrap_or(0); + let hi = self.next_remainder().unwrap_or(0); + ((hi << 56) | lo, bits_left) + } +} + +impl Iterator for FastU56BitmapIter<'_> { + type Item = u64; + + #[inline] + fn next(&mut self) -> Option { + // Fast path, can load a whole u64. + if self.bits_left >= 64 { + let chunk; + unsafe { + // SAFETY: bits_left ensures this is in-bounds. + chunk = self.bytes.get_unchecked(0..8); + self.bytes = self.bytes.get_unchecked(7..); + self.bits_left -= 56; + } + + let word = u64::from_le_bytes(chunk.try_into().unwrap()); + let mask = (1 << 56) - 1; + return Some((word >> self.shift) & mask); + } + + None + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let hint = calc_iters_remaining(self.bits_left, 64, 56); + (hint, Some(hint)) + } +} + +unsafe impl TrustedLen for FastU56BitmapIter<'_> {} + +pub struct FastU64BitmapIter<'a> { + bytes: &'a [u8], + shift: u32, + bits_left: usize, + next_word: u64, +} + +impl<'a> FastU64BitmapIter<'a> { + pub fn new(bytes: &'a [u8], offset: usize, len: usize) -> Self { + assert!(bytes.len() * 8 >= offset + len); + let shift = (offset % 8) as u32; + let bytes = &bytes[offset / 8..]; + let next_word = load_padded_le_u64(bytes); + let bytes = bytes.get(8..).unwrap_or(&[]); + Self { + bytes, + shift, + bits_left: len, + next_word, + } + } + + #[inline] + fn combine(&self, lo: u64, hi: u64) -> u64 { + // Compiles to 128-bit SHRD instruction on x86-64. + // Yes, the % 64 is important for the compiler to generate optimal code. + let wide = ((hi as u128) << 64) | lo as u128; + (wide >> (self.shift % 64)) as u64 + } + + // The iteration logic that would normally follow the fast-path. + fn next_remainder(&mut self) -> Option { + if self.bits_left > 0 { + let lo = self.next_word; + let hi = load_padded_le_u64(self.bytes); + let mask; + if self.bits_left >= 64 { + mask = u64::MAX; + self.bits_left -= 64; + self.bytes = self.bytes.get(8..).unwrap_or(&[]); + } else { + mask = (1 << self.bits_left) - 1; + self.bits_left = 0; + }; + self.next_word = hi; + + return Some(self.combine(lo, hi) & mask); + } + + None + } + + /// Returns the remainder bits and how many there are, + /// assuming the iterator was fully consumed. Output is safe but + /// not specified if the iterator wasn't fully consumed. + pub fn remainder(mut self) -> ([u64; 2], usize) { + let bits_left = self.bits_left; + let lo = self.next_remainder().unwrap_or(0); + let hi = self.next_remainder().unwrap_or(0); + ([lo, hi], bits_left) + } +} + +impl Iterator for FastU64BitmapIter<'_> { + type Item = u64; + + #[inline] + fn next(&mut self) -> Option { + // Fast path: can load two u64s in a row. + // (Note that we already loaded one in the form of self.next_word). + if self.bits_left >= 128 { + let chunk; + unsafe { + // SAFETY: bits_left ensures this is in-bounds. + chunk = self.bytes.get_unchecked(0..8); + self.bytes = self.bytes.get_unchecked(8..); + } + let lo = self.next_word; + let hi = u64::from_le_bytes(chunk.try_into().unwrap()); + self.next_word = hi; + self.bits_left -= 64; + + return Some(self.combine(lo, hi)); + } + + None + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let hint = calc_iters_remaining(self.bits_left, 128, 64); + (hint, Some(hint)) + } +} + +unsafe impl TrustedLen for FastU64BitmapIter<'_> {} + +/// This crates' equivalent of [`std::vec::IntoIter`] for [`Bitmap`]. +#[derive(Debug, Clone)] +pub struct IntoIter { + values: Bitmap, + index: usize, + end: usize, +} + +impl IntoIter { + /// Creates a new [`IntoIter`] from a [`Bitmap`] + #[inline] + pub fn new(values: Bitmap) -> Self { + let end = values.len(); + Self { + values, + index: 0, + end, + } + } +} + +impl Iterator for IntoIter { + type Item = bool; + + #[inline] + fn next(&mut self) -> Option { + if self.index == self.end { + return None; + } + let old = self.index; + self.index += 1; + Some(unsafe { self.values.get_bit_unchecked(old) }) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.end - self.index, Some(self.end - self.index)) + } + + #[inline] + fn nth(&mut self, n: usize) -> Option { + let new_index = self.index + n; + if new_index > self.end { + self.index = self.end; + None + } else { + self.index = new_index; + self.next() + } + } +} + +impl DoubleEndedIterator for IntoIter { + #[inline] + fn next_back(&mut self) -> Option { + if self.index == self.end { + None + } else { + self.end -= 1; + Some(unsafe { self.values.get_bit_unchecked(self.end) }) + } + } +} + +unsafe impl TrustedLen for IntoIter {} diff --git a/crates/polars-arrow/src/bitmap/mod.rs b/crates/polars-arrow/src/bitmap/mod.rs new file mode 100644 index 000000000000..6d518bf596b4 --- /dev/null +++ b/crates/polars-arrow/src/bitmap/mod.rs @@ -0,0 +1,24 @@ +//! contains [`Bitmap`] and [`MutableBitmap`], containers of `bool`. +mod immutable; +pub use immutable::*; + +pub mod iterator; +pub use iterator::IntoIter; + +mod mutable; +pub use mutable::MutableBitmap; + +mod bitmap_ops; +pub use bitmap_ops::*; + +pub mod aligned; + +mod assign_ops; +pub use assign_ops::*; + +pub mod utils; + +pub mod bitmask; + +mod builder; +pub use builder::*; diff --git a/crates/polars-arrow/src/bitmap/mutable.rs b/crates/polars-arrow/src/bitmap/mutable.rs new file mode 100644 index 000000000000..be3397e0162e --- /dev/null +++ b/crates/polars-arrow/src/bitmap/mutable.rs @@ -0,0 +1,851 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use std::hint::unreachable_unchecked; + +use polars_error::{PolarsResult, polars_bail}; +use polars_utils::vec::PushUnchecked; + +use super::bitmask::BitMask; +use super::utils::{BitChunk, BitChunks, BitChunksExactMut, BitmapIter, count_zeros, fmt}; +use super::{Bitmap, intersects_with_mut}; +use crate::bitmap::utils::{get_bit_unchecked, merge_reversed, set_bit_in_byte}; +use crate::storage::SharedStorage; +use crate::trusted_len::TrustedLen; + +/// A container of booleans. [`MutableBitmap`] is semantically equivalent +/// to [`Vec`]. +/// +/// The two main differences against [`Vec`] is that each element stored as a single bit, +/// thereby: +/// * it uses 8x less memory +/// * it cannot be represented as `&[bool]` (i.e. no pointer arithmetics). +/// +/// A [`MutableBitmap`] can be converted to a [`Bitmap`] at `O(1)`. +/// # Examples +/// ``` +/// use polars_arrow::bitmap::MutableBitmap; +/// +/// let bitmap = MutableBitmap::from([true, false, true]); +/// assert_eq!(bitmap.iter().collect::>(), vec![true, false, true]); +/// +/// // creation directly from bytes +/// let mut bitmap = MutableBitmap::try_new(vec![0b00001101], 5).unwrap(); +/// // note: the first bit is the left-most of the first byte +/// assert_eq!(bitmap.iter().collect::>(), vec![true, false, true, true, false]); +/// // we can also get the slice: +/// assert_eq!(bitmap.as_slice(), [0b00001101u8].as_ref()); +/// // debug helps :) +/// assert_eq!(format!("{:?}", bitmap), "Bitmap { len: 5, offset: 0, bytes: [0b___01101] }"); +/// +/// // It supports mutation in place +/// bitmap.set(0, false); +/// assert_eq!(format!("{:?}", bitmap), "Bitmap { len: 5, offset: 0, bytes: [0b___01100] }"); +/// // and `O(1)` random access +/// assert_eq!(bitmap.get(0), false); +/// ``` +/// # Implementation +/// This container is internally a [`Vec`]. +#[derive(Clone)] +pub struct MutableBitmap { + buffer: Vec, + // invariant: length.saturating_add(7) / 8 == buffer.len(); + length: usize, +} + +impl std::fmt::Debug for MutableBitmap { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fmt(&self.buffer, 0, self.len(), f) + } +} + +impl PartialEq for MutableBitmap { + fn eq(&self, other: &Self) -> bool { + self.iter().eq(other.iter()) + } +} + +impl MutableBitmap { + /// Initializes an empty [`MutableBitmap`]. + #[inline] + pub fn new() -> Self { + Self { + buffer: Vec::new(), + length: 0, + } + } + + /// Initializes a new [`MutableBitmap`] from a [`Vec`] and a length. + /// # Errors + /// This function errors iff `length > bytes.len() * 8` + #[inline] + pub fn try_new(mut bytes: Vec, length: usize) -> PolarsResult { + if length > bytes.len().saturating_mul(8) { + polars_bail!(InvalidOperation: + "The length of the bitmap ({}) must be `<=` to the number of bytes times 8 ({})", + length, + bytes.len().saturating_mul(8) + ) + } + + // Ensure invariant holds. + let min_byte_length_needed = length.div_ceil(8); + bytes.drain(min_byte_length_needed..); + Ok(Self { + length, + buffer: bytes, + }) + } + + /// Initializes a [`MutableBitmap`] from a [`Vec`] and a length. + /// This function is `O(1)`. + /// # Panic + /// Panics iff the length is larger than the length of the buffer times 8. + #[inline] + pub fn from_vec(buffer: Vec, length: usize) -> Self { + Self::try_new(buffer, length).unwrap() + } + + /// Initializes a pre-allocated [`MutableBitmap`] with capacity for `capacity` bits. + #[inline] + pub fn with_capacity(capacity: usize) -> Self { + Self { + buffer: Vec::with_capacity(capacity.saturating_add(7) / 8), + length: 0, + } + } + + /// Pushes a new bit to the [`MutableBitmap`], re-sizing it if necessary. + #[inline] + pub fn push(&mut self, value: bool) { + if self.length % 8 == 0 { + self.buffer.push(0); + } + let byte = unsafe { self.buffer.last_mut().unwrap_unchecked() }; + *byte = set_bit_in_byte(*byte, self.length % 8, value); + self.length += 1; + } + + /// Pop the last bit from the [`MutableBitmap`]. + /// Note if the [`MutableBitmap`] is empty, this method will return None. + #[inline] + pub fn pop(&mut self) -> Option { + if self.is_empty() { + return None; + } + + self.length -= 1; + let value = unsafe { self.get_unchecked(self.length) }; + if self.length % 8 == 0 { + self.buffer.pop(); + } + Some(value) + } + + /// Returns whether the position `index` is set. + /// # Panics + /// Panics iff `index >= self.len()`. + #[inline] + pub fn get(&self, index: usize) -> bool { + assert!(index < self.len()); + unsafe { self.get_unchecked(index) } + } + + /// Returns whether the position `index` is set. + /// + /// # Safety + /// The caller must ensure `index < self.len()`. + #[inline] + pub unsafe fn get_unchecked(&self, index: usize) -> bool { + get_bit_unchecked(&self.buffer, index) + } + + /// Sets the position `index` to `value` + /// # Panics + /// Panics iff `index >= self.len()`. + #[inline] + pub fn set(&mut self, index: usize, value: bool) { + assert!(index < self.len()); + unsafe { + self.set_unchecked(index, value); + } + } + + /// Sets the position `index` to the OR of its original value and `value`. + /// + /// # Safety + /// It's undefined behavior if index >= self.len(). + #[inline] + pub unsafe fn or_pos_unchecked(&mut self, index: usize, value: bool) { + *self.buffer.get_unchecked_mut(index / 8) |= (value as u8) << (index % 8); + } + + /// Sets the position `index` to the AND of its original value and `value`. + /// + /// # Safety + /// It's undefined behavior if index >= self.len(). + #[inline] + pub unsafe fn and_pos_unchecked(&mut self, index: usize, value: bool) { + *self.buffer.get_unchecked_mut(index / 8) &= (value as u8) << (index % 8); + } + + /// constructs a new iterator over the bits of [`MutableBitmap`]. + pub fn iter(&self) -> BitmapIter { + BitmapIter::new(&self.buffer, 0, self.length) + } + + /// Empties the [`MutableBitmap`]. + #[inline] + pub fn clear(&mut self) { + self.length = 0; + self.buffer.clear(); + } + + /// Extends [`MutableBitmap`] by `additional` values of constant `value`. + /// # Implementation + /// This function is an order of magnitude faster than pushing element by element. + #[inline] + pub fn extend_constant(&mut self, additional: usize, value: bool) { + if additional == 0 { + return; + } + + if value { + self.extend_set(additional) + } else { + self.extend_unset(additional) + } + } + + /// Resizes the [`MutableBitmap`] to the specified length, inserting value + /// if the length is bigger than the current length. + pub fn resize(&mut self, length: usize, value: bool) { + if let Some(additional) = length.checked_sub(self.len()) { + self.extend_constant(additional, value); + } else { + self.buffer.truncate(length.saturating_add(7) / 8); + self.length = length; + } + } + + /// Initializes a zeroed [`MutableBitmap`]. + #[inline] + pub fn from_len_zeroed(length: usize) -> Self { + Self { + buffer: vec![0; length.saturating_add(7) / 8], + length, + } + } + + /// Initializes a [`MutableBitmap`] with all values set to valid/ true. + #[inline] + pub fn from_len_set(length: usize) -> Self { + Self { + buffer: vec![u8::MAX; length.saturating_add(7) / 8], + length, + } + } + + /// Reserves `additional` bits in the [`MutableBitmap`], potentially re-allocating its buffer. + #[inline(always)] + pub fn reserve(&mut self, additional: usize) { + self.buffer + .reserve((self.length + additional).saturating_add(7) / 8 - self.buffer.len()) + } + + /// Returns the capacity of [`MutableBitmap`] in number of bits. + #[inline] + pub fn capacity(&self) -> usize { + self.buffer.capacity() * 8 + } + + /// Pushes a new bit to the [`MutableBitmap`] + /// + /// # Safety + /// The caller must ensure that the [`MutableBitmap`] has sufficient capacity. + #[inline] + pub unsafe fn push_unchecked(&mut self, value: bool) { + if self.length % 8 == 0 { + self.buffer.push_unchecked(0); + } + let byte = self.buffer.last_mut().unwrap_unchecked(); + *byte = set_bit_in_byte(*byte, self.length % 8, value); + self.length += 1; + } + + /// Returns the number of unset bits on this [`MutableBitmap`]. + /// + /// Guaranteed to be `<= self.len()`. + /// # Implementation + /// This function is `O(N)` + pub fn unset_bits(&self) -> usize { + count_zeros(&self.buffer, 0, self.length) + } + + /// Returns the number of set bits on this [`MutableBitmap`]. + /// + /// Guaranteed to be `<= self.len()`. + /// # Implementation + /// This function is `O(N)` + pub fn set_bits(&self) -> usize { + self.length - self.unset_bits() + } + + /// Returns the length of the [`MutableBitmap`]. + #[inline] + pub fn len(&self) -> usize { + self.length + } + + /// Returns whether [`MutableBitmap`] is empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// # Safety + /// The caller must ensure that the [`MutableBitmap`] was properly initialized up to `len`. + #[inline] + pub(crate) unsafe fn set_len(&mut self, len: usize) { + self.buffer.set_len(len.saturating_add(7) / 8); + self.length = len; + } + + fn extend_set(&mut self, mut additional: usize) { + let offset = self.length % 8; + let added = if offset != 0 { + // offset != 0 => at least one byte in the buffer + let last_index = self.buffer.len() - 1; + let last = &mut self.buffer[last_index]; + + let remaining = 0b11111111u8; + let remaining = remaining >> 8usize.saturating_sub(additional); + let remaining = remaining << offset; + *last |= remaining; + std::cmp::min(additional, 8 - offset) + } else { + 0 + }; + self.length += added; + additional = additional.saturating_sub(added); + if additional > 0 { + debug_assert_eq!(self.length % 8, 0); + let existing = self.length.saturating_add(7) / 8; + let required = (self.length + additional).saturating_add(7) / 8; + // add remaining as full bytes + self.buffer + .extend(std::iter::repeat_n(0b11111111u8, required - existing)); + self.length += additional; + } + } + + fn extend_unset(&mut self, mut additional: usize) { + let offset = self.length % 8; + let added = if offset != 0 { + // offset != 0 => at least one byte in the buffer + let last_index = self.buffer.len() - 1; + let last = &mut self.buffer[last_index]; + *last &= 0b11111111u8 >> (8 - offset); // unset them + std::cmp::min(additional, 8 - offset) + } else { + 0 + }; + self.length += added; + additional = additional.saturating_sub(added); + if additional > 0 { + debug_assert_eq!(self.length % 8, 0); + self.buffer + .resize((self.length + additional).saturating_add(7) / 8, 0); + self.length += additional; + } + } + + /// Sets the position `index` to `value` + /// + /// # Safety + /// Caller must ensure that `index < self.len()` + #[inline] + pub unsafe fn set_unchecked(&mut self, index: usize, value: bool) { + debug_assert!(index < self.len()); + let byte = self.buffer.get_unchecked_mut(index / 8); + *byte = set_bit_in_byte(*byte, index % 8, value); + } + + /// Shrinks the capacity of the [`MutableBitmap`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.buffer.shrink_to_fit(); + } + + /// Returns an iterator over bits in bit chunks [`BitChunk`]. + /// + /// This iterator is useful to operate over multiple bits via e.g. bitwise. + pub fn chunks(&self) -> BitChunks { + BitChunks::new(&self.buffer, 0, self.length) + } + + /// Returns an iterator over mutable slices, [`BitChunksExactMut`] + pub(crate) fn bitchunks_exact_mut(&mut self) -> BitChunksExactMut { + BitChunksExactMut::new(&mut self.buffer, self.length) + } + + pub fn intersects_with(&self, other: &Self) -> bool { + intersects_with_mut(self, other) + } + + pub fn freeze(self) -> Bitmap { + self.into() + } +} + +impl From for Bitmap { + #[inline] + fn from(buffer: MutableBitmap) -> Self { + Bitmap::try_new(buffer.buffer, buffer.length).unwrap() + } +} + +impl From for Option { + #[inline] + fn from(buffer: MutableBitmap) -> Self { + let unset_bits = buffer.unset_bits(); + if unset_bits > 0 { + // SAFETY: invariants of the `MutableBitmap` equal that of `Bitmap`. + let bitmap = unsafe { + Bitmap::from_inner_unchecked( + SharedStorage::from_vec(buffer.buffer), + 0, + buffer.length, + Some(unset_bits), + ) + }; + Some(bitmap) + } else { + None + } + } +} + +impl> From

for MutableBitmap { + #[inline] + fn from(slice: P) -> Self { + MutableBitmap::from_trusted_len_iter(slice.as_ref().iter().copied()) + } +} + +impl Extend for MutableBitmap { + fn extend>(&mut self, iter: T) { + let mut iterator = iter.into_iter(); + + let mut buffer = std::mem::take(&mut self.buffer); + let mut length = std::mem::take(&mut self.length); + + let byte_capacity: usize = iterator.size_hint().0.saturating_add(7) / 8; + buffer.reserve(byte_capacity); + + loop { + let mut exhausted = false; + let mut byte_accum: u8 = 0; + let mut mask: u8 = 1; + + //collect (up to) 8 bits into a byte + while mask != 0 { + if let Some(value) = iterator.next() { + length += 1; + byte_accum |= match value { + true => mask, + false => 0, + }; + mask <<= 1; + } else { + exhausted = true; + break; + } + } + + // break if the iterator was exhausted before it provided a bool for this byte + if exhausted && mask == 1 { + break; + } + + //ensure we have capacity to write the byte + if buffer.len() == buffer.capacity() { + //no capacity for new byte, allocate 1 byte more (plus however many more the iterator advertises) + let additional_byte_capacity = 1usize.saturating_add( + iterator.size_hint().0.saturating_add(7) / 8, //convert bit count to byte count, rounding up + ); + buffer.reserve(additional_byte_capacity) + } + + // Soundness: capacity was allocated above + buffer.push(byte_accum); + if exhausted { + break; + } + } + + self.buffer = buffer; + self.length = length; + } +} + +impl FromIterator for MutableBitmap { + fn from_iter(iter: I) -> Self + where + I: IntoIterator, + { + let mut bm = Self::new(); + bm.extend(iter); + bm + } +} + +// [7, 6, 5, 4, 3, 2, 1, 0], [15, 14, 13, 12, 11, 10, 9, 8] +// [00000001_00000000_00000000_00000000_...] // u64 +/// # Safety +/// The iterator must be trustedLen and its len must be least `len`. +#[inline] +unsafe fn get_chunk_unchecked(iterator: &mut impl Iterator) -> u64 { + let mut byte = 0u64; + let mut mask; + for i in 0..8 { + mask = 1u64 << (8 * i); + for _ in 0..8 { + let value = match iterator.next() { + Some(value) => value, + None => unsafe { unreachable_unchecked() }, + }; + + byte |= match value { + true => mask, + false => 0, + }; + mask <<= 1; + } + } + byte +} + +/// # Safety +/// The iterator must be trustedLen and its len must be least `len`. +#[inline] +unsafe fn get_byte_unchecked(len: usize, iterator: &mut impl Iterator) -> u8 { + let mut byte_accum: u8 = 0; + let mut mask: u8 = 1; + for _ in 0..len { + let value = match iterator.next() { + Some(value) => value, + None => unsafe { unreachable_unchecked() }, + }; + + byte_accum |= match value { + true => mask, + false => 0, + }; + mask <<= 1; + } + byte_accum +} + +/// Extends the [`Vec`] from `iterator` +/// # Safety +/// The iterator MUST be [`TrustedLen`]. +#[inline] +unsafe fn extend_aligned_trusted_iter_unchecked( + buffer: &mut Vec, + mut iterator: impl Iterator, +) -> usize { + let additional_bits = iterator.size_hint().1.unwrap(); + let chunks = additional_bits / 64; + let remainder = additional_bits % 64; + + let additional = additional_bits.div_ceil(8); + assert_eq!( + additional, + // a hint of how the following calculation will be done + chunks * 8 + remainder / 8 + (remainder % 8 > 0) as usize + ); + buffer.reserve(additional); + + // chunks of 64 bits + for _ in 0..chunks { + let chunk = get_chunk_unchecked(&mut iterator); + buffer.extend_from_slice(&chunk.to_le_bytes()); + } + + // remaining complete bytes + for _ in 0..(remainder / 8) { + let byte = unsafe { get_byte_unchecked(8, &mut iterator) }; + buffer.push(byte) + } + + // remaining bits + let remainder = remainder % 8; + if remainder > 0 { + let byte = unsafe { get_byte_unchecked(remainder, &mut iterator) }; + buffer.push(byte) + } + additional_bits +} + +impl MutableBitmap { + /// Extends `self` from a [`TrustedLen`] iterator. + #[inline] + pub fn extend_from_trusted_len_iter>(&mut self, iterator: I) { + // SAFETY: I: TrustedLen + unsafe { self.extend_from_trusted_len_iter_unchecked(iterator) } + } + + /// Extends `self` from an iterator of trusted len. + /// + /// # Safety + /// The caller must guarantee that the iterator has a trusted len. + #[inline] + pub unsafe fn extend_from_trusted_len_iter_unchecked>( + &mut self, + mut iterator: I, + ) { + // the length of the iterator throughout this function. + let mut length = iterator.size_hint().1.unwrap(); + + let bit_offset = self.length % 8; + + if length < 8 - bit_offset { + if bit_offset == 0 { + self.buffer.push(0); + } + // the iterator will not fill the last byte + let byte = self.buffer.last_mut().unwrap(); + let mut i = bit_offset; + for value in iterator { + *byte = set_bit_in_byte(*byte, i, value); + i += 1; + } + self.length += length; + return; + } + + // at this point we know that length will hit a byte boundary and thus + // increase the buffer. + + if bit_offset != 0 { + // we are in the middle of a byte; lets finish it + let byte = self.buffer.last_mut().unwrap(); + (bit_offset..8).for_each(|i| { + *byte = set_bit_in_byte(*byte, i, iterator.next().unwrap()); + }); + self.length += 8 - bit_offset; + length -= 8 - bit_offset; + } + + // everything is aligned; proceed with the bulk operation + debug_assert_eq!(self.length % 8, 0); + + unsafe { extend_aligned_trusted_iter_unchecked(&mut self.buffer, iterator) }; + self.length += length; + } + + /// Creates a new [`MutableBitmap`] from an iterator of booleans. + /// + /// # Safety + /// The iterator must report an accurate length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + I: Iterator, + { + let mut buffer = Vec::::new(); + + let length = extend_aligned_trusted_iter_unchecked(&mut buffer, iterator); + + Self { buffer, length } + } + + /// Creates a new [`MutableBitmap`] from an iterator of booleans. + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + I: TrustedLen, + { + // SAFETY: Iterator is `TrustedLen` + unsafe { Self::from_trusted_len_iter_unchecked(iterator) } + } + + /// Creates a new [`MutableBitmap`] from an iterator of booleans. + pub fn try_from_trusted_len_iter(iterator: I) -> std::result::Result + where + I: TrustedLen>, + { + unsafe { Self::try_from_trusted_len_iter_unchecked(iterator) } + } + + /// Creates a new [`MutableBitmap`] from an falible iterator of booleans. + /// + /// # Safety + /// The caller must guarantee that the iterator is `TrustedLen`. + pub unsafe fn try_from_trusted_len_iter_unchecked( + mut iterator: I, + ) -> std::result::Result + where + I: Iterator>, + { + let length = iterator.size_hint().1.unwrap(); + + let mut buffer = vec![0u8; length.div_ceil(8)]; + + let chunks = length / 8; + let reminder = length % 8; + + let data = buffer.as_mut_slice(); + data[..chunks].iter_mut().try_for_each(|byte| { + (0..8).try_for_each(|i| { + *byte = set_bit_in_byte(*byte, i, iterator.next().unwrap()?); + Ok(()) + }) + })?; + + if reminder != 0 { + let last = &mut data[chunks]; + iterator.enumerate().try_for_each(|(i, value)| { + *last = set_bit_in_byte(*last, i, value?); + Ok(()) + })?; + } + + Ok(Self { buffer, length }) + } + + fn extend_unaligned(&mut self, slice: &[u8], offset: usize, length: usize) { + // e.g. + // [a, b, --101010] <- to be extended + // [00111111, 11010101] <- to extend + // [a, b, 11101010, --001111] expected result + + let aligned_offset = offset / 8; + let own_offset = self.length % 8; + debug_assert_eq!(offset % 8, 0); // assumed invariant + debug_assert!(own_offset != 0); // assumed invariant + + let bytes_len = length.saturating_add(7) / 8; + let items = &slice[aligned_offset..aligned_offset + bytes_len]; + // self has some offset => we need to shift all `items`, and merge the first + let buffer = self.buffer.as_mut_slice(); + let last = &mut buffer[buffer.len() - 1]; + + // --101010 | 00111111 << 6 = 11101010 + // erase previous + *last &= 0b11111111u8 >> (8 - own_offset); // unset before setting + *last |= items[0] << own_offset; + + if length + own_offset <= 8 { + // no new bytes needed + self.length += length; + return; + } + let additional = length - (8 - own_offset); + + let remaining = [items[items.len() - 1], 0]; + let bytes = items + .windows(2) + .chain(std::iter::once(remaining.as_ref())) + .map(|w| merge_reversed(w[0], w[1], 8 - own_offset)) + .take(additional.saturating_add(7) / 8); + self.buffer.extend(bytes); + + self.length += length; + } + + fn extend_aligned(&mut self, slice: &[u8], offset: usize, length: usize) { + let aligned_offset = offset / 8; + let bytes_len = length.saturating_add(7) / 8; + let items = &slice[aligned_offset..aligned_offset + bytes_len]; + self.buffer.extend_from_slice(items); + self.length += length; + } + + /// Extends the [`MutableBitmap`] from a slice of bytes with optional offset. + /// This is the fastest way to extend a [`MutableBitmap`]. + /// # Implementation + /// When both [`MutableBitmap`]'s length and `offset` are both multiples of 8, + /// this function performs a memcopy. Else, it first aligns bit by bit and then performs a memcopy. + /// + /// # Safety + /// Caller must ensure `offset + length <= slice.len() * 8` + #[inline] + pub unsafe fn extend_from_slice_unchecked( + &mut self, + slice: &[u8], + offset: usize, + length: usize, + ) { + if length == 0 { + return; + }; + let is_aligned = self.length % 8 == 0; + let other_is_aligned = offset % 8 == 0; + match (is_aligned, other_is_aligned) { + (true, true) => self.extend_aligned(slice, offset, length), + (false, true) => self.extend_unaligned(slice, offset, length), + // todo: further optimize the other branches. + _ => self.extend_from_trusted_len_iter(BitmapIter::new(slice, offset, length)), + } + // internal invariant: + debug_assert_eq!(self.length.saturating_add(7) / 8, self.buffer.len()); + } + + /// Extends the [`MutableBitmap`] from a slice of bytes with optional offset. + /// This is the fastest way to extend a [`MutableBitmap`]. + /// # Implementation + /// When both [`MutableBitmap`]'s length and `offset` are both multiples of 8, + /// this function performs a memcopy. Else, it first aligns bit by bit and then performs a memcopy. + #[inline] + pub fn extend_from_slice(&mut self, slice: &[u8], offset: usize, length: usize) { + assert!(offset + length <= slice.len() * 8); + // SAFETY: invariant is asserted + unsafe { self.extend_from_slice_unchecked(slice, offset, length) } + } + + #[inline] + pub fn extend_from_bitmask(&mut self, bitmask: BitMask<'_>) { + let (slice, offset, length) = bitmask.inner(); + self.extend_from_slice(slice, offset, length) + } + + /// Extends the [`MutableBitmap`] from a [`Bitmap`]. + #[inline] + pub fn extend_from_bitmap(&mut self, bitmap: &Bitmap) { + let (slice, offset, length) = bitmap.as_slice(); + // SAFETY: bitmap.as_slice adheres to the invariant + unsafe { + self.extend_from_slice_unchecked(slice, offset, length); + } + } + + /// Returns the slice of bytes of this [`MutableBitmap`]. + /// Note that the last byte may not be fully used. + #[inline] + pub fn as_slice(&self) -> &[u8] { + let len = (self.length).saturating_add(7) / 8; + &self.buffer[..len] + } + + /// Returns the slice of bytes of this [`MutableBitmap`]. + /// Note that the last byte may not be fully used. + #[inline] + pub fn as_mut_slice(&mut self) -> &mut [u8] { + let len = (self.length).saturating_add(7) / 8; + &mut self.buffer[..len] + } +} + +impl Default for MutableBitmap { + fn default() -> Self { + Self::new() + } +} + +impl<'a> IntoIterator for &'a MutableBitmap { + type Item = bool; + type IntoIter = BitmapIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + BitmapIter::<'a>::new(&self.buffer, 0, self.length) + } +} diff --git a/crates/polars-arrow/src/bitmap/utils/chunk_iterator/chunks_exact.rs b/crates/polars-arrow/src/bitmap/utils/chunk_iterator/chunks_exact.rs new file mode 100644 index 000000000000..50f4e023ac1f --- /dev/null +++ b/crates/polars-arrow/src/bitmap/utils/chunk_iterator/chunks_exact.rs @@ -0,0 +1,100 @@ +use std::slice::ChunksExact; + +use super::{BitChunk, BitChunkIterExact}; +use crate::trusted_len::TrustedLen; + +/// An iterator over a slice of bytes in [`BitChunk`]s. +#[derive(Debug)] +pub struct BitChunksExact<'a, T: BitChunk> { + iter: ChunksExact<'a, u8>, + remainder: &'a [u8], + remainder_len: usize, + phantom: std::marker::PhantomData, +} + +impl<'a, T: BitChunk> BitChunksExact<'a, T> { + /// Creates a new [`BitChunksExact`]. + #[inline] + pub fn new(bitmap: &'a [u8], length: usize) -> Self { + assert!(length <= bitmap.len() * 8); + let size_of = size_of::(); + + let bitmap = &bitmap[..length.saturating_add(7) / 8]; + + let split = (length / 8 / size_of) * size_of; + let (chunks, remainder) = bitmap.split_at(split); + let remainder_len = length - chunks.len() * 8; + let iter = chunks.chunks_exact(size_of); + + Self { + iter, + remainder, + remainder_len, + phantom: std::marker::PhantomData, + } + } + + /// Returns the number of chunks of this iterator + #[inline] + pub fn len(&self) -> usize { + self.iter.len() + } + + /// Returns whether there are still elements in this iterator + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns the remaining [`BitChunk`]. It is zero iff `len / 8 == 0`. + #[inline] + pub fn remainder(&self) -> T { + let remainder_bytes = self.remainder; + if remainder_bytes.is_empty() { + return T::zero(); + } + let remainder = match remainder_bytes.try_into() { + Ok(a) => a, + Err(_) => { + let mut remainder = T::zero().to_ne_bytes(); + remainder_bytes + .iter() + .enumerate() + .for_each(|(index, b)| remainder[index] = *b); + remainder + }, + }; + T::from_ne_bytes(remainder) + } +} + +impl Iterator for BitChunksExact<'_, T> { + type Item = T; + + #[inline] + fn next(&mut self) -> Option { + self.iter.next().map(|x| match x.try_into() { + Ok(a) => T::from_ne_bytes(a), + Err(_) => unreachable!(), + }) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } +} + +unsafe impl TrustedLen for BitChunksExact<'_, T> {} + +impl BitChunkIterExact for BitChunksExact<'_, T> { + #[inline] + fn remainder(&self) -> T { + self.remainder() + } + + #[inline] + fn remainder_len(&self) -> usize { + self.remainder_len + } +} diff --git a/crates/polars-arrow/src/bitmap/utils/chunk_iterator/merge.rs b/crates/polars-arrow/src/bitmap/utils/chunk_iterator/merge.rs new file mode 100644 index 000000000000..680d3bf96fa4 --- /dev/null +++ b/crates/polars-arrow/src/bitmap/utils/chunk_iterator/merge.rs @@ -0,0 +1,61 @@ +use super::BitChunk; + +/// Merges 2 [`BitChunk`]s into a single [`BitChunk`] so that the new items represents +/// the bitmap where bits from `next` are placed in `current` according to `offset`. +/// # Panic +/// The caller must ensure that `0 < offset < size_of::() * 8` +/// # Example +/// ```rust,ignore +/// let current = 0b01011001; +/// let next = 0b01011011; +/// let result = merge_reversed(current, next, 1); +/// assert_eq!(result, 0b10101100); +/// ``` +#[inline] +pub fn merge_reversed(mut current: T, mut next: T, offset: usize) -> T +where + T: BitChunk, +{ + // 8 _bits_: + // current = [c0, c1, c2, c3, c4, c5, c6, c7] + // next = [n0, n1, n2, n3, n4, n5, n6, n7] + // offset = 3 + // expected = [n5, n6, n7, c0, c1, c2, c3, c4] + + // 1. unset most significants of `next` up to `offset` + let inverse_offset = size_of::() * 8 - offset; + next <<= inverse_offset; + // next = [n5, n6, n7, 0 , 0 , 0 , 0 , 0 ] + + // 2. unset least significants of `current` up to `offset` + current >>= offset; + // current = [0 , 0 , 0 , c0, c1, c2, c3, c4] + + current | next +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_merge_reversed() { + let current = 0b00000000; + let next = 0b00000001; + let result = merge_reversed::(current, next, 1); + assert_eq!(result, 0b10000000); + + let current = 0b01011001; + let next = 0b01011011; + let result = merge_reversed::(current, next, 1); + assert_eq!(result, 0b10101100); + } + + #[test] + fn test_merge_reversed_offset2() { + let current = 0b00000000; + let next = 0b00000001; + let result = merge_reversed::(current, next, 3); + assert_eq!(result, 0b00100000); + } +} diff --git a/crates/polars-arrow/src/bitmap/utils/chunk_iterator/mod.rs b/crates/polars-arrow/src/bitmap/utils/chunk_iterator/mod.rs new file mode 100644 index 000000000000..ec158d743276 --- /dev/null +++ b/crates/polars-arrow/src/bitmap/utils/chunk_iterator/mod.rs @@ -0,0 +1,204 @@ +mod chunks_exact; +mod merge; + +pub use chunks_exact::BitChunksExact; +pub(crate) use merge::merge_reversed; + +use crate::trusted_len::TrustedLen; +pub use crate::types::BitChunk; +use crate::types::BitChunkIter; + +/// Trait representing an exact iterator over bytes in [`BitChunk`]. +pub trait BitChunkIterExact: TrustedLen { + /// The remainder of the iterator. + fn remainder(&self) -> B; + + /// The number of items in the remainder + fn remainder_len(&self) -> usize; + + /// An iterator over individual items of the remainder + #[inline] + fn remainder_iter(&self) -> BitChunkIter { + BitChunkIter::new(self.remainder(), self.remainder_len()) + } +} + +/// This struct is used to efficiently iterate over bit masks by loading bytes on +/// the stack with alignments of `uX`. This allows efficient iteration over bitmaps. +#[derive(Debug)] +pub struct BitChunks<'a, T: BitChunk> { + chunk_iterator: std::slice::ChunksExact<'a, u8>, + current: T, + remainder_bytes: &'a [u8], + last_chunk: T, + remaining: usize, + /// offset inside a byte + bit_offset: usize, + len: usize, + phantom: std::marker::PhantomData, +} + +/// writes `bytes` into `dst`. +#[inline] +fn copy_with_merge(dst: &mut T::Bytes, bytes: &[u8], bit_offset: usize) { + bytes + .windows(2) + .chain(std::iter::once([bytes[bytes.len() - 1], 0].as_ref())) + .take(size_of::()) + .enumerate() + .for_each(|(i, w)| { + let val = merge_reversed(w[0], w[1], bit_offset); + dst[i] = val; + }); +} + +impl<'a, T: BitChunk> BitChunks<'a, T> { + /// Creates a [`BitChunks`]. + pub fn new(slice: &'a [u8], offset: usize, len: usize) -> Self { + assert!(offset + len <= slice.len() * 8); + + let slice = &slice[offset / 8..]; + let bit_offset = offset % 8; + let size_of = size_of::(); + + let bytes_len = len / 8; + let bytes_upper_len = (len + bit_offset).div_ceil(8); + let mut chunks = slice[..bytes_len].chunks_exact(size_of); + + let remainder = &slice[bytes_len - chunks.remainder().len()..bytes_upper_len]; + + let remainder_bytes = if chunks.len() == 0 { slice } else { remainder }; + + let last_chunk = remainder_bytes + .first() + .map(|first| { + let mut last = T::zero().to_ne_bytes(); + last[0] = *first; + T::from_ne_bytes(last) + }) + .unwrap_or_else(T::zero); + + let remaining = chunks.size_hint().0; + + let current = chunks + .next() + .map(|x| match x.try_into() { + Ok(a) => T::from_ne_bytes(a), + Err(_) => unreachable!(), + }) + .unwrap_or_else(T::zero); + + Self { + chunk_iterator: chunks, + len, + current, + remaining, + remainder_bytes, + last_chunk, + bit_offset, + phantom: std::marker::PhantomData, + } + } + + #[inline] + fn load_next(&mut self) { + self.current = match self.chunk_iterator.next().unwrap().try_into() { + Ok(a) => T::from_ne_bytes(a), + Err(_) => unreachable!(), + }; + } + + /// Returns the remainder [`BitChunk`]. + pub fn remainder(&self) -> T { + // remaining bytes may not fit in `size_of::()`. We complement + // them to fit by allocating T and writing to it byte by byte + let mut remainder = T::zero().to_ne_bytes(); + + let remainder = match (self.remainder_bytes.is_empty(), self.bit_offset == 0) { + (true, _) => remainder, + (false, true) => { + // all remaining bytes + self.remainder_bytes + .iter() + .take(size_of::()) + .enumerate() + .for_each(|(i, val)| remainder[i] = *val); + + remainder + }, + (false, false) => { + // all remaining bytes + copy_with_merge::(&mut remainder, self.remainder_bytes, self.bit_offset); + remainder + }, + }; + T::from_ne_bytes(remainder) + } + + /// Returns the remainder bits in [`BitChunks::remainder`]. + pub fn remainder_len(&self) -> usize { + self.len - (size_of::() * ((self.len / 8) / size_of::()) * 8) + } +} + +impl Iterator for BitChunks<'_, T> { + type Item = T; + + #[inline] + fn next(&mut self) -> Option { + if self.remaining == 0 { + return None; + } + + let current = self.current; + let combined = if self.bit_offset == 0 { + // fast case where there is no offset. In this case, there is bit-alignment + // at byte boundary and thus the bytes correspond exactly. + if self.remaining >= 2 { + self.load_next(); + } + current + } else { + let next = if self.remaining >= 2 { + // case where `next` is complete and thus we can take it all + self.load_next(); + self.current + } else { + // case where the `next` is incomplete and thus we take the remaining + self.last_chunk + }; + merge_reversed(current, next, self.bit_offset) + }; + + self.remaining -= 1; + Some(combined) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + // it contains always one more than the chunk_iterator, which is the last + // one where the remainder is merged into current. + (self.remaining, Some(self.remaining)) + } +} + +impl BitChunkIterExact for BitChunks<'_, T> { + #[inline] + fn remainder(&self) -> T { + self.remainder() + } + + #[inline] + fn remainder_len(&self) -> usize { + self.remainder_len() + } +} + +impl ExactSizeIterator for BitChunks<'_, T> { + #[inline] + fn len(&self) -> usize { + self.chunk_iterator.len() + } +} + +unsafe impl TrustedLen for BitChunks<'_, T> {} diff --git a/crates/polars-arrow/src/bitmap/utils/chunks_exact_mut.rs b/crates/polars-arrow/src/bitmap/utils/chunks_exact_mut.rs new file mode 100644 index 000000000000..04a9f8b661a7 --- /dev/null +++ b/crates/polars-arrow/src/bitmap/utils/chunks_exact_mut.rs @@ -0,0 +1,63 @@ +use super::BitChunk; + +/// An iterator over mutable slices of bytes of exact size. +/// +/// # Safety +/// The slices returned by this iterator are guaranteed to have length equal to +/// `size_of::()`. +#[derive(Debug)] +pub struct BitChunksExactMut<'a, T: BitChunk> { + chunks: std::slice::ChunksExactMut<'a, u8>, + remainder: &'a mut [u8], + remainder_len: usize, + marker: std::marker::PhantomData, +} + +impl<'a, T: BitChunk> BitChunksExactMut<'a, T> { + /// Returns a new [`BitChunksExactMut`] + #[inline] + pub fn new(bitmap: &'a mut [u8], length: usize) -> Self { + assert!(length <= bitmap.len() * 8); + let size_of = size_of::(); + + let bitmap = &mut bitmap[..length.saturating_add(7) / 8]; + + let split = (length / 8 / size_of) * size_of; + let (chunks, remainder) = bitmap.split_at_mut(split); + let remainder_len = length - chunks.len() * 8; + + let chunks = chunks.chunks_exact_mut(size_of); + Self { + chunks, + remainder, + remainder_len, + marker: std::marker::PhantomData, + } + } + + /// The remainder slice + #[inline] + pub fn remainder(&mut self) -> &mut [u8] { + self.remainder + } + + /// The length of the remainder slice in bits. + #[inline] + pub fn remainder_len(&mut self) -> usize { + self.remainder_len + } +} + +impl<'a, T: BitChunk> Iterator for BitChunksExactMut<'a, T> { + type Item = &'a mut [u8]; + + #[inline] + fn next(&mut self) -> Option { + self.chunks.next() + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.chunks.size_hint() + } +} diff --git a/crates/polars-arrow/src/bitmap/utils/fmt.rs b/crates/polars-arrow/src/bitmap/utils/fmt.rs new file mode 100644 index 000000000000..3495820a6d0c --- /dev/null +++ b/crates/polars-arrow/src/bitmap/utils/fmt.rs @@ -0,0 +1,72 @@ +use std::fmt::Write; + +use super::is_set; + +/// Formats `bytes` taking into account an offset and length of the form +pub fn fmt( + bytes: &[u8], + offset: usize, + length: usize, + f: &mut std::fmt::Formatter<'_>, +) -> std::fmt::Result { + assert!(offset < 8); + + write!(f, "Bitmap {{ len: {length}, offset: {offset}, bytes: [")?; + let mut remaining = length; + if remaining == 0 { + f.write_str("] }")?; + return Ok(()); + } + + let first = bytes[0]; + let bytes = &bytes[1..]; + let empty_before = 8usize.saturating_sub(remaining + offset); + f.write_str("0b")?; + for _ in 0..empty_before { + f.write_char('_')?; + } + let until = std::cmp::min(8, offset + remaining); + for i in offset..until { + if is_set(first, offset + until - 1 - i) { + f.write_char('1')?; + } else { + f.write_char('0')?; + } + } + for _ in 0..offset { + f.write_char('_')?; + } + remaining -= until - offset; + + if remaining == 0 { + f.write_str("] }")?; + return Ok(()); + } + + let number_of_bytes = remaining / 8; + for byte in &bytes[..number_of_bytes] { + f.write_str(", ")?; + f.write_fmt(format_args!("{byte:#010b}"))?; + } + remaining -= number_of_bytes * 8; + if remaining == 0 { + f.write_str("] }")?; + return Ok(()); + } + + let last = bytes[std::cmp::min((length + offset).div_ceil(8), bytes.len() - 1)]; + let remaining = (length + offset) % 8; + f.write_str(", ")?; + f.write_str("0b")?; + for _ in 0..(8 - remaining) { + f.write_char('_')?; + } + for i in 0..remaining { + if is_set(last, remaining - 1 - i) { + f.write_char('1')?; + } else { + f.write_char('0')?; + } + } + f.write_str("] }") +} diff --git a/crates/polars-arrow/src/bitmap/utils/iterator.rs b/crates/polars-arrow/src/bitmap/utils/iterator.rs new file mode 100644 index 000000000000..243372599687 --- /dev/null +++ b/crates/polars-arrow/src/bitmap/utils/iterator.rs @@ -0,0 +1,386 @@ +use polars_utils::slice::load_padded_le_u64; + +use super::get_bit_unchecked; +use crate::bitmap::MutableBitmap; +use crate::trusted_len::TrustedLen; + +/// An iterator over bits according to the [LSB](https://en.wikipedia.org/wiki/Bit_numbering#Least_significant_bit), +/// i.e. the bytes `[4u8, 128u8]` correspond to `[false, false, true, false, ..., true]`. +#[derive(Debug, Clone)] +pub struct BitmapIter<'a> { + bytes: &'a [u8], + word: u64, + word_len: usize, + rest_len: usize, +} + +impl<'a> BitmapIter<'a> { + /// Creates a new [`BitmapIter`]. + pub fn new(bytes: &'a [u8], offset: usize, len: usize) -> Self { + if len == 0 { + return Self { + bytes, + word: 0, + word_len: 0, + rest_len: 0, + }; + } + + assert!(bytes.len() * 8 >= offset + len); + let first_byte_idx = offset / 8; + let bytes = &bytes[first_byte_idx..]; + let offset = offset % 8; + + // Make sure during our hot loop all our loads are full 8-byte loads + // by loading the remainder now if it exists. + let word = load_padded_le_u64(bytes) >> offset; + let mod8 = bytes.len() % 8; + let first_word_bytes = if mod8 > 0 { mod8 } else { 8 }; + let bytes = &bytes[first_word_bytes..]; + + let word_len = (first_word_bytes * 8 - offset).min(len); + let rest_len = len - word_len; + Self { + bytes, + word, + word_len, + rest_len, + } + } + + /// Consume and returns the numbers of `1` / `true` values at the beginning of the iterator. + /// + /// This performs the same operation as `(&mut iter).take_while(|b| b).count()`. + /// + /// This is a lot more efficient than consecutively polling the iterator and should therefore + /// be preferred, if the use-case allows for it. + pub fn take_leading_ones(&mut self) -> usize { + let word_ones = usize::min(self.word_len, self.word.trailing_ones() as usize); + self.word_len -= word_ones; + self.word = self.word.wrapping_shr(word_ones as u32); + + if self.word_len != 0 { + return word_ones; + } + + let mut num_leading_ones = word_ones; + + while self.rest_len != 0 { + self.word_len = usize::min(self.rest_len, 64); + self.rest_len -= self.word_len; + + unsafe { + let chunk = self.bytes.get_unchecked(..8).try_into().unwrap(); + self.word = u64::from_le_bytes(chunk); + self.bytes = self.bytes.get_unchecked(8..); + } + + let word_ones = usize::min(self.word_len, self.word.trailing_ones() as usize); + self.word_len -= word_ones; + self.word = self.word.wrapping_shr(word_ones as u32); + num_leading_ones += word_ones; + + if self.word_len != 0 { + return num_leading_ones; + } + } + + num_leading_ones + } + + /// Consume and returns the numbers of `0` / `false` values that the start of the iterator. + /// + /// This performs the same operation as `(&mut iter).take_while(|b| !b).count()`. + /// + /// This is a lot more efficient than consecutively polling the iterator and should therefore + /// be preferred, if the use-case allows for it. + pub fn take_leading_zeros(&mut self) -> usize { + let word_zeros = usize::min(self.word_len, self.word.trailing_zeros() as usize); + self.word_len -= word_zeros; + self.word = self.word.wrapping_shr(word_zeros as u32); + + if self.word_len != 0 { + return word_zeros; + } + + let mut num_leading_zeros = word_zeros; + + while self.rest_len != 0 { + self.word_len = usize::min(self.rest_len, 64); + self.rest_len -= self.word_len; + unsafe { + let chunk = self.bytes.get_unchecked(..8).try_into().unwrap(); + self.word = u64::from_le_bytes(chunk); + self.bytes = self.bytes.get_unchecked(8..); + } + + let word_zeros = usize::min(self.word_len, self.word.trailing_zeros() as usize); + self.word_len -= word_zeros; + self.word = self.word.wrapping_shr(word_zeros as u32); + num_leading_zeros += word_zeros; + + if self.word_len != 0 { + return num_leading_zeros; + } + } + + num_leading_zeros + } + + /// Returns the number of remaining elements in the iterator + #[inline] + pub fn num_remaining(&self) -> usize { + self.word_len + self.rest_len + } + + /// Collect at most `n` elements from this iterator into `bitmap` + pub fn collect_n_into(&mut self, bitmap: &mut MutableBitmap, n: usize) { + fn collect_word( + word: &mut u64, + word_len: &mut usize, + bitmap: &mut MutableBitmap, + n: &mut usize, + ) { + while *n > 0 && *word_len > 0 { + { + let trailing_ones = u32::min(word.trailing_ones(), *word_len as u32); + let shift = u32::min(usize::min(*n, u32::MAX as usize) as u32, trailing_ones); + *word = word.wrapping_shr(shift); + *word_len -= shift as usize; + *n -= shift as usize; + + bitmap.extend_constant(shift as usize, true); + } + + { + let trailing_zeros = u32::min(word.trailing_zeros(), *word_len as u32); + let shift = u32::min(usize::min(*n, u32::MAX as usize) as u32, trailing_zeros); + *word = word.wrapping_shr(shift); + *word_len -= shift as usize; + *n -= shift as usize; + + bitmap.extend_constant(shift as usize, false); + } + } + } + + let mut n = usize::min(n, self.num_remaining()); + bitmap.reserve(n); + + collect_word(&mut self.word, &mut self.word_len, bitmap, &mut n); + + if n == 0 { + return; + } + + let num_words = n / 64; + + if num_words > 0 { + assert!(self.bytes.len() >= num_words * size_of::()); + + bitmap.extend_from_slice(self.bytes, 0, num_words * u64::BITS as usize); + + self.bytes = unsafe { self.bytes.get_unchecked(num_words * 8..) }; + self.rest_len -= num_words * u64::BITS as usize; + n -= num_words * u64::BITS as usize; + } + + if n == 0 { + return; + } + + assert!(self.bytes.len() >= size_of::()); + + self.word_len = usize::min(self.rest_len, 64); + self.rest_len -= self.word_len; + unsafe { + let chunk = self.bytes.get_unchecked(..8).try_into().unwrap(); + self.word = u64::from_le_bytes(chunk); + self.bytes = self.bytes.get_unchecked(8..); + } + + collect_word(&mut self.word, &mut self.word_len, bitmap, &mut n); + + debug_assert!(self.num_remaining() == 0 || n == 0); + } +} + +impl Iterator for BitmapIter<'_> { + type Item = bool; + + #[inline] + fn next(&mut self) -> Option { + if self.word_len == 0 { + if self.rest_len == 0 { + return None; + } + + self.word_len = self.rest_len.min(64); + self.rest_len -= self.word_len; + + unsafe { + let chunk = self.bytes.get_unchecked(..8).try_into().unwrap(); + self.word = u64::from_le_bytes(chunk); + self.bytes = self.bytes.get_unchecked(8..); + } + } + + let ret = self.word & 1 != 0; + self.word >>= 1; + self.word_len -= 1; + Some(ret) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let num_remaining = self.num_remaining(); + (num_remaining, Some(num_remaining)) + } +} + +impl DoubleEndedIterator for BitmapIter<'_> { + #[inline] + fn next_back(&mut self) -> Option { + if self.rest_len > 0 { + self.rest_len -= 1; + Some(unsafe { get_bit_unchecked(self.bytes, self.rest_len) }) + } else if self.word_len > 0 { + self.word_len -= 1; + Some(self.word & (1 << self.word_len) != 0) + } else { + None + } + } +} + +unsafe impl TrustedLen for BitmapIter<'_> {} +impl ExactSizeIterator for BitmapIter<'_> {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_collect_into_17579() { + let mut bitmap = MutableBitmap::with_capacity(64); + BitmapIter::new(&[0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0], 0, 128) + .collect_n_into(&mut bitmap, 129); + + let bitmap = bitmap.freeze(); + + assert_eq!(bitmap.set_bits(), 4); + } + + #[test] + #[ignore = "Fuzz test. Too slow"] + fn test_fuzz_collect_into() { + for _ in 0..10_000 { + let mut set_bits = 0; + let mut unset_bits = 0; + + let mut length = 0; + let mut pattern = Vec::new(); + for _ in 0..rand::random::() % 1024 { + let bs = rand::random::() % 4; + + let word = match bs { + 0 => u64::MIN, + 1 => u64::MAX, + 2 | 3 => rand::random(), + _ => unreachable!(), + }; + + pattern.extend_from_slice(&word.to_le_bytes()); + set_bits += word.count_ones(); + unset_bits += word.count_zeros(); + length += 64; + } + + for _ in 0..rand::random::() % 7 { + let b = rand::random::(); + pattern.push(b); + set_bits += b.count_ones(); + unset_bits += b.count_zeros(); + length += 8; + } + + let last_length = rand::random::() % 8; + if last_length != 0 { + let b = rand::random::(); + pattern.push(b); + let ones = (b & ((1 << last_length) - 1)).count_ones(); + set_bits += ones; + unset_bits += last_length as u32 - ones; + length += last_length; + } + + let mut iter = BitmapIter::new(&pattern, 0, length); + let mut bitmap = MutableBitmap::with_capacity(length); + + while iter.num_remaining() > 0 { + let len_before = bitmap.len(); + let n = rand::random::() % iter.num_remaining(); + iter.collect_n_into(&mut bitmap, n); + + // Ensure we are booking the progress we expect + assert_eq!(bitmap.len(), len_before + n); + } + + let bitmap = bitmap.freeze(); + + assert_eq!(bitmap.set_bits(), set_bits as usize); + assert_eq!(bitmap.unset_bits(), unset_bits as usize); + } + } + + #[test] + #[ignore = "Fuzz test. Too slow"] + fn test_fuzz_leading_ops() { + for _ in 0..10_000 { + let mut length = 0; + let mut pattern = Vec::new(); + for _ in 0..rand::random::() % 1024 { + let bs = rand::random::() % 4; + + let word = match bs { + 0 => u64::MIN, + 1 => u64::MAX, + 2 | 3 => rand::random(), + _ => unreachable!(), + }; + + pattern.extend_from_slice(&word.to_le_bytes()); + length += 64; + } + + for _ in 0..rand::random::() % 7 { + pattern.push(rand::random::()); + length += 8; + } + + let last_length = rand::random::() % 8; + if last_length != 0 { + pattern.push(rand::random::()); + length += last_length; + } + + let mut iter = BitmapIter::new(&pattern, 0, length); + + let mut prev_remaining = iter.num_remaining(); + while iter.num_remaining() != 0 { + let num_ones = iter.clone().take_leading_ones(); + assert_eq!(num_ones, (&mut iter).take_while(|&b| b).count()); + + let num_zeros = iter.clone().take_leading_zeros(); + assert_eq!(num_zeros, (&mut iter).take_while(|&b| !b).count()); + + // Ensure that we are making progress + assert!(iter.num_remaining() < prev_remaining); + prev_remaining = iter.num_remaining(); + } + + assert_eq!(iter.take_leading_zeros(), 0); + assert_eq!(iter.take_leading_ones(), 0); + } + } +} diff --git a/crates/polars-arrow/src/bitmap/utils/mod.rs b/crates/polars-arrow/src/bitmap/utils/mod.rs new file mode 100644 index 000000000000..9b94ebcaa417 --- /dev/null +++ b/crates/polars-arrow/src/bitmap/utils/mod.rs @@ -0,0 +1,298 @@ +#![allow(unsafe_op_in_unsafe_fn)] +//! General utilities for bitmaps representing items where LSB is the first item. +mod chunk_iterator; +mod chunks_exact_mut; +mod fmt; +mod iterator; +mod slice_iterator; +mod zip_validity; + +pub(crate) use chunk_iterator::merge_reversed; +pub use chunk_iterator::{BitChunk, BitChunkIterExact, BitChunks, BitChunksExact}; +pub use chunks_exact_mut::BitChunksExactMut; +pub use fmt::fmt; +pub use iterator::BitmapIter; +use polars_utils::slice::load_padded_le_u64; +pub use slice_iterator::SlicesIterator; +pub use zip_validity::{ZipValidity, ZipValidityIter}; + +use crate::bitmap::aligned::AlignedBitmapSlice; + +/// Returns whether bit at position `i` in `byte` is set or not +#[inline] +pub fn is_set(byte: u8, i: usize) -> bool { + debug_assert!(i < 8); + byte & (1 << i) != 0 +} + +/// Sets bit at position `i` in `byte`. +#[inline(always)] +pub fn set_bit_in_byte(byte: u8, i: usize, value: bool) -> u8 { + debug_assert!(i < 8); + let mask = !(1 << i); + let insert = (value as u8) << i; + (byte & mask) | insert +} + +/// Returns whether bit at position `i` in `bytes` is set or not. +/// +/// # Safety +/// `i >= bytes.len() * 8` results in undefined behavior. +#[inline(always)] +pub unsafe fn get_bit_unchecked(bytes: &[u8], i: usize) -> bool { + let byte = *bytes.get_unchecked(i / 8); + let bit = (byte >> (i % 8)) & 1; + bit != 0 +} + +/// Sets bit at position `i` in `bytes` without doing bound checks. +/// # Safety +/// `i >= bytes.len() * 8` results in undefined behavior. +#[inline(always)] +pub unsafe fn set_bit_unchecked(bytes: &mut [u8], i: usize, value: bool) { + let byte = bytes.get_unchecked_mut(i / 8); + *byte = set_bit_in_byte(*byte, i % 8, value); +} + +/// Returns the number of bytes required to hold `bits` bits. +#[inline] +pub fn bytes_for(bits: usize) -> usize { + bits.saturating_add(7) / 8 +} + +/// Returns the number of zero bits in the slice offsetted by `offset` and a length of `length`. +/// # Panics +/// This function panics iff `offset + len > 8 * slice.len()``. +pub fn count_zeros(slice: &[u8], offset: usize, len: usize) -> usize { + if len == 0 { + return 0; + } + + assert!(8 * slice.len() >= offset + len); + + // Fast-path: fits in a single u64 load. + let first_byte_idx = offset / 8; + let offset_in_byte = offset % 8; + if offset_in_byte + len <= 64 { + let mut word = load_padded_le_u64(&slice[first_byte_idx..]); + word >>= offset_in_byte; + word <<= 64 - len; + return len - word.count_ones() as usize; + } + + let aligned = AlignedBitmapSlice::::new(slice, offset, len); + let ones_in_prefix = aligned.prefix().count_ones() as usize; + let ones_in_bulk: usize = aligned.bulk_iter().map(|w| w.count_ones() as usize).sum(); + let ones_in_suffix = aligned.suffix().count_ones() as usize; + len - ones_in_prefix - ones_in_bulk - ones_in_suffix +} + +/// Returns the number of zero bits before seeing a one bit in the slice offsetted by `offset` and +/// a length of `length`. +/// +/// # Panics +/// This function panics iff `offset + len > 8 * slice.len()``. +pub fn leading_zeros(slice: &[u8], offset: usize, len: usize) -> usize { + if len == 0 { + return 0; + } + + assert!(8 * slice.len() >= offset + len); + + let aligned = AlignedBitmapSlice::::new(slice, offset, len); + let leading_zeros_in_prefix = + (aligned.prefix().trailing_zeros() as usize).min(aligned.prefix_bitlen()); + if leading_zeros_in_prefix < aligned.prefix_bitlen() { + return leading_zeros_in_prefix; + } + if let Some(full_zero_bulk_words) = aligned.bulk_iter().position(|w| w != 0) { + return aligned.prefix_bitlen() + + full_zero_bulk_words * 64 + + aligned.bulk()[full_zero_bulk_words].trailing_zeros() as usize; + } + + aligned.prefix_bitlen() + + aligned.bulk_bitlen() + + (aligned.suffix().trailing_zeros() as usize).min(aligned.suffix_bitlen()) +} + +/// Returns the number of one bits before seeing a zero bit in the slice offsetted by `offset` and +/// a length of `length`. +/// +/// # Panics +/// This function panics iff `offset + len > 8 * slice.len()``. +pub fn leading_ones(slice: &[u8], offset: usize, len: usize) -> usize { + if len == 0 { + return 0; + } + + assert!(8 * slice.len() >= offset + len); + + let aligned = AlignedBitmapSlice::::new(slice, offset, len); + let leading_ones_in_prefix = aligned.prefix().trailing_ones() as usize; + if leading_ones_in_prefix < aligned.prefix_bitlen() { + return leading_ones_in_prefix; + } + if let Some(full_one_bulk_words) = aligned.bulk_iter().position(|w| w != u64::MAX) { + return aligned.prefix_bitlen() + + full_one_bulk_words * 64 + + aligned.bulk()[full_one_bulk_words].trailing_ones() as usize; + } + + aligned.prefix_bitlen() + aligned.bulk_bitlen() + aligned.suffix().trailing_ones() as usize +} + +/// Returns the number of zero bits before seeing a one bit in the slice offsetted by `offset` and +/// a length of `length`. +/// +/// # Panics +/// This function panics iff `offset + len > 8 * slice.len()``. +pub fn trailing_zeros(slice: &[u8], offset: usize, len: usize) -> usize { + if len == 0 { + return 0; + } + + assert!(8 * slice.len() >= offset + len); + + let aligned = AlignedBitmapSlice::::new(slice, offset, len); + let trailing_zeros_in_suffix = ((aligned.suffix() << ((64 - aligned.suffix_bitlen()) % 64)) + .leading_zeros() as usize) + .min(aligned.suffix_bitlen()); + if trailing_zeros_in_suffix < aligned.suffix_bitlen() { + return trailing_zeros_in_suffix; + } + if let Some(full_zero_bulk_words) = aligned.bulk_iter().rev().position(|w| w != 0) { + return aligned.suffix_bitlen() + + full_zero_bulk_words * 64 + + aligned.bulk()[aligned.bulk().len() - full_zero_bulk_words - 1].leading_zeros() + as usize; + } + + let trailing_zeros_in_prefix = ((aligned.prefix() << ((64 - aligned.prefix_bitlen()) % 64)) + .leading_zeros() as usize) + .min(aligned.prefix_bitlen()); + aligned.suffix_bitlen() + aligned.bulk_bitlen() + trailing_zeros_in_prefix +} + +/// Returns the number of one bits before seeing a zero bit in the slice offsetted by `offset` and +/// a length of `length`. +/// +/// # Panics +/// This function panics iff `offset + len > 8 * slice.len()``. +pub fn trailing_ones(slice: &[u8], offset: usize, len: usize) -> usize { + if len == 0 { + return 0; + } + + assert!(8 * slice.len() >= offset + len); + + let aligned = AlignedBitmapSlice::::new(slice, offset, len); + let trailing_ones_in_suffix = + (aligned.suffix() << ((64 - aligned.suffix_bitlen()) % 64)).leading_ones() as usize; + if trailing_ones_in_suffix < aligned.suffix_bitlen() { + return trailing_ones_in_suffix; + } + if let Some(full_one_bulk_words) = aligned.bulk_iter().rev().position(|w| w != u64::MAX) { + return aligned.suffix_bitlen() + + full_one_bulk_words * 64 + + aligned.bulk()[aligned.bulk().len() - full_one_bulk_words - 1].leading_ones() + as usize; + } + + let trailing_ones_in_prefix = + (aligned.prefix() << ((64 - aligned.prefix_bitlen()) % 64)).leading_ones() as usize; + aligned.suffix_bitlen() + aligned.bulk_bitlen() + trailing_ones_in_prefix +} + +#[cfg(test)] +mod tests { + use rand::Rng; + + use super::*; + use crate::bitmap::Bitmap; + + #[test] + fn leading_trailing() { + macro_rules! testcase { + ($slice:expr, $offset:expr, $length:expr => lz=$lz:expr,lo=$lo:expr,tz=$tz:expr,to=$to:expr) => { + assert_eq!( + leading_zeros($slice, $offset, $length), + $lz, + "leading_zeros" + ); + assert_eq!(leading_ones($slice, $offset, $length), $lo, "leading_ones"); + assert_eq!( + trailing_zeros($slice, $offset, $length), + $tz, + "trailing_zeros" + ); + assert_eq!( + trailing_ones($slice, $offset, $length), + $to, + "trailing_ones" + ); + }; + } + + testcase!(&[], 0, 0 => lz=0,lo=0,tz=0,to=0); + testcase!(&[0], 0, 1 => lz=1,lo=0,tz=1,to=0); + testcase!(&[1], 0, 1 => lz=0,lo=1,tz=0,to=1); + + testcase!(&[0b010], 0, 3 => lz=1,lo=0,tz=1,to=0); + testcase!(&[0b101], 0, 3 => lz=0,lo=1,tz=0,to=1); + testcase!(&[0b100], 0, 3 => lz=2,lo=0,tz=0,to=1); + testcase!(&[0b110], 0, 3 => lz=1,lo=0,tz=0,to=2); + testcase!(&[0b001], 0, 3 => lz=0,lo=1,tz=2,to=0); + testcase!(&[0b011], 0, 3 => lz=0,lo=2,tz=1,to=0); + + testcase!(&[0b010], 1, 2 => lz=0,lo=1,tz=1,to=0); + testcase!(&[0b101], 1, 2 => lz=1,lo=0,tz=0,to=1); + testcase!(&[0b100], 1, 2 => lz=1,lo=0,tz=0,to=1); + testcase!(&[0b110], 1, 2 => lz=0,lo=2,tz=0,to=2); + testcase!(&[0b001], 1, 2 => lz=2,lo=0,tz=2,to=0); + testcase!(&[0b011], 1, 2 => lz=0,lo=1,tz=1,to=0); + } + + #[ignore = "Fuzz test. Too slow"] + #[test] + fn leading_trailing_fuzz() { + let mut rng = rand::thread_rng(); + + const SIZE: usize = 1000; + const REPEATS: usize = 10_000; + + let mut v = Vec::::with_capacity(SIZE); + + for _ in 0..REPEATS { + v.clear(); + let offset = rng.gen_range(0..SIZE); + let length = rng.gen_range(0..SIZE - offset); + let extra_padding = rng.gen_range(0..64); + + let mut num_remaining = usize::min(SIZE, offset + length + extra_padding); + while num_remaining > 0 { + let chunk_size = rng.gen_range(1..=num_remaining); + v.extend( + rng.clone() + .sample_iter(rand::distributions::Slice::new(&[false, true]).unwrap()) + .take(chunk_size), + ); + num_remaining -= chunk_size; + } + + let v_slice = &v[offset..offset + length]; + let lz = v_slice.iter().take_while(|&v| !*v).count(); + let lo = v_slice.iter().take_while(|&v| *v).count(); + let tz = v_slice.iter().rev().take_while(|&v| !*v).count(); + let to = v_slice.iter().rev().take_while(|&v| *v).count(); + + let bm = Bitmap::from_iter(v.iter().copied()); + let (slice, _, _) = bm.as_slice(); + + assert_eq!(leading_zeros(slice, offset, length), lz); + assert_eq!(leading_ones(slice, offset, length), lo); + assert_eq!(trailing_zeros(slice, offset, length), tz); + assert_eq!(trailing_ones(slice, offset, length), to); + } + } +} diff --git a/crates/polars-arrow/src/bitmap/utils/slice_iterator.rs b/crates/polars-arrow/src/bitmap/utils/slice_iterator.rs new file mode 100644 index 000000000000..9f43a3dfe89a --- /dev/null +++ b/crates/polars-arrow/src/bitmap/utils/slice_iterator.rs @@ -0,0 +1,146 @@ +use crate::bitmap::Bitmap; + +/// Internal state of [`SlicesIterator`] +#[derive(Debug, Clone, PartialEq)] +enum State { + // normal iteration + Nominal, + // nothing more to iterate. + Finished, +} + +/// Iterator over a bitmap that returns slices of set regions. +/// +/// This is the most efficient method to extract slices of values from arrays +/// with a validity bitmap. +/// For example, the bitmap `00101111` returns `[(0,4), (6,1)]` +#[derive(Debug, Clone)] +pub struct SlicesIterator<'a> { + values: std::slice::Iter<'a, u8>, + count: usize, + mask: u8, + max_len: usize, + current_byte: &'a u8, + state: State, + len: usize, + start: usize, + on_region: bool, +} + +impl<'a> SlicesIterator<'a> { + /// Creates a new [`SlicesIterator`] + pub fn new(values: &'a Bitmap) -> Self { + let (buffer, offset, _) = values.as_slice(); + let mut iter = buffer.iter(); + + let (current_byte, state) = match iter.next() { + Some(b) => (b, State::Nominal), + None => (&0, State::Finished), + }; + + Self { + state, + count: values.len() - values.unset_bits(), + max_len: values.len(), + values: iter, + mask: 1u8.rotate_left(offset as u32), + current_byte, + len: 0, + start: 0, + on_region: false, + } + } + + #[inline] + fn finish(&mut self) -> Option<(usize, usize)> { + self.state = State::Finished; + if self.on_region { + Some((self.start, self.len)) + } else { + None + } + } + + #[inline] + fn current_len(&self) -> usize { + self.start + self.len + } + + /// Returns the total number of slots. + /// It corresponds to the sum of all lengths of all slices. + #[inline] + pub fn slots(&self) -> usize { + self.count + } +} + +impl Iterator for SlicesIterator<'_> { + type Item = (usize, usize); + + #[inline] + fn next(&mut self) -> Option { + loop { + if self.state == State::Finished { + return None; + } + if self.current_len() == self.max_len { + return self.finish(); + } + + if self.mask == 1 { + // at the beginning of a byte => try to skip it all together + match (self.on_region, self.current_byte) { + (true, &255u8) => { + self.len = std::cmp::min(self.max_len - self.start, self.len + 8); + if let Some(v) = self.values.next() { + self.current_byte = v; + }; + continue; + }, + (false, &0) => { + self.len = std::cmp::min(self.max_len - self.start, self.len + 8); + if let Some(v) = self.values.next() { + self.current_byte = v; + }; + continue; + }, + _ => (), // we need to run over all bits of this byte + } + }; + + let value = (self.current_byte & self.mask) != 0; + self.mask = self.mask.rotate_left(1); + + match (self.on_region, value) { + (true, true) => self.len += 1, + (false, false) => self.len += 1, + (true, false) => { + self.on_region = false; + let result = (self.start, self.len); + self.start += self.len; + self.len = 1; + if self.mask == 1 { + // reached a new byte => try to fetch it from the iterator + if let Some(v) = self.values.next() { + self.current_byte = v; + }; + } + return Some(result); + }, + (false, true) => { + self.start += self.len; + self.len = 1; + self.on_region = true; + }, + } + + if self.mask == 1 { + // reached a new byte => try to fetch it from the iterator + match self.values.next() { + Some(v) => self.current_byte = v, + None => return self.finish(), + }; + } + } + } +} diff --git a/crates/polars-arrow/src/bitmap/utils/zip_validity.rs b/crates/polars-arrow/src/bitmap/utils/zip_validity.rs new file mode 100644 index 000000000000..9f74ac33c674 --- /dev/null +++ b/crates/polars-arrow/src/bitmap/utils/zip_validity.rs @@ -0,0 +1,217 @@ +use crate::bitmap::Bitmap; +use crate::bitmap::utils::BitmapIter; +use crate::trusted_len::TrustedLen; + +/// An [`Iterator`] over validity and values. +#[derive(Debug, Clone)] +pub struct ZipValidityIter +where + I: Iterator, + V: Iterator, +{ + values: I, + validity: V, +} + +impl ZipValidityIter +where + I: Iterator, + V: Iterator, +{ + /// Creates a new [`ZipValidityIter`]. + /// # Panics + /// This function panics if the size_hints of the iterators are different + pub fn new(values: I, validity: V) -> Self { + assert_eq!(values.size_hint(), validity.size_hint()); + Self { values, validity } + } +} + +impl Iterator for ZipValidityIter +where + I: Iterator, + V: Iterator, +{ + type Item = Option; + + #[inline] + fn next(&mut self) -> Option { + let value = self.values.next(); + let is_valid = self.validity.next(); + is_valid + .zip(value) + .map(|(is_valid, value)| is_valid.then(|| value)) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.values.size_hint() + } + + #[inline] + fn nth(&mut self, n: usize) -> Option { + let value = self.values.nth(n); + let is_valid = self.validity.nth(n); + is_valid + .zip(value) + .map(|(is_valid, value)| is_valid.then(|| value)) + } +} + +impl DoubleEndedIterator for ZipValidityIter +where + I: DoubleEndedIterator, + V: DoubleEndedIterator, +{ + #[inline] + fn next_back(&mut self) -> Option { + let value = self.values.next_back(); + let is_valid = self.validity.next_back(); + is_valid + .zip(value) + .map(|(is_valid, value)| is_valid.then(|| value)) + } +} + +unsafe impl TrustedLen for ZipValidityIter +where + I: TrustedLen, + V: TrustedLen, +{ +} + +impl ExactSizeIterator for ZipValidityIter +where + I: ExactSizeIterator, + V: ExactSizeIterator, +{ +} + +/// An [`Iterator`] over [`Option`] +/// This enum can be used in two distinct ways: +/// * as an iterator, via `Iterator::next` +/// * as an enum of two iterators, via `match self` +/// +/// The latter allows specializalizing to when there are no nulls +#[derive(Debug, Clone)] +pub enum ZipValidity +where + I: Iterator, + V: Iterator, +{ + /// There are no null values + Required(I), + /// There are null values + Optional(ZipValidityIter), +} + +impl ZipValidity +where + I: Iterator, + V: Iterator, +{ + /// Returns a new [`ZipValidity`] + pub fn new(values: I, validity: Option) -> Self { + match validity { + Some(validity) => Self::Optional(ZipValidityIter::new(values, validity)), + _ => Self::Required(values), + } + } +} + +impl<'a, T, I> ZipValidity> +where + I: Iterator, +{ + /// Returns a new [`ZipValidity`] and drops the `validity` if all values + /// are valid. + pub fn new_with_validity(values: I, validity: Option<&'a Bitmap>) -> Self { + // only if the validity has nulls we take the optional branch. + match validity.and_then(|validity| (validity.unset_bits() > 0).then(|| validity.iter())) { + Some(validity) => Self::Optional(ZipValidityIter::new(values, validity)), + _ => Self::Required(values), + } + } +} + +impl Iterator for ZipValidity +where + I: Iterator, + V: Iterator, +{ + type Item = Option; + + #[inline] + fn next(&mut self) -> Option { + match self { + Self::Required(values) => values.next().map(Some), + Self::Optional(zipped) => zipped.next(), + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + match self { + Self::Required(values) => values.size_hint(), + Self::Optional(zipped) => zipped.size_hint(), + } + } + + #[inline] + fn nth(&mut self, n: usize) -> Option { + match self { + Self::Required(values) => values.nth(n).map(Some), + Self::Optional(zipped) => zipped.nth(n), + } + } +} + +impl DoubleEndedIterator for ZipValidity +where + I: DoubleEndedIterator, + V: DoubleEndedIterator, +{ + #[inline] + fn next_back(&mut self) -> Option { + match self { + Self::Required(values) => values.next_back().map(Some), + Self::Optional(zipped) => zipped.next_back(), + } + } +} + +impl ExactSizeIterator for ZipValidity +where + I: ExactSizeIterator, + V: ExactSizeIterator, +{ +} + +unsafe impl TrustedLen for ZipValidity +where + I: TrustedLen, + V: TrustedLen, +{ +} + +impl ZipValidity +where + I: Iterator, + V: Iterator, +{ + /// Unwrap into an iterator that has no null values. + pub fn unwrap_required(self) -> I { + match self { + ZipValidity::Required(i) => i, + _ => panic!("Could not 'unwrap_required'. 'ZipValidity' iterator has nulls."), + } + } + + /// Unwrap into an iterator that has null values. + pub fn unwrap_optional(self) -> ZipValidityIter { + match self { + ZipValidity::Optional(i) => i, + _ => panic!("Could not 'unwrap_optional'. 'ZipValidity' iterator has no nulls."), + } + } +} diff --git a/crates/polars-arrow/src/buffer/immutable.rs b/crates/polars-arrow/src/buffer/immutable.rs new file mode 100644 index 000000000000..b23105169ea7 --- /dev/null +++ b/crates/polars-arrow/src/buffer/immutable.rs @@ -0,0 +1,362 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use std::ops::Deref; + +use bytemuck::Zeroable; +use either::Either; + +use super::IntoIter; +use crate::array::{ArrayAccessor, Splitable}; +use crate::storage::SharedStorage; + +/// [`Buffer`] is a contiguous memory region that can be shared across +/// thread boundaries. +/// +/// The easiest way to think about [`Buffer`] is being equivalent to +/// a `Arc>`, with the following differences: +/// * slicing and cloning is `O(1)`. +/// * it supports external allocated memory +/// +/// The easiest way to create one is to use its implementation of `From>`. +/// +/// # Examples +/// ``` +/// use polars_arrow::buffer::Buffer; +/// +/// let mut buffer: Buffer = vec![1, 2, 3].into(); +/// assert_eq!(buffer.as_ref(), [1, 2, 3].as_ref()); +/// +/// // it supports copy-on-write semantics (i.e. back to a `Vec`) +/// let vec: Vec = buffer.into_mut().right().unwrap(); +/// assert_eq!(vec, vec![1, 2, 3]); +/// +/// // cloning and slicing is `O(1)` (data is shared) +/// let mut buffer: Buffer = vec![1, 2, 3].into(); +/// let mut sliced = buffer.clone(); +/// sliced.slice(1, 1); +/// assert_eq!(sliced.as_ref(), [2].as_ref()); +/// // but cloning forbids getting mut since `slice` and `buffer` now share data +/// assert_eq!(buffer.get_mut_slice(), None); +/// ``` +#[derive(Clone)] +pub struct Buffer { + /// The internal byte buffer. + storage: SharedStorage, + + /// A pointer into the buffer where our data starts. + ptr: *const T, + + // The length of the buffer. + length: usize, +} + +unsafe impl Sync for Buffer {} +unsafe impl Send for Buffer {} + +impl PartialEq for Buffer { + #[inline] + fn eq(&self, other: &Self) -> bool { + self.deref() == other.deref() + } +} + +impl Eq for Buffer {} + +impl std::hash::Hash for Buffer { + #[inline] + fn hash(&self, state: &mut H) { + self.as_slice().hash(state); + } +} + +impl std::fmt::Debug for Buffer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Debug::fmt(&**self, f) + } +} + +impl Default for Buffer { + #[inline] + fn default() -> Self { + Vec::new().into() + } +} + +impl Buffer { + /// Creates an empty [`Buffer`]. + #[inline] + pub fn new() -> Self { + Self::default() + } + + /// Auxiliary method to create a new Buffer + pub fn from_storage(storage: SharedStorage) -> Self { + let ptr = storage.as_ptr(); + let length = storage.len(); + Buffer { + storage, + ptr, + length, + } + } + + pub fn from_static(data: &'static [T]) -> Self { + Self::from_storage(SharedStorage::from_static(data)) + } + + /// Returns the number of bytes in the buffer + #[inline] + pub fn len(&self) -> usize { + self.length + } + + /// Returns whether the buffer is empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.length == 0 + } + + /// Returns whether underlying data is sliced. + /// If sliced the [`Buffer`] is backed by + /// more data than the length of `Self`. + pub fn is_sliced(&self) -> bool { + self.storage.len() != self.length + } + + /// Expands this slice to the maximum allowed by the underlying storage. + /// Only expands towards the end, the offset isn't changed. That is, element + /// i before and after this operation refer to the same element. + pub fn expand_end_to_storage(self) -> Self { + unsafe { + let offset = self.ptr.offset_from(self.storage.as_ptr()) as usize; + Self { + ptr: self.ptr, + length: self.storage.len() - offset, + storage: self.storage, + } + } + } + + /// Returns the byte slice stored in this buffer + #[inline] + pub fn as_slice(&self) -> &[T] { + // SAFETY: + // invariant of this struct `offset + length <= data.len()` + debug_assert!(self.offset() + self.length <= self.storage.len()); + unsafe { std::slice::from_raw_parts(self.ptr, self.length) } + } + + /// Returns the byte slice stored in this buffer + /// + /// # Safety + /// `index` must be smaller than `len` + #[inline] + pub(super) unsafe fn get_unchecked(&self, index: usize) -> &T { + // SAFETY: + // invariant of this function + debug_assert!(index < self.length); + unsafe { &*self.ptr.add(index) } + } + + /// Returns a new [`Buffer`] that is a slice of this buffer starting at `offset`. + /// Doing so allows the same memory region to be shared between buffers. + /// # Panics + /// Panics iff `offset + length` is larger than `len`. + #[inline] + pub fn sliced(self, offset: usize, length: usize) -> Self { + assert!( + offset + length <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + // SAFETY: we just checked bounds + unsafe { self.sliced_unchecked(offset, length) } + } + + /// Slices this buffer starting at `offset`. + /// # Panics + /// Panics iff `offset + length` is larger than `len`. + #[inline] + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + // SAFETY: we just checked bounds + unsafe { self.slice_unchecked(offset, length) } + } + + /// Returns a new [`Buffer`] that is a slice of this buffer starting at `offset`. + /// Doing so allows the same memory region to be shared between buffers. + /// + /// # Safety + /// The caller must ensure `offset + length <= self.len()` + #[inline] + #[must_use] + pub unsafe fn sliced_unchecked(mut self, offset: usize, length: usize) -> Self { + debug_assert!(offset + length <= self.len()); + + self.slice_unchecked(offset, length); + self + } + + /// Slices this buffer starting at `offset`. + /// + /// # Safety + /// The caller must ensure `offset + length <= self.len()` + #[inline] + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.ptr = self.ptr.add(offset); + self.length = length; + } + + /// Returns a pointer to the start of the storage underlying this buffer. + #[inline] + pub(crate) fn storage_ptr(&self) -> *const T { + self.storage.as_ptr() + } + + /// Returns the start offset of this buffer within the underlying storage. + #[inline] + pub fn offset(&self) -> usize { + unsafe { + let ret = self.ptr.offset_from(self.storage.as_ptr()) as usize; + debug_assert!(ret <= self.storage.len()); + ret + } + } + + /// # Safety + /// The caller must ensure that the buffer was properly initialized up to `len`. + #[inline] + pub unsafe fn set_len(&mut self, len: usize) { + self.length = len; + } + + /// Returns a mutable reference to its underlying [`Vec`], if possible. + /// + /// This operation returns [`Either::Right`] iff this [`Buffer`]: + /// * has no alive clones + /// * has not been imported from the C data interface (FFI) + #[inline] + pub fn into_mut(mut self) -> Either> { + // We lose information if the data is sliced. + if self.is_sliced() { + return Either::Left(self); + } + match self.storage.try_into_vec() { + Ok(v) => Either::Right(v), + Err(slf) => { + self.storage = slf; + Either::Left(self) + }, + } + } + + /// Returns a mutable reference to its slice, if possible. + /// + /// This operation returns [`Some`] iff this [`Buffer`]: + /// * has no alive clones + /// * has not been imported from the C data interface (FFI) + #[inline] + pub fn get_mut_slice(&mut self) -> Option<&mut [T]> { + let offset = self.offset(); + let slice = self.storage.try_as_mut_slice()?; + Some(unsafe { slice.get_unchecked_mut(offset..offset + self.length) }) + } + + /// Since this takes a shared reference to self, beware that others might + /// increment this after you've checked it's equal to 1. + pub fn storage_refcount(&self) -> u64 { + self.storage.refcount() + } +} + +impl Buffer { + pub fn make_mut(self) -> Vec { + match self.into_mut() { + Either::Right(v) => v, + Either::Left(same) => same.as_slice().to_vec(), + } + } +} + +impl Buffer { + pub fn zeroed(len: usize) -> Self { + vec![T::zeroed(); len].into() + } +} + +impl From> for Buffer { + #[inline] + fn from(v: Vec) -> Self { + Self::from_storage(SharedStorage::from_vec(v)) + } +} + +impl Deref for Buffer { + type Target = [T]; + + #[inline(always)] + fn deref(&self) -> &[T] { + self.as_slice() + } +} + +impl AsRef<[T]> for Buffer { + #[inline(always)] + fn as_ref(&self) -> &[T] { + self.as_slice() + } +} + +impl FromIterator for Buffer { + #[inline] + fn from_iter>(iter: I) -> Self { + Vec::from_iter(iter).into() + } +} + +impl IntoIterator for Buffer { + type Item = T; + + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + IntoIter::new(self) + } +} + +unsafe impl<'a, T: 'a> ArrayAccessor<'a> for Buffer { + type Item = &'a T; + + unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item { + unsafe { &*self.ptr.add(index) } + } + + fn len(&self) -> usize { + Buffer::len(self) + } +} + +impl Splitable for Buffer { + #[inline(always)] + fn check_bound(&self, offset: usize) -> bool { + offset <= self.len() + } + + unsafe fn _split_at_unchecked(&self, offset: usize) -> (Self, Self) { + let storage = &self.storage; + + ( + Self { + storage: storage.clone(), + ptr: self.ptr, + length: offset, + }, + Self { + storage: storage.clone(), + ptr: self.ptr.wrapping_add(offset), + length: self.length - offset, + }, + ) + } +} diff --git a/crates/polars-arrow/src/buffer/iterator.rs b/crates/polars-arrow/src/buffer/iterator.rs new file mode 100644 index 000000000000..93511c480284 --- /dev/null +++ b/crates/polars-arrow/src/buffer/iterator.rs @@ -0,0 +1,68 @@ +use super::Buffer; +use crate::trusted_len::TrustedLen; + +/// This crates' equivalent of [`std::vec::IntoIter`] for [`Buffer`]. +#[derive(Debug, Clone)] +pub struct IntoIter { + values: Buffer, + index: usize, + end: usize, +} + +impl IntoIter { + /// Creates a new [`Buffer`] + #[inline] + pub fn new(values: Buffer) -> Self { + let end = values.len(); + Self { + values, + index: 0, + end, + } + } +} + +impl Iterator for IntoIter { + type Item = T; + + #[inline] + fn next(&mut self) -> Option { + if self.index == self.end { + return None; + } + let old = self.index; + self.index += 1; + Some(*unsafe { self.values.get_unchecked(old) }) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.end - self.index, Some(self.end - self.index)) + } + + #[inline] + fn nth(&mut self, n: usize) -> Option { + let new_index = self.index + n; + if new_index > self.end { + self.index = self.end; + None + } else { + self.index = new_index; + self.next() + } + } +} + +impl DoubleEndedIterator for IntoIter { + #[inline] + fn next_back(&mut self) -> Option { + if self.index == self.end { + None + } else { + self.end -= 1; + Some(*unsafe { self.values.get_unchecked(self.end) }) + } + } +} + +unsafe impl TrustedLen for IntoIter {} diff --git a/crates/polars-arrow/src/buffer/mod.rs b/crates/polars-arrow/src/buffer/mod.rs new file mode 100644 index 000000000000..386545482d09 --- /dev/null +++ b/crates/polars-arrow/src/buffer/mod.rs @@ -0,0 +1,7 @@ +//! Contains [`Buffer`], an immutable container for all Arrow physical types (e.g. i32, f64). + +mod immutable; +mod iterator; + +pub use immutable::Buffer; +pub(super) use iterator::IntoIter; diff --git a/crates/polars-arrow/src/compute/README.md b/crates/polars-arrow/src/compute/README.md new file mode 100644 index 000000000000..4662ed0944b8 --- /dev/null +++ b/crates/polars-arrow/src/compute/README.md @@ -0,0 +1,38 @@ +# Design + +This document outlines the design guide lines of this module. + +This module is composed by independent operations common in analytics. Below are some design of its +principles: + +- APIs MUST return an error when either: + - The arguments are incorrect + - The execution results in a predictable error (e.g. divide by zero) + +- APIs MAY error when an operation overflows (e.g. `i32 + i32`) + +- kernels MUST NOT have side-effects + +- kernels MUST NOT take ownership of any of its arguments (i.e. everything must be a reference). + +- APIs SHOULD error when an operation on variable sized containers can overflow the maximum size of + `usize`. + +- Kernels SHOULD use the arrays' logical type to decide whether kernels can be applied on an array. + For example, `Date32 + Date32` is meaningless and SHOULD NOT be implemented. + +- Kernels SHOULD be implemented via `clone`, `slice` or the `iterator` API provided by `Buffer`, + `Bitmap`, `Vec` or `MutableBitmap`. + +- Kernels MUST NOT use any API to read bits other than the ones provided by `Bitmap`. + +- Implementations SHOULD aim for auto-vectorization, which is usually accomplished via + `from_trusted_len_iter`. + +- Implementations MUST feature-gate any implementation that requires external dependencies + +- When a kernel accepts dynamically-typed arrays, it MUST expect them as `&dyn Array`. + +- When an API returns `&dyn Array`, it MUST return `Box`. The rational is that a `Box` is + mutable, while an `Arc` is not. As such, `Box` offers the most flexible API to consumers and the + compiler. Users can cast a `Box` into `Arc` via `.into()`. diff --git a/crates/polars-arrow/src/compute/aggregate/memory.rs b/crates/polars-arrow/src/compute/aggregate/memory.rs new file mode 100644 index 000000000000..1cf7512bbbef --- /dev/null +++ b/crates/polars-arrow/src/compute/aggregate/memory.rs @@ -0,0 +1,131 @@ +use crate::array::*; +use crate::bitmap::Bitmap; +use crate::datatypes::PhysicalType; +pub use crate::types::PrimitiveType; +use crate::{match_integer_type, with_match_primitive_type_full}; +fn validity_size(validity: Option<&Bitmap>) -> usize { + validity.as_ref().map(|b| b.as_slice().0.len()).unwrap_or(0) +} + +macro_rules! dyn_binary { + ($array:expr, $ty:ty, $o:ty) => {{ + let array = $array.as_any().downcast_ref::<$ty>().unwrap(); + let offsets = array.offsets().buffer(); + + // in case of Binary/Utf8/List the offsets are sliced, + // not the values buffer + let values_start = offsets[0] as usize; + let values_end = offsets[offsets.len() - 1] as usize; + + values_end - values_start + + offsets.len() * size_of::<$o>() + + validity_size(array.validity()) + }}; +} + +fn binview_size(array: &BinaryViewArrayGeneric) -> usize { + // We choose the optimal usage as data can be shared across buffers. + // If we would sum all buffers we overestimate memory usage and trigger OOC when not needed. + array.total_bytes_len() +} + +/// Returns the total (heap) allocated size of the array in bytes. +/// # Implementation +/// This estimation is the sum of the size of its buffers, validity, including nested arrays. +/// Multiple arrays may share buffers and bitmaps. Therefore, the size of 2 arrays is not the +/// sum of the sizes computed from this function. In particular, [`StructArray`]'s size is an upper bound. +/// +/// When an array is sliced, its allocated size remains constant because the buffer unchanged. +/// However, this function will yield a smaller number. This is because this function returns +/// the visible size of the buffer, not its total capacity. +/// +/// FFI buffers are included in this estimation. +pub fn estimated_bytes_size(array: &dyn Array) -> usize { + use PhysicalType::*; + match array.dtype().to_physical_type() { + Null => 0, + Boolean => { + let array = array.as_any().downcast_ref::().unwrap(); + array.values().as_slice().0.len() + validity_size(array.validity()) + }, + Primitive(PrimitiveType::DaysMs) => { + let array = array.as_any().downcast_ref::().unwrap(); + array.values().len() * size_of::() * 2 + validity_size(array.validity()) + }, + Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + array.values().len() * size_of::<$T>() + validity_size(array.validity()) + }), + Binary => dyn_binary!(array, BinaryArray, i32), + FixedSizeBinary => { + let array = array + .as_any() + .downcast_ref::() + .unwrap(); + array.values().len() + validity_size(array.validity()) + }, + LargeBinary => dyn_binary!(array, BinaryArray, i64), + Utf8 => dyn_binary!(array, Utf8Array, i32), + LargeUtf8 => dyn_binary!(array, Utf8Array, i64), + List => { + let array = array.as_any().downcast_ref::>().unwrap(); + estimated_bytes_size(array.values().as_ref()) + + array.offsets().len_proxy() * size_of::() + + validity_size(array.validity()) + }, + FixedSizeList => { + let array = array.as_any().downcast_ref::().unwrap(); + estimated_bytes_size(array.values().as_ref()) + validity_size(array.validity()) + }, + LargeList => { + let array = array.as_any().downcast_ref::>().unwrap(); + estimated_bytes_size(array.values().as_ref()) + + array.offsets().len_proxy() * size_of::() + + validity_size(array.validity()) + }, + Struct => { + let array = array.as_any().downcast_ref::().unwrap(); + array + .values() + .iter() + .map(|x| x.as_ref()) + .map(estimated_bytes_size) + .sum::() + + validity_size(array.validity()) + }, + Union => { + let array = array.as_any().downcast_ref::().unwrap(); + let types = array.types().len() * size_of::(); + let offsets = array + .offsets() + .as_ref() + .map(|x| x.len() * size_of::()) + .unwrap_or_default(); + let fields = array + .fields() + .iter() + .map(|x| x.as_ref()) + .map(estimated_bytes_size) + .sum::(); + types + offsets + fields + }, + Dictionary(key_type) => match_integer_type!(key_type, |$T| { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + estimated_bytes_size(array.keys()) + estimated_bytes_size(array.values().as_ref()) + }), + Utf8View => binview_size::(array.as_any().downcast_ref().unwrap()), + BinaryView => binview_size::<[u8]>(array.as_any().downcast_ref().unwrap()), + Map => { + let array = array.as_any().downcast_ref::().unwrap(); + let offsets = array.offsets().len_proxy() * size_of::(); + offsets + estimated_bytes_size(array.field().as_ref()) + validity_size(array.validity()) + }, + } +} diff --git a/crates/polars-arrow/src/compute/aggregate/mod.rs b/crates/polars-arrow/src/compute/aggregate/mod.rs new file mode 100644 index 000000000000..879e3a09c6ed --- /dev/null +++ b/crates/polars-arrow/src/compute/aggregate/mod.rs @@ -0,0 +1,2 @@ +mod memory; +pub use memory::*; diff --git a/crates/polars-arrow/src/compute/arity.rs b/crates/polars-arrow/src/compute/arity.rs new file mode 100644 index 000000000000..2670cfb4d031 --- /dev/null +++ b/crates/polars-arrow/src/compute/arity.rs @@ -0,0 +1,70 @@ +//! Defines kernels suitable to perform operations to primitive arrays. + +use super::utils::{check_same_len, combine_validities_and}; +use crate::array::PrimitiveArray; +use crate::datatypes::ArrowDataType; +use crate::types::NativeType; + +/// Applies an unary and infallible function to a [`PrimitiveArray`]. +/// +/// This is the /// fastest way to perform an operation on a [`PrimitiveArray`] when the benefits +/// of a vectorized operation outweighs the cost of branching nulls and non-nulls. +/// +/// # Implementation +/// This will apply the function for all values, including those on null slots. +/// This implies that the operation must be infallible for any value of the +/// corresponding type or this function may panic. +#[inline] +pub fn unary(array: &PrimitiveArray, op: F, dtype: ArrowDataType) -> PrimitiveArray +where + I: NativeType, + O: NativeType, + F: Fn(I) -> O, +{ + let values = array.values().iter().map(|v| op(*v)).collect::>(); + + PrimitiveArray::::new(dtype, values.into(), array.validity().cloned()) +} + +/// Applies a binary operations to two primitive arrays. +/// +/// This is the fastest way to perform an operation on two primitive array when the benefits of a +/// vectorized operation outweighs the cost of branching nulls and non-nulls. +/// +/// # Errors +/// This function errors iff the arrays have a different length. +/// +/// # Implementation +/// This will apply the function for all values, including those on null slots. +/// This implies that the operation must be infallible for any value of the +/// corresponding type. +/// The types of the arrays are not checked with this operation. The closure +/// "op" needs to handle the different types in the arrays. The datatype for the +/// resulting array has to be selected by the implementer of the function as +/// an argument for the function. +#[inline] +pub fn binary( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, + dtype: ArrowDataType, + op: F, +) -> PrimitiveArray +where + T: NativeType, + D: NativeType, + F: Fn(T, D) -> T, +{ + check_same_len(lhs, rhs).unwrap(); + + let validity = combine_validities_and(lhs.validity(), rhs.validity()); + + let values = lhs + .values() + .iter() + .zip(rhs.values().iter()) + .map(|(l, r)| op(*l, *r)) + .collect::>() + .into(); + + PrimitiveArray::::new(dtype, values, validity) +} diff --git a/crates/polars-arrow/src/compute/arity_assign.rs b/crates/polars-arrow/src/compute/arity_assign.rs new file mode 100644 index 000000000000..e1b358d8aebb --- /dev/null +++ b/crates/polars-arrow/src/compute/arity_assign.rs @@ -0,0 +1,96 @@ +//! Defines generics suitable to perform operations to [`PrimitiveArray`] in-place. + +use either::Either; + +use super::utils::check_same_len; +use crate::array::PrimitiveArray; +use crate::types::NativeType; + +/// Applies an unary function to a [`PrimitiveArray`], optionally in-place. +/// +/// # Implementation +/// This function tries to apply the function directly to the values of the array. +/// If that region is shared, this function creates a new region and writes to it. +/// +/// # Panics +/// This function panics iff +/// * the arrays have a different length. +/// * the function itself panics. +#[inline] +pub fn unary(array: &mut PrimitiveArray, op: F) +where + I: NativeType, + F: Fn(I) -> I, +{ + if let Some(values) = array.get_mut_values() { + // mutate in place + values.iter_mut().for_each(|l| *l = op(*l)); + } else { + // alloc and write to new region + let values = array.values().iter().map(|l| op(*l)).collect::>(); + array.set_values(values.into()); + } +} + +/// Applies a binary function to two [`PrimitiveArray`]s, optionally in-place, returning +/// a new [`PrimitiveArray`]. +/// +/// # Implementation +/// This function tries to apply the function directly to the values of the array. +/// If that region is shared, this function creates a new region and writes to it. +/// # Panics +/// This function panics iff +/// * the arrays have a different length. +/// * the function itself panics. +#[inline] +pub fn binary(lhs: &mut PrimitiveArray, rhs: &PrimitiveArray, op: F) +where + T: NativeType, + D: NativeType, + F: Fn(T, D) -> T, +{ + check_same_len(lhs, rhs).unwrap(); + + // both for the validity and for the values + // we branch to check if we can mutate in place + // if we can, great that is fastest. + // if we cannot, we allocate a new buffer and assign values to that + // new buffer, that is benchmarked to be ~2x faster than first memcpy and assign in place + // for the validity bits it can be much faster as we might need to iterate all bits if the + // bitmap has an offset. + if let Some(rhs) = rhs.validity() { + if lhs.validity().is_none() { + lhs.set_validity(Some(rhs.clone())); + } else { + lhs.apply_validity(|bitmap| { + match bitmap.into_mut() { + Either::Left(immutable) => { + // alloc new region + &immutable & rhs + }, + Either::Right(mutable) => { + // mutate in place + (mutable & rhs).into() + }, + } + }); + } + }; + + if let Some(values) = lhs.get_mut_values() { + // mutate values in place + values + .iter_mut() + .zip(rhs.values().iter()) + .for_each(|(l, r)| *l = op(*l, *r)); + } else { + // alloc new region + let values = lhs + .values() + .iter() + .zip(rhs.values().iter()) + .map(|(l, r)| op(*l, *r)) + .collect::>(); + lhs.set_values(values.into()); + } +} diff --git a/crates/polars-arrow/src/compute/bitwise.rs b/crates/polars-arrow/src/compute/bitwise.rs new file mode 100644 index 000000000000..1762fb430e58 --- /dev/null +++ b/crates/polars-arrow/src/compute/bitwise.rs @@ -0,0 +1,75 @@ +//! Contains bitwise operators: [`or`], [`and`], [`xor`] and [`not`]. +use std::ops::{BitAnd, BitOr, BitXor, Not}; + +use crate::array::PrimitiveArray; +use crate::compute::arity::{binary, unary}; +use crate::types::NativeType; + +/// Performs `OR` operation on two [`PrimitiveArray`]s. +/// # Panic +/// This function errors when the arrays have different lengths. +pub fn or(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeType + BitOr, +{ + binary(lhs, rhs, lhs.dtype().clone(), |a, b| a | b) +} + +/// Performs `XOR` operation between two [`PrimitiveArray`]s. +/// # Panic +/// This function errors when the arrays have different lengths. +pub fn xor(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeType + BitXor, +{ + binary(lhs, rhs, lhs.dtype().clone(), |a, b| a ^ b) +} + +/// Performs `AND` operation on two [`PrimitiveArray`]s. +/// # Panic +/// This function panics when the arrays have different lengths. +pub fn and(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeType + BitAnd, +{ + binary(lhs, rhs, lhs.dtype().clone(), |a, b| a & b) +} + +/// Returns a new [`PrimitiveArray`] with the bitwise `not`. +pub fn not(array: &PrimitiveArray) -> PrimitiveArray +where + T: NativeType + Not, +{ + let op = move |a: T| !a; + unary(array, op, array.dtype().clone()) +} + +/// Performs `OR` operation between a [`PrimitiveArray`] and scalar. +/// # Panic +/// This function errors when the arrays have different lengths. +pub fn or_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeType + BitOr, +{ + unary(lhs, |a| a | *rhs, lhs.dtype().clone()) +} + +/// Performs `XOR` operation between a [`PrimitiveArray`] and scalar. +/// # Panic +/// This function errors when the arrays have different lengths. +pub fn xor_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeType + BitXor, +{ + unary(lhs, |a| a ^ *rhs, lhs.dtype().clone()) +} + +/// Performs `AND` operation between a [`PrimitiveArray`] and scalar. +/// # Panic +/// This function panics when the arrays have different lengths. +pub fn and_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeType + BitAnd, +{ + unary(lhs, |a| a & *rhs, lhs.dtype().clone()) +} diff --git a/crates/polars-arrow/src/compute/boolean.rs b/crates/polars-arrow/src/compute/boolean.rs new file mode 100644 index 000000000000..5152d91ab338 --- /dev/null +++ b/crates/polars-arrow/src/compute/boolean.rs @@ -0,0 +1,284 @@ +//! null-preserving operators such as [`and`], [`or`] and [`not`]. +use super::utils::combine_validities_and; +use crate::array::{Array, BooleanArray}; +use crate::bitmap::Bitmap; +use crate::datatypes::ArrowDataType; +use crate::scalar::BooleanScalar; + +fn assert_lengths(lhs: &BooleanArray, rhs: &BooleanArray) { + assert_eq!( + lhs.len(), + rhs.len(), + "lhs and rhs must have the same length" + ); +} + +/// Helper function to implement binary kernels +pub(crate) fn binary_boolean_kernel( + lhs: &BooleanArray, + rhs: &BooleanArray, + op: F, +) -> BooleanArray +where + F: Fn(&Bitmap, &Bitmap) -> Bitmap, +{ + assert_lengths(lhs, rhs); + let validity = combine_validities_and(lhs.validity(), rhs.validity()); + + let left_buffer = lhs.values(); + let right_buffer = rhs.values(); + + let values = op(left_buffer, right_buffer); + + BooleanArray::new(ArrowDataType::Boolean, values, validity) +} + +/// Performs `&&` operation on two [`BooleanArray`], combining the validities. +/// # Panics +/// This function panics iff the arrays have different lengths. +/// # Examples +/// ```rust +/// use polars_arrow::array::BooleanArray; +/// use polars_arrow::compute::boolean::and; +/// +/// let a = BooleanArray::from(&[Some(false), Some(true), None]); +/// let b = BooleanArray::from(&[Some(true), Some(true), Some(false)]); +/// let and_ab = and(&a, &b); +/// assert_eq!(and_ab, BooleanArray::from(&[Some(false), Some(true), None])); +/// ``` +pub fn and(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { + if lhs.null_count() == 0 && rhs.null_count() == 0 { + let left_buffer = lhs.values(); + let right_buffer = rhs.values(); + + match (left_buffer.unset_bits(), right_buffer.unset_bits()) { + // all values are `true` on both sides + (0, 0) => { + assert_lengths(lhs, rhs); + return lhs.clone(); + }, + // all values are `false` on left side + (l, _) if l == lhs.len() => { + assert_lengths(lhs, rhs); + return lhs.clone(); + }, + // all values are `false` on right side + (_, r) if r == rhs.len() => { + assert_lengths(lhs, rhs); + return rhs.clone(); + }, + // ignore the rest + _ => {}, + } + } + + binary_boolean_kernel(lhs, rhs, |lhs, rhs| lhs & rhs) +} + +/// Performs `||` operation on two [`BooleanArray`], combining the validities. +/// # Panics +/// This function panics iff the arrays have different lengths. +/// # Examples +/// ```rust +/// use polars_arrow::array::BooleanArray; +/// use polars_arrow::compute::boolean::or; +/// +/// let a = BooleanArray::from(vec![Some(false), Some(true), None]); +/// let b = BooleanArray::from(vec![Some(true), Some(true), Some(false)]); +/// let or_ab = or(&a, &b); +/// assert_eq!(or_ab, BooleanArray::from(vec![Some(true), Some(true), None])); +/// ``` +pub fn or(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { + if lhs.null_count() == 0 && rhs.null_count() == 0 { + let left_buffer = lhs.values(); + let right_buffer = rhs.values(); + + match (left_buffer.unset_bits(), right_buffer.unset_bits()) { + // all values are `true` on left side + (0, _) => { + assert_lengths(lhs, rhs); + return lhs.clone(); + }, + // all values are `true` on right side + (_, 0) => { + assert_lengths(lhs, rhs); + return rhs.clone(); + }, + // all values on lhs and rhs are `false` + (l, r) if l == lhs.len() && r == rhs.len() => { + assert_lengths(lhs, rhs); + return rhs.clone(); + }, + // ignore the rest + _ => {}, + } + } + + binary_boolean_kernel(lhs, rhs, |lhs, rhs| lhs | rhs) +} + +/// Performs unary `NOT` operation on an arrays. If value is null then the result is also +/// null. +/// # Example +/// ```rust +/// use polars_arrow::array::BooleanArray; +/// use polars_arrow::compute::boolean::not; +/// +/// let a = BooleanArray::from(vec![Some(false), Some(true), None]); +/// let not_a = not(&a); +/// assert_eq!(not_a, BooleanArray::from(vec![Some(true), Some(false), None])); +/// ``` +pub fn not(array: &BooleanArray) -> BooleanArray { + let values = !array.values(); + let validity = array.validity().cloned(); + BooleanArray::new(ArrowDataType::Boolean, values, validity) +} + +/// Returns a non-null [`BooleanArray`] with whether each value of the array is null. +/// # Example +/// ```rust +/// use polars_arrow::array::BooleanArray; +/// use polars_arrow::compute::boolean::is_null; +/// # fn main() { +/// let a = BooleanArray::from(vec![Some(false), Some(true), None]); +/// let a_is_null = is_null(&a); +/// assert_eq!(a_is_null, BooleanArray::from_slice(vec![false, false, true])); +/// # } +/// ``` +pub fn is_null(input: &dyn Array) -> BooleanArray { + let len = input.len(); + + let values = match input.validity() { + None => Bitmap::new_zeroed(len), + Some(buffer) => !buffer, + }; + + BooleanArray::new(ArrowDataType::Boolean, values, None) +} + +/// Returns a non-null [`BooleanArray`] with whether each value of the array is not null. +/// # Example +/// ```rust +/// use polars_arrow::array::BooleanArray; +/// use polars_arrow::compute::boolean::is_not_null; +/// +/// let a = BooleanArray::from(&vec![Some(false), Some(true), None]); +/// let a_is_not_null = is_not_null(&a); +/// assert_eq!(a_is_not_null, BooleanArray::from_slice(&vec![true, true, false])); +/// ``` +pub fn is_not_null(input: &dyn Array) -> BooleanArray { + let values = match input.validity() { + None => Bitmap::new_with_value(true, input.len()), + Some(buffer) => buffer.clone(), + }; + BooleanArray::new(ArrowDataType::Boolean, values, None) +} + +/// Performs `AND` operation on an array and a scalar value. If either left or right value +/// is null then the result is also null. +/// # Example +/// ```rust +/// use polars_arrow::array::BooleanArray; +/// use polars_arrow::compute::boolean::and_scalar; +/// use polars_arrow::scalar::BooleanScalar; +/// +/// let array = BooleanArray::from_slice(&[false, false, true, true]); +/// let scalar = BooleanScalar::new(Some(true)); +/// let result = and_scalar(&array, &scalar); +/// assert_eq!(result, BooleanArray::from_slice(&[false, false, true, true])); +/// +/// ``` +pub fn and_scalar(array: &BooleanArray, scalar: &BooleanScalar) -> BooleanArray { + match scalar.value() { + Some(true) => array.clone(), + Some(false) => { + let values = Bitmap::new_zeroed(array.len()); + BooleanArray::new(ArrowDataType::Boolean, values, array.validity().cloned()) + }, + None => BooleanArray::new_null(ArrowDataType::Boolean, array.len()), + } +} + +/// Performs `OR` operation on an array and a scalar value. If either left or right value +/// is null then the result is also null. +/// # Example +/// ```rust +/// use polars_arrow::array::BooleanArray; +/// use polars_arrow::compute::boolean::or_scalar; +/// use polars_arrow::scalar::BooleanScalar; +/// # fn main() { +/// let array = BooleanArray::from_slice(&[false, false, true, true]); +/// let scalar = BooleanScalar::new(Some(true)); +/// let result = or_scalar(&array, &scalar); +/// assert_eq!(result, BooleanArray::from_slice(&[true, true, true, true])); +/// # } +/// ``` +pub fn or_scalar(array: &BooleanArray, scalar: &BooleanScalar) -> BooleanArray { + match scalar.value() { + Some(true) => BooleanArray::new( + ArrowDataType::Boolean, + Bitmap::new_with_value(true, array.len()), + array.validity().cloned(), + ), + Some(false) => array.clone(), + None => BooleanArray::new_null(ArrowDataType::Boolean, array.len()), + } +} + +/// Returns whether any of the values in the array are `true`. +/// +/// Null values are ignored. +/// +/// # Example +/// +/// ``` +/// use polars_arrow::array::BooleanArray; +/// use polars_arrow::compute::boolean::any; +/// +/// let a = BooleanArray::from(&[Some(true), Some(false)]); +/// let b = BooleanArray::from(&[Some(false), Some(false)]); +/// let c = BooleanArray::from(&[None, Some(false)]); +/// +/// assert_eq!(any(&a), true); +/// assert_eq!(any(&b), false); +/// assert_eq!(any(&c), false); +/// ``` +pub fn any(array: &BooleanArray) -> bool { + if array.is_empty() { + false + } else if array.null_count() > 0 { + array.into_iter().any(|v| v == Some(true)) + } else { + let vals = array.values(); + vals.unset_bits() != vals.len() + } +} + +/// Returns whether all values in the array are `true`. +/// +/// Null values are ignored. +/// +/// # Example +/// +/// ``` +/// use polars_arrow::array::BooleanArray; +/// use polars_arrow::compute::boolean::all; +/// +/// let a = BooleanArray::from(&[Some(true), Some(true)]); +/// let b = BooleanArray::from(&[Some(false), Some(true)]); +/// let c = BooleanArray::from(&[None, Some(true)]); +/// +/// assert_eq!(all(&a), true); +/// assert_eq!(all(&b), false); +/// assert_eq!(all(&c), true); +/// ``` +pub fn all(array: &BooleanArray) -> bool { + if array.is_empty() { + true + } else if array.null_count() > 0 { + !array.into_iter().any(|v| v == Some(false)) + } else { + let vals = array.values(); + vals.unset_bits() == 0 + } +} diff --git a/crates/polars-arrow/src/compute/boolean_kleene.rs b/crates/polars-arrow/src/compute/boolean_kleene.rs new file mode 100644 index 000000000000..711310154c56 --- /dev/null +++ b/crates/polars-arrow/src/compute/boolean_kleene.rs @@ -0,0 +1,305 @@ +//! Boolean operators of [Kleene logic](https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics). +use crate::array::{Array, BooleanArray}; +use crate::bitmap::{Bitmap, binary, quaternary, ternary, unary}; +use crate::datatypes::ArrowDataType; +use crate::scalar::BooleanScalar; + +/// Logical 'or' operation on two arrays with [Kleene logic](https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics) +/// # Panics +/// This function panics iff the arrays have a different length +/// # Example +/// +/// ```rust +/// use polars_arrow::array::BooleanArray; +/// use polars_arrow::compute::boolean_kleene::or; +/// +/// let a = BooleanArray::from(&[Some(true), Some(false), None]); +/// let b = BooleanArray::from(&[None, None, None]); +/// let or_ab = or(&a, &b); +/// assert_eq!(or_ab, BooleanArray::from(&[Some(true), None, None])); +/// ``` +pub fn or(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { + assert_eq!( + lhs.len(), + rhs.len(), + "lhs and rhs must have the same length" + ); + + let lhs_values = lhs.values(); + let rhs_values = rhs.values(); + + let lhs_validity = lhs.validity(); + let rhs_validity = rhs.validity(); + + let validity = match (lhs_validity, rhs_validity) { + (Some(lhs_validity), Some(rhs_validity)) => { + Some(quaternary( + lhs_values, + rhs_values, + lhs_validity, + rhs_validity, + // see https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics + |lhs, rhs, lhs_v, rhs_v| { + // A = T + (lhs & lhs_v) | + // B = T + (rhs & rhs_v) | + // A = F & B = F + (!lhs & lhs_v) & (!rhs & rhs_v) + }, + )) + }, + (Some(lhs_validity), None) => { + // B != U + Some(ternary( + lhs_values, + rhs_values, + lhs_validity, + // see https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics + |lhs, rhs, lhs_v| { + // A = T + (lhs & lhs_v) | + // B = T + rhs | + // A = F & B = F + (!lhs & lhs_v) & !rhs + }, + )) + }, + (None, Some(rhs_validity)) => { + Some(ternary( + lhs_values, + rhs_values, + rhs_validity, + // see https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics + |lhs, rhs, rhs_v| { + // A = T + lhs | + // B = T + (rhs & rhs_v) | + // A = F & B = F + !lhs & (!rhs & rhs_v) + }, + )) + }, + (None, None) => None, + }; + BooleanArray::new(ArrowDataType::Boolean, lhs_values | rhs_values, validity) +} + +/// Logical 'and' operation on two arrays with [Kleene logic](https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics) +/// # Panics +/// This function panics iff the arrays have a different length +/// # Example +/// +/// ```rust +/// use polars_arrow::array::BooleanArray; +/// use polars_arrow::compute::boolean_kleene::and; +/// +/// let a = BooleanArray::from(&[Some(true), Some(false), None]); +/// let b = BooleanArray::from(&[None, None, None]); +/// let and_ab = and(&a, &b); +/// assert_eq!(and_ab, BooleanArray::from(&[None, Some(false), None])); +/// ``` +pub fn and(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { + assert_eq!( + lhs.len(), + rhs.len(), + "lhs and rhs must have the same length" + ); + + let lhs_values = lhs.values(); + let rhs_values = rhs.values(); + + let lhs_validity = lhs.validity(); + let rhs_validity = rhs.validity(); + + let validity = match (lhs_validity, rhs_validity) { + (Some(lhs_validity), Some(rhs_validity)) => { + Some(quaternary( + lhs_values, + rhs_values, + lhs_validity, + rhs_validity, + // see https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics + |lhs, rhs, lhs_v, rhs_v| { + // B = F + (!rhs & rhs_v) | + // A = F + (!lhs & lhs_v) | + // A = T & B = T + (lhs & lhs_v) & (rhs & rhs_v) + }, + )) + }, + (Some(lhs_validity), None) => { + Some(ternary( + lhs_values, + rhs_values, + lhs_validity, + // see https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics + |lhs, rhs, lhs_v| { + // B = F + !rhs | + // A = F + (!lhs & lhs_v) | + // A = T & B = T + (lhs & lhs_v) & rhs + }, + )) + }, + (None, Some(rhs_validity)) => { + Some(ternary( + lhs_values, + rhs_values, + rhs_validity, + // see https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics + |lhs, rhs, rhs_v| { + // B = F + (!rhs & rhs_v) | + // A = F + !lhs | + // A = T & B = T + lhs & (rhs & rhs_v) + }, + )) + }, + (None, None) => None, + }; + BooleanArray::new(ArrowDataType::Boolean, lhs_values & rhs_values, validity) +} + +/// Logical 'or' operation on an array and a scalar value with [Kleene logic](https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics) +/// # Example +/// +/// ```rust +/// use polars_arrow::array::BooleanArray; +/// use polars_arrow::scalar::BooleanScalar; +/// use polars_arrow::compute::boolean_kleene::or_scalar; +/// +/// let array = BooleanArray::from(&[Some(true), Some(false), None]); +/// let scalar = BooleanScalar::new(Some(false)); +/// let result = or_scalar(&array, &scalar); +/// assert_eq!(result, BooleanArray::from(&[Some(true), Some(false), None])); +/// ``` +pub fn or_scalar(array: &BooleanArray, scalar: &BooleanScalar) -> BooleanArray { + match scalar.value() { + Some(true) => BooleanArray::new( + ArrowDataType::Boolean, + Bitmap::new_with_value(true, array.len()), + None, + ), + Some(false) => array.clone(), + None => { + let values = array.values(); + let validity = match array.validity() { + Some(validity) => binary(values, validity, |value, validity| validity & value), + None => unary(values, |value| value), + }; + BooleanArray::new(ArrowDataType::Boolean, values.clone(), Some(validity)) + }, + } +} + +/// Logical 'and' operation on an array and a scalar value with [Kleene logic](https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics) +/// # Example +/// +/// ```rust +/// use polars_arrow::array::BooleanArray; +/// use polars_arrow::scalar::BooleanScalar; +/// use polars_arrow::compute::boolean_kleene::and_scalar; +/// +/// let array = BooleanArray::from(&[Some(true), Some(false), None]); +/// let scalar = BooleanScalar::new(None); +/// let result = and_scalar(&array, &scalar); +/// assert_eq!(result, BooleanArray::from(&[None, Some(false), None])); +/// ``` +pub fn and_scalar(array: &BooleanArray, scalar: &BooleanScalar) -> BooleanArray { + match scalar.value() { + Some(true) => array.clone(), + Some(false) => { + let values = Bitmap::new_zeroed(array.len()); + BooleanArray::new(ArrowDataType::Boolean, values, None) + }, + None => { + let values = array.values(); + let validity = match array.validity() { + Some(validity) => binary(values, validity, |value, validity| validity & !value), + None => unary(values, |value| !value), + }; + BooleanArray::new( + ArrowDataType::Boolean, + array.values().clone(), + Some(validity), + ) + }, + } +} + +/// Returns whether any of the values in the array are `true`. +/// +/// The output is unknown (`None`) if the array contains any null values and +/// no `true` values. +/// +/// # Example +/// +/// ``` +/// use polars_arrow::array::BooleanArray; +/// use polars_arrow::compute::boolean_kleene::any; +/// +/// let a = BooleanArray::from(&[Some(true), Some(false)]); +/// let b = BooleanArray::from(&[Some(false), Some(false)]); +/// let c = BooleanArray::from(&[None, Some(false)]); +/// +/// assert_eq!(any(&a), Some(true)); +/// assert_eq!(any(&b), Some(false)); +/// assert_eq!(any(&c), None); +/// ``` +pub fn any(array: &BooleanArray) -> Option { + if array.is_empty() { + Some(false) + } else if array.null_count() > 0 { + if array.into_iter().any(|v| v == Some(true)) { + Some(true) + } else { + None + } + } else { + let vals = array.values(); + Some(vals.unset_bits() != vals.len()) + } +} + +/// Returns whether all values in the array are `true`. +/// +/// The output is unknown (`None`) if the array contains any null values and +/// no `false` values. +/// +/// # Example +/// +/// ``` +/// use polars_arrow::array::BooleanArray; +/// use polars_arrow::compute::boolean_kleene::all; +/// +/// let a = BooleanArray::from(&[Some(true), Some(true)]); +/// let b = BooleanArray::from(&[Some(false), Some(true)]); +/// let c = BooleanArray::from(&[None, Some(true)]); +/// +/// assert_eq!(all(&a), Some(true)); +/// assert_eq!(all(&b), Some(false)); +/// assert_eq!(all(&c), None); +/// ``` +pub fn all(array: &BooleanArray) -> Option { + if array.is_empty() { + Some(true) + } else if array.null_count() > 0 { + if array.into_iter().any(|v| v == Some(false)) { + Some(false) + } else { + None + } + } else { + let vals = array.values(); + Some(vals.unset_bits() == 0) + } +} diff --git a/crates/polars-arrow/src/compute/concatenate.rs b/crates/polars-arrow/src/compute/concatenate.rs new file mode 100644 index 000000000000..afcd0685a1cd --- /dev/null +++ b/crates/polars-arrow/src/compute/concatenate.rs @@ -0,0 +1,442 @@ +use std::sync::Arc; + +use hashbrown::hash_map::Entry; +use polars_error::{PolarsResult, polars_bail}; +use polars_utils::aliases::{InitHashMaps, PlHashMap}; +use polars_utils::itertools::Itertools; +use polars_utils::vec::PushUnchecked; + +use crate::array::*; +use crate::bitmap::{Bitmap, BitmapBuilder}; +use crate::buffer::Buffer; +use crate::datatypes::PhysicalType; +use crate::offset::Offsets; +use crate::types::{NativeType, Offset}; +use crate::with_match_primitive_type_full; + +/// Concatenate multiple [`Array`] of the same type into a single [`Array`]. +pub fn concatenate(arrays: &[&dyn Array]) -> PolarsResult> { + if arrays.is_empty() { + polars_bail!(InvalidOperation: "concat requires input of at least one array") + } + + if arrays + .iter() + .any(|array| array.dtype() != arrays[0].dtype()) + { + polars_bail!(InvalidOperation: "It is not possible to concatenate arrays of different data types.") + } + + concatenate_unchecked(arrays) +} + +fn len_null_count>(arrays: &[A]) -> (usize, usize) { + let mut len = 0; + let mut null_count = 0; + for arr in arrays { + let arr = arr.as_ref(); + len += arr.len(); + null_count += arr.null_count(); + } + (len, null_count) +} + +/// Concatenate the validities of multiple [Array]s into a single Bitmap. +pub fn concatenate_validities>(arrays: &[A]) -> Option { + let (len, null_count) = len_null_count(arrays); + concatenate_validities_with_len_null_count(arrays, len, null_count) +} + +fn concatenate_validities_with_len_null_count>( + arrays: &[A], + len: usize, + null_count: usize, +) -> Option { + if null_count == 0 { + return None; + } + + let mut bitmap = BitmapBuilder::with_capacity(len); + for arr in arrays { + let arr = arr.as_ref(); + if arr.null_count() == arr.len() { + bitmap.extend_constant(arr.len(), false); + } else if arr.null_count() == 0 { + bitmap.extend_constant(arr.len(), true); + } else { + bitmap.extend_from_bitmap(arr.validity().unwrap()); + } + } + bitmap.into_opt_validity() +} + +/// Concatenate multiple [`Array`] of the same type into a single [`Array`]. +/// All arrays must be of the same dtype or a panic can occur. +pub fn concatenate_unchecked>(arrays: &[A]) -> PolarsResult> { + if arrays.is_empty() { + polars_bail!(InvalidOperation: "concat requires input of at least one array") + } + + if arrays.len() == 1 { + return Ok(arrays[0].as_ref().to_boxed()); + } + + use PhysicalType::*; + match arrays[0].as_ref().dtype().to_physical_type() { + Null => Ok(Box::new(concatenate_null(arrays))), + Boolean => Ok(Box::new(concatenate_bool(arrays))), + Primitive(ptype) => { + with_match_primitive_type_full!(ptype, |$T| { + Ok(Box::new(concatenate_primitive::<$T, _>(arrays))) + }) + }, + Binary => Ok(Box::new(concatenate_binary::(arrays)?)), + LargeBinary => Ok(Box::new(concatenate_binary::(arrays)?)), + Utf8 => Ok(Box::new(concatenate_utf8::(arrays)?)), + LargeUtf8 => Ok(Box::new(concatenate_utf8::(arrays)?)), + BinaryView => Ok(Box::new(concatenate_view::<[u8], _>(arrays))), + Utf8View => Ok(Box::new(concatenate_view::(arrays))), + List => Ok(Box::new(concatenate_list::(arrays)?)), + LargeList => Ok(Box::new(concatenate_list::(arrays)?)), + FixedSizeBinary => Ok(Box::new(concatenate_fixed_size_binary(arrays)?)), + FixedSizeList => Ok(Box::new(concatenate_fixed_size_list(arrays)?)), + Struct => Ok(Box::new(concatenate_struct(arrays)?)), + Union => unimplemented!(), + Map => unimplemented!(), + Dictionary(_) => unimplemented!(), + } +} + +fn concatenate_null>(arrays: &[A]) -> NullArray { + let dtype = arrays[0].as_ref().dtype().clone(); + let total_len = arrays.iter().map(|arr| arr.as_ref().len()).sum(); + NullArray::new(dtype, total_len) +} + +fn concatenate_bool>(arrays: &[A]) -> BooleanArray { + let dtype = arrays[0].as_ref().dtype().clone(); + let (total_len, null_count) = len_null_count(arrays); + let validity = concatenate_validities_with_len_null_count(arrays, total_len, null_count); + + let mut bitmap = BitmapBuilder::with_capacity(total_len); + for arr in arrays { + let arr: &BooleanArray = arr.as_ref().as_any().downcast_ref().unwrap(); + bitmap.extend_from_bitmap(arr.values()); + } + BooleanArray::new(dtype, bitmap.freeze(), validity) +} + +fn concatenate_primitive>(arrays: &[A]) -> PrimitiveArray { + let dtype = arrays[0].as_ref().dtype().clone(); + let (total_len, null_count) = len_null_count(arrays); + let validity = concatenate_validities_with_len_null_count(arrays, total_len, null_count); + + let mut out = Vec::with_capacity(total_len); + for arr in arrays { + let arr: &PrimitiveArray = arr.as_ref().as_any().downcast_ref().unwrap(); + out.extend_from_slice(arr.values()); + } + unsafe { PrimitiveArray::new_unchecked(dtype, Buffer::from(out), validity) } +} + +fn concatenate_binary>( + arrays: &[A], +) -> PolarsResult> { + let dtype = arrays[0].as_ref().dtype().clone(); + let (total_len, null_count) = len_null_count(arrays); + let validity = concatenate_validities_with_len_null_count(arrays, total_len, null_count); + + let total_bytes = arrays + .iter() + .map(|arr| { + let arr: &BinaryArray = arr.as_ref().as_any().downcast_ref().unwrap(); + arr.get_values_size() + }) + .sum(); + + let mut values = Vec::with_capacity(total_bytes); + let mut offsets = Offsets::::with_capacity(total_len); + + for arr in arrays { + let arr: &BinaryArray = arr.as_ref().as_any().downcast_ref().unwrap(); + let first_offset = arr.offsets().first().to_usize(); + let last_offset = arr.offsets().last().to_usize(); + values.extend_from_slice(&arr.values()[first_offset..last_offset]); + for len in arr.offsets().lengths() { + offsets.try_push(len)?; + } + } + + Ok(unsafe { BinaryArray::new(dtype, offsets.into(), values.into(), validity) }) +} + +fn concatenate_utf8>(arrays: &[A]) -> PolarsResult> { + let dtype = arrays[0].as_ref().dtype().clone(); + let (total_len, null_count) = len_null_count(arrays); + let validity = concatenate_validities_with_len_null_count(arrays, total_len, null_count); + + let total_bytes = arrays + .iter() + .map(|arr| { + let arr: &Utf8Array = arr.as_ref().as_any().downcast_ref().unwrap(); + arr.get_values_size() + }) + .sum(); + + let mut bytes = Vec::with_capacity(total_bytes); + let mut offsets = Offsets::::with_capacity(total_len); + + for arr in arrays { + let arr: &Utf8Array = arr.as_ref().as_any().downcast_ref().unwrap(); + let first_offset = arr.offsets().first().to_usize(); + let last_offset = arr.offsets().last().to_usize(); + bytes.extend_from_slice(&arr.values()[first_offset..last_offset]); + for len in arr.offsets().lengths() { + offsets.try_push(len)?; + } + } + + Ok(unsafe { Utf8Array::new_unchecked(dtype, offsets.into(), bytes.into(), validity) }) +} + +fn concatenate_view>( + arrays: &[A], +) -> BinaryViewArrayGeneric { + let dtype = arrays[0].as_ref().dtype().clone(); + let (total_len, null_count) = len_null_count(arrays); + if total_len == 0 { + return BinaryViewArrayGeneric::new_empty(dtype); + } + let validity = concatenate_validities_with_len_null_count(arrays, total_len, null_count); + + let first_arr: &BinaryViewArrayGeneric = arrays[0].as_ref().as_any().downcast_ref().unwrap(); + let mut total_nondedup_buffers = first_arr.data_buffers().len(); + let mut max_arr_bufferset_len = 0; + let mut all_same_bufs = true; + for arr in arrays { + let arr: &BinaryViewArrayGeneric = arr.as_ref().as_any().downcast_ref().unwrap(); + max_arr_bufferset_len = max_arr_bufferset_len.max(arr.data_buffers().len()); + total_nondedup_buffers += arr.data_buffers().len(); + // Fat pointer equality, checks both start and length. + all_same_bufs &= std::ptr::eq( + Arc::as_ptr(arr.data_buffers()), + Arc::as_ptr(first_arr.data_buffers()), + ); + } + + let mut total_bytes_len = 0; + let mut views = Vec::with_capacity(total_len); + + let mut total_buffer_len = 0; + let buffers = if all_same_bufs { + total_buffer_len = first_arr.total_buffer_len(); + for arr in arrays { + let arr: &BinaryViewArrayGeneric = arr.as_ref().as_any().downcast_ref().unwrap(); + views.extend_from_slice(arr.views()); + total_bytes_len += arr.total_bytes_len(); + } + Arc::clone(first_arr.data_buffers()) + + // There might be way more buffers than elements, so we only dedup if there + // is at least one element per buffer on average. + } else if total_len > total_nondedup_buffers { + assert!(arrays.len() < u32::MAX as usize); + + let mut dedup_buffers = Vec::with_capacity(total_nondedup_buffers); + let mut global_dedup_buffer_idx = PlHashMap::with_capacity(total_nondedup_buffers); + let mut local_dedup_buffer_idx = Vec::new(); + local_dedup_buffer_idx.resize(max_arr_bufferset_len, (0, u32::MAX)); + + for (arr_idx, arr) in arrays.iter().enumerate() { + let arr: &BinaryViewArrayGeneric = arr.as_ref().as_any().downcast_ref().unwrap(); + + unsafe { + for mut view in arr.views().iter().copied() { + if view.length > View::MAX_INLINE_SIZE { + // Translate from old array-local buffer idx to global deduped buffer idx. + let (mut new_buffer_idx, cache_tag) = + *local_dedup_buffer_idx.get_unchecked(view.buffer_idx as usize); + if cache_tag != arr_idx as u32 { + // This buffer index wasn't seen before for this array, do a dedup lookup. + let buffer = arr.data_buffers().get_unchecked(view.buffer_idx as usize); + let buf_id = (buffer.as_slice().as_ptr(), buffer.len()); + let idx = match global_dedup_buffer_idx.entry(buf_id) { + Entry::Occupied(o) => *o.get(), + Entry::Vacant(v) => { + let idx = dedup_buffers.len() as u32; + dedup_buffers.push(buffer.clone()); + total_buffer_len += buffer.len(); + v.insert(idx); + idx + }, + }; + + // Cache result for future lookups. + *local_dedup_buffer_idx.get_unchecked_mut(view.buffer_idx as usize) = + (idx, arr_idx as u32); + new_buffer_idx = idx; + } + view.buffer_idx = new_buffer_idx; + } + + total_bytes_len += view.length as usize; + views.push_unchecked(view); + } + } + } + + dedup_buffers.into_iter().collect() + } else { + // Only very few of the total number of buffers is referenced, simply + // create a new direct buffer. + for arr in arrays { + let arr: &BinaryViewArrayGeneric = arr.as_ref().as_any().downcast_ref().unwrap(); + total_buffer_len += arr + .len_iter() + .map(|l| if l > 12 { l as usize } else { 0 }) + .sum::(); + } + let mut new_buffer = Vec::with_capacity(total_buffer_len); + for arr in arrays { + let arr: &BinaryViewArrayGeneric = arr.as_ref().as_any().downcast_ref().unwrap(); + let buffers = arr.data_buffers(); + + unsafe { + for mut view in arr.views().iter().copied() { + total_bytes_len += view.length as usize; + if view.length > 12 { + let new_offset = new_buffer.len().try_into().unwrap(); + new_buffer.extend_from_slice(view.get_slice_unchecked(buffers)); + view.offset = new_offset; + view.buffer_idx = 0; + } + views.push_unchecked(view); + } + } + } + + Arc::new([Buffer::from(new_buffer)]) as Arc<[_]> + }; + + unsafe { + BinaryViewArrayGeneric::new_unchecked( + dtype, + views.into(), + buffers, + validity, + total_bytes_len, + total_buffer_len, + ) + } +} + +fn concatenate_list>(arrays: &[A]) -> PolarsResult> { + let dtype = arrays[0].as_ref().dtype().clone(); + let (total_len, null_count) = len_null_count(arrays); + let validity = concatenate_validities_with_len_null_count(arrays, total_len, null_count); + + let mut num_sliced = 0; + let mut offsets = Offsets::::with_capacity(total_len); + for arr in arrays { + let arr: &ListArray = arr.as_ref().as_any().downcast_ref().unwrap(); + for len in arr.offsets().lengths() { + offsets.try_push(len)?; + } + let first_offset = arr.offsets().first().to_usize(); + let offset_range = arr.offsets().range().to_usize(); + num_sliced += (first_offset != 0 || offset_range != arr.values().len()) as usize; + } + + let values = if num_sliced > 0 { + let inner_sliced_arrays = arrays + .iter() + .map(|arr| { + let arr: &ListArray = arr.as_ref().as_any().downcast_ref().unwrap(); + let first_offset = arr.offsets().first().to_usize(); + let offset_range = arr.offsets().range().to_usize(); + if first_offset != 0 || offset_range != arr.values().len() { + arr.values().sliced(first_offset, offset_range) + } else { + arr.values().to_boxed() + } + }) + .collect_vec(); + concatenate_unchecked(&inner_sliced_arrays[..])? + } else { + let inner_arrays = arrays + .iter() + .map(|arr| { + let arr: &ListArray = arr.as_ref().as_any().downcast_ref().unwrap(); + &**arr.values() + }) + .collect_vec(); + concatenate_unchecked(&inner_arrays)? + }; + + Ok(ListArray::new(dtype, offsets.into(), values, validity)) +} + +fn concatenate_fixed_size_binary>( + arrays: &[A], +) -> PolarsResult { + let dtype = arrays[0].as_ref().dtype().clone(); + let (total_len, null_count) = len_null_count(arrays); + let validity = concatenate_validities_with_len_null_count(arrays, total_len, null_count); + + let total_bytes = arrays + .iter() + .map(|arr| { + let arr: &FixedSizeBinaryArray = arr.as_ref().as_any().downcast_ref().unwrap(); + arr.values().len() + }) + .sum(); + + let mut bytes = Vec::with_capacity(total_bytes); + for arr in arrays { + let arr: &FixedSizeBinaryArray = arr.as_ref().as_any().downcast_ref().unwrap(); + bytes.extend_from_slice(arr.values()); + } + + Ok(FixedSizeBinaryArray::new(dtype, bytes.into(), validity)) +} + +fn concatenate_fixed_size_list>( + arrays: &[A], +) -> PolarsResult { + let dtype = arrays[0].as_ref().dtype().clone(); + let (total_len, null_count) = len_null_count(arrays); + + let inner_arrays = arrays + .iter() + .map(|arr| { + let arr: &FixedSizeListArray = arr.as_ref().as_any().downcast_ref().unwrap(); + &**arr.values() + }) + .collect_vec(); + let values = concatenate_unchecked(&inner_arrays)?; + let validity = concatenate_validities_with_len_null_count(arrays, total_len, null_count); + Ok(FixedSizeListArray::new(dtype, total_len, values, validity)) +} + +fn concatenate_struct>(arrays: &[A]) -> PolarsResult { + let dtype = arrays[0].as_ref().dtype().clone(); + let (total_len, null_count) = len_null_count(arrays); + let validity = concatenate_validities_with_len_null_count(arrays, total_len, null_count); + + let first_arr: &StructArray = arrays[0].as_ref().as_any().downcast_ref().unwrap(); + let num_fields = first_arr.values().len(); + + let mut inner_arrays = Vec::with_capacity(arrays.len()); + let values = (0..num_fields) + .map(|f| { + inner_arrays.clear(); + for arr in arrays { + let arr: &StructArray = arr.as_ref().as_any().downcast_ref().unwrap(); + inner_arrays.push(&arr.values()[f]); + } + concatenate_unchecked(&inner_arrays) + }) + .try_collect_vec()?; + + Ok(StructArray::new(dtype, total_len, values, validity)) +} diff --git a/crates/polars-arrow/src/compute/decimal.rs b/crates/polars-arrow/src/compute/decimal.rs new file mode 100644 index 000000000000..c03a6f280c47 --- /dev/null +++ b/crates/polars-arrow/src/compute/decimal.rs @@ -0,0 +1,314 @@ +use std::sync::atomic::{AtomicBool, Ordering}; + +use num_traits::Euclid; + +static TRIM_DECIMAL_ZEROS: AtomicBool = AtomicBool::new(false); + +pub fn get_trim_decimal_zeros() -> bool { + TRIM_DECIMAL_ZEROS.load(Ordering::Relaxed) +} +pub fn set_trim_decimal_zeros(trim: Option) { + TRIM_DECIMAL_ZEROS.store(trim.unwrap_or(false), Ordering::Relaxed) +} + +/// Assuming bytes are a well-formed decimal number (with or without a separator), +/// infer the scale of the number. If no separator is present, the scale is 0. +pub fn infer_scale(bytes: &[u8]) -> u8 { + let Some(separator) = bytes.iter().position(|b| *b == b'.') else { + return 0; + }; + (bytes.len() - (1 + separator)) as u8 +} + +/// Deserialize bytes to a single i128 representing a decimal, at a specified +/// precision (optional) and scale (required). The number is checked to ensure +/// it fits within the specified precision and scale. Consistent with float +/// parsing, no decimal separator is required (eg "500", "500.", and "500.0" are +/// all accepted); this allows mixed integer/decimal sequences to be parsed as +/// decimals. All trailing zeros are assumed to be significant, whether or not +/// a separator is present: 1200 requires precision >= 4, while 1200.200 +/// requires precision >= 7 and scale >= 3. Returns None if the number is not +/// well-formed, or does not fit. Only b'.' is allowed as a decimal separator +/// (issue #6698). +#[inline] +pub fn deserialize_decimal(bytes: &[u8], precision: Option, scale: u8) -> Option { + let precision_digits = precision.unwrap_or(38).min(38) as usize; + if scale as usize > precision_digits { + return None; + } + + let separator = bytes.iter().position(|b| *b == b'.').unwrap_or(bytes.len()); + let (mut int, mut frac) = bytes.split_at(separator); + if frac.len() <= 1 || scale == 0 { + // Only integer fast path. + let n: i128 = atoi_simd::parse(int).ok()?; + let ret = n.checked_mul(POW10[scale as usize] as i128)?; + if precision.is_some() && ret >= POW10[precision_digits] as i128 { + return None; + } + return Some(ret); + } + + // Skip period. + frac = &frac[1..]; + + // Skip sign. + let negative = match bytes.first() { + Some(s @ (b'+' | b'-')) => { + int = &int[1..]; + *s == b'-' + }, + _ => false, + }; + + // Truncate trailing digits that extend beyond the scale. + let frac_scale = if scale as usize <= frac.len() { + frac = &frac[..scale as usize]; + 0 + } else { + scale as usize - frac.len() + }; + + // Parse and combine parts. + let pint: u128 = if int.is_empty() { + 0 + } else { + atoi_simd::parse_pos(int).ok()? + }; + let pfrac: u128 = atoi_simd::parse_pos(frac).ok()?; + + let ret = pint + .checked_mul(POW10[scale as usize])? + .checked_add(pfrac.checked_mul(POW10[frac_scale])?)?; + if precision.is_some() && ret >= POW10[precision_digits] { + return None; + } + if negative { + if ret > (1 << 127) { + None + } else { + Some(ret.wrapping_neg() as i128) + } + } else { + ret.try_into().ok() + } +} + +const MAX_DECIMAL_LEN: usize = 48; + +#[derive(Clone, Copy)] +pub struct DecimalFmtBuffer { + data: [u8; MAX_DECIMAL_LEN], + len: usize, +} + +impl Default for DecimalFmtBuffer { + fn default() -> Self { + Self::new() + } +} + +impl DecimalFmtBuffer { + #[inline] + pub const fn new() -> Self { + Self { + data: [0; MAX_DECIMAL_LEN], + len: 0, + } + } + + pub fn format(&mut self, x: i128, scale: usize, trim_zeros: bool) -> &str { + let factor = POW10[scale]; + let mut itoa_buf = itoa::Buffer::new(); + + self.len = 0; + let (div, rem) = x.unsigned_abs().div_rem_euclid(&factor); + if x < 0 { + self.data[0] = b'-'; + self.len += 1; + } + + let div_fmt = itoa_buf.format(div); + self.data[self.len..self.len + div_fmt.len()].copy_from_slice(div_fmt.as_bytes()); + self.len += div_fmt.len(); + + if scale == 0 { + return unsafe { std::str::from_utf8_unchecked(&self.data[..self.len]) }; + } + + self.data[self.len] = b'.'; + self.len += 1; + + let rem_fmt = itoa_buf.format(rem + factor); // + factor adds leading 1 where period would be. + self.data[self.len..self.len + rem_fmt.len() - 1].copy_from_slice(&rem_fmt.as_bytes()[1..]); + self.len += rem_fmt.len() - 1; + + if trim_zeros { + while self.data.get(self.len - 1) == Some(&b'0') { + self.len -= 1; + } + if self.data.get(self.len - 1) == Some(&b'.') { + self.len -= 1; + } + } + + unsafe { std::str::from_utf8_unchecked(&self.data[..self.len]) } + } +} + +const POW10: [u128; 39] = [ + 1, + 10, + 100, + 1000, + 10000, + 100000, + 1000000, + 10000000, + 100000000, + 1000000000, + 10000000000, + 100000000000, + 1000000000000, + 10000000000000, + 100000000000000, + 1000000000000000, + 10000000000000000, + 100000000000000000, + 1000000000000000000, + 10000000000000000000, + 100000000000000000000, + 1000000000000000000000, + 10000000000000000000000, + 100000000000000000000000, + 1000000000000000000000000, + 10000000000000000000000000, + 100000000000000000000000000, + 1000000000000000000000000000, + 10000000000000000000000000000, + 100000000000000000000000000000, + 1000000000000000000000000000000, + 10000000000000000000000000000000, + 100000000000000000000000000000000, + 1000000000000000000000000000000000, + 10000000000000000000000000000000000, + 100000000000000000000000000000000000, + 1000000000000000000000000000000000000, + 10000000000000000000000000000000000000, + 100000000000000000000000000000000000000, +]; + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn test_decimal() { + let precision = Some(8); + let scale = 2; + + let val = "12.09"; + assert_eq!( + deserialize_decimal(val.as_bytes(), precision, scale), + Some(1209) + ); + + let val = "1200.90"; + assert_eq!( + deserialize_decimal(val.as_bytes(), precision, scale), + Some(120090) + ); + + let val = "143.9"; + assert_eq!( + deserialize_decimal(val.as_bytes(), precision, scale), + Some(14390) + ); + + let val = "+000000.5"; + assert_eq!( + deserialize_decimal(val.as_bytes(), precision, scale), + Some(50) + ); + + let val = "-0.5"; + assert_eq!( + deserialize_decimal(val.as_bytes(), precision, scale), + Some(-50) + ); + + let val = "-1.5"; + assert_eq!( + deserialize_decimal(val.as_bytes(), precision, scale), + Some(-150) + ); + + let scale = 20; + let val = "0.01"; + assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None); + assert_eq!( + deserialize_decimal(val.as_bytes(), None, scale), + Some(1000000000000000000) + ); + + let scale = 5; + let val = "12ABC.34"; + assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None); + + let val = "1ABC2.34"; + assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None); + + let val = "12.3ABC4"; + assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None); + + let val = "12.3.ABC4"; + assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None); + + let val = "12.-3"; + assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None); + + let val = ""; + assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None); + + let val = "5."; + assert_eq!( + deserialize_decimal(val.as_bytes(), precision, scale), + Some(500000i128) + ); + + let val = "5"; + assert_eq!( + deserialize_decimal(val.as_bytes(), precision, scale), + Some(500000i128) + ); + + let val = ".5"; + assert_eq!( + deserialize_decimal(val.as_bytes(), precision, scale), + Some(50000i128) + ); + + // Precision and scale fitting: + let val = b"1200"; + assert_eq!(deserialize_decimal(val, None, 0), Some(1200)); + assert_eq!(deserialize_decimal(val, Some(4), 0), Some(1200)); + assert_eq!(deserialize_decimal(val, Some(3), 0), None); + assert_eq!(deserialize_decimal(val, Some(4), 1), None); + + let val = b"1200.010"; + assert_eq!(deserialize_decimal(val, None, 0), Some(1200)); // truncate scale + assert_eq!(deserialize_decimal(val, None, 3), Some(1200010)); // exact scale + assert_eq!(deserialize_decimal(val, None, 6), Some(1200010000)); // excess scale + assert_eq!(deserialize_decimal(val, Some(7), 0), Some(1200)); // sufficient precision and truncate scale + assert_eq!(deserialize_decimal(val, Some(7), 3), Some(1200010)); // exact precision and scale + assert_eq!(deserialize_decimal(val, Some(10), 6), Some(1200010000)); // exact precision, excess scale + assert_eq!(deserialize_decimal(val, Some(5), 6), None); // insufficient precision, excess scale + assert_eq!(deserialize_decimal(val, Some(5), 3), None); // insufficient precision, exact scale + assert_eq!(deserialize_decimal(val, Some(12), 5), Some(120001000)); // excess precision, excess scale + assert_eq!( + deserialize_decimal(val, None, 35), + Some(120001000000000000000000000000000000000) + ); + assert_eq!(deserialize_decimal(val, None, 36), None); + assert_eq!(deserialize_decimal(val, Some(38), 35), None); // scale causes insufficient precision + } +} diff --git a/crates/polars-arrow/src/compute/mod.rs b/crates/polars-arrow/src/compute/mod.rs new file mode 100644 index 000000000000..4c7ade50fc5b --- /dev/null +++ b/crates/polars-arrow/src/compute/mod.rs @@ -0,0 +1,34 @@ +//! contains a wide range of compute operations (e.g. +//! [`arithmetics`], [`aggregate`], +//! [`filter`], [`comparison`], and [`sort`]) +//! +//! This module's general design is +//! that each operator has two interfaces, a statically-typed version and a dynamically-typed +//! version. +//! The statically-typed version expects concrete arrays (such as [`PrimitiveArray`](crate::array::PrimitiveArray)); +//! the dynamically-typed version expects `&dyn Array` and errors if the type is not +//! supported. +//! Some dynamically-typed operators have an auxiliary function, `can_*`, that returns +//! true if the operator can be applied to the particular `DataType`. + +#[cfg(feature = "compute_aggregate")] +#[cfg_attr(docsrs, doc(cfg(feature = "compute_aggregate")))] +pub mod aggregate; +pub mod arity; +pub mod arity_assign; +#[cfg(feature = "compute_bitwise")] +#[cfg_attr(docsrs, doc(cfg(feature = "compute_bitwise")))] +pub mod bitwise; +#[cfg(feature = "compute_boolean")] +#[cfg_attr(docsrs, doc(cfg(feature = "compute_boolean")))] +pub mod boolean; +#[cfg(feature = "compute_boolean_kleene")] +#[cfg_attr(docsrs, doc(cfg(feature = "compute_boolean_kleene")))] +pub mod boolean_kleene; +pub mod concatenate; +#[cfg(feature = "dtype-decimal")] +pub mod decimal; +#[cfg(feature = "compute_temporal")] +#[cfg_attr(docsrs, doc(cfg(feature = "compute_temporal")))] +pub mod temporal; +pub mod utils; diff --git a/crates/polars-arrow/src/compute/temporal.rs b/crates/polars-arrow/src/compute/temporal.rs new file mode 100644 index 000000000000..98bc920caeb3 --- /dev/null +++ b/crates/polars-arrow/src/compute/temporal.rs @@ -0,0 +1,341 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines temporal kernels for time and date related functions. + +use chrono::{Datelike, Timelike}; +use polars_error::PolarsResult; + +use super::arity::unary; +use crate::array::*; +use crate::datatypes::*; +use crate::temporal_conversions::*; +use crate::types::NativeType; + +// Create and implement a trait that converts chrono's `Weekday` +// type into `i8` +trait Int8Weekday: Datelike { + fn i8_weekday(&self) -> i8 { + self.weekday().number_from_monday().try_into().unwrap() + } +} + +impl Int8Weekday for chrono::NaiveDateTime {} +impl Int8Weekday for chrono::DateTime {} + +// Create and implement a trait that converts chrono's `IsoWeek` +// type into `i8` +trait Int8IsoWeek: Datelike { + fn i8_iso_week(&self) -> i8 { + self.iso_week().week().try_into().unwrap() + } +} + +impl Int8IsoWeek for chrono::NaiveDateTime {} +impl Int8IsoWeek for chrono::DateTime {} + +// Macro to avoid repetition in functions, that apply +// `chrono::Datelike` methods on Arrays +macro_rules! date_like { + ($extract:ident, $array:ident, $dtype:path) => { + match $array.dtype().to_logical_type() { + ArrowDataType::Date32 | ArrowDataType::Date64 | ArrowDataType::Timestamp(_, None) => { + date_variants($array, $dtype, |x| x.$extract().try_into().unwrap()) + }, + ArrowDataType::Timestamp(time_unit, Some(timezone_str)) => { + let array = $array.as_any().downcast_ref().unwrap(); + + if let Ok(timezone) = parse_offset(timezone_str.as_str()) { + Ok(extract_impl(array, *time_unit, timezone, |x| { + x.$extract().try_into().unwrap() + })) + } else { + chrono_tz(array, *time_unit, timezone_str.as_str(), |x| { + x.$extract().try_into().unwrap() + }) + } + }, + _ => unimplemented!(), + } + }; +} + +/// Extracts the years of a temporal array as [`PrimitiveArray`]. +pub fn year(array: &dyn Array) -> PolarsResult> { + date_like!(year, array, ArrowDataType::Int32) +} + +/// Extracts the months of a temporal array as [`PrimitiveArray`]. +/// +/// Value ranges from 1 to 12. +pub fn month(array: &dyn Array) -> PolarsResult> { + date_like!(month, array, ArrowDataType::Int8) +} + +/// Extracts the days of a temporal array as [`PrimitiveArray`]. +/// +/// Value ranges from 1 to 32 (Last day depends on month). +pub fn day(array: &dyn Array) -> PolarsResult> { + date_like!(day, array, ArrowDataType::Int8) +} + +/// Extracts weekday of a temporal array as [`PrimitiveArray`]. +/// +/// Monday is 1, Tuesday is 2, ..., Sunday is 7. +pub fn weekday(array: &dyn Array) -> PolarsResult> { + date_like!(i8_weekday, array, ArrowDataType::Int8) +} + +/// Extracts ISO week of a temporal array as [`PrimitiveArray`]. +/// +/// Value ranges from 1 to 53 (Last week depends on the year). +pub fn iso_week(array: &dyn Array) -> PolarsResult> { + date_like!(i8_iso_week, array, ArrowDataType::Int8) +} + +// Macro to avoid repetition in functions, that apply +// `chrono::Timelike` methods on Arrays +macro_rules! time_like { + ($extract:ident, $array:ident, $dtype:path) => { + match $array.dtype().to_logical_type() { + ArrowDataType::Date32 | ArrowDataType::Date64 | ArrowDataType::Timestamp(_, None) => { + date_variants($array, $dtype, |x| x.$extract().try_into().unwrap()) + }, + ArrowDataType::Time32(_) | ArrowDataType::Time64(_) => { + time_variants($array, ArrowDataType::UInt32, |x| { + x.$extract().try_into().unwrap() + }) + }, + ArrowDataType::Timestamp(time_unit, Some(timezone_str)) => { + let array = $array.as_any().downcast_ref().unwrap(); + + if let Ok(timezone) = parse_offset(timezone_str.as_str()) { + Ok(extract_impl(array, *time_unit, timezone, |x| { + x.$extract().try_into().unwrap() + })) + } else { + chrono_tz(array, *time_unit, timezone_str.as_str(), |x| { + x.$extract().try_into().unwrap() + }) + } + }, + _ => unimplemented!(), + } + }; +} + +/// Extracts the hours of a temporal array as [`PrimitiveArray`]. +/// Value ranges from 0 to 23. +/// Use [`can_hour`] to check if this operation is supported for the target [`ArrowDataType`]. +pub fn hour(array: &dyn Array) -> PolarsResult> { + time_like!(hour, array, ArrowDataType::Int8) +} + +/// Extracts the minutes of a temporal array as [`PrimitiveArray`]. +/// Value ranges from 0 to 59. +/// Use [`can_minute`] to check if this operation is supported for the target [`ArrowDataType`]. +pub fn minute(array: &dyn Array) -> PolarsResult> { + time_like!(minute, array, ArrowDataType::Int8) +} + +/// Extracts the seconds of a temporal array as [`PrimitiveArray`]. +/// Value ranges from 0 to 59. +/// Use [`can_second`] to check if this operation is supported for the target [`ArrowDataType`]. +pub fn second(array: &dyn Array) -> PolarsResult> { + time_like!(second, array, ArrowDataType::Int8) +} + +/// Extracts the nanoseconds of a temporal array as [`PrimitiveArray`]. +/// +/// Value ranges from 0 to 1_999_999_999. +/// The range from 1_000_000_000 to 1_999_999_999 represents the leap second. +/// Use [`can_nanosecond`] to check if this operation is supported for the target [`ArrowDataType`]. +pub fn nanosecond(array: &dyn Array) -> PolarsResult> { + time_like!(nanosecond, array, ArrowDataType::Int32) +} + +fn date_variants( + array: &dyn Array, + dtype: ArrowDataType, + op: F, +) -> PolarsResult> +where + O: NativeType, + F: Fn(chrono::NaiveDateTime) -> O, +{ + match array.dtype().to_logical_type() { + ArrowDataType::Date32 => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(unary(array, |x| op(date32_to_datetime(x)), dtype)) + }, + ArrowDataType::Date64 => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(unary(array, |x| op(date64_to_datetime(x)), dtype)) + }, + ArrowDataType::Timestamp(time_unit, None) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let func = match time_unit { + TimeUnit::Second => timestamp_s_to_datetime, + TimeUnit::Millisecond => timestamp_ms_to_datetime, + TimeUnit::Microsecond => timestamp_us_to_datetime, + TimeUnit::Nanosecond => timestamp_ns_to_datetime, + }; + Ok(PrimitiveArray::::from_trusted_len_iter( + array.iter().map(|v| v.map(|x| op(func(*x)))), + )) + }, + _ => unreachable!(), + } +} + +fn time_variants( + array: &dyn Array, + dtype: ArrowDataType, + op: F, +) -> PolarsResult> +where + O: NativeType, + F: Fn(chrono::NaiveTime) -> O, +{ + match array.dtype().to_logical_type() { + ArrowDataType::Time32(TimeUnit::Second) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(unary(array, |x| op(time32s_to_time(x)), dtype)) + }, + ArrowDataType::Time32(TimeUnit::Millisecond) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(unary(array, |x| op(time32ms_to_time(x)), dtype)) + }, + ArrowDataType::Time64(TimeUnit::Microsecond) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(unary(array, |x| op(time64us_to_time(x)), dtype)) + }, + ArrowDataType::Time64(TimeUnit::Nanosecond) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(unary(array, |x| op(time64ns_to_time(x)), dtype)) + }, + _ => unreachable!(), + } +} + +#[cfg(feature = "chrono-tz")] +fn chrono_tz( + array: &PrimitiveArray, + time_unit: TimeUnit, + timezone_str: &str, + op: F, +) -> PolarsResult> +where + O: NativeType, + F: Fn(chrono::DateTime) -> O, +{ + let timezone = parse_offset_tz(timezone_str)?; + Ok(extract_impl(array, time_unit, timezone, op)) +} + +#[cfg(not(feature = "chrono-tz"))] +fn chrono_tz( + _: &PrimitiveArray, + _: TimeUnit, + timezone_str: &str, + _: F, +) -> PolarsResult> +where + O: NativeType, + F: Fn(chrono::DateTime) -> O, +{ + panic!( + "timezone \"{}\" cannot be parsed (feature chrono-tz is not active)", + timezone_str + ) +} + +fn extract_impl( + array: &PrimitiveArray, + time_unit: TimeUnit, + timezone: T, + extract: F, +) -> PrimitiveArray +where + T: chrono::TimeZone, + A: NativeType, + F: Fn(chrono::DateTime) -> A, +{ + match time_unit { + TimeUnit::Second => { + let op = |x| { + let datetime = timestamp_s_to_datetime(x); + let offset = timezone.offset_from_utc_datetime(&datetime); + extract(chrono::DateTime::::from_naive_utc_and_offset( + datetime, offset, + )) + }; + unary(array, op, A::PRIMITIVE.into()) + }, + TimeUnit::Millisecond => { + let op = |x| { + let datetime = timestamp_ms_to_datetime(x); + let offset = timezone.offset_from_utc_datetime(&datetime); + extract(chrono::DateTime::::from_naive_utc_and_offset( + datetime, offset, + )) + }; + unary(array, op, A::PRIMITIVE.into()) + }, + TimeUnit::Microsecond => { + let op = |x| { + let datetime = timestamp_us_to_datetime(x); + let offset = timezone.offset_from_utc_datetime(&datetime); + extract(chrono::DateTime::::from_naive_utc_and_offset( + datetime, offset, + )) + }; + unary(array, op, A::PRIMITIVE.into()) + }, + TimeUnit::Nanosecond => { + let op = |x| { + let datetime = timestamp_ns_to_datetime(x); + let offset = timezone.offset_from_utc_datetime(&datetime); + extract(chrono::DateTime::::from_naive_utc_and_offset( + datetime, offset, + )) + }; + unary(array, op, A::PRIMITIVE.into()) + }, + } +} diff --git a/crates/polars-arrow/src/compute/utils.rs b/crates/polars-arrow/src/compute/utils.rs new file mode 100644 index 000000000000..2faf1f12b8b5 --- /dev/null +++ b/crates/polars-arrow/src/compute/utils.rs @@ -0,0 +1,118 @@ +use std::borrow::Borrow; +use std::ops::{BitAnd, BitOr}; + +use polars_error::{PolarsResult, polars_ensure}; + +use crate::array::Array; +use crate::bitmap::{Bitmap, and_not, push_bitchunk, ternary}; + +pub fn combine_validities_and3( + opt1: Option<&Bitmap>, + opt2: Option<&Bitmap>, + opt3: Option<&Bitmap>, +) -> Option { + match (opt1, opt2, opt3) { + (Some(a), Some(b), Some(c)) => Some(ternary(a, b, c, |x, y, z| x & y & z)), + (Some(a), Some(b), None) => Some(a.bitand(b)), + (Some(a), None, Some(c)) => Some(a.bitand(c)), + (None, Some(b), Some(c)) => Some(b.bitand(c)), + (Some(a), None, None) => Some(a.clone()), + (None, Some(b), None) => Some(b.clone()), + (None, None, Some(c)) => Some(c.clone()), + (None, None, None) => None, + } +} + +pub fn combine_validities_and(opt_l: Option<&Bitmap>, opt_r: Option<&Bitmap>) -> Option { + match (opt_l, opt_r) { + (Some(l), Some(r)) => Some(l.bitand(r)), + (None, Some(r)) => Some(r.clone()), + (Some(l), None) => Some(l.clone()), + (None, None) => None, + } +} + +pub fn combine_validities_or(opt_l: Option<&Bitmap>, opt_r: Option<&Bitmap>) -> Option { + match (opt_l, opt_r) { + (Some(l), Some(r)) => Some(l.bitor(r)), + _ => None, + } +} + +pub fn combine_validities_and_not( + opt_l: Option<&Bitmap>, + opt_r: Option<&Bitmap>, +) -> Option { + match (opt_l, opt_r) { + (Some(l), Some(r)) => Some(and_not(l, r)), + (None, Some(r)) => Some(!r), + (Some(l), None) => Some(l.clone()), + (None, None) => None, + } +} + +pub fn combine_validities_and_many>(bitmaps: &[Option]) -> Option { + let mut bitmaps = bitmaps + .iter() + .flatten() + .map(|b| b.borrow()) + .collect::>(); + + match bitmaps.len() { + 0 => None, + 1 => bitmaps.pop().cloned(), + 2 => combine_validities_and(bitmaps.pop(), bitmaps.pop()), + 3 => combine_validities_and3(bitmaps.pop(), bitmaps.pop(), bitmaps.pop()), + _ => { + let mut iterators = bitmaps + .iter() + .map(|v| v.fast_iter_u64()) + .collect::>(); + let mut buffer = Vec::with_capacity(iterators.first().unwrap().size_hint().0 + 2); + + 'rows: loop { + // All ones so as identity for & operation + let mut out = u64::MAX; + for iter in iterators.iter_mut() { + if let Some(v) = iter.next() { + out &= v + } else { + break 'rows; + } + } + push_bitchunk(&mut buffer, out); + } + + // All ones so as identity for & operation + let mut out = [u64::MAX, u64::MAX]; + let mut len = 0; + for iter in iterators.into_iter() { + let (rem, rem_len) = iter.remainder(); + len = rem_len; + + for (out, rem) in out.iter_mut().zip(rem) { + *out &= rem; + } + } + push_bitchunk(&mut buffer, out[0]); + if len > 64 { + push_bitchunk(&mut buffer, out[1]); + } + let bitmap = Bitmap::from_u8_vec(buffer, bitmaps[0].len()); + if bitmap.unset_bits() == bitmap.len() { + None + } else { + Some(bitmap) + } + }, + } +} + +// Errors iff the two arrays have a different length. +#[inline] +pub fn check_same_len(lhs: &dyn Array, rhs: &dyn Array) -> PolarsResult<()> { + polars_ensure!(lhs.len() == rhs.len(), ComputeError: + "arrays must have the same length" + ); + Ok(()) +} diff --git a/crates/polars-arrow/src/datatypes/field.rs b/crates/polars-arrow/src/datatypes/field.rs new file mode 100644 index 000000000000..4957b2b69a5f --- /dev/null +++ b/crates/polars-arrow/src/datatypes/field.rs @@ -0,0 +1,86 @@ +use std::sync::Arc; + +use polars_utils::pl_str::PlSmallStr; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use super::{ArrowDataType, Metadata}; + +pub static DTYPE_ENUM_VALUES: &str = "_PL_ENUM_VALUES"; +pub static DTYPE_CATEGORICAL: &str = "_PL_CATEGORICAL"; + +/// Represents Arrow's metadata of a "column". +/// +/// A [`Field`] is the closest representation of the traditional "column": a logical type +/// ([`ArrowDataType`]) with a name and nullability. +/// A Field has optional [`Metadata`] that can be used to annotate the field with custom metadata. +/// +/// Almost all IO in this crate uses [`Field`] to represent logical information about the data +/// to be serialized. +#[derive(Debug, Clone, Eq, PartialEq, Hash, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct Field { + /// Its name + pub name: PlSmallStr, + /// Its logical [`ArrowDataType`] + pub dtype: ArrowDataType, + /// Its nullability + pub is_nullable: bool, + /// Additional custom (opaque) metadata. + pub metadata: Option>, +} + +/// Support for `ArrowSchema::from_iter([field, ..])` +impl From for (PlSmallStr, Field) { + fn from(value: Field) -> Self { + (value.name.clone(), value) + } +} + +impl Field { + /// Creates a new [`Field`]. + pub fn new(name: PlSmallStr, dtype: ArrowDataType, is_nullable: bool) -> Self { + Field { + name, + dtype, + is_nullable, + metadata: Default::default(), + } + } + + /// Creates a new [`Field`] with metadata. + #[inline] + pub fn with_metadata(self, metadata: Metadata) -> Self { + if metadata.is_empty() { + return self; + } + Self { + name: self.name, + dtype: self.dtype, + is_nullable: self.is_nullable, + metadata: Some(Arc::new(metadata)), + } + } + + /// Returns the [`Field`]'s [`ArrowDataType`]. + #[inline] + pub fn dtype(&self) -> &ArrowDataType { + &self.dtype + } + + pub fn is_enum(&self) -> bool { + if let Some(md) = &self.metadata { + md.get(DTYPE_ENUM_VALUES).is_some() + } else { + false + } + } + + pub fn is_categorical(&self) -> bool { + if let Some(md) = &self.metadata { + md.get(DTYPE_CATEGORICAL).is_some() + } else { + false + } + } +} diff --git a/crates/polars-arrow/src/datatypes/mod.rs b/crates/polars-arrow/src/datatypes/mod.rs new file mode 100644 index 000000000000..feb72f974ce2 --- /dev/null +++ b/crates/polars-arrow/src/datatypes/mod.rs @@ -0,0 +1,515 @@ +//! Contains all metadata, such as [`PhysicalType`], [`ArrowDataType`], [`Field`] and [`ArrowSchema`]. + +mod field; +mod physical_type; +pub mod reshape; +mod schema; + +use std::collections::BTreeMap; +use std::sync::Arc; + +pub use field::{DTYPE_CATEGORICAL, DTYPE_ENUM_VALUES, Field}; +pub use physical_type::*; +use polars_utils::pl_str::PlSmallStr; +pub use schema::{ArrowSchema, ArrowSchemaRef}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// typedef for [BTreeMap] denoting [`Field`]'s and [`ArrowSchema`]'s metadata. +pub type Metadata = BTreeMap; +/// typedef for [Option<(PlSmallStr, Option)>] descr +pub(crate) type Extension = Option<(PlSmallStr, Option)>; + +/// The set of supported logical types in this crate. +/// +/// Each variant uniquely identifies a logical type, which define specific semantics to the data +/// (e.g. how it should be represented). +/// Each variant has a corresponding [`PhysicalType`], obtained via [`ArrowDataType::to_physical_type`], +/// which declares the in-memory representation of data. +/// The [`ArrowDataType::Extension`] is special in that it augments a [`ArrowDataType`] with metadata to support custom types. +/// Use `to_logical_type` to desugar such type and return its corresponding logical type. +#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum ArrowDataType { + /// Null type + #[default] + Null, + /// `true` and `false`. + Boolean, + /// An [`i8`] + Int8, + /// An [`i16`] + Int16, + /// An [`i32`] + Int32, + /// An [`i64`] + Int64, + /// An [`i128`] + Int128, + /// An [`u8`] + UInt8, + /// An [`u16`] + UInt16, + /// An [`u32`] + UInt32, + /// An [`u64`] + UInt64, + /// An 16-bit float + Float16, + /// A [`f32`] + Float32, + /// A [`f64`] + Float64, + /// A [`i64`] representing a timestamp measured in [`TimeUnit`] with an optional timezone. + /// + /// Time is measured as a Unix epoch, counting the seconds from + /// 00:00:00.000 on 1 January 1970, excluding leap seconds, + /// as a 64-bit signed integer. + /// + /// The time zone is a string indicating the name of a time zone, one of: + /// + /// * As used in the Olson time zone database (the "tz database" or + /// "tzdata"), such as "America/New_York" + /// * An absolute time zone offset of the form +XX:XX or -XX:XX, such as +07:30 + /// + /// When the timezone is not specified, the timestamp is considered to have no timezone + /// and is represented _as is_ + Timestamp(TimeUnit, Option), + /// An [`i32`] representing the elapsed time since UNIX epoch (1970-01-01) + /// in days. + Date32, + /// An [`i64`] representing the elapsed time since UNIX epoch (1970-01-01) + /// in milliseconds. Values are evenly divisible by 86400000. + Date64, + /// A 32-bit time representing the elapsed time since midnight in the unit of `TimeUnit`. + /// Only [`TimeUnit::Second`] and [`TimeUnit::Millisecond`] are supported on this variant. + Time32(TimeUnit), + /// A 64-bit time representing the elapsed time since midnight in the unit of `TimeUnit`. + /// Only [`TimeUnit::Microsecond`] and [`TimeUnit::Nanosecond`] are supported on this variant. + Time64(TimeUnit), + /// Measure of elapsed time. This elapsed time is a physical duration (i.e. 1s as defined in S.I.) + Duration(TimeUnit), + /// A "calendar" interval modeling elapsed time that takes into account calendar shifts. + /// For example an interval of 1 day may represent more than 24 hours. + Interval(IntervalUnit), + /// Opaque binary data of variable length whose offsets are represented as [`i32`]. + Binary, + /// Opaque binary data of fixed size. + /// Enum parameter specifies the number of bytes per value. + FixedSizeBinary(usize), + /// Opaque binary data of variable length whose offsets are represented as [`i64`]. + LargeBinary, + /// A variable-length UTF-8 encoded string whose offsets are represented as [`i32`]. + Utf8, + /// A variable-length UTF-8 encoded string whose offsets are represented as [`i64`]. + LargeUtf8, + /// A list of some logical data type whose offsets are represented as [`i32`]. + List(Box), + /// A list of some logical data type with a fixed number of elements. + FixedSizeList(Box, usize), + /// A list of some logical data type whose offsets are represented as [`i64`]. + LargeList(Box), + /// A nested [`ArrowDataType`] with a given number of [`Field`]s. + Struct(Vec), + /// A nested type that is represented as + /// + /// List> + /// + /// In this layout, the keys and values are each respectively contiguous. We do + /// not constrain the key and value types, so the application is responsible + /// for ensuring that the keys are hashable and unique. Whether the keys are sorted + /// may be set in the metadata for this field. + /// + /// In a field with Map type, the field has a child Struct field, which then + /// has two children: key type and the second the value type. The names of the + /// child fields may be respectively "entries", "key", and "value", but this is + /// not enforced. + /// + /// Map + /// ```text + /// - child[0] entries: Struct + /// - child[0] key: K + /// - child[1] value: V + /// ``` + /// Neither the "entries" field nor the "key" field may be nullable. + /// + /// The metadata is structured so that Arrow systems without special handling + /// for Map can make Map an alias for List. The "layout" attribute for the Map + /// field must have the same contents as a List. + /// - Field + /// - ordered + Map(Box, bool), + /// A dictionary encoded array (`key_type`, `value_type`), where + /// each array element is an index of `key_type` into an + /// associated dictionary of `value_type`. + /// + /// Dictionary arrays are used to store columns of `value_type` + /// that contain many repeated values using less memory, but with + /// a higher CPU overhead for some operations. + /// + /// This type mostly used to represent low cardinality string + /// arrays or a limited set of primitive types as integers. + /// + /// The `bool` value indicates the `Dictionary` is sorted if set to `true`. + Dictionary(IntegerType, Box, bool), + /// Decimal value with precision and scale + /// precision is the number of digits in the number and + /// scale is the number of decimal places. + /// The number 999.99 has a precision of 5 and scale of 2. + Decimal(usize, usize), + /// Decimal backed by 256 bits + Decimal256(usize, usize), + /// Extension type. + Extension(Box), + /// A binary type that inlines small values + /// and can intern bytes. + BinaryView, + /// A string type that inlines small values + /// and can intern strings. + Utf8View, + /// A type unknown to Arrow. + Unknown, + /// A nested datatype that can represent slots of differing types. + /// Third argument represents mode + #[cfg_attr(feature = "serde", serde(skip))] + Union(Box), +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct ExtensionType { + pub name: PlSmallStr, + pub inner: ArrowDataType, + pub metadata: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct UnionType { + pub fields: Vec, + pub ids: Option>, + pub mode: UnionMode, +} + +/// Mode of [`ArrowDataType::Union`] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum UnionMode { + /// Dense union + Dense, + /// Sparse union + Sparse, +} + +impl UnionMode { + /// Constructs a [`UnionMode::Sparse`] if the input bool is true, + /// or otherwise constructs a [`UnionMode::Dense`] + pub fn sparse(is_sparse: bool) -> Self { + if is_sparse { Self::Sparse } else { Self::Dense } + } + + /// Returns whether the mode is sparse + pub fn is_sparse(&self) -> bool { + matches!(self, Self::Sparse) + } + + /// Returns whether the mode is dense + pub fn is_dense(&self) -> bool { + matches!(self, Self::Dense) + } +} + +/// The time units defined in Arrow. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum TimeUnit { + /// Time in seconds. + Second, + /// Time in milliseconds. + Millisecond, + /// Time in microseconds. + Microsecond, + /// Time in nanoseconds. + Nanosecond, +} + +/// Interval units defined in Arrow +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum IntervalUnit { + /// The number of elapsed whole months. + YearMonth, + /// The number of elapsed days and milliseconds, + /// stored as 2 contiguous `i32` + DayTime, + /// The number of elapsed months (i32), days (i32) and nanoseconds (i64). + MonthDayNano, +} + +impl ArrowDataType { + /// the [`PhysicalType`] of this [`ArrowDataType`]. + pub fn to_physical_type(&self) -> PhysicalType { + use ArrowDataType::*; + match self { + Null => PhysicalType::Null, + Boolean => PhysicalType::Boolean, + Int8 => PhysicalType::Primitive(PrimitiveType::Int8), + Int16 => PhysicalType::Primitive(PrimitiveType::Int16), + Int32 | Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) => { + PhysicalType::Primitive(PrimitiveType::Int32) + }, + Int64 | Date64 | Timestamp(_, _) | Time64(_) | Duration(_) => { + PhysicalType::Primitive(PrimitiveType::Int64) + }, + Decimal(_, _) => PhysicalType::Primitive(PrimitiveType::Int128), + Decimal256(_, _) => PhysicalType::Primitive(PrimitiveType::Int256), + UInt8 => PhysicalType::Primitive(PrimitiveType::UInt8), + UInt16 => PhysicalType::Primitive(PrimitiveType::UInt16), + UInt32 => PhysicalType::Primitive(PrimitiveType::UInt32), + UInt64 => PhysicalType::Primitive(PrimitiveType::UInt64), + Float16 => PhysicalType::Primitive(PrimitiveType::Float16), + Float32 => PhysicalType::Primitive(PrimitiveType::Float32), + Float64 => PhysicalType::Primitive(PrimitiveType::Float64), + Int128 => PhysicalType::Primitive(PrimitiveType::Int128), + Interval(IntervalUnit::DayTime) => PhysicalType::Primitive(PrimitiveType::DaysMs), + Interval(IntervalUnit::MonthDayNano) => { + PhysicalType::Primitive(PrimitiveType::MonthDayNano) + }, + Binary => PhysicalType::Binary, + FixedSizeBinary(_) => PhysicalType::FixedSizeBinary, + LargeBinary => PhysicalType::LargeBinary, + Utf8 => PhysicalType::Utf8, + LargeUtf8 => PhysicalType::LargeUtf8, + BinaryView => PhysicalType::BinaryView, + Utf8View => PhysicalType::Utf8View, + List(_) => PhysicalType::List, + FixedSizeList(_, _) => PhysicalType::FixedSizeList, + LargeList(_) => PhysicalType::LargeList, + Struct(_) => PhysicalType::Struct, + Union(_) => PhysicalType::Union, + Map(_, _) => PhysicalType::Map, + Dictionary(key, _, _) => PhysicalType::Dictionary(*key), + Extension(ext) => ext.inner.to_physical_type(), + Unknown => unimplemented!(), + } + } + + // The datatype underlying this (possibly logical) arrow data type. + pub fn underlying_physical_type(&self) -> ArrowDataType { + use ArrowDataType::*; + match self { + Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) => Int32, + Date64 + | Timestamp(_, _) + | Time64(_) + | Duration(_) + | Interval(IntervalUnit::DayTime) => Int64, + Interval(IntervalUnit::MonthDayNano) => unimplemented!(), + Binary => Binary, + List(field) => List(Box::new(Field { + dtype: field.dtype.underlying_physical_type(), + ..*field.clone() + })), + LargeList(field) => LargeList(Box::new(Field { + dtype: field.dtype.underlying_physical_type(), + ..*field.clone() + })), + FixedSizeList(field, width) => FixedSizeList( + Box::new(Field { + dtype: field.dtype.underlying_physical_type(), + ..*field.clone() + }), + *width, + ), + Struct(fields) => Struct( + fields + .iter() + .map(|field| Field { + dtype: field.dtype.underlying_physical_type(), + ..field.clone() + }) + .collect(), + ), + Dictionary(keys, _, _) => (*keys).into(), + Union(_) => unimplemented!(), + Map(_, _) => unimplemented!(), + Extension(ext) => ext.inner.underlying_physical_type(), + _ => self.clone(), + } + } + + /// Returns `&self` for all but [`ArrowDataType::Extension`]. For [`ArrowDataType::Extension`], + /// (recursively) returns the inner [`ArrowDataType`]. + /// Never returns the variant [`ArrowDataType::Extension`]. + pub fn to_logical_type(&self) -> &ArrowDataType { + use ArrowDataType::*; + match self { + Extension(ext) => ext.inner.to_logical_type(), + _ => self, + } + } + + pub fn inner_dtype(&self) -> Option<&ArrowDataType> { + match self { + ArrowDataType::List(inner) => Some(inner.dtype()), + ArrowDataType::LargeList(inner) => Some(inner.dtype()), + ArrowDataType::FixedSizeList(inner, _) => Some(inner.dtype()), + _ => None, + } + } + + pub fn is_nested(&self) -> bool { + use ArrowDataType as D; + + matches!( + self, + D::List(_) + | D::LargeList(_) + | D::FixedSizeList(_, _) + | D::Struct(_) + | D::Union(_) + | D::Map(_, _) + | D::Dictionary(_, _, _) + | D::Extension(_) + ) + } + + pub fn is_view(&self) -> bool { + matches!(self, ArrowDataType::Utf8View | ArrowDataType::BinaryView) + } + + pub fn is_numeric(&self) -> bool { + use ArrowDataType as D; + matches!( + self, + D::Int8 + | D::Int16 + | D::Int32 + | D::Int64 + | D::Int128 + | D::UInt8 + | D::UInt16 + | D::UInt32 + | D::UInt64 + | D::Float32 + | D::Float64 + | D::Decimal(_, _) + | D::Decimal256(_, _) + ) + } + + pub fn to_fixed_size_list(self, size: usize, is_nullable: bool) -> ArrowDataType { + ArrowDataType::FixedSizeList( + Box::new(Field::new( + PlSmallStr::from_static("item"), + self, + is_nullable, + )), + size, + ) + } + + /// Check (recursively) whether datatype contains an [`ArrowDataType::Dictionary`] type. + pub fn contains_dictionary(&self) -> bool { + use ArrowDataType as D; + match self { + D::Null + | D::Boolean + | D::Int8 + | D::Int16 + | D::Int32 + | D::Int64 + | D::UInt8 + | D::UInt16 + | D::UInt32 + | D::UInt64 + | D::Int128 + | D::Float16 + | D::Float32 + | D::Float64 + | D::Timestamp(_, _) + | D::Date32 + | D::Date64 + | D::Time32(_) + | D::Time64(_) + | D::Duration(_) + | D::Interval(_) + | D::Binary + | D::FixedSizeBinary(_) + | D::LargeBinary + | D::Utf8 + | D::LargeUtf8 + | D::Decimal(_, _) + | D::Decimal256(_, _) + | D::BinaryView + | D::Utf8View + | D::Unknown => false, + D::List(field) + | D::FixedSizeList(field, _) + | D::Map(field, _) + | D::LargeList(field) => field.dtype().contains_dictionary(), + D::Struct(fields) => fields.iter().any(|f| f.dtype().contains_dictionary()), + D::Union(union) => union.fields.iter().any(|f| f.dtype().contains_dictionary()), + D::Dictionary(_, _, _) => true, + D::Extension(ext) => ext.inner.contains_dictionary(), + } + } +} + +impl From for ArrowDataType { + fn from(item: IntegerType) -> Self { + match item { + IntegerType::Int8 => ArrowDataType::Int8, + IntegerType::Int16 => ArrowDataType::Int16, + IntegerType::Int32 => ArrowDataType::Int32, + IntegerType::Int64 => ArrowDataType::Int64, + IntegerType::Int128 => ArrowDataType::Int128, + IntegerType::UInt8 => ArrowDataType::UInt8, + IntegerType::UInt16 => ArrowDataType::UInt16, + IntegerType::UInt32 => ArrowDataType::UInt32, + IntegerType::UInt64 => ArrowDataType::UInt64, + } + } +} + +impl From for ArrowDataType { + fn from(item: PrimitiveType) -> Self { + match item { + PrimitiveType::Int8 => ArrowDataType::Int8, + PrimitiveType::Int16 => ArrowDataType::Int16, + PrimitiveType::Int32 => ArrowDataType::Int32, + PrimitiveType::Int64 => ArrowDataType::Int64, + PrimitiveType::UInt8 => ArrowDataType::UInt8, + PrimitiveType::UInt16 => ArrowDataType::UInt16, + PrimitiveType::UInt32 => ArrowDataType::UInt32, + PrimitiveType::UInt64 => ArrowDataType::UInt64, + PrimitiveType::Int128 => ArrowDataType::Int128, + PrimitiveType::Int256 => ArrowDataType::Decimal256(32, 32), + PrimitiveType::Float16 => ArrowDataType::Float16, + PrimitiveType::Float32 => ArrowDataType::Float32, + PrimitiveType::Float64 => ArrowDataType::Float64, + PrimitiveType::DaysMs => ArrowDataType::Interval(IntervalUnit::DayTime), + PrimitiveType::MonthDayNano => ArrowDataType::Interval(IntervalUnit::MonthDayNano), + PrimitiveType::UInt128 => unimplemented!(), + } + } +} + +/// typedef for [`Arc`]. +pub type SchemaRef = Arc; + +/// support get extension for metadata +pub fn get_extension(metadata: &Metadata) -> Extension { + if let Some(name) = metadata.get(&PlSmallStr::from_static("ARROW:extension:name")) { + let metadata = metadata + .get(&PlSmallStr::from_static("ARROW:extension:metadata")) + .cloned(); + Some((name.clone(), metadata)) + } else { + None + } +} + +#[cfg(not(feature = "bigidx"))] +pub type IdxArr = super::array::UInt32Array; +#[cfg(feature = "bigidx")] +pub type IdxArr = super::array::UInt64Array; diff --git a/crates/polars-arrow/src/datatypes/physical_type.rs b/crates/polars-arrow/src/datatypes/physical_type.rs new file mode 100644 index 000000000000..f75d8e644f4c --- /dev/null +++ b/crates/polars-arrow/src/datatypes/physical_type.rs @@ -0,0 +1,89 @@ +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +pub use crate::types::PrimitiveType; + +/// The set of physical types: unique in-memory representations of an Arrow array. +/// +/// A physical type has a one-to-many relationship with a [`crate::datatypes::ArrowDataType`] and +/// a one-to-one mapping to each struct in this crate that implements [`crate::array::Array`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum PhysicalType { + /// A Null with no allocation. + Null, + /// A boolean represented as a single bit. + Boolean, + /// An array where each slot has a known compile-time size. + Primitive(PrimitiveType), + /// Opaque binary data of variable length. + Binary, + /// Opaque binary data of fixed size. + FixedSizeBinary, + /// Opaque binary data of variable length and 64-bit offsets. + LargeBinary, + /// A variable-length string in Unicode with UTF-8 encoding. + Utf8, + /// A variable-length string in Unicode with UFT-8 encoding and 64-bit offsets. + LargeUtf8, + /// A list of some data type with variable length. + List, + /// A list of some data type with fixed length. + FixedSizeList, + /// A list of some data type with variable length and 64-bit offsets. + LargeList, + /// A nested type that contains an arbitrary number of fields. + Struct, + /// A nested type that represents slots of differing types. + Union, + /// A nested type. + Map, + /// A dictionary encoded array by `IntegerType`. + Dictionary(IntegerType), + /// A binary type that inlines small values + /// and can intern bytes. + BinaryView, + /// A string type that inlines small values + /// and can intern strings. + Utf8View, +} + +impl PhysicalType { + /// Whether this physical type equals [`PhysicalType::Primitive`] of type `primitive`. + pub fn eq_primitive(&self, primitive: PrimitiveType) -> bool { + if let Self::Primitive(o) = self { + o == &primitive + } else { + false + } + } + + pub fn is_primitive(&self) -> bool { + matches!(self, Self::Primitive(_)) + } +} + +/// the set of valid indices types of a dictionary-encoded Array. +/// Each type corresponds to a variant of [`crate::array::DictionaryArray`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum IntegerType { + /// A signed 8-bit integer. + Int8, + /// A signed 16-bit integer. + Int16, + /// A signed 32-bit integer. + Int32, + /// A signed 64-bit integer. + Int64, + /// A signed 128-bit integer. + Int128, + /// An unsigned 8-bit integer. + UInt8, + /// An unsigned 16-bit integer. + UInt16, + /// An unsigned 32-bit integer. + UInt32, + /// An unsigned 64-bit integer. + UInt64, +} diff --git a/crates/polars-arrow/src/datatypes/reshape.rs b/crates/polars-arrow/src/datatypes/reshape.rs new file mode 100644 index 000000000000..7fa5e3dcfa7f --- /dev/null +++ b/crates/polars-arrow/src/datatypes/reshape.rs @@ -0,0 +1,118 @@ +use std::fmt; +use std::hash::Hash; +use std::num::NonZeroU64; + +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[repr(transparent)] +pub struct Dimension(NonZeroU64); + +/// A dimension in a reshape. +/// +/// Any dimension smaller than 0 is seen as an `infer`. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub enum ReshapeDimension { + Infer, + Specified(Dimension), +} + +impl fmt::Debug for Dimension { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.get().fmt(f) + } +} + +impl fmt::Display for ReshapeDimension { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Infer => f.write_str("inferred"), + Self::Specified(v) => v.get().fmt(f), + } + } +} + +impl Hash for ReshapeDimension { + fn hash(&self, state: &mut H) { + self.to_repr().hash(state) + } +} + +impl Dimension { + #[inline] + pub const fn new(v: u64) -> Self { + assert!(v <= i64::MAX as u64); + + // SAFETY: Bounds check done before + let dim = unsafe { NonZeroU64::new_unchecked(v.wrapping_add(1)) }; + Self(dim) + } + + #[inline] + pub const fn get(self) -> u64 { + self.0.get() - 1 + } +} + +impl ReshapeDimension { + #[inline] + pub const fn new(v: i64) -> Self { + if v < 0 { + Self::Infer + } else { + // SAFETY: We have bounds checked for -1 + let dim = unsafe { NonZeroU64::new_unchecked((v as u64).wrapping_add(1)) }; + Self::Specified(Dimension(dim)) + } + } + + #[inline] + fn to_repr(self) -> u64 { + match self { + Self::Infer => 0, + Self::Specified(dim) => dim.0.get(), + } + } + + #[inline] + pub const fn get(self) -> Option { + match self { + ReshapeDimension::Infer => None, + ReshapeDimension::Specified(dim) => Some(dim.get()), + } + } + + #[inline] + pub const fn get_or_infer(self, inferred: u64) -> u64 { + match self { + ReshapeDimension::Infer => inferred, + ReshapeDimension::Specified(dim) => dim.get(), + } + } + + #[inline] + pub fn get_or_infer_with(self, f: impl Fn() -> u64) -> u64 { + match self { + ReshapeDimension::Infer => f(), + ReshapeDimension::Specified(dim) => dim.get(), + } + } + + pub const fn new_dimension(dimension: u64) -> ReshapeDimension { + Self::Specified(Dimension::new(dimension)) + } +} + +impl TryFrom for Dimension { + type Error = (); + + #[inline] + fn try_from(value: i64) -> Result { + let ReshapeDimension::Specified(v) = ReshapeDimension::new(value) else { + return Err(()); + }; + + Ok(v) + } +} diff --git a/crates/polars-arrow/src/datatypes/schema.rs b/crates/polars-arrow/src/datatypes/schema.rs new file mode 100644 index 000000000000..02920204b4dc --- /dev/null +++ b/crates/polars-arrow/src/datatypes/schema.rs @@ -0,0 +1,11 @@ +use std::sync::Arc; + +use super::Field; + +/// An ordered sequence of [`Field`]s +/// +/// [`ArrowSchema`] is an abstraction used to read from, and write to, Arrow IPC format, +/// Apache Parquet, and Apache Avro. All these formats have a concept of a schema +/// with fields and metadata. +pub type ArrowSchema = polars_schema::Schema; +pub type ArrowSchemaRef = Arc; diff --git a/crates/polars-arrow/src/doc/lib.md b/crates/polars-arrow/src/doc/lib.md new file mode 100644 index 000000000000..6d22c121890d --- /dev/null +++ b/crates/polars-arrow/src/doc/lib.md @@ -0,0 +1,80 @@ +Welcome to polars_arrow's documentation. Thanks for checking it out! + +This is a library for efficient in-memory data operations with +[Arrow in-memory format](https://arrow.apache.org/docs/format/Columnar.html). It is a re-write from +the bottom up of the official `arrow` crate with soundness and type safety in mind. + +Check out [the guide](https://jorgecarleitao.github.io/polars_arrow/main/guide/) for an +introduction. Below is an example of some of the things you can do with it: + +```rust +use std::sync::Arc; + +use polars_arrow::array::*; +use polars_arrow::datatypes::{Field, DataType, Schema}; +use polars_arrow::compute::arithmetics; +use polars_arrow::error::Result; +use polars_arrow::io::parquet::write::*; +use polars_arrow::chunk::Chunk; + +fn main() -> Result<()> { + // declare arrays + let a = Int32Array::from(&[Some(1), None, Some(3)]); + let b = Int32Array::from(&[Some(2), None, Some(6)]); + + // compute (probably the fastest implementation of a nullable op you can find out there) + let c = arithmetics::basic::mul_scalar(&a, &2); + assert_eq!(c, b); + + // declare a schema with fields + let schema = Schema::from(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Int32, true), + ]); + + // declare chunk + let chunk = Chunk::new(vec![a.arced(), b.arced()]); + + // write to parquet (probably the fastest implementation of writing to parquet out there) + + let options = WriteOptions { + write_statistics: true, + compression: CompressionOptions::Snappy, + version: Version::V1, + data_page_size: None, + }; + + let row_groups = RowGroupIterator::try_new( + vec![Ok(chunk)].into_iter(), + &schema, + options, + vec![vec![Encoding::Plain], vec![Encoding::Plain]], + )?; + + // anything implementing `std::io::Write` works + let mut file = vec![]; + + let mut writer = FileWriter::try_new(file, schema, options)?; + + // Write the file. + for group in row_groups { + writer.write(group?)?; + } + let _ = writer.end(None)?; + Ok(()) +} +``` + +## Cargo features + +This crate has a significant number of cargo features to reduce compilation time and number of +dependencies. The feature `"full"` activates most functionality, such as: + +- `io_ipc`: to interact with the Arrow IPC format +- `io_ipc_compression`: to read and write compressed Arrow IPC (v2) +- `io_flight` to read and write to Arrow's Flight protocol +- `compute` to operate on arrays (addition, sum, sort, etc.) + +The feature `simd` (not part of `full`) produces more explicit SIMD instructions via +[`std::simd`](https://doc.rust-lang.org/nightly/std/simd/index.html), but requires the nightly +channel. diff --git a/crates/polars-arrow/src/ffi/array.rs b/crates/polars-arrow/src/ffi/array.rs new file mode 100644 index 000000000000..563c0a3809b4 --- /dev/null +++ b/crates/polars-arrow/src/ffi/array.rs @@ -0,0 +1,649 @@ +//! Contains functionality to load an ArrayData from the C Data Interface +use std::sync::Arc; + +use polars_error::{PolarsResult, polars_bail}; + +use super::ArrowArray; +use crate::array::*; +use crate::bitmap::Bitmap; +use crate::bitmap::utils::bytes_for; +use crate::buffer::Buffer; +use crate::datatypes::{ArrowDataType, PhysicalType}; +use crate::ffi::schema::get_child; +use crate::storage::SharedStorage; +use crate::types::NativeType; +use crate::{ffi, match_integer_type, with_match_primitive_type_full}; + +/// Reads a valid `ffi` interface into a `Box` +/// # Errors +/// If and only if: +/// * the interface is not valid (e.g. a null pointer) +pub unsafe fn try_from(array: A) -> PolarsResult> { + use PhysicalType::*; + Ok(match array.dtype().to_physical_type() { + Null => Box::new(NullArray::try_from_ffi(array)?), + Boolean => Box::new(BooleanArray::try_from_ffi(array)?), + Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { + Box::new(PrimitiveArray::<$T>::try_from_ffi(array)?) + }), + Utf8 => Box::new(Utf8Array::::try_from_ffi(array)?), + LargeUtf8 => Box::new(Utf8Array::::try_from_ffi(array)?), + Binary => Box::new(BinaryArray::::try_from_ffi(array)?), + LargeBinary => Box::new(BinaryArray::::try_from_ffi(array)?), + FixedSizeBinary => Box::new(FixedSizeBinaryArray::try_from_ffi(array)?), + List => Box::new(ListArray::::try_from_ffi(array)?), + LargeList => Box::new(ListArray::::try_from_ffi(array)?), + FixedSizeList => Box::new(FixedSizeListArray::try_from_ffi(array)?), + Struct => Box::new(StructArray::try_from_ffi(array)?), + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + Box::new(DictionaryArray::<$T>::try_from_ffi(array)?) + }) + }, + Union => Box::new(UnionArray::try_from_ffi(array)?), + Map => Box::new(MapArray::try_from_ffi(array)?), + BinaryView => Box::new(BinaryViewArray::try_from_ffi(array)?), + Utf8View => Box::new(Utf8ViewArray::try_from_ffi(array)?), + }) +} + +// Sound because the arrow specification does not allow multiple implementations +// to change this struct +// This is intrinsically impossible to prove because the implementations agree +// on this as part of the Arrow specification +unsafe impl Send for ArrowArray {} +unsafe impl Sync for ArrowArray {} + +impl Drop for ArrowArray { + fn drop(&mut self) { + match self.release { + None => (), + Some(release) => unsafe { release(self) }, + }; + } +} + +// callback used to drop [ArrowArray] when it is exported +unsafe extern "C" fn c_release_array(array: *mut ArrowArray) { + if array.is_null() { + return; + } + let array = &mut *array; + + // take ownership of `private_data`, therefore dropping it + let private = Box::from_raw(array.private_data as *mut PrivateData); + for child in private.children_ptr.iter() { + let _ = Box::from_raw(*child); + } + + if let Some(ptr) = private.dictionary_ptr { + let _ = Box::from_raw(ptr); + } + + array.release = None; +} + +#[allow(dead_code)] +struct PrivateData { + array: Box, + buffers_ptr: Box<[*const std::os::raw::c_void]>, + children_ptr: Box<[*mut ArrowArray]>, + dictionary_ptr: Option<*mut ArrowArray>, + variadic_buffer_sizes: Box<[i64]>, +} + +impl ArrowArray { + /// creates a new `ArrowArray` from existing data. + /// + /// # Safety + /// This method releases `buffers`. Consumers of this struct *must* call `release` before + /// releasing this struct, or contents in `buffers` leak. + pub(crate) fn new(array: Box) -> Self { + #[allow(unused_mut)] + let (offset, mut buffers, children, dictionary) = + offset_buffers_children_dictionary(array.as_ref()); + + let variadic_buffer_sizes = match array.dtype() { + ArrowDataType::BinaryView => { + let arr = array.as_any().downcast_ref::().unwrap(); + let boxed = arr.variadic_buffer_lengths().into_boxed_slice(); + let ptr = boxed.as_ptr().cast::(); + buffers.push(Some(ptr)); + boxed + }, + ArrowDataType::Utf8View => { + let arr = array.as_any().downcast_ref::().unwrap(); + let boxed = arr.variadic_buffer_lengths().into_boxed_slice(); + let ptr = boxed.as_ptr().cast::(); + buffers.push(Some(ptr)); + boxed + }, + _ => Box::new([]), + }; + + let buffers_ptr = buffers + .iter() + .map(|maybe_buffer| match maybe_buffer { + Some(b) => *b as *const std::os::raw::c_void, + None => std::ptr::null(), + }) + .collect::>(); + let n_buffers = buffers.len() as i64; + + let children_ptr = children + .into_iter() + .map(|child| { + Box::into_raw(Box::new(ArrowArray::new(ffi::align_to_c_data_interface( + child, + )))) + }) + .collect::>(); + let n_children = children_ptr.len() as i64; + + let dictionary_ptr = dictionary.map(|array| { + Box::into_raw(Box::new(ArrowArray::new(ffi::align_to_c_data_interface( + array, + )))) + }); + + let length = array.len() as i64; + let null_count = array.null_count() as i64; + + let mut private_data = Box::new(PrivateData { + array, + buffers_ptr, + children_ptr, + dictionary_ptr, + variadic_buffer_sizes, + }); + + Self { + length, + null_count, + offset: offset as i64, + n_buffers, + n_children, + buffers: private_data.buffers_ptr.as_mut_ptr(), + children: private_data.children_ptr.as_mut_ptr(), + dictionary: private_data.dictionary_ptr.unwrap_or(std::ptr::null_mut()), + release: Some(c_release_array), + private_data: Box::into_raw(private_data) as *mut ::std::os::raw::c_void, + } + } + + /// creates an empty [`ArrowArray`], which can be used to import data into + pub fn empty() -> Self { + Self { + length: 0, + null_count: 0, + offset: 0, + n_buffers: 0, + n_children: 0, + buffers: std::ptr::null_mut(), + children: std::ptr::null_mut(), + dictionary: std::ptr::null_mut(), + release: None, + private_data: std::ptr::null_mut(), + } + } + + /// the length of the array + pub(crate) fn len(&self) -> usize { + self.length as usize + } + + /// the offset of the array + pub(crate) fn offset(&self) -> usize { + self.offset as usize + } + + /// the null count of the array + pub(crate) fn null_count(&self) -> usize { + self.null_count as usize + } +} + +/// # Safety +/// The caller must ensure that the buffer at index `i` is not mutably shared. +unsafe fn get_buffer_ptr( + array: &ArrowArray, + dtype: &ArrowDataType, + index: usize, +) -> PolarsResult<*mut T> { + if array.buffers.is_null() { + polars_bail!( ComputeError: + "an ArrowArray of type {dtype:?} must have non-null buffers" + ); + } + + if array.buffers.align_offset(align_of::<*mut *const u8>()) != 0 { + polars_bail!( ComputeError: + "an ArrowArray of type {dtype:?} + must have buffer {index} aligned to type {}", + std::any::type_name::<*mut *const u8>() + ); + } + let buffers = array.buffers as *mut *const u8; + + if index >= array.n_buffers as usize { + polars_bail!(ComputeError: + "An ArrowArray of type {dtype:?} + must have buffer {index}." + ) + } + + let ptr = *buffers.add(index); + if ptr.is_null() { + polars_bail!(ComputeError: + "An array of type {dtype:?} + must have a non-null buffer {index}" + ) + } + + // note: we can't prove that this pointer is not mutably shared - part of the safety invariant + Ok(ptr as *mut T) +} + +unsafe fn create_buffer_known_len( + array: &ArrowArray, + dtype: &ArrowDataType, + owner: InternalArrowArray, + len: usize, + index: usize, +) -> PolarsResult> { + if len == 0 { + return Ok(Buffer::new()); + } + let ptr: *mut T = get_buffer_ptr(array, dtype, index)?; + let storage = SharedStorage::from_internal_arrow_array(ptr, len, owner); + Ok(Buffer::from_storage(storage)) +} + +/// returns the buffer `i` of `array` interpreted as a [`Buffer`]. +/// # Safety +/// This function is safe iff: +/// * the buffers up to position `index` are valid for the declared length +/// * the buffers' pointers are not mutably shared for the lifetime of `owner` +unsafe fn create_buffer( + array: &ArrowArray, + dtype: &ArrowDataType, + owner: InternalArrowArray, + index: usize, +) -> PolarsResult> { + let len = buffer_len(array, dtype, index)?; + + if len == 0 { + return Ok(Buffer::new()); + } + + let offset = buffer_offset(array, dtype, index); + let ptr: *mut T = get_buffer_ptr(array, dtype, index)?; + + // We have to check alignment. + // This is the zero-copy path. + if ptr.align_offset(align_of::()) == 0 { + let storage = SharedStorage::from_internal_arrow_array(ptr, len, owner); + Ok(Buffer::from_storage(storage).sliced(offset, len - offset)) + } + // This is the path where alignment isn't correct. + // We copy the data to a new vec + else { + let buf = std::slice::from_raw_parts(ptr, len - offset).to_vec(); + Ok(Buffer::from(buf)) + } +} + +/// returns the buffer `i` of `array` interpreted as a [`Bitmap`]. +/// # Safety +/// This function is safe iff: +/// * the buffer at position `index` is valid for the declared length +/// * the buffers' pointer is not mutable for the lifetime of `owner` +unsafe fn create_bitmap( + array: &ArrowArray, + dtype: &ArrowDataType, + owner: InternalArrowArray, + index: usize, + // if this is the validity bitmap + // we can use the null count directly + is_validity: bool, +) -> PolarsResult { + let len: usize = array.length.try_into().expect("length to fit in `usize`"); + if len == 0 { + return Ok(Bitmap::new()); + } + let ptr = get_buffer_ptr(array, dtype, index)?; + + // Pointer of u8 has alignment 1, so we don't have to check alignment. + + let offset: usize = array.offset.try_into().expect("offset to fit in `usize`"); + let bytes_len = bytes_for(offset + len); + let storage = SharedStorage::from_internal_arrow_array(ptr, bytes_len, owner); + + let null_count = if is_validity { + Some(array.null_count()) + } else { + None + }; + Ok(Bitmap::from_inner_unchecked( + storage, offset, len, null_count, + )) +} + +fn buffer_offset(array: &ArrowArray, dtype: &ArrowDataType, i: usize) -> usize { + use PhysicalType::*; + match (dtype.to_physical_type(), i) { + (LargeUtf8, 2) | (LargeBinary, 2) | (Utf8, 2) | (Binary, 2) => 0, + (FixedSizeBinary, 1) => { + if let ArrowDataType::FixedSizeBinary(size) = dtype.to_logical_type() { + let offset: usize = array.offset.try_into().expect("Offset to fit in `usize`"); + offset * *size + } else { + unreachable!() + } + }, + _ => array.offset.try_into().expect("Offset to fit in `usize`"), + } +} + +/// Returns the length, in slots, of the buffer `i` (indexed according to the C data interface) +unsafe fn buffer_len(array: &ArrowArray, dtype: &ArrowDataType, i: usize) -> PolarsResult { + Ok(match (dtype.to_physical_type(), i) { + (PhysicalType::FixedSizeBinary, 1) => { + if let ArrowDataType::FixedSizeBinary(size) = dtype.to_logical_type() { + *size * (array.offset as usize + array.length as usize) + } else { + unreachable!() + } + }, + (PhysicalType::FixedSizeList, 1) => { + if let ArrowDataType::FixedSizeList(_, size) = dtype.to_logical_type() { + *size * (array.offset as usize + array.length as usize) + } else { + unreachable!() + } + }, + (PhysicalType::Utf8, 1) + | (PhysicalType::LargeUtf8, 1) + | (PhysicalType::Binary, 1) + | (PhysicalType::LargeBinary, 1) + | (PhysicalType::List, 1) + | (PhysicalType::LargeList, 1) + | (PhysicalType::Map, 1) => { + // the len of the offset buffer (buffer 1) equals length + 1 + array.offset as usize + array.length as usize + 1 + }, + (PhysicalType::BinaryView, 1) | (PhysicalType::Utf8View, 1) => { + array.offset as usize + array.length as usize + }, + (PhysicalType::Utf8, 2) | (PhysicalType::Binary, 2) => { + // the len of the data buffer (buffer 2) equals the last value of the offset buffer (buffer 1) + let len = buffer_len(array, dtype, 1)?; + // first buffer is the null buffer => add(1) + let offset_buffer = unsafe { *(array.buffers as *mut *const u8).add(1) }; + // interpret as i32 + let offset_buffer = offset_buffer as *const i32; + // get last offset + + (unsafe { *offset_buffer.add(len - 1) }) as usize + }, + (PhysicalType::LargeUtf8, 2) | (PhysicalType::LargeBinary, 2) => { + // the len of the data buffer (buffer 2) equals the last value of the offset buffer (buffer 1) + let len = buffer_len(array, dtype, 1)?; + // first buffer is the null buffer => add(1) + let offset_buffer = unsafe { *(array.buffers as *mut *const u8).add(1) }; + // interpret as i64 + let offset_buffer = offset_buffer as *const i64; + // get last offset + (unsafe { *offset_buffer.add(len - 1) }) as usize + }, + // buffer len of primitive types + _ => array.offset as usize + array.length as usize, + }) +} + +/// # Safety +/// +/// This function is safe iff: +/// * `array.children` at `index` is valid +/// * `array.children` is not mutably shared for the lifetime of `parent` +/// * the pointer of `array.children` at `index` is valid +/// * the pointer of `array.children` at `index` is not mutably shared for the lifetime of `parent` +unsafe fn create_child( + array: &ArrowArray, + dtype: &ArrowDataType, + parent: InternalArrowArray, + index: usize, +) -> PolarsResult> { + let dtype = get_child(dtype, index)?; + + // catch what we can + if array.children.is_null() { + polars_bail!(ComputeError: "an ArrowArray of type {dtype:?} must have non-null children"); + } + + if index >= array.n_children as usize { + polars_bail!(ComputeError: + "an ArrowArray of type {dtype:?} + must have child {index}." + ); + } + + // SAFETY: part of the invariant + let arr_ptr = unsafe { *array.children.add(index) }; + + // catch what we can + if arr_ptr.is_null() { + polars_bail!(ComputeError: + "an array of type {dtype:?} + must have a non-null child {index}" + ) + } + + // SAFETY: invariant of this function + let arr_ptr = unsafe { &*arr_ptr }; + Ok(ArrowArrayChild::new(arr_ptr, dtype, parent)) +} + +/// # Safety +/// +/// This function is safe iff: +/// * `array.dictionary` is valid +/// * `array.dictionary` is not mutably shared for the lifetime of `parent` +unsafe fn create_dictionary( + array: &ArrowArray, + dtype: &ArrowDataType, + parent: InternalArrowArray, +) -> PolarsResult>> { + if let ArrowDataType::Dictionary(_, values, _) = dtype { + let dtype = values.as_ref().clone(); + // catch what we can + if array.dictionary.is_null() { + polars_bail!(ComputeError: + "an array of type {dtype:?} + must have a non-null dictionary" + ) + } + + // SAFETY: part of the invariant + let array = unsafe { &*array.dictionary }; + Ok(Some(ArrowArrayChild::new(array, dtype, parent))) + } else { + Ok(None) + } +} + +pub trait ArrowArrayRef: std::fmt::Debug { + fn owner(&self) -> InternalArrowArray { + (*self.parent()).clone() + } + + /// returns the null bit buffer. + /// Rust implementation uses a buffer that is not part of the array of buffers. + /// The C Data interface's null buffer is part of the array of buffers. + /// + /// # Safety + /// The caller must guarantee that the buffer `index` corresponds to a bitmap. + /// This function assumes that the bitmap created from FFI is valid; this is impossible to prove. + unsafe fn validity(&self) -> PolarsResult> { + if self.array().null_count() == 0 { + Ok(None) + } else { + create_bitmap(self.array(), self.dtype(), self.owner(), 0, true).map(Some) + } + } + + /// # Safety + /// The caller must guarantee that the buffer `index` corresponds to a buffer. + /// This function assumes that the buffer created from FFI is valid; this is impossible to prove. + unsafe fn buffer(&self, index: usize) -> PolarsResult> { + create_buffer::(self.array(), self.dtype(), self.owner(), index) + } + + /// # Safety + /// The caller must guarantee that the buffer `index` corresponds to a buffer. + /// This function assumes that the buffer created from FFI is valid; this is impossible to prove. + unsafe fn buffer_known_len( + &self, + index: usize, + len: usize, + ) -> PolarsResult> { + create_buffer_known_len::(self.array(), self.dtype(), self.owner(), len, index) + } + + /// # Safety + /// This function is safe iff: + /// * the buffer at position `index` is valid for the declared length + /// * the buffers' pointer is not mutable for the lifetime of `owner` + unsafe fn bitmap(&self, index: usize) -> PolarsResult { + create_bitmap(self.array(), self.dtype(), self.owner(), index, false) + } + + /// # Safety + /// * `array.children` at `index` is valid + /// * `array.children` is not mutably shared for the lifetime of `parent` + /// * the pointer of `array.children` at `index` is valid + /// * the pointer of `array.children` at `index` is not mutably shared for the lifetime of `parent` + unsafe fn child(&self, index: usize) -> PolarsResult { + create_child(self.array(), self.dtype(), self.parent().clone(), index) + } + + unsafe fn dictionary(&self) -> PolarsResult> { + create_dictionary(self.array(), self.dtype(), self.parent().clone()) + } + + fn n_buffers(&self) -> usize; + + fn offset(&self) -> usize; + fn length(&self) -> usize; + + fn parent(&self) -> &InternalArrowArray; + fn array(&self) -> &ArrowArray; + fn dtype(&self) -> &ArrowDataType; +} + +/// Struct used to move an Array from and to the C Data Interface. +/// Its main responsibility is to expose functionality that requires +/// both [ArrowArray] and [ArrowSchema]. +/// +/// This struct has two main paths: +/// +/// ## Import from the C Data Interface +/// * [InternalArrowArray::empty] to allocate memory to be filled by an external call +/// * [InternalArrowArray::try_from_raw] to consume two non-null allocated pointers +/// ## Export to the C Data Interface +/// * [InternalArrowArray::try_new] to create a new [InternalArrowArray] from Rust-specific information +/// * [InternalArrowArray::into_raw] to expose two pointers for [ArrowArray] and [ArrowSchema]. +/// +/// # Safety +/// Whoever creates this struct is responsible for releasing their resources. Specifically, +/// consumers *must* call [InternalArrowArray::into_raw] and take ownership of the individual pointers, +/// calling [ArrowArray::release] and [ArrowSchema::release] accordingly. +/// +/// Furthermore, this struct assumes that the incoming data agrees with the C data interface. +#[derive(Debug, Clone)] +pub struct InternalArrowArray { + // Arc is used for sharability since this is immutable + array: Arc, + // Arced to reduce cost of cloning + dtype: Arc, +} + +impl InternalArrowArray { + pub fn new(array: ArrowArray, dtype: ArrowDataType) -> Self { + Self { + array: Arc::new(array), + dtype: Arc::new(dtype), + } + } +} + +impl ArrowArrayRef for InternalArrowArray { + /// the dtype as declared in the schema + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } + + fn parent(&self) -> &InternalArrowArray { + self + } + + fn array(&self) -> &ArrowArray { + self.array.as_ref() + } + + fn n_buffers(&self) -> usize { + self.array.n_buffers as usize + } + + fn offset(&self) -> usize { + self.array.offset as usize + } + + fn length(&self) -> usize { + self.array.length as usize + } +} + +#[derive(Debug)] +pub struct ArrowArrayChild<'a> { + array: &'a ArrowArray, + dtype: ArrowDataType, + parent: InternalArrowArray, +} + +impl ArrowArrayRef for ArrowArrayChild<'_> { + /// the dtype as declared in the schema + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } + + fn parent(&self) -> &InternalArrowArray { + &self.parent + } + + fn array(&self) -> &ArrowArray { + self.array + } + + fn n_buffers(&self) -> usize { + self.array.n_buffers as usize + } + + fn offset(&self) -> usize { + self.array.offset as usize + } + + fn length(&self) -> usize { + self.array.length as usize + } +} + +impl<'a> ArrowArrayChild<'a> { + fn new(array: &'a ArrowArray, dtype: ArrowDataType, parent: InternalArrowArray) -> Self { + Self { + array, + dtype, + parent, + } + } +} diff --git a/crates/polars-arrow/src/ffi/bridge.rs b/crates/polars-arrow/src/ffi/bridge.rs new file mode 100644 index 000000000000..c23c21643214 --- /dev/null +++ b/crates/polars-arrow/src/ffi/bridge.rs @@ -0,0 +1,42 @@ +use crate::array::*; +use crate::{match_integer_type, with_match_primitive_type_full}; + +macro_rules! ffi_dyn { + ($array:expr, $ty:ty) => {{ + let a = $array.as_any().downcast_ref::<$ty>().unwrap(); + if a.offset().is_some() { + $array + } else { + Box::new(a.to_ffi_aligned()) + } + }}; +} + +pub fn align_to_c_data_interface(array: Box) -> Box { + use crate::datatypes::PhysicalType::*; + match array.dtype().to_physical_type() { + Null => ffi_dyn!(array, NullArray), + Boolean => ffi_dyn!(array, BooleanArray), + Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { + ffi_dyn!(array, PrimitiveArray<$T>) + }), + Binary => ffi_dyn!(array, BinaryArray), + LargeBinary => ffi_dyn!(array, BinaryArray), + FixedSizeBinary => ffi_dyn!(array, FixedSizeBinaryArray), + Utf8 => ffi_dyn!(array, Utf8Array::), + LargeUtf8 => ffi_dyn!(array, Utf8Array::), + List => ffi_dyn!(array, ListArray::), + LargeList => ffi_dyn!(array, ListArray::), + FixedSizeList => ffi_dyn!(array, FixedSizeListArray), + Struct => ffi_dyn!(array, StructArray), + Union => ffi_dyn!(array, UnionArray), + Map => ffi_dyn!(array, MapArray), + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + ffi_dyn!(array, DictionaryArray<$T>) + }) + }, + BinaryView => ffi_dyn!(array, BinaryViewArray), + Utf8View => ffi_dyn!(array, Utf8ViewArray), + } +} diff --git a/crates/polars-arrow/src/ffi/generated.rs b/crates/polars-arrow/src/ffi/generated.rs new file mode 100644 index 000000000000..cd4953b7198a --- /dev/null +++ b/crates/polars-arrow/src/ffi/generated.rs @@ -0,0 +1,55 @@ +/* automatically generated by rust-bindgen 0.59.2 */ + +/// ABI-compatible struct for [`ArrowSchema`](https://arrow.apache.org/docs/format/CDataInterface.html#structure-definitions) +#[repr(C)] +#[derive(Debug)] +pub struct ArrowSchema { + pub(super) format: *const ::std::os::raw::c_char, + pub(super) name: *const ::std::os::raw::c_char, + pub(super) metadata: *const ::std::os::raw::c_char, + pub(super) flags: i64, + pub(super) n_children: i64, + pub(super) children: *mut *mut ArrowSchema, + pub(super) dictionary: *mut ArrowSchema, + pub(super) release: ::std::option::Option, + pub(super) private_data: *mut ::std::os::raw::c_void, +} + +/// ABI-compatible struct for [`ArrowArray`](https://arrow.apache.org/docs/format/CDataInterface.html#structure-definitions) +#[repr(C)] +#[derive(Debug)] +pub struct ArrowArray { + pub(super) length: i64, + pub(super) null_count: i64, + pub(super) offset: i64, + pub(super) n_buffers: i64, + pub(super) n_children: i64, + pub(super) buffers: *mut *const ::std::os::raw::c_void, + pub(super) children: *mut *mut ArrowArray, + pub(super) dictionary: *mut ArrowArray, + pub(super) release: ::std::option::Option, + pub(super) private_data: *mut ::std::os::raw::c_void, +} + +/// ABI-compatible struct for [`ArrowArrayStream`](https://arrow.apache.org/docs/format/CStreamInterface.html). +#[repr(C)] +#[derive(Debug)] +pub struct ArrowArrayStream { + pub(super) get_schema: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut ArrowArrayStream, + out: *mut ArrowSchema, + ) -> ::std::os::raw::c_int, + >, + pub(super) get_next: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut ArrowArrayStream, + out: *mut ArrowArray, + ) -> ::std::os::raw::c_int, + >, + pub(super) get_last_error: ::std::option::Option< + unsafe extern "C" fn(arg1: *mut ArrowArrayStream) -> *const ::std::os::raw::c_char, + >, + pub(super) release: ::std::option::Option, + pub(super) private_data: *mut ::std::os::raw::c_void, +} diff --git a/crates/polars-arrow/src/ffi/mmap.rs b/crates/polars-arrow/src/ffi/mmap.rs new file mode 100644 index 000000000000..8fdbedf24cea --- /dev/null +++ b/crates/polars-arrow/src/ffi/mmap.rs @@ -0,0 +1,214 @@ +//! Functionality to mmap in-memory data regions. +use std::sync::Arc; + +use polars_error::{PolarsResult, polars_bail}; + +use super::{ArrowArray, InternalArrowArray}; +use crate::array::{BooleanArray, FromFfi, PrimitiveArray}; +use crate::bitmap::Bitmap; +use crate::buffer::Buffer; +use crate::datatypes::ArrowDataType; +use crate::storage::SharedStorage; +use crate::types::NativeType; + +#[allow(dead_code)] +struct PrivateData { + // the owner of the pointers' regions + data: T, + buffers_ptr: Box<[*const std::os::raw::c_void]>, + children_ptr: Box<[*mut ArrowArray]>, + dictionary_ptr: Option<*mut ArrowArray>, +} + +pub(crate) unsafe fn create_array< + T, + I: Iterator>, + II: Iterator, +>( + data: Arc, + num_rows: usize, + null_count: usize, + buffers: I, + children: II, + dictionary: Option, + offset: Option, +) -> ArrowArray { + let buffers_ptr = buffers + .map(|maybe_buffer| match maybe_buffer { + Some(b) => b as *const std::os::raw::c_void, + None => std::ptr::null(), + }) + .collect::>(); + let n_buffers = buffers_ptr.len() as i64; + + let children_ptr = children + .map(|child| Box::into_raw(Box::new(child))) + .collect::>(); + let n_children = children_ptr.len() as i64; + + let dictionary_ptr = dictionary.map(|array| Box::into_raw(Box::new(array))); + + let mut private_data = Box::new(PrivateData::> { + data, + buffers_ptr, + children_ptr, + dictionary_ptr, + }); + + ArrowArray { + length: num_rows as i64, + null_count: null_count as i64, + offset: offset.unwrap_or(0) as i64, // Unwrap: IPC files are by definition not offset + n_buffers, + n_children, + buffers: private_data.buffers_ptr.as_mut_ptr(), + children: private_data.children_ptr.as_mut_ptr(), + dictionary: private_data.dictionary_ptr.unwrap_or(std::ptr::null_mut()), + release: Some(release::>), + private_data: Box::into_raw(private_data) as *mut ::std::os::raw::c_void, + } +} + +/// callback used to drop [`ArrowArray`] when it is exported specified for [`PrivateData`]. +unsafe extern "C" fn release(array: *mut ArrowArray) { + if array.is_null() { + return; + } + let array = &mut *array; + + // take ownership of `private_data`, therefore dropping it + let private = Box::from_raw(array.private_data as *mut PrivateData); + for child in private.children_ptr.iter() { + let _ = Box::from_raw(*child); + } + + if let Some(ptr) = private.dictionary_ptr { + let _ = Box::from_raw(ptr); + } + + array.release = None; +} + +/// Creates a (non-null) [`PrimitiveArray`] from a slice of values. +/// This does not have memcopy and is the fastest way to create a [`PrimitiveArray`]. +/// +/// This can be useful if you want to apply arrow kernels on slices without incurring +/// a memcopy cost. +/// +/// # Safety +/// +/// Using this function is not unsafe, but the returned PrimitiveArray's lifetime is bound to the lifetime +/// of the slice. The returned [`PrimitiveArray`] _must not_ outlive the passed slice. +pub unsafe fn slice(values: &[T]) -> PrimitiveArray { + let static_values = std::mem::transmute::<&[T], &'static [T]>(values); + let storage = SharedStorage::from_static(static_values); + let buffer = Buffer::from_storage(storage); + PrimitiveArray::new_unchecked(T::PRIMITIVE.into(), buffer, None) +} + +/// Creates a (non-null) [`PrimitiveArray`] from a slice of values. +/// This does not have memcopy and is the fastest way to create a [`PrimitiveArray`]. +/// +/// This can be useful if you want to apply arrow kernels on slices without incurring +/// a memcopy cost. +/// +/// # Safety +/// +/// The caller must ensure the passed `owner` ensures the data remains alive. +pub unsafe fn slice_and_owner(slice: &[T], owner: O) -> PrimitiveArray { + let num_rows = slice.len(); + let null_count = 0; + let validity = None; + + let data: &[u8] = bytemuck::cast_slice(slice); + let ptr = data.as_ptr(); + let data = Arc::new(owner); + + // SAFETY: the underlying assumption of this function: the array will not be used + // beyond the + let array = create_array( + data, + num_rows, + null_count, + [validity, Some(ptr)].into_iter(), + [].into_iter(), + None, + None, + ); + let array = InternalArrowArray::new(array, T::PRIMITIVE.into()); + + // SAFETY: we just created a valid array + unsafe { PrimitiveArray::::try_from_ffi(array) }.unwrap() +} + +/// Creates a (non-null) [`BooleanArray`] from a slice of bits. +/// This does not have memcopy and is the fastest way to create a [`BooleanArray`]. +/// +/// This can be useful if you want to apply arrow kernels on slices without +/// incurring a memcopy cost. +/// +/// The `offset` indicates where the first bit starts in the first byte. +/// +/// # Safety +/// +/// Using this function is not unsafe, but the returned BooleanArrays's lifetime +/// is bound to the lifetime of the slice. The returned [`BooleanArray`] _must +/// not_ outlive the passed slice. +pub unsafe fn bitmap(data: &[u8], offset: usize, length: usize) -> PolarsResult { + if offset >= 8 { + polars_bail!(InvalidOperation: "offset should be < 8") + }; + if length > data.len() * 8 - offset { + polars_bail!(InvalidOperation: "given length is oob") + } + let static_data = std::mem::transmute::<&[u8], &'static [u8]>(data); + let storage = SharedStorage::from_static(static_data); + let bitmap = Bitmap::from_inner_unchecked(storage, offset, length, None); + Ok(BooleanArray::new(ArrowDataType::Boolean, bitmap, None)) +} + +/// Creates a (non-null) [`BooleanArray`] from a slice of bits. +/// This does not have memcopy and is the fastest way to create a [`BooleanArray`]. +/// +/// This can be useful if you want to apply arrow kernels on slices without +/// incurring a memcopy cost. +/// +/// The `offset` indicates where the first bit starts in the first byte. +/// +/// # Safety +/// +/// The caller must ensure the passed `owner` ensures the data remains alive. +pub unsafe fn bitmap_and_owner( + data: &[u8], + offset: usize, + length: usize, + owner: O, +) -> PolarsResult { + if offset >= 8 { + polars_bail!(InvalidOperation: "offset should be < 8") + }; + if length > data.len() * 8 - offset { + polars_bail!(InvalidOperation: "given length is oob") + } + let null_count = 0; + let validity = None; + + let ptr = data.as_ptr(); + let data = Arc::new(owner); + + // SAFETY: the underlying assumption of this function: the array will not be used + // beyond the + let array = create_array( + data, + length, + null_count, + [validity, Some(ptr)].into_iter(), + [].into_iter(), + None, + Some(offset), + ); + let array = InternalArrowArray::new(array, ArrowDataType::Boolean); + + // SAFETY: we just created a valid array + Ok(unsafe { BooleanArray::try_from_ffi(array) }.unwrap()) +} diff --git a/crates/polars-arrow/src/ffi/mod.rs b/crates/polars-arrow/src/ffi/mod.rs new file mode 100644 index 000000000000..1e9802a8c532 --- /dev/null +++ b/crates/polars-arrow/src/ffi/mod.rs @@ -0,0 +1,48 @@ +#![allow(unsafe_op_in_unsafe_fn)] +//! contains FFI bindings to import and export [`Array`](crate::array::Array) via +//! Arrow's [C Data Interface](https://arrow.apache.org/docs/format/CDataInterface.html) +mod array; +mod bridge; +mod generated; +pub mod mmap; +mod schema; +mod stream; + +pub(crate) use array::{ArrowArrayRef, InternalArrowArray, try_from}; +pub(crate) use bridge::align_to_c_data_interface; +pub use generated::{ArrowArray, ArrowArrayStream, ArrowSchema}; +use polars_error::PolarsResult; +pub use stream::{ArrowArrayStreamReader, export_iterator}; + +use self::schema::to_field; +use crate::array::Array; +use crate::datatypes::{ArrowDataType, Field}; + +/// Exports an [`Box`] to the C data interface. +pub fn export_array_to_c(array: Box) -> ArrowArray { + ArrowArray::new(bridge::align_to_c_data_interface(array)) +} + +/// Exports a [`Field`] to the C data interface. +pub fn export_field_to_c(field: &Field) -> ArrowSchema { + ArrowSchema::new(field) +} + +/// Imports a [`Field`] from the C data interface. +/// # Safety +/// This function is intrinsically `unsafe` and relies on a [`ArrowSchema`] +/// being valid according to the [C data interface](https://arrow.apache.org/docs/format/CDataInterface.html) (FFI). +pub unsafe fn import_field_from_c(field: &ArrowSchema) -> PolarsResult { + to_field(field) +} + +/// Imports an [`Array`] from the C data interface. +/// # Safety +/// This function is intrinsically `unsafe` and relies on a [`ArrowArray`] +/// being valid according to the [C data interface](https://arrow.apache.org/docs/format/CDataInterface.html) (FFI). +pub unsafe fn import_array_from_c( + array: ArrowArray, + dtype: ArrowDataType, +) -> PolarsResult> { + try_from(InternalArrowArray::new(array, dtype)) +} diff --git a/crates/polars-arrow/src/ffi/schema.rs b/crates/polars-arrow/src/ffi/schema.rs new file mode 100644 index 000000000000..1010b890fda3 --- /dev/null +++ b/crates/polars-arrow/src/ffi/schema.rs @@ -0,0 +1,720 @@ +use std::collections::BTreeMap; +use std::ffi::{CStr, CString}; +use std::ptr; + +use polars_error::{PolarsResult, polars_bail, polars_err}; +use polars_utils::pl_str::PlSmallStr; + +use super::ArrowSchema; +use crate::datatypes::{ + ArrowDataType, Extension, ExtensionType, Field, IntegerType, IntervalUnit, Metadata, TimeUnit, + UnionMode, UnionType, +}; + +#[allow(dead_code)] +struct SchemaPrivateData { + name: CString, + format: CString, + metadata: Option>, + children_ptr: Box<[*mut ArrowSchema]>, + dictionary: Option<*mut ArrowSchema>, +} + +// callback used to drop [ArrowSchema] when it is exported. +unsafe extern "C" fn c_release_schema(schema: *mut ArrowSchema) { + if schema.is_null() { + return; + } + let schema = &mut *schema; + + let private = Box::from_raw(schema.private_data as *mut SchemaPrivateData); + for child in private.children_ptr.iter() { + let _ = Box::from_raw(*child); + } + + if let Some(ptr) = private.dictionary { + let _ = Box::from_raw(ptr); + } + + schema.release = None; +} + +/// allocate (and hold) the children +fn schema_children(dtype: &ArrowDataType, flags: &mut i64) -> Box<[*mut ArrowSchema]> { + match dtype { + ArrowDataType::List(field) + | ArrowDataType::FixedSizeList(field, _) + | ArrowDataType::LargeList(field) => { + Box::new([Box::into_raw(Box::new(ArrowSchema::new(field.as_ref())))]) + }, + ArrowDataType::Map(field, is_sorted) => { + *flags += (*is_sorted as i64) * 4; + Box::new([Box::into_raw(Box::new(ArrowSchema::new(field.as_ref())))]) + }, + ArrowDataType::Struct(fields) => fields + .iter() + .map(|field| Box::into_raw(Box::new(ArrowSchema::new(field)))) + .collect::>(), + ArrowDataType::Union(u) => u + .fields + .iter() + .map(|field| Box::into_raw(Box::new(ArrowSchema::new(field)))) + .collect::>(), + ArrowDataType::Extension(ext) => schema_children(&ext.inner, flags), + _ => Box::new([]), + } +} + +impl ArrowSchema { + /// creates a new [ArrowSchema] + pub(crate) fn new(field: &Field) -> Self { + let format = to_format(field.dtype()); + let name = field.name.clone(); + + let mut flags = field.is_nullable as i64 * 2; + + // note: this cannot be done along with the above because the above is fallible and this op leaks. + let children_ptr = schema_children(field.dtype(), &mut flags); + let n_children = children_ptr.len() as i64; + + let dictionary = if let ArrowDataType::Dictionary(_, values, is_ordered) = field.dtype() { + flags += *is_ordered as i64; + // we do not store field info in the dict values, so can't recover it all :( + let field = Field::new(PlSmallStr::EMPTY, values.as_ref().clone(), true); + Some(Box::new(ArrowSchema::new(&field))) + } else { + None + }; + + let metadata = field + .metadata + .as_ref() + .map(|inner| (**inner).clone()) + .unwrap_or_default(); + + let metadata = if let ArrowDataType::Extension(ext) = field.dtype() { + // append extension information. + let mut metadata = metadata.clone(); + + // metadata + if let Some(extension_metadata) = &ext.metadata { + metadata.insert( + PlSmallStr::from_static("ARROW:extension:metadata"), + extension_metadata.clone(), + ); + } + + metadata.insert( + PlSmallStr::from_static("ARROW:extension:name"), + ext.name.clone(), + ); + + Some(metadata_to_bytes(&metadata)) + } else if !metadata.is_empty() { + Some(metadata_to_bytes(&metadata)) + } else { + None + }; + + let name = CString::new(name.as_bytes()).unwrap(); + let format = CString::new(format).unwrap(); + + let mut private = Box::new(SchemaPrivateData { + name, + format, + metadata, + children_ptr, + dictionary: dictionary.map(Box::into_raw), + }); + + // + Self { + format: private.format.as_ptr(), + name: private.name.as_ptr(), + metadata: private + .metadata + .as_ref() + .map(|x| x.as_ptr()) + .unwrap_or(std::ptr::null()) as *const ::std::os::raw::c_char, + flags, + n_children, + children: private.children_ptr.as_mut_ptr(), + dictionary: private.dictionary.unwrap_or(std::ptr::null_mut()), + release: Some(c_release_schema), + private_data: Box::into_raw(private) as *mut ::std::os::raw::c_void, + } + } + + /// create an empty [ArrowSchema] + pub fn empty() -> Self { + Self { + format: std::ptr::null_mut(), + name: std::ptr::null_mut(), + metadata: std::ptr::null_mut(), + flags: 0, + n_children: 0, + children: ptr::null_mut(), + dictionary: std::ptr::null_mut(), + release: None, + private_data: std::ptr::null_mut(), + } + } + + pub fn is_null(&self) -> bool { + self.private_data.is_null() + } + + /// returns the format of this schema. + pub(crate) fn format(&self) -> &str { + assert!(!self.format.is_null()); + // safe because the lifetime of `self.format` equals `self` + unsafe { CStr::from_ptr(self.format) } + .to_str() + .expect("The external API has a non-utf8 as format") + } + + /// returns the name of this schema. + /// + /// Since this field is optional, `""` is returned if it is not set (as per the spec). + pub(crate) fn name(&self) -> &str { + if self.name.is_null() { + return ""; + } + // safe because the lifetime of `self.name` equals `self` + unsafe { CStr::from_ptr(self.name) }.to_str().unwrap() + } + + pub(crate) fn child(&self, index: usize) -> &'static Self { + assert!(index < self.n_children as usize); + unsafe { self.children.add(index).as_ref().unwrap().as_ref().unwrap() } + } + + pub(crate) fn dictionary(&self) -> Option<&'static Self> { + if self.dictionary.is_null() { + return None; + }; + Some(unsafe { self.dictionary.as_ref().unwrap() }) + } + + pub(crate) fn nullable(&self) -> bool { + (self.flags / 2) & 1 == 1 + } +} + +impl Drop for ArrowSchema { + fn drop(&mut self) { + match self.release { + None => (), + Some(release) => unsafe { release(self) }, + }; + } +} + +pub(crate) unsafe fn to_field(schema: &ArrowSchema) -> PolarsResult { + let dictionary = schema.dictionary(); + let dtype = if let Some(dictionary) = dictionary { + let indices = to_integer_type(schema.format())?; + let values = to_field(dictionary)?; + let is_ordered = schema.flags & 1 == 1; + ArrowDataType::Dictionary(indices, Box::new(values.dtype().clone()), is_ordered) + } else { + to_dtype(schema)? + }; + let (metadata, extension) = unsafe { metadata_from_bytes(schema.metadata) }; + + let dtype = if let Some((name, extension_metadata)) = extension { + ArrowDataType::Extension(Box::new(ExtensionType { + name, + inner: dtype, + metadata: extension_metadata, + })) + } else { + dtype + }; + + Ok(Field::new( + PlSmallStr::from_str(schema.name()), + dtype, + schema.nullable(), + ) + .with_metadata(metadata)) +} + +fn to_integer_type(format: &str) -> PolarsResult { + use IntegerType::*; + Ok(match format { + "c" => Int8, + "C" => UInt8, + "s" => Int16, + "S" => UInt16, + "i" => Int32, + "I" => UInt32, + "l" => Int64, + "L" => UInt64, + _ => { + polars_bail!( + ComputeError: + "dictionary indices can only be integers" + ) + }, + }) +} + +unsafe fn to_dtype(schema: &ArrowSchema) -> PolarsResult { + Ok(match schema.format() { + "n" => ArrowDataType::Null, + "b" => ArrowDataType::Boolean, + "c" => ArrowDataType::Int8, + "C" => ArrowDataType::UInt8, + "s" => ArrowDataType::Int16, + "S" => ArrowDataType::UInt16, + "i" => ArrowDataType::Int32, + "I" => ArrowDataType::UInt32, + "l" => ArrowDataType::Int64, + "L" => ArrowDataType::UInt64, + "_pli128" => ArrowDataType::Int128, + "e" => ArrowDataType::Float16, + "f" => ArrowDataType::Float32, + "g" => ArrowDataType::Float64, + "z" => ArrowDataType::Binary, + "Z" => ArrowDataType::LargeBinary, + "u" => ArrowDataType::Utf8, + "U" => ArrowDataType::LargeUtf8, + "tdD" => ArrowDataType::Date32, + "tdm" => ArrowDataType::Date64, + "tts" => ArrowDataType::Time32(TimeUnit::Second), + "ttm" => ArrowDataType::Time32(TimeUnit::Millisecond), + "ttu" => ArrowDataType::Time64(TimeUnit::Microsecond), + "ttn" => ArrowDataType::Time64(TimeUnit::Nanosecond), + "tDs" => ArrowDataType::Duration(TimeUnit::Second), + "tDm" => ArrowDataType::Duration(TimeUnit::Millisecond), + "tDu" => ArrowDataType::Duration(TimeUnit::Microsecond), + "tDn" => ArrowDataType::Duration(TimeUnit::Nanosecond), + "tiM" => ArrowDataType::Interval(IntervalUnit::YearMonth), + "tiD" => ArrowDataType::Interval(IntervalUnit::DayTime), + "vu" => ArrowDataType::Utf8View, + "vz" => ArrowDataType::BinaryView, + "+l" => { + let child = schema.child(0); + ArrowDataType::List(Box::new(to_field(child)?)) + }, + "+L" => { + let child = schema.child(0); + ArrowDataType::LargeList(Box::new(to_field(child)?)) + }, + "+m" => { + let child = schema.child(0); + + let is_sorted = (schema.flags & 4) != 0; + ArrowDataType::Map(Box::new(to_field(child)?), is_sorted) + }, + "+s" => { + let children = (0..schema.n_children as usize) + .map(|x| to_field(schema.child(x))) + .collect::>>()?; + ArrowDataType::Struct(children) + }, + other => { + match other.splitn(2, ':').collect::>()[..] { + // Timestamps with no timezone + ["tss", ""] => ArrowDataType::Timestamp(TimeUnit::Second, None), + ["tsm", ""] => ArrowDataType::Timestamp(TimeUnit::Millisecond, None), + ["tsu", ""] => ArrowDataType::Timestamp(TimeUnit::Microsecond, None), + ["tsn", ""] => ArrowDataType::Timestamp(TimeUnit::Nanosecond, None), + + // Timestamps with timezone + ["tss", tz] => { + ArrowDataType::Timestamp(TimeUnit::Second, Some(PlSmallStr::from_str(tz))) + }, + ["tsm", tz] => { + ArrowDataType::Timestamp(TimeUnit::Millisecond, Some(PlSmallStr::from_str(tz))) + }, + ["tsu", tz] => { + ArrowDataType::Timestamp(TimeUnit::Microsecond, Some(PlSmallStr::from_str(tz))) + }, + ["tsn", tz] => { + ArrowDataType::Timestamp(TimeUnit::Nanosecond, Some(PlSmallStr::from_str(tz))) + }, + + ["w", size_raw] => { + // Example: "w:42" fixed-width binary [42 bytes] + let size = size_raw + .parse::() + .map_err(|_| polars_err!(ComputeError: "size is not a valid integer"))?; + ArrowDataType::FixedSizeBinary(size) + }, + ["+w", size_raw] => { + // Example: "+w:123" fixed-sized list [123 items] + let size = size_raw + .parse::() + .map_err(|_| polars_err!(ComputeError: "size is not a valid integer"))?; + let child = to_field(schema.child(0))?; + ArrowDataType::FixedSizeList(Box::new(child), size) + }, + ["d", raw] => { + // Decimal + let (precision, scale) = match raw.split(',').collect::>()[..] { + [precision_raw, scale_raw] => { + // Example: "d:19,10" decimal128 [precision 19, scale 10] + (precision_raw, scale_raw) + }, + [precision_raw, scale_raw, width_raw] => { + // Example: "d:19,10,NNN" decimal bitwidth = NNN [precision 19, scale 10] + // Only bitwdth of 128 currently supported + let bit_width = width_raw.parse::().map_err(|_| { + polars_err!(ComputeError: "Decimal bit width is not a valid integer") + })?; + if bit_width == 256 { + return Ok(ArrowDataType::Decimal256( + precision_raw.parse::().map_err(|_| { + polars_err!(ComputeError: "Decimal precision is not a valid integer") + })?, + scale_raw.parse::().map_err(|_| { + polars_err!(ComputeError: "Decimal scale is not a valid integer") + })?, + )); + } + (precision_raw, scale_raw) + }, + _ => { + polars_bail!(ComputeError: + "Decimal must contain 2 or 3 comma-separated values" + ) + }, + }; + + ArrowDataType::Decimal( + precision.parse::().map_err(|_| { + polars_err!(ComputeError: + "Decimal precision is not a valid integer" + ) + })?, + scale.parse::().map_err(|_| { + polars_err!(ComputeError: + "Decimal scale is not a valid integer" + ) + })?, + ) + }, + [union_type @ "+us", union_parts] | [union_type @ "+ud", union_parts] => { + // union, sparse + // Example "+us:I,J,..." sparse union with type ids I,J... + // Example: "+ud:I,J,..." dense union with type ids I,J... + let mode = UnionMode::sparse(union_type == "+us"); + let type_ids = union_parts + .split(',') + .map(|x| { + x.parse::().map_err(|_| { + polars_err!(ComputeError: + "Union type id is not a valid integer" + ) + }) + }) + .collect::>>()?; + let fields = (0..schema.n_children as usize) + .map(|x| to_field(schema.child(x))) + .collect::>>()?; + ArrowDataType::Union(Box::new(UnionType { + fields, + ids: Some(type_ids), + mode, + })) + }, + _ => { + polars_bail!(ComputeError: + "The datatype \"{other}\" is still not supported in Rust implementation", + ) + }, + } + }, + }) +} + +/// the inverse of [to_field] +fn to_format(dtype: &ArrowDataType) -> String { + match dtype { + ArrowDataType::Null => "n".to_string(), + ArrowDataType::Boolean => "b".to_string(), + ArrowDataType::Int8 => "c".to_string(), + ArrowDataType::UInt8 => "C".to_string(), + ArrowDataType::Int16 => "s".to_string(), + ArrowDataType::UInt16 => "S".to_string(), + ArrowDataType::Int32 => "i".to_string(), + ArrowDataType::UInt32 => "I".to_string(), + ArrowDataType::Int64 => "l".to_string(), + ArrowDataType::UInt64 => "L".to_string(), + // Doesn't exist in arrow, '_pl' prefixed is Polars specific + ArrowDataType::Int128 => "_pli128".to_string(), + ArrowDataType::Float16 => "e".to_string(), + ArrowDataType::Float32 => "f".to_string(), + ArrowDataType::Float64 => "g".to_string(), + ArrowDataType::Binary => "z".to_string(), + ArrowDataType::LargeBinary => "Z".to_string(), + ArrowDataType::Utf8 => "u".to_string(), + ArrowDataType::LargeUtf8 => "U".to_string(), + ArrowDataType::Date32 => "tdD".to_string(), + ArrowDataType::Date64 => "tdm".to_string(), + ArrowDataType::Time32(TimeUnit::Second) => "tts".to_string(), + ArrowDataType::Time32(TimeUnit::Millisecond) => "ttm".to_string(), + ArrowDataType::Time32(_) => { + unreachable!("Time32 is only supported for seconds and milliseconds") + }, + ArrowDataType::Time64(TimeUnit::Microsecond) => "ttu".to_string(), + ArrowDataType::Time64(TimeUnit::Nanosecond) => "ttn".to_string(), + ArrowDataType::Time64(_) => { + unreachable!("Time64 is only supported for micro and nanoseconds") + }, + ArrowDataType::Duration(TimeUnit::Second) => "tDs".to_string(), + ArrowDataType::Duration(TimeUnit::Millisecond) => "tDm".to_string(), + ArrowDataType::Duration(TimeUnit::Microsecond) => "tDu".to_string(), + ArrowDataType::Duration(TimeUnit::Nanosecond) => "tDn".to_string(), + ArrowDataType::Interval(IntervalUnit::YearMonth) => "tiM".to_string(), + ArrowDataType::Interval(IntervalUnit::DayTime) => "tiD".to_string(), + ArrowDataType::Interval(IntervalUnit::MonthDayNano) => { + todo!("Spec for FFI for MonthDayNano still not defined.") + }, + ArrowDataType::Timestamp(unit, tz) => { + let unit = match unit { + TimeUnit::Second => "s", + TimeUnit::Millisecond => "m", + TimeUnit::Microsecond => "u", + TimeUnit::Nanosecond => "n", + }; + format!( + "ts{}:{}", + unit, + tz.as_ref().map(|x| x.as_str()).unwrap_or("") + ) + }, + ArrowDataType::Utf8View => "vu".to_string(), + ArrowDataType::BinaryView => "vz".to_string(), + ArrowDataType::Decimal(precision, scale) => format!("d:{precision},{scale}"), + ArrowDataType::Decimal256(precision, scale) => format!("d:{precision},{scale},256"), + ArrowDataType::List(_) => "+l".to_string(), + ArrowDataType::LargeList(_) => "+L".to_string(), + ArrowDataType::Struct(_) => "+s".to_string(), + ArrowDataType::FixedSizeBinary(size) => format!("w:{size}"), + ArrowDataType::FixedSizeList(_, size) => format!("+w:{size}"), + ArrowDataType::Union(u) => { + let sparsness = if u.mode.is_sparse() { 's' } else { 'd' }; + let mut r = format!("+u{sparsness}:"); + let ids = if let Some(ids) = &u.ids { + ids.iter() + .fold(String::new(), |a, b| a + b.to_string().as_str() + ",") + } else { + (0..u.fields.len()).fold(String::new(), |a, b| a + b.to_string().as_str() + ",") + }; + let ids = &ids[..ids.len() - 1]; // take away last "," + r.push_str(ids); + r + }, + ArrowDataType::Map(_, _) => "+m".to_string(), + ArrowDataType::Dictionary(index, _, _) => to_format(&(*index).into()), + ArrowDataType::Extension(ext) => to_format(&ext.inner), + ArrowDataType::Unknown => unimplemented!(), + } +} + +pub(super) fn get_child(dtype: &ArrowDataType, index: usize) -> PolarsResult { + match (index, dtype) { + (0, ArrowDataType::List(field)) => Ok(field.dtype().clone()), + (0, ArrowDataType::FixedSizeList(field, _)) => Ok(field.dtype().clone()), + (0, ArrowDataType::LargeList(field)) => Ok(field.dtype().clone()), + (0, ArrowDataType::Map(field, _)) => Ok(field.dtype().clone()), + (index, ArrowDataType::Struct(fields)) => Ok(fields[index].dtype().clone()), + (index, ArrowDataType::Union(u)) => Ok(u.fields[index].dtype().clone()), + (index, ArrowDataType::Extension(ext)) => get_child(&ext.inner, index), + (child, dtype) => polars_bail!(ComputeError: + "Requested child {child} to type {dtype:?} that has no such child", + ), + } +} + +fn metadata_to_bytes(metadata: &BTreeMap) -> Vec { + let a = (metadata.len() as i32).to_ne_bytes().to_vec(); + metadata.iter().fold(a, |mut acc, (key, value)| { + acc.extend((key.len() as i32).to_ne_bytes()); + acc.extend(key.as_bytes()); + acc.extend((value.len() as i32).to_ne_bytes()); + acc.extend(value.as_bytes()); + acc + }) +} + +unsafe fn read_ne_i32(ptr: *const u8) -> i32 { + let slice = std::slice::from_raw_parts(ptr, 4); + i32::from_ne_bytes(slice.try_into().unwrap()) +} + +unsafe fn read_bytes(ptr: *const u8, len: usize) -> &'static str { + let slice = std::slice::from_raw_parts(ptr, len); + simdutf8::basic::from_utf8(slice).unwrap() +} + +unsafe fn metadata_from_bytes(data: *const ::std::os::raw::c_char) -> (Metadata, Extension) { + let mut data = data as *const u8; // u8 = i8 + if data.is_null() { + return (Metadata::default(), None); + }; + let len = read_ne_i32(data); + data = data.add(4); + + let mut result = BTreeMap::new(); + let mut extension_name = None; + let mut extension_metadata = None; + for _ in 0..len { + let key_len = read_ne_i32(data) as usize; + data = data.add(4); + let key = read_bytes(data, key_len); + data = data.add(key_len); + let value_len = read_ne_i32(data) as usize; + data = data.add(4); + let value = read_bytes(data, value_len); + data = data.add(value_len); + match key { + "ARROW:extension:name" => { + extension_name = Some(PlSmallStr::from_str(value)); + }, + "ARROW:extension:metadata" => { + extension_metadata = Some(PlSmallStr::from_str(value)); + }, + _ => { + result.insert(PlSmallStr::from_str(key), PlSmallStr::from_str(value)); + }, + }; + } + let extension = extension_name.map(|name| (name, extension_metadata)); + (result, extension) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_all() { + let mut dts = vec![ + ArrowDataType::Null, + ArrowDataType::Boolean, + ArrowDataType::UInt8, + ArrowDataType::UInt16, + ArrowDataType::UInt32, + ArrowDataType::UInt64, + ArrowDataType::Int8, + ArrowDataType::Int16, + ArrowDataType::Int32, + ArrowDataType::Int64, + ArrowDataType::Float32, + ArrowDataType::Float64, + ArrowDataType::Date32, + ArrowDataType::Date64, + ArrowDataType::Time32(TimeUnit::Second), + ArrowDataType::Time32(TimeUnit::Millisecond), + ArrowDataType::Time64(TimeUnit::Microsecond), + ArrowDataType::Time64(TimeUnit::Nanosecond), + ArrowDataType::Decimal(5, 5), + ArrowDataType::Utf8, + ArrowDataType::LargeUtf8, + ArrowDataType::Binary, + ArrowDataType::LargeBinary, + ArrowDataType::FixedSizeBinary(2), + ArrowDataType::List(Box::new(Field::new( + PlSmallStr::from_static("example"), + ArrowDataType::Boolean, + false, + ))), + ArrowDataType::FixedSizeList( + Box::new(Field::new( + PlSmallStr::from_static("example"), + ArrowDataType::Boolean, + false, + )), + 2, + ), + ArrowDataType::LargeList(Box::new(Field::new( + PlSmallStr::from_static("example"), + ArrowDataType::Boolean, + false, + ))), + ArrowDataType::Struct(vec![ + Field::new(PlSmallStr::from_static("a"), ArrowDataType::Int64, true), + Field::new( + PlSmallStr::from_static("b"), + ArrowDataType::List(Box::new(Field::new( + PlSmallStr::from_static("item"), + ArrowDataType::Int32, + true, + ))), + true, + ), + ]), + ArrowDataType::Map( + Box::new(Field::new( + PlSmallStr::from_static("a"), + ArrowDataType::Int64, + true, + )), + true, + ), + ArrowDataType::Union(Box::new(UnionType { + fields: vec![ + Field::new(PlSmallStr::from_static("a"), ArrowDataType::Int64, true), + Field::new( + PlSmallStr::from_static("b"), + ArrowDataType::List(Box::new(Field::new( + PlSmallStr::from_static("item"), + ArrowDataType::Int32, + true, + ))), + true, + ), + ], + ids: Some(vec![1, 2]), + mode: UnionMode::Dense, + })), + ArrowDataType::Union(Box::new(UnionType { + fields: vec![ + Field::new(PlSmallStr::from_static("a"), ArrowDataType::Int64, true), + Field::new( + PlSmallStr::from_static("b"), + ArrowDataType::List(Box::new(Field::new( + PlSmallStr::from_static("item"), + ArrowDataType::Int32, + true, + ))), + true, + ), + ], + ids: Some(vec![0, 1]), + mode: UnionMode::Sparse, + })), + ]; + for time_unit in [ + TimeUnit::Second, + TimeUnit::Millisecond, + TimeUnit::Microsecond, + TimeUnit::Nanosecond, + ] { + dts.push(ArrowDataType::Timestamp(time_unit, None)); + dts.push(ArrowDataType::Timestamp( + time_unit, + Some(PlSmallStr::from_static("00:00")), + )); + dts.push(ArrowDataType::Duration(time_unit)); + } + for interval_type in [ + IntervalUnit::DayTime, + IntervalUnit::YearMonth, + //IntervalUnit::MonthDayNano, // not yet defined on the C data interface + ] { + dts.push(ArrowDataType::Interval(interval_type)); + } + + for expected in dts { + let field = Field::new(PlSmallStr::from_static("a"), expected.clone(), true); + let schema = ArrowSchema::new(&field); + let result = unsafe { super::to_dtype(&schema).unwrap() }; + assert_eq!(result, expected); + } + } +} diff --git a/crates/polars-arrow/src/ffi/stream.rs b/crates/polars-arrow/src/ffi/stream.rs new file mode 100644 index 000000000000..e3ed1d347c29 --- /dev/null +++ b/crates/polars-arrow/src/ffi/stream.rs @@ -0,0 +1,219 @@ +use std::ffi::{CStr, CString}; +use std::ops::DerefMut; + +use polars_error::{PolarsError, PolarsResult, polars_bail, polars_err}; + +use super::{ + ArrowArray, ArrowArrayStream, ArrowSchema, export_array_to_c, export_field_to_c, + import_array_from_c, import_field_from_c, +}; +use crate::array::Array; +use crate::datatypes::Field; + +impl Drop for ArrowArrayStream { + fn drop(&mut self) { + match self.release { + None => (), + Some(release) => unsafe { release(self) }, + }; + } +} + +unsafe impl Send for ArrowArrayStream {} + +impl ArrowArrayStream { + /// Creates an empty [`ArrowArrayStream`] used to import from a producer. + pub fn empty() -> Self { + Self { + get_schema: None, + get_next: None, + get_last_error: None, + release: None, + private_data: std::ptr::null_mut(), + } + } +} + +unsafe fn handle_error(iter: &mut ArrowArrayStream) -> PolarsError { + let error = unsafe { (iter.get_last_error.unwrap())(&mut *iter) }; + + if error.is_null() { + return polars_err!(ComputeError: "got unspecified external error"); + } + + let error = unsafe { CStr::from_ptr(error) }; + polars_err!(ComputeError: "got external error: {}", error.to_str().unwrap()) +} + +/// Implements an iterator of [`Array`] consumed from the [C stream interface](https://arrow.apache.org/docs/format/CStreamInterface.html). +pub struct ArrowArrayStreamReader> { + iter: Iter, + field: Field, +} + +impl> ArrowArrayStreamReader { + /// Returns a new [`ArrowArrayStreamReader`] + /// # Error + /// Errors iff the [`ArrowArrayStream`] is out of specification, + /// or was already released prior to calling this function. + /// + /// # Safety + /// This method is intrinsically `unsafe` since it assumes that the `ArrowArrayStream` + /// contains a valid Arrow C stream interface. + /// In particular: + /// * The `ArrowArrayStream` fulfills the invariants of the C stream interface + /// * The schema `get_schema` produces fulfills the C data interface + pub unsafe fn try_new(mut iter: Iter) -> PolarsResult { + if iter.release.is_none() { + polars_bail!(InvalidOperation: "the C stream was already released") + }; + + if iter.get_next.is_none() { + polars_bail!(InvalidOperation: "the C stream must contain a non-null get_next") + }; + + if iter.get_last_error.is_none() { + polars_bail!(InvalidOperation: "The C stream MUST contain a non-null get_last_error") + }; + + let mut field = ArrowSchema::empty(); + let status = if let Some(f) = iter.get_schema { + unsafe { (f)(&mut *iter, &mut field) } + } else { + polars_bail!(InvalidOperation: + "The C stream MUST contain a non-null get_schema" + ) + }; + + if status != 0 { + return Err(unsafe { handle_error(&mut iter) }); + } + + let field = unsafe { import_field_from_c(&field)? }; + + Ok(Self { iter, field }) + } + + /// Returns the field provided by the stream + pub fn field(&self) -> &Field { + &self.field + } + + /// Advances this iterator by one array + /// # Error + /// Errors iff: + /// * The C stream interface returns an error + /// * The C stream interface returns an invalid array (that we can identify, see Safety below) + /// + /// # Safety + /// Calling this iterator's `next` assumes that the [`ArrowArrayStream`] produces arrow arrays + /// that fulfill the C data interface + pub unsafe fn next(&mut self) -> Option>> { + let mut array = ArrowArray::empty(); + let status = unsafe { (self.iter.get_next.unwrap())(&mut *self.iter, &mut array) }; + + if status != 0 { + return Some(Err(unsafe { handle_error(&mut self.iter) })); + } + + // last paragraph of https://arrow.apache.org/docs/format/CStreamInterface.html#c.ArrowArrayStream.get_next + array.release?; + + // SAFETY: assumed from the C stream interface + unsafe { import_array_from_c(array, self.field.dtype.clone()) } + .map(Some) + .transpose() + } +} + +struct PrivateData { + iter: Box>>>, + field: Field, + error: Option, +} + +unsafe extern "C" fn get_next(iter: *mut ArrowArrayStream, array: *mut ArrowArray) -> i32 { + if iter.is_null() { + return 2001; + } + let private = &mut *((*iter).private_data as *mut PrivateData); + + match private.iter.next() { + Some(Ok(item)) => { + // check that the array has the same dtype as field + let item_dt = item.dtype(); + let expected_dt = private.field.dtype(); + if item_dt != expected_dt { + private.error = Some(CString::new(format!("The iterator produced an item of data type {item_dt:?} but the producer expects data type {expected_dt:?}").as_bytes().to_vec()).unwrap()); + return 2001; // custom application specific error (since this is never a result of this interface) + } + + std::ptr::write(array, export_array_to_c(item)); + + private.error = None; + 0 + }, + Some(Err(err)) => { + private.error = Some(CString::new(err.to_string().as_bytes().to_vec()).unwrap()); + 2001 // custom application specific error (since this is never a result of this interface) + }, + None => { + let a = ArrowArray::empty(); + std::ptr::write_unaligned(array, a); + private.error = None; + 0 + }, + } +} + +unsafe extern "C" fn get_schema(iter: *mut ArrowArrayStream, schema: *mut ArrowSchema) -> i32 { + if iter.is_null() { + return 2001; + } + let private = &mut *((*iter).private_data as *mut PrivateData); + + std::ptr::write(schema, export_field_to_c(&private.field)); + 0 +} + +unsafe extern "C" fn get_last_error(iter: *mut ArrowArrayStream) -> *const ::std::os::raw::c_char { + if iter.is_null() { + return std::ptr::null(); + } + let private = &mut *((*iter).private_data as *mut PrivateData); + + private + .error + .as_ref() + .map(|x| x.as_ptr()) + .unwrap_or(std::ptr::null()) +} + +unsafe extern "C" fn release(iter: *mut ArrowArrayStream) { + if iter.is_null() { + return; + } + let _ = Box::from_raw((*iter).private_data as *mut PrivateData); + (*iter).release = None; + // private drops automatically +} + +/// Exports an iterator to the [C stream interface](https://arrow.apache.org/docs/format/CStreamInterface.html) +pub fn export_iterator( + iter: Box>>>, + field: Field, +) -> ArrowArrayStream { + let private_data = Box::new(PrivateData { + iter, + field, + error: None, + }); + + ArrowArrayStream { + get_schema: Some(get_schema), + get_next: Some(get_next), + get_last_error: Some(get_last_error), + release: Some(release), + private_data: Box::into_raw(private_data) as *mut ::std::os::raw::c_void, + } +} diff --git a/crates/polars-arrow/src/io/README.md b/crates/polars-arrow/src/io/README.md new file mode 100644 index 000000000000..52e236220f13 --- /dev/null +++ b/crates/polars-arrow/src/io/README.md @@ -0,0 +1,28 @@ +# IO module + +This document describes the overall design of this module. + +## Rules: + +- Each directory in this module corresponds to a specific format such as `csv` and `json`. +- directories that depend on external dependencies MUST be feature gated, with a feature named with + a prefix `io_`. +- modules MUST re-export any API of external dependencies they require as part of their public API. + E.g. + - if a module as an API `write(writer: &mut csv:Writer, ...)`, it MUST contain + `pub use csv::Writer;`. + + The rational is that adding this crate to `cargo.toml` must be sufficient to use it. +- Each directory SHOULD contain two directories, `read` and `write`, corresponding to functionality + about reading from the format and writing to the format respectively. +- The base module SHOULD contain `use pub read;` and `use pub write;`. +- Implementations SHOULD separate reading of "data" from reading of "metadata". Examples: + - schema read or inference SHOULD be a separate function + - functions that read "data" SHOULD consume a schema typically pre-read. +- Implementations SHOULD separate IO-bounded operations from CPU-bounded operations. I.e. + implementations SHOULD: + - contain functions that consume a `Read` implementor and output a "raw" struct, i.e. a struct + that is e.g. compressed and serialized + - contain functions that consume a "raw" struct and convert it into Arrow. + - offer each of these functions as independent public APIs, so that consumers can decide how to + balance CPU-bounds and IO-bounds. diff --git a/crates/polars-arrow/src/io/avro/mod.rs b/crates/polars-arrow/src/io/avro/mod.rs new file mode 100644 index 000000000000..f535f4ab68a1 --- /dev/null +++ b/crates/polars-arrow/src/io/avro/mod.rs @@ -0,0 +1,34 @@ +//! Read and write from and to Apache Avro + +pub use avro_schema; + +pub mod read; +pub mod write; + +// macros that can operate in sync and async code. +macro_rules! avro_decode { + ($reader:ident $($_await:tt)*) => { + { + let mut i = 0u64; + let mut buf = [0u8; 1]; + let mut j = 0; + loop { + if j > 9 { + // if j * 7 > 64 + polars_error::polars_bail!(oos = "zigzag decoding failed - corrupt avro file") + } + $reader.read_exact(&mut buf[..])$($_await)*?; + i |= (u64::from(buf[0] & 0x7F)) << (j * 7); + if (buf[0] >> 7) == 0 { + break; + } else { + j += 1; + } + } + + Ok(i) + } + } +} + +pub(crate) use avro_decode; diff --git a/crates/polars-arrow/src/io/avro/read/deserialize.rs b/crates/polars-arrow/src/io/avro/read/deserialize.rs new file mode 100644 index 000000000000..c775f7de810f --- /dev/null +++ b/crates/polars-arrow/src/io/avro/read/deserialize.rs @@ -0,0 +1,528 @@ +use std::sync::Arc; + +use avro_schema::file::Block; +use avro_schema::schema::{Enum, Field as AvroField, Record, Schema as AvroSchema}; +use polars_error::{PolarsResult, polars_bail, polars_err}; + +use super::nested::*; +use super::util; +use crate::array::*; +use crate::datatypes::*; +use crate::record_batch::RecordBatchT; +use crate::types::months_days_ns; +use crate::with_match_primitive_type_full; + +fn make_mutable( + dtype: &ArrowDataType, + avro_field: Option<&AvroSchema>, + capacity: usize, +) -> PolarsResult> { + Ok(match dtype.to_physical_type() { + PhysicalType::Boolean => { + Box::new(MutableBooleanArray::with_capacity(capacity)) as Box + }, + PhysicalType::Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { + Box::new(MutablePrimitiveArray::<$T>::with_capacity(capacity).to(dtype.clone())) + as Box + }), + PhysicalType::Binary => { + Box::new(MutableBinaryArray::::with_capacity(capacity)) as Box + }, + PhysicalType::Utf8 => { + Box::new(MutableUtf8Array::::with_capacity(capacity)) as Box + }, + PhysicalType::Dictionary(_) => { + if let Some(AvroSchema::Enum(Enum { symbols, .. })) = avro_field { + let values = Utf8Array::::from_slice(symbols); + Box::new(FixedItemsUtf8Dictionary::with_capacity(values, capacity)) + as Box + } else { + unreachable!() + } + }, + _ => match dtype { + ArrowDataType::List(inner) => { + let values = make_mutable(inner.dtype(), None, 0)?; + Box::new(DynMutableListArray::::new_from( + values, + dtype.clone(), + capacity, + )) as Box + }, + ArrowDataType::FixedSizeBinary(size) => { + Box::new(MutableFixedSizeBinaryArray::with_capacity(*size, capacity)) + as Box + }, + ArrowDataType::Struct(fields) => { + let values = fields + .iter() + .map(|field| make_mutable(field.dtype(), None, capacity)) + .collect::>>()?; + Box::new(DynMutableStructArray::new(values, dtype.clone())) as Box + }, + other => { + polars_bail!(nyi = "Deserializing type {other:#?} is still not implemented") + }, + }, + }) +} + +fn is_union_null_first(avro_field: &AvroSchema) -> bool { + if let AvroSchema::Union(schemas) = avro_field { + schemas[0] == AvroSchema::Null + } else { + unreachable!() + } +} + +fn deserialize_item<'a>( + array: &mut dyn MutableArray, + is_nullable: bool, + avro_field: &AvroSchema, + mut block: &'a [u8], +) -> PolarsResult<&'a [u8]> { + if is_nullable { + let variant = util::zigzag_i64(&mut block)?; + let is_null_first = is_union_null_first(avro_field); + if is_null_first && variant == 0 || !is_null_first && variant != 0 { + array.push_null(); + return Ok(block); + } + } + deserialize_value(array, avro_field, block) +} + +fn deserialize_value<'a>( + array: &mut dyn MutableArray, + avro_field: &AvroSchema, + mut block: &'a [u8], +) -> PolarsResult<&'a [u8]> { + let dtype = array.dtype(); + match dtype { + ArrowDataType::List(inner) => { + let is_nullable = inner.is_nullable; + let avro_inner = match avro_field { + AvroSchema::Array(inner) => inner.as_ref(), + AvroSchema::Union(u) => match &u.as_slice() { + &[AvroSchema::Array(inner), _] | &[_, AvroSchema::Array(inner)] => { + inner.as_ref() + }, + _ => unreachable!(), + }, + _ => unreachable!(), + }; + + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + // Arrays are encoded as a series of blocks. + loop { + // Each block consists of a long count value, followed by that many array items. + let len = util::zigzag_i64(&mut block)?; + let len = if len < 0 { + // Avro spec: If a block's count is negative, its absolute value is used, + // and the count is followed immediately by a long block size indicating the number of bytes in the block. This block size permits fast skipping through data, e.g., when projecting a record to a subset of its fields. + let _ = util::zigzag_i64(&mut block)?; + + -len + } else { + len + }; + + // A block with count zero indicates the end of the array. + if len == 0 { + break; + } + + // Each item is encoded per the array’s item schema. + let values = array.mut_values(); + for _ in 0..len { + block = deserialize_item(values, is_nullable, avro_inner, block)?; + } + } + array.try_push_valid()?; + }, + ArrowDataType::Struct(inner_fields) => { + let fields = match avro_field { + AvroSchema::Record(Record { fields, .. }) => fields, + AvroSchema::Union(u) => match &u.as_slice() { + &[AvroSchema::Record(Record { fields, .. }), _] + | &[_, AvroSchema::Record(Record { fields, .. })] => fields, + _ => unreachable!(), + }, + _ => unreachable!(), + }; + + let is_nullable = inner_fields + .iter() + .map(|x| x.is_nullable) + .collect::>(); + let array = array + .as_mut_any() + .downcast_mut::() + .unwrap(); + + for (index, (field, is_nullable)) in fields.iter().zip(is_nullable.iter()).enumerate() { + let values = array.mut_values(index); + block = deserialize_item(values, *is_nullable, &field.schema, block)?; + } + array.try_push_valid()?; + }, + _ => match dtype.to_physical_type() { + PhysicalType::Boolean => { + let is_valid = block[0] == 1; + block = &block[1..]; + let array = array + .as_mut_any() + .downcast_mut::() + .unwrap(); + array.push(Some(is_valid)) + }, + PhysicalType::Primitive(primitive) => match primitive { + PrimitiveType::Int32 => { + let value = util::zigzag_i64(&mut block)? as i32; + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + array.push(Some(value)) + }, + PrimitiveType::Int64 => { + let value = util::zigzag_i64(&mut block)?; + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + array.push(Some(value)) + }, + PrimitiveType::Float32 => { + let value = f32::from_le_bytes(block[..size_of::()].try_into().unwrap()); + block = &block[size_of::()..]; + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + array.push(Some(value)) + }, + PrimitiveType::Float64 => { + let value = f64::from_le_bytes(block[..size_of::()].try_into().unwrap()); + block = &block[size_of::()..]; + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + array.push(Some(value)) + }, + PrimitiveType::MonthDayNano => { + // https://avro.apache.org/docs/current/spec.html#Duration + // 12 bytes, months, days, millis in LE + let data = &block[..12]; + block = &block[12..]; + + let value = months_days_ns::new( + i32::from_le_bytes([data[0], data[1], data[2], data[3]]), + i32::from_le_bytes([data[4], data[5], data[6], data[7]]), + i32::from_le_bytes([data[8], data[9], data[10], data[11]]) as i64 + * 1_000_000, + ); + + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + array.push(Some(value)) + }, + PrimitiveType::Int128 => { + let avro_inner = match avro_field { + AvroSchema::Bytes(_) | AvroSchema::Fixed(_) => avro_field, + AvroSchema::Union(u) => match &u.as_slice() { + &[e, AvroSchema::Null] | &[AvroSchema::Null, e] => e, + _ => unreachable!(), + }, + _ => unreachable!(), + }; + let len = match avro_inner { + AvroSchema::Bytes(_) => { + util::zigzag_i64(&mut block)?.try_into().map_err(|_| { + polars_err!( + oos = "Avro format contains a non-usize number of bytes" + ) + })? + }, + AvroSchema::Fixed(b) => b.size, + _ => unreachable!(), + }; + if len > 16 { + polars_bail!(oos = "Avro decimal bytes return more than 16 bytes") + } + let mut bytes = [0u8; 16]; + bytes[..len].copy_from_slice(&block[..len]); + block = &block[len..]; + let data = i128::from_be_bytes(bytes) >> (8 * (16 - len)); + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + array.push(Some(data)) + }, + _ => unreachable!(), + }, + PhysicalType::Utf8 => { + let len: usize = util::zigzag_i64(&mut block)?.try_into().map_err(|_| { + polars_err!(oos = "Avro format contains a non-usize number of bytes") + })?; + let data = simdutf8::basic::from_utf8(&block[..len])?; + block = &block[len..]; + + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + array.push(Some(data)) + }, + PhysicalType::Binary => { + let len: usize = util::zigzag_i64(&mut block)?.try_into().map_err(|_| { + polars_err!(oos = "Avro format contains a non-usize number of bytes") + })?; + let data = &block[..len]; + block = &block[len..]; + + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + array.push(Some(data)); + }, + PhysicalType::FixedSizeBinary => { + let array = array + .as_mut_any() + .downcast_mut::() + .unwrap(); + let len = array.size(); + let data = &block[..len]; + block = &block[len..]; + array.push(Some(data)); + }, + PhysicalType::Dictionary(_) => { + let index = util::zigzag_i64(&mut block)? as i32; + let array = array + .as_mut_any() + .downcast_mut::() + .unwrap(); + array.push_valid(index); + }, + _ => todo!(), + }, + }; + Ok(block) +} + +fn skip_item<'a>( + field: &Field, + avro_field: &AvroSchema, + mut block: &'a [u8], +) -> PolarsResult<&'a [u8]> { + if field.is_nullable { + let variant = util::zigzag_i64(&mut block)?; + let is_null_first = is_union_null_first(avro_field); + if is_null_first && variant == 0 || !is_null_first && variant != 0 { + return Ok(block); + } + } + match &field.dtype { + ArrowDataType::List(inner) => { + let avro_inner = match avro_field { + AvroSchema::Array(inner) => inner.as_ref(), + AvroSchema::Union(u) => match &u.as_slice() { + &[AvroSchema::Array(inner), _] | &[_, AvroSchema::Array(inner)] => { + inner.as_ref() + }, + _ => unreachable!(), + }, + _ => unreachable!(), + }; + + loop { + let len = util::zigzag_i64(&mut block)?; + let (len, bytes) = if len < 0 { + // Avro spec: If a block's count is negative, its absolute value is used, + // and the count is followed immediately by a long block size indicating the number of bytes in the block. This block size permits fast skipping through data, e.g., when projecting a record to a subset of its fields. + let bytes = util::zigzag_i64(&mut block)?; + + (-len, Some(bytes)) + } else { + (len, None) + }; + + let bytes: Option = bytes + .map(|bytes| { + bytes + .try_into() + .map_err(|_| polars_err!(oos = "Avro block size negative or too large")) + }) + .transpose()?; + + if len == 0 { + break; + } + + if let Some(bytes) = bytes { + block = &block[bytes..]; + } else { + for _ in 0..len { + block = skip_item(inner, avro_inner, block)?; + } + } + } + }, + ArrowDataType::Struct(inner_fields) => { + let fields = match avro_field { + AvroSchema::Record(Record { fields, .. }) => fields, + AvroSchema::Union(u) => match &u.as_slice() { + &[AvroSchema::Record(Record { fields, .. }), _] + | &[_, AvroSchema::Record(Record { fields, .. })] => fields, + _ => unreachable!(), + }, + _ => unreachable!(), + }; + + for (field, avro_field) in inner_fields.iter().zip(fields.iter()) { + block = skip_item(field, &avro_field.schema, block)?; + } + }, + _ => match field.dtype.to_physical_type() { + PhysicalType::Boolean => { + let _ = block[0] == 1; + block = &block[1..]; + }, + PhysicalType::Primitive(primitive) => match primitive { + PrimitiveType::Int32 => { + let _ = util::zigzag_i64(&mut block)?; + }, + PrimitiveType::Int64 => { + let _ = util::zigzag_i64(&mut block)?; + }, + PrimitiveType::Float32 => { + block = &block[size_of::()..]; + }, + PrimitiveType::Float64 => { + block = &block[size_of::()..]; + }, + PrimitiveType::MonthDayNano => { + block = &block[12..]; + }, + PrimitiveType::Int128 => { + let avro_inner = match avro_field { + AvroSchema::Bytes(_) | AvroSchema::Fixed(_) => avro_field, + AvroSchema::Union(u) => match &u.as_slice() { + &[e, AvroSchema::Null] | &[AvroSchema::Null, e] => e, + _ => unreachable!(), + }, + _ => unreachable!(), + }; + let len = match avro_inner { + AvroSchema::Bytes(_) => { + util::zigzag_i64(&mut block)?.try_into().map_err(|_| { + polars_err!( + oos = "Avro format contains a non-usize number of bytes" + ) + })? + }, + AvroSchema::Fixed(b) => b.size, + _ => unreachable!(), + }; + block = &block[len..]; + }, + _ => unreachable!(), + }, + PhysicalType::Utf8 | PhysicalType::Binary => { + let len: usize = util::zigzag_i64(&mut block)?.try_into().map_err(|_| { + polars_err!(oos = "Avro format contains a non-usize number of bytes") + })?; + block = &block[len..]; + }, + PhysicalType::FixedSizeBinary => { + let len = if let ArrowDataType::FixedSizeBinary(len) = &field.dtype { + *len + } else { + unreachable!() + }; + + block = &block[len..]; + }, + PhysicalType::Dictionary(_) => { + let _ = util::zigzag_i64(&mut block)? as i32; + }, + _ => todo!(), + }, + } + Ok(block) +} + +/// Deserializes a [`Block`] assumed to be encoded according to [`AvroField`] into [`RecordBatchT`], +/// using `projection` to ignore `avro_fields`. +/// # Panics +/// `fields`, `avro_fields` and `projection` must have the same length. +pub fn deserialize( + block: &Block, + fields: &ArrowSchema, + avro_fields: &[AvroField], + projection: &[bool], +) -> PolarsResult>> { + assert_eq!(fields.len(), avro_fields.len()); + assert_eq!(fields.len(), projection.len()); + + let rows = block.number_of_rows; + let mut block = block.data.as_ref(); + + // create mutables, one per field + let mut arrays: Vec> = fields + .iter_values() + .zip(avro_fields.iter()) + .zip(projection.iter()) + .map(|((field, avro_field), projection)| { + if *projection { + make_mutable(&field.dtype, Some(&avro_field.schema), rows) + } else { + // just something; we are not going to use it + make_mutable(&ArrowDataType::Int32, None, 0) + } + }) + .collect::>()?; + + // this is _the_ expensive transpose (rows -> columns) + for _ in 0..rows { + let iter = arrays + .iter_mut() + .zip(fields.iter_values()) + .zip(avro_fields.iter()) + .zip(projection.iter()); + + for (((array, field), avro_field), projection) in iter { + block = if *projection { + deserialize_item(array.as_mut(), field.is_nullable, &avro_field.schema, block) + } else { + skip_item(field, &avro_field.schema, block) + }? + } + } + + let projected_schema = fields + .iter_values() + .zip(projection) + .filter_map(|(f, p)| (*p).then_some(f)) + .cloned() + .collect(); + + RecordBatchT::try_new( + rows, + Arc::new(projected_schema), + arrays + .iter_mut() + .zip(projection.iter()) + .filter_map(|x| x.1.then(|| x.0)) + .map(|array| array.as_box()) + .collect(), + ) +} diff --git a/crates/polars-arrow/src/io/avro/read/mod.rs b/crates/polars-arrow/src/io/avro/read/mod.rs new file mode 100644 index 000000000000..9552c632ac46 --- /dev/null +++ b/crates/polars-arrow/src/io/avro/read/mod.rs @@ -0,0 +1,68 @@ +//! APIs to read from Avro format to arrow. +use std::io::Read; + +use avro_schema::file::FileMetadata; +use avro_schema::read::fallible_streaming_iterator::FallibleStreamingIterator; +use avro_schema::read::{BlockStreamingIterator, block_iterator}; +use avro_schema::schema::Field as AvroField; + +mod deserialize; +pub use deserialize::deserialize; +use polars_error::PolarsResult; + +mod nested; +mod schema; +mod util; + +pub use schema::infer_schema; + +use crate::array::Array; +use crate::datatypes::ArrowSchema; +use crate::record_batch::RecordBatchT; + +/// Single threaded, blocking reader of Avro; [`Iterator`] of [`RecordBatchT`]. +pub struct Reader { + iter: BlockStreamingIterator, + avro_fields: Vec, + fields: ArrowSchema, + projection: Vec, +} + +impl Reader { + /// Creates a new [`Reader`]. + pub fn new( + reader: R, + metadata: FileMetadata, + fields: ArrowSchema, + projection: Option>, + ) -> Self { + let projection = projection.unwrap_or_else(|| fields.iter().map(|_| true).collect()); + + Self { + iter: block_iterator(reader, metadata.compression, metadata.marker), + avro_fields: metadata.record.fields, + fields, + projection, + } + } + + /// Deconstructs itself into its internal reader + pub fn into_inner(self) -> R { + self.iter.into_inner() + } +} + +impl Iterator for Reader { + type Item = PolarsResult>>; + + fn next(&mut self) -> Option { + let fields = &self.fields; + let avro_fields = &self.avro_fields; + let projection = &self.projection; + + self.iter + .next() + .transpose() + .map(|maybe_block| deserialize(maybe_block?, fields, avro_fields, projection)) + } +} diff --git a/crates/polars-arrow/src/io/avro/read/nested.rs b/crates/polars-arrow/src/io/avro/read/nested.rs new file mode 100644 index 000000000000..22c96d76a621 --- /dev/null +++ b/crates/polars-arrow/src/io/avro/read/nested.rs @@ -0,0 +1,318 @@ +use polars_error::{PolarsResult, polars_err}; + +use crate::array::*; +use crate::bitmap::*; +use crate::datatypes::*; +use crate::offset::{Offset, Offsets}; + +/// Auxiliary struct +#[derive(Debug)] +pub struct DynMutableListArray { + dtype: ArrowDataType, + offsets: Offsets, + values: Box, + validity: Option, +} + +impl DynMutableListArray { + pub fn new_from(values: Box, dtype: ArrowDataType, capacity: usize) -> Self { + assert_eq!(values.len(), 0); + ListArray::::get_child_field(&dtype); + Self { + dtype, + offsets: Offsets::::with_capacity(capacity), + values, + validity: None, + } + } + + /// The values + pub fn mut_values(&mut self) -> &mut dyn MutableArray { + self.values.as_mut() + } + + #[inline] + pub fn try_push_valid(&mut self) -> PolarsResult<()> { + let total_length = self.values.len(); + let offset = self.offsets.last().to_usize(); + let length = total_length + .checked_sub(offset) + .ok_or_else(|| polars_err!(ComputeError: "overflow"))?; + + self.offsets.try_push(length)?; + if let Some(validity) = &mut self.validity { + validity.push(true) + } + Ok(()) + } + + #[inline] + fn push_null(&mut self) { + self.offsets.extend_constant(1); + match &mut self.validity { + Some(validity) => validity.push(false), + None => self.init_validity(), + } + } + + fn init_validity(&mut self) { + let len = self.offsets.len_proxy(); + + let mut validity = MutableBitmap::new(); + validity.extend_constant(len, true); + validity.set(len - 1, false); + self.validity = Some(validity) + } +} + +impl MutableArray for DynMutableListArray { + fn len(&self) -> usize { + self.offsets.len_proxy() + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + ListArray::new( + self.dtype.clone(), + std::mem::take(&mut self.offsets).into(), + self.values.as_box(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .boxed() + } + + fn as_arc(&mut self) -> std::sync::Arc { + ListArray::new( + self.dtype.clone(), + std::mem::take(&mut self.offsets).into(), + self.values.as_box(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .arced() + } + + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + self.push_null() + } + + fn reserve(&mut self, _: usize) { + todo!(); + } + + fn shrink_to_fit(&mut self) { + todo!(); + } +} + +#[derive(Debug)] +pub struct FixedItemsUtf8Dictionary { + dtype: ArrowDataType, + keys: MutablePrimitiveArray, + values: Utf8Array, +} + +impl FixedItemsUtf8Dictionary { + pub fn with_capacity(values: Utf8Array, capacity: usize) -> Self { + Self { + dtype: ArrowDataType::Dictionary( + IntegerType::Int32, + Box::new(values.dtype().clone()), + false, + ), + keys: MutablePrimitiveArray::::with_capacity(capacity), + values, + } + } + + pub fn push_valid(&mut self, key: i32) { + self.keys.push(Some(key)) + } + + /// pushes a null value + pub fn push_null(&mut self) { + self.keys.push(None) + } +} + +impl MutableArray for FixedItemsUtf8Dictionary { + fn len(&self) -> usize { + self.keys.len() + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.keys.validity() + } + + fn as_box(&mut self) -> Box { + Box::new( + DictionaryArray::try_new( + self.dtype.clone(), + std::mem::take(&mut self.keys).into(), + Box::new(self.values.clone()), + ) + .unwrap(), + ) + } + + fn as_arc(&mut self) -> std::sync::Arc { + std::sync::Arc::new( + DictionaryArray::try_new( + self.dtype.clone(), + std::mem::take(&mut self.keys).into(), + Box::new(self.values.clone()), + ) + .unwrap(), + ) + } + + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + self.push_null() + } + + fn reserve(&mut self, _: usize) { + todo!(); + } + + fn shrink_to_fit(&mut self) { + todo!(); + } +} + +/// Auxiliary struct +#[derive(Debug)] +pub struct DynMutableStructArray { + dtype: ArrowDataType, + length: usize, + values: Vec>, + validity: Option, +} + +impl DynMutableStructArray { + pub fn new(values: Vec>, dtype: ArrowDataType) -> Self { + Self { + dtype, + length: 0, + values, + validity: None, + } + } + + /// The values + pub fn mut_values(&mut self, field: usize) -> &mut dyn MutableArray { + self.values[field].as_mut() + } + + #[inline] + pub fn try_push_valid(&mut self) -> PolarsResult<()> { + if let Some(validity) = &mut self.validity { + validity.push(true) + } + self.length += 1; + Ok(()) + } + + #[inline] + fn push_null(&mut self) { + self.values.iter_mut().for_each(|x| x.push_null()); + self.length += 1; + match &mut self.validity { + Some(validity) => validity.push(false), + None => self.init_validity(), + } + } + + fn init_validity(&mut self) { + let len = self.len(); + + let mut validity = MutableBitmap::new(); + validity.extend_constant(len, true); + validity.set(len - 1, false); + self.validity = Some(validity) + } +} + +impl MutableArray for DynMutableStructArray { + fn len(&self) -> usize { + self.length + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + let values = self.values.iter_mut().map(|x| x.as_box()).collect(); + + Box::new(StructArray::new( + self.dtype.clone(), + self.length, + values, + std::mem::take(&mut self.validity).map(|x| x.into()), + )) + } + + fn as_arc(&mut self) -> std::sync::Arc { + let values = self.values.iter_mut().map(|x| x.as_box()).collect(); + + std::sync::Arc::new(StructArray::new( + self.dtype.clone(), + self.length, + values, + std::mem::take(&mut self.validity).map(|x| x.into()), + )) + } + + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + self.push_null() + } + + fn reserve(&mut self, _: usize) { + todo!(); + } + + fn shrink_to_fit(&mut self) { + todo!(); + } +} diff --git a/crates/polars-arrow/src/io/avro/read/schema.rs b/crates/polars-arrow/src/io/avro/read/schema.rs new file mode 100644 index 000000000000..23dc64bb3ad9 --- /dev/null +++ b/crates/polars-arrow/src/io/avro/read/schema.rs @@ -0,0 +1,161 @@ +use avro_schema::schema::{Enum, Fixed, Record, Schema as AvroSchema}; +use polars_error::{PolarsResult, polars_bail}; +use polars_utils::pl_str::PlSmallStr; + +use crate::datatypes::*; + +fn external_props(schema: &AvroSchema) -> Metadata { + let mut props = Metadata::new(); + match schema { + AvroSchema::Record(Record { doc: Some(doc), .. }) + | AvroSchema::Enum(Enum { doc: Some(doc), .. }) => { + props.insert( + PlSmallStr::from_static("avro::doc"), + PlSmallStr::from_str(doc.as_str()), + ); + }, + _ => {}, + } + props +} + +/// Infers an [`ArrowSchema`] from the root [`Record`]. +/// This +pub fn infer_schema(record: &Record) -> PolarsResult { + record + .fields + .iter() + .map(|field| { + let field = schema_to_field( + &field.schema, + Some(&field.name), + external_props(&field.schema), + )?; + + Ok((field.name.clone(), field)) + }) + .collect::>() +} + +fn schema_to_field( + schema: &AvroSchema, + name: Option<&str>, + props: Metadata, +) -> PolarsResult { + let mut nullable = false; + let dtype = match schema { + AvroSchema::Null => ArrowDataType::Null, + AvroSchema::Boolean => ArrowDataType::Boolean, + AvroSchema::Int(logical) => match logical { + Some(logical) => match logical { + avro_schema::schema::IntLogical::Date => ArrowDataType::Date32, + avro_schema::schema::IntLogical::Time => { + ArrowDataType::Time32(TimeUnit::Millisecond) + }, + }, + None => ArrowDataType::Int32, + }, + AvroSchema::Long(logical) => match logical { + Some(logical) => match logical { + avro_schema::schema::LongLogical::Time => { + ArrowDataType::Time64(TimeUnit::Microsecond) + }, + avro_schema::schema::LongLogical::TimestampMillis => ArrowDataType::Timestamp( + TimeUnit::Millisecond, + Some(PlSmallStr::from_static("00:00")), + ), + avro_schema::schema::LongLogical::TimestampMicros => ArrowDataType::Timestamp( + TimeUnit::Microsecond, + Some(PlSmallStr::from_static("00:00")), + ), + avro_schema::schema::LongLogical::LocalTimestampMillis => { + ArrowDataType::Timestamp(TimeUnit::Millisecond, None) + }, + avro_schema::schema::LongLogical::LocalTimestampMicros => { + ArrowDataType::Timestamp(TimeUnit::Microsecond, None) + }, + }, + None => ArrowDataType::Int64, + }, + AvroSchema::Float => ArrowDataType::Float32, + AvroSchema::Double => ArrowDataType::Float64, + AvroSchema::Bytes(logical) => match logical { + Some(logical) => match logical { + avro_schema::schema::BytesLogical::Decimal(precision, scale) => { + ArrowDataType::Decimal(*precision, *scale) + }, + }, + None => ArrowDataType::Binary, + }, + AvroSchema::String(_) => ArrowDataType::Utf8, + AvroSchema::Array(item_schema) => ArrowDataType::List(Box::new(schema_to_field( + item_schema, + Some("item"), // default name for list items + Metadata::default(), + )?)), + AvroSchema::Map(_) => todo!("Avro maps are mapped to MapArrays"), + AvroSchema::Union(schemas) => { + // If there are only two variants and one of them is null, set the other type as the field data type + let has_nullable = schemas.iter().any(|x| x == &AvroSchema::Null); + if has_nullable && schemas.len() == 2 { + nullable = true; + if let Some(schema) = schemas + .iter() + .find(|&schema| !matches!(schema, AvroSchema::Null)) + { + schema_to_field(schema, None, Metadata::default())?.dtype + } else { + polars_bail!(nyi = "Can't read avro union {schema:?}"); + } + } else { + let fields = schemas + .iter() + .map(|s| schema_to_field(s, None, Metadata::default())) + .collect::>>()?; + ArrowDataType::Union(Box::new(UnionType { + fields, + ids: None, + mode: UnionMode::Dense, + })) + } + }, + AvroSchema::Record(Record { fields, .. }) => { + let fields = fields + .iter() + .map(|field| { + let mut props = Metadata::new(); + if let Some(doc) = &field.doc { + props.insert( + PlSmallStr::from_static("avro::doc"), + PlSmallStr::from_str(doc), + ); + } + schema_to_field(&field.schema, Some(&field.name), props) + }) + .collect::>()?; + ArrowDataType::Struct(fields) + }, + AvroSchema::Enum { .. } => { + return Ok(Field::new( + PlSmallStr::from_str(name.unwrap_or_default()), + ArrowDataType::Dictionary(IntegerType::Int32, Box::new(ArrowDataType::Utf8), false), + false, + )); + }, + AvroSchema::Fixed(Fixed { size, logical, .. }) => match logical { + Some(logical) => match logical { + avro_schema::schema::FixedLogical::Decimal(precision, scale) => { + ArrowDataType::Decimal(*precision, *scale) + }, + avro_schema::schema::FixedLogical::Duration => { + ArrowDataType::Interval(IntervalUnit::MonthDayNano) + }, + }, + None => ArrowDataType::FixedSizeBinary(*size), + }, + }; + + let name = name.unwrap_or_default(); + + Ok(Field::new(PlSmallStr::from_str(name), dtype, nullable).with_metadata(props)) +} diff --git a/crates/polars-arrow/src/io/avro/read/util.rs b/crates/polars-arrow/src/io/avro/read/util.rs new file mode 100644 index 000000000000..f097326694ee --- /dev/null +++ b/crates/polars-arrow/src/io/avro/read/util.rs @@ -0,0 +1,18 @@ +use std::io::Read; + +use polars_error::PolarsResult; + +use super::super::avro_decode; + +pub fn zigzag_i64(reader: &mut R) -> PolarsResult { + let z = decode_variable(reader)?; + Ok(if z & 0x1 == 0 { + (z >> 1) as i64 + } else { + !(z >> 1) as i64 + }) +} + +fn decode_variable(reader: &mut R) -> PolarsResult { + avro_decode!(reader) +} diff --git a/crates/polars-arrow/src/io/avro/write/mod.rs b/crates/polars-arrow/src/io/avro/write/mod.rs new file mode 100644 index 000000000000..85cc3aa9d808 --- /dev/null +++ b/crates/polars-arrow/src/io/avro/write/mod.rs @@ -0,0 +1,28 @@ +//! APIs to write to Avro format. +use avro_schema::file::Block; + +mod schema; +pub use schema::to_record; +mod serialize; +pub use serialize::{BoxSerializer, can_serialize, new_serializer}; + +/// consumes a set of [`BoxSerializer`] into an [`Block`]. +/// # Panics +/// Panics iff the number of items in any of the serializers is not equal to the number of rows +/// declared in the `block`. +pub fn serialize(serializers: &mut [BoxSerializer], block: &mut Block) { + let Block { + data, + number_of_rows, + } = block; + + data.clear(); // restart it + + // _the_ transpose (columns -> rows) + for _ in 0..*number_of_rows { + for serializer in &mut *serializers { + let item_data = serializer.next().unwrap(); + data.extend(item_data); + } + } +} diff --git a/crates/polars-arrow/src/io/avro/write/schema.rs b/crates/polars-arrow/src/io/avro/write/schema.rs new file mode 100644 index 000000000000..d023851089ff --- /dev/null +++ b/crates/polars-arrow/src/io/avro/write/schema.rs @@ -0,0 +1,94 @@ +use avro_schema::schema::{ + BytesLogical, Field as AvroField, Fixed, FixedLogical, IntLogical, LongLogical, Record, + Schema as AvroSchema, +}; +use polars_error::{PolarsResult, polars_bail}; + +use crate::datatypes::*; + +/// Converts a [`ArrowSchema`] to an Avro [`Record`]. +pub fn to_record(schema: &ArrowSchema, name: String) -> PolarsResult { + let mut name_counter: i32 = 0; + let fields = schema + .iter_values() + .map(|f| field_to_field(f, &mut name_counter)) + .collect::>()?; + Ok(Record { + name, + namespace: None, + doc: None, + aliases: vec![], + fields, + }) +} + +fn field_to_field(field: &Field, name_counter: &mut i32) -> PolarsResult { + let schema = type_to_schema(field.dtype(), field.is_nullable, name_counter)?; + Ok(AvroField::new(field.name.to_string(), schema)) +} + +fn type_to_schema( + dtype: &ArrowDataType, + is_nullable: bool, + name_counter: &mut i32, +) -> PolarsResult { + Ok(if is_nullable { + AvroSchema::Union(vec![ + AvroSchema::Null, + _type_to_schema(dtype, name_counter)?, + ]) + } else { + _type_to_schema(dtype, name_counter)? + }) +} + +fn _get_field_name(name_counter: &mut i32) -> String { + *name_counter += 1; + format!("r{name_counter}") +} + +fn _type_to_schema(dtype: &ArrowDataType, name_counter: &mut i32) -> PolarsResult { + Ok(match dtype.to_logical_type() { + ArrowDataType::Null => AvroSchema::Null, + ArrowDataType::Boolean => AvroSchema::Boolean, + ArrowDataType::Int32 => AvroSchema::Int(None), + ArrowDataType::Int64 => AvroSchema::Long(None), + ArrowDataType::Float32 => AvroSchema::Float, + ArrowDataType::Float64 => AvroSchema::Double, + ArrowDataType::Binary => AvroSchema::Bytes(None), + ArrowDataType::LargeBinary => AvroSchema::Bytes(None), + ArrowDataType::Utf8 => AvroSchema::String(None), + ArrowDataType::LargeUtf8 => AvroSchema::String(None), + ArrowDataType::LargeList(inner) | ArrowDataType::List(inner) => { + AvroSchema::Array(Box::new(type_to_schema( + &inner.dtype, + inner.is_nullable, + name_counter, + )?)) + }, + ArrowDataType::Struct(fields) => AvroSchema::Record(Record::new( + _get_field_name(name_counter), + fields + .iter() + .map(|f| field_to_field(f, name_counter)) + .collect::>>()?, + )), + ArrowDataType::Date32 => AvroSchema::Int(Some(IntLogical::Date)), + ArrowDataType::Time32(TimeUnit::Millisecond) => AvroSchema::Int(Some(IntLogical::Time)), + ArrowDataType::Time64(TimeUnit::Microsecond) => AvroSchema::Long(Some(LongLogical::Time)), + ArrowDataType::Timestamp(TimeUnit::Millisecond, None) => { + AvroSchema::Long(Some(LongLogical::LocalTimestampMillis)) + }, + ArrowDataType::Timestamp(TimeUnit::Microsecond, None) => { + AvroSchema::Long(Some(LongLogical::LocalTimestampMicros)) + }, + ArrowDataType::Interval(IntervalUnit::MonthDayNano) => { + let mut fixed = Fixed::new("", 12); + fixed.logical = Some(FixedLogical::Duration); + AvroSchema::Fixed(fixed) + }, + ArrowDataType::FixedSizeBinary(size) => AvroSchema::Fixed(Fixed::new("", *size)), + ArrowDataType::Decimal(p, s) => AvroSchema::Bytes(Some(BytesLogical::Decimal(*p, *s))), + other => polars_bail!(nyi = "write {other:?} to avro"), + }) +} diff --git a/crates/polars-arrow/src/io/avro/write/serialize.rs b/crates/polars-arrow/src/io/avro/write/serialize.rs new file mode 100644 index 000000000000..ba287521b677 --- /dev/null +++ b/crates/polars-arrow/src/io/avro/write/serialize.rs @@ -0,0 +1,535 @@ +use avro_schema::schema::{Record, Schema as AvroSchema}; +use avro_schema::write::encode; + +use super::super::super::iterator::*; +use crate::array::*; +use crate::bitmap::utils::ZipValidity; +use crate::datatypes::{ArrowDataType, IntervalUnit, PhysicalType, PrimitiveType}; +use crate::offset::Offset; +use crate::types::months_days_ns; + +// Zigzag representation of false and true respectively. +const IS_NULL: u8 = 0; +const IS_VALID: u8 = 2; + +/// A type alias for a boxed [`StreamingIterator`], used to write arrays into avro rows +/// (i.e. a column -> row transposition of types known at run-time) +pub type BoxSerializer<'a> = Box + 'a + Send + Sync>; + +fn utf8_required(array: &Utf8Array) -> BoxSerializer { + Box::new(BufStreamingIterator::new( + array.values_iter(), + |x, buf| { + encode::zigzag_encode(x.len() as i64, buf).unwrap(); + buf.extend_from_slice(x.as_bytes()); + }, + vec![], + )) +} + +fn utf8_optional(array: &Utf8Array) -> BoxSerializer { + Box::new(BufStreamingIterator::new( + array.iter(), + |x, buf| { + if let Some(x) = x { + buf.push(IS_VALID); + encode::zigzag_encode(x.len() as i64, buf).unwrap(); + buf.extend_from_slice(x.as_bytes()); + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) +} + +fn binary_required(array: &BinaryArray) -> BoxSerializer { + Box::new(BufStreamingIterator::new( + array.values_iter(), + |x, buf| { + encode::zigzag_encode(x.len() as i64, buf).unwrap(); + buf.extend_from_slice(x); + }, + vec![], + )) +} + +fn binary_optional(array: &BinaryArray) -> BoxSerializer { + Box::new(BufStreamingIterator::new( + array.iter(), + |x, buf| { + if let Some(x) = x { + buf.push(IS_VALID); + encode::zigzag_encode(x.len() as i64, buf).unwrap(); + buf.extend_from_slice(x); + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) +} + +fn fixed_size_binary_required(array: &FixedSizeBinaryArray) -> BoxSerializer { + Box::new(BufStreamingIterator::new( + array.values_iter(), + |x, buf| { + buf.extend_from_slice(x); + }, + vec![], + )) +} + +fn fixed_size_binary_optional(array: &FixedSizeBinaryArray) -> BoxSerializer { + Box::new(BufStreamingIterator::new( + array.iter(), + |x, buf| { + if let Some(x) = x { + buf.push(IS_VALID); + buf.extend_from_slice(x); + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) +} + +fn list_required<'a, O: Offset>(array: &'a ListArray, schema: &AvroSchema) -> BoxSerializer<'a> { + let mut inner = new_serializer(array.values().as_ref(), schema); + let lengths = array + .offsets() + .buffer() + .windows(2) + .map(|w| (w[1] - w[0]).to_usize() as i64); + + Box::new(BufStreamingIterator::new( + lengths, + move |length, buf| { + encode::zigzag_encode(length, buf).unwrap(); + let mut rows = 0; + while let Some(item) = inner.next() { + buf.extend_from_slice(item); + rows += 1; + if rows == length { + encode::zigzag_encode(0, buf).unwrap(); + break; + } + } + }, + vec![], + )) +} + +fn list_optional<'a, O: Offset>(array: &'a ListArray, schema: &AvroSchema) -> BoxSerializer<'a> { + let mut inner = new_serializer(array.values().as_ref(), schema); + let lengths = array + .offsets() + .buffer() + .windows(2) + .map(|w| (w[1] - w[0]).to_usize() as i64); + let lengths = ZipValidity::new_with_validity(lengths, array.validity()); + + Box::new(BufStreamingIterator::new( + lengths, + move |length, buf| { + if let Some(length) = length { + buf.push(IS_VALID); + encode::zigzag_encode(length, buf).unwrap(); + let mut rows = 0; + while let Some(item) = inner.next() { + buf.extend_from_slice(item); + rows += 1; + if rows == length { + encode::zigzag_encode(0, buf).unwrap(); + break; + } + } + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) +} + +fn struct_required<'a>(array: &'a StructArray, schema: &Record) -> BoxSerializer<'a> { + let schemas = schema.fields.iter().map(|x| &x.schema); + let mut inner = array + .values() + .iter() + .zip(schemas) + .map(|(x, schema)| new_serializer(x.as_ref(), schema)) + .collect::>(); + + Box::new(BufStreamingIterator::new( + 0..array.len(), + move |_, buf| { + inner + .iter_mut() + .for_each(|item| buf.extend_from_slice(item.next().unwrap())) + }, + vec![], + )) +} + +fn struct_optional<'a>(array: &'a StructArray, schema: &Record) -> BoxSerializer<'a> { + let schemas = schema.fields.iter().map(|x| &x.schema); + let mut inner = array + .values() + .iter() + .zip(schemas) + .map(|(x, schema)| new_serializer(x.as_ref(), schema)) + .collect::>(); + + let iterator = ZipValidity::new_with_validity(0..array.len(), array.validity()); + + Box::new(BufStreamingIterator::new( + iterator, + move |maybe, buf| { + if maybe.is_some() { + buf.push(IS_VALID); + inner + .iter_mut() + .for_each(|item| buf.extend_from_slice(item.next().unwrap())) + } else { + buf.push(IS_NULL); + // skip the item + inner.iter_mut().for_each(|item| { + let _ = item.next().unwrap(); + }); + } + }, + vec![], + )) +} + +/// Creates a [`StreamingIterator`] trait object that presents items from `array` +/// encoded according to `schema`. +/// # Panic +/// This function panics iff the `dtype` is not supported (use [`can_serialize`] to check) +/// # Implementation +/// This function performs minimal CPU work: it dynamically dispatches based on the schema +/// and arrow type. +pub fn new_serializer<'a>(array: &'a dyn Array, schema: &AvroSchema) -> BoxSerializer<'a> { + let dtype = array.dtype().to_physical_type(); + + match (dtype, schema) { + (PhysicalType::Boolean, AvroSchema::Boolean) => { + let values = array.as_any().downcast_ref::().unwrap(); + Box::new(BufStreamingIterator::new( + values.values_iter(), + |x, buf| { + buf.push(x as u8); + }, + vec![], + )) + }, + (PhysicalType::Boolean, AvroSchema::Union(_)) => { + let values = array.as_any().downcast_ref::().unwrap(); + Box::new(BufStreamingIterator::new( + values.iter(), + |x, buf| { + if let Some(x) = x { + buf.extend_from_slice(&[IS_VALID, x as u8]); + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) + }, + (PhysicalType::Utf8, AvroSchema::Union(_)) => { + utf8_optional::(array.as_any().downcast_ref().unwrap()) + }, + (PhysicalType::LargeUtf8, AvroSchema::Union(_)) => { + utf8_optional::(array.as_any().downcast_ref().unwrap()) + }, + (PhysicalType::Utf8, AvroSchema::String(_)) => { + utf8_required::(array.as_any().downcast_ref().unwrap()) + }, + (PhysicalType::LargeUtf8, AvroSchema::String(_)) => { + utf8_required::(array.as_any().downcast_ref().unwrap()) + }, + (PhysicalType::Binary, AvroSchema::Union(_)) => { + binary_optional::(array.as_any().downcast_ref().unwrap()) + }, + (PhysicalType::LargeBinary, AvroSchema::Union(_)) => { + binary_optional::(array.as_any().downcast_ref().unwrap()) + }, + (PhysicalType::FixedSizeBinary, AvroSchema::Union(_)) => { + fixed_size_binary_optional(array.as_any().downcast_ref().unwrap()) + }, + (PhysicalType::Binary, AvroSchema::Bytes(_)) => { + binary_required::(array.as_any().downcast_ref().unwrap()) + }, + (PhysicalType::LargeBinary, AvroSchema::Bytes(_)) => { + binary_required::(array.as_any().downcast_ref().unwrap()) + }, + (PhysicalType::FixedSizeBinary, AvroSchema::Fixed(_)) => { + fixed_size_binary_required(array.as_any().downcast_ref().unwrap()) + }, + + (PhysicalType::Primitive(PrimitiveType::Int32), AvroSchema::Union(_)) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.iter(), + |x, buf| { + if let Some(x) = x { + buf.push(IS_VALID); + encode::zigzag_encode(*x as i64, buf).unwrap(); + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::Int32), AvroSchema::Int(_)) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.values().iter(), + |x, buf| { + encode::zigzag_encode(*x as i64, buf).unwrap(); + }, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::Int64), AvroSchema::Union(_)) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.iter(), + |x, buf| { + if let Some(x) = x { + buf.push(IS_VALID); + encode::zigzag_encode(*x, buf).unwrap(); + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::Int64), AvroSchema::Long(_)) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.values().iter(), + |x, buf| { + encode::zigzag_encode(*x, buf).unwrap(); + }, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::Float32), AvroSchema::Union(_)) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.iter(), + |x, buf| { + if let Some(x) = x { + buf.push(IS_VALID); + buf.extend(x.to_le_bytes()) + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::Float32), AvroSchema::Float) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.values().iter(), + |x, buf| { + buf.extend_from_slice(&x.to_le_bytes()); + }, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::Float64), AvroSchema::Union(_)) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.iter(), + |x, buf| { + if let Some(x) = x { + buf.push(IS_VALID); + buf.extend(x.to_le_bytes()) + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::Float64), AvroSchema::Double) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.values().iter(), + |x, buf| { + buf.extend_from_slice(&x.to_le_bytes()); + }, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::Int128), AvroSchema::Bytes(_)) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.values().iter(), + |x, buf| { + let len = ((x.leading_zeros() / 8) - ((x.leading_zeros() / 8) % 2)) as usize; + encode::zigzag_encode((16 - len) as i64, buf).unwrap(); + buf.extend_from_slice(&x.to_be_bytes()[len..]); + }, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::Int128), AvroSchema::Union(_)) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.iter(), + |x, buf| { + if let Some(x) = x { + buf.push(IS_VALID); + let len = + ((x.leading_zeros() / 8) - ((x.leading_zeros() / 8) % 2)) as usize; + encode::zigzag_encode((16 - len) as i64, buf).unwrap(); + buf.extend_from_slice(&x.to_be_bytes()[len..]); + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::MonthDayNano), AvroSchema::Fixed(_)) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.values().iter(), + interval_write, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::MonthDayNano), AvroSchema::Union(_)) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.iter(), + |x, buf| { + if let Some(x) = x { + buf.push(IS_VALID); + interval_write(x, buf) + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) + }, + + (PhysicalType::List, AvroSchema::Array(schema)) => { + list_required::(array.as_any().downcast_ref().unwrap(), schema.as_ref()) + }, + (PhysicalType::LargeList, AvroSchema::Array(schema)) => { + list_required::(array.as_any().downcast_ref().unwrap(), schema.as_ref()) + }, + (PhysicalType::List, AvroSchema::Union(inner)) => { + let schema = if let AvroSchema::Array(schema) = &inner[1] { + schema.as_ref() + } else { + unreachable!("The schema declaration does not match the deserialization") + }; + list_optional::(array.as_any().downcast_ref().unwrap(), schema) + }, + (PhysicalType::LargeList, AvroSchema::Union(inner)) => { + let schema = if let AvroSchema::Array(schema) = &inner[1] { + schema.as_ref() + } else { + unreachable!("The schema declaration does not match the deserialization") + }; + list_optional::(array.as_any().downcast_ref().unwrap(), schema) + }, + (PhysicalType::Struct, AvroSchema::Record(inner)) => { + struct_required(array.as_any().downcast_ref().unwrap(), inner) + }, + (PhysicalType::Struct, AvroSchema::Union(inner)) => { + let inner = if let AvroSchema::Record(inner) = &inner[1] { + inner + } else { + unreachable!("The schema declaration does not match the deserialization") + }; + struct_optional(array.as_any().downcast_ref().unwrap(), inner) + }, + (a, b) => todo!("{:?} -> {:?} not supported", a, b), + } +} + +/// Whether [`new_serializer`] supports `dtype`. +pub fn can_serialize(dtype: &ArrowDataType) -> bool { + use ArrowDataType::*; + match dtype.to_logical_type() { + List(inner) => return can_serialize(&inner.dtype), + LargeList(inner) => return can_serialize(&inner.dtype), + Struct(inner) => return inner.iter().all(|inner| can_serialize(&inner.dtype)), + _ => {}, + }; + + matches!( + dtype, + Boolean + | Int32 + | Int64 + | Float32 + | Float64 + | Decimal(_, _) + | Utf8 + | Binary + | FixedSizeBinary(_) + | LargeUtf8 + | LargeBinary + | Interval(IntervalUnit::MonthDayNano) + ) +} + +#[inline] +fn interval_write(x: &months_days_ns, buf: &mut Vec) { + // https://avro.apache.org/docs/current/spec.html#Duration + // 12 bytes, months, days, millis in LE + buf.reserve(12); + buf.extend(x.months().to_le_bytes()); + buf.extend(x.days().to_le_bytes()); + buf.extend(((x.ns() / 1_000_000) as i32).to_le_bytes()); +} diff --git a/crates/polars-arrow/src/io/ipc/append/mod.rs b/crates/polars-arrow/src/io/ipc/append/mod.rs new file mode 100644 index 000000000000..075ce78b4a4f --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/append/mod.rs @@ -0,0 +1,72 @@ +//! A struct adapter of Read+Seek+Write to append to IPC files +// read header and convert to writer information +// seek to first byte of header - 1 +// write new batch +// write new footer +use std::io::{Read, Seek, SeekFrom, Write}; + +use polars_error::{PolarsResult, polars_bail, polars_err}; + +use super::endianness::is_native_little_endian; +use super::read::{self, FileMetadata}; +use super::write::common::DictionaryTracker; +use super::write::writer::*; +use super::write::*; + +impl FileWriter { + /// Creates a new [`FileWriter`] from an existing file, seeking to the last message + /// and appending new messages afterwards. Users call `finish` to write the footer (with both) + /// the existing and appended messages on it. + /// # Error + /// This function errors iff: + /// * the file's endianness is not the native endianness (not yet supported) + /// * the file is not a valid Arrow IPC file + pub fn try_from_file( + mut writer: R, + metadata: FileMetadata, + options: WriteOptions, + ) -> PolarsResult> { + if metadata.ipc_schema.is_little_endian != is_native_little_endian() { + polars_bail!(ComputeError: "appending to a file of a non-native endianness is not supported") + } + + let dictionaries = + read::read_file_dictionaries(&mut writer, &metadata, &mut Default::default())?; + + let last_block = metadata.blocks.last().ok_or_else(|| { + polars_err!(oos = "an Arrow IPC file must have at least 1 message (the schema message)") + })?; + let offset: u64 = last_block + .offset + .try_into() + .map_err(|_| polars_err!(oos = "the block's offset must be a positive number"))?; + let meta_data_length: u64 = last_block + .meta_data_length + .try_into() + .map_err(|_| polars_err!(oos = "the block's offset must be a positive number"))?; + let body_length: u64 = last_block + .body_length + .try_into() + .map_err(|_| polars_err!(oos = "the block's body length must be a positive number"))?; + let offset: u64 = offset + meta_data_length + body_length; + + writer.seek(SeekFrom::Start(offset))?; + + Ok(FileWriter { + writer, + options, + schema: metadata.schema, + ipc_fields: metadata.ipc_schema.fields, + block_offsets: offset as usize, + dictionary_blocks: metadata.dictionaries.unwrap_or_default(), + record_blocks: metadata.blocks, + state: State::Started, // file already exists, so we are ready + dictionary_tracker: DictionaryTracker { + dictionaries, + cannot_replace: true, + }, + encoded_message: Default::default(), + custom_schema_metadata: None, + }) + } +} diff --git a/crates/polars-arrow/src/io/ipc/compression.rs b/crates/polars-arrow/src/io/ipc/compression.rs new file mode 100644 index 000000000000..4c42e776f301 --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/compression.rs @@ -0,0 +1,96 @@ +use polars_error::PolarsResult; +#[cfg(feature = "io_ipc_compression")] +use polars_error::to_compute_err; + +#[cfg(feature = "io_ipc_compression")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_compression")))] +pub fn decompress_lz4(input_buf: &[u8], output_buf: &mut [u8]) -> PolarsResult<()> { + use std::io::Read; + let mut decoder = lz4::Decoder::new(input_buf)?; + decoder.read_exact(output_buf).map_err(|e| e.into()) +} + +#[cfg(feature = "io_ipc_compression")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_compression")))] +pub fn decompress_zstd(input_buf: &[u8], output_buf: &mut [u8]) -> PolarsResult<()> { + use std::io::Read; + let mut decoder = zstd::Decoder::with_buffer(input_buf)?; + decoder.read_exact(output_buf).map_err(|e| e.into()) +} + +#[cfg(not(feature = "io_ipc_compression"))] +pub fn decompress_lz4(_input_buf: &[u8], _output_buf: &mut [u8]) -> PolarsResult<()> { + panic!( + "The crate was compiled without IPC compression. Use `io_ipc_compression` to read compressed IPC." + ); +} + +#[cfg(not(feature = "io_ipc_compression"))] +pub fn decompress_zstd(_input_buf: &[u8], _output_buf: &mut [u8]) -> PolarsResult<()> { + panic!( + "The crate was compiled without IPC compression. Use `io_ipc_compression` to read compressed IPC." + ); +} + +#[cfg(feature = "io_ipc_compression")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_compression")))] +pub fn compress_lz4(input_buf: &[u8], output_buf: &mut Vec) -> PolarsResult<()> { + use std::io::Write; + + let mut encoder = lz4::EncoderBuilder::new() + .build(output_buf) + .map_err(to_compute_err)?; + encoder.write_all(input_buf)?; + encoder.finish().1.map_err(|e| e.into()) +} + +#[cfg(feature = "io_ipc_compression")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_compression")))] +pub fn compress_zstd(input_buf: &[u8], output_buf: &mut Vec) -> PolarsResult<()> { + zstd::stream::copy_encode(input_buf, output_buf, 0).map_err(|e| e.into()) +} + +#[cfg(not(feature = "io_ipc_compression"))] +pub fn compress_lz4(_input_buf: &[u8], _output_buf: &[u8]) -> PolarsResult<()> { + panic!( + "The crate was compiled without IPC compression. Use `io_ipc_compression` to write compressed IPC." + ) +} + +#[cfg(not(feature = "io_ipc_compression"))] +pub fn compress_zstd(_input_buf: &[u8], _output_buf: &[u8]) -> PolarsResult<()> { + panic!( + "The crate was compiled without IPC compression. Use `io_ipc_compression` to write compressed IPC." + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg(feature = "io_ipc_compression")] + #[test] + #[cfg_attr(miri, ignore)] // ZSTD uses foreign calls that miri does not support + fn round_trip_zstd() { + let data: Vec = (0..200u8).map(|x| x % 10).collect(); + let mut buffer = vec![]; + compress_zstd(&data, &mut buffer).unwrap(); + + let mut result = vec![0; 200]; + decompress_zstd(&buffer, &mut result).unwrap(); + assert_eq!(data, result); + } + + #[cfg(feature = "io_ipc_compression")] + #[test] + #[cfg_attr(miri, ignore)] // LZ4 uses foreign calls that miri does not support + fn round_trip_lz4() { + let data: Vec = (0..200u8).map(|x| x % 10).collect(); + let mut buffer = vec![]; + compress_lz4(&data, &mut buffer).unwrap(); + + let mut result = vec![0; 200]; + decompress_lz4(&buffer, &mut result).unwrap(); + assert_eq!(data, result); + } +} diff --git a/crates/polars-arrow/src/io/ipc/endianness.rs b/crates/polars-arrow/src/io/ipc/endianness.rs new file mode 100644 index 000000000000..61b3f9b7c51c --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/endianness.rs @@ -0,0 +1,11 @@ +#[cfg(target_endian = "little")] +#[inline] +pub fn is_native_little_endian() -> bool { + true +} + +#[cfg(target_endian = "big")] +#[inline] +pub fn is_native_little_endian() -> bool { + false +} diff --git a/crates/polars-arrow/src/io/ipc/mod.rs b/crates/polars-arrow/src/io/ipc/mod.rs new file mode 100644 index 000000000000..2cadac90f90d --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/mod.rs @@ -0,0 +1,64 @@ +//! APIs to read from and write to Arrow's IPC format. +//! +//! Inter-process communication is a method through which different processes +//! share and pass data between them. Its use-cases include parallel +//! processing of chunks of data across different CPU cores, transferring +//! data between different Apache Arrow implementations in other languages and +//! more. Under the hood Apache Arrow uses [FlatBuffers](https://google.github.io/flatbuffers/) +//! as its binary protocol, so every Arrow-centered streaming or serialiation +//! problem that could be solved using FlatBuffers could probably be solved +//! using the more integrated approach that is exposed in this module. +//! +//! [Arrow's IPC protocol](https://arrow.apache.org/docs/format/Columnar.html#serialization-and-interprocess-communication-ipc) +//! allows only batch or dictionary columns to be passed +//! around due to its reliance on a pre-defined data scheme. This constraint +//! provides a large performance gain because serialized data will always have a +//! known structutre, i.e. the same fields and datatypes, with the only variance +//! being the number of rows and the actual data inside the Batch. This dramatically +//! increases the deserialization rate, as the bytes in the file or stream are already +//! structured "correctly". +//! +//! Reading and writing IPC messages is done using one of two variants - either +//! [`FileReader`](read::FileReader) <-> [`FileWriter`](struct@write::FileWriter) or +//! [`StreamReader`](read::StreamReader) <-> [`StreamWriter`](struct@write::StreamWriter). +//! These two variants wrap a type `T` that implements [`Read`](std::io::Read), and in +//! the case of the `File` variant it also implements [`Seek`](std::io::Seek). In +//! practice it means that `File`s can be arbitrarily accessed while `Stream`s are only +//! read in certain order - the one they were written in (first in, first out). +mod compression; +mod endianness; + +pub mod append; +pub mod read; +pub mod write; +pub use arrow_format as format; + +const ARROW_MAGIC_V1: [u8; 4] = [b'F', b'E', b'A', b'1']; +const ARROW_MAGIC_V2: [u8; 6] = [b'A', b'R', b'R', b'O', b'W', b'1']; +pub(crate) const CONTINUATION_MARKER: [u8; 4] = [0xff; 4]; + +/// Struct containing `dictionary_id` and nested `IpcField`, allowing users +/// to specify the dictionary ids of the IPC fields when writing to IPC. +#[derive(Debug, Clone, PartialEq, Default)] +pub struct IpcField { + /// optional children + pub fields: Vec, + /// dictionary id + pub dictionary_id: Option, +} + +impl IpcField { + /// Check (recursively) whether the [`IpcField`] contains a dictionary. + pub fn contains_dictionary(&self) -> bool { + self.dictionary_id.is_some() || self.fields.iter().any(|f| f.contains_dictionary()) + } +} + +/// Struct containing fields and whether the file is written in little or big endian. +#[derive(Debug, Clone, PartialEq)] +pub struct IpcSchema { + /// The fields in the schema + pub fields: Vec, + /// Endianness of the file + pub is_little_endian: bool, +} diff --git a/crates/polars-arrow/src/io/ipc/read/array/binary.rs b/crates/polars-arrow/src/io/ipc/read/array/binary.rs new file mode 100644 index 000000000000..d5436502593e --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/read/array/binary.rs @@ -0,0 +1,87 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use polars_error::{PolarsResult, polars_err}; + +use super::super::read_basic::*; +use super::super::{Compression, IpcBuffer, Node}; +use crate::array::BinaryArray; +use crate::buffer::Buffer; +use crate::datatypes::ArrowDataType; +use crate::io::ipc::read::array::{try_get_array_length, try_get_field_node}; +use crate::offset::Offset; + +#[allow(clippy::too_many_arguments)] +pub fn read_binary( + field_nodes: &mut VecDeque, + dtype: ArrowDataType, + buffers: &mut VecDeque, + reader: &mut R, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + scratch: &mut Vec, +) -> PolarsResult> { + let field_node = try_get_field_node(field_nodes, &dtype)?; + + let validity = read_validity( + buffers, + field_node, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + let length = try_get_array_length(field_node, limit)?; + + let offsets: Buffer = read_buffer( + buffers, + 1 + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + ) + // Older versions of the IPC format sometimes do not report an offset + .or_else(|_| PolarsResult::Ok(Buffer::::from(vec![O::default()])))?; + + let last_offset = offsets.last().unwrap().to_usize(); + let values = read_buffer( + buffers, + last_offset, + reader, + block_offset, + is_little_endian, + compression, + scratch, + )?; + + BinaryArray::::try_new(dtype, offsets.try_into()?, values, validity) +} + +pub fn skip_binary( + field_nodes: &mut VecDeque, + buffers: &mut VecDeque, +) -> PolarsResult<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + polars_err!( + oos = "IPC: unable to fetch the field for binary. The file or stream is corrupted." + ) + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| polars_err!(oos = "IPC: missing validity buffer."))?; + let _ = buffers + .pop_front() + .ok_or_else(|| polars_err!(oos = "IPC: missing offsets buffer."))?; + let _ = buffers + .pop_front() + .ok_or_else(|| polars_err!(oos = "IPC: missing values buffer."))?; + Ok(()) +} diff --git a/crates/polars-arrow/src/io/ipc/read/array/binview.rs b/crates/polars-arrow/src/io/ipc/read/array/binview.rs new file mode 100644 index 000000000000..4423cdaab6e4 --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/read/array/binview.rs @@ -0,0 +1,98 @@ +use std::io::{Read, Seek}; +use std::sync::Arc; + +use polars_error::polars_err; + +use super::super::read_basic::*; +use super::*; +use crate::array::{ArrayRef, BinaryViewArrayGeneric, View, ViewType}; +use crate::buffer::Buffer; + +#[allow(clippy::too_many_arguments)] +pub fn read_binview( + field_nodes: &mut VecDeque, + variadic_buffer_counts: &mut VecDeque, + dtype: ArrowDataType, + buffers: &mut VecDeque, + reader: &mut R, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + scratch: &mut Vec, +) -> PolarsResult { + let field_node = try_get_field_node(field_nodes, &dtype)?; + + let validity = read_validity( + buffers, + field_node, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + let length = try_get_array_length(field_node, limit)?; + let views: Buffer = read_buffer( + buffers, + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + )?; + + let n_variadic = variadic_buffer_counts.pop_front().ok_or_else( + || polars_err!(ComputeError: "IPC: unable to fetch the variadic buffers\n\nThe file or stream is corrupted.") + )?; + + let variadic_buffers = (0..n_variadic) + .map(|_| { + read_bytes( + buffers, + reader, + block_offset, + is_little_endian, + compression, + scratch, + ) + }) + .collect::>>>()?; + + BinaryViewArrayGeneric::::try_new(dtype, views, Arc::from(variadic_buffers), validity) + .map(|arr| arr.boxed()) +} + +pub fn skip_binview( + field_nodes: &mut VecDeque, + buffers: &mut VecDeque, + variadic_buffer_counts: &mut VecDeque, +) -> PolarsResult<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + polars_err!( + oos = "IPC: unable to fetch the field for utf8. The file or stream is corrupted." + ) + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| polars_err!(oos = "IPC: missing validity buffer."))?; + + let _ = buffers + .pop_front() + .ok_or_else(|| polars_err!(oos = "IPC: missing views buffer."))?; + + let n_variadic = variadic_buffer_counts.pop_front().ok_or_else( + || polars_err!(ComputeError: "IPC: unable to fetch the variadic buffers\n\nThe file or stream is corrupted.") + )?; + + for _ in 0..n_variadic { + let _ = buffers + .pop_front() + .ok_or_else(|| polars_err!(oos = "IPC: missing variadic buffer"))?; + } + Ok(()) +} diff --git a/crates/polars-arrow/src/io/ipc/read/array/boolean.rs b/crates/polars-arrow/src/io/ipc/read/array/boolean.rs new file mode 100644 index 000000000000..6131e2acc8cf --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/read/array/boolean.rs @@ -0,0 +1,68 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use polars_error::{PolarsResult, polars_err}; + +use super::super::read_basic::*; +use super::super::{Compression, IpcBuffer, Node}; +use crate::array::BooleanArray; +use crate::datatypes::ArrowDataType; +use crate::io::ipc::read::array::{try_get_array_length, try_get_field_node}; + +#[allow(clippy::too_many_arguments)] +pub fn read_boolean( + field_nodes: &mut VecDeque, + dtype: ArrowDataType, + buffers: &mut VecDeque, + reader: &mut R, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + scratch: &mut Vec, +) -> PolarsResult { + let field_node = try_get_field_node(field_nodes, &dtype)?; + + let validity = read_validity( + buffers, + field_node, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + let length = try_get_array_length(field_node, limit)?; + + let values = read_bitmap( + buffers, + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + )?; + BooleanArray::try_new(dtype, values, validity) +} + +pub fn skip_boolean( + field_nodes: &mut VecDeque, + buffers: &mut VecDeque, +) -> PolarsResult<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + polars_err!( + oos = "IPC: unable to fetch the field for boolean. The file or stream is corrupted." + ) + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| polars_err!(oos = "IPC: missing validity buffer."))?; + let _ = buffers + .pop_front() + .ok_or_else(|| polars_err!(oos = "IPC: missing values buffer."))?; + Ok(()) +} diff --git a/crates/polars-arrow/src/io/ipc/read/array/dictionary.rs b/crates/polars-arrow/src/io/ipc/read/array/dictionary.rs new file mode 100644 index 000000000000..5b9910b960a8 --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/read/array/dictionary.rs @@ -0,0 +1,64 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use polars_error::{PolarsResult, polars_bail, polars_err}; +use polars_utils::aliases::PlHashSet; + +use super::super::{Compression, Dictionaries, IpcBuffer, Node}; +use super::{read_primitive, skip_primitive}; +use crate::array::{DictionaryArray, DictionaryKey}; +use crate::datatypes::ArrowDataType; + +#[allow(clippy::too_many_arguments)] +pub fn read_dictionary( + field_nodes: &mut VecDeque, + dtype: ArrowDataType, + id: Option, + buffers: &mut VecDeque, + reader: &mut R, + dictionaries: &Dictionaries, + block_offset: u64, + compression: Option, + limit: Option, + is_little_endian: bool, + scratch: &mut Vec, +) -> PolarsResult> +where + Vec: TryInto, +{ + let id = if let Some(id) = id { + id + } else { + polars_bail!(oos = "Dictionary has no id."); + }; + let values = dictionaries + .get(&id) + .ok_or_else(|| { + let valid_ids = dictionaries.keys().collect::>(); + polars_err!(ComputeError: + "Dictionary id {id} not found. Valid ids: {valid_ids:?}" + ) + })? + .clone(); + + let keys = read_primitive( + field_nodes, + T::PRIMITIVE.into(), + buffers, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + DictionaryArray::::try_new(dtype, keys, values) +} + +pub fn skip_dictionary( + field_nodes: &mut VecDeque, + buffers: &mut VecDeque, +) -> PolarsResult<()> { + skip_primitive(field_nodes, buffers) +} diff --git a/crates/polars-arrow/src/io/ipc/read/array/fixed_size_binary.rs b/crates/polars-arrow/src/io/ipc/read/array/fixed_size_binary.rs new file mode 100644 index 000000000000..a9af5255d241 --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/read/array/fixed_size_binary.rs @@ -0,0 +1,70 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use polars_error::{PolarsResult, polars_err}; + +use super::super::read_basic::*; +use super::super::{Compression, IpcBuffer, Node}; +use crate::array::FixedSizeBinaryArray; +use crate::datatypes::ArrowDataType; +use crate::io::ipc::read::array::{try_get_array_length, try_get_field_node}; + +#[allow(clippy::too_many_arguments)] +pub fn read_fixed_size_binary( + field_nodes: &mut VecDeque, + dtype: ArrowDataType, + buffers: &mut VecDeque, + reader: &mut R, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + scratch: &mut Vec, +) -> PolarsResult { + let field_node = try_get_field_node(field_nodes, &dtype)?; + + let validity = read_validity( + buffers, + field_node, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + let length = try_get_array_length(field_node, limit)?; + + let length = length.saturating_mul(FixedSizeBinaryArray::maybe_get_size(&dtype)?); + let values = read_buffer( + buffers, + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + )?; + + FixedSizeBinaryArray::try_new(dtype, values, validity) +} + +pub fn skip_fixed_size_binary( + field_nodes: &mut VecDeque, + buffers: &mut VecDeque, +) -> PolarsResult<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + polars_err!(oos = + "IPC: unable to fetch the field for fixed-size binary. The file or stream is corrupted." + ) + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| polars_err!(oos = "IPC: missing validity buffer."))?; + let _ = buffers + .pop_front() + .ok_or_else(|| polars_err!(oos = "IPC: missing values buffer."))?; + Ok(()) +} diff --git a/crates/polars-arrow/src/io/ipc/read/array/fixed_size_list.rs b/crates/polars-arrow/src/io/ipc/read/array/fixed_size_list.rs new file mode 100644 index 000000000000..311fb16fece0 --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/read/array/fixed_size_list.rs @@ -0,0 +1,85 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use polars_error::{PolarsResult, polars_ensure, polars_err}; + +use super::super::super::IpcField; +use super::super::deserialize::{read, skip}; +use super::super::read_basic::*; +use super::super::{Compression, Dictionaries, IpcBuffer, Node, Version}; +use crate::array::FixedSizeListArray; +use crate::datatypes::ArrowDataType; +use crate::io::ipc::read::array::try_get_field_node; + +#[allow(clippy::too_many_arguments)] +pub fn read_fixed_size_list( + field_nodes: &mut VecDeque, + variadic_buffer_counts: &mut VecDeque, + dtype: ArrowDataType, + ipc_field: &IpcField, + buffers: &mut VecDeque, + reader: &mut R, + dictionaries: &Dictionaries, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + version: Version, + scratch: &mut Vec, +) -> PolarsResult { + let field_node = try_get_field_node(field_nodes, &dtype)?; + + let validity = read_validity( + buffers, + field_node, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + let (field, size) = FixedSizeListArray::get_child_and_size(&dtype); + polars_ensure!(size > 0, nyi = "Cannot read zero sized arrays from IPC"); + + let limit = limit.map(|x| x.saturating_mul(size)); + + let values = read( + field_nodes, + variadic_buffer_counts, + field, + &ipc_field.fields[0], + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + limit, + version, + scratch, + )?; + FixedSizeListArray::try_new(dtype, values.len() / size, values, validity) +} + +pub fn skip_fixed_size_list( + field_nodes: &mut VecDeque, + dtype: &ArrowDataType, + buffers: &mut VecDeque, + variadic_buffer_counts: &mut VecDeque, +) -> PolarsResult<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + polars_err!(oos = + "IPC: unable to fetch the field for fixed-size list. The file or stream is corrupted." + ) + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| polars_err!(oos = "IPC: missing validity buffer."))?; + + let (field, _) = FixedSizeListArray::get_child_and_size(dtype); + + skip(field_nodes, field.dtype(), buffers, variadic_buffer_counts) +} diff --git a/crates/polars-arrow/src/io/ipc/read/array/list.rs b/crates/polars-arrow/src/io/ipc/read/array/list.rs new file mode 100644 index 000000000000..2ef05109c983 --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/read/array/list.rs @@ -0,0 +1,106 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use polars_error::{PolarsResult, polars_err}; + +use super::super::super::IpcField; +use super::super::deserialize::{read, skip}; +use super::super::read_basic::*; +use super::super::{Compression, Dictionaries, IpcBuffer, Node, Version}; +use crate::array::ListArray; +use crate::buffer::Buffer; +use crate::datatypes::ArrowDataType; +use crate::io::ipc::read::array::{try_get_array_length, try_get_field_node}; +use crate::offset::Offset; + +#[allow(clippy::too_many_arguments)] +pub fn read_list( + field_nodes: &mut VecDeque, + variadic_buffer_counts: &mut VecDeque, + dtype: ArrowDataType, + ipc_field: &IpcField, + buffers: &mut VecDeque, + reader: &mut R, + dictionaries: &Dictionaries, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + version: Version, + scratch: &mut Vec, +) -> PolarsResult> +where + Vec: TryInto, +{ + let field_node = try_get_field_node(field_nodes, &dtype)?; + + let validity = read_validity( + buffers, + field_node, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + let length = try_get_array_length(field_node, limit)?; + + let offsets = read_buffer::( + buffers, + 1 + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + ) + // Older versions of the IPC format sometimes do not report an offset + .or_else(|_| PolarsResult::Ok(Buffer::::from(vec![O::default()])))?; + + let last_offset = offsets.last().unwrap().to_usize(); + + let field = ListArray::::get_child_field(&dtype); + + let values = read( + field_nodes, + variadic_buffer_counts, + field, + &ipc_field.fields[0], + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + Some(last_offset), + version, + scratch, + )?; + ListArray::try_new(dtype, offsets.try_into()?, values, validity) +} + +pub fn skip_list( + field_nodes: &mut VecDeque, + dtype: &ArrowDataType, + buffers: &mut VecDeque, + variadic_buffer_counts: &mut VecDeque, +) -> PolarsResult<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + polars_err!( + oos = "IPC: unable to fetch the field for list. The file or stream is corrupted." + ) + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| polars_err!(oos = "IPC: missing validity buffer."))?; + let _ = buffers + .pop_front() + .ok_or_else(|| polars_err!(oos = "IPC: missing offsets buffer."))?; + + let dtype = ListArray::::get_child_type(dtype); + + skip(field_nodes, dtype, buffers, variadic_buffer_counts) +} diff --git a/crates/polars-arrow/src/io/ipc/read/array/map.rs b/crates/polars-arrow/src/io/ipc/read/array/map.rs new file mode 100644 index 000000000000..a7ef9cd35225 --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/read/array/map.rs @@ -0,0 +1,102 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use polars_error::{PolarsResult, polars_err}; + +use super::super::super::IpcField; +use super::super::deserialize::{read, skip}; +use super::super::read_basic::*; +use super::super::{Compression, Dictionaries, IpcBuffer, Node, Version}; +use crate::array::MapArray; +use crate::buffer::Buffer; +use crate::datatypes::ArrowDataType; +use crate::io::ipc::read::array::{try_get_array_length, try_get_field_node}; + +#[allow(clippy::too_many_arguments)] +pub fn read_map( + field_nodes: &mut VecDeque, + variadic_buffer_counts: &mut VecDeque, + dtype: ArrowDataType, + ipc_field: &IpcField, + buffers: &mut VecDeque, + reader: &mut R, + dictionaries: &Dictionaries, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + version: Version, + scratch: &mut Vec, +) -> PolarsResult { + let field_node = try_get_field_node(field_nodes, &dtype)?; + + let validity = read_validity( + buffers, + field_node, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + let length = try_get_array_length(field_node, limit)?; + + let offsets = read_buffer::( + buffers, + 1 + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + ) + // Older versions of the IPC format sometimes do not report an offset + .or_else(|_| PolarsResult::Ok(Buffer::::from(vec![0i32])))?; + + let field = MapArray::get_field(&dtype); + + let last_offset: usize = offsets.last().copied().unwrap() as usize; + + let field = read( + field_nodes, + variadic_buffer_counts, + field, + &ipc_field.fields[0], + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + Some(last_offset), + version, + scratch, + )?; + MapArray::try_new(dtype, offsets.try_into()?, field, validity) +} + +pub fn skip_map( + field_nodes: &mut VecDeque, + dtype: &ArrowDataType, + buffers: &mut VecDeque, + variadic_buffer_counts: &mut VecDeque, +) -> PolarsResult<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + polars_err!( + oos = "IPC: unable to fetch the field for map. The file or stream is corrupted." + ) + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| polars_err!(oos = "IPC: missing validity buffer."))?; + let _ = buffers + .pop_front() + .ok_or_else(|| polars_err!(oos = "IPC: missing offsets buffer."))?; + + let dtype = MapArray::get_field(dtype).dtype(); + + skip(field_nodes, dtype, buffers, variadic_buffer_counts) +} diff --git a/crates/polars-arrow/src/io/ipc/read/array/mod.rs b/crates/polars-arrow/src/io/ipc/read/array/mod.rs new file mode 100644 index 000000000000..21c393a2869e --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/read/array/mod.rs @@ -0,0 +1,50 @@ +mod primitive; + +use std::collections::VecDeque; + +pub use primitive::*; +mod boolean; +pub use boolean::*; +mod utf8; +pub use utf8::*; +mod binary; +pub use binary::*; +mod fixed_size_binary; +pub use fixed_size_binary::*; +mod list; +pub use list::*; +mod fixed_size_list; +pub use fixed_size_list::*; +mod struct_; +pub use struct_::*; +mod null; +pub use null::*; +mod dictionary; +pub use dictionary::*; +mod union; +pub use union::*; +mod binview; +mod map; +pub use binview::*; +pub use map::*; +use polars_error::{PolarsResult, *}; + +use super::{Compression, IpcBuffer, Node, OutOfSpecKind}; +use crate::datatypes::ArrowDataType; + +fn try_get_field_node<'a>( + field_nodes: &mut VecDeque>, + dtype: &ArrowDataType, +) -> PolarsResult> { + field_nodes.pop_front().ok_or_else(|| { + polars_err!(ComputeError: "IPC: unable to fetch the field for {:?}\n\nThe file or stream is corrupted.", dtype) + }) +} + +fn try_get_array_length(field_node: Node, limit: Option) -> PolarsResult { + let length: usize = field_node + .length() + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + Ok(limit.map(|limit| limit.min(length)).unwrap_or(length)) +} diff --git a/crates/polars-arrow/src/io/ipc/read/array/null.rs b/crates/polars-arrow/src/io/ipc/read/array/null.rs new file mode 100644 index 000000000000..d0484a9b25a2 --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/read/array/null.rs @@ -0,0 +1,29 @@ +use std::collections::VecDeque; + +use polars_error::{PolarsResult, polars_err}; + +use super::super::Node; +use crate::array::NullArray; +use crate::datatypes::ArrowDataType; +use crate::io::ipc::read::array::{try_get_array_length, try_get_field_node}; + +pub fn read_null( + field_nodes: &mut VecDeque, + dtype: ArrowDataType, + limit: Option, +) -> PolarsResult { + let field_node = try_get_field_node(field_nodes, &dtype)?; + + let length = try_get_array_length(field_node, limit)?; + + NullArray::try_new(dtype, length) +} + +pub fn skip_null(field_nodes: &mut VecDeque) -> PolarsResult<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + polars_err!( + oos = "IPC: unable to fetch the field for null. The file or stream is corrupted." + ) + })?; + Ok(()) +} diff --git a/crates/polars-arrow/src/io/ipc/read/array/primitive.rs b/crates/polars-arrow/src/io/ipc/read/array/primitive.rs new file mode 100644 index 000000000000..aed0896ab552 --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/read/array/primitive.rs @@ -0,0 +1,72 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use polars_error::{PolarsResult, polars_err}; + +use super::super::read_basic::*; +use super::super::{Compression, IpcBuffer, Node}; +use crate::array::PrimitiveArray; +use crate::datatypes::ArrowDataType; +use crate::io::ipc::read::array::{try_get_array_length, try_get_field_node}; +use crate::types::NativeType; + +#[allow(clippy::too_many_arguments)] +pub fn read_primitive( + field_nodes: &mut VecDeque, + dtype: ArrowDataType, + buffers: &mut VecDeque, + reader: &mut R, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + scratch: &mut Vec, +) -> PolarsResult> +where + Vec: TryInto, +{ + let field_node = try_get_field_node(field_nodes, &dtype)?; + + let validity = read_validity( + buffers, + field_node, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + let length = try_get_array_length(field_node, limit)?; + + let values = read_buffer( + buffers, + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + )?; + PrimitiveArray::::try_new(dtype, values, validity) +} + +pub fn skip_primitive( + field_nodes: &mut VecDeque, + buffers: &mut VecDeque, +) -> PolarsResult<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + polars_err!( + oos = "IPC: unable to fetch the field for primitive. The file or stream is corrupted." + ) + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| polars_err!(oos = "IPC: missing validity buffer."))?; + let _ = buffers + .pop_front() + .ok_or_else(|| polars_err!(oos = "IPC: missing values buffer."))?; + Ok(()) +} diff --git a/crates/polars-arrow/src/io/ipc/read/array/struct_.rs b/crates/polars-arrow/src/io/ipc/read/array/struct_.rs new file mode 100644 index 000000000000..991b7ec3ef67 --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/read/array/struct_.rs @@ -0,0 +1,93 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use polars_error::{PolarsResult, polars_err}; + +use super::super::super::IpcField; +use super::super::deserialize::{read, skip}; +use super::super::read_basic::*; +use super::super::{Compression, Dictionaries, IpcBuffer, Node, Version}; +use super::try_get_array_length; +use crate::array::StructArray; +use crate::datatypes::ArrowDataType; +use crate::io::ipc::read::array::try_get_field_node; + +#[allow(clippy::too_many_arguments)] +pub fn read_struct( + field_nodes: &mut VecDeque, + variadic_buffer_counts: &mut VecDeque, + dtype: ArrowDataType, + ipc_field: &IpcField, + buffers: &mut VecDeque, + reader: &mut R, + dictionaries: &Dictionaries, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + version: Version, + scratch: &mut Vec, +) -> PolarsResult { + let field_node = try_get_field_node(field_nodes, &dtype)?; + let length = try_get_array_length(field_node, limit)?; + + let validity = read_validity( + buffers, + field_node, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + let fields = StructArray::get_fields(&dtype); + + let values = fields + .iter() + .zip(ipc_field.fields.iter()) + .map(|(field, ipc_field)| { + read( + field_nodes, + variadic_buffer_counts, + field, + ipc_field, + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + limit, + version, + scratch, + ) + }) + .collect::>>()?; + + StructArray::try_new(dtype, length, values, validity) +} + +pub fn skip_struct( + field_nodes: &mut VecDeque, + dtype: &ArrowDataType, + buffers: &mut VecDeque, + variadic_buffer_counts: &mut VecDeque, +) -> PolarsResult<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + polars_err!( + oos = "IPC: unable to fetch the field for struct. The file or stream is corrupted." + ) + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| polars_err!(oos = "IPC: missing validity buffer."))?; + + let fields = StructArray::get_fields(dtype); + + fields + .iter() + .try_for_each(|field| skip(field_nodes, field.dtype(), buffers, variadic_buffer_counts)) +} diff --git a/crates/polars-arrow/src/io/ipc/read/array/union.rs b/crates/polars-arrow/src/io/ipc/read/array/union.rs new file mode 100644 index 000000000000..595b57e1d9f2 --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/read/array/union.rs @@ -0,0 +1,124 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use polars_error::{PolarsResult, polars_err}; + +use super::super::super::IpcField; +use super::super::deserialize::{read, skip}; +use super::super::read_basic::*; +use super::super::{Compression, Dictionaries, IpcBuffer, Node, Version}; +use crate::array::UnionArray; +use crate::datatypes::{ArrowDataType, UnionMode}; +use crate::io::ipc::read::array::{try_get_array_length, try_get_field_node}; + +#[allow(clippy::too_many_arguments)] +pub fn read_union( + field_nodes: &mut VecDeque, + variadic_buffer_counts: &mut VecDeque, + dtype: ArrowDataType, + ipc_field: &IpcField, + buffers: &mut VecDeque, + reader: &mut R, + dictionaries: &Dictionaries, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + version: Version, + scratch: &mut Vec, +) -> PolarsResult { + let field_node = try_get_field_node(field_nodes, &dtype)?; + + if version != Version::V5 { + let _ = buffers + .pop_front() + .ok_or_else(|| polars_err!(oos = "IPC: missing validity buffer."))?; + }; + + let length = try_get_array_length(field_node, limit)?; + + let types = read_buffer( + buffers, + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + )?; + + let offsets = if let ArrowDataType::Union(u) = &dtype { + if !u.mode.is_sparse() { + Some(read_buffer( + buffers, + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + )?) + } else { + None + } + } else { + unreachable!() + }; + + let fields = UnionArray::get_fields(&dtype); + + let fields = fields + .iter() + .zip(ipc_field.fields.iter()) + .map(|(field, ipc_field)| { + read( + field_nodes, + variadic_buffer_counts, + field, + ipc_field, + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + None, + version, + scratch, + ) + }) + .collect::>>()?; + + UnionArray::try_new(dtype, types, fields, offsets) +} + +pub fn skip_union( + field_nodes: &mut VecDeque, + dtype: &ArrowDataType, + buffers: &mut VecDeque, + variadic_buffer_counts: &mut VecDeque, +) -> PolarsResult<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + polars_err!( + oos = "IPC: unable to fetch the field for struct. The file or stream is corrupted." + ) + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| polars_err!(oos = "IPC: missing validity buffer."))?; + if let ArrowDataType::Union(u) = dtype { + assert!(u.mode == UnionMode::Dense); + let _ = buffers + .pop_front() + .ok_or_else(|| polars_err!(oos = "IPC: missing offsets buffer."))?; + } else { + unreachable!() + }; + + let fields = UnionArray::get_fields(dtype); + + fields + .iter() + .try_for_each(|field| skip(field_nodes, field.dtype(), buffers, variadic_buffer_counts)) +} diff --git a/crates/polars-arrow/src/io/ipc/read/array/utf8.rs b/crates/polars-arrow/src/io/ipc/read/array/utf8.rs new file mode 100644 index 000000000000..33f43baf7f1b --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/read/array/utf8.rs @@ -0,0 +1,84 @@ +use std::io::{Read, Seek}; + +use polars_error::polars_err; + +use super::super::read_basic::*; +use super::*; +use crate::array::Utf8Array; +use crate::buffer::Buffer; +use crate::offset::Offset; + +#[allow(clippy::too_many_arguments)] +pub fn read_utf8( + field_nodes: &mut VecDeque, + dtype: ArrowDataType, + buffers: &mut VecDeque, + reader: &mut R, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + scratch: &mut Vec, +) -> PolarsResult> { + let field_node = try_get_field_node(field_nodes, &dtype)?; + + let validity = read_validity( + buffers, + field_node, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + let length = try_get_array_length(field_node, limit)?; + + let offsets: Buffer = read_buffer( + buffers, + 1 + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + ) + // Older versions of the IPC format sometimes do not report an offset + .or_else(|_| PolarsResult::Ok(Buffer::::from(vec![O::default()])))?; + + let last_offset = offsets.last().unwrap().to_usize(); + let values = read_buffer( + buffers, + last_offset, + reader, + block_offset, + is_little_endian, + compression, + scratch, + )?; + + Utf8Array::::try_new(dtype, offsets.try_into()?, values, validity) +} + +pub fn skip_utf8( + field_nodes: &mut VecDeque, + buffers: &mut VecDeque, +) -> PolarsResult<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + polars_err!( + oos = "IPC: unable to fetch the field for utf8. The file or stream is corrupted." + ) + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| polars_err!(oos = "IPC: missing validity buffer."))?; + let _ = buffers + .pop_front() + .ok_or_else(|| polars_err!(oos = "IPC: missing offsets buffer."))?; + let _ = buffers + .pop_front() + .ok_or_else(|| polars_err!(oos = "IPC: missing values buffer."))?; + Ok(()) +} diff --git a/crates/polars-arrow/src/io/ipc/read/common.rs b/crates/polars-arrow/src/io/ipc/read/common.rs new file mode 100644 index 000000000000..d802d8d55803 --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/read/common.rs @@ -0,0 +1,427 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; +use std::sync::Arc; + +use polars_error::{PolarsResult, polars_bail, polars_err}; +use polars_utils::aliases::PlHashMap; +use polars_utils::pl_str::PlSmallStr; + +use super::Dictionaries; +use super::deserialize::{read, skip}; +use crate::array::*; +use crate::datatypes::{ArrowDataType, ArrowSchema, Field}; +use crate::io::ipc::read::OutOfSpecKind; +use crate::io::ipc::{IpcField, IpcSchema}; +use crate::record_batch::RecordBatchT; + +#[derive(Debug, Eq, PartialEq, Hash)] +enum ProjectionResult { + Selected(A), + NotSelected(A), +} + +/// An iterator adapter that will return `Some(x)` or `None` +/// # Panics +/// The iterator panics iff the `projection` is not strictly increasing. +struct ProjectionIter<'a, A, I: Iterator> { + projection: &'a [usize], + iter: I, + current_count: usize, + current_projection: usize, +} + +impl<'a, A, I: Iterator> ProjectionIter<'a, A, I> { + /// # Panics + /// iff `projection` is empty + pub fn new(projection: &'a [usize], iter: I) -> Self { + Self { + projection: &projection[1..], + iter, + current_count: 0, + current_projection: projection[0], + } + } +} + +impl> Iterator for ProjectionIter<'_, A, I> { + type Item = ProjectionResult; + + fn next(&mut self) -> Option { + if let Some(item) = self.iter.next() { + let result = if self.current_count == self.current_projection { + if !self.projection.is_empty() { + assert!(self.projection[0] > self.current_projection); + self.current_projection = self.projection[0]; + self.projection = &self.projection[1..]; + } else { + self.current_projection = 0 // a value that most likely already passed + }; + Some(ProjectionResult::Selected(item)) + } else { + Some(ProjectionResult::NotSelected(item)) + }; + self.current_count += 1; + result + } else { + None + } + } + + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } +} + +/// Returns a [`RecordBatchT`] from a reader. +/// # Panic +/// Panics iff the projection is not in increasing order (e.g. `[1, 0]` nor `[0, 1, 1]` are valid) +#[allow(clippy::too_many_arguments)] +pub fn read_record_batch( + batch: arrow_format::ipc::RecordBatchRef, + fields: &ArrowSchema, + ipc_schema: &IpcSchema, + projection: Option<&[usize]>, + limit: Option, + dictionaries: &Dictionaries, + version: arrow_format::ipc::MetadataVersion, + reader: &mut R, + block_offset: u64, + file_size: u64, + scratch: &mut Vec, +) -> PolarsResult>> { + assert_eq!(fields.len(), ipc_schema.fields.len()); + let buffers = batch + .buffers() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferBuffers(err)))? + .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingMessageBuffers))?; + let mut variadic_buffer_counts = batch + .variadic_buffer_counts() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferRecordBatches(err)))? + .map(|v| v.iter().map(|v| v as usize).collect::>()) + .unwrap_or_else(VecDeque::new); + let mut buffers: VecDeque = buffers.iter().collect(); + + // check that the sum of the sizes of all buffers is <= than the size of the file + let buffers_size = buffers + .iter() + .map(|buffer| { + let buffer_size: u64 = buffer + .length() + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + Ok(buffer_size) + }) + .sum::>()?; + if buffers_size > file_size { + return Err(polars_err!( + oos = OutOfSpecKind::InvalidBuffersLength { + buffers_size, + file_size, + } + )); + } + + let field_nodes = batch + .nodes() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferNodes(err)))? + .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingMessageNodes))?; + let mut field_nodes = field_nodes.iter().collect::>(); + + let columns = if let Some(projection) = projection { + let projection = ProjectionIter::new( + projection, + fields.iter_values().zip(ipc_schema.fields.iter()), + ); + + projection + .map(|maybe_field| match maybe_field { + ProjectionResult::Selected((field, ipc_field)) => Ok(Some(read( + &mut field_nodes, + &mut variadic_buffer_counts, + field, + ipc_field, + &mut buffers, + reader, + dictionaries, + block_offset, + ipc_schema.is_little_endian, + batch.compression().map_err(|err| { + polars_err!(oos = OutOfSpecKind::InvalidFlatbufferCompression(err)) + })?, + limit, + version, + scratch, + )?)), + ProjectionResult::NotSelected((field, _)) => { + skip( + &mut field_nodes, + &field.dtype, + &mut buffers, + &mut variadic_buffer_counts, + )?; + Ok(None) + }, + }) + .filter_map(|x| x.transpose()) + .collect::>>()? + } else { + fields + .iter_values() + .zip(ipc_schema.fields.iter()) + .map(|(field, ipc_field)| { + read( + &mut field_nodes, + &mut variadic_buffer_counts, + field, + ipc_field, + &mut buffers, + reader, + dictionaries, + block_offset, + ipc_schema.is_little_endian, + batch.compression().map_err(|err| { + polars_err!(oos = OutOfSpecKind::InvalidFlatbufferCompression(err)) + })?, + limit, + version, + scratch, + ) + }) + .collect::>>()? + }; + + let length = batch + .length() + .map_err(|_| polars_err!(oos = OutOfSpecKind::MissingData)) + .unwrap() + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + let length = limit.map(|limit| limit.min(length)).unwrap_or(length); + + let mut schema: ArrowSchema = fields.iter_values().cloned().collect(); + if let Some(projection) = projection { + schema = schema.try_project_indices(projection).unwrap(); + } + RecordBatchT::try_new(length, Arc::new(schema), columns) +} + +fn find_first_dict_field_d<'a>( + id: i64, + dtype: &'a ArrowDataType, + ipc_field: &'a IpcField, +) -> Option<(&'a Field, &'a IpcField)> { + use ArrowDataType::*; + match dtype { + Dictionary(_, inner, _) => find_first_dict_field_d(id, inner.as_ref(), ipc_field), + List(field) | LargeList(field) | FixedSizeList(field, ..) | Map(field, ..) => { + find_first_dict_field(id, field.as_ref(), &ipc_field.fields[0]) + }, + Struct(fields) => { + for (field, ipc_field) in fields.iter().zip(ipc_field.fields.iter()) { + if let Some(f) = find_first_dict_field(id, field, ipc_field) { + return Some(f); + } + } + None + }, + Union(u) => { + for (field, ipc_field) in u.fields.iter().zip(ipc_field.fields.iter()) { + if let Some(f) = find_first_dict_field(id, field, ipc_field) { + return Some(f); + } + } + None + }, + _ => None, + } +} + +fn find_first_dict_field<'a>( + id: i64, + field: &'a Field, + ipc_field: &'a IpcField, +) -> Option<(&'a Field, &'a IpcField)> { + if let Some(field_id) = ipc_field.dictionary_id { + if id == field_id { + return Some((field, ipc_field)); + } + } + find_first_dict_field_d(id, &field.dtype, ipc_field) +} + +pub(crate) fn first_dict_field<'a>( + id: i64, + fields: &'a ArrowSchema, + ipc_fields: &'a [IpcField], +) -> PolarsResult<(&'a Field, &'a IpcField)> { + assert_eq!(fields.len(), ipc_fields.len()); + for (field, ipc_field) in fields.iter_values().zip(ipc_fields.iter()) { + if let Some(field) = find_first_dict_field(id, field, ipc_field) { + return Ok(field); + } + } + Err(polars_err!( + oos = OutOfSpecKind::InvalidId { requested_id: id } + )) +} + +/// Reads a dictionary from the reader, +/// updating `dictionaries` with the resulting dictionary +#[allow(clippy::too_many_arguments)] +pub fn read_dictionary( + batch: arrow_format::ipc::DictionaryBatchRef, + fields: &ArrowSchema, + ipc_schema: &IpcSchema, + dictionaries: &mut Dictionaries, + reader: &mut R, + block_offset: u64, + file_size: u64, + scratch: &mut Vec, +) -> PolarsResult<()> { + if batch + .is_delta() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferIsDelta(err)))? + { + polars_bail!(ComputeError: "delta dictionary batches not supported") + } + + let id = batch + .id() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferId(err)))?; + let (first_field, first_ipc_field) = first_dict_field(id, fields, &ipc_schema.fields)?; + + let batch = batch + .data() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferData(err)))? + .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingData))?; + + let value_type = + if let ArrowDataType::Dictionary(_, value_type, _) = first_field.dtype.to_logical_type() { + value_type.as_ref() + } else { + polars_bail!(oos = OutOfSpecKind::InvalidIdDataType { requested_id: id }) + }; + + // Make a fake schema for the dictionary batch. + let fields = std::iter::once(( + PlSmallStr::EMPTY, + Field::new(PlSmallStr::EMPTY, value_type.clone(), false), + )) + .collect(); + let ipc_schema = IpcSchema { + fields: vec![first_ipc_field.clone()], + is_little_endian: ipc_schema.is_little_endian, + }; + let chunk = read_record_batch( + batch, + &fields, + &ipc_schema, + None, + None, // we must read the whole dictionary + dictionaries, + arrow_format::ipc::MetadataVersion::V5, + reader, + block_offset, + file_size, + scratch, + )?; + + dictionaries.insert(id, chunk.into_arrays().pop().unwrap()); + + Ok(()) +} + +#[derive(Clone)] +pub struct ProjectionInfo { + pub columns: Vec, + pub map: PlHashMap, + pub schema: ArrowSchema, +} + +pub fn prepare_projection(schema: &ArrowSchema, mut projection: Vec) -> ProjectionInfo { + let schema = projection + .iter() + .map(|x| { + let (k, v) = schema.get_at_index(*x).unwrap(); + (k.clone(), v.clone()) + }) + .collect(); + + // todo: find way to do this more efficiently + let mut indices = (0..projection.len()).collect::>(); + indices.sort_unstable_by_key(|&i| &projection[i]); + let map = indices.iter().copied().enumerate().fold( + PlHashMap::default(), + |mut acc, (index, new_index)| { + acc.insert(index, new_index); + acc + }, + ); + projection.sort_unstable(); + + // check unique + if !projection.is_empty() { + let mut previous = projection[0]; + + for &i in &projection[1..] { + assert!( + previous < i, + "The projection on IPC must not contain duplicates" + ); + previous = i; + } + } + + ProjectionInfo { + columns: projection, + map, + schema, + } +} + +pub fn apply_projection( + chunk: RecordBatchT>, + map: &PlHashMap, +) -> RecordBatchT> { + let length = chunk.len(); + + // re-order according to projection + let (schema, arrays) = chunk.into_schema_and_arrays(); + let mut new_schema = schema.as_ref().clone(); + let mut new_arrays = arrays.clone(); + + map.iter().for_each(|(old, new)| { + let (old_name, old_field) = schema.get_at_index(*old).unwrap(); + let (new_name, new_field) = new_schema.get_at_index_mut(*new).unwrap(); + + *new_name = old_name.clone(); + *new_field = old_field.clone(); + + new_arrays[*new] = arrays[*old].clone(); + }); + + RecordBatchT::new(length, Arc::new(new_schema), new_arrays) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn project_iter() { + let iter = 1..6; + let iter = ProjectionIter::new(&[0, 2, 4], iter); + let result: Vec<_> = iter.collect(); + use ProjectionResult::*; + assert_eq!( + result, + vec![ + Selected(1), + NotSelected(2), + Selected(3), + NotSelected(4), + Selected(5) + ] + ) + } +} diff --git a/crates/polars-arrow/src/io/ipc/read/deserialize.rs b/crates/polars-arrow/src/io/ipc/read/deserialize.rs new file mode 100644 index 000000000000..1a57ac487c70 --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/read/deserialize.rs @@ -0,0 +1,285 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use arrow_format::ipc::{BodyCompressionRef, MetadataVersion}; +use polars_error::PolarsResult; + +use super::array::*; +use super::{Dictionaries, IpcBuffer, Node}; +use crate::array::*; +use crate::datatypes::{ArrowDataType, Field, PhysicalType}; +use crate::io::ipc::IpcField; +use crate::{match_integer_type, with_match_primitive_type_full}; + +#[allow(clippy::too_many_arguments)] +pub fn read( + field_nodes: &mut VecDeque, + variadic_buffer_counts: &mut VecDeque, + field: &Field, + ipc_field: &IpcField, + buffers: &mut VecDeque, + reader: &mut R, + dictionaries: &Dictionaries, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + version: MetadataVersion, + scratch: &mut Vec, +) -> PolarsResult> { + use PhysicalType::*; + let dtype = field.dtype.clone(); + + match dtype.to_physical_type() { + Null => read_null(field_nodes, dtype, limit).map(|x| x.boxed()), + Boolean => read_boolean( + field_nodes, + dtype, + buffers, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + ) + .map(|x| x.boxed()), + Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { + read_primitive::<$T, _>( + field_nodes, + dtype, + buffers, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + ) + .map(|x| x.boxed()) + }), + Binary => read_binary::( + field_nodes, + dtype, + buffers, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + ) + .map(|x| x.boxed()), + LargeBinary => read_binary::( + field_nodes, + dtype, + buffers, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + ) + .map(|x| x.boxed()), + FixedSizeBinary => read_fixed_size_binary( + field_nodes, + dtype, + buffers, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + ) + .map(|x| x.boxed()), + Utf8 => read_utf8::( + field_nodes, + dtype, + buffers, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + ) + .map(|x| x.boxed()), + LargeUtf8 => read_utf8::( + field_nodes, + dtype, + buffers, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + ) + .map(|x| x.boxed()), + List => read_list::( + field_nodes, + variadic_buffer_counts, + dtype, + ipc_field, + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + limit, + version, + scratch, + ) + .map(|x| x.boxed()), + LargeList => read_list::( + field_nodes, + variadic_buffer_counts, + dtype, + ipc_field, + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + limit, + version, + scratch, + ) + .map(|x| x.boxed()), + FixedSizeList => read_fixed_size_list( + field_nodes, + variadic_buffer_counts, + dtype, + ipc_field, + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + limit, + version, + scratch, + ) + .map(|x| x.boxed()), + Struct => read_struct( + field_nodes, + variadic_buffer_counts, + dtype, + ipc_field, + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + limit, + version, + scratch, + ) + .map(|x| x.boxed()), + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + read_dictionary::<$T, _>( + field_nodes, + dtype, + ipc_field.dictionary_id, + buffers, + reader, + dictionaries, + block_offset, + compression, + limit, + is_little_endian, + scratch, + ) + .map(|x| x.boxed()) + }) + }, + Union => read_union( + field_nodes, + variadic_buffer_counts, + dtype, + ipc_field, + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + limit, + version, + scratch, + ) + .map(|x| x.boxed()), + Map => read_map( + field_nodes, + variadic_buffer_counts, + dtype, + ipc_field, + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + limit, + version, + scratch, + ) + .map(|x| x.boxed()), + Utf8View => read_binview::( + field_nodes, + variadic_buffer_counts, + dtype, + buffers, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + ), + BinaryView => read_binview::<[u8], _>( + field_nodes, + variadic_buffer_counts, + dtype, + buffers, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + ), + } +} + +pub fn skip( + field_nodes: &mut VecDeque, + dtype: &ArrowDataType, + buffers: &mut VecDeque, + variadic_buffer_counts: &mut VecDeque, +) -> PolarsResult<()> { + use PhysicalType::*; + match dtype.to_physical_type() { + Null => skip_null(field_nodes), + Boolean => skip_boolean(field_nodes, buffers), + Primitive(_) => skip_primitive(field_nodes, buffers), + LargeBinary | Binary => skip_binary(field_nodes, buffers), + LargeUtf8 | Utf8 => skip_utf8(field_nodes, buffers), + FixedSizeBinary => skip_fixed_size_binary(field_nodes, buffers), + List => skip_list::(field_nodes, dtype, buffers, variadic_buffer_counts), + LargeList => skip_list::(field_nodes, dtype, buffers, variadic_buffer_counts), + FixedSizeList => skip_fixed_size_list(field_nodes, dtype, buffers, variadic_buffer_counts), + Struct => skip_struct(field_nodes, dtype, buffers, variadic_buffer_counts), + Dictionary(_) => skip_dictionary(field_nodes, buffers), + Union => skip_union(field_nodes, dtype, buffers, variadic_buffer_counts), + Map => skip_map(field_nodes, dtype, buffers, variadic_buffer_counts), + BinaryView | Utf8View => skip_binview(field_nodes, buffers, variadic_buffer_counts), + } +} diff --git a/crates/polars-arrow/src/io/ipc/read/error.rs b/crates/polars-arrow/src/io/ipc/read/error.rs new file mode 100644 index 000000000000..837388741f1d --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/read/error.rs @@ -0,0 +1,106 @@ +use std::fmt::{Display, Formatter}; + +/// The different types of errors that reading from IPC can cause +#[derive(Debug)] +#[non_exhaustive] +pub enum OutOfSpecKind { + /// The IPC file does not start with [b'A', b'R', b'R', b'O', b'W', b'1'] + InvalidHeader, + /// The IPC file does not end with [b'A', b'R', b'R', b'O', b'W', b'1'] + InvalidFooter, + /// The first 4 bytes of the last 10 bytes is < 0 + NegativeFooterLength, + /// The footer is an invalid flatbuffer + InvalidFlatbufferFooter(arrow_format::ipc::planus::Error), + /// The file's footer does not contain record batches + MissingRecordBatches, + /// The footer's record batches is an invalid flatbuffer + InvalidFlatbufferRecordBatches(arrow_format::ipc::planus::Error), + /// The file's footer does not contain a schema + MissingSchema, + /// The footer's schema is an invalid flatbuffer + InvalidFlatbufferSchema(arrow_format::ipc::planus::Error), + /// The file's schema does not contain fields + MissingFields, + /// The footer's dictionaries is an invalid flatbuffer + InvalidFlatbufferDictionaries(arrow_format::ipc::planus::Error), + /// The block is an invalid flatbuffer + InvalidFlatbufferBlock(arrow_format::ipc::planus::Error), + /// The dictionary message is an invalid flatbuffer + InvalidFlatbufferMessage(arrow_format::ipc::planus::Error), + /// The message does not contain a header + MissingMessageHeader, + /// The message's header is an invalid flatbuffer + InvalidFlatbufferHeader(arrow_format::ipc::planus::Error), + /// Relative positions in the file is < 0 + UnexpectedNegativeInteger, + /// dictionaries can only contain dictionary messages; record batches can only contain records + UnexpectedMessageType, + /// RecordBatch messages do not contain buffers + MissingMessageBuffers, + /// The message's buffers is an invalid flatbuffer + InvalidFlatbufferBuffers(arrow_format::ipc::planus::Error), + /// RecordBatch messages does not contain nodes + MissingMessageNodes, + /// The message's nodes is an invalid flatbuffer + InvalidFlatbufferNodes(arrow_format::ipc::planus::Error), + /// The message's body length is an invalid flatbuffer + InvalidFlatbufferBodyLength(arrow_format::ipc::planus::Error), + /// The message does not contain data + MissingData, + /// The message's data is an invalid flatbuffer + InvalidFlatbufferData(arrow_format::ipc::planus::Error), + /// The version is an invalid flatbuffer + InvalidFlatbufferVersion(arrow_format::ipc::planus::Error), + /// The compression is an invalid flatbuffer + InvalidFlatbufferCompression(arrow_format::ipc::planus::Error), + /// The record contains a number of buffers that does not match the required number by the data type + ExpectedBuffer, + /// A buffer's size is smaller than the required for the number of elements + InvalidBuffer { + /// Declared number of elements in the buffer + length: usize, + /// The name of the `NativeType` + type_name: &'static str, + /// Bytes required for the `length` and `type` + required_number_of_bytes: usize, + /// The size of the IPC buffer + buffer_length: usize, + }, + /// A buffer's size is larger than the file size + InvalidBuffersLength { + /// number of bytes of all buffers in the record + buffers_size: u64, + /// the size of the file + file_size: u64, + }, + /// A bitmap's size is smaller than the required for the number of elements + InvalidBitmap { + /// Declared length of the bitmap + length: usize, + /// Number of bits on the IPC buffer + number_of_bits: usize, + }, + /// The dictionary is_delta is an invalid flatbuffer + InvalidFlatbufferIsDelta(arrow_format::ipc::planus::Error), + /// The dictionary id is an invalid flatbuffer + InvalidFlatbufferId(arrow_format::ipc::planus::Error), + /// Invalid dictionary id + InvalidId { + /// The requested dictionary id + requested_id: i64, + }, + /// Field id is not a dictionary + InvalidIdDataType { + /// The requested dictionary id + requested_id: i64, + }, + /// FixedSizeBinaryArray has invalid datatype. + InvalidDataType, +} + +impl Display for OutOfSpecKind { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{self:?}") + } +} diff --git a/crates/polars-arrow/src/io/ipc/read/file.rs b/crates/polars-arrow/src/io/ipc/read/file.rs new file mode 100644 index 000000000000..94eef3549f8c --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/read/file.rs @@ -0,0 +1,382 @@ +use std::io::{Read, Seek, SeekFrom}; +use std::sync::Arc; + +use arrow_format::ipc::FooterRef; +use arrow_format::ipc::planus::ReadAsRoot; +use polars_error::{PolarsResult, polars_bail, polars_err}; +use polars_utils::aliases::{InitHashMaps, PlHashMap}; + +use super::super::{ARROW_MAGIC_V1, ARROW_MAGIC_V2, CONTINUATION_MARKER}; +use super::common::*; +use super::schema::fb_to_schema; +use super::{Dictionaries, OutOfSpecKind, SendableIterator}; +use crate::array::Array; +use crate::datatypes::{ArrowSchemaRef, Metadata}; +use crate::io::ipc::IpcSchema; +use crate::record_batch::RecordBatchT; + +/// Metadata of an Arrow IPC file, written in the footer of the file. +#[derive(Debug, Clone)] +pub struct FileMetadata { + /// The schema that is read from the file footer + pub schema: ArrowSchemaRef, + + /// The custom metadata that is read from the schema + pub custom_schema_metadata: Option>, + + /// The files' [`IpcSchema`] + pub ipc_schema: IpcSchema, + + /// The blocks in the file + /// + /// A block indicates the regions in the file to read to get data + pub blocks: Vec, + + /// Dictionaries associated to each dict_id + pub(crate) dictionaries: Option>, + + /// The total size of the file in bytes + pub size: u64, +} + +/// Read the row count by summing the length of the of the record batches +pub fn get_row_count(reader: &mut R) -> PolarsResult { + let (_, footer_len) = read_footer_len(reader)?; + let footer = read_footer(reader, footer_len)?; + let (_, blocks) = deserialize_footer_blocks(&footer)?; + + get_row_count_from_blocks(reader, &blocks) +} + +/// Read the row count by summing the length of the of the record batches in blocks +pub fn get_row_count_from_blocks( + reader: &mut R, + blocks: &[arrow_format::ipc::Block], +) -> PolarsResult { + let mut message_scratch: Vec = Default::default(); + + blocks + .iter() + .map(|block| { + let message = get_message_from_block(reader, block, &mut message_scratch)?; + let record_batch = get_record_batch(message)?; + record_batch.length().map_err(|e| e.into()) + }) + .sum() +} + +pub(crate) fn get_dictionary_batch<'a>( + message: &'a arrow_format::ipc::MessageRef, +) -> PolarsResult> { + let header = message + .header() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferHeader(err)))? + .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingMessageHeader))?; + match header { + arrow_format::ipc::MessageHeaderRef::DictionaryBatch(batch) => Ok(batch), + _ => polars_bail!(oos = OutOfSpecKind::UnexpectedMessageType), + } +} + +fn read_dictionary_block( + reader: &mut R, + metadata: &FileMetadata, + block: &arrow_format::ipc::Block, + dictionaries: &mut Dictionaries, + message_scratch: &mut Vec, + dictionary_scratch: &mut Vec, +) -> PolarsResult<()> { + let message = get_message_from_block(reader, block, message_scratch)?; + let batch = get_dictionary_batch(&message)?; + + let offset: u64 = block + .offset + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::UnexpectedNegativeInteger))?; + + let length: u64 = block + .meta_data_length + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::UnexpectedNegativeInteger))?; + + read_dictionary( + batch, + &metadata.schema, + &metadata.ipc_schema, + dictionaries, + reader, + offset + length, + metadata.size, + dictionary_scratch, + ) +} + +/// Reads all file's dictionaries, if any +/// This function is IO-bounded +pub fn read_file_dictionaries( + reader: &mut R, + metadata: &FileMetadata, + scratch: &mut Vec, +) -> PolarsResult { + let mut dictionaries = Default::default(); + + let blocks = if let Some(blocks) = &metadata.dictionaries { + blocks + } else { + return Ok(PlHashMap::new()); + }; + // use a temporary smaller scratch for the messages + let mut message_scratch = Default::default(); + + for block in blocks { + read_dictionary_block( + reader, + metadata, + block, + &mut dictionaries, + &mut message_scratch, + scratch, + )?; + } + Ok(dictionaries) +} + +pub(super) fn decode_footer_len(footer: [u8; 10], end: u64) -> PolarsResult<(u64, usize)> { + let footer_len = i32::from_le_bytes(footer[..4].try_into().unwrap()); + + if footer[4..] != ARROW_MAGIC_V2 { + if footer[..4] == ARROW_MAGIC_V1 { + polars_bail!(ComputeError: "feather v1 not supported"); + } + return Err(polars_err!(oos = OutOfSpecKind::InvalidFooter)); + } + let footer_len = footer_len + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + + Ok((end, footer_len)) +} + +/// Reads the footer's length and magic number in footer +fn read_footer_len(reader: &mut R) -> PolarsResult<(u64, usize)> { + // read footer length and magic number in footer + let end = reader.seek(SeekFrom::End(-10))? + 10; + + let mut footer: [u8; 10] = [0; 10]; + + reader.read_exact(&mut footer)?; + decode_footer_len(footer, end) +} + +fn read_footer(reader: &mut R, footer_len: usize) -> PolarsResult> { + // read footer + reader.seek(SeekFrom::End(-10 - footer_len as i64))?; + + let mut serialized_footer = vec![]; + serialized_footer.try_reserve(footer_len)?; + reader + .by_ref() + .take(footer_len as u64) + .read_to_end(&mut serialized_footer)?; + Ok(serialized_footer) +} + +fn deserialize_footer_blocks( + footer_data: &[u8], +) -> PolarsResult<(FooterRef, Vec)> { + let footer = arrow_format::ipc::FooterRef::read_as_root(footer_data) + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferFooter(err)))?; + + let blocks = footer + .record_batches() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferRecordBatches(err)))? + .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingRecordBatches))?; + + let blocks = blocks + .iter() + .map(|block| { + block.try_into().map_err(|err| { + polars_err!(oos = OutOfSpecKind::InvalidFlatbufferRecordBatches(err)) + }) + }) + .collect::>>()?; + Ok((footer, blocks)) +} + +pub(super) fn deserialize_footer_ref(footer_data: &[u8]) -> PolarsResult { + arrow_format::ipc::FooterRef::read_as_root(footer_data) + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferFooter(err))) +} + +pub(super) fn deserialize_schema_ref_from_footer( + footer: arrow_format::ipc::FooterRef, +) -> PolarsResult { + footer + .schema() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferSchema(err)))? + .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingSchema)) +} + +/// Get the IPC blocks from the footer containing record batches +pub(super) fn iter_recordbatch_blocks_from_footer( + footer: arrow_format::ipc::FooterRef, +) -> PolarsResult> + '_> { + let blocks = footer + .record_batches() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferRecordBatches(err)))? + .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingRecordBatches))?; + + Ok(blocks.iter().map(|block| { + block + .try_into() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferRecordBatches(err))) + })) +} + +pub(super) fn iter_dictionary_blocks_from_footer( + footer: arrow_format::ipc::FooterRef, +) -> PolarsResult> + '_>> +{ + let dictionaries = footer + .dictionaries() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferDictionaries(err)))?; + + Ok(dictionaries.map(|dicts| { + dicts.into_iter().map(|block| { + block.try_into().map_err(|err| { + polars_err!(oos = OutOfSpecKind::InvalidFlatbufferRecordBatches(err)) + }) + }) + })) +} + +pub fn deserialize_footer(footer_data: &[u8], size: u64) -> PolarsResult { + let footer = deserialize_footer_ref(footer_data)?; + let blocks = iter_recordbatch_blocks_from_footer(footer)?.collect::>>()?; + let dictionaries = iter_dictionary_blocks_from_footer(footer)? + .map(|dicts| dicts.collect::>>()) + .transpose()?; + let ipc_schema = deserialize_schema_ref_from_footer(footer)?; + let (schema, ipc_schema, custom_schema_metadata) = fb_to_schema(ipc_schema)?; + + Ok(FileMetadata { + schema: Arc::new(schema), + ipc_schema, + blocks, + dictionaries, + size, + custom_schema_metadata: custom_schema_metadata.map(Arc::new), + }) +} + +/// Read the Arrow IPC file's metadata +pub fn read_file_metadata(reader: &mut R) -> PolarsResult { + let start = reader.stream_position()?; + let (end, footer_len) = read_footer_len(reader)?; + let serialized_footer = read_footer(reader, footer_len)?; + deserialize_footer(&serialized_footer, end - start) +} + +pub(crate) fn get_record_batch( + message: arrow_format::ipc::MessageRef, +) -> PolarsResult { + let header = message + .header() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferHeader(err)))? + .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingMessageHeader))?; + match header { + arrow_format::ipc::MessageHeaderRef::RecordBatch(batch) => Ok(batch), + _ => polars_bail!(oos = OutOfSpecKind::UnexpectedMessageType), + } +} + +fn get_message_from_block_offset<'a, R: Read + Seek>( + reader: &mut R, + offset: u64, + message_scratch: &'a mut Vec, +) -> PolarsResult> { + // read length + reader.seek(SeekFrom::Start(offset))?; + let mut meta_buf = [0; 4]; + reader.read_exact(&mut meta_buf)?; + if meta_buf == CONTINUATION_MARKER { + // continuation marker encountered, read message next + reader.read_exact(&mut meta_buf)?; + } + let meta_len = i32::from_le_bytes(meta_buf) + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::UnexpectedNegativeInteger))?; + + message_scratch.clear(); + message_scratch.try_reserve(meta_len)?; + reader + .by_ref() + .take(meta_len as u64) + .read_to_end(message_scratch)?; + + arrow_format::ipc::MessageRef::read_as_root(message_scratch) + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err))) +} + +pub(super) fn get_message_from_block<'a, R: Read + Seek>( + reader: &mut R, + block: &arrow_format::ipc::Block, + message_scratch: &'a mut Vec, +) -> PolarsResult> { + let offset: u64 = block + .offset + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + + get_message_from_block_offset(reader, offset, message_scratch) +} + +/// Reads the record batch at position `index` from the reader. +/// +/// This function is useful for random access to the file. For example, if +/// you have indexed the file somewhere else, this allows pruning +/// certain parts of the file. +/// # Panics +/// This function panics iff `index >= metadata.blocks.len()` +#[allow(clippy::too_many_arguments)] +pub fn read_batch( + reader: &mut R, + dictionaries: &Dictionaries, + metadata: &FileMetadata, + projection: Option<&[usize]>, + limit: Option, + index: usize, + message_scratch: &mut Vec, + data_scratch: &mut Vec, +) -> PolarsResult>> { + let block = metadata.blocks[index]; + + let offset: u64 = block + .offset + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + + let length: u64 = block + .meta_data_length + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + + let message = get_message_from_block_offset(reader, offset, message_scratch)?; + let batch = get_record_batch(message)?; + + read_record_batch( + batch, + &metadata.schema, + &metadata.ipc_schema, + projection, + limit, + dictionaries, + message + .version() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferVersion(err)))?, + reader, + offset + length, + metadata.size, + data_scratch, + ) +} diff --git a/crates/polars-arrow/src/io/ipc/read/flight.rs b/crates/polars-arrow/src/io/ipc/read/flight.rs new file mode 100644 index 000000000000..257037ed8e40 --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/read/flight.rs @@ -0,0 +1,457 @@ +use std::io::SeekFrom; +use std::pin::Pin; +use std::sync::Arc; + +use arrow_format::ipc::planus::ReadAsRoot; +use arrow_format::ipc::{Block, FooterRef, MessageHeaderRef}; +use futures::{Stream, StreamExt}; +use polars_error::{PolarsResult, polars_bail, polars_err}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt}; + +use crate::datatypes::ArrowSchema; +use crate::io::ipc::read::common::read_record_batch; +use crate::io::ipc::read::file::{ + decode_footer_len, deserialize_schema_ref_from_footer, iter_dictionary_blocks_from_footer, + iter_recordbatch_blocks_from_footer, +}; +use crate::io::ipc::read::schema::deserialize_stream_metadata; +use crate::io::ipc::read::{Dictionaries, OutOfSpecKind, SendableIterator, StreamMetadata}; +use crate::io::ipc::write::common::EncodedData; +use crate::mmap::{mmap_dictionary_from_batch, mmap_record}; +use crate::record_batch::RecordBatch; + +async fn read_ipc_message_from_block<'a, R: AsyncRead + AsyncSeek + Unpin>( + reader: &mut R, + block: &arrow_format::ipc::Block, + scratch: &'a mut Vec, +) -> PolarsResult> { + let offset: u64 = block + .offset + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + reader.seek(SeekFrom::Start(offset)).await?; + read_ipc_message(reader, scratch).await +} + +/// Read an encapsulated IPC Message from the reader +async fn read_ipc_message<'a, R: AsyncRead + Unpin>( + reader: &mut R, + scratch: &'a mut Vec, +) -> PolarsResult> { + let mut message_size: [u8; 4] = [0; 4]; + + reader.read_exact(&mut message_size).await?; + if message_size == crate::io::ipc::CONTINUATION_MARKER { + reader.read_exact(&mut message_size).await?; + }; + let message_length = i32::from_le_bytes(message_size); + + let message_length: usize = message_length + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + + scratch.clear(); + scratch.try_reserve(message_length)?; + reader + .take(message_length as u64) + .read_to_end(scratch) + .await?; + + arrow_format::ipc::MessageRef::read_as_root(scratch) + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err))) +} + +async fn read_footer_len( + reader: &mut R, +) -> PolarsResult<(u64, usize)> { + // read footer length and magic number in footer + let end = reader.seek(SeekFrom::End(-10)).await? + 10; + + let mut footer: [u8; 10] = [0; 10]; + reader.read_exact(&mut footer).await?; + + decode_footer_len(footer, end) +} + +async fn read_footer( + reader: &mut R, + footer_len: usize, +) -> PolarsResult> { + // read footer + reader.seek(SeekFrom::End(-10 - footer_len as i64)).await?; + + let mut serialized_footer = vec![]; + serialized_footer.try_reserve(footer_len)?; + + reader + .take(footer_len as u64) + .read_to_end(&mut serialized_footer) + .await?; + Ok(serialized_footer) +} + +fn schema_to_raw_message(schema: arrow_format::ipc::SchemaRef) -> EncodedData { + // Turn the IPC schema into an encapsulated message + let message = arrow_format::ipc::Message { + version: arrow_format::ipc::MetadataVersion::V5, + // Assumed the conversion is infallible. + header: Some(arrow_format::ipc::MessageHeader::Schema(Box::new( + schema.try_into().unwrap(), + ))), + body_length: 0, + custom_metadata: None, // todo: allow writing custom metadata + }; + let mut builder = arrow_format::ipc::planus::Builder::new(); + let header = builder.finish(&message, None).to_vec(); + + // Use `EncodedData` directly instead of `FlightData`. In FlightData we would only use + // `data_header` and `data_body`. + EncodedData { + ipc_message: header, + arrow_data: vec![], + } +} + +async fn block_to_raw_message<'a, R>( + reader: &mut R, + block: &arrow_format::ipc::Block, + encoded_data: &mut EncodedData, +) -> PolarsResult<()> +where + R: AsyncRead + AsyncSeek + Unpin + Send + 'a, +{ + debug_assert!(encoded_data.arrow_data.is_empty() && encoded_data.ipc_message.is_empty()); + let message = read_ipc_message_from_block(reader, block, &mut encoded_data.ipc_message).await?; + + let block_length: u64 = message + .body_length() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferBodyLength(err)))? + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::UnexpectedNegativeInteger))?; + reader + .take(block_length) + .read_to_end(&mut encoded_data.arrow_data) + .await?; + + Ok(()) +} + +pub async fn into_flight_stream( + reader: &mut R, +) -> PolarsResult> + '_> { + Ok(async_stream::try_stream! { + let (_end, len) = read_footer_len(reader).await?; + let footer_data = read_footer(reader, len).await?; + let footer = arrow_format::ipc::FooterRef::read_as_root(&footer_data) + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferFooter(err)))?; + let data_blocks = iter_recordbatch_blocks_from_footer(footer)?; + let dict_blocks = iter_dictionary_blocks_from_footer(footer)?; + + let schema_ref = deserialize_schema_ref_from_footer(footer)?; + let schema = schema_to_raw_message(schema_ref); + + yield schema; + + if let Some(dict_blocks_iter) = dict_blocks { + for d in dict_blocks_iter { + let mut ed: EncodedData = Default::default(); + block_to_raw_message(reader, &d?, &mut ed).await?; + yield ed + } + }; + + for d in data_blocks { + let mut ed: EncodedData = Default::default(); + block_to_raw_message(reader, &d?, &mut ed).await?; + yield ed + } + }) +} + +pub struct FlightStreamProducer<'a, R: AsyncRead + AsyncSeek + Unpin + Send> { + footer: Option<*const FooterRef<'static>>, + footer_data: Vec, + dict_blocks: Option>>>, + data_blocks: Option>>>, + reader: &'a mut R, +} + +impl Drop for FlightStreamProducer<'_, R> { + fn drop(&mut self) { + if let Some(p) = self.footer { + unsafe { + let _ = Box::from_raw(p as *mut FooterRef<'static>); + } + } + } +} + +unsafe impl Send for FlightStreamProducer<'_, R> {} + +impl<'a, R: AsyncRead + AsyncSeek + Unpin + Send> FlightStreamProducer<'a, R> { + pub async fn new(reader: &'a mut R) -> PolarsResult>> { + let (_end, len) = read_footer_len(reader).await?; + let footer_data = read_footer(reader, len).await?; + + Ok(Box::pin(Self { + footer: None, + footer_data, + dict_blocks: None, + data_blocks: None, + reader, + })) + } + + pub fn init(self: &mut Pin>) -> PolarsResult<()> { + let footer = arrow_format::ipc::FooterRef::read_as_root(&self.footer_data) + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferFooter(err)))?; + + let footer = Box::new(footer); + + #[allow(clippy::unnecessary_cast)] + let ptr = Box::leak(footer) as *const _ as *const FooterRef<'static>; + + self.footer = Some(ptr); + let footer = &unsafe { **self.footer.as_ref().unwrap() }; + + self.data_blocks = Some(Box::new(iter_recordbatch_blocks_from_footer(*footer)?) + as Box>); + self.dict_blocks = iter_dictionary_blocks_from_footer(*footer)? + .map(|i| Box::new(i) as Box>); + + Ok(()) + } + + pub fn get_schema(self: &Pin>) -> PolarsResult { + let footer = &unsafe { **self.footer.as_ref().expect("init must be called first") }; + + let schema_ref = deserialize_schema_ref_from_footer(*footer)?; + let schema = schema_to_raw_message(schema_ref); + + Ok(schema) + } + + pub async fn next_dict( + self: &mut Pin>, + encoded_data: &mut EncodedData, + ) -> PolarsResult> { + assert!(self.data_blocks.is_some(), "init must be called first"); + encoded_data.ipc_message.clear(); + encoded_data.arrow_data.clear(); + + if let Some(iter) = &mut self.dict_blocks { + let Some(value) = iter.next() else { + return Ok(None); + }; + let block = value?; + + block_to_raw_message(&mut self.reader, &block, encoded_data).await?; + Ok(Some(())) + } else { + Ok(None) + } + } + + pub async fn next_data( + self: &mut Pin>, + encoded_data: &mut EncodedData, + ) -> PolarsResult> { + encoded_data.ipc_message.clear(); + encoded_data.arrow_data.clear(); + + let iter = self + .data_blocks + .as_mut() + .expect("init must be called first"); + let Some(value) = iter.next() else { + return Ok(None); + }; + let block = value?; + + block_to_raw_message(&mut self.reader, &block, encoded_data).await?; + Ok(Some(())) + } +} + +pub struct FlightConsumer { + dictionaries: Dictionaries, + md: StreamMetadata, + scratch: Vec, +} + +impl FlightConsumer { + pub fn new(first: EncodedData) -> PolarsResult { + let md = deserialize_stream_metadata(&first.ipc_message)?; + Ok(Self { + dictionaries: Default::default(), + md, + scratch: vec![], + }) + } + + pub fn schema(&self) -> &ArrowSchema { + &self.md.schema + } + + pub fn consume(&mut self, msg: EncodedData) -> PolarsResult> { + // Parse the header + let message = arrow_format::ipc::MessageRef::read_as_root(&msg.ipc_message) + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err)))?; + + let header = message + .header() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferHeader(err)))? + .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingMessageHeader))?; + + // Either append to the dictionaries and return None or return Some(ArrowChunk) + match header { + MessageHeaderRef::Schema(_) => { + polars_bail!(ComputeError: "Unexpected schema message while parsing Stream"); + }, + // Add to dictionary state and continue iteration + MessageHeaderRef::DictionaryBatch(batch) => unsafe { + // Needed to memory map. + let arrow_data = Arc::new(msg.arrow_data); + mmap_dictionary_from_batch( + &self.md.schema, + &self.md.ipc_schema.fields, + &arrow_data, + batch, + &mut self.dictionaries, + 0, + ) + .map(|_| None) + }, + // Return Batch + MessageHeaderRef::RecordBatch(batch) => { + if batch.compression()?.is_some() { + let data_size = msg.arrow_data.len() as u64; + let mut reader = std::io::Cursor::new(msg.arrow_data.as_slice()); + read_record_batch( + batch, + &self.md.schema, + &self.md.ipc_schema, + None, + None, + &self.dictionaries, + self.md.version, + &mut reader, + 0, + data_size, + &mut self.scratch, + ) + .map(Some) + } else { + // Needed to memory map. + let arrow_data = Arc::new(msg.arrow_data); + unsafe { + mmap_record( + &self.md.schema, + &self.md.ipc_schema.fields, + arrow_data.clone(), + batch, + 0, + &self.dictionaries, + ) + .map(Some) + } + } + }, + _ => unimplemented!(), + } + } +} + +pub struct FlightstreamConsumer> + Unpin> { + inner: FlightConsumer, + stream: S, +} + +impl> + Unpin> FlightstreamConsumer { + pub async fn new(mut stream: S) -> PolarsResult { + let Some(first) = stream.next().await else { + polars_bail!(ComputeError: "expected the schema") + }; + let first = first?; + + Ok(FlightstreamConsumer { + inner: FlightConsumer::new(first)?, + stream, + }) + } + + pub async fn next_batch(&mut self) -> PolarsResult> { + while let Some(msg) = self.stream.next().await { + let msg = msg?; + let option_recordbatch = self.inner.consume(msg)?; + if option_recordbatch.is_some() { + return Ok(option_recordbatch); + } + } + Ok(None) + } +} + +#[cfg(test)] +mod test { + use std::path::{Path, PathBuf}; + + use tokio::fs::File; + + use super::*; + use crate::record_batch::RecordBatch; + + fn get_file_path() -> PathBuf { + let polars_arrow = std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set"); + Path::new(&polars_arrow).join("../../py-polars/tests/unit/io/files/foods1.ipc") + } + + fn read_file(path: &Path) -> RecordBatch { + let mut file = std::fs::File::open(path).unwrap(); + let md = crate::io::ipc::read::read_file_metadata(&mut file).unwrap(); + let mut ipc_reader = crate::io::ipc::read::FileReader::new(&mut file, md, None, None); + ipc_reader.next().unwrap().unwrap() + } + + #[tokio::test] + async fn test_file_flight_simple() { + let path = &get_file_path(); + let mut file = tokio::fs::File::open(path).await.unwrap(); + let stream = into_flight_stream(&mut file).await.unwrap(); + + let mut c = FlightstreamConsumer::new(Box::pin(stream)).await.unwrap(); + let b = c.next_batch().await.unwrap().unwrap(); + + assert_eq!(b, read_file(path)); + } + + #[tokio::test] + async fn test_file_flight_amortized() { + let path = &get_file_path(); + let mut file = File::open(path).await.unwrap(); + let mut p = FlightStreamProducer::new(&mut file).await.unwrap(); + p.init().unwrap(); + + let mut batches = vec![]; + + let schema = p.get_schema().unwrap(); + batches.push(schema); + + let mut ed = EncodedData::default(); + if p.next_dict(&mut ed).await.unwrap().is_some() { + batches.push(ed); + } + + let mut ed = EncodedData::default(); + p.next_data(&mut ed).await.unwrap(); + batches.push(ed); + + let mut c = + FlightstreamConsumer::new(Box::pin(futures::stream::iter(batches.into_iter().map(Ok)))) + .await + .unwrap(); + let b = c.next_batch().await.unwrap().unwrap(); + + assert_eq!(b, read_file(path)); + } +} diff --git a/crates/polars-arrow/src/io/ipc/read/mod.rs b/crates/polars-arrow/src/io/ipc/read/mod.rs new file mode 100644 index 000000000000..da32856a2775 --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/read/mod.rs @@ -0,0 +1,46 @@ +//! APIs to read Arrow's IPC format. +//! +//! The two important structs here are the [`FileReader`](reader::FileReader), +//! which provides arbitrary access to any of its messages, and the +//! [`StreamReader`](stream::StreamReader), which only supports reading +//! data in the order it was written in. +use crate::array::Array; + +mod array; +mod common; +mod deserialize; +mod error; +pub(crate) mod file; +#[cfg(feature = "io_flight")] +mod flight; +mod read_basic; +mod reader; +mod schema; +mod stream; + +pub(crate) use common::first_dict_field; +pub use common::{ProjectionInfo, prepare_projection}; +pub use error::OutOfSpecKind; +pub use file::{ + FileMetadata, deserialize_footer, get_row_count, get_row_count_from_blocks, read_batch, + read_file_dictionaries, read_file_metadata, +}; +use polars_utils::aliases::PlHashMap; +pub use reader::FileReader; +pub use schema::deserialize_schema; +pub use stream::{StreamMetadata, StreamReader, StreamState, read_stream_metadata}; + +/// how dictionaries are tracked in this crate +pub type Dictionaries = PlHashMap>; + +pub(crate) type Node<'a> = arrow_format::ipc::FieldNodeRef<'a>; +pub(crate) type IpcBuffer<'a> = arrow_format::ipc::BufferRef<'a>; +pub(crate) type Compression<'a> = arrow_format::ipc::BodyCompressionRef<'a>; +pub(crate) type Version = arrow_format::ipc::MetadataVersion; + +#[cfg(feature = "io_flight")] +pub use flight::*; + +pub trait SendableIterator: Send + Iterator {} + +impl SendableIterator for T {} diff --git a/crates/polars-arrow/src/io/ipc/read/read_basic.rs b/crates/polars-arrow/src/io/ipc/read/read_basic.rs new file mode 100644 index 000000000000..1679067d749b --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/read/read_basic.rs @@ -0,0 +1,360 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek, SeekFrom}; + +use polars_error::{PolarsResult, polars_bail, polars_err}; + +use super::super::compression; +use super::super::endianness::is_native_little_endian; +use super::{Compression, IpcBuffer, Node, OutOfSpecKind}; +use crate::bitmap::Bitmap; +use crate::buffer::Buffer; +use crate::types::NativeType; + +fn read_swapped( + reader: &mut R, + length: usize, + buffer: &mut Vec, + is_little_endian: bool, +) -> PolarsResult<()> { + // slow case where we must reverse bits + let mut slice = vec![0u8; length * size_of::()]; + reader.read_exact(&mut slice)?; + + let chunks = slice.chunks_exact(size_of::()); + if !is_little_endian { + // machine is little endian, file is big endian + buffer + .as_mut_slice() + .iter_mut() + .zip(chunks) + .try_for_each(|(slot, chunk)| { + let a: T::Bytes = match chunk.try_into() { + Ok(a) => a, + Err(_) => unreachable!(), + }; + *slot = T::from_be_bytes(a); + PolarsResult::Ok(()) + })?; + } else { + // machine is big endian, file is little endian + polars_bail!(ComputeError: + "Reading little endian files from big endian machines", + ) + } + Ok(()) +} + +fn read_uncompressed_bytes( + reader: &mut R, + buffer_length: usize, + is_little_endian: bool, +) -> PolarsResult> { + if is_native_little_endian() == is_little_endian { + let mut buffer = Vec::with_capacity(buffer_length); + let _ = reader + .take(buffer_length as u64) + .read_to_end(&mut buffer) + .unwrap(); + Ok(buffer) + } else { + unreachable!() + } +} + +fn read_uncompressed_buffer( + reader: &mut R, + buffer_length: usize, + length: usize, + is_little_endian: bool, +) -> PolarsResult> { + let required_number_of_bytes = length.saturating_mul(size_of::()); + if required_number_of_bytes > buffer_length { + polars_bail!( + oos = OutOfSpecKind::InvalidBuffer { + length, + type_name: std::any::type_name::(), + required_number_of_bytes, + buffer_length, + } + ); + } + + // it is undefined behavior to call read_exact on un-initialized, https://doc.rust-lang.org/std/io/trait.Read.html#tymethod.read + // see also https://github.com/MaikKlein/ash/issues/354#issue-781730580 + let mut buffer = vec![T::default(); length]; + + if is_native_little_endian() == is_little_endian { + // fast case where we can just copy the contents + let slice = bytemuck::cast_slice_mut(&mut buffer); + reader.read_exact(slice)?; + } else { + read_swapped(reader, length, &mut buffer, is_little_endian)?; + } + Ok(buffer) +} + +fn read_compressed_buffer( + reader: &mut R, + buffer_length: usize, + output_length: Option, + is_little_endian: bool, + compression: Compression, + scratch: &mut Vec, +) -> PolarsResult> { + if output_length == Some(0) { + return Ok(vec![]); + } + + if is_little_endian != is_native_little_endian() { + polars_bail!(ComputeError: + "Reading compressed and big endian IPC".to_string(), + ) + } + + // decompress first + scratch.clear(); + scratch.try_reserve(buffer_length)?; + reader + .by_ref() + .take(buffer_length as u64) + .read_to_end(scratch)?; + + let length = output_length + .unwrap_or_else(|| i64::from_le_bytes(scratch[..8].try_into().unwrap()) as usize); + + // It is undefined behavior to call read_exact on un-initialized, https://doc.rust-lang.org/std/io/trait.Read.html#tymethod.read + // see also https://github.com/MaikKlein/ash/issues/354#issue-781730580 + let mut buffer = vec![T::default(); length]; + + let out_slice = bytemuck::cast_slice_mut(&mut buffer); + + let compression = compression + .codec() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferCompression(err)))?; + + match compression { + arrow_format::ipc::CompressionType::Lz4Frame => { + compression::decompress_lz4(&scratch[8..], out_slice)?; + }, + arrow_format::ipc::CompressionType::Zstd => { + compression::decompress_zstd(&scratch[8..], out_slice)?; + }, + } + Ok(buffer) +} + +fn read_compressed_bytes( + reader: &mut R, + buffer_length: usize, + is_little_endian: bool, + compression: Compression, + scratch: &mut Vec, +) -> PolarsResult> { + read_compressed_buffer::( + reader, + buffer_length, + None, + is_little_endian, + compression, + scratch, + ) +} + +pub fn read_bytes( + buf: &mut VecDeque, + reader: &mut R, + block_offset: u64, + is_little_endian: bool, + compression: Option, + scratch: &mut Vec, +) -> PolarsResult> { + let buf = buf + .pop_front() + .ok_or_else(|| polars_err!(oos = OutOfSpecKind::ExpectedBuffer))?; + + let offset: u64 = buf + .offset() + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + + let buffer_length: usize = buf + .length() + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + + reader.seek(SeekFrom::Start(block_offset + offset))?; + + if let Some(compression) = compression { + Ok(read_compressed_bytes( + reader, + buffer_length, + is_little_endian, + compression, + scratch, + )? + .into()) + } else { + Ok(read_uncompressed_bytes(reader, buffer_length, is_little_endian)?.into()) + } +} + +pub fn read_buffer( + buf: &mut VecDeque, + length: usize, // in slots + reader: &mut R, + block_offset: u64, + is_little_endian: bool, + compression: Option, + scratch: &mut Vec, +) -> PolarsResult> { + let buf = buf + .pop_front() + .ok_or_else(|| polars_err!(oos = OutOfSpecKind::ExpectedBuffer))?; + + let offset: u64 = buf + .offset() + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + + let buffer_length: usize = buf + .length() + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + + reader.seek(SeekFrom::Start(block_offset + offset))?; + + if let Some(compression) = compression { + Ok(read_compressed_buffer( + reader, + buffer_length, + Some(length), + is_little_endian, + compression, + scratch, + )? + .into()) + } else { + Ok(read_uncompressed_buffer(reader, buffer_length, length, is_little_endian)?.into()) + } +} + +fn read_uncompressed_bitmap( + length: usize, + bytes: usize, + reader: &mut R, +) -> PolarsResult> { + if length > bytes * 8 { + polars_bail!( + oos = OutOfSpecKind::InvalidBitmap { + length, + number_of_bits: bytes * 8, + } + ) + } + + let mut buffer = vec![]; + buffer.try_reserve(bytes)?; + reader + .by_ref() + .take(bytes as u64) + .read_to_end(&mut buffer)?; + + Ok(buffer) +} + +fn read_compressed_bitmap( + length: usize, + bytes: usize, + compression: Compression, + reader: &mut R, + scratch: &mut Vec, +) -> PolarsResult> { + let mut buffer = vec![0; length.div_ceil(8)]; + + scratch.clear(); + scratch.try_reserve(bytes)?; + reader.by_ref().take(bytes as u64).read_to_end(scratch)?; + + let compression = compression + .codec() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferCompression(err)))?; + + match compression { + arrow_format::ipc::CompressionType::Lz4Frame => { + compression::decompress_lz4(&scratch[8..], &mut buffer)?; + }, + arrow_format::ipc::CompressionType::Zstd => { + compression::decompress_zstd(&scratch[8..], &mut buffer)?; + }, + } + Ok(buffer) +} + +pub fn read_bitmap( + buf: &mut VecDeque, + length: usize, + reader: &mut R, + block_offset: u64, + _: bool, + compression: Option, + scratch: &mut Vec, +) -> PolarsResult { + let buf = buf + .pop_front() + .ok_or_else(|| polars_err!(oos = OutOfSpecKind::ExpectedBuffer))?; + + let offset: u64 = buf + .offset() + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + + let bytes: usize = buf + .length() + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + + reader.seek(SeekFrom::Start(block_offset + offset))?; + + let buffer = if let Some(compression) = compression { + read_compressed_bitmap(length, bytes, compression, reader, scratch) + } else { + read_uncompressed_bitmap(length, bytes, reader) + }?; + + Bitmap::try_new(buffer, length) +} + +#[allow(clippy::too_many_arguments)] +pub fn read_validity( + buffers: &mut VecDeque, + field_node: Node, + reader: &mut R, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + scratch: &mut Vec, +) -> PolarsResult> { + let length: usize = field_node + .length() + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + let length = limit.map(|limit| limit.min(length)).unwrap_or(length); + + Ok(if field_node.null_count() > 0 { + Some(read_bitmap( + buffers, + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + )?) + } else { + let _ = buffers + .pop_front() + .ok_or_else(|| polars_err!(oos = OutOfSpecKind::ExpectedBuffer))?; + None + }) +} diff --git a/crates/polars-arrow/src/io/ipc/read/reader.rs b/crates/polars-arrow/src/io/ipc/read/reader.rs new file mode 100644 index 000000000000..ebfdd6bc96e5 --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/read/reader.rs @@ -0,0 +1,203 @@ +use std::io::{Read, Seek}; + +use polars_error::PolarsResult; + +use super::common::*; +use super::file::{get_message_from_block, get_record_batch}; +use super::{Dictionaries, FileMetadata, read_batch, read_file_dictionaries}; +use crate::array::Array; +use crate::datatypes::ArrowSchema; +use crate::record_batch::RecordBatchT; + +/// An iterator of [`RecordBatchT`]s from an Arrow IPC file. +pub struct FileReader { + reader: R, + metadata: FileMetadata, + // the dictionaries are going to be read + dictionaries: Option, + current_block: usize, + projection: Option, + remaining: usize, + data_scratch: Vec, + message_scratch: Vec, +} + +impl FileReader { + /// Creates a new [`FileReader`]. Use `projection` to only take certain columns. + /// # Panic + /// Panics iff the projection is not in increasing order (e.g. `[1, 0]` nor `[0, 1, 1]` are valid) + pub fn new( + reader: R, + metadata: FileMetadata, + projection: Option>, + limit: Option, + ) -> Self { + let projection = + projection.map(|projection| prepare_projection(&metadata.schema, projection)); + Self { + reader, + metadata, + dictionaries: Default::default(), + projection, + remaining: limit.unwrap_or(usize::MAX), + current_block: 0, + data_scratch: Default::default(), + message_scratch: Default::default(), + } + } + + /// Creates a new [`FileReader`]. Use `projection` to only take certain columns. + /// # Panic + /// Panics iff the projection is not in increasing order (e.g. `[1, 0]` nor `[0, 1, 1]` are valid) + pub fn new_with_projection_info( + reader: R, + metadata: FileMetadata, + projection: Option, + limit: Option, + ) -> Self { + Self { + reader, + metadata, + dictionaries: Default::default(), + projection, + remaining: limit.unwrap_or(usize::MAX), + current_block: 0, + data_scratch: Default::default(), + message_scratch: Default::default(), + } + } + + /// Return the schema of the file + pub fn schema(&self) -> &ArrowSchema { + self.projection + .as_ref() + .map(|x| &x.schema) + .unwrap_or(&self.metadata.schema) + } + + /// Returns the [`FileMetadata`] + pub fn metadata(&self) -> &FileMetadata { + &self.metadata + } + + /// Consumes this FileReader, returning the underlying reader + pub fn into_inner(self) -> R { + self.reader + } + + pub fn set_current_block(&mut self, idx: usize) { + self.current_block = idx; + } + + pub fn get_current_block(&self) -> usize { + self.current_block + } + + /// Get the inner memory scratches so they can be reused in a new writer. + /// This can be utilized to save memory allocations for performance reasons. + pub fn take_projection_info(&mut self) -> Option { + std::mem::take(&mut self.projection) + } + + /// Get the inner memory scratches so they can be reused in a new writer. + /// This can be utilized to save memory allocations for performance reasons. + pub fn take_scratches(&mut self) -> (Vec, Vec) { + ( + std::mem::take(&mut self.data_scratch), + std::mem::take(&mut self.message_scratch), + ) + } + + /// Set the inner memory scratches so they can be reused in a new writer. + /// This can be utilized to save memory allocations for performance reasons. + pub fn set_scratches(&mut self, scratches: (Vec, Vec)) { + (self.data_scratch, self.message_scratch) = scratches; + } + + fn read_dictionaries(&mut self) -> PolarsResult<()> { + if self.dictionaries.is_none() { + self.dictionaries = Some(read_file_dictionaries( + &mut self.reader, + &self.metadata, + &mut self.data_scratch, + )?); + }; + Ok(()) + } + + /// Skip over blocks until we have seen at most `offset` rows, returning how many rows we are + /// still too see. + /// + /// This will never go over the `offset`. Meaning that if the `offset < current_block.len()`, + /// the block will not be skipped. + pub fn skip_blocks_till_limit(&mut self, offset: u64) -> PolarsResult { + let mut remaining_offset = offset; + + for (i, block) in self.metadata.blocks.iter().enumerate() { + let message = + get_message_from_block(&mut self.reader, block, &mut self.message_scratch)?; + let record_batch = get_record_batch(message)?; + + let length = record_batch.length()?; + let length = length as u64; + + if length > remaining_offset { + self.current_block = i; + return Ok(remaining_offset); + } + + remaining_offset -= length; + } + + self.current_block = self.metadata.blocks.len(); + Ok(remaining_offset) + } + + pub fn next_record_batch( + &mut self, + ) -> Option>> { + let block = self.metadata.blocks.get(self.current_block)?; + self.current_block += 1; + let message = get_message_from_block(&mut self.reader, block, &mut self.message_scratch); + Some(message.and_then(|m| get_record_batch(m))) + } +} + +impl Iterator for FileReader { + type Item = PolarsResult>>; + + fn next(&mut self) -> Option { + // get current block + if self.current_block == self.metadata.blocks.len() { + return None; + } + + match self.read_dictionaries() { + Ok(_) => {}, + Err(e) => return Some(Err(e)), + }; + + let block = self.current_block; + self.current_block += 1; + + let chunk = read_batch( + &mut self.reader, + self.dictionaries.as_ref().unwrap(), + &self.metadata, + self.projection.as_ref().map(|x| x.columns.as_ref()), + Some(self.remaining), + block, + &mut self.message_scratch, + &mut self.data_scratch, + ); + self.remaining -= chunk.as_ref().map(|x| x.len()).unwrap_or_default(); + + let chunk = if let Some(ProjectionInfo { map, .. }) = &self.projection { + // re-order according to projection + chunk.map(|chunk| apply_projection(chunk, map)) + } else { + chunk + }; + Some(chunk) + } +} diff --git a/crates/polars-arrow/src/io/ipc/read/schema.rs b/crates/polars-arrow/src/io/ipc/read/schema.rs new file mode 100644 index 000000000000..6c427bf8e859 --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/read/schema.rs @@ -0,0 +1,459 @@ +use std::sync::Arc; + +use arrow_format::ipc::planus::ReadAsRoot; +use arrow_format::ipc::{FieldRef, FixedSizeListRef, MapRef, TimeRef, TimestampRef, UnionRef}; +use polars_error::{PolarsResult, polars_bail, polars_err}; +use polars_utils::pl_str::PlSmallStr; + +use super::super::{IpcField, IpcSchema}; +use super::{OutOfSpecKind, StreamMetadata}; +use crate::datatypes::{ + ArrowDataType, ArrowSchema, Extension, ExtensionType, Field, IntegerType, IntervalUnit, + Metadata, TimeUnit, UnionMode, UnionType, get_extension, +}; + +fn try_unzip_vec>>( + iter: I, +) -> PolarsResult<(Vec, Vec)> { + let mut a = vec![]; + let mut b = vec![]; + for maybe_item in iter { + let (a_i, b_i) = maybe_item?; + a.push(a_i); + b.push(b_i); + } + + Ok((a, b)) +} + +fn deserialize_field(ipc_field: arrow_format::ipc::FieldRef) -> PolarsResult<(Field, IpcField)> { + let metadata = read_metadata(&ipc_field)?; + + let extension = metadata.as_ref().and_then(get_extension); + + let (dtype, ipc_field_) = get_dtype(ipc_field, extension, true)?; + + let field = Field { + name: PlSmallStr::from_str( + ipc_field + .name()? + .ok_or_else(|| polars_err!(oos = "Every field in IPC must have a name"))?, + ), + dtype, + is_nullable: ipc_field.nullable()?, + metadata: metadata.map(Arc::new), + }; + + Ok((field, ipc_field_)) +} + +fn read_metadata(field: &arrow_format::ipc::FieldRef) -> PolarsResult> { + Ok(if let Some(list) = field.custom_metadata()? { + let mut metadata_map = Metadata::new(); + for kv in list { + let kv = kv?; + if let (Some(k), Some(v)) = (kv.key()?, kv.value()?) { + metadata_map.insert(PlSmallStr::from_str(k), PlSmallStr::from_str(v)); + } + } + Some(metadata_map) + } else { + None + }) +} + +fn deserialize_integer(int: arrow_format::ipc::IntRef) -> PolarsResult { + Ok(match (int.bit_width()?, int.is_signed()?) { + (8, true) => IntegerType::Int8, + (8, false) => IntegerType::UInt8, + (16, true) => IntegerType::Int16, + (16, false) => IntegerType::UInt16, + (32, true) => IntegerType::Int32, + (32, false) => IntegerType::UInt32, + (64, true) => IntegerType::Int64, + (64, false) => IntegerType::UInt64, + (128, true) => IntegerType::Int128, + _ => polars_bail!(oos = "IPC: indexType can only be 8, 16, 32, 64 or 128."), + }) +} + +fn deserialize_timeunit(time_unit: arrow_format::ipc::TimeUnit) -> PolarsResult { + use arrow_format::ipc::TimeUnit::*; + Ok(match time_unit { + Second => TimeUnit::Second, + Millisecond => TimeUnit::Millisecond, + Microsecond => TimeUnit::Microsecond, + Nanosecond => TimeUnit::Nanosecond, + }) +} + +fn deserialize_time(time: TimeRef) -> PolarsResult<(ArrowDataType, IpcField)> { + let unit = deserialize_timeunit(time.unit()?)?; + + let dtype = match (time.bit_width()?, unit) { + (32, TimeUnit::Second) => ArrowDataType::Time32(TimeUnit::Second), + (32, TimeUnit::Millisecond) => ArrowDataType::Time32(TimeUnit::Millisecond), + (64, TimeUnit::Microsecond) => ArrowDataType::Time64(TimeUnit::Microsecond), + (64, TimeUnit::Nanosecond) => ArrowDataType::Time64(TimeUnit::Nanosecond), + (bits, precision) => { + polars_bail!(ComputeError: + "Time type with bit width of {bits} and unit of {precision:?}" + ) + }, + }; + Ok((dtype, IpcField::default())) +} + +fn deserialize_timestamp(timestamp: TimestampRef) -> PolarsResult<(ArrowDataType, IpcField)> { + let timezone = timestamp.timezone()?; + let time_unit = deserialize_timeunit(timestamp.unit()?)?; + Ok(( + ArrowDataType::Timestamp(time_unit, timezone.map(PlSmallStr::from_str)), + IpcField::default(), + )) +} + +fn deserialize_union(union_: UnionRef, field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> { + let mode = UnionMode::sparse(union_.mode()? == arrow_format::ipc::UnionMode::Sparse); + let ids = union_.type_ids()?.map(|x| x.iter().collect()); + + let fields = field + .children()? + .ok_or_else(|| polars_err!(oos = "IPC: Union must contain children"))?; + if fields.is_empty() { + polars_bail!(oos = "IPC: Union must contain at least one child"); + } + + let (fields, ipc_fields) = try_unzip_vec(fields.iter().map(|field| { + let (field, fields) = deserialize_field(field?)?; + Ok((field, fields)) + }))?; + let ipc_field = IpcField { + fields: ipc_fields, + dictionary_id: None, + }; + Ok(( + ArrowDataType::Union(Box::new(UnionType { fields, ids, mode })), + ipc_field, + )) +} + +fn deserialize_map(map: MapRef, field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> { + let is_sorted = map.keys_sorted()?; + + let children = field + .children()? + .ok_or_else(|| polars_err!(oos = "IPC: Map must contain children"))?; + let inner = children + .get(0) + .ok_or_else(|| polars_err!(oos = "IPC: Map must contain one child"))??; + let (field, ipc_field) = deserialize_field(inner)?; + + let dtype = ArrowDataType::Map(Box::new(field), is_sorted); + Ok(( + dtype, + IpcField { + fields: vec![ipc_field], + dictionary_id: None, + }, + )) +} + +fn deserialize_struct(field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> { + let fields = field + .children()? + .ok_or_else(|| polars_err!(oos = "IPC: Struct must contain children"))?; + let (fields, ipc_fields) = try_unzip_vec(fields.iter().map(|field| { + let (field, fields) = deserialize_field(field?)?; + Ok((field, fields)) + }))?; + let ipc_field = IpcField { + fields: ipc_fields, + dictionary_id: None, + }; + Ok((ArrowDataType::Struct(fields), ipc_field)) +} + +fn deserialize_list(field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> { + let children = field + .children()? + .ok_or_else(|| polars_err!(oos = "IPC: List must contain children"))?; + let inner = children + .get(0) + .ok_or_else(|| polars_err!(oos = "IPC: List must contain one child"))??; + let (field, ipc_field) = deserialize_field(inner)?; + + Ok(( + ArrowDataType::List(Box::new(field)), + IpcField { + fields: vec![ipc_field], + dictionary_id: None, + }, + )) +} + +fn deserialize_large_list(field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> { + let children = field + .children()? + .ok_or_else(|| polars_err!(oos = "IPC: List must contain children"))?; + let inner = children + .get(0) + .ok_or_else(|| polars_err!(oos = "IPC: List must contain one child"))??; + let (field, ipc_field) = deserialize_field(inner)?; + + Ok(( + ArrowDataType::LargeList(Box::new(field)), + IpcField { + fields: vec![ipc_field], + dictionary_id: None, + }, + )) +} + +fn deserialize_fixed_size_list( + list: FixedSizeListRef, + field: FieldRef, +) -> PolarsResult<(ArrowDataType, IpcField)> { + let children = field + .children()? + .ok_or_else(|| polars_err!(oos = "IPC: FixedSizeList must contain children"))?; + let inner = children + .get(0) + .ok_or_else(|| polars_err!(oos = "IPC: FixedSizeList must contain one child"))??; + let (field, ipc_field) = deserialize_field(inner)?; + + let size = list + .list_size()? + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + + Ok(( + ArrowDataType::FixedSizeList(Box::new(field), size), + IpcField { + fields: vec![ipc_field], + dictionary_id: None, + }, + )) +} + +/// Get the Arrow data type from the flatbuffer Field table +fn get_dtype( + field: arrow_format::ipc::FieldRef, + extension: Extension, + may_be_dictionary: bool, +) -> PolarsResult<(ArrowDataType, IpcField)> { + if let Some(dictionary) = field.dictionary()? { + if may_be_dictionary { + let int = dictionary + .index_type()? + .ok_or_else(|| polars_err!(oos = "indexType is mandatory in Dictionary."))?; + let index_type = deserialize_integer(int)?; + let (inner, mut ipc_field) = get_dtype(field, extension, false)?; + ipc_field.dictionary_id = Some(dictionary.id()?); + return Ok(( + ArrowDataType::Dictionary(index_type, Box::new(inner), dictionary.is_ordered()?), + ipc_field, + )); + } + } + + if let Some(extension) = extension { + let (name, metadata) = extension; + let (dtype, fields) = get_dtype(field, None, false)?; + return Ok(( + ArrowDataType::Extension(Box::new(ExtensionType { + name, + inner: dtype, + metadata, + })), + fields, + )); + } + + let type_ = field + .type_()? + .ok_or_else(|| polars_err!(oos = "IPC: field type is mandatory"))?; + + use arrow_format::ipc::TypeRef::*; + Ok(match type_ { + Null(_) => (ArrowDataType::Null, IpcField::default()), + Bool(_) => (ArrowDataType::Boolean, IpcField::default()), + Int(int) => { + let dtype = deserialize_integer(int)?.into(); + (dtype, IpcField::default()) + }, + Binary(_) => (ArrowDataType::Binary, IpcField::default()), + LargeBinary(_) => (ArrowDataType::LargeBinary, IpcField::default()), + Utf8(_) => (ArrowDataType::Utf8, IpcField::default()), + LargeUtf8(_) => (ArrowDataType::LargeUtf8, IpcField::default()), + BinaryView(_) => (ArrowDataType::BinaryView, IpcField::default()), + Utf8View(_) => (ArrowDataType::Utf8View, IpcField::default()), + FixedSizeBinary(fixed) => ( + ArrowDataType::FixedSizeBinary( + fixed + .byte_width()? + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?, + ), + IpcField::default(), + ), + FloatingPoint(float) => { + let dtype = match float.precision()? { + arrow_format::ipc::Precision::Half => ArrowDataType::Float16, + arrow_format::ipc::Precision::Single => ArrowDataType::Float32, + arrow_format::ipc::Precision::Double => ArrowDataType::Float64, + }; + (dtype, IpcField::default()) + }, + Date(date) => { + let dtype = match date.unit()? { + arrow_format::ipc::DateUnit::Day => ArrowDataType::Date32, + arrow_format::ipc::DateUnit::Millisecond => ArrowDataType::Date64, + }; + (dtype, IpcField::default()) + }, + Time(time) => deserialize_time(time)?, + Timestamp(timestamp) => deserialize_timestamp(timestamp)?, + Interval(interval) => { + let dtype = match interval.unit()? { + arrow_format::ipc::IntervalUnit::YearMonth => { + ArrowDataType::Interval(IntervalUnit::YearMonth) + }, + arrow_format::ipc::IntervalUnit::DayTime => { + ArrowDataType::Interval(IntervalUnit::DayTime) + }, + arrow_format::ipc::IntervalUnit::MonthDayNano => { + ArrowDataType::Interval(IntervalUnit::MonthDayNano) + }, + }; + (dtype, IpcField::default()) + }, + Duration(duration) => { + let time_unit = deserialize_timeunit(duration.unit()?)?; + (ArrowDataType::Duration(time_unit), IpcField::default()) + }, + Decimal(decimal) => { + let bit_width: usize = decimal + .bit_width()? + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + let precision: usize = decimal + .precision()? + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + let scale: usize = decimal + .scale()? + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + + let dtype = match bit_width { + 128 => ArrowDataType::Decimal(precision, scale), + 256 => ArrowDataType::Decimal256(precision, scale), + _ => return Err(polars_err!(oos = OutOfSpecKind::NegativeFooterLength)), + }; + + (dtype, IpcField::default()) + }, + List(_) => deserialize_list(field)?, + LargeList(_) => deserialize_large_list(field)?, + FixedSizeList(list) => deserialize_fixed_size_list(list, field)?, + Struct(_) => deserialize_struct(field)?, + Union(union_) => deserialize_union(union_, field)?, + Map(map) => deserialize_map(map, field)?, + RunEndEncoded(_) => todo!(), + LargeListView(_) | ListView(_) => todo!(), + }) +} + +/// Deserialize an flatbuffers-encoded Schema message into [`ArrowSchema`] and [`IpcSchema`]. +pub fn deserialize_schema( + message: &[u8], +) -> PolarsResult<(ArrowSchema, IpcSchema, Option)> { + let message = arrow_format::ipc::MessageRef::read_as_root(message) + .map_err(|err| polars_err!(oos = format!("Unable deserialize message: {err:?}")))?; + + let schema = match message + .header()? + .ok_or_else(|| polars_err!(oos = "Unable to convert header to a schema".to_string()))? + { + arrow_format::ipc::MessageHeaderRef::Schema(schema) => PolarsResult::Ok(schema), + _ => polars_bail!(ComputeError: "The message is expected to be a Schema message"), + }?; + + fb_to_schema(schema) +} + +/// Deserialize the raw Schema table from IPC format to Schema data type +pub(super) fn fb_to_schema( + schema: arrow_format::ipc::SchemaRef, +) -> PolarsResult<(ArrowSchema, IpcSchema, Option)> { + let fields = schema + .fields()? + .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingFields))?; + + let mut arrow_schema = ArrowSchema::with_capacity(fields.len()); + let mut ipc_fields = Vec::with_capacity(fields.len()); + + for field in fields { + let (field, ipc_field) = deserialize_field(field?)?; + arrow_schema.insert(field.name.clone(), field); + ipc_fields.push(ipc_field); + } + + let is_little_endian = match schema.endianness()? { + arrow_format::ipc::Endianness::Little => true, + arrow_format::ipc::Endianness::Big => false, + }; + + let custom_schema_metadata = match schema.custom_metadata()? { + None => None, + Some(metadata) => { + let metadata: Metadata = metadata + .into_iter() + .filter_map(|kv_result| { + // FIXME: silently hiding errors here + let kv_ref = kv_result.ok()?; + Some((kv_ref.key().ok()??.into(), kv_ref.value().ok()??.into())) + }) + .collect(); + + if metadata.is_empty() { + None + } else { + Some(metadata) + } + }, + }; + + Ok(( + arrow_schema, + IpcSchema { + fields: ipc_fields, + is_little_endian, + }, + custom_schema_metadata, + )) +} + +pub(super) fn deserialize_stream_metadata(meta: &[u8]) -> PolarsResult { + let message = arrow_format::ipc::MessageRef::read_as_root(meta) + .map_err(|err| polars_err!(oos = format!("Unable to get root as message: {err:?}")))?; + let version = message.version()?; + // message header is a Schema, so read it + let header = message + .header()? + .ok_or_else(|| polars_err!(oos = "Unable to read the first IPC message"))?; + let schema = if let arrow_format::ipc::MessageHeaderRef::Schema(schema) = header { + schema + } else { + polars_bail!(oos = "The first IPC message of the stream must be a schema") + }; + let (schema, ipc_schema, custom_schema_metadata) = fb_to_schema(schema)?; + + Ok(StreamMetadata { + schema, + version, + ipc_schema, + custom_schema_metadata, + }) +} diff --git a/crates/polars-arrow/src/io/ipc/read/stream.rs b/crates/polars-arrow/src/io/ipc/read/stream.rs new file mode 100644 index 000000000000..64b8325d368e --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/read/stream.rs @@ -0,0 +1,310 @@ +use std::io::Read; + +use arrow_format::ipc::planus::ReadAsRoot; +use polars_error::{PolarsError, PolarsResult, polars_bail, polars_err}; + +use super::super::CONTINUATION_MARKER; +use super::common::*; +use super::schema::deserialize_stream_metadata; +use super::{Dictionaries, OutOfSpecKind}; +use crate::array::Array; +use crate::datatypes::{ArrowSchema, Metadata}; +use crate::io::ipc::IpcSchema; +use crate::record_batch::RecordBatchT; + +/// Metadata of an Arrow IPC stream, written at the start of the stream +#[derive(Debug, Clone)] +pub struct StreamMetadata { + /// The schema that is read from the stream's first message + pub schema: ArrowSchema, + + /// The custom metadata that is read from the schema + pub custom_schema_metadata: Option, + + /// The IPC version of the stream + pub version: arrow_format::ipc::MetadataVersion, + + /// The IPC fields tracking dictionaries + pub ipc_schema: IpcSchema, +} + +/// Reads the metadata of the stream +pub fn read_stream_metadata(reader: &mut dyn std::io::Read) -> PolarsResult { + // determine metadata length + let mut meta_size: [u8; 4] = [0; 4]; + reader.read_exact(&mut meta_size)?; + let meta_length = { + // If a continuation marker is encountered, skip over it and read + // the size from the next four bytes. + if meta_size == CONTINUATION_MARKER { + reader.read_exact(&mut meta_size)?; + } + i32::from_le_bytes(meta_size) + }; + + let length: usize = meta_length + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + + let mut buffer = vec![]; + buffer.try_reserve(length)?; + reader.take(length as u64).read_to_end(&mut buffer)?; + + deserialize_stream_metadata(&buffer) +} + +/// Encodes the stream's status after each read. +/// +/// A stream is an iterator, and an iterator returns `Option`. The `Item` +/// type in the [`StreamReader`] case is `StreamState`, which means that an Arrow +/// stream may yield one of three values: (1) `None`, which signals that the stream +/// is done; (2) [`StreamState::Some`], which signals that there was +/// data waiting in the stream and we read it; and finally (3) +/// [`Some(StreamState::Waiting)`], which means that the stream is still "live", it +/// just doesn't hold any data right now. +pub enum StreamState { + /// A live stream without data + Waiting, + /// Next item in the stream + Some(RecordBatchT>), +} + +impl StreamState { + /// Return the data inside this wrapper. + /// + /// # Panics + /// + /// If the `StreamState` was `Waiting`. + pub fn unwrap(self) -> RecordBatchT> { + if let StreamState::Some(batch) = self { + batch + } else { + panic!("The batch is not available") + } + } +} + +/// Reads the next item, yielding `None` if the stream is done, +/// and a [`StreamState`] otherwise. +fn read_next( + reader: &mut R, + metadata: &StreamMetadata, + dictionaries: &mut Dictionaries, + message_buffer: &mut Vec, + data_buffer: &mut Vec, + projection: &Option, + scratch: &mut Vec, +) -> PolarsResult> { + // determine metadata length + let mut meta_length: [u8; 4] = [0; 4]; + + match reader.read_exact(&mut meta_length) { + Ok(()) => (), + Err(e) => { + return if e.kind() == std::io::ErrorKind::UnexpectedEof { + // Handle EOF without the "0xFFFFFFFF 0x00000000" + // valid according to: + // https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format + Ok(Some(StreamState::Waiting)) + } else { + Err(PolarsError::from(e)) + }; + }, + } + + let meta_length = { + // If a continuation marker is encountered, skip over it and read + // the size from the next four bytes. + if meta_length == CONTINUATION_MARKER { + reader.read_exact(&mut meta_length)?; + } + i32::from_le_bytes(meta_length) + }; + + let meta_length: usize = meta_length + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + + if meta_length == 0 { + // the stream has ended, mark the reader as finished + return Ok(None); + } + + message_buffer.clear(); + message_buffer.try_reserve(meta_length)?; + reader + .by_ref() + .take(meta_length as u64) + .read_to_end(message_buffer)?; + + let message = arrow_format::ipc::MessageRef::read_as_root(message_buffer.as_ref()) + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err)))?; + + let header = message + .header() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferHeader(err)))? + .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingMessageHeader))?; + + let block_length: usize = message + .body_length() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferBodyLength(err)))? + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::UnexpectedNegativeInteger))?; + + match header { + arrow_format::ipc::MessageHeaderRef::RecordBatch(batch) => { + data_buffer.clear(); + data_buffer.try_reserve(block_length)?; + reader + .by_ref() + .take(block_length as u64) + .read_to_end(data_buffer)?; + + let file_size = data_buffer.len() as u64; + + let mut reader = std::io::Cursor::new(data_buffer); + + let chunk = read_record_batch( + batch, + &metadata.schema, + &metadata.ipc_schema, + projection.as_ref().map(|x| x.columns.as_ref()), + None, + dictionaries, + metadata.version, + &mut reader, + 0, + file_size, + scratch, + ); + + if let Some(ProjectionInfo { map, .. }) = projection { + // re-order according to projection + chunk + .map(|chunk| apply_projection(chunk, map)) + .map(|x| Some(StreamState::Some(x))) + } else { + chunk.map(|x| Some(StreamState::Some(x))) + } + }, + arrow_format::ipc::MessageHeaderRef::DictionaryBatch(batch) => { + data_buffer.clear(); + data_buffer.try_reserve(block_length)?; + reader + .by_ref() + .take(block_length as u64) + .read_to_end(data_buffer)?; + + let file_size = data_buffer.len() as u64; + let mut dict_reader = std::io::Cursor::new(&data_buffer); + + read_dictionary( + batch, + &metadata.schema, + &metadata.ipc_schema, + dictionaries, + &mut dict_reader, + 0, + file_size, + scratch, + )?; + + // read the next message until we encounter a RecordBatch message + read_next( + reader, + metadata, + dictionaries, + message_buffer, + data_buffer, + projection, + scratch, + ) + }, + _ => polars_bail!(oos = OutOfSpecKind::UnexpectedMessageType), + } +} + +/// Arrow Stream reader. +/// +/// An [`Iterator`] over an Arrow stream that yields a result of [`StreamState`]s. +/// This is the recommended way to read an arrow stream (by iterating over its data). +/// +/// For a more thorough walkthrough consult [this example](https://github.com/jorgecarleitao/polars_arrow/tree/main/examples/ipc_pyarrow). +pub struct StreamReader { + reader: R, + metadata: StreamMetadata, + dictionaries: Dictionaries, + finished: bool, + data_buffer: Vec, + message_buffer: Vec, + projection: Option, + scratch: Vec, +} + +impl StreamReader { + /// Try to create a new stream reader + /// + /// The first message in the stream is the schema, the reader will fail if it does not + /// encounter a schema. + /// To check if the reader is done, use `is_finished(self)` + pub fn new(reader: R, metadata: StreamMetadata, projection: Option>) -> Self { + let projection = + projection.map(|projection| prepare_projection(&metadata.schema, projection)); + + Self { + reader, + metadata, + dictionaries: Default::default(), + finished: false, + data_buffer: Default::default(), + message_buffer: Default::default(), + projection, + scratch: Default::default(), + } + } + + /// Return the schema of the stream + pub fn metadata(&self) -> &StreamMetadata { + &self.metadata + } + + /// Return the schema of the file + pub fn schema(&self) -> &ArrowSchema { + self.projection + .as_ref() + .map(|x| &x.schema) + .unwrap_or(&self.metadata.schema) + } + + /// Check if the stream is finished + pub fn is_finished(&self) -> bool { + self.finished + } + + fn maybe_next(&mut self) -> PolarsResult> { + if self.finished { + return Ok(None); + } + let batch = read_next( + &mut self.reader, + &self.metadata, + &mut self.dictionaries, + &mut self.message_buffer, + &mut self.data_buffer, + &self.projection, + &mut self.scratch, + )?; + if batch.is_none() { + self.finished = true; + } + Ok(batch) + } +} + +impl Iterator for StreamReader { + type Item = PolarsResult; + + fn next(&mut self) -> Option { + self.maybe_next().transpose() + } +} diff --git a/crates/polars-arrow/src/io/ipc/write/common.rs b/crates/polars-arrow/src/io/ipc/write/common.rs new file mode 100644 index 000000000000..79497f4f89e3 --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/write/common.rs @@ -0,0 +1,584 @@ +use std::borrow::{Borrow, Cow}; + +use arrow_format::ipc; +use arrow_format::ipc::planus::Builder; +use polars_error::{PolarsResult, polars_bail, polars_err}; + +use super::super::IpcField; +use super::{write, write_dictionary}; +use crate::array::*; +use crate::datatypes::*; +use crate::io::ipc::endianness::is_native_little_endian; +use crate::io::ipc::read::Dictionaries; +use crate::legacy::prelude::LargeListArray; +use crate::match_integer_type; +use crate::record_batch::RecordBatchT; +use crate::types::Index; + +/// Compression codec +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Compression { + /// LZ4 (framed) + LZ4, + /// ZSTD + ZSTD, +} + +/// Options declaring the behaviour of writing to IPC +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] +pub struct WriteOptions { + /// Whether the buffers should be compressed and which codec to use. + /// Note: to use compression the crate must be compiled with feature `io_ipc_compression`. + pub compression: Option, +} + +/// Find the dictionary that are new and need to be encoded. +pub fn dictionaries_to_encode( + field: &IpcField, + array: &dyn Array, + dictionary_tracker: &mut DictionaryTracker, + dicts_to_encode: &mut Vec<(i64, Box)>, +) -> PolarsResult<()> { + use PhysicalType::*; + match array.dtype().to_physical_type() { + Utf8 | LargeUtf8 | Binary | LargeBinary | Primitive(_) | Boolean | Null + | FixedSizeBinary | BinaryView | Utf8View => Ok(()), + Dictionary(key_type) => match_integer_type!(key_type, |$T| { + let dict_id = field.dictionary_id + .ok_or_else(|| polars_err!(InvalidOperation: "Dictionaries must have an associated id"))?; + + if dictionary_tracker.insert(dict_id, array)? { + dicts_to_encode.push((dict_id, array.to_boxed())); + } + + let array = array.as_any().downcast_ref::>().unwrap(); + let values = array.values(); + // @Q? Should this not pick fields[0]? + dictionaries_to_encode(field, + values.as_ref(), + dictionary_tracker, + dicts_to_encode, + )?; + + Ok(()) + }), + Struct => { + let array = array.as_any().downcast_ref::().unwrap(); + let fields = field.fields.as_slice(); + if array.fields().len() != fields.len() { + polars_bail!(InvalidOperation: + "The number of fields in a struct must equal the number of children in IpcField".to_string(), + ); + } + fields + .iter() + .zip(array.values().iter()) + .try_for_each(|(field, values)| { + dictionaries_to_encode( + field, + values.as_ref(), + dictionary_tracker, + dicts_to_encode, + ) + }) + }, + List => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap() + .values(); + let field = &field.fields[0]; // todo: error instead + dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode) + }, + LargeList => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap() + .values(); + let field = &field.fields[0]; // todo: error instead + dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode) + }, + FixedSizeList => { + let values = array + .as_any() + .downcast_ref::() + .unwrap() + .values(); + let field = &field.fields[0]; // todo: error instead + dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode) + }, + Union => { + let values = array + .as_any() + .downcast_ref::() + .unwrap() + .fields(); + let fields = &field.fields[..]; // todo: error instead + if values.len() != fields.len() { + polars_bail!(InvalidOperation: + "The number of fields in a union must equal the number of children in IpcField" + ); + } + fields + .iter() + .zip(values.iter()) + .try_for_each(|(field, values)| { + dictionaries_to_encode( + field, + values.as_ref(), + dictionary_tracker, + dicts_to_encode, + ) + }) + }, + Map => { + let values = array.as_any().downcast_ref::().unwrap().field(); + let field = &field.fields[0]; // todo: error instead + dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode) + }, + } +} + +/// Encode a dictionary array with a certain id. +/// +/// # Panics +/// +/// This will panic if the given array is not a [`DictionaryArray`]. +pub fn encode_dictionary( + dict_id: i64, + array: &dyn Array, + options: &WriteOptions, + encoded_dictionaries: &mut Vec, +) -> PolarsResult<()> { + let PhysicalType::Dictionary(key_type) = array.dtype().to_physical_type() else { + panic!("Given array is not a DictionaryArray") + }; + + match_integer_type!(key_type, |$T| { + let array = array.as_any().downcast_ref::>().unwrap(); + encoded_dictionaries.push(dictionary_batch_to_bytes::<$T>( + dict_id, + array, + options, + is_native_little_endian(), + )); + }); + + Ok(()) +} + +pub fn encode_new_dictionaries( + field: &IpcField, + array: &dyn Array, + options: &WriteOptions, + dictionary_tracker: &mut DictionaryTracker, + encoded_dictionaries: &mut Vec, +) -> PolarsResult<()> { + let mut dicts_to_encode = Vec::new(); + dictionaries_to_encode(field, array, dictionary_tracker, &mut dicts_to_encode)?; + for (dict_id, dict_array) in dicts_to_encode { + encode_dictionary(dict_id, dict_array.as_ref(), options, encoded_dictionaries)?; + } + Ok(()) +} + +pub fn encode_chunk( + chunk: &RecordBatchT>, + fields: &[IpcField], + dictionary_tracker: &mut DictionaryTracker, + options: &WriteOptions, +) -> PolarsResult<(Vec, EncodedData)> { + let mut encoded_message = EncodedData::default(); + let encoded_dictionaries = encode_chunk_amortized( + chunk, + fields, + dictionary_tracker, + options, + &mut encoded_message, + )?; + Ok((encoded_dictionaries, encoded_message)) +} + +// Amortizes `EncodedData` allocation. +pub fn encode_chunk_amortized( + chunk: &RecordBatchT>, + fields: &[IpcField], + dictionary_tracker: &mut DictionaryTracker, + options: &WriteOptions, + encoded_message: &mut EncodedData, +) -> PolarsResult> { + let mut encoded_dictionaries = vec![]; + + for (field, array) in fields.iter().zip(chunk.as_ref()) { + encode_new_dictionaries( + field, + array.as_ref(), + options, + dictionary_tracker, + &mut encoded_dictionaries, + )?; + } + encode_record_batch(chunk, options, encoded_message); + + Ok(encoded_dictionaries) +} + +fn serialize_compression( + compression: Option, +) -> Option> { + if let Some(compression) = compression { + let codec = match compression { + Compression::LZ4 => arrow_format::ipc::CompressionType::Lz4Frame, + Compression::ZSTD => arrow_format::ipc::CompressionType::Zstd, + }; + Some(Box::new(arrow_format::ipc::BodyCompression { + codec, + method: arrow_format::ipc::BodyCompressionMethod::Buffer, + })) + } else { + None + } +} + +fn set_variadic_buffer_counts(counts: &mut Vec, array: &dyn Array) { + match array.dtype() { + ArrowDataType::Utf8View => { + let array = array.as_any().downcast_ref::().unwrap(); + counts.push(array.data_buffers().len() as i64); + }, + ArrowDataType::BinaryView => { + let array = array.as_any().downcast_ref::().unwrap(); + counts.push(array.data_buffers().len() as i64); + }, + ArrowDataType::Struct(_) => { + let array = array.as_any().downcast_ref::().unwrap(); + for array in array.values() { + set_variadic_buffer_counts(counts, array.as_ref()) + } + }, + ArrowDataType::LargeList(_) => { + // Subslicing can change the variadic buffer count, so we have to + // slice here as well to stay synchronized. + let array = array.as_any().downcast_ref::().unwrap(); + let offsets = array.offsets().buffer(); + let first = *offsets.first().unwrap(); + let last = *offsets.last().unwrap(); + let subslice = array + .values() + .sliced(first.to_usize(), last.to_usize() - first.to_usize()); + set_variadic_buffer_counts(counts, &*subslice) + }, + ArrowDataType::FixedSizeList(_, _) => { + let array = array.as_any().downcast_ref::().unwrap(); + set_variadic_buffer_counts(counts, array.values().as_ref()) + }, + // Don't traverse dictionary values as those are set when the `Dictionary` IPC struct + // is read. + ArrowDataType::Dictionary(_, _, _) => (), + _ => (), + } +} + +fn gc_bin_view<'a, T: ViewType + ?Sized>( + arr: &'a Box, + concrete_arr: &'a BinaryViewArrayGeneric, +) -> Cow<'a, Box> { + let bytes_len = concrete_arr.total_bytes_len(); + let buffer_len = concrete_arr.total_buffer_len(); + let extra_len = buffer_len.saturating_sub(bytes_len); + if extra_len < bytes_len.min(1024) { + // We can afford some tiny waste. + Cow::Borrowed(arr) + } else { + // Force GC it. + Cow::Owned(concrete_arr.clone().gc().boxed()) + } +} + +pub fn encode_array( + array: &Box, + options: &WriteOptions, + variadic_buffer_counts: &mut Vec, + buffers: &mut Vec, + arrow_data: &mut Vec, + nodes: &mut Vec, + offset: &mut i64, +) { + // We don't want to write all buffers in sliced arrays. + let array = match array.dtype() { + ArrowDataType::BinaryView => { + let concrete_arr = array.as_any().downcast_ref::().unwrap(); + gc_bin_view(array, concrete_arr) + }, + ArrowDataType::Utf8View => { + let concrete_arr = array.as_any().downcast_ref::().unwrap(); + gc_bin_view(array, concrete_arr) + }, + _ => Cow::Borrowed(array), + }; + let array = array.as_ref().as_ref(); + + set_variadic_buffer_counts(variadic_buffer_counts, array); + + write( + array, + buffers, + arrow_data, + nodes, + offset, + is_native_little_endian(), + options.compression, + ) +} + +/// Write [`RecordBatchT`] into two sets of bytes, one for the header (ipc::Schema::Message) and the +/// other for the batch's data +pub fn encode_record_batch( + chunk: &RecordBatchT>, + options: &WriteOptions, + encoded_message: &mut EncodedData, +) { + let mut nodes: Vec = vec![]; + let mut buffers: Vec = vec![]; + encoded_message.arrow_data.clear(); + + let mut offset = 0; + let mut variadic_buffer_counts = vec![]; + for array in chunk.arrays() { + encode_array( + array, + options, + &mut variadic_buffer_counts, + &mut buffers, + &mut encoded_message.arrow_data, + &mut nodes, + &mut offset, + ); + } + + commit_encoded_arrays( + chunk.len(), + options, + variadic_buffer_counts, + buffers, + nodes, + encoded_message, + ); +} + +pub fn commit_encoded_arrays( + array_len: usize, + options: &WriteOptions, + variadic_buffer_counts: Vec, + buffers: Vec, + nodes: Vec, + encoded_message: &mut EncodedData, +) { + let variadic_buffer_counts = if variadic_buffer_counts.is_empty() { + None + } else { + Some(variadic_buffer_counts) + }; + + let compression = serialize_compression(options.compression); + + let message = arrow_format::ipc::Message { + version: arrow_format::ipc::MetadataVersion::V5, + header: Some(arrow_format::ipc::MessageHeader::RecordBatch(Box::new( + arrow_format::ipc::RecordBatch { + length: array_len as i64, + nodes: Some(nodes), + buffers: Some(buffers), + compression, + variadic_buffer_counts, + }, + ))), + body_length: encoded_message.arrow_data.len() as i64, + custom_metadata: None, + }; + + let mut builder = Builder::new(); + let ipc_message = builder.finish(&message, None); + encoded_message.ipc_message = ipc_message.to_vec(); +} + +/// Write dictionary values into two sets of bytes, one for the header (ipc::Schema::Message) and the +/// other for the data +fn dictionary_batch_to_bytes( + dict_id: i64, + array: &DictionaryArray, + options: &WriteOptions, + is_little_endian: bool, +) -> EncodedData { + let mut nodes: Vec = vec![]; + let mut buffers: Vec = vec![]; + let mut arrow_data: Vec = vec![]; + let mut variadic_buffer_counts = vec![]; + set_variadic_buffer_counts(&mut variadic_buffer_counts, array.values().as_ref()); + + let variadic_buffer_counts = if variadic_buffer_counts.is_empty() { + None + } else { + Some(variadic_buffer_counts) + }; + + let length = write_dictionary( + array, + &mut buffers, + &mut arrow_data, + &mut nodes, + &mut 0, + is_little_endian, + options.compression, + false, + ); + + let compression = serialize_compression(options.compression); + + let message = arrow_format::ipc::Message { + version: arrow_format::ipc::MetadataVersion::V5, + header: Some(arrow_format::ipc::MessageHeader::DictionaryBatch(Box::new( + arrow_format::ipc::DictionaryBatch { + id: dict_id, + data: Some(Box::new(arrow_format::ipc::RecordBatch { + length: length as i64, + nodes: Some(nodes), + buffers: Some(buffers), + compression, + variadic_buffer_counts, + })), + is_delta: false, + }, + ))), + body_length: arrow_data.len() as i64, + custom_metadata: None, + }; + + let mut builder = Builder::new(); + let ipc_message = builder.finish(&message, None); + + EncodedData { + ipc_message: ipc_message.to_vec(), + arrow_data, + } +} + +/// Keeps track of dictionaries that have been written, to avoid emitting the same dictionary +/// multiple times. Can optionally error if an update to an existing dictionary is attempted, which +/// isn't allowed in the `FileWriter`. +pub struct DictionaryTracker { + pub dictionaries: Dictionaries, + pub cannot_replace: bool, +} + +impl DictionaryTracker { + /// Keep track of the dictionary with the given ID and values. Behavior: + /// + /// * If this ID has been written already and has the same data, return `Ok(false)` to indicate + /// that the dictionary was not actually inserted (because it's already been seen). + /// * If this ID has been written already but with different data, and this tracker is + /// configured to return an error, return an error. + /// * If the tracker has not been configured to error on replacement or this dictionary + /// has never been seen before, return `Ok(true)` to indicate that the dictionary was just + /// inserted. + pub fn insert(&mut self, dict_id: i64, array: &dyn Array) -> PolarsResult { + let values = match array.dtype() { + ArrowDataType::Dictionary(key_type, _, _) => { + match_integer_type!(key_type, |$T| { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + array.values() + }) + }, + _ => unreachable!(), + }; + + // If a dictionary with this id was already emitted, check if it was the same. + if let Some(last) = self.dictionaries.get(&dict_id) { + if last.as_ref() == values.as_ref() { + // Same dictionary values => no need to emit it again + return Ok(false); + } else if self.cannot_replace { + polars_bail!(InvalidOperation: + "Dictionary replacement detected when writing IPC file format. \ + Arrow IPC files only support a single dictionary for a given field \ + across all batches." + ); + } + }; + + self.dictionaries.insert(dict_id, values.clone()); + Ok(true) + } +} + +/// Stores the encoded data, which is an ipc::Schema::Message, and optional Arrow data +#[derive(Debug, Default)] +pub struct EncodedData { + /// An encoded ipc::Schema::Message + pub ipc_message: Vec, + /// Arrow buffers to be written, should be an empty vec for schema messages + pub arrow_data: Vec, +} + +/// Calculate an 8-byte boundary and return the number of bytes needed to pad to 8 bytes +#[inline] +pub(crate) fn pad_to_64(len: usize) -> usize { + ((len + 63) & !63) - len +} + +/// An array [`RecordBatchT`] with optional accompanying IPC fields. +#[derive(Debug, Clone, PartialEq)] +pub struct Record<'a> { + columns: Cow<'a, RecordBatchT>>, + fields: Option>, +} + +impl Record<'_> { + /// Get the IPC fields for this record. + pub fn fields(&self) -> Option<&[IpcField]> { + self.fields.as_deref() + } + + /// Get the Arrow columns in this record. + pub fn columns(&self) -> &RecordBatchT> { + self.columns.borrow() + } +} + +impl From>> for Record<'static> { + fn from(columns: RecordBatchT>) -> Self { + Self { + columns: Cow::Owned(columns), + fields: None, + } + } +} + +impl<'a, F> From<(RecordBatchT>, Option)> for Record<'a> +where + F: Into>, +{ + fn from((columns, fields): (RecordBatchT>, Option)) -> Self { + Self { + columns: Cow::Owned(columns), + fields: fields.map(|f| f.into()), + } + } +} + +impl<'a, F> From<(&'a RecordBatchT>, Option)> for Record<'a> +where + F: Into>, +{ + fn from((columns, fields): (&'a RecordBatchT>, Option)) -> Self { + Self { + columns: Cow::Borrowed(columns), + fields: fields.map(|f| f.into()), + } + } +} diff --git a/crates/polars-arrow/src/io/ipc/write/common_sync.rs b/crates/polars-arrow/src/io/ipc/write/common_sync.rs new file mode 100644 index 000000000000..074da477b5c4 --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/write/common_sync.rs @@ -0,0 +1,63 @@ +use std::io::Write; + +use polars_error::PolarsResult; + +use super::super::CONTINUATION_MARKER; +use super::common::{EncodedData, pad_to_64}; + +/// Write a message's IPC data and buffers, returning metadata and buffer data lengths written +pub fn write_message( + writer: &mut W, + encoded: &EncodedData, +) -> PolarsResult<(usize, usize)> { + let arrow_data_len = encoded.arrow_data.len(); + + let a = 8 - 1; + let buffer = &encoded.ipc_message; + let flatbuf_size = buffer.len(); + let prefix_size = 8; + let aligned_size = (flatbuf_size + prefix_size + a) & !a; + let padding_bytes = aligned_size - flatbuf_size - prefix_size; + + write_continuation(writer, (aligned_size - prefix_size) as i32)?; + + // write the flatbuf + if flatbuf_size > 0 { + writer.write_all(buffer)?; + } + // write padding + // aligned to a 8 byte boundary, so maximum is [u8;8] + const PADDING_MAX: [u8; 8] = [0u8; 8]; + writer.write_all(&PADDING_MAX[..padding_bytes])?; + + // write arrow data + let body_len = if arrow_data_len > 0 { + write_body_buffers(writer, &encoded.arrow_data)? + } else { + 0 + }; + + Ok((aligned_size, body_len)) +} + +fn write_body_buffers(mut writer: W, data: &[u8]) -> PolarsResult { + let len = data.len(); + let pad_len = pad_to_64(data.len()); + let total_len = len + pad_len; + + // write body buffer + writer.write_all(data)?; + if pad_len > 0 { + writer.write_all(&vec![0u8; pad_len][..])?; + } + + Ok(total_len) +} + +/// Write a record batch to the writer, writing the message size before the message +/// if the record batch is being written to a stream +pub fn write_continuation(writer: &mut W, total_len: i32) -> PolarsResult { + writer.write_all(&CONTINUATION_MARKER)?; + writer.write_all(&total_len.to_le_bytes()[..])?; + Ok(8) +} diff --git a/crates/polars-arrow/src/io/ipc/write/mod.rs b/crates/polars-arrow/src/io/ipc/write/mod.rs new file mode 100644 index 000000000000..5fb2540dba67 --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/write/mod.rs @@ -0,0 +1,72 @@ +//! APIs to write to Arrow's IPC format. +pub(crate) mod common; +mod schema; +mod serialize; +mod stream; +pub(crate) mod writer; + +pub use common::{ + Compression, DictionaryTracker, EncodedData, Record, WriteOptions, commit_encoded_arrays, + dictionaries_to_encode, encode_array, encode_dictionary, encode_new_dictionaries, + encode_record_batch, +}; +pub use schema::schema_to_bytes; +pub use serialize::write; +use serialize::write_dictionary; +pub use stream::StreamWriter; +pub use writer::FileWriter; + +pub(crate) mod common_sync; + +use super::IpcField; +use crate::datatypes::{ArrowDataType, Field}; + +fn default_ipc_field(dtype: &ArrowDataType, current_id: &mut i64) -> IpcField { + use crate::datatypes::ArrowDataType::*; + match dtype.to_logical_type() { + // single child => recurse + Map(inner, ..) | FixedSizeList(inner, _) | LargeList(inner) | List(inner) => IpcField { + fields: vec![default_ipc_field(inner.dtype(), current_id)], + dictionary_id: None, + }, + // multiple children => recurse + Struct(fields) => IpcField { + fields: fields + .iter() + .map(|f| default_ipc_field(f.dtype(), current_id)) + .collect(), + dictionary_id: None, + }, + // multiple children => recurse + Union(u) => IpcField { + fields: u + .fields + .iter() + .map(|f| default_ipc_field(f.dtype(), current_id)) + .collect(), + dictionary_id: None, + }, + // dictionary => current_id + Dictionary(_, dtype, _) => { + let dictionary_id = Some(*current_id); + *current_id += 1; + IpcField { + fields: vec![default_ipc_field(dtype, current_id)], + dictionary_id, + } + }, + // no children => do nothing + _ => IpcField { + fields: vec![], + dictionary_id: None, + }, + } +} + +/// Assigns every dictionary field a unique ID +pub fn default_ipc_fields<'a>(fields: impl ExactSizeIterator) -> Vec { + let mut dictionary_id = 0i64; + fields + .map(|field| default_ipc_field(field.dtype().to_logical_type(), &mut dictionary_id)) + .collect() +} diff --git a/crates/polars-arrow/src/io/ipc/write/schema.rs b/crates/polars-arrow/src/io/ipc/write/schema.rs new file mode 100644 index 000000000000..3e1ccb6bbea9 --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/write/schema.rs @@ -0,0 +1,359 @@ +use arrow_format::ipc::planus::Builder; + +use super::super::IpcField; +use crate::datatypes::{ + ArrowDataType, ArrowSchema, Field, IntegerType, IntervalUnit, Metadata, TimeUnit, UnionMode, +}; +use crate::io::ipc::endianness::is_native_little_endian; + +/// Converts a [ArrowSchema] and [IpcField]s to a flatbuffers-encoded [arrow_format::ipc::Message]. +pub fn schema_to_bytes( + schema: &ArrowSchema, + ipc_fields: &[IpcField], + custom_metadata: Option<&Metadata>, +) -> Vec { + let schema = serialize_schema(schema, ipc_fields, custom_metadata); + + let message = arrow_format::ipc::Message { + version: arrow_format::ipc::MetadataVersion::V5, + header: Some(arrow_format::ipc::MessageHeader::Schema(Box::new(schema))), + body_length: 0, + custom_metadata: None, // todo: allow writing custom metadata + }; + let mut builder = Builder::new(); + let footer_data = builder.finish(&message, None); + footer_data.to_vec() +} + +pub fn serialize_schema( + schema: &ArrowSchema, + ipc_fields: &[IpcField], + custom_schema_metadata: Option<&Metadata>, +) -> arrow_format::ipc::Schema { + let endianness = if is_native_little_endian() { + arrow_format::ipc::Endianness::Little + } else { + arrow_format::ipc::Endianness::Big + }; + + let fields = schema + .iter_values() + .zip(ipc_fields.iter()) + .map(|(field, ipc_field)| serialize_field(field, ipc_field)) + .collect::>(); + + let custom_metadata = custom_schema_metadata.and_then(|custom_meta| { + let as_kv = custom_meta + .iter() + .map(|(key, val)| key_value(key.clone().into_string(), val.clone().into_string())) + .collect::>(); + (!as_kv.is_empty()).then_some(as_kv) + }); + + arrow_format::ipc::Schema { + endianness, + fields: Some(fields), + custom_metadata, + features: None, // todo add this one + } +} + +fn key_value(key: impl Into, val: impl Into) -> arrow_format::ipc::KeyValue { + arrow_format::ipc::KeyValue { + key: Some(key.into()), + value: Some(val.into()), + } +} + +fn write_metadata(metadata: &Metadata, kv_vec: &mut Vec) { + for (k, v) in metadata { + if k.as_str() != "ARROW:extension:name" && k.as_str() != "ARROW:extension:metadata" { + kv_vec.push(key_value(k.clone().into_string(), v.clone().into_string())); + } + } +} + +fn write_extension( + name: &str, + metadata: Option<&str>, + kv_vec: &mut Vec, +) { + if let Some(metadata) = metadata { + kv_vec.push(key_value("ARROW:extension:metadata".to_string(), metadata)); + } + + kv_vec.push(key_value("ARROW:extension:name".to_string(), name)); +} + +/// Create an IPC Field from an Arrow Field +pub(crate) fn serialize_field(field: &Field, ipc_field: &IpcField) -> arrow_format::ipc::Field { + // custom metadata. + let mut kv_vec = vec![]; + if let ArrowDataType::Extension(ext) = field.dtype() { + write_extension( + &ext.name, + ext.metadata.as_ref().map(|x| x.as_str()), + &mut kv_vec, + ); + } + + let type_ = serialize_type(field.dtype()); + let children = serialize_children(field.dtype(), ipc_field); + + let dictionary = if let ArrowDataType::Dictionary(index_type, inner, is_ordered) = field.dtype() + { + if let ArrowDataType::Extension(ext) = inner.as_ref() { + write_extension( + ext.name.as_str(), + ext.metadata.as_ref().map(|x| x.as_str()), + &mut kv_vec, + ); + } + Some(serialize_dictionary( + index_type, + ipc_field + .dictionary_id + .expect("All Dictionary types have `dict_id`"), + *is_ordered, + )) + } else { + None + }; + + if let Some(metadata) = &field.metadata { + write_metadata(metadata, &mut kv_vec); + } + + let custom_metadata = if !kv_vec.is_empty() { + Some(kv_vec) + } else { + None + }; + + arrow_format::ipc::Field { + name: Some(field.name.to_string()), + nullable: field.is_nullable, + type_: Some(type_), + dictionary: dictionary.map(Box::new), + children: Some(children), + custom_metadata, + } +} + +fn serialize_time_unit(unit: &TimeUnit) -> arrow_format::ipc::TimeUnit { + match unit { + TimeUnit::Second => arrow_format::ipc::TimeUnit::Second, + TimeUnit::Millisecond => arrow_format::ipc::TimeUnit::Millisecond, + TimeUnit::Microsecond => arrow_format::ipc::TimeUnit::Microsecond, + TimeUnit::Nanosecond => arrow_format::ipc::TimeUnit::Nanosecond, + } +} + +fn serialize_type(dtype: &ArrowDataType) -> arrow_format::ipc::Type { + use ArrowDataType::*; + use arrow_format::ipc; + match dtype { + Null => ipc::Type::Null(Box::new(ipc::Null {})), + Boolean => ipc::Type::Bool(Box::new(ipc::Bool {})), + UInt8 => ipc::Type::Int(Box::new(ipc::Int { + bit_width: 8, + is_signed: false, + })), + UInt16 => ipc::Type::Int(Box::new(ipc::Int { + bit_width: 16, + is_signed: false, + })), + UInt32 => ipc::Type::Int(Box::new(ipc::Int { + bit_width: 32, + is_signed: false, + })), + UInt64 => ipc::Type::Int(Box::new(ipc::Int { + bit_width: 64, + is_signed: false, + })), + Int8 => ipc::Type::Int(Box::new(ipc::Int { + bit_width: 8, + is_signed: true, + })), + Int16 => ipc::Type::Int(Box::new(ipc::Int { + bit_width: 16, + is_signed: true, + })), + Int32 => ipc::Type::Int(Box::new(ipc::Int { + bit_width: 32, + is_signed: true, + })), + Int64 => ipc::Type::Int(Box::new(ipc::Int { + bit_width: 64, + is_signed: true, + })), + Int128 => ipc::Type::Int(Box::new(ipc::Int { + bit_width: 128, + is_signed: true, + })), + Float16 => ipc::Type::FloatingPoint(Box::new(ipc::FloatingPoint { + precision: ipc::Precision::Half, + })), + Float32 => ipc::Type::FloatingPoint(Box::new(ipc::FloatingPoint { + precision: ipc::Precision::Single, + })), + Float64 => ipc::Type::FloatingPoint(Box::new(ipc::FloatingPoint { + precision: ipc::Precision::Double, + })), + Decimal(precision, scale) => ipc::Type::Decimal(Box::new(ipc::Decimal { + precision: *precision as i32, + scale: *scale as i32, + bit_width: 128, + })), + Decimal256(precision, scale) => ipc::Type::Decimal(Box::new(ipc::Decimal { + precision: *precision as i32, + scale: *scale as i32, + bit_width: 256, + })), + Binary => ipc::Type::Binary(Box::new(ipc::Binary {})), + LargeBinary => ipc::Type::LargeBinary(Box::new(ipc::LargeBinary {})), + Utf8 => ipc::Type::Utf8(Box::new(ipc::Utf8 {})), + LargeUtf8 => ipc::Type::LargeUtf8(Box::new(ipc::LargeUtf8 {})), + FixedSizeBinary(size) => ipc::Type::FixedSizeBinary(Box::new(ipc::FixedSizeBinary { + byte_width: *size as i32, + })), + Date32 => ipc::Type::Date(Box::new(ipc::Date { + unit: ipc::DateUnit::Day, + })), + Date64 => ipc::Type::Date(Box::new(ipc::Date { + unit: ipc::DateUnit::Millisecond, + })), + Duration(unit) => ipc::Type::Duration(Box::new(ipc::Duration { + unit: serialize_time_unit(unit), + })), + Time32(unit) => ipc::Type::Time(Box::new(ipc::Time { + unit: serialize_time_unit(unit), + bit_width: 32, + })), + Time64(unit) => ipc::Type::Time(Box::new(ipc::Time { + unit: serialize_time_unit(unit), + bit_width: 64, + })), + Timestamp(unit, tz) => ipc::Type::Timestamp(Box::new(ipc::Timestamp { + unit: serialize_time_unit(unit), + timezone: tz.as_ref().map(|x| x.to_string()), + })), + Interval(unit) => ipc::Type::Interval(Box::new(ipc::Interval { + unit: match unit { + IntervalUnit::YearMonth => ipc::IntervalUnit::YearMonth, + IntervalUnit::DayTime => ipc::IntervalUnit::DayTime, + IntervalUnit::MonthDayNano => ipc::IntervalUnit::MonthDayNano, + }, + })), + List(_) => ipc::Type::List(Box::new(ipc::List {})), + LargeList(_) => ipc::Type::LargeList(Box::new(ipc::LargeList {})), + FixedSizeList(_, size) => ipc::Type::FixedSizeList(Box::new(ipc::FixedSizeList { + list_size: *size as i32, + })), + Union(u) => ipc::Type::Union(Box::new(ipc::Union { + mode: match u.mode { + UnionMode::Dense => ipc::UnionMode::Dense, + UnionMode::Sparse => ipc::UnionMode::Sparse, + }, + type_ids: u.ids.clone(), + })), + Map(_, keys_sorted) => ipc::Type::Map(Box::new(ipc::Map { + keys_sorted: *keys_sorted, + })), + Struct(_) => ipc::Type::Struct(Box::new(ipc::Struct {})), + Dictionary(_, v, _) => serialize_type(v), + Extension(ext) => serialize_type(&ext.inner), + Utf8View => ipc::Type::Utf8View(Box::new(ipc::Utf8View {})), + BinaryView => ipc::Type::BinaryView(Box::new(ipc::BinaryView {})), + Unknown => unimplemented!(), + } +} + +fn serialize_children( + dtype: &ArrowDataType, + ipc_field: &IpcField, +) -> Vec { + use ArrowDataType::*; + match dtype { + Null + | Boolean + | Int8 + | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Int128 + | Float16 + | Float32 + | Float64 + | Timestamp(_, _) + | Date32 + | Date64 + | Time32(_) + | Time64(_) + | Duration(_) + | Interval(_) + | Binary + | FixedSizeBinary(_) + | LargeBinary + | Utf8 + | LargeUtf8 + | Decimal(_, _) + | Utf8View + | BinaryView + | Decimal256(_, _) => vec![], + FixedSizeList(inner, _) | LargeList(inner) | List(inner) | Map(inner, _) => { + vec![serialize_field(inner, &ipc_field.fields[0])] + }, + Struct(fields) => fields + .iter() + .zip(ipc_field.fields.iter()) + .map(|(field, ipc)| serialize_field(field, ipc)) + .collect(), + Union(u) => u + .fields + .iter() + .zip(ipc_field.fields.iter()) + .map(|(field, ipc)| serialize_field(field, ipc)) + .collect(), + Dictionary(_, inner, _) => serialize_children(inner, ipc_field), + Extension(ext) => serialize_children(&ext.inner, ipc_field), + Unknown => unimplemented!(), + } +} + +/// Create an IPC dictionary encoding +pub(crate) fn serialize_dictionary( + index_type: &IntegerType, + dict_id: i64, + dict_is_ordered: bool, +) -> arrow_format::ipc::DictionaryEncoding { + use IntegerType::*; + let is_signed = match index_type { + Int8 | Int16 | Int32 | Int64 | Int128 => true, + UInt8 | UInt16 | UInt32 | UInt64 => false, + }; + + let bit_width = match index_type { + Int8 | UInt8 => 8, + Int16 | UInt16 => 16, + Int32 | UInt32 => 32, + Int64 | UInt64 => 64, + Int128 => 128, + }; + + let index_type = arrow_format::ipc::Int { + bit_width, + is_signed, + }; + + arrow_format::ipc::DictionaryEncoding { + id: dict_id, + index_type: Some(Box::new(index_type)), + is_ordered: dict_is_ordered, + dictionary_kind: arrow_format::ipc::DictionaryKind::DenseArray, + } +} diff --git a/crates/polars-arrow/src/io/ipc/write/serialize/binary.rs b/crates/polars-arrow/src/io/ipc/write/serialize/binary.rs new file mode 100644 index 000000000000..9642ded1f78b --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/write/serialize/binary.rs @@ -0,0 +1,93 @@ +use super::*; + +#[allow(clippy::too_many_arguments)] +fn write_generic_binary( + validity: Option<&Bitmap>, + offsets: &OffsetsBuffer, + values: &[u8], + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + let offsets = offsets.buffer(); + write_bitmap( + validity, + offsets.len() - 1, + buffers, + arrow_data, + offset, + compression, + ); + + let first = *offsets.first().unwrap(); + let last = *offsets.last().unwrap(); + if first == O::default() { + write_buffer( + offsets, + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); + } else { + write_buffer_from_iter( + offsets.iter().map(|x| *x - first), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); + } + + write_bytes( + &values[first.to_usize()..last.to_usize()], + buffers, + arrow_data, + offset, + compression, + ); +} + +pub(super) fn write_binary( + array: &BinaryArray, + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + write_generic_binary( + array.validity(), + array.offsets(), + array.values(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); +} + +pub(super) fn write_utf8( + array: &Utf8Array, + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + write_generic_binary( + array.validity(), + array.offsets(), + array.values(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); +} diff --git a/crates/polars-arrow/src/io/ipc/write/serialize/binview.rs b/crates/polars-arrow/src/io/ipc/write/serialize/binview.rs new file mode 100644 index 000000000000..66afafbd0e69 --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/write/serialize/binview.rs @@ -0,0 +1,34 @@ +use super::*; +use crate::array; + +#[allow(clippy::too_many_arguments)] +pub(super) fn write_binview( + array: &BinaryViewArrayGeneric, + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + write_bitmap( + array.validity(), + array::Array::len(array), + buffers, + arrow_data, + offset, + compression, + ); + + write_buffer( + array.views(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); + + for data in array.data_buffers().as_ref() { + write_bytes(data, buffers, arrow_data, offset, compression); + } +} diff --git a/crates/polars-arrow/src/io/ipc/write/serialize/boolean.rs b/crates/polars-arrow/src/io/ipc/write/serialize/boolean.rs new file mode 100644 index 000000000000..f699860b89cd --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/write/serialize/boolean.rs @@ -0,0 +1,27 @@ +use super::*; + +pub(super) fn write_boolean( + array: &BooleanArray, + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + _: bool, + compression: Option, +) { + write_bitmap( + array.validity(), + array.len(), + buffers, + arrow_data, + offset, + compression, + ); + write_bitmap( + Some(&array.values().clone()), + array.len(), + buffers, + arrow_data, + offset, + compression, + ); +} diff --git a/crates/polars-arrow/src/io/ipc/write/serialize/dictionary.rs b/crates/polars-arrow/src/io/ipc/write/serialize/dictionary.rs new file mode 100644 index 000000000000..0d1eb96ea7e3 --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/write/serialize/dictionary.rs @@ -0,0 +1,37 @@ +use super::*; + +// use `write_keys` to either write keys or values +#[allow(clippy::too_many_arguments)] +pub fn write_dictionary( + array: &DictionaryArray, + buffers: &mut Vec, + arrow_data: &mut Vec, + nodes: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, + write_keys: bool, +) -> usize { + if write_keys { + write_primitive( + array.keys(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); + array.keys().len() + } else { + write( + array.values().as_ref(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ); + array.values().len() + } +} diff --git a/crates/polars-arrow/src/io/ipc/write/serialize/fixed_size_binary.rs b/crates/polars-arrow/src/io/ipc/write/serialize/fixed_size_binary.rs new file mode 100644 index 000000000000..dc1e973b4d4a --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/write/serialize/fixed_size_binary.rs @@ -0,0 +1,20 @@ +use super::*; + +pub(super) fn write_fixed_size_binary( + array: &FixedSizeBinaryArray, + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + _is_little_endian: bool, + compression: Option, +) { + write_bitmap( + array.validity(), + array.len(), + buffers, + arrow_data, + offset, + compression, + ); + write_bytes(array.values(), buffers, arrow_data, offset, compression); +} diff --git a/crates/polars-arrow/src/io/ipc/write/serialize/fixed_sized_list.rs b/crates/polars-arrow/src/io/ipc/write/serialize/fixed_sized_list.rs new file mode 100644 index 000000000000..da8fa7db962b --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/write/serialize/fixed_sized_list.rs @@ -0,0 +1,29 @@ +use super::*; + +pub(super) fn write_fixed_size_list( + array: &FixedSizeListArray, + buffers: &mut Vec, + arrow_data: &mut Vec, + nodes: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + write_bitmap( + array.validity(), + array.len(), + buffers, + arrow_data, + offset, + compression, + ); + write( + array.values().as_ref(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ); +} diff --git a/crates/polars-arrow/src/io/ipc/write/serialize/list.rs b/crates/polars-arrow/src/io/ipc/write/serialize/list.rs new file mode 100644 index 000000000000..8cca7eba1b87 --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/write/serialize/list.rs @@ -0,0 +1,58 @@ +use super::*; + +pub(super) fn write_list( + array: &ListArray, + buffers: &mut Vec, + arrow_data: &mut Vec, + nodes: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + let offsets = array.offsets().buffer(); + let validity = array.validity(); + + write_bitmap( + validity, + offsets.len() - 1, + buffers, + arrow_data, + offset, + compression, + ); + + let first = *offsets.first().unwrap(); + let last = *offsets.last().unwrap(); + if first == O::zero() { + write_buffer( + offsets, + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); + } else { + write_buffer_from_iter( + offsets.iter().map(|x| *x - first), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); + } + + write( + array + .values() + .sliced(first.to_usize(), last.to_usize() - first.to_usize()) + .as_ref(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ); +} diff --git a/crates/polars-arrow/src/io/ipc/write/serialize/map.rs b/crates/polars-arrow/src/io/ipc/write/serialize/map.rs new file mode 100644 index 000000000000..19492679e418 --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/write/serialize/map.rs @@ -0,0 +1,58 @@ +use super::*; + +pub(super) fn write_map( + array: &MapArray, + buffers: &mut Vec, + arrow_data: &mut Vec, + nodes: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + let offsets = array.offsets().buffer(); + let validity = array.validity(); + + write_bitmap( + validity, + offsets.len() - 1, + buffers, + arrow_data, + offset, + compression, + ); + + let first = *offsets.first().unwrap(); + let last = *offsets.last().unwrap(); + if first == 0 { + write_buffer( + offsets, + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); + } else { + write_buffer_from_iter( + offsets.iter().map(|x| *x - first), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); + } + + write( + array + .field() + .sliced(first as usize, last as usize - first as usize) + .as_ref(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ); +} diff --git a/crates/polars-arrow/src/io/ipc/write/serialize/mod.rs b/crates/polars-arrow/src/io/ipc/write/serialize/mod.rs new file mode 100644 index 000000000000..9278350fd65d --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/write/serialize/mod.rs @@ -0,0 +1,392 @@ +#![allow(clippy::ptr_arg)] // false positive in clippy, see https://github.com/rust-lang/rust-clippy/issues/8463 +use arrow_format::ipc; + +use super::super::compression; +use super::super::endianness::is_native_little_endian; +use super::common::{Compression, pad_to_64}; +use crate::array::*; +use crate::bitmap::Bitmap; +use crate::datatypes::PhysicalType; +use crate::offset::{Offset, OffsetsBuffer}; +use crate::trusted_len::TrustedLen; +use crate::types::NativeType; +use crate::{match_integer_type, with_match_primitive_type_full}; +mod binary; +mod binview; +mod boolean; +mod dictionary; +mod fixed_size_binary; +mod fixed_sized_list; +mod list; +mod map; +mod primitive; +mod struct_; +mod union; + +use binary::*; +use binview::*; +use boolean::*; +pub(super) use dictionary::*; +use fixed_size_binary::*; +use fixed_sized_list::*; +use list::*; +use map::*; +use primitive::*; +use struct_::*; +use union::*; + +/// Writes an [`Array`] to `arrow_data` +pub fn write( + array: &dyn Array, + buffers: &mut Vec, + arrow_data: &mut Vec, + nodes: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + nodes.push(ipc::FieldNode { + length: array.len() as i64, + null_count: array.null_count() as i64, + }); + use PhysicalType::*; + match array.dtype().to_physical_type() { + Null => (), + Boolean => write_boolean( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ), + Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { + let array = array.as_any().downcast_ref().unwrap(); + write_primitive::<$T>(array, buffers, arrow_data, offset, is_little_endian, compression) + }), + Binary => write_binary::( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ), + LargeBinary => write_binary::( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ), + FixedSizeBinary => write_fixed_size_binary( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ), + Utf8 => write_utf8::( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ), + LargeUtf8 => write_utf8::( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ), + List => write_list::( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ), + LargeList => write_list::( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ), + FixedSizeList => write_fixed_size_list( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ), + Struct => write_struct( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ), + Dictionary(key_type) => match_integer_type!(key_type, |$T| { + write_dictionary::<$T>( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + true, + ); + }), + Union => { + write_union( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ); + }, + Map => { + write_map( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ); + }, + Utf8View => write_binview( + array.as_any().downcast_ref::().unwrap(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ), + BinaryView => write_binview( + array.as_any().downcast_ref::().unwrap(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ), + } +} + +#[inline] +fn pad_buffer_to_64(buffer: &mut Vec, length: usize) { + let pad_len = pad_to_64(length); + for _ in 0..pad_len { + buffer.push(0u8); + } +} + +/// writes `bytes` to `arrow_data` updating `buffers` and `offset` and guaranteeing a 8 byte boundary. +fn write_bytes( + bytes: &[u8], + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + compression: Option, +) { + let start = arrow_data.len(); + if let Some(compression) = compression { + arrow_data.extend_from_slice(&(bytes.len() as i64).to_le_bytes()); + match compression { + Compression::LZ4 => { + compression::compress_lz4(bytes, arrow_data).unwrap(); + }, + Compression::ZSTD => { + compression::compress_zstd(bytes, arrow_data).unwrap(); + }, + } + } else { + arrow_data.extend_from_slice(bytes); + }; + + buffers.push(finish_buffer(arrow_data, start, offset)); +} + +fn write_bitmap( + bitmap: Option<&Bitmap>, + length: usize, + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + compression: Option, +) { + match bitmap { + Some(bitmap) => { + assert_eq!(bitmap.len(), length); + let (slice, slice_offset, _) = bitmap.as_slice(); + if slice_offset != 0 { + // case where we can't slice the bitmap as the offsets are not multiple of 8 + let bytes = Bitmap::from_trusted_len_iter(bitmap.iter()); + let (slice, _, _) = bytes.as_slice(); + write_bytes(slice, buffers, arrow_data, offset, compression) + } else { + write_bytes(slice, buffers, arrow_data, offset, compression) + } + }, + None => { + buffers.push(ipc::Buffer { + offset: *offset, + length: 0, + }); + }, + } +} + +/// writes `bytes` to `arrow_data` updating `buffers` and `offset` and guaranteeing a 8 byte boundary. +fn write_buffer( + buffer: &[T], + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + let start = arrow_data.len(); + if let Some(compression) = compression { + _write_compressed_buffer(buffer, arrow_data, is_little_endian, compression); + } else { + _write_buffer(buffer, arrow_data, is_little_endian); + }; + + buffers.push(finish_buffer(arrow_data, start, offset)); +} + +#[inline] +fn _write_buffer_from_iter>( + buffer: I, + arrow_data: &mut Vec, + is_little_endian: bool, +) { + let len = buffer.size_hint().0; + arrow_data.reserve(len * size_of::()); + if is_little_endian { + buffer + .map(|x| T::to_le_bytes(&x)) + .for_each(|x| arrow_data.extend_from_slice(x.as_ref())) + } else { + buffer + .map(|x| T::to_be_bytes(&x)) + .for_each(|x| arrow_data.extend_from_slice(x.as_ref())) + } +} + +#[inline] +fn _write_compressed_buffer_from_iter>( + buffer: I, + arrow_data: &mut Vec, + is_little_endian: bool, + compression: Compression, +) { + let len = buffer.size_hint().0; + let mut swapped = Vec::with_capacity(len * size_of::()); + if is_little_endian { + buffer + .map(|x| T::to_le_bytes(&x)) + .for_each(|x| swapped.extend_from_slice(x.as_ref())); + } else { + buffer + .map(|x| T::to_be_bytes(&x)) + .for_each(|x| swapped.extend_from_slice(x.as_ref())) + }; + arrow_data.extend_from_slice(&(swapped.len() as i64).to_le_bytes()); + match compression { + Compression::LZ4 => { + compression::compress_lz4(&swapped, arrow_data).unwrap(); + }, + Compression::ZSTD => { + compression::compress_zstd(&swapped, arrow_data).unwrap(); + }, + } +} + +fn _write_buffer(buffer: &[T], arrow_data: &mut Vec, is_little_endian: bool) { + if is_little_endian == is_native_little_endian() { + // in native endianness we can use the bytes directly. + let buffer = bytemuck::cast_slice(buffer); + arrow_data.extend_from_slice(buffer); + } else { + _write_buffer_from_iter(buffer.iter().copied(), arrow_data, is_little_endian) + } +} + +fn _write_compressed_buffer( + buffer: &[T], + arrow_data: &mut Vec, + is_little_endian: bool, + compression: Compression, +) { + if is_little_endian == is_native_little_endian() { + let bytes = bytemuck::cast_slice(buffer); + arrow_data.extend_from_slice(&(bytes.len() as i64).to_le_bytes()); + match compression { + Compression::LZ4 => { + compression::compress_lz4(bytes, arrow_data).unwrap(); + }, + Compression::ZSTD => { + compression::compress_zstd(bytes, arrow_data).unwrap(); + }, + } + } else { + todo!() + } +} + +/// writes `bytes` to `arrow_data` updating `buffers` and `offset` and guaranteeing a 8 byte boundary. +#[inline] +fn write_buffer_from_iter>( + buffer: I, + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + let start = arrow_data.len(); + + if let Some(compression) = compression { + _write_compressed_buffer_from_iter(buffer, arrow_data, is_little_endian, compression); + } else { + _write_buffer_from_iter(buffer, arrow_data, is_little_endian); + } + + buffers.push(finish_buffer(arrow_data, start, offset)); +} + +fn finish_buffer(arrow_data: &mut Vec, start: usize, offset: &mut i64) -> ipc::Buffer { + let buffer_len = (arrow_data.len() - start) as i64; + + pad_buffer_to_64(arrow_data, arrow_data.len() - start); + let total_len = (arrow_data.len() - start) as i64; + + let buffer = ipc::Buffer { + offset: *offset, + length: buffer_len, + }; + *offset += total_len; + buffer +} diff --git a/crates/polars-arrow/src/io/ipc/write/serialize/primitive.rs b/crates/polars-arrow/src/io/ipc/write/serialize/primitive.rs new file mode 100644 index 000000000000..acd3ad672f78 --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/write/serialize/primitive.rs @@ -0,0 +1,28 @@ +use super::*; + +pub(super) fn write_primitive( + array: &PrimitiveArray, + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + write_bitmap( + array.validity(), + array.len(), + buffers, + arrow_data, + offset, + compression, + ); + + write_buffer( + array.values(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ) +} diff --git a/crates/polars-arrow/src/io/ipc/write/serialize/struct_.rs b/crates/polars-arrow/src/io/ipc/write/serialize/struct_.rs new file mode 100644 index 000000000000..67353746d4cd --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/write/serialize/struct_.rs @@ -0,0 +1,31 @@ +use super::*; + +pub(super) fn write_struct( + array: &StructArray, + buffers: &mut Vec, + arrow_data: &mut Vec, + nodes: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + write_bitmap( + array.validity(), + array.len(), + buffers, + arrow_data, + offset, + compression, + ); + array.values().iter().for_each(|array| { + write( + array.as_ref(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ); + }); +} diff --git a/crates/polars-arrow/src/io/ipc/write/serialize/union.rs b/crates/polars-arrow/src/io/ipc/write/serialize/union.rs new file mode 100644 index 000000000000..9f0e53fcf67b --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/write/serialize/union.rs @@ -0,0 +1,42 @@ +use super::*; + +pub(super) fn write_union( + array: &UnionArray, + buffers: &mut Vec, + arrow_data: &mut Vec, + nodes: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + write_buffer( + array.types(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); + + if let Some(offsets) = array.offsets() { + write_buffer( + offsets, + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); + } + array.fields().iter().for_each(|array| { + write( + array.as_ref(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ) + }); +} diff --git a/crates/polars-arrow/src/io/ipc/write/stream.rs b/crates/polars-arrow/src/io/ipc/write/stream.rs new file mode 100644 index 000000000000..64f22e4802bf --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/write/stream.rs @@ -0,0 +1,132 @@ +//! Arrow IPC File and Stream Writers +//! +//! The `FileWriter` and `StreamWriter` have similar interfaces, +//! however the `FileWriter` expects a reader that supports `Seek`ing + +use std::io::Write; +use std::sync::Arc; + +use polars_error::{PolarsError, PolarsResult}; + +use super::super::IpcField; +use super::common::{DictionaryTracker, EncodedData, WriteOptions, encode_chunk}; +use super::common_sync::{write_continuation, write_message}; +use super::{default_ipc_fields, schema_to_bytes}; +use crate::array::Array; +use crate::datatypes::*; +use crate::record_batch::RecordBatchT; + +/// Arrow stream writer +/// +/// The data written by this writer must be read in order. To signal that no more +/// data is arriving through the stream call [`self.finish()`](StreamWriter::finish); +/// +/// For a usage walkthrough consult [this example](https://github.com/jorgecarleitao/polars_arrow/tree/main/examples/ipc_pyarrow). +pub struct StreamWriter { + /// The object to write to + writer: W, + /// IPC write options + write_options: WriteOptions, + /// Whether the stream has been finished + finished: bool, + /// Keeps track of dictionaries that have been written + dictionary_tracker: DictionaryTracker, + /// Custom schema-level metadata + custom_schema_metadata: Option>, + + ipc_fields: Option>, +} + +impl StreamWriter { + /// Creates a new [`StreamWriter`] + pub fn new(writer: W, write_options: WriteOptions) -> Self { + Self { + writer, + write_options, + finished: false, + dictionary_tracker: DictionaryTracker { + dictionaries: Default::default(), + cannot_replace: false, + }, + ipc_fields: None, + custom_schema_metadata: None, + } + } + + /// Sets custom schema metadata. Must be called before `start` is called + pub fn set_custom_schema_metadata(&mut self, custom_metadata: Arc) { + self.custom_schema_metadata = Some(custom_metadata); + } + + /// Starts the stream by writing a Schema message to it. + /// Use `ipc_fields` to declare dictionary ids in the schema, for dictionary-reuse + pub fn start( + &mut self, + schema: &ArrowSchema, + ipc_fields: Option>, + ) -> PolarsResult<()> { + self.ipc_fields = Some(if let Some(ipc_fields) = ipc_fields { + ipc_fields + } else { + default_ipc_fields(schema.iter_values()) + }); + + let encoded_message = EncodedData { + ipc_message: schema_to_bytes( + schema, + self.ipc_fields.as_ref().unwrap(), + self.custom_schema_metadata.as_deref(), + ), + arrow_data: vec![], + }; + write_message(&mut self.writer, &encoded_message)?; + Ok(()) + } + + /// Writes [`RecordBatchT`] to the stream + pub fn write( + &mut self, + columns: &RecordBatchT>, + ipc_fields: Option<&[IpcField]>, + ) -> PolarsResult<()> { + if self.finished { + let io_err = std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "Cannot write to a finished stream".to_string(), + ); + return Err(PolarsError::from(io_err)); + } + + // we can't make it a closure because it borrows (and it can't borrow mut and non-mut below) + #[allow(clippy::or_fun_call)] + let fields = ipc_fields.unwrap_or(self.ipc_fields.as_ref().unwrap()); + + let (encoded_dictionaries, encoded_message) = encode_chunk( + columns, + fields, + &mut self.dictionary_tracker, + &self.write_options, + )?; + + for encoded_dictionary in encoded_dictionaries { + write_message(&mut self.writer, &encoded_dictionary)?; + } + + write_message(&mut self.writer, &encoded_message)?; + Ok(()) + } + + /// Write continuation bytes, and mark the stream as done + pub fn finish(&mut self) -> PolarsResult<()> { + write_continuation(&mut self.writer, 0)?; + + self.finished = true; + + Ok(()) + } + + /// Consumes itself, returning the inner writer. + pub fn into_inner(self) -> W { + self.writer + } +} diff --git a/crates/polars-arrow/src/io/ipc/write/writer.rs b/crates/polars-arrow/src/io/ipc/write/writer.rs new file mode 100644 index 000000000000..852123bdedb3 --- /dev/null +++ b/crates/polars-arrow/src/io/ipc/write/writer.rs @@ -0,0 +1,256 @@ +use std::io::Write; +use std::sync::Arc; + +use arrow_format::ipc::planus::Builder; +use polars_error::{PolarsResult, polars_bail}; + +use super::super::{ARROW_MAGIC_V2, IpcField}; +use super::common::{DictionaryTracker, EncodedData, WriteOptions}; +use super::common_sync::{write_continuation, write_message}; +use super::{default_ipc_fields, schema, schema_to_bytes}; +use crate::array::Array; +use crate::datatypes::*; +use crate::io::ipc::write::common::encode_chunk_amortized; +use crate::record_batch::RecordBatchT; + +#[derive(Clone, Copy, PartialEq, Eq)] +pub(crate) enum State { + None, + Started, + Finished, +} + +/// Arrow file writer +pub struct FileWriter { + /// The object to write to + pub(crate) writer: W, + /// IPC write options + pub(crate) options: WriteOptions, + /// A reference to the schema, used in validating record batches + pub(crate) schema: ArrowSchemaRef, + pub(crate) ipc_fields: Vec, + /// The number of bytes between each block of bytes, as an offset for random access + pub(crate) block_offsets: usize, + /// Dictionary blocks that will be written as part of the IPC footer + pub(crate) dictionary_blocks: Vec, + /// Record blocks that will be written as part of the IPC footer + pub(crate) record_blocks: Vec, + /// Whether the writer footer has been written, and the writer is finished + pub(crate) state: State, + /// Keeps track of dictionaries that have been written + pub(crate) dictionary_tracker: DictionaryTracker, + /// Buffer/scratch that is reused between writes + pub(crate) encoded_message: EncodedData, + /// Custom schema-level metadata + pub(crate) custom_schema_metadata: Option>, +} + +impl FileWriter { + /// Creates a new [`FileWriter`] and writes the header to `writer` + pub fn try_new( + writer: W, + schema: ArrowSchemaRef, + ipc_fields: Option>, + options: WriteOptions, + ) -> PolarsResult { + let mut slf = Self::new(writer, schema, ipc_fields, options); + slf.start()?; + + Ok(slf) + } + + /// Creates a new [`FileWriter`]. + pub fn new( + writer: W, + schema: ArrowSchemaRef, + ipc_fields: Option>, + options: WriteOptions, + ) -> Self { + let ipc_fields = if let Some(ipc_fields) = ipc_fields { + ipc_fields + } else { + default_ipc_fields(schema.iter_values()) + }; + + Self { + writer, + options, + schema, + ipc_fields, + block_offsets: 0, + dictionary_blocks: vec![], + record_blocks: vec![], + state: State::None, + dictionary_tracker: DictionaryTracker { + dictionaries: Default::default(), + cannot_replace: true, + }, + encoded_message: Default::default(), + custom_schema_metadata: None, + } + } + + /// Consumes itself into the inner writer + pub fn into_inner(self) -> W { + self.writer + } + + /// Get the inner memory scratches so they can be reused in a new writer. + /// This can be utilized to save memory allocations for performance reasons. + pub fn get_scratches(&mut self) -> EncodedData { + std::mem::take(&mut self.encoded_message) + } + /// Set the inner memory scratches so they can be reused in a new writer. + /// This can be utilized to save memory allocations for performance reasons. + pub fn set_scratches(&mut self, scratches: EncodedData) { + self.encoded_message = scratches; + } + + /// Writes the header and first (schema) message to the file. + /// # Errors + /// Errors if the file has been started or has finished. + pub fn start(&mut self) -> PolarsResult<()> { + if self.state != State::None { + polars_bail!(oos = "The IPC file can only be started once"); + } + // write magic to header + self.writer.write_all(&ARROW_MAGIC_V2[..])?; + // create an 8-byte boundary after the header + self.writer.write_all(&[0, 0])?; + // write the schema, set the written bytes to the schema + + let encoded_message = EncodedData { + ipc_message: schema_to_bytes( + &self.schema, + &self.ipc_fields, + // No need to pass metadata here, as it is already written to the footer in `finish` + None, + ), + arrow_data: vec![], + }; + + let (meta, data) = write_message(&mut self.writer, &encoded_message)?; + self.block_offsets += meta + data + 8; // 8 <=> arrow magic + 2 bytes for alignment + self.state = State::Started; + Ok(()) + } + + /// Writes [`RecordBatchT`] to the file + pub fn write( + &mut self, + chunk: &RecordBatchT>, + ipc_fields: Option<&[IpcField]>, + ) -> PolarsResult<()> { + if self.state != State::Started { + polars_bail!( + oos = "The IPC file must be started before it can be written to. Call `start` before `write`" + ); + } + + let ipc_fields = if let Some(ipc_fields) = ipc_fields { + ipc_fields + } else { + self.ipc_fields.as_ref() + }; + let encoded_dictionaries = encode_chunk_amortized( + chunk, + ipc_fields, + &mut self.dictionary_tracker, + &self.options, + &mut self.encoded_message, + )?; + + let encoded_message = std::mem::take(&mut self.encoded_message); + self.write_encoded(&encoded_dictionaries[..], &encoded_message)?; + self.encoded_message = encoded_message; + + Ok(()) + } + + pub fn write_encoded( + &mut self, + encoded_dictionaries: &[EncodedData], + encoded_message: &EncodedData, + ) -> PolarsResult<()> { + if self.state != State::Started { + polars_bail!( + oos = "The IPC file must be started before it can be written to. Call `start` before `write`" + ); + } + + // add all dictionaries + for encoded_dictionary in encoded_dictionaries { + let (meta, data) = write_message(&mut self.writer, encoded_dictionary)?; + + let block = arrow_format::ipc::Block { + offset: self.block_offsets as i64, + meta_data_length: meta as i32, + body_length: data as i64, + }; + self.dictionary_blocks.push(block); + self.block_offsets += meta + data; + } + + self.write_encoded_record_batch(encoded_message)?; + + Ok(()) + } + + pub fn write_encoded_record_batch( + &mut self, + encoded_message: &EncodedData, + ) -> PolarsResult<()> { + let (meta, data) = write_message(&mut self.writer, encoded_message)?; + // add a record block for the footer + let block = arrow_format::ipc::Block { + offset: self.block_offsets as i64, + meta_data_length: meta as i32, // TODO: is this still applicable? + body_length: data as i64, + }; + self.record_blocks.push(block); + self.block_offsets += meta + data; + + Ok(()) + } + + /// Write footer and closing tag, then mark the writer as done + pub fn finish(&mut self) -> PolarsResult<()> { + if self.state != State::Started { + polars_bail!( + oos = "The IPC file must be started before it can be finished. Call `start` before `finish`" + ); + } + + // write EOS + write_continuation(&mut self.writer, 0)?; + + let schema = schema::serialize_schema( + &self.schema, + &self.ipc_fields, + self.custom_schema_metadata.as_deref(), + ); + + let root = arrow_format::ipc::Footer { + version: arrow_format::ipc::MetadataVersion::V5, + schema: Some(Box::new(schema)), + dictionaries: Some(std::mem::take(&mut self.dictionary_blocks)), + record_batches: Some(std::mem::take(&mut self.record_blocks)), + custom_metadata: None, + }; + let mut builder = Builder::new(); + let footer_data = builder.finish(&root, None); + self.writer.write_all(footer_data)?; + self.writer + .write_all(&(footer_data.len() as i32).to_le_bytes())?; + self.writer.write_all(&ARROW_MAGIC_V2)?; + self.writer.flush()?; + self.state = State::Finished; + + Ok(()) + } + + /// Sets custom schema metadata. Must be called before `start` is called + pub fn set_custom_schema_metadata(&mut self, custom_metadata: Arc) { + self.custom_schema_metadata = Some(custom_metadata); + } +} diff --git a/crates/polars-arrow/src/io/iterator.rs b/crates/polars-arrow/src/io/iterator.rs new file mode 100644 index 000000000000..91ec86fc2e04 --- /dev/null +++ b/crates/polars-arrow/src/io/iterator.rs @@ -0,0 +1,65 @@ +pub use streaming_iterator::StreamingIterator; + +/// A [`StreamingIterator`] with an internal buffer of [`Vec`] used to efficiently +/// present items of type `T` as `&[u8]`. +/// It is generic over the type `T` and the transformation `F: T -> &[u8]`. +pub struct BufStreamingIterator +where + I: Iterator, + F: FnMut(T, &mut Vec), +{ + iterator: I, + f: F, + buffer: Vec, + is_valid: bool, +} + +impl BufStreamingIterator +where + I: Iterator, + F: FnMut(T, &mut Vec), +{ + #[inline] + pub fn new(iterator: I, f: F, buffer: Vec) -> Self { + Self { + iterator, + f, + buffer, + is_valid: false, + } + } +} + +impl StreamingIterator for BufStreamingIterator +where + I: Iterator, + F: FnMut(T, &mut Vec), +{ + type Item = [u8]; + + #[inline] + fn advance(&mut self) { + let a = self.iterator.next(); + if let Some(a) = a { + self.is_valid = true; + self.buffer.clear(); + (self.f)(a, &mut self.buffer); + } else { + self.is_valid = false; + } + } + + #[inline] + fn get(&self) -> Option<&Self::Item> { + if self.is_valid { + Some(&self.buffer) + } else { + None + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.iterator.size_hint() + } +} diff --git a/crates/polars-arrow/src/io/mod.rs b/crates/polars-arrow/src/io/mod.rs new file mode 100644 index 000000000000..1ae39a8ef766 --- /dev/null +++ b/crates/polars-arrow/src/io/mod.rs @@ -0,0 +1,9 @@ +#[cfg(feature = "io_ipc")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc")))] +pub mod ipc; + +#[cfg(feature = "io_avro")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_avro")))] +pub mod avro; + +pub mod iterator; diff --git a/crates/polars-arrow/src/legacy/array/default_arrays.rs b/crates/polars-arrow/src/legacy/array/default_arrays.rs new file mode 100644 index 000000000000..02dfd6600a4d --- /dev/null +++ b/crates/polars-arrow/src/legacy/array/default_arrays.rs @@ -0,0 +1,65 @@ +use crate::array::{BinaryArray, BooleanArray, PrimitiveArray, Utf8Array}; +use crate::bitmap::Bitmap; +use crate::buffer::Buffer; +use crate::datatypes::ArrowDataType; +use crate::offset::OffsetsBuffer; +use crate::types::NativeType; + +pub trait FromData { + fn from_data_default(values: T, validity: Option) -> Self; +} + +impl FromData for BooleanArray { + fn from_data_default(values: Bitmap, validity: Option) -> BooleanArray { + BooleanArray::new(ArrowDataType::Boolean, values, validity) + } +} + +impl FromData> for PrimitiveArray { + fn from_data_default(values: Buffer, validity: Option) -> Self { + let dt = T::PRIMITIVE; + PrimitiveArray::new(dt.into(), values, validity) + } +} + +pub trait FromDataUtf8 { + /// # Safety + /// `values` buffer must contain valid utf8 between every `offset` + unsafe fn from_data_unchecked_default( + offsets: Buffer, + values: Buffer, + validity: Option, + ) -> Self; +} + +impl FromDataUtf8 for Utf8Array { + unsafe fn from_data_unchecked_default( + offsets: Buffer, + values: Buffer, + validity: Option, + ) -> Self { + let offsets = OffsetsBuffer::new_unchecked(offsets); + Utf8Array::new_unchecked(ArrowDataType::LargeUtf8, offsets, values, validity) + } +} + +pub trait FromDataBinary { + /// # Safety + /// `values` buffer must contain valid utf8 between every `offset` + unsafe fn from_data_unchecked_default( + offsets: Buffer, + values: Buffer, + validity: Option, + ) -> Self; +} + +impl FromDataBinary for BinaryArray { + unsafe fn from_data_unchecked_default( + offsets: Buffer, + values: Buffer, + validity: Option, + ) -> Self { + let offsets = OffsetsBuffer::new_unchecked(offsets); + BinaryArray::new(ArrowDataType::LargeBinary, offsets, values, validity) + } +} diff --git a/crates/polars-arrow/src/legacy/array/fixed_size_list.rs b/crates/polars-arrow/src/legacy/array/fixed_size_list.rs new file mode 100644 index 000000000000..85b5a438281a --- /dev/null +++ b/crates/polars-arrow/src/legacy/array/fixed_size_list.rs @@ -0,0 +1,99 @@ +use polars_error::PolarsResult; + +use crate::array::{ArrayRef, FixedSizeListArray, NullArray, new_null_array}; +use crate::bitmap::BitmapBuilder; +use crate::compute::concatenate::concatenate_unchecked; +use crate::datatypes::ArrowDataType; +use crate::legacy::array::{convert_inner_type, is_nested_null}; + +#[derive(Default)] +pub struct AnonymousBuilder { + arrays: Vec, + validity: Option, + length: usize, + pub width: usize, +} + +impl AnonymousBuilder { + pub fn new(capacity: usize, width: usize) -> Self { + Self { + arrays: Vec::with_capacity(capacity), + validity: None, + width, + length: 0, + } + } + pub fn is_empty(&self) -> bool { + self.arrays.is_empty() + } + + #[inline] + pub fn push(&mut self, arr: ArrayRef) { + self.arrays.push(arr); + + if let Some(validity) = &mut self.validity { + validity.push(true) + } + + self.length += 1; + } + + pub fn push_null(&mut self) { + self.arrays + .push(NullArray::new(ArrowDataType::Null, self.width).boxed()); + match &mut self.validity { + Some(validity) => validity.push(false), + None => self.init_validity(), + } + + self.length += 1; + } + + fn init_validity(&mut self) { + let mut validity = BitmapBuilder::with_capacity(self.arrays.capacity()); + if !self.arrays.is_empty() { + validity.extend_constant(self.arrays.len() - 1, true); + validity.push(false); + } + self.validity = Some(validity) + } + + pub fn finish(self, inner_dtype: Option<&ArrowDataType>) -> PolarsResult { + let mut inner_dtype = inner_dtype.unwrap_or_else(|| self.arrays[0].dtype()); + + if is_nested_null(inner_dtype) { + for arr in &self.arrays { + if !is_nested_null(arr.dtype()) { + inner_dtype = arr.dtype(); + break; + } + } + }; + + // convert nested null arrays to the correct dtype. + let arrays = self + .arrays + .iter() + .map(|arr| { + if matches!(arr.dtype(), ArrowDataType::Null) { + new_null_array(inner_dtype.clone(), arr.len()) + } else if is_nested_null(arr.dtype()) { + convert_inner_type(&**arr, inner_dtype) + } else { + arr.to_boxed() + } + }) + .collect::>(); + + let values = concatenate_unchecked(&arrays)?; + + let dtype = FixedSizeListArray::default_datatype(inner_dtype.clone(), self.width); + Ok(FixedSizeListArray::new( + dtype, + self.length, + values, + self.validity + .and_then(|validity| validity.into_opt_validity()), + )) + } +} diff --git a/crates/polars-arrow/src/legacy/array/list.rs b/crates/polars-arrow/src/legacy/array/list.rs new file mode 100644 index 000000000000..4f0ebed03fb9 --- /dev/null +++ b/crates/polars-arrow/src/legacy/array/list.rs @@ -0,0 +1,167 @@ +use polars_error::PolarsResult; + +use crate::array::{Array, ArrayRef, ListArray, NullArray, new_null_array}; +use crate::bitmap::BitmapBuilder; +use crate::compute::concatenate; +use crate::datatypes::ArrowDataType; +use crate::legacy::array::is_nested_null; +use crate::legacy::prelude::*; +use crate::offset::Offsets; + +pub struct AnonymousBuilder<'a> { + arrays: Vec<&'a dyn Array>, + offsets: Vec, + validity: Option, + size: i64, +} + +impl<'a> AnonymousBuilder<'a> { + pub fn new(size: usize) -> Self { + let mut offsets = Vec::with_capacity(size + 1); + offsets.push(0i64); + Self { + arrays: Vec::with_capacity(size), + offsets, + validity: None, + size: 0, + } + } + #[inline] + fn last_offset(&self) -> i64 { + *self.offsets.last().unwrap() + } + + pub fn is_empty(&self) -> bool { + self.offsets.len() == 1 + } + + pub fn offsets(&self) -> &[i64] { + &self.offsets + } + + pub fn take_offsets(self) -> Offsets { + // SAFETY: offsets are correct + unsafe { Offsets::new_unchecked(self.offsets) } + } + + #[inline] + pub fn push(&mut self, arr: &'a dyn Array) { + self.size += arr.len() as i64; + self.offsets.push(self.size); + self.arrays.push(arr); + + if let Some(validity) = &mut self.validity { + validity.push(true) + } + } + + pub fn push_multiple(&mut self, arrs: &'a [ArrayRef]) { + for arr in arrs { + self.size += arr.len() as i64; + self.arrays.push(arr.as_ref()); + } + self.offsets.push(self.size); + self.update_validity() + } + + #[inline] + pub fn push_null(&mut self) { + self.offsets.push(self.last_offset()); + match &mut self.validity { + Some(validity) => validity.push(false), + None => self.init_validity(), + } + } + + #[inline] + pub fn push_opt(&mut self, arr: Option<&'a dyn Array>) { + match arr { + None => self.push_null(), + Some(arr) => self.push(arr), + } + } + + pub fn push_empty(&mut self) { + self.offsets.push(self.last_offset()); + self.update_validity() + } + + fn init_validity(&mut self) { + let len = self.offsets.len() - 1; + let mut validity = BitmapBuilder::with_capacity(self.offsets.capacity()); + if len > 0 { + validity.extend_constant(len - 1, true); + validity.push(false); + } + self.validity = Some(validity) + } + + fn update_validity(&mut self) { + if let Some(validity) = &mut self.validity { + validity.push(true) + } + } + + pub fn finish(self, inner_dtype: Option<&ArrowDataType>) -> PolarsResult> { + // SAFETY: + // offsets are monotonically increasing + let offsets = unsafe { Offsets::new_unchecked(self.offsets) }; + let (inner_dtype, values) = if self.arrays.is_empty() { + let len = *offsets.last() as usize; + match inner_dtype { + None => { + let values = NullArray::new(ArrowDataType::Null, len).boxed(); + (ArrowDataType::Null, values) + }, + Some(inner_dtype) => { + let values = new_null_array(inner_dtype.clone(), len); + (inner_dtype.clone(), values) + }, + } + } else { + let inner_dtype = inner_dtype.unwrap_or_else(|| self.arrays[0].dtype()); + + // check if there is a dtype that is not `Null` + // if we find it, we will convert the null arrays + // to empty arrays of this dtype, otherwise the concat kernel fails. + let mut non_null_dtype = None; + if is_nested_null(inner_dtype) { + for arr in &self.arrays { + if !is_nested_null(arr.dtype()) { + non_null_dtype = Some(arr.dtype()); + break; + } + } + }; + + // there are null arrays found, ensure the types are correct. + if let Some(dtype) = non_null_dtype { + let arrays = self + .arrays + .iter() + .map(|arr| { + if is_nested_null(arr.dtype()) { + convert_inner_type(&**arr, dtype) + } else { + arr.to_boxed() + } + }) + .collect::>(); + + let values = concatenate::concatenate_unchecked(&arrays)?; + (dtype.clone(), values) + } else { + let values = concatenate::concatenate(&self.arrays)?; + (inner_dtype.clone(), values) + } + }; + let dtype = ListArray::::default_datatype(inner_dtype); + Ok(ListArray::::new( + dtype, + offsets.into(), + values, + self.validity + .and_then(|validity| validity.into_opt_validity()), + )) + } +} diff --git a/crates/polars-arrow/src/legacy/array/mod.rs b/crates/polars-arrow/src/legacy/array/mod.rs new file mode 100644 index 000000000000..7f1358f92d03 --- /dev/null +++ b/crates/polars-arrow/src/legacy/array/mod.rs @@ -0,0 +1,252 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use crate::array::{ + Array, BooleanArray, FixedSizeListArray, ListArray, MutableBinaryViewArray, PrimitiveArray, + StructArray, ViewType, new_null_array, +}; +use crate::bitmap::BitmapBuilder; +use crate::datatypes::ArrowDataType; +use crate::legacy::utils::CustomIterTools; +use crate::offset::Offsets; +use crate::types::NativeType; + +pub mod default_arrays; +#[cfg(feature = "dtype-array")] +pub mod fixed_size_list; +pub mod list; +pub mod null; +pub mod slice; +pub mod utf8; + +pub use slice::*; + +use crate::legacy::prelude::LargeListArray; + +macro_rules! iter_to_values { + ($iterator:expr, $validity:expr, $offsets:expr, $length_so_far:expr) => {{ + $iterator + .filter_map(|opt_iter| match opt_iter { + Some(x) => { + let it = x.into_iter(); + $length_so_far += it.size_hint().0 as i64; + $validity.push(true); + $offsets.push($length_so_far); + Some(it) + }, + None => { + $validity.push(false); + $offsets.push($length_so_far); + None + }, + }) + .flatten() + .collect() + }}; +} + +pub trait ListFromIter { + /// Create a list-array from an iterator. + /// Used in group_by agg-list + /// + /// # Safety + /// Will produce incorrect arrays if size hint is incorrect. + unsafe fn from_iter_primitive_trusted_len( + iter: I, + dtype: ArrowDataType, + ) -> ListArray + where + T: NativeType, + P: IntoIterator>, + I: IntoIterator>, + { + let iterator = iter.into_iter(); + let (lower, _) = iterator.size_hint(); + + let mut validity = BitmapBuilder::with_capacity(lower); + let mut offsets = Vec::::with_capacity(lower + 1); + let mut length_so_far = 0i64; + offsets.push(length_so_far); + + let values: PrimitiveArray = iter_to_values!(iterator, validity, offsets, length_so_far); + + // SAFETY: + // offsets are monotonically increasing + ListArray::new( + ListArray::::default_datatype(dtype.clone()), + Offsets::new_unchecked(offsets).into(), + Box::new(values.to(dtype)), + validity.into_opt_validity(), + ) + } + + /// Create a list-array from an iterator. + /// Used in group_by agg-list + /// + /// # Safety + /// Will produce incorrect arrays if size hint is incorrect. + unsafe fn from_iter_bool_trusted_len(iter: I) -> ListArray + where + I: IntoIterator>, + P: IntoIterator>, + { + let iterator = iter.into_iter(); + let (lower, _) = iterator.size_hint(); + + let mut validity = Vec::with_capacity(lower); + let mut offsets = Vec::::with_capacity(lower + 1); + let mut length_so_far = 0i64; + offsets.push(length_so_far); + + let values: BooleanArray = iter_to_values!(iterator, validity, offsets, length_so_far); + + // SAFETY: + // Offsets are monotonically increasing. + ListArray::new( + ListArray::::default_datatype(ArrowDataType::Boolean), + Offsets::new_unchecked(offsets).into(), + Box::new(values), + Some(validity.into()), + ) + } + + /// # Safety + /// Will produce incorrect arrays if size hint is incorrect. + unsafe fn from_iter_binview_trusted_len( + iter: I, + n_elements: usize, + ) -> ListArray + where + I: IntoIterator>, + P: IntoIterator>, + Ref: AsRef, + { + let iterator = iter.into_iter(); + let (lower, _) = iterator.size_hint(); + + let mut validity = BitmapBuilder::with_capacity(lower); + let mut offsets = Vec::::with_capacity(lower + 1); + let mut length_so_far = 0i64; + offsets.push(length_so_far); + + let values: MutableBinaryViewArray = iterator + .filter_map(|opt_iter| match opt_iter { + Some(x) => { + let it = x.into_iter(); + length_so_far += it.size_hint().0 as i64; + validity.push(true); + offsets.push(length_so_far); + Some(it) + }, + None => { + validity.push(false); + offsets.push(length_so_far); + None + }, + }) + .flatten() + .trust_my_length(n_elements) + .collect(); + + // SAFETY: + // offsets are monotonically increasing + ListArray::new( + ListArray::::default_datatype(T::DATA_TYPE), + Offsets::new_unchecked(offsets).into(), + values.freeze().boxed(), + validity.into_opt_validity(), + ) + } + + /// Create a list-array from an iterator. + /// Used in group_by agg-list + /// + /// # Safety + /// Will produce incorrect arrays if size hint is incorrect. + unsafe fn from_iter_utf8_trusted_len(iter: I, n_elements: usize) -> ListArray + where + I: IntoIterator>, + P: IntoIterator>, + Ref: AsRef, + { + Self::from_iter_binview_trusted_len(iter, n_elements) + } + + /// Create a list-array from an iterator. + /// Used in group_by agg-list + /// + /// # Safety + /// Will produce incorrect arrays if size hint is incorrect. + unsafe fn from_iter_binary_trusted_len(iter: I, n_elements: usize) -> ListArray + where + I: IntoIterator>, + P: IntoIterator>, + Ref: AsRef<[u8]>, + { + Self::from_iter_binview_trusted_len(iter, n_elements) + } +} +impl ListFromIter for ListArray {} + +fn is_nested_null(dtype: &ArrowDataType) -> bool { + match dtype { + ArrowDataType::Null => true, + ArrowDataType::LargeList(field) => is_nested_null(field.dtype()), + ArrowDataType::FixedSizeList(field, _) => is_nested_null(field.dtype()), + ArrowDataType::Struct(fields) => fields.iter().all(|field| is_nested_null(field.dtype())), + _ => false, + } +} + +/// Cast null arrays to inner type and ensure that all offsets remain correct +pub fn convert_inner_type(array: &dyn Array, dtype: &ArrowDataType) -> Box { + match dtype { + ArrowDataType::LargeList(field) => { + let array = array.as_any().downcast_ref::().unwrap(); + let inner = array.values(); + let new_values = convert_inner_type(inner.as_ref(), field.dtype()); + let dtype = LargeListArray::default_datatype(new_values.dtype().clone()); + LargeListArray::new( + dtype, + array.offsets().clone(), + new_values, + array.validity().cloned(), + ) + .boxed() + }, + ArrowDataType::FixedSizeList(field, width) => { + let width = *width; + + let array = array.as_any().downcast_ref::().unwrap(); + let inner = array.values(); + let length = if width == array.size() { + array.len() + } else { + assert!(!array.values().is_empty() || width != 0); + if width == 0 { + 0 + } else { + array.values().len() / width + } + }; + let new_values = convert_inner_type(inner.as_ref(), field.dtype()); + let dtype = FixedSizeListArray::default_datatype(new_values.dtype().clone(), width); + FixedSizeListArray::new(dtype, length, new_values, array.validity().cloned()).boxed() + }, + ArrowDataType::Struct(fields) => { + let array = array.as_any().downcast_ref::().unwrap(); + let inner = array.values(); + let new_values = inner + .iter() + .zip(fields) + .map(|(arr, field)| convert_inner_type(arr.as_ref(), field.dtype())) + .collect::>(); + StructArray::new( + dtype.clone(), + array.len(), + new_values, + array.validity().cloned(), + ) + .boxed() + }, + _ => new_null_array(dtype.clone(), array.len()), + } +} diff --git a/crates/polars-arrow/src/legacy/array/null.rs b/crates/polars-arrow/src/legacy/array/null.rs new file mode 100644 index 000000000000..ec630250e0c5 --- /dev/null +++ b/crates/polars-arrow/src/legacy/array/null.rs @@ -0,0 +1,58 @@ +use std::any::Any; + +use crate::array::{Array, MutableArray, NullArray}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::ArrowDataType; + +#[derive(Debug, Default, Clone)] +pub struct MutableNullArray { + len: usize, +} + +impl MutableArray for MutableNullArray { + fn dtype(&self) -> &ArrowDataType { + &ArrowDataType::Null + } + + fn len(&self) -> usize { + self.len + } + + fn validity(&self) -> Option<&MutableBitmap> { + None + } + + fn as_box(&mut self) -> Box { + Box::new(NullArray::new_null(ArrowDataType::Null, self.len)) + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn Any { + self + } + + fn push_null(&mut self) { + self.len += 1; + } + + fn reserve(&mut self, _additional: usize) { + // no-op + } + + fn shrink_to_fit(&mut self) { + // no-op + } +} + +impl MutableNullArray { + pub fn new(len: usize) -> Self { + MutableNullArray { len } + } + + pub fn extend_nulls(&mut self, null_count: usize) { + self.len += null_count; + } +} diff --git a/crates/polars-arrow/src/legacy/array/slice.rs b/crates/polars-arrow/src/legacy/array/slice.rs new file mode 100644 index 000000000000..720723c901a8 --- /dev/null +++ b/crates/polars-arrow/src/legacy/array/slice.rs @@ -0,0 +1,38 @@ +use crate::array::Array; + +/// Utility trait to slice concrete arrow arrays whilst keeping their +/// concrete type. E.g. don't return `Box`. +pub trait SlicedArray { + /// Slices this [`Array`]. + /// # Implementation + /// This operation is `O(1)` over `len`. + /// # Panic + /// This function panics iff `offset + length > self.len()`. + fn slice_typed(&self, offset: usize, length: usize) -> Self + where + Self: Sized; + + /// Slices the [`Array`]. + /// # Implementation + /// This operation is `O(1)`. + /// + /// # Safety + /// The caller must ensure that `offset + length <= self.len()` + unsafe fn slice_typed_unchecked(&self, offset: usize, length: usize) -> Self + where + Self: Sized; +} + +impl SlicedArray for T { + fn slice_typed(&self, offset: usize, length: usize) -> Self { + let mut arr = self.clone(); + arr.slice(offset, length); + arr + } + + unsafe fn slice_typed_unchecked(&self, offset: usize, length: usize) -> Self { + let mut arr = self.clone(); + arr.slice_unchecked(offset, length); + arr + } +} diff --git a/crates/polars-arrow/src/legacy/array/utf8.rs b/crates/polars-arrow/src/legacy/array/utf8.rs new file mode 100644 index 000000000000..db931bc88d20 --- /dev/null +++ b/crates/polars-arrow/src/legacy/array/utf8.rs @@ -0,0 +1,94 @@ +use crate::array::{BinaryArray, Utf8Array}; +use crate::datatypes::ArrowDataType; +use crate::legacy::trusted_len::TrustedLenPush; +use crate::offset::Offsets; + +#[inline] +unsafe fn extend_from_trusted_len_values_iter( + offsets: &mut Vec, + values: &mut Vec, + iterator: I, +) where + P: AsRef<[u8]>, + I: Iterator, +{ + let mut total_length = 0; + offsets.push(total_length); + iterator.for_each(|item| { + let s = item.as_ref(); + // Push new entries for both `values` and `offsets` buffer + values.extend_from_slice(s); + + total_length += s.len() as i64; + offsets.push_unchecked(total_length); + }); +} + +/// # Safety +/// reported `len` must be correct. +#[inline] +unsafe fn fill_offsets_and_values( + iterator: I, + value_capacity: usize, + len: usize, +) -> (Offsets, Vec) +where + P: AsRef<[u8]>, + I: Iterator, +{ + let mut offsets = Vec::with_capacity(len + 1); + let mut values = Vec::::with_capacity(value_capacity); + + extend_from_trusted_len_values_iter(&mut offsets, &mut values, iterator); + + (Offsets::new_unchecked(offsets), values) +} + +struct StrAsBytes

(P); +impl> AsRef<[u8]> for StrAsBytes { + #[inline(always)] + fn as_ref(&self) -> &[u8] { + self.0.as_ref().as_bytes() + } +} + +pub trait Utf8FromIter { + #[inline] + fn from_values_iter(iter: I, len: usize, size_hint: usize) -> Utf8Array + where + S: AsRef, + I: Iterator, + { + let iter = iter.map(StrAsBytes); + let (offsets, values) = unsafe { fill_offsets_and_values(iter, size_hint, len) }; + unsafe { + Utf8Array::new_unchecked( + ArrowDataType::LargeUtf8, + offsets.into(), + values.into(), + None, + ) + } + } +} + +impl Utf8FromIter for Utf8Array {} + +pub trait BinaryFromIter { + #[inline] + fn from_values_iter(iter: I, len: usize, value_cap: usize) -> BinaryArray + where + S: AsRef<[u8]>, + I: Iterator, + { + let (offsets, values) = unsafe { fill_offsets_and_values(iter, value_cap, len) }; + BinaryArray::new( + ArrowDataType::LargeBinary, + offsets.into(), + values.into(), + None, + ) + } +} + +impl BinaryFromIter for BinaryArray {} diff --git a/crates/polars-arrow/src/legacy/bit_util.rs b/crates/polars-arrow/src/legacy/bit_util.rs new file mode 100644 index 000000000000..1680bfe7330a --- /dev/null +++ b/crates/polars-arrow/src/legacy/bit_util.rs @@ -0,0 +1,163 @@ +use crate::bitmap::Bitmap; +/// Forked from Arrow until their API stabilizes. +/// +/// Note that the bound checks are optimized away. +/// +use crate::bitmap::utils::{BitChunkIterExact, BitChunks, BitChunksExact}; + +fn first_set_bit_impl(mut mask_chunks: I) -> usize +where + I: BitChunkIterExact, +{ + let mut total = 0usize; + const SIZE: u32 = 64; + for chunk in &mut mask_chunks { + let pos = chunk.trailing_zeros(); + if pos != SIZE { + return total + pos as usize; + } else { + total += SIZE as usize + } + } + if let Some(pos) = mask_chunks.remainder_iter().position(|v| v) { + total += pos; + return total; + } + // all null, return the first + 0 +} + +pub fn first_set_bit(mask: &Bitmap) -> usize { + if mask.unset_bits() == 0 || mask.unset_bits() == mask.len() { + return 0; + } + let (slice, offset, length) = mask.as_slice(); + if offset == 0 { + let mask_chunks = BitChunksExact::::new(slice, length); + first_set_bit_impl(mask_chunks) + } else { + let mask_chunks = mask.chunks::(); + first_set_bit_impl(mask_chunks) + } +} + +fn first_unset_bit_impl(mut mask_chunks: I) -> usize +where + I: BitChunkIterExact, +{ + let mut total = 0usize; + const SIZE: u32 = 64; + for chunk in &mut mask_chunks { + let pos = chunk.trailing_ones(); + if pos != SIZE { + return total + pos as usize; + } else { + total += SIZE as usize + } + } + if let Some(pos) = mask_chunks.remainder_iter().position(|v| !v) { + total += pos; + return total; + } + // all null, return the first + 0 +} + +pub fn first_unset_bit(mask: &Bitmap) -> usize { + if mask.unset_bits() == 0 || mask.unset_bits() == mask.len() { + return 0; + } + let (slice, offset, length) = mask.as_slice(); + if offset == 0 { + let mask_chunks = BitChunksExact::::new(slice, length); + first_unset_bit_impl(mask_chunks) + } else { + let mask_chunks = mask.chunks::(); + first_unset_bit_impl(mask_chunks) + } +} + +pub fn find_first_true_false_null( + mut bit_chunks: BitChunks, + mut validity_chunks: BitChunks, +) -> (Option, Option, Option) { + let (mut true_index, mut false_index, mut null_index) = (None, None, None); + let (mut true_not_found_mask, mut false_not_found_mask, mut null_not_found_mask) = + (!0u64, !0u64, !0u64); // All ones while not found. + let mut offset: usize = 0; + let mut all_found = false; + for (truth_mask, null_mask) in (&mut bit_chunks).zip(&mut validity_chunks) { + let mask = null_mask & truth_mask & true_not_found_mask; + if mask > 0 { + true_index = Some(offset + mask.trailing_zeros() as usize); + true_not_found_mask = 0; + } + let mask = null_mask & !truth_mask & false_not_found_mask; + if mask > 0 { + false_index = Some(offset + mask.trailing_zeros() as usize); + false_not_found_mask = 0; + } + if !null_mask & null_not_found_mask > 0 { + null_index = Some(offset + null_mask.trailing_ones() as usize); + null_not_found_mask = 0; + } + if null_not_found_mask | true_not_found_mask | false_not_found_mask == 0 { + all_found = true; + break; + } + offset += 64; + } + if !all_found { + for (val, not_null) in bit_chunks + .remainder_iter() + .zip(validity_chunks.remainder_iter()) + { + if true_index.is_none() && not_null && val { + true_index = Some(offset); + } else if false_index.is_none() && not_null && !val { + false_index = Some(offset); + } else if null_index.is_none() && !not_null { + null_index = Some(offset); + } + offset += 1; + } + } + (true_index, false_index, null_index) +} + +pub fn find_first_true_false_no_null( + mut bit_chunks: BitChunks, +) -> (Option, Option) { + let (mut true_index, mut false_index) = (None, None); + let (mut true_not_found_mask, mut false_not_found_mask) = (!0u64, !0u64); // All ones while not found. + let mut offset: usize = 0; + let mut all_found = false; + for truth_mask in &mut bit_chunks { + let mask = truth_mask & true_not_found_mask; + if mask > 0 { + true_index = Some(offset + mask.trailing_zeros() as usize); + true_not_found_mask = 0; + } + let mask = !truth_mask & false_not_found_mask; + if mask > 0 { + false_index = Some(offset + mask.trailing_zeros() as usize); + false_not_found_mask = 0; + } + if true_not_found_mask | false_not_found_mask == 0 { + all_found = true; + break; + } + offset += 64; + } + if !all_found { + for val in bit_chunks.remainder_iter() { + if true_index.is_none() && val { + true_index = Some(offset); + } else if false_index.is_none() && !val { + false_index = Some(offset); + } + offset += 1; + } + } + (true_index, false_index) +} diff --git a/crates/polars-arrow/src/legacy/conversion.rs b/crates/polars-arrow/src/legacy/conversion.rs new file mode 100644 index 000000000000..9504e45d878c --- /dev/null +++ b/crates/polars-arrow/src/legacy/conversion.rs @@ -0,0 +1,24 @@ +use crate::array::{ArrayRef, PrimitiveArray, StructArray}; +use crate::datatypes::{ArrowDataType, Field}; +use crate::record_batch::RecordBatchT; +use crate::types::NativeType; + +pub fn chunk_to_struct(chunk: RecordBatchT, fields: Vec) -> StructArray { + let dtype = ArrowDataType::Struct(fields); + StructArray::new(dtype, chunk.len(), chunk.into_arrays(), None) +} + +/// Returns its underlying [`Vec`], if possible. +/// +/// This operation returns [`Some`] iff this [`PrimitiveArray`]: +/// * has not been sliced with an offset +/// * has not been cloned (i.e. [`Arc::get_mut`][Arc::get_mut] yields [`Some`]) +/// * has not been imported from the c data interface (FFI) +/// +/// [Arc::get_mut]: std::sync::Arc::get_mut +pub fn primitive_to_vec(arr: ArrayRef) -> Option> { + let arr_ref = arr.as_any().downcast_ref::>().unwrap(); + let buffer = arr_ref.values().clone(); + drop(arr); // Drop original reference so refcount becomes 1 if possible. + buffer.into_mut().right() +} diff --git a/crates/polars-arrow/src/legacy/error.rs b/crates/polars-arrow/src/legacy/error.rs new file mode 100644 index 000000000000..415c77c3bd2c --- /dev/null +++ b/crates/polars-arrow/src/legacy/error.rs @@ -0,0 +1 @@ +pub use polars_error::*; diff --git a/crates/polars-arrow/src/legacy/index.rs b/crates/polars-arrow/src/legacy/index.rs new file mode 100644 index 000000000000..9ee157c08d19 --- /dev/null +++ b/crates/polars-arrow/src/legacy/index.rs @@ -0,0 +1,49 @@ +use std::fmt::Display; + +use num_traits::{NumCast, Signed, Zero}; +use polars_error::{PolarsResult, polars_err}; +use polars_utils::IdxSize; + +use crate::array::PrimitiveArray; + +pub trait IndexToUsize: Display { + /// Translate the negative index to an offset. + fn negative_to_usize(self, len: usize) -> Option; + + fn try_negative_to_usize(self, len: usize) -> PolarsResult + where + Self: Sized + Copy, + { + self.negative_to_usize(len) + .ok_or_else(|| polars_err!(OutOfBounds: "index {} for length: {}", self, len)) + } +} + +impl IndexToUsize for I +where + I: PartialOrd + PartialEq + NumCast + Signed + Zero + Display, +{ + #[inline] + fn negative_to_usize(self, len: usize) -> Option { + if self >= Zero::zero() { + if (self.to_usize().unwrap()) < len { + Some(self.to_usize().unwrap()) + } else { + None + } + } else { + let subtract = self.abs().to_usize().unwrap(); + if subtract > len { + None + } else { + Some(len - subtract) + } + } + } +} + +pub fn indexes_to_usizes(idx: &[IdxSize]) -> impl Iterator + '_ { + idx.iter().map(|idx| *idx as usize) +} + +pub type IdxArr = PrimitiveArray; diff --git a/crates/polars-arrow/src/legacy/is_valid.rs b/crates/polars-arrow/src/legacy/is_valid.rs new file mode 100644 index 000000000000..34cd98d6ba28 --- /dev/null +++ b/crates/polars-arrow/src/legacy/is_valid.rs @@ -0,0 +1,27 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use crate::array::{ + Array, BinaryArray, BooleanArray, FixedSizeListArray, ListArray, PrimitiveArray, Utf8Array, +}; +use crate::types::NativeType; + +pub trait IsValid { + /// # Safety + /// no bound checks + unsafe fn is_valid_unchecked(&self, i: usize) -> bool; +} + +pub trait ArrowArray: Array {} + +impl ArrowArray for BinaryArray {} +impl ArrowArray for Utf8Array {} +impl ArrowArray for PrimitiveArray {} +impl ArrowArray for BooleanArray {} +impl ArrowArray for ListArray {} +impl ArrowArray for FixedSizeListArray {} + +impl IsValid for A { + #[inline] + unsafe fn is_valid_unchecked(&self, i: usize) -> bool { + !self.is_null_unchecked(i) + } +} diff --git a/crates/polars-arrow/src/legacy/kernels/ewm/average.rs b/crates/polars-arrow/src/legacy/kernels/ewm/average.rs new file mode 100644 index 000000000000..6d5aed26061f --- /dev/null +++ b/crates/polars-arrow/src/legacy/kernels/ewm/average.rs @@ -0,0 +1,165 @@ +use std::ops::{AddAssign, MulAssign}; + +use num_traits::Float; + +use crate::array::PrimitiveArray; +use crate::legacy::utils::CustomIterTools; +use crate::trusted_len::TrustedLen; +use crate::types::NativeType; + +pub fn ewm_mean( + xs: I, + alpha: T, + adjust: bool, + min_periods: usize, + ignore_nulls: bool, +) -> PrimitiveArray +where + I: IntoIterator>, + I::IntoIter: TrustedLen, + T: Float + NativeType + AddAssign + MulAssign, +{ + let new_wt = if adjust { T::one() } else { alpha }; + let old_wt_factor = T::one() - alpha; + let mut old_wt = T::one(); + let mut weighted_avg = None; + let mut non_null_cnt = 0usize; + + xs.into_iter() + .enumerate() + .map(|(i, opt_x)| { + if opt_x.is_some() { + non_null_cnt += 1; + } + match (i, weighted_avg) { + (0, _) | (_, None) => weighted_avg = opt_x, + (_, Some(w_avg)) => { + if opt_x.is_some() || !ignore_nulls { + old_wt *= old_wt_factor; + if let Some(x) = opt_x { + if w_avg != x { + weighted_avg = + Some((old_wt * w_avg + new_wt * x) / (old_wt + new_wt)); + } + old_wt = if adjust { old_wt + new_wt } else { T::one() }; + } + } + }, + } + match (non_null_cnt < min_periods, opt_x.is_some()) { + (_, false) => None, + (true, true) => None, + (false, true) => weighted_avg, + } + }) + .collect_trusted() +} + +#[cfg(test)] +mod test { + use super::super::assert_allclose; + use super::*; + const ALPHA: f64 = 0.5; + const EPS: f64 = 1e-15; + + #[test] + fn test_ewm_mean_without_null() { + let xs: Vec> = vec![Some(1.0), Some(2.0), Some(3.0)]; + for adjust in [false, true] { + for ignore_nulls in [false, true] { + for min_periods in [0, 1] { + let result = ewm_mean(xs.clone(), ALPHA, adjust, min_periods, ignore_nulls); + let expected = match adjust { + false => PrimitiveArray::from([Some(1.0f64), Some(1.5f64), Some(2.25f64)]), + true => PrimitiveArray::from([ + Some(1.0), + Some(1.666_666_666_666_666_7), + Some(2.428_571_428_571_428_4), + ]), + }; + assert_allclose!(result, expected, 1e-15); + } + let result = ewm_mean(xs.clone(), ALPHA, adjust, 2, ignore_nulls); + let expected = match adjust { + false => PrimitiveArray::from([None, Some(1.5f64), Some(2.25f64)]), + true => PrimitiveArray::from([ + None, + Some(1.666_666_666_666_666_7), + Some(2.428_571_428_571_428_4), + ]), + }; + assert_allclose!(result, expected, EPS); + } + } + } + + #[test] + fn test_ewm_mean_with_null() { + let xs1 = vec![ + None, + None, + Some(5.0f64), + Some(7.0f64), + None, + Some(2.0f64), + Some(1.0f64), + Some(4.0f64), + ]; + assert_allclose!( + ewm_mean(xs1.clone(), 0.5, true, 0, true), + PrimitiveArray::from([ + None, + None, + Some(5.0), + Some(6.333_333_333_333_333), + None, + Some(3.857_142_857_142_857), + Some(2.333_333_333_333_333_5), + Some(3.193_548_387_096_774), + ]), + EPS + ); + assert_allclose!( + ewm_mean(xs1.clone(), 0.5, true, 0, false), + PrimitiveArray::from([ + None, + None, + Some(5.0), + Some(6.333_333_333_333_333), + None, + Some(3.181_818_181_818_181_7), + Some(1.888_888_888_888_888_8), + Some(3.033_898_305_084_745_7), + ]), + EPS + ); + assert_allclose!( + ewm_mean(xs1.clone(), 0.5, false, 0, true), + PrimitiveArray::from([ + None, + None, + Some(5.0), + Some(6.0), + None, + Some(4.0), + Some(2.5), + Some(3.25), + ]), + EPS + ); + assert_allclose!( + ewm_mean(xs1, 0.5, false, 0, false), + PrimitiveArray::from([ + None, + None, + Some(5.0), + Some(6.0), + None, + Some(3.333_333_333_333_333_5), + Some(2.166_666_666_666_667), + Some(3.083_333_333_333_333_5), + ]), + EPS + ); + } +} diff --git a/crates/polars-arrow/src/legacy/kernels/ewm/mod.rs b/crates/polars-arrow/src/legacy/kernels/ewm/mod.rs new file mode 100644 index 000000000000..5b8f3a5fb3fb --- /dev/null +++ b/crates/polars-arrow/src/legacy/kernels/ewm/mod.rs @@ -0,0 +1,92 @@ +mod average; +mod variance; + +use std::hash::{Hash, Hasher}; + +pub use average::*; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; +pub use variance::*; + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Copy, Clone, PartialEq)] +#[must_use] +pub struct EWMOptions { + pub alpha: f64, + pub adjust: bool, + pub bias: bool, + pub min_periods: usize, + pub ignore_nulls: bool, +} + +impl Default for EWMOptions { + fn default() -> Self { + Self { + alpha: 0.5, + adjust: true, + bias: false, + min_periods: 1, + ignore_nulls: true, + } + } +} + +impl Hash for EWMOptions { + fn hash(&self, state: &mut H) { + self.alpha.to_bits().hash(state); + self.adjust.hash(state); + self.bias.hash(state); + self.min_periods.hash(state); + self.ignore_nulls.hash(state); + } +} + +impl EWMOptions { + pub fn and_min_periods(mut self, min_periods: usize) -> Self { + self.min_periods = min_periods; + self + } + pub fn and_adjust(mut self, adjust: bool) -> Self { + self.adjust = adjust; + self + } + pub fn and_span(mut self, span: usize) -> Self { + assert!(span >= 1); + self.alpha = 2.0 / (span as f64 + 1.0); + self + } + pub fn and_half_life(mut self, half_life: f64) -> Self { + assert!(half_life > 0.0); + self.alpha = 1.0 - (-(2.0f64.ln()) / half_life).exp(); + self + } + pub fn and_com(mut self, com: f64) -> Self { + assert!(com > 0.0); + self.alpha = 1.0 / (1.0 + com); + self + } + pub fn and_ignore_nulls(mut self, ignore_nulls: bool) -> Self { + self.ignore_nulls = ignore_nulls; + self + } +} + +#[cfg(test)] +macro_rules! assert_allclose { + ($xs:expr, $ys:expr, $tol:expr) => { + assert!( + $xs.iter() + .zip($ys.iter()) + .map(|(x, z)| { + match (x, z) { + (Some(a), Some(b)) => (a - b).abs() < $tol, + (None, None) => true, + _ => false, + } + }) + .fold(true, |acc, b| acc && b) + ); + }; +} +#[cfg(test)] +pub(crate) use assert_allclose; diff --git a/crates/polars-arrow/src/legacy/kernels/ewm/variance.rs b/crates/polars-arrow/src/legacy/kernels/ewm/variance.rs new file mode 100644 index 000000000000..42a1d0c95e21 --- /dev/null +++ b/crates/polars-arrow/src/legacy/kernels/ewm/variance.rs @@ -0,0 +1,796 @@ +use std::ops::{AddAssign, DivAssign, MulAssign}; + +use num_traits::Float; + +use crate::array::PrimitiveArray; +use crate::legacy::utils::CustomIterTools; +use crate::trusted_len::TrustedLen; +use crate::types::NativeType; + +#[allow(clippy::too_many_arguments)] +fn ewm_cov_internal( + xs: I, + ys: I, + alpha: T, + adjust: bool, + bias: bool, + min_periods: usize, + ignore_nulls: bool, + do_sqrt: bool, +) -> PrimitiveArray +where + I: IntoIterator>, + I::IntoIter: TrustedLen, + T: Float + NativeType + AddAssign + MulAssign + DivAssign, +{ + let old_wt_factor = T::one() - alpha; + let new_wt = if adjust { T::one() } else { alpha }; + let mut sum_wt = T::one(); + let mut sum_wt2 = T::one(); + let mut old_wt = T::one(); + + let mut opt_mean_x = None; + let mut opt_mean_y = None; + let mut cov = T::zero(); + let mut non_na_cnt = 0usize; + let min_periods_fixed = if min_periods == 0 { 1 } else { min_periods }; + + let res = xs + .into_iter() + .zip(ys) + .enumerate() + .map(|(i, (opt_x, opt_y))| { + let is_observation = opt_x.is_some() && opt_y.is_some(); + if is_observation { + non_na_cnt += 1; + } + match (i, opt_mean_x, opt_mean_y) { + (0, _, _) => { + if is_observation { + opt_mean_x = opt_x; + opt_mean_y = opt_y; + } + }, + (_, Some(mean_x), Some(mean_y)) => { + if is_observation || !ignore_nulls { + sum_wt *= old_wt_factor; + sum_wt2 *= old_wt_factor * old_wt_factor; + old_wt *= old_wt_factor; + if is_observation { + let x = opt_x.unwrap(); + let y = opt_y.unwrap(); + let old_mean_x = mean_x; + let old_mean_y = mean_y; + + // avoid numerical errors on constant series + if mean_x != x { + opt_mean_x = + Some((old_wt * old_mean_x + new_wt * x) / (old_wt + new_wt)); + } + + // avoid numerical errors on constant series + if mean_y != y { + opt_mean_y = + Some((old_wt * old_mean_y + new_wt * y) / (old_wt + new_wt)); + } + + cov = ((old_wt + * (cov + + ((old_mean_x - opt_mean_x.unwrap()) + * (old_mean_y - opt_mean_y.unwrap())))) + + (new_wt + * ((x - opt_mean_x.unwrap()) * (y - opt_mean_y.unwrap())))) + / (old_wt + new_wt); + + sum_wt += new_wt; + sum_wt2 += new_wt * new_wt; + old_wt += new_wt; + if !adjust { + sum_wt /= old_wt; + sum_wt2 /= old_wt * old_wt; + old_wt = T::one(); + } + } + } + }, + _ => { + if is_observation { + opt_mean_x = opt_x; + opt_mean_y = opt_y; + } + }, + } + match (non_na_cnt >= min_periods_fixed, bias, is_observation) { + (_, _, false) => None, + (false, _, true) => None, + (true, false, true) => { + if non_na_cnt == 1 { + Some(cov) + } else { + let numerator = sum_wt * sum_wt; + let denominator = numerator - sum_wt2; + if denominator > T::zero() { + Some((numerator / denominator) * cov) + } else { + None + } + } + }, + (true, true, true) => Some(cov), + } + }); + + if do_sqrt { + res.map(|opt_x| opt_x.map(|x| x.sqrt())).collect_trusted() + } else { + res.collect_trusted() + } +} + +pub fn ewm_cov( + xs: I, + ys: I, + alpha: T, + adjust: bool, + bias: bool, + min_periods: usize, + ignore_nulls: bool, +) -> PrimitiveArray +where + I: IntoIterator>, + I::IntoIter: TrustedLen, + T: Float + NativeType + AddAssign + MulAssign + DivAssign, +{ + ewm_cov_internal( + xs, + ys, + alpha, + adjust, + bias, + min_periods, + ignore_nulls, + false, + ) +} + +pub fn ewm_var( + xs: I, + alpha: T, + adjust: bool, + bias: bool, + min_periods: usize, + ignore_nulls: bool, +) -> PrimitiveArray +where + I: IntoIterator> + Clone, + I::IntoIter: TrustedLen, + T: Float + NativeType + AddAssign + MulAssign + DivAssign, +{ + ewm_cov_internal( + xs.clone(), + xs, + alpha, + adjust, + bias, + min_periods, + ignore_nulls, + false, + ) +} + +pub fn ewm_std( + xs: I, + alpha: T, + adjust: bool, + bias: bool, + min_periods: usize, + ignore_nulls: bool, +) -> PrimitiveArray +where + I: IntoIterator> + Clone, + I::IntoIter: TrustedLen, + T: Float + NativeType + AddAssign + MulAssign + DivAssign, +{ + ewm_cov_internal( + xs.clone(), + xs, + alpha, + adjust, + bias, + min_periods, + ignore_nulls, + true, + ) +} + +#[cfg(test)] +mod test { + use super::super::assert_allclose; + use super::*; + const ALPHA: f64 = 0.5; + const EPS: f64 = 1e-15; + use std::f64::consts::SQRT_2; + + const XS: [Option; 7] = [ + Some(1.0), + Some(5.0), + Some(7.0), + Some(1.0), + Some(2.0), + Some(1.0), + Some(4.0), + ]; + const YS: [Option; 7] = [None, Some(5.0), Some(7.0), None, None, Some(1.0), Some(4.0)]; + + #[test] + fn test_ewm_var() { + assert_allclose!( + ewm_var(XS.to_vec(), ALPHA, true, true, 0, true), + PrimitiveArray::from([ + Some(0.0), + Some(3.555_555_555_555_556), + Some(4.244_897_959_183_674), + Some(7.182_222_222_222_221), + Some(3.796_045_785_639_958), + Some(2.467_120_181_405_896), + Some(2.476_036_952_073_904_3), + ]), + EPS + ); + assert_allclose!( + ewm_var(XS.to_vec(), ALPHA, true, true, 0, false), + PrimitiveArray::from([ + Some(0.0), + Some(3.555_555_555_555_556), + Some(4.244_897_959_183_674), + Some(7.182_222_222_222_221), + Some(3.796_045_785_639_958), + Some(2.467_120_181_405_896), + Some(2.476_036_952_073_904_3), + ]), + EPS + ); + assert_allclose!( + ewm_var(XS.to_vec(), ALPHA, true, false, 0, true), + PrimitiveArray::from([ + Some(0.0), + Some(8.0), + Some(7.428_571_428_571_429), + Some(11.542_857_142_857_143), + Some(5.883_870_967_741_934_5), + Some(3.760_368_663_594_470_6), + Some(3.743_532_058_492_688_6), + ]), + EPS + ); + assert_allclose!( + ewm_var(XS.to_vec(), ALPHA, true, false, 0, false), + PrimitiveArray::from([ + Some(0.0), + Some(8.0), + Some(7.428_571_428_571_429), + Some(11.542_857_142_857_143), + Some(5.883_870_967_741_934_5), + Some(3.760_368_663_594_470_6), + Some(3.743_532_058_492_688_6), + ]), + EPS + ); + assert_allclose!( + ewm_var(XS.to_vec(), ALPHA, false, true, 0, true), + PrimitiveArray::from([ + Some(0.0), + Some(4.0), + Some(6.0), + Some(7.0), + Some(3.75), + Some(2.437_5), + Some(2.484_375), + ]), + EPS + ); + assert_allclose!( + ewm_var(XS.to_vec(), ALPHA, false, true, 0, false), + PrimitiveArray::from([ + Some(0.0), + Some(4.0), + Some(6.0), + Some(7.0), + Some(3.75), + Some(2.437_5), + Some(2.484_375), + ]), + EPS + ); + assert_allclose!( + ewm_var(XS.to_vec(), ALPHA, false, true, 0, false), + PrimitiveArray::from([ + Some(0.0), + Some(4.0), + Some(6.0), + Some(7.0), + Some(3.75), + Some(2.437_5), + Some(2.484_375), + ]), + EPS + ); + assert_allclose!( + ewm_var(XS.to_vec(), ALPHA, false, false, 0, true), + PrimitiveArray::from([ + Some(0.0), + Some(8.0), + Some(9.600_000_000_000_001), + Some(10.666_666_666_666_666), + Some(5.647_058_823_529_411), + Some(3.659_824_046_920_821), + Some(3.727_472_527_472_527_6), + ]), + EPS + ); + assert_allclose!( + ewm_var(XS.to_vec(), ALPHA, false, false, 0, false), + PrimitiveArray::from([ + Some(0.0), + Some(8.0), + Some(9.600_000_000_000_001), + Some(10.666_666_666_666_666), + Some(5.647_058_823_529_411), + Some(3.659_824_046_920_821), + Some(3.727_472_527_472_527_6), + ]), + EPS + ); + assert_allclose!( + ewm_var(YS.to_vec(), ALPHA, true, true, 0, true), + PrimitiveArray::from([ + None, + Some(0.0), + Some(0.888_888_888_888_889), + None, + None, + Some(7.346_938_775_510_203), + Some(3.555_555_555_555_555_4), + ]), + EPS + ); + assert_allclose!( + ewm_var(YS.to_vec(), ALPHA, true, true, 0, false), + PrimitiveArray::from([ + None, + Some(0.0), + Some(0.888_888_888_888_889), + None, + None, + Some(3.922_437_673_130_193_3), + Some(2.549_788_542_868_127_3), + ]), + EPS + ); + assert_allclose!( + ewm_var(YS.to_vec(), ALPHA, true, false, 0, true), + PrimitiveArray::from([ + None, + Some(0.0), + Some(2.0), + None, + None, + Some(12.857_142_857_142_856), + Some(5.714_285_714_285_714), + ]), + EPS + ); + assert_allclose!( + ewm_var(YS.to_vec(), ALPHA, true, false, 0, false), + PrimitiveArray::from([ + None, + Some(0.0), + Some(2.0), + None, + None, + Some(14.159_999_999_999_997), + Some(5.039_513_677_811_549_5), + ]), + EPS + ); + assert_allclose!( + ewm_var(YS.to_vec(), ALPHA, false, true, 0, true), + PrimitiveArray::from([ + None, + Some(0.0), + Some(1.0), + None, + None, + Some(6.75), + Some(3.437_5), + ]), + EPS + ); + assert_allclose!( + ewm_var(YS.to_vec(), ALPHA, false, true, 0, false), + PrimitiveArray::from([None, Some(0.0), Some(1.0), None, None, Some(4.2), Some(3.1)]), + EPS + ); + assert_allclose!( + ewm_var(YS.to_vec(), ALPHA, false, false, 0, true), + PrimitiveArray::from([ + None, + Some(0.0), + Some(2.0), + None, + None, + Some(10.8), + Some(5.238_095_238_095_238), + ]), + EPS + ); + assert_allclose!( + ewm_var(YS.to_vec(), ALPHA, false, false, 0, false), + PrimitiveArray::from([ + None, + Some(0.0), + Some(2.0), + None, + None, + Some(12.352_941_176_470_589), + Some(5.299_145_299_145_3), + ]), + EPS + ); + } + + #[test] + fn test_ewm_cov() { + assert_allclose!( + ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, true, true, 0, true), + PrimitiveArray::from([ + None, + Some(0.0), + Some(0.888_888_888_888_889), + None, + None, + Some(7.346_938_775_510_203), + Some(3.555_555_555_555_555_4) + ]), + EPS + ); + assert_allclose!( + ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, true, true, 0, false), + PrimitiveArray::from([ + None, + Some(0.0), + Some(0.888_888_888_888_889), + None, + None, + Some(3.922_437_673_130_193_3), + Some(2.549_788_542_868_127_3) + ]), + EPS + ); + assert_allclose!( + ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, true, false, 0, true), + PrimitiveArray::from([ + None, + Some(0.0), + Some(2.0), + None, + None, + Some(12.857_142_857_142_856), + Some(5.714_285_714_285_714) + ]), + EPS + ); + assert_allclose!( + ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, true, false, 0, false), + PrimitiveArray::from([ + None, + Some(0.0), + Some(2.0), + None, + None, + Some(14.159_999_999_999_997), + Some(5.039_513_677_811_549_5) + ]), + EPS + ); + assert_allclose!( + ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, false, true, 0, true), + PrimitiveArray::from([ + None, + Some(0.0), + Some(1.0), + None, + None, + Some(6.75), + Some(3.437_5) + ]), + EPS + ); + assert_allclose!( + ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, false, true, 0, false), + PrimitiveArray::from([None, Some(0.0), Some(1.0), None, None, Some(4.2), Some(3.1)]), + EPS + ); + assert_allclose!( + ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, false, false, 0, true), + PrimitiveArray::from([ + None, + Some(0.0), + Some(2.0), + None, + None, + Some(10.8), + Some(5.238_095_238_095_238) + ]), + EPS + ); + assert_allclose!( + ewm_cov(XS.to_vec(), YS.to_vec(), ALPHA, false, false, 0, false), + PrimitiveArray::from([ + None, + Some(0.0), + Some(2.0), + None, + None, + Some(12.352_941_176_470_589), + Some(5.299_145_299_145_3) + ]), + EPS + ); + } + + #[test] + fn test_ewm_std() { + assert_allclose!( + ewm_std(XS.to_vec(), ALPHA, true, true, 0, true), + PrimitiveArray::from([ + Some(0.0), + Some(1.885_618_083_164_126_7), + Some(2.060_315_014_550_851_3), + Some(2.679_966_832_298_904), + Some(1.948_344_370_392_451_5), + Some(1.570_706_904_997_204_2), + Some(1.573_542_802_746_053_2), + ]), + EPS + ); + assert_allclose!( + ewm_std(XS.to_vec(), ALPHA, true, true, 0, false), + PrimitiveArray::from([ + Some(0.0), + Some(1.885_618_083_164_126_7), + Some(2.060_315_014_550_851_3), + Some(2.679_966_832_298_904), + Some(1.948_344_370_392_451_5), + Some(1.570_706_904_997_204_2), + Some(1.573_542_802_746_053_2), + ]), + EPS + ); + assert_allclose!( + ewm_std(XS.to_vec(), ALPHA, true, false, 0, true), + PrimitiveArray::from([ + Some(0.0), + Some(2.828_427_124_746_190_3), + Some(2.725_540_575_476_987_5), + Some(3.397_478_056_273_085_3), + Some(2.425_669_179_369_259), + Some(1.939_167_002_502_484_5), + Some(1.934_820_937_061_796_6), + ]), + EPS + ); + assert_allclose!( + ewm_std(XS.to_vec(), ALPHA, true, false, 0, false), + PrimitiveArray::from([ + Some(0.0), + Some(2.828_427_124_746_190_3), + Some(2.725_540_575_476_987_5), + Some(3.397_478_056_273_085_3), + Some(2.425_669_179_369_259), + Some(1.939_167_002_502_484_5), + Some(1.934_820_937_061_796_6), + ]), + EPS + ); + assert_allclose!( + ewm_std(XS.to_vec(), ALPHA, false, true, 0, true), + PrimitiveArray::from([ + Some(0.0), + Some(2.0), + Some(2.449_489_742_783_178), + Some(2.645_751_311_064_590_7), + Some(1.936_491_673_103_708_5), + Some(1.561_249_499_599_599_6), + Some(1.576_190_026_614_811_4), + ]), + EPS + ); + assert_allclose!( + ewm_std(XS.to_vec(), ALPHA, false, true, 0, false), + PrimitiveArray::from([ + Some(0.0), + Some(2.0), + Some(2.449_489_742_783_178), + Some(2.645_751_311_064_590_7), + Some(1.936_491_673_103_708_5), + Some(1.561_249_499_599_599_6), + Some(1.576_190_026_614_811_4), + ]), + EPS + ); + assert_allclose!( + ewm_std(XS.to_vec(), ALPHA, false, false, 0, true), + PrimitiveArray::from([ + Some(0.0), + Some(2.828_427_124_746_190_3), + Some(3.098_386_676_965_933_6), + Some(3.265_986_323_710_904), + Some(2.376_354_103_144_018_3), + Some(1.913_066_660_344_281_2), + Some(1.930_666_342_865_210_7), + ]), + EPS + ); + assert_allclose!( + ewm_std(XS.to_vec(), ALPHA, false, false, 0, false), + PrimitiveArray::from([ + Some(0.0), + Some(2.828_427_124_746_190_3), + Some(3.098_386_676_965_933_6), + Some(3.265_986_323_710_904), + Some(2.376_354_103_144_018_3), + Some(1.913_066_660_344_281_2), + Some(1.930_666_342_865_210_7), + ]), + EPS + ); + assert_allclose!( + ewm_std(YS.to_vec(), ALPHA, true, true, 0, true), + PrimitiveArray::from([ + None, + Some(0.0), + Some(0.942_809_041_582_063_4), + None, + None, + Some(2.710_523_708_715_753_4), + Some(1.885_618_083_164_126_7), + ]), + EPS + ); + assert_allclose!( + ewm_std(YS.to_vec(), ALPHA, true, true, 0, false), + PrimitiveArray::from([ + None, + Some(0.0), + Some(0.942_809_041_582_063_4), + None, + None, + Some(1.980_514_497_076_503), + Some(1.596_805_731_098_222), + ]), + EPS + ); + assert_allclose!( + ewm_std(YS.to_vec(), ALPHA, true, false, 0, true), + PrimitiveArray::from([ + None, + Some(0.0), + Some(SQRT_2), + None, + None, + Some(3.585_685_828_003_181), + Some(2.390_457_218_668_787), + ]), + EPS + ); + assert_allclose!( + ewm_std(YS.to_vec(), ALPHA, true, false, 0, false), + PrimitiveArray::from([ + None, + Some(0.0), + Some(SQRT_2), + None, + None, + Some(3.762_977_544_445_355_3), + Some(2.244_886_116_891_356), + ]), + EPS + ); + assert_allclose!( + ewm_std(YS.to_vec(), ALPHA, false, true, 0, true), + PrimitiveArray::from([ + None, + Some(0.0), + Some(1.0), + None, + None, + Some(2.598_076_211_353_316), + Some(1.854_049_621_773_915_7), + ]), + EPS + ); + assert_allclose!( + ewm_std(YS.to_vec(), ALPHA, false, true, 0, false), + PrimitiveArray::from([ + None, + Some(0.0), + Some(1.0), + None, + None, + Some(2.049_390_153_191_92), + Some(1.760_681_686_165_901), + ]), + EPS + ); + assert_allclose!( + ewm_std(YS.to_vec(), ALPHA, false, false, 0, true), + PrimitiveArray::from([ + None, + Some(0.0), + Some(SQRT_2), + None, + None, + Some(3.286_335_345_030_997), + Some(2.288_688_541_085_317_5), + ]), + EPS + ); + assert_allclose!( + ewm_std(YS.to_vec(), ALPHA, false, false, 0, false), + PrimitiveArray::from([ + None, + Some(0.0), + Some(SQRT_2), + None, + None, + Some(3.514_675_116_774_036_7), + Some(2.301_987_249_996_250_4), + ]), + EPS + ); + } + + #[test] + fn test_ewm_min_periods() { + assert_allclose!( + ewm_var(YS.to_vec(), ALPHA, true, true, 0, false), + PrimitiveArray::from([ + None, + Some(0.0), + Some(0.888_888_888_888_889), + None, + None, + Some(3.922_437_673_130_193_3), + Some(2.549_788_542_868_127_3), + ]), + EPS + ); + assert_allclose!( + ewm_var(YS.to_vec(), ALPHA, true, true, 1, false), + PrimitiveArray::from([ + None, + Some(0.0), + Some(0.888_888_888_888_889), + None, + None, + Some(3.922_437_673_130_193_3), + Some(2.549_788_542_868_127_3), + ]), + EPS + ); + assert_allclose!( + ewm_var(YS.to_vec(), ALPHA, true, true, 2, false), + PrimitiveArray::from([ + None, + None, + Some(0.888_888_888_888_889), + None, + None, + Some(3.922_437_673_130_193_3), + Some(2.549_788_542_868_127_3), + ]), + EPS + ); + } +} diff --git a/crates/polars-arrow/src/legacy/kernels/mod.rs b/crates/polars-arrow/src/legacy/kernels/mod.rs new file mode 100644 index 000000000000..651faf789347 --- /dev/null +++ b/crates/polars-arrow/src/legacy/kernels/mod.rs @@ -0,0 +1,301 @@ +use std::iter::Enumerate; + +use crate::array::BooleanArray; +use crate::bitmap::utils::BitChunks; +pub mod ewm; +pub mod set; +pub mod sort_partition; +#[cfg(feature = "performant")] +pub mod sorted_join; +#[cfg(feature = "strings")] +pub mod string; +pub mod take_agg; +mod time; + +pub use time::{Ambiguous, NonExistent}; +#[cfg(feature = "timezones")] +pub use time::{convert_to_naive_local, convert_to_naive_local_opt}; + +/// Internal state of [SlicesIterator] +#[derive(Debug, PartialEq)] +enum State { + // it is iterating over bits of a mask (`u64`, steps of size of 1 slot) + Bits(u64), + // it is iterating over chunks (steps of size of 64 slots) + Chunks, + // it is iterating over the remaining bits (steps of size of 1 slot) + Remainder, + // nothing more to iterate. + Finish, +} + +/// Forked and modified from arrow crate. +/// +/// An iterator of `(usize, usize)` each representing an interval `[start,end[` whose +/// slots of a [BooleanArray] are true. Each interval corresponds to a contiguous region of memory to be +/// "taken" from an array to be filtered. +#[derive(Debug)] +struct MaskedSlicesIterator<'a> { + iter: Enumerate>, + state: State, + remainder_mask: u64, + remainder_len: usize, + chunk_len: usize, + len: usize, + start: usize, + on_region: bool, + current_chunk: usize, + current_bit: usize, + total_len: usize, +} + +impl<'a> MaskedSlicesIterator<'a> { + pub(crate) fn new(mask: &'a BooleanArray) -> Self { + let chunks = mask.values().chunks::(); + + let chunk_len = mask.len() / 64; + let remainder_len = chunks.remainder_len(); + let remainder_mask = chunks.remainder(); + + Self { + iter: chunks.enumerate(), + state: State::Chunks, + remainder_len, + chunk_len, + remainder_mask, + len: 0, + start: 0, + on_region: false, + current_chunk: 0, + current_bit: 0, + total_len: mask.len(), + } + } + + #[inline] + fn current_start(&self) -> usize { + self.current_chunk * 64 + self.current_bit + } + + #[inline] + fn iterate_bits(&mut self, mask: u64, max: usize) -> Option<(usize, usize)> { + while self.current_bit < max { + if (mask & (1 << self.current_bit)) != 0 { + if !self.on_region { + self.start = self.current_start(); + self.on_region = true; + } + self.len += 1; + } else if self.on_region { + let result = (self.start, self.start + self.len); + self.len = 0; + self.on_region = false; + self.current_bit += 1; + return Some(result); + } + self.current_bit += 1; + } + self.current_bit = 0; + None + } + + /// iterates over chunks. + #[inline] + fn iterate_chunks(&mut self) -> Option<(usize, usize)> { + while let Some((i, mask)) = self.iter.next() { + self.current_chunk = i; + if mask == 0 { + if self.on_region { + let result = (self.start, self.start + self.len); + self.len = 0; + self.on_region = false; + return Some(result); + } + } else if mask == u64::MAX { + // = !0u64 + if !self.on_region { + self.start = self.current_start(); + self.on_region = true; + } + self.len += 64; + } else { + // there is a chunk that has a non-trivial mask => iterate over bits. + self.state = State::Bits(mask); + return None; + } + } + // no more chunks => start iterating over the remainder + self.current_chunk = self.chunk_len; + self.state = State::Remainder; + None + } +} + +impl Iterator for MaskedSlicesIterator<'_> { + type Item = (usize, usize); + + fn next(&mut self) -> Option { + match self.state { + State::Chunks => { + match self.iterate_chunks() { + None => { + // iterating over chunks does not yield any new slice => continue to the next + self.current_bit = 0; + self.next() + }, + other => other, + } + }, + State::Bits(mask) => { + match self.iterate_bits(mask, 64) { + None => { + // iterating over bits does not yield any new slice => change back + // to chunks and continue to the next + self.state = State::Chunks; + self.next() + }, + other => other, + } + }, + State::Remainder => match self.iterate_bits(self.remainder_mask, self.remainder_len) { + None => { + self.state = State::Finish; + if self.on_region { + Some((self.start, self.start + self.len)) + } else { + None + } + }, + other => other, + }, + State::Finish => None, + } + } +} + +#[derive(Eq, PartialEq, Debug)] +enum BinaryMaskedState { + Start, + // Last masks were false values + LastFalse, + // Last masks were true values + LastTrue, + Finish, +} + +pub(crate) struct BinaryMaskedSliceIterator<'a> { + slice_iter: MaskedSlicesIterator<'a>, + filled: usize, + low: usize, + high: usize, + state: BinaryMaskedState, +} + +impl<'a> BinaryMaskedSliceIterator<'a> { + pub(crate) fn new(mask: &'a BooleanArray) -> Self { + Self { + slice_iter: MaskedSlicesIterator::new(mask), + filled: 0, + low: 0, + high: 0, + state: BinaryMaskedState::Start, + } + } +} + +impl Iterator for BinaryMaskedSliceIterator<'_> { + type Item = (usize, usize, bool); + + fn next(&mut self) -> Option { + use BinaryMaskedState::*; + + match self.state { + Start => { + // first iteration + if self.low == 0 && self.high == 0 { + match self.slice_iter.next() { + Some((low, high)) => { + self.low = low; + self.high = high; + + if low > 0 { + // do another start iteration. + Some((0, low, false)) + } else { + self.state = LastTrue; + self.filled = high; + Some((low, high, true)) + } + }, + None => { + self.state = Finish; + Some((self.filled, self.slice_iter.total_len, false)) + }, + } + } else { + self.filled = self.high; + self.state = LastTrue; + Some((self.low, self.high, true)) + } + }, + LastFalse => { + self.state = LastTrue; + self.filled = self.high; + Some((self.low, self.high, true)) + }, + LastTrue => match self.slice_iter.next() { + Some((low, high)) => { + self.low = low; + self.high = high; + self.state = LastFalse; + let last_filled = self.filled; + self.filled = low; + Some((last_filled, low, false)) + }, + None => { + self.state = Finish; + if self.filled != self.slice_iter.total_len { + Some((self.filled, self.slice_iter.total_len, false)) + } else { + None + } + }, + }, + Finish => None, + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_binary_masked_slice_iter() { + let mask = BooleanArray::from_slice([false, false, true, true, true, false, false]); + + let out = BinaryMaskedSliceIterator::new(&mask) + .map(|(_, _, b)| b) + .collect::>(); + assert_eq!(out, &[false, true, false]); + let out = BinaryMaskedSliceIterator::new(&mask).collect::>(); + assert_eq!(out, &[(0, 2, false), (2, 5, true), (5, 7, false)]); + let mask = BooleanArray::from_slice([true, true, false, true]); + let out = BinaryMaskedSliceIterator::new(&mask) + .map(|(_, _, b)| b) + .collect::>(); + assert_eq!(out, &[true, false, true]); + let mask = BooleanArray::from_slice([true, true, false, true, true]); + let out = BinaryMaskedSliceIterator::new(&mask) + .map(|(_, _, b)| b) + .collect::>(); + assert_eq!(out, &[true, false, true]); + } + + #[test] + fn test_binary_slice_mask_iter_with_false() { + let mask = BooleanArray::from_slice([false, false]); + let out = BinaryMaskedSliceIterator::new(&mask).collect::>(); + assert_eq!(out, &[(0, 2, false)]); + } +} diff --git a/crates/polars-arrow/src/legacy/kernels/set.rs b/crates/polars-arrow/src/legacy/kernels/set.rs new file mode 100644 index 000000000000..46a53ff75955 --- /dev/null +++ b/crates/polars-arrow/src/legacy/kernels/set.rs @@ -0,0 +1,139 @@ +use std::ops::BitOr; + +use polars_error::polars_err; +use polars_utils::IdxSize; + +use crate::array::*; +use crate::datatypes::ArrowDataType; +use crate::legacy::array::default_arrays::FromData; +use crate::legacy::error::PolarsResult; +use crate::legacy::kernels::BinaryMaskedSliceIterator; +use crate::legacy::trusted_len::TrustedLenPush; +use crate::types::NativeType; + +/// Set values in a primitive array where the primitive array has null values. +/// this is faster because we don't have to invert and combine bitmaps +pub fn set_at_nulls(array: &PrimitiveArray, value: T) -> PrimitiveArray +where + T: NativeType, +{ + let values = array.values(); + if array.null_count() == 0 { + return array.clone(); + } + + let validity = array.validity().unwrap(); + let validity = BooleanArray::from_data_default(validity.clone(), None); + + let mut av = Vec::with_capacity(array.len()); + BinaryMaskedSliceIterator::new(&validity).for_each(|(lower, upper, truthy)| { + if truthy { + av.extend_from_slice(&values[lower..upper]) + } else { + av.extend_trusted_len(std::iter::repeat_n(value, upper - lower)) + } + }); + + PrimitiveArray::new(array.dtype().clone(), av.into(), None) +} + +/// Set values in a primitive array based on a mask array. This is fast when large chunks of bits are set or unset. +pub fn set_with_mask( + array: &PrimitiveArray, + mask: &BooleanArray, + value: T, + dtype: ArrowDataType, +) -> PrimitiveArray { + let values = array.values(); + + let mut buf = Vec::with_capacity(array.len()); + BinaryMaskedSliceIterator::new(mask).for_each(|(lower, upper, truthy)| { + if truthy { + buf.extend_trusted_len(std::iter::repeat_n(value, upper - lower)) + } else { + buf.extend_from_slice(&values[lower..upper]) + } + }); + // make sure that where the mask is set to true, the validity buffer is also set to valid + // after we have applied the or operation we have new buffer with no offsets + let validity = array.validity().as_ref().map(|valid| { + let mask_bitmap = mask.values(); + valid.bitor(mask_bitmap) + }); + + PrimitiveArray::new(dtype, buf.into(), validity) +} + +/// Efficiently sets value at the indices from the iterator to `set_value`. +/// The new array is initialized with a `memcpy` from the old values. +pub fn scatter_single_non_null( + array: &PrimitiveArray, + idx: I, + set_value: T, + dtype: ArrowDataType, +) -> PolarsResult> +where + T: NativeType, + I: IntoIterator, +{ + let mut buf = Vec::with_capacity(array.len()); + buf.extend_from_slice(array.values().as_slice()); + let mut_slice = buf.as_mut_slice(); + + idx.into_iter().try_for_each::<_, PolarsResult<_>>(|idx| { + let val = mut_slice + .get_mut(idx as usize) + .ok_or_else(|| polars_err!(ComputeError: "index is out of bounds"))?; + *val = set_value; + Ok(()) + })?; + + Ok(PrimitiveArray::new( + dtype, + buf.into(), + array.validity().cloned(), + )) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_set_mask() { + let mask = BooleanArray::from_iter((0..86).map(|v| v > 68 && v != 85).map(Some)); + let val = UInt32Array::from_iter((0..86).map(Some)); + let a = set_with_mask(&val, &mask, 100, ArrowDataType::UInt32); + let slice = a.values(); + + assert_eq!(slice[a.len() - 1], 85); + assert_eq!(slice[a.len() - 2], 100); + assert_eq!(slice[67], 67); + assert_eq!(slice[68], 68); + assert_eq!(slice[1], 1); + assert_eq!(slice[0], 0); + + let mask = BooleanArray::from_slice([ + false, true, false, true, false, true, false, true, false, false, + ]); + let val = UInt32Array::from_slice([0; 10]); + let out = set_with_mask(&val, &mask, 1, ArrowDataType::UInt32); + assert_eq!(out.values().as_slice(), &[0, 1, 0, 1, 0, 1, 0, 1, 0, 0]); + + let val = UInt32Array::from(&[None, None, None]); + let mask = BooleanArray::from(&[Some(true), Some(true), None]); + let out = set_with_mask(&val, &mask, 1, ArrowDataType::UInt32); + let out: Vec<_> = out.iter().map(|v| v.copied()).collect(); + assert_eq!(out, &[Some(1), Some(1), None]) + } + + #[test] + fn test_scatter_single_non_null() { + let val = UInt32Array::from_slice([1, 2, 3]); + let out = + scatter_single_non_null(&val, std::iter::once(1), 100, ArrowDataType::UInt32).unwrap(); + assert_eq!(out.values().as_slice(), &[1, 100, 3]); + let out = scatter_single_non_null(&val, std::iter::once(100), 100, ArrowDataType::UInt32); + assert!(out.is_err()) + } +} diff --git a/crates/polars-arrow/src/legacy/kernels/sort_partition.rs b/crates/polars-arrow/src/legacy/kernels/sort_partition.rs new file mode 100644 index 000000000000..cf7de348f72a --- /dev/null +++ b/crates/polars-arrow/src/legacy/kernels/sort_partition.rs @@ -0,0 +1,207 @@ +use std::fmt::Debug; + +use polars_utils::IdxSize; +use polars_utils::itertools::Itertools; +use polars_utils::total_ord::TotalEq; + +use crate::types::NativeType; + +/// Find partition indexes such that every partition contains unique groups. +fn find_partition_points(values: &[T], n: usize, descending: bool) -> Vec +where + T: Debug + NativeType, +{ + let len = values.len(); + if n > len { + return find_partition_points(values, len / 2, descending); + } + if n < 2 { + return vec![]; + } + let chunk_size = len / n; + + let mut partition_points = Vec::with_capacity(n + 1); + + let mut start_idx = 0; + loop { + let end_idx = start_idx + chunk_size; + if end_idx >= len { + break; + } + // first take that partition as a slice + // and then find the location where the group of the latest value starts + let part = &values[start_idx..end_idx]; + + let latest_val = values[end_idx]; + let idx = if descending { + part.partition_point(|v| v.tot_gt(&latest_val)) + } else { + part.partition_point(|v| v.tot_lt(&latest_val)) + }; + + if idx != 0 { + partition_points.push(idx + start_idx) + } + + start_idx += chunk_size; + } + partition_points +} + +pub fn create_clean_partitions(values: &[T], n: usize, descending: bool) -> Vec<&[T]> +where + T: Debug + NativeType, +{ + let part_idx = find_partition_points(values, n, descending); + let mut out = Vec::with_capacity(n + 1); + + let mut start_idx = 0_usize; + for end_idx in part_idx { + if end_idx != start_idx { + out.push(&values[start_idx..end_idx]); + start_idx = end_idx; + } + } + let latest = &values[start_idx..]; + if !latest.is_empty() { + out.push(latest) + } + + out +} + +pub fn partition_to_groups_amortized_varsize( + values: I, + values_len: IdxSize, + first_group_offset: IdxSize, + nulls_first: bool, + offset: IdxSize, + out: &mut Vec<[IdxSize; 2]>, +) where + T: Debug + TotalEq, + I: IntoIterator, +{ + let mut values = values.into_iter().enumerate_idx(); + if let Some((i, mut first)) = values.next() { + out.clear(); + if nulls_first && first_group_offset > 0 { + out.push([0, first_group_offset]) + } + + let mut first_idx = if nulls_first { first_group_offset } else { 0 } + offset; + let mut start = i; + + for (i, val) in values { + // new group reached + if val.tot_ne(&first) { + let len = i - start; + start = i; + out.push([first_idx, len]); + first_idx += len; + first = val; + } + } + // add last group + if nulls_first { + out.push([first_idx, values_len + first_group_offset - first_idx]); + } else { + out.push([first_idx, values_len - (first_idx - offset)]); + } + + if !nulls_first && first_group_offset > 0 { + out.push([values_len + offset, first_group_offset]) + } + } +} + +pub fn partition_to_groups_amortized( + values: &[T], + first_group_offset: IdxSize, + nulls_first: bool, + offset: IdxSize, + out: &mut Vec<[IdxSize; 2]>, +) where + T: Debug + TotalEq + Sized, +{ + if let Some(mut first) = values.first() { + out.clear(); + if nulls_first && first_group_offset > 0 { + out.push([0, first_group_offset]) + } + + let mut first_idx = if nulls_first { first_group_offset } else { 0 } + offset; + + for val in values { + // new group reached + if val.tot_ne(first) { + let val_ptr = val as *const T; + let first_ptr = first as *const T; + + // SAFETY: + // all pointers suffice the invariants + let len = unsafe { val_ptr.offset_from(first_ptr) } as IdxSize; + out.push([first_idx, len]); + first_idx += len; + first = val; + } + } + // add last group + if nulls_first { + out.push([ + first_idx, + values.len() as IdxSize + first_group_offset - first_idx, + ]); + } else { + out.push([first_idx, values.len() as IdxSize - (first_idx - offset)]); + } + + if !nulls_first && first_group_offset > 0 { + out.push([values.len() as IdxSize + offset, first_group_offset]) + } + } +} + +/// Take a clean-partitioned slice and return the groups slices +/// With clean-partitioned we mean that the slice contains all groups and are not spilled to another partition. +/// +/// `first_group_offset` can be used to add insert the `null` values group. +pub fn partition_to_groups( + values: &[T], + first_group_offset: IdxSize, + nulls_first: bool, + offset: IdxSize, +) -> Vec<[IdxSize; 2]> +where + T: Debug + NativeType, +{ + if values.is_empty() { + return vec![]; + } + let mut out = Vec::with_capacity(values.len() / 10); + partition_to_groups_amortized(values, first_group_offset, nulls_first, offset, &mut out); + out +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_partition_points() { + let values = &[1, 3, 3, 3, 3, 5, 5, 5, 9, 9, 10]; + + assert_eq!(find_partition_points(values, 4, false), &[1, 5, 8, 10]); + assert_eq!( + partition_to_groups(values, 0, true, 0), + &[[0, 1], [1, 4], [5, 3], [8, 2], [10, 1]] + ); + assert_eq!( + partition_to_groups(values, 5, true, 0), + &[[0, 5], [5, 1], [6, 4], [10, 3], [13, 2], [15, 1]] + ); + assert_eq!( + partition_to_groups(values, 5, false, 0), + &[[0, 1], [1, 4], [5, 3], [8, 2], [10, 1], [11, 5]] + ); + } +} diff --git a/crates/polars-arrow/src/legacy/kernels/sorted_join/inner.rs b/crates/polars-arrow/src/legacy/kernels/sorted_join/inner.rs new file mode 100644 index 000000000000..e87e6b2ca1ee --- /dev/null +++ b/crates/polars-arrow/src/legacy/kernels/sorted_join/inner.rs @@ -0,0 +1,95 @@ +use super::*; + +pub fn join( + left: &[T], + right: &[T], + left_offset: IdxSize, +) -> InnerJoinIds { + if left.is_empty() || right.is_empty() { + return (vec![], vec![]); + } + + // * 1.5 because of possible duplicates + let cap = (std::cmp::min(left.len(), right.len()) as f32 * 1.5) as usize; + let mut out_rhs = Vec::with_capacity(cap); + let mut out_lhs = Vec::with_capacity(cap); + + let mut right_idx = 0 as IdxSize; + // left array could start lower than right; + // left: [-1, 0, 1, 2], + // right: [1, 2, 3] + let first_right = right[0]; + let mut left_idx = left.partition_point(|v| v < &first_right) as IdxSize; + + for &val_l in &left[left_idx as usize..] { + while let Some(&val_r) = right.get(right_idx as usize) { + // matching join key + if val_l == val_r { + out_lhs.push(left_idx + left_offset); + out_rhs.push(right_idx); + let current_idx = right_idx; + + loop { + right_idx += 1; + match right.get(right_idx as usize) { + // rhs depleted + None => { + // reset right index because the next lhs value can be the same + right_idx = current_idx; + break; + }, + Some(&val_r) => { + if val_l == val_r { + out_lhs.push(left_idx + left_offset); + out_rhs.push(right_idx); + } else { + // reset right index because the next lhs value can be the same + right_idx = current_idx; + break; + } + }, + } + } + break; + } + + // right is larger than left. + if val_r > val_l { + break; + } + // continue looping the right side + right_idx += 1; + } + left_idx += 1; + } + (out_lhs, out_rhs) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_inner_join() { + let lhs = &[0, 1, 1, 2, 3, 5]; + let rhs = &[0, 1, 1, 3, 4]; + + let (l_idx, r_idx) = join(lhs, rhs, 0); + + assert_eq!(&l_idx, &[0, 1, 1, 2, 2, 4]); + assert_eq!(&r_idx, &[0, 1, 2, 1, 2, 3]); + + let lhs = &[4, 4, 4, 4, 5, 6, 6, 7, 7, 7]; + let rhs = &[0, 1, 2, 3, 4, 4, 4, 6, 7, 7]; + let (l_idx, r_idx) = join(lhs, rhs, 0); + + assert_eq!( + &l_idx, + &[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 5, 6, 7, 7, 8, 8, 9, 9] + ); + assert_eq!( + &r_idx, + &[4, 5, 6, 4, 5, 6, 4, 5, 6, 4, 5, 6, 7, 7, 8, 9, 8, 9, 8, 9] + ); + } +} diff --git a/crates/polars-arrow/src/legacy/kernels/sorted_join/left.rs b/crates/polars-arrow/src/legacy/kernels/sorted_join/left.rs new file mode 100644 index 000000000000..6e35ba7c48bc --- /dev/null +++ b/crates/polars-arrow/src/legacy/kernels/sorted_join/left.rs @@ -0,0 +1,195 @@ +use super::*; + +pub fn join( + left: &[T], + right: &[T], + left_offset: IdxSize, +) -> LeftJoinIds { + if left.is_empty() { + return (vec![], vec![]); + } + if right.is_empty() { + return ( + (left_offset..left.len() as IdxSize + left_offset).collect(), + vec![NullableIdxSize::null(); left.len()], + ); + } + // * 1.5 because there can be duplicates + let cap = (left.len() as f32 * 1.5) as usize; + let mut out_rhs = Vec::with_capacity(cap); + let mut out_lhs = Vec::with_capacity(cap); + + let mut right_idx = 0 as IdxSize; + // left array could start lower than right; + // left: [-1, 0, 1, 2], + // right: [1, 2, 3] + // first values should be None, until left has caught up + + let first_right = right[right_idx as usize]; + let mut left_idx = left.partition_point(|v| v < &first_right) as IdxSize; + out_rhs.extend(std::iter::repeat_n( + NullableIdxSize::null(), + left_idx as usize, + )); + out_lhs.extend(left_offset..(left_idx + left_offset)); + + for &val_l in &left[left_idx as usize..] { + loop { + match right.get(right_idx as usize) { + Some(&val_r) => { + // matching join key + if val_l == val_r { + out_lhs.push(left_idx + left_offset); + out_rhs.push(right_idx.into()); + let current_idx = right_idx; + + loop { + right_idx += 1; + match right.get(right_idx as usize) { + // rhs depleted + None => { + // reset right index because the next lhs value can be the same + right_idx = current_idx; + break; + }, + Some(&val_r) => { + if val_l == val_r { + out_lhs.push(left_idx + left_offset); + out_rhs.push(right_idx.into()); + } else { + // reset right index because the next lhs value can be the same + right_idx = current_idx; + break; + } + }, + } + } + break; + } + + // right is larger than left. + if val_r > val_l { + out_lhs.push(left_idx + left_offset); + out_rhs.push(NullableIdxSize::null()); + break; + } + // continue looping the right side + right_idx += 1; + }, + // we depleted the right array + None => { + out_lhs.push(left_idx + left_offset); + out_rhs.push(NullableIdxSize::null()); + break; + }, + } + } + left_idx += 1; + } + (out_lhs, out_rhs) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_left_join() { + let lhs = &[0, 1, 1, 2, 3, 5]; + let rhs = &[0, 1, 1, 3, 4]; + + let (l_idx, r_idx) = join(lhs, rhs, 0); + let out_left = &[0, 1, 1, 2, 2, 3, 4, 5]; + let out_right = &[ + 0.into(), + 1.into(), + 2.into(), + 1.into(), + 2.into(), + NullableIdxSize::null(), + 3.into(), + NullableIdxSize::null(), + ]; + assert_eq!(&l_idx, out_left); + assert_eq!(&r_idx, out_right); + + let offset = 2; + let (l_idx, r_idx) = join(&lhs[offset..], rhs, offset as IdxSize); + assert_eq!(l_idx, out_left[3..]); + assert_eq!(r_idx, out_right[3..]); + + let offset = 3; + let (l_idx, r_idx) = join(&lhs[offset..], rhs, offset as IdxSize); + assert_eq!(l_idx, out_left[5..]); + assert_eq!(r_idx, out_right[5..]); + + let lhs = &[0, 0, 1, 3, 4, 5, 6, 6, 6, 7]; + let rhs = &[0, 0, 1, 3, 4, 6, 6]; + + let (l_idx, r_idx) = join(lhs, rhs, 0); + assert_eq!(&l_idx, &[0, 0, 1, 1, 2, 3, 4, 5, 6, 6, 7, 7, 8, 8, 9]); + assert_eq!( + &r_idx, + &[ + 0.into(), + 1.into(), + 0.into(), + 1.into(), + 2.into(), + 3.into(), + 4.into(), + NullableIdxSize::null(), + 5.into(), + 6.into(), + 5.into(), + 6.into(), + 5.into(), + 6.into(), + NullableIdxSize::null(), + ] + ); + + let lhs = &[1, 3, 4, 5, 5, 5, 5, 6, 7, 7]; + let rhs = &[2, 4, 5, 6, 7, 8, 10, 11, 11, 12, 12, 12, 12, 13]; + let (l_idx, r_idx) = join(lhs, rhs, 0); + assert_eq!(&l_idx, &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); + assert_eq!( + &r_idx, + &[ + NullableIdxSize::null(), + NullableIdxSize::null(), + 1.into(), + 2.into(), + 2.into(), + 2.into(), + 2.into(), + 3.into(), + 4.into(), + 4.into() + ] + ); + let lhs = &[0, 1, 2, 2, 3, 4, 4, 6, 6, 7]; + let rhs = &[4, 4, 4, 8]; + let (l_idx, r_idx) = join(lhs, rhs, 0); + assert_eq!(&l_idx, &[0, 1, 2, 3, 4, 5, 5, 5, 6, 6, 6, 7, 8, 9]); + assert_eq!( + &r_idx, + &[ + NullableIdxSize::null(), + NullableIdxSize::null(), + NullableIdxSize::null(), + NullableIdxSize::null(), + NullableIdxSize::null(), + 0.into(), + 1.into(), + 2.into(), + 0.into(), + 1.into(), + 2.into(), + NullableIdxSize::null(), + NullableIdxSize::null(), + NullableIdxSize::null(), + ] + ) + } +} diff --git a/crates/polars-arrow/src/legacy/kernels/sorted_join/mod.rs b/crates/polars-arrow/src/legacy/kernels/sorted_join/mod.rs new file mode 100644 index 000000000000..5aea170f30d5 --- /dev/null +++ b/crates/polars-arrow/src/legacy/kernels/sorted_join/mod.rs @@ -0,0 +1,11 @@ +pub mod inner; +pub mod left; + +use std::fmt::Debug; + +use polars_utils::{IdxSize, NullableIdxSize}; + +type JoinOptIds = Vec; +type JoinIds = Vec; +type LeftJoinIds = (JoinIds, JoinOptIds); +type InnerJoinIds = (JoinIds, JoinIds); diff --git a/crates/polars-arrow/src/legacy/kernels/string.rs b/crates/polars-arrow/src/legacy/kernels/string.rs new file mode 100644 index 000000000000..4733605030ea --- /dev/null +++ b/crates/polars-arrow/src/legacy/kernels/string.rs @@ -0,0 +1,18 @@ +use crate::array::{Array, ArrayRef, UInt32Array, Utf8ViewArray}; +use crate::buffer::Buffer; +use crate::datatypes::ArrowDataType; +use crate::legacy::trusted_len::TrustedLenPush; + +pub fn utf8view_len_bytes(array: &Utf8ViewArray) -> ArrayRef { + let values = array.len_iter().collect::>(); + let values: Buffer<_> = values.into(); + let array = UInt32Array::new(ArrowDataType::UInt32, values, array.validity().cloned()); + Box::new(array) +} + +pub fn string_len_chars(array: &Utf8ViewArray) -> ArrayRef { + let values = array.values_iter().map(|x| x.chars().count() as u32); + let values: Buffer<_> = Vec::from_trusted_len_iter(values).into(); + let array = UInt32Array::new(ArrowDataType::UInt32, values, array.validity().cloned()); + Box::new(array) +} diff --git a/crates/polars-arrow/src/legacy/kernels/take_agg/boolean.rs b/crates/polars-arrow/src/legacy/kernels/take_agg/boolean.rs new file mode 100644 index 000000000000..8397666e40fa --- /dev/null +++ b/crates/polars-arrow/src/legacy/kernels/take_agg/boolean.rs @@ -0,0 +1,90 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use super::*; + +/// Take kernel for single chunk and an iterator as index. +/// # Safety +/// caller must ensure iterators indexes are in bounds +#[inline] +pub unsafe fn take_min_bool_iter_unchecked_nulls>( + arr: &BooleanArray, + indices: I, + len: IdxSize, +) -> Option { + let mut null_count = 0 as IdxSize; + let validity = arr.validity().unwrap(); + + for idx in indices { + if validity.get_bit_unchecked(idx) { + if !arr.value_unchecked(idx) { + return Some(false); + } + } else { + null_count += 1; + } + } + if null_count == len { None } else { Some(true) } +} + +/// Take kernel for single chunk and an iterator as index. +/// # Safety +/// caller must ensure iterators indexes are in bounds +#[inline] +pub unsafe fn take_min_bool_iter_unchecked_no_nulls>( + arr: &BooleanArray, + indices: I, +) -> Option { + if arr.is_empty() { + return None; + } + + for idx in indices { + if !arr.value_unchecked(idx) { + return Some(false); + } + } + Some(true) +} + +/// Take kernel for single chunk and an iterator as index. +/// # Safety +/// caller must ensure iterators indexes are in bounds +#[inline] +pub unsafe fn take_max_bool_iter_unchecked_nulls>( + arr: &BooleanArray, + indices: I, + len: IdxSize, +) -> Option { + let mut null_count = 0 as IdxSize; + let validity = arr.validity().unwrap(); + + for idx in indices { + if validity.get_bit_unchecked(idx) { + if arr.value_unchecked(idx) { + return Some(true); + } + } else { + null_count += 1; + } + } + if null_count == len { None } else { Some(false) } +} + +/// Take kernel for single chunk and an iterator as index. +/// # Safety +/// caller must ensure iterators indexes are in bounds +#[inline] +pub unsafe fn take_max_bool_iter_unchecked_no_nulls>( + arr: &BooleanArray, + indices: I, +) -> Option { + if arr.is_empty() { + return None; + } + + for idx in indices { + if arr.value_unchecked(idx) { + return Some(true); + } + } + Some(false) +} diff --git a/crates/polars-arrow/src/legacy/kernels/take_agg/mod.rs b/crates/polars-arrow/src/legacy/kernels/take_agg/mod.rs new file mode 100644 index 000000000000..b25ae16f3c91 --- /dev/null +++ b/crates/polars-arrow/src/legacy/kernels/take_agg/mod.rs @@ -0,0 +1,156 @@ +#![allow(unsafe_op_in_unsafe_fn)] +//! kernels that combine take and aggregations. +mod boolean; +mod var; + +pub use boolean::*; +use num_traits::{NumCast, ToPrimitive}; +use polars_utils::IdxSize; +pub use var::*; + +use crate::array::{Array, BinaryViewArray, BooleanArray, PrimitiveArray}; +use crate::types::NativeType; + +/// Take kernel for single chunk without nulls and an iterator as index. +/// # Safety +/// caller must ensure iterators indexes are in bounds +#[inline] +pub unsafe fn take_agg_no_null_primitive_iter_unchecked< + T: NativeType + ToPrimitive, + TOut: NumCast + NativeType, + I: IntoIterator, + F: Fn(TOut, TOut) -> TOut, +>( + arr: &PrimitiveArray, + indices: I, + f: F, +) -> Option { + debug_assert!(arr.null_count() == 0); + let array_values = arr.values().as_slice(); + + indices + .into_iter() + .map(|idx| TOut::from(*array_values.get_unchecked(idx)).unwrap_unchecked()) + .reduce(f) +} + +/// Take kernel for single chunk and an iterator as index. +/// # Safety +/// caller must ensure iterators indexes are in bounds +#[inline] +pub unsafe fn take_agg_primitive_iter_unchecked< + T: NativeType, + I: IntoIterator, + F: Fn(T, T) -> T, +>( + arr: &PrimitiveArray, + indices: I, + f: F, +) -> Option { + let array_values = arr.values().as_slice(); + let validity = arr.validity().unwrap(); + + indices + .into_iter() + .filter(|&idx| validity.get_bit_unchecked(idx)) + .map(|idx| *array_values.get_unchecked(idx)) + .reduce(f) +} + +/// Take kernel for single chunk and an iterator as index. +/// # Safety +/// caller must enure iterators indexes are in bounds +#[inline] +pub unsafe fn take_agg_primitive_iter_unchecked_count_nulls< + T: NativeType + ToPrimitive, + TOut: NumCast + NativeType, + I: IntoIterator, + F: Fn(TOut, TOut) -> TOut, +>( + arr: &PrimitiveArray, + indices: I, + f: F, + init: TOut, + len: IdxSize, +) -> Option<(TOut, IdxSize)> { + let array_values = arr.values().as_slice(); + let validity = arr.validity().expect("null buffer should be there"); + + let mut null_count = 0 as IdxSize; + let out = indices.into_iter().fold(init, |acc, idx| { + if validity.get_bit_unchecked(idx) { + f( + acc, + NumCast::from(*array_values.get_unchecked(idx)).unwrap_unchecked(), + ) + } else { + null_count += 1; + acc + } + }); + if null_count == len { + None + } else { + Some((out, null_count)) + } +} + +/// Take kernel for single chunk and an iterator as index. +/// # Safety +/// caller must ensure iterators indexes are in bounds +#[inline] +pub unsafe fn take_agg_bin_iter_unchecked< + 'a, + I: IntoIterator, + F: Fn(&'a [u8], &'a [u8]) -> &'a [u8], +>( + arr: &'a BinaryViewArray, + indices: I, + f: F, + len: IdxSize, +) -> Option<&'a [u8]> { + let mut null_count = 0 as IdxSize; + let validity = arr.validity().unwrap(); + + let out = indices + .into_iter() + .map(|idx| { + if validity.get_bit_unchecked(idx) { + Some(arr.value_unchecked(idx)) + } else { + None + } + }) + .reduce(|acc, opt_val| match (acc, opt_val) { + (Some(acc), Some(str_val)) => Some(f(acc, str_val)), + (_, None) => { + null_count += 1; + acc + }, + (None, Some(str_val)) => Some(str_val), + }); + if null_count == len { + None + } else { + out.flatten() + } +} + +/// Take kernel for single chunk and an iterator as index. +/// # Safety +/// caller must ensure iterators indexes are in bounds +#[inline] +pub unsafe fn take_agg_bin_iter_unchecked_no_null< + 'a, + I: IntoIterator, + F: Fn(&'a [u8], &'a [u8]) -> &'a [u8], +>( + arr: &'a BinaryViewArray, + indices: I, + f: F, +) -> Option<&'a [u8]> { + indices + .into_iter() + .map(|idx| arr.value_unchecked(idx)) + .reduce(|acc, str_val| f(acc, str_val)) +} diff --git a/crates/polars-arrow/src/legacy/kernels/take_agg/var.rs b/crates/polars-arrow/src/legacy/kernels/take_agg/var.rs new file mode 100644 index 000000000000..3a15062ab674 --- /dev/null +++ b/crates/polars-arrow/src/legacy/kernels/take_agg/var.rs @@ -0,0 +1,92 @@ +use super::*; + +/// Numerical stable online variance aggregation. +/// +/// See: +/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products". +/// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577. +/// and: +/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances". +/// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154. +pub fn online_variance( + // iterator producing values + iter: I, + ddof: u8, +) -> Option +where + I: IntoIterator, +{ + let mut m2 = 0.0; + let mut mean = 0.0; + let mut count = 0u64; + + for value in iter { + let new_count = count + 1; + let delta_1 = value - mean; + let new_mean = delta_1 / new_count as f64 + mean; + let delta_2 = value - new_mean; + let new_m2 = m2 + delta_1 * delta_2; + + count += 1; + mean = new_mean; + m2 = new_m2; + } + + if count <= ddof as u64 { + return None; + } + + Some(m2 / (count as f64 - ddof as f64)) +} + +/// Take kernel for single chunk and an iterator as index. +/// # Safety +/// caller must ensure iterators indexes are in bounds +pub unsafe fn take_var_no_null_primitive_iter_unchecked( + arr: &PrimitiveArray, + indices: I, + ddof: u8, +) -> Option +where + T: NativeType + ToPrimitive, + I: IntoIterator, +{ + debug_assert!(arr.null_count() == 0); + let array_values = arr.values().as_slice(); + let iter = unsafe { + indices.into_iter().map(|idx| { + let value = *array_values.get_unchecked(idx); + value.to_f64().unwrap_unchecked() + }) + }; + online_variance(iter, ddof) +} + +/// Take kernel for single chunk and an iterator as index. +/// # Safety +/// caller must ensure iterators indexes are in bounds +pub unsafe fn take_var_nulls_primitive_iter_unchecked( + arr: &PrimitiveArray, + indices: I, + ddof: u8, +) -> Option +where + T: NativeType + ToPrimitive, + I: IntoIterator, +{ + debug_assert!(arr.null_count() > 0); + let array_values = arr.values().as_slice(); + let validity = arr.validity().unwrap(); + + let iter = unsafe { + indices.into_iter().flat_map(|idx| { + if validity.get_bit_unchecked(idx) { + let value = *array_values.get_unchecked(idx); + value.to_f64() + } else { + None + } + }) + }; + online_variance(iter, ddof) +} diff --git a/crates/polars-arrow/src/legacy/kernels/time.rs b/crates/polars-arrow/src/legacy/kernels/time.rs new file mode 100644 index 000000000000..65201fd59b2c --- /dev/null +++ b/crates/polars-arrow/src/legacy/kernels/time.rs @@ -0,0 +1,93 @@ +use std::str::FromStr; + +#[cfg(feature = "timezones")] +use chrono::{LocalResult, NaiveDateTime, TimeZone}; +#[cfg(feature = "timezones")] +use chrono_tz::Tz; +#[cfg(feature = "timezones")] +use polars_error::PolarsResult; +use polars_error::{PolarsError, polars_bail}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; + +pub enum Ambiguous { + Earliest, + Latest, + Null, + Raise, +} +impl FromStr for Ambiguous { + type Err = PolarsError; + + fn from_str(s: &str) -> Result { + match s { + "earliest" => Ok(Ambiguous::Earliest), + "latest" => Ok(Ambiguous::Latest), + "raise" => Ok(Ambiguous::Raise), + "null" => Ok(Ambiguous::Null), + s => polars_bail!(InvalidOperation: + "Invalid argument {}, expected one of: \"earliest\", \"latest\", \"null\", \"raise\"", s + ), + } + } +} + +#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, IntoStaticStr)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] +pub enum NonExistent { + Null, + Raise, +} + +#[cfg(feature = "timezones")] +pub fn convert_to_naive_local( + from_tz: &Tz, + to_tz: &Tz, + ndt: NaiveDateTime, + ambiguous: Ambiguous, + non_existent: NonExistent, +) -> PolarsResult> { + let ndt = from_tz.from_utc_datetime(&ndt).naive_local(); + match to_tz.from_local_datetime(&ndt) { + LocalResult::Single(dt) => Ok(Some(dt.naive_utc())), + LocalResult::Ambiguous(dt_earliest, dt_latest) => match ambiguous { + Ambiguous::Earliest => Ok(Some(dt_earliest.naive_utc())), + Ambiguous::Latest => Ok(Some(dt_latest.naive_utc())), + Ambiguous::Null => Ok(None), + Ambiguous::Raise => { + polars_bail!(ComputeError: "datetime '{}' is ambiguous in time zone '{}'. Please use `ambiguous` to tell how it should be localized.", ndt, to_tz) + }, + }, + LocalResult::None => match non_existent { + NonExistent::Raise => polars_bail!(ComputeError: + "datetime '{}' is non-existent in time zone '{}'. You may be able to use `non_existent='null'` to return `null` in this case.", + ndt, to_tz + ), + NonExistent::Null => Ok(None), + }, + } +} + +/// Same as convert_to_naive_local, but return `None` instead +/// raising - in some cases this can be used to save a string allocation. +#[cfg(feature = "timezones")] +pub fn convert_to_naive_local_opt( + from_tz: &Tz, + to_tz: &Tz, + ndt: NaiveDateTime, + ambiguous: Ambiguous, +) -> Option> { + let ndt = from_tz.from_utc_datetime(&ndt).naive_local(); + match to_tz.from_local_datetime(&ndt) { + LocalResult::Single(dt) => Some(Some(dt.naive_utc())), + LocalResult::Ambiguous(dt_earliest, dt_latest) => match ambiguous { + Ambiguous::Earliest => Some(Some(dt_earliest.naive_utc())), + Ambiguous::Latest => Some(Some(dt_latest.naive_utc())), + Ambiguous::Null => Some(None), + Ambiguous::Raise => None, + }, + LocalResult::None => None, + } +} diff --git a/crates/polars-arrow/src/legacy/mod.rs b/crates/polars-arrow/src/legacy/mod.rs new file mode 100644 index 000000000000..96815a6226d6 --- /dev/null +++ b/crates/polars-arrow/src/legacy/mod.rs @@ -0,0 +1,11 @@ +pub mod array; +pub mod bit_util; +pub mod conversion; +pub mod error; +pub mod index; +pub mod is_valid; +pub mod kernels; +pub mod prelude; +pub mod time_zone; +pub mod trusted_len; +pub mod utils; diff --git a/crates/polars-arrow/src/legacy/prelude.rs b/crates/polars-arrow/src/legacy/prelude.rs new file mode 100644 index 000000000000..ffc82aacd622 --- /dev/null +++ b/crates/polars-arrow/src/legacy/prelude.rs @@ -0,0 +1,9 @@ +use crate::array::{BinaryArray, ListArray, Utf8Array}; +pub use crate::legacy::array::default_arrays::*; +pub use crate::legacy::array::*; +pub use crate::legacy::index::*; +pub use crate::legacy::kernels::{Ambiguous, NonExistent}; + +pub type LargeStringArray = Utf8Array; +pub type LargeBinaryArray = BinaryArray; +pub type LargeListArray = ListArray; diff --git a/crates/polars-arrow/src/legacy/time_zone.rs b/crates/polars-arrow/src/legacy/time_zone.rs new file mode 100644 index 000000000000..8b8679780191 --- /dev/null +++ b/crates/polars-arrow/src/legacy/time_zone.rs @@ -0,0 +1,6 @@ +// a placeholder type for when timezones are not enabled +#[cfg(not(feature = "timezones"))] +#[derive(Copy, Clone)] +pub enum Tz {} +#[cfg(feature = "timezones")] +pub use chrono_tz::Tz; diff --git a/crates/polars-arrow/src/legacy/trusted_len/boolean.rs b/crates/polars-arrow/src/legacy/trusted_len/boolean.rs new file mode 100644 index 000000000000..f6ac86dafc21 --- /dev/null +++ b/crates/polars-arrow/src/legacy/trusted_len/boolean.rs @@ -0,0 +1,85 @@ +use crate::array::BooleanArray; +use crate::bitmap::utils::set_bit_unchecked; +use crate::bitmap::{BitmapBuilder, MutableBitmap}; +use crate::datatypes::ArrowDataType; +use crate::legacy::array::default_arrays::FromData; +use crate::legacy::trusted_len::FromIteratorReversed; +use crate::legacy::utils::FromTrustedLenIterator; +use crate::trusted_len::TrustedLen; + +impl FromTrustedLenIterator> for BooleanArray { + fn from_iter_trusted_length>>(iter: I) -> Self + where + I::IntoIter: TrustedLen, + { + // Soundness + // Trait system bounded to TrustedLen + unsafe { BooleanArray::from_trusted_len_iter_unchecked(iter.into_iter()) } + } +} +impl FromTrustedLenIterator for BooleanArray { + fn from_iter_trusted_length>(iter: I) -> Self + where + I::IntoIter: TrustedLen, + { + // Soundness + // Trait system bounded to TrustedLen + unsafe { + BooleanArray::from_data_default( + BitmapBuilder::from_trusted_len_iter(iter.into_iter()).freeze(), + None, + ) + } + } +} + +impl FromIteratorReversed for BooleanArray { + fn from_trusted_len_iter_rev>(iter: I) -> Self { + let size = iter.size_hint().1.unwrap(); + + let mut vals = MutableBitmap::from_len_zeroed(size); + let vals_slice = vals.as_mut_slice(); + unsafe { + let mut offset = size; + iter.for_each(|item| { + offset -= 1; + if item { + set_bit_unchecked(vals_slice, offset, true); + } + }); + } + BooleanArray::new(ArrowDataType::Boolean, vals.into(), None) + } +} + +impl FromIteratorReversed> for BooleanArray { + fn from_trusted_len_iter_rev>>(iter: I) -> Self { + let size = iter.size_hint().1.unwrap(); + + let mut vals = MutableBitmap::from_len_zeroed(size); + let mut validity = MutableBitmap::with_capacity(size); + validity.extend_constant(size, true); + let validity_slice = validity.as_mut_slice(); + let vals_slice = vals.as_mut_slice(); + unsafe { + let mut offset = size; + + iter.for_each(|opt_item| { + offset -= 1; + match opt_item { + Some(item) => { + if item { + // Set value (validity bit is already true). + set_bit_unchecked(vals_slice, offset, true); + } + }, + None => { + // Unset validity bit. + set_bit_unchecked(validity_slice, offset, false) + }, + } + }); + } + BooleanArray::new(ArrowDataType::Boolean, vals.into(), Some(validity.into())) + } +} diff --git a/crates/polars-arrow/src/legacy/trusted_len/mod.rs b/crates/polars-arrow/src/legacy/trusted_len/mod.rs new file mode 100644 index 000000000000..9967ecebc594 --- /dev/null +++ b/crates/polars-arrow/src/legacy/trusted_len/mod.rs @@ -0,0 +1,6 @@ +mod boolean; +mod push_unchecked; +mod rev; + +pub use push_unchecked::*; +pub use rev::FromIteratorReversed; diff --git a/crates/polars-arrow/src/legacy/trusted_len/push_unchecked.rs b/crates/polars-arrow/src/legacy/trusted_len/push_unchecked.rs new file mode 100644 index 000000000000..a78cb334fd11 --- /dev/null +++ b/crates/polars-arrow/src/legacy/trusted_len/push_unchecked.rs @@ -0,0 +1,120 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use crate::trusted_len::TrustedLen; + +pub trait TrustedLenPush { + /// Will push an item and not check if there is enough capacity. + /// + /// # Safety + /// Caller must ensure the array has enough capacity to hold `T`. + unsafe fn push_unchecked(&mut self, value: T); + + /// Extend the array with an iterator who's length can be trusted. + fn extend_trusted_len, J: TrustedLen>( + &mut self, + iter: I, + ) { + unsafe { self.extend_trusted_len_unchecked(iter) } + } + + /// # Safety + /// Caller must ensure the iterators reported length is correct. + unsafe fn extend_trusted_len_unchecked>(&mut self, iter: I); + + /// # Safety + /// Caller must ensure the iterators reported length is correct. + unsafe fn try_extend_trusted_len_unchecked>>( + &mut self, + iter: I, + ) -> Result<(), E>; + + fn from_trusted_len_iter, J: TrustedLen>( + iter: I, + ) -> Self + where + Self: Sized, + { + unsafe { Self::from_trusted_len_iter_unchecked(iter) } + } + /// # Safety + /// Caller must ensure the iterators reported length is correct. + unsafe fn from_trusted_len_iter_unchecked>(iter: I) -> Self; + + fn try_from_trusted_len_iter< + E, + I: IntoIterator, IntoIter = J>, + J: TrustedLen, + >( + iter: I, + ) -> Result + where + Self: Sized, + { + unsafe { Self::try_from_trusted_len_iter_unchecked(iter) } + } + /// # Safety + /// Caller must ensure the iterators reported length is correct. + unsafe fn try_from_trusted_len_iter_unchecked>>( + iter: I, + ) -> Result + where + Self: Sized; +} + +impl TrustedLenPush for Vec { + #[inline(always)] + unsafe fn push_unchecked(&mut self, value: T) { + debug_assert!(self.capacity() > self.len()); + let end = self.as_mut_ptr().add(self.len()); + std::ptr::write(end, value); + self.set_len(self.len() + 1); + } + + #[inline] + unsafe fn extend_trusted_len_unchecked>(&mut self, iter: I) { + let iter = iter.into_iter(); + let upper = iter.size_hint().1.expect("must have an upper bound"); + self.reserve(upper); + + let mut dst = self.as_mut_ptr().add(self.len()); + for value in iter { + std::ptr::write(dst, value); + dst = dst.add(1) + } + self.set_len(self.len() + upper) + } + + unsafe fn try_extend_trusted_len_unchecked>>( + &mut self, + iter: I, + ) -> Result<(), E> { + let iter = iter.into_iter(); + let upper = iter.size_hint().1.expect("must have an upper bound"); + self.reserve(upper); + + let mut dst = self.as_mut_ptr().add(self.len()); + for value in iter { + std::ptr::write(dst, value?); + dst = dst.add(1) + } + self.set_len(self.len() + upper); + Ok(()) + } + + #[inline] + unsafe fn from_trusted_len_iter_unchecked>(iter: I) -> Self { + let mut v = vec![]; + v.extend_trusted_len_unchecked(iter); + v + } + + unsafe fn try_from_trusted_len_iter_unchecked>>( + iter: I, + ) -> Result + where + Self: Sized, + { + let mut v = vec![]; + v.try_extend_trusted_len_unchecked(iter)?; + Ok(v) + } +} diff --git a/crates/polars-arrow/src/legacy/trusted_len/rev.rs b/crates/polars-arrow/src/legacy/trusted_len/rev.rs new file mode 100644 index 000000000000..0677ced9f7df --- /dev/null +++ b/crates/polars-arrow/src/legacy/trusted_len/rev.rs @@ -0,0 +1,5 @@ +use crate::trusted_len::TrustedLen; + +pub trait FromIteratorReversed: Sized { + fn from_trusted_len_iter_rev>(iter: I) -> Self; +} diff --git a/crates/polars-arrow/src/legacy/utils.rs b/crates/polars-arrow/src/legacy/utils.rs new file mode 100644 index 000000000000..767e0361e592 --- /dev/null +++ b/crates/polars-arrow/src/legacy/utils.rs @@ -0,0 +1,139 @@ +use crate::array::PrimitiveArray; +use crate::bitmap::MutableBitmap; +use crate::bitmap::utils::set_bit_unchecked; +use crate::datatypes::ArrowDataType; +use crate::legacy::trusted_len::{FromIteratorReversed, TrustedLenPush}; +use crate::trusted_len::{TrustMyLength, TrustedLen}; +use crate::types::NativeType; + +pub trait CustomIterTools: Iterator { + /// Turn any iterator in a trusted length iterator + /// + /// # Safety + /// The given length must be correct. + #[inline] + unsafe fn trust_my_length(self, length: usize) -> TrustMyLength + where + Self: Sized, + { + unsafe { TrustMyLength::new(self, length) } + } + + fn collect_trusted>(self) -> T + where + Self: Sized + TrustedLen, + { + FromTrustedLenIterator::from_iter_trusted_length(self) + } + + fn collect_reversed>(self) -> T + where + Self: Sized + TrustedLen, + { + FromIteratorReversed::from_trusted_len_iter_rev(self) + } +} + +pub trait CustomIterToolsSized: Iterator + Sized {} + +impl CustomIterTools for T where T: Iterator {} + +pub trait FromTrustedLenIterator: Sized { + fn from_iter_trusted_length>(iter: T) -> Self + where + T::IntoIter: TrustedLen; +} + +impl FromTrustedLenIterator for Vec { + fn from_iter_trusted_length>(iter: I) -> Self + where + I::IntoIter: TrustedLen, + { + let iter = iter.into_iter(); + let len = iter.size_hint().0; + let mut v = Vec::with_capacity(len); + v.extend_trusted_len(iter); + v + } +} + +impl FromTrustedLenIterator> for PrimitiveArray { + fn from_iter_trusted_length>>(iter: I) -> Self + where + I::IntoIter: TrustedLen, + { + let iter = iter.into_iter(); + unsafe { PrimitiveArray::from_trusted_len_iter_unchecked(iter) } + } +} + +impl FromTrustedLenIterator for PrimitiveArray { + fn from_iter_trusted_length>(iter: I) -> Self + where + I::IntoIter: TrustedLen, + { + let iter = iter.into_iter(); + unsafe { PrimitiveArray::from_trusted_len_values_iter_unchecked(iter) } + } +} + +impl FromIteratorReversed for Vec { + fn from_trusted_len_iter_rev>(iter: I) -> Self { + unsafe { + let len = iter.size_hint().1.unwrap(); + let mut out: Vec = Vec::with_capacity(len); + let mut idx = len; + for x in iter { + debug_assert!(idx > 0); + idx -= 1; + out.as_mut_ptr().add(idx).write(x); + } + debug_assert!(idx == 0); + out.set_len(len); + out + } + } +} + +impl FromIteratorReversed for PrimitiveArray { + fn from_trusted_len_iter_rev>(iter: I) -> Self { + let vals: Vec = iter.collect_reversed(); + PrimitiveArray::new(ArrowDataType::from(T::PRIMITIVE), vals.into(), None) + } +} + +impl FromIteratorReversed> for PrimitiveArray { + fn from_trusted_len_iter_rev>>(iter: I) -> Self { + let size = iter.size_hint().1.unwrap(); + + let mut vals: Vec = Vec::with_capacity(size); + let mut validity = MutableBitmap::with_capacity(size); + validity.extend_constant(size, true); + let validity_slice = validity.as_mut_slice(); + unsafe { + // Set to end of buffer. + let mut ptr = vals.as_mut_ptr().add(size); + let mut offset = size; + + iter.for_each(|opt_item| { + offset -= 1; + ptr = ptr.sub(1); + match opt_item { + Some(item) => { + std::ptr::write(ptr, item); + }, + None => { + std::ptr::write(ptr, T::default()); + set_bit_unchecked(validity_slice, offset, false); + }, + } + }); + vals.set_len(size) + } + PrimitiveArray::new( + ArrowDataType::from(T::PRIMITIVE), + vals.into(), + Some(validity.into()), + ) + } +} diff --git a/crates/polars-arrow/src/lib.rs b/crates/polars-arrow/src/lib.rs new file mode 100644 index 000000000000..8c9b5c0d1af5 --- /dev/null +++ b/crates/polars-arrow/src/lib.rs @@ -0,0 +1,46 @@ +// So that we have more control over what is `unsafe` inside an `unsafe` block +#![allow(unused_unsafe)] +// +#![allow(clippy::len_without_is_empty)] +// this landed on 1.60. Let's not force everyone to bump just yet +#![allow(clippy::unnecessary_lazy_evaluations)] +// Trait objects must be returned as a &Box so that they can be cloned +#![allow(clippy::borrowed_box)] +// Allow type complexity warning to avoid API break. +#![allow(clippy::type_complexity)] +#![cfg_attr(docsrs, feature(doc_cfg))] +#![cfg_attr(feature = "simd", feature(portable_simd))] +#![cfg_attr(feature = "nightly", allow(clippy::non_canonical_partial_ord_impl))] // Remove once stable. +#![cfg_attr(feature = "nightly", allow(clippy::blocks_in_conditions))] // Remove once stable. + +extern crate core; + +#[macro_use] +pub mod array; +pub mod bitmap; +pub mod buffer; +#[cfg(feature = "io_ipc")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc")))] +pub mod mmap; +pub mod record_batch; + +pub mod offset; +pub mod scalar; +pub mod storage; +pub mod trusted_len; +pub mod types; + +pub mod compute; +pub mod io; +pub mod temporal_conversions; + +pub mod datatypes; + +pub mod ffi; +pub mod legacy; +pub mod pushable; +pub mod util; + +// re-exported because we return `Either` in our public API +// re-exported to construct dictionaries +pub use either::Either; diff --git a/crates/polars-arrow/src/mmap/array.rs b/crates/polars-arrow/src/mmap/array.rs new file mode 100644 index 000000000000..cec4793e9d42 --- /dev/null +++ b/crates/polars-arrow/src/mmap/array.rs @@ -0,0 +1,624 @@ +use std::collections::VecDeque; +use std::sync::Arc; + +use polars_error::{PolarsResult, polars_bail, polars_err}; + +use crate::array::{Array, DictionaryKey, FixedSizeListArray, ListArray, StructArray, View}; +use crate::datatypes::ArrowDataType; +use crate::ffi::mmap::create_array; +use crate::ffi::{ArrowArray, InternalArrowArray, export_array_to_c, try_from}; +use crate::io::ipc::IpcField; +use crate::io::ipc::read::{Dictionaries, IpcBuffer, Node, OutOfSpecKind}; +use crate::offset::Offset; +use crate::types::NativeType; +use crate::{match_integer_type, with_match_primitive_type_full}; + +fn get_buffer_bounds(buffers: &mut VecDeque) -> PolarsResult<(usize, usize)> { + let buffer = buffers.pop_front().ok_or_else( + || polars_err!(ComputeError: "out-of-spec {:?}", OutOfSpecKind::ExpectedBuffer), + )?; + + let offset: usize = buffer.offset().try_into().map_err( + |_| polars_err!(ComputeError: "out-of-spec {:?}", OutOfSpecKind::NegativeFooterLength), + )?; + + let length: usize = buffer.length().try_into().map_err( + |_| polars_err!(ComputeError: "out-of-spec {:?}", OutOfSpecKind::NegativeFooterLength), + )?; + + Ok((offset, length)) +} + +/// Checks that the length of `bytes` is at least `size_of::() * expected_len`, and +/// returns a boolean indicating whether it is aligned. +fn check_bytes_len_and_is_aligned( + bytes: &[u8], + expected_len: usize, +) -> PolarsResult { + if bytes.len() < size_of::() * expected_len { + polars_bail!(ComputeError: "buffer's length is too small in mmap") + }; + + Ok(bytemuck::try_cast_slice::<_, T>(bytes).is_ok()) +} + +fn get_buffer<'a, T: NativeType>( + data: &'a [u8], + block_offset: usize, + buffers: &mut VecDeque, + num_rows: usize, +) -> PolarsResult<&'a [u8]> { + let (offset, length) = get_buffer_bounds(buffers)?; + + // verify that they are in-bounds + let values = data + .get(block_offset + offset..block_offset + offset + length) + .ok_or_else(|| polars_err!(ComputeError: "buffer out of bounds"))?; + + if !check_bytes_len_and_is_aligned::(values, num_rows)? { + polars_bail!(ComputeError: "buffer not aligned for mmap"); + } + + Ok(values) +} + +fn get_bytes<'a>( + data: &'a [u8], + block_offset: usize, + buffers: &mut VecDeque, +) -> PolarsResult<&'a [u8]> { + let (offset, length) = get_buffer_bounds(buffers)?; + + // verify that they are in-bounds + data.get(block_offset + offset..block_offset + offset + length) + .ok_or_else(|| polars_err!(ComputeError: "buffer out of bounds")) +} + +fn get_validity<'a>( + data: &'a [u8], + block_offset: usize, + buffers: &mut VecDeque, + null_count: usize, +) -> PolarsResult> { + let validity = get_buffer_bounds(buffers)?; + let (offset, length) = validity; + + Ok(if null_count > 0 { + // verify that they are in-bounds and get its pointer + Some( + data.get(block_offset + offset..block_offset + offset + length) + .ok_or_else(|| polars_err!(ComputeError: "buffer out of bounds"))?, + ) + } else { + None + }) +} + +fn get_num_rows_and_null_count(node: &Node) -> PolarsResult<(usize, usize)> { + let num_rows: usize = node + .length() + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + + let null_count: usize = node + .null_count() + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + Ok((num_rows, null_count)) +} + +fn mmap_binary>( + data: Arc, + node: &Node, + block_offset: usize, + buffers: &mut VecDeque, +) -> PolarsResult { + let (num_rows, null_count) = get_num_rows_and_null_count(node)?; + let data_ref = data.as_ref().as_ref(); + + let validity = get_validity(data_ref, block_offset, buffers, null_count)?.map(|x| x.as_ptr()); + + let offsets = get_buffer::(data_ref, block_offset, buffers, num_rows + 1)?.as_ptr(); + let values = get_buffer::(data_ref, block_offset, buffers, 0)?.as_ptr(); + + // NOTE: offsets and values invariants are _not_ validated + Ok(unsafe { + create_array( + data, + num_rows, + null_count, + [validity, Some(offsets), Some(values)].into_iter(), + [].into_iter(), + None, + None, + ) + }) +} + +fn mmap_binview>( + data: Arc, + node: &Node, + block_offset: usize, + buffers: &mut VecDeque, + variadic_buffer_counts: &mut VecDeque, +) -> PolarsResult { + let (num_rows, null_count) = get_num_rows_and_null_count(node)?; + let data_ref = data.as_ref().as_ref(); + + let validity = get_validity(data_ref, block_offset, buffers, null_count)?.map(|x| x.as_ptr()); + + let views = get_buffer::(data_ref, block_offset, buffers, num_rows)?; + + let n_variadic = variadic_buffer_counts + .pop_front() + .ok_or_else(|| polars_err!(ComputeError: "expected variadic_buffer_count"))?; + + let mut buffer_ptrs = Vec::with_capacity(n_variadic + 2); + buffer_ptrs.push(validity); + buffer_ptrs.push(Some(views.as_ptr())); + + let mut variadic_buffer_sizes = Vec::with_capacity(n_variadic); + for _ in 0..n_variadic { + let variadic_buffer = get_bytes(data_ref, block_offset, buffers)?; + variadic_buffer_sizes.push(variadic_buffer.len() as i64); + buffer_ptrs.push(Some(variadic_buffer.as_ptr())); + } + buffer_ptrs.push(Some(variadic_buffer_sizes.as_ptr().cast::())); + + // Move variadic buffer sizes in an Arc, so that it stays alive. + let data = Arc::new((data, variadic_buffer_sizes)); + + // NOTE: invariants are not validated + Ok(unsafe { + create_array( + data, + num_rows, + null_count, + buffer_ptrs.into_iter(), + [].into_iter(), + None, + None, + ) + }) +} + +fn mmap_fixed_size_binary>( + data: Arc, + node: &Node, + block_offset: usize, + buffers: &mut VecDeque, + dtype: &ArrowDataType, +) -> PolarsResult { + let bytes_per_row = if let ArrowDataType::FixedSizeBinary(bytes_per_row) = dtype { + bytes_per_row + } else { + polars_bail!(ComputeError: "out-of-spec {:?}", OutOfSpecKind::InvalidDataType); + }; + let (num_rows, null_count) = get_num_rows_and_null_count(node)?; + + let data_ref = data.as_ref().as_ref(); + + let validity = get_validity(data_ref, block_offset, buffers, null_count)?.map(|x| x.as_ptr()); + let values = + get_buffer::(data_ref, block_offset, buffers, num_rows * bytes_per_row)?.as_ptr(); + + Ok(unsafe { + create_array( + data, + num_rows, + null_count, + [validity, Some(values)].into_iter(), + [].into_iter(), + None, + None, + ) + }) +} + +fn mmap_null>( + data: Arc, + node: &Node, + _block_offset: usize, + _buffers: &mut VecDeque, +) -> PolarsResult { + let (num_rows, null_count) = get_num_rows_and_null_count(node)?; + + Ok(unsafe { + create_array( + data, + num_rows, + null_count, + [].into_iter(), + [].into_iter(), + None, + None, + ) + }) +} + +fn mmap_boolean>( + data: Arc, + node: &Node, + block_offset: usize, + buffers: &mut VecDeque, +) -> PolarsResult { + let (num_rows, null_count) = get_num_rows_and_null_count(node)?; + + let data_ref = data.as_ref().as_ref(); + + let validity = get_validity(data_ref, block_offset, buffers, null_count)?.map(|x| x.as_ptr()); + + let values = get_buffer_bounds(buffers)?; + let (offset, length) = values; + + // verify that they are in-bounds and get its pointer + let values = data_ref[block_offset + offset..block_offset + offset + length].as_ptr(); + + Ok(unsafe { + create_array( + data, + num_rows, + null_count, + [validity, Some(values)].into_iter(), + [].into_iter(), + None, + None, + ) + }) +} + +fn mmap_primitive>( + data: Arc, + node: &Node, + block_offset: usize, + buffers: &mut VecDeque, +) -> PolarsResult { + let data_ref = data.as_ref().as_ref(); + let (num_rows, null_count) = get_num_rows_and_null_count(node)?; + + let validity = get_validity(data_ref, block_offset, buffers, null_count)?.map(|x| x.as_ptr()); + + let bytes = get_bytes(data_ref, block_offset, buffers)?; + let is_aligned = check_bytes_len_and_is_aligned::

(bytes, num_rows)?; + + let out = if is_aligned || size_of::() <= 8 { + assert!( + is_aligned, + "primitive type with size <= 8 bytes should have been aligned" + ); + let bytes_ptr = bytes.as_ptr(); + + unsafe { + create_array( + data, + num_rows, + null_count, + [validity, Some(bytes_ptr)].into_iter(), + [].into_iter(), + None, + None, + ) + } + } else { + let mut values = vec![P::default(); num_rows]; + unsafe { + std::ptr::copy_nonoverlapping( + bytes.as_ptr(), + values.as_mut_ptr() as *mut u8, + bytes.len(), + ) + }; + // Now we need to keep the new buffer alive + let owned_data = Arc::new(( + // We can drop the original ref if we don't have a validity + validity.and(Some(data)), + values, + )); + let bytes_ptr = owned_data.1.as_ptr() as *mut u8; + + unsafe { + create_array( + owned_data, + num_rows, + null_count, + [validity, Some(bytes_ptr)].into_iter(), + [].into_iter(), + None, + None, + ) + } + }; + + Ok(out) +} + +#[allow(clippy::too_many_arguments)] +fn mmap_list>( + data: Arc, + node: &Node, + block_offset: usize, + dtype: &ArrowDataType, + ipc_field: &IpcField, + dictionaries: &Dictionaries, + field_nodes: &mut VecDeque, + variadic_buffer_counts: &mut VecDeque, + buffers: &mut VecDeque, +) -> PolarsResult { + let child = ListArray::::try_get_child(dtype)?.dtype(); + let (num_rows, null_count) = get_num_rows_and_null_count(node)?; + + let data_ref = data.as_ref().as_ref(); + + let validity = get_validity(data_ref, block_offset, buffers, null_count)?.map(|x| x.as_ptr()); + + let offsets = get_buffer::(data_ref, block_offset, buffers, num_rows + 1)?.as_ptr(); + + let values = get_array( + data.clone(), + block_offset, + child, + &ipc_field.fields[0], + dictionaries, + field_nodes, + variadic_buffer_counts, + buffers, + )?; + + // NOTE: offsets and values invariants are _not_ validated + Ok(unsafe { + create_array( + data, + num_rows, + null_count, + [validity, Some(offsets)].into_iter(), + [values].into_iter(), + None, + None, + ) + }) +} + +#[allow(clippy::too_many_arguments)] +fn mmap_fixed_size_list>( + data: Arc, + node: &Node, + block_offset: usize, + dtype: &ArrowDataType, + ipc_field: &IpcField, + dictionaries: &Dictionaries, + field_nodes: &mut VecDeque, + variadic_buffer_counts: &mut VecDeque, + buffers: &mut VecDeque, +) -> PolarsResult { + let child = FixedSizeListArray::try_child_and_size(dtype)?.0.dtype(); + let (num_rows, null_count) = get_num_rows_and_null_count(node)?; + + let data_ref = data.as_ref().as_ref(); + + let validity = get_validity(data_ref, block_offset, buffers, null_count)?.map(|x| x.as_ptr()); + + let values = get_array( + data.clone(), + block_offset, + child, + &ipc_field.fields[0], + dictionaries, + field_nodes, + variadic_buffer_counts, + buffers, + )?; + + Ok(unsafe { + create_array( + data, + num_rows, + null_count, + [validity].into_iter(), + [values].into_iter(), + None, + None, + ) + }) +} + +#[allow(clippy::too_many_arguments)] +fn mmap_struct>( + data: Arc, + node: &Node, + block_offset: usize, + dtype: &ArrowDataType, + ipc_field: &IpcField, + dictionaries: &Dictionaries, + field_nodes: &mut VecDeque, + variadic_buffer_counts: &mut VecDeque, + buffers: &mut VecDeque, +) -> PolarsResult { + let children = StructArray::try_get_fields(dtype)?; + let (num_rows, null_count) = get_num_rows_and_null_count(node)?; + + let data_ref = data.as_ref().as_ref(); + + let validity = get_validity(data_ref, block_offset, buffers, null_count)?.map(|x| x.as_ptr()); + + let values = children + .iter() + .map(|f| &f.dtype) + .zip(ipc_field.fields.iter()) + .map(|(child, ipc)| { + get_array( + data.clone(), + block_offset, + child, + ipc, + dictionaries, + field_nodes, + variadic_buffer_counts, + buffers, + ) + }) + .collect::>>()?; + + Ok(unsafe { + create_array( + data, + num_rows, + null_count, + [validity].into_iter(), + values.into_iter(), + None, + None, + ) + }) +} + +#[allow(clippy::too_many_arguments)] +fn mmap_dict>( + data: Arc, + node: &Node, + block_offset: usize, + _: &ArrowDataType, + ipc_field: &IpcField, + dictionaries: &Dictionaries, + _: &mut VecDeque, + buffers: &mut VecDeque, +) -> PolarsResult { + let (num_rows, null_count) = get_num_rows_and_null_count(node)?; + + let data_ref = data.as_ref().as_ref(); + + let dictionary = dictionaries + .get(&ipc_field.dictionary_id.unwrap()) + .ok_or_else(|| polars_err!(ComputeError: "out-of-spec: missing dictionary"))? + .clone(); + + let validity = get_validity(data_ref, block_offset, buffers, null_count)?.map(|x| x.as_ptr()); + + let values = get_buffer::(data_ref, block_offset, buffers, num_rows)?.as_ptr(); + + Ok(unsafe { + create_array( + data, + num_rows, + null_count, + [validity, Some(values)].into_iter(), + [].into_iter(), + Some(export_array_to_c(dictionary)), + None, + ) + }) +} + +#[allow(clippy::too_many_arguments)] +fn get_array>( + data: Arc, + block_offset: usize, + dtype: &ArrowDataType, + ipc_field: &IpcField, + dictionaries: &Dictionaries, + field_nodes: &mut VecDeque, + variadic_buffer_counts: &mut VecDeque, + buffers: &mut VecDeque, +) -> PolarsResult { + use crate::datatypes::PhysicalType::*; + let node = field_nodes.pop_front().ok_or_else( + || polars_err!(ComputeError: "out-of-spec: {:?}", OutOfSpecKind::ExpectedBuffer), + )?; + + match dtype.to_physical_type() { + Null => mmap_null(data, &node, block_offset, buffers), + Boolean => mmap_boolean(data, &node, block_offset, buffers), + Primitive(p) => with_match_primitive_type_full!(p, |$T| { + mmap_primitive::<$T, _>(data, &node, block_offset, buffers) + }), + Utf8 | Binary => mmap_binary::(data, &node, block_offset, buffers), + Utf8View | BinaryView => { + mmap_binview(data, &node, block_offset, buffers, variadic_buffer_counts) + }, + FixedSizeBinary => mmap_fixed_size_binary(data, &node, block_offset, buffers, dtype), + LargeBinary | LargeUtf8 => mmap_binary::(data, &node, block_offset, buffers), + List => mmap_list::( + data, + &node, + block_offset, + dtype, + ipc_field, + dictionaries, + field_nodes, + variadic_buffer_counts, + buffers, + ), + LargeList => mmap_list::( + data, + &node, + block_offset, + dtype, + ipc_field, + dictionaries, + field_nodes, + variadic_buffer_counts, + buffers, + ), + FixedSizeList => mmap_fixed_size_list( + data, + &node, + block_offset, + dtype, + ipc_field, + dictionaries, + field_nodes, + variadic_buffer_counts, + buffers, + ), + Struct => mmap_struct( + data, + &node, + block_offset, + dtype, + ipc_field, + dictionaries, + field_nodes, + variadic_buffer_counts, + buffers, + ), + Dictionary(key_type) => match_integer_type!(key_type, |$T| { + mmap_dict::<$T, _>( + data, + &node, + block_offset, + dtype, + ipc_field, + dictionaries, + field_nodes, + buffers, + ) + }), + _ => todo!(), + } +} + +#[allow(clippy::too_many_arguments)] +/// Maps a memory region to an [`Array`]. +pub(crate) unsafe fn mmap>( + data: Arc, + block_offset: usize, + dtype: ArrowDataType, + ipc_field: &IpcField, + dictionaries: &Dictionaries, + field_nodes: &mut VecDeque, + variadic_buffer_counts: &mut VecDeque, + buffers: &mut VecDeque, +) -> PolarsResult> { + let array = get_array( + data, + block_offset, + &dtype, + ipc_field, + dictionaries, + field_nodes, + variadic_buffer_counts, + buffers, + )?; + // The unsafety comes from the fact that `array` is not necessarily valid - + // the IPC file may be corrupted (e.g. invalid offsets or non-utf8 data) + unsafe { try_from(InternalArrowArray::new(array, dtype)) } +} diff --git a/crates/polars-arrow/src/mmap/mod.rs b/crates/polars-arrow/src/mmap/mod.rs new file mode 100644 index 000000000000..a5614f48e089 --- /dev/null +++ b/crates/polars-arrow/src/mmap/mod.rs @@ -0,0 +1,245 @@ +#![allow(unsafe_op_in_unsafe_fn)] +//! Memory maps regions defined on the IPC format into [`Array`]. +use std::collections::VecDeque; +use std::sync::Arc; + +mod array; + +use arrow_format::ipc::planus::ReadAsRoot; +use arrow_format::ipc::{Block, DictionaryBatchRef, MessageRef, RecordBatchRef}; +use polars_error::{PolarsResult, polars_bail, polars_err, to_compute_err}; +use polars_utils::pl_str::PlSmallStr; + +use crate::array::Array; +use crate::datatypes::{ArrowDataType, ArrowSchema, Field}; +use crate::io::ipc::read::file::{get_dictionary_batch, get_record_batch}; +use crate::io::ipc::read::{ + Dictionaries, FileMetadata, IpcBuffer, Node, OutOfSpecKind, first_dict_field, +}; +use crate::io::ipc::{CONTINUATION_MARKER, IpcField}; +use crate::record_batch::RecordBatchT; + +fn read_message( + mut bytes: &[u8], + block: arrow_format::ipc::Block, +) -> PolarsResult<(MessageRef, usize)> { + let offset: usize = block.offset.try_into().map_err( + |_err| polars_err!(ComputeError: "out-of-spec {:?}", OutOfSpecKind::NegativeFooterLength), + )?; + + let block_length: usize = block.meta_data_length.try_into().map_err( + |_err| polars_err!(ComputeError: "out-of-spec {:?}", OutOfSpecKind::NegativeFooterLength), + )?; + + bytes = &bytes[offset..]; + let mut message_length = bytes[..4].try_into().unwrap(); + bytes = &bytes[4..]; + + if message_length == CONTINUATION_MARKER { + // continuation marker encountered, read message next + message_length = bytes[..4].try_into().unwrap(); + bytes = &bytes[4..]; + }; + + let message_length: usize = i32::from_le_bytes(message_length).try_into().map_err( + |_err| polars_err!(ComputeError: "out-of-spec {:?}", OutOfSpecKind::NegativeFooterLength), + )?; + + let message = arrow_format::ipc::MessageRef::read_as_root(&bytes[..message_length]) + .map_err(|err| polars_err!(ComputeError: "out-of-spec {:?}", OutOfSpecKind::InvalidFlatbufferMessage(err)))?; + + Ok((message, offset + block_length)) +} + +fn get_buffers_nodes(batch: RecordBatchRef) -> PolarsResult<(VecDeque, VecDeque)> { + let compression = batch.compression().map_err(to_compute_err)?; + if compression.is_some() { + polars_bail!(ComputeError: "memory_map can only be done on uncompressed IPC files") + } + + let buffers = batch + .buffers() + .map_err(|err| polars_err!(ComputeError: "out-of-spec {:?}", OutOfSpecKind::InvalidFlatbufferBuffers(err)))? + .ok_or_else(|| polars_err!(ComputeError: "out-of-spec {:?}", OutOfSpecKind::MissingMessageBuffers))?; + let buffers = buffers.iter().collect::>(); + + let field_nodes = batch + .nodes() + .map_err(|err| polars_err!(ComputeError: "out-of-spec {:?}", OutOfSpecKind::InvalidFlatbufferNodes(err)))? + .ok_or_else(|| polars_err!(ComputeError: "out-of-spec {:?}", OutOfSpecKind::MissingMessageNodes))?; + let field_nodes = field_nodes.iter().collect::>(); + + Ok((buffers, field_nodes)) +} + +pub(crate) unsafe fn mmap_record>( + fields: &ArrowSchema, + ipc_fields: &[IpcField], + data: Arc, + batch: RecordBatchRef, + offset: usize, + dictionaries: &Dictionaries, +) -> PolarsResult>> { + let (mut buffers, mut field_nodes) = get_buffers_nodes(batch)?; + let mut variadic_buffer_counts = batch + .variadic_buffer_counts() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferRecordBatches(err)))? + .map(|v| v.iter().map(|v| v as usize).collect::>()) + .unwrap_or_else(VecDeque::new); + + let length = batch + .length() + .map_err(|_| polars_err!(oos = OutOfSpecKind::MissingData)) + .unwrap() + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + + fields + .iter_values() + .map(|f| &f.dtype) + .cloned() + .zip(ipc_fields) + .map(|(dtype, ipc_field)| { + array::mmap( + data.clone(), + offset, + dtype, + ipc_field, + dictionaries, + &mut field_nodes, + &mut variadic_buffer_counts, + &mut buffers, + ) + }) + .collect::>() + .and_then(|arr| { + RecordBatchT::try_new( + length, + Arc::new(fields.iter_values().cloned().collect()), + arr, + ) + }) +} + +/// Memory maps an record batch from an IPC file into a [`RecordBatchT`]. +/// # Errors +/// This function errors when: +/// * The IPC file is not valid +/// * the buffers on the file are un-aligned with their corresponding data. This can happen when: +/// * the file was written with 8-bit alignment +/// * the file contains type decimal 128 or 256 +/// # Safety +/// The caller must ensure that `data` contains a valid buffers, for example: +/// * Offsets in variable-sized containers must be in-bounds and increasing +/// * Utf8 data is valid +pub unsafe fn mmap_unchecked>( + metadata: &FileMetadata, + dictionaries: &Dictionaries, + data: Arc, + chunk: usize, +) -> PolarsResult>> { + let block = metadata.blocks[chunk]; + + let (message, offset) = read_message(data.as_ref().as_ref(), block)?; + let batch = get_record_batch(message)?; + mmap_record( + &metadata.schema, + &metadata.ipc_schema.fields, + data.clone(), + batch, + offset, + dictionaries, + ) +} + +unsafe fn mmap_dictionary>( + schema: &ArrowSchema, + ipc_fields: &[IpcField], + data: Arc, + block: Block, + dictionaries: &mut Dictionaries, +) -> PolarsResult<()> { + let (message, offset) = read_message(data.as_ref().as_ref(), block)?; + let batch = get_dictionary_batch(&message)?; + mmap_dictionary_from_batch(schema, ipc_fields, &data, batch, dictionaries, offset) +} + +pub(crate) unsafe fn mmap_dictionary_from_batch>( + schema: &ArrowSchema, + ipc_fields: &[IpcField], + data: &Arc, + batch: DictionaryBatchRef, + dictionaries: &mut Dictionaries, + offset: usize, +) -> PolarsResult<()> { + let id = batch + .id() + .map_err(|err| polars_err!(ComputeError: "out-of-spec {:?}", OutOfSpecKind::InvalidFlatbufferId(err)))?; + let (first_field, first_ipc_field) = first_dict_field(id, schema, ipc_fields)?; + + let batch = batch + .data() + .map_err(|err| polars_err!(ComputeError: "out-of-spec {:?}", OutOfSpecKind::InvalidFlatbufferData(err)))? + .ok_or_else(|| polars_err!(ComputeError: "out-of-spec {:?}", OutOfSpecKind::MissingData))?; + + let value_type = if let ArrowDataType::Dictionary(_, value_type, _) = + first_field.dtype.to_logical_type() + { + value_type.as_ref() + } else { + polars_bail!(ComputeError: "out-of-spec {:?}", OutOfSpecKind::InvalidIdDataType {requested_id: id} ) + }; + + // Make a fake schema for the dictionary batch. + let field = Field::new(PlSmallStr::EMPTY, value_type.clone(), false); + + let chunk = mmap_record( + &std::iter::once((field.name.clone(), field)).collect(), + &[first_ipc_field.clone()], + data.clone(), + batch, + offset, + dictionaries, + )?; + + dictionaries.insert(id, chunk.into_arrays().pop().unwrap()); + + Ok(()) +} + +/// Memory maps dictionaries from an IPC file into +/// # Safety +/// The caller must ensure that `data` contains a valid buffers, for example: +/// * Offsets in variable-sized containers must be in-bounds and increasing +/// * Utf8 data is valid +pub unsafe fn mmap_dictionaries_unchecked>( + metadata: &FileMetadata, + data: Arc, +) -> PolarsResult { + mmap_dictionaries_unchecked2( + metadata.schema.as_ref(), + &metadata.ipc_schema.fields, + metadata.dictionaries.as_ref(), + data, + ) +} + +pub(crate) unsafe fn mmap_dictionaries_unchecked2>( + schema: &ArrowSchema, + ipc_fields: &[IpcField], + dictionaries: Option<&Vec>, + data: Arc, +) -> PolarsResult { + let blocks = if let Some(blocks) = &dictionaries { + blocks + } else { + return Ok(Default::default()); + }; + + let mut dictionaries = Default::default(); + + blocks.iter().cloned().try_for_each(|block| { + mmap_dictionary(schema, ipc_fields, data.clone(), block, &mut dictionaries) + })?; + Ok(dictionaries) +} diff --git a/crates/polars-arrow/src/offset.rs b/crates/polars-arrow/src/offset.rs new file mode 100644 index 000000000000..3781a997a5a9 --- /dev/null +++ b/crates/polars-arrow/src/offset.rs @@ -0,0 +1,662 @@ +#![allow(unsafe_op_in_unsafe_fn)] +//! Contains the declaration of [`Offset`] +use std::hint::unreachable_unchecked; +use std::ops::Deref; + +use polars_error::{PolarsError, PolarsResult, polars_bail, polars_err}; + +use crate::array::Splitable; +use crate::buffer::Buffer; +pub use crate::types::Offset; + +/// A wrapper type of [`Vec`] representing the invariants of Arrow's offsets. +/// It is guaranteed to (sound to assume that): +/// * every element is `>= 0` +/// * element at position `i` is >= than element at position `i-1`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Offsets(Vec); + +impl Default for Offsets { + #[inline] + fn default() -> Self { + Self::new() + } +} + +impl Deref for Offsets { + type Target = [O]; + + fn deref(&self) -> &Self::Target { + self.as_slice() + } +} + +impl TryFrom> for Offsets { + type Error = PolarsError; + + #[inline] + fn try_from(offsets: Vec) -> Result { + try_check_offsets(&offsets)?; + Ok(Self(offsets)) + } +} + +impl TryFrom> for OffsetsBuffer { + type Error = PolarsError; + + #[inline] + fn try_from(offsets: Buffer) -> Result { + try_check_offsets(&offsets)?; + Ok(Self(offsets)) + } +} + +impl TryFrom> for OffsetsBuffer { + type Error = PolarsError; + + #[inline] + fn try_from(offsets: Vec) -> Result { + try_check_offsets(&offsets)?; + Ok(Self(offsets.into())) + } +} + +impl From> for OffsetsBuffer { + #[inline] + fn from(offsets: Offsets) -> Self { + Self(offsets.0.into()) + } +} + +impl Offsets { + /// Returns an empty [`Offsets`] (i.e. with a single element, the zero) + #[inline] + pub fn new() -> Self { + Self(vec![O::zero()]) + } + + /// Returns an [`Offsets`] whose all lengths are zero. + #[inline] + pub fn new_zeroed(length: usize) -> Self { + Self(vec![O::zero(); length + 1]) + } + + /// Creates a new [`Offsets`] from an iterator of lengths + #[inline] + pub fn try_from_iter>(iter: I) -> PolarsResult { + let iterator = iter.into_iter(); + let (lower, _) = iterator.size_hint(); + let mut offsets = Self::with_capacity(lower); + for item in iterator { + offsets.try_push(item)? + } + Ok(offsets) + } + + /// Returns a new [`Offsets`] with a capacity, allocating at least `capacity + 1` entries. + pub fn with_capacity(capacity: usize) -> Self { + let mut offsets = Vec::with_capacity(capacity + 1); + offsets.push(O::zero()); + Self(offsets) + } + + /// Returns the capacity of [`Offsets`]. + pub fn capacity(&self) -> usize { + self.0.capacity() - 1 + } + + /// Reserves `additional` entries. + pub fn reserve(&mut self, additional: usize) { + self.0.reserve(additional); + } + + /// Shrinks the capacity of self to fit. + pub fn shrink_to_fit(&mut self) { + self.0.shrink_to_fit(); + } + + /// Pushes a new element with a given length. + /// # Error + /// This function errors iff the new last item is larger than what `O` supports. + /// # Implementation + /// This function: + /// * checks that this length does not overflow + #[inline] + pub fn try_push(&mut self, length: usize) -> PolarsResult<()> { + if O::IS_LARGE { + let length = O::from_as_usize(length); + let old_length = self.last(); + let new_length = *old_length + length; + self.0.push(new_length); + Ok(()) + } else { + let length = + O::from_usize(length).ok_or_else(|| polars_err!(ComputeError: "overflow"))?; + + let old_length = self.last(); + let new_length = old_length + .checked_add(&length) + .ok_or_else(|| polars_err!(ComputeError: "overflow"))?; + self.0.push(new_length); + Ok(()) + } + } + + /// Returns [`Offsets`] assuming that `offsets` fulfills its invariants + /// + /// # Safety + /// This is safe iff the invariants of this struct are guaranteed in `offsets`. + #[inline] + pub unsafe fn new_unchecked(offsets: Vec) -> Self { + #[cfg(debug_assertions)] + { + let mut prev_offset = O::default(); + let mut is_monotonely_increasing = true; + for offset in &offsets { + is_monotonely_increasing &= *offset >= prev_offset; + prev_offset = *offset; + } + assert!( + is_monotonely_increasing, + "Unsafe precondition violated. Invariant of offsets broken." + ); + } + + Self(offsets) + } + + /// Returns the last offset of this container. + #[inline] + pub fn last(&self) -> &O { + match self.0.last() { + Some(element) => element, + None => unsafe { unreachable_unchecked() }, + } + } + + /// Returns a `length` corresponding to the position `index` + /// # Panic + /// This function panics iff `index >= self.len_proxy()` + #[inline] + pub fn length_at(&self, index: usize) -> usize { + let (start, end) = self.start_end(index); + end - start + } + + /// Returns a range (start, end) corresponding to the position `index` + /// # Panic + /// This function panics iff `index >= self.len_proxy()` + #[inline] + pub fn start_end(&self, index: usize) -> (usize, usize) { + // soundness: the invariant of the function + assert!(index < self.len_proxy()); + unsafe { self.start_end_unchecked(index) } + } + + /// Returns a range (start, end) corresponding to the position `index` + /// + /// # Safety + /// `index` must be `< self.len()` + #[inline] + pub unsafe fn start_end_unchecked(&self, index: usize) -> (usize, usize) { + // soundness: the invariant of the function + let start = self.0.get_unchecked(index).to_usize(); + let end = self.0.get_unchecked(index + 1).to_usize(); + (start, end) + } + + /// Returns the length an array with these offsets would be. + #[inline] + pub fn len_proxy(&self) -> usize { + self.0.len() - 1 + } + + /// Returns the byte slice stored in this buffer + #[inline] + pub fn as_slice(&self) -> &[O] { + self.0.as_slice() + } + + /// Pops the last element + #[inline] + pub fn pop(&mut self) -> Option { + if self.len_proxy() == 0 { + None + } else { + self.0.pop() + } + } + + /// Extends itself with `additional` elements equal to the last offset. + /// This is useful to extend offsets with empty values, e.g. for null slots. + #[inline] + pub fn extend_constant(&mut self, additional: usize) { + let offset = *self.last(); + if additional == 1 { + self.0.push(offset) + } else { + self.0.resize(self.0.len() + additional, offset) + } + } + + /// Try to create a new [`Offsets`] from a sequence of `lengths` + /// # Errors + /// This function errors iff this operation overflows for the maximum value of `O`. + #[inline] + pub fn try_from_lengths>(lengths: I) -> PolarsResult { + let mut self_ = Self::with_capacity(lengths.size_hint().0); + self_.try_extend_from_lengths(lengths)?; + Ok(self_) + } + + /// Try extend from an iterator of lengths + /// # Errors + /// This function errors iff this operation overflows for the maximum value of `O`. + #[inline] + pub fn try_extend_from_lengths>( + &mut self, + lengths: I, + ) -> PolarsResult<()> { + let mut total_length = 0; + let mut offset = *self.last(); + let original_offset = offset.to_usize(); + + let lengths = lengths.map(|length| { + total_length += length; + O::from_as_usize(length) + }); + + let offsets = lengths.map(|length| { + offset += length; // this may overflow, checked below + offset + }); + self.0.extend(offsets); + + let last_offset = original_offset + .checked_add(total_length) + .ok_or_else(|| polars_err!(ComputeError: "overflow"))?; + O::from_usize(last_offset).ok_or_else(|| polars_err!(ComputeError: "overflow"))?; + Ok(()) + } + + /// Extends itself from another [`Offsets`] + /// # Errors + /// This function errors iff this operation overflows for the maximum value of `O`. + pub fn try_extend_from_self(&mut self, other: &Self) -> PolarsResult<()> { + let mut length = *self.last(); + let other_length = *other.last(); + // check if the operation would overflow + length + .checked_add(&other_length) + .ok_or_else(|| polars_err!(ComputeError: "overflow"))?; + + let lengths = other.as_slice().windows(2).map(|w| w[1] - w[0]); + let offsets = lengths.map(|new_length| { + length += new_length; + length + }); + self.0.extend(offsets); + Ok(()) + } + + /// Extends itself from another [`Offsets`] sliced by `start, length` + /// # Errors + /// This function errors iff this operation overflows for the maximum value of `O`. + pub fn try_extend_from_slice( + &mut self, + other: &OffsetsBuffer, + start: usize, + length: usize, + ) -> PolarsResult<()> { + if length == 0 { + return Ok(()); + } + let other = &other.0[start..start + length + 1]; + let other_length = other.last().expect("Length to be non-zero"); + let mut length = *self.last(); + // check if the operation would overflow + length + .checked_add(other_length) + .ok_or_else(|| polars_err!(ComputeError: "overflow"))?; + + let lengths = other.windows(2).map(|w| w[1] - w[0]); + let offsets = lengths.map(|new_length| { + length += new_length; + length + }); + self.0.extend(offsets); + Ok(()) + } + + /// Returns the inner [`Vec`]. + #[inline] + pub fn into_inner(self) -> Vec { + self.0 + } +} + +/// Checks that `offsets` is monotonically increasing. +fn try_check_offsets(offsets: &[O]) -> PolarsResult<()> { + // this code is carefully constructed to auto-vectorize, don't change naively! + match offsets.first() { + None => polars_bail!(ComputeError: "offsets must have at least one element"), + Some(first) => { + if *first < O::zero() { + polars_bail!(ComputeError: "offsets must be larger than 0") + } + let mut previous = *first; + let mut any_invalid = false; + + // This loop will auto-vectorize because there is not any break, + // an invalid value will be returned once the whole offsets buffer is processed. + for offset in offsets { + if previous > *offset { + any_invalid = true + } + previous = *offset; + } + + if any_invalid { + polars_bail!(ComputeError: "offsets must be monotonically increasing") + } else { + Ok(()) + } + }, + } +} + +/// A wrapper type of [`Buffer`] that is guaranteed to: +/// * Always contain an element +/// * Every element is `>= 0` +/// * element at position `i` is >= than element at position `i-1`. +#[derive(Clone, PartialEq, Debug)] +pub struct OffsetsBuffer(Buffer); + +impl Default for OffsetsBuffer { + #[inline] + fn default() -> Self { + Self(vec![O::zero()].into()) + } +} + +impl OffsetsBuffer { + /// # Safety + /// This is safe iff the invariants of this struct are guaranteed in `offsets`. + #[inline] + pub unsafe fn new_unchecked(offsets: Buffer) -> Self { + Self(offsets) + } + + /// Returns an empty [`OffsetsBuffer`] (i.e. with a single element, the zero) + #[inline] + pub fn new() -> Self { + Self(vec![O::zero()].into()) + } + + #[inline] + pub fn one_with_length(length: O) -> Self { + Self(vec![O::zero(), length].into()) + } + + /// Copy-on-write API to convert [`OffsetsBuffer`] into [`Offsets`]. + #[inline] + pub fn into_mut(self) -> either::Either> { + self.0 + .into_mut() + // SAFETY: Offsets and OffsetsBuffer share invariants + .map_right(|offsets| unsafe { Offsets::new_unchecked(offsets) }) + .map_left(Self) + } + + /// Returns a reference to its internal [`Buffer`]. + #[inline] + pub fn buffer(&self) -> &Buffer { + &self.0 + } + + /// Returns what the length an array with these offsets would be. + #[inline] + pub fn len_proxy(&self) -> usize { + self.0.len() - 1 + } + + /// Returns the number of offsets in this container. + #[inline] + pub fn len(&self) -> usize { + self.0.len() + } + + /// Returns the byte slice stored in this buffer + #[inline] + pub fn as_slice(&self) -> &[O] { + self.0.as_slice() + } + + /// Returns the range of the offsets. + #[inline] + pub fn range(&self) -> O { + *self.last() - *self.first() + } + + /// Returns the first offset. + #[inline] + pub fn first(&self) -> &O { + match self.0.first() { + Some(element) => element, + None => unsafe { unreachable_unchecked() }, + } + } + + /// Returns the last offset. + #[inline] + pub fn last(&self) -> &O { + match self.0.last() { + Some(element) => element, + None => unsafe { unreachable_unchecked() }, + } + } + + /// Returns a `length` corresponding to the position `index` + /// # Panic + /// This function panics iff `index >= self.len_proxy()` + #[inline] + pub fn length_at(&self, index: usize) -> usize { + let (start, end) = self.start_end(index); + end - start + } + + /// Returns a range (start, end) corresponding to the position `index` + /// # Panic + /// This function panics iff `index >= self.len_proxy()` + #[inline] + pub fn start_end(&self, index: usize) -> (usize, usize) { + // soundness: the invariant of the function + assert!(index < self.len_proxy()); + unsafe { self.start_end_unchecked(index) } + } + + /// Returns a range (start, end) corresponding to the position `index` + /// + /// # Safety + /// `index` must be `< self.len()` + #[inline] + pub unsafe fn start_end_unchecked(&self, index: usize) -> (usize, usize) { + // soundness: the invariant of the function + let start = self.0.get_unchecked(index).to_usize(); + let end = self.0.get_unchecked(index + 1).to_usize(); + (start, end) + } + + /// Slices this [`OffsetsBuffer`]. + /// # Panics + /// Panics if `offset + length` is larger than `len` + /// or `length == 0`. + #[inline] + pub fn slice(&mut self, offset: usize, length: usize) { + assert!(length > 0); + self.0.slice(offset, length); + } + + /// Slices this [`OffsetsBuffer`] starting at `offset`. + /// + /// # Safety + /// The caller must ensure `offset + length <= self.len()` + #[inline] + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.0.slice_unchecked(offset, length); + } + + /// Returns an iterator with the lengths of the offsets + #[inline] + pub fn lengths(&self) -> impl ExactSizeIterator + '_ { + self.0.windows(2).map(|w| (w[1] - w[0]).to_usize()) + } + + /// Returns `(offset, len)` pairs. + #[inline] + pub fn offset_and_length_iter(&self) -> impl ExactSizeIterator + '_ { + self.windows(2).map(|x| { + let [l, r] = x else { unreachable!() }; + let l = l.to_usize(); + let r = r.to_usize(); + (l, r - l) + }) + } + + /// Offset and length of the primitive (leaf) array for a double+ nested list for every outer + /// row. + pub fn leaf_ranges_iter( + offsets: &[Self], + ) -> impl Iterator> + '_ { + let others = &offsets[1..]; + + offsets[0].windows(2).map(move |x| { + let [l, r] = x else { unreachable!() }; + let mut l = l.to_usize(); + let mut r = r.to_usize(); + + for o in others { + let slc = o.as_slice(); + l = slc[l].to_usize(); + r = slc[r].to_usize(); + } + + l..r + }) + } + + /// Return the full range of the leaf array used by the list. + pub fn leaf_full_start_end(offsets: &[Self]) -> core::ops::Range { + let mut l = offsets[0].first().to_usize(); + let mut r = offsets[0].last().to_usize(); + + for o in &offsets[1..] { + let slc = o.as_slice(); + l = slc[l].to_usize(); + r = slc[r].to_usize(); + } + + l..r + } + + /// Returns the inner [`Buffer`]. + #[inline] + pub fn into_inner(self) -> Buffer { + self.0 + } + + /// Returns the offset difference between `start` and `end`. + #[inline] + pub fn delta(&self, start: usize, end: usize) -> usize { + assert!(start <= end); + + (self.0[end + 1] - self.0[start]).to_usize() + } +} + +impl From<&OffsetsBuffer> for OffsetsBuffer { + fn from(offsets: &OffsetsBuffer) -> Self { + // this conversion is lossless and uphelds all invariants + Self( + offsets + .buffer() + .iter() + .map(|x| *x as i64) + .collect::>() + .into(), + ) + } +} + +impl TryFrom<&OffsetsBuffer> for OffsetsBuffer { + type Error = PolarsError; + + fn try_from(offsets: &OffsetsBuffer) -> Result { + i32::try_from(*offsets.last()).map_err(|_| polars_err!(ComputeError: "overflow"))?; + + // this conversion is lossless and uphelds all invariants + Ok(Self( + offsets + .buffer() + .iter() + .map(|x| *x as i32) + .collect::>() + .into(), + )) + } +} + +impl From> for Offsets { + fn from(offsets: Offsets) -> Self { + // this conversion is lossless and uphelds all invariants + Self( + offsets + .as_slice() + .iter() + .map(|x| *x as i64) + .collect::>(), + ) + } +} + +impl TryFrom> for Offsets { + type Error = PolarsError; + + fn try_from(offsets: Offsets) -> Result { + i32::try_from(*offsets.last()).map_err(|_| polars_err!(ComputeError: "overflow"))?; + + // this conversion is lossless and uphelds all invariants + Ok(Self( + offsets + .as_slice() + .iter() + .map(|x| *x as i32) + .collect::>(), + )) + } +} + +impl std::ops::Deref for OffsetsBuffer { + type Target = [O]; + + #[inline] + fn deref(&self) -> &[O] { + self.0.as_slice() + } +} + +impl Splitable for OffsetsBuffer { + fn check_bound(&self, offset: usize) -> bool { + offset <= self.len_proxy() + } + + unsafe fn _split_at_unchecked(&self, offset: usize) -> (Self, Self) { + let mut lhs = self.0.clone(); + let mut rhs = self.0.clone(); + + lhs.slice(0, offset + 1); + rhs.slice(offset, self.0.len() - offset); + + (Self(lhs), Self(rhs)) + } +} diff --git a/crates/polars-arrow/src/pushable.rs b/crates/polars-arrow/src/pushable.rs new file mode 100644 index 000000000000..29464b8df679 --- /dev/null +++ b/crates/polars-arrow/src/pushable.rs @@ -0,0 +1,333 @@ +use crate::array::{ + BinaryViewArrayGeneric, BooleanArray, MutableBinaryViewArray, MutableBooleanArray, + MutablePrimitiveArray, PrimitiveArray, ViewType, +}; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::offset::{Offset, Offsets, OffsetsBuffer}; +use crate::types::NativeType; + +/// A private trait representing structs that can receive elements. +pub trait Pushable: Sized + Default { + type Freeze; + + fn with_capacity(capacity: usize) -> Self { + let mut new = Self::default(); + new.reserve(capacity); + new + } + fn reserve(&mut self, additional: usize); + fn push(&mut self, value: T); + fn len(&self) -> usize; + fn push_null(&mut self); + #[inline] + fn extend_n(&mut self, n: usize, iter: impl Iterator) { + for item in iter.take(n) { + self.push(item); + } + } + fn extend_constant(&mut self, additional: usize, value: T); + fn extend_null_constant(&mut self, additional: usize); + fn freeze(self) -> Self::Freeze; +} + +impl Pushable for MutableBitmap { + type Freeze = Bitmap; + + #[inline] + fn reserve(&mut self, additional: usize) { + MutableBitmap::reserve(self, additional) + } + + #[inline] + fn len(&self) -> usize { + self.len() + } + + #[inline] + fn push(&mut self, value: bool) { + self.push(value) + } + + #[inline] + fn push_null(&mut self) { + self.push(false) + } + + #[inline] + fn extend_constant(&mut self, additional: usize, value: bool) { + self.extend_constant(additional, value) + } + + #[inline] + fn extend_null_constant(&mut self, additional: usize) { + self.extend_constant(additional, false) + } + + fn freeze(self) -> Self::Freeze { + self.into() + } +} + +impl Pushable for Vec { + type Freeze = Vec; + #[inline] + fn reserve(&mut self, additional: usize) { + Vec::reserve(self, additional) + } + #[inline] + fn len(&self) -> usize { + self.len() + } + + #[inline] + fn push_null(&mut self) { + self.push(T::default()) + } + + #[inline] + fn push(&mut self, value: T) { + self.push(value) + } + + #[inline] + fn extend_n(&mut self, n: usize, iter: impl Iterator) { + self.extend(iter.take(n)); + } + + #[inline] + fn extend_constant(&mut self, additional: usize, value: T) { + self.resize(self.len() + additional, value); + } + + #[inline] + fn extend_null_constant(&mut self, additional: usize) { + self.extend_constant(additional, T::default()) + } + fn freeze(self) -> Self::Freeze { + self + } +} +impl Pushable for Offsets { + type Freeze = OffsetsBuffer; + fn reserve(&mut self, additional: usize) { + self.reserve(additional) + } + #[inline] + fn len(&self) -> usize { + self.len_proxy() + } + + #[inline] + fn push(&mut self, value: usize) { + self.try_push(value).unwrap() + } + + #[inline] + fn push_null(&mut self) { + self.extend_constant(1); + } + + #[inline] + fn extend_constant(&mut self, additional: usize, _: usize) { + self.extend_constant(additional) + } + + #[inline] + fn extend_null_constant(&mut self, additional: usize) { + self.extend_constant(additional) + } + fn freeze(self) -> Self::Freeze { + self.into() + } +} + +impl Pushable> for MutablePrimitiveArray { + type Freeze = PrimitiveArray; + + #[inline] + fn reserve(&mut self, additional: usize) { + MutablePrimitiveArray::reserve(self, additional) + } + + #[inline] + fn push(&mut self, value: Option) { + MutablePrimitiveArray::push(self, value) + } + + #[inline] + fn len(&self) -> usize { + self.values().len() + } + + #[inline] + fn push_null(&mut self) { + self.push(None) + } + + #[inline] + fn extend_constant(&mut self, additional: usize, value: Option) { + MutablePrimitiveArray::extend_constant(self, additional, value) + } + + #[inline] + fn extend_null_constant(&mut self, additional: usize) { + MutablePrimitiveArray::extend_constant(self, additional, None) + } + fn freeze(self) -> Self::Freeze { + self.into() + } +} + +pub trait NoOption {} +impl NoOption for &str {} +impl NoOption for &[u8] {} +impl NoOption for Vec {} + +impl Pushable for MutableBinaryViewArray +where + T: AsRef + NoOption, + K: ViewType + ?Sized, +{ + type Freeze = BinaryViewArrayGeneric; + + #[inline] + fn reserve(&mut self, additional: usize) { + MutableBinaryViewArray::reserve(self, additional) + } + + #[inline] + fn push(&mut self, value: T) { + MutableBinaryViewArray::push_value(self, value) + } + + #[inline] + fn len(&self) -> usize { + MutableBinaryViewArray::len(self) + } + + fn push_null(&mut self) { + MutableBinaryViewArray::push_null(self) + } + + fn extend_constant(&mut self, additional: usize, value: T) { + MutableBinaryViewArray::extend_constant(self, additional, Some(value)); + } + + #[inline] + fn extend_null_constant(&mut self, additional: usize) { + self.extend_null(additional); + } + fn freeze(self) -> Self::Freeze { + self.into() + } +} + +impl Pushable> for MutableBinaryViewArray +where + T: AsRef, + K: ViewType + ?Sized, +{ + type Freeze = BinaryViewArrayGeneric; + #[inline] + fn reserve(&mut self, additional: usize) { + MutableBinaryViewArray::reserve(self, additional) + } + + #[inline] + fn push(&mut self, value: Option) { + MutableBinaryViewArray::push(self, value.as_ref()) + } + + #[inline] + fn len(&self) -> usize { + MutableBinaryViewArray::len(self) + } + + fn push_null(&mut self) { + MutableBinaryViewArray::push_null(self) + } + + fn extend_constant(&mut self, additional: usize, value: Option) { + MutableBinaryViewArray::extend_constant(self, additional, value); + } + + #[inline] + fn extend_null_constant(&mut self, additional: usize) { + self.extend_null(additional); + } + fn freeze(self) -> Self::Freeze { + self.into() + } +} + +impl Pushable for MutableBooleanArray { + type Freeze = BooleanArray; + #[inline] + fn reserve(&mut self, additional: usize) { + MutableBooleanArray::reserve(self, additional) + } + + #[inline] + fn push(&mut self, value: bool) { + MutableBooleanArray::push_value(self, value) + } + + #[inline] + fn len(&self) -> usize { + self.values().len() + } + + #[inline] + fn push_null(&mut self) { + unimplemented!() + } + + #[inline] + fn extend_constant(&mut self, additional: usize, value: bool) { + MutableBooleanArray::extend_constant(self, additional, Some(value)) + } + + #[inline] + fn extend_null_constant(&mut self, _additional: usize) { + unimplemented!() + } + fn freeze(self) -> Self::Freeze { + self.into() + } +} + +impl Pushable> for MutableBooleanArray { + type Freeze = BooleanArray; + #[inline] + fn reserve(&mut self, additional: usize) { + MutableBooleanArray::reserve(self, additional) + } + + #[inline] + fn push(&mut self, value: Option) { + MutableBooleanArray::push(self, value) + } + + #[inline] + fn len(&self) -> usize { + self.values().len() + } + + #[inline] + fn push_null(&mut self) { + MutableBooleanArray::push_null(self) + } + + #[inline] + fn extend_constant(&mut self, additional: usize, value: Option) { + MutableBooleanArray::extend_constant(self, additional, value) + } + + #[inline] + fn extend_null_constant(&mut self, additional: usize) { + MutableBooleanArray::extend_constant(self, additional, None) + } + fn freeze(self) -> Self::Freeze { + self.into() + } +} diff --git a/crates/polars-arrow/src/record_batch.rs b/crates/polars-arrow/src/record_batch.rs new file mode 100644 index 000000000000..a96e778a8bc6 --- /dev/null +++ b/crates/polars-arrow/src/record_batch.rs @@ -0,0 +1,113 @@ +//! Contains [`RecordBatchT`], a container of [`Array`] where every array has the +//! same length. + +use polars_error::{PolarsResult, polars_ensure}; + +use crate::array::{Array, ArrayRef}; +use crate::datatypes::{ArrowSchema, ArrowSchemaRef}; + +/// A vector of trait objects of [`Array`] where every item has +/// the same length, [`RecordBatchT::len`]. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct RecordBatchT> { + height: usize, + schema: ArrowSchemaRef, + arrays: Vec, +} + +pub type RecordBatch = RecordBatchT; + +impl> RecordBatchT { + /// Creates a new [`RecordBatchT`]. + /// + /// # Panics + /// + /// I.f.f. the length does not match the length of any of the arrays + pub fn new(length: usize, schema: ArrowSchemaRef, arrays: Vec) -> Self { + Self::try_new(length, schema, arrays).unwrap() + } + + /// Creates a new [`RecordBatchT`]. + /// + /// # Error + /// + /// I.f.f. the height does not match the length of any of the arrays + pub fn try_new(height: usize, schema: ArrowSchemaRef, arrays: Vec) -> PolarsResult { + polars_ensure!( + schema.len() == arrays.len(), + ComputeError: "RecordBatch requires an equal number of fields and arrays", + ); + polars_ensure!( + arrays.iter().all(|arr| arr.as_ref().len() == height), + ComputeError: "RecordBatch requires all its arrays to have an equal number of rows", + ); + + Ok(Self { + height, + schema, + arrays, + }) + } + + /// returns the [`Array`]s in [`RecordBatchT`] + pub fn arrays(&self) -> &[A] { + &self.arrays + } + + /// returns the [`ArrowSchema`]s in [`RecordBatchT`] + pub fn schema(&self) -> &ArrowSchema { + &self.schema + } + + /// returns the [`Array`]s in [`RecordBatchT`] + pub fn columns(&self) -> &[A] { + &self.arrays + } + + /// returns the number of rows of every array + pub fn len(&self) -> usize { + self.height + } + + /// returns the number of rows of every array + pub fn height(&self) -> usize { + self.height + } + + /// returns the number of arrays + pub fn width(&self) -> usize { + self.arrays.len() + } + + /// returns whether the columns have any rows + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Consumes [`RecordBatchT`] into its underlying arrays. + /// The arrays are guaranteed to have the same length + pub fn into_arrays(self) -> Vec { + self.arrays + } + + /// Consumes [`RecordBatchT`] into its underlying schema and arrays. + /// The arrays are guaranteed to have the same length + pub fn into_schema_and_arrays(self) -> (ArrowSchemaRef, Vec) { + (self.schema, self.arrays) + } +} + +impl> From> for Vec { + fn from(c: RecordBatchT) -> Self { + c.into_arrays() + } +} + +impl> std::ops::Deref for RecordBatchT { + type Target = [A]; + + #[inline] + fn deref(&self) -> &[A] { + self.arrays() + } +} diff --git a/crates/polars-arrow/src/scalar/README.md b/crates/polars-arrow/src/scalar/README.md new file mode 100644 index 000000000000..9ff93cd86df7 --- /dev/null +++ b/crates/polars-arrow/src/scalar/README.md @@ -0,0 +1,29 @@ +# Scalar API + +Design choices: + +### `Scalar` is trait object + +There are three reasons: + +- a scalar should have a small memory footprint, which an enum would not ensure given the different + physical types available. +- forward-compatibility: a new entry on an `enum` is backward-incompatible +- do not expose implementation details to users (reduce the surface of the public API) + +### `Scalar` MUST contain nullability information + +This is to be aligned with the general notion of arrow's `Array`. + +This API is a companion to the `Array`, and follows the same design as `Array`. Specifically, a +`Scalar` is a trait object that can be downcasted to concrete implementations. + +Like `Array`, `Scalar` implements + +- `dtype`, which is used to perform the correct downcast +- `is_valid`, to tell whether the scalar is null or not + +### There is one implementation per arrows' physical type + +- Reduces the number of `match` that users need to write +- Allows casting of logical types without changing the underlying physical type diff --git a/crates/polars-arrow/src/scalar/binary.rs b/crates/polars-arrow/src/scalar/binary.rs new file mode 100644 index 000000000000..f758cf021b1c --- /dev/null +++ b/crates/polars-arrow/src/scalar/binary.rs @@ -0,0 +1,55 @@ +use super::Scalar; +use crate::datatypes::ArrowDataType; +use crate::offset::Offset; + +/// The [`Scalar`] implementation of binary ([`Option>`]). +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct BinaryScalar { + value: Option>, + phantom: std::marker::PhantomData, +} + +impl BinaryScalar { + /// Returns a new [`BinaryScalar`]. + #[inline] + pub fn new>>(value: Option

) -> Self { + Self { + value: value.map(|x| x.into()), + phantom: std::marker::PhantomData, + } + } + + /// Its value + #[inline] + pub fn value(&self) -> Option<&[u8]> { + self.value.as_ref().map(|x| x.as_ref()) + } +} + +impl>> From> for BinaryScalar { + #[inline] + fn from(v: Option

) -> Self { + Self::new(v) + } +} + +impl Scalar for BinaryScalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + self.value.is_some() + } + + #[inline] + fn dtype(&self) -> &ArrowDataType { + if O::IS_LARGE { + &ArrowDataType::LargeBinary + } else { + &ArrowDataType::Binary + } + } +} diff --git a/crates/polars-arrow/src/scalar/binview.rs b/crates/polars-arrow/src/scalar/binview.rs new file mode 100644 index 000000000000..958037041623 --- /dev/null +++ b/crates/polars-arrow/src/scalar/binview.rs @@ -0,0 +1,72 @@ +use std::fmt::{Debug, Formatter}; + +use super::Scalar; +use crate::array::ViewType; +use crate::datatypes::ArrowDataType; + +/// The implementation of [`Scalar`] for utf8, semantically equivalent to [`Option`]. +#[derive(PartialEq, Eq)] +pub struct BinaryViewScalar { + value: Option, + phantom: std::marker::PhantomData, +} + +impl Debug for BinaryViewScalar { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "Scalar({:?})", self.value) + } +} + +impl Clone for BinaryViewScalar { + fn clone(&self) -> Self { + Self { + value: self.value.clone(), + phantom: Default::default(), + } + } +} + +impl BinaryViewScalar { + /// Returns a new [`BinaryViewScalar`] + #[inline] + pub fn new(value: Option<&T>) -> Self { + Self { + value: value.map(|x| x.into_owned()), + phantom: std::marker::PhantomData, + } + } + + /// Returns the value irrespectively of the validity. + #[inline] + pub fn value(&self) -> Option<&T> { + self.value.as_ref().map(|x| x.as_ref()) + } +} + +impl From> for BinaryViewScalar { + #[inline] + fn from(v: Option<&T>) -> Self { + Self::new(v) + } +} + +impl Scalar for BinaryViewScalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + self.value.is_some() + } + + #[inline] + fn dtype(&self) -> &ArrowDataType { + if T::IS_UTF8 { + &ArrowDataType::Utf8View + } else { + &ArrowDataType::BinaryView + } + } +} diff --git a/crates/polars-arrow/src/scalar/boolean.rs b/crates/polars-arrow/src/scalar/boolean.rs new file mode 100644 index 000000000000..44158d8c3636 --- /dev/null +++ b/crates/polars-arrow/src/scalar/boolean.rs @@ -0,0 +1,46 @@ +use super::Scalar; +use crate::datatypes::ArrowDataType; + +/// The [`Scalar`] implementation of a boolean. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct BooleanScalar { + value: Option, +} + +impl BooleanScalar { + /// Returns a new [`BooleanScalar`] + #[inline] + pub fn new(value: Option) -> Self { + Self { value } + } + + /// The value + #[inline] + pub fn value(&self) -> Option { + self.value + } +} + +impl Scalar for BooleanScalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + self.value.is_some() + } + + #[inline] + fn dtype(&self) -> &ArrowDataType { + &ArrowDataType::Boolean + } +} + +impl From> for BooleanScalar { + #[inline] + fn from(v: Option) -> Self { + Self::new(v) + } +} diff --git a/crates/polars-arrow/src/scalar/dictionary.rs b/crates/polars-arrow/src/scalar/dictionary.rs new file mode 100644 index 000000000000..b92a99355559 --- /dev/null +++ b/crates/polars-arrow/src/scalar/dictionary.rs @@ -0,0 +1,54 @@ +use std::any::Any; + +use super::Scalar; +use crate::array::*; +use crate::datatypes::ArrowDataType; + +/// The [`DictionaryArray`] equivalent of [`Array`] for [`Scalar`]. +#[derive(Debug, Clone)] +pub struct DictionaryScalar { + value: Option>, + phantom: std::marker::PhantomData, + dtype: ArrowDataType, +} + +impl PartialEq for DictionaryScalar { + fn eq(&self, other: &Self) -> bool { + (self.dtype == other.dtype) && (self.value.as_ref() == other.value.as_ref()) + } +} + +impl DictionaryScalar { + /// returns a new [`DictionaryScalar`] + /// # Panics + /// iff + /// * the `dtype` is not `List` or `LargeList` (depending on this scalar's offset `O`) + /// * the child of the `dtype` is not equal to the `values` + #[inline] + pub fn new(dtype: ArrowDataType, value: Option>) -> Self { + Self { + value, + phantom: std::marker::PhantomData, + dtype, + } + } + + /// The values of the [`DictionaryScalar`] + pub fn value(&self) -> Option<&Box> { + self.value.as_ref() + } +} + +impl Scalar for DictionaryScalar { + fn as_any(&self) -> &dyn Any { + self + } + + fn is_valid(&self) -> bool { + self.value.is_some() + } + + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } +} diff --git a/crates/polars-arrow/src/scalar/equal.rs b/crates/polars-arrow/src/scalar/equal.rs new file mode 100644 index 000000000000..3978765fe73a --- /dev/null +++ b/crates/polars-arrow/src/scalar/equal.rs @@ -0,0 +1,58 @@ +use std::sync::Arc; + +use super::*; +use crate::{match_integer_type, with_match_primitive_type_full}; + +impl PartialEq for dyn Scalar + '_ { + fn eq(&self, that: &dyn Scalar) -> bool { + equal(self, that) + } +} + +impl PartialEq for Arc { + fn eq(&self, that: &dyn Scalar) -> bool { + equal(&**self, that) + } +} + +impl PartialEq for Box { + fn eq(&self, that: &dyn Scalar) -> bool { + equal(&**self, that) + } +} + +macro_rules! dyn_eq { + ($ty:ty, $lhs:expr, $rhs:expr) => {{ + let lhs = $lhs.as_any().downcast_ref::<$ty>().unwrap(); + let rhs = $rhs.as_any().downcast_ref::<$ty>().unwrap(); + lhs == rhs + }}; +} + +fn equal(lhs: &dyn Scalar, rhs: &dyn Scalar) -> bool { + if lhs.dtype() != rhs.dtype() { + return false; + } + + use PhysicalType::*; + match lhs.dtype().to_physical_type() { + Null => dyn_eq!(NullScalar, lhs, rhs), + Boolean => dyn_eq!(BooleanScalar, lhs, rhs), + Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { + dyn_eq!(PrimitiveScalar<$T>, lhs, rhs) + }), + LargeUtf8 => dyn_eq!(Utf8Scalar, lhs, rhs), + LargeBinary => dyn_eq!(BinaryScalar, lhs, rhs), + LargeList => dyn_eq!(ListScalar, lhs, rhs), + Dictionary(key_type) => match_integer_type!(key_type, |$T| { + dyn_eq!(DictionaryScalar<$T>, lhs, rhs) + }), + Struct => dyn_eq!(StructScalar, lhs, rhs), + FixedSizeBinary => dyn_eq!(FixedSizeBinaryScalar, lhs, rhs), + FixedSizeList => dyn_eq!(FixedSizeListScalar, lhs, rhs), + Union => dyn_eq!(UnionScalar, lhs, rhs), + Map => dyn_eq!(MapScalar, lhs, rhs), + Utf8View => dyn_eq!(BinaryViewScalar, lhs, rhs), + _ => unimplemented!(), + } +} diff --git a/crates/polars-arrow/src/scalar/fixed_size_binary.rs b/crates/polars-arrow/src/scalar/fixed_size_binary.rs new file mode 100644 index 000000000000..a14d2886d75d --- /dev/null +++ b/crates/polars-arrow/src/scalar/fixed_size_binary.rs @@ -0,0 +1,58 @@ +use super::Scalar; +use crate::datatypes::ArrowDataType; + +#[derive(Debug, Clone, PartialEq, Eq)] +/// The [`Scalar`] implementation of fixed size binary ([`Option>`]). +pub struct FixedSizeBinaryScalar { + value: Option>, + dtype: ArrowDataType, +} + +impl FixedSizeBinaryScalar { + /// Returns a new [`FixedSizeBinaryScalar`]. + /// # Panics + /// iff + /// * the `dtype` is not `FixedSizeBinary` + /// * the size of child binary is not equal + #[inline] + pub fn new>>(dtype: ArrowDataType, value: Option

) -> Self { + assert_eq!( + dtype.to_physical_type(), + crate::datatypes::PhysicalType::FixedSizeBinary + ); + Self { + value: value.map(|x| { + let x: Vec = x.into(); + assert_eq!( + dtype.to_logical_type(), + &ArrowDataType::FixedSizeBinary(x.len()) + ); + x.into_boxed_slice() + }), + dtype, + } + } + + /// Its value + #[inline] + pub fn value(&self) -> Option<&[u8]> { + self.value.as_ref().map(|x| x.as_ref()) + } +} + +impl Scalar for FixedSizeBinaryScalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + self.value.is_some() + } + + #[inline] + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } +} diff --git a/crates/polars-arrow/src/scalar/fixed_size_list.rs b/crates/polars-arrow/src/scalar/fixed_size_list.rs new file mode 100644 index 000000000000..5810eeab2dfc --- /dev/null +++ b/crates/polars-arrow/src/scalar/fixed_size_list.rs @@ -0,0 +1,59 @@ +use std::any::Any; + +use super::Scalar; +use crate::array::*; +use crate::datatypes::ArrowDataType; + +/// The scalar equivalent of [`FixedSizeListArray`]. Like [`FixedSizeListArray`], this struct holds a dynamically-typed +/// [`Array`]. The only difference is that this has only one element. +#[derive(Debug, Clone)] +pub struct FixedSizeListScalar { + values: Option>, + dtype: ArrowDataType, +} + +impl PartialEq for FixedSizeListScalar { + fn eq(&self, other: &Self) -> bool { + (self.dtype == other.dtype) + && (self.values.is_some() == other.values.is_some()) + && ((self.values.is_none()) | (self.values.as_ref() == other.values.as_ref())) + } +} + +impl FixedSizeListScalar { + /// returns a new [`FixedSizeListScalar`] + /// # Panics + /// iff + /// * the `dtype` is not `FixedSizeList` + /// * the child of the `dtype` is not equal to the `values` + /// * the size of child array is not equal + #[inline] + pub fn new(dtype: ArrowDataType, values: Option>) -> Self { + let (field, size) = FixedSizeListArray::get_child_and_size(&dtype); + let inner_dtype = field.dtype(); + let values = values.inspect(|x| { + assert_eq!(inner_dtype, x.dtype()); + assert_eq!(size, x.len()); + }); + Self { values, dtype } + } + + /// The values of the [`FixedSizeListScalar`] + pub fn values(&self) -> Option<&Box> { + self.values.as_ref() + } +} + +impl Scalar for FixedSizeListScalar { + fn as_any(&self) -> &dyn Any { + self + } + + fn is_valid(&self) -> bool { + self.values.is_some() + } + + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } +} diff --git a/crates/polars-arrow/src/scalar/list.rs b/crates/polars-arrow/src/scalar/list.rs new file mode 100644 index 000000000000..6978c6e61860 --- /dev/null +++ b/crates/polars-arrow/src/scalar/list.rs @@ -0,0 +1,68 @@ +use std::any::Any; + +use super::Scalar; +use crate::array::*; +use crate::datatypes::ArrowDataType; +use crate::offset::Offset; + +/// The scalar equivalent of [`ListArray`]. Like [`ListArray`], this struct holds a dynamically-typed +/// [`Array`]. The only difference is that this has only one element. +#[derive(Debug, Clone)] +pub struct ListScalar { + values: Box, + is_valid: bool, + phantom: std::marker::PhantomData, + dtype: ArrowDataType, +} + +impl PartialEq for ListScalar { + fn eq(&self, other: &Self) -> bool { + (self.dtype == other.dtype) + && (self.is_valid == other.is_valid) + && ((!self.is_valid) | (self.values.as_ref() == other.values.as_ref())) + } +} + +impl ListScalar { + /// returns a new [`ListScalar`] + /// # Panics + /// iff + /// * the `dtype` is not `List` or `LargeList` (depending on this scalar's offset `O`) + /// * the child of the `dtype` is not equal to the `values` + #[inline] + pub fn new(dtype: ArrowDataType, values: Option>) -> Self { + let inner_dtype = ListArray::::get_child_type(&dtype); + let (is_valid, values) = match values { + Some(values) => { + assert_eq!(inner_dtype, values.dtype()); + (true, values) + }, + None => (false, new_empty_array(inner_dtype.clone())), + }; + Self { + values, + is_valid, + phantom: std::marker::PhantomData, + dtype, + } + } + + /// The values of the [`ListScalar`] + pub fn values(&self) -> &Box { + &self.values + } +} + +impl Scalar for ListScalar { + fn as_any(&self) -> &dyn Any { + self + } + + fn is_valid(&self) -> bool { + self.is_valid + } + + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } +} diff --git a/crates/polars-arrow/src/scalar/map.rs b/crates/polars-arrow/src/scalar/map.rs new file mode 100644 index 000000000000..f9e7b238c481 --- /dev/null +++ b/crates/polars-arrow/src/scalar/map.rs @@ -0,0 +1,66 @@ +use std::any::Any; + +use super::Scalar; +use crate::array::*; +use crate::datatypes::ArrowDataType; + +/// The scalar equivalent of [`MapArray`]. Like [`MapArray`], this struct holds a dynamically-typed +/// [`Array`]. The only difference is that this has only one element. +#[derive(Debug, Clone)] +pub struct MapScalar { + values: Box, + is_valid: bool, + dtype: ArrowDataType, +} + +impl PartialEq for MapScalar { + fn eq(&self, other: &Self) -> bool { + (self.dtype == other.dtype) + && (self.is_valid == other.is_valid) + && ((!self.is_valid) | (self.values.as_ref() == other.values.as_ref())) + } +} + +impl MapScalar { + /// returns a new [`MapScalar`] + /// # Panics + /// iff + /// * the `dtype` is not `Map` + /// * the child of the `dtype` is not equal to the `values` + #[inline] + pub fn new(dtype: ArrowDataType, values: Option>) -> Self { + let inner_field = MapArray::try_get_field(&dtype).unwrap(); + let inner_dtype = inner_field.dtype(); + let (is_valid, values) = match values { + Some(values) => { + assert_eq!(inner_dtype, values.dtype()); + (true, values) + }, + None => (false, new_empty_array(inner_dtype.clone())), + }; + Self { + values, + is_valid, + dtype, + } + } + + /// The values of the [`MapScalar`] + pub fn values(&self) -> &Box { + &self.values + } +} + +impl Scalar for MapScalar { + fn as_any(&self) -> &dyn Any { + self + } + + fn is_valid(&self) -> bool { + self.is_valid + } + + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } +} diff --git a/crates/polars-arrow/src/scalar/mod.rs b/crates/polars-arrow/src/scalar/mod.rs new file mode 100644 index 000000000000..adcf006862db --- /dev/null +++ b/crates/polars-arrow/src/scalar/mod.rs @@ -0,0 +1,209 @@ +//! contains the [`Scalar`] trait object representing individual items of [`Array`](crate::array::Array)s, +//! as well as concrete implementations such as [`BooleanScalar`]. +use std::any::Any; + +use crate::array::*; +use crate::datatypes::*; + +mod dictionary; +pub use dictionary::*; +mod equal; +mod primitive; +pub use primitive::*; +mod utf8; +pub use utf8::*; +mod binary; +pub use binary::*; +mod boolean; +pub use boolean::*; +mod list; +pub use list::*; +mod map; +pub use map::*; +mod null; +pub use null::*; +mod struct_; +pub use struct_::*; +mod fixed_size_list; +pub use fixed_size_list::*; +mod fixed_size_binary; +pub use binview::*; +pub use fixed_size_binary::*; +mod binview; +mod union; + +pub use union::UnionScalar; + +use crate::{match_integer_type, with_match_primitive_type_full}; + +/// Trait object declaring an optional value with a [`ArrowDataType`]. +/// This strait is often used in APIs that accept multiple scalar types. +pub trait Scalar: std::fmt::Debug + Send + Sync + dyn_clone::DynClone + 'static { + /// convert itself to + fn as_any(&self) -> &dyn Any; + + /// whether it is valid + fn is_valid(&self) -> bool; + + /// the logical type. + fn dtype(&self) -> &ArrowDataType; +} + +dyn_clone::clone_trait_object!(Scalar); + +macro_rules! dyn_new_utf8 { + ($array:expr, $index:expr, $type:ty) => {{ + let array = $array.as_any().downcast_ref::>().unwrap(); + let value = if array.is_valid($index) { + Some(array.value($index)) + } else { + None + }; + Box::new(Utf8Scalar::<$type>::new(value)) + }}; +} + +macro_rules! dyn_new_binview { + ($array:expr, $index:expr, $type:ty) => {{ + let array = $array + .as_any() + .downcast_ref::>() + .unwrap(); + let value = if array.is_valid($index) { + Some(array.value($index)) + } else { + None + }; + Box::new(BinaryViewScalar::<$type>::new(value)) + }}; +} + +macro_rules! dyn_new_binary { + ($array:expr, $index:expr, $type:ty) => {{ + let array = $array + .as_any() + .downcast_ref::>() + .unwrap(); + let value = if array.is_valid($index) { + Some(array.value($index)) + } else { + None + }; + Box::new(BinaryScalar::<$type>::new(value)) + }}; +} + +macro_rules! dyn_new_list { + ($array:expr, $index:expr, $type:ty) => {{ + let array = $array.as_any().downcast_ref::>().unwrap(); + let value = if array.is_valid($index) { + Some(array.value($index).into()) + } else { + None + }; + Box::new(ListScalar::<$type>::new(array.dtype().clone(), value)) + }}; +} + +/// creates a new [`Scalar`] from an [`Array`]. +pub fn new_scalar(array: &dyn Array, index: usize) -> Box { + use PhysicalType::*; + match array.dtype().to_physical_type() { + Null => Box::new(NullScalar::new()), + Boolean => { + let array = array.as_any().downcast_ref::().unwrap(); + let value = if array.is_valid(index) { + Some(array.value(index)) + } else { + None + }; + Box::new(BooleanScalar::new(value)) + }, + Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let value = if array.is_valid(index) { + Some(array.value(index)) + } else { + None + }; + Box::new(PrimitiveScalar::new(array.dtype().clone(), value)) + }), + BinaryView => dyn_new_binview!(array, index, [u8]), + Utf8View => dyn_new_binview!(array, index, str), + Utf8 => dyn_new_utf8!(array, index, i32), + LargeUtf8 => dyn_new_utf8!(array, index, i64), + Binary => dyn_new_binary!(array, index, i32), + LargeBinary => dyn_new_binary!(array, index, i64), + List => dyn_new_list!(array, index, i32), + LargeList => dyn_new_list!(array, index, i64), + Struct => { + let array = array.as_any().downcast_ref::().unwrap(); + if array.is_valid(index) { + let values = array + .values() + .iter() + .map(|x| new_scalar(x.as_ref(), index)) + .collect(); + Box::new(StructScalar::new(array.dtype().clone(), Some(values))) + } else { + Box::new(StructScalar::new(array.dtype().clone(), None)) + } + }, + FixedSizeBinary => { + let array = array + .as_any() + .downcast_ref::() + .unwrap(); + let value = if array.is_valid(index) { + Some(array.value(index)) + } else { + None + }; + Box::new(FixedSizeBinaryScalar::new(array.dtype().clone(), value)) + }, + FixedSizeList => { + let array = array.as_any().downcast_ref::().unwrap(); + let value = if array.is_valid(index) { + Some(array.value(index)) + } else { + None + }; + Box::new(FixedSizeListScalar::new(array.dtype().clone(), value)) + }, + Union => { + let array = array.as_any().downcast_ref::().unwrap(); + Box::new(UnionScalar::new( + array.dtype().clone(), + array.types()[index], + array.value(index), + )) + }, + Map => { + let array = array.as_any().downcast_ref::().unwrap(); + let value = if array.is_valid(index) { + Some(array.value(index)) + } else { + None + }; + Box::new(MapScalar::new(array.dtype().clone(), value)) + }, + Dictionary(key_type) => match_integer_type!(key_type, |$T| { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let value = if array.is_valid(index) { + Some(array.value(index).into()) + } else { + None + }; + Box::new(DictionaryScalar::<$T>::new( + array.dtype().clone(), + value, + )) + }), + } +} diff --git a/crates/polars-arrow/src/scalar/null.rs b/crates/polars-arrow/src/scalar/null.rs new file mode 100644 index 000000000000..2071f0d4584e --- /dev/null +++ b/crates/polars-arrow/src/scalar/null.rs @@ -0,0 +1,37 @@ +use super::Scalar; +use crate::datatypes::ArrowDataType; + +/// The representation of a single entry of a [`crate::array::NullArray`]. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct NullScalar {} + +impl NullScalar { + /// A new [`NullScalar`] + #[inline] + pub fn new() -> Self { + Self {} + } +} + +impl Default for NullScalar { + fn default() -> Self { + Self::new() + } +} + +impl Scalar for NullScalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + false + } + + #[inline] + fn dtype(&self) -> &ArrowDataType { + &ArrowDataType::Null + } +} diff --git a/crates/polars-arrow/src/scalar/primitive.rs b/crates/polars-arrow/src/scalar/primitive.rs new file mode 100644 index 000000000000..35214b270032 --- /dev/null +++ b/crates/polars-arrow/src/scalar/primitive.rs @@ -0,0 +1,63 @@ +use super::Scalar; +use crate::datatypes::ArrowDataType; +use crate::types::NativeType; + +/// The implementation of [`Scalar`] for primitive, semantically equivalent to [`Option`] +/// with [`ArrowDataType`]. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PrimitiveScalar { + value: Option, + dtype: ArrowDataType, +} + +impl PrimitiveScalar { + /// Returns a new [`PrimitiveScalar`]. + #[inline] + pub fn new(dtype: ArrowDataType, value: Option) -> Self { + if !dtype.to_physical_type().eq_primitive(T::PRIMITIVE) { + panic!( + "Type {} does not support logical type {:?}", + std::any::type_name::(), + dtype + ) + } + Self { value, dtype } + } + + /// Returns the optional value. + #[inline] + pub fn value(&self) -> &Option { + &self.value + } + + /// Returns a new `PrimitiveScalar` with the same value but different [`ArrowDataType`] + /// # Panic + /// This function panics if the `dtype` is not valid for self's physical type `T`. + pub fn to(self, dtype: ArrowDataType) -> Self { + Self::new(dtype, self.value) + } +} + +impl From> for PrimitiveScalar { + #[inline] + fn from(v: Option) -> Self { + Self::new(T::PRIMITIVE.into(), v) + } +} + +impl Scalar for PrimitiveScalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + self.value.is_some() + } + + #[inline] + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } +} diff --git a/crates/polars-arrow/src/scalar/struct_.rs b/crates/polars-arrow/src/scalar/struct_.rs new file mode 100644 index 000000000000..c9ba6a8e66c0 --- /dev/null +++ b/crates/polars-arrow/src/scalar/struct_.rs @@ -0,0 +1,54 @@ +use super::Scalar; +use crate::datatypes::ArrowDataType; + +/// A single entry of a [`crate::array::StructArray`]. +#[derive(Debug, Clone)] +pub struct StructScalar { + values: Vec>, + is_valid: bool, + dtype: ArrowDataType, +} + +impl PartialEq for StructScalar { + fn eq(&self, other: &Self) -> bool { + (self.dtype == other.dtype) + && (self.is_valid == other.is_valid) + && ((!self.is_valid) | (self.values == other.values)) + } +} + +impl StructScalar { + /// Returns a new [`StructScalar`] + #[inline] + pub fn new(dtype: ArrowDataType, values: Option>>) -> Self { + let is_valid = values.is_some(); + Self { + values: values.unwrap_or_default(), + is_valid, + dtype, + } + } + + /// Returns the values irrespectively of the validity. + #[inline] + pub fn values(&self) -> &[Box] { + &self.values + } +} + +impl Scalar for StructScalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + self.is_valid + } + + #[inline] + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } +} diff --git a/crates/polars-arrow/src/scalar/union.rs b/crates/polars-arrow/src/scalar/union.rs new file mode 100644 index 000000000000..95f4ebba6e3e --- /dev/null +++ b/crates/polars-arrow/src/scalar/union.rs @@ -0,0 +1,51 @@ +use super::Scalar; +use crate::datatypes::ArrowDataType; + +/// A single entry of a [`crate::array::UnionArray`]. +#[derive(Debug, Clone, PartialEq)] +pub struct UnionScalar { + value: Box, + type_: i8, + dtype: ArrowDataType, +} + +impl UnionScalar { + /// Returns a new [`UnionScalar`] + #[inline] + pub fn new(dtype: ArrowDataType, type_: i8, value: Box) -> Self { + Self { + value, + type_, + dtype, + } + } + + /// Returns the inner value + #[inline] + pub fn value(&self) -> &Box { + &self.value + } + + /// Returns the type of the union scalar + #[inline] + pub fn type_(&self) -> i8 { + self.type_ + } +} + +impl Scalar for UnionScalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + true + } + + #[inline] + fn dtype(&self) -> &ArrowDataType { + &self.dtype + } +} diff --git a/crates/polars-arrow/src/scalar/utf8.rs b/crates/polars-arrow/src/scalar/utf8.rs new file mode 100644 index 000000000000..986477d5bb5c --- /dev/null +++ b/crates/polars-arrow/src/scalar/utf8.rs @@ -0,0 +1,55 @@ +use super::Scalar; +use crate::datatypes::ArrowDataType; +use crate::offset::Offset; + +/// The implementation of [`Scalar`] for utf8, semantically equivalent to [`Option`]. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Utf8Scalar { + value: Option, + phantom: std::marker::PhantomData, +} + +impl Utf8Scalar { + /// Returns a new [`Utf8Scalar`] + #[inline] + pub fn new>(value: Option

) -> Self { + Self { + value: value.map(|x| x.into()), + phantom: std::marker::PhantomData, + } + } + + /// Returns the value irrespectively of the validity. + #[inline] + pub fn value(&self) -> Option<&str> { + self.value.as_ref().map(|x| x.as_ref()) + } +} + +impl> From> for Utf8Scalar { + #[inline] + fn from(v: Option

) -> Self { + Self::new(v) + } +} + +impl Scalar for Utf8Scalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + self.value.is_some() + } + + #[inline] + fn dtype(&self) -> &ArrowDataType { + if O::IS_LARGE { + &ArrowDataType::LargeUtf8 + } else { + &ArrowDataType::Utf8 + } + } +} diff --git a/crates/polars-arrow/src/storage.rs b/crates/polars-arrow/src/storage.rs new file mode 100644 index 000000000000..538336313daa --- /dev/null +++ b/crates/polars-arrow/src/storage.rs @@ -0,0 +1,400 @@ +use std::marker::PhantomData; +use std::mem::ManuallyDrop; +use std::ops::{Deref, DerefMut}; +use std::ptr::NonNull; +use std::sync::atomic::{AtomicU64, Ordering}; + +use bytemuck::Pod; + +// Allows us to transmute between types while also keeping the original +// stats and drop method of the Vec around. +struct VecVTable { + size: usize, + align: usize, + drop_buffer: unsafe fn(*mut (), usize), +} + +impl VecVTable { + const fn new() -> Self { + unsafe fn drop_buffer(ptr: *mut (), cap: usize) { + unsafe { drop(Vec::from_raw_parts(ptr.cast::(), 0, cap)) } + } + + Self { + size: size_of::(), + align: align_of::(), + drop_buffer: drop_buffer::, + } + } + + fn new_static() -> &'static Self { + const { &Self::new::() } + } +} + +use crate::ffi::InternalArrowArray; + +enum BackingStorage { + Vec { + original_capacity: usize, // Elements, not bytes. + vtable: &'static VecVTable, + }, + InternalArrowArray(InternalArrowArray), + + /// Backed by some external method which we do not need to take care of, + /// but we still should refcount and drop the SharedStorageInner. + External, + + /// Both the backing storage and the SharedStorageInner are leaked, no + /// refcounting is done. This technically should be a flag on + /// SharedStorageInner instead of being here, but that would add 8 more + /// bytes to SharedStorageInner, so here it is. + Leaked, +} + +struct SharedStorageInner { + ref_count: AtomicU64, + ptr: *mut T, + length_in_bytes: usize, + backing: BackingStorage, + // https://github.com/rust-lang/rfcs/blob/master/text/0769-sound-generic-drop.md#phantom-data + phantom: PhantomData, +} + +unsafe impl Sync for SharedStorageInner {} + +impl SharedStorageInner { + pub fn from_vec(mut v: Vec) -> Self { + let length_in_bytes = v.len() * size_of::(); + let original_capacity = v.capacity(); + let ptr = v.as_mut_ptr(); + core::mem::forget(v); + Self { + ref_count: AtomicU64::new(1), + ptr, + length_in_bytes, + backing: BackingStorage::Vec { + original_capacity, + vtable: VecVTable::new_static::(), + }, + phantom: PhantomData, + } + } +} + +impl Drop for SharedStorageInner { + fn drop(&mut self) { + match core::mem::replace(&mut self.backing, BackingStorage::External) { + BackingStorage::InternalArrowArray(a) => drop(a), + BackingStorage::Vec { + original_capacity, + vtable, + } => unsafe { + // Drop the elements in our slice. + if std::mem::needs_drop::() { + core::ptr::drop_in_place(core::ptr::slice_from_raw_parts_mut( + self.ptr, + self.length_in_bytes / size_of::(), + )); + } + + // Free the buffer. + if original_capacity > 0 { + (vtable.drop_buffer)(self.ptr.cast(), original_capacity); + } + }, + BackingStorage::External | BackingStorage::Leaked => {}, + } + } +} + +pub struct SharedStorage { + inner: NonNull>, + phantom: PhantomData>, +} + +unsafe impl Send for SharedStorage {} +unsafe impl Sync for SharedStorage {} + +impl Default for SharedStorage { + fn default() -> Self { + Self::empty() + } +} + +impl SharedStorage { + const fn empty() -> Self { + assert!(align_of::() <= 1 << 30); + static INNER: SharedStorageInner<()> = SharedStorageInner { + ref_count: AtomicU64::new(1), + ptr: core::ptr::without_provenance_mut(1 << 30), // Very overaligned for any T. + length_in_bytes: 0, + backing: BackingStorage::Leaked, + phantom: PhantomData, + }; + + Self { + inner: NonNull::new(&raw const INNER as *mut SharedStorageInner).unwrap(), + phantom: PhantomData, + } + } + + pub fn from_static(slice: &'static [T]) -> Self { + #[expect(clippy::manual_slice_size_calculation)] + let length_in_bytes = slice.len() * size_of::(); + let ptr = slice.as_ptr().cast_mut(); + let inner = SharedStorageInner { + ref_count: AtomicU64::new(1), + ptr, + length_in_bytes, + backing: BackingStorage::External, + phantom: PhantomData, + }; + Self { + inner: NonNull::new(Box::into_raw(Box::new(inner))).unwrap(), + phantom: PhantomData, + } + } + + pub fn from_vec(v: Vec) -> Self { + Self { + inner: NonNull::new(Box::into_raw(Box::new(SharedStorageInner::from_vec(v)))).unwrap(), + phantom: PhantomData, + } + } + + pub fn from_internal_arrow_array(ptr: *const T, len: usize, arr: InternalArrowArray) -> Self { + let inner = SharedStorageInner { + ref_count: AtomicU64::new(1), + ptr: ptr.cast_mut(), + length_in_bytes: len * size_of::(), + backing: BackingStorage::InternalArrowArray(arr), + phantom: PhantomData, + }; + Self { + inner: NonNull::new(Box::into_raw(Box::new(inner))).unwrap(), + phantom: PhantomData, + } + } + + /// Leaks this SharedStorage such that it and its inner value is never + /// dropped. In return no refcounting needs to be performed. + /// + /// The SharedStorage must be exclusive. + pub fn leak(&mut self) { + assert!(self.is_exclusive()); + unsafe { + let inner = &mut *self.inner.as_ptr(); + core::mem::forget(core::mem::replace( + &mut inner.backing, + BackingStorage::Leaked, + )); + } + } +} + +pub struct SharedStorageAsVecMut<'a, T> { + ss: &'a mut SharedStorage, + vec: ManuallyDrop>, +} + +impl Deref for SharedStorageAsVecMut<'_, T> { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.vec + } +} + +impl DerefMut for SharedStorageAsVecMut<'_, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.vec + } +} + +impl Drop for SharedStorageAsVecMut<'_, T> { + fn drop(&mut self) { + unsafe { + // Restore the SharedStorage. + let vec = ManuallyDrop::take(&mut self.vec); + let inner = self.ss.inner.as_ptr(); + inner.write(SharedStorageInner::from_vec(vec)); + } + } +} + +impl SharedStorage { + #[inline(always)] + pub fn len(&self) -> usize { + self.inner().length_in_bytes / size_of::() + } + + #[inline(always)] + pub fn as_ptr(&self) -> *const T { + self.inner().ptr + } + + #[inline(always)] + pub fn is_exclusive(&mut self) -> bool { + // Ordering semantics copied from Arc. + self.inner().ref_count.load(Ordering::Acquire) == 1 + } + + /// Gets the reference count of this storage. + /// + /// Because this function takes a shared reference this should not be used + /// in cases where we are checking if the refcount is one for safety, + /// someone else could increment it in the meantime. + #[inline(always)] + pub fn refcount(&self) -> u64 { + // Ordering semantics copied from Arc. + self.inner().ref_count.load(Ordering::Acquire) + } + + pub fn try_as_mut_slice(&mut self) -> Option<&mut [T]> { + self.is_exclusive().then(|| { + let inner = self.inner(); + let len = inner.length_in_bytes / size_of::(); + unsafe { core::slice::from_raw_parts_mut(inner.ptr, len) } + }) + } + + /// Try to take the vec backing this SharedStorage, leaving this as an empty slice. + pub fn try_take_vec(&mut self) -> Option> { + // If there are other references we can't get an exclusive reference. + if !self.is_exclusive() { + return None; + } + + let ret; + unsafe { + let inner = &mut *self.inner.as_ptr(); + + // We may only go back to a Vec if we originally came from a Vec + // where the desired size/align matches the original. + let BackingStorage::Vec { + original_capacity, + vtable, + } = &mut inner.backing + else { + return None; + }; + + if vtable.size != size_of::() || vtable.align != align_of::() { + return None; + } + + // Steal vec from inner. + let len = inner.length_in_bytes / size_of::(); + ret = Vec::from_raw_parts(inner.ptr, len, *original_capacity); + *original_capacity = 0; + inner.length_in_bytes = 0; + } + Some(ret) + } + + /// Attempts to call the given function with this SharedStorage as a + /// reference to a mutable Vec. If this SharedStorage can't be converted to + /// a Vec the function is not called and instead returned as an error. + pub fn try_as_mut_vec(&mut self) -> Option> { + Some(SharedStorageAsVecMut { + vec: ManuallyDrop::new(self.try_take_vec()?), + ss: self, + }) + } + + pub fn try_into_vec(mut self) -> Result, Self> { + self.try_take_vec().ok_or(self) + } + + #[inline(always)] + fn inner(&self) -> &SharedStorageInner { + unsafe { &*self.inner.as_ptr() } + } + + /// # Safety + /// May only be called once. + #[cold] + unsafe fn drop_slow(&mut self) { + unsafe { drop(Box::from_raw(self.inner.as_ptr())) } + } +} + +impl SharedStorage { + fn try_transmute(self) -> Result, Self> { + let inner = self.inner(); + + // The length of the array in bytes must be a multiple of the target size. + // We can skip this check if the size of U divides the size of T. + if size_of::() % size_of::() != 0 && inner.length_in_bytes % size_of::() != 0 { + return Err(self); + } + + // The pointer must be properly aligned for U. + // We can skip this check if the alignment of U divides the alignment of T. + if align_of::() % align_of::() != 0 && !inner.ptr.cast::().is_aligned() { + return Err(self); + } + + let storage = SharedStorage { + inner: self.inner.cast(), + phantom: PhantomData, + }; + std::mem::forget(self); + Ok(storage) + } +} + +impl SharedStorage { + /// Create a [`SharedStorage`][SharedStorage] from a [`Vec`] of [`Pod`]. + pub fn bytes_from_pod_vec(v: Vec) -> Self { + // This can't fail, bytes is compatible with everything. + SharedStorage::from_vec(v) + .try_transmute::() + .unwrap_or_else(|_| unreachable!()) + } +} + +impl Deref for SharedStorage { + type Target = [T]; + + #[inline] + fn deref(&self) -> &Self::Target { + unsafe { + let inner = self.inner(); + let len = inner.length_in_bytes / size_of::(); + core::slice::from_raw_parts(inner.ptr, len) + } + } +} + +impl Clone for SharedStorage { + fn clone(&self) -> Self { + let inner = self.inner(); + if !matches!(inner.backing, BackingStorage::Leaked) { + // Ordering semantics copied from Arc. + inner.ref_count.fetch_add(1, Ordering::Relaxed); + } + Self { + inner: self.inner, + phantom: PhantomData, + } + } +} + +impl Drop for SharedStorage { + fn drop(&mut self) { + let inner = self.inner(); + if matches!(inner.backing, BackingStorage::Leaked) { + return; + } + + // Ordering semantics copied from Arc. + if inner.ref_count.fetch_sub(1, Ordering::Release) == 1 { + std::sync::atomic::fence(Ordering::Acquire); + unsafe { + self.drop_slow(); + } + } + } +} diff --git a/crates/polars-arrow/src/temporal_conversions.rs b/crates/polars-arrow/src/temporal_conversions.rs new file mode 100644 index 000000000000..aac37442dd2b --- /dev/null +++ b/crates/polars-arrow/src/temporal_conversions.rs @@ -0,0 +1,306 @@ +//! Conversion methods for dates and times. + +use chrono::format::{Parsed, StrftimeItems, parse}; +use chrono::{DateTime, Duration, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, TimeDelta}; +use polars_error::{PolarsResult, polars_err}; + +use crate::datatypes::TimeUnit; + +/// Number of seconds in a day +pub const SECONDS_IN_DAY: i64 = 86_400; +/// Number of milliseconds in a second +pub const MILLISECONDS: i64 = 1_000; +/// Number of microseconds in a second +pub const MICROSECONDS: i64 = 1_000_000; +/// Number of nanoseconds in a second +pub const NANOSECONDS: i64 = 1_000_000_000; +/// Number of milliseconds in a day +pub const MILLISECONDS_IN_DAY: i64 = SECONDS_IN_DAY * MILLISECONDS; +/// Number of microseconds in a day +pub const MICROSECONDS_IN_DAY: i64 = SECONDS_IN_DAY * MICROSECONDS; +/// Number of nanoseconds in a day +pub const NANOSECONDS_IN_DAY: i64 = SECONDS_IN_DAY * NANOSECONDS; +/// Number of days between 0001-01-01 and 1970-01-01 +pub const EPOCH_DAYS_FROM_CE: i32 = 719_163; + +/// converts a `i32` representing a `date32` to [`NaiveDateTime`] +#[inline] +pub fn date32_to_datetime(v: i32) -> NaiveDateTime { + date32_to_datetime_opt(v).expect("invalid or out-of-range datetime") +} + +/// converts a `i32` representing a `date32` to [`NaiveDateTime`] +#[inline] +pub fn date32_to_datetime_opt(v: i32) -> Option { + let delta = TimeDelta::try_days(v.into())?; + NaiveDateTime::UNIX_EPOCH.checked_add_signed(delta) +} + +/// converts a `i32` representing a `date32` to [`NaiveDate`] +#[inline] +pub fn date32_to_date(days: i32) -> NaiveDate { + date32_to_date_opt(days).expect("out-of-range date") +} + +/// converts a `i32` representing a `date32` to [`NaiveDate`] +#[inline] +pub fn date32_to_date_opt(days: i32) -> Option { + NaiveDate::from_num_days_from_ce_opt(EPOCH_DAYS_FROM_CE + days) +} + +/// converts a `i64` representing a `date64` to [`NaiveDateTime`] +#[inline] +pub fn date64_to_datetime(v: i64) -> NaiveDateTime { + TimeDelta::try_milliseconds(v) + .and_then(|delta| NaiveDateTime::UNIX_EPOCH.checked_add_signed(delta)) + .expect("invalid or out-of-range datetime") +} + +/// converts a `i64` representing a `date64` to [`NaiveDate`] +#[inline] +pub fn date64_to_date(milliseconds: i64) -> NaiveDate { + date64_to_datetime(milliseconds).date() +} + +/// converts a `i32` representing a `time32(s)` to [`NaiveTime`] +#[inline] +pub fn time32s_to_time(v: i32) -> NaiveTime { + NaiveTime::from_num_seconds_from_midnight_opt(v as u32, 0).expect("invalid time") +} + +/// converts a `i64` representing a `duration(s)` to [`Duration`] +#[inline] +pub fn duration_s_to_duration(v: i64) -> Duration { + Duration::try_seconds(v).expect("out-of-range duration") +} + +/// converts a `i64` representing a `duration(ms)` to [`Duration`] +#[inline] +pub fn duration_ms_to_duration(v: i64) -> Duration { + Duration::try_milliseconds(v).expect("out-of-range in duration conversion") +} + +/// converts a `i64` representing a `duration(us)` to [`Duration`] +#[inline] +pub fn duration_us_to_duration(v: i64) -> Duration { + Duration::microseconds(v) +} + +/// converts a `i64` representing a `duration(ns)` to [`Duration`] +#[inline] +pub fn duration_ns_to_duration(v: i64) -> Duration { + Duration::nanoseconds(v) +} + +/// converts a `i32` representing a `time32(ms)` to [`NaiveTime`] +#[inline] +pub fn time32ms_to_time(v: i32) -> NaiveTime { + let v = v as i64; + let seconds = v / MILLISECONDS; + + let milli_to_nano = 1_000_000; + let nano = (v - seconds * MILLISECONDS) * milli_to_nano; + NaiveTime::from_num_seconds_from_midnight_opt(seconds as u32, nano as u32) + .expect("invalid time") +} + +/// converts a `i64` representing a `time64(us)` to [`NaiveTime`] +#[inline] +pub fn time64us_to_time(v: i64) -> NaiveTime { + time64us_to_time_opt(v).expect("invalid time") +} + +/// converts a `i64` representing a `time64(us)` to [`NaiveTime`] +#[inline] +pub fn time64us_to_time_opt(v: i64) -> Option { + NaiveTime::from_num_seconds_from_midnight_opt( + // extract seconds from microseconds + (v / MICROSECONDS) as u32, + // discard extracted seconds and convert microseconds to + // nanoseconds + (v % MICROSECONDS * MILLISECONDS) as u32, + ) +} + +/// converts a `i64` representing a `time64(ns)` to [`NaiveTime`] +#[inline] +pub fn time64ns_to_time(v: i64) -> NaiveTime { + time64ns_to_time_opt(v).expect("invalid time") +} + +/// converts a `i64` representing a `time64(ns)` to [`NaiveTime`] +#[inline] +pub fn time64ns_to_time_opt(v: i64) -> Option { + NaiveTime::from_num_seconds_from_midnight_opt( + // extract seconds from nanoseconds + (v / NANOSECONDS) as u32, + // discard extracted seconds + (v % NANOSECONDS) as u32, + ) +} + +/// converts a `i64` representing a `timestamp(s)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_s_to_datetime(seconds: i64) -> NaiveDateTime { + timestamp_s_to_datetime_opt(seconds).expect("invalid or out-of-range datetime") +} + +/// converts a `i64` representing a `timestamp(s)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_s_to_datetime_opt(seconds: i64) -> Option { + Some(DateTime::from_timestamp(seconds, 0)?.naive_utc()) +} + +/// converts a `i64` representing a `timestamp(ms)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_ms_to_datetime(v: i64) -> NaiveDateTime { + timestamp_ms_to_datetime_opt(v).expect("invalid or out-of-range datetime") +} + +/// converts a `i64` representing a `timestamp(ms)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_ms_to_datetime_opt(v: i64) -> Option { + let delta = TimeDelta::try_milliseconds(v)?; + NaiveDateTime::UNIX_EPOCH.checked_add_signed(delta) +} + +/// converts a `i64` representing a `timestamp(us)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_us_to_datetime(v: i64) -> NaiveDateTime { + timestamp_us_to_datetime_opt(v).expect("invalid or out-of-range datetime") +} + +/// converts a `i64` representing a `timestamp(us)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_us_to_datetime_opt(v: i64) -> Option { + let delta = TimeDelta::microseconds(v); + NaiveDateTime::UNIX_EPOCH.checked_add_signed(delta) +} + +/// converts a `i64` representing a `timestamp(ns)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_ns_to_datetime(v: i64) -> NaiveDateTime { + timestamp_ns_to_datetime_opt(v).expect("invalid or out-of-range datetime") +} + +/// converts a `i64` representing a `timestamp(ns)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_ns_to_datetime_opt(v: i64) -> Option { + let delta = TimeDelta::nanoseconds(v); + NaiveDateTime::UNIX_EPOCH.checked_add_signed(delta) +} + +/// Converts a timestamp in `time_unit` and `timezone` into [`chrono::DateTime`]. +#[inline] +pub(crate) fn timestamp_to_naive_datetime( + timestamp: i64, + time_unit: TimeUnit, +) -> chrono::NaiveDateTime { + match time_unit { + TimeUnit::Second => timestamp_s_to_datetime(timestamp), + TimeUnit::Millisecond => timestamp_ms_to_datetime(timestamp), + TimeUnit::Microsecond => timestamp_us_to_datetime(timestamp), + TimeUnit::Nanosecond => timestamp_ns_to_datetime(timestamp), + } +} + +/// Converts a timestamp in `time_unit` and `timezone` into [`chrono::DateTime`]. +#[inline] +pub fn timestamp_to_datetime( + timestamp: i64, + time_unit: TimeUnit, + timezone: &T, +) -> chrono::DateTime { + timezone.from_utc_datetime(×tamp_to_naive_datetime(timestamp, time_unit)) +} + +/// Calculates the scale factor between two TimeUnits. The function returns the +/// scale that should multiply the TimeUnit "b" to have the same time scale as +/// the TimeUnit "a". +pub fn timeunit_scale(a: TimeUnit, b: TimeUnit) -> f64 { + match (a, b) { + (TimeUnit::Second, TimeUnit::Second) => 1.0, + (TimeUnit::Second, TimeUnit::Millisecond) => 0.001, + (TimeUnit::Second, TimeUnit::Microsecond) => 0.000_001, + (TimeUnit::Second, TimeUnit::Nanosecond) => 0.000_000_001, + (TimeUnit::Millisecond, TimeUnit::Second) => 1_000.0, + (TimeUnit::Millisecond, TimeUnit::Millisecond) => 1.0, + (TimeUnit::Millisecond, TimeUnit::Microsecond) => 0.001, + (TimeUnit::Millisecond, TimeUnit::Nanosecond) => 0.000_001, + (TimeUnit::Microsecond, TimeUnit::Second) => 1_000_000.0, + (TimeUnit::Microsecond, TimeUnit::Millisecond) => 1_000.0, + (TimeUnit::Microsecond, TimeUnit::Microsecond) => 1.0, + (TimeUnit::Microsecond, TimeUnit::Nanosecond) => 0.001, + (TimeUnit::Nanosecond, TimeUnit::Second) => 1_000_000_000.0, + (TimeUnit::Nanosecond, TimeUnit::Millisecond) => 1_000_000.0, + (TimeUnit::Nanosecond, TimeUnit::Microsecond) => 1_000.0, + (TimeUnit::Nanosecond, TimeUnit::Nanosecond) => 1.0, + } +} + +/// Parses `value` to `Option` consistent with the Arrow's definition of timestamp with timezone. +/// +/// `tz` must be built from `timezone` (either via [`parse_offset`] or `chrono-tz`). +/// Returns in scale `tz` of `TimeUnit`. +#[inline] +pub fn utf8_to_timestamp_scalar( + value: &str, + fmt: &str, + tz: &T, + tu: &TimeUnit, +) -> Option { + let mut parsed = Parsed::new(); + let fmt = StrftimeItems::new(fmt); + let r = parse(&mut parsed, value, fmt).ok(); + if r.is_some() { + parsed + .to_datetime() + .map(|x| x.naive_utc()) + .map(|x| tz.from_utc_datetime(&x)) + .map(|x| match tu { + TimeUnit::Second => x.timestamp(), + TimeUnit::Millisecond => x.timestamp_millis(), + TimeUnit::Microsecond => x.timestamp_micros(), + TimeUnit::Nanosecond => x.timestamp_nanos_opt().unwrap(), + }) + .ok() + } else { + None + } +} + +/// Parses an offset of the form `"+WX:YZ"` or `"UTC"` into [`FixedOffset`]. +/// # Errors +/// If the offset is not in any of the allowed forms. +pub fn parse_offset(offset: &str) -> PolarsResult { + if offset == "UTC" { + return Ok(FixedOffset::east_opt(0).expect("FixedOffset::east out of bounds")); + } + let error = "timezone offset must be of the form [-]00:00"; + + let mut a = offset.split(':'); + let first: &str = a + .next() + .ok_or_else(|| polars_err!(InvalidOperation: error))?; + let last = a + .next() + .ok_or_else(|| polars_err!(InvalidOperation: error))?; + let hours: i32 = first + .parse() + .map_err(|_| polars_err!(InvalidOperation: error))?; + let minutes: i32 = last + .parse() + .map_err(|_| polars_err!(InvalidOperation: error))?; + + Ok(FixedOffset::east_opt(hours * 60 * 60 + minutes * 60) + .expect("FixedOffset::east out of bounds")) +} + +/// Parses `value` to a [`chrono_tz::Tz`] with the Arrow's definition of timestamp with a timezone. +#[cfg(feature = "chrono-tz")] +#[cfg_attr(docsrs, doc(cfg(feature = "chrono-tz")))] +pub fn parse_offset_tz(timezone: &str) -> PolarsResult { + timezone + .parse::() + .map_err(|_| polars_err!(InvalidOperation: "timezone \"{timezone}\" cannot be parsed")) +} diff --git a/crates/polars-arrow/src/trusted_len.rs b/crates/polars-arrow/src/trusted_len.rs new file mode 100644 index 000000000000..794f995caca6 --- /dev/null +++ b/crates/polars-arrow/src/trusted_len.rs @@ -0,0 +1,140 @@ +//! Declares [`TrustedLen`]. +use std::iter::Scan; +use std::slice::{Iter, IterMut}; + +/// An iterator of known, fixed size. +/// +/// A trait denoting Rusts' unstable [TrustedLen](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). +/// This is re-defined here and implemented for some iterators until `std::iter::TrustedLen` +/// is stabilized. +/// +/// # Safety +/// This trait must only be implemented when the contract is upheld. +/// Consumers of this trait must inspect Iterator::size_hint()’s upper bound. +pub unsafe trait TrustedLen: Iterator {} + +unsafe impl TrustedLen for Iter<'_, T> {} +unsafe impl TrustedLen for IterMut<'_, T> {} + +unsafe impl<'a, I, T: 'a> TrustedLen for std::iter::Copied +where + I: TrustedLen, + T: Copy, +{ +} +unsafe impl<'a, I, T: 'a> TrustedLen for std::iter::Cloned +where + I: TrustedLen, + T: Clone, +{ +} + +unsafe impl TrustedLen for std::iter::Enumerate where I: TrustedLen {} + +unsafe impl TrustedLen for std::iter::Zip +where + A: TrustedLen, + B: TrustedLen, +{ +} + +unsafe impl TrustedLen for std::slice::ChunksExact<'_, T> {} + +unsafe impl TrustedLen for std::slice::Windows<'_, T> {} + +unsafe impl TrustedLen for std::iter::Chain +where + A: TrustedLen, + B: TrustedLen, +{ +} + +unsafe impl TrustedLen for std::iter::Once {} + +unsafe impl TrustedLen for std::vec::IntoIter {} + +unsafe impl TrustedLen for std::iter::Repeat {} +unsafe impl A> TrustedLen for std::iter::RepeatWith {} +unsafe impl TrustedLen for std::iter::Take {} + +unsafe impl TrustedLen for std::iter::RepeatN {} + +unsafe impl TrustedLen for &mut dyn TrustedLen {} +unsafe impl TrustedLen for Box + '_> {} + +unsafe impl B> TrustedLen for std::iter::Map {} + +unsafe impl TrustedLen for std::iter::Rev {} + +unsafe impl, J> TrustedLen for TrustMyLength {} +unsafe impl TrustedLen for std::ops::Range where std::ops::Range: Iterator {} +unsafe impl TrustedLen for std::ops::RangeInclusive where std::ops::RangeInclusive: Iterator +{} +unsafe impl TrustedLen for std::iter::StepBy {} + +unsafe impl TrustedLen for Scan +where + F: FnMut(&mut St, I::Item) -> Option, + I: TrustedLen, +{ +} + +unsafe impl TrustedLen for hashbrown::hash_map::IntoIter {} + +#[derive(Clone)] +pub struct TrustMyLength, J> { + iter: I, + len: usize, +} + +impl TrustMyLength +where + I: Iterator, +{ + /// Create a new `TrustMyLength` iterator + /// + /// # Safety + /// + /// This is safe if the iterator always has the exact length given by `len`. + #[inline] + pub unsafe fn new(iter: I, len: usize) -> Self { + Self { iter, len } + } +} + +impl TrustMyLength, J> { + /// Create a new `TrustMyLength` iterator that repeats `value` `len` times. + pub fn new_repeat_n(value: J, len: usize) -> Self { + // SAFETY: This is always safe since repeat(..).take(n) always repeats exactly `n` times`. + unsafe { Self::new(std::iter::repeat_n(value, len), len) } + } +} + +impl Iterator for TrustMyLength +where + I: Iterator, +{ + type Item = J; + + #[inline] + fn next(&mut self) -> Option { + self.iter.next() + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.len, Some(self.len)) + } +} + +impl ExactSizeIterator for TrustMyLength where I: Iterator {} + +impl DoubleEndedIterator for TrustMyLength +where + I: Iterator + DoubleEndedIterator, +{ + #[inline] + fn next_back(&mut self) -> Option { + self.iter.next_back() + } +} diff --git a/crates/polars-arrow/src/types/aligned_bytes.rs b/crates/polars-arrow/src/types/aligned_bytes.rs new file mode 100644 index 000000000000..95aa84cfd349 --- /dev/null +++ b/crates/polars-arrow/src/types/aligned_bytes.rs @@ -0,0 +1,152 @@ +use bytemuck::{Pod, Zeroable}; + +use super::{days_ms, f16, i256, months_days_ns}; +use crate::array::View; + +/// Define that a type has the same byte alignment and size as `B`. +/// +/// # Safety +/// +/// This is safe to implement if both types have the same alignment and size. +pub unsafe trait AlignedBytesCast: Pod {} + +/// A representation of a type as raw bytes with the same alignment as the original type. +pub trait AlignedBytes: Pod + Zeroable + Copy + Default + Eq { + const ALIGNMENT: usize; + const SIZE: usize; + const SIZE_ALIGNMENT_PAIR: PrimitiveSizeAlignmentPair; + + type Unaligned: AsRef<[u8]> + + AsMut<[u8]> + + std::ops::Index + + std::ops::IndexMut + + for<'a> TryFrom<&'a [u8]> + + std::fmt::Debug + + Default + + IntoIterator + + Pod; + + fn to_unaligned(&self) -> Self::Unaligned; + fn from_unaligned(unaligned: Self::Unaligned) -> Self; + + /// Safely cast a mutable reference to a [`Vec`] of `T` to a mutable reference of `Self`. + fn cast_vec_ref_mut>(vec: &mut Vec) -> &mut Vec { + if cfg!(debug_assertions) { + assert_eq!(size_of::(), size_of::()); + assert_eq!(align_of::(), align_of::()); + } + + // SAFETY: SameBytes guarantees that T: + // 1. has the same size + // 2. has the same alignment + // 3. is Pod (therefore has no life-time issues) + unsafe { std::mem::transmute(vec) } + } +} + +macro_rules! impl_aligned_bytes { + ( + $(($name:ident, $size:literal, $alignment:literal, $sap:ident, [$($eq_type:ty),*]),)+ + ) => { + $( + /// Bytes with a size and alignment. + /// + /// This is used to reduce the monomorphizations for routines that solely rely on the size + /// and alignment of types. + #[derive(Debug, Copy, Clone, PartialEq, Eq, Default, Pod, Zeroable)] + #[repr(C, align($alignment))] + pub struct $name([u8; $size]); + + impl AlignedBytes for $name { + const ALIGNMENT: usize = $alignment; + const SIZE: usize = $size; + const SIZE_ALIGNMENT_PAIR: PrimitiveSizeAlignmentPair = PrimitiveSizeAlignmentPair::$sap; + + type Unaligned = [u8; $size]; + + #[inline(always)] + fn to_unaligned(&self) -> Self::Unaligned { + self.0 + } + #[inline(always)] + fn from_unaligned(unaligned: Self::Unaligned) -> Self { + Self(unaligned) + } + } + + impl AsRef<[u8; $size]> for $name { + #[inline(always)] + fn as_ref(&self) -> &[u8; $size] { + &self.0 + } + } + + $( + impl From<$eq_type> for $name { + #[inline(always)] + fn from(value: $eq_type) -> Self { + bytemuck::must_cast(value) + } + } + impl From<$name> for $eq_type { + #[inline(always)] + fn from(value: $name) -> Self { + bytemuck::must_cast(value) + } + } + unsafe impl AlignedBytesCast<$name> for $eq_type {} + )* + )+ + } +} + +#[derive(Clone, Copy)] +pub enum PrimitiveSizeAlignmentPair { + S1A1, + S2A2, + S4A4, + S8A4, + S8A8, + S12A4, + S16A4, + S16A8, + S16A16, + S32A16, +} + +impl PrimitiveSizeAlignmentPair { + pub const fn size(self) -> usize { + match self { + Self::S1A1 => 1, + Self::S2A2 => 2, + Self::S4A4 => 4, + Self::S8A4 | Self::S8A8 => 8, + Self::S12A4 => 12, + Self::S16A4 | Self::S16A8 | Self::S16A16 => 16, + Self::S32A16 => 32, + } + } + + pub const fn alignment(self) -> usize { + match self { + Self::S1A1 => 1, + Self::S2A2 => 2, + Self::S4A4 | Self::S8A4 | Self::S12A4 | Self::S16A4 => 4, + Self::S8A8 | Self::S16A8 => 8, + Self::S16A16 | Self::S32A16 => 16, + } + } +} + +impl_aligned_bytes! { + (Bytes1Alignment1, 1, 1, S1A1, [u8, i8]), + (Bytes2Alignment2, 2, 2, S2A2, [u16, i16, f16]), + (Bytes4Alignment4, 4, 4, S4A4, [u32, i32, f32]), + (Bytes8Alignment8, 8, 8, S8A8, [u64, i64, f64]), + (Bytes8Alignment4, 8, 4, S8A4, [days_ms]), + (Bytes12Alignment4, 12, 4, S12A4, [[u32; 3]]), + (Bytes16Alignment4, 16, 4, S16A4, [View]), + (Bytes16Alignment8, 16, 8, S16A8, [months_days_ns]), + (Bytes16Alignment16, 16, 16, S16A16, [u128, i128]), + (Bytes32Alignment16, 32, 16, S32A16, [i256]), +} diff --git a/crates/polars-arrow/src/types/bit_chunk.rs b/crates/polars-arrow/src/types/bit_chunk.rs new file mode 100644 index 000000000000..60892689060f --- /dev/null +++ b/crates/polars-arrow/src/types/bit_chunk.rs @@ -0,0 +1,162 @@ +use std::fmt::Binary; +use std::ops::{BitAndAssign, Not, Shl, ShlAssign, ShrAssign}; + +use num_traits::PrimInt; + +use super::NativeType; + +/// A chunk of bits. This is used to create masks of a given length +/// whose width is `1` bit. In `portable_simd` notation, this corresponds to `m1xY`. +/// +/// This (sealed) trait is implemented for [`u8`], [`u16`], [`u32`] and [`u64`]. +pub trait BitChunk: + super::private::Sealed + + PrimInt + + NativeType + + Binary + + ShlAssign + + Not + + ShrAssign + + ShlAssign + + Shl + + BitAndAssign +{ + /// convert itself into bytes. + fn to_ne_bytes(self) -> Self::Bytes; + /// convert itself from bytes. + fn from_ne_bytes(v: Self::Bytes) -> Self; +} + +macro_rules! bit_chunk { + ($ty:ty) => { + impl BitChunk for $ty { + #[inline(always)] + fn to_ne_bytes(self) -> Self::Bytes { + self.to_ne_bytes() + } + + #[inline(always)] + fn from_ne_bytes(v: Self::Bytes) -> Self { + Self::from_ne_bytes(v) + } + } + }; +} + +bit_chunk!(u8); +bit_chunk!(u16); +bit_chunk!(u32); +bit_chunk!(u64); + +/// An [`Iterator`] over a [`BitChunk`]. +/// +/// This iterator is often compiled to SIMD. +/// +/// The [LSB](https://en.wikipedia.org/wiki/Bit_numbering#Least_significant_bit) corresponds +/// to the first slot, as defined by the arrow specification. +/// # Example +/// ``` +/// use polars_arrow::types::BitChunkIter; +/// let a = 0b00010000u8; +/// let iter = BitChunkIter::new(a, 7); +/// let r = iter.collect::>(); +/// assert_eq!(r, vec![false, false, false, false, true, false, false]); +/// ``` +pub struct BitChunkIter { + value: T, + mask: T, + remaining: usize, +} + +impl BitChunkIter { + /// Creates a new [`BitChunkIter`] with `len` bits. + #[inline] + pub fn new(value: T, len: usize) -> Self { + assert!(len <= size_of::() * 8); + Self { + value, + remaining: len, + mask: T::one(), + } + } +} + +impl Iterator for BitChunkIter { + type Item = bool; + + #[inline] + fn next(&mut self) -> Option { + if self.remaining == 0 { + return None; + }; + let result = Some(self.value & self.mask != T::zero()); + self.remaining -= 1; + self.mask <<= 1; + result + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.remaining, Some(self.remaining)) + } +} + +// # Safety +// a mathematical invariant of this iterator +unsafe impl crate::trusted_len::TrustedLen for BitChunkIter {} + +/// An [`Iterator`] over a [`BitChunk`] returning the index of each bit set in the chunk +/// See for details +/// # Example +/// ``` +/// use polars_arrow::types::BitChunkOnes; +/// let a = 0b00010000u8; +/// let iter = BitChunkOnes::new(a); +/// let r = iter.collect::>(); +/// assert_eq!(r, vec![4]); +/// ``` +pub struct BitChunkOnes { + value: T, + remaining: usize, +} + +impl BitChunkOnes { + /// Creates a new [`BitChunkOnes`] with `len` bits. + #[inline] + pub fn new(value: T) -> Self { + Self { + value, + remaining: value.count_ones() as usize, + } + } + + #[inline] + pub fn from_known_count(value: T, remaining: usize) -> Self { + Self { value, remaining } + } +} + +impl Iterator for BitChunkOnes { + type Item = usize; + + #[inline] + fn next(&mut self) -> Option { + if self.remaining == 0 { + return None; + } + let v = self.value.trailing_zeros() as usize; + self.value &= self.value - T::one(); + + self.remaining -= 1; + Some(v) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.remaining, Some(self.remaining)) + } +} + +// # Safety +// a mathematical invariant of this iterator +unsafe impl crate::trusted_len::TrustedLen for BitChunkOnes {} diff --git a/crates/polars-arrow/src/types/index.rs b/crates/polars-arrow/src/types/index.rs new file mode 100644 index 000000000000..83299a76980f --- /dev/null +++ b/crates/polars-arrow/src/types/index.rs @@ -0,0 +1,103 @@ +use super::NativeType; +use crate::trusted_len::TrustedLen; + +/// Sealed trait describing the subset of [`NativeType`] (`i32`, `i64`, `u32` and `u64`) +/// that can be used to index a slot of an array. +pub trait Index: + NativeType + + std::ops::AddAssign + + std::ops::Sub + + num_traits::One + + num_traits::Num + + num_traits::CheckedAdd + + PartialOrd + + Ord +{ + /// Convert itself to [`usize`]. + fn to_usize(&self) -> usize; + /// Convert itself from [`usize`]. + fn from_usize(index: usize) -> Option; + + /// Convert itself from [`usize`]. + fn from_as_usize(index: usize) -> Self; + + /// An iterator from (inclusive) `start` to (exclusive) `end`. + fn range(start: usize, end: usize) -> Option> { + let start = Self::from_usize(start); + let end = Self::from_usize(end); + match (start, end) { + (Some(start), Some(end)) => Some(IndexRange::new(start, end)), + _ => None, + } + } +} + +macro_rules! index { + ($t:ty) => { + impl Index for $t { + #[inline] + fn to_usize(&self) -> usize { + *self as usize + } + + #[inline] + fn from_usize(value: usize) -> Option { + Self::try_from(value).ok() + } + + #[inline] + fn from_as_usize(value: usize) -> Self { + value as $t + } + } + }; +} + +index!(i8); +index!(i16); +index!(i32); +index!(i64); +index!(u8); +index!(u16); +index!(u32); +index!(u64); + +/// Range of [`Index`], equivalent to `(a..b)`. +/// `Step` is unstable in Rust, which does not allow us to implement (a..b) for [`Index`]. +pub struct IndexRange { + start: I, + end: I, +} + +impl IndexRange { + /// Returns a new [`IndexRange`]. + pub fn new(start: I, end: I) -> Self { + assert!(end >= start); + Self { start, end } + } +} + +impl Iterator for IndexRange { + type Item = I; + + #[inline] + fn next(&mut self) -> Option { + if self.start == self.end { + return None; + } + let old = self.start; + self.start += I::one(); + Some(old) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = (self.end - self.start).to_usize(); + (len, Some(len)) + } +} + +/// # Safety +/// +/// A range is always of known length. +unsafe impl TrustedLen for IndexRange {} diff --git a/crates/polars-arrow/src/types/mod.rs b/crates/polars-arrow/src/types/mod.rs new file mode 100644 index 000000000000..7a22102edc16 --- /dev/null +++ b/crates/polars-arrow/src/types/mod.rs @@ -0,0 +1,93 @@ +//! Sealed traits and implementations to handle all _physical types_ used in this crate. +//! +//! Most physical types used in this crate are native Rust types, such as `i32`. +//! The trait [`NativeType`] describes the interfaces required by this crate to be conformant +//! with Arrow. +//! +//! Every implementation of [`NativeType`] has an associated variant in [`PrimitiveType`], +//! available via [`NativeType::PRIMITIVE`]. +//! Combined, these allow structs generic over [`NativeType`] to be trait objects downcastable +//! to concrete implementations based on the matched [`NativeType::PRIMITIVE`] variant. +//! +//! Another important trait in this module is [`Offset`], the subset of [`NativeType`] that can +//! be used in Arrow offsets (`i32` and `i64`). +//! +//! Another important trait in this module is [`BitChunk`], describing types that can be used to +//! represent chunks of bits (e.g. 8 bits via `u8`, 16 via `u16`), and [`BitChunkIter`], +//! that can be used to iterate over bitmaps in [`BitChunk`]s according to +//! Arrow's definition of bitmaps. + +mod aligned_bytes; +pub use aligned_bytes::*; +mod bit_chunk; +pub use bit_chunk::{BitChunk, BitChunkIter, BitChunkOnes}; +mod index; +pub use index::*; +mod native; +pub use native::*; +mod offset; +pub use offset::*; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// The set of all implementations of the sealed trait [`NativeType`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum PrimitiveType { + /// A signed 8-bit integer. + Int8, + /// A signed 16-bit integer. + Int16, + /// A signed 32-bit integer. + Int32, + /// A signed 64-bit integer. + Int64, + /// A signed 128-bit integer. + Int128, + /// A signed 256-bit integer. + Int256, + /// An unsigned 8-bit integer. + UInt8, + /// An unsigned 16-bit integer. + UInt16, + /// An unsigned 32-bit integer. + UInt32, + /// An unsigned 64-bit integer. + UInt64, + /// An unsigned 128-bit integer. + UInt128, + /// A 16-bit floating point number. + Float16, + /// A 32-bit floating point number. + Float32, + /// A 64-bit floating point number. + Float64, + /// Two i32 representing days and ms + DaysMs, + /// months_days_ns(i32, i32, i64) + MonthDayNano, +} + +mod private { + use crate::array::View; + + pub trait Sealed {} + + impl Sealed for u8 {} + impl Sealed for u16 {} + impl Sealed for u32 {} + impl Sealed for u64 {} + impl Sealed for i8 {} + impl Sealed for i16 {} + impl Sealed for i32 {} + impl Sealed for i64 {} + impl Sealed for i128 {} + impl Sealed for u128 {} + impl Sealed for super::i256 {} + impl Sealed for super::f16 {} + impl Sealed for f32 {} + impl Sealed for f64 {} + impl Sealed for super::days_ms {} + impl Sealed for super::months_days_ns {} + impl Sealed for View {} +} diff --git a/crates/polars-arrow/src/types/native.rs b/crates/polars-arrow/src/types/native.rs new file mode 100644 index 000000000000..ccbba3234663 --- /dev/null +++ b/crates/polars-arrow/src/types/native.rs @@ -0,0 +1,860 @@ +use std::hash::{Hash, Hasher}; +use std::ops::Neg; +use std::panic::RefUnwindSafe; + +use bytemuck::{Pod, Zeroable}; +use polars_utils::min_max::MinMax; +use polars_utils::nulls::IsNull; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash, TotalOrd, TotalOrdWrap}; + +use super::PrimitiveType; +use super::aligned_bytes::*; + +/// Sealed trait implemented by all physical types that can be allocated, +/// serialized and deserialized by this crate. +/// All O(N) allocations in this crate are done for this trait alone. +pub trait NativeType: + super::private::Sealed + + Pod + + Send + + Sync + + Sized + + RefUnwindSafe + + std::fmt::Debug + + std::fmt::Display + + PartialEq + + Default + + Copy + + TotalOrd + + IsNull + + MinMax + + AlignedBytesCast +{ + /// The corresponding variant of [`PrimitiveType`]. + const PRIMITIVE: PrimitiveType; + + /// Type denoting its representation as bytes. + /// This is `[u8; N]` where `N = size_of::`. + type Bytes: AsRef<[u8]> + + AsMut<[u8]> + + std::ops::Index + + std::ops::IndexMut + + for<'a> TryFrom<&'a [u8]> + + std::fmt::Debug + + Default + + IntoIterator; + + /// Type denoting its representation as aligned bytes. + /// + /// This is `[u8; N]` where `N = size_of::` and has alignment `align_of::`. + type AlignedBytes: AlignedBytes + From + Into; + + /// To bytes in little endian + fn to_le_bytes(&self) -> Self::Bytes; + + /// To bytes in big endian + fn to_be_bytes(&self) -> Self::Bytes; + + /// From bytes in little endian + fn from_le_bytes(bytes: Self::Bytes) -> Self; + + /// From bytes in big endian + fn from_be_bytes(bytes: Self::Bytes) -> Self; +} + +macro_rules! native_type { + ($type:ty, $aligned:ty, $primitive_type:expr) => { + impl NativeType for $type { + const PRIMITIVE: PrimitiveType = $primitive_type; + + type Bytes = [u8; std::mem::size_of::()]; + type AlignedBytes = $aligned; + + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + Self::to_le_bytes(*self) + } + + #[inline] + fn to_be_bytes(&self) -> Self::Bytes { + Self::to_be_bytes(*self) + } + + #[inline] + fn from_le_bytes(bytes: Self::Bytes) -> Self { + Self::from_le_bytes(bytes) + } + + #[inline] + fn from_be_bytes(bytes: Self::Bytes) -> Self { + Self::from_be_bytes(bytes) + } + } + }; +} + +native_type!(u8, Bytes1Alignment1, PrimitiveType::UInt8); +native_type!(u16, Bytes2Alignment2, PrimitiveType::UInt16); +native_type!(u32, Bytes4Alignment4, PrimitiveType::UInt32); +native_type!(u64, Bytes8Alignment8, PrimitiveType::UInt64); +native_type!(i8, Bytes1Alignment1, PrimitiveType::Int8); +native_type!(i16, Bytes2Alignment2, PrimitiveType::Int16); +native_type!(i32, Bytes4Alignment4, PrimitiveType::Int32); +native_type!(i64, Bytes8Alignment8, PrimitiveType::Int64); +native_type!(f32, Bytes4Alignment4, PrimitiveType::Float32); +native_type!(f64, Bytes8Alignment8, PrimitiveType::Float64); +native_type!(i128, Bytes16Alignment16, PrimitiveType::Int128); +native_type!(u128, Bytes16Alignment16, PrimitiveType::UInt128); + +/// The in-memory representation of the DayMillisecond variant of arrow's "Interval" logical type. +#[derive(Debug, Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Zeroable, Pod)] +#[allow(non_camel_case_types)] +#[repr(C)] +pub struct days_ms(pub i32, pub i32); + +impl days_ms { + /// A new [`days_ms`]. + #[inline] + pub fn new(days: i32, milliseconds: i32) -> Self { + Self(days, milliseconds) + } + + /// The number of days + #[inline] + pub fn days(&self) -> i32 { + self.0 + } + + /// The number of milliseconds + #[inline] + pub fn milliseconds(&self) -> i32 { + self.1 + } +} + +impl TotalEq for days_ms { + #[inline] + fn tot_eq(&self, other: &Self) -> bool { + self == other + } +} + +impl TotalOrd for days_ms { + #[inline] + fn tot_cmp(&self, other: &Self) -> std::cmp::Ordering { + self.days() + .cmp(&other.days()) + .then(self.milliseconds().cmp(&other.milliseconds())) + } +} + +impl MinMax for days_ms { + fn nan_min_lt(&self, other: &Self) -> bool { + self < other + } + + fn nan_max_lt(&self, other: &Self) -> bool { + self < other + } +} + +impl NativeType for days_ms { + const PRIMITIVE: PrimitiveType = PrimitiveType::DaysMs; + + type Bytes = [u8; 8]; + type AlignedBytes = Bytes8Alignment4; + + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + let days = self.0.to_le_bytes(); + let ms = self.1.to_le_bytes(); + let mut result = [0; 8]; + result[0] = days[0]; + result[1] = days[1]; + result[2] = days[2]; + result[3] = days[3]; + result[4] = ms[0]; + result[5] = ms[1]; + result[6] = ms[2]; + result[7] = ms[3]; + result + } + + #[inline] + fn to_be_bytes(&self) -> Self::Bytes { + let days = self.0.to_be_bytes(); + let ms = self.1.to_be_bytes(); + let mut result = [0; 8]; + result[0] = days[0]; + result[1] = days[1]; + result[2] = days[2]; + result[3] = days[3]; + result[4] = ms[0]; + result[5] = ms[1]; + result[6] = ms[2]; + result[7] = ms[3]; + result + } + + #[inline] + fn from_le_bytes(bytes: Self::Bytes) -> Self { + let mut days = [0; 4]; + days[0] = bytes[0]; + days[1] = bytes[1]; + days[2] = bytes[2]; + days[3] = bytes[3]; + let mut ms = [0; 4]; + ms[0] = bytes[4]; + ms[1] = bytes[5]; + ms[2] = bytes[6]; + ms[3] = bytes[7]; + Self(i32::from_le_bytes(days), i32::from_le_bytes(ms)) + } + + #[inline] + fn from_be_bytes(bytes: Self::Bytes) -> Self { + let mut days = [0; 4]; + days[0] = bytes[0]; + days[1] = bytes[1]; + days[2] = bytes[2]; + days[3] = bytes[3]; + let mut ms = [0; 4]; + ms[0] = bytes[4]; + ms[1] = bytes[5]; + ms[2] = bytes[6]; + ms[3] = bytes[7]; + Self(i32::from_be_bytes(days), i32::from_be_bytes(ms)) + } +} + +/// The in-memory representation of the MonthDayNano variant of the "Interval" logical type. +#[derive(Debug, Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Zeroable, Pod)] +#[allow(non_camel_case_types)] +#[repr(C)] +pub struct months_days_ns(pub i32, pub i32, pub i64); + +impl IsNull for months_days_ns { + const HAS_NULLS: bool = false; + type Inner = months_days_ns; + + fn is_null(&self) -> bool { + false + } + + fn unwrap_inner(self) -> Self::Inner { + self + } +} + +impl months_days_ns { + /// A new [`months_days_ns`]. + #[inline] + pub fn new(months: i32, days: i32, nanoseconds: i64) -> Self { + Self(months, days, nanoseconds) + } + + /// The number of months + #[inline] + pub fn months(&self) -> i32 { + self.0 + } + + /// The number of days + #[inline] + pub fn days(&self) -> i32 { + self.1 + } + + /// The number of nanoseconds + #[inline] + pub fn ns(&self) -> i64 { + self.2 + } +} + +impl TotalEq for months_days_ns { + #[inline] + fn tot_eq(&self, other: &Self) -> bool { + self == other + } +} + +impl TotalOrd for months_days_ns { + #[inline] + fn tot_cmp(&self, other: &Self) -> std::cmp::Ordering { + self.months() + .cmp(&other.months()) + .then(self.days().cmp(&other.days())) + .then(self.ns().cmp(&other.ns())) + } +} + +impl MinMax for months_days_ns { + fn nan_min_lt(&self, other: &Self) -> bool { + self < other + } + + fn nan_max_lt(&self, other: &Self) -> bool { + self < other + } +} + +impl NativeType for months_days_ns { + const PRIMITIVE: PrimitiveType = PrimitiveType::MonthDayNano; + + type Bytes = [u8; 16]; + type AlignedBytes = Bytes16Alignment8; + + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + let months = self.months().to_le_bytes(); + let days = self.days().to_le_bytes(); + let ns = self.ns().to_le_bytes(); + let mut result = [0; 16]; + result[0] = months[0]; + result[1] = months[1]; + result[2] = months[2]; + result[3] = months[3]; + result[4] = days[0]; + result[5] = days[1]; + result[6] = days[2]; + result[7] = days[3]; + (0..8).for_each(|i| { + result[8 + i] = ns[i]; + }); + result + } + + #[inline] + fn to_be_bytes(&self) -> Self::Bytes { + let months = self.months().to_be_bytes(); + let days = self.days().to_be_bytes(); + let ns = self.ns().to_be_bytes(); + let mut result = [0; 16]; + result[0] = months[0]; + result[1] = months[1]; + result[2] = months[2]; + result[3] = months[3]; + result[4] = days[0]; + result[5] = days[1]; + result[6] = days[2]; + result[7] = days[3]; + (0..8).for_each(|i| { + result[8 + i] = ns[i]; + }); + result + } + + #[inline] + fn from_le_bytes(bytes: Self::Bytes) -> Self { + let mut months = [0; 4]; + months[0] = bytes[0]; + months[1] = bytes[1]; + months[2] = bytes[2]; + months[3] = bytes[3]; + let mut days = [0; 4]; + days[0] = bytes[4]; + days[1] = bytes[5]; + days[2] = bytes[6]; + days[3] = bytes[7]; + let mut ns = [0; 8]; + (0..8).for_each(|i| { + ns[i] = bytes[8 + i]; + }); + Self( + i32::from_le_bytes(months), + i32::from_le_bytes(days), + i64::from_le_bytes(ns), + ) + } + + #[inline] + fn from_be_bytes(bytes: Self::Bytes) -> Self { + let mut months = [0; 4]; + months[0] = bytes[0]; + months[1] = bytes[1]; + months[2] = bytes[2]; + months[3] = bytes[3]; + let mut days = [0; 4]; + days[0] = bytes[4]; + days[1] = bytes[5]; + days[2] = bytes[6]; + days[3] = bytes[7]; + let mut ns = [0; 8]; + (0..8).for_each(|i| { + ns[i] = bytes[8 + i]; + }); + Self( + i32::from_be_bytes(months), + i32::from_be_bytes(days), + i64::from_be_bytes(ns), + ) + } +} + +impl IsNull for days_ms { + const HAS_NULLS: bool = false; + type Inner = days_ms; + fn is_null(&self) -> bool { + false + } + fn unwrap_inner(self) -> Self::Inner { + self + } +} + +impl std::fmt::Display for days_ms { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}d {}ms", self.days(), self.milliseconds()) + } +} + +impl std::fmt::Display for months_days_ns { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}m {}d {}ns", self.months(), self.days(), self.ns()) + } +} + +impl Neg for days_ms { + type Output = Self; + + #[inline(always)] + fn neg(self) -> Self::Output { + Self::new(-self.days(), -self.milliseconds()) + } +} + +impl Neg for months_days_ns { + type Output = Self; + + #[inline(always)] + fn neg(self) -> Self::Output { + Self::new(-self.months(), -self.days(), -self.ns()) + } +} + +/// Type representation of the Float16 physical type +#[derive(Copy, Clone, Default, Zeroable, Pod)] +#[allow(non_camel_case_types)] +#[repr(C)] +pub struct f16(pub u16); + +impl PartialEq for f16 { + #[inline] + fn eq(&self, other: &f16) -> bool { + if self.is_nan() || other.is_nan() { + false + } else { + (self.0 == other.0) || ((self.0 | other.0) & 0x7FFFu16 == 0) + } + } +} + +/// Converts an f32 into a canonical form, where -0 == 0 and all NaNs map to +/// the same value. +#[inline] +pub fn canonical_f16(x: f16) -> f16 { + // zero out the sign bit if the f16 is zero. + let convert_zero = f16(x.0 & (0x7FFF | (u16::from(x.0 & 0x7FFF == 0) << 15))); + if convert_zero.is_nan() { + f16::from_bits(0x7c00) // Canonical quiet NaN. + } else { + convert_zero + } +} + +impl TotalHash for f16 { + #[inline(always)] + fn tot_hash(&self, state: &mut H) + where + H: Hasher, + { + canonical_f16(*self).to_bits().hash(state) + } +} + +impl ToTotalOrd for f16 { + type TotalOrdItem = TotalOrdWrap; + type SourceItem = f16; + + #[inline] + fn to_total_ord(&self) -> Self::TotalOrdItem { + TotalOrdWrap(*self) + } + + #[inline] + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem { + ord_item.0 + } +} + +impl IsNull for f16 { + const HAS_NULLS: bool = false; + type Inner = f16; + + #[inline(always)] + fn is_null(&self) -> bool { + false + } + fn unwrap_inner(self) -> Self::Inner { + self + } +} + +// see https://github.com/starkat99/half-rs/blob/main/src/binary16.rs +impl f16 { + /// The difference between 1.0 and the next largest representable number. + pub const EPSILON: f16 = f16(0x1400u16); + + #[inline] + #[must_use] + pub(crate) const fn is_nan(self) -> bool { + self.0 & 0x7FFFu16 > 0x7C00u16 + } + + /// Casts from u16. + #[inline] + pub const fn from_bits(bits: u16) -> f16 { + f16(bits) + } + + /// Casts to u16. + #[inline] + pub const fn to_bits(self) -> u16 { + self.0 + } + + /// Casts this `f16` to `f32` + pub fn to_f32(self) -> f32 { + let i = self.0; + // Check for signed zero + if i & 0x7FFFu16 == 0 { + return f32::from_bits((i as u32) << 16); + } + + let half_sign = (i & 0x8000u16) as u32; + let half_exp = (i & 0x7C00u16) as u32; + let half_man = (i & 0x03FFu16) as u32; + + // Check for an infinity or NaN when all exponent bits set + if half_exp == 0x7C00u32 { + // Check for signed infinity if mantissa is zero + if half_man == 0 { + let number = (half_sign << 16) | 0x7F80_0000u32; + return f32::from_bits(number); + } else { + // NaN, keep current mantissa but also set most significiant mantissa bit + let number = (half_sign << 16) | 0x7FC0_0000u32 | (half_man << 13); + return f32::from_bits(number); + } + } + + // Calculate single-precision components with adjusted exponent + let sign = half_sign << 16; + // Unbias exponent + let unbiased_exp = ((half_exp as i32) >> 10) - 15; + + // Check for subnormals, which will be normalized by adjusting exponent + if half_exp == 0 { + // Calculate how much to adjust the exponent by + let e = (half_man as u16).leading_zeros() - 6; + + // Rebias and adjust exponent + let exp = (127 - 15 - e) << 23; + let man = (half_man << (14 + e)) & 0x7F_FF_FFu32; + return f32::from_bits(sign | exp | man); + } + + // Rebias exponent for a normalized normal + let exp = ((unbiased_exp + 127) as u32) << 23; + let man = (half_man & 0x03FFu32) << 13; + f32::from_bits(sign | exp | man) + } + + /// Casts an `f32` into `f16` + pub fn from_f32(value: f32) -> Self { + let x: u32 = value.to_bits(); + + // Extract IEEE754 components + let sign = x & 0x8000_0000u32; + let exp = x & 0x7F80_0000u32; + let man = x & 0x007F_FFFFu32; + + // Check for all exponent bits being set, which is Infinity or NaN + if exp == 0x7F80_0000u32 { + // Set mantissa MSB for NaN (and also keep shifted mantissa bits) + let nan_bit = if man == 0 { 0 } else { 0x0200u32 }; + return f16(((sign >> 16) | 0x7C00u32 | nan_bit | (man >> 13)) as u16); + } + + // The number is normalized, start assembling half precision version + let half_sign = sign >> 16; + // Unbias the exponent, then bias for half precision + let unbiased_exp = ((exp >> 23) as i32) - 127; + let half_exp = unbiased_exp + 15; + + // Check for exponent overflow, return +infinity + if half_exp >= 0x1F { + return f16((half_sign | 0x7C00u32) as u16); + } + + // Check for underflow + if half_exp <= 0 { + // Check mantissa for what we can do + if 14 - half_exp > 24 { + // No rounding possibility, so this is a full underflow, return signed zero + return f16(half_sign as u16); + } + // Don't forget about hidden leading mantissa bit when assembling mantissa + let man = man | 0x0080_0000u32; + let mut half_man = man >> (14 - half_exp); + // Check for rounding (see comment above functions) + let round_bit = 1 << (13 - half_exp); + if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 { + half_man += 1; + } + // No exponent for subnormals + return f16((half_sign | half_man) as u16); + } + + // Rebias the exponent + let half_exp = (half_exp as u32) << 10; + let half_man = man >> 13; + // Check for rounding (see comment above functions) + let round_bit = 0x0000_1000u32; + if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 { + // Round it + f16(((half_sign | half_exp | half_man) + 1) as u16) + } else { + f16((half_sign | half_exp | half_man) as u16) + } + } +} + +impl std::fmt::Debug for f16 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.to_f32()) + } +} + +impl std::fmt::Display for f16 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.to_f32()) + } +} + +impl TotalEq for f16 { + #[inline] + fn tot_eq(&self, other: &Self) -> bool { + if self.is_nan() { + other.is_nan() + } else { + self == other + } + } +} + +impl TotalOrd for f16 { + #[inline] + fn tot_cmp(&self, _other: &Self) -> std::cmp::Ordering { + unimplemented!() + } +} + +impl MinMax for f16 { + fn nan_min_lt(&self, _other: &Self) -> bool { + unimplemented!() + } + + fn nan_max_lt(&self, _other: &Self) -> bool { + unimplemented!() + } +} + +impl NativeType for f16 { + const PRIMITIVE: PrimitiveType = PrimitiveType::Float16; + + type Bytes = [u8; 2]; + type AlignedBytes = Bytes2Alignment2; + + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + self.0.to_le_bytes() + } + + #[inline] + fn to_be_bytes(&self) -> Self::Bytes { + self.0.to_be_bytes() + } + + #[inline] + fn from_be_bytes(bytes: Self::Bytes) -> Self { + Self(u16::from_be_bytes(bytes)) + } + + #[inline] + fn from_le_bytes(bytes: Self::Bytes) -> Self { + Self(u16::from_le_bytes(bytes)) + } +} + +/// Physical representation of a decimal +#[derive(Clone, Copy, Default, Eq, Hash, PartialEq, PartialOrd, Ord)] +#[allow(non_camel_case_types)] +#[repr(C)] +pub struct i256(pub ethnum::I256); + +impl i256 { + /// Returns a new [`i256`] from two `i128`. + pub fn from_words(hi: i128, lo: i128) -> Self { + Self(ethnum::I256::from_words(hi, lo)) + } +} + +impl IsNull for i256 { + const HAS_NULLS: bool = false; + type Inner = i256; + #[inline(always)] + fn is_null(&self) -> bool { + false + } + fn unwrap_inner(self) -> Self::Inner { + self + } +} + +impl Neg for i256 { + type Output = Self; + + #[inline] + fn neg(self) -> Self::Output { + let (a, b) = self.0.into_words(); + Self(ethnum::I256::from_words(-a, b)) + } +} + +impl std::fmt::Debug for i256 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.0) + } +} + +impl std::fmt::Display for i256 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +unsafe impl Pod for i256 {} +unsafe impl Zeroable for i256 {} + +impl TotalEq for i256 { + #[inline] + fn tot_eq(&self, other: &Self) -> bool { + self == other + } +} + +impl TotalOrd for i256 { + #[inline] + fn tot_cmp(&self, other: &Self) -> std::cmp::Ordering { + self.cmp(other) + } +} + +impl MinMax for i256 { + fn nan_min_lt(&self, other: &Self) -> bool { + self < other + } + + fn nan_max_lt(&self, other: &Self) -> bool { + self < other + } +} + +impl NativeType for i256 { + const PRIMITIVE: PrimitiveType = PrimitiveType::Int256; + + type Bytes = [u8; 32]; + type AlignedBytes = Bytes32Alignment16; + + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + let mut bytes = [0u8; 32]; + let (a, b) = self.0.into_words(); + let a = a.to_le_bytes(); + (0..16).for_each(|i| { + bytes[i] = a[i]; + }); + + let b = b.to_le_bytes(); + (0..16).for_each(|i| { + bytes[i + 16] = b[i]; + }); + + bytes + } + + #[inline] + fn to_be_bytes(&self) -> Self::Bytes { + let mut bytes = [0u8; 32]; + let (a, b) = self.0.into_words(); + + let a = a.to_be_bytes(); + (0..16).for_each(|i| { + bytes[i] = a[i]; + }); + + let b = b.to_be_bytes(); + (0..16).for_each(|i| { + bytes[i + 16] = b[i]; + }); + + bytes + } + + #[inline] + fn from_be_bytes(bytes: Self::Bytes) -> Self { + let (a, b) = bytes.split_at(16); + let a: [u8; 16] = a.try_into().unwrap(); + let b: [u8; 16] = b.try_into().unwrap(); + let a = i128::from_be_bytes(a); + let b = i128::from_be_bytes(b); + Self(ethnum::I256::from_words(a, b)) + } + + #[inline] + fn from_le_bytes(bytes: Self::Bytes) -> Self { + let (b, a) = bytes.split_at(16); + let a: [u8; 16] = a.try_into().unwrap(); + let b: [u8; 16] = b.try_into().unwrap(); + let a = i128::from_le_bytes(a); + let b = i128::from_le_bytes(b); + Self(ethnum::I256::from_words(a, b)) + } +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn test_f16_to_f32() { + let f = f16::from_f32(7.0); + assert_eq!(f.to_f32(), 7.0f32); + + // 7.1 is NOT exactly representable in 16-bit, it's rounded + let f = f16::from_f32(7.1); + let diff = (f.to_f32() - 7.1f32).abs(); + // diff must be <= 4 * EPSILON, as 7 has two more significant bits than 1 + assert!(diff <= 4.0 * f16::EPSILON.to_f32()); + + assert_eq!(f16(0x0000_0001).to_f32(), 2.0f32.powi(-24)); + assert_eq!(f16(0x0000_0005).to_f32(), 5.0 * 2.0f32.powi(-24)); + + assert_eq!(f16(0x0000_0001), f16::from_f32(2.0f32.powi(-24))); + assert_eq!(f16(0x0000_0005), f16::from_f32(5.0 * 2.0f32.powi(-24))); + + assert_eq!(format!("{}", f16::from_f32(7.0)), "7".to_string()); + assert_eq!(format!("{:?}", f16::from_f32(7.0)), "7.0".to_string()); + } +} diff --git a/crates/polars-arrow/src/types/offset.rs b/crates/polars-arrow/src/types/offset.rs new file mode 100644 index 000000000000..e68bb7ceb6bd --- /dev/null +++ b/crates/polars-arrow/src/types/offset.rs @@ -0,0 +1,16 @@ +use super::Index; + +/// Sealed trait describing the subset (`i32` and `i64`) of [`Index`] that can be used +/// as offsets of variable-length Arrow arrays. +pub trait Offset: super::private::Sealed + Index { + /// Whether it is `i32` (false) or `i64` (true). + const IS_LARGE: bool; +} + +impl Offset for i32 { + const IS_LARGE: bool = false; +} + +impl Offset for i64 { + const IS_LARGE: bool = true; +} diff --git a/crates/polars-arrow/src/util/macros.rs b/crates/polars-arrow/src/util/macros.rs new file mode 100644 index 000000000000..2153d2cb3a07 --- /dev/null +++ b/crates/polars-arrow/src/util/macros.rs @@ -0,0 +1,66 @@ +#[macro_export] +macro_rules! with_match_primitive_type {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use $crate::datatypes::PrimitiveType::*; + match $key_type { + Int8 => __with_ty__! { i8 }, + Int16 => __with_ty__! { i16 }, + Int32 => __with_ty__! { i32 }, + Int64 => __with_ty__! { i64 }, + UInt8 => __with_ty__! { u8 }, + UInt16 => __with_ty__! { u16 }, + UInt32 => __with_ty__! { u32 }, + UInt64 => __with_ty__! { u64 }, + Int128 => __with_ty__! { i128 }, + Float32 => __with_ty__! { f32 }, + Float64 => __with_ty__! { f64 }, + _ => panic!("operator does not support primitive `{:?}`", + $key_type) + } +})} + +#[macro_export] +macro_rules! with_match_primitive_type_full {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use $crate::datatypes::PrimitiveType::*; + use $crate::types::{f16}; + match $key_type { + Int8 => __with_ty__! { i8 }, + Int16 => __with_ty__! { i16 }, + Int32 => __with_ty__! { i32 }, + Int64 => __with_ty__! { i64 }, + UInt8 => __with_ty__! { u8 }, + UInt16 => __with_ty__! { u16 }, + UInt32 => __with_ty__! { u32 }, + UInt64 => __with_ty__! { u64 }, + Int128 => __with_ty__! { i128 }, + Float16 => __with_ty__! { f16 }, + Float32 => __with_ty__! { f32 }, + Float64 => __with_ty__! { f64 }, + _ => panic!("operator does not support primitive `{:?}`", + $key_type) + } +})} + +#[macro_export] +macro_rules! match_integer_type {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use $crate::datatypes::IntegerType::*; + match $key_type { + Int8 => __with_ty__! { i8 }, + Int16 => __with_ty__! { i16 }, + Int32 => __with_ty__! { i32 }, + Int64 => __with_ty__! { i64 }, + Int128 => __with_ty__! { i128 }, + UInt8 => __with_ty__! { u8 }, + UInt16 => __with_ty__! { u16 }, + UInt32 => __with_ty__! { u32 }, + UInt64 => __with_ty__! { u64 }, + } +})} diff --git a/crates/polars-arrow/src/util/mod.rs b/crates/polars-arrow/src/util/mod.rs new file mode 100644 index 000000000000..3940dd45a2fe --- /dev/null +++ b/crates/polars-arrow/src/util/mod.rs @@ -0,0 +1,2 @@ +//! Misc utilities used in different places in the crate. +pub mod macros; diff --git a/crates/polars-compute/Cargo.toml b/crates/polars-compute/Cargo.toml new file mode 100644 index 000000000000..1c622d8f8e49 --- /dev/null +++ b/crates/polars-compute/Cargo.toml @@ -0,0 +1,52 @@ +[package] +name = "polars-compute" +version = { workspace = true } +authors = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +license = { workspace = true } +repository = { workspace = true } +description = "Private compute kernels for the Polars DataFrame library" + +[dependencies] +arrow = { workspace = true } +atoi_simd = { workspace = true, optional = true } +bytemuck = { workspace = true } +chrono = { workspace = true, optional = true } +either = { workspace = true } +fast-float2 = { workspace = true, optional = true } +hashbrown = { workspace = true } +itoa = { workspace = true, optional = true } +num-traits = { workspace = true } +polars-error = { workspace = true } +polars-utils = { workspace = true } +rand = { workspace = true } +ryu = { workspace = true, optional = true } +serde = { workspace = true, optional = true } +skiplist = { workspace = true } +strength_reduce = { workspace = true } +strum_macros = { workspace = true } + +[dev-dependencies] +rand = { workspace = true } + +[build-dependencies] +version_check = { workspace = true } + +[features] +cast = [ + "gather", + "arrow/chrono-tz", + "dep:atoi_simd", + "dep:chrono", + "dep:fast-float2", + "dep:itoa", + "dep:ryu", +] +gather = [] +nightly = [] +simd = ["arrow/simd"] +approx_unique = [] +dtype-array = [] +dtype-decimal = ["arrow/dtype-decimal", "dtype-i128"] +dtype-i128 = [] diff --git a/crates/polars-compute/LICENSE b/crates/polars-compute/LICENSE new file mode 120000 index 000000000000..30cff7403da0 --- /dev/null +++ b/crates/polars-compute/LICENSE @@ -0,0 +1 @@ +../../LICENSE \ No newline at end of file diff --git a/crates/polars-compute/README.md b/crates/polars-compute/README.md new file mode 100644 index 000000000000..8b5cc0e0be09 --- /dev/null +++ b/crates/polars-compute/README.md @@ -0,0 +1,7 @@ +# polars-compute + +`polars-compute` is an **internal sub-crate** of the [Polars](https://crates.io/crates/polars) +library, supplying private compute kernels. + +**Important Note**: This crate is **not intended for external usage**. Please refer to the main +[Polars crate](https://crates.io/crates/polars) for intended usage. diff --git a/crates/polars-compute/build.rs b/crates/polars-compute/build.rs new file mode 100644 index 000000000000..3e4ab64620ac --- /dev/null +++ b/crates/polars-compute/build.rs @@ -0,0 +1,7 @@ +fn main() { + println!("cargo:rerun-if-changed=build.rs"); + let channel = version_check::Channel::read().unwrap(); + if channel.is_nightly() { + println!("cargo:rustc-cfg=feature=\"nightly\""); + } +} diff --git a/crates/polars-compute/src/arithmetic/float.rs b/crates/polars-compute/src/arithmetic/float.rs new file mode 100644 index 000000000000..751e9af32544 --- /dev/null +++ b/crates/polars-compute/src/arithmetic/float.rs @@ -0,0 +1,119 @@ +use arrow::array::PrimitiveArray as PArr; + +use super::PrimitiveArithmeticKernelImpl; +use crate::arity::{prim_binary_values, prim_unary_values}; + +macro_rules! impl_float_arith_kernel { + ($T:ty) => { + impl PrimitiveArithmeticKernelImpl for $T { + type TrueDivT = $T; + + fn prim_wrapping_abs(lhs: PArr<$T>) -> PArr<$T> { + prim_unary_values(lhs, |x| x.abs()) + } + + fn prim_wrapping_neg(lhs: PArr<$T>) -> PArr<$T> { + prim_unary_values(lhs, |x| -x) + } + + fn prim_wrapping_add(lhs: PArr<$T>, rhs: PArr<$T>) -> PArr<$T> { + prim_binary_values(lhs, rhs, |l, r| l + r) + } + + fn prim_wrapping_sub(lhs: PArr<$T>, rhs: PArr<$T>) -> PArr<$T> { + prim_binary_values(lhs, rhs, |l, r| l - r) + } + + fn prim_wrapping_mul(lhs: PArr<$T>, rhs: PArr<$T>) -> PArr<$T> { + prim_binary_values(lhs, rhs, |l, r| l * r) + } + + fn prim_wrapping_floor_div(lhs: PArr<$T>, rhs: PArr<$T>) -> PArr<$T> { + prim_binary_values(lhs, rhs, |l, r| (l / r).floor()) + } + + fn prim_wrapping_trunc_div(lhs: PArr<$T>, rhs: PArr<$T>) -> PArr<$T> { + prim_binary_values(lhs, rhs, |l, r| (l / r).trunc()) + } + + fn prim_wrapping_mod(lhs: PArr<$T>, rhs: PArr<$T>) -> PArr<$T> { + prim_binary_values(lhs, rhs, |l, r| l - r * (l / r).floor()) + } + + fn prim_wrapping_add_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + if rhs == 0.0 { + return lhs; + } + prim_unary_values(lhs, |x| x + rhs) + } + + fn prim_wrapping_sub_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + if rhs == 0.0 { + return lhs; + } + Self::prim_wrapping_add_scalar(lhs, -rhs) + } + + fn prim_wrapping_sub_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> { + if lhs == 0.0 { + Self::prim_wrapping_neg(rhs) + } else { + prim_unary_values(rhs, |x| lhs - x) + } + } + + fn prim_wrapping_mul_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + // No optimization for multiplication by zero, would invalidate NaNs/infinities. + if rhs == 1.0 { + lhs + } else if rhs == -1.0 { + Self::prim_wrapping_neg(lhs) + } else { + prim_unary_values(lhs, |x| x * rhs) + } + } + + fn prim_wrapping_floor_div_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + let inv = 1.0 / rhs; + prim_unary_values(lhs, |x| (x * inv).floor()) + } + + fn prim_wrapping_floor_div_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> { + prim_unary_values(rhs, |x| (lhs / x).floor()) + } + + fn prim_wrapping_trunc_div_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + let inv = 1.0 / rhs; + prim_unary_values(lhs, |x| (x * inv).trunc()) + } + + fn prim_wrapping_trunc_div_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> { + prim_unary_values(rhs, |x| (lhs / x).trunc()) + } + + fn prim_wrapping_mod_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + let inv = 1.0 / rhs; + prim_unary_values(lhs, |x| x - rhs * (x * inv).floor()) + } + + fn prim_wrapping_mod_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> { + prim_unary_values(rhs, |x| lhs - x * (lhs / x).floor()) + } + + fn prim_true_div(lhs: PArr<$T>, rhs: PArr<$T>) -> PArr { + prim_binary_values(lhs, rhs, |l, r| l / r) + } + + fn prim_true_div_scalar(lhs: PArr<$T>, rhs: $T) -> PArr { + Self::prim_wrapping_mul_scalar(lhs, 1.0 / rhs) + } + + fn prim_true_div_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr { + prim_unary_values(rhs, |x| lhs / x) + } + } + }; +} + +impl_float_arith_kernel!(f32); +impl_float_arith_kernel!(f64); diff --git a/crates/polars-compute/src/arithmetic/mod.rs b/crates/polars-compute/src/arithmetic/mod.rs new file mode 100644 index 000000000000..cb74881ed4a5 --- /dev/null +++ b/crates/polars-compute/src/arithmetic/mod.rs @@ -0,0 +1,146 @@ +use std::any::TypeId; + +use arrow::array::{Array, PrimitiveArray}; +use arrow::types::NativeType; + +// Low-level comparison kernel. +pub trait ArithmeticKernel: Sized + Array { + type Scalar; + type TrueDivT: NativeType; + + fn wrapping_abs(self) -> Self; + fn wrapping_neg(self) -> Self; + fn wrapping_add(self, rhs: Self) -> Self; + fn wrapping_sub(self, rhs: Self) -> Self; + fn wrapping_mul(self, rhs: Self) -> Self; + fn wrapping_floor_div(self, rhs: Self) -> Self; + fn wrapping_trunc_div(self, rhs: Self) -> Self; + fn wrapping_mod(self, rhs: Self) -> Self; + + fn wrapping_add_scalar(self, rhs: Self::Scalar) -> Self; + fn wrapping_sub_scalar(self, rhs: Self::Scalar) -> Self; + fn wrapping_sub_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self; + fn wrapping_mul_scalar(self, rhs: Self::Scalar) -> Self; + fn wrapping_floor_div_scalar(self, rhs: Self::Scalar) -> Self; + fn wrapping_floor_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self; + fn wrapping_trunc_div_scalar(self, rhs: Self::Scalar) -> Self; + fn wrapping_trunc_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self; + fn wrapping_mod_scalar(self, rhs: Self::Scalar) -> Self; + fn wrapping_mod_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self; + + fn true_div(self, rhs: Self) -> PrimitiveArray; + fn true_div_scalar(self, rhs: Self::Scalar) -> PrimitiveArray; + fn true_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> PrimitiveArray; + + // TODO: remove these. + // These are flooring division for integer types, true division for floating point types. + fn legacy_div(self, rhs: Self) -> Self { + if TypeId::of::() == TypeId::of::>() { + let ret = self.true_div(rhs); + unsafe { + let cast_ret = std::mem::transmute_copy(&ret); + std::mem::forget(ret); + cast_ret + } + } else { + self.wrapping_floor_div(rhs) + } + } + fn legacy_div_scalar(self, rhs: Self::Scalar) -> Self { + if TypeId::of::() == TypeId::of::>() { + let ret = self.true_div_scalar(rhs); + unsafe { + let cast_ret = std::mem::transmute_copy(&ret); + std::mem::forget(ret); + cast_ret + } + } else { + self.wrapping_floor_div_scalar(rhs) + } + } + + fn legacy_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self { + if TypeId::of::() == TypeId::of::>() { + let ret = ArithmeticKernel::true_div_scalar_lhs(lhs, rhs); + unsafe { + let cast_ret = std::mem::transmute_copy(&ret); + std::mem::forget(ret); + cast_ret + } + } else { + ArithmeticKernel::wrapping_floor_div_scalar_lhs(lhs, rhs) + } + } +} + +// Proxy trait so one can bound T: HasPrimitiveArithmeticKernel. Sadly Rust +// doesn't support adding supertraits for other types. +#[allow(private_bounds)] +pub trait HasPrimitiveArithmeticKernel: NativeType + PrimitiveArithmeticKernelImpl {} +impl HasPrimitiveArithmeticKernel for T {} + +use PrimitiveArray as PArr; + +#[doc(hidden)] +pub trait PrimitiveArithmeticKernelImpl: NativeType { + type TrueDivT: NativeType; + + fn prim_wrapping_abs(lhs: PArr) -> PArr; + fn prim_wrapping_neg(lhs: PArr) -> PArr; + fn prim_wrapping_add(lhs: PArr, rhs: PArr) -> PArr; + fn prim_wrapping_sub(lhs: PArr, rhs: PArr) -> PArr; + fn prim_wrapping_mul(lhs: PArr, rhs: PArr) -> PArr; + fn prim_wrapping_floor_div(lhs: PArr, rhs: PArr) -> PArr; + fn prim_wrapping_trunc_div(lhs: PArr, rhs: PArr) -> PArr; + fn prim_wrapping_mod(lhs: PArr, rhs: PArr) -> PArr; + + fn prim_wrapping_add_scalar(lhs: PArr, rhs: Self) -> PArr; + fn prim_wrapping_sub_scalar(lhs: PArr, rhs: Self) -> PArr; + fn prim_wrapping_sub_scalar_lhs(lhs: Self, rhs: PArr) -> PArr; + fn prim_wrapping_mul_scalar(lhs: PArr, rhs: Self) -> PArr; + fn prim_wrapping_floor_div_scalar(lhs: PArr, rhs: Self) -> PArr; + fn prim_wrapping_floor_div_scalar_lhs(lhs: Self, rhs: PArr) -> PArr; + fn prim_wrapping_trunc_div_scalar(lhs: PArr, rhs: Self) -> PArr; + fn prim_wrapping_trunc_div_scalar_lhs(lhs: Self, rhs: PArr) -> PArr; + fn prim_wrapping_mod_scalar(lhs: PArr, rhs: Self) -> PArr; + fn prim_wrapping_mod_scalar_lhs(lhs: Self, rhs: PArr) -> PArr; + + fn prim_true_div(lhs: PArr, rhs: PArr) -> PArr; + fn prim_true_div_scalar(lhs: PArr, rhs: Self) -> PArr; + fn prim_true_div_scalar_lhs(lhs: Self, rhs: PArr) -> PArr; +} + +#[rustfmt::skip] +impl ArithmeticKernel for PrimitiveArray { + type Scalar = T; + type TrueDivT = T::TrueDivT; + + fn wrapping_abs(self) -> Self { T::prim_wrapping_abs(self) } + fn wrapping_neg(self) -> Self { T::prim_wrapping_neg(self) } + fn wrapping_add(self, rhs: Self) -> Self { T::prim_wrapping_add(self, rhs) } + fn wrapping_sub(self, rhs: Self) -> Self { T::prim_wrapping_sub(self, rhs) } + fn wrapping_mul(self, rhs: Self) -> Self { T::prim_wrapping_mul(self, rhs) } + fn wrapping_floor_div(self, rhs: Self) -> Self { T::prim_wrapping_floor_div(self, rhs) } + fn wrapping_trunc_div(self, rhs: Self) -> Self { T::prim_wrapping_trunc_div(self, rhs) } + fn wrapping_mod(self, rhs: Self) -> Self { T::prim_wrapping_mod(self, rhs) } + + fn wrapping_add_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_add_scalar(self, rhs) } + fn wrapping_sub_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_sub_scalar(self, rhs) } + fn wrapping_sub_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self { T::prim_wrapping_sub_scalar_lhs(lhs, rhs) } + fn wrapping_mul_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_mul_scalar(self, rhs) } + fn wrapping_floor_div_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_floor_div_scalar(self, rhs) } + fn wrapping_floor_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self { T::prim_wrapping_floor_div_scalar_lhs(lhs, rhs) } + fn wrapping_trunc_div_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_trunc_div_scalar(self, rhs) } + fn wrapping_trunc_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self { T::prim_wrapping_trunc_div_scalar_lhs(lhs, rhs) } + fn wrapping_mod_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_mod_scalar(self, rhs) } + fn wrapping_mod_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self { T::prim_wrapping_mod_scalar_lhs(lhs, rhs) } + + fn true_div(self, rhs: Self) -> PrimitiveArray { T::prim_true_div(self, rhs) } + fn true_div_scalar(self, rhs: Self::Scalar) -> PrimitiveArray { T::prim_true_div_scalar(self, rhs) } + fn true_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> PrimitiveArray { T::prim_true_div_scalar_lhs(lhs, rhs) } +} + +mod float; +pub mod pl_num; +mod signed; +mod unsigned; diff --git a/crates/polars-compute/src/arithmetic/pl_num.rs b/crates/polars-compute/src/arithmetic/pl_num.rs new file mode 100644 index 000000000000..81855cc8f181 --- /dev/null +++ b/crates/polars-compute/src/arithmetic/pl_num.rs @@ -0,0 +1,217 @@ +use core::any::TypeId; + +use arrow::types::NativeType; +use polars_utils::floor_divmod::FloorDivMod; + +/// Implements basic arithmetic between scalars with the same behavior as `ArithmeticKernel`. +/// +/// Note, however, that the user is responsible for setting the validity of +/// results for e.g. div/mod operations with 0 in the denominator. +/// +/// This is intended as a low-level utility for custom arithmetic loops +/// (e.g. in list arithmetic). In most cases prefer using `ArithmeticKernel` or +/// `ArithmeticChunked` instead. +pub trait PlNumArithmetic: Sized + Copy + 'static { + type TrueDivT: NativeType; + + fn wrapping_abs(self) -> Self; + fn wrapping_neg(self) -> Self; + fn wrapping_add(self, rhs: Self) -> Self; + fn wrapping_sub(self, rhs: Self) -> Self; + fn wrapping_mul(self, rhs: Self) -> Self; + fn wrapping_floor_div(self, rhs: Self) -> Self; + fn wrapping_trunc_div(self, rhs: Self) -> Self; + fn wrapping_mod(self, rhs: Self) -> Self; + + fn true_div(self, rhs: Self) -> Self::TrueDivT; + + #[inline(always)] + fn legacy_div(self, rhs: Self) -> Self { + if TypeId::of::() == TypeId::of::() { + let ret = self.true_div(rhs); + unsafe { core::mem::transmute_copy(&ret) } + } else { + self.wrapping_floor_div(rhs) + } + } +} + +macro_rules! impl_signed_pl_num_arith { + ($T:ty) => { + impl PlNumArithmetic for $T { + type TrueDivT = f64; + + #[inline(always)] + fn wrapping_abs(self) -> Self { + self.wrapping_abs() + } + + #[inline(always)] + fn wrapping_neg(self) -> Self { + self.wrapping_neg() + } + + #[inline(always)] + fn wrapping_add(self, rhs: Self) -> Self { + self.wrapping_add(rhs) + } + + #[inline(always)] + fn wrapping_sub(self, rhs: Self) -> Self { + self.wrapping_sub(rhs) + } + + #[inline(always)] + fn wrapping_mul(self, rhs: Self) -> Self { + self.wrapping_mul(rhs) + } + + #[inline(always)] + fn wrapping_floor_div(self, rhs: Self) -> Self { + self.wrapping_floor_div_mod(rhs).0 + } + + #[inline(always)] + fn wrapping_trunc_div(self, rhs: Self) -> Self { + if rhs != 0 { self.wrapping_div(rhs) } else { 0 } + } + + #[inline(always)] + fn wrapping_mod(self, rhs: Self) -> Self { + self.wrapping_floor_div_mod(rhs).1 + } + + #[inline(always)] + fn true_div(self, rhs: Self) -> Self::TrueDivT { + self as f64 / rhs as f64 + } + } + }; +} + +impl_signed_pl_num_arith!(i8); +impl_signed_pl_num_arith!(i16); +impl_signed_pl_num_arith!(i32); +impl_signed_pl_num_arith!(i64); +impl_signed_pl_num_arith!(i128); + +macro_rules! impl_unsigned_pl_num_arith { + ($T:ty) => { + impl PlNumArithmetic for $T { + type TrueDivT = f64; + + #[inline(always)] + fn wrapping_abs(self) -> Self { + self + } + + #[inline(always)] + fn wrapping_neg(self) -> Self { + self.wrapping_neg() + } + + #[inline(always)] + fn wrapping_add(self, rhs: Self) -> Self { + self.wrapping_add(rhs) + } + + #[inline(always)] + fn wrapping_sub(self, rhs: Self) -> Self { + self.wrapping_sub(rhs) + } + + #[inline(always)] + fn wrapping_mul(self, rhs: Self) -> Self { + self.wrapping_mul(rhs) + } + + #[inline(always)] + fn wrapping_floor_div(self, rhs: Self) -> Self { + if rhs != 0 { self / rhs } else { 0 } + } + + #[inline(always)] + fn wrapping_trunc_div(self, rhs: Self) -> Self { + self.wrapping_floor_div(rhs) + } + + #[inline(always)] + fn wrapping_mod(self, rhs: Self) -> Self { + if rhs != 0 { self % rhs } else { 0 } + } + + #[inline(always)] + fn true_div(self, rhs: Self) -> Self::TrueDivT { + self as f64 / rhs as f64 + } + } + }; +} + +impl_unsigned_pl_num_arith!(u8); +impl_unsigned_pl_num_arith!(u16); +impl_unsigned_pl_num_arith!(u32); +impl_unsigned_pl_num_arith!(u64); +impl_unsigned_pl_num_arith!(u128); + +macro_rules! impl_float_pl_num_arith { + ($T:ty) => { + impl PlNumArithmetic for $T { + type TrueDivT = $T; + + #[inline(always)] + fn wrapping_abs(self) -> Self { + self.abs() + } + + #[inline(always)] + fn wrapping_neg(self) -> Self { + -self + } + + #[inline(always)] + fn wrapping_add(self, rhs: Self) -> Self { + self + rhs + } + + #[inline(always)] + fn wrapping_sub(self, rhs: Self) -> Self { + self - rhs + } + + #[inline(always)] + fn wrapping_mul(self, rhs: Self) -> Self { + self * rhs + } + + #[inline(always)] + fn wrapping_floor_div(self, rhs: Self) -> Self { + let l = self; + let r = rhs; + (l / r).floor() + } + + #[inline(always)] + fn wrapping_trunc_div(self, rhs: Self) -> Self { + let l = self; + let r = rhs; + (l / r).trunc() + } + + #[inline(always)] + fn wrapping_mod(self, rhs: Self) -> Self { + let l = self; + let r = rhs; + l - r * (l / r).floor() + } + + #[inline(always)] + fn true_div(self, rhs: Self) -> Self::TrueDivT { + self / rhs + } + } + }; +} + +impl_float_pl_num_arith!(f32); +impl_float_pl_num_arith!(f64); diff --git a/crates/polars-compute/src/arithmetic/signed.rs b/crates/polars-compute/src/arithmetic/signed.rs new file mode 100644 index 000000000000..073a92a8968a --- /dev/null +++ b/crates/polars-compute/src/arithmetic/signed.rs @@ -0,0 +1,230 @@ +use arrow::array::{PrimitiveArray as PArr, StaticArray}; +use arrow::compute::utils::{combine_validities_and, combine_validities_and3}; +use polars_utils::floor_divmod::FloorDivMod; +use strength_reduce::*; + +use super::PrimitiveArithmeticKernelImpl; +use crate::arity::{prim_binary_values, prim_unary_values}; +use crate::comparisons::TotalEqKernel; + +macro_rules! impl_signed_arith_kernel { + ($T:ty, $StrRed:ty) => { + impl PrimitiveArithmeticKernelImpl for $T { + type TrueDivT = f64; + + fn prim_wrapping_abs(lhs: PArr<$T>) -> PArr<$T> { + prim_unary_values(lhs, |x| x.wrapping_abs()) + } + + fn prim_wrapping_neg(lhs: PArr<$T>) -> PArr<$T> { + prim_unary_values(lhs, |x| x.wrapping_neg()) + } + + fn prim_wrapping_add(lhs: PArr<$T>, other: PArr<$T>) -> PArr<$T> { + prim_binary_values(lhs, other, |a, b| a.wrapping_add(b)) + } + + fn prim_wrapping_sub(lhs: PArr<$T>, other: PArr<$T>) -> PArr<$T> { + prim_binary_values(lhs, other, |a, b| a.wrapping_sub(b)) + } + + fn prim_wrapping_mul(lhs: PArr<$T>, other: PArr<$T>) -> PArr<$T> { + prim_binary_values(lhs, other, |a, b| a.wrapping_mul(b)) + } + + fn prim_wrapping_floor_div(mut lhs: PArr<$T>, mut other: PArr<$T>) -> PArr<$T> { + let mask = other.tot_ne_kernel_broadcast(&0); + let valid = combine_validities_and3( + lhs.take_validity().as_ref(), // Take validity so we don't + other.take_validity().as_ref(), // compute combination twice. + Some(&mask), + ); + let ret = + prim_binary_values(lhs, other, |lhs, rhs| lhs.wrapping_floor_div_mod(rhs).0); + ret.with_validity(valid) + } + + fn prim_wrapping_trunc_div(mut lhs: PArr<$T>, mut other: PArr<$T>) -> PArr<$T> { + let mask = other.tot_ne_kernel_broadcast(&0); + let valid = combine_validities_and3( + lhs.take_validity().as_ref(), // Take validity so we don't + other.take_validity().as_ref(), // compute combination twice. + Some(&mask), + ); + let ret = prim_binary_values(lhs, other, |lhs, rhs| { + if rhs != 0 { lhs.wrapping_div(rhs) } else { 0 } + }); + ret.with_validity(valid) + } + + fn prim_wrapping_mod(mut lhs: PArr<$T>, mut other: PArr<$T>) -> PArr<$T> { + let mask = other.tot_ne_kernel_broadcast(&0); + let valid = combine_validities_and3( + lhs.take_validity().as_ref(), // Take validity so we don't + other.take_validity().as_ref(), // compute combination twice. + Some(&mask), + ); + let ret = + prim_binary_values(lhs, other, |lhs, rhs| lhs.wrapping_floor_div_mod(rhs).1); + ret.with_validity(valid) + } + + fn prim_wrapping_add_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + prim_unary_values(lhs, |x| x.wrapping_add(rhs)) + } + + fn prim_wrapping_sub_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + Self::prim_wrapping_add_scalar(lhs, rhs.wrapping_neg()) + } + + fn prim_wrapping_sub_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> { + prim_unary_values(rhs, |x| lhs.wrapping_sub(x)) + } + + fn prim_wrapping_mul_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + let scalar_u = rhs.unsigned_abs(); + if rhs == 0 { + lhs.fill_with(0) + } else if rhs == 1 { + lhs + } else if scalar_u.is_power_of_two() { + // Power of two. + let shift = scalar_u.trailing_zeros(); + if rhs > 0 { + prim_unary_values(lhs, |x| x << shift) + } else { + prim_unary_values(lhs, |x| (x << shift).wrapping_neg()) + } + } else { + prim_unary_values(lhs, |x| x.wrapping_mul(rhs)) + } + } + + fn prim_wrapping_floor_div_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + if rhs == 0 { + PArr::full_null(lhs.len(), lhs.dtype().clone()) + } else if rhs == -1 { + Self::prim_wrapping_neg(lhs) + } else if rhs == 1 { + lhs + } else { + let red = <$StrRed>::new(rhs.unsigned_abs()); + prim_unary_values(lhs, |x| { + let (quot, rem) = <$StrRed>::div_rem(x.unsigned_abs(), red); + if (x < 0) != (rhs < 0) { + // Different signs: result should be negative. + // Since we handled rhs.abs() <= 1, quot fits. + let mut ret = -(quot as $T); + if rem != 0 { + // Division had remainder, subtract 1 to floor to + // negative infinity, as we truncated to zero. + ret -= 1; + } + ret + } else { + quot as $T + } + }) + } + } + + fn prim_wrapping_floor_div_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> { + let mask = rhs.tot_ne_kernel_broadcast(&0); + let valid = combine_validities_and(rhs.validity(), Some(&mask)); + let ret = if lhs == 0 { + rhs.fill_with(0) + } else { + prim_unary_values(rhs, |x| lhs.wrapping_floor_div_mod(x).0) + }; + ret.with_validity(valid) + } + + fn prim_wrapping_trunc_div_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + if rhs == 0 { + PArr::full_null(lhs.len(), lhs.dtype().clone()) + } else if rhs == -1 { + Self::prim_wrapping_neg(lhs) + } else if rhs == 1 { + lhs + } else { + let red = <$StrRed>::new(rhs.unsigned_abs()); + prim_unary_values(lhs, |x| { + let quot = x.unsigned_abs() / red; + if (x < 0) != (rhs < 0) { + // Different signs: result should be negative. + -(quot as $T) + } else { + quot as $T + } + }) + } + } + + fn prim_wrapping_trunc_div_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> { + let mask = rhs.tot_ne_kernel_broadcast(&0); + let valid = combine_validities_and(rhs.validity(), Some(&mask)); + let ret = if lhs == 0 { + rhs.fill_with(0) + } else { + prim_unary_values(rhs, |x| if x != 0 { lhs.wrapping_div(x) } else { 0 }) + }; + ret.with_validity(valid) + } + + fn prim_wrapping_mod_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + if rhs == 0 { + PArr::full_null(lhs.len(), lhs.dtype().clone()) + } else if rhs == -1 || rhs == 1 { + lhs.fill_with(0) + } else { + let scalar_u = rhs.unsigned_abs(); + let red = <$StrRed>::new(scalar_u); + prim_unary_values(lhs, |x| { + // Remainder fits in signed type after reduction. + // Largest possible modulo -I::MIN, with + // -I::MIN-1 == I::MAX as largest remainder. + let mut rem_u = x.unsigned_abs() % red; + + // Mixed signs: swap direction of remainder. + if rem_u != 0 && (rhs < 0) != (x < 0) { + rem_u = scalar_u - rem_u; + } + + // Remainder should have sign of RHS. + if rhs < 0 { -(rem_u as $T) } else { rem_u as $T } + }) + } + } + + fn prim_wrapping_mod_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> { + let mask = rhs.tot_ne_kernel_broadcast(&0); + let valid = combine_validities_and(rhs.validity(), Some(&mask)); + let ret = if lhs == 0 { + rhs.fill_with(0) + } else { + prim_unary_values(rhs, |x| lhs.wrapping_floor_div_mod(x).1) + }; + ret.with_validity(valid) + } + + fn prim_true_div(lhs: PArr<$T>, other: PArr<$T>) -> PArr { + prim_binary_values(lhs, other, |a, b| a as f64 / b as f64) + } + + fn prim_true_div_scalar(lhs: PArr<$T>, rhs: $T) -> PArr { + let inv = 1.0 / rhs as f64; + prim_unary_values(lhs, |x| x as f64 * inv) + } + + fn prim_true_div_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr { + prim_unary_values(rhs, |x| lhs as f64 / x as f64) + } + } + }; +} + +impl_signed_arith_kernel!(i8, StrengthReducedU8); +impl_signed_arith_kernel!(i16, StrengthReducedU16); +impl_signed_arith_kernel!(i32, StrengthReducedU32); +impl_signed_arith_kernel!(i64, StrengthReducedU64); +impl_signed_arith_kernel!(i128, StrengthReducedU128); diff --git a/crates/polars-compute/src/arithmetic/unsigned.rs b/crates/polars-compute/src/arithmetic/unsigned.rs new file mode 100644 index 000000000000..46fc0037597d --- /dev/null +++ b/crates/polars-compute/src/arithmetic/unsigned.rs @@ -0,0 +1,158 @@ +use arrow::array::{PrimitiveArray as PArr, StaticArray}; +use arrow::compute::utils::{combine_validities_and, combine_validities_and3}; +use strength_reduce::*; + +use super::PrimitiveArithmeticKernelImpl; +use crate::arity::{prim_binary_values, prim_unary_values}; +use crate::comparisons::TotalEqKernel; + +macro_rules! impl_unsigned_arith_kernel { + ($T:ty, $StrRed:ty) => { + impl PrimitiveArithmeticKernelImpl for $T { + type TrueDivT = f64; + + fn prim_wrapping_abs(lhs: PArr<$T>) -> PArr<$T> { + lhs + } + + fn prim_wrapping_neg(lhs: PArr<$T>) -> PArr<$T> { + prim_unary_values(lhs, |x| x.wrapping_neg()) + } + + fn prim_wrapping_add(lhs: PArr<$T>, other: PArr<$T>) -> PArr<$T> { + prim_binary_values(lhs, other, |a, b| a.wrapping_add(b)) + } + + fn prim_wrapping_sub(lhs: PArr<$T>, other: PArr<$T>) -> PArr<$T> { + prim_binary_values(lhs, other, |a, b| a.wrapping_sub(b)) + } + + fn prim_wrapping_mul(lhs: PArr<$T>, other: PArr<$T>) -> PArr<$T> { + prim_binary_values(lhs, other, |a, b| a.wrapping_mul(b)) + } + + fn prim_wrapping_floor_div(mut lhs: PArr<$T>, mut other: PArr<$T>) -> PArr<$T> { + let mask = other.tot_ne_kernel_broadcast(&0); + let valid = combine_validities_and3( + lhs.take_validity().as_ref(), // Take validity so we don't + other.take_validity().as_ref(), // compute combination twice. + Some(&mask), + ); + let ret = prim_binary_values(lhs, other, |a, b| if b != 0 { a / b } else { 0 }); + ret.with_validity(valid) + } + + fn prim_wrapping_trunc_div(lhs: PArr<$T>, rhs: PArr<$T>) -> PArr<$T> { + Self::prim_wrapping_floor_div(lhs, rhs) + } + + fn prim_wrapping_mod(mut lhs: PArr<$T>, mut other: PArr<$T>) -> PArr<$T> { + let mask = other.tot_ne_kernel_broadcast(&0); + let valid = combine_validities_and3( + lhs.take_validity().as_ref(), // Take validity so we don't + other.take_validity().as_ref(), // compute combination twice. + Some(&mask), + ); + let ret = prim_binary_values(lhs, other, |a, b| if b != 0 { a % b } else { 0 }); + ret.with_validity(valid) + } + + fn prim_wrapping_add_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + prim_unary_values(lhs, |x| x.wrapping_add(rhs)) + } + + fn prim_wrapping_sub_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + Self::prim_wrapping_add_scalar(lhs, rhs.wrapping_neg()) + } + + fn prim_wrapping_sub_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> { + prim_unary_values(rhs, |x| lhs.wrapping_sub(x)) + } + + fn prim_wrapping_mul_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + if rhs == 0 { + lhs.fill_with(0) + } else if rhs == 1 { + lhs + } else if rhs.is_power_of_two() { + // Power of two. + let shift = rhs.trailing_zeros(); + prim_unary_values(lhs, |x| x << shift) + } else { + prim_unary_values(lhs, |x| x.wrapping_mul(rhs)) + } + } + + fn prim_wrapping_floor_div_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + if rhs == 0 { + PArr::full_null(lhs.len(), lhs.dtype().clone()) + } else if rhs == 1 { + lhs + } else { + let red = <$StrRed>::new(rhs); + prim_unary_values(lhs, |x| x / red) + } + } + + fn prim_wrapping_floor_div_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> { + let mask = rhs.tot_ne_kernel_broadcast(&0); + let valid = combine_validities_and(rhs.validity(), Some(&mask)); + let ret = if lhs == 0 { + rhs.fill_with(0) + } else { + prim_unary_values(rhs, |x| if x != 0 { lhs / x } else { 0 }) + }; + ret.with_validity(valid) + } + + fn prim_wrapping_trunc_div_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + Self::prim_wrapping_floor_div_scalar(lhs, rhs) + } + + fn prim_wrapping_trunc_div_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> { + Self::prim_wrapping_floor_div_scalar_lhs(lhs, rhs) + } + + fn prim_wrapping_mod_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { + if rhs == 0 { + PArr::full_null(lhs.len(), lhs.dtype().clone()) + } else if rhs == 1 { + lhs.fill_with(0) + } else { + let red = <$StrRed>::new(rhs); + prim_unary_values(lhs, |x| x % red) + } + } + + fn prim_wrapping_mod_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr<$T> { + let mask = rhs.tot_ne_kernel_broadcast(&0); + let valid = combine_validities_and(rhs.validity(), Some(&mask)); + let ret = if lhs == 0 { + rhs.fill_with(0) + } else { + prim_unary_values(rhs, |x| if x != 0 { lhs % x } else { 0 }) + }; + ret.with_validity(valid) + } + + fn prim_true_div(lhs: PArr<$T>, other: PArr<$T>) -> PArr { + prim_binary_values(lhs, other, |a, b| a as f64 / b as f64) + } + + fn prim_true_div_scalar(lhs: PArr<$T>, rhs: $T) -> PArr { + let inv = 1.0 / rhs as f64; + prim_unary_values(lhs, |x| x as f64 * inv) + } + + fn prim_true_div_scalar_lhs(lhs: $T, rhs: PArr<$T>) -> PArr { + prim_unary_values(rhs, |x| lhs as f64 / x as f64) + } + } + }; +} + +impl_unsigned_arith_kernel!(u8, StrengthReducedU8); +impl_unsigned_arith_kernel!(u16, StrengthReducedU16); +impl_unsigned_arith_kernel!(u32, StrengthReducedU32); +impl_unsigned_arith_kernel!(u64, StrengthReducedU64); +impl_unsigned_arith_kernel!(u128, StrengthReducedU128); diff --git a/crates/polars-compute/src/arity.rs b/crates/polars-compute/src/arity.rs new file mode 100644 index 000000000000..026bcb169210 --- /dev/null +++ b/crates/polars-compute/src/arity.rs @@ -0,0 +1,127 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use arrow::array::PrimitiveArray; +use arrow::compute::utils::combine_validities_and; +use arrow::types::NativeType; + +/// To reduce codegen we use these helpers where the input and output arrays +/// may overlap. These are marked to never be inlined, this way only a single +/// unrolled kernel gets generated, even if we call it in multiple ways. +/// +/// # Safety +/// - arr must point to a readable slice of length len. +/// - out must point to a writable slice of length len. +#[inline(never)] +unsafe fn ptr_apply_unary_kernel O>( + arr: *const I, + out: *mut O, + len: usize, + op: F, +) { + for i in 0..len { + let ret = op(arr.add(i).read()); + out.add(i).write(ret); + } +} + +/// # Safety +/// - left must point to a readable slice of length len. +/// - right must point to a readable slice of length len. +/// - out must point to a writable slice of length len. +#[inline(never)] +unsafe fn ptr_apply_binary_kernel O>( + left: *const L, + right: *const R, + out: *mut O, + len: usize, + op: F, +) { + for i in 0..len { + let ret = op(left.add(i).read(), right.add(i).read()); + out.add(i).write(ret); + } +} + +/// Applies a function to all the values (regardless of nullability). +/// +/// May reuse the memory of the array if possible. +pub fn prim_unary_values(mut arr: PrimitiveArray, op: F) -> PrimitiveArray +where + I: NativeType, + O: NativeType, + F: Fn(I) -> O, +{ + let len = arr.len(); + + // Reuse memory if possible. + if size_of::() == size_of::() && align_of::() == align_of::() { + if let Some(values) = arr.get_mut_values() { + let ptr = values.as_mut_ptr(); + // SAFETY: checked same size & alignment I/O, NativeType is always Pod. + unsafe { ptr_apply_unary_kernel(ptr, ptr as *mut O, len, op) } + return arr.transmute::(); + } + } + + let mut out = Vec::with_capacity(len); + unsafe { + // SAFETY: checked pointers point to slices of length len. + ptr_apply_unary_kernel(arr.values().as_ptr(), out.as_mut_ptr(), len, op); + out.set_len(len); + } + PrimitiveArray::from_vec(out).with_validity(arr.take_validity()) +} + +/// Apply a binary function to all the values (regardless of nullability) +/// in (lhs, rhs). Combines the validities with a bitand. +/// +/// May reuse the memory of one of its arguments if possible. +pub fn prim_binary_values( + mut lhs: PrimitiveArray, + mut rhs: PrimitiveArray, + op: F, +) -> PrimitiveArray +where + L: NativeType, + R: NativeType, + O: NativeType, + F: Fn(L, R) -> O, +{ + assert_eq!(lhs.len(), rhs.len()); + let len = lhs.len(); + + let validity = combine_validities_and(lhs.validity(), rhs.validity()); + + // Reuse memory if possible. + if size_of::() == size_of::() && align_of::() == align_of::() { + if let Some(lv) = lhs.get_mut_values() { + let lp = lv.as_mut_ptr(); + let rp = rhs.values().as_ptr(); + unsafe { + // SAFETY: checked same size & alignment L/O, NativeType is always Pod. + ptr_apply_binary_kernel(lp, rp, lp as *mut O, len, op); + } + return lhs.transmute::().with_validity(validity); + } + } + if size_of::() == size_of::() && align_of::() == align_of::() { + if let Some(rv) = rhs.get_mut_values() { + let lp = lhs.values().as_ptr(); + let rp = rv.as_mut_ptr(); + unsafe { + // SAFETY: checked same size & alignment R/O, NativeType is always Pod. + ptr_apply_binary_kernel(lp, rp, rp as *mut O, len, op); + } + return rhs.transmute::().with_validity(validity); + } + } + + let mut out = Vec::with_capacity(len); + unsafe { + // SAFETY: checked pointers point to slices of length len. + let lp = lhs.values().as_ptr(); + let rp = rhs.values().as_ptr(); + ptr_apply_binary_kernel(lp, rp, out.as_mut_ptr(), len, op); + out.set_len(len); + } + PrimitiveArray::from_vec(out).with_validity(validity) +} diff --git a/crates/polars-compute/src/binview_index_map.rs b/crates/polars-compute/src/binview_index_map.rs new file mode 100644 index 000000000000..233a19cfa823 --- /dev/null +++ b/crates/polars-compute/src/binview_index_map.rs @@ -0,0 +1,348 @@ +use arrow::array::View; +use hashbrown::hash_table::{ + Entry as TEntry, HashTable, OccupiedEntry as TOccupiedEntry, VacantEntry as TVacantEntry, +}; +use polars_utils::IdxSize; + +const BASE_KEY_BUFFER_CAPACITY: usize = 1024; + +struct Key { + hash: u64, + view: View, +} + +/// An IndexMap where the keys are [u8] slices or `View`s which are pre-hashed. +/// Does not support deletion. +pub struct BinaryViewIndexMap { + table: HashTable, + tuples: Vec<(Key, V)>, + buffers: Vec>, + + // Internal random seed used to keep hash iteration order decorrelated. + // We simply store a random odd number and multiply the canonical hash by it. + seed: u64, +} + +impl Default for BinaryViewIndexMap { + fn default() -> Self { + Self { + table: HashTable::new(), + tuples: Vec::new(), + buffers: vec![], + seed: rand::random::() | 1, + } + } +} + +impl BinaryViewIndexMap { + pub fn new() -> Self { + Self::default() + } + + pub fn reserve(&mut self, additional: usize) { + self.table.reserve(additional, |i| unsafe { + let tuple = self.tuples.get_unchecked(*i as usize); + tuple.0.hash.wrapping_mul(self.seed) + }); + self.tuples.reserve(additional); + } + + pub fn len(&self) -> IdxSize { + self.tuples.len() as IdxSize + } + + pub fn is_empty(&self) -> bool { + self.tuples.is_empty() + } + + pub fn buffers(&self) -> &[Vec] { + &self.buffers + } + + #[inline] + pub fn get(&self, hash: u64, key: &[u8]) -> Option<&V> { + unsafe { + if key.len() <= View::MAX_INLINE_SIZE as usize { + self.get_inline_view(hash, &View::new_inline_unchecked(key)) + } else { + self.get_long_key(hash, key) + } + } + } + + /// # Safety + /// The view must be valid in combination with the given buffers. + #[inline] + pub unsafe fn get_view>( + &self, + hash: u64, + key: &View, + buffers: &[B], + ) -> Option<&V> { + unsafe { + if key.length <= View::MAX_INLINE_SIZE { + self.get_inline_view(hash, key) + } else { + self.get_long_key(hash, key.get_external_slice_unchecked(buffers)) + } + } + } + + /// # Safety + /// The view must be inlined. + pub unsafe fn get_inline_view(&self, hash: u64, key: &View) -> Option<&V> { + unsafe { + debug_assert!(key.length <= View::MAX_INLINE_SIZE); + let idx = self.table.find(hash.wrapping_mul(self.seed), |i| { + let t = self.tuples.get_unchecked(*i as usize); + *key == t.0.view + })?; + Some(&self.tuples.get_unchecked(*idx as usize).1) + } + } + + /// # Safety + /// key.len() > View::MAX_INLINE_SIZE + pub unsafe fn get_long_key(&self, hash: u64, key: &[u8]) -> Option<&V> { + unsafe { + debug_assert!(key.len() > View::MAX_INLINE_SIZE as usize); + let idx = self.table.find(hash.wrapping_mul(self.seed), |i| { + let t = self.tuples.get_unchecked(*i as usize); + hash == t.0.hash + && key.len() == t.0.view.length as usize + && key == t.0.view.get_external_slice_unchecked(&self.buffers) + })?; + Some(&self.tuples.get_unchecked(*idx as usize).1) + } + } + + #[inline] + pub fn entry<'k>(&mut self, hash: u64, key: &'k [u8]) -> Entry<'_, 'k, V> { + unsafe { + if key.len() <= View::MAX_INLINE_SIZE as usize { + self.entry_inline_view(hash, View::new_inline_unchecked(key)) + } else { + self.entry_long_key(hash, key) + } + } + } + + /// # Safety + /// The view must be valid in combination with the given buffers. + #[inline] + pub unsafe fn entry_view<'k, B: AsRef<[u8]>>( + &mut self, + hash: u64, + key: View, + buffers: &'k [B], + ) -> Entry<'_, 'k, V> { + unsafe { + if key.length <= View::MAX_INLINE_SIZE { + self.entry_inline_view(hash, key) + } else { + self.entry_long_key(hash, key.get_external_slice_unchecked(buffers)) + } + } + } + + /// # Safety + /// The view must be inlined. + pub unsafe fn entry_inline_view<'k>(&mut self, hash: u64, key: View) -> Entry<'_, 'k, V> { + debug_assert!(key.length <= View::MAX_INLINE_SIZE); + let entry = self.table.entry( + hash.wrapping_mul(self.seed), + |i| unsafe { + let t = self.tuples.get_unchecked(*i as usize); + key == t.0.view + }, + |i| unsafe { + let t = self.tuples.get_unchecked(*i as usize); + t.0.hash.wrapping_mul(self.seed) + }, + ); + + match entry { + TEntry::Occupied(o) => Entry::Occupied(OccupiedEntry { + entry: o, + tuples: &mut self.tuples, + }), + TEntry::Vacant(v) => Entry::Vacant(VacantEntry { + view: key, + external: None, + hash, + entry: v, + tuples: &mut self.tuples, + buffers: &mut self.buffers, + }), + } + } + + /// # Safety + /// key.len() > View::MAX_INLINE_SIZE + pub unsafe fn entry_long_key<'k>(&mut self, hash: u64, key: &'k [u8]) -> Entry<'_, 'k, V> { + debug_assert!(key.len() > View::MAX_INLINE_SIZE as usize); + let entry = self.table.entry( + hash.wrapping_mul(self.seed), + |i| unsafe { + let t = self.tuples.get_unchecked(*i as usize); + hash == t.0.hash + && key.len() == t.0.view.length as usize + && key == t.0.view.get_external_slice_unchecked(&self.buffers) + }, + |i| unsafe { + let t = self.tuples.get_unchecked(*i as usize); + t.0.hash.wrapping_mul(self.seed) + }, + ); + + match entry { + TEntry::Occupied(o) => Entry::Occupied(OccupiedEntry { + entry: o, + tuples: &mut self.tuples, + }), + TEntry::Vacant(v) => Entry::Vacant(VacantEntry { + view: View::default(), + external: Some(key), + hash, + entry: v, + tuples: &mut self.tuples, + buffers: &mut self.buffers, + }), + } + } + + /// Insert an empty entry which will never be mapped to. Returns the index of the entry. + /// + /// This is useful for entries which are handled externally. + pub fn push_unmapped_empty_entry(&mut self, value: V) -> IdxSize { + let ret = self.tuples.len() as IdxSize; + let key = Key { + hash: 0, + view: View::default(), + }; + self.tuples.push((key, value)); + ret + } + + /// Gets the hash, key and value at the given index by insertion order. + #[inline(always)] + pub fn get_index(&self, idx: IdxSize) -> Option<(u64, &[u8], &V)> { + let t = self.tuples.get(idx as usize)?; + Some(( + t.0.hash, + unsafe { t.0.view.get_slice_unchecked(&self.buffers) }, + &t.1, + )) + } + + /// Gets the hash, key and value at the given index by insertion order. + /// + /// # Safety + /// The index must be less than len(). + #[inline(always)] + pub unsafe fn get_index_unchecked(&self, idx: IdxSize) -> (u64, &[u8], &V) { + let t = unsafe { self.tuples.get_unchecked(idx as usize) }; + unsafe { (t.0.hash, t.0.view.get_slice_unchecked(&self.buffers), &t.1) } + } + + /// Gets the hash, view and value at the given index by insertion order. + /// + /// # Safety + /// The index must be less than len(). + #[inline(always)] + pub unsafe fn get_index_view_unchecked(&self, idx: IdxSize) -> (u64, View, &V) { + let t = unsafe { self.tuples.get_unchecked(idx as usize) }; + (t.0.hash, t.0.view, &t.1) + } + + /// Iterates over the (hash, key) pairs in insertion order. + pub fn iter_hash_keys(&self) -> impl Iterator { + self.tuples + .iter() + .map(|t| unsafe { (t.0.hash, t.0.view.get_slice_unchecked(&self.buffers)) }) + } + + /// Iterates over the (hash, key_view) pairs in insertion order. + pub fn iter_hash_views(&self) -> impl Iterator { + self.tuples.iter().map(|t| (t.0.hash, t.0.view)) + } + + /// Iterates over the values in insertion order. + pub fn iter_values(&self) -> impl Iterator { + self.tuples.iter().map(|t| &t.1) + } +} + +pub enum Entry<'a, 'k, V> { + Occupied(OccupiedEntry<'a, V>), + Vacant(VacantEntry<'a, 'k, V>), +} + +pub struct OccupiedEntry<'a, V> { + entry: TOccupiedEntry<'a, IdxSize>, + tuples: &'a mut Vec<(Key, V)>, +} + +impl<'a, V> OccupiedEntry<'a, V> { + #[inline] + pub fn index(&self) -> IdxSize { + *self.entry.get() + } + + #[inline] + pub fn into_mut(self) -> &'a mut V { + let idx = self.index(); + unsafe { &mut self.tuples.get_unchecked_mut(idx as usize).1 } + } +} + +pub struct VacantEntry<'a, 'k, V> { + hash: u64, + view: View, // Empty when key is not inlined. + external: Option<&'k [u8]>, // Only set when not inlined. + entry: TVacantEntry<'a, IdxSize>, + tuples: &'a mut Vec<(Key, V)>, + buffers: &'a mut Vec>, +} + +#[allow(clippy::needless_lifetimes)] +impl<'a, 'k, V> VacantEntry<'a, 'k, V> { + #[inline] + pub fn index(&self) -> IdxSize { + self.tuples.len() as IdxSize + } + + #[inline] + pub fn insert(self, value: V) -> &'a mut V { + unsafe { + let tuple_idx: IdxSize = self.tuples.len().try_into().unwrap(); + let view = if let Some(key) = self.external { + if self + .buffers + .last() + .is_none_or(|buf| buf.len() + key.len() > buf.capacity()) + { + let ideal_next_cap = BASE_KEY_BUFFER_CAPACITY + .checked_shl(self.buffers.len() as u32) + .unwrap(); + let next_capacity = std::cmp::max(ideal_next_cap, key.len()); + self.buffers.push(Vec::with_capacity(next_capacity)); + } + let buffer_idx = (self.buffers.len() - 1) as u32; + let active_buf = self.buffers.last_mut().unwrap_unchecked(); + let offset = active_buf.len() as u32; + active_buf.extend_from_slice(key); + View::new_from_bytes(key, buffer_idx, offset) + } else { + self.view + }; + let tuple_key = Key { + hash: self.hash, + view, + }; + self.tuples.push((tuple_key, value)); + self.entry.insert(tuple_idx); + &mut self.tuples.last_mut().unwrap_unchecked().1 + } + } +} diff --git a/crates/polars-compute/src/bitwise/mod.rs b/crates/polars-compute/src/bitwise/mod.rs new file mode 100644 index 000000000000..07f5827fa3f5 --- /dev/null +++ b/crates/polars-compute/src/bitwise/mod.rs @@ -0,0 +1,251 @@ +use std::convert::identity; + +use arrow::array::{Array, BooleanArray, PrimitiveArray}; +use arrow::datatypes::ArrowDataType; +use arrow::legacy::utils::CustomIterTools; + +pub trait BitwiseKernel { + type Scalar; + + fn count_ones(&self) -> PrimitiveArray; + fn count_zeros(&self) -> PrimitiveArray; + + fn leading_ones(&self) -> PrimitiveArray; + fn leading_zeros(&self) -> PrimitiveArray; + + fn trailing_ones(&self) -> PrimitiveArray; + fn trailing_zeros(&self) -> PrimitiveArray; + + fn reduce_and(&self) -> Option; + fn reduce_or(&self) -> Option; + fn reduce_xor(&self) -> Option; + + fn bit_and(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar; + fn bit_or(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar; + fn bit_xor(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar; +} + +macro_rules! impl_bitwise_kernel { + ($(($T:ty, $to_bits:expr, $from_bits:expr)),+ $(,)?) => { + $( + impl BitwiseKernel for PrimitiveArray<$T> { + type Scalar = $T; + + #[inline(never)] + fn count_ones(&self) -> PrimitiveArray { + PrimitiveArray::new( + ArrowDataType::UInt32, + self.values_iter() + .map(|&v| $to_bits(v).count_ones()) + .collect_trusted::>() + .into(), + self.validity().cloned(), + ) + } + + #[inline(never)] + fn count_zeros(&self) -> PrimitiveArray { + PrimitiveArray::new( + ArrowDataType::UInt32, + self.values_iter() + .map(|&v| $to_bits(v).count_zeros()) + .collect_trusted::>() + .into(), + self.validity().cloned(), + ) + } + + #[inline(never)] + fn leading_ones(&self) -> PrimitiveArray { + PrimitiveArray::new( + ArrowDataType::UInt32, + self.values_iter() + .map(|&v| $to_bits(v).leading_ones()) + .collect_trusted::>() + .into(), + self.validity().cloned(), + ) + } + + #[inline(never)] + fn leading_zeros(&self) -> PrimitiveArray { + PrimitiveArray::new( + ArrowDataType::UInt32, + self.values_iter() + .map(|&v| $to_bits(v).leading_zeros()) + .collect_trusted::>() + .into(), + self.validity().cloned(), + ) + } + + #[inline(never)] + fn trailing_ones(&self) -> PrimitiveArray { + PrimitiveArray::new( + ArrowDataType::UInt32, + self.values_iter() + .map(|&v| $to_bits(v).trailing_ones()) + .collect_trusted::>() + .into(), + self.validity().cloned(), + ) + } + + #[inline(never)] + fn trailing_zeros(&self) -> PrimitiveArray { + PrimitiveArray::new( + ArrowDataType::UInt32, + self.values().iter() + .map(|&v| $to_bits(v).trailing_zeros()) + .collect_trusted::>() + .into(), + self.validity().cloned(), + ) + } + + #[inline(never)] + fn reduce_and(&self) -> Option { + if !self.has_nulls() { + self.values_iter().copied().map($to_bits).reduce(|a, b| a & b).map($from_bits) + } else { + self.non_null_values_iter().map($to_bits).reduce(|a, b| a & b).map($from_bits) + } + } + + #[inline(never)] + fn reduce_or(&self) -> Option { + if !self.has_nulls() { + self.values_iter().copied().map($to_bits).reduce(|a, b| a | b).map($from_bits) + } else { + self.non_null_values_iter().map($to_bits).reduce(|a, b| a | b).map($from_bits) + } + } + + #[inline(never)] + fn reduce_xor(&self) -> Option { + if !self.has_nulls() { + self.values_iter().copied().map($to_bits).reduce(|a, b| a ^ b).map($from_bits) + } else { + self.non_null_values_iter().map($to_bits).reduce(|a, b| a ^ b).map($from_bits) + } + } + + fn bit_and(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar { + $from_bits($to_bits(lhs) & $to_bits(rhs)) + } + fn bit_or(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar { + $from_bits($to_bits(lhs) | $to_bits(rhs)) + } + fn bit_xor(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar { + $from_bits($to_bits(lhs) ^ $to_bits(rhs)) + } + } + )+ + }; +} + +impl_bitwise_kernel! { + (i8, identity, identity), + (i16, identity, identity), + (i32, identity, identity), + (i64, identity, identity), + (u8, identity, identity), + (u16, identity, identity), + (u32, identity, identity), + (u64, identity, identity), + (f32, f32::to_bits, f32::from_bits), + (f64, f64::to_bits, f64::from_bits), +} + +#[cfg(feature = "dtype-i128")] +impl_bitwise_kernel! { + (i128, identity, identity), +} + +impl BitwiseKernel for BooleanArray { + type Scalar = bool; + + #[inline(never)] + fn count_ones(&self) -> PrimitiveArray { + PrimitiveArray::new( + ArrowDataType::UInt32, + self.values_iter() + .map(u32::from) + .collect_trusted::>() + .into(), + self.validity().cloned(), + ) + } + + #[inline(never)] + fn count_zeros(&self) -> PrimitiveArray { + PrimitiveArray::new( + ArrowDataType::UInt32, + self.values_iter() + .map(|v| u32::from(!v)) + .collect_trusted::>() + .into(), + self.validity().cloned(), + ) + } + + #[inline(always)] + fn leading_ones(&self) -> PrimitiveArray { + self.count_ones() + } + + #[inline(always)] + fn leading_zeros(&self) -> PrimitiveArray { + self.count_zeros() + } + + #[inline(always)] + fn trailing_ones(&self) -> PrimitiveArray { + self.count_ones() + } + + #[inline(always)] + fn trailing_zeros(&self) -> PrimitiveArray { + self.count_zeros() + } + + fn reduce_and(&self) -> Option { + if self.len() == self.null_count() { + None + } else if !self.has_nulls() { + Some(self.values().unset_bits() == 0) + } else { + Some((self.values() & self.validity().unwrap()).unset_bits() == 0) + } + } + + fn reduce_or(&self) -> Option { + if self.len() == self.null_count() { + None + } else if !self.has_nulls() { + Some(self.values().set_bits() > 0) + } else { + Some((self.values() & self.validity().unwrap()).set_bits() > 0) + } + } + + fn reduce_xor(&self) -> Option { + if self.len() == self.null_count() { + None + } else if !self.has_nulls() { + Some(self.values().set_bits() % 2 == 1) + } else { + Some((self.values() & self.validity().unwrap()).set_bits() % 2 == 1) + } + } + + fn bit_and(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar { + lhs & rhs + } + fn bit_or(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar { + lhs | rhs + } + fn bit_xor(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar { + lhs ^ rhs + } +} diff --git a/crates/polars-compute/src/cardinality.rs b/crates/polars-compute/src/cardinality.rs new file mode 100644 index 000000000000..d28efa9d051e --- /dev/null +++ b/crates/polars-compute/src/cardinality.rs @@ -0,0 +1,159 @@ +use arrow::array::{ + Array, BinaryArray, BinaryViewArray, BooleanArray, FixedSizeBinaryArray, PrimitiveArray, + Utf8Array, Utf8ViewArray, +}; +use arrow::datatypes::PhysicalType; +use arrow::types::Offset; +use arrow::with_match_primitive_type_full; +use polars_utils::total_ord::ToTotalOrd; + +use crate::hyperloglogplus::HyperLogLog; + +/// Get an estimate for the *cardinality* of the array (i.e. the number of unique values) +/// +/// This is not currently implemented for nested types. +pub fn estimate_cardinality(array: &dyn Array) -> usize { + if array.is_empty() { + return 0; + } + + if array.null_count() == array.len() { + return 1; + } + + // Estimate the cardinality with HyperLogLog + use PhysicalType as PT; + match array.dtype().to_physical_type() { + PT::Null => 1, + + PT::Boolean => { + let mut cardinality = 0; + + let array = array.as_any().downcast_ref::().unwrap(); + + cardinality += usize::from(array.has_nulls()); + + if let Some(unset_bits) = array.values().lazy_unset_bits() { + cardinality += 1 + usize::from(unset_bits != array.len()); + } else { + cardinality += 2; + } + + cardinality + }, + + PT::Primitive(primitive_type) => with_match_primitive_type_full!(primitive_type, |$T| { + let mut hll = HyperLogLog::new(); + + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + if array.has_nulls() { + for v in array.iter() { + let v = v.copied().unwrap_or_default(); + hll.add(&v.to_total_ord()); + } + } else { + for v in array.values_iter() { + hll.add(&v.to_total_ord()); + } + } + + hll.count() + }), + PT::FixedSizeBinary => { + let mut hll = HyperLogLog::new(); + + let array = array + .as_any() + .downcast_ref::() + .unwrap(); + + if array.has_nulls() { + for v in array.iter() { + let v = v.unwrap_or_default(); + hll.add(v); + } + } else { + for v in array.values_iter() { + hll.add(v); + } + } + + hll.count() + }, + PT::Binary => { + binary_offset_array_estimate(array.as_any().downcast_ref::>().unwrap()) + }, + PT::LargeBinary => { + binary_offset_array_estimate(array.as_any().downcast_ref::>().unwrap()) + }, + PT::Utf8 => binary_offset_array_estimate( + &array + .as_any() + .downcast_ref::>() + .unwrap() + .to_binary(), + ), + PT::LargeUtf8 => binary_offset_array_estimate( + &array + .as_any() + .downcast_ref::>() + .unwrap() + .to_binary(), + ), + PT::BinaryView => { + binary_view_array_estimate(array.as_any().downcast_ref::().unwrap()) + }, + PT::Utf8View => binary_view_array_estimate( + &array + .as_any() + .downcast_ref::() + .unwrap() + .to_binview(), + ), + PT::List => unimplemented!(), + PT::FixedSizeList => unimplemented!(), + PT::LargeList => unimplemented!(), + PT::Struct => unimplemented!(), + PT::Union => unimplemented!(), + PT::Map => unimplemented!(), + PT::Dictionary(_) => unimplemented!(), + } +} + +fn binary_offset_array_estimate(array: &BinaryArray) -> usize { + let mut hll = HyperLogLog::new(); + + if array.has_nulls() { + for v in array.iter() { + let v = v.unwrap_or_default(); + hll.add(v); + } + } else { + for v in array.values_iter() { + hll.add(v); + } + } + + hll.count() +} + +fn binary_view_array_estimate(array: &BinaryViewArray) -> usize { + let mut hll = HyperLogLog::new(); + + if array.has_nulls() { + for v in array.iter() { + let v = v.unwrap_or_default(); + hll.add(v); + } + } else { + for v in array.values_iter() { + hll.add(v); + } + } + + hll.count() +} diff --git a/crates/polars-compute/src/cast/binary_to.rs b/crates/polars-compute/src/cast/binary_to.rs new file mode 100644 index 000000000000..ff679196ede2 --- /dev/null +++ b/crates/polars-compute/src/cast/binary_to.rs @@ -0,0 +1,249 @@ +use std::sync::Arc; + +use arrow::array::*; +use arrow::buffer::Buffer; +use arrow::datatypes::ArrowDataType; +use arrow::offset::{Offset, Offsets}; +use arrow::types::NativeType; +use polars_error::PolarsResult; + +use super::CastOptionsImpl; + +pub(super) trait Parse { + fn parse(val: &[u8]) -> Option + where + Self: Sized; +} + +macro_rules! impl_parse { + ($primitive_type:ident) => { + impl Parse for $primitive_type { + fn parse(val: &[u8]) -> Option { + atoi_simd::parse_skipped(val).ok() + } + } + }; +} +impl_parse!(i8); +impl_parse!(i16); +impl_parse!(i32); +impl_parse!(i64); + +impl_parse!(u8); +impl_parse!(u16); +impl_parse!(u32); +impl_parse!(u64); + +#[cfg(feature = "dtype-i128")] +impl_parse!(i128); + +impl Parse for f32 { + fn parse(val: &[u8]) -> Option + where + Self: Sized, + { + fast_float2::parse(val).ok() + } +} +impl Parse for f64 { + fn parse(val: &[u8]) -> Option + where + Self: Sized, + { + fast_float2::parse(val).ok() + } +} + +/// Conversion of binary +pub fn binary_to_large_binary( + from: &BinaryArray, + to_dtype: ArrowDataType, +) -> BinaryArray { + let values = from.values().clone(); + BinaryArray::::new( + to_dtype, + from.offsets().into(), + values, + from.validity().cloned(), + ) +} + +/// Conversion of binary +pub fn binary_large_to_binary( + from: &BinaryArray, + to_dtype: ArrowDataType, +) -> PolarsResult> { + let values = from.values().clone(); + let offsets = from.offsets().try_into()?; + Ok(BinaryArray::::new( + to_dtype, + offsets, + values, + from.validity().cloned(), + )) +} + +/// Conversion to utf8 +pub fn binary_to_utf8( + from: &BinaryArray, + to_dtype: ArrowDataType, +) -> PolarsResult> { + Utf8Array::::try_new( + to_dtype, + from.offsets().clone(), + from.values().clone(), + from.validity().cloned(), + ) +} + +/// Casts a [`BinaryArray`] to a [`PrimitiveArray`], making any uncastable value a Null. +pub(super) fn binary_to_primitive( + from: &BinaryArray, + to: &ArrowDataType, +) -> PrimitiveArray +where + T: NativeType + Parse, +{ + let iter = from.iter().map(|x| x.and_then::(|x| T::parse(x))); + + PrimitiveArray::::from_trusted_len_iter(iter).to(to.clone()) +} + +pub(super) fn binary_to_primitive_dyn( + from: &dyn Array, + to: &ArrowDataType, + options: CastOptionsImpl, +) -> PolarsResult> +where + T: NativeType + Parse, +{ + let from = from.as_any().downcast_ref().unwrap(); + if options.partial { + unimplemented!() + } else { + Ok(Box::new(binary_to_primitive::(from, to))) + } +} + +/// Cast [`BinaryArray`] to [`DictionaryArray`], also known as packing. +/// # Errors +/// This function errors if the maximum key is smaller than the number of distinct elements +/// in the array. +pub fn binary_to_dictionary( + from: &BinaryArray, +) -> PolarsResult> { + let mut array = MutableDictionaryArray::>::new(); + array.reserve(from.len()); + array.try_extend(from.iter())?; + + Ok(array.into()) +} + +pub(super) fn binary_to_dictionary_dyn( + from: &dyn Array, +) -> PolarsResult> { + let values = from.as_any().downcast_ref().unwrap(); + binary_to_dictionary::(values).map(|x| Box::new(x) as Box) +} + +fn fixed_size_to_offsets(values_len: usize, fixed_size: usize) -> Offsets { + let offsets = (0..(values_len + 1)) + .step_by(fixed_size) + .map(|v| O::from_as_usize(v)) + .collect(); + // SAFETY: + // * every element is `>= 0` + // * element at position `i` is >= than element at position `i-1`. + unsafe { Offsets::new_unchecked(offsets) } +} + +/// Conversion of `FixedSizeBinary` to `Binary`. +pub fn fixed_size_binary_binary( + from: &FixedSizeBinaryArray, + to_dtype: ArrowDataType, +) -> BinaryArray { + let values = from.values().clone(); + let offsets = fixed_size_to_offsets(values.len(), from.size()); + BinaryArray::::new(to_dtype, offsets.into(), values, from.validity().cloned()) +} + +pub fn fixed_size_binary_to_binview(from: &FixedSizeBinaryArray) -> BinaryViewArray { + let datatype = <[u8] as ViewType>::DATA_TYPE; + + // Fast path: all the views are inlineable + if from.size() <= View::MAX_INLINE_SIZE as usize { + // @NOTE: There is something with the code-generation of `View::new_inline_unchecked` that + // prevents it from properly SIMD-ing this loop. It insists on memcpying while it should + // know that the size is really small. Dispatching over the `from.size()` and making it + // constant does make loop SIMD, but it does not actually speed anything up and the code it + // generates is still horrible. + // + // This is really slow, and I don't think it has to be. + + // SAFETY: We checked that slice.len() <= View::MAX_INLINE_SIZE before + let mut views = Vec::new(); + View::extend_with_inlinable_strided( + &mut views, + from.values().as_slice(), + from.size() as u8, + ); + let views = Buffer::from(views); + return BinaryViewArray::try_new(datatype, views, Arc::default(), from.validity().cloned()) + .unwrap(); + } + + const MAX_BYTES_PER_BUFFER: usize = u32::MAX as usize; + + let size = from.size(); + let num_bytes = from.len() * size; + let num_buffers = num_bytes.div_ceil(MAX_BYTES_PER_BUFFER); + assert!(num_buffers < u32::MAX as usize); + + let num_elements_per_buffer = MAX_BYTES_PER_BUFFER / size; + // This is NOT equal to MAX_BYTES_PER_BUFFER because of integer division + let split_point = num_elements_per_buffer * size; + + // This is zero-copy for the buffer since split just increases the data since + let mut buffer = from.values().clone(); + let mut buffers = Vec::with_capacity(num_buffers); + + if let Some(num_buffers) = num_buffers.checked_sub(1) { + for _ in 0..num_buffers { + let slice; + (slice, buffer) = buffer.split_at(split_point); + buffers.push(slice); + } + buffers.push(buffer); + } + + let mut iter = from.values_iter(); + let iter = iter.by_ref(); + let mut views = Vec::with_capacity(from.len()); + for buffer_idx in 0..num_buffers { + views.extend( + iter.take(num_elements_per_buffer) + .enumerate() + .map(|(i, slice)| { + // SAFETY: We checked that slice.len() > View::MAX_INLINE_SIZE before + unsafe { + View::new_noninline_unchecked(slice, buffer_idx as u32, (i * size) as u32) + } + }), + ); + } + let views = views.into(); + + BinaryViewArray::try_new(datatype, views, buffers.into(), from.validity().cloned()).unwrap() +} + +/// Conversion of binary +pub fn binary_to_list(from: &BinaryArray, to_dtype: ArrowDataType) -> ListArray { + let values = from.values().clone(); + let values = PrimitiveArray::new(ArrowDataType::UInt8, values, None); + ListArray::::new( + to_dtype, + from.offsets().clone(), + values.boxed(), + from.validity().cloned(), + ) +} diff --git a/crates/polars-compute/src/cast/binview_to.rs b/crates/polars-compute/src/cast/binview_to.rs new file mode 100644 index 000000000000..152f4bdf212c --- /dev/null +++ b/crates/polars-compute/src/cast/binview_to.rs @@ -0,0 +1,141 @@ +use arrow::array::*; +#[cfg(feature = "dtype-decimal")] +use arrow::compute::decimal::deserialize_decimal; +use arrow::datatypes::{ArrowDataType, TimeUnit}; +use arrow::offset::Offset; +use arrow::types::NativeType; +use chrono::Datelike; +use polars_error::PolarsResult; + +use super::CastOptionsImpl; +use super::binary_to::Parse; +use super::temporal::EPOCH_DAYS_FROM_CE; + +pub(super) const RFC3339: &str = "%Y-%m-%dT%H:%M:%S%.f%:z"; + +/// Cast [`BinaryViewArray`] to [`DictionaryArray`], also known as packing. +/// # Errors +/// This function errors if the maximum key is smaller than the number of distinct elements +/// in the array. +pub(super) fn binview_to_dictionary( + from: &BinaryViewArray, +) -> PolarsResult> { + let mut array = MutableDictionaryArray::>::new(); + array.reserve(from.len()); + array.try_extend(from.iter())?; + + Ok(array.into()) +} + +pub(super) fn utf8view_to_dictionary( + from: &Utf8ViewArray, +) -> PolarsResult> { + let mut array = MutableDictionaryArray::>::new(); + array.reserve(from.len()); + array.try_extend(from.iter())?; + + Ok(array.into()) +} + +pub(super) fn view_to_binary(array: &BinaryViewArray) -> BinaryArray { + let len: usize = Array::len(array); + let mut mutable = MutableBinaryValuesArray::::with_capacities(len, array.total_bytes_len()); + for slice in array.values_iter() { + mutable.push(slice) + } + let out: BinaryArray = mutable.into(); + out.with_validity(array.validity().cloned()) +} + +pub fn utf8view_to_utf8(array: &Utf8ViewArray) -> Utf8Array { + let array = array.to_binview(); + let out = view_to_binary::(&array); + + let dtype = Utf8Array::::default_dtype(); + unsafe { + Utf8Array::new_unchecked( + dtype, + out.offsets().clone(), + out.values().clone(), + out.validity().cloned(), + ) + } +} +/// Casts a [`BinaryArray`] to a [`PrimitiveArray`], making any uncastable value a Null. +pub(super) fn binview_to_primitive( + from: &BinaryViewArray, + to: &ArrowDataType, +) -> PrimitiveArray +where + T: NativeType + Parse, +{ + let iter = from.iter().map(|x| x.and_then::(|x| T::parse(x))); + + PrimitiveArray::::from_trusted_len_iter(iter).to(to.clone()) +} + +pub(super) fn binview_to_primitive_dyn( + from: &dyn Array, + to: &ArrowDataType, + options: CastOptionsImpl, +) -> PolarsResult> +where + T: NativeType + Parse, +{ + let from = from.as_any().downcast_ref().unwrap(); + if options.partial { + unimplemented!() + } else { + Ok(Box::new(binview_to_primitive::(from, to))) + } +} + +#[cfg(feature = "dtype-decimal")] +pub fn binview_to_decimal( + array: &BinaryViewArray, + precision: Option, + scale: usize, +) -> PrimitiveArray { + let precision = precision.map(|p| p as u8); + PrimitiveArray::::from_trusted_len_iter( + array + .iter() + .map(|val| val.and_then(|val| deserialize_decimal(val, precision, scale as u8))), + ) + .to(ArrowDataType::Decimal( + precision.unwrap_or(38).into(), + scale, + )) +} + +pub(super) fn utf8view_to_naive_timestamp_dyn( + from: &dyn Array, + time_unit: TimeUnit, +) -> PolarsResult> { + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(utf8view_to_naive_timestamp(from, time_unit))) +} + +/// [`super::temporal::utf8view_to_timestamp`] applied for RFC3339 formatting +pub fn utf8view_to_naive_timestamp( + from: &Utf8ViewArray, + time_unit: TimeUnit, +) -> PrimitiveArray { + super::temporal::utf8view_to_naive_timestamp(from, RFC3339, time_unit) +} + +pub(super) fn utf8view_to_date32(from: &Utf8ViewArray) -> PrimitiveArray { + let iter = from.iter().map(|x| { + x.and_then(|x| { + x.parse::() + .ok() + .map(|x| x.num_days_from_ce() - EPOCH_DAYS_FROM_CE) + }) + }); + PrimitiveArray::::from_trusted_len_iter(iter).to(ArrowDataType::Date32) +} + +pub(super) fn utf8view_to_date32_dyn(from: &dyn Array) -> PolarsResult> { + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(utf8view_to_date32(from))) +} diff --git a/crates/polars-compute/src/cast/boolean_to.rs b/crates/polars-compute/src/cast/boolean_to.rs new file mode 100644 index 000000000000..251178711774 --- /dev/null +++ b/crates/polars-compute/src/cast/boolean_to.rs @@ -0,0 +1,51 @@ +use arrow::array::{Array, BooleanArray, PrimitiveArray}; +use arrow::types::NativeType; +use polars_error::PolarsResult; + +use super::{ArrayFromIter, BinaryViewArray, Utf8ViewArray}; + +pub(super) fn boolean_to_primitive_dyn(array: &dyn Array) -> PolarsResult> +where + T: NativeType + num_traits::One, +{ + let array = array.as_any().downcast_ref().unwrap(); + Ok(Box::new(boolean_to_primitive::(array))) +} + +/// Casts the [`BooleanArray`] to a [`PrimitiveArray`]. +pub fn boolean_to_primitive(from: &BooleanArray) -> PrimitiveArray +where + T: NativeType + num_traits::One, +{ + let values = from + .values() + .iter() + .map(|x| if x { T::one() } else { T::default() }) + .collect::>(); + + PrimitiveArray::::new(T::PRIMITIVE.into(), values.into(), from.validity().cloned()) +} + +pub fn boolean_to_utf8view(from: &BooleanArray) -> Utf8ViewArray { + unsafe { boolean_to_binaryview(from).to_utf8view_unchecked() } +} + +pub(super) fn boolean_to_utf8view_dyn(array: &dyn Array) -> PolarsResult> { + let array = array.as_any().downcast_ref().unwrap(); + Ok(boolean_to_utf8view(array).boxed()) +} + +/// Casts the [`BooleanArray`] to a [`BinaryArray`], casting trues to `"1"` and falses to `"0"` +pub fn boolean_to_binaryview(from: &BooleanArray) -> BinaryViewArray { + let iter = from.iter().map(|opt_b| match opt_b { + Some(true) => Some("true".as_bytes()), + Some(false) => Some("false".as_bytes()), + None => None, + }); + BinaryViewArray::arr_from_iter_trusted(iter) +} + +pub(super) fn boolean_to_binaryview_dyn(array: &dyn Array) -> PolarsResult> { + let array = array.as_any().downcast_ref().unwrap(); + Ok(boolean_to_binaryview(array).boxed()) +} diff --git a/crates/polars-compute/src/cast/decimal_to.rs b/crates/polars-compute/src/cast/decimal_to.rs new file mode 100644 index 000000000000..e19481ff6154 --- /dev/null +++ b/crates/polars-compute/src/cast/decimal_to.rs @@ -0,0 +1,162 @@ +use arrow::array::*; +use arrow::datatypes::ArrowDataType; +use arrow::types::NativeType; +use num_traits::{AsPrimitive, Float, NumCast}; +use polars_error::PolarsResult; + +#[inline] +fn decimal_to_decimal_impl Option>( + from: &PrimitiveArray, + op: F, + to_precision: usize, + to_scale: usize, +) -> PrimitiveArray { + let upper_bound_for_precision = 10_i128.saturating_pow(to_precision as u32); + let lower_bound_for_precision = upper_bound_for_precision.saturating_neg(); + + let values = from.iter().map(|x| { + x.and_then(|x| { + op(*x).and_then(|x| { + if x >= upper_bound_for_precision || x <= lower_bound_for_precision { + None + } else { + Some(x) + } + }) + }) + }); + PrimitiveArray::::from_trusted_len_iter(values) + .to(ArrowDataType::Decimal(to_precision, to_scale)) +} + +/// Returns a [`PrimitiveArray`] with the cast values. Values are `None` on overflow +pub fn decimal_to_decimal( + from: &PrimitiveArray, + to_precision: usize, + to_scale: usize, +) -> PrimitiveArray { + let (from_precision, from_scale) = + if let ArrowDataType::Decimal(p, s) = from.dtype().to_logical_type() { + (*p, *s) + } else { + panic!("internal error: i128 is always a decimal") + }; + + if to_scale == from_scale && to_precision >= from_precision { + // fast path + return from + .clone() + .to(ArrowDataType::Decimal(to_precision, to_scale)); + } + // todo: other fast paths include increasing scale and precision by so that + // a number will never overflow (validity is preserved) + + if from_scale > to_scale { + let factor = 10_i128.pow((from_scale - to_scale) as u32); + decimal_to_decimal_impl( + from, + |x: i128| x.checked_div(factor), + to_precision, + to_scale, + ) + } else { + let factor = 10_i128.pow((to_scale - from_scale) as u32); + decimal_to_decimal_impl( + from, + |x: i128| x.checked_mul(factor), + to_precision, + to_scale, + ) + } +} + +pub(super) fn decimal_to_decimal_dyn( + from: &dyn Array, + to_precision: usize, + to_scale: usize, +) -> PolarsResult> { + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(decimal_to_decimal(from, to_precision, to_scale))) +} + +/// Returns a [`PrimitiveArray`] with the cast values. Values are `None` on overflow +pub fn decimal_to_float(from: &PrimitiveArray) -> PrimitiveArray +where + T: NativeType + Float, + f64: AsPrimitive, +{ + let (_, from_scale) = if let ArrowDataType::Decimal(p, s) = from.dtype().to_logical_type() { + (*p, *s) + } else { + panic!("internal error: i128 is always a decimal") + }; + + let div = 10_f64.powi(from_scale as i32); + let values = from + .values() + .iter() + .map(|x| (*x as f64 / div).as_()) + .collect(); + + PrimitiveArray::::new(T::PRIMITIVE.into(), values, from.validity().cloned()) +} + +pub(super) fn decimal_to_float_dyn(from: &dyn Array) -> PolarsResult> +where + T: NativeType + Float, + f64: AsPrimitive, +{ + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(decimal_to_float::(from))) +} + +/// Returns a [`PrimitiveArray`] with the cast values. Values are `None` on overflow +pub fn decimal_to_integer(from: &PrimitiveArray) -> PrimitiveArray +where + T: NativeType + NumCast, +{ + let (_, from_scale) = if let ArrowDataType::Decimal(p, s) = from.dtype().to_logical_type() { + (*p, *s) + } else { + panic!("internal error: i128 is always a decimal") + }; + + let factor = 10_i128.pow(from_scale as u32); + let values = from.iter().map(|x| x.and_then(|x| T::from(*x / factor))); + + PrimitiveArray::from_trusted_len_iter(values) +} + +pub(super) fn decimal_to_integer_dyn(from: &dyn Array) -> PolarsResult> +where + T: NativeType + NumCast, +{ + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(decimal_to_integer::(from))) +} + +/// Returns a [`Utf8Array`] where every element is the utf8 representation of the decimal. +#[cfg(feature = "dtype-decimal")] +pub(super) fn decimal_to_utf8view(from: &PrimitiveArray) -> Utf8ViewArray { + use arrow::compute::decimal::DecimalFmtBuffer; + + let (_, from_scale) = if let ArrowDataType::Decimal(p, s) = from.dtype().to_logical_type() { + (*p, *s) + } else { + panic!("internal error: i128 is always a decimal") + }; + + let mut mutable = MutableBinaryViewArray::with_capacity(from.len()); + let mut fmt_buf = DecimalFmtBuffer::new(); + for &x in from.values().iter() { + mutable.push_value_ignore_validity(fmt_buf.format(x, from_scale, false)) + } + + mutable.freeze().with_validity(from.validity().cloned()) +} + +#[cfg(feature = "dtype-decimal")] +pub(super) fn decimal_to_utf8view_dyn(from: &dyn Array) -> Utf8ViewArray { + let from = from.as_any().downcast_ref().unwrap(); + decimal_to_utf8view(from) +} diff --git a/crates/polars-compute/src/cast/dictionary_to.rs b/crates/polars-compute/src/cast/dictionary_to.rs new file mode 100644 index 000000000000..16622d5b7703 --- /dev/null +++ b/crates/polars-compute/src/cast/dictionary_to.rs @@ -0,0 +1,49 @@ +use arrow::array::{Array, DictionaryArray, DictionaryKey}; +use arrow::datatypes::ArrowDataType; +use arrow::match_integer_type; +use polars_error::{PolarsResult, polars_bail}; + +use super::{CastOptionsImpl, cast, primitive_to_primitive}; + +macro_rules! key_cast { + ($keys:expr, $values:expr, $array:expr, $to_keys_type:expr, $to_type:ty, $to_datatype:expr) => {{ + let cast_keys = primitive_to_primitive::<_, $to_type>($keys, $to_keys_type); + + // Failure to cast keys (because they don't fit in the + // target type) results in NULL values; + if cast_keys.null_count() > $keys.null_count() { + polars_bail!(ComputeError: "overflow") + } + // SAFETY: this is safe because given a type `T` that fits in a `usize`, casting it to type `P` either overflows or also fits in a `usize` + unsafe { + DictionaryArray::try_new_unchecked($to_datatype, cast_keys, $values.clone()) + } + .map(|x| x.boxed()) + }}; +} + +pub(super) fn dictionary_cast_dyn( + array: &dyn Array, + to_type: &ArrowDataType, + options: CastOptionsImpl, +) -> PolarsResult> { + let array = array.as_any().downcast_ref::>().unwrap(); + let keys = array.keys(); + let values = array.values(); + + match to_type { + ArrowDataType::Dictionary(to_keys_type, to_values_type, _) => { + let values = cast(values.as_ref(), to_values_type, options)?; + + // create the appropriate array type + let to_key_type = (*to_keys_type).into(); + + // SAFETY: + // we return an error on overflow so the integers remain within bounds + match_integer_type!(to_keys_type, |$T| { + key_cast!(keys, values, array, &to_key_type, $T, to_type.clone()) + }) + }, + _ => unimplemented!(), + } +} diff --git a/crates/polars-compute/src/cast/mod.rs b/crates/polars-compute/src/cast/mod.rs new file mode 100644 index 000000000000..a70f96200c3c --- /dev/null +++ b/crates/polars-compute/src/cast/mod.rs @@ -0,0 +1,855 @@ +//! Defines different casting operators such as [`cast`] or [`primitive_to_binary`]. + +mod binary_to; +mod binview_to; +mod boolean_to; +mod decimal_to; +mod dictionary_to; +mod primitive_to; +mod utf8_to; + +pub use binary_to::*; +#[cfg(feature = "dtype-decimal")] +pub use binview_to::binview_to_decimal; +use binview_to::binview_to_primitive_dyn; +pub use binview_to::utf8view_to_utf8; +pub use boolean_to::*; +pub use decimal_to::*; +pub mod temporal; +use arrow::array::*; +use arrow::datatypes::*; +use arrow::match_integer_type; +use arrow::offset::{Offset, Offsets}; +use binview_to::{ + binview_to_dictionary, utf8view_to_date32_dyn, utf8view_to_dictionary, + utf8view_to_naive_timestamp_dyn, view_to_binary, +}; +use dictionary_to::*; +use polars_error::{PolarsResult, polars_bail, polars_ensure, polars_err}; +use polars_utils::IdxSize; +pub use primitive_to::*; +use temporal::utf8view_to_timestamp; +pub use utf8_to::*; + +/// options defining how Cast kernels behave +#[derive(Clone, Copy, Debug, Default)] +pub struct CastOptionsImpl { + /// default to false + /// whether an overflowing cast should be converted to `None` (default), or be wrapped (i.e. `256i16 as u8 = 0` vectorized). + /// Settings this to `true` is 5-6x faster for numeric types. + pub wrapped: bool, + /// default to false + /// whether to cast to an integer at the best-effort + pub partial: bool, +} + +impl CastOptionsImpl { + pub fn unchecked() -> Self { + Self { + wrapped: true, + partial: false, + } + } +} + +impl CastOptionsImpl { + fn with_wrapped(&self, v: bool) -> Self { + let mut option = *self; + option.wrapped = v; + option + } +} + +macro_rules! primitive_dyn { + ($from:expr, $expr:tt) => {{ + let from = $from.as_any().downcast_ref().unwrap(); + Ok(Box::new($expr(from))) + }}; + ($from:expr, $expr:tt, $to:expr) => {{ + let from = $from.as_any().downcast_ref().unwrap(); + Ok(Box::new($expr(from, $to))) + }}; + ($from:expr, $expr:tt, $from_t:expr, $to:expr) => {{ + let from = $from.as_any().downcast_ref().unwrap(); + Ok(Box::new($expr(from, $from_t, $to))) + }}; + ($from:expr, $expr:tt, $arg1:expr, $arg2:expr, $arg3:expr) => {{ + let from = $from.as_any().downcast_ref().unwrap(); + Ok(Box::new($expr(from, $arg1, $arg2, $arg3))) + }}; +} + +fn cast_struct( + array: &StructArray, + to_type: &ArrowDataType, + options: CastOptionsImpl, +) -> PolarsResult { + let values = array.values(); + let fields = StructArray::get_fields(to_type); + let new_values = values + .iter() + .zip(fields) + .map(|(arr, field)| cast(arr.as_ref(), field.dtype(), options)) + .collect::>>()?; + + Ok(StructArray::new( + to_type.clone(), + array.len(), + new_values, + array.validity().cloned(), + )) +} + +fn cast_list( + array: &ListArray, + to_type: &ArrowDataType, + options: CastOptionsImpl, +) -> PolarsResult> { + let values = array.values(); + let new_values = cast( + values.as_ref(), + ListArray::::get_child_type(to_type), + options, + )?; + + Ok(ListArray::::new( + to_type.clone(), + array.offsets().clone(), + new_values, + array.validity().cloned(), + )) +} + +fn cast_list_to_large_list(array: &ListArray, to_type: &ArrowDataType) -> ListArray { + let offsets = array.offsets().into(); + + ListArray::::new( + to_type.clone(), + offsets, + array.values().clone(), + array.validity().cloned(), + ) +} + +fn cast_large_to_list(array: &ListArray, to_type: &ArrowDataType) -> ListArray { + let offsets = array.offsets().try_into().expect("Convertme to error"); + + ListArray::::new( + to_type.clone(), + offsets, + array.values().clone(), + array.validity().cloned(), + ) +} + +fn cast_fixed_size_list_to_list( + fixed: &FixedSizeListArray, + to_type: &ArrowDataType, + options: CastOptionsImpl, +) -> PolarsResult> { + let new_values = cast( + fixed.values().as_ref(), + ListArray::::get_child_type(to_type), + options, + )?; + + let offsets = (0..=fixed.len()) + .map(|ix| O::from_as_usize(ix * fixed.size())) + .collect::>(); + // SAFETY: offsets _are_ monotonically increasing + let offsets = unsafe { Offsets::new_unchecked(offsets) }; + + Ok(ListArray::::new( + to_type.clone(), + offsets.into(), + new_values, + fixed.validity().cloned(), + )) +} + +fn cast_list_to_fixed_size_list( + list: &ListArray, + inner: &Field, + size: usize, + options: CastOptionsImpl, +) -> PolarsResult { + let null_cnt = list.null_count(); + let new_values = if null_cnt == 0 { + let start_offset = list.offsets().first().to_usize(); + let offsets = list.offsets().buffer(); + + let mut is_valid = true; + for (i, offset) in offsets.iter().enumerate() { + is_valid &= offset.to_usize() == start_offset + i * size; + } + + polars_ensure!(is_valid, ComputeError: "not all elements have the specified width {size}"); + + let sliced_values = list + .values() + .sliced(start_offset, list.offsets().range().to_usize()); + cast(sliced_values.as_ref(), inner.dtype(), options)? + } else { + let offsets = list.offsets().as_slice(); + // Check the lengths of each list are equal to the fixed size. + // SAFETY: we know the index is in bound. + let mut expected_offset = unsafe { *offsets.get_unchecked(0) } + O::from_as_usize(size); + for i in 1..=list.len() { + // SAFETY: we know the index is in bound. + let current_offset = unsafe { *offsets.get_unchecked(i) }; + if list.is_null(i - 1) { + expected_offset = current_offset + O::from_as_usize(size); + } else { + polars_ensure!(current_offset == expected_offset, ComputeError: + "not all elements have the specified width {size}"); + expected_offset += O::from_as_usize(size); + } + } + + // Build take indices for the values. This is used to fill in the null slots. + let mut indices = + MutablePrimitiveArray::::with_capacity(list.values().len() + null_cnt * size); + for i in 0..list.len() { + if list.is_null(i) { + indices.extend_constant(size, None) + } else { + // SAFETY: we know the index is in bound. + let current_offset = unsafe { *offsets.get_unchecked(i) }; + for j in 0..size { + indices.push(Some( + (current_offset + O::from_as_usize(j)).to_usize() as IdxSize + )); + } + } + } + let take_values = + unsafe { crate::gather::take_unchecked(list.values().as_ref(), &indices.freeze()) }; + + cast(take_values.as_ref(), inner.dtype(), options)? + }; + + FixedSizeListArray::try_new( + ArrowDataType::FixedSizeList(Box::new(inner.clone()), size), + list.len(), + new_values, + list.validity().cloned(), + ) + .map_err(|_| polars_err!(ComputeError: "not all elements have the specified width {size}")) +} + +pub fn cast_default(array: &dyn Array, to_type: &ArrowDataType) -> PolarsResult> { + cast(array, to_type, Default::default()) +} + +pub fn cast_unchecked(array: &dyn Array, to_type: &ArrowDataType) -> PolarsResult> { + cast(array, to_type, CastOptionsImpl::unchecked()) +} + +/// Cast `array` to the provided data type and return a new [`Array`] with +/// type `to_type`, if possible. +/// +/// Behavior: +/// * PrimitiveArray to PrimitiveArray: overflowing cast will be None +/// * Boolean to Utf8: `true` => '1', `false` => `0` +/// * Utf8 to numeric: strings that can't be parsed to numbers return null, float strings +/// in integer casts return null +/// * Numeric to boolean: 0 returns `false`, any other value returns `true` +/// * List to List: the underlying data type is cast +/// * Fixed Size List to List: the underlying data type is cast +/// * List to Fixed Size List: the offsets are checked for valid order, then the +/// underlying type is cast. +/// * Struct to Struct: the underlying fields are cast. +/// * PrimitiveArray to List: a list array with 1 value per slot is created +/// * Date32 and Date64: precision lost when going to higher interval +/// * Time32 and Time64: precision lost when going to higher interval +/// * Timestamp and Date{32|64}: precision lost when going to higher interval +/// * Temporal to/from backing primitive: zero-copy with data type change +/// +/// Unsupported Casts +/// * non-`StructArray` to `StructArray` or `StructArray` to non-`StructArray` +/// * List to primitive +/// * Utf8 to boolean +/// * Interval and duration +pub fn cast( + array: &dyn Array, + to_type: &ArrowDataType, + options: CastOptionsImpl, +) -> PolarsResult> { + use ArrowDataType::*; + let from_type = array.dtype(); + + // clone array if types are the same + if from_type == to_type { + return Ok(clone(array)); + } + + let as_options = options.with_wrapped(true); + match (from_type, to_type) { + (Null, _) | (_, Null) => Ok(new_null_array(to_type.clone(), array.len())), + (Struct(from_fd), Struct(to_fd)) => { + polars_ensure!(from_fd.len() == to_fd.len(), InvalidOperation: "Cannot cast struct with different number of fields."); + cast_struct(array.as_any().downcast_ref().unwrap(), to_type, options).map(|x| x.boxed()) + }, + (Struct(_), _) | (_, Struct(_)) => polars_bail!(InvalidOperation: + "Cannot cast from struct to other types" + ), + (Dictionary(index_type, ..), _) => match_integer_type!(index_type, |$T| { + dictionary_cast_dyn::<$T>(array, to_type, options) + }), + (_, Dictionary(index_type, value_type, _)) => match_integer_type!(index_type, |$T| { + cast_to_dictionary::<$T>(array, value_type, options) + }), + // not supported by polars + // (List(_), FixedSizeList(inner, size)) => cast_list_to_fixed_size_list::( + // array.as_any().downcast_ref().unwrap(), + // inner.as_ref(), + // *size, + // options, + // ) + // .map(|x| x.boxed()), + (LargeList(_), FixedSizeList(inner, size)) => cast_list_to_fixed_size_list::( + array.as_any().downcast_ref().unwrap(), + inner.as_ref(), + *size, + options, + ) + .map(|x| x.boxed()), + (FixedSizeList(_, _), List(_)) => cast_fixed_size_list_to_list::( + array.as_any().downcast_ref().unwrap(), + to_type, + options, + ) + .map(|x| x.boxed()), + (FixedSizeList(_, _), LargeList(_)) => cast_fixed_size_list_to_list::( + array.as_any().downcast_ref().unwrap(), + to_type, + options, + ) + .map(|x| x.boxed()), + (BinaryView, _) => match to_type { + Utf8View => array + .as_any() + .downcast_ref::() + .unwrap() + .to_utf8view() + .map(|arr| arr.boxed()), + LargeBinary => Ok(binview_to::view_to_binary::( + array.as_any().downcast_ref().unwrap(), + ) + .boxed()), + LargeList(inner) if matches!(inner.dtype, ArrowDataType::UInt8) => { + let bin_array = view_to_binary::(array.as_any().downcast_ref().unwrap()); + Ok(binary_to_list(&bin_array, to_type.clone()).boxed()) + }, + _ => polars_bail!(InvalidOperation: + "casting from {from_type:?} to {to_type:?} not supported", + ), + }, + (LargeList(_), LargeList(_)) => { + cast_list::(array.as_any().downcast_ref().unwrap(), to_type, options) + .map(|x| x.boxed()) + }, + (List(lhs), LargeList(rhs)) if lhs == rhs => { + Ok(cast_list_to_large_list(array.as_any().downcast_ref().unwrap(), to_type).boxed()) + }, + (LargeList(lhs), List(rhs)) if lhs == rhs => { + Ok(cast_large_to_list(array.as_any().downcast_ref().unwrap(), to_type).boxed()) + }, + + (_, List(to)) => { + // cast primitive to list's primitive + let values = cast(array, &to.dtype, options)?; + // create offsets, where if array.len() = 2, we have [0,1,2] + let offsets = (0..=array.len() as i32).collect::>(); + // SAFETY: offsets _are_ monotonically increasing + let offsets = unsafe { Offsets::new_unchecked(offsets) }; + + let list_array = ListArray::::new(to_type.clone(), offsets.into(), values, None); + + Ok(Box::new(list_array)) + }, + + (_, LargeList(to)) if from_type != &LargeBinary => { + // cast primitive to list's primitive + let values = cast(array, &to.dtype, options)?; + // create offsets, where if array.len() = 2, we have [0,1,2] + let offsets = (0..=array.len() as i64).collect::>(); + // SAFETY: offsets _are_ monotonically increasing + let offsets = unsafe { Offsets::new_unchecked(offsets) }; + + let list_array = ListArray::::new( + to_type.clone(), + offsets.into(), + values, + array.validity().cloned(), + ); + + Ok(Box::new(list_array)) + }, + + (Utf8View, _) => { + let arr = array.as_any().downcast_ref::().unwrap(); + + match to_type { + BinaryView => Ok(arr.to_binview().boxed()), + LargeUtf8 => Ok(binview_to::utf8view_to_utf8::(arr).boxed()), + UInt8 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), + UInt16 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), + UInt32 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), + UInt64 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), + Int8 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), + Int16 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), + Int32 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), + Int64 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), + #[cfg(feature = "dtype-i128")] + Int128 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), + Float32 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), + Float64 => binview_to_primitive_dyn::(&arr.to_binview(), to_type, options), + Timestamp(time_unit, None) => { + utf8view_to_naive_timestamp_dyn(array, time_unit.to_owned()) + }, + Timestamp(time_unit, Some(time_zone)) => utf8view_to_timestamp( + array.as_any().downcast_ref().unwrap(), + RFC3339, + time_zone.clone(), + time_unit.to_owned(), + ) + .map(|arr| arr.boxed()), + Date32 => utf8view_to_date32_dyn(array), + #[cfg(feature = "dtype-decimal")] + Decimal(precision, scale) => { + Ok(binview_to_decimal(&arr.to_binview(), Some(*precision), *scale).to_boxed()) + }, + _ => polars_bail!(InvalidOperation: + "casting from {from_type:?} to {to_type:?} not supported", + ), + } + }, + + (_, Boolean) => match from_type { + UInt8 => primitive_to_boolean_dyn::(array, to_type.clone()), + UInt16 => primitive_to_boolean_dyn::(array, to_type.clone()), + UInt32 => primitive_to_boolean_dyn::(array, to_type.clone()), + UInt64 => primitive_to_boolean_dyn::(array, to_type.clone()), + Int8 => primitive_to_boolean_dyn::(array, to_type.clone()), + Int16 => primitive_to_boolean_dyn::(array, to_type.clone()), + Int32 => primitive_to_boolean_dyn::(array, to_type.clone()), + Int64 => primitive_to_boolean_dyn::(array, to_type.clone()), + #[cfg(feature = "dtype-i128")] + Int128 => primitive_to_boolean_dyn::(array, to_type.clone()), + Float32 => primitive_to_boolean_dyn::(array, to_type.clone()), + Float64 => primitive_to_boolean_dyn::(array, to_type.clone()), + Decimal(_, _) => primitive_to_boolean_dyn::(array, to_type.clone()), + _ => polars_bail!(InvalidOperation: + "casting from {from_type:?} to {to_type:?} not supported", + ), + }, + (Boolean, _) => match to_type { + UInt8 => boolean_to_primitive_dyn::(array), + UInt16 => boolean_to_primitive_dyn::(array), + UInt32 => boolean_to_primitive_dyn::(array), + UInt64 => boolean_to_primitive_dyn::(array), + Int8 => boolean_to_primitive_dyn::(array), + Int16 => boolean_to_primitive_dyn::(array), + Int32 => boolean_to_primitive_dyn::(array), + Int64 => boolean_to_primitive_dyn::(array), + #[cfg(feature = "dtype-i128")] + Int128 => boolean_to_primitive_dyn::(array), + Float32 => boolean_to_primitive_dyn::(array), + Float64 => boolean_to_primitive_dyn::(array), + Utf8View => boolean_to_utf8view_dyn(array), + BinaryView => boolean_to_binaryview_dyn(array), + _ => polars_bail!(InvalidOperation: + "casting from {from_type:?} to {to_type:?} not supported", + ), + }, + (_, BinaryView) => from_to_binview(array, from_type, to_type).map(|arr| arr.boxed()), + (_, Utf8View) => match from_type { + LargeUtf8 => Ok(utf8_to_utf8view( + array.as_any().downcast_ref::>().unwrap(), + ) + .boxed()), + Utf8 => Ok( + utf8_to_utf8view(array.as_any().downcast_ref::>().unwrap()).boxed(), + ), + #[cfg(feature = "dtype-decimal")] + Decimal(_, _) => Ok(decimal_to_utf8view_dyn(array).boxed()), + _ => from_to_binview(array, from_type, to_type) + .map(|arr| unsafe { arr.to_utf8view_unchecked() }.boxed()), + }, + (Utf8, _) => match to_type { + LargeUtf8 => Ok(Box::new(utf8_to_large_utf8( + array.as_any().downcast_ref().unwrap(), + ))), + _ => polars_bail!(InvalidOperation: + "casting from {from_type:?} to {to_type:?} not supported", + ), + }, + (LargeUtf8, _) => match to_type { + LargeBinary => Ok(utf8_to_binary::( + array.as_any().downcast_ref().unwrap(), + to_type.clone(), + ) + .boxed()), + _ => polars_bail!(InvalidOperation: + "casting from {from_type:?} to {to_type:?} not supported", + ), + }, + (_, LargeUtf8) => match from_type { + UInt8 => primitive_to_utf8_dyn::(array), + LargeBinary => { + binary_to_utf8::(array.as_any().downcast_ref().unwrap(), to_type.clone()) + .map(|x| x.boxed()) + }, + _ => polars_bail!(InvalidOperation: + "casting from {from_type:?} to {to_type:?} not supported", + ), + }, + + (Binary, _) => match to_type { + LargeBinary => Ok(Box::new(binary_to_large_binary( + array.as_any().downcast_ref().unwrap(), + to_type.clone(), + ))), + _ => polars_bail!(InvalidOperation: + "casting from {from_type:?} to {to_type:?} not supported", + ), + }, + + (LargeBinary, _) => match to_type { + UInt8 => binary_to_primitive_dyn::(array, to_type, options), + UInt16 => binary_to_primitive_dyn::(array, to_type, options), + UInt32 => binary_to_primitive_dyn::(array, to_type, options), + UInt64 => binary_to_primitive_dyn::(array, to_type, options), + Int8 => binary_to_primitive_dyn::(array, to_type, options), + Int16 => binary_to_primitive_dyn::(array, to_type, options), + Int32 => binary_to_primitive_dyn::(array, to_type, options), + Int64 => binary_to_primitive_dyn::(array, to_type, options), + #[cfg(feature = "dtype-i128")] + Int128 => binary_to_primitive_dyn::(array, to_type, options), + Float32 => binary_to_primitive_dyn::(array, to_type, options), + Float64 => binary_to_primitive_dyn::(array, to_type, options), + Binary => { + binary_large_to_binary(array.as_any().downcast_ref().unwrap(), to_type.clone()) + .map(|x| x.boxed()) + }, + LargeUtf8 => { + binary_to_utf8::(array.as_any().downcast_ref().unwrap(), to_type.clone()) + .map(|x| x.boxed()) + }, + _ => polars_bail!(InvalidOperation: + "casting from {from_type:?} to {to_type:?} not supported", + ), + }, + (FixedSizeBinary(_), _) => match to_type { + Binary => Ok(fixed_size_binary_binary::( + array.as_any().downcast_ref().unwrap(), + to_type.clone(), + ) + .boxed()), + LargeBinary => Ok(fixed_size_binary_binary::( + array.as_any().downcast_ref().unwrap(), + to_type.clone(), + ) + .boxed()), + _ => polars_bail!(InvalidOperation: + "casting from {from_type:?} to {to_type:?} not supported", + ), + }, + // start numeric casts + (UInt8, UInt16) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt8, UInt32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt8, UInt64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt8, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt8, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt8, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt8, Int64) => primitive_to_primitive_dyn::(array, to_type, options), + #[cfg(feature = "dtype-i128")] + (UInt8, Int128) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt8, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt8, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt8, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + + (UInt16, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt16, UInt32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt16, UInt64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt16, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt16, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt16, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt16, Int64) => primitive_to_primitive_dyn::(array, to_type, options), + #[cfg(feature = "dtype-i128")] + (UInt16, Int128) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt16, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt16, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt16, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + + (UInt32, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt32, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt32, UInt64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt32, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt32, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt32, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt32, Int64) => primitive_to_primitive_dyn::(array, to_type, options), + #[cfg(feature = "dtype-i128")] + (UInt32, Int128) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt32, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt32, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt32, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + + (UInt64, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt64, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt64, UInt32) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt64, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt64, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt64, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt64, Int64) => primitive_to_primitive_dyn::(array, to_type, options), + #[cfg(feature = "dtype-i128")] + (UInt64, Int128) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt64, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt64, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt64, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + + (Int8, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (Int8, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (Int8, UInt32) => primitive_to_primitive_dyn::(array, to_type, options), + (Int8, UInt64) => primitive_to_primitive_dyn::(array, to_type, options), + (Int8, Int16) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int8, Int32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int8, Int64) => primitive_to_primitive_dyn::(array, to_type, as_options), + #[cfg(feature = "dtype-i128")] + (Int8, Int128) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int8, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int8, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int8, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + + (Int16, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (Int16, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (Int16, UInt32) => primitive_to_primitive_dyn::(array, to_type, options), + (Int16, UInt64) => primitive_to_primitive_dyn::(array, to_type, options), + (Int16, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (Int16, Int32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int16, Int64) => primitive_to_primitive_dyn::(array, to_type, as_options), + #[cfg(feature = "dtype-i128")] + (Int16, Int128) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int16, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int16, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int16, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + + (Int32, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (Int32, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (Int32, UInt32) => primitive_to_primitive_dyn::(array, to_type, options), + (Int32, UInt64) => primitive_to_primitive_dyn::(array, to_type, options), + (Int32, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (Int32, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (Int32, Int64) => primitive_to_primitive_dyn::(array, to_type, as_options), + #[cfg(feature = "dtype-i128")] + (Int32, Int128) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int32, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int32, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int32, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + + (Int64, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, UInt32) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, UInt64) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + #[cfg(feature = "dtype-i128")] + (Int64, Int128) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, Float32) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int64, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + + #[cfg(feature = "dtype-i128")] + (Int128, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + #[cfg(feature = "dtype-i128")] + (Int128, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + #[cfg(feature = "dtype-i128")] + (Int128, UInt32) => primitive_to_primitive_dyn::(array, to_type, options), + #[cfg(feature = "dtype-i128")] + (Int128, UInt64) => primitive_to_primitive_dyn::(array, to_type, options), + #[cfg(feature = "dtype-i128")] + (Int128, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + #[cfg(feature = "dtype-i128")] + (Int128, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + #[cfg(feature = "dtype-i128")] + (Int128, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + #[cfg(feature = "dtype-i128")] + (Int128, Int64) => primitive_to_primitive_dyn::(array, to_type, options), + #[cfg(feature = "dtype-i128")] + (Int128, Float32) => primitive_to_primitive_dyn::(array, to_type, options), + #[cfg(feature = "dtype-i128")] + (Int128, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + #[cfg(feature = "dtype-i128")] + (Int128, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + + (Float16, Float32) => { + let from = array.as_any().downcast_ref().unwrap(); + Ok(f16_to_f32(from).boxed()) + }, + + (Float32, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, UInt32) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, UInt64) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, Int64) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, Int128) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Float32, Decimal(p, s)) => float_to_decimal_dyn::(array, *p, *s), + + (Float64, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, UInt32) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, UInt64) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, Int64) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, Int128) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, Float32) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, Decimal(p, s)) => float_to_decimal_dyn::(array, *p, *s), + + (Decimal(_, _), UInt8) => decimal_to_integer_dyn::(array), + (Decimal(_, _), UInt16) => decimal_to_integer_dyn::(array), + (Decimal(_, _), UInt32) => decimal_to_integer_dyn::(array), + (Decimal(_, _), UInt64) => decimal_to_integer_dyn::(array), + (Decimal(_, _), Int8) => decimal_to_integer_dyn::(array), + (Decimal(_, _), Int16) => decimal_to_integer_dyn::(array), + (Decimal(_, _), Int32) => decimal_to_integer_dyn::(array), + (Decimal(_, _), Int64) => decimal_to_integer_dyn::(array), + (Decimal(_, _), Int128) => decimal_to_integer_dyn::(array), + (Decimal(_, _), Float32) => decimal_to_float_dyn::(array), + (Decimal(_, _), Float64) => decimal_to_float_dyn::(array), + (Decimal(_, _), Decimal(to_p, to_s)) => decimal_to_decimal_dyn(array, *to_p, *to_s), + // end numeric casts + + // temporal casts + (Int32, Date32) => primitive_to_same_primitive_dyn::(array, to_type), + (Int32, Time32(TimeUnit::Second)) => primitive_dyn!(array, int32_to_time32s), + (Int32, Time32(TimeUnit::Millisecond)) => primitive_dyn!(array, int32_to_time32ms), + // No support for microsecond/nanosecond with i32 + (Date32, Int32) => primitive_to_same_primitive_dyn::(array, to_type), + (Date32, Int64) => primitive_to_primitive_dyn::(array, to_type, options), + (Time32(_), Int32) => primitive_to_same_primitive_dyn::(array, to_type), + (Int64, Date64) => primitive_to_same_primitive_dyn::(array, to_type), + // No support for second/milliseconds with i64 + (Int64, Time64(TimeUnit::Microsecond)) => primitive_dyn!(array, int64_to_time64us), + (Int64, Time64(TimeUnit::Nanosecond)) => primitive_dyn!(array, int64_to_time64ns), + + (Date64, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + (Date64, Int64) => primitive_to_same_primitive_dyn::(array, to_type), + (Time64(_), Int64) => primitive_to_same_primitive_dyn::(array, to_type), + (Date32, Date64) => primitive_dyn!(array, date32_to_date64), + (Date64, Date32) => primitive_dyn!(array, date64_to_date32), + (Time32(TimeUnit::Second), Time32(TimeUnit::Millisecond)) => { + primitive_dyn!(array, time32s_to_time32ms) + }, + (Time32(TimeUnit::Millisecond), Time32(TimeUnit::Second)) => { + primitive_dyn!(array, time32ms_to_time32s) + }, + (Time32(from_unit), Time64(to_unit)) => { + primitive_dyn!(array, time32_to_time64, *from_unit, *to_unit) + }, + (Time64(TimeUnit::Microsecond), Time64(TimeUnit::Nanosecond)) => { + primitive_dyn!(array, time64us_to_time64ns) + }, + (Time64(TimeUnit::Nanosecond), Time64(TimeUnit::Microsecond)) => { + primitive_dyn!(array, time64ns_to_time64us) + }, + (Time64(from_unit), Time32(to_unit)) => { + primitive_dyn!(array, time64_to_time32, *from_unit, *to_unit) + }, + (Timestamp(_, _), Int64) => primitive_to_same_primitive_dyn::(array, to_type), + (Int64, Timestamp(_, _)) => primitive_to_same_primitive_dyn::(array, to_type), + (Timestamp(from_unit, _), Timestamp(to_unit, tz)) => { + primitive_dyn!(array, timestamp_to_timestamp, *from_unit, *to_unit, tz) + }, + (Timestamp(from_unit, _), Date32) => primitive_dyn!(array, timestamp_to_date32, *from_unit), + (Timestamp(from_unit, _), Date64) => primitive_dyn!(array, timestamp_to_date64, *from_unit), + + (Int64, Duration(_)) => primitive_to_same_primitive_dyn::(array, to_type), + (Duration(_), Int64) => primitive_to_same_primitive_dyn::(array, to_type), + + // Not supported by Polars. + // (Interval(IntervalUnit::DayTime), Interval(IntervalUnit::MonthDayNano)) => { + // primitive_dyn!(array, days_ms_to_months_days_ns) + // }, + // (Interval(IntervalUnit::YearMonth), Interval(IntervalUnit::MonthDayNano)) => { + // primitive_dyn!(array, months_to_months_days_ns) + // }, + _ => polars_bail!(InvalidOperation: + "casting from {from_type:?} to {to_type:?} not supported", + ), + } +} + +/// Attempts to encode an array into an `ArrayDictionary` with index +/// type K and value (dictionary) type value_type +/// +/// K is the key type +fn cast_to_dictionary( + array: &dyn Array, + dict_value_type: &ArrowDataType, + options: CastOptionsImpl, +) -> PolarsResult> { + let array = cast(array, dict_value_type, options)?; + let array = array.as_ref(); + match *dict_value_type { + ArrowDataType::Int8 => primitive_to_dictionary_dyn::(array), + ArrowDataType::Int16 => primitive_to_dictionary_dyn::(array), + ArrowDataType::Int32 => primitive_to_dictionary_dyn::(array), + ArrowDataType::Int64 => primitive_to_dictionary_dyn::(array), + ArrowDataType::UInt8 => primitive_to_dictionary_dyn::(array), + ArrowDataType::UInt16 => primitive_to_dictionary_dyn::(array), + ArrowDataType::UInt32 => primitive_to_dictionary_dyn::(array), + ArrowDataType::UInt64 => primitive_to_dictionary_dyn::(array), + ArrowDataType::BinaryView => { + binview_to_dictionary::(array.as_any().downcast_ref().unwrap()) + .map(|arr| arr.boxed()) + }, + ArrowDataType::Utf8View => { + utf8view_to_dictionary::(array.as_any().downcast_ref().unwrap()) + .map(|arr| arr.boxed()) + }, + ArrowDataType::LargeUtf8 => utf8_to_dictionary_dyn::(array), + ArrowDataType::LargeBinary => binary_to_dictionary_dyn::(array), + ArrowDataType::Time64(_) => primitive_to_dictionary_dyn::(array), + ArrowDataType::Timestamp(_, _) => primitive_to_dictionary_dyn::(array), + ArrowDataType::Date32 => primitive_to_dictionary_dyn::(array), + _ => polars_bail!(ComputeError: + "unsupported output type for dictionary packing: {dict_value_type:?}" + ), + } +} + +fn from_to_binview( + array: &dyn Array, + from_type: &ArrowDataType, + to_type: &ArrowDataType, +) -> PolarsResult { + use ArrowDataType::*; + let binview = match from_type { + UInt8 => primitive_to_binview_dyn::(array), + UInt16 => primitive_to_binview_dyn::(array), + UInt32 => primitive_to_binview_dyn::(array), + UInt64 => primitive_to_binview_dyn::(array), + Int8 => primitive_to_binview_dyn::(array), + Int16 => primitive_to_binview_dyn::(array), + Int32 => primitive_to_binview_dyn::(array), + Int64 => primitive_to_binview_dyn::(array), + Int128 => primitive_to_binview_dyn::(array), + Float32 => primitive_to_binview_dyn::(array), + Float64 => primitive_to_binview_dyn::(array), + Binary => binary_to_binview::(array.as_any().downcast_ref().unwrap()), + FixedSizeBinary(_) => fixed_size_binary_to_binview(array.as_any().downcast_ref().unwrap()), + LargeBinary => binary_to_binview::(array.as_any().downcast_ref().unwrap()), + _ => polars_bail!(InvalidOperation: + "casting from {from_type:?} to {to_type:?} not supported", + ), + }; + Ok(binview) +} diff --git a/crates/polars-compute/src/cast/primitive_to.rs b/crates/polars-compute/src/cast/primitive_to.rs new file mode 100644 index 000000000000..5f973ddfbd14 --- /dev/null +++ b/crates/polars-compute/src/cast/primitive_to.rs @@ -0,0 +1,556 @@ +use std::hash::Hash; + +use arrow::array::*; +use arrow::bitmap::Bitmap; +use arrow::compute::arity::unary; +use arrow::datatypes::{ArrowDataType, TimeUnit}; +use arrow::offset::{Offset, Offsets}; +use arrow::types::{NativeType, f16}; +use num_traits::{AsPrimitive, Float, ToPrimitive}; +use polars_error::PolarsResult; +use polars_utils::pl_str::PlSmallStr; + +use super::CastOptionsImpl; +use super::temporal::*; + +pub trait SerPrimitive { + fn write(f: &mut Vec, val: Self) -> usize + where + Self: Sized; +} + +macro_rules! impl_ser_primitive { + ($ptype:ident) => { + impl SerPrimitive for $ptype { + fn write(f: &mut Vec, val: Self) -> usize + where + Self: Sized, + { + let mut buffer = itoa::Buffer::new(); + let value = buffer.format(val); + f.extend_from_slice(value.as_bytes()); + value.len() + } + } + }; +} + +impl_ser_primitive!(i8); +impl_ser_primitive!(i16); +impl_ser_primitive!(i32); +impl_ser_primitive!(i64); +impl_ser_primitive!(i128); +impl_ser_primitive!(u8); +impl_ser_primitive!(u16); +impl_ser_primitive!(u32); +impl_ser_primitive!(u64); + +impl SerPrimitive for f32 { + fn write(f: &mut Vec, val: Self) -> usize + where + Self: Sized, + { + let mut buffer = ryu::Buffer::new(); + let value = buffer.format(val); + f.extend_from_slice(value.as_bytes()); + value.len() + } +} + +impl SerPrimitive for f64 { + fn write(f: &mut Vec, val: Self) -> usize + where + Self: Sized, + { + let mut buffer = ryu::Buffer::new(); + let value = buffer.format(val); + f.extend_from_slice(value.as_bytes()); + value.len() + } +} + +fn primitive_to_values_and_offsets( + from: &PrimitiveArray, +) -> (Vec, Offsets) { + let mut values: Vec = Vec::with_capacity(from.len()); + let mut offsets: Vec = Vec::with_capacity(from.len() + 1); + offsets.push(O::default()); + + let mut offset: usize = 0; + + unsafe { + for &x in from.values().iter() { + let len = T::write(&mut values, x); + + offset += len; + offsets.push(O::from_as_usize(offset)); + } + values.set_len(offset); + values.shrink_to_fit(); + // SAFETY: offsets _are_ monotonically increasing + let offsets = Offsets::new_unchecked(offsets); + + (values, offsets) + } +} + +/// Returns a [`BooleanArray`] where every element is different from zero. +/// Validity is preserved. +pub fn primitive_to_boolean( + from: &PrimitiveArray, + to_type: ArrowDataType, +) -> BooleanArray { + let iter = from.values().iter().map(|v| *v != T::default()); + let values = Bitmap::from_trusted_len_iter(iter); + + BooleanArray::new(to_type, values, from.validity().cloned()) +} + +pub(super) fn primitive_to_boolean_dyn( + from: &dyn Array, + to_type: ArrowDataType, +) -> PolarsResult> +where + T: NativeType, +{ + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(primitive_to_boolean::(from, to_type))) +} + +/// Returns a [`Utf8Array`] where every element is the utf8 representation of the number. +pub(super) fn primitive_to_utf8( + from: &PrimitiveArray, +) -> Utf8Array { + let (values, offsets) = primitive_to_values_and_offsets(from); + unsafe { + Utf8Array::::new_unchecked( + Utf8Array::::default_dtype(), + offsets.into(), + values.into(), + from.validity().cloned(), + ) + } +} + +pub(super) fn primitive_to_utf8_dyn(from: &dyn Array) -> PolarsResult> +where + O: Offset, + T: NativeType + SerPrimitive, +{ + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(primitive_to_utf8::(from))) +} + +pub(super) fn primitive_to_primitive_dyn( + from: &dyn Array, + to_type: &ArrowDataType, + options: CastOptionsImpl, +) -> PolarsResult> +where + I: NativeType + num_traits::NumCast + num_traits::AsPrimitive, + O: NativeType + num_traits::NumCast, +{ + let from = from.as_any().downcast_ref::>().unwrap(); + if options.wrapped { + Ok(Box::new(primitive_as_primitive::(from, to_type))) + } else { + Ok(Box::new(primitive_to_primitive::(from, to_type))) + } +} + +/// Cast [`PrimitiveArray`] to a [`PrimitiveArray`] of another physical type via numeric conversion. +pub fn primitive_to_primitive( + from: &PrimitiveArray, + to_type: &ArrowDataType, +) -> PrimitiveArray +where + I: NativeType + num_traits::NumCast, + O: NativeType + num_traits::NumCast, +{ + let iter = from + .iter() + .map(|v| v.and_then(|x| num_traits::cast::cast::(*x))); + PrimitiveArray::::from_trusted_len_iter(iter).to(to_type.clone()) +} + +/// Returns a [`PrimitiveArray`] with the cast values. Values are `None` on overflow +pub fn integer_to_decimal>( + from: &PrimitiveArray, + to_precision: usize, + to_scale: usize, +) -> PrimitiveArray { + let multiplier = 10_i128.pow(to_scale as u32); + + let min_for_precision = 9_i128 + .saturating_pow(1 + to_precision as u32) + .saturating_neg(); + let max_for_precision = 9_i128.saturating_pow(1 + to_precision as u32); + + let values = from.iter().map(|x| { + x.and_then(|x| { + x.as_().checked_mul(multiplier).and_then(|x| { + if x > max_for_precision || x < min_for_precision { + None + } else { + Some(x) + } + }) + }) + }); + + PrimitiveArray::::from_trusted_len_iter(values) + .to(ArrowDataType::Decimal(to_precision, to_scale)) +} + +pub(super) fn integer_to_decimal_dyn( + from: &dyn Array, + precision: usize, + scale: usize, +) -> PolarsResult> +where + T: NativeType + AsPrimitive, +{ + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(integer_to_decimal::(from, precision, scale))) +} + +/// Returns a [`PrimitiveArray`] with the cast values. Values are `None` on overflow +pub fn float_to_decimal( + from: &PrimitiveArray, + to_precision: usize, + to_scale: usize, +) -> PrimitiveArray +where + T: NativeType + Float + ToPrimitive, + f64: AsPrimitive, +{ + // 1.2 => 12 + let multiplier: T = (10_f64).powi(to_scale as i32).as_(); + + let min_for_precision = 9_i128 + .saturating_pow(1 + to_precision as u32) + .saturating_neg(); + let max_for_precision = 9_i128.saturating_pow(1 + to_precision as u32); + + let values = from.iter().map(|x| { + x.and_then(|x| { + let x = (*x * multiplier).to_i128()?; + if x > max_for_precision || x < min_for_precision { + None + } else { + Some(x) + } + }) + }); + + PrimitiveArray::::from_trusted_len_iter(values) + .to(ArrowDataType::Decimal(to_precision, to_scale)) +} + +pub(super) fn float_to_decimal_dyn( + from: &dyn Array, + precision: usize, + scale: usize, +) -> PolarsResult> +where + T: NativeType + Float + ToPrimitive, + f64: AsPrimitive, +{ + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(float_to_decimal::(from, precision, scale))) +} + +/// Cast [`PrimitiveArray`] as a [`PrimitiveArray`] +/// Same as `number as to_number_type` in rust +pub fn primitive_as_primitive( + from: &PrimitiveArray, + to_type: &ArrowDataType, +) -> PrimitiveArray +where + I: NativeType + num_traits::AsPrimitive, + O: NativeType, +{ + unary(from, num_traits::AsPrimitive::::as_, to_type.clone()) +} + +/// Cast [`PrimitiveArray`] to a [`PrimitiveArray`] of the same physical type. +/// This is O(1). +pub fn primitive_to_same_primitive( + from: &PrimitiveArray, + to_type: &ArrowDataType, +) -> PrimitiveArray +where + T: NativeType, +{ + PrimitiveArray::::new( + to_type.clone(), + from.values().clone(), + from.validity().cloned(), + ) +} + +/// Cast [`PrimitiveArray`] to a [`PrimitiveArray`] of the same physical type. +/// This is O(1). +pub(super) fn primitive_to_same_primitive_dyn( + from: &dyn Array, + to_type: &ArrowDataType, +) -> PolarsResult> +where + T: NativeType, +{ + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(primitive_to_same_primitive::(from, to_type))) +} + +pub(super) fn primitive_to_dictionary_dyn( + from: &dyn Array, +) -> PolarsResult> { + let from = from.as_any().downcast_ref().unwrap(); + primitive_to_dictionary::(from).map(|x| Box::new(x) as Box) +} + +/// Cast [`PrimitiveArray`] to [`DictionaryArray`]. Also known as packing. +/// # Errors +/// This function errors if the maximum key is smaller than the number of distinct elements +/// in the array. +pub fn primitive_to_dictionary( + from: &PrimitiveArray, +) -> PolarsResult> { + let iter = from.iter().map(|x| x.copied()); + let mut array = MutableDictionaryArray::::try_empty(MutablePrimitiveArray::::from( + from.dtype().clone(), + ))?; + array.reserve(from.len()); + array.try_extend(iter)?; + + Ok(array.into()) +} + +/// # Safety +/// +/// `dtype` should be valid for primitive. +pub unsafe fn primitive_map_is_valid( + from: &PrimitiveArray, + f: impl Fn(T) -> bool, + dtype: ArrowDataType, +) -> PrimitiveArray { + let values = from.values().clone(); + + let validity: Bitmap = values.iter().map(|&v| f(v)).collect(); + + let validity = if validity.unset_bits() > 0 { + let new_validity = match from.validity() { + None => validity, + Some(v) => v & &validity, + }; + + Some(new_validity) + } else { + from.validity().cloned() + }; + + // SAFETY: + // - Validity did not change length + // - dtype should be valid + unsafe { PrimitiveArray::new_unchecked(dtype, values, validity) } +} + +/// Conversion of `Int32` to `Time32(TimeUnit::Second)` +pub fn int32_to_time32s(from: &PrimitiveArray) -> PrimitiveArray { + // SAFETY: Time32(TimeUnit::Second) is valid for Int32 + unsafe { + primitive_map_is_valid( + from, + |v| (0..SECONDS_IN_DAY as i32).contains(&v), + ArrowDataType::Time32(TimeUnit::Second), + ) + } +} + +/// Conversion of `Int32` to `Time32(TimeUnit::Millisecond)` +pub fn int32_to_time32ms(from: &PrimitiveArray) -> PrimitiveArray { + // SAFETY: Time32(TimeUnit::Millisecond) is valid for Int32 + unsafe { + primitive_map_is_valid( + from, + |v| (0..MILLISECONDS_IN_DAY as i32).contains(&v), + ArrowDataType::Time32(TimeUnit::Millisecond), + ) + } +} + +/// Conversion of `Int64` to `Time32(TimeUnit::Microsecond)` +pub fn int64_to_time64us(from: &PrimitiveArray) -> PrimitiveArray { + // SAFETY: Time64(TimeUnit::Microsecond) is valid for Int64 + unsafe { + primitive_map_is_valid( + from, + |v| (0..MICROSECONDS_IN_DAY).contains(&v), + ArrowDataType::Time32(TimeUnit::Microsecond), + ) + } +} + +/// Conversion of `Int64` to `Time32(TimeUnit::Nanosecond)` +pub fn int64_to_time64ns(from: &PrimitiveArray) -> PrimitiveArray { + // SAFETY: Time64(TimeUnit::Nanosecond) is valid for Int64 + unsafe { + primitive_map_is_valid( + from, + |v| (0..NANOSECONDS_IN_DAY).contains(&v), + ArrowDataType::Time64(TimeUnit::Nanosecond), + ) + } +} + +/// Conversion of dates +pub fn date32_to_date64(from: &PrimitiveArray) -> PrimitiveArray { + unary( + from, + |x| x as i64 * MILLISECONDS_IN_DAY, + ArrowDataType::Date64, + ) +} + +/// Conversion of dates +pub fn date64_to_date32(from: &PrimitiveArray) -> PrimitiveArray { + unary( + from, + |x| (x / MILLISECONDS_IN_DAY) as i32, + ArrowDataType::Date32, + ) +} + +/// Conversion of times +pub fn time32s_to_time32ms(from: &PrimitiveArray) -> PrimitiveArray { + unary( + from, + |x| x * 1000, + ArrowDataType::Time32(TimeUnit::Millisecond), + ) +} + +/// Conversion of times +pub fn time32ms_to_time32s(from: &PrimitiveArray) -> PrimitiveArray { + unary(from, |x| x / 1000, ArrowDataType::Time32(TimeUnit::Second)) +} + +/// Conversion of times +pub fn time64us_to_time64ns(from: &PrimitiveArray) -> PrimitiveArray { + unary( + from, + |x| x * 1000, + ArrowDataType::Time64(TimeUnit::Nanosecond), + ) +} + +/// Conversion of times +pub fn time64ns_to_time64us(from: &PrimitiveArray) -> PrimitiveArray { + unary( + from, + |x| x / 1000, + ArrowDataType::Time64(TimeUnit::Microsecond), + ) +} + +/// Conversion of timestamp +pub fn timestamp_to_date64(from: &PrimitiveArray, from_unit: TimeUnit) -> PrimitiveArray { + let from_size = time_unit_multiple(from_unit); + let to_size = MILLISECONDS; + let to_type = ArrowDataType::Date64; + + // Scale time_array by (to_size / from_size) using a + // single integer operation, but need to avoid integer + // math rounding down to zero + + match to_size.cmp(&from_size) { + std::cmp::Ordering::Less => unary(from, |x| (x / (from_size / to_size)), to_type), + std::cmp::Ordering::Equal => primitive_to_same_primitive(from, &to_type), + std::cmp::Ordering::Greater => unary(from, |x| (x * (to_size / from_size)), to_type), + } +} + +/// Conversion of timestamp +pub fn timestamp_to_date32(from: &PrimitiveArray, from_unit: TimeUnit) -> PrimitiveArray { + let from_size = time_unit_multiple(from_unit) * SECONDS_IN_DAY; + unary(from, |x| (x / from_size) as i32, ArrowDataType::Date32) +} + +/// Conversion of time +pub fn time32_to_time64( + from: &PrimitiveArray, + from_unit: TimeUnit, + to_unit: TimeUnit, +) -> PrimitiveArray { + let from_size = time_unit_multiple(from_unit); + let to_size = time_unit_multiple(to_unit); + let divisor = to_size / from_size; + unary( + from, + |x| (x as i64 * divisor), + ArrowDataType::Time64(to_unit), + ) +} + +/// Conversion of time +pub fn time64_to_time32( + from: &PrimitiveArray, + from_unit: TimeUnit, + to_unit: TimeUnit, +) -> PrimitiveArray { + let from_size = time_unit_multiple(from_unit); + let to_size = time_unit_multiple(to_unit); + let divisor = from_size / to_size; + unary( + from, + |x| (x / divisor) as i32, + ArrowDataType::Time32(to_unit), + ) +} + +/// Conversion of timestamp +pub fn timestamp_to_timestamp( + from: &PrimitiveArray, + from_unit: TimeUnit, + to_unit: TimeUnit, + tz: &Option, +) -> PrimitiveArray { + let from_size = time_unit_multiple(from_unit); + let to_size = time_unit_multiple(to_unit); + let to_type = ArrowDataType::Timestamp(to_unit, tz.clone()); + // we either divide or multiply, depending on size of each unit + if from_size >= to_size { + unary(from, |x| (x / (from_size / to_size)), to_type) + } else { + unary(from, |x| (x * (to_size / from_size)), to_type) + } +} + +/// Casts f16 into f32 +pub fn f16_to_f32(from: &PrimitiveArray) -> PrimitiveArray { + unary(from, |x| x.to_f32(), ArrowDataType::Float32) +} + +/// Returns a [`Utf8Array`] where every element is the utf8 representation of the number. +pub(super) fn primitive_to_binview( + from: &PrimitiveArray, +) -> BinaryViewArray { + let mut mutable = MutableBinaryViewArray::with_capacity(from.len()); + + let mut scratch = vec![]; + for &x in from.values().iter() { + unsafe { scratch.set_len(0) }; + T::write(&mut scratch, x); + mutable.push_value_ignore_validity(&scratch) + } + + mutable.freeze().with_validity(from.validity().cloned()) +} + +pub(super) fn primitive_to_binview_dyn(from: &dyn Array) -> BinaryViewArray +where + T: NativeType + SerPrimitive, +{ + let from = from.as_any().downcast_ref().unwrap(); + primitive_to_binview::(from) +} diff --git a/crates/polars-compute/src/cast/temporal.rs b/crates/polars-compute/src/cast/temporal.rs new file mode 100644 index 000000000000..442c1cd5e9a8 --- /dev/null +++ b/crates/polars-compute/src/cast/temporal.rs @@ -0,0 +1,142 @@ +use arrow::array::{PrimitiveArray, Utf8ViewArray}; +use arrow::datatypes::{ArrowDataType, TimeUnit}; +pub use arrow::temporal_conversions::{ + EPOCH_DAYS_FROM_CE, MICROSECONDS, MICROSECONDS_IN_DAY, MILLISECONDS, MILLISECONDS_IN_DAY, + NANOSECONDS, NANOSECONDS_IN_DAY, SECONDS_IN_DAY, +}; +use arrow::temporal_conversions::{parse_offset, parse_offset_tz}; +use chrono::format::{Parsed, StrftimeItems}; +use polars_error::PolarsResult; +use polars_utils::pl_str::PlSmallStr; + +/// Get the time unit as a multiple of a second +pub const fn time_unit_multiple(unit: TimeUnit) -> i64 { + match unit { + TimeUnit::Second => 1, + TimeUnit::Millisecond => MILLISECONDS, + TimeUnit::Microsecond => MICROSECONDS, + TimeUnit::Nanosecond => NANOSECONDS, + } +} + +fn chrono_tz_utf_to_timestamp( + array: &Utf8ViewArray, + fmt: &str, + time_zone: PlSmallStr, + time_unit: TimeUnit, +) -> PolarsResult> { + let tz = parse_offset_tz(time_zone.as_str())?; + Ok(utf8view_to_timestamp_impl( + array, fmt, time_zone, tz, time_unit, + )) +} + +fn utf8view_to_timestamp_impl( + array: &Utf8ViewArray, + fmt: &str, + time_zone: PlSmallStr, + tz: T, + time_unit: TimeUnit, +) -> PrimitiveArray { + let iter = array + .iter() + .map(|x| x.and_then(|x| utf8_to_timestamp_scalar(x, fmt, &tz, &time_unit))); + + PrimitiveArray::from_trusted_len_iter(iter) + .to(ArrowDataType::Timestamp(time_unit, Some(time_zone))) +} + +/// Parses `value` to `Option` consistent with the Arrow's definition of timestamp with timezone. +/// +/// `tz` must be built from `timezone` (either via [`parse_offset`] or `chrono-tz`). +/// Returns in scale `tz` of `TimeUnit`. +#[inline] +pub fn utf8_to_timestamp_scalar( + value: &str, + fmt: &str, + tz: &T, + tu: &TimeUnit, +) -> Option { + let mut parsed = Parsed::new(); + let fmt = StrftimeItems::new(fmt); + let r = chrono::format::parse(&mut parsed, value, fmt).ok(); + if r.is_some() { + parsed + .to_datetime() + .map(|x| x.naive_utc()) + .map(|x| tz.from_utc_datetime(&x)) + .map(|x| match tu { + TimeUnit::Second => x.timestamp(), + TimeUnit::Millisecond => x.timestamp_millis(), + TimeUnit::Microsecond => x.timestamp_micros(), + TimeUnit::Nanosecond => x.timestamp_nanos_opt().unwrap(), + }) + .ok() + } else { + None + } +} + +/// Parses a [`Utf8Array`] to a timeozone-aware timestamp, i.e. [`PrimitiveArray`] with type `Timestamp(Nanosecond, Some(timezone))`. +/// +/// # Implementation +/// +/// * parsed values with timezone other than `timezone` are converted to `timezone`. +/// * parsed values without timezone are null. Use [`utf8_to_naive_timestamp`] to parse naive timezones. +/// * Null elements remain null; non-parsable elements are null. +/// +/// The feature `"chrono-tz"` enables IANA and zoneinfo formats for `timezone`. +/// +/// # Error +/// +/// This function errors iff `timezone` is not parsable to an offset. +pub(crate) fn utf8view_to_timestamp( + array: &Utf8ViewArray, + fmt: &str, + time_zone: PlSmallStr, + time_unit: TimeUnit, +) -> PolarsResult> { + let tz = parse_offset(time_zone.as_str()); + + if let Ok(tz) = tz { + Ok(utf8view_to_timestamp_impl( + array, fmt, time_zone, tz, time_unit, + )) + } else { + chrono_tz_utf_to_timestamp(array, fmt, time_zone, time_unit) + } +} + +/// Parses a [`Utf8Array`] to naive timestamp, i.e. +/// [`PrimitiveArray`] with type `Timestamp(Nanosecond, None)`. +/// Timezones are ignored. +/// Null elements remain null; non-parsable elements are set to null. +pub(crate) fn utf8view_to_naive_timestamp( + array: &Utf8ViewArray, + fmt: &str, + time_unit: TimeUnit, +) -> PrimitiveArray { + let iter = array + .iter() + .map(|x| x.and_then(|x| utf8_to_naive_timestamp_scalar(x, fmt, &time_unit))); + + PrimitiveArray::from_trusted_len_iter(iter).to(ArrowDataType::Timestamp(time_unit, None)) +} + +/// Parses `value` to `Option` consistent with the Arrow's definition of timestamp without timezone. +/// Returns in scale `tz` of `TimeUnit`. +#[inline] +pub fn utf8_to_naive_timestamp_scalar(value: &str, fmt: &str, tu: &TimeUnit) -> Option { + let fmt = StrftimeItems::new(fmt); + let mut parsed = Parsed::new(); + chrono::format::parse(&mut parsed, value, fmt.clone()).ok(); + parsed + .to_naive_datetime_with_offset(0) + .map(|x| match tu { + TimeUnit::Second => x.and_utc().timestamp(), + TimeUnit::Millisecond => x.and_utc().timestamp_millis(), + TimeUnit::Microsecond => x.and_utc().timestamp_micros(), + TimeUnit::Nanosecond => x.and_utc().timestamp_nanos_opt().unwrap(), + }) + .ok() +} diff --git a/crates/polars-compute/src/cast/utf8_to.rs b/crates/polars-compute/src/cast/utf8_to.rs new file mode 100644 index 000000000000..7278d16236f4 --- /dev/null +++ b/crates/polars-compute/src/cast/utf8_to.rs @@ -0,0 +1,199 @@ +use std::sync::Arc; + +use arrow::array::*; +use arrow::buffer::Buffer; +use arrow::datatypes::ArrowDataType; +use arrow::offset::Offset; +use arrow::types::NativeType; +use polars_error::PolarsResult; +use polars_utils::vec::PushUnchecked; + +pub(super) const RFC3339: &str = "%Y-%m-%dT%H:%M:%S%.f%:z"; + +pub(super) fn utf8_to_dictionary_dyn( + from: &dyn Array, +) -> PolarsResult> { + let values = from.as_any().downcast_ref().unwrap(); + utf8_to_dictionary::(values).map(|x| Box::new(x) as Box) +} + +/// Cast [`Utf8Array`] to [`DictionaryArray`], also known as packing. +/// # Errors +/// This function errors if the maximum key is smaller than the number of distinct elements +/// in the array. +pub fn utf8_to_dictionary( + from: &Utf8Array, +) -> PolarsResult> { + let mut array = MutableDictionaryArray::>::new(); + array.reserve(from.len()); + array.try_extend(from.iter())?; + + Ok(array.into()) +} + +/// Conversion of utf8 +pub fn utf8_to_large_utf8(from: &Utf8Array) -> Utf8Array { + let dtype = Utf8Array::::default_dtype(); + let validity = from.validity().cloned(); + let values = from.values().clone(); + + let offsets = from.offsets().into(); + // SAFETY: sound because `values` fulfills the same invariants as `from.values()` + unsafe { Utf8Array::::new_unchecked(dtype, offsets, values, validity) } +} + +/// Conversion of utf8 +pub fn utf8_large_to_utf8(from: &Utf8Array) -> PolarsResult> { + let dtype = Utf8Array::::default_dtype(); + let validity = from.validity().cloned(); + let values = from.values().clone(); + let offsets = from.offsets().try_into()?; + + // SAFETY: sound because `values` fulfills the same invariants as `from.values()` + Ok(unsafe { Utf8Array::::new_unchecked(dtype, offsets, values, validity) }) +} + +/// Conversion to binary +pub fn utf8_to_binary(from: &Utf8Array, to_dtype: ArrowDataType) -> BinaryArray { + // SAFETY: erasure of an invariant is always safe + BinaryArray::::new( + to_dtype, + from.offsets().clone(), + from.values().clone(), + from.validity().cloned(), + ) +} + +// Different types to test the overflow path. +#[cfg(not(test))] +type OffsetType = u32; + +// To trigger overflow +#[cfg(test)] +type OffsetType = i8; + +// If we don't do this the GC of binview will trigger. As we will split up buffers into multiple +// chunks so that we don't overflow the offset u32. +fn truncate_buffer(buf: &Buffer) -> Buffer { + // * 2, as it must be able to hold u32::MAX offset + u32::MAX len. + buf.clone().sliced( + 0, + std::cmp::min(buf.len(), ((OffsetType::MAX as u64) * 2) as usize), + ) +} + +pub fn binary_to_binview(arr: &BinaryArray) -> BinaryViewArray { + // Ensure we didn't accidentally set wrong type + #[cfg(not(debug_assertions))] + let _ = std::mem::transmute::; + + let mut views = Vec::with_capacity(arr.len()); + let mut uses_buffer = false; + + let mut base_buffer = arr.values().clone(); + // Offset into the buffer + let mut base_ptr = base_buffer.as_ptr() as usize; + + // Offset into the binview buffers + let mut buffer_idx = 0_u32; + + // Binview buffers + // Note that the buffer may look far further than u32::MAX, but as we don't clone data + let mut buffers = vec![truncate_buffer(&base_buffer)]; + + for bytes in arr.values_iter() { + let len: u32 = bytes + .len() + .try_into() + .expect("max string/binary length exceeded"); + + let mut payload = [0; 16]; + payload[0..4].copy_from_slice(&len.to_le_bytes()); + + if len <= 12 { + payload[4..4 + bytes.len()].copy_from_slice(bytes); + } else { + uses_buffer = true; + + // Copy the parts we know are correct. + unsafe { payload[4..8].copy_from_slice(bytes.get_unchecked(0..4)) }; + payload[0..4].copy_from_slice(&len.to_le_bytes()); + + let current_bytes_ptr = bytes.as_ptr() as usize; + let offset = current_bytes_ptr - base_ptr; + + // Here we check the overflow of the buffer offset. + if let Ok(offset) = OffsetType::try_from(offset) { + #[allow(clippy::unnecessary_cast)] + let offset = offset as u32; + payload[12..16].copy_from_slice(&offset.to_le_bytes()); + payload[8..12].copy_from_slice(&buffer_idx.to_le_bytes()); + } else { + let len = base_buffer.len() - offset; + + // Set new buffer + base_buffer = base_buffer.clone().sliced(offset, len); + base_ptr = base_buffer.as_ptr() as usize; + + // And add the (truncated) one to the buffers + buffers.push(truncate_buffer(&base_buffer)); + buffer_idx = buffer_idx.checked_add(1).expect("max buffers exceeded"); + + let offset = 0u32; + payload[12..16].copy_from_slice(&offset.to_le_bytes()); + payload[8..12].copy_from_slice(&buffer_idx.to_le_bytes()); + } + } + + let value = View::from_le_bytes(payload); + unsafe { views.push_unchecked(value) }; + } + let buffers = if uses_buffer { + Arc::from(buffers) + } else { + Arc::from([]) + }; + unsafe { + BinaryViewArray::new_unchecked_unknown_md( + ArrowDataType::BinaryView, + views.into(), + buffers, + arr.validity().cloned(), + None, + ) + } +} + +pub fn utf8_to_utf8view(arr: &Utf8Array) -> Utf8ViewArray { + unsafe { binary_to_binview(&arr.to_binary()).to_utf8view_unchecked() } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn overflowing_utf8_to_binview() { + let values = [ + "lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 0 (offset) + "123", // inline + "lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 74 + "lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 0 (new buffer) + "lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 74 + "234", // inline + "lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 0 (new buffer) + "lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 74 + "lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 0 (new buffer) + "lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 74 + "324", // inline + ]; + let array = Utf8Array::::from_slice(values); + + let out = utf8_to_utf8view(&array); + // Ensure we hit the multiple buffers part. + assert_eq!(out.data_buffers().len(), 4); + // Ensure we created a valid binview + let out = out.values_iter().collect::>(); + assert_eq!(out, values); + } +} diff --git a/crates/polars-compute/src/comparisons/array.rs b/crates/polars-compute/src/comparisons/array.rs new file mode 100644 index 000000000000..870c8ed9c1e2 --- /dev/null +++ b/crates/polars-compute/src/comparisons/array.rs @@ -0,0 +1,255 @@ +use arrow::array::{ + Array, BinaryArray, BinaryViewArray, BooleanArray, DictionaryArray, FixedSizeBinaryArray, + FixedSizeListArray, ListArray, NullArray, PrimitiveArray, StructArray, Utf8Array, + Utf8ViewArray, +}; +use arrow::bitmap::Bitmap; +use arrow::bitmap::utils::count_zeros; +use arrow::datatypes::ArrowDataType; +use arrow::legacy::utils::CustomIterTools; +use arrow::types::{days_ms, f16, i256, months_days_ns}; + +use super::TotalEqKernel; +use crate::comparisons::dyn_array::{array_tot_eq_missing_kernel, array_tot_ne_missing_kernel}; + +/// Condenses a bitmap of n * width elements into one with n elements. +/// +/// For each block of width bits a zero count is done. The block of bits is then +/// replaced with a single bit: the result of true_zero_count(zero_count). +fn agg_array_bitmap(bm: Bitmap, width: usize, true_zero_count: F) -> Bitmap +where + F: Fn(usize) -> bool, +{ + if bm.len() == 1 { + bm + } else { + assert!(width > 0 && bm.len() % width == 0); + + let (slice, offset, _len) = bm.as_slice(); + (0..bm.len() / width) + .map(|i| true_zero_count(count_zeros(slice, offset + i * width, width))) + .collect() + } +} + +impl TotalEqKernel for FixedSizeListArray { + type Scalar = Box; + + fn tot_eq_kernel(&self, other: &Self) -> Bitmap { + // Nested comparison always done with eq_missing, propagating doesn't + // make any sense. + + assert_eq!(self.len(), other.len()); + let ArrowDataType::FixedSizeList(self_type, self_width) = self.dtype().to_logical_type() + else { + panic!("array comparison called with non-array type"); + }; + let ArrowDataType::FixedSizeList(other_type, other_width) = other.dtype().to_logical_type() + else { + panic!("array comparison called with non-array type"); + }; + assert_eq!(self_type.dtype(), other_type.dtype()); + + if self_width != other_width { + return Bitmap::new_with_value(false, self.len()); + } + + if *self_width == 0 { + return Bitmap::new_with_value(true, self.len()); + } + + // @TODO: It is probably worth it to dispatch to a special kernel for when there are + // several nested arrays because that can be rather slow with this code. + let inner = array_tot_eq_missing_kernel(self.values().as_ref(), other.values().as_ref()); + + agg_array_bitmap(inner, self.size(), |zeroes| zeroes == 0) + } + + fn tot_ne_kernel(&self, other: &Self) -> Bitmap { + assert_eq!(self.len(), other.len()); + let ArrowDataType::FixedSizeList(self_type, self_width) = self.dtype().to_logical_type() + else { + panic!("array comparison called with non-array type"); + }; + let ArrowDataType::FixedSizeList(other_type, other_width) = other.dtype().to_logical_type() + else { + panic!("array comparison called with non-array type"); + }; + assert_eq!(self_type.dtype(), other_type.dtype()); + + if self_width != other_width { + return Bitmap::new_with_value(true, self.len()); + } + + if *self_width == 0 { + return Bitmap::new_with_value(false, self.len()); + } + + // @TODO: It is probably worth it to dispatch to a special kernel for when there are + // several nested arrays because that can be rather slow with this code. + let inner = array_tot_ne_missing_kernel(self.values().as_ref(), other.values().as_ref()); + + agg_array_bitmap(inner, self.size(), |zeroes| zeroes < self.size()) + } + + fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + let ArrowDataType::FixedSizeList(self_type, width) = self.dtype().to_logical_type() else { + panic!("array comparison called with non-array type"); + }; + assert_eq!(self_type.dtype(), other.dtype().to_logical_type()); + + let width = *width; + + if width != other.len() { + return Bitmap::new_with_value(false, self.len()); + } + + if width == 0 { + return Bitmap::new_with_value(true, self.len()); + } + + // @TODO: It is probably worth it to dispatch to a special kernel for when there are + // several nested arrays because that can be rather slow with this code. + array_fsl_tot_eq_missing_kernel(self.values().as_ref(), other.as_ref(), self.len(), width) + } + + fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + let ArrowDataType::FixedSizeList(self_type, width) = self.dtype().to_logical_type() else { + panic!("array comparison called with non-array type"); + }; + assert_eq!(self_type.dtype(), other.dtype().to_logical_type()); + + let width = *width; + + if width != other.len() { + return Bitmap::new_with_value(true, self.len()); + } + + if width == 0 { + return Bitmap::new_with_value(false, self.len()); + } + + // @TODO: It is probably worth it to dispatch to a special kernel for when there are + // several nested arrays because that can be rather slow with this code. + array_fsl_tot_ne_missing_kernel(self.values().as_ref(), other.as_ref(), self.len(), width) + } +} + +macro_rules! compare { + ($lhs:expr, $rhs:expr, $length:expr, $width:expr, $op:path, $true_op:expr) => {{ + let lhs = $lhs; + let rhs = $rhs; + + macro_rules! call_binary { + ($T:ty) => {{ + let values: &$T = $lhs.as_any().downcast_ref().unwrap(); + let scalar: &$T = $rhs.as_any().downcast_ref().unwrap(); + + (0..$length) + .map(move |i| { + // @TODO: I feel like there is a better way to do this. + let mut values: $T = values.clone(); + <$T>::slice(&mut values, i * $width, $width); + + $true_op($op(&values, scalar)) + }) + .collect_trusted() + }}; + } + + assert_eq!(lhs.dtype(), rhs.dtype()); + + use arrow::datatypes::{IntegerType as I, PhysicalType as PH, PrimitiveType as PR}; + match lhs.dtype().to_physical_type() { + PH::Boolean => call_binary!(BooleanArray), + PH::BinaryView => call_binary!(BinaryViewArray), + PH::Utf8View => call_binary!(Utf8ViewArray), + PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Float16) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int256) => call_binary!(PrimitiveArray), + PH::Primitive(PR::DaysMs) => call_binary!(PrimitiveArray), + PH::Primitive(PR::MonthDayNano) => { + call_binary!(PrimitiveArray) + }, + + #[cfg(feature = "dtype-array")] + PH::FixedSizeList => call_binary!(arrow::array::FixedSizeListArray), + #[cfg(not(feature = "dtype-array"))] + PH::FixedSizeList => todo!( + "Comparison of FixedSizeListArray is not supported without dtype-array feature" + ), + + PH::Null => call_binary!(NullArray), + PH::FixedSizeBinary => call_binary!(FixedSizeBinaryArray), + PH::Binary => call_binary!(BinaryArray), + PH::LargeBinary => call_binary!(BinaryArray), + PH::Utf8 => call_binary!(Utf8Array), + PH::LargeUtf8 => call_binary!(Utf8Array), + PH::List => call_binary!(ListArray), + PH::LargeList => call_binary!(ListArray), + PH::Struct => call_binary!(StructArray), + PH::Union => todo!("Comparison of UnionArrays is not yet supported"), + PH::Map => todo!("Comparison of MapArrays is not yet supported"), + PH::Dictionary(I::Int8) => call_binary!(DictionaryArray), + PH::Dictionary(I::Int16) => call_binary!(DictionaryArray), + PH::Dictionary(I::Int32) => call_binary!(DictionaryArray), + PH::Dictionary(I::Int64) => call_binary!(DictionaryArray), + PH::Dictionary(I::Int128) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt8) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt16) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt32) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt64) => call_binary!(DictionaryArray), + } + }}; +} + +fn array_fsl_tot_eq_missing_kernel( + values: &dyn Array, + scalar: &dyn Array, + length: usize, + width: usize, +) -> Bitmap { + // @NOTE: Zero-Width Array are handled before + debug_assert_eq!(values.len(), length * width); + debug_assert_eq!(scalar.len(), width); + + compare!( + values, + scalar, + length, + width, + TotalEqKernel::tot_eq_missing_kernel, + |bm: Bitmap| bm.unset_bits() == 0 + ) +} + +fn array_fsl_tot_ne_missing_kernel( + values: &dyn Array, + scalar: &dyn Array, + length: usize, + width: usize, +) -> Bitmap { + // @NOTE: Zero-Width Array are handled before + debug_assert_eq!(values.len(), length * width); + debug_assert_eq!(scalar.len(), width); + + compare!( + values, + scalar, + length, + width, + TotalEqKernel::tot_ne_missing_kernel, + |bm: Bitmap| bm.set_bits() > 0 + ) +} diff --git a/crates/polars-compute/src/comparisons/binary.rs b/crates/polars-compute/src/comparisons/binary.rs new file mode 100644 index 000000000000..656057486aba --- /dev/null +++ b/crates/polars-compute/src/comparisons/binary.rs @@ -0,0 +1,114 @@ +use arrow::array::{BinaryArray, FixedSizeBinaryArray}; +use arrow::bitmap::Bitmap; +use arrow::types::Offset; +use polars_utils::total_ord::{TotalEq, TotalOrd}; + +use super::{TotalEqKernel, TotalOrdKernel}; + +impl TotalEqKernel for BinaryArray { + type Scalar = [u8]; + + fn tot_eq_kernel(&self, other: &Self) -> Bitmap { + assert!(self.len() == other.len()); + self.values_iter() + .zip(other.values_iter()) + .map(|(l, r)| l.tot_eq(&r)) + .collect() + } + + fn tot_ne_kernel(&self, other: &Self) -> Bitmap { + assert!(self.len() == other.len()); + self.values_iter() + .zip(other.values_iter()) + .map(|(l, r)| l.tot_ne(&r)) + .collect() + } + + fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.values_iter().map(|l| l.tot_eq(&other)).collect() + } + + fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.values_iter().map(|l| l.tot_ne(&other)).collect() + } +} + +impl TotalOrdKernel for BinaryArray { + type Scalar = [u8]; + + fn tot_lt_kernel(&self, other: &Self) -> Bitmap { + assert!(self.len() == other.len()); + self.values_iter() + .zip(other.values_iter()) + .map(|(l, r)| l.tot_lt(&r)) + .collect() + } + + fn tot_le_kernel(&self, other: &Self) -> Bitmap { + assert!(self.len() == other.len()); + self.values_iter() + .zip(other.values_iter()) + .map(|(l, r)| l.tot_le(&r)) + .collect() + } + + fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.values_iter().map(|l| l.tot_lt(&other)).collect() + } + + fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.values_iter().map(|l| l.tot_le(&other)).collect() + } + + fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.values_iter().map(|l| l.tot_gt(&other)).collect() + } + + fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.values_iter().map(|l| l.tot_ge(&other)).collect() + } +} + +impl TotalEqKernel for FixedSizeBinaryArray { + type Scalar = [u8]; + + fn tot_eq_kernel(&self, other: &Self) -> Bitmap { + assert!(self.len() == other.len()); + + if self.size() != other.size() { + return Bitmap::new_zeroed(self.len()); + } + + (0..self.len()) + .map(|i| self.value(i) == other.value(i)) + .collect() + } + + fn tot_ne_kernel(&self, other: &Self) -> Bitmap { + assert!(self.len() == other.len()); + + if self.size() != other.size() { + return Bitmap::new_with_value(true, self.len()); + } + + (0..self.len()) + .map(|i| self.value(i) != other.value(i)) + .collect() + } + + fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + if self.size() != other.len() { + return Bitmap::new_zeroed(self.len()); + } + + (0..self.len()).map(|i| self.value(i) == other).collect() + } + + fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + if self.size() != other.len() { + return Bitmap::new_with_value(true, self.len()); + } + + (0..self.len()).map(|i| self.value(i) != other).collect() + } +} diff --git a/crates/polars-compute/src/comparisons/boolean.rs b/crates/polars-compute/src/comparisons/boolean.rs new file mode 100644 index 000000000000..39a8f9b3814a --- /dev/null +++ b/crates/polars-compute/src/comparisons/boolean.rs @@ -0,0 +1,72 @@ +use arrow::array::BooleanArray; +use arrow::bitmap::{self, Bitmap}; + +use super::{TotalEqKernel, TotalOrdKernel}; + +impl TotalEqKernel for BooleanArray { + type Scalar = bool; + + fn tot_eq_kernel(&self, other: &Self) -> Bitmap { + bitmap::binary(self.values(), other.values(), |l, r| !(l ^ r)) + } + + fn tot_ne_kernel(&self, other: &Self) -> Bitmap { + self.values() ^ other.values() + } + + fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + if *other { + self.values().clone() + } else { + !self.values() + } + } + + fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.tot_eq_kernel_broadcast(&!*other) + } +} + +impl TotalOrdKernel for BooleanArray { + type Scalar = bool; + + fn tot_lt_kernel(&self, other: &Self) -> Bitmap { + bitmap::binary(self.values(), other.values(), |l, r| !l & r) + } + + fn tot_le_kernel(&self, other: &Self) -> Bitmap { + bitmap::binary(self.values(), other.values(), |l, r| !l | r) + } + + fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + if *other { + !self.values() + } else { + Bitmap::new_zeroed(self.len()) + } + } + + fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + if *other { + Bitmap::new_with_value(true, self.len()) + } else { + !self.values() + } + } + + fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + if *other { + Bitmap::new_zeroed(self.len()) + } else { + self.values().clone() + } + } + + fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + if *other { + self.values().clone() + } else { + Bitmap::new_with_value(true, self.len()) + } + } +} diff --git a/crates/polars-compute/src/comparisons/dictionary.rs b/crates/polars-compute/src/comparisons/dictionary.rs new file mode 100644 index 000000000000..8619a5048418 --- /dev/null +++ b/crates/polars-compute/src/comparisons/dictionary.rs @@ -0,0 +1,75 @@ +use arrow::array::{Array, DictionaryArray, DictionaryKey}; +use arrow::bitmap::{Bitmap, BitmapBuilder}; + +use super::TotalEqKernel; +use crate::comparisons::dyn_array::{array_tot_eq_missing_kernel, array_tot_ne_missing_kernel}; + +impl TotalEqKernel for DictionaryArray { + type Scalar = Box; + + fn tot_eq_kernel(&self, other: &Self) -> Bitmap { + assert_eq!(self.len(), other.len()); + + let mut bitmap = BitmapBuilder::with_capacity(self.len()); + + for i in 0..self.len() { + let lval = self.validity().is_none_or(|v| v.get(i).unwrap()); + let rval = other.validity().is_none_or(|v| v.get(i).unwrap()); + + if !lval || !rval { + bitmap.push(true); + continue; + } + + let lkey = self.key_value(i); + let rkey = other.key_value(i); + + let mut lhs_value = self.values().clone(); + lhs_value.slice(lkey, 1); + let mut rhs_value = other.values().clone(); + rhs_value.slice(rkey, 1); + + let result = array_tot_eq_missing_kernel(lhs_value.as_ref(), rhs_value.as_ref()); + bitmap.push(result.unset_bits() == 0); + } + + bitmap.freeze() + } + + fn tot_ne_kernel(&self, other: &Self) -> Bitmap { + assert_eq!(self.len(), other.len()); + + let mut bitmap = BitmapBuilder::with_capacity(self.len()); + + for i in 0..self.len() { + let lval = self.validity().is_none_or(|v| v.get(i).unwrap()); + let rval = other.validity().is_none_or(|v| v.get(i).unwrap()); + + if !lval || !rval { + bitmap.push(false); + continue; + } + + let lkey = self.key_value(i); + let rkey = other.key_value(i); + + let mut lhs_value = self.values().clone(); + lhs_value.slice(lkey, 1); + let mut rhs_value = other.values().clone(); + rhs_value.slice(rkey, 1); + + let result = array_tot_ne_missing_kernel(lhs_value.as_ref(), rhs_value.as_ref()); + bitmap.push(result.set_bits() > 0); + } + + bitmap.freeze() + } + + fn tot_eq_kernel_broadcast(&self, _other: &Self::Scalar) -> arrow::bitmap::Bitmap { + todo!() + } + + fn tot_ne_kernel_broadcast(&self, _other: &Self::Scalar) -> arrow::bitmap::Bitmap { + todo!() + } +} diff --git a/crates/polars-compute/src/comparisons/dyn_array.rs b/crates/polars-compute/src/comparisons/dyn_array.rs new file mode 100644 index 000000000000..07fd4bbd9a9d --- /dev/null +++ b/crates/polars-compute/src/comparisons/dyn_array.rs @@ -0,0 +1,86 @@ +use arrow::array::{ + Array, BinaryArray, BinaryViewArray, BooleanArray, DictionaryArray, FixedSizeBinaryArray, + ListArray, NullArray, PrimitiveArray, StructArray, Utf8Array, Utf8ViewArray, +}; +use arrow::bitmap::Bitmap; +use arrow::types::{days_ms, f16, i256, months_days_ns}; + +use crate::comparisons::TotalEqKernel; + +macro_rules! call_binary { + ($T:ty, $lhs:expr, $rhs:expr, $op:path) => {{ + let lhs: &$T = $lhs.as_any().downcast_ref().unwrap(); + let rhs: &$T = $rhs.as_any().downcast_ref().unwrap(); + $op(lhs, rhs) + }}; +} + +macro_rules! compare { + ($lhs:expr, $rhs:expr, $op:path) => {{ + let lhs = $lhs; + let rhs = $rhs; + + assert_eq!(lhs.dtype(), rhs.dtype()); + + use arrow::datatypes::{IntegerType as I, PhysicalType as PH, PrimitiveType as PR}; + match lhs.dtype().to_physical_type() { + PH::Boolean => call_binary!(BooleanArray, lhs, rhs, $op), + PH::BinaryView => call_binary!(BinaryViewArray, lhs, rhs, $op), + PH::Utf8View => call_binary!(Utf8ViewArray, lhs, rhs, $op), + PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray, lhs, rhs, $op), + PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray, lhs, rhs, $op), + PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray, lhs, rhs, $op), + PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray, lhs, rhs, $op), + PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray, lhs, rhs, $op), + PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray, lhs, rhs, $op), + PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray, lhs, rhs, $op), + PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray, lhs, rhs, $op), + PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray, lhs, rhs, $op), + PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray, lhs, rhs, $op), + PH::Primitive(PR::Float16) => call_binary!(PrimitiveArray, lhs, rhs, $op), + PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray, lhs, rhs, $op), + PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray, lhs, rhs, $op), + PH::Primitive(PR::Int256) => call_binary!(PrimitiveArray, lhs, rhs, $op), + PH::Primitive(PR::DaysMs) => call_binary!(PrimitiveArray, lhs, rhs, $op), + PH::Primitive(PR::MonthDayNano) => { + call_binary!(PrimitiveArray, lhs, rhs, $op) + }, + + #[cfg(feature = "dtype-array")] + PH::FixedSizeList => call_binary!(arrow::array::FixedSizeListArray, lhs, rhs, $op), + #[cfg(not(feature = "dtype-array"))] + PH::FixedSizeList => todo!( + "Comparison of FixedSizeListArray is not supported without dtype-array feature" + ), + + PH::Null => call_binary!(NullArray, lhs, rhs, $op), + PH::FixedSizeBinary => call_binary!(FixedSizeBinaryArray, lhs, rhs, $op), + PH::Binary => call_binary!(BinaryArray, lhs, rhs, $op), + PH::LargeBinary => call_binary!(BinaryArray, lhs, rhs, $op), + PH::Utf8 => call_binary!(Utf8Array, lhs, rhs, $op), + PH::LargeUtf8 => call_binary!(Utf8Array, lhs, rhs, $op), + PH::List => call_binary!(ListArray, lhs, rhs, $op), + PH::LargeList => call_binary!(ListArray, lhs, rhs, $op), + PH::Struct => call_binary!(StructArray, lhs, rhs, $op), + PH::Union => todo!("Comparison of UnionArrays is not yet supported"), + PH::Map => todo!("Comparison of MapArrays is not yet supported"), + PH::Dictionary(I::Int8) => call_binary!(DictionaryArray, lhs, rhs, $op), + PH::Dictionary(I::Int16) => call_binary!(DictionaryArray, lhs, rhs, $op), + PH::Dictionary(I::Int32) => call_binary!(DictionaryArray, lhs, rhs, $op), + PH::Dictionary(I::Int64) => call_binary!(DictionaryArray, lhs, rhs, $op), + PH::Dictionary(I::Int128) => call_binary!(DictionaryArray, lhs, rhs, $op), + PH::Dictionary(I::UInt8) => call_binary!(DictionaryArray, lhs, rhs, $op), + PH::Dictionary(I::UInt16) => call_binary!(DictionaryArray, lhs, rhs, $op), + PH::Dictionary(I::UInt32) => call_binary!(DictionaryArray, lhs, rhs, $op), + PH::Dictionary(I::UInt64) => call_binary!(DictionaryArray, lhs, rhs, $op), + } + }}; +} + +pub fn array_tot_eq_missing_kernel(lhs: &dyn Array, rhs: &dyn Array) -> Bitmap { + compare!(lhs, rhs, TotalEqKernel::tot_eq_missing_kernel) +} + +pub fn array_tot_ne_missing_kernel(lhs: &dyn Array, rhs: &dyn Array) -> Bitmap { + compare!(lhs, rhs, TotalEqKernel::tot_ne_missing_kernel) +} diff --git a/crates/polars-compute/src/comparisons/list.rs b/crates/polars-compute/src/comparisons/list.rs new file mode 100644 index 000000000000..4b2141274ed1 --- /dev/null +++ b/crates/polars-compute/src/comparisons/list.rs @@ -0,0 +1,259 @@ +use arrow::array::{ + Array, BinaryArray, BinaryViewArray, BooleanArray, DictionaryArray, FixedSizeBinaryArray, + ListArray, NullArray, PrimitiveArray, StructArray, Utf8Array, Utf8ViewArray, +}; +use arrow::bitmap::Bitmap; +use arrow::legacy::utils::CustomIterTools; +use arrow::types::{Offset, days_ms, f16, i256, months_days_ns}; + +use super::TotalEqKernel; + +macro_rules! compare { + ( + $lhs:expr, $rhs:expr, + $op:path, $true_op:expr, + $ineq_len_rv:literal, $invalid_rv:literal + ) => {{ + let lhs = $lhs; + let rhs = $rhs; + + assert_eq!(lhs.len(), rhs.len()); + assert_eq!(lhs.dtype(), rhs.dtype()); + + macro_rules! call_binary { + ($T:ty) => {{ + let lhs_values: &$T = $lhs.values().as_any().downcast_ref().unwrap(); + let rhs_values: &$T = $rhs.values().as_any().downcast_ref().unwrap(); + + (0..$lhs.len()) + .map(|i| { + let lval = $lhs.validity().is_none_or(|v| v.get(i).unwrap()); + let rval = $rhs.validity().is_none_or(|v| v.get(i).unwrap()); + + if !lval || !rval { + return $invalid_rv; + } + + // SAFETY: ListArray's invariant offsets.len_proxy() == len + let (lstart, lend) = unsafe { $lhs.offsets().start_end_unchecked(i) }; + let (rstart, rend) = unsafe { $rhs.offsets().start_end_unchecked(i) }; + + if lend - lstart != rend - rstart { + return $ineq_len_rv; + } + + let mut lhs_values = lhs_values.clone(); + lhs_values.slice(lstart, lend - lstart); + let mut rhs_values = rhs_values.clone(); + rhs_values.slice(rstart, rend - rstart); + + $true_op($op(&lhs_values, &rhs_values)) + }) + .collect_trusted() + }}; + } + + use arrow::datatypes::{IntegerType as I, PhysicalType as PH, PrimitiveType as PR}; + match lhs.values().dtype().to_physical_type() { + PH::Boolean => call_binary!(BooleanArray), + PH::BinaryView => call_binary!(BinaryViewArray), + PH::Utf8View => call_binary!(Utf8ViewArray), + PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Float16) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int256) => call_binary!(PrimitiveArray), + PH::Primitive(PR::DaysMs) => call_binary!(PrimitiveArray), + PH::Primitive(PR::MonthDayNano) => { + call_binary!(PrimitiveArray) + }, + + #[cfg(feature = "dtype-array")] + PH::FixedSizeList => call_binary!(arrow::array::FixedSizeListArray), + #[cfg(not(feature = "dtype-array"))] + PH::FixedSizeList => todo!( + "Comparison of FixedSizeListArray is not supported without dtype-array feature" + ), + + PH::Null => call_binary!(NullArray), + PH::FixedSizeBinary => call_binary!(FixedSizeBinaryArray), + PH::Binary => call_binary!(BinaryArray), + PH::LargeBinary => call_binary!(BinaryArray), + PH::Utf8 => call_binary!(Utf8Array), + PH::LargeUtf8 => call_binary!(Utf8Array), + PH::List => call_binary!(ListArray), + PH::LargeList => call_binary!(ListArray), + PH::Struct => call_binary!(StructArray), + PH::Union => todo!("Comparison of UnionArrays is not yet supported"), + PH::Map => todo!("Comparison of MapArrays is not yet supported"), + PH::Dictionary(I::Int8) => call_binary!(DictionaryArray), + PH::Dictionary(I::Int16) => call_binary!(DictionaryArray), + PH::Dictionary(I::Int32) => call_binary!(DictionaryArray), + PH::Dictionary(I::Int64) => call_binary!(DictionaryArray), + PH::Dictionary(I::Int128) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt8) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt16) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt32) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt64) => call_binary!(DictionaryArray), + } + }}; +} + +macro_rules! compare_broadcast { + ( + $lhs:expr, $rhs:expr, + $offsets:expr, $validity:expr, + $op:path, $true_op:expr, + $ineq_len_rv:literal, $invalid_rv:literal + ) => {{ + let lhs = $lhs; + let rhs = $rhs; + + macro_rules! call_binary { + ($T:ty) => {{ + let values: &$T = $lhs.as_any().downcast_ref().unwrap(); + let scalar: &$T = $rhs.as_any().downcast_ref().unwrap(); + + let length = $offsets.len_proxy(); + + (0..length) + .map(move |i| { + let v = $validity.is_none_or(|v| v.get(i).unwrap()); + + if !v { + return $invalid_rv; + } + + let (start, end) = unsafe { $offsets.start_end_unchecked(i) }; + + if end - start != scalar.len() { + return $ineq_len_rv; + } + + // @TODO: I feel like there is a better way to do this. + let mut values: $T = values.clone(); + <$T>::slice(&mut values, start, end - start); + + $true_op($op(&values, scalar)) + }) + .collect_trusted() + }}; + } + + assert_eq!(lhs.dtype(), rhs.dtype()); + + use arrow::datatypes::{IntegerType as I, PhysicalType as PH, PrimitiveType as PR}; + match lhs.dtype().to_physical_type() { + PH::Boolean => call_binary!(BooleanArray), + PH::BinaryView => call_binary!(BinaryViewArray), + PH::Utf8View => call_binary!(Utf8ViewArray), + PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Float16) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int256) => call_binary!(PrimitiveArray), + PH::Primitive(PR::DaysMs) => call_binary!(PrimitiveArray), + PH::Primitive(PR::MonthDayNano) => { + call_binary!(PrimitiveArray) + }, + + #[cfg(feature = "dtype-array")] + PH::FixedSizeList => call_binary!(arrow::array::FixedSizeListArray), + #[cfg(not(feature = "dtype-array"))] + PH::FixedSizeList => todo!( + "Comparison of FixedSizeListArray is not supported without dtype-array feature" + ), + + PH::Null => call_binary!(NullArray), + PH::FixedSizeBinary => call_binary!(FixedSizeBinaryArray), + PH::Binary => call_binary!(BinaryArray), + PH::LargeBinary => call_binary!(BinaryArray), + PH::Utf8 => call_binary!(Utf8Array), + PH::LargeUtf8 => call_binary!(Utf8Array), + PH::List => call_binary!(ListArray), + PH::LargeList => call_binary!(ListArray), + PH::Struct => call_binary!(StructArray), + PH::Union => todo!("Comparison of UnionArrays is not yet supported"), + PH::Map => todo!("Comparison of MapArrays is not yet supported"), + PH::Dictionary(I::Int8) => call_binary!(DictionaryArray), + PH::Dictionary(I::Int16) => call_binary!(DictionaryArray), + PH::Dictionary(I::Int32) => call_binary!(DictionaryArray), + PH::Dictionary(I::Int64) => call_binary!(DictionaryArray), + PH::Dictionary(I::Int128) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt8) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt16) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt32) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt64) => call_binary!(DictionaryArray), + } + }}; +} + +impl TotalEqKernel for ListArray { + type Scalar = Box; + + fn tot_eq_kernel(&self, other: &Self) -> Bitmap { + compare!( + self, + other, + TotalEqKernel::tot_eq_missing_kernel, + |bm: Bitmap| bm.unset_bits() == 0, + false, + true + ) + } + + fn tot_ne_kernel(&self, other: &Self) -> Bitmap { + compare!( + self, + other, + TotalEqKernel::tot_ne_missing_kernel, + |bm: Bitmap| bm.set_bits() > 0, + true, + false + ) + } + + fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + compare_broadcast!( + self.values().as_ref(), + other.as_ref(), + self.offsets(), + self.validity(), + TotalEqKernel::tot_eq_missing_kernel, + |bm: Bitmap| bm.unset_bits() == 0, + false, + true + ) + } + + fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + compare_broadcast!( + self.values().as_ref(), + other.as_ref(), + self.offsets(), + self.validity(), + TotalEqKernel::tot_ne_missing_kernel, + |bm: Bitmap| bm.set_bits() > 0, + true, + false + ) + } +} diff --git a/crates/polars-compute/src/comparisons/mod.rs b/crates/polars-compute/src/comparisons/mod.rs new file mode 100644 index 000000000000..5c58a9f290b0 --- /dev/null +++ b/crates/polars-compute/src/comparisons/mod.rs @@ -0,0 +1,104 @@ +use arrow::array::Array; +use arrow::bitmap::{self, Bitmap}; + +pub trait TotalEqKernel: Sized + Array { + type Scalar: ?Sized; + + // These kernels ignore validity entirely (results for nulls are unspecified + // but initialized). + fn tot_eq_kernel(&self, other: &Self) -> Bitmap; + fn tot_ne_kernel(&self, other: &Self) -> Bitmap; + fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap; + fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap; + + // These kernels treat null as any other value equal to itself but unequal + // to anything else. + fn tot_eq_missing_kernel(&self, other: &Self) -> Bitmap { + let q = self.tot_eq_kernel(other); + match (self.validity(), other.validity()) { + (None, None) => q, + (None, Some(r)) => &q & r, + (Some(l), None) => &q & l, + (Some(l), Some(r)) => bitmap::ternary(&q, l, r, |q, l, r| (q & l & r) | !(l | r)), + } + } + + fn tot_ne_missing_kernel(&self, other: &Self) -> Bitmap { + let q = self.tot_ne_kernel(other); + match (self.validity(), other.validity()) { + (None, None) => q, + (None, Some(r)) => &q | &!r, + (Some(l), None) => &q | &!l, + (Some(l), Some(r)) => bitmap::ternary(&q, l, r, |q, l, r| (q & l & r) | (l ^ r)), + } + } + fn tot_eq_missing_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + let q = self.tot_eq_kernel_broadcast(other); + if let Some(valid) = self.validity() { + bitmap::binary(&q, valid, |q, v| q & v) + } else { + q + } + } + + fn tot_ne_missing_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + let q = self.tot_ne_kernel_broadcast(other); + if let Some(valid) = self.validity() { + bitmap::binary(&q, valid, |q, v| q | !v) + } else { + q + } + } +} + +// Low-level comparison kernel. +pub trait TotalOrdKernel: Sized + Array { + type Scalar: ?Sized; + + // These kernels ignore validity entirely (results for nulls are unspecified + // but initialized). + fn tot_lt_kernel(&self, other: &Self) -> Bitmap; + fn tot_le_kernel(&self, other: &Self) -> Bitmap; + fn tot_gt_kernel(&self, other: &Self) -> Bitmap { + other.tot_lt_kernel(self) + } + fn tot_ge_kernel(&self, other: &Self) -> Bitmap { + other.tot_le_kernel(self) + } + + // These kernels ignore validity entirely (results for nulls are unspecified + // but initialized). + fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap; + fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap; + fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap; + fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap; +} + +mod binary; +mod boolean; +mod dictionary; +mod dyn_array; +mod list; +mod null; +mod scalar; +mod struct_; +mod utf8; +mod view; + +#[cfg(feature = "simd")] +mod _simd_dtypes { + use arrow::types::{days_ms, f16, i256, months_days_ns}; + + use crate::NotSimdPrimitive; + + impl NotSimdPrimitive for f16 {} + impl NotSimdPrimitive for i256 {} + impl NotSimdPrimitive for days_ms {} + impl NotSimdPrimitive for months_days_ns {} +} + +#[cfg(feature = "simd")] +mod simd; + +#[cfg(feature = "dtype-array")] +mod array; diff --git a/crates/polars-compute/src/comparisons/null.rs b/crates/polars-compute/src/comparisons/null.rs new file mode 100644 index 000000000000..9d6e9e3dcefd --- /dev/null +++ b/crates/polars-compute/src/comparisons/null.rs @@ -0,0 +1,54 @@ +use arrow::array::{Array, NullArray}; +use arrow::bitmap::Bitmap; + +use super::{TotalEqKernel, TotalOrdKernel}; + +impl TotalEqKernel for NullArray { + type Scalar = Box; + + fn tot_eq_kernel(&self, other: &Self) -> Bitmap { + assert!(self.len() == other.len()); + Bitmap::new_with_value(true, self.len()) + } + + fn tot_ne_kernel(&self, other: &Self) -> Bitmap { + assert!(self.len() == other.len()); + Bitmap::new_zeroed(self.len()) + } + + fn tot_eq_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { + todo!() + } + + fn tot_ne_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { + todo!() + } +} + +impl TotalOrdKernel for NullArray { + type Scalar = Box; + + fn tot_lt_kernel(&self, _other: &Self) -> Bitmap { + unimplemented!() + } + + fn tot_le_kernel(&self, _other: &Self) -> Bitmap { + unimplemented!() + } + + fn tot_lt_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { + unimplemented!() + } + + fn tot_le_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { + unimplemented!() + } + + fn tot_gt_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { + unimplemented!() + } + + fn tot_ge_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { + unimplemented!() + } +} diff --git a/crates/polars-compute/src/comparisons/scalar.rs b/crates/polars-compute/src/comparisons/scalar.rs new file mode 100644 index 000000000000..d792503b4225 --- /dev/null +++ b/crates/polars-compute/src/comparisons/scalar.rs @@ -0,0 +1,74 @@ +use arrow::array::PrimitiveArray; +use arrow::bitmap::Bitmap; +use polars_utils::total_ord::TotalOrd; + +use super::{TotalEqKernel, TotalOrdKernel}; +use crate::NotSimdPrimitive; + +impl TotalEqKernel for PrimitiveArray { + type Scalar = T; + + fn tot_eq_kernel(&self, other: &Self) -> Bitmap { + assert!(self.len() == other.len()); + self.values() + .iter() + .zip(other.values().iter()) + .map(|(l, r)| l.tot_eq(r)) + .collect() + } + + fn tot_ne_kernel(&self, other: &Self) -> Bitmap { + assert!(self.len() == other.len()); + self.values() + .iter() + .zip(other.values().iter()) + .map(|(l, r)| l.tot_ne(r)) + .collect() + } + + fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.values().iter().map(|l| l.tot_eq(other)).collect() + } + + fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.values().iter().map(|l| l.tot_ne(other)).collect() + } +} + +impl TotalOrdKernel for PrimitiveArray { + type Scalar = T; + + fn tot_lt_kernel(&self, other: &Self) -> Bitmap { + assert!(self.len() == other.len()); + self.values() + .iter() + .zip(other.values().iter()) + .map(|(l, r)| l.tot_lt(r)) + .collect() + } + + fn tot_le_kernel(&self, other: &Self) -> Bitmap { + assert!(self.len() == other.len()); + self.values() + .iter() + .zip(other.values().iter()) + .map(|(l, r)| l.tot_le(r)) + .collect() + } + + fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.values().iter().map(|l| l.tot_lt(other)).collect() + } + + fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.values().iter().map(|l| l.tot_le(other)).collect() + } + + fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.values().iter().map(|l| l.tot_gt(other)).collect() + } + + fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.values().iter().map(|l| l.tot_ge(other)).collect() + } +} diff --git a/crates/polars-compute/src/comparisons/simd.rs b/crates/polars-compute/src/comparisons/simd.rs new file mode 100644 index 000000000000..95ea9707744a --- /dev/null +++ b/crates/polars-compute/src/comparisons/simd.rs @@ -0,0 +1,293 @@ +use std::ptr; +use std::simd::prelude::{Simd, SimdPartialEq, SimdPartialOrd}; + +use arrow::array::PrimitiveArray; +use arrow::bitmap::Bitmap; +use arrow::types::NativeType; +use bytemuck::Pod; + +use super::{TotalEqKernel, TotalOrdKernel}; + +fn apply_binary_kernel( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, + mut f: F, +) -> Bitmap +where + T: NativeType, + F: FnMut(&[T; N], &[T; N]) -> M, +{ + assert_eq!(N, size_of::() * 8); + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + + let lhs_buf = lhs.values().as_slice(); + let rhs_buf = rhs.values().as_slice(); + let lhs_chunks = lhs_buf.chunks_exact(N); + let rhs_chunks = rhs_buf.chunks_exact(N); + let lhs_rest = lhs_chunks.remainder(); + let rhs_rest = rhs_chunks.remainder(); + + let num_masks = n.div_ceil(N); + let mut v: Vec = Vec::with_capacity(num_masks * size_of::()); + let mut p = v.as_mut_ptr() as *mut M; + for (l, r) in lhs_chunks.zip(rhs_chunks) { + unsafe { + let mask = f( + l.try_into().unwrap_unchecked(), + r.try_into().unwrap_unchecked(), + ); + p.write_unaligned(mask); + p = p.wrapping_add(1); + } + } + + if n % N > 0 { + let mut l: [T; N] = [T::zeroed(); N]; + let mut r: [T; N] = [T::zeroed(); N]; + unsafe { + ptr::copy_nonoverlapping(lhs_rest.as_ptr(), l.as_mut_ptr(), n % N); + ptr::copy_nonoverlapping(rhs_rest.as_ptr(), r.as_mut_ptr(), n % N); + p.write_unaligned(f(&l, &r)); + } + } + + unsafe { + v.set_len(num_masks * size_of::()); + } + + Bitmap::from_u8_vec(v, n) +} + +fn apply_unary_kernel(arg: &PrimitiveArray, mut f: F) -> Bitmap +where + T: NativeType, + F: FnMut(&[T; N]) -> M, +{ + assert_eq!(N, size_of::() * 8); + let n = arg.len(); + + let arg_buf = arg.values().as_slice(); + let arg_chunks = arg_buf.chunks_exact(N); + let arg_rest = arg_chunks.remainder(); + + let num_masks = n.div_ceil(N); + let mut v: Vec = Vec::with_capacity(num_masks * size_of::()); + let mut p = v.as_mut_ptr() as *mut M; + for a in arg_chunks { + unsafe { + let mask = f(a.try_into().unwrap_unchecked()); + p.write_unaligned(mask); + p = p.wrapping_add(1); + } + } + + if n % N > 0 { + let mut a: [T; N] = [T::zeroed(); N]; + unsafe { + ptr::copy_nonoverlapping(arg_rest.as_ptr(), a.as_mut_ptr(), n % N); + p.write_unaligned(f(&a)); + } + } + + unsafe { + v.set_len(num_masks * size_of::()); + } + + Bitmap::from_u8_vec(v, n) +} + +macro_rules! impl_int_total_ord_kernel { + ($T: ty, $width: literal, $mask: ty) => { + impl TotalEqKernel for PrimitiveArray<$T> { + type Scalar = $T; + + fn tot_eq_kernel(&self, other: &Self) -> Bitmap { + apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| { + Simd::from(*l).simd_eq(Simd::from(*r)).to_bitmask() as $mask + }) + } + + fn tot_ne_kernel(&self, other: &Self) -> Bitmap { + apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| { + Simd::from(*l).simd_ne(Simd::from(*r)).to_bitmask() as $mask + }) + } + + fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + let r = Simd::splat(*other); + apply_unary_kernel::<$width, $mask, _, _>(self, |l| { + Simd::from(*l).simd_eq(r).to_bitmask() as $mask + }) + } + + fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + let r = Simd::splat(*other); + apply_unary_kernel::<$width, $mask, _, _>(self, |l| { + Simd::from(*l).simd_ne(r).to_bitmask() as $mask + }) + } + } + + impl TotalOrdKernel for PrimitiveArray<$T> { + type Scalar = $T; + + fn tot_lt_kernel(&self, other: &Self) -> Bitmap { + apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| { + Simd::from(*l).simd_lt(Simd::from(*r)).to_bitmask() as $mask + }) + } + + fn tot_le_kernel(&self, other: &Self) -> Bitmap { + apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| { + Simd::from(*l).simd_le(Simd::from(*r)).to_bitmask() as $mask + }) + } + + fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + let r = Simd::splat(*other); + apply_unary_kernel::<$width, $mask, _, _>(self, |l| { + Simd::from(*l).simd_lt(r).to_bitmask() as $mask + }) + } + + fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + let r = Simd::splat(*other); + apply_unary_kernel::<$width, $mask, _, _>(self, |l| { + Simd::from(*l).simd_le(r).to_bitmask() as $mask + }) + } + + fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + let r = Simd::splat(*other); + apply_unary_kernel::<$width, $mask, _, _>(self, |l| { + Simd::from(*l).simd_gt(r).to_bitmask() as $mask + }) + } + + fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + let r = Simd::splat(*other); + apply_unary_kernel::<$width, $mask, _, _>(self, |l| { + Simd::from(*l).simd_ge(r).to_bitmask() as $mask + }) + } + } + }; +} + +macro_rules! impl_float_total_ord_kernel { + ($T: ty, $width: literal, $mask: ty) => { + impl TotalEqKernel for PrimitiveArray<$T> { + type Scalar = $T; + + fn tot_eq_kernel(&self, other: &Self) -> Bitmap { + apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| { + let ls = Simd::from(*l); + let rs = Simd::from(*r); + let lhs_is_nan = ls.simd_ne(ls); + let rhs_is_nan = rs.simd_ne(rs); + ((lhs_is_nan & rhs_is_nan) | ls.simd_eq(rs)).to_bitmask() as $mask + }) + } + + fn tot_ne_kernel(&self, other: &Self) -> Bitmap { + apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| { + let ls = Simd::from(*l); + let rs = Simd::from(*r); + let lhs_is_nan = ls.simd_ne(ls); + let rhs_is_nan = rs.simd_ne(rs); + (!((lhs_is_nan & rhs_is_nan) | ls.simd_eq(rs))).to_bitmask() as $mask + }) + } + + fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + let rs = Simd::splat(*other); + apply_unary_kernel::<$width, $mask, _, _>(self, |l| { + let ls = Simd::from(*l); + let lhs_is_nan = ls.simd_ne(ls); + let rhs_is_nan = rs.simd_ne(rs); + ((lhs_is_nan & rhs_is_nan) | ls.simd_eq(rs)).to_bitmask() as $mask + }) + } + + fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + let rs = Simd::splat(*other); + apply_unary_kernel::<$width, $mask, _, _>(self, |l| { + let ls = Simd::from(*l); + let lhs_is_nan = ls.simd_ne(ls); + let rhs_is_nan = rs.simd_ne(rs); + (!((lhs_is_nan & rhs_is_nan) | ls.simd_eq(rs))).to_bitmask() as $mask + }) + } + } + + impl TotalOrdKernel for PrimitiveArray<$T> { + type Scalar = $T; + + fn tot_lt_kernel(&self, other: &Self) -> Bitmap { + apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| { + let ls = Simd::from(*l); + let rs = Simd::from(*r); + let lhs_is_nan = ls.simd_ne(ls); + (!(lhs_is_nan | ls.simd_ge(rs))).to_bitmask() as $mask + }) + } + + fn tot_le_kernel(&self, other: &Self) -> Bitmap { + apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| { + let ls = Simd::from(*l); + let rs = Simd::from(*r); + let rhs_is_nan = rs.simd_ne(rs); + (rhs_is_nan | ls.simd_le(rs)).to_bitmask() as $mask + }) + } + + fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + let rs = Simd::splat(*other); + apply_unary_kernel::<$width, $mask, _, _>(self, |l| { + let ls = Simd::from(*l); + let lhs_is_nan = ls.simd_ne(ls); + (!(lhs_is_nan | ls.simd_ge(rs))).to_bitmask() as $mask + }) + } + + fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + let rs = Simd::splat(*other); + apply_unary_kernel::<$width, $mask, _, _>(self, |l| { + let ls = Simd::from(*l); + let rhs_is_nan = rs.simd_ne(rs); + (rhs_is_nan | ls.simd_le(rs)).to_bitmask() as $mask + }) + } + + fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + let rs = Simd::splat(*other); + apply_unary_kernel::<$width, $mask, _, _>(self, |l| { + let ls = Simd::from(*l); + let rhs_is_nan = rs.simd_ne(rs); + (!(rhs_is_nan | rs.simd_ge(ls))).to_bitmask() as $mask + }) + } + + fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + let rs = Simd::splat(*other); + apply_unary_kernel::<$width, $mask, _, _>(self, |l| { + let ls = Simd::from(*l); + let lhs_is_nan = ls.simd_ne(ls); + (lhs_is_nan | rs.simd_le(ls)).to_bitmask() as $mask + }) + } + } + }; +} + +impl_int_total_ord_kernel!(u8, 32, u32); +impl_int_total_ord_kernel!(u16, 16, u16); +impl_int_total_ord_kernel!(u32, 8, u8); +impl_int_total_ord_kernel!(u64, 8, u8); +impl_int_total_ord_kernel!(i8, 32, u32); +impl_int_total_ord_kernel!(i16, 16, u16); +impl_int_total_ord_kernel!(i32, 8, u8); +impl_int_total_ord_kernel!(i64, 8, u8); +impl_float_total_ord_kernel!(f32, 8, u8); +impl_float_total_ord_kernel!(f64, 8, u8); diff --git a/crates/polars-compute/src/comparisons/struct_.rs b/crates/polars-compute/src/comparisons/struct_.rs new file mode 100644 index 000000000000..767d502e9af6 --- /dev/null +++ b/crates/polars-compute/src/comparisons/struct_.rs @@ -0,0 +1,108 @@ +use arrow::array::{Array, StructArray}; +use arrow::bitmap::{Bitmap, BitmapBuilder}; + +use super::TotalEqKernel; +use crate::comparisons::dyn_array::array_tot_eq_missing_kernel; + +impl TotalEqKernel for StructArray { + type Scalar = Box; + + fn tot_eq_kernel(&self, other: &Self) -> Bitmap { + let lhs = self; + let rhs = other; + assert_eq!(lhs.len(), rhs.len()); + + if lhs.fields() != rhs.fields() { + return Bitmap::new_zeroed(lhs.len()); + } + + let ln = lhs.validity(); + let rn = rhs.validity(); + + let lv = lhs.values(); + let rv = rhs.values(); + + let mut bitmap = BitmapBuilder::with_capacity(lhs.len()); + + for i in 0..lhs.len() { + let mut is_equal = true; + + if !ln.is_none_or(|v| v.get(i).unwrap()) || !rn.is_none_or(|v| v.get(i).unwrap()) { + bitmap.push(true); + continue; + } + + for j in 0..lhs.values().len() { + if lv[j].len() != rv[j].len() { + is_equal = false; + break; + } + + let result = array_tot_eq_missing_kernel(lv[j].as_ref(), rv[j].as_ref()); + if result.unset_bits() != 0 { + is_equal = false; + break; + } + } + + bitmap.push(is_equal); + } + + bitmap.freeze() + } + + fn tot_ne_kernel(&self, other: &Self) -> Bitmap { + let lhs = self; + let rhs = other; + + if lhs.fields() != rhs.fields() { + return Bitmap::new_with_value(true, lhs.len()); + } + + if lhs.values().len() != rhs.values().len() { + return Bitmap::new_with_value(true, lhs.len()); + } + + let ln = lhs.validity(); + let rn = rhs.validity(); + + let lv = lhs.values(); + let rv = rhs.values(); + + let mut bitmap = BitmapBuilder::with_capacity(lhs.len()); + + for i in 0..lhs.len() { + let mut is_equal = true; + + if !ln.is_none_or(|v| v.get(i).unwrap()) || !rn.is_none_or(|v| v.get(i).unwrap()) { + bitmap.push(false); + continue; + } + + for j in 0..lhs.values().len() { + if lv[j].len() != rv[j].len() { + is_equal = false; + break; + } + + let result = array_tot_eq_missing_kernel(lv[j].as_ref(), rv[j].as_ref()); + if result.unset_bits() != 0 { + is_equal = false; + break; + } + } + + bitmap.push(!is_equal); + } + + bitmap.freeze() + } + + fn tot_eq_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { + todo!() + } + + fn tot_ne_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { + todo!() + } +} diff --git a/crates/polars-compute/src/comparisons/utf8.rs b/crates/polars-compute/src/comparisons/utf8.rs new file mode 100644 index 000000000000..dbb8aa2cc3fa --- /dev/null +++ b/crates/polars-compute/src/comparisons/utf8.rs @@ -0,0 +1,53 @@ +use arrow::array::Utf8Array; +use arrow::bitmap::Bitmap; +use arrow::types::Offset; + +use super::{TotalEqKernel, TotalOrdKernel}; + +impl TotalEqKernel for Utf8Array { + type Scalar = str; + + fn tot_eq_kernel(&self, other: &Self) -> Bitmap { + self.to_binary().tot_eq_kernel(&other.to_binary()) + } + + fn tot_ne_kernel(&self, other: &Self) -> Bitmap { + self.to_binary().tot_ne_kernel(&other.to_binary()) + } + + fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.to_binary().tot_eq_kernel_broadcast(other.as_bytes()) + } + + fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.to_binary().tot_ne_kernel_broadcast(other.as_bytes()) + } +} + +impl TotalOrdKernel for Utf8Array { + type Scalar = str; + + fn tot_lt_kernel(&self, other: &Self) -> Bitmap { + self.to_binary().tot_lt_kernel(&other.to_binary()) + } + + fn tot_le_kernel(&self, other: &Self) -> Bitmap { + self.to_binary().tot_le_kernel(&other.to_binary()) + } + + fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.to_binary().tot_lt_kernel_broadcast(other.as_bytes()) + } + + fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.to_binary().tot_le_kernel_broadcast(other.as_bytes()) + } + + fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.to_binary().tot_gt_kernel_broadcast(other.as_bytes()) + } + + fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.to_binary().tot_ge_kernel_broadcast(other.as_bytes()) + } +} diff --git a/crates/polars-compute/src/comparisons/view.rs b/crates/polars-compute/src/comparisons/view.rs new file mode 100644 index 000000000000..c39187e90c60 --- /dev/null +++ b/crates/polars-compute/src/comparisons/view.rs @@ -0,0 +1,253 @@ +use arrow::array::{BinaryViewArray, Utf8ViewArray}; +use arrow::bitmap::Bitmap; + +use super::TotalEqKernel; +use crate::comparisons::TotalOrdKernel; + +// If s fits in 12 bytes, returns the view encoding it would have in a +// BinaryViewArray. +fn small_view_encoding(s: &[u8]) -> Option { + if s.len() > 12 { + return None; + } + + let mut tmp = [0u8; 16]; + tmp[0] = s.len() as u8; + tmp[4..4 + s.len()].copy_from_slice(s); + Some(u128::from_le_bytes(tmp)) +} + +// Loads (up to) the first 4 bytes of s as little-endian, padded with zeros. +fn load_prefix(s: &[u8]) -> u32 { + let start = &s[..s.len().min(4)]; + let mut tmp = [0u8; 4]; + tmp[..start.len()].copy_from_slice(start); + u32::from_le_bytes(tmp) +} + +fn broadcast_inequality( + arr: &BinaryViewArray, + scalar: &[u8], + cmp_prefix: impl Fn(u32, u32) -> bool, + cmp_str: impl Fn(&[u8], &[u8]) -> bool, +) -> Bitmap { + let views = arr.views().as_slice(); + let prefix = load_prefix(scalar); + let be_prefix = prefix.to_be(); + Bitmap::from_trusted_len_iter((0..arr.len()).map(|i| unsafe { + let v_prefix = (views.get_unchecked(i).as_u128() >> 32) as u32; + if v_prefix != prefix { + cmp_prefix(v_prefix.to_be(), be_prefix) + } else { + cmp_str(arr.value_unchecked(i), scalar) + } + })) +} + +impl TotalEqKernel for BinaryViewArray { + type Scalar = [u8]; + + fn tot_eq_kernel(&self, other: &Self) -> Bitmap { + debug_assert!(self.len() == other.len()); + + let slf_views = self.views().as_slice(); + let other_views = other.views().as_slice(); + + Bitmap::from_trusted_len_iter((0..self.len()).map(|i| unsafe { + let av = slf_views.get_unchecked(i).as_u128(); + let bv = other_views.get_unchecked(i).as_u128(); + + // First 64 bits contain length and prefix. + let a_len_prefix = av as u64; + let b_len_prefix = bv as u64; + if a_len_prefix != b_len_prefix { + return false; + } + + let alen = av as u32; + if alen <= 12 { + // String is fully inlined, compare top 64 bits. Bottom bits were + // tested equal before, which also ensures the lengths are equal. + (av >> 64) as u64 == (bv >> 64) as u64 + } else { + self.value_unchecked(i) == other.value_unchecked(i) + } + })) + } + + fn tot_ne_kernel(&self, other: &Self) -> Bitmap { + debug_assert!(self.len() == other.len()); + + let slf_views = self.views().as_slice(); + let other_views = other.views().as_slice(); + + Bitmap::from_trusted_len_iter((0..self.len()).map(|i| unsafe { + let av = slf_views.get_unchecked(i).as_u128(); + let bv = other_views.get_unchecked(i).as_u128(); + + // First 64 bits contain length and prefix. + let a_len_prefix = av as u64; + let b_len_prefix = bv as u64; + if a_len_prefix != b_len_prefix { + return true; + } + + let alen = av as u32; + if alen <= 12 { + // String is fully inlined, compare top 64 bits. Bottom bits were + // tested equal before, which also ensures the lengths are equal. + (av >> 64) as u64 != (bv >> 64) as u64 + } else { + self.value_unchecked(i) != other.value_unchecked(i) + } + })) + } + + fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + if let Some(val) = small_view_encoding(other) { + Bitmap::from_trusted_len_iter(self.views().iter().map(|v| v.as_u128() == val)) + } else { + let slf_views = self.views().as_slice(); + let prefix = u32::from_le_bytes(other[..4].try_into().unwrap()); + let prefix_len = ((prefix as u64) << 32) | other.len() as u64; + Bitmap::from_trusted_len_iter((0..self.len()).map(|i| unsafe { + let v_prefix_len = slf_views.get_unchecked(i).as_u128() as u64; + if v_prefix_len != prefix_len { + false + } else { + self.value_unchecked(i) == other + } + })) + } + } + + fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + if let Some(val) = small_view_encoding(other) { + Bitmap::from_trusted_len_iter(self.views().iter().map(|v| v.as_u128() != val)) + } else { + let slf_views = self.views().as_slice(); + let prefix = u32::from_le_bytes(other[..4].try_into().unwrap()); + let prefix_len = ((prefix as u64) << 32) | other.len() as u64; + Bitmap::from_trusted_len_iter((0..self.len()).map(|i| unsafe { + let v_prefix_len = slf_views.get_unchecked(i).as_u128() as u64; + if v_prefix_len != prefix_len { + true + } else { + self.value_unchecked(i) != other + } + })) + } + } +} + +impl TotalOrdKernel for BinaryViewArray { + type Scalar = [u8]; + + fn tot_lt_kernel(&self, other: &Self) -> Bitmap { + debug_assert!(self.len() == other.len()); + + let slf_views = self.views().as_slice(); + let other_views = other.views().as_slice(); + + Bitmap::from_trusted_len_iter((0..self.len()).map(|i| unsafe { + let av = slf_views.get_unchecked(i).as_u128(); + let bv = other_views.get_unchecked(i).as_u128(); + + // First 64 bits contain length and prefix. + // Only check prefix. + let a_prefix = (av >> 32) as u32; + let b_prefix = (bv >> 32) as u32; + if a_prefix != b_prefix { + a_prefix.to_be() < b_prefix.to_be() + } else { + self.value_unchecked(i) < other.value_unchecked(i) + } + })) + } + + fn tot_le_kernel(&self, other: &Self) -> Bitmap { + debug_assert!(self.len() == other.len()); + + let slf_views = self.views().as_slice(); + let other_views = other.views().as_slice(); + + Bitmap::from_trusted_len_iter((0..self.len()).map(|i| unsafe { + let av = slf_views.get_unchecked(i).as_u128(); + let bv = other_views.get_unchecked(i).as_u128(); + + // First 64 bits contain length and prefix. + // Only check prefix. + let a_prefix = (av >> 32) as u32; + let b_prefix = (bv >> 32) as u32; + if a_prefix != b_prefix { + a_prefix.to_be() < b_prefix.to_be() + } else { + self.value_unchecked(i) <= other.value_unchecked(i) + } + })) + } + + fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + broadcast_inequality(self, other, |a, b| a < b, |a, b| a < b) + } + + fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + broadcast_inequality(self, other, |a, b| a <= b, |a, b| a <= b) + } + + fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + broadcast_inequality(self, other, |a, b| a > b, |a, b| a > b) + } + + fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + broadcast_inequality(self, other, |a, b| a >= b, |a, b| a >= b) + } +} + +impl TotalEqKernel for Utf8ViewArray { + type Scalar = str; + + fn tot_eq_kernel(&self, other: &Self) -> Bitmap { + self.to_binview().tot_eq_kernel(&other.to_binview()) + } + + fn tot_ne_kernel(&self, other: &Self) -> Bitmap { + self.to_binview().tot_ne_kernel(&other.to_binview()) + } + + fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.to_binview().tot_eq_kernel_broadcast(other.as_bytes()) + } + + fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.to_binview().tot_ne_kernel_broadcast(other.as_bytes()) + } +} + +impl TotalOrdKernel for Utf8ViewArray { + type Scalar = str; + + fn tot_lt_kernel(&self, other: &Self) -> Bitmap { + self.to_binview().tot_lt_kernel(&other.to_binview()) + } + + fn tot_le_kernel(&self, other: &Self) -> Bitmap { + self.to_binview().tot_le_kernel(&other.to_binview()) + } + + fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.to_binview().tot_lt_kernel_broadcast(other.as_bytes()) + } + + fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.to_binview().tot_le_kernel_broadcast(other.as_bytes()) + } + + fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.to_binview().tot_gt_kernel_broadcast(other.as_bytes()) + } + + fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.to_binview().tot_ge_kernel_broadcast(other.as_bytes()) + } +} diff --git a/crates/polars-compute/src/filter/avx512.rs b/crates/polars-compute/src/filter/avx512.rs new file mode 100644 index 000000000000..a439f131f60e --- /dev/null +++ b/crates/polars-compute/src/filter/avx512.rs @@ -0,0 +1,115 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use core::arch::x86_64::*; + +// It's not possible to inline target_feature(enable = ...) functions into other +// functions without that enabled, so we use a macro for these very-similarly +// structured functions. +macro_rules! simd_filter { + ($values: ident, $mask_bytes: ident, $out: ident, |$subchunk: ident, $m: ident: $MaskT: ty| $body:block) => {{ + const MASK_BITS: usize = <$MaskT>::BITS as usize; + + // Do a 64-element loop for sparse fast path. + let chunks = $values.chunks_exact(64); + $values = chunks.remainder(); + for chunk in chunks { + let mask_chunk = $mask_bytes.get_unchecked(..8); + $mask_bytes = $mask_bytes.get_unchecked(8..); + let mut m64 = u64::from_le_bytes(mask_chunk.try_into().unwrap()); + + // Fast-path: skip entire 64-element chunk. + if m64 == 0 { + continue; + } + + for $subchunk in chunk.chunks_exact(MASK_BITS) { + let $m = m64 as $MaskT; + $body; + m64 >>= MASK_BITS % 64; + } + } + + // Handle the SIMD-block-sized remainder. + let subchunks = $values.chunks_exact(MASK_BITS); + $values = subchunks.remainder(); + for $subchunk in subchunks { + let mask_chunk = $mask_bytes.get_unchecked(..MASK_BITS / 8); + $mask_bytes = $mask_bytes.get_unchecked(MASK_BITS / 8..); + let $m = <$MaskT>::from_le_bytes(mask_chunk.try_into().unwrap()); + $body; + } + + ($values, $mask_bytes, $out) + }}; +} + +/// # Safety +/// out must be valid for 64 + bitslice(mask_bytes, 0..values.len()).count_ones() writes. +/// AVX512_VBMI2 must be enabled. +#[target_feature(enable = "avx512f")] +#[target_feature(enable = "avx512vbmi2")] +pub unsafe fn filter_u8_avx512vbmi2<'a>( + mut values: &'a [u8], + mut mask_bytes: &'a [u8], + mut out: *mut u8, +) -> (&'a [u8], &'a [u8], *mut u8) { + simd_filter!(values, mask_bytes, out, |vchunk, m: u64| { + // We don't use compress-store instructions because they are very slow + // on Zen. We are allowed to overshoot anyway. + let v = _mm512_loadu_si512(vchunk.as_ptr().cast()); + let filtered = _mm512_maskz_compress_epi8(m, v); + _mm512_storeu_si512(out.cast(), filtered); + out = out.add(m.count_ones() as usize); + }) +} + +/// # Safety +/// out must be valid for 32 + bitslice(mask_bytes, 0..values.len()).count_ones() writes. +/// AVX512_VBMI2 must be enabled. +#[target_feature(enable = "avx512f")] +#[target_feature(enable = "avx512vbmi2")] +pub unsafe fn filter_u16_avx512vbmi2<'a>( + mut values: &'a [u16], + mut mask_bytes: &'a [u8], + mut out: *mut u16, +) -> (&'a [u16], &'a [u8], *mut u16) { + simd_filter!(values, mask_bytes, out, |vchunk, m: u32| { + let v = _mm512_loadu_si512(vchunk.as_ptr().cast()); + let filtered = _mm512_maskz_compress_epi16(m, v); + _mm512_storeu_si512(out.cast(), filtered); + out = out.add(m.count_ones() as usize); + }) +} + +/// # Safety +/// out must be valid for 16 + bitslice(mask_bytes, 0..values.len()).count_ones() writes. +/// AVX512F must be enabled. +#[target_feature(enable = "avx512f")] +pub unsafe fn filter_u32_avx512f<'a>( + mut values: &'a [u32], + mut mask_bytes: &'a [u8], + mut out: *mut u32, +) -> (&'a [u32], &'a [u8], *mut u32) { + simd_filter!(values, mask_bytes, out, |vchunk, m: u16| { + let v = _mm512_loadu_si512(vchunk.as_ptr().cast()); + let filtered = _mm512_maskz_compress_epi32(m, v); + _mm512_storeu_si512(out.cast(), filtered); + out = out.add(m.count_ones() as usize); + }) +} + +/// # Safety +/// out must be valid for 8 + bitslice(mask_bytes, 0..values.len()).count_ones() writes. +/// AVX512F must be enabled. +#[target_feature(enable = "avx512f")] +pub unsafe fn filter_u64_avx512f<'a>( + mut values: &'a [u64], + mut mask_bytes: &'a [u8], + mut out: *mut u64, +) -> (&'a [u64], &'a [u8], *mut u64) { + simd_filter!(values, mask_bytes, out, |vchunk, m: u8| { + let v = _mm512_loadu_si512(vchunk.as_ptr().cast()); + let filtered = _mm512_maskz_compress_epi64(m, v); + _mm512_storeu_si512(out.cast(), filtered); + out = out.add(m.count_ones() as usize); + }) +} diff --git a/crates/polars-compute/src/filter/boolean.rs b/crates/polars-compute/src/filter/boolean.rs new file mode 100644 index 000000000000..676744a48fb1 --- /dev/null +++ b/crates/polars-compute/src/filter/boolean.rs @@ -0,0 +1,294 @@ +use arrow::bitmap::Bitmap; +use polars_utils::clmul::prefix_xorsum; + +const U56_MAX: u64 = (1 << 56) - 1; + +fn pext64_polyfill(mut v: u64, mut m: u64, m_popcnt: u32) -> u64 { + // Fast path: popcount is low. + if m_popcnt <= 4 { + // Not a "while m != 0" but a for loop instead so the compiler fully + // unrolls the loop, this makes bit << i much faster. + let mut out = 0; + for i in 0..4 { + if m == 0 { + break; + }; + + let bit = (v >> m.trailing_zeros()) & 1; + out |= bit << i; + m &= m.wrapping_sub(1); // Clear least significant bit. + } + return out; + } + + // Fast path: all the masked bits in v are 0 or 1. + // Despite this fast path being simpler than the above popcount-based one, + // we do it afterwards because if m has a low popcount these branches become + // very unpredictable. + v &= m; + if v == 0 { + return 0; + } else if v == m { + return (1 << m_popcnt) - 1; + } + + // This algorithm is too involved to explain here, see https://github.com/zwegner/zp7. + // That is an optimized version of Hacker's Delight Chapter 7-4, parallel suffix method for compress(). + let mut invm = !m; + + for i in 0..6 { + let shift = 1 << i; + let prefix_count_bit = if i < 5 { + prefix_xorsum(invm) + } else { + invm.wrapping_neg() << 1 + }; + let keep_in_place = v & !prefix_count_bit; + let shift_down = v & prefix_count_bit; + v = keep_in_place | (shift_down >> shift); + invm &= prefix_count_bit; + } + v +} + +pub fn filter_boolean_kernel(values: &Bitmap, mask: &Bitmap) -> Bitmap { + assert_eq!(values.len(), mask.len()); + let mask_bits_set = mask.set_bits(); + + // Fast path: values is all-0s or all-1s. + if let Some(num_values_bits) = values.lazy_set_bits() { + if num_values_bits == 0 || num_values_bits == values.len() { + return Bitmap::new_with_value(num_values_bits == values.len(), mask_bits_set); + } + } + + // Fast path: mask is all-0s or all-1s. + if mask_bits_set == 0 { + return Bitmap::new(); + } else if mask_bits_set == mask.len() { + return values.clone(); + } + + // Overallocate by 1 u64 so we can always do a full u64 write. + let num_words = mask_bits_set.div_ceil(64); + let num_bytes = 8 * (num_words + 1); + let mut out_vec: Vec = Vec::with_capacity(num_bytes); + + unsafe { + if mask_bits_set <= mask.len() / (64 * 4) { + // Less than one in 1 in 4 words has a bit set on average, use sparse kernel. + filter_boolean_kernel_sparse(values, mask, out_vec.as_mut_ptr()); + } else if polars_utils::cpuid::has_fast_bmi2() { + #[cfg(target_arch = "x86_64")] + filter_boolean_kernel_pext::(values, mask, out_vec.as_mut_ptr(), |v, m, _| { + // SAFETY: has_fast_bmi2 ensures this is a legal instruction. + core::arch::x86_64::_pext_u64(v, m) + }); + } else { + filter_boolean_kernel_pext::( + values, + mask, + out_vec.as_mut_ptr(), + pext64_polyfill, + ) + } + + // SAFETY: the above filters must have initialized these bytes. + out_vec.set_len(mask_bits_set.div_ceil(8)); + } + + Bitmap::from_u8_vec(out_vec, mask_bits_set) +} + +/// # Safety +/// out_ptr must point to a buffer of length >= 8 + 8 * ceil(mask.set_bits() / 64). +/// This function will initialize at least the first ceil(mask.set_bits() / 8) bytes. +unsafe fn filter_boolean_kernel_sparse(values: &Bitmap, mask: &Bitmap, mut out_ptr: *mut u8) { + assert_eq!(values.len(), mask.len()); + + let mut value_idx = 0; + let mut bits_in_word = 0usize; + let mut word = 0u64; + + macro_rules! loop_body { + ($m: expr) => {{ + let mut m = $m; + while m > 0 { + let idx_in_m = m.trailing_zeros() as usize; + let bit = unsafe { values.get_bit_unchecked(value_idx + idx_in_m) }; + word |= (bit as u64) << bits_in_word; + bits_in_word += 1; + + if bits_in_word == 64 { + unsafe { + out_ptr.cast::().write_unaligned(word.to_le()); + out_ptr = out_ptr.add(8); + bits_in_word = 0; + word = 0; + } + } + + m &= m.wrapping_sub(1); // Clear least significant bit. + } + }}; + } + + let mask_aligned = mask.aligned::(); + if mask_aligned.prefix_bitlen() > 0 { + loop_body!(mask_aligned.prefix()); + value_idx += mask_aligned.prefix_bitlen(); + } + + for m in mask_aligned.bulk_iter() { + loop_body!(m); + value_idx += 64; + } + + if mask_aligned.suffix_bitlen() > 0 { + loop_body!(mask_aligned.suffix()); + } + + if bits_in_word > 0 { + unsafe { + out_ptr.cast::().write_unaligned(word.to_le()); + } + } +} + +/// # Safety +/// See filter_boolean_kernel_sparse. +unsafe fn filter_boolean_kernel_pext u64>( + values: &Bitmap, + mask: &Bitmap, + mut out_ptr: *mut u8, + pext: F, +) { + assert_eq!(values.len(), mask.len()); + + let mut bits_in_word = 0usize; + let mut word = 0u64; + + macro_rules! loop_body { + ($v: expr, $m: expr) => {{ + let (v, m) = ($v, $m); + + // Fast-path, all-0 mask. + if m == 0 { + continue; + } + + // Fast path, all-1 mask. + // This is only worth it if we don't have a native pext. + if !HAS_NATIVE_PEXT && m == U56_MAX { + word |= v << bits_in_word; + unsafe { + out_ptr.cast::().write_unaligned(word.to_le()); + out_ptr = out_ptr.add(7); + } + word >>= 56; + continue; + } + + let mask_popcnt = m.count_ones(); + let bits = pext(v, m, mask_popcnt); + + // Because we keep bits_in_word < 8 and we iterate over u56s, + // this never loses output bits. + word |= bits << bits_in_word; + bits_in_word += mask_popcnt as usize; + unsafe { + out_ptr.cast::().write_unaligned(word.to_le()); + + let full_bytes_written = bits_in_word / 8; + out_ptr = out_ptr.add(full_bytes_written); + word >>= full_bytes_written * 8; + bits_in_word %= 8; + } + }}; + } + + let mut v_iter = values.fast_iter_u56(); + let mut m_iter = mask.fast_iter_u56(); + for v in &mut v_iter { + // SAFETY: we checked values and mask have same length. + let m = unsafe { m_iter.next().unwrap_unchecked() }; + loop_body!(v, m); + } + let mut v_rem = v_iter.remainder().0; + let mut m_rem = m_iter.remainder().0; + while m_rem != 0 { + let v = v_rem & U56_MAX; + let m = m_rem & U56_MAX; + v_rem >>= 56; + m_rem >>= 56; + loop_body!(v, m); // Careful, contains 'continue', increment loop variables first. + } +} + +pub fn filter_bitmap_and_validity( + values: &Bitmap, + validity: Option<&Bitmap>, + mask: &Bitmap, +) -> (Bitmap, Option) { + let filtered_values = filter_boolean_kernel(values, mask); + if let Some(validity) = validity { + // TODO: we could theoretically be faster by computing these two filters + // at once. Unsure if worth duplicating all the code above. + let filtered_validity = filter_boolean_kernel(validity, mask); + (filtered_values, Some(filtered_validity)) + } else { + (filtered_values, None) + } +} + +#[cfg(test)] +mod test { + use rand::prelude::*; + + use super::*; + + fn naive_pext64(word: u64, mask: u64) -> u64 { + let mut out = 0; + let mut out_idx = 0; + + for i in 0..64 { + let ith_mask_bit = (mask >> i) & 1; + let ith_word_bit = (word >> i) & 1; + if ith_mask_bit == 1 { + out |= ith_word_bit << out_idx; + out_idx += 1; + } + } + + out + } + + #[test] + fn test_pext64() { + // Verify polyfill against naive implementation. + let mut rng = StdRng::seed_from_u64(0xdeadbeef); + for _ in 0..100 { + let x = rng.r#gen(); + let y = rng.r#gen(); + assert_eq!(naive_pext64(x, y), pext64_polyfill(x, y, y.count_ones())); + + // Test all-zeros and all-ones. + assert_eq!(naive_pext64(0, y), pext64_polyfill(0, y, y.count_ones())); + assert_eq!( + naive_pext64(u64::MAX, y), + pext64_polyfill(u64::MAX, y, y.count_ones()) + ); + assert_eq!(naive_pext64(x, 0), pext64_polyfill(x, 0, 0)); + assert_eq!(naive_pext64(x, u64::MAX), pext64_polyfill(x, u64::MAX, 64)); + + // Test low popcount mask. + let popcnt = rng.gen_range(0..=8); + // Not perfect (can generate same bit twice) but it'll do. + let mask = (0..popcnt).map(|_| 1 << rng.gen_range(0..64)).sum(); + assert_eq!( + naive_pext64(x, mask), + pext64_polyfill(x, mask, mask.count_ones()) + ); + } + } +} diff --git a/crates/polars-compute/src/filter/mod.rs b/crates/polars-compute/src/filter/mod.rs new file mode 100644 index 000000000000..ab5bf6bba72d --- /dev/null +++ b/crates/polars-compute/src/filter/mod.rs @@ -0,0 +1,110 @@ +//! Contains operators to filter arrays such as [`filter`]. +mod boolean; +mod primitive; +mod scalar; + +#[cfg(all(target_arch = "x86_64", feature = "simd"))] +mod avx512; + +use arrow::array::builder::{ArrayBuilder, ShareStrategy, make_builder}; +use arrow::array::{ + Array, BinaryViewArray, BooleanArray, PrimitiveArray, Utf8ViewArray, new_empty_array, +}; +use arrow::bitmap::Bitmap; +use arrow::bitmap::utils::SlicesIterator; +use arrow::with_match_primitive_type_full; +pub use boolean::filter_boolean_kernel; + +pub fn filter(array: &dyn Array, mask: &BooleanArray) -> Box { + assert_eq!(array.len(), mask.len()); + + // Treat null mask values as false. + if let Some(validities) = mask.validity() { + let combined_mask = mask.values() & validities; + filter_with_bitmap(array, &combined_mask) + } else { + filter_with_bitmap(array, mask.values()) + } +} + +pub fn filter_with_bitmap(array: &dyn Array, mask: &Bitmap) -> Box { + // Many filters involve filtering values in a subsection of the array. When we trim the leading + // and trailing filtered items, we can close in on those items and not have to perform and + // thinking about those. The overhead for when there are no leading or trailing filtered values + // is very minimal: only a clone of the mask and the array. + // + // This also allows dispatching to the fast paths way, way, way more often. + let mut mask = mask.clone(); + let leading_zeros = mask.take_leading_zeros(); + mask.take_trailing_zeros(); + let array = array.sliced(leading_zeros, mask.len()); + + let mask = &mask; + let array = array.as_ref(); + + // Fast-path: completely empty or completely full mask. + let false_count = mask.unset_bits(); + if false_count == mask.len() { + return new_empty_array(array.dtype().clone()); + } + if false_count == 0 { + return array.to_boxed(); + } + + use arrow::datatypes::PhysicalType::*; + match array.dtype().to_physical_type() { + Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { + let array: &PrimitiveArray<$T> = array.as_any().downcast_ref().unwrap(); + let (values, validity) = primitive::filter_values_and_validity::<$T>(array.values(), array.validity(), mask); + Box::new(PrimitiveArray::from_vec(values).with_validity(validity)) + }), + Boolean => { + let array = array.as_any().downcast_ref::().unwrap(); + let (values, validity) = + boolean::filter_bitmap_and_validity(array.values(), array.validity(), mask); + BooleanArray::new(array.dtype().clone(), values, validity).boxed() + }, + BinaryView => { + let array = array.as_any().downcast_ref::().unwrap(); + let views = array.views(); + let validity = array.validity(); + let (views, validity) = primitive::filter_values_and_validity(views, validity, mask); + unsafe { + BinaryViewArray::new_unchecked_unknown_md( + array.dtype().clone(), + views.into(), + array.data_buffers().clone(), + validity, + Some(array.total_buffer_len()), + ) + } + .boxed() + }, + Utf8View => { + let array = array.as_any().downcast_ref::().unwrap(); + let views = array.views(); + let validity = array.validity(); + let (views, validity) = primitive::filter_values_and_validity(views, validity, mask); + unsafe { + BinaryViewArray::new_unchecked_unknown_md( + arrow::datatypes::ArrowDataType::BinaryView, + views.into(), + array.data_buffers().clone(), + validity, + Some(array.total_buffer_len()), + ) + .to_utf8view_unchecked() + } + .boxed() + }, + _ => { + let iter = SlicesIterator::new(mask); + let mut mutable = make_builder(array.dtype()); + mutable.reserve(iter.slots()); + iter.for_each(|(start, len)| { + mutable.subslice_extend(array, start, len, ShareStrategy::Always) + }); + mutable.freeze() + }, + } +} diff --git a/crates/polars-compute/src/filter/primitive.rs b/crates/polars-compute/src/filter/primitive.rs new file mode 100644 index 000000000000..6f2bd792c1d8 --- /dev/null +++ b/crates/polars-compute/src/filter/primitive.rs @@ -0,0 +1,94 @@ +use arrow::bitmap::Bitmap; +use bytemuck::{Pod, cast_slice, cast_vec}; +#[cfg(all(target_arch = "x86_64", feature = "simd"))] +use polars_utils::cpuid::is_avx512_enabled; + +#[cfg(all(target_arch = "x86_64", feature = "simd"))] +use super::avx512; +use super::boolean::filter_boolean_kernel; +use super::scalar::{scalar_filter, scalar_filter_offset}; + +type FilterFn = for<'a> unsafe fn(&'a [T], &'a [u8], *mut T) -> (&'a [T], &'a [u8], *mut T); + +fn nop_filter<'a, T: Pod>( + values: &'a [T], + mask: &'a [u8], + out: *mut T, +) -> (&'a [T], &'a [u8], *mut T) { + (values, mask, out) +} + +pub fn filter_values(values: &[T], mask: &Bitmap) -> Vec { + match (size_of::(), align_of::()) { + (1, 1) => cast_vec(filter_values_u8(cast_slice(values), mask)), + (2, 2) => cast_vec(filter_values_u16(cast_slice(values), mask)), + (4, 4) => cast_vec(filter_values_u32(cast_slice(values), mask)), + (8, 8) => cast_vec(filter_values_u64(cast_slice(values), mask)), + _ => filter_values_generic(values, mask, 1, nop_filter), + } +} + +fn filter_values_u8(values: &[u8], mask: &Bitmap) -> Vec { + #[cfg(all(target_arch = "x86_64", feature = "simd"))] + if is_avx512_enabled() && std::arch::is_x86_feature_detected!("avx512vbmi2") { + return filter_values_generic(values, mask, 64, avx512::filter_u8_avx512vbmi2); + } + + filter_values_generic(values, mask, 1, nop_filter) +} + +fn filter_values_u16(values: &[u16], mask: &Bitmap) -> Vec { + #[cfg(all(target_arch = "x86_64", feature = "simd"))] + if is_avx512_enabled() && std::arch::is_x86_feature_detected!("avx512vbmi2") { + return filter_values_generic(values, mask, 32, avx512::filter_u16_avx512vbmi2); + } + + filter_values_generic(values, mask, 1, nop_filter) +} + +fn filter_values_u32(values: &[u32], mask: &Bitmap) -> Vec { + #[cfg(all(target_arch = "x86_64", feature = "simd"))] + if is_avx512_enabled() { + return filter_values_generic(values, mask, 16, avx512::filter_u32_avx512f); + } + + filter_values_generic(values, mask, 1, nop_filter) +} + +fn filter_values_u64(values: &[u64], mask: &Bitmap) -> Vec { + #[cfg(all(target_arch = "x86_64", feature = "simd"))] + if is_avx512_enabled() { + return filter_values_generic(values, mask, 8, avx512::filter_u64_avx512f); + } + + filter_values_generic(values, mask, 1, nop_filter) +} + +fn filter_values_generic( + values: &[T], + mask: &Bitmap, + pad: usize, + bulk_filter: FilterFn, +) -> Vec { + assert_eq!(values.len(), mask.len()); + let mask_bits_set = mask.set_bits(); + let mut out = Vec::with_capacity(mask_bits_set + pad); + unsafe { + let (values, mask_bytes, out_ptr) = scalar_filter_offset(values, mask, out.as_mut_ptr()); + let (values, mask_bytes, out_ptr) = bulk_filter(values, mask_bytes, out_ptr); + scalar_filter(values, mask_bytes, out_ptr); + out.set_len(mask_bits_set); + } + out +} + +pub fn filter_values_and_validity( + values: &[T], + validity: Option<&Bitmap>, + mask: &Bitmap, +) -> (Vec, Option) { + ( + filter_values(values, mask), + validity.map(|v| filter_boolean_kernel(v, mask)), + ) +} diff --git a/crates/polars-compute/src/filter/scalar.rs b/crates/polars-compute/src/filter/scalar.rs new file mode 100644 index 000000000000..00de0c9d0741 --- /dev/null +++ b/crates/polars-compute/src/filter/scalar.rs @@ -0,0 +1,138 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use arrow::bitmap::Bitmap; +use bytemuck::Pod; +use polars_utils::slice::load_padded_le_u64; + +/// # Safety +/// If the ith bit of m is set (m & (1 << i)), then v[i] must be in-bounds. +/// out must be valid for at least m.count_ones() + 1 writes. +unsafe fn scalar_sparse_filter64(v: &[T], mut m: u64, out: *mut T) { + let mut written = 0usize; + + while m > 0 { + // Unroll loop manually twice. + let idx = m.trailing_zeros() as usize; + *out.add(written) = *v.get_unchecked(idx); + m &= m.wrapping_sub(1); // Clear least significant bit. + written += 1; + + // tz % 64 otherwise we could go out of bounds + let idx = (m.trailing_zeros() % 64) as usize; + *out.add(written) = *v.get_unchecked(idx); + m &= m.wrapping_sub(1); // Clear least significant bit. + written += 1; + } +} + +/// # Safety +/// v.len() >= 64 must hold. +/// out must be valid for at least m.count_ones() + 1 writes. +unsafe fn scalar_dense_filter64(v: &[T], mut m: u64, out: *mut T) { + // Rust generated significantly better code if we write the below loop + // with v as a pointer, and out.add(written) instead of incrementing out + // directly. + let mut written = 0usize; + let mut src = v.as_ptr(); + + // We hope the outer loop doesn't get unrolled, but the inner loop does. + for _ in 0..16 { + for i in 0..4 { + *out.add(written) = *src; + written += ((m >> i) & 1) as usize; + src = src.add(1); + } + m >>= 4; + } +} + +/// Handles the offset portion of a Bitmap to start an efficient filter operation. +/// Returns the remaining values and mask bytes for the filter, as well as where +/// to continue writing to out. +/// +/// # Safety +/// out must be valid for at least mask.set_bits() + 1 writes. +pub unsafe fn scalar_filter_offset<'a, T: Pod>( + values: &'a [T], + mask: &'a Bitmap, + mut out: *mut T, +) -> (&'a [T], &'a [u8], *mut T) { + assert_eq!(values.len(), mask.len()); + + let (mut mask_bytes, offset, len) = mask.as_slice(); + let mut value_idx = 0; + if offset > 0 { + let first_byte = mask_bytes[0]; + mask_bytes = &mask_bytes[1..]; + + for byte_idx in offset..8 { + if value_idx < len { + unsafe { + // SAFETY: we checked that value_idx < len. + let bit_is_set = first_byte & (1 << byte_idx) != 0; + *out = *values.get_unchecked(value_idx); + out = out.add(bit_is_set as usize); + } + value_idx += 1; + } + } + } + + (&values[value_idx..], mask_bytes, out) +} + +/// # Safety +/// out must be valid for 1 + bitslice(mask_bytes, 0..values.len()).count_ones() writes. +pub unsafe fn scalar_filter(values: &[T], mut mask_bytes: &[u8], mut out: *mut T) { + assert!(mask_bytes.len() * 8 >= values.len()); + + // Handle bulk. + let mut value_idx = 0; + while value_idx + 64 <= values.len() { + let (mask_chunk, value_chunk); + unsafe { + // SAFETY: we checked that value_idx + 64 <= values.len(), so these are + // all in-bounds. + mask_chunk = mask_bytes.get_unchecked(0..8); + mask_bytes = mask_bytes.get_unchecked(8..); + value_chunk = values.get_unchecked(value_idx..value_idx + 64); + value_idx += 64; + }; + let m = u64::from_le_bytes(mask_chunk.try_into().unwrap()); + + // Fast-path: empty mask. + if m == 0 { + continue; + } + + unsafe { + // SAFETY: we will only write at most m_popcnt + 1 to out, which + // is allowed. + + // Fast-path: completely full mask. + if m == u64::MAX { + core::ptr::copy_nonoverlapping(value_chunk.as_ptr(), out, 64); + out = out.add(64); + continue; + } + + let m_popcnt = m.count_ones(); + if m_popcnt <= 16 { + scalar_sparse_filter64(value_chunk, m, out) + } else { + scalar_dense_filter64(value_chunk, m, out) + }; + out = out.add(m_popcnt as usize); + } + } + + // Handle remainder. + if value_idx < values.len() { + let rest_len = values.len() - value_idx; + assert!(rest_len < 64); + let m = load_padded_le_u64(mask_bytes) & ((1 << rest_len) - 1); + unsafe { + let value_chunk = values.get_unchecked(value_idx..); + scalar_sparse_filter64(value_chunk, m, out); + } + } +} diff --git a/crates/polars-compute/src/float_sum.rs b/crates/polars-compute/src/float_sum.rs new file mode 100644 index 000000000000..4c44aa732ffe --- /dev/null +++ b/crates/polars-compute/src/float_sum.rs @@ -0,0 +1,277 @@ +use std::ops::{Add, IndexMut}; +#[cfg(feature = "simd")] +use std::simd::{prelude::*, *}; + +use arrow::array::{Array, PrimitiveArray}; +use arrow::bitmap::Bitmap; +use arrow::bitmap::bitmask::BitMask; +use arrow::types::NativeType; +use num_traits::{AsPrimitive, Float}; + +const STRIPE: usize = 16; +const PAIRWISE_RECURSION_LIMIT: usize = 128; + +// We want to be generic over both integers and floats, requiring this helper trait. +#[cfg(feature = "simd")] +pub trait SimdCastGeneric +where + LaneCount: SupportedLaneCount, +{ + fn cast_generic(self) -> Simd; +} + +macro_rules! impl_cast_custom { + ($_type:ty) => { + #[cfg(feature = "simd")] + impl SimdCastGeneric for Simd<$_type, N> + where + LaneCount: SupportedLaneCount, + { + fn cast_generic(self) -> Simd { + self.cast::() + } + } + }; +} + +impl_cast_custom!(u8); +impl_cast_custom!(u16); +impl_cast_custom!(u32); +impl_cast_custom!(u64); +impl_cast_custom!(i8); +impl_cast_custom!(i16); +impl_cast_custom!(i32); +impl_cast_custom!(i64); +impl_cast_custom!(f32); +impl_cast_custom!(f64); + +fn vector_horizontal_sum(mut v: V) -> T +where + V: IndexMut, + T: Add + Sized + Copy, +{ + // We have to be careful about this reduction, floating + // point math is NOT associative so we have to write this + // in a form that maps to good shuffle instructions. + // We fold the vector onto itself, halved, until we are down to + // four elements which we add in a shuffle-friendly way. + let mut width = STRIPE; + while width > 4 { + for j in 0..width / 2 { + v[j] = v[j] + v[width / 2 + j]; + } + width /= 2; + } + + (v[0] + v[2]) + (v[1] + v[3]) +} + +// As a trait to not proliferate SIMD bounds. +pub trait SumBlock { + fn sum_block_vectorized(&self) -> F; + fn sum_block_vectorized_with_mask(&self, mask: BitMask<'_>) -> F; +} + +#[cfg(feature = "simd")] +impl SumBlock for [T; PAIRWISE_RECURSION_LIMIT] +where + T: SimdElement, + F: SimdElement + SimdCast + Add + Default, + Simd: SimdCastGeneric, + Simd: std::iter::Sum, +{ + fn sum_block_vectorized(&self) -> F { + let vsum = self + .chunks_exact(STRIPE) + .map(|a| Simd::::from_slice(a).cast_generic::()) + .sum::>(); + vector_horizontal_sum(vsum) + } + + fn sum_block_vectorized_with_mask(&self, mask: BitMask<'_>) -> F { + let zero = Simd::default(); + let vsum = self + .chunks_exact(STRIPE) + .enumerate() + .map(|(i, a)| { + let m: Mask<_, STRIPE> = mask.get_simd(i * STRIPE); + m.select(Simd::from_slice(a).cast_generic::(), zero) + }) + .sum::>(); + vector_horizontal_sum(vsum) + } +} + +#[cfg(feature = "simd")] +impl SumBlock for [i128; PAIRWISE_RECURSION_LIMIT] +where + i128: AsPrimitive, + F: Float + std::iter::Sum + 'static, +{ + fn sum_block_vectorized(&self) -> F { + self.iter().map(|x| x.as_()).sum() + } + + fn sum_block_vectorized_with_mask(&self, mask: BitMask<'_>) -> F { + self.iter() + .enumerate() + .map(|(idx, x)| if mask.get(idx) { x.as_() } else { F::zero() }) + .sum() + } +} + +#[cfg(not(feature = "simd"))] +impl SumBlock for [T; PAIRWISE_RECURSION_LIMIT] +where + T: AsPrimitive + 'static, + F: Default + Add + Copy + 'static, +{ + fn sum_block_vectorized(&self) -> F { + let mut vsum = [F::default(); STRIPE]; + for chunk in self.chunks_exact(STRIPE) { + for j in 0..STRIPE { + vsum[j] = vsum[j] + chunk[j].as_(); + } + } + vector_horizontal_sum(vsum) + } + + fn sum_block_vectorized_with_mask(&self, mask: BitMask<'_>) -> F { + let mut vsum = [F::default(); STRIPE]; + for (i, chunk) in self.chunks_exact(STRIPE).enumerate() { + for j in 0..STRIPE { + // Unconditional add with select for better branch-free opts. + let addend = if mask.get(i * STRIPE + j) { + chunk[j].as_() + } else { + F::default() + }; + vsum[j] = vsum[j] + addend; + } + } + vector_horizontal_sum(vsum) + } +} + +/// Invariant: f.len() % PAIRWISE_RECURSION_LIMIT == 0 and f.len() > 0. +unsafe fn pairwise_sum(f: &[T]) -> F +where + [T; PAIRWISE_RECURSION_LIMIT]: SumBlock, + F: Add, +{ + debug_assert!(!f.is_empty() && f.len() % PAIRWISE_RECURSION_LIMIT == 0); + + let block: Option<&[T; PAIRWISE_RECURSION_LIMIT]> = f.try_into().ok(); + if let Some(block) = block { + return block.sum_block_vectorized(); + } + + // SAFETY: we maintain the invariant. `try_into` array of len PAIRWISE_RECURSION_LIMIT + // failed so we know f.len() >= 2*PAIRWISE_RECURSION_LIMIT, and thus blocks >= 2. + // This means 0 < left_len < f.len() and left_len is divisible by PAIRWISE_RECURSION_LIMIT, + // maintaining the invariant for both recursive calls. + unsafe { + let blocks = f.len() / PAIRWISE_RECURSION_LIMIT; + let left_len = (blocks / 2) * PAIRWISE_RECURSION_LIMIT; + let (left, right) = (f.get_unchecked(..left_len), f.get_unchecked(left_len..)); + pairwise_sum(left) + pairwise_sum(right) + } +} + +/// Invariant: f.len() % PAIRWISE_RECURSION_LIMIT == 0 and f.len() > 0. +/// Also, f.len() == mask.len(). +unsafe fn pairwise_sum_with_mask(f: &[T], mask: BitMask<'_>) -> F +where + [T; PAIRWISE_RECURSION_LIMIT]: SumBlock, + F: Add, +{ + debug_assert!(!f.is_empty() && f.len() % PAIRWISE_RECURSION_LIMIT == 0); + debug_assert!(f.len() == mask.len()); + + let block: Option<&[T; PAIRWISE_RECURSION_LIMIT]> = f.try_into().ok(); + if let Some(block) = block { + return block.sum_block_vectorized_with_mask(mask); + } + + // SAFETY: see pairwise_sum. + unsafe { + let blocks = f.len() / PAIRWISE_RECURSION_LIMIT; + let left_len = (blocks / 2) * PAIRWISE_RECURSION_LIMIT; + let (left, right) = (f.get_unchecked(..left_len), f.get_unchecked(left_len..)); + let (left_mask, right_mask) = mask.split_at_unchecked(left_len); + pairwise_sum_with_mask(left, left_mask) + pairwise_sum_with_mask(right, right_mask) + } +} + +pub trait FloatSum: Sized { + fn sum(f: &[Self]) -> F; + fn sum_with_validity(f: &[Self], validity: &Bitmap) -> F; +} + +impl FloatSum for T +where + F: Float + std::iter::Sum + 'static, + T: AsPrimitive, + [T; PAIRWISE_RECURSION_LIMIT]: SumBlock, +{ + fn sum(f: &[Self]) -> F { + let remainder = f.len() % PAIRWISE_RECURSION_LIMIT; + let (rest, main) = f.split_at(remainder); + let mainsum = if f.len() > remainder { + unsafe { pairwise_sum(main) } + } else { + F::zero() + }; + // TODO: faster remainder. + let restsum: F = rest.iter().map(|x| x.as_()).sum(); + mainsum + restsum + } + + fn sum_with_validity(f: &[Self], validity: &Bitmap) -> F { + let mask = BitMask::from_bitmap(validity); + assert!(f.len() == mask.len()); + + let remainder = f.len() % PAIRWISE_RECURSION_LIMIT; + let (rest, main) = f.split_at(remainder); + let (rest_mask, main_mask) = mask.split_at(remainder); + let mainsum = if f.len() > remainder { + unsafe { pairwise_sum_with_mask(main, main_mask) } + } else { + F::zero() + }; + // TODO: faster remainder. + let restsum: F = rest + .iter() + .enumerate() + .map(|(i, x)| { + // No filter but rather select of 0.0 for cmov opt. + if rest_mask.get(i) { x.as_() } else { F::zero() } + }) + .sum(); + mainsum + restsum + } +} + +pub fn sum_arr_as_f32(arr: &PrimitiveArray) -> f32 +where + T: NativeType + FloatSum, +{ + let validity = arr.validity().filter(|_| arr.null_count() > 0); + if let Some(mask) = validity { + FloatSum::sum_with_validity(arr.values(), mask) + } else { + FloatSum::sum(arr.values()) + } +} + +pub fn sum_arr_as_f64(arr: &PrimitiveArray) -> f64 +where + T: NativeType + FloatSum, +{ + let validity = arr.validity().filter(|_| arr.null_count() > 0); + if let Some(mask) = validity { + FloatSum::sum_with_validity(arr.values(), mask) + } else { + FloatSum::sum(arr.values()) + } +} diff --git a/crates/polars-compute/src/gather/binary.rs b/crates/polars-compute/src/gather/binary.rs new file mode 100644 index 000000000000..d34ee69316bd --- /dev/null +++ b/crates/polars-compute/src/gather/binary.rs @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, BinaryArray, PrimitiveArray}; +use arrow::offset::Offset; + +use super::Index; +use super::generic_binary::*; + +/// `take` implementation for utf8 arrays +/// # Safety +/// The indices must be in-bounds. +pub unsafe fn take_unchecked( + values: &BinaryArray, + indices: &PrimitiveArray, +) -> BinaryArray { + let dtype = values.dtype().clone(); + let indices_has_validity = indices.null_count() > 0; + let values_has_validity = values.null_count() > 0; + + let (offsets, values, validity) = match (values_has_validity, indices_has_validity) { + (false, false) => { + take_no_validity_unchecked::(values.offsets(), values.values(), indices.values()) + }, + (true, false) => take_values_validity(values, indices.values()), + (false, true) => take_indices_validity(values.offsets(), values.values(), indices), + (true, true) => take_values_indices_validity(values, indices), + }; + BinaryArray::::new_unchecked(dtype, offsets, values, validity) +} diff --git a/crates/polars-compute/src/gather/binview.rs b/crates/polars-compute/src/gather/binview.rs new file mode 100644 index 000000000000..05c16845a075 --- /dev/null +++ b/crates/polars-compute/src/gather/binview.rs @@ -0,0 +1,23 @@ +use arrow::array::{BinaryViewArrayGeneric, ViewType}; + +use self::primitive::take_values_and_validity_unchecked; +use super::*; + +/// # Safety +/// No bound checks +pub unsafe fn take_binview_unchecked( + arr: &BinaryViewArrayGeneric, + indices: &IdxArr, +) -> BinaryViewArrayGeneric { + let (views, validity) = + take_values_and_validity_unchecked(arr.views(), arr.validity(), indices); + + BinaryViewArrayGeneric::new_unchecked_unknown_md( + arr.dtype().clone(), + views.into(), + arr.data_buffers().clone(), + validity, + Some(arr.total_buffer_len()), + ) + .maybe_gc() +} diff --git a/crates/polars-compute/src/gather/bitmap.rs b/crates/polars-compute/src/gather/bitmap.rs new file mode 100644 index 000000000000..e64bd46c650d --- /dev/null +++ b/crates/polars-compute/src/gather/bitmap.rs @@ -0,0 +1,37 @@ +use arrow::array::Array; +use arrow::bitmap::Bitmap; +use arrow::datatypes::IdxArr; +use polars_utils::IdxSize; + +/// # Safety +/// Doesn't do any bound checks. +pub unsafe fn take_bitmap_unchecked(values: &Bitmap, indices: &[IdxSize]) -> Bitmap { + let values = indices.iter().map(|&index| { + debug_assert!((index as usize) < values.len()); + values.get_bit_unchecked(index as usize) + }); + Bitmap::from_trusted_len_iter(values) +} + +/// # Safety +/// Doesn't check bounds for non-null elements. +pub unsafe fn take_bitmap_nulls_unchecked(values: &Bitmap, indices: &IdxArr) -> Bitmap { + // Fast-path: no need to bother with null indices. + if indices.null_count() == 0 { + return take_bitmap_unchecked(values, indices.values()); + } + + if values.is_empty() { + // Nothing can be in-bounds, assume indices is full-null. + debug_assert!(indices.null_count() == indices.len()); + return Bitmap::new_zeroed(indices.len()); + } + + let values = indices.iter().map(|opt_index| { + // We checked that values.len() > 0 so we can use index 0 for nulls. + let index = opt_index.copied().unwrap_or(0) as usize; + debug_assert!(index < values.len()); + values.get_bit_unchecked(index) + }); + Bitmap::from_trusted_len_iter(values) +} diff --git a/crates/polars-compute/src/gather/boolean.rs b/crates/polars-compute/src/gather/boolean.rs new file mode 100644 index 000000000000..d118ad1f61e2 --- /dev/null +++ b/crates/polars-compute/src/gather/boolean.rs @@ -0,0 +1,80 @@ +use arrow::array::{Array, BooleanArray, PrimitiveArray}; +use arrow::bitmap::{Bitmap, BitmapBuilder}; +use polars_utils::IdxSize; + +use super::bitmap::{take_bitmap_nulls_unchecked, take_bitmap_unchecked}; + +// Take implementation when neither values nor indices contain nulls. +unsafe fn take_no_validity(values: &Bitmap, indices: &[IdxSize]) -> (Bitmap, Option) { + (take_bitmap_unchecked(values, indices), None) +} + +// Take implementation when only values contain nulls. +unsafe fn take_values_validity( + values: &BooleanArray, + indices: &[IdxSize], +) -> (Bitmap, Option) { + let validity_values = values.validity().unwrap(); + let validity = take_bitmap_unchecked(validity_values, indices); + + let values_values = values.values(); + let buffer = take_bitmap_unchecked(values_values, indices); + + (buffer, validity.into()) +} + +// Take implementation when only indices contain nulls. +unsafe fn take_indices_validity( + values: &Bitmap, + indices: &PrimitiveArray, +) -> (Bitmap, Option) { + let buffer = take_bitmap_nulls_unchecked(values, indices); + (buffer, indices.validity().cloned()) +} + +// Take implementation when both values and indices contain nulls. +unsafe fn take_values_indices_validity( + values: &BooleanArray, + indices: &PrimitiveArray, +) -> (Bitmap, Option) { + let mut validity = BitmapBuilder::with_capacity(indices.len()); + + let values_validity = values.validity().unwrap(); + + let values_values = values.values(); + let values = indices.iter().map(|index| match index { + Some(&index) => { + let index = index as usize; + debug_assert!(index < values.len()); + validity.push(values_validity.get_bit_unchecked(index)); + values_values.get_bit_unchecked(index) + }, + None => { + validity.push(false); + false + }, + }); + let values = Bitmap::from_trusted_len_iter(values); + (values, validity.into_opt_validity()) +} + +/// `take` implementation for boolean arrays +/// # Safety +/// The indices must be in-bounds. +pub unsafe fn take_unchecked( + values: &BooleanArray, + indices: &PrimitiveArray, +) -> BooleanArray { + let dtype = values.dtype().clone(); + let indices_has_validity = indices.null_count() > 0; + let values_has_validity = values.null_count() > 0; + + let (values, validity) = match (values_has_validity, indices_has_validity) { + (false, false) => take_no_validity(values.values(), indices.values()), + (true, false) => take_values_validity(values, indices.values()), + (false, true) => take_indices_validity(values.values(), indices), + (true, true) => take_values_indices_validity(values, indices), + }; + + BooleanArray::new(dtype, values, validity) +} diff --git a/crates/polars-compute/src/gather/fixed_size_list.rs b/crates/polars-compute/src/gather/fixed_size_list.rs new file mode 100644 index 000000000000..8cc74d5a78ad --- /dev/null +++ b/crates/polars-compute/src/gather/fixed_size_list.rs @@ -0,0 +1,285 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::mem::ManuallyDrop; + +use arrow::array::{Array, ArrayRef, FixedSizeListArray, PrimitiveArray, StaticArray}; +use arrow::bitmap::MutableBitmap; +use arrow::compute::utils::combine_validities_and; +use arrow::datatypes::reshape::{Dimension, ReshapeDimension}; +use arrow::datatypes::{ArrowDataType, IdxArr, PhysicalType}; +use arrow::legacy::prelude::FromData; +use arrow::with_match_primitive_type; +use polars_utils::itertools::Itertools; + +use super::Index; +use crate::gather::bitmap::{take_bitmap_nulls_unchecked, take_bitmap_unchecked}; + +fn get_stride_and_leaf_type(dtype: &ArrowDataType, size: usize) -> (usize, &ArrowDataType) { + if let ArrowDataType::FixedSizeList(inner, size_inner) = dtype { + get_stride_and_leaf_type(inner.dtype(), *size_inner * size) + } else { + (size, dtype) + } +} + +fn get_leaves(array: &FixedSizeListArray) -> &dyn Array { + if let Some(array) = array.values().as_any().downcast_ref::() { + get_leaves(array) + } else { + &**array.values() + } +} + +fn get_buffer_and_size(array: &dyn Array) -> (&[u8], usize) { + match array.dtype().to_physical_type() { + PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + + let arr = array.as_any().downcast_ref::>().unwrap(); + let values = arr.values(); + (bytemuck::cast_slice(values), size_of::<$T>()) + + }), + _ => { + unimplemented!() + }, + } +} + +unsafe fn from_buffer(mut buf: ManuallyDrop>, dtype: &ArrowDataType) -> ArrayRef { + match dtype.to_physical_type() { + PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + let ptr = buf.as_mut_ptr(); + let len_units = buf.len(); + let cap_units = buf.capacity(); + + let buf = Vec::from_raw_parts( + ptr as *mut $T, + len_units / size_of::<$T>(), + cap_units / size_of::<$T>(), + ); + + PrimitiveArray::from_data_default(buf.into(), None).boxed() + + }), + _ => { + unimplemented!() + }, + } +} + +unsafe fn aligned_vec(dt: &ArrowDataType, n_bytes: usize) -> Vec { + match dt.to_physical_type() { + PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + + let n_units = (n_bytes / size_of::<$T>()) + 1; + + let mut aligned: Vec<$T> = Vec::with_capacity(n_units); + + let ptr = aligned.as_mut_ptr(); + let len_units = aligned.len(); + let cap_units = aligned.capacity(); + + std::mem::forget(aligned); + + Vec::from_raw_parts( + ptr as *mut u8, + len_units * size_of::<$T>(), + cap_units * size_of::<$T>(), + ) + + }), + _ => { + unimplemented!() + }, + } +} + +fn arr_no_validities_recursive(arr: &dyn Array) -> bool { + arr.validity().is_none() + && arr + .as_any() + .downcast_ref::() + .is_none_or(|x| arr_no_validities_recursive(x.values().as_ref())) +} + +/// `take` implementation for FixedSizeListArrays +pub(super) unsafe fn take_unchecked(values: &FixedSizeListArray, indices: &IdxArr) -> ArrayRef { + let (stride, leaf_type) = get_stride_and_leaf_type(values.dtype(), 1); + if leaf_type.to_physical_type().is_primitive() + && arr_no_validities_recursive(values.values().as_ref()) + { + let leaves = get_leaves(values); + + let (leaves_buf, leave_size) = get_buffer_and_size(leaves); + let bytes_per_element = leave_size * stride; + + let n_idx = indices.len(); + let total_bytes = bytes_per_element * n_idx; + + let mut buf = ManuallyDrop::new(aligned_vec(leaves.dtype(), total_bytes)); + let dst = buf.spare_capacity_mut(); + + let mut count = 0; + let outer_validity = if indices.null_count() == 0 { + for i in indices.values().iter() { + let i = i.to_usize(); + + std::ptr::copy_nonoverlapping( + leaves_buf.as_ptr().add(i * bytes_per_element), + dst.as_mut_ptr().add(count * bytes_per_element) as *mut _, + bytes_per_element, + ); + count += 1; + } + None + } else { + let mut new_validity = MutableBitmap::with_capacity(indices.len()); + new_validity.extend_constant(indices.len(), true); + for i in indices.iter() { + if let Some(i) = i { + let i = i.to_usize(); + std::ptr::copy_nonoverlapping( + leaves_buf.as_ptr().add(i * bytes_per_element), + dst.as_mut_ptr().add(count * bytes_per_element) as *mut _, + bytes_per_element, + ); + } else { + new_validity.set_unchecked(count, false); + std::ptr::write_bytes( + dst.as_mut_ptr().add(count * bytes_per_element) as *mut _, + 0, + bytes_per_element, + ); + } + + count += 1; + } + Some(new_validity.freeze()) + }; + + assert_eq!(count * bytes_per_element, total_bytes); + buf.set_len(total_bytes); + + let outer_validity = combine_validities_and( + outer_validity.as_ref(), + values + .validity() + .map(|x| { + if indices.has_nulls() { + take_bitmap_nulls_unchecked(x, indices) + } else { + take_bitmap_unchecked(x, indices.as_slice().unwrap()) + } + }) + .as_ref(), + ); + + let leaves = from_buffer(buf, leaves.dtype()); + let mut shape = values.get_dims(); + shape[0] = Dimension::new(indices.len() as _); + let shape = shape + .into_iter() + .map(ReshapeDimension::Specified) + .collect_vec(); + + FixedSizeListArray::from_shape(leaves.clone(), &shape) + .unwrap() + .with_validity(outer_validity) + } else { + super::take_unchecked_impl_generic(values, indices, &FixedSizeListArray::new_null).boxed() + } +} + +#[cfg(test)] +mod tests { + use arrow::array::StaticArray; + use arrow::datatypes::ArrowDataType; + + /// Test gather for FixedSizeListArray with outer validity but no inner validities. + #[test] + fn test_arr_gather_nulls_outer_validity_19482() { + use arrow::array::{FixedSizeListArray, Int64Array, PrimitiveArray}; + use arrow::bitmap::Bitmap; + use arrow::datatypes::reshape::{Dimension, ReshapeDimension}; + use polars_utils::IdxSize; + + use super::take_unchecked; + + unsafe { + let dyn_arr = FixedSizeListArray::from_shape( + Box::new(Int64Array::from_slice([1, 2, 3, 4])), + &[ + ReshapeDimension::Specified(Dimension::new(2)), + ReshapeDimension::Specified(Dimension::new(2)), + ], + ) + .unwrap() + .with_validity(Some(Bitmap::from_iter([true, false]))); // FixedSizeListArray[[1, 2], None] + + let arr = dyn_arr + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!( + [arr.validity().is_some(), arr.values().validity().is_some()], + [true, false] + ); + + assert_eq!( + take_unchecked(arr, &PrimitiveArray::::from_slice([0, 1])), + dyn_arr + ) + } + } + + #[test] + fn test_arr_gather_nulls_inner_validity() { + use arrow::array::{FixedSizeListArray, Int64Array, PrimitiveArray}; + use arrow::datatypes::reshape::{Dimension, ReshapeDimension}; + use polars_utils::IdxSize; + + use super::take_unchecked; + + unsafe { + let dyn_arr = FixedSizeListArray::from_shape( + Box::new(Int64Array::full_null(4, ArrowDataType::Int64)), + &[ + ReshapeDimension::Specified(Dimension::new(2)), + ReshapeDimension::Specified(Dimension::new(2)), + ], + ) + .unwrap(); // FixedSizeListArray[[None, None], [None, None]] + + let arr = dyn_arr + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!( + [arr.validity().is_some(), arr.values().validity().is_some()], + [false, true] + ); + + assert_eq!( + take_unchecked(arr, &PrimitiveArray::::from_slice([0, 1])), + dyn_arr + ) + } + } +} diff --git a/crates/polars-compute/src/gather/generic_binary.rs b/crates/polars-compute/src/gather/generic_binary.rs new file mode 100644 index 000000000000..652a30d3fe10 --- /dev/null +++ b/crates/polars-compute/src/gather/generic_binary.rs @@ -0,0 +1,173 @@ +use arrow::array::{GenericBinaryArray, PrimitiveArray}; +use arrow::bitmap::{Bitmap, BitmapBuilder}; +use arrow::buffer::Buffer; +use arrow::offset::{Offset, Offsets, OffsetsBuffer}; +use polars_utils::vec::{CapacityByFactor, PushUnchecked}; + +use super::Index; + +fn create_offsets, O: Offset>( + lengths: I, + idx_len: usize, +) -> OffsetsBuffer { + let mut length_so_far = O::default(); + let mut offsets = Vec::with_capacity(idx_len + 1); + offsets.push(length_so_far); + + for len in lengths { + unsafe { + length_so_far += O::from_usize(len).unwrap_unchecked(); + offsets.push_unchecked(length_so_far) + }; + } + unsafe { Offsets::new_unchecked(offsets).into() } +} + +pub(super) unsafe fn take_values( + length: O, + starts: &[O], + offsets: &OffsetsBuffer, + values: &[u8], +) -> Buffer { + let new_len = length.to_usize(); + let mut buffer = Vec::with_capacity(new_len); + starts + .iter() + .map(|start| start.to_usize()) + .zip(offsets.lengths()) + .for_each(|(start, length)| { + let end = start + length; + buffer.extend_from_slice(values.get_unchecked(start..end)); + }); + buffer.into() +} + +// take implementation when neither values nor indices contain nulls +pub(super) unsafe fn take_no_validity_unchecked( + offsets: &OffsetsBuffer, + values: &[u8], + indices: &[I], +) -> (OffsetsBuffer, Buffer, Option) { + let values_len = offsets.last().to_usize(); + let fraction_estimate = indices.len() as f64 / offsets.len() as f64 + 0.3; + let mut buffer = Vec::::with_capacity_by_factor(values_len, fraction_estimate); + + let lengths = indices.iter().map(|index| index.to_usize()).map(|index| { + let (start, end) = offsets.start_end_unchecked(index); + buffer.extend_from_slice(values.get_unchecked(start..end)); + end - start + }); + let offsets = create_offsets(lengths, indices.len()); + + (offsets, buffer.into(), None) +} + +// take implementation when only values contain nulls +pub(super) unsafe fn take_values_validity>( + values: &A, + indices: &[I], +) -> (OffsetsBuffer, Buffer, Option) { + let validity_values = values.validity().unwrap(); + let validity = indices + .iter() + .map(|index| validity_values.get_bit_unchecked(index.to_usize())); + let validity = Bitmap::from_trusted_len_iter(validity); + + let mut total_length = O::default(); + + let offsets = values.offsets(); + let values_values = values.values(); + + let mut starts = Vec::::with_capacity(indices.len()); + let lengths = indices.iter().map(|index| { + let index = index.to_usize(); + let start = *offsets.get_unchecked(index); + let length = *offsets.get_unchecked(index + 1) - start; + total_length += length; + starts.push_unchecked(start); + length.to_usize() + }); + let offsets = create_offsets(lengths, indices.len()); + let buffer = take_values(total_length, starts.as_slice(), &offsets, values_values); + + (offsets, buffer, validity.into()) +} + +// take implementation when only indices contain nulls +pub(super) unsafe fn take_indices_validity( + offsets: &OffsetsBuffer, + values: &[u8], + indices: &PrimitiveArray, +) -> (OffsetsBuffer, Buffer, Option) { + let mut total_length = O::default(); + + let offsets = offsets.buffer(); + + let mut starts = Vec::::with_capacity(indices.len()); + let lengths = indices.values().iter().map(|index| { + let index = index.to_usize(); + let length; + match offsets.get(index + 1) { + Some(&next) => { + let start = *offsets.get_unchecked(index); + length = next - start; + total_length += length; + starts.push_unchecked(start); + }, + None => { + length = O::zero(); + starts.push_unchecked(O::default()); + }, + }; + length.to_usize() + }); + let offsets = create_offsets(lengths, indices.len()); + + let buffer = take_values(total_length, &starts, &offsets, values); + + (offsets, buffer, indices.validity().cloned()) +} + +// take implementation when both indices and values contain nulls +pub(super) unsafe fn take_values_indices_validity>( + values: &A, + indices: &PrimitiveArray, +) -> (OffsetsBuffer, Buffer, Option) { + let mut total_length = O::default(); + let mut validity = BitmapBuilder::with_capacity(indices.len()); + + let values_validity = values.validity().unwrap(); + let offsets = values.offsets(); + let values_values = values.values(); + + let mut starts = Vec::::with_capacity(indices.len()); + let lengths = indices.iter().map(|index| { + let length; + match index { + Some(index) => { + let index = index.to_usize(); + if values_validity.get_bit(index) { + validity.push(true); + length = *offsets.get_unchecked(index + 1) - *offsets.get_unchecked(index); + starts.push_unchecked(*offsets.get_unchecked(index)); + } else { + validity.push(false); + length = O::zero(); + starts.push_unchecked(O::default()); + } + }, + None => { + validity.push(false); + length = O::zero(); + starts.push_unchecked(O::default()); + }, + }; + total_length += length; + length.to_usize() + }); + let offsets = create_offsets(lengths, indices.len()); + + let buffer = take_values(total_length, &starts, &offsets, values_values); + + (offsets, buffer, validity.into_opt_validity()) +} diff --git a/crates/polars-compute/src/gather/list.rs b/crates/polars-compute/src/gather/list.rs new file mode 100644 index 000000000000..fb0863c6b268 --- /dev/null +++ b/crates/polars-compute/src/gather/list.rs @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{self, ArrayFromIterDtype, ListArray, StaticArray}; +use arrow::datatypes::IdxArr; +use arrow::offset::Offset; + +/// `take` implementation for ListArrays +pub(super) unsafe fn take_unchecked( + values: &ListArray, + indices: &IdxArr, +) -> ListArray +where + ListArray: StaticArray + ArrayFromIterDtype>>, +{ + super::take_unchecked_impl_generic(values, indices, &ListArray::new_null) +} diff --git a/crates/polars-compute/src/gather/mod.rs b/crates/polars-compute/src/gather/mod.rs new file mode 100644 index 000000000000..71b1eea1fa60 --- /dev/null +++ b/crates/polars-compute/src/gather/mod.rs @@ -0,0 +1,151 @@ +#![allow(unsafe_op_in_unsafe_fn)] +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines take kernel for [`Array`] + +use arrow::array::{ + self, Array, ArrayCollectIterExt, ArrayFromIterDtype, BinaryViewArray, NullArray, StaticArray, + Utf8ViewArray, new_empty_array, +}; +use arrow::datatypes::{ArrowDataType, IdxArr}; +use arrow::types::Index; + +pub mod binary; +pub mod binview; +pub mod bitmap; +pub mod boolean; +pub mod fixed_size_list; +pub mod generic_binary; +pub mod list; +pub mod primitive; +pub mod structure; +pub mod sublist; + +use arrow::with_match_primitive_type_full; + +/// Returns a new [`Array`] with only indices at `indices`. Null indices are taken as nulls. +/// The returned array has a length equal to `indices.len()`. +/// # Safety +/// Doesn't do bound checks +pub unsafe fn take_unchecked(values: &dyn Array, indices: &IdxArr) -> Box { + if indices.len() == 0 { + return new_empty_array(values.dtype().clone()); + } + + use arrow::datatypes::PhysicalType::*; + match values.dtype().to_physical_type() { + Null => Box::new(NullArray::new(values.dtype().clone(), indices.len())), + Boolean => { + let values = values.as_any().downcast_ref().unwrap(); + Box::new(boolean::take_unchecked(values, indices)) + }, + Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { + let values = values.as_any().downcast_ref().unwrap(); + Box::new(primitive::take_primitive_unchecked::<$T>(&values, indices)) + }), + LargeBinary => { + let values = values.as_any().downcast_ref().unwrap(); + Box::new(binary::take_unchecked::(values, indices)) + }, + Struct => { + let array = values.as_any().downcast_ref().unwrap(); + structure::take_unchecked(array, indices).boxed() + }, + LargeList => { + let array = values.as_any().downcast_ref().unwrap(); + Box::new(list::take_unchecked::(array, indices)) + }, + FixedSizeList => { + let array = values.as_any().downcast_ref().unwrap(); + fixed_size_list::take_unchecked(array, indices) + }, + BinaryView => { + let array: &BinaryViewArray = values.as_any().downcast_ref().unwrap(); + binview::take_binview_unchecked(array, indices).boxed() + }, + Utf8View => { + let array: &Utf8ViewArray = values.as_any().downcast_ref().unwrap(); + binview::take_binview_unchecked(array, indices).boxed() + }, + t => unimplemented!("Take not supported for data type {:?}", t), + } +} + +/// Naive default implementation +unsafe fn take_unchecked_impl_generic( + values: &T, + indices: &IdxArr, + new_null_func: &dyn Fn(ArrowDataType, usize) -> T, +) -> T +where + T: StaticArray + ArrayFromIterDtype>>, +{ + if values.null_count() == values.len() || indices.null_count() == indices.len() { + return new_null_func(values.dtype().clone(), indices.len()); + } + + match (indices.has_nulls(), values.has_nulls()) { + (true, true) => { + let values_validity = values.validity().unwrap(); + + indices + .iter() + .map(|i| { + if let Some(i) = i { + let i = *i as usize; + if values_validity.get_bit_unchecked(i) { + return Some(values.value_unchecked(i)); + } + } + None + }) + .collect_arr_trusted_with_dtype(values.dtype().clone()) + }, + (true, false) => indices + .iter() + .map(|i| { + if let Some(i) = i { + let i = *i as usize; + return Some(values.value_unchecked(i)); + } + None + }) + .collect_arr_trusted_with_dtype(values.dtype().clone()), + (false, true) => { + let values_validity = values.validity().unwrap(); + + indices + .values_iter() + .map(|i| { + let i = *i as usize; + if values_validity.get_bit_unchecked(i) { + return Some(values.value_unchecked(i)); + } + None + }) + .collect_arr_trusted_with_dtype(values.dtype().clone()) + }, + (false, false) => indices + .values_iter() + .map(|i| { + let i = *i as usize; + Some(values.value_unchecked(i)) + }) + .collect_arr_trusted_with_dtype(values.dtype().clone()), + } +} diff --git a/crates/polars-compute/src/gather/primitive.rs b/crates/polars-compute/src/gather/primitive.rs new file mode 100644 index 000000000000..e8b1d8c8230e --- /dev/null +++ b/crates/polars-compute/src/gather/primitive.rs @@ -0,0 +1,78 @@ +use arrow::array::PrimitiveArray; +use arrow::bitmap::utils::set_bit_unchecked; +use arrow::bitmap::{Bitmap, MutableBitmap}; +use arrow::legacy::index::IdxArr; +use arrow::legacy::utils::CustomIterTools; +use arrow::types::NativeType; +use polars_utils::index::NullCount; + +pub(super) unsafe fn take_values_and_validity_unchecked( + values: &[T], + validity_values: Option<&Bitmap>, + indices: &IdxArr, +) -> (Vec, Option) { + let index_values = indices.values().as_slice(); + + let null_count = validity_values.map(|b| b.unset_bits()).unwrap_or(0); + + // first take the values, these are always needed + let values: Vec = if indices.null_count() == 0 { + index_values + .iter() + .map(|idx| *values.get_unchecked(*idx as usize)) + .collect_trusted() + } else { + indices + .iter() + .map(|idx| match idx { + Some(idx) => *values.get_unchecked(*idx as usize), + None => T::default(), + }) + .collect_trusted() + }; + + if null_count > 0 { + let validity_values = validity_values.unwrap(); + // the validity buffer we will fill with all valid. And we unset the ones that are null + // in later checks + // this is in the assumption that most values will be valid. + // Maybe we could add another branch based on the null count + let mut validity = MutableBitmap::with_capacity(indices.len()); + validity.extend_constant(indices.len(), true); + let validity_slice = validity.as_mut_slice(); + + if let Some(validity_indices) = indices.validity().as_ref() { + index_values.iter().enumerate().for_each(|(i, idx)| { + // i is iteration count + // idx is the index that we take from the values array. + let idx = *idx as usize; + if !validity_indices.get_bit_unchecked(i) || !validity_values.get_bit_unchecked(idx) + { + set_bit_unchecked(validity_slice, i, false); + } + }); + } else { + index_values.iter().enumerate().for_each(|(i, idx)| { + let idx = *idx as usize; + if !validity_values.get_bit_unchecked(idx) { + set_bit_unchecked(validity_slice, i, false); + } + }); + }; + (values, Some(validity.freeze())) + } else { + (values, indices.validity().cloned()) + } +} + +/// Take kernel for single chunk with nulls and arrow array as index that may have nulls. +/// # Safety +/// caller must ensure indices are in bounds +pub unsafe fn take_primitive_unchecked( + arr: &PrimitiveArray, + indices: &IdxArr, +) -> PrimitiveArray { + let (values, validity) = + take_values_and_validity_unchecked(arr.values(), arr.validity(), indices); + PrimitiveArray::new_unchecked(arr.dtype().clone(), values.into(), validity) +} diff --git a/crates/polars-compute/src/gather/structure.rs b/crates/polars-compute/src/gather/structure.rs new file mode 100644 index 000000000000..dc71e013294a --- /dev/null +++ b/crates/polars-compute/src/gather/structure.rs @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, StructArray}; +use arrow::compute::utils::combine_validities_and; +use arrow::datatypes::IdxArr; + +pub(super) unsafe fn take_unchecked(array: &StructArray, indices: &IdxArr) -> StructArray { + let values: Vec> = array + .values() + .iter() + .map(|a| super::take_unchecked(a.as_ref(), indices)) + .collect(); + + let validity = array + .validity() + .map(|b| super::bitmap::take_bitmap_nulls_unchecked(b, indices)); + let validity = combine_validities_and(validity.as_ref(), indices.validity()); + StructArray::new(array.dtype().clone(), indices.len(), values, validity) +} diff --git a/crates/polars-compute/src/gather/sublist/fixed_size_list.rs b/crates/polars-compute/src/gather/sublist/fixed_size_list.rs new file mode 100644 index 000000000000..57bf4350a88f --- /dev/null +++ b/crates/polars-compute/src/gather/sublist/fixed_size_list.rs @@ -0,0 +1,73 @@ +use arrow::array::{ArrayRef, FixedSizeListArray, PrimitiveArray}; +use arrow::legacy::prelude::*; +use arrow::legacy::utils::CustomIterTools; +use polars_error::{PolarsResult, polars_bail}; +use polars_utils::IdxSize; +use polars_utils::index::NullCount; + +use crate::gather::take_unchecked; + +fn sub_fixed_size_list_get_indexes_literal(width: usize, len: usize, index: i64) -> IdxArr { + (0..len) + .map(|i| { + if index >= width as i64 { + return None; + } + + index + .negative_to_usize(width) + .map(|idx| (idx + i * width) as IdxSize) + }) + .collect_trusted() +} + +fn sub_fixed_size_list_get_indexes(width: usize, index: &PrimitiveArray) -> IdxArr { + index + .iter() + .enumerate() + .map(|(i, idx)| { + if let Some(idx) = idx { + if *idx >= width as i64 { + return None; + } + + idx.negative_to_usize(width) + .map(|idx| (idx + i * width) as IdxSize) + } else { + None + } + }) + .collect_trusted() +} + +pub fn sub_fixed_size_list_get_literal( + arr: &FixedSizeListArray, + index: i64, + null_on_oob: bool, +) -> PolarsResult { + let take_by = sub_fixed_size_list_get_indexes_literal(arr.size(), arr.len(), index); + if !null_on_oob && take_by.null_count() > 0 { + polars_bail!(ComputeError: "get index is out of bounds"); + } + + let values = arr.values(); + // SAFETY: + // the indices we generate are in bounds + unsafe { Ok(take_unchecked(&**values, &take_by)) } +} + +pub fn sub_fixed_size_list_get( + arr: &FixedSizeListArray, + index: &PrimitiveArray, + null_on_oob: bool, +) -> PolarsResult { + let take_by = sub_fixed_size_list_get_indexes(arr.size(), index); + if !null_on_oob && take_by.null_count() > 0 { + polars_bail!(ComputeError: "get index is out of bounds"); + } + + let values = arr.values(); + // SAFETY: + // the indices we generate are in bounds + unsafe { Ok(take_unchecked(&**values, &take_by)) } +} diff --git a/crates/polars-compute/src/gather/sublist/list.rs b/crates/polars-compute/src/gather/sublist/list.rs new file mode 100644 index 000000000000..d54305fb5aac --- /dev/null +++ b/crates/polars-compute/src/gather/sublist/list.rs @@ -0,0 +1,209 @@ +use arrow::array::{Array, ArrayRef, ListArray}; +use arrow::legacy::prelude::*; +use arrow::legacy::trusted_len::TrustedLenPush; +use arrow::legacy::utils::CustomIterTools; +use arrow::offset::{Offsets, OffsetsBuffer}; +use polars_utils::IdxSize; + +use crate::gather::take_unchecked; + +/// Get the indices that would result in a get operation on the lists values. +/// for example, consider this list: +/// ```text +/// [[1, 2, 3], +/// [4, 5], +/// [6]] +/// +/// This contains the following values array: +/// [1, 2, 3, 4, 5, 6] +/// +/// get index 0 +/// would lead to the following indexes: +/// [0, 3, 5]. +/// if we use those in a take operation on the values array we get: +/// [1, 4, 6] +/// +/// +/// get index -1 +/// would lead to the following indexes: +/// [2, 4, 5]. +/// if we use those in a take operation on the values array we get: +/// [3, 5, 6] +/// +/// ``` +fn sublist_get_indexes(arr: &ListArray, index: i64) -> IdxArr { + let offsets = arr.offsets().as_slice(); + let mut iter = offsets.iter(); + + // the indices can be sliced, so we should not start at 0. + let mut cum_offset = (*offsets.first().unwrap_or(&0)) as IdxSize; + + if let Some(mut previous) = iter.next().copied() { + if arr.null_count() == 0 { + iter.map(|&offset| { + let len = offset - previous; + previous = offset; + // make sure that empty lists don't get accessed + // and out of bounds return null + if len == 0 { + return None; + } + if index >= len { + cum_offset += len as IdxSize; + return None; + } + + let out = index + .negative_to_usize(len as usize) + .map(|idx| idx as IdxSize + cum_offset); + cum_offset += len as IdxSize; + out + }) + .collect_trusted() + } else { + // we can ensure that validity is not none as we have null value. + let validity = arr.validity().unwrap(); + iter.enumerate() + .map(|(i, &offset)| { + let len = offset - previous; + previous = offset; + // make sure that empty and null lists don't get accessed and return null. + // SAFETY, we are within bounds + if len == 0 || !unsafe { validity.get_bit_unchecked(i) } { + cum_offset += len as IdxSize; + return None; + } + + // make sure that out of bounds return null + if index >= len { + cum_offset += len as IdxSize; + return None; + } + + let out = index + .negative_to_usize(len as usize) + .map(|idx| idx as IdxSize + cum_offset); + cum_offset += len as IdxSize; + out + }) + .collect_trusted() + } + } else { + IdxArr::from_slice([]) + } +} + +pub fn sublist_get(arr: &ListArray, index: i64) -> ArrayRef { + let take_by = sublist_get_indexes(arr, index); + let values = arr.values(); + // SAFETY: + // the indices we generate are in bounds + unsafe { take_unchecked(&**values, &take_by) } +} + +/// Check if an index is out of bounds for at least one sublist. +pub fn index_is_oob(arr: &ListArray, index: i64) -> bool { + if arr.null_count() == 0 { + arr.offsets() + .lengths() + .any(|len| index.negative_to_usize(len).is_none()) + } else { + arr.offsets() + .lengths() + .zip(arr.validity().unwrap()) + .any(|(len, valid)| { + if valid { + index.negative_to_usize(len).is_none() + } else { + // skip nulls + false + } + }) + } +} + +/// Convert a list `[1, 2, 3]` to a list type of `[[1], [2], [3]]` +pub fn array_to_unit_list(array: ArrayRef) -> ListArray { + let len = array.len(); + let mut offsets = Vec::with_capacity(len + 1); + // SAFETY: we allocated enough + unsafe { + offsets.push_unchecked(0i64); + + for _ in 0..len { + offsets.push_unchecked(offsets.len() as i64) + } + }; + + // SAFETY: + // offsets are monotonically increasing + unsafe { + let offsets: OffsetsBuffer = Offsets::new_unchecked(offsets).into(); + let dtype = ListArray::::default_datatype(array.dtype().clone()); + ListArray::::new(dtype, offsets, array, None) + } +} + +#[cfg(test)] +mod test { + use arrow::array::{Int32Array, PrimitiveArray}; + use arrow::datatypes::ArrowDataType; + + use super::*; + + fn get_array() -> ListArray { + let values = Int32Array::from_slice([1, 2, 3, 4, 5, 6]); + let offsets = OffsetsBuffer::try_from(vec![0i64, 3, 5, 6]).unwrap(); + + let dtype = ListArray::::default_datatype(ArrowDataType::Int32); + ListArray::::new(dtype, offsets, Box::new(values), None) + } + + #[test] + fn test_sublist_get_indexes() { + let arr = get_array(); + let out = sublist_get_indexes(&arr, 0); + assert_eq!(out.values().as_slice(), &[0, 3, 5]); + let out = sublist_get_indexes(&arr, -1); + assert_eq!(out.values().as_slice(), &[2, 4, 5]); + let out = sublist_get_indexes(&arr, 3); + assert_eq!(out.null_count(), 3); + + let values = Int32Array::from_iter([ + Some(1), + Some(1), + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + Some(9), + None, + Some(11), + ]); + let offsets = OffsetsBuffer::try_from(vec![0i64, 1, 2, 3, 6, 9, 11]).unwrap(); + + let dtype = ListArray::::default_datatype(ArrowDataType::Int32); + let arr = ListArray::::new(dtype, offsets, Box::new(values), None); + + let out = sublist_get_indexes(&arr, 1); + assert_eq!( + out.into_iter().collect::>(), + &[None, None, None, Some(4), Some(7), Some(10)] + ); + } + + #[test] + fn test_sublist_get() { + let arr = get_array(); + + let out = sublist_get(&arr, 0); + let out = out.as_any().downcast_ref::>().unwrap(); + + assert_eq!(out.values().as_slice(), &[1, 4, 6]); + let out = sublist_get(&arr, -1); + let out = out.as_any().downcast_ref::>().unwrap(); + assert_eq!(out.values().as_slice(), &[3, 5, 6]); + } +} diff --git a/crates/polars-compute/src/gather/sublist/mod.rs b/crates/polars-compute/src/gather/sublist/mod.rs new file mode 100644 index 000000000000..24dd6560d6c1 --- /dev/null +++ b/crates/polars-compute/src/gather/sublist/mod.rs @@ -0,0 +1,3 @@ +//! Kernels for gathering values contained within lists. +pub mod fixed_size_list; +pub mod list; diff --git a/crates/polars-compute/src/horizontal_flatten/mod.rs b/crates/polars-compute/src/horizontal_flatten/mod.rs new file mode 100644 index 000000000000..032925c8fa0c --- /dev/null +++ b/crates/polars-compute/src/horizontal_flatten/mod.rs @@ -0,0 +1,185 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use arrow::array::{ + Array, ArrayCollectIterExt, BinaryArray, BinaryViewArray, BooleanArray, FixedSizeListArray, + ListArray, NullArray, PrimitiveArray, StaticArray, StructArray, Utf8ViewArray, +}; +use arrow::bitmap::Bitmap; +use arrow::datatypes::{ArrowDataType, PhysicalType}; +use arrow::with_match_primitive_type_full; +use strength_reduce::StrengthReducedUsize; +mod struct_; + +/// Low-level operation used by `concat_arr`. This should be called with the inner values array of +/// every FixedSizeList array. +/// +/// # Safety +/// * `arrays` is non-empty +/// * `arrays` and `widths` have equal length +/// * All widths in `widths` are non-zero +/// * Every array `arrays[i]` has a length of either +/// * `widths[i] * output_height` +/// * `widths[i]` (this would be broadcasted) +/// * All arrays in `arrays` have the same type +pub unsafe fn horizontal_flatten_unchecked( + arrays: &[Box], + widths: &[usize], + output_height: usize, +) -> Box { + use PhysicalType::*; + + let dtype = arrays[0].dtype(); + + match dtype.to_physical_type() { + Null => Box::new(NullArray::new( + dtype.clone(), + output_height * widths.iter().copied().sum::(), + )), + Boolean => Box::new(horizontal_flatten_unchecked_impl_generic( + &arrays + .iter() + .map(|x| x.as_any().downcast_ref::().unwrap().clone()) + .collect::>(), + widths, + output_height, + dtype, + )), + Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { + Box::new(horizontal_flatten_unchecked_impl_generic( + &arrays + .iter() + .map(|x| x.as_any().downcast_ref::>().unwrap().clone()) + .collect::>(), + widths, + output_height, + dtype + )) + }), + LargeBinary => Box::new(horizontal_flatten_unchecked_impl_generic( + &arrays + .iter() + .map(|x| { + x.as_any() + .downcast_ref::>() + .unwrap() + .clone() + }) + .collect::>(), + widths, + output_height, + dtype, + )), + Struct => Box::new(struct_::horizontal_flatten_unchecked( + &arrays + .iter() + .map(|x| x.as_any().downcast_ref::().unwrap().clone()) + .collect::>(), + widths, + output_height, + )), + LargeList => Box::new(horizontal_flatten_unchecked_impl_generic( + &arrays + .iter() + .map(|x| x.as_any().downcast_ref::>().unwrap().clone()) + .collect::>(), + widths, + output_height, + dtype, + )), + FixedSizeList => Box::new(horizontal_flatten_unchecked_impl_generic( + &arrays + .iter() + .map(|x| { + x.as_any() + .downcast_ref::() + .unwrap() + .clone() + }) + .collect::>(), + widths, + output_height, + dtype, + )), + BinaryView => Box::new(horizontal_flatten_unchecked_impl_generic( + &arrays + .iter() + .map(|x| { + x.as_any() + .downcast_ref::() + .unwrap() + .clone() + }) + .collect::>(), + widths, + output_height, + dtype, + )), + Utf8View => Box::new(horizontal_flatten_unchecked_impl_generic( + &arrays + .iter() + .map(|x| x.as_any().downcast_ref::().unwrap().clone()) + .collect::>(), + widths, + output_height, + dtype, + )), + t => unimplemented!("horizontal_flatten not supported for data type {:?}", t), + } +} + +unsafe fn horizontal_flatten_unchecked_impl_generic( + arrays: &[T], + widths: &[usize], + output_height: usize, + dtype: &ArrowDataType, +) -> T +where + T: StaticArray, +{ + assert!(!arrays.is_empty()); + assert_eq!(widths.len(), arrays.len()); + + debug_assert!(widths.iter().all(|x| *x > 0)); + debug_assert!( + arrays + .iter() + .zip(widths) + .all(|(arr, width)| arr.len() == output_height * *width || arr.len() == *width) + ); + + // We modulo the array length to support broadcasting. + let lengths = arrays + .iter() + .map(|x| StrengthReducedUsize::new(x.len())) + .collect::>(); + let out_row_width: usize = widths.iter().cloned().sum(); + let out_len = out_row_width.checked_mul(output_height).unwrap(); + + let mut col_idx = 0; + let mut row_idx = 0; + let mut until = widths[0]; + let mut outer_row_idx = 0; + + // We do `0..out_len` to get an `ExactSizeIterator`. + (0..out_len) + .map(|_| { + let arr = arrays.get_unchecked(col_idx); + let out = arr.get_unchecked(row_idx % *lengths.get_unchecked(col_idx)); + + row_idx += 1; + + if row_idx == until { + // Safety: All widths are non-zero so we only need to increment once. + col_idx = if 1 + col_idx == widths.len() { + outer_row_idx += 1; + 0 + } else { + 1 + col_idx + }; + row_idx = outer_row_idx * *widths.get_unchecked(col_idx); + until = (1 + outer_row_idx) * *widths.get_unchecked(col_idx) + } + + out + }) + .collect_arr_trusted_with_dtype(dtype.clone()) +} diff --git a/crates/polars-compute/src/horizontal_flatten/struct_.rs b/crates/polars-compute/src/horizontal_flatten/struct_.rs new file mode 100644 index 000000000000..6dc12d93a1eb --- /dev/null +++ b/crates/polars-compute/src/horizontal_flatten/struct_.rs @@ -0,0 +1,88 @@ +use super::*; + +/// # Safety +/// All preconditions in [`super::horizontal_flatten_unchecked`] +pub(super) unsafe fn horizontal_flatten_unchecked( + arrays: &[StructArray], + widths: &[usize], + output_height: usize, +) -> StructArray { + // For StructArrays, we perform the flatten operation individually for every field in the struct + // as well as on the outer validity. We then construct the result array from the individual + // result parts. + + let dtype = arrays[0].dtype(); + + let field_arrays: Vec<&[Box]> = arrays + .iter() + .inspect(|x| debug_assert_eq!(x.dtype(), dtype)) + .map(|x| x.values()) + .collect::>(); + + let n_fields = field_arrays[0].len(); + + let mut scratch = Vec::with_capacity(field_arrays.len()); + // Safety: We can take by index as all struct arrays have the same columns names in the same + // order. + // Note: `field_arrays` can be empty for 0-field structs. + let field_arrays = (0..n_fields) + .map(|i| { + scratch.clear(); + scratch.extend(field_arrays.iter().map(|v| v[i].clone())); + + super::horizontal_flatten_unchecked(&scratch, widths, output_height) + }) + .collect::>(); + + let validity = if arrays.iter().any(|x| x.validity().is_some()) { + let max_height = output_height * widths.iter().fold(0usize, |a, b| a.max(*b)); + let mut shared_validity = None; + + // We need to create BooleanArrays from the Bitmaps for dispatch. + let validities: Vec = arrays + .iter() + .map(|x| { + x.validity().cloned().unwrap_or_else(|| { + if shared_validity.is_none() { + shared_validity = Some(Bitmap::new_with_value(true, max_height)) + }; + // We have to slice to exact length to pass an assertion. + shared_validity.clone().unwrap().sliced(0, x.len()) + }) + }) + .map(|x| BooleanArray::from_inner_unchecked(ArrowDataType::Boolean, x, None)) + .collect::>(); + + Some( + super::horizontal_flatten_unchecked_impl_generic::( + &validities, + widths, + output_height, + &ArrowDataType::Boolean, + ) + .as_any() + .downcast_ref::() + .unwrap() + .values() + .clone(), + ) + } else { + None + }; + + StructArray::new( + dtype.clone(), + if n_fields == 0 { + output_height * widths.iter().copied().sum::() + } else { + debug_assert_eq!( + field_arrays[0].len(), + output_height * widths.iter().copied().sum::() + ); + + field_arrays[0].len() + }, + field_arrays, + validity, + ) +} diff --git a/crates/polars-compute/src/hyperloglogplus.rs b/crates/polars-compute/src/hyperloglogplus.rs new file mode 100644 index 000000000000..70519e3752f4 --- /dev/null +++ b/crates/polars-compute/src/hyperloglogplus.rs @@ -0,0 +1,345 @@ +//! # HyperLogLogPlus +//! +//! `hyperloglogplus` module contains implementation of HyperLogLogPlus +//! algorithm for cardinality estimation so that [`crate::series::approx_n_unique`] function can +//! be efficiently implemented. +//! +//! This module borrows code from [arrow-datafusion](https://github.com/apache/arrow-datafusion/blob/93771052c5ac31f2cf22b8c25bf938656afe1047/datafusion/physical-expr/src/aggregate/hyperloglog.rs). +//! +//! # Examples +//! +//! ``` +//! # use polars_compute::hyperloglogplus::*; +//! let mut hllp = HyperLogLog::new(); +//! hllp.add(&12345); +//! hllp.add(&23456); +//! +//! assert_eq!(hllp.count(), 2); +//! ``` + +use std::hash::{BuildHasher, Hash}; +use std::marker::PhantomData; + +use polars_utils::aliases::PlFixedStateQuality; + +/// The greater is P, the smaller the error. +const HLL_P: usize = 14_usize; +/// The number of bits of the hash value used determining the number of leading zeros +const HLL_Q: usize = 64_usize - HLL_P; +const NUM_REGISTERS: usize = 1_usize << HLL_P; +/// Mask to obtain index into the registers +const HLL_P_MASK: u64 = (NUM_REGISTERS as u64) - 1; + +#[derive(Clone, Debug)] +pub struct HyperLogLog +where + T: Hash + ?Sized, +{ + registers: [u8; NUM_REGISTERS], + phantom: PhantomData, +} + +impl Default for HyperLogLog +where + T: Hash + ?Sized, +{ + fn default() -> Self { + Self::new() + } +} + +/// Fixed seed for the hashing so that values are consistent across runs +/// +/// Note that when we later move on to have serialized HLL register binaries +/// shared across cluster, this SEED will have to be consistent across all +/// parties otherwise we might have corruption. So ideally for later this seed +/// shall be part of the serialized form (or stay unchanged across versions). +const SEED: PlFixedStateQuality = PlFixedStateQuality::with_seed(0); + +impl HyperLogLog +where + T: Hash + ?Sized, +{ + /// Creates a new, empty HyperLogLog. + pub fn new() -> Self { + let registers = [0; NUM_REGISTERS]; + Self::new_with_registers(registers) + } + + /// Creates a HyperLogLog from already populated registers + /// note that this method should not be invoked in untrusted environment + /// because the internal structure of registers are not examined. + pub(crate) fn new_with_registers(registers: [u8; NUM_REGISTERS]) -> Self { + Self { + registers, + phantom: PhantomData, + } + } + + #[inline] + fn hash_value(&self, obj: &T) -> u64 { + SEED.hash_one(obj) + } + + /// Adds an element to the HyperLogLog. + pub fn add(&mut self, obj: &T) { + let hash = self.hash_value(obj); + let index = (hash & HLL_P_MASK) as usize; + let p = ((hash >> HLL_P) | (1_u64 << HLL_Q)).trailing_zeros() + 1; + self.registers[index] = self.registers[index].max(p as u8); + } + + /// Get the register histogram (each value in register index into + /// the histogram; u32 is enough because we only have 2**14=16384 registers + #[inline] + fn get_histogram(&self) -> [u32; HLL_Q + 2] { + let mut histogram = [0; HLL_Q + 2]; + // hopefully this can be unrolled + for r in self.registers { + histogram[r as usize] += 1; + } + histogram + } + + /// Merge the other [`HyperLogLog`] into this one + pub fn merge(&mut self, other: &HyperLogLog) { + assert!( + self.registers.len() == other.registers.len(), + "unexpected got unequal register size, expect {}, got {}", + self.registers.len(), + other.registers.len() + ); + for i in 0..self.registers.len() { + self.registers[i] = self.registers[i].max(other.registers[i]); + } + } + + /// Guess the number of unique elements seen by the HyperLogLog. + pub fn count(&self) -> usize { + let histogram = self.get_histogram(); + let m = NUM_REGISTERS as f64; + let mut z = m * hll_tau((m - histogram[HLL_Q + 1] as f64) / m); + for i in histogram[1..=HLL_Q].iter().rev() { + z += *i as f64; + z *= 0.5; + } + z += m * hll_sigma(histogram[0] as f64 / m); + (0.5 / 2_f64.ln() * m * m / z).round() as usize + } +} + +/// Helper function sigma as defined in +/// "New cardinality estimation algorithms for HyperLogLog sketches" +/// Otmar Ertl, arXiv:1702.01284 +#[inline] +fn hll_sigma(x: f64) -> f64 { + if x == 1. { + f64::INFINITY + } else { + let mut y = 1.0; + let mut z = x; + let mut x = x; + loop { + x *= x; + let z_prime = z; + z += x * y; + y += y; + if z_prime == z { + break; + } + } + z + } +} + +/// Helper function tau as defined in +/// "New cardinality estimation algorithms for HyperLogLog sketches" +/// Otmar Ertl, arXiv:1702.01284 +#[inline] +fn hll_tau(x: f64) -> f64 { + if x == 0.0 || x == 1.0 { + 0.0 + } else { + let mut y = 1.0; + let mut z = 1.0 - x; + let mut x = x; + loop { + x = x.sqrt(); + let z_prime = z; + y *= 0.5; + z -= (1.0 - x).powi(2) * y; + if z_prime == z { + break; + } + } + z / 3.0 + } +} + +impl AsRef<[u8]> for HyperLogLog +where + T: Hash + ?Sized, +{ + fn as_ref(&self) -> &[u8] { + &self.registers + } +} + +impl Extend for HyperLogLog +where + T: Hash, +{ + fn extend>(&mut self, iter: S) { + for elem in iter { + self.add(&elem); + } + } +} + +impl<'a, T> Extend<&'a T> for HyperLogLog +where + T: 'a + Hash + ?Sized, +{ + fn extend>(&mut self, iter: S) { + for elem in iter { + self.add(elem); + } + } +} + +#[cfg(test)] +mod tests { + use super::{HyperLogLog, NUM_REGISTERS}; + + fn compare_with_delta(got: usize, expected: usize) { + let expected = expected as f64; + let diff = (got as f64) - expected; + let diff = diff.abs() / expected; + // times 6 because we want the tests to be stable + // so we allow a rather large margin of error + // this is adopted from redis's unit test version as well + let margin = 1.04 / ((NUM_REGISTERS as f64).sqrt()) * 6.0; + assert!( + diff <= margin, + "{} is not near {} percent of {} which is ({}, {})", + got, + margin, + expected, + expected * (1.0 - margin), + expected * (1.0 + margin) + ); + } + + macro_rules! sized_number_test { + ($SIZE: expr, $T: tt) => {{ + let mut hll = HyperLogLog::<$T>::new(); + for i in 0..$SIZE { + hll.add(&i); + } + compare_with_delta(hll.count(), $SIZE); + }}; + } + + macro_rules! typed_large_number_test { + ($SIZE: expr) => {{ + sized_number_test!($SIZE, u64); + sized_number_test!($SIZE, u128); + sized_number_test!($SIZE, i64); + sized_number_test!($SIZE, i128); + }}; + } + + macro_rules! typed_number_test { + ($SIZE: expr) => {{ + sized_number_test!($SIZE, u16); + sized_number_test!($SIZE, u32); + sized_number_test!($SIZE, i16); + sized_number_test!($SIZE, i32); + typed_large_number_test!($SIZE); + }}; + } + + #[test] + fn test_empty() { + let hll = HyperLogLog::::new(); + assert_eq!(hll.count(), 0); + } + + #[test] + fn test_one() { + let mut hll = HyperLogLog::::new(); + hll.add(&1); + assert_eq!(hll.count(), 1); + } + + #[test] + fn test_number_100() { + typed_number_test!(100); + } + + #[test] + fn test_number_1k() { + typed_number_test!(1_000); + } + + #[test] + fn test_number_10k() { + typed_number_test!(10_000); + } + + #[test] + fn test_number_100k() { + typed_large_number_test!(100_000); + } + + #[test] + fn test_number_1m() { + typed_large_number_test!(1_000_000); + } + + #[test] + fn test_u8() { + let mut hll = HyperLogLog::<[u8]>::new(); + for i in 0..1000 { + let s = i.to_string(); + let b = s.as_bytes(); + hll.add(b); + } + compare_with_delta(hll.count(), 1000); + } + + #[test] + fn test_string() { + let mut hll = HyperLogLog::::new(); + hll.extend((0..1000).map(|i| i.to_string())); + compare_with_delta(hll.count(), 1000); + } + + #[test] + fn test_empty_merge() { + let mut hll = HyperLogLog::::new(); + hll.merge(&HyperLogLog::::new()); + assert_eq!(hll.count(), 0); + } + + #[test] + fn test_merge_overlapped() { + let mut hll = HyperLogLog::::new(); + hll.extend((0..1000).map(|i| i.to_string())); + + let mut other = HyperLogLog::::new(); + other.extend((0..1000).map(|i| i.to_string())); + + hll.merge(&other); + compare_with_delta(hll.count(), 1000); + } + + #[test] + fn test_repetition() { + let mut hll = HyperLogLog::::new(); + for i in 0..1_000_000 { + hll.add(&(i % 1000)); + } + compare_with_delta(hll.count(), 1000); + } +} diff --git a/crates/polars-compute/src/if_then_else/array.rs b/crates/polars-compute/src/if_then_else/array.rs new file mode 100644 index 000000000000..49fc9efee17c --- /dev/null +++ b/crates/polars-compute/src/if_then_else/array.rs @@ -0,0 +1,89 @@ +use arrow::array::builder::{ShareStrategy, StaticArrayBuilder, make_builder}; +use arrow::array::{Array, ArrayCollectIterExt, FixedSizeListArray, FixedSizeListArrayBuilder}; +use arrow::bitmap::Bitmap; + +use super::{IfThenElseKernel, if_then_else_extend}; + +impl IfThenElseKernel for FixedSizeListArray { + type Scalar<'a> = Box; + + fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self { + let inner_dt = if_true.dtype().inner_dtype().unwrap(); + let mut builder = + FixedSizeListArrayBuilder::new(if_true.dtype().clone(), make_builder(inner_dt)); + builder.reserve(mask.len()); + if_then_else_extend( + &mut builder, + mask, + |b, off, len| b.subslice_extend(if_true, off, len, ShareStrategy::Always), + |b, off, len| b.subslice_extend(if_false, off, len, ShareStrategy::Always), + ); + builder.freeze() + } + + fn if_then_else_broadcast_true( + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: &Self, + ) -> Self { + let if_true_list: FixedSizeListArray = + std::iter::once(if_true).collect_arr_trusted_with_dtype(if_false.dtype().clone()); + let inner_dt = if_false.dtype().inner_dtype().unwrap(); + let mut builder = + FixedSizeListArrayBuilder::new(if_false.dtype().clone(), make_builder(inner_dt)); + builder.reserve(mask.len()); + if_then_else_extend( + &mut builder, + mask, + |b, _, len| b.subslice_extend_repeated(&if_true_list, 0, 1, len, ShareStrategy::Always), + |b, off, len| b.subslice_extend(if_false, off, len, ShareStrategy::Always), + ); + builder.freeze() + } + + fn if_then_else_broadcast_false( + mask: &Bitmap, + if_true: &Self, + if_false: Self::Scalar<'_>, + ) -> Self { + let if_false_list: FixedSizeListArray = + std::iter::once(if_false).collect_arr_trusted_with_dtype(if_true.dtype().clone()); + let inner_dt = if_true.dtype().inner_dtype().unwrap(); + let mut builder = + FixedSizeListArrayBuilder::new(if_true.dtype().clone(), make_builder(inner_dt)); + builder.reserve(mask.len()); + if_then_else_extend( + &mut builder, + mask, + |b, off, len| b.subslice_extend(if_true, off, len, ShareStrategy::Always), + |b, _, len| { + b.subslice_extend_repeated(&if_false_list, 0, 1, len, ShareStrategy::Always) + }, + ); + builder.freeze() + } + + fn if_then_else_broadcast_both( + dtype: arrow::datatypes::ArrowDataType, + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: Self::Scalar<'_>, + ) -> Self { + let if_true_list: FixedSizeListArray = + std::iter::once(if_true).collect_arr_trusted_with_dtype(dtype.clone()); + let if_false_list: FixedSizeListArray = + std::iter::once(if_false).collect_arr_trusted_with_dtype(dtype.clone()); + let inner_dt = dtype.inner_dtype().unwrap(); + let mut builder = FixedSizeListArrayBuilder::new(dtype.clone(), make_builder(inner_dt)); + builder.reserve(mask.len()); + if_then_else_extend( + &mut builder, + mask, + |b, _, len| b.subslice_extend_repeated(&if_true_list, 0, 1, len, ShareStrategy::Always), + |b, _, len| { + b.subslice_extend_repeated(&if_false_list, 0, 1, len, ShareStrategy::Always) + }, + ); + builder.freeze() + } +} diff --git a/crates/polars-compute/src/if_then_else/boolean.rs b/crates/polars-compute/src/if_then_else/boolean.rs new file mode 100644 index 000000000000..7aabeff8d7ac --- /dev/null +++ b/crates/polars-compute/src/if_then_else/boolean.rs @@ -0,0 +1,60 @@ +use arrow::array::BooleanArray; +use arrow::bitmap::{self, Bitmap}; +use arrow::datatypes::ArrowDataType; + +use super::{IfThenElseKernel, if_then_else_validity}; + +impl IfThenElseKernel for BooleanArray { + type Scalar<'a> = bool; + + fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self { + let values = bitmap::ternary(mask, if_true.values(), if_false.values(), |m, t, f| { + (m & t) | (!m & f) + }); + let validity = if_then_else_validity(mask, if_true.validity(), if_false.validity()); + BooleanArray::from(values).with_validity(validity) + } + + fn if_then_else_broadcast_true( + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: &Self, + ) -> Self { + let values = if if_true { + bitmap::or(if_false.values(), mask) // (m & true) | (!m & f) -> f | m + } else { + bitmap::and_not(if_false.values(), mask) // (m & false) | (!m & f) -> f & !m + }; + let validity = if_then_else_validity(mask, None, if_false.validity()); + BooleanArray::from(values).with_validity(validity) + } + + fn if_then_else_broadcast_false( + mask: &Bitmap, + if_true: &Self, + if_false: Self::Scalar<'_>, + ) -> Self { + let values = if if_false { + bitmap::or_not(if_true.values(), mask) // (m & t) | (!m & true) -> t | !m + } else { + bitmap::and(if_true.values(), mask) // (m & t) | (!m & false) -> t & m + }; + let validity = if_then_else_validity(mask, if_true.validity(), None); + BooleanArray::from(values).with_validity(validity) + } + + fn if_then_else_broadcast_both( + _dtype: ArrowDataType, + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: Self::Scalar<'_>, + ) -> Self { + let values = match (if_true, if_false) { + (false, false) => Bitmap::new_with_value(false, mask.len()), + (false, true) => !mask, + (true, false) => mask.clone(), + (true, true) => Bitmap::new_with_value(true, mask.len()), + }; + BooleanArray::from(values) + } +} diff --git a/crates/polars-compute/src/if_then_else/list.rs b/crates/polars-compute/src/if_then_else/list.rs new file mode 100644 index 000000000000..39f549b1514b --- /dev/null +++ b/crates/polars-compute/src/if_then_else/list.rs @@ -0,0 +1,86 @@ +use arrow::array::builder::{ShareStrategy, StaticArrayBuilder, make_builder}; +use arrow::array::{Array, ArrayCollectIterExt, ListArray, ListArrayBuilder}; +use arrow::bitmap::Bitmap; + +use super::{IfThenElseKernel, if_then_else_extend}; + +impl IfThenElseKernel for ListArray { + type Scalar<'a> = Box; + + fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self { + let inner_dt = if_true.dtype().inner_dtype().unwrap(); + let mut builder = ListArrayBuilder::new(if_true.dtype().clone(), make_builder(inner_dt)); + builder.reserve(mask.len()); + if_then_else_extend( + &mut builder, + mask, + |b, off, len| b.subslice_extend(if_true, off, len, ShareStrategy::Always), + |b, off, len| b.subslice_extend(if_false, off, len, ShareStrategy::Always), + ); + builder.freeze() + } + + fn if_then_else_broadcast_true( + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: &Self, + ) -> Self { + let if_true_list: ListArray = + std::iter::once(if_true).collect_arr_trusted_with_dtype(if_false.dtype().clone()); + let inner_dt = if_false.dtype().inner_dtype().unwrap(); + let mut builder = ListArrayBuilder::new(if_false.dtype().clone(), make_builder(inner_dt)); + builder.reserve(mask.len()); + if_then_else_extend( + &mut builder, + mask, + |b, _, len| b.subslice_extend_repeated(&if_true_list, 0, 1, len, ShareStrategy::Always), + |b, off, len| b.subslice_extend(if_false, off, len, ShareStrategy::Always), + ); + builder.freeze() + } + + fn if_then_else_broadcast_false( + mask: &Bitmap, + if_true: &Self, + if_false: Self::Scalar<'_>, + ) -> Self { + let if_false_list: ListArray = + std::iter::once(if_false).collect_arr_trusted_with_dtype(if_true.dtype().clone()); + let inner_dt = if_true.dtype().inner_dtype().unwrap(); + let mut builder = ListArrayBuilder::new(if_true.dtype().clone(), make_builder(inner_dt)); + builder.reserve(mask.len()); + if_then_else_extend( + &mut builder, + mask, + |b, off, len| b.subslice_extend(if_true, off, len, ShareStrategy::Always), + |b, _, len| { + b.subslice_extend_repeated(&if_false_list, 0, 1, len, ShareStrategy::Always) + }, + ); + builder.freeze() + } + + fn if_then_else_broadcast_both( + dtype: arrow::datatypes::ArrowDataType, + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: Self::Scalar<'_>, + ) -> Self { + let if_true_list: ListArray = + std::iter::once(if_true).collect_arr_trusted_with_dtype(dtype.clone()); + let if_false_list: ListArray = + std::iter::once(if_false).collect_arr_trusted_with_dtype(dtype.clone()); + let inner_dt = dtype.inner_dtype().unwrap(); + let mut builder = ListArrayBuilder::new(dtype.clone(), make_builder(inner_dt)); + builder.reserve(mask.len()); + if_then_else_extend( + &mut builder, + mask, + |b, _, len| b.subslice_extend_repeated(&if_true_list, 0, 1, len, ShareStrategy::Always), + |b, _, len| { + b.subslice_extend_repeated(&if_false_list, 0, 1, len, ShareStrategy::Always) + }, + ); + builder.freeze() + } +} diff --git a/crates/polars-compute/src/if_then_else/mod.rs b/crates/polars-compute/src/if_then_else/mod.rs new file mode 100644 index 000000000000..4d3723ae426a --- /dev/null +++ b/crates/polars-compute/src/if_then_else/mod.rs @@ -0,0 +1,292 @@ +use std::mem::MaybeUninit; + +use arrow::array::{Array, PrimitiveArray}; +use arrow::bitmap::utils::SlicesIterator; +use arrow::bitmap::{self, Bitmap}; +use arrow::datatypes::ArrowDataType; + +use crate::NotSimdPrimitive; + +mod array; +mod boolean; +mod list; +mod scalar; +#[cfg(feature = "simd")] +mod simd; +mod view; + +pub trait IfThenElseKernel: Sized + Array { + type Scalar<'a>; + + fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self; + fn if_then_else_broadcast_true( + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: &Self, + ) -> Self; + fn if_then_else_broadcast_false( + mask: &Bitmap, + if_true: &Self, + if_false: Self::Scalar<'_>, + ) -> Self; + fn if_then_else_broadcast_both( + dtype: ArrowDataType, + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: Self::Scalar<'_>, + ) -> Self; +} + +impl IfThenElseKernel for PrimitiveArray { + type Scalar<'a> = T; + + fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self { + let values = if_then_else_loop( + mask, + if_true.values(), + if_false.values(), + scalar::if_then_else_scalar_rest, + scalar::if_then_else_scalar_64, + ); + let validity = if_then_else_validity(mask, if_true.validity(), if_false.validity()); + PrimitiveArray::from_vec(values).with_validity(validity) + } + + fn if_then_else_broadcast_true( + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: &Self, + ) -> Self { + let values = if_then_else_loop_broadcast_false( + true, + mask, + if_false.values(), + if_true, + scalar::if_then_else_broadcast_false_scalar_64, + ); + let validity = if_then_else_validity(mask, None, if_false.validity()); + PrimitiveArray::from_vec(values).with_validity(validity) + } + + fn if_then_else_broadcast_false( + mask: &Bitmap, + if_true: &Self, + if_false: Self::Scalar<'_>, + ) -> Self { + let values = if_then_else_loop_broadcast_false( + false, + mask, + if_true.values(), + if_false, + scalar::if_then_else_broadcast_false_scalar_64, + ); + let validity = if_then_else_validity(mask, if_true.validity(), None); + PrimitiveArray::from_vec(values).with_validity(validity) + } + + fn if_then_else_broadcast_both( + _dtype: ArrowDataType, + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: Self::Scalar<'_>, + ) -> Self { + let values = if_then_else_loop_broadcast_both( + mask, + if_true, + if_false, + scalar::if_then_else_broadcast_both_scalar_64, + ); + PrimitiveArray::from_vec(values) + } +} + +pub fn if_then_else_validity( + mask: &Bitmap, + if_true: Option<&Bitmap>, + if_false: Option<&Bitmap>, +) -> Option { + match (if_true, if_false) { + (None, None) => None, + (None, Some(f)) => Some(mask | f), + (Some(t), None) => Some(bitmap::binary(mask, t, |m, t| !m | t)), + (Some(t), Some(f)) => Some(bitmap::ternary(mask, t, f, |m, t, f| (m & t) | (!m & f))), + } +} + +fn if_then_else_extend( + builder: &mut B, + mask: &Bitmap, + extend_true: ET, + extend_false: EF, +) { + let mut last_true_end = 0; + for (start, len) in SlicesIterator::new(mask) { + if start != last_true_end { + extend_false(builder, last_true_end, start - last_true_end); + }; + extend_true(builder, start, len); + last_true_end = start + len; + } + if last_true_end != mask.len() { + extend_false(builder, last_true_end, mask.len() - last_true_end) + } +} + +fn if_then_else_loop( + mask: &Bitmap, + if_true: &[T], + if_false: &[T], + process_var: F, + process_chunk: F64, +) -> Vec +where + T: Copy, + F: Fn(u64, &[T], &[T], &mut [MaybeUninit]), + F64: Fn(u64, &[T; 64], &[T; 64], &mut [MaybeUninit; 64]), +{ + assert_eq!(mask.len(), if_true.len()); + assert_eq!(mask.len(), if_false.len()); + + let mut ret = Vec::with_capacity(mask.len()); + let out = &mut ret.spare_capacity_mut()[..mask.len()]; + + // Handle prefix. + let aligned = mask.aligned::(); + let (start_true, rest_true) = if_true.split_at(aligned.prefix_bitlen()); + let (start_false, rest_false) = if_false.split_at(aligned.prefix_bitlen()); + let (start_out, rest_out) = out.split_at_mut(aligned.prefix_bitlen()); + if aligned.prefix_bitlen() > 0 { + process_var(aligned.prefix(), start_true, start_false, start_out); + } + + // Handle bulk. + let mut true_chunks = rest_true.chunks_exact(64); + let mut false_chunks = rest_false.chunks_exact(64); + let mut out_chunks = rest_out.chunks_exact_mut(64); + let combined = true_chunks + .by_ref() + .zip(false_chunks.by_ref()) + .zip(out_chunks.by_ref()); + for (i, ((tc, fc), oc)) in combined.enumerate() { + let m = unsafe { *aligned.bulk().get_unchecked(i) }; + process_chunk( + m, + tc.try_into().unwrap(), + fc.try_into().unwrap(), + oc.try_into().unwrap(), + ); + } + + // Handle suffix. + if aligned.suffix_bitlen() > 0 { + process_var( + aligned.suffix(), + true_chunks.remainder(), + false_chunks.remainder(), + out_chunks.into_remainder(), + ); + } + + unsafe { + ret.set_len(mask.len()); + } + ret +} + +fn if_then_else_loop_broadcast_false( + invert_mask: bool, // Allows code reuse for both false and true broadcasts. + mask: &Bitmap, + if_true: &[T], + if_false: T, + process_chunk: F64, +) -> Vec +where + T: Copy, + F64: Fn(u64, &[T; 64], T, &mut [MaybeUninit; 64]), +{ + assert_eq!(mask.len(), if_true.len()); + + let mut ret = Vec::with_capacity(mask.len()); + let out = &mut ret.spare_capacity_mut()[..mask.len()]; + + // XOR with all 1's inverts the mask. + let xor_inverter = if invert_mask { u64::MAX } else { 0 }; + + // Handle prefix. + let aligned = mask.aligned::(); + let (start_true, rest_true) = if_true.split_at(aligned.prefix_bitlen()); + let (start_out, rest_out) = out.split_at_mut(aligned.prefix_bitlen()); + if aligned.prefix_bitlen() > 0 { + scalar::if_then_else_broadcast_false_scalar_rest( + aligned.prefix() ^ xor_inverter, + start_true, + if_false, + start_out, + ); + } + + // Handle bulk. + let mut true_chunks = rest_true.chunks_exact(64); + let mut out_chunks = rest_out.chunks_exact_mut(64); + let combined = true_chunks.by_ref().zip(out_chunks.by_ref()); + for (i, (tc, oc)) in combined.enumerate() { + let m = unsafe { *aligned.bulk().get_unchecked(i) } ^ xor_inverter; + process_chunk(m, tc.try_into().unwrap(), if_false, oc.try_into().unwrap()); + } + + // Handle suffix. + if aligned.suffix_bitlen() > 0 { + scalar::if_then_else_broadcast_false_scalar_rest( + aligned.suffix() ^ xor_inverter, + true_chunks.remainder(), + if_false, + out_chunks.into_remainder(), + ); + } + + unsafe { + ret.set_len(mask.len()); + } + ret +} + +fn if_then_else_loop_broadcast_both( + mask: &Bitmap, + if_true: T, + if_false: T, + generate_chunk: F64, +) -> Vec +where + T: Copy, + F64: Fn(u64, T, T, &mut [MaybeUninit; 64]), +{ + let mut ret = Vec::with_capacity(mask.len()); + let out = &mut ret.spare_capacity_mut()[..mask.len()]; + + // Handle prefix. + let aligned = mask.aligned::(); + let (start_out, rest_out) = out.split_at_mut(aligned.prefix_bitlen()); + scalar::if_then_else_broadcast_both_scalar_rest(aligned.prefix(), if_true, if_false, start_out); + + // Handle bulk. + let mut out_chunks = rest_out.chunks_exact_mut(64); + for (i, oc) in out_chunks.by_ref().enumerate() { + let m = unsafe { *aligned.bulk().get_unchecked(i) }; + generate_chunk(m, if_true, if_false, oc.try_into().unwrap()); + } + + // Handle suffix. + if aligned.suffix_bitlen() > 0 { + scalar::if_then_else_broadcast_both_scalar_rest( + aligned.suffix(), + if_true, + if_false, + out_chunks.into_remainder(), + ); + } + + unsafe { + ret.set_len(mask.len()); + } + ret +} diff --git a/crates/polars-compute/src/if_then_else/scalar.rs b/crates/polars-compute/src/if_then_else/scalar.rs new file mode 100644 index 000000000000..310da8cfb121 --- /dev/null +++ b/crates/polars-compute/src/if_then_else/scalar.rs @@ -0,0 +1,76 @@ +use std::mem::MaybeUninit; + +pub fn if_then_else_scalar_rest( + mask: u64, + if_true: &[T], + if_false: &[T], + out: &mut [MaybeUninit], +) { + assert!(if_true.len() == out.len()); // Removes bounds checks in inner loop. + let true_it = if_true.iter().copied(); + let false_it = if_false.iter().copied(); + for (i, (t, f)) in true_it.zip(false_it).enumerate() { + let src = if (mask >> i) & 1 != 0 { t } else { f }; + out[i] = MaybeUninit::new(src); + } +} + +pub fn if_then_else_broadcast_false_scalar_rest( + mask: u64, + if_true: &[T], + if_false: T, + out: &mut [MaybeUninit], +) { + assert!(if_true.len() == out.len()); // Removes bounds checks in inner loop. + let true_it = if_true.iter().copied(); + for (i, t) in true_it.enumerate() { + let src = if (mask >> i) & 1 != 0 { t } else { if_false }; + out[i] = MaybeUninit::new(src); + } +} + +pub fn if_then_else_broadcast_both_scalar_rest( + mask: u64, + if_true: T, + if_false: T, + out: &mut [MaybeUninit], +) { + for (i, dst) in out.iter_mut().enumerate() { + let src = if (mask >> i) & 1 != 0 { + if_true + } else { + if_false + }; + *dst = MaybeUninit::new(src); + } +} + +pub fn if_then_else_scalar_64( + mask: u64, + if_true: &[T; 64], + if_false: &[T; 64], + out: &mut [MaybeUninit; 64], +) { + // This generated the best autovectorized code on ARM, and branchless everywhere. + if_then_else_scalar_rest(mask, if_true, if_false, out) +} + +pub fn if_then_else_broadcast_false_scalar_64( + mask: u64, + if_true: &[T; 64], + if_false: T, + out: &mut [MaybeUninit; 64], +) { + // This generated the best autovectorized code on ARM, and branchless everywhere. + if_then_else_broadcast_false_scalar_rest(mask, if_true, if_false, out) +} + +pub fn if_then_else_broadcast_both_scalar_64( + mask: u64, + if_true: T, + if_false: T, + out: &mut [MaybeUninit; 64], +) { + // This generated the best autovectorized code on ARM, and branchless everywhere. + if_then_else_broadcast_both_scalar_rest(mask, if_true, if_false, out) +} diff --git a/crates/polars-compute/src/if_then_else/simd.rs b/crates/polars-compute/src/if_then_else/simd.rs new file mode 100644 index 000000000000..e86e91bc72f6 --- /dev/null +++ b/crates/polars-compute/src/if_then_else/simd.rs @@ -0,0 +1,157 @@ +#[cfg(target_arch = "x86_64")] +use std::mem::MaybeUninit; +#[cfg(target_arch = "x86_64")] +use std::simd::{Mask, Simd, SimdElement}; + +use arrow::array::PrimitiveArray; +use arrow::bitmap::Bitmap; +use arrow::datatypes::ArrowDataType; + +use super::{ + IfThenElseKernel, if_then_else_loop, if_then_else_loop_broadcast_both, + if_then_else_loop_broadcast_false, if_then_else_validity, scalar, +}; + +#[cfg(target_arch = "x86_64")] +fn select_simd_64( + mask: u64, + if_true: Simd, + if_false: Simd, + out: &mut [MaybeUninit; 64], +) { + let mv = Mask::<::Mask, 64>::from_bitmask(mask); + let ret = mv.select(if_true, if_false); + unsafe { + let src = ret.as_array().as_ptr() as *const MaybeUninit; + core::ptr::copy_nonoverlapping(src, out.as_mut_ptr(), 64); + } +} + +#[cfg(target_arch = "x86_64")] +fn if_then_else_simd_64( + mask: u64, + if_true: &[T; 64], + if_false: &[T; 64], + out: &mut [MaybeUninit; 64], +) { + select_simd_64( + mask, + Simd::from_slice(if_true), + Simd::from_slice(if_false), + out, + ) +} + +#[cfg(target_arch = "x86_64")] +fn if_then_else_broadcast_false_simd_64( + mask: u64, + if_true: &[T; 64], + if_false: T, + out: &mut [MaybeUninit; 64], +) { + select_simd_64(mask, Simd::from_slice(if_true), Simd::splat(if_false), out) +} + +#[cfg(target_arch = "x86_64")] +fn if_then_else_broadcast_both_simd_64( + mask: u64, + if_true: T, + if_false: T, + out: &mut [MaybeUninit; 64], +) { + select_simd_64(mask, Simd::splat(if_true), Simd::splat(if_false), out) +} + +macro_rules! impl_if_then_else { + ($T: ty) => { + impl IfThenElseKernel for PrimitiveArray<$T> { + type Scalar<'a> = $T; + + fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self { + let values = if_then_else_loop( + mask, + if_true.values(), + if_false.values(), + scalar::if_then_else_scalar_rest, + // Auto-generated SIMD was slower on ARM. + #[cfg(target_arch = "x86_64")] + if_then_else_simd_64, + #[cfg(not(target_arch = "x86_64"))] + scalar::if_then_else_scalar_64, + ); + let validity = if_then_else_validity(mask, if_true.validity(), if_false.validity()); + PrimitiveArray::from_vec(values).with_validity(validity) + } + + fn if_then_else_broadcast_true( + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: &Self, + ) -> Self { + let values = if_then_else_loop_broadcast_false( + true, + mask, + if_false.values(), + if_true, + // Auto-generated SIMD was slower on ARM. + #[cfg(target_arch = "x86_64")] + if_then_else_broadcast_false_simd_64, + #[cfg(not(target_arch = "x86_64"))] + scalar::if_then_else_broadcast_false_scalar_64, + ); + let validity = if_then_else_validity(mask, None, if_false.validity()); + PrimitiveArray::from_vec(values).with_validity(validity) + } + + fn if_then_else_broadcast_false( + mask: &Bitmap, + if_true: &Self, + if_false: Self::Scalar<'_>, + ) -> Self { + let values = if_then_else_loop_broadcast_false( + false, + mask, + if_true.values(), + if_false, + // Auto-generated SIMD was slower on ARM. + #[cfg(target_arch = "x86_64")] + if_then_else_broadcast_false_simd_64, + #[cfg(not(target_arch = "x86_64"))] + scalar::if_then_else_broadcast_false_scalar_64, + ); + let validity = if_then_else_validity(mask, if_true.validity(), None); + PrimitiveArray::from_vec(values).with_validity(validity) + } + + fn if_then_else_broadcast_both( + _dtype: ArrowDataType, + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: Self::Scalar<'_>, + ) -> Self { + let values = if_then_else_loop_broadcast_both( + mask, + if_true, + if_false, + // Auto-generated SIMD was slower on ARM. + #[cfg(target_arch = "x86_64")] + if_then_else_broadcast_both_simd_64, + #[cfg(not(target_arch = "x86_64"))] + scalar::if_then_else_broadcast_both_scalar_64, + ); + PrimitiveArray::from_vec(values) + } + } + }; +} + +impl_if_then_else!(i8); +impl_if_then_else!(i16); +impl_if_then_else!(i32); +impl_if_then_else!(i64); +impl_if_then_else!(u8); +impl_if_then_else!(u16); +impl_if_then_else!(u32); +impl_if_then_else!(u64); +impl_if_then_else!(f32); +impl_if_then_else!(f64); diff --git a/crates/polars-compute/src/if_then_else/view.rs b/crates/polars-compute/src/if_then_else/view.rs new file mode 100644 index 000000000000..5b3fd8fc4df9 --- /dev/null +++ b/crates/polars-compute/src/if_then_else/view.rs @@ -0,0 +1,298 @@ +use std::mem::MaybeUninit; +use std::ops::Deref; +use std::sync::Arc; + +use arrow::array::{Array, BinaryViewArray, MutablePlBinary, Utf8ViewArray, View}; +use arrow::bitmap::Bitmap; +use arrow::buffer::Buffer; +use arrow::datatypes::ArrowDataType; +use polars_utils::aliases::{InitHashMaps, PlHashSet}; + +use super::IfThenElseKernel; +use crate::if_then_else::scalar::if_then_else_broadcast_both_scalar_64; + +// Makes a buffer and a set of views into that buffer from a set of strings. +// Does not allocate a buffer if not necessary. +fn make_buffer_and_views( + strings: [&[u8]; N], + buffer_idx: u32, +) -> ([View; N], Option>) { + let mut buf_data = Vec::new(); + let views = strings.map(|s| { + let offset = buf_data.len().try_into().unwrap(); + if s.len() > 12 { + buf_data.extend(s); + } + View::new_from_bytes(s, buffer_idx, offset) + }); + let buf = (!buf_data.is_empty()).then(|| buf_data.into()); + (views, buf) +} + +fn has_duplicate_buffers(bufs: &[Buffer]) -> bool { + let mut has_duplicate_buffers = false; + let mut bufset = PlHashSet::new(); + for buf in bufs { + if !bufset.insert(buf.as_ptr()) { + has_duplicate_buffers = true; + break; + } + } + has_duplicate_buffers +} + +impl IfThenElseKernel for BinaryViewArray { + type Scalar<'a> = &'a [u8]; + + fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self { + let combined_buffers: Arc<_>; + let false_buffer_idx_offset: u32; + let mut has_duplicate_bufs = false; + if Arc::ptr_eq(if_true.data_buffers(), if_false.data_buffers()) { + // Share exact same buffers, no need to combine. + combined_buffers = if_true.data_buffers().clone(); + false_buffer_idx_offset = 0; + } else { + // Put false buffers after true buffers. + let true_buffers = if_true.data_buffers().iter().cloned(); + let false_buffers = if_false.data_buffers().iter().cloned(); + + combined_buffers = true_buffers.chain(false_buffers).collect(); + has_duplicate_bufs = has_duplicate_buffers(&combined_buffers); + false_buffer_idx_offset = if_true.data_buffers().len() as u32; + } + + let views = super::if_then_else_loop( + mask, + if_true.views(), + if_false.views(), + |m, t, f, o| if_then_else_view_rest(m, t, f, o, false_buffer_idx_offset), + |m, t, f, o| if_then_else_view_64(m, t, f, o, false_buffer_idx_offset), + ); + + let validity = super::if_then_else_validity(mask, if_true.validity(), if_false.validity()); + + let mut builder = MutablePlBinary::with_capacity(views.len()); + + if has_duplicate_bufs { + unsafe { + builder.extend_non_null_views_unchecked_dedupe( + views.into_iter(), + combined_buffers.deref(), + ) + }; + } else { + unsafe { + builder.extend_non_null_views_unchecked(views.into_iter(), combined_buffers.deref()) + }; + } + builder + .freeze_with_dtype(if_true.dtype().clone()) + .with_validity(validity) + } + + fn if_then_else_broadcast_true( + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: &Self, + ) -> Self { + // It's cheaper if we put the false buffers first, that way we don't need to modify any views in the loop. + let false_buffers = if_false.data_buffers().iter().cloned(); + let true_buffer_idx_offset: u32 = if_false.data_buffers().len() as u32; + let ([true_view], true_buffer) = make_buffer_and_views([if_true], true_buffer_idx_offset); + let combined_buffers: Arc<_> = false_buffers.chain(true_buffer).collect(); + + let views = super::if_then_else_loop_broadcast_false( + true, // Invert the mask so we effectively broadcast true. + mask, + if_false.views(), + true_view, + if_then_else_broadcast_false_view_64, + ); + + let validity = super::if_then_else_validity(mask, None, if_false.validity()); + + let mut builder = MutablePlBinary::with_capacity(views.len()); + + unsafe { + if has_duplicate_buffers(&combined_buffers) { + builder.extend_non_null_views_unchecked_dedupe( + views.into_iter(), + combined_buffers.deref(), + ) + } else { + builder.extend_non_null_views_unchecked(views.into_iter(), combined_buffers.deref()) + } + } + builder + .freeze_with_dtype(if_false.dtype().clone()) + .with_validity(validity) + } + + fn if_then_else_broadcast_false( + mask: &Bitmap, + if_true: &Self, + if_false: Self::Scalar<'_>, + ) -> Self { + // It's cheaper if we put the true buffers first, that way we don't need to modify any views in the loop. + let true_buffers = if_true.data_buffers().iter().cloned(); + let false_buffer_idx_offset: u32 = if_true.data_buffers().len() as u32; + let ([false_view], false_buffer) = + make_buffer_and_views([if_false], false_buffer_idx_offset); + let combined_buffers: Arc<_> = true_buffers.chain(false_buffer).collect(); + + let views = super::if_then_else_loop_broadcast_false( + false, + mask, + if_true.views(), + false_view, + if_then_else_broadcast_false_view_64, + ); + + let validity = super::if_then_else_validity(mask, if_true.validity(), None); + + let mut builder = MutablePlBinary::with_capacity(views.len()); + unsafe { + if has_duplicate_buffers(&combined_buffers) { + builder.extend_non_null_views_unchecked_dedupe( + views.into_iter(), + combined_buffers.deref(), + ) + } else { + builder.extend_non_null_views_unchecked(views.into_iter(), combined_buffers.deref()) + } + }; + builder + .freeze_with_dtype(if_true.dtype().clone()) + .with_validity(validity) + } + + fn if_then_else_broadcast_both( + dtype: ArrowDataType, + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: Self::Scalar<'_>, + ) -> Self { + let ([true_view, false_view], buffer) = make_buffer_and_views([if_true, if_false], 0); + let buffers: Arc<_> = buffer.into_iter().collect(); + let views = super::if_then_else_loop_broadcast_both( + mask, + true_view, + false_view, + if_then_else_broadcast_both_scalar_64, + ); + + let mut builder = MutablePlBinary::with_capacity(views.len()); + unsafe { + if has_duplicate_buffers(&buffers) { + builder.extend_non_null_views_unchecked_dedupe(views.into_iter(), buffers.deref()) + } else { + builder.extend_non_null_views_unchecked(views.into_iter(), buffers.deref()) + } + }; + builder.freeze_with_dtype(dtype) + } +} + +impl IfThenElseKernel for Utf8ViewArray { + type Scalar<'a> = &'a str; + + fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self { + let ret = + IfThenElseKernel::if_then_else(mask, &if_true.to_binview(), &if_false.to_binview()); + unsafe { ret.to_utf8view_unchecked() } + } + + fn if_then_else_broadcast_true( + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: &Self, + ) -> Self { + let ret = IfThenElseKernel::if_then_else_broadcast_true( + mask, + if_true.as_bytes(), + &if_false.to_binview(), + ); + unsafe { ret.to_utf8view_unchecked() } + } + + fn if_then_else_broadcast_false( + mask: &Bitmap, + if_true: &Self, + if_false: Self::Scalar<'_>, + ) -> Self { + let ret = IfThenElseKernel::if_then_else_broadcast_false( + mask, + &if_true.to_binview(), + if_false.as_bytes(), + ); + unsafe { ret.to_utf8view_unchecked() } + } + + fn if_then_else_broadcast_both( + dtype: ArrowDataType, + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: Self::Scalar<'_>, + ) -> Self { + let ret: BinaryViewArray = IfThenElseKernel::if_then_else_broadcast_both( + dtype, + mask, + if_true.as_bytes(), + if_false.as_bytes(), + ); + unsafe { ret.to_utf8view_unchecked() } + } +} + +pub fn if_then_else_view_rest( + mask: u64, + if_true: &[View], + if_false: &[View], + out: &mut [MaybeUninit], + false_buffer_idx_offset: u32, +) { + assert!(if_true.len() <= out.len()); // Removes bounds checks in inner loop. + let true_it = if_true.iter(); + let false_it = if_false.iter(); + for (i, (t, f)) in true_it.zip(false_it).enumerate() { + // Written like this, this loop *should* be branchless. + // Unfortunately we're still dependent on the compiler. + let m = (mask >> i) & 1 != 0; + let src = if m { t } else { f }; + let mut v = *src; + let offset = if m | (v.length <= 12) { + // Yes, | instead of || is intentional. + 0 + } else { + false_buffer_idx_offset + }; + v.buffer_idx += offset; + out[i] = MaybeUninit::new(v); + } +} + +pub fn if_then_else_view_64( + mask: u64, + if_true: &[View; 64], + if_false: &[View; 64], + out: &mut [MaybeUninit; 64], + false_buffer_idx_offset: u32, +) { + if_then_else_view_rest(mask, if_true, if_false, out, false_buffer_idx_offset) +} + +// Using the scalar variant of this works, but was slower, we want to select a source pointer and +// then copy it. Using this version for the integers results in branches. +pub fn if_then_else_broadcast_false_view_64( + mask: u64, + if_true: &[View; 64], + if_false: View, + out: &mut [MaybeUninit; 64], +) { + assert!(if_true.len() == out.len()); // Removes bounds checks in inner loop. + for (i, t) in if_true.iter().enumerate() { + let src = if (mask >> i) & 1 != 0 { t } else { &if_false }; + out[i] = MaybeUninit::new(*src); + } +} diff --git a/crates/polars-compute/src/lib.rs b/crates/polars-compute/src/lib.rs new file mode 100644 index 000000000000..92b017846e6b --- /dev/null +++ b/crates/polars-compute/src/lib.rs @@ -0,0 +1,66 @@ +#![cfg_attr(feature = "simd", feature(portable_simd))] +#![cfg_attr(feature = "simd", feature(avx512_target_feature))] +#![cfg_attr( + all(feature = "simd", target_arch = "x86_64"), + feature(stdarch_x86_avx512) +)] + +use arrow::types::NativeType; + +pub mod arithmetic; +pub mod arity; +pub mod binview_index_map; +pub mod bitwise; +#[cfg(feature = "approx_unique")] +pub mod cardinality; +#[cfg(feature = "cast")] +pub mod cast; +pub mod comparisons; +pub mod filter; +pub mod float_sum; +#[cfg(feature = "gather")] +pub mod gather; +pub mod horizontal_flatten; +#[cfg(feature = "approx_unique")] +pub mod hyperloglogplus; +pub mod if_then_else; +pub mod min_max; +pub mod moment; +pub mod propagate_dictionary; +pub mod rolling; +pub mod size; +pub mod sum; +pub mod unique; + +// Trait to enable the scalar blanket implementation. +pub trait NotSimdPrimitive: NativeType {} + +#[cfg(not(feature = "simd"))] +impl NotSimdPrimitive for T {} + +#[cfg(feature = "simd")] +impl NotSimdPrimitive for u128 {} +#[cfg(feature = "simd")] +impl NotSimdPrimitive for i128 {} + +// Trait to allow blanket impl for all SIMD types when simd is enabled. +#[cfg(feature = "simd")] +mod _simd_primitive { + use std::simd::SimdElement; + pub trait SimdPrimitive: SimdElement {} + impl SimdPrimitive for u8 {} + impl SimdPrimitive for u16 {} + impl SimdPrimitive for u32 {} + impl SimdPrimitive for u64 {} + impl SimdPrimitive for usize {} + impl SimdPrimitive for i8 {} + impl SimdPrimitive for i16 {} + impl SimdPrimitive for i32 {} + impl SimdPrimitive for i64 {} + impl SimdPrimitive for isize {} + impl SimdPrimitive for f32 {} + impl SimdPrimitive for f64 {} +} + +#[cfg(feature = "simd")] +pub use _simd_primitive::SimdPrimitive; diff --git a/crates/polars-compute/src/min_max/dyn_array.rs b/crates/polars-compute/src/min_max/dyn_array.rs new file mode 100644 index 000000000000..20af38dedc63 --- /dev/null +++ b/crates/polars-compute/src/min_max/dyn_array.rs @@ -0,0 +1,92 @@ +use arrow::array::{ + Array, BinaryArray, BinaryViewArray, BooleanArray, PrimitiveArray, Utf8Array, Utf8ViewArray, +}; +use arrow::scalar::{BinaryScalar, BinaryViewScalar, BooleanScalar, PrimitiveScalar, Scalar}; + +use crate::min_max::MinMaxKernel; + +macro_rules! call_op { + ($T:ty, $scalar:ty, $arr:expr, $op:path) => {{ + let arr: &$T = $arr.as_any().downcast_ref().unwrap(); + $op(arr).map(|v| Box::new(<$scalar>::new(Some(v))) as Box) + }}; + (dt: $T:ty, $scalar:ty, $arr:expr, $op:path) => {{ + let arr: &$T = $arr.as_any().downcast_ref().unwrap(); + $op(arr).map(|v| Box::new(<$scalar>::new(arr.dtype().clone(), Some(v))) as Box) + }}; + ($T:ty, $scalar:ty, $arr:expr, $op:path, ret_two) => {{ + let arr: &$T = $arr.as_any().downcast_ref().unwrap(); + $op(arr).map(|(l, r)| { + ( + Box::new(<$scalar>::new(Some(l))) as Box, + Box::new(<$scalar>::new(Some(r))) as Box, + ) + }) + }}; + (dt: $T:ty, $scalar:ty, $arr:expr, $op:path, ret_two) => {{ + let arr: &$T = $arr.as_any().downcast_ref().unwrap(); + $op(arr).map(|(l, r)| { + ( + Box::new(<$scalar>::new(arr.dtype().clone(), Some(l))) as Box, + Box::new(<$scalar>::new(arr.dtype().clone(), Some(r))) as Box, + ) + }) + }}; +} + +macro_rules! call { + ($arr:expr, $op:path$(, $variant:ident)?) => {{ + let arr = $arr; + + use arrow::datatypes::{PhysicalType as PH, PrimitiveType as PR}; + use PrimitiveArray as PArr; + use PrimitiveScalar as PScalar; + match arr.dtype().to_physical_type() { + PH::Boolean => call_op!(BooleanArray, BooleanScalar, arr, $op$(, $variant)?), + PH::Primitive(PR::Int8) => call_op!(dt: PArr, PScalar, arr, $op$(, $variant)?), + PH::Primitive(PR::Int16) => call_op!(dt: PArr, PScalar, arr, $op$(, $variant)?), + PH::Primitive(PR::Int32) => call_op!(dt: PArr, PScalar, arr, $op$(, $variant)?), + PH::Primitive(PR::Int64) => call_op!(dt: PArr, PScalar, arr, $op$(, $variant)?), + PH::Primitive(PR::Int128) => call_op!(dt: PArr, PScalar, arr, $op$(, $variant)?), + PH::Primitive(PR::UInt8) => call_op!(dt: PArr, PScalar, arr, $op$(, $variant)?), + PH::Primitive(PR::UInt16) => call_op!(dt: PArr, PScalar, arr, $op$(, $variant)?), + PH::Primitive(PR::UInt32) => call_op!(dt: PArr, PScalar, arr, $op$(, $variant)?), + PH::Primitive(PR::UInt64) => call_op!(dt: PArr, PScalar, arr, $op$(, $variant)?), + PH::Primitive(PR::UInt128) => call_op!(dt: PArr, PScalar, arr, $op$(, $variant)?), + PH::Primitive(PR::Float32) => call_op!(dt: PArr, PScalar, arr, $op$(, $variant)?), + PH::Primitive(PR::Float64) => call_op!(dt: PArr, PScalar, arr, $op$(, $variant)?), + + PH::BinaryView => call_op!(BinaryViewArray, BinaryViewScalar<[u8]>, arr, $op$(, $variant)?), + PH::Utf8View => call_op!(Utf8ViewArray, BinaryViewScalar, arr, $op$(, $variant)?), + + PH::Binary => call_op!(BinaryArray, BinaryScalar, arr, $op$(, $variant)?), + PH::LargeBinary => call_op!(BinaryArray, BinaryScalar, arr, $op$(, $variant)?), + PH::Utf8 => call_op!(Utf8Array, BinaryScalar, arr, $op$(, $variant)?), + PH::LargeUtf8 => call_op!(Utf8Array, BinaryScalar, arr, $op$(, $variant)?), + + _ => todo!("Dynamic MinMax is not yet implemented for {:?}", arr.dtype()), + } + }}; +} + +pub fn dyn_array_min_ignore_nan(arr: &dyn Array) -> Option> { + call!(arr, MinMaxKernel::min_ignore_nan_kernel) +} + +pub fn dyn_array_max_ignore_nan(arr: &dyn Array) -> Option> { + call!(arr, MinMaxKernel::max_ignore_nan_kernel) +} + +pub fn dyn_array_min_propagate_nan(arr: &dyn Array) -> Option> { + call!(arr, MinMaxKernel::min_propagate_nan_kernel) +} + +pub fn dyn_array_max_propagate_nan(arr: &dyn Array) -> Option> { + call!(arr, MinMaxKernel::max_propagate_nan_kernel) +} + +pub fn dyn_array_min_max_propagate_nan( + arr: &dyn Array, +) -> Option<(Box, Box)> { + call!(arr, MinMaxKernel::min_max_propagate_nan_kernel, ret_two) +} diff --git a/crates/polars-compute/src/min_max/mod.rs b/crates/polars-compute/src/min_max/mod.rs new file mode 100644 index 000000000000..0ffb15499449 --- /dev/null +++ b/crates/polars-compute/src/min_max/mod.rs @@ -0,0 +1,45 @@ +use polars_utils::min_max::MinMax; + +pub use self::dyn_array::{ + dyn_array_max_ignore_nan, dyn_array_max_propagate_nan, dyn_array_min_ignore_nan, + dyn_array_min_max_propagate_nan, dyn_array_min_propagate_nan, +}; + +/// Low-level min/max kernel. +pub trait MinMaxKernel { + type Scalar<'a>: MinMax + where + Self: 'a; + + fn min_ignore_nan_kernel(&self) -> Option>; + fn max_ignore_nan_kernel(&self) -> Option>; + fn min_max_ignore_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> { + Some((self.min_ignore_nan_kernel()?, self.max_ignore_nan_kernel()?)) + } + + fn min_propagate_nan_kernel(&self) -> Option>; + fn max_propagate_nan_kernel(&self) -> Option>; + fn min_max_propagate_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> { + Some(( + self.min_propagate_nan_kernel()?, + self.max_propagate_nan_kernel()?, + )) + } +} + +// Trait to enable the scalar blanket implementation. +trait NotSimdPrimitive {} + +#[cfg(not(feature = "simd"))] +impl NotSimdPrimitive for T {} + +#[cfg(feature = "simd")] +impl NotSimdPrimitive for u128 {} +#[cfg(feature = "simd")] +impl NotSimdPrimitive for i128 {} + +mod dyn_array; +mod scalar; + +#[cfg(feature = "simd")] +mod simd; diff --git a/crates/polars-compute/src/min_max/scalar.rs b/crates/polars-compute/src/min_max/scalar.rs new file mode 100644 index 000000000000..8134b53daabd --- /dev/null +++ b/crates/polars-compute/src/min_max/scalar.rs @@ -0,0 +1,263 @@ +use arrow::array::{ + Array, BinaryArray, BinaryViewArray, BooleanArray, PrimitiveArray, Utf8Array, Utf8ViewArray, +}; +use arrow::types::{NativeType, Offset}; +use polars_utils::min_max::MinMax; + +use super::MinMaxKernel; + +fn min_max_ignore_nan((cur_min, cur_max): (T, T), (min, max): (T, T)) -> (T, T) { + ( + MinMax::min_ignore_nan(cur_min, min), + MinMax::max_ignore_nan(cur_max, max), + ) +} + +fn min_max_propagate_nan((cur_min, cur_max): (T, T), (min, max): (T, T)) -> (T, T) { + ( + MinMax::min_propagate_nan(cur_min, min), + MinMax::max_propagate_nan(cur_max, max), + ) +} + +fn reduce_vals(v: &PrimitiveArray, f: F) -> Option +where + T: NativeType, + F: Fn(T, T) -> T, +{ + if v.null_count() == 0 { + v.values_iter().copied().reduce(f) + } else { + v.non_null_values_iter().reduce(f) + } +} + +fn reduce_tuple_vals(v: &PrimitiveArray, f: F) -> Option<(T, T)> +where + T: NativeType, + F: Fn((T, T), (T, T)) -> (T, T), +{ + if v.null_count() == 0 { + v.values_iter().copied().map(|v| (v, v)).reduce(f) + } else { + v.non_null_values_iter().map(|v| (v, v)).reduce(f) + } +} + +impl MinMaxKernel for PrimitiveArray { + type Scalar<'a> = T; + + fn min_ignore_nan_kernel(&self) -> Option> { + reduce_vals(self, MinMax::min_ignore_nan) + } + + fn max_ignore_nan_kernel(&self) -> Option> { + reduce_vals(self, MinMax::max_ignore_nan) + } + + fn min_max_ignore_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> { + reduce_tuple_vals(self, min_max_ignore_nan) + } + + fn min_propagate_nan_kernel(&self) -> Option> { + reduce_vals(self, MinMax::min_propagate_nan) + } + + fn max_propagate_nan_kernel(&self) -> Option> { + reduce_vals(self, MinMax::max_propagate_nan) + } + + fn min_max_propagate_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> { + reduce_tuple_vals(self, min_max_propagate_nan) + } +} + +impl MinMaxKernel for [T] { + type Scalar<'a> = T; + + fn min_ignore_nan_kernel(&self) -> Option> { + self.iter().copied().reduce(MinMax::min_ignore_nan) + } + + fn max_ignore_nan_kernel(&self) -> Option> { + self.iter().copied().reduce(MinMax::max_ignore_nan) + } + + fn min_max_ignore_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> { + self.iter() + .copied() + .map(|v| (v, v)) + .reduce(min_max_ignore_nan) + } + + fn min_propagate_nan_kernel(&self) -> Option> { + self.iter().copied().reduce(MinMax::min_propagate_nan) + } + + fn max_propagate_nan_kernel(&self) -> Option> { + self.iter().copied().reduce(MinMax::max_propagate_nan) + } + + fn min_max_propagate_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> { + self.iter() + .copied() + .map(|v| (v, v)) + .reduce(min_max_propagate_nan) + } +} + +impl MinMaxKernel for BooleanArray { + type Scalar<'a> = bool; + + fn min_ignore_nan_kernel(&self) -> Option> { + if self.len() - self.null_count() == 0 { + return None; + } + + let unset_bits = self.values().unset_bits(); + Some(unset_bits == 0) + } + + fn max_ignore_nan_kernel(&self) -> Option> { + if self.len() - self.null_count() == 0 { + return None; + } + + let set_bits = self.values().set_bits(); + Some(set_bits > 0) + } + + #[inline(always)] + fn min_propagate_nan_kernel(&self) -> Option> { + self.min_ignore_nan_kernel() + } + + #[inline(always)] + fn max_propagate_nan_kernel(&self) -> Option> { + self.max_ignore_nan_kernel() + } +} + +impl MinMaxKernel for BinaryViewArray { + type Scalar<'a> = &'a [u8]; + + fn min_ignore_nan_kernel(&self) -> Option> { + if self.null_count() == 0 { + self.values_iter().reduce(MinMax::min_ignore_nan) + } else { + self.non_null_values_iter().reduce(MinMax::min_ignore_nan) + } + } + + fn max_ignore_nan_kernel(&self) -> Option> { + if self.null_count() == 0 { + self.values_iter().reduce(MinMax::max_ignore_nan) + } else { + self.non_null_values_iter().reduce(MinMax::max_ignore_nan) + } + } + + #[inline(always)] + fn min_propagate_nan_kernel(&self) -> Option> { + self.min_ignore_nan_kernel() + } + + #[inline(always)] + fn max_propagate_nan_kernel(&self) -> Option> { + self.max_ignore_nan_kernel() + } +} + +impl MinMaxKernel for Utf8ViewArray { + type Scalar<'a> = &'a str; + + #[inline(always)] + fn min_ignore_nan_kernel(&self) -> Option> { + self.to_binview().min_ignore_nan_kernel().map(|s| unsafe { + // SAFETY: the lifetime is the same, and it is valid UTF-8. + #[allow(clippy::transmute_bytes_to_str)] + std::mem::transmute::<&[u8], &str>(s) + }) + } + + #[inline(always)] + fn max_ignore_nan_kernel(&self) -> Option> { + self.to_binview().max_ignore_nan_kernel().map(|s| unsafe { + // SAFETY: the lifetime is the same, and it is valid UTF-8. + #[allow(clippy::transmute_bytes_to_str)] + std::mem::transmute::<&[u8], &str>(s) + }) + } + + #[inline(always)] + fn min_propagate_nan_kernel(&self) -> Option> { + self.min_ignore_nan_kernel() + } + + #[inline(always)] + fn max_propagate_nan_kernel(&self) -> Option> { + self.max_ignore_nan_kernel() + } +} + +impl MinMaxKernel for BinaryArray { + type Scalar<'a> = &'a [u8]; + + fn min_ignore_nan_kernel(&self) -> Option> { + if self.null_count() == 0 { + self.values_iter().reduce(MinMax::min_ignore_nan) + } else { + self.non_null_values_iter().reduce(MinMax::min_ignore_nan) + } + } + + fn max_ignore_nan_kernel(&self) -> Option> { + if self.null_count() == 0 { + self.values_iter().reduce(MinMax::max_ignore_nan) + } else { + self.non_null_values_iter().reduce(MinMax::max_ignore_nan) + } + } + + #[inline(always)] + fn min_propagate_nan_kernel(&self) -> Option> { + self.min_ignore_nan_kernel() + } + + #[inline(always)] + fn max_propagate_nan_kernel(&self) -> Option> { + self.max_ignore_nan_kernel() + } +} + +impl MinMaxKernel for Utf8Array { + type Scalar<'a> = &'a str; + + #[inline(always)] + fn min_ignore_nan_kernel(&self) -> Option> { + self.to_binary().min_ignore_nan_kernel().map(|s| unsafe { + // SAFETY: the lifetime is the same, and it is valid UTF-8. + #[allow(clippy::transmute_bytes_to_str)] + std::mem::transmute::<&[u8], &str>(s) + }) + } + + #[inline(always)] + fn max_ignore_nan_kernel(&self) -> Option> { + self.to_binary().max_ignore_nan_kernel().map(|s| unsafe { + // SAFETY: the lifetime is the same, and it is valid UTF-8. + #[allow(clippy::transmute_bytes_to_str)] + std::mem::transmute::<&[u8], &str>(s) + }) + } + + #[inline(always)] + fn min_propagate_nan_kernel(&self) -> Option> { + self.min_ignore_nan_kernel() + } + + #[inline(always)] + fn max_propagate_nan_kernel(&self) -> Option> { + self.max_ignore_nan_kernel() + } +} diff --git a/crates/polars-compute/src/min_max/simd.rs b/crates/polars-compute/src/min_max/simd.rs new file mode 100644 index 000000000000..518d8e85e4d0 --- /dev/null +++ b/crates/polars-compute/src/min_max/simd.rs @@ -0,0 +1,382 @@ +use std::simd::prelude::*; +use std::simd::{LaneCount, SimdElement, SupportedLaneCount}; + +use arrow::array::PrimitiveArray; +use arrow::bitmap::Bitmap; +use arrow::bitmap::bitmask::BitMask; +use arrow::types::NativeType; +use polars_utils::min_max::MinMax; + +use super::MinMaxKernel; + +fn scalar_reduce_min_propagate_nan(arr: &[T; N]) -> T { + let it = arr.iter().copied(); + it.reduce(MinMax::min_propagate_nan).unwrap() +} + +fn scalar_reduce_max_propagate_nan(arr: &[T; N]) -> T { + let it = arr.iter().copied(); + it.reduce(MinMax::max_propagate_nan).unwrap() +} + +fn fold_agg_kernel( + arr: &[T], + validity: Option<&Bitmap>, + scalar_identity: T, + mut simd_f: F, +) -> Option> +where + T: SimdElement + NativeType, + F: FnMut(Simd, Simd) -> Simd, + LaneCount: SupportedLaneCount, +{ + if arr.is_empty() { + return None; + } + + let mut arr_chunks = arr.chunks_exact(N); + + let identity = Simd::splat(scalar_identity); + let mut state = identity; + if let Some(valid) = validity { + if valid.unset_bits() == arr.len() { + return None; + } + + let mask = BitMask::from_bitmap(valid); + let mut offset = 0; + for c in arr_chunks.by_ref() { + let m: Mask<_, N> = mask.get_simd(offset); + state = simd_f(state, m.select(Simd::from_slice(c), identity)); + offset += N; + } + if arr.len() % N > 0 { + let mut rest: [T; N] = identity.to_array(); + let arr_rest = arr_chunks.remainder(); + rest[..arr_rest.len()].copy_from_slice(arr_rest); + let m: Mask<_, N> = mask.get_simd(offset); + state = simd_f(state, m.select(Simd::from_array(rest), identity)); + } + } else { + for c in arr_chunks.by_ref() { + state = simd_f(state, Simd::from_slice(c)); + } + if arr.len() % N > 0 { + let mut rest: [T; N] = identity.to_array(); + let arr_rest = arr_chunks.remainder(); + rest[..arr_rest.len()].copy_from_slice(arr_rest); + state = simd_f(state, Simd::from_array(rest)); + } + } + + Some(state) +} + +fn fold_agg_min_max_kernel( + arr: &[T], + validity: Option<&Bitmap>, + min_scalar_identity: T, + max_scalar_identity: T, + mut simd_f: F, +) -> Option<(Simd, Simd)> +where + T: SimdElement + NativeType, + F: FnMut((Simd, Simd), (Simd, Simd)) -> (Simd, Simd), + LaneCount: SupportedLaneCount, +{ + if arr.is_empty() { + return None; + } + + let mut arr_chunks = arr.chunks_exact(N); + + let min_identity = Simd::splat(min_scalar_identity); + let max_identity = Simd::splat(max_scalar_identity); + let mut state = (min_identity, max_identity); + if let Some(valid) = validity { + if valid.unset_bits() == arr.len() { + return None; + } + + let mask = BitMask::from_bitmap(valid); + let mut offset = 0; + for c in arr_chunks.by_ref() { + let m: Mask<_, N> = mask.get_simd(offset); + let slice = Simd::from_slice(c); + state = simd_f( + state, + (m.select(slice, min_identity), m.select(slice, max_identity)), + ); + offset += N; + } + if arr.len() % N > 0 { + let mut min_rest: [T; N] = min_identity.to_array(); + let mut max_rest: [T; N] = max_identity.to_array(); + + let arr_rest = arr_chunks.remainder(); + min_rest[..arr_rest.len()].copy_from_slice(arr_rest); + max_rest[..arr_rest.len()].copy_from_slice(arr_rest); + + let m: Mask<_, N> = mask.get_simd(offset); + + let min_rest = Simd::from_array(min_rest); + let max_rest = Simd::from_array(max_rest); + + state = simd_f( + state, + ( + m.select(min_rest, min_identity), + m.select(max_rest, max_identity), + ), + ); + } + } else { + for c in arr_chunks.by_ref() { + let slice = Simd::from_slice(c); + state = simd_f(state, (slice, slice)); + } + if arr.len() % N > 0 { + let mut min_rest: [T; N] = min_identity.to_array(); + let mut max_rest: [T; N] = max_identity.to_array(); + + let arr_rest = arr_chunks.remainder(); + min_rest[..arr_rest.len()].copy_from_slice(arr_rest); + max_rest[..arr_rest.len()].copy_from_slice(arr_rest); + + let min_rest = Simd::from_array(min_rest); + let max_rest = Simd::from_array(max_rest); + + state = simd_f(state, (min_rest, max_rest)); + } + } + + Some(state) +} + +macro_rules! impl_min_max_kernel_int { + ($T:ty, $N:literal) => { + impl MinMaxKernel for PrimitiveArray<$T> { + type Scalar<'a> = $T; + + fn min_ignore_nan_kernel(&self) -> Option> { + fold_agg_kernel::<$N, $T, _>(self.values(), self.validity(), <$T>::MAX, |a, b| { + a.simd_min(b) + }) + .map(|s| s.reduce_min()) + } + + fn max_ignore_nan_kernel(&self) -> Option> { + fold_agg_kernel::<$N, $T, _>(self.values(), self.validity(), <$T>::MIN, |a, b| { + a.simd_max(b) + }) + .map(|s| s.reduce_max()) + } + + fn min_max_ignore_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> { + fold_agg_min_max_kernel::<$N, $T, _>( + self.values(), + self.validity(), + <$T>::MAX, + <$T>::MIN, + |(cmin, cmax), (min, max)| (cmin.simd_min(min), cmax.simd_max(max)), + ) + .map(|(min, max)| (min.reduce_min(), max.reduce_max())) + } + + fn min_propagate_nan_kernel(&self) -> Option> { + self.min_ignore_nan_kernel() + } + + fn max_propagate_nan_kernel(&self) -> Option> { + self.max_ignore_nan_kernel() + } + + fn min_max_propagate_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> { + self.min_max_ignore_nan_kernel() + } + } + + impl MinMaxKernel for [$T] { + type Scalar<'a> = $T; + + fn min_ignore_nan_kernel(&self) -> Option> { + fold_agg_kernel::<$N, $T, _>(self, None, <$T>::MAX, |a, b| a.simd_min(b)) + .map(|s| s.reduce_min()) + } + + fn max_ignore_nan_kernel(&self) -> Option> { + fold_agg_kernel::<$N, $T, _>(self, None, <$T>::MIN, |a, b| a.simd_max(b)) + .map(|s| s.reduce_max()) + } + + fn min_max_ignore_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> { + fold_agg_min_max_kernel::<$N, $T, _>( + self, + None, + <$T>::MAX, + <$T>::MIN, + |(cmin, cmax), (min, max)| (cmin.simd_min(min), cmax.simd_max(max)), + ) + .map(|(min, max)| (min.reduce_min(), max.reduce_max())) + } + + fn min_propagate_nan_kernel(&self) -> Option> { + self.min_ignore_nan_kernel() + } + + fn max_propagate_nan_kernel(&self) -> Option> { + self.max_ignore_nan_kernel() + } + + fn min_max_propagate_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> { + self.min_max_ignore_nan_kernel() + } + } + }; +} + +impl_min_max_kernel_int!(u8, 32); +impl_min_max_kernel_int!(u16, 16); +impl_min_max_kernel_int!(u32, 16); +impl_min_max_kernel_int!(u64, 8); +impl_min_max_kernel_int!(i8, 32); +impl_min_max_kernel_int!(i16, 16); +impl_min_max_kernel_int!(i32, 16); +impl_min_max_kernel_int!(i64, 8); + +macro_rules! impl_min_max_kernel_float { + ($T:ty, $N:literal) => { + impl MinMaxKernel for PrimitiveArray<$T> { + type Scalar<'a> = $T; + + fn min_ignore_nan_kernel(&self) -> Option> { + fold_agg_kernel::<$N, $T, _>(self.values(), self.validity(), <$T>::NAN, |a, b| { + a.simd_min(b) + }) + .map(|s| s.reduce_min()) + } + + fn max_ignore_nan_kernel(&self) -> Option> { + fold_agg_kernel::<$N, $T, _>(self.values(), self.validity(), <$T>::NAN, |a, b| { + a.simd_max(b) + }) + .map(|s| s.reduce_max()) + } + + fn min_max_ignore_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> { + fold_agg_min_max_kernel::<$N, $T, _>( + self.values(), + self.validity(), + <$T>::NAN, + <$T>::NAN, + |(cmin, cmax), (min, max)| (cmin.simd_min(min), cmax.simd_max(max)), + ) + .map(|(min, max)| (min.reduce_min(), max.reduce_max())) + } + + fn min_propagate_nan_kernel(&self) -> Option> { + fold_agg_kernel::<$N, $T, _>( + self.values(), + self.validity(), + <$T>::INFINITY, + |a, b| (a.simd_lt(b) | a.simd_ne(a)).select(a, b), + ) + .map(|s| scalar_reduce_min_propagate_nan(s.as_array())) + } + + fn max_propagate_nan_kernel(&self) -> Option> { + fold_agg_kernel::<$N, $T, _>( + self.values(), + self.validity(), + <$T>::NEG_INFINITY, + |a, b| (a.simd_gt(b) | a.simd_ne(a)).select(a, b), + ) + .map(|s| scalar_reduce_max_propagate_nan(s.as_array())) + } + + fn min_max_propagate_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> { + fold_agg_min_max_kernel::<$N, $T, _>( + self.values(), + self.validity(), + <$T>::INFINITY, + <$T>::NEG_INFINITY, + |(cmin, cmax), (min, max)| { + ( + (cmin.simd_lt(min) | cmin.simd_ne(cmin)).select(cmin, min), + (cmax.simd_gt(max) | cmax.simd_ne(cmax)).select(cmax, max), + ) + }, + ) + .map(|(min, max)| { + ( + scalar_reduce_min_propagate_nan(min.as_array()), + scalar_reduce_max_propagate_nan(max.as_array()), + ) + }) + } + } + + impl MinMaxKernel for [$T] { + type Scalar<'a> = $T; + + fn min_ignore_nan_kernel(&self) -> Option> { + fold_agg_kernel::<$N, $T, _>(self, None, <$T>::NAN, |a, b| a.simd_min(b)) + .map(|s| s.reduce_min()) + } + + fn max_ignore_nan_kernel(&self) -> Option> { + fold_agg_kernel::<$N, $T, _>(self, None, <$T>::NAN, |a, b| a.simd_max(b)) + .map(|s| s.reduce_max()) + } + + fn min_max_ignore_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> { + fold_agg_min_max_kernel::<$N, $T, _>( + self, + None, + <$T>::NAN, + <$T>::NAN, + |(cmin, cmax), (min, max)| (cmin.simd_min(min), cmax.simd_max(max)), + ) + .map(|(min, max)| (min.reduce_min(), max.reduce_max())) + } + + fn min_propagate_nan_kernel(&self) -> Option> { + fold_agg_kernel::<$N, $T, _>(self, None, <$T>::INFINITY, |a, b| { + (a.simd_lt(b) | a.simd_ne(a)).select(a, b) + }) + .map(|s| scalar_reduce_min_propagate_nan(s.as_array())) + } + + fn max_propagate_nan_kernel(&self) -> Option> { + fold_agg_kernel::<$N, $T, _>(self, None, <$T>::NEG_INFINITY, |a, b| { + (a.simd_gt(b) | a.simd_ne(a)).select(a, b) + }) + .map(|s| scalar_reduce_max_propagate_nan(s.as_array())) + } + + fn min_max_propagate_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> { + fold_agg_min_max_kernel::<$N, $T, _>( + self, + None, + <$T>::INFINITY, + <$T>::NEG_INFINITY, + |(cmin, cmax), (min, max)| { + ( + (cmin.simd_lt(min) | cmin.simd_ne(cmin)).select(cmin, min), + (cmax.simd_gt(max) | cmax.simd_ne(cmax)).select(cmax, max), + ) + }, + ) + .map(|(min, max)| { + ( + scalar_reduce_min_propagate_nan(min.as_array()), + scalar_reduce_max_propagate_nan(max.as_array()), + ) + }) + } + } + }; +} + +impl_min_max_kernel_float!(f32, 16); +impl_min_max_kernel_float!(f64, 8); diff --git a/crates/polars-compute/src/moment.rs b/crates/polars-compute/src/moment.rs new file mode 100644 index 000000000000..2cbac2151bab --- /dev/null +++ b/crates/polars-compute/src/moment.rs @@ -0,0 +1,680 @@ +// Some formulae: +// mean_x = sum(weight[i] * x[i]) / sum(weight) +// dp_xy = weighted sum of deviation products of variables x, y, written in +// the paper as simply XY. +// dp_xy = sum(weight[i] * (x[i] - mean_x) * (y[i] - mean_y)) +// +// cov(x, y) = dp_xy / sum(weight) +// var(x) = cov(x, x) +// +// Algorithms from: +// Numerically stable parallel computation of (co-)variance. +// Schubert, E. & Gertz, M. (2018). +// +// Key equations from the paper: +// (17) for mean update, (23) for dp update (and also Table 1). +// +// +// For higher moments we refer to: +// Numerically Stable, Scalable Formulas for Parallel and Online Computation of +// Higher-Order Multivariate Central Moments with Arbitrary Weights. +// Pébay, P. & Terriberry, T. B. & Kolla, H. & Bennett J. (2016) +// +// Key equations from paper: +// (3.26) mean update, (3.27) moment update. +// +// Here we use mk to mean the weighted kth central moment: +// mk = sum(weight[i] * (x[i] - mean_x)**k) +// Note that we'll use the terms m2 = dp = dp_xx if unambiguous. + +#![allow(clippy::collapsible_else_if)] + +use arrow::array::{Array, PrimitiveArray}; +use arrow::types::NativeType; +use num_traits::AsPrimitive; +use polars_utils::algebraic_ops::*; + +const CHUNK_SIZE: usize = 128; + +#[derive(Default, Clone)] +pub struct VarState { + weight: f64, + mean: f64, + dp: f64, +} + +#[derive(Default, Clone)] +pub struct CovState { + weight: f64, + mean_x: f64, + mean_y: f64, + dp_xy: f64, +} + +#[derive(Default, Clone)] +pub struct PearsonState { + weight: f64, + mean_x: f64, + mean_y: f64, + dp_xx: f64, + dp_xy: f64, + dp_yy: f64, +} + +impl VarState { + fn new(x: &[f64]) -> Self { + if x.is_empty() { + return Self::default(); + } + + let weight = x.len() as f64; + let mean = alg_sum_f64(x.iter().copied()) / weight; + Self { + weight, + mean, + dp: alg_sum_f64(x.iter().map(|&xi| (xi - mean) * (xi - mean))), + } + } + + fn clear_zero_weight_nan(&mut self) { + // Clear NaNs due to division by zero. + if self.weight == 0.0 { + self.mean = 0.0; + self.dp = 0.0; + } + } + + pub fn insert_one(&mut self, x: f64) { + // Just a specialized version of + // self.combine(&Self { weight: 1.0, mean: x, dp: 0.0 }) + let new_weight = self.weight + 1.0; + let delta_mean = x - self.mean; + let new_mean = self.mean + delta_mean / new_weight; + self.dp += (x - new_mean) * delta_mean; + self.weight = new_weight; + self.mean = new_mean; + self.clear_zero_weight_nan(); + } + + pub fn remove_one(&mut self, x: f64) { + // Just a specialized version of + // self.combine(&Self { weight: -1.0, mean: x, dp: 0.0 }) + let new_weight = self.weight - 1.0; + let delta_mean = x - self.mean; + let new_mean = self.mean - delta_mean / new_weight; + self.dp -= (x - new_mean) * delta_mean; + self.weight = new_weight; + self.mean = new_mean; + self.clear_zero_weight_nan(); + } + + pub fn combine(&mut self, other: &Self) { + if other.weight == 0.0 { + return; + } + + let new_weight = self.weight + other.weight; + let other_weight_frac = other.weight / new_weight; + let delta_mean = other.mean - self.mean; + let new_mean = self.mean + delta_mean * other_weight_frac; + self.dp += other.dp + other.weight * (other.mean - new_mean) * delta_mean; + self.weight = new_weight; + self.mean = new_mean; + self.clear_zero_weight_nan(); + } + + pub fn finalize(&self, ddof: u8) -> Option { + if self.weight <= ddof as f64 { + None + } else { + let var = self.dp / (self.weight - ddof as f64); + Some(if var < 0.0 { + // Variance can't be negative, except through numerical instability. + // We don't use f64::max here so we propagate nans. + 0.0 + } else { + var + }) + } + } +} + +impl CovState { + fn new(x: &[f64], y: &[f64]) -> Self { + assert!(x.len() == y.len()); + if x.is_empty() { + return Self::default(); + } + + let weight = x.len() as f64; + let inv_weight = 1.0 / weight; + let mean_x = alg_sum_f64(x.iter().copied()) * inv_weight; + let mean_y = alg_sum_f64(y.iter().copied()) * inv_weight; + Self { + weight, + mean_x, + mean_y, + dp_xy: alg_sum_f64( + x.iter() + .zip(y) + .map(|(&xi, &yi)| (xi - mean_x) * (yi - mean_y)), + ), + } + } + + pub fn combine(&mut self, other: &Self) { + if other.weight == 0.0 { + return; + } + + let new_weight = self.weight + other.weight; + let other_weight_frac = other.weight / new_weight; + let delta_mean_x = other.mean_x - self.mean_x; + let delta_mean_y = other.mean_y - self.mean_y; + let new_mean_x = self.mean_x + delta_mean_x * other_weight_frac; + let new_mean_y = self.mean_y + delta_mean_y * other_weight_frac; + self.dp_xy += other.dp_xy + other.weight * (other.mean_x - new_mean_x) * delta_mean_y; + self.weight = new_weight; + self.mean_x = new_mean_x; + self.mean_y = new_mean_y; + } + + pub fn finalize(&self, ddof: u8) -> Option { + if self.weight <= ddof as f64 { + None + } else { + Some(self.dp_xy / (self.weight - ddof as f64)) + } + } +} + +impl PearsonState { + fn new(x: &[f64], y: &[f64]) -> Self { + assert!(x.len() == y.len()); + if x.is_empty() { + return Self::default(); + } + + let weight = x.len() as f64; + let inv_weight = 1.0 / weight; + let mean_x = alg_sum_f64(x.iter().copied()) * inv_weight; + let mean_y = alg_sum_f64(y.iter().copied()) * inv_weight; + let mut dp_xx = 0.0; + let mut dp_xy = 0.0; + let mut dp_yy = 0.0; + for (xi, yi) in x.iter().zip(y.iter()) { + dp_xx = alg_add_f64(dp_xx, (xi - mean_x) * (xi - mean_x)); + dp_xy = alg_add_f64(dp_xy, (xi - mean_x) * (yi - mean_y)); + dp_yy = alg_add_f64(dp_yy, (yi - mean_y) * (yi - mean_y)); + } + Self { + weight, + mean_x, + mean_y, + dp_xx, + dp_xy, + dp_yy, + } + } + + pub fn combine(&mut self, other: &Self) { + if other.weight == 0.0 { + return; + } + + let new_weight = self.weight + other.weight; + let other_weight_frac = other.weight / new_weight; + let delta_mean_x = other.mean_x - self.mean_x; + let delta_mean_y = other.mean_y - self.mean_y; + let new_mean_x = self.mean_x + delta_mean_x * other_weight_frac; + let new_mean_y = self.mean_y + delta_mean_y * other_weight_frac; + self.dp_xx += other.dp_xx + other.weight * (other.mean_x - new_mean_x) * delta_mean_x; + self.dp_xy += other.dp_xy + other.weight * (other.mean_x - new_mean_x) * delta_mean_y; + self.dp_yy += other.dp_yy + other.weight * (other.mean_y - new_mean_y) * delta_mean_y; + self.weight = new_weight; + self.mean_x = new_mean_x; + self.mean_y = new_mean_y; + } + + pub fn finalize(&self) -> f64 { + let denom_sq = self.dp_xx * self.dp_yy; + if denom_sq > 0.0 { + self.dp_xy / denom_sq.sqrt() + } else { + f64::NAN + } + } +} + +#[derive(Default, Clone)] +pub struct SkewState { + weight: f64, + mean: f64, + m2: f64, + m3: f64, +} + +impl SkewState { + fn new(x: &[f64]) -> Self { + if x.is_empty() { + return Self::default(); + } + + let weight = x.len() as f64; + let mean = alg_sum_f64(x.iter().copied()) / weight; + let mut m2 = 0.0; + let mut m3 = 0.0; + for xi in x.iter() { + let d = xi - mean; + let d2 = d * d; + let d3 = d * d2; + m2 = alg_add_f64(m2, d2); + m3 = alg_add_f64(m3, d3); + } + Self { + weight, + mean, + m2, + m3, + } + } + + fn clear_zero_weight_nan(&mut self) { + // Clear NaNs due to division by zero. + if self.weight == 0.0 { + self.mean = 0.0; + self.m2 = 0.0; + self.m3 = 0.0; + } + } + + pub fn insert_one(&mut self, x: f64) { + // Specialization of self.combine(&SkewState { weight: 1.0, mean: x, m2: 0.0, m3: 0.0 }); + let new_weight = self.weight + 1.0; + let delta_mean = x - self.mean; + let delta_mean_weight = delta_mean / new_weight; + let new_mean = self.mean + delta_mean_weight; + + let weight_diff = self.weight - 1.0; + let m2_update = (x - new_mean) * delta_mean; + let new_m2 = self.m2 + m2_update; + let new_m3 = self.m3 + delta_mean_weight * (m2_update * weight_diff - 3.0 * self.m2); + + self.weight = new_weight; + self.mean = new_mean; + self.m2 = new_m2; + self.m3 = new_m3; + self.clear_zero_weight_nan(); + } + + pub fn remove_one(&mut self, x: f64) { + // Specialization of self.combine(&SkewState { weight: -1.0, mean: x, m2: 0.0, m3: 0.0 }); + let new_weight = self.weight - 1.0; + let delta_mean = x - self.mean; + let delta_mean_weight = delta_mean / new_weight; + let new_mean = self.mean - delta_mean_weight; + + let weight_diff = self.weight + 1.0; + let m2_update = (new_mean - x) * delta_mean; + let new_m2 = self.m2 + m2_update; + let new_m3 = self.m3 + delta_mean_weight * (m2_update * weight_diff + 3.0 * self.m2); + + self.weight = new_weight; + self.mean = new_mean; + self.m2 = new_m2; + self.m3 = new_m3; + self.clear_zero_weight_nan(); + } + + pub fn combine(&mut self, other: &Self) { + if other.weight == 0.0 { + return; + } + + let new_weight = self.weight + other.weight; + let delta_mean = other.mean - self.mean; + let delta_mean_weight = delta_mean / new_weight; + let new_mean = self.mean + other.weight * delta_mean_weight; + + let weight_diff = self.weight - other.weight; + let self_weight_other_m2 = self.weight * other.m2; + let other_weight_self_m2 = other.weight * self.m2; + let m2_update = other.weight * (other.mean - new_mean) * delta_mean; + let new_m2 = self.m2 + other.m2 + m2_update; + let new_m3 = self.m3 + + other.m3 + + delta_mean_weight + * (m2_update * weight_diff + 3.0 * (self_weight_other_m2 - other_weight_self_m2)); + + self.weight = new_weight; + self.mean = new_mean; + self.m2 = new_m2; + self.m3 = new_m3; + self.clear_zero_weight_nan(); + } + + pub fn finalize(&self, bias: bool) -> Option { + let m2 = self.m2 / self.weight; + let m3 = self.m3 / self.weight; + let is_zero = m2 <= (f64::EPSILON * self.mean).powi(2); + let biased_est = if is_zero { f64::NAN } else { m3 / m2.powf(1.5) }; + if bias { + if self.weight == 0.0 { + None + } else { + Some(biased_est) + } + } else { + if self.weight <= 2.0 { + None + } else { + let correction = (self.weight * (self.weight - 1.0)).sqrt() / (self.weight - 2.0); + Some(correction * biased_est) + } + } + } +} + +#[derive(Default, Clone)] +pub struct KurtosisState { + weight: f64, + mean: f64, + m2: f64, + m3: f64, + m4: f64, +} + +impl KurtosisState { + fn new(x: &[f64]) -> Self { + if x.is_empty() { + return Self::default(); + } + + let weight = x.len() as f64; + let mean = alg_sum_f64(x.iter().copied()) / weight; + let mut m2 = 0.0; + let mut m3 = 0.0; + let mut m4 = 0.0; + for xi in x.iter() { + let d = xi - mean; + let d2 = d * d; + let d3 = d * d2; + let d4 = d2 * d2; + m2 = alg_add_f64(m2, d2); + m3 = alg_add_f64(m3, d3); + m4 = alg_add_f64(m4, d4); + } + Self { + weight, + mean, + m2, + m3, + m4, + } + } + + fn clear_zero_weight_nan(&mut self) { + // Clear NaNs due to division by zero. + if self.weight == 0.0 { + self.mean = 0.0; + self.m2 = 0.0; + self.m3 = 0.0; + self.m4 = 0.0; + } + } + + pub fn insert_one(&mut self, x: f64) { + // Specialization of self.combine(&KurtosisState { weight: 1.0, mean: x, m2: 0.0, m3: 0.0, m4: 0.0 }); + let new_weight = self.weight + 1.0; + let delta_mean = x - self.mean; + let delta_mean_weight = delta_mean / new_weight; + let new_mean = self.mean + delta_mean_weight; + + let weight_diff = self.weight - 1.0; + let m2_update = (x - new_mean) * delta_mean; + let new_m2 = self.m2 + m2_update; + let new_m3 = self.m3 + delta_mean_weight * (m2_update * weight_diff - 3.0 * self.m2); + let new_m4 = self.m4 + + delta_mean_weight + * (delta_mean_weight + * (m2_update * (self.weight * weight_diff + 1.0) + 6.0 * self.m2) + - 4.0 * self.m3); + + self.weight = new_weight; + self.mean = new_mean; + self.m2 = new_m2; + self.m3 = new_m3; + self.m4 = new_m4; + self.clear_zero_weight_nan(); + } + + pub fn remove_one(&mut self, x: f64) { + // Specialization of self.combine(&KurtosisState { weight: -1.0, mean: x, m2: 0.0, m3: 0.0, m4: 0.0 }); + let new_weight = self.weight - 1.0; + let delta_mean = x - self.mean; + let delta_mean_weight = delta_mean / new_weight; + let new_mean = self.mean - delta_mean_weight; + + let weight_diff = self.weight + 1.0; + let m2_update = (new_mean - x) * delta_mean; + let new_m2 = self.m2 + m2_update; + let new_m3 = self.m3 + delta_mean_weight * (m2_update * weight_diff + 3.0 * self.m2); + let new_m4 = self.m4 + + delta_mean_weight + * (delta_mean_weight + * (m2_update * (self.weight * weight_diff + 1.0) + 6.0 * self.m2) + + 4.0 * self.m3); + + self.weight = new_weight; + self.mean = new_mean; + self.m2 = new_m2; + self.m3 = new_m3; + self.m4 = new_m4; + self.clear_zero_weight_nan(); + } + + pub fn combine(&mut self, other: &Self) { + if other.weight == 0.0 { + return; + } + + let new_weight = self.weight + other.weight; + let delta_mean = other.mean - self.mean; + let delta_mean_weight = delta_mean / new_weight; + let new_mean = self.mean + other.weight * delta_mean_weight; + + let weight_diff = self.weight - other.weight; + let self_weight_other_m2 = self.weight * other.m2; + let other_weight_self_m2 = other.weight * self.m2; + let m2_update = other.weight * (other.mean - new_mean) * delta_mean; + let new_m2 = self.m2 + other.m2 + m2_update; + let new_m3 = self.m3 + + other.m3 + + delta_mean_weight + * (m2_update * weight_diff + 3.0 * (self_weight_other_m2 - other_weight_self_m2)); + let new_m4 = self.m4 + + other.m4 + + delta_mean_weight + * (delta_mean_weight + * (m2_update * (self.weight * weight_diff + other.weight * other.weight) + + 6.0 + * (self.weight * self_weight_other_m2 + + other.weight * other_weight_self_m2)) + + 4.0 * (self.weight * other.m3 - other.weight * self.m3)); + + self.weight = new_weight; + self.mean = new_mean; + self.m2 = new_m2; + self.m3 = new_m3; + self.m4 = new_m4; + self.clear_zero_weight_nan(); + } + + pub fn finalize(&self, fisher: bool, bias: bool) -> Option { + let m4 = self.m4 / self.weight; + let m2 = self.m2 / self.weight; + let is_zero = m2 <= (f64::EPSILON * self.mean).powi(2); + let biased_est = if is_zero { f64::NAN } else { m4 / (m2 * m2) }; + let out = if bias { + if self.weight == 0.0 { + return None; + } + + biased_est + } else { + if self.weight <= 3.0 { + return None; + } + + let n = self.weight; + let nm1_nm2 = (n - 1.0) / (n - 2.0); + let np1_nm3 = (n + 1.0) / (n - 3.0); + let nm1_nm3 = (n - 1.0) / (n - 3.0); + nm1_nm2 * (np1_nm3 * biased_est - 3.0 * nm1_nm3) + 3.0 + }; + + if fisher { Some(out - 3.0) } else { Some(out) } + } +} + +fn chunk_as_float(it: I, mut f: F) +where + T: NativeType + AsPrimitive, + I: IntoIterator, + F: FnMut(&[f64]), +{ + let mut chunk = [0.0; CHUNK_SIZE]; + let mut i = 0; + for val in it { + if i >= CHUNK_SIZE { + f(&chunk); + i = 0; + } + chunk[i] = val.as_(); + i += 1; + } + if i > 0 { + f(&chunk[..i]); + } +} + +fn chunk_as_float_binary(it: I, mut f: F) +where + T: NativeType + AsPrimitive, + U: NativeType + AsPrimitive, + I: IntoIterator, + F: FnMut(&[f64], &[f64]), +{ + let mut left_chunk = [0.0; CHUNK_SIZE]; + let mut right_chunk = [0.0; CHUNK_SIZE]; + let mut i = 0; + for (l, r) in it { + if i >= CHUNK_SIZE { + f(&left_chunk, &right_chunk); + i = 0; + } + left_chunk[i] = l.as_(); + right_chunk[i] = r.as_(); + i += 1; + } + if i > 0 { + f(&left_chunk[..i], &right_chunk[..i]); + } +} + +pub fn var(arr: &PrimitiveArray) -> VarState +where + T: NativeType + AsPrimitive, +{ + let mut out = VarState::default(); + if arr.has_nulls() { + chunk_as_float(arr.non_null_values_iter(), |chunk| { + out.combine(&VarState::new(chunk)) + }); + } else { + chunk_as_float(arr.values().iter().copied(), |chunk| { + out.combine(&VarState::new(chunk)) + }); + } + out +} + +pub fn cov(x: &PrimitiveArray, y: &PrimitiveArray) -> CovState +where + T: NativeType + AsPrimitive, + U: NativeType + AsPrimitive, +{ + assert!(x.len() == y.len()); + let mut out = CovState::default(); + if x.has_nulls() || y.has_nulls() { + chunk_as_float_binary( + x.iter() + .zip(y.iter()) + .filter_map(|(l, r)| l.copied().zip(r.copied())), + |l, r| out.combine(&CovState::new(l, r)), + ); + } else { + chunk_as_float_binary( + x.values().iter().copied().zip(y.values().iter().copied()), + |l, r| out.combine(&CovState::new(l, r)), + ); + } + out +} + +pub fn pearson_corr(x: &PrimitiveArray, y: &PrimitiveArray) -> PearsonState +where + T: NativeType + AsPrimitive, + U: NativeType + AsPrimitive, +{ + assert!(x.len() == y.len()); + let mut out = PearsonState::default(); + if x.has_nulls() || y.has_nulls() { + chunk_as_float_binary( + x.iter() + .zip(y.iter()) + .filter_map(|(l, r)| l.copied().zip(r.copied())), + |l, r| out.combine(&PearsonState::new(l, r)), + ); + } else { + chunk_as_float_binary( + x.values().iter().copied().zip(y.values().iter().copied()), + |l, r| out.combine(&PearsonState::new(l, r)), + ); + } + out +} + +pub fn skew(arr: &PrimitiveArray) -> SkewState +where + T: NativeType + AsPrimitive, +{ + let mut out = SkewState::default(); + if arr.has_nulls() { + chunk_as_float(arr.non_null_values_iter(), |chunk| { + out.combine(&SkewState::new(chunk)) + }); + } else { + chunk_as_float(arr.values().iter().copied(), |chunk| { + out.combine(&SkewState::new(chunk)) + }); + } + out +} + +pub fn kurtosis(arr: &PrimitiveArray) -> KurtosisState +where + T: NativeType + AsPrimitive, +{ + let mut out = KurtosisState::default(); + if arr.has_nulls() { + chunk_as_float(arr.non_null_values_iter(), |chunk| { + out.combine(&KurtosisState::new(chunk)) + }); + } else { + chunk_as_float(arr.values().iter().copied(), |chunk| { + out.combine(&KurtosisState::new(chunk)) + }); + } + out +} diff --git a/crates/polars-compute/src/propagate_dictionary.rs b/crates/polars-compute/src/propagate_dictionary.rs new file mode 100644 index 000000000000..e005c98bfe16 --- /dev/null +++ b/crates/polars-compute/src/propagate_dictionary.rs @@ -0,0 +1,83 @@ +use arrow::array::{Array, BinaryViewArray, PrimitiveArray, Utf8ViewArray}; +use arrow::bitmap::Bitmap; +use arrow::datatypes::ArrowDataType::UInt32; + +/// Propagate the nulls from the dictionary values into the keys and remove those nulls from the +/// values. +pub fn propagate_dictionary_value_nulls( + keys: &PrimitiveArray, + values: &Utf8ViewArray, +) -> (PrimitiveArray, Utf8ViewArray) { + let Some(values_validity) = values.validity() else { + return (keys.clone(), values.clone().with_validity(None)); + }; + if values_validity.unset_bits() == 0 { + return (keys.clone(), values.clone().with_validity(None)); + } + + let num_values = values.len(); + + // Create a map from the old indices to indices with nulls filtered out + let mut offset = 0; + let new_idx_map: Vec = (0..num_values) + .map(|i| { + let is_valid = unsafe { values_validity.get_bit_unchecked(i) }; + offset += usize::from(!is_valid); + if is_valid { (i - offset) as u32 } else { 0 } + }) + .collect(); + + let keys = match keys.validity() { + None => { + let values = keys + .values() + .iter() + .map(|&k| unsafe { + // SAFETY: Arrow invariant that all keys are in range of values + *new_idx_map.get_unchecked(k as usize) + }) + .collect(); + let validity = Bitmap::from_iter(keys.values().iter().map(|&k| unsafe { + // SAFETY: Arrow invariant that all keys are in range of values + values_validity.get_bit_unchecked(k as usize) + })); + + PrimitiveArray::new(UInt32, values, Some(validity)) + }, + Some(keys_validity) => { + let values = keys + .values() + .iter() + .map(|&k| { + // deal with nulls in keys + let idx = (k as usize).min(num_values); + // SAFETY: Arrow invariant that all keys are in range of values + *unsafe { new_idx_map.get_unchecked(idx) } + }) + .collect(); + let propagated_validity = Bitmap::from_iter(keys.values().iter().map(|&k| { + // deal with nulls in keys + let idx = (k as usize).min(num_values); + // SAFETY: Arrow invariant that all keys are in range of values + unsafe { values_validity.get_bit_unchecked(idx) } + })); + + let validity = &propagated_validity & keys_validity; + PrimitiveArray::new(UInt32, values, Some(validity)) + }, + }; + + // Filter only handles binary + let values = values.to_binview(); + + // Filter out the null values + let values = crate::filter::filter_with_bitmap(&values, values_validity); + let values = values.as_any().downcast_ref::().unwrap(); + let values = unsafe { values.to_utf8view_unchecked() }.clone(); + + // Explicitly set the values validity to none. + assert_eq!(values.null_count(), 0); + let values = values.with_validity(None); + + (keys, values) +} diff --git a/crates/polars-compute/src/rolling/min_max.rs b/crates/polars-compute/src/rolling/min_max.rs new file mode 100644 index 000000000000..6e306812fad2 --- /dev/null +++ b/crates/polars-compute/src/rolling/min_max.rs @@ -0,0 +1,137 @@ +use std::collections::VecDeque; +use std::marker::PhantomData; + +use arrow::bitmap::Bitmap; +use arrow::types::NativeType; +use polars_utils::min_max::MinMaxPolicy; + +use super::RollingFnParams; +use super::no_nulls::RollingAggWindowNoNulls; +use super::nulls::RollingAggWindowNulls; + +// Algorithm: https://cs.stackexchange.com/questions/120915/interview-question-with-arrays-and-consecutive-subintervals/120936#120936 +pub struct MinMaxWindow<'a, T, P> { + values: &'a [T], + validity: Option<&'a Bitmap>, + // values[monotonic_idxs[i]] is better than values[monotonic_idxs[i+1]] for + // all i, as per the policy. + monotonic_idxs: VecDeque, + nonnulls_in_window: usize, + last_end: usize, + policy: PhantomData

, +} + +impl MinMaxWindow<'_, T, P> { + /// # Safety + /// The index must be in-bounds. + unsafe fn insert_nonnull_value(&mut self, idx: usize) { + unsafe { + let value = self.values.get_unchecked(idx); + + // Remove values which are older and worse. + while let Some(tail_idx) = self.monotonic_idxs.back() { + let tail_value = self.values.get_unchecked(*tail_idx); + if !P::is_better(value, tail_value) { + break; + } + self.monotonic_idxs.pop_back(); + } + + self.monotonic_idxs.push_back(idx); + self.nonnulls_in_window += 1; + } + } + + fn remove_old_values(&mut self, window_start: usize) { + // Remove values which have fallen outside the window start. + while let Some(head_idx) = self.monotonic_idxs.front() { + if *head_idx >= window_start { + break; + } + self.monotonic_idxs.pop_front(); + self.nonnulls_in_window -= 1; + } + } +} + +impl<'a, T: NativeType, P: MinMaxPolicy> RollingAggWindowNulls<'a, T> for MinMaxWindow<'a, T, P> { + unsafe fn new( + slice: &'a [T], + validity: &'a Bitmap, + start: usize, + end: usize, + params: Option, + _window_size: Option, + ) -> Self { + assert!(params.is_none()); + let mut slf = Self { + values: slice, + validity: Some(validity), + monotonic_idxs: VecDeque::new(), + nonnulls_in_window: 0, + last_end: 0, + policy: PhantomData, + }; + unsafe { + RollingAggWindowNulls::update(&mut slf, start, end); + } + slf + } + + unsafe fn update(&mut self, start: usize, end: usize) -> Option { + unsafe { + let v = self.validity.unwrap_unchecked(); + self.remove_old_values(start); + for i in start.max(self.last_end)..end { + if v.get_bit_unchecked(i) { + self.insert_nonnull_value(i); + } + } + self.last_end = end; + self.monotonic_idxs + .front() + .map(|idx| *self.values.get_unchecked(*idx)) + } + } + + fn is_valid(&self, min_periods: usize) -> bool { + self.nonnulls_in_window >= min_periods + } +} + +impl<'a, T: NativeType, P: MinMaxPolicy> RollingAggWindowNoNulls<'a, T> for MinMaxWindow<'a, T, P> { + fn new( + slice: &'a [T], + start: usize, + end: usize, + params: Option, + _window_size: Option, + ) -> Self { + assert!(params.is_none()); + let mut slf = Self { + values: slice, + validity: None, + monotonic_idxs: VecDeque::new(), + nonnulls_in_window: 0, + last_end: 0, + policy: PhantomData, + }; + unsafe { + RollingAggWindowNoNulls::update(&mut slf, start, end); + } + slf + } + + unsafe fn update(&mut self, start: usize, end: usize) -> Option { + unsafe { + self.remove_old_values(start); + for i in start.max(self.last_end)..end { + self.insert_nonnull_value(i); + } + self.last_end = end; + self.monotonic_idxs + .front() + .map(|idx| *self.values.get_unchecked(*idx)) + } + } +} diff --git a/crates/polars-compute/src/rolling/mod.rs b/crates/polars-compute/src/rolling/mod.rs new file mode 100644 index 000000000000..e9aae2f2ca64 --- /dev/null +++ b/crates/polars-compute/src/rolling/mod.rs @@ -0,0 +1,122 @@ +mod min_max; +pub mod moment; +pub mod no_nulls; +pub mod nulls; +pub mod quantile_filter; +pub(super) mod window; +use std::hash::Hash; +use std::ops::{Add, AddAssign, Div, Mul, Sub, SubAssign}; + +use arrow::array::{ArrayRef, PrimitiveArray}; +use arrow::bitmap::{Bitmap, MutableBitmap}; +use arrow::types::NativeType; +use num_traits::{Bounded, Float, NumCast, One, Zero}; +use polars_utils::float::IsFloat; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; +use window::*; + +type Start = usize; +type End = usize; +type Idx = usize; +type WindowSize = usize; +type Len = usize; + +#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash, IntoStaticStr)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] +pub enum QuantileMethod { + #[default] + Nearest, + Lower, + Higher, + Midpoint, + Linear, + Equiprobable, +} + +#[deprecated(note = "use QuantileMethod instead")] +pub type QuantileInterpolOptions = QuantileMethod; + +#[derive(Clone, Copy, Debug, PartialEq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum RollingFnParams { + Quantile(RollingQuantileParams), + Var(RollingVarParams), + Skew { bias: bool }, + Kurtosis { fisher: bool, bias: bool }, +} + +fn det_offsets(i: Idx, window_size: WindowSize, _len: Len) -> (usize, usize) { + (i.saturating_sub(window_size - 1), i + 1) +} +fn det_offsets_center(i: Idx, window_size: WindowSize, len: Len) -> (usize, usize) { + let right_window = window_size.div_ceil(2); + ( + i.saturating_sub(window_size - right_window), + std::cmp::min(len, i + right_window), + ) +} + +fn create_validity( + min_periods: usize, + len: usize, + window_size: usize, + det_offsets_fn: Fo, +) -> Option +where + Fo: Fn(Idx, WindowSize, Len) -> (Start, End), +{ + if min_periods > 1 { + let mut validity = MutableBitmap::with_capacity(len); + validity.extend_constant(len, true); + + // Set the null values at the boundaries + + // Head. + for i in 0..len { + let (start, end) = det_offsets_fn(i, window_size, len); + if (end - start) < min_periods { + validity.set(i, false) + } else { + break; + } + } + // Tail. + for i in (0..len).rev() { + let (start, end) = det_offsets_fn(i, window_size, len); + if (end - start) < min_periods { + validity.set(i, false) + } else { + break; + } + } + + Some(validity) + } else { + None + } +} + +// Parameters allowed for rolling operations. +#[derive(Clone, Copy, Debug, PartialEq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct RollingVarParams { + pub ddof: u8, +} + +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct RollingQuantileParams { + pub prob: f64, + pub method: QuantileMethod, +} + +impl Hash for RollingQuantileParams { + fn hash(&self, state: &mut H) { + // Will not be NaN, so hash + eq symmetry will hold. + self.prob.to_bits().hash(state); + self.method.hash(state); + } +} diff --git a/crates/polars-compute/src/rolling/moment.rs b/crates/polars-compute/src/rolling/moment.rs new file mode 100644 index 000000000000..2ca1ed5e2bb6 --- /dev/null +++ b/crates/polars-compute/src/rolling/moment.rs @@ -0,0 +1,106 @@ +use super::RollingFnParams; +use crate::moment::{KurtosisState, SkewState, VarState}; + +pub trait StateUpdate { + fn new(params: Option) -> Self; + fn insert_one(&mut self, x: f64); + + fn remove_one(&mut self, x: f64); + + fn finalize(&self) -> Option; +} + +pub struct VarianceMoment { + state: VarState, + ddof: u8, +} + +impl StateUpdate for VarianceMoment { + fn new(params: Option) -> Self { + let ddof = if let Some(RollingFnParams::Var(params)) = params { + params.ddof + } else { + 1 + }; + + Self { + state: VarState::default(), + ddof, + } + } + + fn insert_one(&mut self, x: f64) { + self.state.insert_one(x); + } + + fn remove_one(&mut self, x: f64) { + self.state.remove_one(x); + } + fn finalize(&self) -> Option { + self.state.finalize(self.ddof) + } +} + +pub struct KurtosisMoment { + state: KurtosisState, + fisher: bool, + bias: bool, +} + +impl StateUpdate for KurtosisMoment { + fn new(params: Option) -> Self { + let (fisher, bias) = if let Some(RollingFnParams::Kurtosis { fisher, bias }) = params { + (fisher, bias) + } else { + (false, false) + }; + + Self { + state: KurtosisState::default(), + fisher, + bias, + } + } + + fn insert_one(&mut self, x: f64) { + self.state.insert_one(x); + } + + fn remove_one(&mut self, x: f64) { + self.state.remove_one(x); + } + fn finalize(&self) -> Option { + self.state.finalize(self.fisher, self.bias) + } +} + +pub struct SkewMoment { + state: SkewState, + bias: bool, +} + +impl StateUpdate for SkewMoment { + fn new(params: Option) -> Self { + let bias = if let Some(RollingFnParams::Skew { bias }) = params { + bias + } else { + false + }; + + Self { + state: SkewState::default(), + bias, + } + } + + fn insert_one(&mut self, x: f64) { + self.state.insert_one(x); + } + + fn remove_one(&mut self, x: f64) { + self.state.remove_one(x); + } + fn finalize(&self) -> Option { + self.state.finalize(self.bias) + } +} diff --git a/crates/polars-compute/src/rolling/no_nulls/mean.rs b/crates/polars-compute/src/rolling/no_nulls/mean.rs new file mode 100644 index 000000000000..ebb22693f899 --- /dev/null +++ b/crates/polars-compute/src/rolling/no_nulls/mean.rs @@ -0,0 +1,83 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use polars_error::polars_ensure; + +use super::*; + +pub struct MeanWindow<'a, T> { + sum: SumWindow<'a, T>, +} + +impl< + 'a, + T: NativeType + + IsFloat + + std::iter::Sum + + AddAssign + + SubAssign + + Div + + NumCast + + Add + + Sub, +> RollingAggWindowNoNulls<'a, T> for MeanWindow<'a, T> +{ + fn new( + slice: &'a [T], + start: usize, + end: usize, + params: Option, + window_size: Option, + ) -> Self { + Self { + sum: SumWindow::new(slice, start, end, params, window_size), + } + } + + unsafe fn update(&mut self, start: usize, end: usize) -> Option { + let sum = self.sum.update(start, end).unwrap_unchecked(); + Some(sum / NumCast::from(end - start).unwrap()) + } +} + +pub fn rolling_mean( + values: &[T], + window_size: usize, + min_periods: usize, + center: bool, + weights: Option<&[f64]>, + _params: Option, +) -> PolarsResult +where + T: NativeType + Float + std::iter::Sum + SubAssign + AddAssign + IsFloat, +{ + let offset_fn = match center { + true => det_offsets_center, + false => det_offsets, + }; + match weights { + None => rolling_apply_agg_window::, _, _>( + values, + window_size, + min_periods, + offset_fn, + None, + ), + Some(weights) => { + // A weighted mean is a weighted sum with normalized weights + let mut wts = no_nulls::coerce_weights(weights); + let wsum = wts.iter().fold(T::zero(), |acc, x| acc + *x); + polars_ensure!( + wsum != T::zero(), + ComputeError: "Weighted mean is undefined if weights sum to 0" + ); + wts.iter_mut().for_each(|w| *w = *w / wsum); + no_nulls::rolling_apply_weights( + values, + window_size, + min_periods, + offset_fn, + no_nulls::compute_sum_weights, + &wts, + ) + }, + } +} diff --git a/crates/polars-compute/src/rolling/no_nulls/min_max.rs b/crates/polars-compute/src/rolling/no_nulls/min_max.rs new file mode 100644 index 000000000000..540070c60f80 --- /dev/null +++ b/crates/polars-compute/src/rolling/no_nulls/min_max.rs @@ -0,0 +1,145 @@ +use polars_utils::min_max::{MaxPropagateNan, MinMaxPolicy, MinPropagateNan}; + +use super::super::min_max::MinMaxWindow; +use super::*; + +pub type MinWindow<'a, T> = MinMaxWindow<'a, T, MinPropagateNan>; +pub type MaxWindow<'a, T> = MinMaxWindow<'a, T, MaxPropagateNan>; + +fn weighted_min_max(values: &[T], weights: &[T]) -> T +where + T: NativeType + std::ops::Mul, + P: MinMaxPolicy, +{ + values + .iter() + .zip(weights) + .map(|(v, w)| *v * *w) + .reduce(P::best) + .unwrap() +} + +macro_rules! rolling_minmax_func { + ($rolling_m:ident, $policy:ident) => { + pub fn $rolling_m( + values: &[T], + window_size: usize, + min_periods: usize, + center: bool, + weights: Option<&[f64]>, + _params: Option, + ) -> PolarsResult + where + T: NativeType + PartialOrd + IsFloat + Bounded + NumCast + Mul + Num, + { + let offset_fn = match center { + true => det_offsets_center, + false => det_offsets, + }; + match weights { + None => rolling_apply_agg_window::, _, _>( + values, + window_size, + min_periods, + offset_fn, + None, + ), + Some(weights) => { + assert!( + T::is_float(), + "implementation error, should only be reachable by float types" + ); + let weights = weights + .iter() + .map(|v| NumCast::from(*v).unwrap()) + .collect::>(); + no_nulls::rolling_apply_weights( + values, + window_size, + min_periods, + offset_fn, + weighted_min_max::, + &weights, + ) + }, + } + } + }; +} + +rolling_minmax_func!(rolling_min, MinPropagateNan); +rolling_minmax_func!(rolling_max, MaxPropagateNan); + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_rolling_min_max() { + let values = &[1.0f64, 5.0, 3.0, 4.0]; + + let out = rolling_min(values, 2, 2, false, None, None).unwrap(); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[None, Some(1.0), Some(3.0), Some(3.0)]); + let out = rolling_max(values, 2, 2, false, None, None).unwrap(); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[None, Some(5.0), Some(5.0), Some(4.0)]); + + let out = rolling_min(values, 2, 1, false, None, None).unwrap(); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(3.0)]); + let out = rolling_max(values, 2, 1, false, None, None).unwrap(); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[Some(1.0), Some(5.0), Some(5.0), Some(4.0)]); + + let out = rolling_max(values, 3, 1, false, None, None).unwrap(); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[Some(1.0), Some(5.0), Some(5.0), Some(5.0)]); + + // test nan handling. + let values = &[1.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0]; + let out = rolling_min(values, 3, 3, false, None, None).unwrap(); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + // we cannot compare nans, so we compare the string values + assert_eq!( + format!("{:?}", out.as_slice()), + format!( + "{:?}", + &[ + None, + None, + Some(1.0), + Some(f64::nan()), + Some(f64::nan()), + Some(f64::nan()), + Some(5.0) + ] + ) + ); + + let out = rolling_max(values, 3, 3, false, None, None).unwrap(); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!( + format!("{:?}", out.as_slice()), + format!( + "{:?}", + &[ + None, + None, + Some(3.0), + Some(f64::nan()), + Some(f64::nan()), + Some(f64::nan()), + Some(7.0) + ] + ) + ); + } +} diff --git a/crates/polars-compute/src/rolling/no_nulls/mod.rs b/crates/polars-compute/src/rolling/no_nulls/mod.rs new file mode 100644 index 000000000000..8d9840e5768c --- /dev/null +++ b/crates/polars-compute/src/rolling/no_nulls/mod.rs @@ -0,0 +1,142 @@ +mod mean; +mod min_max; +mod moment; +mod quantile; +mod sum; +use std::fmt::Debug; + +use arrow::array::PrimitiveArray; +use arrow::datatypes::ArrowDataType; +use arrow::legacy::error::PolarsResult; +use arrow::legacy::utils::CustomIterTools; +use arrow::types::NativeType; +pub use mean::*; +pub use min_max::*; +pub use moment::*; +use num_traits::{Float, Num, NumCast}; +pub use quantile::*; +pub use sum::*; + +use super::*; + +pub trait RollingAggWindowNoNulls<'a, T: NativeType> { + fn new( + slice: &'a [T], + start: usize, + end: usize, + params: Option, + window_size: Option, + ) -> Self; + + /// Update and recompute the window + /// + /// # Safety + /// `start` and `end` must be within the windows bounds + unsafe fn update(&mut self, start: usize, end: usize) -> Option; +} + +// Use an aggregation window that maintains the state +pub(super) fn rolling_apply_agg_window<'a, Agg, T, Fo>( + values: &'a [T], + window_size: usize, + min_periods: usize, + det_offsets_fn: Fo, + params: Option, +) -> PolarsResult +where + Fo: Fn(Idx, WindowSize, Len) -> (Start, End), + Agg: RollingAggWindowNoNulls<'a, T>, + T: Debug + NativeType + Num, +{ + let len = values.len(); + let (start, end) = det_offsets_fn(0, window_size, len); + let mut agg_window = Agg::new(values, start, end, params, Some(window_size)); + if let Some(validity) = create_validity(min_periods, len, window_size, &det_offsets_fn) { + if validity.iter().all(|x| !x) { + return Ok(Box::new(PrimitiveArray::::new_null( + T::PRIMITIVE.into(), + len, + ))); + } + } + + let out = (0..len).map(|idx| { + let (start, end) = det_offsets_fn(idx, window_size, len); + if end - start < min_periods { + None + } else { + // SAFETY: + // we are in bounds + unsafe { agg_window.update(start, end) } + } + }); + let arr = PrimitiveArray::from_trusted_len_iter(out); + Ok(Box::new(arr)) +} + +pub(super) fn rolling_apply_weights( + values: &[T], + window_size: usize, + min_periods: usize, + det_offsets_fn: Fo, + aggregator: Fa, + weights: &[T], +) -> PolarsResult +where + T: NativeType, + Fo: Fn(Idx, WindowSize, Len) -> (Start, End), + Fa: Fn(&[T], &[T]) -> T, +{ + assert_eq!(weights.len(), window_size); + let len = values.len(); + let out = (0..len) + .map(|idx| { + let (start, end) = det_offsets_fn(idx, window_size, len); + let vals = unsafe { values.get_unchecked(start..end) }; + + aggregator(vals, weights) + }) + .collect_trusted::>(); + + let validity = create_validity(min_periods, len, window_size, det_offsets_fn); + Ok(Box::new(PrimitiveArray::new( + ArrowDataType::from(T::PRIMITIVE), + out.into(), + validity.map(|b| b.into()), + ))) +} + +fn compute_var_weights(vals: &[T], weights: &[T]) -> T +where + T: Float + std::ops::AddAssign, +{ + // Assumes the weights have already been standardized to 1 + debug_assert!( + weights.iter().fold(T::zero(), |acc, x| acc + *x) == T::one(), + "Rolling weighted variance Weights don't sum to 1" + ); + let (wssq, wmean) = vals + .iter() + .zip(weights) + .fold((T::zero(), T::zero()), |(wssq, wsum), (&v, &w)| { + (wssq + v * v * w, wsum + v * w) + }); + + wssq - wmean * wmean +} + +pub(crate) fn compute_sum_weights(values: &[T], weights: &[T]) -> T +where + T: std::iter::Sum + Copy + std::ops::Mul, +{ + values.iter().zip(weights).map(|(v, w)| *v * *w).sum() +} + +pub(super) fn coerce_weights(weights: &[f64]) -> Vec +where +{ + weights + .iter() + .map(|v| NumCast::from(*v).unwrap()) + .collect::>() +} diff --git a/crates/polars-compute/src/rolling/no_nulls/moment.rs b/crates/polars-compute/src/rolling/no_nulls/moment.rs new file mode 100644 index 000000000000..56233dd9f0a6 --- /dev/null +++ b/crates/polars-compute/src/rolling/no_nulls/moment.rs @@ -0,0 +1,226 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use num_traits::{FromPrimitive, ToPrimitive}; +use polars_error::polars_ensure; + +pub use super::super::moment::*; +use super::*; + +pub struct MomentWindow<'a, T, M: StateUpdate> { + slice: &'a [T], + moment: M, + last_start: usize, + last_end: usize, + params: Option, +} + +impl MomentWindow<'_, T, M> { + fn compute_var(&mut self, start: usize, end: usize) { + self.moment = M::new(self.params); + for value in &self.slice[start..end] { + let value: f64 = NumCast::from(*value).unwrap(); + self.moment.insert_one(value); + } + } +} + +impl<'a, T: NativeType + IsFloat + Float + ToPrimitive + FromPrimitive, M: StateUpdate> + RollingAggWindowNoNulls<'a, T> for MomentWindow<'a, T, M> +{ + fn new( + slice: &'a [T], + start: usize, + end: usize, + params: Option, + _window_size: Option, + ) -> Self { + let mut out = Self { + slice, + moment: M::new(params), + last_start: start, + last_end: end, + params, + }; + out.compute_var(start, end); + out + } + + unsafe fn update(&mut self, start: usize, end: usize) -> Option { + let recompute_var = if start >= self.last_end { + true + } else { + // remove elements that should leave the window + let mut recompute_var = false; + for idx in self.last_start..start { + // SAFETY: we are in bounds + let leaving_value = *self.slice.get_unchecked(idx); + + // if the leaving value is nan we need to recompute the window + if T::is_float() && !leaving_value.is_finite() { + recompute_var = true; + break; + } + let leaving_value: f64 = NumCast::from(leaving_value).unwrap(); + self.moment.remove_one(leaving_value); + } + recompute_var + }; + + self.last_start = start; + + // we traverse all values and compute + if recompute_var { + self.compute_var(start, end); + } else { + for idx in self.last_end..end { + let entering_value = *self.slice.get_unchecked(idx); + let entering_value: f64 = NumCast::from(entering_value).unwrap(); + + self.moment.insert_one(entering_value); + } + } + self.last_end = end; + self.moment.finalize().map(|v| T::from_f64(v).unwrap()) + } +} + +pub fn rolling_var( + values: &[T], + window_size: usize, + min_periods: usize, + center: bool, + weights: Option<&[f64]>, + params: Option, +) -> PolarsResult +where + T: NativeType + Float + IsFloat + ToPrimitive + FromPrimitive + AddAssign, +{ + let offset_fn = match center { + true => det_offsets_center, + false => det_offsets, + }; + match weights { + None => rolling_apply_agg_window::, _, _>( + values, + window_size, + min_periods, + offset_fn, + params, + ), + Some(weights) => { + // Validate and standardize the weights like we do for the mean. This definition is fine + // because frequency weights and unbiasing don't make sense for rolling operations. + let mut wts = no_nulls::coerce_weights(weights); + let wsum = wts.iter().fold(T::zero(), |acc, x| acc + *x); + polars_ensure!( + wsum != T::zero(), + ComputeError: "Weighted variance is undefined if weights sum to 0" + ); + wts.iter_mut().for_each(|w| *w = *w / wsum); + super::rolling_apply_weights( + values, + window_size, + min_periods, + offset_fn, + compute_var_weights, + &wts, + ) + }, + } +} + +pub fn rolling_skew( + values: &[T], + window_size: usize, + min_periods: usize, + center: bool, + params: Option, +) -> PolarsResult +where + T: NativeType + Float + IsFloat + ToPrimitive + FromPrimitive + AddAssign, +{ + let offset_fn = match center { + true => det_offsets_center, + false => det_offsets, + }; + rolling_apply_agg_window::, _, _>( + values, + window_size, + min_periods, + offset_fn, + params, + ) +} + +pub fn rolling_kurtosis( + values: &[T], + window_size: usize, + min_periods: usize, + center: bool, + params: Option, +) -> PolarsResult +where + T: NativeType + Float + IsFloat + ToPrimitive + FromPrimitive + AddAssign, +{ + let offset_fn = match center { + true => det_offsets_center, + false => det_offsets, + }; + rolling_apply_agg_window::, _, _>( + values, + window_size, + min_periods, + offset_fn, + params, + ) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_rolling_var() { + let values = &[1.0f64, 5.0, 3.0, 4.0]; + + let out = rolling_var(values, 2, 2, false, None, None).unwrap(); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[None, Some(8.0), Some(2.0), Some(0.5)]); + + let testpars = Some(RollingFnParams::Var(RollingVarParams { ddof: 0 })); + let out = rolling_var(values, 2, 2, false, None, testpars).unwrap(); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[None, Some(4.0), Some(1.0), Some(0.25)]); + + let out = rolling_var(values, 2, 1, false, None, None).unwrap(); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + // we cannot compare nans, so we compare the string values + assert_eq!( + format!("{:?}", out.as_slice()), + format!("{:?}", &[None, Some(8.0), Some(2.0), Some(0.5)]) + ); + // test nan handling. + let values = &[-10.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0]; + let out = rolling_var(values, 3, 3, false, None, None).unwrap(); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + // we cannot compare nans, so we compare the string values + assert_eq!( + format!("{:?}", out.as_slice()), + format!( + "{:?}", + &[ + None, + None, + Some(52.33333333333333), + Some(f64::nan()), + Some(f64::nan()), + Some(f64::nan()), + Some(1.0) + ] + ) + ); + } +} diff --git a/crates/polars-compute/src/rolling/no_nulls/quantile.rs b/crates/polars-compute/src/rolling/no_nulls/quantile.rs new file mode 100644 index 000000000000..c83c73f2e7f9 --- /dev/null +++ b/crates/polars-compute/src/rolling/no_nulls/quantile.rs @@ -0,0 +1,347 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use arrow::legacy::utils::CustomIterTools; +use num_traits::ToPrimitive; +use polars_error::polars_ensure; + +use super::QuantileMethod::*; +use super::*; +use crate::rolling::quantile_filter::SealedRolling; + +pub struct QuantileWindow<'a, T: NativeType> { + sorted: SortedBuf<'a, T>, + prob: f64, + method: QuantileMethod, +} + +impl< + 'a, + T: NativeType + + Float + + std::iter::Sum + + AddAssign + + SubAssign + + Div + + NumCast + + One + + Zero + + SealedRolling + + Sub, +> RollingAggWindowNoNulls<'a, T> for QuantileWindow<'a, T> +{ + fn new( + slice: &'a [T], + start: usize, + end: usize, + params: Option, + window_size: Option, + ) -> Self { + let params = params.unwrap(); + let RollingFnParams::Quantile(params) = params else { + unreachable!("expected Quantile params"); + }; + + Self { + sorted: SortedBuf::new(slice, start, end, window_size), + prob: params.prob, + method: params.method, + } + } + + unsafe fn update(&mut self, start: usize, end: usize) -> Option { + self.sorted.update(start, end); + let length = self.sorted.len(); + + let idx = match self.method { + Linear => { + // Maybe add a fast path for median case? They could branch depending on odd/even. + let length_f = length as f64; + let idx = ((length_f - 1.0) * self.prob).floor() as usize; + + let float_idx_top = (length_f - 1.0) * self.prob; + let top_idx = float_idx_top.ceil() as usize; + return if idx == top_idx { + Some(self.sorted.get(idx)) + } else { + let proportion = T::from(float_idx_top - idx as f64).unwrap(); + let vi = self.sorted.get(idx); + let vj = self.sorted.get(top_idx); + + Some(proportion * (vj - vi) + vi) + }; + }, + Midpoint => { + let length_f = length as f64; + let idx = (length_f * self.prob) as usize; + let idx = std::cmp::min(idx, length - 1); + + let top_idx = ((length_f - 1.0) * self.prob).ceil() as usize; + return if top_idx == idx { + Some(self.sorted.get(idx)) + } else { + let (mid, mid_plus_1) = (self.sorted.get(idx), (self.sorted.get(idx + 1))); + + Some((mid + mid_plus_1) / (T::one() + T::one())) + }; + }, + Nearest => { + let idx = ((length as f64) * self.prob) as usize; + std::cmp::min(idx, length - 1) + }, + Lower => ((length as f64 - 1.0) * self.prob).floor() as usize, + Higher => { + let idx = ((length as f64 - 1.0) * self.prob).ceil() as usize; + std::cmp::min(idx, length - 1) + }, + Equiprobable => ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize, + }; + + Some(self.sorted.get(idx)) + } +} + +pub fn rolling_quantile( + values: &[T], + window_size: usize, + min_periods: usize, + center: bool, + weights: Option<&[f64]>, + params: Option, +) -> PolarsResult +where + T: NativeType + + IsFloat + + Float + + std::iter::Sum + + AddAssign + + SubAssign + + Div + + NumCast + + One + + Zero + + SealedRolling + + PartialOrd + + Sub, +{ + let offset_fn = match center { + true => det_offsets_center, + false => det_offsets, + }; + match weights { + None => { + if !center { + let params = params.as_ref().unwrap(); + let RollingFnParams::Quantile(params) = params else { + unreachable!("expected Quantile params"); + }; + let out = super::quantile_filter::rolling_quantile::<_, Vec<_>>( + params.method, + min_periods, + window_size, + values, + params.prob, + ); + let validity = create_validity(min_periods, values.len(), window_size, offset_fn); + return Ok(Box::new(PrimitiveArray::new( + T::PRIMITIVE.into(), + out.into(), + validity.map(|b| b.into()), + ))); + } + + rolling_apply_agg_window::, _, _>( + values, + window_size, + min_periods, + offset_fn, + params, + ) + }, + Some(weights) => { + let wsum = weights.iter().sum(); + polars_ensure!( + wsum != 0.0, + ComputeError: "Weighted quantile is undefined if weights sum to 0" + ); + let params = params.unwrap(); + let RollingFnParams::Quantile(params) = params else { + unreachable!("expected Quantile params"); + }; + + Ok(rolling_apply_weighted_quantile( + values, + params.prob, + params.method, + window_size, + min_periods, + offset_fn, + weights, + wsum, + )) + }, + } +} + +#[inline] +fn compute_wq(buf: &[(T, f64)], p: f64, wsum: f64, method: QuantileMethod) -> T +where + T: Debug + NativeType + Mul + Sub + NumCast + ToPrimitive + Zero, +{ + // There are a few ways to compute a weighted quantile but no "canonical" way. + // This is mostly taken from the Julia implementation which was readable and reasonable + // https://juliastats.org/StatsBase.jl/stable/scalarstats/#Quantile-and-Related-Functions-1 + let (mut s, mut s_old, mut vk, mut v_old) = (0.0, 0.0, T::zero(), T::zero()); + + // Once the cumulative weight crosses h, we've found our ind{ex/ices}. The definition may look + // odd but it's the equivalent of taking h = p * (n - 1) + 1 if your data is indexed from 1. + let h: f64 = p * (wsum - buf[0].1) + buf[0].1; + for &(v, w) in buf.iter() { + if s > h { + break; + } + (s_old, v_old, vk) = (s, vk, v); + s += w; + } + match (h == s_old, method) { + (true, _) => v_old, // If we hit the break exactly interpolation shouldn't matter + (_, Lower) => v_old, + (_, Higher) => vk, + (_, Nearest) => { + if s - h > h - s_old { + v_old + } else { + vk + } + }, + (_, Equiprobable) => { + let threshold = (wsum * p).ceil() - 1.0; + if s > threshold { vk } else { v_old } + }, + (_, Midpoint) => (vk + v_old) * NumCast::from(0.5).unwrap(), + // This is seemingly the canonical way to do it. + (_, Linear) => { + v_old + ::from((h - s_old) / (s - s_old)).unwrap() * (vk - v_old) + }, + } +} + +#[allow(clippy::too_many_arguments)] +fn rolling_apply_weighted_quantile( + values: &[T], + p: f64, + method: QuantileMethod, + window_size: usize, + min_periods: usize, + det_offsets_fn: Fo, + weights: &[f64], + wsum: f64, +) -> ArrayRef +where + Fo: Fn(Idx, WindowSize, Len) -> (Start, End), + T: Debug + NativeType + Mul + Sub + NumCast + ToPrimitive + Zero, +{ + assert_eq!(weights.len(), window_size); + // Keep nonzero weights and their indices to know which values we need each iteration. + let nz_idx_wts: Vec<_> = weights.iter().enumerate().filter(|x| x.1 != &0.0).collect(); + let mut buf = vec![(T::zero(), 0.0); nz_idx_wts.len()]; + let len = values.len(); + let out = (0..len) + .map(|idx| { + // Don't need end. Window size is constant and we computed offsets from start above. + let (start, _) = det_offsets_fn(idx, window_size, len); + + // Sorting is not ideal, see https://github.com/tobiasschoch/wquantile for something faster + unsafe { + buf.iter_mut() + .zip(nz_idx_wts.iter()) + .for_each(|(b, (i, w))| *b = (*values.get_unchecked(i + start), **w)); + } + buf.sort_unstable_by(|&a, &b| a.0.tot_cmp(&b.0)); + compute_wq(&buf, p, wsum, method) + }) + .collect_trusted::>(); + + let validity = create_validity(min_periods, len, window_size, det_offsets_fn); + Box::new(PrimitiveArray::new( + T::PRIMITIVE.into(), + out.into(), + validity.map(|b| b.into()), + )) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_rolling_median() { + let values = &[1.0, 2.0, 3.0, 4.0]; + let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams { + prob: 0.5, + method: Linear, + })); + let out = rolling_quantile(values, 2, 2, false, None, med_pars).unwrap(); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[None, Some(1.5), Some(2.5), Some(3.5)]); + + let out = rolling_quantile(values, 2, 1, false, None, med_pars).unwrap(); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[Some(1.0), Some(1.5), Some(2.5), Some(3.5)]); + + let out = rolling_quantile(values, 4, 1, false, None, med_pars).unwrap(); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[Some(1.0), Some(1.5), Some(2.0), Some(2.5)]); + + let out = rolling_quantile(values, 4, 1, true, None, med_pars).unwrap(); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[Some(1.5), Some(2.0), Some(2.5), Some(3.0)]); + + let out = rolling_quantile(values, 4, 4, true, None, med_pars).unwrap(); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[None, None, Some(2.5), None]); + } + + #[test] + fn test_rolling_quantile_limits() { + let values = &[1.0f64, 2.0, 3.0, 4.0]; + + let methods = vec![ + QuantileMethod::Lower, + QuantileMethod::Higher, + QuantileMethod::Nearest, + QuantileMethod::Midpoint, + QuantileMethod::Linear, + QuantileMethod::Equiprobable, + ]; + + for method in methods { + let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams { + prob: 0.0, + method, + })); + let out1 = rolling_min(values, 2, 2, false, None, None).unwrap(); + let out1 = out1.as_any().downcast_ref::>().unwrap(); + let out1 = out1.into_iter().map(|v| v.copied()).collect::>(); + let out2 = rolling_quantile(values, 2, 2, false, None, min_pars).unwrap(); + let out2 = out2.as_any().downcast_ref::>().unwrap(); + let out2 = out2.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out1, out2); + + let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams { + prob: 1.0, + method, + })); + let out1 = rolling_max(values, 2, 2, false, None, None).unwrap(); + let out1 = out1.as_any().downcast_ref::>().unwrap(); + let out1 = out1.into_iter().map(|v| v.copied()).collect::>(); + let out2 = rolling_quantile(values, 2, 2, false, None, max_pars).unwrap(); + let out2 = out2.as_any().downcast_ref::>().unwrap(); + let out2 = out2.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out1, out2); + } + } +} diff --git a/crates/polars-compute/src/rolling/no_nulls/sum.rs b/crates/polars-compute/src/rolling/no_nulls/sum.rs new file mode 100644 index 000000000000..025982d01089 --- /dev/null +++ b/crates/polars-compute/src/rolling/no_nulls/sum.rs @@ -0,0 +1,249 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use super::*; + +fn sum_kahan< + T: NativeType + + IsFloat + + std::iter::Sum + + AddAssign + + SubAssign + + Sub + + Add, +>( + vals: &[T], +) -> (T, T) { + if T::is_float() { + let mut sum = T::zeroed(); + let mut err = T::zeroed(); + + for val in vals.iter().copied() { + if val.is_finite() { + let y = val - err; + let new_sum = sum + y; + err = (new_sum - sum) - y; + sum = new_sum; + } else { + sum += val + } + } + (sum, err) + } else { + (vals.iter().copied().sum::(), T::zeroed()) + } +} + +pub struct SumWindow<'a, T> { + slice: &'a [T], + sum: T, + err: T, + last_start: usize, + last_end: usize, +} + +impl + Add> + SumWindow<'_, T> +{ + // Kahan summation + fn add(&mut self, val: T) { + if T::is_float() && val.is_finite() { + let y = val - self.err; + let new_sum = self.sum + y; + self.err = (new_sum - self.sum) - y; + self.sum = new_sum; + } else { + self.sum += val; + } + } + + fn sub(&mut self, val: T) { + if T::is_float() { + self.add(T::zeroed() - val) + } else { + self.sum -= val; + } + } +} + +impl< + 'a, + T: NativeType + + IsFloat + + std::iter::Sum + + AddAssign + + SubAssign + + Sub + + Add, +> RollingAggWindowNoNulls<'a, T> for SumWindow<'a, T> +{ + fn new( + slice: &'a [T], + start: usize, + end: usize, + _params: Option, + _window_size: Option, + ) -> Self { + let (sum, err) = sum_kahan(&slice[start..end]); + Self { + slice, + sum, + err, + last_start: start, + last_end: end, + } + } + + unsafe fn update(&mut self, start: usize, end: usize) -> Option { + // if we exceed the end, we have a completely new window + // so we recompute + let recompute_sum = if start >= self.last_end { + true + } else { + // remove elements that should leave the window + let mut recompute_sum = false; + for idx in self.last_start..start { + // SAFETY: + // we are in bounds + let leaving_value = self.slice.get_unchecked(idx); + + if T::is_float() && !leaving_value.is_finite() { + recompute_sum = true; + break; + } + + self.sub(*leaving_value); + } + recompute_sum + }; + self.last_start = start; + + // we traverse all values and compute + if recompute_sum { + let vals = self.slice.get_unchecked(start..end); + let (sum, err) = sum_kahan(vals); + self.sum = sum; + self.err = err; + } + // add entering values. + else { + for idx in self.last_end..end { + self.add(*self.slice.get_unchecked(idx)) + } + } + self.last_end = end; + Some(self.sum) + } +} + +pub fn rolling_sum( + values: &[T], + window_size: usize, + min_periods: usize, + center: bool, + weights: Option<&[f64]>, + _params: Option, +) -> PolarsResult +where + T: NativeType + + std::iter::Sum + + NumCast + + Mul + + AddAssign + + SubAssign + + IsFloat + + Num, +{ + match (center, weights) { + (true, None) => rolling_apply_agg_window::, _, _>( + values, + window_size, + min_periods, + det_offsets_center, + None, + ), + (false, None) => rolling_apply_agg_window::, _, _>( + values, + window_size, + min_periods, + det_offsets, + None, + ), + (true, Some(weights)) => { + let weights = no_nulls::coerce_weights(weights); + no_nulls::rolling_apply_weights( + values, + window_size, + min_periods, + det_offsets_center, + no_nulls::compute_sum_weights, + &weights, + ) + }, + (false, Some(weights)) => { + let weights = no_nulls::coerce_weights(weights); + no_nulls::rolling_apply_weights( + values, + window_size, + min_periods, + det_offsets, + no_nulls::compute_sum_weights, + &weights, + ) + }, + } +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn test_rolling_sum() { + let values = &[1.0f64, 2.0, 3.0, 4.0]; + + let out = rolling_sum(values, 2, 2, false, None, None).unwrap(); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[None, Some(3.0), Some(5.0), Some(7.0)]); + + let out = rolling_sum(values, 2, 1, false, None, None).unwrap(); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[Some(1.0), Some(3.0), Some(5.0), Some(7.0)]); + + let out = rolling_sum(values, 4, 1, false, None, None).unwrap(); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[Some(1.0), Some(3.0), Some(6.0), Some(10.0)]); + + let out = rolling_sum(values, 4, 1, true, None, None).unwrap(); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[Some(3.0), Some(6.0), Some(10.0), Some(9.0)]); + + let out = rolling_sum(values, 4, 4, true, None, None).unwrap(); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[None, None, Some(10.0), None]); + + // test nan handling. + let values = &[1.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0]; + let out = rolling_sum(values, 3, 3, false, None, None).unwrap(); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + + assert_eq!( + format!("{:?}", out.as_slice()), + format!( + "{:?}", + &[ + None, + None, + Some(6.0), + Some(f64::nan()), + Some(f64::nan()), + Some(f64::nan()), + Some(18.0) + ] + ) + ); + } +} diff --git a/crates/polars-compute/src/rolling/nulls/mean.rs b/crates/polars-compute/src/rolling/nulls/mean.rs new file mode 100644 index 000000000000..3ffcc3aeb9e5 --- /dev/null +++ b/crates/polars-compute/src/rolling/nulls/mean.rs @@ -0,0 +1,83 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use super::*; + +pub struct MeanWindow<'a, T> { + sum: SumWindow<'a, T>, +} + +impl< + 'a, + T: NativeType + + IsFloat + + Add + + Sub + + NumCast + + Div + + AddAssign + + SubAssign, +> RollingAggWindowNulls<'a, T> for MeanWindow<'a, T> +{ + unsafe fn new( + slice: &'a [T], + validity: &'a Bitmap, + start: usize, + end: usize, + params: Option, + window_size: Option, + ) -> Self { + Self { + sum: SumWindow::new(slice, validity, start, end, params, window_size), + } + } + + unsafe fn update(&mut self, start: usize, end: usize) -> Option { + let sum = self.sum.update(start, end); + sum.map(|sum| sum / NumCast::from(end - start - self.sum.null_count).unwrap()) + } + fn is_valid(&self, min_periods: usize) -> bool { + self.sum.is_valid(min_periods) + } +} + +pub fn rolling_mean( + arr: &PrimitiveArray, + window_size: usize, + min_periods: usize, + center: bool, + weights: Option<&[f64]>, + _params: Option, +) -> ArrayRef +where + T: NativeType + + IsFloat + + PartialOrd + + Add + + Sub + + NumCast + + AddAssign + + SubAssign + + Div, +{ + if weights.is_some() { + panic!("weights not yet supported on array with null values") + } + if center { + rolling_apply_agg_window::, _, _>( + arr.values().as_slice(), + arr.validity().as_ref().unwrap(), + window_size, + min_periods, + det_offsets_center, + None, + ) + } else { + rolling_apply_agg_window::, _, _>( + arr.values().as_slice(), + arr.validity().as_ref().unwrap(), + window_size, + min_periods, + det_offsets, + None, + ) + } +} diff --git a/crates/polars-compute/src/rolling/nulls/min_max.rs b/crates/polars-compute/src/rolling/nulls/min_max.rs new file mode 100644 index 000000000000..9db62256b7dc --- /dev/null +++ b/crates/polars-compute/src/rolling/nulls/min_max.rs @@ -0,0 +1,78 @@ +use polars_utils::min_max::{MaxPropagateNan, MinPropagateNan}; + +use super::super::min_max::MinMaxWindow; + +pub type MinWindow<'a, T> = MinMaxWindow<'a, T, MinPropagateNan>; +pub type MaxWindow<'a, T> = MinMaxWindow<'a, T, MaxPropagateNan>; + +use super::*; + +pub fn rolling_min( + arr: &PrimitiveArray, + window_size: usize, + min_periods: usize, + center: bool, + weights: Option<&[f64]>, + _params: Option, +) -> ArrayRef +where + T: NativeType + IsFloat, +{ + if weights.is_some() { + panic!("weights not yet supported on array with null values") + } + if center { + rolling_apply_agg_window::, _, _>( + arr.values().as_slice(), + arr.validity().as_ref().unwrap(), + window_size, + min_periods, + det_offsets_center, + None, + ) + } else { + rolling_apply_agg_window::, _, _>( + arr.values().as_slice(), + arr.validity().as_ref().unwrap(), + window_size, + min_periods, + det_offsets, + None, + ) + } +} + +pub fn rolling_max( + arr: &PrimitiveArray, + window_size: usize, + min_periods: usize, + center: bool, + weights: Option<&[f64]>, + _params: Option, +) -> ArrayRef +where + T: NativeType + std::iter::Sum + Zero + AddAssign + Copy + PartialOrd + Bounded + IsFloat, +{ + if weights.is_some() { + panic!("weights not yet supported on array with null values") + } + if center { + rolling_apply_agg_window::, _, _>( + arr.values().as_slice(), + arr.validity().as_ref().unwrap(), + window_size, + min_periods, + det_offsets_center, + None, + ) + } else { + rolling_apply_agg_window::, _, _>( + arr.values().as_slice(), + arr.validity().as_ref().unwrap(), + window_size, + min_periods, + det_offsets, + None, + ) + } +} diff --git a/crates/polars-compute/src/rolling/nulls/mod.rs b/crates/polars-compute/src/rolling/nulls/mod.rs new file mode 100644 index 000000000000..9285a0ba5ed6 --- /dev/null +++ b/crates/polars-compute/src/rolling/nulls/mod.rs @@ -0,0 +1,267 @@ +mod mean; +mod min_max; +mod moment; +mod quantile; +mod sum; + +use arrow::legacy::utils::CustomIterTools; +pub use mean::*; +pub use min_max::*; +pub use moment::*; +pub use quantile::*; +pub use sum::*; + +use super::*; + +pub trait RollingAggWindowNulls<'a, T: NativeType> { + /// # Safety + /// `start` and `end` must be in bounds for `slice` and `validity` + unsafe fn new( + slice: &'a [T], + validity: &'a Bitmap, + start: usize, + end: usize, + params: Option, + window_size: Option, + ) -> Self; + + /// # Safety + /// `start` and `end` must be in bounds of `slice` and `bitmap` + unsafe fn update(&mut self, start: usize, end: usize) -> Option; + + fn is_valid(&self, min_periods: usize) -> bool; +} + +// Use an aggregation window that maintains the state +pub(super) fn rolling_apply_agg_window<'a, Agg, T, Fo>( + values: &'a [T], + validity: &'a Bitmap, + window_size: usize, + min_periods: usize, + det_offsets_fn: Fo, + params: Option, +) -> ArrayRef +where + Fo: Fn(Idx, WindowSize, Len) -> (Start, End) + Copy, + Agg: RollingAggWindowNulls<'a, T>, + T: IsFloat + NativeType, +{ + let len = values.len(); + let (start, end) = det_offsets_fn(0, window_size, len); + // SAFETY; we are in bounds + let mut agg_window = + unsafe { Agg::new(values, validity, start, end, params, Some(window_size)) }; + + let mut validity = create_validity(min_periods, len, window_size, det_offsets_fn) + .unwrap_or_else(|| { + let mut validity = MutableBitmap::with_capacity(len); + validity.extend_constant(len, true); + validity + }); + + let out = (0..len) + .map(|idx| { + let (start, end) = det_offsets_fn(idx, window_size, len); + // SAFETY: + // we are in bounds + let agg = unsafe { agg_window.update(start, end) }; + match agg { + Some(val) => { + if agg_window.is_valid(min_periods) { + val + } else { + // SAFETY: we are in bounds + unsafe { validity.set_unchecked(idx, false) }; + T::default() + } + }, + None => { + // SAFETY: we are in bounds + unsafe { validity.set_unchecked(idx, false) }; + T::default() + }, + } + }) + .collect_trusted::>(); + + Box::new(PrimitiveArray::new( + T::PRIMITIVE.into(), + out.into(), + Some(validity.into()), + )) +} + +#[cfg(test)] +mod test { + use arrow::array::{Array, Int32Array}; + use arrow::buffer::Buffer; + use arrow::datatypes::ArrowDataType; + use polars_utils::min_max::MaxIgnoreNan; + + use super::*; + use crate::rolling::min_max::MinMaxWindow; + + fn get_null_arr() -> PrimitiveArray { + // 1, None, -1, 4 + let buf = Buffer::from(vec![1.0, 0.0, -1.0, 4.0]); + PrimitiveArray::new( + ArrowDataType::Float64, + buf, + Some(Bitmap::from(&[true, false, true, true])), + ) + } + + #[test] + fn test_rolling_sum_nulls() { + let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]); + let arr = &PrimitiveArray::new( + ArrowDataType::Float64, + buf, + Some(Bitmap::from(&[true, false, true, true])), + ); + + let out = rolling_sum(arr, 2, 2, false, None, None); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[None, None, None, Some(7.0)]); + + let out = rolling_sum(arr, 2, 1, false, None, None); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(7.0)]); + + let out = rolling_sum(arr, 4, 1, false, None, None); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[Some(1.0), Some(1.0), Some(4.0), Some(8.0)]); + + let out = rolling_sum(arr, 4, 1, true, None, None); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[Some(1.0), Some(4.0), Some(8.0), Some(7.0)]); + + let out = rolling_sum(arr, 4, 4, true, None, None); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[None, None, None, None]); + } + + #[test] + fn test_rolling_mean_nulls() { + let arr = get_null_arr(); + let arr = &arr; + + let out = rolling_mean(arr, 2, 2, false, None, None); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[None, None, None, Some(1.5)]); + + let out = rolling_mean(arr, 2, 1, false, None, None); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[Some(1.0), Some(1.0), Some(-1.0), Some(1.5)]); + + let out = rolling_mean(arr, 4, 1, false, None, None); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[Some(1.0), Some(1.0), Some(0.0), Some(4.0 / 3.0)]); + } + + #[test] + fn test_rolling_var_nulls() { + let arr = get_null_arr(); + let arr = &arr; + + let out = rolling_var(arr, 3, 1, false, None, None); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + + assert_eq!(out, &[None, None, Some(2.0), Some(12.5)]); + + let testpars = Some(RollingFnParams::Var(RollingVarParams { ddof: 0 })); + let out = rolling_var(arr, 3, 1, false, None, testpars); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + + assert_eq!(out, &[Some(0.0), Some(0.0), Some(1.0), Some(6.25)]); + + let out = rolling_var(arr, 4, 1, false, None, None); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[None, None, Some(2.0), Some(6.333333333333334)]); + + let out = rolling_var(arr, 4, 1, false, None, testpars); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!( + out, + &[Some(0.), Some(0.0), Some(1.0), Some(4.222222222222222)] + ); + } + + #[test] + fn test_rolling_max_no_nulls() { + let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]); + let arr = &PrimitiveArray::new( + ArrowDataType::Float64, + buf, + Some(Bitmap::from(&[true, true, true, true])), + ); + let out = rolling_max(arr, 4, 1, false, None, None); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[Some(1.0), Some(2.0), Some(3.0), Some(4.0)]); + + let out = rolling_max(arr, 2, 2, false, None, None); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[None, Some(2.0), Some(3.0), Some(4.0)]); + + let out = rolling_max(arr, 4, 4, false, None, None); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[None, None, None, Some(4.0)]); + + let buf = Buffer::from(vec![4.0, 3.0, 2.0, 1.0]); + let arr = &PrimitiveArray::new( + ArrowDataType::Float64, + buf, + Some(Bitmap::from(&[true, true, true, true])), + ); + let out = rolling_max(arr, 2, 1, false, None, None); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[Some(4.0), Some(4.0), Some(3.0), Some(2.0)]); + + let out = + super::no_nulls::rolling_max(arr.values().as_slice(), 2, 1, false, None, None).unwrap(); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[Some(4.0), Some(4.0), Some(3.0), Some(2.0)]); + } + + #[test] + fn test_rolling_extrema_nulls() { + let vals = vec![3, 3, 3, 10, 10, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]; + let validity = Bitmap::new_with_value(true, vals.len()); + let window_size = 3; + let min_periods = 3; + + let arr = Int32Array::new(ArrowDataType::Int32, vals.into(), Some(validity)); + + let out = rolling_apply_agg_window::, _, _>( + arr.values().as_slice(), + arr.validity().as_ref().unwrap(), + window_size, + min_periods, + det_offsets, + None, + ); + let arr = out.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.null_count(), 2); + assert_eq!( + &arr.values().as_slice()[2..], + &[3, 10, 10, 10, 10, 10, 9, 8, 7, 6, 5, 4, 3] + ); + } +} diff --git a/crates/polars-compute/src/rolling/nulls/moment.rs b/crates/polars-compute/src/rolling/nulls/moment.rs new file mode 100644 index 000000000000..af60df5d48fe --- /dev/null +++ b/crates/polars-compute/src/rolling/nulls/moment.rs @@ -0,0 +1,210 @@ +#![allow(unsafe_op_in_unsafe_fn)] + +use num_traits::{FromPrimitive, ToPrimitive}; + +pub use super::super::moment::*; +use super::*; + +pub struct MomentWindow<'a, T, M: StateUpdate> { + slice: &'a [T], + validity: &'a Bitmap, + moment: Option, + last_start: usize, + last_end: usize, + null_count: usize, + params: Option, +} + +impl MomentWindow<'_, T, M> { + // compute sum from the entire window + unsafe fn compute_moment_and_null_count(&mut self, start: usize, end: usize) { + self.moment = None; + let mut idx = start; + self.null_count = 0; + for value in &self.slice[start..end] { + let valid = self.validity.get_bit_unchecked(idx); + if valid { + let value: f64 = NumCast::from(*value).unwrap(); + self.moment + .get_or_insert_with(|| M::new(self.params)) + .insert_one(value); + } else { + self.null_count += 1; + } + idx += 1; + } + } +} + +impl<'a, T: NativeType + ToPrimitive + IsFloat + FromPrimitive, M: StateUpdate> + RollingAggWindowNulls<'a, T> for MomentWindow<'a, T, M> +{ + unsafe fn new( + slice: &'a [T], + validity: &'a Bitmap, + start: usize, + end: usize, + params: Option, + _window_size: Option, + ) -> Self { + let mut out = Self { + slice, + validity, + moment: None, + last_start: start, + last_end: end, + null_count: 0, + params, + }; + out.compute_moment_and_null_count(start, end); + out + } + + unsafe fn update(&mut self, start: usize, end: usize) -> Option { + let recompute_var = if start >= self.last_end { + true + } else { + // remove elements that should leave the window + let mut recompute_var = false; + for idx in self.last_start..start { + // SAFETY: + // we are in bounds + let valid = self.validity.get_bit_unchecked(idx); + if valid { + let leaving_value = *self.slice.get_unchecked(idx); + + // if the leaving value is nan we need to recompute the window + if T::is_float() && !leaving_value.is_finite() { + recompute_var = true; + break; + } + let leaving_value: f64 = NumCast::from(leaving_value).unwrap(); + if let Some(v) = self.moment.as_mut() { + v.remove_one(leaving_value) + } + } else { + // null value leaving the window + self.null_count -= 1; + + // self.sum is None and the leaving value is None + // if the entering value is valid, we might get a new sum. + if self.moment.is_none() { + recompute_var = true; + break; + } + } + } + recompute_var + }; + + self.last_start = start; + + // we traverse all values and compute + if recompute_var { + self.compute_moment_and_null_count(start, end); + } else { + for idx in self.last_end..end { + let valid = self.validity.get_bit_unchecked(idx); + + if valid { + let entering_value = *self.slice.get_unchecked(idx); + let entering_value: f64 = NumCast::from(entering_value).unwrap(); + self.moment + .get_or_insert_with(|| M::new(self.params)) + .insert_one(entering_value); + } else { + // null value entering the window + self.null_count += 1; + } + } + } + self.last_end = end; + self.moment.as_ref().and_then(|v| { + let out = v.finalize(); + out.map(|v| T::from_f64(v).unwrap()) + }) + } + + fn is_valid(&self, min_periods: usize) -> bool { + ((self.last_end - self.last_start) - self.null_count) >= min_periods + } +} + +pub fn rolling_var( + arr: &PrimitiveArray, + window_size: usize, + min_periods: usize, + center: bool, + weights: Option<&[f64]>, + params: Option, +) -> ArrayRef +where + T: NativeType + ToPrimitive + FromPrimitive + IsFloat + Float, +{ + if weights.is_some() { + panic!("weights not yet supported on array with null values") + } + let offsets_fn = if center { + det_offsets_center + } else { + det_offsets + }; + rolling_apply_agg_window::, _, _>( + arr.values().as_slice(), + arr.validity().as_ref().unwrap(), + window_size, + min_periods, + offsets_fn, + params, + ) +} + +pub fn rolling_skew( + arr: &PrimitiveArray, + window_size: usize, + min_periods: usize, + center: bool, + params: Option, +) -> ArrayRef +where + T: NativeType + ToPrimitive + FromPrimitive + IsFloat + Float, +{ + let offsets_fn = if center { + det_offsets_center + } else { + det_offsets + }; + rolling_apply_agg_window::, _, _>( + arr.values().as_slice(), + arr.validity().as_ref().unwrap(), + window_size, + min_periods, + offsets_fn, + params, + ) +} + +pub fn rolling_kurtosis( + arr: &PrimitiveArray, + window_size: usize, + min_periods: usize, + center: bool, + params: Option, +) -> ArrayRef +where + T: NativeType + ToPrimitive + FromPrimitive + IsFloat + Float, +{ + let offsets_fn = if center { + det_offsets_center + } else { + det_offsets + }; + rolling_apply_agg_window::, _, _>( + arr.values().as_slice(), + arr.validity().as_ref().unwrap(), + window_size, + min_periods, + offsets_fn, + params, + ) +} diff --git a/crates/polars-compute/src/rolling/nulls/quantile.rs b/crates/polars-compute/src/rolling/nulls/quantile.rs new file mode 100644 index 000000000000..d03fa9f58f8c --- /dev/null +++ b/crates/polars-compute/src/rolling/nulls/quantile.rs @@ -0,0 +1,253 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use arrow::array::MutablePrimitiveArray; + +use super::*; +use crate::rolling::quantile_filter::SealedRolling; + +pub struct QuantileWindow<'a, T: NativeType + IsFloat + PartialOrd> { + sorted: SortedBufNulls<'a, T>, + prob: f64, + method: QuantileMethod, +} + +impl< + 'a, + T: NativeType + + IsFloat + + Float + + std::iter::Sum + + AddAssign + + SubAssign + + Div + + NumCast + + One + + Zero + + SealedRolling + + PartialOrd + + Sub, +> RollingAggWindowNulls<'a, T> for QuantileWindow<'a, T> +{ + unsafe fn new( + slice: &'a [T], + validity: &'a Bitmap, + start: usize, + end: usize, + params: Option, + window_size: Option, + ) -> Self { + let params = params.unwrap(); + let RollingFnParams::Quantile(params) = params else { + unreachable!("expected Quantile params"); + }; + Self { + sorted: SortedBufNulls::new(slice, validity, start, end, window_size), + prob: params.prob, + method: params.method, + } + } + + unsafe fn update(&mut self, start: usize, end: usize) -> Option { + let null_count = self.sorted.update(start, end); + let mut length = self.sorted.len(); + // The min periods_issue will be taken care of when actually rolling + if null_count == length { + return None; + } + // Nulls are guaranteed to be at the front + length -= null_count; + let mut idx = match self.method { + QuantileMethod::Nearest => ((length as f64) * self.prob) as usize, + QuantileMethod::Lower | QuantileMethod::Midpoint | QuantileMethod::Linear => { + ((length as f64 - 1.0) * self.prob).floor() as usize + }, + QuantileMethod::Higher => ((length as f64 - 1.0) * self.prob).ceil() as usize, + QuantileMethod::Equiprobable => { + ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize + }, + }; + + idx = std::cmp::min(idx, length - 1); + + // we can unwrap because we sliced of the nulls + match self.method { + QuantileMethod::Midpoint => { + let top_idx = ((length as f64 - 1.0) * self.prob).ceil() as usize; + Some( + (self.sorted.get(idx + null_count).unwrap() + + self.sorted.get(top_idx + null_count).unwrap()) + / T::from::(2.0f64).unwrap(), + ) + }, + QuantileMethod::Linear => { + let float_idx = (length as f64 - 1.0) * self.prob; + let top_idx = f64::ceil(float_idx) as usize; + + if top_idx == idx { + Some(self.sorted.get(idx + null_count).unwrap()) + } else { + let proportion = T::from(float_idx - idx as f64).unwrap(); + Some( + proportion + * (self.sorted.get(top_idx + null_count).unwrap() + - self.sorted.get(idx + null_count).unwrap()) + + self.sorted.get(idx + null_count).unwrap(), + ) + } + }, + _ => Some(self.sorted.get(idx).unwrap()), + } + } + + fn is_valid(&self, min_periods: usize) -> bool { + self.sorted.is_valid(min_periods) + } +} + +pub fn rolling_quantile( + arr: &PrimitiveArray, + window_size: usize, + min_periods: usize, + center: bool, + weights: Option<&[f64]>, + params: Option, +) -> ArrayRef +where + T: NativeType + + IsFloat + + Float + + std::iter::Sum + + AddAssign + + SubAssign + + Div + + NumCast + + One + + Zero + + SealedRolling + + PartialOrd + + Sub, +{ + if weights.is_some() { + panic!("weights not yet supported on array with null values") + } + let offset_fn = match center { + true => det_offsets_center, + false => det_offsets, + }; + if !center { + let params = params.as_ref().unwrap(); + let RollingFnParams::Quantile(params) = params else { + unreachable!("expected Quantile params"); + }; + + let out = super::quantile_filter::rolling_quantile::<_, MutablePrimitiveArray<_>>( + params.method, + min_periods, + window_size, + arr.clone(), + params.prob, + ); + let out: PrimitiveArray = out.into(); + return Box::new(out); + } + rolling_apply_agg_window::, _, _>( + arr.values().as_slice(), + arr.validity().as_ref().unwrap(), + window_size, + min_periods, + offset_fn, + params, + ) +} + +#[cfg(test)] +mod test { + use arrow::buffer::Buffer; + use arrow::datatypes::ArrowDataType; + + use super::*; + + #[test] + fn test_rolling_median_nulls() { + let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]); + let arr = &PrimitiveArray::new( + ArrowDataType::Float64, + buf, + Some(Bitmap::from(&[true, false, true, true])), + ); + let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams { + prob: 0.5, + method: QuantileMethod::Linear, + })); + + let out = rolling_quantile(arr, 2, 2, false, None, med_pars); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[None, None, None, Some(3.5)]); + + let out = rolling_quantile(arr, 2, 1, false, None, med_pars); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(3.5)]); + + let out = rolling_quantile(arr, 4, 1, false, None, med_pars); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[Some(1.0), Some(1.0), Some(2.0), Some(3.0)]); + + let out = rolling_quantile(arr, 4, 1, true, None, med_pars); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[Some(1.0), Some(2.0), Some(3.0), Some(3.5)]); + + let out = rolling_quantile(arr, 4, 4, true, None, med_pars); + let out = out.as_any().downcast_ref::>().unwrap(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out, &[None, None, None, None]); + } + + #[test] + fn test_rolling_quantile_nulls_limits() { + // compare quantiles to corresponding min/max/median values + let buf = Buffer::::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]); + let values = &PrimitiveArray::new( + ArrowDataType::Float64, + buf, + Some(Bitmap::from(&[true, false, false, true, true])), + ); + + let methods = vec![ + QuantileMethod::Lower, + QuantileMethod::Higher, + QuantileMethod::Nearest, + QuantileMethod::Midpoint, + QuantileMethod::Linear, + QuantileMethod::Equiprobable, + ]; + + for method in methods { + let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams { + prob: 0.0, + method, + })); + let out1 = rolling_min(values, 2, 1, false, None, None); + let out1 = out1.as_any().downcast_ref::>().unwrap(); + let out1 = out1.into_iter().map(|v| v.copied()).collect::>(); + let out2 = rolling_quantile(values, 2, 1, false, None, min_pars); + let out2 = out2.as_any().downcast_ref::>().unwrap(); + let out2 = out2.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out1, out2); + + let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams { + prob: 1.0, + method, + })); + let out1 = rolling_max(values, 2, 1, false, None, None); + let out1 = out1.as_any().downcast_ref::>().unwrap(); + let out1 = out1.into_iter().map(|v| v.copied()).collect::>(); + let out2 = rolling_quantile(values, 2, 1, false, None, max_pars); + let out2 = out2.as_any().downcast_ref::>().unwrap(); + let out2 = out2.into_iter().map(|v| v.copied()).collect::>(); + assert_eq!(out1, out2); + } + } +} diff --git a/crates/polars-compute/src/rolling/nulls/sum.rs b/crates/polars-compute/src/rolling/nulls/sum.rs new file mode 100644 index 000000000000..c16d1be24c0c --- /dev/null +++ b/crates/polars-compute/src/rolling/nulls/sum.rs @@ -0,0 +1,192 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use super::*; + +pub struct SumWindow<'a, T> { + slice: &'a [T], + validity: &'a Bitmap, + sum: Option, + err: T, + last_start: usize, + last_end: usize, + pub(super) null_count: usize, +} + +impl + Add> + SumWindow<'_, T> +{ + // Kahan summation + fn add(&mut self, val: T) { + if T::is_float() && val.is_finite() { + self.sum = self.sum.map(|sum| { + let y = val - self.err; + let new_sum = sum + y; + self.err = (new_sum - sum) - y; + new_sum + }); + } else { + self.sum = self.sum.map(|v| v + val) + } + } + + fn sub(&mut self, val: T) { + if T::is_float() && val.is_finite() { + self.add(T::zeroed() - val) + } else { + self.sum = self.sum.map(|v| v - val) + } + } +} + +impl + Sub> SumWindow<'_, T> { + // compute sum from the entire window + unsafe fn compute_sum_and_null_count(&mut self, start: usize, end: usize) -> Option { + let mut sum = None; + let mut idx = start; + self.null_count = 0; + for value in &self.slice[start..end] { + let valid = self.validity.get_bit_unchecked(idx); + if valid { + match sum { + None => sum = Some(*value), + Some(current) => sum = Some(*value + current), + } + } else { + self.null_count += 1; + } + idx += 1; + } + self.sum = sum; + sum + } +} + +impl<'a, T: NativeType + IsFloat + Add + Sub + AddAssign + SubAssign> + RollingAggWindowNulls<'a, T> for SumWindow<'a, T> +{ + unsafe fn new( + slice: &'a [T], + validity: &'a Bitmap, + start: usize, + end: usize, + _params: Option, + _window_size: Option, + ) -> Self { + let mut out = Self { + slice, + validity, + sum: None, + err: T::zeroed(), + last_start: start, + last_end: end, + null_count: 0, + }; + out.compute_sum_and_null_count(start, end); + out + } + + unsafe fn update(&mut self, start: usize, end: usize) -> Option { + // if we exceed the end, we have a completely new window + // so we recompute + let recompute_sum = if start >= self.last_end { + true + } else { + // remove elements that should leave the window + let mut recompute_sum = false; + for idx in self.last_start..start { + // SAFETY: + // we are in bounds + let valid = self.validity.get_bit_unchecked(idx); + if valid { + let leaving_value = self.slice.get_unchecked(idx); + + // if the leaving value is nan we need to recompute the window + if T::is_float() && !leaving_value.is_finite() { + recompute_sum = true; + break; + } + self.sub(*leaving_value); + } else { + // null value leaving the window + self.null_count -= 1; + + // self.sum is None and the leaving value is None + // if the entering value is valid, we might get a new sum. + if self.sum.is_none() { + recompute_sum = true; + break; + } + } + } + recompute_sum + }; + + self.last_start = start; + + // we traverse all values and compute + if recompute_sum { + self.compute_sum_and_null_count(start, end); + } else { + for idx in self.last_end..end { + let valid = self.validity.get_bit_unchecked(idx); + + if valid { + let value = *self.slice.get_unchecked(idx); + match self.sum { + None => self.sum = Some(value), + _ => self.add(value), + } + } else { + // null value entering the window + self.null_count += 1; + } + } + } + self.last_end = end; + self.sum + } + + fn is_valid(&self, min_periods: usize) -> bool { + ((self.last_end - self.last_start) - self.null_count) >= min_periods + } +} + +pub fn rolling_sum( + arr: &PrimitiveArray, + window_size: usize, + min_periods: usize, + center: bool, + weights: Option<&[f64]>, + _params: Option, +) -> ArrayRef +where + T: NativeType + + IsFloat + + PartialOrd + + Add + + Sub + + SubAssign + + AddAssign, +{ + if weights.is_some() { + panic!("weights not yet supported on array with null values") + } + if center { + rolling_apply_agg_window::, _, _>( + arr.values().as_slice(), + arr.validity().as_ref().unwrap(), + window_size, + min_periods, + det_offsets_center, + None, + ) + } else { + rolling_apply_agg_window::, _, _>( + arr.values().as_slice(), + arr.validity().as_ref().unwrap(), + window_size, + min_periods, + det_offsets, + None, + ) + } +} diff --git a/crates/polars-compute/src/rolling/quantile_filter.rs b/crates/polars-compute/src/rolling/quantile_filter.rs new file mode 100644 index 000000000000..dc3d51477ecd --- /dev/null +++ b/crates/polars-compute/src/rolling/quantile_filter.rs @@ -0,0 +1,1112 @@ +// Combine dancing links with sort merge. +// https://arxiv.org/abs/1406.1717 +#![allow(unsafe_op_in_unsafe_fn)] +use std::cmp::Ordering; +use std::fmt::{Debug, Formatter}; +use std::ops::{Add, Div, Mul, Sub}; + +use arrow::pushable::Pushable; +use arrow::types::NativeType; +use num_traits::NumCast; +use polars_utils::index::{Bounded, Indexable, NullCount}; +use polars_utils::nulls::IsNull; +use polars_utils::slice::SliceAble; +use polars_utils::sort::arg_sort_ascending; +use polars_utils::total_ord::TotalOrd; + +use super::QuantileMethod; + +struct Block<'a, A> { + k: usize, + tail: usize, + n_element: usize, + // Values buffer + alpha: A, + // Permutation + pi: &'a mut [u32], + prev: &'a mut Vec, + next: &'a mut Vec, + // permutation index in alpha + m: usize, + // index in the list + current_index: usize, + nulls_in_window: usize, +} + +impl Debug for Block<'_, A> +where + A: Indexable, + A::Item: Debug + Copy, +{ + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + if self.n_element == 0 { + return writeln!(f, "empty block"); + } + writeln!(f, "elements in list: {}", self.n_element)?; + writeln!(f, "m: {}", self.m)?; + if self.current_index != self.n_element { + writeln!(f, "m_index: {}", self.current_index)?; + writeln!(f, "α[m]: {:?}", self.alpha.get(self.m))?; + } else { + // Index is at tail, so OOB. + writeln!(f, "m_index: tail")?; + writeln!(f, "α[m]: tail")?; + } + + let mut p = self.m as u32; + + // Find start. + loop { + p = self.prev[p as usize]; + if p as usize == self.tail { + p = self.next[p as usize]; + break; + } + } + + // Find all elements from start. + let mut current = Vec::with_capacity(self.n_element); + for _ in 0..self.n_element { + current.push(self.alpha.get(p as usize)); + p = self.next[p as usize]; + } + + write!(f, "current buffer sorted: [")?; + for (i, v) in current.iter().enumerate() { + if i == self.current_index { + write!(f, "[{v:?}], ")?; + } else { + let chars = if i == self.n_element - 1 { "" } else { ", " }; + write!(f, "{v:?}{chars}")?; + } + } + write!(f, "]") + } +} + +impl<'a, A> Block<'a, A> +where + A: Indexable + Bounded + NullCount + Clone, + ::Item: TotalOrd + Copy + IsNull + Debug + 'a, +{ + fn new( + alpha: A, + scratch: &'a mut Vec, + prev: &'a mut Vec, + next: &'a mut Vec, + ) -> Self { + debug_assert!(!alpha.is_empty()); + let k = alpha.len(); + let pi = arg_sort_ascending((0..alpha.len()).map(|i| alpha.get(i)), scratch, alpha.len()); + + let nulls_in_window = alpha.null_count(); + let m_index = k / 2; + let m = pi[m_index] as usize; + + prev.resize(k + 1, 0); + next.resize(k + 1, 0); + let mut b = Self { + k, + pi, + prev, + next, + m, + current_index: m_index, + n_element: k, + tail: k, + alpha, + nulls_in_window, + }; + b.init_links(); + b + } + + fn capacity(&self) -> usize { + self.alpha.len() + } + + fn init_links(&mut self) { + let mut p = self.tail; + + for &q in self.pi.iter() { + // SAFETY: bounded by pi + unsafe { + *self.next.get_unchecked_mut(p) = q; + *self.prev.get_unchecked_mut(q as usize) = p as u32; + } + + p = q as usize; + } + unsafe { + *self.next.get_unchecked_mut(p) = self.tail as u32; + *self.prev.get_unchecked_mut(self.tail) = p as u32; + } + } + + unsafe fn delete_link(&mut self, i: usize) { + if ::Item::HAS_NULLS && self.alpha.get_unchecked(i).is_null() { + self.nulls_in_window -= 1 + } + + *self + .next + .get_unchecked_mut(*self.prev.get_unchecked(i) as usize) = *self.next.get_unchecked(i); + *self + .prev + .get_unchecked_mut(*self.next.get_unchecked(i) as usize) = *self.prev.get_unchecked(i); + } + + unsafe fn undelete_link(&mut self, i: usize) { + if ::Item::HAS_NULLS && self.alpha.get_unchecked(i).is_null() { + self.nulls_in_window += 1 + } + + *self + .next + .get_unchecked_mut(*self.prev.get_unchecked(i) as usize) = i as u32; + *self + .prev + .get_unchecked_mut(*self.next.get_unchecked(i) as usize) = i as u32; + } + + fn unwind(&mut self) { + for i in (0..self.k).rev() { + // SAFETY: k is upper bound + unsafe { self.delete_link(i) } + } + self.m = self.tail; + self.n_element = 0; + } + + #[cfg(test)] + fn set_median(&mut self) { + // median index position + let new_index = self.n_element / 2; + // SAFETY: only used in tests. + unsafe { self.traverse_to_index(new_index) } + } + + unsafe fn traverse_to_index(&mut self, i: usize) { + match i as i64 - self.current_index as i64 { + 0 => { + // pass + }, + -1 => { + self.current_index -= 1; + self.m = *self.prev.get_unchecked(self.m) as usize; + }, + 1 => self.advance(), + i64::MIN..=0 => { + for _ in i..self.current_index { + self.m = *self.prev.get_unchecked(self.m) as usize; + } + self.current_index = i; + }, + _ => { + for _ in self.current_index..i { + self.m = *self.next.get_unchecked(self.m) as usize; + } + self.current_index = i; + }, + } + } + + fn reverse(&mut self) { + if self.current_index > 0 { + self.current_index -= 1; + self.m = unsafe { *self.prev.get_unchecked(self.m) as usize }; + } + } + + fn advance(&mut self) { + if self.current_index < self.n_element { + self.current_index += 1; + self.m = unsafe { *self.next.get_unchecked(self.m) as usize }; + } + } + + #[cfg(test)] + fn reset(&mut self) { + self.current_index = 0; + self.m = self.next[self.tail] as usize; + } + + unsafe fn delete(&mut self, i: usize) { + if self.at_end() { + self.reverse() + } + let delete = self.get_pair(i); + + let current = self.get_pair(self.m); + + // delete from links + self.delete_link(i); + + self.n_element -= 1; + + match delete.tot_cmp(¤t) { + Ordering::Less => { + // 1, 2, [3], 4, 5 + // 2, [3], 4, 5 + // the del changes index + self.current_index -= 1 + }, + Ordering::Greater => { + // 1, 2, [3], 4, 5 + // 1, 2, [3], 4 + // index position remains unaffected + }, + Ordering::Equal => { + // 1, 2, [3], 4, 5 + // 1, 2, [4], 5 + // go to next position because the link was deleted + if self.n_element >= self.current_index { + let next_m = *self.next.get_unchecked(self.m) as usize; + + if next_m == self.tail && self.n_element > 0 { + // The index points to tail, set the index in the array again. + self.current_index -= 1; + self.m = *self.prev.get_unchecked(self.m) as usize + } else { + self.m = *self.next.get_unchecked(self.m) as usize; + } + } else { + // move to previous position because the link was deleted + // 1, [2], + // [1] + self.m = *self.prev.get_unchecked(self.m) as usize + } + }, + }; + } + + unsafe fn undelete(&mut self, i: usize) { + if !self.is_empty() && self.at_end() { + self.reverse() + } + // undelete from links + self.undelete_link(i); + + if self.is_empty() { + self.m = self.prev[self.m] as usize; + self.n_element = 1; + self.current_index = 0; + return; + } + let added = self.get_pair(i); + let current = self.get_pair(self.m); + + self.n_element += 1; + + match added.tot_cmp(¤t) { + Ordering::Less => { + // 2, [3], 4, 5 + // 1, 2, [3], 4, 5 + // the addition changes index + self.current_index += 1 + }, + Ordering::Greater => { + // 1, 2, [3], 4 + // 1, 2, [3], 4, 5 + // index position remains unaffected + }, + Ordering::Equal => { + // 1, 2, 4, 5 + // 1, 2, [3], 4, 5 + // go to prev position because the link was added + // self.m = self.prev[self.m as usize] as usize; + }, + }; + } + + #[cfg(test)] + fn delete_set_median(&mut self, i: usize) { + // SAFETY: only used in testing + unsafe { self.delete(i) }; + self.set_median() + } + + #[cfg(test)] + fn undelete_set_median(&mut self, i: usize) { + // SAFETY: only used in testing + unsafe { self.undelete(i) }; + self.set_median() + } + + fn at_end(&self) -> bool { + self.m == self.tail + } + + fn is_empty(&self) -> bool { + self.n_element == 0 + } + + fn peek(&self) -> Option<::Item> { + if self.at_end() { + None + } else { + Some(self.alpha.get(self.m)) + } + } + + fn peek_previous(&self) -> Option<::Item> { + let m = self.prev[self.m]; + if m == self.tail as u32 { + None + } else { + Some(self.alpha.get(m as usize)) + } + } + + fn get_pair(&self, i: usize) -> (::Item, u32) { + unsafe { (self.alpha.get_unchecked(i), i as u32) } + } +} + +trait LenGet { + type Item; + fn len(&self) -> usize; + + fn get(&mut self, i: usize) -> Self::Item; + + fn null_count(&self) -> usize; +} + +impl<'a, A> LenGet for &mut Block<'a, A> +where + A: Indexable + Bounded + NullCount + Clone, + ::Item: Copy + TotalOrd + Debug + 'a, +{ + type Item = ::Item; + + fn len(&self) -> usize { + self.n_element + } + + fn get(&mut self, i: usize) -> Self::Item { + // ONLY PRIVATE USE + unsafe { self.traverse_to_index(i) }; + self.peek().unwrap() + } + + fn null_count(&self) -> usize { + self.nulls_in_window + } +} + +struct BlockUnion<'a, A: Indexable> +where + A::Item: TotalOrd + Copy, +{ + block_left: &'a mut Block<'a, A>, + block_right: &'a mut Block<'a, A>, +} + +impl<'a, A> BlockUnion<'a, A> +where + A: Indexable + Bounded + NullCount + Clone, + ::Item: TotalOrd + Copy + Debug, +{ + fn new(block_left: &'a mut Block<'a, A>, block_right: &'a mut Block<'a, A>) -> Self { + Self { + block_left, + block_right, + } + } + + unsafe fn set_state(&mut self, i: usize) { + self.block_left.delete(i); + self.block_right.undelete(i); + } + + fn reverse(&mut self) { + let left = self.block_left.peek_previous(); + let right = self.block_right.peek_previous(); + match (left, right) { + (Some(_), None) => { + self.block_left.reverse(); + }, + (None, Some(_)) => { + self.block_right.reverse(); + }, + (Some(left), Some(right)) => match left.tot_cmp(&right) { + Ordering::Equal | Ordering::Less => { + self.block_right.reverse(); + }, + Ordering::Greater => { + self.block_left.reverse(); + }, + }, + (None, None) => {}, + } + } +} + +impl LenGet for BlockUnion<'_, A> +where + A: Indexable + Bounded + NullCount + Clone, + ::Item: TotalOrd + Copy + Debug, +{ + type Item = ::Item; + + fn len(&self) -> usize { + self.block_left.n_element + self.block_right.n_element + } + + fn get(&mut self, i: usize) -> Self::Item { + debug_assert!(i < self.block_left.len() + self.block_right.len()); + // Simple case, all elements are left. + if self.block_right.n_element == 0 { + unsafe { self.block_left.traverse_to_index(i) }; + return self.block_left.peek().unwrap(); + } else if self.block_left.n_element == 0 { + unsafe { self.block_right.traverse_to_index(i) }; + return self.block_right.peek().unwrap(); + } + + // Needed: one of the block can point too far depending on what was (un)deleted in the other + // block. + let mut peek_index = self.block_left.current_index + self.block_right.current_index + 1; + while i <= peek_index { + self.reverse(); + peek_index = self.block_left.current_index + self.block_right.current_index + 1; + if peek_index <= 1 && i <= 1 { + break; + } + } + + loop { + // Current index position of merge sort + let s = self.block_left.current_index + self.block_right.current_index; + + let left = self.block_left.peek(); + let right = self.block_right.peek(); + match (left, right) { + (Some(left), None) => { + if s == i { + return left; + } + // Only advance on next iteration as the state can change when a new + // delete/undelete occurs. So next get call we might hit a different branch. + self.block_left.advance(); + }, + (None, Some(right)) => { + if s == i { + return right; + } + self.block_right.advance(); + }, + (Some(left), Some(right)) => { + match left.tot_cmp(&right) { + // On equality, take the left as that one was first. + Ordering::Equal | Ordering::Less => { + if s == i { + return left; + } + self.block_left.advance(); + }, + Ordering::Greater => { + if s == i { + return right; + } + self.block_right.advance(); + }, + } + }, + (None, None) => {}, + } + } + } + + fn null_count(&self) -> usize { + self.block_left.nulls_in_window + self.block_right.nulls_in_window + } +} + +pub(super) trait FinishLinear { + fn finish(proportion: f64, lower: Self, upper: Self) -> Self; + fn finish_midpoint(lower: Self, upper: Self) -> Self; +} + +pub trait SealedRolling {} + +impl SealedRolling for i8 {} +impl SealedRolling for i16 {} +impl SealedRolling for i32 {} +impl SealedRolling for i64 {} +impl SealedRolling for u8 {} +impl SealedRolling for u16 {} +impl SealedRolling for u32 {} +impl SealedRolling for u64 {} +impl SealedRolling for i128 {} +impl SealedRolling for f32 {} +impl SealedRolling for f64 {} + +impl< + T: NativeType + + NumCast + + Add + + Sub + + Div + + Mul + + SealedRolling + + Debug, +> FinishLinear for T +{ + fn finish(proportion: f64, lower: Self, upper: Self) -> Self { + debug_assert!(proportion >= 0.0); + debug_assert!(proportion <= 1.0); + let proportion: T = NumCast::from(proportion).unwrap(); + proportion * (upper - lower) + lower + } + fn finish_midpoint(lower: Self, upper: Self) -> Self { + (lower + upper) / NumCast::from(2).unwrap() + } +} + +impl FinishLinear for Option { + fn finish(proportion: f64, lower: Self, upper: Self) -> Self { + match (lower, upper) { + (Some(lower), Some(upper)) => Some(T::finish(proportion, lower, upper)), + (Some(lower), _) => Some(lower), + (None, Some(upper)) => Some(upper), + _ => None, + } + } + fn finish_midpoint(lower: Self, upper: Self) -> Self { + match (lower, upper) { + (Some(lower), Some(upper)) => Some(T::finish_midpoint(lower, upper)), + (Some(lower), _) => Some(lower), + (None, Some(upper)) => Some(upper), + _ => None, + } + } +} + +struct QuantileUpdate { + inner: M, + quantile: f64, + min_periods: usize, + method: QuantileMethod, +} + +impl QuantileUpdate +where + M: LenGet, + ::Item: Default + IsNull + Copy + FinishLinear + Debug, +{ + fn new(method: QuantileMethod, min_periods: usize, quantile: f64, inner: M) -> Self { + Self { + min_periods, + quantile, + inner, + method, + } + } + + fn quantile(&mut self) -> M::Item { + // nulls are ignored in median position. + let null_count = self.inner.null_count(); + let valid_length = self.inner.len() - null_count; + + if M::Item::HAS_NULLS && valid_length < self.min_periods { + // Default is None + return M::Item::default(); + } + + let valid_length_f = valid_length as f64; + + use QuantileMethod::*; + match self.method { + Linear => { + let float_idx_top = (valid_length_f - 1.0) * self.quantile; + let idx = float_idx_top.floor() as usize; + let top_idx = float_idx_top.ceil() as usize; + + if idx == top_idx { + self.inner.get(idx + null_count) + } else { + let vi = self.inner.get(idx + null_count); + let vj = self.inner.get(top_idx + null_count); + let proportion = float_idx_top - idx as f64; + <::Item>::finish(proportion, vi, vj) + } + }, + Nearest => { + let idx = (valid_length_f * self.quantile) as usize; + let idx = std::cmp::min(idx, valid_length - 1); + self.inner.get(idx + null_count) + }, + Equiprobable => { + let idx = ((valid_length_f * self.quantile).ceil() - 1.0).max(0.0) as usize; + self.inner.get(idx + null_count) + }, + Midpoint => { + let idx = (valid_length_f * self.quantile) as usize; + let idx = std::cmp::min(idx, valid_length - 1); + + let top_idx = ((valid_length_f - 1.0) * self.quantile).ceil() as usize; + if top_idx == idx { + self.inner.get(idx + null_count) + } else { + let mid = self.inner.get(idx + null_count); + let mid_1 = self.inner.get(top_idx + null_count); + <::Item>::finish_midpoint(mid, mid_1) + } + }, + Lower => { + let idx = ((valid_length_f - 1.0) * self.quantile).floor() as usize; + let idx = std::cmp::min(idx, valid_length - 1); + self.inner.get(idx + null_count) + }, + Higher => { + let idx = ((valid_length_f - 1.0) * self.quantile).ceil() as usize; + let idx = std::cmp::min(idx, valid_length - 1); + self.inner.get(idx + null_count) + }, + } + } +} + +pub(super) fn rolling_quantile::Item>>( + method: QuantileMethod, + min_periods: usize, + k: usize, + values: A, + quantile: f64, +) -> Out +where + A: Indexable + SliceAble + Bounded + NullCount + Clone, + ::Item: Default + TotalOrd + Copy + FinishLinear + Debug, +{ + let mut scratch_left = vec![]; + let mut prev_left = vec![]; + let mut next_left = vec![]; + + let mut scratch_right = vec![]; + let mut prev_right = vec![]; + let mut next_right = vec![]; + + let k = std::cmp::min(k, values.len()); + let alpha = values.slice(0..k); + + let mut out = Out::with_capacity(values.len()); + + let scratch_right_ptr = &mut scratch_right as *mut Vec; + let scratch_left_ptr = &mut scratch_left as *mut Vec; + let prev_right_ptr = &mut prev_right as *mut Vec<_>; + let prev_left_ptr = &mut prev_left as *mut Vec<_>; + let next_right_ptr = &mut next_right as *mut Vec<_>; + let next_left_ptr = &mut next_left as *mut Vec<_>; + + let n_blocks = values.len() / k; + + let mut block_left = unsafe { + Block::new( + alpha, + &mut *scratch_left_ptr, + &mut *prev_left_ptr, + &mut *next_left_ptr, + ) + }; + let mut block_right = unsafe { + Block::new( + values.slice(0..1), + &mut *scratch_right_ptr, + &mut *prev_right_ptr, + &mut *next_right_ptr, + ) + }; + + let ptr_left = &mut block_left as *mut Block<'_, _>; + let ptr_right = &mut block_right as *mut Block<'_, _>; + + block_left.unwind(); + + for i in 0..block_left.capacity() { + // SAFETY: bounded by capacity + unsafe { block_left.undelete(i) }; + + let mut mu = QuantileUpdate::new(method, min_periods, quantile, &mut block_left); + out.push(mu.quantile()); + } + for i in 1..n_blocks + 1 { + // Block left is now completely full as it is completely filled coming from the boundary effects. + debug_assert!(block_left.n_element == k); + + // Windows state at this point. + // + // - BLOCK_LEFT -- BLOCK_RIGHT - + // |-------------||-------------| + // - WINDOW - + // |--------------| + let end = std::cmp::min((i + 1) * k, values.len()); + let alpha = unsafe { values.slice_unchecked(i * k..end) }; + + if alpha.is_empty() { + break; + } + + // Find the scratch that belongs to the left window that has gone out of scope + let (scratch, prev, next) = if i % 2 == 0 { + (scratch_left_ptr, prev_left_ptr, next_left_ptr) + } else { + (scratch_right_ptr, prev_right_ptr, next_right_ptr) + }; + + block_right = unsafe { Block::new(alpha, &mut *scratch, &mut *prev, &mut *next) }; + + // Time reverse the rhs so we can undelete in sorted order. + block_right.unwind(); + + // Here the window will move from BLOCK_LEFT into BLOCK_RIGHT + for j in 0..block_right.capacity() { + unsafe { + let mut union = BlockUnion::new(&mut *ptr_left, &mut *ptr_right); + union.set_state(j); + let q: ::Item = + QuantileUpdate::new(method, min_periods, quantile, union).quantile(); + out.push(q); + } + } + + std::mem::swap(&mut block_left, &mut block_right); + } + out +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_block_1() { + // 0, 1, 2, 3, 4, 5, 6, 7 + let values = [2, 8, 5, 9, 1, 3, 4, 10].as_ref(); + let mut scratch = vec![]; + let mut prev = vec![]; + let mut next = vec![]; + let mut b = Block::new(values, &mut scratch, &mut prev, &mut next); + + // Unwind to get temporal window + b.unwind(); + + // Insert window in the right order + b.undelete_set_median(0); + // [[2]] + assert_eq!(b.peek(), Some(2)); + b.undelete_set_median(1); + // [2, [8]] + assert_eq!(b.peek(), Some(8)); + b.undelete_set_median(2); + // [2, [5], 8] + assert_eq!(b.peek(), Some(5)); + b.undelete_set_median(3); + // [2, 5, [8], 9] + assert_eq!(b.peek(), Some(8)); + b.undelete_set_median(4); + // [1, 2, [5], 8, 9] + assert_eq!(b.peek(), Some(5)); + b.undelete_set_median(5); + // [1, 2, 3, [5], 8, 9] + assert_eq!(b.peek(), Some(5)); + b.undelete_set_median(6); + // [1, 2, 3, [4], 5, 8, 9] + assert_eq!(b.peek(), Some(4)); + b.undelete_set_median(7); + // [1, 2, 3, 4, [5], 8, 9, 10] + assert_eq!(b.peek(), Some(5)); + + // Now we will delete as the block` will leave the window. + b.delete_set_median(0); + // [1, 3, 4, [5], 8, 9, 10] + assert_eq!(b.peek(), Some(5)); + b.delete_set_median(1); + // [1, 3, 4, [5], 9, 10] + assert_eq!(b.peek(), Some(5)); + b.delete_set_median(2); + // [1, 3, [4], 9, 10] + assert_eq!(b.peek(), Some(4)); + b.delete_set_median(3); + // [1, 3, [4], 10] + assert_eq!(b.peek(), Some(4)); + b.delete_set_median(4); + // [3, [4], 10] + assert_eq!(b.peek(), Some(4)); + b.delete_set_median(5); + // [4, [10]] + assert_eq!(b.peek(), Some(10)); + b.delete_set_median(6); + // [[10]] + assert_eq!(b.peek(), Some(10)); + } + + #[test] + fn test_block_2() { + let values = [9, 1, 2].as_ref(); + let mut scratch = vec![]; + let mut prev = vec![]; + let mut next = vec![]; + let mut b = Block::new(values, &mut scratch, &mut prev, &mut next); + + b.unwind(); + b.undelete_set_median(0); + assert_eq!(b.peek(), Some(9)); + b.undelete_set_median(1); + assert_eq!(b.peek(), Some(9)); + b.undelete_set_median(2); + assert_eq!(b.peek(), Some(2)); + } + + #[test] + fn test_block_union_1() { + let alpha_a = [10, 4, 2]; + let alpha_b = [3, 4, 1]; + + let mut scratch = vec![]; + let mut prev = vec![]; + let mut next = vec![]; + let mut a = Block::new(alpha_a.as_ref(), &mut scratch, &mut prev, &mut next); + + let mut scratch = vec![]; + let mut prev = vec![]; + let mut next = vec![]; + let mut b = Block::new(alpha_b.as_ref(), &mut scratch, &mut prev, &mut next); + + b.unwind(); + let mut aub = BlockUnion::new(&mut a, &mut b); + assert_eq!(aub.len(), 3); + // STEP 0 + // block 1: + // i: 10, 4, 2 + // s: 2, 4, 10 + // block 2: empty + assert_eq!(aub.get(0), 2); + assert_eq!(aub.get(1), 4); + assert_eq!(aub.get(2), 10); + + unsafe { + // STEP 1 + aub.block_left.reset(); + aub.set_state(0); + assert_eq!(aub.len(), 3); + // block 1: + // i: 4, 2 + // s: 2, 4 + // block 2: + // i: 3 + // s: 3 + // union s: [2, 3, 4] + assert_eq!(aub.get(0), 2); + assert_eq!(aub.get(1), 3); + assert_eq!(aub.get(2), 4); + + // STEP 2 + // i: 2 + // s: 2 + // block 2: + // i: 3, 4 + // s: 3, 4 + // union s: [2, 3, 4] + aub.set_state(1); + assert_eq!(aub.get(0), 2); + assert_eq!(aub.get(1), 3); + assert_eq!(aub.get(2), 4); + } + } + + #[test] + fn test_block_union_2() { + let alpha_a = [3, 4, 5, 7, 3, 9, 2, 6, 9, 8].as_ref(); + let alpha_b = [2, 2, 1, 7, 5, 3, 2, 6, 1, 7].as_ref(); + + let mut scratch = vec![]; + let mut prev = vec![]; + let mut next = vec![]; + let mut a = Block::new(alpha_a, &mut scratch, &mut prev, &mut next); + + let mut scratch = vec![]; + let mut prev = vec![]; + let mut next = vec![]; + let mut b = Block::new(alpha_b, &mut scratch, &mut prev, &mut next); + + b.unwind(); + let mut aub = BlockUnion::new(&mut a, &mut b); + assert_eq!(aub.len(), 10); + // STEP 0 + // block 1: + // i: 3, 4, 5, 7, 3, 9, 2, 6, 9, 8 + // s: 2, 3, 3, 4, 5, 6, 7, 8, 9, 9 + // block 2: empty + assert_eq!(aub.get(0), 2); + assert_eq!(aub.get(1), 3); + assert_eq!(aub.get(2), 3); + // skip a step + assert_eq!(aub.get(4), 5); + // skip to end + assert_eq!(aub.get(9), 9); + + // get median + assert_eq!(aub.get(5), 6); + + unsafe { + // STEP 1 + aub.set_state(0); + assert_eq!(aub.len(), 10); + // block 1: + // i: 4, 5, 7, 3, 9, 2, 6, 9, 8 + // s: 2, 3, 4, 5, 6, 7, 8, 9, 9 + // block 2: + // i: 2 + // s: 2 + // union s: 2, 2, 3, 4, 5, [6], 7, 8, 9, 9 + assert_eq!(aub.get(5), 6); + assert_eq!(aub.get(7), 8); + + // STEP 2 + aub.set_state(1); + + // Back to index 4 + aub.block_left.reset(); + aub.block_right.reset(); + assert_eq!(aub.get(4), 5); + // block 1: + // i: 5, 7, 3, 9, 2, 6, 9, 8 + // s: 2, 3, 5, 6, 7, 8, 9, 9 + // block 2: + // i: 2, 2 + // s: 2, 2 + // union s: 2, 2, 3, 4, 5, [6], 7, 8, 9, 9 + assert_eq!(aub.get(5), 6); + + // STEP 3 + aub.set_state(2); + // block 1: + // i: 7, 3, 9, 2, 6, 9, 8 + // s: 2, 3, 6, 7, 8, 9, 9 + // block 2: + // i: 2, 2, 1 + // s: 1, 2, 2 + // union s: 1, 2, 2, 3, 4, [6], 7, 8, 9, 9 + assert_eq!(aub.get(5), 6); + + // STEP 4 + aub.set_state(3); + // block 1: + // i: 3, 9, 2, 6, 9, 8 + // s: 2, 3, 6, 8, 9, 9 + // block 2: + // i: 2, 2, 1, 7 + // s: 1, 2, 2, 7 + // union s: 1, 2, 2, 3, 4, [6], 7, 8, 9, 9 + assert_eq!(aub.get(5), 6); + + // STEP 5 + aub.set_state(4); + // block 1: + // i: 9, 2, 6, 9, 8 + // s: 2, 6, 8, 9, 9 + // block 2: + // i: 2, 2, 1, 7, 5 + // s: 1, 2, 2, 5, 7 + // union s: 1, 2, 2, 2, 5, [6], 7, 8, 9, 9 + assert_eq!(aub.get(5), 6); + assert_eq!(aub.len(), 10); + + // STEP 6 + aub.set_state(5); + // LEFT IS phasing out + // block 1: + // i: 2, 6, 9, 8 + // s: 2, 6, 8, 9 + // block 2: + // i: 2, 2, 1, 7, 5, 3 + // s: 1, 2, 2, 3, 5, 7 + // union s: 1, 2, 2, 2, 4, [5], 6, 7, 8, 9 + assert_eq!(aub.len(), 10); + assert_eq!(aub.get(5), 5); + + // STEP 7 + aub.set_state(6); + // block 1: + // i: 6, 9, 8 + // s: 6, 8, 9 + // block 2: + // i: 2, 2, 1, 7, 5, 3, 2 + // s: 1, 2, 2, 2, 3, 5, 7 + // union s: 1, 2, 2, 2, 3, [5], 6, 7, 8, 9 + assert_eq!(aub.len(), 10); + assert_eq!(aub.get(5), 5); + + // STEP 8 + aub.set_state(7); + // block 1: + // i: 9, 8 + // s: 8, 9 + // block 2: + // i: 2, 2, 1, 7, 5, 3, 2, 6 + // s: 1, 2, 2, 2, 3, 5, 6, 7 + // union s: 1, 2, 2, 2, 3, [5], 6, 7, 8, 9 + assert_eq!(aub.len(), 10); + assert_eq!(aub.get(5), 5); + + // STEP 9 + aub.set_state(8); + // block 1: + // i: 8 + // s: 8 + // block 2: + // i: 2, 2, 1, 7, 5, 3, 2, 6, 1 + // s: 1, 1, 2, 2, 2, 3, 5, 6, 7 + // union s: 1, 1, 2, 2, 2, [3], 5, 6, 7, 8 + assert_eq!(aub.len(), 10); + assert_eq!(aub.get(5), 3); + + // STEP 10 + aub.set_state(9); + // block 1: empty + // block 2: + // i: 2, 2, 1, 7, 5, 3, 2, 6, 1, 7 + // s: 1, 1, 2, 2, 2, 3, 5, 6, 7 + // union s: 1, 1, 2, 2, 2, [3], 5, 6, 7, 7 + assert_eq!(aub.len(), 10); + assert_eq!(aub.get(5), 3); + } + } + + #[test] + fn test_median_1() { + let values = [ + 2.0, 8.0, 5.0, 9.0, 1.0, 2.0, 4.0, 2.0, 4.0, 8.1, -1.0, 2.9, 1.2, 23.0, + ] + .as_ref(); + let out: Vec<_> = rolling_quantile(QuantileMethod::Linear, 0, 3, values, 0.5); + let expected = [ + 2.0, 5.0, 5.0, 8.0, 5.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 2.9, 1.2, 2.9, + ]; + assert_eq!(out, expected); + let out: Vec<_> = rolling_quantile(QuantileMethod::Linear, 0, 5, values, 0.5); + let expected = [ + 2.0, 5.0, 5.0, 6.5, 5.0, 5.0, 4.0, 2.0, 2.0, 4.0, 4.0, 2.9, 2.9, 2.9, + ]; + assert_eq!(out, expected); + let out: Vec<_> = rolling_quantile(QuantileMethod::Linear, 0, 7, values, 0.5); + let expected = [ + 2.0, 5.0, 5.0, 6.5, 5.0, 3.5, 4.0, 4.0, 4.0, 4.0, 2.0, 2.9, 2.9, 2.9, + ]; + assert_eq!(out, expected); + let out: Vec<_> = rolling_quantile(QuantileMethod::Linear, 0, 4, values, 0.5); + let expected = [ + 2.0, 5.0, 5.0, 6.5, 6.5, 3.5, 3.0, 2.0, 3.0, 4.0, 3.0, 3.45, 2.05, 2.05, + ]; + assert_eq!(out, expected); + } + + #[test] + fn test_median_2() { + let values = [10, 10, 15, 13, 9, 5, 3, 13, 19, 15, 19].as_ref(); + let out: Vec<_> = rolling_quantile(QuantileMethod::Linear, 0, 3, values, 0.5); + let expected = [10, 10, 10, 13, 13, 9, 5, 5, 13, 15, 19]; + assert_eq!(out, expected); + } +} diff --git a/crates/polars-compute/src/rolling/window.rs b/crates/polars-compute/src/rolling/window.rs new file mode 100644 index 000000000000..37a2b167d4ba --- /dev/null +++ b/crates/polars-compute/src/rolling/window.rs @@ -0,0 +1,192 @@ +use ::skiplist::OrderedSkipList; +use polars_utils::total_ord::TotalOrd; + +use super::*; + +pub(super) struct SortedBuf<'a, T: NativeType> { + // slice over which the window slides + slice: &'a [T], + last_start: usize, + last_end: usize, + // values within the window that we keep sorted + pub buf: OrderedSkipList, +} + +impl<'a, T: NativeType + PartialOrd + Copy> SortedBuf<'a, T> { + pub(super) fn new( + slice: &'a [T], + start: usize, + end: usize, + max_window_size: Option, + ) -> Self { + let mut buf = if let Some(max_window_size) = max_window_size { + OrderedSkipList::with_capacity(max_window_size) + } else { + OrderedSkipList::new() + }; + unsafe { buf.sort_by(TotalOrd::tot_cmp) }; + let mut out = Self { + slice, + last_start: start, + last_end: end, + buf, + }; + let init = &slice[start..end]; + out.reset(init); + out + } + + fn reset(&mut self, slice: &[T]) { + self.buf.clear(); + self.buf.extend(slice.iter().copied()); + } + + /// Update the window position by setting the `start` index and the `end` index. + /// + /// # Safety + /// The caller must ensure that `start` and `end` are within bounds of `self.slice` + /// + pub(super) unsafe fn update(&mut self, start: usize, end: usize) { + // swap the whole buffer + if start >= self.last_end { + self.buf.clear(); + let new_window = unsafe { self.slice.get_unchecked(start..end) }; + self.reset(new_window); + } else { + // remove elements that should leave the window + for idx in self.last_start..start { + // SAFETY: + // in bounds + let val = unsafe { self.slice.get_unchecked(idx) }; + self.buf.remove(val); + } + + // insert elements that enter the window, but insert them sorted + for idx in self.last_end..end { + // SAFETY: + // we are in bounds + let val = unsafe { *self.slice.get_unchecked(idx) }; + self.buf.insert(val); + } + } + self.last_start = start; + self.last_end = end; + } + + pub(super) fn get(&self, index: usize) -> T { + self.buf[index] + } + + pub(super) fn len(&self) -> usize { + self.buf.len() + } +} + +pub(super) struct SortedBufNulls<'a, T: NativeType> { + // slice over which the window slides + slice: &'a [T], + validity: &'a Bitmap, + last_start: usize, + last_end: usize, + // values within the window that we keep sorted + buf: OrderedSkipList>, + pub null_count: usize, +} + +impl<'a, T: NativeType + PartialOrd> SortedBufNulls<'a, T> { + unsafe fn fill_and_sort_buf(&mut self, start: usize, end: usize) { + self.null_count = 0; + let iter = (start..end).map(|idx| unsafe { + if self.validity.get_bit_unchecked(idx) { + Some(*self.slice.get_unchecked(idx)) + } else { + self.null_count += 1; + None + } + }); + + self.buf.clear(); + self.buf.extend(iter); + } + + pub(super) unsafe fn new( + slice: &'a [T], + validity: &'a Bitmap, + start: usize, + end: usize, + max_window_size: Option, + ) -> Self { + let mut buf = if let Some(max_window_size) = max_window_size { + OrderedSkipList::with_capacity(max_window_size) + } else { + OrderedSkipList::new() + }; + unsafe { buf.sort_by(TotalOrd::tot_cmp) }; + + // sort_opt_buf(&mut buf); + let mut out = Self { + slice, + validity, + last_start: start, + last_end: end, + buf, + null_count: 0, + }; + unsafe { out.fill_and_sort_buf(start, end) }; + out + } + + /// Update the window position by setting the `start` index and the `end` index. + /// + /// # Safety + /// The caller must ensure that `start` and `end` are within bounds of `self.slice` + /// + pub(super) unsafe fn update(&mut self, start: usize, end: usize) -> usize { + // swap the whole buffer + if start >= self.last_end { + unsafe { self.fill_and_sort_buf(start, end) }; + } else { + // remove elements that should leave the window + for idx in self.last_start..start { + // SAFETY: + // we are in bounds + let val = if unsafe { self.validity.get_bit_unchecked(idx) } { + unsafe { Some(*self.slice.get_unchecked(idx)) } + } else { + self.null_count -= 1; + None + }; + self.buf.remove(&val); + } + + // insert elements that enter the window, but insert them sorted + for idx in self.last_end..end { + // SAFETY: + // we are in bounds + let val = if unsafe { self.validity.get_bit_unchecked(idx) } { + unsafe { Some(*self.slice.get_unchecked(idx)) } + } else { + self.null_count += 1; + None + }; + + self.buf.insert(val); + } + } + self.last_start = start; + self.last_end = end; + self.null_count + } + + pub(super) fn is_valid(&self, min_periods: usize) -> bool { + ((self.last_end - self.last_start) - self.null_count) >= min_periods + } + + pub(super) fn len(&self) -> usize { + self.buf.len() + } + + pub(super) fn get(&self, idx: usize) -> Option { + self.buf[idx] + } +} diff --git a/crates/polars-compute/src/size.rs b/crates/polars-compute/src/size.rs new file mode 100644 index 000000000000..199a945e1d10 --- /dev/null +++ b/crates/polars-compute/src/size.rs @@ -0,0 +1,9 @@ +use arrow::array::{Array, ArrayRef, BinaryViewArray, UInt32Array}; +use arrow::buffer::Buffer; +use arrow::datatypes::ArrowDataType; + +pub fn binary_size_bytes(array: &BinaryViewArray) -> ArrayRef { + let values: Buffer<_> = array.len_iter().collect(); + let array = UInt32Array::new(ArrowDataType::UInt32, values, array.validity().cloned()); + Box::new(array) +} diff --git a/crates/polars-compute/src/sum.rs b/crates/polars-compute/src/sum.rs new file mode 100644 index 000000000000..e9aa1b0eb17b --- /dev/null +++ b/crates/polars-compute/src/sum.rs @@ -0,0 +1,156 @@ +use std::ops::Add; +#[cfg(feature = "simd")] +use std::simd::prelude::*; + +use arrow::array::{Array, PrimitiveArray}; +use arrow::bitmap::bitmask::BitMask; +use arrow::types::NativeType; +use num_traits::Zero; + +macro_rules! wrapping_impl { + ($trait_name:ident, $method:ident, $t:ty) => { + impl $trait_name for $t { + #[inline(always)] + fn wrapping_add(&self, v: &Self) -> Self { + <$t>::$method(*self, *v) + } + } + }; +} + +/// Performs addition that wraps around on overflow. +/// +/// Differs from num::WrappingAdd in that this is also implemented for floats. +pub trait WrappingAdd: Sized { + /// Wrapping (modular) addition. Computes `self + other`, wrapping around at + /// the boundary of the type. + fn wrapping_add(&self, v: &Self) -> Self; +} + +wrapping_impl!(WrappingAdd, wrapping_add, u8); +wrapping_impl!(WrappingAdd, wrapping_add, u16); +wrapping_impl!(WrappingAdd, wrapping_add, u32); +wrapping_impl!(WrappingAdd, wrapping_add, u64); +wrapping_impl!(WrappingAdd, wrapping_add, usize); +wrapping_impl!(WrappingAdd, wrapping_add, u128); + +wrapping_impl!(WrappingAdd, wrapping_add, i8); +wrapping_impl!(WrappingAdd, wrapping_add, i16); +wrapping_impl!(WrappingAdd, wrapping_add, i32); +wrapping_impl!(WrappingAdd, wrapping_add, i64); +wrapping_impl!(WrappingAdd, wrapping_add, isize); +wrapping_impl!(WrappingAdd, wrapping_add, i128); + +wrapping_impl!(WrappingAdd, add, f32); +wrapping_impl!(WrappingAdd, add, f64); + +#[cfg(feature = "simd")] +const STRIPE: usize = 16; + +fn wrapping_sum_with_mask_scalar(vals: &[T], mask: &BitMask) -> T { + assert!(vals.len() == mask.len()); + vals.iter() + .enumerate() + .map(|(i, x)| { + // No filter but rather select of 0 for cmov opt. + if mask.get(i) { *x } else { T::zero() } + }) + .fold(T::zero(), |a, b| a.wrapping_add(&b)) +} + +#[cfg(not(feature = "simd"))] +impl WrappingSum for T +where + T: NativeType + WrappingAdd + Zero, +{ + fn wrapping_sum(vals: &[Self]) -> Self { + vals.iter() + .copied() + .fold(T::zero(), |a, b| a.wrapping_add(&b)) + } + + fn wrapping_sum_with_validity(vals: &[Self], mask: &BitMask) -> Self { + wrapping_sum_with_mask_scalar(vals, mask) + } +} + +#[cfg(feature = "simd")] +impl WrappingSum for T +where + T: NativeType + WrappingAdd + Zero + crate::SimdPrimitive, +{ + fn wrapping_sum(vals: &[Self]) -> Self { + vals.iter() + .copied() + .fold(T::zero(), |a, b| a.wrapping_add(&b)) + } + + fn wrapping_sum_with_validity(vals: &[Self], mask: &BitMask) -> Self { + assert!(vals.len() == mask.len()); + let remainder = vals.len() % STRIPE; + let (rest, main) = vals.split_at(remainder); + let (rest_mask, main_mask) = mask.split_at(remainder); + let zero: Simd = Simd::default(); + + let vsum = main + .chunks_exact(STRIPE) + .enumerate() + .map(|(i, a)| { + let m: Mask<_, STRIPE> = main_mask.get_simd(i * STRIPE); + m.select(Simd::from_slice(a), zero) + }) + .fold(zero, |a, b| { + let a = a.to_array(); + let b = b.to_array(); + Simd::from_array(std::array::from_fn(|i| a[i].wrapping_add(&b[i]))) + }); + + let mainsum = vsum + .to_array() + .into_iter() + .fold(T::zero(), |a, b| a.wrapping_add(&b)); + + // TODO: faster remainder. + let restsum = wrapping_sum_with_mask_scalar(rest, &rest_mask); + mainsum.wrapping_add(&restsum) + } +} + +#[cfg(feature = "simd")] +impl WrappingSum for u128 { + fn wrapping_sum(vals: &[Self]) -> Self { + vals.iter().copied().fold(0, |a, b| a.wrapping_add(b)) + } + + fn wrapping_sum_with_validity(vals: &[Self], mask: &BitMask) -> Self { + wrapping_sum_with_mask_scalar(vals, mask) + } +} + +#[cfg(feature = "simd")] +impl WrappingSum for i128 { + fn wrapping_sum(vals: &[Self]) -> Self { + vals.iter().copied().fold(0, |a, b| a.wrapping_add(b)) + } + + fn wrapping_sum_with_validity(vals: &[Self], mask: &BitMask) -> Self { + wrapping_sum_with_mask_scalar(vals, mask) + } +} + +pub trait WrappingSum: Sized { + fn wrapping_sum(vals: &[Self]) -> Self; + fn wrapping_sum_with_validity(vals: &[Self], mask: &BitMask) -> Self; +} + +pub fn wrapping_sum_arr(arr: &PrimitiveArray) -> T +where + T: NativeType + WrappingSum, +{ + let validity = arr.validity().filter(|_| arr.null_count() > 0); + if let Some(mask) = validity { + WrappingSum::wrapping_sum_with_validity(arr.values(), &BitMask::from_bitmap(mask)) + } else { + WrappingSum::wrapping_sum(arr.values()) + } +} diff --git a/crates/polars-compute/src/unique/boolean.rs b/crates/polars-compute/src/unique/boolean.rs new file mode 100644 index 000000000000..0dad1d704522 --- /dev/null +++ b/crates/polars-compute/src/unique/boolean.rs @@ -0,0 +1,128 @@ +use arrow::array::{Array, BooleanArray}; +use arrow::bitmap::BitmapBuilder; +use arrow::datatypes::ArrowDataType; + +use super::{GenericUniqueKernel, RangedUniqueKernel}; + +#[derive(Default)] +pub struct BooleanUniqueKernelState { + seen: u32, +} + +impl BooleanUniqueKernelState { + pub fn new() -> Self { + Self::default() + } +} + +impl RangedUniqueKernel for BooleanUniqueKernelState { + type Array = BooleanArray; + + fn has_seen_all(&self) -> bool { + self.seen == 0b111 + } + + fn append(&mut self, array: &Self::Array) { + if array.len() == 0 { + return; + } + + let null_count = array.null_count(); + self.seen |= u32::from(null_count > 0) << 2; + let set_bits = if null_count > 0 { + array + .values() + .num_intersections_with(array.validity().unwrap()) + } else { + array.values().set_bits() + }; + + self.seen |= u32::from(set_bits != array.len() - null_count); + self.seen |= u32::from(set_bits != 0) << 1; + } + + fn append_state(&mut self, other: &Self) { + self.seen |= other.seen; + } + + fn finalize_unique(self) -> Self::Array { + let mut values = BitmapBuilder::with_capacity(self.seen.count_ones() as usize); + + if self.seen & 0b001 != 0 { + values.push(false); + } + if self.seen & 0b010 != 0 { + values.push(true); + } + let validity = if self.seen & 0b100 != 0 { + let mut validity = BitmapBuilder::with_capacity(values.len() + 1); + validity.extend_constant(values.len(), true); + validity.push(false); + values.push(false); + Some(validity.freeze()) + } else { + None + }; + + let values = values.freeze(); + BooleanArray::new(ArrowDataType::Boolean, values, validity) + } + + fn finalize_n_unique(&self) -> usize { + self.seen.count_ones() as usize + } + + fn finalize_n_unique_non_null(&self) -> usize { + (self.seen & 0b011).count_ones() as usize + } +} + +impl GenericUniqueKernel for BooleanArray { + fn unique(&self) -> Self { + let mut state = BooleanUniqueKernelState::new(); + state.append(self); + state.finalize_unique() + } + + fn n_unique(&self) -> usize { + let mut state = BooleanUniqueKernelState::new(); + state.append(self); + state.finalize_n_unique() + } + + fn n_unique_non_null(&self) -> usize { + let mut state = BooleanUniqueKernelState::new(); + state.append(self); + state.finalize_n_unique_non_null() + } +} + +#[test] +fn test_boolean_distinct_count() { + use arrow::bitmap::Bitmap; + use arrow::datatypes::ArrowDataType; + + macro_rules! assert_bool_dc { + ($values:expr, $validity:expr => $dc:expr) => { + let validity: Option = + >>::map($validity, |v| Bitmap::from_iter(v)); + let arr = + BooleanArray::new(ArrowDataType::Boolean, Bitmap::from_iter($values), validity); + assert_eq!(arr.n_unique(), $dc); + }; + } + + assert_bool_dc!(vec![], None => 0); + assert_bool_dc!(vec![], Some(vec![]) => 0); + assert_bool_dc!(vec![true], None => 1); + assert_bool_dc!(vec![true], Some(vec![true]) => 1); + assert_bool_dc!(vec![true], Some(vec![false]) => 1); + assert_bool_dc!(vec![true, false], None => 2); + assert_bool_dc!(vec![true, false, false], None => 2); + assert_bool_dc!(vec![true, false, false], Some(vec![true, true, false]) => 3); + + // Copied from https://github.com/pola-rs/polars/pull/16765#discussion_r1629426159 + assert_bool_dc!(vec![true, true, true, true, true], Some(vec![true, false, true, false, false]) => 2); + assert_bool_dc!(vec![false, true, false, true, true], Some(vec![true, false, true, false, false]) => 2); + assert_bool_dc!(vec![true, false, true, false, true, true], Some(vec![true, true, false, true, false, false]) => 3); +} diff --git a/crates/polars-compute/src/unique/dictionary.rs b/crates/polars-compute/src/unique/dictionary.rs new file mode 100644 index 000000000000..9ad967ed38c2 --- /dev/null +++ b/crates/polars-compute/src/unique/dictionary.rs @@ -0,0 +1,63 @@ +use arrow::array::{Array, DictionaryArray}; +use arrow::datatypes::ArrowDataType; + +use super::{PrimitiveRangedUniqueState, RangedUniqueKernel}; + +/// A specialized unique kernel for [`DictionaryArray`] for when all values are in a small known +/// range. +pub struct DictionaryRangedUniqueState { + key_state: PrimitiveRangedUniqueState, + values: Box, +} + +impl DictionaryRangedUniqueState { + pub fn new(values: Box) -> Self { + Self { + key_state: PrimitiveRangedUniqueState::new(0, values.len() as u32 + 1), + values, + } + } + + pub fn key_state(&mut self) -> &mut PrimitiveRangedUniqueState { + &mut self.key_state + } +} + +impl RangedUniqueKernel for DictionaryRangedUniqueState { + type Array = DictionaryArray; + + fn has_seen_all(&self) -> bool { + self.key_state.has_seen_all() + } + + fn append(&mut self, array: &Self::Array) { + self.key_state.append(array.keys()); + } + + fn append_state(&mut self, other: &Self) { + debug_assert_eq!(self.values, other.values); + self.key_state.append_state(&other.key_state); + } + + fn finalize_unique(self) -> Self::Array { + let keys = self.key_state.finalize_unique(); + DictionaryArray::::try_new( + ArrowDataType::Dictionary( + arrow::datatypes::IntegerType::UInt32, + Box::new(self.values.dtype().clone()), + false, + ), + keys, + self.values, + ) + .unwrap() + } + + fn finalize_n_unique(&self) -> usize { + self.key_state.finalize_n_unique() + } + + fn finalize_n_unique_non_null(&self) -> usize { + self.key_state.finalize_n_unique_non_null() + } +} diff --git a/crates/polars-compute/src/unique/mod.rs b/crates/polars-compute/src/unique/mod.rs new file mode 100644 index 000000000000..645daea39126 --- /dev/null +++ b/crates/polars-compute/src/unique/mod.rs @@ -0,0 +1,69 @@ +use arrow::array::Array; + +/// Kernel to calculate the number of unique elements where the elements are already sorted. +pub trait SortedUniqueKernel: Array { + /// Calculate the set of unique elements in `fst` and `others` and fold the result into one + /// array. + fn unique_fold<'a>(fst: &'a Self, others: impl Iterator) -> Self; + + /// Calculate the set of unique elements in [`Self`] where we have no further information about + /// `self`. + fn unique(&self) -> Self; + + /// Calculate the number of unique elements in [`Self`] + /// + /// A null is also considered a unique value + fn n_unique(&self) -> usize; + + /// Calculate the number of unique non-null elements in [`Self`] + fn n_unique_non_null(&self) -> usize; +} + +/// Optimized kernel to calculate the unique elements of an array. +/// +/// This kernel is a specialized for where all values are known to be in some small range of +/// values. In this case, you can usually get by with a bitset and bit-arithmetic instead of using +/// vectors and hashsets. Consequently, this kernel is usually called when further information is +/// known about the underlying array. +/// +/// This trait is not implemented directly on the `Array` as with many other kernels. Rather, it is +/// implemented on a `State` struct to which `Array`s can be appended. This allows for sharing of +/// `State` between many chunks and allows for different implementations for the same array (e.g. a +/// maintain order and no maintain-order variant). +pub trait RangedUniqueKernel { + type Array: Array; + + /// Returns whether all the values in the whole range are in the state + fn has_seen_all(&self) -> bool; + + /// Append an `Array`'s values to the `State` + fn append(&mut self, array: &Self::Array); + + /// Append another state into the `State` + fn append_state(&mut self, other: &Self); + + /// Consume the state to get the unique elements + fn finalize_unique(self) -> Self::Array; + /// Consume the state to get the number of unique elements including null + fn finalize_n_unique(&self) -> usize; + /// Consume the state to get the number of unique elements excluding null + fn finalize_n_unique_non_null(&self) -> usize; +} + +/// A generic unique kernel that selects the generally applicable unique kernel for an `Array`. +pub trait GenericUniqueKernel { + /// Calculate the set of unique elements + fn unique(&self) -> Self; + /// Calculate the number of unique elements including null + fn n_unique(&self) -> usize; + /// Calculate the number of unique elements excluding null + fn n_unique_non_null(&self) -> usize; +} + +mod boolean; +mod dictionary; +mod primitive; + +pub use boolean::BooleanUniqueKernelState; +pub use dictionary::DictionaryRangedUniqueState; +pub use primitive::PrimitiveRangedUniqueState; diff --git a/crates/polars-compute/src/unique/primitive.rs b/crates/polars-compute/src/unique/primitive.rs new file mode 100644 index 000000000000..541bb2750c30 --- /dev/null +++ b/crates/polars-compute/src/unique/primitive.rs @@ -0,0 +1,236 @@ +use std::ops::{Add, RangeInclusive, Sub}; + +use arrow::array::PrimitiveArray; +use arrow::bitmap::bitmask::BitMask; +use arrow::bitmap::{BitmapBuilder, MutableBitmap}; +use arrow::datatypes::ArrowDataType; +use arrow::types::NativeType; +use num_traits::{FromPrimitive, ToPrimitive}; +use polars_utils::total_ord::TotalOrd; + +use super::RangedUniqueKernel; + +/// A specialized unique kernel for [`PrimitiveArray`] for when all values are in a small known +/// range. +pub struct PrimitiveRangedUniqueState { + seen: Seen, + range: RangeInclusive, +} + +enum Seen { + Small(u128), + Large(MutableBitmap), +} + +impl Seen { + pub fn from_size(size: usize) -> Self { + if size <= 128 { + Self::Small(0) + } else { + Self::Large(MutableBitmap::from_len_zeroed(size)) + } + } + + fn num_seen(&self) -> usize { + match self { + Seen::Small(v) => v.count_ones() as usize, + Seen::Large(v) => v.set_bits(), + } + } + + fn has_seen_null(&self, size: usize) -> bool { + match self { + Self::Small(v) => v >> (size - 1) != 0, + Self::Large(v) => v.get(size - 1), + } + } +} + +impl PrimitiveRangedUniqueState +where + T: Add + Sub + ToPrimitive + FromPrimitive, +{ + pub fn new(min_value: T, max_value: T) -> Self { + let size = (max_value - min_value).to_usize().unwrap(); + // Range is inclusive + let size = size + 1; + // One value is left for null + let size = size + 1; + + Self { + seen: Seen::from_size(size), + range: min_value..=max_value, + } + } + + fn size(&self) -> usize { + (*self.range.end() - *self.range.start()) + .to_usize() + .unwrap() + + 1 + } +} + +impl RangedUniqueKernel for PrimitiveRangedUniqueState +where + T: Add + Sub + ToPrimitive + FromPrimitive, +{ + type Array = PrimitiveArray; + + fn has_seen_all(&self) -> bool { + let size = self.size(); + match &self.seen { + Seen::Small(v) if size == 128 => !v == 0, + Seen::Small(v) => *v == ((1 << size) - 1), + Seen::Large(v) => BitMask::new(v.as_slice(), 0, size).unset_bits() == 0, + } + } + + fn append(&mut self, array: &Self::Array) { + let size = self.size(); + match array.validity().as_ref().filter(|v| v.unset_bits() > 0) { + None => { + const STEP_SIZE: usize = 512; + + let mut i = 0; + let values = array.values().as_slice(); + + match self.seen { + Seen::Small(ref mut seen) => { + // Check every so often whether we have already seen all the values. + while *seen != ((1 << (size - 1)) - 1) && i < values.len() { + for v in values[i..].iter().take(STEP_SIZE) { + if cfg!(debug_assertions) { + assert!(TotalOrd::tot_ge(v, self.range.start())); + assert!(TotalOrd::tot_le(v, self.range.end())); + } + + let v = *v - *self.range.start(); + let v = unsafe { v.to_usize().unwrap_unchecked() }; + *seen |= 1 << v; + } + + i += STEP_SIZE; + } + }, + Seen::Large(ref mut seen) => { + // Check every so often whether we have already seen all the values. + while BitMask::new(seen.as_slice(), 0, size - 1).unset_bits() > 0 + && i < values.len() + { + for v in values[i..].iter().take(STEP_SIZE) { + if cfg!(debug_assertions) { + assert!(TotalOrd::tot_ge(v, self.range.start())); + assert!(TotalOrd::tot_le(v, self.range.end())); + } + + let v = *v - *self.range.start(); + let v = unsafe { v.to_usize().unwrap_unchecked() }; + seen.set(v, true); + } + + i += STEP_SIZE; + } + }, + } + }, + Some(_) => { + let iter = array.non_null_values_iter(); + + match self.seen { + Seen::Small(ref mut seen) => { + *seen |= 1 << (size - 1); + + for v in iter { + if cfg!(debug_assertions) { + assert!(TotalOrd::tot_ge(&v, self.range.start())); + assert!(TotalOrd::tot_le(&v, self.range.end())); + } + + let v = v - *self.range.start(); + let v = unsafe { v.to_usize().unwrap_unchecked() }; + *seen |= 1 << v; + } + }, + Seen::Large(ref mut seen) => { + seen.set(size - 1, true); + + for v in iter { + if cfg!(debug_assertions) { + assert!(TotalOrd::tot_ge(&v, self.range.start())); + assert!(TotalOrd::tot_le(&v, self.range.end())); + } + + let v = v - *self.range.start(); + let v = unsafe { v.to_usize().unwrap_unchecked() }; + seen.set(v, true); + } + }, + } + }, + } + } + + fn append_state(&mut self, other: &Self) { + debug_assert_eq!(self.size(), other.size()); + match (&mut self.seen, &other.seen) { + (Seen::Small(lhs), Seen::Small(rhs)) => *lhs |= rhs, + (Seen::Large(lhs), Seen::Large(rhs)) => { + let mut lhs = lhs; + <&mut MutableBitmap as std::ops::BitOrAssign<&MutableBitmap>>::bitor_assign( + &mut lhs, rhs, + ) + }, + _ => unreachable!(), + } + } + + fn finalize_unique(self) -> Self::Array { + let size = self.size(); + let seen = self.seen; + + let has_null = seen.has_seen_null(size); + let num_values = seen.num_seen(); + let mut values = Vec::with_capacity(num_values); + + let mut offset = 0; + match seen { + Seen::Small(mut v) => { + while v != 0 { + let shift = v.trailing_zeros(); + offset += shift as u8; + values.push(*self.range.start() + T::from_u8(offset).unwrap()); + + v >>= shift + 1; + offset += 1; + } + }, + Seen::Large(v) => { + for offset in v.freeze().true_idx_iter() { + values.push(*self.range.start() + T::from_usize(offset).unwrap()); + } + }, + } + + let validity = if has_null { + let mut validity = BitmapBuilder::new(); + validity.extend_constant(values.len() - 1, true); + validity.push(false); + // The null has already been pushed. + *values.last_mut().unwrap() = T::zeroed(); + Some(validity.freeze()) + } else { + None + }; + + PrimitiveArray::new(ArrowDataType::from(T::PRIMITIVE), values.into(), validity) + } + + fn finalize_n_unique(&self) -> usize { + self.seen.num_seen() + } + + fn finalize_n_unique_non_null(&self) -> usize { + self.seen.num_seen() - usize::from(self.seen.has_seen_null(self.size())) + } +} diff --git a/crates/polars-core/Cargo.toml b/crates/polars-core/Cargo.toml new file mode 100644 index 000000000000..4906a2831d16 --- /dev/null +++ b/crates/polars-core/Cargo.toml @@ -0,0 +1,169 @@ +[package] +name = "polars-core" +version = { workspace = true } +authors = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +license = { workspace = true } +repository = { workspace = true } +description = "Core of the Polars DataFrame library" + +[dependencies] +polars-compute = { workspace = true, features = ["gather"] } +polars-error = { workspace = true } +polars-row = { workspace = true } +polars-schema = { workspace = true } +polars-utils = { workspace = true } + +arrow = { workspace = true } +bitflags = { workspace = true } +bytemuck = { workspace = true } +chrono = { workspace = true, optional = true } +chrono-tz = { workspace = true, optional = true } +comfy-table = { version = "7.1.1", default-features = false, optional = true } +either = { workspace = true } +hashbrown = { workspace = true } +hashbrown_old_nightly_hack = { workspace = true } +indexmap = { workspace = true } +itoa = { workspace = true } +ndarray = { workspace = true, optional = true } +num-traits = { workspace = true } +rand = { workspace = true, optional = true, features = ["small_rng", "std"] } +rand_distr = { workspace = true, optional = true } +rayon = { workspace = true } +regex = { workspace = true, optional = true } +# activate if you want serde support for Series and DataFrames +serde = { workspace = true, optional = true } +serde_json = { workspace = true, optional = true } +strum_macros = { workspace = true } +xxhash-rust = { workspace = true } + +[dev-dependencies] +bincode = { workspace = true } +serde_json = { workspace = true } + +[build-dependencies] +version_check = { workspace = true } + +[features] +simd = ["arrow/simd", "polars-compute/simd"] +nightly = ["simd", "hashbrown/nightly", "hashbrown_old_nightly_hack/nightly", "polars-utils/nightly", "arrow/nightly"] +avx512 = [] +docs = [] +temporal = ["regex", "chrono", "polars-error/regex"] +random = ["rand", "rand_distr"] +algorithm_group_by = [] +default = ["algorithm_group_by"] +lazy = [] + +# ~40% faster collect, needed until trustedlength iter stabilizes +# more fast paths, slower compilation +performant = ["arrow/performant", "reinterpret"] + +# extra utilities for StringChunked +strings = ["regex", "arrow/strings", "polars-error/regex"] +# support for ObjectChunked (downcastable Series of any type) +object = ["serde_json", "algorithm_group_by"] + +fmt = ["comfy-table/tty"] +fmt_no_tty = ["comfy-table"] + +# opt-in features +# create from row values +# and include pivot operation +rows = [] + +# operations +approx_unique = ["polars-compute/approx_unique"] +bitwise = ["algorithm_group_by"] +zip_with = [] +round_series = [] +checked_arithmetic = [] +is_first_distinct = [] +is_last_distinct = [] +dot_product = [] +row_hash = [] +reinterpret = [] +take_opt_iter = [] +# allow group_by operation on list type +group_by_list = [] +# rolling window functions +rolling_window = [] +rolling_window_by = [] +diagonal_concat = [] +dataframe_arithmetic = [] +product = [] +unique_counts = [] +partition_by = ["algorithm_group_by"] +describe = [] +timezones = ["temporal", "chrono", "chrono-tz", "arrow/chrono-tz", "arrow/timezones"] +dynamic_group_by = ["dtype-datetime", "dtype-date"] +list_arithmetic = [] +array_arithmetic = ["dtype-array"] + +# opt-in datatypes for Series +dtype-date = ["temporal"] +dtype-datetime = ["temporal"] +dtype-duration = ["temporal"] +dtype-time = ["temporal"] +dtype-array = ["arrow/dtype-array", "polars-compute/dtype-array"] +dtype-i8 = [] +dtype-i16 = [] +dtype-i128 = ["polars-compute/dtype-i128"] +dtype-decimal = ["arrow/dtype-decimal", "polars-compute/cast", "polars-compute/dtype-decimal", "dtype-i128"] +dtype-u8 = [] +dtype-u16 = [] +dtype-categorical = [] +dtype-struct = [] + +# scale to terabytes? +bigidx = ["arrow/bigidx", "polars-utils/bigidx"] +python = [] + +serde = [ + "dep:serde", + "bitflags/serde", + "polars-schema/serde", + "polars-utils/serde", + "arrow/io_ipc", + "arrow/io_ipc_compression", + "serde_json", +] +serde-lazy = ["serde", "arrow/serde", "indexmap/serde", "chrono/serde"] + +docs-selection = [ + "ndarray", + "rows", + "docs", + "strings", + "object", + "lazy", + "temporal", + "random", + "zip_with", + "checked_arithmetic", + "is_first_distinct", + "is_last_distinct", + "dot_product", + "row_hash", + "rolling_window", + "rolling_window_by", + "serde", + "dtype-categorical", + "dtype-decimal", + "diagonal_concat", + "dataframe_arithmetic", + "product", + "describe", + "partition_by", + "algorithm_group_by", + "list_arithmetic", + "array_arithmetic", +] + +[package.metadata.docs.rs] +# not all because arrow 4.3 does not compile with simd +# all-features = true +features = ["docs-selection"] +# defines the configuration attribute `docsrs` +rustdoc-args = ["--cfg", "docsrs"] diff --git a/crates/polars-core/LICENSE b/crates/polars-core/LICENSE new file mode 120000 index 000000000000..30cff7403da0 --- /dev/null +++ b/crates/polars-core/LICENSE @@ -0,0 +1 @@ +../../LICENSE \ No newline at end of file diff --git a/crates/polars-core/README.md b/crates/polars-core/README.md new file mode 100644 index 000000000000..81aabbc979e1 --- /dev/null +++ b/crates/polars-core/README.md @@ -0,0 +1,7 @@ +# polars-core + +`polars-core` is an **internal sub-crate** of the [Polars](https://crates.io/crates/polars) library, +providing its core functionalities. + +**Important Note**: This crate is **not intended for external usage**. Please refer to the main +[Polars crate](https://crates.io/crates/polars) for intended usage. diff --git a/crates/polars-core/build.rs b/crates/polars-core/build.rs new file mode 100644 index 000000000000..3e4ab64620ac --- /dev/null +++ b/crates/polars-core/build.rs @@ -0,0 +1,7 @@ +fn main() { + println!("cargo:rerun-if-changed=build.rs"); + let channel = version_check::Channel::read().unwrap(); + if channel.is_nightly() { + println!("cargo:rustc-cfg=feature=\"nightly\""); + } +} diff --git a/crates/polars-core/src/chunked_array/arithmetic/decimal.rs b/crates/polars-core/src/chunked_array/arithmetic/decimal.rs new file mode 100644 index 000000000000..2b7eac56f70c --- /dev/null +++ b/crates/polars-core/src/chunked_array/arithmetic/decimal.rs @@ -0,0 +1,56 @@ +use super::*; + +impl Add for &DecimalChunked { + type Output = PolarsResult; + + fn add(self, rhs: Self) -> Self::Output { + let scale = _get_decimal_scale_add_sub(self.scale(), rhs.scale()); + let lhs = self.to_scale(scale)?; + let rhs = rhs.to_scale(scale)?; + Ok((&lhs.0 + &rhs.0).into_decimal_unchecked(None, scale)) + } +} + +impl Sub for &DecimalChunked { + type Output = PolarsResult; + + fn sub(self, rhs: Self) -> Self::Output { + let scale = _get_decimal_scale_add_sub(self.scale(), rhs.scale()); + let lhs = self.to_scale(scale)?; + let rhs = rhs.to_scale(scale)?; + Ok((&lhs.0 - &rhs.0).into_decimal_unchecked(None, scale)) + } +} + +impl Mul for &DecimalChunked { + type Output = PolarsResult; + + fn mul(self, rhs: Self) -> Self::Output { + let scale = _get_decimal_scale_mul(self.scale(), rhs.scale()); + Ok((&self.0 * &rhs.0).into_decimal_unchecked(None, scale)) + } +} + +impl Div for &DecimalChunked { + type Output = PolarsResult; + + fn div(self, rhs: Self) -> Self::Output { + let scale = _get_decimal_scale_div(self.scale()); + let lhs = self.to_scale(scale + rhs.scale())?; + Ok((&lhs.0 / &rhs.0).into_decimal_unchecked(None, scale)) + } +} + +// Used by polars-plan to determine schema. +pub fn _get_decimal_scale_add_sub(scale_left: usize, scale_right: usize) -> usize { + scale_left.max(scale_right) +} + +pub fn _get_decimal_scale_mul(scale_left: usize, scale_right: usize) -> usize { + scale_left + scale_right +} + +pub fn _get_decimal_scale_div(scale_left: usize) -> usize { + // Follow postgres and MySQL adding a fixed scale increment of 4 + scale_left + 4 +} diff --git a/crates/polars-core/src/chunked_array/arithmetic/mod.rs b/crates/polars-core/src/chunked_array/arithmetic/mod.rs new file mode 100644 index 000000000000..0285f2b6a81b --- /dev/null +++ b/crates/polars-core/src/chunked_array/arithmetic/mod.rs @@ -0,0 +1,189 @@ +//! Implementations of arithmetic operations on ChunkedArrays. +#[cfg(feature = "dtype-decimal")] +mod decimal; +mod numeric; + +use std::ops::{Add, Div, Mul, Rem, Sub}; + +use arrow::compute::utils::combine_validities_and; +#[cfg(feature = "dtype-decimal")] +pub use decimal::{_get_decimal_scale_add_sub, _get_decimal_scale_div, _get_decimal_scale_mul}; +use num_traits::{Num, NumCast, ToPrimitive}; +pub use numeric::ArithmeticChunked; + +use crate::prelude::arity::unary_elementwise_values; +use crate::prelude::*; + +#[inline] +fn concat_binary_arrs(l: &[u8], r: &[u8], buf: &mut Vec) { + buf.clear(); + + buf.extend_from_slice(l); + buf.extend_from_slice(r); +} + +impl Add for &StringChunked { + type Output = StringChunked; + + fn add(self, rhs: Self) -> Self::Output { + unsafe { (self.as_binary() + rhs.as_binary()).to_string_unchecked() } + } +} + +impl Add for StringChunked { + type Output = StringChunked; + + fn add(self, rhs: Self) -> Self::Output { + (&self).add(&rhs) + } +} + +impl Add<&str> for &StringChunked { + type Output = StringChunked; + + fn add(self, rhs: &str) -> Self::Output { + unsafe { ((&self.as_binary()) + rhs.as_bytes()).to_string_unchecked() } + } +} + +fn concat_binview(a: &BinaryViewArray, b: &BinaryViewArray) -> BinaryViewArray { + let validity = combine_validities_and(a.validity(), b.validity()); + + let mut mutable = MutableBinaryViewArray::with_capacity(a.len()); + + let mut scratch = vec![]; + for (a, b) in a.values_iter().zip(b.values_iter()) { + concat_binary_arrs(a, b, &mut scratch); + mutable.push_value(&scratch) + } + + mutable.freeze().with_validity(validity) +} + +impl Add for &BinaryChunked { + type Output = BinaryChunked; + + fn add(self, rhs: Self) -> Self::Output { + // broadcasting path rhs + if rhs.len() == 1 { + let rhs = rhs.get(0); + let mut buf = vec![]; + return match rhs { + Some(rhs) => { + self.apply_mut(|s| { + concat_binary_arrs(s, rhs, &mut buf); + let out = buf.as_slice(); + // SAFETY: lifetime is bound to the outer scope and the + // ref is valid for the lifetime of this closure. + unsafe { std::mem::transmute::<_, &'static [u8]>(out) } + }) + }, + None => BinaryChunked::full_null(self.name().clone(), self.len()), + }; + } + // broadcasting path lhs + if self.len() == 1 { + let lhs = self.get(0); + let mut buf = vec![]; + return match lhs { + Some(lhs) => rhs.apply_mut(|s| { + concat_binary_arrs(lhs, s, &mut buf); + let out = buf.as_slice(); + // SAFETY: lifetime is bound to the outer scope and the + // ref is valid for the lifetime of this closure. + unsafe { std::mem::transmute::<_, &'static [u8]>(out) } + }), + None => BinaryChunked::full_null(self.name().clone(), rhs.len()), + }; + } + + arity::binary(self, rhs, concat_binview) + } +} + +impl Add for BinaryChunked { + type Output = BinaryChunked; + + fn add(self, rhs: Self) -> Self::Output { + (&self).add(&rhs) + } +} + +impl Add<&[u8]> for &BinaryChunked { + type Output = BinaryChunked; + + fn add(self, rhs: &[u8]) -> Self::Output { + let arr = BinaryViewArray::from_slice_values([rhs]); + let rhs: BinaryChunked = arr.into(); + self.add(&rhs) + } +} + +fn add_boolean(a: &BooleanArray, b: &BooleanArray) -> PrimitiveArray { + let validity = combine_validities_and(a.validity(), b.validity()); + + let values = a + .values_iter() + .zip(b.values_iter()) + .map(|(a, b)| a as IdxSize + b as IdxSize) + .collect::>(); + PrimitiveArray::from_data_default(values.into(), validity) +} + +impl Add for &BooleanChunked { + type Output = IdxCa; + + fn add(self, rhs: Self) -> Self::Output { + // Broadcasting path rhs. + if rhs.len() == 1 { + let rhs = rhs.get(0); + return match rhs { + Some(rhs) => unary_elementwise_values(self, |v| v as IdxSize + rhs as IdxSize), + None => IdxCa::full_null(self.name().clone(), self.len()), + }; + } + // Broadcasting path lhs. + if self.len() == 1 { + return rhs.add(self); + } + arity::binary(self, rhs, add_boolean) + } +} + +impl Add for BooleanChunked { + type Output = IdxCa; + + fn add(self, rhs: Self) -> Self::Output { + (&self).add(&rhs) + } +} + +#[cfg(test)] +pub(crate) mod test { + use crate::prelude::*; + + pub(crate) fn create_two_chunked() -> (Int32Chunked, Int32Chunked) { + let mut a1 = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3]); + let a2 = Int32Chunked::new(PlSmallStr::from_static("a"), &[4, 5, 6]); + let a3 = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3, 4, 5, 6]); + a1.append(&a2).unwrap(); + (a1, a3) + } + + #[test] + #[allow(clippy::eq_op)] + fn test_chunk_mismatch() { + let (a1, a2) = create_two_chunked(); + // With different chunks. + let _ = &a1 + &a2; + let _ = &a1 - &a2; + let _ = &a1 / &a2; + let _ = &a1 * &a2; + + // With same chunks. + let _ = &a1 + &a1; + let _ = &a1 - &a1; + let _ = &a1 / &a1; + let _ = &a1 * &a1; + } +} diff --git a/crates/polars-core/src/chunked_array/arithmetic/numeric.rs b/crates/polars-core/src/chunked_array/arithmetic/numeric.rs new file mode 100644 index 000000000000..8fe054b06e39 --- /dev/null +++ b/crates/polars-core/src/chunked_array/arithmetic/numeric.rs @@ -0,0 +1,422 @@ +use polars_compute::arithmetic::ArithmeticKernel; + +use super::*; +use crate::chunked_array::arity::{ + apply_binary_kernel_broadcast, apply_binary_kernel_broadcast_owned, unary_kernel, + unary_kernel_owned, +}; + +macro_rules! impl_op_overload { + ($op: ident, $trait_method: ident, $ca_method: ident, $ca_method_scalar: ident) => { + impl $op for ChunkedArray { + type Output = ChunkedArray; + + fn $trait_method(self, rhs: Self) -> Self::Output { + ArithmeticChunked::$ca_method(self, rhs) + } + } + + impl $op for &ChunkedArray { + type Output = ChunkedArray; + + fn $trait_method(self, rhs: Self) -> Self::Output { + ArithmeticChunked::$ca_method(self, rhs) + } + } + + // TODO: make this more strict instead of casting. + impl $op for ChunkedArray { + type Output = ChunkedArray; + + fn $trait_method(self, rhs: N) -> Self::Output { + let rhs: T::Native = NumCast::from(rhs).unwrap(); + ArithmeticChunked::$ca_method_scalar(self, rhs) + } + } + + impl $op for &ChunkedArray { + type Output = ChunkedArray; + + fn $trait_method(self, rhs: N) -> Self::Output { + let rhs: T::Native = NumCast::from(rhs).unwrap(); + ArithmeticChunked::$ca_method_scalar(self, rhs) + } + } + }; +} + +impl_op_overload!(Add, add, wrapping_add, wrapping_add_scalar); +impl_op_overload!(Sub, sub, wrapping_sub, wrapping_sub_scalar); +impl_op_overload!(Mul, mul, wrapping_mul, wrapping_mul_scalar); +impl_op_overload!(Div, div, legacy_div, legacy_div_scalar); // FIXME: replace this with true division. +impl_op_overload!(Rem, rem, wrapping_mod, wrapping_mod_scalar); + +pub trait ArithmeticChunked { + type Scalar; + type Out; + type TrueDivOut; + + fn wrapping_abs(self) -> Self::Out; + fn wrapping_neg(self) -> Self::Out; + fn wrapping_add(self, rhs: Self) -> Self::Out; + fn wrapping_sub(self, rhs: Self) -> Self::Out; + fn wrapping_mul(self, rhs: Self) -> Self::Out; + fn wrapping_floor_div(self, rhs: Self) -> Self::Out; + fn wrapping_trunc_div(self, rhs: Self) -> Self::Out; + fn wrapping_mod(self, rhs: Self) -> Self::Out; + + fn wrapping_add_scalar(self, rhs: Self::Scalar) -> Self::Out; + fn wrapping_sub_scalar(self, rhs: Self::Scalar) -> Self::Out; + fn wrapping_sub_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::Out; + fn wrapping_mul_scalar(self, rhs: Self::Scalar) -> Self::Out; + fn wrapping_floor_div_scalar(self, rhs: Self::Scalar) -> Self::Out; + fn wrapping_floor_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::Out; + fn wrapping_trunc_div_scalar(self, rhs: Self::Scalar) -> Self::Out; + fn wrapping_trunc_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::Out; + fn wrapping_mod_scalar(self, rhs: Self::Scalar) -> Self::Out; + fn wrapping_mod_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::Out; + + fn true_div(self, rhs: Self) -> Self::TrueDivOut; + fn true_div_scalar(self, rhs: Self::Scalar) -> Self::TrueDivOut; + fn true_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::TrueDivOut; + + // TODO: remove these. + // These are flooring division for integer types, true division for floating point types. + fn legacy_div(self, rhs: Self) -> Self::Out; + fn legacy_div_scalar(self, rhs: Self::Scalar) -> Self::Out; + fn legacy_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::Out; +} + +impl ArithmeticChunked for ChunkedArray { + type Scalar = T::Native; + type Out = ChunkedArray; + type TrueDivOut = ChunkedArray<::TrueDivPolarsType>; + + fn wrapping_abs(self) -> Self::Out { + unary_kernel_owned(self, ArithmeticKernel::wrapping_abs) + } + + fn wrapping_neg(self) -> Self::Out { + unary_kernel_owned(self, ArithmeticKernel::wrapping_neg) + } + + fn wrapping_add(self, rhs: Self) -> Self::Out { + apply_binary_kernel_broadcast_owned( + self, + rhs, + ArithmeticKernel::wrapping_add, + |l, r| ArithmeticKernel::wrapping_add_scalar(r, l), + ArithmeticKernel::wrapping_add_scalar, + ) + } + + fn wrapping_sub(self, rhs: Self) -> Self::Out { + apply_binary_kernel_broadcast_owned( + self, + rhs, + ArithmeticKernel::wrapping_sub, + ArithmeticKernel::wrapping_sub_scalar_lhs, + ArithmeticKernel::wrapping_sub_scalar, + ) + } + + fn wrapping_mul(self, rhs: Self) -> Self::Out { + apply_binary_kernel_broadcast_owned( + self, + rhs, + ArithmeticKernel::wrapping_mul, + |l, r| ArithmeticKernel::wrapping_mul_scalar(r, l), + ArithmeticKernel::wrapping_mul_scalar, + ) + } + + fn wrapping_floor_div(self, rhs: Self) -> Self::Out { + apply_binary_kernel_broadcast_owned( + self, + rhs, + ArithmeticKernel::wrapping_floor_div, + ArithmeticKernel::wrapping_floor_div_scalar_lhs, + ArithmeticKernel::wrapping_floor_div_scalar, + ) + } + + fn wrapping_trunc_div(self, rhs: Self) -> Self::Out { + apply_binary_kernel_broadcast_owned( + self, + rhs, + ArithmeticKernel::wrapping_trunc_div, + ArithmeticKernel::wrapping_trunc_div_scalar_lhs, + ArithmeticKernel::wrapping_trunc_div_scalar, + ) + } + + fn wrapping_mod(self, rhs: Self) -> Self::Out { + apply_binary_kernel_broadcast_owned( + self, + rhs, + ArithmeticKernel::wrapping_mod, + ArithmeticKernel::wrapping_mod_scalar_lhs, + ArithmeticKernel::wrapping_mod_scalar, + ) + } + + fn wrapping_add_scalar(self, rhs: Self::Scalar) -> Self::Out { + unary_kernel_owned(self, |a| ArithmeticKernel::wrapping_add_scalar(a, rhs)) + } + + fn wrapping_sub_scalar(self, rhs: Self::Scalar) -> Self::Out { + unary_kernel_owned(self, |a| ArithmeticKernel::wrapping_sub_scalar(a, rhs)) + } + + fn wrapping_sub_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::Out { + unary_kernel_owned(rhs, |a| ArithmeticKernel::wrapping_sub_scalar_lhs(lhs, a)) + } + + fn wrapping_mul_scalar(self, rhs: Self::Scalar) -> Self::Out { + unary_kernel_owned(self, |a| ArithmeticKernel::wrapping_mul_scalar(a, rhs)) + } + + fn wrapping_floor_div_scalar(self, rhs: Self::Scalar) -> Self::Out { + unary_kernel_owned(self, |a| { + ArithmeticKernel::wrapping_floor_div_scalar(a, rhs) + }) + } + + fn wrapping_floor_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::Out { + unary_kernel_owned(rhs, |a| { + ArithmeticKernel::wrapping_floor_div_scalar_lhs(lhs, a) + }) + } + + fn wrapping_trunc_div_scalar(self, rhs: Self::Scalar) -> Self::Out { + unary_kernel_owned(self, |a| { + ArithmeticKernel::wrapping_trunc_div_scalar(a, rhs) + }) + } + + fn wrapping_trunc_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::Out { + unary_kernel_owned(rhs, |a| { + ArithmeticKernel::wrapping_trunc_div_scalar_lhs(lhs, a) + }) + } + + fn wrapping_mod_scalar(self, rhs: Self::Scalar) -> Self::Out { + unary_kernel_owned(self, |a| ArithmeticKernel::wrapping_mod_scalar(a, rhs)) + } + + fn wrapping_mod_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::Out { + unary_kernel_owned(rhs, |a| ArithmeticKernel::wrapping_mod_scalar_lhs(lhs, a)) + } + + fn true_div(self, rhs: Self) -> Self::TrueDivOut { + apply_binary_kernel_broadcast_owned( + self, + rhs, + ArithmeticKernel::true_div, + ArithmeticKernel::true_div_scalar_lhs, + ArithmeticKernel::true_div_scalar, + ) + } + + fn true_div_scalar(self, rhs: Self::Scalar) -> Self::TrueDivOut { + unary_kernel_owned(self, |a| ArithmeticKernel::true_div_scalar(a, rhs)) + } + + fn true_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::TrueDivOut { + unary_kernel_owned(rhs, |a| ArithmeticKernel::true_div_scalar_lhs(lhs, a)) + } + + fn legacy_div(self, rhs: Self) -> Self::Out { + apply_binary_kernel_broadcast_owned( + self, + rhs, + ArithmeticKernel::legacy_div, + ArithmeticKernel::legacy_div_scalar_lhs, + ArithmeticKernel::legacy_div_scalar, + ) + } + + fn legacy_div_scalar(self, rhs: Self::Scalar) -> Self::Out { + unary_kernel_owned(self, |a| ArithmeticKernel::legacy_div_scalar(a, rhs)) + } + + fn legacy_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::Out { + unary_kernel_owned(rhs, |a| ArithmeticKernel::legacy_div_scalar_lhs(lhs, a)) + } +} + +impl ArithmeticChunked for &ChunkedArray { + type Scalar = T::Native; + type Out = ChunkedArray; + type TrueDivOut = ChunkedArray<::TrueDivPolarsType>; + + fn wrapping_abs(self) -> Self::Out { + unary_kernel(self, |a| ArithmeticKernel::wrapping_abs(a.clone())) + } + + fn wrapping_neg(self) -> Self::Out { + unary_kernel(self, |a| ArithmeticKernel::wrapping_neg(a.clone())) + } + + fn wrapping_add(self, rhs: Self) -> Self::Out { + apply_binary_kernel_broadcast( + self, + rhs, + |l, r| ArithmeticKernel::wrapping_add(l.clone(), r.clone()), + |l, r| ArithmeticKernel::wrapping_add_scalar(r.clone(), l), + |l, r| ArithmeticKernel::wrapping_add_scalar(l.clone(), r), + ) + } + + fn wrapping_sub(self, rhs: Self) -> Self::Out { + apply_binary_kernel_broadcast( + self, + rhs, + |l, r| ArithmeticKernel::wrapping_sub(l.clone(), r.clone()), + |l, r| ArithmeticKernel::wrapping_sub_scalar_lhs(l, r.clone()), + |l, r| ArithmeticKernel::wrapping_sub_scalar(l.clone(), r), + ) + } + + fn wrapping_mul(self, rhs: Self) -> Self::Out { + apply_binary_kernel_broadcast( + self, + rhs, + |l, r| ArithmeticKernel::wrapping_mul(l.clone(), r.clone()), + |l, r| ArithmeticKernel::wrapping_mul_scalar(r.clone(), l), + |l, r| ArithmeticKernel::wrapping_mul_scalar(l.clone(), r), + ) + } + + fn wrapping_floor_div(self, rhs: Self) -> Self::Out { + apply_binary_kernel_broadcast( + self, + rhs, + |l, r| ArithmeticKernel::wrapping_floor_div(l.clone(), r.clone()), + |l, r| ArithmeticKernel::wrapping_floor_div_scalar_lhs(l, r.clone()), + |l, r| ArithmeticKernel::wrapping_floor_div_scalar(l.clone(), r), + ) + } + + fn wrapping_trunc_div(self, rhs: Self) -> Self::Out { + apply_binary_kernel_broadcast( + self, + rhs, + |l, r| ArithmeticKernel::wrapping_trunc_div(l.clone(), r.clone()), + |l, r| ArithmeticKernel::wrapping_trunc_div_scalar_lhs(l, r.clone()), + |l, r| ArithmeticKernel::wrapping_trunc_div_scalar(l.clone(), r), + ) + } + + fn wrapping_mod(self, rhs: Self) -> Self::Out { + apply_binary_kernel_broadcast( + self, + rhs, + |l, r| ArithmeticKernel::wrapping_mod(l.clone(), r.clone()), + |l, r| ArithmeticKernel::wrapping_mod_scalar_lhs(l, r.clone()), + |l, r| ArithmeticKernel::wrapping_mod_scalar(l.clone(), r), + ) + } + + fn wrapping_add_scalar(self, rhs: Self::Scalar) -> Self::Out { + unary_kernel(self, |a| { + ArithmeticKernel::wrapping_add_scalar(a.clone(), rhs) + }) + } + + fn wrapping_sub_scalar(self, rhs: Self::Scalar) -> Self::Out { + unary_kernel(self, |a| { + ArithmeticKernel::wrapping_sub_scalar(a.clone(), rhs) + }) + } + + fn wrapping_sub_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::Out { + unary_kernel(rhs, |a| { + ArithmeticKernel::wrapping_sub_scalar_lhs(lhs, a.clone()) + }) + } + + fn wrapping_mul_scalar(self, rhs: Self::Scalar) -> Self::Out { + unary_kernel(self, |a| { + ArithmeticKernel::wrapping_mul_scalar(a.clone(), rhs) + }) + } + + fn wrapping_floor_div_scalar(self, rhs: Self::Scalar) -> Self::Out { + unary_kernel(self, |a| { + ArithmeticKernel::wrapping_floor_div_scalar(a.clone(), rhs) + }) + } + + fn wrapping_floor_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::Out { + unary_kernel(rhs, |a| { + ArithmeticKernel::wrapping_floor_div_scalar_lhs(lhs, a.clone()) + }) + } + + fn wrapping_trunc_div_scalar(self, rhs: Self::Scalar) -> Self::Out { + unary_kernel(self, |a| { + ArithmeticKernel::wrapping_trunc_div_scalar(a.clone(), rhs) + }) + } + + fn wrapping_trunc_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::Out { + unary_kernel(rhs, |a| { + ArithmeticKernel::wrapping_trunc_div_scalar_lhs(lhs, a.clone()) + }) + } + + fn wrapping_mod_scalar(self, rhs: Self::Scalar) -> Self::Out { + unary_kernel(self, |a| { + ArithmeticKernel::wrapping_mod_scalar(a.clone(), rhs) + }) + } + + fn wrapping_mod_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::Out { + unary_kernel(rhs, |a| { + ArithmeticKernel::wrapping_mod_scalar_lhs(lhs, a.clone()) + }) + } + + fn true_div(self, rhs: Self) -> Self::TrueDivOut { + apply_binary_kernel_broadcast( + self, + rhs, + |l, r| ArithmeticKernel::true_div(l.clone(), r.clone()), + |l, r| ArithmeticKernel::true_div_scalar_lhs(l, r.clone()), + |l, r| ArithmeticKernel::true_div_scalar(l.clone(), r), + ) + } + + fn true_div_scalar(self, rhs: Self::Scalar) -> Self::TrueDivOut { + unary_kernel(self, |a| ArithmeticKernel::true_div_scalar(a.clone(), rhs)) + } + + fn true_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::TrueDivOut { + unary_kernel(rhs, |a| { + ArithmeticKernel::true_div_scalar_lhs(lhs, a.clone()) + }) + } + + fn legacy_div(self, rhs: Self) -> Self::Out { + apply_binary_kernel_broadcast( + self, + rhs, + |l, r| ArithmeticKernel::legacy_div(l.clone(), r.clone()), + |l, r| ArithmeticKernel::legacy_div_scalar_lhs(l, r.clone()), + |l, r| ArithmeticKernel::legacy_div_scalar(l.clone(), r), + ) + } + + fn legacy_div_scalar(self, rhs: Self::Scalar) -> Self::Out { + unary_kernel(self, |a| { + ArithmeticKernel::legacy_div_scalar(a.clone(), rhs) + }) + } + + fn legacy_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self::Out { + unary_kernel(rhs, |a| { + ArithmeticKernel::legacy_div_scalar_lhs(lhs, a.clone()) + }) + } +} diff --git a/crates/polars-core/src/chunked_array/array/iterator.rs b/crates/polars-core/src/chunked_array/array/iterator.rs new file mode 100644 index 000000000000..98576f8ab5d7 --- /dev/null +++ b/crates/polars-core/src/chunked_array/array/iterator.rs @@ -0,0 +1,230 @@ +use std::ptr::NonNull; + +use super::*; +use crate::chunked_array::list::iterator::AmortizedListIter; +use crate::series::amortized_iter::{AmortSeries, ArrayBox, unstable_series_container_and_ptr}; + +impl ArrayChunked { + /// This is an iterator over a [`ArrayChunked`] that save allocations. + /// A Series is: + /// 1. [`Arc`] + /// ChunkedArray is: + /// 2. Vec< 3. ArrayRef> + /// + /// The [`ArrayRef`] we indicated with 3. will be updated during iteration. + /// The Series will be pinned in memory, saving an allocation for + /// 1. Arc<..> + /// 2. Vec<...> + /// + /// # Warning + /// Though memory safe in the sense that it will not read unowned memory, UB, or memory leaks + /// this function still needs precautions. The returned should never be cloned or taken longer + /// than a single iteration, as every call on `next` of the iterator will change the contents of + /// that Series. + /// + /// # Safety + /// The lifetime of [AmortSeries] is bound to the iterator. Keeping it alive + /// longer than the iterator is UB. + pub fn amortized_iter(&self) -> AmortizedListIter> + '_> { + self.amortized_iter_with_name(PlSmallStr::EMPTY) + } + + /// This is an iterator over a [`ArrayChunked`] that save allocations. + /// A Series is: + /// 1. [`Arc`] + /// ChunkedArray is: + /// 2. Vec< 3. ArrayRef> + /// + /// The ArrayRef we indicated with 3. will be updated during iteration. + /// The Series will be pinned in memory, saving an allocation for + /// 1. Arc<..> + /// 2. Vec<...> + /// + /// If the returned `AmortSeries` is cloned, the local copy will be replaced and a new container + /// will be set. + pub fn amortized_iter_with_name( + &self, + name: PlSmallStr, + ) -> AmortizedListIter> + '_> { + // we create the series container from the inner array + // so that the container has the proper dtype. + let arr = self.downcast_iter().next().unwrap(); + let inner_values = arr.values(); + + let inner_dtype = self.inner_dtype(); + let iter_dtype = match inner_dtype { + #[cfg(feature = "dtype-struct")] + DataType::Struct(_) => inner_dtype.to_physical(), + // TODO: figure out how to deal with physical/logical distinction + // physical primitives like time, date etc. work + // physical nested need more + _ => inner_dtype.clone(), + }; + + // SAFETY: + // inner type passed as physical type + let (s, ptr) = + unsafe { unstable_series_container_and_ptr(name, inner_values.clone(), &iter_dtype) }; + + // SAFETY: `ptr` belongs to the `Series`. + unsafe { + AmortizedListIter::new( + self.len(), + s, + NonNull::new(ptr).unwrap(), + self.downcast_iter().flat_map(|arr| arr.iter()), + inner_dtype.clone(), + ) + } + } + + pub fn try_apply_amortized_to_list(&self, mut f: F) -> PolarsResult + where + F: FnMut(AmortSeries) -> PolarsResult, + { + if self.is_empty() { + return Ok(Series::new_empty( + self.name().clone(), + &DataType::List(Box::new(self.inner_dtype().clone())), + ) + .list() + .unwrap() + .clone()); + } + let mut fast_explode = self.null_count() == 0; + let mut ca: ListChunked = { + self.amortized_iter() + .map(|opt_v| { + opt_v + .map(|v| { + let out = f(v); + if let Ok(out) = &out { + if out.is_empty() { + fast_explode = false + } + }; + out + }) + .transpose() + }) + .collect::>()? + }; + ca.rename(self.name().clone()); + if fast_explode { + ca.set_fast_explode(); + } + Ok(ca) + } + + /// Apply a closure `F` to each array. + /// + /// # Safety + /// Return series of `F` must has the same dtype and number of elements as input. + #[must_use] + pub unsafe fn apply_amortized_same_type(&self, mut f: F) -> Self + where + F: FnMut(AmortSeries) -> Series, + { + if self.is_empty() { + return self.clone(); + } + self.amortized_iter() + .map(|opt_v| { + opt_v.map(|v| { + let out = f(v); + to_arr(&out) + }) + }) + .collect_ca_with_dtype(self.name().clone(), self.dtype().clone()) + } + + /// Try apply a closure `F` to each array. + /// + /// # Safety + /// Return series of `F` must has the same dtype and number of elements as input if it is Ok. + pub unsafe fn try_apply_amortized_same_type(&self, mut f: F) -> PolarsResult + where + F: FnMut(AmortSeries) -> PolarsResult, + { + if self.is_empty() { + return Ok(self.clone()); + } + self.amortized_iter() + .map(|opt_v| { + opt_v + .map(|v| { + let out = f(v)?; + Ok(to_arr(&out)) + }) + .transpose() + }) + .try_collect_ca_with_dtype(self.name().clone(), self.dtype().clone()) + } + + /// Zip with a `ChunkedArray` then apply a binary function `F` elementwise. + /// + /// # Safety + // Return series of `F` must has the same dtype and number of elements as input series. + #[must_use] + pub unsafe fn zip_and_apply_amortized_same_type<'a, T, F>( + &'a self, + ca: &'a ChunkedArray, + mut f: F, + ) -> Self + where + T: PolarsDataType, + F: FnMut(Option, Option>) -> Option, + { + if self.is_empty() { + return self.clone(); + } + self.amortized_iter() + .zip(ca.iter()) + .map(|(opt_s, opt_v)| { + let out = f(opt_s, opt_v); + out.map(|s| to_arr(&s)) + }) + .collect_ca_with_dtype(self.name().clone(), self.dtype().clone()) + } + + /// Apply a closure `F` elementwise. + #[must_use] + pub fn apply_amortized_generic(&self, f: F) -> ChunkedArray + where + V: PolarsDataType, + F: FnMut(Option) -> Option + Copy, + V::Array: ArrayFromIter>, + { + self.amortized_iter().map(f).collect_ca(self.name().clone()) + } + + /// Try apply a closure `F` elementwise. + pub fn try_apply_amortized_generic(&self, f: F) -> PolarsResult> + where + V: PolarsDataType, + F: FnMut(Option) -> PolarsResult> + Copy, + V::Array: ArrayFromIter>, + { + { + self.amortized_iter() + .map(f) + .try_collect_ca(self.name().clone()) + } + } + + pub fn for_each_amortized(&self, f: F) + where + F: FnMut(Option), + { + self.amortized_iter().for_each(f) + } +} + +fn to_arr(s: &Series) -> ArrayRef { + if s.chunks().len() > 1 { + let s = s.rechunk(); + s.chunks()[0].clone() + } else { + s.chunks()[0].clone() + } +} diff --git a/crates/polars-core/src/chunked_array/array/mod.rs b/crates/polars-core/src/chunked_array/array/mod.rs new file mode 100644 index 000000000000..cd5e23a4c59a --- /dev/null +++ b/crates/polars-core/src/chunked_array/array/mod.rs @@ -0,0 +1,185 @@ +//! Special fixed-size-list utility methods + +mod iterator; + +use std::borrow::Cow; + +use either::Either; + +use crate::prelude::*; + +impl ArrayChunked { + /// Get the inner data type of the fixed size list. + pub fn inner_dtype(&self) -> &DataType { + match self.dtype() { + DataType::Array(dt, _size) => dt.as_ref(), + _ => unreachable!(), + } + } + + pub fn width(&self) -> usize { + match self.dtype() { + DataType::Array(_dt, size) => *size, + _ => unreachable!(), + } + } + + /// # Safety + /// The caller must ensure that the logical type given fits the physical type of the array. + pub unsafe fn to_logical(&mut self, inner_dtype: DataType) { + debug_assert_eq!(&inner_dtype.to_physical(), self.inner_dtype()); + let width = self.width(); + let fld = Arc::make_mut(&mut self.field); + fld.coerce(DataType::Array(Box::new(inner_dtype), width)) + } + + /// Convert the datatype of the array into the physical datatype. + pub fn to_physical_repr(&self) -> Cow { + let Cow::Owned(physical_repr) = self.get_inner().to_physical_repr() else { + return Cow::Borrowed(self); + }; + + let chunk_len_validity_iter = + if physical_repr.chunks().len() == 1 && self.chunks().len() > 1 { + // Physical repr got rechunked, rechunk our validity as well. + Either::Left(std::iter::once((self.len(), self.rechunk_validity()))) + } else { + // No rechunking, expect the same number of chunks. + assert_eq!(self.chunks().len(), physical_repr.chunks().len()); + Either::Right( + self.chunks() + .iter() + .map(|c| (c.len(), c.validity().cloned())), + ) + }; + + let width = self.width(); + let chunks: Vec<_> = chunk_len_validity_iter + .zip(physical_repr.into_chunks()) + .map(|((len, validity), values)| { + FixedSizeListArray::new( + ArrowDataType::FixedSizeList( + Box::new(ArrowField::new( + PlSmallStr::from_static("item"), + values.dtype().clone(), + true, + )), + width, + ), + len, + values, + validity, + ) + .to_boxed() + }) + .collect(); + + let name = self.name().clone(); + let dtype = DataType::Array(Box::new(self.inner_dtype().to_physical()), width); + Cow::Owned(unsafe { ArrayChunked::from_chunks_and_dtype_unchecked(name, chunks, dtype) }) + } + + /// Convert a non-logical [`ArrayChunked`] back into a logical [`ArrayChunked`] without casting. + /// + /// # Safety + /// + /// This can lead to invalid memory access in downstream code. + pub unsafe fn from_physical_unchecked(&self, to_inner_dtype: DataType) -> PolarsResult { + debug_assert!(!self.inner_dtype().is_logical()); + + let chunks = self + .downcast_iter() + .map(|chunk| chunk.values()) + .cloned() + .collect(); + + let inner = unsafe { + Series::from_chunks_and_dtype_unchecked(PlSmallStr::EMPTY, chunks, self.inner_dtype()) + }; + let inner = unsafe { inner.from_physical_unchecked(&to_inner_dtype) }?; + + let chunks: Vec<_> = self + .downcast_iter() + .zip(inner.into_chunks()) + .map(|(chunk, values)| { + FixedSizeListArray::new( + ArrowDataType::FixedSizeList( + Box::new(ArrowField::new( + PlSmallStr::from_static("item"), + values.dtype().clone(), + true, + )), + self.width(), + ), + chunk.len(), + values, + chunk.validity().cloned(), + ) + .to_boxed() + }) + .collect(); + + let name = self.name().clone(); + let dtype = DataType::Array(Box::new(to_inner_dtype), self.width()); + Ok(unsafe { Self::from_chunks_and_dtype_unchecked(name, chunks, dtype) }) + } + + /// Get the inner values as `Series` + pub fn get_inner(&self) -> Series { + let chunks: Vec<_> = self.downcast_iter().map(|c| c.values().clone()).collect(); + + // SAFETY: Data type of arrays matches because they are chunks from the same array. + unsafe { + Series::from_chunks_and_dtype_unchecked(self.name().clone(), chunks, self.inner_dtype()) + } + } + + /// Ignore the list indices and apply `func` to the inner type as [`Series`]. + pub fn apply_to_inner( + &self, + func: &dyn Fn(Series) -> PolarsResult, + ) -> PolarsResult { + // Rechunk or the generated Series will have wrong length. + let ca = self.rechunk(); + let arr = ca.downcast_as_array(); + + // SAFETY: + // Inner dtype is passed correctly + let elements = unsafe { + Series::from_chunks_and_dtype_unchecked( + self.name().clone(), + vec![arr.values().clone()], + ca.inner_dtype(), + ) + }; + + let expected_len = elements.len(); + let out: Series = func(elements)?; + polars_ensure!( + out.len() == expected_len, + ComputeError: "the function should apply element-wise, it removed elements instead" + ); + let out = out.rechunk(); + let values = out.chunks()[0].clone(); + + let inner_dtype = FixedSizeListArray::default_datatype(values.dtype().clone(), ca.width()); + let arr = FixedSizeListArray::new(inner_dtype, arr.len(), values, arr.validity().cloned()); + + Ok(unsafe { + ArrayChunked::from_chunks_and_dtype_unchecked( + self.name().clone(), + vec![arr.into_boxed()], + DataType::Array(Box::new(out.dtype().clone()), self.width()), + ) + }) + } + + /// Recurse nested types until we are at the leaf array. + pub fn get_leaf_array(&self) -> Series { + let mut current = self.get_inner(); + while let Some(child_array) = current.try_array() { + current = child_array.get_inner(); + } + current + } +} diff --git a/crates/polars-core/src/chunked_array/binary.rs b/crates/polars-core/src/chunked_array/binary.rs new file mode 100644 index 000000000000..43ace422562f --- /dev/null +++ b/crates/polars-core/src/chunked_array/binary.rs @@ -0,0 +1,78 @@ +use std::hash::BuildHasher; + +use polars_utils::aliases::PlRandomState; +use polars_utils::hashing::BytesHash; +use rayon::prelude::*; + +use crate::POOL; +use crate::prelude::*; +use crate::utils::{_set_partition_size, _split_offsets}; + +#[inline] +fn fill_bytes_hashes<'a, T>( + ca: &'a ChunkedArray, + null_h: u64, + hb: PlRandomState, +) -> Vec> +where + T: PolarsDataType, + <::Array as StaticArray>::ValueT<'a>: AsRef<[u8]>, +{ + let mut byte_hashes = Vec::with_capacity(ca.len()); + for arr in ca.downcast_iter() { + for opt_b in arr.iter() { + let opt_b = opt_b.as_ref().map(|v| v.as_ref()); + // SAFETY: + // the underlying data is tied to self + let opt_b = unsafe { std::mem::transmute::, Option<&'a [u8]>>(opt_b) }; + let hash = match opt_b { + Some(s) => hb.hash_one(s), + None => null_h, + }; + byte_hashes.push(BytesHash::new(opt_b, hash)) + } + } + byte_hashes +} + +impl ChunkedArray +where + T: PolarsDataType, + for<'a> ::ValueT<'a>: AsRef<[u8]>, +{ + #[allow(clippy::needless_lifetimes)] + pub fn to_bytes_hashes<'a>( + &'a self, + mut multithreaded: bool, + hb: PlRandomState, + ) -> Vec>> { + multithreaded &= POOL.current_num_threads() > 1; + let null_h = hb.hash_one(0xde259df92c607d49_u64); + + if multithreaded { + let n_partitions = _set_partition_size(); + + let split = _split_offsets(self.len(), n_partitions); + + POOL.install(|| { + split + .into_par_iter() + .map(|(offset, len)| { + let ca = self.slice(offset as i64, len); + let byte_hashes = fill_bytes_hashes(&ca, null_h, hb); + + // SAFETY: + // the underlying data is tied to self + unsafe { + std::mem::transmute::>, Vec>>( + byte_hashes, + ) + } + }) + .collect::>() + }) + } else { + vec![fill_bytes_hashes(self, null_h, hb)] + } + } +} diff --git a/crates/polars-core/src/chunked_array/bitwise.rs b/crates/polars-core/src/chunked_array/bitwise.rs new file mode 100644 index 000000000000..ad179e0c1c28 --- /dev/null +++ b/crates/polars-core/src/chunked_array/bitwise.rs @@ -0,0 +1,205 @@ +use std::ops::{BitAnd, BitOr, BitXor, Not}; + +use arrow::compute; +use arrow::compute::bitwise; +use arrow::compute::utils::combine_validities_and; + +use super::*; +use crate::chunked_array::arity::apply_binary_kernel_broadcast; + +impl BitAnd for &ChunkedArray +where + T: PolarsIntegerType, + T::Native: BitAnd, +{ + type Output = ChunkedArray; + + fn bitand(self, rhs: Self) -> Self::Output { + apply_binary_kernel_broadcast( + self, + rhs, + bitwise::and, + |l, r| bitwise::and_scalar(r, &l), + |l, r| bitwise::and_scalar(l, &r), + ) + } +} + +impl BitOr for &ChunkedArray +where + T: PolarsIntegerType, + T::Native: BitOr, +{ + type Output = ChunkedArray; + + fn bitor(self, rhs: Self) -> Self::Output { + apply_binary_kernel_broadcast( + self, + rhs, + bitwise::or, + |l, r| bitwise::or_scalar(r, &l), + |l, r| bitwise::or_scalar(l, &r), + ) + } +} + +impl BitXor for &ChunkedArray +where + T: PolarsIntegerType, + T::Native: BitXor, +{ + type Output = ChunkedArray; + + fn bitxor(self, rhs: Self) -> Self::Output { + apply_binary_kernel_broadcast( + self, + rhs, + bitwise::xor, + |l, r| bitwise::xor_scalar(r, &l), + |l, r| bitwise::xor_scalar(l, &r), + ) + } +} + +impl BitOr for &BooleanChunked { + type Output = BooleanChunked; + + fn bitor(self, rhs: Self) -> Self::Output { + match (self.len(), rhs.len()) { + // make sure that we fall through if both are equal unit lengths + // otherwise we stackoverflow + (1, 1) => {}, + (1, _) => { + return match self.get(0) { + Some(true) => BooleanChunked::full(self.name().clone(), true, rhs.len()), + Some(false) => { + let mut rhs = rhs.clone(); + rhs.rename(self.name().clone()); + rhs + }, + None => &self.new_from_index(0, rhs.len()) | rhs, + }; + }, + (_, 1) => { + return match rhs.get(0) { + Some(true) => BooleanChunked::full(self.name().clone(), true, self.len()), + Some(false) => self.clone(), + None => self | &rhs.new_from_index(0, self.len()), + }; + }, + _ => {}, + } + + arity::binary(self, rhs, compute::boolean_kleene::or) + } +} + +impl BitOr for BooleanChunked { + type Output = BooleanChunked; + + fn bitor(self, rhs: Self) -> Self::Output { + (&self).bitor(&rhs) + } +} + +impl BitXor for &BooleanChunked { + type Output = BooleanChunked; + + fn bitxor(self, rhs: Self) -> Self::Output { + match (self.len(), rhs.len()) { + // make sure that we fall through if both are equal unit lengths + // otherwise we stackoverflow + (1, 1) => {}, + (1, _) => { + return match self.get(0) { + Some(true) => { + let mut rhs = rhs.not(); + rhs.rename(self.name().clone()); + rhs + }, + Some(false) => { + let mut rhs = rhs.clone(); + rhs.rename(self.name().clone()); + rhs + }, + None => &self.new_from_index(0, rhs.len()) | rhs, + }; + }, + (_, 1) => { + return match rhs.get(0) { + Some(true) => self.not(), + Some(false) => self.clone(), + None => self | &rhs.new_from_index(0, self.len()), + }; + }, + _ => {}, + } + + arity::binary(self, rhs, |l_arr, r_arr| { + let validity = combine_validities_and(l_arr.validity(), r_arr.validity()); + let values = l_arr.values() ^ r_arr.values(); + BooleanArray::from_data_default(values, validity) + }) + } +} + +impl BitXor for BooleanChunked { + type Output = BooleanChunked; + + fn bitxor(self, rhs: Self) -> Self::Output { + (&self).bitxor(&rhs) + } +} + +impl BitAnd for &BooleanChunked { + type Output = BooleanChunked; + + fn bitand(self, rhs: Self) -> Self::Output { + match (self.len(), rhs.len()) { + // make sure that we fall through if both are equal unit lengths + // otherwise we stackoverflow + (1, 1) => {}, + (1, _) => { + return match self.get(0) { + Some(true) => rhs.clone().with_name(self.name().clone()), + Some(false) => BooleanChunked::full(self.name().clone(), false, rhs.len()), + None => &self.new_from_index(0, rhs.len()) & rhs, + }; + }, + (_, 1) => { + return match rhs.get(0) { + Some(true) => self.clone(), + Some(false) => BooleanChunked::full(self.name().clone(), false, self.len()), + None => self & &rhs.new_from_index(0, self.len()), + }; + }, + _ => {}, + } + + arity::binary(self, rhs, compute::boolean_kleene::and) + } +} + +impl BitAnd for BooleanChunked { + type Output = BooleanChunked; + + fn bitand(self, rhs: Self) -> Self::Output { + (&self).bitand(&rhs) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn guard_so_issue_2494() { + // this cause a stack overflow + let a = BooleanChunked::new(PlSmallStr::from_static("a"), [None]); + let b = BooleanChunked::new(PlSmallStr::from_static("b"), [None]); + + assert_eq!((&a).bitand(&b).null_count(), 1); + assert_eq!((&a).bitor(&b).null_count(), 1); + assert_eq!((&a).bitxor(&b).null_count(), 1); + } +} diff --git a/crates/polars-core/src/chunked_array/builder/boolean.rs b/crates/polars-core/src/chunked_array/builder/boolean.rs new file mode 100644 index 000000000000..649db3d6252e --- /dev/null +++ b/crates/polars-core/src/chunked_array/builder/boolean.rs @@ -0,0 +1,39 @@ +use super::*; + +#[derive(Clone)] +pub struct BooleanChunkedBuilder { + pub(crate) array_builder: MutableBooleanArray, + pub(crate) field: Field, +} + +impl ChunkedBuilder for BooleanChunkedBuilder { + /// Appends a value of type `T` into the builder + #[inline] + fn append_value(&mut self, v: bool) { + self.array_builder.push_value(v); + } + + /// Appends a null slot into the builder + #[inline] + fn append_null(&mut self) { + self.array_builder.push_null(); + } + + fn finish(mut self) -> BooleanChunked { + let arr = self.array_builder.as_box(); + ChunkedArray::new_with_compute_len(Arc::new(self.field), vec![arr]) + } + + fn shrink_to_fit(&mut self) { + self.array_builder.shrink_to_fit() + } +} + +impl BooleanChunkedBuilder { + pub fn new(name: PlSmallStr, capacity: usize) -> Self { + BooleanChunkedBuilder { + array_builder: MutableBooleanArray::with_capacity(capacity), + field: Field::new(name, DataType::Boolean), + } + } +} diff --git a/crates/polars-core/src/chunked_array/builder/fixed_size_list.rs b/crates/polars-core/src/chunked_array/builder/fixed_size_list.rs new file mode 100644 index 000000000000..9bdc6fb98d1a --- /dev/null +++ b/crates/polars-core/src/chunked_array/builder/fixed_size_list.rs @@ -0,0 +1,158 @@ +use arrow::types::NativeType; +use polars_utils::pl_str::PlSmallStr; + +use crate::prelude::*; + +pub(crate) struct FixedSizeListNumericBuilder { + inner: Option>>, + width: usize, + name: PlSmallStr, + logical_dtype: DataType, +} + +impl FixedSizeListNumericBuilder { + /// # Safety + /// + /// The caller must ensure that the physical numerical type match logical type. + pub(crate) unsafe fn new( + name: PlSmallStr, + width: usize, + capacity: usize, + logical_dtype: DataType, + ) -> Self { + let mp = MutablePrimitiveArray::::with_capacity(capacity * width); + let inner = Some(MutableFixedSizeListArray::new(mp, width)); + Self { + inner, + width, + name, + logical_dtype, + } + } +} + +pub(crate) trait FixedSizeListBuilder { + unsafe fn push_unchecked(&mut self, arr: &dyn Array, offset: usize); + unsafe fn push_null(&mut self); + fn finish(&mut self) -> ArrayChunked; +} + +impl FixedSizeListBuilder for FixedSizeListNumericBuilder { + #[inline] + unsafe fn push_unchecked(&mut self, arr: &dyn Array, offset: usize) { + let start = offset * self.width; + let end = start + self.width; + let arr = arr + .as_any() + .downcast_ref::>() + .unwrap_unchecked(); + let inner = self.inner.as_mut().unwrap_unchecked(); + + let values = arr.values().as_slice(); + let validity = arr.validity(); + if let Some(validity) = validity { + let iter = (start..end).map(|i| { + if validity.get_bit_unchecked(i) { + Some(*values.get_unchecked(i)) + } else { + None + } + }); + inner.push_unchecked(Some(iter)) + } else { + let iter = (start..end).map(|i| Some(*values.get_unchecked(i))); + inner.push_unchecked(Some(iter)) + } + } + + #[inline] + unsafe fn push_null(&mut self) { + let inner = self.inner.as_mut().unwrap_unchecked(); + inner.push_null() + } + + fn finish(&mut self) -> ArrayChunked { + let arr: FixedSizeListArray = self.inner.take().unwrap().into(); + // SAFETY: physical type matches the logical + unsafe { + ChunkedArray::from_chunks_and_dtype( + self.name.clone(), + vec![Box::new(arr)], + DataType::Array(Box::new(self.logical_dtype.clone()), self.width), + ) + } + } +} + +pub(crate) struct AnonymousOwnedFixedSizeListBuilder { + inner: fixed_size_list::AnonymousBuilder, + name: PlSmallStr, + inner_dtype: Option, +} + +impl AnonymousOwnedFixedSizeListBuilder { + pub(crate) fn new( + name: PlSmallStr, + width: usize, + capacity: usize, + inner_dtype: Option, + ) -> Self { + let inner = fixed_size_list::AnonymousBuilder::new(capacity, width); + Self { + inner, + name, + inner_dtype, + } + } +} + +impl FixedSizeListBuilder for AnonymousOwnedFixedSizeListBuilder { + #[inline] + unsafe fn push_unchecked(&mut self, arr: &dyn Array, offset: usize) { + let arr = arr.sliced_unchecked(offset * self.inner.width, self.inner.width); + self.inner.push(arr) + } + + #[inline] + unsafe fn push_null(&mut self) { + self.inner.push_null() + } + + fn finish(&mut self) -> ArrayChunked { + let arr = std::mem::take(&mut self.inner) + .finish( + self.inner_dtype + .as_ref() + .map(|dt| dt.to_arrow(CompatLevel::newest())) + .as_ref(), + ) + .unwrap(); + ChunkedArray::with_chunk(self.name.clone(), arr) + } +} + +pub(crate) fn get_fixed_size_list_builder( + inner_type_logical: &DataType, + capacity: usize, + width: usize, + name: PlSmallStr, +) -> PolarsResult> { + let phys_dtype = inner_type_logical.to_physical(); + + let builder = if phys_dtype.is_primitive_numeric() { + with_match_physical_numeric_type!(phys_dtype, |$T| { + // SAFETY: physical type match logical type + unsafe { + Box::new(FixedSizeListNumericBuilder::<$T>::new(name, width, capacity,inner_type_logical.clone())) as Box + } + }) + } else { + Box::new(AnonymousOwnedFixedSizeListBuilder::new( + name, + width, + capacity, + Some(inner_type_logical.clone()), + )) + }; + Ok(builder) +} diff --git a/crates/polars-core/src/chunked_array/builder/list/anonymous.rs b/crates/polars-core/src/chunked_array/builder/list/anonymous.rs new file mode 100644 index 000000000000..eab5df634be2 --- /dev/null +++ b/crates/polars-core/src/chunked_array/builder/list/anonymous.rs @@ -0,0 +1,179 @@ +use super::*; + +pub struct AnonymousListBuilder<'a> { + name: PlSmallStr, + builder: AnonymousBuilder<'a>, + fast_explode: bool, + inner_dtype: DtypeMerger, +} + +impl Default for AnonymousListBuilder<'_> { + fn default() -> Self { + Self::new(PlSmallStr::EMPTY, 0, None) + } +} + +impl<'a> AnonymousListBuilder<'a> { + pub fn new(name: PlSmallStr, capacity: usize, inner_dtype: Option) -> Self { + Self { + name, + builder: AnonymousBuilder::new(capacity), + fast_explode: true, + inner_dtype: DtypeMerger::new(inner_dtype), + } + } + + pub fn append_opt_series(&mut self, opt_s: Option<&'a Series>) -> PolarsResult<()> { + match opt_s { + Some(s) => return self.append_series(s), + None => { + self.append_null(); + }, + } + Ok(()) + } + + pub fn append_opt_array(&mut self, opt_s: Option<&'a dyn Array>) { + match opt_s { + Some(s) => self.append_array(s), + None => { + self.append_null(); + }, + } + } + + pub fn append_array(&mut self, arr: &'a dyn Array) { + self.builder.push(arr) + } + + #[inline] + pub fn append_null(&mut self) { + self.fast_explode = false; + self.builder.push_null(); + } + + #[inline] + pub fn append_empty(&mut self) { + self.fast_explode = false; + self.builder.push_empty() + } + + pub fn append_series(&mut self, s: &'a Series) -> PolarsResult<()> { + match s.dtype() { + // Empty arrays tend to be null type and thus differ + // if we would push it the concat would fail. + DataType::Null if s.is_empty() => self.append_empty(), + dt => self.inner_dtype.update(dt)?, + } + self.builder.push_multiple(s.chunks()); + Ok(()) + } + + pub fn finish(&mut self) -> ListChunked { + // Don't use self from here on out. + let slf = std::mem::take(self); + if slf.builder.is_empty() { + ListChunked::full_null_with_dtype( + slf.name.clone(), + 0, + &slf.inner_dtype.materialize().unwrap_or(DataType::Null), + ) + } else { + let inner_dtype = slf.inner_dtype.materialize(); + + let inner_dtype_physical = inner_dtype + .as_ref() + .map(|dt| dt.to_physical().to_arrow(CompatLevel::newest())); + let arr = slf.builder.finish(inner_dtype_physical.as_ref()).unwrap(); + + let list_dtype_logical = match inner_dtype { + None => DataType::from_arrow_dtype(arr.dtype()), + Some(dt) => DataType::List(Box::new(dt)), + }; + + let mut ca = ListChunked::with_chunk(PlSmallStr::EMPTY, arr); + if slf.fast_explode { + ca.set_fast_explode(); + } + ca.field = Arc::new(Field::new(slf.name.clone(), list_dtype_logical)); + ca + } + } +} + +pub struct AnonymousOwnedListBuilder { + name: PlSmallStr, + builder: AnonymousBuilder<'static>, + owned: Vec, + inner_dtype: DtypeMerger, + fast_explode: bool, +} + +impl Default for AnonymousOwnedListBuilder { + fn default() -> Self { + Self::new(PlSmallStr::EMPTY, 0, None) + } +} + +impl ListBuilderTrait for AnonymousOwnedListBuilder { + fn append_series(&mut self, s: &Series) -> PolarsResult<()> { + if s.is_empty() { + self.append_empty(); + } else { + unsafe { + self.inner_dtype.update(s.dtype())?; + self.builder + .push_multiple(&*(s.chunks().as_ref() as *const [ArrayRef])); + } + // This make sure that the underlying ArrayRef's are not dropped. + self.owned.push(s.clone()); + } + Ok(()) + } + + #[inline] + fn append_null(&mut self) { + self.fast_explode = false; + self.builder.push_null() + } + + fn finish(&mut self) -> ListChunked { + let inner_dtype = std::mem::take(&mut self.inner_dtype).materialize(); + // Don't use self from here on out. + let slf = std::mem::take(self); + let inner_dtype_physical = inner_dtype + .as_ref() + .map(|dt| dt.to_physical().to_arrow(CompatLevel::newest())); + let arr = slf.builder.finish(inner_dtype_physical.as_ref()).unwrap(); + + let list_dtype_logical = match inner_dtype { + None => DataType::from_arrow_dtype(arr.dtype()), + Some(dt) => DataType::List(Box::new(dt)), + }; + + let mut ca = ListChunked::with_chunk(PlSmallStr::EMPTY, arr); + if slf.fast_explode { + ca.set_fast_explode(); + } + ca.field = Arc::new(Field::new(slf.name.clone(), list_dtype_logical)); + ca + } +} + +impl AnonymousOwnedListBuilder { + pub fn new(name: PlSmallStr, capacity: usize, inner_dtype: Option) -> Self { + Self { + name, + builder: AnonymousBuilder::new(capacity), + owned: Vec::with_capacity(capacity), + inner_dtype: DtypeMerger::new(inner_dtype), + fast_explode: true, + } + } + + #[inline] + pub fn append_empty(&mut self) { + self.fast_explode = false; + self.builder.push_empty() + } +} diff --git a/crates/polars-core/src/chunked_array/builder/list/binary.rs b/crates/polars-core/src/chunked_array/builder/list/binary.rs new file mode 100644 index 000000000000..d55a69a2eace --- /dev/null +++ b/crates/polars-core/src/chunked_array/builder/list/binary.rs @@ -0,0 +1,177 @@ +use super::*; + +pub struct ListStringChunkedBuilder { + builder: LargeListBinViewBuilder, + field: Field, + fast_explode: bool, +} + +impl ListStringChunkedBuilder { + pub fn new(name: PlSmallStr, capacity: usize, values_capacity: usize) -> Self { + let values = MutableBinaryViewArray::with_capacity(values_capacity); + let builder = LargeListBinViewBuilder::new_with_capacity(values, capacity); + let field = Field::new(name, DataType::List(Box::new(DataType::String))); + + ListStringChunkedBuilder { + builder, + field, + fast_explode: true, + } + } + + #[inline] + pub fn append_trusted_len_iter<'a, I: Iterator> + TrustedLen>( + &mut self, + iter: I, + ) { + if iter.size_hint().0 == 0 { + self.fast_explode = false; + } + // SAFETY: + // trusted len, trust the type system + self.builder.mut_values().extend_trusted_len(iter); + self.builder.try_push_valid().unwrap(); + } + + #[inline] + pub fn append_values_iter<'a, I: Iterator>(&mut self, iter: I) { + if iter.size_hint().0 == 0 { + self.fast_explode = false; + } + self.builder.mut_values().extend_values(iter); + self.builder.try_push_valid().unwrap(); + } + + #[inline] + pub(crate) fn append(&mut self, ca: &StringChunked) { + if ca.is_empty() { + self.fast_explode = false; + } + for arr in ca.downcast_iter() { + if arr.null_count() == 0 { + self.builder + .mut_values() + .extend_values(arr.non_null_values_iter()); + } else { + self.builder.mut_values().extend_trusted_len(arr.iter()) + } + } + self.builder.try_push_valid().unwrap(); + } +} + +impl ListBuilderTrait for ListStringChunkedBuilder { + #[inline] + fn append_null(&mut self) { + self.fast_explode = false; + self.builder.push_null(); + } + + #[inline] + fn append_series(&mut self, s: &Series) -> PolarsResult<()> { + if s.is_empty() { + self.fast_explode = false; + } + let ca = s.str()?; + self.append(ca); + Ok(()) + } + + fn field(&self) -> &Field { + &self.field + } + + fn inner_array(&mut self) -> ArrayRef { + self.builder.as_box() + } + + fn fast_explode(&self) -> bool { + self.fast_explode + } +} + +pub struct ListBinaryChunkedBuilder { + builder: LargeListBinViewBuilder<[u8]>, + field: Field, + fast_explode: bool, +} + +impl ListBinaryChunkedBuilder { + pub fn new(name: PlSmallStr, capacity: usize, values_capacity: usize) -> Self { + let values = MutablePlBinary::with_capacity(values_capacity); + let builder = LargeListBinViewBuilder::new_with_capacity(values, capacity); + let field = Field::new(name, DataType::List(Box::new(DataType::Binary))); + + ListBinaryChunkedBuilder { + builder, + field, + fast_explode: true, + } + } + + pub fn append_trusted_len_iter<'a, I: Iterator> + TrustedLen>( + &mut self, + iter: I, + ) { + if iter.size_hint().0 == 0 { + self.fast_explode = false; + } + // SAFETY: + // trusted len, trust the type system + self.builder.mut_values().extend_trusted_len(iter); + self.builder.try_push_valid().unwrap(); + } + + pub fn append_values_iter<'a, I: Iterator>(&mut self, iter: I) { + if iter.size_hint().0 == 0 { + self.fast_explode = false; + } + self.builder.mut_values().extend_values(iter); + self.builder.try_push_valid().unwrap(); + } + + pub(crate) fn append(&mut self, ca: &BinaryChunked) { + if ca.is_empty() { + self.fast_explode = false; + } + for arr in ca.downcast_iter() { + if arr.null_count() == 0 { + self.builder + .mut_values() + .extend_values(arr.non_null_values_iter()); + } else { + self.builder.mut_values().extend_trusted_len(arr.iter()) + } + } + self.builder.try_push_valid().unwrap(); + } +} + +impl ListBuilderTrait for ListBinaryChunkedBuilder { + #[inline] + fn append_null(&mut self) { + self.fast_explode = false; + self.builder.push_null(); + } + + fn append_series(&mut self, s: &Series) -> PolarsResult<()> { + if s.is_empty() { + self.fast_explode = false; + } + let ca = s.binary()?; + self.append(ca); + Ok(()) + } + + fn field(&self) -> &Field { + &self.field + } + + fn inner_array(&mut self) -> ArrayRef { + self.builder.as_box() + } + + fn fast_explode(&self) -> bool { + self.fast_explode + } +} diff --git a/crates/polars-core/src/chunked_array/builder/list/boolean.rs b/crates/polars-core/src/chunked_array/builder/list/boolean.rs new file mode 100644 index 000000000000..8142d1a50954 --- /dev/null +++ b/crates/polars-core/src/chunked_array/builder/list/boolean.rs @@ -0,0 +1,71 @@ +use super::*; + +pub struct ListBooleanChunkedBuilder { + builder: LargeListBooleanBuilder, + field: Field, + fast_explode: bool, +} + +impl ListBooleanChunkedBuilder { + pub fn new(name: PlSmallStr, capacity: usize, values_capacity: usize) -> Self { + let values = MutableBooleanArray::with_capacity(values_capacity); + let builder = LargeListBooleanBuilder::new_with_capacity(values, capacity); + let field = Field::new(name, DataType::List(Box::new(DataType::Boolean))); + + Self { + builder, + field, + fast_explode: true, + } + } + + #[inline] + pub fn append_iter> + TrustedLen>(&mut self, iter: I) { + let values = self.builder.mut_values(); + + if iter.size_hint().0 == 0 { + self.fast_explode = false; + } + // SAFETY: + // trusted len, trust the type system + unsafe { values.extend_trusted_len_unchecked(iter) }; + self.builder.try_push_valid().unwrap(); + } + + #[inline] + pub(crate) fn append(&mut self, ca: &BooleanChunked) { + if ca.is_empty() { + self.fast_explode = false; + } + let value_builder = self.builder.mut_values(); + value_builder.extend(ca); + self.builder.try_push_valid().unwrap(); + } +} + +impl ListBuilderTrait for ListBooleanChunkedBuilder { + #[inline] + fn append_null(&mut self) { + self.fast_explode = false; + self.builder.push_null(); + } + + #[inline] + fn append_series(&mut self, s: &Series) -> PolarsResult<()> { + let ca = s.bool()?; + self.append(ca); + Ok(()) + } + + fn field(&self) -> &Field { + &self.field + } + + fn inner_array(&mut self) -> ArrayRef { + self.builder.as_box() + } + + fn fast_explode(&self) -> bool { + self.fast_explode + } +} diff --git a/crates/polars-core/src/chunked_array/builder/list/categorical.rs b/crates/polars-core/src/chunked_array/builder/list/categorical.rs new file mode 100644 index 000000000000..1175b669bf06 --- /dev/null +++ b/crates/polars-core/src/chunked_array/builder/list/categorical.rs @@ -0,0 +1,235 @@ +use std::hash::BuildHasher; + +use hashbrown::HashTable; +use hashbrown::hash_table::Entry; + +use super::*; + +pub fn create_categorical_chunked_listbuilder( + name: PlSmallStr, + ordering: CategoricalOrdering, + capacity: usize, + values_capacity: usize, + rev_map: Arc, +) -> Box { + match &*rev_map { + RevMapping::Local(_, h) => Box::new(ListLocalCategoricalChunkedBuilder::new( + name, + ordering, + capacity, + values_capacity, + *h, + )), + RevMapping::Global(_, _, _) => Box::new(ListGlobalCategoricalChunkedBuilder::new( + name, + ordering, + capacity, + values_capacity, + rev_map, + )), + } +} + +pub struct ListEnumCategoricalChunkedBuilder { + inner: ListPrimitiveChunkedBuilder, + ordering: CategoricalOrdering, + rev_map: RevMapping, +} + +impl ListEnumCategoricalChunkedBuilder { + pub(super) fn new( + name: PlSmallStr, + ordering: CategoricalOrdering, + capacity: usize, + values_capacity: usize, + rev_map: RevMapping, + ) -> Self { + Self { + inner: ListPrimitiveChunkedBuilder::new( + name, + capacity, + values_capacity, + DataType::UInt32, + ), + ordering, + rev_map, + } + } +} + +impl ListBuilderTrait for ListEnumCategoricalChunkedBuilder { + fn append_series(&mut self, s: &Series) -> PolarsResult<()> { + let DataType::Enum(Some(rev_map), _) = s.dtype() else { + polars_bail!(ComputeError: "expected enum type") + }; + polars_ensure!(rev_map.same_src(&self.rev_map),ComputeError: "incompatible enum types"); + self.inner.append_series(s) + } + + fn append_null(&mut self) { + self.inner.append_null() + } + + fn finish(&mut self) -> ListChunked { + let inner_dtype = DataType::Enum(Some(Arc::new(self.rev_map.clone())), self.ordering); + let mut ca = self.inner.finish(); + unsafe { ca.set_dtype(DataType::List(Box::new(inner_dtype))) } + ca + } +} + +struct ListLocalCategoricalChunkedBuilder { + inner: ListPrimitiveChunkedBuilder, + idx_lookup: HashTable, + ordering: CategoricalOrdering, + categories: MutablePlString, + categories_hash: u128, +} + +impl ListLocalCategoricalChunkedBuilder { + #[inline] + pub fn get_hash_builder() -> PlFixedStateQuality { + PlFixedStateQuality::with_seed(0) + } + + pub(super) fn new( + name: PlSmallStr, + ordering: CategoricalOrdering, + capacity: usize, + values_capacity: usize, + hash: u128, + ) -> Self { + Self { + inner: ListPrimitiveChunkedBuilder::new( + name, + capacity, + values_capacity, + DataType::UInt32, + ), + idx_lookup: HashTable::with_capacity(capacity), + ordering, + categories: MutablePlString::with_capacity(capacity), + categories_hash: hash, + } + } +} + +impl ListBuilderTrait for ListLocalCategoricalChunkedBuilder { + fn append_series(&mut self, s: &Series) -> PolarsResult<()> { + let DataType::Categorical(Some(rev_map), _) = s.dtype() else { + polars_bail!(ComputeError: "expected categorical type") + }; + let RevMapping::Local(cats_right, new_hash) = &**rev_map else { + polars_bail!(string_cache_mismatch) + }; + let ca = s.categorical().unwrap(); + + // Fast path rev_maps are compatible & lookup is initialized + if self.categories_hash == *new_hash && !self.idx_lookup.is_empty() { + return self.inner.append_series(s); + } + + let hash_builder = ListLocalCategoricalChunkedBuilder::get_hash_builder(); + + // Map the physical of the appended series to be compatible with the existing rev map + let mut idx_mapping = PlHashMap::with_capacity(ca.len()); + + for (idx, cat) in cats_right.values_iter().enumerate() { + let hash_cat = hash_builder.hash_one(cat); + let len = self.idx_lookup.len(); + + // Custom hashing / equality functions for comparing the &str to the idx + // SAFETY: index in hashmap are within bounds of categories + unsafe { + let r = self.idx_lookup.entry( + hash_cat, + |k| self.categories.value_unchecked(*k as usize) == cat, + |k| hash_builder.hash_one(self.categories.value_unchecked(*k as usize)), + ); + + match r { + Entry::Occupied(v) => { + // SAFETY: bucket is initialized. + idx_mapping.insert_unique_unchecked(idx as u32, *v.get()); + }, + Entry::Vacant(slot) => { + idx_mapping.insert_unique_unchecked(idx as u32, len as u32); + self.categories.push(Some(cat)); + slot.insert(len as u32); + }, + } + } + } + + let op = |opt_v: Option<&u32>| opt_v.map(|v| *idx_mapping.get(v).unwrap()); + // SAFETY: length is correct as we do one-one mapping over ca. + let iter = unsafe { + ca.physical() + .downcast_iter() + .flat_map(|arr| arr.iter().map(op)) + .trust_my_length(ca.len()) + }; + self.inner.append_iter(iter); + + Ok(()) + } + + fn append_null(&mut self) { + self.inner.append_null() + } + + fn finish(&mut self) -> ListChunked { + let categories: Utf8ViewArray = std::mem::take(&mut self.categories).into(); + let rev_map = RevMapping::build_local(categories); + let inner_dtype = DataType::Categorical(Some(Arc::new(rev_map)), self.ordering); + let mut ca = self.inner.finish(); + unsafe { ca.set_dtype(DataType::List(Box::new(inner_dtype))) } + ca + } +} + +struct ListGlobalCategoricalChunkedBuilder { + inner: ListPrimitiveChunkedBuilder, + ordering: CategoricalOrdering, + map_merger: GlobalRevMapMerger, +} + +impl ListGlobalCategoricalChunkedBuilder { + pub(super) fn new( + name: PlSmallStr, + ordering: CategoricalOrdering, + capacity: usize, + values_capacity: usize, + rev_map: Arc, + ) -> Self { + let inner = + ListPrimitiveChunkedBuilder::new(name, capacity, values_capacity, DataType::UInt32); + Self { + inner, + ordering, + map_merger: GlobalRevMapMerger::new(rev_map), + } + } +} + +impl ListBuilderTrait for ListGlobalCategoricalChunkedBuilder { + fn append_series(&mut self, s: &Series) -> PolarsResult<()> { + let DataType::Categorical(Some(rev_map), _) = s.dtype() else { + polars_bail!(ComputeError: "expected categorical type") + }; + self.map_merger.merge_map(rev_map)?; + self.inner.append_series(s) + } + + fn append_null(&mut self) { + self.inner.append_null() + } + + fn finish(&mut self) -> ListChunked { + let rev_map = std::mem::take(&mut self.map_merger).finish(); + let inner_dtype = DataType::Categorical(Some(rev_map), self.ordering); + let mut ca = self.inner.finish(); + unsafe { ca.set_dtype(DataType::List(Box::new(inner_dtype))) } + ca + } +} diff --git a/crates/polars-core/src/chunked_array/builder/list/dtypes.rs b/crates/polars-core/src/chunked_array/builder/list/dtypes.rs new file mode 100644 index 000000000000..9808ae0841ca --- /dev/null +++ b/crates/polars-core/src/chunked_array/builder/list/dtypes.rs @@ -0,0 +1,56 @@ +use super::*; + +// Allow large enum as this shouldn't be moved much +#[allow(clippy::large_enum_variant)] +pub(super) enum DtypeMerger { + #[cfg(feature = "dtype-categorical")] + Categorical(GlobalRevMapMerger, CategoricalOrdering), + Other(Option), +} + +impl Default for DtypeMerger { + fn default() -> Self { + DtypeMerger::Other(None) + } +} + +impl DtypeMerger { + pub(super) fn new(dtype: Option) -> Self { + match dtype { + #[cfg(feature = "dtype-categorical")] + Some(DataType::Categorical(Some(rev_map), ordering)) if rev_map.is_global() => { + DtypeMerger::Categorical(GlobalRevMapMerger::new(rev_map), ordering) + }, + _ => DtypeMerger::Other(dtype), + } + } + + #[inline] + pub(super) fn update(&mut self, dtype: &DataType) -> PolarsResult<()> { + match self { + #[cfg(feature = "dtype-categorical")] + DtypeMerger::Categorical(merger, _) => { + let DataType::Categorical(Some(rev_map), _) = dtype else { + polars_bail!(ComputeError: "expected categorical rev-map") + }; + polars_ensure!(rev_map.is_global(), string_cache_mismatch); + return merger.merge_map(rev_map); + }, + DtypeMerger::Other(Some(set_dtype)) => { + polars_ensure!(set_dtype == dtype, ComputeError: "dtypes don't match, got {}, expected: {}", dtype, set_dtype) + }, + _ => {}, + } + Ok(()) + } + + pub(super) fn materialize(self) -> Option { + match self { + #[cfg(feature = "dtype-categorical")] + DtypeMerger::Categorical(merger, ordering) => { + Some(DataType::Categorical(Some(merger.finish()), ordering)) + }, + DtypeMerger::Other(dtype) => dtype, + } + } +} diff --git a/crates/polars-core/src/chunked_array/builder/list/mod.rs b/crates/polars-core/src/chunked_array/builder/list/mod.rs new file mode 100644 index 000000000000..c5eb91834a53 --- /dev/null +++ b/crates/polars-core/src/chunked_array/builder/list/mod.rs @@ -0,0 +1,198 @@ +mod anonymous; +mod binary; +mod boolean; +#[cfg(feature = "dtype-categorical")] +mod categorical; +mod dtypes; +mod null; +mod primitive; + +pub use anonymous::*; +use arrow::legacy::array::list::AnonymousBuilder; +use arrow::legacy::array::null::MutableNullArray; +pub use binary::*; +pub use boolean::*; +#[cfg(feature = "dtype-categorical")] +use categorical::*; +use dtypes::*; +pub use null::*; +pub use primitive::*; + +use super::*; +#[cfg(feature = "object")] +use crate::chunked_array::object::registry::get_object_builder; + +pub trait ListBuilderTrait { + fn append_opt_series(&mut self, opt_s: Option<&Series>) -> PolarsResult<()> { + match opt_s { + Some(s) => return self.append_series(s), + None => self.append_null(), + } + Ok(()) + } + fn append_series(&mut self, s: &Series) -> PolarsResult<()>; + fn append_null(&mut self); + + fn field(&self) -> &Field { + unimplemented!() + } + + fn inner_array(&mut self) -> ArrayRef { + unimplemented!() + } + + fn fast_explode(&self) -> bool { + unimplemented!() + } + + fn finish(&mut self) -> ListChunked { + let arr = self.inner_array(); + + let mut ca = ListChunked::new_with_compute_len(Arc::new(self.field().clone()), vec![arr]); + if self.fast_explode() { + ca.set_fast_explode() + } + ca + } +} + +impl ListBuilderTrait for Box +where + S: ListBuilderTrait, +{ + fn append_opt_series(&mut self, opt_s: Option<&Series>) -> PolarsResult<()> { + (**self).append_opt_series(opt_s) + } + + fn append_series(&mut self, s: &Series) -> PolarsResult<()> { + (**self).append_series(s) + } + + fn append_null(&mut self) { + (**self).append_null() + } + + fn finish(&mut self) -> ListChunked { + (**self).finish() + } +} + +type LargePrimitiveBuilder = MutableListArray>; +type LargeListBinViewBuilder = MutableListArray>; +type LargeListBooleanBuilder = MutableListArray; +type LargeListNullBuilder = MutableListArray; + +pub fn get_list_builder( + inner_type_logical: &DataType, + value_capacity: usize, + list_capacity: usize, + name: PlSmallStr, +) -> Box { + match inner_type_logical { + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(Some(rev_map), ordering) => { + return create_categorical_chunked_listbuilder( + name, + *ordering, + list_capacity, + value_capacity, + rev_map.clone(), + ); + }, + #[cfg(feature = "dtype-categorical")] + DataType::Enum(Some(rev_map), ordering) => { + let list_builder = ListEnumCategoricalChunkedBuilder::new( + name, + *ordering, + list_capacity, + value_capacity, + (**rev_map).clone(), + ); + return Box::new(list_builder); + }, + _ => {}, + } + + let physical_type = inner_type_logical.to_physical(); + + match &physical_type { + #[cfg(feature = "object")] + DataType::Object(_) => { + let builder = get_object_builder(PlSmallStr::EMPTY, 0).get_list_builder( + name, + value_capacity, + list_capacity, + ); + Box::new(builder) + }, + #[cfg(feature = "dtype-struct")] + DataType::Struct(_) => Box::new(AnonymousOwnedListBuilder::new( + name, + list_capacity, + Some(inner_type_logical.clone()), + )), + DataType::Null => Box::new(ListNullChunkedBuilder::new(name, list_capacity)), + DataType::List(_) => Box::new(AnonymousOwnedListBuilder::new( + name, + list_capacity, + Some(inner_type_logical.clone()), + )), + #[cfg(feature = "dtype-array")] + DataType::Array(..) => Box::new(AnonymousOwnedListBuilder::new( + name, + list_capacity, + Some(inner_type_logical.clone()), + )), + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(_, _) => Box::new( + ListPrimitiveChunkedBuilder::::new_with_values_type( + name, + list_capacity, + value_capacity, + physical_type, + inner_type_logical.clone(), + ), + ), + _ => { + macro_rules! get_primitive_builder { + ($type:ty) => {{ + let builder = ListPrimitiveChunkedBuilder::<$type>::new( + name, + list_capacity, + value_capacity, + inner_type_logical.clone(), + ); + Box::new(builder) + }}; + } + macro_rules! get_bool_builder { + () => {{ + let builder = + ListBooleanChunkedBuilder::new(name, list_capacity, value_capacity); + Box::new(builder) + }}; + } + macro_rules! get_string_builder { + () => {{ + let builder = + ListStringChunkedBuilder::new(name, list_capacity, 5 * value_capacity); + Box::new(builder) + }}; + } + macro_rules! get_binary_builder { + () => {{ + let builder = + ListBinaryChunkedBuilder::new(name, list_capacity, 5 * value_capacity); + Box::new(builder) + }}; + } + match_dtype_to_logical_apply_macro!( + physical_type, + get_primitive_builder, + get_string_builder, + get_binary_builder, + get_bool_builder + ) + }, + } +} diff --git a/crates/polars-core/src/chunked_array/builder/list/null.rs b/crates/polars-core/src/chunked_array/builder/list/null.rs new file mode 100644 index 000000000000..233f53e17412 --- /dev/null +++ b/crates/polars-core/src/chunked_array/builder/list/null.rs @@ -0,0 +1,50 @@ +use super::*; + +pub struct ListNullChunkedBuilder { + builder: LargeListNullBuilder, + name: PlSmallStr, +} + +impl ListNullChunkedBuilder { + pub fn new(name: PlSmallStr, capacity: usize) -> Self { + ListNullChunkedBuilder { + builder: LargeListNullBuilder::with_capacity(capacity), + name, + } + } + + pub(crate) fn append(&mut self, s: &Series) { + let value_builder = self.builder.mut_values(); + value_builder.extend_nulls(s.len()); + self.builder.try_push_valid().unwrap(); + } + + pub(crate) fn append_with_len(&mut self, len: usize) { + let value_builder = self.builder.mut_values(); + value_builder.extend_nulls(len); + self.builder.try_push_valid().unwrap(); + } +} + +impl ListBuilderTrait for ListNullChunkedBuilder { + #[inline] + fn append_series(&mut self, s: &Series) -> PolarsResult<()> { + self.append(s); + Ok(()) + } + + #[inline] + fn append_null(&mut self) { + self.builder.push_null(); + } + + fn finish(&mut self) -> ListChunked { + unsafe { + ListChunked::from_chunks_and_dtype_unchecked( + self.name.clone(), + vec![self.builder.as_box()], + DataType::List(Box::new(DataType::Null)), + ) + } + } +} diff --git a/crates/polars-core/src/chunked_array/builder/list/primitive.rs b/crates/polars-core/src/chunked_array/builder/list/primitive.rs new file mode 100644 index 000000000000..ba4cbcc1ea19 --- /dev/null +++ b/crates/polars-core/src/chunked_array/builder/list/primitive.rs @@ -0,0 +1,169 @@ +use super::*; + +pub struct ListPrimitiveChunkedBuilder +where + T: PolarsNumericType, +{ + pub builder: LargePrimitiveBuilder, + field: Field, + fast_explode: bool, +} + +impl ListPrimitiveChunkedBuilder +where + T: PolarsNumericType, +{ + pub fn new( + name: PlSmallStr, + capacity: usize, + values_capacity: usize, + inner_type: DataType, + ) -> Self { + debug_assert!( + inner_type.to_physical().is_primitive_numeric(), + "inner type must be primitive, got {}", + inner_type + ); + let values = MutablePrimitiveArray::::with_capacity(values_capacity); + let builder = LargePrimitiveBuilder::::new_with_capacity(values, capacity); + let field = Field::new(name, DataType::List(Box::new(inner_type))); + + Self { + builder, + field, + fast_explode: true, + } + } + + pub fn new_with_values_type( + name: PlSmallStr, + capacity: usize, + values_capacity: usize, + values_type: DataType, + logical_type: DataType, + ) -> Self { + let values = MutablePrimitiveArray::::with_capacity_from( + values_capacity, + values_type.to_arrow(CompatLevel::newest()), + ); + let builder = LargePrimitiveBuilder::::new_with_capacity(values, capacity); + let field = Field::new(name, DataType::List(Box::new(logical_type))); + Self { + builder, + field, + fast_explode: true, + } + } + + #[inline] + pub fn append_slice(&mut self, items: &[T::Native]) { + let values = self.builder.mut_values(); + values.extend_from_slice(items); + self.builder.try_push_valid().unwrap(); + + if items.is_empty() { + self.fast_explode = false; + } + } + + #[inline] + pub fn append_opt_slice(&mut self, opt_v: Option<&[T::Native]>) { + match opt_v { + Some(items) => self.append_slice(items), + None => { + self.builder.push_null(); + }, + } + } + /// Appends from an iterator over values + #[inline] + pub fn append_values_iter_trusted_len + TrustedLen>( + &mut self, + iter: I, + ) { + let values = self.builder.mut_values(); + + if iter.size_hint().0 == 0 { + self.fast_explode = false; + } + // SAFETY: + // trusted len, trust the type system + values.extend_values(iter); + self.builder.try_push_valid().unwrap(); + } + + #[inline] + pub fn append_values_iter>(&mut self, iter: I) { + let values = self.builder.mut_values(); + + if iter.size_hint().0 == 0 { + self.fast_explode = false; + } + values.extend_values(iter); + self.builder.try_push_valid().unwrap(); + } + + /// Appends from an iterator over values + #[inline] + pub fn append_iter> + TrustedLen>(&mut self, iter: I) { + let values = self.builder.mut_values(); + + if iter.size_hint().0 == 0 { + self.fast_explode = false; + } + // SAFETY: + // trusted len, trust the type system + unsafe { values.extend_trusted_len_unchecked(iter) }; + self.builder.try_push_valid().unwrap(); + } +} + +impl ListBuilderTrait for ListPrimitiveChunkedBuilder +where + T: PolarsNumericType, +{ + #[inline] + fn append_null(&mut self) { + self.fast_explode = false; + self.builder.push_null(); + } + + #[inline] + fn append_series(&mut self, s: &Series) -> PolarsResult<()> { + if s.is_empty() { + self.fast_explode = false; + } + let physical = s.to_physical_repr(); + let ca = physical.unpack::().map_err(|_| { + polars_err!(SchemaMismatch: "cannot build list with different dtypes + +Expected {}, got {}.", self.field.dtype(), s.dtype()) + })?; + let values = self.builder.mut_values(); + + ca.downcast_iter().for_each(|arr| { + if arr.null_count() == 0 { + values.extend_from_slice(arr.values().as_slice()) + } else { + // SAFETY: + // Arrow arrays are trusted length iterators. + unsafe { values.extend_trusted_len_unchecked(arr.into_iter()) } + } + }); + // overflow of i64 is far beyond polars capable lengths. + unsafe { self.builder.try_push_valid().unwrap_unchecked() }; + Ok(()) + } + + fn field(&self) -> &Field { + &self.field + } + + fn inner_array(&mut self) -> ArrayRef { + self.builder.as_box() + } + + fn fast_explode(&self) -> bool { + self.fast_explode + } +} diff --git a/crates/polars-core/src/chunked_array/builder/mod.rs b/crates/polars-core/src/chunked_array/builder/mod.rs new file mode 100644 index 000000000000..05742cc0dba2 --- /dev/null +++ b/crates/polars-core/src/chunked_array/builder/mod.rs @@ -0,0 +1,241 @@ +mod boolean; +#[cfg(feature = "dtype-array")] +pub mod fixed_size_list; +pub mod list; +mod null; +mod primitive; +mod string; + +use std::sync::Arc; + +use arrow::array::*; +use arrow::bitmap::Bitmap; +pub use boolean::*; +#[cfg(feature = "dtype-array")] +pub(crate) use fixed_size_list::*; +pub use list::*; +pub use null::*; +pub use primitive::*; +pub use string::*; + +use crate::chunked_array::to_primitive; +use crate::prelude::*; +use crate::utils::{NoNull, get_iter_capacity}; + +// N: the value type; T: the sentinel type +pub trait ChunkedBuilder { + fn append_value(&mut self, val: N); + fn append_null(&mut self); + fn append_option(&mut self, opt_val: Option) { + match opt_val { + Some(v) => self.append_value(v), + None => self.append_null(), + } + } + fn finish(self) -> ChunkedArray; + + fn shrink_to_fit(&mut self); +} + +// Used in polars/src/chunked_array/apply.rs:24 to collect from aligned vecs and null bitmaps +impl FromIterator<(Vec, Option)> for ChunkedArray +where + T: PolarsNumericType, +{ + fn from_iter, Option)>>(iter: I) -> Self { + let chunks = iter + .into_iter() + .map(|(values, opt_buffer)| to_primitive::(values, opt_buffer)); + ChunkedArray::from_chunk_iter(PlSmallStr::EMPTY, chunks) + } +} + +pub trait NewChunkedArray { + fn from_slice(name: PlSmallStr, v: &[N]) -> Self; + fn from_slice_options(name: PlSmallStr, opt_v: &[Option]) -> Self; + + /// Create a new ChunkedArray from an iterator. + fn from_iter_options(name: PlSmallStr, it: impl Iterator>) -> Self; + + /// Create a new ChunkedArray from an iterator. + fn from_iter_values(name: PlSmallStr, it: impl Iterator) -> Self; +} + +impl NewChunkedArray for ChunkedArray +where + T: PolarsNumericType, +{ + fn from_slice(name: PlSmallStr, v: &[T::Native]) -> Self { + let arr = PrimitiveArray::from_slice(v).to(T::get_dtype().to_arrow(CompatLevel::newest())); + ChunkedArray::with_chunk(name, arr) + } + + fn from_slice_options(name: PlSmallStr, opt_v: &[Option]) -> Self { + Self::from_iter_options(name, opt_v.iter().copied()) + } + + fn from_iter_options( + name: PlSmallStr, + it: impl Iterator>, + ) -> ChunkedArray { + let mut builder = PrimitiveChunkedBuilder::new(name, get_iter_capacity(&it)); + it.for_each(|opt| builder.append_option(opt)); + builder.finish() + } + + /// Create a new ChunkedArray from an iterator. + fn from_iter_values(name: PlSmallStr, it: impl Iterator) -> ChunkedArray { + let ca: NoNull> = it.collect(); + let mut ca = ca.into_inner(); + ca.rename(name); + ca + } +} + +impl NewChunkedArray for BooleanChunked { + fn from_slice(name: PlSmallStr, v: &[bool]) -> Self { + Self::from_iter_values(name, v.iter().copied()) + } + + fn from_slice_options(name: PlSmallStr, opt_v: &[Option]) -> Self { + Self::from_iter_options(name, opt_v.iter().copied()) + } + + fn from_iter_options( + name: PlSmallStr, + it: impl Iterator>, + ) -> ChunkedArray { + let mut builder = BooleanChunkedBuilder::new(name, get_iter_capacity(&it)); + it.for_each(|opt| builder.append_option(opt)); + builder.finish() + } + + /// Create a new ChunkedArray from an iterator. + fn from_iter_values( + name: PlSmallStr, + it: impl Iterator, + ) -> ChunkedArray { + let mut ca: ChunkedArray<_> = it.collect(); + ca.rename(name); + ca + } +} + +impl NewChunkedArray for StringChunked +where + S: AsRef, +{ + fn from_slice(name: PlSmallStr, v: &[S]) -> Self { + let arr = Utf8ViewArray::from_slice_values(v); + ChunkedArray::with_chunk(name, arr) + } + + fn from_slice_options(name: PlSmallStr, opt_v: &[Option]) -> Self { + let arr = Utf8ViewArray::from_slice(opt_v); + ChunkedArray::with_chunk(name, arr) + } + + fn from_iter_options(name: PlSmallStr, it: impl Iterator>) -> Self { + let arr = MutableBinaryViewArray::from_iterator(it).freeze(); + ChunkedArray::with_chunk(name, arr) + } + + /// Create a new ChunkedArray from an iterator. + fn from_iter_values(name: PlSmallStr, it: impl Iterator) -> Self { + let arr = MutableBinaryViewArray::from_values_iter(it).freeze(); + ChunkedArray::with_chunk(name, arr) + } +} + +impl NewChunkedArray for BinaryChunked +where + B: AsRef<[u8]>, +{ + fn from_slice(name: PlSmallStr, v: &[B]) -> Self { + let arr = BinaryViewArray::from_slice_values(v); + ChunkedArray::with_chunk(name, arr) + } + + fn from_slice_options(name: PlSmallStr, opt_v: &[Option]) -> Self { + let arr = BinaryViewArray::from_slice(opt_v); + ChunkedArray::with_chunk(name, arr) + } + + fn from_iter_options(name: PlSmallStr, it: impl Iterator>) -> Self { + let arr = MutableBinaryViewArray::from_iterator(it).freeze(); + ChunkedArray::with_chunk(name, arr) + } + + /// Create a new ChunkedArray from an iterator. + fn from_iter_values(name: PlSmallStr, it: impl Iterator) -> Self { + let arr = MutableBinaryViewArray::from_values_iter(it).freeze(); + ChunkedArray::with_chunk(name, arr) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_primitive_builder() { + let mut builder = + PrimitiveChunkedBuilder::::new(PlSmallStr::from_static("foo"), 6); + let values = &[Some(1), None, Some(2), Some(3), None, Some(4)]; + for val in values { + builder.append_option(*val); + } + let ca = builder.finish(); + assert_eq!(Vec::from(&ca), values); + } + + #[test] + fn test_list_builder() { + let mut builder = ListPrimitiveChunkedBuilder::::new( + PlSmallStr::from_static("a"), + 10, + 5, + DataType::Int32, + ); + + // Create a series containing two chunks. + let mut s1 = + Int32Chunked::from_slice(PlSmallStr::from_static("a"), &[1, 2, 3]).into_series(); + let s2 = Int32Chunked::from_slice(PlSmallStr::from_static("b"), &[4, 5, 6]).into_series(); + s1.append(&s2).unwrap(); + + builder.append_series(&s1).unwrap(); + builder.append_series(&s2).unwrap(); + let ls = builder.finish(); + if let AnyValue::List(s) = ls.get_any_value(0).unwrap() { + // many chunks are aggregated to one in the ListArray + assert_eq!(s.len(), 6) + } else { + panic!() + } + if let AnyValue::List(s) = ls.get_any_value(1).unwrap() { + assert_eq!(s.len(), 3) + } else { + panic!() + } + + // Test list collect. + let out = [&s1, &s2].iter().copied().collect::(); + assert_eq!(out.get_as_series(0).unwrap().len(), 6); + assert_eq!(out.get_as_series(1).unwrap().len(), 3); + + let mut builder = ListPrimitiveChunkedBuilder::::new( + PlSmallStr::from_static("a"), + 10, + 5, + DataType::Int32, + ); + builder.append_series(&s1).unwrap(); + builder.append_null(); + + let out = builder.finish(); + let out = out.explode().unwrap(); + assert_eq!(out.len(), 7); + assert_eq!(out.get(6).unwrap(), AnyValue::Null); + } +} diff --git a/crates/polars-core/src/chunked_array/builder/null.rs b/crates/polars-core/src/chunked_array/builder/null.rs new file mode 100644 index 000000000000..2c7eb4317bcc --- /dev/null +++ b/crates/polars-core/src/chunked_array/builder/null.rs @@ -0,0 +1,36 @@ +use arrow::legacy::array::null::MutableNullArray; + +use super::*; +use crate::series::implementations::null::NullChunked; + +#[derive(Clone)] +pub struct NullChunkedBuilder { + array_builder: MutableNullArray, + pub(crate) field: Field, +} + +impl NullChunkedBuilder { + pub fn new(name: PlSmallStr, len: usize) -> Self { + let array_builder = MutableNullArray::new(len); + + NullChunkedBuilder { + array_builder, + field: Field::new(name, DataType::Null), + } + } + + /// Appends a null slot into the builder + #[inline] + pub fn append_null(&mut self) { + self.array_builder.push_null() + } + + pub fn finish(mut self) -> NullChunked { + let arr = self.array_builder.as_box(); + NullChunked::new(self.field.name().clone(), arr.len()) + } + + pub fn shrink_to_fit(&mut self) { + self.array_builder.shrink_to_fit() + } +} diff --git a/crates/polars-core/src/chunked_array/builder/primitive.rs b/crates/polars-core/src/chunked_array/builder/primitive.rs new file mode 100644 index 000000000000..f310d4145a19 --- /dev/null +++ b/crates/polars-core/src/chunked_array/builder/primitive.rs @@ -0,0 +1,51 @@ +use super::*; + +#[derive(Clone)] +pub struct PrimitiveChunkedBuilder +where + T: PolarsNumericType, +{ + array_builder: MutablePrimitiveArray, + pub(crate) field: Field, +} + +impl ChunkedBuilder for PrimitiveChunkedBuilder +where + T: PolarsNumericType, +{ + /// Appends a value of type `T` into the builder + #[inline] + fn append_value(&mut self, v: T::Native) { + self.array_builder.push(Some(v)) + } + + /// Appends a null slot into the builder + #[inline] + fn append_null(&mut self) { + self.array_builder.push(None) + } + + fn finish(mut self) -> ChunkedArray { + let arr = self.array_builder.as_box(); + ChunkedArray::new_with_compute_len(Arc::new(self.field), vec![arr]) + } + + fn shrink_to_fit(&mut self) { + self.array_builder.shrink_to_fit() + } +} + +impl PrimitiveChunkedBuilder +where + T: PolarsNumericType, +{ + pub fn new(name: PlSmallStr, capacity: usize) -> Self { + let array_builder = MutablePrimitiveArray::::with_capacity(capacity) + .to(T::get_dtype().to_arrow(CompatLevel::newest())); + + PrimitiveChunkedBuilder { + array_builder, + field: Field::new(name, T::get_dtype()), + } + } +} diff --git a/crates/polars-core/src/chunked_array/builder/string.rs b/crates/polars-core/src/chunked_array/builder/string.rs new file mode 100644 index 000000000000..49c298f36865 --- /dev/null +++ b/crates/polars-core/src/chunked_array/builder/string.rs @@ -0,0 +1,62 @@ +use super::*; + +pub struct BinViewChunkedBuilder { + pub(crate) chunk_builder: MutableBinaryViewArray, + pub(crate) field: FieldRef, +} + +impl Clone for BinViewChunkedBuilder { + fn clone(&self) -> Self { + Self { + chunk_builder: self.chunk_builder.clone(), + field: self.field.clone(), + } + } +} + +pub type StringChunkedBuilder = BinViewChunkedBuilder; +pub type BinaryChunkedBuilder = BinViewChunkedBuilder<[u8]>; + +impl BinViewChunkedBuilder { + /// Create a new BinViewChunkedBuilder + /// + /// # Arguments + /// + /// * `capacity` - Number of string elements in the final array. + pub fn new(name: PlSmallStr, capacity: usize) -> Self { + Self { + chunk_builder: MutableBinaryViewArray::with_capacity(capacity), + field: Arc::new(Field::new(name, DataType::from_arrow_dtype(&T::DATA_TYPE))), + } + } + + /// Appends a value of type `T` into the builder + #[inline] + pub fn append_value>(&mut self, v: S) { + self.chunk_builder.push_value(v.as_ref()); + } + + /// Appends a null slot into the builder + #[inline] + pub fn append_null(&mut self) { + self.chunk_builder.push_null() + } + + #[inline] + pub fn append_option>(&mut self, opt: Option) { + self.chunk_builder.push(opt); + } +} + +impl StringChunkedBuilder { + pub fn finish(mut self) -> StringChunked { + let arr = self.chunk_builder.as_box(); + ChunkedArray::new_with_compute_len(self.field, vec![arr]) + } +} +impl BinaryChunkedBuilder { + pub fn finish(mut self) -> BinaryChunked { + let arr = self.chunk_builder.as_box(); + ChunkedArray::new_with_compute_len(self.field, vec![arr]) + } +} diff --git a/crates/polars-core/src/chunked_array/cast.rs b/crates/polars-core/src/chunked_array/cast.rs new file mode 100644 index 000000000000..002139af6f3c --- /dev/null +++ b/crates/polars-core/src/chunked_array/cast.rs @@ -0,0 +1,731 @@ +//! Implementations of the ChunkCast Trait. + +use polars_compute::cast::CastOptionsImpl; +#[cfg(feature = "serde-lazy")] +use serde::{Deserialize, Serialize}; + +use super::flags::StatisticsFlags; +#[cfg(feature = "timezones")] +use crate::chunked_array::temporal::validate_time_zone; +#[cfg(feature = "dtype-datetime")] +use crate::prelude::DataType::Datetime; +use crate::prelude::*; + +#[derive(Copy, Clone, Debug, Default, PartialEq, Hash, Eq)] +#[cfg_attr(feature = "serde-lazy", derive(Serialize, Deserialize))] +#[repr(u8)] +pub enum CastOptions { + /// Raises on overflow + #[default] + Strict, + /// Overflow is replaced with null + NonStrict, + /// Allows wrapping overflow + Overflowing, +} + +impl CastOptions { + pub fn strict(&self) -> bool { + matches!(self, CastOptions::Strict) + } +} + +impl From for CastOptionsImpl { + fn from(value: CastOptions) -> Self { + let wrapped = match value { + CastOptions::Strict | CastOptions::NonStrict => false, + CastOptions::Overflowing => true, + }; + CastOptionsImpl { + wrapped, + partial: false, + } + } +} + +pub(crate) fn cast_chunks( + chunks: &[ArrayRef], + dtype: &DataType, + options: CastOptions, +) -> PolarsResult> { + let check_nulls = matches!(options, CastOptions::Strict); + let options = options.into(); + + let arrow_dtype = dtype.try_to_arrow(CompatLevel::newest())?; + chunks + .iter() + .map(|arr| { + let out = polars_compute::cast::cast(arr.as_ref(), &arrow_dtype, options); + if check_nulls { + out.and_then(|new| { + polars_ensure!(arr.null_count() == new.null_count(), ComputeError: "strict cast failed"); + Ok(new) + }) + + } else { + out + } + }) + .collect::>>() +} + +fn cast_impl_inner( + name: PlSmallStr, + chunks: &[ArrayRef], + dtype: &DataType, + options: CastOptions, +) -> PolarsResult { + let chunks = match dtype { + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(_, _) => { + let mut chunks = cast_chunks(chunks, dtype, options)?; + // @NOTE: We cannot cast here as that will lower the scale. + for chunk in chunks.iter_mut() { + *chunk = std::mem::take( + chunk + .as_any_mut() + .downcast_mut::>() + .unwrap(), + ) + .to(ArrowDataType::Int128) + .to_boxed(); + } + chunks + }, + _ => cast_chunks(chunks, &dtype.to_physical(), options)?, + }; + + let out = Series::try_from((name, chunks))?; + use DataType::*; + let out = match dtype { + Date => out.into_date(), + Datetime(tu, tz) => match tz { + #[cfg(feature = "timezones")] + Some(tz) => { + validate_time_zone(tz)?; + out.into_datetime(*tu, Some(tz.clone())) + }, + _ => out.into_datetime(*tu, None), + }, + Duration(tu) => out.into_duration(*tu), + #[cfg(feature = "dtype-time")] + Time => out.into_time(), + #[cfg(feature = "dtype-decimal")] + Decimal(precision, scale) => out.into_decimal(*precision, scale.unwrap_or(0))?, + _ => out, + }; + + Ok(out) +} + +fn cast_impl( + name: PlSmallStr, + chunks: &[ArrayRef], + dtype: &DataType, + options: CastOptions, +) -> PolarsResult { + cast_impl_inner(name, chunks, dtype, options) +} + +#[cfg(feature = "dtype-struct")] +fn cast_single_to_struct( + name: PlSmallStr, + chunks: &[ArrayRef], + fields: &[Field], + options: CastOptions, +) -> PolarsResult { + polars_ensure!(fields.len() == 1, InvalidOperation: "must specify one field in the struct"); + let mut new_fields = Vec::with_capacity(fields.len()); + // cast to first field dtype + let mut fields = fields.iter(); + let fld = fields.next().unwrap(); + let s = cast_impl_inner(fld.name.clone(), chunks, &fld.dtype, options)?; + let length = s.len(); + new_fields.push(s); + + for fld in fields { + new_fields.push(Series::full_null(fld.name.clone(), length, &fld.dtype)); + } + + StructChunked::from_series(name, length, new_fields.iter()).map(|ca| ca.into_series()) +} + +impl ChunkedArray +where + T: PolarsNumericType, +{ + fn cast_impl(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + if self.dtype() == dtype { + // SAFETY: chunks are correct dtype + let mut out = unsafe { + Series::from_chunks_and_dtype_unchecked( + self.name().clone(), + self.chunks.clone(), + dtype, + ) + }; + out.set_sorted_flag(self.is_sorted_flag()); + return Ok(out); + } + match dtype { + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_, ordering) => { + polars_ensure!( + self.dtype() == &DataType::UInt32, + ComputeError: "cannot cast numeric types to 'Categorical'" + ); + // SAFETY: + // we are guarded by the type system + let ca = unsafe { &*(self as *const ChunkedArray as *const UInt32Chunked) }; + + CategoricalChunked::from_global_indices(ca.clone(), *ordering) + .map(|ca| ca.into_series()) + }, + #[cfg(feature = "dtype-categorical")] + DataType::Enum(rev_map, ordering) => { + let ca = match self.dtype() { + DataType::UInt32 => { + // SAFETY: we are guarded by the type system + unsafe { &*(self as *const ChunkedArray as *const UInt32Chunked) } + .clone() + }, + dt if dt.is_integer() => self + .cast_with_options(self.dtype(), options)? + .strict_cast(&DataType::UInt32)? + .u32()? + .clone(), + _ => { + polars_bail!(ComputeError: "cannot cast non integer types to 'Enum'") + }, + }; + let Some(rev_map) = rev_map else { + polars_bail!(ComputeError: "cannot cast to Enum without categories"); + }; + let categories = rev_map.get_categories(); + // Check if indices are in bounds + if let Some(m) = ChunkAgg::max(&ca) { + if m >= categories.len() as u32 { + polars_bail!(OutOfBounds: "index {} is bigger than the number of categories {}",m,categories.len()); + } + } + // SAFETY: indices are in bound + unsafe { + Ok(CategoricalChunked::from_cats_and_rev_map_unchecked( + ca.clone(), + rev_map.clone(), + true, + *ordering, + ) + .into_series()) + } + }, + #[cfg(feature = "dtype-struct")] + DataType::Struct(fields) => { + cast_single_to_struct(self.name().clone(), &self.chunks, fields, options) + }, + _ => cast_impl_inner(self.name().clone(), &self.chunks, dtype, options).map(|mut s| { + // maintain sorted if data types + // - remain signed + // - unsigned -> signed + // this may still fail with overflow? + let to_signed = dtype.is_signed_integer(); + let unsigned2unsigned = + self.dtype().is_unsigned_integer() && dtype.is_unsigned_integer(); + let allowed = to_signed || unsigned2unsigned; + + if (allowed) + && (s.null_count() == self.null_count()) + // physical to logicals + || (self.dtype().to_physical() == dtype.to_physical()) + { + let is_sorted = self.is_sorted_flag(); + s.set_sorted_flag(is_sorted) + } + s + }), + } + } +} + +impl ChunkCast for ChunkedArray +where + T: PolarsNumericType, +{ + fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + self.cast_impl(dtype, options) + } + + unsafe fn cast_unchecked(&self, dtype: &DataType) -> PolarsResult { + match dtype { + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(Some(rev_map), ordering) + | DataType::Enum(Some(rev_map), ordering) => { + if self.dtype() == &DataType::UInt32 { + // SAFETY: + // we are guarded by the type system. + let ca = unsafe { &*(self as *const ChunkedArray as *const UInt32Chunked) }; + Ok(unsafe { + CategoricalChunked::from_cats_and_rev_map_unchecked( + ca.clone(), + rev_map.clone(), + matches!(dtype, DataType::Enum(_, _)), + *ordering, + ) + } + .into_series()) + } else { + polars_bail!(ComputeError: "cannot cast numeric types to 'Categorical'"); + } + }, + _ => self.cast_impl(dtype, CastOptions::Overflowing), + } + } +} + +impl ChunkCast for StringChunked { + fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + match dtype { + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(rev_map, ordering) => match rev_map { + None => { + // SAFETY: length is correct + let iter = + unsafe { self.downcast_iter().flatten().trust_my_length(self.len()) }; + let builder = + CategoricalChunkedBuilder::new(self.name().clone(), self.len(), *ordering); + let ca = builder.drain_iter_and_finish(iter); + Ok(ca.into_series()) + }, + Some(_) => { + polars_bail!(InvalidOperation: "casting to a categorical with rev map is not allowed"); + }, + }, + #[cfg(feature = "dtype-categorical")] + DataType::Enum(rev_map, ordering) => { + let Some(rev_map) = rev_map else { + polars_bail!(InvalidOperation: "cannot cast / initialize Enum without categories present") + }; + CategoricalChunked::from_string_to_enum(self, rev_map.get_categories(), *ordering) + .map(|ca| { + let mut s = ca.into_series(); + s.rename(self.name().clone()); + s + }) + }, + #[cfg(feature = "dtype-struct")] + DataType::Struct(fields) => { + cast_single_to_struct(self.name().clone(), &self.chunks, fields, options) + }, + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(precision, scale) => match (precision, scale) { + (precision, Some(scale)) => { + let chunks = self.downcast_iter().map(|arr| { + polars_compute::cast::binview_to_decimal( + &arr.to_binview(), + *precision, + *scale, + ) + .to(ArrowDataType::Int128) + }); + Ok(Int128Chunked::from_chunk_iter(self.name().clone(), chunks) + .into_decimal_unchecked(*precision, *scale) + .into_series()) + }, + (None, None) => self.to_decimal(100), + _ => { + polars_bail!(ComputeError: "expected 'precision' or 'scale' when casting to Decimal") + }, + }, + #[cfg(feature = "dtype-date")] + DataType::Date => { + let result = cast_chunks(&self.chunks, dtype, options)?; + let out = Series::try_from((self.name().clone(), result))?; + Ok(out) + }, + #[cfg(feature = "dtype-datetime")] + DataType::Datetime(time_unit, time_zone) => match time_zone { + #[cfg(feature = "timezones")] + Some(time_zone) => { + validate_time_zone(time_zone)?; + let result = cast_chunks( + &self.chunks, + &Datetime(time_unit.to_owned(), Some(time_zone.clone())), + options, + )?; + Series::try_from((self.name().clone(), result)) + }, + _ => { + let result = + cast_chunks(&self.chunks, &Datetime(time_unit.to_owned(), None), options)?; + Series::try_from((self.name().clone(), result)) + }, + }, + _ => cast_impl(self.name().clone(), &self.chunks, dtype, options), + } + } + + unsafe fn cast_unchecked(&self, dtype: &DataType) -> PolarsResult { + self.cast_with_options(dtype, CastOptions::Overflowing) + } +} + +impl BinaryChunked { + /// # Safety + /// String is not validated + pub unsafe fn to_string_unchecked(&self) -> StringChunked { + let chunks = self + .downcast_iter() + .map(|arr| unsafe { arr.to_utf8view_unchecked() }.boxed()) + .collect(); + let field = Arc::new(Field::new(self.name().clone(), DataType::String)); + + let mut ca = StringChunked::new_with_compute_len(field, chunks); + + use StatisticsFlags as F; + ca.retain_flags_from(self, F::IS_SORTED_ANY | F::CAN_FAST_EXPLODE_LIST); + ca + } +} + +impl StringChunked { + pub fn as_binary(&self) -> BinaryChunked { + let chunks = self + .downcast_iter() + .map(|arr| arr.to_binview().boxed()) + .collect(); + let field = Arc::new(Field::new(self.name().clone(), DataType::Binary)); + + let mut ca = BinaryChunked::new_with_compute_len(field, chunks); + + use StatisticsFlags as F; + ca.retain_flags_from(self, F::IS_SORTED_ANY | F::CAN_FAST_EXPLODE_LIST); + ca + } +} + +impl ChunkCast for BinaryChunked { + fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + match dtype { + #[cfg(feature = "dtype-struct")] + DataType::Struct(fields) => { + cast_single_to_struct(self.name().clone(), &self.chunks, fields, options) + }, + _ => cast_impl(self.name().clone(), &self.chunks, dtype, options), + } + } + + unsafe fn cast_unchecked(&self, dtype: &DataType) -> PolarsResult { + match dtype { + DataType::String => unsafe { Ok(self.to_string_unchecked().into_series()) }, + _ => self.cast_with_options(dtype, CastOptions::Overflowing), + } + } +} + +impl ChunkCast for BinaryOffsetChunked { + fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + match dtype { + #[cfg(feature = "dtype-struct")] + DataType::Struct(fields) => { + cast_single_to_struct(self.name().clone(), &self.chunks, fields, options) + }, + _ => cast_impl(self.name().clone(), &self.chunks, dtype, options), + } + } + + unsafe fn cast_unchecked(&self, dtype: &DataType) -> PolarsResult { + self.cast_with_options(dtype, CastOptions::Overflowing) + } +} + +impl ChunkCast for BooleanChunked { + fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + match dtype { + #[cfg(feature = "dtype-struct")] + DataType::Struct(fields) => { + cast_single_to_struct(self.name().clone(), &self.chunks, fields, options) + }, + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_, _) | DataType::Enum(_, _) => { + polars_bail!(InvalidOperation: "cannot cast Boolean to Categorical"); + }, + _ => cast_impl(self.name().clone(), &self.chunks, dtype, options), + } + } + + unsafe fn cast_unchecked(&self, dtype: &DataType) -> PolarsResult { + self.cast_with_options(dtype, CastOptions::Overflowing) + } +} + +/// We cannot cast anything to or from List/LargeList +/// So this implementation casts the inner type +impl ChunkCast for ListChunked { + fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + use DataType::*; + match dtype { + List(child_type) => { + match (self.inner_dtype(), &**child_type) { + (old, new) if old == new => Ok(self.clone().into_series()), + #[cfg(feature = "dtype-categorical")] + (dt, Categorical(None, _) | Enum(_, _)) + if !matches!(dt, Categorical(_, _) | Enum(_, _) | String | Null) => + { + polars_bail!(InvalidOperation: "cannot cast List inner type: '{:?}' to Categorical", dt) + }, + _ => { + // ensure the inner logical type bubbles up + let (arr, child_type) = cast_list(self, child_type, options)?; + // SAFETY: we just cast so the dtype matches. + // we must take this path to correct for physical types. + unsafe { + Ok(Series::from_chunks_and_dtype_unchecked( + self.name().clone(), + vec![arr], + &List(Box::new(child_type)), + )) + } + }, + } + }, + #[cfg(feature = "dtype-array")] + Array(child_type, width) => { + let physical_type = dtype.to_physical(); + + // TODO!: properly implement this recursively. + #[cfg(feature = "dtype-categorical")] + polars_ensure!(!matches!(&**child_type, Categorical(_, _)), InvalidOperation: "array of categorical is not yet supported"); + + // cast to the physical type to avoid logical chunks. + let chunks = cast_chunks(self.chunks(), &physical_type, options)?; + // SAFETY: we just cast so the dtype matches. + // we must take this path to correct for physical types. + unsafe { + Ok(Series::from_chunks_and_dtype_unchecked( + self.name().clone(), + chunks, + &Array(child_type.clone(), *width), + )) + } + }, + _ => { + polars_bail!( + InvalidOperation: "cannot cast List type (inner: '{:?}', to: '{:?}')", + self.inner_dtype(), + dtype, + ) + }, + } + } + + unsafe fn cast_unchecked(&self, dtype: &DataType) -> PolarsResult { + use DataType::*; + match dtype { + List(child_type) => cast_list_unchecked(self, child_type), + _ => self.cast_with_options(dtype, CastOptions::Overflowing), + } + } +} + +/// We cannot cast anything to or from List/LargeList +/// So this implementation casts the inner type +#[cfg(feature = "dtype-array")] +impl ChunkCast for ArrayChunked { + fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + use DataType::*; + match dtype { + Array(child_type, width) => { + polars_ensure!( + *width == self.width(), + InvalidOperation: "cannot cast Array to a different width" + ); + + match (self.inner_dtype(), &**child_type) { + (old, new) if old == new => Ok(self.clone().into_series()), + #[cfg(feature = "dtype-categorical")] + (dt, Categorical(None, _) | Enum(_, _)) if !matches!(dt, String) => { + polars_bail!(InvalidOperation: "cannot cast Array inner type: '{:?}' to dtype: {:?}", dt, child_type) + }, + _ => { + // ensure the inner logical type bubbles up + let (arr, child_type) = cast_fixed_size_list(self, child_type, options)?; + // SAFETY: we just cast so the dtype matches. + // we must take this path to correct for physical types. + unsafe { + Ok(Series::from_chunks_and_dtype_unchecked( + self.name().clone(), + vec![arr], + &Array(Box::new(child_type), *width), + )) + } + }, + } + }, + List(child_type) => { + let physical_type = dtype.to_physical(); + // cast to the physical type to avoid logical chunks. + let chunks = cast_chunks(self.chunks(), &physical_type, options)?; + // SAFETY: we just cast so the dtype matches. + // we must take this path to correct for physical types. + unsafe { + Ok(Series::from_chunks_and_dtype_unchecked( + self.name().clone(), + chunks, + &List(child_type.clone()), + )) + } + }, + _ => { + polars_bail!( + InvalidOperation: "cannot cast Array type (inner: '{:?}', to: '{:?}')", + self.inner_dtype(), + dtype, + ) + }, + } + } + + unsafe fn cast_unchecked(&self, dtype: &DataType) -> PolarsResult { + self.cast_with_options(dtype, CastOptions::Overflowing) + } +} + +// Returns inner data type. This is needed because a cast can instantiate the dtype inner +// values for instance with categoricals +fn cast_list( + ca: &ListChunked, + child_type: &DataType, + options: CastOptions, +) -> PolarsResult<(ArrayRef, DataType)> { + // We still rechunk because we must bubble up a single data-type + // TODO!: consider a version that works on chunks and merges the data-types and arrays. + let ca = ca.rechunk(); + let arr = ca.downcast_as_array(); + // SAFETY: inner dtype is passed correctly + let s = unsafe { + Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![arr.values().clone()], + ca.inner_dtype(), + ) + }; + let new_inner = s.cast_with_options(child_type, options)?; + + let inner_dtype = new_inner.dtype().clone(); + debug_assert_eq!(&inner_dtype, child_type); + + let new_values = new_inner.array_ref(0).clone(); + + let dtype = ListArray::::default_datatype(new_values.dtype().clone()); + let new_arr = ListArray::::new( + dtype, + arr.offsets().clone(), + new_values, + arr.validity().cloned(), + ); + Ok((new_arr.boxed(), inner_dtype)) +} + +unsafe fn cast_list_unchecked(ca: &ListChunked, child_type: &DataType) -> PolarsResult { + // TODO! add chunked, but this must correct for list offsets. + let ca = ca.rechunk(); + let arr = ca.downcast_as_array(); + // SAFETY: inner dtype is passed correctly + let s = unsafe { + Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![arr.values().clone()], + ca.inner_dtype(), + ) + }; + let new_inner = s.cast_unchecked(child_type)?; + let new_values = new_inner.array_ref(0).clone(); + + let dtype = ListArray::::default_datatype(new_values.dtype().clone()); + let new_arr = ListArray::::new( + dtype, + arr.offsets().clone(), + new_values, + arr.validity().cloned(), + ); + Ok(ListChunked::from_chunks_and_dtype_unchecked( + ca.name().clone(), + vec![Box::new(new_arr)], + DataType::List(Box::new(child_type.clone())), + ) + .into_series()) +} + +// Returns inner data type. This is needed because a cast can instantiate the dtype inner +// values for instance with categoricals +#[cfg(feature = "dtype-array")] +fn cast_fixed_size_list( + ca: &ArrayChunked, + child_type: &DataType, + options: CastOptions, +) -> PolarsResult<(ArrayRef, DataType)> { + let ca = ca.rechunk(); + let arr = ca.downcast_as_array(); + // SAFETY: inner dtype is passed correctly + let s = unsafe { + Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![arr.values().clone()], + ca.inner_dtype(), + ) + }; + let new_inner = s.cast_with_options(child_type, options)?; + + let inner_dtype = new_inner.dtype().clone(); + debug_assert_eq!(&inner_dtype, child_type); + + let new_values = new_inner.array_ref(0).clone(); + + let dtype = FixedSizeListArray::default_datatype(new_values.dtype().clone(), ca.width()); + let new_arr = FixedSizeListArray::new(dtype, ca.len(), new_values, arr.validity().cloned()); + Ok((Box::new(new_arr), inner_dtype)) +} + +#[cfg(test)] +mod test { + use crate::chunked_array::cast::CastOptions; + use crate::prelude::*; + + #[test] + fn test_cast_list() -> PolarsResult<()> { + let mut builder = ListPrimitiveChunkedBuilder::::new( + PlSmallStr::from_static("a"), + 10, + 10, + DataType::Int32, + ); + builder.append_opt_slice(Some(&[1i32, 2, 3])); + builder.append_opt_slice(Some(&[1i32, 2, 3])); + let ca = builder.finish(); + + let new = ca.cast_with_options( + &DataType::List(DataType::Float64.into()), + CastOptions::Strict, + )?; + + assert_eq!(new.dtype(), &DataType::List(DataType::Float64.into())); + Ok(()) + } + + #[test] + #[cfg(feature = "dtype-categorical")] + fn test_cast_noop() { + // check if we can cast categorical twice without panic + let ca = StringChunked::new(PlSmallStr::from_static("foo"), &["bar", "ham"]); + let out = ca + .cast_with_options( + &DataType::Categorical(None, Default::default()), + CastOptions::Strict, + ) + .unwrap(); + let out = out + .cast(&DataType::Categorical(None, Default::default())) + .unwrap(); + assert!(matches!(out.dtype(), &DataType::Categorical(_, _))) + } +} diff --git a/crates/polars-core/src/chunked_array/collect.rs b/crates/polars-core/src/chunked_array/collect.rs new file mode 100644 index 000000000000..eb882b1acc13 --- /dev/null +++ b/crates/polars-core/src/chunked_array/collect.rs @@ -0,0 +1,175 @@ +//! Methods for collecting into a ChunkedArray. +//! +//! For types that don't have dtype parameters: +//! iter.(try_)collect_ca(_trusted) (name) +//! +//! For all types: +//! iter.(try_)collect_ca(_trusted)_like (other_df) Copies name/dtype from other_df +//! iter.(try_)collect_ca(_trusted)_with_dtype (name, df) +//! +//! The try variants work on iterators of Results, the trusted variants do not +//! check the length of the iterator. + +use std::sync::Arc; + +use arrow::trusted_len::TrustedLen; +use polars_utils::pl_str::PlSmallStr; + +use crate::chunked_array::ChunkedArray; +use crate::datatypes::{ + ArrayCollectIterExt, ArrayFromIter, ArrayFromIterDtype, DataType, Field, PolarsDataType, +}; +use crate::prelude::CompatLevel; + +pub trait ChunkedCollectIterExt: Iterator + Sized { + #[inline] + fn collect_ca_with_dtype(self, name: PlSmallStr, dtype: DataType) -> ChunkedArray + where + T::Array: ArrayFromIterDtype, + { + let field = Arc::new(Field::new(name, dtype.clone())); + let arr = self.collect_arr_with_dtype(field.dtype.to_arrow(CompatLevel::newest())); + ChunkedArray::from_chunk_iter_and_field(field, [arr]) + } + + #[inline] + fn collect_ca_like(self, name_dtype_src: &ChunkedArray) -> ChunkedArray + where + T::Array: ArrayFromIterDtype, + { + let field = Arc::clone(&name_dtype_src.field); + let arr = self.collect_arr_with_dtype(field.dtype.to_arrow(CompatLevel::newest())); + ChunkedArray::from_chunk_iter_and_field(field, [arr]) + } + + #[inline] + fn collect_ca_trusted_with_dtype(self, name: PlSmallStr, dtype: DataType) -> ChunkedArray + where + T::Array: ArrayFromIterDtype, + Self: TrustedLen, + { + let field = Arc::new(Field::new(name, dtype.clone())); + let arr = self.collect_arr_trusted_with_dtype(field.dtype.to_arrow(CompatLevel::newest())); + ChunkedArray::from_chunk_iter_and_field(field, [arr]) + } + + #[inline] + fn collect_ca_trusted_like(self, name_dtype_src: &ChunkedArray) -> ChunkedArray + where + T::Array: ArrayFromIterDtype, + Self: TrustedLen, + { + let field = Arc::clone(&name_dtype_src.field); + let arr = self.collect_arr_trusted_with_dtype(field.dtype.to_arrow(CompatLevel::newest())); + ChunkedArray::from_chunk_iter_and_field(field, [arr]) + } + + #[inline] + fn try_collect_ca_with_dtype( + self, + name: PlSmallStr, + dtype: DataType, + ) -> Result, E> + where + T::Array: ArrayFromIterDtype, + Self: Iterator>, + { + let field = Arc::new(Field::new(name, dtype.clone())); + let arr = self.try_collect_arr_with_dtype(field.dtype.to_arrow(CompatLevel::newest()))?; + Ok(ChunkedArray::from_chunk_iter_and_field(field, [arr])) + } + + #[inline] + fn try_collect_ca_like( + self, + name_dtype_src: &ChunkedArray, + ) -> Result, E> + where + T::Array: ArrayFromIterDtype, + Self: Iterator>, + { + let field = Arc::clone(&name_dtype_src.field); + let arr = self.try_collect_arr_with_dtype(field.dtype.to_arrow(CompatLevel::newest()))?; + Ok(ChunkedArray::from_chunk_iter_and_field(field, [arr])) + } + + #[inline] + fn try_collect_ca_trusted_with_dtype( + self, + name: PlSmallStr, + dtype: DataType, + ) -> Result, E> + where + T::Array: ArrayFromIterDtype, + Self: Iterator> + TrustedLen, + { + let field = Arc::new(Field::new(name, dtype.clone())); + let arr = + self.try_collect_arr_trusted_with_dtype(field.dtype.to_arrow(CompatLevel::newest()))?; + Ok(ChunkedArray::from_chunk_iter_and_field(field, [arr])) + } + + #[inline] + fn try_collect_ca_trusted_like( + self, + name_dtype_src: &ChunkedArray, + ) -> Result, E> + where + T::Array: ArrayFromIterDtype, + Self: Iterator> + TrustedLen, + { + let field = Arc::clone(&name_dtype_src.field); + let arr = + self.try_collect_arr_trusted_with_dtype(field.dtype.to_arrow(CompatLevel::newest()))?; + Ok(ChunkedArray::from_chunk_iter_and_field(field, [arr])) + } +} + +impl ChunkedCollectIterExt for I {} + +pub trait ChunkedCollectInferIterExt: Iterator + Sized { + #[inline] + fn collect_ca(self, name: PlSmallStr) -> ChunkedArray + where + T::Array: ArrayFromIter, + { + let field = Arc::new(Field::new(name, T::get_dtype())); + let arr = self.collect_arr(); + ChunkedArray::from_chunk_iter_and_field(field, [arr]) + } + + #[inline] + fn collect_ca_trusted(self, name: PlSmallStr) -> ChunkedArray + where + T::Array: ArrayFromIter, + Self: TrustedLen, + { + let field = Arc::new(Field::new(name, T::get_dtype())); + let arr = self.collect_arr_trusted(); + ChunkedArray::from_chunk_iter_and_field(field, [arr]) + } + + #[inline] + fn try_collect_ca(self, name: PlSmallStr) -> Result, E> + where + T::Array: ArrayFromIter, + Self: Iterator>, + { + let field = Arc::new(Field::new(name, T::get_dtype())); + let arr = self.try_collect_arr()?; + Ok(ChunkedArray::from_chunk_iter_and_field(field, [arr])) + } + + #[inline] + fn try_collect_ca_trusted(self, name: PlSmallStr) -> Result, E> + where + T::Array: ArrayFromIter, + Self: Iterator> + TrustedLen, + { + let field = Arc::new(Field::new(name, T::get_dtype())); + let arr = self.try_collect_arr_trusted()?; + Ok(ChunkedArray::from_chunk_iter_and_field(field, [arr])) + } +} + +impl ChunkedCollectInferIterExt for I {} diff --git a/crates/polars-core/src/chunked_array/comparison/categorical.rs b/crates/polars-core/src/chunked_array/comparison/categorical.rs new file mode 100644 index 000000000000..09573c5fbd32 --- /dev/null +++ b/crates/polars-core/src/chunked_array/comparison/categorical.rs @@ -0,0 +1,481 @@ +use arrow::bitmap::Bitmap; +use arrow::legacy::utils::FromTrustedLenIterator; +use polars_compute::comparisons::TotalOrdKernel; + +use crate::chunked_array::cast::CastOptions; +use crate::prelude::nulls::replace_non_null; +use crate::prelude::*; + +#[cfg(feature = "dtype-categorical")] +fn cat_equality_helper<'a, Compare, Missing>( + lhs: &'a CategoricalChunked, + rhs: &'a CategoricalChunked, + missing_function: Missing, + compare_function: Compare, +) -> PolarsResult +where + Compare: Fn(&'a UInt32Chunked, &'a UInt32Chunked) -> BooleanChunked, + Missing: Fn(&'a CategoricalChunked) -> BooleanChunked, +{ + let rev_map_l = lhs.get_rev_map(); + polars_ensure!(rev_map_l.same_src(rhs.get_rev_map()), string_cache_mismatch); + let rhs = rhs.physical(); + + // Fast path for globals + if rhs.len() == 1 && rhs.null_count() == 0 { + let rhs = rhs.get(0).unwrap(); + if rev_map_l.get_optional(rhs).is_none() { + return Ok(missing_function(lhs)); + } + } + Ok(compare_function(lhs.physical(), rhs)) +} + +fn cat_compare_helper<'a, Compare, CompareString>( + lhs: &'a CategoricalChunked, + rhs: &'a CategoricalChunked, + compare_function: Compare, + compare_str_function: CompareString, +) -> PolarsResult +where + Compare: Fn(&'a UInt32Chunked, &'a UInt32Chunked) -> BooleanChunked, + CompareString: Fn(&str, &str) -> bool, +{ + let rev_map_l = lhs.get_rev_map(); + let rev_map_r = rhs.get_rev_map(); + polars_ensure!(rev_map_l.same_src(rev_map_r), ComputeError: "can only compare categoricals of the same type with the same categories"); + + if lhs.is_enum() || !lhs.uses_lexical_ordering() { + Ok(compare_function(lhs.physical(), rhs.physical())) + } else { + match (lhs.len(), rhs.len()) { + (lhs_len, 1) => { + // SAFETY: physical is in range of revmap + let v = unsafe { + rhs.physical() + .get(0) + .map(|phys| rev_map_r.get_unchecked(phys)) + }; + let Some(v) = v else { + return Ok(BooleanChunked::full_null(lhs.name().clone(), lhs_len)); + }; + + Ok(lhs + .iter_str() + .map(|opt_s| opt_s.map(|s| compare_str_function(s, v))) + .collect_ca_trusted(lhs.name().clone())) + }, + (1, rhs_len) => { + // SAFETY: physical is in range of revmap + let v = unsafe { + lhs.physical() + .get(0) + .map(|phys| rev_map_l.get_unchecked(phys)) + }; + let Some(v) = v else { + return Ok(BooleanChunked::full_null(lhs.name().clone(), rhs_len)); + }; + Ok(rhs + .iter_str() + .map(|opt_s| opt_s.map(|s| compare_str_function(v, s))) + .collect_ca_trusted(lhs.name().clone())) + }, + (lhs_len, rhs_len) if lhs_len == rhs_len => Ok(lhs + .iter_str() + .zip(rhs.iter_str()) + .map(|(l, r)| match (l, r) { + (None, _) => None, + (_, None) => None, + (Some(l), Some(r)) => Some(compare_str_function(l, r)), + }) + .collect_ca_trusted(lhs.name().clone())), + (lhs_len, rhs_len) => { + polars_bail!(ComputeError: "Columns are of unequal length: {} vs {}",lhs_len,rhs_len) + }, + } + } +} + +impl ChunkCompareEq<&CategoricalChunked> for CategoricalChunked { + type Item = PolarsResult; + + fn equal(&self, rhs: &CategoricalChunked) -> Self::Item { + cat_equality_helper( + self, + rhs, + |lhs| replace_non_null(lhs.name().clone(), &lhs.physical().chunks, false), + UInt32Chunked::equal, + ) + } + + fn equal_missing(&self, rhs: &CategoricalChunked) -> Self::Item { + cat_equality_helper( + self, + rhs, + |lhs| BooleanChunked::full(lhs.name().clone(), false, lhs.len()), + UInt32Chunked::equal_missing, + ) + } + + fn not_equal(&self, rhs: &CategoricalChunked) -> Self::Item { + cat_equality_helper( + self, + rhs, + |lhs| replace_non_null(lhs.name().clone(), &lhs.physical().chunks, true), + UInt32Chunked::not_equal, + ) + } + + fn not_equal_missing(&self, rhs: &CategoricalChunked) -> Self::Item { + cat_equality_helper( + self, + rhs, + |lhs| BooleanChunked::full(lhs.name().clone(), true, lhs.len()), + UInt32Chunked::not_equal_missing, + ) + } +} + +impl ChunkCompareIneq<&CategoricalChunked> for CategoricalChunked { + type Item = PolarsResult; + + fn gt(&self, rhs: &CategoricalChunked) -> Self::Item { + cat_compare_helper(self, rhs, UInt32Chunked::gt, |l, r| l > r) + } + + fn gt_eq(&self, rhs: &CategoricalChunked) -> Self::Item { + cat_compare_helper(self, rhs, UInt32Chunked::gt_eq, |l, r| l >= r) + } + + fn lt(&self, rhs: &CategoricalChunked) -> Self::Item { + cat_compare_helper(self, rhs, UInt32Chunked::lt, |l, r| l < r) + } + + fn lt_eq(&self, rhs: &CategoricalChunked) -> Self::Item { + cat_compare_helper(self, rhs, UInt32Chunked::lt_eq, |l, r| l <= r) + } +} + +fn cat_str_equality_helper<'a, Missing, CompareNone, CompareCat, ComparePhys, CompareString>( + lhs: &'a CategoricalChunked, + rhs: &'a StringChunked, + missing_function: Missing, + compare_to_none: CompareNone, + cat_compare_function: CompareCat, + phys_compare_function: ComparePhys, + str_compare_function: CompareString, +) -> PolarsResult +where + Missing: Fn(&CategoricalChunked) -> BooleanChunked, + CompareNone: Fn(&CategoricalChunked) -> BooleanChunked, + ComparePhys: Fn(&UInt32Chunked, u32) -> BooleanChunked, + CompareCat: Fn(&CategoricalChunked, &CategoricalChunked) -> PolarsResult, + CompareString: Fn(&StringChunked, &'a StringChunked) -> BooleanChunked, +{ + if lhs.is_enum() { + let rhs_cat = rhs.clone().into_series().strict_cast(lhs.dtype())?; + cat_compare_function(lhs, rhs_cat.categorical().unwrap()) + } else if rhs.len() == 1 { + match rhs.get(0) { + None => Ok(compare_to_none(lhs)), + Some(s) => { + cat_single_str_equality_helper(lhs, s, missing_function, phys_compare_function) + }, + } + } else { + let lhs_string = lhs.cast_with_options(&DataType::String, CastOptions::NonStrict)?; + Ok(str_compare_function(lhs_string.str().unwrap(), rhs)) + } +} + +fn cat_str_compare_helper<'a, CompareCat, ComparePhys, CompareStringSingle, CompareString>( + lhs: &'a CategoricalChunked, + rhs: &'a StringChunked, + cat_compare_function: CompareCat, + phys_compare_function: ComparePhys, + str_single_compare_function: CompareStringSingle, + str_compare_function: CompareString, +) -> PolarsResult +where + CompareStringSingle: Fn(&Utf8ViewArray, &str) -> Bitmap, + ComparePhys: Fn(&UInt32Chunked, u32) -> BooleanChunked, + CompareCat: Fn(&CategoricalChunked, &CategoricalChunked) -> PolarsResult, + CompareString: Fn(&StringChunked, &'a StringChunked) -> BooleanChunked, +{ + if lhs.is_enum() { + let rhs_cat = rhs.clone().into_series().strict_cast(lhs.dtype())?; + cat_compare_function(lhs, rhs_cat.categorical().unwrap()) + } else if rhs.len() == 1 { + match rhs.get(0) { + None => Ok(BooleanChunked::full_null(lhs.name().clone(), lhs.len())), + Some(s) => cat_single_str_compare_helper( + lhs, + s, + phys_compare_function, + str_single_compare_function, + ), + } + } else { + let lhs_string = lhs.cast_with_options(&DataType::String, CastOptions::NonStrict)?; + Ok(str_compare_function(lhs_string.str().unwrap(), rhs)) + } +} + +impl ChunkCompareEq<&StringChunked> for CategoricalChunked { + type Item = PolarsResult; + + fn equal(&self, rhs: &StringChunked) -> Self::Item { + cat_str_equality_helper( + self, + rhs, + |lhs| replace_non_null(lhs.name().clone(), &lhs.physical().chunks, false), + |lhs| BooleanChunked::full_null(lhs.name().clone(), lhs.len()), + |s1, s2| CategoricalChunked::equal(s1, s2), + UInt32Chunked::equal, + StringChunked::equal, + ) + } + fn equal_missing(&self, rhs: &StringChunked) -> Self::Item { + cat_str_equality_helper( + self, + rhs, + |lhs| BooleanChunked::full(lhs.name().clone(), false, lhs.len()), + |lhs| lhs.physical().is_null(), + |s1, s2| CategoricalChunked::equal_missing(s1, s2), + UInt32Chunked::equal_missing, + StringChunked::equal_missing, + ) + } + + fn not_equal(&self, rhs: &StringChunked) -> Self::Item { + cat_str_equality_helper( + self, + rhs, + |lhs| replace_non_null(lhs.name().clone(), &lhs.physical().chunks, true), + |lhs| BooleanChunked::full_null(lhs.name().clone(), lhs.len()), + |s1, s2| CategoricalChunked::not_equal(s1, s2), + UInt32Chunked::not_equal, + StringChunked::not_equal, + ) + } + fn not_equal_missing(&self, rhs: &StringChunked) -> Self::Item { + cat_str_equality_helper( + self, + rhs, + |lhs| BooleanChunked::full(lhs.name().clone(), true, lhs.len()), + |lhs| !lhs.physical().is_null(), + |s1, s2| CategoricalChunked::not_equal_missing(s1, s2), + UInt32Chunked::not_equal_missing, + StringChunked::not_equal_missing, + ) + } +} + +impl ChunkCompareIneq<&StringChunked> for CategoricalChunked { + type Item = PolarsResult; + + fn gt(&self, rhs: &StringChunked) -> Self::Item { + cat_str_compare_helper( + self, + rhs, + |s1, s2| CategoricalChunked::gt(s1, s2), + UInt32Chunked::gt, + Utf8ViewArray::tot_gt_kernel_broadcast, + StringChunked::gt, + ) + } + + fn gt_eq(&self, rhs: &StringChunked) -> Self::Item { + cat_str_compare_helper( + self, + rhs, + |s1, s2| CategoricalChunked::gt_eq(s1, s2), + UInt32Chunked::gt_eq, + Utf8ViewArray::tot_ge_kernel_broadcast, + StringChunked::gt_eq, + ) + } + + fn lt(&self, rhs: &StringChunked) -> Self::Item { + cat_str_compare_helper( + self, + rhs, + |s1, s2| CategoricalChunked::lt(s1, s2), + UInt32Chunked::lt, + Utf8ViewArray::tot_lt_kernel_broadcast, + StringChunked::lt, + ) + } + + fn lt_eq(&self, rhs: &StringChunked) -> Self::Item { + cat_str_compare_helper( + self, + rhs, + |s1, s2| CategoricalChunked::lt_eq(s1, s2), + UInt32Chunked::lt_eq, + Utf8ViewArray::tot_le_kernel_broadcast, + StringChunked::lt_eq, + ) + } +} + +fn cat_single_str_equality_helper<'a, ComparePhys, Missing>( + lhs: &'a CategoricalChunked, + rhs: &'a str, + missing_function: Missing, + phys_compare_function: ComparePhys, +) -> PolarsResult +where + ComparePhys: Fn(&UInt32Chunked, u32) -> BooleanChunked, + Missing: Fn(&CategoricalChunked) -> BooleanChunked, +{ + let rev_map = lhs.get_rev_map(); + let idx = rev_map.find(rhs); + if lhs.is_enum() { + let Some(idx) = idx else { + polars_bail!( + not_in_enum, + value = rhs, + categories = rev_map.get_categories() + ) + }; + Ok(phys_compare_function(lhs.physical(), idx)) + } else { + match rev_map.find(rhs) { + None => Ok(missing_function(lhs)), + Some(idx) => Ok(phys_compare_function(lhs.physical(), idx)), + } + } +} + +fn cat_single_str_compare_helper<'a, ComparePhys, CompareStringSingle>( + lhs: &'a CategoricalChunked, + rhs: &'a str, + phys_compare_function: ComparePhys, + str_single_compare_function: CompareStringSingle, +) -> PolarsResult +where + CompareStringSingle: Fn(&Utf8ViewArray, &str) -> Bitmap, + ComparePhys: Fn(&UInt32Chunked, u32) -> BooleanChunked, +{ + let rev_map = lhs.get_rev_map(); + if lhs.is_enum() { + match rev_map.find(rhs) { + None => { + polars_bail!( + not_in_enum, + value = rhs, + categories = rev_map.get_categories() + ) + }, + Some(idx) => Ok(phys_compare_function(lhs.physical(), idx)), + } + } else { + // Apply comparison on categories map and then do a lookup + let bitmap = str_single_compare_function(lhs.get_rev_map().get_categories(), rhs); + + let mask = match lhs.get_rev_map().as_ref() { + RevMapping::Local(_, _) => { + BooleanChunked::from_iter_trusted_length(lhs.physical().into_iter().map( + |opt_idx| { + // SAFETY: indexing into bitmap with same length as original array + opt_idx.map(|idx| unsafe { bitmap.get_bit_unchecked(idx as usize) }) + }, + )) + }, + RevMapping::Global(idx_map, _, _) => { + BooleanChunked::from_iter_trusted_length(lhs.physical().into_iter().map( + |opt_idx| { + // SAFETY: indexing into bitmap with same length as original array + opt_idx.map(|idx| unsafe { + let idx = *idx_map.get(&idx).unwrap(); + bitmap.get_bit_unchecked(idx as usize) + }) + }, + )) + }, + }; + + Ok(mask.with_name(lhs.name().clone())) + } +} + +impl ChunkCompareEq<&str> for CategoricalChunked { + type Item = PolarsResult; + + fn equal(&self, rhs: &str) -> Self::Item { + cat_single_str_equality_helper( + self, + rhs, + |lhs| replace_non_null(lhs.name().clone(), &lhs.physical().chunks, false), + UInt32Chunked::equal, + ) + } + + fn equal_missing(&self, rhs: &str) -> Self::Item { + cat_single_str_equality_helper( + self, + rhs, + |lhs| BooleanChunked::full(lhs.name().clone(), false, lhs.len()), + UInt32Chunked::equal_missing, + ) + } + + fn not_equal(&self, rhs: &str) -> Self::Item { + cat_single_str_equality_helper( + self, + rhs, + |lhs| replace_non_null(lhs.name().clone(), &lhs.physical().chunks, true), + UInt32Chunked::not_equal, + ) + } + + fn not_equal_missing(&self, rhs: &str) -> Self::Item { + cat_single_str_equality_helper( + self, + rhs, + |lhs| BooleanChunked::full(lhs.name().clone(), true, lhs.len()), + UInt32Chunked::equal_missing, + ) + } +} + +impl ChunkCompareIneq<&str> for CategoricalChunked { + type Item = PolarsResult; + + fn gt(&self, rhs: &str) -> Self::Item { + cat_single_str_compare_helper( + self, + rhs, + UInt32Chunked::gt, + Utf8ViewArray::tot_gt_kernel_broadcast, + ) + } + + fn gt_eq(&self, rhs: &str) -> Self::Item { + cat_single_str_compare_helper( + self, + rhs, + UInt32Chunked::gt_eq, + Utf8ViewArray::tot_ge_kernel_broadcast, + ) + } + + fn lt(&self, rhs: &str) -> Self::Item { + cat_single_str_compare_helper( + self, + rhs, + UInt32Chunked::lt, + Utf8ViewArray::tot_lt_kernel_broadcast, + ) + } + + fn lt_eq(&self, rhs: &str) -> Self::Item { + cat_single_str_compare_helper( + self, + rhs, + UInt32Chunked::lt_eq, + Utf8ViewArray::tot_le_kernel_broadcast, + ) + } +} diff --git a/crates/polars-core/src/chunked_array/comparison/mod.rs b/crates/polars-core/src/chunked_array/comparison/mod.rs new file mode 100644 index 000000000000..5dc1f9bb5d43 --- /dev/null +++ b/crates/polars-core/src/chunked_array/comparison/mod.rs @@ -0,0 +1,1557 @@ +mod scalar; + +#[cfg(feature = "dtype-categorical")] +mod categorical; + +use std::ops::{BitAnd, Not}; + +use arrow::array::BooleanArray; +use arrow::bitmap::{Bitmap, BitmapBuilder}; +use arrow::compute; +use num_traits::{NumCast, ToPrimitive}; +use polars_compute::comparisons::{TotalEqKernel, TotalOrdKernel}; + +use crate::prelude::*; +use crate::series::IsSorted; +use crate::series::implementations::null::NullChunked; +use crate::utils::align_chunks_binary; + +impl ChunkCompareEq<&ChunkedArray> for ChunkedArray +where + T: PolarsNumericType, + T::Array: TotalOrdKernel + TotalEqKernel, +{ + type Item = BooleanChunked; + + fn equal(&self, rhs: &ChunkedArray) -> BooleanChunked { + // Broadcast. + match (self.len(), rhs.len()) { + (_, 1) => { + if let Some(value) = rhs.get(0) { + self.equal(value) + } else { + BooleanChunked::full_null(PlSmallStr::EMPTY, self.len()) + } + }, + (1, _) => { + if let Some(value) = self.get(0) { + rhs.equal(value) + } else { + BooleanChunked::full_null(PlSmallStr::EMPTY, rhs.len()) + } + }, + _ => arity::binary_mut_values( + self, + rhs, + |a, b| a.tot_eq_kernel(b).into(), + PlSmallStr::EMPTY, + ), + } + } + + fn equal_missing(&self, rhs: &ChunkedArray) -> BooleanChunked { + // Broadcast. + match (self.len(), rhs.len()) { + (_, 1) => { + if let Some(value) = rhs.get(0) { + self.equal_missing(value) + } else { + self.is_null() + } + }, + (1, _) => { + if let Some(value) = self.get(0) { + rhs.equal_missing(value) + } else { + rhs.is_null() + } + }, + _ => arity::binary_mut_with_options( + self, + rhs, + |a, b| a.tot_eq_missing_kernel(b).into(), + PlSmallStr::EMPTY, + ), + } + } + + fn not_equal(&self, rhs: &ChunkedArray) -> BooleanChunked { + // Broadcast. + match (self.len(), rhs.len()) { + (_, 1) => { + if let Some(value) = rhs.get(0) { + self.not_equal(value) + } else { + BooleanChunked::full_null(PlSmallStr::EMPTY, self.len()) + } + }, + (1, _) => { + if let Some(value) = self.get(0) { + rhs.not_equal(value) + } else { + BooleanChunked::full_null(PlSmallStr::EMPTY, rhs.len()) + } + }, + _ => arity::binary_mut_values( + self, + rhs, + |a, b| a.tot_ne_kernel(b).into(), + PlSmallStr::EMPTY, + ), + } + } + + fn not_equal_missing(&self, rhs: &ChunkedArray) -> BooleanChunked { + // Broadcast. + match (self.len(), rhs.len()) { + (_, 1) => { + if let Some(value) = rhs.get(0) { + self.not_equal_missing(value) + } else { + self.is_not_null() + } + }, + (1, _) => { + if let Some(value) = self.get(0) { + rhs.not_equal_missing(value) + } else { + rhs.is_not_null() + } + }, + _ => arity::binary_mut_with_options( + self, + rhs, + |a, b| a.tot_ne_missing_kernel(b).into(), + PlSmallStr::EMPTY, + ), + } + } +} + +impl ChunkCompareIneq<&ChunkedArray> for ChunkedArray +where + T: PolarsNumericType, + T::Array: TotalOrdKernel + TotalEqKernel, +{ + type Item = BooleanChunked; + + fn lt(&self, rhs: &ChunkedArray) -> BooleanChunked { + // Broadcast. + match (self.len(), rhs.len()) { + (_, 1) => { + if let Some(value) = rhs.get(0) { + self.lt(value) + } else { + BooleanChunked::full_null(PlSmallStr::EMPTY, self.len()) + } + }, + (1, _) => { + if let Some(value) = self.get(0) { + rhs.gt(value) + } else { + BooleanChunked::full_null(PlSmallStr::EMPTY, rhs.len()) + } + }, + _ => arity::binary_mut_values( + self, + rhs, + |a, b| a.tot_lt_kernel(b).into(), + PlSmallStr::EMPTY, + ), + } + } + + fn lt_eq(&self, rhs: &ChunkedArray) -> BooleanChunked { + // Broadcast. + match (self.len(), rhs.len()) { + (_, 1) => { + if let Some(value) = rhs.get(0) { + self.lt_eq(value) + } else { + BooleanChunked::full_null(PlSmallStr::EMPTY, self.len()) + } + }, + (1, _) => { + if let Some(value) = self.get(0) { + rhs.gt_eq(value) + } else { + BooleanChunked::full_null(PlSmallStr::EMPTY, rhs.len()) + } + }, + _ => arity::binary_mut_values( + self, + rhs, + |a, b| a.tot_le_kernel(b).into(), + PlSmallStr::EMPTY, + ), + } + } + + fn gt(&self, rhs: &Self) -> BooleanChunked { + rhs.lt(self) + } + + fn gt_eq(&self, rhs: &Self) -> BooleanChunked { + rhs.lt_eq(self) + } +} + +impl ChunkCompareEq<&NullChunked> for NullChunked { + type Item = BooleanChunked; + + fn equal(&self, rhs: &NullChunked) -> Self::Item { + BooleanChunked::full_null(self.name().clone(), get_broadcast_length(self, rhs)) + } + + fn equal_missing(&self, rhs: &NullChunked) -> Self::Item { + BooleanChunked::full(self.name().clone(), true, get_broadcast_length(self, rhs)) + } + + fn not_equal(&self, rhs: &NullChunked) -> Self::Item { + BooleanChunked::full_null(self.name().clone(), get_broadcast_length(self, rhs)) + } + + fn not_equal_missing(&self, rhs: &NullChunked) -> Self::Item { + BooleanChunked::full(self.name().clone(), false, get_broadcast_length(self, rhs)) + } +} + +impl ChunkCompareIneq<&NullChunked> for NullChunked { + type Item = BooleanChunked; + + fn gt(&self, rhs: &NullChunked) -> Self::Item { + BooleanChunked::full_null(self.name().clone(), get_broadcast_length(self, rhs)) + } + + fn gt_eq(&self, rhs: &NullChunked) -> Self::Item { + BooleanChunked::full_null(self.name().clone(), get_broadcast_length(self, rhs)) + } + + fn lt(&self, rhs: &NullChunked) -> Self::Item { + BooleanChunked::full_null(self.name().clone(), get_broadcast_length(self, rhs)) + } + + fn lt_eq(&self, rhs: &NullChunked) -> Self::Item { + BooleanChunked::full_null(self.name().clone(), get_broadcast_length(self, rhs)) + } +} + +#[inline] +fn get_broadcast_length(lhs: &NullChunked, rhs: &NullChunked) -> usize { + match (lhs.len(), rhs.len()) { + (1, len_r) => len_r, + (len_l, 1) => len_l, + (len_l, len_r) if len_l == len_r => len_l, + _ => panic!("Cannot compare two series of different lengths."), + } +} + +impl ChunkCompareEq<&BooleanChunked> for BooleanChunked { + type Item = BooleanChunked; + + fn equal(&self, rhs: &BooleanChunked) -> BooleanChunked { + // Broadcast. + match (self.len(), rhs.len()) { + (_, 1) => { + if let Some(value) = rhs.get(0) { + arity::unary_mut_values(self, |arr| arr.tot_eq_kernel_broadcast(&value).into()) + } else { + BooleanChunked::full_null(PlSmallStr::EMPTY, self.len()) + } + }, + (1, _) => { + if let Some(value) = self.get(0) { + arity::unary_mut_values(rhs, |arr| arr.tot_eq_kernel_broadcast(&value).into()) + } else { + BooleanChunked::full_null(PlSmallStr::EMPTY, rhs.len()) + } + }, + _ => arity::binary_mut_values( + self, + rhs, + |a, b| a.tot_eq_kernel(b).into(), + PlSmallStr::EMPTY, + ), + } + } + + fn equal_missing(&self, rhs: &BooleanChunked) -> BooleanChunked { + // Broadcast. + match (self.len(), rhs.len()) { + (_, 1) => { + if let Some(value) = rhs.get(0) { + arity::unary_mut_with_options(self, |arr| { + arr.tot_eq_missing_kernel_broadcast(&value).into() + }) + } else { + self.is_null() + } + }, + (1, _) => { + if let Some(value) = self.get(0) { + arity::unary_mut_with_options(rhs, |arr| { + arr.tot_eq_missing_kernel_broadcast(&value).into() + }) + } else { + rhs.is_null() + } + }, + _ => arity::binary_mut_with_options( + self, + rhs, + |a, b| a.tot_eq_missing_kernel(b).into(), + PlSmallStr::EMPTY, + ), + } + } + + fn not_equal(&self, rhs: &BooleanChunked) -> BooleanChunked { + // Broadcast. + match (self.len(), rhs.len()) { + (_, 1) => { + if let Some(value) = rhs.get(0) { + arity::unary_mut_values(self, |arr| arr.tot_ne_kernel_broadcast(&value).into()) + } else { + BooleanChunked::full_null(PlSmallStr::EMPTY, self.len()) + } + }, + (1, _) => { + if let Some(value) = self.get(0) { + arity::unary_mut_values(rhs, |arr| arr.tot_ne_kernel_broadcast(&value).into()) + } else { + BooleanChunked::full_null(PlSmallStr::EMPTY, rhs.len()) + } + }, + _ => arity::binary_mut_values( + self, + rhs, + |a, b| a.tot_ne_kernel(b).into(), + PlSmallStr::EMPTY, + ), + } + } + + fn not_equal_missing(&self, rhs: &BooleanChunked) -> BooleanChunked { + // Broadcast. + match (self.len(), rhs.len()) { + (_, 1) => { + if let Some(value) = rhs.get(0) { + arity::unary_mut_with_options(self, |arr| { + arr.tot_ne_missing_kernel_broadcast(&value).into() + }) + } else { + self.is_not_null() + } + }, + (1, _) => { + if let Some(value) = self.get(0) { + arity::unary_mut_with_options(rhs, |arr| { + arr.tot_ne_missing_kernel_broadcast(&value).into() + }) + } else { + rhs.is_not_null() + } + }, + _ => arity::binary_mut_with_options( + self, + rhs, + |a, b| a.tot_ne_missing_kernel(b).into(), + PlSmallStr::EMPTY, + ), + } + } +} + +impl ChunkCompareIneq<&BooleanChunked> for BooleanChunked { + type Item = BooleanChunked; + + fn lt(&self, rhs: &BooleanChunked) -> BooleanChunked { + // Broadcast. + match (self.len(), rhs.len()) { + (_, 1) => { + if let Some(value) = rhs.get(0) { + arity::unary_mut_values(self, |arr| arr.tot_lt_kernel_broadcast(&value).into()) + } else { + BooleanChunked::full_null(PlSmallStr::EMPTY, self.len()) + } + }, + (1, _) => { + if let Some(value) = self.get(0) { + arity::unary_mut_values(rhs, |arr| arr.tot_gt_kernel_broadcast(&value).into()) + } else { + BooleanChunked::full_null(PlSmallStr::EMPTY, rhs.len()) + } + }, + _ => arity::binary_mut_values( + self, + rhs, + |a, b| a.tot_lt_kernel(b).into(), + PlSmallStr::EMPTY, + ), + } + } + + fn lt_eq(&self, rhs: &BooleanChunked) -> BooleanChunked { + // Broadcast. + match (self.len(), rhs.len()) { + (_, 1) => { + if let Some(value) = rhs.get(0) { + arity::unary_mut_values(self, |arr| arr.tot_le_kernel_broadcast(&value).into()) + } else { + BooleanChunked::full_null(PlSmallStr::EMPTY, self.len()) + } + }, + (1, _) => { + if let Some(value) = self.get(0) { + arity::unary_mut_values(rhs, |arr| arr.tot_ge_kernel_broadcast(&value).into()) + } else { + BooleanChunked::full_null(PlSmallStr::EMPTY, rhs.len()) + } + }, + _ => arity::binary_mut_values( + self, + rhs, + |a, b| a.tot_le_kernel(b).into(), + PlSmallStr::EMPTY, + ), + } + } + + fn gt(&self, rhs: &Self) -> BooleanChunked { + rhs.lt(self) + } + + fn gt_eq(&self, rhs: &Self) -> BooleanChunked { + rhs.lt_eq(self) + } +} + +impl ChunkCompareEq<&StringChunked> for StringChunked { + type Item = BooleanChunked; + + fn equal(&self, rhs: &StringChunked) -> BooleanChunked { + self.as_binary().equal(&rhs.as_binary()) + } + + fn equal_missing(&self, rhs: &StringChunked) -> BooleanChunked { + self.as_binary().equal_missing(&rhs.as_binary()) + } + + fn not_equal(&self, rhs: &StringChunked) -> BooleanChunked { + self.as_binary().not_equal(&rhs.as_binary()) + } + + fn not_equal_missing(&self, rhs: &StringChunked) -> BooleanChunked { + self.as_binary().not_equal_missing(&rhs.as_binary()) + } +} + +impl ChunkCompareIneq<&StringChunked> for StringChunked { + type Item = BooleanChunked; + + fn gt(&self, rhs: &StringChunked) -> BooleanChunked { + self.as_binary().gt(&rhs.as_binary()) + } + + fn gt_eq(&self, rhs: &StringChunked) -> BooleanChunked { + self.as_binary().gt_eq(&rhs.as_binary()) + } + + fn lt(&self, rhs: &StringChunked) -> BooleanChunked { + self.as_binary().lt(&rhs.as_binary()) + } + + fn lt_eq(&self, rhs: &StringChunked) -> BooleanChunked { + self.as_binary().lt_eq(&rhs.as_binary()) + } +} + +impl ChunkCompareEq<&BinaryChunked> for BinaryChunked { + type Item = BooleanChunked; + + fn equal(&self, rhs: &BinaryChunked) -> BooleanChunked { + // Broadcast. + match (self.len(), rhs.len()) { + (_, 1) => { + if let Some(value) = rhs.get(0) { + self.equal(value) + } else { + BooleanChunked::full_null(PlSmallStr::EMPTY, self.len()) + } + }, + (1, _) => { + if let Some(value) = self.get(0) { + rhs.equal(value) + } else { + BooleanChunked::full_null(PlSmallStr::EMPTY, rhs.len()) + } + }, + _ => arity::binary_mut_values( + self, + rhs, + |a, b| a.tot_eq_kernel(b).into(), + PlSmallStr::EMPTY, + ), + } + } + + fn equal_missing(&self, rhs: &BinaryChunked) -> BooleanChunked { + // Broadcast. + match (self.len(), rhs.len()) { + (_, 1) => { + if let Some(value) = rhs.get(0) { + self.equal_missing(value) + } else { + self.is_null() + } + }, + (1, _) => { + if let Some(value) = self.get(0) { + rhs.equal_missing(value) + } else { + rhs.is_null() + } + }, + _ => arity::binary_mut_with_options( + self, + rhs, + |a, b| a.tot_eq_missing_kernel(b).into(), + PlSmallStr::EMPTY, + ), + } + } + + fn not_equal(&self, rhs: &BinaryChunked) -> BooleanChunked { + // Broadcast. + match (self.len(), rhs.len()) { + (_, 1) => { + if let Some(value) = rhs.get(0) { + self.not_equal(value) + } else { + BooleanChunked::full_null(PlSmallStr::EMPTY, self.len()) + } + }, + (1, _) => { + if let Some(value) = self.get(0) { + rhs.not_equal(value) + } else { + BooleanChunked::full_null(PlSmallStr::EMPTY, rhs.len()) + } + }, + _ => arity::binary_mut_values( + self, + rhs, + |a, b| a.tot_ne_kernel(b).into(), + PlSmallStr::EMPTY, + ), + } + } + + fn not_equal_missing(&self, rhs: &BinaryChunked) -> BooleanChunked { + // Broadcast. + match (self.len(), rhs.len()) { + (_, 1) => { + if let Some(value) = rhs.get(0) { + self.not_equal_missing(value) + } else { + self.is_not_null() + } + }, + (1, _) => { + if let Some(value) = self.get(0) { + rhs.not_equal_missing(value) + } else { + rhs.is_not_null() + } + }, + _ => arity::binary_mut_with_options( + self, + rhs, + |a, b| a.tot_ne_missing_kernel(b).into(), + PlSmallStr::EMPTY, + ), + } + } +} + +impl ChunkCompareIneq<&BinaryChunked> for BinaryChunked { + type Item = BooleanChunked; + + fn lt(&self, rhs: &BinaryChunked) -> BooleanChunked { + // Broadcast. + match (self.len(), rhs.len()) { + (_, 1) => { + if let Some(value) = rhs.get(0) { + self.lt(value) + } else { + BooleanChunked::full_null(PlSmallStr::EMPTY, self.len()) + } + }, + (1, _) => { + if let Some(value) = self.get(0) { + rhs.gt(value) + } else { + BooleanChunked::full_null(PlSmallStr::EMPTY, rhs.len()) + } + }, + _ => arity::binary_mut_values( + self, + rhs, + |a, b| a.tot_lt_kernel(b).into(), + PlSmallStr::EMPTY, + ), + } + } + + fn lt_eq(&self, rhs: &BinaryChunked) -> BooleanChunked { + // Broadcast. + match (self.len(), rhs.len()) { + (_, 1) => { + if let Some(value) = rhs.get(0) { + self.lt_eq(value) + } else { + BooleanChunked::full_null(PlSmallStr::EMPTY, self.len()) + } + }, + (1, _) => { + if let Some(value) = self.get(0) { + rhs.gt_eq(value) + } else { + BooleanChunked::full_null(PlSmallStr::EMPTY, rhs.len()) + } + }, + _ => arity::binary_mut_values( + self, + rhs, + |a, b| a.tot_le_kernel(b).into(), + PlSmallStr::EMPTY, + ), + } + } + + fn gt(&self, rhs: &Self) -> BooleanChunked { + rhs.lt(self) + } + + fn gt_eq(&self, rhs: &Self) -> BooleanChunked { + rhs.lt_eq(self) + } +} + +fn _list_comparison_helper( + lhs: &ListChunked, + rhs: &ListChunked, + op: F, + broadcast_op: B, + missing: bool, + is_ne: bool, +) -> BooleanChunked +where + F: Fn(&ListArray, &ListArray) -> Bitmap, + B: Fn(&ListArray, &Box) -> Bitmap, +{ + match (lhs.len(), rhs.len()) { + (_, 1) => { + let right = rhs + .downcast_iter() + .find(|x| !x.is_empty()) + .unwrap() + .as_any() + .downcast_ref::>() + .unwrap(); + + if !right.validity().is_none_or(|v| v.get(0).unwrap()) { + if missing { + if is_ne { + return lhs.is_not_null(); + } else { + return lhs.is_null(); + } + } else { + return BooleanChunked::full_null(PlSmallStr::EMPTY, lhs.len()); + } + } + + let values = right.values().sliced( + (*right.offsets().first()).try_into().unwrap(), + right.offsets().range().try_into().unwrap(), + ); + + if missing { + arity::unary_mut_with_options(lhs, |a| broadcast_op(a, &values).into()) + } else { + arity::unary_mut_values(lhs, |a| broadcast_op(a, &values).into()) + } + }, + (1, _) => { + let left = lhs + .downcast_iter() + .find(|x| !x.is_empty()) + .unwrap() + .as_any() + .downcast_ref::>() + .unwrap(); + + if !left.validity().is_none_or(|v| v.get(0).unwrap()) { + if missing { + if is_ne { + return rhs.is_not_null(); + } else { + return rhs.is_null(); + } + } else { + return BooleanChunked::full_null(PlSmallStr::EMPTY, rhs.len()); + } + } + + let values = left.values().sliced( + (*left.offsets().first()).try_into().unwrap(), + left.offsets().range().try_into().unwrap(), + ); + + if missing { + arity::unary_mut_with_options(rhs, |a| broadcast_op(a, &values).into()) + } else { + arity::unary_mut_values(rhs, |a| broadcast_op(a, &values).into()) + } + }, + _ => { + if missing { + arity::binary_mut_with_options(lhs, rhs, |a, b| op(a, b).into(), PlSmallStr::EMPTY) + } else { + arity::binary_mut_values(lhs, rhs, |a, b| op(a, b).into(), PlSmallStr::EMPTY) + } + }, + } +} + +impl ChunkCompareEq<&ListChunked> for ListChunked { + type Item = BooleanChunked; + fn equal(&self, rhs: &ListChunked) -> BooleanChunked { + _list_comparison_helper( + self, + rhs, + TotalEqKernel::tot_eq_kernel, + TotalEqKernel::tot_eq_kernel_broadcast, + false, + false, + ) + } + + fn equal_missing(&self, rhs: &ListChunked) -> BooleanChunked { + _list_comparison_helper( + self, + rhs, + TotalEqKernel::tot_eq_missing_kernel, + TotalEqKernel::tot_eq_missing_kernel_broadcast, + true, + false, + ) + } + + fn not_equal(&self, rhs: &ListChunked) -> BooleanChunked { + _list_comparison_helper( + self, + rhs, + TotalEqKernel::tot_ne_kernel, + TotalEqKernel::tot_ne_kernel_broadcast, + false, + true, + ) + } + + fn not_equal_missing(&self, rhs: &ListChunked) -> BooleanChunked { + _list_comparison_helper( + self, + rhs, + TotalEqKernel::tot_ne_missing_kernel, + TotalEqKernel::tot_ne_missing_kernel_broadcast, + true, + true, + ) + } +} + +#[cfg(feature = "dtype-struct")] +fn struct_helper( + a: &StructChunked, + b: &StructChunked, + op: F, + reduce: R, + op_is_ne: bool, + is_missing: bool, +) -> BooleanChunked +where + F: Fn(&Series, &Series) -> BooleanChunked, + R: Fn(BooleanChunked, BooleanChunked) -> BooleanChunked, +{ + let len_a = a.len(); + let len_b = b.len(); + let broadcasts = len_a == 1 || len_b == 1; + if (a.len() != b.len() && !broadcasts) || a.struct_fields().len() != b.struct_fields().len() { + BooleanChunked::full(PlSmallStr::EMPTY, op_is_ne, a.len()) + } else { + let (a, b) = align_chunks_binary(a, b); + + let mut out = a + .fields_as_series() + .iter() + .zip(b.fields_as_series().iter()) + .map(|(l, r)| op(l, r)) + .reduce(&reduce) + .unwrap_or_else(|| BooleanChunked::full(PlSmallStr::EMPTY, !op_is_ne, a.len())); + + if is_missing && (a.has_nulls() || b.has_nulls()) { + // Do some allocations so that we can use the Series dispatch, it otherwise + // gets complicated dealing with combinations of ==, != and broadcasting. + let default = || { + BooleanChunked::with_chunk(PlSmallStr::EMPTY, BooleanArray::from_slice([true])) + .into_series() + }; + let validity_to_series = |x| unsafe { + BooleanChunked::with_chunk( + PlSmallStr::EMPTY, + BooleanArray::from_inner_unchecked(ArrowDataType::Boolean, x, None), + ) + .into_series() + }; + + out = reduce( + out, + op( + &a.rechunk_validity() + .map_or_else(default, validity_to_series), + &b.rechunk_validity() + .map_or_else(default, validity_to_series), + ), + ) + } + + if !is_missing && (a.null_count() > 0 || b.null_count() > 0) { + let mut a = a; + let mut b = b; + + if broadcasts { + if a.len() == 1 { + a = std::borrow::Cow::Owned(a.new_from_index(0, b.len())); + } + if b.len() == 1 { + b = std::borrow::Cow::Owned(b.new_from_index(0, a.len())); + } + } + + let mut a = a.into_owned(); + a.zip_outer_validity(&b); + unsafe { + for (arr, a) in out.downcast_iter_mut().zip(a.downcast_iter()) { + arr.set_validity(a.validity().cloned()) + } + } + } + + out + } +} + +#[cfg(feature = "dtype-struct")] +impl ChunkCompareEq<&StructChunked> for StructChunked { + type Item = BooleanChunked; + fn equal(&self, rhs: &StructChunked) -> BooleanChunked { + struct_helper( + self, + rhs, + |l, r| l.equal_missing(r).unwrap(), + |a, b| a.bitand(b), + false, + false, + ) + } + + fn equal_missing(&self, rhs: &StructChunked) -> BooleanChunked { + struct_helper( + self, + rhs, + |l, r| l.equal_missing(r).unwrap(), + |a, b| a.bitand(b), + false, + true, + ) + } + + fn not_equal(&self, rhs: &StructChunked) -> BooleanChunked { + struct_helper( + self, + rhs, + |l, r| l.not_equal_missing(r).unwrap(), + |a, b| a | b, + true, + false, + ) + } + + fn not_equal_missing(&self, rhs: &StructChunked) -> BooleanChunked { + struct_helper( + self, + rhs, + |l, r| l.not_equal_missing(r).unwrap(), + |a, b| a | b, + true, + true, + ) + } +} + +#[cfg(feature = "dtype-array")] +fn _array_comparison_helper( + lhs: &ArrayChunked, + rhs: &ArrayChunked, + op: F, + broadcast_op: B, + missing: bool, + is_ne: bool, +) -> BooleanChunked +where + F: Fn(&FixedSizeListArray, &FixedSizeListArray) -> Bitmap, + B: Fn(&FixedSizeListArray, &Box) -> Bitmap, +{ + match (lhs.len(), rhs.len()) { + (_, 1) => { + let right = rhs + .downcast_iter() + .find(|x| !x.is_empty()) + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + if !right.validity().is_none_or(|v| v.get(0).unwrap()) { + if missing { + if is_ne { + return lhs.is_not_null(); + } else { + return lhs.is_null(); + } + } else { + return BooleanChunked::full_null(PlSmallStr::EMPTY, lhs.len()); + } + } + + if missing { + arity::unary_mut_with_options(lhs, |a| broadcast_op(a, right.values()).into()) + } else { + arity::unary_mut_values(lhs, |a| broadcast_op(a, right.values()).into()) + } + }, + (1, _) => { + let left = lhs + .downcast_iter() + .find(|x| !x.is_empty()) + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + if !left.validity().is_none_or(|v| v.get(0).unwrap()) { + if missing { + if is_ne { + return rhs.is_not_null(); + } else { + return rhs.is_null(); + } + } else { + return BooleanChunked::full_null(PlSmallStr::EMPTY, rhs.len()); + } + } + + if missing { + arity::unary_mut_with_options(rhs, |a| broadcast_op(a, left.values()).into()) + } else { + arity::unary_mut_values(rhs, |a| broadcast_op(a, left.values()).into()) + } + }, + _ => { + if missing { + arity::binary_mut_with_options(lhs, rhs, |a, b| op(a, b).into(), PlSmallStr::EMPTY) + } else { + arity::binary_mut_values(lhs, rhs, |a, b| op(a, b).into(), PlSmallStr::EMPTY) + } + }, + } +} + +#[cfg(feature = "dtype-array")] +impl ChunkCompareEq<&ArrayChunked> for ArrayChunked { + type Item = BooleanChunked; + fn equal(&self, rhs: &ArrayChunked) -> BooleanChunked { + _array_comparison_helper( + self, + rhs, + TotalEqKernel::tot_eq_kernel, + TotalEqKernel::tot_eq_kernel_broadcast, + false, + false, + ) + } + + fn equal_missing(&self, rhs: &ArrayChunked) -> BooleanChunked { + _array_comparison_helper( + self, + rhs, + TotalEqKernel::tot_eq_missing_kernel, + TotalEqKernel::tot_eq_missing_kernel_broadcast, + true, + false, + ) + } + + fn not_equal(&self, rhs: &ArrayChunked) -> BooleanChunked { + _array_comparison_helper( + self, + rhs, + TotalEqKernel::tot_ne_kernel, + TotalEqKernel::tot_ne_kernel_broadcast, + false, + true, + ) + } + + fn not_equal_missing(&self, rhs: &ArrayChunked) -> Self::Item { + _array_comparison_helper( + self, + rhs, + TotalEqKernel::tot_ne_missing_kernel, + TotalEqKernel::tot_ne_missing_kernel_broadcast, + true, + true, + ) + } +} + +impl Not for &BooleanChunked { + type Output = BooleanChunked; + + fn not(self) -> Self::Output { + let chunks = self.downcast_iter().map(compute::boolean::not); + ChunkedArray::from_chunk_iter(self.name().clone(), chunks) + } +} + +impl Not for BooleanChunked { + type Output = BooleanChunked; + + fn not(self) -> Self::Output { + (&self).not() + } +} + +impl BooleanChunked { + /// Returns whether any of the values in the column are `true`. + /// + /// Null values are ignored. + pub fn any(&self) -> bool { + self.downcast_iter().any(compute::boolean::any) + } + + /// Returns whether all values in the array are `true`. + /// + /// Null values are ignored. + pub fn all(&self) -> bool { + self.downcast_iter().all(compute::boolean::all) + } + + /// Returns whether any of the values in the column are `true`. + /// + /// The output is unknown (`None`) if the array contains any null values and + /// no `true` values. + pub fn any_kleene(&self) -> Option { + let mut result = Some(false); + for arr in self.downcast_iter() { + match compute::boolean_kleene::any(arr) { + Some(true) => return Some(true), + None => result = None, + _ => (), + }; + } + result + } + + /// Returns whether all values in the column are `true`. + /// + /// The output is unknown (`None`) if the array contains any null values and + /// no `false` values. + pub fn all_kleene(&self) -> Option { + let mut result = Some(true); + for arr in self.downcast_iter() { + match compute::boolean_kleene::all(arr) { + Some(false) => return Some(false), + None => result = None, + _ => (), + }; + } + result + } +} + +// private +pub(crate) trait ChunkEqualElement { + /// Only meant for physical types. + /// Check if element in self is equal to element in other, assumes same dtypes + /// + /// # Safety + /// + /// No type checks. + unsafe fn equal_element(&self, _idx_self: usize, _idx_other: usize, _other: &Series) -> bool { + unimplemented!() + } +} + +impl ChunkEqualElement for ChunkedArray +where + T: PolarsNumericType, +{ + unsafe fn equal_element(&self, idx_self: usize, idx_other: usize, other: &Series) -> bool { + let ca_other = other.as_ref().as_ref(); + debug_assert!(self.dtype() == other.dtype()); + let ca_other = &*(ca_other as *const ChunkedArray); + // Should be get and not get_unchecked, because there could be nulls + self.get_unchecked(idx_self) + .tot_eq(&ca_other.get_unchecked(idx_other)) + } +} + +impl ChunkEqualElement for BooleanChunked { + unsafe fn equal_element(&self, idx_self: usize, idx_other: usize, other: &Series) -> bool { + let ca_other = other.as_ref().as_ref(); + debug_assert!(self.dtype() == other.dtype()); + let ca_other = &*(ca_other as *const BooleanChunked); + self.get_unchecked(idx_self) == ca_other.get_unchecked(idx_other) + } +} + +impl ChunkEqualElement for StringChunked { + unsafe fn equal_element(&self, idx_self: usize, idx_other: usize, other: &Series) -> bool { + let ca_other = other.as_ref().as_ref(); + debug_assert!(self.dtype() == other.dtype()); + let ca_other = &*(ca_other as *const StringChunked); + self.get_unchecked(idx_self) == ca_other.get_unchecked(idx_other) + } +} + +impl ChunkEqualElement for BinaryChunked { + unsafe fn equal_element(&self, idx_self: usize, idx_other: usize, other: &Series) -> bool { + let ca_other = other.as_ref().as_ref(); + debug_assert!(self.dtype() == other.dtype()); + let ca_other = &*(ca_other as *const BinaryChunked); + self.get_unchecked(idx_self) == ca_other.get_unchecked(idx_other) + } +} + +impl ChunkEqualElement for BinaryOffsetChunked { + unsafe fn equal_element(&self, idx_self: usize, idx_other: usize, other: &Series) -> bool { + let ca_other = other.as_ref().as_ref(); + debug_assert!(self.dtype() == other.dtype()); + let ca_other = &*(ca_other as *const BinaryOffsetChunked); + self.get_unchecked(idx_self) == ca_other.get_unchecked(idx_other) + } +} + +impl ChunkEqualElement for ListChunked {} +#[cfg(feature = "dtype-array")] +impl ChunkEqualElement for ArrayChunked {} + +#[cfg(test)] +#[cfg_attr(feature = "nightly", allow(clippy::manual_repeat_n))] // remove once stable +mod test { + use std::iter::repeat_n; + + use super::super::test::get_chunked_array; + use crate::prelude::*; + + pub(crate) fn create_two_chunked() -> (Int32Chunked, Int32Chunked) { + let mut a1 = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3]); + let a2 = Int32Chunked::new(PlSmallStr::from_static("a"), &[4, 5, 6]); + let a3 = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3, 4, 5, 6]); + a1.append(&a2).unwrap(); + (a1, a3) + } + + #[test] + fn test_bitwise_ops() { + let a = BooleanChunked::new(PlSmallStr::from_static("a"), &[true, false, false]); + let b = BooleanChunked::new( + PlSmallStr::from_static("b"), + &[Some(true), Some(true), None], + ); + assert_eq!(Vec::from(&a | &b), &[Some(true), Some(true), None]); + assert_eq!(Vec::from(&a & &b), &[Some(true), Some(false), Some(false)]); + assert_eq!(Vec::from(!b), &[Some(false), Some(false), None]); + } + + #[test] + fn test_compare_chunk_diff() { + let (a1, a2) = create_two_chunked(); + + assert_eq!( + a1.equal(&a2).into_iter().collect::>(), + repeat_n(Some(true), 6).collect::>() + ); + assert_eq!( + a2.equal(&a1).into_iter().collect::>(), + repeat_n(Some(true), 6).collect::>() + ); + assert_eq!( + a1.not_equal(&a2).into_iter().collect::>(), + repeat_n(Some(false), 6).collect::>() + ); + assert_eq!( + a2.not_equal(&a1).into_iter().collect::>(), + repeat_n(Some(false), 6).collect::>() + ); + assert_eq!( + a1.gt(&a2).into_iter().collect::>(), + repeat_n(Some(false), 6).collect::>() + ); + assert_eq!( + a2.gt(&a1).into_iter().collect::>(), + repeat_n(Some(false), 6).collect::>() + ); + assert_eq!( + a1.gt_eq(&a2).into_iter().collect::>(), + repeat_n(Some(true), 6).collect::>() + ); + assert_eq!( + a2.gt_eq(&a1).into_iter().collect::>(), + repeat_n(Some(true), 6).collect::>() + ); + assert_eq!( + a1.lt_eq(&a2).into_iter().collect::>(), + repeat_n(Some(true), 6).collect::>() + ); + assert_eq!( + a2.lt_eq(&a1).into_iter().collect::>(), + repeat_n(Some(true), 6).collect::>() + ); + assert_eq!( + a1.lt(&a2).into_iter().collect::>(), + repeat_n(Some(false), 6).collect::>() + ); + assert_eq!( + a2.lt(&a1).into_iter().collect::>(), + repeat_n(Some(false), 6).collect::>() + ); + } + + #[test] + fn test_equal_chunks() { + let a1 = get_chunked_array(); + let a2 = get_chunked_array(); + + assert_eq!( + a1.equal(&a2).into_iter().collect::>(), + repeat_n(Some(true), 3).collect::>() + ); + assert_eq!( + a2.equal(&a1).into_iter().collect::>(), + repeat_n(Some(true), 3).collect::>() + ); + assert_eq!( + a1.not_equal(&a2).into_iter().collect::>(), + repeat_n(Some(false), 3).collect::>() + ); + assert_eq!( + a2.not_equal(&a1).into_iter().collect::>(), + repeat_n(Some(false), 3).collect::>() + ); + assert_eq!( + a1.gt(&a2).into_iter().collect::>(), + repeat_n(Some(false), 3).collect::>() + ); + assert_eq!( + a2.gt(&a1).into_iter().collect::>(), + repeat_n(Some(false), 3).collect::>() + ); + assert_eq!( + a1.gt_eq(&a2).into_iter().collect::>(), + repeat_n(Some(true), 3).collect::>() + ); + assert_eq!( + a2.gt_eq(&a1).into_iter().collect::>(), + repeat_n(Some(true), 3).collect::>() + ); + assert_eq!( + a1.lt_eq(&a2).into_iter().collect::>(), + repeat_n(Some(true), 3).collect::>() + ); + assert_eq!( + a2.lt_eq(&a1).into_iter().collect::>(), + repeat_n(Some(true), 3).collect::>() + ); + assert_eq!( + a1.lt(&a2).into_iter().collect::>(), + repeat_n(Some(false), 3).collect::>() + ); + assert_eq!( + a2.lt(&a1).into_iter().collect::>(), + repeat_n(Some(false), 3).collect::>() + ); + } + + #[test] + fn test_null_handling() { + // assert we comply with arrows way of handling null data + // we check comparison on two arrays with one chunk and verify it is equal to a differently + // chunked array comparison. + + // two same chunked arrays + let a1: Int32Chunked = [Some(1), None, Some(3)].iter().copied().collect(); + let a2: Int32Chunked = [Some(1), Some(2), Some(3)].iter().copied().collect(); + + let mut a2_2chunks: Int32Chunked = [Some(1), Some(2)].iter().copied().collect(); + a2_2chunks + .append(&[Some(3)].iter().copied().collect()) + .unwrap(); + + assert_eq!( + a1.equal(&a2).into_iter().collect::>(), + a1.equal(&a2_2chunks).into_iter().collect::>() + ); + + assert_eq!( + a1.not_equal(&a2).into_iter().collect::>(), + a1.not_equal(&a2_2chunks).into_iter().collect::>() + ); + assert_eq!( + a1.not_equal(&a2).into_iter().collect::>(), + a2_2chunks.not_equal(&a1).into_iter().collect::>() + ); + + assert_eq!( + a1.gt(&a2).into_iter().collect::>(), + a1.gt(&a2_2chunks).into_iter().collect::>() + ); + assert_eq!( + a1.gt(&a2).into_iter().collect::>(), + a2_2chunks.gt(&a1).into_iter().collect::>() + ); + + assert_eq!( + a1.gt_eq(&a2).into_iter().collect::>(), + a1.gt_eq(&a2_2chunks).into_iter().collect::>() + ); + assert_eq!( + a1.gt_eq(&a2).into_iter().collect::>(), + a2_2chunks.gt_eq(&a1).into_iter().collect::>() + ); + + assert_eq!( + a1.lt_eq(&a2).into_iter().collect::>(), + a1.lt_eq(&a2_2chunks).into_iter().collect::>() + ); + assert_eq!( + a1.lt_eq(&a2).into_iter().collect::>(), + a2_2chunks.lt_eq(&a1).into_iter().collect::>() + ); + + assert_eq!( + a1.lt(&a2).into_iter().collect::>(), + a1.lt(&a2_2chunks).into_iter().collect::>() + ); + assert_eq!( + a1.lt(&a2).into_iter().collect::>(), + a2_2chunks.lt(&a1).into_iter().collect::>() + ); + } + + #[test] + fn test_left_right() { + // This failed with arrow comparisons. + // sliced + let a1: Int32Chunked = [Some(1), Some(2)].iter().copied().collect(); + let a1 = a1.slice(1, 1); + let a2: Int32Chunked = [Some(2)].iter().copied().collect(); + assert_eq!(a1.equal(&a2).sum(), a2.equal(&a1).sum()); + assert_eq!(a1.not_equal(&a2).sum(), a2.not_equal(&a1).sum()); + assert_eq!(a1.gt(&a2).sum(), a2.gt(&a1).sum()); + assert_eq!(a1.lt(&a2).sum(), a2.lt(&a1).sum()); + assert_eq!(a1.lt_eq(&a2).sum(), a2.lt_eq(&a1).sum()); + assert_eq!(a1.gt_eq(&a2).sum(), a2.gt_eq(&a1).sum()); + + let a1: StringChunked = ["a", "b"].iter().copied().collect(); + let a1 = a1.slice(1, 1); + let a2: StringChunked = ["b"].iter().copied().collect(); + assert_eq!(a1.equal(&a2).sum(), a2.equal(&a1).sum()); + assert_eq!(a1.not_equal(&a2).sum(), a2.not_equal(&a1).sum()); + assert_eq!(a1.gt(&a2).sum(), a2.gt(&a1).sum()); + assert_eq!(a1.lt(&a2).sum(), a2.lt(&a1).sum()); + assert_eq!(a1.lt_eq(&a2).sum(), a2.lt_eq(&a1).sum()); + assert_eq!(a1.gt_eq(&a2).sum(), a2.gt_eq(&a1).sum()); + } + + #[test] + fn test_kleene() { + let a = BooleanChunked::new(PlSmallStr::EMPTY, &[Some(true), Some(false), None]); + let trues = BooleanChunked::from_slice(PlSmallStr::EMPTY, &[true, true, true]); + let falses = BooleanChunked::from_slice(PlSmallStr::EMPTY, &[false, false, false]); + + let c = &a | &trues; + assert_eq!(Vec::from(&c), &[Some(true), Some(true), Some(true)]); + + let c = &a | &falses; + assert_eq!(Vec::from(&c), &[Some(true), Some(false), None]) + } + + #[test] + fn list_broadcasting_lists() { + let s_el = Series::new(PlSmallStr::EMPTY, &[1, 2, 3]); + let s_lhs = Series::new(PlSmallStr::EMPTY, &[s_el.clone(), s_el.clone()]); + let s_rhs = Series::new(PlSmallStr::EMPTY, &[s_el.clone()]); + + let result = s_lhs.list().unwrap().equal(s_rhs.list().unwrap()); + assert_eq!(result.len(), 2); + assert!(result.all()); + } + + #[test] + fn test_broadcasting_bools() { + let a = BooleanChunked::from_slice(PlSmallStr::EMPTY, &[true, false, true]); + let true_ = BooleanChunked::from_slice(PlSmallStr::EMPTY, &[true]); + let false_ = BooleanChunked::from_slice(PlSmallStr::EMPTY, &[false]); + + let out = a.equal(&true_); + assert_eq!(Vec::from(&out), &[Some(true), Some(false), Some(true)]); + let out = true_.equal(&a); + assert_eq!(Vec::from(&out), &[Some(true), Some(false), Some(true)]); + let out = a.equal(&false_); + assert_eq!(Vec::from(&out), &[Some(false), Some(true), Some(false)]); + let out = false_.equal(&a); + assert_eq!(Vec::from(&out), &[Some(false), Some(true), Some(false)]); + + let out = a.not_equal(&true_); + assert_eq!(Vec::from(&out), &[Some(false), Some(true), Some(false)]); + let out = true_.not_equal(&a); + assert_eq!(Vec::from(&out), &[Some(false), Some(true), Some(false)]); + let out = a.not_equal(&false_); + assert_eq!(Vec::from(&out), &[Some(true), Some(false), Some(true)]); + let out = false_.not_equal(&a); + assert_eq!(Vec::from(&out), &[Some(true), Some(false), Some(true)]); + + let out = a.gt(&true_); + assert_eq!(Vec::from(&out), &[Some(false), Some(false), Some(false)]); + let out = true_.gt(&a); + assert_eq!(Vec::from(&out), &[Some(false), Some(true), Some(false)]); + let out = a.gt(&false_); + assert_eq!(Vec::from(&out), &[Some(true), Some(false), Some(true)]); + let out = false_.gt(&a); + assert_eq!(Vec::from(&out), &[Some(false), Some(false), Some(false)]); + + let out = a.gt_eq(&true_); + assert_eq!(Vec::from(&out), &[Some(true), Some(false), Some(true)]); + let out = true_.gt_eq(&a); + assert_eq!(Vec::from(&out), &[Some(true), Some(true), Some(true)]); + let out = a.gt_eq(&false_); + assert_eq!(Vec::from(&out), &[Some(true), Some(true), Some(true)]); + let out = false_.gt_eq(&a); + assert_eq!(Vec::from(&out), &[Some(false), Some(true), Some(false)]); + + let out = a.lt(&true_); + assert_eq!(Vec::from(&out), &[Some(false), Some(true), Some(false)]); + let out = true_.lt(&a); + assert_eq!(Vec::from(&out), &[Some(false), Some(false), Some(false)]); + let out = a.lt(&false_); + assert_eq!(Vec::from(&out), &[Some(false), Some(false), Some(false)]); + let out = false_.lt(&a); + assert_eq!(Vec::from(&out), &[Some(true), Some(false), Some(true)]); + + let out = a.lt_eq(&true_); + assert_eq!(Vec::from(&out), &[Some(true), Some(true), Some(true)]); + let out = true_.lt_eq(&a); + assert_eq!(Vec::from(&out), &[Some(true), Some(false), Some(true)]); + let out = a.lt_eq(&false_); + assert_eq!(Vec::from(&out), &[Some(false), Some(true), Some(false)]); + let out = false_.lt_eq(&a); + assert_eq!(Vec::from(&out), &[Some(true), Some(true), Some(true)]); + + let a = + BooleanChunked::from_slice_options(PlSmallStr::EMPTY, &[Some(true), Some(false), None]); + let all_true = BooleanChunked::from_slice(PlSmallStr::EMPTY, &[true, true, true]); + let all_false = BooleanChunked::from_slice(PlSmallStr::EMPTY, &[false, false, false]); + let out = a.equal(&true_); + assert_eq!(Vec::from(&out), &[Some(true), Some(false), None]); + let out = a.not_equal(&true_); + assert_eq!(Vec::from(&out), &[Some(false), Some(true), None]); + + let out = a.equal(&all_true); + assert_eq!(Vec::from(&out), &[Some(true), Some(false), None]); + let out = a.not_equal(&all_true); + assert_eq!(Vec::from(&out), &[Some(false), Some(true), None]); + let out = a.equal(&false_); + assert_eq!(Vec::from(&out), &[Some(false), Some(true), None]); + let out = a.not_equal(&false_); + assert_eq!(Vec::from(&out), &[Some(true), Some(false), None]); + let out = a.equal(&all_false); + assert_eq!(Vec::from(&out), &[Some(false), Some(true), None]); + let out = a.not_equal(&all_false); + assert_eq!(Vec::from(&out), &[Some(true), Some(false), None]); + } + + #[test] + fn test_broadcasting_numeric() { + let a = Int32Chunked::from_slice(PlSmallStr::EMPTY, &[1, 2, 3]); + let one = Int32Chunked::from_slice(PlSmallStr::EMPTY, &[1]); + let three = Int32Chunked::from_slice(PlSmallStr::EMPTY, &[3]); + + let out = a.equal(&one); + assert_eq!(Vec::from(&out), &[Some(true), Some(false), Some(false)]); + let out = one.equal(&a); + assert_eq!(Vec::from(&out), &[Some(true), Some(false), Some(false)]); + let out = a.equal(&three); + assert_eq!(Vec::from(&out), &[Some(false), Some(false), Some(true)]); + let out = three.equal(&a); + assert_eq!(Vec::from(&out), &[Some(false), Some(false), Some(true)]); + + let out = a.not_equal(&one); + assert_eq!(Vec::from(&out), &[Some(false), Some(true), Some(true)]); + let out = one.not_equal(&a); + assert_eq!(Vec::from(&out), &[Some(false), Some(true), Some(true)]); + let out = a.not_equal(&three); + assert_eq!(Vec::from(&out), &[Some(true), Some(true), Some(false)]); + let out = three.not_equal(&a); + assert_eq!(Vec::from(&out), &[Some(true), Some(true), Some(false)]); + + let out = a.gt(&one); + assert_eq!(Vec::from(&out), &[Some(false), Some(true), Some(true)]); + let out = one.gt(&a); + assert_eq!(Vec::from(&out), &[Some(false), Some(false), Some(false)]); + let out = a.gt(&three); + assert_eq!(Vec::from(&out), &[Some(false), Some(false), Some(false)]); + let out = three.gt(&a); + assert_eq!(Vec::from(&out), &[Some(true), Some(true), Some(false)]); + + let out = a.lt(&one); + assert_eq!(Vec::from(&out), &[Some(false), Some(false), Some(false)]); + let out = one.lt(&a); + assert_eq!(Vec::from(&out), &[Some(false), Some(true), Some(true)]); + let out = a.lt(&three); + assert_eq!(Vec::from(&out), &[Some(true), Some(true), Some(false)]); + let out = three.lt(&a); + assert_eq!(Vec::from(&out), &[Some(false), Some(false), Some(false)]); + + let out = a.gt_eq(&one); + assert_eq!(Vec::from(&out), &[Some(true), Some(true), Some(true)]); + let out = one.gt_eq(&a); + assert_eq!(Vec::from(&out), &[Some(true), Some(false), Some(false)]); + let out = a.gt_eq(&three); + assert_eq!(Vec::from(&out), &[Some(false), Some(false), Some(true)]); + let out = three.gt_eq(&a); + assert_eq!(Vec::from(&out), &[Some(true), Some(true), Some(true)]); + + let out = a.lt_eq(&one); + assert_eq!(Vec::from(&out), &[Some(true), Some(false), Some(false)]); + let out = one.lt_eq(&a); + assert_eq!(Vec::from(&out), &[Some(true), Some(true), Some(true)]); + let out = a.lt_eq(&three); + assert_eq!(Vec::from(&out), &[Some(true), Some(true), Some(true)]); + let out = three.lt_eq(&a); + assert_eq!(Vec::from(&out), &[Some(false), Some(false), Some(true)]); + } +} diff --git a/crates/polars-core/src/chunked_array/comparison/scalar.rs b/crates/polars-core/src/chunked_array/comparison/scalar.rs new file mode 100644 index 000000000000..86397de18aa8 --- /dev/null +++ b/crates/polars-core/src/chunked_array/comparison/scalar.rs @@ -0,0 +1,337 @@ +use super::*; + +#[derive(Clone, Copy)] +enum CmpOp { + Lt, + Le, + Gt, + Ge, +} + +// Given two monotonic functions f_a and f_d where f_a is ascending +// (f_a(x[0]) <= f_a(x[1]) <= .. <= f_a(x[n-1])) and f_d is descending +// (f_d(x[0]) >= f_d(x[1]) >= .. >= f_d(x[n-1])), +// outputs a mask where both are true. +// +// If a function is not given it is always assumed to be true. If invert is +// true the output mask is inverted. +fn bitonic_mask( + ca: &ChunkedArray, + f_a: Option, + f_d: Option, + rhs: &T::Native, + invert: bool, +) -> BooleanChunked { + fn apply(op: CmpOp, x: T::Native, rhs: &T::Native) -> bool { + match op { + CmpOp::Lt => x.tot_lt(rhs), + CmpOp::Le => x.tot_le(rhs), + CmpOp::Gt => x.tot_gt(rhs), + CmpOp::Ge => x.tot_ge(rhs), + } + } + let mut output_order: Option = None; + let mut last_value: Option = None; + let mut logical_extend = |len: usize, val: bool| { + if len != 0 { + if let Some(last_value) = last_value { + output_order = match (last_value, val, output_order) { + (false, true, None) => Some(IsSorted::Ascending), + (false, true, _) => Some(IsSorted::Not), + (true, false, None) => Some(IsSorted::Descending), + (true, false, _) => Some(IsSorted::Not), + _ => output_order, + }; + } + last_value = Some(val); + } + }; + + let chunks = ca.downcast_iter().map(|arr| { + let values = arr.values(); + let true_range_start = if let Some(f_a) = f_a { + values.partition_point(|x| !apply::(f_a, *x, rhs)) + } else { + 0 + }; + let true_range_end = if let Some(f_d) = f_d { + true_range_start + + values[true_range_start..].partition_point(|x| apply::(f_d, *x, rhs)) + } else { + values.len() + }; + let mut mask = BitmapBuilder::with_capacity(arr.len()); + mask.extend_constant(true_range_start, invert); + mask.extend_constant(true_range_end - true_range_start, !invert); + mask.extend_constant(arr.len() - true_range_end, invert); + logical_extend(true_range_start, invert); + logical_extend(true_range_end - true_range_start, !invert); + logical_extend(arr.len() - true_range_end, invert); + BooleanArray::from_data_default(mask.freeze(), None) + }); + + let mut ca = BooleanChunked::from_chunk_iter(ca.name().clone(), chunks); + ca.set_sorted_flag(output_order.unwrap_or(IsSorted::Ascending)); + ca +} + +impl ChunkCompareEq for ChunkedArray +where + T: PolarsNumericType, + Rhs: ToPrimitive, + T::Array: TotalOrdKernel + TotalEqKernel, +{ + type Item = BooleanChunked; + + fn equal(&self, rhs: Rhs) -> BooleanChunked { + let rhs: T::Native = NumCast::from(rhs).unwrap(); + let fa = Some(CmpOp::Ge); + let fd = Some(CmpOp::Le); + match (self.is_sorted_flag(), self.null_count()) { + (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, false), + (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, false), + _ => arity::unary_mut_values(self, |arr| arr.tot_eq_kernel_broadcast(&rhs).into()), + } + } + + fn equal_missing(&self, rhs: Rhs) -> BooleanChunked { + if self.null_count() == 0 { + self.equal(rhs) + } else { + let rhs: T::Native = NumCast::from(rhs).unwrap(); + arity::unary_mut_with_options(self, |arr| { + arr.tot_eq_missing_kernel_broadcast(&rhs).into() + }) + } + } + + fn not_equal(&self, rhs: Rhs) -> BooleanChunked { + let rhs: T::Native = NumCast::from(rhs).unwrap(); + let fa = Some(CmpOp::Ge); + let fd = Some(CmpOp::Le); + match (self.is_sorted_flag(), self.null_count()) { + (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, true), + (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, true), + _ => arity::unary_mut_values(self, |arr| arr.tot_ne_kernel_broadcast(&rhs).into()), + } + } + + fn not_equal_missing(&self, rhs: Rhs) -> BooleanChunked { + if self.null_count() == 0 { + self.not_equal(rhs) + } else { + let rhs: T::Native = NumCast::from(rhs).unwrap(); + arity::unary_mut_with_options(self, |arr| { + arr.tot_ne_missing_kernel_broadcast(&rhs).into() + }) + } + } +} + +impl ChunkCompareIneq for ChunkedArray +where + T: PolarsNumericType, + Rhs: ToPrimitive, + T::Array: TotalOrdKernel + TotalEqKernel, +{ + type Item = BooleanChunked; + + fn gt(&self, rhs: Rhs) -> BooleanChunked { + let rhs: T::Native = NumCast::from(rhs).unwrap(); + let fa = Some(CmpOp::Gt); + let fd = None; + match (self.is_sorted_flag(), self.null_count()) { + (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, false), + (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, false), + _ => arity::unary_mut_values(self, |arr| arr.tot_gt_kernel_broadcast(&rhs).into()), + } + } + + fn gt_eq(&self, rhs: Rhs) -> BooleanChunked { + let rhs: T::Native = NumCast::from(rhs).unwrap(); + let fa = Some(CmpOp::Ge); + let fd = None; + match (self.is_sorted_flag(), self.null_count()) { + (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, false), + (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, false), + _ => arity::unary_mut_values(self, |arr| arr.tot_ge_kernel_broadcast(&rhs).into()), + } + } + + fn lt(&self, rhs: Rhs) -> BooleanChunked { + let rhs: T::Native = NumCast::from(rhs).unwrap(); + let fa = None; + let fd = Some(CmpOp::Lt); + match (self.is_sorted_flag(), self.null_count()) { + (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, false), + (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, false), + _ => arity::unary_mut_values(self, |arr| arr.tot_lt_kernel_broadcast(&rhs).into()), + } + } + + fn lt_eq(&self, rhs: Rhs) -> BooleanChunked { + let rhs: T::Native = NumCast::from(rhs).unwrap(); + let fa = None; + let fd = Some(CmpOp::Le); + match (self.is_sorted_flag(), self.null_count()) { + (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, false), + (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, false), + _ => arity::unary_mut_values(self, |arr| arr.tot_le_kernel_broadcast(&rhs).into()), + } + } +} + +impl ChunkCompareEq<&[u8]> for BinaryChunked { + type Item = BooleanChunked; + + fn equal(&self, rhs: &[u8]) -> BooleanChunked { + arity::unary_mut_values(self, |arr| arr.tot_eq_kernel_broadcast(rhs).into()) + } + + fn equal_missing(&self, rhs: &[u8]) -> BooleanChunked { + arity::unary_mut_with_options(self, |arr| arr.tot_eq_missing_kernel_broadcast(rhs).into()) + } + + fn not_equal(&self, rhs: &[u8]) -> BooleanChunked { + arity::unary_mut_values(self, |arr| arr.tot_ne_kernel_broadcast(rhs).into()) + } + + fn not_equal_missing(&self, rhs: &[u8]) -> BooleanChunked { + arity::unary_mut_with_options(self, |arr| arr.tot_ne_missing_kernel_broadcast(rhs).into()) + } +} + +impl ChunkCompareIneq<&[u8]> for BinaryChunked { + type Item = BooleanChunked; + + fn gt(&self, rhs: &[u8]) -> BooleanChunked { + arity::unary_mut_values(self, |arr| arr.tot_gt_kernel_broadcast(rhs).into()) + } + + fn gt_eq(&self, rhs: &[u8]) -> BooleanChunked { + arity::unary_mut_values(self, |arr| arr.tot_ge_kernel_broadcast(rhs).into()) + } + + fn lt(&self, rhs: &[u8]) -> BooleanChunked { + arity::unary_mut_values(self, |arr| arr.tot_lt_kernel_broadcast(rhs).into()) + } + + fn lt_eq(&self, rhs: &[u8]) -> BooleanChunked { + arity::unary_mut_values(self, |arr| arr.tot_le_kernel_broadcast(rhs).into()) + } +} + +impl ChunkCompareEq<&str> for StringChunked { + type Item = BooleanChunked; + + fn equal(&self, rhs: &str) -> BooleanChunked { + arity::unary_mut_values(self, |arr| arr.tot_eq_kernel_broadcast(rhs).into()) + } + + fn equal_missing(&self, rhs: &str) -> BooleanChunked { + arity::unary_mut_with_options(self, |arr| arr.tot_eq_missing_kernel_broadcast(rhs).into()) + } + + fn not_equal(&self, rhs: &str) -> BooleanChunked { + arity::unary_mut_values(self, |arr| arr.tot_ne_kernel_broadcast(rhs).into()) + } + + fn not_equal_missing(&self, rhs: &str) -> BooleanChunked { + arity::unary_mut_with_options(self, |arr| arr.tot_ne_missing_kernel_broadcast(rhs).into()) + } +} + +impl ChunkCompareIneq<&str> for StringChunked { + type Item = BooleanChunked; + + fn gt(&self, rhs: &str) -> BooleanChunked { + arity::unary_mut_values(self, |arr| arr.tot_gt_kernel_broadcast(rhs).into()) + } + + fn gt_eq(&self, rhs: &str) -> BooleanChunked { + arity::unary_mut_values(self, |arr| arr.tot_ge_kernel_broadcast(rhs).into()) + } + + fn lt(&self, rhs: &str) -> BooleanChunked { + arity::unary_mut_values(self, |arr| arr.tot_lt_kernel_broadcast(rhs).into()) + } + + fn lt_eq(&self, rhs: &str) -> BooleanChunked { + arity::unary_mut_values(self, |arr| arr.tot_le_kernel_broadcast(rhs).into()) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_binary_search_cmp() { + let mut s = Series::new(PlSmallStr::EMPTY, &[1, 1, 2, 2, 4, 8]); + s.set_sorted_flag(IsSorted::Ascending); + let out = s.gt(10).unwrap(); + assert!(!out.any()); + + let out = s.gt(0).unwrap(); + assert!(out.all()); + + let out = s.gt(2).unwrap(); + assert_eq!( + out.into_series(), + Series::new(PlSmallStr::EMPTY, [false, false, false, false, true, true]) + ); + let out = s.gt(3).unwrap(); + assert_eq!( + out.into_series(), + Series::new(PlSmallStr::EMPTY, [false, false, false, false, true, true]) + ); + + let out = s.gt_eq(10).unwrap(); + assert!(!out.any()); + let out = s.gt_eq(0).unwrap(); + assert!(out.all()); + + let out = s.gt_eq(2).unwrap(); + assert_eq!( + out.into_series(), + Series::new(PlSmallStr::EMPTY, [false, false, true, true, true, true]) + ); + let out = s.gt_eq(3).unwrap(); + assert_eq!( + out.into_series(), + Series::new(PlSmallStr::EMPTY, [false, false, false, false, true, true]) + ); + + let out = s.lt(10).unwrap(); + assert!(out.all()); + let out = s.lt(0).unwrap(); + assert!(!out.any()); + + let out = s.lt(2).unwrap(); + assert_eq!( + out.into_series(), + Series::new(PlSmallStr::EMPTY, [true, true, false, false, false, false]) + ); + let out = s.lt(3).unwrap(); + assert_eq!( + out.into_series(), + Series::new(PlSmallStr::EMPTY, [true, true, true, true, false, false]) + ); + + let out = s.lt_eq(10).unwrap(); + assert!(out.all()); + let out = s.lt_eq(0).unwrap(); + assert!(!out.any()); + + let out = s.lt_eq(2).unwrap(); + assert_eq!( + out.into_series(), + Series::new(PlSmallStr::EMPTY, [true, true, true, true, false, false]) + ); + let out = s.lt(3).unwrap(); + assert_eq!( + out.into_series(), + Series::new(PlSmallStr::EMPTY, [true, true, true, true, false, false]) + ); + } +} diff --git a/crates/polars-core/src/chunked_array/drop.rs b/crates/polars-core/src/chunked_array/drop.rs new file mode 100644 index 000000000000..f8276b78626f --- /dev/null +++ b/crates/polars-core/src/chunked_array/drop.rs @@ -0,0 +1,24 @@ +use crate::chunked_array::object::extension::drop::drop_list; +use crate::prelude::*; + +#[inline(never)] +#[cold] +fn drop_slow(ca: &mut ChunkedArray) { + // SAFETY: + // guarded by the type system + // the transmute only convinces the type system that we are a list + #[allow(clippy::transmute_undefined_repr)] + unsafe { + drop_list(std::mem::transmute::<&mut ChunkedArray, &ListChunked>( + ca, + )) + } +} + +impl Drop for ChunkedArray { + fn drop(&mut self) { + if matches!(self.dtype(), DataType::List(_)) { + drop_slow(self); + } + } +} diff --git a/crates/polars-core/src/chunked_array/flags.rs b/crates/polars-core/src/chunked_array/flags.rs new file mode 100644 index 000000000000..73522fe3baff --- /dev/null +++ b/crates/polars-core/src/chunked_array/flags.rs @@ -0,0 +1,116 @@ +use std::sync::atomic::{AtomicU32, Ordering}; + +use crate::series::IsSorted; + +/// An interior mutable version of [`StatisticsFlags`] +pub struct StatisticsFlagsIM { + inner: AtomicU32, +} + +bitflags::bitflags! { + #[derive(Clone, Copy, Debug, PartialEq, Eq)] + #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] + pub struct StatisticsFlags: u32 { + const IS_SORTED_ANY = 0x03; + + const IS_SORTED_ASC = 0x01; + const IS_SORTED_DSC = 0x02; + const CAN_FAST_EXPLODE_LIST = 0x04; + } +} + +impl std::fmt::Debug for StatisticsFlagsIM { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("ChunkedArrayFlagsIM") + .field(&self.get()) + .finish() + } +} + +impl Clone for StatisticsFlagsIM { + fn clone(&self) -> Self { + Self::new(self.get()) + } +} + +impl PartialEq for StatisticsFlagsIM { + fn eq(&self, other: &Self) -> bool { + self.get() == other.get() + } +} +impl Eq for StatisticsFlagsIM {} + +impl From for StatisticsFlagsIM { + fn from(value: StatisticsFlags) -> Self { + Self { + inner: AtomicU32::new(value.bits()), + } + } +} + +impl StatisticsFlagsIM { + pub fn new(value: StatisticsFlags) -> Self { + Self { + inner: AtomicU32::new(value.bits()), + } + } + + pub fn empty() -> Self { + Self::new(StatisticsFlags::empty()) + } + + pub fn get_mut(&mut self) -> StatisticsFlags { + StatisticsFlags::from_bits(*self.inner.get_mut()).unwrap() + } + pub fn set_mut(&mut self, value: StatisticsFlags) { + *self.inner.get_mut() = value.bits(); + } + + pub fn get(&self) -> StatisticsFlags { + StatisticsFlags::from_bits(self.inner.load(Ordering::Relaxed)).unwrap() + } + pub fn set(&self, value: StatisticsFlags) { + self.inner.store(value.bits(), Ordering::Relaxed); + } +} + +impl StatisticsFlags { + pub fn is_sorted(&self) -> IsSorted { + let is_sorted_asc = self.contains(Self::IS_SORTED_ASC); + let is_sorted_dsc = self.contains(Self::IS_SORTED_DSC); + + assert!(!is_sorted_asc || !is_sorted_dsc); + + if is_sorted_asc { + IsSorted::Ascending + } else if is_sorted_dsc { + IsSorted::Descending + } else { + IsSorted::Not + } + } + + pub fn set_sorted(&mut self, is_sorted: IsSorted) { + let is_sorted = match is_sorted { + IsSorted::Not => Self::empty(), + IsSorted::Ascending => Self::IS_SORTED_ASC, + IsSorted::Descending => Self::IS_SORTED_DSC, + }; + self.remove(Self::IS_SORTED_ASC | Self::IS_SORTED_DSC); + self.insert(is_sorted); + } + + pub fn is_sorted_any(&self) -> bool { + self.contains(Self::IS_SORTED_ASC) | self.contains(Self::IS_SORTED_DSC) + } + pub fn is_sorted_ascending(&self) -> bool { + self.contains(Self::IS_SORTED_ASC) + } + pub fn is_sorted_descending(&self) -> bool { + self.contains(Self::IS_SORTED_DSC) + } + + pub fn can_fast_explode_list(&self) -> bool { + self.contains(Self::CAN_FAST_EXPLODE_LIST) + } +} diff --git a/crates/polars-core/src/chunked_array/float.rs b/crates/polars-core/src/chunked_array/float.rs new file mode 100644 index 000000000000..5d9bae240062 --- /dev/null +++ b/crates/polars-core/src/chunked_array/float.rs @@ -0,0 +1,62 @@ +use arrow::legacy::kernels::set::set_at_nulls; +use num_traits::Float; +use polars_utils::total_ord::{canonical_f32, canonical_f64}; + +use crate::prelude::arity::unary_elementwise_values; +use crate::prelude::*; + +impl ChunkedArray +where + T: PolarsFloatType, + T::Native: Float, +{ + pub fn is_nan(&self) -> BooleanChunked { + unary_elementwise_values(self, |x| x.is_nan()) + } + pub fn is_not_nan(&self) -> BooleanChunked { + unary_elementwise_values(self, |x| !x.is_nan()) + } + pub fn is_finite(&self) -> BooleanChunked { + unary_elementwise_values(self, |x| x.is_finite()) + } + pub fn is_infinite(&self) -> BooleanChunked { + unary_elementwise_values(self, |x| x.is_infinite()) + } + + #[must_use] + /// Convert missing values to `NaN` values. + pub fn none_to_nan(&self) -> Self { + let chunks = self + .downcast_iter() + .map(|arr| set_at_nulls(arr, T::Native::nan())); + ChunkedArray::from_chunk_iter(self.name().clone(), chunks) + } +} + +pub trait Canonical { + fn canonical(self) -> Self; +} + +impl Canonical for f32 { + #[inline] + fn canonical(self) -> Self { + canonical_f32(self) + } +} + +impl Canonical for f64 { + #[inline] + fn canonical(self) -> Self { + canonical_f64(self) + } +} + +impl ChunkedArray +where + T: PolarsFloatType, + T::Native: Float + Canonical, +{ + pub fn to_canonical(&self) -> Self { + unary_elementwise_values(self, |v| v.canonical()) + } +} diff --git a/crates/polars-core/src/chunked_array/from.rs b/crates/polars-core/src/chunked_array/from.rs new file mode 100644 index 000000000000..61d44d7cea67 --- /dev/null +++ b/crates/polars-core/src/chunked_array/from.rs @@ -0,0 +1,323 @@ +use arrow::compute::concatenate::concatenate_unchecked; + +use super::*; + +#[allow(clippy::all)] +fn from_chunks_list_dtype(chunks: &mut Vec, dtype: DataType) -> DataType { + // ensure we don't get List + let dtype = if let Some(arr) = chunks.get(0) { + DataType::from_arrow_dtype(arr.dtype()) + } else { + dtype + }; + + match dtype { + #[cfg(feature = "dtype-categorical")] + // arrow dictionaries are not nested as dictionaries, but only by their keys, so we must + // change the list-value array to the keys and store the dictionary values in the datatype. + // if a global string cache is set, we also must modify the keys. + DataType::List(inner) + if matches!( + *inner, + DataType::Categorical(None, _) | DataType::Enum(None, _) + ) => + { + let array = concatenate_unchecked(chunks).unwrap(); + let list_arr = array.as_any().downcast_ref::>().unwrap(); + let values_arr = list_arr.values(); + let cat = unsafe { + Series::_try_from_arrow_unchecked( + PlSmallStr::EMPTY, + vec![values_arr.clone()], + values_arr.dtype(), + ) + .unwrap() + }; + + // we nest only the physical representation + // the mapping is still in our rev-map + let arrow_dtype = ListArray::::default_datatype(ArrowDataType::UInt32); + let new_array = ListArray::new( + arrow_dtype, + list_arr.offsets().clone(), + cat.array_ref(0).clone(), + list_arr.validity().cloned(), + ); + chunks.clear(); + chunks.push(Box::new(new_array)); + DataType::List(Box::new(cat.dtype().clone())) + }, + #[cfg(all(feature = "dtype-array", feature = "dtype-categorical"))] + DataType::Array(inner, width) + if matches!( + *inner, + DataType::Categorical(None, _) | DataType::Enum(None, _) + ) => + { + let array = concatenate_unchecked(chunks).unwrap(); + let list_arr = array.as_any().downcast_ref::().unwrap(); + let values_arr = list_arr.values(); + let cat = unsafe { + Series::_try_from_arrow_unchecked( + PlSmallStr::EMPTY, + vec![values_arr.clone()], + values_arr.dtype(), + ) + .unwrap() + }; + + // we nest only the physical representation + // the mapping is still in our rev-map + let arrow_dtype = FixedSizeListArray::default_datatype(ArrowDataType::UInt32, width); + let new_array = FixedSizeListArray::new( + arrow_dtype, + values_arr.len(), + cat.array_ref(0).clone(), + list_arr.validity().cloned(), + ); + chunks.clear(); + chunks.push(Box::new(new_array)); + DataType::Array(Box::new(cat.dtype().clone()), width) + }, + _ => dtype, + } +} + +impl From for ChunkedArray +where + T: PolarsDataType, + A: Array, +{ + fn from(arr: A) -> Self { + Self::with_chunk(PlSmallStr::EMPTY, arr) + } +} + +impl ChunkedArray +where + T: PolarsDataType, +{ + pub fn with_chunk(name: PlSmallStr, arr: A) -> Self + where + A: Array, + T: PolarsDataType, + { + unsafe { Self::from_chunks(name, vec![Box::new(arr)]) } + } + + pub fn with_chunk_like(ca: &Self, arr: A) -> Self + where + A: Array, + T: PolarsDataType, + { + Self::from_chunk_iter_like(ca, std::iter::once(arr)) + } + + pub fn from_chunk_iter(name: PlSmallStr, iter: I) -> Self + where + I: IntoIterator, + T: PolarsDataType::Item>, + ::Item: Array, + { + let chunks = iter + .into_iter() + .map(|x| Box::new(x) as Box) + .collect(); + unsafe { Self::from_chunks(name, chunks) } + } + + pub fn from_chunk_iter_like(ca: &Self, iter: I) -> Self + where + I: IntoIterator, + T: PolarsDataType::Item>, + ::Item: Array, + { + let chunks = iter + .into_iter() + .map(|x| Box::new(x) as Box) + .collect(); + unsafe { + Self::from_chunks_and_dtype_unchecked(ca.name().clone(), chunks, ca.dtype().clone()) + } + } + + pub fn try_from_chunk_iter(name: PlSmallStr, iter: I) -> Result + where + I: IntoIterator>, + T: PolarsDataType, + A: Array, + { + let chunks: Result<_, _> = iter + .into_iter() + .map(|x| Ok(Box::new(x?) as Box)) + .collect(); + unsafe { Ok(Self::from_chunks(name, chunks?)) } + } + + pub(crate) fn from_chunk_iter_and_field(field: Arc, chunks: I) -> Self + where + I: IntoIterator, + T: PolarsDataType::Item>, + ::Item: Array, + { + assert_eq!( + std::mem::discriminant(&T::get_dtype()), + std::mem::discriminant(&field.dtype) + ); + + let mut length = 0; + let mut null_count = 0; + let chunks = chunks + .into_iter() + .map(|x| { + length += x.len(); + null_count += x.null_count(); + Box::new(x) as Box + }) + .collect(); + + unsafe { ChunkedArray::new_with_dims(field, chunks, length, null_count) } + } + + /// Create a new [`ChunkedArray`] from existing chunks. + /// + /// # Safety + /// The Arrow datatype of all chunks must match the [`PolarsDataType`] `T`. + pub unsafe fn from_chunks(name: PlSmallStr, mut chunks: Vec) -> Self { + let dtype = match T::get_dtype() { + dtype @ DataType::List(_) => from_chunks_list_dtype(&mut chunks, dtype), + #[cfg(feature = "dtype-array")] + dtype @ DataType::Array(_, _) => from_chunks_list_dtype(&mut chunks, dtype), + #[cfg(feature = "dtype-struct")] + dtype @ DataType::Struct(_) => from_chunks_list_dtype(&mut chunks, dtype), + dt => dt, + }; + Self::from_chunks_and_dtype(name, chunks, dtype) + } + + /// # Safety + /// The Arrow datatype of all chunks must match the [`PolarsDataType`] `T`. + pub unsafe fn with_chunks(&self, chunks: Vec) -> Self { + ChunkedArray::new_with_compute_len(self.field.clone(), chunks) + } + + /// Create a new [`ChunkedArray`] from existing chunks. + /// + /// # Safety + /// + /// The Arrow datatype of all chunks must match the [`PolarsDataType`] `T`. + pub unsafe fn from_chunks_and_dtype( + name: PlSmallStr, + chunks: Vec, + dtype: DataType, + ) -> Self { + // assertions in debug mode + // that check if the data types in the arrays are as expected + #[cfg(debug_assertions)] + { + if !chunks.is_empty() && !chunks[0].is_empty() && dtype.is_primitive() { + assert_eq!(chunks[0].dtype(), &dtype.to_arrow(CompatLevel::newest())) + } + } + + Self::from_chunks_and_dtype_unchecked(name, chunks, dtype) + } + + /// Create a new [`ChunkedArray`] from existing chunks. + /// + /// # Safety + /// + /// The Arrow datatype of all chunks must match the [`PolarsDataType`] `T`. + pub(crate) unsafe fn from_chunks_and_dtype_unchecked( + name: PlSmallStr, + chunks: Vec, + dtype: DataType, + ) -> Self { + let field = Arc::new(Field::new(name, dtype)); + ChunkedArray::new_with_compute_len(field, chunks) + } + + pub fn full_null_like(ca: &Self, length: usize) -> Self { + let chunks = std::iter::once(T::Array::full_null( + length, + ca.dtype().to_arrow(CompatLevel::newest()), + )); + Self::from_chunk_iter_like(ca, chunks) + } +} + +impl ChunkedArray +where + T: PolarsNumericType, +{ + /// Create a new ChunkedArray by taking ownership of the Vec. This operation is zero copy. + pub fn from_vec(name: PlSmallStr, v: Vec) -> Self { + Self::with_chunk(name, to_primitive::(v, None)) + } + + /// Create a new ChunkedArray from a Vec and a validity mask. + pub fn from_vec_validity( + name: PlSmallStr, + values: Vec, + buffer: Option, + ) -> Self { + let arr = to_array::(values, buffer); + ChunkedArray::new_with_compute_len(Arc::new(Field::new(name, T::get_dtype())), vec![arr]) + } + + /// Create a temporary [`ChunkedArray`] from a slice. + /// + /// # Safety + /// The lifetime will be bound to the lifetime of the slice. + /// This will not be checked by the borrowchecker. + pub unsafe fn mmap_slice(name: PlSmallStr, values: &[T::Native]) -> Self { + Self::with_chunk(name, arrow::ffi::mmap::slice(values)) + } +} + +impl BooleanChunked { + /// Create a temporary [`ChunkedArray`] from a slice. + /// + /// # Safety + /// The lifetime will be bound to the lifetime of the slice. + /// This will not be checked by the borrowchecker. + pub unsafe fn mmap_slice(name: PlSmallStr, values: &[u8], offset: usize, len: usize) -> Self { + let arr = arrow::ffi::mmap::bitmap(values, offset, len).unwrap(); + Self::with_chunk(name, arr) + } + + pub fn from_bitmap(name: PlSmallStr, bitmap: Bitmap) -> Self { + Self::with_chunk( + name, + BooleanArray::new(ArrowDataType::Boolean, bitmap, None), + ) + } +} + +impl<'a, T> From<&'a ChunkedArray> for Vec>> +where + T: PolarsDataType, +{ + fn from(ca: &'a ChunkedArray) -> Self { + let mut out = Vec::with_capacity(ca.len()); + for arr in ca.downcast_iter() { + out.extend(arr.iter()) + } + out + } +} +impl From for Vec> { + fn from(ca: StringChunked) -> Self { + ca.iter().map(|opt| opt.map(|s| s.to_string())).collect() + } +} + +impl From for Vec> { + fn from(ca: BooleanChunked) -> Self { + let mut out = Vec::with_capacity(ca.len()); + for arr in ca.downcast_iter() { + out.extend(arr.iter()) + } + out + } +} diff --git a/crates/polars-core/src/chunked_array/from_iterator.rs b/crates/polars-core/src/chunked_array/from_iterator.rs new file mode 100644 index 000000000000..f55f2f045ae8 --- /dev/null +++ b/crates/polars-core/src/chunked_array/from_iterator.rs @@ -0,0 +1,268 @@ +//! Implementations of upstream traits for [`ChunkedArray`] +use std::borrow::{Borrow, Cow}; + +#[cfg(feature = "object")] +use arrow::bitmap::BitmapBuilder; + +use crate::chunked_array::builder::{AnonymousOwnedListBuilder, get_list_builder}; +#[cfg(feature = "object")] +use crate::chunked_array::object::ObjectArray; +#[cfg(feature = "object")] +use crate::chunked_array::object::builder::get_object_type; +use crate::prelude::*; +use crate::utils::{NoNull, get_iter_capacity}; + +/// FromIterator trait +impl FromIterator> for ChunkedArray +where + T: PolarsNumericType, +{ + #[inline] + fn from_iter>>(iter: I) -> Self { + // TODO: eliminate this FromIterator implementation entirely. + iter.into_iter().collect_ca(PlSmallStr::EMPTY) + } +} + +// NoNull is only a wrapper needed for specialization +impl FromIterator for NoNull> +where + T: PolarsNumericType, +{ + // We use Vec because it is way faster than Arrows builder. We can do this because we + // know we don't have null values. + #[inline] + fn from_iter>(iter: I) -> Self { + // 2021-02-07: aligned vec was ~2x faster than arrow collect. + let av = iter.into_iter().collect::>(); + NoNull::new(ChunkedArray::from_vec(PlSmallStr::EMPTY, av)) + } +} + +impl FromIterator> for ChunkedArray { + #[inline] + fn from_iter>>(iter: I) -> Self { + BooleanArray::from_iter(iter).into() + } +} + +impl FromIterator for BooleanChunked { + #[inline] + fn from_iter>(iter: I) -> Self { + iter.into_iter().collect_ca(PlSmallStr::EMPTY) + } +} + +impl FromIterator for NoNull { + #[inline] + fn from_iter>(iter: I) -> Self { + NoNull::new(iter.into_iter().collect_ca(PlSmallStr::EMPTY)) + } +} + +// FromIterator for StringChunked variants. + +impl FromIterator> for StringChunked +where + Ptr: AsRef, +{ + #[inline] + fn from_iter>>(iter: I) -> Self { + let arr = MutableBinaryViewArray::from_iterator(iter.into_iter()).freeze(); + ChunkedArray::with_chunk(PlSmallStr::EMPTY, arr) + } +} + +/// Local [`AsRef`] trait to circumvent the orphan rule. +pub trait PolarsAsRef: AsRef {} + +impl PolarsAsRef for String {} +impl PolarsAsRef for &str {} +// &["foo", "bar"] +impl PolarsAsRef for &&str {} + +impl PolarsAsRef for Cow<'_, str> {} +impl PolarsAsRef<[u8]> for Vec {} +impl PolarsAsRef<[u8]> for &[u8] {} +// TODO: remove! +impl PolarsAsRef<[u8]> for &&[u8] {} +impl PolarsAsRef<[u8]> for Cow<'_, [u8]> {} + +impl FromIterator for StringChunked +where + Ptr: PolarsAsRef, +{ + #[inline] + fn from_iter>(iter: I) -> Self { + let arr = MutableBinaryViewArray::from_values_iter(iter.into_iter()).freeze(); + ChunkedArray::with_chunk(PlSmallStr::EMPTY, arr) + } +} + +// FromIterator for BinaryChunked variants. +impl FromIterator> for BinaryChunked +where + Ptr: AsRef<[u8]>, +{ + #[inline] + fn from_iter>>(iter: I) -> Self { + let arr = MutableBinaryViewArray::from_iter(iter).freeze(); + ChunkedArray::with_chunk(PlSmallStr::EMPTY, arr) + } +} + +impl FromIterator for BinaryChunked +where + Ptr: PolarsAsRef<[u8]>, +{ + #[inline] + fn from_iter>(iter: I) -> Self { + let arr = MutableBinaryViewArray::from_values_iter(iter.into_iter()).freeze(); + ChunkedArray::with_chunk(PlSmallStr::EMPTY, arr) + } +} + +impl FromIterator for ListChunked +where + Ptr: Borrow, +{ + #[inline] + fn from_iter>(iter: I) -> Self { + let mut it = iter.into_iter(); + let capacity = get_iter_capacity(&it); + + // first take one to get the dtype. + let v = match it.next() { + Some(v) => v, + None => return ListChunked::full_null(PlSmallStr::EMPTY, 0), + }; + // We don't know the needed capacity. We arbitrarily choose an average of 5 elements per series. + let mut builder = get_list_builder( + v.borrow().dtype(), + capacity * 5, + capacity, + PlSmallStr::EMPTY, + ); + + builder.append_series(v.borrow()).unwrap(); + for s in it { + builder.append_series(s.borrow()).unwrap(); + } + builder.finish() + } +} + +impl FromIterator> for ListChunked { + fn from_iter>>(iter: T) -> Self { + ListChunked::from_iter( + iter.into_iter() + .map(|c| c.map(|c| c.take_materialized_series())), + ) + } +} + +impl FromIterator> for ListChunked { + #[inline] + fn from_iter>>(iter: I) -> Self { + let mut it = iter.into_iter(); + let capacity = get_iter_capacity(&it); + + // get first non None from iter + let first_value; + let mut init_null_count = 0; + loop { + match it.next() { + Some(Some(s)) => { + first_value = Some(s); + break; + }, + Some(None) => { + init_null_count += 1; + }, + None => return ListChunked::full_null(PlSmallStr::EMPTY, init_null_count), + } + } + + match first_value { + None => { + // already returned full_null above + unreachable!() + }, + Some(ref first_s) => { + // AnyValues with empty lists in python can create + // Series of an unknown dtype. + // We use the anonymousbuilder without a dtype + // the empty arrays is then not added (we add an extra offset instead) + // the next non-empty series then must have the correct dtype. + if matches!(first_s.dtype(), DataType::Null) && first_s.is_empty() { + let mut builder = + AnonymousOwnedListBuilder::new(PlSmallStr::EMPTY, capacity, None); + for _ in 0..init_null_count { + builder.append_null(); + } + builder.append_empty(); + + for opt_s in it { + builder.append_opt_series(opt_s.as_ref()).unwrap(); + } + builder.finish() + } else { + // We don't know the needed capacity. We arbitrarily choose an average of 5 elements per series. + let mut builder = get_list_builder( + first_s.dtype(), + capacity * 5, + capacity, + PlSmallStr::EMPTY, + ); + + for _ in 0..init_null_count { + builder.append_null(); + } + builder.append_series(first_s).unwrap(); + + for opt_s in it { + builder.append_opt_series(opt_s.as_ref()).unwrap(); + } + builder.finish() + } + }, + } + } +} + +impl FromIterator>> for ListChunked { + #[inline] + fn from_iter>>>(iter: I) -> Self { + iter.into_iter().collect_ca(PlSmallStr::EMPTY) + } +} + +#[cfg(feature = "object")] +impl FromIterator> for ObjectChunked { + fn from_iter>>(iter: I) -> Self { + let iter = iter.into_iter(); + let size = iter.size_hint().0; + let mut null_mask_builder = BitmapBuilder::with_capacity(size); + + let values: Vec = iter + .map(|value| match value { + Some(value) => { + null_mask_builder.push(true); + value + }, + None => { + null_mask_builder.push(false); + T::default() + }, + }) + .collect(); + + let arr = Box::new( + ObjectArray::from(values).with_validity(null_mask_builder.into_opt_validity()), + ); + ChunkedArray::new_with_compute_len( + Arc::new(Field::new(PlSmallStr::EMPTY, get_object_type::())), + vec![arr], + ) + } +} diff --git a/crates/polars-core/src/chunked_array/from_iterator_par.rs b/crates/polars-core/src/chunked_array/from_iterator_par.rs new file mode 100644 index 000000000000..a8287bd60bac --- /dev/null +++ b/crates/polars-core/src/chunked_array/from_iterator_par.rs @@ -0,0 +1,321 @@ +//! Implementations of upstream traits for [`ChunkedArray`] +use std::collections::LinkedList; +use std::sync::Mutex; + +use arrow::pushable::{NoOption, Pushable}; +use rayon::prelude::*; + +use super::from_iterator::PolarsAsRef; +use crate::chunked_array::builder::get_list_builder; +use crate::prelude::*; +use crate::utils::NoNull; +use crate::utils::flatten::flatten_par; + +/// FromParallelIterator trait +// Code taken from https://docs.rs/rayon/1.3.1/src/rayon/iter/extend.rs.html#356-366 +fn vec_push(mut vec: Vec, elem: T) -> Vec { + vec.push(elem); + vec +} + +fn as_list(item: T) -> LinkedList { + let mut list = LinkedList::new(); + list.push_back(item); + list +} + +fn list_append(mut list1: LinkedList, mut list2: LinkedList) -> LinkedList { + list1.append(&mut list2); + list1 +} + +fn collect_into_linked_list_vec(par_iter: I) -> LinkedList> +where + I: IntoParallelIterator, +{ + let it = par_iter.into_par_iter(); + // be careful optimizing allocations. Its hard to figure out the size + // needed + // https://github.com/pola-rs/polars/issues/1562 + it.fold(Vec::new, vec_push) + .map(as_list) + .reduce(LinkedList::new, list_append) +} + +fn collect_into_linked_list(par_iter: I, identity: F) -> LinkedList +where + I: IntoParallelIterator, + P: Pushable + Send + Sync, + F: Fn() -> P + Sync + Send, + P::Freeze: Send, +{ + let it = par_iter.into_par_iter(); + it.fold(identity, |mut v, item| { + v.push(item); + v + }) + // The freeze on this line, ensures the null count is done in parallel + .map(|p| as_list(p.freeze())) + .reduce(LinkedList::new, list_append) +} + +fn get_capacity_from_par_results(ll: &LinkedList>) -> usize { + ll.iter().map(|list| list.len()).sum() +} + +impl FromParallelIterator for NoNull> +where + T: PolarsNumericType, +{ + fn from_par_iter>(iter: I) -> Self { + // Get linkedlist filled with different vec result from different threads + let vectors = collect_into_linked_list_vec(iter); + let vectors = vectors.into_iter().collect::>(); + let values = flatten_par(&vectors); + NoNull::new(ChunkedArray::new_vec(PlSmallStr::EMPTY, values)) + } +} + +impl FromParallelIterator> for ChunkedArray +where + T: PolarsNumericType, +{ + fn from_par_iter>>(iter: I) -> Self { + let chunks = collect_into_linked_list(iter, MutablePrimitiveArray::new); + Self::from_chunk_iter(PlSmallStr::EMPTY, chunks).optional_rechunk() + } +} + +impl FromParallelIterator for BooleanChunked { + fn from_par_iter>(iter: I) -> Self { + let chunks = collect_into_linked_list(iter, MutableBooleanArray::new); + Self::from_chunk_iter(PlSmallStr::EMPTY, chunks).optional_rechunk() + } +} + +impl FromParallelIterator> for BooleanChunked { + fn from_par_iter>>(iter: I) -> Self { + let chunks = collect_into_linked_list(iter, MutableBooleanArray::new); + Self::from_chunk_iter(PlSmallStr::EMPTY, chunks).optional_rechunk() + } +} + +impl FromParallelIterator for StringChunked +where + Ptr: PolarsAsRef + Send + Sync + NoOption, +{ + fn from_par_iter>(iter: I) -> Self { + let chunks = collect_into_linked_list(iter, MutableBinaryViewArray::new); + Self::from_chunk_iter(PlSmallStr::EMPTY, chunks).optional_rechunk() + } +} + +impl FromParallelIterator for BinaryChunked +where + Ptr: PolarsAsRef<[u8]> + Send + Sync + NoOption, +{ + fn from_par_iter>(iter: I) -> Self { + let chunks = collect_into_linked_list(iter, MutableBinaryViewArray::new); + Self::from_chunk_iter(PlSmallStr::EMPTY, chunks).optional_rechunk() + } +} + +impl FromParallelIterator> for StringChunked +where + Ptr: AsRef + Send + Sync, +{ + fn from_par_iter>>(iter: I) -> Self { + let chunks = collect_into_linked_list(iter, MutableBinaryViewArray::new); + Self::from_chunk_iter(PlSmallStr::EMPTY, chunks).optional_rechunk() + } +} + +impl FromParallelIterator> for BinaryChunked +where + Ptr: AsRef<[u8]> + Send + Sync, +{ + fn from_par_iter>>(iter: I) -> Self { + let chunks = collect_into_linked_list(iter, MutableBinaryViewArray::new); + Self::from_chunk_iter(PlSmallStr::EMPTY, chunks).optional_rechunk() + } +} + +pub trait FromParIterWithDtype { + fn from_par_iter_with_dtype(iter: I, name: PlSmallStr, dtype: DataType) -> Self + where + I: IntoParallelIterator, + Self: Sized; +} + +fn get_value_cap(vectors: &LinkedList>>) -> usize { + vectors + .iter() + .map(|list| { + list.iter() + .map(|opt_s| opt_s.as_ref().map(|s| s.len()).unwrap_or(0)) + .sum::() + }) + .sum::() +} + +fn get_dtype(vectors: &LinkedList>>) -> DataType { + for v in vectors { + for s in v.iter().flatten() { + let dtype = s.dtype(); + if !matches!(dtype, DataType::Null) { + return dtype.clone(); + } + } + } + DataType::Null +} + +fn materialize_list( + name: PlSmallStr, + vectors: &LinkedList>>, + dtype: DataType, + value_capacity: usize, + list_capacity: usize, +) -> PolarsResult { + let mut builder = get_list_builder(&dtype, value_capacity, list_capacity, name); + for v in vectors { + for val in v { + builder.append_opt_series(val.as_ref())?; + } + } + Ok(builder.finish()) +} + +impl FromParallelIterator> for ListChunked { + fn from_par_iter(par_iter: I) -> Self + where + I: IntoParallelIterator>, + { + list_from_par_iter(par_iter, PlSmallStr::EMPTY).unwrap() + } +} + +pub fn list_from_par_iter(par_iter: I, name: PlSmallStr) -> PolarsResult +where + I: IntoParallelIterator>, +{ + let vectors = collect_into_linked_list_vec(par_iter); + + let list_capacity: usize = get_capacity_from_par_results(&vectors); + let value_capacity = get_value_cap(&vectors); + let dtype = get_dtype(&vectors); + if let DataType::Null = dtype { + Ok(ListChunked::full_null_with_dtype( + name, + list_capacity, + &DataType::Null, + )) + } else { + materialize_list(name, &vectors, dtype, value_capacity, list_capacity) + } +} + +pub fn try_list_from_par_iter(par_iter: I, name: PlSmallStr) -> PolarsResult +where + I: IntoParallelIterator>>, +{ + fn ok(saved: &Mutex>) -> impl Fn(Result) -> Option + '_ { + move |item| match item { + Ok(item) => Some(item), + Err(error) => { + // We don't need a blocking `lock()`, as anybody + // else holding the lock will also be writing + // `Some(error)`, and then ours is irrelevant. + if let Ok(mut guard) = saved.try_lock() { + if guard.is_none() { + *guard = Some(error); + } + } + None + }, + } + } + + let saved_error = Mutex::new(None); + let iter = par_iter.into_par_iter().map(ok(&saved_error)).while_some(); + + let collection = list_from_par_iter(iter, name)?; + + match saved_error.into_inner().unwrap() { + Some(error) => Err(error), + None => Ok(collection), + } +} + +impl FromParIterWithDtype> for ListChunked { + fn from_par_iter_with_dtype(iter: I, name: PlSmallStr, dtype: DataType) -> Self + where + I: IntoParallelIterator>, + Self: Sized, + { + let vectors = collect_into_linked_list_vec(iter); + + let list_capacity: usize = get_capacity_from_par_results(&vectors); + let value_capacity = get_value_cap(&vectors); + if let DataType::List(dtype) = dtype { + materialize_list(name, &vectors, *dtype, value_capacity, list_capacity).unwrap() + } else { + panic!("expected list dtype") + } + } +} + +pub trait ChunkedCollectParIterExt: ParallelIterator { + fn collect_ca_with_dtype>( + self, + name: PlSmallStr, + dtype: DataType, + ) -> B + where + Self: Sized, + { + B::from_par_iter_with_dtype(self, name, dtype) + } +} + +impl ChunkedCollectParIterExt for I {} + +// Adapted from rayon +impl FromParIterWithDtype> for Result +where + C: FromParIterWithDtype, + T: Send, + E: Send, +{ + fn from_par_iter_with_dtype(par_iter: I, name: PlSmallStr, dtype: DataType) -> Self + where + I: IntoParallelIterator>, + { + fn ok(saved: &Mutex>) -> impl Fn(Result) -> Option + '_ { + move |item| match item { + Ok(item) => Some(item), + Err(error) => { + // We don't need a blocking `lock()`, as anybody + // else holding the lock will also be writing + // `Some(error)`, and then ours is irrelevant. + if let Ok(mut guard) = saved.try_lock() { + if guard.is_none() { + *guard = Some(error); + } + } + None + }, + } + } + + let saved_error = Mutex::new(None); + let iter = par_iter.into_par_iter().map(ok(&saved_error)).while_some(); + + let collection = C::from_par_iter_with_dtype(iter, name, dtype); + + match saved_error.into_inner().unwrap() { + Some(error) => Err(error), + None => Ok(collection), + } + } +} diff --git a/crates/polars-core/src/chunked_array/iterator/mod.rs b/crates/polars-core/src/chunked_array/iterator/mod.rs new file mode 100644 index 000000000000..a87d888968f3 --- /dev/null +++ b/crates/polars-core/src/chunked_array/iterator/mod.rs @@ -0,0 +1,1084 @@ +use arrow::array::*; + +use crate::prelude::*; + +pub mod par; + +impl ChunkedArray +where + T: PolarsDataType, +{ + #[inline] + pub fn iter(&self) -> impl PolarsIterator>> { + // SAFETY: we set the correct length of the iterator. + unsafe { + self.downcast_iter() + .flat_map(|arr| arr.iter()) + .trust_my_length(self.len()) + } + } +} + +/// A [`PolarsIterator`] is an iterator over a [`ChunkedArray`] which contains polars types. A [`PolarsIterator`] +/// must implement [`ExactSizeIterator`] and [`DoubleEndedIterator`]. +pub trait PolarsIterator: + ExactSizeIterator + DoubleEndedIterator + Send + Sync + TrustedLen +{ +} +unsafe impl TrustedLen for Box + '_> {} + +/// Implement [`PolarsIterator`] for every iterator that implements the needed traits. +impl PolarsIterator for T where + T: ExactSizeIterator + DoubleEndedIterator + Send + Sync + TrustedLen +{ +} + +impl<'a, T> IntoIterator for &'a ChunkedArray +where + T: PolarsNumericType, +{ + type Item = Option; + type IntoIter = Box + 'a>; + fn into_iter(self) -> Self::IntoIter { + Box::new( + // we know that we only iterate over length == self.len() + unsafe { + self.downcast_iter() + .flatten() + .map(|x| x.copied()) + .trust_my_length(self.len()) + }, + ) + } +} + +impl<'a> IntoIterator for &'a BooleanChunked { + type Item = Option; + type IntoIter = Box + 'a>; + fn into_iter(self) -> Self::IntoIter { + // we know that we only iterate over length == self.len() + unsafe { Box::new(self.downcast_iter().flatten().trust_my_length(self.len())) } + } +} + +/// The no null iterator for a [`BooleanArray`] +pub struct BoolIterNoNull<'a> { + array: &'a BooleanArray, + current: usize, + current_end: usize, +} + +impl<'a> BoolIterNoNull<'a> { + /// create a new iterator + pub fn new(array: &'a BooleanArray) -> Self { + BoolIterNoNull { + array, + current: 0, + current_end: array.len(), + } + } +} + +impl Iterator for BoolIterNoNull<'_> { + type Item = bool; + + fn next(&mut self) -> Option { + if self.current == self.current_end { + None + } else { + let old = self.current; + self.current += 1; + unsafe { Some(self.array.value_unchecked(old)) } + } + } + + fn size_hint(&self) -> (usize, Option) { + ( + self.array.len() - self.current, + Some(self.array.len() - self.current), + ) + } +} + +impl DoubleEndedIterator for BoolIterNoNull<'_> { + fn next_back(&mut self) -> Option { + if self.current_end == self.current { + None + } else { + self.current_end -= 1; + unsafe { Some(self.array.value_unchecked(self.current_end)) } + } + } +} + +/// all arrays have known size. +impl ExactSizeIterator for BoolIterNoNull<'_> {} + +impl BooleanChunked { + #[allow(clippy::wrong_self_convention)] + #[doc(hidden)] + pub fn into_no_null_iter( + &self, + ) -> impl '_ + Send + Sync + ExactSizeIterator + DoubleEndedIterator + TrustedLen + { + // we know that we only iterate over length == self.len() + unsafe { + self.downcast_iter() + .flat_map(BoolIterNoNull::new) + .trust_my_length(self.len()) + } + } +} + +impl<'a> IntoIterator for &'a StringChunked { + type Item = Option<&'a str>; + type IntoIter = Box + 'a>; + fn into_iter(self) -> Self::IntoIter { + // we know that we only iterate over length == self.len() + unsafe { Box::new(self.downcast_iter().flatten().trust_my_length(self.len())) } + } +} + +impl StringChunked { + #[allow(clippy::wrong_self_convention)] + #[doc(hidden)] + pub fn into_no_null_iter( + &self, + ) -> impl '_ + Send + Sync + ExactSizeIterator + DoubleEndedIterator + TrustedLen + { + // we know that we only iterate over length == self.len() + unsafe { + self.downcast_iter() + .flat_map(|arr| arr.values_iter()) + .trust_my_length(self.len()) + } + } +} + +impl<'a> IntoIterator for &'a BinaryChunked { + type Item = Option<&'a [u8]>; + type IntoIter = Box + 'a>; + fn into_iter(self) -> Self::IntoIter { + // we know that we only iterate over length == self.len() + unsafe { Box::new(self.downcast_iter().flatten().trust_my_length(self.len())) } + } +} + +impl BinaryChunked { + #[allow(clippy::wrong_self_convention)] + #[doc(hidden)] + pub fn into_no_null_iter( + &self, + ) -> impl '_ + Send + Sync + ExactSizeIterator + DoubleEndedIterator + TrustedLen + { + // we know that we only iterate over length == self.len() + unsafe { + self.downcast_iter() + .flat_map(|arr| arr.values_iter()) + .trust_my_length(self.len()) + } + } +} + +impl<'a> IntoIterator for &'a BinaryOffsetChunked { + type Item = Option<&'a [u8]>; + type IntoIter = Box + 'a>; + fn into_iter(self) -> Self::IntoIter { + // we know that we only iterate over length == self.len() + unsafe { Box::new(self.downcast_iter().flatten().trust_my_length(self.len())) } + } +} + +impl BinaryOffsetChunked { + #[allow(clippy::wrong_self_convention)] + #[doc(hidden)] + pub fn into_no_null_iter( + &self, + ) -> impl '_ + Send + Sync + ExactSizeIterator + DoubleEndedIterator + TrustedLen + { + // we know that we only iterate over length == self.len() + unsafe { + self.downcast_iter() + .flat_map(|arr| arr.values_iter()) + .trust_my_length(self.len()) + } + } +} + +impl<'a> IntoIterator for &'a ListChunked { + type Item = Option; + type IntoIter = Box + 'a>; + fn into_iter(self) -> Self::IntoIter { + let dtype = self.inner_dtype(); + + if self.null_count() == 0 { + // we know that we only iterate over length == self.len() + unsafe { + Box::new( + self.downcast_iter() + .flat_map(|arr| arr.iter().unwrap_required()) + .trust_my_length(self.len()) + .map(move |arr| { + Some(Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![arr], + dtype, + )) + }), + ) + } + } else { + // we know that we only iterate over length == self.len() + unsafe { + Box::new( + self.downcast_iter() + .flat_map(|arr| arr.iter()) + .trust_my_length(self.len()) + .map(move |arr| { + arr.map(|arr| { + Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![arr], + dtype, + ) + }) + }), + ) + } + } + } +} + +impl ListChunked { + #[allow(clippy::wrong_self_convention)] + #[doc(hidden)] + pub fn into_no_null_iter( + &self, + ) -> impl '_ + Send + Sync + ExactSizeIterator + DoubleEndedIterator + TrustedLen + { + let inner_type = self.inner_dtype(); + unsafe { + self.downcast_iter() + .flat_map(|arr| arr.values_iter()) + .map(move |arr| { + Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![arr], + inner_type, + ) + }) + .trust_my_length(self.len()) + } + } +} + +#[cfg(feature = "dtype-array")] +impl<'a> IntoIterator for &'a ArrayChunked { + type Item = Option; + type IntoIter = Box + 'a>; + fn into_iter(self) -> Self::IntoIter { + let dtype = self.inner_dtype(); + + if self.null_count() == 0 { + // we know that we only iterate over length == self.len() + unsafe { + Box::new( + self.downcast_iter() + .flat_map(|arr| arr.iter().unwrap_required()) + .trust_my_length(self.len()) + .map(move |arr| { + Some(Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![arr], + dtype, + )) + }), + ) + } + } else { + // we know that we only iterate over length == self.len() + unsafe { + Box::new( + self.downcast_iter() + .flat_map(|arr| arr.iter()) + .trust_my_length(self.len()) + .map(move |arr| { + arr.map(|arr| { + Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![arr], + dtype, + ) + }) + }), + ) + } + } + } +} + +#[cfg(feature = "dtype-array")] +pub struct FixedSizeListIterNoNull<'a> { + array: &'a FixedSizeListArray, + inner_type: DataType, + current: usize, + current_end: usize, +} + +#[cfg(feature = "dtype-array")] +impl<'a> FixedSizeListIterNoNull<'a> { + /// create a new iterator + pub fn new(array: &'a FixedSizeListArray, inner_type: DataType) -> Self { + FixedSizeListIterNoNull { + array, + inner_type, + current: 0, + current_end: array.len(), + } + } +} + +#[cfg(feature = "dtype-array")] +impl Iterator for FixedSizeListIterNoNull<'_> { + type Item = Series; + + fn next(&mut self) -> Option { + if self.current == self.current_end { + None + } else { + let old = self.current; + self.current += 1; + unsafe { + Some(Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![self.array.value_unchecked(old)], + &self.inner_type, + )) + } + } + } + + fn size_hint(&self) -> (usize, Option) { + ( + self.array.len() - self.current, + Some(self.array.len() - self.current), + ) + } +} + +#[cfg(feature = "dtype-array")] +impl DoubleEndedIterator for FixedSizeListIterNoNull<'_> { + fn next_back(&mut self) -> Option { + if self.current_end == self.current { + None + } else { + self.current_end -= 1; + unsafe { + Some( + Series::try_from(( + PlSmallStr::EMPTY, + self.array.value_unchecked(self.current_end), + )) + .unwrap(), + ) + } + } + } +} + +/// all arrays have known size. +#[cfg(feature = "dtype-array")] +impl ExactSizeIterator for FixedSizeListIterNoNull<'_> {} + +#[cfg(feature = "dtype-array")] +impl ArrayChunked { + #[allow(clippy::wrong_self_convention)] + #[doc(hidden)] + pub fn into_no_null_iter( + &self, + ) -> impl '_ + Send + Sync + ExactSizeIterator + DoubleEndedIterator + TrustedLen + { + // we know that we only iterate over length == self.len() + let inner_type = self.inner_dtype(); + unsafe { + self.downcast_iter() + .flat_map(move |arr| FixedSizeListIterNoNull::new(arr, inner_type.clone())) + .trust_my_length(self.len()) + } + } +} + +#[cfg(feature = "object")] +impl<'a, T> IntoIterator for &'a ObjectChunked +where + T: PolarsObject, +{ + type Item = Option<&'a T>; + type IntoIter = Box + 'a>; + fn into_iter(self) -> Self::IntoIter { + // we know that we only iterate over length == self.len() + unsafe { Box::new(self.downcast_iter().flatten().trust_my_length(self.len())) } + } +} + +#[cfg(feature = "object")] +impl ObjectChunked { + #[allow(clippy::wrong_self_convention)] + #[doc(hidden)] + pub fn into_no_null_iter( + &self, + ) -> impl '_ + Send + Sync + ExactSizeIterator + DoubleEndedIterator + TrustedLen + { + // we know that we only iterate over length == self.len() + unsafe { + self.downcast_iter() + .flat_map(|arr| arr.values_iter()) + .trust_my_length(self.len()) + } + } +} + +/// Wrapper struct to convert an iterator of type `T` into one of type [`Option`]. It is useful to make the +/// [`IntoIterator`] trait, in which every iterator shall return an [`Option`]. +pub struct SomeIterator(I) +where + I: Iterator; + +impl Iterator for SomeIterator +where + I: Iterator, +{ + type Item = Option; + + fn next(&mut self) -> Option { + self.0.next().map(Some) + } + + fn size_hint(&self) -> (usize, Option) { + self.0.size_hint() + } +} + +impl DoubleEndedIterator for SomeIterator +where + I: DoubleEndedIterator, +{ + fn next_back(&mut self) -> Option { + self.0.next_back().map(Some) + } +} + +impl ExactSizeIterator for SomeIterator where I: ExactSizeIterator {} + +#[cfg(test)] +mod test { + use crate::prelude::*; + + #[test] + fn out_of_bounds() { + let mut a = UInt32Chunked::from_slice(PlSmallStr::from_static("a"), &[1, 2, 3]); + let b = UInt32Chunked::from_slice(PlSmallStr::from_static("a"), &[1, 2, 3]); + a.append(&b).unwrap(); + + let v = a.into_iter().collect::>(); + assert_eq!( + vec![Some(1u32), Some(2), Some(3), Some(1), Some(2), Some(3)], + v + ) + } + + /// Generate test for [`IntoIterator`] trait for chunked arrays with just one chunk and no null values. + /// The expected return value of the iterator generated by [`IntoIterator`] trait is [`Option`], where + /// `T` is the chunked array type. + /// + /// # Input + /// + /// test_name: The name of the test to generate. + /// ca_type: The chunked array to use for this test. Ex: [`StringChunked`], [`UInt32Chunked`] ... + /// first_val: The first value contained in the chunked array. + /// second_val: The second value contained in the chunked array. + /// third_val: The third value contained in the chunked array. + macro_rules! impl_test_iter_single_chunk { + ($test_name:ident, $ca_type:ty, $first_val:expr, $second_val:expr, $third_val:expr) => { + #[test] + fn $test_name() { + let a = <$ca_type>::from_slice( + PlSmallStr::from_static("test"), + &[$first_val, $second_val, $third_val], + ); + + // normal iterator + let mut it = a.into_iter(); + assert_eq!(it.next(), Some(Some($first_val))); + assert_eq!(it.next(), Some(Some($second_val))); + assert_eq!(it.next(), Some(Some($third_val))); + assert_eq!(it.next(), None); + // ensure both sides are consumes. + assert_eq!(it.next_back(), None); + + // reverse iterator + let mut it = a.into_iter(); + assert_eq!(it.next_back(), Some(Some($third_val))); + assert_eq!(it.next_back(), Some(Some($second_val))); + assert_eq!(it.next_back(), Some(Some($first_val))); + assert_eq!(it.next_back(), None); + // ensure both sides are consumes. + assert_eq!(it.next(), None); + + // iterators should not cross + let mut it = a.into_iter(); + assert_eq!(it.next_back(), Some(Some($third_val))); + assert_eq!(it.next(), Some(Some($first_val))); + assert_eq!(it.next(), Some(Some($second_val))); + // should stop here as we took this one from the back + assert_eq!(it.next(), None); + // ensure both sides are consumes. + assert_eq!(it.next_back(), None); + + // do the same from the right side + let mut it = a.into_iter(); + assert_eq!(it.next(), Some(Some($first_val))); + assert_eq!(it.next_back(), Some(Some($third_val))); + assert_eq!(it.next_back(), Some(Some($second_val))); + assert_eq!(it.next_back(), None); + // ensure both sides are consumes. + assert_eq!(it.next(), None); + } + }; + } + + impl_test_iter_single_chunk!(num_iter_single_chunk, UInt32Chunked, 1, 2, 3); + impl_test_iter_single_chunk!(utf8_iter_single_chunk, StringChunked, "a", "b", "c"); + impl_test_iter_single_chunk!(bool_iter_single_chunk, BooleanChunked, true, true, false); + + /// Generate test for [`IntoIterator`] trait for chunked arrays with just one chunk and null values. + /// The expected return value of the iterator generated by [`IntoIterator`] trait is [`Option`], where + /// `T` is the chunked array type. + /// + /// # Input + /// + /// test_name: The name of the test to generate. + /// ca_type: The chunked array to use for this test. Ex: [`StringChunked`], [`UInt32Chunked`] ... + /// first_val: The first value contained in the chunked array. Must be an [`Option`]. + /// second_val: The second value contained in the chunked array. Must be an [`Option`]. + /// third_val: The third value contained in the chunked array. Must be an [`Option`]. + macro_rules! impl_test_iter_single_chunk_null_check { + ($test_name:ident, $ca_type:ty, $first_val:expr, $second_val:expr, $third_val:expr) => { + #[test] + fn $test_name() { + let a = <$ca_type>::new( + PlSmallStr::from_static("test"), + &[$first_val, $second_val, $third_val], + ); + + // normal iterator + let mut it = a.into_iter(); + assert_eq!(it.next(), Some($first_val)); + assert_eq!(it.next(), Some($second_val)); + assert_eq!(it.next(), Some($third_val)); + assert_eq!(it.next(), None); + // ensure both sides are consumes. + assert_eq!(it.next_back(), None); + + // reverse iterator + let mut it = a.into_iter(); + assert_eq!(it.next_back(), Some($third_val)); + assert_eq!(it.next_back(), Some($second_val)); + assert_eq!(it.next_back(), Some($first_val)); + assert_eq!(it.next_back(), None); + // ensure both sides are consumes. + assert_eq!(it.next(), None); + + // iterators should not cross + let mut it = a.into_iter(); + assert_eq!(it.next_back(), Some($third_val)); + assert_eq!(it.next(), Some($first_val)); + assert_eq!(it.next(), Some($second_val)); + // should stop here as we took this one from the back + assert_eq!(it.next(), None); + // ensure both sides are consumes. + assert_eq!(it.next_back(), None); + + // do the same from the right side + let mut it = a.into_iter(); + assert_eq!(it.next(), Some($first_val)); + assert_eq!(it.next_back(), Some($third_val)); + assert_eq!(it.next_back(), Some($second_val)); + assert_eq!(it.next_back(), None); + // ensure both sides are consumes. + assert_eq!(it.next(), None); + } + }; + } + + impl_test_iter_single_chunk_null_check!( + num_iter_single_chunk_null_check, + UInt32Chunked, + Some(1), + None, + Some(3) + ); + impl_test_iter_single_chunk_null_check!( + utf8_iter_single_chunk_null_check, + StringChunked, + Some("a"), + None, + Some("c") + ); + impl_test_iter_single_chunk_null_check!( + bool_iter_single_chunk_null_check, + BooleanChunked, + Some(true), + None, + Some(false) + ); + + /// Generate test for [`IntoIterator`] trait for chunked arrays with many chunks and no null values. + /// The expected return value of the iterator generated by [`IntoIterator`] trait is [`Option`], where + /// `T` is the chunked array type. + /// + /// # Input + /// + /// test_name: The name of the test to generate. + /// ca_type: The chunked array to use for this test. Ex: [`StringChunked`], [`UInt32Chunked`] ... + /// first_val: The first value contained in the chunked array. + /// second_val: The second value contained in the chunked array. + /// third_val: The third value contained in the chunked array. + macro_rules! impl_test_iter_many_chunk { + ($test_name:ident, $ca_type:ty, $first_val:expr, $second_val:expr, $third_val:expr) => { + #[test] + fn $test_name() { + let mut a = <$ca_type>::from_slice( + PlSmallStr::from_static("test"), + &[$first_val, $second_val], + ); + let a_b = <$ca_type>::from_slice(PlSmallStr::EMPTY, &[$third_val]); + a.append(&a_b).unwrap(); + + // normal iterator + let mut it = a.into_iter(); + assert_eq!(it.next(), Some(Some($first_val))); + assert_eq!(it.next(), Some(Some($second_val))); + assert_eq!(it.next(), Some(Some($third_val))); + assert_eq!(it.next(), None); + // ensure both sides are consumes. + assert_eq!(it.next_back(), None); + + // reverse iterator + let mut it = a.into_iter(); + assert_eq!(it.next_back(), Some(Some($third_val))); + assert_eq!(it.next_back(), Some(Some($second_val))); + assert_eq!(it.next_back(), Some(Some($first_val))); + assert_eq!(it.next_back(), None); + // ensure both sides are consumes. + assert_eq!(it.next(), None); + + // iterators should not cross + let mut it = a.into_iter(); + assert_eq!(it.next_back(), Some(Some($third_val))); + assert_eq!(it.next(), Some(Some($first_val))); + assert_eq!(it.next(), Some(Some($second_val))); + // should stop here as we took this one from the back + assert_eq!(it.next(), None); + // ensure both sides are consumes. + assert_eq!(it.next_back(), None); + + // do the same from the right side + let mut it = a.into_iter(); + assert_eq!(it.next(), Some(Some($first_val))); + assert_eq!(it.next_back(), Some(Some($third_val))); + assert_eq!(it.next_back(), Some(Some($second_val))); + assert_eq!(it.next_back(), None); + // ensure both sides are consumes. + assert_eq!(it.next(), None); + } + }; + } + + impl_test_iter_many_chunk!(num_iter_many_chunk, UInt32Chunked, 1, 2, 3); + impl_test_iter_many_chunk!(utf8_iter_many_chunk, StringChunked, "a", "b", "c"); + impl_test_iter_many_chunk!(bool_iter_many_chunk, BooleanChunked, true, true, false); + + /// Generate test for [`IntoIterator`] trait for chunked arrays with many chunk and null values. + /// The expected return value of the iterator generated by [`IntoIterator`] trait is [`Option`], where + /// `T` is the chunked array type. + /// + /// # Input + /// + /// test_name: The name of the test to generate. + /// ca_type: The chunked array to use for this test. Ex: [`StringChunked`], [`UInt32Chunked`] ... + /// first_val: The first value contained in the chunked array. Must be an [`Option`]. + /// second_val: The second value contained in the chunked array. Must be an [`Option`]. + /// third_val: The third value contained in the chunked array. Must be an [`Option`]. + macro_rules! impl_test_iter_many_chunk_null_check { + ($test_name:ident, $ca_type:ty, $first_val:expr, $second_val:expr, $third_val:expr) => { + #[test] + fn $test_name() { + let mut a = + <$ca_type>::new(PlSmallStr::from_static("test"), &[$first_val, $second_val]); + let a_b = <$ca_type>::new(PlSmallStr::EMPTY, &[$third_val]); + a.append(&a_b).unwrap(); + + // normal iterator + let mut it = a.into_iter(); + assert_eq!(it.next(), Some($first_val)); + assert_eq!(it.next(), Some($second_val)); + assert_eq!(it.next(), Some($third_val)); + assert_eq!(it.next(), None); + // ensure both sides are consumes. + assert_eq!(it.next_back(), None); + + // reverse iterator + let mut it = a.into_iter(); + assert_eq!(it.next_back(), Some($third_val)); + assert_eq!(it.next_back(), Some($second_val)); + assert_eq!(it.next_back(), Some($first_val)); + assert_eq!(it.next_back(), None); + // ensure both sides are consumes. + assert_eq!(it.next(), None); + + // iterators should not cross + let mut it = a.into_iter(); + assert_eq!(it.next_back(), Some($third_val)); + assert_eq!(it.next(), Some($first_val)); + assert_eq!(it.next(), Some($second_val)); + // should stop here as we took this one from the back + assert_eq!(it.next(), None); + // ensure both sides are consumes. + assert_eq!(it.next_back(), None); + + // do the same from the right side + let mut it = a.into_iter(); + assert_eq!(it.next(), Some($first_val)); + assert_eq!(it.next_back(), Some($third_val)); + assert_eq!(it.next_back(), Some($second_val)); + assert_eq!(it.next_back(), None); + // ensure both sides are consumes. + assert_eq!(it.next(), None); + } + }; + } + + impl_test_iter_many_chunk_null_check!( + num_iter_many_chunk_null_check, + UInt32Chunked, + Some(1), + None, + Some(3) + ); + impl_test_iter_many_chunk_null_check!( + utf8_iter_many_chunk_null_check, + StringChunked, + Some("a"), + None, + Some("c") + ); + impl_test_iter_many_chunk_null_check!( + bool_iter_many_chunk_null_check, + BooleanChunked, + Some(true), + None, + Some(false) + ); + + /// Generate test for [`IntoNoNullIterator`] trait for chunked arrays with just one chunk and no null values. + /// The expected return value of the iterator generated by [`IntoNoNullIterator`] trait is `T`, where + /// `T` is the chunked array type. + /// + /// # Input + /// + /// test_name: The name of the test to generate. + /// ca_type: The chunked array to use for this test. Ex: [`StringChunked`], [`UInt32Chunked`] ... + /// first_val: The first value contained in the chunked array. + /// second_val: The second value contained in the chunked array. + /// third_val: The third value contained in the chunked array. + macro_rules! impl_test_no_null_iter_single_chunk { + ($test_name:ident, $ca_type:ty, $first_val:expr, $second_val:expr, $third_val:expr) => { + #[test] + fn $test_name() { + let a = <$ca_type>::from_slice( + PlSmallStr::from_static("test"), + &[$first_val, $second_val, $third_val], + ); + + // normal iterator + let mut it = a.into_no_null_iter(); + assert_eq!(it.next(), Some($first_val)); + assert_eq!(it.next(), Some($second_val)); + assert_eq!(it.next(), Some($third_val)); + assert_eq!(it.next(), None); + // ensure both sides are consumes. + assert_eq!(it.next_back(), None); + + // reverse iterator + let mut it = a.into_no_null_iter(); + assert_eq!(it.next_back(), Some($third_val)); + assert_eq!(it.next_back(), Some($second_val)); + assert_eq!(it.next_back(), Some($first_val)); + assert_eq!(it.next_back(), None); + // ensure both sides are consumes. + assert_eq!(it.next(), None); + + // iterators should not cross + let mut it = a.into_no_null_iter(); + assert_eq!(it.next_back(), Some($third_val)); + assert_eq!(it.next(), Some($first_val)); + assert_eq!(it.next(), Some($second_val)); + // should stop here as we took this one from the back + assert_eq!(it.next(), None); + // ensure both sides are consumes. + assert_eq!(it.next_back(), None); + + // do the same from the right side + let mut it = a.into_no_null_iter(); + assert_eq!(it.next(), Some($first_val)); + assert_eq!(it.next_back(), Some($third_val)); + assert_eq!(it.next_back(), Some($second_val)); + assert_eq!(it.next_back(), None); + // ensure both sides are consumes. + assert_eq!(it.next(), None); + } + }; + } + + impl_test_no_null_iter_single_chunk!(num_no_null_iter_single_chunk, UInt32Chunked, 1, 2, 3); + impl_test_no_null_iter_single_chunk!( + utf8_no_null_iter_single_chunk, + StringChunked, + "a", + "b", + "c" + ); + impl_test_no_null_iter_single_chunk!( + bool_no_null_iter_single_chunk, + BooleanChunked, + true, + true, + false + ); + + /// Generate test for [`IntoNoNullIterator`] trait for chunked arrays with many chunks and no null values. + /// The expected return value of the iterator generated by [`IntoNoNullIterator`] trait is `T`, where + /// `T` is the chunked array type. + /// + /// # Input + /// + /// test_name: The name of the test to generate. + /// ca_type: The chunked array to use for this test. Ex: [`StringChunked`], [`UInt32Chunked`] ... + /// first_val: The first value contained in the chunked array. + /// second_val: The second value contained in the chunked array. + /// third_val: The third value contained in the chunked array. + macro_rules! impl_test_no_null_iter_many_chunk { + ($test_name:ident, $ca_type:ty, $first_val:expr, $second_val:expr, $third_val:expr) => { + #[test] + fn $test_name() { + let mut a = <$ca_type>::from_slice( + PlSmallStr::from_static("test"), + &[$first_val, $second_val], + ); + let a_b = <$ca_type>::from_slice(PlSmallStr::EMPTY, &[$third_val]); + a.append(&a_b).unwrap(); + + // normal iterator + let mut it = a.into_no_null_iter(); + assert_eq!(it.next(), Some($first_val)); + assert_eq!(it.next(), Some($second_val)); + assert_eq!(it.next(), Some($third_val)); + assert_eq!(it.next(), None); + // ensure both sides are consumes. + assert_eq!(it.next_back(), None); + + // reverse iterator + let mut it = a.into_no_null_iter(); + assert_eq!(it.next_back(), Some($third_val)); + assert_eq!(it.next_back(), Some($second_val)); + assert_eq!(it.next_back(), Some($first_val)); + assert_eq!(it.next_back(), None); + // ensure both sides are consumes. + assert_eq!(it.next(), None); + + // iterators should not cross + let mut it = a.into_no_null_iter(); + assert_eq!(it.next_back(), Some($third_val)); + assert_eq!(it.next(), Some($first_val)); + assert_eq!(it.next(), Some($second_val)); + // should stop here as we took this one from the back + assert_eq!(it.next(), None); + // ensure both sides are consumes. + assert_eq!(it.next_back(), None); + + // do the same from the right side + let mut it = a.into_no_null_iter(); + assert_eq!(it.next(), Some($first_val)); + assert_eq!(it.next_back(), Some($third_val)); + assert_eq!(it.next_back(), Some($second_val)); + assert_eq!(it.next_back(), None); + // ensure both sides are consumes. + assert_eq!(it.next(), None); + } + }; + } + + impl_test_no_null_iter_many_chunk!(num_no_null_iter_many_chunk, UInt32Chunked, 1, 2, 3); + impl_test_no_null_iter_many_chunk!(utf8_no_null_iter_many_chunk, StringChunked, "a", "b", "c"); + impl_test_no_null_iter_many_chunk!( + bool_no_null_iter_many_chunk, + BooleanChunked, + true, + true, + false + ); + + /// The size of the skip iterator. + const SKIP_ITERATOR_SIZE: usize = 10; + + /// Generates tests to verify the correctness of the `skip` method. + /// + /// # Input + /// + /// test_name: The name of the test to implement, it is a function name so it shall be unique. + /// skip_values: The number of values to skip. Keep in mind that is the number of values to skip + /// after performing the first next, then, skip_values = 8, will skip until index 1 + skip_values = 9. + /// first_val: The value before skip. + /// second_val: The value after skip. + /// ca_init_block: The block which initialize the chunked array. It shall return the chunked array. + macro_rules! impl_test_iter_skip { + ($test_name:ident, $skip_values:expr, $first_val:expr, $second_val:expr, $ca_init_block:block) => { + #[test] + fn $test_name() { + let a = $ca_init_block; + + // Consume first position of iterator. + let mut it = a.into_iter(); + assert_eq!(it.next(), Some($first_val)); + + // Consume `$skip_values` and check the result. + let mut it = it.skip($skip_values); + assert_eq!(it.next(), Some($second_val)); + + // Consume more values than available and check result is None. + let mut it = it.skip(SKIP_ITERATOR_SIZE); + assert_eq!(it.next(), None); + } + }; + } + + /// Generates a `Vec` of `Strings`, where every position is the `String` representation of its index. + fn generate_utf8_vec(size: usize) -> Vec { + (0..size).map(|n| n.to_string()).collect() + } + + /// Generate a `Vec` of `Option`, where even indexes are `Some("{idx}")` and odd indexes are `None`. + fn generate_opt_utf8_vec(size: usize) -> Vec> { + (0..size) + .map(|n| { + if n % 2 == 0 { + Some(n.to_string()) + } else { + None + } + }) + .collect() + } + + impl_test_iter_skip!(utf8_iter_single_chunk_skip, 8, Some("0"), Some("9"), { + StringChunked::from_slice( + PlSmallStr::from_static("test"), + &generate_utf8_vec(SKIP_ITERATOR_SIZE), + ) + }); + + impl_test_iter_skip!( + utf8_iter_single_chunk_null_check_skip, + 8, + Some("0"), + None, + { + StringChunked::new( + PlSmallStr::from_static("test"), + &generate_opt_utf8_vec(SKIP_ITERATOR_SIZE), + ) + } + ); + + impl_test_iter_skip!(utf8_iter_many_chunk_skip, 18, Some("0"), Some("9"), { + let mut a = StringChunked::from_slice( + PlSmallStr::from_static("test"), + &generate_utf8_vec(SKIP_ITERATOR_SIZE), + ); + let a_b = StringChunked::from_slice( + PlSmallStr::from_static("test"), + &generate_utf8_vec(SKIP_ITERATOR_SIZE), + ); + a.append(&a_b).unwrap(); + a + }); + + impl_test_iter_skip!(utf8_iter_many_chunk_null_check_skip, 18, Some("0"), None, { + let mut a = StringChunked::new( + PlSmallStr::from_static("test"), + &generate_opt_utf8_vec(SKIP_ITERATOR_SIZE), + ); + let a_b = StringChunked::new( + PlSmallStr::from_static("test"), + &generate_opt_utf8_vec(SKIP_ITERATOR_SIZE), + ); + a.append(&a_b).unwrap(); + a + }); + + /// Generates a [`Vec`] of [`bool`], with even indexes are true, and odd indexes are false. + fn generate_boolean_vec(size: usize) -> Vec { + (0..size).map(|n| n % 2 == 0).collect() + } + + /// Generate a [`Vec`] of [`Option`], where: + /// - If the index is divisible by 3, then, the value is `None`. + /// - If the index is not divisible by 3 and it is even, then, the value is `Some(true)`. + /// - Otherwise, the value is `Some(false)`. + fn generate_opt_boolean_vec(size: usize) -> Vec> { + (0..size) + .map(|n| if n % 3 == 0 { None } else { Some(n % 2 == 0) }) + .collect() + } + + impl_test_iter_skip!(bool_iter_single_chunk_skip, 8, Some(true), Some(false), { + BooleanChunked::from_slice( + PlSmallStr::from_static("test"), + &generate_boolean_vec(SKIP_ITERATOR_SIZE), + ) + }); + + impl_test_iter_skip!(bool_iter_single_chunk_null_check_skip, 8, None, None, { + BooleanChunked::new( + PlSmallStr::from_static("test"), + &generate_opt_boolean_vec(SKIP_ITERATOR_SIZE), + ) + }); + + impl_test_iter_skip!(bool_iter_many_chunk_skip, 18, Some(true), Some(false), { + let mut a = BooleanChunked::from_slice( + PlSmallStr::from_static("test"), + &generate_boolean_vec(SKIP_ITERATOR_SIZE), + ); + let a_b = BooleanChunked::from_slice( + PlSmallStr::from_static("test"), + &generate_boolean_vec(SKIP_ITERATOR_SIZE), + ); + a.append(&a_b).unwrap(); + a + }); + + impl_test_iter_skip!(bool_iter_many_chunk_null_check_skip, 18, None, None, { + let mut a = BooleanChunked::new( + PlSmallStr::from_static("test"), + &generate_opt_boolean_vec(SKIP_ITERATOR_SIZE), + ); + let a_b = BooleanChunked::new( + PlSmallStr::from_static("test"), + &generate_opt_boolean_vec(SKIP_ITERATOR_SIZE), + ); + a.append(&a_b).unwrap(); + a + }); +} diff --git a/crates/polars-core/src/chunked_array/iterator/par/list.rs b/crates/polars-core/src/chunked_array/iterator/par/list.rs new file mode 100644 index 000000000000..438cf1779b59 --- /dev/null +++ b/crates/polars-core/src/chunked_array/iterator/par/list.rs @@ -0,0 +1,41 @@ +use rayon::prelude::*; + +use crate::prelude::*; + +unsafe fn idx_to_array(idx: usize, arr: &ListArray, dtype: &DataType) -> Option { + if arr.is_valid(idx) { + Some(arr.value_unchecked(idx)).map(|arr: ArrayRef| { + Series::from_chunks_and_dtype_unchecked(PlSmallStr::EMPTY, vec![arr], dtype) + }) + } else { + None + } +} + +impl ListChunked { + // Get a parallel iterator over the [`Series`] in this [`ListChunked`]. + pub fn par_iter(&self) -> impl ParallelIterator> + '_ { + self.chunks.par_iter().flat_map(move |arr| { + let dtype = self.inner_dtype(); + // SAFETY: + // guarded by the type system + let arr = &**arr; + let arr = unsafe { &*(arr as *const dyn Array as *const ListArray) }; + (0..arr.len()) + .into_par_iter() + .map(move |idx| unsafe { idx_to_array(idx, arr, dtype) }) + }) + } + + // Get an indexed parallel iterator over the [`Series`] in this [`ListChunked`]. + // Also might be faster as it doesn't use `flat_map`. + pub fn par_iter_indexed(&mut self) -> impl IndexedParallelIterator> + '_ { + self.rechunk_mut(); + let arr = self.downcast_iter().next().unwrap(); + + let dtype = self.inner_dtype(); + (0..arr.len()) + .into_par_iter() + .map(move |idx| unsafe { idx_to_array(idx, arr, dtype) }) + } +} diff --git a/crates/polars-core/src/chunked_array/iterator/par/mod.rs b/crates/polars-core/src/chunked_array/iterator/par/mod.rs new file mode 100644 index 000000000000..ccdb46c78411 --- /dev/null +++ b/crates/polars-core/src/chunked_array/iterator/par/mod.rs @@ -0,0 +1,2 @@ +pub mod list; +pub mod string; diff --git a/crates/polars-core/src/chunked_array/iterator/par/string.rs b/crates/polars-core/src/chunked_array/iterator/par/string.rs new file mode 100644 index 000000000000..6130fb711b4e --- /dev/null +++ b/crates/polars-core/src/chunked_array/iterator/par/string.rs @@ -0,0 +1,38 @@ +use rayon::prelude::*; + +use crate::prelude::*; + +#[inline] +unsafe fn idx_to_str(idx: usize, arr: &Utf8ViewArray) -> Option<&str> { + if arr.is_valid(idx) { + Some(arr.value_unchecked(idx)) + } else { + None + } +} + +impl StringChunked { + pub fn par_iter_indexed(&self) -> impl IndexedParallelIterator> { + assert_eq!(self.chunks.len(), 1); + let arr = &*self.chunks[0]; + + // SAFETY: + // guarded by the type system + let arr = unsafe { &*(arr as *const dyn Array as *const Utf8ViewArray) }; + (0..arr.len()) + .into_par_iter() + .map(move |idx| unsafe { idx_to_str(idx, arr) }) + } + + pub fn par_iter(&self) -> impl ParallelIterator> + '_ { + self.chunks.par_iter().flat_map(move |arr| { + // SAFETY: + // guarded by the type system + let arr = &**arr; + let arr = unsafe { &*(arr as *const dyn Array as *const Utf8ViewArray) }; + (0..arr.len()) + .into_par_iter() + .map(move |idx| unsafe { idx_to_str(idx, arr) }) + }) + } +} diff --git a/crates/polars-core/src/chunked_array/list/iterator.rs b/crates/polars-core/src/chunked_array/list/iterator.rs new file mode 100644 index 000000000000..c76eda862989 --- /dev/null +++ b/crates/polars-core/src/chunked_array/list/iterator.rs @@ -0,0 +1,459 @@ +use std::marker::PhantomData; +use std::ptr::NonNull; +use std::rc::Rc; + +use crate::chunked_array::flags::StatisticsFlags; +use crate::prelude::*; +use crate::series::amortized_iter::{AmortSeries, ArrayBox, unstable_series_container_and_ptr}; + +pub struct AmortizedListIter<'a, I: Iterator>> { + len: usize, + series_container: Rc, + inner: NonNull, + lifetime: PhantomData<&'a ArrayRef>, + iter: I, + // used only if feature="dtype-struct" + #[allow(dead_code)] + inner_dtype: DataType, +} + +impl>> AmortizedListIter<'_, I> { + pub(crate) unsafe fn new( + len: usize, + series_container: Series, + inner: NonNull, + iter: I, + inner_dtype: DataType, + ) -> Self { + Self { + len, + series_container: Rc::new(series_container), + inner, + lifetime: PhantomData, + iter, + inner_dtype, + } + } +} + +impl>> Iterator for AmortizedListIter<'_, I> { + type Item = Option; + + fn next(&mut self) -> Option { + self.iter.next().map(|opt_val| { + opt_val.map(|array_ref| { + #[cfg(feature = "dtype-struct")] + // structs arrays are bound to the series not to the arrayref + // so we must get a hold to the new array + if matches!(self.inner_dtype, DataType::Struct(_)) { + // SAFETY: + // dtype is known + unsafe { + let s = Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![array_ref], + &self.inner_dtype.to_physical(), + ) + .from_physical_unchecked(&self.inner_dtype) + .unwrap(); + let inner = Rc::make_mut(&mut self.series_container); + *inner = s; + + return AmortSeries::new(self.series_container.clone()); + } + } + // The series is cloned, we make a new container. + if Arc::strong_count(&self.series_container.0) > 1 + || Rc::strong_count(&self.series_container) > 1 + { + let (s, ptr) = unsafe { + unstable_series_container_and_ptr( + self.series_container.name().clone(), + array_ref, + self.series_container.dtype(), + ) + }; + self.series_container = Rc::new(s); + self.inner = NonNull::new(ptr).unwrap(); + } else { + // SAFETY: we checked the RC above; + let series_mut = + unsafe { Rc::get_mut(&mut self.series_container).unwrap_unchecked() }; + // update the inner state + unsafe { *self.inner.as_mut() = array_ref }; + + // As an optimization, we try to minimize how many calls to + // _get_inner_mut() we do. + let series_mut_inner = series_mut._get_inner_mut(); + // last iteration could have set the sorted flag (e.g. in compute_len) + series_mut_inner._set_flags(StatisticsFlags::empty()); + // make sure that the length is correct + series_mut_inner.compute_len(); + } + + // SAFETY: + // inner belongs to Series. + unsafe { + AmortSeries::new_with_chunk(self.series_container.clone(), self.inner.as_ref()) + } + }) + }) + } + + fn size_hint(&self) -> (usize, Option) { + (self.len, Some(self.len)) + } +} + +// # Safety +// we correctly implemented size_hint +unsafe impl>> TrustedLen for AmortizedListIter<'_, I> {} +impl>> ExactSizeIterator for AmortizedListIter<'_, I> {} + +impl ListChunked { + /// This is an iterator over a [`ListChunked`] that saves allocations. + /// A Series is: + /// 1. [`Arc`] + /// ChunkedArray is: + /// 2. Vec< 3. ArrayRef> + /// + /// The ArrayRef we indicated with 3. will be updated during iteration. + /// The Series will be pinned in memory, saving an allocation for + /// 1. Arc<..> + /// 2. Vec<...> + /// + /// If the returned `AmortSeries` is cloned, the local copy will be replaced and a new container + /// will be set. + pub fn amortized_iter(&self) -> AmortizedListIter> + '_> { + self.amortized_iter_with_name(PlSmallStr::EMPTY) + } + + /// See `amortized_iter`. + pub fn amortized_iter_with_name( + &self, + name: PlSmallStr, + ) -> AmortizedListIter> + '_> { + // we create the series container from the inner array + // so that the container has the proper dtype. + let arr = self.downcast_iter().next().unwrap(); + let inner_values = arr.values(); + + let inner_dtype = self.inner_dtype(); + let iter_dtype = match inner_dtype { + #[cfg(feature = "dtype-struct")] + DataType::Struct(_) => inner_dtype.to_physical(), + // TODO: figure out how to deal with physical/logical distinction + // physical primitives like time, date etc. work + // physical nested need more + _ => inner_dtype.clone(), + }; + + // SAFETY: + // inner type passed as physical type + let (s, ptr) = + unsafe { unstable_series_container_and_ptr(name, inner_values.clone(), &iter_dtype) }; + + // SAFETY: ptr belongs the Series.. + unsafe { + AmortizedListIter::new( + self.len(), + s, + NonNull::new(ptr).unwrap(), + self.downcast_iter().flat_map(|arr| arr.iter()), + inner_dtype.clone(), + ) + } + } + + /// Apply a closure `F` elementwise. + #[must_use] + pub fn apply_amortized_generic(&self, f: F) -> ChunkedArray + where + V: PolarsDataType, + F: FnMut(Option) -> Option + Copy, + V::Array: ArrayFromIter>, + { + // TODO! make an amortized iter that does not flatten + self.amortized_iter().map(f).collect_ca(self.name().clone()) + } + + pub fn try_apply_amortized_generic(&self, f: F) -> PolarsResult> + where + V: PolarsDataType, + F: FnMut(Option) -> PolarsResult> + Copy, + V::Array: ArrayFromIter>, + { + // TODO! make an amortized iter that does not flatten + self.amortized_iter() + .map(f) + .try_collect_ca(self.name().clone()) + } + + pub fn for_each_amortized(&self, f: F) + where + F: FnMut(Option), + { + self.amortized_iter().for_each(f) + } + + /// Zip with a `ChunkedArray` then apply a binary function `F` elementwise. + #[must_use] + pub fn zip_and_apply_amortized<'a, T, I, F>(&'a self, ca: &'a ChunkedArray, mut f: F) -> Self + where + T: PolarsDataType, + &'a ChunkedArray: IntoIterator, + I: TrustedLen>>, + F: FnMut(Option, Option>) -> Option, + { + if self.is_empty() { + return self.clone(); + } + let mut fast_explode = self.null_count() == 0; + let mut out: ListChunked = { + self.amortized_iter() + .zip(ca) + .map(|(opt_s, opt_v)| { + let out = f(opt_s, opt_v); + match out { + Some(out) => { + fast_explode &= !out.is_empty(); + Some(out) + }, + None => { + fast_explode = false; + out + }, + } + }) + .collect_trusted() + }; + + out.rename(self.name().clone()); + if fast_explode { + out.set_fast_explode(); + } + out + } + + #[must_use] + pub fn binary_zip_and_apply_amortized<'a, T, U, F>( + &'a self, + ca1: &'a ChunkedArray, + ca2: &'a ChunkedArray, + mut f: F, + ) -> Self + where + T: PolarsDataType, + U: PolarsDataType, + F: FnMut( + Option, + Option>, + Option>, + ) -> Option, + { + if self.is_empty() { + return self.clone(); + } + let mut fast_explode = self.null_count() == 0; + let mut out: ListChunked = { + self.amortized_iter() + .zip(ca1.iter()) + .zip(ca2.iter()) + .map(|((opt_s, opt_u), opt_v)| { + let out = f(opt_s, opt_u, opt_v); + match out { + Some(out) => { + fast_explode &= !out.is_empty(); + Some(out) + }, + None => { + fast_explode = false; + out + }, + } + }) + .collect_trusted() + }; + + out.rename(self.name().clone()); + if fast_explode { + out.set_fast_explode(); + } + out + } + + pub fn try_binary_zip_and_apply_amortized<'a, T, U, F>( + &'a self, + ca1: &'a ChunkedArray, + ca2: &'a ChunkedArray, + mut f: F, + ) -> PolarsResult + where + T: PolarsDataType, + U: PolarsDataType, + F: FnMut( + Option, + Option>, + Option>, + ) -> PolarsResult>, + { + if self.is_empty() { + return Ok(self.clone()); + } + let mut fast_explode = self.null_count() == 0; + let mut out: ListChunked = { + self.amortized_iter() + .zip(ca1.iter()) + .zip(ca2.iter()) + .map(|((opt_s, opt_u), opt_v)| { + let out = f(opt_s, opt_u, opt_v)?; + match out { + Some(out) => { + fast_explode &= !out.is_empty(); + Ok(Some(out)) + }, + None => { + fast_explode = false; + Ok(out) + }, + } + }) + .collect::>()? + }; + + out.rename(self.name().clone()); + if fast_explode { + out.set_fast_explode(); + } + Ok(out) + } + + pub fn try_zip_and_apply_amortized<'a, T, I, F>( + &'a self, + ca: &'a ChunkedArray, + mut f: F, + ) -> PolarsResult + where + T: PolarsDataType, + &'a ChunkedArray: IntoIterator, + I: TrustedLen>>, + F: FnMut(Option, Option>) -> PolarsResult>, + { + if self.is_empty() { + return Ok(self.clone()); + } + let mut fast_explode = self.null_count() == 0; + let mut out: ListChunked = { + self.amortized_iter() + .zip(ca) + .map(|(opt_s, opt_v)| { + let out = f(opt_s, opt_v)?; + match out { + Some(out) => { + fast_explode &= !out.is_empty(); + Ok(Some(out)) + }, + None => { + fast_explode = false; + Ok(out) + }, + } + }) + .collect::>()? + }; + + out.rename(self.name().clone()); + if fast_explode { + out.set_fast_explode(); + } + Ok(out) + } + + /// Apply a closure `F` elementwise. + #[must_use] + pub fn apply_amortized(&self, mut f: F) -> Self + where + F: FnMut(AmortSeries) -> Series, + { + if self.is_empty() { + return self.clone(); + } + let mut fast_explode = self.null_count() == 0; + let mut ca: ListChunked = { + self.amortized_iter() + .map(|opt_v| { + opt_v.map(|v| { + let out = f(v); + if out.is_empty() { + fast_explode = false; + } + out + }) + }) + .collect_trusted() + }; + + ca.rename(self.name().clone()); + if fast_explode { + ca.set_fast_explode(); + } + ca + } + + pub fn try_apply_amortized(&self, mut f: F) -> PolarsResult + where + F: FnMut(AmortSeries) -> PolarsResult, + { + if self.is_empty() { + return Ok(self.clone()); + } + let mut fast_explode = self.null_count() == 0; + let mut ca: ListChunked = { + self.amortized_iter() + .map(|opt_v| { + opt_v + .map(|v| { + let out = f(v); + if let Ok(out) = &out { + if out.is_empty() { + fast_explode = false + } + }; + out + }) + .transpose() + }) + .collect::>()? + }; + ca.rename(self.name().clone()); + if fast_explode { + ca.set_fast_explode(); + } + Ok(ca) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::chunked_array::builder::get_list_builder; + + #[test] + fn test_iter_list() { + let mut builder = get_list_builder(&DataType::Int32, 10, 10, PlSmallStr::EMPTY); + builder + .append_series(&Series::new(PlSmallStr::EMPTY, &[1, 2, 3])) + .unwrap(); + builder + .append_series(&Series::new(PlSmallStr::EMPTY, &[3, 2, 1])) + .unwrap(); + builder + .append_series(&Series::new(PlSmallStr::EMPTY, &[1, 1])) + .unwrap(); + let ca = builder.finish(); + + ca.amortized_iter().zip(&ca).for_each(|(s1, s2)| { + assert!(s1.unwrap().as_ref().equals(&s2.unwrap())); + }) + } +} diff --git a/crates/polars-core/src/chunked_array/list/mod.rs b/crates/polars-core/src/chunked_array/list/mod.rs new file mode 100644 index 000000000000..b00589fa104b --- /dev/null +++ b/crates/polars-core/src/chunked_array/list/mod.rs @@ -0,0 +1,193 @@ +//! Special list utility methods +pub(super) mod iterator; + +use std::borrow::Cow; + +use crate::prelude::*; + +impl ListChunked { + /// Get the inner data type of the list. + pub fn inner_dtype(&self) -> &DataType { + match self.dtype() { + DataType::List(dt) => dt.as_ref(), + _ => unreachable!(), + } + } + + pub fn set_inner_dtype(&mut self, dtype: DataType) { + assert_eq!(dtype.to_physical(), self.inner_dtype().to_physical()); + let field = Arc::make_mut(&mut self.field); + field.coerce(DataType::List(Box::new(dtype))); + } + + pub fn set_fast_explode(&mut self) { + self.set_fast_explode_list(true) + } + + pub fn _can_fast_explode(&self) -> bool { + self.get_fast_explode_list() + } + + /// Set the logical type of the [`ListChunked`]. + /// + /// # Safety + /// The caller must ensure that the logical type given fits the physical type of the array. + pub unsafe fn to_logical(&mut self, inner_dtype: DataType) { + debug_assert_eq!(&inner_dtype.to_physical(), self.inner_dtype()); + let fld = Arc::make_mut(&mut self.field); + fld.coerce(DataType::List(Box::new(inner_dtype))) + } + + /// Convert the datatype of the list into the physical datatype. + pub fn to_physical_repr(&self) -> Cow { + let Cow::Owned(physical_repr) = self.get_inner().to_physical_repr() else { + return Cow::Borrowed(self); + }; + + let ca = if physical_repr.chunks().len() == 1 && self.chunks().len() > 1 { + // Physical repr got rechunked, rechunk self as well. + self.rechunk() + } else { + Cow::Borrowed(self) + }; + + assert_eq!(ca.chunks().len(), physical_repr.chunks().len()); + + let chunks: Vec<_> = ca + .downcast_iter() + .zip(physical_repr.into_chunks()) + .map(|(chunk, values)| { + LargeListArray::new( + ArrowDataType::LargeList(Box::new(ArrowField::new( + PlSmallStr::from_static("item"), + values.dtype().clone(), + true, + ))), + chunk.offsets().clone(), + values, + chunk.validity().cloned(), + ) + .to_boxed() + }) + .collect(); + + let name = self.name().clone(); + let dtype = DataType::List(Box::new(self.inner_dtype().to_physical())); + Cow::Owned(unsafe { ListChunked::from_chunks_and_dtype_unchecked(name, chunks, dtype) }) + } + + /// Convert a non-logical [`ListChunked`] back into a logical [`ListChunked`] without casting. + /// + /// # Safety + /// + /// This can lead to invalid memory access in downstream code. + pub unsafe fn from_physical_unchecked( + &self, + to_inner_dtype: DataType, + ) -> PolarsResult { + debug_assert!(!self.inner_dtype().is_logical()); + + let inner_chunks = self + .downcast_iter() + .map(|chunk| chunk.values()) + .cloned() + .collect(); + + let inner = unsafe { + Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + inner_chunks, + self.inner_dtype(), + ) + }; + let inner = unsafe { inner.from_physical_unchecked(&to_inner_dtype) }?; + + let chunks: Vec<_> = self + .downcast_iter() + .zip(inner.into_chunks()) + .map(|(chunk, values)| { + LargeListArray::new( + ArrowDataType::LargeList(Box::new(ArrowField::new( + PlSmallStr::from_static("item"), + values.dtype().clone(), + true, + ))), + chunk.offsets().clone(), + values, + chunk.validity().cloned(), + ) + .to_boxed() + }) + .collect(); + + let name = self.name().clone(); + let dtype = DataType::List(Box::new(to_inner_dtype)); + Ok(unsafe { ListChunked::from_chunks_and_dtype_unchecked(name, chunks, dtype) }) + } + + /// Get the inner values as [`Series`], ignoring the list offsets. + pub fn get_inner(&self) -> Series { + let chunks: Vec<_> = self.downcast_iter().map(|c| c.values().clone()).collect(); + + // SAFETY: Data type of arrays matches because they are chunks from the same array. + unsafe { + Series::from_chunks_and_dtype_unchecked(self.name().clone(), chunks, self.inner_dtype()) + } + } + + /// Ignore the list indices and apply `func` to the inner type as [`Series`]. + pub fn apply_to_inner( + &self, + func: &dyn Fn(Series) -> PolarsResult, + ) -> PolarsResult { + // generated Series will have wrong length otherwise. + let ca = self.rechunk(); + let arr = ca.downcast_as_array(); + + // SAFETY: + // Inner dtype is passed correctly + let elements = unsafe { + Series::from_chunks_and_dtype_unchecked( + self.name().clone(), + vec![arr.values().clone()], + ca.inner_dtype(), + ) + }; + + let expected_len = elements.len(); + let out: Series = func(elements)?; + polars_ensure!( + out.len() == expected_len, + ComputeError: "the function should apply element-wise, it removed elements instead" + ); + let out = out.rechunk(); + let values = out.chunks()[0].clone(); + + let inner_dtype = LargeListArray::default_datatype(values.dtype().clone()); + let arr = LargeListArray::new( + inner_dtype, + (*arr.offsets()).clone(), + values, + arr.validity().cloned(), + ); + + // SAFETY: arr's inner dtype is derived from out dtype. + Ok(unsafe { + ListChunked::from_chunks_and_dtype_unchecked( + ca.name().clone(), + vec![Box::new(arr)], + DataType::List(Box::new(out.dtype().clone())), + ) + }) + } + + pub fn rechunk_and_trim_to_normalized_offsets(&self) -> Self { + Self::with_chunk( + self.name().clone(), + self.rechunk() + .downcast_get(0) + .unwrap() + .trim_to_normalized_offsets_recursive(), + ) + } +} diff --git a/crates/polars-core/src/chunked_array/logical/categorical/builder.rs b/crates/polars-core/src/chunked_array/logical/categorical/builder.rs new file mode 100644 index 000000000000..e44e0cda8748 --- /dev/null +++ b/crates/polars-core/src/chunked_array/logical/categorical/builder.rs @@ -0,0 +1,507 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use std::hash::BuildHasher; + +use arrow::array::*; +use arrow::legacy::trusted_len::TrustedLenPush; +use hashbrown::hash_map::Entry; +use hashbrown::hash_table::{Entry as HTEntry, HashTable}; +use polars_utils::itertools::Itertools; + +use crate::hashing::_HASHMAP_INIT_SIZE; +use crate::prelude::*; +use crate::{POOL, StringCache, using_string_cache}; + +pub struct CategoricalChunkedBuilder { + cat_builder: UInt32Vec, + name: PlSmallStr, + ordering: CategoricalOrdering, + categories: MutablePlString, + local_mapping: HashTable, + local_hasher: PlFixedStateQuality, +} + +impl CategoricalChunkedBuilder { + pub fn new(name: PlSmallStr, capacity: usize, ordering: CategoricalOrdering) -> Self { + Self { + cat_builder: UInt32Vec::with_capacity(capacity), + name, + ordering, + categories: MutablePlString::with_capacity(_HASHMAP_INIT_SIZE), + local_mapping: HashTable::with_capacity(capacity / 10), + local_hasher: StringCache::get_hash_builder(), + } + } + + fn get_cat_idx(&mut self, s: &str, h: u64) -> (u32, bool) { + let len = self.local_mapping.len() as u32; + + // SAFETY: index in hashmap are within bounds of categories + unsafe { + let r = self.local_mapping.entry( + h, + |k| self.categories.value_unchecked(*k as usize) == s, + |k| { + self.local_hasher + .hash_one(self.categories.value_unchecked(*k as usize)) + }, + ); + + match r { + HTEntry::Occupied(v) => (*v.get(), false), + HTEntry::Vacant(slot) => { + self.categories.push(Some(s)); + slot.insert(len); + (len, true) + }, + } + } + } + + fn try_get_cat_idx(&mut self, s: &str, h: u64) -> Option { + // SAFETY: index in hashmap are within bounds of categories + unsafe { + let r = self.local_mapping.entry( + h, + |k| self.categories.value_unchecked(*k as usize) == s, + |k| { + self.local_hasher + .hash_one(self.categories.value_unchecked(*k as usize)) + }, + ); + + match r { + HTEntry::Occupied(v) => Some(*v.get()), + HTEntry::Vacant(_) => None, + } + } + } + + /// Append a new category, but fail if it didn't exist yet in the category state. + /// You can register categories up front with `register_value`, or via `append`. + #[inline] + pub fn try_append_value(&mut self, s: &str) -> PolarsResult<()> { + let h = self.local_hasher.hash_one(s); + let idx = self.try_get_cat_idx(s, h).ok_or_else( + || polars_err!(ComputeError: "category {} doesn't exist in Enum dtype", s), + )?; + self.cat_builder.push(Some(idx)); + Ok(()) + } + + /// Append a new category, but fail if it didn't exist yet in the category state. + /// You can register categories up front with `register_value`, or via `append`. + #[inline] + pub fn try_append(&mut self, opt_s: Option<&str>) -> PolarsResult<()> { + match opt_s { + None => self.append_null(), + Some(s) => self.try_append_value(s)?, + } + Ok(()) + } + + /// Registers a value to a categorical index without pushing it. + /// Returns the index and if the value was new. + #[inline] + pub fn register_value(&mut self, s: &str) -> (u32, bool) { + let h = self.local_hasher.hash_one(s); + self.get_cat_idx(s, h) + } + + #[inline] + pub fn append_value(&mut self, s: &str) { + let h = self.local_hasher.hash_one(s); + let idx = self.get_cat_idx(s, h).0; + self.cat_builder.push(Some(idx)); + } + + #[inline] + pub fn append_null(&mut self) { + self.cat_builder.push(None) + } + + #[inline] + pub fn append(&mut self, opt_s: Option<&str>) { + match opt_s { + None => self.append_null(), + Some(s) => self.append_value(s), + } + } + + fn drain_iter<'a, I>(&mut self, i: I) + where + I: IntoIterator>, + { + for opt_s in i.into_iter() { + self.append(opt_s); + } + } + + /// Fast path for global categorical which preserves hashes and saves an allocation by + /// altering the keys in place. + fn drain_iter_global_and_finish<'a, I>(&mut self, i: I) -> CategoricalChunked + where + I: IntoIterator>, + { + let iter = i.into_iter(); + // Save hashes for later when inserting into the global hashmap. + let mut hashes = Vec::with_capacity(_HASHMAP_INIT_SIZE); + for s in self.categories.values_iter() { + hashes.push(self.local_hasher.hash_one(s)); + } + + for opt_s in iter { + match opt_s { + None => self.append_null(), + Some(s) => { + let hash = self.local_hasher.hash_one(s); + let (cat_idx, new) = self.get_cat_idx(s, hash); + self.cat_builder.push(Some(cat_idx)); + if new { + // We appended a value to the map. + hashes.push(hash); + } + }, + } + } + + let categories = std::mem::take(&mut self.categories).freeze(); + + // We will create a mapping from our local categoricals to global categoricals + // and a mapping from global categoricals to our local categoricals. + let mut local_to_global: Vec = Vec::with_capacity(categories.len()); + let (id, local_to_global) = crate::STRING_CACHE.apply(|cache| { + for (s, h) in categories.values_iter().zip(hashes) { + // SAFETY: we allocated enough. + unsafe { local_to_global.push_unchecked(cache.insert_from_hash(h, s)) } + } + local_to_global + }); + + // Change local indices inplace to their global counterparts. + let update_cats = || { + if !local_to_global.is_empty() { + // when all categorical are null, `local_to_global` is empty and all cats physical values are 0. + self.cat_builder.apply_values(|cats| { + for cat in cats { + debug_assert!((*cat as usize) < local_to_global.len()); + *cat = *unsafe { local_to_global.get_unchecked(*cat as usize) }; + } + }) + } + }; + + let mut global_to_local = PlHashMap::with_capacity(local_to_global.len()); + POOL.join( + || fill_global_to_local(&local_to_global, &mut global_to_local), + update_cats, + ); + + let indices = std::mem::take(&mut self.cat_builder).into(); + let indices = UInt32Chunked::with_chunk(self.name.clone(), indices); + + // SAFETY: indices are in bounds of new rev_map + unsafe { + CategoricalChunked::from_cats_and_rev_map_unchecked( + indices, + Arc::new(RevMapping::Global(global_to_local, categories, id)), + false, + self.ordering, + ) + .with_fast_unique(true) + } + } + + pub fn drain_iter_and_finish<'a, I>(mut self, i: I) -> CategoricalChunked + where + I: IntoIterator>, + { + if using_string_cache() { + self.drain_iter_global_and_finish(i) + } else { + self.drain_iter(i); + self.finish() + } + } + + pub fn finish(self) -> CategoricalChunked { + // SAFETY: keys and values are in bounds + unsafe { + CategoricalChunked::from_keys_and_values( + self.name.clone(), + &self.cat_builder.into(), + &self.categories.into(), + self.ordering, + ) + .with_fast_unique(true) + } + } +} + +fn fill_global_to_local(local_to_global: &[u32], global_to_local: &mut PlHashMap) { + let mut local_idx = 0; + #[allow(clippy::explicit_counter_loop)] + for global_idx in local_to_global { + // we know the keys are unique so this is much faster + unsafe { + global_to_local.insert_unique_unchecked(*global_idx, local_idx); + } + local_idx += 1; + } +} + +impl CategoricalChunked { + /// Create a [`CategoricalChunked`] from a categorical indices. The indices will + /// probe the global string cache. + pub(crate) fn from_global_indices( + cats: UInt32Chunked, + ordering: CategoricalOrdering, + ) -> PolarsResult { + let len = crate::STRING_CACHE.read_map().len() as u32; + let oob = cats.into_iter().flatten().any(|cat| cat >= len); + polars_ensure!( + !oob, + ComputeError: + "cannot construct Categorical from these categories; at least one of them is out of bounds" + ); + Ok(unsafe { Self::from_global_indices_unchecked(cats, ordering) }) + } + + /// Create a [`CategoricalChunked`] from a categorical indices. The indices will + /// probe the global string cache. + /// + /// # Safety + /// This does not do any bound checks + pub unsafe fn from_global_indices_unchecked( + cats: UInt32Chunked, + ordering: CategoricalOrdering, + ) -> CategoricalChunked { + let cache = crate::STRING_CACHE.read_map(); + + let cap = std::cmp::min(std::cmp::min(cats.len(), cache.len()), _HASHMAP_INIT_SIZE); + let mut rev_map = PlHashMap::with_capacity(cap); + let mut str_values = MutablePlString::with_capacity(cap); + + for arr in cats.downcast_iter() { + for cat in arr.into_iter().flatten().copied() { + let offset = str_values.len() as u32; + + if let Entry::Vacant(entry) = rev_map.entry(cat) { + entry.insert(offset); + let str_val = cache.get_unchecked(cat); + str_values.push(Some(str_val)) + } + } + } + + let rev_map = RevMapping::Global(rev_map, str_values.into(), cache.uuid); + + CategoricalChunked::from_cats_and_rev_map_unchecked( + cats, + Arc::new(rev_map), + false, + ordering, + ) + } + + pub(crate) unsafe fn from_keys_and_values_global( + name: PlSmallStr, + keys: impl IntoIterator> + Send, + capacity: usize, + values: &Utf8ViewArray, + ordering: CategoricalOrdering, + ) -> Self { + // Vec where the index is local and the value is the global index + let mut local_to_global: Vec = Vec::with_capacity(values.len()); + let (id, local_to_global) = crate::STRING_CACHE.apply(|cache| { + // locally we don't need a hashmap because we all categories are 1 integer apart + // so the index is local, and the values is global + for s in values.values_iter() { + // SAFETY: we allocated enough + unsafe { local_to_global.push_unchecked(cache.insert(s)) } + } + local_to_global + }); + + let compute_cats = || { + let mut result = UInt32Vec::with_capacity(capacity); + + for opt_value in keys.into_iter() { + result.push(opt_value.map(|cat| { + debug_assert!((cat as usize) < local_to_global.len()); + *unsafe { local_to_global.get_unchecked(cat as usize) } + })); + } + result + }; + + let mut global_to_local = PlHashMap::with_capacity(local_to_global.len()); + let (_, cats) = POOL.join( + || fill_global_to_local(&local_to_global, &mut global_to_local), + compute_cats, + ); + unsafe { + CategoricalChunked::from_cats_and_rev_map_unchecked( + UInt32Chunked::with_chunk(name, cats.into()), + Arc::new(RevMapping::Global(global_to_local, values.clone(), id)), + false, + ordering, + ) + } + } + + pub(crate) unsafe fn from_keys_and_values_local( + name: PlSmallStr, + keys: &PrimitiveArray, + values: &Utf8ViewArray, + ordering: CategoricalOrdering, + ) -> CategoricalChunked { + CategoricalChunked::from_cats_and_rev_map_unchecked( + UInt32Chunked::with_chunk(name, keys.clone()), + Arc::new(RevMapping::build_local(values.clone())), + false, + ordering, + ) + } + + /// # Safety + /// The caller must ensure that index values in the `keys` are in within bounds of the `values` length. + pub(crate) unsafe fn from_keys_and_values( + name: PlSmallStr, + keys: &PrimitiveArray, + values: &Utf8ViewArray, + ordering: CategoricalOrdering, + ) -> Self { + if !using_string_cache() { + CategoricalChunked::from_keys_and_values_local(name, keys, values, ordering) + } else { + CategoricalChunked::from_keys_and_values_global( + name, + keys.into_iter().map(|c| c.copied()), + keys.len(), + values, + ordering, + ) + } + } + + /// Create a [`CategoricalChunked`] from a fixed list of categories and a List of strings. + /// This will error if a string is not in the fixed list of categories + pub fn from_string_to_enum( + values: &StringChunked, + categories: &Utf8ViewArray, + ordering: CategoricalOrdering, + ) -> PolarsResult { + polars_ensure!(categories.null_count() == 0, ComputeError: "categories can not contain null values"); + + // Build a mapping string -> idx + let mut map = PlHashMap::with_capacity(categories.len()); + for (idx, cat) in categories.values_iter().enumerate_idx() { + #[allow(clippy::unnecessary_cast)] + map.insert(cat, idx as u32); + } + // Find idx of every value in the map + let iter = values.downcast_iter().map(|arr| { + arr.iter() + .map(|opt_s: Option<&str>| opt_s.and_then(|s| map.get(s).copied())) + .collect_arr() + }); + let mut keys: UInt32Chunked = ChunkedArray::from_chunk_iter(values.name().clone(), iter); + keys.rename(values.name().clone()); + let rev_map = RevMapping::build_local(categories.clone()); + unsafe { + Ok(CategoricalChunked::from_cats_and_rev_map_unchecked( + keys, + Arc::new(rev_map), + true, + ordering, + ) + .with_fast_unique(false)) + } + } +} + +#[cfg(test)] +mod test { + use crate::prelude::*; + use crate::{SINGLE_LOCK, disable_string_cache, enable_string_cache}; + + #[test] + fn test_categorical_rev() -> PolarsResult<()> { + let _lock = SINGLE_LOCK.lock(); + disable_string_cache(); + let slice = &[ + Some("foo"), + None, + Some("bar"), + Some("foo"), + Some("foo"), + Some("bar"), + ]; + let ca = StringChunked::new(PlSmallStr::from_static("a"), slice); + let out = ca.cast(&DataType::Categorical(None, Default::default()))?; + let out = out.categorical().unwrap().clone(); + assert_eq!(out.get_rev_map().len(), 2); + + // test the global branch + enable_string_cache(); + // empty global cache + let out = ca.cast(&DataType::Categorical(None, Default::default()))?; + let out = out.categorical().unwrap().clone(); + assert_eq!(out.get_rev_map().len(), 2); + // full global cache + let out = ca.cast(&DataType::Categorical(None, Default::default()))?; + let out = out.categorical().unwrap().clone(); + assert_eq!(out.get_rev_map().len(), 2); + + // Check that we don't panic if we append two categorical arrays + // build under the same string cache + // https://github.com/pola-rs/polars/issues/1115 + let ca1 = StringChunked::new(PlSmallStr::from_static("a"), slice) + .cast(&DataType::Categorical(None, Default::default()))?; + let mut ca1 = ca1.categorical().unwrap().clone(); + let ca2 = StringChunked::new(PlSmallStr::from_static("a"), slice) + .cast(&DataType::Categorical(None, Default::default()))?; + let ca2 = ca2.categorical().unwrap(); + ca1.append(ca2).unwrap(); + + Ok(()) + } + + #[test] + fn test_categorical_builder() { + use crate::{disable_string_cache, enable_string_cache}; + let _lock = crate::SINGLE_LOCK.lock(); + for use_string_cache in [false, true] { + disable_string_cache(); + if use_string_cache { + enable_string_cache(); + } + + // Use 2 builders to check if the global string cache + // does not interfere with the index mapping + let builder1 = CategoricalChunkedBuilder::new( + PlSmallStr::from_static("foo"), + 10, + Default::default(), + ); + let builder2 = CategoricalChunkedBuilder::new( + PlSmallStr::from_static("foo"), + 10, + Default::default(), + ); + let s = builder1 + .drain_iter_and_finish(vec![None, Some("hello"), Some("vietnam")]) + .into_series(); + assert_eq!(s.str_value(0).unwrap(), "null"); + assert_eq!(s.str_value(1).unwrap(), "hello"); + assert_eq!(s.str_value(2).unwrap(), "vietnam"); + + let s = builder2 + .drain_iter_and_finish(vec![Some("hello"), None, Some("world")]) + .into_series(); + assert_eq!(s.str_value(0).unwrap(), "hello"); + assert_eq!(s.str_value(1).unwrap(), "null"); + assert_eq!(s.str_value(2).unwrap(), "world"); + } + } +} diff --git a/crates/polars-core/src/chunked_array/logical/categorical/from.rs b/crates/polars-core/src/chunked_array/logical/categorical/from.rs new file mode 100644 index 000000000000..94ddaab1bfb4 --- /dev/null +++ b/crates/polars-core/src/chunked_array/logical/categorical/from.rs @@ -0,0 +1,100 @@ +use arrow::datatypes::IntegerType; +use polars_compute::cast::{CastOptionsImpl, cast, utf8view_to_utf8}; + +use super::*; + +fn convert_values(arr: &Utf8ViewArray, compat_level: CompatLevel) -> ArrayRef { + if compat_level.0 >= 1 { + arr.clone().boxed() + } else { + utf8view_to_utf8::(arr).boxed() + } +} + +impl CategoricalChunked { + pub fn to_arrow(&self, compat_level: CompatLevel, as_i64: bool) -> ArrayRef { + if as_i64 { + self.to_i64(compat_level).boxed() + } else { + self.to_u32(compat_level).boxed() + } + } + + fn to_u32(&self, compat_level: CompatLevel) -> DictionaryArray { + let values_dtype = if compat_level.0 >= 1 { + ArrowDataType::Utf8View + } else { + ArrowDataType::LargeUtf8 + }; + let keys = self.physical().rechunk(); + let keys = keys.downcast_as_array(); + let map = &**self.get_rev_map(); + let dtype = ArrowDataType::Dictionary(IntegerType::UInt32, Box::new(values_dtype), false); + match map { + RevMapping::Local(arr, _) => { + let values = convert_values(arr, compat_level); + + // SAFETY: + // the keys are in bounds + unsafe { DictionaryArray::try_new_unchecked(dtype, keys.clone(), values).unwrap() } + }, + RevMapping::Global(reverse_map, values, _uuid) => { + let iter = keys + .iter() + .map(|opt_k| opt_k.map(|k| *reverse_map.get(k).unwrap())); + let keys = PrimitiveArray::from_trusted_len_iter(iter); + + let values = convert_values(values, compat_level); + + // SAFETY: + // the keys are in bounds + unsafe { DictionaryArray::try_new_unchecked(dtype, keys, values).unwrap() } + }, + } + } + + fn to_i64(&self, compat_level: CompatLevel) -> DictionaryArray { + let values_dtype = if compat_level.0 >= 1 { + ArrowDataType::Utf8View + } else { + ArrowDataType::LargeUtf8 + }; + let keys = self.physical().rechunk(); + let keys = keys.downcast_as_array(); + let map = &**self.get_rev_map(); + let dtype = ArrowDataType::Dictionary(IntegerType::Int64, Box::new(values_dtype), false); + match map { + RevMapping::Local(arr, _) => { + let values = convert_values(arr, compat_level); + + // SAFETY: + // the keys are in bounds + unsafe { + DictionaryArray::try_new_unchecked( + dtype, + cast(keys, &ArrowDataType::Int64, CastOptionsImpl::unchecked()) + .unwrap() + .as_any() + .downcast_ref::>() + .unwrap() + .clone(), + values, + ) + .unwrap() + } + }, + RevMapping::Global(reverse_map, values, _uuid) => { + let iter = keys + .iter() + .map(|opt_k| opt_k.map(|k| *reverse_map.get(k).unwrap() as i64)); + let keys = PrimitiveArray::from_trusted_len_iter(iter); + + let values = convert_values(values, compat_level); + + // SAFETY: + // the keys are in bounds + unsafe { DictionaryArray::try_new_unchecked(dtype, keys, values).unwrap() } + }, + } + } +} diff --git a/crates/polars-core/src/chunked_array/logical/categorical/merge.rs b/crates/polars-core/src/chunked_array/logical/categorical/merge.rs new file mode 100644 index 000000000000..28e92cb4f11e --- /dev/null +++ b/crates/polars-core/src/chunked_array/logical/categorical/merge.rs @@ -0,0 +1,260 @@ +use std::borrow::Cow; + +use super::*; +use crate::series::IsSorted; +use crate::utils::align_chunks_binary; + +fn slots_to_mut(slots: &Utf8ViewArray) -> MutablePlString { + slots.clone().make_mut() +} + +struct State { + map: PlHashMap, + slots: MutablePlString, +} + +#[derive(Default)] +pub struct GlobalRevMapMerger { + id: u32, + original: Arc, + // only initiate state when + // we encounter a rev-map from a different source, + // but from the same string cache + state: Option, +} + +impl GlobalRevMapMerger { + pub fn new(rev_map: Arc) -> Self { + let RevMapping::Global(_, _, id) = rev_map.as_ref() else { + unreachable!() + }; + + GlobalRevMapMerger { + state: None, + id: *id, + original: rev_map, + } + } + + fn init_state(&mut self) { + let RevMapping::Global(map, slots, _) = self.original.as_ref() else { + unreachable!() + }; + self.state = Some(State { + map: (*map).clone(), + slots: slots_to_mut(slots), + }) + } + + pub fn merge_map(&mut self, rev_map: &Arc) -> PolarsResult<()> { + // happy path they come from the same source + if Arc::ptr_eq(&self.original, rev_map) { + return Ok(()); + } + + let RevMapping::Global(map, slots, id) = rev_map.as_ref() else { + polars_bail!(string_cache_mismatch) + }; + polars_ensure!(*id == self.id, string_cache_mismatch); + + if self.state.is_none() { + self.init_state() + } + let state = self.state.as_mut().unwrap(); + + for (cat, idx) in map.iter() { + state.map.entry(*cat).or_insert_with(|| { + // SAFETY: + // within bounds + let str_val = unsafe { slots.value_unchecked(*idx as usize) }; + let new_idx = state.slots.len() as u32; + state.slots.push(Some(str_val)); + + new_idx + }); + } + Ok(()) + } + + pub fn finish(self) -> Arc { + match self.state { + None => self.original, + Some(state) => { + let new_rev = RevMapping::Global(state.map, state.slots.into(), self.id); + Arc::new(new_rev) + }, + } + } +} + +fn merge_local_rhs_categorical<'a>( + categories: &'a Utf8ViewArray, + ca_right: &'a CategoricalChunked, +) -> Result<(UInt32Chunked, Arc), PolarsError> { + // Counterpart of the GlobalRevmapMerger. + // In case of local categorical we also need to change the physicals not only the revmap + + polars_warn!( + CategoricalRemappingWarning, + "Local categoricals have different encodings, expensive re-encoding is done \ + to perform this merge operation. Consider using a StringCache or an Enum type \ + if the categories are known in advance" + ); + + let RevMapping::Local(cats_right, _) = &**ca_right.get_rev_map() else { + unreachable!() + }; + + let cats_left_hashmap = PlHashMap::from_iter( + categories + .values_iter() + .enumerate() + .map(|(k, v)| (v, k as u32)), + ); + let mut new_categories = slots_to_mut(categories); + let mut idx_mapping = PlHashMap::with_capacity(cats_right.len()); + + for (idx, s) in cats_right.values_iter().enumerate() { + if let Some(v) = cats_left_hashmap.get(&s) { + idx_mapping.insert(idx as u32, *v); + } else { + idx_mapping.insert(idx as u32, new_categories.len() as u32); + new_categories.push(Some(s)); + } + } + let new_rev_map = Arc::new(RevMapping::build_local(new_categories.into())); + Ok(( + ca_right + .physical + .apply(|opt_v| opt_v.map(|v| *idx_mapping.get(&v).unwrap())), + new_rev_map, + )) +} + +pub trait CategoricalMergeOperation { + fn finish(self, lhs: &UInt32Chunked, rhs: &UInt32Chunked) -> PolarsResult; +} + +// Make the right categorical compatible with the left while applying the merge operation +pub fn call_categorical_merge_operation( + cat_left: &CategoricalChunked, + cat_right: &CategoricalChunked, + merge_ops: I, +) -> PolarsResult { + let rev_map_left = cat_left.get_rev_map(); + let rev_map_right = cat_right.get_rev_map(); + let (mut new_physical, new_rev_map) = match (&**rev_map_left, &**rev_map_right) { + (RevMapping::Global(_, _, idl), RevMapping::Global(_, _, idr)) if idl == idr => { + let mut rev_map_merger = GlobalRevMapMerger::new(rev_map_left.clone()); + rev_map_merger.merge_map(rev_map_right)?; + ( + merge_ops.finish(cat_left.physical(), cat_right.physical())?, + rev_map_merger.finish(), + ) + }, + (RevMapping::Local(_, idl), RevMapping::Local(_, idr)) + if idl == idr && cat_left.is_enum() == cat_right.is_enum() => + { + ( + merge_ops.finish(cat_left.physical(), cat_right.physical())?, + rev_map_left.clone(), + ) + }, + (RevMapping::Local(categorical, _), RevMapping::Local(_, _)) + if !cat_left.is_enum() && !cat_right.is_enum() => + { + let (rhs_physical, rev_map) = merge_local_rhs_categorical(categorical, cat_right)?; + ( + merge_ops.finish(cat_left.physical(), &rhs_physical)?, + rev_map, + ) + }, + (RevMapping::Local(_, _), RevMapping::Local(_, _)) + if cat_left.is_enum() | cat_right.is_enum() => + { + polars_bail!(ComputeError: "can not merge incompatible Enum types") + }, + _ => polars_bail!(string_cache_mismatch), + }; + // During merge operation, the sorted flag might get set on the underlying physical. + // Ensure that the sorted flag is not set if we use lexical order + if cat_left.uses_lexical_ordering() { + new_physical.set_sorted_flag(IsSorted::Not) + } + + // SAFETY: physical and rev map are correctly constructed above + unsafe { + Ok(CategoricalChunked::from_cats_and_rev_map_unchecked( + new_physical, + new_rev_map, + cat_left.is_enum(), + cat_left.get_ordering(), + )) + } +} + +struct DoNothing; +impl CategoricalMergeOperation for DoNothing { + fn finish(self, _lhs: &UInt32Chunked, rhs: &UInt32Chunked) -> PolarsResult { + Ok(rhs.clone()) + } +} + +// Make the right categorical compatible with the left +pub fn make_rhs_categoricals_compatible( + ca_left: &CategoricalChunked, + ca_right: &CategoricalChunked, +) -> PolarsResult<(CategoricalChunked, CategoricalChunked)> { + let new_ca_right = call_categorical_merge_operation(ca_left, ca_right, DoNothing)?; + + // Alter rev map of left + let mut new_ca_left = ca_left.clone(); + // SAFETY: We just made both rev maps compatible only appended categories + unsafe { + new_ca_left.set_rev_map( + new_ca_right.get_rev_map().clone(), + ca_left.get_rev_map().len() == new_ca_right.get_rev_map().len(), + ) + }; + + Ok((new_ca_left, new_ca_right)) +} + +pub fn make_rhs_list_categoricals_compatible( + mut list_ca_left: ListChunked, + list_ca_right: ListChunked, +) -> PolarsResult<(ListChunked, ListChunked)> { + // Make categoricals compatible + + let cat_left = list_ca_left.get_inner(); + let cat_right = list_ca_right.get_inner(); + let (cat_left, cat_right) = + make_rhs_categoricals_compatible(cat_left.categorical()?, cat_right.categorical()?)?; + + // we only appended categories to the rev_map at the end, so only change the inner dtype + list_ca_left.set_inner_dtype(cat_left.dtype().clone()); + + // We changed the physicals and the rev_map, offsets and validity buffers are still good + let (list_ca_right, cat_physical): (Cow, Cow) = + align_chunks_binary(&list_ca_right, cat_right.physical()); + let mut list_ca_right = list_ca_right.into_owned(); + // SAFETY: + // Chunks are aligned, length / dtype remains correct + unsafe { + list_ca_right + .downcast_iter_mut() + .zip(cat_physical.chunks()) + .for_each(|(arr, new_phys)| { + *arr = ListArray::new( + arr.dtype().clone(), + arr.offsets().clone(), + new_phys.clone(), + arr.validity().cloned(), + ) + }); + } + // reset the sorted flag and add extra categories back in + list_ca_right.set_sorted_flag(IsSorted::Not); + list_ca_right.set_inner_dtype(cat_right.dtype().clone()); + Ok((list_ca_left, list_ca_right)) +} diff --git a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs new file mode 100644 index 000000000000..1cb51d5ea8f1 --- /dev/null +++ b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs @@ -0,0 +1,606 @@ +mod builder; +mod from; +mod merge; +mod ops; +pub mod revmap; +pub mod string_cache; + +use bitflags::bitflags; +pub use builder::*; +pub use merge::*; +use polars_utils::itertools::Itertools; +use polars_utils::sync::SyncPtr; +pub use revmap::*; + +use super::*; +use crate::chunked_array::cast::CastOptions; +use crate::chunked_array::flags::StatisticsFlags; +use crate::prelude::*; +use crate::series::IsSorted; +use crate::using_string_cache; + +bitflags! { + #[derive(Default, Clone)] + struct BitSettings: u8 { + const ORIGINAL = 0x01; + } +} + +#[derive(Default, Clone)] +pub struct CategoricalChunked { + physical: Logical, + /// 1st bit: original local categorical + /// meaning that n_unique is the same as the cat map length + bit_settings: BitSettings, +} + +impl CategoricalChunked { + pub(crate) fn field(&self) -> Field { + let name = self.physical().name(); + Field::new(name.clone(), self.dtype().clone()) + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + #[inline] + pub fn len(&self) -> usize { + self.physical.len() + } + + #[inline] + pub fn null_count(&self) -> usize { + self.physical.null_count() + } + + pub fn name(&self) -> &PlSmallStr { + self.physical.name() + } + + /// Get the physical array (the category indexes). + pub fn into_physical(self) -> UInt32Chunked { + self.physical.0 + } + + // TODO: Rename this + /// Get a reference to the physical array (the categories). + pub fn physical(&self) -> &UInt32Chunked { + &self.physical + } + + /// Get a mutable reference to the physical array (the categories). + pub(crate) fn physical_mut(&mut self) -> &mut UInt32Chunked { + &mut self.physical + } + + pub fn is_enum(&self) -> bool { + matches!(self.dtype(), DataType::Enum(_, _)) + } + + /// Convert a categorical column to its local representation. + pub fn to_local(&self) -> Self { + let rev_map = self.get_rev_map(); + let (physical_map, categories) = match rev_map.as_ref() { + RevMapping::Global(m, c, _) => (m, c), + RevMapping::Local(_, _) if !self.is_enum() => return self.clone(), + RevMapping::Local(_, _) => { + // Change dtype from Enum to Categorical + let mut local = self.clone(); + local.physical.2 = Some(DataType::Categorical( + Some(rev_map.clone()), + self.get_ordering(), + )); + return local; + }, + }; + + let local_rev_map = RevMapping::build_local(categories.clone()); + // TODO: A fast path can possibly be implemented here: + // if all physical map keys are equal to their values, + // we can skip the apply and only update the rev_map + let local_ca = self + .physical() + .apply(|opt_v| opt_v.map(|v| *physical_map.get(&v).unwrap())); + + let mut out = unsafe { + Self::from_cats_and_rev_map_unchecked( + local_ca, + local_rev_map.into(), + false, + self.get_ordering(), + ) + }; + out.set_fast_unique(self._can_fast_unique()); + + out + } + + pub fn to_global(&self) -> PolarsResult { + polars_ensure!(using_string_cache(), string_cache_mismatch); + // Fast path + let categories = match &**self.get_rev_map() { + RevMapping::Global(_, _, _) => return Ok(self.clone()), + RevMapping::Local(categories, _) => categories, + }; + + // SAFETY: keys and values are in bounds + unsafe { + Ok(CategoricalChunked::from_keys_and_values_global( + self.name().clone(), + self.physical(), + self.len(), + categories, + self.get_ordering(), + )) + } + } + + // Convert to fixed enum. Values not in categories are mapped to None. + pub fn to_enum(&self, categories: &Utf8ViewArray, hash: u128) -> Self { + // Fast paths + match self.get_rev_map().as_ref() { + RevMapping::Local(_, cur_hash) if hash == *cur_hash => { + return unsafe { + CategoricalChunked::from_cats_and_rev_map_unchecked( + self.physical().clone(), + self.get_rev_map().clone(), + true, + self.get_ordering(), + ) + }; + }, + _ => (), + }; + // Make a mapping from old idx to new idx + let old_rev_map = self.get_rev_map(); + + // Create map of old category -> idx for fast lookup. + let old_categories = old_rev_map.get_categories(); + let old_idx_map: PlHashMap<&str, u32> = old_categories + .values_iter() + .zip(0..old_categories.len() as u32) + .collect(); + + #[allow(clippy::unnecessary_cast)] + let idx_map: PlHashMap = categories + .values_iter() + .enumerate_idx() + .filter_map(|(new_idx, s)| old_idx_map.get(s).map(|old_idx| (*old_idx, new_idx as u32))) + .collect(); + + // Loop over the physicals and try get new idx + let new_phys: UInt32Chunked = self + .physical() + .into_iter() + .map(|opt_v: Option| opt_v.and_then(|v| idx_map.get(&v).copied())) + .collect(); + + // SAFETY: we created the physical from the enum categories + unsafe { + CategoricalChunked::from_cats_and_rev_map_unchecked( + new_phys, + Arc::new(RevMapping::Local(categories.clone(), hash)), + true, + self.get_ordering(), + ) + } + } + + pub(crate) fn get_flags(&self) -> StatisticsFlags { + self.physical().get_flags() + } + + /// Set flags for the Chunked Array + pub(crate) fn set_flags(&mut self, mut flags: StatisticsFlags) { + // We should not set the sorted flag if we are sorting in lexical order + if self.uses_lexical_ordering() { + flags.set_sorted(IsSorted::Not) + } + self.physical_mut().set_flags(flags) + } + + /// Return whether or not the [`CategoricalChunked`] uses the lexical order + /// of the string values when sorting. + pub fn uses_lexical_ordering(&self) -> bool { + self.get_ordering() == CategoricalOrdering::Lexical + } + + pub fn get_ordering(&self) -> CategoricalOrdering { + if let DataType::Categorical(_, ordering) | DataType::Enum(_, ordering) = + &self.physical.2.as_ref().unwrap() + { + *ordering + } else { + panic!("implementation error") + } + } + + /// Create a [`CategoricalChunked`] from a physical array and dtype. + /// + /// # Safety + /// It's not checked that the indices are in-bounds or that the dtype is + /// correct. + pub unsafe fn from_cats_and_dtype_unchecked(idx: UInt32Chunked, dtype: DataType) -> Self { + debug_assert!(matches!( + dtype, + DataType::Enum { .. } | DataType::Categorical { .. } + )); + let mut logical = Logical::::new_logical::(idx); + logical.2 = Some(dtype); + Self { + physical: logical, + bit_settings: Default::default(), + } + } + + /// Create a [`CategoricalChunked`] from an array of `idx` and an existing [`RevMapping`]: `rev_map`. + /// + /// # Safety + /// Invariant in `v < rev_map.len() for v in idx` must hold. + pub unsafe fn from_cats_and_rev_map_unchecked( + idx: UInt32Chunked, + rev_map: Arc, + is_enum: bool, + ordering: CategoricalOrdering, + ) -> Self { + let mut logical = Logical::::new_logical::(idx); + if is_enum { + logical.2 = Some(DataType::Enum(Some(rev_map), ordering)); + } else { + logical.2 = Some(DataType::Categorical(Some(rev_map), ordering)); + } + Self { + physical: logical, + bit_settings: Default::default(), + } + } + + pub(crate) fn set_ordering( + mut self, + ordering: CategoricalOrdering, + keep_fast_unique: bool, + ) -> Self { + self.physical.2 = match self.dtype() { + DataType::Enum(_, _) => { + Some(DataType::Enum(Some(self.get_rev_map().clone()), ordering)) + }, + DataType::Categorical(_, _) => Some(DataType::Categorical( + Some(self.get_rev_map().clone()), + ordering, + )), + _ => panic!("implementation error"), + }; + + if !keep_fast_unique { + self.set_fast_unique(false) + } + self + } + + /// # Safety + /// The existing index values must be in bounds of the new [`RevMapping`]. + pub(crate) unsafe fn set_rev_map(&mut self, rev_map: Arc, keep_fast_unique: bool) { + self.physical.2 = match self.dtype() { + DataType::Enum(_, _) => Some(DataType::Enum(Some(rev_map), self.get_ordering())), + DataType::Categorical(_, _) => { + Some(DataType::Categorical(Some(rev_map), self.get_ordering())) + }, + _ => panic!("implementation error"), + }; + + if !keep_fast_unique { + self.set_fast_unique(false) + } + } + + /// True if all categories are represented in this array. When this is the case, the unique + /// values of the array are the categories. + pub fn _can_fast_unique(&self) -> bool { + self.bit_settings.contains(BitSettings::ORIGINAL) + && self.physical.chunks.len() == 1 + && self.null_count() == 0 + } + + pub(crate) fn set_fast_unique(&mut self, toggle: bool) { + if toggle { + self.bit_settings.insert(BitSettings::ORIGINAL); + } else { + self.bit_settings.remove(BitSettings::ORIGINAL); + } + } + + /// Set `FAST_UNIQUE` metadata + /// # Safety + /// This invariant must hold `unique(categories) == unique(self)` + pub(crate) unsafe fn with_fast_unique(mut self, toggle: bool) -> Self { + self.set_fast_unique(toggle); + self + } + + /// Set `FAST_UNIQUE` metadata + /// # Safety + /// This invariant must hold `unique(categories) == unique(self)` + pub unsafe fn _with_fast_unique(self, toggle: bool) -> Self { + self.with_fast_unique(toggle) + } + + /// Get a reference to the mapping of categorical types to the string values. + pub fn get_rev_map(&self) -> &Arc { + if let DataType::Categorical(Some(rev_map), _) | DataType::Enum(Some(rev_map), _) = + &self.physical.2.as_ref().unwrap() + { + rev_map + } else { + panic!("implementation error") + } + } + + /// Create an [`Iterator`] that iterates over the `&str` values of the [`CategoricalChunked`]. + pub fn iter_str(&self) -> CatIter<'_> { + let iter = self.physical().into_iter(); + CatIter { + rev: self.get_rev_map(), + iter, + } + } +} + +impl LogicalType for CategoricalChunked { + fn dtype(&self) -> &DataType { + self.physical.2.as_ref().unwrap() + } + + fn get_any_value(&self, i: usize) -> PolarsResult> { + polars_ensure!(i < self.len(), oob = i, self.len()); + Ok(unsafe { self.get_any_value_unchecked(i) }) + } + + unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> { + match self.physical.0.get_unchecked(i) { + Some(i) => match self.dtype() { + DataType::Enum(_, _) => AnyValue::Enum(i, self.get_rev_map(), SyncPtr::new_null()), + DataType::Categorical(_, _) => { + AnyValue::Categorical(i, self.get_rev_map(), SyncPtr::new_null()) + }, + _ => unimplemented!(), + }, + None => AnyValue::Null, + } + } + + fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + match dtype { + DataType::String => { + let mapping = &**self.get_rev_map(); + + let mut builder = + StringChunkedBuilder::new(self.physical.name().clone(), self.len()); + + let f = |idx: u32| mapping.get(idx); + + if !self.physical.has_nulls() { + self.physical + .into_no_null_iter() + .for_each(|idx| builder.append_value(f(idx))); + } else { + self.physical.into_iter().for_each(|opt_idx| { + builder.append_option(opt_idx.map(f)); + }); + } + + let ca = builder.finish(); + Ok(ca.into_series()) + }, + DataType::UInt32 => { + let ca = unsafe { + UInt32Chunked::from_chunks( + self.physical.name().clone(), + self.physical.chunks.clone(), + ) + }; + Ok(ca.into_series()) + }, + #[cfg(feature = "dtype-categorical")] + DataType::Enum(Some(rev_map), ordering) => { + let RevMapping::Local(categories, hash) = &**rev_map else { + polars_bail!(ComputeError: "can not cast to enum with global mapping") + }; + Ok(self + .to_enum(categories, *hash) + .set_ordering(*ordering, true) + .into_series() + .with_name(self.name().clone())) + }, + DataType::Enum(None, _) => { + polars_bail!(ComputeError: "can not cast to enum without categories present") + }, + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(rev_map, ordering) => { + // Casting from an Enum to a local or global + if matches!(self.dtype(), DataType::Enum(_, _)) && rev_map.is_none() { + if using_string_cache() { + return Ok(self + .to_global()? + .set_ordering(*ordering, true) + .into_series()); + } else { + return Ok(self.to_local().set_ordering(*ordering, true).into_series()); + } + } + // If casting to lexical categorical, set sorted flag as not set + + let mut ca = self.clone().set_ordering(*ordering, true); + if ca.uses_lexical_ordering() { + ca.physical.set_sorted_flag(IsSorted::Not); + } + Ok(ca.into_series()) + }, + dt if dt.is_primitive_numeric() => { + // Apply the cast to the categories and then index into the casted series. + // This has to be local for the gather. + let slf = self.to_local(); + let categories = StringChunked::with_chunk( + slf.physical.name().clone(), + slf.get_rev_map().get_categories().clone(), + ); + let casted_series = categories.cast_with_options(dtype, options)?; + + #[cfg(feature = "bigidx")] + { + let s = slf.physical.cast_with_options(&DataType::UInt64, options)?; + Ok(unsafe { casted_series.take_unchecked(s.u64()?) }) + } + #[cfg(not(feature = "bigidx"))] + { + // SAFETY: Invariant of categorical means indices are in bound + Ok(unsafe { casted_series.take_unchecked(&slf.physical) }) + } + }, + _ => self.physical.cast_with_options(dtype, options), + } + } +} + +pub struct CatIter<'a> { + rev: &'a RevMapping, + iter: Box> + 'a>, +} + +unsafe impl TrustedLen for CatIter<'_> {} + +impl<'a> Iterator for CatIter<'a> { + type Item = Option<&'a str>; + + fn next(&mut self) -> Option { + self.iter.next().map(|item| { + item.map(|idx| { + // SAFETY: + // all categories are in bound + unsafe { self.rev.get_unchecked(idx) } + }) + }) + } + + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } +} + +impl DoubleEndedIterator for CatIter<'_> { + fn next_back(&mut self) -> Option { + self.iter.next_back().map(|item| { + item.map(|idx| { + // SAFETY: + // all categories are in bound + unsafe { self.rev.get_unchecked(idx) } + }) + }) + } +} + +impl ExactSizeIterator for CatIter<'_> {} + +#[cfg(test)] +mod test { + use super::*; + use crate::{SINGLE_LOCK, disable_string_cache, enable_string_cache}; + + #[test] + fn test_categorical_round_trip() -> PolarsResult<()> { + let _lock = SINGLE_LOCK.lock(); + disable_string_cache(); + let slice = &[ + Some("foo"), + None, + Some("bar"), + Some("foo"), + Some("foo"), + Some("bar"), + ]; + let ca = StringChunked::new(PlSmallStr::from_static("a"), slice); + let ca = ca.cast(&DataType::Categorical(None, Default::default()))?; + let ca = ca.categorical().unwrap(); + + let arr = ca.to_arrow(CompatLevel::newest(), false); + let s = Series::try_from((PlSmallStr::from_static("foo"), arr))?; + assert!(matches!(s.dtype(), &DataType::Categorical(_, _))); + assert_eq!(s.null_count(), 1); + assert_eq!(s.len(), 6); + + Ok(()) + } + + #[test] + fn test_append_categorical() { + let _lock = SINGLE_LOCK.lock(); + disable_string_cache(); + enable_string_cache(); + + let mut s1 = Series::new(PlSmallStr::from_static("1"), vec!["a", "b", "c"]) + .cast(&DataType::Categorical(None, Default::default())) + .unwrap(); + let s2 = Series::new(PlSmallStr::from_static("2"), vec!["a", "x", "y"]) + .cast(&DataType::Categorical(None, Default::default())) + .unwrap(); + let appended = s1.append(&s2).unwrap(); + assert_eq!(appended.str_value(0).unwrap(), "a"); + assert_eq!(appended.str_value(1).unwrap(), "b"); + assert_eq!(appended.str_value(4).unwrap(), "x"); + assert_eq!(appended.str_value(5).unwrap(), "y"); + } + + #[test] + fn test_fast_unique() { + let _lock = SINGLE_LOCK.lock(); + let s = Series::new(PlSmallStr::from_static("1"), vec!["a", "b", "c"]) + .cast(&DataType::Categorical(None, Default::default())) + .unwrap(); + + assert_eq!(s.n_unique().unwrap(), 3); + // Make sure that it does not take the fast path after take/slice. + let out = s.take(&IdxCa::new(PlSmallStr::EMPTY, [1, 2])).unwrap(); + assert_eq!(out.n_unique().unwrap(), 2); + let out = s.slice(1, 2); + assert_eq!(out.n_unique().unwrap(), 2); + } + + #[test] + fn test_categorical_flow() -> PolarsResult<()> { + let _lock = SINGLE_LOCK.lock(); + disable_string_cache(); + + // tests several things that may lose the dtype information + let s = Series::new(PlSmallStr::from_static("a"), vec!["a", "b", "c"]) + .cast(&DataType::Categorical(None, Default::default()))?; + + assert_eq!( + s.field().into_owned(), + Field::new( + PlSmallStr::from_static("a"), + DataType::Categorical(None, Default::default()) + ) + ); + assert!(matches!( + s.get(0)?, + AnyValue::Categorical(0, RevMapping::Local(_, _), _) + )); + + let groups = s.group_tuples(false, true); + let aggregated = unsafe { s.agg_list(&groups?) }; + match aggregated.get(0)? { + AnyValue::List(s) => { + assert!(matches!(s.dtype(), DataType::Categorical(_, _))); + let str_s = s.cast(&DataType::String).unwrap(); + assert_eq!(str_s.get(0)?, AnyValue::String("a")); + assert_eq!(s.len(), 1); + }, + _ => panic!(), + } + let flat = aggregated.explode()?; + let ca = flat.categorical().unwrap(); + let vals = ca.iter_str().map(|v| v.unwrap()).collect::>(); + assert_eq!(vals, &["a", "b", "c"]); + Ok(()) + } +} diff --git a/crates/polars-core/src/chunked_array/logical/categorical/ops/append.rs b/crates/polars-core/src/chunked_array/logical/categorical/ops/append.rs new file mode 100644 index 000000000000..4339f8d9e3fa --- /dev/null +++ b/crates/polars-core/src/chunked_array/logical/categorical/ops/append.rs @@ -0,0 +1,49 @@ +use polars_error::constants::LENGTH_LIMIT_MSG; + +use super::*; +use crate::chunked_array::ops::append::new_chunks; + +struct CategoricalAppend; + +impl CategoricalMergeOperation for CategoricalAppend { + fn finish(self, lhs: &UInt32Chunked, rhs: &UInt32Chunked) -> PolarsResult { + let mut lhs_mut = lhs.clone(); + lhs_mut.append(rhs)?; + Ok(lhs_mut) + } +} + +impl CategoricalChunked { + fn set_lengths(&mut self, other: &Self) { + let length_self = &mut self.physical_mut().length; + *length_self = length_self + .checked_add(other.len()) + .expect(LENGTH_LIMIT_MSG); + + assert!( + IdxSize::try_from(*length_self).is_ok(), + "{}", + LENGTH_LIMIT_MSG + ); + self.physical_mut().null_count += other.null_count(); + } + + pub fn append(&mut self, other: &Self) -> PolarsResult<()> { + // fast path all nulls + if self.physical.null_count() == self.len() && other.physical.null_count() == other.len() { + let len = self.len(); + self.set_lengths(other); + new_chunks(&mut self.physical.chunks, &other.physical().chunks, len); + return Ok(()); + } + + let mut new_self = call_categorical_merge_operation(self, other, CategoricalAppend)?; + std::mem::swap(self, &mut new_self); + Ok(()) + } + + pub fn append_owned(&mut self, other: Self) -> PolarsResult<()> { + // @TODO: Move the implementation to append_owned and make append dispatch here. + self.append(&other) + } +} diff --git a/crates/polars-core/src/chunked_array/logical/categorical/ops/full.rs b/crates/polars-core/src/chunked_array/logical/categorical/ops/full.rs new file mode 100644 index 000000000000..ed53722d163e --- /dev/null +++ b/crates/polars-core/src/chunked_array/logical/categorical/ops/full.rs @@ -0,0 +1,21 @@ +use super::*; + +impl CategoricalChunked { + pub fn full_null( + name: PlSmallStr, + is_enum: bool, + length: usize, + ordering: CategoricalOrdering, + ) -> CategoricalChunked { + let cats = UInt32Chunked::full_null(name, length); + + unsafe { + CategoricalChunked::from_cats_and_rev_map_unchecked( + cats, + Arc::new(RevMapping::default()), + is_enum, + ordering, + ) + } + } +} diff --git a/crates/polars-core/src/chunked_array/logical/categorical/ops/mod.rs b/crates/polars-core/src/chunked_array/logical/categorical/ops/mod.rs new file mode 100644 index 000000000000..759628b322cb --- /dev/null +++ b/crates/polars-core/src/chunked_array/logical/categorical/ops/mod.rs @@ -0,0 +1,8 @@ +mod append; +mod full; +#[cfg(feature = "algorithm_group_by")] +mod unique; +#[cfg(feature = "zip_with")] +mod zip; + +use super::*; diff --git a/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs b/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs new file mode 100644 index 000000000000..17752f828d8d --- /dev/null +++ b/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs @@ -0,0 +1,89 @@ +use super::*; + +impl CategoricalChunked { + pub fn unique(&self) -> PolarsResult { + let cat_map = self.get_rev_map(); + if self.is_empty() { + // SAFETY: rev map is valid. + unsafe { + return Ok(CategoricalChunked::from_cats_and_rev_map_unchecked( + UInt32Chunked::full_null(self.name().clone(), 0), + cat_map.clone(), + self.is_enum(), + self.get_ordering(), + )); + } + }; + + if self._can_fast_unique() { + let ca = match &**cat_map { + RevMapping::Local(a, _) => UInt32Chunked::from_iter_values( + self.physical().name().clone(), + 0..(a.len() as u32), + ), + RevMapping::Global(map, _, _) => UInt32Chunked::from_iter_values( + self.physical().name().clone(), + map.keys().copied(), + ), + }; + // SAFETY: + // we only removed some indexes so we are still in bounds + unsafe { + let mut out = CategoricalChunked::from_cats_and_rev_map_unchecked( + ca, + cat_map.clone(), + self.is_enum(), + self.get_ordering(), + ); + out.set_fast_unique(true); + Ok(out) + } + } else { + let ca = self.physical().unique()?; + // SAFETY: + // we only removed some indexes so we are still in bounds + unsafe { + Ok(CategoricalChunked::from_cats_and_rev_map_unchecked( + ca, + cat_map.clone(), + self.is_enum(), + self.get_ordering(), + )) + } + } + } + + pub fn n_unique(&self) -> PolarsResult { + if self._can_fast_unique() { + Ok(self.get_rev_map().len()) + } else { + self.physical().n_unique() + } + } + + pub fn value_counts(&self) -> PolarsResult { + let groups = self.physical().group_tuples(true, false).unwrap(); + let physical_values = unsafe { + self.physical() + .clone() + .into_series() + .agg_first(&groups) + .u32() + .unwrap() + .clone() + }; + + let mut values = self.clone(); + *values.physical_mut() = physical_values; + + let mut counts = groups.group_count(); + counts.rename(PlSmallStr::from_static("counts")); + let height = counts.len(); + let cols = vec![values.into_series().into(), counts.into_series().into()]; + let df = unsafe { DataFrame::new_no_checks(height, cols) }; + df.sort( + ["counts"], + SortMultipleOptions::default().with_order_descending(true), + ) + } +} diff --git a/crates/polars-core/src/chunked_array/logical/categorical/ops/zip.rs b/crates/polars-core/src/chunked_array/logical/categorical/ops/zip.rs new file mode 100644 index 000000000000..b92f46c68e17 --- /dev/null +++ b/crates/polars-core/src/chunked_array/logical/categorical/ops/zip.rs @@ -0,0 +1,18 @@ +use super::*; + +struct CategoricalZipWith<'a>(&'a BooleanChunked); + +impl CategoricalMergeOperation for CategoricalZipWith<'_> { + fn finish(self, lhs: &UInt32Chunked, rhs: &UInt32Chunked) -> PolarsResult { + lhs.zip_with(self.0, rhs) + } +} +impl CategoricalChunked { + pub(crate) fn zip_with( + &self, + mask: &BooleanChunked, + other: &CategoricalChunked, + ) -> PolarsResult { + call_categorical_merge_operation(self, other, CategoricalZipWith(mask)) + } +} diff --git a/crates/polars-core/src/chunked_array/logical/categorical/revmap.rs b/crates/polars-core/src/chunked_array/logical/categorical/revmap.rs new file mode 100644 index 000000000000..0b27aaec5bbe --- /dev/null +++ b/crates/polars-core/src/chunked_array/logical/categorical/revmap.rs @@ -0,0 +1,170 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use std::fmt::{Debug, Formatter}; +use std::hash::{BuildHasher, Hash, Hasher}; + +use arrow::array::*; +use polars_utils::aliases::PlFixedStateQuality; + +use crate::datatypes::PlHashMap; +use crate::{StringCache, using_string_cache}; + +#[derive(Clone)] +pub enum RevMapping { + /// Hashmap: maps the indexes from the global cache/categorical array to indexes in the local Utf8Array + /// Utf8Array: caches the string values + Global(PlHashMap, Utf8ViewArray, u32), + /// Utf8Array: caches the string values and a hash of all values for quick comparison + Local(Utf8ViewArray, u128), +} + +impl Debug for RevMapping { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + RevMapping::Global(_, _, _) => { + write!(f, "global") + }, + RevMapping::Local(_, _) => { + write!(f, "local") + }, + } + } +} + +impl Default for RevMapping { + fn default() -> Self { + let slice: &[Option<&str>] = &[]; + let cats = Utf8ViewArray::from_slice(slice); + if using_string_cache() { + let cache = &mut crate::STRING_CACHE.lock_map(); + let id = cache.uuid; + RevMapping::Global(Default::default(), cats, id) + } else { + RevMapping::build_local(cats) + } + } +} + +#[allow(clippy::len_without_is_empty)] +impl RevMapping { + pub fn is_active_global(&self) -> bool { + match self { + Self::Global(_, _, id) => *id == StringCache::active_cache_id(), + _ => false, + } + } + + pub fn is_global(&self) -> bool { + matches!(self, Self::Global(_, _, _)) + } + + pub fn is_local(&self) -> bool { + matches!(self, Self::Local(_, _)) + } + + /// Get the categories in this [`RevMapping`] + pub fn get_categories(&self) -> &Utf8ViewArray { + match self { + Self::Global(_, a, _) => a, + Self::Local(a, _) => a, + } + } + + fn build_hash(categories: &Utf8ViewArray) -> u128 { + // TODO! we must also validate the cases of duplicates! + let mut hb = PlFixedStateQuality::with_seed(0).build_hasher(); + categories.values_iter().for_each(|val| { + val.hash(&mut hb); + }); + let hash = hb.finish(); + ((hash as u128) << 64) | (categories.total_buffer_len() as u128) + } + + pub fn build_local(categories: Utf8ViewArray) -> Self { + debug_assert_eq!(categories.null_count(), 0); + let hash = Self::build_hash(&categories); + Self::Local(categories, hash) + } + + /// Get the length of the [`RevMapping`] + pub fn len(&self) -> usize { + self.get_categories().len() + } + + /// [`Categorical`] to [`str`] + /// + /// [`Categorical`]: crate::datatypes::DataType::Categorical + pub fn get(&self, idx: u32) -> &str { + match self { + Self::Global(map, a, _) => { + let idx = *map.get(&idx).unwrap(); + a.value(idx as usize) + }, + Self::Local(a, _) => a.value(idx as usize), + } + } + + pub fn get_optional(&self, idx: u32) -> Option<&str> { + match self { + Self::Global(map, a, _) => { + let idx = *map.get(&idx)?; + a.get(idx as usize) + }, + Self::Local(a, _) => a.get(idx as usize), + } + } + + /// [`Categorical`] to [`str`] + /// + /// [`Categorical`]: crate::datatypes::DataType::Categorical + /// + /// # Safety + /// This doesn't do any bound checking + pub(crate) unsafe fn get_unchecked(&self, idx: u32) -> &str { + match self { + Self::Global(map, a, _) => { + let idx = *map.get(&idx).unwrap(); + a.value_unchecked(idx as usize) + }, + Self::Local(a, _) => a.value_unchecked(idx as usize), + } + } + /// Check if the categoricals have a compatible mapping + #[inline] + pub fn same_src(&self, other: &Self) -> bool { + match (self, other) { + (RevMapping::Global(_, _, l), RevMapping::Global(_, _, r)) => *l == *r, + (RevMapping::Local(_, l_hash), RevMapping::Local(_, r_hash)) => l_hash == r_hash, + _ => false, + } + } + + /// [`str`] to [`Categorical`] + /// + /// + /// [`Categorical`]: crate::datatypes::DataType::Categorical + pub fn find(&self, value: &str) -> Option { + match self { + Self::Global(rev_map, a, id) => { + // fast path is check + if using_string_cache() { + let map = crate::STRING_CACHE.read_map(); + if map.uuid == *id { + return map.get_cat(value); + } + } + rev_map + .iter() + // SAFETY: + // value is always within bounds + .find(|&(_k, &v)| (unsafe { a.value_unchecked(v as usize) } == value)) + .map(|(k, _v)| *k) + }, + + Self::Local(a, _) => { + // SAFETY: within bounds + unsafe { (0..a.len()).find(|idx| a.value_unchecked(*idx) == value) } + .map(|idx| idx as u32) + }, + } + } +} diff --git a/crates/polars-core/src/chunked_array/logical/categorical/string_cache.rs b/crates/polars-core/src/chunked_array/logical/categorical/string_cache.rs new file mode 100644 index 000000000000..1b70b3efa894 --- /dev/null +++ b/crates/polars-core/src/chunked_array/logical/categorical/string_cache.rs @@ -0,0 +1,258 @@ +use std::hash::{BuildHasher, Hash, Hasher}; +use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; +use std::sync::{LazyLock, Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard}; + +use hashbrown::HashTable; +use hashbrown::hash_table::Entry; +use polars_utils::aliases::PlFixedStateQuality; +use polars_utils::pl_str::PlSmallStr; + +use crate::hashing::_HASHMAP_INIT_SIZE; + +/// We use atomic reference counting to determine how many threads use the +/// string cache. If the refcount is zero, we may clear the string cache. +static STRING_CACHE_REFCOUNT: Mutex = Mutex::new(0); +static STRING_CACHE_ENABLED_GLOBALLY: AtomicBool = AtomicBool::new(false); +static STRING_CACHE_UUID_CTR: AtomicU32 = AtomicU32::new(0); + +/// Enable the global string cache as long as the object is alive ([RAII]). +/// +/// # Examples +/// +/// Enable the string cache by initializing the object: +/// +/// ``` +/// use polars_core::StringCacheHolder; +/// +/// let _sc = StringCacheHolder::hold(); +/// ``` +/// +/// The string cache is enabled until `handle` is dropped. +/// +/// # De-allocation +/// +/// Multiple threads can hold the string cache at the same time. +/// The contents of the cache will only get dropped when no thread holds it. +/// +/// [RAII]: https://en.wikipedia.org/wiki/Resource_acquisition_is_initialization +pub struct StringCacheHolder { + // only added so that it will never be constructed directly + #[allow(dead_code)] + private_zst: (), +} + +impl Default for StringCacheHolder { + fn default() -> Self { + Self::hold() + } +} + +impl StringCacheHolder { + /// Hold the StringCache + pub fn hold() -> StringCacheHolder { + increment_string_cache_refcount(); + StringCacheHolder { private_zst: () } + } +} + +impl Drop for StringCacheHolder { + fn drop(&mut self) { + decrement_string_cache_refcount(); + } +} + +fn increment_string_cache_refcount() { + let mut refcount = STRING_CACHE_REFCOUNT.lock().unwrap(); + *refcount += 1; +} +fn decrement_string_cache_refcount() { + let mut refcount = STRING_CACHE_REFCOUNT.lock().unwrap(); + *refcount -= 1; + if *refcount == 0 { + STRING_CACHE.clear() + } +} + +/// Enable the global string cache. +/// +/// [`Categorical`] columns created under the same global string cache have the +/// same underlying physical value when string values are equal. This allows the +/// columns to be concatenated or used in a join operation, for example. +/// +/// Note that enabling the global string cache introduces some overhead. +/// The amount of overhead depends on the number of categories in your data. +/// It is advised to enable the global string cache only when strictly necessary. +/// +/// [`Categorical`]: crate::datatypes::DataType::Categorical +pub fn enable_string_cache() { + let was_enabled = STRING_CACHE_ENABLED_GLOBALLY.swap(true, Ordering::AcqRel); + if !was_enabled { + increment_string_cache_refcount(); + } +} + +/// Disable and clear the global string cache. +/// +/// Note: Consider using [`StringCacheHolder`] for a more reliable way of +/// enabling and disabling the string cache. +pub fn disable_string_cache() { + let was_enabled = STRING_CACHE_ENABLED_GLOBALLY.swap(false, Ordering::AcqRel); + if was_enabled { + decrement_string_cache_refcount(); + } +} + +/// Check whether the global string cache is enabled. +pub fn using_string_cache() -> bool { + let refcount = STRING_CACHE_REFCOUNT.lock().unwrap(); + *refcount > 0 +} + +// This is the hash and the Index offset in the linear buffer +#[derive(Copy, Clone)] +struct Key { + pub(super) hash: u64, + pub(super) idx: u32, +} + +impl Key { + #[inline] + pub(super) fn new(hash: u64, idx: u32) -> Self { + Self { hash, idx } + } +} + +impl Hash for Key { + #[inline] + fn hash(&self, state: &mut H) { + state.write_u64(self.hash) + } +} + +pub(crate) struct SCacheInner { + map: HashTable, + pub(crate) uuid: u32, + payloads: Vec, +} + +impl SCacheInner { + #[inline] + pub(crate) unsafe fn get_unchecked(&self, cat: u32) -> &str { + self.payloads.get_unchecked(cat as usize).as_str() + } + + pub(crate) fn len(&self) -> usize { + self.map.len() + } + + #[inline] + pub(crate) fn insert_from_hash(&mut self, h: u64, s: &str) -> u32 { + let mut global_idx = self.payloads.len() as u32; + let entry = self.map.entry( + h, + |k| { + let value = unsafe { self.payloads.get_unchecked(k.idx as usize) }; + s == value.as_str() + }, + |k| k.hash, + ); + + match entry { + Entry::Occupied(entry) => { + global_idx = entry.get().idx; + }, + Entry::Vacant(entry) => { + let idx = self.payloads.len() as u32; + let key = Key::new(h, idx); + entry.insert(key); + self.payloads.push(PlSmallStr::from_str(s)); + }, + } + global_idx + } + + #[inline] + pub(crate) fn get_cat(&self, s: &str) -> Option { + let h = StringCache::get_hash_builder().hash_one(s); + self.map + .find(h, |k| { + let value = unsafe { self.payloads.get_unchecked(k.idx as usize) }; + s == value.as_str() + }) + .map(|k| k.idx) + } + + #[inline] + pub(crate) fn insert(&mut self, s: &str) -> u32 { + let h = StringCache::get_hash_builder().hash_one(s); + self.insert_from_hash(h, s) + } + + #[inline] + pub(crate) fn get_current_payloads(&self) -> &[PlSmallStr] { + &self.payloads + } +} + +impl Default for SCacheInner { + fn default() -> Self { + Self { + map: HashTable::with_capacity(_HASHMAP_INIT_SIZE), + uuid: STRING_CACHE_UUID_CTR.fetch_add(1, Ordering::AcqRel), + payloads: Vec::with_capacity(_HASHMAP_INIT_SIZE), + } + } +} + +/// Used by categorical data that need to share global categories. +/// In *eager* you need to specifically toggle global string cache to have a global effect. +/// In *lazy* it is toggled on at the start of a computation run and turned of (deleted) when a +/// result is produced. +#[derive(Default)] +pub(crate) struct StringCache(pub(crate) RwLock); + +impl StringCache { + /// The global `StringCache` will always use a predictable seed. This allows local builders to mimic + /// the hashes in case of contention. + #[inline] + pub(crate) fn get_hash_builder() -> PlFixedStateQuality { + PlFixedStateQuality::with_seed(0) + } + + pub(crate) fn active_cache_id() -> u32 { + STRING_CACHE_UUID_CTR + .load(Ordering::Relaxed) + .wrapping_sub(1) + } + + /// Lock the string cache + pub(crate) fn lock_map(&self) -> RwLockWriteGuard { + self.0.write().unwrap() + } + + pub(crate) fn read_map(&self) -> RwLockReadGuard { + self.0.read().unwrap() + } + + pub(crate) fn clear(&self) { + let mut lock = self.lock_map(); + *lock = Default::default(); + } + + pub(crate) fn apply(&self, fun: F) -> (u32, T) + where + F: FnOnce(&mut RwLockWriteGuard) -> T, + { + let cache = &mut crate::STRING_CACHE.lock_map(); + + let result = fun(cache); + + if cache.len() > u32::MAX as usize { + panic!("not more than {} categories supported", u32::MAX) + }; + + (cache.uuid, result) + } +} + +pub(crate) static STRING_CACHE: LazyLock = LazyLock::new(Default::default); diff --git a/crates/polars-core/src/chunked_array/logical/date.rs b/crates/polars-core/src/chunked_array/logical/date.rs new file mode 100644 index 000000000000..ae67fd09ed0e --- /dev/null +++ b/crates/polars-core/src/chunked_array/logical/date.rs @@ -0,0 +1,61 @@ +use super::*; +use crate::prelude::*; +pub type DateChunked = Logical; + +impl From for DateChunked { + fn from(ca: Int32Chunked) -> Self { + DateChunked::new_logical(ca) + } +} + +impl Int32Chunked { + pub fn into_date(self) -> DateChunked { + DateChunked::new_logical(self) + } +} + +impl LogicalType for DateChunked { + fn dtype(&self) -> &DataType { + &DataType::Date + } + + fn get_any_value(&self, i: usize) -> PolarsResult> { + self.0.get_any_value(i).map(|av| av.as_date()) + } + + unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> { + self.0.get_any_value_unchecked(i).as_date() + } + + fn cast_with_options( + &self, + dtype: &DataType, + cast_options: CastOptions, + ) -> PolarsResult { + use DataType::*; + match dtype { + Date => Ok(self.clone().into_series()), + #[cfg(feature = "dtype-datetime")] + Datetime(tu, tz) => { + let casted = self.0.cast_with_options(dtype, cast_options)?; + let casted = casted.datetime().unwrap(); + let conversion = match tu { + TimeUnit::Nanoseconds => NS_IN_DAY, + TimeUnit::Microseconds => US_IN_DAY, + TimeUnit::Milliseconds => MS_IN_DAY, + }; + Ok((casted.deref() * conversion) + .into_datetime(*tu, tz.clone()) + .into_series()) + }, + dt if dt.is_primitive_numeric() => self.0.cast_with_options(dtype, cast_options), + dt => { + polars_bail!( + InvalidOperation: + "casting from {:?} to {:?} not supported", + self.dtype(), dt + ) + }, + } + } +} diff --git a/crates/polars-core/src/chunked_array/logical/datetime.rs b/crates/polars-core/src/chunked_array/logical/datetime.rs new file mode 100644 index 000000000000..b5b6434d41a5 --- /dev/null +++ b/crates/polars-core/src/chunked_array/logical/datetime.rs @@ -0,0 +1,121 @@ +use super::*; +use crate::prelude::*; + +pub type DatetimeChunked = Logical; + +impl Int64Chunked { + pub fn into_datetime(self, timeunit: TimeUnit, tz: Option) -> DatetimeChunked { + let mut dt = DatetimeChunked::new_logical(self); + dt.2 = Some(DataType::Datetime(timeunit, tz)); + dt + } +} + +impl LogicalType for DatetimeChunked { + fn dtype(&self) -> &DataType { + self.2.as_ref().unwrap() + } + + fn get_any_value(&self, i: usize) -> PolarsResult> { + self.0 + .get_any_value(i) + .map(|av| av.as_datetime(self.time_unit(), self.time_zone().as_ref())) + } + + unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> { + self.0 + .get_any_value_unchecked(i) + .as_datetime(self.time_unit(), self.time_zone().as_ref()) + } + + fn cast_with_options( + &self, + dtype: &DataType, + cast_options: CastOptions, + ) -> PolarsResult { + use DataType::*; + use TimeUnit::*; + let out = match dtype { + Datetime(to_unit, tz) => { + let from_unit = self.time_unit(); + let (multiplier, divisor) = match (from_unit, to_unit) { + // scaling from lower precision to higher precision + (Milliseconds, Nanoseconds) => (Some(1_000_000i64), None), + (Milliseconds, Microseconds) => (Some(1_000i64), None), + (Microseconds, Nanoseconds) => (Some(1_000i64), None), + // scaling from higher precision to lower precision + (Nanoseconds, Milliseconds) => (None, Some(1_000_000i64)), + (Nanoseconds, Microseconds) => (None, Some(1_000i64)), + (Microseconds, Milliseconds) => (None, Some(1_000i64)), + _ => return self.0.cast_with_options(dtype, cast_options), + }; + match multiplier { + // scale to higher precision (eg: ms → us, ms → ns, us → ns) + Some(m) => Ok((self.0.as_ref() * m) + .into_datetime(*to_unit, tz.clone()) + .into_series()), + // scale to lower precision (eg: ns → us, ns → ms, us → ms) + None => match divisor { + Some(d) => Ok(self + .0 + .apply_values(|v| v.div_euclid(d)) + .into_datetime(*to_unit, tz.clone()) + .into_series()), + None => unreachable!("must always have a time unit divisor here"), + }, + } + }, + #[cfg(feature = "dtype-date")] + Date => { + let cast_to_date = |tu_in_day: i64| { + let mut dt = self + .0 + .apply_values(|v| v.div_euclid(tu_in_day)) + .cast_with_options(&Int32, cast_options) + .unwrap() + .into_date() + .into_series(); + dt.set_sorted_flag(self.is_sorted_flag()); + Ok(dt) + }; + match self.time_unit() { + Nanoseconds => cast_to_date(NS_IN_DAY), + Microseconds => cast_to_date(US_IN_DAY), + Milliseconds => cast_to_date(MS_IN_DAY), + } + }, + #[cfg(feature = "dtype-time")] + Time => { + let (scaled_mod, multiplier) = match self.time_unit() { + Nanoseconds => (NS_IN_DAY, 1i64), + Microseconds => (US_IN_DAY, 1_000i64), + Milliseconds => (MS_IN_DAY, 1_000_000i64), + }; + return Ok(self + .0 + .apply_values(|v| { + let t = v % scaled_mod * multiplier; + t + (NS_IN_DAY * (t < 0) as i64) + }) + .into_time() + .into_series()); + }, + dt if dt.is_primitive_numeric() => { + return self.0.cast_with_options(dtype, cast_options); + }, + dt => { + polars_bail!( + InvalidOperation: + "casting from {:?} to {:?} not supported", + self.dtype(), dt + ) + }, + }; + out.map(|mut s| { + // TODO!; implement the divisions/multipliers above + // in a checked manner so that we raise on overflow + s.set_sorted_flag(self.is_sorted_flag()); + s + }) + } +} diff --git a/crates/polars-core/src/chunked_array/logical/decimal.rs b/crates/polars-core/src/chunked_array/logical/decimal.rs new file mode 100644 index 000000000000..306e164dba9e --- /dev/null +++ b/crates/polars-core/src/chunked_array/logical/decimal.rs @@ -0,0 +1,130 @@ +use std::borrow::Cow; + +use super::*; +use crate::chunked_array::cast::cast_chunks; +use crate::prelude::*; + +pub type DecimalChunked = Logical; + +impl Int128Chunked { + #[inline] + pub fn into_decimal_unchecked(self, precision: Option, scale: usize) -> DecimalChunked { + let mut dt = DecimalChunked::new_logical(self); + dt.2 = Some(DataType::Decimal(precision, Some(scale))); + dt + } + + pub fn into_decimal( + self, + precision: Option, + scale: usize, + ) -> PolarsResult { + // TODO: if precision is None, do we check that the value fits within precision of 38?... + if let Some(precision) = precision { + let precision_max = 10_i128.pow(precision as u32); + if let Some((min, max)) = self.min_max() { + let max_abs = max.abs().max(min.abs()); + polars_ensure!( + max_abs < precision_max, + ComputeError: "decimal precision {} can't fit values with {} digits", + precision, + max_abs.to_string().len() + ); + } + } + Ok(self.into_decimal_unchecked(precision, scale)) + } +} + +impl LogicalType for DecimalChunked { + fn dtype(&self) -> &DataType { + self.2.as_ref().unwrap() + } + + #[inline] + fn get_any_value(&self, i: usize) -> PolarsResult> { + polars_ensure!(i < self.len(), oob = i, self.len()); + Ok(unsafe { self.get_any_value_unchecked(i) }) + } + + #[inline] + unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> { + match self.0.get_unchecked(i) { + Some(v) => AnyValue::Decimal(v, self.scale()), + None => AnyValue::Null, + } + } + + fn cast_with_options( + &self, + dtype: &DataType, + cast_options: CastOptions, + ) -> PolarsResult { + let mut dtype = Cow::Borrowed(dtype); + if let DataType::Decimal(to_precision, to_scale) = dtype.as_ref() { + let from_precision = self.precision(); + let from_scale = self.scale(); + + let to_precision = to_precision.or(from_precision); + let to_scale = to_scale.unwrap_or(from_scale); + + if to_precision == from_precision && to_scale == from_scale { + return Ok(self.clone().into_series()); + } + + dtype = Cow::Owned(DataType::Decimal(to_precision, Some(to_scale))); + } + + let arrow_dtype = self.dtype().to_arrow(CompatLevel::newest()); + let chunks = self + .chunks + .iter() + .map(|arr| { + arr.as_any() + .downcast_ref::>() + .unwrap() + .clone() + .to(arrow_dtype.clone()) + .to_boxed() + }) + .collect::>(); + let chunks = cast_chunks(&chunks, dtype.as_ref(), cast_options)?; + Series::try_from((self.name().clone(), chunks)) + } +} + +impl DecimalChunked { + pub fn precision(&self) -> Option { + match self.2.as_ref().unwrap() { + DataType::Decimal(precision, _) => *precision, + _ => unreachable!(), + } + } + + pub fn scale(&self) -> usize { + match self.2.as_ref().unwrap() { + DataType::Decimal(_, scale) => scale.unwrap_or_else(|| unreachable!()), + _ => unreachable!(), + } + } + + pub fn to_scale(&self, scale: usize) -> PolarsResult> { + if self.scale() == scale { + return Ok(Cow::Borrowed(self)); + } + + let mut precision = self.precision(); + if let Some(ref mut precision) = precision { + if self.scale() < scale { + *precision += scale; + *precision = (*precision).min(38); + } + } + + let s = self.cast_with_options( + &DataType::Decimal(precision, Some(scale)), + CastOptions::NonStrict, + )?; + Ok(Cow::Owned(s.decimal().unwrap().clone())) + } +} diff --git a/crates/polars-core/src/chunked_array/logical/duration.rs b/crates/polars-core/src/chunked_array/logical/duration.rs new file mode 100644 index 000000000000..1dc0eab17c5d --- /dev/null +++ b/crates/polars-core/src/chunked_array/logical/duration.rs @@ -0,0 +1,67 @@ +use super::*; +use crate::prelude::*; + +pub type DurationChunked = Logical; + +impl Int64Chunked { + pub fn into_duration(self, timeunit: TimeUnit) -> DurationChunked { + let mut dt = DurationChunked::new_logical(self); + dt.2 = Some(DataType::Duration(timeunit)); + dt + } +} + +impl LogicalType for DurationChunked { + fn dtype(&self) -> &DataType { + self.2.as_ref().unwrap() + } + + fn get_any_value(&self, i: usize) -> PolarsResult> { + self.0 + .get_any_value(i) + .map(|av| av.as_duration(self.time_unit())) + } + unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> { + self.0 + .get_any_value_unchecked(i) + .as_duration(self.time_unit()) + } + + fn cast_with_options( + &self, + dtype: &DataType, + cast_options: CastOptions, + ) -> PolarsResult { + use DataType::*; + use TimeUnit::*; + match dtype { + Duration(tu) => { + let to_unit = *tu; + let out = match (self.time_unit(), to_unit) { + (Milliseconds, Microseconds) => self.0.as_ref() * 1_000i64, + (Milliseconds, Nanoseconds) => self.0.as_ref() * 1_000_000i64, + (Microseconds, Milliseconds) => { + self.0.as_ref().wrapping_trunc_div_scalar(1_000i64) + }, + (Microseconds, Nanoseconds) => self.0.as_ref() * 1_000i64, + (Nanoseconds, Milliseconds) => { + self.0.as_ref().wrapping_trunc_div_scalar(1_000_000i64) + }, + (Nanoseconds, Microseconds) => { + self.0.as_ref().wrapping_trunc_div_scalar(1_000i64) + }, + _ => return Ok(self.clone().into_series()), + }; + Ok(out.into_duration(to_unit).into_series()) + }, + dt if dt.is_primitive_numeric() => self.0.cast_with_options(dtype, cast_options), + dt => { + polars_bail!( + InvalidOperation: + "casting from {:?} to {:?} not supported", + self.dtype(), dt + ) + }, + } + } +} diff --git a/crates/polars-core/src/chunked_array/logical/enum_/mod.rs b/crates/polars-core/src/chunked_array/logical/enum_/mod.rs new file mode 100644 index 000000000000..7894a8c6bf91 --- /dev/null +++ b/crates/polars-core/src/chunked_array/logical/enum_/mod.rs @@ -0,0 +1,109 @@ +use std::sync::Arc; + +use arrow::array::UInt32Vec; +use arrow::bitmap::MutableBitmap; +use polars_error::{PolarsResult, polars_bail, polars_err}; +use polars_utils::aliases::{InitHashMaps, PlHashMap}; +use polars_utils::pl_str::PlSmallStr; + +use super::{CategoricalChunked, CategoricalOrdering, DataType, Field, RevMapping, UInt32Chunked}; + +pub struct EnumChunkedBuilder { + name: PlSmallStr, + enum_builder: UInt32Vec, + + rev: Arc, + ordering: CategoricalOrdering, + seen: MutableBitmap, + + // Mapping to amortize the costs of lookups. + mapping: PlHashMap, + strict: bool, +} + +impl EnumChunkedBuilder { + pub fn new( + name: PlSmallStr, + capacity: usize, + rev: Arc, + ordering: CategoricalOrdering, + strict: bool, + ) -> Self { + let seen = MutableBitmap::from_len_zeroed(rev.len()); + + Self { + name, + enum_builder: UInt32Vec::with_capacity(capacity), + + rev, + ordering, + seen, + + mapping: PlHashMap::new(), + strict, + } + } + + pub fn append_str(&mut self, v: &str) -> PolarsResult<&mut Self> { + match self.mapping.get(v) { + Some(v) => self.enum_builder.push(Some(*v)), + None => { + let Some(iv) = self.rev.find(v) else { + if self.strict { + polars_bail!(InvalidOperation: "cannot append '{v}' to enum without that variant"); + } else { + self.enum_builder.push(None); + return Ok(self); + } + }; + self.seen.set(iv as usize, true); + self.mapping.insert(v.into(), iv); + self.enum_builder.push(Some(iv)); + }, + } + + Ok(self) + } + + pub fn append_null(&mut self) -> &mut Self { + self.enum_builder.push(None); + self + } + + pub fn append_enum(&mut self, v: u32, rev: &RevMapping) -> PolarsResult<&mut Self> { + if !self.rev.same_src(rev) { + if self.strict { + return Err(polars_err!(ComputeError: "incompatible enum types")); + } else { + self.enum_builder.push(None); + } + } else { + self.seen.set(v as usize, true); + self.enum_builder.push(Some(v)); + } + + Ok(self) + } + + pub fn finish(self) -> CategoricalChunked { + let arr = self.enum_builder.freeze(); + let null_count = arr.validity().map_or(0, |a| a.unset_bits()); + let length = arr.len(); + let ca = unsafe { + UInt32Chunked::new_with_dims( + Arc::new(Field::new(self.name, DataType::UInt32)), + vec![Box::new(arr)], + length, + null_count, + ) + }; + // Fast Unique <=> unique(rev) == unique(ca) + let fast_unique = !ca.has_nulls() && self.seen.unset_bits() == 0; + + // SAFETY: keys and values are in bounds + unsafe { + CategoricalChunked::from_cats_and_rev_map_unchecked(ca, self.rev, true, self.ordering) + .with_fast_unique(fast_unique) + } + } +} diff --git a/crates/polars-core/src/chunked_array/logical/mod.rs b/crates/polars-core/src/chunked_array/logical/mod.rs new file mode 100644 index 000000000000..e450f1508a81 --- /dev/null +++ b/crates/polars-core/src/chunked_array/logical/mod.rs @@ -0,0 +1,110 @@ +#[cfg(feature = "dtype-date")] +mod date; +#[cfg(feature = "dtype-date")] +pub use date::*; +#[cfg(feature = "dtype-datetime")] +mod datetime; +#[cfg(feature = "dtype-datetime")] +pub use datetime::*; +#[cfg(feature = "dtype-decimal")] +mod decimal; +#[cfg(feature = "dtype-decimal")] +pub use decimal::*; +#[cfg(feature = "dtype-duration")] +mod duration; +#[cfg(feature = "dtype-duration")] +pub use duration::*; +#[cfg(feature = "dtype-categorical")] +pub mod categorical; +#[cfg(feature = "dtype-categorical")] +pub mod enum_; +#[cfg(feature = "dtype-time")] +mod time; + +use std::marker::PhantomData; +use std::ops::{Deref, DerefMut}; + +#[cfg(feature = "dtype-categorical")] +pub use categorical::*; +#[cfg(feature = "dtype-time")] +pub use time::*; + +use crate::chunked_array::cast::CastOptions; +use crate::prelude::*; + +/// Maps a logical type to a chunked array implementation of the physical type. +/// This saves a lot of compiler bloat and allows us to reuse functionality. +pub struct Logical( + pub ChunkedArray, + PhantomData, + pub Option, +); + +impl Default for Logical { + fn default() -> Self { + Self(Default::default(), Default::default(), Default::default()) + } +} + +impl Clone for Logical { + fn clone(&self) -> Self { + let mut new = Logical::::new_logical(self.0.clone()); + new.2.clone_from(&self.2); + new + } +} + +impl Deref for Logical { + type Target = ChunkedArray; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for Logical { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl Logical { + pub fn new_logical(ca: ChunkedArray) -> Logical { + Logical(ca, PhantomData, None) + } +} + +pub trait LogicalType { + /// Get data type of [`ChunkedArray`]. + fn dtype(&self) -> &DataType; + + /// Gets [`AnyValue`] from [`LogicalType`] + fn get_any_value(&self, _i: usize) -> PolarsResult> { + unimplemented!() + } + + /// # Safety + /// Does not do any bound checks. + unsafe fn get_any_value_unchecked(&self, _i: usize) -> AnyValue<'_> { + unimplemented!() + } + + fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult; + + fn cast(&self, dtype: &DataType) -> PolarsResult { + self.cast_with_options(dtype, CastOptions::NonStrict) + } +} + +impl Logical +where + Self: LogicalType, +{ + pub fn physical(&self) -> &ChunkedArray { + &self.0 + } + pub fn field(&self) -> Field { + let name = self.0.ref_field().name(); + Field::new(name.clone(), LogicalType::dtype(self).clone()) + } +} diff --git a/crates/polars-core/src/chunked_array/logical/time.rs b/crates/polars-core/src/chunked_array/logical/time.rs new file mode 100644 index 000000000000..ce9b890282d8 --- /dev/null +++ b/crates/polars-core/src/chunked_array/logical/time.rs @@ -0,0 +1,105 @@ +use polars_compute::cast::CastOptionsImpl; + +use super::*; +use crate::prelude::*; + +pub type TimeChunked = Logical; + +impl From for TimeChunked { + fn from(ca: Int64Chunked) -> Self { + TimeChunked::new_logical(ca) + } +} + +impl Int64Chunked { + pub fn into_time(mut self) -> TimeChunked { + let mut null_count = 0; + + // Invalid time values are replaced with `null` during the arrow cast. We utilize the + // validity coming from there to create the new TimeChunked. + let chunks = std::mem::take(&mut self.chunks) + .into_iter() + .map(|chunk| { + // We need to retain the PhysicalType underneath, but we should properly update the + // validity as that might change because Time is not valid for all values of Int64. + let casted = polars_compute::cast::cast( + chunk.as_ref(), + &ArrowDataType::Time64(ArrowTimeUnit::Nanosecond), + CastOptionsImpl::default(), + ) + .unwrap(); + let validity = casted.validity(); + + match validity { + None => chunk, + Some(validity) => { + null_count += validity.unset_bits(); + chunk.with_validity(Some(validity.clone())) + }, + } + }) + .collect::>>(); + + debug_assert!(null_count >= self.null_count); + + // @TODO: We throw away metadata here. That is mostly not needed. + // SAFETY: We calculated the null_count again. And we are taking the rest from the previous + // Int64Chunked. + let int64chunked = + unsafe { Self::new_with_dims(self.field.clone(), chunks, self.length, null_count) }; + + TimeChunked::new_logical(int64chunked) + } +} + +impl LogicalType for TimeChunked { + fn dtype(&self) -> &'static DataType { + &DataType::Time + } + + #[cfg(feature = "dtype-time")] + fn get_any_value(&self, i: usize) -> PolarsResult> { + self.0.get_any_value(i).map(|av| av.as_time()) + } + unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> { + self.0.get_any_value_unchecked(i).as_time() + } + + fn cast_with_options( + &self, + dtype: &DataType, + cast_options: CastOptions, + ) -> PolarsResult { + use DataType::*; + match dtype { + Time => Ok(self.clone().into_series()), + #[cfg(feature = "dtype-duration")] + Duration(tu) => { + let out = self + .0 + .cast_with_options(&DataType::Duration(TimeUnit::Nanoseconds), cast_options); + if !matches!(tu, TimeUnit::Nanoseconds) { + out?.cast_with_options(dtype, cast_options) + } else { + out + } + }, + #[cfg(feature = "dtype-datetime")] + Datetime(_, _) => { + polars_bail!( + InvalidOperation: + "casting from {:?} to {:?} not supported; consider using `dt.combine`", + self.dtype(), dtype + ) + }, + dt if dt.is_primitive_numeric() => self.0.cast_with_options(dtype, cast_options), + _ => { + polars_bail!( + InvalidOperation: + "casting from {:?} to {:?} not supported", + self.dtype(), dtype + ) + }, + } + } +} diff --git a/crates/polars-core/src/chunked_array/mod.rs b/crates/polars-core/src/chunked_array/mod.rs new file mode 100644 index 000000000000..45d4bbebb00f --- /dev/null +++ b/crates/polars-core/src/chunked_array/mod.rs @@ -0,0 +1,1048 @@ +//! The typed heart of every Series column. +#![allow(unsafe_op_in_unsafe_fn)] +use std::iter::Map; +use std::sync::Arc; + +use arrow::array::*; +use arrow::bitmap::Bitmap; +use arrow::compute::concatenate::concatenate_unchecked; +use polars_compute::filter::filter_with_bitmap; + +use crate::prelude::*; + +pub mod ops; +#[macro_use] +pub mod arithmetic; +pub mod builder; +pub mod cast; +pub mod collect; +pub mod comparison; +pub mod flags; +pub mod float; +pub mod iterator; +#[cfg(feature = "ndarray")] +pub(crate) mod ndarray; + +#[cfg(feature = "dtype-array")] +pub(crate) mod array; +mod binary; +mod bitwise; +#[cfg(feature = "object")] +mod drop; +mod from; +mod from_iterator; +pub mod from_iterator_par; +pub(crate) mod list; +pub(crate) mod logical; +#[cfg(feature = "object")] +pub mod object; +#[cfg(feature = "random")] +mod random; +#[cfg(feature = "dtype-struct")] +mod struct_; +#[cfg(any( + feature = "temporal", + feature = "dtype-datetime", + feature = "dtype-date" +))] +pub mod temporal; +mod to_vec; +mod trusted_len; + +use std::slice::Iter; + +use arrow::legacy::prelude::*; +#[cfg(feature = "dtype-struct")] +pub use struct_::StructChunked; + +use self::flags::{StatisticsFlags, StatisticsFlagsIM}; +use crate::series::IsSorted; +use crate::utils::{first_non_null, last_non_null}; + +#[cfg(not(feature = "dtype-categorical"))] +pub struct RevMapping {} + +pub type ChunkLenIter<'a> = std::iter::Map, fn(&ArrayRef) -> usize>; + +/// # ChunkedArray +/// +/// Every Series contains a [`ChunkedArray`]. Unlike [`Series`], [`ChunkedArray`]s are typed. This allows +/// us to apply closures to the data and collect the results to a [`ChunkedArray`] of the same type `T`. +/// Below we use an apply to use the cosine function to the values of a [`ChunkedArray`]. +/// +/// ```rust +/// # use polars_core::prelude::*; +/// fn apply_cosine_and_cast(ca: &Float32Chunked) -> Float32Chunked { +/// ca.apply_values(|v| v.cos()) +/// } +/// ``` +/// +/// ## Conversion between Series and ChunkedArrays +/// Conversion from a [`Series`] to a [`ChunkedArray`] is effortless. +/// +/// ```rust +/// # use polars_core::prelude::*; +/// fn to_chunked_array(series: &Series) -> PolarsResult<&Int32Chunked>{ +/// series.i32() +/// } +/// +/// fn to_series(ca: Int32Chunked) -> Series { +/// ca.into_series() +/// } +/// ``` +/// +/// # Iterators +/// +/// [`ChunkedArray`]s fully support Rust native [Iterator](https://doc.rust-lang.org/std/iter/trait.Iterator.html) +/// and [DoubleEndedIterator](https://doc.rust-lang.org/std/iter/trait.DoubleEndedIterator.html) traits, thereby +/// giving access to all the excellent methods available for [Iterators](https://doc.rust-lang.org/std/iter/trait.Iterator.html). +/// +/// ```rust +/// # use polars_core::prelude::*; +/// +/// fn iter_forward(ca: &Float32Chunked) { +/// ca.iter() +/// .for_each(|opt_v| println!("{:?}", opt_v)) +/// } +/// +/// fn iter_backward(ca: &Float32Chunked) { +/// ca.iter() +/// .rev() +/// .for_each(|opt_v| println!("{:?}", opt_v)) +/// } +/// ``` +/// +/// # Memory layout +/// +/// [`ChunkedArray`]s use [Apache Arrow](https://github.com/apache/arrow) as backend for the memory layout. +/// Arrows memory is immutable which makes it possible to make multiple zero copy (sub)-views from a single array. +/// +/// To be able to append data, Polars uses chunks to append new memory locations, hence the [`ChunkedArray`] data structure. +/// Appends are cheap, because it will not lead to a full reallocation of the whole array (as could be the case with a Rust Vec). +/// +/// However, multiple chunks in a [`ChunkedArray`] will slow down many operations that need random access because we have an extra indirection +/// and indexes need to be mapped to the proper chunk. Arithmetic may also be slowed down by this. +/// When multiplying two [`ChunkedArray`]s with different chunk sizes they cannot utilize [SIMD](https://en.wikipedia.org/wiki/SIMD) for instance. +/// +/// If you want to have predictable performance +/// (no unexpected re-allocation of memory), it is advised to call the [`ChunkedArray::rechunk`] after +/// multiple append operations. +/// +/// See also [`ChunkedArray::extend`] for appends within a chunk. +/// +/// # Invariants +/// - A [`ChunkedArray`] should always have at least a single [`ArrayRef`]. +/// - The [`PolarsDataType`] `T` should always map to the correct [`ArrowDataType`] in the [`ArrayRef`] +/// chunks. +/// - Nested datatypes such as [`List`] and [`Array`] store the physical types instead of the +/// logical type given by the datatype. +/// +/// [`List`]: crate::datatypes::DataType::List +pub struct ChunkedArray { + pub(crate) field: Arc, + pub(crate) chunks: Vec, + + pub(crate) flags: StatisticsFlagsIM, + + length: usize, + null_count: usize, + _pd: std::marker::PhantomData, +} + +impl ChunkedArray { + fn should_rechunk(&self) -> bool { + self.chunks.len() > 1 && self.chunks.len() > self.len() / 3 + } + + fn optional_rechunk(mut self) -> Self { + // Rechunk if we have many small chunks. + if self.should_rechunk() { + self.rechunk_mut() + } + self + } + + pub(crate) fn as_any(&self) -> &dyn std::any::Any { + self + } + + /// Series to [`ChunkedArray`] + pub fn unpack_series_matching_type<'a>( + &self, + series: &'a Series, + ) -> PolarsResult<&'a ChunkedArray> { + match self.dtype() { + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(_, _) => { + let logical = series.decimal()?; + + let ca = logical.physical(); + Ok(ca.as_any().downcast_ref::>().unwrap()) + }, + dt => { + polars_ensure!( + dt == series.dtype(), + SchemaMismatch: "cannot unpack series of type `{}` into `{}`", + series.dtype(), + dt, + ); + + // SAFETY: + // dtype will be correct. + Ok(unsafe { self.unpack_series_matching_physical_type(series) }) + }, + } + } + + /// Create a new [`ChunkedArray`] and compute its `length` and `null_count`. + /// + /// If you want to explicitly the `length` and `null_count`, look at + /// [`ChunkedArray::new_with_dims`] + fn new_with_compute_len(field: Arc, chunks: Vec) -> Self { + unsafe { + let mut chunked_arr = Self::new_with_dims(field, chunks, 0, 0); + chunked_arr.compute_len(); + chunked_arr + } + } + + /// Create a new [`ChunkedArray`] and explicitly set its `length` and `null_count`. + /// # Safety + /// The length and null_count must be correct. + pub unsafe fn new_with_dims( + field: Arc, + chunks: Vec, + length: usize, + null_count: usize, + ) -> Self { + Self { + field, + chunks, + flags: StatisticsFlagsIM::empty(), + + _pd: Default::default(), + length, + null_count, + } + } + + pub(crate) fn is_sorted_ascending_flag(&self) -> bool { + self.get_flags().is_sorted_ascending() + } + + pub(crate) fn is_sorted_descending_flag(&self) -> bool { + self.get_flags().is_sorted_descending() + } + + /// Whether `self` is sorted in any direction. + pub(crate) fn is_sorted_any(&self) -> bool { + self.get_flags().is_sorted_any() + } + + pub fn unset_fast_explode_list(&mut self) { + self.set_fast_explode_list(false) + } + + pub fn set_fast_explode_list(&mut self, value: bool) { + let mut flags = self.flags.get_mut(); + flags.set(StatisticsFlags::CAN_FAST_EXPLODE_LIST, value); + self.flags.set_mut(flags); + } + + pub fn get_fast_explode_list(&self) -> bool { + self.get_flags().can_fast_explode_list() + } + + pub fn get_flags(&self) -> StatisticsFlags { + self.flags.get() + } + + /// Set flags for the [`ChunkedArray`] + pub(crate) fn set_flags(&mut self, flags: StatisticsFlags) { + self.flags = StatisticsFlagsIM::new(flags); + } + + pub fn is_sorted_flag(&self) -> IsSorted { + self.get_flags().is_sorted() + } + + pub fn retain_flags_from( + &mut self, + from: &ChunkedArray, + retain_flags: StatisticsFlags, + ) { + let flags = from.flags.get(); + // Try to avoid write contention. + if !flags.is_empty() { + self.set_flags(flags & retain_flags) + } + } + + /// Set the 'sorted' bit meta info. + pub fn set_sorted_flag(&mut self, sorted: IsSorted) { + let mut flags = self.flags.get_mut(); + flags.set_sorted(sorted); + self.flags.set_mut(flags); + } + + /// Set the 'sorted' bit meta info. + pub fn with_sorted_flag(&self, sorted: IsSorted) -> Self { + let mut out = self.clone(); + out.set_sorted_flag(sorted); + out + } + + /// Get the index of the first non null value in this [`ChunkedArray`]. + pub fn first_non_null(&self) -> Option { + if self.null_count() == self.len() { + None + } + // We now know there is at least 1 non-null item in the array, and self.len() > 0 + else if self.null_count() == 0 { + Some(0) + } else if self.is_sorted_any() { + let out = if unsafe { self.downcast_get_unchecked(0).is_null_unchecked(0) } { + // nulls are all at the start + self.null_count() + } else { + // nulls are all at the end + 0 + }; + + debug_assert!( + // If we are lucky this catches something. + unsafe { self.get_unchecked(out) }.is_some(), + "incorrect sorted flag" + ); + + Some(out) + } else { + first_non_null(self.iter_validities()) + } + } + + /// Get the index of the last non null value in this [`ChunkedArray`]. + pub fn last_non_null(&self) -> Option { + if self.null_count() == self.len() { + None + } + // We now know there is at least 1 non-null item in the array, and self.len() > 0 + else if self.null_count() == 0 { + Some(self.len() - 1) + } else if self.is_sorted_any() { + let out = if unsafe { self.downcast_get_unchecked(0).is_null_unchecked(0) } { + // nulls are all at the start + self.len() - 1 + } else { + // nulls are all at the end + self.len() - self.null_count() - 1 + }; + + debug_assert!( + // If we are lucky this catches something. + unsafe { self.get_unchecked(out) }.is_some(), + "incorrect sorted flag" + ); + + Some(out) + } else { + last_non_null(self.iter_validities(), self.len()) + } + } + + pub fn drop_nulls(&self) -> Self { + if self.null_count() == 0 { + self.clone() + } else { + let chunks = self + .downcast_iter() + .map(|arr| { + if arr.null_count() == 0 { + arr.to_boxed() + } else { + filter_with_bitmap(arr, arr.validity().unwrap()) + } + }) + .collect(); + unsafe { + Self::new_with_dims( + self.field.clone(), + chunks, + self.len() - self.null_count(), + 0, + ) + } + } + } + + /// Get the buffer of bits representing null values + #[inline] + #[allow(clippy::type_complexity)] + pub fn iter_validities(&self) -> Map, fn(&ArrayRef) -> Option<&Bitmap>> { + fn to_validity(arr: &ArrayRef) -> Option<&Bitmap> { + arr.validity() + } + self.chunks.iter().map(to_validity) + } + + #[inline] + /// Return if any the chunks in this [`ChunkedArray`] have nulls. + pub fn has_nulls(&self) -> bool { + self.null_count > 0 + } + + /// Shrink the capacity of this array to fit its length. + pub fn shrink_to_fit(&mut self) { + self.chunks = vec![concatenate_unchecked(self.chunks.as_slice()).unwrap()]; + } + + pub fn clear(&self) -> Self { + // SAFETY: we keep the correct dtype + let mut ca = unsafe { + self.copy_with_chunks(vec![new_empty_array( + self.chunks.first().unwrap().dtype().clone(), + )]) + }; + + use StatisticsFlags as F; + ca.retain_flags_from(self, F::IS_SORTED_ANY | F::CAN_FAST_EXPLODE_LIST); + ca + } + + /// Unpack a [`Series`] to the same physical type. + /// + /// # Safety + /// + /// This is unsafe as the dtype may be incorrect and + /// is assumed to be correct in other safe code. + pub(crate) unsafe fn unpack_series_matching_physical_type<'a>( + &self, + series: &'a Series, + ) -> &'a ChunkedArray { + let series_trait = &**series; + if self.dtype() == series.dtype() { + &*(series_trait as *const dyn SeriesTrait as *const ChunkedArray) + } else { + use DataType::*; + match (self.dtype(), series.dtype()) { + (Int64, Datetime(_, _)) | (Int64, Duration(_)) | (Int32, Date) => { + &*(series_trait as *const dyn SeriesTrait as *const ChunkedArray) + }, + _ => panic!( + "cannot unpack series {:?} into matching type {:?}", + series, + self.dtype() + ), + } + } + } + + /// Returns an iterator over the lengths of the chunks of the array. + pub fn chunk_lengths(&self) -> ChunkLenIter { + self.chunks.iter().map(|chunk| chunk.len()) + } + + /// A reference to the chunks + #[inline] + pub fn chunks(&self) -> &Vec { + &self.chunks + } + + /// A mutable reference to the chunks + /// + /// # Safety + /// The caller must ensure to not change the [`DataType`] or `length` of any of the chunks. + /// And the `null_count` remains correct. + #[inline] + pub unsafe fn chunks_mut(&mut self) -> &mut Vec { + &mut self.chunks + } + + /// Returns true if contains a single chunk and has no null values + pub fn is_optimal_aligned(&self) -> bool { + self.chunks.len() == 1 && self.null_count() == 0 + } + + /// Create a new [`ChunkedArray`] from self, where the chunks are replaced. + /// + /// # Safety + /// The caller must ensure the dtypes of the chunks are correct + unsafe fn copy_with_chunks(&self, chunks: Vec) -> Self { + Self::new_with_compute_len(self.field.clone(), chunks) + } + + /// Get data type of [`ChunkedArray`]. + pub fn dtype(&self) -> &DataType { + self.field.dtype() + } + + pub(crate) unsafe fn set_dtype(&mut self, dtype: DataType) { + self.field = Arc::new(Field::new(self.name().clone(), dtype)) + } + + /// Name of the [`ChunkedArray`]. + pub fn name(&self) -> &PlSmallStr { + self.field.name() + } + + /// Get a reference to the field. + pub fn ref_field(&self) -> &Field { + &self.field + } + + /// Rename this [`ChunkedArray`]. + pub fn rename(&mut self, name: PlSmallStr) { + self.field = Arc::new(Field::new(name, self.field.dtype().clone())); + } + + /// Return this [`ChunkedArray`] with a new name. + pub fn with_name(mut self, name: PlSmallStr) -> Self { + self.rename(name); + self + } +} + +impl ChunkedArray +where + T: PolarsDataType, +{ + /// Get a single value from this [`ChunkedArray`]. If the return values is `None` this + /// indicates a NULL value. + /// + /// # Panics + /// This function will panic if `idx` is out of bounds. + #[inline] + pub fn get(&self, idx: usize) -> Option> { + let (chunk_idx, arr_idx) = self.index_to_chunked_index(idx); + assert!( + chunk_idx < self.chunks().len(), + "index: {} out of bounds for len: {}", + idx, + self.len() + ); + unsafe { + let arr = self.downcast_get_unchecked(chunk_idx); + assert!( + arr_idx < arr.len(), + "index: {} out of bounds for len: {}", + idx, + self.len() + ); + arr.get_unchecked(arr_idx) + } + } + + /// Get a single value from this [`ChunkedArray`]. If the return values is `None` this + /// indicates a NULL value. + /// + /// # Safety + /// It is the callers responsibility that the `idx < self.len()`. + #[inline] + pub unsafe fn get_unchecked(&self, idx: usize) -> Option> { + let (chunk_idx, arr_idx) = self.index_to_chunked_index(idx); + + unsafe { + // SAFETY: up to the caller to make sure the index is valid. + self.downcast_get_unchecked(chunk_idx) + .get_unchecked(arr_idx) + } + } + + /// Get a single value from this [`ChunkedArray`]. Null values are ignored and the returned + /// value could be garbage if it was masked out by NULL. Note that the value always is initialized. + /// + /// # Safety + /// It is the callers responsibility that the `idx < self.len()`. + #[inline] + pub unsafe fn value_unchecked(&self, idx: usize) -> T::Physical<'_> { + let (chunk_idx, arr_idx) = self.index_to_chunked_index(idx); + + unsafe { + // SAFETY: up to the caller to make sure the index is valid. + self.downcast_get_unchecked(chunk_idx) + .value_unchecked(arr_idx) + } + } + + #[inline] + pub fn first(&self) -> Option> { + unsafe { + let arr = self.downcast_get_unchecked(0); + arr.get_unchecked(0) + } + } + + #[inline] + pub fn last(&self) -> Option> { + unsafe { + let arr = self.downcast_get_unchecked(self.chunks.len().checked_sub(1)?); + arr.get_unchecked(arr.len().checked_sub(1)?) + } + } +} + +impl ListChunked { + #[inline] + pub fn get_as_series(&self, idx: usize) -> Option { + unsafe { + Some(Series::from_chunks_and_dtype_unchecked( + self.name().clone(), + vec![self.get(idx)?], + &self.inner_dtype().to_physical(), + )) + } + } +} + +#[cfg(feature = "dtype-array")] +impl ArrayChunked { + #[inline] + pub fn get_as_series(&self, idx: usize) -> Option { + unsafe { + Some(Series::from_chunks_and_dtype_unchecked( + self.name().clone(), + vec![self.get(idx)?], + &self.inner_dtype().to_physical(), + )) + } + } +} + +impl ChunkedArray +where + T: PolarsDataType, +{ + /// Should be used to match the chunk_id of another [`ChunkedArray`]. + /// # Panics + /// It is the callers responsibility to ensure that this [`ChunkedArray`] has a single chunk. + pub(crate) fn match_chunks(&self, chunk_id: I) -> Self + where + I: Iterator, + { + debug_assert!(self.chunks.len() == 1); + // Takes a ChunkedArray containing a single chunk. + let slice = |ca: &Self| { + let array = &ca.chunks[0]; + + let mut offset = 0; + let chunks = chunk_id + .map(|len| { + // SAFETY: within bounds. + debug_assert!((offset + len) <= array.len()); + let out = unsafe { array.sliced_unchecked(offset, len) }; + offset += len; + out + }) + .collect(); + + debug_assert_eq!(offset, array.len()); + + // SAFETY: We just slice the original chunks, their type will not change. + unsafe { + Self::from_chunks_and_dtype(self.name().clone(), chunks, self.dtype().clone()) + } + }; + + if self.chunks.len() != 1 { + let out = self.rechunk(); + slice(&out) + } else { + slice(self) + } + } +} + +impl AsRefDataType for ChunkedArray { + fn as_ref_dtype(&self) -> &DataType { + self.dtype() + } +} + +pub(crate) trait AsSinglePtr: AsRefDataType { + /// Rechunk and return a ptr to the start of the array + fn as_single_ptr(&mut self) -> PolarsResult { + polars_bail!(opq = as_single_ptr, self.as_ref_dtype()); + } +} + +impl AsSinglePtr for ChunkedArray +where + T: PolarsNumericType, +{ + fn as_single_ptr(&mut self) -> PolarsResult { + self.rechunk_mut(); + let a = self.data_views().next().unwrap(); + let ptr = a.as_ptr(); + Ok(ptr as usize) + } +} + +impl AsSinglePtr for BooleanChunked {} +impl AsSinglePtr for ListChunked {} +#[cfg(feature = "dtype-array")] +impl AsSinglePtr for ArrayChunked {} +impl AsSinglePtr for StringChunked {} +impl AsSinglePtr for BinaryChunked {} +#[cfg(feature = "object")] +impl AsSinglePtr for ObjectChunked {} + +pub enum ChunkedArrayLayout<'a, T: PolarsDataType> { + SingleNoNull(&'a T::Array), + Single(&'a T::Array), + MultiNoNull(&'a ChunkedArray), + Multi(&'a ChunkedArray), +} + +impl ChunkedArray +where + T: PolarsDataType, +{ + pub fn layout(&self) -> ChunkedArrayLayout<'_, T> { + if self.chunks.len() == 1 { + let arr = self.downcast_iter().next().unwrap(); + return if arr.null_count() == 0 { + ChunkedArrayLayout::SingleNoNull(arr) + } else { + ChunkedArrayLayout::Single(arr) + }; + } + + if self.downcast_iter().all(|a| a.null_count() == 0) { + ChunkedArrayLayout::MultiNoNull(self) + } else { + ChunkedArrayLayout::Multi(self) + } + } +} + +impl ChunkedArray +where + T: PolarsNumericType, +{ + /// Returns the values of the array as a contiguous slice. + pub fn cont_slice(&self) -> PolarsResult<&[T::Native]> { + polars_ensure!( + self.chunks.len() == 1 && self.chunks[0].null_count() == 0, + ComputeError: "chunked array is not contiguous" + ); + Ok(self.downcast_iter().next().map(|arr| arr.values()).unwrap()) + } + + /// Returns the values of the array as a contiguous mutable slice. + pub(crate) fn cont_slice_mut(&mut self) -> Option<&mut [T::Native]> { + if self.chunks.len() == 1 && self.chunks[0].null_count() == 0 { + // SAFETY, we will not swap the PrimitiveArray. + let arr = unsafe { self.downcast_iter_mut().next().unwrap() }; + arr.get_mut_values() + } else { + None + } + } + + /// Get slices of the underlying arrow data. + /// NOTE: null values should be taken into account by the user of these slices as they are handled + /// separately + pub fn data_views(&self) -> impl DoubleEndedIterator { + self.downcast_iter().map(|arr| arr.values().as_slice()) + } + + #[allow(clippy::wrong_self_convention)] + pub fn into_no_null_iter( + &self, + ) -> impl '_ + Send + Sync + ExactSizeIterator + DoubleEndedIterator + TrustedLen + { + // .copied was significantly slower in benchmark, next call did not inline? + #[allow(clippy::map_clone)] + // we know the iterators len + unsafe { + self.data_views() + .flatten() + .map(|v| *v) + .trust_my_length(self.len()) + } + } +} + +impl Clone for ChunkedArray { + fn clone(&self) -> Self { + ChunkedArray { + field: self.field.clone(), + chunks: self.chunks.clone(), + flags: self.flags.clone(), + + _pd: Default::default(), + length: self.length, + null_count: self.null_count, + } + } +} + +impl AsRef> for ChunkedArray { + fn as_ref(&self) -> &ChunkedArray { + self + } +} + +impl ValueSize for ListChunked { + fn get_values_size(&self) -> usize { + self.chunks + .iter() + .fold(0usize, |acc, arr| acc + arr.get_values_size()) + } +} + +#[cfg(feature = "dtype-array")] +impl ValueSize for ArrayChunked { + fn get_values_size(&self) -> usize { + self.chunks + .iter() + .fold(0usize, |acc, arr| acc + arr.get_values_size()) + } +} +impl ValueSize for StringChunked { + fn get_values_size(&self) -> usize { + self.chunks + .iter() + .fold(0usize, |acc, arr| acc + arr.get_values_size()) + } +} + +impl ValueSize for BinaryOffsetChunked { + fn get_values_size(&self) -> usize { + self.chunks + .iter() + .fold(0usize, |acc, arr| acc + arr.get_values_size()) + } +} + +pub(crate) fn to_primitive( + values: Vec, + validity: Option, +) -> PrimitiveArray { + PrimitiveArray::new( + T::get_dtype().to_arrow(CompatLevel::newest()), + values.into(), + validity, + ) +} + +pub(crate) fn to_array( + values: Vec, + validity: Option, +) -> ArrayRef { + Box::new(to_primitive::(values, validity)) +} + +impl Default for ChunkedArray { + fn default() -> Self { + let dtype = T::get_dtype(); + let arrow_dtype = dtype.to_physical().to_arrow(CompatLevel::newest()); + ChunkedArray { + field: Arc::new(Field::new(PlSmallStr::EMPTY, dtype)), + // Invariant: always has 1 chunk. + chunks: vec![new_empty_array(arrow_dtype)], + flags: StatisticsFlagsIM::empty(), + + _pd: Default::default(), + length: 0, + null_count: 0, + } + } +} + +#[cfg(test)] +pub(crate) mod test { + use crate::prelude::*; + + pub(crate) fn get_chunked_array() -> Int32Chunked { + ChunkedArray::new(PlSmallStr::from_static("a"), &[1, 2, 3]) + } + + #[test] + fn test_sort() { + let a = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 9, 3, 2]); + let b = a + .sort(false) + .into_iter() + .map(|opt| opt.unwrap()) + .collect::>(); + assert_eq!(b, [1, 2, 3, 9]); + let a = StringChunked::new(PlSmallStr::from_static("a"), &["b", "a", "c"]); + let a = a.sort(false); + let b = a.into_iter().collect::>(); + assert_eq!(b, [Some("a"), Some("b"), Some("c")]); + assert!(a.is_sorted_ascending_flag()); + } + + #[test] + fn arithmetic() { + let a = &Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 100, 6, 40]); + let b = &Int32Chunked::new(PlSmallStr::from_static("b"), &[-1, 2, 3, 4]); + + // Not really asserting anything here but still making sure the code is exercised + // This (and more) is properly tested from the integration test suite and Python bindings. + println!("{:?}", a + b); + println!("{:?}", a - b); + println!("{:?}", a * b); + println!("{:?}", a / b); + } + + #[test] + fn iter() { + let s1 = get_chunked_array(); + // sum + assert_eq!(s1.into_iter().fold(0, |acc, val| { acc + val.unwrap() }), 6) + } + + #[test] + fn limit() { + let a = get_chunked_array(); + let b = a.limit(2); + println!("{:?}", b); + assert_eq!(b.len(), 2) + } + + #[test] + fn filter() { + let a = get_chunked_array(); + let b = a + .filter(&BooleanChunked::new( + PlSmallStr::from_static("filter"), + &[true, false, false], + )) + .unwrap(); + assert_eq!(b.len(), 1); + assert_eq!(b.into_iter().next(), Some(Some(1))); + } + + #[test] + fn aggregates() { + let a = &Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 100, 10, 9]); + assert_eq!(a.max(), Some(100)); + assert_eq!(a.min(), Some(1)); + assert_eq!(a.sum(), Some(120)) + } + + #[test] + fn take() { + let a = get_chunked_array(); + let new = a.take(&[0 as IdxSize, 1]).unwrap(); + assert_eq!(new.len(), 2) + } + + #[test] + fn cast() { + let a = get_chunked_array(); + let b = a.cast(&DataType::Int64).unwrap(); + assert_eq!(b.dtype(), &DataType::Int64) + } + + fn assert_slice_equal(ca: &ChunkedArray, eq: &[T::Native]) + where + T: PolarsNumericType, + { + assert_eq!(ca.iter().map(|opt| opt.unwrap()).collect::>(), eq) + } + + #[test] + fn slice() { + let mut first = UInt32Chunked::new(PlSmallStr::from_static("first"), &[0, 1, 2]); + let second = UInt32Chunked::new(PlSmallStr::from_static("second"), &[3, 4, 5]); + first.append(&second).unwrap(); + assert_slice_equal(&first.slice(0, 3), &[0, 1, 2]); + assert_slice_equal(&first.slice(0, 4), &[0, 1, 2, 3]); + assert_slice_equal(&first.slice(1, 4), &[1, 2, 3, 4]); + assert_slice_equal(&first.slice(3, 2), &[3, 4]); + assert_slice_equal(&first.slice(3, 3), &[3, 4, 5]); + assert_slice_equal(&first.slice(-3, 3), &[3, 4, 5]); + assert_slice_equal(&first.slice(-6, 6), &[0, 1, 2, 3, 4, 5]); + + assert_eq!(first.slice(-7, 2).len(), 1); + assert_eq!(first.slice(-3, 4).len(), 3); + assert_eq!(first.slice(3, 4).len(), 3); + assert_eq!(first.slice(10, 4).len(), 0); + } + + #[test] + fn sorting() { + let s = UInt32Chunked::new(PlSmallStr::EMPTY, &[9, 2, 4]); + let sorted = s.sort(false); + assert_slice_equal(&sorted, &[2, 4, 9]); + let sorted = s.sort(true); + assert_slice_equal(&sorted, &[9, 4, 2]); + + let s: StringChunked = ["b", "a", "z"].iter().collect(); + let sorted = s.sort(false); + assert_eq!( + sorted.into_iter().collect::>(), + &[Some("a"), Some("b"), Some("z")] + ); + let sorted = s.sort(true); + assert_eq!( + sorted.into_iter().collect::>(), + &[Some("z"), Some("b"), Some("a")] + ); + let s: StringChunked = [Some("b"), None, Some("z")].iter().copied().collect(); + let sorted = s.sort(false); + assert_eq!( + sorted.into_iter().collect::>(), + &[None, Some("b"), Some("z")] + ); + } + + #[test] + fn reverse() { + let s = UInt32Chunked::new(PlSmallStr::EMPTY, &[1, 2, 3]); + // path with continuous slice + assert_slice_equal(&s.reverse(), &[3, 2, 1]); + // path with options + let s = UInt32Chunked::new(PlSmallStr::EMPTY, &[Some(1), None, Some(3)]); + assert_eq!(Vec::from(&s.reverse()), &[Some(3), None, Some(1)]); + let s = BooleanChunked::new(PlSmallStr::EMPTY, &[true, false]); + assert_eq!(Vec::from(&s.reverse()), &[Some(false), Some(true)]); + + let s = StringChunked::new(PlSmallStr::EMPTY, &["a", "b", "c"]); + assert_eq!(Vec::from(&s.reverse()), &[Some("c"), Some("b"), Some("a")]); + + let s = StringChunked::new(PlSmallStr::EMPTY, &[Some("a"), None, Some("c")]); + assert_eq!(Vec::from(&s.reverse()), &[Some("c"), None, Some("a")]); + } + + #[test] + #[cfg(feature = "dtype-categorical")] + fn test_iter_categorical() { + use crate::{SINGLE_LOCK, disable_string_cache}; + let _lock = SINGLE_LOCK.lock(); + disable_string_cache(); + let ca = StringChunked::new( + PlSmallStr::EMPTY, + &[Some("foo"), None, Some("bar"), Some("ham")], + ); + let ca = ca + .cast(&DataType::Categorical(None, Default::default())) + .unwrap(); + let ca = ca.categorical().unwrap(); + let v: Vec<_> = ca.physical().into_iter().collect(); + assert_eq!(v, &[Some(0), None, Some(1), Some(2)]); + } + + #[test] + #[ignore] + fn test_shrink_to_fit() { + let mut builder = StringChunkedBuilder::new(PlSmallStr::from_static("foo"), 2048); + builder.append_value("foo"); + let mut arr = builder.finish(); + let before = arr + .chunks() + .iter() + .map(|arr| arrow::compute::aggregate::estimated_bytes_size(arr.as_ref())) + .sum::(); + arr.shrink_to_fit(); + let after = arr + .chunks() + .iter() + .map(|arr| arrow::compute::aggregate::estimated_bytes_size(arr.as_ref())) + .sum::(); + assert!(before > after); + } +} diff --git a/crates/polars-core/src/chunked_array/ndarray.rs b/crates/polars-core/src/chunked_array/ndarray.rs new file mode 100644 index 000000000000..aa632aedcd0e --- /dev/null +++ b/crates/polars-core/src/chunked_array/ndarray.rs @@ -0,0 +1,250 @@ +use ndarray::prelude::*; +use rayon::prelude::*; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use crate::POOL; +use crate::prelude::*; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum IndexOrder { + C, + #[default] + Fortran, +} + +impl ChunkedArray +where + T: PolarsNumericType, +{ + /// If data is aligned in a single chunk and has no Null values a zero copy view is returned + /// as an [ndarray] + pub fn to_ndarray(&self) -> PolarsResult> { + let slice = self.cont_slice()?; + Ok(aview1(slice)) + } +} + +impl ListChunked { + /// If all nested [`Series`] have the same length, a 2 dimensional [`ndarray::Array`] is returned. + pub fn to_ndarray(&self) -> PolarsResult> + where + N: PolarsNumericType, + { + polars_ensure!( + self.null_count() == 0, + ComputeError: "creation of ndarray with null values is not supported" + ); + + // first iteration determine the size + let mut iter = self.into_no_null_iter(); + let series = iter + .next() + .ok_or_else(|| polars_err!(NoData: "unable to create ndarray of empty ListChunked"))?; + + let width = series.len(); + let mut row_idx = 0; + let mut ndarray = ndarray::Array::uninit((self.len(), width)); + + let series = series.cast(&N::get_dtype())?; + let ca = series.unpack::()?; + let a = ca.to_ndarray()?; + let mut row = ndarray.slice_mut(s![row_idx, ..]); + a.assign_to(&mut row); + row_idx += 1; + + for series in iter { + polars_ensure!( + series.len() == width, + ShapeMismatch: "unable to create a 2-D array, series have different lengths" + ); + let series = series.cast(&N::get_dtype())?; + let ca = series.unpack::()?; + let a = ca.to_ndarray()?; + let mut row = ndarray.slice_mut(s![row_idx, ..]); + a.assign_to(&mut row); + row_idx += 1; + } + + debug_assert_eq!(row_idx, self.len()); + // SAFETY: + // We have assigned to every row and element of the array + unsafe { Ok(ndarray.assume_init()) } + } +} + +impl DataFrame { + /// Create a 2D [`ndarray::Array`] from this [`DataFrame`]. This requires all columns in the + /// [`DataFrame`] to be non-null and numeric. They will be cast to the same data type + /// (if they aren't already). + /// + /// For floating point data we implicitly convert `None` to `NaN` without failure. + /// + /// ```rust + /// use polars_core::prelude::*; + /// let a = UInt32Chunked::new("a".into(), &[1, 2, 3]).into_column(); + /// let b = Float64Chunked::new("b".into(), &[10., 8., 6.]).into_column(); + /// + /// let df = DataFrame::new(vec![a, b]).unwrap(); + /// let ndarray = df.to_ndarray::(IndexOrder::Fortran).unwrap(); + /// println!("{:?}", ndarray); + /// ``` + /// Outputs: + /// ```text + /// [[1.0, 10.0], + /// [2.0, 8.0], + /// [3.0, 6.0]], shape=[3, 2], strides=[1, 3], layout=Ff (0xa), const ndim=2 + /// ``` + pub fn to_ndarray(&self, ordering: IndexOrder) -> PolarsResult> + where + N: PolarsNumericType, + { + let shape = self.shape(); + let height = self.height(); + let mut membuf = Vec::with_capacity(shape.0 * shape.1); + let ptr = membuf.as_ptr() as usize; + + let columns = self.get_columns(); + POOL.install(|| { + columns.par_iter().enumerate().try_for_each(|(col_idx, s)| { + let s = s.as_materialized_series().cast(&N::get_dtype())?; + let s = match s.dtype() { + DataType::Float32 => { + let ca = s.f32().unwrap(); + ca.none_to_nan().into_series() + }, + DataType::Float64 => { + let ca = s.f64().unwrap(); + ca.none_to_nan().into_series() + }, + _ => s, + }; + polars_ensure!( + s.null_count() == 0, + ComputeError: "creation of ndarray with null values is not supported" + ); + let ca = s.unpack::()?; + + let mut chunk_offset = 0; + for arr in ca.downcast_iter() { + let vals = arr.values(); + + // Depending on the desired order, we add items to the buffer. + // SAFETY: + // We get parallel access to the vector by offsetting index access accordingly. + // For C-order, we only operate on every num-col-th element, starting from the + // column index. For Fortran-order we only operate on n contiguous elements, + // offset by n * the column index. + match ordering { + IndexOrder::C => unsafe { + let num_cols = columns.len(); + let mut offset = + (ptr as *mut N::Native).add(col_idx + chunk_offset * num_cols); + for v in vals.iter() { + *offset = *v; + offset = offset.add(num_cols); + } + }, + IndexOrder::Fortran => unsafe { + let offset_ptr = + (ptr as *mut N::Native).add(col_idx * height + chunk_offset); + // SAFETY: + // this is uninitialized memory, so we must never read from this data + // copy_from_slice does not read + let buf = std::slice::from_raw_parts_mut(offset_ptr, vals.len()); + buf.copy_from_slice(vals) + }, + } + chunk_offset += vals.len(); + } + + Ok(()) + }) + })?; + + // SAFETY: + // we have written all data, so we can now safely set length + unsafe { + membuf.set_len(shape.0 * shape.1); + } + // Depending on the desired order, we can either return the array buffer as-is or reverse + // the axes. + match ordering { + IndexOrder::C => Ok(Array2::from_shape_vec((shape.0, shape.1), membuf).unwrap()), + IndexOrder::Fortran => { + let ndarr = Array2::from_shape_vec((shape.1, shape.0), membuf).unwrap(); + Ok(ndarr.reversed_axes()) + }, + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_ndarray_from_ca() -> PolarsResult<()> { + let ca = Float64Chunked::new(PlSmallStr::EMPTY, &[1.0, 2.0, 3.0]); + let ndarr = ca.to_ndarray()?; + assert_eq!(ndarr, ArrayView1::from(&[1.0, 2.0, 3.0])); + + let mut builder = ListPrimitiveChunkedBuilder::::new( + PlSmallStr::EMPTY, + 10, + 10, + DataType::Float64, + ); + builder.append_opt_slice(Some(&[1.0, 2.0, 3.0])); + builder.append_opt_slice(Some(&[2.0, 4.0, 5.0])); + builder.append_opt_slice(Some(&[6.0, 7.0, 8.0])); + let list = builder.finish(); + + let ndarr = list.to_ndarray::()?; + let expected = array![[1.0, 2.0, 3.0], [2.0, 4.0, 5.0], [6.0, 7.0, 8.0]]; + assert_eq!(ndarr, expected); + + // test list array that is not square + let mut builder = ListPrimitiveChunkedBuilder::::new( + PlSmallStr::EMPTY, + 10, + 10, + DataType::Float64, + ); + builder.append_opt_slice(Some(&[1.0, 2.0, 3.0])); + builder.append_opt_slice(Some(&[2.0])); + builder.append_opt_slice(Some(&[6.0, 7.0, 8.0])); + let list = builder.finish(); + assert!(list.to_ndarray::().is_err()); + Ok(()) + } + + #[test] + fn test_ndarray_from_df_order_fortran() -> PolarsResult<()> { + let df = df!["a"=> [1.0, 2.0, 3.0], + "b" => [2.0, 3.0, 4.0] + ]?; + + let ndarr = df.to_ndarray::(IndexOrder::Fortran)?; + let expected = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]]; + assert!(!ndarr.is_standard_layout()); + assert_eq!(ndarr, expected); + + Ok(()) + } + + #[test] + fn test_ndarray_from_df_order_c() -> PolarsResult<()> { + let df = df!["a"=> [1.0, 2.0, 3.0], + "b" => [2.0, 3.0, 4.0] + ]?; + + let ndarr = df.to_ndarray::(IndexOrder::C)?; + let expected = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]]; + assert!(ndarr.is_standard_layout()); + assert_eq!(ndarr, expected); + + Ok(()) + } +} diff --git a/crates/polars-core/src/chunked_array/object/builder.rs b/crates/polars-core/src/chunked_array/object/builder.rs new file mode 100644 index 000000000000..05e522c80150 --- /dev/null +++ b/crates/polars-core/src/chunked_array/object/builder.rs @@ -0,0 +1,273 @@ +use arrow::array::builder::{ArrayBuilder, ShareStrategy}; +use arrow::bitmap::BitmapBuilder; +use polars_utils::vec::PushUnchecked; + +use super::*; +use crate::utils::get_iter_capacity; + +pub struct ObjectChunkedBuilder { + field: Field, + bitmask_builder: BitmapBuilder, + values: Vec, +} + +impl ObjectChunkedBuilder +where + T: PolarsObject, +{ + pub fn field(&self) -> &Field { + &self.field + } + pub fn new(name: PlSmallStr, capacity: usize) -> Self { + ObjectChunkedBuilder { + field: Field::new(name, DataType::Object(T::type_name())), + values: Vec::with_capacity(capacity), + bitmask_builder: BitmapBuilder::with_capacity(capacity), + } + } + + /// Appends a value of type `T` into the builder + #[inline] + pub fn append_value(&mut self, v: T) { + self.values.push(v); + self.bitmask_builder.push(true); + } + + /// Appends a null slot into the builder + #[inline] + pub fn append_null(&mut self) { + self.values.push(T::default()); + self.bitmask_builder.push(false); + } + + #[inline] + pub fn append_value_from_any(&mut self, v: &dyn Any) -> PolarsResult<()> { + let Some(v) = v.downcast_ref::() else { + polars_bail!(SchemaMismatch: "cannot downcast any in ObjectBuilder"); + }; + self.append_value(v.clone()); + Ok(()) + } + + #[inline] + pub fn append_option(&mut self, opt: Option) { + match opt { + Some(s) => self.append_value(s), + None => self.append_null(), + } + } + + pub fn finish(mut self) -> ObjectChunked { + let null_bitmap: Option = self.bitmask_builder.into_opt_validity(); + + let len = self.values.len(); + let null_count = null_bitmap + .as_ref() + .map(|validity| validity.unset_bits()) + .unwrap_or(0); + + let arr = Box::new(ObjectArray { + values: self.values.into(), + validity: null_bitmap, + }); + + self.field.dtype = get_object_type::(); + + unsafe { ChunkedArray::new_with_dims(Arc::new(self.field), vec![arr], len, null_count) } + } +} + +/// Initialize a polars Object data type. The type has got information needed to +/// construct new objects. +pub(crate) fn get_object_type() -> DataType { + DataType::Object(T::type_name()) +} + +impl Default for ObjectChunkedBuilder +where + T: PolarsObject, +{ + fn default() -> Self { + ObjectChunkedBuilder::new(PlSmallStr::EMPTY, 0) + } +} + +impl NewChunkedArray, T> for ObjectChunked +where + T: PolarsObject, +{ + fn from_slice(name: PlSmallStr, v: &[T]) -> Self { + Self::from_iter_values(name, v.iter().cloned()) + } + + fn from_slice_options(name: PlSmallStr, opt_v: &[Option]) -> Self { + let mut builder = ObjectChunkedBuilder::::new(name, opt_v.len()); + opt_v + .iter() + .cloned() + .for_each(|opt| builder.append_option(opt)); + builder.finish() + } + + fn from_iter_options( + name: PlSmallStr, + it: impl Iterator>, + ) -> ObjectChunked { + let mut builder = ObjectChunkedBuilder::new(name, get_iter_capacity(&it)); + it.for_each(|opt| builder.append_option(opt)); + builder.finish() + } + + /// Create a new ChunkedArray from an iterator. + fn from_iter_values(name: PlSmallStr, it: impl Iterator) -> ObjectChunked { + let mut builder = ObjectChunkedBuilder::new(name, get_iter_capacity(&it)); + it.for_each(|v| builder.append_value(v)); + builder.finish() + } +} + +impl ObjectChunked +where + T: PolarsObject, +{ + pub fn new_from_vec(name: PlSmallStr, v: Vec) -> Self { + let field = Arc::new(Field::new(name, DataType::Object(T::type_name()))); + let len = v.len(); + let arr = Box::new(ObjectArray { + values: v.into(), + validity: None, + }); + + unsafe { ObjectChunked::new_with_dims(field, vec![arr], len, 0) } + } + + pub fn new_from_vec_and_validity( + name: PlSmallStr, + v: Vec, + validity: Option, + ) -> Self { + let field = Arc::new(Field::new(name, DataType::Object(T::type_name()))); + let len = v.len(); + let null_count = validity.as_ref().map(|v| v.unset_bits()).unwrap_or(0); + let arr = Box::new(ObjectArray { + values: v.into(), + validity, + }); + + unsafe { ObjectChunked::new_with_dims(field, vec![arr], len, null_count) } + } + + pub fn new_empty(name: PlSmallStr) -> Self { + Self::new_from_vec(name, vec![]) + } +} + +/// Convert a Series of dtype object to an Arrow Array of FixedSizeBinary +pub(crate) fn object_series_to_arrow_array(s: &Series) -> ArrayRef { + // The list builder knows how to create an arrow array + // we simply piggy back on that code. + + // SAFETY: 0..len is in bounds + let list_s = unsafe { + s.agg_list(&GroupsType::Slice { + groups: vec![[0, s.len() as IdxSize]], + rolling: false, + }) + }; + let arr = &list_s.chunks()[0]; + let arr = arr.as_any().downcast_ref::>().unwrap(); + arr.values().to_boxed() +} + +impl ArrayBuilder for ObjectChunkedBuilder { + fn dtype(&self) -> &ArrowDataType { + &ArrowDataType::FixedSizeBinary(size_of::()) + } + + fn reserve(&mut self, additional: usize) { + self.bitmask_builder.reserve(additional); + self.values.reserve(additional); + } + + fn freeze(self) -> Box { + Box::new(ObjectArray { + values: self.values.into(), + validity: self.bitmask_builder.into_opt_validity(), + }) + } + + fn freeze_reset(&mut self) -> Box { + Box::new(ObjectArray { + values: core::mem::take(&mut self.values).into(), + validity: core::mem::take(&mut self.bitmask_builder).into_opt_validity(), + }) + } + + fn len(&self) -> usize { + self.values.len() + } + + fn extend_nulls(&mut self, length: usize) { + self.values.resize(self.values.len() + length, T::default()); + self.bitmask_builder.extend_constant(length, false); + } + + fn subslice_extend( + &mut self, + other: &dyn Array, + start: usize, + length: usize, + _share: ShareStrategy, + ) { + let other: &ObjectArray = other.as_any().downcast_ref().unwrap(); + self.values + .extend_from_slice(&other.values[start..start + length]); + self.bitmask_builder + .subslice_extend_from_opt_validity(other.validity(), start, length); + } + + fn subslice_extend_repeated( + &mut self, + other: &dyn Array, + start: usize, + length: usize, + repeats: usize, + share: ShareStrategy, + ) { + for _ in 0..repeats { + self.subslice_extend(other, start, length, share) + } + } + + unsafe fn gather_extend(&mut self, other: &dyn Array, idxs: &[IdxSize], _share: ShareStrategy) { + let other: &ObjectArray = other.as_any().downcast_ref().unwrap(); + let other_values_slice = other.values.as_slice(); + self.values.extend( + idxs.iter() + .map(|idx| other_values_slice.get_unchecked(*idx as usize).clone()), + ); + self.bitmask_builder + .gather_extend_from_opt_validity(other.validity(), idxs, other.len()); + } + + fn opt_gather_extend(&mut self, other: &dyn Array, idxs: &[IdxSize], _share: ShareStrategy) { + let other: &ObjectArray = other.as_any().downcast_ref().unwrap(); + let other_values_slice = other.values.as_slice(); + self.values.reserve(idxs.len()); + unsafe { + for idx in idxs { + let val = if (*idx as usize) < other.len() { + other_values_slice.get_unchecked(*idx as usize).clone() + } else { + T::default() + }; + self.values.push_unchecked(val); + } + } + self.bitmask_builder.opt_gather_extend_from_opt_validity( + other.validity(), + idxs, + other.len(), + ); + } +} diff --git a/crates/polars-core/src/chunked_array/object/extension/drop.rs b/crates/polars-core/src/chunked_array/object/extension/drop.rs new file mode 100644 index 000000000000..6377cb170c38 --- /dev/null +++ b/crates/polars-core/src/chunked_array/object/extension/drop.rs @@ -0,0 +1,47 @@ +use crate::chunked_array::object::extension::PolarsExtension; +use crate::prelude::*; + +/// This will dereference a raw ptr when dropping the [`PolarsExtension`], make sure that it's valid. +pub(crate) unsafe fn drop_list(ca: &ListChunked) { + let mut inner = ca.inner_dtype(); + let mut nested_count = 0; + + while let Some(a) = inner.inner_dtype() { + nested_count += 1; + inner = a; + } + + if matches!(inner, DataType::Object(_)) { + if nested_count != 0 { + panic!("multiple nested objects not yet supported") + } + // if empty the memory is leaked somewhere + assert!(!ca.chunks.is_empty()); + for lst_arr in &ca.chunks { + if let ArrowDataType::LargeList(fld) = lst_arr.dtype() { + let dtype = fld.dtype(); + + assert!(matches!(dtype, ArrowDataType::Extension(_))); + + // recreate the polars extension so that the content is dropped + let arr = lst_arr.as_any().downcast_ref::().unwrap(); + + let values = arr.values(); + drop_object_array(values.as_ref()) + } + } + } +} + +pub(crate) unsafe fn drop_object_array(values: &dyn Array) { + let arr = values + .as_any() + .downcast_ref::() + .unwrap(); + + // If the buf is not shared with anyone but us we can deallocate. + let buf = arr.values(); + if buf.storage_refcount() == 1 && !buf.is_empty() { + PolarsExtension::new(arr.clone()); + }; +} diff --git a/crates/polars-core/src/chunked_array/object/extension/list.rs b/crates/polars-core/src/chunked_array/object/extension/list.rs new file mode 100644 index 000000000000..2d34315c378d --- /dev/null +++ b/crates/polars-core/src/chunked_array/object/extension/list.rs @@ -0,0 +1,89 @@ +use arrow::offset::Offsets; + +use crate::chunked_array::object::builder::ObjectChunkedBuilder; +use crate::chunked_array::object::extension::create_extension; +use crate::prelude::*; + +impl ObjectChunked { + pub(crate) fn get_list_builder( + name: PlSmallStr, + values_capacity: usize, + list_capacity: usize, + ) -> Box { + Box::new(ExtensionListBuilder::::new( + name, + values_capacity, + list_capacity, + )) + } +} + +pub(crate) struct ExtensionListBuilder { + values_builder: ObjectChunkedBuilder, + offsets: Vec, + fast_explode: bool, +} + +impl ExtensionListBuilder { + pub(crate) fn new(name: PlSmallStr, values_capacity: usize, list_capacity: usize) -> Self { + let mut offsets = Vec::with_capacity(list_capacity + 1); + offsets.push(0); + Self { + values_builder: ObjectChunkedBuilder::new(name, values_capacity), + offsets, + fast_explode: true, + } + } +} + +impl ListBuilderTrait for ExtensionListBuilder { + fn append_series(&mut self, s: &Series) -> PolarsResult<()> { + let arr = s.as_any().downcast_ref::>().unwrap(); + + for v in arr.into_iter() { + self.values_builder.append_option(v.cloned()) + } + if arr.is_empty() { + self.fast_explode = false; + } + let len_so_far = self.offsets[self.offsets.len() - 1]; + self.offsets.push(len_so_far + arr.len() as i64); + Ok(()) + } + + fn append_null(&mut self) { + self.values_builder.append_null(); + let len_so_far = self.offsets[self.offsets.len() - 1]; + self.offsets.push(len_so_far + 1); + } + + fn finish(&mut self) -> ListChunked { + let values_builder = std::mem::take(&mut self.values_builder); + let offsets = std::mem::take(&mut self.offsets); + let ca = values_builder.finish(); + let obj_arr = ca.downcast_chunks().get(0).unwrap().clone(); + + // SAFETY: this is safe because we just created the PolarsExtension + // meaning that the sentinel is heap allocated and the dereference of + // the pointer does not fail. + let mut pe = create_extension(obj_arr.into_iter_cloned()); + unsafe { pe.set_to_series_fn::() }; + let extension_array = Box::new(pe.take_and_forget()) as ArrayRef; + let extension_dtype = extension_array.dtype(); + + let dtype = ListArray::::default_datatype(extension_dtype.clone()); + let arr = ListArray::::new( + dtype, + // SAFETY: offsets are monotonically increasing. + unsafe { Offsets::new_unchecked(offsets).into() }, + extension_array, + None, + ); + + let mut listarr = ListChunked::with_chunk(ca.name().clone(), arr); + if self.fast_explode { + listarr.set_fast_explode() + } + listarr + } +} diff --git a/crates/polars-core/src/chunked_array/object/extension/mod.rs b/crates/polars-core/src/chunked_array/object/extension/mod.rs new file mode 100644 index 000000000000..d27557f748d1 --- /dev/null +++ b/crates/polars-core/src/chunked_array/object/extension/mod.rs @@ -0,0 +1,198 @@ +pub(crate) mod drop; +pub(super) mod list; +pub(crate) mod polars_extension; + +use std::mem; +use std::sync::atomic::{AtomicBool, Ordering}; + +use arrow::array::FixedSizeBinaryArray; +use arrow::bitmap::BitmapBuilder; +use arrow::buffer::Buffer; +use arrow::datatypes::ExtensionType; +use polars_extension::PolarsExtension; +use polars_utils::format_pl_smallstr; + +use crate::PROCESS_ID; +use crate::prelude::*; + +static POLARS_ALLOW_EXTENSION: AtomicBool = AtomicBool::new(false); + +/// Control whether extension types may be created. +/// +/// If the environment variable POLARS_ALLOW_EXTENSION is set, this function has no effect. +pub fn set_polars_allow_extension(toggle: bool) { + POLARS_ALLOW_EXTENSION.store(toggle, Ordering::Relaxed) +} + +/// Invariants +/// `ptr` must point to start a `T` allocation +/// `n_t_vals` must represent the correct number of `T` values in that allocation +unsafe fn create_drop(mut ptr: *const u8, n_t_vals: usize) -> Box { + Box::new(move || { + let t_size = size_of::() as isize; + for _ in 0..n_t_vals { + let _ = std::ptr::read_unaligned(ptr as *const T); + ptr = ptr.offset(t_size) + } + }) +} + +#[allow(clippy::type_complexity)] +struct ExtensionSentinel { + drop_fn: Option>, + // A function on the heap that take a `array: FixedSizeBinary` and a `name: PlSmallStr` + // and returns a `Series` of `ObjectChunked` + pub(crate) to_series_fn: Option Series>>, +} + +impl Drop for ExtensionSentinel { + fn drop(&mut self) { + let mut drop_fn = self.drop_fn.take().unwrap(); + drop_fn() + } +} + +// https://stackoverflow.com/questions/28127165/how-to-convert-struct-to-u8d +// not entirely sure if padding bytes in T are initialized or not. +unsafe fn any_as_u8_slice(p: &T) -> &[u8] { + std::slice::from_raw_parts((p as *const T) as *const u8, size_of::()) +} + +/// Create an extension Array that can be sent to arrow and (once wrapped in [`PolarsExtension`] will +/// also call drop on `T`, when the array is dropped. +pub(crate) fn create_extension> + TrustedLen, T: Sized + Default>( + iter: I, +) -> PolarsExtension { + let env = "POLARS_ALLOW_EXTENSION"; + if !(POLARS_ALLOW_EXTENSION.load(Ordering::Relaxed) || std::env::var(env).is_ok()) { + panic!("creating extension types not allowed - try setting the environment variable {env}") + } + let t_size = size_of::(); + let t_alignment = align_of::(); + let n_t_vals = iter.size_hint().1.unwrap(); + + let mut buf = Vec::with_capacity(n_t_vals * t_size); + let mut validity = BitmapBuilder::with_capacity(n_t_vals); + + // when we transmute from &[u8] to T, T must be aligned correctly, + // so we pad with bytes until the alignment matches + let n_padding = (buf.as_ptr() as usize) % t_alignment; + buf.extend(std::iter::repeat_n(0, n_padding)); + + // transmute T as bytes and copy in buffer + for opt_t in iter.into_iter() { + match opt_t { + Some(t) => { + unsafe { + buf.extend_from_slice(any_as_u8_slice(&t)); + // SAFETY: we allocated upfront + validity.push_unchecked(true) + } + mem::forget(t); + }, + None => { + unsafe { + buf.extend_from_slice(any_as_u8_slice(&T::default())); + // SAFETY: we allocated upfront + validity.push_unchecked(false) + } + }, + } + } + + // we slice the buffer because we want to ignore the padding bytes from here + // they can be forgotten + let buf: Buffer = buf.into(); + let len = buf.len() - n_padding; + let buf = buf.sliced(n_padding, len); + + // ptr to start of T, not to start of padding + let ptr = buf.as_slice().as_ptr(); + + // SAFETY: + // ptr and t are correct + let drop_fn = unsafe { create_drop::(ptr, n_t_vals) }; + let et = Box::new(ExtensionSentinel { + drop_fn: Some(drop_fn), + to_series_fn: None, + }); + let et_ptr = &*et as *const ExtensionSentinel; + std::mem::forget(et); + + let metadata = format_pl_smallstr!("{};{}", *PROCESS_ID, et_ptr as usize); + + let physical_type = ArrowDataType::FixedSizeBinary(t_size); + let extension_type = ArrowDataType::Extension(Box::new(ExtensionType { + name: PlSmallStr::from_static(EXTENSION_NAME), + inner: physical_type, + metadata: Some(metadata), + })); + + let array = FixedSizeBinaryArray::new(extension_type, buf, validity.into_opt_validity()); + + // SAFETY: we just heap allocated the ExtensionSentinel, so its alive. + unsafe { PolarsExtension::new(array) } +} + +#[cfg(test)] +mod test { + use std::fmt::{Display, Formatter}; + use std::hash::{Hash, Hasher}; + + use polars_utils::total_ord::TotalHash; + + use super::*; + + #[derive(Clone, Debug, Default, Eq, Hash, PartialEq)] + struct Foo { + pub a: i32, + pub b: u8, + pub other_heap: String, + } + + impl TotalEq for Foo { + fn tot_eq(&self, other: &Self) -> bool { + self == other + } + } + + impl TotalHash for Foo { + fn tot_hash(&self, state: &mut H) + where + H: Hasher, + { + self.hash(state); + } + } + + impl Display for Foo { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } + } + + impl PolarsObject for Foo { + fn type_name() -> &'static str { + "object" + } + } + + #[test] + fn test_create_extension() { + set_polars_allow_extension(true); + // Run this under MIRI. + let foo = Foo { + a: 1, + b: 1, + other_heap: "foo".into(), + }; + let foo2 = Foo { + a: 1, + b: 1, + other_heap: "bar".into(), + }; + + let vals = vec![Some(foo), Some(foo2)]; + create_extension(vals.into_iter()); + } +} diff --git a/crates/polars-core/src/chunked_array/object/extension/polars_extension.rs b/crates/polars-core/src/chunked_array/object/extension/polars_extension.rs new file mode 100644 index 000000000000..ac671ba236c7 --- /dev/null +++ b/crates/polars-core/src/chunked_array/object/extension/polars_extension.rs @@ -0,0 +1,99 @@ +use std::mem::ManuallyDrop; + +use super::*; +use crate::prelude::*; + +pub struct PolarsExtension { + array: Option, +} + +impl PolarsExtension { + /// This is very expensive + pub(crate) unsafe fn arr_to_av(arr: &FixedSizeBinaryArray, i: usize) -> AnyValue { + let arr = arr.slice_typed_unchecked(i, 1); + let pe = Self::new(arr); + let pe = ManuallyDrop::new(pe); + pe.get_series(&PlSmallStr::EMPTY) + .get(0) + .unwrap() + .into_static() + } + + pub(crate) unsafe fn new(array: FixedSizeBinaryArray) -> Self { + Self { array: Some(array) } + } + + /// Take the Array hold by [`PolarsExtension`] and forget polars extension, + /// so that drop is not called + pub(crate) fn take_and_forget(self) -> FixedSizeBinaryArray { + let mut md = ManuallyDrop::new(self); + md.array.take().unwrap() + } + + /// Apply a function with the sentinel, without the sentinels drop being called + unsafe fn with_sentinel T>(&self, fun: F) -> T { + let mut sentinel = self.get_sentinel(); + let out = fun(&mut sentinel); + std::mem::forget(sentinel); + out + } + + /// Load the sentinel from the heap. + /// be very careful, this dereferences a raw pointer on the heap, + unsafe fn get_sentinel(&self) -> Box { + if let ArrowDataType::Extension(ext) = self.array.as_ref().unwrap().dtype() { + let metadata = ext + .metadata + .as_ref() + .expect("should have metadata in extension type"); + let mut iter = metadata.split(';'); + + let pid = iter.next().unwrap().parse::().unwrap(); + let ptr = iter.next().unwrap().parse::().unwrap(); + if pid == *PROCESS_ID { + Box::from_raw(ptr as *const ExtensionSentinel as *mut ExtensionSentinel) + } else { + panic!("pid did not mach process id") + } + } else { + panic!("should be extension type") + } + } + + /// Calls the heap allocated function in the [`ExtensionSentinel`] that knows + /// how to convert the [`FixedSizeBinaryArray`] to a `Series` of type [`ObjectChunked`] + pub(crate) unsafe fn get_series(&self, name: &PlSmallStr) -> Series { + self.with_sentinel(|sent| { + (sent.to_series_fn.as_ref().unwrap())(self.array.as_ref().unwrap(), name) + }) + } + + // heap allocates a function that converts the binary array to a Series of [`ObjectChunked`] + // the `name` will be the `name` of the output `Series` when this function is called (later). + pub(crate) unsafe fn set_to_series_fn(&mut self) { + let f = Box::new(move |arr: &FixedSizeBinaryArray, name: &PlSmallStr| { + let iter = arr.iter().map(|opt| { + opt.map(|bytes| { + let t = std::ptr::read_unaligned(bytes.as_ptr() as *const T); + + let ret = t.clone(); + std::mem::forget(t); + ret + }) + }); + + let ca = ObjectChunked::::from_iter_options(name.clone(), iter); + ca.into_series() + }); + self.with_sentinel(move |sent| { + sent.to_series_fn = Some(f); + }); + } +} + +impl Drop for PolarsExtension { + fn drop(&mut self) { + // implicitly drop by taking ownership + unsafe { self.get_sentinel() }; + } +} diff --git a/crates/polars-core/src/chunked_array/object/is_valid.rs b/crates/polars-core/src/chunked_array/object/is_valid.rs new file mode 100644 index 000000000000..b75eb13e945b --- /dev/null +++ b/crates/polars-core/src/chunked_array/object/is_valid.rs @@ -0,0 +1,5 @@ +use arrow::legacy::is_valid::ArrowArray; + +use super::{ObjectArray, PolarsObject}; + +impl ArrowArray for ObjectArray {} diff --git a/crates/polars-core/src/chunked_array/object/iterator.rs b/crates/polars-core/src/chunked_array/object/iterator.rs new file mode 100644 index 000000000000..7abb9c46f4ee --- /dev/null +++ b/crates/polars-core/src/chunked_array/object/iterator.rs @@ -0,0 +1,144 @@ +use arrow::array::Array; +use arrow::trusted_len::TrustedLen; + +use crate::chunked_array::object::{ObjectArray, PolarsObject}; + +/// An iterator that returns Some(T) or None, that can be used on any ObjectArray +// Note: This implementation is based on std's [Vec]s' [IntoIter]. +pub struct ObjectIter<'a, T: PolarsObject> { + array: &'a ObjectArray, + current: usize, + current_end: usize, +} + +impl<'a, T: PolarsObject> ObjectIter<'a, T> { + /// create a new iterator + pub fn new(array: &'a ObjectArray) -> Self { + ObjectIter:: { + array, + current: 0, + current_end: array.len(), + } + } +} + +impl<'a, T: PolarsObject> std::iter::Iterator for ObjectIter<'a, T> { + type Item = Option<&'a T>; + + #[inline] + fn next(&mut self) -> Option { + if self.current == self.current_end { + None + // SAFETY: + // Se comment below + } else if unsafe { self.array.is_null_unchecked(self.current) } { + self.current += 1; + Some(None) + } else { + let old = self.current; + self.current += 1; + // SAFETY: + // we just checked bounds in `self.current_end == self.current` + // this is safe on the premise that this struct is initialized with + // current = array.len() + // and that current_end is ever only decremented + unsafe { Some(Some(self.array.value_unchecked(old))) } + } + } + + fn size_hint(&self) -> (usize, Option) { + ( + self.array.len() - self.current, + Some(self.array.len() - self.current), + ) + } +} + +impl std::iter::DoubleEndedIterator for ObjectIter<'_, T> { + fn next_back(&mut self) -> Option { + if self.current_end == self.current { + None + } else { + self.current_end -= 1; + Some(if self.array.is_null(self.current_end) { + None + } else { + // SAFETY: + // we just checked bounds in `self.current_end == self.current` + // this is safe on the premise that this struct is initialized with + // current = array.len() + // and that current_end is ever only decremented + unsafe { Some(self.array.value_unchecked(self.current_end)) } + }) + } + } +} + +/// all arrays have known size. +impl std::iter::ExactSizeIterator for ObjectIter<'_, T> {} + +impl<'a, T: PolarsObject> IntoIterator for &'a ObjectArray { + type Item = Option<&'a T>; + type IntoIter = ObjectIter<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + ObjectIter::<'a, T>::new(self) + } +} + +pub struct OwnedObjectIter { + array: ObjectArray, + current: usize, + current_end: usize, +} + +impl OwnedObjectIter { + /// create a new iterator + pub fn new(array: ObjectArray) -> Self { + let current_end = array.len(); + OwnedObjectIter:: { + array, + current: 0, + current_end, + } + } +} + +unsafe impl TrustedLen for OwnedObjectIter {} + +impl ObjectArray { + pub(crate) fn into_iter_cloned(self) -> OwnedObjectIter { + OwnedObjectIter::::new(self) + } +} +impl std::iter::Iterator for OwnedObjectIter { + type Item = Option; + + #[inline] + fn next(&mut self) -> Option { + if self.current == self.current_end { + None + // SAFETY: + // Se comment below + } else if unsafe { self.array.is_null_unchecked(self.current) } { + self.current += 1; + Some(None) + } else { + let old = self.current; + self.current += 1; + // SAFETY: + // we just checked bounds in `self.current_end == self.current` + // this is safe on the premise that this struct is initialized with + // current = array.len() + // and that current_end is ever only decremented + unsafe { Some(Some(self.array.value_unchecked(old).clone())) } + } + } + + fn size_hint(&self) -> (usize, Option) { + ( + self.array.len() - self.current, + Some(self.array.len() - self.current), + ) + } +} diff --git a/crates/polars-core/src/chunked_array/object/mod.rs b/crates/polars-core/src/chunked_array/object/mod.rs new file mode 100644 index 000000000000..d511c5676fd1 --- /dev/null +++ b/crates/polars-core/src/chunked_array/object/mod.rs @@ -0,0 +1,328 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use std::any::Any; +use std::fmt::{Debug, Display}; +use std::hash::Hash; + +use arrow::bitmap::Bitmap; +use arrow::bitmap::utils::{BitmapIter, ZipValidity}; +use arrow::buffer::Buffer; +use polars_utils::total_ord::TotalHash; + +use crate::prelude::*; + +pub mod builder; +#[cfg(feature = "object")] +pub(crate) mod extension; +mod is_valid; +mod iterator; +pub mod registry; + +pub use extension::set_polars_allow_extension; + +#[derive(Debug, Clone)] +pub struct ObjectArray +where + T: PolarsObject, +{ + values: Buffer, + validity: Option, +} + +/// Trimmed down object safe polars object +pub trait PolarsObjectSafe: Any + Debug + Send + Sync + Display { + fn type_name(&self) -> &'static str; + + fn as_any(&self) -> &dyn Any; + + fn to_boxed(&self) -> Box; + + fn equal(&self, other: &dyn PolarsObjectSafe) -> bool; +} + +impl PartialEq for &dyn PolarsObjectSafe { + fn eq(&self, other: &Self) -> bool { + self.equal(*other) + } +} + +/// Values need to implement this so that they can be stored into a Series and DataFrame +pub trait PolarsObject: + Any + Debug + Clone + Send + Sync + Default + Display + Hash + TotalHash + PartialEq + Eq + TotalEq +{ + /// This should be used as type information. Consider this a part of the type system. + fn type_name() -> &'static str; +} + +impl PolarsObjectSafe for T { + fn type_name(&self) -> &'static str { + T::type_name() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn to_boxed(&self) -> Box { + Box::new(self.clone()) + } + + fn equal(&self, other: &dyn PolarsObjectSafe) -> bool { + let Some(other) = other.as_any().downcast_ref::() else { + return false; + }; + self == other + } +} + +pub type ObjectValueIter<'a, T> = std::slice::Iter<'a, T>; + +impl ObjectArray +where + T: PolarsObject, +{ + pub fn values_iter(&self) -> ObjectValueIter<'_, T> { + self.values.iter() + } + + /// Returns an iterator of `Option<&T>` over every element of this array. + pub fn iter(&self) -> ZipValidity<&T, ObjectValueIter<'_, T>, BitmapIter> { + ZipValidity::new_with_validity(self.values_iter(), self.validity.as_ref()) + } + + /// Get a value at a certain index location + pub fn value(&self, index: usize) -> &T { + &self.values[index] + } + + pub fn get(&self, index: usize) -> Option<&T> { + if self.is_valid(index) { + Some(unsafe { self.value_unchecked(index) }) + } else { + None + } + } + + /// Get a value at a certain index location + /// + /// # Safety + /// + /// This does not any bound checks. The caller needs to ensure the index is within + /// the size of the array. + pub unsafe fn value_unchecked(&self, index: usize) -> &T { + self.values.get_unchecked(index) + } + + /// Check validity + /// + /// # Safety + /// No bounds checks + #[inline] + pub unsafe fn is_valid_unchecked(&self, i: usize) -> bool { + if let Some(b) = &self.validity { + b.get_bit_unchecked(i) + } else { + true + } + } + + /// Check validity + /// + /// # Safety + /// No bounds checks + #[inline] + pub unsafe fn is_null_unchecked(&self, i: usize) -> bool { + !self.is_valid_unchecked(i) + } + + /// Returns this array with a new validity. + /// # Panic + /// Panics iff `validity.len() != self.len()`. + #[must_use] + #[inline] + pub fn with_validity(mut self, validity: Option) -> Self { + self.set_validity(validity); + self + } + + /// Sets the validity of this array. + /// # Panics + /// This function panics iff `validity.len() != self.len()`. + #[inline] + pub fn set_validity(&mut self, validity: Option) { + if matches!(&validity, Some(bitmap) if bitmap.len() != self.len()) { + panic!("validity must be equal to the array's length") + } + self.validity = validity; + } +} + +impl Array for ObjectArray +where + T: PolarsObject, +{ + fn as_any(&self) -> &dyn Any { + self + } + + fn dtype(&self) -> &ArrowDataType { + &ArrowDataType::FixedSizeBinary(size_of::()) + } + + fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.validity = self + .validity + .take() + .map(|bitmap| bitmap.sliced_unchecked(offset, length)) + .filter(|bitmap| bitmap.unset_bits() > 0); + self.values.slice_unchecked(offset, length); + } + + fn split_at_boxed(&self, offset: usize) -> (Box, Box) { + let (lhs, rhs) = Splitable::split_at(self, offset); + (Box::new(lhs), Box::new(rhs)) + } + + unsafe fn split_at_boxed_unchecked(&self, offset: usize) -> (Box, Box) { + let (lhs, rhs) = unsafe { Splitable::split_at_unchecked(self, offset) }; + (Box::new(lhs), Box::new(rhs)) + } + + fn len(&self) -> usize { + self.values.len() + } + + fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } + + fn to_boxed(&self) -> Box { + Box::new(self.clone()) + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + unimplemented!() + } + + fn null_count(&self) -> usize { + match &self.validity { + None => 0, + Some(validity) => validity.unset_bits(), + } + } +} + +impl Splitable for ObjectArray { + fn check_bound(&self, offset: usize) -> bool { + offset <= self.len() + } + + unsafe fn _split_at_unchecked(&self, offset: usize) -> (Self, Self) { + let (left_values, right_values) = unsafe { self.values.split_at_unchecked(offset) }; + let (left_validity, right_validity) = unsafe { self.validity.split_at_unchecked(offset) }; + ( + Self { + values: left_values, + validity: left_validity, + }, + Self { + values: right_values, + validity: right_validity, + }, + ) + } +} + +impl StaticArray for ObjectArray { + type ValueT<'a> = &'a T; + type ZeroableValueT<'a> = Option<&'a T>; + type ValueIterT<'a> = ObjectValueIter<'a, T>; + + #[inline] + unsafe fn value_unchecked(&self, idx: usize) -> Self::ValueT<'_> { + self.value_unchecked(idx) + } + + fn values_iter(&self) -> Self::ValueIterT<'_> { + self.values_iter() + } + + fn iter(&self) -> ZipValidity, Self::ValueIterT<'_>, BitmapIter> { + self.iter() + } + + fn with_validity_typed(self, validity: Option) -> Self { + self.with_validity(validity) + } + + fn full_null(length: usize, _dtype: ArrowDataType) -> Self { + ObjectArray { + values: vec![T::default(); length].into(), + validity: Some(Bitmap::new_with_value(false, length)), + } + } +} + +impl ParameterFreeDtypeStaticArray for ObjectArray { + fn get_dtype() -> ArrowDataType { + ArrowDataType::FixedSizeBinary(size_of::()) + } +} + +impl ObjectChunked +where + T: PolarsObject, +{ + /// Get a hold to an object that can be formatted or downcasted via the Any trait. + /// + /// # Safety + /// + /// No bounds checks + pub unsafe fn get_object_unchecked(&self, index: usize) -> Option<&dyn PolarsObjectSafe> { + let (chunk_idx, idx) = self.index_to_chunked_index(index); + self.get_object_chunked_unchecked(chunk_idx, idx) + } + + pub(crate) unsafe fn get_object_chunked_unchecked( + &self, + chunk: usize, + index: usize, + ) -> Option<&dyn PolarsObjectSafe> { + let chunks = self.downcast_chunks(); + let arr = chunks.get_unchecked(chunk); + if arr.is_valid_unchecked(index) { + Some(arr.value(index)) + } else { + None + } + } + + /// Get a hold to an object that can be formatted or downcasted via the Any trait. + pub fn get_object(&self, index: usize) -> Option<&dyn PolarsObjectSafe> { + if index < self.len() { + unsafe { self.get_object_unchecked(index) } + } else { + None + } + } +} + +impl From> for ObjectArray { + fn from(values: Vec) -> Self { + Self { + values: values.into(), + validity: None, + } + } +} diff --git a/crates/polars-core/src/chunked_array/object/registry.rs b/crates/polars-core/src/chunked_array/object/registry.rs new file mode 100644 index 000000000000..6a408507c139 --- /dev/null +++ b/crates/polars-core/src/chunked_array/object/registry.rs @@ -0,0 +1,160 @@ +//! This is a heap allocated utility that can be used to register an object type. +//! +//! That object type will know its own generic type parameter `T` and callers can simply +//! send `&Any` values and don't have to know the generic type themselves. +use std::any::Any; +use std::fmt::{Debug, Formatter}; +use std::ops::Deref; +use std::sync::{Arc, LazyLock, RwLock}; + +use arrow::array::ArrayRef; +use arrow::array::builder::ArrayBuilder; +use arrow::datatypes::ArrowDataType; +use polars_utils::pl_str::PlSmallStr; + +use crate::chunked_array::object::builder::ObjectChunkedBuilder; +use crate::datatypes::AnyValue; +use crate::prelude::{ListBuilderTrait, ObjectChunked, PolarsObject}; +use crate::series::{IntoSeries, Series}; + +/// Takes a `name` and `capacity` and constructs a new builder. +pub type BuilderConstructor = + Box Box + Send + Sync>; +pub type ObjectConverter = Arc Box + Send + Sync>; +pub type PyObjectConverter = Arc Box + Send + Sync>; + +pub struct ObjectRegistry { + /// A function that creates an object builder + pub builder_constructor: BuilderConstructor, + // A function that converts AnyValue to Box of the object type + object_converter: Option, + // A function that converts AnyValue to Box of the PyObject type + pyobject_converter: Option, + pub physical_dtype: ArrowDataType, +} + +impl Debug for ObjectRegistry { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "object-registry") + } +} + +static GLOBAL_OBJECT_REGISTRY: LazyLock>> = + LazyLock::new(Default::default); + +/// This trait can be registered, after which that global registration +/// can be used to materialize object types +pub trait AnonymousObjectBuilder: ArrayBuilder { + fn as_array_builder(self: Box) -> Box; + + /// # Safety + /// Expect `ObjectArray` arrays. + unsafe fn from_chunks(self: Box, chunks: Vec) -> Series; + + /// Append a `null` value. + fn append_null(&mut self); + + /// Append a `T` of [`ObjectChunked`][ObjectChunked] made generic via the [`Any`] trait. + /// + /// [ObjectChunked]: crate::chunked_array::object::ObjectChunked + fn append_value(&mut self, value: &dyn Any); + + fn append_option(&mut self, value: Option<&dyn Any>) { + match value { + None => self.append_null(), + Some(v) => self.append_value(v), + } + } + + /// Take the current state and materialize as a [`Series`] + /// the builder should not be used after that. + fn to_series(&mut self) -> Series; + + fn get_list_builder( + &self, + name: PlSmallStr, + values_capacity: usize, + list_capacity: usize, + ) -> Box; +} + +impl AnonymousObjectBuilder for ObjectChunkedBuilder { + /// # Safety + /// Expects `ObjectArray` arrays. + unsafe fn from_chunks(self: Box, chunks: Vec) -> Series { + ObjectChunked::::new_with_compute_len(Arc::new(self.field().clone()), chunks) + .into_series() + } + + fn as_array_builder(self: Box) -> Box { + self + } + + fn append_null(&mut self) { + self.append_null() + } + + fn append_value(&mut self, value: &dyn Any) { + let value = value.downcast_ref::().unwrap(); + self.append_value(value.clone()) + } + + fn to_series(&mut self) -> Series { + let builder = std::mem::take(self); + builder.finish().into_series() + } + fn get_list_builder( + &self, + name: PlSmallStr, + values_capacity: usize, + list_capacity: usize, + ) -> Box { + Box::new(super::extension::list::ExtensionListBuilder::::new( + name, + values_capacity, + list_capacity, + )) + } +} + +pub fn register_object_builder( + builder_constructor: BuilderConstructor, + object_converter: ObjectConverter, + pyobject_converter: PyObjectConverter, + physical_dtype: ArrowDataType, +) { + let reg = GLOBAL_OBJECT_REGISTRY.deref(); + let mut reg = reg.write().unwrap(); + + *reg = Some(ObjectRegistry { + builder_constructor, + object_converter: Some(object_converter), + pyobject_converter: Some(pyobject_converter), + physical_dtype, + }) +} + +#[cold] +pub fn get_object_physical_type() -> ArrowDataType { + let reg = GLOBAL_OBJECT_REGISTRY.read().unwrap(); + let reg = reg.as_ref().unwrap(); + reg.physical_dtype.clone() +} + +pub fn get_object_builder(name: PlSmallStr, capacity: usize) -> Box { + let reg = GLOBAL_OBJECT_REGISTRY.read().unwrap(); + let reg = reg.as_ref().unwrap(); + (reg.builder_constructor)(name, capacity) +} + +pub fn get_object_converter() -> ObjectConverter { + let reg = GLOBAL_OBJECT_REGISTRY.read().unwrap(); + let reg = reg.as_ref().unwrap(); + reg.object_converter.as_ref().unwrap().clone() +} + +pub fn get_pyobject_converter() -> PyObjectConverter { + let reg = GLOBAL_OBJECT_REGISTRY.read().unwrap(); + let reg = reg.as_ref().unwrap(); + reg.pyobject_converter.as_ref().unwrap().clone() +} diff --git a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs new file mode 100644 index 000000000000..42447cf3f823 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs @@ -0,0 +1,976 @@ +//! Implementations of the ChunkAgg trait. +mod quantile; +mod var; + +use arrow::types::NativeType; +use num_traits::{Float, One, ToPrimitive, Zero}; +use polars_compute::float_sum; +use polars_compute::min_max::MinMaxKernel; +use polars_compute::rolling::QuantileMethod; +use polars_compute::sum::{WrappingSum, wrapping_sum_arr}; +use polars_utils::min_max::MinMax; +use polars_utils::sync::SyncPtr; +pub use quantile::*; +pub use var::*; + +use super::float_sorted_arg_max::{ + float_arg_max_sorted_ascending, float_arg_max_sorted_descending, +}; +use crate::chunked_array::ChunkedArray; +use crate::datatypes::{BooleanChunked, PolarsNumericType}; +use crate::prelude::*; +use crate::series::IsSorted; + +/// Aggregations that return [`Series`] of unit length. Those can be used in broadcasting operations. +pub trait ChunkAggSeries { + /// Get the sum of the [`ChunkedArray`] as a new [`Series`] of length 1. + fn sum_reduce(&self) -> Scalar { + unimplemented!() + } + /// Get the max of the [`ChunkedArray`] as a new [`Series`] of length 1. + fn max_reduce(&self) -> Scalar { + unimplemented!() + } + /// Get the min of the [`ChunkedArray`] as a new [`Series`] of length 1. + fn min_reduce(&self) -> Scalar { + unimplemented!() + } + /// Get the product of the [`ChunkedArray`] as a new [`Series`] of length 1. + fn prod_reduce(&self) -> Scalar { + unimplemented!() + } +} + +fn sum(array: &PrimitiveArray) -> T +where + T: NumericNative + NativeType + WrappingSum, +{ + if array.null_count() == array.len() { + return T::default(); + } + + if T::is_float() { + unsafe { + if T::is_f32() { + let f32_arr = + std::mem::transmute::<&PrimitiveArray, &PrimitiveArray>(array); + let sum = float_sum::sum_arr_as_f32(f32_arr); + std::mem::transmute_copy::(&sum) + } else if T::is_f64() { + let f64_arr = + std::mem::transmute::<&PrimitiveArray, &PrimitiveArray>(array); + let sum = float_sum::sum_arr_as_f64(f64_arr); + std::mem::transmute_copy::(&sum) + } else { + unreachable!("only supported float types are f32 and f64"); + } + } + } else { + wrapping_sum_arr(array) + } +} + +impl ChunkAgg for ChunkedArray +where + T: PolarsNumericType, + T::Native: WrappingSum, + PrimitiveArray: for<'a> MinMaxKernel = T::Native>, +{ + fn sum(&self) -> Option { + Some( + self.downcast_iter() + .map(sum) + .fold(T::Native::zero(), |acc, v| acc + v), + ) + } + + fn _sum_as_f64(&self) -> f64 { + self.downcast_iter().map(float_sum::sum_arr_as_f64).sum() + } + + fn min(&self) -> Option { + if self.null_count() == self.len() { + return None; + } + + // There is at least one non-null value. + + match self.is_sorted_flag() { + IsSorted::Ascending => { + let idx = self.first_non_null().unwrap(); + unsafe { self.get_unchecked(idx) } + }, + IsSorted::Descending => { + let idx = self.last_non_null().unwrap(); + unsafe { self.get_unchecked(idx) } + }, + IsSorted::Not => self + .downcast_iter() + .filter_map(MinMaxKernel::min_ignore_nan_kernel) + .reduce(MinMax::min_ignore_nan), + } + } + + fn max(&self) -> Option { + if self.null_count() == self.len() { + return None; + } + // There is at least one non-null value. + + match self.is_sorted_flag() { + IsSorted::Ascending => { + let idx = if T::get_dtype().is_float() { + float_arg_max_sorted_ascending(self) + } else { + self.last_non_null().unwrap() + }; + + unsafe { self.get_unchecked(idx) } + }, + IsSorted::Descending => { + let idx = if T::get_dtype().is_float() { + float_arg_max_sorted_descending(self) + } else { + self.first_non_null().unwrap() + }; + + unsafe { self.get_unchecked(idx) } + }, + IsSorted::Not => self + .downcast_iter() + .filter_map(MinMaxKernel::max_ignore_nan_kernel) + .reduce(MinMax::max_ignore_nan), + } + } + + fn min_max(&self) -> Option<(T::Native, T::Native)> { + if self.null_count() == self.len() { + return None; + } + // There is at least one non-null value. + + match self.is_sorted_flag() { + IsSorted::Ascending => { + let min = unsafe { self.get_unchecked(self.first_non_null().unwrap()) }; + let max = { + let idx = if T::get_dtype().is_float() { + float_arg_max_sorted_ascending(self) + } else { + self.last_non_null().unwrap() + }; + + unsafe { self.get_unchecked(idx) } + }; + min.zip(max) + }, + IsSorted::Descending => { + let min = unsafe { self.get_unchecked(self.last_non_null().unwrap()) }; + let max = { + let idx = if T::get_dtype().is_float() { + float_arg_max_sorted_descending(self) + } else { + self.first_non_null().unwrap() + }; + + unsafe { self.get_unchecked(idx) } + }; + + min.zip(max) + }, + IsSorted::Not => self + .downcast_iter() + .filter_map(MinMaxKernel::min_max_ignore_nan_kernel) + .reduce(|(min1, max1), (min2, max2)| { + ( + MinMax::min_ignore_nan(min1, min2), + MinMax::max_ignore_nan(max1, max2), + ) + }), + } + } + + fn mean(&self) -> Option { + let count = self.len() - self.null_count(); + if count == 0 { + return None; + } + Some(self._sum_as_f64() / count as f64) + } +} + +/// Booleans are cast to 1 or 0. +impl BooleanChunked { + pub fn sum(&self) -> Option { + Some(if self.is_empty() { + 0 + } else { + self.downcast_iter() + .map(|arr| match arr.validity() { + Some(validity) => { + (arr.len() - (validity & arr.values()).unset_bits()) as IdxSize + }, + None => (arr.len() - arr.values().unset_bits()) as IdxSize, + }) + .sum() + }) + } + + pub fn min(&self) -> Option { + let nc = self.null_count(); + let len = self.len(); + if self.is_empty() || nc == len { + return None; + } + if nc == 0 { + if self.all() { Some(true) } else { Some(false) } + } else { + // we can unwrap as we already checked empty and all null above + if (self.sum().unwrap() + nc as IdxSize) == len as IdxSize { + Some(true) + } else { + Some(false) + } + } + } + + pub fn max(&self) -> Option { + if self.is_empty() || self.null_count() == self.len() { + return None; + } + if self.any() { Some(true) } else { Some(false) } + } + pub fn mean(&self) -> Option { + if self.is_empty() || self.null_count() == self.len() { + return None; + } + self.sum() + .map(|sum| sum as f64 / (self.len() - self.null_count()) as f64) + } +} + +// Needs the same trait bounds as the implementation of ChunkedArray of dyn Series. +impl ChunkAggSeries for ChunkedArray +where + T: PolarsNumericType, + T::Native: WrappingSum, + PrimitiveArray: for<'a> MinMaxKernel = T::Native>, + ChunkedArray: IntoSeries, +{ + fn sum_reduce(&self) -> Scalar { + let v: Option = self.sum(); + Scalar::new(T::get_dtype(), v.into()) + } + + fn max_reduce(&self) -> Scalar { + let v = ChunkAgg::max(self); + Scalar::new(T::get_dtype(), v.into()) + } + + fn min_reduce(&self) -> Scalar { + let v = ChunkAgg::min(self); + Scalar::new(T::get_dtype(), v.into()) + } + + fn prod_reduce(&self) -> Scalar { + let mut prod = T::Native::one(); + + for arr in self.downcast_iter() { + for v in arr.into_iter().flatten() { + prod = prod * *v + } + } + Scalar::new(T::get_dtype(), prod.into()) + } +} + +impl VarAggSeries for ChunkedArray +where + T: PolarsIntegerType, + ChunkedArray: ChunkVar, +{ + fn var_reduce(&self, ddof: u8) -> Scalar { + let v = self.var(ddof); + Scalar::new(DataType::Float64, v.into()) + } + + fn std_reduce(&self, ddof: u8) -> Scalar { + let v = self.std(ddof); + Scalar::new(DataType::Float64, v.into()) + } +} + +impl VarAggSeries for Float32Chunked { + fn var_reduce(&self, ddof: u8) -> Scalar { + let v = self.var(ddof).map(|v| v as f32); + Scalar::new(DataType::Float32, v.into()) + } + + fn std_reduce(&self, ddof: u8) -> Scalar { + let v = self.std(ddof).map(|v| v as f32); + Scalar::new(DataType::Float32, v.into()) + } +} + +impl VarAggSeries for Float64Chunked { + fn var_reduce(&self, ddof: u8) -> Scalar { + let v = self.var(ddof); + Scalar::new(DataType::Float64, v.into()) + } + + fn std_reduce(&self, ddof: u8) -> Scalar { + let v = self.std(ddof); + Scalar::new(DataType::Float64, v.into()) + } +} + +impl QuantileAggSeries for ChunkedArray +where + T: PolarsIntegerType, + T::Native: Ord + WrappingSum, +{ + fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { + let v = self.quantile(quantile, method)?; + Ok(Scalar::new(DataType::Float64, v.into())) + } + + fn median_reduce(&self) -> Scalar { + let v = self.median(); + Scalar::new(DataType::Float64, v.into()) + } +} + +impl QuantileAggSeries for Float32Chunked { + fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { + let v = self.quantile(quantile, method)?; + Ok(Scalar::new(DataType::Float32, v.into())) + } + + fn median_reduce(&self) -> Scalar { + let v = self.median(); + Scalar::new(DataType::Float32, v.into()) + } +} + +impl QuantileAggSeries for Float64Chunked { + fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { + let v = self.quantile(quantile, method)?; + Ok(Scalar::new(DataType::Float64, v.into())) + } + + fn median_reduce(&self) -> Scalar { + let v = self.median(); + Scalar::new(DataType::Float64, v.into()) + } +} + +impl ChunkAggSeries for BooleanChunked { + fn sum_reduce(&self) -> Scalar { + let v = self.sum(); + Scalar::new(IDX_DTYPE, v.into()) + } + fn max_reduce(&self) -> Scalar { + let v = self.max(); + Scalar::new(DataType::Boolean, v.into()) + } + fn min_reduce(&self) -> Scalar { + let v = self.min(); + Scalar::new(DataType::Boolean, v.into()) + } +} + +impl StringChunked { + pub(crate) fn max_str(&self) -> Option<&str> { + if self.is_empty() { + return None; + } + match self.is_sorted_flag() { + IsSorted::Ascending => { + self.last_non_null().and_then(|idx| { + // SAFETY: last_non_null returns in bound index + unsafe { self.get_unchecked(idx) } + }) + }, + IsSorted::Descending => { + self.first_non_null().and_then(|idx| { + // SAFETY: first_non_null returns in bound index + unsafe { self.get_unchecked(idx) } + }) + }, + IsSorted::Not => self + .downcast_iter() + .filter_map(MinMaxKernel::max_ignore_nan_kernel) + .reduce(MinMax::max_ignore_nan), + } + } + pub(crate) fn min_str(&self) -> Option<&str> { + if self.is_empty() { + return None; + } + match self.is_sorted_flag() { + IsSorted::Ascending => { + self.first_non_null().and_then(|idx| { + // SAFETY: first_non_null returns in bound index + unsafe { self.get_unchecked(idx) } + }) + }, + IsSorted::Descending => { + self.last_non_null().and_then(|idx| { + // SAFETY: last_non_null returns in bound index + unsafe { self.get_unchecked(idx) } + }) + }, + IsSorted::Not => self + .downcast_iter() + .filter_map(MinMaxKernel::min_ignore_nan_kernel) + .reduce(MinMax::min_ignore_nan), + } + } +} + +impl ChunkAggSeries for StringChunked { + fn max_reduce(&self) -> Scalar { + let av: AnyValue = self.max_str().into(); + Scalar::new(DataType::String, av.into_static()) + } + fn min_reduce(&self) -> Scalar { + let av: AnyValue = self.min_str().into(); + Scalar::new(DataType::String, av.into_static()) + } +} + +#[cfg(feature = "dtype-categorical")] +impl CategoricalChunked { + fn min_categorical(&self) -> Option { + if self.is_empty() || self.null_count() == self.len() { + return None; + } + if self.uses_lexical_ordering() { + let rev_map = self.get_rev_map(); + // Fast path where all categories are used + let c = if self._can_fast_unique() { + rev_map.get_categories().min_ignore_nan_kernel() + } else { + // SAFETY: + // Indices are in bounds + self.physical() + .iter() + .flat_map(|opt_el: Option| { + opt_el.map(|el| unsafe { rev_map.get_unchecked(el) }) + }) + .min() + }; + rev_map.find(c.unwrap()) + } else { + self.physical().min() + } + } + + fn max_categorical(&self) -> Option { + if self.is_empty() || self.null_count() == self.len() { + return None; + } + if self.uses_lexical_ordering() { + let rev_map = self.get_rev_map(); + // Fast path where all categories are used + let c = if self._can_fast_unique() { + rev_map.get_categories().max_ignore_nan_kernel() + } else { + // SAFETY: + // Indices are in bounds + self.physical() + .iter() + .flat_map(|opt_el: Option| { + opt_el.map(|el| unsafe { rev_map.get_unchecked(el) }) + }) + .max() + }; + rev_map.find(c.unwrap()) + } else { + self.physical().max() + } + } +} + +#[cfg(feature = "dtype-categorical")] +impl ChunkAggSeries for CategoricalChunked { + fn min_reduce(&self) -> Scalar { + match self.dtype() { + DataType::Enum(r, _) => match self.physical().min() { + None => Scalar::new(self.dtype().clone(), AnyValue::Null), + Some(v) => { + let RevMapping::Local(arr, _) = &**r.as_ref().unwrap() else { + unreachable!() + }; + Scalar::new( + self.dtype().clone(), + AnyValue::EnumOwned( + v, + r.as_ref().unwrap().clone(), + SyncPtr::from_const(arr as *const _), + ), + ) + }, + }, + DataType::Categorical(r, _) => match self.min_categorical() { + None => Scalar::new(self.dtype().clone(), AnyValue::Null), + Some(v) => { + let r = r.as_ref().unwrap(); + let arr = match &**r { + RevMapping::Local(arr, _) => arr, + RevMapping::Global(_, arr, _) => arr, + }; + Scalar::new( + self.dtype().clone(), + AnyValue::CategoricalOwned( + v, + r.clone(), + SyncPtr::from_const(arr as *const _), + ), + ) + }, + }, + _ => unreachable!(), + } + } + fn max_reduce(&self) -> Scalar { + match self.dtype() { + DataType::Enum(r, _) => match self.physical().max() { + None => Scalar::new(self.dtype().clone(), AnyValue::Null), + Some(v) => { + let RevMapping::Local(arr, _) = &**r.as_ref().unwrap() else { + unreachable!() + }; + Scalar::new( + self.dtype().clone(), + AnyValue::EnumOwned( + v, + r.as_ref().unwrap().clone(), + SyncPtr::from_const(arr as *const _), + ), + ) + }, + }, + DataType::Categorical(r, _) => match self.max_categorical() { + None => Scalar::new(self.dtype().clone(), AnyValue::Null), + Some(v) => { + let r = r.as_ref().unwrap(); + let arr = match &**r { + RevMapping::Local(arr, _) => arr, + RevMapping::Global(_, arr, _) => arr, + }; + Scalar::new( + self.dtype().clone(), + AnyValue::CategoricalOwned( + v, + r.clone(), + SyncPtr::from_const(arr as *const _), + ), + ) + }, + }, + _ => unreachable!(), + } + } +} + +impl BinaryChunked { + pub fn max_binary(&self) -> Option<&[u8]> { + if self.is_empty() { + return None; + } + match self.is_sorted_flag() { + IsSorted::Ascending => { + self.last_non_null().and_then(|idx| { + // SAFETY: last_non_null returns in bound index. + unsafe { self.get_unchecked(idx) } + }) + }, + IsSorted::Descending => { + self.first_non_null().and_then(|idx| { + // SAFETY: first_non_null returns in bound index. + unsafe { self.get_unchecked(idx) } + }) + }, + IsSorted::Not => self + .downcast_iter() + .filter_map(MinMaxKernel::max_ignore_nan_kernel) + .reduce(MinMax::max_ignore_nan), + } + } + + pub fn min_binary(&self) -> Option<&[u8]> { + if self.is_empty() { + return None; + } + match self.is_sorted_flag() { + IsSorted::Ascending => { + self.first_non_null().and_then(|idx| { + // SAFETY: first_non_null returns in bound index. + unsafe { self.get_unchecked(idx) } + }) + }, + IsSorted::Descending => { + self.last_non_null().and_then(|idx| { + // SAFETY: last_non_null returns in bound index. + unsafe { self.get_unchecked(idx) } + }) + }, + IsSorted::Not => self + .downcast_iter() + .filter_map(MinMaxKernel::min_ignore_nan_kernel) + .reduce(MinMax::min_ignore_nan), + } + } +} + +impl ChunkAggSeries for BinaryChunked { + fn sum_reduce(&self) -> Scalar { + unimplemented!() + } + fn max_reduce(&self) -> Scalar { + let av: AnyValue = self.max_binary().into(); + Scalar::new(self.dtype().clone(), av.into_static()) + } + fn min_reduce(&self) -> Scalar { + let av: AnyValue = self.min_binary().into(); + Scalar::new(self.dtype().clone(), av.into_static()) + } +} + +#[cfg(feature = "object")] +impl ChunkAggSeries for ObjectChunked {} + +#[cfg(test)] +mod test { + use polars_compute::rolling::QuantileMethod; + + use crate::prelude::*; + + #[test] + #[cfg(not(miri))] + fn test_var() { + // Validated with numpy. Note that numpy uses ddof as an argument which + // influences results. The default ddof=0, we chose ddof=1, which is + // standard in statistics. + let ca1 = Int32Chunked::new(PlSmallStr::EMPTY, &[5, 8, 9, 5, 0]); + let ca2 = Int32Chunked::new( + PlSmallStr::EMPTY, + &[ + Some(5), + None, + Some(8), + Some(9), + None, + Some(5), + Some(0), + None, + ], + ); + for ca in &[ca1, ca2] { + let out = ca.var(1); + assert_eq!(out, Some(12.3)); + let out = ca.std(1).unwrap(); + assert!((3.5071355833500366 - out).abs() < 0.000000001); + } + } + + #[test] + fn test_agg_float() { + let ca1 = Float32Chunked::new(PlSmallStr::from_static("a"), &[1.0, f32::NAN]); + let ca2 = Float32Chunked::new(PlSmallStr::from_static("b"), &[f32::NAN, 1.0]); + assert_eq!(ca1.min(), ca2.min()); + let ca1 = Float64Chunked::new(PlSmallStr::from_static("a"), &[1.0, f64::NAN]); + let ca2 = Float64Chunked::from_slice(PlSmallStr::from_static("b"), &[f64::NAN, 1.0]); + assert_eq!(ca1.min(), ca2.min()); + println!("{:?}", (ca1.min(), ca2.min())) + } + + #[test] + fn test_median() { + let ca = UInt32Chunked::new( + PlSmallStr::from_static("a"), + &[Some(2), Some(1), None, Some(3), Some(5), None, Some(4)], + ); + assert_eq!(ca.median(), Some(3.0)); + let ca = UInt32Chunked::new( + PlSmallStr::from_static("a"), + &[ + None, + Some(7), + Some(6), + Some(2), + Some(1), + None, + Some(3), + Some(5), + None, + Some(4), + ], + ); + assert_eq!(ca.median(), Some(4.0)); + + let ca = Float32Chunked::from_slice( + PlSmallStr::EMPTY, + &[ + 0.166189, 0.166559, 0.168517, 0.169393, 0.175272, 0.233167, 0.238787, 0.266562, + 0.26903, 0.285792, 0.292801, 0.293429, 0.301706, 0.308534, 0.331489, 0.346095, + 0.367644, 0.369939, 0.372074, 0.41014, 0.415789, 0.421781, 0.427725, 0.465363, + 0.500208, 2.621727, 2.803311, 3.868526, + ], + ); + assert!((ca.median().unwrap() - 0.3200115).abs() < 0.0001) + } + + #[test] + fn test_mean() { + let ca = Float32Chunked::new(PlSmallStr::EMPTY, &[Some(1.0), Some(2.0), None]); + assert_eq!(ca.mean().unwrap(), 1.5); + assert_eq!( + ca.into_series() + .mean_reduce() + .value() + .extract::() + .unwrap(), + 1.5 + ); + // all null values case + let ca = Float32Chunked::full_null(PlSmallStr::EMPTY, 3); + assert_eq!(ca.mean(), None); + assert_eq!( + ca.into_series().mean_reduce().value().extract::(), + None + ); + } + + #[test] + fn test_quantile_all_null() { + let test_f32 = Float32Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]); + let test_i32 = Int32Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]); + let test_f64 = Float64Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]); + let test_i64 = Int64Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]); + + let methods = vec![ + QuantileMethod::Nearest, + QuantileMethod::Lower, + QuantileMethod::Higher, + QuantileMethod::Midpoint, + QuantileMethod::Linear, + QuantileMethod::Equiprobable, + ]; + + for method in methods { + assert_eq!(test_f32.quantile(0.9, method).unwrap(), None); + assert_eq!(test_i32.quantile(0.9, method).unwrap(), None); + assert_eq!(test_f64.quantile(0.9, method).unwrap(), None); + assert_eq!(test_i64.quantile(0.9, method).unwrap(), None); + } + } + + #[test] + fn test_quantile_single_value() { + let test_f32 = Float32Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1.0)]); + let test_i32 = Int32Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1)]); + let test_f64 = Float64Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1.0)]); + let test_i64 = Int64Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1)]); + + let methods = vec![ + QuantileMethod::Nearest, + QuantileMethod::Lower, + QuantileMethod::Higher, + QuantileMethod::Midpoint, + QuantileMethod::Linear, + QuantileMethod::Equiprobable, + ]; + + for method in methods { + assert_eq!(test_f32.quantile(0.5, method).unwrap(), Some(1.0)); + assert_eq!(test_i32.quantile(0.5, method).unwrap(), Some(1.0)); + assert_eq!(test_f64.quantile(0.5, method).unwrap(), Some(1.0)); + assert_eq!(test_i64.quantile(0.5, method).unwrap(), Some(1.0)); + } + } + + #[test] + fn test_quantile_min_max() { + let test_f32 = Float32Chunked::from_slice_options( + PlSmallStr::EMPTY, + &[None, Some(1f32), Some(5f32), Some(1f32)], + ); + let test_i32 = Int32Chunked::from_slice_options( + PlSmallStr::EMPTY, + &[None, Some(1i32), Some(5i32), Some(1i32)], + ); + let test_f64 = Float64Chunked::from_slice_options( + PlSmallStr::EMPTY, + &[None, Some(1f64), Some(5f64), Some(1f64)], + ); + let test_i64 = Int64Chunked::from_slice_options( + PlSmallStr::EMPTY, + &[None, Some(1i64), Some(5i64), Some(1i64)], + ); + + let methods = vec![ + QuantileMethod::Nearest, + QuantileMethod::Lower, + QuantileMethod::Higher, + QuantileMethod::Midpoint, + QuantileMethod::Linear, + QuantileMethod::Equiprobable, + ]; + + for method in methods { + assert_eq!(test_f32.quantile(0.0, method).unwrap(), test_f32.min()); + assert_eq!(test_f32.quantile(1.0, method).unwrap(), test_f32.max()); + + assert_eq!( + test_i32.quantile(0.0, method).unwrap().unwrap(), + test_i32.min().unwrap() as f64 + ); + assert_eq!( + test_i32.quantile(1.0, method).unwrap().unwrap(), + test_i32.max().unwrap() as f64 + ); + + assert_eq!(test_f64.quantile(0.0, method).unwrap(), test_f64.min()); + assert_eq!(test_f64.quantile(1.0, method).unwrap(), test_f64.max()); + assert_eq!(test_f64.quantile(0.5, method).unwrap(), test_f64.median()); + + assert_eq!( + test_i64.quantile(0.0, method).unwrap().unwrap(), + test_i64.min().unwrap() as f64 + ); + assert_eq!( + test_i64.quantile(1.0, method).unwrap().unwrap(), + test_i64.max().unwrap() as f64 + ); + } + } + + #[test] + fn test_quantile() { + let ca = UInt32Chunked::new( + PlSmallStr::from_static("a"), + &[Some(2), Some(1), None, Some(3), Some(5), None, Some(4)], + ); + + assert_eq!( + ca.quantile(0.1, QuantileMethod::Nearest).unwrap(), + Some(1.0) + ); + assert_eq!( + ca.quantile(0.9, QuantileMethod::Nearest).unwrap(), + Some(5.0) + ); + assert_eq!( + ca.quantile(0.6, QuantileMethod::Nearest).unwrap(), + Some(3.0) + ); + + assert_eq!(ca.quantile(0.1, QuantileMethod::Lower).unwrap(), Some(1.0)); + assert_eq!(ca.quantile(0.9, QuantileMethod::Lower).unwrap(), Some(4.0)); + assert_eq!(ca.quantile(0.6, QuantileMethod::Lower).unwrap(), Some(3.0)); + + assert_eq!(ca.quantile(0.1, QuantileMethod::Higher).unwrap(), Some(2.0)); + assert_eq!(ca.quantile(0.9, QuantileMethod::Higher).unwrap(), Some(5.0)); + assert_eq!(ca.quantile(0.6, QuantileMethod::Higher).unwrap(), Some(4.0)); + + assert_eq!( + ca.quantile(0.1, QuantileMethod::Midpoint).unwrap(), + Some(1.5) + ); + assert_eq!( + ca.quantile(0.9, QuantileMethod::Midpoint).unwrap(), + Some(4.5) + ); + assert_eq!( + ca.quantile(0.6, QuantileMethod::Midpoint).unwrap(), + Some(3.5) + ); + + assert_eq!(ca.quantile(0.1, QuantileMethod::Linear).unwrap(), Some(1.4)); + assert_eq!(ca.quantile(0.9, QuantileMethod::Linear).unwrap(), Some(4.6)); + assert!( + (ca.quantile(0.6, QuantileMethod::Linear).unwrap().unwrap() - 3.4).abs() < 0.0000001 + ); + + assert_eq!( + ca.quantile(0.15, QuantileMethod::Equiprobable).unwrap(), + Some(1.0) + ); + assert_eq!( + ca.quantile(0.25, QuantileMethod::Equiprobable).unwrap(), + Some(2.0) + ); + assert_eq!( + ca.quantile(0.6, QuantileMethod::Equiprobable).unwrap(), + Some(3.0) + ); + + let ca = UInt32Chunked::new( + PlSmallStr::from_static("a"), + &[ + None, + Some(7), + Some(6), + Some(2), + Some(1), + None, + Some(3), + Some(5), + None, + Some(4), + ], + ); + + assert_eq!( + ca.quantile(0.1, QuantileMethod::Nearest).unwrap(), + Some(2.0) + ); + assert_eq!( + ca.quantile(0.9, QuantileMethod::Nearest).unwrap(), + Some(6.0) + ); + assert_eq!( + ca.quantile(0.6, QuantileMethod::Nearest).unwrap(), + Some(5.0) + ); + + assert_eq!(ca.quantile(0.1, QuantileMethod::Lower).unwrap(), Some(1.0)); + assert_eq!(ca.quantile(0.9, QuantileMethod::Lower).unwrap(), Some(6.0)); + assert_eq!(ca.quantile(0.6, QuantileMethod::Lower).unwrap(), Some(4.0)); + + assert_eq!(ca.quantile(0.1, QuantileMethod::Higher).unwrap(), Some(2.0)); + assert_eq!(ca.quantile(0.9, QuantileMethod::Higher).unwrap(), Some(7.0)); + assert_eq!(ca.quantile(0.6, QuantileMethod::Higher).unwrap(), Some(5.0)); + + assert_eq!( + ca.quantile(0.1, QuantileMethod::Midpoint).unwrap(), + Some(1.5) + ); + assert_eq!( + ca.quantile(0.9, QuantileMethod::Midpoint).unwrap(), + Some(6.5) + ); + assert_eq!( + ca.quantile(0.6, QuantileMethod::Midpoint).unwrap(), + Some(4.5) + ); + + assert_eq!(ca.quantile(0.1, QuantileMethod::Linear).unwrap(), Some(1.6)); + assert_eq!(ca.quantile(0.9, QuantileMethod::Linear).unwrap(), Some(6.4)); + assert_eq!(ca.quantile(0.6, QuantileMethod::Linear).unwrap(), Some(4.6)); + + assert_eq!( + ca.quantile(0.14, QuantileMethod::Equiprobable).unwrap(), + Some(1.0) + ); + assert_eq!( + ca.quantile(0.15, QuantileMethod::Equiprobable).unwrap(), + Some(2.0) + ); + assert_eq!( + ca.quantile(0.6, QuantileMethod::Equiprobable).unwrap(), + Some(5.0) + ); + } +} diff --git a/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs b/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs new file mode 100644 index 000000000000..5a2ff1739ebe --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs @@ -0,0 +1,274 @@ +use polars_compute::rolling::QuantileMethod; + +use super::*; + +pub trait QuantileAggSeries { + /// Get the median of the [`ChunkedArray`] as a new [`Series`] of length 1. + fn median_reduce(&self) -> Scalar; + /// Get the quantile of the [`ChunkedArray`] as a new [`Series`] of length 1. + fn quantile_reduce(&self, _quantile: f64, _method: QuantileMethod) -> PolarsResult; +} + +/// helper +fn quantile_idx( + quantile: f64, + length: usize, + null_count: usize, + method: QuantileMethod, +) -> (usize, f64, usize) { + let nonnull_count = (length - null_count) as f64; + let float_idx = (nonnull_count - 1.0) * quantile + null_count as f64; + let mut base_idx = match method { + QuantileMethod::Nearest => { + let idx = float_idx.round() as usize; + return (idx, 0.0, idx); + }, + QuantileMethod::Lower | QuantileMethod::Midpoint | QuantileMethod::Linear => { + float_idx as usize + }, + QuantileMethod::Higher => float_idx.ceil() as usize, + QuantileMethod::Equiprobable => { + let idx = ((nonnull_count * quantile).ceil() - 1.0).max(0.0) as usize + null_count; + return (idx, 0.0, idx); + }, + }; + + base_idx = base_idx.clamp(0, length - 1); + let top_idx = f64::ceil(float_idx) as usize; + + (base_idx, float_idx, top_idx) +} + +/// helper +fn linear_interpol(lower: T, upper: T, idx: usize, float_idx: f64) -> T { + if lower == upper { + lower + } else { + let proportion: T = T::from(float_idx).unwrap() - T::from(idx).unwrap(); + proportion * (upper - lower) + lower + } +} +fn midpoint_interpol(lower: T, upper: T) -> T { + if lower == upper { + lower + } else { + (lower + upper) / (T::one() + T::one()) + } +} + +// Uses quickselect instead of sorting all data +fn quantile_slice( + vals: &mut [T], + quantile: f64, + method: QuantileMethod, +) -> PolarsResult> { + polars_ensure!((0.0..=1.0).contains(&quantile), + ComputeError: "quantile should be between 0.0 and 1.0", + ); + if vals.is_empty() { + return Ok(None); + } + if vals.len() == 1 { + return Ok(vals[0].to_f64()); + } + let (idx, float_idx, top_idx) = quantile_idx(quantile, vals.len(), 0, method); + + let (_lhs, lower, rhs) = vals.select_nth_unstable_by(idx, TotalOrd::tot_cmp); + if idx == top_idx { + Ok(lower.to_f64()) + } else { + match method { + QuantileMethod::Midpoint => { + let upper = rhs.iter().copied().min_by(TotalOrd::tot_cmp).unwrap(); + Ok(Some(midpoint_interpol( + lower.to_f64().unwrap(), + upper.to_f64().unwrap(), + ))) + }, + QuantileMethod::Linear => { + let upper = rhs.iter().copied().min_by(TotalOrd::tot_cmp).unwrap(); + Ok(linear_interpol( + lower.to_f64().unwrap(), + upper.to_f64().unwrap(), + idx, + float_idx, + ) + .to_f64()) + }, + _ => Ok(lower.to_f64()), + } + } +} + +fn generic_quantile( + ca: ChunkedArray, + quantile: f64, + method: QuantileMethod, +) -> PolarsResult> +where + T: PolarsNumericType, +{ + polars_ensure!( + (0.0..=1.0).contains(&quantile), + ComputeError: "`quantile` should be between 0.0 and 1.0", + ); + + let null_count = ca.null_count(); + let length = ca.len(); + + if null_count == length { + return Ok(None); + } + + let (idx, float_idx, top_idx) = quantile_idx(quantile, length, null_count, method); + let sorted = ca.sort(false); + let lower = sorted.get(idx).map(|v| v.to_f64().unwrap()); + + let opt = match method { + QuantileMethod::Midpoint => { + if top_idx == idx { + lower + } else { + let upper = sorted.get(idx + 1).map(|v| v.to_f64().unwrap()); + midpoint_interpol(lower.unwrap(), upper.unwrap()).to_f64() + } + }, + QuantileMethod::Linear => { + if top_idx == idx { + lower + } else { + let upper = sorted.get(idx + 1).map(|v| v.to_f64().unwrap()); + + linear_interpol(lower.unwrap(), upper.unwrap(), idx, float_idx).to_f64() + } + }, + _ => lower, + }; + Ok(opt) +} + +impl ChunkQuantile for ChunkedArray +where + T: PolarsIntegerType, + T::Native: TotalOrd, +{ + fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult> { + // in case of sorted data, the sort is free, so don't take quickselect route + if let (Ok(slice), false) = (self.cont_slice(), self.is_sorted_ascending_flag()) { + let mut owned = slice.to_vec(); + quantile_slice(&mut owned, quantile, method) + } else { + generic_quantile(self.clone(), quantile, method) + } + } + + fn median(&self) -> Option { + self.quantile(0.5, QuantileMethod::Linear).unwrap() // unwrap fine since quantile in range + } +} + +// Version of quantile/median that don't need a memcpy +impl ChunkedArray +where + T: PolarsIntegerType, + T::Native: TotalOrd, +{ + pub(crate) fn quantile_faster( + mut self, + quantile: f64, + method: QuantileMethod, + ) -> PolarsResult> { + // in case of sorted data, the sort is free, so don't take quickselect route + let is_sorted = self.is_sorted_ascending_flag(); + if let (Some(slice), false) = (self.cont_slice_mut(), is_sorted) { + quantile_slice(slice, quantile, method) + } else { + self.quantile(quantile, method) + } + } + + pub(crate) fn median_faster(self) -> Option { + self.quantile_faster(0.5, QuantileMethod::Linear).unwrap() + } +} + +impl ChunkQuantile for Float32Chunked { + fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult> { + // in case of sorted data, the sort is free, so don't take quickselect route + let out = if let (Ok(slice), false) = (self.cont_slice(), self.is_sorted_ascending_flag()) { + let mut owned = slice.to_vec(); + quantile_slice(&mut owned, quantile, method) + } else { + generic_quantile(self.clone(), quantile, method) + }; + out.map(|v| v.map(|v| v as f32)) + } + + fn median(&self) -> Option { + self.quantile(0.5, QuantileMethod::Linear).unwrap() // unwrap fine since quantile in range + } +} + +impl ChunkQuantile for Float64Chunked { + fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult> { + // in case of sorted data, the sort is free, so don't take quickselect route + if let (Ok(slice), false) = (self.cont_slice(), self.is_sorted_ascending_flag()) { + let mut owned = slice.to_vec(); + quantile_slice(&mut owned, quantile, method) + } else { + generic_quantile(self.clone(), quantile, method) + } + } + + fn median(&self) -> Option { + self.quantile(0.5, QuantileMethod::Linear).unwrap() // unwrap fine since quantile in range + } +} + +impl Float64Chunked { + pub(crate) fn quantile_faster( + mut self, + quantile: f64, + method: QuantileMethod, + ) -> PolarsResult> { + // in case of sorted data, the sort is free, so don't take quickselect route + let is_sorted = self.is_sorted_ascending_flag(); + if let (Some(slice), false) = (self.cont_slice_mut(), is_sorted) { + quantile_slice(slice, quantile, method) + } else { + self.quantile(quantile, method) + } + } + + pub(crate) fn median_faster(self) -> Option { + self.quantile_faster(0.5, QuantileMethod::Linear).unwrap() + } +} + +impl Float32Chunked { + pub(crate) fn quantile_faster( + mut self, + quantile: f64, + method: QuantileMethod, + ) -> PolarsResult> { + // in case of sorted data, the sort is free, so don't take quickselect route + let is_sorted = self.is_sorted_ascending_flag(); + if let (Some(slice), false) = (self.cont_slice_mut(), is_sorted) { + quantile_slice(slice, quantile, method).map(|v| v.map(|v| v as f32)) + } else { + self.quantile(quantile, method) + } + } + + pub(crate) fn median_faster(self) -> Option { + self.quantile_faster(0.5, QuantileMethod::Linear).unwrap() + } +} + +impl ChunkQuantile for StringChunked {} +impl ChunkQuantile for ListChunked {} +#[cfg(feature = "dtype-array")] +impl ChunkQuantile for ArrayChunked {} +#[cfg(feature = "object")] +impl ChunkQuantile for ObjectChunked {} +impl ChunkQuantile for BooleanChunked {} diff --git a/crates/polars-core/src/chunked_array/ops/aggregate/var.rs b/crates/polars-core/src/chunked_array/ops/aggregate/var.rs new file mode 100644 index 000000000000..791ad69fb7aa --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/aggregate/var.rs @@ -0,0 +1,36 @@ +use polars_compute::moment::VarState; + +use super::*; + +pub trait VarAggSeries { + /// Get the variance of the [`ChunkedArray`] as a new [`Series`] of length 1. + fn var_reduce(&self, ddof: u8) -> Scalar; + /// Get the standard deviation of the [`ChunkedArray`] as a new [`Series`] of length 1. + fn std_reduce(&self, ddof: u8) -> Scalar; +} + +impl ChunkVar for ChunkedArray +where + T: PolarsNumericType, + ChunkedArray: ChunkAgg, +{ + fn var(&self, ddof: u8) -> Option { + let mut out = VarState::default(); + for arr in self.downcast_iter() { + out.combine(&polars_compute::moment::var(arr)) + } + out.finalize(ddof) + } + + fn std(&self, ddof: u8) -> Option { + self.var(ddof).map(|var| var.sqrt()) + } +} + +impl ChunkVar for StringChunked {} +impl ChunkVar for ListChunked {} +#[cfg(feature = "dtype-array")] +impl ChunkVar for ArrayChunked {} +#[cfg(feature = "object")] +impl ChunkVar for ObjectChunked {} +impl ChunkVar for BooleanChunked {} diff --git a/crates/polars-core/src/chunked_array/ops/any_value.rs b/crates/polars-core/src/chunked_array/ops/any_value.rs new file mode 100644 index 000000000000..e7227c27cd11 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/any_value.rs @@ -0,0 +1,359 @@ +#![allow(unsafe_op_in_unsafe_fn)] +#[cfg(feature = "dtype-categorical")] +use polars_utils::sync::SyncPtr; + +#[cfg(feature = "object")] +use crate::chunked_array::object::extension::polars_extension::PolarsExtension; +use crate::prelude::*; +use crate::series::implementations::null::NullChunked; +use crate::utils::index_to_chunked_index; + +#[inline] +#[allow(unused_variables)] +pub(crate) unsafe fn arr_to_any_value<'a>( + arr: &'a dyn Array, + idx: usize, + dtype: &'a DataType, +) -> AnyValue<'a> { + debug_assert!(idx < arr.len()); + if arr.is_null(idx) { + return AnyValue::Null; + } + + macro_rules! downcast_and_pack { + ($casttype:ident, $variant:ident) => {{ + let arr = &*(arr as *const dyn Array as *const $casttype); + let v = arr.value_unchecked(idx); + AnyValue::$variant(v) + }}; + } + macro_rules! downcast { + ($casttype:ident) => {{ + let arr = &*(arr as *const dyn Array as *const $casttype); + arr.value_unchecked(idx) + }}; + } + match dtype { + DataType::String => downcast_and_pack!(Utf8ViewArray, String), + DataType::Binary => downcast_and_pack!(BinaryViewArray, Binary), + DataType::Boolean => downcast_and_pack!(BooleanArray, Boolean), + DataType::UInt8 => downcast_and_pack!(UInt8Array, UInt8), + DataType::UInt16 => downcast_and_pack!(UInt16Array, UInt16), + DataType::UInt32 => downcast_and_pack!(UInt32Array, UInt32), + DataType::UInt64 => downcast_and_pack!(UInt64Array, UInt64), + DataType::Int8 => downcast_and_pack!(Int8Array, Int8), + DataType::Int16 => downcast_and_pack!(Int16Array, Int16), + DataType::Int32 => downcast_and_pack!(Int32Array, Int32), + DataType::Int64 => downcast_and_pack!(Int64Array, Int64), + DataType::Int128 => downcast_and_pack!(Int128Array, Int128), + DataType::Float32 => downcast_and_pack!(Float32Array, Float32), + DataType::Float64 => downcast_and_pack!(Float64Array, Float64), + DataType::List(dt) => { + let v: ArrayRef = downcast!(LargeListArray); + if dt.is_primitive() { + let s = Series::from_chunks_and_dtype_unchecked(PlSmallStr::EMPTY, vec![v], dt); + AnyValue::List(s) + } else { + let s = Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![v], + &dt.to_physical(), + ) + .from_physical_unchecked(dt) + .unwrap(); + AnyValue::List(s) + } + }, + #[cfg(feature = "dtype-array")] + DataType::Array(dt, width) => { + let v: ArrayRef = downcast!(FixedSizeListArray); + if dt.is_primitive() { + let s = Series::from_chunks_and_dtype_unchecked(PlSmallStr::EMPTY, vec![v], dt); + AnyValue::Array(s, *width) + } else { + let s = Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![v], + &dt.to_physical(), + ) + .from_physical_unchecked(dt) + .unwrap(); + AnyValue::Array(s, *width) + } + }, + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(rev_map, _) => { + let arr = &*(arr as *const dyn Array as *const UInt32Array); + let v = arr.value_unchecked(idx); + AnyValue::Categorical(v, rev_map.as_ref().unwrap().as_ref(), SyncPtr::new_null()) + }, + #[cfg(feature = "dtype-categorical")] + DataType::Enum(rev_map, _) => { + let arr = &*(arr as *const dyn Array as *const UInt32Array); + let v = arr.value_unchecked(idx); + AnyValue::Enum(v, rev_map.as_ref().unwrap().as_ref(), SyncPtr::new_null()) + }, + #[cfg(feature = "dtype-struct")] + DataType::Struct(flds) => { + let arr = &*(arr as *const dyn Array as *const StructArray); + AnyValue::Struct(idx, arr, flds) + }, + #[cfg(feature = "dtype-datetime")] + DataType::Datetime(tu, tz) => { + let arr = &*(arr as *const dyn Array as *const Int64Array); + let v = arr.value_unchecked(idx); + AnyValue::Datetime(v, *tu, tz.as_ref()) + }, + #[cfg(feature = "dtype-date")] + DataType::Date => { + let arr = &*(arr as *const dyn Array as *const Int32Array); + let v = arr.value_unchecked(idx); + AnyValue::Date(v) + }, + #[cfg(feature = "dtype-duration")] + DataType::Duration(tu) => { + let arr = &*(arr as *const dyn Array as *const Int64Array); + let v = arr.value_unchecked(idx); + AnyValue::Duration(v, *tu) + }, + #[cfg(feature = "dtype-time")] + DataType::Time => { + let arr = &*(arr as *const dyn Array as *const Int64Array); + let v = arr.value_unchecked(idx); + AnyValue::Time(v) + }, + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(precision, scale) => { + let arr = &*(arr as *const dyn Array as *const Int128Array); + let v = arr.value_unchecked(idx); + AnyValue::Decimal(v, scale.unwrap_or_else(|| unreachable!())) + }, + #[cfg(feature = "object")] + DataType::Object(_) => { + // We should almost never hit this. The only known exception is when we put objects in + // structs. Any other hit should be considered a bug. + let arr = arr.as_any().downcast_ref::().unwrap(); + PolarsExtension::arr_to_av(arr, idx) + }, + DataType::Null => AnyValue::Null, + DataType::BinaryOffset => downcast_and_pack!(LargeBinaryArray, Binary), + dt => panic!("not implemented for {dt:?}"), + } +} + +#[cfg(feature = "dtype-struct")] +impl<'a> AnyValue<'a> { + pub fn _iter_struct_av(&self) -> impl Iterator { + match self { + AnyValue::Struct(idx, arr, flds) => { + let idx = *idx; + unsafe { + arr.values().iter().zip(*flds).map(move |(arr, fld)| { + // The dictionary arrays categories don't have to map to the rev-map in the dtype + // so we set the array pointer with values of the dictionary array. + #[cfg(feature = "dtype-categorical")] + { + use arrow::legacy::is_valid::IsValid as _; + if let Some(arr) = arr.as_any().downcast_ref::>() { + let keys = arr.keys(); + let values = arr.values(); + let values = + values.as_any().downcast_ref::().unwrap(); + let arr = &*(keys as *const dyn Array as *const UInt32Array); + + if arr.is_valid_unchecked(idx) { + let v = arr.value_unchecked(idx); + match fld.dtype() { + DataType::Categorical(Some(rev_map), _) => { + AnyValue::Categorical( + v, + rev_map, + SyncPtr::from_const(values), + ) + }, + DataType::Enum(Some(rev_map), _) => { + AnyValue::Enum(v, rev_map, SyncPtr::from_const(values)) + }, + _ => unimplemented!(), + } + } else { + AnyValue::Null + } + } else { + arr_to_any_value(&**arr, idx, fld.dtype()) + } + } + + #[cfg(not(feature = "dtype-categorical"))] + { + arr_to_any_value(&**arr, idx, fld.dtype()) + } + }) + } + }, + _ => unreachable!(), + } + } + + pub fn _materialize_struct_av(&'a self, buf: &mut Vec>) { + let iter = self._iter_struct_av(); + buf.extend(iter) + } +} + +macro_rules! get_any_value_unchecked { + ($self:ident, $index:expr) => {{ + let (chunk_idx, idx) = $self.index_to_chunked_index($index); + debug_assert!(chunk_idx < $self.chunks.len()); + let arr = &**$self.chunks.get_unchecked(chunk_idx); + debug_assert!(idx < arr.len()); + arr_to_any_value(arr, idx, $self.dtype()) + }}; +} + +macro_rules! get_any_value { + ($self:ident, $index:expr) => {{ + if $index >= $self.len() { + polars_bail!(oob = $index, $self.len()); + } + // SAFETY: + // bounds are checked + Ok(unsafe { $self.get_any_value_unchecked($index) }) + }}; +} + +impl ChunkAnyValue for ChunkedArray +where + T: PolarsNumericType, +{ + #[inline] + unsafe fn get_any_value_unchecked(&self, index: usize) -> AnyValue { + get_any_value_unchecked!(self, index) + } + + fn get_any_value(&self, index: usize) -> PolarsResult { + get_any_value!(self, index) + } +} + +impl ChunkAnyValue for BooleanChunked { + #[inline] + unsafe fn get_any_value_unchecked(&self, index: usize) -> AnyValue { + get_any_value_unchecked!(self, index) + } + + fn get_any_value(&self, index: usize) -> PolarsResult { + get_any_value!(self, index) + } +} + +impl ChunkAnyValue for StringChunked { + #[inline] + unsafe fn get_any_value_unchecked(&self, index: usize) -> AnyValue { + get_any_value_unchecked!(self, index) + } + + fn get_any_value(&self, index: usize) -> PolarsResult { + get_any_value!(self, index) + } +} + +impl ChunkAnyValue for BinaryChunked { + #[inline] + unsafe fn get_any_value_unchecked(&self, index: usize) -> AnyValue { + get_any_value_unchecked!(self, index) + } + + fn get_any_value(&self, index: usize) -> PolarsResult { + get_any_value!(self, index) + } +} + +impl ChunkAnyValue for BinaryOffsetChunked { + #[inline] + unsafe fn get_any_value_unchecked(&self, index: usize) -> AnyValue { + get_any_value_unchecked!(self, index) + } + + fn get_any_value(&self, index: usize) -> PolarsResult { + get_any_value!(self, index) + } +} + +impl ChunkAnyValue for ListChunked { + #[inline] + unsafe fn get_any_value_unchecked(&self, index: usize) -> AnyValue { + get_any_value_unchecked!(self, index) + } + + fn get_any_value(&self, index: usize) -> PolarsResult { + get_any_value!(self, index) + } +} + +#[cfg(feature = "dtype-array")] +impl ChunkAnyValue for ArrayChunked { + #[inline] + unsafe fn get_any_value_unchecked(&self, index: usize) -> AnyValue { + get_any_value_unchecked!(self, index) + } + + fn get_any_value(&self, index: usize) -> PolarsResult { + get_any_value!(self, index) + } +} + +#[cfg(feature = "object")] +impl ChunkAnyValue for ObjectChunked { + #[inline] + unsafe fn get_any_value_unchecked(&self, index: usize) -> AnyValue { + match self.get_object_unchecked(index) { + None => AnyValue::Null, + Some(v) => AnyValue::Object(v), + } + } + + fn get_any_value(&self, index: usize) -> PolarsResult { + get_any_value!(self, index) + } +} + +impl ChunkAnyValue for NullChunked { + #[inline] + unsafe fn get_any_value_unchecked(&self, _index: usize) -> AnyValue { + AnyValue::Null + } + + fn get_any_value(&self, _index: usize) -> PolarsResult { + Ok(AnyValue::Null) + } +} + +#[cfg(feature = "dtype-struct")] +impl ChunkAnyValue for StructChunked { + /// Gets AnyValue from LogicalType + fn get_any_value(&self, i: usize) -> PolarsResult> { + polars_ensure!(i < self.len(), oob = i, self.len()); + unsafe { Ok(self.get_any_value_unchecked(i)) } + } + + unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> { + let (chunk_idx, idx) = index_to_chunked_index(self.chunks.iter().map(|c| c.len()), i); + if let DataType::Struct(flds) = self.dtype() { + // SAFETY: we already have a single chunk and we are + // guarded by the type system. + unsafe { + let arr = &**self.chunks.get_unchecked(chunk_idx); + let arr = &*(arr as *const dyn Array as *const StructArray); + + if arr.is_null_unchecked(idx) { + AnyValue::Null + } else { + AnyValue::Struct(idx, arr, flds) + } + } + } else { + unreachable!() + } + } +} diff --git a/crates/polars-core/src/chunked_array/ops/append.rs b/crates/polars-core/src/chunked_array/ops/append.rs new file mode 100644 index 000000000000..70e6bcdbe2a8 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/append.rs @@ -0,0 +1,276 @@ +use polars_error::constants::LENGTH_LIMIT_MSG; + +use crate::prelude::*; +use crate::series::IsSorted; + +pub(crate) fn new_chunks(chunks: &mut Vec, other: &[ArrayRef], len: usize) { + // Replace an empty array. + if chunks.len() == 1 && len == 0 { + other.clone_into(chunks); + } else { + for chunk in other { + if !chunk.is_empty() { + chunks.push(chunk.clone()); + } + } + } +} + +pub(crate) fn new_chunks_owned(chunks: &mut Vec, other: Vec, len: usize) { + // Replace an empty array. + if chunks.len() == 1 && len == 0 { + *chunks = other; + } else { + chunks.reserve(other.len()); + chunks.extend(other.into_iter().filter(|c| !c.is_empty())); + } +} + +pub(super) fn update_sorted_flag_before_append(ca: &mut ChunkedArray, other: &ChunkedArray) +where + T: PolarsDataType, + for<'a> T::Physical<'a>: TotalOrd, +{ + // Note: Do not call (first|last)_non_null on an array here before checking + // it is sorted, otherwise it will lead to quadratic behavior. + let sorted_flag = match ( + ca.null_count() != ca.len(), + other.null_count() != other.len(), + ) { + (false, false) => IsSorted::Ascending, + (false, true) => { + if + // lhs is empty, just take sorted flag from rhs + ca.is_empty() + || ( + // lhs is non-empty and all-null, so rhs must have nulls ordered first + other.is_sorted_any() && 1 + other.last_non_null().unwrap() == other.len() + ) + { + other.is_sorted_flag() + } else { + IsSorted::Not + } + }, + (true, false) => { + if + // rhs is empty, just take sorted flag from lhs + other.is_empty() + || ( + // rhs is non-empty and all-null, so lhs must have nulls ordered last + ca.is_sorted_any() && ca.first_non_null().unwrap() == 0 + ) + { + ca.is_sorted_flag() + } else { + IsSorted::Not + } + }, + (true, true) => { + // both arrays have non-null values. + // for arrays of unit length we can ignore the sorted flag, as it is + // not necessarily set. + if !(ca.is_sorted_any() || ca.len() == 1) + || !(other.is_sorted_any() || other.len() == 1) + || !( + // We will coerce for single values + ca.len() - ca.null_count() == 1 + || other.len() - other.null_count() == 1 + || ca.is_sorted_flag() == other.is_sorted_flag() + ) + { + IsSorted::Not + } else { + let l_idx = ca.last_non_null().unwrap(); + let r_idx = other.first_non_null().unwrap(); + + let null_pos_check = + // check null positions + // lhs does not end in nulls + (1 + l_idx == ca.len()) + // rhs does not start with nulls + && (r_idx == 0) + // if there are nulls, they are all on one end + && !(ca.first_non_null().unwrap() != 0 && 1 + other.last_non_null().unwrap() != other.len()); + + if !null_pos_check { + IsSorted::Not + } else { + #[allow(unused_assignments)] + let mut out = IsSorted::Not; + + // This can be relatively expensive because of chunks, so delay as much as possible. + let l_val = unsafe { ca.value_unchecked(l_idx) }; + let r_val = unsafe { other.value_unchecked(r_idx) }; + + match ( + ca.len() - ca.null_count() == 1, + other.len() - other.null_count() == 1, + ) { + (true, true) => { + out = [IsSorted::Descending, IsSorted::Ascending] + [l_val.tot_le(&r_val) as usize]; + drop(l_val); + drop(r_val); + ca.set_sorted_flag(out); + return; + }, + (true, false) => out = other.is_sorted_flag(), + _ => out = ca.is_sorted_flag(), + } + + debug_assert!(!matches!(out, IsSorted::Not)); + + let check = if matches!(out, IsSorted::Ascending) { + l_val.tot_le(&r_val) + } else { + l_val.tot_ge(&r_val) + }; + + if !check { + out = IsSorted::Not + } + + out + } + } + }, + }; + + ca.set_sorted_flag(sorted_flag); +} + +impl ChunkedArray +where + T: PolarsDataType, + for<'a> T::Physical<'a>: TotalOrd, +{ + /// Append in place. This is done by adding the chunks of `other` to this [`ChunkedArray`]. + /// + /// See also [`extend`](Self::extend) for appends to the underlying memory + pub fn append(&mut self, other: &Self) -> PolarsResult<()> { + update_sorted_flag_before_append::(self, other); + let len = self.len(); + self.length = self + .length + .checked_add(other.length) + .ok_or_else(|| polars_err!(ComputeError: LENGTH_LIMIT_MSG))?; + self.null_count += other.null_count; + new_chunks(&mut self.chunks, &other.chunks, len); + Ok(()) + } + + /// Append in place. This is done by adding the chunks of `other` to this [`ChunkedArray`]. + /// + /// See also [`extend`](Self::extend) for appends to the underlying memory + pub fn append_owned(&mut self, mut other: Self) -> PolarsResult<()> { + update_sorted_flag_before_append::(self, &other); + let len = self.len(); + self.length = self + .length + .checked_add(other.length) + .ok_or_else(|| polars_err!(ComputeError: LENGTH_LIMIT_MSG))?; + self.null_count += other.null_count; + new_chunks_owned(&mut self.chunks, std::mem::take(&mut other.chunks), len); + Ok(()) + } +} + +#[doc(hidden)] +impl ListChunked { + pub fn append(&mut self, other: &Self) -> PolarsResult<()> { + self.append_owned(other.clone()) + } + + pub fn append_owned(&mut self, mut other: Self) -> PolarsResult<()> { + let dtype = merge_dtypes(self.dtype(), other.dtype())?; + self.field = Arc::new(Field::new(self.name().clone(), dtype)); + + let len = self.len(); + self.length = self + .length + .checked_add(other.length) + .ok_or_else(|| polars_err!(ComputeError: LENGTH_LIMIT_MSG))?; + self.null_count += other.null_count; + self.set_sorted_flag(IsSorted::Not); + if !other.get_fast_explode_list() { + self.unset_fast_explode_list() + } + + new_chunks_owned(&mut self.chunks, std::mem::take(&mut other.chunks), len); + Ok(()) + } +} + +#[cfg(feature = "dtype-array")] +#[doc(hidden)] +impl ArrayChunked { + pub fn append(&mut self, other: &Self) -> PolarsResult<()> { + self.append_owned(other.clone()) + } + + pub fn append_owned(&mut self, mut other: Self) -> PolarsResult<()> { + let dtype = merge_dtypes(self.dtype(), other.dtype())?; + self.field = Arc::new(Field::new(self.name().clone(), dtype)); + + let len = self.len(); + + self.length = self + .length + .checked_add(other.length) + .ok_or_else(|| polars_err!(ComputeError: LENGTH_LIMIT_MSG))?; + self.null_count += other.null_count; + + self.set_sorted_flag(IsSorted::Not); + + new_chunks_owned(&mut self.chunks, std::mem::take(&mut other.chunks), len); + Ok(()) + } +} + +#[cfg(feature = "dtype-struct")] +#[doc(hidden)] +impl StructChunked { + pub fn append(&mut self, other: &Self) -> PolarsResult<()> { + self.append_owned(other.clone()) + } + + pub fn append_owned(&mut self, mut other: Self) -> PolarsResult<()> { + let dtype = merge_dtypes(self.dtype(), other.dtype())?; + self.field = Arc::new(Field::new(self.name().clone(), dtype)); + + let len = self.len(); + + self.length = self + .length + .checked_add(other.length) + .ok_or_else(|| polars_err!(ComputeError: LENGTH_LIMIT_MSG))?; + self.null_count += other.null_count; + + self.set_sorted_flag(IsSorted::Not); + + new_chunks_owned(&mut self.chunks, std::mem::take(&mut other.chunks), len); + Ok(()) + } +} + +#[cfg(feature = "object")] +#[doc(hidden)] +impl ObjectChunked { + pub fn append(&mut self, other: &Self) -> PolarsResult<()> { + self.append_owned(other.clone()) + } + + pub fn append_owned(&mut self, mut other: Self) -> PolarsResult<()> { + let len = self.len(); + self.length = self + .length + .checked_add(other.length) + .ok_or_else(|| polars_err!(ComputeError: LENGTH_LIMIT_MSG))?; + self.null_count += other.null_count; + self.set_sorted_flag(IsSorted::Not); + + new_chunks_owned(&mut self.chunks, std::mem::take(&mut other.chunks), len); + Ok(()) + } +} diff --git a/crates/polars-core/src/chunked_array/ops/apply.rs b/crates/polars-core/src/chunked_array/ops/apply.rs new file mode 100644 index 000000000000..c8fcae796e87 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/apply.rs @@ -0,0 +1,589 @@ +//! Implementations of the ChunkApply Trait. +#![allow(unsafe_op_in_unsafe_fn)] +use std::borrow::Cow; + +use crate::chunked_array::arity::{unary_elementwise, unary_elementwise_values}; +use crate::chunked_array::cast::CastOptions; +use crate::prelude::*; +use crate::series::IsSorted; + +impl ChunkedArray +where + T: PolarsDataType, +{ + /// Applies a function only to the non-null elements, propagating nulls. + pub fn apply_nonnull_values_generic<'a, U, K, F>( + &'a self, + dtype: DataType, + mut op: F, + ) -> ChunkedArray + where + U: PolarsDataType, + F: FnMut(T::Physical<'a>) -> K, + U::Array: ArrayFromIterDtype + ArrayFromIterDtype>, + { + let iter = self.downcast_iter().map(|arr| { + if arr.null_count() == 0 { + let out: U::Array = arr + .values_iter() + .map(&mut op) + .collect_arr_with_dtype(dtype.to_arrow(CompatLevel::newest())); + out.with_validity_typed(arr.validity().cloned()) + } else { + let out: U::Array = arr + .iter() + .map(|opt| opt.map(&mut op)) + .collect_arr_with_dtype(dtype.to_arrow(CompatLevel::newest())); + out.with_validity_typed(arr.validity().cloned()) + } + }); + + ChunkedArray::from_chunk_iter(self.name().clone(), iter) + } + + /// Applies a function only to the non-null elements, propagating nulls. + pub fn try_apply_nonnull_values_generic<'a, U, K, F, E>( + &'a self, + mut op: F, + ) -> Result, E> + where + U: PolarsDataType, + F: FnMut(T::Physical<'a>) -> Result, + U::Array: ArrayFromIter + ArrayFromIter>, + { + let iter = self.downcast_iter().map(|arr| { + let arr = if arr.null_count() == 0 { + let out: U::Array = arr.values_iter().map(&mut op).try_collect_arr()?; + out.with_validity_typed(arr.validity().cloned()) + } else { + let out: U::Array = arr + .iter() + .map(|opt| opt.map(&mut op).transpose()) + .try_collect_arr()?; + out.with_validity_typed(arr.validity().cloned()) + }; + Ok(arr) + }); + + ChunkedArray::try_from_chunk_iter(self.name().clone(), iter) + } + + pub fn apply_into_string_amortized<'a, F>(&'a self, mut f: F) -> StringChunked + where + F: FnMut(T::Physical<'a>, &mut String), + { + let mut buf = String::new(); + let chunks = self + .downcast_iter() + .map(|arr| { + let mut mutarr = MutablePlString::with_capacity(arr.len()); + arr.iter().for_each(|opt| match opt { + None => mutarr.push_null(), + Some(v) => { + buf.clear(); + f(v, &mut buf); + mutarr.push_value(&buf) + }, + }); + mutarr.freeze() + }) + .collect::>(); + ChunkedArray::from_chunk_iter(self.name().clone(), chunks) + } + + pub fn try_apply_into_string_amortized<'a, F, E>(&'a self, mut f: F) -> Result + where + F: FnMut(T::Physical<'a>, &mut String) -> Result<(), E>, + { + let mut buf = String::new(); + let chunks = self + .downcast_iter() + .map(|arr| { + let mut mutarr = MutablePlString::with_capacity(arr.len()); + for opt in arr.iter() { + match opt { + None => mutarr.push_null(), + Some(v) => { + buf.clear(); + f(v, &mut buf)?; + mutarr.push_value(&buf) + }, + }; + } + Ok(mutarr.freeze()) + }) + .collect::>(); + ChunkedArray::try_from_chunk_iter(self.name().clone(), chunks) + } +} + +fn apply_in_place_impl(name: PlSmallStr, chunks: Vec, f: F) -> ChunkedArray +where + F: Fn(S::Native) -> S::Native + Copy, + S: PolarsNumericType, +{ + use arrow::Either::*; + let chunks = chunks.into_iter().map(|arr| { + let owned_arr = arr + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); + // Make sure we have a single ref count coming in. + drop(arr); + + let compute_immutable = |arr: &PrimitiveArray| { + arrow::compute::arity::unary(arr, f, S::get_dtype().to_arrow(CompatLevel::newest())) + }; + + if owned_arr.values().is_sliced() { + compute_immutable(&owned_arr) + } else { + match owned_arr.into_mut() { + Left(immutable) => compute_immutable(&immutable), + Right(mut mutable) => { + let vals = mutable.values_mut_slice(); + vals.iter_mut().for_each(|v| *v = f(*v)); + mutable.into() + }, + } + } + }); + + ChunkedArray::from_chunk_iter(name, chunks) +} + +impl ChunkedArray { + /// Cast a numeric array to another numeric data type and apply a function in place. + /// This saves an allocation. + pub fn cast_and_apply_in_place(&self, f: F) -> ChunkedArray + where + F: Fn(S::Native) -> S::Native + Copy, + S: PolarsNumericType, + { + // if we cast, we create a new arrow buffer + // then we clone the arrays and drop the cast arrays + // this will ensure we have a single ref count + // and we can mutate in place + let chunks = { + let s = self + .cast_with_options(&S::get_dtype(), CastOptions::Overflowing) + .unwrap(); + s.chunks().clone() + }; + apply_in_place_impl(self.name().clone(), chunks, f) + } + + /// Cast a numeric array to another numeric data type and apply a function in place. + /// This saves an allocation. + pub fn apply_in_place(mut self, f: F) -> Self + where + F: Fn(T::Native) -> T::Native + Copy, + { + let chunks = std::mem::take(&mut self.chunks); + apply_in_place_impl(self.name().clone(), chunks, f) + } +} + +impl ChunkedArray { + pub fn apply_mut(&mut self, f: F) + where + F: Fn(T::Native) -> T::Native + Copy, + { + // SAFETY, we do no t change the lengths + unsafe { + self.downcast_iter_mut() + .for_each(|arr| arrow::compute::arity_assign::unary(arr, f)) + }; + // can be in any order now + self.compute_len(); + self.set_sorted_flag(IsSorted::Not); + } +} + +impl<'a, T> ChunkApply<'a, T::Native> for ChunkedArray +where + T: PolarsNumericType, +{ + type FuncRet = T::Native; + + fn apply_values(&'a self, f: F) -> Self + where + F: Fn(T::Native) -> T::Native + Copy, + { + let chunks = self + .data_views() + .zip(self.iter_validities()) + .map(|(slice, validity)| { + let arr: T::Array = slice.iter().copied().map(f).collect_arr(); + arr.with_validity(validity.cloned()) + }); + ChunkedArray::from_chunk_iter(self.name().clone(), chunks) + } + + fn apply(&'a self, f: F) -> Self + where + F: Fn(Option) -> Option + Copy, + { + let chunks = self.downcast_iter().map(|arr| { + let iter = arr.into_iter().map(|opt_v| f(opt_v.copied())); + PrimitiveArray::::from_trusted_len_iter(iter) + }); + Self::from_chunk_iter(self.name().clone(), chunks) + } + + fn apply_to_slice(&'a self, f: F, slice: &mut [V]) + where + F: Fn(Option, &V) -> V, + { + assert!(slice.len() >= self.len()); + + let mut idx = 0; + self.downcast_iter().for_each(|arr| { + arr.into_iter().for_each(|opt_val| { + // SAFETY: + // length asserted above + let item = unsafe { slice.get_unchecked_mut(idx) }; + *item = f(opt_val.copied(), item); + idx += 1; + }) + }); + } +} + +impl<'a> ChunkApply<'a, bool> for BooleanChunked { + type FuncRet = bool; + + fn apply_values(&self, f: F) -> Self + where + F: Fn(bool) -> bool + Copy, + { + // Can just fully deduce behavior from two invocations. + match (f(false), f(true)) { + (false, false) => self.apply_kernel(&|arr| { + Box::new( + BooleanArray::full(arr.len(), false, ArrowDataType::Boolean) + .with_validity(arr.validity().cloned()), + ) + }), + (false, true) => self.clone(), + (true, false) => !self, + (true, true) => self.apply_kernel(&|arr| { + Box::new( + BooleanArray::full(arr.len(), true, ArrowDataType::Boolean) + .with_validity(arr.validity().cloned()), + ) + }), + } + } + + fn apply(&'a self, f: F) -> Self + where + F: Fn(Option) -> Option + Copy, + { + unary_elementwise(self, f) + } + + fn apply_to_slice(&'a self, f: F, slice: &mut [T]) + where + F: Fn(Option, &T) -> T, + { + assert!(slice.len() >= self.len()); + + let mut idx = 0; + self.downcast_iter().for_each(|arr| { + arr.into_iter().for_each(|opt_val| { + // SAFETY: + // length asserted above + let item = unsafe { slice.get_unchecked_mut(idx) }; + *item = f(opt_val, item); + idx += 1; + }) + }); + } +} + +impl StringChunked { + pub fn apply_mut<'a, F>(&'a self, mut f: F) -> Self + where + F: FnMut(&'a str) -> &'a str, + { + let chunks = self.downcast_iter().map(|arr| { + let iter = arr.values_iter().map(&mut f); + let new = Utf8ViewArray::arr_from_iter(iter); + new.with_validity(arr.validity().cloned()) + }); + StringChunked::from_chunk_iter(self.name().clone(), chunks) + } +} + +impl BinaryChunked { + pub fn apply_mut<'a, F>(&'a self, mut f: F) -> Self + where + F: FnMut(&'a [u8]) -> &'a [u8], + { + let chunks = self.downcast_iter().map(|arr| { + let iter = arr.values_iter().map(&mut f); + let new = BinaryViewArray::arr_from_iter(iter); + new.with_validity(arr.validity().cloned()) + }); + BinaryChunked::from_chunk_iter(self.name().clone(), chunks) + } +} + +impl<'a> ChunkApply<'a, &'a str> for StringChunked { + type FuncRet = Cow<'a, str>; + + fn apply_values(&'a self, f: F) -> Self + where + F: Fn(&'a str) -> Cow<'a, str> + Copy, + { + unary_elementwise_values(self, f) + } + + fn apply(&'a self, f: F) -> Self + where + F: Fn(Option<&'a str>) -> Option> + Copy, + { + unary_elementwise(self, f) + } + + fn apply_to_slice(&'a self, f: F, slice: &mut [T]) + where + F: Fn(Option<&'a str>, &T) -> T, + { + assert!(slice.len() >= self.len()); + + let mut idx = 0; + self.downcast_iter().for_each(|arr| { + arr.into_iter().for_each(|opt_val| { + // SAFETY: + // length asserted above + let item = unsafe { slice.get_unchecked_mut(idx) }; + *item = f(opt_val, item); + idx += 1; + }) + }); + } +} + +impl<'a> ChunkApply<'a, &'a [u8]> for BinaryChunked { + type FuncRet = Cow<'a, [u8]>; + + fn apply_values(&'a self, f: F) -> Self + where + F: Fn(&'a [u8]) -> Cow<'a, [u8]> + Copy, + { + unary_elementwise_values(self, f) + } + + fn apply(&'a self, f: F) -> Self + where + F: Fn(Option<&'a [u8]>) -> Option> + Copy, + { + unary_elementwise(self, f) + } + + fn apply_to_slice(&'a self, f: F, slice: &mut [T]) + where + F: Fn(Option<&'a [u8]>, &T) -> T, + { + assert!(slice.len() >= self.len()); + + let mut idx = 0; + self.downcast_iter().for_each(|arr| { + arr.into_iter().for_each(|opt_val| { + // SAFETY: + // length asserted above + let item = unsafe { slice.get_unchecked_mut(idx) }; + *item = f(opt_val, item); + idx += 1; + }) + }); + } +} + +impl ChunkApplyKernel for BooleanChunked { + fn apply_kernel(&self, f: &dyn Fn(&BooleanArray) -> ArrayRef) -> Self { + let chunks = self.downcast_iter().map(f).collect(); + unsafe { Self::from_chunks(self.name().clone(), chunks) } + } + + fn apply_kernel_cast(&self, f: &dyn Fn(&BooleanArray) -> ArrayRef) -> ChunkedArray + where + S: PolarsDataType, + { + let chunks = self.downcast_iter().map(f).collect(); + unsafe { ChunkedArray::::from_chunks(self.name().clone(), chunks) } + } +} + +impl ChunkApplyKernel> for ChunkedArray +where + T: PolarsNumericType, +{ + fn apply_kernel(&self, f: &dyn Fn(&PrimitiveArray) -> ArrayRef) -> Self { + self.apply_kernel_cast(&f) + } + fn apply_kernel_cast( + &self, + f: &dyn Fn(&PrimitiveArray) -> ArrayRef, + ) -> ChunkedArray + where + S: PolarsDataType, + { + let chunks = self.downcast_iter().map(f).collect(); + unsafe { ChunkedArray::from_chunks(self.name().clone(), chunks) } + } +} + +impl ChunkApplyKernel for StringChunked { + fn apply_kernel(&self, f: &dyn Fn(&Utf8ViewArray) -> ArrayRef) -> Self { + self.apply_kernel_cast(&f) + } + + fn apply_kernel_cast(&self, f: &dyn Fn(&Utf8ViewArray) -> ArrayRef) -> ChunkedArray + where + S: PolarsDataType, + { + let chunks = self.downcast_iter().map(f).collect(); + unsafe { ChunkedArray::from_chunks(self.name().clone(), chunks) } + } +} + +impl ChunkApplyKernel for BinaryChunked { + fn apply_kernel(&self, f: &dyn Fn(&BinaryViewArray) -> ArrayRef) -> Self { + self.apply_kernel_cast(&f) + } + + fn apply_kernel_cast(&self, f: &dyn Fn(&BinaryViewArray) -> ArrayRef) -> ChunkedArray + where + S: PolarsDataType, + { + let chunks = self.downcast_iter().map(f).collect(); + unsafe { ChunkedArray::from_chunks(self.name().clone(), chunks) } + } +} + +impl<'a> ChunkApply<'a, Series> for ListChunked { + type FuncRet = Series; + + /// Apply a closure `F` elementwise. + fn apply_values(&'a self, f: F) -> Self + where + F: Fn(Series) -> Series + Copy, + { + if self.is_empty() { + return self.clone(); + } + let mut fast_explode = true; + let mut function = |s: Series| { + let out = f(s); + if out.is_empty() { + fast_explode = false; + } + out + }; + let mut ca: ListChunked = { + if !self.has_nulls() { + self.into_no_null_iter() + .map(&mut function) + .collect_trusted() + } else { + self.into_iter() + .map(|opt_v| opt_v.map(&mut function)) + .collect_trusted() + } + }; + if fast_explode { + ca.set_fast_explode() + } + ca + } + + fn apply(&'a self, f: F) -> Self + where + F: Fn(Option) -> Option + Copy, + { + if self.is_empty() { + return self.clone(); + } + self.into_iter().map(f).collect_trusted() + } + + fn apply_to_slice(&'a self, f: F, slice: &mut [T]) + where + F: Fn(Option, &T) -> T, + { + assert!(slice.len() >= self.len()); + + let mut idx = 0; + self.downcast_iter().for_each(|arr| { + arr.iter().for_each(|opt_val| { + let opt_val = opt_val + .map(|arrayref| Series::try_from((PlSmallStr::EMPTY, arrayref)).unwrap()); + + // SAFETY: + // length asserted above + let item = unsafe { slice.get_unchecked_mut(idx) }; + *item = f(opt_val, item); + idx += 1; + }) + }); + } +} + +#[cfg(feature = "object")] +impl<'a, T> ChunkApply<'a, &'a T> for ObjectChunked +where + T: PolarsObject, +{ + type FuncRet = T; + + fn apply_values(&'a self, f: F) -> Self + where + F: Fn(&'a T) -> T + Copy, + { + let mut ca: ObjectChunked = self.into_iter().map(|opt_v| opt_v.map(f)).collect(); + ca.rename(self.name().clone()); + ca + } + + fn apply(&'a self, f: F) -> Self + where + F: Fn(Option<&'a T>) -> Option + Copy, + { + let mut ca: ObjectChunked = self.into_iter().map(f).collect(); + ca.rename(self.name().clone()); + ca + } + + fn apply_to_slice(&'a self, f: F, slice: &mut [V]) + where + F: Fn(Option<&'a T>, &V) -> V, + { + assert!(slice.len() >= self.len()); + let mut idx = 0; + self.downcast_iter().for_each(|arr| { + arr.into_iter().for_each(|opt_val| { + // SAFETY: + // length asserted above + let item = unsafe { slice.get_unchecked_mut(idx) }; + *item = f(opt_val, item); + idx += 1; + }) + }); + } +} + +impl StringChunked { + /// # Safety + /// Update the views. All invariants of the views apply. + pub unsafe fn apply_views View + Copy>(&self, update_view: F) -> Self { + let mut out = self.clone(); + for arr in out.downcast_iter_mut() { + *arr = arr.apply_views(update_view); + } + out + } +} diff --git a/crates/polars-core/src/chunked_array/ops/approx_n_unique.rs b/crates/polars-core/src/chunked_array/ops/approx_n_unique.rs new file mode 100644 index 000000000000..dc3d4c45a352 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/approx_n_unique.rs @@ -0,0 +1,20 @@ +use std::hash::Hash; + +use polars_compute::hyperloglogplus::HyperLogLog; +use polars_utils::IdxSize; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; + +use super::{ChunkApproxNUnique, ChunkedArray, PolarsDataType}; + +impl ChunkApproxNUnique for ChunkedArray +where + T: PolarsDataType, + for<'a> T::Physical<'a>: TotalHash + TotalEq + Copy + ToTotalOrd, + for<'a> > as ToTotalOrd>::TotalOrdItem: Hash + Eq, +{ + fn approx_n_unique(&self) -> IdxSize { + let mut hllp = HyperLogLog::new(); + self.iter().for_each(|item| hllp.add(&item.to_total_ord())); + hllp.count() as IdxSize + } +} diff --git a/crates/polars-core/src/chunked_array/ops/arity.rs b/crates/polars-core/src/chunked_array/ops/arity.rs new file mode 100644 index 000000000000..29d33b43dc73 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/arity.rs @@ -0,0 +1,868 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use std::error::Error; + +use arrow::array::{Array, MutablePlString, StaticArray}; +use arrow::compute::utils::combine_validities_and; +use polars_error::PolarsResult; +use polars_utils::pl_str::PlSmallStr; + +use crate::chunked_array::flags::StatisticsFlags; +use crate::datatypes::{ArrayCollectIterExt, ArrayFromIter}; +use crate::prelude::{ChunkedArray, CompatLevel, PolarsDataType, Series, StringChunked}; +use crate::utils::{align_chunks_binary, align_chunks_binary_owned, align_chunks_ternary}; + +// We need this helper because for<'a> notation can't yet be applied properly +// on the return type. +pub trait UnaryFnMut: FnMut(A1) -> Self::Ret { + type Ret; +} + +impl R> UnaryFnMut for T { + type Ret = R; +} + +// We need this helper because for<'a> notation can't yet be applied properly +// on the return type. +pub trait TernaryFnMut: FnMut(A1, A2, A3) -> Self::Ret { + type Ret; +} + +impl R> TernaryFnMut for T { + type Ret = R; +} + +// We need this helper because for<'a> notation can't yet be applied properly +// on the return type. +pub trait BinaryFnMut: FnMut(A1, A2) -> Self::Ret { + type Ret; +} + +impl R> BinaryFnMut for T { + type Ret = R; +} + +/// Applies a kernel that produces `Array` types. +#[inline] +pub fn unary_kernel(ca: &ChunkedArray, op: F) -> ChunkedArray +where + T: PolarsDataType, + V: PolarsDataType, + Arr: Array, + F: FnMut(&T::Array) -> Arr, +{ + let iter = ca.downcast_iter().map(op); + ChunkedArray::from_chunk_iter(ca.name().clone(), iter) +} + +/// Applies a kernel that produces `Array` types. +#[inline] +pub fn unary_kernel_owned(ca: ChunkedArray, op: F) -> ChunkedArray +where + T: PolarsDataType, + V: PolarsDataType, + Arr: Array, + F: FnMut(T::Array) -> Arr, +{ + let name = ca.name().clone(); + let iter = ca.downcast_into_iter().map(op); + ChunkedArray::from_chunk_iter(name, iter) +} + +#[inline] +pub fn unary_elementwise<'a, T, V, F>(ca: &'a ChunkedArray, mut op: F) -> ChunkedArray +where + T: PolarsDataType, + V: PolarsDataType, + F: UnaryFnMut>>, + V::Array: ArrayFromIter<>>>::Ret>, +{ + if ca.has_nulls() { + let iter = ca + .downcast_iter() + .map(|arr| arr.iter().map(&mut op).collect_arr()); + ChunkedArray::from_chunk_iter(ca.name().clone(), iter) + } else { + let iter = ca + .downcast_iter() + .map(|arr| arr.values_iter().map(|x| op(Some(x))).collect_arr()); + ChunkedArray::from_chunk_iter(ca.name().clone(), iter) + } +} + +#[inline] +pub fn try_unary_elementwise<'a, T, V, F, K, E>( + ca: &'a ChunkedArray, + mut op: F, +) -> Result, E> +where + T: PolarsDataType, + V: PolarsDataType, + F: FnMut(Option>) -> Result, E>, + V::Array: ArrayFromIter>, +{ + let iter = ca + .downcast_iter() + .map(|arr| arr.iter().map(&mut op).try_collect_arr()); + ChunkedArray::try_from_chunk_iter(ca.name().clone(), iter) +} + +#[inline] +pub fn unary_elementwise_values<'a, T, V, F>(ca: &'a ChunkedArray, mut op: F) -> ChunkedArray +where + T: PolarsDataType, + V: PolarsDataType, + F: UnaryFnMut>, + V::Array: ArrayFromIter<>>::Ret>, +{ + if ca.null_count() == ca.len() { + let arr = V::Array::full_null(ca.len(), V::get_dtype().to_arrow(CompatLevel::newest())); + return ChunkedArray::with_chunk(ca.name().clone(), arr); + } + + let iter = ca.downcast_iter().map(|arr| { + let validity = arr.validity().cloned(); + let arr: V::Array = arr.values_iter().map(&mut op).collect_arr(); + arr.with_validity_typed(validity) + }); + ChunkedArray::from_chunk_iter(ca.name().clone(), iter) +} + +#[inline] +pub fn try_unary_elementwise_values<'a, T, V, F, K, E>( + ca: &'a ChunkedArray, + mut op: F, +) -> Result, E> +where + T: PolarsDataType, + V: PolarsDataType, + F: FnMut(T::Physical<'a>) -> Result, + V::Array: ArrayFromIter, +{ + if ca.null_count() == ca.len() { + let arr = V::Array::full_null(ca.len(), V::get_dtype().to_arrow(CompatLevel::newest())); + return Ok(ChunkedArray::with_chunk(ca.name().clone(), arr)); + } + + let iter = ca.downcast_iter().map(|arr| { + let validity = arr.validity().cloned(); + let arr: V::Array = arr.values_iter().map(&mut op).try_collect_arr()?; + Ok(arr.with_validity_typed(validity)) + }); + ChunkedArray::try_from_chunk_iter(ca.name().clone(), iter) +} + +/// Applies a kernel that produces `Array` types. +/// +/// Intended for kernels that apply on values, this function will apply the +/// validity mask afterwards. +#[inline] +pub fn unary_mut_values(ca: &ChunkedArray, mut op: F) -> ChunkedArray +where + T: PolarsDataType, + V: PolarsDataType, + Arr: Array + StaticArray, + F: FnMut(&T::Array) -> Arr, +{ + let iter = ca + .downcast_iter() + .map(|arr| op(arr).with_validity_typed(arr.validity().cloned())); + ChunkedArray::from_chunk_iter(ca.name().clone(), iter) +} + +/// Applies a kernel that produces `Array` types. +#[inline] +pub fn unary_mut_with_options(ca: &ChunkedArray, op: F) -> ChunkedArray +where + T: PolarsDataType, + V: PolarsDataType, + Arr: Array + StaticArray, + F: FnMut(&T::Array) -> Arr, +{ + ChunkedArray::from_chunk_iter(ca.name().clone(), ca.downcast_iter().map(op)) +} + +#[inline] +pub fn try_unary_mut_with_options( + ca: &ChunkedArray, + op: F, +) -> Result, E> +where + T: PolarsDataType, + V: PolarsDataType, + Arr: Array + StaticArray, + F: FnMut(&T::Array) -> Result, + E: Error, +{ + ChunkedArray::try_from_chunk_iter(ca.name().clone(), ca.downcast_iter().map(op)) +} + +#[inline] +pub fn binary_elementwise( + lhs: &ChunkedArray, + rhs: &ChunkedArray, + mut op: F, +) -> ChunkedArray +where + T: PolarsDataType, + U: PolarsDataType, + V: PolarsDataType, + F: for<'a> BinaryFnMut>, Option>>, + V::Array: for<'a> ArrayFromIter< + >, Option>>>::Ret, + >, +{ + let (lhs, rhs) = align_chunks_binary(lhs, rhs); + let iter = lhs + .downcast_iter() + .zip(rhs.downcast_iter()) + .map(|(lhs_arr, rhs_arr)| { + let element_iter = lhs_arr + .iter() + .zip(rhs_arr.iter()) + .map(|(lhs_opt_val, rhs_opt_val)| op(lhs_opt_val, rhs_opt_val)); + element_iter.collect_arr() + }); + ChunkedArray::from_chunk_iter(lhs.name().clone(), iter) +} + +#[inline] +pub fn binary_elementwise_for_each<'a, 'b, T, U, F>( + lhs: &'a ChunkedArray, + rhs: &'b ChunkedArray, + mut op: F, +) where + T: PolarsDataType, + U: PolarsDataType, + F: FnMut(Option>, Option>), +{ + let mut lhs_arr_iter = lhs.downcast_iter(); + let mut rhs_arr_iter = rhs.downcast_iter(); + + let lhs_arr = lhs_arr_iter.next().unwrap(); + let rhs_arr = rhs_arr_iter.next().unwrap(); + + let mut lhs_remaining = lhs_arr.len(); + let mut rhs_remaining = rhs_arr.len(); + let mut lhs_iter = lhs_arr.iter(); + let mut rhs_iter = rhs_arr.iter(); + + loop { + let range = std::cmp::min(lhs_remaining, rhs_remaining); + + for _ in 0..range { + // SAFETY: we loop until the smaller iter is exhausted. + let lhs_opt_val = unsafe { lhs_iter.next().unwrap_unchecked() }; + let rhs_opt_val = unsafe { rhs_iter.next().unwrap_unchecked() }; + op(lhs_opt_val, rhs_opt_val) + } + lhs_remaining -= range; + rhs_remaining -= range; + + if lhs_remaining == 0 { + let Some(new_arr) = lhs_arr_iter.next() else { + return; + }; + lhs_remaining = new_arr.len(); + lhs_iter = new_arr.iter(); + } + if rhs_remaining == 0 { + let Some(new_arr) = rhs_arr_iter.next() else { + return; + }; + rhs_remaining = new_arr.len(); + rhs_iter = new_arr.iter(); + } + } +} + +#[inline] +pub fn try_binary_elementwise( + lhs: &ChunkedArray, + rhs: &ChunkedArray, + mut op: F, +) -> Result, E> +where + T: PolarsDataType, + U: PolarsDataType, + V: PolarsDataType, + F: for<'a> FnMut(Option>, Option>) -> Result, E>, + V::Array: ArrayFromIter>, +{ + let (lhs, rhs) = align_chunks_binary(lhs, rhs); + let iter = lhs + .downcast_iter() + .zip(rhs.downcast_iter()) + .map(|(lhs_arr, rhs_arr)| { + let element_iter = lhs_arr + .iter() + .zip(rhs_arr.iter()) + .map(|(lhs_opt_val, rhs_opt_val)| op(lhs_opt_val, rhs_opt_val)); + element_iter.try_collect_arr() + }); + ChunkedArray::try_from_chunk_iter(lhs.name().clone(), iter) +} + +#[inline] +pub fn binary_elementwise_values( + lhs: &ChunkedArray, + rhs: &ChunkedArray, + mut op: F, +) -> ChunkedArray +where + T: PolarsDataType, + U: PolarsDataType, + V: PolarsDataType, + F: for<'a> FnMut(T::Physical<'a>, U::Physical<'a>) -> K, + V::Array: ArrayFromIter, +{ + if lhs.null_count() == lhs.len() || rhs.null_count() == rhs.len() { + let len = lhs.len().min(rhs.len()); + let arr = V::Array::full_null(len, V::get_dtype().to_arrow(CompatLevel::newest())); + + return ChunkedArray::with_chunk(lhs.name().clone(), arr); + } + + let (lhs, rhs) = align_chunks_binary(lhs, rhs); + + let iter = lhs + .downcast_iter() + .zip(rhs.downcast_iter()) + .map(|(lhs_arr, rhs_arr)| { + let validity = combine_validities_and(lhs_arr.validity(), rhs_arr.validity()); + + let element_iter = lhs_arr + .values_iter() + .zip(rhs_arr.values_iter()) + .map(|(lhs_val, rhs_val)| op(lhs_val, rhs_val)); + + let array: V::Array = element_iter.collect_arr(); + array.with_validity_typed(validity) + }); + ChunkedArray::from_chunk_iter(lhs.name().clone(), iter) +} + +/// Apply elementwise binary function which produces string, amortising allocations. +/// +/// Currently unused within Polars itself, but it's a useful utility for plugin authors. +#[inline] +pub fn binary_elementwise_into_string_amortized( + lhs: &ChunkedArray, + rhs: &ChunkedArray, + mut op: F, +) -> StringChunked +where + T: PolarsDataType, + U: PolarsDataType, + F: for<'a> FnMut(T::Physical<'a>, U::Physical<'a>, &mut String), +{ + let (lhs, rhs) = align_chunks_binary(lhs, rhs); + let mut buf = String::new(); + let iter = lhs + .downcast_iter() + .zip(rhs.downcast_iter()) + .map(|(lhs_arr, rhs_arr)| { + let mut mutarr = MutablePlString::with_capacity(lhs_arr.len()); + lhs_arr + .iter() + .zip(rhs_arr.iter()) + .for_each(|(lhs_opt, rhs_opt)| match (lhs_opt, rhs_opt) { + (None, _) | (_, None) => mutarr.push_null(), + (Some(lhs_val), Some(rhs_val)) => { + buf.clear(); + op(lhs_val, rhs_val, &mut buf); + mutarr.push_value(&buf) + }, + }); + mutarr.freeze() + }); + ChunkedArray::from_chunk_iter(lhs.name().clone(), iter) +} + +/// Applies a kernel that produces `Array` types. +/// +/// Intended for kernels that apply on values, this function will filter out any +/// results which do not have two non-null inputs. +#[inline] +pub fn binary_mut_values( + lhs: &ChunkedArray, + rhs: &ChunkedArray, + mut op: F, + name: PlSmallStr, +) -> ChunkedArray +where + T: PolarsDataType, + U: PolarsDataType, + V: PolarsDataType, + Arr: Array + StaticArray, + F: FnMut(&T::Array, &U::Array) -> Arr, +{ + let (lhs, rhs) = align_chunks_binary(lhs, rhs); + let iter = lhs + .downcast_iter() + .zip(rhs.downcast_iter()) + .map(|(lhs_arr, rhs_arr)| { + let ret = op(lhs_arr, rhs_arr); + let inp_val = combine_validities_and(lhs_arr.validity(), rhs_arr.validity()); + let val = combine_validities_and(inp_val.as_ref(), ret.validity()); + ret.with_validity_typed(val) + }); + ChunkedArray::from_chunk_iter(name, iter) +} + +/// Applies a kernel that produces `Array` types. +#[inline] +pub fn binary_mut_with_options( + lhs: &ChunkedArray, + rhs: &ChunkedArray, + mut op: F, + name: PlSmallStr, +) -> ChunkedArray +where + T: PolarsDataType, + U: PolarsDataType, + V: PolarsDataType, + Arr: Array, + F: FnMut(&T::Array, &U::Array) -> Arr, +{ + let (lhs, rhs) = align_chunks_binary(lhs, rhs); + let iter = lhs + .downcast_iter() + .zip(rhs.downcast_iter()) + .map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr)); + ChunkedArray::from_chunk_iter(name, iter) +} + +#[inline] +pub fn try_binary_mut_with_options( + lhs: &ChunkedArray, + rhs: &ChunkedArray, + mut op: F, + name: PlSmallStr, +) -> Result, E> +where + T: PolarsDataType, + U: PolarsDataType, + V: PolarsDataType, + Arr: Array, + F: FnMut(&T::Array, &U::Array) -> Result, + E: Error, +{ + let (lhs, rhs) = align_chunks_binary(lhs, rhs); + let iter = lhs + .downcast_iter() + .zip(rhs.downcast_iter()) + .map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr)); + ChunkedArray::try_from_chunk_iter(name, iter) +} + +/// Applies a kernel that produces `Array` types. +pub fn binary( + lhs: &ChunkedArray, + rhs: &ChunkedArray, + op: F, +) -> ChunkedArray +where + T: PolarsDataType, + U: PolarsDataType, + V: PolarsDataType, + Arr: Array, + F: FnMut(&T::Array, &U::Array) -> Arr, +{ + binary_mut_with_options(lhs, rhs, op, lhs.name().clone()) +} + +/// Applies a kernel that produces `Array` types. +pub fn binary_owned( + lhs: ChunkedArray, + rhs: ChunkedArray, + mut op: F, +) -> ChunkedArray +where + L: PolarsDataType, + R: PolarsDataType, + V: PolarsDataType, + Arr: Array, + F: FnMut(L::Array, R::Array) -> Arr, +{ + let name = lhs.name().clone(); + let (lhs, rhs) = align_chunks_binary_owned(lhs, rhs); + let iter = lhs + .downcast_into_iter() + .zip(rhs.downcast_into_iter()) + .map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr)); + ChunkedArray::from_chunk_iter(name, iter) +} + +/// Applies a kernel that produces `Array` types. +pub fn try_binary( + lhs: &ChunkedArray, + rhs: &ChunkedArray, + mut op: F, +) -> Result, E> +where + T: PolarsDataType, + U: PolarsDataType, + V: PolarsDataType, + Arr: Array, + F: FnMut(&T::Array, &U::Array) -> Result, + E: Error, +{ + let (lhs, rhs) = align_chunks_binary(lhs, rhs); + let iter = lhs + .downcast_iter() + .zip(rhs.downcast_iter()) + .map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr)); + ChunkedArray::try_from_chunk_iter(lhs.name().clone(), iter) +} + +/// Applies a kernel that produces `ArrayRef` of the same type. +/// +/// # Safety +/// Caller must ensure that the returned `ArrayRef` belongs to `T: PolarsDataType`. +#[inline] +pub unsafe fn binary_unchecked_same_type( + lhs: &ChunkedArray, + rhs: &ChunkedArray, + mut op: F, + keep_sorted: bool, + keep_fast_explode: bool, +) -> ChunkedArray +where + T: PolarsDataType, + U: PolarsDataType, + F: FnMut(&T::Array, &U::Array) -> Box, +{ + let (lhs, rhs) = align_chunks_binary(lhs, rhs); + let chunks = lhs + .downcast_iter() + .zip(rhs.downcast_iter()) + .map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr)) + .collect(); + + let mut ca = lhs.copy_with_chunks(chunks); + + let mut retain_flags = StatisticsFlags::empty(); + use StatisticsFlags as F; + retain_flags.set(F::IS_SORTED_ANY, keep_sorted); + retain_flags.set(F::CAN_FAST_EXPLODE_LIST, keep_fast_explode); + ca.retain_flags_from(lhs.as_ref(), retain_flags); + + ca +} + +#[inline] +pub fn binary_to_series( + lhs: &ChunkedArray, + rhs: &ChunkedArray, + mut op: F, +) -> PolarsResult +where + T: PolarsDataType, + U: PolarsDataType, + F: FnMut(&T::Array, &U::Array) -> Box, +{ + let (lhs, rhs) = align_chunks_binary(lhs, rhs); + let chunks = lhs + .downcast_iter() + .zip(rhs.downcast_iter()) + .map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr)) + .collect::>(); + Series::try_from((lhs.name().clone(), chunks)) +} + +/// Applies a kernel that produces `ArrayRef` of the same type. +/// +/// # Safety +/// Caller must ensure that the returned `ArrayRef` belongs to `T: PolarsDataType`. +#[inline] +pub unsafe fn try_binary_unchecked_same_type( + lhs: &ChunkedArray, + rhs: &ChunkedArray, + mut op: F, + keep_sorted: bool, + keep_fast_explode: bool, +) -> Result, E> +where + T: PolarsDataType, + U: PolarsDataType, + F: FnMut(&T::Array, &U::Array) -> Result, E>, + E: Error, +{ + let (lhs, rhs) = align_chunks_binary(lhs, rhs); + let chunks = lhs + .downcast_iter() + .zip(rhs.downcast_iter()) + .map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr)) + .collect::, E>>()?; + let mut ca = lhs.copy_with_chunks(chunks); + + let mut retain_flags = StatisticsFlags::empty(); + use StatisticsFlags as F; + retain_flags.set(F::IS_SORTED_ANY, keep_sorted); + retain_flags.set(F::CAN_FAST_EXPLODE_LIST, keep_fast_explode); + ca.retain_flags_from(lhs.as_ref(), retain_flags); + + Ok(ca) +} + +#[inline] +pub fn try_ternary_elementwise( + ca1: &ChunkedArray, + ca2: &ChunkedArray, + ca3: &ChunkedArray, + mut op: F, +) -> Result, E> +where + T: PolarsDataType, + U: PolarsDataType, + V: PolarsDataType, + G: PolarsDataType, + F: for<'a> FnMut( + Option>, + Option>, + Option>, + ) -> Result, E>, + V::Array: ArrayFromIter>, +{ + let (ca1, ca2, ca3) = align_chunks_ternary(ca1, ca2, ca3); + let iter = ca1 + .downcast_iter() + .zip(ca2.downcast_iter()) + .zip(ca3.downcast_iter()) + .map(|((ca1_arr, ca2_arr), ca3_arr)| { + let element_iter = ca1_arr.iter().zip(ca2_arr.iter()).zip(ca3_arr.iter()).map( + |((ca1_opt_val, ca2_opt_val), ca3_opt_val)| { + op(ca1_opt_val, ca2_opt_val, ca3_opt_val) + }, + ); + element_iter.try_collect_arr() + }); + ChunkedArray::try_from_chunk_iter(ca1.name().clone(), iter) +} + +#[inline] +pub fn ternary_elementwise( + ca1: &ChunkedArray, + ca2: &ChunkedArray, + ca3: &ChunkedArray, + mut op: F, +) -> ChunkedArray +where + T: PolarsDataType, + U: PolarsDataType, + G: PolarsDataType, + V: PolarsDataType, + F: for<'a> TernaryFnMut< + Option>, + Option>, + Option>, + >, + V::Array: for<'a> ArrayFromIter< + >, + Option>, + Option>, + >>::Ret, + >, +{ + let (ca1, ca2, ca3) = align_chunks_ternary(ca1, ca2, ca3); + let iter = ca1 + .downcast_iter() + .zip(ca2.downcast_iter()) + .zip(ca3.downcast_iter()) + .map(|((ca1_arr, ca2_arr), ca3_arr)| { + let element_iter = ca1_arr.iter().zip(ca2_arr.iter()).zip(ca3_arr.iter()).map( + |((ca1_opt_val, ca2_opt_val), ca3_opt_val)| { + op(ca1_opt_val, ca2_opt_val, ca3_opt_val) + }, + ); + element_iter.collect_arr() + }); + ChunkedArray::from_chunk_iter(ca1.name().clone(), iter) +} + +pub fn broadcast_binary_elementwise( + lhs: &ChunkedArray, + rhs: &ChunkedArray, + mut op: F, +) -> ChunkedArray +where + T: PolarsDataType, + U: PolarsDataType, + V: PolarsDataType, + F: for<'a> BinaryFnMut>, Option>>, + V::Array: for<'a> ArrayFromIter< + >, Option>>>::Ret, + >, +{ + match (lhs.len(), rhs.len()) { + (1, _) => { + let a = unsafe { lhs.get_unchecked(0) }; + unary_elementwise(rhs, |b| op(a.clone(), b)).with_name(lhs.name().clone()) + }, + (_, 1) => { + let b = unsafe { rhs.get_unchecked(0) }; + unary_elementwise(lhs, |a| op(a, b.clone())) + }, + _ => binary_elementwise(lhs, rhs, op), + } +} + +pub fn broadcast_try_binary_elementwise( + lhs: &ChunkedArray, + rhs: &ChunkedArray, + mut op: F, +) -> Result, E> +where + T: PolarsDataType, + U: PolarsDataType, + V: PolarsDataType, + F: for<'a> FnMut(Option>, Option>) -> Result, E>, + V::Array: ArrayFromIter>, +{ + match (lhs.len(), rhs.len()) { + (1, _) => { + let a = unsafe { lhs.get_unchecked(0) }; + Ok(try_unary_elementwise(rhs, |b| op(a.clone(), b))?.with_name(lhs.name().clone())) + }, + (_, 1) => { + let b = unsafe { rhs.get_unchecked(0) }; + try_unary_elementwise(lhs, |a| op(a, b.clone())) + }, + _ => try_binary_elementwise(lhs, rhs, op), + } +} + +pub fn broadcast_binary_elementwise_values( + lhs: &ChunkedArray, + rhs: &ChunkedArray, + mut op: F, +) -> ChunkedArray +where + T: PolarsDataType, + U: PolarsDataType, + V: PolarsDataType, + F: for<'a> FnMut(T::Physical<'a>, U::Physical<'a>) -> K, + V::Array: ArrayFromIter, +{ + if lhs.null_count() == lhs.len() || rhs.null_count() == rhs.len() { + let min = lhs.len().min(rhs.len()); + let max = lhs.len().max(rhs.len()); + let len = if min == 1 { max } else { min }; + let arr = V::Array::full_null(len, V::get_dtype().to_arrow(CompatLevel::newest())); + + return ChunkedArray::with_chunk(lhs.name().clone(), arr); + } + + match (lhs.len(), rhs.len()) { + (1, _) => { + let a = unsafe { lhs.value_unchecked(0) }; + unary_elementwise_values(rhs, |b| op(a.clone(), b)).with_name(lhs.name().clone()) + }, + (_, 1) => { + let b = unsafe { rhs.value_unchecked(0) }; + unary_elementwise_values(lhs, |a| op(a, b.clone())) + }, + _ => binary_elementwise_values(lhs, rhs, op), + } +} + +pub fn apply_binary_kernel_broadcast<'l, 'r, L, R, O, K, LK, RK>( + lhs: &'l ChunkedArray, + rhs: &'r ChunkedArray, + kernel: K, + lhs_broadcast_kernel: LK, + rhs_broadcast_kernel: RK, +) -> ChunkedArray +where + L: PolarsDataType, + R: PolarsDataType, + O: PolarsDataType, + K: Fn(&L::Array, &R::Array) -> O::Array, + LK: Fn(L::Physical<'l>, &R::Array) -> O::Array, + RK: Fn(&L::Array, R::Physical<'r>) -> O::Array, +{ + let name = lhs.name(); + let out = match (lhs.len(), rhs.len()) { + (a, b) if a == b => binary(lhs, rhs, |lhs, rhs| kernel(lhs, rhs)), + // broadcast right path + (_, 1) => { + let opt_rhs = rhs.get(0); + match opt_rhs { + None => { + let arr = O::Array::full_null( + lhs.len(), + O::get_dtype().to_arrow(CompatLevel::newest()), + ); + ChunkedArray::::with_chunk(lhs.name().clone(), arr) + }, + Some(rhs) => unary_kernel(lhs, |arr| rhs_broadcast_kernel(arr, rhs.clone())), + } + }, + (1, _) => { + let opt_lhs = lhs.get(0); + match opt_lhs { + None => { + let arr = O::Array::full_null( + rhs.len(), + O::get_dtype().to_arrow(CompatLevel::newest()), + ); + ChunkedArray::::with_chunk(lhs.name().clone(), arr) + }, + Some(lhs) => unary_kernel(rhs, |arr| lhs_broadcast_kernel(lhs.clone(), arr)), + } + }, + _ => panic!("Cannot apply operation on arrays of different lengths"), + }; + out.with_name(name.clone()) +} + +pub fn apply_binary_kernel_broadcast_owned( + lhs: ChunkedArray, + rhs: ChunkedArray, + kernel: K, + lhs_broadcast_kernel: LK, + rhs_broadcast_kernel: RK, +) -> ChunkedArray +where + L: PolarsDataType, + R: PolarsDataType, + O: PolarsDataType, + K: Fn(L::Array, R::Array) -> O::Array, + for<'a> LK: Fn(L::Physical<'a>, R::Array) -> O::Array, + for<'a> RK: Fn(L::Array, R::Physical<'a>) -> O::Array, +{ + let name = lhs.name().to_owned(); + let out = match (lhs.len(), rhs.len()) { + (a, b) if a == b => binary_owned(lhs, rhs, kernel), + // broadcast right path + (_, 1) => { + let opt_rhs = rhs.get(0); + match opt_rhs { + None => { + let arr = O::Array::full_null( + lhs.len(), + O::get_dtype().to_arrow(CompatLevel::newest()), + ); + ChunkedArray::::with_chunk(lhs.name().clone(), arr) + }, + Some(rhs) => unary_kernel_owned(lhs, |arr| rhs_broadcast_kernel(arr, rhs.clone())), + } + }, + (1, _) => { + let opt_lhs = lhs.get(0); + match opt_lhs { + None => { + let arr = O::Array::full_null( + rhs.len(), + O::get_dtype().to_arrow(CompatLevel::newest()), + ); + ChunkedArray::::with_chunk(lhs.name().clone(), arr) + }, + Some(lhs) => unary_kernel_owned(rhs, |arr| lhs_broadcast_kernel(lhs.clone(), arr)), + } + }, + _ => panic!("Cannot apply operation on arrays of different lengths"), + }; + out.with_name(name) +} diff --git a/crates/polars-core/src/chunked_array/ops/bit_repr.rs b/crates/polars-core/src/chunked_array/ops/bit_repr.rs new file mode 100644 index 000000000000..7236fde9993c --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/bit_repr.rs @@ -0,0 +1,282 @@ +use arrow::buffer::Buffer; + +use crate::prelude::*; +use crate::series::BitRepr; + +/// Reinterprets the type of a [`ChunkedArray`]. T and U must have the same size +/// and alignment. +fn reinterpret_chunked_array( + ca: &ChunkedArray, +) -> ChunkedArray { + assert!(size_of::() == size_of::()); + assert!(align_of::() == align_of::()); + + let chunks = ca.downcast_iter().map(|array| { + let buf = array.values().clone(); + // SAFETY: we checked that the size and alignment matches. + #[allow(clippy::transmute_undefined_repr)] + let reinterpreted_buf = + unsafe { std::mem::transmute::, Buffer>(buf) }; + PrimitiveArray::from_data_default(reinterpreted_buf, array.validity().cloned()) + }); + + ChunkedArray::from_chunk_iter(ca.name().clone(), chunks) +} + +/// Reinterprets the type of a [`ListChunked`]. T and U must have the same size +/// and alignment. +#[cfg(feature = "reinterpret")] +fn reinterpret_list_chunked( + ca: &ListChunked, +) -> ListChunked { + assert!(size_of::() == size_of::()); + assert!(align_of::() == align_of::()); + + let chunks = ca.downcast_iter().map(|array| { + let inner_arr = array + .values() + .as_any() + .downcast_ref::>() + .unwrap(); + // SAFETY: we checked that the size and alignment matches. + #[allow(clippy::transmute_undefined_repr)] + let reinterpreted_buf = unsafe { + std::mem::transmute::, Buffer>(inner_arr.values().clone()) + }; + let pa = + PrimitiveArray::from_data_default(reinterpreted_buf, inner_arr.validity().cloned()); + LargeListArray::new( + DataType::List(Box::new(U::get_dtype())).to_arrow(CompatLevel::newest()), + array.offsets().clone(), + pa.to_boxed(), + array.validity().cloned(), + ) + }); + + ListChunked::from_chunk_iter(ca.name().clone(), chunks) +} + +#[cfg(all(feature = "reinterpret", feature = "dtype-i16", feature = "dtype-u16"))] +impl Reinterpret for Int16Chunked { + fn reinterpret_signed(&self) -> Series { + self.clone().into_series() + } + + fn reinterpret_unsigned(&self) -> Series { + reinterpret_chunked_array::<_, UInt16Type>(self).into_series() + } +} + +#[cfg(all(feature = "reinterpret", feature = "dtype-u16", feature = "dtype-i16"))] +impl Reinterpret for UInt16Chunked { + fn reinterpret_signed(&self) -> Series { + reinterpret_chunked_array::<_, Int16Type>(self).into_series() + } + + fn reinterpret_unsigned(&self) -> Series { + self.clone().into_series() + } +} + +#[cfg(all(feature = "reinterpret", feature = "dtype-i8", feature = "dtype-u8"))] +impl Reinterpret for Int8Chunked { + fn reinterpret_signed(&self) -> Series { + self.clone().into_series() + } + + fn reinterpret_unsigned(&self) -> Series { + reinterpret_chunked_array::<_, UInt8Type>(self).into_series() + } +} + +#[cfg(all(feature = "reinterpret", feature = "dtype-u8", feature = "dtype-i8"))] +impl Reinterpret for UInt8Chunked { + fn reinterpret_signed(&self) -> Series { + reinterpret_chunked_array::<_, Int8Type>(self).into_series() + } + + fn reinterpret_unsigned(&self) -> Series { + self.clone().into_series() + } +} + +impl ToBitRepr for ChunkedArray +where + T: PolarsNumericType, +{ + fn to_bit_repr(&self) -> BitRepr { + let is_large = size_of::() == 8; + + if is_large { + if matches!(self.dtype(), DataType::UInt64) { + let ca = self.clone(); + // Convince the compiler we are this type. This keeps flags. + return BitRepr::Large(unsafe { + std::mem::transmute::, UInt64Chunked>(ca) + }); + } + + BitRepr::Large(reinterpret_chunked_array(self)) + } else { + BitRepr::Small(if size_of::() == 4 { + if matches!(self.dtype(), DataType::UInt32) { + let ca = self.clone(); + // Convince the compiler we are this type. This preserves flags. + return BitRepr::Small(unsafe { + std::mem::transmute::, UInt32Chunked>(ca) + }); + } + + reinterpret_chunked_array(self) + } else { + // SAFETY: an unchecked cast to uint32 (which has no invariants) is + // always sound. + unsafe { + self.cast_unchecked(&DataType::UInt32) + .unwrap() + .u32() + .unwrap() + .clone() + } + }) + } + } +} + +#[cfg(feature = "reinterpret")] +impl Reinterpret for UInt64Chunked { + fn reinterpret_signed(&self) -> Series { + let signed: Int64Chunked = reinterpret_chunked_array(self); + signed.into_series() + } + + fn reinterpret_unsigned(&self) -> Series { + self.clone().into_series() + } +} +#[cfg(feature = "reinterpret")] +impl Reinterpret for Int64Chunked { + fn reinterpret_signed(&self) -> Series { + self.clone().into_series() + } + + fn reinterpret_unsigned(&self) -> Series { + let BitRepr::Large(b) = self.to_bit_repr() else { + unreachable!() + }; + b.into_series() + } +} + +#[cfg(feature = "reinterpret")] +impl Reinterpret for UInt32Chunked { + fn reinterpret_signed(&self) -> Series { + let signed: Int32Chunked = reinterpret_chunked_array(self); + signed.into_series() + } + + fn reinterpret_unsigned(&self) -> Series { + self.clone().into_series() + } +} + +#[cfg(feature = "reinterpret")] +impl Reinterpret for Int32Chunked { + fn reinterpret_signed(&self) -> Series { + self.clone().into_series() + } + + fn reinterpret_unsigned(&self) -> Series { + let BitRepr::Small(b) = self.to_bit_repr() else { + unreachable!() + }; + b.into_series() + } +} + +#[cfg(feature = "reinterpret")] +impl Reinterpret for Float32Chunked { + fn reinterpret_signed(&self) -> Series { + reinterpret_chunked_array::<_, Int32Type>(self).into_series() + } + + fn reinterpret_unsigned(&self) -> Series { + reinterpret_chunked_array::<_, UInt32Type>(self).into_series() + } +} + +#[cfg(feature = "reinterpret")] +impl Reinterpret for ListChunked { + fn reinterpret_signed(&self) -> Series { + match self.inner_dtype() { + DataType::Float32 => reinterpret_list_chunked::(self), + DataType::Float64 => reinterpret_list_chunked::(self), + _ => unimplemented!(), + } + .into_series() + } + + fn reinterpret_unsigned(&self) -> Series { + match self.inner_dtype() { + DataType::Float32 => reinterpret_list_chunked::(self), + DataType::Float64 => reinterpret_list_chunked::(self), + _ => unimplemented!(), + } + .into_series() + } +} + +#[cfg(feature = "reinterpret")] +impl Reinterpret for Float64Chunked { + fn reinterpret_signed(&self) -> Series { + reinterpret_chunked_array::<_, Int64Type>(self).into_series() + } + + fn reinterpret_unsigned(&self) -> Series { + reinterpret_chunked_array::<_, UInt64Type>(self).into_series() + } +} + +impl UInt64Chunked { + #[doc(hidden)] + pub fn _reinterpret_float(&self) -> Float64Chunked { + reinterpret_chunked_array(self) + } +} +impl UInt32Chunked { + #[doc(hidden)] + pub fn _reinterpret_float(&self) -> Float32Chunked { + reinterpret_chunked_array(self) + } +} + +/// Used to save compilation paths. Use carefully. Although this is safe, +/// if misused it can lead to incorrect results. +impl Float32Chunked { + pub fn apply_as_ints(&self, f: F) -> Series + where + F: Fn(&Series) -> Series, + { + let BitRepr::Small(s) = self.to_bit_repr() else { + unreachable!() + }; + let s = s.into_series(); + let out = f(&s); + let out = out.u32().unwrap(); + out._reinterpret_float().into() + } +} +impl Float64Chunked { + pub fn apply_as_ints(&self, f: F) -> Series + where + F: Fn(&Series) -> Series, + { + let BitRepr::Large(s) = self.to_bit_repr() else { + unreachable!() + }; + let s = s.into_series(); + let out = f(&s); + let out = out.u64().unwrap(); + out._reinterpret_float().into() + } +} diff --git a/crates/polars-core/src/chunked_array/ops/bits.rs b/crates/polars-core/src/chunked_array/ops/bits.rs new file mode 100644 index 000000000000..178df3a73d09 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/bits.rs @@ -0,0 +1,21 @@ +use super::BooleanChunked; + +impl BooleanChunked { + pub fn num_trues(&self) -> usize { + self.downcast_iter() + .map(|arr| match arr.validity() { + None => arr.values().set_bits(), + Some(validity) => arr.values().num_intersections_with(validity), + }) + .sum() + } + + pub fn num_falses(&self) -> usize { + self.downcast_iter() + .map(|arr| match arr.validity() { + None => arr.values().unset_bits(), + Some(validity) => (!arr.values()).num_intersections_with(validity), + }) + .sum() + } +} diff --git a/crates/polars-core/src/chunked_array/ops/bitwise_reduce.rs b/crates/polars-core/src/chunked_array/ops/bitwise_reduce.rs new file mode 100644 index 000000000000..d69dfa2f5d63 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/bitwise_reduce.rs @@ -0,0 +1,54 @@ +use arrow::array::PrimitiveArray; +use arrow::types::NativeType; +use polars_compute::bitwise::BitwiseKernel; + +use super::{BooleanType, ChunkBitwiseReduce, ChunkedArray, PolarsNumericType}; + +impl ChunkBitwiseReduce for ChunkedArray +where + T: PolarsNumericType, + T::Native: NativeType, + PrimitiveArray: BitwiseKernel, +{ + type Physical = T::Native; + + fn and_reduce(&self) -> Option { + self.downcast_iter() + .filter_map(BitwiseKernel::reduce_and) + .reduce( as BitwiseKernel>::bit_and) + } + + fn or_reduce(&self) -> Option { + self.downcast_iter() + .filter_map(BitwiseKernel::reduce_or) + .reduce( as BitwiseKernel>::bit_or) + } + + fn xor_reduce(&self) -> Option { + self.downcast_iter() + .filter_map(BitwiseKernel::reduce_xor) + .reduce( as BitwiseKernel>::bit_xor) + } +} + +impl ChunkBitwiseReduce for ChunkedArray { + type Physical = bool; + + fn and_reduce(&self) -> Option { + self.downcast_iter() + .filter_map(BitwiseKernel::reduce_and) + .reduce(|a, b| a & b) + } + + fn or_reduce(&self) -> Option { + self.downcast_iter() + .filter_map(BitwiseKernel::reduce_or) + .reduce(|a, b| a | b) + } + + fn xor_reduce(&self) -> Option { + self.downcast_iter() + .filter_map(BitwiseKernel::reduce_xor) + .reduce(|a, b| a ^ b) + } +} diff --git a/crates/polars-core/src/chunked_array/ops/chunkops.rs b/crates/polars-core/src/chunked_array/ops/chunkops.rs new file mode 100644 index 000000000000..7b5c0f81ed78 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/chunkops.rs @@ -0,0 +1,378 @@ +use std::borrow::Cow; +use std::cell::Cell; + +use arrow::bitmap::{Bitmap, BitmapBuilder}; +use arrow::compute::concatenate::concatenate_unchecked; +use polars_error::constants::LENGTH_LIMIT_MSG; + +use super::*; +use crate::chunked_array::flags::StatisticsFlags; +#[cfg(feature = "object")] +use crate::chunked_array::object::builder::ObjectChunkedBuilder; +use crate::utils::slice_offsets; + +pub(crate) fn split_at( + chunks: &[ArrayRef], + offset: i64, + own_length: usize, +) -> (Vec, Vec) { + let mut new_chunks_left = Vec::with_capacity(1); + let mut new_chunks_right = Vec::with_capacity(1); + let (raw_offset, _) = slice_offsets(offset, 0, own_length); + + let mut remaining_offset = raw_offset; + let mut iter = chunks.iter(); + + for chunk in &mut iter { + let chunk_len = chunk.len(); + if remaining_offset > 0 && remaining_offset >= chunk_len { + remaining_offset -= chunk_len; + new_chunks_left.push(chunk.clone()); + continue; + } + + let (l, r) = chunk.split_at_boxed(remaining_offset); + new_chunks_left.push(l); + new_chunks_right.push(r); + break; + } + + for chunk in iter { + new_chunks_right.push(chunk.clone()) + } + if new_chunks_left.is_empty() { + new_chunks_left.push(chunks[0].sliced(0, 0)); + } + if new_chunks_right.is_empty() { + new_chunks_right.push(chunks[0].sliced(0, 0)); + } + (new_chunks_left, new_chunks_right) +} + +pub(crate) fn slice( + chunks: &[ArrayRef], + offset: i64, + slice_length: usize, + own_length: usize, +) -> (Vec, usize) { + let mut new_chunks = Vec::with_capacity(1); + let (raw_offset, slice_len) = slice_offsets(offset, slice_length, own_length); + + let mut remaining_length = slice_len; + let mut remaining_offset = raw_offset; + let mut new_len = 0; + + for chunk in chunks { + let chunk_len = chunk.len(); + if remaining_offset > 0 && remaining_offset >= chunk_len { + remaining_offset -= chunk_len; + continue; + } + let take_len = if remaining_length + remaining_offset > chunk_len { + chunk_len - remaining_offset + } else { + remaining_length + }; + new_len += take_len; + + debug_assert!(remaining_offset + take_len <= chunk.len()); + unsafe { + // SAFETY: + // this function ensures the slices are in bounds + new_chunks.push(chunk.sliced_unchecked(remaining_offset, take_len)); + } + remaining_length -= take_len; + remaining_offset = 0; + if remaining_length == 0 { + break; + } + } + if new_chunks.is_empty() { + new_chunks.push(chunks[0].sliced(0, 0)); + } + (new_chunks, new_len) +} + +// When we deal with arrays and lists we can easily exceed the limit if +// we take the underlying values array as a Series. This call stack +// is hard to follow, so for this one case we make an exception +// and use a thread local. +thread_local!(static CHECK_LENGTH: Cell = const { Cell::new(true) }); + +/// Meant for internal use. In very rare conditions this can be turned off. +/// # Safety +/// The caller must ensure the Series that exceeds the length get's deconstructed +/// into array values or list values before and never is used. +pub unsafe fn _set_check_length(check: bool) { + CHECK_LENGTH.set(check) +} + +impl ChunkedArray { + /// Get the length of the ChunkedArray + #[inline] + pub fn len(&self) -> usize { + self.length + } + + /// Return the number of null values in the ChunkedArray. + #[inline] + pub fn null_count(&self) -> usize { + self.null_count + } + + /// Set the null count directly. + /// + /// This can be useful after mutably adjusting the validity of the + /// underlying arrays. + /// + /// # Safety + /// The new null count must match the total null count of the underlying + /// arrays. + pub unsafe fn set_null_count(&mut self, null_count: usize) { + self.null_count = null_count; + } + + /// Check if ChunkedArray is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Compute the length + pub(crate) fn compute_len(&mut self) { + fn inner(chunks: &[ArrayRef]) -> usize { + match chunks.len() { + // fast path + 1 => chunks[0].len(), + _ => chunks.iter().fold(0, |acc, arr| acc + arr.len()), + } + } + let len = inner(&self.chunks); + // Length limit is `IdxSize::MAX - 1`. We use `IdxSize::MAX` to indicate `NULL` in indexing. + if len >= (IdxSize::MAX as usize) && CHECK_LENGTH.get() { + panic!("{}", LENGTH_LIMIT_MSG); + } + self.length = len; + self.null_count = self + .chunks + .iter() + .map(|arr| arr.null_count()) + .sum::(); + } + + /// Rechunks this ChunkedArray, returning a new Cow::Owned ChunkedArray if it was + /// rechunked or simply a Cow::Borrowed of itself if it was already a single chunk. + pub fn rechunk(&self) -> Cow<'_, Self> { + match self.dtype() { + #[cfg(feature = "object")] + DataType::Object(_) => { + panic!("implementation error") + }, + _ => { + if self.chunks.len() == 1 { + Cow::Borrowed(self) + } else { + let chunks = vec![concatenate_unchecked(&self.chunks).unwrap()]; + + let mut ca = unsafe { self.copy_with_chunks(chunks) }; + use StatisticsFlags as F; + ca.retain_flags_from(self, F::IS_SORTED_ANY | F::CAN_FAST_EXPLODE_LIST); + Cow::Owned(ca) + } + }, + } + } + + /// Rechunks this ChunkedArray in-place. + pub fn rechunk_mut(&mut self) { + if self.chunks.len() > 1 { + let rechunked = concatenate_unchecked(&self.chunks).unwrap(); + if self.chunks.capacity() <= 8 { + // Reuse chunk allocation if not excessive. + self.chunks.clear(); + self.chunks.push(rechunked); + } else { + self.chunks = vec![rechunked]; + } + } + } + + pub fn rechunk_validity(&self) -> Option { + if self.chunks.len() == 1 { + return self.chunks[0].validity().cloned(); + } + + if !self.has_nulls() || self.is_empty() { + return None; + } + + let mut bm = BitmapBuilder::with_capacity(self.len()); + for arr in self.downcast_iter() { + if let Some(v) = arr.validity() { + bm.extend_from_bitmap(v); + } else { + bm.extend_constant(arr.len(), true); + } + } + bm.into_opt_validity() + } + + pub fn with_validities(&mut self, validities: &[Option]) { + assert_eq!(validities.len(), self.chunks.len()); + + // SAFETY: + // We don't change the data type of the chunks, nor the length. + for (arr, validity) in unsafe { self.chunks_mut().iter_mut() }.zip(validities.iter()) { + *arr = arr.with_validity(validity.clone()) + } + } + + /// Split the array. The chunks are reallocated the underlying data slices are zero copy. + /// + /// When offset is negative it will be counted from the end of the array. + /// This method will never error, + /// and will slice the best match when offset, or length is out of bounds + pub fn split_at(&self, offset: i64) -> (Self, Self) { + // A normal slice, slice the buffers and thus keep the whole memory allocated. + let (l, r) = split_at(&self.chunks, offset, self.len()); + let mut out_l = unsafe { self.copy_with_chunks(l) }; + let mut out_r = unsafe { self.copy_with_chunks(r) }; + + use StatisticsFlags as F; + out_l.retain_flags_from(self, F::IS_SORTED_ANY | F::CAN_FAST_EXPLODE_LIST); + out_r.retain_flags_from(self, F::IS_SORTED_ANY | F::CAN_FAST_EXPLODE_LIST); + + (out_l, out_r) + } + + /// Slice the array. The chunks are reallocated the underlying data slices are zero copy. + /// + /// When offset is negative it will be counted from the end of the array. + /// This method will never error, + /// and will slice the best match when offset, or length is out of bounds + pub fn slice(&self, offset: i64, length: usize) -> Self { + // The len: 0 special cases ensure we release memory. + // A normal slice, slice the buffers and thus keep the whole memory allocated. + let exec = || { + let (chunks, len) = slice(&self.chunks, offset, length, self.len()); + let mut out = unsafe { self.copy_with_chunks(chunks) }; + + use StatisticsFlags as F; + out.retain_flags_from(self, F::IS_SORTED_ANY | F::CAN_FAST_EXPLODE_LIST); + out.length = len; + + out + }; + + match length { + 0 => match self.dtype() { + #[cfg(feature = "object")] + DataType::Object(_) => exec(), + _ => self.clear(), + }, + _ => exec(), + } + } + + /// Take a view of top n elements + #[must_use] + pub fn limit(&self, num_elements: usize) -> Self + where + Self: Sized, + { + self.slice(0, num_elements) + } + + /// Get the head of the [`ChunkedArray`] + #[must_use] + pub fn head(&self, length: Option) -> Self + where + Self: Sized, + { + match length { + Some(len) => self.slice(0, std::cmp::min(len, self.len())), + None => self.slice(0, std::cmp::min(10, self.len())), + } + } + + /// Get the tail of the [`ChunkedArray`] + #[must_use] + pub fn tail(&self, length: Option) -> Self + where + Self: Sized, + { + let len = match length { + Some(len) => std::cmp::min(len, self.len()), + None => std::cmp::min(10, self.len()), + }; + self.slice(-(len as i64), len) + } + + /// Remove empty chunks. + pub fn prune_empty_chunks(&mut self) { + let mut count = 0u32; + unsafe { + self.chunks_mut().retain(|arr| { + count += 1; + // Always keep at least one chunk + if count == 1 { + true + } else { + // Remove the empty chunks + !arr.is_empty() + } + }) + } + } +} + +#[cfg(feature = "object")] +impl ObjectChunked { + pub(crate) fn rechunk_object(&self) -> Self { + if self.chunks.len() == 1 { + self.clone() + } else { + let mut builder = ObjectChunkedBuilder::new(self.name().clone(), self.len()); + let chunks = self.downcast_iter(); + + // todo! use iterators once implemented + // no_null path + if !self.has_nulls() { + for arr in chunks { + for idx in 0..arr.len() { + builder.append_value(arr.value(idx).clone()) + } + } + } else { + for arr in chunks { + for idx in 0..arr.len() { + if arr.is_valid(idx) { + builder.append_value(arr.value(idx).clone()) + } else { + builder.append_null() + } + } + } + } + builder.finish() + } + } +} + +#[cfg(test)] +mod test { + #[cfg(feature = "dtype-categorical")] + use crate::prelude::*; + + #[test] + #[cfg(feature = "dtype-categorical")] + fn test_categorical_map_after_rechunk() { + let s = Series::new(PlSmallStr::EMPTY, &["foo", "bar", "spam"]); + let mut a = s + .cast(&DataType::Categorical(None, Default::default())) + .unwrap(); + + a.append(&a.slice(0, 2)).unwrap(); + let a = a.rechunk(); + assert!(a.categorical().unwrap().get_rev_map().len() > 0); + } +} diff --git a/crates/polars-core/src/chunked_array/ops/compare_inner.rs b/crates/polars-core/src/chunked_array/ops/compare_inner.rs new file mode 100644 index 000000000000..ad2155407e0a --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/compare_inner.rs @@ -0,0 +1,228 @@ +#![allow(unsafe_op_in_unsafe_fn)] +//! Used to speed up TotalEq and TotalOrd of elements within an array. + +use std::cmp::Ordering; + +use crate::chunked_array::ChunkedArrayLayout; +use crate::prelude::*; +use crate::series::implementations::null::NullChunked; + +#[repr(transparent)] +#[derive(Copy, Clone)] +pub struct NonNull(pub T); + +impl TotalEq for NonNull { + fn tot_eq(&self, other: &Self) -> bool { + self.0.tot_eq(&other.0) + } +} + +pub trait NullOrderCmp { + fn null_order_cmp(&self, other: &Self, nulls_last: bool) -> Ordering; +} + +impl NullOrderCmp for Option { + fn null_order_cmp(&self, other: &Self, nulls_last: bool) -> Ordering { + match (self, other) { + (None, None) => Ordering::Equal, + (None, Some(_)) => { + if nulls_last { + Ordering::Greater + } else { + Ordering::Less + } + }, + (Some(_), None) => { + if nulls_last { + Ordering::Less + } else { + Ordering::Greater + } + }, + (Some(l), Some(r)) => l.tot_cmp(r), + } + } +} + +impl NullOrderCmp for NonNull { + fn null_order_cmp(&self, other: &Self, _nulls_last: bool) -> Ordering { + self.0.tot_cmp(&other.0) + } +} + +trait GetInner { + type Item; + unsafe fn get_unchecked(&self, idx: usize) -> Self::Item; +} + +impl<'a, T: PolarsDataType> GetInner for &'a ChunkedArray { + type Item = Option>; + unsafe fn get_unchecked(&self, idx: usize) -> Self::Item { + ChunkedArray::get_unchecked(self, idx) + } +} + +impl<'a, T: StaticArray> GetInner for &'a T { + type Item = Option>; + unsafe fn get_unchecked(&self, idx: usize) -> Self::Item { + ::get_unchecked(self, idx) + } +} + +impl<'a, T: PolarsDataType> GetInner for NonNull<&'a ChunkedArray> { + type Item = NonNull>; + unsafe fn get_unchecked(&self, idx: usize) -> Self::Item { + NonNull(self.0.value_unchecked(idx)) + } +} + +impl<'a, T: StaticArray> GetInner for NonNull<&'a T> { + type Item = NonNull>; + unsafe fn get_unchecked(&self, idx: usize) -> Self::Item { + NonNull(self.0.value_unchecked(idx)) + } +} + +pub trait TotalEqInner: Send + Sync { + /// # Safety + /// Does not do any bound checks. + unsafe fn eq_element_unchecked(&self, idx_a: usize, idx_b: usize) -> bool; +} + +pub trait TotalOrdInner: Send + Sync { + /// # Safety + /// Does not do any bound checks. + unsafe fn cmp_element_unchecked( + &self, + idx_a: usize, + idx_b: usize, + nulls_last: bool, + ) -> Ordering; +} + +impl TotalEqInner for T +where + T: GetInner + Send + Sync, + T::Item: TotalEq, +{ + #[inline] + unsafe fn eq_element_unchecked(&self, idx_a: usize, idx_b: usize) -> bool { + self.get_unchecked(idx_a).tot_eq(&self.get_unchecked(idx_b)) + } +} + +impl TotalEqInner for &NullChunked { + unsafe fn eq_element_unchecked(&self, _idx_a: usize, _idx_b: usize) -> bool { + true + } +} + +/// Create a type that implements TotalEqInner. +pub(crate) trait IntoTotalEqInner<'a> { + /// Create a type that implements `TakeRandom`. + fn into_total_eq_inner(self) -> Box; +} + +impl<'a> IntoTotalEqInner<'a> for &'a NullChunked { + fn into_total_eq_inner(self) -> Box { + Box::new(self) + } +} + +/// We use a trait object because we want to call this from Series and cannot use a typed enum. +impl<'a, T> IntoTotalEqInner<'a> for &'a ChunkedArray +where + T: PolarsDataType, + T::Physical<'a>: TotalEq, +{ + fn into_total_eq_inner(self) -> Box { + match self.layout() { + ChunkedArrayLayout::SingleNoNull(arr) => Box::new(NonNull(arr)), + ChunkedArrayLayout::Single(arr) => Box::new(arr), + ChunkedArrayLayout::MultiNoNull(ca) => Box::new(NonNull(ca)), + ChunkedArrayLayout::Multi(ca) => Box::new(ca), + } + } +} + +impl TotalOrdInner for T +where + T: GetInner + Send + Sync, + T::Item: NullOrderCmp, +{ + #[inline] + unsafe fn cmp_element_unchecked( + &self, + idx_a: usize, + idx_b: usize, + nulls_last: bool, + ) -> Ordering { + let a = self.get_unchecked(idx_a); + let b = self.get_unchecked(idx_b); + a.null_order_cmp(&b, nulls_last) + } +} + +/// Create a type that implements TotalOrdInner. +pub(crate) trait IntoTotalOrdInner<'a> { + /// Create a type that implements `TakeRandom`. + fn into_total_ord_inner(self) -> Box; +} + +impl<'a, T> IntoTotalOrdInner<'a> for &'a ChunkedArray +where + T: PolarsDataType, + T::Physical<'a>: TotalOrd, +{ + fn into_total_ord_inner(self) -> Box { + match self.layout() { + ChunkedArrayLayout::SingleNoNull(arr) => Box::new(NonNull(arr)), + ChunkedArrayLayout::Single(arr) => Box::new(arr), + ChunkedArrayLayout::MultiNoNull(ca) => Box::new(NonNull(ca)), + ChunkedArrayLayout::Multi(ca) => Box::new(ca), + } + } +} + +#[cfg(feature = "dtype-categorical")] +struct LocalCategorical<'a> { + rev_map: &'a Utf8ViewArray, + cats: &'a UInt32Chunked, +} + +#[cfg(feature = "dtype-categorical")] +impl<'a> GetInner for LocalCategorical<'a> { + type Item = Option<&'a str>; + unsafe fn get_unchecked(&self, idx: usize) -> Self::Item { + let cat = self.cats.get_unchecked(idx)?; + Some(self.rev_map.value_unchecked(cat as usize)) + } +} + +#[cfg(feature = "dtype-categorical")] +struct GlobalCategorical<'a> { + p1: &'a PlHashMap, + p2: &'a Utf8ViewArray, + cats: &'a UInt32Chunked, +} + +#[cfg(feature = "dtype-categorical")] +impl<'a> GetInner for GlobalCategorical<'a> { + type Item = Option<&'a str>; + unsafe fn get_unchecked(&self, idx: usize) -> Self::Item { + let cat = self.cats.get_unchecked(idx)?; + let idx = self.p1.get(&cat).unwrap(); + Some(self.p2.value_unchecked(*idx as usize)) + } +} + +#[cfg(feature = "dtype-categorical")] +impl<'a> IntoTotalOrdInner<'a> for &'a CategoricalChunked { + fn into_total_ord_inner(self) -> Box { + let cats = self.physical(); + match &**self.get_rev_map() { + RevMapping::Global(p1, p2, _) => Box::new(GlobalCategorical { p1, p2, cats }), + RevMapping::Local(rev_map, _) => Box::new(LocalCategorical { rev_map, cats }), + } + } +} diff --git a/crates/polars-core/src/chunked_array/ops/decimal.rs b/crates/polars-core/src/chunked_array/ops/decimal.rs new file mode 100644 index 000000000000..5f242ee37caa --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/decimal.rs @@ -0,0 +1,56 @@ +use crate::chunked_array::cast::CastOptions; +use crate::prelude::*; + +impl StringChunked { + /// Convert an [`StringChunked`] to a [`Series`] of [`DataType::Decimal`]. + /// Scale needed for the decimal type are inferred. Parsing is not strict. + /// Scale inference assumes that all tested strings are well-formed numbers, + /// and may produce unexpected results for scale if this is not the case. + /// + /// If the decimal `precision` and `scale` are already known, consider + /// using the `cast` method. + pub fn to_decimal(&self, infer_length: usize) -> PolarsResult { + let mut scale = 0; + let mut iter = self.into_iter(); + let mut valid_count = 0; + while let Some(Some(v)) = iter.next() { + let scale_value = arrow::compute::decimal::infer_scale(v.as_bytes()); + scale = std::cmp::max(scale, scale_value); + valid_count += 1; + if valid_count == infer_length { + break; + } + } + + self.cast_with_options( + &DataType::Decimal(None, Some(scale as usize)), + CastOptions::NonStrict, + ) + } +} + +#[cfg(test)] +mod test { + #[test] + fn test_inferred_length() { + use super::*; + let vals = [ + "1.0", + "invalid", + "225.0", + "3.00045", + "-4.0", + "5.104", + "5.25251525353", + ]; + let s = StringChunked::from_slice(PlSmallStr::from_str("test"), &vals); + let s = s.to_decimal(6).unwrap(); + assert_eq!(s.dtype(), &DataType::Decimal(None, Some(5))); + assert_eq!(s.len(), 7); + assert_eq!(s.get(0).unwrap(), AnyValue::Decimal(100000, 5)); + assert_eq!(s.get(1).unwrap(), AnyValue::Null); + assert_eq!(s.get(3).unwrap(), AnyValue::Decimal(300045, 5)); + assert_eq!(s.get(4).unwrap(), AnyValue::Decimal(-400000, 5)); + assert_eq!(s.get(6).unwrap(), AnyValue::Decimal(525251, 5)); + } +} diff --git a/crates/polars-core/src/chunked_array/ops/downcast.rs b/crates/polars-core/src/chunked_array/ops/downcast.rs new file mode 100644 index 000000000000..251cb2980570 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/downcast.rs @@ -0,0 +1,164 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use std::marker::PhantomData; + +use arrow::array::*; +use arrow::compute::utils::combine_validities_and; + +use crate::prelude::*; +use crate::utils::{index_to_chunked_index, index_to_chunked_index_rev}; + +pub struct Chunks<'a, T> { + chunks: &'a [ArrayRef], + phantom: PhantomData, +} + +impl<'a, T> Chunks<'a, T> { + fn new(chunks: &'a [ArrayRef]) -> Self { + Chunks { + chunks, + phantom: PhantomData, + } + } + + #[inline] + pub fn get(&self, index: usize) -> Option<&'a T> { + self.chunks.get(index).map(|arr| { + let arr = &**arr; + unsafe { &*(arr as *const dyn Array as *const T) } + }) + } + + #[inline] + pub unsafe fn get_unchecked(&self, index: usize) -> &'a T { + let arr = self.chunks.get_unchecked(index); + let arr = &**arr; + &*(arr as *const dyn Array as *const T) + } + + pub fn len(&self) -> usize { + self.chunks.len() + } + + #[inline] + pub fn last(&self) -> Option<&'a T> { + self.chunks.last().map(|arr| { + let arr = &**arr; + unsafe { &*(arr as *const dyn Array as *const T) } + }) + } +} + +#[doc(hidden)] +impl ChunkedArray { + #[inline] + pub fn downcast_into_iter(mut self) -> impl DoubleEndedIterator { + let chunks = std::mem::take(&mut self.chunks); + chunks.into_iter().map(|arr| { + // SAFETY: T::Array guarantees this is correct. + let ptr = Box::into_raw(arr).cast::(); + unsafe { *Box::from_raw(ptr) } + }) + } + + #[inline] + pub fn downcast_iter(&self) -> impl DoubleEndedIterator { + self.chunks.iter().map(|arr| { + // SAFETY: T::Array guarantees this is correct. + let arr = &**arr; + unsafe { &*(arr as *const dyn Array as *const T::Array) } + }) + } + + #[inline] + pub fn downcast_slices(&self) -> Option]>> { + if self.null_count() != 0 { + return None; + } + let arr = self.downcast_iter().next().unwrap(); + if arr.as_slice().is_some() { + Some(self.downcast_iter().map(|arr| arr.as_slice().unwrap())) + } else { + None + } + } + + /// # Safety + /// The caller must ensure: + /// * the length remains correct. + /// * the flags (sorted, etc) remain correct. + #[inline] + pub unsafe fn downcast_iter_mut(&mut self) -> impl DoubleEndedIterator { + self.chunks.iter_mut().map(|arr| { + // SAFETY: T::Array guarantees this is correct. + let arr = &mut **arr; + &mut *(arr as *mut dyn Array as *mut T::Array) + }) + } + + #[inline] + pub fn downcast_chunks(&self) -> Chunks<'_, T::Array> { + Chunks::new(&self.chunks) + } + + #[inline] + pub fn downcast_get(&self, idx: usize) -> Option<&T::Array> { + let arr = self.chunks.get(idx)?; + // SAFETY: T::Array guarantees this is correct. + let arr = &**arr; + unsafe { Some(&*(arr as *const dyn Array as *const T::Array)) } + } + + #[inline] + pub fn downcast_as_array(&self) -> &T::Array { + assert_eq!(self.chunks.len(), 1); + self.downcast_get(0).unwrap() + } + + #[inline] + /// # Safety + /// It is up to the caller to ensure the chunk idx is in-bounds + pub unsafe fn downcast_get_unchecked(&self, idx: usize) -> &T::Array { + let arr = self.chunks.get_unchecked(idx); + // SAFETY: T::Array guarantees this is correct. + let arr = &**arr; + unsafe { &*(arr as *const dyn Array as *const T::Array) } + } + + /// Get the index of the chunk and the index of the value in that chunk. + #[inline] + pub(crate) fn index_to_chunked_index(&self, index: usize) -> (usize, usize) { + // Fast path. + if self.chunks.len() == 1 { + // SAFETY: chunks.len() == 1 guarantees this is correct. + let len = unsafe { self.chunks.get_unchecked(0).len() }; + return if index < len { + (0, index) + } else { + (1, index - len) + }; + } + let chunk_lens = self.chunk_lengths(); + let len = self.len(); + if index <= len / 2 { + // Access from lhs. + index_to_chunked_index(chunk_lens, index) + } else { + // Access from rhs. + let index_from_back = len - index; + index_to_chunked_index_rev(chunk_lens.rev(), index_from_back, self.chunks.len()) + } + } + + /// # Panics + /// Panics if chunks don't align + pub fn merge_validities(&mut self, chunks: &[ArrayRef]) { + assert_eq!(chunks.len(), self.chunks.len()); + unsafe { + for (arr, other) in self.chunks_mut().iter_mut().zip(chunks) { + let validity = combine_validities_and(arr.validity(), other.validity()); + *arr = arr.with_validity(validity); + } + } + self.compute_len(); + } +} diff --git a/crates/polars-core/src/chunked_array/ops/explode.rs b/crates/polars-core/src/chunked_array/ops/explode.rs new file mode 100644 index 000000000000..bab6d5550016 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/explode.rs @@ -0,0 +1,446 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use arrow::array::*; +use arrow::bitmap::utils::set_bit_unchecked; +use arrow::bitmap::{Bitmap, MutableBitmap}; +use arrow::legacy::prelude::*; + +use crate::prelude::*; +use crate::series::implementations::null::NullChunked; + +pub(crate) trait ExplodeByOffsets { + fn explode_by_offsets(&self, offsets: &[i64]) -> Series; +} + +unsafe fn unset_nulls( + start: usize, + last: usize, + validity_values: &Bitmap, + nulls: &mut Vec, + empty_row_idx: &[usize], + base_offset: usize, +) { + for i in start..last { + if !validity_values.get_bit_unchecked(i) { + nulls.push(i + empty_row_idx.len() - base_offset); + } + } +} + +fn get_capacity(offsets: &[i64]) -> usize { + (offsets[offsets.len() - 1] - offsets[0] + 1) as usize +} + +impl ExplodeByOffsets for ChunkedArray +where + T: PolarsIntegerType, +{ + fn explode_by_offsets(&self, offsets: &[i64]) -> Series { + debug_assert_eq!(self.chunks.len(), 1); + let arr = self.downcast_iter().next().unwrap(); + + // make sure that we don't look beyond the sliced array + let values = &arr.values().as_slice()[..offsets[offsets.len() - 1] as usize]; + + let mut empty_row_idx = vec![]; + let mut nulls = vec![]; + + let mut start = offsets[0] as usize; + let base_offset = start; + let mut last = start; + + let mut new_values = Vec::with_capacity(offsets[offsets.len() - 1] as usize - start + 1); + + // we check all the offsets and in the case a consecutive offset is the same, + // e.g. 0, 1, 4, 4, 6 + // the 4 4, means that that is an empty row. + // the empty row will be replaced with a None value. + // + // below we memcpy as much as possible and for the empty rows we add a default value + // that value will later be masked out by the validity bitmap + + // in the case that the value array has got null values, we need to check every validity + // value and collect the indices. + // because the length of the array is not known, we first collect the null indexes, offset + // with the insertion of empty rows (as None) and later create a validity bitmap + if arr.null_count() > 0 { + let validity_values = arr.validity().unwrap(); + + for &o in &offsets[1..] { + let o = o as usize; + if o == last { + if start != last { + #[cfg(debug_assertions)] + new_values.extend_from_slice(&values[start..last]); + + #[cfg(not(debug_assertions))] + unsafe { + new_values.extend_from_slice(values.get_unchecked(start..last)) + }; + + // SAFETY: + // we are in bounds + unsafe { + unset_nulls( + start, + last, + validity_values, + &mut nulls, + &empty_row_idx, + base_offset, + ) + } + } + + empty_row_idx.push(o + empty_row_idx.len() - base_offset); + new_values.push(T::Native::default()); + start = o; + } + last = o; + } + + // final null check + // SAFETY: + // we are in bounds + unsafe { + unset_nulls( + start, + last, + validity_values, + &mut nulls, + &empty_row_idx, + base_offset, + ) + } + } else { + for &o in &offsets[1..] { + let o = o as usize; + if o == last { + if start != last { + unsafe { new_values.extend_from_slice(values.get_unchecked(start..last)) }; + } + + empty_row_idx.push(o + empty_row_idx.len() - base_offset); + new_values.push(T::Native::default()); + start = o; + } + last = o; + } + } + + // add remaining values + new_values.extend_from_slice(&values[start..]); + + let mut validity = MutableBitmap::with_capacity(new_values.len()); + validity.extend_constant(new_values.len(), true); + let validity_slice = validity.as_mut_slice(); + + for i in empty_row_idx { + unsafe { set_bit_unchecked(validity_slice, i, false) } + } + for i in nulls { + unsafe { set_bit_unchecked(validity_slice, i, false) } + } + let arr = PrimitiveArray::new( + T::get_dtype().to_arrow(CompatLevel::newest()), + new_values.into(), + Some(validity.into()), + ); + Series::try_from((self.name().clone(), Box::new(arr) as ArrayRef)).unwrap() + } +} + +impl ExplodeByOffsets for Float32Chunked { + fn explode_by_offsets(&self, offsets: &[i64]) -> Series { + self.apply_as_ints(|s| { + let ca = s.u32().unwrap(); + ca.explode_by_offsets(offsets) + }) + } +} +impl ExplodeByOffsets for Float64Chunked { + fn explode_by_offsets(&self, offsets: &[i64]) -> Series { + self.apply_as_ints(|s| { + let ca = s.u64().unwrap(); + ca.explode_by_offsets(offsets) + }) + } +} + +impl ExplodeByOffsets for NullChunked { + fn explode_by_offsets(&self, offsets: &[i64]) -> Series { + let mut last_offset = offsets[0]; + + let mut len = 0; + for &offset in &offsets[1..] { + // If offset == last_offset we have an empty list and a new row is inserted, + // therefore we always increase at least 1. + len += std::cmp::max(offset - last_offset, 1) as usize; + last_offset = offset; + } + NullChunked::new(self.name.clone(), len).into_series() + } +} + +impl ExplodeByOffsets for BooleanChunked { + fn explode_by_offsets(&self, offsets: &[i64]) -> Series { + debug_assert_eq!(self.chunks.len(), 1); + let arr = self.downcast_iter().next().unwrap(); + + let cap = get_capacity(offsets); + let mut builder = BooleanChunkedBuilder::new(self.name().clone(), cap); + + let mut start = offsets[0] as usize; + let mut last = start; + for &o in &offsets[1..] { + let o = o as usize; + if o == last { + if start != last { + let vals = arr.slice_typed(start, last - start); + + if vals.null_count() == 0 { + builder + .array_builder + .extend_trusted_len_values(vals.values_iter()) + } else { + builder.array_builder.extend_trusted_len(vals.into_iter()); + } + } + builder.append_null(); + start = o; + } + last = o; + } + let vals = arr.slice_typed(start, last - start); + if vals.null_count() == 0 { + builder + .array_builder + .extend_trusted_len_values(vals.values_iter()) + } else { + builder.array_builder.extend_trusted_len(vals.into_iter()); + } + builder.finish().into() + } +} + +/// Convert Arrow array offsets to indexes of the original list +pub(crate) fn offsets_to_indexes(offsets: &[i64], capacity: usize) -> Vec { + if offsets.is_empty() { + return vec![]; + } + + let mut idx = Vec::with_capacity(capacity); + + let mut last_idx = 0; + for (offset_start, offset_end) in offsets.iter().zip(offsets[1..].iter()) { + if idx.len() >= capacity { + // significant speed-up in edge cases with many offsets, + // no measurable overhead in typical case due to branch prediction + break; + } + + if offset_start == offset_end { + // if the previous offset is equal to the current offset, we have an empty + // list and we duplicate the previous index + idx.push(last_idx); + } else { + let width = (offset_end - offset_start) as usize; + for _ in 0..width { + idx.push(last_idx); + } + } + + last_idx += 1; + } + + // take the remaining values + for _ in 0..capacity.saturating_sub(idx.len()) { + idx.push(last_idx); + } + idx.truncate(capacity); + idx +} + +#[cfg(test)] +mod test { + use super::*; + use crate::chunked_array::builder::get_list_builder; + + #[test] + fn test_explode_list() -> PolarsResult<()> { + let mut builder = get_list_builder(&DataType::Int32, 5, 5, PlSmallStr::from_static("a")); + + builder + .append_series(&Series::new(PlSmallStr::EMPTY, &[1, 2, 3, 3])) + .unwrap(); + builder + .append_series(&Series::new(PlSmallStr::EMPTY, &[1])) + .unwrap(); + builder + .append_series(&Series::new(PlSmallStr::EMPTY, &[2])) + .unwrap(); + + let ca = builder.finish(); + assert!(ca._can_fast_explode()); + + // normal explode + let exploded = ca.explode()?; + let out: Vec<_> = exploded.i32()?.into_no_null_iter().collect(); + assert_eq!(out, &[1, 2, 3, 3, 1, 2]); + + // sliced explode + let exploded = ca.slice(0, 1).explode()?; + let out: Vec<_> = exploded.i32()?.into_no_null_iter().collect(); + assert_eq!(out, &[1, 2, 3, 3]); + + Ok(()) + } + + #[test] + fn test_explode_empty_list_slot() -> PolarsResult<()> { + // primitive + let mut builder = get_list_builder(&DataType::Int32, 5, 5, PlSmallStr::from_static("a")); + builder + .append_series(&Series::new(PlSmallStr::EMPTY, &[1i32, 2])) + .unwrap(); + builder + .append_series(&Int32Chunked::from_slice(PlSmallStr::EMPTY, &[]).into_series()) + .unwrap(); + builder + .append_series(&Series::new(PlSmallStr::EMPTY, &[3i32])) + .unwrap(); + + let ca = builder.finish(); + let exploded = ca.explode()?; + assert_eq!( + Vec::from(exploded.i32()?), + &[Some(1), Some(2), None, Some(3)] + ); + + // more primitive + let mut builder = get_list_builder(&DataType::Int32, 5, 5, PlSmallStr::from_static("a")); + builder + .append_series(&Series::new(PlSmallStr::EMPTY, &[1i32])) + .unwrap(); + builder + .append_series(&Int32Chunked::from_slice(PlSmallStr::EMPTY, &[]).into_series()) + .unwrap(); + builder + .append_series(&Series::new(PlSmallStr::EMPTY, &[2i32])) + .unwrap(); + builder + .append_series(&Int32Chunked::from_slice(PlSmallStr::EMPTY, &[]).into_series()) + .unwrap(); + builder + .append_series(&Series::new(PlSmallStr::EMPTY, &[3, 4i32])) + .unwrap(); + + let ca = builder.finish(); + let exploded = ca.explode()?; + assert_eq!( + Vec::from(exploded.i32()?), + &[Some(1), None, Some(2), None, Some(3), Some(4)] + ); + + // string + let mut builder = get_list_builder(&DataType::String, 5, 5, PlSmallStr::from_static("a")); + builder + .append_series(&Series::new(PlSmallStr::EMPTY, &["abc"])) + .unwrap(); + builder + .append_series( + &>::from_slice( + PlSmallStr::EMPTY, + &[], + ) + .into_series(), + ) + .unwrap(); + builder + .append_series(&Series::new(PlSmallStr::EMPTY, &["de"])) + .unwrap(); + builder + .append_series( + &>::from_slice( + PlSmallStr::EMPTY, + &[], + ) + .into_series(), + ) + .unwrap(); + builder + .append_series(&Series::new(PlSmallStr::EMPTY, &["fg"])) + .unwrap(); + builder + .append_series( + &>::from_slice( + PlSmallStr::EMPTY, + &[], + ) + .into_series(), + ) + .unwrap(); + + let ca = builder.finish(); + let exploded = ca.explode()?; + assert_eq!( + Vec::from(exploded.str()?), + &[Some("abc"), None, Some("de"), None, Some("fg"), None] + ); + + // boolean + let mut builder = get_list_builder(&DataType::Boolean, 5, 5, PlSmallStr::from_static("a")); + builder + .append_series(&Series::new(PlSmallStr::EMPTY, &[true])) + .unwrap(); + builder + .append_series(&BooleanChunked::from_slice(PlSmallStr::EMPTY, &[]).into_series()) + .unwrap(); + builder + .append_series(&Series::new(PlSmallStr::EMPTY, &[false])) + .unwrap(); + builder + .append_series(&BooleanChunked::from_slice(PlSmallStr::EMPTY, &[]).into_series()) + .unwrap(); + builder + .append_series(&Series::new(PlSmallStr::EMPTY, &[true, true])) + .unwrap(); + + let ca = builder.finish(); + let exploded = ca.explode()?; + assert_eq!( + Vec::from(exploded.bool()?), + &[Some(true), None, Some(false), None, Some(true), Some(true)] + ); + + Ok(()) + } + + #[test] + fn test_row_offsets() { + let offsets = &[0, 1, 2, 2, 3, 4, 4]; + let out = offsets_to_indexes(offsets, 6); + assert_eq!(out, &[0, 1, 2, 3, 4, 5]); + } + + #[test] + fn test_empty_row_offsets() { + let offsets = &[0, 0]; + let out = offsets_to_indexes(offsets, 0); + let expected: Vec = Vec::new(); + assert_eq!(out, expected); + } + + #[test] + fn test_row_offsets_over_capacity() { + let offsets = &[0, 1, 1, 2, 2]; + let out = offsets_to_indexes(offsets, 2); + assert_eq!(out, &[0, 1]); + } + + #[test] + fn test_row_offsets_nonzero_first_offset() { + let offsets = &[3, 6, 8]; + let out = offsets_to_indexes(offsets, 10); + assert_eq!(out, &[0, 0, 0, 1, 1, 2, 2, 2, 2, 2]); + } +} diff --git a/crates/polars-core/src/chunked_array/ops/explode_and_offsets.rs b/crates/polars-core/src/chunked_array/ops/explode_and_offsets.rs new file mode 100644 index 000000000000..d6f546e5fef9 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/explode_and_offsets.rs @@ -0,0 +1,303 @@ +use arrow::offset::OffsetsBuffer; +use polars_compute::gather::take_unchecked; + +use super::*; + +impl ListChunked { + fn specialized( + &self, + values: ArrayRef, + offsets: &[i64], + offsets_buf: OffsetsBuffer, + ) -> (Series, OffsetsBuffer) { + // SAFETY: inner_dtype should be correct + let values = unsafe { + Series::from_chunks_and_dtype_unchecked( + self.name().clone(), + vec![values], + &self.inner_dtype().to_physical(), + ) + }; + + use crate::chunked_array::ops::explode::ExplodeByOffsets; + + let mut values = match values.dtype() { + DataType::Boolean => { + let t = values.bool().unwrap(); + ExplodeByOffsets::explode_by_offsets(t, offsets).into_series() + }, + DataType::Null => { + let t = values.null().unwrap(); + ExplodeByOffsets::explode_by_offsets(t, offsets).into_series() + }, + dtype => { + with_match_physical_numeric_polars_type!(dtype, |$T| { + let t: &ChunkedArray<$T> = values.as_ref().as_ref(); + ExplodeByOffsets::explode_by_offsets(t, offsets).into_series() + }) + }, + }; + + // let mut values = values.explode_by_offsets(offsets); + // restore logical type + values = unsafe { values.from_physical_unchecked(self.inner_dtype()) }.unwrap(); + + (values, offsets_buf) + } +} + +impl ChunkExplode for ListChunked { + fn offsets(&self) -> PolarsResult> { + let ca = self.rechunk(); + let listarr: &LargeListArray = ca.downcast_iter().next().unwrap(); + let offsets = listarr.offsets().clone(); + Ok(offsets) + } + + fn explode_and_offsets(&self) -> PolarsResult<(Series, OffsetsBuffer)> { + // A list array's memory layout is actually already 'exploded', so we can just take the + // values array of the list. And we also return a slice of the offsets. This slice can be + // used to find the old list layout or indexes to expand a DataFrame in the same manner as + // the `explode` operation. + let ca = self.rechunk(); + let listarr: &LargeListArray = ca.downcast_iter().next().unwrap(); + let offsets_buf = listarr.offsets().clone(); + let offsets = listarr.offsets().as_slice(); + let mut values = listarr.values().clone(); + + let (mut s, offsets) = if ca._can_fast_explode() { + // ensure that the value array is sliced + // as a list only slices its offsets on a slice operation + + // we only do this in fast-explode as for the other + // branch the offsets must coincide with the values. + if !offsets.is_empty() { + let start = offsets[0] as usize; + let len = offsets[offsets.len() - 1] as usize - start; + // SAFETY: + // we are in bounds + values = unsafe { values.sliced_unchecked(start, len) }; + } + // SAFETY: inner_dtype should be correct + ( + unsafe { + Series::from_chunks_and_dtype_unchecked( + self.name().clone(), + vec![values], + &self.inner_dtype().to_physical(), + ) + }, + offsets_buf, + ) + } else { + // during tests + // test that this code branch is not hit with list arrays that could be fast exploded + #[cfg(test)] + { + let mut last = offsets[0]; + let mut has_empty = false; + for &o in &offsets[1..] { + if o == last { + has_empty = true; + } + last = o; + } + if !has_empty && offsets[0] == 0 { + panic!("could have fast exploded") + } + } + let (indices, new_offsets) = if listarr.null_count() == 0 { + // SPECIALIZED path. + let inner_phys = self.inner_dtype().to_physical(); + if inner_phys.is_primitive_numeric() || inner_phys.is_null() || inner_phys.is_bool() + { + return Ok(self.specialized(values, offsets, offsets_buf)); + } + // Use gather + let mut indices = + MutablePrimitiveArray::::with_capacity(*offsets_buf.last() as usize); + let mut new_offsets = Vec::with_capacity(listarr.len() + 1); + let mut current_offset = 0i64; + let mut iter = offsets.iter(); + if let Some(mut previous) = iter.next().copied() { + new_offsets.push(current_offset); + iter.for_each(|&offset| { + let len = offset - previous; + let start = previous as IdxSize; + let end = offset as IdxSize; + + if len == 0 { + indices.push_null(); + } else { + indices.extend_trusted_len_values(start..end); + } + current_offset += len; + previous = offset; + new_offsets.push(current_offset); + }) + } + (indices, new_offsets) + } else { + // we have already ensure that validity is not none. + let validity = listarr.validity().unwrap(); + + let mut indices = + MutablePrimitiveArray::::with_capacity(*offsets_buf.last() as usize); + let mut new_offsets = Vec::with_capacity(listarr.len() + 1); + let mut current_offset = 0i64; + let mut iter = offsets.iter(); + if let Some(mut previous) = iter.next().copied() { + new_offsets.push(current_offset); + iter.enumerate().for_each(|(i, &offset)| { + let len = offset - previous; + let start = previous as IdxSize; + let end = offset as IdxSize; + // SAFETY: we are within bounds + if unsafe { validity.get_bit_unchecked(i) } { + // explode expects null value if sublist is empty. + if len == 0 { + indices.push_null(); + } else { + indices.extend_trusted_len_values(start..end); + } + current_offset += len; + } else { + indices.push_null(); + } + previous = offset; + new_offsets.push(current_offset); + }) + } + (indices, new_offsets) + }; + + // SAFETY: the indices we generate are in bounds + let chunk = unsafe { take_unchecked(values.as_ref(), &indices.into()) }; + // SAFETY: inner_dtype should be correct + let s = unsafe { + Series::from_chunks_and_dtype_unchecked( + self.name().clone(), + vec![chunk], + &self.inner_dtype().to_physical(), + ) + }; + // SAFETY: monotonically increasing + let new_offsets = unsafe { OffsetsBuffer::new_unchecked(new_offsets.into()) }; + (s, new_offsets) + }; + debug_assert_eq!(s.name(), self.name()); + // restore logical type + s = unsafe { s.from_physical_unchecked(self.inner_dtype()) }.unwrap(); + + Ok((s, offsets)) + } +} + +#[cfg(feature = "dtype-array")] +impl ChunkExplode for ArrayChunked { + fn offsets(&self) -> PolarsResult> { + // fast-path for non-null array. + if self.null_count() == 0 { + let width = self.width() as i64; + let offsets = (0..self.len() + 1) + .map(|i| { + let i = i as i64; + i * width + }) + .collect::>(); + // SAFETY: monotonically increasing + let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) }; + + return Ok(offsets); + } + + let ca = self.rechunk(); + let arr = ca.downcast_iter().next().unwrap(); + // we have already ensure that validity is not none. + let validity = arr.validity().unwrap(); + let width = arr.size(); + + let mut current_offset = 0i64; + let offsets = (0..=arr.len()) + .map(|i| { + if i == 0 { + return current_offset; + } + // SAFETY: we are within bounds + if unsafe { validity.get_bit_unchecked(i - 1) } { + current_offset += width as i64 + } + current_offset + }) + .collect::>(); + // SAFETY: monotonically increasing + let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) }; + Ok(offsets) + } + + fn explode_and_offsets(&self) -> PolarsResult<(Series, OffsetsBuffer)> { + let ca = self.rechunk(); + let arr = ca.downcast_iter().next().unwrap(); + // fast-path for non-null array. + if arr.null_count() == 0 { + let s = unsafe { + Series::from_chunks_and_dtype_unchecked( + self.name().clone(), + vec![arr.values().clone()], + ca.inner_dtype(), + ) + }; + let width = self.width() as i64; + let offsets = (0..self.len() + 1) + .map(|i| { + let i = i as i64; + i * width + }) + .collect::>(); + // SAFETY: monotonically increasing + let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) }; + return Ok((s, offsets)); + } + + // we have already ensure that validity is not none. + let validity = arr.validity().unwrap(); + let values = arr.values(); + let width = arr.size(); + + let mut indices = MutablePrimitiveArray::::with_capacity( + values.len() - arr.null_count() * (width - 1), + ); + let mut offsets = Vec::with_capacity(arr.len() + 1); + let mut current_offset = 0i64; + offsets.push(current_offset); + (0..arr.len()).for_each(|i| { + // SAFETY: we are within bounds + if unsafe { validity.get_bit_unchecked(i) } { + let start = (i * width) as IdxSize; + let end = start + width as IdxSize; + indices.extend_trusted_len_values(start..end); + current_offset += width as i64; + } else { + indices.push_null(); + } + offsets.push(current_offset); + }); + + // SAFETY: the indices we generate are in bounds + let chunk = unsafe { take_unchecked(&**values, &indices.into()) }; + // SAFETY: monotonically increasing + let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) }; + + Ok(( + // SAFETY: inner_dtype should be correct + unsafe { + Series::from_chunks_and_dtype_unchecked( + ca.name().clone(), + vec![chunk], + ca.inner_dtype(), + ) + }, + offsets, + )) + } +} diff --git a/crates/polars-core/src/chunked_array/ops/extend.rs b/crates/polars-core/src/chunked_array/ops/extend.rs new file mode 100644 index 000000000000..8111b4a764e0 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/extend.rs @@ -0,0 +1,244 @@ +use arrow::Either; +use arrow::compute::concatenate::concatenate; + +use crate::prelude::append::update_sorted_flag_before_append; +use crate::prelude::*; +use crate::series::IsSorted; + +fn extend_immutable(immutable: &dyn Array, chunks: &mut Vec, other_chunks: &[ArrayRef]) { + let out = if chunks.len() == 1 { + concatenate(&[immutable, &*other_chunks[0]]).unwrap() + } else { + let mut arrays = Vec::with_capacity(other_chunks.len() + 1); + arrays.push(immutable); + arrays.extend(other_chunks.iter().map(|a| &**a)); + concatenate(&arrays).unwrap() + }; + + chunks.push(out); +} + +impl ChunkedArray +where + T: PolarsNumericType, +{ + /// Extend the memory backed by this array with the values from `other`. + /// + /// Different from [`ChunkedArray::append`] which adds chunks to this [`ChunkedArray`] `extend` + /// appends the data from `other` to the underlying `PrimitiveArray` and thus may cause a reallocation. + /// + /// However if this does not cause a reallocation, the resulting data structure will not have any extra chunks + /// and thus will yield faster queries. + /// + /// Prefer `extend` over `append` when you want to do a query after a single append. For instance during + /// online operations where you add `n` rows and rerun a query. + /// + /// Prefer `append` over `extend` when you want to append many times before doing a query. For instance + /// when you read in multiple files and when to store them in a single `DataFrame`. + /// In the latter case finish the sequence of `append` operations with a [`rechunk`](Self::rechunk). + pub fn extend(&mut self, other: &Self) -> PolarsResult<()> { + update_sorted_flag_before_append::(self, other); + // all to a single chunk + if self.chunks.len() > 1 { + self.append(other)?; + self.rechunk_mut(); + return Ok(()); + } + // Depending on the state of the underlying arrow array we + // might be able to get a `MutablePrimitiveArray` + // + // This is only possible if the reference count of the array and its buffers are 1 + // So the logic below is needed to keep the reference count 1 if it is + + // First we must obtain an owned version of the array + let arr = self.downcast_iter().next().unwrap(); + + // increments 1 + let arr = arr.clone(); + + // now we drop our owned ArrayRefs so that + // decrements 1 + { + self.chunks.clear(); + } + + use Either::*; + + if arr.values().is_sliced() { + extend_immutable(&arr, &mut self.chunks, &other.chunks); + } else { + match arr.into_mut() { + Left(immutable) => { + extend_immutable(&immutable, &mut self.chunks, &other.chunks); + }, + Right(mut mutable) => { + for arr in other.downcast_iter() { + match arr.null_count() { + 0 => mutable.extend_from_slice(arr.values()), + _ => mutable.extend_trusted_len(arr.into_iter()), + } + } + let arr: PrimitiveArray = mutable.into(); + self.chunks.push(Box::new(arr) as ArrayRef) + }, + } + } + self.compute_len(); + Ok(()) + } +} + +#[doc(hidden)] +impl StringChunked { + pub fn extend(&mut self, other: &Self) -> PolarsResult<()> { + self.set_sorted_flag(IsSorted::Not); + self.append(other) + } +} + +#[doc(hidden)] +impl BinaryChunked { + pub fn extend(&mut self, other: &Self) -> PolarsResult<()> { + self.set_sorted_flag(IsSorted::Not); + self.append(other) + } +} + +#[doc(hidden)] +impl BinaryOffsetChunked { + pub fn extend(&mut self, other: &Self) -> PolarsResult<()> { + self.set_sorted_flag(IsSorted::Not); + self.append(other) + } +} + +#[doc(hidden)] +impl BooleanChunked { + pub fn extend(&mut self, other: &Self) -> PolarsResult<()> { + update_sorted_flag_before_append::(self, other); + // make sure that we are a single chunk already + if self.chunks.len() > 1 { + self.append(other)?; + self.rechunk_mut(); + return Ok(()); + } + let arr = self.downcast_iter().next().unwrap(); + + // increments 1 + let arr = arr.clone(); + + // now we drop our owned ArrayRefs so that + // decrements 1 + { + self.chunks.clear(); + } + + use Either::*; + + match arr.into_mut() { + Left(immutable) => { + extend_immutable(&immutable, &mut self.chunks, &other.chunks); + }, + Right(mut mutable) => { + for arr in other.downcast_iter() { + mutable.extend_trusted_len(arr.into_iter()) + } + let arr: BooleanArray = mutable.into(); + self.chunks.push(Box::new(arr) as ArrayRef) + }, + } + self.compute_len(); + self.set_sorted_flag(IsSorted::Not); + Ok(()) + } +} + +#[doc(hidden)] +impl ListChunked { + pub fn extend(&mut self, other: &Self) -> PolarsResult<()> { + // TODO! properly implement mutation + // this is harder because we don't know the inner type of the list + self.set_sorted_flag(IsSorted::Not); + self.append(other) + } +} + +#[cfg(feature = "dtype-array")] +#[doc(hidden)] +impl ArrayChunked { + pub fn extend(&mut self, other: &Self) -> PolarsResult<()> { + // TODO! properly implement mutation + // this is harder because we don't know the inner type of the list + self.set_sorted_flag(IsSorted::Not); + self.append(other) + } +} + +#[cfg(feature = "dtype-struct")] +#[doc(hidden)] +impl StructChunked { + pub fn extend(&mut self, other: &Self) -> PolarsResult<()> { + // TODO! properly implement mutation + // this is harder because we don't know the inner type of the list + self.set_sorted_flag(IsSorted::Not); + self.append(other) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + #[allow(clippy::redundant_clone)] + fn test_extend_primitive() -> PolarsResult<()> { + // create a vec with overcapacity, so that we do not trigger a realloc + // this allows us to test if the mutation was successful + + let mut values = Vec::with_capacity(32); + values.extend_from_slice(&[1, 2, 3]); + let mut ca = Int32Chunked::from_vec(PlSmallStr::from_static("a"), values); + let location = ca.cont_slice().unwrap().as_ptr() as usize; + let to_append = Int32Chunked::new(PlSmallStr::from_static("a"), &[4, 5, 6]); + + ca.extend(&to_append)?; + let location2 = ca.cont_slice().unwrap().as_ptr() as usize; + assert_eq!(location, location2); + assert_eq!(ca.cont_slice().unwrap(), [1, 2, 3, 4, 5, 6]); + + // now check if it succeeds if we cannot do this with a mutable. + let _temp = ca.chunks.clone(); + ca.extend(&to_append)?; + let location2 = ca.cont_slice().unwrap().as_ptr() as usize; + assert_ne!(location, location2); + assert_eq!(ca.cont_slice().unwrap(), [1, 2, 3, 4, 5, 6, 4, 5, 6]); + + Ok(()) + } + + #[test] + fn test_extend_string() -> PolarsResult<()> { + let mut ca = StringChunked::new(PlSmallStr::from_static("a"), &["a", "b", "c"]); + let to_append = StringChunked::new(PlSmallStr::from_static("a"), &["a", "b", "e"]); + + ca.extend(&to_append)?; + assert_eq!(ca.len(), 6); + let vals = ca.into_no_null_iter().collect::>(); + assert_eq!(vals, ["a", "b", "c", "a", "b", "e"]); + + Ok(()) + } + + #[test] + fn test_extend_bool() -> PolarsResult<()> { + let mut ca = BooleanChunked::new(PlSmallStr::from_static("a"), [true, false]); + let to_append = BooleanChunked::new(PlSmallStr::from_static("a"), &[false, false]); + + ca.extend(&to_append)?; + assert_eq!(ca.len(), 4); + let vals = ca.into_no_null_iter().collect::>(); + assert_eq!(vals, [true, false, false, false]); + + Ok(()) + } +} diff --git a/crates/polars-core/src/chunked_array/ops/fill_null.rs b/crates/polars-core/src/chunked_array/ops/fill_null.rs new file mode 100644 index 000000000000..3ca3aae9279f --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/fill_null.rs @@ -0,0 +1,378 @@ +use arrow::bitmap::{Bitmap, BitmapBuilder}; +use arrow::legacy::kernels::set::set_at_nulls; +use bytemuck::Zeroable; +use num_traits::{NumCast, One, Zero}; +use polars_utils::itertools::Itertools; + +use crate::prelude::*; + +fn err_fill_null() -> PolarsError { + polars_err!(ComputeError: "could not determine the fill value") +} + +impl Series { + /// Replace None values with one of the following strategies: + /// * Forward fill (replace None with the previous value) + /// * Backward fill (replace None with the next value) + /// * Mean fill (replace None with the mean of the whole array) + /// * Min fill (replace None with the minimum of the whole array) + /// * Max fill (replace None with the maximum of the whole array) + /// * Zero fill (replace None with the value zero) + /// * One fill (replace None with the value one) + /// + /// *NOTE: If you want to fill the Nones with a value use the + /// [`fill_null` operation on `ChunkedArray`](crate::chunked_array::ops::ChunkFillNullValue)*. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// fn example() -> PolarsResult<()> { + /// let s = Column::new("some_missing".into(), &[Some(1), None, Some(2)]); + /// + /// let filled = s.fill_null(FillNullStrategy::Forward(None))?; + /// assert_eq!(Vec::from(filled.i32()?), &[Some(1), Some(1), Some(2)]); + /// + /// let filled = s.fill_null(FillNullStrategy::Backward(None))?; + /// assert_eq!(Vec::from(filled.i32()?), &[Some(1), Some(2), Some(2)]); + /// + /// let filled = s.fill_null(FillNullStrategy::Min)?; + /// assert_eq!(Vec::from(filled.i32()?), &[Some(1), Some(1), Some(2)]); + /// + /// let filled = s.fill_null(FillNullStrategy::Max)?; + /// assert_eq!(Vec::from(filled.i32()?), &[Some(1), Some(2), Some(2)]); + /// + /// let filled = s.fill_null(FillNullStrategy::Mean)?; + /// assert_eq!(Vec::from(filled.i32()?), &[Some(1), Some(1), Some(2)]); + /// + /// let filled = s.fill_null(FillNullStrategy::Zero)?; + /// assert_eq!(Vec::from(filled.i32()?), &[Some(1), Some(0), Some(2)]); + /// + /// let filled = s.fill_null(FillNullStrategy::One)?; + /// assert_eq!(Vec::from(filled.i32()?), &[Some(1), Some(1), Some(2)]); + /// + /// Ok(()) + /// } + /// example(); + /// ``` + pub fn fill_null(&self, strategy: FillNullStrategy) -> PolarsResult { + // Nothing to fill. + let nc = self.null_count(); + if nc == 0 + || (nc == self.len() + && matches!( + strategy, + FillNullStrategy::Forward(_) + | FillNullStrategy::Backward(_) + | FillNullStrategy::Max + | FillNullStrategy::Min + | FillNullStrategy::Mean + )) + { + return Ok(self.clone()); + } + + let physical_type = self.dtype().to_physical(); + + match strategy { + FillNullStrategy::Forward(None) if !physical_type.is_primitive_numeric() => { + fill_forward_gather(self) + }, + FillNullStrategy::Forward(Some(limit)) => fill_forward_gather_limit(self, limit), + FillNullStrategy::Backward(None) if !physical_type.is_primitive_numeric() => { + fill_backward_gather(self) + }, + FillNullStrategy::Backward(Some(limit)) => fill_backward_gather_limit(self, limit), + #[cfg(feature = "dtype-decimal")] + FillNullStrategy::One if self.dtype().is_decimal() => { + let ca = self.decimal().unwrap(); + let precision = ca.precision(); + let scale = ca.scale(); + let fill_value = 10i128.pow(scale as u32); + let phys = ca.as_ref().fill_null_with_values(fill_value)?; + Ok(phys.into_decimal_unchecked(precision, scale).into_series()) + }, + _ => { + let logical_type = self.dtype(); + let s = self.to_physical_repr(); + use DataType::*; + let out = match s.dtype() { + Boolean => fill_null_bool(s.bool().unwrap(), strategy), + String => { + let s = unsafe { s.cast_unchecked(&Binary)? }; + let out = s.fill_null(strategy)?; + return unsafe { out.cast_unchecked(&String) }; + }, + Binary => { + let ca = s.binary().unwrap(); + fill_null_binary(ca, strategy).map(|ca| ca.into_series()) + }, + dt if dt.is_primitive_numeric() => { + with_match_physical_numeric_polars_type!(dt, |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + fill_null_numeric(ca, strategy).map(|ca| ca.into_series()) + }) + }, + dt => { + polars_bail!(InvalidOperation: "fill null strategy not yet supported for dtype: {}", dt) + }, + }?; + unsafe { out.from_physical_unchecked(logical_type) } + }, + } + } +} + +fn fill_forward_numeric<'a, T, I>(ca: &'a ChunkedArray) -> ChunkedArray +where + T: PolarsDataType, + &'a ChunkedArray: IntoIterator, + I: TrustedLen + Iterator>>, + T::ZeroablePhysical<'a>: Copy, +{ + // Compute values. + let values: Vec> = ca + .into_iter() + .scan(T::ZeroablePhysical::zeroed(), |prev, v| { + *prev = v.map(|v| v.into()).unwrap_or(*prev); + Some(*prev) + }) + .collect_trusted(); + + // Compute bitmask. + let num_start_nulls = ca.first_non_null().unwrap_or(ca.len()); + let mut bm = BitmapBuilder::with_capacity(ca.len()); + bm.extend_constant(num_start_nulls, false); + bm.extend_constant(ca.len() - num_start_nulls, true); + ChunkedArray::from_chunk_iter_like( + ca, + [ + T::Array::from_zeroable_vec(values, ca.dtype().to_arrow(CompatLevel::newest())) + .with_validity_typed(bm.into_opt_validity()), + ], + ) +} + +fn fill_backward_numeric<'a, T, I>(ca: &'a ChunkedArray) -> ChunkedArray +where + T: PolarsDataType, + &'a ChunkedArray: IntoIterator, + I: TrustedLen + Iterator>> + DoubleEndedIterator, + T::ZeroablePhysical<'a>: Copy, +{ + // Compute values. + let values: Vec> = ca + .into_iter() + .rev() + .scan(T::ZeroablePhysical::zeroed(), |prev, v| { + *prev = v.map(|v| v.into()).unwrap_or(*prev); + Some(*prev) + }) + .collect_reversed(); + + // Compute bitmask. + let num_end_nulls = ca + .last_non_null() + .map(|i| ca.len() - 1 - i) + .unwrap_or(ca.len()); + let mut bm = BitmapBuilder::with_capacity(ca.len()); + bm.extend_constant(ca.len() - num_end_nulls, true); + bm.extend_constant(num_end_nulls, false); + ChunkedArray::from_chunk_iter_like( + ca, + [ + T::Array::from_zeroable_vec(values, ca.dtype().to_arrow(CompatLevel::newest())) + .with_validity_typed(bm.into_opt_validity()), + ], + ) +} + +fn fill_null_numeric( + ca: &ChunkedArray, + strategy: FillNullStrategy, +) -> PolarsResult> +where + T: PolarsNumericType, + ChunkedArray: ChunkAgg, +{ + // Nothing to fill. + let mut out = match strategy { + FillNullStrategy::Min => { + ca.fill_null_with_values(ChunkAgg::min(ca).ok_or_else(err_fill_null)?)? + }, + FillNullStrategy::Max => { + ca.fill_null_with_values(ChunkAgg::max(ca).ok_or_else(err_fill_null)?)? + }, + FillNullStrategy::Mean => ca.fill_null_with_values( + ca.mean() + .map(|v| NumCast::from(v).unwrap()) + .ok_or_else(err_fill_null)?, + )?, + FillNullStrategy::One => return ca.fill_null_with_values(One::one()), + FillNullStrategy::Zero => return ca.fill_null_with_values(Zero::zero()), + FillNullStrategy::Forward(None) => fill_forward_numeric(ca), + FillNullStrategy::Backward(None) => fill_backward_numeric(ca), + // Handled earlier + FillNullStrategy::Forward(_) => unreachable!(), + FillNullStrategy::Backward(_) => unreachable!(), + }; + out.rename(ca.name().clone()); + Ok(out) +} + +fn fill_with_gather Vec>( + s: &Series, + bits_to_idx: F, +) -> PolarsResult { + let s = s.rechunk(); + let arr = s.chunks()[0].clone(); + let validity = arr.validity().expect("nulls"); + + let idx = bits_to_idx(validity); + + Ok(unsafe { s.take_slice_unchecked(&idx) }) +} + +fn fill_forward_gather(s: &Series) -> PolarsResult { + fill_with_gather(s, |validity| { + let mut last_valid = 0; + validity + .iter() + .enumerate_idx() + .map(|(i, v)| { + if v { + last_valid = i; + i + } else { + last_valid + } + }) + .collect::>() + }) +} + +fn fill_forward_gather_limit(s: &Series, limit: IdxSize) -> PolarsResult { + fill_with_gather(s, |validity| { + let mut last_valid = 0; + let mut conseq_invalid_count = 0; + validity + .iter() + .enumerate_idx() + .map(|(i, v)| { + if v { + last_valid = i; + conseq_invalid_count = 0; + i + } else if conseq_invalid_count < limit { + conseq_invalid_count += 1; + last_valid + } else { + i + } + }) + .collect::>() + }) +} + +fn fill_backward_gather(s: &Series) -> PolarsResult { + fill_with_gather(s, |validity| { + let last = validity.len() as IdxSize - 1; + let mut last_valid = last; + unsafe { + validity + .iter() + .rev() + .enumerate_idx() + .map(|(i, v)| { + if v { + last_valid = last - i; + last - i + } else { + last_valid + } + }) + .trust_my_length((last + 1) as usize) + .collect_reversed::>() + } + }) +} + +fn fill_backward_gather_limit(s: &Series, limit: IdxSize) -> PolarsResult { + fill_with_gather(s, |validity| { + let last = validity.len() as IdxSize - 1; + let mut last_valid = last; + let mut conseq_invalid_count = 0; + unsafe { + validity + .iter() + .rev() + .enumerate_idx() + .map(|(i, v)| { + if v { + last_valid = last - i; + conseq_invalid_count = 0; + last - i + } else if conseq_invalid_count < limit { + conseq_invalid_count += 1; + last_valid + } else { + last - i + } + }) + .trust_my_length((last + 1) as usize) + .collect_reversed() + } + }) +} + +fn fill_null_bool(ca: &BooleanChunked, strategy: FillNullStrategy) -> PolarsResult { + match strategy { + FillNullStrategy::Min => ca + .fill_null_with_values(ca.min().ok_or_else(err_fill_null)?) + .map(|ca| ca.into_series()), + FillNullStrategy::Max => ca + .fill_null_with_values(ca.max().ok_or_else(err_fill_null)?) + .map(|ca| ca.into_series()), + FillNullStrategy::Mean => polars_bail!(opq = mean, "Boolean"), + FillNullStrategy::One => ca.fill_null_with_values(true).map(|ca| ca.into_series()), + FillNullStrategy::Zero => ca.fill_null_with_values(false).map(|ca| ca.into_series()), + FillNullStrategy::Forward(_) => unreachable!(), + FillNullStrategy::Backward(_) => unreachable!(), + } +} + +fn fill_null_binary(ca: &BinaryChunked, strategy: FillNullStrategy) -> PolarsResult { + match strategy { + FillNullStrategy::Min => { + ca.fill_null_with_values(ca.min_binary().ok_or_else(err_fill_null)?) + }, + FillNullStrategy::Max => { + ca.fill_null_with_values(ca.max_binary().ok_or_else(err_fill_null)?) + }, + FillNullStrategy::Zero => ca.fill_null_with_values(&[]), + FillNullStrategy::Forward(_) => unreachable!(), + FillNullStrategy::Backward(_) => unreachable!(), + strat => polars_bail!(InvalidOperation: "fill-null strategy {:?} is not supported", strat), + } +} + +impl ChunkFillNullValue for ChunkedArray +where + T: PolarsNumericType, +{ + fn fill_null_with_values(&self, value: T::Native) -> PolarsResult { + Ok(self.apply_kernel(&|arr| Box::new(set_at_nulls(arr, value)))) + } +} + +impl ChunkFillNullValue for BooleanChunked { + fn fill_null_with_values(&self, value: bool) -> PolarsResult { + self.set(&self.is_null(), Some(value)) + } +} + +impl ChunkFillNullValue<&[u8]> for BinaryChunked { + fn fill_null_with_values(&self, value: &[u8]) -> PolarsResult { + self.set(&self.is_null(), Some(value)) + } +} diff --git a/crates/polars-core/src/chunked_array/ops/filter.rs b/crates/polars-core/src/chunked_array/ops/filter.rs new file mode 100644 index 000000000000..a927f3c6cd99 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/filter.rs @@ -0,0 +1,216 @@ +use polars_compute::filter::filter as filter_fn; + +#[cfg(feature = "object")] +use crate::chunked_array::object::builder::ObjectChunkedBuilder; +use crate::prelude::*; + +macro_rules! check_filter_len { + ($self:expr, $filter:expr) => {{ + polars_ensure!( + $self.len() == $filter.len(), + ShapeMismatch: "filter's length: {} differs from that of the series: {}", + $filter.len(), $self.len() + ) + }}; +} + +impl ChunkFilter for ChunkedArray +where + T: PolarsDataType, +{ + fn filter(&self, filter: &BooleanChunked) -> PolarsResult> { + // Broadcast. + if filter.len() == 1 { + return match filter.get(0) { + Some(true) => Ok(self.clone()), + _ => Ok(self.clear()), + }; + } + check_filter_len!(self, filter); + Ok(unsafe { + arity::binary_unchecked_same_type( + self, + filter, + |left, mask| filter_fn(left, mask), + true, + true, + ) + }) + } +} + +// impl ChunkFilter for BooleanChunked { +// fn filter(&self, filter: &BooleanChunked) -> PolarsResult> { +// // Broadcast. +// if filter.len() == 1 { +// return match filter.get(0) { +// Some(true) => Ok(self.clone()), +// _ => Ok(self.clear()), +// }; +// } +// check_filter_len!(self, filter); +// Ok(unsafe { +// arity::binary_unchecked_same_type( +// self, +// filter, +// |left, mask| filter_fn(left, mask), +// true, +// true, +// ) +// }) +// } +// } + +impl ChunkFilter for StringChunked { + fn filter(&self, filter: &BooleanChunked) -> PolarsResult> { + let out = self.as_binary().filter(filter)?; + unsafe { Ok(out.to_string_unchecked()) } + } +} + +impl ChunkFilter for BinaryChunked { + fn filter(&self, filter: &BooleanChunked) -> PolarsResult> { + // Broadcast. + if filter.len() == 1 { + return match filter.get(0) { + Some(true) => Ok(self.clone()), + _ => Ok(self.clear()), + }; + } + check_filter_len!(self, filter); + Ok(unsafe { + arity::binary_unchecked_same_type( + self, + filter, + |left, mask| filter_fn(left, mask), + true, + true, + ) + }) + } +} + +// impl ChunkFilter for BinaryOffsetChunked { +// fn filter(&self, filter: &BooleanChunked) -> PolarsResult { +// // Broadcast. +// if filter.len() == 1 { +// return match filter.get(0) { +// Some(true) => Ok(self.clone()), +// _ => Ok(self.clear()), +// }; +// } +// check_filter_len!(self, filter); +// Ok(unsafe { +// arity::binary_unchecked_same_type( +// self, +// filter, +// |left, mask| filter_fn(left, mask), +// true, +// true, +// ) +// }) +// } +// } +// +// impl ChunkFilter for ListChunked { +// fn filter(&self, filter: &BooleanChunked) -> PolarsResult { +// // Broadcast. +// if filter.len() == 1 { +// return match filter.get(0) { +// Some(true) => Ok(self.clone()), +// _ => Ok(self.clear()), +// }; +// } +// check_filter_len!(self, filter); +// Ok(unsafe { +// arity::binary_unchecked_same_type( +// self, +// filter, +// |left, mask| filter_fn(left, mask), +// true, +// true, +// ) +// }) +// } +// } +// +// #[cfg(feature = "dtype-struct")] +// impl ChunkFilter for StructChunked { +// fn filter(&self, filter: &BooleanChunked) -> PolarsResult> +// where +// Self: Sized +// { +// if filter.len() == 1 { +// return match filter.get(0) { +// Some(true) => Ok(self.clone()), +// _ => Ok(self.clear()) +// } +// } +// } +// } +// +// #[cfg(feature = "dtype-array")] +// impl ChunkFilter for ArrayChunked { +// fn filter(&self, filter: &BooleanChunked) -> PolarsResult { +// // Broadcast. +// if filter.len() == 1 { +// return match filter.get(0) { +// Some(true) => Ok(self.clone()), +// _ => Ok(ArrayChunked::from_chunk_iter( +// self.name(), +// [FixedSizeListArray::new_empty( +// self.dtype().to_arrow(CompatLevel::newest()), +// )], +// )), +// }; +// } +// check_filter_len!(self, filter); +// Ok(unsafe { +// arity::binary_unchecked_same_type( +// self, +// filter, +// |left, mask| filter_fn(left, mask), +// true, +// true, +// ) +// }) +// } +// } + +#[cfg(feature = "object")] +impl ChunkFilter> for ObjectChunked +where + T: PolarsObject, +{ + fn filter(&self, filter: &BooleanChunked) -> PolarsResult>> + where + Self: Sized, + { + // Broadcast. + if filter.len() == 1 { + return match filter.get(0) { + Some(true) => Ok(self.clone()), + _ => Ok(ObjectChunked::new_empty(self.name().clone())), + }; + } + check_filter_len!(self, filter); + let chunks = self.downcast_iter().collect::>(); + let mut builder = ObjectChunkedBuilder::::new(self.name().clone(), self.len()); + for (idx, mask) in filter.into_iter().enumerate() { + if mask.unwrap_or(false) { + let (chunk_idx, idx) = self.index_to_chunked_index(idx); + unsafe { + let arr = chunks.get_unchecked(chunk_idx); + match arr.is_null(idx) { + true => builder.append_null(), + false => { + let v = arr.value(idx); + builder.append_value(v.clone()) + }, + } + } + } + } + Ok(builder.finish()) + } +} diff --git a/crates/polars-core/src/chunked_array/ops/float_sorted_arg_max.rs b/crates/polars-core/src/chunked_array/ops/float_sorted_arg_max.rs new file mode 100644 index 000000000000..ca4adea9a354 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/float_sorted_arg_max.rs @@ -0,0 +1,69 @@ +use num_traits::Float; + +use self::search_sorted::{SearchSortedSide, binary_search_ca}; +use crate::prelude::*; + +impl ChunkedArray +where + T: PolarsFloatType, + T::Native: Float, +{ + fn float_arg_max_sorted_ascending(&self) -> usize { + let ca = self; + debug_assert!(ca.is_sorted_ascending_flag()); + + let maybe_max_idx = ca.last_non_null().unwrap(); + let maybe_max = unsafe { ca.value_unchecked(maybe_max_idx) }; + if !maybe_max.is_nan() { + return maybe_max_idx; + } + + let search_val = std::iter::once(Some(T::Native::nan())); + let idx = binary_search_ca(ca, search_val, SearchSortedSide::Left, false)[0] as usize; + idx.saturating_sub(1) + } + + fn float_arg_max_sorted_descending(&self) -> usize { + let ca = self; + debug_assert!(ca.is_sorted_descending_flag()); + + let maybe_max_idx = ca.first_non_null().unwrap(); + + let maybe_max = unsafe { ca.value_unchecked(maybe_max_idx) }; + if !maybe_max.is_nan() { + return maybe_max_idx; + } + + let search_val = std::iter::once(Some(T::Native::nan())); + let idx = binary_search_ca(ca, search_val, SearchSortedSide::Right, true)[0] as usize; + if idx == ca.len() { idx - 1 } else { idx } + } +} + +/// # Safety +/// `ca` has a float dtype, has at least 1 non-null value and is sorted ascending +pub fn float_arg_max_sorted_ascending(ca: &ChunkedArray) -> usize +where + T: PolarsNumericType, +{ + with_match_physical_float_polars_type!(ca.dtype(), |$T| { + let ca: &ChunkedArray<$T> = unsafe { + &*(ca as *const ChunkedArray as *const ChunkedArray<$T>) + }; + ca.float_arg_max_sorted_ascending() + }) +} + +/// # Safety +/// `ca` has a float dtype, has at least 1 non-null value and is sorted descending +pub fn float_arg_max_sorted_descending(ca: &ChunkedArray) -> usize +where + T: PolarsNumericType, +{ + with_match_physical_float_polars_type!(ca.dtype(), |$T| { + let ca: &ChunkedArray<$T> = unsafe { + &*(ca as *const ChunkedArray as *const ChunkedArray<$T>) + }; + ca.float_arg_max_sorted_descending() + }) +} diff --git a/crates/polars-core/src/chunked_array/ops/for_each.rs b/crates/polars-core/src/chunked_array/ops/for_each.rs new file mode 100644 index 000000000000..42713e0cdff2 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/for_each.rs @@ -0,0 +1,15 @@ +use crate::prelude::*; + +impl ChunkedArray +where + T: PolarsDataType, +{ + pub fn for_each<'a, F>(&'a self, mut op: F) + where + F: FnMut(Option>), + { + self.downcast_iter().for_each(|arr| { + arr.iter().for_each(&mut op); + }) + } +} diff --git a/crates/polars-core/src/chunked_array/ops/full.rs b/crates/polars-core/src/chunked_array/ops/full.rs new file mode 100644 index 000000000000..367c9d000dd7 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/full.rs @@ -0,0 +1,233 @@ +use arrow::bitmap::Bitmap; + +use crate::chunked_array::builder::get_list_builder; +use crate::prelude::*; +use crate::series::IsSorted; + +impl ChunkFull for ChunkedArray +where + T: PolarsNumericType, +{ + fn full(name: PlSmallStr, value: T::Native, length: usize) -> Self { + let data = vec![value; length]; + let mut out = ChunkedArray::from_vec(name, data); + out.set_sorted_flag(IsSorted::Ascending); + out + } +} + +impl ChunkFullNull for ChunkedArray +where + T: PolarsNumericType, +{ + fn full_null(name: PlSmallStr, length: usize) -> Self { + let arr = PrimitiveArray::new_null(T::get_dtype().to_arrow(CompatLevel::newest()), length); + ChunkedArray::with_chunk(name, arr) + } +} +impl ChunkFull for BooleanChunked { + fn full(name: PlSmallStr, value: bool, length: usize) -> Self { + let bits = Bitmap::new_with_value(value, length); + let arr = BooleanArray::from_data_default(bits, None); + let mut out = BooleanChunked::with_chunk(name, arr); + out.set_sorted_flag(IsSorted::Ascending); + out + } +} + +impl ChunkFullNull for BooleanChunked { + fn full_null(name: PlSmallStr, length: usize) -> Self { + let arr = BooleanArray::new_null(ArrowDataType::Boolean, length); + ChunkedArray::with_chunk(name, arr) + } +} + +impl<'a> ChunkFull<&'a str> for StringChunked { + fn full(name: PlSmallStr, value: &'a str, length: usize) -> Self { + let mut builder = StringChunkedBuilder::new(name, length); + builder.chunk_builder.extend_constant(length, Some(value)); + let mut out = builder.finish(); + out.set_sorted_flag(IsSorted::Ascending); + out + } +} + +impl ChunkFullNull for StringChunked { + fn full_null(name: PlSmallStr, length: usize) -> Self { + let arr = Utf8ViewArray::new_null(DataType::String.to_arrow(CompatLevel::newest()), length); + ChunkedArray::with_chunk(name, arr) + } +} + +impl<'a> ChunkFull<&'a [u8]> for BinaryChunked { + fn full(name: PlSmallStr, value: &'a [u8], length: usize) -> Self { + let mut builder = BinaryChunkedBuilder::new(name, length); + builder.chunk_builder.extend_constant(length, Some(value)); + let mut out = builder.finish(); + out.set_sorted_flag(IsSorted::Ascending); + out + } +} + +impl ChunkFullNull for BinaryChunked { + fn full_null(name: PlSmallStr, length: usize) -> Self { + let arr = + BinaryViewArray::new_null(DataType::Binary.to_arrow(CompatLevel::newest()), length); + ChunkedArray::with_chunk(name, arr) + } +} + +impl<'a> ChunkFull<&'a [u8]> for BinaryOffsetChunked { + fn full(name: PlSmallStr, value: &'a [u8], length: usize) -> Self { + let mut mutable = MutableBinaryArray::with_capacities(length, length * value.len()); + mutable.extend_values(std::iter::repeat_n(value, length)); + let arr: BinaryArray = mutable.into(); + let mut out = ChunkedArray::with_chunk(name, arr); + out.set_sorted_flag(IsSorted::Ascending); + out + } +} + +impl ChunkFullNull for BinaryOffsetChunked { + fn full_null(name: PlSmallStr, length: usize) -> Self { + let arr = BinaryArray::::new_null( + DataType::BinaryOffset.to_arrow(CompatLevel::newest()), + length, + ); + ChunkedArray::with_chunk(name, arr) + } +} + +impl ChunkFull<&Series> for ListChunked { + fn full(name: PlSmallStr, value: &Series, length: usize) -> ListChunked { + let mut builder = get_list_builder(value.dtype(), value.len() * length, length, name); + for _ in 0..length { + builder.append_series(value).unwrap(); + } + builder.finish() + } +} + +impl ChunkFullNull for ListChunked { + fn full_null(name: PlSmallStr, length: usize) -> ListChunked { + ListChunked::full_null_with_dtype(name, length, &DataType::Null) + } +} + +#[cfg(feature = "dtype-array")] +impl ArrayChunked { + pub fn full_null_with_dtype( + name: PlSmallStr, + length: usize, + inner_dtype: &DataType, + width: usize, + ) -> ArrayChunked { + let arr = FixedSizeListArray::new_null( + ArrowDataType::FixedSizeList( + Box::new(ArrowField::new( + PlSmallStr::from_static("item"), + inner_dtype.to_physical().to_arrow(CompatLevel::newest()), + true, + )), + width, + ), + length, + ); + // SAFETY: physical type matches the logical. + unsafe { + ChunkedArray::from_chunks_and_dtype( + name, + vec![Box::new(arr)], + DataType::Array(Box::new(inner_dtype.clone()), width), + ) + } + } +} + +#[cfg(feature = "dtype-array")] +impl ChunkFull<&Series> for ArrayChunked { + fn full(name: PlSmallStr, value: &Series, length: usize) -> ArrayChunked { + let width = value.len(); + let dtype = value.dtype(); + let arrow_dtype = ArrowDataType::FixedSizeList( + Box::new(ArrowField::new( + PlSmallStr::from_static("item"), + dtype.to_physical().to_arrow(CompatLevel::newest()), + true, + )), + width, + ); + let value = value.rechunk().chunks()[0].clone(); + let arr = FixedSizeListArray::full(length, value, arrow_dtype); + + // SAFETY: physical type matches the logical. + unsafe { + ChunkedArray::from_chunks_and_dtype( + name, + vec![Box::new(arr)], + DataType::Array(Box::new(dtype.clone()), width), + ) + } + } +} + +#[cfg(feature = "dtype-array")] +impl ChunkFullNull for ArrayChunked { + fn full_null(name: PlSmallStr, length: usize) -> ArrayChunked { + ArrayChunked::full_null_with_dtype(name, length, &DataType::Null, 0) + } +} + +impl ListChunked { + pub fn full_null_with_dtype( + name: PlSmallStr, + length: usize, + inner_dtype: &DataType, + ) -> ListChunked { + let arr: ListArray = ListArray::new_null( + ArrowDataType::LargeList(Box::new(ArrowField::new( + PlSmallStr::from_static("item"), + inner_dtype.to_physical().to_arrow(CompatLevel::newest()), + true, + ))), + length, + ); + // SAFETY: physical type matches the logical. + unsafe { + ChunkedArray::from_chunks_and_dtype( + name, + vec![Box::new(arr)], + DataType::List(Box::new(inner_dtype.clone())), + ) + } + } +} +#[cfg(feature = "dtype-struct")] +impl ChunkFullNull for StructChunked { + fn full_null(name: PlSmallStr, length: usize) -> StructChunked { + StructChunked::from_series(name, length, [].iter()) + .unwrap() + .with_outer_validity(Some(Bitmap::new_zeroed(length))) + } +} + +#[cfg(feature = "object")] +impl ChunkFull for ObjectChunked { + fn full(name: PlSmallStr, value: T, length: usize) -> Self + where + Self: Sized, + { + let mut ca: Self = (0..length).map(|_| Some(value.clone())).collect(); + ca.rename(name); + ca + } +} + +#[cfg(feature = "object")] +impl ChunkFullNull for ObjectChunked { + fn full_null(name: PlSmallStr, length: usize) -> ObjectChunked { + let mut ca: Self = (0..length).map(|_| None).collect(); + ca.rename(name); + ca + } +} diff --git a/crates/polars-core/src/chunked_array/ops/gather.rs b/crates/polars-core/src/chunked_array/ops/gather.rs new file mode 100644 index 000000000000..735dddc8aaf7 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/gather.rs @@ -0,0 +1,378 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use arrow::bitmap::Bitmap; +use arrow::bitmap::bitmask::BitMask; +use polars_compute::gather::take_unchecked; +use polars_error::polars_ensure; +use polars_utils::index::check_bounds; + +use crate::prelude::*; +use crate::series::IsSorted; + +pub fn check_bounds_nulls(idx: &PrimitiveArray, len: IdxSize) -> PolarsResult<()> { + let mask = BitMask::from_bitmap(idx.validity().unwrap()); + + // We iterate in chunks to make the inner loop branch-free. + for (block_idx, block) in idx.values().chunks(32).enumerate() { + let mut in_bounds = 0; + for (i, x) in block.iter().enumerate() { + in_bounds |= ((*x < len) as u32) << i; + } + let m = mask.get_u32(32 * block_idx); + polars_ensure!(m == m & in_bounds, ComputeError: "gather indices are out of bounds"); + } + Ok(()) +} + +pub fn check_bounds_ca(indices: &IdxCa, len: IdxSize) -> PolarsResult<()> { + let all_valid = indices.downcast_iter().all(|a| { + if a.null_count() == 0 { + check_bounds(a.values(), len).is_ok() + } else { + check_bounds_nulls(a, len).is_ok() + } + }); + polars_ensure!(all_valid, OutOfBounds: "gather indices are out of bounds"); + Ok(()) +} + +impl + ?Sized> ChunkTake for ChunkedArray +where + ChunkedArray: ChunkTakeUnchecked, +{ + /// Gather values from ChunkedArray by index. + fn take(&self, indices: &I) -> PolarsResult { + check_bounds(indices.as_ref(), self.len() as IdxSize)?; + + // SAFETY: we just checked the indices are valid. + Ok(unsafe { self.take_unchecked(indices) }) + } +} + +impl ChunkTake for ChunkedArray +where + ChunkedArray: ChunkTakeUnchecked, +{ + /// Gather values from ChunkedArray by index. + fn take(&self, indices: &IdxCa) -> PolarsResult { + check_bounds_ca(indices, self.len() as IdxSize)?; + + // SAFETY: we just checked the indices are valid. + Ok(unsafe { self.take_unchecked(indices) }) + } +} + +/// Computes cumulative lengths for efficient branchless binary search +/// lookup. The first element is always 0, and the last length of arrs +/// is always ignored (as we already checked that all indices are +/// in-bounds we don't need to check against the last length). +fn cumulative_lengths(arrs: &[&A]) -> Vec { + let mut ret = Vec::with_capacity(arrs.len()); + let mut cumsum: IdxSize = 0; + for arr in arrs { + ret.push(cumsum); + cumsum = cumsum.checked_add(arr.len().try_into().unwrap()).unwrap(); + } + ret +} + +#[rustfmt::skip] +#[inline] +fn resolve_chunked_idx(idx: IdxSize, cumlens: &[IdxSize]) -> (usize, usize) { + let chunk_idx = cumlens.partition_point(|cl| idx >= *cl) - 1; + (chunk_idx, (idx - cumlens[chunk_idx]) as usize) +} + +#[inline] +unsafe fn target_value_unchecked<'a, A: StaticArray>( + targets: &[&'a A], + cumlens: &[IdxSize], + idx: IdxSize, +) -> A::ValueT<'a> { + let (chunk_idx, arr_idx) = resolve_chunked_idx(idx, cumlens); + let arr = targets.get_unchecked(chunk_idx); + arr.value_unchecked(arr_idx) +} + +#[inline] +unsafe fn target_get_unchecked<'a, A: StaticArray>( + targets: &[&'a A], + cumlens: &[IdxSize], + idx: IdxSize, +) -> Option> { + let (chunk_idx, arr_idx) = resolve_chunked_idx(idx, cumlens); + let arr = targets.get_unchecked(chunk_idx); + arr.get_unchecked(arr_idx) +} + +unsafe fn gather_idx_array_unchecked( + dtype: ArrowDataType, + targets: &[&A], + has_nulls: bool, + indices: &[IdxSize], +) -> A { + let it = indices.iter().copied(); + if targets.len() == 1 { + let target = targets.first().unwrap(); + if has_nulls { + it.map(|i| target.get_unchecked(i as usize)) + .collect_arr_trusted_with_dtype(dtype) + } else if let Some(sl) = target.as_slice() { + // Avoid the Arc overhead from value_unchecked. + it.map(|i| sl.get_unchecked(i as usize).clone()) + .collect_arr_trusted_with_dtype(dtype) + } else { + it.map(|i| target.value_unchecked(i as usize)) + .collect_arr_trusted_with_dtype(dtype) + } + } else { + let cumlens = cumulative_lengths(targets); + if has_nulls { + it.map(|i| target_get_unchecked(targets, &cumlens, i)) + .collect_arr_trusted_with_dtype(dtype) + } else { + it.map(|i| target_value_unchecked(targets, &cumlens, i)) + .collect_arr_trusted_with_dtype(dtype) + } + } +} + +impl + ?Sized> ChunkTakeUnchecked for ChunkedArray +where + T: PolarsDataType, +{ + /// Gather values from ChunkedArray by index. + unsafe fn take_unchecked(&self, indices: &I) -> Self { + let ca = self; + let targets: Vec<_> = ca.downcast_iter().collect(); + let arr = gather_idx_array_unchecked( + ca.dtype().to_arrow(CompatLevel::newest()), + &targets, + ca.null_count() > 0, + indices.as_ref(), + ); + ChunkedArray::from_chunk_iter_like(ca, [arr]) + } +} + +pub fn _update_gather_sorted_flag(sorted_arr: IsSorted, sorted_idx: IsSorted) -> IsSorted { + use crate::series::IsSorted::*; + match (sorted_arr, sorted_idx) { + (_, Not) => Not, + (Not, _) => Not, + (Ascending, Ascending) => Ascending, + (Ascending, Descending) => Descending, + (Descending, Ascending) => Descending, + (Descending, Descending) => Ascending, + } +} + +impl ChunkTakeUnchecked for ChunkedArray +where + T: PolarsDataType, +{ + /// Gather values from ChunkedArray by index. + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self { + let ca = self; + let targets_have_nulls = ca.null_count() > 0; + let targets: Vec<_> = ca.downcast_iter().collect(); + + let chunks = indices.downcast_iter().map(|idx_arr| { + let dtype = ca.dtype().to_arrow(CompatLevel::newest()); + if idx_arr.null_count() == 0 { + gather_idx_array_unchecked(dtype, &targets, targets_have_nulls, idx_arr.values()) + } else if targets.len() == 1 { + let target = targets.first().unwrap(); + if targets_have_nulls { + idx_arr + .iter() + .map(|i| target.get_unchecked(*i? as usize)) + .collect_arr_trusted_with_dtype(dtype) + } else { + idx_arr + .iter() + .map(|i| Some(target.value_unchecked(*i? as usize))) + .collect_arr_trusted_with_dtype(dtype) + } + } else { + let cumlens = cumulative_lengths(&targets); + if targets_have_nulls { + idx_arr + .iter() + .map(|i| target_get_unchecked(&targets, &cumlens, *i?)) + .collect_arr_trusted_with_dtype(dtype) + } else { + idx_arr + .iter() + .map(|i| Some(target_value_unchecked(&targets, &cumlens, *i?))) + .collect_arr_trusted_with_dtype(dtype) + } + } + }); + + let mut out = ChunkedArray::from_chunk_iter_like(ca, chunks); + let sorted_flag = _update_gather_sorted_flag(ca.is_sorted_flag(), indices.is_sorted_flag()); + + out.set_sorted_flag(sorted_flag); + out + } +} + +impl ChunkTakeUnchecked for BinaryChunked { + /// Gather values from ChunkedArray by index. + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self { + let ca = self; + let targets_have_nulls = ca.null_count() > 0; + let targets: Vec<_> = ca.downcast_iter().collect(); + + let chunks = indices.downcast_iter().map(|idx_arr| { + let dtype = ca.dtype().to_arrow(CompatLevel::newest()); + if targets.len() == 1 { + let target = targets.first().unwrap(); + take_unchecked(&**target, idx_arr) + } else { + let cumlens = cumulative_lengths(&targets); + if targets_have_nulls { + let arr: BinaryViewArray = idx_arr + .iter() + .map(|i| target_get_unchecked(&targets, &cumlens, *i?)) + .collect_arr_trusted_with_dtype(dtype); + arr.to_boxed() + } else { + let arr: BinaryViewArray = idx_arr + .iter() + .map(|i| Some(target_value_unchecked(&targets, &cumlens, *i?))) + .collect_arr_trusted_with_dtype(dtype); + arr.to_boxed() + } + } + }); + + let mut out = ChunkedArray::from_chunks(ca.name().clone(), chunks.collect()); + let sorted_flag = _update_gather_sorted_flag(ca.is_sorted_flag(), indices.is_sorted_flag()); + out.set_sorted_flag(sorted_flag); + out + } +} + +impl ChunkTakeUnchecked for StringChunked { + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self { + let ca = self; + let targets_have_nulls = ca.null_count() > 0; + let targets: Vec<_> = ca.downcast_iter().collect(); + + let chunks = indices.downcast_iter().map(|idx_arr| { + let dtype = ca.dtype().to_arrow(CompatLevel::newest()); + if targets.len() == 1 { + let target = targets.first().unwrap(); + take_unchecked(&**target, idx_arr) + } else { + let cumlens = cumulative_lengths(&targets); + if targets_have_nulls { + let arr: Utf8ViewArray = idx_arr + .iter() + .map(|i| target_get_unchecked(&targets, &cumlens, *i?)) + .collect_arr_trusted_with_dtype(dtype); + arr.to_boxed() + } else { + let arr: Utf8ViewArray = idx_arr + .iter() + .map(|i| Some(target_value_unchecked(&targets, &cumlens, *i?))) + .collect_arr_trusted_with_dtype(dtype); + arr.to_boxed() + } + } + }); + + let mut out = ChunkedArray::from_chunks(ca.name().clone(), chunks.collect()); + let sorted_flag = _update_gather_sorted_flag(ca.is_sorted_flag(), indices.is_sorted_flag()); + out.set_sorted_flag(sorted_flag); + out + } +} + +impl + ?Sized> ChunkTakeUnchecked for BinaryChunked { + /// Gather values from ChunkedArray by index. + unsafe fn take_unchecked(&self, indices: &I) -> Self { + let indices = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref()); + self.take_unchecked(&indices) + } +} + +impl + ?Sized> ChunkTakeUnchecked for StringChunked { + /// Gather values from ChunkedArray by index. + unsafe fn take_unchecked(&self, indices: &I) -> Self { + let indices = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref()); + self.take_unchecked(&indices) + } +} + +#[cfg(feature = "dtype-struct")] +impl ChunkTakeUnchecked for StructChunked { + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self { + let a = self.rechunk(); + let index = indices.rechunk(); + + let chunks = a + .downcast_iter() + .zip(index.downcast_iter()) + .map(|(arr, idx)| take_unchecked(arr, idx)) + .collect::>(); + self.copy_with_chunks(chunks) + } +} + +#[cfg(feature = "dtype-struct")] +impl + ?Sized> ChunkTakeUnchecked for StructChunked { + unsafe fn take_unchecked(&self, indices: &I) -> Self { + let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref()); + self.take_unchecked(&idx) + } +} + +impl IdxCa { + pub fn with_nullable_idx T>(idx: &[NullableIdxSize], f: F) -> T { + let validity: Bitmap = idx.iter().map(|idx| !idx.is_null_idx()).collect_trusted(); + let idx = bytemuck::cast_slice::<_, IdxSize>(idx); + let arr = unsafe { arrow::ffi::mmap::slice(idx) }; + let arr = arr.with_validity_typed(Some(validity)); + let ca = IdxCa::with_chunk(PlSmallStr::EMPTY, arr); + + f(&ca) + } +} + +#[cfg(feature = "dtype-array")] +impl ChunkTakeUnchecked for ArrayChunked { + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self { + let chunks = vec![take_unchecked( + self.rechunk().downcast_as_array(), + indices.rechunk().downcast_as_array(), + )]; + self.copy_with_chunks(chunks) + } +} + +#[cfg(feature = "dtype-array")] +impl + ?Sized> ChunkTakeUnchecked for ArrayChunked { + unsafe fn take_unchecked(&self, indices: &I) -> Self { + let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref()); + self.take_unchecked(&idx) + } +} + +impl ChunkTakeUnchecked for ListChunked { + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self { + let chunks = vec![take_unchecked( + self.rechunk().downcast_as_array(), + indices.rechunk().downcast_as_array(), + )]; + self.copy_with_chunks(chunks) + } +} + +impl + ?Sized> ChunkTakeUnchecked for ListChunked { + unsafe fn take_unchecked(&self, indices: &I) -> Self { + let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref()); + self.take_unchecked(&idx) + } +} diff --git a/crates/polars-core/src/chunked_array/ops/mod.rs b/crates/polars-core/src/chunked_array/ops/mod.rs new file mode 100644 index 000000000000..7863674dc7c0 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/mod.rs @@ -0,0 +1,660 @@ +//! Traits for miscellaneous operations on ChunkedArray +use arrow::offset::OffsetsBuffer; +use polars_compute::rolling::QuantileMethod; + +use crate::prelude::*; + +pub(crate) mod aggregate; +pub(crate) mod any_value; +pub(crate) mod append; +mod apply; +#[cfg(feature = "approx_unique")] +mod approx_n_unique; +pub mod arity; +mod bit_repr; +mod bits; +#[cfg(feature = "bitwise")] +mod bitwise_reduce; +pub(crate) mod chunkops; +pub(crate) mod compare_inner; +#[cfg(feature = "dtype-decimal")] +mod decimal; +pub(crate) mod downcast; +pub(crate) mod explode; +mod explode_and_offsets; +mod extend; +pub mod fill_null; +mod filter; +pub mod float_sorted_arg_max; +mod for_each; +pub mod full; +pub mod gather; +pub(crate) mod nulls; +mod reverse; +#[cfg(feature = "rolling_window")] +pub(crate) mod rolling_window; +pub mod row_encode; +pub mod search_sorted; +mod set; +mod shift; +pub mod sort; +#[cfg(feature = "algorithm_group_by")] +pub(crate) mod unique; +#[cfg(feature = "zip_with")] +pub mod zip; + +pub use chunkops::_set_check_length; +#[cfg(feature = "serde-lazy")] +use serde::{Deserialize, Serialize}; +pub use sort::options::*; + +use crate::chunked_array::cast::CastOptions; +use crate::series::{BitRepr, IsSorted}; +#[cfg(feature = "reinterpret")] +pub trait Reinterpret { + fn reinterpret_signed(&self) -> Series { + unimplemented!() + } + + fn reinterpret_unsigned(&self) -> Series { + unimplemented!() + } +} + +/// Transmute [`ChunkedArray`] to bit representation. +/// This is useful in hashing context and reduces no. +/// of compiled code paths. +pub(crate) trait ToBitRepr { + fn to_bit_repr(&self) -> BitRepr; +} + +pub trait ChunkAnyValue { + /// Get a single value. Beware this is slow. + /// If you need to use this slightly performant, cast Categorical to UInt32 + /// + /// # Safety + /// Does not do any bounds checking. + unsafe fn get_any_value_unchecked(&self, index: usize) -> AnyValue; + + /// Get a single value. Beware this is slow. + fn get_any_value(&self, index: usize) -> PolarsResult; +} + +/// Explode/flatten a List or String Series +pub trait ChunkExplode { + fn explode(&self) -> PolarsResult { + self.explode_and_offsets().map(|t| t.0) + } + fn offsets(&self) -> PolarsResult>; + fn explode_and_offsets(&self) -> PolarsResult<(Series, OffsetsBuffer)>; +} + +pub trait ChunkBytes { + fn to_byte_slices(&self) -> Vec<&[u8]>; +} + +/// This differs from ChunkWindowCustom and ChunkWindow +/// by not using a fold aggregator, but reusing a `Series` wrapper and calling `Series` aggregators. +/// This likely is a bit slower than ChunkWindow +#[cfg(feature = "rolling_window")] +pub trait ChunkRollApply: AsRefDataType { + fn rolling_map( + &self, + _f: &dyn Fn(&Series) -> Series, + _options: RollingOptionsFixedWindow, + ) -> PolarsResult + where + Self: Sized, + { + polars_bail!(opq = rolling_map, self.as_ref_dtype()); + } +} + +pub trait ChunkTake: ChunkTakeUnchecked { + /// Gather values from ChunkedArray by index. + fn take(&self, indices: &Idx) -> PolarsResult + where + Self: Sized; +} + +pub trait ChunkTakeUnchecked { + /// Gather values from ChunkedArray by index. + /// + /// # Safety + /// The non-null indices must be valid. + unsafe fn take_unchecked(&self, indices: &Idx) -> Self; +} + +/// Create a `ChunkedArray` with new values by index or by boolean mask. +/// +/// Note that these operations clone data. This is however the only way we can modify at mask or +/// index level as the underlying Arrow arrays are immutable. +pub trait ChunkSet<'a, A, B> { + /// Set the values at indexes `idx` to some optional value `Option`. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let ca = UInt32Chunked::new("a".into(), &[1, 2, 3]); + /// let new = ca.scatter_single(vec![0, 1], Some(10)).unwrap(); + /// + /// assert_eq!(Vec::from(&new), &[Some(10), Some(10), Some(3)]); + /// ``` + fn scatter_single>( + &'a self, + idx: I, + opt_value: Option, + ) -> PolarsResult + where + Self: Sized; + + /// Set the values at indexes `idx` by applying a closure to these values. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let ca = Int32Chunked::new("a".into(), &[1, 2, 3]); + /// let new = ca.scatter_with(vec![0, 1], |opt_v| opt_v.map(|v| v - 5)).unwrap(); + /// + /// assert_eq!(Vec::from(&new), &[Some(-4), Some(-3), Some(3)]); + /// ``` + fn scatter_with, F>( + &'a self, + idx: I, + f: F, + ) -> PolarsResult + where + Self: Sized, + F: Fn(Option) -> Option; + /// Set the values where the mask evaluates to `true` to some optional value `Option`. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let ca = Int32Chunked::new("a".into(), &[1, 2, 3]); + /// let mask = BooleanChunked::new("mask".into(), &[false, true, false]); + /// let new = ca.set(&mask, Some(5)).unwrap(); + /// assert_eq!(Vec::from(&new), &[Some(1), Some(5), Some(3)]); + /// ``` + fn set(&'a self, mask: &BooleanChunked, opt_value: Option) -> PolarsResult + where + Self: Sized; +} + +/// Cast `ChunkedArray` to `ChunkedArray` +pub trait ChunkCast { + /// Cast a [`ChunkedArray`] to [`DataType`] + fn cast(&self, dtype: &DataType) -> PolarsResult { + self.cast_with_options(dtype, CastOptions::NonStrict) + } + + /// Cast a [`ChunkedArray`] to [`DataType`] + fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult; + + /// Does not check if the cast is a valid one and may over/underflow + /// + /// # Safety + /// - This doesn't do utf8 validation checking when casting from binary + /// - This doesn't do categorical bound checking when casting from UInt32 + unsafe fn cast_unchecked(&self, dtype: &DataType) -> PolarsResult; +} + +/// Fastest way to do elementwise operations on a [`ChunkedArray`] when the operation is cheaper than +/// branching due to null checking. +pub trait ChunkApply<'a, T> { + type FuncRet; + + /// Apply a closure elementwise. This is fastest when the null check branching is more expensive + /// than the closure application. Often it is. + /// + /// Null values remain null. + /// + /// # Example + /// + /// ``` + /// use polars_core::prelude::*; + /// fn double(ca: &UInt32Chunked) -> UInt32Chunked { + /// ca.apply_values(|v| v * 2) + /// } + /// ``` + #[must_use] + fn apply_values(&'a self, f: F) -> Self + where + F: Fn(T) -> Self::FuncRet + Copy; + + /// Apply a closure elementwise including null values. + #[must_use] + fn apply(&'a self, f: F) -> Self + where + F: Fn(Option) -> Option + Copy; + + /// Apply a closure elementwise and write results to a mutable slice. + fn apply_to_slice(&'a self, f: F, slice: &mut [S]) + // (value of chunkedarray, value of slice) -> value of slice + where + F: Fn(Option, &S) -> S; +} + +/// Aggregation operations. +pub trait ChunkAgg { + /// Aggregate the sum of the ChunkedArray. + /// Returns `None` if not implemented for `T`. + /// If the array is empty, `0` is returned + fn sum(&self) -> Option { + None + } + + fn _sum_as_f64(&self) -> f64; + + fn min(&self) -> Option { + None + } + + /// Returns the maximum value in the array, according to the natural order. + /// Returns `None` if the array is empty or only contains null values. + fn max(&self) -> Option { + None + } + + fn min_max(&self) -> Option<(T, T)> { + Some((self.min()?, self.max()?)) + } + + /// Returns the mean value in the array. + /// Returns `None` if the array is empty or only contains null values. + fn mean(&self) -> Option { + None + } +} + +/// Quantile and median aggregation. +pub trait ChunkQuantile { + /// Returns the mean value in the array. + /// Returns `None` if the array is empty or only contains null values. + fn median(&self) -> Option { + None + } + /// Aggregate a given quantile of the ChunkedArray. + /// Returns `None` if the array is empty or only contains null values. + fn quantile(&self, _quantile: f64, _method: QuantileMethod) -> PolarsResult> { + Ok(None) + } +} + +/// Variance and standard deviation aggregation. +pub trait ChunkVar { + /// Compute the variance of this ChunkedArray/Series. + fn var(&self, _ddof: u8) -> Option { + None + } + + /// Compute the standard deviation of this ChunkedArray/Series. + fn std(&self, _ddof: u8) -> Option { + None + } +} + +/// Bitwise Reduction Operations. +#[cfg(feature = "bitwise")] +pub trait ChunkBitwiseReduce { + type Physical; + + fn and_reduce(&self) -> Option; + fn or_reduce(&self) -> Option; + fn xor_reduce(&self) -> Option; +} + +/// Compare [`Series`] and [`ChunkedArray`]'s and get a `boolean` mask that +/// can be used to filter rows. +/// +/// # Example +/// +/// ``` +/// use polars_core::prelude::*; +/// fn filter_all_ones(df: &DataFrame) -> PolarsResult { +/// let mask = df +/// .column("column_a")? +/// .as_materialized_series() +/// .equal(1)?; +/// +/// df.filter(&mask) +/// } +/// ``` +pub trait ChunkCompareEq { + type Item; + + /// Check for equality. + fn equal(&self, rhs: Rhs) -> Self::Item; + + /// Check for equality where `None == None`. + fn equal_missing(&self, rhs: Rhs) -> Self::Item; + + /// Check for inequality. + fn not_equal(&self, rhs: Rhs) -> Self::Item; + + /// Check for inequality where `None == None`. + fn not_equal_missing(&self, rhs: Rhs) -> Self::Item; +} + +/// Compare [`Series`] and [`ChunkedArray`]'s using inequality operators (`<`, `>=`, etc.) and get +/// a `boolean` mask that can be used to filter rows. +pub trait ChunkCompareIneq { + type Item; + + /// Greater than comparison. + fn gt(&self, rhs: Rhs) -> Self::Item; + + /// Greater than or equal comparison. + fn gt_eq(&self, rhs: Rhs) -> Self::Item; + + /// Less than comparison. + fn lt(&self, rhs: Rhs) -> Self::Item; + + /// Less than or equal comparison + fn lt_eq(&self, rhs: Rhs) -> Self::Item; +} + +/// Get unique values in a `ChunkedArray` +pub trait ChunkUnique { + // We don't return Self to be able to use AutoRef specialization + /// Get unique values of a ChunkedArray + fn unique(&self) -> PolarsResult + where + Self: Sized; + + /// Get first index of the unique values in a `ChunkedArray`. + /// This Vec is sorted. + fn arg_unique(&self) -> PolarsResult; + + /// Number of unique values in the `ChunkedArray` + fn n_unique(&self) -> PolarsResult { + self.arg_unique().map(|v| v.len()) + } +} + +#[cfg(feature = "approx_unique")] +pub trait ChunkApproxNUnique { + fn approx_n_unique(&self) -> IdxSize; +} + +/// Sort operations on `ChunkedArray`. +pub trait ChunkSort { + #[allow(unused_variables)] + fn sort_with(&self, options: SortOptions) -> ChunkedArray; + + /// Returned a sorted `ChunkedArray`. + fn sort(&self, descending: bool) -> ChunkedArray; + + /// Retrieve the indexes needed to sort this array. + fn arg_sort(&self, options: SortOptions) -> IdxCa; + + /// Retrieve the indexes need to sort this and the other arrays. + #[allow(unused_variables)] + fn arg_sort_multiple( + &self, + by: &[Column], + _options: &SortMultipleOptions, + ) -> PolarsResult { + polars_bail!(opq = arg_sort_multiple, T::get_dtype()); + } +} + +pub type FillNullLimit = Option; + +#[derive(Copy, Clone, Debug, PartialEq, Hash)] +#[cfg_attr(feature = "serde-lazy", derive(Serialize, Deserialize))] +pub enum FillNullStrategy { + /// previous value in array + Backward(FillNullLimit), + /// next value in array + Forward(FillNullLimit), + /// mean value of array + Mean, + /// minimal value in array + Min, + /// maximum value in array + Max, + /// replace with the value zero + Zero, + /// replace with the value one + One, +} + +impl FillNullStrategy { + pub fn is_elementwise(&self) -> bool { + matches!(self, Self::One | Self::Zero) + } +} + +/// Replace None values with a value +pub trait ChunkFillNullValue { + /// Replace None values with a give value `T`. + fn fill_null_with_values(&self, value: T) -> PolarsResult + where + Self: Sized; +} + +/// Fill a ChunkedArray with one value. +pub trait ChunkFull { + /// Create a ChunkedArray with a single value. + fn full(name: PlSmallStr, value: T, length: usize) -> Self + where + Self: Sized; +} + +pub trait ChunkFullNull { + fn full_null(_name: PlSmallStr, _length: usize) -> Self + where + Self: Sized; +} + +/// Reverse a [`ChunkedArray`] +pub trait ChunkReverse { + /// Return a reversed version of this array. + fn reverse(&self) -> Self; +} + +/// Filter values by a boolean mask. +pub trait ChunkFilter { + /// Filter values in the ChunkedArray with a boolean mask. + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let array = Int32Chunked::new("array".into(), &[1, 2, 3]); + /// let mask = BooleanChunked::new("mask".into(), &[true, false, true]); + /// + /// let filtered = array.filter(&mask).unwrap(); + /// assert_eq!(Vec::from(&filtered), [Some(1), Some(3)]) + /// ``` + fn filter(&self, filter: &BooleanChunked) -> PolarsResult> + where + Self: Sized; +} + +/// Create a new ChunkedArray filled with values at that index. +pub trait ChunkExpandAtIndex { + /// Create a new ChunkedArray filled with values at that index. + fn new_from_index(&self, index: usize, length: usize) -> ChunkedArray; +} + +macro_rules! impl_chunk_expand { + ($self:ident, $length:ident, $index:ident) => {{ + if $self.is_empty() { + return $self.clone(); + } + let opt_val = $self.get($index); + match opt_val { + Some(val) => ChunkedArray::full($self.name().clone(), val, $length), + None => ChunkedArray::full_null($self.name().clone(), $length), + } + }}; +} + +impl ChunkExpandAtIndex for ChunkedArray +where + ChunkedArray: ChunkFull, +{ + fn new_from_index(&self, index: usize, length: usize) -> ChunkedArray { + let mut out = impl_chunk_expand!(self, length, index); + out.set_sorted_flag(IsSorted::Ascending); + out + } +} + +impl ChunkExpandAtIndex for BooleanChunked { + fn new_from_index(&self, index: usize, length: usize) -> BooleanChunked { + let mut out = impl_chunk_expand!(self, length, index); + out.set_sorted_flag(IsSorted::Ascending); + out + } +} + +impl ChunkExpandAtIndex for StringChunked { + fn new_from_index(&self, index: usize, length: usize) -> StringChunked { + let mut out = impl_chunk_expand!(self, length, index); + out.set_sorted_flag(IsSorted::Ascending); + out + } +} + +impl ChunkExpandAtIndex for BinaryChunked { + fn new_from_index(&self, index: usize, length: usize) -> BinaryChunked { + let mut out = impl_chunk_expand!(self, length, index); + out.set_sorted_flag(IsSorted::Ascending); + out + } +} + +impl ChunkExpandAtIndex for BinaryOffsetChunked { + fn new_from_index(&self, index: usize, length: usize) -> BinaryOffsetChunked { + let mut out = impl_chunk_expand!(self, length, index); + out.set_sorted_flag(IsSorted::Ascending); + out + } +} + +impl ChunkExpandAtIndex for ListChunked { + fn new_from_index(&self, index: usize, length: usize) -> ListChunked { + let opt_val = self.get_as_series(index); + match opt_val { + Some(val) => { + let mut ca = ListChunked::full(self.name().clone(), &val, length); + unsafe { ca.to_logical(self.inner_dtype().clone()) }; + ca + }, + None => { + ListChunked::full_null_with_dtype(self.name().clone(), length, self.inner_dtype()) + }, + } + } +} + +#[cfg(feature = "dtype-struct")] +impl ChunkExpandAtIndex for StructChunked { + fn new_from_index(&self, index: usize, length: usize) -> ChunkedArray { + let (chunk_idx, idx) = self.index_to_chunked_index(index); + let chunk = self.downcast_chunks().get(chunk_idx).unwrap(); + let chunk = if chunk.is_null(idx) { + new_null_array(chunk.dtype().clone(), length) + } else { + let values = chunk + .values() + .iter() + .map(|arr| { + let s = Series::try_from((PlSmallStr::EMPTY, arr.clone())).unwrap(); + let s = s.new_from_index(idx, length); + s.chunks()[0].clone() + }) + .collect::>(); + + StructArray::new(chunk.dtype().clone(), length, values, None).boxed() + }; + + // SAFETY: chunks are from self. + unsafe { self.copy_with_chunks(vec![chunk]) } + } +} + +#[cfg(feature = "dtype-array")] +impl ChunkExpandAtIndex for ArrayChunked { + fn new_from_index(&self, index: usize, length: usize) -> ArrayChunked { + let opt_val = self.get_as_series(index); + match opt_val { + Some(val) => { + let mut ca = ArrayChunked::full(self.name().clone(), &val, length); + unsafe { ca.to_logical(self.inner_dtype().clone()) }; + ca + }, + None => ArrayChunked::full_null_with_dtype( + self.name().clone(), + length, + self.inner_dtype(), + self.width(), + ), + } + } +} + +#[cfg(feature = "object")] +impl ChunkExpandAtIndex> for ObjectChunked { + fn new_from_index(&self, index: usize, length: usize) -> ObjectChunked { + let opt_val = self.get(index); + match opt_val { + Some(val) => ObjectChunked::::full(self.name().clone(), val.clone(), length), + None => ObjectChunked::::full_null(self.name().clone(), length), + } + } +} + +/// Shift the values of a [`ChunkedArray`] by a number of periods. +pub trait ChunkShiftFill { + /// Shift the values by a given period and fill the parts that will be empty due to this operation + /// with `fill_value`. + fn shift_and_fill(&self, periods: i64, fill_value: V) -> ChunkedArray; +} + +pub trait ChunkShift { + fn shift(&self, periods: i64) -> ChunkedArray; +} + +/// Combine two [`ChunkedArray`] based on some predicate. +pub trait ChunkZip { + /// Create a new ChunkedArray with values from self where the mask evaluates `true` and values + /// from `other` where the mask evaluates `false` + fn zip_with( + &self, + mask: &BooleanChunked, + other: &ChunkedArray, + ) -> PolarsResult>; +} + +/// Apply kernels on the arrow array chunks in a ChunkedArray. +pub trait ChunkApplyKernel { + /// Apply kernel and return result as a new ChunkedArray. + #[must_use] + fn apply_kernel(&self, f: &dyn Fn(&A) -> ArrayRef) -> Self; + + /// Apply a kernel that outputs an array of different type. + fn apply_kernel_cast(&self, f: &dyn Fn(&A) -> ArrayRef) -> ChunkedArray + where + S: PolarsDataType; +} + +#[cfg(feature = "is_first_distinct")] +/// Mask the first unique values as `true` +pub trait IsFirstDistinct { + fn is_first_distinct(&self) -> PolarsResult { + polars_bail!(opq = is_first_distinct, T::get_dtype()); + } +} + +#[cfg(feature = "is_last_distinct")] +/// Mask the last unique values as `true` +pub trait IsLastDistinct { + fn is_last_distinct(&self) -> PolarsResult { + polars_bail!(opq = is_last_distinct, T::get_dtype()); + } +} diff --git a/crates/polars-core/src/chunked_array/ops/nulls.rs b/crates/polars-core/src/chunked_array/ops/nulls.rs new file mode 100644 index 000000000000..96a2d2e8cd96 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/nulls.rs @@ -0,0 +1,85 @@ +use arrow::bitmap::Bitmap; + +use super::*; +use crate::chunked_array::flags::StatisticsFlags; + +impl ChunkedArray { + /// Get a mask of the null values. + pub fn is_null(&self) -> BooleanChunked { + if !self.has_nulls() { + return BooleanChunked::full(self.name().clone(), false, self.len()); + } + // dispatch to non-generic function + is_null(self.name().clone(), &self.chunks) + } + + /// Get a mask of the valid values. + pub fn is_not_null(&self) -> BooleanChunked { + if self.null_count() == 0 { + return BooleanChunked::full(self.name().clone(), true, self.len()); + } + // dispatch to non-generic function + is_not_null(self.name().clone(), &self.chunks) + } + + pub(crate) fn coalesce_nulls(&self, other: &[ArrayRef]) -> Self { + let chunks = coalesce_nulls(&self.chunks, other); + let mut ca = unsafe { self.copy_with_chunks(chunks) }; + use StatisticsFlags as F; + ca.retain_flags_from(self, F::IS_SORTED_ANY); + ca + } +} + +pub fn is_not_null(name: PlSmallStr, chunks: &[ArrayRef]) -> BooleanChunked { + let chunks = chunks.iter().map(|arr| { + let bitmap = arr + .validity() + .cloned() + .unwrap_or_else(|| !(&Bitmap::new_zeroed(arr.len()))); + BooleanArray::from_data_default(bitmap, None) + }); + BooleanChunked::from_chunk_iter(name, chunks) +} + +pub fn is_null(name: PlSmallStr, chunks: &[ArrayRef]) -> BooleanChunked { + let chunks = chunks.iter().map(|arr| { + let bitmap = arr + .validity() + .map(|bitmap| !bitmap) + .unwrap_or_else(|| Bitmap::new_zeroed(arr.len())); + BooleanArray::from_data_default(bitmap, None) + }); + BooleanChunked::from_chunk_iter(name, chunks) +} + +pub fn replace_non_null(name: PlSmallStr, chunks: &[ArrayRef], default: bool) -> BooleanChunked { + BooleanChunked::from_chunk_iter( + name, + chunks.iter().map(|el| { + BooleanArray::from_data_default( + Bitmap::new_with_value(default, el.len()), + el.validity().cloned(), + ) + }), + ) +} + +pub(crate) fn coalesce_nulls(chunks: &[ArrayRef], other: &[ArrayRef]) -> Vec { + assert_eq!(chunks.len(), other.len()); + chunks + .iter() + .zip(other) + .map(|(a, b)| { + assert_eq!(a.len(), b.len()); + let validity = match (a.validity(), b.validity()) { + (None, Some(b)) => Some(b.clone()), + (Some(a), Some(b)) => Some(a & b), + (Some(a), None) => Some(a.clone()), + (None, None) => None, + }; + + a.with_validity(validity) + }) + .collect() +} diff --git a/crates/polars-core/src/chunked_array/ops/reverse.rs b/crates/polars-core/src/chunked_array/ops/reverse.rs new file mode 100644 index 000000000000..ff8209b3e892 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/reverse.rs @@ -0,0 +1,131 @@ +#[cfg(feature = "dtype-array")] +use crate::chunked_array::builder::get_fixed_size_list_builder; +use crate::prelude::*; +use crate::series::IsSorted; +use crate::utils::NoNull; + +impl ChunkReverse for ChunkedArray +where + T: PolarsNumericType, +{ + fn reverse(&self) -> ChunkedArray { + let mut out = if let Ok(slice) = self.cont_slice() { + let ca: NoNull> = slice.iter().rev().copied().collect_trusted(); + ca.into_inner() + } else { + self.into_iter().rev().collect_trusted() + }; + out.rename(self.name().clone()); + + match self.is_sorted_flag() { + IsSorted::Ascending => out.set_sorted_flag(IsSorted::Descending), + IsSorted::Descending => out.set_sorted_flag(IsSorted::Ascending), + _ => {}, + } + + out + } +} + +macro_rules! impl_reverse { + ($arrow_type:ident, $ca_type:ident) => { + impl ChunkReverse for $ca_type { + fn reverse(&self) -> Self { + let mut ca: Self = self.into_iter().rev().collect_trusted(); + ca.rename(self.name().clone()); + ca + } + } + }; +} + +impl_reverse!(BooleanType, BooleanChunked); +impl_reverse!(BinaryOffsetType, BinaryOffsetChunked); +impl_reverse!(ListType, ListChunked); + +impl ChunkReverse for BinaryChunked { + fn reverse(&self) -> Self { + if self.chunks.len() == 1 { + let arr = self.downcast_iter().next().unwrap(); + let views = arr.views().iter().copied().rev().collect::>(); + + unsafe { + let arr = BinaryViewArray::new_unchecked( + arr.dtype().clone(), + views.into(), + arr.data_buffers().clone(), + arr.validity().map(|bitmap| bitmap.iter().rev().collect()), + arr.total_bytes_len(), + arr.total_buffer_len(), + ) + .boxed(); + BinaryChunked::from_chunks_and_dtype_unchecked( + self.name().clone(), + vec![arr], + self.dtype().clone(), + ) + } + } else { + let ca = IdxCa::from_vec( + PlSmallStr::EMPTY, + (0..self.len() as IdxSize).rev().collect(), + ); + unsafe { self.take_unchecked(&ca) } + } + } +} + +impl ChunkReverse for StringChunked { + fn reverse(&self) -> Self { + unsafe { self.as_binary().reverse().to_string_unchecked() } + } +} + +#[cfg(feature = "dtype-array")] +impl ChunkReverse for ArrayChunked { + fn reverse(&self) -> Self { + if !self.inner_dtype().is_primitive_numeric() { + todo!("reverse for FixedSizeList with non-numeric dtypes not yet supported") + } + let ca = self.rechunk(); + let arr = ca.downcast_as_array(); + let values = arr.values().as_ref(); + + let mut builder = + get_fixed_size_list_builder(ca.inner_dtype(), ca.len(), ca.width(), ca.name().clone()) + .expect("not yet supported"); + + // SAFETY, we are within bounds + unsafe { + if arr.null_count() == 0 { + for i in (0..arr.len()).rev() { + builder.push_unchecked(values, i) + } + } else { + let validity = arr.validity().unwrap(); + for i in (0..arr.len()).rev() { + if validity.get_bit_unchecked(i) { + builder.push_unchecked(values, i) + } else { + builder.push_null() + } + } + } + } + builder.finish() + } +} + +#[cfg(feature = "object")] +impl ChunkReverse for ObjectChunked { + fn reverse(&self) -> Self { + // SAFETY: we know we don't go out of bounds. + unsafe { + self.take_unchecked( + &(0..self.len() as IdxSize) + .rev() + .collect_ca(PlSmallStr::EMPTY), + ) + } + } +} diff --git a/crates/polars-core/src/chunked_array/ops/rolling_window.rs b/crates/polars-core/src/chunked_array/ops/rolling_window.rs new file mode 100644 index 000000000000..ea8a69ef553e --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/rolling_window.rs @@ -0,0 +1,283 @@ +use std::hash::{Hash, Hasher}; + +use polars_compute::rolling::RollingFnParams; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "rolling_window", derive(PartialEq))] +pub struct RollingOptionsFixedWindow { + /// The length of the window. + pub window_size: usize, + /// Amount of elements in the window that should be filled before computing a result. + pub min_periods: usize, + /// An optional slice with the same length as the window that will be multiplied + /// elementwise with the values in the window. + pub weights: Option>, + /// Set the labels at the center of the window. + pub center: bool, + /// Optional parameters for the rolling + #[cfg_attr(feature = "serde", serde(default))] + pub fn_params: Option, +} + +impl Hash for RollingOptionsFixedWindow { + fn hash(&self, state: &mut H) { + self.window_size.hash(state); + self.min_periods.hash(state); + self.center.hash(state); + self.weights.is_some().hash(state); + } +} + +impl Default for RollingOptionsFixedWindow { + fn default() -> Self { + RollingOptionsFixedWindow { + window_size: 3, + min_periods: 1, + weights: None, + center: false, + fn_params: None, + } + } +} + +#[cfg(feature = "rolling_window")] +mod inner_mod { + use std::ops::SubAssign; + + use arrow::bitmap::MutableBitmap; + use arrow::bitmap::utils::set_bit_unchecked; + use arrow::legacy::trusted_len::TrustedLenPush; + use num_traits::pow::Pow; + use num_traits::{Float, Zero}; + use polars_utils::float::IsFloat; + + use crate::chunked_array::cast::CastOptions; + use crate::prelude::*; + + /// utility + fn check_input(window_size: usize, min_periods: usize) -> PolarsResult<()> { + polars_ensure!( + min_periods <= window_size, + ComputeError: "`window_size`: {} should be >= `min_periods`: {}", + window_size, min_periods + ); + Ok(()) + } + + /// utility + fn window_edges(idx: usize, len: usize, window_size: usize, center: bool) -> (usize, usize) { + let (start, end) = if center { + let right_window = window_size.div_ceil(2); + ( + idx.saturating_sub(window_size - right_window), + len.min(idx + right_window), + ) + } else { + (idx.saturating_sub(window_size - 1), idx + 1) + }; + + (start, end - start) + } + + impl ChunkRollApply for ChunkedArray + where + T: PolarsNumericType, + Self: IntoSeries, + { + /// Apply a rolling custom function. This is pretty slow because of dynamic dispatch. + fn rolling_map( + &self, + f: &dyn Fn(&Series) -> Series, + mut options: RollingOptionsFixedWindow, + ) -> PolarsResult { + check_input(options.window_size, options.min_periods)?; + + let ca = self.rechunk(); + if options.weights.is_some() + && !matches!(self.dtype(), DataType::Float64 | DataType::Float32) + { + let s = self.cast_with_options(&DataType::Float64, CastOptions::NonStrict)?; + return s.rolling_map(f, options); + } + + options.window_size = std::cmp::min(self.len(), options.window_size); + + let len = self.len(); + let arr = ca.downcast_as_array(); + let mut ca = ChunkedArray::::from_slice(PlSmallStr::EMPTY, &[T::Native::zero()]); + let ptr = ca.chunks[0].as_mut() as *mut dyn Array as *mut PrimitiveArray; + let mut series_container = ca.into_series(); + + let mut builder = PrimitiveChunkedBuilder::::new(self.name().clone(), self.len()); + + if let Some(weights) = options.weights { + let weights_series = + Float64Chunked::new(PlSmallStr::from_static("weights"), &weights).into_series(); + + let weights_series = weights_series.cast(self.dtype()).unwrap(); + + for idx in 0..len { + let (start, size) = window_edges(idx, len, options.window_size, options.center); + + if size < options.min_periods { + builder.append_null(); + } else { + // SAFETY: + // we are in bounds + let arr_window = unsafe { arr.slice_typed_unchecked(start, size) }; + + // ensure we still meet window size criteria after removing null values + if size - arr_window.null_count() < options.min_periods { + builder.append_null(); + continue; + } + + // SAFETY. + // ptr is not dropped as we are in scope + // We are also the only owner of the contents of the Arc + // we do this to reduce heap allocs. + unsafe { + *ptr = arr_window; + } + // reset flags as we reuse this container + series_container.clear_flags(); + // ensure the length is correct + series_container._get_inner_mut().compute_len(); + let s = if size == options.window_size { + f(&series_container.multiply(&weights_series).unwrap()) + } else { + let weights_cutoff: Series = match self.dtype() { + DataType::Float64 => weights_series + .f64() + .unwrap() + .into_iter() + .take(series_container.len()) + .collect(), + _ => weights_series // Float32 case + .f32() + .unwrap() + .into_iter() + .take(series_container.len()) + .collect(), + }; + f(&series_container.multiply(&weights_cutoff).unwrap()) + }; + + let out = self.unpack_series_matching_type(&s)?; + builder.append_option(out.get(0)); + } + } + + Ok(builder.finish().into_series()) + } else { + for idx in 0..len { + let (start, size) = window_edges(idx, len, options.window_size, options.center); + + if size < options.min_periods { + builder.append_null(); + } else { + // SAFETY: + // we are in bounds + let arr_window = unsafe { arr.slice_typed_unchecked(start, size) }; + + // ensure we still meet window size criteria after removing null values + if size - arr_window.null_count() < options.min_periods { + builder.append_null(); + continue; + } + + // SAFETY. + // ptr is not dropped as we are in scope + // We are also the only owner of the contents of the Arc + // we do this to reduce heap allocs. + unsafe { + *ptr = arr_window; + } + // reset flags as we reuse this container + series_container.clear_flags(); + // ensure the length is correct + series_container._get_inner_mut().compute_len(); + let s = f(&series_container); + let out = self.unpack_series_matching_type(&s)?; + builder.append_option(out.get(0)); + } + } + + Ok(builder.finish().into_series()) + } + } + } + + impl ChunkedArray + where + ChunkedArray: IntoSeries, + T: PolarsFloatType, + T::Native: Float + IsFloat + SubAssign + Pow, + { + /// Apply a rolling custom function. This is pretty slow because of dynamic dispatch. + pub fn rolling_map_float(&self, window_size: usize, mut f: F) -> PolarsResult + where + F: FnMut(&mut ChunkedArray) -> Option, + { + if window_size > self.len() { + return Ok(Self::full_null(self.name().clone(), self.len())); + } + let ca = self.rechunk(); + let arr = ca.downcast_as_array(); + + // We create a temporary dummy ChunkedArray. This will be a + // container where we swap the window contents every iteration doing + // so will save a lot of heap allocations. + let mut heap_container = + ChunkedArray::::from_slice(PlSmallStr::EMPTY, &[T::Native::zero()]); + let ptr = heap_container.chunks[0].as_mut() as *mut dyn Array + as *mut PrimitiveArray; + + let mut validity = MutableBitmap::with_capacity(ca.len()); + validity.extend_constant(window_size - 1, false); + validity.extend_constant(ca.len() - (window_size - 1), true); + let validity_slice = validity.as_mut_slice(); + + let mut values = Vec::with_capacity(ca.len()); + values.extend(std::iter::repeat_n(T::Native::default(), window_size - 1)); + + for offset in 0..self.len() + 1 - window_size { + debug_assert!(offset + window_size <= arr.len()); + let arr_window = unsafe { arr.slice_typed_unchecked(offset, window_size) }; + // The lengths are cached, so we must update them. + heap_container.length = arr_window.len(); + + // SAFETY: ptr is not dropped as we are in scope. We are also the only + // owner of the contents of the Arc (we do this to reduce heap allocs). + unsafe { + *ptr = arr_window; + } + + let out = f(&mut heap_container); + match out { + Some(v) => { + // SAFETY: we have pre-allocated. + unsafe { values.push_unchecked(v) } + }, + None => { + // SAFETY: we allocated enough for both the `values` vec + // and the `validity_ptr`. + unsafe { + values.push_unchecked(T::Native::default()); + set_bit_unchecked(validity_slice, offset + window_size - 1, false); + } + }, + } + } + let arr = PrimitiveArray::new( + T::get_dtype().to_arrow(CompatLevel::newest()), + values.into(), + Some(validity.into()), + ); + Ok(Self::with_chunk(self.name().clone(), arr)) + } + } +} diff --git a/crates/polars-core/src/chunked_array/ops/row_encode.rs b/crates/polars-core/src/chunked_array/ops/row_encode.rs new file mode 100644 index 000000000000..6d6273959530 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/row_encode.rs @@ -0,0 +1,300 @@ +use arrow::compute::utils::combine_validities_and_many; +use polars_row::{ + RowEncodingCategoricalContext, RowEncodingContext, RowEncodingOptions, RowsEncoded, + convert_columns, +}; +use polars_utils::itertools::Itertools; +use rayon::prelude::*; + +use crate::POOL; +use crate::prelude::*; +use crate::utils::_split_offsets; + +pub fn encode_rows_vertical_par_unordered(by: &[Column]) -> PolarsResult { + let n_threads = POOL.current_num_threads(); + let len = by[0].len(); + let splits = _split_offsets(len, n_threads); + + let chunks = splits.into_par_iter().map(|(offset, len)| { + let sliced = by + .iter() + .map(|s| s.slice(offset as i64, len)) + .collect::>(); + let rows = _get_rows_encoded_unordered(&sliced)?; + Ok(rows.into_array()) + }); + let chunks = POOL.install(|| chunks.collect::>>()); + + Ok(BinaryOffsetChunked::from_chunk_iter( + PlSmallStr::EMPTY, + chunks?, + )) +} + +// Almost the same but broadcast nulls to the row-encoded array. +pub fn encode_rows_vertical_par_unordered_broadcast_nulls( + by: &[Column], +) -> PolarsResult { + let n_threads = POOL.current_num_threads(); + let len = by[0].len(); + let splits = _split_offsets(len, n_threads); + + let chunks = splits.into_par_iter().map(|(offset, len)| { + let sliced = by + .iter() + .map(|s| s.slice(offset as i64, len)) + .collect::>(); + let rows = _get_rows_encoded_unordered(&sliced)?; + + let validities = sliced + .iter() + .flat_map(|s| { + let s = s.rechunk(); + #[allow(clippy::unnecessary_to_owned)] + s.as_materialized_series() + .chunks() + .to_vec() + .into_iter() + .map(|arr| arr.validity().cloned()) + }) + .collect::>(); + + let validity = combine_validities_and_many(&validities); + Ok(rows.into_array().with_validity_typed(validity)) + }); + let chunks = POOL.install(|| chunks.collect::>>()); + + Ok(BinaryOffsetChunked::from_chunk_iter( + PlSmallStr::EMPTY, + chunks?, + )) +} + +/// Get the [`RowEncodingContext`] for a certain [`DataType`]. +/// +/// This should be given the logical type in order to communicate Polars datatype information down +/// into the row encoding / decoding. +pub fn get_row_encoding_context(dtype: &DataType, ordered: bool) -> Option { + match dtype { + DataType::Boolean + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Int128 + | DataType::Float32 + | DataType::Float64 + | DataType::String + | DataType::Binary + | DataType::BinaryOffset + | DataType::Null + | DataType::Time + | DataType::Date + | DataType::Datetime(_, _) + | DataType::Duration(_) => None, + + DataType::Unknown(_) => panic!("Unsupported in row encoding"), + + #[cfg(feature = "object")] + DataType::Object(_) => panic!("Unsupported in row encoding"), + + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(precision, _) => { + Some(RowEncodingContext::Decimal(precision.unwrap_or(38))) + }, + + #[cfg(feature = "dtype-array")] + DataType::Array(dtype, _) => get_row_encoding_context(dtype, ordered), + DataType::List(dtype) => get_row_encoding_context(dtype, ordered), + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(revmap, ordering) | DataType::Enum(revmap, ordering) => { + let is_enum = dtype.is_enum(); + let ctx = match revmap { + Some(revmap) => { + let (num_known_categories, lexical_sort_idxs) = match revmap.as_ref() { + RevMapping::Global(map, _, _) => { + let num_known_categories = + map.keys().max().copied().map_or(0, |m| m + 1); + + // @TODO: This should probably be cached. + let lexical_sort_idxs = (ordered + && matches!(ordering, CategoricalOrdering::Lexical)) + .then(|| { + let read_map = crate::STRING_CACHE.read_map(); + let payloads = read_map.get_current_payloads(); + assert!(payloads.len() >= num_known_categories as usize); + + let mut idxs = (0..num_known_categories).collect::>(); + idxs.sort_by_key(|&k| payloads[k as usize].as_str()); + let mut sort_idxs = vec![0; num_known_categories as usize]; + for (i, idx) in idxs.into_iter().enumerate_u32() { + sort_idxs[idx as usize] = i; + } + sort_idxs + }); + + (num_known_categories, lexical_sort_idxs) + }, + RevMapping::Local(values, _) => { + // @TODO: This should probably be cached. + let lexical_sort_idxs = (ordered + && matches!(ordering, CategoricalOrdering::Lexical)) + .then(|| { + assert_eq!(values.null_count(), 0); + let values: Vec<&str> = values.values_iter().collect(); + + let mut idxs = (0..values.len() as u32).collect::>(); + idxs.sort_by_key(|&k| values[k as usize]); + let mut sort_idxs = vec![0; values.len()]; + for (i, idx) in idxs.into_iter().enumerate_u32() { + sort_idxs[idx as usize] = i; + } + sort_idxs + }); + + (values.len() as u32, lexical_sort_idxs) + }, + }; + + RowEncodingCategoricalContext { + num_known_categories, + is_enum, + lexical_sort_idxs, + } + }, + None => { + let num_known_categories = u32::MAX; + + if matches!(ordering, CategoricalOrdering::Lexical) && ordered { + panic!("lexical ordering not yet supported if rev-map not given"); + } + RowEncodingCategoricalContext { + num_known_categories, + is_enum, + lexical_sort_idxs: None, + } + }, + }; + + Some(RowEncodingContext::Categorical(ctx)) + }, + #[cfg(feature = "dtype-struct")] + DataType::Struct(fs) => { + let mut ctxts = Vec::new(); + + for (i, f) in fs.iter().enumerate() { + if let Some(ctxt) = get_row_encoding_context(f.dtype(), ordered) { + ctxts.reserve(fs.len()); + ctxts.extend(std::iter::repeat_n(None, i)); + ctxts.push(Some(ctxt)); + break; + } + } + + if ctxts.is_empty() { + return None; + } + + ctxts.extend( + fs[ctxts.len()..] + .iter() + .map(|f| get_row_encoding_context(f.dtype(), ordered)), + ); + + Some(RowEncodingContext::Struct(ctxts)) + }, + } +} + +pub fn encode_rows_unordered(by: &[Column]) -> PolarsResult { + let rows = _get_rows_encoded_unordered(by)?; + Ok(BinaryOffsetChunked::with_chunk( + PlSmallStr::EMPTY, + rows.into_array(), + )) +} + +pub fn _get_rows_encoded_unordered(by: &[Column]) -> PolarsResult { + let mut cols = Vec::with_capacity(by.len()); + let mut opts = Vec::with_capacity(by.len()); + let mut ctxts = Vec::with_capacity(by.len()); + + // Since ZFS exists, we might not actually have any arrays and need to get the length from the + // columns. + let num_rows = by.first().map_or(0, |c| c.len()); + + for by in by { + debug_assert_eq!(by.len(), num_rows); + + let by = by.as_materialized_series(); + let arr = by.to_physical_repr().rechunk().chunks()[0].to_boxed(); + let opt = RowEncodingOptions::new_unsorted(); + let ctxt = get_row_encoding_context(by.dtype(), false); + + cols.push(arr); + opts.push(opt); + ctxts.push(ctxt); + } + Ok(convert_columns(num_rows, &cols, &opts, &ctxts)) +} + +pub fn _get_rows_encoded( + by: &[Column], + descending: &[bool], + nulls_last: &[bool], +) -> PolarsResult { + debug_assert_eq!(by.len(), descending.len()); + debug_assert_eq!(by.len(), nulls_last.len()); + + let mut cols = Vec::with_capacity(by.len()); + let mut opts = Vec::with_capacity(by.len()); + let mut ctxts = Vec::with_capacity(by.len()); + + // Since ZFS exists, we might not actually have any arrays and need to get the length from the + // columns. + let num_rows = by.first().map_or(0, |c| c.len()); + + for ((by, desc), null_last) in by.iter().zip(descending).zip(nulls_last) { + debug_assert_eq!(by.len(), num_rows); + + let by = by.as_materialized_series(); + let arr = by.to_physical_repr().rechunk().chunks()[0].to_boxed(); + let opt = RowEncodingOptions::new_sorted(*desc, *null_last); + let ctxt = get_row_encoding_context(by.dtype(), true); + + cols.push(arr); + opts.push(opt); + ctxts.push(ctxt); + } + Ok(convert_columns(num_rows, &cols, &opts, &ctxts)) +} + +pub fn _get_rows_encoded_ca( + name: PlSmallStr, + by: &[Column], + descending: &[bool], + nulls_last: &[bool], +) -> PolarsResult { + _get_rows_encoded(by, descending, nulls_last) + .map(|rows| BinaryOffsetChunked::with_chunk(name, rows.into_array())) +} + +pub fn _get_rows_encoded_arr( + by: &[Column], + descending: &[bool], + nulls_last: &[bool], +) -> PolarsResult> { + _get_rows_encoded(by, descending, nulls_last).map(|rows| rows.into_array()) +} + +pub fn _get_rows_encoded_ca_unordered( + name: PlSmallStr, + by: &[Column], +) -> PolarsResult { + _get_rows_encoded_unordered(by) + .map(|rows| BinaryOffsetChunked::with_chunk(name, rows.into_array())) +} diff --git a/crates/polars-core/src/chunked_array/ops/search_sorted.rs b/crates/polars-core/src/chunked_array/ops/search_sorted.rs new file mode 100644 index 000000000000..5e97f0818176 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/search_sorted.rs @@ -0,0 +1,223 @@ +use std::fmt::Debug; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use crate::prelude::*; + +#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum SearchSortedSide { + #[default] + Any, + Left, + Right, +} + +/// Computes the first point on [lo, hi) where f is true, assuming it is first +/// always false and then always true. It is assumed f(hi) is true. +/// midpoint is a function that returns some lo < i < hi if one exists, else lo. +fn lower_bound(mut lo: I, mut hi: I, midpoint: M, f: F) -> I +where + I: PartialEq + Eq, + M: Fn(&I, &I) -> I, + F: Fn(&I) -> bool, +{ + loop { + let m = midpoint(&lo, &hi); + if m == lo { + return if f(&lo) { lo } else { hi }; + } + + if f(&m) { + hi = m; + } else { + lo = m; + } + } +} + +/// Search through a series of chunks for the first position where f(x) is true, +/// assuming it is first always false and then always true. +/// +/// It repeats this for each value in search_values. If the search value is null null_idx is +/// returned. +/// +/// Assumes the chunks are non-empty. +pub fn lower_bound_chunks<'a, T, F>( + chunks: &[&'a T::Array], + search_values: impl Iterator>>, + null_idx: IdxSize, + f: F, +) -> Vec +where + T: PolarsDataType, + F: Fn(&'a T::Array, usize, &T::Physical<'a>) -> bool, +{ + if chunks.is_empty() { + return search_values.map(|_| 0).collect(); + } + + // Fast-path: only a single chunk. + if chunks.len() == 1 { + let chunk = &chunks[0]; + return search_values + .map(|ov| { + if let Some(v) = ov { + lower_bound(0, chunk.len(), |l, r| (l + r) / 2, |m| f(chunk, *m, &v)) as IdxSize + } else { + null_idx + } + }) + .collect(); + } + + // Multiple chunks, precompute prefix sum of lengths so we can look up + // in O(1) the global position of chunk i. + let mut sz = 0; + let mut chunk_len_prefix_sum = Vec::with_capacity(chunks.len() + 1); + for c in chunks { + chunk_len_prefix_sum.push(sz); + sz += c.len(); + } + chunk_len_prefix_sum.push(sz); + + // For each search value do a binary search on (chunk_idx, idx_in_chunk) pairs. + search_values + .map(|ov| { + let Some(v) = ov else { + return null_idx; + }; + let left = (0, 0); + let right = (chunks.len(), 0); + let midpoint = |l: &(usize, usize), r: &(usize, usize)| { + if l.0 == r.0 { + // Within same chunk. + (l.0, (l.1 + r.1) / 2) + } else if l.0 + 1 == r.0 { + // Two adjacent chunks, might have to be l or r. + let left_len = chunks[l.0].len() - l.1; + + let logical_mid = (left_len + r.1) / 2; + if logical_mid < left_len { + (l.0, l.1 + logical_mid) + } else { + (r.0, logical_mid - left_len) + } + } else { + // Has a chunk in between. + ((l.0 + r.0) / 2, 0) + } + }; + + let bound = lower_bound(left, right, midpoint, |m| { + f(unsafe { chunks.get_unchecked(m.0) }, m.1, &v) + }); + + (chunk_len_prefix_sum[bound.0] + bound.1) as IdxSize + }) + .collect() +} + +#[allow(clippy::collapsible_else_if)] +pub fn binary_search_ca<'a, T>( + ca: &'a ChunkedArray, + search_values: impl Iterator>>, + side: SearchSortedSide, + descending: bool, +) -> Vec +where + T: PolarsDataType, + T::Physical<'a>: TotalOrd + Debug + Copy, +{ + let chunks: Vec<_> = ca.downcast_iter().filter(|c| c.len() > 0).collect(); + let has_nulls = ca.null_count() > 0; + let nulls_last = has_nulls && chunks[0].get(0).is_some(); + let null_idx = if nulls_last { + if side == SearchSortedSide::Right { + ca.len() + } else { + ca.len() - ca.null_count() + } + } else { + if side == SearchSortedSide::Right { + ca.null_count() + } else { + 0 + } + } as IdxSize; + + if !descending { + if !has_nulls { + if side == SearchSortedSide::Right { + lower_bound_chunks::( + &chunks, + search_values, + null_idx, + |chunk, i, sv| unsafe { chunk.value_unchecked(i).tot_gt(sv) }, + ) + } else { + lower_bound_chunks::( + &chunks, + search_values, + null_idx, + |chunk, i, sv| unsafe { chunk.value_unchecked(i).tot_ge(sv) }, + ) + } + } else { + if side == SearchSortedSide::Right { + lower_bound_chunks::(&chunks, search_values, null_idx, |chunk, i, sv| { + if let Some(v) = unsafe { chunk.get_unchecked(i) } { + v.tot_gt(sv) + } else { + nulls_last + } + }) + } else { + lower_bound_chunks::(&chunks, search_values, null_idx, |chunk, i, sv| { + if let Some(v) = unsafe { chunk.get_unchecked(i) } { + v.tot_ge(sv) + } else { + nulls_last + } + }) + } + } + } else { + if !has_nulls { + if side == SearchSortedSide::Right { + lower_bound_chunks::( + &chunks, + search_values, + null_idx, + |chunk, i, sv| unsafe { chunk.value_unchecked(i).tot_lt(sv) }, + ) + } else { + lower_bound_chunks::( + &chunks, + search_values, + null_idx, + |chunk, i, sv| unsafe { chunk.value_unchecked(i).tot_le(sv) }, + ) + } + } else { + if side == SearchSortedSide::Right { + lower_bound_chunks::(&chunks, search_values, null_idx, |chunk, i, sv| { + if let Some(v) = unsafe { chunk.get_unchecked(i) } { + v.tot_lt(sv) + } else { + nulls_last + } + }) + } else { + lower_bound_chunks::(&chunks, search_values, null_idx, |chunk, i, sv| { + if let Some(v) = unsafe { chunk.get_unchecked(i) } { + v.tot_le(sv) + } else { + nulls_last + } + }) + } + } + } +} diff --git a/crates/polars-core/src/chunked_array/ops/set.rs b/crates/polars-core/src/chunked_array/ops/set.rs new file mode 100644 index 000000000000..9aa0d04719e6 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/set.rs @@ -0,0 +1,386 @@ +use arrow::bitmap::{Bitmap, MutableBitmap}; +use arrow::legacy::kernels::set::{scatter_single_non_null, set_with_mask}; + +use crate::prelude::*; +use crate::utils::align_chunks_binary; + +macro_rules! impl_scatter_with { + ($self:ident, $builder:ident, $idx:ident, $f:ident) => {{ + let mut ca_iter = $self.into_iter().enumerate(); + + for current_idx in $idx.into_iter().map(|i| i as usize) { + polars_ensure!(current_idx < $self.len(), oob = current_idx, $self.len()); + while let Some((cnt_idx, opt_val)) = ca_iter.next() { + if cnt_idx == current_idx { + $builder.append_option($f(opt_val)); + break; + } else { + $builder.append_option(opt_val); + } + } + } + // the last idx is probably not the last value so we finish the iterator + while let Some((_, opt_val)) = ca_iter.next() { + $builder.append_option(opt_val); + } + + let ca = $builder.finish(); + Ok(ca) + }}; +} + +macro_rules! check_bounds { + ($self:ident, $mask:ident) => {{ + polars_ensure!( + $self.len() == $mask.len(), + ShapeMismatch: "invalid mask in `get` operation: shape doesn't match array's shape" + ); + }}; +} + +impl<'a, T> ChunkSet<'a, T::Native, T::Native> for ChunkedArray +where + T: PolarsNumericType, +{ + fn scatter_single>( + &'a self, + idx: I, + value: Option, + ) -> PolarsResult { + if !self.has_nulls() { + if let Some(value) = value { + // Fast path uses kernel. + if self.chunks.len() == 1 { + let arr = scatter_single_non_null( + self.downcast_iter().next().unwrap(), + idx, + value, + T::get_dtype().to_arrow(CompatLevel::newest()), + )?; + return Ok(Self::with_chunk(self.name().clone(), arr)); + } + // Other fast path. Slightly slower as it does not do a memcpy. + else { + let mut av = Vec::with_capacity(self.len()); + for chunk in self.downcast_iter() { + av.extend_from_slice(chunk.values()) + } + let data = av.as_mut_slice(); + + idx.into_iter().try_for_each::<_, PolarsResult<_>>(|idx| { + let val = data + .get_mut(idx as usize) + .ok_or_else(|| polars_err!(oob = idx as usize, self.len()))?; + *val = value; + Ok(()) + })?; + return Ok(Self::from_vec(self.name().clone(), av)); + } + } + } + self.scatter_with(idx, |_| value) + } + + fn scatter_with, F>( + &'a self, + idx: I, + f: F, + ) -> PolarsResult + where + F: Fn(Option) -> Option, + { + let mut builder = PrimitiveChunkedBuilder::::new(self.name().clone(), self.len()); + impl_scatter_with!(self, builder, idx, f) + } + + fn set(&'a self, mask: &BooleanChunked, value: Option) -> PolarsResult { + check_bounds!(self, mask); + + // Fast path uses the kernel in polars-arrow. + if let (Some(value), false) = (value, mask.has_nulls()) { + let (left, mask) = align_chunks_binary(self, mask); + + // Apply binary kernel. + let chunks = left + .downcast_iter() + .zip(mask.downcast_iter()) + .map(|(arr, mask)| { + set_with_mask( + arr, + mask, + value, + T::get_dtype().to_arrow(CompatLevel::newest()), + ) + }); + Ok(ChunkedArray::from_chunk_iter(self.name().clone(), chunks)) + } else { + let mask = mask.rechunk(); + let mask = mask.downcast_as_array(); + let mask = mask.true_and_valid(); + let iter = mask.true_idx_iter(); + self.scatter_single(iter.map(|v| v as IdxSize), value) + } + } +} + +impl<'a> ChunkSet<'a, bool, bool> for BooleanChunked { + fn scatter_single>( + &'a self, + idx: I, + value: Option, + ) -> PolarsResult { + self.scatter_with(idx, |_| value) + } + + fn scatter_with, F>( + &'a self, + idx: I, + f: F, + ) -> PolarsResult + where + F: Fn(Option) -> Option, + { + let mut values = MutableBitmap::with_capacity(self.len()); + let mut validity = MutableBitmap::with_capacity(self.len()); + + for a in self.downcast_iter() { + values.extend_from_bitmap(a.values()); + if let Some(v) = a.validity() { + validity.extend_from_bitmap(v) + } else { + validity.extend_constant(a.len(), true); + } + } + + for i in idx.into_iter().map(|i| i as usize) { + let input = validity.get(i).then(|| values.get(i)); + + match f(input) { + Some(v) => { + values.set(i, v); + validity.set(i, true); + }, + None => { + validity.set(i, false); + }, + } + } + let validity: Bitmap = validity.into(); + let validity = if validity.unset_bits() > 0 { + Some(validity) + } else { + None + }; + + let arr = BooleanArray::from_data_default(values.into(), validity); + Ok(BooleanChunked::with_chunk(self.name().clone(), arr)) + } + + fn set(&'a self, mask: &BooleanChunked, value: Option) -> PolarsResult { + let mask = mask.rechunk(); + let mask = mask.downcast_as_array(); + let mask = mask.true_and_valid(); + let iter = mask.true_idx_iter(); + self.scatter_single(iter.map(|v| v as IdxSize), value) + } +} + +impl<'a> ChunkSet<'a, &'a str, String> for StringChunked { + fn scatter_single>( + &'a self, + idx: I, + opt_value: Option<&'a str>, + ) -> PolarsResult + where + Self: Sized, + { + let idx_iter = idx.into_iter(); + let mut ca_iter = self.into_iter().enumerate(); + let mut builder = StringChunkedBuilder::new(self.name().clone(), self.len()); + + for current_idx in idx_iter.into_iter().map(|i| i as usize) { + polars_ensure!(current_idx < self.len(), oob = current_idx, self.len()); + for (cnt_idx, opt_val_self) in &mut ca_iter { + if cnt_idx == current_idx { + builder.append_option(opt_value); + break; + } else { + builder.append_option(opt_val_self); + } + } + } + // the last idx is probably not the last value so we finish the iterator + for (_, opt_val_self) in ca_iter { + builder.append_option(opt_val_self); + } + + let ca = builder.finish(); + Ok(ca) + } + + fn scatter_with, F>( + &'a self, + idx: I, + f: F, + ) -> PolarsResult + where + Self: Sized, + F: Fn(Option<&'a str>) -> Option, + { + let mut builder = StringChunkedBuilder::new(self.name().clone(), self.len()); + impl_scatter_with!(self, builder, idx, f) + } + + fn set(&'a self, mask: &BooleanChunked, value: Option<&'a str>) -> PolarsResult + where + Self: Sized, + { + check_bounds!(self, mask); + let ca = mask + .into_iter() + .zip(self) + .map(|(mask_val, opt_val)| match mask_val { + Some(true) => value, + _ => opt_val, + }) + .collect_trusted::() + .with_name(self.name().clone()); + Ok(ca) + } +} + +impl<'a> ChunkSet<'a, &'a [u8], Vec> for BinaryChunked { + fn scatter_single>( + &'a self, + idx: I, + opt_value: Option<&'a [u8]>, + ) -> PolarsResult + where + Self: Sized, + { + let mut ca_iter = self.into_iter().enumerate(); + let mut builder = BinaryChunkedBuilder::new(self.name().clone(), self.len()); + + for current_idx in idx.into_iter().map(|i| i as usize) { + polars_ensure!(current_idx < self.len(), oob = current_idx, self.len()); + for (cnt_idx, opt_val_self) in &mut ca_iter { + if cnt_idx == current_idx { + builder.append_option(opt_value); + break; + } else { + builder.append_option(opt_val_self); + } + } + } + // the last idx is probably not the last value so we finish the iterator + for (_, opt_val_self) in ca_iter { + builder.append_option(opt_val_self); + } + + let ca = builder.finish(); + Ok(ca) + } + + fn scatter_with, F>( + &'a self, + idx: I, + f: F, + ) -> PolarsResult + where + Self: Sized, + F: Fn(Option<&'a [u8]>) -> Option>, + { + let mut builder = BinaryChunkedBuilder::new(self.name().clone(), self.len()); + impl_scatter_with!(self, builder, idx, f) + } + + fn set(&'a self, mask: &BooleanChunked, value: Option<&'a [u8]>) -> PolarsResult + where + Self: Sized, + { + check_bounds!(self, mask); + let ca = mask + .into_iter() + .zip(self) + .map(|(mask_val, opt_val)| match mask_val { + Some(true) => value, + _ => opt_val, + }) + .collect_trusted::() + .with_name(self.name().clone()); + Ok(ca) + } +} + +#[cfg(test)] +mod test { + use crate::prelude::*; + + #[test] + fn test_set() { + let ca = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3]); + let mask = BooleanChunked::new(PlSmallStr::from_static("mask"), &[false, true, false]); + let ca = ca.set(&mask, Some(5)).unwrap(); + assert_eq!(Vec::from(&ca), &[Some(1), Some(5), Some(3)]); + + let ca = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3]); + let mask = BooleanChunked::new(PlSmallStr::from_static("mask"), &[None, Some(true), None]); + let ca = ca.set(&mask, Some(5)).unwrap(); + assert_eq!(Vec::from(&ca), &[Some(1), Some(5), Some(3)]); + + let ca = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3]); + let mask = BooleanChunked::new(PlSmallStr::from_static("mask"), &[None, None, None]); + let ca = ca.set(&mask, Some(5)).unwrap(); + assert_eq!(Vec::from(&ca), &[Some(1), Some(2), Some(3)]); + + let ca = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3]); + let mask = BooleanChunked::new( + PlSmallStr::from_static("mask"), + &[Some(true), Some(false), None], + ); + let ca = ca.set(&mask, Some(5)).unwrap(); + assert_eq!(Vec::from(&ca), &[Some(5), Some(2), Some(3)]); + + let ca = ca.scatter_single(vec![0, 1], Some(10)).unwrap(); + assert_eq!(Vec::from(&ca), &[Some(10), Some(10), Some(3)]); + + assert!(ca.scatter_single(vec![0, 10], Some(0)).is_err()); + + // test booleans + let ca = BooleanChunked::new(PlSmallStr::from_static("a"), &[true, true, true]); + let mask = BooleanChunked::new(PlSmallStr::from_static("mask"), &[false, true, false]); + let ca = ca.set(&mask, None).unwrap(); + assert_eq!(Vec::from(&ca), &[Some(true), None, Some(true)]); + + // test string + let ca = StringChunked::new(PlSmallStr::from_static("a"), &["foo", "foo", "foo"]); + let mask = BooleanChunked::new(PlSmallStr::from_static("mask"), &[false, true, false]); + let ca = ca.set(&mask, Some("bar")).unwrap(); + assert_eq!(Vec::from(&ca), &[Some("foo"), Some("bar"), Some("foo")]); + } + + #[test] + fn test_set_null_values() { + let ca = Int32Chunked::new(PlSmallStr::from_static("a"), &[Some(1), None, Some(3)]); + let mask = BooleanChunked::new( + PlSmallStr::from_static("mask"), + &[Some(false), Some(true), None], + ); + let ca = ca.set(&mask, Some(2)).unwrap(); + assert_eq!(Vec::from(&ca), &[Some(1), Some(2), Some(3)]); + + let ca = StringChunked::new( + PlSmallStr::from_static("a"), + &[Some("foo"), None, Some("bar")], + ); + let ca = ca.set(&mask, Some("foo")).unwrap(); + assert_eq!(Vec::from(&ca), &[Some("foo"), Some("foo"), Some("bar")]); + + let ca = BooleanChunked::new( + PlSmallStr::from_static("a"), + &[Some(false), None, Some(true)], + ); + let ca = ca.set(&mask, Some(true)).unwrap(); + assert_eq!(Vec::from(&ca), &[Some(false), Some(true), Some(true)]); + } +} diff --git a/crates/polars-core/src/chunked_array/ops/shift.rs b/crates/polars-core/src/chunked_array/ops/shift.rs new file mode 100644 index 000000000000..09f71a5ed5d6 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/shift.rs @@ -0,0 +1,268 @@ +use num_traits::{abs, clamp}; + +use crate::prelude::*; +use crate::series::implementations::null::NullChunked; + +macro_rules! impl_shift_fill { + ($self:ident, $periods:expr, $fill_value:expr) => {{ + let fill_length = abs($periods) as usize; + + if fill_length >= $self.len() { + return match $fill_value { + Some(fill) => Self::full($self.name().clone(), fill, $self.len()), + None => Self::full_null($self.name().clone(), $self.len()), + }; + } + let slice_offset = (-$periods).max(0) as i64; + let length = $self.len() - fill_length; + let mut slice = $self.slice(slice_offset, length); + + let mut fill = match $fill_value { + Some(val) => Self::full($self.name().clone(), val, fill_length), + None => Self::full_null($self.name().clone(), fill_length), + }; + + if $periods < 0 { + slice.append(&fill).unwrap(); + slice + } else { + fill.append(&slice).unwrap(); + fill + } + }}; +} + +impl ChunkShiftFill> for ChunkedArray +where + T: PolarsNumericType, +{ + fn shift_and_fill(&self, periods: i64, fill_value: Option) -> ChunkedArray { + impl_shift_fill!(self, periods, fill_value) + } +} +impl ChunkShift for ChunkedArray +where + T: PolarsNumericType, +{ + fn shift(&self, periods: i64) -> ChunkedArray { + self.shift_and_fill(periods, None) + } +} + +impl ChunkShiftFill> for BooleanChunked { + fn shift_and_fill(&self, periods: i64, fill_value: Option) -> BooleanChunked { + impl_shift_fill!(self, periods, fill_value) + } +} + +impl ChunkShift for BooleanChunked { + fn shift(&self, periods: i64) -> Self { + self.shift_and_fill(periods, None) + } +} + +impl ChunkShiftFill> for StringChunked { + fn shift_and_fill(&self, periods: i64, fill_value: Option<&str>) -> StringChunked { + let ca = self.as_binary(); + unsafe { + ca.shift_and_fill(periods, fill_value.map(|v| v.as_bytes())) + .to_string_unchecked() + } + } +} + +impl ChunkShiftFill> for BinaryChunked { + fn shift_and_fill(&self, periods: i64, fill_value: Option<&[u8]>) -> BinaryChunked { + impl_shift_fill!(self, periods, fill_value) + } +} + +impl ChunkShiftFill> for BinaryOffsetChunked { + fn shift_and_fill(&self, periods: i64, fill_value: Option<&[u8]>) -> BinaryOffsetChunked { + impl_shift_fill!(self, periods, fill_value) + } +} + +impl ChunkShift for StringChunked { + fn shift(&self, periods: i64) -> Self { + self.shift_and_fill(periods, None) + } +} + +impl ChunkShift for BinaryChunked { + fn shift(&self, periods: i64) -> Self { + self.shift_and_fill(periods, None) + } +} + +impl ChunkShift for BinaryOffsetChunked { + fn shift(&self, periods: i64) -> Self { + self.shift_and_fill(periods, None) + } +} + +impl ChunkShiftFill> for ListChunked { + fn shift_and_fill(&self, periods: i64, fill_value: Option<&Series>) -> ListChunked { + // This has its own implementation because a ListChunked cannot have a full-null without + // knowing the inner type + let periods = clamp(periods, -(self.len() as i64), self.len() as i64); + let slice_offset = (-periods).max(0); + let length = self.len() - abs(periods) as usize; + let mut slice = self.slice(slice_offset, length); + + let fill_length = abs(periods) as usize; + let mut fill = match fill_value { + Some(val) => Self::full(self.name().clone(), val, fill_length), + None => ListChunked::full_null_with_dtype( + self.name().clone(), + fill_length, + self.inner_dtype(), + ), + }; + + if periods < 0 { + slice.append(&fill).unwrap(); + slice + } else { + fill.append(&slice).unwrap(); + fill + } + } +} + +impl ChunkShift for ListChunked { + fn shift(&self, periods: i64) -> Self { + self.shift_and_fill(periods, None) + } +} + +#[cfg(feature = "dtype-array")] +impl ChunkShiftFill> for ArrayChunked { + fn shift_and_fill(&self, periods: i64, fill_value: Option<&Series>) -> ArrayChunked { + // This has its own implementation because a ArrayChunked cannot have a full-null without + // knowing the inner type + let periods = clamp(periods, -(self.len() as i64), self.len() as i64); + let slice_offset = (-periods).max(0); + let length = self.len() - abs(periods) as usize; + let mut slice = self.slice(slice_offset, length); + + let fill_length = abs(periods) as usize; + let mut fill = match fill_value { + Some(val) => Self::full(self.name().clone(), val, fill_length), + None => ArrayChunked::full_null_with_dtype( + self.name().clone(), + fill_length, + self.inner_dtype(), + 0, + ), + }; + + if periods < 0 { + slice.append(&fill).unwrap(); + slice + } else { + fill.append(&slice).unwrap(); + fill + } + } +} + +#[cfg(feature = "dtype-array")] +impl ChunkShift for ArrayChunked { + fn shift(&self, periods: i64) -> Self { + self.shift_and_fill(periods, None) + } +} + +#[cfg(feature = "object")] +impl ChunkShiftFill, Option>> for ObjectChunked { + fn shift_and_fill( + &self, + _periods: i64, + _fill_value: Option>, + ) -> ChunkedArray> { + todo!() + } +} +#[cfg(feature = "object")] +impl ChunkShift> for ObjectChunked { + fn shift(&self, periods: i64) -> Self { + self.shift_and_fill(periods, None) + } +} + +#[cfg(feature = "dtype-struct")] +impl ChunkShift for StructChunked { + fn shift(&self, periods: i64) -> ChunkedArray { + // This has its own implementation because a ArrayChunked cannot have a full-null without + // knowing the inner type + let periods = clamp(periods, -(self.len() as i64), self.len() as i64); + let slice_offset = (-periods).max(0); + let length = self.len() - abs(periods) as usize; + let mut slice = self.slice(slice_offset, length); + + let fill_length = abs(periods) as usize; + + // Go via null, so the cast creates the proper struct type. + let fill = NullChunked::new(self.name().clone(), fill_length) + .cast(self.dtype(), Default::default()) + .unwrap(); + let mut fill = fill.struct_().unwrap().clone(); + + if periods < 0 { + slice.append(&fill).unwrap(); + slice + } else { + fill.append(&slice).unwrap(); + fill + } + } +} + +#[cfg(test)] +mod test { + use crate::prelude::*; + + #[test] + fn test_shift() { + let ca = Int32Chunked::new(PlSmallStr::EMPTY, &[1, 2, 3]); + + // shift by 0, 1, 2, 3, 4 + let shifted = ca.shift_and_fill(0, Some(5)); + assert_eq!(Vec::from(&shifted), &[Some(1), Some(2), Some(3)]); + let shifted = ca.shift_and_fill(1, Some(5)); + assert_eq!(Vec::from(&shifted), &[Some(5), Some(1), Some(2)]); + let shifted = ca.shift_and_fill(2, Some(5)); + assert_eq!(Vec::from(&shifted), &[Some(5), Some(5), Some(1)]); + let shifted = ca.shift_and_fill(3, Some(5)); + assert_eq!(Vec::from(&shifted), &[Some(5), Some(5), Some(5)]); + let shifted = ca.shift_and_fill(4, Some(5)); + assert_eq!(Vec::from(&shifted), &[Some(5), Some(5), Some(5)]); + + // shift by -1, -2, -3, -4 + let shifted = ca.shift_and_fill(-1, Some(5)); + assert_eq!(Vec::from(&shifted), &[Some(2), Some(3), Some(5)]); + let shifted = ca.shift_and_fill(-2, Some(5)); + assert_eq!(Vec::from(&shifted), &[Some(3), Some(5), Some(5)]); + let shifted = ca.shift_and_fill(-3, Some(5)); + assert_eq!(Vec::from(&shifted), &[Some(5), Some(5), Some(5)]); + let shifted = ca.shift_and_fill(-4, Some(5)); + assert_eq!(Vec::from(&shifted), &[Some(5), Some(5), Some(5)]); + + // fill with None + let shifted = ca.shift_and_fill(1, None); + assert_eq!(Vec::from(&shifted), &[None, Some(1), Some(2)]); + let shifted = ca.shift_and_fill(10, None); + assert_eq!(Vec::from(&shifted), &[None, None, None]); + let shifted = ca.shift_and_fill(-2, None); + assert_eq!(Vec::from(&shifted), &[Some(3), None, None]); + + // string + let s = Series::new(PlSmallStr::from_static("a"), ["a", "b", "c"]); + let shifted = s.shift(-1); + assert_eq!( + Vec::from(shifted.str().unwrap()), + &[Some("b"), Some("c"), None] + ); + } +} diff --git a/crates/polars-core/src/chunked_array/ops/sort/arg_bottom_k.rs b/crates/polars-core/src/chunked_array/ops/sort/arg_bottom_k.rs new file mode 100644 index 000000000000..7787ef28076f --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/sort/arg_bottom_k.rs @@ -0,0 +1,96 @@ +use polars_utils::itertools::Itertools; + +use super::*; +use crate::chunked_array::ops::row_encode::_get_rows_encoded; + +#[derive(Eq)] +struct CompareRow<'a> { + idx: IdxSize, + bytes: &'a [u8], +} + +impl PartialEq for CompareRow<'_> { + fn eq(&self, other: &Self) -> bool { + self.bytes == other.bytes + } +} + +impl Ord for CompareRow<'_> { + fn cmp(&self, other: &Self) -> Ordering { + self.bytes.cmp(other.bytes) + } +} + +impl PartialOrd for CompareRow<'_> { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +/// Return the indices of the bottom k elements. +/// +/// Similar to .argsort() then .slice(0, k) but with a more efficient implementation. +pub fn _arg_bottom_k( + k: usize, + by_column: &[Column], + sort_options: &mut SortMultipleOptions, +) -> PolarsResult> { + let from_n_rows = by_column[0].len(); + _broadcast_bools(by_column.len(), &mut sort_options.descending); + _broadcast_bools(by_column.len(), &mut sort_options.nulls_last); + + // Don't go into row encoding. + if by_column.len() == 1 && sort_options.limit.is_some() && !sort_options.maintain_order { + return Ok(NoNull::new(by_column[0].arg_sort((&*sort_options).into()))); + } + + let encoded = _get_rows_encoded( + by_column, + &sort_options.descending, + &sort_options.nulls_last, + )?; + let arr = encoded.into_array(); + let mut rows = arr + .values_iter() + .enumerate_idx() + .map(|(idx, bytes)| CompareRow { idx, bytes }) + .collect::>(); + + let sorted = if k >= from_n_rows { + match (sort_options.multithreaded, sort_options.maintain_order) { + (true, true) => POOL.install(|| { + rows.par_sort(); + }), + (true, false) => POOL.install(|| { + rows.par_sort_unstable(); + }), + (false, true) => rows.sort(), + (false, false) => rows.sort_unstable(), + } + &rows + } else if sort_options.maintain_order { + // todo: maybe there is some more efficient method, comparable to select_nth_unstable + if sort_options.multithreaded { + POOL.install(|| { + rows.par_sort(); + }) + } else { + rows.sort(); + } + &rows[..k] + } else { + // todo: possible multi threaded `select_nth_unstable`? + let (lower, _el, _upper) = rows.select_nth_unstable(k); + if sort_options.multithreaded { + POOL.install(|| { + lower.par_sort_unstable(); + }) + } else { + lower.sort_unstable(); + } + &*lower + }; + + let idx: NoNull = sorted.iter().map(|cmp_row| cmp_row.idx).collect(); + Ok(idx) +} diff --git a/crates/polars-core/src/chunked_array/ops/sort/arg_sort.rs b/crates/polars-core/src/chunked_array/ops/sort/arg_sort.rs new file mode 100644 index 000000000000..033c4850ef04 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/sort/arg_sort.rs @@ -0,0 +1,329 @@ +use polars_utils::itertools::Itertools; + +use self::row_encode::_get_rows_encoded; +use super::*; + +// Reduce monomorphisation. +fn sort_impl(vals: &mut [(IdxSize, T)], options: SortOptions) +where + T: TotalOrd + Send + Sync, +{ + sort_by_branch( + vals, + options.descending, + |a, b| a.1.tot_cmp(&b.1), + options.multithreaded, + ); +} +// Compute the indexes after reversing a sorted array, maintaining +// the order of equal elements, in linear time. Faster than sort_impl +// as we avoid allocating extra memory. +pub(super) fn reverse_stable_no_nulls(iters: I, len: usize) -> Vec +where + I: IntoIterator, + J: IntoIterator, + T: TotalOrd + Send + Sync, +{ + let mut current_start: IdxSize = 0; + let mut current_end: IdxSize = 0; + let mut rev_idx: Vec = Vec::with_capacity(len); + let mut i: IdxSize; + // We traverse the array, comparing consecutive elements. + // We maintain the start and end indice of elements with same value. + // When we see a new element we push the previous indices in reverse order. + // We do a final reverse to get stable reverse index. + // Example - + // 1 2 2 3 3 3 4 + // 0 1 2 3 4 5 6 + // We get start and end position of equal values - + // 0 1-2 3-5 6 + // We insert the indexes of equal elements in reverse + // 0 2 1 5 4 3 6 + // Then do a final reverse + // 6 3 4 5 1 2 0 + let mut previous_element: Option = None; + for arr_iter in iters { + for current_element in arr_iter { + match &previous_element { + None => { + //There is atleast one element + current_end = 1; + }, + Some(prev) => { + if current_element.tot_cmp(prev) == Ordering::Equal { + current_end += 1; + } else { + // Insert in reverse order + i = current_end; + while i > current_start { + i -= 1; + //SAFETY - we allocated enough + unsafe { rev_idx.push_unchecked(i) }; + } + current_start = current_end; + current_end += 1; + } + }, + } + previous_element = Some(current_element); + } + } + // If there are no elements this does nothing + i = current_end; + while i > current_start { + i -= 1; + unsafe { rev_idx.push_unchecked(i) }; + } + // Final reverse + rev_idx.reverse(); + rev_idx +} + +pub(super) fn arg_sort( + name: PlSmallStr, + iters: I, + options: SortOptions, + null_count: usize, + mut len: usize, + is_sorted_flag: IsSorted, + first_element_null: bool, +) -> IdxCa +where + I: IntoIterator, + J: IntoIterator>, + T: TotalOrd + Send + Sync, +{ + let nulls_last = options.nulls_last; + let null_cap = if nulls_last { null_count } else { len }; + + // Fast path + // Only if array is already sorted in the required ordered and + // nulls are also in the correct position + if ((options.descending && is_sorted_flag == IsSorted::Descending) + || (!options.descending && is_sorted_flag == IsSorted::Ascending)) + && ((nulls_last && !first_element_null) || (!nulls_last && first_element_null)) + { + len = options + .limit + .map_or(len, |limit| std::cmp::min(limit.try_into().unwrap(), len)); + return ChunkedArray::with_chunk( + name, + IdxArr::from_data_default( + Buffer::from((0..(len as IdxSize)).collect::>()), + None, + ), + ); + } + + let mut vals = Vec::with_capacity(len - null_count); + let mut nulls_idx = Vec::with_capacity(null_cap); + let mut count: IdxSize = 0; + + for arr_iter in iters { + let iter = arr_iter.into_iter().filter_map(|v| { + let i = count; + count += 1; + match v { + Some(v) => Some((i, v)), + None => { + // SAFETY: we allocated enough. + unsafe { nulls_idx.push_unchecked(i) }; + None + }, + } + }); + vals.extend(iter); + } + + let vals = if let Some(limit) = options.limit { + let limit = limit as usize; + // Overwrite output len. + len = limit; + let out = if limit >= vals.len() { + vals.as_mut_slice() + } else { + let (lower, _el, _upper) = vals + .as_mut_slice() + .select_nth_unstable_by(limit, |a, b| a.1.tot_cmp(&b.1)); + lower + }; + + sort_impl(out, options); + out + } else { + sort_impl(vals.as_mut_slice(), options); + vals.as_slice() + }; + + let iter = vals.iter().map(|(idx, _v)| idx).copied(); + let idx = if nulls_last { + let mut idx = Vec::with_capacity(len); + idx.extend(iter); + + let nulls_idx = if options.limit.is_some() { + &nulls_idx[..len - idx.len()] + } else { + &nulls_idx + }; + idx.extend_from_slice(nulls_idx); + idx + } else if options.limit.is_some() { + nulls_idx.extend(iter.take(len - nulls_idx.len())); + nulls_idx + } else { + let ptr = nulls_idx.as_ptr() as usize; + nulls_idx.extend(iter); + // We had a realloc. + debug_assert_eq!(nulls_idx.as_ptr() as usize, ptr); + nulls_idx + }; + + ChunkedArray::with_chunk(name, IdxArr::from_data_default(Buffer::from(idx), None)) +} + +pub(super) fn arg_sort_no_nulls( + name: PlSmallStr, + iters: I, + options: SortOptions, + len: usize, + is_sorted_flag: IsSorted, +) -> IdxCa +where + I: IntoIterator, + J: IntoIterator, + T: TotalOrd + Send + Sync, +{ + // Fast path + // 1) If array is already sorted in the required ordered . + // 2) If array is reverse sorted -> we do a stable reverse. + if is_sorted_flag != IsSorted::Not { + let len_final = options + .limit + .map_or(len, |limit| std::cmp::min(limit.try_into().unwrap(), len)); + if (options.descending && is_sorted_flag == IsSorted::Descending) + || (!options.descending && is_sorted_flag == IsSorted::Ascending) + { + return ChunkedArray::with_chunk( + name, + IdxArr::from_data_default( + Buffer::from((0..(len_final as IdxSize)).collect::>()), + None, + ), + ); + } else if (options.descending && is_sorted_flag == IsSorted::Ascending) + || (!options.descending && is_sorted_flag == IsSorted::Descending) + { + let idx = reverse_stable_no_nulls(iters, len); + let idx = Buffer::from(idx).sliced(0, len_final); + return ChunkedArray::with_chunk(name, IdxArr::from_data_default(idx, None)); + } + } + + let mut vals = Vec::with_capacity(len); + + let mut count: IdxSize = 0; + for arr_iter in iters { + vals.extend(arr_iter.into_iter().map(|v| { + let idx = count; + count += 1; + (idx, v) + })); + } + + let vals = if let Some(limit) = options.limit { + let limit = limit as usize; + let out = if limit >= vals.len() { + vals.as_mut_slice() + } else { + let (lower, _el, _upper) = vals + .as_mut_slice() + .select_nth_unstable_by(limit, |a, b| a.1.tot_cmp(&b.1)); + lower + }; + sort_impl(out, options); + out + } else { + sort_impl(vals.as_mut_slice(), options); + vals.as_slice() + }; + + let iter = vals.iter().map(|(idx, _v)| idx).copied(); + let idx: Vec<_> = iter.collect_trusted(); + + ChunkedArray::with_chunk(name, IdxArr::from_data_default(Buffer::from(idx), None)) +} + +pub(crate) fn arg_sort_row_fmt( + by: &[Column], + descending: bool, + nulls_last: bool, + parallel: bool, +) -> PolarsResult { + let rows_encoded = _get_rows_encoded(by, &[descending], &[nulls_last])?; + let mut items: Vec<_> = rows_encoded.iter().enumerate_idx().collect(); + + if parallel { + POOL.install(|| items.par_sort_by(|a, b| a.1.cmp(b.1))); + } else { + items.sort_by(|a, b| a.1.cmp(b.1)); + } + + let ca: NoNull = items.into_iter().map(|tpl| tpl.0).collect(); + Ok(ca.into_inner()) +} +#[cfg(test)] +mod test { + use sort::arg_sort::reverse_stable_no_nulls; + + use crate::prelude::*; + + #[test] + fn test_reverse_stable_no_nulls() { + let a = Int32Chunked::new( + PlSmallStr::from_static("a"), + &[ + Some(1), // 0 + Some(2), // 1 + Some(2), // 2 + Some(3), // 3 + Some(3), // 4 + Some(3), // 5 + Some(4), // 6 + ], + ); + let idx = reverse_stable_no_nulls(&a, 7); + let expected = [6, 3, 4, 5, 1, 2, 0]; + assert_eq!(idx, expected); + + let a = Int32Chunked::new( + PlSmallStr::from_static("a"), + &[ + Some(1), // 0 + Some(2), // 1 + Some(3), // 2 + Some(4), // 3 + Some(5), // 4 + Some(6), // 5 + Some(7), // 6 + ], + ); + let idx = reverse_stable_no_nulls(&a, 7); + let expected = [6, 5, 4, 3, 2, 1, 0]; + assert_eq!(idx, expected); + + let a = Int32Chunked::new( + PlSmallStr::from_static("a"), + &[ + Some(1), // 0 + ], + ); + let idx = reverse_stable_no_nulls(&a, 1); + let expected = [0]; + assert_eq!(idx, expected); + + let empty_array: [i32; 0] = []; + let a = Int32Chunked::new(PlSmallStr::from_static("a"), &empty_array); + let idx = reverse_stable_no_nulls(&a, 0); + assert_eq!(idx.len(), 0); + } +} diff --git a/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs b/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs new file mode 100644 index 000000000000..01d8e3454b0e --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs @@ -0,0 +1,107 @@ +use compare_inner::NullOrderCmp; +use polars_utils::itertools::Itertools; + +use super::*; +use crate::chunked_array::ops::row_encode::_get_rows_encoded; + +pub(crate) fn args_validate( + ca: &ChunkedArray, + other: &[Column], + param_value: &[bool], + param_name: &str, +) -> PolarsResult<()> { + for s in other { + assert_eq!(ca.len(), s.len()); + } + polars_ensure!(other.len() == (param_value.len() - 1), + ComputeError: + "the length of `{}` ({}) does not match the number of series ({})", + param_name, param_value.len(), other.len() + 1, + ); + Ok(()) +} + +pub(crate) fn arg_sort_multiple_impl( + mut vals: Vec<(IdxSize, T)>, + by: &[Column], + options: &SortMultipleOptions, +) -> PolarsResult { + let nulls_last = &options.nulls_last; + let descending = &options.descending; + + debug_assert_eq!(descending.len() - 1, by.len()); + debug_assert_eq!(nulls_last.len() - 1, by.len()); + + let compare_inner: Vec<_> = by + .iter() + .map(|c| c.into_total_ord_inner()) + .collect_trusted(); + + let first_descending = descending[0]; + let first_nulls_last = nulls_last[0]; + + let compare = |tpl_a: &(_, T), tpl_b: &(_, T)| -> Ordering { + match ( + first_descending, + tpl_a + .1 + .null_order_cmp(&tpl_b.1, first_nulls_last ^ first_descending), + ) { + // if ordering is equal, we check the other arrays until we find a non-equal ordering + // if we have exhausted all arrays, we keep the equal ordering. + (_, Ordering::Equal) => { + let idx_a = tpl_a.0 as usize; + let idx_b = tpl_b.0 as usize; + unsafe { + ordering_other_columns( + &compare_inner, + descending.get_unchecked(1..), + nulls_last.get_unchecked(1..), + idx_a, + idx_b, + ) + } + }, + (true, Ordering::Less) => Ordering::Greater, + (true, Ordering::Greater) => Ordering::Less, + (_, ord) => ord, + } + }; + + match (options.multithreaded, options.maintain_order) { + (true, true) => POOL.install(|| { + vals.par_sort_by(compare); + }), + (true, false) => POOL.install(|| { + vals.par_sort_unstable_by(compare); + }), + (false, true) => vals.sort_by(compare), + (false, false) => vals.sort_unstable_by(compare), + } + + let ca: NoNull = vals.into_iter().map(|(idx, _v)| idx).collect_trusted(); + // Don't set to sorted. Argsort indices are not sorted. + Ok(ca.into_inner()) +} + +pub(crate) fn argsort_multiple_row_fmt( + by: &[Column], + mut descending: Vec, + mut nulls_last: Vec, + parallel: bool, +) -> PolarsResult { + _broadcast_bools(by.len(), &mut descending); + _broadcast_bools(by.len(), &mut nulls_last); + + let rows_encoded = _get_rows_encoded(by, &descending, &nulls_last)?; + let mut items: Vec<_> = rows_encoded.iter().enumerate_idx().collect(); + + if parallel { + POOL.install(|| items.par_sort_by_key(|i| i.1)); + } else { + items.sort_by_key(|i| i.1); + } + + let ca: NoNull = items.into_iter().map(|tpl| tpl.0).collect(); + Ok(ca.into_inner()) +} diff --git a/crates/polars-core/src/chunked_array/ops/sort/categorical.rs b/crates/polars-core/src/chunked_array/ops/sort/categorical.rs new file mode 100644 index 000000000000..865cb0cb3b26 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/sort/categorical.rs @@ -0,0 +1,221 @@ +use super::*; + +impl CategoricalChunked { + #[must_use] + pub fn sort_with(&self, options: SortOptions) -> CategoricalChunked { + if self.uses_lexical_ordering() { + let mut vals = self + .physical() + .into_iter() + .zip(self.iter_str()) + .collect_trusted::>(); + + sort_unstable_by_branch(vals.as_mut_slice(), options, |a, b| a.1.cmp(&b.1)); + + let mut cats = Vec::with_capacity(self.len()); + let mut validity = + (self.null_count() > 0).then(|| BitmapBuilder::with_capacity(self.len())); + + if self.null_count() > 0 && !options.nulls_last { + cats.resize(self.null_count(), 0); + if let Some(validity) = &mut validity { + validity.extend_constant(self.null_count(), false); + } + } + + let valid_slice = if options.descending { + &vals[..self.len() - self.null_count()] + } else { + &vals[self.null_count()..] + }; + cats.extend(valid_slice.iter().map(|(idx, _v)| idx.unwrap())); + if let Some(validity) = &mut validity { + validity.extend_constant(self.len() - self.null_count(), true); + } + + if self.null_count() > 0 && options.nulls_last { + cats.resize(self.len(), 0); + if let Some(validity) = &mut validity { + validity.extend_constant(self.null_count(), false); + } + } + + let cats = PrimitiveArray::::new( + ArrowDataType::UInt32, + cats.into(), + validity.map(|v| v.freeze()), + ); + let cats = UInt32Chunked::from_chunk_iter(self.name().clone(), Some(cats)); + + // SAFETY: + // we only reordered the indexes so we are still in bounds + return unsafe { + CategoricalChunked::from_cats_and_rev_map_unchecked( + cats, + self.get_rev_map().clone(), + self.is_enum(), + self.get_ordering(), + ) + }; + } + let cats = self.physical().sort_with(options); + // SAFETY: + // we only reordered the indexes so we are still in bounds + unsafe { + CategoricalChunked::from_cats_and_rev_map_unchecked( + cats, + self.get_rev_map().clone(), + self.is_enum(), + self.get_ordering(), + ) + } + } + + /// Returned a sorted `ChunkedArray`. + #[must_use] + pub fn sort(&self, descending: bool) -> CategoricalChunked { + self.sort_with(SortOptions { + nulls_last: false, + descending, + multithreaded: true, + maintain_order: false, + limit: None, + }) + } + + /// Retrieve the indexes needed to sort this array. + pub fn arg_sort(&self, options: SortOptions) -> IdxCa { + if self.uses_lexical_ordering() { + let iters = [self.iter_str()]; + arg_sort::arg_sort( + self.name().clone(), + iters, + options, + self.physical().null_count(), + self.len(), + IsSorted::Not, + false, + ) + } else { + self.physical().arg_sort(options) + } + } + + /// Retrieve the indices needed to sort this and the other arrays. + pub(crate) fn arg_sort_multiple( + &self, + by: &[Column], + options: &SortMultipleOptions, + ) -> PolarsResult { + if self.uses_lexical_ordering() { + args_validate(self.physical(), by, &options.descending, "descending")?; + args_validate(self.physical(), by, &options.nulls_last, "nulls_last")?; + let mut count: IdxSize = 0; + + // we use bytes to save a monomorphisized str impl + // as bytes already is used for binary and string sorting + let vals: Vec<_> = self + .iter_str() + .map(|v| { + let i = count; + count += 1; + (i, v.map(|v| v.as_bytes())) + }) + .collect_trusted(); + + arg_sort_multiple_impl(vals, by, options) + } else { + self.physical().arg_sort_multiple(by, options) + } + } +} + +#[cfg(test)] +mod test { + use crate::prelude::*; + use crate::{SINGLE_LOCK, disable_string_cache, enable_string_cache}; + + fn assert_order(ca: &CategoricalChunked, cmp: &[&str]) { + let s = ca.cast(&DataType::String).unwrap(); + let ca = s.str().unwrap(); + assert_eq!(ca.into_no_null_iter().collect::>(), cmp); + } + + #[test] + fn test_cat_lexical_sort() -> PolarsResult<()> { + let init = &["c", "b", "a", "d"]; + + let _lock = SINGLE_LOCK.lock(); + for use_string_cache in [true, false] { + disable_string_cache(); + if use_string_cache { + enable_string_cache(); + } + + let s = Series::new(PlSmallStr::EMPTY, init) + .cast(&DataType::Categorical(None, CategoricalOrdering::Lexical))?; + let ca = s.categorical()?; + let ca_lexical = ca.clone(); + + let out = ca_lexical.sort(false); + assert_order(&out, &["a", "b", "c", "d"]); + + let s = Series::new(PlSmallStr::EMPTY, init) + .cast(&DataType::Categorical(None, Default::default()))?; + let ca = s.categorical()?; + + let out = ca.sort(false); + assert_order(&out, init); + + let out = ca_lexical.arg_sort(SortOptions { + descending: false, + ..Default::default() + }); + assert_eq!(out.into_no_null_iter().collect::>(), &[2, 1, 0, 3]); + } + + Ok(()) + } + + #[test] + fn test_cat_lexical_sort_multiple() -> PolarsResult<()> { + let init = &["c", "b", "a", "a"]; + + let _lock = SINGLE_LOCK.lock(); + for use_string_cache in [true, false] { + disable_string_cache(); + if use_string_cache { + enable_string_cache(); + } + + let s = Series::new(PlSmallStr::EMPTY, init) + .cast(&DataType::Categorical(None, CategoricalOrdering::Lexical))?; + let ca = s.categorical()?; + let ca_lexical: CategoricalChunked = ca.clone(); + + let series = ca_lexical.into_series(); + + let df = df![ + "cat" => &series, + "vals" => [1, 1, 2, 2] + ]?; + + let out = df.sort( + ["cat", "vals"], + SortMultipleOptions::default().with_order_descending_multi([false, false]), + )?; + let out = out.column("cat")?; + let cat = out.as_materialized_series().categorical()?; + assert_order(cat, &["a", "a", "b", "c"]); + + let out = df.sort( + ["vals", "cat"], + SortMultipleOptions::default().with_order_descending_multi([false, false]), + )?; + let out = out.column("cat")?; + let cat = out.as_materialized_series().categorical()?; + assert_order(cat, &["b", "c", "a", "a"]); + } + Ok(()) + } +} diff --git a/crates/polars-core/src/chunked_array/ops/sort/mod.rs b/crates/polars-core/src/chunked_array/ops/sort/mod.rs new file mode 100644 index 000000000000..429190dbbb10 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/sort/mod.rs @@ -0,0 +1,1071 @@ +mod arg_sort; + +pub mod arg_sort_multiple; + +pub mod arg_bottom_k; +pub mod options; + +#[cfg(feature = "dtype-categorical")] +mod categorical; + +use std::cmp::Ordering; + +pub(crate) use arg_sort::arg_sort_row_fmt; +pub(crate) use arg_sort_multiple::argsort_multiple_row_fmt; +use arrow::bitmap::{Bitmap, BitmapBuilder}; +use arrow::buffer::Buffer; +use arrow::legacy::trusted_len::TrustedLenPush; +use compare_inner::NonNull; +use rayon::prelude::*; +pub use slice::*; + +use super::row_encode::_get_rows_encoded_ca; +use crate::POOL; +use crate::prelude::compare_inner::TotalOrdInner; +use crate::prelude::sort::arg_sort_multiple::*; +use crate::prelude::*; +use crate::series::IsSorted; +use crate::utils::NoNull; + +fn partition_nulls( + values: &mut [T], + mut validity: Option, + options: SortOptions, +) -> (&mut [T], Option) { + let partitioned = if let Some(bitmap) = &validity { + // Partition null last first + let mut out_len = 0; + for idx in bitmap.true_idx_iter() { + unsafe { *values.get_unchecked_mut(out_len) = *values.get_unchecked(idx) }; + out_len += 1; + } + let valid_count = out_len; + let null_count = values.len() - valid_count; + validity = Some(create_validity( + bitmap.len(), + bitmap.unset_bits(), + options.nulls_last, + )); + + // Views are correctly partitioned. + if options.nulls_last { + &mut values[..valid_count] + } + // We need to swap the ends. + else { + // swap nulls with end + let mut end = values.len() - 1; + + for i in 0..null_count { + unsafe { *values.get_unchecked_mut(end) = *values.get_unchecked(i) }; + end = end.saturating_sub(1); + } + &mut values[null_count..] + } + } else { + values + }; + (partitioned, validity) +} + +pub(crate) fn sort_by_branch(slice: &mut [T], descending: bool, cmp: C, parallel: bool) +where + T: Send, + C: Send + Sync + Fn(&T, &T) -> Ordering, +{ + if parallel { + POOL.install(|| match descending { + true => slice.par_sort_by(|a, b| cmp(b, a)), + false => slice.par_sort_by(cmp), + }) + } else { + match descending { + true => slice.sort_by(|a, b| cmp(b, a)), + false => slice.sort_by(cmp), + } + } +} + +fn sort_unstable_by_branch(slice: &mut [T], options: SortOptions, cmp: C) +where + T: Send, + C: Send + Sync + Fn(&T, &T) -> Ordering, +{ + if options.multithreaded { + POOL.install(|| match options.descending { + true => slice.par_sort_unstable_by(|a, b| cmp(b, a)), + false => slice.par_sort_unstable_by(cmp), + }) + } else { + match options.descending { + true => slice.sort_unstable_by(|a, b| cmp(b, a)), + false => slice.sort_unstable_by(cmp), + } + } +} + +// Reduce monomorphisation. +fn sort_impl_unstable(vals: &mut [T], options: SortOptions) +where + T: TotalOrd + Send + Sync, +{ + sort_unstable_by_branch(vals, options, TotalOrd::tot_cmp); +} + +fn create_validity(len: usize, null_count: usize, nulls_last: bool) -> Bitmap { + let mut validity = BitmapBuilder::with_capacity(len); + if nulls_last { + validity.extend_constant(len - null_count, true); + validity.extend_constant(null_count, false); + } else { + validity.extend_constant(null_count, false); + validity.extend_constant(len - null_count, true); + } + validity.freeze() +} + +macro_rules! sort_with_fast_path { + ($ca:ident, $options:expr) => {{ + if $ca.is_empty() { + return $ca.clone(); + } + + // we can clone if we sort in same order + if $options.descending && $ca.is_sorted_descending_flag() || ($ca.is_sorted_ascending_flag() && !$options.descending) { + // there are nulls + if $ca.null_count() > 0 { + // if the nulls are already last we can clone + if $options.nulls_last && $ca.get($ca.len() - 1).is_none() || + // if the nulls are already first we can clone + (!$options.nulls_last && $ca.get(0).is_none()) + { + return $ca.clone(); + } + // nulls are not at the right place + // continue w/ sorting + // TODO: we can optimize here and just put the null at the correct place + } else { + return $ca.clone(); + } + } + // we can reverse if we sort in other order + else if ($options.descending && $ca.is_sorted_ascending_flag() || $ca.is_sorted_descending_flag()) && $ca.null_count() == 0 { + return $ca.reverse() + }; + + + }} +} + +macro_rules! arg_sort_fast_path { + ($ca:ident, $options:expr) => {{ + // if already sorted in required order we can just return 0..len + if $options.limit.is_none() && + ($options.descending && $ca.is_sorted_descending_flag() || ($ca.is_sorted_ascending_flag() && !$options.descending)) { + // there are nulls + if $ca.null_count() > 0 { + // if the nulls are already last we can return 0..len + if ($options.nulls_last && $ca.get($ca.len() - 1).is_none() ) || + // if the nulls are already first we can return 0..len + (! $options.nulls_last && $ca.get(0).is_none()) + { + return ChunkedArray::with_chunk($ca.name().clone(), + IdxArr::from_data_default(Buffer::from((0..($ca.len() as IdxSize)).collect::>()), None)); + } + // nulls are not at the right place + // continue w/ sorting + // TODO: we can optimize here and just put the null at the correct place + } else { + // no nulls + return ChunkedArray::with_chunk($ca.name().clone(), + IdxArr::from_data_default(Buffer::from((0..($ca.len() as IdxSize )).collect::>()), None)); + } + } + }} +} + +fn sort_with_numeric(ca: &ChunkedArray, options: SortOptions) -> ChunkedArray +where + T: PolarsNumericType, +{ + sort_with_fast_path!(ca, options); + if ca.null_count() == 0 { + let mut vals = ca.to_vec_null_aware().left().unwrap(); + + sort_impl_unstable(vals.as_mut_slice(), options); + + let mut ca = ChunkedArray::from_vec(ca.name().clone(), vals); + let s = if options.descending { + IsSorted::Descending + } else { + IsSorted::Ascending + }; + ca.set_sorted_flag(s); + ca + } else { + let null_count = ca.null_count(); + let len = ca.len(); + + let mut vals = Vec::with_capacity(ca.len()); + + if !options.nulls_last { + let iter = std::iter::repeat_n(T::Native::default(), null_count); + vals.extend(iter); + } + + ca.downcast_iter().for_each(|arr| { + let iter = arr.iter().filter_map(|v| v.copied()); + vals.extend(iter); + }); + let mut_slice = if options.nulls_last { + &mut vals[..len - null_count] + } else { + &mut vals[null_count..] + }; + + sort_impl_unstable(mut_slice, options); + + if options.nulls_last { + vals.extend(std::iter::repeat_n(T::Native::default(), ca.null_count())); + } + + let arr = PrimitiveArray::new( + T::get_dtype().to_arrow(CompatLevel::newest()), + vals.into(), + Some(create_validity(len, null_count, options.nulls_last)), + ); + let mut new_ca = ChunkedArray::with_chunk(ca.name().clone(), arr); + let s = if options.descending { + IsSorted::Descending + } else { + IsSorted::Ascending + }; + new_ca.set_sorted_flag(s); + new_ca + } +} + +fn arg_sort_numeric(ca: &ChunkedArray, mut options: SortOptions) -> IdxCa +where + T: PolarsNumericType, +{ + options.multithreaded &= POOL.current_num_threads() > 1; + arg_sort_fast_path!(ca, options); + if ca.null_count() == 0 { + let iter = ca + .downcast_iter() + .map(|arr| arr.values().as_slice().iter().copied()); + arg_sort::arg_sort_no_nulls( + ca.name().clone(), + iter, + options, + ca.len(), + ca.is_sorted_flag(), + ) + } else { + let iter = ca + .downcast_iter() + .map(|arr| arr.iter().map(|opt| opt.copied())); + arg_sort::arg_sort( + ca.name().clone(), + iter, + options, + ca.null_count(), + ca.len(), + ca.is_sorted_flag(), + ca.get(0).is_none(), + ) + } +} + +fn arg_sort_multiple_numeric( + ca: &ChunkedArray, + by: &[Column], + options: &SortMultipleOptions, +) -> PolarsResult { + args_validate(ca, by, &options.descending, "descending")?; + args_validate(ca, by, &options.nulls_last, "nulls_last")?; + let mut count: IdxSize = 0; + + let no_nulls = ca.null_count() == 0; + + if no_nulls { + let mut vals = Vec::with_capacity(ca.len()); + for arr in ca.downcast_iter() { + vals.extend_trusted_len(arr.values().as_slice().iter().map(|v| { + let i = count; + count += 1; + (i, NonNull(*v)) + })) + } + arg_sort_multiple_impl(vals, by, options) + } else { + let mut vals = Vec::with_capacity(ca.len()); + for arr in ca.downcast_iter() { + vals.extend_trusted_len(arr.into_iter().map(|v| { + let i = count; + count += 1; + (i, v.copied()) + })); + } + arg_sort_multiple_impl(vals, by, options) + } +} + +impl ChunkSort for ChunkedArray +where + T: PolarsNumericType, +{ + fn sort_with(&self, mut options: SortOptions) -> ChunkedArray { + options.multithreaded &= POOL.current_num_threads() > 1; + sort_with_numeric(self, options) + } + + fn sort(&self, descending: bool) -> ChunkedArray { + self.sort_with(SortOptions { + descending, + ..Default::default() + }) + } + + fn arg_sort(&self, options: SortOptions) -> IdxCa { + arg_sort_numeric(self, options) + } + + /// # Panics + /// + /// This function is very opinionated. + /// We assume that all numeric `Series` are of the same type, if not it will panic + fn arg_sort_multiple( + &self, + by: &[Column], + options: &SortMultipleOptions, + ) -> PolarsResult { + arg_sort_multiple_numeric(self, by, options) + } +} + +fn ordering_other_columns<'a>( + compare_inner: &'a [Box], + descending: &[bool], + nulls_last: &[bool], + idx_a: usize, + idx_b: usize, +) -> Ordering { + for ((cmp, descending), null_last) in compare_inner.iter().zip(descending).zip(nulls_last) { + // SAFETY: indices are in bounds + let ordering = unsafe { cmp.cmp_element_unchecked(idx_a, idx_b, null_last ^ descending) }; + match (ordering, descending) { + (Ordering::Equal, _) => continue, + (_, true) => return ordering.reverse(), + _ => return ordering, + } + } + // all arrays/columns exhausted, ordering equal it is. + Ordering::Equal +} + +impl ChunkSort for StringChunked { + fn sort_with(&self, options: SortOptions) -> ChunkedArray { + unsafe { self.as_binary().sort_with(options).to_string_unchecked() } + } + + fn sort(&self, descending: bool) -> StringChunked { + self.sort_with(SortOptions { + descending, + nulls_last: false, + multithreaded: true, + maintain_order: false, + limit: None, + }) + } + + fn arg_sort(&self, options: SortOptions) -> IdxCa { + self.as_binary().arg_sort(options) + } + + /// # Panics + /// + /// This function is very opinionated. On the implementation of `ChunkedArray` for numeric types, + /// we assume that all numeric `Series` are of the same type. + /// + /// In this case we assume that all numeric `Series` are `f64` types. The caller needs to + /// uphold this contract. If not, it will panic. + /// + fn arg_sort_multiple( + &self, + by: &[Column], + options: &SortMultipleOptions, + ) -> PolarsResult { + self.as_binary().arg_sort_multiple(by, options) + } +} + +impl ChunkSort for BinaryChunked { + fn sort_with(&self, mut options: SortOptions) -> ChunkedArray { + options.multithreaded &= POOL.current_num_threads() > 1; + sort_with_fast_path!(self, options); + // We will sort by the views and reconstruct with sorted views. We leave the buffers as is. + // We must rechunk to ensure that all views point into the proper buffers. + let ca = self.rechunk(); + let arr = ca.downcast_as_array().clone(); + + let (views, buffers, validity, total_bytes_len, total_buffer_len) = arr.into_inner(); + let mut views = views.make_mut(); + + let (partitioned_part, validity) = partition_nulls(&mut views, validity, options); + + sort_unstable_by_branch(partitioned_part, options, |a, b| unsafe { + a.get_slice_unchecked(&buffers) + .tot_cmp(&b.get_slice_unchecked(&buffers)) + }); + + let array = unsafe { + BinaryViewArray::new_unchecked( + ArrowDataType::BinaryView, + views.into(), + buffers, + validity, + total_bytes_len, + total_buffer_len, + ) + }; + + let mut out = Self::with_chunk_like(self, array); + + let s = if options.descending { + IsSorted::Descending + } else { + IsSorted::Ascending + }; + out.set_sorted_flag(s); + out + } + + fn sort(&self, descending: bool) -> ChunkedArray { + self.sort_with(SortOptions { + descending, + nulls_last: false, + multithreaded: true, + maintain_order: false, + limit: None, + }) + } + + fn arg_sort(&self, options: SortOptions) -> IdxCa { + arg_sort_fast_path!(self, options); + if self.null_count() == 0 { + arg_sort::arg_sort_no_nulls( + self.name().clone(), + self.downcast_iter().map(|arr| arr.values_iter()), + options, + self.len(), + self.is_sorted_flag(), + ) + } else { + arg_sort::arg_sort( + self.name().clone(), + self.downcast_iter().map(|arr| arr.iter()), + options, + self.null_count(), + self.len(), + self.is_sorted_flag(), + self.get(0).is_none(), + ) + } + } + + fn arg_sort_multiple( + &self, + by: &[Column], + options: &SortMultipleOptions, + ) -> PolarsResult { + args_validate(self, by, &options.descending, "descending")?; + args_validate(self, by, &options.nulls_last, "nulls_last")?; + let mut count: IdxSize = 0; + + let mut vals = Vec::with_capacity(self.len()); + for arr in self.downcast_iter() { + for v in arr { + let i = count; + count += 1; + vals.push((i, v)) + } + } + + arg_sort_multiple_impl(vals, by, options) + } +} + +impl ChunkSort for BinaryOffsetChunked { + fn sort_with(&self, mut options: SortOptions) -> BinaryOffsetChunked { + options.multithreaded &= POOL.current_num_threads() > 1; + sort_with_fast_path!(self, options); + + let mut v: Vec<&[u8]> = Vec::with_capacity(self.len()); + for arr in self.downcast_iter() { + v.extend(arr.non_null_values_iter()); + } + + sort_impl_unstable(v.as_mut_slice(), options); + + let mut values = Vec::::with_capacity(self.get_values_size()); + let mut offsets = Vec::::with_capacity(self.len() + 1); + let mut length_so_far = 0i64; + offsets.push(length_so_far); + + let len = self.len(); + let null_count = self.null_count(); + let mut ca: Self = match (null_count, options.nulls_last) { + (0, _) => { + for val in v { + values.extend_from_slice(val); + length_so_far = values.len() as i64; + offsets.push(length_so_far); + } + // SAFETY: offsets are correctly created. + let arr = unsafe { + BinaryArray::from_data_unchecked_default(offsets.into(), values.into(), None) + }; + ChunkedArray::with_chunk(self.name().clone(), arr) + }, + (_, true) => { + for val in v { + values.extend_from_slice(val); + length_so_far = values.len() as i64; + offsets.push(length_so_far); + } + offsets.extend(std::iter::repeat_n(length_so_far, null_count)); + + // SAFETY: offsets are correctly created. + let arr = unsafe { + BinaryArray::from_data_unchecked_default( + offsets.into(), + values.into(), + Some(create_validity(len, null_count, true)), + ) + }; + ChunkedArray::with_chunk(self.name().clone(), arr) + }, + (_, false) => { + offsets.extend(std::iter::repeat_n(length_so_far, null_count)); + + for val in v { + values.extend_from_slice(val); + length_so_far = values.len() as i64; + offsets.push(length_so_far); + } + + // SAFETY: we pass valid UTF-8. + let arr = unsafe { + BinaryArray::from_data_unchecked_default( + offsets.into(), + values.into(), + Some(create_validity(len, null_count, false)), + ) + }; + ChunkedArray::with_chunk(self.name().clone(), arr) + }, + }; + + let s = if options.descending { + IsSorted::Descending + } else { + IsSorted::Ascending + }; + ca.set_sorted_flag(s); + ca + } + + fn sort(&self, descending: bool) -> BinaryOffsetChunked { + self.sort_with(SortOptions { + descending, + nulls_last: false, + multithreaded: true, + maintain_order: false, + limit: None, + }) + } + + fn arg_sort(&self, mut options: SortOptions) -> IdxCa { + options.multithreaded &= POOL.current_num_threads() > 1; + let ca = self.rechunk(); + let arr = ca.downcast_as_array(); + let mut idx = (0..(arr.len() as IdxSize)).collect::>(); + + let argsort = |args| { + if options.maintain_order { + sort_by_branch( + args, + options.descending, + |a, b| unsafe { + let a = arr.value_unchecked(*a as usize); + let b = arr.value_unchecked(*b as usize); + a.tot_cmp(&b) + }, + options.multithreaded, + ); + } else { + sort_unstable_by_branch(args, options, |a, b| unsafe { + let a = arr.value_unchecked(*a as usize); + let b = arr.value_unchecked(*b as usize); + a.tot_cmp(&b) + }); + } + }; + + if self.null_count() == 0 { + argsort(&mut idx); + IdxCa::from_vec(self.name().clone(), idx) + } else { + // This branch (almost?) never gets called as the row-encoding also encodes nulls. + let (partitioned_part, validity) = + partition_nulls(&mut idx, arr.validity().cloned(), options); + argsort(partitioned_part); + IdxCa::with_chunk( + self.name().clone(), + IdxArr::from_data_default(idx.into(), validity), + ) + } + } + + /// # Panics + /// + /// This function is very opinionated. On the implementation of `ChunkedArray` for numeric types, + /// we assume that all numeric `Series` are of the same type. + /// + /// In this case we assume that all numeric `Series` are `f64` types. The caller needs to + /// uphold this contract. If not, it will panic. + fn arg_sort_multiple( + &self, + by: &[Column], + options: &SortMultipleOptions, + ) -> PolarsResult { + args_validate(self, by, &options.descending, "descending")?; + args_validate(self, by, &options.nulls_last, "nulls_last")?; + let mut count: IdxSize = 0; + + let mut vals = Vec::with_capacity(self.len()); + for arr in self.downcast_iter() { + for v in arr { + let i = count; + count += 1; + vals.push((i, v)) + } + } + + arg_sort_multiple_impl(vals, by, options) + } +} + +#[cfg(feature = "dtype-struct")] +impl ChunkSort for StructChunked { + fn sort_with(&self, mut options: SortOptions) -> ChunkedArray { + options.multithreaded &= POOL.current_num_threads() > 1; + let idx = self.arg_sort(options); + let mut out = unsafe { self.take_unchecked(&idx) }; + + let s = if options.descending { + IsSorted::Descending + } else { + IsSorted::Ascending + }; + out.set_sorted_flag(s); + out + } + + fn sort(&self, descending: bool) -> ChunkedArray { + self.sort_with(SortOptions::new().with_order_descending(descending)) + } + + fn arg_sort(&self, options: SortOptions) -> IdxCa { + let bin = self.get_row_encoded(options).unwrap(); + bin.arg_sort(Default::default()) + } +} + +impl ChunkSort for ListChunked { + fn sort_with(&self, mut options: SortOptions) -> ListChunked { + options.multithreaded &= POOL.current_num_threads() > 1; + let idx = self.arg_sort(options); + let mut out = unsafe { self.take_unchecked(&idx) }; + + let s = if options.descending { + IsSorted::Descending + } else { + IsSorted::Ascending + }; + out.set_sorted_flag(s); + out + } + + fn sort(&self, descending: bool) -> ListChunked { + self.sort_with(SortOptions::new().with_order_descending(descending)) + } + + fn arg_sort(&self, options: SortOptions) -> IdxCa { + let bin = _get_rows_encoded_ca( + self.name().clone(), + &[self.clone().into_column()], + &[options.descending], + &[options.nulls_last], + ) + .unwrap(); + bin.arg_sort(Default::default()) + } +} + +impl ChunkSort for BooleanChunked { + fn sort_with(&self, mut options: SortOptions) -> ChunkedArray { + options.multithreaded &= POOL.current_num_threads() > 1; + sort_with_fast_path!(self, options); + let mut bitmap = BitmapBuilder::with_capacity(self.len()); + let mut validity = + (self.null_count() > 0).then(|| BitmapBuilder::with_capacity(self.len())); + + if self.null_count() > 0 && !options.nulls_last { + bitmap.extend_constant(self.null_count(), false); + if let Some(validity) = &mut validity { + validity.extend_constant(self.null_count(), false); + } + } + + let n_valid = self.len() - self.null_count(); + let n_set = self.sum().unwrap() as usize; + if options.descending { + bitmap.extend_constant(n_set, true); + bitmap.extend_constant(n_valid - n_set, false); + } else { + bitmap.extend_constant(n_valid - n_set, false); + bitmap.extend_constant(n_set, true); + } + if let Some(validity) = &mut validity { + validity.extend_constant(n_valid, true); + } + + if self.null_count() > 0 && options.nulls_last { + bitmap.extend_constant(self.null_count(), false); + if let Some(validity) = &mut validity { + validity.extend_constant(self.null_count(), false); + } + } + + Self::from_chunk_iter( + self.name().clone(), + Some(BooleanArray::from_data_default( + bitmap.freeze(), + validity.map(|v| v.freeze()), + )), + ) + } + + fn sort(&self, descending: bool) -> BooleanChunked { + self.sort_with(SortOptions { + descending, + nulls_last: false, + multithreaded: true, + maintain_order: false, + limit: None, + }) + } + + fn arg_sort(&self, options: SortOptions) -> IdxCa { + arg_sort_fast_path!(self, options); + if self.null_count() == 0 { + arg_sort::arg_sort_no_nulls( + self.name().clone(), + self.downcast_iter().map(|arr| arr.values_iter()), + options, + self.len(), + self.is_sorted_flag(), + ) + } else { + arg_sort::arg_sort( + self.name().clone(), + self.downcast_iter().map(|arr| arr.iter()), + options, + self.null_count(), + self.len(), + self.is_sorted_flag(), + self.get(0).is_none(), + ) + } + } + fn arg_sort_multiple( + &self, + by: &[Column], + options: &SortMultipleOptions, + ) -> PolarsResult { + let mut vals = Vec::with_capacity(self.len()); + let mut count: IdxSize = 0; + for arr in self.downcast_iter() { + vals.extend_trusted_len(arr.into_iter().map(|v| { + let i = count; + count += 1; + (i, v.map(|v| v as u8)) + })); + } + arg_sort_multiple_impl(vals, by, options) + } +} + +pub fn _broadcast_bools(n_cols: usize, values: &mut Vec) { + if n_cols > values.len() && values.len() == 1 { + while n_cols != values.len() { + values.push(values[0]); + } + } +} + +pub(crate) fn prepare_arg_sort( + columns: Vec, + sort_options: &mut SortMultipleOptions, +) -> PolarsResult<(Column, Vec)> { + let n_cols = columns.len(); + + let mut columns = columns; + + _broadcast_bools(n_cols, &mut sort_options.descending); + _broadcast_bools(n_cols, &mut sort_options.nulls_last); + + let first = columns.remove(0); + Ok((first, columns)) +} + +#[cfg(test)] +mod test { + use crate::prelude::*; + #[test] + fn test_arg_sort() { + let a = Int32Chunked::new( + PlSmallStr::from_static("a"), + &[ + Some(1), // 0 + Some(5), // 1 + None, // 2 + Some(1), // 3 + None, // 4 + Some(4), // 5 + Some(3), // 6 + Some(1), // 7 + ], + ); + let idx = a.arg_sort(SortOptions { + descending: false, + ..Default::default() + }); + let idx = idx.cont_slice().unwrap(); + + let expected = [2, 4, 0, 3, 7, 6, 5, 1]; + assert_eq!(idx, expected); + + let idx = a.arg_sort(SortOptions { + descending: true, + ..Default::default() + }); + let idx = idx.cont_slice().unwrap(); + // the duplicates are in reverse order of appearance, so we cannot reverse expected + let expected = [2, 4, 1, 5, 6, 0, 3, 7]; + assert_eq!(idx, expected); + } + + #[test] + fn test_sort() { + let a = Int32Chunked::new( + PlSmallStr::from_static("a"), + &[ + Some(1), + Some(5), + None, + Some(1), + None, + Some(4), + Some(3), + Some(1), + ], + ); + let out = a.sort_with(SortOptions { + descending: false, + nulls_last: false, + multithreaded: true, + maintain_order: false, + limit: None, + }); + assert_eq!( + Vec::from(&out), + &[ + None, + None, + Some(1), + Some(1), + Some(1), + Some(3), + Some(4), + Some(5) + ] + ); + let out = a.sort_with(SortOptions { + descending: false, + nulls_last: true, + multithreaded: true, + maintain_order: false, + limit: None, + }); + assert_eq!( + Vec::from(&out), + &[ + Some(1), + Some(1), + Some(1), + Some(3), + Some(4), + Some(5), + None, + None + ] + ); + let b = BooleanChunked::new( + PlSmallStr::from_static("b"), + &[Some(false), Some(true), Some(false)], + ); + let out = b.sort_with(SortOptions::default().with_order_descending(true)); + assert_eq!(Vec::from(&out), &[Some(true), Some(false), Some(false)]); + let out = b.sort_with(SortOptions::default().with_order_descending(false)); + assert_eq!(Vec::from(&out), &[Some(false), Some(false), Some(true)]); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_arg_sort_multiple() -> PolarsResult<()> { + let a = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 1, 1, 3, 4, 3, 3]); + let b = Int64Chunked::new(PlSmallStr::from_static("b"), &[0, 1, 2, 3, 4, 5, 6, 1]); + let c = StringChunked::new( + PlSmallStr::from_static("c"), + &["a", "b", "c", "d", "e", "f", "g", "h"], + ); + let df = DataFrame::new(vec![ + a.into_series().into(), + b.into_series().into(), + c.into_series().into(), + ])?; + + let out = df.sort(["a", "b", "c"], SortMultipleOptions::default())?; + assert_eq!( + Vec::from(out.column("b")?.as_series().unwrap().i64()?), + &[ + Some(0), + Some(2), + Some(3), + Some(1), + Some(1), + Some(4), + Some(6), + Some(5) + ] + ); + + // now let the first sort be a string + let a = StringChunked::new( + PlSmallStr::from_static("a"), + &["a", "b", "c", "a", "b", "c"], + ) + .into_series(); + let b = Int32Chunked::new(PlSmallStr::from_static("b"), &[5, 4, 2, 3, 4, 5]).into_series(); + let df = DataFrame::new(vec![a.into(), b.into()])?; + + let out = df.sort(["a", "b"], SortMultipleOptions::default())?; + let expected = df!( + "a" => ["a", "a", "b", "b", "c", "c"], + "b" => [3, 5, 4, 4, 2, 5] + )?; + assert!(out.equals(&expected)); + + let df = df!( + "groups" => [1, 2, 3], + "values" => ["a", "a", "b"] + )?; + + let out = df.sort( + ["groups", "values"], + SortMultipleOptions::default().with_order_descending_multi([true, false]), + )?; + let expected = df!( + "groups" => [3, 2, 1], + "values" => ["b", "a", "a"] + )?; + assert!(out.equals(&expected)); + + let out = df.sort( + ["values", "groups"], + SortMultipleOptions::default().with_order_descending_multi([false, true]), + )?; + let expected = df!( + "groups" => [2, 1, 3], + "values" => ["a", "a", "b"] + )?; + assert!(out.equals(&expected)); + + Ok(()) + } + + #[test] + fn test_sort_string() { + let ca = StringChunked::new( + PlSmallStr::from_static("a"), + &[Some("a"), None, Some("c"), None, Some("b")], + ); + let out = ca.sort_with(SortOptions { + descending: false, + nulls_last: false, + multithreaded: true, + maintain_order: false, + limit: None, + }); + let expected = &[None, None, Some("a"), Some("b"), Some("c")]; + assert_eq!(Vec::from(&out), expected); + + let out = ca.sort_with(SortOptions { + descending: true, + nulls_last: false, + multithreaded: true, + maintain_order: false, + limit: None, + }); + + let expected = &[None, None, Some("c"), Some("b"), Some("a")]; + assert_eq!(Vec::from(&out), expected); + + let out = ca.sort_with(SortOptions { + descending: false, + nulls_last: true, + multithreaded: true, + maintain_order: false, + limit: None, + }); + let expected = &[Some("a"), Some("b"), Some("c"), None, None]; + assert_eq!(Vec::from(&out), expected); + + let out = ca.sort_with(SortOptions { + descending: true, + nulls_last: true, + multithreaded: true, + maintain_order: false, + limit: None, + }); + let expected = &[Some("c"), Some("b"), Some("a"), None, None]; + assert_eq!(Vec::from(&out), expected); + + // no nulls + let ca = StringChunked::new( + PlSmallStr::from_static("a"), + &[Some("a"), Some("c"), Some("b")], + ); + let out = ca.sort(false); + let expected = &[Some("a"), Some("b"), Some("c")]; + assert_eq!(Vec::from(&out), expected); + + let out = ca.sort(true); + let expected = &[Some("c"), Some("b"), Some("a")]; + assert_eq!(Vec::from(&out), expected); + } +} diff --git a/crates/polars-core/src/chunked_array/ops/sort/options.rs b/crates/polars-core/src/chunked_array/ops/sort/options.rs new file mode 100644 index 000000000000..15e82e9f8606 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/sort/options.rs @@ -0,0 +1,248 @@ +#[cfg(feature = "serde-lazy")] +use serde::{Deserialize, Serialize}; +pub use slice::*; + +use crate::prelude::*; + +/// Options for single series sorting. +/// +/// Indicating the order of sorting, nulls position, multithreading, and maintaining order. +/// +/// # Example +/// +/// ``` +/// # use polars_core::prelude::*; +/// let s = Series::new("a".into(), [Some(5), Some(2), Some(3), Some(4), None].as_ref()); +/// let sorted = s +/// .sort( +/// SortOptions::default() +/// .with_order_descending(true) +/// .with_nulls_last(true) +/// .with_multithreaded(false), +/// ) +/// .unwrap(); +/// assert_eq!( +/// sorted, +/// Series::new("a".into(), [Some(5), Some(4), Some(3), Some(2), None].as_ref()) +/// ); +/// ``` +#[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)] +#[cfg_attr(feature = "serde-lazy", derive(Serialize, Deserialize))] +pub struct SortOptions { + /// If true sort in descending order. + /// Default `false`. + pub descending: bool, + /// Whether place null values last. + /// Default `false`. + pub nulls_last: bool, + /// If true sort in multiple threads. + /// Default `true`. + pub multithreaded: bool, + /// If true maintain the order of equal elements. + /// Default `false`. + pub maintain_order: bool, + /// Limit a sort output, this is for optimization purposes and might be ignored. + pub limit: Option, +} + +/// Sort options for multi-series sorting. +/// +/// Indicating the order of sorting, nulls position, multithreading, and maintaining order. +/// +/// # Example +/// ``` +/// # use polars_core::prelude::*; +/// +/// # fn main() -> PolarsResult<()> { +/// let df = df! { +/// "a" => [Some(1), Some(2), None, Some(4), None], +/// "b" => [Some(5), None, Some(3), Some(2), Some(1)] +/// }?; +/// +/// let out = df +/// .sort( +/// ["a", "b"], +/// SortMultipleOptions::default() +/// .with_maintain_order(true) +/// .with_multithreaded(false) +/// .with_order_descending_multi([false, true]) +/// .with_nulls_last(true), +/// )?; +/// +/// let expected = df! { +/// "a" => [Some(1), Some(2), Some(4), None, None], +/// "b" => [Some(5), None, Some(2), Some(3), Some(1)] +/// }?; +/// +/// assert_eq!(out, expected); +/// +/// # Ok(()) +/// # } +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +#[cfg_attr(feature = "serde-lazy", derive(Serialize, Deserialize))] +pub struct SortMultipleOptions { + /// Order of the columns. Default all `false``. + /// + /// If only one value is given, it will broadcast to all columns. + /// + /// Use [`SortMultipleOptions::with_order_descending_multi`] + /// or [`SortMultipleOptions::with_order_descending`] to modify. + /// + /// # Safety + /// + /// Len must match the number of columns, or equal 1. + pub descending: Vec, + /// Whether place null values last. Default `false`. + pub nulls_last: Vec, + /// Whether sort in multiple threads. Default `true`. + pub multithreaded: bool, + /// Whether maintain the order of equal elements. Default `false`. + pub maintain_order: bool, + /// Limit a sort output, this is for optimization purposes and might be ignored. + pub limit: Option, +} + +impl Default for SortOptions { + fn default() -> Self { + Self { + descending: false, + nulls_last: false, + multithreaded: true, + maintain_order: false, + limit: None, + } + } +} + +impl Default for SortMultipleOptions { + fn default() -> Self { + Self { + descending: vec![false], + nulls_last: vec![false], + multithreaded: true, + maintain_order: false, + limit: None, + } + } +} + +impl SortMultipleOptions { + /// Create `SortMultipleOptions` with default values. + pub fn new() -> Self { + Self::default() + } + + /// Specify order for each column. Defaults all `false`. + /// + /// # Safety + /// + /// Len must match the number of columns, or be equal to 1. + pub fn with_order_descending_multi( + mut self, + descending: impl IntoIterator, + ) -> Self { + self.descending = descending.into_iter().collect(); + self + } + + /// Sort order for all columns. Default `false` which is ascending. + pub fn with_order_descending(mut self, descending: bool) -> Self { + self.descending = vec![descending]; + self + } + + /// Specify whether to place nulls last, per-column. Defaults all `false`. + /// + /// # Safety + /// + /// Len must match the number of columns, or be equal to 1. + pub fn with_nulls_last_multi(mut self, nulls_last: impl IntoIterator) -> Self { + self.nulls_last = nulls_last.into_iter().collect(); + self + } + + /// Whether to place null values last. Default `false`. + pub fn with_nulls_last(mut self, enabled: bool) -> Self { + self.nulls_last = vec![enabled]; + self + } + + /// Whether to sort in multiple threads. Default `true`. + pub fn with_multithreaded(mut self, enabled: bool) -> Self { + self.multithreaded = enabled; + self + } + + /// Whether to maintain the order of equal elements. Default `false`. + pub fn with_maintain_order(mut self, enabled: bool) -> Self { + self.maintain_order = enabled; + self + } + + /// Reverse the order of sorting for each column. + pub fn with_order_reversed(mut self) -> Self { + self.descending.iter_mut().for_each(|x| *x = !*x); + self + } +} + +impl SortOptions { + /// Create `SortOptions` with default values. + pub fn new() -> Self { + Self::default() + } + + /// Specify sorting order for the column. Default `false`. + pub fn with_order_descending(mut self, enabled: bool) -> Self { + self.descending = enabled; + self + } + + /// Whether place null values last. Default `false`. + pub fn with_nulls_last(mut self, enabled: bool) -> Self { + self.nulls_last = enabled; + self + } + + /// Whether sort in multiple threads. Default `true`. + pub fn with_multithreaded(mut self, enabled: bool) -> Self { + self.multithreaded = enabled; + self + } + + /// Whether maintain the order of equal elements. Default `false`. + pub fn with_maintain_order(mut self, enabled: bool) -> Self { + self.maintain_order = enabled; + self + } + + /// Reverse the order of sorting. + pub fn with_order_reversed(mut self) -> Self { + self.descending = !self.descending; + self + } +} + +impl From<&SortOptions> for SortMultipleOptions { + fn from(value: &SortOptions) -> Self { + SortMultipleOptions { + descending: vec![value.descending], + nulls_last: vec![value.nulls_last], + multithreaded: value.multithreaded, + maintain_order: value.maintain_order, + limit: value.limit, + } + } +} + +impl From<&SortMultipleOptions> for SortOptions { + fn from(value: &SortMultipleOptions) -> Self { + SortOptions { + descending: value.descending.first().copied().unwrap_or(false), + nulls_last: value.nulls_last.first().copied().unwrap_or(false), + multithreaded: value.multithreaded, + maintain_order: value.maintain_order, + limit: value.limit, + } + } +} diff --git a/crates/polars-core/src/chunked_array/ops/unique/mod.rs b/crates/polars-core/src/chunked_array/ops/unique/mod.rs new file mode 100644 index 000000000000..11b06af190e6 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/unique/mod.rs @@ -0,0 +1,305 @@ +use std::hash::Hash; +use std::ops::Deref; + +use arrow::bitmap::MutableBitmap; +use polars_compute::unique::BooleanUniqueKernelState; +use polars_utils::total_ord::{ToTotalOrd, TotalHash}; + +use crate::hashing::_HASHMAP_INIT_SIZE; +use crate::prelude::*; +use crate::series::IsSorted; + +fn finish_is_unique_helper( + unique_idx: Vec, + len: IdxSize, + setter: bool, + default: bool, +) -> BooleanChunked { + let mut values = MutableBitmap::with_capacity(len as usize); + values.extend_constant(len as usize, default); + + for idx in unique_idx { + unsafe { values.set_unchecked(idx as usize, setter) } + } + let arr = BooleanArray::from_data_default(values.into(), None); + arr.into() +} + +pub(crate) fn is_unique_helper( + groups: &GroupPositions, + len: IdxSize, + unique_val: bool, + duplicated_val: bool, +) -> BooleanChunked { + debug_assert_ne!(unique_val, duplicated_val); + + let idx = match groups.deref() { + GroupsType::Idx(groups) => groups + .iter() + .filter_map(|(first, g)| if g.len() == 1 { Some(first) } else { None }) + .collect::>(), + GroupsType::Slice { groups, .. } => groups + .iter() + .filter_map(|[first, len]| if *len == 1 { Some(*first) } else { None }) + .collect(), + }; + finish_is_unique_helper(idx, len, unique_val, duplicated_val) +} + +#[cfg(feature = "object")] +impl ChunkUnique for ObjectChunked { + fn unique(&self) -> PolarsResult>> { + polars_bail!(opq = unique, self.dtype()); + } + + fn arg_unique(&self) -> PolarsResult { + polars_bail!(opq = arg_unique, self.dtype()); + } +} + +fn arg_unique(a: impl Iterator, capacity: usize) -> Vec +where + T: ToTotalOrd, + ::TotalOrdItem: Hash + Eq, +{ + let mut set = PlHashSet::new(); + let mut unique = Vec::with_capacity(capacity); + a.enumerate().for_each(|(idx, val)| { + if set.insert(val.to_total_ord()) { + unique.push(idx as IdxSize) + } + }); + unique +} + +macro_rules! arg_unique_ca { + ($ca:expr) => {{ + match $ca.has_nulls() { + false => arg_unique($ca.into_no_null_iter(), $ca.len()), + _ => arg_unique($ca.iter(), $ca.len()), + } + }}; +} + +impl ChunkUnique for ChunkedArray +where + T: PolarsNumericType, + T::Native: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq + Ord, + ChunkedArray: + IntoSeries + for<'a> ChunkCompareEq<&'a ChunkedArray, Item = BooleanChunked>, +{ + fn unique(&self) -> PolarsResult { + // prevent stackoverflow repeated sorted.unique call + if self.is_empty() { + return Ok(self.clone()); + } + match self.is_sorted_flag() { + IsSorted::Ascending | IsSorted::Descending => { + if self.null_count() > 0 { + let mut arr = MutablePrimitiveArray::with_capacity(self.len()); + + if !self.is_empty() { + let mut iter = self.iter(); + let last = iter.next().unwrap(); + arr.push(last); + let mut last = last.to_total_ord(); + + let to_extend = iter.filter(|opt_val| { + let opt_val_tot_ord = opt_val.to_total_ord(); + let out = opt_val_tot_ord != last; + last = opt_val_tot_ord; + out + }); + + arr.extend(to_extend); + } + + let arr: PrimitiveArray = arr.into(); + Ok(ChunkedArray::with_chunk(self.name().clone(), arr)) + } else { + let mask = self.not_equal_missing(&self.shift(1)); + self.filter(&mask) + } + }, + IsSorted::Not => { + let sorted = self.sort(false); + sorted.unique() + }, + } + } + + fn arg_unique(&self) -> PolarsResult { + Ok(IdxCa::from_vec(self.name().clone(), arg_unique_ca!(self))) + } + + fn n_unique(&self) -> PolarsResult { + // prevent stackoverflow repeated sorted.unique call + if self.is_empty() { + return Ok(0); + } + match self.is_sorted_flag() { + IsSorted::Ascending | IsSorted::Descending => { + if self.null_count() > 0 { + let mut count = 0; + + if self.is_empty() { + return Ok(count); + } + + let mut iter = self.iter(); + let mut last = iter.next().unwrap().to_total_ord(); + + count += 1; + + iter.for_each(|opt_val| { + let opt_val = opt_val.to_total_ord(); + if opt_val != last { + last = opt_val; + count += 1; + } + }); + + Ok(count) + } else { + let mask = self.not_equal_missing(&self.shift(1)); + Ok(mask.sum().unwrap() as usize) + } + }, + IsSorted::Not => { + let sorted = self.sort(false); + sorted.n_unique() + }, + } + } +} + +impl ChunkUnique for StringChunked { + fn unique(&self) -> PolarsResult { + let out = self.as_binary().unique()?; + Ok(unsafe { out.to_string_unchecked() }) + } + + fn arg_unique(&self) -> PolarsResult { + self.as_binary().arg_unique() + } + + fn n_unique(&self) -> PolarsResult { + self.as_binary().n_unique() + } +} + +impl ChunkUnique for BinaryChunked { + fn unique(&self) -> PolarsResult { + match self.null_count() { + 0 => { + let mut set = + PlHashSet::with_capacity(std::cmp::min(_HASHMAP_INIT_SIZE, self.len())); + for arr in self.downcast_iter() { + set.extend(arr.values_iter()) + } + Ok(BinaryChunked::from_iter_values( + self.name().clone(), + set.iter().copied(), + )) + }, + _ => { + let mut set = + PlHashSet::with_capacity(std::cmp::min(_HASHMAP_INIT_SIZE, self.len())); + for arr in self.downcast_iter() { + set.extend(arr.iter()) + } + Ok(BinaryChunked::from_iter_options( + self.name().clone(), + set.iter().copied(), + )) + }, + } + } + + fn arg_unique(&self) -> PolarsResult { + Ok(IdxCa::from_vec(self.name().clone(), arg_unique_ca!(self))) + } + + fn n_unique(&self) -> PolarsResult { + let mut set: PlHashSet<&[u8]> = PlHashSet::new(); + if self.null_count() > 0 { + for arr in self.downcast_iter() { + set.extend(arr.into_iter().flatten()) + } + Ok(set.len() + 1) + } else { + for arr in self.downcast_iter() { + set.extend(arr.values_iter()) + } + Ok(set.len()) + } + } +} + +impl ChunkUnique for BooleanChunked { + fn unique(&self) -> PolarsResult { + use polars_compute::unique::RangedUniqueKernel; + + let mut state = BooleanUniqueKernelState::new(); + + for arr in self.downcast_iter() { + state.append(arr); + + if state.has_seen_all() { + break; + } + } + + let unique = state.finalize_unique(); + + Ok(Self::with_chunk(self.name().clone(), unique)) + } + + fn arg_unique(&self) -> PolarsResult { + Ok(IdxCa::from_vec(self.name().clone(), arg_unique_ca!(self))) + } +} + +#[cfg(test)] +mod test { + use crate::prelude::*; + + #[test] + fn unique() { + let ca = + ChunkedArray::::from_slice(PlSmallStr::from_static("a"), &[1, 2, 3, 2, 1]); + assert_eq!( + ca.unique() + .unwrap() + .sort(false) + .into_iter() + .collect::>(), + vec![Some(1), Some(2), Some(3)] + ); + let ca = BooleanChunked::from_slice(PlSmallStr::from_static("a"), &[true, false, true]); + assert_eq!( + ca.unique().unwrap().into_iter().collect::>(), + vec![Some(false), Some(true)] + ); + + let ca = StringChunked::new( + PlSmallStr::EMPTY, + &[Some("a"), None, Some("a"), Some("b"), None], + ); + assert_eq!( + Vec::from(&ca.unique().unwrap().sort(false)), + &[None, Some("a"), Some("b")] + ); + } + + #[test] + fn arg_unique() { + let ca = + ChunkedArray::::from_slice(PlSmallStr::from_static("a"), &[1, 2, 1, 1, 3]); + assert_eq!( + ca.arg_unique().unwrap().into_iter().collect::>(), + vec![Some(0), Some(1), Some(4)] + ); + } +} diff --git a/crates/polars-core/src/chunked_array/ops/zip.rs b/crates/polars-core/src/chunked_array/ops/zip.rs new file mode 100644 index 000000000000..3597be035153 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/zip.rs @@ -0,0 +1,535 @@ +use std::borrow::Cow; + +use arrow::bitmap::{Bitmap, BitmapBuilder}; +use arrow::compute::utils::{combine_validities_and, combine_validities_and_not}; +use polars_compute::if_then_else::{IfThenElseKernel, if_then_else_validity}; + +#[cfg(feature = "object")] +use crate::chunked_array::object::ObjectArray; +use crate::prelude::*; +use crate::utils::{align_chunks_binary, align_chunks_ternary}; + +const SHAPE_MISMATCH_STR: &str = + "shapes of `self`, `mask` and `other` are not suitable for `zip_with` operation"; + +fn if_then_else_broadcast_mask( + mask: bool, + if_true: &ChunkedArray, + if_false: &ChunkedArray, +) -> PolarsResult> +where + ChunkedArray: ChunkExpandAtIndex, +{ + let src = if mask { if_true } else { if_false }; + let other = if mask { if_false } else { if_true }; + let ret = match (src.len(), other.len()) { + (a, b) if a == b => src.clone(), + (_, 1) => src.clone(), + (1, other_len) => src.new_from_index(0, other_len), + _ => polars_bail!(ShapeMismatch: SHAPE_MISMATCH_STR), + }; + Ok(ret.with_name(if_true.name().clone())) +} + +fn bool_null_to_false(mask: &BooleanArray) -> Bitmap { + if mask.null_count() == 0 { + mask.values().clone() + } else { + mask.values() & mask.validity().unwrap() + } +} + +/// Combines the validities of ca with the bits in mask using the given combiner. +/// +/// If the mask itself has validity, those null bits are converted to false. +fn combine_validities_chunked< + T: PolarsDataType, + F: Fn(Option<&Bitmap>, Option<&Bitmap>) -> Option, +>( + ca: &ChunkedArray, + mask: &BooleanChunked, + combiner: F, +) -> ChunkedArray { + let (ca_al, mask_al) = align_chunks_binary(ca, mask); + let chunks = ca_al + .downcast_iter() + .zip(mask_al.downcast_iter()) + .map(|(a, m)| { + let bm = bool_null_to_false(m); + let validity = combiner(a.validity(), Some(&bm)); + a.clone().with_validity_typed(validity) + }); + ChunkedArray::from_chunk_iter_like(ca, chunks) +} + +impl ChunkZip for ChunkedArray +where + T: PolarsDataType, + T::Array: for<'a> IfThenElseKernel = T::Physical<'a>>, + ChunkedArray: ChunkExpandAtIndex, +{ + fn zip_with( + &self, + mask: &BooleanChunked, + other: &ChunkedArray, + ) -> PolarsResult> { + let if_true = self; + let if_false = other; + + // Broadcast mask. + if mask.len() == 1 { + return if_then_else_broadcast_mask(mask.get(0).unwrap_or(false), if_true, if_false); + } + + // Broadcast both. + let ret = if if_true.len() == 1 && if_false.len() == 1 { + match (if_true.get(0), if_false.get(0)) { + (None, None) => ChunkedArray::full_null_like(if_true, mask.len()), + (None, Some(_)) => combine_validities_chunked( + &if_false.new_from_index(0, mask.len()), + mask, + combine_validities_and_not, + ), + (Some(_), None) => combine_validities_chunked( + &if_true.new_from_index(0, mask.len()), + mask, + combine_validities_and, + ), + (Some(t), Some(f)) => { + let dtype = if_true.downcast_iter().next().unwrap().dtype(); + let chunks = mask.downcast_iter().map(|m| { + let bm = bool_null_to_false(m); + let t = t.clone(); + let f = f.clone(); + IfThenElseKernel::if_then_else_broadcast_both(dtype.clone(), &bm, t, f) + }); + ChunkedArray::from_chunk_iter_like(if_true, chunks) + }, + } + + // Broadcast neither. + } else if if_true.len() == if_false.len() { + polars_ensure!(mask.len() == if_true.len(), ShapeMismatch: SHAPE_MISMATCH_STR); + let (mask_al, if_true_al, if_false_al) = align_chunks_ternary(mask, if_true, if_false); + let chunks = mask_al + .downcast_iter() + .zip(if_true_al.downcast_iter()) + .zip(if_false_al.downcast_iter()) + .map(|((m, t), f)| IfThenElseKernel::if_then_else(&bool_null_to_false(m), t, f)); + ChunkedArray::from_chunk_iter_like(if_true, chunks) + + // Broadcast true value. + } else if if_true.len() == 1 { + polars_ensure!(mask.len() == if_false.len(), ShapeMismatch: SHAPE_MISMATCH_STR); + if let Some(true_scalar) = if_true.get(0) { + let (mask_al, if_false_al) = align_chunks_binary(mask, if_false); + let chunks = mask_al + .downcast_iter() + .zip(if_false_al.downcast_iter()) + .map(|(m, f)| { + let bm = bool_null_to_false(m); + let t = true_scalar.clone(); + IfThenElseKernel::if_then_else_broadcast_true(&bm, t, f) + }); + ChunkedArray::from_chunk_iter_like(if_true, chunks) + } else { + combine_validities_chunked(if_false, mask, combine_validities_and_not) + } + + // Broadcast false value. + } else if if_false.len() == 1 { + polars_ensure!(mask.len() == if_true.len(), ShapeMismatch: SHAPE_MISMATCH_STR); + if let Some(false_scalar) = if_false.get(0) { + let (mask_al, if_true_al) = align_chunks_binary(mask, if_true); + let chunks = + mask_al + .downcast_iter() + .zip(if_true_al.downcast_iter()) + .map(|(m, t)| { + let bm = bool_null_to_false(m); + let f = false_scalar.clone(); + IfThenElseKernel::if_then_else_broadcast_false(&bm, t, f) + }); + ChunkedArray::from_chunk_iter_like(if_false, chunks) + } else { + combine_validities_chunked(if_true, mask, combine_validities_and) + } + } else { + polars_bail!(ShapeMismatch: SHAPE_MISMATCH_STR) + }; + + Ok(ret.with_name(if_true.name().clone())) + } +} + +// Basic implementation for ObjectArray. +#[cfg(feature = "object")] +impl IfThenElseKernel for ObjectArray { + type Scalar<'a> = &'a T; + + fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self { + mask.iter() + .zip(if_true.iter()) + .zip(if_false.iter()) + .map(|((m, t), f)| if m { t } else { f }) + .collect_arr() + } + + fn if_then_else_broadcast_true( + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: &Self, + ) -> Self { + mask.iter() + .zip(if_false.iter()) + .map(|(m, f)| if m { Some(if_true) } else { f }) + .collect_arr() + } + + fn if_then_else_broadcast_false( + mask: &Bitmap, + if_true: &Self, + if_false: Self::Scalar<'_>, + ) -> Self { + mask.iter() + .zip(if_true.iter()) + .map(|(m, t)| if m { t } else { Some(if_false) }) + .collect_arr() + } + + fn if_then_else_broadcast_both( + _dtype: ArrowDataType, + mask: &Bitmap, + if_true: Self::Scalar<'_>, + if_false: Self::Scalar<'_>, + ) -> Self { + mask.iter() + .map(|m| if m { if_true } else { if_false }) + .collect_arr() + } +} + +#[cfg(feature = "dtype-struct")] +impl ChunkZip for StructChunked { + fn zip_with( + &self, + mask: &BooleanChunked, + other: &ChunkedArray, + ) -> PolarsResult> { + let min_length = self.length.min(mask.length).min(other.length); + let max_length = self.length.max(mask.length).max(other.length); + + let length = if min_length == 0 { 0 } else { max_length }; + + debug_assert!(self.length == 1 || self.length == length); + debug_assert!(mask.length == 1 || mask.length == length); + debug_assert!(other.length == 1 || other.length == length); + + let mut if_true: Cow> = Cow::Borrowed(self); + let mut if_false: Cow> = Cow::Borrowed(other); + + // Special case. In this case, we know what to do. + // @TODO: Optimization. If all mask values are the same, select one of the two. + if mask.length == 1 { + // pl.when(None) <=> pl.when(False) + let is_true = mask.get(0).unwrap_or(false); + return Ok(if is_true && self.length == 1 { + self.new_from_index(0, length) + } else if is_true { + self.clone() + } else if other.length == 1 { + let mut s = other.new_from_index(0, length); + s.rename(self.name().clone()); + s + } else { + let mut s = other.clone(); + s.rename(self.name().clone()); + s + }); + } + + // align_chunks_ternary can only align chunks if: + // - Each chunkedarray only has 1 chunk + // - Each chunkedarray has an equal length (i.e. is broadcasted) + // + // Therefore, we broadcast only those that are necessary to be broadcasted. + let needs_broadcast = + if_true.chunks().len() > 1 || if_false.chunks().len() > 1 || mask.chunks().len() > 1; + if needs_broadcast && length > 1 { + if self.length == 1 { + let broadcasted = self.new_from_index(0, length); + if_true = Cow::Owned(broadcasted); + } + if other.length == 1 { + let broadcasted = other.new_from_index(0, length); + if_false = Cow::Owned(broadcasted); + } + } + + let if_true = if_true.as_ref(); + let if_false = if_false.as_ref(); + + let (if_true, if_false, mask) = align_chunks_ternary(if_true, if_false, mask); + + // Prepare the boolean arrays such that Null maps to false. + // This prevents every field doing that. + // # SAFETY + // We don't modify the length and update the null count. + let mut mask = mask.into_owned(); + unsafe { + for arr in mask.downcast_iter_mut() { + let bm = bool_null_to_false(arr); + *arr = BooleanArray::from_data_default(bm, None); + } + mask.set_null_count(0); + } + + // Zip all the fields. + let fields = if_true + .fields_as_series() + .iter() + .zip(if_false.fields_as_series()) + .map(|(lhs, rhs)| lhs.zip_with_same_type(&mask, &rhs)) + .collect::>>()?; + + let mut out = StructChunked::from_series(self.name().clone(), length, fields.iter())?; + + fn rechunk_bitmaps( + total_length: usize, + iter: impl Iterator)>, + ) -> Option { + let mut rechunked_length = 0; + let mut rechunked_validity = None; + for (chunk_length, validity) in iter { + if let Some(validity) = validity { + if validity.unset_bits() > 0 { + rechunked_validity + .get_or_insert_with(|| { + let mut bm = BitmapBuilder::with_capacity(total_length); + bm.extend_constant(rechunked_length, true); + bm + }) + .extend_from_bitmap(&validity); + } + } + + rechunked_length += chunk_length; + } + + if let Some(rechunked_validity) = rechunked_validity.as_mut() { + rechunked_validity.extend_constant(total_length - rechunked_validity.len(), true); + } + + rechunked_validity.map(BitmapBuilder::freeze) + } + + // Zip the validities. + // + // We need to take two things into account: + // 1. The chunk lengths of `out` might not necessarily match `l`, `r` and `mask`. + // 2. `l` and `r` might still need to be broadcasted. + if (if_true.null_count + if_false.null_count) > 0 { + // Create one validity mask that spans the entirety of out. + let rechunked_validity = match (if_true.len(), if_false.len()) { + (1, 1) if length != 1 => { + match (if_true.null_count() == 0, if_false.null_count() == 0) { + (true, true) => None, + (false, true) => { + if mask.chunks().len() == 1 { + let m = mask.chunks()[0] + .as_any() + .downcast_ref::() + .unwrap() + .values(); + Some(!m) + } else { + rechunk_bitmaps( + length, + mask.downcast_iter() + .map(|m| (m.len(), Some(m.values().clone()))), + ) + } + }, + (true, false) => { + if mask.chunks().len() == 1 { + let m = mask.chunks()[0] + .as_any() + .downcast_ref::() + .unwrap() + .values(); + Some(m.clone()) + } else { + rechunk_bitmaps( + length, + mask.downcast_iter().map(|m| (m.len(), Some(!m.values()))), + ) + } + }, + (false, false) => Some(Bitmap::new_zeroed(length)), + } + }, + (1, _) if length != 1 => { + debug_assert!( + if_false + .chunk_lengths() + .zip(mask.chunk_lengths()) + .all(|(r, m)| r == m) + ); + + let combine = if if_true.null_count() == 0 { + |if_false: Option<&Bitmap>, m: &Bitmap| { + if_false.map(|v| arrow::bitmap::or(v, m)) + } + } else { + |if_false: Option<&Bitmap>, m: &Bitmap| { + Some(if_false.map_or_else(|| !m, |v| arrow::bitmap::and_not(v, m))) + } + }; + + if if_false.chunks().len() == 1 { + let if_false = if_false.chunks()[0].validity(); + let m = mask.chunks()[0] + .as_any() + .downcast_ref::() + .unwrap() + .values(); + + let validity = combine(if_false, m); + validity.filter(|v| v.unset_bits() > 0) + } else { + rechunk_bitmaps( + length, + if_false.chunks().iter().zip(mask.downcast_iter()).map( + |(chunk, mask)| { + (mask.len(), combine(chunk.validity(), mask.values())) + }, + ), + ) + } + }, + (_, 1) if length != 1 => { + debug_assert!( + if_true + .chunk_lengths() + .zip(mask.chunk_lengths()) + .all(|(l, m)| l == m) + ); + + let combine = if if_false.null_count() == 0 { + |if_true: Option<&Bitmap>, m: &Bitmap| { + if_true.map(|v| arrow::bitmap::or_not(v, m)) + } + } else { + |if_true: Option<&Bitmap>, m: &Bitmap| { + Some(if_true.map_or_else(|| m.clone(), |v| arrow::bitmap::and(v, m))) + } + }; + + if if_true.chunks().len() == 1 { + let if_true = if_true.chunks()[0].validity(); + let m = mask.chunks()[0] + .as_any() + .downcast_ref::() + .unwrap() + .values(); + + let validity = combine(if_true, m); + validity.filter(|v| v.unset_bits() > 0) + } else { + rechunk_bitmaps( + length, + if_true.chunks().iter().zip(mask.downcast_iter()).map( + |(chunk, mask)| { + (mask.len(), combine(chunk.validity(), mask.values())) + }, + ), + ) + } + }, + (_, _) => { + debug_assert!( + if_true + .chunk_lengths() + .zip(if_false.chunk_lengths()) + .all(|(l, r)| l == r) + ); + debug_assert!( + if_true + .chunk_lengths() + .zip(mask.chunk_lengths()) + .all(|(l, r)| l == r) + ); + + let validities = if_true + .chunks() + .iter() + .zip(if_false.chunks()) + .map(|(l, r)| (l.validity(), r.validity())); + + rechunk_bitmaps( + length, + validities + .zip(mask.downcast_iter()) + .map(|((if_true, if_false), mask)| { + ( + mask.len(), + if_then_else_validity(mask.values(), if_true, if_false), + ) + }), + ) + }, + }; + + // Apply the validity spreading over the chunks of out. + if let Some(mut rechunked_validity) = rechunked_validity { + assert_eq!(rechunked_validity.len(), out.len()); + + let num_chunks = out.chunks().len(); + let null_count = rechunked_validity.unset_bits(); + + // SAFETY: We do not change the lengths of the chunks and we update the null_count + // afterwards. + let chunks = unsafe { out.chunks_mut() }; + + if num_chunks == 1 { + chunks[0] = chunks[0].with_validity(Some(rechunked_validity)); + } else { + for chunk in chunks { + let chunk_len = chunk.len(); + let chunk_validity; + + // SAFETY: We know that rechunked_validity.len() == out.len() + (chunk_validity, rechunked_validity) = + unsafe { rechunked_validity.split_at_unchecked(chunk_len) }; + *chunk = chunk.with_validity( + (chunk_validity.unset_bits() > 0).then_some(chunk_validity), + ); + } + } + + out.null_count = null_count; + } else { + // SAFETY: We do not change the lengths of the chunks and we update the null_count + // afterwards. + let chunks = unsafe { out.chunks_mut() }; + + for chunk in chunks { + *chunk = chunk.with_validity(None); + } + + out.null_count = 0; + } + } + + if cfg!(debug_assertions) { + let start_length = out.len(); + let start_null_count = out.null_count(); + + out.compute_len(); + + assert_eq!(start_length, out.len()); + assert_eq!(start_null_count, out.null_count()); + } + Ok(out) + } +} diff --git a/crates/polars-core/src/chunked_array/random.rs b/crates/polars-core/src/chunked_array/random.rs new file mode 100644 index 000000000000..552cbe80b121 --- /dev/null +++ b/crates/polars-core/src/chunked_array/random.rs @@ -0,0 +1,393 @@ +use num_traits::{Float, NumCast}; +use polars_error::to_compute_err; +use rand::distributions::Bernoulli; +use rand::prelude::*; +use rand::seq::index::IndexVec; +use rand_distr::{Normal, Standard, StandardNormal, Uniform}; + +use crate::prelude::DataType::Float64; +use crate::prelude::*; +use crate::random::get_global_random_u64; +use crate::utils::NoNull; + +fn create_rand_index_with_replacement(n: usize, len: usize, seed: Option) -> IdxCa { + if len == 0 { + return IdxCa::new_vec(PlSmallStr::EMPTY, vec![]); + } + let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64)); + let dist = Uniform::new(0, len as IdxSize); + (0..n as IdxSize) + .map(move |_| dist.sample(&mut rng)) + .collect_trusted::>() + .into_inner() +} + +fn create_rand_index_no_replacement( + n: usize, + len: usize, + seed: Option, + shuffle: bool, +) -> IdxCa { + let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64)); + let mut buf: Vec; + if n == len { + buf = (0..len as IdxSize).collect(); + if shuffle { + buf.shuffle(&mut rng) + } + } else { + // TODO: avoid extra potential copy by vendoring rand::seq::index::sample, + // or genericize take over slices over any unsigned type. The optimizer + // should get rid of the extra copy already if IdxSize matches the IndexVec + // size returned. + buf = match rand::seq::index::sample(&mut rng, len, n) { + IndexVec::U32(v) => v.into_iter().map(|x| x as IdxSize).collect(), + IndexVec::USize(v) => v.into_iter().map(|x| x as IdxSize).collect(), + }; + } + IdxCa::new_vec(PlSmallStr::EMPTY, buf) +} + +impl ChunkedArray +where + T: PolarsNumericType, + Standard: Distribution, +{ + pub fn init_rand(size: usize, null_density: f32, seed: Option) -> Self { + let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64)); + (0..size) + .map(|_| { + if rng.r#gen::() < null_density { + None + } else { + Some(rng.r#gen()) + } + }) + .collect() + } +} + +fn ensure_shape(n: usize, len: usize, with_replacement: bool) -> PolarsResult<()> { + polars_ensure!( + with_replacement || n <= len, + ShapeMismatch: + "cannot take a larger sample than the total population when `with_replacement=false`" + ); + Ok(()) +} + +impl Series { + pub fn sample_n( + &self, + n: usize, + with_replacement: bool, + shuffle: bool, + seed: Option, + ) -> PolarsResult { + ensure_shape(n, self.len(), with_replacement)?; + if n == 0 { + return Ok(self.clear()); + } + let len = self.len(); + + match with_replacement { + true => { + let idx = create_rand_index_with_replacement(n, len, seed); + debug_assert_eq!(len, self.len()); + // SAFETY: we know that we never go out of bounds. + unsafe { Ok(self.take_unchecked(&idx)) } + }, + false => { + let idx = create_rand_index_no_replacement(n, len, seed, shuffle); + debug_assert_eq!(len, self.len()); + // SAFETY: we know that we never go out of bounds. + unsafe { Ok(self.take_unchecked(&idx)) } + }, + } + } + + /// Sample a fraction between 0.0-1.0 of this [`ChunkedArray`]. + pub fn sample_frac( + &self, + frac: f64, + with_replacement: bool, + shuffle: bool, + seed: Option, + ) -> PolarsResult { + let n = (self.len() as f64 * frac) as usize; + self.sample_n(n, with_replacement, shuffle, seed) + } + + pub fn shuffle(&self, seed: Option) -> Self { + let len = self.len(); + let n = len; + let idx = create_rand_index_no_replacement(n, len, seed, true); + debug_assert_eq!(len, self.len()); + // SAFETY: we know that we never go out of bounds. + unsafe { self.take_unchecked(&idx) } + } +} + +impl ChunkedArray +where + T: PolarsDataType, + ChunkedArray: ChunkTake, +{ + /// Sample n datapoints from this [`ChunkedArray`]. + pub fn sample_n( + &self, + n: usize, + with_replacement: bool, + shuffle: bool, + seed: Option, + ) -> PolarsResult { + ensure_shape(n, self.len(), with_replacement)?; + let len = self.len(); + + match with_replacement { + true => { + let idx = create_rand_index_with_replacement(n, len, seed); + debug_assert_eq!(len, self.len()); + // SAFETY: we know that we never go out of bounds. + unsafe { Ok(self.take_unchecked(&idx)) } + }, + false => { + let idx = create_rand_index_no_replacement(n, len, seed, shuffle); + debug_assert_eq!(len, self.len()); + // SAFETY: we know that we never go out of bounds. + unsafe { Ok(self.take_unchecked(&idx)) } + }, + } + } + + /// Sample a fraction between 0.0-1.0 of this [`ChunkedArray`]. + pub fn sample_frac( + &self, + frac: f64, + with_replacement: bool, + shuffle: bool, + seed: Option, + ) -> PolarsResult { + let n = (self.len() as f64 * frac) as usize; + self.sample_n(n, with_replacement, shuffle, seed) + } +} + +impl DataFrame { + /// Sample n datapoints from this [`DataFrame`]. + pub fn sample_n( + &self, + n: &Series, + with_replacement: bool, + shuffle: bool, + seed: Option, + ) -> PolarsResult { + polars_ensure!( + n.len() == 1, + ComputeError: "Sample size must be a single value." + ); + + let n = n.cast(&IDX_DTYPE)?; + let n = n.idx()?; + + match n.get(0) { + Some(n) => self.sample_n_literal(n as usize, with_replacement, shuffle, seed), + None => Ok(self.clear()), + } + } + + pub fn sample_n_literal( + &self, + n: usize, + with_replacement: bool, + shuffle: bool, + seed: Option, + ) -> PolarsResult { + ensure_shape(n, self.height(), with_replacement)?; + // All columns should used the same indices. So we first create the indices. + let idx = match with_replacement { + true => create_rand_index_with_replacement(n, self.height(), seed), + false => create_rand_index_no_replacement(n, self.height(), seed, shuffle), + }; + // SAFETY: the indices are within bounds. + Ok(unsafe { self.take_unchecked(&idx) }) + } + + /// Sample a fraction between 0.0-1.0 of this [`DataFrame`]. + pub fn sample_frac( + &self, + frac: &Series, + with_replacement: bool, + shuffle: bool, + seed: Option, + ) -> PolarsResult { + polars_ensure!( + frac.len() == 1, + ComputeError: "Sample fraction must be a single value." + ); + + let frac = frac.cast(&Float64)?; + let frac = frac.f64()?; + + match frac.get(0) { + Some(frac) => { + let n = (self.height() as f64 * frac) as usize; + self.sample_n_literal(n, with_replacement, shuffle, seed) + }, + None => Ok(self.clear()), + } + } +} + +impl ChunkedArray +where + T: PolarsNumericType, + T::Native: Float, +{ + /// Create [`ChunkedArray`] with samples from a Normal distribution. + pub fn rand_normal( + name: PlSmallStr, + length: usize, + mean: f64, + std_dev: f64, + ) -> PolarsResult { + let normal = Normal::new(mean, std_dev).map_err(to_compute_err)?; + let mut builder = PrimitiveChunkedBuilder::::new(name, length); + let mut rng = rand::thread_rng(); + for _ in 0..length { + let smpl = normal.sample(&mut rng); + let smpl = NumCast::from(smpl).unwrap(); + builder.append_value(smpl) + } + Ok(builder.finish()) + } + + /// Create [`ChunkedArray`] with samples from a Standard Normal distribution. + pub fn rand_standard_normal(name: PlSmallStr, length: usize) -> Self { + let mut builder = PrimitiveChunkedBuilder::::new(name, length); + let mut rng = rand::thread_rng(); + for _ in 0..length { + let smpl: f64 = rng.sample(StandardNormal); + let smpl = NumCast::from(smpl).unwrap(); + builder.append_value(smpl) + } + builder.finish() + } + + /// Create [`ChunkedArray`] with samples from a Uniform distribution. + pub fn rand_uniform(name: PlSmallStr, length: usize, low: f64, high: f64) -> Self { + let uniform = Uniform::new(low, high); + let mut builder = PrimitiveChunkedBuilder::::new(name, length); + let mut rng = rand::thread_rng(); + for _ in 0..length { + let smpl = uniform.sample(&mut rng); + let smpl = NumCast::from(smpl).unwrap(); + builder.append_value(smpl) + } + builder.finish() + } +} + +impl BooleanChunked { + /// Create [`ChunkedArray`] with samples from a Bernoulli distribution. + pub fn rand_bernoulli(name: PlSmallStr, length: usize, p: f64) -> PolarsResult { + let dist = Bernoulli::new(p).map_err(to_compute_err)?; + let mut rng = rand::thread_rng(); + let mut builder = BooleanChunkedBuilder::new(name, length); + for _ in 0..length { + let smpl = dist.sample(&mut rng); + builder.append_value(smpl) + } + Ok(builder.finish()) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_sample() { + let df = df![ + "foo" => &[1, 2, 3, 4, 5] + ] + .unwrap(); + + // Default samples are random and don't require seeds. + assert!( + df.sample_n( + &Series::new(PlSmallStr::from_static("s"), &[3]), + false, + false, + None + ) + .is_ok() + ); + assert!( + df.sample_frac( + &Series::new(PlSmallStr::from_static("frac"), &[0.4]), + false, + false, + None + ) + .is_ok() + ); + // With seeding. + assert!( + df.sample_n( + &Series::new(PlSmallStr::from_static("s"), &[3]), + false, + false, + Some(0) + ) + .is_ok() + ); + assert!( + df.sample_frac( + &Series::new(PlSmallStr::from_static("frac"), &[0.4]), + false, + false, + Some(0) + ) + .is_ok() + ); + // Without replacement can not sample more than 100%. + assert!( + df.sample_frac( + &Series::new(PlSmallStr::from_static("frac"), &[2.0]), + false, + false, + Some(0) + ) + .is_err() + ); + assert!( + df.sample_n( + &Series::new(PlSmallStr::from_static("s"), &[3]), + true, + false, + Some(0) + ) + .is_ok() + ); + assert!( + df.sample_frac( + &Series::new(PlSmallStr::from_static("frac"), &[0.4]), + true, + false, + Some(0) + ) + .is_ok() + ); + // With replacement can sample more than 100%. + assert!( + df.sample_frac( + &Series::new(PlSmallStr::from_static("frac"), &[2.0]), + true, + false, + Some(0) + ) + .is_ok() + ); + } +} diff --git a/crates/polars-core/src/chunked_array/struct_/frame.rs b/crates/polars-core/src/chunked_array/struct_/frame.rs new file mode 100644 index 000000000000..b175b3a04832 --- /dev/null +++ b/crates/polars-core/src/chunked_array/struct_/frame.rs @@ -0,0 +1,10 @@ +use polars_utils::pl_str::PlSmallStr; + +use crate::frame::DataFrame; +use crate::prelude::StructChunked; + +impl DataFrame { + pub fn into_struct(self, name: PlSmallStr) -> StructChunked { + StructChunked::from_columns(name, self.height(), &self.columns).expect("same invariants") + } +} diff --git a/crates/polars-core/src/chunked_array/struct_/mod.rs b/crates/polars-core/src/chunked_array/struct_/mod.rs new file mode 100644 index 000000000000..8e8685726223 --- /dev/null +++ b/crates/polars-core/src/chunked_array/struct_/mod.rs @@ -0,0 +1,467 @@ +mod frame; + +use std::borrow::Cow; +use std::fmt::Write; + +use arrow::array::StructArray; +use arrow::bitmap::Bitmap; +use arrow::compute::utils::combine_validities_and; +use polars_error::{PolarsResult, polars_ensure}; +use polars_utils::aliases::PlHashMap; +use polars_utils::itertools::Itertools; + +use crate::chunked_array::ChunkedArray; +use crate::chunked_array::cast::CastOptions; +use crate::chunked_array::ops::row_encode::{_get_rows_encoded_arr, _get_rows_encoded_ca}; +use crate::prelude::*; +use crate::series::Series; +use crate::utils::Container; + +pub type StructChunked = ChunkedArray; + +fn constructor<'a, I: ExactSizeIterator + Clone>( + name: PlSmallStr, + length: usize, + fields: I, +) -> StructChunked { + if fields.len() == 0 { + let dtype = DataType::Struct(Vec::new()); + let arrow_dtype = dtype.to_physical().to_arrow(CompatLevel::newest()); + let chunks = vec![StructArray::new(arrow_dtype, length, Vec::new(), None).boxed()]; + + // SAFETY: We construct each chunk above to have the `Struct` data type. + return unsafe { StructChunked::from_chunks_and_dtype(name, chunks, dtype) }; + } + + // Different chunk lengths: rechunk and recurse. + if !fields.clone().map(|s| s.n_chunks()).all_equal() { + let fields = fields.map(|s| s.rechunk()).collect::>(); + return constructor(name, length, fields.iter()); + } + + let n_chunks = fields.clone().next().unwrap().n_chunks(); + let dtype = DataType::Struct(fields.clone().map(|s| s.field().into_owned()).collect()); + let arrow_dtype = dtype.to_physical().to_arrow(CompatLevel::newest()); + + let chunks = (0..n_chunks) + .map(|c_i| { + let fields = fields + .clone() + .map(|field| field.chunks()[c_i].clone()) + .collect::>(); + let chunk_length = fields[0].len(); + + if fields[1..].iter().any(|arr| chunk_length != arr.len()) { + return None; + } + + Some(StructArray::new(arrow_dtype.clone(), chunk_length, fields, None).boxed()) + }) + .collect::>>(); + + match chunks { + Some(chunks) => { + // SAFETY: invariants checked above. + unsafe { StructChunked::from_chunks_and_dtype_unchecked(name, chunks, dtype) } + }, + // Different chunks: rechunk and recurse. + None => { + let fields = fields.map(|s| s.rechunk()).collect::>(); + constructor(name, length, fields.iter()) + }, + } +} + +impl StructChunked { + pub fn from_columns(name: PlSmallStr, length: usize, fields: &[Column]) -> PolarsResult { + Self::from_series( + name, + length, + fields.iter().map(|c| c.as_materialized_series()), + ) + } + + pub fn from_series<'a, I: ExactSizeIterator + Clone>( + name: PlSmallStr, + length: usize, + fields: I, + ) -> PolarsResult { + let mut names = PlHashSet::with_capacity(fields.len()); + + let mut needs_to_broadcast = false; + for s in fields.clone() { + let s_len = s.len(); + + if s_len != length && s_len != 1 { + polars_bail!( + ShapeMismatch: "expected struct fields to have given length. given = {length}, field length = {s_len}." + ); + } + + needs_to_broadcast |= length != 1 && s_len == 1; + + polars_ensure!( + names.insert(s.name()), + Duplicate: "multiple fields with name '{}' found", s.name() + ); + + match s.dtype() { + #[cfg(feature = "object")] + DataType::Object(_) => { + polars_bail!(InvalidOperation: "nested objects are not allowed") + }, + _ => {}, + } + } + + if !needs_to_broadcast { + return Ok(constructor(name, length, fields)); + } + + if length == 0 { + // @NOTE: There are columns that are being broadcasted so we need to clear those. + let new_fields = fields.map(|s| s.clear()).collect::>(); + + return Ok(constructor(name, length, new_fields.iter())); + } + + let new_fields = fields + .map(|s| { + if s.len() == length { + s.clone() + } else { + s.new_from_index(0, length) + } + }) + .collect::>(); + Ok(constructor(name, length, new_fields.iter())) + } + + /// Convert a struct to the underlying physical datatype. + pub fn to_physical_repr(&self) -> Cow { + let mut physicals = Vec::new(); + + let field_series = self.fields_as_series(); + for (i, s) in field_series.iter().enumerate() { + if let Cow::Owned(physical) = s.to_physical_repr() { + physicals.reserve(field_series.len()); + physicals.extend(field_series[..i].iter().cloned()); + physicals.push(physical); + break; + } + } + + if physicals.is_empty() { + return Cow::Borrowed(self); + } + + physicals.extend( + field_series[physicals.len()..] + .iter() + .map(|s| s.to_physical_repr().into_owned()), + ); + + let mut ca = constructor(self.name().clone(), self.length, physicals.iter()); + if self.null_count() > 0 { + ca.zip_outer_validity(self); + } + + Cow::Owned(ca) + } + + /// Convert a non-logical [`StructChunked`] back into a logical [`StructChunked`] without casting. + /// + /// # Safety + /// + /// This can lead to invalid memory access in downstream code. + pub unsafe fn from_physical_unchecked( + &self, + to_fields: &[Field], + ) -> PolarsResult { + if cfg!(debug_assertions) { + for f in self.struct_fields() { + assert!(!f.dtype().is_logical()); + } + } + + let length = self.len(); + let fields = self + .fields_as_series() + .iter() + .zip(to_fields) + .map(|(f, to)| unsafe { f.from_physical_unchecked(to.dtype()) }) + .collect::>>()?; + + let mut out = StructChunked::from_series(self.name().clone(), length, fields.iter())?; + out.zip_outer_validity(self); + Ok(out) + } + + pub fn struct_fields(&self) -> &[Field] { + let DataType::Struct(fields) = self.dtype() else { + unreachable!() + }; + fields + } + + pub fn fields_as_series(&self) -> Vec { + self.struct_fields() + .iter() + .enumerate() + .map(|(i, field)| { + let field_chunks = self + .downcast_iter() + .map(|chunk| chunk.values()[i].clone()) + .collect::>(); + + // SAFETY: correct type. + unsafe { + Series::from_chunks_and_dtype_unchecked( + field.name.clone(), + field_chunks, + &field.dtype, + ) + } + }) + .collect() + } + + unsafe fn cast_impl( + &self, + dtype: &DataType, + cast_options: CastOptions, + unchecked: bool, + ) -> PolarsResult { + match dtype { + DataType::Struct(dtype_fields) => { + let fields = self.fields_as_series(); + let map = PlHashMap::from_iter(fields.iter().map(|s| (s.name(), s))); + let struct_len = self.len(); + let new_fields = dtype_fields + .iter() + .map(|new_field| match map.get(new_field.name()) { + Some(s) => { + if unchecked { + s.cast_unchecked(&new_field.dtype) + } else { + s.cast_with_options(&new_field.dtype, cast_options) + } + }, + None => Ok(Series::full_null( + new_field.name().clone(), + struct_len, + &new_field.dtype, + )), + }) + .collect::>>()?; + + let mut out = + Self::from_series(self.name().clone(), struct_len, new_fields.iter())?; + if self.null_count > 0 { + out.zip_outer_validity(self); + } + Ok(out.into_series()) + }, + DataType::String => { + let ca = self.rechunk(); + let fields = ca.fields_as_series(); + let mut iters = fields.iter().map(|s| s.iter()).collect::>(); + let cap = ca.len(); + + let mut builder = MutablePlString::with_capacity(cap); + let mut scratch = String::new(); + + for _ in 0..ca.len() { + let mut row_has_nulls = false; + + write!(scratch, "{{").unwrap(); + for iter in &mut iters { + let av = unsafe { iter.next().unwrap_unchecked() }; + row_has_nulls |= matches!(&av, AnyValue::Null); + write!(scratch, "{},", av).unwrap(); + } + + // replace latest comma with '|' + unsafe { + *scratch.as_bytes_mut().last_mut().unwrap_unchecked() = b'}'; + } + + // TODO: this seem strange to me. We should use outer mutability to determine this. + // Also we should move this whole cast into arrow logic. + if row_has_nulls { + builder.push_null() + } else { + builder.push_value(scratch.as_str()); + } + scratch.clear(); + } + let array = builder.freeze().boxed(); + Series::try_from((ca.name().clone(), array)) + }, + _ => { + let fields = self + .fields_as_series() + .iter() + .map(|s| { + if unchecked { + s.cast_unchecked(dtype) + } else { + s.cast_with_options(dtype, cast_options) + } + }) + .collect::>>()?; + let mut out = Self::from_series(self.name().clone(), self.len(), fields.iter())?; + if self.null_count > 0 { + out.zip_outer_validity(self); + } + Ok(out.into_series()) + }, + } + } + + pub(crate) unsafe fn cast_unchecked(&self, dtype: &DataType) -> PolarsResult { + if dtype == self.dtype() { + return Ok(self.clone().into_series()); + } + self.cast_impl(dtype, CastOptions::Overflowing, true) + } + + // in case of a struct, a cast will coerce the inner types + pub fn cast_with_options( + &self, + dtype: &DataType, + cast_options: CastOptions, + ) -> PolarsResult { + unsafe { self.cast_impl(dtype, cast_options, false) } + } + + pub fn cast(&self, dtype: &DataType) -> PolarsResult { + self.cast_with_options(dtype, CastOptions::NonStrict) + } + + pub fn _apply_fields(&self, mut func: F) -> PolarsResult + where + F: FnMut(&Series) -> Series, + { + self.try_apply_fields(|s| Ok(func(s))) + } + + pub fn try_apply_fields(&self, func: F) -> PolarsResult + where + F: FnMut(&Series) -> PolarsResult, + { + let fields = self + .fields_as_series() + .iter() + .map(func) + .collect::>>()?; + Self::from_series(self.name().clone(), self.len(), fields.iter()).map(|mut ca| { + if self.null_count > 0 { + // SAFETY: we don't change types/ lengths. + unsafe { + for (new, this) in ca.downcast_iter_mut().zip(self.downcast_iter()) { + new.set_validity(this.validity().cloned()) + } + } + } + ca + }) + } + + pub fn get_row_encoded_array(&self, options: SortOptions) -> PolarsResult> { + let c = self.clone().into_column(); + _get_rows_encoded_arr(&[c], &[options.descending], &[options.nulls_last]) + } + + pub fn get_row_encoded(&self, options: SortOptions) -> PolarsResult { + let c = self.clone().into_column(); + _get_rows_encoded_ca( + self.name().clone(), + &[c], + &[options.descending], + &[options.nulls_last], + ) + } + + /// Set the outer nulls into the inner arrays, and clear the outer validity. + pub(crate) fn propagate_nulls(&mut self) { + if self.null_count > 0 { + // SAFETY: + // We keep length and dtypes the same. + unsafe { + for arr in self.downcast_iter_mut() { + *arr = arr.propagate_nulls() + } + } + } + } + + /// Combine the validities of two structs. + pub fn zip_outer_validity(&mut self, other: &StructChunked) { + // This might go wrong for broadcasting behavior. If this is not checked, it leads to a + // segfault because we infinitely recurse. + assert_eq!(self.len(), other.len()); + + if other.null_count() == 0 { + return; + } + + if self.chunks.len() != other.chunks.len() + || self + .chunks + .iter() + .zip(other.chunks()) + .any(|(a, b)| a.len() != b.len()) + { + self.rechunk_mut(); + let other = other.rechunk(); + return self.zip_outer_validity(&other); + } + + // SAFETY: + // We keep length and dtypes the same. + unsafe { + for (a, b) in self.downcast_iter_mut().zip(other.downcast_iter()) { + let new = combine_validities_and(a.validity(), b.validity()); + a.set_validity(new) + } + } + + self.compute_len(); + self.propagate_nulls(); + } + + pub fn unnest(self) -> DataFrame { + // @scalar-opt + let columns = self + .fields_as_series() + .into_iter() + .map(Column::from) + .collect::>(); + + // SAFETY: invariants for struct are the same + unsafe { DataFrame::new_no_checks(self.len(), columns) } + } + + /// Get access to one of this [`StructChunked`]'s fields + pub fn field_by_name(&self, name: &str) -> PolarsResult { + self.fields_as_series() + .into_iter() + .find(|s| s.name().as_str() == name) + .ok_or_else(|| polars_err!(StructFieldNotFound: "{}", name)) + } + pub(crate) fn set_outer_validity(&mut self, validity: Option) { + assert_eq!(self.chunks().len(), 1); + unsafe { + let arr = self.chunks_mut().iter_mut().next().unwrap(); + *arr = arr.with_validity(validity); + } + self.compute_len(); + self.propagate_nulls(); + } + + pub fn with_outer_validity(mut self, validity: Option) -> Self { + self.set_outer_validity(validity); + self + } +} diff --git a/crates/polars-core/src/chunked_array/temporal/conversion.rs b/crates/polars-core/src/chunked_array/temporal/conversion.rs new file mode 100644 index 000000000000..7209f4bbdbb3 --- /dev/null +++ b/crates/polars-core/src/chunked_array/temporal/conversion.rs @@ -0,0 +1,90 @@ +use arrow::temporal_conversions::*; +use chrono::*; + +use crate::prelude::*; + +pub(crate) const NS_IN_DAY: i64 = 86_400_000_000_000; +pub(crate) const US_IN_DAY: i64 = 86_400_000_000; +pub(crate) const MS_IN_DAY: i64 = 86_400_000; +pub(crate) const SECONDS_IN_DAY: i64 = 86_400; + +impl From<&AnyValue<'_>> for NaiveDateTime { + fn from(v: &AnyValue) -> Self { + match v { + #[cfg(feature = "dtype-date")] + AnyValue::Date(v) => date32_to_datetime(*v), + #[cfg(feature = "dtype-datetime")] + AnyValue::Datetime(v, tu, _) => match tu { + TimeUnit::Nanoseconds => timestamp_ns_to_datetime(*v), + TimeUnit::Microseconds => timestamp_us_to_datetime(*v), + TimeUnit::Milliseconds => timestamp_ms_to_datetime(*v), + }, + _ => panic!("can only convert date/datetime to NaiveDateTime"), + } + } +} + +impl From<&AnyValue<'_>> for NaiveTime { + fn from(v: &AnyValue) -> Self { + match v { + #[cfg(feature = "dtype-time")] + AnyValue::Time(v) => time64ns_to_time(*v), + _ => panic!("can only convert date/datetime to NaiveTime"), + } + } +} + +// Used by lazy for literal conversion +pub fn datetime_to_timestamp_ns(v: NaiveDateTime) -> i64 { + let us = v.and_utc().timestamp() * 1_000_000_000; //nanos + us + v.and_utc().timestamp_subsec_nanos() as i64 +} + +pub fn datetime_to_timestamp_ms(v: NaiveDateTime) -> i64 { + v.and_utc().timestamp_millis() +} + +pub fn datetime_to_timestamp_us(v: NaiveDateTime) -> i64 { + let us = v.and_utc().timestamp() * 1_000_000; + us + v.and_utc().timestamp_subsec_micros() as i64 +} + +pub(crate) fn naive_datetime_to_date(v: NaiveDateTime) -> i32 { + (datetime_to_timestamp_ms(v) / (MILLISECONDS * SECONDS_IN_DAY)) as i32 +} + +pub fn get_strftime_format(fmt: &str, dtype: &DataType) -> PolarsResult { + if fmt == "polars" && !matches!(dtype, DataType::Duration(_)) { + polars_bail!(InvalidOperation: "'polars' is not a valid `to_string` format for {} dtype expressions", dtype); + } else { + let format_string = if fmt != "iso" && fmt != "iso:strict" { + fmt.to_string() + } else { + let sep = if fmt == "iso" { " " } else { "T" }; + #[allow(unreachable_code)] + match dtype { + #[cfg(feature = "dtype-datetime")] + DataType::Datetime(tu, tz) => match (tu, tz.is_some()) { + (TimeUnit::Milliseconds, true) => format!("%F{}%T%.3f%:z", sep), + (TimeUnit::Milliseconds, false) => format!("%F{}%T%.3f", sep), + (TimeUnit::Microseconds, true) => format!("%F{}%T%.6f%:z", sep), + (TimeUnit::Microseconds, false) => format!("%F{}%T%.6f", sep), + (TimeUnit::Nanoseconds, true) => format!("%F{}%T%.9f%:z", sep), + (TimeUnit::Nanoseconds, false) => format!("%F{}%T%.9f", sep), + }, + #[cfg(feature = "dtype-date")] + DataType::Date => "%F".to_string(), + #[cfg(feature = "dtype-time")] + DataType::Time => "%T%.f".to_string(), + _ => { + let err = format!( + "invalid call to `get_strftime_format`; fmt={:?}, dtype={}", + fmt, dtype + ); + unimplemented!("{}", err) + }, + } + }; + Ok(format_string) + } +} diff --git a/crates/polars-core/src/chunked_array/temporal/date.rs b/crates/polars-core/src/chunked_array/temporal/date.rs new file mode 100644 index 000000000000..796bc33ee690 --- /dev/null +++ b/crates/polars-core/src/chunked_array/temporal/date.rs @@ -0,0 +1,65 @@ +use std::fmt::Write; + +use arrow::temporal_conversions::date32_to_date; + +use super::*; +use crate::prelude::*; + +pub(crate) fn naive_date_to_date(nd: NaiveDate) -> i32 { + let nt = NaiveTime::from_hms_opt(0, 0, 0).unwrap(); + let ndt = NaiveDateTime::new(nd, nt); + naive_datetime_to_date(ndt) +} + +impl DateChunked { + pub fn as_date_iter(&self) -> impl TrustedLen> + '_ { + // SAFETY: we know the iterators len + unsafe { + self.downcast_iter() + .flat_map(|iter| { + iter.into_iter() + .map(|opt_v| opt_v.copied().map(date32_to_date)) + }) + .trust_my_length(self.len()) + } + } + + /// Construct a new [`DateChunked`] from an iterator over [`NaiveDate`]. + pub fn from_naive_date>(name: PlSmallStr, v: I) -> Self { + let unit = v.into_iter().map(naive_date_to_date).collect::>(); + Int32Chunked::from_vec(name, unit).into() + } + + /// Convert from Date into String with the given format. + /// See [chrono strftime/strptime](https://docs.rs/chrono/0.4.19/chrono/format/strftime/index.html). + pub fn to_string(&self, format: &str) -> PolarsResult { + let format = if format == "iso" || format == "iso:strict" { + "%F" + } else { + format + }; + let datefmt_f = |ndt: NaiveDate| ndt.format(format); + self.try_apply_into_string_amortized(|val, buf| { + let ndt = date32_to_date(val); + write!(buf, "{}", datefmt_f(ndt)) + }) + .map_err(|_| polars_err!(ComputeError: "cannot format Date with format '{}'", format)) + } + + /// Convert from Date into String with the given format. + /// See [chrono strftime/strptime](https://docs.rs/chrono/0.4.19/chrono/format/strftime/index.html). + /// + /// Alias for `to_string`. + pub fn strftime(&self, format: &str) -> PolarsResult { + self.to_string(format) + } + + /// Construct a new [`DateChunked`] from an iterator over optional [`NaiveDate`]. + pub fn from_naive_date_options>>( + name: PlSmallStr, + v: I, + ) -> Self { + let unit = v.into_iter().map(|opt| opt.map(naive_date_to_date)); + Int32Chunked::from_iter_options(name, unit).into() + } +} diff --git a/crates/polars-core/src/chunked_array/temporal/datetime.rs b/crates/polars-core/src/chunked_array/temporal/datetime.rs new file mode 100644 index 000000000000..ce5732885d9c --- /dev/null +++ b/crates/polars-core/src/chunked_array/temporal/datetime.rs @@ -0,0 +1,221 @@ +use std::fmt::Write; + +use arrow::temporal_conversions::{ + timestamp_ms_to_datetime, timestamp_ns_to_datetime, timestamp_us_to_datetime, +}; +#[cfg(feature = "timezones")] +use chrono::TimeZone as TimeZoneTrait; + +use super::*; +use crate::prelude::DataType::Datetime; +use crate::prelude::*; + +impl DatetimeChunked { + pub fn as_datetime_iter(&self) -> impl TrustedLen> + '_ { + let func = match self.time_unit() { + TimeUnit::Nanoseconds => timestamp_ns_to_datetime, + TimeUnit::Microseconds => timestamp_us_to_datetime, + TimeUnit::Milliseconds => timestamp_ms_to_datetime, + }; + // we know the iterators len + unsafe { + self.downcast_iter() + .flat_map(move |iter| iter.into_iter().map(move |opt_v| opt_v.copied().map(func))) + .trust_my_length(self.len()) + } + } + + pub fn time_unit(&self) -> TimeUnit { + match self.2.as_ref().unwrap() { + DataType::Datetime(tu, _) => *tu, + _ => unreachable!(), + } + } + + pub fn time_zone(&self) -> &Option { + match self.2.as_ref().unwrap() { + DataType::Datetime(_, tz) => tz, + _ => unreachable!(), + } + } + + /// Convert from Datetime into String with the given format. + /// See [chrono strftime/strptime](https://docs.rs/chrono/0.4.19/chrono/format/strftime/index.html). + pub fn to_string(&self, format: &str) -> PolarsResult { + let conversion_f = match self.time_unit() { + TimeUnit::Nanoseconds => timestamp_ns_to_datetime, + TimeUnit::Microseconds => timestamp_us_to_datetime, + TimeUnit::Milliseconds => timestamp_ms_to_datetime, + }; + let format = get_strftime_format(format, self.dtype())?; + let mut ca: StringChunked = match self.time_zone() { + #[cfg(feature = "timezones")] + Some(time_zone) => { + let parsed_time_zone = time_zone.parse::().expect("already validated"); + let datefmt_f = |ndt| parsed_time_zone.from_utc_datetime(&ndt).format(&format); + self.try_apply_into_string_amortized(|val, buf| { + let ndt = conversion_f(val); + write!(buf, "{}", datefmt_f(ndt)) + } + ).map_err( + |_| polars_err!(ComputeError: "cannot format timezone-aware Datetime with format '{}'", format), + )? + }, + _ => { + let datefmt_f = |ndt: NaiveDateTime| ndt.format(&format); + self.try_apply_into_string_amortized(|val, buf| { + let ndt = conversion_f(val); + write!(buf, "{}", datefmt_f(ndt)) + } + ).map_err( + |_| polars_err!(ComputeError: "cannot format timezone-naive Datetime with format '{}'", format), + )? + }, + }; + ca.rename(self.name().clone()); + Ok(ca) + } + + /// Convert from Datetime into String with the given format. + /// See [chrono strftime/strptime](https://docs.rs/chrono/0.4.19/chrono/format/strftime/index.html). + /// + /// Alias for `to_string`. + pub fn strftime(&self, format: &str) -> PolarsResult { + self.to_string(format) + } + + /// Construct a new [`DatetimeChunked`] from an iterator over [`NaiveDateTime`]. + pub fn from_naive_datetime>( + name: PlSmallStr, + v: I, + tu: TimeUnit, + ) -> Self { + let func = match tu { + TimeUnit::Nanoseconds => datetime_to_timestamp_ns, + TimeUnit::Microseconds => datetime_to_timestamp_us, + TimeUnit::Milliseconds => datetime_to_timestamp_ms, + }; + let vals = v.into_iter().map(func).collect::>(); + Int64Chunked::from_vec(name, vals).into_datetime(tu, None) + } + + pub fn from_naive_datetime_options>>( + name: PlSmallStr, + v: I, + tu: TimeUnit, + ) -> Self { + let func = match tu { + TimeUnit::Nanoseconds => datetime_to_timestamp_ns, + TimeUnit::Microseconds => datetime_to_timestamp_us, + TimeUnit::Milliseconds => datetime_to_timestamp_ms, + }; + let vals = v.into_iter().map(|opt_nd| opt_nd.map(func)); + Int64Chunked::from_iter_options(name, vals).into_datetime(tu, None) + } + + /// Change the underlying [`TimeUnit`]. And update the data accordingly. + #[must_use] + pub fn cast_time_unit(&self, tu: TimeUnit) -> Self { + let current_unit = self.time_unit(); + let mut out = self.clone(); + out.set_time_unit(tu); + + use TimeUnit::*; + match (current_unit, tu) { + (Nanoseconds, Microseconds) => { + let ca = (&self.0).wrapping_floor_div_scalar(1_000); + out.0 = ca; + out + }, + (Nanoseconds, Milliseconds) => { + let ca = (&self.0).wrapping_floor_div_scalar(1_000_000); + out.0 = ca; + out + }, + (Microseconds, Nanoseconds) => { + let ca = &self.0 * 1_000; + out.0 = ca; + out + }, + (Microseconds, Milliseconds) => { + let ca = (&self.0).wrapping_floor_div_scalar(1_000); + out.0 = ca; + out + }, + (Milliseconds, Nanoseconds) => { + let ca = &self.0 * 1_000_000; + out.0 = ca; + out + }, + (Milliseconds, Microseconds) => { + let ca = &self.0 * 1_000; + out.0 = ca; + out + }, + (Nanoseconds, Nanoseconds) + | (Microseconds, Microseconds) + | (Milliseconds, Milliseconds) => out, + } + } + + /// Change the underlying [`TimeUnit`]. This does not modify the data. + pub fn set_time_unit(&mut self, time_unit: TimeUnit) { + self.2 = Some(Datetime(time_unit, self.time_zone().clone())) + } + + /// Change the underlying [`TimeZone`]. This does not modify the data. + /// This does not validate the time zone - it's up to the caller to verify that it's + /// already been validated. + #[cfg(feature = "timezones")] + pub fn set_time_zone(&mut self, time_zone: TimeZone) -> PolarsResult<()> { + self.2 = Some(Datetime(self.time_unit(), Some(time_zone))); + Ok(()) + } + + /// Change the underlying [`TimeUnit`] and [`TimeZone`]. This does not modify the data. + /// This does not validate the time zone - it's up to the caller to verify that it's + /// already been validated. + #[cfg(feature = "timezones")] + pub fn set_time_unit_and_time_zone( + &mut self, + time_unit: TimeUnit, + time_zone: TimeZone, + ) -> PolarsResult<()> { + self.2 = Some(Datetime(time_unit, Some(time_zone))); + Ok(()) + } +} + +#[cfg(test)] +mod test { + use chrono::NaiveDateTime; + + use crate::prelude::*; + + #[test] + fn from_datetime() { + let datetimes: Vec<_> = [ + "1988-08-25 00:00:16", + "2015-09-05 23:56:04", + "2012-12-21 00:00:00", + ] + .iter() + .map(|s| NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S").unwrap()) + .collect(); + + // NOTE: the values are checked and correct. + let dt = DatetimeChunked::from_naive_datetime( + PlSmallStr::from_static("name"), + datetimes.iter().copied(), + TimeUnit::Nanoseconds, + ); + assert_eq!( + [ + 588_470_416_000_000_000, + 1_441_497_364_000_000_000, + 1_356_048_000_000_000_000 + ], + dt.cont_slice().unwrap() + ); + } +} diff --git a/crates/polars-core/src/chunked_array/temporal/duration.rs b/crates/polars-core/src/chunked_array/temporal/duration.rs new file mode 100644 index 000000000000..a0a9c07205f8 --- /dev/null +++ b/crates/polars-core/src/chunked_array/temporal/duration.rs @@ -0,0 +1,131 @@ +use chrono::Duration as ChronoDuration; + +use crate::fmt::{fmt_duration_string, iso_duration_string}; +use crate::prelude::DataType::Duration; +use crate::prelude::*; + +impl DurationChunked { + pub fn time_unit(&self) -> TimeUnit { + match self.2.as_ref().unwrap() { + DataType::Duration(tu) => *tu, + _ => unreachable!(), + } + } + + /// Change the underlying [`TimeUnit`]. And update the data accordingly. + #[must_use] + pub fn cast_time_unit(&self, tu: TimeUnit) -> Self { + let current_unit = self.time_unit(); + let mut out = self.clone(); + out.set_time_unit(tu); + + use TimeUnit::*; + match (current_unit, tu) { + (Nanoseconds, Microseconds) => { + let ca = (&self.0).wrapping_trunc_div_scalar(1_000); + out.0 = ca; + out + }, + (Nanoseconds, Milliseconds) => { + let ca = (&self.0).wrapping_trunc_div_scalar(1_000_000); + out.0 = ca; + out + }, + (Microseconds, Nanoseconds) => { + let ca = &self.0 * 1_000; + out.0 = ca; + out + }, + (Microseconds, Milliseconds) => { + let ca = (&self.0).wrapping_trunc_div_scalar(1_000); + out.0 = ca; + out + }, + (Milliseconds, Nanoseconds) => { + let ca = &self.0 * 1_000_000; + out.0 = ca; + out + }, + (Milliseconds, Microseconds) => { + let ca = &self.0 * 1_000; + out.0 = ca; + out + }, + (Nanoseconds, Nanoseconds) + | (Microseconds, Microseconds) + | (Milliseconds, Milliseconds) => out, + } + } + + /// Change the underlying [`TimeUnit`]. This does not modify the data. + pub fn set_time_unit(&mut self, tu: TimeUnit) { + self.2 = Some(Duration(tu)) + } + + /// Convert from [`Duration`] to String; note that `strftime` format + /// strings are not supported, only the specifiers 'iso' and 'polars'. + pub fn to_string(&self, format: &str) -> PolarsResult { + // the duration string functions below can reuse this string buffer + let mut s = String::with_capacity(32); + match format { + "iso" | "iso:strict" => { + let out: StringChunked = + self.0 + .apply_nonnull_values_generic(DataType::String, |v: i64| { + s.clear(); + iso_duration_string(&mut s, v, self.time_unit()); + s.clone() + }); + Ok(out) + }, + "polars" => { + let out: StringChunked = + self.0 + .apply_nonnull_values_generic(DataType::String, |v: i64| { + s.clear(); + fmt_duration_string(&mut s, v, self.time_unit()) + .map_err(|e| polars_err!(ComputeError: "{:?}", e)) + .expect("failed to format duration"); + s.clone() + }); + Ok(out) + }, + _ => { + polars_bail!( + InvalidOperation: "format {:?} not supported for Duration type (expected one of 'iso' or 'polars')", + format + ) + }, + } + } + + /// Construct a new [`DurationChunked`] from an iterator over [`ChronoDuration`]. + pub fn from_duration>( + name: PlSmallStr, + v: I, + tu: TimeUnit, + ) -> Self { + let func = match tu { + TimeUnit::Nanoseconds => |v: ChronoDuration| v.num_nanoseconds().unwrap(), + TimeUnit::Microseconds => |v: ChronoDuration| v.num_microseconds().unwrap(), + TimeUnit::Milliseconds => |v: ChronoDuration| v.num_milliseconds(), + }; + let vals = v.into_iter().map(func).collect::>(); + Int64Chunked::from_vec(name, vals).into_duration(tu) + } + + /// Construct a new [`DurationChunked`] from an iterator over optional [`ChronoDuration`]. + pub fn from_duration_options>>( + name: PlSmallStr, + v: I, + tu: TimeUnit, + ) -> Self { + let func = match tu { + TimeUnit::Nanoseconds => |v: ChronoDuration| v.num_nanoseconds().unwrap(), + TimeUnit::Microseconds => |v: ChronoDuration| v.num_microseconds().unwrap(), + TimeUnit::Milliseconds => |v: ChronoDuration| v.num_milliseconds(), + }; + let vals = v.into_iter().map(|opt| opt.map(func)); + Int64Chunked::from_iter_options(name, vals).into_duration(tu) + } +} diff --git a/crates/polars-core/src/chunked_array/temporal/mod.rs b/crates/polars-core/src/chunked_array/temporal/mod.rs new file mode 100644 index 000000000000..196609c47f9a --- /dev/null +++ b/crates/polars-core/src/chunked_array/temporal/mod.rs @@ -0,0 +1,87 @@ +//! Traits and utilities for temporal data. +pub mod conversion; +#[cfg(feature = "dtype-date")] +mod date; +#[cfg(feature = "dtype-datetime")] +mod datetime; +#[cfg(feature = "dtype-duration")] +mod duration; +#[cfg(feature = "dtype-time")] +mod time; + +#[cfg(feature = "dtype-date")] +use chrono::NaiveDate; +use chrono::NaiveDateTime; +#[cfg(any(feature = "dtype-time", feature = "dtype-date"))] +use chrono::NaiveTime; +#[cfg(feature = "timezones")] +use chrono_tz::Tz; +#[cfg(feature = "timezones")] +use polars_utils::pl_str::PlSmallStr; +#[cfg(feature = "dtype-time")] +pub use time::time_to_time64ns; + +pub use self::conversion::*; +#[cfg(feature = "timezones")] +use crate::prelude::{PolarsResult, polars_bail}; + +#[cfg(feature = "timezones")] +static FIXED_OFFSET_PATTERN: &str = r#"(?x) + ^ + (?P[-+])? # optional sign + (?P0[0-9]|1[0-4]) # hour (between 0 and 14) + :? # optional separator + 00 # minute + $ + "#; +#[cfg(feature = "timezones")] +polars_utils::regex_cache::cached_regex! { + static FIXED_OFFSET_RE = FIXED_OFFSET_PATTERN; +} + +#[cfg(feature = "timezones")] +pub fn validate_time_zone(tz: &str) -> PolarsResult<()> { + match tz.parse::() { + Ok(_) => Ok(()), + Err(_) => { + polars_bail!(ComputeError: "unable to parse time zone: '{}'. Please check the Time Zone Database for a list of available time zones", tz) + }, + } +} + +#[cfg(feature = "timezones")] +pub fn parse_time_zone(tz: &str) -> PolarsResult { + match tz.parse::() { + Ok(tz) => Ok(tz), + Err(_) => { + polars_bail!(ComputeError: "unable to parse time zone: '{}'. Please check the Time Zone Database for a list of available time zones", tz) + }, + } +} + +/// Convert fixed offset to Etc/GMT one from time zone database +/// +/// E.g. +01:00 -> Etc/GMT-1 +/// +/// Note: the sign appears reversed, but is correct, see : +/// > In order to conform with the POSIX style, those zone names beginning with +/// > "Etc/GMT" have their sign reversed from the standard ISO 8601 convention. +/// > In the "Etc" area, zones west of GMT have a positive sign and those east +/// > have a negative sign in their name (e.g "Etc/GMT-14" is 14 hours ahead of GMT). +#[cfg(feature = "timezones")] +pub fn parse_fixed_offset(tz: &str) -> PolarsResult { + use polars_utils::format_pl_smallstr; + + if let Some(caps) = FIXED_OFFSET_RE.captures(tz) { + let sign = match caps.name("sign").map(|s| s.as_str()) { + Some("-") => "+", + _ => "-", + }; + let hour = caps.name("hour").unwrap().as_str().parse::().unwrap(); + let etc_tz = format_pl_smallstr!("Etc/GMT{}{}", sign, hour); + if etc_tz.parse::().is_ok() { + return Ok(etc_tz); + } + } + polars_bail!(ComputeError: "unable to parse time zone: '{}'. Please check the Time Zone Database for a list of available time zones", tz) +} diff --git a/crates/polars-core/src/chunked_array/temporal/time.rs b/crates/polars-core/src/chunked_array/temporal/time.rs new file mode 100644 index 000000000000..fe0ca0270c1d --- /dev/null +++ b/crates/polars-core/src/chunked_array/temporal/time.rs @@ -0,0 +1,89 @@ +use std::fmt::Write; + +use arrow::temporal_conversions::{NANOSECONDS, time64ns_to_time}; +use chrono::Timelike; + +use super::*; +use crate::prelude::*; + +const SECONDS_IN_MINUTE: i64 = 60; +const SECONDS_IN_HOUR: i64 = 3_600; + +pub fn time_to_time64ns(time: &NaiveTime) -> i64 { + (time.hour() as i64 * SECONDS_IN_HOUR + + time.minute() as i64 * SECONDS_IN_MINUTE + + time.second() as i64) + * NANOSECONDS + + time.nanosecond() as i64 +} + +impl TimeChunked { + /// Convert from Time into String with the given format. + /// See [chrono strftime/strptime](https://docs.rs/chrono/0.4.19/chrono/format/strftime/index.html). + pub fn to_string(&self, format: &str) -> StringChunked { + let mut ca: StringChunked = self.apply_kernel_cast(&|arr| { + let mut buf = String::new(); + let format = if format == "iso" || format == "iso:strict" { + "%T%.9f" + } else { + format + }; + let mut mutarr = MutablePlString::with_capacity(arr.len()); + + for opt in arr.into_iter() { + match opt { + None => mutarr.push_null(), + Some(v) => { + buf.clear(); + let timefmt = time64ns_to_time(*v).format(format); + write!(buf, "{timefmt}").unwrap(); + mutarr.push_value(&buf) + }, + } + } + + mutarr.freeze().boxed() + }); + + ca.rename(self.name().clone()); + ca + } + + /// Convert from Time into String with the given format. + /// See [chrono strftime/strptime](https://docs.rs/chrono/0.4.19/chrono/format/strftime/index.html). + /// + /// Alias for `to_string`. + pub fn strftime(&self, format: &str) -> StringChunked { + self.to_string(format) + } + + pub fn as_time_iter(&self) -> impl TrustedLen> + '_ { + // we know the iterators len + unsafe { + self.downcast_iter() + .flat_map(|iter| { + iter.into_iter() + .map(|opt_v| opt_v.copied().map(time64ns_to_time)) + }) + .trust_my_length(self.len()) + } + } + + /// Construct a new [`TimeChunked`] from an iterator over [`NaiveTime`]. + pub fn from_naive_time>(name: PlSmallStr, v: I) -> Self { + let vals = v + .into_iter() + .map(|nt| time_to_time64ns(&nt)) + .collect::>(); + Int64Chunked::from_vec(name, vals).into_time() + } + + /// Construct a new [`TimeChunked`] from an iterator over optional [`NaiveTime`]. + pub fn from_naive_time_options>>( + name: PlSmallStr, + v: I, + ) -> Self { + let vals = v.into_iter().map(|opt| opt.map(|nt| time_to_time64ns(&nt))); + Int64Chunked::from_iter_options(name, vals).into_time() + } +} diff --git a/crates/polars-core/src/chunked_array/to_vec.rs b/crates/polars-core/src/chunked_array/to_vec.rs new file mode 100644 index 000000000000..1a8ed2798e65 --- /dev/null +++ b/crates/polars-core/src/chunked_array/to_vec.rs @@ -0,0 +1,28 @@ +use either::Either; + +use crate::prelude::*; + +impl ChunkedArray { + /// Convert to a [`Vec`] of [`Option`]. + pub fn to_vec(&self) -> Vec> { + let mut buf = Vec::with_capacity(self.len()); + for arr in self.downcast_iter() { + buf.extend(arr.into_iter().map(|v| v.copied())) + } + buf + } + + /// Convert to a [`Vec`] but don't return [`Option`] if there are no null values + pub fn to_vec_null_aware(&self) -> Either, Vec>> { + if self.null_count() == 0 { + let mut buf = Vec::with_capacity(self.len()); + + for arr in self.downcast_iter() { + buf.extend_from_slice(arr.values()) + } + Either::Left(buf) + } else { + Either::Right(self.to_vec()) + } + } +} diff --git a/crates/polars-core/src/chunked_array/trusted_len.rs b/crates/polars-core/src/chunked_array/trusted_len.rs new file mode 100644 index 000000000000..2304d74933b0 --- /dev/null +++ b/crates/polars-core/src/chunked_array/trusted_len.rs @@ -0,0 +1,213 @@ +use std::borrow::Borrow; + +use arrow::legacy::trusted_len::{FromIteratorReversed, TrustedLenPush}; + +use crate::chunked_array::from_iterator::PolarsAsRef; +use crate::prelude::*; +use crate::utils::{FromTrustedLenIterator, NoNull}; + +impl FromTrustedLenIterator> for ChunkedArray +where + T: PolarsNumericType, +{ + fn from_iter_trusted_length>>(iter: I) -> Self + where + I::IntoIter: TrustedLen, + { + // SAFETY: iter is TrustedLen. + let iter = iter.into_iter(); + let arr = unsafe { + PrimitiveArray::from_trusted_len_iter_unchecked(iter) + .to(T::get_dtype().to_arrow(CompatLevel::newest())) + }; + arr.into() + } +} + +// NoNull is only a wrapper needed for specialization. +impl FromTrustedLenIterator for NoNull> +where + T: PolarsNumericType, +{ + // We use Vec because it is way faster than Arrows builder. We can do this + // because we know we don't have null values. + fn from_iter_trusted_length>(iter: I) -> Self + where + I::IntoIter: TrustedLen, + { + // SAFETY: iter is TrustedLen. + let iter = iter.into_iter(); + let values = unsafe { Vec::from_trusted_len_iter_unchecked(iter) }.into(); + let arr = PrimitiveArray::new(T::get_dtype().to_arrow(CompatLevel::newest()), values, None); + NoNull::new(arr.into()) + } +} + +impl FromIteratorReversed> for ChunkedArray +where + T: PolarsNumericType, +{ + fn from_trusted_len_iter_rev>>(iter: I) -> Self { + let arr: PrimitiveArray = iter.collect_reversed(); + arr.into() + } +} + +impl FromIteratorReversed for NoNull> +where + T: PolarsNumericType, +{ + fn from_trusted_len_iter_rev>(iter: I) -> Self { + let arr: PrimitiveArray = iter.collect_reversed(); + NoNull::new(arr.into()) + } +} + +impl FromIteratorReversed> for BooleanChunked { + fn from_trusted_len_iter_rev>>(iter: I) -> Self { + let arr: BooleanArray = iter.collect_reversed(); + arr.into() + } +} + +impl FromIteratorReversed for NoNull { + fn from_trusted_len_iter_rev>(iter: I) -> Self { + let arr: BooleanArray = iter.collect_reversed(); + NoNull::new(arr.into()) + } +} + +impl FromTrustedLenIterator for ListChunked +where + Ptr: Borrow, +{ + fn from_iter_trusted_length>(iter: I) -> Self { + let iter = iter.into_iter(); + iter.collect() + } +} + +impl FromTrustedLenIterator> for ListChunked { + fn from_iter_trusted_length>>(iter: I) -> Self { + let iter = iter.into_iter(); + iter.collect() + } +} + +impl FromTrustedLenIterator> for ChunkedArray { + fn from_iter_trusted_length>>(iter: I) -> Self + where + I::IntoIter: TrustedLen, + { + let iter = iter.into_iter(); + let arr: BooleanArray = iter.collect_trusted(); + arr.into() + } +} + +impl FromTrustedLenIterator for BooleanChunked { + fn from_iter_trusted_length>(iter: I) -> Self + where + I::IntoIter: TrustedLen, + { + let iter = iter.into_iter(); + let arr: BooleanArray = iter.collect_trusted(); + arr.into() + } +} + +impl FromTrustedLenIterator for NoNull { + fn from_iter_trusted_length>(iter: I) -> Self { + let iter = iter.into_iter(); + iter.collect() + } +} +impl FromTrustedLenIterator for StringChunked +where + Ptr: PolarsAsRef, +{ + fn from_iter_trusted_length>(iter: I) -> Self { + let iter = iter.into_iter(); + iter.collect() + } +} + +impl FromTrustedLenIterator> for StringChunked +where + Ptr: AsRef, +{ + fn from_iter_trusted_length>>(iter: I) -> Self { + let iter = iter.into_iter(); + iter.collect() + } +} + +impl FromTrustedLenIterator for BinaryChunked +where + Ptr: PolarsAsRef<[u8]>, +{ + fn from_iter_trusted_length>(iter: I) -> Self { + let iter = iter.into_iter(); + iter.collect() + } +} + +impl FromTrustedLenIterator> for BinaryChunked +where + Ptr: AsRef<[u8]>, +{ + fn from_iter_trusted_length>>(iter: I) -> Self { + let iter = iter.into_iter(); + iter.collect() + } +} + +impl FromTrustedLenIterator for BinaryOffsetChunked +where + Ptr: PolarsAsRef<[u8]>, +{ + fn from_iter_trusted_length>(iter: I) -> Self { + let arr = BinaryArray::from_iter_values(iter.into_iter()); + ChunkedArray::with_chunk(PlSmallStr::EMPTY, arr) + } +} + +impl FromTrustedLenIterator> for BinaryOffsetChunked +where + Ptr: AsRef<[u8]>, +{ + fn from_iter_trusted_length>>(iter: I) -> Self { + let iter = iter.into_iter(); + let arr = BinaryArray::from_iter(iter); + ChunkedArray::with_chunk(PlSmallStr::EMPTY, arr) + } +} + +#[cfg(feature = "object")] +impl FromTrustedLenIterator> for ObjectChunked { + fn from_iter_trusted_length>>(iter: I) -> Self { + let iter = iter.into_iter(); + iter.collect() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_reverse_collect() { + let ca: NoNull = (0..5).collect_reversed(); + let arr = ca.downcast_iter().next().unwrap(); + let s = arr.values().as_slice(); + assert_eq!(s, &[4, 3, 2, 1, 0]); + + let ca: Int32Chunked = (0..5) + .map(|val| match val % 2 == 0 { + true => Some(val), + false => None, + }) + .collect_reversed(); + assert_eq!(Vec::from(&ca), &[Some(4), None, Some(2), None, Some(0)]); + } +} diff --git a/crates/polars-core/src/config.rs b/crates/polars-core/src/config.rs new file mode 100644 index 000000000000..a1a0defc664d --- /dev/null +++ b/crates/polars-core/src/config.rs @@ -0,0 +1,74 @@ +use crate::POOL; + +// Formatting environment variables (typically referenced/set from the python-side Config object) +#[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] +pub(crate) const FMT_MAX_COLS: &str = "POLARS_FMT_MAX_COLS"; +pub(crate) const FMT_MAX_ROWS: &str = "POLARS_FMT_MAX_ROWS"; +pub(crate) const FMT_STR_LEN: &str = "POLARS_FMT_STR_LEN"; +#[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] +pub(crate) const FMT_TABLE_CELL_ALIGNMENT: &str = "POLARS_FMT_TABLE_CELL_ALIGNMENT"; +#[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] +pub(crate) const FMT_TABLE_CELL_NUMERIC_ALIGNMENT: &str = "POLARS_FMT_TABLE_CELL_NUMERIC_ALIGNMENT"; +#[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] +pub(crate) const FMT_TABLE_DATAFRAME_SHAPE_BELOW: &str = "POLARS_FMT_TABLE_DATAFRAME_SHAPE_BELOW"; +#[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] +pub(crate) const FMT_TABLE_FORMATTING: &str = "POLARS_FMT_TABLE_FORMATTING"; +#[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] +pub(crate) const FMT_TABLE_HIDE_COLUMN_DATA_TYPES: &str = "POLARS_FMT_TABLE_HIDE_COLUMN_DATA_TYPES"; +#[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] +pub(crate) const FMT_TABLE_HIDE_COLUMN_NAMES: &str = "POLARS_FMT_TABLE_HIDE_COLUMN_NAMES"; +#[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] +pub(crate) const FMT_TABLE_HIDE_COLUMN_SEPARATOR: &str = "POLARS_FMT_TABLE_HIDE_COLUMN_SEPARATOR"; +#[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] +pub(crate) const FMT_TABLE_HIDE_DATAFRAME_SHAPE_INFORMATION: &str = + "POLARS_FMT_TABLE_HIDE_DATAFRAME_SHAPE_INFORMATION"; +#[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] +pub(crate) const FMT_TABLE_INLINE_COLUMN_DATA_TYPE: &str = + "POLARS_FMT_TABLE_INLINE_COLUMN_DATA_TYPE"; +#[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] +pub(crate) const FMT_TABLE_ROUNDED_CORNERS: &str = "POLARS_FMT_TABLE_ROUNDED_CORNERS"; +pub(crate) const FMT_TABLE_CELL_LIST_LEN: &str = "POLARS_FMT_TABLE_CELL_LIST_LEN"; + +pub fn verbose() -> bool { + std::env::var("POLARS_VERBOSE").as_deref().unwrap_or("") == "1" +} + +pub fn get_engine_affinity() -> String { + std::env::var("POLARS_ENGINE_AFFINITY").unwrap_or_else(|_| "auto".to_string()) +} + +/// Prints a log message if sensitive verbose logging has been enabled. +pub fn verbose_print_sensitive String>(create_log_message: F) { + fn do_log(create_log_message: &dyn Fn() -> String) { + if std::env::var("POLARS_VERBOSE_SENSITIVE") + .as_deref() + .unwrap_or("") + == "1" + { + // Force the message to be a single line. + let msg = create_log_message().replace('\n', ""); + eprintln!("[SENSITIVE]: {}", msg) + } + } + + do_log(&create_log_message) +} + +pub fn get_file_prefetch_size() -> usize { + std::env::var("POLARS_PREFETCH_SIZE") + .map(|s| s.parse::().expect("integer")) + .unwrap_or_else(|_| std::cmp::max(POOL.current_num_threads() * 2, 16)) +} + +pub fn get_rg_prefetch_size() -> usize { + std::env::var("POLARS_ROW_GROUP_PREFETCH_SIZE") + .map(|s| s.parse::().expect("integer")) + // Set it to something big, but not unlimited. + .unwrap_or_else(|_| std::cmp::max(get_file_prefetch_size(), 128)) +} + +pub fn force_async() -> bool { + std::env::var("POLARS_FORCE_ASYNC") + .map(|value| value == "1") + .unwrap_or_default() +} diff --git a/crates/polars-core/src/datatypes/_serde.rs b/crates/polars-core/src/datatypes/_serde.rs new file mode 100644 index 000000000000..ece27fafa882 --- /dev/null +++ b/crates/polars-core/src/datatypes/_serde.rs @@ -0,0 +1,236 @@ +//! Having `Object<&;static> in [`DataType`] make serde tag the `Deserialize` trait bound 'static +//! even though we skip serializing `Object`. +//! +//! We could use [serde_1712](https://github.com/serde-rs/serde/issues/1712), but that gave problems caused by +//! [rust_96956](https://github.com/rust-lang/rust/issues/96956), so we make a dummy type without static + +#[cfg(feature = "dtype-categorical")] +use serde::de::SeqAccess; +use serde::{Deserialize, Serialize}; + +use super::*; + +impl<'a> Deserialize<'a> for DataType { + fn deserialize(deserializer: D) -> std::result::Result + where + D: Deserializer<'a>, + { + Ok(SerializableDataType::deserialize(deserializer)?.into()) + } +} + +impl Serialize for DataType { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: Serializer, + { + let dt: SerializableDataType = self.into(); + dt.serialize(serializer) + } +} + +#[cfg(feature = "dtype-categorical")] +struct Wrap(T); + +#[cfg(feature = "dtype-categorical")] +impl serde::Serialize for Wrap { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.collect_seq(self.0.values_iter()) + } +} + +#[cfg(feature = "dtype-categorical")] +impl<'de> serde::Deserialize<'de> for Wrap { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct Utf8Visitor; + + impl<'de> Visitor<'de> for Utf8Visitor { + type Value = Wrap; + + fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result { + formatter.write_str("Utf8Visitor string sequence.") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + let mut utf8array = MutablePlString::with_capacity(seq.size_hint().unwrap_or(10)); + while let Some(key) = seq.next_element()? { + let key: Option = key; + utf8array.push(key) + } + Ok(Wrap(utf8array.into())) + } + } + + deserializer.deserialize_seq(Utf8Visitor) + } +} + +#[derive(Serialize, Deserialize)] +enum SerializableDataType { + Boolean, + UInt8, + UInt16, + UInt32, + UInt64, + Int8, + Int16, + Int32, + Int64, + Int128, + Float32, + Float64, + String, + Binary, + /// A 32-bit date representing the elapsed time since UNIX epoch (1970-01-01) + /// in days (32 bits). + Date, + /// A 64-bit date representing the elapsed time since UNIX epoch (1970-01-01) + /// in the given ms/us/ns TimeUnit (64 bits). + Datetime(TimeUnit, Option), + // 64-bit integer representing difference between times in milli|micro|nano seconds + Duration(TimeUnit), + /// A 64-bit time representing elapsed time since midnight in the given TimeUnit. + Time, + List(Box), + #[cfg(feature = "dtype-array")] + Array(Box, usize), + Null, + #[cfg(feature = "dtype-struct")] + Struct(Vec), + // some logical types we cannot know statically, e.g. Datetime + Unknown(UnknownKind), + #[cfg(feature = "dtype-categorical")] + Categorical(Option, CategoricalOrdering), + #[cfg(feature = "dtype-decimal")] + Decimal(Option, Option), + #[cfg(feature = "dtype-categorical")] + Enum(Option, CategoricalOrdering), + #[cfg(feature = "object")] + Object(String), +} + +impl From<&DataType> for SerializableDataType { + fn from(dt: &DataType) -> Self { + use DataType::*; + match dt { + Boolean => Self::Boolean, + UInt8 => Self::UInt8, + UInt16 => Self::UInt16, + UInt32 => Self::UInt32, + UInt64 => Self::UInt64, + Int8 => Self::Int8, + Int16 => Self::Int16, + Int32 => Self::Int32, + Int64 => Self::Int64, + Int128 => Self::Int128, + Float32 => Self::Float32, + Float64 => Self::Float64, + String => Self::String, + Binary => Self::Binary, + Date => Self::Date, + Datetime(tu, tz) => Self::Datetime(*tu, tz.clone()), + Duration(tu) => Self::Duration(*tu), + Time => Self::Time, + List(dt) => Self::List(Box::new(dt.as_ref().into())), + #[cfg(feature = "dtype-array")] + Array(dt, width) => Self::Array(Box::new(dt.as_ref().into()), *width), + Null => Self::Null, + Unknown(kind) => Self::Unknown(*kind), + #[cfg(feature = "dtype-struct")] + Struct(flds) => Self::Struct(flds.clone()), + #[cfg(feature = "dtype-categorical")] + Categorical(Some(rev_map), ordering) => Self::Categorical( + Some( + StringChunked::with_chunk(PlSmallStr::EMPTY, rev_map.get_categories().clone()) + .into_series(), + ), + *ordering, + ), + #[cfg(feature = "dtype-categorical")] + Categorical(None, ordering) => Self::Categorical(None, *ordering), + #[cfg(feature = "dtype-categorical")] + Enum(Some(rev_map), ordering) => Self::Enum( + Some( + StringChunked::with_chunk(PlSmallStr::EMPTY, rev_map.get_categories().clone()) + .into_series(), + ), + *ordering, + ), + #[cfg(feature = "dtype-categorical")] + Enum(None, ordering) => Self::Enum(None, *ordering), + #[cfg(feature = "dtype-decimal")] + Decimal(precision, scale) => Self::Decimal(*precision, *scale), + #[cfg(feature = "object")] + Object(name) => Self::Object(name.to_string()), + dt => panic!("{dt:?} not supported"), + } + } +} +impl From for DataType { + fn from(dt: SerializableDataType) -> Self { + use SerializableDataType::*; + match dt { + Boolean => Self::Boolean, + UInt8 => Self::UInt8, + UInt16 => Self::UInt16, + UInt32 => Self::UInt32, + UInt64 => Self::UInt64, + Int8 => Self::Int8, + Int16 => Self::Int16, + Int32 => Self::Int32, + Int64 => Self::Int64, + Int128 => Self::Int128, + Float32 => Self::Float32, + Float64 => Self::Float64, + String => Self::String, + Binary => Self::Binary, + Date => Self::Date, + Datetime(tu, tz) => Self::Datetime(tu, tz), + Duration(tu) => Self::Duration(tu), + Time => Self::Time, + List(dt) => Self::List(Box::new((*dt).into())), + #[cfg(feature = "dtype-array")] + Array(dt, width) => Self::Array(Box::new((*dt).into()), width), + Null => Self::Null, + Unknown(kind) => Self::Unknown(kind), + #[cfg(feature = "dtype-struct")] + Struct(flds) => Self::Struct(flds), + #[cfg(feature = "dtype-categorical")] + Categorical(Some(categories), ordering) => Self::Categorical( + Some(Arc::new(RevMapping::build_local( + categories.0.rechunk().chunks()[0] + .as_any() + .downcast_ref::() + .unwrap() + .clone(), + ))), + ordering, + ), + #[cfg(feature = "dtype-categorical")] + Categorical(None, ordering) => Self::Categorical(None, ordering), + #[cfg(feature = "dtype-categorical")] + Enum(Some(categories), _) => create_enum_dtype( + categories.rechunk().chunks()[0] + .as_any() + .downcast_ref::() + .unwrap() + .clone(), + ), + #[cfg(feature = "dtype-categorical")] + Enum(None, ordering) => Self::Enum(None, ordering), + #[cfg(feature = "dtype-decimal")] + Decimal(precision, scale) => Self::Decimal(precision, scale), + #[cfg(feature = "object")] + Object(_) => Self::Object("unknown"), + } + } +} diff --git a/crates/polars-core/src/datatypes/aliases.rs b/crates/polars-core/src/datatypes/aliases.rs new file mode 100644 index 000000000000..4787b7fcd229 --- /dev/null +++ b/crates/polars-core/src/datatypes/aliases.rs @@ -0,0 +1,47 @@ +pub use arrow::legacy::index::IdxArr; +pub use polars_utils::aliases::{ + InitHashMaps, PlHashMap, PlHashSet, PlIndexMap, PlIndexSet, PlRandomState, +}; + +use super::*; +use crate::hashing::IdBuildHasher; + +#[cfg(not(feature = "bigidx"))] +pub type IdxCa = UInt32Chunked; +#[cfg(feature = "bigidx")] +pub type IdxCa = UInt64Chunked; + +#[cfg(not(feature = "bigidx"))] +pub const IDX_DTYPE: DataType = DataType::UInt32; +#[cfg(feature = "bigidx")] +pub const IDX_DTYPE: DataType = DataType::UInt64; + +#[cfg(not(feature = "bigidx"))] +pub type IdxType = UInt32Type; +#[cfg(feature = "bigidx")] +pub type IdxType = UInt64Type; + +pub use polars_utils::pl_str::PlSmallStr; + +/// This hashmap uses an IdHasher +pub type PlIdHashMap = hashbrown::HashMap; + +pub trait InitHashMaps2 { + type HashMap; + + fn new() -> Self::HashMap; + + fn with_capacity(capacity: usize) -> Self::HashMap; +} + +impl InitHashMaps2 for PlIdHashMap { + type HashMap = Self; + + fn new() -> Self::HashMap { + Self::with_capacity_and_hasher(0, Default::default()) + } + + fn with_capacity(capacity: usize) -> Self { + Self::with_capacity_and_hasher(capacity, Default::default()) + } +} diff --git a/crates/polars-core/src/datatypes/any_value.rs b/crates/polars-core/src/datatypes/any_value.rs new file mode 100644 index 000000000000..58b1ba15f7a4 --- /dev/null +++ b/crates/polars-core/src/datatypes/any_value.rs @@ -0,0 +1,1816 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use std::borrow::Cow; + +use arrow::types::PrimitiveType; +use polars_compute::cast::SerPrimitive; +use polars_error::feature_gated; +#[cfg(feature = "dtype-categorical")] +use polars_utils::sync::SyncPtr; +use polars_utils::total_ord::ToTotalOrd; + +use super::*; +use crate::CHEAP_SERIES_HASH_LIMIT; +#[cfg(feature = "dtype-struct")] +use crate::prelude::any_value::arr_to_any_value; + +#[cfg(feature = "object")] +#[derive(Debug)] +pub struct OwnedObject(pub Box); + +#[cfg(feature = "object")] +impl Clone for OwnedObject { + fn clone(&self) -> Self { + Self(self.0.to_boxed()) + } +} + +#[derive(Debug, Clone, Default)] +pub enum AnyValue<'a> { + #[default] + Null, + /// A binary true or false. + Boolean(bool), + /// A UTF8 encoded string type. + String(&'a str), + /// An unsigned 8-bit integer number. + UInt8(u8), + /// An unsigned 16-bit integer number. + UInt16(u16), + /// An unsigned 32-bit integer number. + UInt32(u32), + /// An unsigned 64-bit integer number. + UInt64(u64), + /// An 8-bit integer number. + Int8(i8), + /// A 16-bit integer number. + Int16(i16), + /// A 32-bit integer number. + Int32(i32), + /// A 64-bit integer number. + Int64(i64), + /// A 128-bit integer number. + Int128(i128), + /// A 32-bit floating point number. + Float32(f32), + /// A 64-bit floating point number. + Float64(f64), + /// A 32-bit date representing the elapsed time since UNIX epoch (1970-01-01) + /// in days (32 bits). + #[cfg(feature = "dtype-date")] + Date(i32), + /// A 64-bit date representing the elapsed time since UNIX epoch (1970-01-01) + /// in nanoseconds (64 bits). + #[cfg(feature = "dtype-datetime")] + Datetime(i64, TimeUnit, Option<&'a TimeZone>), + /// A 64-bit date representing the elapsed time since UNIX epoch (1970-01-01) + /// in nanoseconds (64 bits). + #[cfg(feature = "dtype-datetime")] + DatetimeOwned(i64, TimeUnit, Option>), + /// A 64-bit integer representing difference between date-times in [`TimeUnit`] + #[cfg(feature = "dtype-duration")] + Duration(i64, TimeUnit), + /// A 64-bit time representing the elapsed time since midnight in nanoseconds + #[cfg(feature = "dtype-time")] + Time(i64), + // If syncptr is_null the data is in the rev-map + // otherwise it is in the array pointer + #[cfg(feature = "dtype-categorical")] + Categorical(u32, &'a RevMapping, SyncPtr), + // If syncptr is_null the data is in the rev-map + // otherwise it is in the array pointer + #[cfg(feature = "dtype-categorical")] + CategoricalOwned(u32, Arc, SyncPtr), + #[cfg(feature = "dtype-categorical")] + Enum(u32, &'a RevMapping, SyncPtr), + #[cfg(feature = "dtype-categorical")] + EnumOwned(u32, Arc, SyncPtr), + /// Nested type, contains arrays that are filled with one of the datatypes. + List(Series), + #[cfg(feature = "dtype-array")] + Array(Series, usize), + /// Can be used to fmt and implements Any, so can be downcasted to the proper value type. + #[cfg(feature = "object")] + Object(&'a dyn PolarsObjectSafe), + #[cfg(feature = "object")] + ObjectOwned(OwnedObject), + // 3 pointers and thus not larger than string/vec + // - idx in the `&StructArray` + // - The array itself + // - The fields + #[cfg(feature = "dtype-struct")] + Struct(usize, &'a StructArray, &'a [Field]), + #[cfg(feature = "dtype-struct")] + StructOwned(Box<(Vec>, Vec)>), + /// An UTF8 encoded string type. + StringOwned(PlSmallStr), + Binary(&'a [u8]), + BinaryOwned(Vec), + /// A 128-bit fixed point decimal number with a scale. + #[cfg(feature = "dtype-decimal")] + Decimal(i128, usize), +} + +#[cfg(feature = "serde")] +impl Serialize for AnyValue<'_> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: Serializer, + { + let name = "AnyValue"; + match self { + AnyValue::Null => serializer.serialize_unit_variant(name, 0, "Null"), + + AnyValue::Int8(v) => serializer.serialize_newtype_variant(name, 1, "Int8", v), + AnyValue::Int16(v) => serializer.serialize_newtype_variant(name, 2, "Int16", v), + AnyValue::Int32(v) => serializer.serialize_newtype_variant(name, 3, "Int32", v), + AnyValue::Int64(v) => serializer.serialize_newtype_variant(name, 4, "Int64", v), + AnyValue::Int128(v) => serializer.serialize_newtype_variant(name, 5, "Int128", v), + AnyValue::UInt8(v) => serializer.serialize_newtype_variant(name, 6, "UInt8", v), + AnyValue::UInt16(v) => serializer.serialize_newtype_variant(name, 7, "UInt16", v), + AnyValue::UInt32(v) => serializer.serialize_newtype_variant(name, 8, "UInt32", v), + AnyValue::UInt64(v) => serializer.serialize_newtype_variant(name, 9, "UInt64", v), + AnyValue::Float32(v) => serializer.serialize_newtype_variant(name, 10, "Float32", v), + AnyValue::Float64(v) => serializer.serialize_newtype_variant(name, 11, "Float64", v), + AnyValue::List(v) => serializer.serialize_newtype_variant(name, 12, "List", v), + AnyValue::Boolean(v) => serializer.serialize_newtype_variant(name, 13, "Bool", v), + + // both variants same number + AnyValue::String(v) => serializer.serialize_newtype_variant(name, 14, "StringOwned", v), + AnyValue::StringOwned(v) => { + serializer.serialize_newtype_variant(name, 14, "StringOwned", v.as_str()) + }, + + // both variants same number + AnyValue::Binary(v) => serializer.serialize_newtype_variant(name, 15, "BinaryOwned", v), + AnyValue::BinaryOwned(v) => { + serializer.serialize_newtype_variant(name, 15, "BinaryOwned", v) + }, + + #[cfg(feature = "dtype-date")] + AnyValue::Date(v) => serializer.serialize_newtype_variant(name, 16, "Date", v), + // both variants same number + #[cfg(feature = "dtype-datetime")] + AnyValue::Datetime(v, tu, tz) => serializer.serialize_newtype_variant( + name, + 17, + "DatetimeOwned", + &(*v, *tu, tz.map(|v| v.as_str())), + ), + #[cfg(feature = "dtype-datetime")] + AnyValue::DatetimeOwned(v, tu, tz) => serializer.serialize_newtype_variant( + name, + 17, + "DatetimeOwned", + &(*v, *tu, tz.as_deref().map(|v| v.as_str())), + ), + #[cfg(feature = "dtype-duration")] + AnyValue::Duration(v, tu) => { + serializer.serialize_newtype_variant(name, 18, "Duration", &(*v, *tu)) + }, + #[cfg(feature = "dtype-time")] + AnyValue::Time(v) => serializer.serialize_newtype_variant(name, 19, "Time", v), + + // Not 100% sure how to deal with these. + #[cfg(feature = "dtype-categorical")] + AnyValue::Categorical(..) | AnyValue::CategoricalOwned(..) => Err( + serde::ser::Error::custom("Cannot serialize categorical value."), + ), + #[cfg(feature = "dtype-categorical")] + AnyValue::Enum(..) | AnyValue::EnumOwned(..) => { + Err(serde::ser::Error::custom("Cannot serialize enum value.")) + }, + + #[cfg(feature = "dtype-array")] + AnyValue::Array(v, width) => { + serializer.serialize_newtype_variant(name, 22, "Array", &(v, *width)) + }, + #[cfg(feature = "object")] + AnyValue::Object(_) | AnyValue::ObjectOwned(_) => { + Err(serde::ser::Error::custom("Cannot serialize object value.")) + }, + #[cfg(feature = "dtype-struct")] + AnyValue::Struct(_, _, _) | AnyValue::StructOwned(_) => { + Err(serde::ser::Error::custom("Cannot serialize struct value.")) + }, + #[cfg(feature = "dtype-decimal")] + AnyValue::Decimal(v, scale) => { + serializer.serialize_newtype_variant(name, 25, "Decimal", &(*v, *scale)) + }, + } + } +} + +#[cfg(feature = "serde")] +impl<'a> Deserialize<'a> for AnyValue<'static> { + fn deserialize(deserializer: D) -> std::result::Result + where + D: Deserializer<'a>, + { + macro_rules! define_av_field { + ($($variant:ident,)+) => { + #[derive(Deserialize, Serialize)] + enum AvField { + $($variant,)+ + } + const VARIANTS: &'static [&'static str] = &[ + $(stringify!($variant),)+ + ]; + }; + } + define_av_field! { + Null, + Int8, + Int16, + Int32, + Int64, + Int128, + UInt8, + UInt16, + UInt32, + UInt64, + Float32, + Float64, + List, + Bool, + StringOwned, + BinaryOwned, + Date, + DatetimeOwned, + Duration, + Time, + CategoricalOwned, + EnumOwned, + Array, + Object, + Struct, + Decimal, + }; + + struct OuterVisitor; + + impl<'b> Visitor<'b> for OuterVisitor { + type Value = AnyValue<'static>; + + fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result { + write!(formatter, "enum AnyValue") + } + + fn visit_enum(self, data: A) -> std::result::Result + where + A: EnumAccess<'b>, + { + let out = match data.variant()? { + (AvField::Null, _variant) => AnyValue::Null, + (AvField::Int8, variant) => { + let value = variant.newtype_variant()?; + AnyValue::Int8(value) + }, + (AvField::Int16, variant) => { + let value = variant.newtype_variant()?; + AnyValue::Int16(value) + }, + (AvField::Int32, variant) => { + let value = variant.newtype_variant()?; + AnyValue::Int32(value) + }, + (AvField::Int64, variant) => { + let value = variant.newtype_variant()?; + AnyValue::Int64(value) + }, + (AvField::Int128, variant) => { + let value = variant.newtype_variant()?; + AnyValue::Int128(value) + }, + (AvField::UInt8, variant) => { + let value = variant.newtype_variant()?; + AnyValue::UInt8(value) + }, + (AvField::UInt16, variant) => { + let value = variant.newtype_variant()?; + AnyValue::UInt16(value) + }, + (AvField::UInt32, variant) => { + let value = variant.newtype_variant()?; + AnyValue::UInt32(value) + }, + (AvField::UInt64, variant) => { + let value = variant.newtype_variant()?; + AnyValue::UInt64(value) + }, + (AvField::Float32, variant) => { + let value = variant.newtype_variant()?; + AnyValue::Float32(value) + }, + (AvField::Float64, variant) => { + let value = variant.newtype_variant()?; + AnyValue::Float64(value) + }, + (AvField::Bool, variant) => { + let value = variant.newtype_variant()?; + AnyValue::Boolean(value) + }, + (AvField::List, variant) => { + let value = variant.newtype_variant()?; + AnyValue::List(value) + }, + (AvField::StringOwned, variant) => { + let value: PlSmallStr = variant.newtype_variant()?; + AnyValue::StringOwned(value) + }, + (AvField::BinaryOwned, variant) => { + let value = variant.newtype_variant()?; + AnyValue::BinaryOwned(value) + }, + (AvField::Date, variant) => feature_gated!("dtype-date", { + let value = variant.newtype_variant()?; + AnyValue::Date(value) + }), + (AvField::DatetimeOwned, variant) => feature_gated!("dtype-datetime", { + let (value, time_unit, time_zone) = variant.newtype_variant()?; + AnyValue::DatetimeOwned(value, time_unit, time_zone) + }), + (AvField::Duration, variant) => feature_gated!("dtype-duration", { + let (value, time_unit) = variant.newtype_variant()?; + AnyValue::Duration(value, time_unit) + }), + (AvField::Time, variant) => feature_gated!("dtype-time", { + let value = variant.newtype_variant()?; + AnyValue::Time(value) + }), + (AvField::CategoricalOwned, _) => feature_gated!("dtype-categorical", { + return Err(serde::de::Error::custom( + "unable to deserialize categorical", + )); + }), + (AvField::EnumOwned, _) => feature_gated!("dtype-categorical", { + return Err(serde::de::Error::custom("unable to deserialize enum")); + }), + (AvField::Array, variant) => feature_gated!("dtype-array", { + let (s, width) = variant.newtype_variant()?; + AnyValue::Array(s, width) + }), + (AvField::Object, _) => feature_gated!("object", { + return Err(serde::de::Error::custom("unable to deserialize object")); + }), + (AvField::Struct, _) => feature_gated!("dtype-struct", { + return Err(serde::de::Error::custom("unable to deserialize struct")); + }), + (AvField::Decimal, variant) => feature_gated!("dtype-decimal", { + let (v, scale) = variant.newtype_variant()?; + AnyValue::Decimal(v, scale) + }), + }; + Ok(out) + } + } + deserializer.deserialize_enum("AnyValue", VARIANTS, OuterVisitor) + } +} + +impl AnyValue<'static> { + pub fn zero_sum(dtype: &DataType) -> Self { + match dtype { + DataType::String => AnyValue::StringOwned(PlSmallStr::EMPTY), + DataType::Binary => AnyValue::BinaryOwned(Vec::new()), + DataType::Boolean => (0 as IdxSize).into(), + // SAFETY: numeric values are static, inform the compiler of this. + d if d.is_primitive_numeric() => unsafe { + std::mem::transmute::, AnyValue<'static>>( + AnyValue::UInt8(0).cast(dtype), + ) + }, + #[cfg(feature = "dtype-duration")] + DataType::Duration(unit) => AnyValue::Duration(0, *unit), + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(_p, s) => { + AnyValue::Decimal(0, s.expect("unknown scale during execution")) + }, + _ => AnyValue::Null, + } + } + + /// Can the [`AnyValue`] exist as having `dtype` as its `DataType`. + pub fn can_have_dtype(&self, dtype: &DataType) -> bool { + matches!(self, AnyValue::Null) || dtype == &self.dtype() + } +} + +impl<'a> AnyValue<'a> { + /// Get the matching [`DataType`] for this [`AnyValue`]`. + /// + /// Note: For `Categorical` and `Enum` values, the exact mapping information + /// is not preserved in the result for performance reasons. + pub fn dtype(&self) -> DataType { + use AnyValue::*; + match self { + Null => DataType::Null, + Boolean(_) => DataType::Boolean, + Int8(_) => DataType::Int8, + Int16(_) => DataType::Int16, + Int32(_) => DataType::Int32, + Int64(_) => DataType::Int64, + Int128(_) => DataType::Int128, + UInt8(_) => DataType::UInt8, + UInt16(_) => DataType::UInt16, + UInt32(_) => DataType::UInt32, + UInt64(_) => DataType::UInt64, + Float32(_) => DataType::Float32, + Float64(_) => DataType::Float64, + String(_) | StringOwned(_) => DataType::String, + Binary(_) | BinaryOwned(_) => DataType::Binary, + #[cfg(feature = "dtype-date")] + Date(_) => DataType::Date, + #[cfg(feature = "dtype-time")] + Time(_) => DataType::Time, + #[cfg(feature = "dtype-datetime")] + Datetime(_, tu, tz) => DataType::Datetime(*tu, (*tz).cloned()), + #[cfg(feature = "dtype-datetime")] + DatetimeOwned(_, tu, tz) => { + DataType::Datetime(*tu, tz.as_ref().map(|v| v.as_ref().clone())) + }, + #[cfg(feature = "dtype-duration")] + Duration(_, tu) => DataType::Duration(*tu), + #[cfg(feature = "dtype-categorical")] + Categorical(_, _, _) | CategoricalOwned(_, _, _) => { + DataType::Categorical(None, Default::default()) + }, + #[cfg(feature = "dtype-categorical")] + Enum(_, _, _) | EnumOwned(_, _, _) => DataType::Enum(None, Default::default()), + List(s) => DataType::List(Box::new(s.dtype().clone())), + #[cfg(feature = "dtype-array")] + Array(s, size) => DataType::Array(Box::new(s.dtype().clone()), *size), + #[cfg(feature = "dtype-struct")] + Struct(_, _, fields) => DataType::Struct(fields.to_vec()), + #[cfg(feature = "dtype-struct")] + StructOwned(payload) => DataType::Struct(payload.1.clone()), + #[cfg(feature = "dtype-decimal")] + Decimal(_, scale) => DataType::Decimal(None, Some(*scale)), + #[cfg(feature = "object")] + Object(o) => DataType::Object(o.type_name()), + #[cfg(feature = "object")] + ObjectOwned(o) => DataType::Object(o.0.type_name()), + } + } + + /// Extract a numerical value from the AnyValue + #[doc(hidden)] + #[inline] + pub fn extract(&self) -> Option { + use AnyValue::*; + match self { + Int8(v) => NumCast::from(*v), + Int16(v) => NumCast::from(*v), + Int32(v) => NumCast::from(*v), + Int64(v) => NumCast::from(*v), + Int128(v) => NumCast::from(*v), + UInt8(v) => NumCast::from(*v), + UInt16(v) => NumCast::from(*v), + UInt32(v) => NumCast::from(*v), + UInt64(v) => NumCast::from(*v), + Float32(v) => NumCast::from(*v), + Float64(v) => NumCast::from(*v), + #[cfg(feature = "dtype-date")] + Date(v) => NumCast::from(*v), + #[cfg(feature = "dtype-datetime")] + Datetime(v, _, _) | DatetimeOwned(v, _, _) => NumCast::from(*v), + #[cfg(feature = "dtype-time")] + Time(v) => NumCast::from(*v), + #[cfg(feature = "dtype-duration")] + Duration(v, _) => NumCast::from(*v), + #[cfg(feature = "dtype-decimal")] + Decimal(v, scale) => { + if *scale == 0 { + NumCast::from(*v) + } else { + let f: Option = NumCast::from(*v); + NumCast::from(f? / 10f64.powi(*scale as _)) + } + }, + Boolean(v) => NumCast::from(if *v { 1 } else { 0 }), + String(v) => { + if let Ok(val) = (*v).parse::() { + NumCast::from(val) + } else { + NumCast::from((*v).parse::().ok()?) + } + }, + StringOwned(v) => String(v.as_str()).extract(), + _ => None, + } + } + + #[inline] + pub fn try_extract(&self) -> PolarsResult { + self.extract().ok_or_else(|| { + polars_err!( + ComputeError: "could not extract number from any-value of dtype: '{:?}'", + self.dtype(), + ) + }) + } + + pub fn is_boolean(&self) -> bool { + matches!(self, AnyValue::Boolean(_)) + } + + pub fn is_primitive_numeric(&self) -> bool { + self.is_integer() || self.is_float() + } + + pub fn is_float(&self) -> bool { + matches!(self, AnyValue::Float32(_) | AnyValue::Float64(_)) + } + + pub fn is_integer(&self) -> bool { + self.is_signed_integer() || self.is_unsigned_integer() + } + + pub fn is_signed_integer(&self) -> bool { + matches!( + self, + AnyValue::Int8(_) + | AnyValue::Int16(_) + | AnyValue::Int32(_) + | AnyValue::Int64(_) + | AnyValue::Int128(_) + ) + } + + pub fn is_unsigned_integer(&self) -> bool { + matches!( + self, + AnyValue::UInt8(_) | AnyValue::UInt16(_) | AnyValue::UInt32(_) | AnyValue::UInt64(_) + ) + } + + pub fn is_nan(&self) -> bool { + match self { + AnyValue::Float32(f) => f.is_nan(), + AnyValue::Float64(f) => f.is_nan(), + _ => false, + } + } + + pub fn is_null(&self) -> bool { + matches!(self, AnyValue::Null) + } + + pub fn is_nested_null(&self) -> bool { + match self { + AnyValue::Null => true, + AnyValue::List(s) => s.null_count() == s.len(), + #[cfg(feature = "dtype-array")] + AnyValue::Array(s, _) => s.null_count() == s.len(), + #[cfg(feature = "dtype-struct")] + AnyValue::Struct(_, _, _) => self._iter_struct_av().all(|av| av.is_nested_null()), + _ => false, + } + } + + /// Cast `AnyValue` to the provided data type and return a new `AnyValue` with type `dtype`, + /// if possible. + pub fn strict_cast(&self, dtype: &'a DataType) -> Option> { + let new_av = match (self, dtype) { + // to numeric + (av, DataType::UInt8) => AnyValue::UInt8(av.extract::()?), + (av, DataType::UInt16) => AnyValue::UInt16(av.extract::()?), + (av, DataType::UInt32) => AnyValue::UInt32(av.extract::()?), + (av, DataType::UInt64) => AnyValue::UInt64(av.extract::()?), + (av, DataType::Int8) => AnyValue::Int8(av.extract::()?), + (av, DataType::Int16) => AnyValue::Int16(av.extract::()?), + (av, DataType::Int32) => AnyValue::Int32(av.extract::()?), + (av, DataType::Int64) => AnyValue::Int64(av.extract::()?), + (av, DataType::Int128) => AnyValue::Int128(av.extract::()?), + (av, DataType::Float32) => AnyValue::Float32(av.extract::()?), + (av, DataType::Float64) => AnyValue::Float64(av.extract::()?), + + // to boolean + (AnyValue::UInt8(v), DataType::Boolean) => AnyValue::Boolean(*v != u8::default()), + (AnyValue::UInt16(v), DataType::Boolean) => AnyValue::Boolean(*v != u16::default()), + (AnyValue::UInt32(v), DataType::Boolean) => AnyValue::Boolean(*v != u32::default()), + (AnyValue::UInt64(v), DataType::Boolean) => AnyValue::Boolean(*v != u64::default()), + (AnyValue::Int8(v), DataType::Boolean) => AnyValue::Boolean(*v != i8::default()), + (AnyValue::Int16(v), DataType::Boolean) => AnyValue::Boolean(*v != i16::default()), + (AnyValue::Int32(v), DataType::Boolean) => AnyValue::Boolean(*v != i32::default()), + (AnyValue::Int64(v), DataType::Boolean) => AnyValue::Boolean(*v != i64::default()), + (AnyValue::Int128(v), DataType::Boolean) => AnyValue::Boolean(*v != i128::default()), + (AnyValue::Float32(v), DataType::Boolean) => AnyValue::Boolean(*v != f32::default()), + (AnyValue::Float64(v), DataType::Boolean) => AnyValue::Boolean(*v != f64::default()), + + // to string + (AnyValue::String(v), DataType::String) => AnyValue::String(v), + (AnyValue::StringOwned(v), DataType::String) => AnyValue::StringOwned(v.clone()), + + (av, DataType::String) => { + let mut tmp = vec![]; + if av.is_unsigned_integer() { + let val = av.extract::()?; + SerPrimitive::write(&mut tmp, val); + } else if av.is_float() { + let val = av.extract::()?; + SerPrimitive::write(&mut tmp, val); + } else { + let val = av.extract::()?; + SerPrimitive::write(&mut tmp, val); + } + AnyValue::StringOwned(PlSmallStr::from_str(std::str::from_utf8(&tmp).unwrap())) + }, + + // to binary + (AnyValue::String(v), DataType::Binary) => AnyValue::Binary(v.as_bytes()), + + // to datetime + #[cfg(feature = "dtype-datetime")] + (av, DataType::Datetime(tu, tz)) if av.is_primitive_numeric() => { + AnyValue::Datetime(av.extract::()?, *tu, tz.as_ref()) + }, + #[cfg(all(feature = "dtype-datetime", feature = "dtype-date"))] + (AnyValue::Date(v), DataType::Datetime(tu, _)) => AnyValue::Datetime( + match tu { + TimeUnit::Nanoseconds => (*v as i64) * NS_IN_DAY, + TimeUnit::Microseconds => (*v as i64) * US_IN_DAY, + TimeUnit::Milliseconds => (*v as i64) * MS_IN_DAY, + }, + *tu, + None, + ), + #[cfg(feature = "dtype-datetime")] + ( + AnyValue::Datetime(v, tu, _) | AnyValue::DatetimeOwned(v, tu, _), + DataType::Datetime(tu_r, tz_r), + ) => AnyValue::Datetime( + match (tu, tu_r) { + (TimeUnit::Nanoseconds, TimeUnit::Microseconds) => *v / 1_000i64, + (TimeUnit::Nanoseconds, TimeUnit::Milliseconds) => *v / 1_000_000i64, + (TimeUnit::Microseconds, TimeUnit::Nanoseconds) => *v * 1_000i64, + (TimeUnit::Microseconds, TimeUnit::Milliseconds) => *v / 1_000i64, + (TimeUnit::Milliseconds, TimeUnit::Microseconds) => *v * 1_000i64, + (TimeUnit::Milliseconds, TimeUnit::Nanoseconds) => *v * 1_000_000i64, + _ => *v, + }, + *tu_r, + tz_r.as_ref(), + ), + + // to date + #[cfg(feature = "dtype-date")] + (av, DataType::Date) if av.is_primitive_numeric() => { + AnyValue::Date(av.extract::()?) + }, + #[cfg(all(feature = "dtype-date", feature = "dtype-datetime"))] + (AnyValue::Datetime(v, tu, _) | AnyValue::DatetimeOwned(v, tu, _), DataType::Date) => { + AnyValue::Date(match tu { + TimeUnit::Nanoseconds => *v / NS_IN_DAY, + TimeUnit::Microseconds => *v / US_IN_DAY, + TimeUnit::Milliseconds => *v / MS_IN_DAY, + } as i32) + }, + + // to time + #[cfg(feature = "dtype-time")] + (av, DataType::Time) if av.is_primitive_numeric() => { + AnyValue::Time(av.extract::()?) + }, + #[cfg(all(feature = "dtype-time", feature = "dtype-datetime"))] + (AnyValue::Datetime(v, tu, _) | AnyValue::DatetimeOwned(v, tu, _), DataType::Time) => { + AnyValue::Time(match tu { + TimeUnit::Nanoseconds => *v % NS_IN_DAY, + TimeUnit::Microseconds => (*v % US_IN_DAY) * 1_000i64, + TimeUnit::Milliseconds => (*v % MS_IN_DAY) * 1_000_000i64, + }) + }, + + // to duration + #[cfg(feature = "dtype-duration")] + (av, DataType::Duration(tu)) if av.is_primitive_numeric() => { + AnyValue::Duration(av.extract::()?, *tu) + }, + #[cfg(all(feature = "dtype-duration", feature = "dtype-time"))] + (AnyValue::Time(v), DataType::Duration(tu)) => AnyValue::Duration( + match *tu { + TimeUnit::Nanoseconds => *v, + TimeUnit::Microseconds => *v / 1_000i64, + TimeUnit::Milliseconds => *v / 1_000_000i64, + }, + *tu, + ), + #[cfg(feature = "dtype-duration")] + (AnyValue::Duration(v, tu), DataType::Duration(tu_r)) => AnyValue::Duration( + match (tu, tu_r) { + (_, _) if tu == tu_r => *v, + (TimeUnit::Nanoseconds, TimeUnit::Microseconds) => *v / 1_000i64, + (TimeUnit::Nanoseconds, TimeUnit::Milliseconds) => *v / 1_000_000i64, + (TimeUnit::Microseconds, TimeUnit::Nanoseconds) => *v * 1_000i64, + (TimeUnit::Microseconds, TimeUnit::Milliseconds) => *v / 1_000i64, + (TimeUnit::Milliseconds, TimeUnit::Microseconds) => *v * 1_000i64, + (TimeUnit::Milliseconds, TimeUnit::Nanoseconds) => *v * 1_000_000i64, + _ => *v, + }, + *tu_r, + ), + + // to decimal + #[cfg(feature = "dtype-decimal")] + (av, DataType::Decimal(prec, scale)) if av.is_integer() => { + let value = av.try_extract::().unwrap(); + let scale = scale.unwrap_or(0); + let factor = 10_i128.pow(scale as _); // Conversion to u32 is safe, max value is 38. + let converted = value.checked_mul(factor)?; + + // Check if the converted value fits into the specified precision + let prec = prec.unwrap_or(38) as u32; + let num_digits = (converted.abs() as f64).log10().ceil() as u32; + if num_digits > prec { + return None; + } + + AnyValue::Decimal(converted, scale) + }, + #[cfg(feature = "dtype-decimal")] + (AnyValue::Decimal(value, scale_av), DataType::Decimal(_, scale)) => { + let Some(scale) = scale else { + return Some(self.clone()); + }; + // TODO: Allow lossy conversion? + let scale_diff = scale.checked_sub(*scale_av)?; + let factor = 10_i128.pow(scale_diff as _); // Conversion is safe, max value is 38. + let converted = value.checked_mul(factor)?; + AnyValue::Decimal(converted, *scale) + }, + + // to self + (av, dtype) if av.dtype() == *dtype => self.clone(), + + _ => return None, + }; + Some(new_av) + } + + /// Cast `AnyValue` to the provided data type and return a new `AnyValue` with type `dtype`, + /// if possible. + pub fn try_strict_cast(&self, dtype: &'a DataType) -> PolarsResult> { + self.strict_cast(dtype).ok_or_else( + || polars_err!(ComputeError: "cannot cast any-value {:?} to dtype '{}'", self, dtype), + ) + } + + pub fn cast(&self, dtype: &'a DataType) -> AnyValue<'a> { + match self.strict_cast(dtype) { + Some(av) => av, + None => AnyValue::Null, + } + } + + pub fn idx(&self) -> IdxSize { + match self { + #[cfg(not(feature = "bigidx"))] + Self::UInt32(v) => *v, + #[cfg(feature = "bigidx")] + Self::UInt64(v) => *v, + _ => panic!("expected index type found {self:?}"), + } + } + + pub fn str_value(&self) -> Cow<'a, str> { + match self { + Self::String(s) => Cow::Borrowed(s), + Self::StringOwned(s) => Cow::Owned(s.to_string()), + Self::Null => Cow::Borrowed("null"), + #[cfg(feature = "dtype-categorical")] + Self::Categorical(idx, rev, arr) | AnyValue::Enum(idx, rev, arr) => { + if arr.is_null() { + Cow::Borrowed(rev.get(*idx)) + } else { + unsafe { Cow::Borrowed(arr.deref_unchecked().value(*idx as usize)) } + } + }, + #[cfg(feature = "dtype-categorical")] + Self::CategoricalOwned(idx, rev, arr) | AnyValue::EnumOwned(idx, rev, arr) => { + if arr.is_null() { + Cow::Owned(rev.get(*idx).to_string()) + } else { + unsafe { Cow::Borrowed(arr.deref_unchecked().value(*idx as usize)) } + } + }, + av => Cow::Owned(av.to_string()), + } + } +} + +impl From> for DataType { + fn from(value: AnyValue<'_>) -> Self { + value.dtype() + } +} + +impl<'a> From<&AnyValue<'a>> for DataType { + fn from(value: &AnyValue<'a>) -> Self { + value.dtype() + } +} + +impl AnyValue<'_> { + pub fn hash_impl(&self, state: &mut H, cheap: bool) { + use AnyValue::*; + std::mem::discriminant(self).hash(state); + match self { + Int8(v) => v.hash(state), + Int16(v) => v.hash(state), + Int32(v) => v.hash(state), + Int64(v) => v.hash(state), + Int128(v) => feature_gated!("dtype-i128", v.hash(state)), + UInt8(v) => v.hash(state), + UInt16(v) => v.hash(state), + UInt32(v) => v.hash(state), + UInt64(v) => v.hash(state), + String(v) => v.hash(state), + StringOwned(v) => v.hash(state), + Float32(v) => v.to_ne_bytes().hash(state), + Float64(v) => v.to_ne_bytes().hash(state), + Binary(v) => v.hash(state), + BinaryOwned(v) => v.hash(state), + Boolean(v) => v.hash(state), + List(v) => { + if !cheap || v.len() < CHEAP_SERIES_HASH_LIMIT { + Hash::hash(&Wrap(v.clone()), state) + } + }, + #[cfg(feature = "dtype-array")] + Array(v, width) => { + if !cheap || v.len() < CHEAP_SERIES_HASH_LIMIT { + Hash::hash(&Wrap(v.clone()), state) + } + width.hash(state) + }, + #[cfg(feature = "dtype-date")] + Date(v) => v.hash(state), + #[cfg(feature = "dtype-datetime")] + Datetime(v, tu, tz) => { + v.hash(state); + tu.hash(state); + tz.hash(state); + }, + #[cfg(feature = "dtype-datetime")] + DatetimeOwned(v, tu, tz) => { + v.hash(state); + tu.hash(state); + tz.hash(state); + }, + #[cfg(feature = "dtype-duration")] + Duration(v, tz) => { + v.hash(state); + tz.hash(state); + }, + #[cfg(feature = "dtype-time")] + Time(v) => v.hash(state), + #[cfg(feature = "dtype-categorical")] + Categorical(v, _, _) + | CategoricalOwned(v, _, _) + | Enum(v, _, _) + | EnumOwned(v, _, _) => v.hash(state), + #[cfg(feature = "object")] + Object(_) => {}, + #[cfg(feature = "object")] + ObjectOwned(_) => {}, + #[cfg(feature = "dtype-struct")] + Struct(_, _, _) => { + if !cheap { + let mut buf = vec![]; + self._materialize_struct_av(&mut buf); + buf.hash(state) + } + }, + #[cfg(feature = "dtype-struct")] + StructOwned(v) => v.0.hash(state), + #[cfg(feature = "dtype-decimal")] + Decimal(v, k) => { + v.hash(state); + k.hash(state); + }, + Null => {}, + } + } +} + +impl Hash for AnyValue<'_> { + fn hash(&self, state: &mut H) { + self.hash_impl(state, false) + } +} + +impl Eq for AnyValue<'_> {} + +impl<'a, T> From> for AnyValue<'a> +where + T: Into>, +{ + #[inline] + fn from(a: Option) -> Self { + match a { + None => AnyValue::Null, + Some(v) => v.into(), + } + } +} + +impl<'a> AnyValue<'a> { + #[cfg(any(feature = "dtype-date", feature = "dtype-datetime"))] + pub(crate) fn as_date(&self) -> AnyValue<'static> { + match self { + #[cfg(feature = "dtype-date")] + AnyValue::Int32(v) => AnyValue::Date(*v), + AnyValue::Null => AnyValue::Null, + dt => panic!("cannot create date from other type. dtype: {dt}"), + } + } + #[cfg(feature = "dtype-datetime")] + pub(crate) fn as_datetime(&self, tu: TimeUnit, tz: Option<&'a TimeZone>) -> AnyValue<'a> { + match self { + AnyValue::Int64(v) => AnyValue::Datetime(*v, tu, tz), + AnyValue::Null => AnyValue::Null, + dt => panic!("cannot create date from other type. dtype: {dt}"), + } + } + + #[cfg(feature = "dtype-duration")] + pub(crate) fn as_duration(&self, tu: TimeUnit) -> AnyValue<'static> { + match self { + AnyValue::Int64(v) => AnyValue::Duration(*v, tu), + AnyValue::Null => AnyValue::Null, + dt => panic!("cannot create date from other type. dtype: {dt}"), + } + } + + #[cfg(feature = "dtype-time")] + pub(crate) fn as_time(&self) -> AnyValue<'static> { + match self { + AnyValue::Int64(v) => AnyValue::Time(*v), + AnyValue::Null => AnyValue::Null, + dt => panic!("cannot create date from other type. dtype: {dt}"), + } + } + + pub(crate) fn to_i128(&self) -> Option { + match self { + AnyValue::UInt8(v) => Some((*v).into()), + AnyValue::UInt16(v) => Some((*v).into()), + AnyValue::UInt32(v) => Some((*v).into()), + AnyValue::UInt64(v) => Some((*v).into()), + AnyValue::Int8(v) => Some((*v).into()), + AnyValue::Int16(v) => Some((*v).into()), + AnyValue::Int32(v) => Some((*v).into()), + AnyValue::Int64(v) => Some((*v).into()), + AnyValue::Int128(v) => Some(*v), + _ => None, + } + } + + pub(crate) fn to_f64(&self) -> Option { + match self { + AnyValue::Float32(v) => Some((*v).into()), + AnyValue::Float64(v) => Some(*v), + _ => None, + } + } + + #[must_use] + pub fn add(&self, rhs: &AnyValue) -> AnyValue<'static> { + use AnyValue::*; + match (self, rhs) { + (Null, r) => r.clone().into_static(), + (l, Null) => l.clone().into_static(), + (Int32(l), Int32(r)) => Int32(l + r), + (Int64(l), Int64(r)) => Int64(l + r), + (UInt32(l), UInt32(r)) => UInt32(l + r), + (UInt64(l), UInt64(r)) => UInt64(l + r), + (Float32(l), Float32(r)) => Float32(l + r), + (Float64(l), Float64(r)) => Float64(l + r), + #[cfg(feature = "dtype-duration")] + (Duration(l, lu), Duration(r, ru)) => { + if lu != ru { + unimplemented!("adding durations with different units is not supported here"); + } + + Duration(l + r, *lu) + }, + #[cfg(feature = "dtype-decimal")] + (Decimal(l, ls), Decimal(r, rs)) => { + if ls != rs { + unimplemented!("adding decimals with different scales is not supported here"); + } + + Decimal(l + r, *ls) + }, + _ => unimplemented!(), + } + } + + #[inline] + pub fn as_borrowed(&self) -> AnyValue<'_> { + match self { + AnyValue::BinaryOwned(data) => AnyValue::Binary(data), + AnyValue::StringOwned(data) => AnyValue::String(data.as_str()), + #[cfg(feature = "dtype-datetime")] + AnyValue::DatetimeOwned(v, tu, tz) => { + AnyValue::Datetime(*v, *tu, tz.as_ref().map(AsRef::as_ref)) + }, + #[cfg(feature = "dtype-categorical")] + AnyValue::CategoricalOwned(v, rev, arr) => { + AnyValue::Categorical(*v, rev.as_ref(), *arr) + }, + #[cfg(feature = "dtype-categorical")] + AnyValue::EnumOwned(v, rev, arr) => AnyValue::Enum(*v, rev.as_ref(), *arr), + av => av.clone(), + } + } + + /// Try to coerce to an AnyValue with static lifetime. + /// This can be done if it does not borrow any values. + #[inline] + pub fn into_static(self) -> AnyValue<'static> { + use AnyValue::*; + match self { + Null => Null, + Int8(v) => Int8(v), + Int16(v) => Int16(v), + Int32(v) => Int32(v), + Int64(v) => Int64(v), + Int128(v) => Int128(v), + UInt8(v) => UInt8(v), + UInt16(v) => UInt16(v), + UInt32(v) => UInt32(v), + UInt64(v) => UInt64(v), + Boolean(v) => Boolean(v), + Float32(v) => Float32(v), + Float64(v) => Float64(v), + #[cfg(feature = "dtype-datetime")] + Datetime(v, tu, tz) => DatetimeOwned(v, tu, tz.map(|v| Arc::new(v.clone()))), + #[cfg(feature = "dtype-datetime")] + DatetimeOwned(v, tu, tz) => DatetimeOwned(v, tu, tz), + #[cfg(feature = "dtype-date")] + Date(v) => Date(v), + #[cfg(feature = "dtype-duration")] + Duration(v, tu) => Duration(v, tu), + #[cfg(feature = "dtype-time")] + Time(v) => Time(v), + List(v) => List(v), + #[cfg(feature = "dtype-array")] + Array(s, size) => Array(s, size), + String(v) => StringOwned(PlSmallStr::from_str(v)), + StringOwned(v) => StringOwned(v), + Binary(v) => BinaryOwned(v.to_vec()), + BinaryOwned(v) => BinaryOwned(v), + #[cfg(feature = "object")] + Object(v) => ObjectOwned(OwnedObject(v.to_boxed())), + #[cfg(feature = "dtype-struct")] + Struct(idx, arr, fields) => { + let avs = struct_to_avs_static(idx, arr, fields); + StructOwned(Box::new((avs, fields.to_vec()))) + }, + #[cfg(feature = "dtype-struct")] + StructOwned(payload) => { + let av = StructOwned(payload); + // SAFETY: owned is already static + unsafe { std::mem::transmute::, AnyValue<'static>>(av) } + }, + #[cfg(feature = "object")] + ObjectOwned(payload) => { + let av = ObjectOwned(payload); + // SAFETY: owned is already static + unsafe { std::mem::transmute::, AnyValue<'static>>(av) } + }, + #[cfg(feature = "dtype-decimal")] + Decimal(val, scale) => Decimal(val, scale), + #[cfg(feature = "dtype-categorical")] + Categorical(v, rev, arr) => CategoricalOwned(v, Arc::new(rev.clone()), arr), + #[cfg(feature = "dtype-categorical")] + CategoricalOwned(v, rev, arr) => CategoricalOwned(v, rev, arr), + #[cfg(feature = "dtype-categorical")] + Enum(v, rev, arr) => EnumOwned(v, Arc::new(rev.clone()), arr), + #[cfg(feature = "dtype-categorical")] + EnumOwned(v, rev, arr) => EnumOwned(v, rev, arr), + } + } + + /// Get a reference to the `&str` contained within [`AnyValue`]. + pub fn get_str(&self) -> Option<&str> { + match self { + AnyValue::String(s) => Some(s), + AnyValue::StringOwned(s) => Some(s.as_str()), + #[cfg(feature = "dtype-categorical")] + AnyValue::Categorical(idx, rev, arr) | AnyValue::Enum(idx, rev, arr) => { + let s = if arr.is_null() { + rev.get(*idx) + } else { + unsafe { arr.deref_unchecked().value(*idx as usize) } + }; + Some(s) + }, + #[cfg(feature = "dtype-categorical")] + AnyValue::CategoricalOwned(idx, rev, arr) | AnyValue::EnumOwned(idx, rev, arr) => { + let s = if arr.is_null() { + rev.get(*idx) + } else { + unsafe { arr.deref_unchecked().value(*idx as usize) } + }; + Some(s) + }, + _ => None, + } + } +} + +impl<'a> From> for Option { + fn from(val: AnyValue<'a>) -> Self { + use AnyValue::*; + match val { + Null => None, + Int32(v) => Some(v as i64), + Int64(v) => Some(v), + UInt32(v) => Some(v as i64), + _ => todo!(), + } + } +} + +impl AnyValue<'_> { + #[inline] + pub fn eq_missing(&self, other: &Self, null_equal: bool) -> bool { + fn struct_owned_value_iter<'a>( + v: &'a (Vec>, Vec), + ) -> impl ExactSizeIterator> { + v.0.iter().map(|v| v.as_borrowed()) + } + fn struct_value_iter( + idx: usize, + arr: &StructArray, + ) -> impl ExactSizeIterator> { + assert!(idx < arr.len()); + + arr.values().iter().map(move |field_arr| unsafe { + // SAFETY: We asserted before that idx is smaller than the array length. Since it + // is an invariant of StructArray that all fields have the same length this is fine + // to do. + field_arr.get_unchecked(idx) + }) + } + + fn struct_eq_missing<'a>( + l: impl ExactSizeIterator>, + r: impl ExactSizeIterator>, + null_equal: bool, + ) -> bool { + if l.len() != r.len() { + return false; + } + + l.zip(r).all(|(lv, rv)| lv.eq_missing(&rv, null_equal)) + } + + use AnyValue::*; + match (self, other) { + // Map to borrowed. + (StringOwned(l), r) => AnyValue::String(l.as_str()) == *r, + (BinaryOwned(l), r) => AnyValue::Binary(l.as_slice()) == *r, + #[cfg(feature = "object")] + (ObjectOwned(l), r) => AnyValue::Object(&*l.0) == *r, + (l, StringOwned(r)) => *l == AnyValue::String(r.as_str()), + (l, BinaryOwned(r)) => *l == AnyValue::Binary(r.as_slice()), + #[cfg(feature = "object")] + (l, ObjectOwned(r)) => *l == AnyValue::Object(&*r.0), + #[cfg(feature = "dtype-datetime")] + (DatetimeOwned(lv, ltu, ltz), r) => { + Datetime(*lv, *ltu, ltz.as_ref().map(|v| v.as_ref())) == *r + }, + #[cfg(feature = "dtype-datetime")] + (l, DatetimeOwned(rv, rtu, rtz)) => { + *l == Datetime(*rv, *rtu, rtz.as_ref().map(|v| v.as_ref())) + }, + #[cfg(feature = "dtype-categorical")] + (CategoricalOwned(lv, lrev, larr), r) => Categorical(*lv, lrev.as_ref(), *larr) == *r, + #[cfg(feature = "dtype-categorical")] + (l, CategoricalOwned(rv, rrev, rarr)) => *l == Categorical(*rv, rrev.as_ref(), *rarr), + #[cfg(feature = "dtype-categorical")] + (EnumOwned(lv, lrev, larr), r) => Enum(*lv, lrev.as_ref(), *larr) == *r, + #[cfg(feature = "dtype-categorical")] + (l, EnumOwned(rv, rrev, rarr)) => *l == Enum(*rv, rrev.as_ref(), *rarr), + + // Comparison with null. + (Null, Null) => null_equal, + (Null, _) => false, + (_, Null) => false, + + // Equality between equal types. + (Boolean(l), Boolean(r)) => *l == *r, + (UInt8(l), UInt8(r)) => *l == *r, + (UInt16(l), UInt16(r)) => *l == *r, + (UInt32(l), UInt32(r)) => *l == *r, + (UInt64(l), UInt64(r)) => *l == *r, + (Int8(l), Int8(r)) => *l == *r, + (Int16(l), Int16(r)) => *l == *r, + (Int32(l), Int32(r)) => *l == *r, + (Int64(l), Int64(r)) => *l == *r, + (Int128(l), Int128(r)) => *l == *r, + (Float32(l), Float32(r)) => l.to_total_ord() == r.to_total_ord(), + (Float64(l), Float64(r)) => l.to_total_ord() == r.to_total_ord(), + (String(l), String(r)) => l == r, + (Binary(l), Binary(r)) => l == r, + #[cfg(feature = "dtype-time")] + (Time(l), Time(r)) => *l == *r, + #[cfg(all(feature = "dtype-datetime", feature = "dtype-date"))] + (Date(l), Date(r)) => *l == *r, + #[cfg(all(feature = "dtype-datetime", feature = "dtype-date"))] + (Datetime(l, tul, tzl), Datetime(r, tur, tzr)) => { + *l == *r && *tul == *tur && tzl == tzr + }, + (List(l), List(r)) => l == r, + #[cfg(feature = "dtype-categorical")] + (Categorical(idx_l, rev_l, ptr_l), Categorical(idx_r, rev_r, ptr_r)) => { + if !same_revmap(rev_l, *ptr_l, rev_r, *ptr_r) { + // We can't support this because our Hash impl directly hashes the index. If you + // add support for this we must change the Hash impl. + unimplemented!( + "comparing categoricals with different revmaps is not supported" + ); + } + + idx_l == idx_r + }, + #[cfg(feature = "dtype-categorical")] + (Enum(idx_l, rev_l, ptr_l), Enum(idx_r, rev_r, ptr_r)) => { + // We can't support this because our Hash impl directly hashes the index. If you + // add support for this we must change the Hash impl. + if !same_revmap(rev_l, *ptr_l, rev_r, *ptr_r) { + unimplemented!("comparing enums with different revmaps is not supported"); + } + + idx_l == idx_r + }, + #[cfg(feature = "dtype-duration")] + (Duration(l, tu_l), Duration(r, tu_r)) => l == r && tu_l == tu_r, + + #[cfg(feature = "dtype-struct")] + (StructOwned(l), StructOwned(r)) => struct_eq_missing( + struct_owned_value_iter(l.as_ref()), + struct_owned_value_iter(r.as_ref()), + null_equal, + ), + #[cfg(feature = "dtype-struct")] + (StructOwned(l), Struct(idx, arr, _)) => struct_eq_missing( + struct_owned_value_iter(l.as_ref()), + struct_value_iter(*idx, arr), + null_equal, + ), + #[cfg(feature = "dtype-struct")] + (Struct(idx, arr, _), StructOwned(r)) => struct_eq_missing( + struct_value_iter(*idx, arr), + struct_owned_value_iter(r.as_ref()), + null_equal, + ), + #[cfg(feature = "dtype-struct")] + (Struct(l_idx, l_arr, _), Struct(r_idx, r_arr, _)) => struct_eq_missing( + struct_value_iter(*l_idx, l_arr), + struct_value_iter(*r_idx, r_arr), + null_equal, + ), + #[cfg(feature = "dtype-decimal")] + (Decimal(l_v, l_s), Decimal(r_v, r_s)) => { + // l_v / 10**l_s == r_v / 10**r_s + if l_s == r_s && l_v == r_v || *l_v == 0 && *r_v == 0 { + true + } else if l_s < r_s { + // l_v * 10**(r_s - l_s) == r_v + if let Some(lhs) = (|| { + let exp = i128::checked_pow(10, (r_s - l_s).try_into().ok()?)?; + l_v.checked_mul(exp) + })() { + lhs == *r_v + } else { + false + } + } else { + // l_v == r_v * 10**(l_s - r_s) + if let Some(rhs) = (|| { + let exp = i128::checked_pow(10, (l_s - r_s).try_into().ok()?)?; + r_v.checked_mul(exp) + })() { + *l_v == rhs + } else { + false + } + } + }, + #[cfg(feature = "object")] + (Object(l), Object(r)) => l == r, + #[cfg(feature = "dtype-array")] + (Array(l_values, l_size), Array(r_values, r_size)) => { + if l_size != r_size { + return false; + } + + debug_assert_eq!(l_values.len(), *l_size); + debug_assert_eq!(r_values.len(), *r_size); + + let mut is_equal = true; + for i in 0..*l_size { + let l = unsafe { l_values.get_unchecked(i) }; + let r = unsafe { r_values.get_unchecked(i) }; + + is_equal &= l.eq_missing(&r, null_equal); + } + is_equal + }, + + (l, r) if l.to_i128().is_some() && r.to_i128().is_some() => l.to_i128() == r.to_i128(), + (l, r) if l.to_f64().is_some() && r.to_f64().is_some() => { + l.to_f64().unwrap().to_total_ord() == r.to_f64().unwrap().to_total_ord() + }, + + (_, _) => { + unimplemented!( + "scalar eq_missing for mixed dtypes {self:?} and {other:?} is not supported" + ) + }, + } + } +} + +impl PartialEq for AnyValue<'_> { + #[inline] + fn eq(&self, other: &Self) -> bool { + self.eq_missing(other, true) + } +} + +impl PartialOrd for AnyValue<'_> { + /// Only implemented for the same types and physical types! + fn partial_cmp(&self, other: &Self) -> Option { + use AnyValue::*; + match (self, &other) { + // Map to borrowed. + (StringOwned(l), r) => AnyValue::String(l.as_str()).partial_cmp(r), + (BinaryOwned(l), r) => AnyValue::Binary(l.as_slice()).partial_cmp(r), + #[cfg(feature = "object")] + (ObjectOwned(l), r) => AnyValue::Object(&*l.0).partial_cmp(r), + (l, StringOwned(r)) => l.partial_cmp(&AnyValue::String(r.as_str())), + (l, BinaryOwned(r)) => l.partial_cmp(&AnyValue::Binary(r.as_slice())), + #[cfg(feature = "object")] + (l, ObjectOwned(r)) => l.partial_cmp(&AnyValue::Object(&*r.0)), + #[cfg(feature = "dtype-datetime")] + (DatetimeOwned(lv, ltu, ltz), r) => { + Datetime(*lv, *ltu, ltz.as_ref().map(|v| v.as_ref())).partial_cmp(r) + }, + #[cfg(feature = "dtype-datetime")] + (l, DatetimeOwned(rv, rtu, rtz)) => { + l.partial_cmp(&Datetime(*rv, *rtu, rtz.as_ref().map(|v| v.as_ref()))) + }, + #[cfg(feature = "dtype-categorical")] + (CategoricalOwned(lv, lrev, larr), r) => { + Categorical(*lv, lrev.as_ref(), *larr).partial_cmp(r) + }, + #[cfg(feature = "dtype-categorical")] + (l, CategoricalOwned(rv, rrev, rarr)) => { + l.partial_cmp(&Categorical(*rv, rrev.as_ref(), *rarr)) + }, + #[cfg(feature = "dtype-categorical")] + (EnumOwned(lv, lrev, larr), r) => Enum(*lv, lrev.as_ref(), *larr).partial_cmp(r), + #[cfg(feature = "dtype-categorical")] + (l, EnumOwned(rv, rrev, rarr)) => l.partial_cmp(&Enum(*rv, rrev.as_ref(), *rarr)), + + // Comparison with null. + (Null, Null) => Some(Ordering::Equal), + (Null, _) => Some(Ordering::Less), + (_, Null) => Some(Ordering::Greater), + + // Comparison between equal types. + (Boolean(l), Boolean(r)) => l.partial_cmp(r), + (UInt8(l), UInt8(r)) => l.partial_cmp(r), + (UInt16(l), UInt16(r)) => l.partial_cmp(r), + (UInt32(l), UInt32(r)) => l.partial_cmp(r), + (UInt64(l), UInt64(r)) => l.partial_cmp(r), + (Int8(l), Int8(r)) => l.partial_cmp(r), + (Int16(l), Int16(r)) => l.partial_cmp(r), + (Int32(l), Int32(r)) => l.partial_cmp(r), + (Int64(l), Int64(r)) => l.partial_cmp(r), + (Int128(l), Int128(r)) => l.partial_cmp(r), + (Float32(l), Float32(r)) => Some(l.tot_cmp(r)), + (Float64(l), Float64(r)) => Some(l.tot_cmp(r)), + (String(l), String(r)) => l.partial_cmp(r), + (Binary(l), Binary(r)) => l.partial_cmp(r), + #[cfg(feature = "dtype-date")] + (Date(l), Date(r)) => l.partial_cmp(r), + #[cfg(feature = "dtype-datetime")] + (Datetime(lt, lu, lz), Datetime(rt, ru, rz)) => { + if lu != ru || lz != rz { + unimplemented!( + "comparing datetimes with different units or timezones is not supported" + ); + } + + lt.partial_cmp(rt) + }, + #[cfg(feature = "dtype-duration")] + (Duration(lt, lu), Duration(rt, ru)) => { + if lu != ru { + unimplemented!("comparing durations with different units is not supported"); + } + + lt.partial_cmp(rt) + }, + #[cfg(feature = "dtype-time")] + (Time(l), Time(r)) => l.partial_cmp(r), + #[cfg(feature = "dtype-categorical")] + (Categorical(..), Categorical(..)) => { + unimplemented!( + "can't order categoricals as AnyValues, dtype for ordering is needed" + ) + }, + #[cfg(feature = "dtype-categorical")] + (Enum(..), Enum(..)) => { + unimplemented!("can't order enums as AnyValues, dtype for ordering is needed") + }, + (List(_), List(_)) => { + unimplemented!("ordering for List dtype is not supported") + }, + #[cfg(feature = "dtype-array")] + (Array(..), Array(..)) => { + unimplemented!("ordering for Array dtype is not supported") + }, + #[cfg(feature = "object")] + (Object(_), Object(_)) => { + unimplemented!("ordering for Object dtype is not supported") + }, + #[cfg(feature = "dtype-struct")] + (StructOwned(_), StructOwned(_)) + | (StructOwned(_), Struct(..)) + | (Struct(..), StructOwned(_)) + | (Struct(..), Struct(..)) => { + unimplemented!("ordering for Struct dtype is not supported") + }, + #[cfg(feature = "dtype-decimal")] + (Decimal(l_v, l_s), Decimal(r_v, r_s)) => { + // l_v / 10**l_s <=> r_v / 10**r_s + if l_s == r_s && l_v == r_v || *l_v == 0 && *r_v == 0 { + Some(Ordering::Equal) + } else if l_s < r_s { + // l_v * 10**(r_s - l_s) <=> r_v + if let Some(lhs) = (|| { + let exp = i128::checked_pow(10, (r_s - l_s).try_into().ok()?)?; + l_v.checked_mul(exp) + })() { + lhs.partial_cmp(r_v) + } else { + Some(Ordering::Greater) + } + } else { + // l_v <=> r_v * 10**(l_s - r_s) + if let Some(rhs) = (|| { + let exp = i128::checked_pow(10, (l_s - r_s).try_into().ok()?)?; + r_v.checked_mul(exp) + })() { + l_v.partial_cmp(&rhs) + } else { + Some(Ordering::Less) + } + } + }, + + (_, _) => { + unimplemented!( + "scalar ordering for mixed dtypes {self:?} and {other:?} is not supported" + ) + }, + } + } +} + +impl TotalEq for AnyValue<'_> { + #[inline] + fn tot_eq(&self, other: &Self) -> bool { + self.eq_missing(other, true) + } +} + +#[cfg(feature = "dtype-struct")] +fn struct_to_avs_static(idx: usize, arr: &StructArray, fields: &[Field]) -> Vec> { + assert!(idx < arr.len()); + + let arrs = arr.values(); + + debug_assert_eq!(arrs.len(), fields.len()); + + arrs.iter() + .zip(fields) + .map(|(arr, field)| { + // SAFETY: We asserted above that the length of StructArray is larger than `idx`. Since + // StructArray has the invariant that each array is the same length. This is okay to do + // now. + unsafe { arr_to_any_value(arr.as_ref(), idx, &field.dtype) }.into_static() + }) + .collect() +} + +#[cfg(feature = "dtype-categorical")] +fn same_revmap( + rev_l: &RevMapping, + ptr_l: SyncPtr, + rev_r: &RevMapping, + ptr_r: SyncPtr, +) -> bool { + if ptr_l.is_null() && ptr_r.is_null() { + match (rev_l, rev_r) { + (RevMapping::Global(_, _, id_l), RevMapping::Global(_, _, id_r)) => id_l == id_r, + (RevMapping::Local(_, id_l), RevMapping::Local(_, id_r)) => id_l == id_r, + _ => false, + } + } else { + ptr_l == ptr_r + } +} + +pub trait GetAnyValue { + /// # Safety + /// + /// Get an value without doing bound checks. + unsafe fn get_unchecked(&self, index: usize) -> AnyValue; +} + +impl GetAnyValue for ArrayRef { + // Should only be called with physical types + unsafe fn get_unchecked(&self, index: usize) -> AnyValue { + match self.dtype() { + ArrowDataType::Int8 => { + let arr = self + .as_any() + .downcast_ref::>() + .unwrap_unchecked(); + match arr.get_unchecked(index) { + None => AnyValue::Null, + Some(v) => AnyValue::Int8(v), + } + }, + ArrowDataType::Int16 => { + let arr = self + .as_any() + .downcast_ref::>() + .unwrap_unchecked(); + match arr.get_unchecked(index) { + None => AnyValue::Null, + Some(v) => AnyValue::Int16(v), + } + }, + ArrowDataType::Int32 => { + let arr = self + .as_any() + .downcast_ref::>() + .unwrap_unchecked(); + match arr.get_unchecked(index) { + None => AnyValue::Null, + Some(v) => AnyValue::Int32(v), + } + }, + ArrowDataType::Int64 => { + let arr = self + .as_any() + .downcast_ref::>() + .unwrap_unchecked(); + match arr.get_unchecked(index) { + None => AnyValue::Null, + Some(v) => AnyValue::Int64(v), + } + }, + ArrowDataType::Int128 => { + let arr = self + .as_any() + .downcast_ref::>() + .unwrap_unchecked(); + match arr.get_unchecked(index) { + None => AnyValue::Null, + Some(v) => AnyValue::Int128(v), + } + }, + ArrowDataType::UInt8 => { + let arr = self + .as_any() + .downcast_ref::>() + .unwrap_unchecked(); + match arr.get_unchecked(index) { + None => AnyValue::Null, + Some(v) => AnyValue::UInt8(v), + } + }, + ArrowDataType::UInt16 => { + let arr = self + .as_any() + .downcast_ref::>() + .unwrap_unchecked(); + match arr.get_unchecked(index) { + None => AnyValue::Null, + Some(v) => AnyValue::UInt16(v), + } + }, + ArrowDataType::UInt32 => { + let arr = self + .as_any() + .downcast_ref::>() + .unwrap_unchecked(); + match arr.get_unchecked(index) { + None => AnyValue::Null, + Some(v) => AnyValue::UInt32(v), + } + }, + ArrowDataType::UInt64 => { + let arr = self + .as_any() + .downcast_ref::>() + .unwrap_unchecked(); + match arr.get_unchecked(index) { + None => AnyValue::Null, + Some(v) => AnyValue::UInt64(v), + } + }, + ArrowDataType::Float32 => { + let arr = self + .as_any() + .downcast_ref::>() + .unwrap_unchecked(); + match arr.get_unchecked(index) { + None => AnyValue::Null, + Some(v) => AnyValue::Float32(v), + } + }, + ArrowDataType::Float64 => { + let arr = self + .as_any() + .downcast_ref::>() + .unwrap_unchecked(); + match arr.get_unchecked(index) { + None => AnyValue::Null, + Some(v) => AnyValue::Float64(v), + } + }, + ArrowDataType::Boolean => { + let arr = self + .as_any() + .downcast_ref::() + .unwrap_unchecked(); + match arr.get_unchecked(index) { + None => AnyValue::Null, + Some(v) => AnyValue::Boolean(v), + } + }, + ArrowDataType::LargeUtf8 => { + let arr = self + .as_any() + .downcast_ref::() + .unwrap_unchecked(); + match arr.get_unchecked(index) { + None => AnyValue::Null, + Some(v) => AnyValue::String(v), + } + }, + _ => unimplemented!(), + } + } +} + +impl From for AnyValue<'static> { + fn from(value: K) -> Self { + unsafe { + match K::PRIMITIVE { + PrimitiveType::Int8 => AnyValue::Int8(NumCast::from(value).unwrap_unchecked()), + PrimitiveType::Int16 => AnyValue::Int16(NumCast::from(value).unwrap_unchecked()), + PrimitiveType::Int32 => AnyValue::Int32(NumCast::from(value).unwrap_unchecked()), + PrimitiveType::Int64 => AnyValue::Int64(NumCast::from(value).unwrap_unchecked()), + PrimitiveType::Int128 => AnyValue::Int128(NumCast::from(value).unwrap_unchecked()), + PrimitiveType::UInt8 => AnyValue::UInt8(NumCast::from(value).unwrap_unchecked()), + PrimitiveType::UInt16 => AnyValue::UInt16(NumCast::from(value).unwrap_unchecked()), + PrimitiveType::UInt32 => AnyValue::UInt32(NumCast::from(value).unwrap_unchecked()), + PrimitiveType::UInt64 => AnyValue::UInt64(NumCast::from(value).unwrap_unchecked()), + PrimitiveType::Float32 => { + AnyValue::Float32(NumCast::from(value).unwrap_unchecked()) + }, + PrimitiveType::Float64 => { + AnyValue::Float64(NumCast::from(value).unwrap_unchecked()) + }, + // not supported by polars + _ => unreachable!(), + } + } + } +} + +impl<'a> From<&'a [u8]> for AnyValue<'a> { + fn from(value: &'a [u8]) -> Self { + AnyValue::Binary(value) + } +} + +impl<'a> From<&'a str> for AnyValue<'a> { + fn from(value: &'a str) -> Self { + AnyValue::String(value) + } +} + +impl From for AnyValue<'static> { + fn from(value: bool) -> Self { + AnyValue::Boolean(value) + } +} + +#[cfg(test)] +mod test { + #[cfg(feature = "dtype-categorical")] + use super::*; + + #[test] + #[cfg(feature = "dtype-categorical")] + fn test_arrow_dtypes_to_polars() { + let dtypes = [ + ( + ArrowDataType::Duration(ArrowTimeUnit::Nanosecond), + DataType::Duration(TimeUnit::Nanoseconds), + ), + ( + ArrowDataType::Duration(ArrowTimeUnit::Millisecond), + DataType::Duration(TimeUnit::Milliseconds), + ), + ( + ArrowDataType::Date64, + DataType::Datetime(TimeUnit::Milliseconds, None), + ), + ( + ArrowDataType::Timestamp(ArrowTimeUnit::Nanosecond, None), + DataType::Datetime(TimeUnit::Nanoseconds, None), + ), + ( + ArrowDataType::Timestamp(ArrowTimeUnit::Microsecond, None), + DataType::Datetime(TimeUnit::Microseconds, None), + ), + ( + ArrowDataType::Timestamp(ArrowTimeUnit::Millisecond, None), + DataType::Datetime(TimeUnit::Milliseconds, None), + ), + ( + ArrowDataType::Timestamp(ArrowTimeUnit::Second, None), + DataType::Datetime(TimeUnit::Milliseconds, None), + ), + ( + ArrowDataType::Timestamp(ArrowTimeUnit::Second, Some(PlSmallStr::EMPTY)), + DataType::Datetime(TimeUnit::Milliseconds, None), + ), + (ArrowDataType::LargeUtf8, DataType::String), + (ArrowDataType::Utf8, DataType::String), + (ArrowDataType::LargeBinary, DataType::Binary), + (ArrowDataType::Binary, DataType::Binary), + ( + ArrowDataType::Time64(ArrowTimeUnit::Nanosecond), + DataType::Time, + ), + ( + ArrowDataType::Time64(ArrowTimeUnit::Millisecond), + DataType::Time, + ), + ( + ArrowDataType::Time64(ArrowTimeUnit::Microsecond), + DataType::Time, + ), + (ArrowDataType::Time64(ArrowTimeUnit::Second), DataType::Time), + ( + ArrowDataType::Time32(ArrowTimeUnit::Nanosecond), + DataType::Time, + ), + ( + ArrowDataType::Time32(ArrowTimeUnit::Millisecond), + DataType::Time, + ), + ( + ArrowDataType::Time32(ArrowTimeUnit::Microsecond), + DataType::Time, + ), + (ArrowDataType::Time32(ArrowTimeUnit::Second), DataType::Time), + ( + ArrowDataType::List(Box::new(ArrowField::new( + PlSmallStr::from_static("item"), + ArrowDataType::Float64, + true, + ))), + DataType::List(DataType::Float64.into()), + ), + ( + ArrowDataType::LargeList(Box::new(ArrowField::new( + PlSmallStr::from_static("item"), + ArrowDataType::Float64, + true, + ))), + DataType::List(DataType::Float64.into()), + ), + ]; + + for (dt_a, dt_p) in dtypes { + let dt = DataType::from_arrow_dtype(&dt_a); + + assert_eq!(dt_p, dt); + } + } +} diff --git a/crates/polars-core/src/datatypes/dtype.rs b/crates/polars-core/src/datatypes/dtype.rs new file mode 100644 index 000000000000..5d19efb3bc7b --- /dev/null +++ b/crates/polars-core/src/datatypes/dtype.rs @@ -0,0 +1,1171 @@ +use std::collections::BTreeMap; + +use arrow::datatypes::{DTYPE_CATEGORICAL, DTYPE_ENUM_VALUES, Metadata}; +#[cfg(feature = "dtype-array")] +use polars_utils::format_tuple; +use polars_utils::itertools::Itertools; +#[cfg(any(feature = "serde-lazy", feature = "serde"))] +use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; + +use super::*; +#[cfg(feature = "object")] +use crate::chunked_array::object::registry::get_object_physical_type; +use crate::utils::materialize_dyn_int; + +pub type TimeZone = PlSmallStr; + +static MAINTAIN_PL_TYPE: &str = "maintain_type"; +static PL_KEY: &str = "pl"; + +pub trait MetaDataExt: IntoMetadata { + fn is_enum(&self) -> bool { + let metadata = self.into_metadata_ref(); + metadata.get(DTYPE_ENUM_VALUES).is_some() + } + + fn categorical(&self) -> Option { + let metadata = self.into_metadata_ref(); + match metadata.get(DTYPE_CATEGORICAL)?.as_str() { + "lexical" => Some(CategoricalOrdering::Lexical), + // Default is Physical + _ => Some(CategoricalOrdering::Physical), + } + } + + fn maintain_type(&self) -> bool { + let metadata = self.into_metadata_ref(); + metadata.get(PL_KEY).map(|s| s.as_str()) == Some(MAINTAIN_PL_TYPE) + } +} + +impl MetaDataExt for Metadata {} +pub trait IntoMetadata { + #[allow(clippy::wrong_self_convention)] + fn into_metadata_ref(&self) -> &Metadata; +} + +impl IntoMetadata for Metadata { + fn into_metadata_ref(&self) -> &Metadata { + self + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)] +#[cfg_attr( + any(feature = "serde", feature = "serde-lazy"), + derive(Serialize, Deserialize) +)] +pub enum UnknownKind { + // Hold the value to determine the concrete size. + Int(i128), + Float, + // Can be Categorical or String + Str, + #[default] + Any, +} + +impl UnknownKind { + pub fn materialize(&self) -> Option { + let dtype = match self { + UnknownKind::Int(v) => materialize_dyn_int(*v).dtype(), + UnknownKind::Float => DataType::Float64, + UnknownKind::Str => DataType::String, + UnknownKind::Any => return None, + }; + Some(dtype) + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Default, IntoStaticStr)] +#[cfg_attr( + any(feature = "serde-lazy", feature = "serde"), + derive(Serialize, Deserialize) +)] +#[strum(serialize_all = "snake_case")] +pub enum CategoricalOrdering { + #[default] + Physical, + Lexical, +} + +#[derive(Clone, Debug)] +pub enum DataType { + Boolean, + UInt8, + UInt16, + UInt32, + UInt64, + Int8, + Int16, + Int32, + Int64, + Int128, + Float32, + Float64, + /// Fixed point decimal type optional precision and non-negative scale. + /// This is backed by a signed 128-bit integer which allows for up to 38 significant digits. + /// Meaning max precision is 38. + #[cfg(feature = "dtype-decimal")] + Decimal(Option, Option), // precision/scale; scale being None means "infer" + /// String data + String, + Binary, + BinaryOffset, + /// A 32-bit date representing the elapsed time since UNIX epoch (1970-01-01) + /// in days (32 bits). + Date, + /// A 64-bit date representing the elapsed time since UNIX epoch (1970-01-01) + /// in the given timeunit (64 bits). + Datetime(TimeUnit, Option), + /// 64-bit integer representing difference between times in milliseconds or nanoseconds + Duration(TimeUnit), + /// A 64-bit time representing the elapsed time since midnight in nanoseconds + Time, + /// A nested list with a fixed size in each row + #[cfg(feature = "dtype-array")] + Array(Box, usize), + /// A nested list with a variable size in each row + List(Box), + /// A generic type that can be used in a `Series` + /// &'static str can be used to determine/set inner type + #[cfg(feature = "object")] + Object(&'static str), + Null, + // The RevMapping has the internal state. + // This is ignored with comparisons, hashing etc. + #[cfg(feature = "dtype-categorical")] + Categorical(Option>, CategoricalOrdering), + // It is an Option, so that matching Enum/Categoricals can take the same guards. + #[cfg(feature = "dtype-categorical")] + Enum(Option>, CategoricalOrdering), + #[cfg(feature = "dtype-struct")] + Struct(Vec), + // some logical types we cannot know statically, e.g. Datetime + Unknown(UnknownKind), +} + +impl Default for DataType { + fn default() -> Self { + DataType::Unknown(UnknownKind::Any) + } +} + +pub trait AsRefDataType { + fn as_ref_dtype(&self) -> &DataType; +} + +impl Hash for DataType { + fn hash(&self, state: &mut H) { + std::mem::discriminant(self).hash(state) + } +} + +impl PartialEq for DataType { + fn eq(&self, other: &Self) -> bool { + use DataType::*; + { + match (self, other) { + #[cfg(feature = "dtype-categorical")] + // Don't include rev maps in comparisons + // TODO: include ordering in comparison + (Categorical(_, _ordering_l), Categorical(_, _ordering_r)) => true, + #[cfg(feature = "dtype-categorical")] + // None means select all Enum dtypes. This is for operation `pl.col(pl.Enum)` + (Enum(None, _), Enum(_, _)) | (Enum(_, _), Enum(None, _)) => true, + #[cfg(feature = "dtype-categorical")] + (Enum(Some(cat_lhs), _), Enum(Some(cat_rhs), _)) => { + cat_lhs.get_categories() == cat_rhs.get_categories() + }, + (Datetime(tu_l, tz_l), Datetime(tu_r, tz_r)) => tu_l == tu_r && tz_l == tz_r, + (List(left_inner), List(right_inner)) => left_inner == right_inner, + #[cfg(feature = "dtype-duration")] + (Duration(tu_l), Duration(tu_r)) => tu_l == tu_r, + #[cfg(feature = "dtype-decimal")] + (Decimal(l_prec, l_scale), Decimal(r_prec, r_scale)) => { + let is_prec_eq = l_prec.is_none() || r_prec.is_none() || l_prec == r_prec; + let is_scale_eq = l_scale.is_none() || r_scale.is_none() || l_scale == r_scale; + + is_prec_eq && is_scale_eq + }, + #[cfg(feature = "object")] + (Object(lhs), Object(rhs)) => lhs == rhs, + #[cfg(feature = "dtype-struct")] + (Struct(lhs), Struct(rhs)) => { + std::ptr::eq(Vec::as_ptr(lhs), Vec::as_ptr(rhs)) || lhs == rhs + }, + #[cfg(feature = "dtype-array")] + (Array(left_inner, left_width), Array(right_inner, right_width)) => { + left_width == right_width && left_inner == right_inner + }, + (Unknown(l), Unknown(r)) => match (l, r) { + (UnknownKind::Int(_), UnknownKind::Int(_)) => true, + _ => l == r, + }, + _ => std::mem::discriminant(self) == std::mem::discriminant(other), + } + } + } +} + +impl Eq for DataType {} + +impl DataType { + pub fn new_idxsize() -> Self { + #[cfg(feature = "bigidx")] + { + Self::UInt64 + } + #[cfg(not(feature = "bigidx"))] + { + Self::UInt32 + } + } + + /// Standardize timezones to consistent values. + pub(crate) fn canonical_timezone(tz: &Option) -> Option { + match tz.as_deref() { + Some("") | None => None, + #[cfg(feature = "timezones")] + Some("+00:00") | Some("00:00") | Some("utc") => Some(PlSmallStr::from_static("UTC")), + Some(v) => Some(PlSmallStr::from_str(v)), + } + } + + pub fn value_within_range(&self, other: AnyValue) -> bool { + use DataType::*; + match self { + UInt8 => other.extract::().is_some(), + #[cfg(feature = "dtype-u16")] + UInt16 => other.extract::().is_some(), + UInt32 => other.extract::().is_some(), + UInt64 => other.extract::().is_some(), + #[cfg(feature = "dtype-i8")] + Int8 => other.extract::().is_some(), + #[cfg(feature = "dtype-i16")] + Int16 => other.extract::().is_some(), + Int32 => other.extract::().is_some(), + Int64 => other.extract::().is_some(), + _ => false, + } + } + + /// Check if the whole dtype is known. + pub fn is_known(&self) -> bool { + match self { + DataType::List(inner) => inner.is_known(), + #[cfg(feature = "dtype-array")] + DataType::Array(inner, _) => inner.is_known(), + #[cfg(feature = "dtype-struct")] + DataType::Struct(fields) => fields.iter().all(|fld| fld.dtype.is_known()), + DataType::Unknown(_) => false, + _ => true, + } + } + + /// Materialize this datatype if it is unknown. All other datatypes + /// are left unchanged. + pub fn materialize_unknown(self, allow_unknown: bool) -> PolarsResult { + match self { + DataType::Unknown(u) => match u.materialize() { + Some(known) => Ok(known), + None => { + if allow_unknown { + Ok(DataType::Unknown(u)) + } else { + polars_bail!(SchemaMismatch: "failed to materialize unknown type") + } + }, + }, + DataType::List(inner) => Ok(DataType::List(Box::new( + inner.materialize_unknown(allow_unknown)?, + ))), + #[cfg(feature = "dtype-array")] + DataType::Array(inner, size) => Ok(DataType::Array( + Box::new(inner.materialize_unknown(allow_unknown)?), + size, + )), + #[cfg(feature = "dtype-struct")] + DataType::Struct(fields) => Ok(DataType::Struct( + fields + .into_iter() + .map(|f| { + PolarsResult::Ok(Field::new( + f.name, + f.dtype.materialize_unknown(allow_unknown)?, + )) + }) + .try_collect_vec()?, + )), + _ => Ok(self), + } + } + + #[cfg(feature = "dtype-array")] + /// Get the full shape of a multidimensional array. + pub fn get_shape(&self) -> Option> { + fn get_shape_impl(dt: &DataType, shape: &mut Vec) { + if let DataType::Array(inner, size) = dt { + shape.push(*size); + get_shape_impl(inner, shape); + } + } + + if let DataType::Array(inner, size) = self { + let mut shape = vec![*size]; + get_shape_impl(inner, &mut shape); + Some(shape) + } else { + None + } + } + + /// Get the inner data type of a nested type. + pub fn inner_dtype(&self) -> Option<&DataType> { + match self { + DataType::List(inner) => Some(inner), + #[cfg(feature = "dtype-array")] + DataType::Array(inner, _) => Some(inner), + _ => None, + } + } + + /// Get the absolute inner data type of a nested type. + pub fn leaf_dtype(&self) -> &DataType { + let mut prev = self; + while let Some(dtype) = prev.inner_dtype() { + prev = dtype + } + prev + } + + #[cfg(feature = "dtype-array")] + /// Get the inner data type of a multidimensional array. + pub fn array_leaf_dtype(&self) -> Option<&DataType> { + let mut prev = self; + match prev { + DataType::Array(_, _) => { + while let DataType::Array(inner, _) = &prev { + prev = inner; + } + Some(prev) + }, + _ => None, + } + } + + /// Cast the leaf types of Lists/Arrays and keep the nesting. + pub fn cast_leaf(&self, to: DataType) -> DataType { + use DataType::*; + match self { + List(inner) => List(Box::new(inner.cast_leaf(to))), + #[cfg(feature = "dtype-array")] + Array(inner, size) => Array(Box::new(inner.cast_leaf(to)), *size), + _ => to, + } + } + + /// Return whether the cast to `to` makes sense. + /// + /// If it `None`, we are not sure. + pub fn can_cast_to(&self, to: &DataType) -> Option { + if self == to { + return Some(true); + } + if self.is_primitive_numeric() && to.is_primitive_numeric() { + return Some(true); + } + + if self.is_null() { + return Some(true); + } + + use DataType as D; + Some(match (self, to) { + #[cfg(feature = "dtype-categorical")] + (D::Categorical(_, _) | D::Enum(_, _), D::Binary) + | (D::Binary, D::Categorical(_, _) | D::Enum(_, _)) => false, + + #[cfg(feature = "object")] + (D::Object(_), D::Object(_)) => true, + #[cfg(feature = "object")] + (D::Object(_), _) | (_, D::Object(_)) => false, + + (D::Boolean, dt) | (dt, D::Boolean) => match dt { + dt if dt.is_primitive_numeric() => true, + #[cfg(feature = "dtype-decimal")] + D::Decimal(_, _) => true, + D::String | D::Binary => true, + _ => false, + }, + + (D::List(from), D::List(to)) => from.can_cast_to(to)?, + #[cfg(feature = "dtype-array")] + (D::Array(from, l_width), D::Array(to, r_width)) => { + l_width == r_width && from.can_cast_to(to)? + }, + #[cfg(feature = "dtype-struct")] + (D::Struct(l_fields), D::Struct(r_fields)) => { + if l_fields.is_empty() { + return Some(true); + } + + if l_fields.len() != r_fields.len() { + return Some(false); + } + + for (l, r) in l_fields.iter().zip(r_fields) { + if !l.dtype().can_cast_to(r.dtype())? { + return Some(false); + } + } + + true + }, + + // @NOTE: we are being conversative + _ => return None, + }) + } + + pub fn implode(self) -> DataType { + DataType::List(Box::new(self)) + } + + /// Convert to the physical data type + #[must_use] + pub fn to_physical(&self) -> DataType { + use DataType::*; + match self { + Date => Int32, + Datetime(_, _) => Int64, + Duration(_) => Int64, + Time => Int64, + #[cfg(feature = "dtype-decimal")] + Decimal(_, _) => Int128, + #[cfg(feature = "dtype-categorical")] + Categorical(_, _) | Enum(_, _) => UInt32, + #[cfg(feature = "dtype-array")] + Array(dt, width) => Array(Box::new(dt.to_physical()), *width), + List(dt) => List(Box::new(dt.to_physical())), + #[cfg(feature = "dtype-struct")] + Struct(fields) => { + let new_fields = fields + .iter() + .map(|s| Field::new(s.name().clone(), s.dtype().to_physical())) + .collect(); + Struct(new_fields) + }, + _ => self.clone(), + } + } + + pub fn is_supported_list_arithmetic_input(&self) -> bool { + self.is_primitive_numeric() || self.is_bool() || self.is_null() + } + + /// Check if this [`DataType`] is a logical type + pub fn is_logical(&self) -> bool { + self != &self.to_physical() + } + + /// Check if this [`DataType`] is a temporal type + pub fn is_temporal(&self) -> bool { + use DataType::*; + matches!(self, Date | Datetime(_, _) | Duration(_) | Time) + } + + /// Check if datatype is a primitive type. By that we mean that + /// it is not a nested or logical type. + pub fn is_primitive(&self) -> bool { + self.is_primitive_numeric() + | matches!( + self, + DataType::Boolean | DataType::String | DataType::Binary + ) + } + + /// Check if this [`DataType`] is a primitive numeric type (excludes Decimal). + pub fn is_primitive_numeric(&self) -> bool { + self.is_float() || self.is_integer() + } + + /// Check if this [`DataType`] is a boolean. + pub fn is_bool(&self) -> bool { + matches!(self, DataType::Boolean) + } + + /// Check if this [`DataType`] is a list. + pub fn is_list(&self) -> bool { + matches!(self, DataType::List(_)) + } + + /// Check if this [`DataType`] is an array. + pub fn is_array(&self) -> bool { + #[cfg(feature = "dtype-array")] + { + matches!(self, DataType::Array(_, _)) + } + #[cfg(not(feature = "dtype-array"))] + { + false + } + } + + pub fn is_nested(&self) -> bool { + self.is_list() || self.is_struct() || self.is_array() + } + + /// Check if this [`DataType`] is a struct + pub fn is_struct(&self) -> bool { + #[cfg(feature = "dtype-struct")] + { + matches!(self, DataType::Struct(_)) + } + #[cfg(not(feature = "dtype-struct"))] + { + false + } + } + + pub fn is_binary(&self) -> bool { + matches!(self, DataType::Binary) + } + + pub fn is_date(&self) -> bool { + matches!(self, DataType::Date) + } + pub fn is_datetime(&self) -> bool { + matches!(self, DataType::Datetime(..)) + } + + pub fn is_object(&self) -> bool { + #[cfg(feature = "object")] + { + matches!(self, DataType::Object(_)) + } + #[cfg(not(feature = "object"))] + { + false + } + } + + pub fn is_null(&self) -> bool { + matches!(self, DataType::Null) + } + + pub fn contains_views(&self) -> bool { + use DataType::*; + match self { + Binary | String => true, + #[cfg(feature = "dtype-categorical")] + Categorical(_, _) | Enum(_, _) => true, + List(inner) => inner.contains_views(), + #[cfg(feature = "dtype-array")] + Array(inner, _) => inner.contains_views(), + #[cfg(feature = "dtype-struct")] + Struct(fields) => fields.iter().any(|field| field.dtype.contains_views()), + _ => false, + } + } + + pub fn contains_categoricals(&self) -> bool { + use DataType::*; + match self { + #[cfg(feature = "dtype-categorical")] + Categorical(_, _) | Enum(_, _) => true, + List(inner) => inner.contains_categoricals(), + #[cfg(feature = "dtype-array")] + Array(inner, _) => inner.contains_categoricals(), + #[cfg(feature = "dtype-struct")] + Struct(fields) => fields + .iter() + .any(|field| field.dtype.contains_categoricals()), + _ => false, + } + } + + pub fn contains_objects(&self) -> bool { + use DataType::*; + match self { + #[cfg(feature = "object")] + Object(_) => true, + List(inner) => inner.contains_objects(), + #[cfg(feature = "dtype-array")] + Array(inner, _) => inner.contains_objects(), + #[cfg(feature = "dtype-struct")] + Struct(fields) => fields.iter().any(|field| field.dtype.contains_objects()), + _ => false, + } + } + + /// Check if type is sortable + pub fn is_ord(&self) -> bool { + #[cfg(feature = "dtype-categorical")] + let is_cat = matches!(self, DataType::Categorical(_, _) | DataType::Enum(_, _)); + #[cfg(not(feature = "dtype-categorical"))] + let is_cat = false; + + let phys = self.to_physical(); + (phys.is_primitive_numeric() + || self.is_decimal() + || matches!( + phys, + DataType::Binary | DataType::String | DataType::Boolean + )) + && !is_cat + } + + /// Check if this [`DataType`] is a Decimal type (of any scale/precision). + pub fn is_decimal(&self) -> bool { + match self { + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(_, _) => true, + _ => false, + } + } + + /// Check if this [`DataType`] is a basic floating point type (excludes Decimal). + /// Note, this also includes `Unknown(UnknownKind::Float)`. + pub fn is_float(&self) -> bool { + matches!( + self, + DataType::Float32 | DataType::Float64 | DataType::Unknown(UnknownKind::Float) + ) + } + + /// Check if this [`DataType`] is an integer. Note, this also includes `Unknown(UnknownKind::Int)`. + pub fn is_integer(&self) -> bool { + matches!( + self, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Int128 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Unknown(UnknownKind::Int(_)) + ) + } + + pub fn is_signed_integer(&self) -> bool { + // allow because it cannot be replaced when object feature is activated + matches!( + self, + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 | DataType::Int128 + ) + } + + pub fn is_unsigned_integer(&self) -> bool { + matches!( + self, + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64, + ) + } + + pub fn is_string(&self) -> bool { + matches!(self, DataType::String | DataType::Unknown(UnknownKind::Str)) + } + + pub fn is_categorical(&self) -> bool { + #[cfg(feature = "dtype-categorical")] + { + matches!(self, DataType::Categorical(_, _)) + } + #[cfg(not(feature = "dtype-categorical"))] + { + false + } + } + + pub fn is_enum(&self) -> bool { + #[cfg(feature = "dtype-categorical")] + { + matches!(self, DataType::Enum(_, _)) + } + #[cfg(not(feature = "dtype-categorical"))] + { + false + } + } + + /// Convert to an Arrow Field + pub fn to_arrow_field(&self, name: PlSmallStr, compat_level: CompatLevel) -> ArrowField { + let metadata = match self { + #[cfg(feature = "dtype-categorical")] + DataType::Enum(Some(revmap), _) => { + let cats = revmap.get_categories(); + let mut encoded = String::with_capacity(cats.len() * 10); + for cat in cats.values_iter() { + encoded.push_str(itoa::Buffer::new().format(cat.len())); + encoded.push(';'); + encoded.push_str(cat); + } + Some(BTreeMap::from([( + PlSmallStr::from_static(DTYPE_ENUM_VALUES), + PlSmallStr::from_string(encoded), + )])) + }, + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_, ordering) => Some(BTreeMap::from([( + PlSmallStr::from_static(DTYPE_CATEGORICAL), + PlSmallStr::from_static(ordering.into()), + )])), + DataType::BinaryOffset => Some(BTreeMap::from([( + PlSmallStr::from_static(PL_KEY), + PlSmallStr::from_static(MAINTAIN_PL_TYPE), + )])), + _ => None, + }; + + let field = ArrowField::new(name, self.to_arrow(compat_level), true); + + if let Some(metadata) = metadata { + field.with_metadata(metadata) + } else { + field + } + } + + /// Try to get the maximum value for this datatype. + pub fn max(&self) -> PolarsResult { + use DataType::*; + let v = match self { + Int8 => Scalar::from(i8::MAX), + Int16 => Scalar::from(i16::MAX), + Int32 => Scalar::from(i32::MAX), + Int64 => Scalar::from(i64::MAX), + Int128 => Scalar::from(i128::MAX), + UInt8 => Scalar::from(u8::MAX), + UInt16 => Scalar::from(u16::MAX), + UInt32 => Scalar::from(u32::MAX), + UInt64 => Scalar::from(u64::MAX), + Float32 => Scalar::from(f32::INFINITY), + Float64 => Scalar::from(f64::INFINITY), + #[cfg(feature = "dtype-time")] + Time => Scalar::new(Time, AnyValue::Time(NS_IN_DAY - 1)), + dt => polars_bail!(ComputeError: "cannot determine upper bound for dtype `{}`", dt), + }; + Ok(v) + } + + /// Try to get the minimum value for this datatype. + pub fn min(&self) -> PolarsResult { + use DataType::*; + let v = match self { + Int8 => Scalar::from(i8::MIN), + Int16 => Scalar::from(i16::MIN), + Int32 => Scalar::from(i32::MIN), + Int64 => Scalar::from(i64::MIN), + Int128 => Scalar::from(i128::MIN), + UInt8 => Scalar::from(u8::MIN), + UInt16 => Scalar::from(u16::MIN), + UInt32 => Scalar::from(u32::MIN), + UInt64 => Scalar::from(u64::MIN), + Float32 => Scalar::from(f32::NEG_INFINITY), + Float64 => Scalar::from(f64::NEG_INFINITY), + #[cfg(feature = "dtype-time")] + Time => Scalar::new(Time, AnyValue::Time(0)), + dt => polars_bail!(ComputeError: "cannot determine lower bound for dtype `{}`", dt), + }; + Ok(v) + } + + /// Convert to an Arrow data type. + #[inline] + pub fn to_arrow(&self, compat_level: CompatLevel) -> ArrowDataType { + self.try_to_arrow(compat_level).unwrap() + } + + #[inline] + pub fn try_to_arrow(&self, compat_level: CompatLevel) -> PolarsResult { + use DataType::*; + match self { + Boolean => Ok(ArrowDataType::Boolean), + UInt8 => Ok(ArrowDataType::UInt8), + UInt16 => Ok(ArrowDataType::UInt16), + UInt32 => Ok(ArrowDataType::UInt32), + UInt64 => Ok(ArrowDataType::UInt64), + Int8 => Ok(ArrowDataType::Int8), + Int16 => Ok(ArrowDataType::Int16), + Int32 => Ok(ArrowDataType::Int32), + Int64 => Ok(ArrowDataType::Int64), + Int128 => Ok(ArrowDataType::Int128), + Float32 => Ok(ArrowDataType::Float32), + Float64 => Ok(ArrowDataType::Float64), + #[cfg(feature = "dtype-decimal")] + Decimal(precision, scale) => { + let precision = (*precision).unwrap_or(38); + polars_ensure!(precision <= 38 && precision > 0, InvalidOperation: "decimal precision should be <= 38 & >= 1"); + + Ok(ArrowDataType::Decimal( + precision, + scale.unwrap_or(0), // and what else can we do here? + )) + }, + String => { + let dt = if compat_level.0 >= 1 { + ArrowDataType::Utf8View + } else { + ArrowDataType::LargeUtf8 + }; + Ok(dt) + }, + Binary => { + let dt = if compat_level.0 >= 1 { + ArrowDataType::BinaryView + } else { + ArrowDataType::LargeBinary + }; + Ok(dt) + }, + Date => Ok(ArrowDataType::Date32), + Datetime(unit, tz) => Ok(ArrowDataType::Timestamp(unit.to_arrow(), tz.clone())), + Duration(unit) => Ok(ArrowDataType::Duration(unit.to_arrow())), + Time => Ok(ArrowDataType::Time64(ArrowTimeUnit::Nanosecond)), + #[cfg(feature = "dtype-array")] + Array(dt, size) => Ok(dt + .try_to_arrow(compat_level)? + .to_fixed_size_list(*size, true)), + List(dt) => Ok(ArrowDataType::LargeList(Box::new( + dt.to_arrow_field(PlSmallStr::from_static("item"), compat_level), + ))), + Null => Ok(ArrowDataType::Null), + #[cfg(feature = "object")] + Object(_) => Ok(get_object_physical_type()), + #[cfg(feature = "dtype-categorical")] + Categorical(_, _) | Enum(_, _) => { + let values = if compat_level.0 >= 1 { + ArrowDataType::Utf8View + } else { + ArrowDataType::LargeUtf8 + }; + Ok(ArrowDataType::Dictionary( + IntegerType::UInt32, + Box::new(values), + false, + )) + }, + #[cfg(feature = "dtype-struct")] + Struct(fields) => { + let fields = fields + .iter() + .map(|fld| fld.to_arrow(compat_level)) + .collect(); + Ok(ArrowDataType::Struct(fields)) + }, + BinaryOffset => Ok(ArrowDataType::LargeBinary), + Unknown(kind) => { + let dt = match kind { + UnknownKind::Any => ArrowDataType::Unknown, + UnknownKind::Float => ArrowDataType::Float64, + UnknownKind::Str => ArrowDataType::Utf8View, + UnknownKind::Int(v) => { + return materialize_dyn_int(*v).dtype().try_to_arrow(compat_level); + }, + }; + Ok(dt) + }, + } + } + + pub fn is_nested_null(&self) -> bool { + use DataType::*; + match self { + Null => true, + List(field) => field.is_nested_null(), + #[cfg(feature = "dtype-array")] + Array(field, _) => field.is_nested_null(), + #[cfg(feature = "dtype-struct")] + Struct(fields) => fields.iter().all(|fld| fld.dtype.is_nested_null()), + _ => false, + } + } + + /// Answers if this type matches the given type of a schema. + /// + /// Allows (nested) Null types in this type to match any type in the schema, + /// but not vice versa. In such a case Ok(true) is returned, because a cast + /// is necessary. If no cast is necessary Ok(false) is returned, and an + /// error is returned if the types are incompatible. + pub fn matches_schema_type(&self, schema_type: &DataType) -> PolarsResult { + match (self, schema_type) { + (DataType::List(l), DataType::List(r)) => l.matches_schema_type(r), + #[cfg(feature = "dtype-array")] + (DataType::Array(l, sl), DataType::Array(r, sr)) => { + Ok(l.matches_schema_type(r)? && sl == sr) + }, + #[cfg(feature = "dtype-struct")] + (DataType::Struct(l), DataType::Struct(r)) => { + let mut must_cast = false; + for (l, r) in l.iter().zip(r.iter()) { + must_cast |= l.dtype.matches_schema_type(&r.dtype)?; + } + Ok(must_cast) + }, + (DataType::Null, DataType::Null) => Ok(false), + #[cfg(feature = "dtype-decimal")] + (DataType::Decimal(_, s1), DataType::Decimal(_, s2)) => Ok(s1 != s2), + // We don't allow the other way around, only if our current type is + // null and the schema isn't we allow it. + (DataType::Null, _) => Ok(true), + (l, r) if l == r => Ok(false), + (l, r) => { + polars_bail!(SchemaMismatch: "type {:?} is incompatible with expected type {:?}", l, r) + }, + } + } + + #[inline] + pub fn is_unknown(&self) -> bool { + matches!(self, DataType::Unknown(_)) + } + + pub fn nesting_level(&self) -> usize { + let mut level = 0; + let mut slf = self; + while let Some(inner_dtype) = slf.inner_dtype() { + level += 1; + slf = inner_dtype; + } + level + } +} + +impl Display for DataType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let s = match self { + DataType::Null => "null", + DataType::Boolean => "bool", + DataType::UInt8 => "u8", + DataType::UInt16 => "u16", + DataType::UInt32 => "u32", + DataType::UInt64 => "u64", + DataType::Int8 => "i8", + DataType::Int16 => "i16", + DataType::Int32 => "i32", + DataType::Int64 => "i64", + DataType::Int128 => "i128", + DataType::Float32 => "f32", + DataType::Float64 => "f64", + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(precision, scale) => { + return match (precision, scale) { + (Some(precision), Some(scale)) => { + f.write_str(&format!("decimal[{precision},{scale}]")) + }, + (None, Some(scale)) => f.write_str(&format!("decimal[*,{scale}]")), + _ => f.write_str("decimal[?]"), // shouldn't happen + }; + }, + DataType::String => "str", + DataType::Binary => "binary", + DataType::Date => "date", + DataType::Datetime(tu, tz) => { + let s = match tz { + None => format!("datetime[{tu}]"), + Some(tz) => format!("datetime[{tu}, {tz}]"), + }; + return f.write_str(&s); + }, + DataType::Duration(tu) => return write!(f, "duration[{tu}]"), + DataType::Time => "time", + #[cfg(feature = "dtype-array")] + DataType::Array(_, _) => { + let tp = self.array_leaf_dtype().unwrap(); + + let dims = self.get_shape().unwrap(); + let shape = if dims.len() == 1 { + format!("{}", dims[0]) + } else { + format_tuple!(dims) + }; + return write!(f, "array[{tp}, {}]", shape); + }, + DataType::List(tp) => return write!(f, "list[{tp}]"), + #[cfg(feature = "object")] + DataType::Object(s) => s, + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_, _) => "cat", + #[cfg(feature = "dtype-categorical")] + DataType::Enum(_, _) => "enum", + #[cfg(feature = "dtype-struct")] + DataType::Struct(fields) => return write!(f, "struct[{}]", fields.len()), + DataType::Unknown(kind) => match kind { + UnknownKind::Any => "unknown", + UnknownKind::Int(_) => "dyn int", + UnknownKind::Float => "dyn float", + UnknownKind::Str => "dyn str", + }, + DataType::BinaryOffset => "binary[offset]", + }; + f.write_str(s) + } +} + +pub fn merge_dtypes(left: &DataType, right: &DataType) -> PolarsResult { + use DataType::*; + Ok(match (left, right) { + #[cfg(feature = "dtype-categorical")] + (Categorical(Some(rev_map_l), ordering), Categorical(Some(rev_map_r), _)) => { + match (&**rev_map_l, &**rev_map_r) { + (RevMapping::Global(_, _, idl), RevMapping::Global(_, _, idr)) if idl == idr => { + let mut merger = GlobalRevMapMerger::new(rev_map_l.clone()); + merger.merge_map(rev_map_r)?; + Categorical(Some(merger.finish()), *ordering) + }, + (RevMapping::Local(_, idl), RevMapping::Local(_, idr)) if idl == idr => { + left.clone() + }, + _ => polars_bail!(string_cache_mismatch), + } + }, + #[cfg(feature = "dtype-categorical")] + (Enum(Some(rev_map_l), _), Enum(Some(rev_map_r), _)) => { + match (&**rev_map_l, &**rev_map_r) { + (RevMapping::Local(_, idl), RevMapping::Local(_, idr)) if idl == idr => { + left.clone() + }, + _ => polars_bail!(ComputeError: "can not combine with different categories"), + } + }, + (List(inner_l), List(inner_r)) => { + let merged = merge_dtypes(inner_l, inner_r)?; + List(Box::new(merged)) + }, + #[cfg(feature = "dtype-struct")] + (Struct(inner_l), Struct(inner_r)) => { + polars_ensure!(inner_l.len() == inner_r.len(), ComputeError: "cannot combine structs with differing amounts of fields ({} != {})", inner_l.len(), inner_r.len()); + let fields = inner_l.iter().zip(inner_r.iter()).map(|(l, r)| { + polars_ensure!(l.name() == r.name(), ComputeError: "cannot combine structs with different fields ({} != {})", l.name(), r.name()); + let merged = merge_dtypes(l.dtype(), r.dtype())?; + Ok(Field::new(l.name().clone(), merged)) + }).collect::>>()?; + Struct(fields) + }, + #[cfg(feature = "dtype-array")] + (Array(inner_l, width_l), Array(inner_r, width_r)) => { + polars_ensure!(width_l == width_r, ComputeError: "widths of FixedSizeWidth Series are not equal"); + let merged = merge_dtypes(inner_l, inner_r)?; + Array(Box::new(merged), *width_l) + }, + (left, right) if left == right => left.clone(), + _ => polars_bail!(ComputeError: "unable to merge datatypes"), + }) +} + +fn collect_nested_types( + dtype: &DataType, + result: &mut PlHashSet, + include_compound_types: bool, +) { + match dtype { + DataType::List(inner) => { + if include_compound_types { + result.insert(dtype.clone()); + } + collect_nested_types(inner, result, include_compound_types); + }, + #[cfg(feature = "dtype-array")] + DataType::Array(inner, _) => { + if include_compound_types { + result.insert(dtype.clone()); + } + collect_nested_types(inner, result, include_compound_types); + }, + #[cfg(feature = "dtype-struct")] + DataType::Struct(fields) => { + if include_compound_types { + result.insert(dtype.clone()); + } + for field in fields { + collect_nested_types(field.dtype(), result, include_compound_types); + } + }, + _ => { + result.insert(dtype.clone()); + }, + } +} + +pub fn unpack_dtypes(dtype: &DataType, include_compound_types: bool) -> PlHashSet { + let mut result = PlHashSet::new(); + collect_nested_types(dtype, &mut result, include_compound_types); + result +} + +#[cfg(feature = "dtype-categorical")] +pub fn create_enum_dtype(categories: Utf8ViewArray) -> DataType { + let rev_map = RevMapping::build_local(categories); + DataType::Enum(Some(Arc::new(rev_map)), Default::default()) +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct CompatLevel(pub(crate) u16); + +impl CompatLevel { + pub const fn newest() -> CompatLevel { + CompatLevel(1) + } + + pub const fn oldest() -> CompatLevel { + CompatLevel(0) + } + + // The following methods are only used internally + + #[doc(hidden)] + pub fn with_level(level: u16) -> PolarsResult { + if level > CompatLevel::newest().0 { + polars_bail!(InvalidOperation: "invalid compat level"); + } + Ok(CompatLevel(level)) + } + + #[doc(hidden)] + pub fn get_level(&self) -> u16 { + self.0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg(feature = "dtype-array")] + #[test] + fn test_unpack_primitive_dtypes() { + let inner_type = DataType::Float64; + let array_type = DataType::Array(Box::new(inner_type), 10); + let list_type = DataType::List(Box::new(array_type.clone())); + + let result = unpack_dtypes(&list_type, false); + + let mut expected = PlHashSet::new(); + expected.insert(DataType::Float64); + + assert_eq!(result, expected) + } + + #[cfg(feature = "dtype-array")] + #[test] + fn test_unpack_compound_dtypes() { + let inner_type = DataType::Float64; + let array_type = DataType::Array(Box::new(inner_type), 10); + let list_type = DataType::List(Box::new(array_type.clone())); + + let result = unpack_dtypes(&list_type, true); + + let mut expected = PlHashSet::new(); + expected.insert(list_type.clone()); + expected.insert(array_type.clone()); + expected.insert(DataType::Float64); + + assert_eq!(result, expected) + } +} diff --git a/crates/polars-core/src/datatypes/field.rs b/crates/polars-core/src/datatypes/field.rs new file mode 100644 index 000000000000..caeb44a341dc --- /dev/null +++ b/crates/polars-core/src/datatypes/field.rs @@ -0,0 +1,265 @@ +use arrow::datatypes::{DTYPE_ENUM_VALUES, Metadata}; +use polars_utils::pl_str::PlSmallStr; + +use super::*; +pub static EXTENSION_NAME: &str = "POLARS_EXTENSION_TYPE"; + +/// Characterizes the name and the [`DataType`] of a column. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[cfg_attr( + any(feature = "serde", feature = "serde-lazy"), + derive(Serialize, Deserialize) +)] +pub struct Field { + pub name: PlSmallStr, + pub dtype: DataType, +} + +impl From for (PlSmallStr, DataType) { + fn from(value: Field) -> Self { + (value.name, value.dtype) + } +} + +pub type FieldRef = Arc; + +impl Field { + /// Creates a new `Field`. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let f1 = Field::new("Fruit name".into(), DataType::String); + /// let f2 = Field::new("Lawful".into(), DataType::Boolean); + /// let f2 = Field::new("Departure".into(), DataType::Time); + /// ``` + #[inline] + pub fn new(name: PlSmallStr, dtype: DataType) -> Self { + Field { name, dtype } + } + + /// Returns a reference to the `Field` name. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let f = Field::new("Year".into(), DataType::Int32); + /// + /// assert_eq!(f.name(), "Year"); + /// ``` + #[inline] + pub fn name(&self) -> &PlSmallStr { + &self.name + } + + /// Returns a reference to the `Field` datatype. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let f = Field::new("Birthday".into(), DataType::Date); + /// + /// assert_eq!(f.dtype(), &DataType::Date); + /// ``` + #[inline] + pub fn dtype(&self) -> &DataType { + &self.dtype + } + + /// Sets the `Field` datatype. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let mut f = Field::new("Temperature".into(), DataType::Int32); + /// f.coerce(DataType::Float32); + /// + /// assert_eq!(f, Field::new("Temperature".into(), DataType::Float32)); + /// ``` + pub fn coerce(&mut self, dtype: DataType) { + self.dtype = dtype; + } + + /// Sets the `Field` name. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let mut f = Field::new("Atomic number".into(), DataType::UInt32); + /// f.set_name("Proton".into()); + /// + /// assert_eq!(f, Field::new("Proton".into(), DataType::UInt32)); + /// ``` + pub fn set_name(&mut self, name: PlSmallStr) { + self.name = name; + } + + /// Returns this `Field`, renamed. + pub fn with_name(mut self, name: PlSmallStr) -> Self { + self.name = name; + self + } + + /// Converts the `Field` to an `arrow::datatypes::Field`. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let f = Field::new("Value".into(), DataType::Int64); + /// let af = arrow::datatypes::Field::new("Value".into(), arrow::datatypes::ArrowDataType::Int64, true); + /// + /// assert_eq!(f.to_arrow(CompatLevel::newest()), af); + /// ``` + pub fn to_arrow(&self, compat_level: CompatLevel) -> ArrowField { + self.dtype.to_arrow_field(self.name.clone(), compat_level) + } +} + +impl AsRef for Field { + fn as_ref(&self) -> &DataType { + &self.dtype + } +} + +impl AsRef for DataType { + fn as_ref(&self) -> &DataType { + self + } +} + +impl DataType { + pub fn boxed(self) -> Box { + Box::new(self) + } + + pub fn from_arrow_field(field: &ArrowField) -> DataType { + Self::from_arrow(&field.dtype, true, field.metadata.as_deref()) + } + + pub fn from_arrow_dtype(dt: &ArrowDataType) -> DataType { + Self::from_arrow(dt, true, None) + } + + pub fn from_arrow(dt: &ArrowDataType, bin_to_view: bool, md: Option<&Metadata>) -> DataType { + match dt { + ArrowDataType::Null => DataType::Null, + ArrowDataType::UInt8 => DataType::UInt8, + ArrowDataType::UInt16 => DataType::UInt16, + ArrowDataType::UInt32 => DataType::UInt32, + ArrowDataType::UInt64 => DataType::UInt64, + ArrowDataType::Int8 => DataType::Int8, + ArrowDataType::Int16 => DataType::Int16, + ArrowDataType::Int32 => DataType::Int32, + ArrowDataType::Int64 => DataType::Int64, + #[cfg(feature = "dtype-i128")] + ArrowDataType::Int128 => DataType::Int128, + ArrowDataType::Boolean => DataType::Boolean, + ArrowDataType::Float16 => DataType::Float32, + ArrowDataType::Float32 => DataType::Float32, + ArrowDataType::Float64 => DataType::Float64, + #[cfg(feature = "dtype-array")] + ArrowDataType::FixedSizeList(f, size) => { + DataType::Array(DataType::from_arrow_field(f).boxed(), *size) + }, + ArrowDataType::LargeList(f) | ArrowDataType::List(f) => { + DataType::List(DataType::from_arrow_field(f).boxed()) + }, + ArrowDataType::Date32 => DataType::Date, + ArrowDataType::Timestamp(tu, tz) => { + DataType::Datetime(tu.into(), DataType::canonical_timezone(tz)) + }, + ArrowDataType::Duration(tu) => DataType::Duration(tu.into()), + ArrowDataType::Date64 => DataType::Datetime(TimeUnit::Milliseconds, None), + ArrowDataType::Time64(_) | ArrowDataType::Time32(_) => DataType::Time, + #[cfg(feature = "dtype-categorical")] + ArrowDataType::Dictionary(_, value_type, _) => { + if md.map(|md| md.is_enum()).unwrap_or(false) { + let md = md.unwrap(); + let encoded = md.get(DTYPE_ENUM_VALUES).unwrap(); + let mut encoded = encoded.as_str(); + let mut cats = MutableBinaryViewArray::::new(); + + // Data is encoded as + // We know thus that len is only [0-9] and the first ';' doesn't belong to the + // payload. + while let Some(pos) = encoded.find(';') { + let (len, remainder) = encoded.split_at(pos); + // Split off ';' + encoded = &remainder[1..]; + let len = len.parse::().unwrap(); + + let (value, remainder) = encoded.split_at(len); + cats.push_value(value); + encoded = remainder; + } + DataType::Enum( + Some(Arc::new(RevMapping::build_local(cats.into()))), + Default::default(), + ) + } else if let Some(ordering) = md.and_then(|md| md.categorical()) { + DataType::Categorical(None, ordering) + } else if matches!( + value_type.as_ref(), + ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 | ArrowDataType::Utf8View + ) { + DataType::Categorical(None, Default::default()) + } else { + Self::from_arrow(value_type, bin_to_view, None) + } + }, + #[cfg(feature = "dtype-struct")] + ArrowDataType::Struct(fields) => { + DataType::Struct(fields.iter().map(|fld| fld.into()).collect()) + }, + #[cfg(not(feature = "dtype-struct"))] + ArrowDataType::Struct(_) => { + panic!("activate the 'dtype-struct' feature to handle struct data types") + }, + ArrowDataType::Extension(ext) if ext.name.as_str() == EXTENSION_NAME => { + #[cfg(feature = "object")] + { + DataType::Object("object") + } + #[cfg(not(feature = "object"))] + { + panic!("activate the 'object' feature to be able to load POLARS_EXTENSION_TYPE") + } + }, + #[cfg(feature = "dtype-decimal")] + ArrowDataType::Decimal(precision, scale) => { + DataType::Decimal(Some(*precision), Some(*scale)) + }, + ArrowDataType::Utf8View | ArrowDataType::LargeUtf8 | ArrowDataType::Utf8 => { + DataType::String + }, + ArrowDataType::BinaryView => DataType::Binary, + ArrowDataType::LargeBinary | ArrowDataType::Binary => { + if bin_to_view { + DataType::Binary + } else { + DataType::BinaryOffset + } + }, + ArrowDataType::FixedSizeBinary(_) => DataType::Binary, + ArrowDataType::Map(inner, _is_sorted) => { + DataType::List(Self::from_arrow_field(inner).boxed()) + }, + dt => panic!( + "Arrow datatype {dt:?} not supported by Polars. \ + You probably need to activate that data-type feature." + ), + } + } +} + +impl From<&ArrowField> for Field { + fn from(f: &ArrowField) -> Self { + Field::new(f.name.clone(), DataType::from_arrow_field(f)) + } +} diff --git a/crates/polars-core/src/datatypes/into_scalar.rs b/crates/polars-core/src/datatypes/into_scalar.rs new file mode 100644 index 000000000000..5ab7aadb74e0 --- /dev/null +++ b/crates/polars-core/src/datatypes/into_scalar.rs @@ -0,0 +1,61 @@ +use polars_error::{PolarsResult, polars_bail}; + +use super::{AnyValue, DataType, Scalar}; + +pub trait IntoScalar { + fn into_scalar(self, dtype: DataType) -> PolarsResult; +} + +macro_rules! impl_into_scalar { + ($( + $ty:ty: ($($dt:pat),+ $(,)?), + )+) => { + $( + impl IntoScalar for $ty { + fn into_scalar(self, dtype: DataType) -> PolarsResult { + Ok(match &dtype { + T::Null => Scalar::new(dtype, AnyValue::Null), + $($dt => Scalar::new(dtype, self.into()),)+ + _ => polars_bail!(InvalidOperation: "Cannot cast `{}` to `Scalar` with dtype={dtype}", stringify!($ty)), + }) + } + } + )+ + }; +} + +use DataType as T; +impl_into_scalar! { + bool: (T::Boolean), + u8: (T::UInt8), + u16: (T::UInt16), + u32: (T::UInt32), // T::Categorical, T::Enum + u64: (T::UInt64), + i8: (T::Int8), + i16: (T::Int16), + i32: (T::Int32), // T::Date + i64: (T::Int64), // T::Datetime, T::Duration, T::Time + f32: (T::Float32), + f64: (T::Float64), + // Vec: (T::Binary), + // String: (T::String), + // Series: (T::List, T::Array), + // /// Can be used to fmt and implements Any, so can be downcasted to the proper value type. + // #[cfg(feature = "object")] + // Object(&'a dyn PolarsObjectSafe), + // #[cfg(feature = "object")] + // ObjectOwned(OwnedObject), + // #[cfg(feature = "dtype-struct")] + // // 3 pointers and thus not larger than string/vec + // // - idx in the `&StructArray` + // // - The array itself + // // - The fields + // Struct(usize, &'a StructArray, &'a [Field]), + // #[cfg(feature = "dtype-struct")] + // StructOwned(Box<(Vec>, Vec)>), +} + +#[cfg(feature = "dtype-i128")] +impl_into_scalar! { + i128: (T::Int128), // T::Decimal +} diff --git a/crates/polars-core/src/datatypes/mod.rs b/crates/polars-core/src/datatypes/mod.rs new file mode 100644 index 000000000000..790ea9c22117 --- /dev/null +++ b/crates/polars-core/src/datatypes/mod.rs @@ -0,0 +1,403 @@ +//! # Data types supported by Polars. +//! +//! At the moment Polars doesn't include all data types available by Arrow. The goal is to +//! incrementally support more data types and prioritize these by usability. +//! +//! [See the AnyValue variants](enum.AnyValue.html#variants) for the data types that +//! are currently supported. +//! +#[cfg(feature = "serde")] +mod _serde; +mod aliases; +mod any_value; +mod dtype; +mod field; +mod into_scalar; +#[cfg(feature = "object")] +mod static_array_collect; +mod time_unit; + +use std::cmp::Ordering; +use std::fmt::{Display, Formatter}; +use std::hash::{Hash, Hasher}; +use std::ops::{Add, AddAssign, Div, Mul, Rem, Sub, SubAssign}; + +mod schema; +pub use aliases::*; +pub use any_value::*; +pub use arrow::array::{ArrayCollectIterExt, ArrayFromIter, ArrayFromIterDtype, StaticArray}; +#[cfg(feature = "dtype-categorical")] +use arrow::datatypes::IntegerType; +pub use arrow::datatypes::reshape::*; +pub use arrow::datatypes::{ArrowDataType, TimeUnit as ArrowTimeUnit}; +use arrow::types::NativeType; +use bytemuck::Zeroable; +pub use dtype::*; +pub use field::*; +pub use into_scalar::*; +use num_traits::{AsPrimitive, Bounded, FromPrimitive, Num, NumCast, One, Zero}; +use polars_compute::arithmetic::HasPrimitiveArithmeticKernel; +use polars_compute::float_sum::FloatSum; +use polars_utils::abs_diff::AbsDiff; +use polars_utils::float::IsFloat; +use polars_utils::min_max::MinMax; +use polars_utils::nulls::IsNull; +use polars_utils::total_ord::TotalHash; +pub use schema::SchemaExtPl; +#[cfg(feature = "serde")] +use serde::de::{EnumAccess, VariantAccess, Visitor}; +#[cfg(any(feature = "serde", feature = "serde-lazy"))] +use serde::{Deserialize, Serialize}; +#[cfg(any(feature = "serde", feature = "serde-lazy"))] +use serde::{Deserializer, Serializer}; +pub use time_unit::*; + +pub use crate::chunked_array::logical::*; +#[cfg(feature = "object")] +use crate::chunked_array::object::ObjectArray; +#[cfg(feature = "object")] +use crate::chunked_array::object::PolarsObjectSafe; +use crate::prelude::*; +use crate::utils::Wrap; + +pub struct TrueT; +pub struct FalseT; + +/// # Safety +/// +/// The StaticArray and dtype return must be correct. +pub unsafe trait PolarsDataType: Send + Sync + Sized + 'static { + type Physical<'a>: std::fmt::Debug + Clone; + type OwnedPhysical: std::fmt::Debug + Send + Sync + Clone + PartialEq; + type ZeroablePhysical<'a>: Zeroable + From>; + type Array: for<'a> StaticArray< + ValueT<'a> = Self::Physical<'a>, + ZeroableValueT<'a> = Self::ZeroablePhysical<'a>, + >; + type IsNested; + type HasViews; + type IsStruct; + type IsObject; + type IsLogical; + + fn get_dtype() -> DataType + where + Self: Sized; +} + +pub trait PolarsNumericType: 'static +where + Self: for<'a> PolarsDataType< + OwnedPhysical = Self::Native, + Physical<'a> = Self::Native, + ZeroablePhysical<'a> = Self::Native, + Array = PrimitiveArray, + IsNested = FalseT, + HasViews = FalseT, + IsStruct = FalseT, + IsObject = FalseT, + IsLogical = FalseT, + >, +{ + type Native: NumericNative; +} + +pub trait PolarsIntegerType: PolarsNumericType {} +pub trait PolarsFloatType: PolarsNumericType {} + +macro_rules! impl_polars_num_datatype { + ($trait: ident, $ca:ident, $variant:ident, $physical:ty, $owned_phys:ty) => { + #[derive(Clone, Copy)] + pub struct $ca {} + + unsafe impl PolarsDataType for $ca { + type Physical<'a> = $physical; + type OwnedPhysical = $owned_phys; + type ZeroablePhysical<'a> = $physical; + type Array = PrimitiveArray<$physical>; + type IsNested = FalseT; + type HasViews = FalseT; + type IsStruct = FalseT; + type IsObject = FalseT; + type IsLogical = FalseT; + + #[inline] + fn get_dtype() -> DataType { + DataType::$variant + } + } + + impl PolarsNumericType for $ca { + type Native = $physical; + } + + impl $trait for $ca {} + }; +} + +macro_rules! impl_polars_datatype_pass_dtype { + ($ca:ident, $dtype:expr, $arr:ty, $lt:lifetime, $phys:ty, $zerophys:ty, $owned_phys:ty, $has_views:ident, $is_logical:ident) => { + #[derive(Clone, Copy)] + pub struct $ca {} + + unsafe impl PolarsDataType for $ca { + type Physical<$lt> = $phys; + type OwnedPhysical = $owned_phys; + type ZeroablePhysical<$lt> = $zerophys; + type Array = $arr; + type IsNested = FalseT; + type HasViews = $has_views; + type IsStruct = FalseT; + type IsObject = FalseT; + type IsLogical = $is_logical; + + #[inline] + fn get_dtype() -> DataType { + $dtype + } + } + }; +} +macro_rules! impl_polars_binview_datatype { + ($ca:ident, $variant:ident, $arr:ty, $lt:lifetime, $phys:ty, $zerophys:ty, $owned_phys:ty) => { + impl_polars_datatype_pass_dtype!( + $ca, + DataType::$variant, + $arr, + $lt, + $phys, + $zerophys, + $owned_phys, + TrueT, + FalseT + ); + }; +} + +macro_rules! impl_polars_datatype { + ($ca:ident, $variant:ident, $arr:ty, $lt:lifetime, $phys:ty, $zerophys:ty, $owned_phys:ty, $is_logical:ident) => { + impl_polars_datatype_pass_dtype!( + $ca, + DataType::$variant, + $arr, + $lt, + $phys, + $zerophys, + $owned_phys, + FalseT, + $is_logical + ); + }; +} + +impl_polars_num_datatype!(PolarsIntegerType, UInt8Type, UInt8, u8, u8); +impl_polars_num_datatype!(PolarsIntegerType, UInt16Type, UInt16, u16, u16); +impl_polars_num_datatype!(PolarsIntegerType, UInt32Type, UInt32, u32, u32); +impl_polars_num_datatype!(PolarsIntegerType, UInt64Type, UInt64, u64, u64); +impl_polars_num_datatype!(PolarsIntegerType, Int8Type, Int8, i8, i8); +impl_polars_num_datatype!(PolarsIntegerType, Int16Type, Int16, i16, i16); +impl_polars_num_datatype!(PolarsIntegerType, Int32Type, Int32, i32, i32); +impl_polars_num_datatype!(PolarsIntegerType, Int64Type, Int64, i64, i64); + +#[cfg(feature = "dtype-i128")] +impl_polars_num_datatype!(PolarsIntegerType, Int128Type, Int128, i128, i128); +impl_polars_num_datatype!(PolarsFloatType, Float32Type, Float32, f32, f32); +impl_polars_num_datatype!(PolarsFloatType, Float64Type, Float64, f64, f64); +impl_polars_datatype!(DateType, Date, PrimitiveArray, 'a, i32, i32, i32, TrueT); +impl_polars_datatype!(TimeType, Time, PrimitiveArray, 'a, i64, i64, i64, TrueT); +impl_polars_binview_datatype!(StringType, String, Utf8ViewArray, 'a, &'a str, Option<&'a str>, String); +impl_polars_binview_datatype!(BinaryType, Binary, BinaryViewArray, 'a, &'a [u8], Option<&'a [u8]>, Box<[u8]>); +impl_polars_datatype!(BinaryOffsetType, BinaryOffset, BinaryArray, 'a, &'a [u8], Option<&'a [u8]>, Box<[u8]>, FalseT); +impl_polars_datatype!(BooleanType, Boolean, BooleanArray, 'a, bool, bool, bool, FalseT); + +#[cfg(feature = "dtype-decimal")] +impl_polars_datatype_pass_dtype!(DecimalType, DataType::Unknown(UnknownKind::Any), PrimitiveArray, 'a, i128, i128, i128, FalseT, TrueT); +impl_polars_datatype_pass_dtype!(DatetimeType, DataType::Unknown(UnknownKind::Any), PrimitiveArray, 'a, i64, i64, i64, FalseT, TrueT); +impl_polars_datatype_pass_dtype!(DurationType, DataType::Unknown(UnknownKind::Any), PrimitiveArray, 'a, i64, i64, i64, FalseT, TrueT); +impl_polars_datatype_pass_dtype!(CategoricalType, DataType::Unknown(UnknownKind::Any), PrimitiveArray, 'a, u32, u32, u32, FalseT, TrueT); + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct ListType {} +unsafe impl PolarsDataType for ListType { + type Physical<'a> = Box; + type OwnedPhysical = Box; + type ZeroablePhysical<'a> = Option>; + type Array = ListArray; + type IsNested = TrueT; + type HasViews = FalseT; + type IsStruct = FalseT; + type IsObject = FalseT; + type IsLogical = FalseT; + + fn get_dtype() -> DataType { + // Null as we cannot know anything without self. + DataType::List(Box::new(DataType::Null)) + } +} + +#[cfg(feature = "dtype-struct")] +pub struct StructType {} +#[cfg(feature = "dtype-struct")] +unsafe impl PolarsDataType for StructType { + // The physical types are invalid. + // We don't want these to be used as that would be + // very expensive. We use const asserts to ensure + // traits/methods using the physical types are + // not called for structs. + type Physical<'a> = (); + type OwnedPhysical = (); + type ZeroablePhysical<'a> = (); + type Array = StructArray; + type IsNested = TrueT; + type HasViews = FalseT; + type IsStruct = TrueT; + type IsObject = FalseT; + type IsLogical = FalseT; + + fn get_dtype() -> DataType + where + Self: Sized, + { + DataType::Struct(vec![]) + } +} + +#[cfg(feature = "dtype-array")] +pub struct FixedSizeListType {} +#[cfg(feature = "dtype-array")] +unsafe impl PolarsDataType for FixedSizeListType { + type Physical<'a> = Box; + type OwnedPhysical = Box; + type ZeroablePhysical<'a> = Option>; + type Array = FixedSizeListArray; + type IsNested = TrueT; + type HasViews = FalseT; + type IsStruct = FalseT; + type IsObject = FalseT; + type IsLogical = FalseT; + + fn get_dtype() -> DataType { + // Null as we cannot know anything without self. + DataType::Array(Box::new(DataType::Null), 0) + } +} + +#[cfg(feature = "object")] +pub struct ObjectType(T); +#[cfg(feature = "object")] +unsafe impl PolarsDataType for ObjectType { + type Physical<'a> = &'a T; + type OwnedPhysical = T; + type ZeroablePhysical<'a> = Option<&'a T>; + type Array = ObjectArray; + type IsNested = FalseT; + type HasViews = FalseT; + type IsStruct = FalseT; + type IsObject = TrueT; + type IsLogical = FalseT; + + fn get_dtype() -> DataType { + DataType::Object(T::type_name()) + } +} + +#[cfg(feature = "dtype-array")] +pub type ArrayChunked = ChunkedArray; +pub type ListChunked = ChunkedArray; +pub type BooleanChunked = ChunkedArray; +pub type UInt8Chunked = ChunkedArray; +pub type UInt16Chunked = ChunkedArray; +pub type UInt32Chunked = ChunkedArray; +pub type UInt64Chunked = ChunkedArray; +pub type Int8Chunked = ChunkedArray; +pub type Int16Chunked = ChunkedArray; +pub type Int32Chunked = ChunkedArray; +pub type Int64Chunked = ChunkedArray; +#[cfg(feature = "dtype-i128")] +pub type Int128Chunked = ChunkedArray; +pub type Float32Chunked = ChunkedArray; +pub type Float64Chunked = ChunkedArray; +pub type StringChunked = ChunkedArray; +pub type BinaryChunked = ChunkedArray; +pub type BinaryOffsetChunked = ChunkedArray; +#[cfg(feature = "object")] +pub type ObjectChunked = ChunkedArray>; + +pub trait NumericNative: + TotalOrd + + PartialOrd + + TotalHash + + NativeType + + Num + + NumCast + + Zero + + One + // + Simd + // + Simd8 + + std::iter::Sum + + Add + + Sub + + Mul + + Div + + Rem + + AddAssign + + SubAssign + + AbsDiff + + Bounded + + FromPrimitive + + IsFloat + + HasPrimitiveArithmeticKernel::Native> + + FloatSum + + AsPrimitive + + MinMax + + IsNull +{ + type PolarsType: PolarsNumericType; + type TrueDivPolarsType: PolarsNumericType; +} + +impl NumericNative for i8 { + type PolarsType = Int8Type; + type TrueDivPolarsType = Float64Type; +} +impl NumericNative for i16 { + type PolarsType = Int16Type; + type TrueDivPolarsType = Float64Type; +} +impl NumericNative for i32 { + type PolarsType = Int32Type; + type TrueDivPolarsType = Float64Type; +} +impl NumericNative for i64 { + type PolarsType = Int64Type; + type TrueDivPolarsType = Float64Type; +} +#[cfg(feature = "dtype-i128")] +impl NumericNative for i128 { + type PolarsType = Int128Type; + type TrueDivPolarsType = Float64Type; +} +impl NumericNative for u8 { + type PolarsType = UInt8Type; + type TrueDivPolarsType = Float64Type; +} +impl NumericNative for u16 { + type PolarsType = UInt16Type; + type TrueDivPolarsType = Float64Type; +} +impl NumericNative for u32 { + type PolarsType = UInt32Type; + type TrueDivPolarsType = Float64Type; +} +impl NumericNative for u64 { + type PolarsType = UInt64Type; + type TrueDivPolarsType = Float64Type; +} +impl NumericNative for f32 { + type PolarsType = Float32Type; + type TrueDivPolarsType = Float32Type; +} +impl NumericNative for f64 { + type PolarsType = Float64Type; + type TrueDivPolarsType = Float64Type; +} diff --git a/crates/polars-core/src/datatypes/schema.rs b/crates/polars-core/src/datatypes/schema.rs new file mode 100644 index 000000000000..edc3b38dee7b --- /dev/null +++ b/crates/polars-core/src/datatypes/schema.rs @@ -0,0 +1,22 @@ +use super::*; + +pub trait SchemaExtPl { + // Answers if this schema matches the given schema. + // + // Allows (nested) Null types in this schema to match any type in the schema, + // but not vice versa. In such a case Ok(true) is returned, because a cast + // is necessary. If no cast is necessary Ok(false) is returned, and an + // error is returned if the types are incompatible. + fn matches_schema(&self, other: &Schema) -> PolarsResult; +} + +impl SchemaExtPl for Schema { + fn matches_schema(&self, other: &Schema) -> PolarsResult { + polars_ensure!(self.len() == other.len(), SchemaMismatch: "found different number of fields in schema's\n\nLeft schema: {} fields, right schema: {} fields.", self.len(), other.len()); + let mut cast = false; + for (a, b) in self.iter_values().zip(other.iter_values()) { + cast |= a.matches_schema_type(b)?; + } + Ok(cast) + } +} diff --git a/crates/polars-core/src/datatypes/static_array_collect.rs b/crates/polars-core/src/datatypes/static_array_collect.rs new file mode 100644 index 000000000000..b1529a34ad69 --- /dev/null +++ b/crates/polars-core/src/datatypes/static_array_collect.rs @@ -0,0 +1,44 @@ +use arrow::array::ArrayFromIter; +use arrow::bitmap::BitmapBuilder; + +use crate::chunked_array::object::{ObjectArray, PolarsObject}; + +// TODO: more efficient implementations, I really took the short path here. +impl<'a, T: PolarsObject> ArrayFromIter<&'a T> for ObjectArray { + fn arr_from_iter>(iter: I) -> Self { + Self::try_arr_from_iter(iter.into_iter().map(|o| -> Result<_, ()> { Ok(Some(o)) })).unwrap() + } + + fn try_arr_from_iter>>(iter: I) -> Result { + Self::try_arr_from_iter(iter.into_iter().map(|o| Ok(Some(o?)))) + } +} + +impl<'a, T: PolarsObject> ArrayFromIter> for ObjectArray { + fn arr_from_iter>>(iter: I) -> Self { + Self::try_arr_from_iter(iter.into_iter().map(|o| -> Result<_, ()> { Ok(o) })).unwrap() + } + + fn try_arr_from_iter, E>>>( + iter: I, + ) -> Result { + let iter = iter.into_iter(); + let size = iter.size_hint().0; + + let mut null_mask_builder = BitmapBuilder::with_capacity(size); + let values: Vec = iter + .map(|value| match value? { + Some(value) => { + null_mask_builder.push(true); + Ok(value.clone()) + }, + None => { + null_mask_builder.push(false); + Ok(T::default()) + }, + }) + .collect::, E>>()?; + + Ok(ObjectArray::from(values).with_validity(null_mask_builder.into_opt_validity())) + } +} diff --git a/crates/polars-core/src/datatypes/time_unit.rs b/crates/polars-core/src/datatypes/time_unit.rs new file mode 100644 index 000000000000..d3a9a61443fb --- /dev/null +++ b/crates/polars-core/src/datatypes/time_unit.rs @@ -0,0 +1,75 @@ +use super::*; + +#[derive(Copy, Clone, Debug, PartialEq, PartialOrd, Eq, Hash)] +#[cfg_attr( + any(feature = "serde-lazy", feature = "serde"), + derive(Serialize, Deserialize) +)] +pub enum TimeUnit { + Nanoseconds, + Microseconds, + Milliseconds, +} + +impl From<&ArrowTimeUnit> for TimeUnit { + fn from(tu: &ArrowTimeUnit) -> Self { + match tu { + ArrowTimeUnit::Nanosecond => TimeUnit::Nanoseconds, + ArrowTimeUnit::Microsecond => TimeUnit::Microseconds, + ArrowTimeUnit::Millisecond => TimeUnit::Milliseconds, + // will be cast + ArrowTimeUnit::Second => TimeUnit::Milliseconds, + } + } +} + +impl Display for TimeUnit { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + TimeUnit::Nanoseconds => { + write!(f, "ns") + }, + TimeUnit::Microseconds => { + write!(f, "μs") + }, + TimeUnit::Milliseconds => { + write!(f, "ms") + }, + } + } +} + +impl TimeUnit { + pub fn to_ascii(self) -> &'static str { + use TimeUnit::*; + match self { + Nanoseconds => "ns", + Microseconds => "us", + Milliseconds => "ms", + } + } + + pub fn to_arrow(self) -> ArrowTimeUnit { + match self { + TimeUnit::Nanoseconds => ArrowTimeUnit::Nanosecond, + TimeUnit::Microseconds => ArrowTimeUnit::Microsecond, + TimeUnit::Milliseconds => ArrowTimeUnit::Millisecond, + } + } +} + +#[cfg(any(feature = "rows", feature = "object"))] +#[cfg(any(feature = "dtype-datetime", feature = "dtype-duration"))] +#[inline] +pub(crate) fn convert_time_units(v: i64, tu_l: TimeUnit, tu_r: TimeUnit) -> i64 { + use TimeUnit::*; + match (tu_l, tu_r) { + (Nanoseconds, Microseconds) => v / 1_000, + (Nanoseconds, Milliseconds) => v / 1_000_000, + (Microseconds, Nanoseconds) => v * 1_000, + (Microseconds, Milliseconds) => v / 1_000, + (Milliseconds, Microseconds) => v * 1_000, + (Milliseconds, Nanoseconds) => v * 1_000_000, + _ => v, + } +} diff --git a/crates/polars-core/src/error.rs b/crates/polars-core/src/error.rs new file mode 100644 index 000000000000..415c77c3bd2c --- /dev/null +++ b/crates/polars-core/src/error.rs @@ -0,0 +1 @@ +pub use polars_error::*; diff --git a/crates/polars-core/src/fmt.rs b/crates/polars-core/src/fmt.rs new file mode 100644 index 000000000000..49de0368eeed --- /dev/null +++ b/crates/polars-core/src/fmt.rs @@ -0,0 +1,1520 @@ +#![allow(unsafe_op_in_unsafe_fn)] +#[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] +use std::borrow::Cow; +use std::fmt::{Debug, Display, Formatter, Write}; +use std::str::FromStr; +use std::sync::RwLock; +use std::sync::atomic::{AtomicU8, Ordering}; +use std::{fmt, str}; + +#[cfg(any( + feature = "dtype-date", + feature = "dtype-datetime", + feature = "dtype-time" +))] +use arrow::temporal_conversions::*; +#[cfg(feature = "dtype-datetime")] +use chrono::NaiveDateTime; +#[cfg(feature = "timezones")] +use chrono::TimeZone; +#[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] +use comfy_table::modifiers::*; +#[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] +use comfy_table::presets::*; +#[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] +use comfy_table::*; +use num_traits::{Num, NumCast}; +use polars_error::feature_gated; + +use crate::config::*; +use crate::prelude::*; + +// Note: see https://github.com/pola-rs/polars/pull/13699 for the rationale +// behind choosing 10 as the default value for default number of rows displayed +const DEFAULT_ROW_LIMIT: usize = 10; +#[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] +const DEFAULT_COL_LIMIT: usize = 8; +const DEFAULT_STR_LEN_LIMIT: usize = 30; +const DEFAULT_LIST_LEN_LIMIT: usize = 3; + +#[derive(Copy, Clone)] +#[repr(u8)] +pub enum FloatFmt { + Mixed, + Full, +} +static FLOAT_PRECISION: RwLock> = RwLock::new(None); +static FLOAT_FMT: AtomicU8 = AtomicU8::new(FloatFmt::Mixed as u8); + +static THOUSANDS_SEPARATOR: AtomicU8 = AtomicU8::new(b'\0'); +static DECIMAL_SEPARATOR: AtomicU8 = AtomicU8::new(b'.'); + +// Numeric formatting getters +pub fn get_float_fmt() -> FloatFmt { + match FLOAT_FMT.load(Ordering::Relaxed) { + 0 => FloatFmt::Mixed, + 1 => FloatFmt::Full, + _ => panic!(), + } +} +pub fn get_float_precision() -> Option { + *FLOAT_PRECISION.read().unwrap() +} +pub fn get_decimal_separator() -> char { + DECIMAL_SEPARATOR.load(Ordering::Relaxed) as char +} +pub fn get_thousands_separator() -> String { + let sep = THOUSANDS_SEPARATOR.load(Ordering::Relaxed) as char; + if sep == '\0' { + "".to_string() + } else { + sep.to_string() + } +} +#[cfg(feature = "dtype-decimal")] +pub fn get_trim_decimal_zeros() -> bool { + arrow::compute::decimal::get_trim_decimal_zeros() +} + +// Numeric formatting setters +pub fn set_float_fmt(fmt: FloatFmt) { + FLOAT_FMT.store(fmt as u8, Ordering::Relaxed) +} +pub fn set_float_precision(precision: Option) { + *FLOAT_PRECISION.write().unwrap() = precision; +} +pub fn set_decimal_separator(dec: Option) { + DECIMAL_SEPARATOR.store(dec.unwrap_or('.') as u8, Ordering::Relaxed) +} +pub fn set_thousands_separator(sep: Option) { + THOUSANDS_SEPARATOR.store(sep.unwrap_or('\0') as u8, Ordering::Relaxed) +} +#[cfg(feature = "dtype-decimal")] +pub fn set_trim_decimal_zeros(trim: Option) { + arrow::compute::decimal::set_trim_decimal_zeros(trim) +} + +/// Parses an environment variable value. +fn parse_env_var(name: &str) -> Option { + std::env::var(name).ok().and_then(|v| v.parse().ok()) +} +/// Parses an environment variable value as a limit or set a default. +/// +/// Negative values (e.g. -1) are parsed as 'no limit' or [`usize::MAX`]. +fn parse_env_var_limit(name: &str, default: usize) -> usize { + parse_env_var(name).map_or( + default, + |n: i64| { + if n < 0 { usize::MAX } else { n as usize } + }, + ) +} + +fn get_row_limit() -> usize { + parse_env_var_limit(FMT_MAX_ROWS, DEFAULT_ROW_LIMIT) +} +#[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] +fn get_col_limit() -> usize { + parse_env_var_limit(FMT_MAX_COLS, DEFAULT_COL_LIMIT) +} +fn get_str_len_limit() -> usize { + parse_env_var_limit(FMT_STR_LEN, DEFAULT_STR_LEN_LIMIT) +} +fn get_list_len_limit() -> usize { + parse_env_var_limit(FMT_TABLE_CELL_LIST_LEN, DEFAULT_LIST_LEN_LIMIT) +} +#[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] +fn get_ellipsis() -> &'static str { + match std::env::var(FMT_TABLE_FORMATTING).as_deref().unwrap_or("") { + preset if preset.starts_with("ASCII") => "...", + _ => "…", + } +} +#[cfg(not(any(feature = "fmt", feature = "fmt_no_tty")))] +fn get_ellipsis() -> &'static str { + "…" +} + +fn estimate_string_width(s: &str) -> usize { + // get a slightly more accurate estimate of a string's screen + // width, accounting (very roughly) for multibyte characters + let n_chars = s.chars().count(); + let n_bytes = s.len(); + if n_bytes == n_chars { + n_chars + } else { + let adjust = n_bytes as f64 / n_chars as f64; + std::cmp::min(n_chars * 2, (n_chars as f64 * adjust).ceil() as usize) + } +} + +macro_rules! format_array { + ($f:ident, $a:expr, $dtype:expr, $name:expr, $array_type:expr) => {{ + write!( + $f, + "shape: ({},)\n{}: '{}' [{}]\n[\n", + fmt_int_string_custom(&$a.len().to_string(), 3, "_"), + $array_type, + $name, + $dtype + )?; + + let ellipsis = get_ellipsis(); + let truncate = match $a.dtype() { + DataType::String => true, + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_, _) | DataType::Enum(_, _) => true, + _ => false, + }; + let truncate_len = if truncate { get_str_len_limit() } else { 0 }; + + let write_fn = |v, f: &mut Formatter| -> fmt::Result { + if truncate { + let v = format!("{}", v); + let v_no_quotes = &v[1..v.len() - 1]; + let v_trunc = &v_no_quotes[..v_no_quotes + .char_indices() + .take(truncate_len) + .last() + .map(|(i, c)| i + c.len_utf8()) + .unwrap_or(0)]; + if v_no_quotes == v_trunc { + write!(f, "\t{}\n", v)?; + } else { + write!(f, "\t\"{v_trunc}{ellipsis}\n")?; + } + } else { + write!(f, "\t{v}\n")?; + }; + Ok(()) + }; + + let limit = get_row_limit(); + + if $a.len() > limit { + let half = limit / 2; + let rest = limit % 2; + + for i in 0..(half + rest) { + let v = $a.get_any_value(i).unwrap(); + write_fn(v, $f)?; + } + write!($f, "\t{ellipsis}\n")?; + for i in ($a.len() - half)..$a.len() { + let v = $a.get_any_value(i).unwrap(); + write_fn(v, $f)?; + } + } else { + for i in 0..$a.len() { + let v = $a.get_any_value(i).unwrap(); + write_fn(v, $f)?; + } + } + + write!($f, "]") + }}; +} + +#[cfg(feature = "object")] +fn format_object_array( + f: &mut Formatter<'_>, + object: &Series, + name: &str, + array_type: &str, +) -> fmt::Result { + match object.dtype() { + DataType::Object(inner_type) => { + let limit = std::cmp::min(DEFAULT_ROW_LIMIT, object.len()); + write!( + f, + "shape: ({},)\n{}: '{}' [o][{}]\n[\n", + fmt_int_string_custom(&object.len().to_string(), 3, "_"), + array_type, + name, + inner_type + )?; + for i in 0..limit { + let v = object.str_value(i); + writeln!(f, "\t{}", v.unwrap())?; + } + write!(f, "]") + }, + _ => unreachable!(), + } +} + +impl Debug for ChunkedArray +where + T: PolarsNumericType, +{ + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let dt = format!("{}", T::get_dtype()); + format_array!(f, self, dt, self.name(), "ChunkedArray") + } +} + +impl Debug for ChunkedArray { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + format_array!(f, self, "bool", self.name(), "ChunkedArray") + } +} + +impl Debug for StringChunked { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + format_array!(f, self, "str", self.name(), "ChunkedArray") + } +} + +impl Debug for BinaryChunked { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + format_array!(f, self, "binary", self.name(), "ChunkedArray") + } +} + +impl Debug for ListChunked { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + format_array!(f, self, "list", self.name(), "ChunkedArray") + } +} + +#[cfg(feature = "dtype-array")] +impl Debug for ArrayChunked { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + format_array!(f, self, "fixed size list", self.name(), "ChunkedArray") + } +} + +#[cfg(feature = "object")] +impl Debug for ObjectChunked +where + T: PolarsObject, +{ + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let limit = std::cmp::min(DEFAULT_ROW_LIMIT, self.len()); + let ellipsis = get_ellipsis(); + let inner_type = T::type_name(); + write!( + f, + "ChunkedArray: '{}' [o][{}]\n[\n", + self.name(), + inner_type + )?; + + if limit < self.len() { + for i in 0..limit / 2 { + match self.get(i) { + None => writeln!(f, "\tnull")?, + Some(val) => writeln!(f, "\t{val}")?, + }; + } + writeln!(f, "\t{ellipsis}")?; + for i in (0..limit / 2).rev() { + match self.get(self.len() - i - 1) { + None => writeln!(f, "\tnull")?, + Some(val) => writeln!(f, "\t{val}")?, + }; + } + } else { + for i in 0..limit { + match self.get(i) { + None => writeln!(f, "\tnull")?, + Some(val) => writeln!(f, "\t{val}")?, + }; + } + } + Ok(()) + } +} + +impl Debug for Series { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self.dtype() { + DataType::Boolean => { + format_array!(f, self.bool().unwrap(), "bool", self.name(), "Series") + }, + DataType::String => { + format_array!(f, self.str().unwrap(), "str", self.name(), "Series") + }, + DataType::UInt8 => { + format_array!(f, self.u8().unwrap(), "u8", self.name(), "Series") + }, + DataType::UInt16 => { + format_array!(f, self.u16().unwrap(), "u16", self.name(), "Series") + }, + DataType::UInt32 => { + format_array!(f, self.u32().unwrap(), "u32", self.name(), "Series") + }, + DataType::UInt64 => { + format_array!(f, self.u64().unwrap(), "u64", self.name(), "Series") + }, + DataType::Int8 => { + format_array!(f, self.i8().unwrap(), "i8", self.name(), "Series") + }, + DataType::Int16 => { + format_array!(f, self.i16().unwrap(), "i16", self.name(), "Series") + }, + DataType::Int32 => { + format_array!(f, self.i32().unwrap(), "i32", self.name(), "Series") + }, + DataType::Int64 => { + format_array!(f, self.i64().unwrap(), "i64", self.name(), "Series") + }, + DataType::Int128 => { + feature_gated!( + "dtype-i128", + format_array!(f, self.i128().unwrap(), "i128", self.name(), "Series") + ) + }, + DataType::Float32 => { + format_array!(f, self.f32().unwrap(), "f32", self.name(), "Series") + }, + DataType::Float64 => { + format_array!(f, self.f64().unwrap(), "f64", self.name(), "Series") + }, + #[cfg(feature = "dtype-date")] + DataType::Date => format_array!(f, self.date().unwrap(), "date", self.name(), "Series"), + #[cfg(feature = "dtype-datetime")] + DataType::Datetime(_, _) => { + let dt = format!("{}", self.dtype()); + format_array!(f, self.datetime().unwrap(), &dt, self.name(), "Series") + }, + #[cfg(feature = "dtype-time")] + DataType::Time => format_array!(f, self.time().unwrap(), "time", self.name(), "Series"), + #[cfg(feature = "dtype-duration")] + DataType::Duration(_) => { + let dt = format!("{}", self.dtype()); + format_array!(f, self.duration().unwrap(), &dt, self.name(), "Series") + }, + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(_, _) => { + let dt = format!("{}", self.dtype()); + format_array!(f, self.decimal().unwrap(), &dt, self.name(), "Series") + }, + #[cfg(feature = "dtype-array")] + DataType::Array(_, _) => { + let dt = format!("{}", self.dtype()); + format_array!(f, self.array().unwrap(), &dt, self.name(), "Series") + }, + DataType::List(_) => { + let dt = format!("{}", self.dtype()); + format_array!(f, self.list().unwrap(), &dt, self.name(), "Series") + }, + #[cfg(feature = "object")] + DataType::Object(_) => format_object_array(f, self, self.name(), "Series"), + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_, _) => { + format_array!(f, self.categorical().unwrap(), "cat", self.name(), "Series") + }, + + #[cfg(feature = "dtype-categorical")] + DataType::Enum(_, _) => format_array!( + f, + self.categorical().unwrap(), + "enum", + self.name(), + "Series" + ), + #[cfg(feature = "dtype-struct")] + dt @ DataType::Struct(_) => format_array!( + f, + self.struct_().unwrap(), + format!("{dt}"), + self.name(), + "Series" + ), + DataType::Null => { + format_array!(f, self.null().unwrap(), "null", self.name(), "Series") + }, + DataType::Binary => { + format_array!(f, self.binary().unwrap(), "binary", self.name(), "Series") + }, + DataType::BinaryOffset => { + format_array!( + f, + self.binary_offset().unwrap(), + "binary[offset]", + self.name(), + "Series" + ) + }, + dt => panic!("{dt:?} not impl"), + } + } +} + +impl Display for Series { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Debug::fmt(self, f) + } +} + +impl Debug for DataFrame { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Display::fmt(self, f) + } +} +#[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] +fn make_str_val(v: &str, truncate: usize, ellipsis: &String) -> String { + let v_trunc = &v[..v + .char_indices() + .take(truncate) + .last() + .map(|(i, c)| i + c.len_utf8()) + .unwrap_or(0)]; + if v == v_trunc { + v.to_string() + } else { + format!("{v_trunc}{ellipsis}") + } +} + +#[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] +fn field_to_str( + f: &Field, + str_truncate: usize, + ellipsis: &String, + padding: usize, +) -> (String, usize) { + let name = make_str_val(f.name(), str_truncate, ellipsis); + let name_length = estimate_string_width(name.as_str()); + let mut column_name = name; + if env_is_true(FMT_TABLE_HIDE_COLUMN_NAMES) { + column_name = "".to_string(); + } + let column_dtype = if env_is_true(FMT_TABLE_HIDE_COLUMN_DATA_TYPES) { + "".to_string() + } else if env_is_true(FMT_TABLE_INLINE_COLUMN_DATA_TYPE) + | env_is_true(FMT_TABLE_HIDE_COLUMN_NAMES) + { + format!("{}", f.dtype()) + } else { + format!("\n{}", f.dtype()) + }; + let mut dtype_length = column_dtype.trim_start().len(); + let mut separator = "\n---"; + if env_is_true(FMT_TABLE_HIDE_COLUMN_SEPARATOR) + | env_is_true(FMT_TABLE_HIDE_COLUMN_NAMES) + | env_is_true(FMT_TABLE_HIDE_COLUMN_DATA_TYPES) + { + separator = "" + } + let s = if env_is_true(FMT_TABLE_INLINE_COLUMN_DATA_TYPE) + & !env_is_true(FMT_TABLE_HIDE_COLUMN_DATA_TYPES) + { + let inline_name_dtype = format!("{column_name} ({column_dtype})"); + dtype_length = inline_name_dtype.len(); + inline_name_dtype + } else { + format!("{column_name}{separator}{column_dtype}") + }; + let mut s_len = std::cmp::max(name_length, dtype_length); + let separator_length = estimate_string_width(separator.trim()); + if s_len < separator_length { + s_len = separator_length; + } + (s, s_len + padding) +} + +#[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] +fn prepare_row( + row: Vec>, + n_first: usize, + n_last: usize, + str_truncate: usize, + max_elem_lengths: &mut [usize], + ellipsis: &String, + padding: usize, +) -> Vec { + let reduce_columns = n_first + n_last < row.len(); + let n_elems = n_first + n_last + reduce_columns as usize; + let mut row_strings = Vec::with_capacity(n_elems); + + for (idx, v) in row[0..n_first].iter().enumerate() { + let elem_str = make_str_val(v, str_truncate, ellipsis); + let elem_len = estimate_string_width(elem_str.as_str()) + padding; + if max_elem_lengths[idx] < elem_len { + max_elem_lengths[idx] = elem_len; + }; + row_strings.push(elem_str); + } + if reduce_columns { + row_strings.push(ellipsis.to_string()); + max_elem_lengths[n_first] = ellipsis.chars().count() + padding; + } + let elem_offset = n_first + reduce_columns as usize; + for (idx, v) in row[row.len() - n_last..].iter().enumerate() { + let elem_str = make_str_val(v, str_truncate, ellipsis); + let elem_len = estimate_string_width(elem_str.as_str()) + padding; + let elem_idx = elem_offset + idx; + if max_elem_lengths[elem_idx] < elem_len { + max_elem_lengths[elem_idx] = elem_len; + }; + row_strings.push(elem_str); + } + row_strings +} + +#[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] +fn env_is_true(varname: &str) -> bool { + std::env::var(varname).as_deref().unwrap_or("0") == "1" +} + +#[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] +fn fmt_df_shape((shape0, shape1): &(usize, usize)) -> String { + // e.g. (1_000_000, 4_000) + format!( + "({}, {})", + fmt_int_string_custom(&shape0.to_string(), 3, "_"), + fmt_int_string_custom(&shape1.to_string(), 3, "_") + ) +} + +impl Display for DataFrame { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + #[cfg(any(feature = "fmt", feature = "fmt_no_tty"))] + { + let height = self.height(); + assert!( + self.columns.iter().all(|s| s.len() == height), + "The column lengths in the DataFrame are not equal." + ); + + let table_style = std::env::var(FMT_TABLE_FORMATTING).unwrap_or("DEFAULT".to_string()); + let is_utf8 = !table_style.starts_with("ASCII"); + let preset = match table_style.as_str() { + "ASCII_FULL" => ASCII_FULL, + "ASCII_FULL_CONDENSED" => ASCII_FULL_CONDENSED, + "ASCII_NO_BORDERS" => ASCII_NO_BORDERS, + "ASCII_BORDERS_ONLY" => ASCII_BORDERS_ONLY, + "ASCII_BORDERS_ONLY_CONDENSED" => ASCII_BORDERS_ONLY_CONDENSED, + "ASCII_HORIZONTAL_ONLY" => ASCII_HORIZONTAL_ONLY, + "ASCII_MARKDOWN" | "MARKDOWN" => ASCII_MARKDOWN, + "UTF8_FULL" => UTF8_FULL, + "UTF8_FULL_CONDENSED" => UTF8_FULL_CONDENSED, + "UTF8_NO_BORDERS" => UTF8_NO_BORDERS, + "UTF8_BORDERS_ONLY" => UTF8_BORDERS_ONLY, + "UTF8_HORIZONTAL_ONLY" => UTF8_HORIZONTAL_ONLY, + "NOTHING" => NOTHING, + _ => UTF8_FULL_CONDENSED, + }; + let ellipsis = get_ellipsis().to_string(); + let ellipsis_len = ellipsis.chars().count(); + let max_n_cols = get_col_limit(); + let max_n_rows = get_row_limit(); + let str_truncate = get_str_len_limit(); + let padding = 2; // eg: one char either side of the value + + let (n_first, n_last) = if self.width() > max_n_cols { + (max_n_cols.div_ceil(2), max_n_cols / 2) + } else { + (self.width(), 0) + }; + let reduce_columns = n_first + n_last < self.width(); + let n_tbl_cols = n_first + n_last + reduce_columns as usize; + let mut names = Vec::with_capacity(n_tbl_cols); + let mut name_lengths = Vec::with_capacity(n_tbl_cols); + + let fields = self.fields(); + for field in fields[0..n_first].iter() { + let (s, l) = field_to_str(field, str_truncate, &ellipsis, padding); + names.push(s); + name_lengths.push(l); + } + if reduce_columns { + names.push(ellipsis.clone()); + name_lengths.push(ellipsis_len); + } + for field in fields[self.width() - n_last..].iter() { + let (s, l) = field_to_str(field, str_truncate, &ellipsis, padding); + names.push(s); + name_lengths.push(l); + } + + let mut table = Table::new(); + table + .load_preset(preset) + .set_content_arrangement(ContentArrangement::Dynamic); + + if is_utf8 && env_is_true(FMT_TABLE_ROUNDED_CORNERS) { + table.apply_modifier(UTF8_ROUND_CORNERS); + } + let mut constraints = Vec::with_capacity(n_tbl_cols); + let mut max_elem_lengths: Vec = vec![0; n_tbl_cols]; + + if max_n_rows > 0 { + if height > max_n_rows { + // Truncate the table if we have more rows than the + // configured maximum number of rows + let mut rows = Vec::with_capacity(std::cmp::max(max_n_rows, 2)); + let half = max_n_rows / 2; + let rest = max_n_rows % 2; + + for i in 0..(half + rest) { + let row = self + .get_columns() + .iter() + .map(|c| c.str_value(i).unwrap()) + .collect(); + + let row_strings = prepare_row( + row, + n_first, + n_last, + str_truncate, + &mut max_elem_lengths, + &ellipsis, + padding, + ); + rows.push(row_strings); + } + let dots = vec![ellipsis.clone(); rows[0].len()]; + rows.push(dots); + + for i in (height - half)..height { + let row = self + .get_columns() + .iter() + .map(|c| c.str_value(i).unwrap()) + .collect(); + + let row_strings = prepare_row( + row, + n_first, + n_last, + str_truncate, + &mut max_elem_lengths, + &ellipsis, + padding, + ); + rows.push(row_strings); + } + table.add_rows(rows); + } else { + for i in 0..height { + if self.width() > 0 { + let row = self + .materialized_column_iter() + .map(|s| s.str_value(i).unwrap()) + .collect(); + + let row_strings = prepare_row( + row, + n_first, + n_last, + str_truncate, + &mut max_elem_lengths, + &ellipsis, + padding, + ); + table.add_row(row_strings); + } else { + break; + } + } + } + } else if height > 0 { + let dots: Vec = vec![ellipsis.clone(); self.columns.len()]; + table.add_row(dots); + } + let tbl_fallback_width = 100; + let tbl_width = std::env::var("POLARS_TABLE_WIDTH") + .map(|s| { + let n = s + .parse::() + .expect("could not parse table width argument"); + let w = if n < 0 { + u16::MAX + } else { + u16::try_from(n).expect("table width argument does not fit in u16") + }; + Some(w) + }) + .unwrap_or(None); + + // column width constraints + let col_width_exact = + |w: usize| ColumnConstraint::Absolute(comfy_table::Width::Fixed(w as u16)); + let col_width_bounds = |l: usize, u: usize| ColumnConstraint::Boundaries { + lower: Width::Fixed(l as u16), + upper: Width::Fixed(u as u16), + }; + let min_col_width = std::cmp::max(5, 3 + padding); + for (idx, elem_len) in max_elem_lengths.iter().enumerate() { + let mx = std::cmp::min( + str_truncate + ellipsis_len + padding, + std::cmp::max(name_lengths[idx], *elem_len), + ); + if (mx <= min_col_width) && !(max_n_rows > 0 && height > max_n_rows) { + // col width is less than min width + table is not truncated + constraints.push(col_width_exact(mx)); + } else if mx <= min_col_width { + // col width is less than min width + table is truncated (w/ ellipsis) + constraints.push(col_width_bounds(mx, min_col_width)); + } else { + constraints.push(col_width_bounds(min_col_width, mx)); + } + } + + // insert a header row, unless both column names and dtypes are hidden + if !(env_is_true(FMT_TABLE_HIDE_COLUMN_NAMES) + && env_is_true(FMT_TABLE_HIDE_COLUMN_DATA_TYPES)) + { + table.set_header(names).set_constraints(constraints); + } + + // if tbl_width is explicitly set, use it + if let Some(w) = tbl_width { + table.set_width(w); + } else { + // if no tbl_width (it's not tty && width not explicitly set), apply + // a default value; this is needed to support non-tty applications + #[cfg(feature = "fmt")] + if table.width().is_none() && !table.is_tty() { + table.set_width(tbl_fallback_width); + } + #[cfg(feature = "fmt_no_tty")] + if table.width().is_none() { + table.set_width(tbl_fallback_width); + } + } + + // set alignment of cells, if defined + if std::env::var(FMT_TABLE_CELL_ALIGNMENT).is_ok() + | std::env::var(FMT_TABLE_CELL_NUMERIC_ALIGNMENT).is_ok() + { + let str_preset = std::env::var(FMT_TABLE_CELL_ALIGNMENT) + .unwrap_or_else(|_| "DEFAULT".to_string()); + let num_preset = std::env::var(FMT_TABLE_CELL_NUMERIC_ALIGNMENT) + .unwrap_or_else(|_| str_preset.to_string()); + for (column_index, column) in table.column_iter_mut().enumerate() { + let dtype = fields[column_index].dtype(); + let mut preset = str_preset.as_str(); + if dtype.is_primitive_numeric() || dtype.is_decimal() { + preset = num_preset.as_str(); + } + match preset { + "RIGHT" => column.set_cell_alignment(CellAlignment::Right), + "LEFT" => column.set_cell_alignment(CellAlignment::Left), + "CENTER" => column.set_cell_alignment(CellAlignment::Center), + _ => {}, + } + } + } + + // establish 'shape' information (above/below/hidden) + if env_is_true(FMT_TABLE_HIDE_DATAFRAME_SHAPE_INFORMATION) { + write!(f, "{table}")?; + } else { + let shape_str = fmt_df_shape(&self.shape()); + if env_is_true(FMT_TABLE_DATAFRAME_SHAPE_BELOW) { + write!(f, "{table}\nshape: {}", shape_str)?; + } else { + write!(f, "shape: {}\n{}", shape_str, table)?; + } + } + } + #[cfg(not(any(feature = "fmt", feature = "fmt_no_tty")))] + { + write!( + f, + "shape: {:?}\nto see more, compile with the 'fmt' or 'fmt_no_tty' feature", + self.shape() + )?; + } + Ok(()) + } +} + +fn fmt_int_string_custom(num: &str, group_size: u8, group_separator: &str) -> String { + if group_size == 0 || num.len() <= 1 { + num.to_string() + } else { + let mut out = String::new(); + let sign_offset = if num.starts_with('-') || num.starts_with('+') { + out.push(num.chars().next().unwrap()); + 1 + } else { + 0 + }; + let int_body = &num.as_bytes()[sign_offset..] + .rchunks(group_size as usize) + .rev() + .map(str::from_utf8) + .collect::, _>>() + .unwrap() + .join(group_separator); + out.push_str(int_body); + out + } +} + +fn fmt_int_string(num: &str) -> String { + fmt_int_string_custom(num, 3, &get_thousands_separator()) +} + +fn fmt_float_string_custom( + num: &str, + group_size: u8, + group_separator: &str, + decimal: char, +) -> String { + // Quick exit if no formatting would be applied + if num.len() <= 1 || (group_size == 0 && decimal == '.') { + num.to_string() + } else { + // Take existing numeric string and apply digit grouping & separator/decimal chars + // e.g. "1000000" → "1_000_000", "-123456.798" → "-123,456.789", etc + let (idx, has_fractional) = match num.find('.') { + Some(i) => (i, true), + None => (num.len(), false), + }; + let mut out = String::new(); + let integer_part = &num[..idx]; + + out.push_str(&fmt_int_string_custom( + integer_part, + group_size, + group_separator, + )); + if has_fractional { + out.push(decimal); + out.push_str(&num[idx + 1..]); + }; + out + } +} + +fn fmt_float_string(num: &str) -> String { + fmt_float_string_custom(num, 3, &get_thousands_separator(), get_decimal_separator()) +} + +fn fmt_integer( + f: &mut Formatter<'_>, + width: usize, + v: T, +) -> fmt::Result { + write!(f, "{:>width$}", fmt_int_string(&v.to_string())) +} + +const SCIENTIFIC_BOUND: f64 = 999999.0; + +fn fmt_float(f: &mut Formatter<'_>, width: usize, v: T) -> fmt::Result { + let v: f64 = NumCast::from(v).unwrap(); + + let float_precision = get_float_precision(); + + if let Some(precision) = float_precision { + if format!("{v:.precision$}", precision = precision).len() > 19 { + return write!(f, "{v:>width$.precision$e}", precision = precision); + } + let s = format!("{v:>width$.precision$}", precision = precision); + return write!(f, "{}", fmt_float_string(s.as_str())); + } + + if matches!(get_float_fmt(), FloatFmt::Full) { + let s = format!("{v:>width$}"); + return write!(f, "{}", fmt_float_string(s.as_str())); + } + + // show integers as 0.0, 1.0 ... 101.0 + if v.fract() == 0.0 && v.abs() < SCIENTIFIC_BOUND { + let s = format!("{v:>width$.1}"); + write!(f, "{}", fmt_float_string(s.as_str())) + } else if format!("{v}").len() > 9 { + // large and small floats in scientific notation. + // (note: scientific notation does not play well with digit grouping) + if (!(0.000001..=SCIENTIFIC_BOUND).contains(&v.abs()) | (v.abs() > SCIENTIFIC_BOUND)) + && get_thousands_separator().is_empty() + { + let s = format!("{v:>width$.4e}"); + write!(f, "{}", fmt_float_string(s.as_str())) + } else { + // this makes sure we don't write 12.00000 in case of a long flt that is 12.0000000001 + // instead we write 12.0 + let s = format!("{v:>width$.6}"); + + if s.ends_with('0') { + let mut s = s.as_str(); + let mut len = s.len() - 1; + + while s.ends_with('0') { + s = &s[..len]; + len -= 1; + } + let s = if s.ends_with('.') { + format!("{s}0") + } else { + s.to_string() + }; + write!(f, "{}", fmt_float_string(s.as_str())) + } else { + // 12.0934509341243124 + // written as + // 12.09345 + let s = format!("{v:>width$.6}"); + write!(f, "{}", fmt_float_string(s.as_str())) + } + } + } else { + let s = if v.fract() == 0.0 { + format!("{v:>width$e}") + } else { + format!("{v:>width$}") + }; + write!(f, "{}", fmt_float_string(s.as_str())) + } +} + +#[cfg(feature = "dtype-datetime")] +fn fmt_datetime( + f: &mut Formatter<'_>, + v: i64, + tu: TimeUnit, + tz: Option<&self::datatypes::TimeZone>, +) -> fmt::Result { + let ndt = match tu { + TimeUnit::Nanoseconds => timestamp_ns_to_datetime(v), + TimeUnit::Microseconds => timestamp_us_to_datetime(v), + TimeUnit::Milliseconds => timestamp_ms_to_datetime(v), + }; + match tz { + None => std::fmt::Display::fmt(&ndt, f), + Some(tz) => PlTzAware::new(ndt, tz).fmt(f), + } +} + +#[cfg(feature = "dtype-duration")] +const DURATION_PARTS: [&str; 4] = ["d", "h", "m", "s"]; +#[cfg(feature = "dtype-duration")] +const ISO_DURATION_PARTS: [&str; 4] = ["D", "H", "M", "S"]; +#[cfg(feature = "dtype-duration")] +const SIZES_NS: [i64; 4] = [ + 86_400_000_000_000, // per day + 3_600_000_000_000, // per hour + 60_000_000_000, // per minute + 1_000_000_000, // per second +]; +#[cfg(feature = "dtype-duration")] +const SIZES_US: [i64; 4] = [86_400_000_000, 3_600_000_000, 60_000_000, 1_000_000]; +#[cfg(feature = "dtype-duration")] +const SIZES_MS: [i64; 4] = [86_400_000, 3_600_000, 60_000, 1_000]; + +#[cfg(feature = "dtype-duration")] +pub fn fmt_duration_string(f: &mut W, v: i64, unit: TimeUnit) -> fmt::Result { + // take the physical/integer duration value and return a + // friendly/readable duration string, eg: "3d 22m 55s 1ms" + if v == 0 { + return match unit { + TimeUnit::Nanoseconds => f.write_str("0ns"), + TimeUnit::Microseconds => f.write_str("0µs"), + TimeUnit::Milliseconds => f.write_str("0ms"), + }; + }; + // iterate over dtype-specific sizes to appropriately scale + // and extract 'days', 'hours', 'minutes', and 'seconds' parts. + let sizes = match unit { + TimeUnit::Nanoseconds => SIZES_NS.as_slice(), + TimeUnit::Microseconds => SIZES_US.as_slice(), + TimeUnit::Milliseconds => SIZES_MS.as_slice(), + }; + let mut buffer = itoa::Buffer::new(); + for (i, &size) in sizes.iter().enumerate() { + let whole_num = if i == 0 { + v / size + } else { + (v % sizes[i - 1]) / size + }; + if whole_num != 0 { + f.write_str(buffer.format(whole_num))?; + f.write_str(DURATION_PARTS[i])?; + if v % size != 0 { + f.write_char(' ')?; + } + } + } + // write fractional seconds as integer nano/micro/milliseconds. + let (v, units) = match unit { + TimeUnit::Nanoseconds => (v % 1_000_000_000, ["ns", "µs", "ms"]), + TimeUnit::Microseconds => (v % 1_000_000, ["µs", "ms", ""]), + TimeUnit::Milliseconds => (v % 1_000, ["ms", "", ""]), + }; + if v != 0 { + let (value, suffix) = if v % 1_000 != 0 { + (v, units[0]) + } else if v % 1_000_000 != 0 { + (v / 1_000, units[1]) + } else { + (v / 1_000_000, units[2]) + }; + f.write_str(buffer.format(value))?; + f.write_str(suffix)?; + } + Ok(()) +} + +#[cfg(feature = "dtype-duration")] +pub fn iso_duration_string(s: &mut String, mut v: i64, unit: TimeUnit) { + if v == 0 { + s.push_str("PT0S"); + return; + } + let mut buffer = itoa::Buffer::new(); + let mut wrote_part = false; + if v < 0 { + // negative sign before "P" indicates entire ISO duration is negative. + s.push_str("-P"); + v = v.abs(); + } else { + s.push('P'); + } + // iterate over dtype-specific sizes to appropriately scale + // and extract 'days', 'hours', 'minutes', and 'seconds' parts. + let sizes = match unit { + TimeUnit::Nanoseconds => SIZES_NS.as_slice(), + TimeUnit::Microseconds => SIZES_US.as_slice(), + TimeUnit::Milliseconds => SIZES_MS.as_slice(), + }; + for (i, &size) in sizes.iter().enumerate() { + let whole_num = if i == 0 { + v / size + } else { + (v % sizes[i - 1]) / size + }; + if whole_num != 0 || i == 3 { + if i != 3 { + // days, hours, minutes + s.push_str(buffer.format(whole_num)); + s.push_str(ISO_DURATION_PARTS[i]); + } else { + // (index 3 => 'seconds' part): the ISO version writes + // fractional seconds, not integer nano/micro/milliseconds. + // if zero, only write out if no other parts written yet. + let fractional_part = v % size; + if whole_num == 0 && fractional_part == 0 { + if !wrote_part { + s.push_str("0S") + } + } else { + s.push_str(buffer.format(whole_num)); + if fractional_part != 0 { + let secs = match unit { + TimeUnit::Nanoseconds => format!(".{:09}", fractional_part), + TimeUnit::Microseconds => format!(".{:06}", fractional_part), + TimeUnit::Milliseconds => format!(".{:03}", fractional_part), + }; + s.push_str(secs.trim_end_matches('0')); + } + s.push_str(ISO_DURATION_PARTS[i]); + } + } + // (index 0 => 'days' part): after writing days above (if non-zero) + // the ISO duration string requires a `T` before the time part. + if i == 0 { + s.push('T'); + } + wrote_part = true; + } else if i == 0 { + // always need to write the `T` separator for ISO + // durations, even if there is no 'days' part. + s.push('T'); + } + } + // if there was only a 'days' component, no need for time separator. + if s.ends_with('T') { + s.pop(); + } +} + +fn format_blob(f: &mut Formatter<'_>, bytes: &[u8]) -> fmt::Result { + let ellipsis = get_ellipsis(); + let width = get_str_len_limit() * 2; + write!(f, "b\"")?; + + for b in bytes.iter().take(width) { + if b.is_ascii_alphanumeric() || b.is_ascii_punctuation() { + write!(f, "{}", *b as char)?; + } else { + write!(f, "\\x{:02x}", b)?; + } + } + if bytes.len() > width { + write!(f, "\"{ellipsis}")?; + } else { + f.write_str("\"")?; + } + Ok(()) +} + +impl Display for AnyValue<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let width = 0; + match self { + AnyValue::Null => write!(f, "null"), + AnyValue::UInt8(v) => fmt_integer(f, width, *v), + AnyValue::UInt16(v) => fmt_integer(f, width, *v), + AnyValue::UInt32(v) => fmt_integer(f, width, *v), + AnyValue::UInt64(v) => fmt_integer(f, width, *v), + AnyValue::Int8(v) => fmt_integer(f, width, *v), + AnyValue::Int16(v) => fmt_integer(f, width, *v), + AnyValue::Int32(v) => fmt_integer(f, width, *v), + AnyValue::Int64(v) => fmt_integer(f, width, *v), + AnyValue::Int128(v) => feature_gated!("dtype-i128", fmt_integer(f, width, *v)), + AnyValue::Float32(v) => fmt_float(f, width, *v), + AnyValue::Float64(v) => fmt_float(f, width, *v), + AnyValue::Boolean(v) => write!(f, "{}", *v), + AnyValue::String(v) => write!(f, "{}", format_args!("\"{v}\"")), + AnyValue::StringOwned(v) => write!(f, "{}", format_args!("\"{v}\"")), + AnyValue::Binary(d) => format_blob(f, d), + AnyValue::BinaryOwned(d) => format_blob(f, d), + #[cfg(feature = "dtype-date")] + AnyValue::Date(v) => write!(f, "{}", date32_to_date(*v)), + #[cfg(feature = "dtype-datetime")] + AnyValue::Datetime(v, tu, tz) => fmt_datetime(f, *v, *tu, *tz), + #[cfg(feature = "dtype-datetime")] + AnyValue::DatetimeOwned(v, tu, tz) => { + fmt_datetime(f, *v, *tu, tz.as_ref().map(|v| v.as_ref())) + }, + #[cfg(feature = "dtype-duration")] + AnyValue::Duration(v, tu) => fmt_duration_string(f, *v, *tu), + #[cfg(feature = "dtype-time")] + AnyValue::Time(_) => { + let nt: chrono::NaiveTime = self.into(); + write!(f, "{nt}") + }, + #[cfg(feature = "dtype-categorical")] + AnyValue::Categorical(_, _, _) + | AnyValue::CategoricalOwned(_, _, _) + | AnyValue::Enum(_, _, _) + | AnyValue::EnumOwned(_, _, _) => { + let s = self.get_str().unwrap(); + write!(f, "\"{s}\"") + }, + #[cfg(feature = "dtype-array")] + AnyValue::Array(s, _size) => write!(f, "{}", s.fmt_list()), + AnyValue::List(s) => write!(f, "{}", s.fmt_list()), + #[cfg(feature = "object")] + AnyValue::Object(v) => write!(f, "{v}"), + #[cfg(feature = "object")] + AnyValue::ObjectOwned(v) => write!(f, "{}", v.0.as_ref()), + #[cfg(feature = "dtype-struct")] + av @ AnyValue::Struct(_, _, _) => { + let mut avs = vec![]; + av._materialize_struct_av(&mut avs); + fmt_struct(f, &avs) + }, + #[cfg(feature = "dtype-struct")] + AnyValue::StructOwned(payload) => fmt_struct(f, &payload.0), + #[cfg(feature = "dtype-decimal")] + AnyValue::Decimal(v, scale) => fmt_decimal(f, *v, *scale), + } + } +} + +/// Utility struct to format a timezone aware datetime. +#[allow(dead_code)] +#[cfg(feature = "dtype-datetime")] +pub struct PlTzAware<'a> { + ndt: NaiveDateTime, + tz: &'a str, +} +#[cfg(feature = "dtype-datetime")] +impl<'a> PlTzAware<'a> { + pub fn new(ndt: NaiveDateTime, tz: &'a str) -> Self { + Self { ndt, tz } + } +} + +#[cfg(feature = "dtype-datetime")] +impl Display for PlTzAware<'_> { + #[allow(unused_variables)] + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + #[cfg(feature = "timezones")] + match self.tz.parse::() { + Ok(tz) => { + let dt_utc = chrono::Utc.from_local_datetime(&self.ndt).unwrap(); + let dt_tz_aware = dt_utc.with_timezone(&tz); + write!(f, "{dt_tz_aware}") + }, + Err(_) => write!(f, "invalid timezone"), + } + #[cfg(not(feature = "timezones"))] + { + panic!("activate 'timezones' feature") + } + } +} + +#[cfg(feature = "dtype-struct")] +fn fmt_struct(f: &mut Formatter<'_>, vals: &[AnyValue]) -> fmt::Result { + write!(f, "{{")?; + if !vals.is_empty() { + for v in &vals[..vals.len() - 1] { + write!(f, "{v},")?; + } + // last value has no trailing comma + write!(f, "{}", vals[vals.len() - 1])?; + } + write!(f, "}}") +} + +impl Series { + pub fn fmt_list(&self) -> String { + if self.is_empty() { + return "[]".to_owned(); + } + let mut result = "[".to_owned(); + let max_items = get_list_len_limit(); + let ellipsis = get_ellipsis(); + + match max_items { + 0 => write!(result, "{ellipsis}]").unwrap(), + _ if max_items >= self.len() => { + // this will always leave a trailing ", " after the last item + // but for long lists, this is faster than checking against the length each time + for item in self.rechunk().iter() { + write!(result, "{item}, ").unwrap(); + } + // remove trailing ", " and replace with closing brace + result.truncate(result.len() - 2); + result.push(']'); + }, + _ => { + let s = self.slice(0, max_items).rechunk(); + for (i, item) in s.iter().enumerate() { + if i == max_items.saturating_sub(1) { + write!(result, "{ellipsis} {}", self.get(self.len() - 1).unwrap()).unwrap(); + break; + } else { + write!(result, "{item}, ").unwrap(); + } + } + result.push(']'); + }, + }; + result + } +} + +#[inline] +#[cfg(feature = "dtype-decimal")] +fn fmt_decimal(f: &mut Formatter<'_>, v: i128, scale: usize) -> fmt::Result { + let mut fmt_buf = arrow::compute::decimal::DecimalFmtBuffer::new(); + let trim_zeros = get_trim_decimal_zeros(); + f.write_str(fmt_float_string(fmt_buf.format(v, scale, trim_zeros)).as_str()) +} + +#[cfg(all( + test, + feature = "temporal", + feature = "dtype-date", + feature = "dtype-datetime" +))] +#[allow(unsafe_op_in_unsafe_fn)] +mod test { + use crate::prelude::*; + + #[test] + fn test_fmt_list() { + let mut builder = ListPrimitiveChunkedBuilder::::new( + PlSmallStr::from_static("a"), + 10, + 10, + DataType::Int32, + ); + builder.append_opt_slice(Some(&[1, 2, 3, 4, 5, 6])); + builder.append_opt_slice(None); + let list_long = builder.finish().into_series(); + + assert_eq!( + r#"shape: (2,) +Series: 'a' [list[i32]] +[ + [1, 2, … 6] + null +]"#, + format!("{:?}", list_long) + ); + + unsafe { std::env::set_var("POLARS_FMT_TABLE_CELL_LIST_LEN", "10") }; + + assert_eq!( + r#"shape: (2,) +Series: 'a' [list[i32]] +[ + [1, 2, 3, 4, 5, 6] + null +]"#, + format!("{:?}", list_long) + ); + + unsafe { std::env::set_var("POLARS_FMT_TABLE_CELL_LIST_LEN", "-1") }; + + assert_eq!( + r#"shape: (2,) +Series: 'a' [list[i32]] +[ + [1, 2, 3, 4, 5, 6] + null +]"#, + format!("{:?}", list_long) + ); + + unsafe { std::env::set_var("POLARS_FMT_TABLE_CELL_LIST_LEN", "0") }; + + assert_eq!( + r#"shape: (2,) +Series: 'a' [list[i32]] +[ + […] + null +]"#, + format!("{:?}", list_long) + ); + + unsafe { std::env::set_var("POLARS_FMT_TABLE_CELL_LIST_LEN", "1") }; + + assert_eq!( + r#"shape: (2,) +Series: 'a' [list[i32]] +[ + [… 6] + null +]"#, + format!("{:?}", list_long) + ); + + unsafe { std::env::set_var("POLARS_FMT_TABLE_CELL_LIST_LEN", "4") }; + + assert_eq!( + r#"shape: (2,) +Series: 'a' [list[i32]] +[ + [1, 2, 3, … 6] + null +]"#, + format!("{:?}", list_long) + ); + + let mut builder = ListPrimitiveChunkedBuilder::::new( + PlSmallStr::from_static("a"), + 10, + 10, + DataType::Int32, + ); + builder.append_opt_slice(Some(&[1])); + builder.append_opt_slice(None); + let list_short = builder.finish().into_series(); + + unsafe { std::env::set_var("POLARS_FMT_TABLE_CELL_LIST_LEN", "") }; + + assert_eq!( + r#"shape: (2,) +Series: 'a' [list[i32]] +[ + [1] + null +]"#, + format!("{:?}", list_short) + ); + + unsafe { std::env::set_var("POLARS_FMT_TABLE_CELL_LIST_LEN", "0") }; + + assert_eq!( + r#"shape: (2,) +Series: 'a' [list[i32]] +[ + […] + null +]"#, + format!("{:?}", list_short) + ); + + unsafe { std::env::set_var("POLARS_FMT_TABLE_CELL_LIST_LEN", "-1") }; + + assert_eq!( + r#"shape: (2,) +Series: 'a' [list[i32]] +[ + [1] + null +]"#, + format!("{:?}", list_short) + ); + + let mut builder = ListPrimitiveChunkedBuilder::::new( + PlSmallStr::from_static("a"), + 10, + 10, + DataType::Int32, + ); + builder.append_opt_slice(Some(&[])); + builder.append_opt_slice(None); + let list_empty = builder.finish().into_series(); + + unsafe { std::env::set_var("POLARS_FMT_TABLE_CELL_LIST_LEN", "") }; + + assert_eq!( + r#"shape: (2,) +Series: 'a' [list[i32]] +[ + [] + null +]"#, + format!("{:?}", list_empty) + ); + } + + #[test] + fn test_fmt_temporal() { + let s = Int32Chunked::new(PlSmallStr::from_static("Date"), &[Some(1), None, Some(3)]) + .into_date(); + assert_eq!( + r#"shape: (3,) +Series: 'Date' [date] +[ + 1970-01-02 + null + 1970-01-04 +]"#, + format!("{:?}", s.into_series()) + ); + + let s = Int64Chunked::new(PlSmallStr::EMPTY, &[Some(1), None, Some(1_000_000_000_000)]) + .into_datetime(TimeUnit::Nanoseconds, None); + assert_eq!( + r#"shape: (3,) +Series: '' [datetime[ns]] +[ + 1970-01-01 00:00:00.000000001 + null + 1970-01-01 00:16:40 +]"#, + format!("{:?}", s.into_series()) + ); + } + + #[test] + fn test_fmt_chunkedarray() { + let ca = Int32Chunked::new(PlSmallStr::from_static("Date"), &[Some(1), None, Some(3)]); + assert_eq!( + r#"shape: (3,) +ChunkedArray: 'Date' [i32] +[ + 1 + null + 3 +]"#, + format!("{:?}", ca) + ); + let ca = StringChunked::new(PlSmallStr::from_static("name"), &["a", "b"]); + assert_eq!( + r#"shape: (2,) +ChunkedArray: 'name' [str] +[ + "a" + "b" +]"#, + format!("{:?}", ca) + ); + } +} diff --git a/crates/polars-core/src/frame/arithmetic.rs b/crates/polars-core/src/frame/arithmetic.rs new file mode 100644 index 000000000000..1bc593a8f87c --- /dev/null +++ b/crates/polars-core/src/frame/arithmetic.rs @@ -0,0 +1,203 @@ +use std::ops::{Add, Div, Mul, Rem, Sub}; + +use rayon::prelude::*; + +use crate::POOL; +use crate::prelude::*; +use crate::utils::try_get_supertype; + +/// Get the supertype that is valid for all columns in the [`DataFrame`]. +/// This reduces casting of the rhs in arithmetic. +fn get_supertype_all(df: &DataFrame, rhs: &Series) -> PolarsResult { + df.columns.iter().try_fold(rhs.dtype().clone(), |dt, s| { + try_get_supertype(s.dtype(), &dt) + }) +} + +macro_rules! impl_arithmetic { + ($self:expr, $rhs:expr, $operand:expr) => {{ + let st = get_supertype_all($self, $rhs)?; + let rhs = $rhs.cast(&st)?; + let cols = POOL.install(|| { + $self + .par_materialized_column_iter() + .map(|s| $operand(&s.cast(&st)?, &rhs)) + .map(|s| s.map(Column::from)) + .collect::>() + })?; + Ok(unsafe { DataFrame::new_no_checks($self.height(), cols) }) + }}; +} + +impl Add<&Series> for &DataFrame { + type Output = PolarsResult; + + fn add(self, rhs: &Series) -> Self::Output { + impl_arithmetic!(self, rhs, std::ops::Add::add) + } +} + +impl Add<&Series> for DataFrame { + type Output = PolarsResult; + + fn add(self, rhs: &Series) -> Self::Output { + (&self).add(rhs) + } +} + +impl Sub<&Series> for &DataFrame { + type Output = PolarsResult; + + fn sub(self, rhs: &Series) -> Self::Output { + impl_arithmetic!(self, rhs, std::ops::Sub::sub) + } +} + +impl Sub<&Series> for DataFrame { + type Output = PolarsResult; + + fn sub(self, rhs: &Series) -> Self::Output { + (&self).sub(rhs) + } +} + +impl Mul<&Series> for &DataFrame { + type Output = PolarsResult; + + fn mul(self, rhs: &Series) -> Self::Output { + impl_arithmetic!(self, rhs, std::ops::Mul::mul) + } +} + +impl Mul<&Series> for DataFrame { + type Output = PolarsResult; + + fn mul(self, rhs: &Series) -> Self::Output { + (&self).mul(rhs) + } +} + +impl Div<&Series> for &DataFrame { + type Output = PolarsResult; + + fn div(self, rhs: &Series) -> Self::Output { + impl_arithmetic!(self, rhs, std::ops::Div::div) + } +} + +impl Div<&Series> for DataFrame { + type Output = PolarsResult; + + fn div(self, rhs: &Series) -> Self::Output { + (&self).div(rhs) + } +} + +impl Rem<&Series> for &DataFrame { + type Output = PolarsResult; + + fn rem(self, rhs: &Series) -> Self::Output { + impl_arithmetic!(self, rhs, std::ops::Rem::rem) + } +} + +impl Rem<&Series> for DataFrame { + type Output = PolarsResult; + + fn rem(self, rhs: &Series) -> Self::Output { + (&self).rem(rhs) + } +} + +impl DataFrame { + fn binary_aligned( + &self, + other: &DataFrame, + f: &(dyn Fn(&Series, &Series) -> PolarsResult + Sync + Send), + ) -> PolarsResult { + let max_len = std::cmp::max(self.height(), other.height()); + let max_width = std::cmp::max(self.width(), other.width()); + let cols = self + .get_columns() + .par_iter() + .zip(other.get_columns().par_iter()) + .map(|(l, r)| { + let l = l.as_materialized_series(); + let r = r.as_materialized_series(); + + let diff_l = max_len - l.len(); + let diff_r = max_len - r.len(); + + let st = try_get_supertype(l.dtype(), r.dtype())?; + let mut l = l.cast(&st)?; + let mut r = r.cast(&st)?; + + if diff_l > 0 { + l = l.extend_constant(AnyValue::Null, diff_l)?; + }; + if diff_r > 0 { + r = r.extend_constant(AnyValue::Null, diff_r)?; + }; + + f(&l, &r).map(Column::from) + }); + let mut cols = POOL.install(|| cols.collect::>>())?; + + let col_len = cols.len(); + if col_len < max_width { + let df = if col_len < self.width() { self } else { other }; + + for i in col_len..max_len { + let s = &df.get_columns().get(i).ok_or_else(|| polars_err!(InvalidOperation: "cannot do arithmetic on DataFrames with shapes: {:?} and {:?}", self.shape(), other.shape()))?; + let name = s.name(); + let dtype = s.dtype(); + + // trick to fill a series with nulls + let vals: &[Option] = &[None]; + let s = Series::new(name.clone(), vals).cast(dtype)?; + cols.push(s.new_from_index(0, max_len).into()) + } + } + DataFrame::new(cols) + } +} + +impl Add<&DataFrame> for &DataFrame { + type Output = PolarsResult; + + fn add(self, rhs: &DataFrame) -> Self::Output { + self.binary_aligned(rhs, &|a, b| a + b) + } +} + +impl Sub<&DataFrame> for &DataFrame { + type Output = PolarsResult; + + fn sub(self, rhs: &DataFrame) -> Self::Output { + self.binary_aligned(rhs, &|a, b| a - b) + } +} + +impl Div<&DataFrame> for &DataFrame { + type Output = PolarsResult; + + fn div(self, rhs: &DataFrame) -> Self::Output { + self.binary_aligned(rhs, &|a, b| a / b) + } +} + +impl Mul<&DataFrame> for &DataFrame { + type Output = PolarsResult; + + fn mul(self, rhs: &DataFrame) -> Self::Output { + self.binary_aligned(rhs, &|a, b| a * b) + } +} + +impl Rem<&DataFrame> for &DataFrame { + type Output = PolarsResult; + + fn rem(self, rhs: &DataFrame) -> Self::Output { + self.binary_aligned(rhs, &|a, b| a % b) + } +} diff --git a/crates/polars-core/src/frame/builder.rs b/crates/polars-core/src/frame/builder.rs new file mode 100644 index 000000000000..43c701d74325 --- /dev/null +++ b/crates/polars-core/src/frame/builder.rs @@ -0,0 +1,189 @@ +use std::sync::Arc; + +use arrow::array::builder::ShareStrategy; +use polars_utils::IdxSize; + +use crate::frame::DataFrame; +use crate::prelude::*; +use crate::schema::Schema; +use crate::series::builder::SeriesBuilder; + +pub struct DataFrameBuilder { + schema: Arc, + builders: Vec, + height: usize, +} + +impl DataFrameBuilder { + pub fn new(schema: Arc) -> Self { + let builders = schema + .iter_values() + .map(|dt| SeriesBuilder::new(dt.clone())) + .collect(); + Self { + schema, + builders, + height: 0, + } + } + + pub fn reserve(&mut self, additional: usize) { + for builder in &mut self.builders { + builder.reserve(additional); + } + } + + pub fn freeze(self) -> DataFrame { + let columns = self + .schema + .iter_names() + .zip(self.builders) + .map(|(n, b)| { + let s = b.freeze(n.clone()); + assert!(s.len() == self.height); + Column::from(s) + }) + .collect(); + + // SAFETY: we checked the lengths and the names are unique because they + // come from Schema. + unsafe { DataFrame::new_no_checks(self.height, columns) } + } + + pub fn freeze_reset(&mut self) -> DataFrame { + let columns = self + .schema + .iter_names() + .zip(&mut self.builders) + .map(|(n, b)| { + let s = b.freeze_reset(n.clone()); + assert!(s.len() == self.height); + Column::from(s) + }) + .collect(); + + // SAFETY: we checked the lengths and the names are unique because they + // come from Schema. + let out = unsafe { DataFrame::new_no_checks(self.height, columns) }; + self.height = 0; + out + } + + pub fn len(&self) -> usize { + self.height + } + + pub fn is_empty(&self) -> bool { + self.height == 0 + } + + /// Extends this builder with the contents of the given dataframe. May panic + /// if other does not match the schema of this builder. + pub fn extend(&mut self, other: &DataFrame, share: ShareStrategy) { + self.subslice_extend(other, 0, other.height(), share); + self.height += other.height(); + } + + /// Extends this builder with the contents of the given dataframe subslice. + /// May panic if other does not match the schema of this builder. + pub fn subslice_extend( + &mut self, + other: &DataFrame, + start: usize, + length: usize, + share: ShareStrategy, + ) { + let columns = other.get_columns(); + assert!(self.builders.len() == columns.len()); + for (builder, column) in self.builders.iter_mut().zip(columns) { + match column { + Column::Series(s) => { + builder.subslice_extend(s, start, length, share); + }, + Column::Partitioned(p) => { + // @scalar-opt + builder.subslice_extend(p.as_materialized_series(), start, length, share); + }, + Column::Scalar(sc) => { + let len = sc.len().saturating_sub(start).min(length); + let scalar_as_series = sc.scalar().clone().into_series(PlSmallStr::default()); + builder.subslice_extend_repeated(&scalar_as_series, 0, 1, len, share); + }, + } + } + + self.height += length.min(other.height().saturating_sub(start)); + } + + /// Extends this builder with the contents of the given dataframe at the given + /// indices. That is, `other[idxs[i]]` is appended to this builder in order, + /// for each i=0..idxs.len(). May panic if other does not match the schema + /// of this builder, or if the other dataframe is not rechunked. + /// + /// # Safety + /// The indices must be in-bounds. + pub unsafe fn gather_extend( + &mut self, + other: &DataFrame, + idxs: &[IdxSize], + share: ShareStrategy, + ) { + let columns = other.get_columns(); + assert!(self.builders.len() == columns.len()); + for (builder, column) in self.builders.iter_mut().zip(columns) { + match column { + Column::Series(s) => { + builder.gather_extend(s, idxs, share); + }, + Column::Partitioned(p) => { + // @scalar-opt + builder.gather_extend(p.as_materialized_series(), idxs, share); + }, + Column::Scalar(sc) => { + let scalar_as_series = sc.scalar().clone().into_series(PlSmallStr::default()); + builder.subslice_extend_repeated(&scalar_as_series, 0, 1, idxs.len(), share); + }, + } + } + + self.height += idxs.len(); + } + + /// Extends this builder with the contents of the given dataframe at the given + /// indices. That is, `other[idxs[i]]` is appended to this builder in order, + /// for each i=0..idxs.len(). Out-of-bounds indices extend with nulls. + /// May panic if other does not match the schema of this builder, or if the + /// other dataframe is not rechunked. + pub fn opt_gather_extend(&mut self, other: &DataFrame, idxs: &[IdxSize], share: ShareStrategy) { + let mut trans_idxs = Vec::new(); + let columns = other.get_columns(); + assert!(self.builders.len() == columns.len()); + for (builder, column) in self.builders.iter_mut().zip(columns) { + match column { + Column::Series(s) => { + builder.opt_gather_extend(s, idxs, share); + }, + Column::Partitioned(p) => { + // @scalar-opt + builder.opt_gather_extend(p.as_materialized_series(), idxs, share); + }, + Column::Scalar(sc) => { + let scalar_as_series = sc.scalar().clone().into_series(PlSmallStr::default()); + // Reduce call overhead by transforming indices to 0/1 and dispatching to + // opt_gather_extend on the scalar as series. + for idx_chunk in idxs.chunks(4096) { + trans_idxs.clear(); + trans_idxs.extend( + idx_chunk + .iter() + .map(|idx| ((*idx as usize) >= sc.len()) as IdxSize), + ); + builder.opt_gather_extend(&scalar_as_series, &trans_idxs, share); + } + }, + } + } + + self.height += idxs.len(); + } +} diff --git a/crates/polars-core/src/frame/chunks.rs b/crates/polars-core/src/frame/chunks.rs new file mode 100644 index 000000000000..b63156cd2ae0 --- /dev/null +++ b/crates/polars-core/src/frame/chunks.rs @@ -0,0 +1,150 @@ +use arrow::record_batch::RecordBatch; +use rayon::prelude::*; + +use crate::POOL; +use crate::prelude::*; +use crate::utils::{_split_offsets, accumulate_dataframes_vertical_unchecked, split_df_as_ref}; + +impl From for DataFrame { + fn from(rb: RecordBatch) -> DataFrame { + let height = rb.height(); + let (schema, arrays) = rb.into_schema_and_arrays(); + + let columns: Vec = arrays + .into_iter() + .zip(schema.iter()) + .map(|(arr, (name, field))| { + // SAFETY: Record Batch has the invariant that the schema datatype matches the + // columns. + unsafe { + Series::_try_from_arrow_unchecked_with_md( + name.clone(), + vec![arr], + field.dtype(), + field.metadata.as_deref(), + ) + } + .unwrap() + .into_column() + }) + .collect(); + + // SAFETY: RecordBatch has the same invariants for names and heights as DataFrame. + unsafe { DataFrame::new_no_checks(height, columns) } + } +} + +impl DataFrame { + pub fn split_chunks(&mut self) -> impl Iterator + '_ { + self.align_chunks_par(); + + let first_series_col_idx = self + .columns + .iter() + .position(|col| col.as_series().is_some()); + let df_height = self.height(); + let mut prev_height = 0; + (0..self.first_col_n_chunks()).map(move |i| unsafe { + // There might still be scalar/partitioned columns after aligning, + // so we follow the size of the chunked column, if any. + let chunk_size = first_series_col_idx + .map(|c| self.get_columns()[c].as_series().unwrap().chunks()[i].len()) + .unwrap_or(df_height); + let columns = self + .get_columns() + .iter() + .map(|col| match col { + Column::Series(s) => Column::from(s.select_chunk(i)), + Column::Partitioned(_) | Column::Scalar(_) => { + col.slice(prev_height as i64, chunk_size) + }, + }) + .collect::>(); + + prev_height += chunk_size; + + DataFrame::new_no_checks(chunk_size, columns) + }) + } + + pub fn split_chunks_by_n(self, n: usize, parallel: bool) -> Vec { + let split = _split_offsets(self.height(), n); + + let split_fn = |(offset, len)| self.slice(offset as i64, len); + + if parallel { + // Parallel so that null_counts run in parallel + POOL.install(|| split.into_par_iter().map(split_fn).collect()) + } else { + split.into_iter().map(split_fn).collect() + } + } +} + +/// Split DataFrame into chunks in preparation for writing. The chunks have a +/// maximum number of rows per chunk to ensure reasonable memory efficiency when +/// reading the resulting file, and a minimum size per chunk to ensure +/// reasonable performance when writing. +pub fn chunk_df_for_writing( + df: &mut DataFrame, + row_group_size: usize, +) -> PolarsResult> { + // ensures all chunks are aligned. + df.align_chunks_par(); + + // Accumulate many small chunks to the row group size. + // See: #16403 + if !df.get_columns().is_empty() + && df.get_columns()[0] + .as_materialized_series() + .chunk_lengths() + .take(5) + .all(|len| len < row_group_size) + { + fn finish(scratch: &mut Vec, new_chunks: &mut Vec) { + let mut new = accumulate_dataframes_vertical_unchecked(scratch.drain(..)); + new.as_single_chunk_par(); + new_chunks.push(new); + } + + let mut new_chunks = Vec::with_capacity(df.first_col_n_chunks()); // upper limit; + let mut scratch = vec![]; + let mut remaining = row_group_size; + + for df in df.split_chunks() { + remaining = remaining.saturating_sub(df.height()); + scratch.push(df); + + if remaining == 0 { + remaining = row_group_size; + finish(&mut scratch, &mut new_chunks); + } + } + if !scratch.is_empty() { + finish(&mut scratch, &mut new_chunks); + } + return Ok(std::borrow::Cow::Owned( + accumulate_dataframes_vertical_unchecked(new_chunks), + )); + } + + let n_splits = df.height() / row_group_size; + let result = if n_splits > 0 { + let mut splits = split_df_as_ref(df, n_splits, false); + + for df in splits.iter_mut() { + // If the chunks are small enough, writing many small chunks + // leads to slow writing performance, so in that case we + // merge them. + let n_chunks = df.first_col_n_chunks(); + if n_chunks > 1 && (df.estimated_size() / n_chunks < 128 * 1024) { + df.as_single_chunk_par(); + } + } + + std::borrow::Cow::Owned(accumulate_dataframes_vertical_unchecked(splits)) + } else { + std::borrow::Cow::Borrowed(df) + }; + Ok(result) +} diff --git a/crates/polars-core/src/frame/column/arithmetic.rs b/crates/polars-core/src/frame/column/arithmetic.rs new file mode 100644 index 000000000000..8018ee4527e6 --- /dev/null +++ b/crates/polars-core/src/frame/column/arithmetic.rs @@ -0,0 +1,93 @@ +use num_traits::{Num, NumCast}; +use polars_error::PolarsResult; + +use super::{Column, ScalarColumn, Series}; + +fn num_op_with_broadcast Series>( + c: &'_ Column, + n: T, + op: F, +) -> Column { + match c { + Column::Series(s) => op(s, n).into(), + // @partition-opt + Column::Partitioned(s) => op(s.as_materialized_series(), n).into(), + Column::Scalar(s) => { + ScalarColumn::from_single_value_series(op(&s.as_single_value_series(), n), s.len()) + .into() + }, + } +} + +macro_rules! broadcastable_ops { + ($(($trait:ident, $op:ident))+) => { + $( + impl std::ops::$trait for Column { + type Output = PolarsResult; + + #[inline] + fn $op(self, rhs: Self) -> Self::Output { + self.try_apply_broadcasting_binary_elementwise(&rhs, |l, r| l.$op(r)) + } + } + + impl std::ops::$trait for &Column { + type Output = PolarsResult; + + #[inline] + fn $op(self, rhs: Self) -> Self::Output { + self.try_apply_broadcasting_binary_elementwise(rhs, |l, r| l.$op(r)) + } + } + )+ + } +} + +macro_rules! broadcastable_num_ops { + ($(($trait:ident, $op:ident))+) => { + $( + impl std::ops::$trait:: for Column + where + T: Num + NumCast, + { + type Output = Self; + + #[inline] + fn $op(self, rhs: T) -> Self::Output { + num_op_with_broadcast(&self, rhs, |l, r| l.$op(r)) + } + } + + impl std::ops::$trait:: for &Column + where + T: Num + NumCast, + { + type Output = Column; + + #[inline] + fn $op(self, rhs: T) -> Self::Output { + num_op_with_broadcast(self, rhs, |l, r| l.$op(r)) + } + } + )+ + }; +} + +broadcastable_ops! { + (Add, add) + (Sub, sub) + (Mul, mul) + (Div, div) + (Rem, rem) + (BitAnd, bitand) + (BitOr, bitor) + (BitXor, bitxor) +} + +broadcastable_num_ops! { + (Add, add) + (Sub, sub) + (Mul, mul) + (Div, div) + (Rem, rem) +} diff --git a/crates/polars-core/src/frame/column/compare.rs b/crates/polars-core/src/frame/column/compare.rs new file mode 100644 index 000000000000..5ea472b66c7f --- /dev/null +++ b/crates/polars-core/src/frame/column/compare.rs @@ -0,0 +1,86 @@ +use polars_error::PolarsResult; + +use super::{BooleanChunked, ChunkCompareEq, ChunkCompareIneq, ChunkExpandAtIndex, Column, Series}; + +macro_rules! column_element_wise_broadcasting { + ($lhs:expr, $rhs:expr, $op:expr) => { + match ($lhs, $rhs) { + (Column::Series(lhs), Column::Scalar(rhs)) => $op(lhs, &rhs.as_single_value_series()), + (Column::Scalar(lhs), Column::Series(rhs)) => $op(&lhs.as_single_value_series(), rhs), + (Column::Scalar(lhs), Column::Scalar(rhs)) => { + $op(&lhs.as_single_value_series(), &rhs.as_single_value_series()).map(|ca| { + if ca.len() == 0 { + ca + } else { + ca.new_from_index(0, lhs.len()) + } + }) + }, + (lhs, rhs) => $op(lhs.as_materialized_series(), rhs.as_materialized_series()), + } + }; +} + +impl ChunkCompareEq<&Column> for Column { + type Item = PolarsResult; + + /// Create a boolean mask by checking for equality. + #[inline] + fn equal(&self, rhs: &Column) -> PolarsResult { + column_element_wise_broadcasting!(self, rhs, >::equal) + } + + /// Create a boolean mask by checking for equality. + #[inline] + fn equal_missing(&self, rhs: &Column) -> PolarsResult { + column_element_wise_broadcasting!( + self, + rhs, + >::equal_missing + ) + } + + /// Create a boolean mask by checking for inequality. + #[inline] + fn not_equal(&self, rhs: &Column) -> PolarsResult { + column_element_wise_broadcasting!(self, rhs, >::not_equal) + } + + /// Create a boolean mask by checking for inequality. + #[inline] + fn not_equal_missing(&self, rhs: &Column) -> PolarsResult { + column_element_wise_broadcasting!( + self, + rhs, + >::not_equal_missing + ) + } +} + +impl ChunkCompareIneq<&Column> for Column { + type Item = PolarsResult; + + /// Create a boolean mask by checking if self > rhs. + #[inline] + fn gt(&self, rhs: &Column) -> PolarsResult { + column_element_wise_broadcasting!(self, rhs, >::gt) + } + + /// Create a boolean mask by checking if self >= rhs. + #[inline] + fn gt_eq(&self, rhs: &Column) -> PolarsResult { + column_element_wise_broadcasting!(self, rhs, >::gt_eq) + } + + /// Create a boolean mask by checking if self < rhs. + #[inline] + fn lt(&self, rhs: &Column) -> PolarsResult { + column_element_wise_broadcasting!(self, rhs, >::lt) + } + + /// Create a boolean mask by checking if self <= rhs. + #[inline] + fn lt_eq(&self, rhs: &Column) -> PolarsResult { + column_element_wise_broadcasting!(self, rhs, >::lt_eq) + } +} diff --git a/crates/polars-core/src/frame/column/mod.rs b/crates/polars-core/src/frame/column/mod.rs new file mode 100644 index 000000000000..a788909eb4b9 --- /dev/null +++ b/crates/polars-core/src/frame/column/mod.rs @@ -0,0 +1,1867 @@ +use std::borrow::Cow; + +use arrow::bitmap::BitmapBuilder; +use arrow::trusted_len::TrustMyLength; +use num_traits::{Num, NumCast}; +use polars_compute::rolling::QuantileMethod; +use polars_error::PolarsResult; +use polars_utils::aliases::PlSeedableRandomStateQuality; +use polars_utils::index::check_bounds; +use polars_utils::pl_str::PlSmallStr; +pub use scalar::ScalarColumn; + +use self::compare_inner::{TotalEqInner, TotalOrdInner}; +use self::gather::check_bounds_ca; +use self::partitioned::PartitionedColumn; +use self::series::SeriesColumn; +use crate::chunked_array::cast::CastOptions; +use crate::chunked_array::flags::StatisticsFlags; +use crate::datatypes::ReshapeDimension; +use crate::prelude::*; +use crate::series::{BitRepr, IsSorted, SeriesPhysIter}; +use crate::utils::{Container, slice_offsets}; +use crate::{HEAD_DEFAULT_LENGTH, TAIL_DEFAULT_LENGTH}; + +mod arithmetic; +mod compare; +mod partitioned; +mod scalar; +mod series; + +/// A column within a [`DataFrame`]. +/// +/// This is lazily initialized to a [`Series`] with methods like +/// [`as_materialized_series`][Column::as_materialized_series] and +/// [`take_materialized_series`][Column::take_materialized_series]. +/// +/// Currently, there are two ways to represent a [`Column`]. +/// 1. A [`Series`] of values +/// 2. A [`ScalarColumn`] that repeats a single [`Scalar`] +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] +pub enum Column { + Series(SeriesColumn), + Partitioned(PartitionedColumn), + Scalar(ScalarColumn), +} + +/// Convert `Self` into a [`Column`] +pub trait IntoColumn: Sized { + fn into_column(self) -> Column; +} + +impl Column { + #[inline] + #[track_caller] + pub fn new(name: PlSmallStr, values: T) -> Self + where + Phantom: ?Sized, + Series: NamedFrom, + { + Self::Series(SeriesColumn::new(NamedFrom::new(name, values))) + } + + #[inline] + pub fn new_empty(name: PlSmallStr, dtype: &DataType) -> Self { + Self::new_scalar(name, Scalar::new(dtype.clone(), AnyValue::Null), 0) + } + + #[inline] + pub fn new_scalar(name: PlSmallStr, scalar: Scalar, length: usize) -> Self { + Self::Scalar(ScalarColumn::new(name, scalar, length)) + } + + #[inline] + pub fn new_partitioned(name: PlSmallStr, scalar: Scalar, length: usize) -> Self { + Self::Scalar(ScalarColumn::new(name, scalar, length)) + } + + pub fn new_row_index(name: PlSmallStr, offset: IdxSize, length: usize) -> PolarsResult { + let length = IdxSize::try_from(length).unwrap_or(IdxSize::MAX); + + if offset.checked_add(length).is_none() { + polars_bail!( + ComputeError: + "row index with offset {} overflows on dataframe with height {}", + offset, length + ) + } + + let range = offset..offset + length; + + let mut ca = IdxCa::from_vec(name, range.collect()); + ca.set_sorted_flag(IsSorted::Ascending); + let col = ca.into_series().into(); + + Ok(col) + } + + // # Materialize + /// Get a reference to a [`Series`] for this [`Column`] + /// + /// This may need to materialize the [`Series`] on the first invocation for a specific column. + #[inline] + pub fn as_materialized_series(&self) -> &Series { + match self { + Column::Series(s) => s, + Column::Partitioned(s) => s.as_materialized_series(), + Column::Scalar(s) => s.as_materialized_series(), + } + } + + /// If the memory repr of this Column is a scalar, a unit-length Series will + /// be returned. + #[inline] + pub fn as_materialized_series_maintain_scalar(&self) -> Series { + match self { + Column::Scalar(s) => s.as_single_value_series(), + v => v.as_materialized_series().clone(), + } + } + + /// Turn [`Column`] into a [`Column::Series`]. + /// + /// This may need to materialize the [`Series`] on the first invocation for a specific column. + #[inline] + pub fn into_materialized_series(&mut self) -> &mut Series { + match self { + Column::Series(s) => s, + Column::Partitioned(s) => { + let series = std::mem::replace( + s, + PartitionedColumn::new_empty(PlSmallStr::EMPTY, DataType::Null), + ) + .take_materialized_series(); + *self = Column::Series(series.into()); + let Column::Series(s) = self else { + unreachable!(); + }; + s + }, + Column::Scalar(s) => { + let series = std::mem::replace( + s, + ScalarColumn::new_empty(PlSmallStr::EMPTY, DataType::Null), + ) + .take_materialized_series(); + *self = Column::Series(series.into()); + let Column::Series(s) = self else { + unreachable!(); + }; + s + }, + } + } + /// Take [`Series`] from a [`Column`] + /// + /// This may need to materialize the [`Series`] on the first invocation for a specific column. + #[inline] + pub fn take_materialized_series(self) -> Series { + match self { + Column::Series(s) => s.take(), + Column::Partitioned(s) => s.take_materialized_series(), + Column::Scalar(s) => s.take_materialized_series(), + } + } + + #[inline] + pub fn dtype(&self) -> &DataType { + match self { + Column::Series(s) => s.dtype(), + Column::Partitioned(s) => s.dtype(), + Column::Scalar(s) => s.dtype(), + } + } + + #[inline] + pub fn field(&self) -> Cow { + match self { + Column::Series(s) => s.field(), + Column::Partitioned(s) => s.field(), + Column::Scalar(s) => match s.lazy_as_materialized_series() { + None => Cow::Owned(Field::new(s.name().clone(), s.dtype().clone())), + Some(s) => s.field(), + }, + } + } + + #[inline] + pub fn name(&self) -> &PlSmallStr { + match self { + Column::Series(s) => s.name(), + Column::Partitioned(s) => s.name(), + Column::Scalar(s) => s.name(), + } + } + + #[inline] + pub fn len(&self) -> usize { + match self { + Column::Series(s) => s.len(), + Column::Partitioned(s) => s.len(), + Column::Scalar(s) => s.len(), + } + } + + #[inline] + pub fn with_name(mut self, name: PlSmallStr) -> Column { + self.rename(name); + self + } + + #[inline] + pub fn rename(&mut self, name: PlSmallStr) { + match self { + Column::Series(s) => _ = s.rename(name), + Column::Partitioned(s) => _ = s.rename(name), + Column::Scalar(s) => _ = s.rename(name), + } + } + + // # Downcasting + #[inline] + pub fn as_series(&self) -> Option<&Series> { + match self { + Column::Series(s) => Some(s), + _ => None, + } + } + #[inline] + pub fn as_partitioned_column(&self) -> Option<&PartitionedColumn> { + match self { + Column::Partitioned(s) => Some(s), + _ => None, + } + } + #[inline] + pub fn as_scalar_column(&self) -> Option<&ScalarColumn> { + match self { + Column::Scalar(s) => Some(s), + _ => None, + } + } + #[inline] + pub fn as_scalar_column_mut(&mut self) -> Option<&mut ScalarColumn> { + match self { + Column::Scalar(s) => Some(s), + _ => None, + } + } + + // # Try to Chunked Arrays + pub fn try_bool(&self) -> Option<&BooleanChunked> { + self.as_materialized_series().try_bool() + } + pub fn try_i8(&self) -> Option<&Int8Chunked> { + self.as_materialized_series().try_i8() + } + pub fn try_i16(&self) -> Option<&Int16Chunked> { + self.as_materialized_series().try_i16() + } + pub fn try_i32(&self) -> Option<&Int32Chunked> { + self.as_materialized_series().try_i32() + } + pub fn try_i64(&self) -> Option<&Int64Chunked> { + self.as_materialized_series().try_i64() + } + pub fn try_u8(&self) -> Option<&UInt8Chunked> { + self.as_materialized_series().try_u8() + } + pub fn try_u16(&self) -> Option<&UInt16Chunked> { + self.as_materialized_series().try_u16() + } + pub fn try_u32(&self) -> Option<&UInt32Chunked> { + self.as_materialized_series().try_u32() + } + pub fn try_u64(&self) -> Option<&UInt64Chunked> { + self.as_materialized_series().try_u64() + } + pub fn try_f32(&self) -> Option<&Float32Chunked> { + self.as_materialized_series().try_f32() + } + pub fn try_f64(&self) -> Option<&Float64Chunked> { + self.as_materialized_series().try_f64() + } + pub fn try_str(&self) -> Option<&StringChunked> { + self.as_materialized_series().try_str() + } + pub fn try_list(&self) -> Option<&ListChunked> { + self.as_materialized_series().try_list() + } + pub fn try_binary(&self) -> Option<&BinaryChunked> { + self.as_materialized_series().try_binary() + } + pub fn try_idx(&self) -> Option<&IdxCa> { + self.as_materialized_series().try_idx() + } + pub fn try_binary_offset(&self) -> Option<&BinaryOffsetChunked> { + self.as_materialized_series().try_binary_offset() + } + #[cfg(feature = "dtype-datetime")] + pub fn try_datetime(&self) -> Option<&DatetimeChunked> { + self.as_materialized_series().try_datetime() + } + #[cfg(feature = "dtype-struct")] + pub fn try_struct(&self) -> Option<&StructChunked> { + self.as_materialized_series().try_struct() + } + #[cfg(feature = "dtype-decimal")] + pub fn try_decimal(&self) -> Option<&DecimalChunked> { + self.as_materialized_series().try_decimal() + } + #[cfg(feature = "dtype-array")] + pub fn try_array(&self) -> Option<&ArrayChunked> { + self.as_materialized_series().try_array() + } + #[cfg(feature = "dtype-categorical")] + pub fn try_categorical(&self) -> Option<&CategoricalChunked> { + self.as_materialized_series().try_categorical() + } + #[cfg(feature = "dtype-date")] + pub fn try_date(&self) -> Option<&DateChunked> { + self.as_materialized_series().try_date() + } + #[cfg(feature = "dtype-duration")] + pub fn try_duration(&self) -> Option<&DurationChunked> { + self.as_materialized_series().try_duration() + } + + // # To Chunked Arrays + pub fn bool(&self) -> PolarsResult<&BooleanChunked> { + self.as_materialized_series().bool() + } + pub fn i8(&self) -> PolarsResult<&Int8Chunked> { + self.as_materialized_series().i8() + } + pub fn i16(&self) -> PolarsResult<&Int16Chunked> { + self.as_materialized_series().i16() + } + pub fn i32(&self) -> PolarsResult<&Int32Chunked> { + self.as_materialized_series().i32() + } + pub fn i64(&self) -> PolarsResult<&Int64Chunked> { + self.as_materialized_series().i64() + } + #[cfg(feature = "dtype-i128")] + pub fn i128(&self) -> PolarsResult<&Int128Chunked> { + self.as_materialized_series().i128() + } + pub fn u8(&self) -> PolarsResult<&UInt8Chunked> { + self.as_materialized_series().u8() + } + pub fn u16(&self) -> PolarsResult<&UInt16Chunked> { + self.as_materialized_series().u16() + } + pub fn u32(&self) -> PolarsResult<&UInt32Chunked> { + self.as_materialized_series().u32() + } + pub fn u64(&self) -> PolarsResult<&UInt64Chunked> { + self.as_materialized_series().u64() + } + pub fn f32(&self) -> PolarsResult<&Float32Chunked> { + self.as_materialized_series().f32() + } + pub fn f64(&self) -> PolarsResult<&Float64Chunked> { + self.as_materialized_series().f64() + } + pub fn str(&self) -> PolarsResult<&StringChunked> { + self.as_materialized_series().str() + } + pub fn list(&self) -> PolarsResult<&ListChunked> { + self.as_materialized_series().list() + } + pub fn binary(&self) -> PolarsResult<&BinaryChunked> { + self.as_materialized_series().binary() + } + pub fn idx(&self) -> PolarsResult<&IdxCa> { + self.as_materialized_series().idx() + } + pub fn binary_offset(&self) -> PolarsResult<&BinaryOffsetChunked> { + self.as_materialized_series().binary_offset() + } + #[cfg(feature = "dtype-datetime")] + pub fn datetime(&self) -> PolarsResult<&DatetimeChunked> { + self.as_materialized_series().datetime() + } + #[cfg(feature = "dtype-struct")] + pub fn struct_(&self) -> PolarsResult<&StructChunked> { + self.as_materialized_series().struct_() + } + #[cfg(feature = "dtype-decimal")] + pub fn decimal(&self) -> PolarsResult<&DecimalChunked> { + self.as_materialized_series().decimal() + } + #[cfg(feature = "dtype-array")] + pub fn array(&self) -> PolarsResult<&ArrayChunked> { + self.as_materialized_series().array() + } + #[cfg(feature = "dtype-categorical")] + pub fn categorical(&self) -> PolarsResult<&CategoricalChunked> { + self.as_materialized_series().categorical() + } + #[cfg(feature = "dtype-date")] + pub fn date(&self) -> PolarsResult<&DateChunked> { + self.as_materialized_series().date() + } + #[cfg(feature = "dtype-duration")] + pub fn duration(&self) -> PolarsResult<&DurationChunked> { + self.as_materialized_series().duration() + } + + // # Casting + pub fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + match self { + Column::Series(s) => s.cast_with_options(dtype, options).map(Column::from), + Column::Partitioned(s) => s.cast_with_options(dtype, options).map(Column::from), + Column::Scalar(s) => s.cast_with_options(dtype, options).map(Column::from), + } + } + pub fn strict_cast(&self, dtype: &DataType) -> PolarsResult { + match self { + Column::Series(s) => s.strict_cast(dtype).map(Column::from), + Column::Partitioned(s) => s.strict_cast(dtype).map(Column::from), + Column::Scalar(s) => s.strict_cast(dtype).map(Column::from), + } + } + pub fn cast(&self, dtype: &DataType) -> PolarsResult { + match self { + Column::Series(s) => s.cast(dtype).map(Column::from), + Column::Partitioned(s) => s.cast(dtype).map(Column::from), + Column::Scalar(s) => s.cast(dtype).map(Column::from), + } + } + /// # Safety + /// + /// This can lead to invalid memory access in downstream code. + pub unsafe fn cast_unchecked(&self, dtype: &DataType) -> PolarsResult { + match self { + Column::Series(s) => unsafe { s.cast_unchecked(dtype) }.map(Column::from), + Column::Partitioned(s) => unsafe { s.cast_unchecked(dtype) }.map(Column::from), + Column::Scalar(s) => unsafe { s.cast_unchecked(dtype) }.map(Column::from), + } + } + + pub fn clear(&self) -> Self { + match self { + Column::Series(s) => s.clear().into(), + Column::Partitioned(s) => s.clear().into(), + Column::Scalar(s) => s.resize(0).into(), + } + } + + #[inline] + pub fn shrink_to_fit(&mut self) { + match self { + Column::Series(s) => s.shrink_to_fit(), + // @partition-opt + Column::Partitioned(_) => {}, + Column::Scalar(_) => {}, + } + } + + #[inline] + pub fn new_from_index(&self, index: usize, length: usize) -> Self { + if index >= self.len() { + return Self::full_null(self.name().clone(), length, self.dtype()); + } + + match self { + Column::Series(s) => { + // SAFETY: Bounds check done before. + let av = unsafe { s.get_unchecked(index) }; + let scalar = Scalar::new(self.dtype().clone(), av.into_static()); + Self::new_scalar(self.name().clone(), scalar, length) + }, + Column::Partitioned(s) => { + // SAFETY: Bounds check done before. + let av = unsafe { s.get_unchecked(index) }; + let scalar = Scalar::new(self.dtype().clone(), av.into_static()); + Self::new_scalar(self.name().clone(), scalar, length) + }, + Column::Scalar(s) => s.resize(length).into(), + } + } + + #[inline] + pub fn has_nulls(&self) -> bool { + match self { + Self::Series(s) => s.has_nulls(), + // @partition-opt + Self::Partitioned(s) => s.as_materialized_series().has_nulls(), + Self::Scalar(s) => s.has_nulls(), + } + } + + #[inline] + pub fn is_null(&self) -> BooleanChunked { + match self { + Self::Series(s) => s.is_null(), + // @partition-opt + Self::Partitioned(s) => s.as_materialized_series().is_null(), + Self::Scalar(s) => { + BooleanChunked::full(s.name().clone(), s.scalar().is_null(), s.len()) + }, + } + } + #[inline] + pub fn is_not_null(&self) -> BooleanChunked { + match self { + Self::Series(s) => s.is_not_null(), + // @partition-opt + Self::Partitioned(s) => s.as_materialized_series().is_not_null(), + Self::Scalar(s) => { + BooleanChunked::full(s.name().clone(), !s.scalar().is_null(), s.len()) + }, + } + } + + pub fn to_physical_repr(&self) -> Column { + // @scalar-opt + self.as_materialized_series() + .to_physical_repr() + .into_owned() + .into() + } + /// # Safety + /// + /// This can lead to invalid memory access in downstream code. + pub unsafe fn from_physical_unchecked(&self, dtype: &DataType) -> PolarsResult { + // @scalar-opt + self.as_materialized_series() + .from_physical_unchecked(dtype) + .map(Column::from) + } + + pub fn head(&self, length: Option) -> Column { + let len = length.unwrap_or(HEAD_DEFAULT_LENGTH); + let len = usize::min(len, self.len()); + self.slice(0, len) + } + pub fn tail(&self, length: Option) -> Column { + let len = length.unwrap_or(TAIL_DEFAULT_LENGTH); + let len = usize::min(len, self.len()); + debug_assert!(len <= i64::MAX as usize); + self.slice(-(len as i64), len) + } + pub fn slice(&self, offset: i64, length: usize) -> Column { + match self { + Column::Series(s) => s.slice(offset, length).into(), + // @partition-opt + Column::Partitioned(s) => s.as_materialized_series().slice(offset, length).into(), + Column::Scalar(s) => { + let (_, length) = slice_offsets(offset, length, s.len()); + s.resize(length).into() + }, + } + } + + pub fn split_at(&self, offset: i64) -> (Column, Column) { + // @scalar-opt + let (l, r) = self.as_materialized_series().split_at(offset); + (l.into(), r.into()) + } + + #[inline] + pub fn null_count(&self) -> usize { + match self { + Self::Series(s) => s.null_count(), + Self::Partitioned(s) => s.null_count(), + Self::Scalar(s) if s.scalar().is_null() => s.len(), + Self::Scalar(_) => 0, + } + } + + pub fn take(&self, indices: &IdxCa) -> PolarsResult { + check_bounds_ca(indices, self.len() as IdxSize)?; + Ok(unsafe { self.take_unchecked(indices) }) + } + pub fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + check_bounds(indices, self.len() as IdxSize)?; + Ok(unsafe { self.take_slice_unchecked(indices) }) + } + /// # Safety + /// + /// No bounds on the indexes are performed. + pub unsafe fn take_unchecked(&self, indices: &IdxCa) -> Column { + debug_assert!(check_bounds_ca(indices, self.len() as IdxSize).is_ok()); + + match self { + Self::Series(s) => unsafe { s.take_unchecked(indices) }.into(), + Self::Partitioned(s) => { + let s = s.as_materialized_series(); + unsafe { s.take_unchecked(indices) }.into() + }, + Self::Scalar(s) => { + let idxs_length = indices.len(); + let idxs_null_count = indices.null_count(); + + let scalar = ScalarColumn::from_single_value_series( + s.as_single_value_series().take_unchecked(&IdxCa::new( + indices.name().clone(), + &[0][..s.len().min(1)], + )), + idxs_length, + ); + + // We need to make sure that null values in `idx` become null values in the result + if idxs_null_count == 0 || scalar.has_nulls() { + scalar.into_column() + } else if idxs_null_count == idxs_length { + scalar.into_nulls().into_column() + } else { + let validity = indices.rechunk_validity(); + let series = scalar.take_materialized_series(); + let name = series.name().clone(); + let dtype = series.dtype().clone(); + let mut chunks = series.into_chunks(); + assert_eq!(chunks.len(), 1); + chunks[0] = chunks[0].with_validity(validity); + unsafe { Series::from_chunks_and_dtype_unchecked(name, chunks, &dtype) } + .into_column() + } + }, + } + } + /// # Safety + /// + /// No bounds on the indexes are performed. + pub unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Column { + debug_assert!(check_bounds(indices, self.len() as IdxSize).is_ok()); + + match self { + Self::Series(s) => unsafe { s.take_slice_unchecked(indices) }.into(), + Self::Partitioned(s) => { + let s = s.as_materialized_series(); + unsafe { s.take_slice_unchecked(indices) }.into() + }, + Self::Scalar(s) => ScalarColumn::from_single_value_series( + s.as_single_value_series() + .take_slice_unchecked(&[0][..s.len().min(1)]), + indices.len(), + ) + .into(), + } + } + + /// General implementation for aggregation where a non-missing scalar would map to itself. + #[inline(always)] + #[cfg(any(feature = "algorithm_group_by", feature = "bitwise"))] + fn agg_with_unit_scalar( + &self, + groups: &GroupsType, + series_agg: impl Fn(&Series, &GroupsType) -> Series, + ) -> Column { + match self { + Column::Series(s) => series_agg(s, groups).into_column(), + // @partition-opt + Column::Partitioned(s) => series_agg(s.as_materialized_series(), groups).into_column(), + Column::Scalar(s) => { + if s.is_empty() { + return series_agg(s.as_materialized_series(), groups).into_column(); + } + + // We utilize the aggregation on Series to see: + // 1. the output datatype of the aggregation + // 2. whether this aggregation is even defined + let series_aggregation = series_agg( + &s.as_single_value_series(), + &GroupsType::Slice { + // @NOTE: this group is always valid since s is non-empty. + groups: vec![[0, 1]], + rolling: false, + }, + ); + + // If the aggregation is not defined, just return all nulls. + if series_aggregation.has_nulls() { + return Self::new_scalar( + series_aggregation.name().clone(), + Scalar::new(series_aggregation.dtype().clone(), AnyValue::Null), + groups.len(), + ); + } + + let mut scalar_col = s.resize(groups.len()); + // The aggregation might change the type (e.g. mean changes int -> float), so we do + // a cast here to the output type. + if series_aggregation.dtype() != s.dtype() { + scalar_col = scalar_col.cast(series_aggregation.dtype()).unwrap(); + } + + let Some(first_empty_idx) = groups.iter().position(|g| g.is_empty()) else { + // Fast path: no empty groups. keep the scalar intact. + return scalar_col.into_column(); + }; + + // All empty groups produce a *missing* or `null` value. + let mut validity = BitmapBuilder::with_capacity(groups.len()); + validity.extend_constant(first_empty_idx, true); + // SAFETY: We trust the length of this iterator. + let iter = unsafe { + TrustMyLength::new( + groups.iter().skip(first_empty_idx).map(|g| !g.is_empty()), + groups.len() - first_empty_idx, + ) + }; + validity.extend_trusted_len_iter(iter); + + let mut s = scalar_col.take_materialized_series().rechunk(); + // SAFETY: We perform a compute_len afterwards. + let chunks = unsafe { s.chunks_mut() }; + let arr = &mut chunks[0]; + *arr = arr.with_validity(validity.into_opt_validity()); + s.compute_len(); + + s.into_column() + }, + } + } + + /// # Safety + /// + /// Does no bounds checks, groups must be correct. + #[cfg(feature = "algorithm_group_by")] + pub unsafe fn agg_min(&self, groups: &GroupsType) -> Self { + self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_min(g) }) + } + + /// # Safety + /// + /// Does no bounds checks, groups must be correct. + #[cfg(feature = "algorithm_group_by")] + pub unsafe fn agg_max(&self, groups: &GroupsType) -> Self { + self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_max(g) }) + } + + /// # Safety + /// + /// Does no bounds checks, groups must be correct. + #[cfg(feature = "algorithm_group_by")] + pub unsafe fn agg_mean(&self, groups: &GroupsType) -> Self { + self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_mean(g) }) + } + + /// # Safety + /// + /// Does no bounds checks, groups must be correct. + #[cfg(feature = "algorithm_group_by")] + pub unsafe fn agg_sum(&self, groups: &GroupsType) -> Self { + // @scalar-opt + unsafe { self.as_materialized_series().agg_sum(groups) }.into() + } + + /// # Safety + /// + /// Does no bounds checks, groups must be correct. + #[cfg(feature = "algorithm_group_by")] + pub unsafe fn agg_first(&self, groups: &GroupsType) -> Self { + self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_first(g) }) + } + + /// # Safety + /// + /// Does no bounds checks, groups must be correct. + #[cfg(feature = "algorithm_group_by")] + pub unsafe fn agg_last(&self, groups: &GroupsType) -> Self { + self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_last(g) }) + } + + /// # Safety + /// + /// Does no bounds checks, groups must be correct. + #[cfg(feature = "algorithm_group_by")] + pub unsafe fn agg_n_unique(&self, groups: &GroupsType) -> Self { + // @scalar-opt + unsafe { self.as_materialized_series().agg_n_unique(groups) }.into() + } + + /// # Safety + /// + /// Does no bounds checks, groups must be correct. + #[cfg(feature = "algorithm_group_by")] + pub unsafe fn agg_quantile( + &self, + groups: &GroupsType, + quantile: f64, + method: QuantileMethod, + ) -> Self { + // @scalar-opt + + unsafe { + self.as_materialized_series() + .agg_quantile(groups, quantile, method) + } + .into() + } + + /// # Safety + /// + /// Does no bounds checks, groups must be correct. + #[cfg(feature = "algorithm_group_by")] + pub unsafe fn agg_median(&self, groups: &GroupsType) -> Self { + self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_median(g) }) + } + + /// # Safety + /// + /// Does no bounds checks, groups must be correct. + #[cfg(feature = "algorithm_group_by")] + pub unsafe fn agg_var(&self, groups: &GroupsType, ddof: u8) -> Self { + // @scalar-opt + unsafe { self.as_materialized_series().agg_var(groups, ddof) }.into() + } + + /// # Safety + /// + /// Does no bounds checks, groups must be correct. + #[cfg(feature = "algorithm_group_by")] + pub unsafe fn agg_std(&self, groups: &GroupsType, ddof: u8) -> Self { + // @scalar-opt + unsafe { self.as_materialized_series().agg_std(groups, ddof) }.into() + } + + /// # Safety + /// + /// Does no bounds checks, groups must be correct. + #[cfg(feature = "algorithm_group_by")] + pub unsafe fn agg_list(&self, groups: &GroupsType) -> Self { + // @scalar-opt + unsafe { self.as_materialized_series().agg_list(groups) }.into() + } + + /// # Safety + /// + /// Does no bounds checks, groups must be correct. + #[cfg(feature = "algorithm_group_by")] + pub fn agg_valid_count(&self, groups: &GroupsType) -> Self { + // @partition-opt + // @scalar-opt + unsafe { self.as_materialized_series().agg_valid_count(groups) }.into() + } + + /// # Safety + /// + /// Does no bounds checks, groups must be correct. + #[cfg(feature = "bitwise")] + pub fn agg_and(&self, groups: &GroupsType) -> Self { + self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_and(g) }) + } + /// # Safety + /// + /// Does no bounds checks, groups must be correct. + #[cfg(feature = "bitwise")] + pub fn agg_or(&self, groups: &GroupsType) -> Self { + self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_or(g) }) + } + /// # Safety + /// + /// Does no bounds checks, groups must be correct. + #[cfg(feature = "bitwise")] + pub fn agg_xor(&self, groups: &GroupsType) -> Self { + // @partition-opt + // @scalar-opt + unsafe { self.as_materialized_series().agg_xor(groups) }.into() + } + + pub fn full_null(name: PlSmallStr, size: usize, dtype: &DataType) -> Self { + Self::new_scalar(name, Scalar::new(dtype.clone(), AnyValue::Null), size) + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn reverse(&self) -> Column { + match self { + Column::Series(s) => s.reverse().into(), + Column::Partitioned(s) => s.reverse().into(), + Column::Scalar(_) => self.clone(), + } + } + + pub fn equals(&self, other: &Column) -> bool { + // @scalar-opt + self.as_materialized_series() + .equals(other.as_materialized_series()) + } + + pub fn equals_missing(&self, other: &Column) -> bool { + // @scalar-opt + self.as_materialized_series() + .equals_missing(other.as_materialized_series()) + } + + pub fn set_sorted_flag(&mut self, sorted: IsSorted) { + // @scalar-opt + match self { + Column::Series(s) => s.set_sorted_flag(sorted), + Column::Partitioned(s) => s.set_sorted_flag(sorted), + Column::Scalar(_) => {}, + } + } + + pub fn get_flags(&self) -> StatisticsFlags { + match self { + Column::Series(s) => s.get_flags(), + // @partition-opt + Column::Partitioned(_) => StatisticsFlags::empty(), + Column::Scalar(_) => { + StatisticsFlags::IS_SORTED_ASC | StatisticsFlags::CAN_FAST_EXPLODE_LIST + }, + } + } + + /// Returns whether the flags were set + pub fn set_flags(&mut self, flags: StatisticsFlags) -> bool { + match self { + Column::Series(s) => { + s.set_flags(flags); + true + }, + // @partition-opt + Column::Partitioned(_) => false, + Column::Scalar(_) => false, + } + } + + pub fn vec_hash( + &self, + build_hasher: PlSeedableRandomStateQuality, + buf: &mut Vec, + ) -> PolarsResult<()> { + // @scalar-opt? + self.as_materialized_series().vec_hash(build_hasher, buf) + } + + pub fn vec_hash_combine( + &self, + build_hasher: PlSeedableRandomStateQuality, + hashes: &mut [u64], + ) -> PolarsResult<()> { + // @scalar-opt? + self.as_materialized_series() + .vec_hash_combine(build_hasher, hashes) + } + + pub fn append(&mut self, other: &Column) -> PolarsResult<&mut Self> { + // @scalar-opt + self.into_materialized_series() + .append(other.as_materialized_series())?; + Ok(self) + } + pub fn append_owned(&mut self, other: Column) -> PolarsResult<&mut Self> { + self.into_materialized_series() + .append_owned(other.take_materialized_series())?; + Ok(self) + } + + pub fn arg_sort(&self, options: SortOptions) -> IdxCa { + if self.is_empty() { + return IdxCa::from_vec(self.name().clone(), Vec::new()); + } + + if self.null_count() == self.len() { + // We might need to maintain order so just respect the descending parameter. + let values = if options.descending { + (0..self.len() as IdxSize).rev().collect() + } else { + (0..self.len() as IdxSize).collect() + }; + + return IdxCa::from_vec(self.name().clone(), values); + } + + let is_sorted = Some(self.is_sorted_flag()); + let Some(is_sorted) = is_sorted.filter(|v| !matches!(v, IsSorted::Not)) else { + return self.as_materialized_series().arg_sort(options); + }; + + // Fast path: the data is sorted. + let is_sorted_dsc = matches!(is_sorted, IsSorted::Descending); + let invert = options.descending != is_sorted_dsc; + + let mut values = Vec::with_capacity(self.len()); + + #[inline(never)] + fn extend( + start: IdxSize, + end: IdxSize, + slf: &Column, + values: &mut Vec, + is_only_nulls: bool, + invert: bool, + maintain_order: bool, + ) { + debug_assert!(start <= end); + debug_assert!(start as usize <= slf.len()); + debug_assert!(end as usize <= slf.len()); + + if !invert || is_only_nulls { + values.extend(start..end); + return; + } + + // If we don't have to maintain order but we have to invert. Just flip it around. + if !maintain_order { + values.extend((start..end).rev()); + return; + } + + // If we want to maintain order but we also needs to invert, we need to invert + // per group of items. + // + // @NOTE: Since the column is sorted, arg_unique can also take a fast path and + // just do a single traversal. + let arg_unique = slf + .slice(start as i64, (end - start) as usize) + .arg_unique() + .unwrap(); + + assert!(!arg_unique.has_nulls()); + + let num_unique = arg_unique.len(); + + // Fast path: all items are unique. + if num_unique == (end - start) as usize { + values.extend((start..end).rev()); + return; + } + + if num_unique == 1 { + values.extend(start..end); + return; + } + + let mut prev_idx = end - start; + for chunk in arg_unique.downcast_iter() { + for &idx in chunk.values().as_slice().iter().rev() { + values.extend(start + idx..start + prev_idx); + prev_idx = idx; + } + } + } + macro_rules! extend { + ($start:expr, $end:expr) => { + extend!($start, $end, is_only_nulls = false); + }; + ($start:expr, $end:expr, is_only_nulls = $is_only_nulls:expr) => { + extend( + $start, + $end, + self, + &mut values, + $is_only_nulls, + invert, + options.maintain_order, + ); + }; + } + + let length = self.len() as IdxSize; + let null_count = self.null_count() as IdxSize; + + if null_count == 0 { + extend!(0, length); + } else { + let has_nulls_last = self.get(self.len() - 1).unwrap().is_null(); + match (options.nulls_last, has_nulls_last) { + (true, true) => { + // Current: Nulls last, Wanted: Nulls last + extend!(0, length - null_count); + extend!(length - null_count, length, is_only_nulls = true); + }, + (true, false) => { + // Current: Nulls first, Wanted: Nulls last + extend!(null_count, length); + extend!(0, null_count, is_only_nulls = true); + }, + (false, true) => { + // Current: Nulls last, Wanted: Nulls first + extend!(length - null_count, length, is_only_nulls = true); + extend!(0, length - null_count); + }, + (false, false) => { + // Current: Nulls first, Wanted: Nulls first + extend!(0, null_count, is_only_nulls = true); + extend!(null_count, length); + }, + } + } + + // @NOTE: This can theoretically be pushed into the previous operation but it is really + // worth it... probably not... + if let Some(limit) = options.limit { + let limit = limit.min(length); + values.truncate(limit as usize); + } + + IdxCa::from_vec(self.name().clone(), values) + } + + pub fn arg_sort_multiple( + &self, + by: &[Column], + options: &SortMultipleOptions, + ) -> PolarsResult { + // @scalar-opt + self.as_materialized_series().arg_sort_multiple(by, options) + } + + pub fn arg_unique(&self) -> PolarsResult { + match self { + Column::Scalar(s) => Ok(IdxCa::new_vec(s.name().clone(), vec![0])), + _ => self.as_materialized_series().arg_unique(), + } + } + + pub fn bit_repr(&self) -> Option { + // @scalar-opt + self.as_materialized_series().bit_repr() + } + + pub fn into_frame(self) -> DataFrame { + // SAFETY: A single-column dataframe cannot have length mismatches or duplicate names + unsafe { DataFrame::new_no_checks(self.len(), vec![self]) } + } + + pub fn extend(&mut self, other: &Column) -> PolarsResult<&mut Self> { + // @scalar-opt + self.into_materialized_series() + .extend(other.as_materialized_series())?; + Ok(self) + } + + pub fn rechunk(&self) -> Column { + match self { + Column::Series(s) => s.rechunk().into(), + Column::Partitioned(s) => { + if let Some(s) = s.lazy_as_materialized_series() { + // This should always hold for partitioned. + debug_assert_eq!(s.n_chunks(), 1) + } + self.clone() + }, + Column::Scalar(s) => { + if s.lazy_as_materialized_series() + .filter(|x| x.n_chunks() > 1) + .is_some() + { + Column::Scalar(ScalarColumn::new( + s.name().clone(), + s.scalar().clone(), + s.len(), + )) + } else { + self.clone() + } + }, + } + } + + pub fn explode(&self) -> PolarsResult { + self.as_materialized_series().explode().map(Column::from) + } + pub fn implode(&self) -> PolarsResult { + self.as_materialized_series().implode() + } + + pub fn fill_null(&self, strategy: FillNullStrategy) -> PolarsResult { + // @scalar-opt + self.as_materialized_series() + .fill_null(strategy) + .map(Column::from) + } + + pub fn divide(&self, rhs: &Column) -> PolarsResult { + // @scalar-opt + self.as_materialized_series() + .divide(rhs.as_materialized_series()) + .map(Column::from) + } + + pub fn shift(&self, periods: i64) -> Column { + // @scalar-opt + self.as_materialized_series().shift(periods).into() + } + + #[cfg(feature = "zip_with")] + pub fn zip_with(&self, mask: &BooleanChunked, other: &Self) -> PolarsResult { + // @scalar-opt + self.as_materialized_series() + .zip_with(mask, other.as_materialized_series()) + .map(Self::from) + } + + #[cfg(feature = "zip_with")] + pub fn zip_with_same_type( + &self, + mask: &ChunkedArray, + other: &Column, + ) -> PolarsResult { + // @scalar-opt + self.as_materialized_series() + .zip_with_same_type(mask, other.as_materialized_series()) + .map(Column::from) + } + + pub fn drop_nulls(&self) -> Column { + match self { + Column::Series(s) => s.drop_nulls().into_column(), + // @partition-opt + Column::Partitioned(s) => s.as_materialized_series().drop_nulls().into_column(), + Column::Scalar(s) => s.drop_nulls().into_column(), + } + } + + /// Packs every element into a list. + pub fn as_list(&self) -> ListChunked { + // @scalar-opt + // @partition-opt + self.as_materialized_series().as_list() + } + + pub fn is_sorted_flag(&self) -> IsSorted { + match self { + Column::Series(s) => s.is_sorted_flag(), + Column::Partitioned(s) => s.partitions().is_sorted_flag(), + Column::Scalar(_) => IsSorted::Ascending, + } + } + + pub fn unique(&self) -> PolarsResult { + match self { + Column::Series(s) => s.unique().map(Column::from), + // @partition-opt + Column::Partitioned(s) => s.as_materialized_series().unique().map(Column::from), + Column::Scalar(s) => { + _ = s.as_single_value_series().unique()?; + if s.is_empty() { + return Ok(s.clone().into_column()); + } + + Ok(s.resize(1).into_column()) + }, + } + } + pub fn unique_stable(&self) -> PolarsResult { + match self { + Column::Series(s) => s.unique_stable().map(Column::from), + // @partition-opt + Column::Partitioned(s) => s.as_materialized_series().unique_stable().map(Column::from), + Column::Scalar(s) => { + _ = s.as_single_value_series().unique_stable()?; + if s.is_empty() { + return Ok(s.clone().into_column()); + } + + Ok(s.resize(1).into_column()) + }, + } + } + + pub fn reshape_list(&self, dimensions: &[ReshapeDimension]) -> PolarsResult { + // @scalar-opt + self.as_materialized_series() + .reshape_list(dimensions) + .map(Self::from) + } + + #[cfg(feature = "dtype-array")] + pub fn reshape_array(&self, dimensions: &[ReshapeDimension]) -> PolarsResult { + // @scalar-opt + self.as_materialized_series() + .reshape_array(dimensions) + .map(Self::from) + } + + pub fn sort(&self, sort_options: SortOptions) -> PolarsResult { + // @scalar-opt + self.as_materialized_series() + .sort(sort_options) + .map(Self::from) + } + + pub fn filter(&self, filter: &BooleanChunked) -> PolarsResult { + match self { + Column::Series(s) => s.filter(filter).map(Column::from), + Column::Partitioned(s) => s.as_materialized_series().filter(filter).map(Column::from), + Column::Scalar(s) => { + if s.is_empty() { + return Ok(s.clone().into_column()); + } + + // Broadcasting + if filter.len() == 1 { + return match filter.get(0) { + Some(true) => Ok(s.clone().into_column()), + _ => Ok(s.resize(0).into_column()), + }; + } + + Ok(s.resize(filter.sum().unwrap() as usize).into_column()) + }, + } + } + + #[cfg(feature = "random")] + pub fn shuffle(&self, seed: Option) -> Self { + // @scalar-opt + self.as_materialized_series().shuffle(seed).into() + } + + #[cfg(feature = "random")] + pub fn sample_frac( + &self, + frac: f64, + with_replacement: bool, + shuffle: bool, + seed: Option, + ) -> PolarsResult { + self.as_materialized_series() + .sample_frac(frac, with_replacement, shuffle, seed) + .map(Self::from) + } + + #[cfg(feature = "random")] + pub fn sample_n( + &self, + n: usize, + with_replacement: bool, + shuffle: bool, + seed: Option, + ) -> PolarsResult { + self.as_materialized_series() + .sample_n(n, with_replacement, shuffle, seed) + .map(Self::from) + } + + pub fn gather_every(&self, n: usize, offset: usize) -> PolarsResult { + polars_ensure!(n > 0, InvalidOperation: "gather_every(n): n should be positive"); + if self.len().saturating_sub(offset) == 0 { + return Ok(self.clear()); + } + + match self { + Column::Series(s) => Ok(s.gather_every(n, offset)?.into()), + Column::Partitioned(s) => { + Ok(s.as_materialized_series().gather_every(n, offset)?.into()) + }, + Column::Scalar(s) => { + let total = s.len() - offset; + Ok(s.resize(1 + (total - 1) / n).into()) + }, + } + } + + pub fn extend_constant(&self, value: AnyValue, n: usize) -> PolarsResult { + if self.is_empty() { + return Ok(Self::new_scalar( + self.name().clone(), + Scalar::new(self.dtype().clone(), value.into_static()), + n, + )); + } + + match self { + Column::Series(s) => s.extend_constant(value, n).map(Column::from), + Column::Partitioned(s) => s.extend_constant(value, n).map(Column::from), + Column::Scalar(s) => { + if s.scalar().as_any_value() == value { + Ok(s.resize(s.len() + n).into()) + } else { + s.as_materialized_series() + .extend_constant(value, n) + .map(Column::from) + } + }, + } + } + + pub fn is_finite(&self) -> PolarsResult { + self.try_map_unary_elementwise_to_bool(|s| s.is_finite()) + } + pub fn is_infinite(&self) -> PolarsResult { + self.try_map_unary_elementwise_to_bool(|s| s.is_infinite()) + } + pub fn is_nan(&self) -> PolarsResult { + self.try_map_unary_elementwise_to_bool(|s| s.is_nan()) + } + pub fn is_not_nan(&self) -> PolarsResult { + self.try_map_unary_elementwise_to_bool(|s| s.is_not_nan()) + } + + pub fn wrapping_trunc_div_scalar(&self, rhs: T) -> Self + where + T: Num + NumCast, + { + // @scalar-opt + self.as_materialized_series() + .wrapping_trunc_div_scalar(rhs) + .into() + } + + pub fn product(&self) -> PolarsResult { + // @scalar-opt + self.as_materialized_series().product() + } + + pub fn phys_iter(&self) -> SeriesPhysIter<'_> { + // @scalar-opt + self.as_materialized_series().phys_iter() + } + + #[inline] + pub fn get(&self, index: usize) -> PolarsResult { + polars_ensure!(index < self.len(), oob = index, self.len()); + + // SAFETY: Bounds check done just before. + Ok(unsafe { self.get_unchecked(index) }) + } + /// # Safety + /// + /// Does not perform bounds check on `index` + #[inline(always)] + pub unsafe fn get_unchecked(&self, index: usize) -> AnyValue { + debug_assert!(index < self.len()); + + match self { + Column::Series(s) => unsafe { s.get_unchecked(index) }, + Column::Partitioned(s) => unsafe { s.get_unchecked(index) }, + Column::Scalar(s) => s.scalar().as_any_value(), + } + } + + #[cfg(feature = "object")] + pub fn get_object( + &self, + index: usize, + ) -> Option<&dyn crate::chunked_array::object::PolarsObjectSafe> { + self.as_materialized_series().get_object(index) + } + + pub fn bitand(&self, rhs: &Self) -> PolarsResult { + self.try_apply_broadcasting_binary_elementwise(rhs, |l, r| l & r) + } + pub fn bitor(&self, rhs: &Self) -> PolarsResult { + self.try_apply_broadcasting_binary_elementwise(rhs, |l, r| l | r) + } + pub fn bitxor(&self, rhs: &Self) -> PolarsResult { + self.try_apply_broadcasting_binary_elementwise(rhs, |l, r| l ^ r) + } + + pub fn try_add_owned(self, other: Self) -> PolarsResult { + match (self, other) { + (Column::Series(lhs), Column::Series(rhs)) => { + lhs.take().try_add_owned(rhs.take()).map(Column::from) + }, + (lhs, rhs) => lhs + rhs, + } + } + pub fn try_sub_owned(self, other: Self) -> PolarsResult { + match (self, other) { + (Column::Series(lhs), Column::Series(rhs)) => { + lhs.take().try_sub_owned(rhs.take()).map(Column::from) + }, + (lhs, rhs) => lhs - rhs, + } + } + pub fn try_mul_owned(self, other: Self) -> PolarsResult { + match (self, other) { + (Column::Series(lhs), Column::Series(rhs)) => { + lhs.take().try_mul_owned(rhs.take()).map(Column::from) + }, + (lhs, rhs) => lhs * rhs, + } + } + + pub(crate) fn str_value(&self, index: usize) -> PolarsResult> { + Ok(self.get(index)?.str_value()) + } + + pub fn min_reduce(&self) -> PolarsResult { + match self { + Column::Series(s) => s.min_reduce(), + Column::Partitioned(s) => s.min_reduce(), + Column::Scalar(s) => { + // We don't really want to deal with handling the full semantics here so we just + // cast to a single value series. This is a tiny bit wasteful, but probably fine. + s.as_single_value_series().min_reduce() + }, + } + } + pub fn max_reduce(&self) -> PolarsResult { + match self { + Column::Series(s) => s.max_reduce(), + Column::Partitioned(s) => s.max_reduce(), + Column::Scalar(s) => { + // We don't really want to deal with handling the full semantics here so we just + // cast to a single value series. This is a tiny bit wasteful, but probably fine. + s.as_single_value_series().max_reduce() + }, + } + } + pub fn median_reduce(&self) -> PolarsResult { + match self { + Column::Series(s) => s.median_reduce(), + Column::Partitioned(s) => s.as_materialized_series().median_reduce(), + Column::Scalar(s) => { + // We don't really want to deal with handling the full semantics here so we just + // cast to a single value series. This is a tiny bit wasteful, but probably fine. + s.as_single_value_series().median_reduce() + }, + } + } + pub fn mean_reduce(&self) -> Scalar { + match self { + Column::Series(s) => s.mean_reduce(), + Column::Partitioned(s) => s.as_materialized_series().mean_reduce(), + Column::Scalar(s) => { + // We don't really want to deal with handling the full semantics here so we just + // cast to a single value series. This is a tiny bit wasteful, but probably fine. + s.as_single_value_series().mean_reduce() + }, + } + } + pub fn std_reduce(&self, ddof: u8) -> PolarsResult { + match self { + Column::Series(s) => s.std_reduce(ddof), + Column::Partitioned(s) => s.as_materialized_series().std_reduce(ddof), + Column::Scalar(s) => { + // We don't really want to deal with handling the full semantics here so we just + // cast to a single value series. This is a tiny bit wasteful, but probably fine. + s.as_single_value_series().std_reduce(ddof) + }, + } + } + pub fn var_reduce(&self, ddof: u8) -> PolarsResult { + match self { + Column::Series(s) => s.var_reduce(ddof), + Column::Partitioned(s) => s.as_materialized_series().var_reduce(ddof), + Column::Scalar(s) => { + // We don't really want to deal with handling the full semantics here so we just + // cast to a single value series. This is a tiny bit wasteful, but probably fine. + s.as_single_value_series().var_reduce(ddof) + }, + } + } + pub fn sum_reduce(&self) -> PolarsResult { + // @partition-opt + // @scalar-opt + self.as_materialized_series().sum_reduce() + } + pub fn and_reduce(&self) -> PolarsResult { + match self { + Column::Series(s) => s.and_reduce(), + Column::Partitioned(s) => s.and_reduce(), + Column::Scalar(s) => { + // We don't really want to deal with handling the full semantics here so we just + // cast to a single value series. This is a tiny bit wasteful, but probably fine. + s.as_single_value_series().and_reduce() + }, + } + } + pub fn or_reduce(&self) -> PolarsResult { + match self { + Column::Series(s) => s.or_reduce(), + Column::Partitioned(s) => s.or_reduce(), + Column::Scalar(s) => { + // We don't really want to deal with handling the full semantics here so we just + // cast to a single value series. This is a tiny bit wasteful, but probably fine. + s.as_single_value_series().or_reduce() + }, + } + } + pub fn xor_reduce(&self) -> PolarsResult { + match self { + Column::Series(s) => s.xor_reduce(), + // @partition-opt + Column::Partitioned(s) => s.as_materialized_series().xor_reduce(), + Column::Scalar(s) => { + // We don't really want to deal with handling the full semantics here so we just + // cast to a single value series. This is a tiny bit wasteful, but probably fine. + // + // We have to deal with the fact that xor is 0 if there is an even number of + // elements and the value if there is an odd number of elements. If there are zero + // elements the result should be `null`. + s.as_n_values_series(2 - s.len() % 2).xor_reduce() + }, + } + } + pub fn n_unique(&self) -> PolarsResult { + match self { + Column::Series(s) => s.n_unique(), + Column::Partitioned(s) => s.partitions().n_unique(), + Column::Scalar(s) => s.as_single_value_series().n_unique(), + } + } + pub fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { + self.as_materialized_series() + .quantile_reduce(quantile, method) + } + + pub(crate) fn estimated_size(&self) -> usize { + // @scalar-opt + self.as_materialized_series().estimated_size() + } + + pub fn sort_with(&self, options: SortOptions) -> PolarsResult { + match self { + Column::Series(s) => s.sort_with(options).map(Self::from), + // @partition-opt + Column::Partitioned(s) => s + .as_materialized_series() + .sort_with(options) + .map(Self::from), + Column::Scalar(s) => { + // This makes this function throw the same errors as Series::sort_with + _ = s.as_single_value_series().sort_with(options)?; + + Ok(self.clone()) + }, + } + } + + pub fn map_unary_elementwise_to_bool( + &self, + f: impl Fn(&Series) -> BooleanChunked, + ) -> BooleanChunked { + self.try_map_unary_elementwise_to_bool(|s| Ok(f(s))) + .unwrap() + } + pub fn try_map_unary_elementwise_to_bool( + &self, + f: impl Fn(&Series) -> PolarsResult, + ) -> PolarsResult { + match self { + Column::Series(s) => f(s), + Column::Partitioned(s) => f(s.as_materialized_series()), + Column::Scalar(s) => Ok(f(&s.as_single_value_series())?.new_from_index(0, s.len())), + } + } + + pub fn apply_unary_elementwise(&self, f: impl Fn(&Series) -> Series) -> Column { + self.try_apply_unary_elementwise(|s| Ok(f(s))).unwrap() + } + pub fn try_apply_unary_elementwise( + &self, + f: impl Fn(&Series) -> PolarsResult, + ) -> PolarsResult { + match self { + Column::Series(s) => f(s).map(Column::from), + Column::Partitioned(s) => s.try_apply_unary_elementwise(f).map(Self::from), + Column::Scalar(s) => Ok(ScalarColumn::from_single_value_series( + f(&s.as_single_value_series())?, + s.len(), + ) + .into()), + } + } + + pub fn apply_broadcasting_binary_elementwise( + &self, + other: &Self, + op: impl Fn(&Series, &Series) -> Series, + ) -> PolarsResult { + self.try_apply_broadcasting_binary_elementwise(other, |lhs, rhs| Ok(op(lhs, rhs))) + } + pub fn try_apply_broadcasting_binary_elementwise( + &self, + other: &Self, + op: impl Fn(&Series, &Series) -> PolarsResult, + ) -> PolarsResult { + fn output_length(a: &Column, b: &Column) -> PolarsResult { + match (a.len(), b.len()) { + // broadcasting + (1, o) | (o, 1) => Ok(o), + // equal + (a, b) if a == b => Ok(a), + // unequal + (a, b) => { + polars_bail!(InvalidOperation: "cannot do a binary operation on columns of different lengths: got {} and {}", a, b) + }, + } + } + + // Here we rely on the underlying broadcast operations. + let length = output_length(self, other)?; + match (self, other) { + (Column::Series(lhs), Column::Series(rhs)) => op(lhs, rhs).map(Column::from), + (Column::Series(lhs), Column::Scalar(rhs)) => { + op(lhs, &rhs.as_single_value_series()).map(Column::from) + }, + (Column::Scalar(lhs), Column::Series(rhs)) => { + op(&lhs.as_single_value_series(), rhs).map(Column::from) + }, + (Column::Scalar(lhs), Column::Scalar(rhs)) => { + let lhs = lhs.as_single_value_series(); + let rhs = rhs.as_single_value_series(); + + Ok(ScalarColumn::from_single_value_series(op(&lhs, &rhs)?, length).into_column()) + }, + // @partition-opt + (lhs, rhs) => { + op(lhs.as_materialized_series(), rhs.as_materialized_series()).map(Column::from) + }, + } + } + + pub fn apply_binary_elementwise( + &self, + other: &Self, + f: impl Fn(&Series, &Series) -> Series, + f_lb: impl Fn(&Scalar, &Series) -> Series, + f_rb: impl Fn(&Series, &Scalar) -> Series, + ) -> Column { + self.try_apply_binary_elementwise( + other, + |lhs, rhs| Ok(f(lhs, rhs)), + |lhs, rhs| Ok(f_lb(lhs, rhs)), + |lhs, rhs| Ok(f_rb(lhs, rhs)), + ) + .unwrap() + } + pub fn try_apply_binary_elementwise( + &self, + other: &Self, + f: impl Fn(&Series, &Series) -> PolarsResult, + f_lb: impl Fn(&Scalar, &Series) -> PolarsResult, + f_rb: impl Fn(&Series, &Scalar) -> PolarsResult, + ) -> PolarsResult { + debug_assert_eq!(self.len(), other.len()); + + match (self, other) { + (Column::Series(lhs), Column::Series(rhs)) => f(lhs, rhs).map(Column::from), + (Column::Series(lhs), Column::Scalar(rhs)) => f_rb(lhs, rhs.scalar()).map(Column::from), + (Column::Scalar(lhs), Column::Series(rhs)) => f_lb(lhs.scalar(), rhs).map(Column::from), + (Column::Scalar(lhs), Column::Scalar(rhs)) => { + let lhs = lhs.as_single_value_series(); + let rhs = rhs.as_single_value_series(); + + Ok( + ScalarColumn::from_single_value_series(f(&lhs, &rhs)?, self.len()) + .into_column(), + ) + }, + // @partition-opt + (lhs, rhs) => { + f(lhs.as_materialized_series(), rhs.as_materialized_series()).map(Column::from) + }, + } + } + + #[cfg(feature = "approx_unique")] + pub fn approx_n_unique(&self) -> PolarsResult { + match self { + Column::Series(s) => s.approx_n_unique(), + // @partition-opt + Column::Partitioned(s) => s.as_materialized_series().approx_n_unique(), + Column::Scalar(s) => { + // @NOTE: We do this for the error handling. + s.as_single_value_series().approx_n_unique()?; + Ok(1) + }, + } + } + + pub fn n_chunks(&self) -> usize { + match self { + Column::Series(s) => s.n_chunks(), + Column::Scalar(s) => s.lazy_as_materialized_series().map_or(1, |x| x.n_chunks()), + Column::Partitioned(s) => { + if let Some(s) = s.lazy_as_materialized_series() { + // This should always hold for partitioned. + debug_assert_eq!(s.n_chunks(), 1) + } + 1 + }, + } + } + + #[expect(clippy::wrong_self_convention)] + pub(crate) fn into_total_ord_inner<'a>(&'a self) -> Box { + // @scalar-opt + self.as_materialized_series().into_total_ord_inner() + } + #[expect(unused, clippy::wrong_self_convention)] + pub(crate) fn into_total_eq_inner<'a>(&'a self) -> Box { + // @scalar-opt + self.as_materialized_series().into_total_eq_inner() + } + + pub fn rechunk_to_arrow(self, compat_level: CompatLevel) -> Box { + // Rechunk to one chunk if necessary + let mut series = self.take_materialized_series(); + if series.n_chunks() > 1 { + series = series.rechunk(); + } + series.to_arrow(0, compat_level) + } +} + +impl Default for Column { + fn default() -> Self { + Self::new_scalar( + PlSmallStr::EMPTY, + Scalar::new(DataType::Int64, AnyValue::Null), + 0, + ) + } +} + +impl PartialEq for Column { + fn eq(&self, other: &Self) -> bool { + // @scalar-opt + self.as_materialized_series() + .eq(other.as_materialized_series()) + } +} + +impl From for Column { + #[inline] + fn from(series: Series) -> Self { + // We instantiate a Scalar Column if the Series is length is 1. This makes it possible for + // future operations to be faster. + if series.len() == 1 { + return Self::Scalar(ScalarColumn::unit_scalar_from_series(series)); + } + + Self::Series(SeriesColumn::new(series)) + } +} + +impl IntoColumn for T { + #[inline] + fn into_column(self) -> Column { + self.into_series().into() + } +} + +impl IntoColumn for Column { + #[inline(always)] + fn into_column(self) -> Column { + self + } +} + +/// We don't want to serialize the scalar columns. So this helps pretend that columns are always +/// initialized without implementing From for Series. +/// +/// Those casts should be explicit. +#[derive(Clone)] +#[cfg_attr(feature = "serde", derive(serde::Serialize))] +#[cfg_attr(feature = "serde", serde(into = "Series"))] +struct _SerdeSeries(Series); + +impl From for _SerdeSeries { + #[inline] + fn from(value: Column) -> Self { + Self(value.take_materialized_series()) + } +} + +impl From<_SerdeSeries> for Series { + #[inline] + fn from(value: _SerdeSeries) -> Self { + value.0 + } +} diff --git a/crates/polars-core/src/frame/column/partitioned.rs b/crates/polars-core/src/frame/column/partitioned.rs new file mode 100644 index 000000000000..aba32c179dac --- /dev/null +++ b/crates/polars-core/src/frame/column/partitioned.rs @@ -0,0 +1,294 @@ +use std::borrow::Cow; +use std::convert::identity; +use std::sync::{Arc, OnceLock}; + +use polars_error::{PolarsResult, polars_ensure}; +use polars_utils::IdxSize; +use polars_utils::pl_str::PlSmallStr; + +use super::{AnyValue, Column, DataType, Field, IntoColumn, Series}; +use crate::chunked_array::cast::CastOptions; +use crate::frame::Scalar; +use crate::series::IsSorted; + +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct PartitionedColumn { + name: PlSmallStr, + + values: Series, + ends: Arc<[IdxSize]>, + + #[cfg_attr(feature = "serde", serde(skip))] + materialized: OnceLock, +} + +impl IntoColumn for PartitionedColumn { + fn into_column(self) -> Column { + Column::Partitioned(self) + } +} + +impl From for Column { + fn from(value: PartitionedColumn) -> Self { + value.into_column() + } +} + +fn verify_invariants(values: &Series, ends: &[IdxSize]) -> PolarsResult<()> { + polars_ensure!( + values.len() == ends.len(), + ComputeError: "partitioned column `values` length does not match `ends` length ({} != {})", + values.len(), + ends.len() + ); + + for vs in ends.windows(2) { + polars_ensure!( + vs[0] <= vs[1], + ComputeError: "partitioned column `ends` are not monotonely non-decreasing", + ); + } + + Ok(()) +} + +impl PartitionedColumn { + pub fn new(name: PlSmallStr, values: Series, ends: Arc<[IdxSize]>) -> Self { + Self::try_new(name, values, ends).unwrap() + } + + /// # Safety + /// + /// Safe if: + /// - `values.len() == ends.len()` + /// - all values can have `dtype` + /// - `ends` is monotonely non-decreasing + pub unsafe fn new_unchecked(name: PlSmallStr, values: Series, ends: Arc<[IdxSize]>) -> Self { + if cfg!(debug_assertions) { + verify_invariants(&values, ends.as_ref()).unwrap(); + } + + let values = values.rechunk(); + Self { + name, + values, + ends, + materialized: OnceLock::new(), + } + } + + pub fn try_new(name: PlSmallStr, values: Series, ends: Arc<[IdxSize]>) -> PolarsResult { + verify_invariants(&values, ends.as_ref())?; + + // SAFETY: Invariants checked before + Ok(unsafe { Self::new_unchecked(name, values, ends) }) + } + + pub fn new_empty(name: PlSmallStr, dtype: DataType) -> Self { + Self { + name, + values: Series::new_empty(PlSmallStr::EMPTY, &dtype), + ends: Arc::default(), + + materialized: OnceLock::new(), + } + } + + pub fn len(&self) -> usize { + self.ends.last().map_or(0, |last| *last as usize) + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn name(&self) -> &PlSmallStr { + &self.name + } + + pub fn dtype(&self) -> &DataType { + self.values.dtype() + } + + #[inline] + pub fn field(&self) -> Cow { + match self.lazy_as_materialized_series() { + None => Cow::Owned(Field::new(self.name().clone(), self.dtype().clone())), + Some(s) => s.field(), + } + } + + pub fn rename(&mut self, name: PlSmallStr) -> &mut Self { + self.name = name; + self + } + + fn _to_series(name: PlSmallStr, values: &Series, ends: &[IdxSize]) -> Series { + let dtype = values.dtype(); + let mut column = Column::Series(Series::new_empty(name, dtype).into()); + + let mut prev_offset = 0; + for (i, &offset) in ends.iter().enumerate() { + // @TODO: Optimize + let length = offset - prev_offset; + column + .extend(&Column::new_scalar( + PlSmallStr::EMPTY, + Scalar::new(dtype.clone(), values.get(i).unwrap().into_static()), + length as usize, + )) + .unwrap(); + prev_offset = offset; + } + + debug_assert_eq!(column.len(), prev_offset as usize); + + column.take_materialized_series() + } + + /// Materialize the [`PartitionedColumn`] into a [`Series`]. + fn to_series(&self) -> Series { + Self::_to_series(self.name.clone(), &self.values, &self.ends) + } + + /// Get the [`PartitionedColumn`] as [`Series`] if it was already materialized. + pub fn lazy_as_materialized_series(&self) -> Option<&Series> { + self.materialized.get() + } + + /// Get the [`PartitionedColumn`] as [`Series`] + /// + /// This needs to materialize upon the first call. Afterwards, this is cached. + pub fn as_materialized_series(&self) -> &Series { + self.materialized.get_or_init(|| self.to_series()) + } + + /// Take the [`PartitionedColumn`] and materialize as a [`Series`] if not already done. + pub fn take_materialized_series(self) -> Series { + self.materialized + .into_inner() + .unwrap_or_else(|| Self::_to_series(self.name, &self.values, &self.ends)) + } + + pub fn apply_unary_elementwise(&self, f: impl Fn(&Series) -> Series) -> Self { + let result = f(&self.values).rechunk(); + assert_eq!(self.values.len(), result.len()); + unsafe { Self::new_unchecked(self.name.clone(), result, self.ends.clone()) } + } + + pub fn try_apply_unary_elementwise( + &self, + f: impl Fn(&Series) -> PolarsResult, + ) -> PolarsResult { + let result = f(&self.values)?.rechunk(); + assert_eq!(self.values.len(), result.len()); + Ok(unsafe { Self::new_unchecked(self.name.clone(), result, self.ends.clone()) }) + } + + pub fn extend_constant(&self, value: AnyValue, n: usize) -> PolarsResult { + let mut new_ends = self.ends.to_vec(); + // @TODO: IdxSize checks + let new_length = (self.len() + n) as IdxSize; + + let values = if !self.is_empty() && self.values.last().value() == &value { + *new_ends.last_mut().unwrap() = new_length; + self.values.clone() + } else { + new_ends.push(new_length); + self.values.extend_constant(value, 1)? + }; + + Ok(unsafe { Self::new_unchecked(self.name.clone(), values, new_ends.into()) }) + } + + pub unsafe fn get_unchecked(&self, index: usize) -> AnyValue { + debug_assert!(index < self.len()); + + // Common situation get_unchecked(0) + if index < self.ends[0] as usize { + return unsafe { self.get_unchecked(0) }; + } + + let value_idx = self + .ends + .binary_search(&(index as IdxSize)) + .map_or_else(identity, identity); + + self.get_unchecked(value_idx) + } + + pub fn min_reduce(&self) -> PolarsResult { + self.values.min_reduce() + } + pub fn max_reduce(&self) -> Result { + self.values.max_reduce() + } + + pub fn reverse(&self) -> Self { + let values = self.values.reverse(); + let mut ends = Vec::with_capacity(self.ends.len()); + + let mut offset = 0; + ends.extend(self.ends.windows(2).rev().map(|vs| { + offset += vs[1] - vs[0]; + offset + })); + ends.push(self.len() as IdxSize); + + unsafe { Self::new_unchecked(self.name.clone(), values, ends.into()) } + } + + pub fn set_sorted_flag(&mut self, sorted: IsSorted) { + self.values.set_sorted_flag(sorted); + } + + pub fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + let values = self.values.cast_with_options(dtype, options)?; + Ok(unsafe { Self::new_unchecked(self.name.clone(), values, self.ends.clone()) }) + } + + pub fn strict_cast(&self, dtype: &DataType) -> PolarsResult { + let values = self.values.strict_cast(dtype)?; + Ok(unsafe { Self::new_unchecked(self.name.clone(), values, self.ends.clone()) }) + } + + pub fn cast(&self, dtype: &DataType) -> PolarsResult { + let values = self.values.cast(dtype)?; + Ok(unsafe { Self::new_unchecked(self.name.clone(), values, self.ends.clone()) }) + } + + pub unsafe fn cast_unchecked(&self, dtype: &DataType) -> PolarsResult { + let values = unsafe { self.values.cast_unchecked(dtype) }?; + Ok(unsafe { Self::new_unchecked(self.name.clone(), values, self.ends.clone()) }) + } + + pub fn null_count(&self) -> usize { + match self.lazy_as_materialized_series() { + Some(s) => s.null_count(), + None => { + // @partition-opt + self.as_materialized_series().null_count() + }, + } + } + + pub fn clear(&self) -> Self { + Self::new_empty(self.name.clone(), self.values.dtype().clone()) + } + + pub fn partitions(&self) -> &Series { + &self.values + } + pub fn partition_ends(&self) -> &[IdxSize] { + &self.ends + } + + pub fn or_reduce(&self) -> PolarsResult { + self.values.or_reduce() + } + + pub fn and_reduce(&self) -> PolarsResult { + self.values.and_reduce() + } +} diff --git a/crates/polars-core/src/frame/column/scalar.rs b/crates/polars-core/src/frame/column/scalar.rs new file mode 100644 index 000000000000..724f427d469a --- /dev/null +++ b/crates/polars-core/src/frame/column/scalar.rs @@ -0,0 +1,389 @@ +use std::sync::OnceLock; + +use polars_error::PolarsResult; +use polars_utils::pl_str::PlSmallStr; + +use super::{AnyValue, Column, DataType, IntoColumn, Scalar, Series}; +use crate::chunked_array::cast::CastOptions; + +/// A [`Column`] that consists of a repeated [`Scalar`] +/// +/// This is lazily materialized into a [`Series`]. +#[derive(Debug, Clone)] +pub struct ScalarColumn { + name: PlSmallStr, + // The value of this scalar may be incoherent when `length == 0`. + scalar: Scalar, + length: usize, + + // invariants: + // materialized.name() == name + // materialized.len() == length + // materialized.dtype() == value.dtype + // materialized[i] == value, for all 0 <= i < length + /// A lazily materialized [`Series`] variant of this [`ScalarColumn`] + materialized: OnceLock, +} + +impl ScalarColumn { + #[inline] + pub fn new(name: PlSmallStr, scalar: Scalar, length: usize) -> Self { + Self { + name, + scalar, + length, + + materialized: OnceLock::new(), + } + } + + #[inline] + pub fn new_empty(name: PlSmallStr, dtype: DataType) -> Self { + Self { + name, + scalar: Scalar::new(dtype, AnyValue::Null), + length: 0, + + materialized: OnceLock::new(), + } + } + + pub fn name(&self) -> &PlSmallStr { + &self.name + } + + pub fn scalar(&self) -> &Scalar { + &self.scalar + } + + pub fn dtype(&self) -> &DataType { + self.scalar.dtype() + } + + pub fn len(&self) -> usize { + self.length + } + + pub fn is_empty(&self) -> bool { + self.length == 0 + } + + fn _to_series(name: PlSmallStr, value: Scalar, length: usize) -> Series { + let series = if length == 0 { + Series::new_empty(name, value.dtype()) + } else { + value.into_series(name).new_from_index(0, length) + }; + + debug_assert_eq!(series.len(), length); + + series + } + + /// Materialize the [`ScalarColumn`] into a [`Series`]. + pub fn to_series(&self) -> Series { + Self::_to_series(self.name.clone(), self.scalar.clone(), self.length) + } + + /// Get the [`ScalarColumn`] as [`Series`] if it was already materialized. + pub fn lazy_as_materialized_series(&self) -> Option<&Series> { + self.materialized.get() + } + + /// Get the [`ScalarColumn`] as [`Series`] + /// + /// This needs to materialize upon the first call. Afterwards, this is cached. + pub fn as_materialized_series(&self) -> &Series { + self.materialized.get_or_init(|| self.to_series()) + } + + /// Take the [`ScalarColumn`] and materialize as a [`Series`] if not already done. + pub fn take_materialized_series(self) -> Series { + self.materialized + .into_inner() + .unwrap_or_else(|| Self::_to_series(self.name, self.scalar, self.length)) + } + + /// Take the [`ScalarColumn`] as a series with a single value. + /// + /// If the [`ScalarColumn`] has `length=0` the resulting `Series` will also have `length=0`. + pub fn as_single_value_series(&self) -> Series { + self.as_n_values_series(1) + } + + /// Take the [`ScalarColumn`] as a series with a `n` values. + /// + /// If the [`ScalarColumn`] has `length=0` the resulting `Series` will also have `length=0`. + pub fn as_n_values_series(&self, n: usize) -> Series { + let length = usize::min(n, self.length); + + match self.materialized.get() { + // Don't take a refcount if we only want length-1 (or empty) - the materialized series + // could be extremely large. + Some(s) if length == self.length || length > 1 => s.head(Some(length)), + _ => Self::_to_series(self.name.clone(), self.scalar.clone(), length), + } + } + + /// Create a new [`ScalarColumn`] from a `length=1` Series and expand it `length`. + /// + /// This will panic if the value cannot be made static or if the series has length `0`. + #[inline] + pub fn unit_scalar_from_series(series: Series) -> Self { + assert_eq!(series.len(), 1); + // SAFETY: We just did the bounds check + let value = unsafe { series.get_unchecked(0) }; + let value = value.into_static(); + let value = Scalar::new(series.dtype().clone(), value); + let mut sc = ScalarColumn::new(series.name().clone(), value, 1); + sc.materialized = OnceLock::from(series); + sc + } + + /// Create a new [`ScalarColumn`] from a `length=1` Series and expand it `length`. + /// + /// This will panic if the value cannot be made static. + pub fn from_single_value_series(series: Series, length: usize) -> Self { + debug_assert!(series.len() <= 1); + + let value = if series.is_empty() { + AnyValue::Null + } else { + unsafe { series.get_unchecked(0) }.into_static() + }; + let value = Scalar::new(series.dtype().clone(), value); + ScalarColumn::new(series.name().clone(), value, length) + } + + /// Resize the [`ScalarColumn`] to new `length`. + /// + /// This reuses the materialized [`Series`], if `length <= self.length`. + pub fn resize(&self, length: usize) -> ScalarColumn { + if self.length == length { + return self.clone(); + } + + // This is violates an invariant if this triggers, the scalar value is undefined if the + // self.length == 0 so therefore we should never resize using that value. + debug_assert!(length == 0 || self.length > 0); + + let mut resized = Self { + name: self.name.clone(), + scalar: self.scalar.clone(), + length, + materialized: OnceLock::new(), + }; + + if length == self.length || (length < self.length && length > 1) { + if let Some(materialized) = self.materialized.get() { + resized.materialized = OnceLock::from(materialized.head(Some(length))); + debug_assert_eq!(resized.materialized.get().unwrap().len(), length); + } + } + + resized + } + + pub fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + // @NOTE: We expect that when casting the materialized series mostly does not need change + // the physical array. Therefore, we try to cast the entire materialized array if it is + // available. + + match self.materialized.get() { + Some(s) => { + let materialized = s.cast_with_options(dtype, options)?; + assert_eq!(self.length, materialized.len()); + + let mut casted = if materialized.is_empty() { + Self::new_empty(materialized.name().clone(), materialized.dtype().clone()) + } else { + // SAFETY: Just did bounds check + let scalar = unsafe { materialized.get_unchecked(0) }.into_static(); + Self::new( + materialized.name().clone(), + Scalar::new(materialized.dtype().clone(), scalar), + self.length, + ) + }; + casted.materialized = OnceLock::from(materialized); + Ok(casted) + }, + None => { + let s = self + .as_single_value_series() + .cast_with_options(dtype, options)?; + + if self.length == 0 { + Ok(Self::new_empty(s.name().clone(), s.dtype().clone())) + } else { + assert_eq!(1, s.len()); + Ok(Self::from_single_value_series(s, self.length)) + } + }, + } + } + + pub fn strict_cast(&self, dtype: &DataType) -> PolarsResult { + self.cast_with_options(dtype, CastOptions::Strict) + } + pub fn cast(&self, dtype: &DataType) -> PolarsResult { + self.cast_with_options(dtype, CastOptions::NonStrict) + } + /// # Safety + /// + /// This can lead to invalid memory access in downstream code. + pub unsafe fn cast_unchecked(&self, dtype: &DataType) -> PolarsResult { + // @NOTE: We expect that when casting the materialized series mostly does not need change + // the physical array. Therefore, we try to cast the entire materialized array if it is + // available. + + match self.materialized.get() { + Some(s) => { + let materialized = s.cast_unchecked(dtype)?; + assert_eq!(self.length, materialized.len()); + + let mut casted = if materialized.is_empty() { + Self::new_empty(materialized.name().clone(), materialized.dtype().clone()) + } else { + // SAFETY: Just did bounds check + let scalar = unsafe { materialized.get_unchecked(0) }.into_static(); + Self::new( + materialized.name().clone(), + Scalar::new(materialized.dtype().clone(), scalar), + self.length, + ) + }; + casted.materialized = OnceLock::from(materialized); + Ok(casted) + }, + None => { + let s = self.as_single_value_series().cast_unchecked(dtype)?; + assert_eq!(1, s.len()); + + if self.length == 0 { + Ok(Self::new_empty(s.name().clone(), s.dtype().clone())) + } else { + Ok(Self::from_single_value_series(s, self.length)) + } + }, + } + } + + pub fn rename(&mut self, name: PlSmallStr) -> &mut Self { + if let Some(series) = self.materialized.get_mut() { + series.rename(name.clone()); + } + + self.name = name; + self + } + + pub fn has_nulls(&self) -> bool { + self.length != 0 && self.scalar.is_null() + } + + pub fn drop_nulls(&self) -> Self { + if self.scalar.is_null() { + self.resize(0) + } else { + self.clone() + } + } + + pub fn into_nulls(mut self) -> Self { + self.scalar.update(AnyValue::Null); + self + } + + pub fn map_scalar(&mut self, map_scalar: impl Fn(Scalar) -> Scalar) { + self.scalar = map_scalar(std::mem::take(&mut self.scalar)); + self.materialized.take(); + } + pub fn with_value(&mut self, value: AnyValue<'static>) -> &mut Self { + self.scalar.update(value); + self.materialized.take(); + self + } +} + +impl IntoColumn for ScalarColumn { + #[inline(always)] + fn into_column(self) -> Column { + self.into() + } +} + +impl From for Column { + #[inline] + fn from(value: ScalarColumn) -> Self { + Self::Scalar(value) + } +} + +#[cfg(feature = "serde")] +mod serde_impl { + use std::sync::OnceLock; + + use polars_error::PolarsError; + use polars_utils::pl_str::PlSmallStr; + + use super::ScalarColumn; + use crate::frame::{Scalar, Series}; + + #[derive(serde::Serialize, serde::Deserialize)] + struct SerializeWrap { + name: PlSmallStr, + /// Unit-length series for dispatching to IPC serialize + unit_series: Series, + length: usize, + } + + impl From<&ScalarColumn> for SerializeWrap { + fn from(value: &ScalarColumn) -> Self { + Self { + name: value.name.clone(), + unit_series: value.scalar.clone().into_series(PlSmallStr::EMPTY), + length: value.length, + } + } + } + + impl TryFrom for ScalarColumn { + type Error = PolarsError; + + fn try_from(value: SerializeWrap) -> Result { + let slf = Self { + name: value.name, + scalar: Scalar::new( + value.unit_series.dtype().clone(), + value.unit_series.get(0)?.into_static(), + ), + length: value.length, + materialized: OnceLock::new(), + }; + + Ok(slf) + } + } + + impl serde::ser::Serialize for ScalarColumn { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + SerializeWrap::from(self).serialize(serializer) + } + } + + impl<'de> serde::de::Deserialize<'de> for ScalarColumn { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + use serde::de::Error; + + SerializeWrap::deserialize(deserializer) + .and_then(|x| ScalarColumn::try_from(x).map_err(D::Error::custom)) + } + } +} diff --git a/crates/polars-core/src/frame/column/series.rs b/crates/polars-core/src/frame/column/series.rs new file mode 100644 index 000000000000..d7c7e1b5b773 --- /dev/null +++ b/crates/polars-core/src/frame/column/series.rs @@ -0,0 +1,73 @@ +use std::ops::{Deref, DerefMut}; + +use super::Series; + +/// A very thin wrapper around [`Series`] that represents a [`Column`]ized version of [`Series`]. +/// +/// At the moment this just conditionally tracks where it was created so that materialization +/// problems can be tracked down. +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct SeriesColumn { + inner: Series, + + #[cfg(debug_assertions)] + #[cfg_attr(feature = "serde", serde(skip))] + materialized_at: Option>, +} + +impl SeriesColumn { + #[track_caller] + pub fn new(series: Series) -> Self { + Self { + inner: series, + + #[cfg(debug_assertions)] + materialized_at: if std::env::var("POLARS_TRACK_SERIES_MATERIALIZATION").as_deref() + == Ok("1") + { + Some(std::sync::Arc::new( + std::backtrace::Backtrace::force_capture(), + )) + } else { + None + }, + } + } + + pub fn materialized_at(&self) -> Option<&std::backtrace::Backtrace> { + #[cfg(debug_assertions)] + { + self.materialized_at.as_ref().map(|v| v.as_ref()) + } + + #[cfg(not(debug_assertions))] + None + } + + pub fn take(self) -> Series { + self.inner + } +} + +impl From for SeriesColumn { + #[track_caller] + #[inline(always)] + fn from(value: Series) -> Self { + Self::new(value) + } +} + +impl Deref for SeriesColumn { + type Target = Series; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for SeriesColumn { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} diff --git a/crates/polars-core/src/frame/explode.rs b/crates/polars-core/src/frame/explode.rs new file mode 100644 index 000000000000..cc4ec164b112 --- /dev/null +++ b/crates/polars-core/src/frame/explode.rs @@ -0,0 +1,308 @@ +use arrow::offset::OffsetsBuffer; +use polars_utils::pl_str::PlSmallStr; +use rayon::prelude::*; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use crate::POOL; +use crate::chunked_array::ops::explode::offsets_to_indexes; +use crate::prelude::*; +use crate::series::IsSorted; + +fn get_exploded(series: &Series) -> PolarsResult<(Series, OffsetsBuffer)> { + match series.dtype() { + DataType::List(_) => series.list().unwrap().explode_and_offsets(), + #[cfg(feature = "dtype-array")] + DataType::Array(_, _) => series.array().unwrap().explode_and_offsets(), + _ => polars_bail!(opq = explode, series.dtype()), + } +} + +/// Arguments for `LazyFrame::unpivot` function +#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct UnpivotArgsIR { + pub on: Vec, + pub index: Vec, + pub variable_name: Option, + pub value_name: Option, +} + +impl DataFrame { + pub fn explode_impl(&self, mut columns: Vec) -> PolarsResult { + polars_ensure!(!columns.is_empty(), InvalidOperation: "no columns provided in explode"); + let mut df = self.clone(); + if self.is_empty() { + for s in &columns { + df.with_column(s.as_materialized_series().explode()?)?; + } + return Ok(df); + } + columns.sort_by(|sa, sb| { + self.check_name_to_idx(sa.name().as_str()) + .expect("checked above") + .partial_cmp( + &self + .check_name_to_idx(sb.name().as_str()) + .expect("checked above"), + ) + .expect("cmp usize -> Ordering") + }); + + // first remove all the exploded columns + for s in &columns { + df = df.drop(s.name().as_str())?; + } + + let exploded_columns = POOL.install(|| { + columns + .par_iter() + .map(Column::as_materialized_series) + .map(get_exploded) + .map(|s| s.map(|(s, o)| (Column::from(s), o))) + .collect::>>() + })?; + + fn process_column( + original_df: &DataFrame, + df: &mut DataFrame, + exploded: Column, + ) -> PolarsResult<()> { + if exploded.len() == df.height() || df.width() == 0 { + let col_idx = original_df.check_name_to_idx(exploded.name().as_str())?; + df.columns.insert(col_idx, exploded); + } else { + polars_bail!( + ShapeMismatch: "exploded column(s) {:?} doesn't have the same length: {} \ + as the dataframe: {}", exploded.name(), exploded.name(), df.height(), + ); + } + Ok(()) + } + + let check_offsets = || { + let first_offsets = exploded_columns[0].1.as_slice(); + for (_, offsets) in &exploded_columns[1..] { + let offsets = offsets.as_slice(); + + let offset_l = first_offsets[0]; + let offset_r = offsets[0]; + let all_equal_len = first_offsets.len() != offsets.len() || { + first_offsets + .iter() + .zip(offsets.iter()) + .all(|(l, r)| (*l - offset_l) == (*r - offset_r)) + }; + + polars_ensure!(all_equal_len, + ShapeMismatch: "exploded columns must have matching element counts" + ) + } + Ok(()) + }; + let process_first = || { + let (exploded, offsets) = &exploded_columns[0]; + + let row_idx = offsets_to_indexes(offsets.as_slice(), exploded.len()); + let mut row_idx = IdxCa::from_vec(PlSmallStr::EMPTY, row_idx); + row_idx.set_sorted_flag(IsSorted::Ascending); + + // SAFETY: + // We just created indices that are in bounds. + let mut df = unsafe { df.take_unchecked(&row_idx) }; + process_column(self, &mut df, exploded.clone())?; + PolarsResult::Ok(df) + }; + let (df, result) = POOL.join(process_first, check_offsets); + let mut df = df?; + result?; + + for (exploded, _) in exploded_columns.into_iter().skip(1) { + process_column(self, &mut df, exploded)? + } + + Ok(df) + } + /// Explode `DataFrame` to long format by exploding a column with Lists. + /// + /// # Example + /// + /// ```ignore + /// # use polars_core::prelude::*; + /// let s0 = Series::new("a".into(), &[1i64, 2, 3]); + /// let s1 = Series::new("b".into(), &[1i64, 1, 1]); + /// let s2 = Series::new("c".into(), &[2i64, 2, 2]); + /// let list = Series::new("foo", &[s0, s1, s2]); + /// + /// let s0 = Series::new("B".into(), [1, 2, 3]); + /// let s1 = Series::new("C".into(), [1, 1, 1]); + /// let df = DataFrame::new(vec![list, s0, s1])?; + /// let exploded = df.explode(["foo"])?; + /// + /// println!("{:?}", df); + /// println!("{:?}", exploded); + /// # Ok::<(), PolarsError>(()) + /// ``` + /// Outputs: + /// + /// ```text + /// +-------------+-----+-----+ + /// | foo | B | C | + /// | --- | --- | --- | + /// | list [i64] | i32 | i32 | + /// +=============+=====+=====+ + /// | "[1, 2, 3]" | 1 | 1 | + /// +-------------+-----+-----+ + /// | "[1, 1, 1]" | 2 | 1 | + /// +-------------+-----+-----+ + /// | "[2, 2, 2]" | 3 | 1 | + /// +-------------+-----+-----+ + /// + /// +-----+-----+-----+ + /// | foo | B | C | + /// | --- | --- | --- | + /// | i64 | i32 | i32 | + /// +=====+=====+=====+ + /// | 1 | 1 | 1 | + /// +-----+-----+-----+ + /// | 2 | 1 | 1 | + /// +-----+-----+-----+ + /// | 3 | 1 | 1 | + /// +-----+-----+-----+ + /// | 1 | 2 | 1 | + /// +-----+-----+-----+ + /// | 1 | 2 | 1 | + /// +-----+-----+-----+ + /// | 1 | 2 | 1 | + /// +-----+-----+-----+ + /// | 2 | 3 | 1 | + /// +-----+-----+-----+ + /// | 2 | 3 | 1 | + /// +-----+-----+-----+ + /// | 2 | 3 | 1 | + /// +-----+-----+-----+ + /// ``` + pub fn explode(&self, columns: I) -> PolarsResult + where + I: IntoIterator, + S: Into, + { + // We need to sort the column by order of original occurrence. Otherwise the insert by index + // below will panic + let columns = self.select_columns(columns)?; + self.explode_impl(columns) + } +} + +#[cfg(test)] +mod test { + use crate::prelude::*; + + #[test] + #[cfg(feature = "dtype-i8")] + #[cfg_attr(miri, ignore)] + fn test_explode() { + let s0 = Series::new(PlSmallStr::from_static("a"), &[1i8, 2, 3]); + let s1 = Series::new(PlSmallStr::from_static("b"), &[1i8, 1, 1]); + let s2 = Series::new(PlSmallStr::from_static("c"), &[2i8, 2, 2]); + let list = Column::new(PlSmallStr::from_static("foo"), &[s0, s1, s2]); + + let s0 = Column::new(PlSmallStr::from_static("B"), [1, 2, 3]); + let s1 = Column::new(PlSmallStr::from_static("C"), [1, 1, 1]); + let df = DataFrame::new(vec![list, s0.clone(), s1.clone()]).unwrap(); + let exploded = df.explode(["foo"]).unwrap(); + assert_eq!(exploded.shape(), (9, 3)); + assert_eq!( + exploded + .column("C") + .unwrap() + .as_materialized_series() + .i32() + .unwrap() + .get(8), + Some(1) + ); + assert_eq!( + exploded + .column("B") + .unwrap() + .as_materialized_series() + .i32() + .unwrap() + .get(8), + Some(3) + ); + assert_eq!( + exploded + .column("foo") + .unwrap() + .as_materialized_series() + .i8() + .unwrap() + .get(8), + Some(2) + ); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_explode_df_empty_list() -> PolarsResult<()> { + let s0 = Series::new(PlSmallStr::from_static("a"), &[1, 2, 3]); + let s1 = Series::new(PlSmallStr::from_static("b"), &[1, 1, 1]); + let list = Column::new( + PlSmallStr::from_static("foo"), + &[s0, s1.clone(), s1.clear()], + ); + let s0 = Column::new(PlSmallStr::from_static("B"), [1, 2, 3]); + let s1 = Column::new(PlSmallStr::from_static("C"), [1, 1, 1]); + let df = DataFrame::new(vec![list, s0.clone(), s1.clone()])?; + + let out = df.explode(["foo"])?; + let expected = df![ + "foo" => [Some(1), Some(2), Some(3), Some(1), Some(1), Some(1), None], + "B" => [1, 1, 1, 2, 2, 2, 3], + "C" => [1, 1, 1, 1, 1, 1, 1], + ]?; + + assert!(out.equals_missing(&expected)); + + let list = Column::new( + PlSmallStr::from_static("foo"), + [ + s0.as_materialized_series().clone(), + s1.as_materialized_series().clear(), + s1.as_materialized_series().clone(), + ], + ); + let df = DataFrame::new(vec![list, s0, s1])?; + let out = df.explode(["foo"])?; + let expected = df![ + "foo" => [Some(1), Some(2), Some(3), None, Some(1), Some(1), Some(1)], + "B" => [1, 1, 1, 2, 3, 3, 3], + "C" => [1, 1, 1, 1, 1, 1, 1], + ]?; + + assert!(out.equals_missing(&expected)); + Ok(()) + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_explode_single_col() -> PolarsResult<()> { + let s0 = Series::new(PlSmallStr::from_static("a"), &[1i32, 2, 3]); + let s1 = Series::new(PlSmallStr::from_static("b"), &[1i32, 1, 1]); + let list = Column::new(PlSmallStr::from_static("foo"), &[s0, s1]); + let df = DataFrame::new(vec![list])?; + + let out = df.explode(["foo"])?; + let out = out + .column("foo")? + .as_materialized_series() + .i32()? + .into_no_null_iter() + .collect::>(); + assert_eq!(out, &[1i32, 2, 3, 1, 1, 1]); + + Ok(()) + } +} diff --git a/crates/polars-core/src/frame/from.rs b/crates/polars-core/src/frame/from.rs new file mode 100644 index 000000000000..060f4f4641c7 --- /dev/null +++ b/crates/polars-core/src/frame/from.rs @@ -0,0 +1,31 @@ +use crate::prelude::*; + +impl TryFrom for DataFrame { + type Error = PolarsError; + + fn try_from(arr: StructArray) -> PolarsResult { + let (fld, _length, arrs, nulls) = arr.into_data(); + polars_ensure!( + nulls.is_none(), + ComputeError: "cannot deserialize struct with nulls into a DataFrame" + ); + let columns = fld + .iter() + .zip(arrs) + .map(|(fld, arr)| { + // SAFETY: + // reported data type is correct + unsafe { + Series::_try_from_arrow_unchecked_with_md( + fld.name.clone(), + vec![arr], + fld.dtype(), + fld.metadata.as_deref(), + ) + } + .map(Column::from) + }) + .collect::>>()?; + DataFrame::new(columns) + } +} diff --git a/crates/polars-core/src/frame/group_by/aggregations/agg_list.rs b/crates/polars-core/src/frame/group_by/aggregations/agg_list.rs new file mode 100644 index 000000000000..3913b05b755e --- /dev/null +++ b/crates/polars-core/src/frame/group_by/aggregations/agg_list.rs @@ -0,0 +1,337 @@ +use arrow::offset::Offsets; + +use super::*; +use crate::chunked_array::builder::ListNullChunkedBuilder; +use crate::series::implementations::null::NullChunked; + +pub trait AggList { + /// # Safety + /// + /// groups should be in bounds + unsafe fn agg_list(&self, _groups: &GroupsType) -> Series; +} + +impl AggList for ChunkedArray +where + T: PolarsNumericType, + ChunkedArray: IntoSeries, +{ + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { + let ca = self.rechunk(); + + match groups { + GroupsType::Idx(groups) => { + let mut can_fast_explode = true; + + let arr = ca.downcast_iter().next().unwrap(); + let values = arr.values(); + + let mut offsets = Vec::::with_capacity(groups.len() + 1); + let mut length_so_far = 0i64; + offsets.push(length_so_far); + + let mut list_values = Vec::::with_capacity(self.len()); + groups.iter().for_each(|(_, idx)| { + let idx_len = idx.len(); + if idx_len == 0 { + can_fast_explode = false; + } + + length_so_far += idx_len as i64; + // SAFETY: + // group tuples are in bounds + { + list_values.extend(idx.iter().map(|idx| { + debug_assert!((*idx as usize) < values.len()); + *values.get_unchecked(*idx as usize) + })); + // SAFETY: + // we know that offsets has allocated enough slots + offsets.push_unchecked(length_so_far); + } + }); + + let validity = if arr.null_count() > 0 { + let old_validity = arr.validity().unwrap(); + let mut validity = MutableBitmap::from_len_set(list_values.len()); + + let mut count = 0; + groups.iter().for_each(|(_, idx)| { + for i in idx.as_slice() { + if !old_validity.get_bit_unchecked(*i as usize) { + validity.set_unchecked(count, false); + } + count += 1; + } + }); + Some(validity.into()) + } else { + None + }; + + let array = PrimitiveArray::new( + T::get_dtype().to_arrow(CompatLevel::newest()), + list_values.into(), + validity, + ); + let dtype = ListArray::::default_datatype( + T::get_dtype().to_arrow(CompatLevel::newest()), + ); + // SAFETY: + // offsets are monotonically increasing + let arr = ListArray::::new( + dtype, + Offsets::new_unchecked(offsets).into(), + Box::new(array), + None, + ); + + let mut ca = ListChunked::with_chunk(self.name().clone(), arr); + if can_fast_explode { + ca.set_fast_explode() + } + ca.into() + }, + GroupsType::Slice { groups, .. } => { + let mut can_fast_explode = true; + let arr = ca.downcast_iter().next().unwrap(); + let values = arr.values(); + + let mut offsets = Vec::::with_capacity(groups.len() + 1); + let mut length_so_far = 0i64; + offsets.push(length_so_far); + + let mut list_values = Vec::::with_capacity(self.len()); + groups.iter().for_each(|&[first, len]| { + if len == 0 { + can_fast_explode = false; + } + + length_so_far += len as i64; + list_values.extend_from_slice(&values[first as usize..(first + len) as usize]); + { + // SAFETY: + // we know that offsets has allocated enough slots + offsets.push_unchecked(length_so_far); + } + }); + + let validity = if arr.null_count() > 0 { + let old_validity = arr.validity().unwrap(); + let mut validity = MutableBitmap::from_len_set(list_values.len()); + + let mut count = 0; + groups.iter().for_each(|[first, len]| { + for i in *first..(*first + *len) { + if !old_validity.get_bit_unchecked(i as usize) { + validity.set_unchecked(count, false) + } + count += 1; + } + }); + Some(validity.into()) + } else { + None + }; + + let array = PrimitiveArray::new( + T::get_dtype().to_arrow(CompatLevel::newest()), + list_values.into(), + validity, + ); + let dtype = ListArray::::default_datatype( + T::get_dtype().to_arrow(CompatLevel::newest()), + ); + let arr = ListArray::::new( + dtype, + Offsets::new_unchecked(offsets).into(), + Box::new(array), + None, + ); + let mut ca = ListChunked::with_chunk(self.name().clone(), arr); + if can_fast_explode { + ca.set_fast_explode() + } + ca.into() + }, + } + } +} + +impl AggList for NullChunked { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { + match groups { + GroupsType::Idx(groups) => { + let mut builder = ListNullChunkedBuilder::new(self.name().clone(), groups.len()); + for idx in groups.all().iter() { + builder.append_with_len(idx.len()); + } + builder.finish().into_series() + }, + GroupsType::Slice { groups, .. } => { + let mut builder = ListNullChunkedBuilder::new(self.name().clone(), groups.len()); + for [_, len] in groups { + builder.append_with_len(*len as usize); + } + builder.finish().into_series() + }, + } + } +} + +impl AggList for BooleanChunked { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { + agg_list_by_gather_and_offsets(self, groups) + } +} + +impl AggList for StringChunked { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { + agg_list_by_gather_and_offsets(self, groups) + } +} + +impl AggList for BinaryChunked { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { + agg_list_by_gather_and_offsets(self, groups) + } +} + +impl AggList for ListChunked { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { + agg_list_by_gather_and_offsets(self, groups) + } +} + +#[cfg(feature = "dtype-array")] +impl AggList for ArrayChunked { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { + agg_list_by_gather_and_offsets(self, groups) + } +} + +#[cfg(feature = "object")] +impl AggList for ObjectChunked { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { + let mut can_fast_explode = true; + let mut offsets = Vec::::with_capacity(groups.len() + 1); + let mut length_so_far = 0i64; + offsets.push(length_so_far); + + // we know that iterators length + let iter = { + groups + .iter() + .flat_map(|indicator| { + let (group_vals, len) = match indicator { + GroupsIndicator::Idx((_first, idx)) => { + // SAFETY: + // group tuples always in bounds + let group_vals = self.take_unchecked(idx); + + (group_vals, idx.len() as IdxSize) + }, + GroupsIndicator::Slice([first, len]) => { + let group_vals = _slice_from_offsets(self, first, len); + + (group_vals, len) + }, + }; + + if len == 0 { + can_fast_explode = false; + } + length_so_far += len as i64; + // SAFETY: + // we know that offsets has allocated enough slots + offsets.push_unchecked(length_so_far); + + let arr = group_vals.downcast_iter().next().unwrap().clone(); + arr.into_iter_cloned() + }) + .trust_my_length(self.len()) + }; + + let mut pe = create_extension(iter); + + // SAFETY: This is safe because we just created the PolarsExtension + // meaning that the sentinel is heap allocated and the dereference of + // the pointer does not fail. + pe.set_to_series_fn::(); + let extension_array = Box::new(pe.take_and_forget()) as ArrayRef; + let extension_dtype = extension_array.dtype(); + + let dtype = ListArray::::default_datatype(extension_dtype.clone()); + // SAFETY: offsets are monotonically increasing. + let arr = ListArray::::new( + dtype, + Offsets::new_unchecked(offsets).into(), + extension_array, + None, + ); + let mut listarr = ListChunked::with_chunk(self.name().clone(), arr); + if can_fast_explode { + listarr.set_fast_explode() + } + listarr.into_series() + } +} + +#[cfg(feature = "dtype-struct")] +impl AggList for StructChunked { + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { + let ca = self.clone(); + let (gather, offsets, can_fast_explode) = groups.prepare_list_agg(self.len()); + + let gathered = if let Some(gather) = gather { + let out = ca.into_series().take_unchecked(&gather); + out.struct_().unwrap().clone() + } else { + ca.rechunk().into_owned() + }; + + let arr = gathered.chunks()[0].clone(); + let dtype = LargeListArray::default_datatype(arr.dtype().clone()); + + let mut chunk = ListChunked::with_chunk( + self.name().clone(), + LargeListArray::new(dtype, offsets, arr, None), + ); + chunk.set_dtype(DataType::List(Box::new(self.dtype().clone()))); + if can_fast_explode { + chunk.set_fast_explode() + } + + chunk.into_series() + } +} + +unsafe fn agg_list_by_gather_and_offsets( + ca: &ChunkedArray, + groups: &GroupsType, +) -> Series +where + ChunkedArray: ChunkTakeUnchecked, +{ + let (gather, offsets, can_fast_explode) = groups.prepare_list_agg(ca.len()); + + let gathered = if let Some(gather) = gather { + ca.take_unchecked(&gather) + } else { + ca.clone() + }; + + let arr = gathered.chunks()[0].clone(); + let dtype = LargeListArray::default_datatype(arr.dtype().clone()); + + let mut chunk = ListChunked::with_chunk( + ca.name().clone(), + LargeListArray::new(dtype, offsets, arr, None), + ); + chunk.set_dtype(DataType::List(Box::new(ca.dtype().clone()))); + if can_fast_explode { + chunk.set_fast_explode() + } + + chunk.into_series() +} diff --git a/crates/polars-core/src/frame/group_by/aggregations/boolean.rs b/crates/polars-core/src/frame/group_by/aggregations/boolean.rs new file mode 100644 index 000000000000..e1a478210129 --- /dev/null +++ b/crates/polars-core/src/frame/group_by/aggregations/boolean.rs @@ -0,0 +1,166 @@ +use std::borrow::Cow; + +use super::*; +use crate::chunked_array::cast::CastOptions; + +pub fn _agg_helper_idx_bool(groups: &GroupsIdx, f: F) -> Series +where + F: Fn((IdxSize, &IdxVec)) -> Option + Send + Sync, +{ + let ca: BooleanChunked = POOL.install(|| groups.into_par_iter().map(f).collect()); + ca.into_series() +} + +pub fn _agg_helper_slice_bool(groups: &[[IdxSize; 2]], f: F) -> Series +where + F: Fn([IdxSize; 2]) -> Option + Send + Sync, +{ + let ca: BooleanChunked = POOL.install(|| groups.par_iter().copied().map(f).collect()); + ca.into_series() +} + +#[cfg(feature = "bitwise")] +unsafe fn bitwise_agg( + ca: &BooleanChunked, + groups: &GroupsType, + f: fn(&BooleanChunked) -> Option, +) -> Series { + // Prevent a rechunk for every individual group. + + let s = if groups.len() > 1 { + ca.rechunk() + } else { + Cow::Borrowed(ca) + }; + + match groups { + GroupsType::Idx(groups) => _agg_helper_idx_bool::<_>(groups, |(_, idx)| { + debug_assert!(idx.len() <= s.len()); + if idx.is_empty() { + None + } else { + let take = s.take_unchecked(idx); + f(&take) + } + }), + GroupsType::Slice { groups, .. } => _agg_helper_slice_bool::<_>(groups, |[first, len]| { + debug_assert!(len <= s.len() as IdxSize); + if len == 0 { + None + } else { + let take = _slice_from_offsets(&s, first, len); + f(&take) + } + }), + } +} + +#[cfg(feature = "bitwise")] +impl BooleanChunked { + pub(crate) unsafe fn agg_and(&self, groups: &GroupsType) -> Series { + bitwise_agg(self, groups, ChunkBitwiseReduce::and_reduce) + } + + pub(crate) unsafe fn agg_or(&self, groups: &GroupsType) -> Series { + bitwise_agg(self, groups, ChunkBitwiseReduce::or_reduce) + } + + pub(crate) unsafe fn agg_xor(&self, groups: &GroupsType) -> Series { + bitwise_agg(self, groups, ChunkBitwiseReduce::xor_reduce) + } +} + +impl BooleanChunked { + pub(crate) unsafe fn agg_min(&self, groups: &GroupsType) -> Series { + // faster paths + match (self.is_sorted_flag(), self.null_count()) { + (IsSorted::Ascending, 0) => { + return self.clone().into_series().agg_first(groups); + }, + (IsSorted::Descending, 0) => { + return self.clone().into_series().agg_last(groups); + }, + _ => {}, + } + let ca_self = self.rechunk(); + let arr = ca_self.downcast_iter().next().unwrap(); + let no_nulls = arr.null_count() == 0; + match groups { + GroupsType::Idx(groups) => _agg_helper_idx_bool(groups, |(first, idx)| { + debug_assert!(idx.len() <= self.len()); + if idx.is_empty() { + None + } else if idx.len() == 1 { + arr.get(first as usize) + } else if no_nulls { + take_min_bool_iter_unchecked_no_nulls(arr, idx2usize(idx)) + } else { + take_min_bool_iter_unchecked_nulls(arr, idx2usize(idx), idx.len() as IdxSize) + } + }), + GroupsType::Slice { + groups: groups_slice, + .. + } => _agg_helper_slice_bool(groups_slice, |[first, len]| { + debug_assert!(len <= self.len() as IdxSize); + match len { + 0 => None, + 1 => self.get(first as usize), + _ => { + let arr_group = _slice_from_offsets(self, first, len); + arr_group.min() + }, + } + }), + } + } + pub(crate) unsafe fn agg_max(&self, groups: &GroupsType) -> Series { + // faster paths + match (self.is_sorted_flag(), self.null_count()) { + (IsSorted::Ascending, 0) => { + return self.clone().into_series().agg_last(groups); + }, + (IsSorted::Descending, 0) => { + return self.clone().into_series().agg_first(groups); + }, + _ => {}, + } + + let ca_self = self.rechunk(); + let arr = ca_self.downcast_iter().next().unwrap(); + let no_nulls = arr.null_count() == 0; + match groups { + GroupsType::Idx(groups) => _agg_helper_idx_bool(groups, |(first, idx)| { + debug_assert!(idx.len() <= self.len()); + if idx.is_empty() { + None + } else if idx.len() == 1 { + self.get(first as usize) + } else if no_nulls { + take_max_bool_iter_unchecked_no_nulls(arr, idx2usize(idx)) + } else { + take_max_bool_iter_unchecked_nulls(arr, idx2usize(idx), idx.len() as IdxSize) + } + }), + GroupsType::Slice { + groups: groups_slice, + .. + } => _agg_helper_slice_bool(groups_slice, |[first, len]| { + debug_assert!(len <= self.len() as IdxSize); + match len { + 0 => None, + 1 => self.get(first as usize), + _ => { + let arr_group = _slice_from_offsets(self, first, len); + arr_group.max() + }, + } + }), + } + } + pub(crate) unsafe fn agg_sum(&self, groups: &GroupsType) -> Series { + self.cast_with_options(&IDX_DTYPE, CastOptions::Overflowing) + .unwrap() + .agg_sum(groups) + } +} diff --git a/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs b/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs new file mode 100644 index 000000000000..336902228e91 --- /dev/null +++ b/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs @@ -0,0 +1,304 @@ +use super::*; + +// implemented on the series because we don't need types +impl Series { + fn slice_from_offsets(&self, first: IdxSize, len: IdxSize) -> Self { + self.slice(first as i64, len as usize) + } + + fn restore_logical(&self, out: Series) -> Series { + if self.dtype().is_logical() { + out.cast(self.dtype()).unwrap() + } else { + out + } + } + + #[doc(hidden)] + pub unsafe fn agg_valid_count(&self, groups: &GroupsType) -> Series { + // Prevent a rechunk for every individual group. + let s = if groups.len() > 1 && self.null_count() > 0 { + self.rechunk() + } else { + self.clone() + }; + + match groups { + GroupsType::Idx(groups) => agg_helper_idx_on_all::(groups, |idx| { + debug_assert!(idx.len() <= s.len()); + if idx.is_empty() { + None + } else if s.null_count() == 0 { + Some(idx.len() as IdxSize) + } else { + let take = unsafe { s.take_slice_unchecked(idx) }; + Some((take.len() - take.null_count()) as IdxSize) + } + }), + GroupsType::Slice { groups, .. } => { + _agg_helper_slice::(groups, |[first, len]| { + debug_assert!(len <= s.len() as IdxSize); + if len == 0 { + None + } else if s.null_count() == 0 { + Some(len) + } else { + let take = s.slice_from_offsets(first, len); + Some((take.len() - take.null_count()) as IdxSize) + } + }) + }, + } + } + + #[doc(hidden)] + pub unsafe fn agg_first(&self, groups: &GroupsType) -> Series { + // Prevent a rechunk for every individual group. + let s = if groups.len() > 1 { + self.rechunk() + } else { + self.clone() + }; + + let mut out = match groups { + GroupsType::Idx(groups) => { + let indices = groups + .iter() + .map( + |(first, idx)| { + if idx.is_empty() { None } else { Some(first) } + }, + ) + .collect_ca(PlSmallStr::EMPTY); + // SAFETY: groups are always in bounds. + s.take_unchecked(&indices) + }, + GroupsType::Slice { groups, .. } => { + let indices = groups + .iter() + .map(|&[first, len]| if len == 0 { None } else { Some(first) }) + .collect_ca(PlSmallStr::EMPTY); + // SAFETY: groups are always in bounds. + s.take_unchecked(&indices) + }, + }; + if groups.is_sorted_flag() { + out.set_sorted_flag(s.is_sorted_flag()) + } + s.restore_logical(out) + } + + #[doc(hidden)] + pub unsafe fn agg_n_unique(&self, groups: &GroupsType) -> Series { + // Prevent a rechunk for every individual group. + let s = if groups.len() > 1 { + self.rechunk() + } else { + self.clone() + }; + + match groups { + GroupsType::Idx(groups) => agg_helper_idx_on_all_no_null::(groups, |idx| { + debug_assert!(idx.len() <= s.len()); + if idx.is_empty() { + 0 + } else { + let take = s.take_slice_unchecked(idx); + take.n_unique().unwrap() as IdxSize + } + }), + GroupsType::Slice { groups, .. } => { + _agg_helper_slice_no_null::(groups, |[first, len]| { + debug_assert!(len <= s.len() as IdxSize); + if len == 0 { + 0 + } else { + let take = s.slice_from_offsets(first, len); + take.n_unique().unwrap() as IdxSize + } + }) + }, + } + } + + #[doc(hidden)] + pub unsafe fn agg_mean(&self, groups: &GroupsType) -> Series { + // Prevent a rechunk for every individual group. + let s = if groups.len() > 1 { + self.rechunk() + } else { + self.clone() + }; + + use DataType::*; + match s.dtype() { + Boolean => s.cast(&Float64).unwrap().agg_mean(groups), + Float32 => SeriesWrap(s.f32().unwrap().clone()).agg_mean(groups), + Float64 => SeriesWrap(s.f64().unwrap().clone()).agg_mean(groups), + dt if dt.is_primitive_numeric() => apply_method_physical_integer!(s, agg_mean, groups), + #[cfg(feature = "dtype-datetime")] + dt @ Datetime(_, _) => self + .to_physical_repr() + .agg_mean(groups) + .cast(&Int64) + .unwrap() + .cast(dt) + .unwrap(), + #[cfg(feature = "dtype-duration")] + dt @ Duration(_) => self + .to_physical_repr() + .agg_mean(groups) + .cast(&Int64) + .unwrap() + .cast(dt) + .unwrap(), + #[cfg(feature = "dtype-time")] + Time => self + .to_physical_repr() + .agg_mean(groups) + .cast(&Int64) + .unwrap() + .cast(&Time) + .unwrap(), + #[cfg(feature = "dtype-date")] + Date => (self + .to_physical_repr() + .agg_mean(groups) + .cast(&Float64) + .unwrap() + * (MS_IN_DAY as f64)) + .cast(&Datetime(TimeUnit::Milliseconds, None)) + .unwrap(), + _ => Series::full_null(PlSmallStr::EMPTY, groups.len(), s.dtype()), + } + } + + #[doc(hidden)] + pub unsafe fn agg_median(&self, groups: &GroupsType) -> Series { + // Prevent a rechunk for every individual group. + let s = if groups.len() > 1 { + self.rechunk() + } else { + self.clone() + }; + + use DataType::*; + match s.dtype() { + Boolean => s.cast(&Float64).unwrap().agg_median(groups), + Float32 => SeriesWrap(s.f32().unwrap().clone()).agg_median(groups), + Float64 => SeriesWrap(s.f64().unwrap().clone()).agg_median(groups), + dt if dt.is_primitive_numeric() => { + apply_method_physical_integer!(s, agg_median, groups) + }, + #[cfg(feature = "dtype-datetime")] + dt @ Datetime(_, _) => self + .to_physical_repr() + .agg_median(groups) + .cast(&Int64) + .unwrap() + .cast(dt) + .unwrap(), + #[cfg(feature = "dtype-duration")] + dt @ Duration(_) => self + .to_physical_repr() + .agg_median(groups) + .cast(&Int64) + .unwrap() + .cast(dt) + .unwrap(), + #[cfg(feature = "dtype-time")] + Time => self + .to_physical_repr() + .agg_median(groups) + .cast(&Int64) + .unwrap() + .cast(&Time) + .unwrap(), + #[cfg(feature = "dtype-date")] + Date => (self + .to_physical_repr() + .agg_median(groups) + .cast(&Float64) + .unwrap() + * (MS_IN_DAY as f64)) + .cast(&Datetime(TimeUnit::Milliseconds, None)) + .unwrap(), + _ => Series::full_null(PlSmallStr::EMPTY, groups.len(), s.dtype()), + } + } + + #[doc(hidden)] + pub unsafe fn agg_quantile( + &self, + groups: &GroupsType, + quantile: f64, + method: QuantileMethod, + ) -> Series { + // Prevent a rechunk for every individual group. + let s = if groups.len() > 1 { + self.rechunk() + } else { + self.clone() + }; + + use DataType::*; + match s.dtype() { + Float32 => s.f32().unwrap().agg_quantile(groups, quantile, method), + Float64 => s.f64().unwrap().agg_quantile(groups, quantile, method), + dt if dt.is_primitive_numeric() || dt.is_temporal() => { + let ca = s.to_physical_repr(); + let physical_type = ca.dtype(); + let s = apply_method_physical_integer!(ca, agg_quantile, groups, quantile, method); + if dt.is_logical() { + // back to physical and then + // back to logical type + s.cast(physical_type).unwrap().cast(dt).unwrap() + } else { + s + } + }, + _ => Series::full_null(PlSmallStr::EMPTY, groups.len(), s.dtype()), + } + } + + #[doc(hidden)] + pub unsafe fn agg_last(&self, groups: &GroupsType) -> Series { + // Prevent a rechunk for every individual group. + let s = if groups.len() > 1 { + self.rechunk() + } else { + self.clone() + }; + + let out = match groups { + GroupsType::Idx(groups) => { + let indices = groups + .all() + .iter() + .map(|idx| { + if idx.is_empty() { + None + } else { + Some(idx[idx.len() - 1]) + } + }) + .collect_ca(PlSmallStr::EMPTY); + s.take_unchecked(&indices) + }, + GroupsType::Slice { groups, .. } => { + let indices = groups + .iter() + .map(|&[first, len]| { + if len == 0 { + None + } else { + Some(first + len - 1) + } + }) + .collect_ca(PlSmallStr::EMPTY); + s.take_unchecked(&indices) + }, + }; + s.restore_logical(out) + } +} diff --git a/crates/polars-core/src/frame/group_by/aggregations/mod.rs b/crates/polars-core/src/frame/group_by/aggregations/mod.rs new file mode 100644 index 000000000000..c029adf1f5af --- /dev/null +++ b/crates/polars-core/src/frame/group_by/aggregations/mod.rs @@ -0,0 +1,1125 @@ +mod agg_list; +mod boolean; +mod dispatch; +mod string; + +use std::borrow::Cow; + +pub use agg_list::*; +use arrow::bitmap::{Bitmap, MutableBitmap}; +use arrow::legacy::kernels::take_agg::*; +use arrow::legacy::trusted_len::TrustedLenPush; +use arrow::types::NativeType; +use num_traits::pow::Pow; +use num_traits::{Bounded, Float, Num, NumCast, ToPrimitive, Zero}; +use polars_compute::rolling::no_nulls::{ + MaxWindow, MeanWindow, MinWindow, MomentWindow, QuantileWindow, RollingAggWindowNoNulls, + SumWindow, +}; +use polars_compute::rolling::nulls::{RollingAggWindowNulls, VarianceMoment}; +use polars_compute::rolling::quantile_filter::SealedRolling; +use polars_compute::rolling::{ + self, QuantileMethod, RollingFnParams, RollingQuantileParams, RollingVarParams, quantile_filter, +}; +use polars_utils::float::IsFloat; +use polars_utils::idx_vec::IdxVec; +use polars_utils::min_max::MinMax; +use rayon::prelude::*; + +use crate::chunked_array::cast::CastOptions; +#[cfg(feature = "object")] +use crate::chunked_array::object::extension::create_extension; +use crate::frame::group_by::GroupsIdx; +#[cfg(feature = "object")] +use crate::frame::group_by::GroupsIndicator; +use crate::prelude::*; +use crate::series::IsSorted; +use crate::series::implementations::SeriesWrap; +use crate::utils::NoNull; +use crate::{POOL, apply_method_physical_integer}; + +fn idx2usize(idx: &[IdxSize]) -> impl ExactSizeIterator + '_ { + idx.iter().map(|i| *i as usize) +} + +// if the windows overlap, we can use the rolling_ kernels +// they maintain state, which saves a lot of compute by not naively traversing all elements every +// window +// +// if the windows don't overlap, we should not use these kernels as they are single threaded, so +// we miss out on easy parallelization. +pub fn _use_rolling_kernels(groups: &GroupsSlice, chunks: &[ArrayRef]) -> bool { + match groups.len() { + 0 | 1 => false, + _ => { + let [first_offset, first_len] = groups[0]; + let second_offset = groups[1][0]; + + second_offset >= first_offset // Prevent false positive from regular group-by that has out of order slices. + // Rolling group-by is expected to have monotonically increasing slices. + && second_offset < (first_offset + first_len) + && chunks.len() == 1 + }, + } +} + +// Use an aggregation window that maintains the state +pub fn _rolling_apply_agg_window_nulls<'a, Agg, T, O>( + values: &'a [T], + validity: &'a Bitmap, + offsets: O, + params: Option, +) -> PrimitiveArray +where + O: Iterator + TrustedLen, + Agg: RollingAggWindowNulls<'a, T>, + T: IsFloat + NativeType, +{ + if values.is_empty() { + let out: Vec = vec![]; + return PrimitiveArray::new(T::PRIMITIVE.into(), out.into(), None); + } + + // This iterators length can be trusted + // these represent the number of groups in the group_by operation + let output_len = offsets.size_hint().0; + // start with a dummy index, will be overwritten on first iteration. + // SAFETY: + // we are in bounds + let mut agg_window = unsafe { Agg::new(values, validity, 0, 0, params, None) }; + + let mut validity = MutableBitmap::with_capacity(output_len); + validity.extend_constant(output_len, true); + + let out = offsets + .enumerate() + .map(|(idx, (start, len))| { + let end = start + len; + + // SAFETY: + // we are in bounds + + let agg = if start == end { + None + } else { + unsafe { agg_window.update(start as usize, end as usize) } + }; + + match agg { + Some(val) => val, + None => { + // SAFETY: we are in bounds + unsafe { validity.set_unchecked(idx, false) }; + T::default() + }, + } + }) + .collect_trusted::>(); + + PrimitiveArray::new(T::PRIMITIVE.into(), out.into(), Some(validity.into())) +} + +// Use an aggregation window that maintains the state. +pub fn _rolling_apply_agg_window_no_nulls<'a, Agg, T, O>( + values: &'a [T], + offsets: O, + params: Option, +) -> PrimitiveArray +where + // items (offset, len) -> so offsets are offset, offset + len + Agg: RollingAggWindowNoNulls<'a, T>, + O: Iterator + TrustedLen, + T: IsFloat + NativeType, +{ + if values.is_empty() { + let out: Vec = vec![]; + return PrimitiveArray::new(T::PRIMITIVE.into(), out.into(), None); + } + // start with a dummy index, will be overwritten on first iteration. + let mut agg_window = Agg::new(values, 0, 0, params, None); + + offsets + .map(|(start, len)| { + let end = start + len; + + if start == end { + None + } else { + // SAFETY: we are in bounds. + unsafe { agg_window.update(start as usize, end as usize) } + } + }) + .collect::>() +} + +pub fn _slice_from_offsets(ca: &ChunkedArray, first: IdxSize, len: IdxSize) -> ChunkedArray +where + T: PolarsDataType, +{ + ca.slice(first as i64, len as usize) +} + +/// Helper that combines the groups into a parallel iterator over `(first, all): (u32, &Vec)`. +pub fn _agg_helper_idx(groups: &GroupsIdx, f: F) -> Series +where + F: Fn((IdxSize, &IdxVec)) -> Option + Send + Sync, + T: PolarsNumericType, + ChunkedArray: IntoSeries, +{ + let ca: ChunkedArray = POOL.install(|| groups.into_par_iter().map(f).collect()); + ca.into_series() +} + +/// Same helper as `_agg_helper_idx` but for aggregations that don't return an Option. +pub fn _agg_helper_idx_no_null(groups: &GroupsIdx, f: F) -> Series +where + F: Fn((IdxSize, &IdxVec)) -> T::Native + Send + Sync, + T: PolarsNumericType, + ChunkedArray: IntoSeries, +{ + let ca: NoNull> = POOL.install(|| groups.into_par_iter().map(f).collect()); + ca.into_inner().into_series() +} + +/// Helper that iterates on the `all: Vec` collection, +/// this doesn't have traverse the `first: Vec` memory and is therefore faster. +fn agg_helper_idx_on_all(groups: &GroupsIdx, f: F) -> Series +where + F: Fn(&IdxVec) -> Option + Send + Sync, + T: PolarsNumericType, + ChunkedArray: IntoSeries, +{ + let ca: ChunkedArray = POOL.install(|| groups.all().into_par_iter().map(f).collect()); + ca.into_series() +} + +/// Same as `agg_helper_idx_on_all` but for aggregations that don't return an Option. +fn agg_helper_idx_on_all_no_null(groups: &GroupsIdx, f: F) -> Series +where + F: Fn(&IdxVec) -> T::Native + Send + Sync, + T: PolarsNumericType, + ChunkedArray: IntoSeries, +{ + let ca: NoNull> = + POOL.install(|| groups.all().into_par_iter().map(f).collect()); + ca.into_inner().into_series() +} + +pub fn _agg_helper_slice(groups: &[[IdxSize; 2]], f: F) -> Series +where + F: Fn([IdxSize; 2]) -> Option + Send + Sync, + T: PolarsNumericType, + ChunkedArray: IntoSeries, +{ + let ca: ChunkedArray = POOL.install(|| groups.par_iter().copied().map(f).collect()); + ca.into_series() +} + +pub fn _agg_helper_slice_no_null(groups: &[[IdxSize; 2]], f: F) -> Series +where + F: Fn([IdxSize; 2]) -> T::Native + Send + Sync, + T: PolarsNumericType, + ChunkedArray: IntoSeries, +{ + let ca: NoNull> = POOL.install(|| groups.par_iter().copied().map(f).collect()); + ca.into_inner().into_series() +} + +/// Intermediate helper trait so we can have a single generic implementation +/// This trait will ensure the specific dispatch works without complicating +/// the trait bounds. +trait QuantileDispatcher { + fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult>; + + fn _median(self) -> Option; +} + +impl QuantileDispatcher for ChunkedArray +where + T: PolarsIntegerType, + T::Native: Ord, + ChunkedArray: IntoSeries, +{ + fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult> { + self.quantile_faster(quantile, method) + } + fn _median(self) -> Option { + self.median_faster() + } +} + +impl QuantileDispatcher for Float32Chunked { + fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult> { + self.quantile_faster(quantile, method) + } + fn _median(self) -> Option { + self.median_faster() + } +} +impl QuantileDispatcher for Float64Chunked { + fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult> { + self.quantile_faster(quantile, method) + } + fn _median(self) -> Option { + self.median_faster() + } +} + +unsafe fn agg_quantile_generic( + ca: &ChunkedArray, + groups: &GroupsType, + quantile: f64, + method: QuantileMethod, +) -> Series +where + T: PolarsNumericType, + ChunkedArray: QuantileDispatcher, + ChunkedArray: IntoSeries, + K: PolarsNumericType, + ::Native: num_traits::Float + quantile_filter::SealedRolling, +{ + let invalid_quantile = !(0.0..=1.0).contains(&quantile); + if invalid_quantile { + return Series::full_null(ca.name().clone(), groups.len(), ca.dtype()); + } + match groups { + GroupsType::Idx(groups) => { + let ca = ca.rechunk(); + agg_helper_idx_on_all::(groups, |idx| { + debug_assert!(idx.len() <= ca.len()); + if idx.is_empty() { + return None; + } + let take = { ca.take_unchecked(idx) }; + // checked with invalid quantile check + take._quantile(quantile, method).unwrap_unchecked() + }) + }, + GroupsType::Slice { groups, .. } => { + if _use_rolling_kernels(groups, ca.chunks()) { + // this cast is a no-op for floats + let s = ca + .cast_with_options(&K::get_dtype(), CastOptions::Overflowing) + .unwrap(); + let ca: &ChunkedArray = s.as_ref().as_ref(); + let arr = ca.downcast_iter().next().unwrap(); + let values = arr.values().as_slice(); + let offset_iter = groups.iter().map(|[first, len]| (*first, *len)); + let arr = match arr.validity() { + None => _rolling_apply_agg_window_no_nulls::, _, _>( + values, + offset_iter, + Some(RollingFnParams::Quantile(RollingQuantileParams { + prob: quantile, + method, + })), + ), + Some(validity) => { + _rolling_apply_agg_window_nulls::, _, _>( + values, + validity, + offset_iter, + Some(RollingFnParams::Quantile(RollingQuantileParams { + prob: quantile, + method, + })), + ) + }, + }; + // The rolling kernels works on the dtype, this is not yet the + // float output type we need. + ChunkedArray::from(arr).into_series() + } else { + _agg_helper_slice::(groups, |[first, len]| { + debug_assert!(first + len <= ca.len() as IdxSize); + match len { + 0 => None, + 1 => ca.get(first as usize).map(|v| NumCast::from(v).unwrap()), + _ => { + let arr_group = _slice_from_offsets(ca, first, len); + // unwrap checked with invalid quantile check + arr_group + ._quantile(quantile, method) + .unwrap_unchecked() + .map(|flt| NumCast::from(flt).unwrap_unchecked()) + }, + } + }) + } + }, + } +} + +unsafe fn agg_median_generic(ca: &ChunkedArray, groups: &GroupsType) -> Series +where + T: PolarsNumericType, + ChunkedArray: QuantileDispatcher, + ChunkedArray: IntoSeries, + K: PolarsNumericType, + ::Native: num_traits::Float + SealedRolling, +{ + match groups { + GroupsType::Idx(groups) => { + let ca = ca.rechunk(); + agg_helper_idx_on_all::(groups, |idx| { + debug_assert!(idx.len() <= ca.len()); + if idx.is_empty() { + return None; + } + let take = { ca.take_unchecked(idx) }; + take._median() + }) + }, + GroupsType::Slice { .. } => { + agg_quantile_generic::(ca, groups, 0.5, QuantileMethod::Linear) + }, + } +} + +/// # Safety +/// +/// No bounds checks on `groups`. +#[cfg(feature = "bitwise")] +unsafe fn bitwise_agg( + ca: &ChunkedArray, + groups: &GroupsType, + f: fn(&ChunkedArray) -> Option, +) -> Series +where + ChunkedArray: + ChunkTakeUnchecked<[IdxSize]> + ChunkBitwiseReduce + IntoSeries, +{ + // Prevent a rechunk for every individual group. + + let s = if groups.len() > 1 { + ca.rechunk() + } else { + Cow::Borrowed(ca) + }; + + match groups { + GroupsType::Idx(groups) => agg_helper_idx_on_all::(groups, |idx| { + debug_assert!(idx.len() <= s.len()); + if idx.is_empty() { + None + } else { + let take = unsafe { s.take_unchecked(idx) }; + f(&take) + } + }), + GroupsType::Slice { groups, .. } => _agg_helper_slice::(groups, |[first, len]| { + debug_assert!(len <= s.len() as IdxSize); + if len == 0 { + None + } else { + let take = _slice_from_offsets(&s, first, len); + f(&take) + } + }), + } +} + +#[cfg(feature = "bitwise")] +impl ChunkedArray +where + T: PolarsNumericType, + ChunkedArray: + ChunkTakeUnchecked<[IdxSize]> + ChunkBitwiseReduce + IntoSeries, +{ + /// # Safety + /// + /// No bounds checks on `groups`. + pub(crate) unsafe fn agg_and(&self, groups: &GroupsType) -> Series { + unsafe { bitwise_agg(self, groups, ChunkBitwiseReduce::and_reduce) } + } + + /// # Safety + /// + /// No bounds checks on `groups`. + pub(crate) unsafe fn agg_or(&self, groups: &GroupsType) -> Series { + unsafe { bitwise_agg(self, groups, ChunkBitwiseReduce::or_reduce) } + } + + /// # Safety + /// + /// No bounds checks on `groups`. + pub(crate) unsafe fn agg_xor(&self, groups: &GroupsType) -> Series { + unsafe { bitwise_agg(self, groups, ChunkBitwiseReduce::xor_reduce) } + } +} + +impl ChunkedArray +where + T: PolarsNumericType + Sync, + T::Native: NativeType + PartialOrd + Num + NumCast + Zero + Bounded + std::iter::Sum, + ChunkedArray: IntoSeries + ChunkAgg, +{ + pub(crate) unsafe fn agg_min(&self, groups: &GroupsType) -> Series { + // faster paths + match (self.is_sorted_flag(), self.null_count()) { + (IsSorted::Ascending, 0) => { + return self.clone().into_series().agg_first(groups); + }, + (IsSorted::Descending, 0) => { + return self.clone().into_series().agg_last(groups); + }, + _ => {}, + } + match groups { + GroupsType::Idx(groups) => { + let ca = self.rechunk(); + let arr = ca.downcast_iter().next().unwrap(); + let no_nulls = arr.null_count() == 0; + _agg_helper_idx::(groups, |(first, idx)| { + debug_assert!(idx.len() <= arr.len()); + if idx.is_empty() { + None + } else if idx.len() == 1 { + arr.get(first as usize) + } else if no_nulls { + take_agg_no_null_primitive_iter_unchecked::<_, T::Native, _, _>( + arr, + idx2usize(idx), + |a, b| a.min_ignore_nan(b), + ) + } else { + take_agg_primitive_iter_unchecked(arr, idx2usize(idx), |a, b| { + a.min_ignore_nan(b) + }) + } + }) + }, + GroupsType::Slice { + groups: groups_slice, + .. + } => { + if _use_rolling_kernels(groups_slice, self.chunks()) { + let arr = self.downcast_iter().next().unwrap(); + let values = arr.values().as_slice(); + let offset_iter = groups_slice.iter().map(|[first, len]| (*first, *len)); + let arr = match arr.validity() { + None => _rolling_apply_agg_window_no_nulls::, _, _>( + values, + offset_iter, + None, + ), + Some(validity) => _rolling_apply_agg_window_nulls::< + rolling::nulls::MinWindow<_>, + _, + _, + >( + values, validity, offset_iter, None + ), + }; + Self::from(arr).into_series() + } else { + _agg_helper_slice::(groups_slice, |[first, len]| { + debug_assert!(len <= self.len() as IdxSize); + match len { + 0 => None, + 1 => self.get(first as usize), + _ => { + let arr_group = _slice_from_offsets(self, first, len); + ChunkAgg::min(&arr_group) + }, + } + }) + } + }, + } + } + + pub(crate) unsafe fn agg_max(&self, groups: &GroupsType) -> Series { + // faster paths + match (self.is_sorted_flag(), self.null_count()) { + (IsSorted::Ascending, 0) => { + return self.clone().into_series().agg_last(groups); + }, + (IsSorted::Descending, 0) => { + return self.clone().into_series().agg_first(groups); + }, + _ => {}, + } + + match groups { + GroupsType::Idx(groups) => { + let ca = self.rechunk(); + let arr = ca.downcast_iter().next().unwrap(); + let no_nulls = arr.null_count() == 0; + _agg_helper_idx::(groups, |(first, idx)| { + debug_assert!(idx.len() <= arr.len()); + if idx.is_empty() { + None + } else if idx.len() == 1 { + arr.get(first as usize) + } else if no_nulls { + take_agg_no_null_primitive_iter_unchecked::<_, T::Native, _, _>( + arr, + idx2usize(idx), + |a, b| a.max_ignore_nan(b), + ) + } else { + take_agg_primitive_iter_unchecked(arr, idx2usize(idx), |a, b| { + a.max_ignore_nan(b) + }) + } + }) + }, + GroupsType::Slice { + groups: groups_slice, + .. + } => { + if _use_rolling_kernels(groups_slice, self.chunks()) { + let arr = self.downcast_iter().next().unwrap(); + let values = arr.values().as_slice(); + let offset_iter = groups_slice.iter().map(|[first, len]| (*first, *len)); + let arr = match arr.validity() { + None => _rolling_apply_agg_window_no_nulls::, _, _>( + values, + offset_iter, + None, + ), + Some(validity) => _rolling_apply_agg_window_nulls::< + rolling::nulls::MaxWindow<_>, + _, + _, + >( + values, validity, offset_iter, None + ), + }; + Self::from(arr).into_series() + } else { + _agg_helper_slice::(groups_slice, |[first, len]| { + debug_assert!(len <= self.len() as IdxSize); + match len { + 0 => None, + 1 => self.get(first as usize), + _ => { + let arr_group = _slice_from_offsets(self, first, len); + ChunkAgg::max(&arr_group) + }, + } + }) + } + }, + } + } + + pub(crate) unsafe fn agg_sum(&self, groups: &GroupsType) -> Series { + match groups { + GroupsType::Idx(groups) => { + let ca = self.rechunk(); + let arr = ca.downcast_iter().next().unwrap(); + let no_nulls = arr.null_count() == 0; + _agg_helper_idx_no_null::(groups, |(first, idx)| { + debug_assert!(idx.len() <= self.len()); + if idx.is_empty() { + T::Native::zero() + } else if idx.len() == 1 { + arr.get(first as usize).unwrap_or(T::Native::zero()) + } else if no_nulls { + take_agg_no_null_primitive_iter_unchecked(arr, idx2usize(idx), |a, b| a + b) + .unwrap_or(T::Native::zero()) + } else { + take_agg_primitive_iter_unchecked(arr, idx2usize(idx), |a, b| a + b) + .unwrap_or(T::Native::zero()) + } + }) + }, + GroupsType::Slice { groups, .. } => { + if _use_rolling_kernels(groups, self.chunks()) { + let arr = self.downcast_iter().next().unwrap(); + let values = arr.values().as_slice(); + let offset_iter = groups.iter().map(|[first, len]| (*first, *len)); + let arr = match arr.validity() { + None => _rolling_apply_agg_window_no_nulls::, _, _>( + values, + offset_iter, + None, + ), + Some(validity) => _rolling_apply_agg_window_nulls::< + rolling::nulls::SumWindow<_>, + _, + _, + >( + values, validity, offset_iter, None + ), + }; + Self::from(arr).into_series() + } else { + _agg_helper_slice_no_null::(groups, |[first, len]| { + debug_assert!(len <= self.len() as IdxSize); + match len { + 0 => T::Native::zero(), + 1 => self.get(first as usize).unwrap_or(T::Native::zero()), + _ => { + let arr_group = _slice_from_offsets(self, first, len); + arr_group.sum().unwrap_or(T::Native::zero()) + }, + } + }) + } + }, + } + } +} + +impl SeriesWrap> +where + T: PolarsFloatType, + ChunkedArray: IntoSeries + + ChunkVar + + VarAggSeries + + ChunkQuantile + + QuantileAggSeries + + ChunkAgg, + T::Native: Pow, +{ + pub(crate) unsafe fn agg_mean(&self, groups: &GroupsType) -> Series { + match groups { + GroupsType::Idx(groups) => { + let ca = self.rechunk(); + let arr = ca.downcast_iter().next().unwrap(); + let no_nulls = arr.null_count() == 0; + _agg_helper_idx::(groups, |(first, idx)| { + // this can fail due to a bug in lazy code. + // here users can create filters in aggregations + // and thereby creating shorter columns than the original group tuples. + // the group tuples are modified, but if that's done incorrect there can be out of bounds + // access + debug_assert!(idx.len() <= self.len()); + let out = if idx.is_empty() { + None + } else if idx.len() == 1 { + arr.get(first as usize).map(|sum| sum.to_f64().unwrap()) + } else if no_nulls { + take_agg_no_null_primitive_iter_unchecked::<_, T::Native, _, _>( + arr, + idx2usize(idx), + |a, b| a + b, + ) + .unwrap() + .to_f64() + .map(|sum| sum / idx.len() as f64) + } else { + take_agg_primitive_iter_unchecked_count_nulls::( + arr, + idx2usize(idx), + |a, b| a + b, + T::Native::zero(), + idx.len() as IdxSize, + ) + .map(|(sum, null_count)| { + sum.to_f64() + .map(|sum| sum / (idx.len() as f64 - null_count as f64)) + .unwrap() + }) + }; + out.map(|flt| NumCast::from(flt).unwrap()) + }) + }, + GroupsType::Slice { groups, .. } => { + if _use_rolling_kernels(groups, self.chunks()) { + let arr = self.downcast_iter().next().unwrap(); + let values = arr.values().as_slice(); + let offset_iter = groups.iter().map(|[first, len]| (*first, *len)); + let arr = match arr.validity() { + None => _rolling_apply_agg_window_no_nulls::, _, _>( + values, + offset_iter, + None, + ), + Some(validity) => _rolling_apply_agg_window_nulls::< + rolling::nulls::MeanWindow<_>, + _, + _, + >( + values, validity, offset_iter, None + ), + }; + ChunkedArray::from(arr).into_series() + } else { + _agg_helper_slice::(groups, |[first, len]| { + debug_assert!(len <= self.len() as IdxSize); + match len { + 0 => None, + 1 => self.get(first as usize), + _ => { + let arr_group = _slice_from_offsets(self, first, len); + arr_group.mean().map(|flt| NumCast::from(flt).unwrap()) + }, + } + }) + } + }, + } + } + + pub(crate) unsafe fn agg_var(&self, groups: &GroupsType, ddof: u8) -> Series + where + ::Native: num_traits::Float, + { + let ca = &self.0.rechunk(); + match groups { + GroupsType::Idx(groups) => { + let ca = ca.rechunk(); + let arr = ca.downcast_iter().next().unwrap(); + let no_nulls = arr.null_count() == 0; + agg_helper_idx_on_all::(groups, |idx| { + debug_assert!(idx.len() <= ca.len()); + if idx.is_empty() { + return None; + } + let out = if no_nulls { + take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof) + } else { + take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof) + }; + out.map(|flt| NumCast::from(flt).unwrap()) + }) + }, + GroupsType::Slice { groups, .. } => { + if _use_rolling_kernels(groups, self.chunks()) { + let arr = self.downcast_iter().next().unwrap(); + let values = arr.values().as_slice(); + let offset_iter = groups.iter().map(|[first, len]| (*first, *len)); + let arr = match arr.validity() { + None => _rolling_apply_agg_window_no_nulls::< + MomentWindow<_, VarianceMoment>, + _, + _, + >( + values, + offset_iter, + Some(RollingFnParams::Var(RollingVarParams { ddof })), + ), + Some(validity) => _rolling_apply_agg_window_nulls::< + rolling::nulls::MomentWindow<_, VarianceMoment>, + _, + _, + >( + values, + validity, + offset_iter, + Some(RollingFnParams::Var(RollingVarParams { ddof })), + ), + }; + ChunkedArray::from(arr).into_series() + } else { + _agg_helper_slice::(groups, |[first, len]| { + debug_assert!(len <= self.len() as IdxSize); + match len { + 0 => None, + 1 => { + if ddof == 0 { + NumCast::from(0) + } else { + None + } + }, + _ => { + let arr_group = _slice_from_offsets(self, first, len); + arr_group.var(ddof).map(|flt| NumCast::from(flt).unwrap()) + }, + } + }) + } + }, + } + } + pub(crate) unsafe fn agg_std(&self, groups: &GroupsType, ddof: u8) -> Series + where + ::Native: num_traits::Float, + { + let ca = &self.0.rechunk(); + match groups { + GroupsType::Idx(groups) => { + let arr = ca.downcast_iter().next().unwrap(); + let no_nulls = arr.null_count() == 0; + agg_helper_idx_on_all::(groups, |idx| { + debug_assert!(idx.len() <= ca.len()); + if idx.is_empty() { + return None; + } + let out = if no_nulls { + take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof) + } else { + take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof) + }; + out.map(|flt| NumCast::from(flt.sqrt()).unwrap()) + }) + }, + GroupsType::Slice { groups, .. } => { + if _use_rolling_kernels(groups, self.chunks()) { + let arr = ca.downcast_iter().next().unwrap(); + let values = arr.values().as_slice(); + let offset_iter = groups.iter().map(|[first, len]| (*first, *len)); + let arr = match arr.validity() { + None => _rolling_apply_agg_window_no_nulls::< + MomentWindow<_, VarianceMoment>, + _, + _, + >( + values, + offset_iter, + Some(RollingFnParams::Var(RollingVarParams { ddof })), + ), + Some(validity) => _rolling_apply_agg_window_nulls::< + rolling::nulls::MomentWindow<_, rolling::nulls::VarianceMoment>, + _, + _, + >( + values, + validity, + offset_iter, + Some(RollingFnParams::Var(RollingVarParams { ddof })), + ), + }; + + let mut ca = ChunkedArray::::from(arr); + ca.apply_mut(|v| v.powf(NumCast::from(0.5).unwrap())); + ca.into_series() + } else { + _agg_helper_slice::(groups, |[first, len]| { + debug_assert!(len <= self.len() as IdxSize); + match len { + 0 => None, + 1 => { + if ddof == 0 { + NumCast::from(0) + } else { + None + } + }, + _ => { + let arr_group = _slice_from_offsets(self, first, len); + arr_group.std(ddof).map(|flt| NumCast::from(flt).unwrap()) + }, + } + }) + } + }, + } + } +} + +impl Float32Chunked { + pub(crate) unsafe fn agg_quantile( + &self, + groups: &GroupsType, + quantile: f64, + method: QuantileMethod, + ) -> Series { + agg_quantile_generic::<_, Float32Type>(self, groups, quantile, method) + } + pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series { + agg_median_generic::<_, Float32Type>(self, groups) + } +} +impl Float64Chunked { + pub(crate) unsafe fn agg_quantile( + &self, + groups: &GroupsType, + quantile: f64, + method: QuantileMethod, + ) -> Series { + agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method) + } + pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series { + agg_median_generic::<_, Float64Type>(self, groups) + } +} + +impl ChunkedArray +where + T: PolarsIntegerType, + ChunkedArray: IntoSeries + ChunkAgg + ChunkVar, + T::Native: NumericNative + Ord, +{ + pub(crate) unsafe fn agg_mean(&self, groups: &GroupsType) -> Series { + match groups { + GroupsType::Idx(groups) => { + let ca = self.rechunk(); + let arr = ca.downcast_get(0).unwrap(); + _agg_helper_idx::(groups, |(first, idx)| { + // this can fail due to a bug in lazy code. + // here users can create filters in aggregations + // and thereby creating shorter columns than the original group tuples. + // the group tuples are modified, but if that's done incorrect there can be out of bounds + // access + debug_assert!(idx.len() <= self.len()); + if idx.is_empty() { + None + } else if idx.len() == 1 { + self.get(first as usize).map(|sum| sum.to_f64().unwrap()) + } else { + match (self.has_nulls(), self.chunks.len()) { + (false, 1) => { + take_agg_no_null_primitive_iter_unchecked::<_, f64, _, _>( + arr, + idx2usize(idx), + |a, b| a + b, + ) + .map(|sum| sum / idx.len() as f64) + }, + (_, 1) => { + { + take_agg_primitive_iter_unchecked_count_nulls::< + T::Native, + f64, + _, + _, + >( + arr, idx2usize(idx), |a, b| a + b, 0.0, idx.len() as IdxSize + ) + } + .map(|(sum, null_count)| { + sum / (idx.len() as f64 - null_count as f64) + }) + }, + _ => { + let take = { self.take_unchecked(idx) }; + take.mean() + }, + } + } + }) + }, + GroupsType::Slice { + groups: groups_slice, + .. + } => { + if _use_rolling_kernels(groups_slice, self.chunks()) { + let ca = self + .cast_with_options(&DataType::Float64, CastOptions::Overflowing) + .unwrap(); + ca.agg_mean(groups) + } else { + _agg_helper_slice::(groups_slice, |[first, len]| { + debug_assert!(first + len <= self.len() as IdxSize); + match len { + 0 => None, + 1 => self.get(first as usize).map(|v| NumCast::from(v).unwrap()), + _ => { + let arr_group = _slice_from_offsets(self, first, len); + arr_group.mean() + }, + } + }) + } + }, + } + } + + pub(crate) unsafe fn agg_var(&self, groups: &GroupsType, ddof: u8) -> Series { + match groups { + GroupsType::Idx(groups) => { + let ca_self = self.rechunk(); + let arr = ca_self.downcast_iter().next().unwrap(); + let no_nulls = arr.null_count() == 0; + agg_helper_idx_on_all::(groups, |idx| { + debug_assert!(idx.len() <= arr.len()); + if idx.is_empty() { + return None; + } + if no_nulls { + take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof) + } else { + take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof) + } + }) + }, + GroupsType::Slice { + groups: groups_slice, + .. + } => { + if _use_rolling_kernels(groups_slice, self.chunks()) { + let ca = self + .cast_with_options(&DataType::Float64, CastOptions::Overflowing) + .unwrap(); + ca.agg_var(groups, ddof) + } else { + _agg_helper_slice::(groups_slice, |[first, len]| { + debug_assert!(first + len <= self.len() as IdxSize); + match len { + 0 => None, + 1 => { + if ddof == 0 { + NumCast::from(0) + } else { + None + } + }, + _ => { + let arr_group = _slice_from_offsets(self, first, len); + arr_group.var(ddof) + }, + } + }) + } + }, + } + } + pub(crate) unsafe fn agg_std(&self, groups: &GroupsType, ddof: u8) -> Series { + match groups { + GroupsType::Idx(groups) => { + let ca_self = self.rechunk(); + let arr = ca_self.downcast_iter().next().unwrap(); + let no_nulls = arr.null_count() == 0; + agg_helper_idx_on_all::(groups, |idx| { + debug_assert!(idx.len() <= self.len()); + if idx.is_empty() { + return None; + } + let out = if no_nulls { + take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof) + } else { + take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof) + }; + out.map(|v| v.sqrt()) + }) + }, + GroupsType::Slice { + groups: groups_slice, + .. + } => { + if _use_rolling_kernels(groups_slice, self.chunks()) { + let ca = self + .cast_with_options(&DataType::Float64, CastOptions::Overflowing) + .unwrap(); + ca.agg_std(groups, ddof) + } else { + _agg_helper_slice::(groups_slice, |[first, len]| { + debug_assert!(first + len <= self.len() as IdxSize); + match len { + 0 => None, + 1 => { + if ddof == 0 { + NumCast::from(0) + } else { + None + } + }, + _ => { + let arr_group = _slice_from_offsets(self, first, len); + arr_group.std(ddof) + }, + } + }) + } + }, + } + } + + pub(crate) unsafe fn agg_quantile( + &self, + groups: &GroupsType, + quantile: f64, + method: QuantileMethod, + ) -> Series { + agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method) + } + pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series { + agg_median_generic::<_, Float64Type>(self, groups) + } +} diff --git a/crates/polars-core/src/frame/group_by/aggregations/string.rs b/crates/polars-core/src/frame/group_by/aggregations/string.rs new file mode 100644 index 000000000000..099eddefe2a9 --- /dev/null +++ b/crates/polars-core/src/frame/group_by/aggregations/string.rs @@ -0,0 +1,157 @@ +use super::*; + +pub fn _agg_helper_idx_bin<'a, F>(groups: &'a GroupsIdx, f: F) -> Series +where + F: Fn((IdxSize, &'a IdxVec)) -> Option<&'a [u8]> + Send + Sync, +{ + let ca: BinaryChunked = POOL.install(|| groups.into_par_iter().map(f).collect()); + ca.into_series() +} + +pub fn _agg_helper_slice_bin<'a, F>(groups: &'a [[IdxSize; 2]], f: F) -> Series +where + F: Fn([IdxSize; 2]) -> Option<&'a [u8]> + Send + Sync, +{ + let ca: BinaryChunked = POOL.install(|| groups.par_iter().copied().map(f).collect()); + ca.into_series() +} + +impl BinaryChunked { + #[allow(clippy::needless_lifetimes)] + pub(crate) unsafe fn agg_min<'a>(&'a self, groups: &GroupsType) -> Series { + // faster paths + match (&self.is_sorted_flag(), &self.null_count()) { + (IsSorted::Ascending, 0) => { + return self.clone().into_series().agg_first(groups); + }, + (IsSorted::Descending, 0) => { + return self.clone().into_series().agg_last(groups); + }, + _ => {}, + } + + match groups { + GroupsType::Idx(groups) => { + let ca_self = self.rechunk(); + let arr = ca_self.downcast_as_array(); + let no_nulls = arr.null_count() == 0; + _agg_helper_idx_bin(groups, |(first, idx)| { + debug_assert!(idx.len() <= ca_self.len()); + if idx.is_empty() { + None + } else if idx.len() == 1 { + arr.get_unchecked(first as usize) + } else if no_nulls { + take_agg_bin_iter_unchecked_no_null( + arr, + indexes_to_usizes(idx), + |acc, v| if acc < v { acc } else { v }, + ) + } else { + take_agg_bin_iter_unchecked( + arr, + indexes_to_usizes(idx), + |acc, v| if acc < v { acc } else { v }, + idx.len() as IdxSize, + ) + } + }) + }, + GroupsType::Slice { + groups: groups_slice, + .. + } => _agg_helper_slice_bin(groups_slice, |[first, len]| { + debug_assert!(len <= self.len() as IdxSize); + match len { + 0 => None, + 1 => self.get(first as usize), + _ => { + let arr_group = _slice_from_offsets(self, first, len); + let borrowed = arr_group.min_binary(); + + // SAFETY: + // The borrowed has `arr_group`s lifetime, but it actually points to data + // hold by self. Here we tell the compiler that. + unsafe { std::mem::transmute::, Option<&'a [u8]>>(borrowed) } + }, + } + }), + } + } + + #[allow(clippy::needless_lifetimes)] + pub(crate) unsafe fn agg_max<'a>(&'a self, groups: &GroupsType) -> Series { + // faster paths + match (self.is_sorted_flag(), self.null_count()) { + (IsSorted::Ascending, 0) => { + return self.clone().into_series().agg_last(groups); + }, + (IsSorted::Descending, 0) => { + return self.clone().into_series().agg_first(groups); + }, + _ => {}, + } + + match groups { + GroupsType::Idx(groups) => { + let ca_self = self.rechunk(); + let arr = ca_self.downcast_as_array(); + let no_nulls = arr.null_count() == 0; + _agg_helper_idx_bin(groups, |(first, idx)| { + debug_assert!(idx.len() <= self.len()); + if idx.is_empty() { + None + } else if idx.len() == 1 { + ca_self.get(first as usize) + } else if no_nulls { + take_agg_bin_iter_unchecked_no_null( + arr, + indexes_to_usizes(idx), + |acc, v| if acc > v { acc } else { v }, + ) + } else { + take_agg_bin_iter_unchecked( + arr, + indexes_to_usizes(idx), + |acc, v| if acc > v { acc } else { v }, + idx.len() as IdxSize, + ) + } + }) + }, + GroupsType::Slice { + groups: groups_slice, + .. + } => _agg_helper_slice_bin(groups_slice, |[first, len]| { + debug_assert!(len <= self.len() as IdxSize); + match len { + 0 => None, + 1 => self.get(first as usize), + _ => { + let arr_group = _slice_from_offsets(self, first, len); + let borrowed = arr_group.max_binary(); + + // SAFETY: + // The borrowed has `arr_group`s lifetime, but it actually points to data + // hold by self. Here we tell the compiler that. + unsafe { std::mem::transmute::, Option<&'a [u8]>>(borrowed) } + }, + } + }), + } + } +} + +impl StringChunked { + #[allow(clippy::needless_lifetimes)] + pub(crate) unsafe fn agg_min<'a>(&'a self, groups: &GroupsType) -> Series { + let out = self.as_binary().agg_min(groups); + out.binary().unwrap().to_string_unchecked().into_series() + } + + #[allow(clippy::needless_lifetimes)] + pub(crate) unsafe fn agg_max<'a>(&'a self, groups: &GroupsType) -> Series { + let out = self.as_binary().agg_max(groups); + out.binary().unwrap().to_string_unchecked().into_series() + } +} diff --git a/crates/polars-core/src/frame/group_by/expr.rs b/crates/polars-core/src/frame/group_by/expr.rs new file mode 100644 index 000000000000..7348fba2f7c2 --- /dev/null +++ b/crates/polars-core/src/frame/group_by/expr.rs @@ -0,0 +1,8 @@ +use crate::prelude::*; + +pub trait PhysicalAggExpr { + #[allow(clippy::ptr_arg)] + fn evaluate(&self, df: &DataFrame, groups: &GroupPositions) -> PolarsResult; + + fn root_name(&self) -> PolarsResult<&PlSmallStr>; +} diff --git a/crates/polars-core/src/frame/group_by/hashing.rs b/crates/polars-core/src/frame/group_by/hashing.rs new file mode 100644 index 000000000000..871d882fa4bb --- /dev/null +++ b/crates/polars-core/src/frame/group_by/hashing.rs @@ -0,0 +1,231 @@ +use hashbrown::hash_map::Entry; +use polars_utils::hashing::{DirtyHash, hash_to_partition}; +use polars_utils::idx_vec::IdxVec; +use polars_utils::itertools::Itertools; +use polars_utils::sync::SyncPtr; +use polars_utils::total_ord::{ToTotalOrd, TotalHash, TotalOrdWrap}; +use polars_utils::unitvec; +use rayon::prelude::*; + +use crate::POOL; +use crate::hashing::*; +use crate::prelude::*; +use crate::utils::flatten; + +fn get_init_size() -> usize { + // we check if this is executed from the main thread + // we don't want to pre-allocate this much if executed + // group_tuples in a parallel iterator as that explodes allocation + if POOL.current_thread_index().is_none() { + _HASHMAP_INIT_SIZE + } else { + 0 + } +} + +fn finish_group_order(mut out: Vec>, sorted: bool) -> GroupsType { + if sorted { + // we can just take the first value, no need to flatten + let mut out = if out.len() == 1 { + out.pop().unwrap() + } else { + let (cap, offsets) = flatten::cap_and_offsets(&out); + // we write (first, all) tuple because of sorting + let mut items = Vec::with_capacity(cap); + let items_ptr = unsafe { SyncPtr::new(items.as_mut_ptr()) }; + + POOL.install(|| { + out.into_par_iter() + .zip(offsets) + .for_each(|(mut g, offset)| { + // pre-sort every array + // this will make the final single threaded sort much faster + g.sort_unstable_by_key(|g| g.0); + + unsafe { + let mut items_ptr: *mut (IdxSize, IdxVec) = items_ptr.get(); + items_ptr = items_ptr.add(offset); + + for (i, g) in g.into_iter().enumerate() { + std::ptr::write(items_ptr.add(i), g) + } + } + }); + }); + unsafe { + items.set_len(cap); + } + items + }; + out.sort_unstable_by_key(|g| g.0); + let mut idx = GroupsIdx::from_iter(out); + idx.sorted = true; + GroupsType::Idx(idx) + } else { + // we can just take the first value, no need to flatten + if out.len() == 1 { + GroupsType::Idx(GroupsIdx::from(out.pop().unwrap())) + } else { + // flattens + GroupsType::Idx(GroupsIdx::from(out)) + } + } +} + +pub(crate) fn group_by(keys: impl Iterator, sorted: bool) -> GroupsType +where + K: TotalHash + TotalEq, +{ + let init_size = get_init_size(); + let (mut first, mut groups); + if sorted { + groups = Vec::with_capacity(get_init_size()); + first = Vec::with_capacity(get_init_size()); + let mut hash_tbl = PlHashMap::with_capacity(init_size); + for (idx, k) in keys.enumerate_idx() { + match hash_tbl.entry(TotalOrdWrap(k)) { + Entry::Vacant(entry) => { + let group_idx = groups.len() as IdxSize; + entry.insert(group_idx); + groups.push(unitvec![idx]); + first.push(idx); + }, + Entry::Occupied(entry) => unsafe { + groups.get_unchecked_mut(*entry.get() as usize).push(idx) + }, + } + } + } else { + let mut hash_tbl = PlHashMap::with_capacity(init_size); + for (idx, k) in keys.enumerate_idx() { + match hash_tbl.entry(TotalOrdWrap(k)) { + Entry::Vacant(entry) => { + entry.insert((idx, unitvec![idx])); + }, + Entry::Occupied(mut entry) => entry.get_mut().1.push(idx), + } + } + (first, groups) = hash_tbl.into_values().unzip(); + } + GroupsType::Idx(GroupsIdx::new(first, groups, sorted)) +} + +// giving the slice info to the compiler is much +// faster than the using an iterator, that's why we +// have the code duplication +pub(crate) fn group_by_threaded_slice( + keys: Vec, + n_partitions: usize, + sorted: bool, +) -> GroupsType +where + T: ToTotalOrd, + ::TotalOrdItem: Send + Sync + Copy + DirtyHash, + IntoSlice: AsRef<[T]> + Send + Sync, +{ + let init_size = get_init_size(); + + // We will create a hashtable in every thread. + // We use the hash to partition the keys to the matching hashtable. + // Every thread traverses all keys/hashes and ignores the ones that doesn't fall in that partition. + let out = POOL.install(|| { + (0..n_partitions) + .into_par_iter() + .map(|thread_no| { + let mut hash_tbl = PlHashMap::with_capacity(init_size); + + let mut offset = 0; + for keys in &keys { + let keys = keys.as_ref(); + let len = keys.len() as IdxSize; + + for (key_idx, k) in keys.iter().enumerate_idx() { + let k = k.to_total_ord(); + let idx = key_idx + offset; + + if thread_no == hash_to_partition(k.dirty_hash(), n_partitions) { + match hash_tbl.entry(k) { + Entry::Vacant(entry) => { + entry.insert((idx, unitvec![idx])); + }, + Entry::Occupied(mut entry) => { + entry.get_mut().1.push(idx); + }, + } + } + } + offset += len; + } + hash_tbl + .into_iter() + .map(|(_k, v)| v) + .collect_trusted::>() + }) + .collect::>() + }); + finish_group_order(out, sorted) +} + +pub(crate) fn group_by_threaded_iter( + keys: &[I], + n_partitions: usize, + sorted: bool, +) -> GroupsType +where + I: IntoIterator + Send + Sync + Clone, + I::IntoIter: ExactSizeIterator, + T: ToTotalOrd, + ::TotalOrdItem: Send + Sync + Copy + DirtyHash, +{ + let init_size = get_init_size(); + + // We will create a hashtable in every thread. + // We use the hash to partition the keys to the matching hashtable. + // Every thread traverses all keys/hashes and ignores the ones that doesn't fall in that partition. + let out = POOL.install(|| { + (0..n_partitions) + .into_par_iter() + .map(|thread_no| { + let mut hash_tbl: PlHashMap = + PlHashMap::with_capacity(init_size); + + let mut offset = 0; + for keys in keys { + let keys = keys.clone().into_iter(); + let len = keys.len() as IdxSize; + + for (key_idx, k) in keys.into_iter().enumerate_idx() { + let k = k.to_total_ord(); + let idx = key_idx + offset; + + if thread_no == hash_to_partition(k.dirty_hash(), n_partitions) { + match hash_tbl.entry(k) { + Entry::Vacant(entry) => { + entry.insert(unitvec![idx]); + }, + Entry::Occupied(mut entry) => { + entry.get_mut().push(idx); + }, + } + } + } + offset += len; + } + // iterating the hash tables locally + // was faster than iterating in the materialization phase directly + // the proper end vec. I believe this is because the hash-table + // currently is local to the thread so in hot cache + // So we first collect into a tight vec and then do a second + // materialization run + // this is also faster than the index-map approach where we + // directly locally store to a vec at the cost of an extra + // indirection + hash_tbl + .into_iter() + .map(|(_k, v)| (unsafe { *v.first().unwrap_unchecked() }, v)) + .collect_trusted::>() + }) + .collect::>() + }); + finish_group_order(out, sorted) +} diff --git a/crates/polars-core/src/frame/group_by/into_groups.rs b/crates/polars-core/src/frame/group_by/into_groups.rs new file mode 100644 index 000000000000..815f777ebd9d --- /dev/null +++ b/crates/polars-core/src/frame/group_by/into_groups.rs @@ -0,0 +1,377 @@ +use arrow::legacy::kernels::sort_partition::{ + create_clean_partitions, partition_to_groups, partition_to_groups_amortized_varsize, +}; +use polars_error::signals::try_raise_keyboard_interrupt; +use polars_utils::total_ord::{ToTotalOrd, TotalHash}; + +use super::*; +use crate::chunked_array::cast::CastOptions; +use crate::chunked_array::ops::row_encode::_get_rows_encoded_ca_unordered; +use crate::config::verbose; +use crate::series::BitRepr; +use crate::utils::Container; +use crate::utils::flatten::flatten_par; + +/// Used to create the tuples for a group_by operation. +pub trait IntoGroupsType { + /// Create the tuples need for a group_by operation. + /// * The first value in the tuple is the first index of the group. + /// * The second value in the tuple is the indexes of the groups including the first value. + fn group_tuples(&self, _multithreaded: bool, _sorted: bool) -> PolarsResult { + unimplemented!() + } +} + +fn group_multithreaded(ca: &ChunkedArray) -> bool { + // TODO! change to something sensible + ca.len() > 1000 && POOL.current_num_threads() > 1 +} + +fn num_groups_proxy(ca: &ChunkedArray, multithreaded: bool, sorted: bool) -> GroupsType +where + T: PolarsNumericType, + T::Native: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash, +{ + if multithreaded && group_multithreaded(ca) { + let n_partitions = _set_partition_size(); + + // use the arrays as iterators + if ca.null_count() == 0 { + let keys = ca + .downcast_iter() + .map(|arr| arr.values().as_slice()) + .collect::>(); + group_by_threaded_slice(keys, n_partitions, sorted) + } else { + let keys = ca + .downcast_iter() + .map(|arr| arr.iter().map(|o| o.copied())) + .collect::>(); + group_by_threaded_iter(&keys, n_partitions, sorted) + } + } else if !ca.has_nulls() { + group_by(ca.into_no_null_iter(), sorted) + } else { + group_by(ca.iter(), sorted) + } +} + +impl ChunkedArray +where + T: PolarsNumericType, + T::Native: NumCast, +{ + fn create_groups_from_sorted(&self, multithreaded: bool) -> GroupsSlice { + if verbose() { + eprintln!("group_by keys are sorted; running sorted key fast path"); + } + let arr = self.downcast_iter().next().unwrap(); + if arr.is_empty() { + return GroupsSlice::default(); + } + let mut values = arr.values().as_slice(); + let null_count = arr.null_count(); + let length = values.len(); + + // all nulls + if null_count == length { + return vec![[0, length as IdxSize]]; + } + + let mut nulls_first = false; + if null_count > 0 { + nulls_first = arr.get(0).is_none() + } + + if nulls_first { + values = &values[null_count..]; + } else { + values = &values[..length - null_count]; + }; + + let n_threads = POOL.current_num_threads(); + if multithreaded && n_threads > 1 { + let parts = + create_clean_partitions(values, n_threads, self.is_sorted_descending_flag()); + let n_parts = parts.len(); + + let first_ptr = &values[0] as *const T::Native as usize; + let groups = parts.par_iter().enumerate().map(|(i, part)| { + // we go via usize as *const is not send + let first_ptr = first_ptr as *const T::Native; + + let part_first_ptr = &part[0] as *const T::Native; + let mut offset = unsafe { part_first_ptr.offset_from(first_ptr) } as IdxSize; + + // nulls first: only add the nulls at the first partition + if nulls_first && i == 0 { + partition_to_groups(part, null_count as IdxSize, true, offset) + } + // nulls last: only compute at the last partition + else if !nulls_first && i == n_parts - 1 { + partition_to_groups(part, null_count as IdxSize, false, offset) + } + // other partitions + else { + if nulls_first { + offset += null_count as IdxSize; + }; + + partition_to_groups(part, 0, false, offset) + } + }); + let groups = POOL.install(|| groups.collect::>()); + flatten_par(&groups) + } else { + partition_to_groups(values, null_count as IdxSize, nulls_first, 0) + } + } +} + +#[cfg(all(feature = "dtype-categorical", feature = "performant"))] +impl IntoGroupsType for CategoricalChunked { + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + Ok(self.group_tuples_perfect(multithreaded, sorted)) + } +} + +impl IntoGroupsType for ChunkedArray +where + T: PolarsNumericType, + T::Native: NumCast, +{ + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + // sorted path + if self.is_sorted_ascending_flag() || self.is_sorted_descending_flag() { + // don't have to pass `sorted` arg, GroupSlice is always sorted. + return Ok(GroupsType::Slice { + groups: self.rechunk().create_groups_from_sorted(multithreaded), + rolling: false, + }); + } + + let out = match self.dtype() { + DataType::UInt64 => { + // convince the compiler that we are this type. + let ca: &UInt64Chunked = unsafe { + &*(self as *const ChunkedArray as *const ChunkedArray) + }; + num_groups_proxy(ca, multithreaded, sorted) + }, + DataType::UInt32 => { + // convince the compiler that we are this type. + let ca: &UInt32Chunked = unsafe { + &*(self as *const ChunkedArray as *const ChunkedArray) + }; + num_groups_proxy(ca, multithreaded, sorted) + }, + DataType::Int64 => { + let BitRepr::Large(ca) = self.to_bit_repr() else { + unreachable!() + }; + num_groups_proxy(&ca, multithreaded, sorted) + }, + DataType::Int32 => { + let BitRepr::Small(ca) = self.to_bit_repr() else { + unreachable!() + }; + num_groups_proxy(&ca, multithreaded, sorted) + }, + DataType::Float64 => { + // convince the compiler that we are this type. + let ca: &Float64Chunked = unsafe { + &*(self as *const ChunkedArray as *const ChunkedArray) + }; + num_groups_proxy(ca, multithreaded, sorted) + }, + DataType::Float32 => { + // convince the compiler that we are this type. + let ca: &Float32Chunked = unsafe { + &*(self as *const ChunkedArray as *const ChunkedArray) + }; + num_groups_proxy(ca, multithreaded, sorted) + }, + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(_, _) => { + // convince the compiler that we are this type. + let ca: &Int128Chunked = unsafe { + &*(self as *const ChunkedArray as *const ChunkedArray) + }; + num_groups_proxy(ca, multithreaded, sorted) + }, + #[cfg(all(feature = "performant", feature = "dtype-i8", feature = "dtype-u8"))] + DataType::Int8 => { + // convince the compiler that we are this type. + let ca: &Int8Chunked = + unsafe { &*(self as *const ChunkedArray as *const ChunkedArray) }; + let s = ca.reinterpret_unsigned(); + return s.group_tuples(multithreaded, sorted); + }, + #[cfg(all(feature = "performant", feature = "dtype-i8", feature = "dtype-u8"))] + DataType::UInt8 => { + // convince the compiler that we are this type. + let ca: &UInt8Chunked = + unsafe { &*(self as *const ChunkedArray as *const ChunkedArray) }; + num_groups_proxy(ca, multithreaded, sorted) + }, + #[cfg(all(feature = "performant", feature = "dtype-i16", feature = "dtype-u16"))] + DataType::Int16 => { + // convince the compiler that we are this type. + let ca: &Int16Chunked = + unsafe { &*(self as *const ChunkedArray as *const ChunkedArray) }; + let s = ca.reinterpret_unsigned(); + return s.group_tuples(multithreaded, sorted); + }, + #[cfg(all(feature = "performant", feature = "dtype-i16", feature = "dtype-u16"))] + DataType::UInt16 => { + // convince the compiler that we are this type. + let ca: &UInt16Chunked = unsafe { + &*(self as *const ChunkedArray as *const ChunkedArray) + }; + num_groups_proxy(ca, multithreaded, sorted) + }, + _ => { + let ca = unsafe { self.cast_unchecked(&DataType::UInt32).unwrap() }; + let ca = ca.u32().unwrap(); + num_groups_proxy(ca, multithreaded, sorted) + }, + }; + try_raise_keyboard_interrupt(); + Ok(out) + } +} +impl IntoGroupsType for BooleanChunked { + fn group_tuples(&self, mut multithreaded: bool, sorted: bool) -> PolarsResult { + multithreaded &= POOL.current_num_threads() > 1; + + #[cfg(feature = "performant")] + { + let ca = self + .cast_with_options(&DataType::UInt8, CastOptions::Overflowing) + .unwrap(); + let ca = ca.u8().unwrap(); + ca.group_tuples(multithreaded, sorted) + } + #[cfg(not(feature = "performant"))] + { + let ca = self + .cast_with_options(&DataType::UInt32, CastOptions::Overflowing) + .unwrap(); + let ca = ca.u32().unwrap(); + ca.group_tuples(multithreaded, sorted) + } + } +} + +impl IntoGroupsType for StringChunked { + #[allow(clippy::needless_lifetimes)] + fn group_tuples<'a>(&'a self, multithreaded: bool, sorted: bool) -> PolarsResult { + self.as_binary().group_tuples(multithreaded, sorted) + } +} + +impl IntoGroupsType for BinaryChunked { + #[allow(clippy::needless_lifetimes)] + fn group_tuples<'a>( + &'a self, + mut multithreaded: bool, + sorted: bool, + ) -> PolarsResult { + if self.is_sorted_any() && !self.has_nulls() && self.n_chunks() == 1 { + let arr = self.downcast_get(0).unwrap(); + let values = arr.values_iter(); + let mut out = Vec::with_capacity(values.len() / 30); + partition_to_groups_amortized_varsize(values, arr.len() as _, 0, false, 0, &mut out); + return Ok(GroupsType::Slice { + groups: out, + rolling: false, + }); + } + + multithreaded &= POOL.current_num_threads() > 1; + let bh = self.to_bytes_hashes(multithreaded, Default::default()); + + let out = if multithreaded { + let n_partitions = bh.len(); + // Take slices so that the vecs are not cloned. + let bh = bh.iter().map(|v| v.as_slice()).collect::>(); + group_by_threaded_slice(bh, n_partitions, sorted) + } else { + group_by(bh[0].iter(), sorted) + }; + try_raise_keyboard_interrupt(); + Ok(out) + } +} + +impl IntoGroupsType for BinaryOffsetChunked { + #[allow(clippy::needless_lifetimes)] + fn group_tuples<'a>( + &'a self, + mut multithreaded: bool, + sorted: bool, + ) -> PolarsResult { + if self.is_sorted_any() && !self.has_nulls() && self.n_chunks() == 1 { + let arr = self.downcast_get(0).unwrap(); + let values = arr.values_iter(); + let mut out = Vec::with_capacity(values.len() / 30); + partition_to_groups_amortized_varsize(values, arr.len() as _, 0, false, 0, &mut out); + return Ok(GroupsType::Slice { + groups: out, + rolling: false, + }); + } + multithreaded &= POOL.current_num_threads() > 1; + let bh = self.to_bytes_hashes(multithreaded, Default::default()); + + let out = if multithreaded { + let n_partitions = bh.len(); + // Take slices so that the vecs are not cloned. + let bh = bh.iter().map(|v| v.as_slice()).collect::>(); + group_by_threaded_slice(bh, n_partitions, sorted) + } else { + group_by(bh[0].iter(), sorted) + }; + Ok(out) + } +} + +impl IntoGroupsType for ListChunked { + #[allow(clippy::needless_lifetimes)] + #[allow(unused_variables)] + fn group_tuples<'a>( + &'a self, + mut multithreaded: bool, + sorted: bool, + ) -> PolarsResult { + multithreaded &= POOL.current_num_threads() > 1; + let by = &[self.clone().into_column()]; + let ca = if multithreaded { + encode_rows_vertical_par_unordered(by).unwrap() + } else { + _get_rows_encoded_ca_unordered(PlSmallStr::EMPTY, by).unwrap() + }; + + ca.group_tuples(multithreaded, sorted) + } +} + +#[cfg(feature = "dtype-array")] +impl IntoGroupsType for ArrayChunked { + #[allow(clippy::needless_lifetimes)] + #[allow(unused_variables)] + fn group_tuples<'a>(&'a self, _multithreaded: bool, _sorted: bool) -> PolarsResult { + todo!("grouping FixedSizeList not yet supported") + } +} + +#[cfg(feature = "object")] +impl IntoGroupsType for ObjectChunked +where + T: PolarsObject, +{ + fn group_tuples(&self, _multithreaded: bool, sorted: bool) -> PolarsResult { + Ok(group_by(self.into_iter(), sorted)) + } +} diff --git a/crates/polars-core/src/frame/group_by/mod.rs b/crates/polars-core/src/frame/group_by/mod.rs new file mode 100644 index 000000000000..945b2e4ddc49 --- /dev/null +++ b/crates/polars-core/src/frame/group_by/mod.rs @@ -0,0 +1,1208 @@ +use std::fmt::{Debug, Display, Formatter}; +use std::hash::Hash; + +use num_traits::NumCast; +use polars_compute::rolling::QuantileMethod; +use polars_utils::format_pl_smallstr; +use polars_utils::hashing::DirtyHash; +use rayon::prelude::*; + +use self::hashing::*; +use crate::POOL; +use crate::prelude::*; +use crate::utils::{_set_partition_size, accumulate_dataframes_vertical}; + +pub mod aggregations; +pub mod expr; +pub(crate) mod hashing; +mod into_groups; +mod perfect; +mod position; + +pub use into_groups::*; +pub use position::*; + +use crate::chunked_array::ops::row_encode::{ + encode_rows_unordered, encode_rows_vertical_par_unordered, +}; + +impl DataFrame { + pub fn group_by_with_series( + &self, + mut by: Vec, + multithreaded: bool, + sorted: bool, + ) -> PolarsResult { + polars_ensure!( + !by.is_empty(), + ComputeError: "at least one key is required in a group_by operation" + ); + let minimal_by_len = by.iter().map(|s| s.len()).min().expect("at least 1 key"); + let df_height = self.height(); + + // we only throw this error if self.width > 0 + // so that we can still call this on a dummy dataframe where we provide the keys + if (minimal_by_len != df_height) && (self.width() > 0) { + polars_ensure!( + minimal_by_len == 1, + ShapeMismatch: "series used as keys should have the same length as the DataFrame" + ); + for by_key in by.iter_mut() { + if by_key.len() == minimal_by_len { + *by_key = by_key.new_from_index(0, df_height) + } + } + }; + + let groups = if by.len() == 1 { + let column = &by[0]; + column + .as_materialized_series() + .group_tuples(multithreaded, sorted) + } else if by.iter().any(|s| s.dtype().is_object()) { + #[cfg(feature = "object")] + { + let mut df = DataFrame::new(by.clone()).unwrap(); + let n = df.height(); + let rows = df.to_av_rows(); + let iter = (0..n).map(|i| rows.get(i)); + Ok(group_by(iter, sorted)) + } + #[cfg(not(feature = "object"))] + { + unreachable!() + } + } else { + // Skip null dtype. + let by = by + .iter() + .filter(|s| !s.dtype().is_null()) + .cloned() + .collect::>(); + if by.is_empty() { + let groups = if self.is_empty() { + vec![] + } else { + vec![[0, self.height() as IdxSize]] + }; + Ok(GroupsType::Slice { + groups, + rolling: false, + }) + } else { + let rows = if multithreaded { + encode_rows_vertical_par_unordered(&by) + } else { + encode_rows_unordered(&by) + }? + .into_series(); + rows.group_tuples(multithreaded, sorted) + } + }; + Ok(GroupBy::new(self, by, groups?.into_sliceable(), None)) + } + + /// Group DataFrame using a Series column. + /// + /// # Example + /// + /// ``` + /// use polars_core::prelude::*; + /// fn group_by_sum(df: &DataFrame) -> PolarsResult { + /// df.group_by(["column_name"])? + /// .select(["agg_column_name"]) + /// .sum() + /// } + /// ``` + pub fn group_by(&self, by: I) -> PolarsResult + where + I: IntoIterator, + S: Into, + { + let selected_keys = self.select_columns(by)?; + self.group_by_with_series(selected_keys, true, false) + } + + /// Group DataFrame using a Series column. + /// The groups are ordered by their smallest row index. + pub fn group_by_stable(&self, by: I) -> PolarsResult + where + I: IntoIterator, + S: Into, + { + let selected_keys = self.select_columns(by)?; + self.group_by_with_series(selected_keys, true, true) + } +} + +/// Returned by a group_by operation on a DataFrame. This struct supports +/// several aggregations. +/// +/// Until described otherwise, the examples in this struct are performed on the following DataFrame: +/// +/// ```ignore +/// use polars_core::prelude::*; +/// +/// let dates = &[ +/// "2020-08-21", +/// "2020-08-21", +/// "2020-08-22", +/// "2020-08-23", +/// "2020-08-22", +/// ]; +/// // date format +/// let fmt = "%Y-%m-%d"; +/// // create date series +/// let s0 = DateChunked::parse_from_str_slice("date", dates, fmt) +/// .into_series(); +/// // create temperature series +/// let s1 = Series::new("temp".into(), [20, 10, 7, 9, 1]); +/// // create rain series +/// let s2 = Series::new("rain".into(), [0.2, 0.1, 0.3, 0.1, 0.01]); +/// // create a new DataFrame +/// let df = DataFrame::new(vec![s0, s1, s2]).unwrap(); +/// println!("{:?}", df); +/// ``` +/// +/// Outputs: +/// +/// ```text +/// +------------+------+------+ +/// | date | temp | rain | +/// | --- | --- | --- | +/// | Date | i32 | f64 | +/// +============+======+======+ +/// | 2020-08-21 | 20 | 0.2 | +/// +------------+------+------+ +/// | 2020-08-21 | 10 | 0.1 | +/// +------------+------+------+ +/// | 2020-08-22 | 7 | 0.3 | +/// +------------+------+------+ +/// | 2020-08-23 | 9 | 0.1 | +/// +------------+------+------+ +/// | 2020-08-22 | 1 | 0.01 | +/// +------------+------+------+ +/// ``` +/// +#[derive(Debug, Clone)] +pub struct GroupBy<'a> { + pub df: &'a DataFrame, + pub(crate) selected_keys: Vec, + // [first idx, [other idx]] + groups: GroupPositions, + // columns selected for aggregation + pub(crate) selected_agg: Option>, +} + +impl<'a> GroupBy<'a> { + pub fn new( + df: &'a DataFrame, + by: Vec, + groups: GroupPositions, + selected_agg: Option>, + ) -> Self { + GroupBy { + df, + selected_keys: by, + groups, + selected_agg, + } + } + + /// Select the column(s) that should be aggregated. + /// You can select a single column or a slice of columns. + /// + /// Note that making a selection with this method is not required. If you + /// skip it all columns (except for the keys) will be selected for aggregation. + #[must_use] + pub fn select, S: Into>(mut self, selection: I) -> Self { + self.selected_agg = Some(selection.into_iter().map(|s| s.into()).collect()); + self + } + + /// Get the internal representation of the GroupBy operation. + /// The Vec returned contains: + /// (first_idx, [`Vec`]) + /// Where second value in the tuple is a vector with all matching indexes. + pub fn get_groups(&self) -> &GroupPositions { + &self.groups + } + + /// Get the internal representation of the GroupBy operation. + /// The Vec returned contains: + /// (first_idx, [`Vec`]) + /// Where second value in the tuple is a vector with all matching indexes. + /// + /// # Safety + /// Groups should always be in bounds of the `DataFrame` hold by this [`GroupBy`]. + /// If you mutate it, you must hold that invariant. + pub unsafe fn get_groups_mut(&mut self) -> &mut GroupPositions { + &mut self.groups + } + + pub fn take_groups(self) -> GroupPositions { + self.groups + } + + pub fn take_groups_mut(&mut self) -> GroupPositions { + std::mem::take(&mut self.groups) + } + + pub fn keys_sliced(&self, slice: Option<(i64, usize)>) -> Vec { + #[allow(unused_assignments)] + // needed to keep the lifetimes valid for this scope + let mut groups_owned = None; + + let groups = if let Some((offset, len)) = slice { + groups_owned = Some(self.groups.slice(offset, len)); + groups_owned.as_deref().unwrap() + } else { + &self.groups + }; + POOL.install(|| { + self.selected_keys + .par_iter() + .map(Column::as_materialized_series) + .map(|s| { + match groups { + GroupsType::Idx(groups) => { + // SAFETY: groups are always in bounds. + let mut out = unsafe { s.take_slice_unchecked(groups.first()) }; + if groups.sorted { + out.set_sorted_flag(s.is_sorted_flag()); + }; + out + }, + GroupsType::Slice { groups, rolling } => { + if *rolling && !groups.is_empty() { + // Groups can be sliced. + let offset = groups[0][0]; + let [upper_offset, upper_len] = groups[groups.len() - 1]; + return s.slice( + offset as i64, + ((upper_offset + upper_len) - offset) as usize, + ); + } + + let indices = groups + .iter() + .map(|&[first, _len]| first) + .collect_ca(PlSmallStr::EMPTY); + // SAFETY: groups are always in bounds. + let mut out = unsafe { s.take_unchecked(&indices) }; + // Sliced groups are always in order of discovery. + out.set_sorted_flag(s.is_sorted_flag()); + out + }, + } + }) + .map(Column::from) + .collect() + }) + } + + pub fn keys(&self) -> Vec { + self.keys_sliced(None) + } + + fn prepare_agg(&self) -> PolarsResult<(Vec, Vec)> { + let keys = self.keys(); + + let agg_col = match &self.selected_agg { + Some(selection) => self.df.select_columns_impl(selection.as_slice()), + None => { + let by: Vec<_> = self.selected_keys.iter().map(|s| s.name()).collect(); + let selection = self + .df + .iter() + .map(|s| s.name()) + .filter(|a| !by.contains(a)) + .cloned() + .collect::>(); + + self.df.select_columns_impl(selection.as_slice()) + }, + }?; + + Ok((keys, agg_col)) + } + + /// Aggregate grouped series and compute the mean per group. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// fn example(df: DataFrame) -> PolarsResult { + /// df.group_by(["date"])?.select(["temp", "rain"]).mean() + /// } + /// ``` + /// Returns: + /// + /// ```text + /// +------------+-----------+-----------+ + /// | date | temp_mean | rain_mean | + /// | --- | --- | --- | + /// | Date | f64 | f64 | + /// +============+===========+===========+ + /// | 2020-08-23 | 9 | 0.1 | + /// +------------+-----------+-----------+ + /// | 2020-08-22 | 4 | 0.155 | + /// +------------+-----------+-----------+ + /// | 2020-08-21 | 15 | 0.15 | + /// +------------+-----------+-----------+ + /// ``` + #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")] + pub fn mean(&self) -> PolarsResult { + let (mut cols, agg_cols) = self.prepare_agg()?; + + for agg_col in agg_cols { + let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Mean); + let mut agg = unsafe { agg_col.agg_mean(&self.groups) }; + agg.rename(new_name); + cols.push(agg); + } + DataFrame::new(cols) + } + + /// Aggregate grouped series and compute the sum per group. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// fn example(df: DataFrame) -> PolarsResult { + /// df.group_by(["date"])?.select(["temp"]).sum() + /// } + /// ``` + /// Returns: + /// + /// ```text + /// +------------+----------+ + /// | date | temp_sum | + /// | --- | --- | + /// | Date | i32 | + /// +============+==========+ + /// | 2020-08-23 | 9 | + /// +------------+----------+ + /// | 2020-08-22 | 8 | + /// +------------+----------+ + /// | 2020-08-21 | 30 | + /// +------------+----------+ + /// ``` + #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")] + pub fn sum(&self) -> PolarsResult { + let (mut cols, agg_cols) = self.prepare_agg()?; + + for agg_col in agg_cols { + let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Sum); + let mut agg = unsafe { agg_col.agg_sum(&self.groups) }; + agg.rename(new_name); + cols.push(agg); + } + DataFrame::new(cols) + } + + /// Aggregate grouped series and compute the minimal value per group. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// fn example(df: DataFrame) -> PolarsResult { + /// df.group_by(["date"])?.select(["temp"]).min() + /// } + /// ``` + /// Returns: + /// + /// ```text + /// +------------+----------+ + /// | date | temp_min | + /// | --- | --- | + /// | Date | i32 | + /// +============+==========+ + /// | 2020-08-23 | 9 | + /// +------------+----------+ + /// | 2020-08-22 | 1 | + /// +------------+----------+ + /// | 2020-08-21 | 10 | + /// +------------+----------+ + /// ``` + #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")] + pub fn min(&self) -> PolarsResult { + let (mut cols, agg_cols) = self.prepare_agg()?; + for agg_col in agg_cols { + let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Min); + let mut agg = unsafe { agg_col.agg_min(&self.groups) }; + agg.rename(new_name); + cols.push(agg); + } + DataFrame::new(cols) + } + + /// Aggregate grouped series and compute the maximum value per group. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// fn example(df: DataFrame) -> PolarsResult { + /// df.group_by(["date"])?.select(["temp"]).max() + /// } + /// ``` + /// Returns: + /// + /// ```text + /// +------------+----------+ + /// | date | temp_max | + /// | --- | --- | + /// | Date | i32 | + /// +============+==========+ + /// | 2020-08-23 | 9 | + /// +------------+----------+ + /// | 2020-08-22 | 7 | + /// +------------+----------+ + /// | 2020-08-21 | 20 | + /// +------------+----------+ + /// ``` + #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")] + pub fn max(&self) -> PolarsResult { + let (mut cols, agg_cols) = self.prepare_agg()?; + for agg_col in agg_cols { + let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Max); + let mut agg = unsafe { agg_col.agg_max(&self.groups) }; + agg.rename(new_name); + cols.push(agg); + } + DataFrame::new(cols) + } + + /// Aggregate grouped `Series` and find the first value per group. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// fn example(df: DataFrame) -> PolarsResult { + /// df.group_by(["date"])?.select(["temp"]).first() + /// } + /// ``` + /// Returns: + /// + /// ```text + /// +------------+------------+ + /// | date | temp_first | + /// | --- | --- | + /// | Date | i32 | + /// +============+============+ + /// | 2020-08-23 | 9 | + /// +------------+------------+ + /// | 2020-08-22 | 7 | + /// +------------+------------+ + /// | 2020-08-21 | 20 | + /// +------------+------------+ + /// ``` + #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")] + pub fn first(&self) -> PolarsResult { + let (mut cols, agg_cols) = self.prepare_agg()?; + for agg_col in agg_cols { + let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::First); + let mut agg = unsafe { agg_col.agg_first(&self.groups) }; + agg.rename(new_name); + cols.push(agg); + } + DataFrame::new(cols) + } + + /// Aggregate grouped `Series` and return the last value per group. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// fn example(df: DataFrame) -> PolarsResult { + /// df.group_by(["date"])?.select(["temp"]).last() + /// } + /// ``` + /// Returns: + /// + /// ```text + /// +------------+------------+ + /// | date | temp_last | + /// | --- | --- | + /// | Date | i32 | + /// +============+============+ + /// | 2020-08-23 | 9 | + /// +------------+------------+ + /// | 2020-08-22 | 1 | + /// +------------+------------+ + /// | 2020-08-21 | 10 | + /// +------------+------------+ + /// ``` + #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")] + pub fn last(&self) -> PolarsResult { + let (mut cols, agg_cols) = self.prepare_agg()?; + for agg_col in agg_cols { + let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Last); + let mut agg = unsafe { agg_col.agg_last(&self.groups) }; + agg.rename(new_name); + cols.push(agg); + } + DataFrame::new(cols) + } + + /// Aggregate grouped `Series` by counting the number of unique values. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// fn example(df: DataFrame) -> PolarsResult { + /// df.group_by(["date"])?.select(["temp"]).n_unique() + /// } + /// ``` + /// Returns: + /// + /// ```text + /// +------------+---------------+ + /// | date | temp_n_unique | + /// | --- | --- | + /// | Date | u32 | + /// +============+===============+ + /// | 2020-08-23 | 1 | + /// +------------+---------------+ + /// | 2020-08-22 | 2 | + /// +------------+---------------+ + /// | 2020-08-21 | 2 | + /// +------------+---------------+ + /// ``` + #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")] + pub fn n_unique(&self) -> PolarsResult { + let (mut cols, agg_cols) = self.prepare_agg()?; + for agg_col in agg_cols { + let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::NUnique); + let mut agg = unsafe { agg_col.agg_n_unique(&self.groups) }; + agg.rename(new_name); + cols.push(agg); + } + DataFrame::new(cols) + } + + /// Aggregate grouped [`Series`] and determine the quantile per group. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// + /// fn example(df: DataFrame) -> PolarsResult { + /// df.group_by(["date"])?.select(["temp"]).quantile(0.2, QuantileMethod::default()) + /// } + /// ``` + #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")] + pub fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { + polars_ensure!( + (0.0..=1.0).contains(&quantile), + ComputeError: "`quantile` should be within 0.0 and 1.0" + ); + let (mut cols, agg_cols) = self.prepare_agg()?; + for agg_col in agg_cols { + let new_name = fmt_group_by_column( + agg_col.name().as_str(), + GroupByMethod::Quantile(quantile, method), + ); + let mut agg = unsafe { agg_col.agg_quantile(&self.groups, quantile, method) }; + agg.rename(new_name); + cols.push(agg); + } + DataFrame::new(cols) + } + + /// Aggregate grouped [`Series`] and determine the median per group. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// fn example(df: DataFrame) -> PolarsResult { + /// df.group_by(["date"])?.select(["temp"]).median() + /// } + /// ``` + #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")] + pub fn median(&self) -> PolarsResult { + let (mut cols, agg_cols) = self.prepare_agg()?; + for agg_col in agg_cols { + let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Median); + let mut agg = unsafe { agg_col.agg_median(&self.groups) }; + agg.rename(new_name); + cols.push(agg); + } + DataFrame::new(cols) + } + + /// Aggregate grouped [`Series`] and determine the variance per group. + #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")] + pub fn var(&self, ddof: u8) -> PolarsResult { + let (mut cols, agg_cols) = self.prepare_agg()?; + for agg_col in agg_cols { + let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Var(ddof)); + let mut agg = unsafe { agg_col.agg_var(&self.groups, ddof) }; + agg.rename(new_name); + cols.push(agg); + } + DataFrame::new(cols) + } + + /// Aggregate grouped [`Series`] and determine the standard deviation per group. + #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")] + pub fn std(&self, ddof: u8) -> PolarsResult { + let (mut cols, agg_cols) = self.prepare_agg()?; + for agg_col in agg_cols { + let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Std(ddof)); + let mut agg = unsafe { agg_col.agg_std(&self.groups, ddof) }; + agg.rename(new_name); + cols.push(agg); + } + DataFrame::new(cols) + } + + /// Aggregate grouped series and compute the number of values per group. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// fn example(df: DataFrame) -> PolarsResult { + /// df.group_by(["date"])?.select(["temp"]).count() + /// } + /// ``` + /// Returns: + /// + /// ```text + /// +------------+------------+ + /// | date | temp_count | + /// | --- | --- | + /// | Date | u32 | + /// +============+============+ + /// | 2020-08-23 | 1 | + /// +------------+------------+ + /// | 2020-08-22 | 2 | + /// +------------+------------+ + /// | 2020-08-21 | 2 | + /// +------------+------------+ + /// ``` + pub fn count(&self) -> PolarsResult { + let (mut cols, agg_cols) = self.prepare_agg()?; + + for agg_col in agg_cols { + let new_name = fmt_group_by_column( + agg_col.name().as_str(), + GroupByMethod::Count { + include_nulls: true, + }, + ); + let mut ca = self.groups.group_count(); + ca.rename(new_name); + cols.push(ca.into_column()); + } + DataFrame::new(cols) + } + + /// Get the group_by group indexes. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// fn example(df: DataFrame) -> PolarsResult { + /// df.group_by(["date"])?.groups() + /// } + /// ``` + /// Returns: + /// + /// ```text + /// +--------------+------------+ + /// | date | groups | + /// | --- | --- | + /// | Date(days) | list [u32] | + /// +==============+============+ + /// | 2020-08-23 | "[3]" | + /// +--------------+------------+ + /// | 2020-08-22 | "[2, 4]" | + /// +--------------+------------+ + /// | 2020-08-21 | "[0, 1]" | + /// +--------------+------------+ + /// ``` + pub fn groups(&self) -> PolarsResult { + let mut cols = self.keys(); + let mut column = self.groups.as_list_chunked(); + let new_name = fmt_group_by_column("", GroupByMethod::Groups); + column.rename(new_name); + cols.push(column.into_column()); + DataFrame::new(cols) + } + + /// Aggregate the groups of the group_by operation into lists. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// fn example(df: DataFrame) -> PolarsResult { + /// // GroupBy and aggregate to Lists + /// df.group_by(["date"])?.select(["temp"]).agg_list() + /// } + /// ``` + /// Returns: + /// + /// ```text + /// +------------+------------------------+ + /// | date | temp_agg_list | + /// | --- | --- | + /// | Date | list [i32] | + /// +============+========================+ + /// | 2020-08-23 | "[Some(9)]" | + /// +------------+------------------------+ + /// | 2020-08-22 | "[Some(7), Some(1)]" | + /// +------------+------------------------+ + /// | 2020-08-21 | "[Some(20), Some(10)]" | + /// +------------+------------------------+ + /// ``` + #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")] + pub fn agg_list(&self) -> PolarsResult { + let (mut cols, agg_cols) = self.prepare_agg()?; + for agg_col in agg_cols { + let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Implode); + let mut agg = unsafe { agg_col.agg_list(&self.groups) }; + agg.rename(new_name); + cols.push(agg); + } + DataFrame::new(cols) + } + + fn prepare_apply(&self) -> PolarsResult { + polars_ensure!(self.df.height() > 0, ComputeError: "cannot group_by + apply on empty 'DataFrame'"); + if let Some(agg) = &self.selected_agg { + if agg.is_empty() { + Ok(self.df.clone()) + } else { + let mut new_cols = Vec::with_capacity(self.selected_keys.len() + agg.len()); + new_cols.extend_from_slice(&self.selected_keys); + let cols = self.df.select_columns_impl(agg.as_slice())?; + new_cols.extend(cols); + Ok(unsafe { DataFrame::new_no_checks(self.df.height(), new_cols) }) + } + } else { + Ok(self.df.clone()) + } + } + + /// Apply a closure over the groups as a new [`DataFrame`] in parallel. + #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")] + pub fn par_apply(&self, f: F) -> PolarsResult + where + F: Fn(DataFrame) -> PolarsResult + Send + Sync, + { + let df = self.prepare_apply()?; + let dfs = self + .get_groups() + .par_iter() + .map(|g| { + // SAFETY: + // groups are in bounds + let sub_df = unsafe { take_df(&df, g) }; + f(sub_df) + }) + .collect::>>()?; + + let mut df = accumulate_dataframes_vertical(dfs)?; + df.as_single_chunk_par(); + Ok(df) + } + + /// Apply a closure over the groups as a new [`DataFrame`]. + pub fn apply(&self, mut f: F) -> PolarsResult + where + F: FnMut(DataFrame) -> PolarsResult + Send + Sync, + { + let df = self.prepare_apply()?; + let dfs = self + .get_groups() + .iter() + .map(|g| { + // SAFETY: + // groups are in bounds + let sub_df = unsafe { take_df(&df, g) }; + f(sub_df) + }) + .collect::>>()?; + + let mut df = accumulate_dataframes_vertical(dfs)?; + df.as_single_chunk_par(); + Ok(df) + } + + pub fn sliced(mut self, slice: Option<(i64, usize)>) -> Self { + match slice { + None => self, + Some((offset, length)) => { + self.groups = (self.groups.slice(offset, length)).clone(); + self.selected_keys = self.keys_sliced(slice); + self + }, + } + } +} + +unsafe fn take_df(df: &DataFrame, g: GroupsIndicator) -> DataFrame { + match g { + GroupsIndicator::Idx(idx) => df.take_slice_unchecked(idx.1), + GroupsIndicator::Slice([first, len]) => df.slice(first as i64, len as usize), + } +} + +#[derive(Copy, Clone, Debug)] +pub enum GroupByMethod { + Min, + NanMin, + Max, + NanMax, + Median, + Mean, + First, + Last, + Sum, + Groups, + NUnique, + Quantile(f64, QuantileMethod), + Count { include_nulls: bool }, + Implode, + Std(u8), + Var(u8), +} + +impl Display for GroupByMethod { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use GroupByMethod::*; + let s = match self { + Min => "min", + NanMin => "nan_min", + Max => "max", + NanMax => "nan_max", + Median => "median", + Mean => "mean", + First => "first", + Last => "last", + Sum => "sum", + Groups => "groups", + NUnique => "n_unique", + Quantile(_, _) => "quantile", + Count { .. } => "count", + Implode => "list", + Std(_) => "std", + Var(_) => "var", + }; + write!(f, "{s}") + } +} + +// Formatting functions used in eager and lazy code for renaming grouped columns +pub fn fmt_group_by_column(name: &str, method: GroupByMethod) -> PlSmallStr { + use GroupByMethod::*; + match method { + Min => format_pl_smallstr!("{name}_min"), + Max => format_pl_smallstr!("{name}_max"), + NanMin => format_pl_smallstr!("{name}_nan_min"), + NanMax => format_pl_smallstr!("{name}_nan_max"), + Median => format_pl_smallstr!("{name}_median"), + Mean => format_pl_smallstr!("{name}_mean"), + First => format_pl_smallstr!("{name}_first"), + Last => format_pl_smallstr!("{name}_last"), + Sum => format_pl_smallstr!("{name}_sum"), + Groups => PlSmallStr::from_static("groups"), + NUnique => format_pl_smallstr!("{name}_n_unique"), + Count { .. } => format_pl_smallstr!("{name}_count"), + Implode => format_pl_smallstr!("{name}_agg_list"), + Quantile(quantile, _interpol) => format_pl_smallstr!("{name}_quantile_{quantile:.2}"), + Std(_) => format_pl_smallstr!("{name}_agg_std"), + Var(_) => format_pl_smallstr!("{name}_agg_var"), + } +} + +#[cfg(test)] +mod test { + use num_traits::FloatConst; + + use crate::prelude::*; + + #[test] + #[cfg(feature = "dtype-date")] + #[cfg_attr(miri, ignore)] + fn test_group_by() -> PolarsResult<()> { + let s0 = Column::new( + PlSmallStr::from_static("date"), + &[ + "2020-08-21", + "2020-08-21", + "2020-08-22", + "2020-08-23", + "2020-08-22", + ], + ); + let s1 = Column::new(PlSmallStr::from_static("temp"), [20, 10, 7, 9, 1]); + let s2 = Column::new(PlSmallStr::from_static("rain"), [0.2, 0.1, 0.3, 0.1, 0.01]); + let df = DataFrame::new(vec![s0, s1, s2]).unwrap(); + + let out = df.group_by_stable(["date"])?.select(["temp"]).count()?; + assert_eq!( + out.column("temp_count")?, + &Column::new(PlSmallStr::from_static("temp_count"), [2 as IdxSize, 2, 1]) + ); + + // Use of deprecated mean() for testing purposes + #[allow(deprecated)] + // Select multiple + let out = df + .group_by_stable(["date"])? + .select(["temp", "rain"]) + .mean()?; + assert_eq!( + out.column("temp_mean")?, + &Column::new(PlSmallStr::from_static("temp_mean"), [15.0f64, 4.0, 9.0]) + ); + + // Use of deprecated `mean()` for testing purposes + #[allow(deprecated)] + // Group by multiple + let out = df + .group_by_stable(["date", "temp"])? + .select(["rain"]) + .mean()?; + assert!(out.column("rain_mean").is_ok()); + + // Use of deprecated `sum()` for testing purposes + #[allow(deprecated)] + let out = df.group_by_stable(["date"])?.select(["temp"]).sum()?; + assert_eq!( + out.column("temp_sum")?, + &Column::new(PlSmallStr::from_static("temp_sum"), [30, 8, 9]) + ); + + // Use of deprecated `n_unique()` for testing purposes + #[allow(deprecated)] + // implicit select all and only aggregate on methods that support that aggregation + let gb = df.group_by(["date"]).unwrap().n_unique().unwrap(); + // check the group by column is filtered out. + assert_eq!(gb.width(), 3); + Ok(()) + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_static_group_by_by_12_columns() { + // Build GroupBy DataFrame. + let s0 = Column::new("G1".into(), ["A", "A", "B", "B", "C"].as_ref()); + let s1 = Column::new("N".into(), [1, 2, 2, 4, 2].as_ref()); + let s2 = Column::new("G2".into(), ["k", "l", "m", "m", "l"].as_ref()); + let s3 = Column::new("G3".into(), ["a", "b", "c", "c", "d"].as_ref()); + let s4 = Column::new("G4".into(), ["1", "2", "3", "3", "4"].as_ref()); + let s5 = Column::new("G5".into(), ["X", "Y", "Z", "Z", "W"].as_ref()); + let s6 = Column::new("G6".into(), [false, true, true, true, false].as_ref()); + let s7 = Column::new("G7".into(), ["r", "x", "q", "q", "o"].as_ref()); + let s8 = Column::new("G8".into(), ["R", "X", "Q", "Q", "O"].as_ref()); + let s9 = Column::new("G9".into(), [1, 2, 3, 3, 4].as_ref()); + let s10 = Column::new("G10".into(), [".", "!", "?", "?", "/"].as_ref()); + let s11 = Column::new("G11".into(), ["(", ")", "@", "@", "$"].as_ref()); + let s12 = Column::new("G12".into(), ["-", "_", ";", ";", ","].as_ref()); + + let df = + DataFrame::new(vec![s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]).unwrap(); + + // Use of deprecated `sum()` for testing purposes + #[allow(deprecated)] + let adf = df + .group_by([ + "G1", "G2", "G3", "G4", "G5", "G6", "G7", "G8", "G9", "G10", "G11", "G12", + ]) + .unwrap() + .select(["N"]) + .sum() + .unwrap(); + + assert_eq!( + Vec::from(&adf.column("N_sum").unwrap().i32().unwrap().sort(false)), + &[Some(1), Some(2), Some(2), Some(6)] + ); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_dynamic_group_by_by_13_columns() { + // The content for every group_by series. + let series_content = ["A", "A", "B", "B", "C"]; + + // The name of every group_by series. + let series_names = [ + "G1", "G2", "G3", "G4", "G5", "G6", "G7", "G8", "G9", "G10", "G11", "G12", "G13", + ]; + + // Vector to contain every series. + let mut columns = Vec::with_capacity(14); + + // Create a series for every group name. + for series_name in series_names { + let group_columns = Column::new(series_name.into(), series_content.as_ref()); + columns.push(group_columns); + } + + // Create a series for the aggregation column. + let agg_series = Column::new("N".into(), [1, 2, 3, 3, 4].as_ref()); + columns.push(agg_series); + + // Create the dataframe with the computed series. + let df = DataFrame::new(columns).unwrap(); + + // Use of deprecated `sum()` for testing purposes + #[allow(deprecated)] + // Compute the aggregated DataFrame by the 13 columns defined in `series_names`. + let adf = df + .group_by(series_names) + .unwrap() + .select(["N"]) + .sum() + .unwrap(); + + // Check that the results of the group-by are correct. The content of every column + // is equal, then, the grouped columns shall be equal and in the same order. + for series_name in &series_names { + assert_eq!( + Vec::from(&adf.column(series_name).unwrap().str().unwrap().sort(false)), + &[Some("A"), Some("B"), Some("C")] + ); + } + + // Check the aggregated column is the expected one. + assert_eq!( + Vec::from(&adf.column("N_sum").unwrap().i32().unwrap().sort(false)), + &[Some(3), Some(4), Some(6)] + ); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_group_by_floats() { + let df = df! {"flt" => [1., 1., 2., 2., 3.], + "val" => [1, 1, 1, 1, 1] + } + .unwrap(); + // Use of deprecated `sum()` for testing purposes + #[allow(deprecated)] + let res = df.group_by(["flt"]).unwrap().sum().unwrap(); + let res = res.sort(["flt"], SortMultipleOptions::default()).unwrap(); + assert_eq!( + Vec::from(res.column("val_sum").unwrap().i32().unwrap()), + &[Some(2), Some(2), Some(1)] + ); + } + + #[test] + #[cfg_attr(miri, ignore)] + #[cfg(feature = "dtype-categorical")] + fn test_group_by_categorical() { + let mut df = df! {"foo" => ["a", "a", "b", "b", "c"], + "ham" => ["a", "a", "b", "b", "c"], + "bar" => [1, 1, 1, 1, 1] + } + .unwrap(); + + df.apply("foo", |s| { + s.cast(&DataType::Categorical(None, Default::default())) + .unwrap() + }) + .unwrap(); + + // Use of deprecated `sum()` for testing purposes + #[allow(deprecated)] + // check multiple keys and categorical + let res = df + .group_by_stable(["foo", "ham"]) + .unwrap() + .select(["bar"]) + .sum() + .unwrap(); + + assert_eq!( + Vec::from( + res.column("bar_sum") + .unwrap() + .as_materialized_series() + .i32() + .unwrap() + ), + &[Some(2), Some(2), Some(1)] + ); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_group_by_null_handling() -> PolarsResult<()> { + let df = df!( + "a" => ["a", "a", "a", "b", "b"], + "b" => [Some(1), Some(2), None, None, Some(1)] + )?; + // Use of deprecated `mean()` for testing purposes + #[allow(deprecated)] + let out = df.group_by_stable(["a"])?.mean()?; + + assert_eq!( + Vec::from(out.column("b_mean")?.as_materialized_series().f64()?), + &[Some(1.5), Some(1.0)] + ); + Ok(()) + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_group_by_var() -> PolarsResult<()> { + // check variance and proper coercion to f64 + let df = df![ + "g" => ["foo", "foo", "bar"], + "flt" => [1.0, 2.0, 3.0], + "int" => [1, 2, 3] + ]?; + + // Use of deprecated `sum()` for testing purposes + #[allow(deprecated)] + let out = df.group_by_stable(["g"])?.select(["int"]).var(1)?; + + assert_eq!(out.column("int_agg_var")?.f64()?.get(0), Some(0.5)); + // Use of deprecated `std()` for testing purposes + #[allow(deprecated)] + let out = df.group_by_stable(["g"])?.select(["int"]).std(1)?; + let val = out.column("int_agg_std")?.f64()?.get(0).unwrap(); + let expected = f64::FRAC_1_SQRT_2(); + assert!((val - expected).abs() < 0.000001); + Ok(()) + } + + #[test] + #[cfg_attr(miri, ignore)] + #[cfg(feature = "dtype-categorical")] + fn test_group_by_null_group() -> PolarsResult<()> { + // check if null is own group + let mut df = df![ + "g" => [Some("foo"), Some("foo"), Some("bar"), None, None], + "flt" => [1.0, 2.0, 3.0, 1.0, 1.0], + "int" => [1, 2, 3, 1, 1] + ]?; + + df.try_apply("g", |s| { + s.cast(&DataType::Categorical(None, Default::default())) + })?; + + // Use of deprecated `sum()` for testing purposes + #[allow(deprecated)] + let _ = df.group_by(["g"])?.sum()?; + Ok(()) + } +} diff --git a/crates/polars-core/src/frame/group_by/perfect.rs b/crates/polars-core/src/frame/group_by/perfect.rs new file mode 100644 index 000000000000..04ce4c54970e --- /dev/null +++ b/crates/polars-core/src/frame/group_by/perfect.rs @@ -0,0 +1,197 @@ +use std::fmt::Debug; +use std::mem::MaybeUninit; + +use num_traits::{FromPrimitive, ToPrimitive}; +use polars_utils::idx_vec::IdxVec; +use polars_utils::sync::SyncPtr; +use rayon::prelude::*; + +use crate::POOL; +#[cfg(all(feature = "dtype-categorical", feature = "performant"))] +use crate::config::verbose; +use crate::datatypes::*; +use crate::prelude::*; + +impl ChunkedArray +where + T: PolarsIntegerType, + T::Native: ToPrimitive + FromPrimitive + Debug, +{ + /// Use the indexes as perfect groups. + /// + /// # Safety + /// This ChunkedArray must contain each value in [0..num_groups) at least + /// once, and nothing outside this range. + pub unsafe fn group_tuples_perfect( + &self, + num_groups: usize, + mut multithreaded: bool, + group_capacity: usize, + ) -> GroupsType { + multithreaded &= POOL.current_num_threads() > 1; + // The latest index will be used for the null sentinel. + let len = if self.null_count() > 0 { + // We add one to store the null sentinel group. + num_groups + 1 + } else { + num_groups + }; + let null_idx = len.saturating_sub(1); + + let n_threads = POOL.current_num_threads(); + let chunk_size = len / n_threads; + + let (groups, first) = if multithreaded && chunk_size > 1 { + let mut groups: Vec = Vec::new(); + groups.resize_with(len, || IdxVec::with_capacity(group_capacity)); + let mut first: Vec = Vec::with_capacity(len); + + // Round up offsets to nearest cache line for groups to reduce false sharing. + let groups_start = groups.as_ptr(); + let mut per_thread_offsets = Vec::with_capacity(n_threads + 1); + per_thread_offsets.push(0); + for t in 0..n_threads { + let ideal_offset = (t + 1) * chunk_size; + let cache_aligned_offset = + ideal_offset + groups_start.wrapping_add(ideal_offset).align_offset(128); + if t == n_threads - 1 { + per_thread_offsets.push(len); + } else { + per_thread_offsets.push(std::cmp::min(cache_aligned_offset, len)); + } + } + + let groups_ptr = unsafe { SyncPtr::new(groups.as_mut_ptr()) }; + let first_ptr = unsafe { SyncPtr::new(first.as_mut_ptr()) }; + POOL.install(|| { + (0..n_threads).into_par_iter().for_each(|thread_no| { + // We use raw pointers because the slices would overlap. + // However, each thread has its own range it is responsible for. + let groups = groups_ptr.get(); + let first = first_ptr.get(); + let start = per_thread_offsets[thread_no]; + let start = T::Native::from_usize(start).unwrap(); + let end = per_thread_offsets[thread_no + 1]; + let end = T::Native::from_usize(end).unwrap(); + + if start == end && thread_no != n_threads - 1 { + return; + }; + + let push_to_group = |cat, row_nr| unsafe { + debug_assert!(cat < len); + let buf = &mut *groups.add(cat); + buf.push(row_nr); + if buf.len() == 1 { + *first.add(cat) = row_nr; + } + }; + + let mut row_nr = 0 as IdxSize; + for arr in self.downcast_iter() { + if arr.null_count() == 0 { + for &cat in arr.values().as_slice() { + if cat >= start && cat < end { + push_to_group(cat.to_usize().unwrap(), row_nr); + } + + row_nr += 1; + } + } else { + for opt_cat in arr.iter() { + if let Some(&cat) = opt_cat { + if cat >= start && cat < end { + push_to_group(cat.to_usize().unwrap(), row_nr); + } + } else if thread_no == n_threads - 1 { + // Last thread handles null values. + push_to_group(null_idx, row_nr); + } + + row_nr += 1; + } + } + } + }); + }); + unsafe { + first.set_len(len); + } + (groups, first) + } else { + let mut groups = Vec::with_capacity(len); + let mut first = Vec::with_capacity(len); + let first_out = first.spare_capacity_mut(); + groups.resize_with(len, || IdxVec::with_capacity(group_capacity)); + + let mut push_to_group = |cat, row_nr| unsafe { + let buf: &mut IdxVec = groups.get_unchecked_mut(cat); + buf.push(row_nr); + if buf.len() == 1 { + *first_out.get_unchecked_mut(cat) = MaybeUninit::new(row_nr); + } + }; + + let mut row_nr = 0 as IdxSize; + for arr in self.downcast_iter() { + for opt_cat in arr.iter() { + if let Some(cat) = opt_cat { + push_to_group(cat.to_usize().unwrap(), row_nr); + } else { + push_to_group(null_idx, row_nr); + } + + row_nr += 1; + } + } + unsafe { + first.set_len(len); + } + (groups, first) + }; + + // NOTE! we set sorted here! + // this happens to be true for `fast_unique` categoricals + GroupsType::Idx(GroupsIdx::new(first, groups, true)) + } +} + +#[cfg(all(feature = "dtype-categorical", feature = "performant"))] +// Special implementation so that cats can be processed in a single pass +impl CategoricalChunked { + // Use the indexes as perfect groups + pub fn group_tuples_perfect(&self, multithreaded: bool, sorted: bool) -> GroupsType { + let rev_map = self.get_rev_map(); + if self.is_empty() { + return GroupsType::Idx(GroupsIdx::new(vec![], vec![], true)); + } + let cats = self.physical(); + + let mut out = match &**rev_map { + RevMapping::Local(cached, _) => { + if self._can_fast_unique() { + assert!(cached.len() <= self.len(), "invalid invariant"); + if verbose() { + eprintln!("grouping categoricals, run perfect hash function"); + } + // on relative small tables this isn't much faster than the default strategy + // but on huge tables, this can be > 2x faster + unsafe { cats.group_tuples_perfect(cached.len(), multithreaded, 0) } + } else { + self.physical().group_tuples(multithreaded, sorted).unwrap() + } + }, + RevMapping::Global(_mapping, _cached, _) => { + // TODO! see if we can optimize this + // the problem is that the global categories are not guaranteed packed together + // so we might need to deref them first to local ones, but that might be more + // expensive than just hashing (benchmark first) + self.physical().group_tuples(multithreaded, sorted).unwrap() + }, + }; + if sorted { + out.sort() + } + out + } +} diff --git a/crates/polars-core/src/frame/group_by/position.rs b/crates/polars-core/src/frame/group_by/position.rs new file mode 100644 index 000000000000..cbd6089e98fa --- /dev/null +++ b/crates/polars-core/src/frame/group_by/position.rs @@ -0,0 +1,734 @@ +use std::mem::ManuallyDrop; +use std::ops::{Deref, DerefMut}; + +use arrow::offset::OffsetsBuffer; +use polars_utils::idx_vec::IdxVec; +use rayon::iter::plumbing::UnindexedConsumer; +use rayon::prelude::*; + +use crate::POOL; +use crate::prelude::*; +use crate::utils::{NoNull, flatten, slice_slice}; + +/// Indexes of the groups, the first index is stored separately. +/// this make sorting fast. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct GroupsIdx { + pub(crate) sorted: bool, + first: Vec, + all: Vec, +} + +pub type IdxItem = (IdxSize, IdxVec); +pub type BorrowIdxItem<'a> = (IdxSize, &'a IdxVec); + +impl Drop for GroupsIdx { + fn drop(&mut self) { + let v = std::mem::take(&mut self.all); + // ~65k took approximately 1ms on local machine, so from that point we drop on other thread + // to stop query from being blocked + #[cfg(not(target_family = "wasm"))] + if v.len() > 1 << 16 { + std::thread::spawn(move || drop(v)); + } else { + drop(v); + } + + #[cfg(target_family = "wasm")] + drop(v); + } +} + +impl From> for GroupsIdx { + fn from(v: Vec) -> Self { + v.into_iter().collect() + } +} + +impl From>> for GroupsIdx { + fn from(v: Vec>) -> Self { + // single threaded flatten: 10% faster than `iter().flatten().collect() + // this is the multi-threaded impl of that + let (cap, offsets) = flatten::cap_and_offsets(&v); + let mut first = Vec::with_capacity(cap); + let first_ptr = first.as_ptr() as usize; + let mut all = Vec::with_capacity(cap); + let all_ptr = all.as_ptr() as usize; + + POOL.install(|| { + v.into_par_iter() + .zip(offsets) + .for_each(|(mut inner, offset)| { + unsafe { + let first = (first_ptr as *const IdxSize as *mut IdxSize).add(offset); + let all = (all_ptr as *const IdxVec as *mut IdxVec).add(offset); + + let inner_ptr = inner.as_mut_ptr(); + for i in 0..inner.len() { + let (first_val, vals) = std::ptr::read(inner_ptr.add(i)); + std::ptr::write(first.add(i), first_val); + std::ptr::write(all.add(i), vals); + } + // set len to 0 so that the contents will not get dropped + // they are moved to `first` and `all` + inner.set_len(0); + } + }); + }); + unsafe { + all.set_len(cap); + first.set_len(cap); + } + GroupsIdx { + sorted: false, + first, + all, + } + } +} + +impl GroupsIdx { + pub fn new(first: Vec, all: Vec, sorted: bool) -> Self { + Self { sorted, first, all } + } + + pub fn sort(&mut self) { + if self.sorted { + return; + } + let mut idx = 0; + let first = std::mem::take(&mut self.first); + // store index and values so that we can sort those + let mut idx_vals = first + .into_iter() + .map(|v| { + let out = [idx, v]; + idx += 1; + out + }) + .collect_trusted::>(); + idx_vals.sort_unstable_by_key(|v| v[1]); + + let take_first = || idx_vals.iter().map(|v| v[1]).collect_trusted::>(); + let take_all = || { + idx_vals + .iter() + .map(|v| unsafe { + let idx = v[0] as usize; + std::mem::take(self.all.get_unchecked_mut(idx)) + }) + .collect_trusted::>() + }; + let (first, all) = POOL.install(|| rayon::join(take_first, take_all)); + self.first = first; + self.all = all; + self.sorted = true + } + pub fn is_sorted_flag(&self) -> bool { + self.sorted + } + + pub fn iter( + &self, + ) -> std::iter::Zip>, std::slice::Iter> + { + self.into_iter() + } + + pub fn all(&self) -> &[IdxVec] { + &self.all + } + + pub fn first(&self) -> &[IdxSize] { + &self.first + } + + pub fn first_mut(&mut self) -> &mut Vec { + &mut self.first + } + + pub(crate) fn len(&self) -> usize { + self.first.len() + } + + pub(crate) unsafe fn get_unchecked(&self, index: usize) -> BorrowIdxItem { + let first = *self.first.get_unchecked(index); + let all = self.all.get_unchecked(index); + (first, all) + } +} + +impl FromIterator for GroupsIdx { + fn from_iter>(iter: T) -> Self { + let (first, all) = iter.into_iter().unzip(); + GroupsIdx { + sorted: false, + first, + all, + } + } +} + +impl<'a> IntoIterator for &'a GroupsIdx { + type Item = BorrowIdxItem<'a>; + type IntoIter = std::iter::Zip< + std::iter::Copied>, + std::slice::Iter<'a, IdxVec>, + >; + + fn into_iter(self) -> Self::IntoIter { + self.first.iter().copied().zip(self.all.iter()) + } +} + +impl IntoIterator for GroupsIdx { + type Item = IdxItem; + type IntoIter = std::iter::Zip, std::vec::IntoIter>; + + fn into_iter(mut self) -> Self::IntoIter { + let first = std::mem::take(&mut self.first); + let all = std::mem::take(&mut self.all); + first.into_iter().zip(all) + } +} + +impl FromParallelIterator for GroupsIdx { + fn from_par_iter(par_iter: I) -> Self + where + I: IntoParallelIterator, + { + let (first, all) = par_iter.into_par_iter().unzip(); + GroupsIdx { + sorted: false, + first, + all, + } + } +} + +impl<'a> IntoParallelIterator for &'a GroupsIdx { + type Iter = rayon::iter::Zip< + rayon::iter::Copied>, + rayon::slice::Iter<'a, IdxVec>, + >; + type Item = BorrowIdxItem<'a>; + + fn into_par_iter(self) -> Self::Iter { + self.first.par_iter().copied().zip(self.all.par_iter()) + } +} + +impl IntoParallelIterator for GroupsIdx { + type Iter = rayon::iter::Zip, rayon::vec::IntoIter>; + type Item = IdxItem; + + fn into_par_iter(mut self) -> Self::Iter { + let first = std::mem::take(&mut self.first); + let all = std::mem::take(&mut self.all); + first.into_par_iter().zip(all.into_par_iter()) + } +} + +/// Every group is indicated by an array where the +/// - first value is an index to the start of the group +/// - second value is the length of the group +/// +/// Only used when group values are stored together +/// +/// This type should have the invariant that it is always sorted in ascending order. +pub type GroupsSlice = Vec<[IdxSize; 2]>; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum GroupsType { + Idx(GroupsIdx), + /// Slice is always sorted in ascending order. + Slice { + // the groups slices + groups: GroupsSlice, + // indicates if we do a rolling group_by + rolling: bool, + }, +} + +impl Default for GroupsType { + fn default() -> Self { + GroupsType::Idx(GroupsIdx::default()) + } +} + +impl GroupsType { + pub fn into_idx(self) -> GroupsIdx { + match self { + GroupsType::Idx(groups) => groups, + GroupsType::Slice { groups, .. } => { + polars_warn!( + "Had to reallocate groups, missed an optimization opportunity. Please open an issue." + ); + groups + .iter() + .map(|&[first, len]| (first, (first..first + len).collect::())) + .collect() + }, + } + } + + pub(crate) fn prepare_list_agg( + &self, + total_len: usize, + ) -> (Option, OffsetsBuffer, bool) { + let mut can_fast_explode = true; + match self { + GroupsType::Idx(groups) => { + let mut list_offset = Vec::with_capacity(self.len() + 1); + let mut gather_offsets = Vec::with_capacity(total_len); + + let mut len_so_far = 0i64; + list_offset.push(len_so_far); + + for idx in groups { + let idx = idx.1; + gather_offsets.extend_from_slice(idx); + len_so_far += idx.len() as i64; + list_offset.push(len_so_far); + can_fast_explode &= !idx.is_empty(); + } + unsafe { + ( + Some(IdxCa::from_vec(PlSmallStr::EMPTY, gather_offsets)), + OffsetsBuffer::new_unchecked(list_offset.into()), + can_fast_explode, + ) + } + }, + GroupsType::Slice { groups, .. } => { + let mut list_offset = Vec::with_capacity(self.len() + 1); + let mut gather_offsets = Vec::with_capacity(total_len); + let mut len_so_far = 0i64; + list_offset.push(len_so_far); + + for g in groups { + let len = g[1]; + let offset = g[0]; + gather_offsets.extend(offset..offset + len); + + len_so_far += len as i64; + list_offset.push(len_so_far); + can_fast_explode &= len > 0; + } + + unsafe { + ( + Some(IdxCa::from_vec(PlSmallStr::EMPTY, gather_offsets)), + OffsetsBuffer::new_unchecked(list_offset.into()), + can_fast_explode, + ) + } + }, + } + } + + pub fn iter(&self) -> GroupsTypeIter { + GroupsTypeIter::new(self) + } + + pub fn sort(&mut self) { + match self { + GroupsType::Idx(groups) => { + if !groups.is_sorted_flag() { + groups.sort() + } + }, + GroupsType::Slice { .. } => { + // invariant of the type + }, + } + } + + pub(crate) fn is_sorted_flag(&self) -> bool { + match self { + GroupsType::Idx(groups) => groups.is_sorted_flag(), + GroupsType::Slice { .. } => true, + } + } + + pub fn take_group_firsts(self) -> Vec { + match self { + GroupsType::Idx(mut groups) => std::mem::take(&mut groups.first), + GroupsType::Slice { groups, .. } => { + groups.into_iter().map(|[first, _len]| first).collect() + }, + } + } + + /// # Safety + /// This will not do any bounds checks. The caller must ensure + /// all groups have members. + pub unsafe fn take_group_lasts(self) -> Vec { + match self { + GroupsType::Idx(groups) => groups + .all + .iter() + .map(|idx| *idx.get_unchecked(idx.len() - 1)) + .collect(), + GroupsType::Slice { groups, .. } => groups + .into_iter() + .map(|[first, len]| first + len - 1) + .collect(), + } + } + + pub fn par_iter(&self) -> GroupsTypeParIter { + GroupsTypeParIter::new(self) + } + + /// Get a reference to the `GroupsIdx`. + /// + /// # Panic + /// + /// panics if the groups are a slice. + pub fn unwrap_idx(&self) -> &GroupsIdx { + match self { + GroupsType::Idx(groups) => groups, + GroupsType::Slice { .. } => panic!("groups are slices not index"), + } + } + + /// Get a reference to the `GroupsSlice`. + /// + /// # Panic + /// + /// panics if the groups are an idx. + pub fn unwrap_slice(&self) -> &GroupsSlice { + match self { + GroupsType::Slice { groups, .. } => groups, + GroupsType::Idx(_) => panic!("groups are index not slices"), + } + } + + pub fn get(&self, index: usize) -> GroupsIndicator { + match self { + GroupsType::Idx(groups) => { + let first = groups.first[index]; + let all = &groups.all[index]; + GroupsIndicator::Idx((first, all)) + }, + GroupsType::Slice { groups, .. } => GroupsIndicator::Slice(groups[index]), + } + } + + /// Get a mutable reference to the `GroupsIdx`. + /// + /// # Panic + /// + /// panics if the groups are a slice. + pub fn idx_mut(&mut self) -> &mut GroupsIdx { + match self { + GroupsType::Idx(groups) => groups, + GroupsType::Slice { .. } => panic!("groups are slices not index"), + } + } + + pub fn len(&self) -> usize { + match self { + GroupsType::Idx(groups) => groups.len(), + GroupsType::Slice { groups, .. } => groups.len(), + } + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn group_count(&self) -> IdxCa { + match self { + GroupsType::Idx(groups) => { + let ca: NoNull = groups + .iter() + .map(|(_first, idx)| idx.len() as IdxSize) + .collect_trusted(); + ca.into_inner() + }, + GroupsType::Slice { groups, .. } => { + let ca: NoNull = groups.iter().map(|[_first, len]| *len).collect_trusted(); + ca.into_inner() + }, + } + } + pub fn as_list_chunked(&self) -> ListChunked { + match self { + GroupsType::Idx(groups) => groups + .iter() + .map(|(_first, idx)| { + let ca: NoNull = idx.iter().map(|&v| v as IdxSize).collect(); + ca.into_inner().into_series() + }) + .collect_trusted(), + GroupsType::Slice { groups, .. } => groups + .iter() + .map(|&[first, len]| { + let ca: NoNull = (first..first + len).collect_trusted(); + ca.into_inner().into_series() + }) + .collect_trusted(), + } + } + + pub fn into_sliceable(self) -> GroupPositions { + let len = self.len(); + slice_groups(Arc::new(self), 0, len) + } +} + +impl From for GroupsType { + fn from(groups: GroupsIdx) -> Self { + GroupsType::Idx(groups) + } +} + +pub enum GroupsIndicator<'a> { + Idx(BorrowIdxItem<'a>), + Slice([IdxSize; 2]), +} + +impl GroupsIndicator<'_> { + pub fn len(&self) -> usize { + match self { + GroupsIndicator::Idx(g) => g.1.len(), + GroupsIndicator::Slice([_, len]) => *len as usize, + } + } + pub fn first(&self) -> IdxSize { + match self { + GroupsIndicator::Idx(g) => g.0, + GroupsIndicator::Slice([first, _]) => *first, + } + } + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +pub struct GroupsTypeIter<'a> { + vals: &'a GroupsType, + len: usize, + idx: usize, +} + +impl<'a> GroupsTypeIter<'a> { + fn new(vals: &'a GroupsType) -> Self { + let len = vals.len(); + let idx = 0; + GroupsTypeIter { vals, len, idx } + } +} + +impl<'a> Iterator for GroupsTypeIter<'a> { + type Item = GroupsIndicator<'a>; + + fn nth(&mut self, n: usize) -> Option { + self.idx = self.idx.saturating_add(n); + self.next() + } + + fn next(&mut self) -> Option { + if self.idx >= self.len { + return None; + } + + let out = unsafe { + match self.vals { + GroupsType::Idx(groups) => { + let item = groups.get_unchecked(self.idx); + Some(GroupsIndicator::Idx(item)) + }, + GroupsType::Slice { groups, .. } => { + Some(GroupsIndicator::Slice(*groups.get_unchecked(self.idx))) + }, + } + }; + self.idx += 1; + out + } +} + +pub struct GroupsTypeParIter<'a> { + vals: &'a GroupsType, + len: usize, +} + +impl<'a> GroupsTypeParIter<'a> { + fn new(vals: &'a GroupsType) -> Self { + let len = vals.len(); + GroupsTypeParIter { vals, len } + } +} + +impl<'a> ParallelIterator for GroupsTypeParIter<'a> { + type Item = GroupsIndicator<'a>; + + fn drive_unindexed(self, consumer: C) -> C::Result + where + C: UnindexedConsumer, + { + (0..self.len) + .into_par_iter() + .map(|i| unsafe { + match self.vals { + GroupsType::Idx(groups) => GroupsIndicator::Idx(groups.get_unchecked(i)), + GroupsType::Slice { groups, .. } => { + GroupsIndicator::Slice(*groups.get_unchecked(i)) + }, + } + }) + .drive_unindexed(consumer) + } +} + +#[derive(Debug)] +pub struct GroupPositions { + // SAFETY: sliced is a shallow clone of original + // It emulates a shared reference, not an exclusive reference + // Its data must not be mutated through direct access + sliced: ManuallyDrop, + // Unsliced buffer + original: Arc, + offset: i64, + len: usize, +} + +impl Clone for GroupPositions { + fn clone(&self) -> Self { + let sliced = slice_groups_inner(&self.original, self.offset, self.len); + + Self { + sliced, + original: self.original.clone(), + offset: self.offset, + len: self.len, + } + } +} + +impl PartialEq for GroupPositions { + fn eq(&self, other: &Self) -> bool { + self.offset == other.offset && self.len == other.len && self.sliced == other.sliced + } +} + +impl AsRef for GroupPositions { + fn as_ref(&self) -> &GroupsType { + self.sliced.deref() + } +} + +impl Deref for GroupPositions { + type Target = GroupsType; + + fn deref(&self) -> &Self::Target { + self.sliced.deref() + } +} + +impl Default for GroupPositions { + fn default() -> Self { + GroupsType::default().into_sliceable() + } +} + +impl GroupPositions { + pub fn slice(&self, offset: i64, len: usize) -> Self { + let offset = self.offset + offset; + slice_groups( + self.original.clone(), + offset, + // invariant that len should be in bounds, so truncate if not + if len > self.len { self.len } else { len }, + ) + } + + pub fn sort(&mut self) { + if !self.as_ref().is_sorted_flag() { + let original = Arc::make_mut(&mut self.original); + original.sort(); + + self.sliced = slice_groups_inner(original, self.offset, self.len); + } + } + + pub fn unroll(mut self) -> GroupPositions { + match self.sliced.deref_mut() { + GroupsType::Idx(_) => self, + GroupsType::Slice { rolling: false, .. } => self, + GroupsType::Slice { groups, .. } => { + // SAFETY: sliced is a shallow partial clone of original. + // A new owning Vec is required per GH issue #21859 + let mut cum_offset = 0 as IdxSize; + let groups: Vec<_> = groups + .iter() + .map(|[_, len]| { + let new = [cum_offset, *len]; + cum_offset += *len; + new + }) + .collect(); + + GroupsType::Slice { + groups, + rolling: false, + } + .into_sliceable() + }, + } + } +} + +fn slice_groups_inner(g: &GroupsType, offset: i64, len: usize) -> ManuallyDrop { + // SAFETY: + // we create new `Vec`s from the sliced groups. But we wrap them in ManuallyDrop + // so that we never call drop on them. + // These groups lifetimes are bounded to the `g`. This must remain valid + // for the scope of the aggregation. + match g { + GroupsType::Idx(groups) => { + let first = unsafe { + let first = slice_slice(groups.first(), offset, len); + let ptr = first.as_ptr() as *mut _; + Vec::from_raw_parts(ptr, first.len(), first.len()) + }; + + let all = unsafe { + let all = slice_slice(groups.all(), offset, len); + let ptr = all.as_ptr() as *mut _; + Vec::from_raw_parts(ptr, all.len(), all.len()) + }; + ManuallyDrop::new(GroupsType::Idx(GroupsIdx::new( + first, + all, + groups.is_sorted_flag(), + ))) + }, + GroupsType::Slice { groups, rolling } => { + let groups = unsafe { + let groups = slice_slice(groups, offset, len); + let ptr = groups.as_ptr() as *mut _; + Vec::from_raw_parts(ptr, groups.len(), groups.len()) + }; + + ManuallyDrop::new(GroupsType::Slice { + groups, + rolling: *rolling, + }) + }, + } +} + +fn slice_groups(g: Arc, offset: i64, len: usize) -> GroupPositions { + let sliced = slice_groups_inner(g.as_ref(), offset, len); + + GroupPositions { + sliced, + original: g, + offset, + len, + } +} diff --git a/crates/polars-core/src/frame/horizontal.rs b/crates/polars-core/src/frame/horizontal.rs new file mode 100644 index 000000000000..5061890204b4 --- /dev/null +++ b/crates/polars-core/src/frame/horizontal.rs @@ -0,0 +1,156 @@ +use polars_error::{PolarsResult, polars_err}; + +use super::Column; +use crate::datatypes::AnyValue; +use crate::frame::DataFrame; + +impl DataFrame { + /// Add columns horizontally. + /// + /// # Safety + /// The caller must ensure: + /// - the length of all [`Column`] is equal to the height of this [`DataFrame`] + /// - the columns names are unique + /// + /// Note: If `self` is empty, `self.height` will always be overridden by the height of the first + /// column in `columns`. + /// + /// Note that on a debug build this will panic on duplicates / height mismatch. + pub unsafe fn hstack_mut_unchecked(&mut self, columns: &[Column]) -> &mut Self { + self.clear_schema(); + self.columns.extend_from_slice(columns); + + if cfg!(debug_assertions) { + if let err @ Err(_) = DataFrame::validate_columns_slice(&self.columns) { + // Reset DataFrame state to before extend. + self.columns.truncate(self.columns.len() - columns.len()); + err.unwrap(); + } + } + + if let Some(c) = self.columns.first() { + unsafe { self.set_height(c.len()) }; + } + + self + } + + /// Add multiple [`Column`] to a [`DataFrame`]. + /// Errors if the resulting DataFrame columns have duplicate names or unequal heights. + /// + /// Note: If `self` is empty, `self.height` will always be overridden by the height of the first + /// column in `columns`. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// fn stack(df: &mut DataFrame, columns: &[Column]) { + /// df.hstack_mut(columns); + /// } + /// ``` + pub fn hstack_mut(&mut self, columns: &[Column]) -> PolarsResult<&mut Self> { + self.clear_schema(); + self.columns.extend_from_slice(columns); + + if let err @ Err(_) = DataFrame::validate_columns_slice(&self.columns) { + // Reset DataFrame state to before extend. + self.columns.truncate(self.columns.len() - columns.len()); + err?; + } + + if let Some(c) = self.columns.first() { + unsafe { self.set_height(c.len()) }; + } + + Ok(self) + } +} + +/// Concat [`DataFrame`]s horizontally. +/// Concat horizontally and extend with null values if lengths don't match +pub fn concat_df_horizontal(dfs: &[DataFrame], check_duplicates: bool) -> PolarsResult { + let output_height = dfs + .iter() + .map(|df| df.height()) + .max() + .ok_or_else(|| polars_err!(ComputeError: "cannot concat empty dataframes"))?; + + let owned_df; + + let mut out_width = 0; + + let all_equal_height = dfs.iter().all(|df| { + out_width += df.width(); + df.height() == output_height + }); + + // if not all equal length, extend the DataFrame with nulls + let dfs = if !all_equal_height { + out_width = 0; + + owned_df = dfs + .iter() + .cloned() + .map(|mut df| { + out_width += df.width(); + + if df.height() != output_height { + let diff = output_height - df.height(); + + // SAFETY: We extend each column with nulls to the point of being of length + // `output_height`. Then, we set the height of the resulting dataframe. + unsafe { df.get_columns_mut() }.iter_mut().for_each(|c| { + *c = c.extend_constant(AnyValue::Null, diff).unwrap(); + }); + df.clear_schema(); + unsafe { + df.set_height(output_height); + } + } + df + }) + .collect::>(); + owned_df.as_slice() + } else { + dfs + }; + + let mut acc_cols = Vec::with_capacity(out_width); + + for df in dfs { + acc_cols.extend(df.get_columns().iter().cloned()); + } + + if check_duplicates { + DataFrame::validate_columns_slice(&acc_cols)?; + } + + let df = unsafe { DataFrame::new_no_checks_height_from_first(acc_cols) }; + + Ok(df) +} + +#[cfg(test)] +mod tests { + use polars_error::PolarsError; + + #[test] + fn test_hstack_mut_empty_frame_height_validation() { + use crate::frame::DataFrame; + use crate::prelude::{Column, DataType}; + let mut df = DataFrame::empty(); + let result = df.hstack_mut(&[ + Column::full_null("a".into(), 1, &DataType::Null), + Column::full_null("b".into(), 3, &DataType::Null), + ]); + + assert!( + matches!(result, Err(PolarsError::ShapeMismatch(_))), + "expected shape mismatch error" + ); + + // Ensure the DataFrame is not mutated in the error case. + assert_eq!(df.width(), 0); + } +} diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs new file mode 100644 index 000000000000..106d596202ac --- /dev/null +++ b/crates/polars-core/src/frame/mod.rs @@ -0,0 +1,3637 @@ +#![allow(unsafe_op_in_unsafe_fn)] +//! DataFrame module. +use std::sync::OnceLock; +use std::{mem, ops}; + +use arrow::datatypes::ArrowSchemaRef; +use polars_row::ArrayRef; +use polars_schema::schema::ensure_matching_schema_names; +use polars_utils::itertools::Itertools; +use rayon::prelude::*; + +use crate::chunked_array::flags::StatisticsFlags; +#[cfg(feature = "algorithm_group_by")] +use crate::chunked_array::ops::unique::is_unique_helper; +use crate::prelude::*; +#[cfg(feature = "row_hash")] +use crate::utils::split_df; +use crate::utils::{Container, NoNull, slice_offsets, try_get_supertype}; +use crate::{HEAD_DEFAULT_LENGTH, TAIL_DEFAULT_LENGTH}; + +#[cfg(feature = "dataframe_arithmetic")] +mod arithmetic; +pub mod builder; +mod chunks; +pub use chunks::chunk_df_for_writing; +pub mod column; +pub mod explode; +mod from; +#[cfg(feature = "algorithm_group_by")] +pub mod group_by; +pub(crate) mod horizontal; +#[cfg(any(feature = "rows", feature = "object"))] +pub mod row; +mod top_k; +mod upstream_traits; +mod validation; + +use arrow::record_batch::{RecordBatch, RecordBatchT}; +use polars_utils::pl_str::PlSmallStr; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; + +use crate::POOL; +#[cfg(feature = "row_hash")] +use crate::hashing::_df_rows_to_hashes_threaded_vertical; +use crate::prelude::sort::{argsort_multiple_row_fmt, prepare_arg_sort}; +use crate::series::IsSorted; + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash, IntoStaticStr)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] +pub enum UniqueKeepStrategy { + /// Keep the first unique row. + First, + /// Keep the last unique row. + Last, + /// Keep None of the unique rows. + None, + /// Keep any of the unique rows + /// This allows more optimizations + #[default] + Any, +} + +fn ensure_names_unique(items: &[T], mut get_name: F) -> PolarsResult<()> +where + F: for<'a> FnMut(&'a T) -> &'a str, +{ + // Always unique. + if items.len() <= 1 { + return Ok(()); + } + + if items.len() <= 4 { + // Too small to be worth spawning a hashmap for, this is at most 6 comparisons. + for i in 0..items.len() - 1 { + let name = get_name(&items[i]); + for other in items.iter().skip(i + 1) { + if name == get_name(other) { + polars_bail!(duplicate = name); + } + } + } + } else { + let mut names = PlHashSet::with_capacity(items.len()); + for item in items { + let name = get_name(item); + if !names.insert(name) { + polars_bail!(duplicate = name); + } + } + } + Ok(()) +} + +/// A contiguous growable collection of `Series` that have the same length. +/// +/// ## Use declarations +/// +/// All the common tools can be found in [`crate::prelude`] (or in `polars::prelude`). +/// +/// ```rust +/// use polars_core::prelude::*; // if the crate polars-core is used directly +/// // use polars::prelude::*; if the crate polars is used +/// ``` +/// +/// # Initialization +/// ## Default +/// +/// A `DataFrame` can be initialized empty: +/// +/// ```rust +/// # use polars_core::prelude::*; +/// let df = DataFrame::default(); +/// assert!(df.is_empty()); +/// ``` +/// +/// ## Wrapping a `Vec` +/// +/// A `DataFrame` is built upon a `Vec` where the `Series` have the same length. +/// +/// ```rust +/// # use polars_core::prelude::*; +/// let s1 = Column::new("Fruit".into(), ["Apple", "Apple", "Pear"]); +/// let s2 = Column::new("Color".into(), ["Red", "Yellow", "Green"]); +/// +/// let df: PolarsResult = DataFrame::new(vec![s1, s2]); +/// ``` +/// +/// ## Using a macro +/// +/// The [`df!`] macro is a convenient method: +/// +/// ```rust +/// # use polars_core::prelude::*; +/// let df: PolarsResult = df!("Fruit" => ["Apple", "Apple", "Pear"], +/// "Color" => ["Red", "Yellow", "Green"]); +/// ``` +/// +/// ## Using a CSV file +/// +/// See the `polars_io::csv::CsvReader`. +/// +/// # Indexing +/// ## By a number +/// +/// The `Index` is implemented for the `DataFrame`. +/// +/// ```rust +/// # use polars_core::prelude::*; +/// let df = df!("Fruit" => ["Apple", "Apple", "Pear"], +/// "Color" => ["Red", "Yellow", "Green"])?; +/// +/// assert_eq!(df[0], Column::new("Fruit".into(), &["Apple", "Apple", "Pear"])); +/// assert_eq!(df[1], Column::new("Color".into(), &["Red", "Yellow", "Green"])); +/// # Ok::<(), PolarsError>(()) +/// ``` +/// +/// ## By a `Series` name +/// +/// ```rust +/// # use polars_core::prelude::*; +/// let df = df!("Fruit" => ["Apple", "Apple", "Pear"], +/// "Color" => ["Red", "Yellow", "Green"])?; +/// +/// assert_eq!(df["Fruit"], Column::new("Fruit".into(), &["Apple", "Apple", "Pear"])); +/// assert_eq!(df["Color"], Column::new("Color".into(), &["Red", "Yellow", "Green"])); +/// # Ok::<(), PolarsError>(()) +/// ``` +#[derive(Clone)] +pub struct DataFrame { + height: usize, + // invariant: columns[i].len() == height for each 0 >= i > columns.len() + pub(crate) columns: Vec, + + /// A cached schema. This might not give correct results if the DataFrame was modified in place + /// between schema and reading. + cached_schema: OnceLock, +} + +impl DataFrame { + pub fn clear_schema(&mut self) { + self.cached_schema = OnceLock::new(); + } + + #[inline] + pub fn column_iter(&self) -> impl ExactSizeIterator { + self.columns.iter() + } + + #[inline] + pub fn materialized_column_iter(&self) -> impl ExactSizeIterator { + self.columns.iter().map(Column::as_materialized_series) + } + + #[inline] + pub fn par_materialized_column_iter(&self) -> impl ParallelIterator { + self.columns.par_iter().map(Column::as_materialized_series) + } + + /// Returns an estimation of the total (heap) allocated size of the `DataFrame` in bytes. + /// + /// # Implementation + /// This estimation is the sum of the size of its buffers, validity, including nested arrays. + /// Multiple arrays may share buffers and bitmaps. Therefore, the size of 2 arrays is not the + /// sum of the sizes computed from this function. In particular, [`StructArray`]'s size is an upper bound. + /// + /// When an array is sliced, its allocated size remains constant because the buffer unchanged. + /// However, this function will yield a smaller number. This is because this function returns + /// the visible size of the buffer, not its total capacity. + /// + /// FFI buffers are included in this estimation. + pub fn estimated_size(&self) -> usize { + self.columns.iter().map(Column::estimated_size).sum() + } + + // Reduce monomorphization. + fn try_apply_columns( + &self, + func: &(dyn Fn(&Column) -> PolarsResult + Send + Sync), + ) -> PolarsResult> { + self.columns.iter().map(func).collect() + } + // Reduce monomorphization. + pub fn _apply_columns(&self, func: &(dyn Fn(&Column) -> Column)) -> Vec { + self.columns.iter().map(func).collect() + } + // Reduce monomorphization. + fn try_apply_columns_par( + &self, + func: &(dyn Fn(&Column) -> PolarsResult + Send + Sync), + ) -> PolarsResult> { + POOL.install(|| self.columns.par_iter().map(func).collect()) + } + // Reduce monomorphization. + pub fn _apply_columns_par( + &self, + func: &(dyn Fn(&Column) -> Column + Send + Sync), + ) -> Vec { + POOL.install(|| self.columns.par_iter().map(func).collect()) + } + + /// Get the index of the column. + fn check_name_to_idx(&self, name: &str) -> PolarsResult { + self.get_column_index(name) + .ok_or_else(|| polars_err!(col_not_found = name)) + } + + fn check_already_present(&self, name: &str) -> PolarsResult<()> { + polars_ensure!( + self.columns.iter().all(|s| s.name().as_str() != name), + Duplicate: "column with name {:?} is already present in the DataFrame", name + ); + Ok(()) + } + + /// Reserve additional slots into the chunks of the series. + pub(crate) fn reserve_chunks(&mut self, additional: usize) { + for s in &mut self.columns { + if let Column::Series(s) = s { + // SAFETY: + // do not modify the data, simply resize. + unsafe { s.chunks_mut().reserve(additional) } + } + } + } + + /// Create a DataFrame from a Vector of Series. + /// + /// Errors if a column names are not unique, or if heights are not all equal. + /// + /// # Example + /// + /// ``` + /// # use polars_core::prelude::*; + /// let s0 = Column::new("days".into(), [0, 1, 2].as_ref()); + /// let s1 = Column::new("temp".into(), [22.1, 19.9, 7.].as_ref()); + /// + /// let df = DataFrame::new(vec![s0, s1])?; + /// # Ok::<(), PolarsError>(()) + /// ``` + pub fn new(columns: Vec) -> PolarsResult { + DataFrame::validate_columns_slice(&columns) + .map_err(|e| e.wrap_msg(|e| format!("could not create a new DataFrame: {}", e)))?; + Ok(unsafe { Self::new_no_checks_height_from_first(columns) }) + } + + pub fn new_with_height(height: usize, columns: Vec) -> PolarsResult { + for col in &columns { + polars_ensure!( + col.len() == height, + ShapeMismatch: "could not create a new DataFrame: series {:?} has length {} while series {:?} has length {}", + columns[0].name(), height, col.name(), col.len() + ); + } + + Ok(DataFrame { + height, + columns, + cached_schema: OnceLock::new(), + }) + } + + /// Converts a sequence of columns into a DataFrame, broadcasting length-1 + /// columns to match the other columns. + pub fn new_with_broadcast(columns: Vec) -> PolarsResult { + // The length of the longest non-unit length column determines the + // broadcast length. If all columns are unit-length the broadcast length + // is one. + let broadcast_len = columns + .iter() + .map(|s| s.len()) + .filter(|l| *l != 1) + .max() + .unwrap_or(1); + Self::new_with_broadcast_len(columns, broadcast_len) + } + + /// Converts a sequence of columns into a DataFrame, broadcasting length-1 + /// columns to broadcast_len. + pub fn new_with_broadcast_len( + columns: Vec, + broadcast_len: usize, + ) -> PolarsResult { + ensure_names_unique(&columns, |s| s.name().as_str())?; + unsafe { Self::new_with_broadcast_no_namecheck(columns, broadcast_len) } + } + + /// Converts a sequence of columns into a DataFrame, broadcasting length-1 + /// columns to match the other columns. + /// + /// # Safety + /// Does not check that the column names are unique (which they must be). + pub unsafe fn new_with_broadcast_no_namecheck( + mut columns: Vec, + broadcast_len: usize, + ) -> PolarsResult { + for col in &mut columns { + // Length not equal to the broadcast len, needs broadcast or is an error. + let len = col.len(); + if len != broadcast_len { + if len != 1 { + let name = col.name().to_owned(); + let extra_info = + if let Some(c) = columns.iter().find(|c| c.len() == broadcast_len) { + format!(" (matching column '{}')", c.name()) + } else { + String::new() + }; + polars_bail!( + ShapeMismatch: "could not create a new DataFrame: series {name:?} has length {len} while trying to broadcast to length {broadcast_len}{extra_info}", + ); + } + *col = col.new_from_index(0, broadcast_len); + } + } + + let length = if columns.is_empty() { 0 } else { broadcast_len }; + + Ok(unsafe { DataFrame::new_no_checks(length, columns) }) + } + + /// Creates an empty `DataFrame` usable in a compile time context (such as static initializers). + /// + /// # Example + /// + /// ```rust + /// use polars_core::prelude::DataFrame; + /// static EMPTY: DataFrame = DataFrame::empty(); + /// ``` + pub const fn empty() -> Self { + Self::empty_with_height(0) + } + + /// Creates an empty `DataFrame` with a specific `height`. + pub const fn empty_with_height(height: usize) -> Self { + DataFrame { + height, + columns: vec![], + cached_schema: OnceLock::new(), + } + } + + /// Create an empty `DataFrame` with empty columns as per the `schema`. + pub fn empty_with_schema(schema: &Schema) -> Self { + let cols = schema + .iter() + .map(|(name, dtype)| Column::from(Series::new_empty(name.clone(), dtype))) + .collect(); + unsafe { DataFrame::new_no_checks(0, cols) } + } + + /// Create an empty `DataFrame` with empty columns as per the `schema`. + pub fn empty_with_arrow_schema(schema: &ArrowSchema) -> Self { + let cols = schema + .iter_values() + .map(|fld| { + Column::from(Series::new_empty( + fld.name.clone(), + &(DataType::from_arrow_field(fld)), + )) + }) + .collect(); + unsafe { DataFrame::new_no_checks(0, cols) } + } + + /// Create a new `DataFrame` with the given schema, only containing nulls. + pub fn full_null(schema: &Schema, height: usize) -> Self { + let columns = schema + .iter_fields() + .map(|f| Column::full_null(f.name.clone(), height, f.dtype())) + .collect(); + unsafe { DataFrame::new_no_checks(height, columns) } + } + + /// Removes the last `Series` from the `DataFrame` and returns it, or [`None`] if it is empty. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let s1 = Column::new("Ocean".into(), ["Atlantic", "Indian"]); + /// let s2 = Column::new("Area (km²)".into(), [106_460_000, 70_560_000]); + /// let mut df = DataFrame::new(vec![s1.clone(), s2.clone()])?; + /// + /// assert_eq!(df.pop(), Some(s2)); + /// assert_eq!(df.pop(), Some(s1)); + /// assert_eq!(df.pop(), None); + /// assert!(df.is_empty()); + /// # Ok::<(), PolarsError>(()) + /// ``` + pub fn pop(&mut self) -> Option { + self.clear_schema(); + + self.columns.pop() + } + + /// Add a new column at index 0 that counts the rows. + /// + /// # Example + /// + /// ``` + /// # use polars_core::prelude::*; + /// let df1: DataFrame = df!("Name" => ["James", "Mary", "John", "Patricia"])?; + /// assert_eq!(df1.shape(), (4, 1)); + /// + /// let df2: DataFrame = df1.with_row_index("Id".into(), None)?; + /// assert_eq!(df2.shape(), (4, 2)); + /// println!("{}", df2); + /// + /// # Ok::<(), PolarsError>(()) + /// ``` + /// + /// Output: + /// + /// ```text + /// shape: (4, 2) + /// +-----+----------+ + /// | Id | Name | + /// | --- | --- | + /// | u32 | str | + /// +=====+==========+ + /// | 0 | James | + /// +-----+----------+ + /// | 1 | Mary | + /// +-----+----------+ + /// | 2 | John | + /// +-----+----------+ + /// | 3 | Patricia | + /// +-----+----------+ + /// ``` + pub fn with_row_index(&self, name: PlSmallStr, offset: Option) -> PolarsResult { + let mut columns = Vec::with_capacity(self.columns.len() + 1); + let offset = offset.unwrap_or(0); + + let col = Column::new_row_index(name, offset, self.height())?; + columns.push(col); + columns.extend_from_slice(&self.columns); + DataFrame::new(columns) + } + + /// Add a row index column in place. + /// + /// # Safety + /// The caller should ensure the DataFrame does not already contain a column with the given name. + /// + /// # Panics + /// Panics if the resulting column would reach or overflow IdxSize::MAX. + pub unsafe fn with_row_index_mut( + &mut self, + name: PlSmallStr, + offset: Option, + ) -> &mut Self { + // TODO: Make this function unsafe + debug_assert!( + self.columns.iter().all(|c| c.name() != &name), + "with_row_index_mut(): column with name {} already exists", + &name + ); + + let offset = offset.unwrap_or(0); + let col = Column::new_row_index(name, offset, self.height()).unwrap(); + + self.clear_schema(); + self.columns.insert(0, col); + self + } + + /// Create a new `DataFrame` but does not check the length or duplicate occurrence of the + /// `Series`. + /// + /// Calculates the height from the first column or `0` if no columns are given. + /// + /// # Safety + /// + /// It is the callers responsibility to uphold the contract of all `Series` + /// having an equal length and a unique name, if not this may panic down the line. + pub unsafe fn new_no_checks_height_from_first(columns: Vec) -> DataFrame { + let height = columns.first().map_or(0, Column::len); + unsafe { Self::new_no_checks(height, columns) } + } + + /// Create a new `DataFrame` but does not check the length or duplicate occurrence of the + /// `Series`. + /// + /// It is advised to use [DataFrame::new] in favor of this method. + /// + /// # Safety + /// + /// It is the callers responsibility to uphold the contract of all `Series` + /// having an equal length and a unique name, if not this may panic down the line. + pub unsafe fn new_no_checks(height: usize, columns: Vec) -> DataFrame { + if cfg!(debug_assertions) { + DataFrame::validate_columns_slice(&columns).unwrap(); + } + + unsafe { Self::_new_no_checks_impl(height, columns) } + } + + /// This will not panic even in debug mode - there are some (rare) use cases where a DataFrame + /// is temporarily constructed containing duplicates for dispatching to functions. A DataFrame + /// constructed with this method is generally highly unsafe and should not be long-lived. + #[allow(clippy::missing_safety_doc)] + pub const unsafe fn _new_no_checks_impl(height: usize, columns: Vec) -> DataFrame { + DataFrame { + height, + columns, + cached_schema: OnceLock::new(), + } + } + + /// Shrink the capacity of this DataFrame to fit its length. + pub fn shrink_to_fit(&mut self) { + // Don't parallelize this. Memory overhead + for s in &mut self.columns { + s.shrink_to_fit(); + } + } + + /// Aggregate all the chunks in the DataFrame to a single chunk. + pub fn as_single_chunk(&mut self) -> &mut Self { + // Don't parallelize this. Memory overhead + for s in &mut self.columns { + *s = s.rechunk(); + } + self + } + + /// Aggregate all the chunks in the DataFrame to a single chunk in parallel. + /// This may lead to more peak memory consumption. + pub fn as_single_chunk_par(&mut self) -> &mut Self { + if self.columns.iter().any(|c| c.n_chunks() > 1) { + self.columns = self._apply_columns_par(&|s| s.rechunk()); + } + self + } + + /// Rechunks all columns to only have a single chunk. + pub fn rechunk_mut(&mut self) { + // SAFETY: We never adjust the length or names of the columns. + let columns = unsafe { self.get_columns_mut() }; + + for col in columns.iter_mut().filter(|c| c.n_chunks() > 1) { + *col = col.rechunk(); + } + } + + pub fn _deshare_views_mut(&mut self) { + // SAFETY: We never adjust the length or names of the columns. + unsafe { + let columns = self.get_columns_mut(); + for col in columns { + let Column::Series(s) = col else { continue }; + + if let Ok(ca) = s.binary() { + let gc_ca = ca.apply_kernel(&|a| a.deshare().into_boxed()); + *col = Column::from(gc_ca.into_series()); + } else if let Ok(ca) = s.str() { + let gc_ca = ca.apply_kernel(&|a| a.deshare().into_boxed()); + *col = Column::from(gc_ca.into_series()); + } + } + } + } + + /// Rechunks all columns to only have a single chunk and turns it into a [`RecordBatchT`]. + pub fn rechunk_to_record_batch( + self, + compat_level: CompatLevel, + ) -> RecordBatchT> { + let height = self.height(); + + let (schema, arrays) = self + .columns + .into_iter() + .map(|col| { + let mut series = col.take_materialized_series(); + // Rechunk to one chunk if necessary + if series.n_chunks() > 1 { + series = series.rechunk(); + } + ( + series.field().to_arrow(compat_level), + series.to_arrow(0, compat_level), + ) + }) + .collect(); + + RecordBatchT::new(height, Arc::new(schema), arrays) + } + + /// Returns true if the chunks of the columns do not align and re-chunking should be done + pub fn should_rechunk(&self) -> bool { + // Fast check. It is also needed for correctness, as code below doesn't check if the number + // of chunks is equal. + if !self + .get_columns() + .iter() + .filter_map(|c| c.as_series().map(|s| s.n_chunks())) + .all_equal() + { + return true; + } + + // From here we check chunk lengths. + let mut chunk_lengths = self.materialized_column_iter().map(|s| s.chunk_lengths()); + match chunk_lengths.next() { + None => false, + Some(first_column_chunk_lengths) => { + // Fast Path for single Chunk Series + if first_column_chunk_lengths.size_hint().0 == 1 { + return chunk_lengths.any(|cl| cl.size_hint().0 != 1); + } + // Always rechunk if we have more chunks than rows. + // except when we have an empty df containing a single chunk + let height = self.height(); + let n_chunks = first_column_chunk_lengths.size_hint().0; + if n_chunks > height && !(height == 0 && n_chunks == 1) { + return true; + } + // Slow Path for multi Chunk series + let v: Vec<_> = first_column_chunk_lengths.collect(); + for cl in chunk_lengths { + if cl.enumerate().any(|(idx, el)| Some(&el) != v.get(idx)) { + return true; + } + } + false + }, + } + } + + /// Ensure all the chunks in the [`DataFrame`] are aligned. + pub fn align_chunks_par(&mut self) -> &mut Self { + if self.should_rechunk() { + self.as_single_chunk_par() + } else { + self + } + } + + pub fn align_chunks(&mut self) -> &mut Self { + if self.should_rechunk() { + self.as_single_chunk() + } else { + self + } + } + + /// Get the [`DataFrame`] schema. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let df: DataFrame = df!("Thing" => ["Observable universe", "Human stupidity"], + /// "Diameter (m)" => [8.8e26, f64::INFINITY])?; + /// + /// let f1: Field = Field::new("Thing".into(), DataType::String); + /// let f2: Field = Field::new("Diameter (m)".into(), DataType::Float64); + /// let sc: Schema = Schema::from_iter(vec![f1, f2]); + /// + /// assert_eq!(&**df.schema(), &sc); + /// # Ok::<(), PolarsError>(()) + /// ``` + pub fn schema(&self) -> &SchemaRef { + let out = self.cached_schema.get_or_init(|| { + Arc::new( + self.columns + .iter() + .map(|x| (x.name().clone(), x.dtype().clone())) + .collect(), + ) + }); + + debug_assert_eq!(out.len(), self.width()); + + out + } + + /// Get a reference to the [`DataFrame`] columns. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let df: DataFrame = df!("Name" => ["Adenine", "Cytosine", "Guanine", "Thymine"], + /// "Symbol" => ["A", "C", "G", "T"])?; + /// let columns: &[Column] = df.get_columns(); + /// + /// assert_eq!(columns[0].name(), "Name"); + /// assert_eq!(columns[1].name(), "Symbol"); + /// # Ok::<(), PolarsError>(()) + /// ``` + #[inline] + pub fn get_columns(&self) -> &[Column] { + &self.columns + } + + #[inline] + /// Get mutable access to the underlying columns. + /// + /// # Safety + /// + /// The caller must ensure the length of all [`Series`] remains equal to `height` or + /// [`DataFrame::set_height`] is called afterwards with the appropriate `height`. + /// The caller must ensure that the cached schema is cleared if it modifies the schema by + /// calling [`DataFrame::clear_schema`]. + pub unsafe fn get_columns_mut(&mut self) -> &mut Vec { + &mut self.columns + } + + #[inline] + /// Remove all the columns in the [`DataFrame`] but keep the `height`. + pub fn clear_columns(&mut self) { + unsafe { self.get_columns_mut() }.clear(); + self.clear_schema(); + } + + #[inline] + /// Extend the columns without checking for name collisions or height. + /// + /// # Safety + /// + /// The caller needs to ensure that: + /// - Column names are unique within the resulting [`DataFrame`]. + /// - The length of each appended column matches the height of the [`DataFrame`]. For + /// `DataFrame`]s with no columns (ZCDFs), it is important that the height is set afterwards + /// with [`DataFrame::set_height`]. + pub unsafe fn column_extend_unchecked(&mut self, iter: impl IntoIterator) { + unsafe { self.get_columns_mut() }.extend(iter); + self.clear_schema(); + } + + /// Take ownership of the underlying columns vec. + pub fn take_columns(self) -> Vec { + self.columns + } + + /// Iterator over the columns as [`Series`]. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let s1 = Column::new("Name".into(), ["Pythagoras' theorem", "Shannon entropy"]); + /// let s2 = Column::new("Formula".into(), ["a²+b²=c²", "H=-Σ[P(x)log|P(x)|]"]); + /// let df: DataFrame = DataFrame::new(vec![s1.clone(), s2.clone()])?; + /// + /// let mut iterator = df.iter(); + /// + /// assert_eq!(iterator.next(), Some(s1.as_materialized_series())); + /// assert_eq!(iterator.next(), Some(s2.as_materialized_series())); + /// assert_eq!(iterator.next(), None); + /// # Ok::<(), PolarsError>(()) + /// ``` + pub fn iter(&self) -> impl ExactSizeIterator { + self.materialized_column_iter() + } + + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let df: DataFrame = df!("Language" => ["Rust", "Python"], + /// "Designer" => ["Graydon Hoare", "Guido van Rossum"])?; + /// + /// assert_eq!(df.get_column_names(), &["Language", "Designer"]); + /// # Ok::<(), PolarsError>(()) + /// ``` + pub fn get_column_names(&self) -> Vec<&PlSmallStr> { + self.columns.iter().map(|s| s.name()).collect() + } + + /// Get the [`Vec`] representing the column names. + pub fn get_column_names_owned(&self) -> Vec { + self.columns.iter().map(|s| s.name().clone()).collect() + } + + pub fn get_column_names_str(&self) -> Vec<&str> { + self.columns.iter().map(|s| s.name().as_str()).collect() + } + + /// Set the column names. + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let mut df: DataFrame = df!("Mathematical set" => ["ℕ", "ℤ", "𝔻", "ℚ", "ℝ", "ℂ"])?; + /// df.set_column_names(["Set"])?; + /// + /// assert_eq!(df.get_column_names(), &["Set"]); + /// # Ok::<(), PolarsError>(()) + /// ``` + pub fn set_column_names(&mut self, names: I) -> PolarsResult<()> + where + I: IntoIterator, + S: Into, + { + let names = names.into_iter().map(Into::into).collect::>(); + self._set_column_names_impl(names.as_slice()) + } + + fn _set_column_names_impl(&mut self, names: &[PlSmallStr]) -> PolarsResult<()> { + polars_ensure!( + names.len() == self.width(), + ShapeMismatch: "{} column names provided for a DataFrame of width {}", + names.len(), self.width() + ); + ensure_names_unique(names, |s| s.as_str())?; + + let columns = mem::take(&mut self.columns); + self.columns = columns + .into_iter() + .zip(names) + .map(|(s, name)| { + let mut s = s; + s.rename(name.clone()); + s + }) + .collect(); + self.clear_schema(); + Ok(()) + } + + /// Get the data types of the columns in the [`DataFrame`]. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let venus_air: DataFrame = df!("Element" => ["Carbon dioxide", "Nitrogen"], + /// "Fraction" => [0.965, 0.035])?; + /// + /// assert_eq!(venus_air.dtypes(), &[DataType::String, DataType::Float64]); + /// # Ok::<(), PolarsError>(()) + /// ``` + pub fn dtypes(&self) -> Vec { + self.columns.iter().map(|s| s.dtype().clone()).collect() + } + + pub(crate) fn first_series_column(&self) -> Option<&Series> { + self.columns.iter().find_map(|col| col.as_series()) + } + + /// The number of chunks for the first column. + pub fn first_col_n_chunks(&self) -> usize { + match self.first_series_column() { + None if self.columns.is_empty() => 0, + None => 1, + Some(s) => s.n_chunks(), + } + } + + /// The highest number of chunks for any column. + pub fn max_n_chunks(&self) -> usize { + self.columns + .iter() + .map(|s| s.as_series().map(|s| s.n_chunks()).unwrap_or(1)) + .max() + .unwrap_or(0) + } + + /// Get a reference to the schema fields of the [`DataFrame`]. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let earth: DataFrame = df!("Surface type" => ["Water", "Land"], + /// "Fraction" => [0.708, 0.292])?; + /// + /// let f1: Field = Field::new("Surface type".into(), DataType::String); + /// let f2: Field = Field::new("Fraction".into(), DataType::Float64); + /// + /// assert_eq!(earth.fields(), &[f1, f2]); + /// # Ok::<(), PolarsError>(()) + /// ``` + pub fn fields(&self) -> Vec { + self.columns + .iter() + .map(|s| s.field().into_owned()) + .collect() + } + + /// Get (height, width) of the [`DataFrame`]. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let df0: DataFrame = DataFrame::default(); + /// let df1: DataFrame = df!("1" => [1, 2, 3, 4, 5])?; + /// let df2: DataFrame = df!("1" => [1, 2, 3, 4, 5], + /// "2" => [1, 2, 3, 4, 5])?; + /// + /// assert_eq!(df0.shape(), (0 ,0)); + /// assert_eq!(df1.shape(), (5, 1)); + /// assert_eq!(df2.shape(), (5, 2)); + /// # Ok::<(), PolarsError>(()) + /// ``` + pub fn shape(&self) -> (usize, usize) { + (self.height, self.columns.len()) + } + + /// Get the width of the [`DataFrame`] which is the number of columns. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let df0: DataFrame = DataFrame::default(); + /// let df1: DataFrame = df!("Series 1" => [0; 0])?; + /// let df2: DataFrame = df!("Series 1" => [0; 0], + /// "Series 2" => [0; 0])?; + /// + /// assert_eq!(df0.width(), 0); + /// assert_eq!(df1.width(), 1); + /// assert_eq!(df2.width(), 2); + /// # Ok::<(), PolarsError>(()) + /// ``` + pub fn width(&self) -> usize { + self.columns.len() + } + + /// Get the height of the [`DataFrame`] which is the number of rows. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let df0: DataFrame = DataFrame::default(); + /// let df1: DataFrame = df!("Currency" => ["€", "$"])?; + /// let df2: DataFrame = df!("Currency" => ["€", "$", "¥", "£", "₿"])?; + /// + /// assert_eq!(df0.height(), 0); + /// assert_eq!(df1.height(), 2); + /// assert_eq!(df2.height(), 5); + /// # Ok::<(), PolarsError>(()) + /// ``` + pub fn height(&self) -> usize { + self.height + } + + /// Returns the size as number of rows * number of columns + pub fn size(&self) -> usize { + let s = self.shape(); + s.0 * s.1 + } + + /// Returns `true` if the [`DataFrame`] contains no rows. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let df1: DataFrame = DataFrame::default(); + /// assert!(df1.is_empty()); + /// + /// let df2: DataFrame = df!("First name" => ["Forever"], + /// "Last name" => ["Alone"])?; + /// assert!(!df2.is_empty()); + /// # Ok::<(), PolarsError>(()) + /// ``` + pub fn is_empty(&self) -> bool { + matches!(self.shape(), (0, _) | (_, 0)) + } + + /// Set the height (i.e. number of rows) of this [`DataFrame`]. + /// + /// # Safety + /// + /// This needs to be equal to the length of all the columns. + pub unsafe fn set_height(&mut self, height: usize) { + self.height = height; + } + + /// Add multiple [`Series`] to a [`DataFrame`]. + /// The added `Series` are required to have the same length. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let df1: DataFrame = df!("Element" => ["Copper", "Silver", "Gold"])?; + /// let s1 = Column::new("Proton".into(), [29, 47, 79]); + /// let s2 = Column::new("Electron".into(), [29, 47, 79]); + /// + /// let df2: DataFrame = df1.hstack(&[s1, s2])?; + /// assert_eq!(df2.shape(), (3, 3)); + /// println!("{}", df2); + /// # Ok::<(), PolarsError>(()) + /// ``` + /// + /// Output: + /// + /// ```text + /// shape: (3, 3) + /// +---------+--------+----------+ + /// | Element | Proton | Electron | + /// | --- | --- | --- | + /// | str | i32 | i32 | + /// +=========+========+==========+ + /// | Copper | 29 | 29 | + /// +---------+--------+----------+ + /// | Silver | 47 | 47 | + /// +---------+--------+----------+ + /// | Gold | 79 | 79 | + /// +---------+--------+----------+ + /// ``` + pub fn hstack(&self, columns: &[Column]) -> PolarsResult { + let mut new_cols = self.columns.clone(); + new_cols.extend_from_slice(columns); + DataFrame::new(new_cols) + } + + /// Concatenate a [`DataFrame`] to this [`DataFrame`] and return as newly allocated [`DataFrame`]. + /// + /// If many `vstack` operations are done, it is recommended to call [`DataFrame::align_chunks_par`]. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let df1: DataFrame = df!("Element" => ["Copper", "Silver", "Gold"], + /// "Melting Point (K)" => [1357.77, 1234.93, 1337.33])?; + /// let df2: DataFrame = df!("Element" => ["Platinum", "Palladium"], + /// "Melting Point (K)" => [2041.4, 1828.05])?; + /// + /// let df3: DataFrame = df1.vstack(&df2)?; + /// + /// assert_eq!(df3.shape(), (5, 2)); + /// println!("{}", df3); + /// # Ok::<(), PolarsError>(()) + /// ``` + /// + /// Output: + /// + /// ```text + /// shape: (5, 2) + /// +-----------+-------------------+ + /// | Element | Melting Point (K) | + /// | --- | --- | + /// | str | f64 | + /// +===========+===================+ + /// | Copper | 1357.77 | + /// +-----------+-------------------+ + /// | Silver | 1234.93 | + /// +-----------+-------------------+ + /// | Gold | 1337.33 | + /// +-----------+-------------------+ + /// | Platinum | 2041.4 | + /// +-----------+-------------------+ + /// | Palladium | 1828.05 | + /// +-----------+-------------------+ + /// ``` + pub fn vstack(&self, other: &DataFrame) -> PolarsResult { + let mut df = self.clone(); + df.vstack_mut(other)?; + Ok(df) + } + + /// Concatenate a [`DataFrame`] to this [`DataFrame`] + /// + /// If many `vstack` operations are done, it is recommended to call [`DataFrame::align_chunks_par`]. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let mut df1: DataFrame = df!("Element" => ["Copper", "Silver", "Gold"], + /// "Melting Point (K)" => [1357.77, 1234.93, 1337.33])?; + /// let df2: DataFrame = df!("Element" => ["Platinum", "Palladium"], + /// "Melting Point (K)" => [2041.4, 1828.05])?; + /// + /// df1.vstack_mut(&df2)?; + /// + /// assert_eq!(df1.shape(), (5, 2)); + /// println!("{}", df1); + /// # Ok::<(), PolarsError>(()) + /// ``` + /// + /// Output: + /// + /// ```text + /// shape: (5, 2) + /// +-----------+-------------------+ + /// | Element | Melting Point (K) | + /// | --- | --- | + /// | str | f64 | + /// +===========+===================+ + /// | Copper | 1357.77 | + /// +-----------+-------------------+ + /// | Silver | 1234.93 | + /// +-----------+-------------------+ + /// | Gold | 1337.33 | + /// +-----------+-------------------+ + /// | Platinum | 2041.4 | + /// +-----------+-------------------+ + /// | Palladium | 1828.05 | + /// +-----------+-------------------+ + /// ``` + pub fn vstack_mut(&mut self, other: &DataFrame) -> PolarsResult<&mut Self> { + if self.width() != other.width() { + polars_ensure!( + self.width() == 0, + ShapeMismatch: + "unable to append to a DataFrame of width {} with a DataFrame of width {}", + self.width(), other.width(), + ); + self.columns.clone_from(&other.columns); + self.height = other.height; + return Ok(self); + } + + self.columns + .iter_mut() + .zip(other.columns.iter()) + .try_for_each::<_, PolarsResult<_>>(|(left, right)| { + ensure_can_extend(&*left, right)?; + left.append(right).map_err(|e| { + e.context(format!("failed to vstack column '{}'", right.name()).into()) + })?; + Ok(()) + })?; + self.height += other.height; + Ok(self) + } + + pub fn vstack_mut_owned(&mut self, other: DataFrame) -> PolarsResult<&mut Self> { + if self.width() != other.width() { + polars_ensure!( + self.width() == 0, + ShapeMismatch: + "unable to append to a DataFrame of width {} with a DataFrame of width {}", + self.width(), other.width(), + ); + self.columns = other.columns; + self.height = other.height; + return Ok(self); + } + + self.columns + .iter_mut() + .zip(other.columns.into_iter()) + .try_for_each::<_, PolarsResult<_>>(|(left, right)| { + ensure_can_extend(&*left, &right)?; + let right_name = right.name().clone(); + left.append_owned(right).map_err(|e| { + e.context(format!("failed to vstack column '{right_name}'").into()) + })?; + Ok(()) + })?; + self.height += other.height; + Ok(self) + } + + /// Concatenate a [`DataFrame`] to this [`DataFrame`] + /// + /// If many `vstack` operations are done, it is recommended to call [`DataFrame::align_chunks_par`]. + /// + /// # Panics + /// Panics if the schema's don't match. + pub fn vstack_mut_unchecked(&mut self, other: &DataFrame) { + self.columns + .iter_mut() + .zip(other.columns.iter()) + .for_each(|(left, right)| { + left.append(right) + .map_err(|e| { + e.context(format!("failed to vstack column '{}'", right.name()).into()) + }) + .expect("should not fail"); + }); + self.height += other.height; + } + + /// Concatenate a [`DataFrame`] to this [`DataFrame`] + /// + /// If many `vstack` operations are done, it is recommended to call [`DataFrame::align_chunks_par`]. + /// + /// # Panics + /// Panics if the schema's don't match. + pub fn vstack_mut_owned_unchecked(&mut self, other: DataFrame) { + self.columns + .iter_mut() + .zip(other.columns) + .for_each(|(left, right)| { + left.append_owned(right).expect("should not fail"); + }); + self.height += other.height; + } + + /// Extend the memory backed by this [`DataFrame`] with the values from `other`. + /// + /// Different from [`vstack`](Self::vstack) which adds the chunks from `other` to the chunks of this [`DataFrame`] + /// `extend` appends the data from `other` to the underlying memory locations and thus may cause a reallocation. + /// + /// If this does not cause a reallocation, the resulting data structure will not have any extra chunks + /// and thus will yield faster queries. + /// + /// Prefer `extend` over `vstack` when you want to do a query after a single append. For instance during + /// online operations where you add `n` rows and rerun a query. + /// + /// Prefer `vstack` over `extend` when you want to append many times before doing a query. For instance + /// when you read in multiple files and when to store them in a single `DataFrame`. In the latter case, finish the sequence + /// of `append` operations with a [`rechunk`](Self::align_chunks_par). + pub fn extend(&mut self, other: &DataFrame) -> PolarsResult<()> { + polars_ensure!( + self.width() == other.width(), + ShapeMismatch: + "unable to extend a DataFrame of width {} with a DataFrame of width {}", + self.width(), other.width(), + ); + + self.columns + .iter_mut() + .zip(other.columns.iter()) + .try_for_each::<_, PolarsResult<_>>(|(left, right)| { + ensure_can_extend(&*left, right)?; + left.extend(right).map_err(|e| { + e.context(format!("failed to extend column '{}'", right.name()).into()) + })?; + Ok(()) + })?; + self.height += other.height; + self.clear_schema(); + Ok(()) + } + + /// Remove a column by name and return the column removed. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let mut df: DataFrame = df!("Animal" => ["Tiger", "Lion", "Great auk"], + /// "IUCN" => ["Endangered", "Vulnerable", "Extinct"])?; + /// + /// let s1: PolarsResult = df.drop_in_place("Average weight"); + /// assert!(s1.is_err()); + /// + /// let s2: Column = df.drop_in_place("Animal")?; + /// assert_eq!(s2, Column::new("Animal".into(), &["Tiger", "Lion", "Great auk"])); + /// # Ok::<(), PolarsError>(()) + /// ``` + pub fn drop_in_place(&mut self, name: &str) -> PolarsResult { + let idx = self.check_name_to_idx(name)?; + self.clear_schema(); + Ok(self.columns.remove(idx)) + } + + /// Return a new [`DataFrame`] where all null values are dropped. + /// + /// # Example + /// + /// ```no_run + /// # use polars_core::prelude::*; + /// let df1: DataFrame = df!("Country" => ["Malta", "Liechtenstein", "North Korea"], + /// "Tax revenue (% GDP)" => [Some(32.7), None, None])?; + /// assert_eq!(df1.shape(), (3, 2)); + /// + /// let df2: DataFrame = df1.drop_nulls::(None)?; + /// assert_eq!(df2.shape(), (1, 2)); + /// println!("{}", df2); + /// # Ok::<(), PolarsError>(()) + /// ``` + /// + /// Output: + /// + /// ```text + /// shape: (1, 2) + /// +---------+---------------------+ + /// | Country | Tax revenue (% GDP) | + /// | --- | --- | + /// | str | f64 | + /// +=========+=====================+ + /// | Malta | 32.7 | + /// +---------+---------------------+ + /// ``` + pub fn drop_nulls(&self, subset: Option<&[S]>) -> PolarsResult + where + for<'a> &'a S: Into, + { + if let Some(v) = subset { + let v = self.select_columns(v)?; + self._drop_nulls_impl(v.as_slice()) + } else { + self._drop_nulls_impl(self.columns.as_slice()) + } + } + + fn _drop_nulls_impl(&self, subset: &[Column]) -> PolarsResult { + // fast path for no nulls in df + if subset.iter().all(|s| !s.has_nulls()) { + return Ok(self.clone()); + } + + let mut iter = subset.iter(); + + let mask = iter + .next() + .ok_or_else(|| polars_err!(NoData: "no data to drop nulls from"))?; + let mut mask = mask.is_not_null(); + + for c in iter { + mask = mask & c.is_not_null(); + } + self.filter(&mask) + } + + /// Drop a column by name. + /// This is a pure method and will return a new [`DataFrame`] instead of modifying + /// the current one in place. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let df1: DataFrame = df!("Ray type" => ["α", "β", "X", "γ"])?; + /// let df2: DataFrame = df1.drop("Ray type")?; + /// + /// assert!(df2.is_empty()); + /// # Ok::<(), PolarsError>(()) + /// ``` + pub fn drop(&self, name: &str) -> PolarsResult { + let idx = self.check_name_to_idx(name)?; + let mut new_cols = Vec::with_capacity(self.columns.len() - 1); + + self.columns.iter().enumerate().for_each(|(i, s)| { + if i != idx { + new_cols.push(s.clone()) + } + }); + + Ok(unsafe { DataFrame::new_no_checks(self.height(), new_cols) }) + } + + /// Drop columns that are in `names`. + pub fn drop_many(&self, names: I) -> Self + where + I: IntoIterator, + S: Into, + { + let names: PlHashSet = names.into_iter().map(|s| s.into()).collect(); + self.drop_many_amortized(&names) + } + + /// Drop columns that are in `names` without allocating a [`HashSet`](std::collections::HashSet). + pub fn drop_many_amortized(&self, names: &PlHashSet) -> DataFrame { + if names.is_empty() { + return self.clone(); + } + let mut new_cols = Vec::with_capacity(self.columns.len().saturating_sub(names.len())); + self.columns.iter().for_each(|s| { + if !names.contains(s.name()) { + new_cols.push(s.clone()) + } + }); + + unsafe { DataFrame::new_no_checks(self.height(), new_cols) } + } + + /// Insert a new column at a given index without checking for duplicates. + /// This can leave the [`DataFrame`] at an invalid state + fn insert_column_no_name_check( + &mut self, + index: usize, + column: Column, + ) -> PolarsResult<&mut Self> { + polars_ensure!( + self.width() == 0 || column.len() == self.height(), + ShapeMismatch: "unable to add a column of length {} to a DataFrame of height {}", + column.len(), self.height(), + ); + + if self.width() == 0 { + self.height = column.len(); + } + + self.columns.insert(index, column); + self.clear_schema(); + Ok(self) + } + + /// Insert a new column at a given index. + pub fn insert_column( + &mut self, + index: usize, + column: S, + ) -> PolarsResult<&mut Self> { + let column = column.into_column(); + self.check_already_present(column.name().as_str())?; + self.insert_column_no_name_check(index, column) + } + + fn add_column_by_search(&mut self, column: Column) -> PolarsResult<()> { + if let Some(idx) = self.get_column_index(column.name().as_str()) { + self.replace_column(idx, column)?; + } else { + if self.width() == 0 { + self.height = column.len(); + } + + self.columns.push(column); + self.clear_schema(); + } + Ok(()) + } + + /// Add a new column to this [`DataFrame`] or replace an existing one. + pub fn with_column(&mut self, column: C) -> PolarsResult<&mut Self> { + fn inner(df: &mut DataFrame, mut column: Column) -> PolarsResult<&mut DataFrame> { + let height = df.height(); + if column.len() == 1 && height > 1 { + column = column.new_from_index(0, height); + } + + if column.len() == height || df.get_columns().is_empty() { + df.add_column_by_search(column)?; + Ok(df) + } + // special case for literals + else if height == 0 && column.len() == 1 { + let s = column.clear(); + df.add_column_by_search(s)?; + Ok(df) + } else { + polars_bail!( + ShapeMismatch: "unable to add a column of length {} to a DataFrame of height {}", + column.len(), height, + ); + } + } + let column = column.into_column(); + inner(self, column) + } + + /// Adds a column to the [`DataFrame`] without doing any checks + /// on length or duplicates. + /// + /// # Safety + /// The caller must ensure `self.width() == 0 || column.len() == self.height()` . + pub unsafe fn with_column_unchecked(&mut self, column: Column) -> &mut Self { + debug_assert!(self.width() == 0 || self.height() == column.len()); + debug_assert!(self.get_column_index(column.name().as_str()).is_none()); + + // SAFETY: Invariant of function guarantees for case `width` > 0. We set the height + // properly for `width` == 0. + if self.width() == 0 { + unsafe { self.set_height(column.len()) }; + } + unsafe { self.get_columns_mut() }.push(column); + self.clear_schema(); + + self + } + + // Note: Schema can be both input or output_schema + fn add_column_by_schema(&mut self, c: Column, schema: &Schema) -> PolarsResult<()> { + let name = c.name(); + if let Some((idx, _, _)) = schema.get_full(name.as_str()) { + if self.columns.get(idx).map(|s| s.name()) != Some(name) { + // Given schema is output_schema and we can push. + if idx == self.columns.len() { + if self.width() == 0 { + self.height = c.len(); + } + + self.columns.push(c); + self.clear_schema(); + } + // Schema is incorrect fallback to search + else { + debug_assert!(false); + self.add_column_by_search(c)?; + } + } else { + self.replace_column(idx, c)?; + } + } else { + if self.width() == 0 { + self.height = c.len(); + } + + self.columns.push(c); + self.clear_schema(); + } + + Ok(()) + } + + // Note: Schema can be both input or output_schema + pub fn _add_series(&mut self, series: Vec, schema: &Schema) -> PolarsResult<()> { + for (i, s) in series.into_iter().enumerate() { + // we need to branch here + // because users can add multiple columns with the same name + if i == 0 || schema.get(s.name().as_str()).is_some() { + self.with_column_and_schema(s.into_column(), schema)?; + } else { + self.with_column(s.clone().into_column())?; + } + } + Ok(()) + } + + pub fn _add_columns(&mut self, columns: Vec, schema: &Schema) -> PolarsResult<()> { + for (i, s) in columns.into_iter().enumerate() { + // we need to branch here + // because users can add multiple columns with the same name + if i == 0 || schema.get(s.name().as_str()).is_some() { + self.with_column_and_schema(s, schema)?; + } else { + self.with_column(s.clone())?; + } + } + + Ok(()) + } + + /// Add a new column to this [`DataFrame`] or replace an existing one. + /// Uses an existing schema to amortize lookups. + /// If the schema is incorrect, we will fallback to linear search. + /// + /// Note: Schema can be both input or output_schema + pub fn with_column_and_schema( + &mut self, + column: C, + schema: &Schema, + ) -> PolarsResult<&mut Self> { + let mut column = column.into_column(); + + let height = self.height(); + if column.len() == 1 && height > 1 { + column = column.new_from_index(0, height); + } + + if column.len() == height || self.columns.is_empty() { + self.add_column_by_schema(column, schema)?; + Ok(self) + } + // special case for literals + else if height == 0 && column.len() == 1 { + let s = column.clear(); + self.add_column_by_schema(s, schema)?; + Ok(self) + } else { + polars_bail!( + ShapeMismatch: "unable to add a column of length {} to a DataFrame of height {}", + column.len(), height, + ); + } + } + + /// Get a row in the [`DataFrame`]. Beware this is slow. + /// + /// # Example + /// + /// ``` + /// # use polars_core::prelude::*; + /// fn example(df: &mut DataFrame, idx: usize) -> Option> { + /// df.get(idx) + /// } + /// ``` + pub fn get(&self, idx: usize) -> Option> { + match self.columns.first() { + Some(s) => { + if s.len() <= idx { + return None; + } + }, + None => return None, + } + // SAFETY: we just checked bounds + unsafe { Some(self.columns.iter().map(|c| c.get_unchecked(idx)).collect()) } + } + + /// Select a [`Series`] by index. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let df: DataFrame = df!("Star" => ["Sun", "Betelgeuse", "Sirius A", "Sirius B"], + /// "Absolute magnitude" => [4.83, -5.85, 1.42, 11.18])?; + /// + /// let s1: Option<&Column> = df.select_at_idx(0); + /// let s2 = Column::new("Star".into(), ["Sun", "Betelgeuse", "Sirius A", "Sirius B"]); + /// + /// assert_eq!(s1, Some(&s2)); + /// # Ok::<(), PolarsError>(()) + /// ``` + pub fn select_at_idx(&self, idx: usize) -> Option<&Column> { + self.columns.get(idx) + } + + /// Select column(s) from this [`DataFrame`] by range and return a new [`DataFrame`] + /// + /// # Examples + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let df = df! { + /// "0" => [0, 0, 0], + /// "1" => [1, 1, 1], + /// "2" => [2, 2, 2] + /// }?; + /// + /// assert!(df.select(["0", "1"])?.equals(&df.select_by_range(0..=1)?)); + /// assert!(df.equals(&df.select_by_range(..)?)); + /// # Ok::<(), PolarsError>(()) + /// ``` + pub fn select_by_range(&self, range: R) -> PolarsResult + where + R: ops::RangeBounds, + { + // This function is copied from std::slice::range (https://doc.rust-lang.org/std/slice/fn.range.html) + // because it is the nightly feature. We should change here if this function were stable. + fn get_range(range: R, bounds: ops::RangeTo) -> ops::Range + where + R: ops::RangeBounds, + { + let len = bounds.end; + + let start: ops::Bound<&usize> = range.start_bound(); + let start = match start { + ops::Bound::Included(&start) => start, + ops::Bound::Excluded(start) => start.checked_add(1).unwrap_or_else(|| { + panic!("attempted to index slice from after maximum usize"); + }), + ops::Bound::Unbounded => 0, + }; + + let end: ops::Bound<&usize> = range.end_bound(); + let end = match end { + ops::Bound::Included(end) => end.checked_add(1).unwrap_or_else(|| { + panic!("attempted to index slice up to maximum usize"); + }), + ops::Bound::Excluded(&end) => end, + ops::Bound::Unbounded => len, + }; + + if start > end { + panic!("slice index starts at {start} but ends at {end}"); + } + if end > len { + panic!("range end index {end} out of range for slice of length {len}",); + } + + ops::Range { start, end } + } + + let colnames = self.get_column_names_owned(); + let range = get_range(range, ..colnames.len()); + + self._select_impl(&colnames[range]) + } + + /// Get column index of a [`Series`] by name. + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let df: DataFrame = df!("Name" => ["Player 1", "Player 2", "Player 3"], + /// "Health" => [100, 200, 500], + /// "Mana" => [250, 100, 0], + /// "Strength" => [30, 150, 300])?; + /// + /// assert_eq!(df.get_column_index("Name"), Some(0)); + /// assert_eq!(df.get_column_index("Health"), Some(1)); + /// assert_eq!(df.get_column_index("Mana"), Some(2)); + /// assert_eq!(df.get_column_index("Strength"), Some(3)); + /// assert_eq!(df.get_column_index("Haste"), None); + /// # Ok::<(), PolarsError>(()) + /// ``` + pub fn get_column_index(&self, name: &str) -> Option { + let schema = self.schema(); + if let Some(idx) = schema.index_of(name) { + if self + .get_columns() + .get(idx) + .is_some_and(|c| c.name() == name) + { + return Some(idx); + } + } + + self.columns.iter().position(|s| s.name().as_str() == name) + } + + /// Get column index of a [`Series`] by name. + pub fn try_get_column_index(&self, name: &str) -> PolarsResult { + self.get_column_index(name) + .ok_or_else(|| polars_err!(col_not_found = name)) + } + + /// Select a single column by name. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let s1 = Column::new("Password".into(), ["123456", "[]B$u$g$s$B#u#n#n#y[]{}"]); + /// let s2 = Column::new("Robustness".into(), ["Weak", "Strong"]); + /// let df: DataFrame = DataFrame::new(vec![s1.clone(), s2])?; + /// + /// assert_eq!(df.column("Password")?, &s1); + /// # Ok::<(), PolarsError>(()) + /// ``` + pub fn column(&self, name: &str) -> PolarsResult<&Column> { + let idx = self.try_get_column_index(name)?; + Ok(self.select_at_idx(idx).unwrap()) + } + + /// Selected multiple columns by name. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let df: DataFrame = df!("Latin name" => ["Oncorhynchus kisutch", "Salmo salar"], + /// "Max weight (kg)" => [16.0, 35.89])?; + /// let sv: Vec<&Column> = df.columns(["Latin name", "Max weight (kg)"])?; + /// + /// assert_eq!(&df[0], sv[0]); + /// assert_eq!(&df[1], sv[1]); + /// # Ok::<(), PolarsError>(()) + /// ``` + pub fn columns(&self, names: I) -> PolarsResult> + where + I: IntoIterator, + S: AsRef, + { + names + .into_iter() + .map(|name| self.column(name.as_ref())) + .collect() + } + + /// Select column(s) from this [`DataFrame`] and return a new [`DataFrame`]. + /// + /// # Examples + /// + /// ``` + /// # use polars_core::prelude::*; + /// fn example(df: &DataFrame) -> PolarsResult { + /// df.select(["foo", "bar"]) + /// } + /// ``` + pub fn select(&self, selection: I) -> PolarsResult + where + I: IntoIterator, + S: Into, + { + let cols = selection.into_iter().map(|s| s.into()).collect::>(); + self._select_impl(cols.as_slice()) + } + + pub fn _select_impl(&self, cols: &[PlSmallStr]) -> PolarsResult { + ensure_names_unique(cols, |s| s.as_str())?; + self._select_impl_unchecked(cols) + } + + pub fn _select_impl_unchecked(&self, cols: &[PlSmallStr]) -> PolarsResult { + let selected = self.select_columns_impl(cols)?; + Ok(unsafe { DataFrame::new_no_checks(self.height(), selected) }) + } + + /// Select with a known schema. The schema names must match the column names of this DataFrame. + pub fn select_with_schema(&self, selection: I, schema: &SchemaRef) -> PolarsResult + where + I: IntoIterator, + S: Into, + { + let cols = selection.into_iter().map(|s| s.into()).collect::>(); + self._select_with_schema_impl(&cols, schema, true) + } + + /// Select with a known schema without checking for duplicates in `selection`. + /// The schema names must match the column names of this DataFrame. + pub fn select_with_schema_unchecked( + &self, + selection: I, + schema: &Schema, + ) -> PolarsResult + where + I: IntoIterator, + S: Into, + { + let cols = selection.into_iter().map(|s| s.into()).collect::>(); + self._select_with_schema_impl(&cols, schema, false) + } + + /// * The schema names must match the column names of this DataFrame. + pub fn _select_with_schema_impl( + &self, + cols: &[PlSmallStr], + schema: &Schema, + check_duplicates: bool, + ) -> PolarsResult { + if check_duplicates { + ensure_names_unique(cols, |s| s.as_str())?; + } + + let selected = self.select_columns_impl_with_schema(cols, schema)?; + Ok(unsafe { DataFrame::new_no_checks(self.height(), selected) }) + } + + /// A non generic implementation to reduce compiler bloat. + fn select_columns_impl_with_schema( + &self, + cols: &[PlSmallStr], + schema: &Schema, + ) -> PolarsResult> { + if cfg!(debug_assertions) { + ensure_matching_schema_names(schema, self.schema())?; + } + + cols.iter() + .map(|name| { + let index = schema.try_get_full(name.as_str())?.0; + Ok(self.columns[index].clone()) + }) + .collect() + } + + pub fn select_physical(&self, selection: I) -> PolarsResult + where + I: IntoIterator, + S: Into, + { + let cols = selection.into_iter().map(|s| s.into()).collect::>(); + self.select_physical_impl(&cols) + } + + fn select_physical_impl(&self, cols: &[PlSmallStr]) -> PolarsResult { + ensure_names_unique(cols, |s| s.as_str())?; + let selected = self.select_columns_physical_impl(cols)?; + Ok(unsafe { DataFrame::new_no_checks(self.height(), selected) }) + } + + /// Select column(s) from this [`DataFrame`] and return them into a [`Vec`]. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let df: DataFrame = df!("Name" => ["Methane", "Ethane", "Propane"], + /// "Carbon" => [1, 2, 3], + /// "Hydrogen" => [4, 6, 8])?; + /// let sv: Vec = df.select_columns(["Carbon", "Hydrogen"])?; + /// + /// assert_eq!(df["Carbon"], sv[0]); + /// assert_eq!(df["Hydrogen"], sv[1]); + /// # Ok::<(), PolarsError>(()) + /// ``` + pub fn select_columns(&self, selection: impl IntoVec) -> PolarsResult> { + let cols = selection.into_vec(); + self.select_columns_impl(&cols) + } + + fn _names_to_idx_map(&self) -> PlHashMap<&str, usize> { + self.columns + .iter() + .enumerate() + .map(|(i, s)| (s.name().as_str(), i)) + .collect() + } + + /// A non generic implementation to reduce compiler bloat. + fn select_columns_physical_impl(&self, cols: &[PlSmallStr]) -> PolarsResult> { + let selected = if cols.len() > 1 && self.columns.len() > 10 { + let name_to_idx = self._names_to_idx_map(); + cols.iter() + .map(|name| { + let idx = *name_to_idx + .get(name.as_str()) + .ok_or_else(|| polars_err!(col_not_found = name))?; + Ok(self.select_at_idx(idx).unwrap().to_physical_repr()) + }) + .collect::>>()? + } else { + cols.iter() + .map(|c| self.column(c.as_str()).map(|s| s.to_physical_repr())) + .collect::>>()? + }; + + Ok(selected) + } + + /// A non generic implementation to reduce compiler bloat. + fn select_columns_impl(&self, cols: &[PlSmallStr]) -> PolarsResult> { + let selected = if cols.len() > 1 && self.columns.len() > 10 { + // we hash, because there are user that having millions of columns. + // # https://github.com/pola-rs/polars/issues/1023 + let name_to_idx = self._names_to_idx_map(); + + cols.iter() + .map(|name| { + let idx = *name_to_idx + .get(name.as_str()) + .ok_or_else(|| polars_err!(col_not_found = name))?; + Ok(self.select_at_idx(idx).unwrap().clone()) + }) + .collect::>>()? + } else { + cols.iter() + .map(|c| self.column(c.as_str()).cloned()) + .collect::>>()? + }; + + Ok(selected) + } + + fn filter_height(&self, filtered: &[Column], mask: &BooleanChunked) -> usize { + // If there is a filtered column just see how many columns there are left. + if let Some(fst) = filtered.first() { + return fst.len(); + } + + // Otherwise, count the number of values that would be filtered and return that height. + let num_trues = mask.num_trues(); + if mask.len() == self.height() { + num_trues + } else { + // This is for broadcasting masks + debug_assert!(num_trues == 0 || num_trues == 1); + self.height() * num_trues + } + } + + /// Take the [`DataFrame`] rows by a boolean mask. + /// + /// # Example + /// + /// ``` + /// # use polars_core::prelude::*; + /// fn example(df: &DataFrame) -> PolarsResult { + /// let mask = df.column("sepal_width")?.is_not_null(); + /// df.filter(&mask) + /// } + /// ``` + pub fn filter(&self, mask: &BooleanChunked) -> PolarsResult { + let new_col = self.try_apply_columns_par(&|s| s.filter(mask))?; + let height = self.filter_height(&new_col, mask); + + Ok(unsafe { DataFrame::new_no_checks(height, new_col) }) + } + + /// Same as `filter` but does not parallelize. + pub fn _filter_seq(&self, mask: &BooleanChunked) -> PolarsResult { + let new_col = self.try_apply_columns(&|s| s.filter(mask))?; + let height = self.filter_height(&new_col, mask); + + Ok(unsafe { DataFrame::new_no_checks(height, new_col) }) + } + + /// Take [`DataFrame`] rows by index values. + /// + /// # Example + /// + /// ``` + /// # use polars_core::prelude::*; + /// fn example(df: &DataFrame) -> PolarsResult { + /// let idx = IdxCa::new("idx".into(), [0, 1, 9]); + /// df.take(&idx) + /// } + /// ``` + pub fn take(&self, indices: &IdxCa) -> PolarsResult { + let new_col = POOL.install(|| self.try_apply_columns_par(&|s| s.take(indices)))?; + + Ok(unsafe { DataFrame::new_no_checks(indices.len(), new_col) }) + } + + /// # Safety + /// The indices must be in-bounds. + pub unsafe fn take_unchecked(&self, idx: &IdxCa) -> Self { + self.take_unchecked_impl(idx, true) + } + + /// # Safety + /// The indices must be in-bounds. + pub unsafe fn take_unchecked_impl(&self, idx: &IdxCa, allow_threads: bool) -> Self { + let cols = if allow_threads { + POOL.install(|| self._apply_columns_par(&|c| c.take_unchecked(idx))) + } else { + self._apply_columns(&|s| s.take_unchecked(idx)) + }; + unsafe { DataFrame::new_no_checks(idx.len(), cols) } + } + + /// # Safety + /// The indices must be in-bounds. + pub unsafe fn take_slice_unchecked(&self, idx: &[IdxSize]) -> Self { + self.take_slice_unchecked_impl(idx, true) + } + + /// # Safety + /// The indices must be in-bounds. + pub unsafe fn take_slice_unchecked_impl(&self, idx: &[IdxSize], allow_threads: bool) -> Self { + let cols = if allow_threads { + POOL.install(|| self._apply_columns_par(&|s| s.take_slice_unchecked(idx))) + } else { + self._apply_columns(&|s| s.take_slice_unchecked(idx)) + }; + unsafe { DataFrame::new_no_checks(idx.len(), cols) } + } + + /// Rename a column in the [`DataFrame`]. + /// + /// # Example + /// + /// ``` + /// # use polars_core::prelude::*; + /// fn example(df: &mut DataFrame) -> PolarsResult<&mut DataFrame> { + /// let original_name = "foo"; + /// let new_name = "bar"; + /// df.rename(original_name, new_name.into()) + /// } + /// ``` + pub fn rename(&mut self, column: &str, name: PlSmallStr) -> PolarsResult<&mut Self> { + if column == name.as_str() { + return Ok(self); + } + polars_ensure!( + !self.schema().contains(&name), + Duplicate: "column rename attempted with already existing name \"{name}\"" + ); + + self.get_column_index(column) + .and_then(|idx| self.columns.get_mut(idx)) + .ok_or_else(|| polars_err!(col_not_found = column)) + .map(|c| c.rename(name))?; + Ok(self) + } + + /// Sort [`DataFrame`] in place. + /// + /// See [`DataFrame::sort`] for more instruction. + pub fn sort_in_place( + &mut self, + by: impl IntoVec, + sort_options: SortMultipleOptions, + ) -> PolarsResult<&mut Self> { + let by_column = self.select_columns(by)?; + self.columns = self.sort_impl(by_column, sort_options, None)?.columns; + Ok(self) + } + + #[doc(hidden)] + /// This is the dispatch of Self::sort, and exists to reduce compile bloat by monomorphization. + pub fn sort_impl( + &self, + by_column: Vec, + mut sort_options: SortMultipleOptions, + slice: Option<(i64, usize)>, + ) -> PolarsResult { + if by_column.is_empty() { + // If no columns selected, any order (including original order) is correct. + return if let Some((offset, len)) = slice { + Ok(self.slice(offset, len)) + } else { + Ok(self.clone()) + }; + } + + // note that the by_column argument also contains evaluated expression from + // polars-lazy that may not even be present in this dataframe. therefore + // when we try to set the first columns as sorted, we ignore the error as + // expressions are not present (they are renamed to _POLARS_SORT_COLUMN_i. + let first_descending = sort_options.descending[0]; + let first_by_column = by_column[0].name().to_string(); + + let set_sorted = |df: &mut DataFrame| { + // Mark the first sort column as sorted; if the column does not exist it + // is ok, because we sorted by an expression not present in the dataframe + let _ = df.apply(&first_by_column, |s| { + let mut s = s.clone(); + if first_descending { + s.set_sorted_flag(IsSorted::Descending) + } else { + s.set_sorted_flag(IsSorted::Ascending) + } + s + }); + }; + if self.is_empty() { + let mut out = self.clone(); + set_sorted(&mut out); + return Ok(out); + } + + if let Some((0, k)) = slice { + if k < self.len() { + return self.bottom_k_impl(k, by_column, sort_options); + } + } + // Check if the required column is already sorted; if so we can exit early + // We can do so when there is only one column to sort by, for multiple columns + // it will be complicated to do so + #[cfg(feature = "dtype-categorical")] + let is_not_categorical_enum = + !(matches!(by_column[0].dtype(), DataType::Categorical(_, _)) + || matches!(by_column[0].dtype(), DataType::Enum(_, _))); + + #[cfg(not(feature = "dtype-categorical"))] + #[allow(non_upper_case_globals)] + const is_not_categorical_enum: bool = true; + + if by_column.len() == 1 && is_not_categorical_enum { + let required_sorting = if sort_options.descending[0] { + IsSorted::Descending + } else { + IsSorted::Ascending + }; + // If null count is 0 then nulls_last doesnt matter + // Safe to get value at last position since the dataframe is not empty (taken care above) + let no_sorting_required = (by_column[0].is_sorted_flag() == required_sorting) + && ((by_column[0].null_count() == 0) + || by_column[0].get(by_column[0].len() - 1).unwrap().is_null() + == sort_options.nulls_last[0]); + + if no_sorting_required { + return if let Some((offset, len)) = slice { + Ok(self.slice(offset, len)) + } else { + Ok(self.clone()) + }; + } + } + + let has_nested = by_column.iter().any(|s| s.dtype().is_nested()); + + // a lot of indirection in both sorting and take + let mut df = self.clone(); + let df = df.as_single_chunk_par(); + let mut take = match (by_column.len(), has_nested) { + (1, false) => { + let s = &by_column[0]; + let options = SortOptions { + descending: sort_options.descending[0], + nulls_last: sort_options.nulls_last[0], + multithreaded: sort_options.multithreaded, + maintain_order: sort_options.maintain_order, + limit: sort_options.limit, + }; + // fast path for a frame with a single series + // no need to compute the sort indices and then take by these indices + // simply sort and return as frame + if df.width() == 1 && df.check_name_to_idx(s.name().as_str()).is_ok() { + let mut out = s.sort_with(options)?; + if let Some((offset, len)) = slice { + out = out.slice(offset, len); + } + return Ok(out.into_frame()); + } + s.arg_sort(options) + }, + _ => { + if sort_options.nulls_last.iter().all(|&x| x) + || has_nested + || std::env::var("POLARS_ROW_FMT_SORT").is_ok() + { + argsort_multiple_row_fmt( + &by_column, + sort_options.descending, + sort_options.nulls_last, + sort_options.multithreaded, + )? + } else { + let (first, other) = prepare_arg_sort(by_column, &mut sort_options)?; + first + .as_materialized_series() + .arg_sort_multiple(&other, &sort_options)? + } + }, + }; + + if let Some((offset, len)) = slice { + take = take.slice(offset, len); + } + + // SAFETY: + // the created indices are in bounds + let mut df = unsafe { df.take_unchecked_impl(&take, sort_options.multithreaded) }; + set_sorted(&mut df); + Ok(df) + } + + /// Create a `DataFrame` that has fields for all the known runtime metadata for each column. + /// + /// This dataframe does not necessarily have a specified schema and may be changed at any + /// point. It is primarily used for debugging. + pub fn _to_metadata(&self) -> DataFrame { + let num_columns = self.columns.len(); + + let mut column_names = + StringChunkedBuilder::new(PlSmallStr::from_static("column_name"), num_columns); + let mut repr_ca = StringChunkedBuilder::new(PlSmallStr::from_static("repr"), num_columns); + let mut sorted_asc_ca = + BooleanChunkedBuilder::new(PlSmallStr::from_static("sorted_asc"), num_columns); + let mut sorted_dsc_ca = + BooleanChunkedBuilder::new(PlSmallStr::from_static("sorted_dsc"), num_columns); + let mut fast_explode_list_ca = + BooleanChunkedBuilder::new(PlSmallStr::from_static("fast_explode_list"), num_columns); + let mut materialized_at_ca = + StringChunkedBuilder::new(PlSmallStr::from_static("materialized_at"), num_columns); + + for col in &self.columns { + let flags = col.get_flags(); + + let (repr, materialized_at) = match col { + Column::Series(s) => ("series", s.materialized_at()), + Column::Partitioned(_) => ("partitioned", None), + Column::Scalar(_) => ("scalar", None), + }; + let sorted_asc = flags.contains(StatisticsFlags::IS_SORTED_ASC); + let sorted_dsc = flags.contains(StatisticsFlags::IS_SORTED_DSC); + let fast_explode_list = flags.contains(StatisticsFlags::CAN_FAST_EXPLODE_LIST); + + column_names.append_value(col.name().clone()); + repr_ca.append_value(repr); + sorted_asc_ca.append_value(sorted_asc); + sorted_dsc_ca.append_value(sorted_dsc); + fast_explode_list_ca.append_value(fast_explode_list); + materialized_at_ca.append_option(materialized_at.map(|v| format!("{v:#?}"))); + } + + unsafe { + DataFrame::new_no_checks( + self.width(), + vec![ + column_names.finish().into_column(), + repr_ca.finish().into_column(), + sorted_asc_ca.finish().into_column(), + sorted_dsc_ca.finish().into_column(), + fast_explode_list_ca.finish().into_column(), + materialized_at_ca.finish().into_column(), + ], + ) + } + } + + /// Return a sorted clone of this [`DataFrame`]. + /// + /// In many cases the output chunks will be continuous in memory but this is not guaranteed + /// # Example + /// + /// Sort by a single column with default options: + /// ``` + /// # use polars_core::prelude::*; + /// fn sort_by_sepal_width(df: &DataFrame) -> PolarsResult { + /// df.sort(["sepal_width"], Default::default()) + /// } + /// ``` + /// Sort by a single column with specific order: + /// ``` + /// # use polars_core::prelude::*; + /// fn sort_with_specific_order(df: &DataFrame, descending: bool) -> PolarsResult { + /// df.sort( + /// ["sepal_width"], + /// SortMultipleOptions::new() + /// .with_order_descending(descending) + /// ) + /// } + /// ``` + /// Sort by multiple columns with specifying order for each column: + /// ``` + /// # use polars_core::prelude::*; + /// fn sort_by_multiple_columns_with_specific_order(df: &DataFrame) -> PolarsResult { + /// df.sort( + /// ["sepal_width", "sepal_length"], + /// SortMultipleOptions::new() + /// .with_order_descending_multi([false, true]) + /// ) + /// } + /// ``` + /// See [`SortMultipleOptions`] for more options. + /// + /// Also see [`DataFrame::sort_in_place`]. + pub fn sort( + &self, + by: impl IntoVec, + sort_options: SortMultipleOptions, + ) -> PolarsResult { + let mut df = self.clone(); + df.sort_in_place(by, sort_options)?; + Ok(df) + } + + /// Replace a column with a [`Series`]. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let mut df: DataFrame = df!("Country" => ["United States", "China"], + /// "Area (km²)" => [9_833_520, 9_596_961])?; + /// let s: Series = Series::new("Country".into(), ["USA", "PRC"]); + /// + /// assert!(df.replace("Nation", s.clone()).is_err()); + /// assert!(df.replace("Country", s).is_ok()); + /// # Ok::<(), PolarsError>(()) + /// ``` + pub fn replace(&mut self, column: &str, new_col: S) -> PolarsResult<&mut Self> { + self.apply(column, |_| new_col.into_series()) + } + + /// Replace or update a column. The difference between this method and [DataFrame::with_column] + /// is that now the value of `column: &str` determines the name of the column and not the name + /// of the `Series` passed to this method. + pub fn replace_or_add( + &mut self, + column: PlSmallStr, + new_col: S, + ) -> PolarsResult<&mut Self> { + let mut new_col = new_col.into_series(); + new_col.rename(column); + self.with_column(new_col) + } + + /// Replace column at index `idx` with a [`Series`]. + /// + /// # Example + /// + /// ```ignored + /// # use polars_core::prelude::*; + /// let s0 = Series::new("foo".into(), ["ham", "spam", "egg"]); + /// let s1 = Series::new("ascii".into(), [70, 79, 79]); + /// let mut df = DataFrame::new(vec![s0, s1])?; + /// + /// // Add 32 to get lowercase ascii values + /// df.replace_column(1, df.select_at_idx(1).unwrap() + 32); + /// # Ok::<(), PolarsError>(()) + /// ``` + pub fn replace_column( + &mut self, + index: usize, + new_column: C, + ) -> PolarsResult<&mut Self> { + polars_ensure!( + index < self.width(), + ShapeMismatch: + "unable to replace at index {}, the DataFrame has only {} columns", + index, self.width(), + ); + let mut new_column = new_column.into_column(); + polars_ensure!( + new_column.len() == self.height(), + ShapeMismatch: + "unable to replace a column, series length {} doesn't match the DataFrame height {}", + new_column.len(), self.height(), + ); + let old_col = &mut self.columns[index]; + mem::swap(old_col, &mut new_column); + self.clear_schema(); + Ok(self) + } + + /// Apply a closure to a column. This is the recommended way to do in place modification. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let s0 = Column::new("foo".into(), ["ham", "spam", "egg"]); + /// let s1 = Column::new("names".into(), ["Jean", "Claude", "van"]); + /// let mut df = DataFrame::new(vec![s0, s1])?; + /// + /// fn str_to_len(str_val: &Column) -> Column { + /// str_val.str() + /// .unwrap() + /// .into_iter() + /// .map(|opt_name: Option<&str>| { + /// opt_name.map(|name: &str| name.len() as u32) + /// }) + /// .collect::() + /// .into_column() + /// } + /// + /// // Replace the names column by the length of the names. + /// df.apply("names", str_to_len); + /// # Ok::<(), PolarsError>(()) + /// ``` + /// Results in: + /// + /// ```text + /// +--------+-------+ + /// | foo | | + /// | --- | names | + /// | str | u32 | + /// +========+=======+ + /// | "ham" | 4 | + /// +--------+-------+ + /// | "spam" | 6 | + /// +--------+-------+ + /// | "egg" | 3 | + /// +--------+-------+ + /// ``` + pub fn apply(&mut self, name: &str, f: F) -> PolarsResult<&mut Self> + where + F: FnOnce(&Column) -> C, + C: IntoColumn, + { + let idx = self.check_name_to_idx(name)?; + self.apply_at_idx(idx, f) + } + + /// Apply a closure to a column at index `idx`. This is the recommended way to do in place + /// modification. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let s0 = Column::new("foo".into(), ["ham", "spam", "egg"]); + /// let s1 = Column::new("ascii".into(), [70, 79, 79]); + /// let mut df = DataFrame::new(vec![s0, s1])?; + /// + /// // Add 32 to get lowercase ascii values + /// df.apply_at_idx(1, |s| s + 32); + /// # Ok::<(), PolarsError>(()) + /// ``` + /// Results in: + /// + /// ```text + /// +--------+-------+ + /// | foo | ascii | + /// | --- | --- | + /// | str | i32 | + /// +========+=======+ + /// | "ham" | 102 | + /// +--------+-------+ + /// | "spam" | 111 | + /// +--------+-------+ + /// | "egg" | 111 | + /// +--------+-------+ + /// ``` + pub fn apply_at_idx(&mut self, idx: usize, f: F) -> PolarsResult<&mut Self> + where + F: FnOnce(&Column) -> C, + C: IntoColumn, + { + let df_height = self.height(); + let width = self.width(); + let col = self.columns.get_mut(idx).ok_or_else(|| { + polars_err!( + ComputeError: "invalid column index: {} for a DataFrame with {} columns", + idx, width + ) + })?; + let name = col.name().clone(); + let new_col = f(col).into_column(); + match new_col.len() { + 1 => { + let new_col = new_col.new_from_index(0, df_height); + let _ = mem::replace(col, new_col); + }, + len if (len == df_height) => { + let _ = mem::replace(col, new_col); + }, + len => polars_bail!( + ShapeMismatch: + "resulting Series has length {} while the DataFrame has height {}", + len, df_height + ), + } + + // make sure the name remains the same after applying the closure + unsafe { + let col = self.columns.get_unchecked_mut(idx); + col.rename(name); + } + Ok(self) + } + + /// Apply a closure that may fail to a column at index `idx`. This is the recommended way to do in place + /// modification. + /// + /// # Example + /// + /// This is the idiomatic way to replace some values a column of a `DataFrame` given range of indexes. + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let s0 = Column::new("foo".into(), ["ham", "spam", "egg", "bacon", "quack"]); + /// let s1 = Column::new("values".into(), [1, 2, 3, 4, 5]); + /// let mut df = DataFrame::new(vec![s0, s1])?; + /// + /// let idx = vec![0, 1, 4]; + /// + /// df.try_apply("foo", |c| { + /// c.str()? + /// .scatter_with(idx, |opt_val| opt_val.map(|string| format!("{}-is-modified", string))) + /// }); + /// # Ok::<(), PolarsError>(()) + /// ``` + /// Results in: + /// + /// ```text + /// +---------------------+--------+ + /// | foo | values | + /// | --- | --- | + /// | str | i32 | + /// +=====================+========+ + /// | "ham-is-modified" | 1 | + /// +---------------------+--------+ + /// | "spam-is-modified" | 2 | + /// +---------------------+--------+ + /// | "egg" | 3 | + /// +---------------------+--------+ + /// | "bacon" | 4 | + /// +---------------------+--------+ + /// | "quack-is-modified" | 5 | + /// +---------------------+--------+ + /// ``` + pub fn try_apply_at_idx(&mut self, idx: usize, f: F) -> PolarsResult<&mut Self> + where + F: FnOnce(&Column) -> PolarsResult, + C: IntoColumn, + { + let width = self.width(); + let col = self.columns.get_mut(idx).ok_or_else(|| { + polars_err!( + ComputeError: "invalid column index: {} for a DataFrame with {} columns", + idx, width + ) + })?; + let name = col.name().clone(); + + let _ = mem::replace(col, f(col).map(|c| c.into_column())?); + + // make sure the name remains the same after applying the closure + unsafe { + let col = self.columns.get_unchecked_mut(idx); + col.rename(name); + } + Ok(self) + } + + /// Apply a closure that may fail to a column. This is the recommended way to do in place + /// modification. + /// + /// # Example + /// + /// This is the idiomatic way to replace some values a column of a `DataFrame` given a boolean mask. + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let s0 = Column::new("foo".into(), ["ham", "spam", "egg", "bacon", "quack"]); + /// let s1 = Column::new("values".into(), [1, 2, 3, 4, 5]); + /// let mut df = DataFrame::new(vec![s0, s1])?; + /// + /// // create a mask + /// let values = df.column("values")?.as_materialized_series(); + /// let mask = values.lt_eq(1)? | values.gt_eq(5_i32)?; + /// + /// df.try_apply("foo", |c| { + /// c.str()? + /// .set(&mask, Some("not_within_bounds")) + /// }); + /// # Ok::<(), PolarsError>(()) + /// ``` + /// Results in: + /// + /// ```text + /// +---------------------+--------+ + /// | foo | values | + /// | --- | --- | + /// | str | i32 | + /// +=====================+========+ + /// | "not_within_bounds" | 1 | + /// +---------------------+--------+ + /// | "spam" | 2 | + /// +---------------------+--------+ + /// | "egg" | 3 | + /// +---------------------+--------+ + /// | "bacon" | 4 | + /// +---------------------+--------+ + /// | "not_within_bounds" | 5 | + /// +---------------------+--------+ + /// ``` + pub fn try_apply(&mut self, column: &str, f: F) -> PolarsResult<&mut Self> + where + F: FnOnce(&Series) -> PolarsResult, + C: IntoColumn, + { + let idx = self.try_get_column_index(column)?; + self.try_apply_at_idx(idx, |c| f(c.as_materialized_series())) + } + + /// Slice the [`DataFrame`] along the rows. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let df: DataFrame = df!("Fruit" => ["Apple", "Grape", "Grape", "Fig", "Fig"], + /// "Color" => ["Green", "Red", "White", "White", "Red"])?; + /// let sl: DataFrame = df.slice(2, 3); + /// + /// assert_eq!(sl.shape(), (3, 2)); + /// println!("{}", sl); + /// # Ok::<(), PolarsError>(()) + /// ``` + /// Output: + /// ```text + /// shape: (3, 2) + /// +-------+-------+ + /// | Fruit | Color | + /// | --- | --- | + /// | str | str | + /// +=======+=======+ + /// | Grape | White | + /// +-------+-------+ + /// | Fig | White | + /// +-------+-------+ + /// | Fig | Red | + /// +-------+-------+ + /// ``` + #[must_use] + pub fn slice(&self, offset: i64, length: usize) -> Self { + if offset == 0 && length == self.height() { + return self.clone(); + } + if length == 0 { + return self.clear(); + } + let col = self + .columns + .iter() + .map(|s| s.slice(offset, length)) + .collect::>(); + + let height = if let Some(fst) = col.first() { + fst.len() + } else { + let (_, length) = slice_offsets(offset, length, self.height()); + length + }; + + unsafe { DataFrame::new_no_checks(height, col) } + } + + /// Split [`DataFrame`] at the given `offset`. + pub fn split_at(&self, offset: i64) -> (Self, Self) { + let (a, b) = self.columns.iter().map(|s| s.split_at(offset)).unzip(); + + let (idx, _) = slice_offsets(offset, 0, self.height()); + + let a = unsafe { DataFrame::new_no_checks(idx, a) }; + let b = unsafe { DataFrame::new_no_checks(self.height() - idx, b) }; + (a, b) + } + + pub fn clear(&self) -> Self { + let col = self.columns.iter().map(|s| s.clear()).collect::>(); + unsafe { DataFrame::new_no_checks(0, col) } + } + + #[must_use] + pub fn slice_par(&self, offset: i64, length: usize) -> Self { + if offset == 0 && length == self.height() { + return self.clone(); + } + let columns = self._apply_columns_par(&|s| s.slice(offset, length)); + unsafe { DataFrame::new_no_checks(length, columns) } + } + + #[must_use] + pub fn _slice_and_realloc(&self, offset: i64, length: usize) -> Self { + if offset == 0 && length == self.height() { + return self.clone(); + } + // @scalar-opt + let columns = self._apply_columns(&|s| { + let mut out = s.slice(offset, length); + out.shrink_to_fit(); + out + }); + unsafe { DataFrame::new_no_checks(length, columns) } + } + + /// Get the head of the [`DataFrame`]. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let countries: DataFrame = + /// df!("Rank by GDP (2021)" => [1, 2, 3, 4, 5], + /// "Continent" => ["North America", "Asia", "Asia", "Europe", "Europe"], + /// "Country" => ["United States", "China", "Japan", "Germany", "United Kingdom"], + /// "Capital" => ["Washington", "Beijing", "Tokyo", "Berlin", "London"])?; + /// assert_eq!(countries.shape(), (5, 4)); + /// + /// println!("{}", countries.head(Some(3))); + /// # Ok::<(), PolarsError>(()) + /// ``` + /// + /// Output: + /// + /// ```text + /// shape: (3, 4) + /// +--------------------+---------------+---------------+------------+ + /// | Rank by GDP (2021) | Continent | Country | Capital | + /// | --- | --- | --- | --- | + /// | i32 | str | str | str | + /// +====================+===============+===============+============+ + /// | 1 | North America | United States | Washington | + /// +--------------------+---------------+---------------+------------+ + /// | 2 | Asia | China | Beijing | + /// +--------------------+---------------+---------------+------------+ + /// | 3 | Asia | Japan | Tokyo | + /// +--------------------+---------------+---------------+------------+ + /// ``` + #[must_use] + pub fn head(&self, length: Option) -> Self { + let col = self + .columns + .iter() + .map(|c| c.head(length)) + .collect::>(); + + let height = length.unwrap_or(HEAD_DEFAULT_LENGTH); + let height = usize::min(height, self.height()); + unsafe { DataFrame::new_no_checks(height, col) } + } + + /// Get the tail of the [`DataFrame`]. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let countries: DataFrame = + /// df!("Rank (2021)" => [105, 106, 107, 108, 109], + /// "Apple Price (€/kg)" => [0.75, 0.70, 0.70, 0.65, 0.52], + /// "Country" => ["Kosovo", "Moldova", "North Macedonia", "Syria", "Turkey"])?; + /// assert_eq!(countries.shape(), (5, 3)); + /// + /// println!("{}", countries.tail(Some(2))); + /// # Ok::<(), PolarsError>(()) + /// ``` + /// + /// Output: + /// + /// ```text + /// shape: (2, 3) + /// +-------------+--------------------+---------+ + /// | Rank (2021) | Apple Price (€/kg) | Country | + /// | --- | --- | --- | + /// | i32 | f64 | str | + /// +=============+====================+=========+ + /// | 108 | 0.63 | Syria | + /// +-------------+--------------------+---------+ + /// | 109 | 0.63 | Turkey | + /// +-------------+--------------------+---------+ + /// ``` + #[must_use] + pub fn tail(&self, length: Option) -> Self { + let col = self + .columns + .iter() + .map(|c| c.tail(length)) + .collect::>(); + + let height = length.unwrap_or(TAIL_DEFAULT_LENGTH); + let height = usize::min(height, self.height()); + unsafe { DataFrame::new_no_checks(height, col) } + } + + /// Iterator over the rows in this [`DataFrame`] as Arrow RecordBatches. + /// + /// # Panics + /// + /// Panics if the [`DataFrame`] that is passed is not rechunked. + /// + /// This responsibility is left to the caller as we don't want to take mutable references here, + /// but we also don't want to rechunk here, as this operation is costly and would benefit the caller + /// as well. + pub fn iter_chunks(&self, compat_level: CompatLevel, parallel: bool) -> RecordBatchIter { + debug_assert!(!self.should_rechunk(), "expected equal chunks"); + // If any of the columns is binview and we don't convert `compat_level` we allow parallelism + // as we must allocate arrow strings/binaries. + let must_convert = compat_level.0 == 0; + let parallel = parallel + && must_convert + && self.columns.len() > 1 + && self + .columns + .iter() + .any(|s| matches!(s.dtype(), DataType::String | DataType::Binary)); + + RecordBatchIter { + columns: &self.columns, + schema: Arc::new( + self.columns + .iter() + .map(|c| c.field().to_arrow(compat_level)) + .collect(), + ), + idx: 0, + n_chunks: self.first_col_n_chunks(), + compat_level, + parallel, + } + } + + /// Iterator over the rows in this [`DataFrame`] as Arrow RecordBatches as physical values. + /// + /// # Panics + /// + /// Panics if the [`DataFrame`] that is passed is not rechunked. + /// + /// This responsibility is left to the caller as we don't want to take mutable references here, + /// but we also don't want to rechunk here, as this operation is costly and would benefit the caller + /// as well. + pub fn iter_chunks_physical(&self) -> PhysRecordBatchIter<'_> { + PhysRecordBatchIter { + schema: Arc::new( + self.get_columns() + .iter() + .map(|c| c.field().to_arrow(CompatLevel::newest())) + .collect(), + ), + arr_iters: self + .materialized_column_iter() + .map(|s| s.chunks().iter()) + .collect(), + } + } + + /// Get a [`DataFrame`] with all the columns in reversed order. + #[must_use] + pub fn reverse(&self) -> Self { + let col = self.columns.iter().map(|s| s.reverse()).collect::>(); + unsafe { DataFrame::new_no_checks(self.height(), col) } + } + + /// Shift the values by a given period and fill the parts that will be empty due to this operation + /// with `Nones`. + /// + /// See the method on [Series](crate::series::SeriesTrait::shift) for more info on the `shift` operation. + #[must_use] + pub fn shift(&self, periods: i64) -> Self { + let col = self._apply_columns_par(&|s| s.shift(periods)); + unsafe { DataFrame::new_no_checks(self.height(), col) } + } + + /// Replace None values with one of the following strategies: + /// * Forward fill (replace None with the previous value) + /// * Backward fill (replace None with the next value) + /// * Mean fill (replace None with the mean of the whole array) + /// * Min fill (replace None with the minimum of the whole array) + /// * Max fill (replace None with the maximum of the whole array) + /// + /// See the method on [Series](crate::series::Series::fill_null) for more info on the `fill_null` operation. + pub fn fill_null(&self, strategy: FillNullStrategy) -> PolarsResult { + let col = self.try_apply_columns_par(&|s| s.fill_null(strategy))?; + + Ok(unsafe { DataFrame::new_no_checks(self.height(), col) }) + } + + /// Pipe different functions/ closure operations that work on a DataFrame together. + pub fn pipe(self, f: F) -> PolarsResult + where + F: Fn(DataFrame) -> PolarsResult, + { + f(self) + } + + /// Pipe different functions/ closure operations that work on a DataFrame together. + pub fn pipe_mut(&mut self, f: F) -> PolarsResult + where + F: Fn(&mut DataFrame) -> PolarsResult, + { + f(self) + } + + /// Pipe different functions/ closure operations that work on a DataFrame together. + pub fn pipe_with_args(self, f: F, args: Args) -> PolarsResult + where + F: Fn(DataFrame, Args) -> PolarsResult, + { + f(self, args) + } + + /// Drop duplicate rows from a [`DataFrame`]. + /// *This fails when there is a column of type List in DataFrame* + /// + /// Stable means that the order is maintained. This has a higher cost than an unstable distinct. + /// + /// # Example + /// + /// ```no_run + /// # use polars_core::prelude::*; + /// let df = df! { + /// "flt" => [1., 1., 2., 2., 3., 3.], + /// "int" => [1, 1, 2, 2, 3, 3, ], + /// "str" => ["a", "a", "b", "b", "c", "c"] + /// }?; + /// + /// println!("{}", df.unique_stable(None, UniqueKeepStrategy::First, None)?); + /// # Ok::<(), PolarsError>(()) + /// ``` + /// Returns + /// + /// ```text + /// +-----+-----+-----+ + /// | flt | int | str | + /// | --- | --- | --- | + /// | f64 | i32 | str | + /// +=====+=====+=====+ + /// | 1 | 1 | "a" | + /// +-----+-----+-----+ + /// | 2 | 2 | "b" | + /// +-----+-----+-----+ + /// | 3 | 3 | "c" | + /// +-----+-----+-----+ + /// ``` + #[cfg(feature = "algorithm_group_by")] + pub fn unique_stable( + &self, + subset: Option<&[String]>, + keep: UniqueKeepStrategy, + slice: Option<(i64, usize)>, + ) -> PolarsResult { + self.unique_impl( + true, + subset.map(|v| v.iter().map(|x| PlSmallStr::from_str(x.as_str())).collect()), + keep, + slice, + ) + } + + /// Unstable distinct. See [`DataFrame::unique_stable`]. + #[cfg(feature = "algorithm_group_by")] + pub fn unique( + &self, + subset: Option<&[String]>, + keep: UniqueKeepStrategy, + slice: Option<(i64, usize)>, + ) -> PolarsResult { + self.unique_impl( + false, + subset.map(|v| v.iter().map(|x| PlSmallStr::from_str(x.as_str())).collect()), + keep, + slice, + ) + } + + #[cfg(feature = "algorithm_group_by")] + pub fn unique_impl( + &self, + maintain_order: bool, + subset: Option>, + keep: UniqueKeepStrategy, + slice: Option<(i64, usize)>, + ) -> PolarsResult { + let names = subset.unwrap_or_else(|| self.get_column_names_owned()); + let mut df = self.clone(); + // take on multiple chunks is terrible + df.as_single_chunk_par(); + + let columns = match (keep, maintain_order) { + (UniqueKeepStrategy::First | UniqueKeepStrategy::Any, true) => { + let gb = df.group_by_stable(names)?; + let groups = gb.get_groups(); + let (offset, len) = slice.unwrap_or((0, groups.len())); + let groups = groups.slice(offset, len); + df._apply_columns_par(&|s| unsafe { s.agg_first(&groups) }) + }, + (UniqueKeepStrategy::Last, true) => { + // maintain order by last values, so the sorted groups are not correct as they + // are sorted by the first value + let gb = df.group_by(names)?; + let groups = gb.get_groups(); + + let func = |g: GroupsIndicator| match g { + GroupsIndicator::Idx((_first, idx)) => idx[idx.len() - 1], + GroupsIndicator::Slice([first, len]) => first + len - 1, + }; + + let last_idx: NoNull = match slice { + None => groups.iter().map(func).collect(), + Some((offset, len)) => { + let (offset, len) = slice_offsets(offset, len, groups.len()); + groups.iter().skip(offset).take(len).map(func).collect() + }, + }; + + let last_idx = last_idx.sort(false); + return Ok(unsafe { df.take_unchecked(&last_idx) }); + }, + (UniqueKeepStrategy::First | UniqueKeepStrategy::Any, false) => { + let gb = df.group_by(names)?; + let groups = gb.get_groups(); + let (offset, len) = slice.unwrap_or((0, groups.len())); + let groups = groups.slice(offset, len); + df._apply_columns_par(&|s| unsafe { s.agg_first(&groups) }) + }, + (UniqueKeepStrategy::Last, false) => { + let gb = df.group_by(names)?; + let groups = gb.get_groups(); + let (offset, len) = slice.unwrap_or((0, groups.len())); + let groups = groups.slice(offset, len); + df._apply_columns_par(&|s| unsafe { s.agg_last(&groups) }) + }, + (UniqueKeepStrategy::None, _) => { + let df_part = df.select(names)?; + let mask = df_part.is_unique()?; + let mask = match slice { + None => mask, + Some((offset, len)) => mask.slice(offset, len), + }; + return df.filter(&mask); + }, + }; + + let height = Self::infer_height(&columns); + Ok(unsafe { DataFrame::new_no_checks(height, columns) }) + } + + /// Get a mask of all the unique rows in the [`DataFrame`]. + /// + /// # Example + /// + /// ```no_run + /// # use polars_core::prelude::*; + /// let df: DataFrame = df!("Company" => ["Apple", "Microsoft"], + /// "ISIN" => ["US0378331005", "US5949181045"])?; + /// let ca: ChunkedArray = df.is_unique()?; + /// + /// assert!(ca.all()); + /// # Ok::<(), PolarsError>(()) + /// ``` + #[cfg(feature = "algorithm_group_by")] + pub fn is_unique(&self) -> PolarsResult { + let gb = self.group_by(self.get_column_names_owned())?; + let groups = gb.get_groups(); + Ok(is_unique_helper( + groups, + self.height() as IdxSize, + true, + false, + )) + } + + /// Get a mask of all the duplicated rows in the [`DataFrame`]. + /// + /// # Example + /// + /// ```no_run + /// # use polars_core::prelude::*; + /// let df: DataFrame = df!("Company" => ["Alphabet", "Alphabet"], + /// "ISIN" => ["US02079K3059", "US02079K1079"])?; + /// let ca: ChunkedArray = df.is_duplicated()?; + /// + /// assert!(!ca.all()); + /// # Ok::<(), PolarsError>(()) + /// ``` + #[cfg(feature = "algorithm_group_by")] + pub fn is_duplicated(&self) -> PolarsResult { + let gb = self.group_by(self.get_column_names_owned())?; + let groups = gb.get_groups(); + Ok(is_unique_helper( + groups, + self.height() as IdxSize, + false, + true, + )) + } + + /// Create a new [`DataFrame`] that shows the null counts per column. + #[must_use] + pub fn null_count(&self) -> Self { + let cols = self + .columns + .iter() + .map(|c| Column::new(c.name().clone(), [c.null_count() as IdxSize])) + .collect(); + unsafe { Self::new_no_checks(1, cols) } + } + + /// Hash and combine the row values + #[cfg(feature = "row_hash")] + pub fn hash_rows( + &mut self, + hasher_builder: Option, + ) -> PolarsResult { + let dfs = split_df(self, POOL.current_num_threads(), false); + let (cas, _) = _df_rows_to_hashes_threaded_vertical(&dfs, hasher_builder)?; + + let mut iter = cas.into_iter(); + let mut acc_ca = iter.next().unwrap(); + for ca in iter { + acc_ca.append(&ca)?; + } + Ok(acc_ca.rechunk().into_owned()) + } + + /// Get the supertype of the columns in this DataFrame + pub fn get_supertype(&self) -> Option> { + self.columns + .iter() + .map(|s| Ok(s.dtype().clone())) + .reduce(|acc, b| try_get_supertype(&acc?, &b.unwrap())) + } + + /// Take by index values given by the slice `idx`. + /// # Warning + /// Be careful with allowing threads when calling this in a large hot loop + /// every thread split may be on rayon stack and lead to SO + #[doc(hidden)] + pub unsafe fn _take_unchecked_slice(&self, idx: &[IdxSize], allow_threads: bool) -> Self { + self._take_unchecked_slice_sorted(idx, allow_threads, IsSorted::Not) + } + + /// Take by index values given by the slice `idx`. Use this over `_take_unchecked_slice` + /// if the index value in `idx` are sorted. This will maintain sorted flags. + /// + /// # Warning + /// Be careful with allowing threads when calling this in a large hot loop + /// every thread split may be on rayon stack and lead to SO + #[doc(hidden)] + pub unsafe fn _take_unchecked_slice_sorted( + &self, + idx: &[IdxSize], + allow_threads: bool, + sorted: IsSorted, + ) -> Self { + #[cfg(debug_assertions)] + { + if idx.len() > 2 { + match sorted { + IsSorted::Ascending => { + assert!(idx[0] <= idx[idx.len() - 1]); + }, + IsSorted::Descending => { + assert!(idx[0] >= idx[idx.len() - 1]); + }, + _ => {}, + } + } + } + let mut ca = IdxCa::mmap_slice(PlSmallStr::EMPTY, idx); + ca.set_sorted_flag(sorted); + self.take_unchecked_impl(&ca, allow_threads) + } + + #[cfg(all(feature = "partition_by", feature = "algorithm_group_by"))] + #[doc(hidden)] + pub fn _partition_by_impl( + &self, + cols: &[PlSmallStr], + stable: bool, + include_key: bool, + parallel: bool, + ) -> PolarsResult> { + let selected_keys = self.select_columns(cols.iter().cloned())?; + let groups = self.group_by_with_series(selected_keys, parallel, stable)?; + let groups = groups.take_groups(); + + // drop key columns prior to calculation if requested + let df = if include_key { + self.clone() + } else { + self.drop_many(cols.iter().cloned()) + }; + + if parallel { + // don't parallelize this + // there is a lot of parallelization in take and this may easily SO + POOL.install(|| { + match groups.as_ref() { + GroupsType::Idx(idx) => { + // Rechunk as the gather may rechunk for every group #17562. + let mut df = df.clone(); + df.as_single_chunk_par(); + Ok(idx + .into_par_iter() + .map(|(_, group)| { + // groups are in bounds + unsafe { + df._take_unchecked_slice_sorted( + group, + false, + IsSorted::Ascending, + ) + } + }) + .collect()) + }, + GroupsType::Slice { groups, .. } => Ok(groups + .into_par_iter() + .map(|[first, len]| df.slice(*first as i64, *len as usize)) + .collect()), + } + }) + } else { + match groups.as_ref() { + GroupsType::Idx(idx) => { + // Rechunk as the gather may rechunk for every group #17562. + let mut df = df.clone(); + df.as_single_chunk(); + Ok(idx + .into_iter() + .map(|(_, group)| { + // groups are in bounds + unsafe { + df._take_unchecked_slice_sorted(group, false, IsSorted::Ascending) + } + }) + .collect()) + }, + GroupsType::Slice { groups, .. } => Ok(groups + .iter() + .map(|[first, len]| df.slice(*first as i64, *len as usize)) + .collect()), + } + } + } + + /// Split into multiple DataFrames partitioned by groups + #[cfg(feature = "partition_by")] + pub fn partition_by(&self, cols: I, include_key: bool) -> PolarsResult> + where + I: IntoIterator, + S: Into, + { + let cols = cols + .into_iter() + .map(Into::into) + .collect::>(); + self._partition_by_impl(cols.as_slice(), false, include_key, true) + } + + /// Split into multiple DataFrames partitioned by groups + /// Order of the groups are maintained. + #[cfg(feature = "partition_by")] + pub fn partition_by_stable( + &self, + cols: I, + include_key: bool, + ) -> PolarsResult> + where + I: IntoIterator, + S: Into, + { + let cols = cols + .into_iter() + .map(Into::into) + .collect::>(); + self._partition_by_impl(cols.as_slice(), true, include_key, true) + } + + /// Unnest the given `Struct` columns. This means that the fields of the `Struct` type will be + /// inserted as columns. + #[cfg(feature = "dtype-struct")] + pub fn unnest>(&self, cols: I) -> PolarsResult { + let cols = cols.into_vec(); + self.unnest_impl(cols.into_iter().collect()) + } + + #[cfg(feature = "dtype-struct")] + fn unnest_impl(&self, cols: PlHashSet) -> PolarsResult { + let mut new_cols = Vec::with_capacity(std::cmp::min(self.width() * 2, self.width() + 128)); + let mut count = 0; + for s in &self.columns { + if cols.contains(s.name()) { + let ca = s.struct_()?.clone(); + new_cols.extend(ca.fields_as_series().into_iter().map(Column::from)); + count += 1; + } else { + new_cols.push(s.clone()) + } + } + if count != cols.len() { + // one or more columns not found + // the code below will return an error with the missing name + let schema = self.schema(); + for col in cols { + let _ = schema + .get(col.as_str()) + .ok_or_else(|| polars_err!(col_not_found = col))?; + } + } + DataFrame::new(new_cols) + } + + pub(crate) fn infer_height(cols: &[Column]) -> usize { + cols.first().map_or(0, Column::len) + } + + pub fn append_record_batch(&mut self, rb: RecordBatchT) -> PolarsResult<()> { + // @Optimize: this does a lot of unnecessary allocations. We should probably have a + // append_chunk or something like this. It is just quite difficult to make that safe. + let df = DataFrame::from(rb); + polars_ensure!( + self.schema() == df.schema(), + SchemaMismatch: "cannot append record batch with different schema", + ); + self.vstack_mut_owned_unchecked(df); + Ok(()) + } +} + +pub struct RecordBatchIter<'a> { + columns: &'a Vec, + schema: ArrowSchemaRef, + idx: usize, + n_chunks: usize, + compat_level: CompatLevel, + parallel: bool, +} + +impl Iterator for RecordBatchIter<'_> { + type Item = RecordBatch; + + fn next(&mut self) -> Option { + if self.idx >= self.n_chunks { + return None; + } + + // Create a batch of the columns with the same chunk no. + let batch_cols: Vec = if self.parallel { + let iter = self + .columns + .par_iter() + .map(Column::as_materialized_series) + .map(|s| s.to_arrow(self.idx, self.compat_level)); + POOL.install(|| iter.collect()) + } else { + self.columns + .iter() + .map(Column::as_materialized_series) + .map(|s| s.to_arrow(self.idx, self.compat_level)) + .collect() + }; + self.idx += 1; + + let length = batch_cols.first().map_or(0, |arr| arr.len()); + Some(RecordBatch::new(length, self.schema.clone(), batch_cols)) + } + + fn size_hint(&self) -> (usize, Option) { + let n = self.n_chunks - self.idx; + (n, Some(n)) + } +} + +pub struct PhysRecordBatchIter<'a> { + schema: ArrowSchemaRef, + arr_iters: Vec>, +} + +impl Iterator for PhysRecordBatchIter<'_> { + type Item = RecordBatch; + + fn next(&mut self) -> Option { + let arrs = self + .arr_iters + .iter_mut() + .map(|phys_iter| phys_iter.next().cloned()) + .collect::>>()?; + + let length = arrs.first().map_or(0, |arr| arr.len()); + Some(RecordBatch::new(length, self.schema.clone(), arrs)) + } + + fn size_hint(&self) -> (usize, Option) { + if let Some(iter) = self.arr_iters.first() { + iter.size_hint() + } else { + (0, None) + } + } +} + +impl Default for DataFrame { + fn default() -> Self { + DataFrame::empty() + } +} + +impl From for Vec { + fn from(df: DataFrame) -> Self { + df.columns + } +} + +// utility to test if we can vstack/extend the columns +fn ensure_can_extend(left: &Column, right: &Column) -> PolarsResult<()> { + polars_ensure!( + left.name() == right.name(), + ShapeMismatch: "unable to vstack, column names don't match: {:?} and {:?}", + left.name(), right.name(), + ); + Ok(()) +} + +#[cfg(test)] +mod test { + use super::*; + + fn create_frame() -> DataFrame { + let s0 = Column::new("days".into(), [0, 1, 2].as_ref()); + let s1 = Column::new("temp".into(), [22.1, 19.9, 7.].as_ref()); + DataFrame::new(vec![s0, s1]).unwrap() + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_recordbatch_iterator() { + let df = df!( + "foo" => [1, 2, 3, 4, 5] + ) + .unwrap(); + let mut iter = df.iter_chunks(CompatLevel::newest(), false); + assert_eq!(5, iter.next().unwrap().len()); + assert!(iter.next().is_none()); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_select() { + let df = create_frame(); + assert_eq!( + df.column("days") + .unwrap() + .as_series() + .unwrap() + .equal(1) + .unwrap() + .sum(), + Some(1) + ); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_filter_broadcast_on_string_col() { + let col_name = "some_col"; + let v = vec!["test".to_string()]; + let s0 = Column::new(PlSmallStr::from_str(col_name), v); + let mut df = DataFrame::new(vec![s0]).unwrap(); + + df = df + .filter( + &df.column(col_name) + .unwrap() + .as_materialized_series() + .equal("") + .unwrap(), + ) + .unwrap(); + assert_eq!( + df.column(col_name) + .unwrap() + .as_materialized_series() + .n_chunks(), + 1 + ); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_filter_broadcast_on_list_col() { + let s1 = Series::new(PlSmallStr::EMPTY, [true, false, true]); + let ll: ListChunked = [&s1].iter().copied().collect(); + + let mask = BooleanChunked::from_slice(PlSmallStr::EMPTY, &[false]); + let new = ll.filter(&mask).unwrap(); + + assert_eq!(new.chunks.len(), 1); + assert_eq!(new.len(), 0); + } + + #[test] + fn slice() { + let df = create_frame(); + let sliced_df = df.slice(0, 2); + assert_eq!(sliced_df.shape(), (2, 2)); + } + + #[test] + fn rechunk_false() { + let df = create_frame(); + assert!(!df.should_rechunk()) + } + + #[test] + fn rechunk_true() -> PolarsResult<()> { + let mut base = df!( + "a" => [1, 2, 3], + "b" => [1, 2, 3] + )?; + + // Create a series with multiple chunks + let mut s = Series::new("foo".into(), 0..2); + let s2 = Series::new("bar".into(), 0..1); + s.append(&s2)?; + + // Append series to frame + let out = base.with_column(s)?; + + // Now we should rechunk + assert!(out.should_rechunk()); + Ok(()) + } + + #[test] + fn test_duplicate_column() { + let mut df = df! { + "foo" => [1, 2, 3] + } + .unwrap(); + // check if column is replaced + assert!( + df.with_column(Series::new("foo".into(), &[1, 2, 3])) + .is_ok() + ); + assert!( + df.with_column(Series::new("bar".into(), &[1, 2, 3])) + .is_ok() + ); + assert!(df.column("bar").is_ok()) + } + + #[test] + #[cfg_attr(miri, ignore)] + fn distinct() { + let df = df! { + "flt" => [1., 1., 2., 2., 3., 3.], + "int" => [1, 1, 2, 2, 3, 3, ], + "str" => ["a", "a", "b", "b", "c", "c"] + } + .unwrap(); + let df = df + .unique_stable(None, UniqueKeepStrategy::First, None) + .unwrap() + .sort(["flt"], SortMultipleOptions::default()) + .unwrap(); + let valid = df! { + "flt" => [1., 2., 3.], + "int" => [1, 2, 3], + "str" => ["a", "b", "c"] + } + .unwrap(); + assert!(df.equals(&valid)); + } + + #[test] + fn test_vstack() { + // check that it does not accidentally rechunks + let mut df = df! { + "flt" => [1., 1., 2., 2., 3., 3.], + "int" => [1, 1, 2, 2, 3, 3, ], + "str" => ["a", "a", "b", "b", "c", "c"] + } + .unwrap(); + + df.vstack_mut(&df.slice(0, 3)).unwrap(); + assert_eq!(df.first_col_n_chunks(), 2) + } + + #[test] + fn test_vstack_on_empty_dataframe() { + let mut df = DataFrame::empty(); + + let df_data = df! { + "flt" => [1., 1., 2., 2., 3., 3.], + "int" => [1, 1, 2, 2, 3, 3, ], + "str" => ["a", "a", "b", "b", "c", "c"] + } + .unwrap(); + + df.vstack_mut(&df_data).unwrap(); + assert_eq!(df.height, 6) + } + + #[test] + fn test_replace_or_add() -> PolarsResult<()> { + let mut df = df!( + "a" => [1, 2, 3], + "b" => [1, 2, 3] + )?; + + // check that the new column is "c" and not "bar". + df.replace_or_add("c".into(), Series::new("bar".into(), [1, 2, 3]))?; + + assert_eq!(df.get_column_names(), &["a", "b", "c"]); + Ok(()) + } +} diff --git a/crates/polars-core/src/frame/row/av_buffer.rs b/crates/polars-core/src/frame/row/av_buffer.rs new file mode 100644 index 000000000000..0a86a3f78e87 --- /dev/null +++ b/crates/polars-core/src/frame/row/av_buffer.rs @@ -0,0 +1,749 @@ +use std::hint::unreachable_unchecked; + +use arrow::bitmap::BitmapBuilder; +#[cfg(feature = "dtype-struct")] +use polars_utils::pl_str::PlSmallStr; + +use super::*; +use crate::chunked_array::builder::NullChunkedBuilder; +#[cfg(feature = "dtype-struct")] +use crate::prelude::any_value::arr_to_any_value; + +#[derive(Clone)] +pub enum AnyValueBuffer<'a> { + Boolean(BooleanChunkedBuilder), + #[cfg(feature = "dtype-i8")] + Int8(PrimitiveChunkedBuilder), + #[cfg(feature = "dtype-i16")] + Int16(PrimitiveChunkedBuilder), + Int32(PrimitiveChunkedBuilder), + Int64(PrimitiveChunkedBuilder), + #[cfg(feature = "dtype-u8")] + UInt8(PrimitiveChunkedBuilder), + #[cfg(feature = "dtype-u16")] + UInt16(PrimitiveChunkedBuilder), + UInt32(PrimitiveChunkedBuilder), + UInt64(PrimitiveChunkedBuilder), + #[cfg(feature = "dtype-date")] + Date(PrimitiveChunkedBuilder), + #[cfg(feature = "dtype-datetime")] + Datetime( + PrimitiveChunkedBuilder, + TimeUnit, + Option, + ), + #[cfg(feature = "dtype-duration")] + Duration(PrimitiveChunkedBuilder, TimeUnit), + #[cfg(feature = "dtype-time")] + Time(PrimitiveChunkedBuilder), + Float32(PrimitiveChunkedBuilder), + Float64(PrimitiveChunkedBuilder), + String(StringChunkedBuilder), + Null(NullChunkedBuilder), + All(DataType, Vec>), +} + +impl<'a> AnyValueBuffer<'a> { + #[inline] + pub fn add(&mut self, val: AnyValue<'a>) -> Option<()> { + use AnyValueBuffer::*; + match (self, val) { + (Boolean(builder), AnyValue::Null) => builder.append_null(), + (Boolean(builder), AnyValue::Boolean(v)) => builder.append_value(v), + (Boolean(builder), val) => { + let v = val.extract::()?; + builder.append_value(v == 1) + }, + (Int32(builder), AnyValue::Null) => builder.append_null(), + (Int32(builder), val) => builder.append_value(val.extract()?), + (Int64(builder), AnyValue::Null) => builder.append_null(), + (Int64(builder), val) => builder.append_value(val.extract()?), + (UInt32(builder), AnyValue::Null) => builder.append_null(), + (UInt32(builder), val) => builder.append_value(val.extract()?), + (UInt64(builder), AnyValue::Null) => builder.append_null(), + (UInt64(builder), val) => builder.append_value(val.extract()?), + (Float32(builder), AnyValue::Null) => builder.append_null(), + (Float64(builder), AnyValue::Null) => builder.append_null(), + (Float32(builder), val) => builder.append_value(val.extract()?), + (Float64(builder), val) => builder.append_value(val.extract()?), + (String(builder), AnyValue::String(v)) => builder.append_value(v), + (String(builder), AnyValue::StringOwned(v)) => builder.append_value(v.as_str()), + (String(builder), AnyValue::Null) => builder.append_null(), + #[cfg(feature = "dtype-i8")] + (Int8(builder), AnyValue::Null) => builder.append_null(), + #[cfg(feature = "dtype-i8")] + (Int8(builder), val) => builder.append_value(val.extract()?), + #[cfg(feature = "dtype-i16")] + (Int16(builder), AnyValue::Null) => builder.append_null(), + #[cfg(feature = "dtype-i16")] + (Int16(builder), val) => builder.append_value(val.extract()?), + #[cfg(feature = "dtype-u8")] + (UInt8(builder), AnyValue::Null) => builder.append_null(), + #[cfg(feature = "dtype-u8")] + (UInt8(builder), val) => builder.append_value(val.extract()?), + #[cfg(feature = "dtype-u16")] + (UInt16(builder), AnyValue::Null) => builder.append_null(), + #[cfg(feature = "dtype-u16")] + (UInt16(builder), val) => builder.append_value(val.extract()?), + #[cfg(feature = "dtype-date")] + (Date(builder), AnyValue::Null) => builder.append_null(), + #[cfg(feature = "dtype-date")] + (Date(builder), AnyValue::Date(v)) => builder.append_value(v), + #[cfg(feature = "dtype-date")] + (Date(builder), val) if val.is_primitive_numeric() => { + builder.append_value(val.extract()?) + }, + #[cfg(feature = "dtype-datetime")] + (Datetime(builder, _, _), AnyValue::Null) => builder.append_null(), + #[cfg(feature = "dtype-datetime")] + ( + Datetime(builder, tu_l, _), + AnyValue::Datetime(v, tu_r, _) | AnyValue::DatetimeOwned(v, tu_r, _), + ) => { + // we convert right tu to left tu + // so we swap. + let v = convert_time_units(v, tu_r, *tu_l); + builder.append_value(v) + }, + #[cfg(feature = "dtype-datetime")] + (Datetime(builder, _, _), val) if val.is_primitive_numeric() => { + builder.append_value(val.extract()?) + }, + #[cfg(feature = "dtype-duration")] + (Duration(builder, _), AnyValue::Null) => builder.append_null(), + #[cfg(feature = "dtype-duration")] + (Duration(builder, tu_l), AnyValue::Duration(v, tu_r)) => { + let v = convert_time_units(v, tu_r, *tu_l); + builder.append_value(v) + }, + #[cfg(feature = "dtype-duration")] + (Duration(builder, _), val) if val.is_primitive_numeric() => { + builder.append_value(val.extract()?) + }, + #[cfg(feature = "dtype-time")] + (Time(builder), AnyValue::Time(v)) => builder.append_value(v), + #[cfg(feature = "dtype-time")] + (Time(builder), AnyValue::Null) => builder.append_null(), + #[cfg(feature = "dtype-time")] + (Time(builder), val) if val.is_primitive_numeric() => { + builder.append_value(val.extract()?) + }, + (Null(builder), AnyValue::Null) => builder.append_null(), + // Struct and List can be recursive so use AnyValues for that + (All(_, vals), v) => vals.push(v.into_static()), + + // dynamic types + (String(builder), av) => match av { + AnyValue::Int64(v) => builder.append_value(format!("{v}")), + AnyValue::Float64(v) => builder.append_value(format!("{v}")), + AnyValue::Boolean(true) => builder.append_value("true"), + AnyValue::Boolean(false) => builder.append_value("false"), + _ => return None, + }, + _ => return None, + }; + Some(()) + } + + pub(crate) fn add_fallible(&mut self, val: &AnyValue<'a>) -> PolarsResult<()> { + self.add(val.clone()).ok_or_else(|| { + polars_err!( + ComputeError: "could not append value: {} of type: {} to the builder; make sure that all rows \ + have the same schema or consider increasing `infer_schema_length`\n\ + \n\ + it might also be that a value overflows the data-type's capacity", val, val.dtype() + ) + }) + } + + pub fn reset(&mut self, capacity: usize, strict: bool) -> PolarsResult { + use AnyValueBuffer::*; + let out = match self { + Boolean(b) => { + let mut new = BooleanChunkedBuilder::new(b.field.name().clone(), capacity); + std::mem::swap(&mut new, b); + new.finish().into_series() + }, + Int32(b) => { + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); + std::mem::swap(&mut new, b); + new.finish().into_series() + }, + Int64(b) => { + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); + std::mem::swap(&mut new, b); + new.finish().into_series() + }, + UInt32(b) => { + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); + std::mem::swap(&mut new, b); + new.finish().into_series() + }, + UInt64(b) => { + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); + std::mem::swap(&mut new, b); + new.finish().into_series() + }, + #[cfg(feature = "dtype-date")] + Date(b) => { + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); + std::mem::swap(&mut new, b); + new.finish().into_date().into_series() + }, + #[cfg(feature = "dtype-datetime")] + Datetime(b, tu, tz) => { + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); + std::mem::swap(&mut new, b); + let tz = if capacity > 0 { + tz.clone() + } else { + std::mem::take(tz) + }; + new.finish().into_datetime(*tu, tz).into_series() + }, + #[cfg(feature = "dtype-duration")] + Duration(b, tu) => { + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); + std::mem::swap(&mut new, b); + new.finish().into_duration(*tu).into_series() + }, + #[cfg(feature = "dtype-time")] + Time(b) => { + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); + std::mem::swap(&mut new, b); + new.finish().into_time().into_series() + }, + Float32(b) => { + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); + std::mem::swap(&mut new, b); + new.finish().into_series() + }, + Float64(b) => { + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); + std::mem::swap(&mut new, b); + new.finish().into_series() + }, + String(b) => { + let mut new = StringChunkedBuilder::new(b.field.name().clone(), capacity); + std::mem::swap(&mut new, b); + new.finish().into_series() + }, + #[cfg(feature = "dtype-i8")] + Int8(b) => { + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); + std::mem::swap(&mut new, b); + new.finish().into_series() + }, + #[cfg(feature = "dtype-i16")] + Int16(b) => { + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); + std::mem::swap(&mut new, b); + new.finish().into_series() + }, + #[cfg(feature = "dtype-u8")] + UInt8(b) => { + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); + std::mem::swap(&mut new, b); + new.finish().into_series() + }, + #[cfg(feature = "dtype-u16")] + UInt16(b) => { + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); + std::mem::swap(&mut new, b); + new.finish().into_series() + }, + Null(b) => { + let mut new = NullChunkedBuilder::new(b.field.name().clone(), 0); + std::mem::swap(&mut new, b); + new.finish().into_series() + }, + All(dtype, vals) => { + let out = + Series::from_any_values_and_dtype(PlSmallStr::EMPTY, vals, dtype, strict)?; + let mut new = Vec::with_capacity(capacity); + std::mem::swap(&mut new, vals); + out + }, + }; + Ok(out) + } + + pub fn into_series(mut self) -> Series { + self.reset(0, false).unwrap() + } + + pub fn new(dtype: &DataType, capacity: usize) -> AnyValueBuffer<'a> { + (dtype, capacity).into() + } +} + +// datatype and length +impl From<(&DataType, usize)> for AnyValueBuffer<'_> { + fn from(a: (&DataType, usize)) -> Self { + let (dt, len) = a; + use DataType::*; + match dt { + Boolean => AnyValueBuffer::Boolean(BooleanChunkedBuilder::new(PlSmallStr::EMPTY, len)), + Int32 => AnyValueBuffer::Int32(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)), + Int64 => AnyValueBuffer::Int64(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)), + UInt32 => AnyValueBuffer::UInt32(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)), + UInt64 => AnyValueBuffer::UInt64(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)), + #[cfg(feature = "dtype-i8")] + Int8 => AnyValueBuffer::Int8(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)), + #[cfg(feature = "dtype-i16")] + Int16 => AnyValueBuffer::Int16(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)), + #[cfg(feature = "dtype-u8")] + UInt8 => AnyValueBuffer::UInt8(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)), + #[cfg(feature = "dtype-u16")] + UInt16 => AnyValueBuffer::UInt16(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)), + #[cfg(feature = "dtype-date")] + Date => AnyValueBuffer::Date(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)), + #[cfg(feature = "dtype-datetime")] + Datetime(tu, tz) => AnyValueBuffer::Datetime( + PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len), + *tu, + tz.clone(), + ), + #[cfg(feature = "dtype-duration")] + Duration(tu) => { + AnyValueBuffer::Duration(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len), *tu) + }, + #[cfg(feature = "dtype-time")] + Time => AnyValueBuffer::Time(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)), + Float32 => { + AnyValueBuffer::Float32(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)) + }, + Float64 => { + AnyValueBuffer::Float64(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)) + }, + String => AnyValueBuffer::String(StringChunkedBuilder::new(PlSmallStr::EMPTY, len)), + Null => AnyValueBuffer::Null(NullChunkedBuilder::new(PlSmallStr::EMPTY, 0)), + // Struct and List can be recursive so use AnyValues for that + dt => AnyValueBuffer::All(dt.clone(), Vec::with_capacity(len)), + } + } +} + +/// An [`AnyValueBuffer`] that should be used when we trust the builder +#[derive(Clone)] +pub enum AnyValueBufferTrusted<'a> { + Boolean(BooleanChunkedBuilder), + #[cfg(feature = "dtype-i8")] + Int8(PrimitiveChunkedBuilder), + #[cfg(feature = "dtype-i16")] + Int16(PrimitiveChunkedBuilder), + Int32(PrimitiveChunkedBuilder), + Int64(PrimitiveChunkedBuilder), + #[cfg(feature = "dtype-u8")] + UInt8(PrimitiveChunkedBuilder), + #[cfg(feature = "dtype-u16")] + UInt16(PrimitiveChunkedBuilder), + UInt32(PrimitiveChunkedBuilder), + UInt64(PrimitiveChunkedBuilder), + Float32(PrimitiveChunkedBuilder), + Float64(PrimitiveChunkedBuilder), + String(StringChunkedBuilder), + #[cfg(feature = "dtype-struct")] + // not the trusted variant! + Struct(BitmapBuilder, Vec<(AnyValueBuffer<'a>, PlSmallStr)>), + Null(NullChunkedBuilder), + All(DataType, Vec>), +} + +impl<'a> AnyValueBufferTrusted<'a> { + pub fn new(dtype: &DataType, len: usize) -> Self { + (dtype, len).into() + } + + #[inline] + unsafe fn add_null(&mut self) { + use AnyValueBufferTrusted::*; + match self { + Boolean(builder) => builder.append_null(), + #[cfg(feature = "dtype-i8")] + Int8(builder) => builder.append_null(), + #[cfg(feature = "dtype-i16")] + Int16(builder) => builder.append_null(), + Int32(builder) => builder.append_null(), + Int64(builder) => builder.append_null(), + #[cfg(feature = "dtype-u8")] + UInt8(builder) => builder.append_null(), + #[cfg(feature = "dtype-u16")] + UInt16(builder) => builder.append_null(), + UInt32(builder) => builder.append_null(), + UInt64(builder) => builder.append_null(), + Float32(builder) => builder.append_null(), + Float64(builder) => builder.append_null(), + String(builder) => builder.append_null(), + #[cfg(feature = "dtype-struct")] + Struct(outer_validity, builders) => { + outer_validity.push(false); + for (b, _) in builders.iter_mut() { + b.add(AnyValue::Null); + } + }, + Null(builder) => builder.append_null(), + All(_, vals) => vals.push(AnyValue::Null), + } + } + + #[inline] + unsafe fn add_physical(&mut self, val: &AnyValue<'_>) { + use AnyValueBufferTrusted::*; + match self { + Boolean(builder) => { + let AnyValue::Boolean(v) = val else { + unreachable_unchecked() + }; + builder.append_value(*v) + }, + #[cfg(feature = "dtype-i8")] + Int8(builder) => { + let AnyValue::Int8(v) = val else { + unreachable_unchecked() + }; + builder.append_value(*v) + }, + #[cfg(feature = "dtype-i16")] + Int16(builder) => { + let AnyValue::Int16(v) = val else { + unreachable_unchecked() + }; + builder.append_value(*v) + }, + Int32(builder) => { + let AnyValue::Int32(v) = val else { + unreachable_unchecked() + }; + builder.append_value(*v) + }, + Int64(builder) => { + let AnyValue::Int64(v) = val else { + unreachable_unchecked() + }; + builder.append_value(*v) + }, + #[cfg(feature = "dtype-u8")] + UInt8(builder) => { + let AnyValue::UInt8(v) = val else { + unreachable_unchecked() + }; + builder.append_value(*v) + }, + #[cfg(feature = "dtype-u16")] + UInt16(builder) => { + let AnyValue::UInt16(v) = val else { + unreachable_unchecked() + }; + builder.append_value(*v) + }, + UInt32(builder) => { + let AnyValue::UInt32(v) = val else { + unreachable_unchecked() + }; + builder.append_value(*v) + }, + UInt64(builder) => { + let AnyValue::UInt64(v) = val else { + unreachable_unchecked() + }; + builder.append_value(*v) + }, + Float32(builder) => { + let AnyValue::Float32(v) = val else { + unreachable_unchecked() + }; + builder.append_value(*v) + }, + Float64(builder) => { + let AnyValue::Float64(v) = val else { + unreachable_unchecked() + }; + builder.append_value(*v) + }, + Null(builder) => { + let AnyValue::Null = val else { + unreachable_unchecked() + }; + builder.append_null() + }, + _ => unreachable_unchecked(), + } + } + + /// Will add the [`AnyValue`] into [`Self`] and unpack as the physical type + /// belonging to [`Self`]. This should only be used with physical buffers + /// + /// If a type is not primitive or String, the AnyValues will be converted to static + /// + /// # Safety + /// The caller must ensure that the [`AnyValue`] type exactly matches the `Buffer` type and is owned. + #[inline] + pub unsafe fn add_unchecked_owned_physical(&mut self, val: &AnyValue<'_>) { + use AnyValueBufferTrusted::*; + match val { + AnyValue::Null => self.add_null(), + _ => { + match self { + String(builder) => { + let AnyValue::StringOwned(v) = val else { + unreachable_unchecked() + }; + builder.append_value(v.as_str()) + }, + #[cfg(feature = "dtype-struct")] + Struct(outer_validity, builders) => { + let AnyValue::StructOwned(payload) = val else { + unreachable_unchecked() + }; + let avs = &*payload.0; + // amortize loop counter + for i in 0..avs.len() { + unsafe { + let (builder, _) = builders.get_unchecked_mut(i); + let av = avs.get_unchecked(i).clone(); + // lifetime is bound to 'a + let av = std::mem::transmute::, AnyValue<'a>>(av); + builder.add(av.clone()); + } + } + outer_validity.push(true); + }, + All(_, vals) => vals.push(val.clone().into_static()), + _ => self.add_physical(val), + } + }, + } + } + + /// # Safety + /// The caller must ensure that the [`AnyValue`] type exactly matches the `Buffer` type and is borrowed. + #[inline] + pub unsafe fn add_unchecked_borrowed_physical(&mut self, val: &AnyValue<'_>) { + use AnyValueBufferTrusted::*; + match val { + AnyValue::Null => self.add_null(), + _ => { + match self { + String(builder) => { + let AnyValue::String(v) = val else { + unreachable_unchecked() + }; + builder.append_value(v) + }, + #[cfg(feature = "dtype-struct")] + Struct(outer_validity, builders) => { + let AnyValue::Struct(idx, arr, fields) = val else { + unreachable_unchecked() + }; + let arrays = arr.values(); + // amortize loop counter + for i in 0..fields.len() { + unsafe { + let array = arrays.get_unchecked(i); + let field = fields.get_unchecked(i); + let (builder, _) = builders.get_unchecked_mut(i); + let av = arr_to_any_value(&**array, *idx, &field.dtype); + // lifetime is bound to 'a + let av = std::mem::transmute::, AnyValue<'a>>(av); + builder.add(av); + } + } + outer_validity.push(true); + }, + All(_, vals) => vals.push(val.clone().into_static()), + _ => self.add_physical(val), + } + }, + } + } + + /// Clear `self` and give `capacity`, returning the old contents as a [`Series`]. + pub fn reset(&mut self, capacity: usize, strict: bool) -> PolarsResult { + use AnyValueBufferTrusted::*; + let out = match self { + Boolean(b) => { + let mut new = BooleanChunkedBuilder::new(b.field.name().clone(), capacity); + std::mem::swap(&mut new, b); + new.finish().into_series() + }, + Int32(b) => { + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); + std::mem::swap(&mut new, b); + new.finish().into_series() + }, + Int64(b) => { + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); + std::mem::swap(&mut new, b); + new.finish().into_series() + }, + UInt32(b) => { + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); + std::mem::swap(&mut new, b); + new.finish().into_series() + }, + UInt64(b) => { + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); + std::mem::swap(&mut new, b); + new.finish().into_series() + }, + Float32(b) => { + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); + std::mem::swap(&mut new, b); + new.finish().into_series() + }, + Float64(b) => { + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); + std::mem::swap(&mut new, b); + new.finish().into_series() + }, + String(b) => { + let mut new = StringChunkedBuilder::new(b.field.name().clone(), capacity); + std::mem::swap(&mut new, b); + new.finish().into_series() + }, + #[cfg(feature = "dtype-i8")] + Int8(b) => { + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); + std::mem::swap(&mut new, b); + new.finish().into_series() + }, + #[cfg(feature = "dtype-i16")] + Int16(b) => { + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); + std::mem::swap(&mut new, b); + new.finish().into_series() + }, + #[cfg(feature = "dtype-u8")] + UInt8(b) => { + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); + std::mem::swap(&mut new, b); + new.finish().into_series() + }, + #[cfg(feature = "dtype-u16")] + UInt16(b) => { + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); + std::mem::swap(&mut new, b); + new.finish().into_series() + }, + #[cfg(feature = "dtype-struct")] + Struct(outer_validity, b) => { + // @Q? Maybe we need to add a length parameter here for ZFS's. I am not very happy + // with just setting the length to zero for that case. + if b.is_empty() { + return Ok( + StructChunked::from_series(PlSmallStr::EMPTY, 0, [].iter())?.into_series() + ); + } + + let mut min_len = usize::MAX; + let mut max_len = usize::MIN; + + let v = b + .iter_mut() + .map(|(b, name)| { + let mut s = b.reset(capacity, strict)?; + + min_len = min_len.min(s.len()); + max_len = max_len.max(s.len()); + + s.rename(name.clone()); + Ok(s) + }) + .collect::>>()?; + + let length = if min_len == 0 { 0 } else { max_len }; + + let old_outer_validity = core::mem::take(outer_validity); + outer_validity.reserve(capacity); + + StructChunked::from_series(PlSmallStr::EMPTY, length, v.iter()) + .unwrap() + .with_outer_validity(Some(old_outer_validity.freeze())) + .into_series() + }, + Null(b) => { + let mut new = NullChunkedBuilder::new(b.field.name().clone(), 0); + std::mem::swap(&mut new, b); + new.finish().into_series() + }, + All(dtype, vals) => { + let mut swap_vals = Vec::with_capacity(capacity); + std::mem::swap(vals, &mut swap_vals); + Series::from_any_values_and_dtype(PlSmallStr::EMPTY, &swap_vals, dtype, false) + .unwrap() + }, + }; + + Ok(out) + } + + pub fn into_series(mut self) -> Series { + // unwrap: non-strict does not error. + self.reset(0, false).unwrap() + } +} + +impl From<(&DataType, usize)> for AnyValueBufferTrusted<'_> { + fn from(a: (&DataType, usize)) -> Self { + let (dt, len) = a; + use DataType::*; + match dt { + Boolean => { + AnyValueBufferTrusted::Boolean(BooleanChunkedBuilder::new(PlSmallStr::EMPTY, len)) + }, + Int32 => { + AnyValueBufferTrusted::Int32(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)) + }, + Int64 => { + AnyValueBufferTrusted::Int64(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)) + }, + UInt32 => { + AnyValueBufferTrusted::UInt32(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)) + }, + UInt64 => { + AnyValueBufferTrusted::UInt64(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)) + }, + #[cfg(feature = "dtype-i8")] + Int8 => { + AnyValueBufferTrusted::Int8(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)) + }, + #[cfg(feature = "dtype-i16")] + Int16 => { + AnyValueBufferTrusted::Int16(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)) + }, + #[cfg(feature = "dtype-u8")] + UInt8 => { + AnyValueBufferTrusted::UInt8(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)) + }, + #[cfg(feature = "dtype-u16")] + UInt16 => { + AnyValueBufferTrusted::UInt16(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)) + }, + Float32 => { + AnyValueBufferTrusted::Float32(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)) + }, + Float64 => { + AnyValueBufferTrusted::Float64(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)) + }, + String => { + AnyValueBufferTrusted::String(StringChunkedBuilder::new(PlSmallStr::EMPTY, len)) + }, + #[cfg(feature = "dtype-struct")] + Struct(fields) => { + let outer_validity = BitmapBuilder::with_capacity(len); + let buffers = fields + .iter() + .map(|field| { + let dtype = field.dtype().to_physical(); + let buffer: AnyValueBuffer = (&dtype, len).into(); + (buffer, field.name.clone()) + }) + .collect::>(); + AnyValueBufferTrusted::Struct(outer_validity, buffers) + }, + // List can be recursive so use AnyValues for that + dt => AnyValueBufferTrusted::All(dt.clone(), Vec::with_capacity(len)), + } + } +} diff --git a/crates/polars-core/src/frame/row/dataframe.rs b/crates/polars-core/src/frame/row/dataframe.rs new file mode 100644 index 000000000000..97891c76d478 --- /dev/null +++ b/crates/polars-core/src/frame/row/dataframe.rs @@ -0,0 +1,150 @@ +use super::*; + +impl DataFrame { + /// Get a row from a [`DataFrame`]. Use of this is discouraged as it will likely be slow. + pub fn get_row(&self, idx: usize) -> PolarsResult { + let values = self + .materialized_column_iter() + .map(|s| s.get(idx)) + .collect::>>()?; + Ok(Row(values)) + } + + /// Amortize allocations by reusing a row. + /// The caller is responsible to make sure that the row has at least the capacity for the number + /// of columns in the [`DataFrame`] + pub fn get_row_amortized<'a>(&'a self, idx: usize, row: &mut Row<'a>) -> PolarsResult<()> { + for (s, any_val) in self.materialized_column_iter().zip(&mut row.0) { + *any_val = s.get(idx)?; + } + Ok(()) + } + + /// Amortize allocations by reusing a row. + /// The caller is responsible to make sure that the row has at least the capacity for the number + /// of columns in the [`DataFrame`] + /// + /// # Safety + /// Does not do any bounds checking. + #[inline] + pub unsafe fn get_row_amortized_unchecked<'a>(&'a self, idx: usize, row: &mut Row<'a>) { + self.materialized_column_iter() + .zip(&mut row.0) + .for_each(|(s, any_val)| { + *any_val = s.get_unchecked(idx); + }); + } + + /// Create a new [`DataFrame`] from rows. + /// + /// This should only be used when you have row wise data, as this is a lot slower + /// than creating the [`Series`] in a columnar fashion + pub fn from_rows_and_schema(rows: &[Row], schema: &Schema) -> PolarsResult { + Self::from_rows_iter_and_schema(rows.iter(), schema) + } + + /// Create a new [`DataFrame`] from an iterator over rows. + /// + /// This should only be used when you have row wise data, as this is a lot slower + /// than creating the [`Series`] in a columnar fashion. + pub fn from_rows_iter_and_schema<'a, I>(mut rows: I, schema: &Schema) -> PolarsResult + where + I: Iterator>, + { + if schema.is_empty() { + let height = rows.count(); + let columns = Vec::new(); + return Ok(unsafe { DataFrame::new_no_checks(height, columns) }); + } + + let capacity = rows.size_hint().0; + + let mut buffers: Vec<_> = schema + .iter_values() + .map(|dtype| { + let buf: AnyValueBuffer = (dtype, capacity).into(); + buf + }) + .collect(); + + let mut expected_len = 0; + rows.try_for_each::<_, PolarsResult<()>>(|row| { + expected_len += 1; + for (value, buf) in row.0.iter().zip(&mut buffers) { + buf.add_fallible(value)? + } + Ok(()) + })?; + let v = buffers + .into_iter() + .zip(schema.iter_names()) + .map(|(b, name)| { + let mut c = b.into_series().into_column(); + // if the schema adds a column not in the rows, we + // fill it with nulls + if c.is_empty() { + Column::full_null(name.clone(), expected_len, c.dtype()) + } else { + c.rename(name.clone()); + c + } + }) + .collect(); + DataFrame::new(v) + } + + /// Create a new [`DataFrame`] from an iterator over rows. This should only be used when you have row wise data, + /// as this is a lot slower than creating the [`Series`] in a columnar fashion + pub fn try_from_rows_iter_and_schema<'a, I>(mut rows: I, schema: &Schema) -> PolarsResult + where + I: Iterator>>, + { + let capacity = rows.size_hint().0; + + let mut buffers: Vec<_> = schema + .iter_values() + .map(|dtype| { + let buf: AnyValueBuffer = (dtype, capacity).into(); + buf + }) + .collect(); + + let mut expected_len = 0; + rows.try_for_each::<_, PolarsResult<()>>(|row| { + expected_len += 1; + for (value, buf) in row?.0.iter().zip(&mut buffers) { + buf.add_fallible(value)? + } + Ok(()) + })?; + let v = buffers + .into_iter() + .zip(schema.iter_names()) + .map(|(b, name)| { + let mut c = b.into_series().into_column(); + // if the schema adds a column not in the rows, we + // fill it with nulls + if c.is_empty() { + Column::full_null(name.clone(), expected_len, c.dtype()) + } else { + c.rename(name.clone()); + c + } + }) + .collect(); + DataFrame::new(v) + } + + /// Create a new [`DataFrame`] from rows. This should only be used when you have row wise data, + /// as this is a lot slower than creating the [`Series`] in a columnar fashion + pub fn from_rows(rows: &[Row]) -> PolarsResult { + let schema = rows_to_schema_first_non_null(rows, Some(50))?; + let has_nulls = schema + .iter_values() + .any(|dtype| matches!(dtype, DataType::Null)); + polars_ensure!( + !has_nulls, ComputeError: "unable to infer row types because of null values" + ); + Self::from_rows_and_schema(rows, &schema) + } +} diff --git a/crates/polars-core/src/frame/row/mod.rs b/crates/polars-core/src/frame/row/mod.rs new file mode 100644 index 000000000000..7165f4224ac9 --- /dev/null +++ b/crates/polars-core/src/frame/row/mod.rs @@ -0,0 +1,251 @@ +mod av_buffer; +mod dataframe; +mod transpose; + +use std::borrow::Borrow; +use std::fmt::Debug; +#[cfg(feature = "object")] +use std::hash::{Hash, Hasher}; + +use arrow::bitmap::Bitmap; +pub use av_buffer::*; +use polars_utils::format_pl_smallstr; +#[cfg(feature = "object")] +use polars_utils::total_ord::TotalHash; +use rayon::prelude::*; + +use crate::POOL; +use crate::prelude::*; +use crate::utils::{dtypes_to_schema, dtypes_to_supertype, try_get_supertype}; + +#[cfg(feature = "object")] +pub(crate) struct AnyValueRows<'a> { + vals: Vec>, + width: usize, +} + +#[cfg(feature = "object")] +pub(crate) struct AnyValueRow<'a>(&'a [AnyValue<'a>]); + +#[cfg(feature = "object")] +impl<'a> AnyValueRows<'a> { + pub(crate) fn get(&'a self, i: usize) -> AnyValueRow<'a> { + let start = i * self.width; + let end = (i + 1) * self.width; + AnyValueRow(&self.vals[start..end]) + } +} + +#[cfg(feature = "object")] +impl TotalEq for AnyValueRow<'_> { + fn tot_eq(&self, other: &Self) -> bool { + let lhs = self.0; + let rhs = other.0; + + // Should only be used in that context. + debug_assert_eq!(lhs.len(), rhs.len()); + lhs.iter().zip(rhs.iter()).all(|(l, r)| l == r) + } +} + +#[cfg(feature = "object")] +impl TotalHash for AnyValueRow<'_> { + fn tot_hash(&self, state: &mut H) + where + H: Hasher, + { + self.0.iter().for_each(|av| av.hash(state)) + } +} + +impl DataFrame { + #[cfg(feature = "object")] + #[allow(clippy::wrong_self_convention)] + // Create indexable rows in a single allocation. + pub(crate) fn to_av_rows(&mut self) -> AnyValueRows<'_> { + self.as_single_chunk_par(); + let width = self.width(); + let size = width * self.height(); + let mut buf = vec![AnyValue::Null; size]; + for (col_i, s) in self.materialized_column_iter().enumerate() { + match s.dtype() { + #[cfg(feature = "object")] + DataType::Object(_) => { + for row_i in 0..s.len() { + let av = s.get(row_i).unwrap(); + buf[row_i * width + col_i] = av + } + }, + _ => { + for (row_i, av) in s.iter().enumerate() { + buf[row_i * width + col_i] = av + } + }, + } + } + AnyValueRows { vals: buf, width } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct Row<'a>(pub Vec>); + +impl<'a> Row<'a> { + pub fn new(values: Vec>) -> Self { + Row(values) + } +} + +type Tracker = PlIndexMap>; + +pub fn infer_schema( + iter: impl Iterator, impl Into)>>, + infer_schema_length: usize, +) -> Schema { + let mut values: Tracker = Tracker::default(); + let len = iter.size_hint().1.unwrap_or(infer_schema_length); + + let max_infer = std::cmp::min(len, infer_schema_length); + for inner in iter.take(max_infer) { + for (key, value) in inner { + add_or_insert(&mut values, key.into(), value.into()); + } + } + Schema::from_iter(resolve_fields(values)) +} + +fn add_or_insert(values: &mut Tracker, key: PlSmallStr, dtype: DataType) { + if values.contains_key(&key) { + let x = values.get_mut(&key).unwrap(); + x.insert(dtype); + } else { + // create hashset and add value type + let mut hs = PlHashSet::new(); + hs.insert(dtype); + values.insert(key, hs); + } +} + +fn resolve_fields(spec: Tracker) -> Vec { + spec.iter() + .map(|(k, hs)| { + let v: Vec<&DataType> = hs.iter().collect(); + Field::new(k.clone(), coerce_dtype(&v)) + }) + .collect() +} + +/// Coerces a slice of datatypes into a single supertype. +pub fn coerce_dtype>(datatypes: &[A]) -> DataType { + use DataType::*; + + let are_all_equal = datatypes.windows(2).all(|w| w[0].borrow() == w[1].borrow()); + + if are_all_equal { + return datatypes[0].borrow().clone(); + } + if datatypes.len() > 2 { + return String; + } + + let (lhs, rhs) = (datatypes[0].borrow(), datatypes[1].borrow()); + try_get_supertype(lhs, rhs).unwrap_or(String) +} + +/// Infer the schema of rows by determining the supertype of the values. +/// +/// Field names are set as `column_0`, `column_1`, and so on. +pub fn rows_to_schema_supertypes( + rows: &[Row], + infer_schema_length: Option, +) -> PolarsResult { + let dtypes = rows_to_supertypes(rows, infer_schema_length)?; + let schema = dtypes_to_schema(dtypes); + Ok(schema) +} + +/// Infer the schema data types of rows by determining the supertype of the values. +pub fn rows_to_supertypes( + rows: &[Row], + infer_schema_length: Option, +) -> PolarsResult> { + polars_ensure!(!rows.is_empty(), NoData: "no rows, cannot infer schema"); + + let max_infer = infer_schema_length.unwrap_or(rows.len()); + + let mut dtypes: Vec> = vec![PlIndexSet::new(); rows[0].0.len()]; + for row in rows.iter().take(max_infer) { + for (val, dtypes_set) in row.0.iter().zip(dtypes.iter_mut()) { + dtypes_set.insert(val.into()); + } + } + + dtypes + .into_iter() + .map(|dtypes_set| dtypes_to_supertype(&dtypes_set)) + .collect() +} + +/// Infer schema from rows and set the first no null type as column data type. +pub fn rows_to_schema_first_non_null( + rows: &[Row], + infer_schema_length: Option, +) -> PolarsResult { + polars_ensure!(!rows.is_empty(), NoData: "no rows, cannot infer schema"); + + let max_infer = infer_schema_length.unwrap_or(rows.len()); + let mut schema: Schema = (&rows[0]).into(); + + // the first row that has no nulls will be used to infer the schema. + // if there is a null, we check the next row and see if we can update the schema + + for row in rows.iter().take(max_infer).skip(1) { + // for i in 1..max_infer { + let nulls: Vec<_> = schema + .iter_values() + .enumerate() + .filter_map(|(i, dtype)| { + // double check struct and list types + // nested null values can be wrongly inferred by front ends + match dtype { + DataType::Null | DataType::List(_) => Some(i), + #[cfg(feature = "dtype-struct")] + DataType::Struct(_) => Some(i), + _ => None, + } + }) + .collect(); + if nulls.is_empty() { + break; + } else { + for i in nulls { + let val = &row.0[i]; + + if !val.is_nested_null() { + let dtype = val.into(); + schema.set_dtype_at_index(i, dtype).unwrap(); + } + } + } + } + Ok(schema) +} + +impl<'a> From<&AnyValue<'a>> for Field { + fn from(val: &AnyValue<'a>) -> Self { + Field::new(PlSmallStr::EMPTY, val.into()) + } +} + +impl From<&Row<'_>> for Schema { + fn from(row: &Row) -> Self { + row.0 + .iter() + .enumerate() + .map(|(i, av)| { + let dtype = av.into(); + Field::new(format_pl_smallstr!("column_{i}"), dtype) + }) + .collect() + } +} diff --git a/crates/polars-core/src/frame/row/transpose.rs b/crates/polars-core/src/frame/row/transpose.rs new file mode 100644 index 000000000000..3ba41290a575 --- /dev/null +++ b/crates/polars-core/src/frame/row/transpose.rs @@ -0,0 +1,334 @@ +use std::borrow::Cow; + +use either::Either; + +use super::*; + +impl DataFrame { + pub(crate) fn transpose_from_dtype( + &self, + dtype: &DataType, + keep_names_as: Option, + names_out: &[PlSmallStr], + ) -> PolarsResult { + let new_width = self.height(); + let new_height = self.width(); + // Allocate space for the transposed columns, putting the "row names" first if needed + let mut cols_t = match keep_names_as { + None => Vec::::with_capacity(new_width), + Some(name) => { + let mut tmp = Vec::::with_capacity(new_width + 1); + tmp.push( + StringChunked::from_iter_values( + name, + self.get_column_names_owned().into_iter(), + ) + .into_column(), + ); + tmp + }, + }; + + let cols = &self.columns; + match dtype { + #[cfg(feature = "dtype-i8")] + DataType::Int8 => numeric_transpose::(cols, names_out, &mut cols_t), + #[cfg(feature = "dtype-i16")] + DataType::Int16 => numeric_transpose::(cols, names_out, &mut cols_t), + DataType::Int32 => numeric_transpose::(cols, names_out, &mut cols_t), + DataType::Int64 => numeric_transpose::(cols, names_out, &mut cols_t), + #[cfg(feature = "dtype-u8")] + DataType::UInt8 => numeric_transpose::(cols, names_out, &mut cols_t), + #[cfg(feature = "dtype-u16")] + DataType::UInt16 => numeric_transpose::(cols, names_out, &mut cols_t), + DataType::UInt32 => numeric_transpose::(cols, names_out, &mut cols_t), + DataType::UInt64 => numeric_transpose::(cols, names_out, &mut cols_t), + DataType::Float32 => numeric_transpose::(cols, names_out, &mut cols_t), + DataType::Float64 => numeric_transpose::(cols, names_out, &mut cols_t), + #[cfg(feature = "object")] + DataType::Object(_) => { + // this requires to support `Object` in Series::iter which we don't yet + polars_bail!(InvalidOperation: "Object dtype not supported in 'transpose'") + }, + _ => { + let phys_dtype = dtype.to_physical(); + let mut buffers = (0..new_width) + .map(|_| { + let buf: AnyValueBufferTrusted = (&phys_dtype, new_height).into(); + buf + }) + .collect::>(); + + let columns = self + .materialized_column_iter() + // first cast to supertype before casting to physical to ensure units are correct + .map(|s| s.cast(dtype).unwrap().cast(&phys_dtype).unwrap()) + .collect::>(); + + // this is very expensive. A lot of cache misses here. + // This is the part that is performance critical. + for s in columns { + polars_ensure!(s.dtype() == &phys_dtype, ComputeError: "cannot transpose with supertype: {}", dtype); + s.iter().zip(buffers.iter_mut()).for_each(|(av, buf)| { + // SAFETY: we checked the type and we borrow + unsafe { + buf.add_unchecked_borrowed_physical(&av); + } + }); + } + cols_t.extend(buffers.into_iter().zip(names_out).map(|(buf, name)| { + // SAFETY: we are casting back to the supertype + let mut s = unsafe { buf.into_series().cast_unchecked(dtype).unwrap() }; + s.rename(name.clone()); + s.into() + })); + }, + }; + Ok(unsafe { DataFrame::new_no_checks(new_height, cols_t) }) + } + + pub fn transpose( + &mut self, + keep_names_as: Option<&str>, + new_col_names: Option>>, + ) -> PolarsResult { + let new_col_names = match new_col_names { + None => None, + Some(Either::Left(v)) => Some(Either::Left(v.into())), + Some(Either::Right(v)) => Some(Either::Right( + v.into_iter().map(Into::into).collect::>(), + )), + }; + + self.transpose_impl(keep_names_as, new_col_names) + } + /// Transpose a DataFrame. This is a very expensive operation. + pub fn transpose_impl( + &mut self, + keep_names_as: Option<&str>, + new_col_names: Option>>, + ) -> PolarsResult { + // We must iterate columns as [`AnyValue`], so we must be contiguous. + self.as_single_chunk_par(); + + let mut df = Cow::Borrowed(self); // Can't use self because we might drop a name column + let names_out = match new_col_names { + None => (0..self.height()) + .map(|i| format_pl_smallstr!("column_{i}")) + .collect(), + Some(cn) => match cn { + Either::Left(name) => { + let new_names = self.column(name.as_str()).and_then(|x| x.str())?; + polars_ensure!(new_names.null_count() == 0, ComputeError: "Column with new names can't have null values"); + df = Cow::Owned(self.drop(name.as_str())?); + new_names + .into_no_null_iter() + .map(PlSmallStr::from_str) + .collect() + }, + Either::Right(names) => { + polars_ensure!(names.len() == self.height(), ShapeMismatch: "Length of new column names must be the same as the row count"); + names + }, + }, + }; + if let Some(cn) = keep_names_as { + // Check that the column name we're using for the original column names is unique before + // wasting time transposing + polars_ensure!(names_out.iter().all(|a| a.as_str() != cn), Duplicate: "{} is already in output column names", cn) + } + polars_ensure!( + df.height() != 0 && df.width() != 0, + NoData: "unable to transpose an empty DataFrame" + ); + let dtype = df.get_supertype().unwrap()?; + match dtype { + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_, _) | DataType::Enum(_, _) => { + let mut valid = true; + let mut rev_map: Option<&Arc> = None; + for s in self.columns.iter() { + if let DataType::Categorical(Some(col_rev_map), _) + | DataType::Enum(Some(col_rev_map), _) = &s.dtype() + { + match rev_map { + Some(rev_map) => valid = valid && rev_map.same_src(col_rev_map), + None => { + rev_map = Some(col_rev_map); + }, + } + } + } + polars_ensure!(valid, string_cache_mismatch); + }, + _ => {}, + } + df.transpose_from_dtype(&dtype, keep_names_as.map(PlSmallStr::from_str), &names_out) + } +} + +#[inline] +unsafe fn add_value( + values_buf_ptr: usize, + col_idx: usize, + row_idx: usize, + value: T, +) { + let column = (*(values_buf_ptr as *mut Vec>)).get_unchecked_mut(col_idx); + let el_ptr = column.as_mut_ptr(); + *el_ptr.add(row_idx) = value; +} + +// This just fills a pre-allocated mutable series vector, which may have a name column. +// Nothing is returned and the actual DataFrame is constructed above. +pub(super) fn numeric_transpose( + cols: &[Column], + names_out: &[PlSmallStr], + cols_t: &mut Vec, +) where + T: PolarsNumericType, + //S: AsRef, + ChunkedArray: IntoSeries, +{ + let new_width = cols[0].len(); + let new_height = cols.len(); + + let has_nulls = cols.iter().any(|s| s.null_count() > 0); + + let mut values_buf: Vec> = (0..new_width) + .map(|_| Vec::with_capacity(new_height)) + .collect(); + let mut validity_buf: Vec<_> = if has_nulls { + // we first use bools instead of bits, because we can access these in parallel without aliasing + (0..new_width).map(|_| vec![true; new_height]).collect() + } else { + (0..new_width).map(|_| vec![]).collect() + }; + + // work with *mut pointers because we it is UB write to &refs. + let values_buf_ptr = &mut values_buf as *mut Vec> as usize; + let validity_buf_ptr = &mut validity_buf as *mut Vec> as usize; + + POOL.install(|| { + cols.iter() + .map(Column::as_materialized_series) + .enumerate() + .for_each(|(row_idx, s)| { + let s = s.cast(&T::get_dtype()).unwrap(); + let ca = s.unpack::().unwrap(); + + // SAFETY: + // we access in parallel, but every access is unique, so we don't break aliasing rules + // we also ensured we allocated enough memory, so we never reallocate and thus + // the pointers remain valid. + if has_nulls { + for (col_idx, opt_v) in ca.iter().enumerate() { + match opt_v { + None => unsafe { + let column = (*(validity_buf_ptr as *mut Vec>)) + .get_unchecked_mut(col_idx); + let el_ptr = column.as_mut_ptr(); + *el_ptr.add(row_idx) = false; + // we must initialize this memory otherwise downstream code + // might access uninitialized memory when the masked out values + // are changed. + add_value(values_buf_ptr, col_idx, row_idx, T::Native::default()); + }, + Some(v) => unsafe { + add_value(values_buf_ptr, col_idx, row_idx, v); + }, + } + } + } else { + for (col_idx, v) in ca.into_no_null_iter().enumerate() { + unsafe { + let column = (*(values_buf_ptr as *mut Vec>)) + .get_unchecked_mut(col_idx); + let el_ptr = column.as_mut_ptr(); + *el_ptr.add(row_idx) = v; + } + } + } + }) + }); + + let par_iter = values_buf + .into_par_iter() + .zip(validity_buf) + .zip(names_out) + .map(|((mut values, validity), name)| { + // SAFETY: + // all values are written we can now set len + unsafe { + values.set_len(new_height); + } + + let validity = if has_nulls { + let validity = Bitmap::from_trusted_len_iter(validity.iter().copied()); + if validity.unset_bits() > 0 { + Some(validity) + } else { + None + } + } else { + None + }; + + let arr = PrimitiveArray::::new( + T::get_dtype().to_arrow(CompatLevel::newest()), + values.into(), + validity, + ); + ChunkedArray::with_chunk(name.clone(), arr).into_column() + }); + POOL.install(|| cols_t.par_extend(par_iter)); +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_transpose() -> PolarsResult<()> { + let mut df = df![ + "a" => [1, 2, 3], + "b" => [10, 20, 30], + ]?; + + let out = df.transpose(None, None)?; + let expected = df![ + "column_0" => [1, 10], + "column_1" => [2, 20], + "column_2" => [3, 30], + + ]?; + assert!(out.equals_missing(&expected)); + + let mut df = df![ + "a" => [Some(1), None, Some(3)], + "b" => [Some(10), Some(20), None], + ]?; + let out = df.transpose(None, None)?; + let expected = df![ + "column_0" => [1, 10], + "column_1" => [None, Some(20)], + "column_2" => [Some(3), None], + + ]?; + assert!(out.equals_missing(&expected)); + + let mut df = df![ + "a" => ["a", "b", "c"], + "b" => [Some(10), Some(20), None], + ]?; + let out = df.transpose(None, None)?; + let expected = df![ + "column_0" => ["a", "10"], + "column_1" => ["b", "20"], + "column_2" => [Some("c"), None], + + ]?; + assert!(out.equals_missing(&expected)); + Ok(()) + } +} diff --git a/crates/polars-core/src/frame/top_k.rs b/crates/polars-core/src/frame/top_k.rs new file mode 100644 index 000000000000..dd610a2383d4 --- /dev/null +++ b/crates/polars-core/src/frame/top_k.rs @@ -0,0 +1,32 @@ +use super::*; +use crate::prelude::sort::arg_bottom_k::_arg_bottom_k; + +impl DataFrame { + pub(crate) fn bottom_k_impl( + &self, + k: usize, + by_column: Vec, + mut sort_options: SortMultipleOptions, + ) -> PolarsResult { + let first_descending = sort_options.descending[0]; + let first_by_column = by_column[0].name().to_string(); + + let idx = _arg_bottom_k(k, &by_column, &mut sort_options)?; + + let mut df = unsafe { self.take_unchecked(&idx.into_inner()) }; + + // Mark the first sort column as sorted + // if the column did not exists it is ok, because we sorted by an expression + // not present in the dataframe + let _ = df.apply(&first_by_column, |s| { + let mut s = s.clone(); + if first_descending { + s.set_sorted_flag(IsSorted::Descending) + } else { + s.set_sorted_flag(IsSorted::Ascending) + } + s + }); + Ok(df) + } +} diff --git a/crates/polars-core/src/frame/upstream_traits.rs b/crates/polars-core/src/frame/upstream_traits.rs new file mode 100644 index 000000000000..1392f87c052f --- /dev/null +++ b/crates/polars-core/src/frame/upstream_traits.rs @@ -0,0 +1,88 @@ +use std::ops::{Index, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive}; + +use arrow::record_batch::RecordBatchT; + +use crate::prelude::*; + +impl FromIterator for DataFrame { + /// # Panics + /// + /// Panics if Series have different lengths. + fn from_iter>(iter: T) -> Self { + let v = iter.into_iter().map(Column::from).collect(); + DataFrame::new(v).expect("could not create DataFrame from iterator") + } +} + +impl FromIterator for DataFrame { + /// # Panics + /// + /// Panics if Column have different lengths. + fn from_iter>(iter: T) -> Self { + let v = iter.into_iter().collect(); + DataFrame::new(v).expect("could not create DataFrame from iterator") + } +} + +impl TryExtend>> for DataFrame { + fn try_extend>>>( + &mut self, + iter: I, + ) -> PolarsResult<()> { + for record_batch in iter { + self.append_record_batch(record_batch)?; + } + + Ok(()) + } +} + +impl TryExtend>>> for DataFrame { + fn try_extend>>>>( + &mut self, + iter: I, + ) -> PolarsResult<()> { + for record_batch in iter { + self.append_record_batch(record_batch?)?; + } + + Ok(()) + } +} + +impl Index for DataFrame { + type Output = Column; + + fn index(&self, index: usize) -> &Self::Output { + &self.columns[index] + } +} + +macro_rules! impl_ranges { + ($range_type:ty) => { + impl Index<$range_type> for DataFrame { + type Output = [Column]; + + fn index(&self, index: $range_type) -> &Self::Output { + &self.columns[index] + } + } + }; +} + +impl_ranges!(Range); +impl_ranges!(RangeInclusive); +impl_ranges!(RangeFrom); +impl_ranges!(RangeTo); +impl_ranges!(RangeToInclusive); +impl_ranges!(RangeFull); + +// we don't implement Borrow or AsRef as upstream crates may add impl of trait for usize. +impl Index<&str> for DataFrame { + type Output = Column; + + fn index(&self, index: &str) -> &Self::Output { + let idx = self.check_name_to_idx(index).unwrap(); + &self.columns[idx] + } +} diff --git a/crates/polars-core/src/frame/validation.rs b/crates/polars-core/src/frame/validation.rs new file mode 100644 index 000000000000..358924428934 --- /dev/null +++ b/crates/polars-core/src/frame/validation.rs @@ -0,0 +1,67 @@ +use polars_error::{PolarsResult, polars_bail}; +use polars_utils::aliases::{InitHashMaps, PlHashSet}; + +use super::DataFrame; +use super::column::Column; + +impl DataFrame { + /// Ensure all equal height and names are unique. + /// + /// An Ok() result indicates `columns` is a valid state for a DataFrame. + pub fn validate_columns_slice(columns: &[Column]) -> PolarsResult<()> { + if columns.len() <= 1 { + return Ok(()); + } + + if columns.len() <= 4 { + // Too small to be worth spawning a hashmap for, this is at most 6 comparisons. + for i in 0..columns.len() - 1 { + let name = columns[i].name(); + let height = columns[i].len(); + + for other in columns.iter().skip(i + 1) { + if other.name() == name { + polars_bail!(duplicate = name); + } + + if other.len() != height { + polars_bail!( + ShapeMismatch: + "height of column '{}' ({}) does not match height of column '{}' ({})", + other.name(), other.len(), name, height + ) + } + } + } + } else { + let first = &columns[0]; + + let first_len = first.len(); + let first_name = first.name(); + + let mut names = PlHashSet::with_capacity(columns.len()); + names.insert(first_name); + + for col in &columns[1..] { + let col_name = col.name(); + let col_len = col.len(); + + if col_len != first_len { + polars_bail!( + ShapeMismatch: + "height of column '{}' ({}) does not match height of column '{}' ({})", + col_name, col_len, first_name, first_len + ) + } + + if names.contains(col_name) { + polars_bail!(duplicate = col_name) + } + + names.insert(col_name); + } + } + + Ok(()) + } +} diff --git a/crates/polars-core/src/functions.rs b/crates/polars-core/src/functions.rs new file mode 100644 index 000000000000..ebc5006e0493 --- /dev/null +++ b/crates/polars-core/src/functions.rs @@ -0,0 +1,46 @@ +//! # Functions +//! +//! Functions that might be useful. +//! +pub use crate::frame::horizontal::concat_df_horizontal; +#[cfg(feature = "diagonal_concat")] +use crate::prelude::*; +#[cfg(feature = "diagonal_concat")] +use crate::utils::concat_df; + +/// Concat [`DataFrame`]s diagonally. +#[cfg(feature = "diagonal_concat")] +/// Concat diagonally thereby combining different schemas. +pub fn concat_df_diagonal(dfs: &[DataFrame]) -> PolarsResult { + // TODO! replace with lazy only? + let upper_bound_width = dfs.iter().map(|df| df.width()).sum(); + let mut column_names = PlHashSet::with_capacity(upper_bound_width); + let mut schema = Vec::with_capacity(upper_bound_width); + + for df in dfs { + df.get_columns().iter().for_each(|s| { + let name = s.name().clone(); + if column_names.insert(name.clone()) { + schema.push((name, s.dtype())) + } + }); + } + + let dfs = dfs + .iter() + .map(|df| { + let height = df.height(); + let mut columns = Vec::with_capacity(schema.len()); + + for (name, dtype) in &schema { + match df.column(name.as_str()).ok() { + Some(s) => columns.push(s.clone()), + None => columns.push(Column::full_null(name.clone(), height, dtype)), + } + } + unsafe { DataFrame::new_no_checks(height, columns) } + }) + .collect::>(); + + concat_df(&dfs) +} diff --git a/crates/polars-core/src/hashing/identity.rs b/crates/polars-core/src/hashing/identity.rs new file mode 100644 index 000000000000..a1ae697106f9 --- /dev/null +++ b/crates/polars-core/src/hashing/identity.rs @@ -0,0 +1,35 @@ +use super::*; + +#[derive(Default)] +pub struct IdHasher { + hash: u64, +} + +impl Hasher for IdHasher { + fn finish(&self) -> u64 { + self.hash + } + + fn write(&mut self, _bytes: &[u8]) { + unreachable!("IdHasher should only be used for integer keys <= 64 bit precision") + } + + fn write_u32(&mut self, i: u32) { + self.hash = i as u64; + } + + #[inline] + fn write_u64(&mut self, i: u64) { + self.hash = i; + } + + fn write_i32(&mut self, i: i32) { + self.hash = i as u64; + } + + fn write_i64(&mut self, i: i64) { + self.hash = i as u64; + } +} + +pub type IdBuildHasher = BuildHasherDefault; diff --git a/crates/polars-core/src/hashing/mod.rs b/crates/polars-core/src/hashing/mod.rs new file mode 100644 index 000000000000..854b7d29737c --- /dev/null +++ b/crates/polars-core/src/hashing/mod.rs @@ -0,0 +1,11 @@ +mod identity; +pub(crate) mod vector_hasher; + +use std::hash::{BuildHasherDefault, Hash, Hasher}; + +pub use identity::*; +pub use vector_hasher::*; + +// We must strike a balance between cache +// Overallocation seems a lot more expensive than resizing so we start reasonable small. +pub const _HASHMAP_INIT_SIZE: usize = 512; diff --git a/crates/polars-core/src/hashing/vector_hasher.rs b/crates/polars-core/src/hashing/vector_hasher.rs new file mode 100644 index 000000000000..bbefac0fd591 --- /dev/null +++ b/crates/polars-core/src/hashing/vector_hasher.rs @@ -0,0 +1,515 @@ +use std::hash::BuildHasher; + +use arrow::bitmap::utils::get_bit_unchecked; +use polars_utils::aliases::PlSeedableRandomStateQuality; +use polars_utils::hashing::{_boost_hash_combine, folded_multiply}; +use polars_utils::total_ord::{ToTotalOrd, TotalHash}; +use rayon::prelude::*; +use xxhash_rust::xxh3::xxh3_64_with_seed; + +use super::*; +use crate::POOL; +use crate::prelude::*; +use crate::series::implementations::null::NullChunked; + +// See: https://github.com/tkaitchuck/aHash/blob/f9acd508bd89e7c5b2877a9510098100f9018d64/src/operations.rs#L4 +const MULTIPLE: u64 = 6364136223846793005; + +// Read more: +// https://www.cockroachlabs.com/blog/vectorized-hash-joiner/ +// http://myeyesareblind.com/2017/02/06/Combine-hash-values/ + +pub trait VecHash { + /// Compute the hash for all values in the array. + fn vec_hash( + &self, + _random_state: PlSeedableRandomStateQuality, + _buf: &mut Vec, + ) -> PolarsResult<()> { + polars_bail!(un_impl = vec_hash); + } + + fn vec_hash_combine( + &self, + _random_state: PlSeedableRandomStateQuality, + _hashes: &mut [u64], + ) -> PolarsResult<()> { + polars_bail!(un_impl = vec_hash_combine); + } +} + +pub(crate) fn get_null_hash_value(random_state: &PlSeedableRandomStateQuality) -> u64 { + // we just start with a large prime number and hash that twice + // to get a constant hash value for null/None + let first = random_state.hash_one(3188347919usize); + random_state.hash_one(first) +} + +fn insert_null_hash( + chunks: &[ArrayRef], + random_state: PlSeedableRandomStateQuality, + buf: &mut Vec, +) { + let null_h = get_null_hash_value(&random_state); + let hashes = buf.as_mut_slice(); + + let mut offset = 0; + chunks.iter().for_each(|arr| { + if arr.null_count() > 0 { + let validity = arr.validity().unwrap(); + let (slice, byte_offset, _) = validity.as_slice(); + (0..validity.len()) + .map(|i| unsafe { get_bit_unchecked(slice, i + byte_offset) }) + .zip(&mut hashes[offset..]) + .for_each(|(valid, h)| { + *h = [null_h, *h][valid as usize]; + }) + } + offset += arr.len(); + }); +} + +fn numeric_vec_hash( + ca: &ChunkedArray, + random_state: PlSeedableRandomStateQuality, + buf: &mut Vec, +) where + T: PolarsNumericType, + T::Native: TotalHash + ToTotalOrd, + ::TotalOrdItem: Hash, +{ + // Note that we don't use the no null branch! This can break in unexpected ways. + // for instance with threading we split an array in n_threads, this may lead to + // splits that have no nulls and splits that have nulls. Then one array is hashed with + // Option and the other array with T. + // Meaning that they cannot be compared. By always hashing on Option the random_state is + // the only deterministic seed. + buf.clear(); + buf.reserve(ca.len()); + + #[allow(unused_unsafe)] + #[allow(clippy::useless_transmute)] + ca.downcast_iter().for_each(|arr| { + buf.extend( + arr.values() + .as_slice() + .iter() + .copied() + .map(|v| random_state.hash_one(v.to_total_ord())), + ); + }); + insert_null_hash(&ca.chunks, random_state, buf) +} + +fn numeric_vec_hash_combine( + ca: &ChunkedArray, + random_state: PlSeedableRandomStateQuality, + hashes: &mut [u64], +) where + T: PolarsNumericType, + T::Native: TotalHash + ToTotalOrd, + ::TotalOrdItem: Hash, +{ + let null_h = get_null_hash_value(&random_state); + + let mut offset = 0; + ca.downcast_iter().for_each(|arr| { + match arr.null_count() { + 0 => arr + .values() + .as_slice() + .iter() + .zip(&mut hashes[offset..]) + .for_each(|(v, h)| { + // Inlined from ahash. This ensures we combine with the previous state. + *h = folded_multiply( + // Be careful not to xor the hash directly with the existing hash, + // it would lead to 0-hashes for 2 columns containing equal values. + random_state.hash_one(v.to_total_ord()) ^ folded_multiply(*h, MULTIPLE), + MULTIPLE, + ); + }), + _ => { + let validity = arr.validity().unwrap(); + let (slice, byte_offset, _) = validity.as_slice(); + (0..validity.len()) + .map(|i| unsafe { get_bit_unchecked(slice, i + byte_offset) }) + .zip(&mut hashes[offset..]) + .zip(arr.values().as_slice()) + .for_each(|((valid, h), l)| { + let lh = random_state.hash_one(l.to_total_ord()); + let to_hash = [null_h, lh][valid as usize]; + *h = folded_multiply(to_hash ^ folded_multiply(*h, MULTIPLE), MULTIPLE); + }); + }, + } + offset += arr.len(); + }); +} + +macro_rules! vec_hash_numeric { + ($ca:ident) => { + impl VecHash for $ca { + fn vec_hash( + &self, + random_state: PlSeedableRandomStateQuality, + buf: &mut Vec, + ) -> PolarsResult<()> { + numeric_vec_hash(self, random_state, buf); + Ok(()) + } + + fn vec_hash_combine( + &self, + random_state: PlSeedableRandomStateQuality, + hashes: &mut [u64], + ) -> PolarsResult<()> { + numeric_vec_hash_combine(self, random_state, hashes); + Ok(()) + } + } + }; +} + +vec_hash_numeric!(Int64Chunked); +vec_hash_numeric!(Int32Chunked); +vec_hash_numeric!(Int16Chunked); +vec_hash_numeric!(Int8Chunked); +vec_hash_numeric!(UInt64Chunked); +vec_hash_numeric!(UInt32Chunked); +vec_hash_numeric!(UInt16Chunked); +vec_hash_numeric!(UInt8Chunked); +vec_hash_numeric!(Float64Chunked); +vec_hash_numeric!(Float32Chunked); +#[cfg(any(feature = "dtype-decimal", feature = "dtype-i128"))] +vec_hash_numeric!(Int128Chunked); + +impl VecHash for StringChunked { + fn vec_hash( + &self, + random_state: PlSeedableRandomStateQuality, + buf: &mut Vec, + ) -> PolarsResult<()> { + self.as_binary().vec_hash(random_state, buf)?; + Ok(()) + } + + fn vec_hash_combine( + &self, + random_state: PlSeedableRandomStateQuality, + hashes: &mut [u64], + ) -> PolarsResult<()> { + self.as_binary().vec_hash_combine(random_state, hashes)?; + Ok(()) + } +} + +// used in polars-pipe +pub fn _hash_binary_array( + arr: &BinaryArray, + random_state: PlSeedableRandomStateQuality, + buf: &mut Vec, +) { + let null_h = get_null_hash_value(&random_state); + if arr.null_count() == 0 { + // use the null_hash as seed to get a hash determined by `random_state` that is passed + buf.extend(arr.values_iter().map(|v| xxh3_64_with_seed(v, null_h))) + } else { + buf.extend(arr.into_iter().map(|opt_v| match opt_v { + Some(v) => xxh3_64_with_seed(v, null_h), + None => null_h, + })) + } +} + +fn hash_binview_array( + arr: &BinaryViewArray, + random_state: PlSeedableRandomStateQuality, + buf: &mut Vec, +) { + let null_h = get_null_hash_value(&random_state); + if arr.null_count() == 0 { + // use the null_hash as seed to get a hash determined by `random_state` that is passed + buf.extend(arr.values_iter().map(|v| xxh3_64_with_seed(v, null_h))) + } else { + buf.extend(arr.into_iter().map(|opt_v| match opt_v { + Some(v) => xxh3_64_with_seed(v, null_h), + None => null_h, + })) + } +} + +impl VecHash for BinaryChunked { + fn vec_hash( + &self, + random_state: PlSeedableRandomStateQuality, + buf: &mut Vec, + ) -> PolarsResult<()> { + buf.clear(); + buf.reserve(self.len()); + self.downcast_iter() + .for_each(|arr| hash_binview_array(arr, random_state, buf)); + Ok(()) + } + + fn vec_hash_combine( + &self, + random_state: PlSeedableRandomStateQuality, + hashes: &mut [u64], + ) -> PolarsResult<()> { + let null_h = get_null_hash_value(&random_state); + + let mut offset = 0; + self.downcast_iter().for_each(|arr| { + match arr.null_count() { + 0 => arr + .values_iter() + .zip(&mut hashes[offset..]) + .for_each(|(v, h)| { + let l = xxh3_64_with_seed(v, null_h); + *h = _boost_hash_combine(l, *h) + }), + _ => { + let validity = arr.validity().unwrap(); + let (slice, byte_offset, _) = validity.as_slice(); + (0..validity.len()) + .map(|i| unsafe { get_bit_unchecked(slice, i + byte_offset) }) + .zip(&mut hashes[offset..]) + .zip(arr.values_iter()) + .for_each(|((valid, h), l)| { + let l = if valid { + xxh3_64_with_seed(l, null_h) + } else { + null_h + }; + *h = _boost_hash_combine(l, *h) + }); + }, + } + offset += arr.len(); + }); + Ok(()) + } +} + +impl VecHash for BinaryOffsetChunked { + fn vec_hash( + &self, + random_state: PlSeedableRandomStateQuality, + buf: &mut Vec, + ) -> PolarsResult<()> { + buf.clear(); + buf.reserve(self.len()); + self.downcast_iter() + .for_each(|arr| _hash_binary_array(arr, random_state, buf)); + Ok(()) + } + + fn vec_hash_combine( + &self, + random_state: PlSeedableRandomStateQuality, + hashes: &mut [u64], + ) -> PolarsResult<()> { + let null_h = get_null_hash_value(&random_state); + + let mut offset = 0; + self.downcast_iter().for_each(|arr| { + match arr.null_count() { + 0 => arr + .values_iter() + .zip(&mut hashes[offset..]) + .for_each(|(v, h)| { + let l = xxh3_64_with_seed(v, null_h); + *h = _boost_hash_combine(l, *h) + }), + _ => { + let validity = arr.validity().unwrap(); + let (slice, byte_offset, _) = validity.as_slice(); + (0..validity.len()) + .map(|i| unsafe { get_bit_unchecked(slice, i + byte_offset) }) + .zip(&mut hashes[offset..]) + .zip(arr.values_iter()) + .for_each(|((valid, h), l)| { + let l = if valid { + xxh3_64_with_seed(l, null_h) + } else { + null_h + }; + *h = _boost_hash_combine(l, *h) + }); + }, + } + offset += arr.len(); + }); + Ok(()) + } +} + +impl VecHash for NullChunked { + fn vec_hash( + &self, + random_state: PlSeedableRandomStateQuality, + buf: &mut Vec, + ) -> PolarsResult<()> { + let null_h = get_null_hash_value(&random_state); + buf.clear(); + buf.resize(self.len(), null_h); + Ok(()) + } + + fn vec_hash_combine( + &self, + random_state: PlSeedableRandomStateQuality, + hashes: &mut [u64], + ) -> PolarsResult<()> { + let null_h = get_null_hash_value(&random_state); + hashes + .iter_mut() + .for_each(|h| *h = _boost_hash_combine(null_h, *h)); + Ok(()) + } +} +impl VecHash for BooleanChunked { + fn vec_hash( + &self, + random_state: PlSeedableRandomStateQuality, + buf: &mut Vec, + ) -> PolarsResult<()> { + buf.clear(); + buf.reserve(self.len()); + let true_h = random_state.hash_one(true); + let false_h = random_state.hash_one(false); + let null_h = get_null_hash_value(&random_state); + self.downcast_iter().for_each(|arr| { + if arr.null_count() == 0 { + buf.extend(arr.values_iter().map(|v| if v { true_h } else { false_h })) + } else { + buf.extend(arr.into_iter().map(|opt_v| match opt_v { + Some(true) => true_h, + Some(false) => false_h, + None => null_h, + })) + } + }); + Ok(()) + } + + fn vec_hash_combine( + &self, + random_state: PlSeedableRandomStateQuality, + hashes: &mut [u64], + ) -> PolarsResult<()> { + let true_h = random_state.hash_one(true); + let false_h = random_state.hash_one(false); + let null_h = get_null_hash_value(&random_state); + + let mut offset = 0; + self.downcast_iter().for_each(|arr| { + match arr.null_count() { + 0 => arr + .values_iter() + .zip(&mut hashes[offset..]) + .for_each(|(v, h)| { + let l = if v { true_h } else { false_h }; + *h = _boost_hash_combine(l, *h) + }), + _ => { + let validity = arr.validity().unwrap(); + let (slice, byte_offset, _) = validity.as_slice(); + (0..validity.len()) + .map(|i| unsafe { get_bit_unchecked(slice, i + byte_offset) }) + .zip(&mut hashes[offset..]) + .zip(arr.values()) + .for_each(|((valid, h), l)| { + let l = if valid { + if l { true_h } else { false_h } + } else { + null_h + }; + *h = _boost_hash_combine(l, *h) + }); + }, + } + offset += arr.len(); + }); + Ok(()) + } +} + +#[cfg(feature = "object")] +impl VecHash for ObjectChunked +where + T: PolarsObject, +{ + fn vec_hash( + &self, + random_state: PlSeedableRandomStateQuality, + buf: &mut Vec, + ) -> PolarsResult<()> { + // Note that we don't use the no null branch! This can break in unexpected ways. + // for instance with threading we split an array in n_threads, this may lead to + // splits that have no nulls and splits that have nulls. Then one array is hashed with + // Option and the other array with T. + // Meaning that they cannot be compared. By always hashing on Option the random_state is + // the only deterministic seed. + buf.clear(); + buf.reserve(self.len()); + + self.downcast_iter() + .for_each(|arr| buf.extend(arr.into_iter().map(|opt_v| random_state.hash_one(opt_v)))); + + Ok(()) + } + + fn vec_hash_combine( + &self, + random_state: PlSeedableRandomStateQuality, + hashes: &mut [u64], + ) -> PolarsResult<()> { + self.apply_to_slice( + |opt_v, h| { + let hashed = random_state.hash_one(opt_v); + _boost_hash_combine(hashed, *h) + }, + hashes, + ); + Ok(()) + } +} + +pub fn _df_rows_to_hashes_threaded_vertical( + keys: &[DataFrame], + build_hasher: Option, +) -> PolarsResult<(Vec, PlSeedableRandomStateQuality)> { + let build_hasher = build_hasher.unwrap_or_default(); + + let hashes = POOL.install(|| { + keys.into_par_iter() + .map(|df| { + let hb = build_hasher; + let mut hashes = vec![]; + columns_to_hashes(df.get_columns(), Some(hb), &mut hashes)?; + Ok(UInt64Chunked::from_vec(PlSmallStr::EMPTY, hashes)) + }) + .collect::>>() + })?; + Ok((hashes, build_hasher)) +} + +pub fn columns_to_hashes( + keys: &[Column], + build_hasher: Option, + hashes: &mut Vec, +) -> PolarsResult { + let build_hasher = build_hasher.unwrap_or_default(); + + let mut iter = keys.iter(); + let first = iter.next().expect("at least one key"); + first.vec_hash(build_hasher, hashes)?; + + for keys in iter { + keys.vec_hash_combine(build_hasher, hashes)?; + } + + Ok(build_hasher) +} diff --git a/crates/polars-core/src/lib.rs b/crates/polars-core/src/lib.rs new file mode 100644 index 000000000000..3c434a5ee6ed --- /dev/null +++ b/crates/polars-core/src/lib.rs @@ -0,0 +1,85 @@ +#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(feature = "simd", feature(portable_simd))] +#![allow(ambiguous_glob_reexports)] +#![cfg_attr(feature = "nightly", allow(clippy::non_canonical_partial_ord_impl))] // remove once stable +extern crate core; + +#[macro_use] +pub mod utils; +pub mod chunked_array; +pub mod config; +pub mod datatypes; +pub mod error; +pub mod fmt; +pub mod frame; +pub mod functions; +pub mod hashing; +mod named_from; +pub mod prelude; +#[cfg(feature = "random")] +pub mod random; +pub mod scalar; +pub mod schema; +#[cfg(feature = "serde")] +pub mod serde; +pub mod series; +pub mod testing; +#[cfg(test)] +mod tests; + +use std::sync::{LazyLock, Mutex}; +use std::time::{SystemTime, UNIX_EPOCH}; + +pub use datatypes::SchemaExtPl; +pub use hashing::IdBuildHasher; +use rayon::{ThreadPool, ThreadPoolBuilder}; + +#[cfg(feature = "dtype-categorical")] +pub use crate::chunked_array::logical::categorical::string_cache::*; + +pub static PROCESS_ID: LazyLock = LazyLock::new(|| { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_nanos() +}); + +// this is re-exported in utils for polars child crates +#[cfg(not(target_family = "wasm"))] // only use this on non wasm targets +pub static POOL: LazyLock = LazyLock::new(|| { + let thread_name = std::env::var("POLARS_THREAD_NAME").unwrap_or_else(|_| "polars".to_string()); + ThreadPoolBuilder::new() + .num_threads( + std::env::var("POLARS_MAX_THREADS") + .map(|s| s.parse::().expect("integer")) + .unwrap_or_else(|_| { + std::thread::available_parallelism() + .unwrap_or(std::num::NonZeroUsize::new(1).unwrap()) + .get() + }), + ) + .thread_name(move |i| format!("{}-{}", thread_name, i)) + .build() + .expect("could not spawn threads") +}); + +#[cfg(all(target_os = "emscripten", target_family = "wasm"))] // Use 1 rayon thread on emscripten +pub static POOL: LazyLock = LazyLock::new(|| { + ThreadPoolBuilder::new() + .num_threads(1) + .use_current_thread() + .build() + .expect("could not create pool") +}); + +#[cfg(all(not(target_os = "emscripten"), target_family = "wasm"))] // use this on other wasm targets +pub static POOL: LazyLock = LazyLock::new(|| polars_utils::wasm::Pool); + +// utility for the tests to ensure a single thread can execute +pub static SINGLE_LOCK: LazyLock> = LazyLock::new(|| Mutex::new(())); + +/// Default length for a `.head()` call +pub(crate) const HEAD_DEFAULT_LENGTH: usize = 10; +/// Default length for a `.tail()` call +pub(crate) const TAIL_DEFAULT_LENGTH: usize = 10; +pub const CHEAP_SERIES_HASH_LIMIT: usize = 1000; diff --git a/crates/polars-core/src/named_from.rs b/crates/polars-core/src/named_from.rs new file mode 100644 index 000000000000..0f0a9b5562a7 --- /dev/null +++ b/crates/polars-core/src/named_from.rs @@ -0,0 +1,473 @@ +use std::borrow::Cow; + +#[cfg(feature = "dtype-duration")] +use chrono::Duration as ChronoDuration; +#[cfg(feature = "dtype-date")] +use chrono::NaiveDate; +#[cfg(feature = "dtype-datetime")] +use chrono::NaiveDateTime; +#[cfg(feature = "dtype-time")] +use chrono::NaiveTime; + +use crate::chunked_array::builder::get_list_builder; +use crate::prelude::*; + +pub trait NamedFrom { + /// Initialize by name and values. + fn new(name: PlSmallStr, _: T) -> Self; +} + +pub trait NamedFromOwned { + /// Initialize by name and values. + fn from_vec(name: PlSmallStr, _: T) -> Self; +} + +macro_rules! impl_named_from_owned { + ($type:ty, $polars_type:ident) => { + impl NamedFromOwned<$type> for Series { + fn from_vec(name: PlSmallStr, v: $type) -> Self { + ChunkedArray::<$polars_type>::from_vec(name, v).into_series() + } + } + }; +} + +#[cfg(feature = "dtype-i8")] +impl_named_from_owned!(Vec, Int8Type); +#[cfg(feature = "dtype-i16")] +impl_named_from_owned!(Vec, Int16Type); +impl_named_from_owned!(Vec, Int32Type); +impl_named_from_owned!(Vec, Int64Type); +#[cfg(feature = "dtype-i128")] +impl_named_from_owned!(Vec, Int128Type); +#[cfg(feature = "dtype-u8")] +impl_named_from_owned!(Vec, UInt8Type); +#[cfg(feature = "dtype-u16")] +impl_named_from_owned!(Vec, UInt16Type); +impl_named_from_owned!(Vec, UInt32Type); +impl_named_from_owned!(Vec, UInt64Type); +impl_named_from_owned!(Vec, Float32Type); +impl_named_from_owned!(Vec, Float64Type); + +macro_rules! impl_named_from { + ($type:ty, $polars_type:ident, $method:ident) => { + impl> NamedFrom for Series { + fn new(name: PlSmallStr, v: T) -> Self { + ChunkedArray::<$polars_type>::$method(name, v.as_ref()).into_series() + } + } + impl> NamedFrom for ChunkedArray<$polars_type> { + fn new(name: PlSmallStr, v: T) -> Self { + ChunkedArray::<$polars_type>::$method(name, v.as_ref()) + } + } + }; +} + +impl_named_from!([String], StringType, from_slice); +impl_named_from!([Vec], BinaryType, from_slice); +impl_named_from!([bool], BooleanType, from_slice); +#[cfg(feature = "dtype-u8")] +impl_named_from!([u8], UInt8Type, from_slice); +#[cfg(feature = "dtype-u16")] +impl_named_from!([u16], UInt16Type, from_slice); +impl_named_from!([u32], UInt32Type, from_slice); +impl_named_from!([u64], UInt64Type, from_slice); +#[cfg(feature = "dtype-i8")] +impl_named_from!([i8], Int8Type, from_slice); +#[cfg(feature = "dtype-i16")] +impl_named_from!([i16], Int16Type, from_slice); +impl_named_from!([i32], Int32Type, from_slice); +impl_named_from!([i64], Int64Type, from_slice); +#[cfg(feature = "dtype-decimal")] +impl_named_from!([i128], Int128Type, from_slice); +impl_named_from!([f32], Float32Type, from_slice); +impl_named_from!([f64], Float64Type, from_slice); +impl_named_from!([Option], StringType, from_slice_options); +impl_named_from!([Option>], BinaryType, from_slice_options); +impl_named_from!([Option], BooleanType, from_slice_options); +#[cfg(feature = "dtype-u8")] +impl_named_from!([Option], UInt8Type, from_slice_options); +#[cfg(feature = "dtype-u16")] +impl_named_from!([Option], UInt16Type, from_slice_options); +impl_named_from!([Option], UInt32Type, from_slice_options); +impl_named_from!([Option], UInt64Type, from_slice_options); +#[cfg(feature = "dtype-i8")] +impl_named_from!([Option], Int8Type, from_slice_options); +#[cfg(feature = "dtype-i16")] +impl_named_from!([Option], Int16Type, from_slice_options); +impl_named_from!([Option], Int32Type, from_slice_options); +impl_named_from!([Option], Int64Type, from_slice_options); +#[cfg(feature = "dtype-decimal")] +impl_named_from!([Option], Int128Type, from_slice_options); +impl_named_from!([Option], Float32Type, from_slice_options); +impl_named_from!([Option], Float64Type, from_slice_options); + +macro_rules! impl_named_from_range { + ($range:ty, $polars_type:ident) => { + impl NamedFrom<$range, $polars_type> for ChunkedArray<$polars_type> { + fn new(name: PlSmallStr, range: $range) -> Self { + let values = range.collect::>(); + ChunkedArray::<$polars_type>::from_vec(name, values) + } + } + + impl NamedFrom<$range, $polars_type> for Series { + fn new(name: PlSmallStr, range: $range) -> Self { + ChunkedArray::new(name, range).into_series() + } + } + }; +} +impl_named_from_range!(std::ops::Range, Int64Type); +impl_named_from_range!(std::ops::Range, Int32Type); +impl_named_from_range!(std::ops::Range, UInt64Type); +impl_named_from_range!(std::ops::Range, UInt32Type); + +impl> NamedFrom for Series { + fn new(name: PlSmallStr, s: T) -> Self { + let series_slice = s.as_ref(); + let list_cap = series_slice.len(); + + if series_slice.is_empty() { + return Series::new_empty(name, &DataType::Null); + } + + let dt = series_slice[0].dtype(); + + let values_cap = series_slice.iter().fold(0, |acc, s| acc + s.len()); + + let mut builder = get_list_builder(dt, values_cap, list_cap, name); + for series in series_slice { + builder.append_series(series).unwrap(); + } + builder.finish().into_series() + } +} + +impl]>> NamedFrom]> for Series { + fn new(name: PlSmallStr, s: T) -> Self { + let series_slice = s.as_ref(); + let values_cap = series_slice.iter().fold(0, |acc, opt_s| { + acc + opt_s.as_ref().map(|s| s.len()).unwrap_or(0) + }); + let dt = match series_slice.iter().filter_map(|opt| opt.as_ref()).next() { + Some(series) => series.dtype(), + None => &DataType::Null, + }; + + let mut builder = get_list_builder(dt, values_cap, series_slice.len(), name); + for series in series_slice { + builder.append_opt_series(series.as_ref()).unwrap(); + } + builder.finish().into_series() + } +} +impl<'a, T: AsRef<[&'a str]>> NamedFrom for Series { + fn new(name: PlSmallStr, v: T) -> Self { + StringChunked::from_slice(name, v.as_ref()).into_series() + } +} + +impl NamedFrom<&Series, str> for Series { + fn new(name: PlSmallStr, s: &Series) -> Self { + let mut s = s.clone(); + s.rename(name); + s + } +} + +impl<'a, T: AsRef<[&'a str]>> NamedFrom for StringChunked { + fn new(name: PlSmallStr, v: T) -> Self { + StringChunked::from_slice(name, v.as_ref()) + } +} + +impl<'a, T: AsRef<[Option<&'a str>]>> NamedFrom]> for Series { + fn new(name: PlSmallStr, v: T) -> Self { + StringChunked::from_slice_options(name, v.as_ref()).into_series() + } +} + +impl<'a, T: AsRef<[Option<&'a str>]>> NamedFrom]> for StringChunked { + fn new(name: PlSmallStr, v: T) -> Self { + StringChunked::from_slice_options(name, v.as_ref()) + } +} + +impl<'a, T: AsRef<[Cow<'a, str>]>> NamedFrom]> for Series { + fn new(name: PlSmallStr, v: T) -> Self { + StringChunked::from_iter_values(name, v.as_ref().iter().map(|value| value.as_ref())) + .into_series() + } +} + +impl<'a, T: AsRef<[Cow<'a, str>]>> NamedFrom]> for StringChunked { + fn new(name: PlSmallStr, v: T) -> Self { + StringChunked::from_iter_values(name, v.as_ref().iter().map(|value| value.as_ref())) + } +} + +impl<'a, T: AsRef<[Option>]>> NamedFrom>]> for Series { + fn new(name: PlSmallStr, v: T) -> Self { + StringChunked::new(name, v).into_series() + } +} + +impl<'a, T: AsRef<[Option>]>> NamedFrom>]> for StringChunked { + fn new(name: PlSmallStr, v: T) -> Self { + StringChunked::from_iter_options( + name, + v.as_ref() + .iter() + .map(|opt| opt.as_ref().map(|value| value.as_ref())), + ) + } +} + +impl<'a, T: AsRef<[&'a [u8]]>> NamedFrom for Series { + fn new(name: PlSmallStr, v: T) -> Self { + BinaryChunked::from_slice(name, v.as_ref()).into_series() + } +} + +impl<'a, T: AsRef<[&'a [u8]]>> NamedFrom for BinaryChunked { + fn new(name: PlSmallStr, v: T) -> Self { + BinaryChunked::from_slice(name, v.as_ref()) + } +} + +impl<'a, T: AsRef<[Option<&'a [u8]>]>> NamedFrom]> for Series { + fn new(name: PlSmallStr, v: T) -> Self { + BinaryChunked::from_slice_options(name, v.as_ref()).into_series() + } +} + +impl<'a, T: AsRef<[Option<&'a [u8]>]>> NamedFrom]> for BinaryChunked { + fn new(name: PlSmallStr, v: T) -> Self { + BinaryChunked::from_slice_options(name, v.as_ref()) + } +} + +impl<'a, T: AsRef<[Cow<'a, [u8]>]>> NamedFrom]> for Series { + fn new(name: PlSmallStr, v: T) -> Self { + BinaryChunked::from_iter_values(name, v.as_ref().iter().map(|value| value.as_ref())) + .into_series() + } +} + +impl<'a, T: AsRef<[Cow<'a, [u8]>]>> NamedFrom]> for BinaryChunked { + fn new(name: PlSmallStr, v: T) -> Self { + BinaryChunked::from_iter_values(name, v.as_ref().iter().map(|value| value.as_ref())) + } +} + +impl<'a, T: AsRef<[Option>]>> NamedFrom>]> for Series { + fn new(name: PlSmallStr, v: T) -> Self { + BinaryChunked::new(name, v).into_series() + } +} + +impl<'a, T: AsRef<[Option>]>> NamedFrom>]> + for BinaryChunked +{ + fn new(name: PlSmallStr, v: T) -> Self { + BinaryChunked::from_iter_options( + name, + v.as_ref() + .iter() + .map(|opt| opt.as_ref().map(|value| value.as_ref())), + ) + } +} + +#[cfg(feature = "dtype-date")] +impl> NamedFrom for DateChunked { + fn new(name: PlSmallStr, v: T) -> Self { + DateChunked::from_naive_date(name, v.as_ref().iter().copied()) + } +} + +#[cfg(feature = "dtype-date")] +impl> NamedFrom for Series { + fn new(name: PlSmallStr, v: T) -> Self { + DateChunked::new(name, v).into_series() + } +} + +#[cfg(feature = "dtype-date")] +impl]>> NamedFrom]> for DateChunked { + fn new(name: PlSmallStr, v: T) -> Self { + DateChunked::from_naive_date_options(name, v.as_ref().iter().copied()) + } +} + +#[cfg(feature = "dtype-date")] +impl]>> NamedFrom]> for Series { + fn new(name: PlSmallStr, v: T) -> Self { + DateChunked::new(name, v).into_series() + } +} + +#[cfg(feature = "dtype-datetime")] +impl> NamedFrom for DatetimeChunked { + fn new(name: PlSmallStr, v: T) -> Self { + DatetimeChunked::from_naive_datetime( + name, + v.as_ref().iter().copied(), + TimeUnit::Milliseconds, + ) + } +} + +#[cfg(feature = "dtype-datetime")] +impl> NamedFrom for Series { + fn new(name: PlSmallStr, v: T) -> Self { + DatetimeChunked::new(name, v).into_series() + } +} + +#[cfg(feature = "dtype-datetime")] +impl]>> NamedFrom]> for DatetimeChunked { + fn new(name: PlSmallStr, v: T) -> Self { + DatetimeChunked::from_naive_datetime_options( + name, + v.as_ref().iter().copied(), + TimeUnit::Milliseconds, + ) + } +} + +#[cfg(feature = "dtype-datetime")] +impl]>> NamedFrom]> for Series { + fn new(name: PlSmallStr, v: T) -> Self { + DatetimeChunked::new(name, v).into_series() + } +} + +#[cfg(feature = "dtype-duration")] +impl> NamedFrom for DurationChunked { + fn new(name: PlSmallStr, v: T) -> Self { + DurationChunked::from_duration(name, v.as_ref().iter().copied(), TimeUnit::Nanoseconds) + } +} + +#[cfg(feature = "dtype-duration")] +impl> NamedFrom for Series { + fn new(name: PlSmallStr, v: T) -> Self { + DurationChunked::new(name, v).into_series() + } +} + +#[cfg(feature = "dtype-duration")] +impl]>> NamedFrom]> + for DurationChunked +{ + fn new(name: PlSmallStr, v: T) -> Self { + DurationChunked::from_duration_options( + name, + v.as_ref().iter().copied(), + TimeUnit::Nanoseconds, + ) + } +} + +#[cfg(feature = "dtype-duration")] +impl]>> NamedFrom]> for Series { + fn new(name: PlSmallStr, v: T) -> Self { + DurationChunked::new(name, v).into_series() + } +} + +#[cfg(feature = "dtype-time")] +impl> NamedFrom for TimeChunked { + fn new(name: PlSmallStr, v: T) -> Self { + TimeChunked::from_naive_time(name, v.as_ref().iter().copied()) + } +} + +#[cfg(feature = "dtype-time")] +impl> NamedFrom for Series { + fn new(name: PlSmallStr, v: T) -> Self { + TimeChunked::new(name, v).into_series() + } +} + +#[cfg(feature = "dtype-time")] +impl]>> NamedFrom]> for TimeChunked { + fn new(name: PlSmallStr, v: T) -> Self { + TimeChunked::from_naive_time_options(name, v.as_ref().iter().copied()) + } +} + +#[cfg(feature = "dtype-time")] +impl]>> NamedFrom]> for Series { + fn new(name: PlSmallStr, v: T) -> Self { + TimeChunked::new(name, v).into_series() + } +} + +#[cfg(feature = "object")] +impl NamedFrom<&[T], &[T]> for ObjectChunked { + fn new(name: PlSmallStr, v: &[T]) -> Self { + ObjectChunked::from_slice(name, v) + } +} + +#[cfg(feature = "object")] +impl]>> NamedFrom]> for ObjectChunked { + fn new(name: PlSmallStr, v: S) -> Self { + ObjectChunked::from_slice_options(name, v.as_ref()) + } +} + +impl ChunkedArray { + /// Specialization that prevents an allocation + /// prefer this over ChunkedArray::new when you have a `Vec` and no null values. + pub fn new_vec(name: PlSmallStr, v: Vec) -> Self { + ChunkedArray::from_vec(name, v) + } +} + +/// For any [`ChunkedArray`] and [`Series`] +impl NamedFrom for Series { + fn new(name: PlSmallStr, t: T) -> Self { + let mut s = t.into_series(); + s.rename(name); + s + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[cfg(all( + feature = "dtype-datetime", + feature = "dtype-duration", + feature = "dtype-date", + feature = "dtype-time" + ))] + #[test] + fn test_temporal_df_construction() { + // check if we can construct. + let _df = df![ + "date" => [NaiveDate::from_ymd_opt(2021, 1, 1).unwrap()], + "datetime" => [NaiveDate::from_ymd_opt(2021, 1, 1).unwrap().and_hms_opt(0, 0, 0).unwrap()], + "optional_date" => [Some(NaiveDate::from_ymd_opt(2021, 1, 1).unwrap())], + "optional_datetime" => [Some(NaiveDate::from_ymd_opt(2021, 1, 1).unwrap().and_hms_opt(0, 0, 0).unwrap())], + "time" => [NaiveTime::from_hms_opt(23, 23, 23).unwrap()], + "optional_time" => [Some(NaiveTime::from_hms_opt(23, 23, 23).unwrap())], + "duration" => [ChronoDuration::from_std(std::time::Duration::from_secs(10)).unwrap()], + "optional_duration" => [Some(ChronoDuration::from_std(std::time::Duration::from_secs(10)).unwrap())], + ].unwrap(); + } + + #[test] + fn build_series_from_empty_series_vec() { + let empty_series = Series::new("test".into(), Vec::::new()); + assert_eq!(empty_series.len(), 0); + assert_eq!(*empty_series.dtype(), DataType::Null); + assert_eq!(empty_series.name().as_str(), "test"); + } +} diff --git a/crates/polars-core/src/prelude.rs b/crates/polars-core/src/prelude.rs new file mode 100644 index 000000000000..ced824db3c54 --- /dev/null +++ b/crates/polars-core/src/prelude.rs @@ -0,0 +1,63 @@ +//! Everything you need to get started with Polars. +pub use std::sync::Arc; + +pub use arrow::array::ArrayRef; +pub(crate) use arrow::array::*; +pub use arrow::datatypes::{ArrowSchema, Field as ArrowField}; +pub use arrow::legacy::prelude::*; +pub(crate) use arrow::trusted_len::TrustedLen; +pub use polars_compute::rolling::{QuantileMethod, RollingFnParams, RollingVarParams}; +pub use polars_utils::aliases::*; +pub use polars_utils::index::{ChunkId, IdxSize, NullableIdxSize}; +pub use polars_utils::pl_str::PlSmallStr; +pub(crate) use polars_utils::total_ord::{TotalEq, TotalOrd}; + +pub(crate) use crate::chunked_array::ChunkLenIter; +pub use crate::chunked_array::ChunkedArray; +#[cfg(feature = "dtype-struct")] +pub use crate::chunked_array::StructChunked; +pub use crate::chunked_array::arithmetic::ArithmeticChunked; +pub use crate::chunked_array::builder::{ + BinaryChunkedBuilder, BooleanChunkedBuilder, ChunkedBuilder, ListBinaryChunkedBuilder, + ListBooleanChunkedBuilder, ListBuilderTrait, ListPrimitiveChunkedBuilder, + ListStringChunkedBuilder, NewChunkedArray, PrimitiveChunkedBuilder, StringChunkedBuilder, +}; +pub use crate::chunked_array::collect::{ChunkedCollectInferIterExt, ChunkedCollectIterExt}; +pub use crate::chunked_array::iterator::PolarsIterator; +#[cfg(feature = "dtype-categorical")] +pub use crate::chunked_array::logical::categorical::*; +#[cfg(feature = "ndarray")] +pub use crate::chunked_array::ndarray::IndexOrder; +#[cfg(feature = "object")] +pub use crate::chunked_array::object::PolarsObject; +pub use crate::chunked_array::ops::aggregate::*; +#[cfg(feature = "rolling_window")] +pub use crate::chunked_array::ops::rolling_window::RollingOptionsFixedWindow; +pub use crate::chunked_array::ops::*; +#[cfg(feature = "temporal")] +pub use crate::chunked_array::temporal::conversion::*; +#[cfg(feature = "dtype-categorical")] +pub use crate::datatypes::string_cache::StringCacheHolder; +pub use crate::datatypes::{ArrayCollectIterExt, *}; +pub use crate::error::signals::try_raise_keyboard_interrupt; +pub use crate::error::{ + PolarsError, PolarsResult, polars_bail, polars_ensure, polars_err, polars_warn, +}; +pub use crate::frame::column::{Column, IntoColumn}; +pub use crate::frame::explode::UnpivotArgsIR; +#[cfg(feature = "algorithm_group_by")] +pub(crate) use crate::frame::group_by::aggregations::*; +#[cfg(feature = "algorithm_group_by")] +pub use crate::frame::group_by::*; +pub use crate::frame::{DataFrame, UniqueKeepStrategy}; +pub use crate::hashing::VecHash; +pub use crate::named_from::{NamedFrom, NamedFromOwned}; +pub use crate::scalar::Scalar; +pub use crate::schema::*; +#[cfg(feature = "checked_arithmetic")] +pub use crate::series::arithmetic::checked::NumOpsDispatchChecked; +pub use crate::series::arithmetic::{LhsNumOps, NumOpsDispatch}; +pub use crate::series::{IntoSeries, Series, SeriesTrait}; +pub(crate) use crate::utils::CustomIterTools; +pub use crate::utils::IntoVec; +pub use crate::{datatypes, df}; diff --git a/crates/polars-core/src/random.rs b/crates/polars-core/src/random.rs new file mode 100644 index 000000000000..1c88b05643a7 --- /dev/null +++ b/crates/polars-core/src/random.rs @@ -0,0 +1,14 @@ +use std::sync::{LazyLock, Mutex}; + +use rand::prelude::*; + +static POLARS_GLOBAL_RNG_STATE: LazyLock> = + LazyLock::new(|| Mutex::new(SmallRng::from_entropy())); + +pub(crate) fn get_global_random_u64() -> u64 { + POLARS_GLOBAL_RNG_STATE.lock().unwrap().next_u64() +} + +pub fn set_global_random_seed(seed: u64) { + *POLARS_GLOBAL_RNG_STATE.lock().unwrap() = SmallRng::seed_from_u64(seed); +} diff --git a/crates/polars-core/src/scalar/from.rs b/crates/polars-core/src/scalar/from.rs new file mode 100644 index 000000000000..e2bcd0fcdc25 --- /dev/null +++ b/crates/polars-core/src/scalar/from.rs @@ -0,0 +1,33 @@ +use polars_utils::pl_str::PlSmallStr; + +use super::{AnyValue, DataType, Scalar}; + +macro_rules! impl_from { + ($(($t:ty, $av:ident, $dt:ident))+) => { + $( + impl From<$t> for Scalar { + #[inline] + fn from(v: $t) -> Self { + Self::new(DataType::$dt, AnyValue::$av(v)) + } + } + )+ + } +} + +impl_from! { + (bool, Boolean, Boolean) + (i8, Int8, Int8) + (i16, Int16, Int16) + (i32, Int32, Int32) + (i64, Int64, Int64) + (i128, Int128, Int128) + (u8, UInt8, UInt8) + (u16, UInt16, UInt16) + (u32, UInt32, UInt32) + (u64, UInt64, UInt64) + (f32, Float32, Float32) + (f64, Float64, Float64) + (PlSmallStr, StringOwned, String) + (Vec, BinaryOwned, Binary) +} diff --git a/crates/polars-core/src/scalar/mod.rs b/crates/polars-core/src/scalar/mod.rs new file mode 100644 index 000000000000..2c633bacff33 --- /dev/null +++ b/crates/polars-core/src/scalar/mod.rs @@ -0,0 +1,113 @@ +mod from; +mod new; +pub mod reduce; + +use std::hash::Hash; + +use polars_error::PolarsResult; +use polars_utils::pl_str::PlSmallStr; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use crate::chunked_array::cast::CastOptions; +use crate::datatypes::{AnyValue, DataType}; +use crate::prelude::{Column, Series}; + +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct Scalar { + dtype: DataType, + value: AnyValue<'static>, +} + +impl Hash for Scalar { + fn hash(&self, state: &mut H) { + self.dtype.hash(state); + self.value.hash_impl(state, true); + } +} + +impl Default for Scalar { + fn default() -> Self { + Self { + dtype: DataType::Null, + value: AnyValue::Null, + } + } +} + +impl Scalar { + #[inline(always)] + pub const fn new(dtype: DataType, value: AnyValue<'static>) -> Self { + Self { dtype, value } + } + + pub const fn null(dtype: DataType) -> Self { + Self::new(dtype, AnyValue::Null) + } + + pub fn cast_with_options(self, dtype: &DataType, options: CastOptions) -> PolarsResult { + if self.dtype() == dtype { + return Ok(self); + } + + // @Optimize: If we have fully fleshed out casting semantics, we could just specify the + // cast on AnyValue. + let s = self + .into_series(PlSmallStr::from_static("scalar")) + .cast_with_options(dtype, options)?; + let value = s.get(0).unwrap(); + Ok(Self::new(s.dtype().clone(), value.into_static())) + } + + #[inline(always)] + pub fn is_null(&self) -> bool { + self.value.is_null() + } + + #[inline(always)] + pub fn is_nan(&self) -> bool { + self.value.is_nan() + } + + #[inline(always)] + pub fn into_value(self) -> AnyValue<'static> { + self.value + } + + #[inline(always)] + pub fn value(&self) -> &AnyValue<'static> { + &self.value + } + + pub fn as_any_value(&self) -> AnyValue { + self.value + .strict_cast(&self.dtype) + .unwrap_or_else(|| self.value.clone()) + } + + pub fn into_series(self, name: PlSmallStr) -> Series { + Series::from_any_values_and_dtype(name, &[self.as_any_value()], &self.dtype, true).unwrap() + } + + /// Turn a scalar into a column with `length=1`. + pub fn into_column(self, name: PlSmallStr) -> Column { + Column::new_scalar(name, self, 1) + } + + #[inline(always)] + pub fn dtype(&self) -> &DataType { + &self.dtype + } + + #[inline(always)] + pub fn update(&mut self, value: AnyValue<'static>) { + self.value = value; + } + + #[inline(always)] + pub fn with_value(mut self, value: AnyValue<'static>) -> Self { + self.update(value); + self + } +} diff --git a/crates/polars-core/src/scalar/new.rs b/crates/polars-core/src/scalar/new.rs new file mode 100644 index 000000000000..c4c920e0cedd --- /dev/null +++ b/crates/polars-core/src/scalar/new.rs @@ -0,0 +1,27 @@ +use std::sync::Arc; + +use super::Scalar; +use crate::prelude::{AnyValue, DataType, TimeUnit, TimeZone}; + +impl Scalar { + #[cfg(feature = "dtype-date")] + pub fn new_date(value: i32) -> Self { + Scalar::new(DataType::Date, AnyValue::Date(value)) + } + + #[cfg(feature = "dtype-datetime")] + pub fn new_datetime(value: i64, time_unit: TimeUnit, tz: Option) -> Self { + Scalar::new( + DataType::Datetime(time_unit, tz.clone()), + AnyValue::DatetimeOwned(value, time_unit, tz.map(Arc::new)), + ) + } + + #[cfg(feature = "dtype-duration")] + pub fn new_duration(value: i64, time_unit: TimeUnit) -> Self { + Scalar::new( + DataType::Duration(time_unit), + AnyValue::Duration(value, time_unit), + ) + } +} diff --git a/crates/polars-core/src/scalar/reduce.rs b/crates/polars-core/src/scalar/reduce.rs new file mode 100644 index 000000000000..a5d8f788274a --- /dev/null +++ b/crates/polars-core/src/scalar/reduce.rs @@ -0,0 +1,37 @@ +use crate::datatypes::{AnyValue, TimeUnit}; +#[cfg(feature = "dtype-date")] +use crate::prelude::MS_IN_DAY; +use crate::prelude::{DataType, Scalar}; + +pub fn mean_reduce(value: Option, dtype: DataType) -> Scalar { + match dtype { + DataType::Float32 => { + let val = value.map(|m| m as f32); + Scalar::new(dtype, val.into()) + }, + dt if dt.is_primitive_numeric() || dt.is_decimal() || dt.is_bool() => { + Scalar::new(DataType::Float64, value.into()) + }, + #[cfg(feature = "dtype-date")] + DataType::Date => { + let val = value.map(|v| (v * MS_IN_DAY as f64) as i64); + Scalar::new(DataType::Datetime(TimeUnit::Milliseconds, None), val.into()) + }, + #[cfg(feature = "dtype-datetime")] + dt @ DataType::Datetime(_, _) => { + let val = value.map(|v| v as i64); + Scalar::new(dt, val.into()) + }, + #[cfg(feature = "dtype-duration")] + dt @ DataType::Duration(_) => { + let val = value.map(|v| v as i64); + Scalar::new(dt, val.into()) + }, + #[cfg(feature = "dtype-time")] + dt @ DataType::Time => { + let val = value.map(|v| v as i64); + Scalar::new(dt, val.into()) + }, + dt => Scalar::new(dt, AnyValue::Null), + } +} diff --git a/crates/polars-core/src/schema.rs b/crates/polars-core/src/schema.rs new file mode 100644 index 000000000000..71514cdd469e --- /dev/null +++ b/crates/polars-core/src/schema.rs @@ -0,0 +1,186 @@ +use std::fmt::Debug; + +use arrow::bitmap::Bitmap; +use polars_utils::pl_str::PlSmallStr; + +use crate::prelude::*; +use crate::utils::try_get_supertype; + +pub type SchemaRef = Arc; +pub type Schema = polars_schema::Schema; + +pub trait SchemaExt { + fn from_arrow_schema(value: &ArrowSchema) -> Self; + + fn get_field(&self, name: &str) -> Option; + + fn try_get_field(&self, name: &str) -> PolarsResult; + + fn to_arrow(&self, compat_level: CompatLevel) -> ArrowSchema; + + fn iter_fields(&self) -> impl ExactSizeIterator + '_; + + fn to_supertype(&mut self, other: &Schema) -> PolarsResult; + + /// Select fields using a bitmap. + fn project_select(&self, select: &Bitmap) -> Self; +} + +impl SchemaExt for Schema { + fn from_arrow_schema(value: &ArrowSchema) -> Self { + value + .iter_values() + .map(|x| (x.name.clone(), DataType::from_arrow_field(x))) + .collect() + } + + /// Look up the name in the schema and return an owned [`Field`] by cloning the data. + /// + /// Returns `None` if the field does not exist. + /// + /// This method constructs the `Field` by cloning the name and dtype. For a version that returns references, see + /// [`get`][Self::get] or [`get_full`][Self::get_full]. + fn get_field(&self, name: &str) -> Option { + self.get_full(name) + .map(|(_, name, dtype)| Field::new(name.clone(), dtype.clone())) + } + + /// Look up the name in the schema and return an owned [`Field`] by cloning the data. + /// + /// Returns `Err(PolarsErr)` if the field does not exist. + /// + /// This method constructs the `Field` by cloning the name and dtype. For a version that returns references, see + /// [`get`][Self::get] or [`get_full`][Self::get_full]. + fn try_get_field(&self, name: &str) -> PolarsResult { + self.get_full(name) + .ok_or_else(|| polars_err!(SchemaFieldNotFound: "{}", name)) + .map(|(_, name, dtype)| Field::new(name.clone(), dtype.clone())) + } + + /// Convert self to `ArrowSchema` by cloning the fields. + fn to_arrow(&self, compat_level: CompatLevel) -> ArrowSchema { + self.iter() + .map(|(name, dtype)| { + ( + name.clone(), + dtype.to_arrow_field(name.clone(), compat_level), + ) + }) + .collect() + } + + /// Iterates the [`Field`]s in this schema, constructing them anew by cloning each `(&name, &dtype)` pair. + /// + /// Note that this clones each name and dtype in order to form an owned [`Field`]. For a clone-free version, use + /// [`iter`][Self::iter], which returns `(&name, &dtype)`. + fn iter_fields(&self) -> impl ExactSizeIterator + '_ { + self.iter() + .map(|(name, dtype)| Field::new(name.clone(), dtype.clone())) + } + + /// Take another [`Schema`] and try to find the supertypes between them. + fn to_supertype(&mut self, other: &Schema) -> PolarsResult { + polars_ensure!(self.len() == other.len(), ComputeError: "schema lengths differ"); + + let mut changed = false; + for ((k, dt), (other_k, other_dt)) in self.iter_mut().zip(other.iter()) { + polars_ensure!(k == other_k, ComputeError: "schema names differ: got {}, expected {}", k, other_k); + + let st = try_get_supertype(dt, other_dt)?; + changed |= (&st != dt) || (&st != other_dt); + *dt = st + } + Ok(changed) + } + + fn project_select(&self, select: &Bitmap) -> Self { + assert_eq!(self.len(), select.len()); + self.iter() + .zip(select.iter()) + .filter(|(_, select)| *select) + .map(|((n, dt), _)| (n.clone(), dt.clone())) + .collect() + } +} + +pub trait SchemaNamesAndDtypes { + const IS_ARROW: bool; + type DataType: Debug + Clone + Default + PartialEq; + + fn iter_names_and_dtypes( + &self, + ) -> impl ExactSizeIterator; +} + +impl SchemaNamesAndDtypes for ArrowSchema { + const IS_ARROW: bool = true; + type DataType = ArrowDataType; + + fn iter_names_and_dtypes( + &self, + ) -> impl ExactSizeIterator { + self.iter_values().map(|x| (&x.name, &x.dtype)) + } +} + +impl SchemaNamesAndDtypes for Schema { + const IS_ARROW: bool = false; + type DataType = DataType; + + fn iter_names_and_dtypes( + &self, + ) -> impl ExactSizeIterator { + self.iter() + } +} + +pub fn ensure_matching_schema( + lhs: &polars_schema::Schema, + rhs: &polars_schema::Schema, +) -> PolarsResult<()> +where + polars_schema::Schema: SchemaNamesAndDtypes, +{ + let lhs = lhs.iter_names_and_dtypes(); + let rhs = rhs.iter_names_and_dtypes(); + + if lhs.len() != rhs.len() { + polars_bail!( + SchemaMismatch: + "schemas contained differing number of columns: {} != {}", + lhs.len(), rhs.len(), + ); + } + + for (i, ((l_name, l_dtype), (r_name, r_dtype))) in lhs.zip(rhs).enumerate() { + if l_name != r_name { + polars_bail!( + SchemaMismatch: + "schema names differ at index {}: {} != {}", + i, l_name, r_name + ) + } + if l_dtype != r_dtype + && (!polars_schema::Schema::::IS_ARROW + || unsafe { + // For timezone normalization. Easier than writing out the entire PartialEq. + DataType::from_arrow_dtype(std::mem::transmute::< + & as SchemaNamesAndDtypes>::DataType, + &ArrowDataType, + >(l_dtype)) + != DataType::from_arrow_dtype(std::mem::transmute::< + & as SchemaNamesAndDtypes>::DataType, + &ArrowDataType, + >(r_dtype)) + }) + { + polars_bail!( + SchemaMismatch: + "schema dtypes differ at index {} for column {}: {:?} != {:?}", + i, l_name, l_dtype, r_dtype + ) + } + } + + Ok(()) +} diff --git a/crates/polars-core/src/serde/chunked_array.rs b/crates/polars-core/src/serde/chunked_array.rs new file mode 100644 index 000000000000..6b24f1750d22 --- /dev/null +++ b/crates/polars-core/src/serde/chunked_array.rs @@ -0,0 +1,20 @@ +use serde::{Serialize, Serializer}; + +use crate::prelude::*; + +// We don't use this internally (we call Series::serialize instead), but Rust users might need it. +impl Serialize for ChunkedArray +where + T: PolarsDataType, + ChunkedArray: IntoSeries, +{ + fn serialize( + &self, + serializer: S, + ) -> std::result::Result<::Ok, ::Error> + where + S: Serializer, + { + self.clone().into_series().serialize(serializer) + } +} diff --git a/crates/polars-core/src/serde/df.rs b/crates/polars-core/src/serde/df.rs new file mode 100644 index 000000000000..fef36ae345ee --- /dev/null +++ b/crates/polars-core/src/serde/df.rs @@ -0,0 +1,174 @@ +use std::sync::Arc; + +use arrow::datatypes::Metadata; +use arrow::io::ipc::read::{StreamReader, StreamState, read_stream_metadata}; +use arrow::io::ipc::write::WriteOptions; +use polars_error::{PolarsResult, polars_err, to_compute_err}; +use polars_utils::format_pl_smallstr; +use polars_utils::pl_serialize::deserialize_map_bytes; +use polars_utils::pl_str::PlSmallStr; +use serde::de::Error; +use serde::*; + +use crate::chunked_array::flags::StatisticsFlags; +use crate::config; +use crate::frame::chunk_df_for_writing; +use crate::prelude::{CompatLevel, DataFrame, SchemaExt}; +use crate::utils::accumulate_dataframes_vertical_unchecked; + +const FLAGS_KEY: PlSmallStr = PlSmallStr::from_static("_PL_FLAGS"); + +impl DataFrame { + pub fn serialize_into_writer(&mut self, writer: &mut dyn std::io::Write) -> PolarsResult<()> { + let schema = self.schema(); + + if schema.iter_values().any(|x| x.is_object()) { + return Err(polars_err!( + ComputeError: + "serializing data of type Object is not supported", + )); + } + + let mut ipc_writer = + arrow::io::ipc::write::StreamWriter::new(writer, WriteOptions { compression: None }); + + ipc_writer.set_custom_schema_metadata(Arc::new(Metadata::from_iter( + self.get_columns().iter().map(|c| { + ( + format_pl_smallstr!("{}{}", FLAGS_KEY, c.name()), + PlSmallStr::from(c.get_flags().bits().to_string()), + ) + }), + ))); + + ipc_writer.set_custom_schema_metadata(Arc::new(Metadata::from([( + FLAGS_KEY, + serde_json::to_string( + &self + .iter() + .map(|s| s.get_flags().bits()) + .collect::>(), + ) + .map_err(to_compute_err)? + .into(), + )]))); + + ipc_writer.start(&schema.to_arrow(CompatLevel::newest()), None)?; + + for batch in chunk_df_for_writing(self, 512 * 512)?.iter_chunks(CompatLevel::newest(), true) + { + ipc_writer.write(&batch, None)?; + } + + ipc_writer.finish()?; + + Ok(()) + } + + pub fn serialize_to_bytes(&mut self) -> PolarsResult> { + let mut buf = vec![]; + self.serialize_into_writer(&mut buf)?; + + Ok(buf) + } + + pub fn deserialize_from_reader(reader: &mut dyn std::io::Read) -> PolarsResult { + let mut md = read_stream_metadata(reader)?; + + let custom_metadata = md.custom_schema_metadata.take(); + + let reader = StreamReader::new(reader, md, None); + let dfs = reader + .into_iter() + .map_while(|batch| match batch { + Ok(StreamState::Some(batch)) => Some(Ok(DataFrame::from(batch))), + Ok(StreamState::Waiting) => None, + Err(e) => Some(Err(e)), + }) + .collect::>>()?; + + if dfs.is_empty() { + return Ok(DataFrame::empty()); + } + let mut df = accumulate_dataframes_vertical_unchecked(dfs); + + // Set custom metadata (fallible) + (|| { + let custom_metadata = custom_metadata?; + let flags = custom_metadata.get(&FLAGS_KEY)?; + + let flags: PolarsResult> = serde_json::from_str(flags).map_err(to_compute_err); + + let verbose = config::verbose(); + + if let Err(e) = &flags { + if verbose { + eprintln!("DataFrame::read_ipc: Error parsing metadata flags: {}", e); + } + } + + let flags = flags.ok()?; + + if flags.len() != df.width() { + if verbose { + eprintln!( + "DataFrame::read_ipc: Metadata flags width mismatch: {} != {}", + flags.len(), + df.width() + ); + } + + return None; + } + + let mut n_set = 0; + + for (c, v) in unsafe { df.get_columns_mut() }.iter_mut().zip(flags) { + if let Some(flags) = StatisticsFlags::from_bits(v) { + n_set += c.set_flags(flags) as usize; + } + } + + if verbose { + eprintln!( + "DataFrame::read_ipc: Loaded metadata for {} / {} columns", + n_set, + df.width() + ); + } + + Some(()) + })(); + + Ok(df) + } +} + +impl Serialize for DataFrame { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + use serde::ser::Error; + + let mut bytes = vec![]; + self.clone() + .serialize_into_writer(&mut bytes) + .map_err(S::Error::custom)?; + + serializer.serialize_bytes(bytes.as_slice()) + } +} + +impl<'de> Deserialize<'de> for DataFrame { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserialize_map_bytes(deserializer, |b| { + let v = &mut b.as_ref(); + Self::deserialize_from_reader(v) + })? + .map_err(D::Error::custom) + } +} diff --git a/crates/polars-core/src/serde/mod.rs b/crates/polars-core/src/serde/mod.rs new file mode 100644 index 000000000000..e7b96c1312d8 --- /dev/null +++ b/crates/polars-core/src/serde/mod.rs @@ -0,0 +1,159 @@ +pub mod chunked_array; +mod df; +pub mod series; + +#[cfg(test)] +mod test { + use crate::chunked_array::flags::StatisticsFlags; + use crate::prelude::*; + use crate::series::IsSorted; + + #[test] + fn test_serde() -> PolarsResult<()> { + let ca = UInt32Chunked::new("foo".into(), &[Some(1), None, Some(2)]); + + let json = serde_json::to_string(&ca.clone().into_series()).unwrap(); + + let out = serde_json::from_str::(&json).unwrap(); + assert!(ca.into_series().equals_missing(&out)); + + let ca = StringChunked::new("foo".into(), &[Some("foo"), None, Some("bar")]); + + let json = serde_json::to_string(&ca.clone().into_series()).unwrap(); + + let out = serde_json::from_str::(&json).unwrap(); // uses `Deserialize<'de>` + assert!(ca.into_series().equals_missing(&out)); + + Ok(()) + } + + /// test using the `DeserializedOwned` trait + #[test] + fn test_serde_owned() { + let ca = UInt32Chunked::new("foo".into(), &[Some(1), None, Some(2)]); + + let json = serde_json::to_string(&ca.clone().into_series()).unwrap(); + + let out = serde_json::from_reader::<_, Series>(json.as_bytes()).unwrap(); // uses `DeserializeOwned` + assert!(ca.into_series().equals_missing(&out)); + } + + fn sample_dataframe() -> DataFrame { + let s1 = Series::new("foo".into(), &[1, 2, 3]); + let s2 = Series::new("bar".into(), &[Some(true), None, Some(false)]); + let s3 = Series::new("string".into(), &["mouse", "elephant", "dog"]); + let s_list = Column::new("list".into(), &[s1.clone(), s1.clone(), s1.clone()]); + + DataFrame::new(vec![s1.into(), s2.into(), s3.into(), s_list]).unwrap() + } + + #[test] + fn test_serde_flags() { + let df = sample_dataframe(); + + for mut column in df.columns { + column.set_sorted_flag(IsSorted::Descending); + let json = serde_json::to_string(&column).unwrap(); + let out = serde_json::from_reader::<_, Column>(json.as_bytes()).unwrap(); + let f = out.get_flags(); + assert_ne!(f, StatisticsFlags::empty()); + assert_eq!(column.get_flags(), out.get_flags()); + } + } + + #[test] + fn test_serde_df_json() { + let df = sample_dataframe(); + let json = serde_json::to_string(&df).unwrap(); + let out = serde_json::from_str::(&json).unwrap(); // uses `Deserialize<'de>` + assert!(df.equals_missing(&out)); + } + + #[test] + fn test_serde_df_bincode() { + let df = sample_dataframe(); + let bytes = bincode::serialize(&df).unwrap(); + let out = bincode::deserialize::(&bytes).unwrap(); // uses `Deserialize<'de>` + assert!(df.equals_missing(&out)); + } + + /// test using the `DeserializedOwned` trait + #[test] + fn test_serde_df_owned_json() { + let df = sample_dataframe(); + let json = serde_json::to_string(&df).unwrap(); + + let out = serde_json::from_reader::<_, DataFrame>(json.as_bytes()).unwrap(); // uses `DeserializeOwned` + assert!(df.equals_missing(&out)); + } + + #[test] + fn test_serde_binary_series_owned_bincode() { + let s1 = Column::new( + "foo".into(), + &[ + vec![1u8, 2u8, 3u8], + vec![4u8, 5u8, 6u8, 7u8], + vec![8u8, 9u8], + ], + ); + let df = DataFrame::new(vec![s1]).unwrap(); + let bytes = bincode::serialize(&df).unwrap(); + let out = bincode::deserialize_from::<_, DataFrame>(bytes.as_slice()).unwrap(); + assert!(df.equals_missing(&out)); + } + + // STRUCT REFACTOR + #[ignore] + #[test] + #[cfg(feature = "dtype-struct")] + fn test_serde_struct_series_owned_json() { + let row_1 = AnyValue::StructOwned(Box::new(( + vec![ + AnyValue::String("1:1"), + AnyValue::Null, + AnyValue::String("1:3"), + ], + vec![ + Field::new("fld_1".into(), DataType::String), + Field::new("fld_2".into(), DataType::String), + Field::new("fld_3".into(), DataType::String), + ], + ))); + let dtype = DataType::Struct(vec![ + Field::new("fld_1".into(), DataType::String), + Field::new("fld_2".into(), DataType::String), + Field::new("fld_3".into(), DataType::String), + ]); + let row_2 = AnyValue::StructOwned(Box::new(( + vec![ + AnyValue::String("2:1"), + AnyValue::String("2:2"), + AnyValue::String("2:3"), + ], + vec![ + Field::new("fld_1".into(), DataType::String), + Field::new("fld_2".into(), DataType::String), + Field::new("fld_3".into(), DataType::String), + ], + ))); + let row_3 = AnyValue::Null; + + let s = + Series::from_any_values_and_dtype("item".into(), &[row_1, row_2, row_3], &dtype, false) + .unwrap(); + let df = DataFrame::new(vec![s.into()]).unwrap(); + + let df_str = serde_json::to_string(&df).unwrap(); + let out = serde_json::from_str::(&df_str).unwrap(); + assert!(df.equals_missing(&out)); + } + /// test using the `DeserializedOwned` trait + #[test] + fn test_serde_df_owned_bincode() { + let df = sample_dataframe(); + let bytes = bincode::serialize(&df).unwrap(); + let out = bincode::deserialize_from::<_, DataFrame>(bytes.as_slice()).unwrap(); // uses `DeserializeOwned` + assert!(df.equals_missing(&out)); + } +} diff --git a/crates/polars-core/src/serde/series.rs b/crates/polars-core/src/serde/series.rs new file mode 100644 index 000000000000..db9080a95737 --- /dev/null +++ b/crates/polars-core/src/serde/series.rs @@ -0,0 +1,66 @@ +use polars_utils::pl_serialize::deserialize_map_bytes; +use serde::de::Error; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +use crate::prelude::*; + +impl Series { + pub fn serialize_into_writer(&self, writer: &mut dyn std::io::Write) -> PolarsResult<()> { + let mut df = + unsafe { DataFrame::new_no_checks_height_from_first(vec![self.clone().into_column()]) }; + + df.serialize_into_writer(writer) + } + + pub fn serialize_to_bytes(&self) -> PolarsResult> { + let mut buf = vec![]; + self.serialize_into_writer(&mut buf)?; + + Ok(buf) + } + + pub fn deserialize_from_reader(reader: &mut dyn std::io::Read) -> PolarsResult { + let df = DataFrame::deserialize_from_reader(reader)?; + + if df.width() != 1 { + polars_bail!( + ShapeMismatch: + "expected only 1 column when deserializing Series from IPC, got columns: {:?}", + df.schema().iter_names().collect::>() + ) + } + + Ok(df.take_columns().swap_remove(0).take_materialized_series()) + } +} + +impl Serialize for Series { + fn serialize( + &self, + serializer: S, + ) -> std::result::Result<::Ok, ::Error> + where + S: Serializer, + { + use serde::ser::Error; + + serializer.serialize_bytes( + self.serialize_to_bytes() + .map_err(S::Error::custom)? + .as_slice(), + ) + } +} + +impl<'de> Deserialize<'de> for Series { + fn deserialize(deserializer: D) -> std::result::Result>::Error> + where + D: Deserializer<'de>, + { + deserialize_map_bytes(deserializer, |b| { + let v = &mut b.as_ref(); + Self::deserialize_from_reader(v) + })? + .map_err(D::Error::custom) + } +} diff --git a/crates/polars-core/src/series/amortized_iter.rs b/crates/polars-core/src/series/amortized_iter.rs new file mode 100644 index 000000000000..09a7cec5de47 --- /dev/null +++ b/crates/polars-core/src/series/amortized_iter.rs @@ -0,0 +1,106 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use std::ptr::NonNull; +use std::rc::Rc; + +use crate::prelude::*; + +/// A [`Series`] that amortizes a few allocations during iteration. +#[derive(Clone)] +pub struct AmortSeries { + container: Rc, + // the ptr to the inner chunk, this saves some ptr chasing + inner: NonNull, +} + +/// We don't implement Deref so that the caller is aware of converting to Series +impl AsRef for AmortSeries { + fn as_ref(&self) -> &Series { + self.container.as_ref() + } +} + +pub type ArrayBox = Box; + +impl AmortSeries { + pub fn new(series: Rc) -> Self { + debug_assert_eq!(series.chunks().len(), 1); + let inner_chunk = series.array_ref(0) as *const ArrayRef as *mut arrow::array::ArrayRef; + let container = series; + AmortSeries { + container, + inner: NonNull::new(inner_chunk).unwrap(), + } + } + + /// Creates a new [`UnsafeSeries`] + /// + /// # Safety + /// Inner chunks must be from `Series` otherwise the dtype may be incorrect and lead to UB. + #[inline] + pub(crate) unsafe fn new_with_chunk(series: Rc, inner_chunk: &ArrayRef) -> Self { + AmortSeries { + container: series, + inner: NonNull::new(inner_chunk as *const ArrayRef as *mut ArrayRef).unwrap_unchecked(), + } + } + + pub fn deep_clone(&self) -> Series { + unsafe { + let s = &(*self.container); + debug_assert_eq!(s.chunks().len(), 1); + let array_ref = s.chunks().get_unchecked(0).clone(); + let name = s.name().clone(); + Series::from_chunks_and_dtype_unchecked(name.clone(), vec![array_ref], s.dtype()) + } + } + + #[inline] + /// Swaps inner state with the `array`. Prefer `AmortSeries::with_array` as this + /// restores the state. + /// # Safety + /// This swaps an underlying pointer that might be hold by other cloned series. + pub unsafe fn swap(&mut self, array: &mut ArrayRef) { + std::mem::swap(self.inner.as_mut(), array); + + // ensure lengths are correct. + unsafe { + let ptr = Rc::as_ptr(&self.container) as *mut Series; + (*ptr)._get_inner_mut().compute_len() + } + } + + /// Temporary swaps out the array, and restores the original state + /// when application of the function `f` is done. + /// + /// # Safety + /// Array must be from `Series` physical dtype. + #[inline] + pub unsafe fn with_array(&mut self, array: &mut ArrayRef, f: F) -> T + where + F: Fn(&AmortSeries) -> T, + { + unsafe { + self.swap(array); + let out = f(self); + self.swap(array); + out + } + } +} + +// SAFETY: +// type must be matching +pub(crate) unsafe fn unstable_series_container_and_ptr( + name: PlSmallStr, + inner_values: ArrayRef, + iter_dtype: &DataType, +) -> (Series, *mut ArrayRef) { + let series_container = { + let mut s = Series::from_chunks_and_dtype_unchecked(name, vec![inner_values], iter_dtype); + s.clear_flags(); + s + }; + + let ptr = series_container.array_ref(0) as *const ArrayRef as *mut ArrayRef; + (series_container, ptr) +} diff --git a/crates/polars-core/src/series/any_value.rs b/crates/polars-core/src/series/any_value.rs new file mode 100644 index 000000000000..22036be49e1e --- /dev/null +++ b/crates/polars-core/src/series/any_value.rs @@ -0,0 +1,941 @@ +use std::borrow::Cow; +use std::fmt::Write; + +use arrow::bitmap::MutableBitmap; + +use crate::chunked_array::builder::{AnonymousOwnedListBuilder, get_list_builder}; +use crate::prelude::*; +use crate::utils::any_values_to_supertype; + +impl<'a, T: AsRef<[AnyValue<'a>]>> NamedFrom]> for Series { + /// Construct a new [`Series`] from a collection of [`AnyValue`]. + /// + /// # Panics + /// + /// Panics if the values do not all share the same data type (with the exception + /// of [`DataType::Null`], which is always allowed). + /// + /// [`AnyValue`]: crate::datatypes::AnyValue + fn new(name: PlSmallStr, values: T) -> Self { + let values = values.as_ref(); + Series::from_any_values(name, values, true).expect("data types of values should match") + } +} + +fn initialize_empty_categorical_revmap_rec(dtype: &DataType) -> Cow { + use DataType as T; + match dtype { + #[cfg(feature = "dtype-categorical")] + T::Categorical(None, o) => { + Cow::Owned(T::Categorical(Some(Arc::new(RevMapping::default())), *o)) + }, + T::List(inner_dtype) => match initialize_empty_categorical_revmap_rec(inner_dtype) { + Cow::Owned(inner_dtype) => Cow::Owned(T::List(Box::new(inner_dtype))), + _ => Cow::Borrowed(dtype), + }, + #[cfg(feature = "dtype-array")] + T::Array(inner_dtype, width) => { + match initialize_empty_categorical_revmap_rec(inner_dtype) { + Cow::Owned(inner_dtype) => Cow::Owned(T::Array(Box::new(inner_dtype), *width)), + _ => Cow::Borrowed(dtype), + } + }, + #[cfg(feature = "dtype-struct")] + T::Struct(fields) => { + for (i, field) in fields.iter().enumerate() { + if let Cow::Owned(field_dtype) = + initialize_empty_categorical_revmap_rec(field.dtype()) + { + let mut new_fields = Vec::with_capacity(fields.len()); + new_fields.extend(fields[..i].iter().cloned()); + new_fields.push(Field::new(field.name().clone(), field_dtype)); + new_fields.extend(fields[i + 1..].iter().map(|field| { + let field_dtype = + initialize_empty_categorical_revmap_rec(field.dtype()).into_owned(); + Field::new(field.name().clone(), field_dtype) + })); + return Cow::Owned(T::Struct(new_fields)); + } + } + + Cow::Borrowed(dtype) + }, + _ => Cow::Borrowed(dtype), + } +} + +impl Series { + /// Construct a new [`Series`] from a slice of AnyValues. + /// + /// The data type of the resulting Series is determined by the `values` + /// and the `strict` parameter: + /// - If `strict` is `true`, the data type is equal to the data type of the + /// first non-null value. If any other non-null values do not match this + /// data type, an error is raised. + /// - If `strict` is `false`, the data type is the supertype of the `values`. + /// An error is returned if no supertype can be determined. + /// **WARNING**: A full pass over the values is required to determine the supertype. + /// - If no values were passed, the resulting data type is `Null`. + pub fn from_any_values( + name: PlSmallStr, + values: &[AnyValue], + strict: bool, + ) -> PolarsResult { + fn get_first_non_null_dtype(values: &[AnyValue]) -> DataType { + let mut all_flat_null = true; + let first_non_null = values.iter().find(|av| { + if !av.is_null() { + all_flat_null = false + }; + !av.is_nested_null() + }); + match first_non_null { + Some(av) => av.dtype(), + None => { + if all_flat_null { + DataType::Null + } else { + // Second pass to check for the nested null value that + // toggled `all_flat_null` to false, e.g. a List(Null). + let first_nested_null = values.iter().find(|av| !av.is_null()).unwrap(); + first_nested_null.dtype() + } + }, + } + } + let dtype = if strict { + get_first_non_null_dtype(values) + } else { + // Currently does not work correctly for Decimal because equality is not implemented. + any_values_to_supertype(values)? + }; + + // TODO: Remove this when Decimal data type equality is implemented. + #[cfg(feature = "dtype-decimal")] + if dtype.is_decimal() { + let dtype = DataType::Decimal(None, None); + return Self::from_any_values_and_dtype(name, values, &dtype, strict); + } + + Self::from_any_values_and_dtype(name, values, &dtype, strict) + } + + /// Construct a new [`Series`] with the given `dtype` from a slice of AnyValues. + /// + /// If `strict` is `true`, an error is returned if the values do not match the given + /// data type. If `strict` is `false`, values that do not match the given data type + /// are cast. If casting is not possible, the values are set to null instead. + pub fn from_any_values_and_dtype( + name: PlSmallStr, + values: &[AnyValue], + dtype: &DataType, + strict: bool, + ) -> PolarsResult { + if values.is_empty() { + return Ok(Self::new_empty( + name, + // This is given categoricals with empty revmaps, but we need to always return + // categoricals with non-empty revmaps. + initialize_empty_categorical_revmap_rec(dtype).as_ref(), + )); + } + + let mut s = match dtype { + #[cfg(feature = "dtype-i8")] + DataType::Int8 => any_values_to_integer::(values, strict)?.into_series(), + #[cfg(feature = "dtype-i16")] + DataType::Int16 => any_values_to_integer::(values, strict)?.into_series(), + DataType::Int32 => any_values_to_integer::(values, strict)?.into_series(), + DataType::Int64 => any_values_to_integer::(values, strict)?.into_series(), + #[cfg(feature = "dtype-i128")] + DataType::Int128 => any_values_to_integer::(values, strict)?.into_series(), + #[cfg(feature = "dtype-u8")] + DataType::UInt8 => any_values_to_integer::(values, strict)?.into_series(), + #[cfg(feature = "dtype-u16")] + DataType::UInt16 => any_values_to_integer::(values, strict)?.into_series(), + DataType::UInt32 => any_values_to_integer::(values, strict)?.into_series(), + DataType::UInt64 => any_values_to_integer::(values, strict)?.into_series(), + DataType::Float32 => any_values_to_f32(values, strict)?.into_series(), + DataType::Float64 => any_values_to_f64(values, strict)?.into_series(), + DataType::Boolean => any_values_to_bool(values, strict)?.into_series(), + DataType::String => any_values_to_string(values, strict)?.into_series(), + DataType::Binary => any_values_to_binary(values, strict)?.into_series(), + #[cfg(feature = "dtype-date")] + DataType::Date => any_values_to_date(values, strict)?.into_series(), + #[cfg(feature = "dtype-time")] + DataType::Time => any_values_to_time(values, strict)?.into_series(), + #[cfg(feature = "dtype-datetime")] + DataType::Datetime(tu, tz) => { + any_values_to_datetime(values, *tu, (*tz).clone(), strict)?.into_series() + }, + #[cfg(feature = "dtype-duration")] + DataType::Duration(tu) => any_values_to_duration(values, *tu, strict)?.into_series(), + #[cfg(feature = "dtype-categorical")] + dt @ DataType::Categorical(_, _) => any_values_to_categorical(values, dt, strict)?, + #[cfg(feature = "dtype-categorical")] + dt @ DataType::Enum(_, _) => any_values_to_enum(values, dt, strict)?, + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(precision, scale) => { + any_values_to_decimal(values, *precision, *scale, strict)?.into_series() + }, + DataType::List(inner) => any_values_to_list(values, inner, strict)?.into_series(), + #[cfg(feature = "dtype-array")] + DataType::Array(inner, size) => any_values_to_array(values, inner, strict, *size)? + .into_series() + .cast(&DataType::Array(inner.clone(), *size))?, + #[cfg(feature = "dtype-struct")] + DataType::Struct(fields) => any_values_to_struct(values, fields, strict)?, + #[cfg(feature = "object")] + DataType::Object(_) => any_values_to_object(values)?, + DataType::Null => Series::new_null(PlSmallStr::EMPTY, values.len()), + dt => { + polars_bail!( + InvalidOperation: + "constructing a Series with data type {dt:?} from AnyValues is not supported" + ) + }, + }; + s.rename(name); + Ok(s) + } +} + +fn any_values_to_primitive_nonstrict(values: &[AnyValue]) -> ChunkedArray { + values + .iter() + .map(|av| av.extract::()) + .collect_trusted() +} + +fn any_values_to_integer( + values: &[AnyValue], + strict: bool, +) -> PolarsResult> { + fn any_values_to_integer_strict( + values: &[AnyValue], + ) -> PolarsResult> { + let mut builder = PrimitiveChunkedBuilder::::new(PlSmallStr::EMPTY, values.len()); + for av in values { + match &av { + av if av.is_integer() => { + let opt_val = av.extract::(); + let val = match opt_val { + Some(v) => v, + None => return Err(invalid_value_error(&T::get_dtype(), av)), + }; + builder.append_value(val) + }, + AnyValue::Null => builder.append_null(), + av => return Err(invalid_value_error(&T::get_dtype(), av)), + } + } + Ok(builder.finish()) + } + + if strict { + any_values_to_integer_strict::(values) + } else { + Ok(any_values_to_primitive_nonstrict::(values)) + } +} + +fn any_values_to_f32(values: &[AnyValue], strict: bool) -> PolarsResult { + fn any_values_to_f32_strict(values: &[AnyValue]) -> PolarsResult { + let mut builder = + PrimitiveChunkedBuilder::::new(PlSmallStr::EMPTY, values.len()); + for av in values { + match av { + AnyValue::Float32(i) => builder.append_value(*i), + AnyValue::Null => builder.append_null(), + av => return Err(invalid_value_error(&DataType::Float32, av)), + } + } + Ok(builder.finish()) + } + if strict { + any_values_to_f32_strict(values) + } else { + Ok(any_values_to_primitive_nonstrict::(values)) + } +} +fn any_values_to_f64(values: &[AnyValue], strict: bool) -> PolarsResult { + fn any_values_to_f64_strict(values: &[AnyValue]) -> PolarsResult { + let mut builder = + PrimitiveChunkedBuilder::::new(PlSmallStr::EMPTY, values.len()); + for av in values { + match av { + AnyValue::Float64(i) => builder.append_value(*i), + AnyValue::Float32(i) => builder.append_value(*i as f64), + AnyValue::Null => builder.append_null(), + av => return Err(invalid_value_error(&DataType::Float64, av)), + } + } + Ok(builder.finish()) + } + if strict { + any_values_to_f64_strict(values) + } else { + Ok(any_values_to_primitive_nonstrict::(values)) + } +} + +fn any_values_to_bool(values: &[AnyValue], strict: bool) -> PolarsResult { + let mut builder = BooleanChunkedBuilder::new(PlSmallStr::EMPTY, values.len()); + for av in values { + match av { + AnyValue::Boolean(b) => builder.append_value(*b), + AnyValue::Null => builder.append_null(), + av => { + if strict { + return Err(invalid_value_error(&DataType::Boolean, av)); + } + match av.cast(&DataType::Boolean) { + AnyValue::Boolean(b) => builder.append_value(b), + _ => builder.append_null(), + } + }, + } + } + Ok(builder.finish()) +} + +fn any_values_to_string(values: &[AnyValue], strict: bool) -> PolarsResult { + fn any_values_to_string_strict(values: &[AnyValue]) -> PolarsResult { + let mut builder = StringChunkedBuilder::new(PlSmallStr::EMPTY, values.len()); + for av in values { + match av { + AnyValue::String(s) => builder.append_value(s), + AnyValue::StringOwned(s) => builder.append_value(s), + AnyValue::Null => builder.append_null(), + av => return Err(invalid_value_error(&DataType::String, av)), + } + } + Ok(builder.finish()) + } + fn any_values_to_string_nonstrict(values: &[AnyValue]) -> StringChunked { + let mut builder = StringChunkedBuilder::new(PlSmallStr::EMPTY, values.len()); + let mut owned = String::new(); // Amortize allocations. + for av in values { + match av { + AnyValue::String(s) => builder.append_value(s), + AnyValue::StringOwned(s) => builder.append_value(s), + AnyValue::Null => builder.append_null(), + AnyValue::Binary(_) | AnyValue::BinaryOwned(_) => builder.append_null(), + av => { + owned.clear(); + write!(owned, "{av}").unwrap(); + builder.append_value(&owned); + }, + } + } + builder.finish() + } + if strict { + any_values_to_string_strict(values) + } else { + Ok(any_values_to_string_nonstrict(values)) + } +} + +fn any_values_to_binary(values: &[AnyValue], strict: bool) -> PolarsResult { + fn any_values_to_binary_strict(values: &[AnyValue]) -> PolarsResult { + let mut builder = BinaryChunkedBuilder::new(PlSmallStr::EMPTY, values.len()); + for av in values { + match av { + AnyValue::Binary(s) => builder.append_value(*s), + AnyValue::BinaryOwned(s) => builder.append_value(&**s), + AnyValue::Null => builder.append_null(), + av => return Err(invalid_value_error(&DataType::Binary, av)), + } + } + Ok(builder.finish()) + } + fn any_values_to_binary_nonstrict(values: &[AnyValue]) -> BinaryChunked { + values + .iter() + .map(|av| match av { + AnyValue::Binary(b) => Some(*b), + AnyValue::BinaryOwned(b) => Some(&**b), + AnyValue::String(s) => Some(s.as_bytes()), + AnyValue::StringOwned(s) => Some(s.as_str().as_bytes()), + _ => None, + }) + .collect_trusted() + } + if strict { + any_values_to_binary_strict(values) + } else { + Ok(any_values_to_binary_nonstrict(values)) + } +} + +#[cfg(feature = "dtype-date")] +fn any_values_to_date(values: &[AnyValue], strict: bool) -> PolarsResult { + let mut builder = PrimitiveChunkedBuilder::::new(PlSmallStr::EMPTY, values.len()); + for av in values { + match av { + AnyValue::Date(i) => builder.append_value(*i), + AnyValue::Null => builder.append_null(), + av => { + if strict { + return Err(invalid_value_error(&DataType::Date, av)); + } + match av.cast(&DataType::Date) { + AnyValue::Date(i) => builder.append_value(i), + _ => builder.append_null(), + } + }, + } + } + Ok(builder.finish().into()) +} + +#[cfg(feature = "dtype-time")] +fn any_values_to_time(values: &[AnyValue], strict: bool) -> PolarsResult { + let mut builder = PrimitiveChunkedBuilder::::new(PlSmallStr::EMPTY, values.len()); + for av in values { + match av { + AnyValue::Time(i) => builder.append_value(*i), + AnyValue::Null => builder.append_null(), + av => { + if strict { + return Err(invalid_value_error(&DataType::Time, av)); + } + match av.cast(&DataType::Time) { + AnyValue::Time(i) => builder.append_value(i), + _ => builder.append_null(), + } + }, + } + } + Ok(builder.finish().into()) +} + +#[cfg(feature = "dtype-datetime")] +fn any_values_to_datetime( + values: &[AnyValue], + time_unit: TimeUnit, + time_zone: Option, + strict: bool, +) -> PolarsResult { + let mut builder = PrimitiveChunkedBuilder::::new(PlSmallStr::EMPTY, values.len()); + let target_dtype = DataType::Datetime(time_unit, time_zone.clone()); + for av in values { + match av { + AnyValue::Datetime(i, tu, _) if *tu == time_unit => builder.append_value(*i), + AnyValue::DatetimeOwned(i, tu, _) if *tu == time_unit => builder.append_value(*i), + AnyValue::Null => builder.append_null(), + av => { + if strict { + return Err(invalid_value_error(&target_dtype, av)); + } + match av.cast(&target_dtype) { + AnyValue::Datetime(i, _, _) => builder.append_value(i), + AnyValue::DatetimeOwned(i, _, _) => builder.append_value(i), + _ => builder.append_null(), + } + }, + } + } + Ok(builder.finish().into_datetime(time_unit, time_zone)) +} + +#[cfg(feature = "dtype-duration")] +fn any_values_to_duration( + values: &[AnyValue], + time_unit: TimeUnit, + strict: bool, +) -> PolarsResult { + let mut builder = PrimitiveChunkedBuilder::::new(PlSmallStr::EMPTY, values.len()); + let target_dtype = DataType::Duration(time_unit); + for av in values { + match av { + AnyValue::Duration(i, tu) if *tu == time_unit => builder.append_value(*i), + AnyValue::Null => builder.append_null(), + av => { + if strict { + return Err(invalid_value_error(&target_dtype, av)); + } + match av.cast(&target_dtype) { + AnyValue::Duration(i, _) => builder.append_value(i), + _ => builder.append_null(), + } + }, + } + } + Ok(builder.finish().into_duration(time_unit)) +} + +#[cfg(feature = "dtype-categorical")] +fn any_values_to_categorical( + values: &[AnyValue], + dtype: &DataType, + strict: bool, +) -> PolarsResult { + let ordering = match dtype { + DataType::Categorical(_, ordering) => ordering, + _ => panic!("any_values_to_categorical with dtype={dtype:?}"), + }; + + let mut builder = CategoricalChunkedBuilder::new(PlSmallStr::EMPTY, values.len(), *ordering); + + let mut owned = String::new(); // Amortize allocations. + for av in values { + match av { + AnyValue::String(s) => builder.append_value(s), + AnyValue::StringOwned(s) => builder.append_value(s), + + AnyValue::Enum(s, rev, _) => builder.append_value(rev.get(*s)), + AnyValue::EnumOwned(s, rev, _) => builder.append_value(rev.get(*s)), + + AnyValue::Categorical(s, rev, _) => builder.append_value(rev.get(*s)), + AnyValue::CategoricalOwned(s, rev, _) => builder.append_value(rev.get(*s)), + + AnyValue::Binary(_) | AnyValue::BinaryOwned(_) if !strict => builder.append_null(), + AnyValue::Null => builder.append_null(), + + av => { + if strict { + return Err(invalid_value_error(&DataType::String, av)); + } + + owned.clear(); + write!(owned, "{av}").unwrap(); + builder.append_value(&owned); + }, + } + } + + let ca = builder.finish(); + + Ok(ca.into_series()) +} + +#[cfg(feature = "dtype-categorical")] +fn any_values_to_enum(values: &[AnyValue], dtype: &DataType, strict: bool) -> PolarsResult { + use self::enum_::EnumChunkedBuilder; + + let (rev, ordering) = match dtype { + DataType::Enum(rev, ordering) => (rev.clone(), ordering), + _ => panic!("any_values_to_categorical with dtype={dtype:?}"), + }; + + let Some(rev) = rev else { + polars_bail!(nyi = "Not yet possible to create enum series without a rev-map"); + }; + + let mut builder = + EnumChunkedBuilder::new(PlSmallStr::EMPTY, values.len(), rev, *ordering, strict); + + let mut owned = String::new(); // Amortize allocations. + for av in values { + match av { + AnyValue::String(s) => builder.append_str(s)?, + AnyValue::StringOwned(s) => builder.append_str(s)?, + + AnyValue::Enum(s, rev, _) => builder.append_enum(*s, rev)?, + AnyValue::EnumOwned(s, rev, _) => builder.append_enum(*s, rev)?, + + AnyValue::Categorical(s, rev, _) => builder.append_str(rev.get(*s))?, + AnyValue::CategoricalOwned(s, rev, _) => builder.append_str(rev.get(*s))?, + + AnyValue::Binary(_) | AnyValue::BinaryOwned(_) if !strict => builder.append_null(), + AnyValue::Null => builder.append_null(), + + av => { + if strict { + return Err(invalid_value_error(&DataType::String, av)); + } + + owned.clear(); + write!(owned, "{av}").unwrap(); + builder.append_str(&owned)? + }, + }; + } + + let ca = builder.finish(); + + Ok(ca.into_series()) +} + +#[cfg(feature = "dtype-decimal")] +fn any_values_to_decimal( + values: &[AnyValue], + precision: Option, + scale: Option, // If None, we're inferring the scale. + strict: bool, +) -> PolarsResult { + /// Get the maximum scale among AnyValues + fn infer_scale( + values: &[AnyValue], + precision: Option, + strict: bool, + ) -> PolarsResult { + let mut max_scale = 0; + for av in values { + let av_scale = match av { + AnyValue::Decimal(_, scale) => *scale, + AnyValue::Null => continue, + av => { + if strict { + let target_dtype = DataType::Decimal(precision, None); + return Err(invalid_value_error(&target_dtype, av)); + } + continue; + }, + }; + max_scale = max_scale.max(av_scale); + } + Ok(max_scale) + } + let scale = match scale { + Some(s) => s, + None => infer_scale(values, precision, strict)?, + }; + let target_dtype = DataType::Decimal(precision, Some(scale)); + + let mut builder = PrimitiveChunkedBuilder::::new(PlSmallStr::EMPTY, values.len()); + for av in values { + match av { + // Allow equal or less scale. We do want to support different scales even in 'strict' mode. + AnyValue::Decimal(v, s) if *s <= scale => { + if *s == scale { + builder.append_value(*v) + } else { + match av.strict_cast(&target_dtype) { + Some(AnyValue::Decimal(i, _)) => builder.append_value(i), + _ => builder.append_null(), + } + } + }, + AnyValue::Null => builder.append_null(), + av => { + if strict { + return Err(invalid_value_error(&target_dtype, av)); + } + // TODO: Precision check, else set to null + match av.strict_cast(&target_dtype) { + Some(AnyValue::Decimal(i, _)) => builder.append_value(i), + _ => builder.append_null(), + } + }, + }; + } + + // Build the array and do a precision check if needed. + builder.finish().into_decimal(precision, scale) +} + +fn any_values_to_list( + avs: &[AnyValue], + inner_type: &DataType, + strict: bool, +) -> PolarsResult { + // GB: + // Lord forgive for the sins I have committed in this function. The amount of strange + // exceptions that need to happen for this to work are insane and I feel like I am going crazy. + // + // This function is essentially a copy of the `` where it does not + // sample the datatype from the first element and instead we give it explicitly. This allows + // this function to properly assign a datatype if `avs` starts with a `null` value. Previously, + // this was solved by assigning the `dtype` again afterwards, but why? We should not link the + // implementation of these functions. We still need to assign the dtype of the ListArray and + // such, anyways. + // + // Then, `collect_ca_with_dtype` does not possess the necessary exceptions shown in this + // function to use that. I have tried adding the exceptions there and it broke other things. I + // really do feel like this is the simplest solution. + + let mut valid = true; + let capacity = avs.len(); + + let ca = match inner_type { + // AnyValues with empty lists in python can create + // Series of an unknown dtype. + // We use the anonymousbuilder without a dtype + // the empty arrays is then not added (we add an extra offset instead) + // the next non-empty series then must have the correct dtype. + DataType::Null => { + let mut builder = AnonymousOwnedListBuilder::new(PlSmallStr::EMPTY, capacity, None); + for av in avs { + match av { + AnyValue::List(b) => builder.append_series(b)?, + AnyValue::Null => builder.append_null(), + _ => { + valid = false; + builder.append_null(); + }, + } + } + builder.finish() + }, + + #[cfg(feature = "object")] + DataType::Object(_) => polars_bail!(nyi = "Nested object types"), + + _ => { + let list_inner_type = match inner_type { + // Categoricals may not have a revmap yet. We just give them an empty one here and + // the list builder takes care of the rest. + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(None, ordering) => { + DataType::Categorical(Some(Arc::new(RevMapping::default())), *ordering) + }, + + _ => inner_type.clone(), + }; + + let mut builder = + get_list_builder(&list_inner_type, capacity * 5, capacity, PlSmallStr::EMPTY); + + for av in avs { + match av { + AnyValue::List(b) => match b.cast(inner_type) { + Ok(casted) => { + if casted.null_count() != b.null_count() { + valid = !strict; + } + builder.append_series(&casted)?; + }, + Err(_) => { + valid = false; + for _ in 0..b.len() { + builder.append_null(); + } + }, + }, + AnyValue::Null => builder.append_null(), + _ => { + valid = false; + builder.append_null() + }, + } + } + + builder.finish() + }, + }; + + if strict && !valid { + polars_bail!(SchemaMismatch: "unexpected value while building Series of type {:?}", DataType::List(Box::new(inner_type.clone()))); + } + + Ok(ca) +} + +#[cfg(feature = "dtype-array")] +fn any_values_to_array( + avs: &[AnyValue], + inner_type: &DataType, + strict: bool, + width: usize, +) -> PolarsResult { + fn to_arr(s: &Series) -> Option { + if s.chunks().len() > 1 { + let s = s.rechunk(); + Some(s.chunks()[0].clone()) + } else { + Some(s.chunks()[0].clone()) + } + } + + let target_dtype = DataType::Array(Box::new(inner_type.clone()), width); + + // This is handled downstream. The builder will choose the first non null type. + let mut valid = true; + #[allow(unused_mut)] + let mut out: ArrayChunked = if inner_type == &DataType::Null { + avs.iter() + .map(|av| match av { + AnyValue::List(b) | AnyValue::Array(b, _) => to_arr(b), + AnyValue::Null => None, + _ => { + valid = false; + None + }, + }) + .collect_ca_with_dtype(PlSmallStr::EMPTY, target_dtype.clone()) + } + // Make sure that wrongly inferred AnyValues don't deviate from the datatype. + else { + avs.iter() + .map(|av| match av { + AnyValue::List(b) | AnyValue::Array(b, _) => { + if b.dtype() == inner_type { + to_arr(b) + } else { + let s = match b.cast(inner_type) { + Ok(out) => out, + Err(_) => Series::full_null(b.name().clone(), b.len(), inner_type), + }; + to_arr(&s) + } + }, + AnyValue::Null => None, + _ => { + valid = false; + None + }, + }) + .collect_ca_with_dtype(PlSmallStr::EMPTY, target_dtype.clone()) + }; + + if strict && !valid { + polars_bail!(SchemaMismatch: "unexpected value while building Series of type {:?}", target_dtype); + } + polars_ensure!( + out.width() == width, + SchemaMismatch: "got mixed size array widths where width {} was expected", width + ); + + // Ensure the logical type is correct for nested types. + #[cfg(feature = "dtype-struct")] + if !matches!(inner_type, DataType::Null) && out.inner_dtype().is_nested() { + unsafe { + out.set_dtype(target_dtype.clone()); + }; + } + + Ok(out) +} + +#[cfg(feature = "dtype-struct")] +fn _any_values_to_struct<'a>( + av_fields: &[Field], + av_values: &[AnyValue<'a>], + field_index: usize, + field: &Field, + fields: &[Field], + field_avs: &mut Vec>, +) { + // TODO: Optimize. + + let mut append_by_search = || { + // Search for the name. + if let Some(i) = av_fields + .iter() + .position(|av_fld| av_fld.name == field.name) + { + field_avs.push(av_values[i].clone()); + return; + } + field_avs.push(AnyValue::Null) + }; + + // All fields are available in this single value. + // We can use the index to get value. + if fields.len() == av_fields.len() { + if fields.iter().zip(av_fields.iter()).any(|(l, r)| l != r) { + append_by_search() + } else { + let av_val = av_values + .get(field_index) + .cloned() + .unwrap_or(AnyValue::Null); + field_avs.push(av_val) + } + } + // Not all fields are available, we search the proper field. + else { + // Search for the name. + append_by_search() + } +} + +#[cfg(feature = "dtype-struct")] +fn any_values_to_struct( + values: &[AnyValue], + fields: &[Field], + strict: bool, +) -> PolarsResult { + // Fast path for structs with no fields. + if fields.is_empty() { + return Ok( + StructChunked::from_series(PlSmallStr::EMPTY, values.len(), [].iter())?.into_series(), + ); + } + + // The physical series fields of the struct. + let mut series_fields = Vec::with_capacity(fields.len()); + let mut has_outer_validity = false; + let mut field_avs = Vec::with_capacity(values.len()); + for (i, field) in fields.iter().enumerate() { + field_avs.clear(); + + for av in values.iter() { + match av { + AnyValue::StructOwned(payload) => { + let av_fields = &payload.1; + let av_values = &payload.0; + _any_values_to_struct(av_fields, av_values, i, field, fields, &mut field_avs); + }, + AnyValue::Struct(_, _, av_fields) => { + let av_values: Vec<_> = av._iter_struct_av().collect(); + _any_values_to_struct(av_fields, &av_values, i, field, fields, &mut field_avs); + }, + _ => { + has_outer_validity = true; + field_avs.push(AnyValue::Null) + }, + } + } + // If the inferred dtype is null, we let auto inference work. + let s = if matches!(field.dtype, DataType::Null) { + Series::from_any_values(field.name().clone(), &field_avs, strict)? + } else { + Series::from_any_values_and_dtype( + field.name().clone(), + &field_avs, + &field.dtype, + strict, + )? + }; + series_fields.push(s) + } + + let mut out = + StructChunked::from_series(PlSmallStr::EMPTY, values.len(), series_fields.iter())?; + if has_outer_validity { + let mut validity = MutableBitmap::new(); + validity.extend_constant(values.len(), true); + for (i, v) in values.iter().enumerate() { + if matches!(v, AnyValue::Null) { + unsafe { validity.set_unchecked(i, false) } + } + } + out.set_outer_validity(Some(validity.freeze())) + } + Ok(out.into_series()) +} + +#[cfg(feature = "object")] +fn any_values_to_object(values: &[AnyValue]) -> PolarsResult { + use crate::chunked_array::object::registry; + let converter = registry::get_object_converter(); + let mut builder = registry::get_object_builder(PlSmallStr::EMPTY, values.len()); + for av in values { + match av { + AnyValue::Object(val) => builder.append_value(val.as_any()), + AnyValue::Null => builder.append_null(), + _ => { + // This is needed because in Python users can send mixed types. + // This only works if you set a global converter. + let any = converter(av.as_borrowed()); + builder.append_value(&*any) + }, + } + } + + Ok(builder.to_series()) +} + +fn invalid_value_error(dtype: &DataType, value: &AnyValue) -> PolarsError { + polars_err!( + SchemaMismatch: + "unexpected value while building Series of type {:?}; found value of type {:?}: {}", + dtype, + value.dtype(), + value + ) +} diff --git a/crates/polars-core/src/series/arithmetic/bitops.rs b/crates/polars-core/src/series/arithmetic/bitops.rs new file mode 100644 index 000000000000..4d5ccf24ce45 --- /dev/null +++ b/crates/polars-core/src/series/arithmetic/bitops.rs @@ -0,0 +1,67 @@ +use std::borrow::Cow; + +use polars_error::PolarsResult; + +use super::{BooleanChunked, ChunkedArray, DataType, IntoSeries, Series, polars_bail}; + +macro_rules! impl_bitop { + ($(($trait:ident, $f:ident))+) => { + $( + impl std::ops::$trait for &Series { + type Output = PolarsResult; + #[inline(never)] + fn $f(self, rhs: Self) -> Self::Output { + use DataType as DT; + match self.dtype() { + DT::Boolean => { + let lhs: &BooleanChunked = self.as_ref().as_ref().as_ref(); + let rhs = lhs.unpack_series_matching_type(rhs)?; + Ok(lhs.$f(rhs).into_series()) + }, + dt if dt.is_integer() => { + let rhs = if rhs.len() == 1 { + Cow::Owned(rhs.cast(self.dtype())?) + } else { + Cow::Borrowed(rhs) + }; + + with_match_physical_integer_polars_type!(dt, |$T| { + let lhs: &ChunkedArray<$T> = self.as_ref().as_ref().as_ref(); + let rhs = lhs.unpack_series_matching_type(&rhs)?; + Ok(lhs.$f(&rhs).into_series()) + }) + }, + _ => polars_bail!(opq = $f, self.dtype()), + } + } + } + impl std::ops::$trait for Series { + type Output = PolarsResult; + #[inline(always)] + fn $f(self, rhs: Self) -> Self::Output { + <&Series as std::ops::$trait>::$f(&self, &rhs) + } + } + impl std::ops::$trait<&Series> for Series { + type Output = PolarsResult; + #[inline(always)] + fn $f(self, rhs: &Series) -> Self::Output { + <&Series as std::ops::$trait>::$f(&self, rhs) + } + } + impl std::ops::$trait for &Series { + type Output = PolarsResult; + #[inline(always)] + fn $f(self, rhs: Series) -> Self::Output { + <&Series as std::ops::$trait>::$f(self, &rhs) + } + } + )+ + }; +} + +impl_bitop! { + (BitAnd, bitand) + (BitOr, bitor) + (BitXor, bitxor) +} diff --git a/crates/polars-core/src/series/arithmetic/borrowed.rs b/crates/polars-core/src/series/arithmetic/borrowed.rs new file mode 100644 index 000000000000..d7c2ccb9f7bf --- /dev/null +++ b/crates/polars-core/src/series/arithmetic/borrowed.rs @@ -0,0 +1,994 @@ +use super::*; +use crate::utils::align_chunks_binary; + +pub trait NumOpsDispatchInner: PolarsDataType + Sized { + fn subtract(lhs: &ChunkedArray, rhs: &Series) -> PolarsResult { + polars_bail!(opq = sub, lhs.dtype(), rhs.dtype()); + } + fn add_to(lhs: &ChunkedArray, rhs: &Series) -> PolarsResult { + polars_bail!(opq = add, lhs.dtype(), rhs.dtype()); + } + fn multiply(lhs: &ChunkedArray, rhs: &Series) -> PolarsResult { + polars_bail!(opq = mul, lhs.dtype(), rhs.dtype()); + } + fn divide(lhs: &ChunkedArray, rhs: &Series) -> PolarsResult { + polars_bail!(opq = div, lhs.dtype(), rhs.dtype()); + } + fn remainder(lhs: &ChunkedArray, rhs: &Series) -> PolarsResult { + polars_bail!(opq = rem, lhs.dtype(), rhs.dtype()); + } +} + +pub trait NumOpsDispatch { + fn subtract(&self, rhs: &Series) -> PolarsResult; + fn add_to(&self, rhs: &Series) -> PolarsResult; + fn multiply(&self, rhs: &Series) -> PolarsResult; + fn divide(&self, rhs: &Series) -> PolarsResult; + fn remainder(&self, rhs: &Series) -> PolarsResult; +} + +impl NumOpsDispatch for ChunkedArray { + fn subtract(&self, rhs: &Series) -> PolarsResult { + T::subtract(self, rhs) + } + fn add_to(&self, rhs: &Series) -> PolarsResult { + T::add_to(self, rhs) + } + fn multiply(&self, rhs: &Series) -> PolarsResult { + T::multiply(self, rhs) + } + fn divide(&self, rhs: &Series) -> PolarsResult { + T::divide(self, rhs) + } + fn remainder(&self, rhs: &Series) -> PolarsResult { + T::remainder(self, rhs) + } +} + +impl NumOpsDispatchInner for T +where + T: PolarsNumericType, + ChunkedArray: IntoSeries, +{ + fn subtract(lhs: &ChunkedArray, rhs: &Series) -> PolarsResult { + polars_ensure!( + lhs.dtype() == rhs.dtype(), + opq = add, + rhs.dtype(), + rhs.dtype() + ); + + // SAFETY: + // There will be UB if a ChunkedArray is alive with the wrong datatype. + // we now only create the potentially wrong dtype for a short time. + // Note that the physical type correctness is checked! + // The ChunkedArray with the wrong dtype is dropped after this operation + let rhs = unsafe { lhs.unpack_series_matching_physical_type(rhs) }; + let out = lhs - rhs; + Ok(out.into_series()) + } + fn add_to(lhs: &ChunkedArray, rhs: &Series) -> PolarsResult { + polars_ensure!( + lhs.dtype() == rhs.dtype(), + opq = add, + rhs.dtype(), + rhs.dtype() + ); + + // SAFETY: + // see subtract + let rhs = unsafe { lhs.unpack_series_matching_physical_type(rhs) }; + let out = lhs + rhs; + Ok(out.into_series()) + } + fn multiply(lhs: &ChunkedArray, rhs: &Series) -> PolarsResult { + polars_ensure!( + lhs.dtype() == rhs.dtype(), + opq = add, + rhs.dtype(), + rhs.dtype() + ); + + // SAFETY: + // see subtract + let rhs = unsafe { lhs.unpack_series_matching_physical_type(rhs) }; + let out = lhs * rhs; + Ok(out.into_series()) + } + fn divide(lhs: &ChunkedArray, rhs: &Series) -> PolarsResult { + polars_ensure!( + lhs.dtype() == rhs.dtype(), + opq = add, + rhs.dtype(), + rhs.dtype() + ); + + // SAFETY: + // see subtract + let rhs = unsafe { lhs.unpack_series_matching_physical_type(rhs) }; + let out = lhs / rhs; + Ok(out.into_series()) + } + fn remainder(lhs: &ChunkedArray, rhs: &Series) -> PolarsResult { + polars_ensure!( + lhs.dtype() == rhs.dtype(), + opq = add, + rhs.dtype(), + rhs.dtype() + ); + + // SAFETY: + // see subtract + let rhs = unsafe { lhs.unpack_series_matching_physical_type(rhs) }; + let out = lhs % rhs; + Ok(out.into_series()) + } +} + +impl NumOpsDispatchInner for StringType { + fn add_to(lhs: &StringChunked, rhs: &Series) -> PolarsResult { + let rhs = lhs.unpack_series_matching_type(rhs)?; + let out = lhs + rhs; + Ok(out.into_series()) + } +} + +impl NumOpsDispatchInner for BinaryType { + fn add_to(lhs: &BinaryChunked, rhs: &Series) -> PolarsResult { + let rhs = lhs.unpack_series_matching_type(rhs)?; + let out = lhs + rhs; + Ok(out.into_series()) + } +} + +impl NumOpsDispatchInner for BooleanType { + fn add_to(lhs: &BooleanChunked, rhs: &Series) -> PolarsResult { + let rhs = lhs.unpack_series_matching_type(rhs)?; + let out = lhs + rhs; + Ok(out.into_series()) + } +} + +#[cfg(feature = "checked_arithmetic")] +pub mod checked { + use num_traits::{CheckedDiv, One, ToPrimitive, Zero}; + + use super::*; + + pub trait NumOpsDispatchCheckedInner: PolarsDataType + Sized { + /// Checked integer division. Computes self / rhs, returning None if rhs == 0 or the division results in overflow. + fn checked_div(lhs: &ChunkedArray, rhs: &Series) -> PolarsResult { + polars_bail!(opq = checked_div, lhs.dtype(), rhs.dtype()); + } + fn checked_div_num( + lhs: &ChunkedArray, + _rhs: T, + ) -> PolarsResult { + polars_bail!(opq = checked_div_num, lhs.dtype(), Self::get_dtype()); + } + } + + pub trait NumOpsDispatchChecked { + /// Checked integer division. Computes self / rhs, returning None if rhs == 0 or the division results in overflow. + fn checked_div(&self, rhs: &Series) -> PolarsResult; + fn checked_div_num(&self, _rhs: T) -> PolarsResult; + } + + impl NumOpsDispatchChecked for ChunkedArray { + fn checked_div(&self, rhs: &Series) -> PolarsResult { + S::checked_div(self, rhs) + } + fn checked_div_num(&self, rhs: T) -> PolarsResult { + S::checked_div_num(self, rhs) + } + } + + impl NumOpsDispatchCheckedInner for T + where + T: PolarsIntegerType, + T::Native: CheckedDiv + CheckedDiv + Zero + One, + ChunkedArray: IntoSeries, + { + fn checked_div(lhs: &ChunkedArray, rhs: &Series) -> PolarsResult { + // SAFETY: + // There will be UB if a ChunkedArray is alive with the wrong datatype. + // we now only create the potentially wrong dtype for a short time. + // Note that the physical type correctness is checked! + // The ChunkedArray with the wrong dtype is dropped after this operation + let rhs = unsafe { lhs.unpack_series_matching_physical_type(rhs) }; + + Ok( + arity::binary_elementwise(lhs, rhs, |opt_l, opt_r| match (opt_l, opt_r) { + (Some(l), Some(r)) => l.checked_div(&r), + _ => None, + }) + .into_series(), + ) + } + } + + impl NumOpsDispatchCheckedInner for Float32Type { + fn checked_div(lhs: &Float32Chunked, rhs: &Series) -> PolarsResult { + // SAFETY: + // see check_div for chunkedarray + let rhs = unsafe { lhs.unpack_series_matching_physical_type(rhs) }; + + let ca: Float32Chunked = + arity::binary_elementwise(lhs, rhs, |opt_l, opt_r| match (opt_l, opt_r) { + (Some(l), Some(r)) => { + if r.is_zero() { + None + } else { + Some(l / r) + } + }, + _ => None, + }); + Ok(ca.into_series()) + } + } + + impl NumOpsDispatchCheckedInner for Float64Type { + fn checked_div(lhs: &Float64Chunked, rhs: &Series) -> PolarsResult { + // SAFETY: + // see check_div + let rhs = unsafe { lhs.unpack_series_matching_physical_type(rhs) }; + + let ca: Float64Chunked = + arity::binary_elementwise(lhs, rhs, |opt_l, opt_r| match (opt_l, opt_r) { + (Some(l), Some(r)) => { + if r.is_zero() { + None + } else { + Some(l / r) + } + }, + _ => None, + }); + Ok(ca.into_series()) + } + } + + impl NumOpsDispatchChecked for Series { + fn checked_div(&self, rhs: &Series) -> PolarsResult { + let (lhs, rhs) = coerce_lhs_rhs(self, rhs).expect("cannot coerce datatypes"); + lhs.as_ref().as_ref().checked_div(rhs.as_ref()) + } + + fn checked_div_num(&self, rhs: T) -> PolarsResult { + use DataType::*; + let s = self.to_physical_repr(); + + let out = match s.dtype() { + #[cfg(feature = "dtype-u8")] + UInt8 => s + .u8() + .unwrap() + .apply(|opt_v| opt_v.and_then(|v| v.checked_div(rhs.to_u8().unwrap()))) + .into_series(), + #[cfg(feature = "dtype-i8")] + Int8 => s + .i8() + .unwrap() + .apply(|opt_v| opt_v.and_then(|v| v.checked_div(rhs.to_i8().unwrap()))) + .into_series(), + #[cfg(feature = "dtype-i16")] + Int16 => s + .i16() + .unwrap() + .apply(|opt_v| opt_v.and_then(|v| v.checked_div(rhs.to_i16().unwrap()))) + .into_series(), + #[cfg(feature = "dtype-u16")] + UInt16 => s + .u16() + .unwrap() + .apply(|opt_v| opt_v.and_then(|v| v.checked_div(rhs.to_u16().unwrap()))) + .into_series(), + UInt32 => s + .u32() + .unwrap() + .apply(|opt_v| opt_v.and_then(|v| v.checked_div(rhs.to_u32().unwrap()))) + .into_series(), + Int32 => s + .i32() + .unwrap() + .apply(|opt_v| opt_v.and_then(|v| v.checked_div(rhs.to_i32().unwrap()))) + .into_series(), + UInt64 => s + .u64() + .unwrap() + .apply(|opt_v| opt_v.and_then(|v| v.checked_div(rhs.to_u64().unwrap()))) + .into_series(), + Int64 => s + .i64() + .unwrap() + .apply(|opt_v| opt_v.and_then(|v| v.checked_div(rhs.to_i64().unwrap()))) + .into_series(), + Float32 => s + .f32() + .unwrap() + .apply(|opt_v| { + opt_v.and_then(|v| { + let res = rhs.to_f32().unwrap(); + if res.is_zero() { None } else { Some(v / res) } + }) + }) + .into_series(), + Float64 => s + .f64() + .unwrap() + .apply(|opt_v| { + opt_v.and_then(|v| { + let res = rhs.to_f64().unwrap(); + if res.is_zero() { None } else { Some(v / res) } + }) + }) + .into_series(), + _ => panic!("dtype not yet supported in checked div"), + }; + out.cast(self.dtype()) + } + } +} + +pub fn coerce_lhs_rhs<'a>( + lhs: &'a Series, + rhs: &'a Series, +) -> PolarsResult<(Cow<'a, Series>, Cow<'a, Series>)> { + if let Some(result) = coerce_time_units(lhs, rhs) { + return Ok(result); + } + let (left_dtype, right_dtype) = (lhs.dtype(), rhs.dtype()); + let leaf_super_dtype = try_get_supertype(left_dtype.leaf_dtype(), right_dtype.leaf_dtype())?; + + let mut new_left_dtype = left_dtype.cast_leaf(leaf_super_dtype.clone()); + let mut new_right_dtype = right_dtype.cast_leaf(leaf_super_dtype); + + // Correct the list and array types + // + // This also casts Lists <-> Array. + if left_dtype.is_list() + || right_dtype.is_list() + || left_dtype.is_array() + || right_dtype.is_array() + { + new_left_dtype = try_get_supertype(&new_left_dtype, &new_right_dtype)?; + new_right_dtype = new_left_dtype.clone(); + } + + let left = if lhs.dtype() == &new_left_dtype { + Cow::Borrowed(lhs) + } else { + Cow::Owned(lhs.cast(&new_left_dtype)?) + }; + let right = if rhs.dtype() == &new_right_dtype { + Cow::Borrowed(rhs) + } else { + Cow::Owned(rhs.cast(&new_right_dtype)?) + }; + Ok((left, right)) +} + +// Handle (Date | Datetime) +/- (Duration) | (Duration) +/- (Date | Datetime) | (Duration) +- +// (Duration) +// Time arithmetic is only implemented on the date / datetime so ensure that's on left + +fn coerce_time_units<'a>( + lhs: &'a Series, + rhs: &'a Series, +) -> Option<(Cow<'a, Series>, Cow<'a, Series>)> { + match (lhs.dtype(), rhs.dtype()) { + (DataType::Datetime(lu, t), DataType::Duration(ru)) => { + let units = get_time_units(lu, ru); + let left = if *lu == units { + Cow::Borrowed(lhs) + } else { + Cow::Owned(lhs.cast(&DataType::Datetime(units, t.clone())).ok()?) + }; + let right = if *ru == units { + Cow::Borrowed(rhs) + } else { + Cow::Owned(rhs.cast(&DataType::Duration(units)).ok()?) + }; + Some((left, right)) + }, + // make sure to return Some here, so we don't cast to supertype. + (DataType::Date, DataType::Duration(_)) => Some((Cow::Borrowed(lhs), Cow::Borrowed(rhs))), + (DataType::Duration(lu), DataType::Duration(ru)) => { + let units = get_time_units(lu, ru); + let left = if *lu == units { + Cow::Borrowed(lhs) + } else { + Cow::Owned(lhs.cast(&DataType::Duration(units)).ok()?) + }; + let right = if *ru == units { + Cow::Borrowed(rhs) + } else { + Cow::Owned(rhs.cast(&DataType::Duration(units)).ok()?) + }; + Some((left, right)) + }, + // swap the order + (DataType::Duration(_), DataType::Datetime(_, _)) + | (DataType::Duration(_), DataType::Date) => { + let (right, left) = coerce_time_units(rhs, lhs)?; + Some((left, right)) + }, + _ => None, + } +} + +#[cfg(feature = "dtype-struct")] +pub fn _struct_arithmetic PolarsResult>( + s: &Series, + rhs: &Series, + mut func: F, +) -> PolarsResult { + let s = s.struct_().unwrap(); + let rhs = rhs.struct_().unwrap(); + + let s_fields = s.fields_as_series(); + let rhs_fields = rhs.fields_as_series(); + + match (s_fields.len(), rhs_fields.len()) { + (_, 1) => { + let rhs = &rhs.fields_as_series()[0]; + Ok(s.try_apply_fields(|s| func(s, rhs))?.into_series()) + }, + (1, _) => { + let s = &s.fields_as_series()[0]; + Ok(rhs.try_apply_fields(|rhs| func(s, rhs))?.into_series()) + }, + _ => { + let mut s = Cow::Borrowed(s); + let mut rhs = Cow::Borrowed(rhs); + + match (s.len(), rhs.len()) { + (l, r) if l == r => {}, + (1, _) => s = Cow::Owned(s.new_from_index(0, rhs.len())), + (_, 1) => rhs = Cow::Owned(rhs.new_from_index(0, s.len())), + (l, r) => { + polars_bail!(ComputeError: "Struct arithmetic between different lengths {l} != {r}") + }, + }; + let (s, rhs) = align_chunks_binary(&s, &rhs); + let mut s = s.into_owned(); + + // Expects lengths to be equal. + s.zip_outer_validity(rhs.as_ref()); + + let mut rhs_iter = rhs.fields_as_series().into_iter(); + + Ok(s.try_apply_fields(|s| match rhs_iter.next() { + Some(rhs) => func(s, &rhs), + None => Ok(s.clone()), + })? + .into_series()) + }, + } +} + +fn check_lengths(a: &Series, b: &Series) -> PolarsResult<()> { + match (a.len(), b.len()) { + // broadcasting + (1, _) | (_, 1) => Ok(()), + // equal + (a, b) if a == b => Ok(()), + // unequal + (a, b) => { + polars_bail!(InvalidOperation: "cannot do arithmetic operation on series of different lengths: got {} and {}", a, b) + }, + } +} + +impl Add for &Series { + type Output = PolarsResult; + + fn add(self, rhs: Self) -> Self::Output { + check_lengths(self, rhs)?; + match (self.dtype(), rhs.dtype()) { + #[cfg(feature = "dtype-struct")] + (DataType::Struct(_), DataType::Struct(_)) => { + _struct_arithmetic(self, rhs, |a, b| a.add(b)) + }, + (DataType::List(_), _) | (_, DataType::List(_)) => { + list::NumericListOp::add().execute(self, rhs) + }, + #[cfg(feature = "dtype-array")] + (DataType::Array(..), _) | (_, DataType::Array(..)) => { + fixed_size_list::NumericFixedSizeListOp::add().execute(self, rhs) + }, + _ => { + let (lhs, rhs) = coerce_lhs_rhs(self, rhs)?; + lhs.add_to(rhs.as_ref()) + }, + } + } +} + +impl Sub for &Series { + type Output = PolarsResult; + + fn sub(self, rhs: Self) -> Self::Output { + check_lengths(self, rhs)?; + match (self.dtype(), rhs.dtype()) { + #[cfg(feature = "dtype-struct")] + (DataType::Struct(_), DataType::Struct(_)) => { + _struct_arithmetic(self, rhs, |a, b| a.sub(b)) + }, + (DataType::List(_), _) | (_, DataType::List(_)) => { + list::NumericListOp::sub().execute(self, rhs) + }, + #[cfg(feature = "dtype-array")] + (DataType::Array(..), _) | (_, DataType::Array(..)) => { + fixed_size_list::NumericFixedSizeListOp::sub().execute(self, rhs) + }, + _ => { + let (lhs, rhs) = coerce_lhs_rhs(self, rhs)?; + lhs.subtract(rhs.as_ref()) + }, + } + } +} + +impl Mul for &Series { + type Output = PolarsResult; + + /// ``` + /// # use polars_core::prelude::*; + /// let s: Series = [1, 2, 3].iter().collect(); + /// let out = (&s * &s).unwrap(); + /// ``` + fn mul(self, rhs: Self) -> Self::Output { + check_lengths(self, rhs)?; + + use DataType::*; + match (self.dtype(), rhs.dtype()) { + #[cfg(feature = "dtype-struct")] + (Struct(_), Struct(_)) => _struct_arithmetic(self, rhs, |a, b| a.mul(b)), + // temporal lh + (Duration(_), _) | (Date, _) | (Datetime(_, _), _) | (Time, _) => self.multiply(rhs), + // temporal rhs + (_, Date) | (_, Datetime(_, _)) | (_, Time) => { + polars_bail!(opq = mul, self.dtype(), rhs.dtype()) + }, + (_, Duration(_)) => { + // swap order + let out = rhs.multiply(self)?; + Ok(out.with_name(self.name().clone())) + }, + (DataType::List(_), _) | (_, DataType::List(_)) => { + list::NumericListOp::mul().execute(self, rhs) + }, + #[cfg(feature = "dtype-array")] + (DataType::Array(..), _) | (_, DataType::Array(..)) => { + fixed_size_list::NumericFixedSizeListOp::mul().execute(self, rhs) + }, + _ => { + let (lhs, rhs) = coerce_lhs_rhs(self, rhs)?; + lhs.multiply(rhs.as_ref()) + }, + } + } +} + +impl Div for &Series { + type Output = PolarsResult; + + /// ``` + /// # use polars_core::prelude::*; + /// let s: Series = [1, 2, 3].iter().collect(); + /// let out = (&s / &s).unwrap(); + /// ``` + fn div(self, rhs: Self) -> Self::Output { + check_lengths(self, rhs)?; + use DataType::*; + match (self.dtype(), rhs.dtype()) { + #[cfg(feature = "dtype-struct")] + (Struct(_), Struct(_)) => _struct_arithmetic(self, rhs, |a, b| a.div(b)), + (Duration(_), _) => self.divide(rhs), + (Date, _) + | (Datetime(_, _), _) + | (Time, _) + | (_, Duration(_)) + | (_, Time) + | (_, Date) + | (_, Datetime(_, _)) => polars_bail!(opq = div, self.dtype(), rhs.dtype()), + (DataType::List(_), _) | (_, DataType::List(_)) => { + list::NumericListOp::div().execute(self, rhs) + }, + #[cfg(feature = "dtype-array")] + (DataType::Array(..), _) | (_, DataType::Array(..)) => { + fixed_size_list::NumericFixedSizeListOp::div().execute(self, rhs) + }, + _ => { + let (lhs, rhs) = coerce_lhs_rhs(self, rhs)?; + lhs.divide(rhs.as_ref()) + }, + } + } +} + +impl Rem for &Series { + type Output = PolarsResult; + + /// ``` + /// # use polars_core::prelude::*; + /// let s: Series = [1, 2, 3].iter().collect(); + /// let out = (&s / &s).unwrap(); + /// ``` + fn rem(self, rhs: Self) -> Self::Output { + check_lengths(self, rhs)?; + match (self.dtype(), rhs.dtype()) { + #[cfg(feature = "dtype-struct")] + (DataType::Struct(_), DataType::Struct(_)) => { + _struct_arithmetic(self, rhs, |a, b| a.rem(b)) + }, + (DataType::List(_), _) | (_, DataType::List(_)) => { + list::NumericListOp::rem().execute(self, rhs) + }, + #[cfg(feature = "dtype-array")] + (DataType::Array(..), _) | (_, DataType::Array(..)) => { + fixed_size_list::NumericFixedSizeListOp::rem().execute(self, rhs) + }, + _ => { + let (lhs, rhs) = coerce_lhs_rhs(self, rhs)?; + lhs.remainder(rhs.as_ref()) + }, + } + } +} + +// Series +-/* numbers instead of Series + +fn finish_cast(inp: &Series, out: Series) -> Series { + match inp.dtype() { + #[cfg(feature = "dtype-date")] + DataType::Date => out.into_date(), + #[cfg(feature = "dtype-datetime")] + DataType::Datetime(tu, tz) => out.into_datetime(*tu, tz.clone()), + #[cfg(feature = "dtype-duration")] + DataType::Duration(tu) => out.into_duration(*tu), + #[cfg(feature = "dtype-time")] + DataType::Time => out.into_time(), + _ => out, + } +} + +impl Sub for &Series +where + T: Num + NumCast, +{ + type Output = Series; + + fn sub(self, rhs: T) -> Self::Output { + let s = self.to_physical_repr(); + macro_rules! sub { + ($ca:expr) => {{ $ca.sub(rhs).into_series() }}; + } + + let out = downcast_as_macro_arg_physical!(s, sub); + finish_cast(self, out) + } +} + +impl Sub for Series +where + T: Num + NumCast, +{ + type Output = Self; + + fn sub(self, rhs: T) -> Self::Output { + (&self).sub(rhs) + } +} + +impl Add for &Series +where + T: Num + NumCast, +{ + type Output = Series; + + fn add(self, rhs: T) -> Self::Output { + let s = self.to_physical_repr(); + macro_rules! add { + ($ca:expr) => {{ $ca.add(rhs).into_series() }}; + } + let out = downcast_as_macro_arg_physical!(s, add); + finish_cast(self, out) + } +} + +impl Add for Series +where + T: Num + NumCast, +{ + type Output = Self; + + fn add(self, rhs: T) -> Self::Output { + (&self).add(rhs) + } +} + +impl Div for &Series +where + T: Num + NumCast, +{ + type Output = Series; + + fn div(self, rhs: T) -> Self::Output { + let s = self.to_physical_repr(); + macro_rules! div { + ($ca:expr) => {{ $ca.div(rhs).into_series() }}; + } + + let out = downcast_as_macro_arg_physical!(s, div); + finish_cast(self, out) + } +} + +impl Div for Series +where + T: Num + NumCast, +{ + type Output = Self; + + fn div(self, rhs: T) -> Self::Output { + (&self).div(rhs) + } +} + +// TODO: remove this, temporary band-aid. +impl Series { + pub fn wrapping_trunc_div_scalar(&self, rhs: T) -> Self { + let s = self.to_physical_repr(); + macro_rules! div { + ($ca:expr) => {{ + let rhs = NumCast::from(rhs).unwrap(); + $ca.wrapping_trunc_div_scalar(rhs).into_series() + }}; + } + + let out = downcast_as_macro_arg_physical!(s, div); + finish_cast(self, out) + } +} + +impl Mul for &Series +where + T: Num + NumCast, +{ + type Output = Series; + + fn mul(self, rhs: T) -> Self::Output { + let s = self.to_physical_repr(); + macro_rules! mul { + ($ca:expr) => {{ $ca.mul(rhs).into_series() }}; + } + let out = downcast_as_macro_arg_physical!(s, mul); + finish_cast(self, out) + } +} + +impl Mul for Series +where + T: Num + NumCast, +{ + type Output = Self; + + fn mul(self, rhs: T) -> Self::Output { + (&self).mul(rhs) + } +} + +impl Rem for &Series +where + T: Num + NumCast, +{ + type Output = Series; + + fn rem(self, rhs: T) -> Self::Output { + let s = self.to_physical_repr(); + macro_rules! rem { + ($ca:expr) => {{ $ca.rem(rhs).into_series() }}; + } + let out = downcast_as_macro_arg_physical!(s, rem); + finish_cast(self, out) + } +} + +impl Rem for Series +where + T: Num + NumCast, +{ + type Output = Self; + + fn rem(self, rhs: T) -> Self::Output { + (&self).rem(rhs) + } +} + +/// We cannot override the left hand side behaviour. So we create a trait LhsNumOps. +/// This allows for 1.add(&Series) +/// +impl ChunkedArray +where + T: PolarsNumericType, + ChunkedArray: IntoSeries, +{ + /// Apply lhs - self + #[must_use] + pub fn lhs_sub(&self, lhs: N) -> Self { + let lhs: T::Native = NumCast::from(lhs).expect("could not cast"); + ArithmeticChunked::wrapping_sub_scalar_lhs(lhs, self) + } + + /// Apply lhs / self + #[must_use] + pub fn lhs_div(&self, lhs: N) -> Self { + let lhs: T::Native = NumCast::from(lhs).expect("could not cast"); + ArithmeticChunked::legacy_div_scalar_lhs(lhs, self) + } + + /// Apply lhs % self + #[must_use] + pub fn lhs_rem(&self, lhs: N) -> Self { + let lhs: T::Native = NumCast::from(lhs).expect("could not cast"); + ArithmeticChunked::wrapping_mod_scalar_lhs(lhs, self) + } +} + +pub trait LhsNumOps { + type Output; + + fn add(self, rhs: &Series) -> Self::Output; + fn sub(self, rhs: &Series) -> Self::Output; + fn div(self, rhs: &Series) -> Self::Output; + fn mul(self, rhs: &Series) -> Self::Output; + fn rem(self, rem: &Series) -> Self::Output; +} + +impl LhsNumOps for T +where + T: Num + NumCast, +{ + type Output = Series; + + fn add(self, rhs: &Series) -> Self::Output { + // order doesn't matter, dispatch to rhs + lhs + rhs + self + } + fn sub(self, rhs: &Series) -> Self::Output { + let s = rhs.to_physical_repr(); + macro_rules! sub { + ($rhs:expr) => {{ $rhs.lhs_sub(self).into_series() }}; + } + let out = downcast_as_macro_arg_physical!(s, sub); + + finish_cast(rhs, out) + } + fn div(self, rhs: &Series) -> Self::Output { + let s = rhs.to_physical_repr(); + macro_rules! div { + ($rhs:expr) => {{ $rhs.lhs_div(self).into_series() }}; + } + let out = downcast_as_macro_arg_physical!(s, div); + + finish_cast(rhs, out) + } + fn mul(self, rhs: &Series) -> Self::Output { + // order doesn't matter, dispatch to rhs * lhs + rhs * self + } + fn rem(self, rhs: &Series) -> Self::Output { + let s = rhs.to_physical_repr(); + macro_rules! rem { + ($rhs:expr) => {{ $rhs.lhs_rem(self).into_series() }}; + } + + let out = downcast_as_macro_arg_physical!(s, rem); + + finish_cast(rhs, out) + } +} + +#[cfg(test)] +mod test { + use crate::prelude::*; + + #[test] + #[allow(clippy::eq_op)] + fn test_arithmetic_series() -> PolarsResult<()> { + // Series +-/* Series + let s = Series::new("foo".into(), [1, 2, 3]); + assert_eq!( + Vec::from((&s * &s)?.i32().unwrap()), + [Some(1), Some(4), Some(9)] + ); + assert_eq!( + Vec::from((&s / &s)?.i32().unwrap()), + [Some(1), Some(1), Some(1)] + ); + assert_eq!( + Vec::from((&s - &s)?.i32().unwrap()), + [Some(0), Some(0), Some(0)] + ); + assert_eq!( + Vec::from((&s + &s)?.i32().unwrap()), + [Some(2), Some(4), Some(6)] + ); + // Series +-/* Number + assert_eq!( + Vec::from((&s + 1).i32().unwrap()), + [Some(2), Some(3), Some(4)] + ); + assert_eq!( + Vec::from((&s - 1).i32().unwrap()), + [Some(0), Some(1), Some(2)] + ); + assert_eq!( + Vec::from((&s * 2).i32().unwrap()), + [Some(2), Some(4), Some(6)] + ); + assert_eq!( + Vec::from((&s / 2).i32().unwrap()), + [Some(0), Some(1), Some(1)] + ); + + // Lhs operations + assert_eq!( + Vec::from((1.add(&s)).i32().unwrap()), + [Some(2), Some(3), Some(4)] + ); + assert_eq!( + Vec::from((1.sub(&s)).i32().unwrap()), + [Some(0), Some(-1), Some(-2)] + ); + assert_eq!( + Vec::from((1.div(&s)).i32().unwrap()), + [Some(1), Some(0), Some(0)] + ); + assert_eq!( + Vec::from((1.mul(&s)).i32().unwrap()), + [Some(1), Some(2), Some(3)] + ); + assert_eq!( + Vec::from((1.rem(&s)).i32().unwrap()), + [Some(0), Some(1), Some(1)] + ); + + assert_eq!((&s * &s)?.name().as_str(), "foo"); + assert_eq!((&s * 1).name().as_str(), "foo"); + assert_eq!((1.div(&s)).name().as_str(), "foo"); + + Ok(()) + } + + #[test] + #[cfg(feature = "checked_arithmetic")] + fn test_checked_div() { + let s = Series::new("foo".into(), [1i32, 0, 1]); + let out = s.checked_div(&s).unwrap(); + assert_eq!(Vec::from(out.i32().unwrap()), &[Some(1), None, Some(1)]); + let out = s.checked_div_num(0).unwrap(); + assert_eq!(Vec::from(out.i32().unwrap()), &[None, None, None]); + + let s_f32 = Series::new("float32".into(), [1.0f32, 0.0, 1.0]); + let out = s_f32.checked_div(&s_f32).unwrap(); + assert_eq!( + Vec::from(out.f32().unwrap()), + &[Some(1.0f32), None, Some(1.0f32)] + ); + let out = s_f32.checked_div_num(0.0f32).unwrap(); + assert_eq!(Vec::from(out.f32().unwrap()), &[None, None, None]); + + let s_f64 = Series::new("float64".into(), [1.0f64, 0.0, 1.0]); + let out = s_f64.checked_div(&s_f64).unwrap(); + assert_eq!( + Vec::from(out.f64().unwrap()), + &[Some(1.0f64), None, Some(1.0f64)] + ); + let out = s_f64.checked_div_num(0.0f64).unwrap(); + assert_eq!(Vec::from(out.f64().unwrap()), &[None, None, None]); + } +} diff --git a/crates/polars-core/src/series/arithmetic/fixed_size_list.rs b/crates/polars-core/src/series/arithmetic/fixed_size_list.rs new file mode 100644 index 000000000000..70e57fc534b7 --- /dev/null +++ b/crates/polars-core/src/series/arithmetic/fixed_size_list.rs @@ -0,0 +1,825 @@ +use polars_error::{PolarsResult, feature_gated}; + +use super::list_utils::NumericOp; +use super::{ArrayChunked, FixedSizeListType, IntoSeries, NumOpsDispatchInner, Series}; + +impl NumOpsDispatchInner for FixedSizeListType { + fn add_to(lhs: &ArrayChunked, rhs: &Series) -> PolarsResult { + NumericFixedSizeListOp::add().execute(&lhs.clone().into_series(), rhs) + } + + fn subtract(lhs: &ArrayChunked, rhs: &Series) -> PolarsResult { + NumericFixedSizeListOp::sub().execute(&lhs.clone().into_series(), rhs) + } + + fn multiply(lhs: &ArrayChunked, rhs: &Series) -> PolarsResult { + NumericFixedSizeListOp::mul().execute(&lhs.clone().into_series(), rhs) + } + + fn divide(lhs: &ArrayChunked, rhs: &Series) -> PolarsResult { + NumericFixedSizeListOp::div().execute(&lhs.clone().into_series(), rhs) + } + + fn remainder(lhs: &ArrayChunked, rhs: &Series) -> PolarsResult { + NumericFixedSizeListOp::rem().execute(&lhs.clone().into_series(), rhs) + } +} + +#[derive(Clone)] +pub struct NumericFixedSizeListOp(NumericOp); + +impl NumericFixedSizeListOp { + pub fn add() -> Self { + Self(NumericOp::Add) + } + + pub fn sub() -> Self { + Self(NumericOp::Sub) + } + + pub fn mul() -> Self { + Self(NumericOp::Mul) + } + + pub fn div() -> Self { + Self(NumericOp::Div) + } + + pub fn rem() -> Self { + Self(NumericOp::Rem) + } + + pub fn floor_div() -> Self { + Self(NumericOp::FloorDiv) + } +} + +impl NumericFixedSizeListOp { + #[cfg_attr(not(feature = "array_arithmetic"), allow(unused))] + pub fn execute(&self, lhs: &Series, rhs: &Series) -> PolarsResult { + feature_gated!("array_arithmetic", { + NumericFixedSizeListOpHelper::execute_op(self.clone(), lhs.rechunk(), rhs.rechunk()) + .map(|x| x.into_series()) + }) + } +} + +#[cfg(feature = "array_arithmetic")] +use inner::NumericFixedSizeListOpHelper; + +#[cfg(feature = "array_arithmetic")] +mod inner { + use arrow::bitmap::{Bitmap, BitmapBuilder}; + use arrow::compute::utils::combine_validities_and; + use fixed_size_list::NumericFixedSizeListOp; + use list_utils::with_match_pl_num_arith; + use num_traits::Zero; + use polars_compute::arithmetic::pl_num::PlNumArithmetic; + use polars_utils::float::IsFloat; + + use super::super::list_utils::{BinaryOpApplyType, Broadcast, NumericOp}; + use super::super::*; + + /// Utility to perform a binary operation between the primitive values of + /// 2 columns, where at least one of the columns is a `ArrayChunked` type. + pub(super) struct NumericFixedSizeListOpHelper { + op: NumericFixedSizeListOp, + output_name: PlSmallStr, + /// We are just re-using the enum used for list arithmetic. + op_apply_type: BinaryOpApplyType, + broadcast: Broadcast, + /// Stride of the leaf array + stride: usize, + /// Widths at every level + output_widths: Vec, + output_dtype: DataType, + output_primitive_dtype: DataType, + /// Length of the outermost level + output_len: usize, + data_lhs: (Series, Vec>), + data_rhs: (Series, Vec>), + swapped: bool, + } + + /// This lets us separate some logic into `new()` to reduce the amount of + /// monomorphized code. + impl NumericFixedSizeListOpHelper { + /// Checks that: + /// * Dtypes are compatible: + /// * list<->primitive | primitive<->list + /// * list<->list both contain primitives (e.g. List) + /// * Primitive dtypes match + /// * Lengths are compatible: + /// * 1<->n | n<->1 + /// * n<->n + /// * Both sides have at least 1 non-NULL outer row. + /// + /// This returns an `Either` which may contain the final result to simplify + /// the implementation. + pub(super) fn execute_op( + op: NumericFixedSizeListOp, + lhs: Series, + rhs: Series, + ) -> PolarsResult { + assert_eq!(lhs.chunks().len(), 1); + assert_eq!(rhs.chunks().len(), 1); + + let dtype_lhs = lhs.dtype(); + let dtype_rhs = rhs.dtype(); + + let prim_dtype_lhs = dtype_lhs.leaf_dtype(); + let prim_dtype_rhs = dtype_rhs.leaf_dtype(); + + // + // Check leaf dtypes + // + + if !(prim_dtype_lhs.is_supported_list_arithmetic_input() + && prim_dtype_rhs.is_supported_list_arithmetic_input()) + { + polars_bail!( + ComputeError: "cannot {} non-numeric inner dtypes: (left: {}, right: {})", + op.0.name(), prim_dtype_lhs, prim_dtype_rhs + ) + } + + let output_primitive_dtype = + op.0.try_get_leaf_supertype(prim_dtype_lhs, prim_dtype_rhs)?; + + fn is_array_type_at_all_levels(dtype: &DataType) -> bool { + match dtype { + DataType::Array(inner, ..) => is_array_type_at_all_levels(inner), + dt if dt.is_supported_list_arithmetic_input() => true, + _ => false, + } + } + + fn array_stride_and_widths(dtype: &DataType, widths: &mut Vec) -> usize { + if let DataType::Array(inner, size_inner) = dtype { + widths.push(*size_inner); + *size_inner * array_stride_and_widths(inner.as_ref(), widths) + } else { + 1 + } + } + + // + // Get broadcasting information and output length + // + + let len_lhs = lhs.len(); + let len_rhs = rhs.len(); + + let (broadcast, output_len) = match (len_lhs, len_rhs) { + (l, r) if l == r => (Broadcast::NoBroadcast, l), + (1, v) => (Broadcast::Left, v), + (v, 1) => (Broadcast::Right, v), + (l, r) => polars_bail!( + ShapeMismatch: + "cannot {} two columns of differing lengths: {} != {}", + op.0.name(), l, r + ), + }; + + // + // Get validities for array levels + // + + fn push_array_validities_recursive(s: &Series, out: &mut Vec>) { + let mut opt_arr = s.array().ok().map(|x| { + assert_eq!(x.chunks().len(), 1); + x.downcast_get(0).unwrap() + }); + + while let Some(arr) = opt_arr { + // Push none if all-valid, this can potentially save some `repeat_bitmap()` + // materializations on broadcasting paths. + out.push(arr.validity().filter(|x| x.unset_bits() > 0).cloned()); + opt_arr = arr.values().as_any().downcast_ref::(); + } + } + + let mut array_validities_lhs = vec![]; + let mut array_validities_rhs = vec![]; + + push_array_validities_recursive(&lhs, &mut array_validities_lhs); + push_array_validities_recursive(&rhs, &mut array_validities_rhs); + + let op_err_msg = |err_reason: &str| { + polars_err!( + InvalidOperation: + "cannot {} columns: {}: (left: {}, right: {})", + op.0.name(), err_reason, dtype_lhs, dtype_rhs, + ) + }; + + let ensure_array_type_at_all_levels = |dtype: &DataType| { + if !is_array_type_at_all_levels(dtype) { + Err(op_err_msg("dtype was not array on all nesting levels")) + } else { + Ok(()) + } + }; + + // + // Check full dtypes and get output widths + // + + let mut output_widths = vec![]; + + let (op_apply_type, stride, output_dtype) = match (dtype_lhs, dtype_rhs) { + (dtype_lhs @ DataType::Array(..), dtype_rhs @ DataType::Array(..)) => { + // `get_arithmetic_field()` in the DSL checks this, but we also have to check here because if a user + // directly adds 2 series together it bypasses the DSL. + // This is currently duplicated code and should be replaced one day with an assert after Series ops get + // checked properly. + + if dtype_lhs.cast_leaf(output_primitive_dtype.clone()) + != dtype_rhs.cast_leaf(output_primitive_dtype.clone()) + { + return Err(op_err_msg("differing dtypes")); + }; + + // We only check dtype_lhs since we already checked dtype_lhs == dtype_rhs + ensure_array_type_at_all_levels(dtype_lhs)?; + + let stride = array_stride_and_widths(dtype_lhs, &mut output_widths); + + // For array<->array without broadcasting we return early here to avoid the rest + // of the setup code and dispatch layers. + if let Broadcast::NoBroadcast = broadcast { + let out = op.0.apply_series( + &lhs.get_leaf_array().cast(&output_primitive_dtype)?, + &rhs.get_leaf_array().cast(&output_primitive_dtype)?, + ); + + return Ok(finish_array_to_array_no_broadcast( + lhs.name().clone(), + &output_widths, + output_len, + &array_validities_lhs, + &array_validities_rhs, + out, + )); + } + + (BinaryOpApplyType::ListToList, stride, dtype_lhs) + }, + (array_dtype @ DataType::Array(..), x) + if x.is_supported_list_arithmetic_input() => + { + ensure_array_type_at_all_levels(array_dtype)?; + + let stride = array_stride_and_widths(array_dtype, &mut output_widths); + (BinaryOpApplyType::ListToPrimitive, stride, array_dtype) + }, + (x, array_dtype @ DataType::Array(..)) + if x.is_supported_list_arithmetic_input() => + { + ensure_array_type_at_all_levels(array_dtype)?; + + let stride = array_stride_and_widths(array_dtype, &mut output_widths); + (BinaryOpApplyType::PrimitiveToList, stride, array_dtype) + }, + (l, r) => polars_bail!( + InvalidOperation: + "cannot {} dtypes: {} != {}", + op.0.name(), l, r, + ), + }; + + let output_dtype = output_dtype.cast_leaf(output_primitive_dtype.clone()); + + assert!(!output_widths.is_empty()); + + if cfg!(debug_assertions) { + match (array_validities_lhs.len(), array_validities_rhs.len()) { + (l, r) if l == output_widths.len() && l == r && l > 0 => {}, + (v, 0) | (0, v) if v == output_widths.len() => {}, + _ => panic!(), // One side should have been an array. + } + } + + if output_len == 0 + || (matches!( + &op_apply_type, + BinaryOpApplyType::ListToList | BinaryOpApplyType::ListToPrimitive + ) && lhs.rechunk_validity().is_some_and(|x| x.set_bits() == 0)) + || (matches!( + &op_apply_type, + BinaryOpApplyType::ListToList | BinaryOpApplyType::PrimitiveToList + ) && rhs.rechunk_validity().is_some_and(|x| x.set_bits() == 0)) + { + let DataType::Array(inner_dtype, width) = output_dtype else { + unreachable!() + }; + + Ok(ArrayChunked::full_null_with_dtype( + lhs.name().clone(), + output_len, + inner_dtype.as_ref(), + width, + )) + } else { + Self { + op, + output_name: lhs.name().clone(), + op_apply_type, + broadcast, + stride, + output_widths, + output_dtype, + output_primitive_dtype, + output_len, + data_lhs: (lhs, array_validities_lhs), + data_rhs: (rhs, array_validities_rhs), + swapped: false, + } + .finish() + } + } + + pub(super) fn finish(mut self) -> PolarsResult { + // We have physical codepaths for a subset of the possible combinations of broadcasting and + // column types. The remaining combinations are handled by dispatching to the physical + // codepaths after operand swapping. + // + // # Physical impl table + // Legend + // * | N | // impl "N" + // * | [N] | // dispatches to impl "N" + // + // | L | N | R | // Broadcast (L)eft, (N)oBroadcast, (R)ight + // ListToList | [1] | 0 | 1 | + // ListToPrimitive | 2 | 3 | 4 | + // PrimitiveToList | [4] | [3] | [2] | + + self.swapped = true; + + match (&self.op_apply_type, &self.broadcast) { + // Mostly the same as ListNumericOp, however with fixed size list we also have + // (BinaryOpApplyType::ListToPrimitive, Broadcast::Left) as a physical impl. + (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast) => unreachable!(), // We return earlier for this + (BinaryOpApplyType::ListToList, Broadcast::Right) + | (BinaryOpApplyType::ListToPrimitive, Broadcast::Left) + | (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) + | (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => { + self.swapped = false; + self._finish_impl_dispatch() + }, + (BinaryOpApplyType::ListToList, Broadcast::Left) => { + self.broadcast = Broadcast::Right; + + std::mem::swap(&mut self.data_lhs, &mut self.data_rhs); + self._finish_impl_dispatch() + }, + (BinaryOpApplyType::PrimitiveToList, Broadcast::Right) => { + self.op_apply_type = BinaryOpApplyType::ListToPrimitive; + self.broadcast = Broadcast::Left; + + std::mem::swap(&mut self.data_lhs, &mut self.data_rhs); + self._finish_impl_dispatch() + }, + + (BinaryOpApplyType::PrimitiveToList, Broadcast::NoBroadcast) => { + self.op_apply_type = BinaryOpApplyType::ListToPrimitive; + + std::mem::swap(&mut self.data_lhs, &mut self.data_rhs); + self._finish_impl_dispatch() + }, + (BinaryOpApplyType::PrimitiveToList, Broadcast::Left) => { + self.op_apply_type = BinaryOpApplyType::ListToPrimitive; + self.broadcast = Broadcast::Right; + + std::mem::swap(&mut self.data_lhs, &mut self.data_rhs); + self._finish_impl_dispatch() + }, + } + } + + fn _finish_impl_dispatch(&mut self) -> PolarsResult { + let output_dtype = self.output_dtype.clone(); + let output_len = self.output_len; + + let prim_lhs = self + .data_lhs + .0 + .get_leaf_array() + .cast(&self.output_primitive_dtype)? + .rechunk(); + let prim_rhs = self + .data_rhs + .0 + .get_leaf_array() + .cast(&self.output_primitive_dtype)? + .rechunk(); + + debug_assert_eq!(prim_lhs.dtype(), prim_rhs.dtype()); + let prim_dtype = prim_lhs.dtype(); + debug_assert_eq!(prim_dtype, &self.output_primitive_dtype); + + // Safety: Leaf dtypes have been checked to be numeric by `try_new()` + let out = with_match_physical_numeric_polars_type!(&prim_dtype, |$T| { + self._finish_impl::<$T>(prim_lhs, prim_rhs) + }); + + debug_assert_eq!(out.dtype(), &output_dtype); + assert_eq!(out.len(), output_len); + + Ok(out) + } + + /// Internal use only - contains physical impls. + fn _finish_impl( + &mut self, + prim_s_lhs: Series, + prim_s_rhs: Series, + ) -> ArrayChunked + where + T::Native: PlNumArithmetic, + PrimitiveArray: + polars_compute::comparisons::TotalEqKernel, + T::Native: Zero + IsFloat, + { + let mut arr_lhs = { + let ca: &ChunkedArray = prim_s_lhs.as_ref().as_ref(); + assert_eq!(ca.chunks().len(), 1); + ca.downcast_get(0).unwrap().clone() + }; + + let mut arr_rhs = { + let ca: &ChunkedArray = prim_s_rhs.as_ref().as_ref(); + assert_eq!(ca.chunks().len(), 1); + ca.downcast_get(0).unwrap().clone() + }; + + self.op.0.prepare_numeric_op_side_validities::( + &mut arr_lhs, + &mut arr_rhs, + self.swapped, + ); + + match (&self.op_apply_type, &self.broadcast) { + (BinaryOpApplyType::ListToList, Broadcast::Right) => { + let mut out_vec: Vec = + Vec::with_capacity(self.output_len * self.stride); + let out_ptr: *mut T::Native = out_vec.as_mut_ptr(); + let stride = self.stride; + + with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| { + unsafe { + for outer_idx in 0..self.output_len { + for inner_idx in 0..stride { + let l = arr_lhs.value_unchecked(stride * outer_idx + inner_idx); + let r = arr_rhs.value_unchecked(inner_idx); + + *out_ptr.add(stride * outer_idx + inner_idx) = $OP(l, r); + } + } + } + }); + + unsafe { out_vec.set_len(self.output_len * self.stride) }; + + let leaf_validity = combine_validities_and( + arr_lhs.validity(), + arr_rhs + .validity() + .map(|x| repeat_bitmap(x, self.output_len)) + .as_ref(), + ); + + let arr = + PrimitiveArray::::from_vec(out_vec).with_validity(leaf_validity); + + let (_, validities_lhs) = std::mem::take(&mut self.data_lhs); + let (_, mut validities_rhs) = std::mem::take(&mut self.data_rhs); + + for v in validities_rhs.iter_mut() { + if let Some(v) = v.as_mut() { + *v = repeat_bitmap(v, self.output_len); + } + } + + finish_array_to_array_no_broadcast( + std::mem::take(&mut self.output_name), + &self.output_widths, + self.output_len, + &validities_lhs, + &validities_rhs, + Box::new(arr), + ) + }, + (BinaryOpApplyType::ListToPrimitive, Broadcast::Left) => { + let mut out_vec: Vec = + Vec::with_capacity(self.output_len * self.stride); + let out_ptr: *mut T::Native = out_vec.as_mut_ptr(); + let stride = self.stride; + + with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| { + unsafe { + for outer_idx in 0..self.output_len { + let r = arr_rhs.value_unchecked(outer_idx); + + for inner_idx in 0..stride { + let l = arr_lhs.value_unchecked(inner_idx); + + *out_ptr.add(stride * outer_idx + inner_idx) = $OP(l, r); + } + } + } + }); + + unsafe { out_vec.set_len(self.output_len * self.stride) }; + + let leaf_validity = combine_validities_array_to_primitive_no_broadcast( + arr_lhs + .validity() + .map(|x| repeat_bitmap(x, self.output_len)) + .as_ref(), + arr_rhs.validity(), + self.stride, + ); + + let arr = + PrimitiveArray::::from_vec(out_vec).with_validity(leaf_validity); + + let (_, mut validities) = std::mem::take(&mut self.data_lhs); + + for v in validities.iter_mut() { + if let Some(v) = v.as_mut() { + *v = repeat_bitmap(v, self.output_len); + } + } + + finish_with_level_validities( + std::mem::take(&mut self.output_name), + &self.output_widths, + self.output_len, + &validities, + Box::new(arr), + ) + }, + (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) => { + let mut out_vec: Vec = + Vec::with_capacity(self.output_len * self.stride); + let out_ptr: *mut T::Native = out_vec.as_mut_ptr(); + let stride = self.stride; + + with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| { + unsafe { + for outer_idx in 0..self.output_len { + let r = arr_rhs.value_unchecked(outer_idx); + + for inner_idx in 0..stride { + let idx = stride * outer_idx + inner_idx; + let l = arr_lhs.value_unchecked(idx); + + *out_ptr.add(idx) = $OP(l, r); + } + } + } + }); + + unsafe { out_vec.set_len(self.output_len * self.stride) }; + + let leaf_validity = combine_validities_array_to_primitive_no_broadcast( + arr_lhs.validity(), + arr_rhs.validity(), + self.stride, + ); + + let arr = + PrimitiveArray::::from_vec(out_vec).with_validity(leaf_validity); + + let (_, validities) = std::mem::take(&mut self.data_lhs); + + finish_with_level_validities( + std::mem::take(&mut self.output_name), + &self.output_widths, + self.output_len, + &validities, + Box::new(arr), + ) + }, + (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => { + assert_eq!(arr_rhs.len(), 1); + + let Some(r) = (unsafe { arr_rhs.get_unchecked(0) }) else { + // RHS is single primitive NULL, create the result by setting the leaf validity to all-NULL. + let (_, validities) = std::mem::take(&mut self.data_lhs); + return finish_with_level_validities( + std::mem::take(&mut self.output_name), + &self.output_widths, + self.output_len, + &validities, + Box::new( + arr_lhs.clone().with_validity(Some(Bitmap::new_with_value( + false, + arr_lhs.len(), + ))), + ), + ); + }; + + let arr = self + .op + .0 + .apply_array_to_scalar::(arr_lhs, r, self.swapped); + + let (_, validities) = std::mem::take(&mut self.data_lhs); + + finish_with_level_validities( + std::mem::take(&mut self.output_name), + &self.output_widths, + self.output_len, + &validities, + Box::new(arr), + ) + }, + v @ (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast) + | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::Right) + | v @ (BinaryOpApplyType::ListToList, Broadcast::Left) + | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::Left) + | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::NoBroadcast) => { + if cfg!(debug_assertions) { + panic!("operation was not re-written: {:?}", v) + } else { + unreachable!() + } + }, + } + } + } + + /// Build the result of an array<->array operation. + #[inline(never)] + fn finish_array_to_array_no_broadcast( + output_name: PlSmallStr, + widths: &[usize], + outer_len: usize, + validities_lhs: &[Option], + validities_rhs: &[Option], + output_leaf_array: Box, + ) -> ArrayChunked { + assert_eq!( + [widths.len(), validities_lhs.len(), validities_rhs.len()], + [widths.len(); 3] + ); + + let mut builder = FixedSizeListLevelBuilder::new(outer_len, widths); + + let validities_iter = validities_lhs + .iter() + .zip(validities_rhs) + .map(|(l, r)| combine_validities_and(l.as_ref(), r.as_ref())); + // `.rev()` - we build this from the inner level. + let mut iter = widths.iter().zip(validities_iter).rev(); + + let mut out = { + let (width, opt_validity) = iter.next().unwrap(); + builder.build_level(*width, opt_validity, output_leaf_array) + }; + + for (width, opt_validity) in iter { + out = builder.build_level(*width, opt_validity, Box::new(out)) + } + + ArrayChunked::with_chunk(output_name, out) + } + + /// Used when we are operating between array<->primitive, as in that case we only need the + /// validities from the array side. + #[inline(never)] + fn finish_with_level_validities( + output_name: PlSmallStr, + widths: &[usize], + outer_len: usize, + validities: &[Option], + output_leaf_array: Box, + ) -> ArrayChunked { + assert_eq!(widths.len(), validities.len()); + + let mut builder = FixedSizeListLevelBuilder::new(outer_len, widths); + + let validities_iter = validities.iter().cloned(); + // `.rev()` - we build this from the inner level. + let mut iter = widths.iter().zip(validities_iter).rev(); + + let mut out = { + let (width, opt_validity) = iter.next().unwrap(); + builder.build_level(*width, opt_validity, output_leaf_array) + }; + + for (width, opt_validity) in iter { + out = builder.build_level(*width, opt_validity, Box::new(out)) + } + + ArrayChunked::with_chunk(output_name, out) + } + + /// ```text + /// array [x, x, x, x, ..] (stride 2) + /// | / | / + /// |/ |/ + /// primitive [x, x, ..] + /// ``` + #[inline(never)] + fn combine_validities_array_to_primitive_no_broadcast( + array_leaf_validity: Option<&Bitmap>, + primitive_validity: Option<&Bitmap>, + stride: usize, + ) -> Option { + match (array_leaf_validity, primitive_validity) { + (Some(l), Some(r)) => Some((l.clone().make_mut(), r)), + (Some(v), None) => return Some(v.clone()), + // Materialize a full-true validity to re-use the codepath, as we still + // need to spread the bits from the RHS to the correct positions. + (None, Some(v)) => Some((Bitmap::new_with_value(true, stride * v.len()).make_mut(), v)), + (None, None) => None, + } + .map(|(mut validity_out, primitive_validity)| { + assert_eq!(validity_out.len(), stride * primitive_validity.len()); + + unsafe { + for outer_idx in 0..primitive_validity.len() { + let r = primitive_validity.get_bit_unchecked(outer_idx); + + for inner_idx in 0..stride { + let idx = stride * outer_idx + inner_idx; + let l = validity_out.get_unchecked(idx); + + validity_out.set_unchecked(idx, l & r); + } + } + } + + validity_out.freeze() + }) + } + + /// Returns `n_repeats` concatenated copies of the bitmap. + #[inline(never)] + fn repeat_bitmap(bitmap: &Bitmap, n_repeats: usize) -> Bitmap { + let mut out = BitmapBuilder::with_capacity(bitmap.len() * n_repeats); + + for _ in 0..n_repeats { + for bit in bitmap.iter() { + unsafe { out.push_unchecked(bit) } + } + } + + out.freeze() + } + + struct FixedSizeListLevelBuilder { + heights: as IntoIterator>::IntoIter, + } + + impl FixedSizeListLevelBuilder { + fn new(outer_len: usize, widths: &[usize]) -> Self { + let mut current_height = outer_len; + // We need to calculate heights here like this rather than dividing the stride because + // there can be 0-width arrays. + let mut heights = Vec::with_capacity(widths.len()); + + heights.push(current_height); + heights.extend(widths.iter().take(widths.len() - 1).map(|width| { + current_height *= *width; + current_height + })); + + Self { + heights: heights.into_iter(), + } + } + } + + impl FixedSizeListLevelBuilder { + fn build_level( + &mut self, + width: usize, + opt_validity: Option, + inner_array: Box, + ) -> FixedSizeListArray { + let level_height = self.heights.next_back().unwrap(); + assert_eq!(inner_array.len(), level_height * width); + + FixedSizeListArray::new( + ArrowDataType::FixedSizeList( + Box::new(ArrowField::new( + PlSmallStr::from_static("item"), + inner_array.dtype().clone(), + // is_nullable, we always set true otherwise the Eq kernels would panic + // when they assert == on the arrow `Field` + true, + )), + width, + ), + level_height, + inner_array, + opt_validity, + ) + } + } +} diff --git a/crates/polars-core/src/series/arithmetic/list.rs b/crates/polars-core/src/series/arithmetic/list.rs new file mode 100644 index 000000000000..f8608dce2f5f --- /dev/null +++ b/crates/polars-core/src/series/arithmetic/list.rs @@ -0,0 +1,938 @@ +//! Allow arithmetic operations for ListChunked. +//! use polars_error::{feature_gated, PolarsResult}; + +use polars_error::{PolarsResult, feature_gated}; + +use super::list_utils::NumericOp; +use super::{IntoSeries, ListChunked, ListType, NumOpsDispatchInner, Series}; + +impl NumOpsDispatchInner for ListType { + fn add_to(lhs: &ListChunked, rhs: &Series) -> PolarsResult { + NumericListOp::add().execute(&lhs.clone().into_series(), rhs) + } + + fn subtract(lhs: &ListChunked, rhs: &Series) -> PolarsResult { + NumericListOp::sub().execute(&lhs.clone().into_series(), rhs) + } + + fn multiply(lhs: &ListChunked, rhs: &Series) -> PolarsResult { + NumericListOp::mul().execute(&lhs.clone().into_series(), rhs) + } + + fn divide(lhs: &ListChunked, rhs: &Series) -> PolarsResult { + NumericListOp::div().execute(&lhs.clone().into_series(), rhs) + } + + fn remainder(lhs: &ListChunked, rhs: &Series) -> PolarsResult { + NumericListOp::rem().execute(&lhs.clone().into_series(), rhs) + } +} + +#[derive(Clone)] +pub struct NumericListOp(NumericOp); + +impl NumericListOp { + pub fn add() -> Self { + Self(NumericOp::Add) + } + + pub fn sub() -> Self { + Self(NumericOp::Sub) + } + + pub fn mul() -> Self { + Self(NumericOp::Mul) + } + + pub fn div() -> Self { + Self(NumericOp::Div) + } + + pub fn rem() -> Self { + Self(NumericOp::Rem) + } + + pub fn floor_div() -> Self { + Self(NumericOp::FloorDiv) + } +} + +impl NumericListOp { + #[cfg_attr(not(feature = "list_arithmetic"), allow(unused))] + pub fn execute(&self, lhs: &Series, rhs: &Series) -> PolarsResult { + feature_gated!("list_arithmetic", { + use either::Either; + + // `trim_to_normalized_offsets` ensures we don't perform excessive + // memory allocation / compute on memory regions that have been + // sliced out. + let lhs = lhs.list_rechunk_and_trim_to_normalized_offsets(); + let rhs = rhs.list_rechunk_and_trim_to_normalized_offsets(); + + let binary_op_exec = match ListNumericOpHelper::try_new( + self.clone(), + lhs.name().clone(), + lhs.dtype(), + rhs.dtype(), + lhs.len(), + rhs.len(), + { + let (a, b) = lhs.list_offsets_and_validities_recursive(); + debug_assert!(a.iter().all(|x| *x.first() as usize == 0)); + (a, b, lhs.clone()) + }, + { + let (a, b) = rhs.list_offsets_and_validities_recursive(); + debug_assert!(a.iter().all(|x| *x.first() as usize == 0)); + (a, b, rhs.clone()) + }, + lhs.rechunk_validity(), + rhs.rechunk_validity(), + )? { + Either::Left(v) => v, + Either::Right(ca) => return Ok(ca.into_series()), + }; + + Ok(binary_op_exec.finish()?.into_series()) + }) + } +} + +#[cfg(feature = "list_arithmetic")] +use inner::ListNumericOpHelper; + +#[cfg(feature = "list_arithmetic")] +mod inner { + use arrow::bitmap::Bitmap; + use arrow::compute::utils::combine_validities_and; + use arrow::offset::OffsetsBuffer; + use either::Either; + use list_utils::with_match_pl_num_arith; + use num_traits::Zero; + use polars_compute::arithmetic::pl_num::PlNumArithmetic; + use polars_utils::float::IsFloat; + + use super::super::list_utils::{BinaryOpApplyType, Broadcast, NumericOp}; + use super::super::*; + + /// Utility to perform a binary operation between the primitive values of + /// 2 columns, where at least one of the columns is a `ListChunked` type. + pub(super) struct ListNumericOpHelper { + op: NumericListOp, + output_name: PlSmallStr, + op_apply_type: BinaryOpApplyType, + broadcast: Broadcast, + output_dtype: DataType, + output_primitive_dtype: DataType, + output_len: usize, + /// Outer validity of the result, we always materialize this to reduce the + /// amount of code paths we need. + outer_validity: Bitmap, + // The series are stored as they are used for list broadcasting. + data_lhs: (Vec>, Vec>, Series), + data_rhs: (Vec>, Vec>, Series), + list_to_prim_lhs: Option<(Box, usize)>, + swapped: bool, + } + + /// This lets us separate some logic into `new()` to reduce the amount of + /// monomorphized code. + impl ListNumericOpHelper { + /// Checks that: + /// * Dtypes are compatible: + /// * list<->primitive | primitive<->list + /// * list<->list both contain primitives (e.g. List) + /// * Primitive dtypes match + /// * Lengths are compatible: + /// * 1<->n | n<->1 + /// * n<->n + /// * Both sides have at least 1 non-NULL outer row. + /// + /// Does not check: + /// * Whether the offsets are aligned for list<->list, this will be checked during execution. + /// + /// This returns an `Either` which may contain the final result to simplify + /// the implementation. + #[allow(clippy::too_many_arguments)] + pub(super) fn try_new( + op: NumericListOp, + output_name: PlSmallStr, + dtype_lhs: &DataType, + dtype_rhs: &DataType, + len_lhs: usize, + len_rhs: usize, + data_lhs: (Vec>, Vec>, Series), + data_rhs: (Vec>, Vec>, Series), + validity_lhs: Option, + validity_rhs: Option, + ) -> PolarsResult> { + let prim_dtype_lhs = dtype_lhs.leaf_dtype(); + let prim_dtype_rhs = dtype_rhs.leaf_dtype(); + + let output_primitive_dtype = + op.0.try_get_leaf_supertype(prim_dtype_lhs, prim_dtype_rhs)?; + + fn is_list_type_at_all_levels(dtype: &DataType) -> bool { + match dtype { + DataType::List(inner) => is_list_type_at_all_levels(inner), + dt if dt.is_supported_list_arithmetic_input() => true, + _ => false, + } + } + + let op_err_msg = |err_reason: &str| { + polars_err!( + InvalidOperation: + "cannot {} columns: {}: (left: {}, right: {})", + op.0.name(), err_reason, dtype_lhs, dtype_rhs, + ) + }; + + let ensure_list_type_at_all_levels = |dtype: &DataType| { + if !is_list_type_at_all_levels(dtype) { + Err(op_err_msg("dtype was not list on all nesting levels")) + } else { + Ok(()) + } + }; + + let (op_apply_type, output_dtype) = match (dtype_lhs, dtype_rhs) { + (l @ DataType::List(a), r @ DataType::List(b)) => { + // `get_arithmetic_field()` in the DSL checks this, but we also have to check here because if a user + // directly adds 2 series together it bypasses the DSL. + // This is currently duplicated code and should be replaced one day with an assert after Series ops get + // checked properly. + if ![a, b] + .into_iter() + .all(|x| x.is_supported_list_arithmetic_input()) + { + polars_bail!( + InvalidOperation: + "cannot {} two list columns with non-numeric inner types: (left: {}, right: {})", + op.0.name(), l, r, + ); + } + (BinaryOpApplyType::ListToList, l) + }, + (list_dtype @ DataType::List(_), x) if x.is_supported_list_arithmetic_input() => { + ensure_list_type_at_all_levels(list_dtype)?; + (BinaryOpApplyType::ListToPrimitive, list_dtype) + }, + (x, list_dtype @ DataType::List(_)) if x.is_supported_list_arithmetic_input() => { + ensure_list_type_at_all_levels(list_dtype)?; + (BinaryOpApplyType::PrimitiveToList, list_dtype) + }, + (l, r) => polars_bail!( + InvalidOperation: + "{} operation not supported for dtypes: {} != {}", + op.0.name(), l, r, + ), + }; + + let output_dtype = output_dtype.cast_leaf(output_primitive_dtype.clone()); + + let (broadcast, output_len) = match (len_lhs, len_rhs) { + (l, r) if l == r => (Broadcast::NoBroadcast, l), + (1, v) => (Broadcast::Left, v), + (v, 1) => (Broadcast::Right, v), + (l, r) => polars_bail!( + ShapeMismatch: + "cannot {} two columns of differing lengths: {} != {}", + op.0.name(), l, r + ), + }; + + let DataType::List(output_inner_dtype) = &output_dtype else { + unreachable!() + }; + + // # NULL semantics + // * [[1, 2]] (List[List[Int64]]) + NULL (Int64) => [[NULL, NULL]] + // * Essentially as if the NULL primitive was added to every primitive in the row of the list column. + // * NULL (List[Int64]) + 1 (Int64) => NULL + // * NULL (List[Int64]) + [1] (List[Int64]) => NULL + + if output_len == 0 + || (matches!( + &op_apply_type, + BinaryOpApplyType::ListToList | BinaryOpApplyType::ListToPrimitive + ) && validity_lhs.as_ref().is_some_and(|x| x.set_bits() == 0)) + || (matches!( + &op_apply_type, + BinaryOpApplyType::ListToList | BinaryOpApplyType::PrimitiveToList + ) && validity_rhs.as_ref().is_some_and(|x| x.set_bits() == 0)) + { + return Ok(Either::Right(ListChunked::full_null_with_dtype( + output_name.clone(), + output_len, + output_inner_dtype.as_ref(), + ))); + } + + // At this point: + // * All unit length list columns have a valid outer value. + + // The outer validity is just the validity of any non-broadcasting lists. + let outer_validity = match (&op_apply_type, &broadcast, validity_lhs, validity_rhs) { + // Both lists with same length, we combine the validity. + (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast, l, r) => { + combine_validities_and(l.as_ref(), r.as_ref()) + }, + // Match all other combinations that have non-broadcasting lists. + ( + BinaryOpApplyType::ListToList | BinaryOpApplyType::ListToPrimitive, + Broadcast::NoBroadcast | Broadcast::Right, + v, + _, + ) + | ( + BinaryOpApplyType::ListToList | BinaryOpApplyType::PrimitiveToList, + Broadcast::NoBroadcast | Broadcast::Left, + _, + v, + ) => v, + _ => None, + } + .unwrap_or_else(|| Bitmap::new_with_value(true, output_len)); + + Ok(Either::Left(Self { + op, + output_name, + op_apply_type, + broadcast, + output_dtype: output_dtype.clone(), + output_primitive_dtype, + output_len, + outer_validity, + data_lhs, + data_rhs, + list_to_prim_lhs: None, + swapped: false, + })) + } + + pub(super) fn finish(mut self) -> PolarsResult { + // We have physical codepaths for a subset of the possible combinations of broadcasting and + // column types. The remaining combinations are handled by dispatching to the physical + // codepaths after operand swapping and/or materialized broadcasting. + // + // # Physical impl table + // Legend + // * | N | // impl "N" + // * | [N] | // dispatches to impl "N" + // + // | L | N | R | // Broadcast (L)eft, (N)oBroadcast, (R)ight + // ListToList | [1] | 0 | 1 | + // ListToPrimitive | [2] | 2 | 3 | // list broadcasting just materializes and dispatches to NoBroadcast + // PrimitiveToList | [3] | [2] | [2] | + + self.swapped = true; + + match (&self.op_apply_type, &self.broadcast) { + (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast) + | (BinaryOpApplyType::ListToList, Broadcast::Right) + | (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) + | (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => { + self.swapped = false; + self._finish_impl_dispatch() + }, + (BinaryOpApplyType::ListToList, Broadcast::Left) => { + self.broadcast = Broadcast::Right; + + std::mem::swap(&mut self.data_lhs, &mut self.data_rhs); + self._finish_impl_dispatch() + }, + (BinaryOpApplyType::ListToPrimitive, Broadcast::Left) => { + self.list_to_prim_lhs + .replace(Self::materialize_broadcasted_list( + &mut self.data_lhs, + self.output_len, + &self.output_primitive_dtype, + )); + + self.broadcast = Broadcast::NoBroadcast; + + // This does not swap! We are just dispatching to `NoBroadcast` + // after materializing the broadcasted list array. + self.swapped = false; + self._finish_impl_dispatch() + }, + (BinaryOpApplyType::PrimitiveToList, Broadcast::NoBroadcast) => { + self.op_apply_type = BinaryOpApplyType::ListToPrimitive; + + std::mem::swap(&mut self.data_lhs, &mut self.data_rhs); + self._finish_impl_dispatch() + }, + (BinaryOpApplyType::PrimitiveToList, Broadcast::Right) => { + // We materialize the list columns with `new_from_index`, as otherwise we'd have to + // implement logic that broadcasts the offsets and validities across multiple levels + // of nesting. But we will re-use the materialized memory to store the result. + + self.list_to_prim_lhs + .replace(Self::materialize_broadcasted_list( + &mut self.data_rhs, + self.output_len, + &self.output_primitive_dtype, + )); + + self.op_apply_type = BinaryOpApplyType::ListToPrimitive; + self.broadcast = Broadcast::NoBroadcast; + + std::mem::swap(&mut self.data_lhs, &mut self.data_rhs); + self._finish_impl_dispatch() + }, + (BinaryOpApplyType::PrimitiveToList, Broadcast::Left) => { + self.op_apply_type = BinaryOpApplyType::ListToPrimitive; + self.broadcast = Broadcast::Right; + + std::mem::swap(&mut self.data_lhs, &mut self.data_rhs); + self._finish_impl_dispatch() + }, + } + } + + fn _finish_impl_dispatch(&mut self) -> PolarsResult { + let output_dtype = self.output_dtype.clone(); + let output_len = self.output_len; + + let prim_lhs = self + .data_lhs + .2 + .get_leaf_array() + .cast(&self.output_primitive_dtype)? + .rechunk(); + let prim_rhs = self + .data_rhs + .2 + .get_leaf_array() + .cast(&self.output_primitive_dtype)? + .rechunk(); + + debug_assert_eq!(prim_lhs.dtype(), prim_rhs.dtype()); + let prim_dtype = prim_lhs.dtype(); + debug_assert_eq!(prim_dtype, &self.output_primitive_dtype); + + // Safety: Leaf dtypes have been checked to be numeric by `try_new()` + let out = with_match_physical_numeric_polars_type!(&prim_dtype, |$T| { + self._finish_impl::<$T>(prim_lhs, prim_rhs) + })?; + + debug_assert_eq!(out.dtype(), &output_dtype); + assert_eq!(out.len(), output_len); + + Ok(out) + } + + /// Internal use only - contains physical impls. + fn _finish_impl( + &mut self, + prim_s_lhs: Series, + prim_s_rhs: Series, + ) -> PolarsResult + where + T::Native: PlNumArithmetic, + PrimitiveArray: + polars_compute::comparisons::TotalEqKernel, + T::Native: Zero + IsFloat, + { + #[inline(never)] + fn check_mismatch_pos( + mismatch_pos: usize, + offsets_lhs: &OffsetsBuffer, + offsets_rhs: &OffsetsBuffer, + ) -> PolarsResult<()> { + if mismatch_pos < offsets_lhs.len_proxy() { + // RHS could be broadcasted + let len_r = offsets_rhs.length_at(if offsets_rhs.len_proxy() == 1 { + 0 + } else { + mismatch_pos + }); + polars_bail!( + ShapeMismatch: + "list lengths differed at index {}: {} != {}", + mismatch_pos, + offsets_lhs.length_at(mismatch_pos), len_r + ) + } + Ok(()) + } + + let mut arr_lhs = { + let ca: &ChunkedArray = prim_s_lhs.as_ref().as_ref(); + assert_eq!(ca.chunks().len(), 1); + ca.downcast_get(0).unwrap().clone() + }; + + let mut arr_rhs = { + let ca: &ChunkedArray = prim_s_rhs.as_ref().as_ref(); + assert_eq!(ca.chunks().len(), 1); + ca.downcast_get(0).unwrap().clone() + }; + + match (&self.op_apply_type, &self.broadcast) { + // We skip for this because it dispatches to `ArithmeticKernel`, which handles the + // validities for us. + (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {}, + _ if self.list_to_prim_lhs.is_none() => { + self.op.0.prepare_numeric_op_side_validities::( + &mut arr_lhs, + &mut arr_rhs, + self.swapped, + ) + }, + (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) => { + // `self.list_to_prim_lhs` is `Some(_)`, this is handled later. + }, + _ => unreachable!(), + } + + // + // General notes + // * Lists can be: + // * Sliced, in which case the primitive/leaf array needs to be indexed starting from an + // offset instead of 0. + // * Masked, in which case the masked rows are permitted to have non-matching widths. + // + + let out = match (&self.op_apply_type, &self.broadcast) { + (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast) => { + let offsets_lhs = &self.data_lhs.0[0]; + let offsets_rhs = &self.data_rhs.0[0]; + + assert_eq!(offsets_lhs.len_proxy(), offsets_rhs.len_proxy()); + + // Output primitive (and optional validity) are aligned to the LHS input. + let n_values = arr_lhs.len(); + let mut out_vec: Vec = Vec::with_capacity(n_values); + let out_ptr: *mut T::Native = out_vec.as_mut_ptr(); + + // Counter that stops being incremented at the first row position with mismatching + // list lengths. + let mut mismatch_pos = 0; + + with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| { + for (i, ((lhs_start, lhs_len), (rhs_start, rhs_len))) in offsets_lhs + .offset_and_length_iter() + .zip(offsets_rhs.offset_and_length_iter()) + .enumerate() + { + if + (mismatch_pos == i) + & ( + (lhs_len == rhs_len) + | unsafe { !self.outer_validity.get_bit_unchecked(i) } + ) + { + mismatch_pos += 1; + } + + // Both sides are lists, we restrict the index to the min length to avoid + // OOB memory access. + let len: usize = lhs_len.min(rhs_len); + + for i in 0..len { + let l_idx = i + lhs_start; + let r_idx = i + rhs_start; + + let l = unsafe { arr_lhs.value_unchecked(l_idx) }; + let r = unsafe { arr_rhs.value_unchecked(r_idx) }; + let v = $OP(l, r); + + unsafe { out_ptr.add(l_idx).write(v) }; + } + } + }); + + check_mismatch_pos(mismatch_pos, offsets_lhs, offsets_rhs)?; + + unsafe { out_vec.set_len(n_values) }; + + /// Reduce monomorphization + #[inline(never)] + fn combine_validities_list_to_list_no_broadcast( + offsets_lhs: &OffsetsBuffer, + offsets_rhs: &OffsetsBuffer, + validity_lhs: Option<&Bitmap>, + validity_rhs: Option<&Bitmap>, + len_lhs: usize, + ) -> Option { + match (validity_lhs, validity_rhs) { + (Some(l), Some(r)) => Some((l.clone().make_mut(), r)), + (Some(v), None) => return Some(v.clone()), + (None, Some(v)) => { + Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v)) + }, + (None, None) => None, + } + .map(|(mut validity_out, validity_rhs)| { + for ((lhs_start, lhs_len), (rhs_start, rhs_len)) in offsets_lhs + .offset_and_length_iter() + .zip(offsets_rhs.offset_and_length_iter()) + { + let len: usize = lhs_len.min(rhs_len); + + for i in 0..len { + let l_idx = i + lhs_start; + let r_idx = i + rhs_start; + + let l_valid = unsafe { validity_out.get_unchecked(l_idx) }; + let r_valid = unsafe { validity_rhs.get_bit_unchecked(r_idx) }; + let is_valid = l_valid & r_valid; + + // Size and alignment of validity vec are based on LHS. + unsafe { validity_out.set_unchecked(l_idx, is_valid) }; + } + } + + validity_out.freeze() + }) + } + + let leaf_validity = combine_validities_list_to_list_no_broadcast( + offsets_lhs, + offsets_rhs, + arr_lhs.validity(), + arr_rhs.validity(), + arr_lhs.len(), + ); + + let arr = + PrimitiveArray::::from_vec(out_vec).with_validity(leaf_validity); + + let (offsets, validities, _) = std::mem::take(&mut self.data_lhs); + assert_eq!(offsets.len(), 1); + + self.finish_offsets_and_validities(Box::new(arr), offsets, validities) + }, + (BinaryOpApplyType::ListToList, Broadcast::Right) => { + let offsets_lhs = &self.data_lhs.0[0]; + let offsets_rhs = &self.data_rhs.0[0]; + + // Output primitive (and optional validity) are aligned to the LHS input. + let n_values = arr_lhs.len(); + let mut out_vec: Vec = Vec::with_capacity(n_values); + let out_ptr: *mut T::Native = out_vec.as_mut_ptr(); + + assert_eq!(offsets_rhs.len_proxy(), 1); + let rhs_start = *offsets_rhs.first() as usize; + let width = offsets_rhs.range() as usize; + + let mut mismatch_pos = 0; + + with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| { + for (i, (lhs_start, lhs_len)) in offsets_lhs.offset_and_length_iter().enumerate() { + if ((lhs_len == width) & (mismatch_pos == i)) + | unsafe { !self.outer_validity.get_bit_unchecked(i) } + { + mismatch_pos += 1; + } + + let len: usize = lhs_len.min(width); + + for i in 0..len { + let l_idx = i + lhs_start; + let r_idx = i + rhs_start; + + let l = unsafe { arr_lhs.value_unchecked(l_idx) }; + let r = unsafe { arr_rhs.value_unchecked(r_idx) }; + let v = $OP(l, r); + + unsafe { + out_ptr.add(l_idx).write(v); + } + } + } + }); + + check_mismatch_pos(mismatch_pos, offsets_lhs, offsets_rhs)?; + + unsafe { out_vec.set_len(n_values) }; + + #[inline(never)] + fn combine_validities_list_to_list_broadcast_right( + offsets_lhs: &OffsetsBuffer, + validity_lhs: Option<&Bitmap>, + validity_rhs: Option<&Bitmap>, + len_lhs: usize, + width: usize, + rhs_start: usize, + ) -> Option { + match (validity_lhs, validity_rhs) { + (Some(l), Some(r)) => Some((l.clone().make_mut(), r)), + (Some(v), None) => return Some(v.clone()), + (None, Some(v)) => { + Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v)) + }, + (None, None) => None, + } + .map(|(mut validity_out, validity_rhs)| { + for (lhs_start, lhs_len) in offsets_lhs.offset_and_length_iter() { + let len: usize = lhs_len.min(width); + + for i in 0..len { + let l_idx = i + lhs_start; + let r_idx = i + rhs_start; + + let l_valid = unsafe { validity_out.get_unchecked(l_idx) }; + let r_valid = unsafe { validity_rhs.get_bit_unchecked(r_idx) }; + let is_valid = l_valid & r_valid; + + // Size and alignment of validity vec are based on LHS. + unsafe { validity_out.set_unchecked(l_idx, is_valid) }; + } + } + + validity_out.freeze() + }) + } + + let leaf_validity = combine_validities_list_to_list_broadcast_right( + offsets_lhs, + arr_lhs.validity(), + arr_rhs.validity(), + arr_lhs.len(), + width, + rhs_start, + ); + + let arr = + PrimitiveArray::::from_vec(out_vec).with_validity(leaf_validity); + + let (offsets, validities, _) = std::mem::take(&mut self.data_lhs); + assert_eq!(offsets.len(), 1); + + self.finish_offsets_and_validities(Box::new(arr), offsets, validities) + }, + (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) + if self.list_to_prim_lhs.is_none() => + { + let offsets_lhs = self.data_lhs.0.as_slice(); + + // Notes + // * Primitive indexing starts from 0 + // * Output is aligned to LHS array + + let n_values = arr_lhs.len(); + let mut out_vec = Vec::::with_capacity(n_values); + let out_ptr = out_vec.as_mut_ptr(); + + with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| { + for (i, l_range) in OffsetsBuffer::::leaf_ranges_iter(offsets_lhs).enumerate() + { + let r = unsafe { arr_rhs.value_unchecked(i) }; + for l_idx in l_range { + unsafe { + let l = arr_lhs.value_unchecked(l_idx); + let v = $OP(l, r); + out_ptr.add(l_idx).write(v); + } + } + } + }); + + unsafe { out_vec.set_len(n_values) } + + let leaf_validity = combine_validities_list_to_primitive_no_broadcast( + offsets_lhs, + arr_lhs.validity(), + arr_rhs.validity(), + arr_lhs.len(), + ); + + let arr = + PrimitiveArray::::from_vec(out_vec).with_validity(leaf_validity); + + let (offsets, validities, _) = std::mem::take(&mut self.data_lhs); + self.finish_offsets_and_validities(Box::new(arr), offsets, validities) + }, + // If we are dispatched here, it means that the LHS array is a unique allocation created + // after a unit-length list column was broadcasted, so this codepath mutably stores the + // results back into the LHS array to save memory. + (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) => { + let offsets_lhs = self.data_lhs.0.as_slice(); + + let (mut arr, n_values) = Option::take(&mut self.list_to_prim_lhs).unwrap(); + let arr = arr + .as_any_mut() + .downcast_mut::>() + .unwrap(); + let mut arr_lhs = std::mem::take(arr); + + self.op.0.prepare_numeric_op_side_validities::( + &mut arr_lhs, + &mut arr_rhs, + self.swapped, + ); + + let arr_lhs_mut_slice = arr_lhs.get_mut_values().unwrap(); + assert_eq!(arr_lhs_mut_slice.len(), n_values); + + with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| { + for (i, l_range) in OffsetsBuffer::::leaf_ranges_iter(offsets_lhs).enumerate() + { + let r = unsafe { arr_rhs.value_unchecked(i) }; + for l_idx in l_range { + unsafe { + let l = arr_lhs_mut_slice.get_unchecked_mut(l_idx); + *l = $OP(*l, r); + } + } + } + }); + + let leaf_validity = combine_validities_list_to_primitive_no_broadcast( + offsets_lhs, + arr_lhs.validity(), + arr_rhs.validity(), + arr_lhs.len(), + ); + + let arr = arr_lhs.with_validity(leaf_validity); + + let (offsets, validities, _) = std::mem::take(&mut self.data_lhs); + self.finish_offsets_and_validities(Box::new(arr), offsets, validities) + }, + (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => { + assert_eq!(arr_rhs.len(), 1); + + let Some(r) = (unsafe { arr_rhs.get_unchecked(0) }) else { + // RHS is single primitive NULL, create the result by setting the leaf validity to all-NULL. + let (offsets, validities, _) = std::mem::take(&mut self.data_lhs); + return Ok(self.finish_offsets_and_validities( + Box::new( + arr_lhs.clone().with_validity(Some(Bitmap::new_with_value( + false, + arr_lhs.len(), + ))), + ), + offsets, + validities, + )); + }; + + let arr = self + .op + .0 + .apply_array_to_scalar::(arr_lhs, r, self.swapped); + let (offsets, validities, _) = std::mem::take(&mut self.data_lhs); + + self.finish_offsets_and_validities(Box::new(arr), offsets, validities) + }, + v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::Right) + | v @ (BinaryOpApplyType::ListToList, Broadcast::Left) + | v @ (BinaryOpApplyType::ListToPrimitive, Broadcast::Left) + | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::Left) + | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::NoBroadcast) => { + if cfg!(debug_assertions) { + panic!("operation was not re-written: {:?}", v) + } else { + unreachable!() + } + }, + }; + + Ok(out) + } + + /// Construct the result `ListChunked` from the leaf array and the offsets/validities of every + /// level. + fn finish_offsets_and_validities( + &mut self, + leaf_array: Box, + offsets: Vec>, + validities: Vec>, + ) -> ListChunked { + assert!(!offsets.is_empty()); + assert_eq!(offsets.len(), validities.len()); + let mut results = leaf_array; + + let mut iter = offsets.into_iter().zip(validities).rev(); + + while iter.len() > 1 { + let (offsets, validity) = iter.next().unwrap(); + let dtype = LargeListArray::default_datatype(results.dtype().clone()); + results = Box::new(LargeListArray::new(dtype, offsets, results, validity)); + } + + // The combined outer validity is pre-computed during `try_new()` + let (offsets, _) = iter.next().unwrap(); + let validity = std::mem::take(&mut self.outer_validity); + let dtype = LargeListArray::default_datatype(results.dtype().clone()); + let results = LargeListArray::new(dtype, offsets, results, Some(validity)); + + ListChunked::with_chunk(std::mem::take(&mut self.output_name), results) + } + + fn materialize_broadcasted_list( + side_data: &mut (Vec>, Vec>, Series), + output_len: usize, + output_primitive_dtype: &DataType, + ) -> (Box, usize) { + let s = &side_data.2; + assert_eq!(s.len(), 1); + + let expected_n_values = { + let offsets = s.list_offsets_and_validities_recursive().0; + output_len * OffsetsBuffer::::leaf_full_start_end(&offsets).len() + }; + + let ca = s.list().unwrap(); + // Remember to cast the leaf primitives to the supertype. + let ca = ca + .cast(&ca.dtype().cast_leaf(output_primitive_dtype.clone())) + .unwrap(); + assert!(output_len > 1); // In case there is a fast-path that doesn't give us owned data. + let ca = ca.new_from_index(0, output_len).rechunk(); + + let s = ca.into_series(); + + *side_data = { + let (a, b) = s.list_offsets_and_validities_recursive(); + // `Series::default()`: This field in the tuple is no longer used. + (a, b, Series::default()) + }; + + let n_values = OffsetsBuffer::::leaf_full_start_end(&side_data.0).len(); + assert_eq!(n_values, expected_n_values); + + let mut s = s.get_leaf_array(); + let v = unsafe { s.chunks_mut() }; + + assert_eq!(v.len(), 1); + (v.swap_remove(0), n_values) + } + } + + /// Used in 2 places, so it's outside here. + #[inline(never)] + fn combine_validities_list_to_primitive_no_broadcast( + offsets_lhs: &[OffsetsBuffer], + validity_lhs: Option<&Bitmap>, + validity_rhs: Option<&Bitmap>, + len_lhs: usize, + ) -> Option { + match (validity_lhs, validity_rhs) { + (Some(l), Some(r)) => Some((l.clone().make_mut(), r)), + (Some(v), None) => return Some(v.clone()), + // Materialize a full-true validity to re-use the codepath, as we still + // need to spread the bits from the RHS to the correct positions. + (None, Some(v)) => Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v)), + (None, None) => None, + } + .map(|(mut validity_out, validity_rhs)| { + for (i, l_range) in OffsetsBuffer::::leaf_ranges_iter(offsets_lhs).enumerate() { + let r_valid = unsafe { validity_rhs.get_bit_unchecked(i) }; + for l_idx in l_range { + let l_valid = unsafe { validity_out.get_unchecked(l_idx) }; + let is_valid = l_valid & r_valid; + + // Size and alignment of validity vec are based on LHS. + unsafe { validity_out.set_unchecked(l_idx, is_valid) }; + } + } + + validity_out.freeze() + }) + } +} diff --git a/crates/polars-core/src/series/arithmetic/list_utils.rs b/crates/polars-core/src/series/arithmetic/list_utils.rs new file mode 100644 index 000000000000..79072de3381b --- /dev/null +++ b/crates/polars-core/src/series/arithmetic/list_utils.rs @@ -0,0 +1,213 @@ +/// Functionality shared between list and array arithmetic implementations. +use arrow::array::{Array, PrimitiveArray}; +use arrow::compute::utils::combine_validities_and; +use num_traits::Zero; +use polars_compute::arithmetic::ArithmeticKernel; +use polars_compute::comparisons::TotalEqKernel; +use polars_error::PolarsResult; +use polars_utils::float::IsFloat; + +use super::*; +use crate::series::ChunkedArray; +use crate::utils::try_get_supertype; + +#[derive(Debug, Clone)] +pub(super) enum NumericOp { + Add, + Sub, + Mul, + Div, + Rem, + FloorDiv, +} + +impl NumericOp { + pub(super) fn name(&self) -> &'static str { + match self { + Self::Add => "add", + Self::Sub => "sub", + Self::Mul => "mul", + Self::Div => "div", + Self::Rem => "rem", + Self::FloorDiv => "floor_div", + } + } + + pub(super) fn try_get_leaf_supertype( + &self, + prim_dtype_lhs: &DataType, + prim_dtype_rhs: &DataType, + ) -> PolarsResult { + let dtype = try_get_supertype(prim_dtype_lhs, prim_dtype_rhs)?; + + Ok(if matches!(self, Self::Div) { + if dtype.is_float() { + dtype + } else { + DataType::Float64 + } + } else { + dtype + }) + } + + /// For operations that perform divisions on integers, sets the validity to NULL on rows where + /// the denominator is 0. + pub(super) fn prepare_numeric_op_side_validities( + &self, + lhs: &mut PrimitiveArray, + rhs: &mut PrimitiveArray, + swapped: bool, + ) where + PrimitiveArray: polars_compute::comparisons::TotalEqKernel, + T::Native: Zero + IsFloat, + { + if !T::Native::is_float() { + match self { + Self::Div | Self::Rem | Self::FloorDiv => { + let target = if swapped { lhs } else { rhs }; + let ne_0 = target.tot_ne_kernel_broadcast(&T::Native::zero()); + let validity = combine_validities_and(target.validity(), Some(&ne_0)); + target.set_validity(validity); + }, + _ => {}, + } + } + } + + /// # Panics + /// Panics if: + /// * lhs.len() != rhs.len() + /// * dtype is not numeric. + pub(super) fn apply_series(&self, lhs: &Series, rhs: &Series) -> Box { + assert_eq!(lhs.len(), rhs.len()); + debug_assert_eq!(lhs.dtype(), rhs.dtype()); + + let lhs = lhs.rechunk(); + let rhs = rhs.rechunk(); + + with_match_physical_numeric_polars_type!(lhs.dtype(), |$T| { + let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); + let rhs: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); + + let lhs = lhs.downcast_get(0).unwrap(); + let rhs = rhs.downcast_get(0).unwrap(); + + Box::new(self.apply_arithmetic_kernel::<$T>(lhs.clone(), rhs.clone())) + }) + } + + fn apply_arithmetic_kernel( + &self, + lhs: PrimitiveArray, + rhs: PrimitiveArray, + ) -> PrimitiveArray { + match self { + Self::Add => ArithmeticKernel::wrapping_add(lhs, rhs), + Self::Sub => ArithmeticKernel::wrapping_sub(lhs, rhs), + Self::Mul => ArithmeticKernel::wrapping_mul(lhs, rhs), + Self::Div => ArithmeticKernel::legacy_div(lhs, rhs), + Self::Rem => ArithmeticKernel::wrapping_mod(lhs, rhs), + Self::FloorDiv => ArithmeticKernel::wrapping_floor_div(lhs, rhs), + } + } + + /// For list<->primitive where the primitive is broadcasted, we can dispatch to + /// `ArithmeticKernel`, which can have optimized codepaths for when one side is + /// a scalar. + pub(super) fn apply_array_to_scalar( + &self, + arr_lhs: PrimitiveArray, + r: T::Native, + swapped: bool, + ) -> PrimitiveArray { + match self { + Self::Add => ArithmeticKernel::wrapping_add_scalar(arr_lhs, r), + Self::Sub => { + if swapped { + ArithmeticKernel::wrapping_sub_scalar_lhs(r, arr_lhs) + } else { + ArithmeticKernel::wrapping_sub_scalar(arr_lhs, r) + } + }, + Self::Mul => ArithmeticKernel::wrapping_mul_scalar(arr_lhs, r), + Self::Div => { + if swapped { + ArithmeticKernel::legacy_div_scalar_lhs(r, arr_lhs) + } else { + ArithmeticKernel::legacy_div_scalar(arr_lhs, r) + } + }, + Self::Rem => { + if swapped { + ArithmeticKernel::wrapping_mod_scalar_lhs(r, arr_lhs) + } else { + ArithmeticKernel::wrapping_mod_scalar(arr_lhs, r) + } + }, + Self::FloorDiv => { + if swapped { + ArithmeticKernel::wrapping_floor_div_scalar_lhs(r, arr_lhs) + } else { + ArithmeticKernel::wrapping_floor_div_scalar(arr_lhs, r) + } + }, + } + } +} + +macro_rules! with_match_pl_num_arith { + ($op:expr, $swapped:expr, | $_:tt $OP:tt | $($body:tt)* ) => ({ + macro_rules! __with_func__ {( $_ $OP:tt ) => ( $($body)* )} + + match $op { + NumericOp::Add => __with_func__! { (PlNumArithmetic::wrapping_add) }, + NumericOp::Sub => { + if $swapped { + __with_func__! { (|b, a| PlNumArithmetic::wrapping_sub(a, b)) } + } else { + __with_func__! { (PlNumArithmetic::wrapping_sub) } + } + }, + NumericOp::Mul => __with_func__! { (PlNumArithmetic::wrapping_mul) }, + NumericOp::Div => { + if $swapped { + __with_func__! { (|b, a| PlNumArithmetic::legacy_div(a, b)) } + } else { + __with_func__! { (PlNumArithmetic::legacy_div) } + } + }, + NumericOp::Rem => { + if $swapped { + __with_func__! { (|b, a| PlNumArithmetic::wrapping_mod(a, b)) } + } else { + __with_func__! { (PlNumArithmetic::wrapping_mod) } + } + }, + NumericOp::FloorDiv => { + if $swapped { + __with_func__! { (|b, a| PlNumArithmetic::wrapping_floor_div(a, b)) } + } else { + __with_func__! { (PlNumArithmetic::wrapping_floor_div) } + } + }, + } + }) +} + +pub(super) use with_match_pl_num_arith; + +#[derive(Debug)] +pub(super) enum BinaryOpApplyType { + ListToList, + ListToPrimitive, + PrimitiveToList, +} + +#[derive(Debug)] +pub(super) enum Broadcast { + Left, + Right, + #[allow(clippy::enum_variant_names)] + NoBroadcast, +} diff --git a/crates/polars-core/src/series/arithmetic/mod.rs b/crates/polars-core/src/series/arithmetic/mod.rs new file mode 100644 index 000000000000..8a4d317276c9 --- /dev/null +++ b/crates/polars-core/src/series/arithmetic/mod.rs @@ -0,0 +1,19 @@ +mod bitops; +mod borrowed; +mod list; +mod owned; + +use std::borrow::Cow; +use std::ops::{Add, Div, Mul, Rem, Sub}; + +pub use borrowed::*; +#[cfg(feature = "dtype-array")] +pub use fixed_size_list::NumericFixedSizeListOp; +pub use list::NumericListOp; +use num_traits::{Num, NumCast}; +#[cfg(feature = "dtype-array")] +mod fixed_size_list; +mod list_utils; + +use crate::prelude::*; +use crate::utils::{get_time_units, try_get_supertype}; diff --git a/crates/polars-core/src/series/arithmetic/owned.rs b/crates/polars-core/src/series/arithmetic/owned.rs new file mode 100644 index 000000000000..837f9e03e823 --- /dev/null +++ b/crates/polars-core/src/series/arithmetic/owned.rs @@ -0,0 +1,115 @@ +use super::*; +#[cfg(feature = "performant")] +use crate::utils::align_chunks_binary_owned_series; + +#[cfg(feature = "performant")] +pub fn coerce_lhs_rhs_owned(lhs: Series, rhs: Series) -> PolarsResult<(Series, Series)> { + let dtype = try_get_supertype(lhs.dtype(), rhs.dtype())?; + let left = if lhs.dtype() == &dtype { + lhs + } else { + lhs.cast(&dtype)? + }; + let right = if rhs.dtype() == &dtype { + rhs + } else { + rhs.cast(&dtype)? + }; + Ok((left, right)) +} + +fn is_eligible(lhs: &DataType, rhs: &DataType) -> bool { + !lhs.is_logical() + && lhs.to_physical().is_primitive_numeric() + && rhs.to_physical().is_primitive_numeric() +} + +#[cfg(feature = "performant")] +fn apply_operation_mut(mut lhs: Series, mut rhs: Series, op: F) -> Series +where + T: PolarsNumericType, + F: Fn(ChunkedArray, ChunkedArray) -> ChunkedArray + Copy, + ChunkedArray: IntoSeries, +{ + let lhs_ca: &mut ChunkedArray = lhs._get_inner_mut().as_mut(); + let rhs_ca: &mut ChunkedArray = rhs._get_inner_mut().as_mut(); + + let lhs = std::mem::take(lhs_ca); + let rhs = std::mem::take(rhs_ca); + + op(lhs, rhs).into_series() +} + +macro_rules! impl_operation { + ($operation:ident, $method:ident, $function:expr) => { + impl $operation for Series { + type Output = PolarsResult; + + fn $method(self, rhs: Self) -> Self::Output { + #[cfg(feature = "performant")] + { + // only physical numeric values take the mutable path + if is_eligible(self.dtype(), rhs.dtype()) { + let (lhs, rhs) = coerce_lhs_rhs_owned(self, rhs).unwrap(); + let (lhs, rhs) = align_chunks_binary_owned_series(lhs, rhs); + use DataType::*; + Ok(match lhs.dtype() { + #[cfg(feature = "dtype-i8")] + Int8 => apply_operation_mut::(lhs, rhs, $function), + #[cfg(feature = "dtype-i16")] + Int16 => apply_operation_mut::(lhs, rhs, $function), + Int32 => apply_operation_mut::(lhs, rhs, $function), + Int64 => apply_operation_mut::(lhs, rhs, $function), + #[cfg(feature = "dtype-u8")] + UInt8 => apply_operation_mut::(lhs, rhs, $function), + #[cfg(feature = "dtype-u16")] + UInt16 => apply_operation_mut::(lhs, rhs, $function), + UInt32 => apply_operation_mut::(lhs, rhs, $function), + UInt64 => apply_operation_mut::(lhs, rhs, $function), + Float32 => apply_operation_mut::(lhs, rhs, $function), + Float64 => apply_operation_mut::(lhs, rhs, $function), + _ => unreachable!(), + }) + } else { + (&self).$method(&rhs) + } + } + #[cfg(not(feature = "performant"))] + { + (&self).$method(&rhs) + } + } + } + }; +} + +impl_operation!(Add, add, |a, b| a.add(b)); +impl_operation!(Sub, sub, |a, b| a.sub(b)); +impl_operation!(Mul, mul, |a, b| a.mul(b)); +impl_operation!(Div, div, |a, b| a.div(b)); + +impl Series { + pub fn try_add_owned(self, other: Self) -> PolarsResult { + if is_eligible(self.dtype(), other.dtype()) { + self + other + } else { + std::ops::Add::add(&self, &other) + } + } + + pub fn try_sub_owned(self, other: Self) -> PolarsResult { + if is_eligible(self.dtype(), other.dtype()) { + self - other + } else { + std::ops::Sub::sub(&self, &other) + } + } + + pub fn try_mul_owned(self, other: Self) -> PolarsResult { + if is_eligible(self.dtype(), other.dtype()) { + self * other + } else { + std::ops::Mul::mul(&self, &other) + } + } +} diff --git a/crates/polars-core/src/series/builder.rs b/crates/polars-core/src/series/builder.rs new file mode 100644 index 000000000000..ad53092bf18a --- /dev/null +++ b/crates/polars-core/src/series/builder.rs @@ -0,0 +1,202 @@ +use arrow::array::builder::{ArrayBuilder, ShareStrategy, make_builder}; +use polars_utils::IdxSize; + +#[cfg(feature = "object")] +use crate::chunked_array::object::registry::get_object_builder; +use crate::prelude::*; +use crate::utils::Container; + +#[cfg(feature = "dtype-categorical")] +#[inline(always)] +fn fill_rev_map(dtype: &DataType, rev_map_merger: &mut Option>) { + if let DataType::Categorical(Some(rev_map), _) = dtype { + assert!( + rev_map.is_active_global(), + "{}", + polars_err!(string_cache_mismatch) + ); + if let Some(merger) = rev_map_merger { + merger.merge_map(rev_map).unwrap(); + } else { + *rev_map_merger = Some(Box::new(GlobalRevMapMerger::new(rev_map.clone()))); + } + } +} + +/// A type-erased wrapper around ArrayBuilder. +pub struct SeriesBuilder { + dtype: DataType, + builder: Box, + #[cfg(feature = "dtype-categorical")] + rev_map_merger: Option>, +} + +impl SeriesBuilder { + pub fn new(dtype: DataType) -> Self { + // FIXME: get rid of this hack. + #[cfg(feature = "object")] + if matches!(dtype, DataType::Object(_)) { + let builder = get_object_builder(PlSmallStr::EMPTY, 0).as_array_builder(); + return Self { + dtype, + builder, + #[cfg(feature = "dtype-categorical")] + rev_map_merger: None, + }; + } + + let builder = make_builder(&dtype.to_physical().to_arrow(CompatLevel::newest())); + Self { + dtype, + builder, + #[cfg(feature = "dtype-categorical")] + rev_map_merger: None, + } + } + + #[inline(always)] + pub fn reserve(&mut self, additional: usize) { + self.builder.reserve(additional); + } + + fn freeze_dtype(&mut self) -> DataType { + #[cfg(feature = "dtype-categorical")] + if let Some(rev_map_merger) = self.rev_map_merger.take() { + let DataType::Categorical(_, order) = self.dtype else { + unreachable!() + }; + return DataType::Categorical(Some(rev_map_merger.finish()), order); + } + + self.dtype.clone() + } + + pub fn freeze(mut self, name: PlSmallStr) -> Series { + unsafe { + let dtype = self.freeze_dtype(); + Series::from_chunks_and_dtype_unchecked(name, vec![self.builder.freeze()], &dtype) + } + } + + pub fn freeze_reset(&mut self, name: PlSmallStr) -> Series { + unsafe { + Series::from_chunks_and_dtype_unchecked( + name, + vec![self.builder.freeze_reset()], + &self.freeze_dtype(), + ) + } + } + + pub fn len(&self) -> usize { + self.builder.len() + } + + pub fn is_empty(&self) -> bool { + self.builder.len() == 0 + } + + /// Extends this builder with the contents of the given series. May panic if + /// other does not match the dtype of this builder. + #[inline(always)] + pub fn extend(&mut self, other: &Series, share: ShareStrategy) { + #[cfg(feature = "dtype-categorical")] + { + fill_rev_map(other.dtype(), &mut self.rev_map_merger); + } + + self.subslice_extend(other, 0, other.len(), share); + } + + /// Extends this builder with the contents of the given series subslice. + /// May panic if other does not match the dtype of this builder. + pub fn subslice_extend( + &mut self, + other: &Series, + mut start: usize, + mut length: usize, + share: ShareStrategy, + ) { + #[cfg(feature = "dtype-categorical")] + { + fill_rev_map(other.dtype(), &mut self.rev_map_merger); + } + + if length == 0 || other.is_empty() { + return; + } + + for chunk in other.chunks() { + if start < chunk.len() { + let length_in_chunk = length.min(chunk.len() - start); + self.builder + .subslice_extend(&**chunk, start, length_in_chunk, share); + + start = 0; + length -= length_in_chunk; + if length == 0 { + break; + } + } else { + start -= chunk.len(); + } + } + } + + pub fn subslice_extend_repeated( + &mut self, + other: &Series, + start: usize, + length: usize, + repeats: usize, + share: ShareStrategy, + ) { + #[cfg(feature = "dtype-categorical")] + { + fill_rev_map(other.dtype(), &mut self.rev_map_merger); + } + + if length == 0 || other.is_empty() { + return; + } + + let chunks = other.chunks(); + if chunks.len() == 1 { + self.builder + .subslice_extend_repeated(&*chunks[0], start, length, repeats, share); + } else { + for _ in 0..repeats { + self.subslice_extend(other, start, length, share); + } + } + } + + /// Extends this builder with the contents of the given series at the given + /// indices. That is, `other[idxs[i]]` is appended to this builder in order, + /// for each i=0..idxs.len(). May panic if other does not match the dtype + /// of this builder, or if the other series is not rechunked. + /// + /// # Safety + /// The indices must be in-bounds. + pub unsafe fn gather_extend(&mut self, other: &Series, idxs: &[IdxSize], share: ShareStrategy) { + #[cfg(feature = "dtype-categorical")] + { + fill_rev_map(other.dtype(), &mut self.rev_map_merger); + } + + let chunks = other.chunks(); + assert!(chunks.len() == 1); + self.builder.gather_extend(&*chunks[0], idxs, share); + } + + pub fn opt_gather_extend(&mut self, other: &Series, idxs: &[IdxSize], share: ShareStrategy) { + #[cfg(feature = "dtype-categorical")] + { + fill_rev_map(other.dtype(), &mut self.rev_map_merger); + } + + let chunks = other.chunks(); + assert!(chunks.len() == 1); + self.builder.opt_gather_extend(&*chunks[0], idxs, share); + } +} diff --git a/crates/polars-core/src/series/comparison.rs b/crates/polars-core/src/series/comparison.rs new file mode 100644 index 000000000000..c4ed9b5769f1 --- /dev/null +++ b/crates/polars-core/src/series/comparison.rs @@ -0,0 +1,425 @@ +//! Comparison operations on Series. + +use polars_error::feature_gated; + +use crate::prelude::*; +use crate::series::arithmetic::coerce_lhs_rhs; +use crate::series::nulls::replace_non_null; + +macro_rules! impl_eq_compare { + ($self:expr, $rhs:expr, $method:ident) => {{ + use DataType::*; + let (lhs, rhs) = ($self, $rhs); + validate_types(lhs.dtype(), rhs.dtype())?; + + polars_ensure!( + lhs.len() == rhs.len() || + + // Broadcast + lhs.len() == 1 || + rhs.len() == 1, + ShapeMismatch: "could not compare between two series of different length ({} != {})", + lhs.len(), + rhs.len() + ); + + #[cfg(feature = "dtype-categorical")] + match (lhs.dtype(), rhs.dtype()) { + (Categorical(_, _) | Enum(_, _), Categorical(_, _) | Enum(_, _)) => { + return Ok(lhs + .categorical() + .unwrap() + .$method(rhs.categorical().unwrap())? + .with_name(lhs.name().clone())); + }, + (Categorical(_, _) | Enum(_, _), String) => { + return Ok(lhs + .categorical() + .unwrap() + .$method(rhs.str().unwrap())? + .with_name(lhs.name().clone())); + }, + (String, Categorical(_, _) | Enum(_, _)) => { + return Ok(rhs + .categorical() + .unwrap() + .$method(lhs.str().unwrap())? + .with_name(lhs.name().clone())); + }, + _ => (), + }; + + let (lhs, rhs) = coerce_lhs_rhs(lhs, rhs) + .map_err(|_| polars_err!( + SchemaMismatch: "could not evaluate comparison between series '{}' of dtype: {} and series '{}' of dtype: {}", + lhs.name(), lhs.dtype(), rhs.name(), rhs.dtype() + ))?; + let lhs = lhs.to_physical_repr(); + let rhs = rhs.to_physical_repr(); + let mut out = match lhs.dtype() { + Null => lhs.null().unwrap().$method(rhs.null().unwrap()), + Boolean => lhs.bool().unwrap().$method(rhs.bool().unwrap()), + String => lhs.str().unwrap().$method(rhs.str().unwrap()), + Binary => lhs.binary().unwrap().$method(rhs.binary().unwrap()), + UInt8 => feature_gated!("dtype-u8", lhs.u8().unwrap().$method(rhs.u8().unwrap())), + UInt16 => feature_gated!("dtype-u16", lhs.u16().unwrap().$method(rhs.u16().unwrap())), + UInt32 => lhs.u32().unwrap().$method(rhs.u32().unwrap()), + UInt64 => lhs.u64().unwrap().$method(rhs.u64().unwrap()), + Int8 => feature_gated!("dtype-i8", lhs.i8().unwrap().$method(rhs.i8().unwrap())), + Int16 => feature_gated!("dtype-i16", lhs.i16().unwrap().$method(rhs.i16().unwrap())), + Int32 => lhs.i32().unwrap().$method(rhs.i32().unwrap()), + Int64 => lhs.i64().unwrap().$method(rhs.i64().unwrap()), + Int128 => feature_gated!("dtype-i128", lhs.i128().unwrap().$method(rhs.i128().unwrap())), + Float32 => lhs.f32().unwrap().$method(rhs.f32().unwrap()), + Float64 => lhs.f64().unwrap().$method(rhs.f64().unwrap()), + List(_) => lhs.list().unwrap().$method(rhs.list().unwrap()), + #[cfg(feature = "dtype-array")] + Array(_, _) => lhs.array().unwrap().$method(rhs.array().unwrap()), + #[cfg(feature = "dtype-struct")] + Struct(_) => lhs.struct_().unwrap().$method(rhs.struct_().unwrap()), + + dt => polars_bail!(InvalidOperation: "could not apply comparison on series of dtype '{}; operand names: '{}', '{}'", dt, lhs.name(), rhs.name()), + }; + out.rename(lhs.name().clone()); + PolarsResult::Ok(out) + }}; +} + +macro_rules! bail_invalid_ineq { + ($lhs:expr, $rhs:expr, $op:literal) => { + polars_bail!( + InvalidOperation: "cannot perform '{}' comparison between series '{}' of dtype: {} and series '{}' of dtype: {}", + $op, + $lhs.name(), $lhs.dtype(), + $rhs.name(), $rhs.dtype(), + ) + }; +} + +macro_rules! impl_ineq_compare { + ($self:expr, $rhs:expr, $method:ident, $op:literal) => {{ + use DataType::*; + let (lhs, rhs) = ($self, $rhs); + validate_types(lhs.dtype(), rhs.dtype())?; + + polars_ensure!( + lhs.len() == rhs.len() || + + // Broadcast + lhs.len() == 1 || + rhs.len() == 1, + ShapeMismatch: + "could not perform '{}' comparison between series '{}' of length: {} and series '{}' of length: {}, because they have different lengths", + $op, + lhs.name(), lhs.len(), + rhs.name(), rhs.len() + ); + + #[cfg(feature = "dtype-categorical")] + match (lhs.dtype(), rhs.dtype()) { + (Categorical(_, _) | Enum(_, _), Categorical(_, _) | Enum(_, _)) => { + return Ok(lhs + .categorical() + .unwrap() + .$method(rhs.categorical().unwrap())? + .with_name(lhs.name().clone())); + }, + (Categorical(_, _) | Enum(_, _), String) => { + return Ok(lhs + .categorical() + .unwrap() + .$method(rhs.str().unwrap())? + .with_name(lhs.name().clone())); + }, + (String, Categorical(_, _) | Enum(_, _)) => { + return Ok(rhs + .categorical() + .unwrap() + .$method(lhs.str().unwrap())? + .with_name(lhs.name().clone())); + }, + _ => (), + }; + + let (lhs, rhs) = coerce_lhs_rhs(lhs, rhs).map_err(|_| + polars_err!( + SchemaMismatch: "could not evaluate '{}' comparison between series '{}' of dtype: {} and series '{}' of dtype: {}", + $op, + lhs.name(), lhs.dtype(), + rhs.name(), rhs.dtype() + ) + )?; + let lhs = lhs.to_physical_repr(); + let rhs = rhs.to_physical_repr(); + let mut out = match lhs.dtype() { + Null => lhs.null().unwrap().$method(rhs.null().unwrap()), + Boolean => lhs.bool().unwrap().$method(rhs.bool().unwrap()), + String => lhs.str().unwrap().$method(rhs.str().unwrap()), + Binary => lhs.binary().unwrap().$method(rhs.binary().unwrap()), + UInt8 => feature_gated!("dtype-u8", lhs.u8().unwrap().$method(rhs.u8().unwrap())), + UInt16 => feature_gated!("dtype-u16", lhs.u16().unwrap().$method(rhs.u16().unwrap())), + UInt32 => lhs.u32().unwrap().$method(rhs.u32().unwrap()), + UInt64 => lhs.u64().unwrap().$method(rhs.u64().unwrap()), + Int8 => feature_gated!("dtype-i8", lhs.i8().unwrap().$method(rhs.i8().unwrap())), + Int16 => feature_gated!("dtype-i16", lhs.i16().unwrap().$method(rhs.i16().unwrap())), + Int32 => lhs.i32().unwrap().$method(rhs.i32().unwrap()), + Int64 => lhs.i64().unwrap().$method(rhs.i64().unwrap()), + Int128 => feature_gated!("dtype-i128", lhs.i128().unwrap().$method(rhs.i128().unwrap())), + Float32 => lhs.f32().unwrap().$method(rhs.f32().unwrap()), + Float64 => lhs.f64().unwrap().$method(rhs.f64().unwrap()), + List(_) => bail_invalid_ineq!(lhs, rhs, $op), + #[cfg(feature = "dtype-array")] + Array(_, _) => bail_invalid_ineq!(lhs, rhs, $op), + #[cfg(feature = "dtype-struct")] + Struct(_) => bail_invalid_ineq!(lhs, rhs, $op), + + dt => polars_bail!(InvalidOperation: "could not apply comparison on series of dtype '{}; operand names: '{}', '{}'", dt, lhs.name(), rhs.name()), + }; + out.rename(lhs.name().clone()); + PolarsResult::Ok(out) + }}; +} + +fn validate_types(left: &DataType, right: &DataType) -> PolarsResult<()> { + use DataType::*; + + match (left, right) { + (String, dt) | (dt, String) if dt.is_primitive_numeric() => { + polars_bail!(ComputeError: "cannot compare string with numeric type ({})", dt) + }, + #[cfg(feature = "dtype-categorical")] + (Categorical(_, _) | Enum(_, _), dt) | (dt, Categorical(_, _) | Enum(_, _)) + if !(dt.is_categorical() | dt.is_string() | dt.is_enum()) => + { + polars_bail!(ComputeError: "cannot compare categorical with {}", dt); + }, + _ => (), + }; + Ok(()) +} + +impl ChunkCompareEq<&Series> for Series { + type Item = PolarsResult; + + /// Create a boolean mask by checking for equality. + fn equal(&self, rhs: &Series) -> Self::Item { + impl_eq_compare!(self, rhs, equal) + } + + /// Create a boolean mask by checking for equality. + fn equal_missing(&self, rhs: &Series) -> Self::Item { + impl_eq_compare!(self, rhs, equal_missing) + } + + /// Create a boolean mask by checking for inequality. + fn not_equal(&self, rhs: &Series) -> Self::Item { + impl_eq_compare!(self, rhs, not_equal) + } + + /// Create a boolean mask by checking for inequality. + fn not_equal_missing(&self, rhs: &Series) -> Self::Item { + impl_eq_compare!(self, rhs, not_equal_missing) + } +} + +impl ChunkCompareIneq<&Series> for Series { + type Item = PolarsResult; + + /// Create a boolean mask by checking if self > rhs. + fn gt(&self, rhs: &Series) -> Self::Item { + impl_ineq_compare!(self, rhs, gt, ">") + } + + /// Create a boolean mask by checking if self >= rhs. + fn gt_eq(&self, rhs: &Series) -> Self::Item { + impl_ineq_compare!(self, rhs, gt_eq, ">=") + } + + /// Create a boolean mask by checking if self < rhs. + fn lt(&self, rhs: &Series) -> Self::Item { + impl_ineq_compare!(self, rhs, lt, "<") + } + + /// Create a boolean mask by checking if self <= rhs. + fn lt_eq(&self, rhs: &Series) -> Self::Item { + impl_ineq_compare!(self, rhs, lt_eq, "<=") + } +} + +impl ChunkCompareEq for Series +where + Rhs: NumericNative, +{ + type Item = PolarsResult; + + fn equal(&self, rhs: Rhs) -> Self::Item { + validate_types(self.dtype(), &DataType::Int8)?; + let s = self.to_physical_repr(); + Ok(apply_method_physical_numeric!(&s, equal, rhs)) + } + + fn equal_missing(&self, rhs: Rhs) -> Self::Item { + validate_types(self.dtype(), &DataType::Int8)?; + let s = self.to_physical_repr(); + Ok(apply_method_physical_numeric!(&s, equal_missing, rhs)) + } + + fn not_equal(&self, rhs: Rhs) -> Self::Item { + validate_types(self.dtype(), &DataType::Int8)?; + let s = self.to_physical_repr(); + Ok(apply_method_physical_numeric!(&s, not_equal, rhs)) + } + + fn not_equal_missing(&self, rhs: Rhs) -> Self::Item { + validate_types(self.dtype(), &DataType::Int8)?; + let s = self.to_physical_repr(); + Ok(apply_method_physical_numeric!(&s, not_equal_missing, rhs)) + } +} + +impl ChunkCompareIneq for Series +where + Rhs: NumericNative, +{ + type Item = PolarsResult; + + fn gt(&self, rhs: Rhs) -> Self::Item { + validate_types(self.dtype(), &DataType::Int8)?; + let s = self.to_physical_repr(); + Ok(apply_method_physical_numeric!(&s, gt, rhs)) + } + + fn gt_eq(&self, rhs: Rhs) -> Self::Item { + validate_types(self.dtype(), &DataType::Int8)?; + let s = self.to_physical_repr(); + Ok(apply_method_physical_numeric!(&s, gt_eq, rhs)) + } + + fn lt(&self, rhs: Rhs) -> Self::Item { + validate_types(self.dtype(), &DataType::Int8)?; + let s = self.to_physical_repr(); + Ok(apply_method_physical_numeric!(&s, lt, rhs)) + } + + fn lt_eq(&self, rhs: Rhs) -> Self::Item { + validate_types(self.dtype(), &DataType::Int8)?; + let s = self.to_physical_repr(); + Ok(apply_method_physical_numeric!(&s, lt_eq, rhs)) + } +} + +impl ChunkCompareEq<&str> for Series { + type Item = PolarsResult; + + fn equal(&self, rhs: &str) -> PolarsResult { + validate_types(self.dtype(), &DataType::String)?; + match self.dtype() { + DataType::String => Ok(self.str().unwrap().equal(rhs)), + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_, _) | DataType::Enum(_, _) => { + self.categorical().unwrap().equal(rhs) + }, + _ => Ok(BooleanChunked::full(self.name().clone(), false, self.len())), + } + } + + fn equal_missing(&self, rhs: &str) -> Self::Item { + validate_types(self.dtype(), &DataType::String)?; + match self.dtype() { + DataType::String => Ok(self.str().unwrap().equal_missing(rhs)), + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_, _) | DataType::Enum(_, _) => { + self.categorical().unwrap().equal_missing(rhs) + }, + _ => Ok(replace_non_null( + self.name().clone(), + self.0.chunks(), + false, + )), + } + } + + fn not_equal(&self, rhs: &str) -> PolarsResult { + validate_types(self.dtype(), &DataType::String)?; + match self.dtype() { + DataType::String => Ok(self.str().unwrap().not_equal(rhs)), + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_, _) | DataType::Enum(_, _) => { + self.categorical().unwrap().not_equal(rhs) + }, + _ => Ok(BooleanChunked::full(self.name().clone(), true, self.len())), + } + } + + fn not_equal_missing(&self, rhs: &str) -> Self::Item { + validate_types(self.dtype(), &DataType::String)?; + match self.dtype() { + DataType::String => Ok(self.str().unwrap().not_equal_missing(rhs)), + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_, _) | DataType::Enum(_, _) => { + self.categorical().unwrap().not_equal_missing(rhs) + }, + _ => Ok(replace_non_null(self.name().clone(), self.0.chunks(), true)), + } + } +} + +impl ChunkCompareIneq<&str> for Series { + type Item = PolarsResult; + + fn gt(&self, rhs: &str) -> Self::Item { + validate_types(self.dtype(), &DataType::String)?; + match self.dtype() { + DataType::String => Ok(self.str().unwrap().gt(rhs)), + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_, _) | DataType::Enum(_, _) => { + self.categorical().unwrap().gt(rhs) + }, + _ => polars_bail!( + ComputeError: "cannot compare str value to series of type {}", self.dtype(), + ), + } + } + + fn gt_eq(&self, rhs: &str) -> Self::Item { + validate_types(self.dtype(), &DataType::String)?; + match self.dtype() { + DataType::String => Ok(self.str().unwrap().gt_eq(rhs)), + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_, _) | DataType::Enum(_, _) => { + self.categorical().unwrap().gt_eq(rhs) + }, + _ => polars_bail!( + ComputeError: "cannot compare str value to series of type {}", self.dtype(), + ), + } + } + + fn lt(&self, rhs: &str) -> Self::Item { + validate_types(self.dtype(), &DataType::String)?; + match self.dtype() { + DataType::String => Ok(self.str().unwrap().lt(rhs)), + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_, _) | DataType::Enum(_, _) => { + self.categorical().unwrap().lt(rhs) + }, + _ => polars_bail!( + ComputeError: "cannot compare str value to series of type {}", self.dtype(), + ), + } + } + + fn lt_eq(&self, rhs: &str) -> Self::Item { + validate_types(self.dtype(), &DataType::String)?; + match self.dtype() { + DataType::String => Ok(self.str().unwrap().lt_eq(rhs)), + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_, _) | DataType::Enum(_, _) => { + self.categorical().unwrap().lt_eq(rhs) + }, + _ => polars_bail!( + ComputeError: "cannot compare str value to series of type {}", self.dtype(), + ), + } + } +} diff --git a/crates/polars-core/src/series/from.rs b/crates/polars-core/src/series/from.rs new file mode 100644 index 000000000000..5c1835d7b29c --- /dev/null +++ b/crates/polars-core/src/series/from.rs @@ -0,0 +1,847 @@ +#[cfg(feature = "dtype-categorical")] +use arrow::compute::concatenate::concatenate_unchecked; +use arrow::datatypes::Metadata; +#[cfg(any( + feature = "dtype-date", + feature = "dtype-datetime", + feature = "dtype-time", + feature = "dtype-duration" +))] +use arrow::temporal_conversions::*; +use polars_compute::cast::cast_unchecked as cast; +use polars_error::feature_gated; +use polars_utils::itertools::Itertools; + +use crate::chunked_array::cast::{CastOptions, cast_chunks}; +#[cfg(feature = "object")] +use crate::chunked_array::object::extension::polars_extension::PolarsExtension; +#[cfg(feature = "object")] +use crate::chunked_array::object::registry::get_object_builder; +#[cfg(feature = "timezones")] +use crate::chunked_array::temporal::parse_fixed_offset; +#[cfg(feature = "timezones")] +use crate::chunked_array::temporal::validate_time_zone; +use crate::prelude::*; + +impl Series { + pub fn from_chunk_and_dtype( + name: PlSmallStr, + chunk: ArrayRef, + dtype: &DataType, + ) -> PolarsResult { + if &dtype.to_physical().to_arrow(CompatLevel::newest()) != chunk.dtype() { + polars_bail!( + InvalidOperation: "cannot create a series of type '{dtype}' of arrow chunk with type '{:?}'", + chunk.dtype() + ); + } + + // SAFETY: We check that the datatype matches. + let series = unsafe { Self::from_chunks_and_dtype_unchecked(name, vec![chunk], dtype) }; + Ok(series) + } + + /// Takes chunks and a polars datatype and constructs the Series + /// This is faster than creating from chunks and an arrow datatype because there is no + /// casting involved + /// + /// # Safety + /// + /// The caller must ensure that the given `dtype`'s physical type matches all the `ArrayRef` dtypes. + pub unsafe fn from_chunks_and_dtype_unchecked( + name: PlSmallStr, + chunks: Vec, + dtype: &DataType, + ) -> Self { + use DataType::*; + match dtype { + #[cfg(feature = "dtype-i8")] + Int8 => Int8Chunked::from_chunks(name, chunks).into_series(), + #[cfg(feature = "dtype-i16")] + Int16 => Int16Chunked::from_chunks(name, chunks).into_series(), + Int32 => Int32Chunked::from_chunks(name, chunks).into_series(), + Int64 => Int64Chunked::from_chunks(name, chunks).into_series(), + #[cfg(feature = "dtype-u8")] + UInt8 => UInt8Chunked::from_chunks(name, chunks).into_series(), + #[cfg(feature = "dtype-u16")] + UInt16 => UInt16Chunked::from_chunks(name, chunks).into_series(), + UInt32 => UInt32Chunked::from_chunks(name, chunks).into_series(), + UInt64 => UInt64Chunked::from_chunks(name, chunks).into_series(), + #[cfg(feature = "dtype-i128")] + Int128 => Int128Chunked::from_chunks(name, chunks).into_series(), + #[cfg(feature = "dtype-date")] + Date => Int32Chunked::from_chunks(name, chunks) + .into_date() + .into_series(), + #[cfg(feature = "dtype-time")] + Time => Int64Chunked::from_chunks(name, chunks) + .into_time() + .into_series(), + #[cfg(feature = "dtype-duration")] + Duration(tu) => Int64Chunked::from_chunks(name, chunks) + .into_duration(*tu) + .into_series(), + #[cfg(feature = "dtype-datetime")] + Datetime(tu, tz) => Int64Chunked::from_chunks(name, chunks) + .into_datetime(*tu, tz.clone()) + .into_series(), + #[cfg(feature = "dtype-decimal")] + Decimal(precision, scale) => Int128Chunked::from_chunks(name, chunks) + .into_decimal_unchecked( + *precision, + scale.unwrap_or_else(|| unreachable!("scale should be set")), + ) + .into_series(), + #[cfg(feature = "dtype-array")] + Array(_, _) => { + ArrayChunked::from_chunks_and_dtype_unchecked(name, chunks, dtype.clone()) + .into_series() + }, + List(_) => ListChunked::from_chunks_and_dtype_unchecked(name, chunks, dtype.clone()) + .into_series(), + String => StringChunked::from_chunks(name, chunks).into_series(), + Binary => BinaryChunked::from_chunks(name, chunks).into_series(), + #[cfg(feature = "dtype-categorical")] + dt @ (Categorical(rev_map, ordering) | Enum(rev_map, ordering)) => { + let cats = UInt32Chunked::from_chunks(name, chunks); + let rev_map = rev_map.clone().unwrap_or_else(|| { + assert!(cats.is_empty()); + Arc::new(RevMapping::default()) + }); + let mut ca = CategoricalChunked::from_cats_and_rev_map_unchecked( + cats, + rev_map, + matches!(dt, Enum(_, _)), + *ordering, + ); + ca.set_fast_unique(false); + ca.into_series() + }, + Boolean => BooleanChunked::from_chunks(name, chunks).into_series(), + Float32 => Float32Chunked::from_chunks(name, chunks).into_series(), + Float64 => Float64Chunked::from_chunks(name, chunks).into_series(), + BinaryOffset => BinaryOffsetChunked::from_chunks(name, chunks).into_series(), + #[cfg(feature = "dtype-struct")] + Struct(_) => { + let mut ca = + StructChunked::from_chunks_and_dtype_unchecked(name, chunks, dtype.clone()); + ca.propagate_nulls(); + ca.into_series() + }, + #[cfg(feature = "object")] + Object(_) => { + if let Some(arr) = chunks[0].as_any().downcast_ref::() { + assert_eq!(chunks.len(), 1); + // SAFETY: + // this is highly unsafe. it will dereference a raw ptr on the heap + // make sure the ptr is allocated and from this pid + // (the pid is checked before dereference) + { + let pe = PolarsExtension::new(arr.clone()); + let s = pe.get_series(&name); + pe.take_and_forget(); + s + } + } else { + unsafe { get_object_builder(name, 0).from_chunks(chunks) } + } + }, + Null => new_null(name, &chunks), + Unknown(_) => { + panic!("dtype is unknown; consider supplying data-types for all operations") + }, + #[allow(unreachable_patterns)] + _ => unreachable!(), + } + } + + /// # Safety + /// The caller must ensure that the given `dtype` matches all the `ArrayRef` dtypes. + pub unsafe fn _try_from_arrow_unchecked( + name: PlSmallStr, + chunks: Vec, + dtype: &ArrowDataType, + ) -> PolarsResult { + Self::_try_from_arrow_unchecked_with_md(name, chunks, dtype, None) + } + + /// Create a new Series without checking if the inner dtype of the chunks is correct + /// + /// # Safety + /// The caller must ensure that the given `dtype` matches all the `ArrayRef` dtypes. + pub unsafe fn _try_from_arrow_unchecked_with_md( + name: PlSmallStr, + chunks: Vec, + dtype: &ArrowDataType, + md: Option<&Metadata>, + ) -> PolarsResult { + match dtype { + ArrowDataType::Utf8View => Ok(StringChunked::from_chunks(name, chunks).into_series()), + ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 => { + let chunks = + cast_chunks(&chunks, &DataType::String, CastOptions::NonStrict).unwrap(); + Ok(StringChunked::from_chunks(name, chunks).into_series()) + }, + ArrowDataType::BinaryView => Ok(BinaryChunked::from_chunks(name, chunks).into_series()), + ArrowDataType::LargeBinary => { + if let Some(md) = md { + if md.maintain_type() { + return Ok(BinaryOffsetChunked::from_chunks(name, chunks).into_series()); + } + } + let chunks = + cast_chunks(&chunks, &DataType::Binary, CastOptions::NonStrict).unwrap(); + Ok(BinaryChunked::from_chunks(name, chunks).into_series()) + }, + ArrowDataType::Binary => { + let chunks = + cast_chunks(&chunks, &DataType::Binary, CastOptions::NonStrict).unwrap(); + Ok(BinaryChunked::from_chunks(name, chunks).into_series()) + }, + ArrowDataType::List(_) | ArrowDataType::LargeList(_) => { + let (chunks, dtype) = to_physical_and_dtype(chunks, md); + unsafe { + Ok( + ListChunked::from_chunks_and_dtype_unchecked(name, chunks, dtype) + .into_series(), + ) + } + }, + #[cfg(feature = "dtype-array")] + ArrowDataType::FixedSizeList(_, _) => { + let (chunks, dtype) = to_physical_and_dtype(chunks, md); + unsafe { + Ok( + ArrayChunked::from_chunks_and_dtype_unchecked(name, chunks, dtype) + .into_series(), + ) + } + }, + ArrowDataType::Boolean => Ok(BooleanChunked::from_chunks(name, chunks).into_series()), + #[cfg(feature = "dtype-u8")] + ArrowDataType::UInt8 => Ok(UInt8Chunked::from_chunks(name, chunks).into_series()), + #[cfg(feature = "dtype-u16")] + ArrowDataType::UInt16 => Ok(UInt16Chunked::from_chunks(name, chunks).into_series()), + ArrowDataType::UInt32 => Ok(UInt32Chunked::from_chunks(name, chunks).into_series()), + ArrowDataType::UInt64 => Ok(UInt64Chunked::from_chunks(name, chunks).into_series()), + #[cfg(feature = "dtype-i8")] + ArrowDataType::Int8 => Ok(Int8Chunked::from_chunks(name, chunks).into_series()), + #[cfg(feature = "dtype-i16")] + ArrowDataType::Int16 => Ok(Int16Chunked::from_chunks(name, chunks).into_series()), + ArrowDataType::Int32 => Ok(Int32Chunked::from_chunks(name, chunks).into_series()), + ArrowDataType::Int64 => Ok(Int64Chunked::from_chunks(name, chunks).into_series()), + ArrowDataType::Int128 => feature_gated!( + "dtype-i128", + Ok(Int128Chunked::from_chunks(name, chunks).into_series()) + ), + ArrowDataType::Float16 => { + let chunks = + cast_chunks(&chunks, &DataType::Float32, CastOptions::NonStrict).unwrap(); + Ok(Float32Chunked::from_chunks(name, chunks).into_series()) + }, + ArrowDataType::Float32 => Ok(Float32Chunked::from_chunks(name, chunks).into_series()), + ArrowDataType::Float64 => Ok(Float64Chunked::from_chunks(name, chunks).into_series()), + #[cfg(feature = "dtype-date")] + ArrowDataType::Date32 => { + let chunks = + cast_chunks(&chunks, &DataType::Int32, CastOptions::Overflowing).unwrap(); + Ok(Int32Chunked::from_chunks(name, chunks) + .into_date() + .into_series()) + }, + #[cfg(feature = "dtype-datetime")] + ArrowDataType::Date64 => { + let chunks = + cast_chunks(&chunks, &DataType::Int64, CastOptions::Overflowing).unwrap(); + let ca = Int64Chunked::from_chunks(name, chunks); + Ok(ca.into_datetime(TimeUnit::Milliseconds, None).into_series()) + }, + #[cfg(feature = "dtype-datetime")] + ArrowDataType::Timestamp(tu, tz) => { + let canonical_tz = DataType::canonical_timezone(tz); + let tz = match canonical_tz.as_deref() { + #[cfg(feature = "timezones")] + Some(tz_str) => match validate_time_zone(tz_str) { + Ok(_) => canonical_tz, + Err(_) => Some(parse_fixed_offset(tz_str)?), + }, + _ => canonical_tz, + }; + let chunks = + cast_chunks(&chunks, &DataType::Int64, CastOptions::NonStrict).unwrap(); + let s = Int64Chunked::from_chunks(name, chunks) + .into_datetime(tu.into(), tz) + .into_series(); + Ok(match tu { + ArrowTimeUnit::Second => &s * MILLISECONDS, + ArrowTimeUnit::Millisecond => s, + ArrowTimeUnit::Microsecond => s, + ArrowTimeUnit::Nanosecond => s, + }) + }, + #[cfg(feature = "dtype-duration")] + ArrowDataType::Duration(tu) => { + let chunks = + cast_chunks(&chunks, &DataType::Int64, CastOptions::NonStrict).unwrap(); + let s = Int64Chunked::from_chunks(name, chunks) + .into_duration(tu.into()) + .into_series(); + Ok(match tu { + ArrowTimeUnit::Second => &s * MILLISECONDS, + ArrowTimeUnit::Millisecond => s, + ArrowTimeUnit::Microsecond => s, + ArrowTimeUnit::Nanosecond => s, + }) + }, + #[cfg(feature = "dtype-time")] + ArrowDataType::Time64(tu) | ArrowDataType::Time32(tu) => { + let mut chunks = chunks; + if matches!(dtype, ArrowDataType::Time32(_)) { + chunks = + cast_chunks(&chunks, &DataType::Int32, CastOptions::NonStrict).unwrap(); + } + let chunks = + cast_chunks(&chunks, &DataType::Int64, CastOptions::NonStrict).unwrap(); + let s = Int64Chunked::from_chunks(name, chunks) + .into_time() + .into_series(); + Ok(match tu { + ArrowTimeUnit::Second => &s * NANOSECONDS, + ArrowTimeUnit::Millisecond => &s * 1_000_000, + ArrowTimeUnit::Microsecond => &s * 1_000, + ArrowTimeUnit::Nanosecond => s, + }) + }, + ArrowDataType::Decimal(precision, scale) + | ArrowDataType::Decimal256(precision, scale) => { + feature_gated!("dtype-decimal", { + polars_ensure!(*scale <= *precision, InvalidOperation: "invalid decimal precision and scale (prec={precision}, scale={scale})"); + polars_ensure!(*precision <= 38, InvalidOperation: "polars does not support decimals about 38 precision"); + + let mut chunks = chunks; + // @NOTE: We cannot cast here as that will lower the scale. + for chunk in chunks.iter_mut() { + *chunk = std::mem::take( + chunk + .as_any_mut() + .downcast_mut::>() + .unwrap(), + ) + .to(ArrowDataType::Int128) + .to_boxed(); + } + let s = Int128Chunked::from_chunks(name, chunks) + .into_decimal_unchecked(Some(*precision), *scale) + .into_series(); + Ok(s) + }) + }, + ArrowDataType::Null => Ok(new_null(name, &chunks)), + #[cfg(not(feature = "dtype-categorical"))] + ArrowDataType::Dictionary(_, _, _) => { + panic!("activate dtype-categorical to convert dictionary arrays") + }, + #[cfg(feature = "dtype-categorical")] + ArrowDataType::Dictionary(key_type, value_type, _) => { + use arrow::datatypes::IntegerType; + // don't spuriously call this; triggers a read on mmapped data + let arr = if chunks.len() > 1 { + concatenate_unchecked(&chunks)? + } else { + chunks[0].clone() + }; + + // If the value type is a string, they are converted to Categoricals or Enums + if matches!( + value_type.as_ref(), + ArrowDataType::Utf8 + | ArrowDataType::LargeUtf8 + | ArrowDataType::Utf8View + | ArrowDataType::Null + ) { + macro_rules! unpack_keys_values { + ($dt:ty) => {{ + let arr = arr.as_any().downcast_ref::>().unwrap(); + let keys = arr.keys(); + let keys = cast(keys, &ArrowDataType::UInt32).unwrap(); + let values = arr.values(); + let values = cast(&**values, &ArrowDataType::Utf8View)?; + (keys, values) + }}; + } + + use IntegerType as I; + let (keys, values) = match key_type { + I::Int8 => unpack_keys_values!(i8), + I::UInt8 => unpack_keys_values!(u8), + I::Int16 => unpack_keys_values!(i16), + I::UInt16 => unpack_keys_values!(u16), + I::Int32 => unpack_keys_values!(i32), + I::UInt32 => unpack_keys_values!(u32), + I::Int64 => unpack_keys_values!(i64), + _ => polars_bail!( + ComputeError: "dictionaries with unsigned 64-bit keys are not supported" + ), + }; + + let keys = keys.as_any().downcast_ref::>().unwrap(); + let values = values.as_any().downcast_ref::().unwrap(); + + // Categoricals and Enums expect the RevMap values to not contain any nulls + let (keys, values) = + polars_compute::propagate_dictionary::propagate_dictionary_value_nulls( + keys, values, + ); + + let mut ordering = CategoricalOrdering::default(); + if let Some(metadata) = md { + if metadata.is_enum() { + // SAFETY: + // the invariants of an Arrow Dictionary guarantee the keys are in bounds + return Ok(CategoricalChunked::from_cats_and_rev_map_unchecked( + UInt32Chunked::with_chunk(name, keys), + Arc::new(RevMapping::build_local(values)), + true, + CategoricalOrdering::Physical, // Enum always uses physical ordering + ) + .into_series()); + } else if let Some(o) = metadata.categorical() { + ordering = o; + } + } + + return Ok(CategoricalChunked::from_keys_and_values( + name, &keys, &values, ordering, + ) + .into_series()); + } + + macro_rules! unpack_keys_values { + ($dt:ty) => {{ + let arr = arr.as_any().downcast_ref::>().unwrap(); + let keys = arr.keys(); + let keys = polars_compute::cast::primitive_as_primitive::< + $dt, + ::Native, + >(keys, &IDX_DTYPE.to_arrow(CompatLevel::newest())); + (arr.values(), keys) + }}; + } + + use IntegerType as I; + let (values, keys) = match key_type { + I::Int8 => unpack_keys_values!(i8), + I::UInt8 => unpack_keys_values!(u8), + I::Int16 => unpack_keys_values!(i16), + I::UInt16 => unpack_keys_values!(u16), + I::Int32 => unpack_keys_values!(i32), + I::UInt32 => unpack_keys_values!(u32), + I::Int64 => unpack_keys_values!(i64), + _ => polars_bail!( + ComputeError: "dictionaries with unsigned 64-bit keys are not supported" + ), + }; + + // Convert the dictionary to a flat array + let values = Series::_try_from_arrow_unchecked_with_md( + name, + vec![values.clone()], + values.dtype(), + None, + )?; + let values = values.take_unchecked(&IdxCa::from_chunks_and_dtype( + PlSmallStr::EMPTY, + vec![keys.to_boxed()], + IDX_DTYPE, + )); + + Ok(values) + }, + #[cfg(feature = "object")] + ArrowDataType::Extension(ext) + if ext.name == EXTENSION_NAME && ext.metadata.is_some() => + { + assert_eq!(chunks.len(), 1); + let arr = chunks[0] + .as_any() + .downcast_ref::() + .unwrap(); + // SAFETY: + // this is highly unsafe. it will dereference a raw ptr on the heap + // make sure the ptr is allocated and from this pid + // (the pid is checked before dereference) + let s = { + let pe = PolarsExtension::new(arr.clone()); + let s = pe.get_series(&name); + pe.take_and_forget(); + s + }; + Ok(s) + }, + #[cfg(feature = "dtype-struct")] + ArrowDataType::Struct(_) => { + let (chunks, dtype) = to_physical_and_dtype(chunks, md); + + unsafe { + let mut ca = + StructChunked::from_chunks_and_dtype_unchecked(name, chunks, dtype); + ca.propagate_nulls(); + Ok(ca.into_series()) + } + }, + ArrowDataType::FixedSizeBinary(_) => { + let chunks = cast_chunks(&chunks, &DataType::Binary, CastOptions::NonStrict)?; + Ok(BinaryChunked::from_chunks(name, chunks).into_series()) + }, + ArrowDataType::Map(_, _) => map_arrays_to_series(name, chunks), + dt => polars_bail!(ComputeError: "cannot create series from {:?}", dt), + } + } +} + +fn map_arrays_to_series(name: PlSmallStr, chunks: Vec) -> PolarsResult { + let chunks = chunks + .iter() + .map(|arr| { + // we convert the map to the logical type: List> + let arr = arr.as_any().downcast_ref::().unwrap(); + let inner = arr.field().clone(); + + // map has i32 offsets + let dtype = ListArray::::default_datatype(inner.dtype().clone()); + Box::new(ListArray::::new( + dtype, + arr.offsets().clone(), + inner, + arr.validity().cloned(), + )) as ArrayRef + }) + .collect::>(); + Series::try_from((name, chunks)) +} + +fn convert ArrayRef>(arr: &[ArrayRef], f: F) -> Vec { + arr.iter().map(|arr| f(&**arr)).collect() +} + +/// Converts to physical types and bubbles up the correct [`DataType`]. +#[allow(clippy::only_used_in_recursion)] +unsafe fn to_physical_and_dtype( + arrays: Vec, + md: Option<&Metadata>, +) -> (Vec, DataType) { + match arrays[0].dtype() { + ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 => { + let chunks = cast_chunks(&arrays, &DataType::String, CastOptions::NonStrict).unwrap(); + (chunks, DataType::String) + }, + ArrowDataType::Binary | ArrowDataType::LargeBinary | ArrowDataType::FixedSizeBinary(_) => { + let chunks = cast_chunks(&arrays, &DataType::Binary, CastOptions::NonStrict).unwrap(); + (chunks, DataType::Binary) + }, + #[allow(unused_variables)] + dt @ ArrowDataType::Dictionary(_, _, _) => { + feature_gated!("dtype-categorical", { + let s = unsafe { + let dt = dt.clone(); + Series::_try_from_arrow_unchecked_with_md(PlSmallStr::EMPTY, arrays, &dt, md) + } + .unwrap(); + (s.chunks().clone(), s.dtype().clone()) + }) + }, + ArrowDataType::List(field) => { + let out = convert(&arrays, |arr| { + cast(arr, &ArrowDataType::LargeList(field.clone())).unwrap() + }); + to_physical_and_dtype(out, md) + }, + #[cfg(feature = "dtype-array")] + ArrowDataType::FixedSizeList(field, size) => { + let values = arrays + .iter() + .map(|arr| { + let arr = arr.as_any().downcast_ref::().unwrap(); + arr.values().clone() + }) + .collect::>(); + + let (converted_values, dtype) = + to_physical_and_dtype(values, field.metadata.as_deref()); + + let arrays = arrays + .iter() + .zip(converted_values) + .map(|(arr, values)| { + let arr = arr.as_any().downcast_ref::().unwrap(); + + let dtype = FixedSizeListArray::default_datatype(values.dtype().clone(), *size); + Box::from(FixedSizeListArray::new( + dtype, + arr.len(), + values, + arr.validity().cloned(), + )) as ArrayRef + }) + .collect(); + (arrays, DataType::Array(Box::new(dtype), *size)) + }, + ArrowDataType::LargeList(field) => { + let values = arrays + .iter() + .map(|arr| { + let arr = arr.as_any().downcast_ref::>().unwrap(); + arr.values().clone() + }) + .collect::>(); + + let (converted_values, dtype) = + to_physical_and_dtype(values, field.metadata.as_deref()); + + let arrays = arrays + .iter() + .zip(converted_values) + .map(|(arr, values)| { + let arr = arr.as_any().downcast_ref::>().unwrap(); + + let dtype = ListArray::::default_datatype(values.dtype().clone()); + Box::from(ListArray::::new( + dtype, + arr.offsets().clone(), + values, + arr.validity().cloned(), + )) as ArrayRef + }) + .collect(); + (arrays, DataType::List(Box::new(dtype))) + }, + ArrowDataType::Struct(_fields) => { + feature_gated!("dtype-struct", { + let mut pl_fields = None; + let arrays = arrays + .iter() + .map(|arr| { + let arr = arr.as_any().downcast_ref::().unwrap(); + let (values, dtypes): (Vec<_>, Vec<_>) = arr + .values() + .iter() + .zip(_fields.iter()) + .map(|(value, field)| { + let mut out = to_physical_and_dtype( + vec![value.clone()], + field.metadata.as_deref(), + ); + (out.0.pop().unwrap(), out.1) + }) + .unzip(); + + let arrow_fields = values + .iter() + .zip(_fields.iter()) + .map(|(arr, field)| { + ArrowField::new(field.name.clone(), arr.dtype().clone(), true) + }) + .collect(); + let arrow_array = Box::new(StructArray::new( + ArrowDataType::Struct(arrow_fields), + arr.len(), + values, + arr.validity().cloned(), + )) as ArrayRef; + + if pl_fields.is_none() { + pl_fields = Some( + _fields + .iter() + .zip(dtypes) + .map(|(field, dtype)| Field::new(field.name.clone(), dtype)) + .collect_vec(), + ) + } + + arrow_array + }) + .collect_vec(); + + (arrays, DataType::Struct(pl_fields.unwrap())) + }) + }, + // Use Series architecture to convert nested logical types to physical. + dt @ (ArrowDataType::Duration(_) + | ArrowDataType::Time32(_) + | ArrowDataType::Time64(_) + | ArrowDataType::Timestamp(_, _) + | ArrowDataType::Date32 + | ArrowDataType::Decimal(_, _) + | ArrowDataType::Date64) => { + let dt = dt.clone(); + let mut s = Series::_try_from_arrow_unchecked(PlSmallStr::EMPTY, arrays, &dt).unwrap(); + let dtype = s.dtype().clone(); + (std::mem::take(s.chunks_mut()), dtype) + }, + dt => { + let dtype = DataType::from_arrow(dt, true, md); + (arrays, dtype) + }, + } +} + +fn check_types(chunks: &[ArrayRef]) -> PolarsResult { + let mut chunks_iter = chunks.iter(); + let dtype: ArrowDataType = chunks_iter + .next() + .ok_or_else(|| polars_err!(NoData: "expected at least one array-ref"))? + .dtype() + .clone(); + + for chunk in chunks_iter { + if chunk.dtype() != &dtype { + polars_bail!( + ComputeError: "cannot create series from multiple arrays with different types" + ); + } + } + Ok(dtype) +} + +impl Series { + pub fn try_new( + name: PlSmallStr, + data: T, + ) -> Result>::Error> + where + (PlSmallStr, T): TryInto, + { + // # TODO + // * Remove the TryFrom impls in favor of this + <(PlSmallStr, T) as TryInto>::try_into((name, data)) + } +} + +impl TryFrom<(PlSmallStr, Vec)> for Series { + type Error = PolarsError; + + fn try_from(name_arr: (PlSmallStr, Vec)) -> PolarsResult { + let (name, chunks) = name_arr; + + let dtype = check_types(&chunks)?; + // SAFETY: + // dtype is checked + unsafe { Series::_try_from_arrow_unchecked(name, chunks, &dtype) } + } +} + +impl TryFrom<(PlSmallStr, ArrayRef)> for Series { + type Error = PolarsError; + + fn try_from(name_arr: (PlSmallStr, ArrayRef)) -> PolarsResult { + let (name, arr) = name_arr; + Series::try_from((name, vec![arr])) + } +} + +impl TryFrom<(&ArrowField, Vec)> for Series { + type Error = PolarsError; + + fn try_from(field_arr: (&ArrowField, Vec)) -> PolarsResult { + let (field, chunks) = field_arr; + + let dtype = check_types(&chunks)?; + + // SAFETY: + // dtype is checked + unsafe { + Series::_try_from_arrow_unchecked_with_md( + field.name.clone(), + chunks, + &dtype, + field.metadata.as_deref(), + ) + } + } +} + +impl TryFrom<(&ArrowField, ArrayRef)> for Series { + type Error = PolarsError; + + fn try_from(field_arr: (&ArrowField, ArrayRef)) -> PolarsResult { + let (field, arr) = field_arr; + Series::try_from((field, vec![arr])) + } +} + +/// Used to convert a [`ChunkedArray`], `&dyn SeriesTrait` and [`Series`] +/// into a [`Series`]. +/// # Safety +/// +/// This trait is marked `unsafe` as the `is_series` return is used +/// to transmute to `Series`. This must always return `false` except +/// for `Series` structs. +pub unsafe trait IntoSeries { + fn is_series() -> bool { + false + } + + fn into_series(self) -> Series + where + Self: Sized; +} + +impl From> for Series +where + T: PolarsDataType, + ChunkedArray: IntoSeries, +{ + fn from(ca: ChunkedArray) -> Self { + ca.into_series() + } +} + +#[cfg(feature = "dtype-date")] +impl From for Series { + fn from(a: DateChunked) -> Self { + a.into_series() + } +} + +#[cfg(feature = "dtype-datetime")] +impl From for Series { + fn from(a: DatetimeChunked) -> Self { + a.into_series() + } +} + +#[cfg(feature = "dtype-duration")] +impl From for Series { + fn from(a: DurationChunked) -> Self { + a.into_series() + } +} + +#[cfg(feature = "dtype-time")] +impl From for Series { + fn from(a: TimeChunked) -> Self { + a.into_series() + } +} + +unsafe impl IntoSeries for Arc { + fn into_series(self) -> Series { + Series(self) + } +} + +unsafe impl IntoSeries for Series { + fn is_series() -> bool { + true + } + + fn into_series(self) -> Series { + self + } +} + +fn new_null(name: PlSmallStr, chunks: &[ArrayRef]) -> Series { + let len = chunks.iter().map(|arr| arr.len()).sum(); + Series::new_null(name, len) +} diff --git a/crates/polars-core/src/series/implementations/array.rs b/crates/polars-core/src/series/implementations/array.rs new file mode 100644 index 000000000000..dc25f2f0f7bf --- /dev/null +++ b/crates/polars-core/src/series/implementations/array.rs @@ -0,0 +1,230 @@ +use std::any::Any; +use std::borrow::Cow; + +use self::compare_inner::{TotalEqInner, TotalOrdInner}; +use self::sort::arg_sort_row_fmt; +use super::{StatisticsFlags, private}; +use crate::chunked_array::AsSinglePtr; +use crate::chunked_array::cast::CastOptions; +use crate::chunked_array::comparison::*; +#[cfg(feature = "algorithm_group_by")] +use crate::frame::group_by::*; +use crate::prelude::*; +use crate::series::implementations::SeriesWrap; + +impl private::PrivateSeries for SeriesWrap { + fn compute_len(&mut self) { + self.0.compute_len() + } + fn _field(&self) -> Cow { + Cow::Borrowed(self.0.ref_field()) + } + fn _dtype(&self) -> &DataType { + self.0.ref_field().dtype() + } + + fn _get_flags(&self) -> StatisticsFlags { + self.0.get_flags() + } + + fn _set_flags(&mut self, flags: StatisticsFlags) { + self.0.set_flags(flags) + } + + unsafe fn equal_element(&self, idx_self: usize, idx_other: usize, other: &Series) -> bool { + self.0.equal_element(idx_self, idx_other, other) + } + + #[cfg(feature = "zip_with")] + fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult { + ChunkZip::zip_with(&self.0, mask, other.as_ref().as_ref()).map(|ca| ca.into_series()) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { + self.0.agg_list(groups) + } + + #[cfg(feature = "algorithm_group_by")] + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + IntoGroupsType::group_tuples(&self.0, multithreaded, sorted) + } + + fn add_to(&self, rhs: &Series) -> PolarsResult { + self.0.add_to(rhs) + } + + fn subtract(&self, rhs: &Series) -> PolarsResult { + self.0.subtract(rhs) + } + + fn multiply(&self, rhs: &Series) -> PolarsResult { + self.0.multiply(rhs) + } + fn divide(&self, rhs: &Series) -> PolarsResult { + self.0.divide(rhs) + } + fn remainder(&self, rhs: &Series) -> PolarsResult { + self.0.remainder(rhs) + } + + fn into_total_eq_inner<'a>(&'a self) -> Box { + invalid_operation_panic!(into_total_eq_inner, self) + } + fn into_total_ord_inner<'a>(&'a self) -> Box { + invalid_operation_panic!(into_total_ord_inner, self) + } +} + +impl SeriesTrait for SeriesWrap { + fn rename(&mut self, name: PlSmallStr) { + self.0.rename(name); + } + + fn chunk_lengths(&self) -> ChunkLenIter { + self.0.chunk_lengths() + } + fn name(&self) -> &PlSmallStr { + self.0.name() + } + + fn chunks(&self) -> &Vec { + self.0.chunks() + } + unsafe fn chunks_mut(&mut self) -> &mut Vec { + self.0.chunks_mut() + } + fn shrink_to_fit(&mut self) { + self.0.shrink_to_fit() + } + + fn arg_sort(&self, options: SortOptions) -> IdxCa { + let slf = (*self).clone(); + let slf = slf.into_column(); + arg_sort_row_fmt( + &[slf], + options.descending, + options.nulls_last, + options.multithreaded, + ) + .unwrap() + } + + fn sort_with(&self, options: SortOptions) -> PolarsResult { + let idxs = self.arg_sort(options); + Ok(unsafe { self.take_unchecked(&idxs) }) + } + + fn slice(&self, offset: i64, length: usize) -> Series { + self.0.slice(offset, length).into_series() + } + + fn split_at(&self, offset: i64) -> (Series, Series) { + let (a, b) = self.0.split_at(offset); + (a.into_series(), b.into_series()) + } + + fn append(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + let other = other.array()?; + self.0.append(other) + } + fn append_owned(&mut self, other: Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + self.0.append_owned(other.take_inner()) + } + + fn extend(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), extend); + self.0.extend(other.as_ref().as_ref()) + } + + fn filter(&self, filter: &BooleanChunked) -> PolarsResult { + ChunkFilter::filter(&self.0, filter).map(|ca| ca.into_series()) + } + + fn take(&self, indices: &IdxCa) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) + } + + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0.take_unchecked(indices).into_series() + } + + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) + } + + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0.take_unchecked(indices).into_series() + } + + fn len(&self) -> usize { + self.0.len() + } + + fn rechunk(&self) -> Series { + self.0.rechunk().into_owned().into_series() + } + + fn new_from_index(&self, index: usize, length: usize) -> Series { + ChunkExpandAtIndex::new_from_index(&self.0, index, length).into_series() + } + + fn cast(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + self.0.cast_with_options(dtype, options) + } + + #[inline] + unsafe fn get_unchecked(&self, index: usize) -> AnyValue { + self.0.get_any_value_unchecked(index) + } + + fn null_count(&self) -> usize { + self.0.null_count() + } + + fn has_nulls(&self) -> bool { + self.0.has_nulls() + } + + fn is_null(&self) -> BooleanChunked { + self.0.is_null() + } + + fn is_not_null(&self) -> BooleanChunked { + self.0.is_not_null() + } + + fn reverse(&self) -> Series { + ChunkReverse::reverse(&self.0).into_series() + } + + fn as_single_ptr(&mut self) -> PolarsResult { + self.0.as_single_ptr() + } + + fn shift(&self, periods: i64) -> Series { + ChunkShift::shift(&self.0, periods).into_series() + } + + fn clone_inner(&self) -> Arc { + Arc::new(SeriesWrap(Clone::clone(&self.0))) + } + + fn as_any(&self) -> &dyn Any { + &self.0 + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } + + fn as_phys_any(&self) -> &dyn Any { + &self.0 + } + + fn as_arc_any(self: Arc) -> Arc { + self as _ + } +} diff --git a/crates/polars-core/src/series/implementations/binary.rs b/crates/polars-core/src/series/implementations/binary.rs new file mode 100644 index 000000000000..e9dbc79027af --- /dev/null +++ b/crates/polars-core/src/series/implementations/binary.rs @@ -0,0 +1,270 @@ +use super::*; +use crate::chunked_array::cast::CastOptions; +use crate::chunked_array::comparison::*; +#[cfg(feature = "algorithm_group_by")] +use crate::frame::group_by::*; +use crate::prelude::*; + +impl private::PrivateSeries for SeriesWrap { + fn compute_len(&mut self) { + self.0.compute_len() + } + fn _field(&self) -> Cow { + Cow::Borrowed(self.0.ref_field()) + } + fn _dtype(&self) -> &DataType { + self.0.ref_field().dtype() + } + fn _get_flags(&self) -> StatisticsFlags { + self.0.get_flags() + } + fn _set_flags(&mut self, flags: StatisticsFlags) { + self.0.set_flags(flags) + } + + unsafe fn equal_element(&self, idx_self: usize, idx_other: usize, other: &Series) -> bool { + self.0.equal_element(idx_self, idx_other, other) + } + + #[cfg(feature = "zip_with")] + fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult { + ChunkZip::zip_with(&self.0, mask, other.as_ref().as_ref()).map(|ca| ca.into_series()) + } + fn into_total_eq_inner<'a>(&'a self) -> Box { + (&self.0).into_total_eq_inner() + } + fn into_total_ord_inner<'a>(&'a self) -> Box { + (&self.0).into_total_ord_inner() + } + + fn vec_hash( + &self, + random_state: PlSeedableRandomStateQuality, + buf: &mut Vec, + ) -> PolarsResult<()> { + self.0.vec_hash(random_state, buf)?; + Ok(()) + } + + fn vec_hash_combine( + &self, + build_hasher: PlSeedableRandomStateQuality, + hashes: &mut [u64], + ) -> PolarsResult<()> { + self.0.vec_hash_combine(build_hasher, hashes)?; + Ok(()) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { + self.0.agg_list(groups) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_min(&self, groups: &GroupsType) -> Series { + self.0.agg_min(groups) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_max(&self, groups: &GroupsType) -> Series { + self.0.agg_max(groups) + } + + fn subtract(&self, rhs: &Series) -> PolarsResult { + NumOpsDispatch::subtract(&self.0, rhs) + } + fn add_to(&self, rhs: &Series) -> PolarsResult { + NumOpsDispatch::add_to(&self.0, rhs) + } + fn multiply(&self, rhs: &Series) -> PolarsResult { + NumOpsDispatch::multiply(&self.0, rhs) + } + fn divide(&self, rhs: &Series) -> PolarsResult { + NumOpsDispatch::divide(&self.0, rhs) + } + fn remainder(&self, rhs: &Series) -> PolarsResult { + NumOpsDispatch::remainder(&self.0, rhs) + } + #[cfg(feature = "algorithm_group_by")] + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + IntoGroupsType::group_tuples(&self.0, multithreaded, sorted) + } + + fn arg_sort_multiple( + &self, + by: &[Column], + options: &SortMultipleOptions, + ) -> PolarsResult { + self.0.arg_sort_multiple(by, options) + } +} + +impl SeriesTrait for SeriesWrap { + fn rename(&mut self, name: PlSmallStr) { + self.0.rename(name); + } + + fn chunk_lengths(&self) -> ChunkLenIter { + self.0.chunk_lengths() + } + fn name(&self) -> &PlSmallStr { + self.0.name() + } + + fn chunks(&self) -> &Vec { + self.0.chunks() + } + unsafe fn chunks_mut(&mut self) -> &mut Vec { + self.0.chunks_mut() + } + fn shrink_to_fit(&mut self) { + self.0.shrink_to_fit() + } + + fn slice(&self, offset: i64, length: usize) -> Series { + self.0.slice(offset, length).into_series() + } + fn split_at(&self, offset: i64) -> (Series, Series) { + let (a, b) = self.0.split_at(offset); + (a.into_series(), b.into_series()) + } + + fn append(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + // todo! add object + self.0.append(other.as_ref().as_ref())?; + Ok(()) + } + fn append_owned(&mut self, other: Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + self.0.append_owned(other.take_inner()) + } + + fn extend(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), extend); + self.0.extend(other.as_ref().as_ref())?; + Ok(()) + } + + fn filter(&self, filter: &BooleanChunked) -> PolarsResult { + ChunkFilter::filter(&self.0, filter).map(|ca| ca.into_series()) + } + + fn take(&self, indices: &IdxCa) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) + } + + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0.take_unchecked(indices).into_series() + } + + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) + } + + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0.take_unchecked(indices).into_series() + } + + fn len(&self) -> usize { + self.0.len() + } + + fn rechunk(&self) -> Series { + self.0.rechunk().into_owned().into_series() + } + + fn new_from_index(&self, index: usize, length: usize) -> Series { + ChunkExpandAtIndex::new_from_index(&self.0, index, length).into_series() + } + + fn cast(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + self.0.cast_with_options(dtype, options) + } + + #[inline] + unsafe fn get_unchecked(&self, index: usize) -> AnyValue { + self.0.get_any_value_unchecked(index) + } + + fn sort_with(&self, options: SortOptions) -> PolarsResult { + Ok(ChunkSort::sort_with(&self.0, options).into_series()) + } + + fn arg_sort(&self, options: SortOptions) -> IdxCa { + ChunkSort::arg_sort(&self.0, options) + } + + fn null_count(&self) -> usize { + self.0.null_count() + } + + fn has_nulls(&self) -> bool { + self.0.has_nulls() + } + + #[cfg(feature = "algorithm_group_by")] + fn unique(&self) -> PolarsResult { + ChunkUnique::unique(&self.0).map(|ca| ca.into_series()) + } + + #[cfg(feature = "algorithm_group_by")] + fn n_unique(&self) -> PolarsResult { + ChunkUnique::n_unique(&self.0) + } + + #[cfg(feature = "algorithm_group_by")] + fn arg_unique(&self) -> PolarsResult { + ChunkUnique::arg_unique(&self.0) + } + + #[cfg(feature = "approx_unique")] + fn approx_n_unique(&self) -> PolarsResult { + Ok(ChunkApproxNUnique::approx_n_unique(&self.0)) + } + + fn is_null(&self) -> BooleanChunked { + self.0.is_null() + } + + fn is_not_null(&self) -> BooleanChunked { + self.0.is_not_null() + } + + fn reverse(&self) -> Series { + ChunkReverse::reverse(&self.0).into_series() + } + + fn as_single_ptr(&mut self) -> PolarsResult { + self.0.as_single_ptr() + } + + fn shift(&self, periods: i64) -> Series { + ChunkShift::shift(&self.0, periods).into_series() + } + + fn max_reduce(&self) -> PolarsResult { + Ok(ChunkAggSeries::max_reduce(&self.0)) + } + fn min_reduce(&self) -> PolarsResult { + Ok(ChunkAggSeries::min_reduce(&self.0)) + } + fn clone_inner(&self) -> Arc { + Arc::new(SeriesWrap(Clone::clone(&self.0))) + } + fn as_any(&self) -> &dyn Any { + &self.0 + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } + + fn as_phys_any(&self) -> &dyn Any { + &self.0 + } + + fn as_arc_any(self: Arc) -> Arc { + self as _ + } +} diff --git a/crates/polars-core/src/series/implementations/binary_offset.rs b/crates/polars-core/src/series/implementations/binary_offset.rs new file mode 100644 index 000000000000..712b1ccb9930 --- /dev/null +++ b/crates/polars-core/src/series/implementations/binary_offset.rs @@ -0,0 +1,212 @@ +use super::*; +use crate::chunked_array::comparison::*; +#[cfg(feature = "algorithm_group_by")] +use crate::frame::group_by::*; +use crate::prelude::*; +use crate::series::private::PrivateSeries; + +impl private::PrivateSeries for SeriesWrap { + fn compute_len(&mut self) { + self.0.compute_len() + } + fn _field(&self) -> Cow { + Cow::Borrowed(self.0.ref_field()) + } + fn _dtype(&self) -> &DataType { + self.0.ref_field().dtype() + } + fn _get_flags(&self) -> StatisticsFlags { + self.0.get_flags() + } + fn _set_flags(&mut self, flags: StatisticsFlags) { + self.0.set_flags(flags) + } + + unsafe fn equal_element(&self, idx_self: usize, idx_other: usize, other: &Series) -> bool { + self.0.equal_element(idx_self, idx_other, other) + } + + fn into_total_eq_inner<'a>(&'a self) -> Box { + (&self.0).into_total_eq_inner() + } + fn into_total_ord_inner<'a>(&'a self) -> Box { + (&self.0).into_total_ord_inner() + } + + fn vec_hash( + &self, + random_state: PlSeedableRandomStateQuality, + buf: &mut Vec, + ) -> PolarsResult<()> { + self.0.vec_hash(random_state, buf)?; + Ok(()) + } + + fn vec_hash_combine( + &self, + build_hasher: PlSeedableRandomStateQuality, + hashes: &mut [u64], + ) -> PolarsResult<()> { + self.0.vec_hash_combine(build_hasher, hashes)?; + Ok(()) + } + + #[cfg(feature = "algorithm_group_by")] + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + IntoGroupsType::group_tuples(&self.0, multithreaded, sorted) + } + + fn arg_sort_multiple( + &self, + by: &[Column], + options: &SortMultipleOptions, + ) -> PolarsResult { + self.0.arg_sort_multiple(by, options) + } +} + +impl SeriesTrait for SeriesWrap { + fn rename(&mut self, name: PlSmallStr) { + self.0.rename(name); + } + + fn chunk_lengths(&self) -> ChunkLenIter { + self.0.chunk_lengths() + } + fn name(&self) -> &PlSmallStr { + self.0.name() + } + + fn chunks(&self) -> &Vec { + self.0.chunks() + } + unsafe fn chunks_mut(&mut self) -> &mut Vec { + self.0.chunks_mut() + } + fn shrink_to_fit(&mut self) { + self.0.shrink_to_fit() + } + + fn slice(&self, offset: i64, length: usize) -> Series { + self.0.slice(offset, length).into_series() + } + fn split_at(&self, offset: i64) -> (Series, Series) { + let (a, b) = self.0.split_at(offset); + (a.into_series(), b.into_series()) + } + + fn append(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + // todo! add object + self.0.append(other.as_ref().as_ref())?; + Ok(()) + } + fn append_owned(&mut self, other: Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + self.0.append_owned(other.take_inner()) + } + + fn extend(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), extend); + self.0.extend(other.as_ref().as_ref())?; + Ok(()) + } + + fn filter(&self, filter: &BooleanChunked) -> PolarsResult { + ChunkFilter::filter(&self.0, filter).map(|ca| ca.into_series()) + } + + fn take(&self, indices: &IdxCa) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) + } + + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0.take_unchecked(indices).into_series() + } + + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) + } + + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0.take_unchecked(indices).into_series() + } + + fn len(&self) -> usize { + self.0.len() + } + + #[cfg(feature = "algorithm_group_by")] + fn n_unique(&self) -> PolarsResult { + // Only used by multi-key join validation, doesn't have to be optimal + self.group_tuples(true, false).map(|g| g.len()) + } + + fn rechunk(&self) -> Series { + self.0.rechunk().into_owned().into_series() + } + + fn new_from_index(&self, index: usize, length: usize) -> Series { + ChunkExpandAtIndex::new_from_index(&self.0, index, length).into_series() + } + + fn cast(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + self.0.cast_with_options(dtype, options) + } + + #[inline] + unsafe fn get_unchecked(&self, index: usize) -> AnyValue { + self.0.get_any_value_unchecked(index) + } + + fn sort_with(&self, options: SortOptions) -> PolarsResult { + Ok(ChunkSort::sort_with(&self.0, options).into_series()) + } + + fn arg_sort(&self, options: SortOptions) -> IdxCa { + ChunkSort::arg_sort(&self.0, options) + } + + fn null_count(&self) -> usize { + self.0.null_count() + } + + fn has_nulls(&self) -> bool { + self.0.has_nulls() + } + + fn is_null(&self) -> BooleanChunked { + self.0.is_null() + } + + fn is_not_null(&self) -> BooleanChunked { + self.0.is_not_null() + } + + fn reverse(&self) -> Series { + ChunkReverse::reverse(&self.0).into_series() + } + + fn shift(&self, periods: i64) -> Series { + ChunkShift::shift(&self.0, periods).into_series() + } + + fn clone_inner(&self) -> Arc { + Arc::new(SeriesWrap(Clone::clone(&self.0))) + } + fn as_any(&self) -> &dyn Any { + &self.0 + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } + + fn as_phys_any(&self) -> &dyn Any { + &self.0 + } + + fn as_arc_any(self: Arc) -> Arc { + self as _ + } +} diff --git a/crates/polars-core/src/series/implementations/boolean.rs b/crates/polars-core/src/series/implementations/boolean.rs new file mode 100644 index 000000000000..4d645e885385 --- /dev/null +++ b/crates/polars-core/src/series/implementations/boolean.rs @@ -0,0 +1,377 @@ +use super::*; +use crate::chunked_array::comparison::*; +#[cfg(feature = "algorithm_group_by")] +use crate::frame::group_by::*; +use crate::prelude::*; + +impl private::PrivateSeries for SeriesWrap { + fn compute_len(&mut self) { + self.0.compute_len() + } + fn _field(&self) -> Cow { + Cow::Borrowed(self.0.ref_field()) + } + fn _dtype(&self) -> &DataType { + self.0.ref_field().dtype() + } + fn _get_flags(&self) -> StatisticsFlags { + self.0.get_flags() + } + fn _set_flags(&mut self, flags: StatisticsFlags) { + self.0.set_flags(flags) + } + + unsafe fn equal_element(&self, idx_self: usize, idx_other: usize, other: &Series) -> bool { + self.0.equal_element(idx_self, idx_other, other) + } + + #[cfg(feature = "zip_with")] + fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult { + ChunkZip::zip_with(&self.0, mask, other.as_ref().as_ref()).map(|ca| ca.into_series()) + } + fn into_total_eq_inner<'a>(&'a self) -> Box { + (&self.0).into_total_eq_inner() + } + fn into_total_ord_inner<'a>(&'a self) -> Box { + (&self.0).into_total_ord_inner() + } + + fn vec_hash( + &self, + random_state: PlSeedableRandomStateQuality, + buf: &mut Vec, + ) -> PolarsResult<()> { + self.0.vec_hash(random_state, buf)?; + Ok(()) + } + + fn vec_hash_combine( + &self, + build_hasher: PlSeedableRandomStateQuality, + hashes: &mut [u64], + ) -> PolarsResult<()> { + self.0.vec_hash_combine(build_hasher, hashes)?; + Ok(()) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_min(&self, groups: &GroupsType) -> Series { + self.0.agg_min(groups) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_max(&self, groups: &GroupsType) -> Series { + self.0.agg_max(groups) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_sum(&self, groups: &GroupsType) -> Series { + self.0.agg_sum(groups) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { + self.0.agg_list(groups) + } + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_std(&self, groups: &GroupsType, _ddof: u8) -> Series { + self.0 + .cast_with_options(&DataType::Float64, CastOptions::Overflowing) + .unwrap() + .agg_std(groups, _ddof) + } + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_var(&self, groups: &GroupsType, _ddof: u8) -> Series { + self.0 + .cast_with_options(&DataType::Float64, CastOptions::Overflowing) + .unwrap() + .agg_var(groups, _ddof) + } + + #[cfg(feature = "bitwise")] + unsafe fn agg_and(&self, groups: &GroupsType) -> Series { + self.0.agg_and(groups) + } + #[cfg(feature = "bitwise")] + unsafe fn agg_or(&self, groups: &GroupsType) -> Series { + self.0.agg_or(groups) + } + #[cfg(feature = "bitwise")] + unsafe fn agg_xor(&self, groups: &GroupsType) -> Series { + self.0.agg_xor(groups) + } + + #[cfg(feature = "algorithm_group_by")] + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + IntoGroupsType::group_tuples(&self.0, multithreaded, sorted) + } + + fn arg_sort_multiple( + &self, + by: &[Column], + options: &SortMultipleOptions, + ) -> PolarsResult { + self.0.arg_sort_multiple(by, options) + } + fn add_to(&self, rhs: &Series) -> PolarsResult { + NumOpsDispatch::add_to(&self.0, rhs) + } +} + +impl SeriesTrait for SeriesWrap { + fn rename(&mut self, name: PlSmallStr) { + self.0.rename(name); + } + + fn chunk_lengths(&self) -> ChunkLenIter { + self.0.chunk_lengths() + } + fn name(&self) -> &PlSmallStr { + self.0.name() + } + + fn chunks(&self) -> &Vec { + self.0.chunks() + } + unsafe fn chunks_mut(&mut self) -> &mut Vec { + self.0.chunks_mut() + } + fn shrink_to_fit(&mut self) { + self.0.shrink_to_fit() + } + + fn slice(&self, offset: i64, length: usize) -> Series { + self.0.slice(offset, length).into_series() + } + fn split_at(&self, offset: i64) -> (Series, Series) { + let (a, b) = self.0.split_at(offset); + (a.into_series(), b.into_series()) + } + + fn append(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + self.0.append(other.as_ref().as_ref())?; + Ok(()) + } + fn append_owned(&mut self, other: Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + self.0.append_owned(other.take_inner()) + } + + fn extend(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), extend); + self.0.extend(other.as_ref().as_ref())?; + Ok(()) + } + + fn filter(&self, filter: &BooleanChunked) -> PolarsResult { + ChunkFilter::filter(&self.0, filter).map(|ca| ca.into_series()) + } + + fn _sum_as_f64(&self) -> f64 { + self.0.sum().unwrap() as f64 + } + + fn mean(&self) -> Option { + self.0.mean() + } + + fn take(&self, indices: &IdxCa) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) + } + + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0.take_unchecked(indices).into_series() + } + + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) + } + + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0.take_unchecked(indices).into_series() + } + + fn len(&self) -> usize { + self.0.len() + } + + fn rechunk(&self) -> Series { + self.0.rechunk().into_owned().into_series() + } + + fn new_from_index(&self, index: usize, length: usize) -> Series { + ChunkExpandAtIndex::new_from_index(&self.0, index, length).into_series() + } + + fn cast(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + self.0.cast_with_options(dtype, options) + } + + #[inline] + unsafe fn get_unchecked(&self, index: usize) -> AnyValue { + self.0.get_any_value_unchecked(index) + } + + fn sort_with(&self, options: SortOptions) -> PolarsResult { + Ok(ChunkSort::sort_with(&self.0, options).into_series()) + } + + fn arg_sort(&self, options: SortOptions) -> IdxCa { + ChunkSort::arg_sort(&self.0, options) + } + + fn null_count(&self) -> usize { + self.0.null_count() + } + + fn has_nulls(&self) -> bool { + self.0.has_nulls() + } + + #[cfg(feature = "algorithm_group_by")] + fn unique(&self) -> PolarsResult { + ChunkUnique::unique(&self.0).map(|ca| ca.into_series()) + } + + #[cfg(feature = "algorithm_group_by")] + fn n_unique(&self) -> PolarsResult { + ChunkUnique::n_unique(&self.0) + } + + #[cfg(feature = "algorithm_group_by")] + fn arg_unique(&self) -> PolarsResult { + ChunkUnique::arg_unique(&self.0) + } + + fn is_null(&self) -> BooleanChunked { + self.0.is_null() + } + + fn is_not_null(&self) -> BooleanChunked { + self.0.is_not_null() + } + + fn reverse(&self) -> Series { + ChunkReverse::reverse(&self.0).into_series() + } + + fn as_single_ptr(&mut self) -> PolarsResult { + self.0.as_single_ptr() + } + + fn shift(&self, periods: i64) -> Series { + ChunkShift::shift(&self.0, periods).into_series() + } + + fn sum_reduce(&self) -> PolarsResult { + Ok(ChunkAggSeries::sum_reduce(&self.0)) + } + fn max_reduce(&self) -> PolarsResult { + Ok(ChunkAggSeries::max_reduce(&self.0)) + } + fn min_reduce(&self) -> PolarsResult { + Ok(ChunkAggSeries::min_reduce(&self.0)) + } + fn median_reduce(&self) -> PolarsResult { + let ca = self + .0 + .cast_with_options(&DataType::Int8, CastOptions::Overflowing) + .unwrap(); + let sc = ca.median_reduce()?; + let v = sc.value().cast(&DataType::Float64); + Ok(Scalar::new(DataType::Float64, v)) + } + /// Get the variance of the Series as a new Series of length 1. + fn var_reduce(&self, _ddof: u8) -> PolarsResult { + let ca = self + .0 + .cast_with_options(&DataType::Int8, CastOptions::Overflowing) + .unwrap(); + let sc = ca.var_reduce(_ddof)?; + let v = sc.value().cast(&DataType::Float64); + Ok(Scalar::new(DataType::Float64, v)) + } + /// Get the standard deviation of the Series as a new Series of length 1. + fn std_reduce(&self, _ddof: u8) -> PolarsResult { + let ca = self + .0 + .cast_with_options(&DataType::Int8, CastOptions::Overflowing) + .unwrap(); + let sc = ca.std_reduce(_ddof)?; + let v = sc.value().cast(&DataType::Float64); + Ok(Scalar::new(DataType::Float64, v)) + } + fn and_reduce(&self) -> PolarsResult { + let dt = DataType::Boolean; + if self.0.null_count() > 0 { + return Ok(Scalar::new(dt, AnyValue::Null)); + } + + Ok(Scalar::new( + dt, + self.0 + .downcast_iter() + .filter(|arr| !arr.is_empty()) + .map(|arr| polars_compute::bitwise::BitwiseKernel::reduce_and(arr).unwrap()) + .reduce(|a, b| a & b) + .map_or(AnyValue::Null, Into::into), + )) + } + fn or_reduce(&self) -> PolarsResult { + let dt = DataType::Boolean; + if self.0.null_count() > 0 { + return Ok(Scalar::new(dt, AnyValue::Null)); + } + + Ok(Scalar::new( + dt, + self.0 + .downcast_iter() + .filter(|arr| !arr.is_empty()) + .map(|arr| polars_compute::bitwise::BitwiseKernel::reduce_or(arr).unwrap()) + .reduce(|a, b| a | b) + .map_or(AnyValue::Null, Into::into), + )) + } + fn xor_reduce(&self) -> PolarsResult { + let dt = DataType::Boolean; + if self.0.null_count() > 0 { + return Ok(Scalar::new(dt, AnyValue::Null)); + } + + Ok(Scalar::new( + dt, + self.0 + .downcast_iter() + .filter(|arr| !arr.is_empty()) + .map(|arr| polars_compute::bitwise::BitwiseKernel::reduce_xor(arr).unwrap()) + .reduce(|a, b| a ^ b) + .map_or(AnyValue::Null, Into::into), + )) + } + + #[cfg(feature = "approx_unique")] + fn approx_n_unique(&self) -> PolarsResult { + Ok(ChunkApproxNUnique::approx_n_unique(&self.0)) + } + + fn clone_inner(&self) -> Arc { + Arc::new(SeriesWrap(Clone::clone(&self.0))) + } + fn as_any(&self) -> &dyn Any { + &self.0 + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } + + fn as_phys_any(&self) -> &dyn Any { + &self.0 + } + + fn as_arc_any(self: Arc) -> Arc { + self as _ + } +} diff --git a/crates/polars-core/src/series/implementations/categorical.rs b/crates/polars-core/src/series/implementations/categorical.rs new file mode 100644 index 000000000000..0ecbc87cb8a1 --- /dev/null +++ b/crates/polars-core/src/series/implementations/categorical.rs @@ -0,0 +1,332 @@ +use super::*; +use crate::chunked_array::comparison::*; +use crate::prelude::*; + +unsafe impl IntoSeries for CategoricalChunked { + fn into_series(self) -> Series { + Series(Arc::new(SeriesWrap(self))) + } +} + +impl SeriesWrap { + fn finish_with_state(&self, keep_fast_unique: bool, cats: UInt32Chunked) -> CategoricalChunked { + let mut out = unsafe { + CategoricalChunked::from_cats_and_rev_map_unchecked( + cats, + self.0.get_rev_map().clone(), + self.0.is_enum(), + self.0.get_ordering(), + ) + }; + if keep_fast_unique && self.0._can_fast_unique() { + out.set_fast_unique(true) + } + out + } + + fn with_state(&self, keep_fast_unique: bool, apply: F) -> CategoricalChunked + where + F: Fn(&UInt32Chunked) -> UInt32Chunked, + { + let cats = apply(self.0.physical()); + self.finish_with_state(keep_fast_unique, cats) + } + + fn try_with_state<'a, F>( + &'a self, + keep_fast_unique: bool, + apply: F, + ) -> PolarsResult + where + F: for<'b> Fn(&'a UInt32Chunked) -> PolarsResult, + { + let cats = apply(self.0.physical())?; + Ok(self.finish_with_state(keep_fast_unique, cats)) + } +} + +impl private::PrivateSeries for SeriesWrap { + fn compute_len(&mut self) { + self.0.physical_mut().compute_len() + } + fn _field(&self) -> Cow { + Cow::Owned(self.0.field()) + } + fn _dtype(&self) -> &DataType { + self.0.dtype() + } + fn _get_flags(&self) -> StatisticsFlags { + self.0.get_flags() + } + fn _set_flags(&mut self, flags: StatisticsFlags) { + self.0.set_flags(flags) + } + + unsafe fn equal_element(&self, idx_self: usize, idx_other: usize, other: &Series) -> bool { + self.0.physical().equal_element(idx_self, idx_other, other) + } + + #[cfg(feature = "zip_with")] + fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult { + self.0 + .zip_with(mask, other.categorical()?) + .map(|ca| ca.into_series()) + } + fn into_total_ord_inner<'a>(&'a self) -> Box { + if self.0.uses_lexical_ordering() { + (&self.0).into_total_ord_inner() + } else { + self.0.physical().into_total_ord_inner() + } + } + fn into_total_eq_inner<'a>(&'a self) -> Box { + invalid_operation_panic!(into_total_eq_inner, self) + } + + fn vec_hash( + &self, + random_state: PlSeedableRandomStateQuality, + buf: &mut Vec, + ) -> PolarsResult<()> { + self.0.physical().vec_hash(random_state, buf)?; + Ok(()) + } + + fn vec_hash_combine( + &self, + build_hasher: PlSeedableRandomStateQuality, + hashes: &mut [u64], + ) -> PolarsResult<()> { + self.0.physical().vec_hash_combine(build_hasher, hashes)?; + Ok(()) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { + // we cannot cast and dispatch as the inner type of the list would be incorrect + let list = self.0.physical().agg_list(groups); + let mut list = list.list().unwrap().clone(); + unsafe { list.to_logical(self.dtype().clone()) }; + list.into_series() + } + + #[cfg(feature = "algorithm_group_by")] + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + #[cfg(feature = "performant")] + { + Ok(self.0.group_tuples_perfect(multithreaded, sorted)) + } + #[cfg(not(feature = "performant"))] + { + self.0.physical().group_tuples(multithreaded, sorted) + } + } + + fn arg_sort_multiple( + &self, + by: &[Column], + options: &SortMultipleOptions, + ) -> PolarsResult { + self.0.arg_sort_multiple(by, options) + } +} + +impl SeriesTrait for SeriesWrap { + fn rename(&mut self, name: PlSmallStr) { + self.0.physical_mut().rename(name); + } + + fn chunk_lengths(&self) -> ChunkLenIter { + self.0.physical().chunk_lengths() + } + fn name(&self) -> &PlSmallStr { + self.0.physical().name() + } + + fn chunks(&self) -> &Vec { + self.0.physical().chunks() + } + unsafe fn chunks_mut(&mut self) -> &mut Vec { + self.0.physical_mut().chunks_mut() + } + fn shrink_to_fit(&mut self) { + self.0.physical_mut().shrink_to_fit() + } + + fn slice(&self, offset: i64, length: usize) -> Series { + self.with_state(false, |cats| cats.slice(offset, length)) + .into_series() + } + fn split_at(&self, offset: i64) -> (Series, Series) { + let (a, b) = self.0.physical().split_at(offset); + let a = self.finish_with_state(false, a).into_series(); + let b = self.finish_with_state(false, b).into_series(); + (a, b) + } + + fn append(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + self.0.append(other.categorical().unwrap()) + } + fn append_owned(&mut self, mut other: Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + let other = other + ._get_inner_mut() + .as_any_mut() + .downcast_mut::() + .unwrap(); + self.0.append_owned(std::mem::take(other)) + } + + fn extend(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), extend); + let other_ca = other.categorical().unwrap(); + // Fast path for globals of the same source + let rev_map_self = self.0.get_rev_map(); + let rev_map_other = other_ca.get_rev_map(); + match (&**rev_map_self, &**rev_map_other) { + (RevMapping::Global(_, _, idl), RevMapping::Global(_, _, idr)) if idl == idr => { + let mut rev_map_merger = GlobalRevMapMerger::new(rev_map_self.clone()); + rev_map_merger.merge_map(rev_map_other)?; + self.0.physical_mut().extend(other_ca.physical())?; + // SAFETY: rev_maps are merged + unsafe { self.0.set_rev_map(rev_map_merger.finish(), false) }; + Ok(()) + }, + _ => self.0.append(other_ca), + } + } + + fn filter(&self, filter: &BooleanChunked) -> PolarsResult { + self.try_with_state(false, |cats| cats.filter(filter)) + .map(|ca| ca.into_series()) + } + + fn take(&self, indices: &IdxCa) -> PolarsResult { + self.try_with_state(false, |cats| cats.take(indices)) + .map(|ca| ca.into_series()) + } + + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.with_state(false, |cats| cats.take_unchecked(indices)) + .into_series() + } + + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + self.try_with_state(false, |cats| cats.take(indices)) + .map(|ca| ca.into_series()) + } + + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.with_state(false, |cats| cats.take_unchecked(indices)) + .into_series() + } + + fn len(&self) -> usize { + self.0.len() + } + + fn rechunk(&self) -> Series { + self.with_state(true, |ca| ca.rechunk().into_owned()) + .into_series() + } + + fn new_from_index(&self, index: usize, length: usize) -> Series { + self.with_state(false, |cats| cats.new_from_index(index, length)) + .into_series() + } + + fn cast(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + self.0.cast_with_options(dtype, options) + } + + #[inline] + unsafe fn get_unchecked(&self, index: usize) -> AnyValue { + self.0.get_any_value_unchecked(index) + } + + fn sort_with(&self, options: SortOptions) -> PolarsResult { + Ok(self.0.sort_with(options).into_series()) + } + + fn arg_sort(&self, options: SortOptions) -> IdxCa { + self.0.arg_sort(options) + } + + fn null_count(&self) -> usize { + self.0.physical().null_count() + } + + fn has_nulls(&self) -> bool { + self.0.physical().has_nulls() + } + + #[cfg(feature = "algorithm_group_by")] + fn unique(&self) -> PolarsResult { + self.0.unique().map(|ca| ca.into_series()) + } + + #[cfg(feature = "algorithm_group_by")] + fn n_unique(&self) -> PolarsResult { + self.0.n_unique() + } + + #[cfg(feature = "algorithm_group_by")] + fn arg_unique(&self) -> PolarsResult { + self.0.physical().arg_unique() + } + + fn is_null(&self) -> BooleanChunked { + self.0.physical().is_null() + } + + fn is_not_null(&self) -> BooleanChunked { + self.0.physical().is_not_null() + } + + fn reverse(&self) -> Series { + self.with_state(true, |cats| cats.reverse()).into_series() + } + + fn as_single_ptr(&mut self) -> PolarsResult { + self.0.physical_mut().as_single_ptr() + } + + fn shift(&self, periods: i64) -> Series { + self.with_state(false, |ca| ca.shift(periods)).into_series() + } + + fn clone_inner(&self) -> Arc { + Arc::new(SeriesWrap(Clone::clone(&self.0))) + } + + fn min_reduce(&self) -> PolarsResult { + Ok(ChunkAggSeries::min_reduce(&self.0)) + } + + fn max_reduce(&self) -> PolarsResult { + Ok(ChunkAggSeries::max_reduce(&self.0)) + } + + fn as_any(&self) -> &dyn Any { + &self.0 + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } + + fn as_phys_any(&self) -> &dyn Any { + self.0.physical() + } + + fn as_arc_any(self: Arc) -> Arc { + self as _ + } +} + +impl private::PrivateSeriesNumeric for SeriesWrap { + fn bit_repr(&self) -> Option { + Some(BitRepr::Small(self.0.physical().clone())) + } +} diff --git a/crates/polars-core/src/series/implementations/date.rs b/crates/polars-core/src/series/implementations/date.rs new file mode 100644 index 000000000000..82764cd4a3ac --- /dev/null +++ b/crates/polars-core/src/series/implementations/date.rs @@ -0,0 +1,383 @@ +//! This module exists to reduce compilation times. +//! +//! All the data types are backed by a physical type in memory e.g. Date -> i32, Datetime-> i64. +//! +//! Series lead to code implementations of all traits. Whereas there are a lot of duplicates due to +//! data types being backed by the same physical type. In this module we reduce compile times by +//! opting for a little more run time cost. We cast to the physical type -> apply the operation and +//! (depending on the result) cast back to the original type +//! +use super::*; +#[cfg(feature = "algorithm_group_by")] +use crate::frame::group_by::*; +use crate::prelude::*; + +unsafe impl IntoSeries for DateChunked { + fn into_series(self) -> Series { + Series(Arc::new(SeriesWrap(self))) + } +} + +impl private::PrivateSeries for SeriesWrap { + fn compute_len(&mut self) { + self.0.compute_len() + } + + fn _field(&self) -> Cow { + Cow::Owned(self.0.field()) + } + + fn _dtype(&self) -> &DataType { + self.0.dtype() + } + + fn _get_flags(&self) -> StatisticsFlags { + self.0.get_flags() + } + + fn _set_flags(&mut self, flags: StatisticsFlags) { + self.0.set_flags(flags) + } + + #[cfg(feature = "zip_with")] + fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult { + let other = other.to_physical_repr().into_owned(); + self.0 + .zip_with(mask, other.as_ref().as_ref()) + .map(|ca| ca.into_date().into_series()) + } + + fn into_total_eq_inner<'a>(&'a self) -> Box { + self.0.physical().into_total_eq_inner() + } + fn into_total_ord_inner<'a>(&'a self) -> Box { + self.0.physical().into_total_ord_inner() + } + + fn vec_hash( + &self, + random_state: PlSeedableRandomStateQuality, + buf: &mut Vec, + ) -> PolarsResult<()> { + self.0.vec_hash(random_state, buf)?; + Ok(()) + } + + fn vec_hash_combine( + &self, + build_hasher: PlSeedableRandomStateQuality, + hashes: &mut [u64], + ) -> PolarsResult<()> { + self.0.vec_hash_combine(build_hasher, hashes)?; + Ok(()) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_min(&self, groups: &GroupsType) -> Series { + self.0.agg_min(groups).into_date().into_series() + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_max(&self, groups: &GroupsType) -> Series { + self.0.agg_max(groups).into_date().into_series() + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { + // we cannot cast and dispatch as the inner type of the list would be incorrect + self.0 + .agg_list(groups) + .cast(&DataType::List(Box::new(self.dtype().clone()))) + .unwrap() + } + + fn subtract(&self, rhs: &Series) -> PolarsResult { + match rhs.dtype() { + DataType::Date => { + let dt = DataType::Datetime(TimeUnit::Milliseconds, None); + let lhs = self.cast(&dt, CastOptions::NonStrict)?; + let rhs = rhs.cast(&dt)?; + lhs.subtract(&rhs) + }, + DataType::Duration(_) => std::ops::Sub::sub( + &self.cast( + &DataType::Datetime(TimeUnit::Milliseconds, None), + CastOptions::NonStrict, + )?, + rhs, + )? + .cast(&DataType::Date), + dtr => polars_bail!(opq = sub, DataType::Date, dtr), + } + } + + fn add_to(&self, rhs: &Series) -> PolarsResult { + match rhs.dtype() { + DataType::Duration(_) => std::ops::Add::add( + &self.cast( + &DataType::Datetime(TimeUnit::Milliseconds, None), + CastOptions::NonStrict, + )?, + rhs, + )? + .cast(&DataType::Date), + dtr => polars_bail!(opq = add, DataType::Date, dtr), + } + } + + fn multiply(&self, rhs: &Series) -> PolarsResult { + polars_bail!(opq = mul, self.0.dtype(), rhs.dtype()); + } + + fn divide(&self, rhs: &Series) -> PolarsResult { + polars_bail!(opq = div, self.0.dtype(), rhs.dtype()); + } + + fn remainder(&self, rhs: &Series) -> PolarsResult { + polars_bail!(opq = rem, self.0.dtype(), rhs.dtype()); + } + #[cfg(feature = "algorithm_group_by")] + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + self.0.group_tuples(multithreaded, sorted) + } + + fn arg_sort_multiple( + &self, + by: &[Column], + options: &SortMultipleOptions, + ) -> PolarsResult { + self.0.deref().arg_sort_multiple(by, options) + } +} + +impl SeriesTrait for SeriesWrap { + fn rename(&mut self, name: PlSmallStr) { + self.0.rename(name); + } + + fn chunk_lengths(&self) -> ChunkLenIter { + self.0.chunk_lengths() + } + fn name(&self) -> &PlSmallStr { + self.0.name() + } + + fn chunks(&self) -> &Vec { + self.0.chunks() + } + unsafe fn chunks_mut(&mut self) -> &mut Vec { + self.0.chunks_mut() + } + + fn shrink_to_fit(&mut self) { + self.0.shrink_to_fit() + } + + fn slice(&self, offset: i64, length: usize) -> Series { + self.0.slice(offset, length).into_date().into_series() + } + fn split_at(&self, offset: i64) -> (Series, Series) { + let (a, b) = self.0.split_at(offset); + (a.into_date().into_series(), b.into_date().into_series()) + } + + fn _sum_as_f64(&self) -> f64 { + self.0._sum_as_f64() + } + + fn mean(&self) -> Option { + self.0.mean() + } + + fn median(&self) -> Option { + self.0.median() + } + + fn append(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + let mut other = other.to_physical_repr().into_owned(); + self.0 + .append_owned(std::mem::take(other._get_inner_mut().as_mut())) + } + fn append_owned(&mut self, mut other: Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + self.0.append_owned(std::mem::take( + &mut other + ._get_inner_mut() + .as_any_mut() + .downcast_mut::() + .unwrap() + .0, + )) + } + fn extend(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), extend); + // 3 refs + // ref Cow + // ref SeriesTrait + // ref ChunkedArray + let other = other.to_physical_repr(); + self.0.extend(other.as_ref().as_ref().as_ref())?; + Ok(()) + } + + fn filter(&self, filter: &BooleanChunked) -> PolarsResult { + self.0.filter(filter).map(|ca| ca.into_date().into_series()) + } + + fn take(&self, indices: &IdxCa) -> PolarsResult { + Ok(self.0.take(indices)?.into_date().into_series()) + } + + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0.take_unchecked(indices).into_date().into_series() + } + + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self.0.take(indices)?.into_date().into_series()) + } + + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0.take_unchecked(indices).into_date().into_series() + } + + fn len(&self) -> usize { + self.0.len() + } + + fn rechunk(&self) -> Series { + self.0.rechunk().into_owned().into_date().into_series() + } + + fn new_from_index(&self, index: usize, length: usize) -> Series { + self.0 + .new_from_index(index, length) + .into_date() + .into_series() + } + + fn cast(&self, dtype: &DataType, cast_options: CastOptions) -> PolarsResult { + match dtype { + DataType::String => Ok(self + .0 + .clone() + .into_series() + .date() + .unwrap() + .to_string("%Y-%m-%d")? + .into_series()), + #[cfg(feature = "dtype-datetime")] + DataType::Datetime(_, _) => { + let mut out = self.0.cast_with_options(dtype, CastOptions::NonStrict)?; + out.set_sorted_flag(self.0.is_sorted_flag()); + Ok(out) + }, + _ => self.0.cast_with_options(dtype, cast_options), + } + } + + #[inline] + unsafe fn get_unchecked(&self, index: usize) -> AnyValue { + self.0.get_any_value_unchecked(index) + } + + fn sort_with(&self, options: SortOptions) -> PolarsResult { + Ok(self.0.sort_with(options).into_date().into_series()) + } + + fn arg_sort(&self, options: SortOptions) -> IdxCa { + self.0.arg_sort(options) + } + + fn null_count(&self) -> usize { + self.0.null_count() + } + + fn has_nulls(&self) -> bool { + self.0.has_nulls() + } + + #[cfg(feature = "algorithm_group_by")] + fn unique(&self) -> PolarsResult { + self.0.unique().map(|ca| ca.into_date().into_series()) + } + + #[cfg(feature = "algorithm_group_by")] + fn n_unique(&self) -> PolarsResult { + self.0.n_unique() + } + + #[cfg(feature = "algorithm_group_by")] + fn arg_unique(&self) -> PolarsResult { + self.0.arg_unique() + } + + fn is_null(&self) -> BooleanChunked { + self.0.is_null() + } + + fn is_not_null(&self) -> BooleanChunked { + self.0.is_not_null() + } + + fn reverse(&self) -> Series { + self.0.reverse().into_date().into_series() + } + + fn as_single_ptr(&mut self) -> PolarsResult { + self.0.as_single_ptr() + } + + fn shift(&self, periods: i64) -> Series { + self.0.shift(periods).into_date().into_series() + } + + fn max_reduce(&self) -> PolarsResult { + let sc = self.0.max_reduce(); + let av = sc.value().cast(self.dtype()).into_static(); + Ok(Scalar::new(self.dtype().clone(), av)) + } + + fn min_reduce(&self) -> PolarsResult { + let sc = self.0.min_reduce(); + let av = sc.value().cast(self.dtype()).into_static(); + Ok(Scalar::new(self.dtype().clone(), av)) + } + + fn median_reduce(&self) -> PolarsResult { + let av: AnyValue = self + .median() + .map(|v| (v * (MS_IN_DAY as f64)) as i64) + .into(); + Ok(Scalar::new( + DataType::Datetime(TimeUnit::Milliseconds, None), + av, + )) + } + + fn clone_inner(&self) -> Arc { + Arc::new(SeriesWrap(Clone::clone(&self.0))) + } + + fn as_any(&self) -> &dyn Any { + &self.0 + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } + + fn as_phys_any(&self) -> &dyn Any { + self.0.physical() + } + + fn as_arc_any(self: Arc) -> Arc { + self as _ + } +} + +impl private::PrivateSeriesNumeric for SeriesWrap { + fn bit_repr(&self) -> Option { + Some(self.0.to_bit_repr()) + } +} diff --git a/crates/polars-core/src/series/implementations/datetime.rs b/crates/polars-core/src/series/implementations/datetime.rs new file mode 100644 index 000000000000..70262a5e7390 --- /dev/null +++ b/crates/polars-core/src/series/implementations/datetime.rs @@ -0,0 +1,397 @@ +use polars_compute::rolling::QuantileMethod; + +use super::*; +#[cfg(feature = "algorithm_group_by")] +use crate::frame::group_by::*; +use crate::prelude::*; + +unsafe impl IntoSeries for DatetimeChunked { + fn into_series(self) -> Series { + Series(Arc::new(SeriesWrap(self))) + } +} + +impl private::PrivateSeriesNumeric for SeriesWrap { + fn bit_repr(&self) -> Option { + Some(self.0.to_bit_repr()) + } +} + +impl private::PrivateSeries for SeriesWrap { + fn compute_len(&mut self) { + self.0.compute_len() + } + fn _field(&self) -> Cow { + Cow::Owned(self.0.field()) + } + fn _dtype(&self) -> &DataType { + self.0.dtype() + } + fn _get_flags(&self) -> StatisticsFlags { + self.0.get_flags() + } + fn _set_flags(&mut self, flags: StatisticsFlags) { + self.0.set_flags(flags) + } + + #[cfg(feature = "zip_with")] + fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult { + let other = other.to_physical_repr().into_owned(); + self.0.zip_with(mask, other.as_ref().as_ref()).map(|ca| { + ca.into_datetime(self.0.time_unit(), self.0.time_zone().clone()) + .into_series() + }) + } + + fn into_total_eq_inner<'a>(&'a self) -> Box { + self.0.physical().into_total_eq_inner() + } + fn into_total_ord_inner<'a>(&'a self) -> Box { + self.0.physical().into_total_ord_inner() + } + + fn vec_hash( + &self, + random_state: PlSeedableRandomStateQuality, + buf: &mut Vec, + ) -> PolarsResult<()> { + self.0.vec_hash(random_state, buf)?; + Ok(()) + } + + fn vec_hash_combine( + &self, + build_hasher: PlSeedableRandomStateQuality, + hashes: &mut [u64], + ) -> PolarsResult<()> { + self.0.vec_hash_combine(build_hasher, hashes)?; + Ok(()) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_min(&self, groups: &GroupsType) -> Series { + self.0 + .agg_min(groups) + .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) + .into_series() + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_max(&self, groups: &GroupsType) -> Series { + self.0 + .agg_max(groups) + .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) + .into_series() + } + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { + // we cannot cast and dispatch as the inner type of the list would be incorrect + self.0 + .agg_list(groups) + .cast(&DataType::List(Box::new(self.dtype().clone()))) + .unwrap() + } + + fn subtract(&self, rhs: &Series) -> PolarsResult { + match (self.dtype(), rhs.dtype()) { + (DataType::Datetime(tu, tz), DataType::Datetime(tur, tzr)) => { + assert_eq!(tu, tur); + assert_eq!(tz, tzr); + let lhs = self.cast(&DataType::Int64, CastOptions::NonStrict).unwrap(); + let rhs = rhs.cast(&DataType::Int64).unwrap(); + Ok(lhs.subtract(&rhs)?.into_duration(*tu).into_series()) + }, + (DataType::Datetime(tu, tz), DataType::Duration(tur)) => { + assert_eq!(tu, tur); + let lhs = self.cast(&DataType::Int64, CastOptions::NonStrict).unwrap(); + let rhs = rhs.cast(&DataType::Int64).unwrap(); + Ok(lhs + .subtract(&rhs)? + .into_datetime(*tu, tz.clone()) + .into_series()) + }, + (dtl, dtr) => polars_bail!(opq = sub, dtl, dtr), + } + } + fn add_to(&self, rhs: &Series) -> PolarsResult { + match (self.dtype(), rhs.dtype()) { + (DataType::Datetime(tu, tz), DataType::Duration(tur)) => { + assert_eq!(tu, tur); + let lhs = self.cast(&DataType::Int64, CastOptions::NonStrict).unwrap(); + let rhs = rhs.cast(&DataType::Int64).unwrap(); + Ok(lhs + .add_to(&rhs)? + .into_datetime(*tu, tz.clone()) + .into_series()) + }, + (dtl, dtr) => polars_bail!(opq = add, dtl, dtr), + } + } + fn multiply(&self, rhs: &Series) -> PolarsResult { + polars_bail!(opq = mul, self.dtype(), rhs.dtype()); + } + fn divide(&self, rhs: &Series) -> PolarsResult { + polars_bail!(opq = div, self.dtype(), rhs.dtype()); + } + fn remainder(&self, rhs: &Series) -> PolarsResult { + polars_bail!(opq = rem, self.dtype(), rhs.dtype()); + } + #[cfg(feature = "algorithm_group_by")] + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + self.0.group_tuples(multithreaded, sorted) + } + + fn arg_sort_multiple( + &self, + by: &[Column], + options: &SortMultipleOptions, + ) -> PolarsResult { + self.0.deref().arg_sort_multiple(by, options) + } +} + +impl SeriesTrait for SeriesWrap { + fn rename(&mut self, name: PlSmallStr) { + self.0.rename(name); + } + + fn chunk_lengths(&self) -> ChunkLenIter { + self.0.chunk_lengths() + } + fn name(&self) -> &PlSmallStr { + self.0.name() + } + + fn chunks(&self) -> &Vec { + self.0.chunks() + } + unsafe fn chunks_mut(&mut self) -> &mut Vec { + self.0.chunks_mut() + } + + fn shrink_to_fit(&mut self) { + self.0.shrink_to_fit() + } + + fn slice(&self, offset: i64, length: usize) -> Series { + self.0 + .slice(offset, length) + .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) + .into_series() + } + fn split_at(&self, offset: i64) -> (Series, Series) { + let (a, b) = self.0.split_at(offset); + ( + a.into_datetime(self.0.time_unit(), self.0.time_zone().clone()) + .into_series(), + b.into_datetime(self.0.time_unit(), self.0.time_zone().clone()) + .into_series(), + ) + } + + fn _sum_as_f64(&self) -> f64 { + self.0._sum_as_f64() + } + + fn mean(&self) -> Option { + self.0.mean() + } + + fn median(&self) -> Option { + self.0.median() + } + + fn append(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + let mut other = other.to_physical_repr().into_owned(); + self.0 + .append_owned(std::mem::take(other._get_inner_mut().as_mut())) + } + fn append_owned(&mut self, mut other: Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + self.0.append_owned(std::mem::take( + &mut other + ._get_inner_mut() + .as_any_mut() + .downcast_mut::() + .unwrap() + .0, + )) + } + + fn extend(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), extend); + let other = other.to_physical_repr(); + self.0.extend(other.as_ref().as_ref().as_ref())?; + Ok(()) + } + + fn filter(&self, filter: &BooleanChunked) -> PolarsResult { + self.0.filter(filter).map(|ca| { + ca.into_datetime(self.0.time_unit(), self.0.time_zone().clone()) + .into_series() + }) + } + + fn take(&self, indices: &IdxCa) -> PolarsResult { + let ca = self.0.take(indices)?; + Ok(ca + .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) + .into_series()) + } + + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + let ca = self.0.take_unchecked(indices); + ca.into_datetime(self.0.time_unit(), self.0.time_zone().clone()) + .into_series() + } + + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + let ca = self.0.take(indices)?; + Ok(ca + .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) + .into_series()) + } + + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + let ca = self.0.take_unchecked(indices); + ca.into_datetime(self.0.time_unit(), self.0.time_zone().clone()) + .into_series() + } + + fn len(&self) -> usize { + self.0.len() + } + + fn rechunk(&self) -> Series { + self.0 + .rechunk() + .into_owned() + .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) + .into_series() + } + + fn new_from_index(&self, index: usize, length: usize) -> Series { + self.0 + .new_from_index(index, length) + .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) + .into_series() + } + + fn cast(&self, dtype: &DataType, cast_options: CastOptions) -> PolarsResult { + match dtype { + DataType::String => Ok(self.0.to_string("iso")?.into_series()), + _ => self.0.cast_with_options(dtype, cast_options), + } + } + + #[inline] + unsafe fn get_unchecked(&self, index: usize) -> AnyValue { + self.0.get_any_value_unchecked(index) + } + + fn sort_with(&self, options: SortOptions) -> PolarsResult { + Ok(self + .0 + .sort_with(options) + .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) + .into_series()) + } + + fn arg_sort(&self, options: SortOptions) -> IdxCa { + self.0.arg_sort(options) + } + + fn null_count(&self) -> usize { + self.0.null_count() + } + + fn has_nulls(&self) -> bool { + self.0.has_nulls() + } + + #[cfg(feature = "algorithm_group_by")] + fn unique(&self) -> PolarsResult { + self.0.unique().map(|ca| { + ca.into_datetime(self.0.time_unit(), self.0.time_zone().clone()) + .into_series() + }) + } + + #[cfg(feature = "algorithm_group_by")] + fn n_unique(&self) -> PolarsResult { + self.0.n_unique() + } + + #[cfg(feature = "algorithm_group_by")] + fn arg_unique(&self) -> PolarsResult { + self.0.arg_unique() + } + + fn is_null(&self) -> BooleanChunked { + self.0.is_null() + } + + fn is_not_null(&self) -> BooleanChunked { + self.0.is_not_null() + } + + fn reverse(&self) -> Series { + self.0 + .reverse() + .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) + .into_series() + } + + fn as_single_ptr(&mut self) -> PolarsResult { + self.0.as_single_ptr() + } + + fn shift(&self, periods: i64) -> Series { + self.0 + .shift(periods) + .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) + .into_series() + } + + fn max_reduce(&self) -> PolarsResult { + let sc = self.0.max_reduce(); + + Ok(Scalar::new(self.dtype().clone(), sc.value().clone())) + } + + fn min_reduce(&self) -> PolarsResult { + let sc = self.0.min_reduce(); + + Ok(Scalar::new(self.dtype().clone(), sc.value().clone())) + } + + fn median_reduce(&self) -> PolarsResult { + let av: AnyValue = self.median().map(|v| v as i64).into(); + Ok(Scalar::new(self.dtype().clone(), av)) + } + + fn quantile_reduce(&self, _quantile: f64, _method: QuantileMethod) -> PolarsResult { + Ok(Scalar::new(self.dtype().clone(), AnyValue::Null)) + } + + fn clone_inner(&self) -> Arc { + Arc::new(SeriesWrap(Clone::clone(&self.0))) + } + + fn as_any(&self) -> &dyn Any { + &self.0 + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } + + fn as_phys_any(&self) -> &dyn Any { + self.0.physical() + } + + fn as_arc_any(self: Arc) -> Arc { + self as _ + } +} diff --git a/crates/polars-core/src/series/implementations/decimal.rs b/crates/polars-core/src/series/implementations/decimal.rs new file mode 100644 index 000000000000..76a11ab8ed00 --- /dev/null +++ b/crates/polars-core/src/series/implementations/decimal.rs @@ -0,0 +1,466 @@ +use polars_compute::rolling::QuantileMethod; + +use super::*; +use crate::prelude::*; + +unsafe impl IntoSeries for DecimalChunked { + fn into_series(self) -> Series { + Series(Arc::new(SeriesWrap(self))) + } +} + +impl private::PrivateSeriesNumeric for SeriesWrap { + fn bit_repr(&self) -> Option { + None + } +} + +impl SeriesWrap { + fn apply_physical_to_s Int128Chunked>(&self, f: F) -> Series { + f(&self.0) + .into_decimal_unchecked(self.0.precision(), self.0.scale()) + .into_series() + } + + fn apply_physical T>(&self, f: F) -> T { + f(&self.0) + } + + fn scale_factor(&self) -> u128 { + 10u128.pow(self.0.scale() as u32) + } + + fn apply_scale(&self, mut scalar: Scalar) -> Scalar { + if scalar.is_null() { + return scalar; + } + + debug_assert_eq!(scalar.dtype(), &DataType::Float64); + let v = scalar + .value() + .try_extract::() + .expect("should be f64 scalar"); + scalar.update((v / self.scale_factor() as f64).into()); + scalar + } + + fn agg_helper Series>(&self, f: F) -> Series { + let agg_s = f(&self.0); + match agg_s.dtype() { + DataType::Int128 => { + let ca = agg_s.i128().unwrap(); + let ca = ca.as_ref().clone(); + let precision = self.0.precision(); + let scale = self.0.scale(); + ca.into_decimal_unchecked(precision, scale).into_series() + }, + DataType::List(dtype) if matches!(dtype.as_ref(), DataType::Int128) => { + let dtype = self.0.dtype(); + let ca = agg_s.list().unwrap(); + let arr = ca.downcast_iter().next().unwrap(); + // SAFETY: dtype is passed correctly + let precision = self.0.precision(); + let scale = self.0.scale(); + let s = unsafe { + Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![arr.values().clone()], + dtype, + ) + } + .into_decimal(precision, scale) + .unwrap(); + let new_values = s.array_ref(0).clone(); + let dtype = DataType::Int128; + let arrow_dtype = + ListArray::::default_datatype(dtype.to_arrow(CompatLevel::newest())); + let new_arr = ListArray::::new( + arrow_dtype, + arr.offsets().clone(), + new_values, + arr.validity().cloned(), + ); + unsafe { + ListChunked::from_chunks_and_dtype_unchecked( + agg_s.name().clone(), + vec![Box::new(new_arr)], + DataType::List(Box::new(DataType::Decimal(precision, Some(scale)))), + ) + .into_series() + } + }, + _ => unreachable!(), + } + } +} + +impl private::PrivateSeries for SeriesWrap { + fn compute_len(&mut self) { + self.0.compute_len() + } + + fn _field(&self) -> Cow { + Cow::Owned(self.0.field()) + } + + fn _dtype(&self) -> &DataType { + self.0.dtype() + } + fn _get_flags(&self) -> StatisticsFlags { + self.0.get_flags() + } + fn _set_flags(&mut self, flags: StatisticsFlags) { + self.0.set_flags(flags) + } + + #[cfg(feature = "zip_with")] + fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult { + let other = other.decimal()?; + + Ok(self + .0 + .physical() + .zip_with(mask, other.physical())? + .into_decimal_unchecked(self.0.precision(), self.0.scale()) + .into_series()) + } + fn into_total_eq_inner<'a>(&'a self) -> Box { + (&self.0).into_total_eq_inner() + } + fn into_total_ord_inner<'a>(&'a self) -> Box { + (&self.0).into_total_ord_inner() + } + + fn vec_hash( + &self, + random_state: PlSeedableRandomStateQuality, + buf: &mut Vec, + ) -> PolarsResult<()> { + self.0.vec_hash(random_state, buf)?; + Ok(()) + } + + fn vec_hash_combine( + &self, + build_hasher: PlSeedableRandomStateQuality, + hashes: &mut [u64], + ) -> PolarsResult<()> { + self.0.vec_hash_combine(build_hasher, hashes)?; + Ok(()) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_sum(&self, groups: &GroupsType) -> Series { + self.agg_helper(|ca| ca.agg_sum(groups)) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_min(&self, groups: &GroupsType) -> Series { + self.agg_helper(|ca| ca.agg_min(groups)) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_max(&self, groups: &GroupsType) -> Series { + self.agg_helper(|ca| ca.agg_max(groups)) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { + self.agg_helper(|ca| ca.agg_list(groups)) + } + + fn subtract(&self, rhs: &Series) -> PolarsResult { + let rhs = rhs.decimal()?; + ((&self.0) - rhs).map(|ca| ca.into_series()) + } + fn add_to(&self, rhs: &Series) -> PolarsResult { + let rhs = rhs.decimal()?; + ((&self.0) + rhs).map(|ca| ca.into_series()) + } + fn multiply(&self, rhs: &Series) -> PolarsResult { + let rhs = rhs.decimal()?; + ((&self.0) * rhs).map(|ca| ca.into_series()) + } + fn divide(&self, rhs: &Series) -> PolarsResult { + let rhs = rhs.decimal()?; + ((&self.0) / rhs).map(|ca| ca.into_series()) + } + #[cfg(feature = "algorithm_group_by")] + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + self.0.group_tuples(multithreaded, sorted) + } + fn arg_sort_multiple( + &self, + by: &[Column], + options: &SortMultipleOptions, + ) -> PolarsResult { + self.0.arg_sort_multiple(by, options) + } +} + +impl SeriesTrait for SeriesWrap { + fn rename(&mut self, name: PlSmallStr) { + self.0.rename(name) + } + + fn chunk_lengths(&self) -> ChunkLenIter { + self.0.chunk_lengths() + } + + fn name(&self) -> &PlSmallStr { + self.0.name() + } + + fn chunks(&self) -> &Vec { + self.0.chunks() + } + unsafe fn chunks_mut(&mut self) -> &mut Vec { + self.0.chunks_mut() + } + + fn slice(&self, offset: i64, length: usize) -> Series { + self.apply_physical_to_s(|ca| ca.slice(offset, length)) + } + + fn split_at(&self, offset: i64) -> (Series, Series) { + let (a, b) = self.0.split_at(offset); + let a = a + .into_decimal_unchecked(self.0.precision(), self.0.scale()) + .into_series(); + let b = b + .into_decimal_unchecked(self.0.precision(), self.0.scale()) + .into_series(); + (a, b) + } + + fn append(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + let mut other = other.to_physical_repr().into_owned(); + self.0 + .append_owned(std::mem::take(other._get_inner_mut().as_mut())) + } + fn append_owned(&mut self, mut other: Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + self.0.append_owned(std::mem::take( + &mut other + ._get_inner_mut() + .as_any_mut() + .downcast_mut::() + .unwrap() + .0, + )) + } + + fn extend(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), extend); + // 3 refs + // ref Cow + // ref SeriesTrait + // ref ChunkedArray + let other = other.to_physical_repr(); + self.0.extend(other.as_ref().as_ref().as_ref())?; + Ok(()) + } + + fn filter(&self, filter: &BooleanChunked) -> PolarsResult { + Ok(self + .0 + .filter(filter)? + .into_decimal_unchecked(self.0.precision(), self.0.scale()) + .into_series()) + } + + fn take(&self, indices: &IdxCa) -> PolarsResult { + Ok(self + .0 + .take(indices)? + .into_decimal_unchecked(self.0.precision(), self.0.scale()) + .into_series()) + } + + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0 + .take_unchecked(indices) + .into_decimal_unchecked(self.0.precision(), self.0.scale()) + .into_series() + } + + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self + .0 + .take(indices)? + .into_decimal_unchecked(self.0.precision(), self.0.scale()) + .into_series()) + } + + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0 + .take_unchecked(indices) + .into_decimal_unchecked(self.0.precision(), self.0.scale()) + .into_series() + } + + fn len(&self) -> usize { + self.0.len() + } + + fn rechunk(&self) -> Series { + let ca = self.0.rechunk().into_owned(); + ca.into_decimal_unchecked(self.0.precision(), self.0.scale()) + .into_series() + } + + fn new_from_index(&self, index: usize, length: usize) -> Series { + self.0 + .new_from_index(index, length) + .into_decimal_unchecked(self.0.precision(), self.0.scale()) + .into_series() + } + + fn cast(&self, dtype: &DataType, cast_options: CastOptions) -> PolarsResult { + self.0.cast_with_options(dtype, cast_options) + } + + #[inline] + unsafe fn get_unchecked(&self, index: usize) -> AnyValue { + self.0.get_any_value_unchecked(index) + } + + fn sort_with(&self, options: SortOptions) -> PolarsResult { + Ok(self + .0 + .sort_with(options) + .into_decimal_unchecked(self.0.precision(), self.0.scale()) + .into_series()) + } + + fn arg_sort(&self, options: SortOptions) -> IdxCa { + self.0.arg_sort(options) + } + + fn null_count(&self) -> usize { + self.0.null_count() + } + + fn has_nulls(&self) -> bool { + self.0.has_nulls() + } + + #[cfg(feature = "algorithm_group_by")] + fn unique(&self) -> PolarsResult { + Ok(self.apply_physical_to_s(|ca| ca.unique().unwrap())) + } + + #[cfg(feature = "algorithm_group_by")] + fn n_unique(&self) -> PolarsResult { + self.0.n_unique() + } + + #[cfg(feature = "algorithm_group_by")] + fn arg_unique(&self) -> PolarsResult { + self.0.arg_unique() + } + + fn is_null(&self) -> BooleanChunked { + self.0.is_null() + } + + fn is_not_null(&self) -> BooleanChunked { + self.0.is_not_null() + } + + fn reverse(&self) -> Series { + self.apply_physical_to_s(|ca| ca.reverse()) + } + + fn shift(&self, periods: i64) -> Series { + self.apply_physical_to_s(|ca| ca.shift(periods)) + } + + fn clone_inner(&self) -> Arc { + Arc::new(SeriesWrap(Clone::clone(&self.0))) + } + + fn sum_reduce(&self) -> PolarsResult { + Ok(self.apply_physical(|ca| { + let sum = ca.sum(); + let DataType::Decimal(_, Some(scale)) = self.dtype() else { + unreachable!() + }; + let av = AnyValue::Decimal(sum.unwrap(), *scale); + Scalar::new(self.dtype().clone(), av) + })) + } + fn min_reduce(&self) -> PolarsResult { + Ok(self.apply_physical(|ca| { + let min = ca.min(); + let DataType::Decimal(_, Some(scale)) = self.dtype() else { + unreachable!() + }; + let av = if let Some(min) = min { + AnyValue::Decimal(min, *scale) + } else { + AnyValue::Null + }; + Scalar::new(self.dtype().clone(), av) + })) + } + fn max_reduce(&self) -> PolarsResult { + Ok(self.apply_physical(|ca| { + let max = ca.max(); + let DataType::Decimal(_, Some(scale)) = self.dtype() else { + unreachable!() + }; + let av = if let Some(m) = max { + AnyValue::Decimal(m, *scale) + } else { + AnyValue::Null + }; + Scalar::new(self.dtype().clone(), av) + })) + } + + fn _sum_as_f64(&self) -> f64 { + self.0._sum_as_f64() / self.scale_factor() as f64 + } + + fn mean(&self) -> Option { + self.0.mean().map(|v| v / self.scale_factor() as f64) + } + + fn median(&self) -> Option { + self.0.median().map(|v| v / self.scale_factor() as f64) + } + fn median_reduce(&self) -> PolarsResult { + Ok(self.apply_scale(self.0.median_reduce())) + } + + fn std(&self, ddof: u8) -> Option { + self.0.std(ddof).map(|v| v / self.scale_factor() as f64) + } + fn std_reduce(&self, ddof: u8) -> PolarsResult { + Ok(self.apply_scale(self.0.std_reduce(ddof))) + } + + fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { + self.0 + .quantile_reduce(quantile, method) + .map(|v| self.apply_scale(v)) + } + + fn as_any(&self) -> &dyn Any { + &self.0 + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } + + fn as_phys_any(&self) -> &dyn Any { + self.0.physical() + } + + fn as_arc_any(self: Arc) -> Arc { + self as _ + } +} diff --git a/crates/polars-core/src/series/implementations/duration.rs b/crates/polars-core/src/series/implementations/duration.rs new file mode 100644 index 000000000000..be95c24eaafa --- /dev/null +++ b/crates/polars-core/src/series/implementations/duration.rs @@ -0,0 +1,553 @@ +use std::ops::DerefMut; + +use polars_compute::rolling::QuantileMethod; + +use super::*; +use crate::chunked_array::comparison::*; +#[cfg(feature = "algorithm_group_by")] +use crate::frame::group_by::*; +use crate::prelude::*; + +unsafe impl IntoSeries for DurationChunked { + fn into_series(self) -> Series { + Series(Arc::new(SeriesWrap(self))) + } +} + +impl private::PrivateSeriesNumeric for SeriesWrap { + fn bit_repr(&self) -> Option { + Some(self.0.to_bit_repr()) + } +} + +impl private::PrivateSeries for SeriesWrap { + fn compute_len(&mut self) { + self.0.compute_len() + } + fn _field(&self) -> Cow { + Cow::Owned(self.0.field()) + } + fn _dtype(&self) -> &DataType { + self.0.dtype() + } + + fn _set_flags(&mut self, flags: StatisticsFlags) { + self.0.deref_mut().set_flags(flags) + } + fn _get_flags(&self) -> StatisticsFlags { + self.0.deref().get_flags() + } + + unsafe fn equal_element(&self, idx_self: usize, idx_other: usize, other: &Series) -> bool { + self.0.equal_element(idx_self, idx_other, other) + } + + #[cfg(feature = "zip_with")] + fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult { + let other = other.to_physical_repr().into_owned(); + self.0 + .zip_with(mask, other.as_ref().as_ref()) + .map(|ca| ca.into_duration(self.0.time_unit()).into_series()) + } + + fn into_total_eq_inner<'a>(&'a self) -> Box { + self.0.physical().into_total_eq_inner() + } + fn into_total_ord_inner<'a>(&'a self) -> Box { + self.0.physical().into_total_ord_inner() + } + + fn vec_hash( + &self, + random_state: PlSeedableRandomStateQuality, + buf: &mut Vec, + ) -> PolarsResult<()> { + self.0.vec_hash(random_state, buf)?; + Ok(()) + } + + fn vec_hash_combine( + &self, + build_hasher: PlSeedableRandomStateQuality, + hashes: &mut [u64], + ) -> PolarsResult<()> { + self.0.vec_hash_combine(build_hasher, hashes)?; + Ok(()) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_min(&self, groups: &GroupsType) -> Series { + self.0 + .agg_min(groups) + .into_duration(self.0.time_unit()) + .into_series() + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_max(&self, groups: &GroupsType) -> Series { + self.0 + .agg_max(groups) + .into_duration(self.0.time_unit()) + .into_series() + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_sum(&self, groups: &GroupsType) -> Series { + self.0 + .agg_sum(groups) + .into_duration(self.0.time_unit()) + .into_series() + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_std(&self, groups: &GroupsType, ddof: u8) -> Series { + self.0 + .agg_std(groups, ddof) + // cast f64 back to physical type + .cast(&DataType::Int64) + .unwrap() + .into_duration(self.0.time_unit()) + .into_series() + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_var(&self, groups: &GroupsType, ddof: u8) -> Series { + self.0 + .agg_var(groups, ddof) + // cast f64 back to physical type + .cast(&DataType::Int64) + .unwrap() + .into_duration(self.0.time_unit()) + .into_series() + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { + // we cannot cast and dispatch as the inner type of the list would be incorrect + self.0 + .agg_list(groups) + .cast(&DataType::List(Box::new(self.dtype().clone()))) + .unwrap() + } + + fn subtract(&self, rhs: &Series) -> PolarsResult { + match (self.dtype(), rhs.dtype()) { + (DataType::Duration(tu), DataType::Duration(tur)) => { + polars_ensure!(tu == tur, InvalidOperation: "units are different"); + let lhs = self.cast(&DataType::Int64, CastOptions::NonStrict).unwrap(); + let rhs = rhs.cast(&DataType::Int64).unwrap(); + Ok(lhs.subtract(&rhs)?.into_duration(*tu).into_series()) + }, + (dtl, dtr) => polars_bail!(opq = sub, dtl, dtr), + } + } + fn add_to(&self, rhs: &Series) -> PolarsResult { + match (self.dtype(), rhs.dtype()) { + (DataType::Duration(tu), DataType::Duration(tur)) => { + polars_ensure!(tu == tur, InvalidOperation: "units are different"); + let lhs = self.cast(&DataType::Int64, CastOptions::NonStrict).unwrap(); + let rhs = rhs.cast(&DataType::Int64).unwrap(); + Ok(lhs.add_to(&rhs)?.into_duration(*tu).into_series()) + }, + (DataType::Duration(tu), DataType::Date) => { + let one_day_in_tu: i64 = match tu { + TimeUnit::Milliseconds => 86_400_000, + TimeUnit::Microseconds => 86_400_000_000, + TimeUnit::Nanoseconds => 86_400_000_000_000, + }; + let lhs = + self.cast(&DataType::Int64, CastOptions::NonStrict).unwrap() / one_day_in_tu; + let rhs = rhs + .cast(&DataType::Int32) + .unwrap() + .cast(&DataType::Int64) + .unwrap(); + Ok(lhs + .add_to(&rhs)? + .cast(&DataType::Int32)? + .into_date() + .into_series()) + }, + (DataType::Duration(tu), DataType::Datetime(tur, tz)) => { + polars_ensure!(tu == tur, InvalidOperation: "units are different"); + let lhs = self.cast(&DataType::Int64, CastOptions::NonStrict).unwrap(); + let rhs = rhs.cast(&DataType::Int64).unwrap(); + Ok(lhs + .add_to(&rhs)? + .into_datetime(*tu, tz.clone()) + .into_series()) + }, + (dtl, dtr) => polars_bail!(opq = add, dtl, dtr), + } + } + fn multiply(&self, rhs: &Series) -> PolarsResult { + let tul = self.0.time_unit(); + match rhs.dtype() { + DataType::Int64 => Ok((&self.0.0 * rhs.i64().unwrap()) + .into_duration(tul) + .into_series()), + dt if dt.is_integer() => { + let rhs = rhs.cast(&DataType::Int64)?; + self.multiply(&rhs) + }, + dt if dt.is_float() => { + let phys = &self.0.0; + let phys_float = phys.cast(dt).unwrap(); + let out = std::ops::Mul::mul(&phys_float, rhs)? + .cast(&DataType::Int64) + .unwrap(); + let phys = out.i64().unwrap().clone(); + Ok(phys.into_duration(tul).into_series()) + }, + _ => { + polars_bail!(opq = mul, self.dtype(), rhs.dtype()); + }, + } + } + fn divide(&self, rhs: &Series) -> PolarsResult { + let tul = self.0.time_unit(); + match rhs.dtype() { + DataType::Duration(tur) => { + if tul == *tur { + // Returns a constant as f64. + Ok(std::ops::Div::div( + &self.0.0.cast(&DataType::Float64).unwrap(), + &rhs.duration().unwrap().0.cast(&DataType::Float64).unwrap(), + )? + .into_series()) + } else { + let rhs = rhs.cast(self.dtype())?; + self.divide(&rhs) + } + }, + DataType::Int64 => Ok((&self.0.0 / rhs.i64().unwrap()) + .into_duration(tul) + .into_series()), + dt if dt.is_integer() => { + let rhs = rhs.cast(&DataType::Int64)?; + self.divide(&rhs) + }, + dt if dt.is_float() => { + let phys = &self.0.0; + let phys_float = phys.cast(dt).unwrap(); + let out = std::ops::Div::div(&phys_float, rhs)? + .cast(&DataType::Int64) + .unwrap(); + let phys = out.i64().unwrap().clone(); + Ok(phys.into_duration(tul).into_series()) + }, + _ => { + polars_bail!(opq = div, self.dtype(), rhs.dtype()); + }, + } + } + fn remainder(&self, rhs: &Series) -> PolarsResult { + polars_ensure!(self.dtype() == rhs.dtype(), InvalidOperation: "dtypes and units must be equal in duration arithmetic"); + let lhs = self.cast(&DataType::Int64, CastOptions::NonStrict).unwrap(); + let rhs = rhs.cast(&DataType::Int64).unwrap(); + Ok(lhs + .remainder(&rhs)? + .into_duration(self.0.time_unit()) + .into_series()) + } + #[cfg(feature = "algorithm_group_by")] + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + self.0.group_tuples(multithreaded, sorted) + } + + fn arg_sort_multiple( + &self, + by: &[Column], + options: &SortMultipleOptions, + ) -> PolarsResult { + self.0.deref().arg_sort_multiple(by, options) + } +} + +impl SeriesTrait for SeriesWrap { + fn rename(&mut self, name: PlSmallStr) { + self.0.rename(name); + } + + fn chunk_lengths(&self) -> ChunkLenIter { + self.0.chunk_lengths() + } + fn name(&self) -> &PlSmallStr { + self.0.name() + } + + fn chunks(&self) -> &Vec { + self.0.chunks() + } + unsafe fn chunks_mut(&mut self) -> &mut Vec { + self.0.chunks_mut() + } + + fn shrink_to_fit(&mut self) { + self.0.shrink_to_fit() + } + + fn slice(&self, offset: i64, length: usize) -> Series { + self.0 + .slice(offset, length) + .into_duration(self.0.time_unit()) + .into_series() + } + + fn split_at(&self, offset: i64) -> (Series, Series) { + let (a, b) = self.0.split_at(offset); + let a = a.into_duration(self.0.time_unit()).into_series(); + let b = b.into_duration(self.0.time_unit()).into_series(); + (a, b) + } + + fn _sum_as_f64(&self) -> f64 { + self.0._sum_as_f64() + } + + fn mean(&self) -> Option { + self.0.mean() + } + + fn median(&self) -> Option { + self.0.median() + } + + fn std(&self, ddof: u8) -> Option { + self.0.std(ddof) + } + + fn var(&self, ddof: u8) -> Option { + self.0.var(ddof) + } + + fn append(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + let mut other = other.to_physical_repr().into_owned(); + self.0 + .append_owned(std::mem::take(other._get_inner_mut().as_mut())) + } + fn append_owned(&mut self, mut other: Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + self.0.append_owned(std::mem::take( + &mut other + ._get_inner_mut() + .as_any_mut() + .downcast_mut::() + .unwrap() + .0, + )) + } + + fn extend(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), extend); + let other = other.to_physical_repr(); + self.0.extend(other.as_ref().as_ref().as_ref())?; + Ok(()) + } + + fn filter(&self, filter: &BooleanChunked) -> PolarsResult { + self.0 + .filter(filter) + .map(|ca| ca.into_duration(self.0.time_unit()).into_series()) + } + + fn take(&self, indices: &IdxCa) -> PolarsResult { + Ok(self + .0 + .take(indices)? + .into_duration(self.0.time_unit()) + .into_series()) + } + + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0 + .take_unchecked(indices) + .into_duration(self.0.time_unit()) + .into_series() + } + + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self + .0 + .take(indices)? + .into_duration(self.0.time_unit()) + .into_series()) + } + + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0 + .take_unchecked(indices) + .into_duration(self.0.time_unit()) + .into_series() + } + + fn len(&self) -> usize { + self.0.len() + } + + fn rechunk(&self) -> Series { + self.0 + .rechunk() + .into_owned() + .into_duration(self.0.time_unit()) + .into_series() + } + + fn new_from_index(&self, index: usize, length: usize) -> Series { + self.0 + .new_from_index(index, length) + .into_duration(self.0.time_unit()) + .into_series() + } + + fn cast(&self, dtype: &DataType, cast_options: CastOptions) -> PolarsResult { + self.0.cast_with_options(dtype, cast_options) + } + + #[inline] + unsafe fn get_unchecked(&self, index: usize) -> AnyValue { + self.0.get_any_value_unchecked(index) + } + + fn sort_with(&self, options: SortOptions) -> PolarsResult { + Ok(self + .0 + .sort_with(options) + .into_duration(self.0.time_unit()) + .into_series()) + } + + fn arg_sort(&self, options: SortOptions) -> IdxCa { + self.0.arg_sort(options) + } + + fn null_count(&self) -> usize { + self.0.null_count() + } + + fn has_nulls(&self) -> bool { + self.0.has_nulls() + } + + #[cfg(feature = "algorithm_group_by")] + fn unique(&self) -> PolarsResult { + self.0 + .unique() + .map(|ca| ca.into_duration(self.0.time_unit()).into_series()) + } + + #[cfg(feature = "algorithm_group_by")] + fn n_unique(&self) -> PolarsResult { + self.0.n_unique() + } + + #[cfg(feature = "algorithm_group_by")] + fn arg_unique(&self) -> PolarsResult { + self.0.arg_unique() + } + + fn is_null(&self) -> BooleanChunked { + self.0.is_null() + } + + fn is_not_null(&self) -> BooleanChunked { + self.0.is_not_null() + } + + fn reverse(&self) -> Series { + self.0 + .reverse() + .into_duration(self.0.time_unit()) + .into_series() + } + + fn as_single_ptr(&mut self) -> PolarsResult { + self.0.as_single_ptr() + } + + fn shift(&self, periods: i64) -> Series { + self.0 + .shift(periods) + .into_duration(self.0.time_unit()) + .into_series() + } + + fn sum_reduce(&self) -> PolarsResult { + let sc = self.0.sum_reduce(); + let v = sc.value().as_duration(self.0.time_unit()); + Ok(Scalar::new(self.dtype().clone(), v)) + } + + fn max_reduce(&self) -> PolarsResult { + let sc = self.0.max_reduce(); + let v = sc.value().as_duration(self.0.time_unit()); + Ok(Scalar::new(self.dtype().clone(), v)) + } + fn min_reduce(&self) -> PolarsResult { + let sc = self.0.min_reduce(); + let v = sc.value().as_duration(self.0.time_unit()); + Ok(Scalar::new(self.dtype().clone(), v)) + } + fn std_reduce(&self, ddof: u8) -> PolarsResult { + let sc = self.0.std_reduce(ddof); + let to = self.dtype().to_physical(); + let v = sc.value().cast(&to); + Ok(Scalar::new( + self.dtype().clone(), + v.as_duration(self.0.time_unit()), + )) + } + + fn var_reduce(&self, ddof: u8) -> PolarsResult { + // Why do we go via MilliSeconds here? Seems wrong to me. + // I think we should fix/inspect the tests that fail if we remain on the time-unit here. + let sc = self + .0 + .cast_time_unit(TimeUnit::Milliseconds) + .var_reduce(ddof); + let to = self.dtype().to_physical(); + let v = sc.value().cast(&to); + Ok(Scalar::new( + DataType::Duration(TimeUnit::Milliseconds), + v.as_duration(TimeUnit::Milliseconds), + )) + } + fn median_reduce(&self) -> PolarsResult { + let v: AnyValue = self.median().map(|v| v as i64).into(); + let to = self.dtype().to_physical(); + let v = v.cast(&to); + Ok(Scalar::new( + self.dtype().clone(), + v.as_duration(self.0.time_unit()), + )) + } + fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { + let v = self.0.quantile_reduce(quantile, method)?; + let to = self.dtype().to_physical(); + let v = v.value().cast(&to); + Ok(Scalar::new( + self.dtype().clone(), + v.as_duration(self.0.time_unit()), + )) + } + + fn clone_inner(&self) -> Arc { + Arc::new(SeriesWrap(Clone::clone(&self.0))) + } + fn as_any(&self) -> &dyn Any { + &self.0 + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } + + fn as_phys_any(&self) -> &dyn Any { + self.0.physical() + } + + fn as_arc_any(self: Arc) -> Arc { + self as _ + } +} diff --git a/crates/polars-core/src/series/implementations/floats.rs b/crates/polars-core/src/series/implementations/floats.rs new file mode 100644 index 000000000000..a331fbb8a2f8 --- /dev/null +++ b/crates/polars-core/src/series/implementations/floats.rs @@ -0,0 +1,392 @@ +use polars_compute::rolling::QuantileMethod; + +use super::*; +use crate::chunked_array::comparison::*; +#[cfg(feature = "algorithm_group_by")] +use crate::frame::group_by::*; +use crate::prelude::*; + +macro_rules! impl_dyn_series { + ($ca: ident, $pdt:ident) => { + impl private::PrivateSeries for SeriesWrap<$ca> { + fn compute_len(&mut self) { + self.0.compute_len() + } + fn _field(&self) -> Cow { + Cow::Borrowed(self.0.ref_field()) + } + fn _dtype(&self) -> &DataType { + self.0.ref_field().dtype() + } + + fn _set_flags(&mut self, flags: StatisticsFlags) { + self.0.set_flags(flags) + } + fn _get_flags(&self) -> StatisticsFlags { + self.0.get_flags() + } + unsafe fn equal_element( + &self, + idx_self: usize, + idx_other: usize, + other: &Series, + ) -> bool { + self.0.equal_element(idx_self, idx_other, other) + } + + #[cfg(feature = "zip_with")] + fn zip_with_same_type( + &self, + mask: &BooleanChunked, + other: &Series, + ) -> PolarsResult { + ChunkZip::zip_with(&self.0, mask, other.as_ref().as_ref()) + .map(|ca| ca.into_series()) + } + fn into_total_eq_inner<'a>(&'a self) -> Box { + (&self.0).into_total_eq_inner() + } + fn into_total_ord_inner<'a>(&'a self) -> Box { + (&self.0).into_total_ord_inner() + } + + fn vec_hash( + &self, + random_state: PlSeedableRandomStateQuality, + buf: &mut Vec, + ) -> PolarsResult<()> { + self.0.vec_hash(random_state, buf)?; + Ok(()) + } + + fn vec_hash_combine( + &self, + build_hasher: PlSeedableRandomStateQuality, + hashes: &mut [u64], + ) -> PolarsResult<()> { + self.0.vec_hash_combine(build_hasher, hashes)?; + Ok(()) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_min(&self, groups: &GroupsType) -> Series { + self.0.agg_min(groups) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_max(&self, groups: &GroupsType) -> Series { + self.0.agg_max(groups) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_sum(&self, groups: &GroupsType) -> Series { + self.0.agg_sum(groups) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_std(&self, groups: &GroupsType, ddof: u8) -> Series { + self.agg_std(groups, ddof) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_var(&self, groups: &GroupsType, ddof: u8) -> Series { + self.agg_var(groups, ddof) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { + self.0.agg_list(groups) + } + + #[cfg(feature = "bitwise")] + unsafe fn agg_and(&self, groups: &GroupsType) -> Series { + self.0.agg_and(groups) + } + #[cfg(feature = "bitwise")] + unsafe fn agg_or(&self, groups: &GroupsType) -> Series { + self.0.agg_or(groups) + } + #[cfg(feature = "bitwise")] + unsafe fn agg_xor(&self, groups: &GroupsType) -> Series { + self.0.agg_xor(groups) + } + + fn subtract(&self, rhs: &Series) -> PolarsResult { + NumOpsDispatch::subtract(&self.0, rhs) + } + fn add_to(&self, rhs: &Series) -> PolarsResult { + NumOpsDispatch::add_to(&self.0, rhs) + } + fn multiply(&self, rhs: &Series) -> PolarsResult { + NumOpsDispatch::multiply(&self.0, rhs) + } + fn divide(&self, rhs: &Series) -> PolarsResult { + NumOpsDispatch::divide(&self.0, rhs) + } + fn remainder(&self, rhs: &Series) -> PolarsResult { + NumOpsDispatch::remainder(&self.0, rhs) + } + #[cfg(feature = "algorithm_group_by")] + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + IntoGroupsType::group_tuples(&self.0, multithreaded, sorted) + } + + fn arg_sort_multiple( + &self, + by: &[Column], + options: &SortMultipleOptions, + ) -> PolarsResult { + self.0.arg_sort_multiple(by, options) + } + } + + impl SeriesTrait for SeriesWrap<$ca> { + #[cfg(feature = "rolling_window")] + fn rolling_map( + &self, + _f: &dyn Fn(&Series) -> Series, + _options: RollingOptionsFixedWindow, + ) -> PolarsResult { + ChunkRollApply::rolling_map(&self.0, _f, _options).map(|ca| ca.into_series()) + } + + fn rename(&mut self, name: PlSmallStr) { + self.0.rename(name); + } + + fn chunk_lengths(&self) -> ChunkLenIter { + self.0.chunk_lengths() + } + fn name(&self) -> &PlSmallStr { + self.0.name() + } + + fn chunks(&self) -> &Vec { + self.0.chunks() + } + unsafe fn chunks_mut(&mut self) -> &mut Vec { + self.0.chunks_mut() + } + fn shrink_to_fit(&mut self) { + self.0.shrink_to_fit() + } + + fn slice(&self, offset: i64, length: usize) -> Series { + return self.0.slice(offset, length).into_series(); + } + + fn split_at(&self, offset: i64) -> (Series, Series) { + let (a, b) = self.0.split_at(offset); + (a.into_series(), b.into_series()) + } + + fn append(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + self.0.append(other.as_ref().as_ref())?; + Ok(()) + } + fn append_owned(&mut self, other: Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + self.0.append_owned(other.take_inner()) + } + + fn extend(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), extend); + self.0.extend(other.as_ref().as_ref())?; + Ok(()) + } + + fn filter(&self, filter: &BooleanChunked) -> PolarsResult { + ChunkFilter::filter(&self.0, filter).map(|ca| ca.into_series()) + } + + fn _sum_as_f64(&self) -> f64 { + self.0._sum_as_f64() + } + + fn mean(&self) -> Option { + self.0.mean() + } + + fn median(&self) -> Option { + self.0.median().map(|v| v as f64) + } + + fn std(&self, ddof: u8) -> Option { + self.0.std(ddof) + } + + fn var(&self, ddof: u8) -> Option { + self.0.var(ddof) + } + + fn take(&self, indices: &IdxCa) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) + } + + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0.take_unchecked(indices).into_series() + } + + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) + } + + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0.take_unchecked(indices).into_series() + } + + fn len(&self) -> usize { + self.0.len() + } + + fn rechunk(&self) -> Series { + self.0.rechunk().into_owned().into_series() + } + + fn new_from_index(&self, index: usize, length: usize) -> Series { + ChunkExpandAtIndex::new_from_index(&self.0, index, length).into_series() + } + + fn cast(&self, dtype: &DataType, cast_options: CastOptions) -> PolarsResult { + self.0.cast_with_options(dtype, cast_options) + } + + #[inline] + unsafe fn get_unchecked(&self, index: usize) -> AnyValue { + self.0.get_any_value_unchecked(index) + } + + fn sort_with(&self, options: SortOptions) -> PolarsResult { + Ok(ChunkSort::sort_with(&self.0, options).into_series()) + } + + fn arg_sort(&self, options: SortOptions) -> IdxCa { + ChunkSort::arg_sort(&self.0, options) + } + + fn null_count(&self) -> usize { + self.0.null_count() + } + + fn has_nulls(&self) -> bool { + self.0.has_nulls() + } + + #[cfg(feature = "algorithm_group_by")] + fn unique(&self) -> PolarsResult { + ChunkUnique::unique(&self.0).map(|ca| ca.into_series()) + } + + #[cfg(feature = "algorithm_group_by")] + fn n_unique(&self) -> PolarsResult { + ChunkUnique::n_unique(&self.0) + } + + #[cfg(feature = "algorithm_group_by")] + fn arg_unique(&self) -> PolarsResult { + ChunkUnique::arg_unique(&self.0) + } + + fn is_null(&self) -> BooleanChunked { + self.0.is_null() + } + + fn is_not_null(&self) -> BooleanChunked { + self.0.is_not_null() + } + + fn reverse(&self) -> Series { + ChunkReverse::reverse(&self.0).into_series() + } + + fn as_single_ptr(&mut self) -> PolarsResult { + self.0.as_single_ptr() + } + + fn shift(&self, periods: i64) -> Series { + ChunkShift::shift(&self.0, periods).into_series() + } + + fn sum_reduce(&self) -> PolarsResult { + Ok(ChunkAggSeries::sum_reduce(&self.0)) + } + fn max_reduce(&self) -> PolarsResult { + Ok(ChunkAggSeries::max_reduce(&self.0)) + } + fn min_reduce(&self) -> PolarsResult { + Ok(ChunkAggSeries::min_reduce(&self.0)) + } + fn median_reduce(&self) -> PolarsResult { + Ok(QuantileAggSeries::median_reduce(&self.0)) + } + fn var_reduce(&self, ddof: u8) -> PolarsResult { + Ok(VarAggSeries::var_reduce(&self.0, ddof)) + } + fn std_reduce(&self, ddof: u8) -> PolarsResult { + Ok(VarAggSeries::std_reduce(&self.0, ddof)) + } + fn quantile_reduce( + &self, + quantile: f64, + method: QuantileMethod, + ) -> PolarsResult { + QuantileAggSeries::quantile_reduce(&self.0, quantile, method) + } + #[cfg(feature = "bitwise")] + fn and_reduce(&self) -> PolarsResult { + let dt = <$pdt as PolarsDataType>::get_dtype(); + let av = self.0.and_reduce().map_or(AnyValue::Null, Into::into); + + Ok(Scalar::new(dt, av)) + } + #[cfg(feature = "bitwise")] + fn or_reduce(&self) -> PolarsResult { + let dt = <$pdt as PolarsDataType>::get_dtype(); + let av = self.0.or_reduce().map_or(AnyValue::Null, Into::into); + + Ok(Scalar::new(dt, av)) + } + #[cfg(feature = "bitwise")] + fn xor_reduce(&self) -> PolarsResult { + let dt = <$pdt as PolarsDataType>::get_dtype(); + let av = self.0.xor_reduce().map_or(AnyValue::Null, Into::into); + + Ok(Scalar::new(dt, av)) + } + + #[cfg(feature = "approx_unique")] + fn approx_n_unique(&self) -> PolarsResult { + Ok(ChunkApproxNUnique::approx_n_unique(&self.0)) + } + + fn clone_inner(&self) -> Arc { + Arc::new(SeriesWrap(Clone::clone(&self.0))) + } + + #[cfg(feature = "checked_arithmetic")] + fn checked_div(&self, rhs: &Series) -> PolarsResult { + self.0.checked_div(rhs) + } + + fn as_any(&self) -> &dyn Any { + &self.0 + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } + + fn as_phys_any(&self) -> &dyn Any { + &self.0 + } + + fn as_arc_any(self: Arc) -> Arc { + self as _ + } + } + }; +} + +impl_dyn_series!(Float32Chunked, Float32Type); +impl_dyn_series!(Float64Chunked, Float64Type); diff --git a/crates/polars-core/src/series/implementations/list.rs b/crates/polars-core/src/series/implementations/list.rs new file mode 100644 index 000000000000..869a10e9ee56 --- /dev/null +++ b/crates/polars-core/src/series/implementations/list.rs @@ -0,0 +1,264 @@ +use super::*; +use crate::chunked_array::comparison::*; +#[cfg(feature = "algorithm_group_by")] +use crate::frame::group_by::*; +use crate::prelude::*; + +impl private::PrivateSeries for SeriesWrap { + fn compute_len(&mut self) { + self.0.compute_len() + } + fn _field(&self) -> Cow { + Cow::Borrowed(self.0.ref_field()) + } + fn _dtype(&self) -> &DataType { + self.0.ref_field().dtype() + } + fn _get_flags(&self) -> StatisticsFlags { + self.0.get_flags() + } + fn _set_flags(&mut self, flags: StatisticsFlags) { + self.0.set_flags(flags) + } + + unsafe fn equal_element(&self, idx_self: usize, idx_other: usize, other: &Series) -> bool { + self.0.equal_element(idx_self, idx_other, other) + } + + #[cfg(feature = "zip_with")] + fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult { + ChunkZip::zip_with(&self.0, mask, other.as_ref().as_ref()).map(|ca| ca.into_series()) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { + self.0.agg_list(groups) + } + + #[cfg(feature = "algorithm_group_by")] + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + IntoGroupsType::group_tuples(&self.0, multithreaded, sorted) + } + + fn into_total_eq_inner<'a>(&'a self) -> Box { + (&self.0).into_total_eq_inner() + } + fn into_total_ord_inner<'a>(&'a self) -> Box { + invalid_operation_panic!(into_total_ord_inner, self) + } + + fn add_to(&self, rhs: &Series) -> PolarsResult { + self.0.add_to(rhs) + } + + fn subtract(&self, rhs: &Series) -> PolarsResult { + self.0.subtract(rhs) + } + + fn multiply(&self, rhs: &Series) -> PolarsResult { + self.0.multiply(rhs) + } + fn divide(&self, rhs: &Series) -> PolarsResult { + self.0.divide(rhs) + } + fn remainder(&self, rhs: &Series) -> PolarsResult { + self.0.remainder(rhs) + } +} + +impl SeriesTrait for SeriesWrap { + fn rename(&mut self, name: PlSmallStr) { + self.0.rename(name); + } + + fn chunk_lengths(&self) -> ChunkLenIter { + self.0.chunk_lengths() + } + fn name(&self) -> &PlSmallStr { + self.0.name() + } + + fn chunks(&self) -> &Vec { + self.0.chunks() + } + unsafe fn chunks_mut(&mut self) -> &mut Vec { + self.0.chunks_mut() + } + fn shrink_to_fit(&mut self) { + self.0.shrink_to_fit() + } + + fn sum_reduce(&self) -> PolarsResult { + polars_bail!( + op = "`sum`", + self.dtype(), + hint = "you may mean to call `concat_list`" + ); + } + + fn arg_sort(&self, options: SortOptions) -> IdxCa { + self.0.arg_sort(options) + } + + fn sort_with(&self, options: SortOptions) -> PolarsResult { + Ok(self.0.sort_with(options).into_series()) + } + + fn slice(&self, offset: i64, length: usize) -> Series { + self.0.slice(offset, length).into_series() + } + + fn split_at(&self, offset: i64) -> (Series, Series) { + let (a, b) = self.0.split_at(offset); + (a.into_series(), b.into_series()) + } + + fn append(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + self.0.append(other.as_ref().as_ref()) + } + fn append_owned(&mut self, other: Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + self.0.append_owned(other.take_inner()) + } + + fn extend(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), extend); + self.0.extend(other.as_ref().as_ref()) + } + + fn filter(&self, filter: &BooleanChunked) -> PolarsResult { + ChunkFilter::filter(&self.0, filter).map(|ca| ca.into_series()) + } + + fn take(&self, indices: &IdxCa) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) + } + + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0.take_unchecked(indices).into_series() + } + + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) + } + + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0.take_unchecked(indices).into_series() + } + + fn len(&self) -> usize { + self.0.len() + } + + fn rechunk(&self) -> Series { + self.0.rechunk().into_owned().into_series() + } + + fn new_from_index(&self, index: usize, length: usize) -> Series { + ChunkExpandAtIndex::new_from_index(&self.0, index, length).into_series() + } + + fn cast(&self, dtype: &DataType, cast_options: CastOptions) -> PolarsResult { + self.0.cast_with_options(dtype, cast_options) + } + + #[inline] + unsafe fn get_unchecked(&self, index: usize) -> AnyValue { + self.0.get_any_value_unchecked(index) + } + + fn null_count(&self) -> usize { + self.0.null_count() + } + + fn has_nulls(&self) -> bool { + self.0.has_nulls() + } + + #[cfg(feature = "algorithm_group_by")] + fn unique(&self) -> PolarsResult { + if !self.inner_dtype().is_primitive_numeric() { + polars_bail!(opq = unique, self.dtype()); + } + // this can be called in aggregation, so this fast path can be worth a lot + if self.len() < 2 { + return Ok(self.0.clone().into_series()); + } + let main_thread = POOL.current_thread_index().is_none(); + let groups = self.group_tuples(main_thread, false); + // SAFETY: + // groups are in bounds + Ok(unsafe { self.0.clone().into_series().agg_first(&groups?) }) + } + + #[cfg(feature = "algorithm_group_by")] + fn n_unique(&self) -> PolarsResult { + // this can be called in aggregation, so this fast path can be worth a lot + match self.len() { + 0 => Ok(0), + 1 => Ok(1), + _ => { + let main_thread = POOL.current_thread_index().is_none(); + let groups = self.group_tuples(main_thread, false)?; + Ok(groups.len()) + }, + } + } + + #[cfg(feature = "algorithm_group_by")] + fn arg_unique(&self) -> PolarsResult { + if !self.inner_dtype().is_primitive_numeric() { + polars_bail!(opq = arg_unique, self.dtype()); + } + // this can be called in aggregation, so this fast path can be worth a lot + if self.len() == 1 { + return Ok(IdxCa::new_vec(self.name().clone(), vec![0 as IdxSize])); + } + let main_thread = POOL.current_thread_index().is_none(); + // arg_unique requires a stable order + let groups = self.group_tuples(main_thread, true)?; + let first = groups.take_group_firsts(); + Ok(IdxCa::from_vec(self.name().clone(), first)) + } + + fn is_null(&self) -> BooleanChunked { + self.0.is_null() + } + + fn is_not_null(&self) -> BooleanChunked { + self.0.is_not_null() + } + + fn reverse(&self) -> Series { + ChunkReverse::reverse(&self.0).into_series() + } + + fn as_single_ptr(&mut self) -> PolarsResult { + self.0.as_single_ptr() + } + + fn shift(&self, periods: i64) -> Series { + ChunkShift::shift(&self.0, periods).into_series() + } + + fn clone_inner(&self) -> Arc { + Arc::new(SeriesWrap(Clone::clone(&self.0))) + } + + fn as_any(&self) -> &dyn Any { + &self.0 + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } + + fn as_phys_any(&self) -> &dyn Any { + &self.0 + } + + fn as_arc_any(self: Arc) -> Arc { + self as _ + } +} diff --git a/crates/polars-core/src/series/implementations/mod.rs b/crates/polars-core/src/series/implementations/mod.rs new file mode 100644 index 000000000000..063c63238d9d --- /dev/null +++ b/crates/polars-core/src/series/implementations/mod.rs @@ -0,0 +1,525 @@ +#![allow(unsafe_op_in_unsafe_fn)] +#[cfg(feature = "dtype-array")] +mod array; +mod binary; +mod binary_offset; +mod boolean; +#[cfg(feature = "dtype-categorical")] +mod categorical; +#[cfg(feature = "dtype-date")] +mod date; +#[cfg(feature = "dtype-datetime")] +mod datetime; +#[cfg(feature = "dtype-decimal")] +mod decimal; +#[cfg(feature = "dtype-duration")] +mod duration; +mod floats; +mod list; +pub(crate) mod null; +#[cfg(feature = "object")] +mod object; +mod string; +#[cfg(feature = "dtype-struct")] +mod struct_; +#[cfg(feature = "dtype-time")] +mod time; + +use std::any::Any; +use std::borrow::Cow; + +use polars_compute::rolling::QuantileMethod; +use polars_utils::aliases::PlSeedableRandomStateQuality; + +use super::*; +use crate::chunked_array::AsSinglePtr; +use crate::chunked_array::comparison::*; +use crate::chunked_array::ops::compare_inner::{ + IntoTotalEqInner, IntoTotalOrdInner, TotalEqInner, TotalOrdInner, +}; + +// Utility wrapper struct +pub(crate) struct SeriesWrap(pub T); + +impl From> for SeriesWrap> { + fn from(ca: ChunkedArray) -> Self { + SeriesWrap(ca) + } +} + +impl Deref for SeriesWrap> { + type Target = ChunkedArray; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +unsafe impl IntoSeries for ChunkedArray +where + SeriesWrap>: SeriesTrait, +{ + fn into_series(self) -> Series + where + Self: Sized, + { + Series(Arc::new(SeriesWrap(self))) + } +} + +macro_rules! impl_dyn_series { + ($ca: ident, $pdt:ty) => { + impl private::PrivateSeries for SeriesWrap<$ca> { + fn compute_len(&mut self) { + self.0.compute_len() + } + + fn _field(&self) -> Cow { + Cow::Borrowed(self.0.ref_field()) + } + + fn _dtype(&self) -> &DataType { + self.0.ref_field().dtype() + } + + fn _get_flags(&self) -> StatisticsFlags { + self.0.get_flags() + } + + fn _set_flags(&mut self, flags: StatisticsFlags) { + self.0.set_flags(flags) + } + + unsafe fn equal_element( + &self, + idx_self: usize, + idx_other: usize, + other: &Series, + ) -> bool { + self.0.equal_element(idx_self, idx_other, other) + } + + #[cfg(feature = "zip_with")] + fn zip_with_same_type( + &self, + mask: &BooleanChunked, + other: &Series, + ) -> PolarsResult { + ChunkZip::zip_with(&self.0, mask, other.as_ref().as_ref()) + .map(|ca| ca.into_series()) + } + fn into_total_eq_inner<'a>(&'a self) -> Box { + (&self.0).into_total_eq_inner() + } + fn into_total_ord_inner<'a>(&'a self) -> Box { + (&self.0).into_total_ord_inner() + } + + fn vec_hash( + &self, + random_state: PlSeedableRandomStateQuality, + buf: &mut Vec, + ) -> PolarsResult<()> { + self.0.vec_hash(random_state, buf)?; + Ok(()) + } + + fn vec_hash_combine( + &self, + build_hasher: PlSeedableRandomStateQuality, + hashes: &mut [u64], + ) -> PolarsResult<()> { + self.0.vec_hash_combine(build_hasher, hashes)?; + Ok(()) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_min(&self, groups: &GroupsType) -> Series { + self.0.agg_min(groups) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_max(&self, groups: &GroupsType) -> Series { + self.0.agg_max(groups) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_sum(&self, groups: &GroupsType) -> Series { + use DataType::*; + match self.dtype() { + Int8 | UInt8 | Int16 | UInt16 => self + .cast(&Int64, CastOptions::Overflowing) + .unwrap() + .agg_sum(groups), + _ => self.0.agg_sum(groups), + } + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_std(&self, groups: &GroupsType, ddof: u8) -> Series { + self.0.agg_std(groups, ddof) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_var(&self, groups: &GroupsType, ddof: u8) -> Series { + self.0.agg_var(groups, ddof) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { + self.0.agg_list(groups) + } + + #[cfg(feature = "bitwise")] + unsafe fn agg_and(&self, groups: &GroupsType) -> Series { + self.0.agg_and(groups) + } + #[cfg(feature = "bitwise")] + unsafe fn agg_or(&self, groups: &GroupsType) -> Series { + self.0.agg_or(groups) + } + #[cfg(feature = "bitwise")] + unsafe fn agg_xor(&self, groups: &GroupsType) -> Series { + self.0.agg_xor(groups) + } + + fn subtract(&self, rhs: &Series) -> PolarsResult { + NumOpsDispatch::subtract(&self.0, rhs) + } + fn add_to(&self, rhs: &Series) -> PolarsResult { + NumOpsDispatch::add_to(&self.0, rhs) + } + fn multiply(&self, rhs: &Series) -> PolarsResult { + NumOpsDispatch::multiply(&self.0, rhs) + } + fn divide(&self, rhs: &Series) -> PolarsResult { + NumOpsDispatch::divide(&self.0, rhs) + } + fn remainder(&self, rhs: &Series) -> PolarsResult { + NumOpsDispatch::remainder(&self.0, rhs) + } + #[cfg(feature = "algorithm_group_by")] + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + IntoGroupsType::group_tuples(&self.0, multithreaded, sorted) + } + + fn arg_sort_multiple( + &self, + by: &[Column], + options: &SortMultipleOptions, + ) -> PolarsResult { + self.0.arg_sort_multiple(by, options) + } + } + + impl SeriesTrait for SeriesWrap<$ca> { + #[cfg(feature = "rolling_window")] + fn rolling_map( + &self, + _f: &dyn Fn(&Series) -> Series, + _options: RollingOptionsFixedWindow, + ) -> PolarsResult { + ChunkRollApply::rolling_map(&self.0, _f, _options).map(|ca| ca.into_series()) + } + + fn rename(&mut self, name: PlSmallStr) { + self.0.rename(name); + } + + fn chunk_lengths(&self) -> ChunkLenIter { + self.0.chunk_lengths() + } + fn name(&self) -> &PlSmallStr { + self.0.name() + } + + fn chunks(&self) -> &Vec { + self.0.chunks() + } + unsafe fn chunks_mut(&mut self) -> &mut Vec { + self.0.chunks_mut() + } + fn shrink_to_fit(&mut self) { + self.0.shrink_to_fit() + } + + fn slice(&self, offset: i64, length: usize) -> Series { + self.0.slice(offset, length).into_series() + } + + fn split_at(&self, offset: i64) -> (Series, Series) { + let (a, b) = self.0.split_at(offset); + (a.into_series(), b.into_series()) + } + + fn append(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + self.0.append(other.as_ref().as_ref())?; + Ok(()) + } + fn append_owned(&mut self, other: Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + self.0.append_owned(other.take_inner()) + } + + fn extend(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), extend); + self.0.extend(other.as_ref().as_ref())?; + Ok(()) + } + + fn filter(&self, filter: &BooleanChunked) -> PolarsResult { + ChunkFilter::filter(&self.0, filter).map(|ca| ca.into_series()) + } + + fn _sum_as_f64(&self) -> f64 { + self.0._sum_as_f64() + } + + fn mean(&self) -> Option { + self.0.mean() + } + + fn median(&self) -> Option { + self.0.median() + } + + fn std(&self, ddof: u8) -> Option { + self.0.std(ddof) + } + + fn var(&self, ddof: u8) -> Option { + self.0.var(ddof) + } + + fn take(&self, indices: &IdxCa) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) + } + + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0.take_unchecked(indices).into_series() + } + + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) + } + + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0.take_unchecked(indices).into_series() + } + + fn len(&self) -> usize { + self.0.len() + } + + fn rechunk(&self) -> Series { + self.0.rechunk().into_owned().into_series() + } + + fn new_from_index(&self, index: usize, length: usize) -> Series { + ChunkExpandAtIndex::new_from_index(&self.0, index, length).into_series() + } + + fn cast(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + self.0.cast_with_options(dtype, options) + } + + #[inline] + unsafe fn get_unchecked(&self, index: usize) -> AnyValue { + self.0.get_any_value_unchecked(index) + } + + fn sort_with(&self, options: SortOptions) -> PolarsResult { + Ok(ChunkSort::sort_with(&self.0, options).into_series()) + } + + fn arg_sort(&self, options: SortOptions) -> IdxCa { + ChunkSort::arg_sort(&self.0, options) + } + + fn null_count(&self) -> usize { + self.0.null_count() + } + + fn has_nulls(&self) -> bool { + self.0.has_nulls() + } + + #[cfg(feature = "algorithm_group_by")] + fn unique(&self) -> PolarsResult { + ChunkUnique::unique(&self.0).map(|ca| ca.into_series()) + } + + #[cfg(feature = "algorithm_group_by")] + fn n_unique(&self) -> PolarsResult { + ChunkUnique::n_unique(&self.0) + } + + #[cfg(feature = "algorithm_group_by")] + fn arg_unique(&self) -> PolarsResult { + ChunkUnique::arg_unique(&self.0) + } + + fn is_null(&self) -> BooleanChunked { + self.0.is_null() + } + + fn is_not_null(&self) -> BooleanChunked { + self.0.is_not_null() + } + + fn reverse(&self) -> Series { + ChunkReverse::reverse(&self.0).into_series() + } + + fn as_single_ptr(&mut self) -> PolarsResult { + self.0.as_single_ptr() + } + + fn shift(&self, periods: i64) -> Series { + ChunkShift::shift(&self.0, periods).into_series() + } + + fn sum_reduce(&self) -> PolarsResult { + Ok(ChunkAggSeries::sum_reduce(&self.0)) + } + fn max_reduce(&self) -> PolarsResult { + Ok(ChunkAggSeries::max_reduce(&self.0)) + } + fn min_reduce(&self) -> PolarsResult { + Ok(ChunkAggSeries::min_reduce(&self.0)) + } + fn median_reduce(&self) -> PolarsResult { + Ok(QuantileAggSeries::median_reduce(&self.0)) + } + fn var_reduce(&self, ddof: u8) -> PolarsResult { + Ok(VarAggSeries::var_reduce(&self.0, ddof)) + } + fn std_reduce(&self, ddof: u8) -> PolarsResult { + Ok(VarAggSeries::std_reduce(&self.0, ddof)) + } + fn quantile_reduce( + &self, + quantile: f64, + method: QuantileMethod, + ) -> PolarsResult { + QuantileAggSeries::quantile_reduce(&self.0, quantile, method) + } + + #[cfg(feature = "bitwise")] + fn and_reduce(&self) -> PolarsResult { + let dt = <$pdt as PolarsDataType>::get_dtype(); + let av = self.0.and_reduce().map_or(AnyValue::Null, Into::into); + + Ok(Scalar::new(dt, av)) + } + + #[cfg(feature = "bitwise")] + fn or_reduce(&self) -> PolarsResult { + let dt = <$pdt as PolarsDataType>::get_dtype(); + let av = self.0.or_reduce().map_or(AnyValue::Null, Into::into); + + Ok(Scalar::new(dt, av)) + } + + #[cfg(feature = "bitwise")] + fn xor_reduce(&self) -> PolarsResult { + let dt = <$pdt as PolarsDataType>::get_dtype(); + let av = self.0.xor_reduce().map_or(AnyValue::Null, Into::into); + + Ok(Scalar::new(dt, av)) + } + + #[cfg(feature = "approx_unique")] + fn approx_n_unique(&self) -> PolarsResult { + Ok(ChunkApproxNUnique::approx_n_unique(&self.0)) + } + + fn clone_inner(&self) -> Arc { + Arc::new(SeriesWrap(Clone::clone(&self.0))) + } + + #[cfg(feature = "checked_arithmetic")] + fn checked_div(&self, rhs: &Series) -> PolarsResult { + self.0.checked_div(rhs) + } + + fn as_any(&self) -> &dyn Any { + &self.0 + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } + + fn as_phys_any(&self) -> &dyn Any { + &self.0 + } + + fn as_arc_any(self: Arc) -> Arc { + self as _ + } + } + }; +} + +#[cfg(feature = "dtype-u8")] +impl_dyn_series!(UInt8Chunked, UInt8Type); +#[cfg(feature = "dtype-u16")] +impl_dyn_series!(UInt16Chunked, UInt16Type); +impl_dyn_series!(UInt32Chunked, UInt32Type); +impl_dyn_series!(UInt64Chunked, UInt64Type); +#[cfg(feature = "dtype-i8")] +impl_dyn_series!(Int8Chunked, Int8Type); +#[cfg(feature = "dtype-i16")] +impl_dyn_series!(Int16Chunked, Int16Type); +impl_dyn_series!(Int32Chunked, Int32Type); +impl_dyn_series!(Int64Chunked, Int64Type); +#[cfg(feature = "dtype-i128")] +impl_dyn_series!(Int128Chunked, Int128Type); + +impl private::PrivateSeriesNumeric for SeriesWrap> { + fn bit_repr(&self) -> Option { + Some(self.0.to_bit_repr()) + } +} + +impl private::PrivateSeriesNumeric for SeriesWrap { + fn bit_repr(&self) -> Option { + None + } +} +impl private::PrivateSeriesNumeric for SeriesWrap { + fn bit_repr(&self) -> Option { + None + } +} +impl private::PrivateSeriesNumeric for SeriesWrap { + fn bit_repr(&self) -> Option { + None + } +} +impl private::PrivateSeriesNumeric for SeriesWrap { + fn bit_repr(&self) -> Option { + None + } +} +#[cfg(feature = "dtype-array")] +impl private::PrivateSeriesNumeric for SeriesWrap { + fn bit_repr(&self) -> Option { + None + } +} +impl private::PrivateSeriesNumeric for SeriesWrap { + fn bit_repr(&self) -> Option { + let repr = self + .0 + .cast_with_options(&DataType::UInt32, CastOptions::NonStrict) + .unwrap() + .u32() + .unwrap() + .clone(); + + Some(BitRepr::Small(repr)) + } +} diff --git a/crates/polars-core/src/series/implementations/null.rs b/crates/polars-core/src/series/implementations/null.rs new file mode 100644 index 000000000000..e2eebad6463a --- /dev/null +++ b/crates/polars-core/src/series/implementations/null.rs @@ -0,0 +1,363 @@ +use std::any::Any; + +use polars_error::constants::LENGTH_LIMIT_MSG; + +use self::compare_inner::TotalOrdInner; +use super::*; +use crate::prelude::compare_inner::{IntoTotalEqInner, TotalEqInner}; +use crate::prelude::*; +use crate::series::private::{PrivateSeries, PrivateSeriesNumeric}; +use crate::series::*; + +impl Series { + pub fn new_null(name: PlSmallStr, len: usize) -> Series { + NullChunked::new(name, len).into_series() + } +} + +#[derive(Clone)] +pub struct NullChunked { + pub(crate) name: PlSmallStr, + length: IdxSize, + // we still need chunks as many series consumers expect + // chunks to be there + chunks: Vec, +} + +impl NullChunked { + pub(crate) fn new(name: PlSmallStr, len: usize) -> Self { + Self { + name, + length: len as IdxSize, + chunks: vec![Box::new(arrow::array::NullArray::new( + ArrowDataType::Null, + len, + ))], + } + } +} +impl PrivateSeriesNumeric for NullChunked { + fn bit_repr(&self) -> Option { + Some(BitRepr::Small(UInt32Chunked::full_null( + self.name.clone(), + self.len(), + ))) + } +} + +impl PrivateSeries for NullChunked { + fn compute_len(&mut self) { + fn inner(chunks: &[ArrayRef]) -> usize { + match chunks.len() { + // fast path + 1 => chunks[0].len(), + _ => chunks.iter().fold(0, |acc, arr| acc + arr.len()), + } + } + self.length = IdxSize::try_from(inner(&self.chunks)).expect(LENGTH_LIMIT_MSG); + } + fn _field(&self) -> Cow { + Cow::Owned(Field::new(self.name().clone(), DataType::Null)) + } + + #[allow(unused)] + fn _set_flags(&mut self, flags: StatisticsFlags) {} + + fn _dtype(&self) -> &DataType { + &DataType::Null + } + + #[cfg(feature = "zip_with")] + fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult { + let len = match (self.len(), mask.len(), other.len()) { + (a, b, c) if a == b && b == c => a, + (1, a, b) | (a, 1, b) | (a, b, 1) if a == b => a, + (a, 1, 1) | (1, a, 1) | (1, 1, a) => a, + (_, 0, _) => 0, + _ => { + polars_bail!(ShapeMismatch: "shapes of `self`, `mask` and `other` are not suitable for `zip_with` operation") + }, + }; + + Ok(Self::new(self.name().clone(), len).into_series()) + } + + fn into_total_eq_inner<'a>(&'a self) -> Box { + IntoTotalEqInner::into_total_eq_inner(self) + } + fn into_total_ord_inner<'a>(&'a self) -> Box { + invalid_operation_panic!(into_total_ord_inner, self) + } + + fn subtract(&self, _rhs: &Series) -> PolarsResult { + null_arithmetic(self, _rhs, "subtract") + } + + fn add_to(&self, _rhs: &Series) -> PolarsResult { + null_arithmetic(self, _rhs, "add_to") + } + fn multiply(&self, _rhs: &Series) -> PolarsResult { + null_arithmetic(self, _rhs, "multiply") + } + fn divide(&self, _rhs: &Series) -> PolarsResult { + null_arithmetic(self, _rhs, "divide") + } + fn remainder(&self, _rhs: &Series) -> PolarsResult { + null_arithmetic(self, _rhs, "remainder") + } + + #[cfg(feature = "algorithm_group_by")] + fn group_tuples(&self, _multithreaded: bool, _sorted: bool) -> PolarsResult { + Ok(if self.is_empty() { + GroupsType::default() + } else { + GroupsType::Slice { + groups: vec![[0, self.length]], + rolling: false, + } + }) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { + AggList::agg_list(self, groups) + } + + fn _get_flags(&self) -> StatisticsFlags { + StatisticsFlags::empty() + } + + fn vec_hash( + &self, + random_state: PlSeedableRandomStateQuality, + buf: &mut Vec, + ) -> PolarsResult<()> { + VecHash::vec_hash(self, random_state, buf)?; + Ok(()) + } + + fn vec_hash_combine( + &self, + build_hasher: PlSeedableRandomStateQuality, + hashes: &mut [u64], + ) -> PolarsResult<()> { + VecHash::vec_hash_combine(self, build_hasher, hashes)?; + Ok(()) + } +} + +fn null_arithmetic(lhs: &NullChunked, rhs: &Series, op: &str) -> PolarsResult { + let output_len = match (lhs.len(), rhs.len()) { + (1, len_r) => len_r, + (len_l, 1) => len_l, + (len_l, len_r) if len_l == len_r => len_l, + _ => polars_bail!(ComputeError: "Cannot {:?} two series of different lengths.", op), + }; + Ok(NullChunked::new(lhs.name().clone(), output_len).into_series()) +} + +impl SeriesTrait for NullChunked { + fn name(&self) -> &PlSmallStr { + &self.name + } + + fn rename(&mut self, name: PlSmallStr) { + self.name = name + } + + fn chunks(&self) -> &Vec { + &self.chunks + } + unsafe fn chunks_mut(&mut self) -> &mut Vec { + &mut self.chunks + } + + fn chunk_lengths(&self) -> ChunkLenIter { + self.chunks.iter().map(|chunk| chunk.len()) + } + + fn take(&self, indices: &IdxCa) -> PolarsResult { + Ok(NullChunked::new(self.name.clone(), indices.len()).into_series()) + } + + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + NullChunked::new(self.name.clone(), indices.len()).into_series() + } + + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(NullChunked::new(self.name.clone(), indices.len()).into_series()) + } + + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + NullChunked::new(self.name.clone(), indices.len()).into_series() + } + + fn len(&self) -> usize { + self.length as usize + } + + fn has_nulls(&self) -> bool { + self.len() > 0 + } + + fn rechunk(&self) -> Series { + NullChunked::new(self.name.clone(), self.len()).into_series() + } + + fn drop_nulls(&self) -> Series { + NullChunked::new(self.name.clone(), 0).into_series() + } + + fn cast(&self, dtype: &DataType, _cast_options: CastOptions) -> PolarsResult { + Ok(Series::full_null(self.name.clone(), self.len(), dtype)) + } + + fn null_count(&self) -> usize { + self.len() + } + + #[cfg(feature = "algorithm_group_by")] + fn unique(&self) -> PolarsResult { + let ca = NullChunked::new(self.name.clone(), self.n_unique().unwrap()); + Ok(ca.into_series()) + } + + #[cfg(feature = "algorithm_group_by")] + fn n_unique(&self) -> PolarsResult { + let n = if self.is_empty() { 0 } else { 1 }; + Ok(n) + } + + #[cfg(feature = "algorithm_group_by")] + fn arg_unique(&self) -> PolarsResult { + let idxs: Vec = (0..self.n_unique().unwrap() as IdxSize).collect(); + Ok(IdxCa::new(self.name().clone(), idxs)) + } + + fn new_from_index(&self, _index: usize, length: usize) -> Series { + NullChunked::new(self.name.clone(), length).into_series() + } + + unsafe fn get_unchecked(&self, _index: usize) -> AnyValue { + AnyValue::Null + } + + fn slice(&self, offset: i64, length: usize) -> Series { + let (chunks, len) = chunkops::slice(&self.chunks, offset, length, self.len()); + NullChunked { + name: self.name.clone(), + length: len as IdxSize, + chunks, + } + .into_series() + } + + fn split_at(&self, offset: i64) -> (Series, Series) { + let (l, r) = chunkops::split_at(self.chunks(), offset, self.len()); + ( + NullChunked { + name: self.name.clone(), + length: l.iter().map(|arr| arr.len() as IdxSize).sum(), + chunks: l, + } + .into_series(), + NullChunked { + name: self.name.clone(), + length: r.iter().map(|arr| arr.len() as IdxSize).sum(), + chunks: r, + } + .into_series(), + ) + } + + fn sort_with(&self, _options: SortOptions) -> PolarsResult { + Ok(self.clone().into_series()) + } + + fn arg_sort(&self, _options: SortOptions) -> IdxCa { + IdxCa::from_vec(self.name().clone(), (0..self.len() as IdxSize).collect()) + } + + fn is_null(&self) -> BooleanChunked { + BooleanChunked::full(self.name().clone(), true, self.len()) + } + + fn is_not_null(&self) -> BooleanChunked { + BooleanChunked::full(self.name().clone(), false, self.len()) + } + + fn reverse(&self) -> Series { + self.clone().into_series() + } + + fn filter(&self, filter: &BooleanChunked) -> PolarsResult { + let len = if self.is_empty() { + // We still allow a length of `1` because it could be `lit(true)`. + polars_ensure!(filter.len() <= 1, ShapeMismatch: "filter's length: {} differs from that of the series: 0", filter.len()); + 0 + } else if filter.len() == 1 { + return match filter.get(0) { + Some(true) => Ok(self.clone().into_series()), + None | Some(false) => Ok(NullChunked::new(self.name.clone(), 0).into_series()), + }; + } else { + polars_ensure!(filter.len() == self.len(), ShapeMismatch: "filter's length: {} differs from that of the series: {}", filter.len(), self.len()); + filter.sum().unwrap_or(0) as usize + }; + Ok(NullChunked::new(self.name.clone(), len).into_series()) + } + + fn shift(&self, _periods: i64) -> Series { + self.clone().into_series() + } + + fn append(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(other.dtype() == &DataType::Null, ComputeError: "expected null dtype"); + // we don't create a new null array to keep probability of aligned chunks higher + self.length += other.len() as IdxSize; + self.chunks.extend(other.chunks().iter().cloned()); + Ok(()) + } + fn append_owned(&mut self, mut other: Series) -> PolarsResult<()> { + polars_ensure!(other.dtype() == &DataType::Null, ComputeError: "expected null dtype"); + // we don't create a new null array to keep probability of aligned chunks higher + let other: &mut NullChunked = other._get_inner_mut().as_any_mut().downcast_mut().unwrap(); + self.length += other.len() as IdxSize; + self.chunks.extend(std::mem::take(&mut other.chunks)); + Ok(()) + } + + fn extend(&mut self, other: &Series) -> PolarsResult<()> { + *self = NullChunked::new(self.name.clone(), self.len() + other.len()); + Ok(()) + } + + fn clone_inner(&self) -> Arc { + Arc::new(self.clone()) + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + fn as_phys_any(&self) -> &dyn Any { + self + } + + fn as_arc_any(self: Arc) -> Arc { + self as _ + } +} + +unsafe impl IntoSeries for NullChunked { + fn into_series(self) -> Series + where + Self: Sized, + { + Series(Arc::new(self)) + } +} diff --git a/crates/polars-core/src/series/implementations/object.rs b/crates/polars-core/src/series/implementations/object.rs new file mode 100644 index 000000000000..b688379402fa --- /dev/null +++ b/crates/polars-core/src/series/implementations/object.rs @@ -0,0 +1,282 @@ +use std::any::Any; +use std::borrow::Cow; + +use self::compare_inner::TotalOrdInner; +use super::{BitRepr, StatisticsFlags, *}; +use crate::chunked_array::cast::CastOptions; +use crate::chunked_array::object::PolarsObjectSafe; +use crate::chunked_array::ops::compare_inner::{IntoTotalEqInner, TotalEqInner}; +use crate::prelude::*; +use crate::series::implementations::SeriesWrap; +use crate::series::private::{PrivateSeries, PrivateSeriesNumeric}; + +impl PrivateSeriesNumeric for SeriesWrap> { + fn bit_repr(&self) -> Option { + None + } +} + +impl PrivateSeries for SeriesWrap> +where + T: PolarsObject, +{ + fn get_list_builder( + &self, + _name: PlSmallStr, + _values_capacity: usize, + _list_capacity: usize, + ) -> Box { + ObjectChunked::::get_list_builder(_name, _values_capacity, _list_capacity) + } + + fn compute_len(&mut self) { + self.0.compute_len() + } + + fn _field(&self) -> Cow { + Cow::Borrowed(self.0.ref_field()) + } + + fn _dtype(&self) -> &DataType { + self.0.dtype() + } + + fn _set_flags(&mut self, flags: StatisticsFlags) { + self.0.set_flags(flags) + } + fn _get_flags(&self) -> StatisticsFlags { + self.0.get_flags() + } + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { + self.0.agg_list(groups) + } + + fn into_total_eq_inner<'a>(&'a self) -> Box { + (&self.0).into_total_eq_inner() + } + fn into_total_ord_inner<'a>(&'a self) -> Box { + invalid_operation_panic!(into_total_ord_inner, self) + } + + fn vec_hash( + &self, + random_state: PlSeedableRandomStateQuality, + buf: &mut Vec, + ) -> PolarsResult<()> { + self.0.vec_hash(random_state, buf)?; + Ok(()) + } + + fn vec_hash_combine( + &self, + build_hasher: PlSeedableRandomStateQuality, + hashes: &mut [u64], + ) -> PolarsResult<()> { + self.0.vec_hash_combine(build_hasher, hashes)?; + Ok(()) + } + + #[cfg(feature = "algorithm_group_by")] + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + IntoGroupsType::group_tuples(&self.0, multithreaded, sorted) + } + #[cfg(feature = "zip_with")] + fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult { + self.0 + .zip_with(mask, other.as_ref().as_ref()) + .map(|ca| ca.into_series()) + } +} +impl SeriesTrait for SeriesWrap> +where + T: PolarsObject, +{ + fn rename(&mut self, name: PlSmallStr) { + ObjectChunked::rename(&mut self.0, name) + } + + fn chunk_lengths(&self) -> ChunkLenIter { + ObjectChunked::chunk_lengths(&self.0) + } + + fn name(&self) -> &PlSmallStr { + ObjectChunked::name(&self.0) + } + + fn dtype(&self) -> &DataType { + ObjectChunked::dtype(&self.0) + } + + fn chunks(&self) -> &Vec { + ObjectChunked::chunks(&self.0) + } + unsafe fn chunks_mut(&mut self) -> &mut Vec { + self.0.chunks_mut() + } + + fn slice(&self, offset: i64, length: usize) -> Series { + ObjectChunked::slice(&self.0, offset, length).into_series() + } + + fn split_at(&self, offset: i64) -> (Series, Series) { + let (a, b) = ObjectChunked::split_at(&self.0, offset); + (a.into_series(), b.into_series()) + } + + fn append(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.dtype() == other.dtype(), append); + ObjectChunked::append(&mut self.0, other.as_ref().as_ref()) + } + fn append_owned(&mut self, other: Series) -> PolarsResult<()> { + polars_ensure!(self.dtype() == other.dtype(), append); + ObjectChunked::append_owned(&mut self.0, other.take_inner()) + } + + fn extend(&mut self, _other: &Series) -> PolarsResult<()> { + polars_bail!(opq = extend, self.dtype()); + } + + fn filter(&self, filter: &BooleanChunked) -> PolarsResult { + ChunkFilter::filter(&self.0, filter).map(|ca| ca.into_series()) + } + + fn take(&self, indices: &IdxCa) -> PolarsResult { + let ca = self.rechunk_object(); + Ok(ca.take(indices)?.into_series()) + } + + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + let ca = self.rechunk_object(); + ca.take_unchecked(indices).into_series() + } + + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) + } + + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0.take_unchecked(indices).into_series() + } + + fn len(&self) -> usize { + ObjectChunked::len(&self.0) + } + + fn rechunk(&self) -> Series { + // do not call normal rechunk + self.rechunk_object().into_series() + } + + fn new_from_index(&self, index: usize, length: usize) -> Series { + ChunkExpandAtIndex::new_from_index(&self.0, index, length).into_series() + } + + fn cast(&self, dtype: &DataType, _cast_options: CastOptions) -> PolarsResult { + if matches!(dtype, DataType::Object(_)) { + Ok(self.0.clone().into_series()) + } else { + Err(PolarsError::ComputeError( + "cannot cast 'Object' type".into(), + )) + } + } + + fn get(&self, index: usize) -> PolarsResult { + ObjectChunked::get_any_value(&self.0, index) + } + unsafe fn get_unchecked(&self, index: usize) -> AnyValue { + ObjectChunked::get_any_value_unchecked(&self.0, index) + } + fn null_count(&self) -> usize { + ObjectChunked::null_count(&self.0) + } + + fn has_nulls(&self) -> bool { + ObjectChunked::has_nulls(&self.0) + } + + fn unique(&self) -> PolarsResult { + ChunkUnique::unique(&self.0).map(|ca| ca.into_series()) + } + + fn n_unique(&self) -> PolarsResult { + ChunkUnique::n_unique(&self.0) + } + + fn arg_unique(&self) -> PolarsResult { + ChunkUnique::arg_unique(&self.0) + } + + fn is_null(&self) -> BooleanChunked { + ObjectChunked::is_null(&self.0) + } + + fn is_not_null(&self) -> BooleanChunked { + ObjectChunked::is_not_null(&self.0) + } + + fn reverse(&self) -> Series { + ChunkReverse::reverse(&self.0).into_series() + } + + fn shift(&self, periods: i64) -> Series { + ChunkShift::shift(&self.0, periods).into_series() + } + + fn clone_inner(&self) -> Arc { + Arc::new(SeriesWrap(Clone::clone(&self.0))) + } + + fn get_object(&self, index: usize) -> Option<&dyn PolarsObjectSafe> { + ObjectChunked::::get_object(&self.0, index) + } + + unsafe fn get_object_chunked_unchecked( + &self, + chunk: usize, + index: usize, + ) -> Option<&dyn PolarsObjectSafe> { + ObjectChunked::::get_object_chunked_unchecked(&self.0, chunk, index) + } + + fn as_any(&self) -> &dyn Any { + &self.0 + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } + + fn as_phys_any(&self) -> &dyn Any { + self + } + + fn as_arc_any(self: Arc) -> Arc { + self as _ + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_downcast_object() -> PolarsResult<()> { + #[allow(non_local_definitions)] + impl PolarsObject for i32 { + fn type_name() -> &'static str { + "i32" + } + } + + let ca = ObjectChunked::new_from_vec("a".into(), vec![0i32, 1, 2]); + let s = ca.into_series(); + + let ca = s.as_any().downcast_ref::>().unwrap(); + assert_eq!(*ca.get(0).unwrap(), 0); + assert_eq!(*ca.get(1).unwrap(), 1); + assert_eq!(*ca.get(2).unwrap(), 2); + + Ok(()) + } +} diff --git a/crates/polars-core/src/series/implementations/string.rs b/crates/polars-core/src/series/implementations/string.rs new file mode 100644 index 000000000000..c8d84ff9ed6f --- /dev/null +++ b/crates/polars-core/src/series/implementations/string.rs @@ -0,0 +1,282 @@ +use super::*; +use crate::chunked_array::comparison::*; +#[cfg(feature = "algorithm_group_by")] +use crate::frame::group_by::*; +use crate::prelude::*; + +impl private::PrivateSeries for SeriesWrap { + fn compute_len(&mut self) { + self.0.compute_len() + } + fn _field(&self) -> Cow { + Cow::Borrowed(self.0.ref_field()) + } + fn _dtype(&self) -> &DataType { + self.0.ref_field().dtype() + } + + fn _set_flags(&mut self, flags: StatisticsFlags) { + self.0.set_flags(flags) + } + fn _get_flags(&self) -> StatisticsFlags { + self.0.get_flags() + } + unsafe fn equal_element(&self, idx_self: usize, idx_other: usize, other: &Series) -> bool { + self.0.equal_element(idx_self, idx_other, other) + } + + #[cfg(feature = "zip_with")] + fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult { + ChunkZip::zip_with(&self.0, mask, other.as_ref().as_ref()).map(|ca| ca.into_series()) + } + fn into_total_eq_inner<'a>(&'a self) -> Box { + (&self.0).into_total_eq_inner() + } + fn into_total_ord_inner<'a>(&'a self) -> Box { + (&self.0).into_total_ord_inner() + } + + fn vec_hash( + &self, + random_state: PlSeedableRandomStateQuality, + buf: &mut Vec, + ) -> PolarsResult<()> { + self.0.vec_hash(random_state, buf)?; + Ok(()) + } + + fn vec_hash_combine( + &self, + build_hasher: PlSeedableRandomStateQuality, + hashes: &mut [u64], + ) -> PolarsResult<()> { + self.0.vec_hash_combine(build_hasher, hashes)?; + Ok(()) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { + self.0.agg_list(groups) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_min(&self, groups: &GroupsType) -> Series { + self.0.agg_min(groups) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_max(&self, groups: &GroupsType) -> Series { + self.0.agg_max(groups) + } + + fn subtract(&self, rhs: &Series) -> PolarsResult { + NumOpsDispatch::subtract(&self.0, rhs) + } + fn add_to(&self, rhs: &Series) -> PolarsResult { + NumOpsDispatch::add_to(&self.0, rhs) + } + fn multiply(&self, rhs: &Series) -> PolarsResult { + NumOpsDispatch::multiply(&self.0, rhs) + } + fn divide(&self, rhs: &Series) -> PolarsResult { + NumOpsDispatch::divide(&self.0, rhs) + } + fn remainder(&self, rhs: &Series) -> PolarsResult { + NumOpsDispatch::remainder(&self.0, rhs) + } + #[cfg(feature = "algorithm_group_by")] + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + IntoGroupsType::group_tuples(&self.0, multithreaded, sorted) + } + + fn arg_sort_multiple( + &self, + by: &[Column], + options: &SortMultipleOptions, + ) -> PolarsResult { + self.0.arg_sort_multiple(by, options) + } +} + +impl SeriesTrait for SeriesWrap { + fn rename(&mut self, name: PlSmallStr) { + self.0.rename(name); + } + + fn chunk_lengths(&self) -> ChunkLenIter { + self.0.chunk_lengths() + } + fn name(&self) -> &PlSmallStr { + self.0.name() + } + + fn chunks(&self) -> &Vec { + self.0.chunks() + } + unsafe fn chunks_mut(&mut self) -> &mut Vec { + self.0.chunks_mut() + } + fn shrink_to_fit(&mut self) { + self.0.shrink_to_fit() + } + + fn slice(&self, offset: i64, length: usize) -> Series { + self.0.slice(offset, length).into_series() + } + fn split_at(&self, offset: i64) -> (Series, Series) { + let (a, b) = self.0.split_at(offset); + (a.into_series(), b.into_series()) + } + + fn append(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + // todo! add object + self.0.append(other.as_ref().as_ref())?; + Ok(()) + } + fn append_owned(&mut self, other: Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + // todo! add object + self.0.append_owned(other.take_inner()) + } + + fn extend(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!( + self.0.dtype() == other.dtype(), + SchemaMismatch: "cannot extend Series: data types don't match", + ); + self.0.extend(other.as_ref().as_ref())?; + Ok(()) + } + + fn filter(&self, filter: &BooleanChunked) -> PolarsResult { + ChunkFilter::filter(&self.0, filter).map(|ca| ca.into_series()) + } + + fn take(&self, indices: &IdxCa) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) + } + + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0.take_unchecked(indices).into_series() + } + + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) + } + + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0.take_unchecked(indices).into_series() + } + + fn len(&self) -> usize { + self.0.len() + } + + fn rechunk(&self) -> Series { + self.0.rechunk().into_owned().into_series() + } + + fn new_from_index(&self, index: usize, length: usize) -> Series { + ChunkExpandAtIndex::new_from_index(&self.0, index, length).into_series() + } + + fn cast(&self, dtype: &DataType, cast_options: CastOptions) -> PolarsResult { + self.0.cast_with_options(dtype, cast_options) + } + + #[inline] + unsafe fn get_unchecked(&self, index: usize) -> AnyValue { + self.0.get_any_value_unchecked(index) + } + + fn sort_with(&self, options: SortOptions) -> PolarsResult { + Ok(ChunkSort::sort_with(&self.0, options).into_series()) + } + + fn arg_sort(&self, options: SortOptions) -> IdxCa { + ChunkSort::arg_sort(&self.0, options) + } + + fn null_count(&self) -> usize { + self.0.null_count() + } + + fn has_nulls(&self) -> bool { + self.0.has_nulls() + } + + #[cfg(feature = "algorithm_group_by")] + fn unique(&self) -> PolarsResult { + ChunkUnique::unique(&self.0).map(|ca| ca.into_series()) + } + + #[cfg(feature = "algorithm_group_by")] + fn n_unique(&self) -> PolarsResult { + ChunkUnique::n_unique(&self.0) + } + + #[cfg(feature = "algorithm_group_by")] + fn arg_unique(&self) -> PolarsResult { + ChunkUnique::arg_unique(&self.0) + } + + fn is_null(&self) -> BooleanChunked { + self.0.is_null() + } + + fn is_not_null(&self) -> BooleanChunked { + self.0.is_not_null() + } + + fn reverse(&self) -> Series { + ChunkReverse::reverse(&self.0).into_series() + } + + fn as_single_ptr(&mut self) -> PolarsResult { + self.0.as_single_ptr() + } + + fn shift(&self, periods: i64) -> Series { + ChunkShift::shift(&self.0, periods).into_series() + } + + fn sum_reduce(&self) -> PolarsResult { + Err(polars_err!( + op = "`sum`", + DataType::String, + hint = "you may mean to call `str.join` or `list.join`" + )) + } + fn max_reduce(&self) -> PolarsResult { + Ok(ChunkAggSeries::max_reduce(&self.0)) + } + fn min_reduce(&self) -> PolarsResult { + Ok(ChunkAggSeries::min_reduce(&self.0)) + } + + #[cfg(feature = "approx_unique")] + fn approx_n_unique(&self) -> PolarsResult { + Ok(ChunkApproxNUnique::approx_n_unique(&self.0)) + } + + fn clone_inner(&self) -> Arc { + Arc::new(SeriesWrap(Clone::clone(&self.0))) + } + + fn as_any(&self) -> &dyn Any { + &self.0 + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } + + fn as_phys_any(&self) -> &dyn Any { + &self.0 + } + + fn as_arc_any(self: Arc) -> Arc { + self as _ + } +} diff --git a/crates/polars-core/src/series/implementations/struct_.rs b/crates/polars-core/src/series/implementations/struct_.rs new file mode 100644 index 000000000000..b48f293a021a --- /dev/null +++ b/crates/polars-core/src/series/implementations/struct_.rs @@ -0,0 +1,288 @@ +use std::ops::Not; + +use arrow::bitmap::Bitmap; + +use super::*; +use crate::chunked_array::StructChunked; +use crate::prelude::*; +use crate::series::private::{PrivateSeries, PrivateSeriesNumeric}; + +impl PrivateSeriesNumeric for SeriesWrap { + fn bit_repr(&self) -> Option { + None + } +} + +impl PrivateSeries for SeriesWrap { + fn _field(&self) -> Cow { + Cow::Borrowed(self.0.ref_field()) + } + + fn _dtype(&self) -> &DataType { + self.0.dtype() + } + + fn compute_len(&mut self) { + self.0.compute_len() + } + + fn _get_flags(&self) -> StatisticsFlags { + self.0.get_flags() + } + + fn _set_flags(&mut self, flags: StatisticsFlags) { + self.0.set_flags(flags); + } + + // TODO! remove this. Very slow. Asof join should use row-encoding. + unsafe fn equal_element(&self, idx_self: usize, idx_other: usize, other: &Series) -> bool { + let other = other.struct_().unwrap(); + self.0 + .fields_as_series() + .iter() + .zip(other.fields_as_series()) + .all(|(s, other)| s.equal_element(idx_self, idx_other, &other)) + } + + #[cfg(feature = "algorithm_group_by")] + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + let ca = self.0.get_row_encoded(Default::default())?; + ca.group_tuples(multithreaded, sorted) + } + + #[cfg(feature = "zip_with")] + fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult { + self.0 + .zip_with(mask, other.struct_()?) + .map(|ca| ca.into_series()) + } + + fn into_total_eq_inner<'a>(&'a self) -> Box { + invalid_operation_panic!(into_total_eq_inner, self) + } + fn into_total_ord_inner<'a>(&'a self) -> Box { + invalid_operation_panic!(into_total_ord_inner, self) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { + self.0.agg_list(groups) + } + + fn vec_hash( + &self, + build_hasher: PlSeedableRandomStateQuality, + buf: &mut Vec, + ) -> PolarsResult<()> { + let mut fields = self.0.fields_as_series().into_iter(); + + if let Some(s) = fields.next() { + s.vec_hash(build_hasher, buf)? + }; + for s in fields { + s.vec_hash_combine(build_hasher, buf)? + } + Ok(()) + } +} + +impl SeriesTrait for SeriesWrap { + fn rename(&mut self, name: PlSmallStr) { + self.0.rename(name) + } + + fn chunk_lengths(&self) -> ChunkLenIter { + self.0.chunk_lengths() + } + + fn name(&self) -> &PlSmallStr { + self.0.name() + } + + fn chunks(&self) -> &Vec { + &self.0.chunks + } + + unsafe fn chunks_mut(&mut self) -> &mut Vec { + self.0.chunks_mut() + } + + fn slice(&self, offset: i64, length: usize) -> Series { + self.0.slice(offset, length).into_series() + } + + fn split_at(&self, offset: i64) -> (Series, Series) { + let (l, r) = self.0.split_at(offset); + (l.into_series(), r.into_series()) + } + + fn append(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + self.0.append(other.as_ref().as_ref()) + } + fn append_owned(&mut self, other: Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + self.0.append_owned(other.take_inner()) + } + + fn extend(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), extend); + self.0.extend(other.as_ref().as_ref()) + } + + fn filter(&self, _filter: &BooleanChunked) -> PolarsResult { + ChunkFilter::filter(&self.0, _filter).map(|ca| ca.into_series()) + } + + fn take(&self, _indices: &IdxCa) -> PolarsResult { + self.0.take(_indices).map(|ca| ca.into_series()) + } + + unsafe fn take_unchecked(&self, _idx: &IdxCa) -> Series { + self.0.take_unchecked(_idx).into_series() + } + + fn take_slice(&self, _indices: &[IdxSize]) -> PolarsResult { + self.0.take(_indices).map(|ca| ca.into_series()) + } + + unsafe fn take_slice_unchecked(&self, _idx: &[IdxSize]) -> Series { + self.0.take_unchecked(_idx).into_series() + } + + fn len(&self) -> usize { + self.0.len() + } + + fn rechunk(&self) -> Series { + self.0.rechunk().into_owned().into_series() + } + + fn new_from_index(&self, _index: usize, _length: usize) -> Series { + self.0.new_from_index(_index, _length).into_series() + } + + fn cast(&self, dtype: &DataType, cast_options: CastOptions) -> PolarsResult { + self.0.cast_with_options(dtype, cast_options) + } + + unsafe fn get_unchecked(&self, index: usize) -> AnyValue { + self.0.get_any_value_unchecked(index) + } + + fn null_count(&self) -> usize { + self.0.null_count() + } + + /// Get unique values in the Series. + #[cfg(feature = "algorithm_group_by")] + fn unique(&self) -> PolarsResult { + // this can called in aggregation, so this fast path can be worth a lot + if self.len() < 2 { + return Ok(self.0.clone().into_series()); + } + let main_thread = POOL.current_thread_index().is_none(); + let groups = self.group_tuples(main_thread, false); + // SAFETY: + // groups are in bounds + Ok(unsafe { self.0.clone().into_series().agg_first(&groups?) }) + } + + /// Get unique values in the Series. + #[cfg(feature = "algorithm_group_by")] + fn n_unique(&self) -> PolarsResult { + // this can called in aggregation, so this fast path can be worth a lot + match self.len() { + 0 => Ok(0), + 1 => Ok(1), + _ => { + // TODO! try row encoding + let main_thread = POOL.current_thread_index().is_none(); + let groups = self.group_tuples(main_thread, false)?; + Ok(groups.len()) + }, + } + } + + /// Get first indexes of unique values. + #[cfg(feature = "algorithm_group_by")] + fn arg_unique(&self) -> PolarsResult { + // this can called in aggregation, so this fast path can be worth a lot + if self.len() == 1 { + return Ok(IdxCa::new_vec(self.name().clone(), vec![0 as IdxSize])); + } + let main_thread = POOL.current_thread_index().is_none(); + let groups = self.group_tuples(main_thread, true)?; + let first = groups.take_group_firsts(); + Ok(IdxCa::from_vec(self.name().clone(), first)) + } + + fn has_nulls(&self) -> bool { + self.0.has_nulls() + } + + fn is_null(&self) -> BooleanChunked { + let iter = self.downcast_iter().map(|arr| { + let bitmap = match arr.validity() { + Some(valid) => valid.not(), + None => Bitmap::new_with_value(false, arr.len()), + }; + BooleanArray::from_data_default(bitmap, None) + }); + BooleanChunked::from_chunk_iter(self.name().clone(), iter) + } + + fn is_not_null(&self) -> BooleanChunked { + let iter = self.downcast_iter().map(|arr| { + let bitmap = match arr.validity() { + Some(valid) => valid.clone(), + None => Bitmap::new_with_value(true, arr.len()), + }; + BooleanArray::from_data_default(bitmap, None) + }); + BooleanChunked::from_chunk_iter(self.name().clone(), iter) + } + + fn reverse(&self) -> Series { + let validity = self + .rechunk_validity() + .map(|x| x.into_iter().rev().collect::()); + self.0 + ._apply_fields(|s| s.reverse()) + .unwrap() + .with_outer_validity(validity) + .into_series() + } + + fn shift(&self, periods: i64) -> Series { + self.0.shift(periods).into_series() + } + + fn clone_inner(&self) -> Arc { + Arc::new(SeriesWrap(Clone::clone(&self.0))) + } + + fn as_any(&self) -> &dyn Any { + &self.0 + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } + + fn as_phys_any(&self) -> &dyn Any { + &self.0 + } + + fn as_arc_any(self: Arc) -> Arc { + self as _ + } + + fn sort_with(&self, options: SortOptions) -> PolarsResult { + Ok(self.0.sort_with(options).into_series()) + } + + fn arg_sort(&self, options: SortOptions) -> IdxCa { + self.0.arg_sort(options) + } +} diff --git a/crates/polars-core/src/series/implementations/time.rs b/crates/polars-core/src/series/implementations/time.rs new file mode 100644 index 000000000000..39b6a9322fe6 --- /dev/null +++ b/crates/polars-core/src/series/implementations/time.rs @@ -0,0 +1,356 @@ +//! This module exists to reduce compilation times. +//! +//! All the data types are backed by a physical type in memory e.g. Date -> i32, Datetime-> i64. +//! +//! Series lead to code implementations of all traits. Whereas there are a lot of duplicates due to +//! data types being backed by the same physical type. In this module we reduce compile times by +//! opting for a little more run time cost. We cast to the physical type -> apply the operation and +//! (depending on the result) cast back to the original type +//! +use super::*; +#[cfg(feature = "algorithm_group_by")] +use crate::frame::group_by::*; +use crate::prelude::*; + +unsafe impl IntoSeries for TimeChunked { + fn into_series(self) -> Series { + Series(Arc::new(SeriesWrap(self))) + } +} + +impl private::PrivateSeries for SeriesWrap { + fn compute_len(&mut self) { + self.0.compute_len() + } + + fn _field(&self) -> Cow { + Cow::Owned(self.0.field()) + } + + fn _dtype(&self) -> &DataType { + self.0.dtype() + } + + fn _get_flags(&self) -> StatisticsFlags { + self.0.get_flags() + } + + fn _set_flags(&mut self, flags: StatisticsFlags) { + self.0.set_flags(flags) + } + + #[cfg(feature = "zip_with")] + fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult { + let other = other.to_physical_repr().into_owned(); + self.0 + .zip_with(mask, other.as_ref().as_ref()) + .map(|ca| ca.into_time().into_series()) + } + + fn into_total_eq_inner<'a>(&'a self) -> Box { + self.0.physical().into_total_eq_inner() + } + fn into_total_ord_inner<'a>(&'a self) -> Box { + self.0.physical().into_total_ord_inner() + } + + fn vec_hash( + &self, + random_state: PlSeedableRandomStateQuality, + buf: &mut Vec, + ) -> PolarsResult<()> { + self.0.vec_hash(random_state, buf)?; + Ok(()) + } + + fn vec_hash_combine( + &self, + build_hasher: PlSeedableRandomStateQuality, + hashes: &mut [u64], + ) -> PolarsResult<()> { + self.0.vec_hash_combine(build_hasher, hashes)?; + Ok(()) + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_min(&self, groups: &GroupsType) -> Series { + self.0.agg_min(groups).into_time().into_series() + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_max(&self, groups: &GroupsType) -> Series { + self.0.agg_max(groups).into_time().into_series() + } + + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { + // we cannot cast and dispatch as the inner type of the list would be incorrect + self.0 + .agg_list(groups) + .cast(&DataType::List(Box::new(self.dtype().clone()))) + .unwrap() + } + + fn subtract(&self, rhs: &Series) -> PolarsResult { + let rhs = rhs.time().map_err(|_| polars_err!(InvalidOperation: "cannot subtract a {} dtype with a series of type: {}", self.dtype(), rhs.dtype()))?; + + let phys = self + .0 + .physical() + .subtract(&rhs.physical().clone().into_series())?; + + Ok(phys.into_duration(TimeUnit::Nanoseconds)) + } + + fn add_to(&self, rhs: &Series) -> PolarsResult { + polars_bail!(opq = add, DataType::Time, rhs.dtype()); + } + + fn multiply(&self, rhs: &Series) -> PolarsResult { + polars_bail!(opq = mul, self.0.dtype(), rhs.dtype()); + } + + fn divide(&self, rhs: &Series) -> PolarsResult { + polars_bail!(opq = div, self.0.dtype(), rhs.dtype()); + } + + fn remainder(&self, rhs: &Series) -> PolarsResult { + polars_bail!(opq = rem, self.0.dtype(), rhs.dtype()); + } + + #[cfg(feature = "algorithm_group_by")] + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + self.0.group_tuples(multithreaded, sorted) + } + + fn arg_sort_multiple( + &self, + by: &[Column], + options: &SortMultipleOptions, + ) -> PolarsResult { + self.0.deref().arg_sort_multiple(by, options) + } +} + +impl SeriesTrait for SeriesWrap { + fn rename(&mut self, name: PlSmallStr) { + self.0.rename(name); + } + + fn chunk_lengths(&self) -> ChunkLenIter { + self.0.chunk_lengths() + } + fn name(&self) -> &PlSmallStr { + self.0.name() + } + + fn chunks(&self) -> &Vec { + self.0.chunks() + } + unsafe fn chunks_mut(&mut self) -> &mut Vec { + self.0.chunks_mut() + } + + fn shrink_to_fit(&mut self) { + self.0.shrink_to_fit() + } + + fn slice(&self, offset: i64, length: usize) -> Series { + self.0.slice(offset, length).into_time().into_series() + } + fn split_at(&self, offset: i64) -> (Series, Series) { + let (a, b) = self.0.split_at(offset); + (a.into_time().into_series(), b.into_time().into_series()) + } + + fn _sum_as_f64(&self) -> f64 { + self.0._sum_as_f64() + } + + fn mean(&self) -> Option { + self.0.mean() + } + + fn median(&self) -> Option { + self.0.median() + } + + fn append(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + let mut other = other.to_physical_repr().into_owned(); + self.0 + .append_owned(std::mem::take(other._get_inner_mut().as_mut())) + } + fn append_owned(&mut self, mut other: Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), append); + self.0.append_owned(std::mem::take( + &mut other + ._get_inner_mut() + .as_any_mut() + .downcast_mut::() + .unwrap() + .0, + )) + } + + fn extend(&mut self, other: &Series) -> PolarsResult<()> { + polars_ensure!(self.0.dtype() == other.dtype(), extend); + // 3 refs + // ref Cow + // ref SeriesTrait + // ref ChunkedArray + let other = other.to_physical_repr(); + self.0.extend(other.as_ref().as_ref().as_ref())?; + Ok(()) + } + + fn filter(&self, filter: &BooleanChunked) -> PolarsResult { + self.0.filter(filter).map(|ca| ca.into_time().into_series()) + } + + fn take(&self, indices: &IdxCa) -> PolarsResult { + Ok(self.0.take(indices)?.into_time().into_series()) + } + + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0.take_unchecked(indices).into_time().into_series() + } + + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self.0.take(indices)?.into_time().into_series()) + } + + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0.take_unchecked(indices).into_time().into_series() + } + + fn len(&self) -> usize { + self.0.len() + } + + fn rechunk(&self) -> Series { + self.0.rechunk().into_owned().into_time().into_series() + } + + fn new_from_index(&self, index: usize, length: usize) -> Series { + self.0 + .new_from_index(index, length) + .into_time() + .into_series() + } + + fn cast(&self, dtype: &DataType, cast_options: CastOptions) -> PolarsResult { + match dtype { + DataType::String => Ok(self + .0 + .clone() + .into_series() + .time() + .unwrap() + .to_string("%T") + .into_series()), + _ => self.0.cast_with_options(dtype, cast_options), + } + } + + #[inline] + unsafe fn get_unchecked(&self, index: usize) -> AnyValue { + self.0.get_any_value_unchecked(index) + } + + fn sort_with(&self, options: SortOptions) -> PolarsResult { + Ok(self.0.sort_with(options).into_time().into_series()) + } + + fn arg_sort(&self, options: SortOptions) -> IdxCa { + self.0.arg_sort(options) + } + + fn null_count(&self) -> usize { + self.0.null_count() + } + + fn has_nulls(&self) -> bool { + self.0.has_nulls() + } + + #[cfg(feature = "algorithm_group_by")] + fn unique(&self) -> PolarsResult { + self.0.unique().map(|ca| ca.into_time().into_series()) + } + + #[cfg(feature = "algorithm_group_by")] + fn n_unique(&self) -> PolarsResult { + self.0.n_unique() + } + + #[cfg(feature = "algorithm_group_by")] + fn arg_unique(&self) -> PolarsResult { + self.0.arg_unique() + } + + fn is_null(&self) -> BooleanChunked { + self.0.is_null() + } + + fn is_not_null(&self) -> BooleanChunked { + self.0.is_not_null() + } + + fn reverse(&self) -> Series { + self.0.reverse().into_time().into_series() + } + + fn as_single_ptr(&mut self) -> PolarsResult { + self.0.as_single_ptr() + } + + fn shift(&self, periods: i64) -> Series { + self.0.shift(periods).into_time().into_series() + } + + fn max_reduce(&self) -> PolarsResult { + let sc = self.0.max_reduce(); + let av = sc.value().cast(self.dtype()).into_static(); + Ok(Scalar::new(self.dtype().clone(), av)) + } + + fn min_reduce(&self) -> PolarsResult { + let sc = self.0.min_reduce(); + let av = sc.value().cast(self.dtype()).into_static(); + Ok(Scalar::new(self.dtype().clone(), av)) + } + + fn median_reduce(&self) -> PolarsResult { + let av = AnyValue::from(self.median().map(|v| v as i64)) + .cast(self.dtype()) + .into_static(); + Ok(Scalar::new(self.dtype().clone(), av)) + } + + fn clone_inner(&self) -> Arc { + Arc::new(SeriesWrap(Clone::clone(&self.0))) + } + + fn as_any(&self) -> &dyn Any { + &self.0 + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } + + fn as_phys_any(&self) -> &dyn Any { + self.0.physical() + } + + fn as_arc_any(self: Arc) -> Arc { + self as _ + } +} + +impl private::PrivateSeriesNumeric for SeriesWrap { + fn bit_repr(&self) -> Option { + Some(self.0.to_bit_repr()) + } +} diff --git a/crates/polars-core/src/series/into.rs b/crates/polars-core/src/series/into.rs new file mode 100644 index 000000000000..c42b54e349f5 --- /dev/null +++ b/crates/polars-core/src/series/into.rs @@ -0,0 +1,202 @@ +#[cfg(any( + feature = "dtype-datetime", + feature = "dtype-date", + feature = "dtype-duration", + feature = "dtype-time" +))] +use polars_compute::cast::cast_default as cast; +use polars_compute::cast::cast_unchecked; + +use crate::prelude::*; + +impl Series { + /// Returns a reference to the Arrow ArrayRef + #[inline] + pub fn array_ref(&self, chunk_idx: usize) -> &ArrayRef { + &self.chunks()[chunk_idx] as &ArrayRef + } + + /// Convert a chunk in the Series to the correct Arrow type. + /// This conversion is needed because polars doesn't use a + /// 1 on 1 mapping for logical/ categoricals, etc. + pub fn to_arrow(&self, chunk_idx: usize, compat_level: CompatLevel) -> ArrayRef { + match self.dtype() { + // make sure that we recursively apply all logical types. + #[cfg(feature = "dtype-struct")] + dt @ DataType::Struct(fields) => { + let ca = self.struct_().unwrap(); + let arr = ca.downcast_chunks().get(chunk_idx).unwrap(); + let values = arr + .values() + .iter() + .zip(fields.iter()) + .map(|(values, field)| { + let dtype = &field.dtype; + let s = unsafe { + Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![values.clone()], + &dtype.to_physical(), + ) + .from_physical_unchecked(dtype) + .unwrap() + }; + s.to_arrow(0, compat_level) + }) + .collect::>(); + StructArray::new( + dt.to_arrow(compat_level), + arr.len(), + values, + arr.validity().cloned(), + ) + .boxed() + }, + // special list branch to + // make sure that we recursively apply all logical types. + DataType::List(inner) => { + let ca = self.list().unwrap(); + let arr = ca.chunks[chunk_idx].clone(); + let arr = arr.as_any().downcast_ref::>().unwrap(); + + let new_values = if let DataType::Null = &**inner { + arr.values().clone() + } else { + // We pass physical arrays and cast to logical before we convert to arrow. + let s = unsafe { + Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![arr.values().clone()], + &inner.to_physical(), + ) + .from_physical_unchecked(inner) + .unwrap() + }; + + s.to_arrow(0, compat_level) + }; + + let dtype = ListArray::::default_datatype(inner.to_arrow(compat_level)); + let arr = ListArray::::new( + dtype, + arr.offsets().clone(), + new_values, + arr.validity().cloned(), + ); + Box::new(arr) + }, + #[cfg(feature = "dtype-array")] + DataType::Array(inner, width) => { + let ca = self.array().unwrap(); + let arr = ca.chunks[chunk_idx].clone(); + let arr = arr.as_any().downcast_ref::().unwrap(); + + let new_values = if let DataType::Null = &**inner { + arr.values().clone() + } else { + let s = unsafe { + Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![arr.values().clone()], + &inner.to_physical(), + ) + .from_physical_unchecked(inner) + .unwrap() + }; + + s.to_arrow(0, compat_level) + }; + + let dtype = + FixedSizeListArray::default_datatype(inner.to_arrow(compat_level), *width); + let arr = + FixedSizeListArray::new(dtype, arr.len(), new_values, arr.validity().cloned()); + Box::new(arr) + }, + #[cfg(feature = "dtype-categorical")] + dt @ (DataType::Categorical(_, ordering) | DataType::Enum(_, ordering)) => { + let ca = self.categorical().unwrap(); + let arr = ca.physical().chunks()[chunk_idx].clone(); + // SAFETY: categoricals are always u32's. + let cats = unsafe { UInt32Chunked::from_chunks(PlSmallStr::EMPTY, vec![arr]) }; + + // SAFETY: we only take a single chunk and change nothing about the index/rev_map mapping. + let new = unsafe { + CategoricalChunked::from_cats_and_rev_map_unchecked( + cats, + ca.get_rev_map().clone(), + matches!(dt, DataType::Enum(_, _)), + *ordering, + ) + }; + + new.to_arrow(compat_level, false) + }, + #[cfg(feature = "dtype-date")] + DataType::Date => cast( + &*self.chunks()[chunk_idx], + &DataType::Date.to_arrow(compat_level), + ) + .unwrap(), + #[cfg(feature = "dtype-datetime")] + DataType::Datetime(_, _) => cast( + &*self.chunks()[chunk_idx], + &self.dtype().to_arrow(compat_level), + ) + .unwrap(), + #[cfg(feature = "dtype-duration")] + DataType::Duration(_) => cast( + &*self.chunks()[chunk_idx], + &self.dtype().to_arrow(compat_level), + ) + .unwrap(), + #[cfg(feature = "dtype-time")] + DataType::Time => cast( + &*self.chunks()[chunk_idx], + &DataType::Time.to_arrow(compat_level), + ) + .unwrap(), + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(_, _) => self.decimal().unwrap().chunks()[chunk_idx] + .as_any() + .downcast_ref::>() + .unwrap() + .clone() + .to(self.dtype().to_arrow(CompatLevel::newest())) + .to_boxed(), + #[cfg(feature = "object")] + DataType::Object(_) => { + use crate::chunked_array::object::builder::object_series_to_arrow_array; + if self.chunks().len() == 1 && chunk_idx == 0 { + object_series_to_arrow_array(self) + } else { + // we slice the series to only that chunk + let offset = self.chunks()[..chunk_idx] + .iter() + .map(|arr| arr.len()) + .sum::() as i64; + let len = self.chunks()[chunk_idx].len(); + let s = self.slice(offset, len); + object_series_to_arrow_array(&s) + } + }, + DataType::String => { + if compat_level.0 >= 1 { + self.array_ref(chunk_idx).clone() + } else { + let arr = self.array_ref(chunk_idx); + cast_unchecked(arr.as_ref(), &ArrowDataType::LargeUtf8).unwrap() + } + }, + DataType::Binary => { + if compat_level.0 >= 1 { + self.array_ref(chunk_idx).clone() + } else { + let arr = self.array_ref(chunk_idx); + cast_unchecked(arr.as_ref(), &ArrowDataType::LargeBinary).unwrap() + } + }, + _ => self.array_ref(chunk_idx).clone(), + } + } +} diff --git a/crates/polars-core/src/series/iterator.rs b/crates/polars-core/src/series/iterator.rs new file mode 100644 index 000000000000..bc9212a89fa1 --- /dev/null +++ b/crates/polars-core/src/series/iterator.rs @@ -0,0 +1,226 @@ +use crate::prelude::any_value::arr_to_any_value; +use crate::prelude::*; +use crate::utils::NoNull; + +macro_rules! from_iterator { + ($native:ty, $variant:ident) => { + impl FromIterator> for Series { + fn from_iter>>(iter: I) -> Self { + let ca: ChunkedArray<$variant> = iter.into_iter().collect(); + ca.into_series() + } + } + + impl FromIterator<$native> for Series { + fn from_iter>(iter: I) -> Self { + let ca: NoNull> = iter.into_iter().collect(); + ca.into_inner().into_series() + } + } + + impl<'a> FromIterator<&'a $native> for Series { + fn from_iter>(iter: I) -> Self { + let ca: ChunkedArray<$variant> = iter.into_iter().map(|v| Some(*v)).collect(); + ca.into_series() + } + } + }; +} + +#[cfg(feature = "dtype-u8")] +from_iterator!(u8, UInt8Type); +#[cfg(feature = "dtype-u16")] +from_iterator!(u16, UInt16Type); +from_iterator!(u32, UInt32Type); +from_iterator!(u64, UInt64Type); +#[cfg(feature = "dtype-i8")] +from_iterator!(i8, Int8Type); +#[cfg(feature = "dtype-i16")] +from_iterator!(i16, Int16Type); +from_iterator!(i32, Int32Type); +from_iterator!(i64, Int64Type); +from_iterator!(f32, Float32Type); +from_iterator!(f64, Float64Type); +from_iterator!(bool, BooleanType); + +impl<'a> FromIterator> for Series { + fn from_iter>>(iter: I) -> Self { + let ca: StringChunked = iter.into_iter().collect(); + ca.into_series() + } +} + +impl<'a> FromIterator<&'a str> for Series { + fn from_iter>(iter: I) -> Self { + let ca: StringChunked = iter.into_iter().collect(); + ca.into_series() + } +} + +impl FromIterator> for Series { + fn from_iter>>(iter: T) -> Self { + let ca: StringChunked = iter.into_iter().collect(); + ca.into_series() + } +} + +impl FromIterator for Series { + fn from_iter>(iter: I) -> Self { + let ca: StringChunked = iter.into_iter().collect(); + ca.into_series() + } +} + +pub type SeriesPhysIter<'a> = Box> + 'a>; + +impl Series { + /// iterate over [`Series`] as [`AnyValue`]. + /// + /// # Panics + /// This will panic if the array is not rechunked first. + pub fn iter(&self) -> SeriesIter<'_> { + let dtype = self.dtype(); + #[cfg(feature = "object")] + assert!( + !matches!(dtype, DataType::Object(_)), + "object dtype not supported in Series.iter" + ); + assert_eq!(self.chunks().len(), 1, "impl error"); + let arr = &*self.chunks()[0]; + let len = arr.len(); + SeriesIter { + arr, + dtype, + idx: 0, + len, + } + } + + pub fn phys_iter(&self) -> SeriesPhysIter<'_> { + let dtype = self.dtype(); + let phys_dtype = dtype.to_physical(); + + assert_eq!(dtype, &phys_dtype, "impl error"); + assert_eq!(self.chunks().len(), 1, "impl error"); + #[cfg(feature = "object")] + assert!( + !matches!(dtype, DataType::Object(_)), + "object dtype not supported in Series.iter" + ); + let arr = &*self.chunks()[0]; + + if phys_dtype.is_primitive_numeric() { + if arr.null_count() == 0 { + with_match_physical_numeric_type!(phys_dtype, |$T| { + let arr = arr.as_any().downcast_ref::>().unwrap(); + let values = arr.values().as_slice(); + Box::new(values.iter().map(|&value| AnyValue::from(value))) as Box> + '_> + }) + } else { + with_match_physical_numeric_type!(phys_dtype, |$T| { + let arr = arr.as_any().downcast_ref::>().unwrap(); + Box::new(arr.iter().map(|value| { + + match value { + Some(value) => AnyValue::from(*value), + None => AnyValue::Null + } + + })) as Box> + '_> + }) + } + } else { + match dtype { + DataType::String => { + let arr = arr.as_any().downcast_ref::().unwrap(); + if arr.null_count() == 0 { + Box::new(arr.values_iter().map(AnyValue::String)) + as Box> + '_> + } else { + let zipvalid = arr.iter(); + Box::new(zipvalid.unwrap_optional().map(|v| match v { + Some(value) => AnyValue::String(value), + None => AnyValue::Null, + })) + as Box> + '_> + } + }, + DataType::Boolean => { + let arr = arr.as_any().downcast_ref::().unwrap(); + if arr.null_count() == 0 { + Box::new(arr.values_iter().map(AnyValue::Boolean)) + as Box> + '_> + } else { + let zipvalid = arr.iter(); + Box::new(zipvalid.unwrap_optional().map(|v| match v { + Some(value) => AnyValue::Boolean(value), + None => AnyValue::Null, + })) + as Box> + '_> + } + }, + _ => Box::new(self.iter()), + } + } + } +} + +pub struct SeriesIter<'a> { + arr: &'a dyn Array, + dtype: &'a DataType, + idx: usize, + len: usize, +} + +impl<'a> Iterator for SeriesIter<'a> { + type Item = AnyValue<'a>; + + #[inline] + fn next(&mut self) -> Option { + let idx = self.idx; + + if idx == self.len { + None + } else { + self.idx += 1; + unsafe { Some(arr_to_any_value(self.arr, idx, self.dtype)) } + } + } + + fn size_hint(&self) -> (usize, Option) { + (self.len, Some(self.len)) + } +} + +impl ExactSizeIterator for SeriesIter<'_> {} + +#[cfg(test)] +mod test { + use crate::prelude::*; + + #[test] + fn test_iter() { + let a = Series::new("age".into(), [23, 71, 9].as_ref()); + let _b = a + .i32() + .unwrap() + .into_iter() + .map(|opt_v| opt_v.map(|v| v * 2)); + } + + #[test] + fn test_iter_str() { + let data = [Some("John"), Some("Doe"), None]; + let a: Series = data.into_iter().collect(); + let b = Series::new("".into(), data); + assert_eq!(a, b); + } + + #[test] + fn test_iter_string() { + let data = [Some("John".to_string()), Some("Doe".to_string()), None]; + let a: Series = data.clone().into_iter().collect(); + let b = Series::new("".into(), data); + assert_eq!(a, b); + } +} diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs new file mode 100644 index 000000000000..3a5efd2c33ee --- /dev/null +++ b/crates/polars-core/src/series/mod.rs @@ -0,0 +1,1215 @@ +#![allow(unsafe_op_in_unsafe_fn)] +//! Type agnostic columnar data structure. +use crate::chunked_array::flags::StatisticsFlags; +pub use crate::prelude::ChunkCompareEq; +use crate::prelude::*; +use crate::{HEAD_DEFAULT_LENGTH, TAIL_DEFAULT_LENGTH}; + +macro_rules! invalid_operation_panic { + ($op:ident, $s:expr) => { + panic!( + "`{}` operation not supported for dtype `{}`", + stringify!($op), + $s._dtype() + ) + }; +} + +pub mod amortized_iter; +mod any_value; +pub mod arithmetic; +pub mod builder; +mod comparison; +mod from; +pub mod implementations; +mod into; +pub(crate) mod iterator; +pub mod ops; +mod series_trait; + +use std::borrow::Cow; +use std::hash::{Hash, Hasher}; +use std::ops::Deref; + +use arrow::compute::aggregate::estimated_bytes_size; +use arrow::offset::Offsets; +pub use from::*; +pub use iterator::{SeriesIter, SeriesPhysIter}; +use num_traits::NumCast; +use polars_error::feature_gated; +pub use series_trait::{IsSorted, *}; + +use crate::POOL; +use crate::chunked_array::cast::CastOptions; +#[cfg(feature = "zip_with")] +use crate::series::arithmetic::coerce_lhs_rhs; +use crate::utils::{Wrap, handle_casting_failures, materialize_dyn_int}; + +/// # Series +/// The columnar data type for a DataFrame. +/// +/// Most of the available functions are defined in the [SeriesTrait trait](crate::series::SeriesTrait). +/// +/// The `Series` struct consists +/// of typed [ChunkedArray]'s. To quickly cast +/// a `Series` to a `ChunkedArray` you can call the method with the name of the type: +/// +/// ``` +/// # use polars_core::prelude::*; +/// let s: Series = [1, 2, 3].iter().collect(); +/// // Quickly obtain the ChunkedArray wrapped by the Series. +/// let chunked_array = s.i32().unwrap(); +/// ``` +/// +/// ## Arithmetic +/// +/// You can do standard arithmetic on series. +/// ``` +/// # use polars_core::prelude::*; +/// let s = Series::new("a".into(), [1 , 2, 3]); +/// let out_add = &s + &s; +/// let out_sub = &s - &s; +/// let out_div = &s / &s; +/// let out_mul = &s * &s; +/// ``` +/// +/// Or with series and numbers. +/// +/// ``` +/// # use polars_core::prelude::*; +/// let s: Series = (1..3).collect(); +/// let out_add_one = &s + 1; +/// let out_multiply = &s * 10; +/// +/// // Could not overload left hand side operator. +/// let out_divide = 1.div(&s); +/// let out_add = 1.add(&s); +/// let out_subtract = 1.sub(&s); +/// let out_multiply = 1.mul(&s); +/// ``` +/// +/// ## Comparison +/// You can obtain boolean mask by comparing series. +/// +/// ``` +/// # use polars_core::prelude::*; +/// let s = Series::new("dollars".into(), &[1, 2, 3]); +/// let mask = s.equal(1).unwrap(); +/// let valid = [true, false, false].iter(); +/// assert!(mask +/// .into_iter() +/// .map(|opt_bool| opt_bool.unwrap()) // option, because series can be null +/// .zip(valid) +/// .all(|(a, b)| a == *b)) +/// ``` +/// +/// See all the comparison operators in the [ChunkCompareEq trait](crate::chunked_array::ops::ChunkCompareEq) and +/// [ChunkCompareIneq trait](crate::chunked_array::ops::ChunkCompareIneq). +/// +/// ## Iterators +/// The Series variants contain differently typed [ChunkedArray]s. +/// These structs can be turned into iterators, making it possible to use any function/ closure you want +/// on a Series. +/// +/// These iterators return an `Option` because the values of a series may be null. +/// +/// ``` +/// use polars_core::prelude::*; +/// let pi = 3.14; +/// let s = Series::new("angle".into(), [2f32 * pi, pi, 1.5 * pi].as_ref()); +/// let s_cos: Series = s.f32() +/// .expect("series was not an f32 dtype") +/// .into_iter() +/// .map(|opt_angle| opt_angle.map(|angle| angle.cos())) +/// .collect(); +/// ``` +/// +/// ## Creation +/// Series can be create from different data structures. Below we'll show a few ways we can create +/// a Series object. +/// +/// ``` +/// # use polars_core::prelude::*; +/// // Series can be created from Vec's, slices and arrays +/// Series::new("boolean series".into(), &[true, false, true]); +/// Series::new("int series".into(), &[1, 2, 3]); +/// // And can be nullable +/// Series::new("got nulls".into(), &[Some(1), None, Some(2)]); +/// +/// // Series can also be collected from iterators +/// let from_iter: Series = (0..10) +/// .into_iter() +/// .collect(); +/// +/// ``` +#[derive(Clone)] +#[must_use] +pub struct Series(pub Arc); + +impl PartialEq for Wrap { + fn eq(&self, other: &Self) -> bool { + self.0.equals_missing(other) + } +} + +impl Eq for Wrap {} + +impl Hash for Wrap { + fn hash(&self, state: &mut H) { + let rs = PlSeedableRandomStateQuality::fixed(); + let mut h = vec![]; + if self.0.vec_hash(rs, &mut h).is_ok() { + let h = h.into_iter().fold(0, |a: u64, b| a.wrapping_add(b)); + h.hash(state) + } else { + self.len().hash(state); + self.null_count().hash(state); + self.dtype().hash(state); + } + } +} + +impl Series { + /// Create a new empty Series. + pub fn new_empty(name: PlSmallStr, dtype: &DataType) -> Series { + Series::full_null(name, 0, dtype) + } + + pub fn clear(&self) -> Series { + if self.is_empty() { + self.clone() + } else { + match self.dtype() { + #[cfg(feature = "object")] + DataType::Object(_) => self + .take(&ChunkedArray::::new_vec(PlSmallStr::EMPTY, vec![])) + .unwrap(), + dt => Series::new_empty(self.name().clone(), dt), + } + } + } + + #[doc(hidden)] + pub fn _get_inner_mut(&mut self) -> &mut dyn SeriesTrait { + if Arc::weak_count(&self.0) + Arc::strong_count(&self.0) != 1 { + self.0 = self.0.clone_inner(); + } + Arc::get_mut(&mut self.0).expect("implementation error") + } + + /// Take or clone a owned copy of the inner [`ChunkedArray`]. + pub fn take_inner(self) -> ChunkedArray + where + T: 'static + PolarsDataType, + { + let arc_any = self.0.as_arc_any(); + let downcast = arc_any + .downcast::>>() + .unwrap(); + + match Arc::try_unwrap(downcast) { + Ok(ca) => ca.0, + Err(ca) => ca.as_ref().as_ref().clone(), + } + } + + /// # Safety + /// The caller must ensure the length and the data types of `ArrayRef` does not change. + /// And that the null_count is updated (e.g. with a `compute_len()`) + pub unsafe fn chunks_mut(&mut self) -> &mut Vec { + #[allow(unused_mut)] + let mut ca = self._get_inner_mut(); + ca.chunks_mut() + } + + pub fn into_chunks(mut self) -> Vec { + let ca = self._get_inner_mut(); + let chunks = std::mem::take(unsafe { ca.chunks_mut() }); + ca.compute_len(); + chunks + } + + // TODO! this probably can now be removed, now we don't have special case for structs. + pub fn select_chunk(&self, i: usize) -> Self { + let mut new = self.clear(); + let mut flags = self.get_flags(); + + use StatisticsFlags as F; + flags &= F::IS_SORTED_ANY | F::CAN_FAST_EXPLODE_LIST; + + // Assign mut so we go through arc only once. + let mut_new = new._get_inner_mut(); + let chunks = unsafe { mut_new.chunks_mut() }; + let chunk = self.chunks()[i].clone(); + chunks.clear(); + chunks.push(chunk); + mut_new.compute_len(); + mut_new._set_flags(flags); + new + } + + pub fn is_sorted_flag(&self) -> IsSorted { + if self.len() <= 1 { + return IsSorted::Ascending; + } + self.get_flags().is_sorted() + } + + pub fn set_sorted_flag(&mut self, sorted: IsSorted) { + let mut flags = self.get_flags(); + flags.set_sorted(sorted); + self.set_flags(flags); + } + + pub(crate) fn clear_flags(&mut self) { + self.set_flags(StatisticsFlags::empty()); + } + pub fn get_flags(&self) -> StatisticsFlags { + self.0._get_flags() + } + + pub(crate) fn set_flags(&mut self, flags: StatisticsFlags) { + self._get_inner_mut()._set_flags(flags) + } + + pub fn into_frame(self) -> DataFrame { + // SAFETY: A single-column dataframe cannot have length mismatches or duplicate names + unsafe { DataFrame::new_no_checks(self.len(), vec![self.into()]) } + } + + /// Rename series. + pub fn rename(&mut self, name: PlSmallStr) -> &mut Series { + self._get_inner_mut().rename(name); + self + } + + /// Return this Series with a new name. + pub fn with_name(mut self, name: PlSmallStr) -> Series { + self.rename(name); + self + } + + pub fn from_arrow_chunks(name: PlSmallStr, arrays: Vec) -> PolarsResult { + Self::try_from((name, arrays)) + } + + pub fn from_arrow(name: PlSmallStr, array: ArrayRef) -> PolarsResult { + Self::try_from((name, array)) + } + + /// Shrink the capacity of this array to fit its length. + pub fn shrink_to_fit(&mut self) { + self._get_inner_mut().shrink_to_fit() + } + + /// Append in place. This is done by adding the chunks of `other` to this [`Series`]. + /// + /// See [`ChunkedArray::append`] and [`ChunkedArray::extend`]. + pub fn append(&mut self, other: &Series) -> PolarsResult<&mut Self> { + let must_cast = other.dtype().matches_schema_type(self.dtype())?; + if must_cast { + let other = other.cast(self.dtype())?; + self.append_owned(other)?; + } else { + self._get_inner_mut().append(other)?; + } + Ok(self) + } + + /// Append in place. This is done by adding the chunks of `other` to this [`Series`]. + /// + /// See [`ChunkedArray::append_owned`] and [`ChunkedArray::extend`]. + pub fn append_owned(&mut self, other: Series) -> PolarsResult<&mut Self> { + let must_cast = other.dtype().matches_schema_type(self.dtype())?; + if must_cast { + let other = other.cast(self.dtype())?; + self._get_inner_mut().append_owned(other)?; + } else { + self._get_inner_mut().append_owned(other)?; + } + Ok(self) + } + + /// Redo a length and null_count compute + pub fn compute_len(&mut self) { + self._get_inner_mut().compute_len() + } + + /// Extend the memory backed by this array with the values from `other`. + /// + /// See [`ChunkedArray::extend`] and [`ChunkedArray::append`]. + pub fn extend(&mut self, other: &Series) -> PolarsResult<&mut Self> { + let must_cast = other.dtype().matches_schema_type(self.dtype())?; + if must_cast { + let other = other.cast(self.dtype())?; + self._get_inner_mut().extend(&other)?; + } else { + self._get_inner_mut().extend(other)?; + } + Ok(self) + } + + /// Sort the series with specific options. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// # fn main() -> PolarsResult<()> { + /// let s = Series::new("foo".into(), [2, 1, 3]); + /// let sorted = s.sort(SortOptions::default())?; + /// assert_eq!(sorted, Series::new("foo".into(), [1, 2, 3])); + /// # Ok(()) + /// } + /// ``` + /// + /// See [`SortOptions`] for more options. + pub fn sort(&self, sort_options: SortOptions) -> PolarsResult { + self.sort_with(sort_options) + } + + /// Only implemented for numeric types + pub fn as_single_ptr(&mut self) -> PolarsResult { + self._get_inner_mut().as_single_ptr() + } + + pub fn cast(&self, dtype: &DataType) -> PolarsResult { + self.cast_with_options(dtype, CastOptions::NonStrict) + } + + /// Cast [`Series`] to another [`DataType`]. + pub fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + use DataType as D; + + let do_clone = match dtype { + D::Unknown(UnknownKind::Any) => true, + D::Unknown(UnknownKind::Int(_)) if self.dtype().is_integer() => true, + D::Unknown(UnknownKind::Float) if self.dtype().is_float() => true, + D::Unknown(UnknownKind::Str) + if self.dtype().is_string() | self.dtype().is_categorical() => + { + true + }, + dt if dt.is_primitive() && dt == self.dtype() => true, + #[cfg(feature = "dtype-categorical")] + D::Enum(None, _) => { + polars_bail!(InvalidOperation: "cannot cast / initialize Enum without categories present"); + }, + _ => false, + }; + + if do_clone { + return Ok(self.clone()); + } + + pub fn cast_dtype(dtype: &DataType) -> Option { + match dtype { + D::Unknown(UnknownKind::Int(v)) => Some(materialize_dyn_int(*v).dtype()), + D::Unknown(UnknownKind::Float) => Some(DataType::Float64), + D::Unknown(UnknownKind::Str) => Some(DataType::String), + // Best leave as is. + D::List(inner) => cast_dtype(inner.as_ref()).map(Box::new).map(D::List), + #[cfg(feature = "dtype-struct")] + D::Struct(fields) => { + // @NOTE: We only allocate if we really need to. + + let mut field_iter = fields.iter().enumerate(); + let mut new_fields = loop { + let (i, field) = field_iter.next()?; + + if let Some(dtype) = cast_dtype(&field.dtype) { + let mut new_fields = Vec::with_capacity(fields.len()); + new_fields.extend(fields.iter().take(i).cloned()); + new_fields.push(Field { + name: field.name.clone(), + dtype, + }); + break new_fields; + } + }; + + new_fields.extend(fields.iter().skip(new_fields.len()).cloned().map(|field| { + let dtype = cast_dtype(&field.dtype).unwrap_or(field.dtype); + Field { + name: field.name.clone(), + dtype, + } + })); + + Some(D::Struct(new_fields)) + }, + _ => None, + } + } + + let casted = cast_dtype(dtype); + let dtype = match casted { + None => dtype, + Some(ref dtype) => dtype, + }; + + // Always allow casting all nulls to other all nulls. + let len = self.len(); + if self.null_count() == len { + return Ok(Series::full_null(self.name().clone(), len, dtype)); + } + + let new_options = match options { + // Strictness is handled on this level to improve error messages. + CastOptions::Strict => CastOptions::NonStrict, + opt => opt, + }; + + let ret = self.0.cast(dtype, new_options); + + match options { + CastOptions::NonStrict | CastOptions::Overflowing => ret, + CastOptions::Strict => { + let ret = ret?; + if self.null_count() != ret.null_count() { + handle_casting_failures(self, &ret)?; + } + Ok(ret) + }, + } + } + + /// Cast from physical to logical types without any checks on the validity of the cast. + /// + /// # Safety + /// + /// This can lead to invalid memory access in downstream code. + pub unsafe fn cast_unchecked(&self, dtype: &DataType) -> PolarsResult { + match self.dtype() { + #[cfg(feature = "dtype-struct")] + DataType::Struct(_) => self.struct_().unwrap().cast_unchecked(dtype), + DataType::List(_) => self.list().unwrap().cast_unchecked(dtype), + dt if dt.is_primitive_numeric() => { + with_match_physical_numeric_polars_type!(dt, |$T| { + let ca: &ChunkedArray<$T> = self.as_ref().as_ref().as_ref(); + ca.cast_unchecked(dtype) + }) + }, + DataType::Binary => self.binary().unwrap().cast_unchecked(dtype), + _ => self.cast_with_options(dtype, CastOptions::Overflowing), + } + } + + /// Convert a non-logical series back into a logical series without casting. + /// + /// # Safety + /// + /// This can lead to invalid memory access in downstream code. + pub unsafe fn from_physical_unchecked(&self, dtype: &DataType) -> PolarsResult { + debug_assert!(!self.dtype().is_logical()); + + if self.dtype() == dtype { + return Ok(self.clone()); + } + + use DataType as D; + match (self.dtype(), dtype) { + #[cfg(feature = "dtype-decimal")] + (D::Int128, D::Decimal(precision, scale)) => { + self.clone().into_decimal(*precision, scale.unwrap()) + }, + + #[cfg(feature = "dtype-categorical")] + (D::UInt32, D::Categorical(revmap, ordering)) => match revmap { + Some(revmap) => Ok(unsafe { + CategoricalChunked::from_cats_and_rev_map_unchecked( + self.u32().unwrap().clone(), + revmap.clone(), + false, + *ordering, + ) + } + .into_series()), + // In the streaming engine this is `None` and the global string cache is turned on + // for the duration of the query. + None => Ok(unsafe { + CategoricalChunked::from_global_indices_unchecked( + self.u32().unwrap().clone(), + *ordering, + ) + .into_series() + }), + }, + #[cfg(feature = "dtype-categorical")] + (D::UInt32, D::Enum(revmap, ordering)) => Ok(unsafe { + CategoricalChunked::from_cats_and_rev_map_unchecked( + self.u32().unwrap().clone(), + revmap.as_ref().unwrap().clone(), + true, + *ordering, + ) + } + .into_series()), + + (D::Int32, D::Date) => feature_gated!("dtype-time", Ok(self.clone().into_date())), + (D::Int64, D::Datetime(tu, tz)) => feature_gated!( + "dtype-datetime", + Ok(self.clone().into_datetime(*tu, tz.clone())) + ), + (D::Int64, D::Duration(tu)) => { + feature_gated!("dtype-duration", Ok(self.clone().into_duration(*tu))) + }, + (D::Int64, D::Time) => feature_gated!("dtype-time", Ok(self.clone().into_time())), + + (D::List(_), D::List(to)) => unsafe { + self.list() + .unwrap() + .from_physical_unchecked(to.as_ref().clone()) + .map(|ca| ca.into_series()) + }, + #[cfg(feature = "dtype-array")] + (D::Array(_, lw), D::Array(to, rw)) if lw == rw => unsafe { + self.array() + .unwrap() + .from_physical_unchecked(to.as_ref().clone()) + .map(|ca| ca.into_series()) + }, + #[cfg(feature = "dtype-struct")] + (D::Struct(_), D::Struct(to)) => unsafe { + self.struct_() + .unwrap() + .from_physical_unchecked(to.as_slice()) + .map(|ca| ca.into_series()) + }, + + _ => panic!("invalid from_physical({dtype:?}) for {:?}", self.dtype()), + } + } + + /// Cast numerical types to f64, and keep floats as is. + pub fn to_float(&self) -> PolarsResult { + match self.dtype() { + DataType::Float32 | DataType::Float64 => Ok(self.clone()), + _ => self.cast_with_options(&DataType::Float64, CastOptions::Overflowing), + } + } + + /// Compute the sum of all values in this Series. + /// Returns `Some(0)` if the array is empty, and `None` if the array only + /// contains null values. + /// + /// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16}` the `Series` is + /// first cast to `Int64` to prevent overflow issues. + pub fn sum(&self) -> PolarsResult + where + T: NumCast, + { + let sum = self.sum_reduce()?; + let sum = sum.value().extract().unwrap(); + Ok(sum) + } + + /// Returns the minimum value in the array, according to the natural order. + /// Returns an option because the array is nullable. + pub fn min(&self) -> PolarsResult> + where + T: NumCast, + { + let min = self.min_reduce()?; + let min = min.value().extract::(); + Ok(min) + } + + /// Returns the maximum value in the array, according to the natural order. + /// Returns an option because the array is nullable. + pub fn max(&self) -> PolarsResult> + where + T: NumCast, + { + let max = self.max_reduce()?; + let max = max.value().extract::(); + Ok(max) + } + + /// Explode a list Series. This expands every item to a new row.. + pub fn explode(&self) -> PolarsResult { + match self.dtype() { + DataType::List(_) => self.list().unwrap().explode(), + #[cfg(feature = "dtype-array")] + DataType::Array(_, _) => self.array().unwrap().explode(), + _ => Ok(self.clone()), + } + } + + /// Check if numeric value is NaN (note this is different than missing/ null) + pub fn is_nan(&self) -> PolarsResult { + match self.dtype() { + DataType::Float32 => Ok(self.f32().unwrap().is_nan()), + DataType::Float64 => Ok(self.f64().unwrap().is_nan()), + DataType::Null => Ok(BooleanChunked::full_null(self.name().clone(), self.len())), + dt if dt.is_primitive_numeric() => { + let arr = BooleanArray::full(self.len(), false, ArrowDataType::Boolean) + .with_validity(self.rechunk_validity()); + Ok(BooleanChunked::with_chunk(self.name().clone(), arr)) + }, + _ => polars_bail!(opq = is_nan, self.dtype()), + } + } + + /// Check if numeric value is NaN (note this is different than missing/null) + pub fn is_not_nan(&self) -> PolarsResult { + match self.dtype() { + DataType::Float32 => Ok(self.f32().unwrap().is_not_nan()), + DataType::Float64 => Ok(self.f64().unwrap().is_not_nan()), + dt if dt.is_primitive_numeric() => { + let arr = BooleanArray::full(self.len(), true, ArrowDataType::Boolean) + .with_validity(self.rechunk_validity()); + Ok(BooleanChunked::with_chunk(self.name().clone(), arr)) + }, + _ => polars_bail!(opq = is_not_nan, self.dtype()), + } + } + + /// Check if numeric value is finite + pub fn is_finite(&self) -> PolarsResult { + match self.dtype() { + DataType::Float32 => Ok(self.f32().unwrap().is_finite()), + DataType::Float64 => Ok(self.f64().unwrap().is_finite()), + DataType::Null => Ok(BooleanChunked::full_null(self.name().clone(), self.len())), + dt if dt.is_primitive_numeric() => { + let arr = BooleanArray::full(self.len(), true, ArrowDataType::Boolean) + .with_validity(self.rechunk_validity()); + Ok(BooleanChunked::with_chunk(self.name().clone(), arr)) + }, + _ => polars_bail!(opq = is_finite, self.dtype()), + } + } + + /// Check if numeric value is infinite + pub fn is_infinite(&self) -> PolarsResult { + match self.dtype() { + DataType::Float32 => Ok(self.f32().unwrap().is_infinite()), + DataType::Float64 => Ok(self.f64().unwrap().is_infinite()), + DataType::Null => Ok(BooleanChunked::full_null(self.name().clone(), self.len())), + dt if dt.is_primitive_numeric() => { + let arr = BooleanArray::full(self.len(), false, ArrowDataType::Boolean) + .with_validity(self.rechunk_validity()); + Ok(BooleanChunked::with_chunk(self.name().clone(), arr)) + }, + _ => polars_bail!(opq = is_infinite, self.dtype()), + } + } + + /// Create a new ChunkedArray with values from self where the mask evaluates `true` and values + /// from `other` where the mask evaluates `false`. This function automatically broadcasts unit + /// length inputs. + #[cfg(feature = "zip_with")] + pub fn zip_with(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult { + let (lhs, rhs) = coerce_lhs_rhs(self, other)?; + lhs.zip_with_same_type(mask, rhs.as_ref()) + } + + /// Converts a Series to their physical representation, if they have one, + /// otherwise the series is left unchanged. + /// + /// * Date -> Int32 + /// * Datetime -> Int64 + /// * Duration -> Int64 + /// * Decimal -> Int128 + /// * Time -> Int64 + /// * Categorical -> UInt32 + /// * List(inner) -> List(physical of inner) + /// * Array(inner) -> Array(physical of inner) + /// * Struct -> Struct with physical repr of each struct column + pub fn to_physical_repr(&self) -> Cow { + use DataType::*; + match self.dtype() { + // NOTE: Don't use cast here, as it might rechunk (if all nulls) + // which is not allowed in a phys repr. + #[cfg(feature = "dtype-date")] + Date => Cow::Owned(self.date().unwrap().0.clone().into_series()), + #[cfg(feature = "dtype-datetime")] + Datetime(_, _) => Cow::Owned(self.datetime().unwrap().0.clone().into_series()), + #[cfg(feature = "dtype-duration")] + Duration(_) => Cow::Owned(self.duration().unwrap().0.clone().into_series()), + #[cfg(feature = "dtype-time")] + Time => Cow::Owned(self.time().unwrap().0.clone().into_series()), + #[cfg(feature = "dtype-categorical")] + Categorical(_, _) | Enum(_, _) => { + let ca = self.categorical().unwrap(); + Cow::Owned(ca.physical().clone().into_series()) + }, + #[cfg(feature = "dtype-decimal")] + Decimal(_, _) => Cow::Owned(self.decimal().unwrap().0.clone().into_series()), + List(_) => match self.list().unwrap().to_physical_repr() { + Cow::Borrowed(_) => Cow::Borrowed(self), + Cow::Owned(ca) => Cow::Owned(ca.into_series()), + }, + #[cfg(feature = "dtype-array")] + Array(_, _) => match self.array().unwrap().to_physical_repr() { + Cow::Borrowed(_) => Cow::Borrowed(self), + Cow::Owned(ca) => Cow::Owned(ca.into_series()), + }, + #[cfg(feature = "dtype-struct")] + Struct(_) => match self.struct_().unwrap().to_physical_repr() { + Cow::Borrowed(_) => Cow::Borrowed(self), + Cow::Owned(ca) => Cow::Owned(ca.into_series()), + }, + _ => Cow::Borrowed(self), + } + } + + /// Traverse and collect every nth element in a new array. + pub fn gather_every(&self, n: usize, offset: usize) -> PolarsResult { + polars_ensure!(n > 0, ComputeError: "cannot perform gather every for `n=0`"); + let idx = ((offset as IdxSize)..self.len() as IdxSize) + .step_by(n) + .collect_ca(PlSmallStr::EMPTY); + // SAFETY: we stay in-bounds. + Ok(unsafe { self.take_unchecked(&idx) }) + } + + #[cfg(feature = "dot_product")] + pub fn dot(&self, other: &Series) -> PolarsResult { + std::ops::Mul::mul(self, other)?.sum::() + } + + /// Get the sum of the Series as a new Series of length 1. + /// Returns a Series with a single zeroed entry if self is an empty numeric series. + /// + /// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16}` the `Series` is + /// first cast to `Int64` to prevent overflow issues. + pub fn sum_reduce(&self) -> PolarsResult { + use DataType::*; + match self.dtype() { + Int8 | UInt8 | Int16 | UInt16 => self.cast(&Int64).unwrap().sum_reduce(), + _ => self.0.sum_reduce(), + } + } + + /// Get the product of an array. + /// + /// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16}` the `Series` is + /// first cast to `Int64` to prevent overflow issues. + pub fn product(&self) -> PolarsResult { + #[cfg(feature = "product")] + { + use DataType::*; + match self.dtype() { + Boolean => self.cast(&DataType::Int64).unwrap().product(), + Int8 | UInt8 | Int16 | UInt16 | Int32 | UInt32 => { + let s = self.cast(&Int64).unwrap(); + s.product() + }, + Int64 => Ok(self.i64().unwrap().prod_reduce()), + UInt64 => Ok(self.u64().unwrap().prod_reduce()), + #[cfg(feature = "dtype-i128")] + Int128 => Ok(self.i128().unwrap().prod_reduce()), + Float32 => Ok(self.f32().unwrap().prod_reduce()), + Float64 => Ok(self.f64().unwrap().prod_reduce()), + dt => { + polars_bail!(InvalidOperation: "`product` operation not supported for dtype `{dt}`") + }, + } + } + #[cfg(not(feature = "product"))] + { + panic!("activate 'product' feature") + } + } + + /// Cast throws an error if conversion had overflows + pub fn strict_cast(&self, dtype: &DataType) -> PolarsResult { + self.cast_with_options(dtype, CastOptions::Strict) + } + + #[cfg(feature = "dtype-decimal")] + pub(crate) fn into_decimal( + self, + precision: Option, + scale: usize, + ) -> PolarsResult { + match self.dtype() { + DataType::Int128 => Ok(self + .i128() + .unwrap() + .clone() + .into_decimal(precision, scale)? + .into_series()), + DataType::Decimal(cur_prec, cur_scale) + if (cur_prec.is_none() || precision.is_none() || *cur_prec == precision) + && *cur_scale == Some(scale) => + { + Ok(self) + }, + dt => panic!("into_decimal({precision:?}, {scale}) not implemented for {dt:?}"), + } + } + + #[cfg(feature = "dtype-time")] + pub(crate) fn into_time(self) -> Series { + match self.dtype() { + DataType::Int64 => self.i64().unwrap().clone().into_time().into_series(), + DataType::Time => self + .time() + .unwrap() + .as_ref() + .clone() + .into_time() + .into_series(), + dt => panic!("date not implemented for {dt:?}"), + } + } + + pub(crate) fn into_date(self) -> Series { + #[cfg(not(feature = "dtype-date"))] + { + panic!("activate feature dtype-date") + } + #[cfg(feature = "dtype-date")] + match self.dtype() { + DataType::Int32 => self.i32().unwrap().clone().into_date().into_series(), + DataType::Date => self + .date() + .unwrap() + .as_ref() + .clone() + .into_date() + .into_series(), + dt => panic!("date not implemented for {dt:?}"), + } + } + + #[allow(unused_variables)] + pub(crate) fn into_datetime(self, timeunit: TimeUnit, tz: Option) -> Series { + #[cfg(not(feature = "dtype-datetime"))] + { + panic!("activate feature dtype-datetime") + } + + #[cfg(feature = "dtype-datetime")] + match self.dtype() { + DataType::Int64 => self + .i64() + .unwrap() + .clone() + .into_datetime(timeunit, tz) + .into_series(), + DataType::Datetime(_, _) => self + .datetime() + .unwrap() + .as_ref() + .clone() + .into_datetime(timeunit, tz) + .into_series(), + dt => panic!("into_datetime not implemented for {dt:?}"), + } + } + + #[allow(unused_variables)] + pub(crate) fn into_duration(self, timeunit: TimeUnit) -> Series { + #[cfg(not(feature = "dtype-duration"))] + { + panic!("activate feature dtype-duration") + } + #[cfg(feature = "dtype-duration")] + match self.dtype() { + DataType::Int64 => self + .i64() + .unwrap() + .clone() + .into_duration(timeunit) + .into_series(), + DataType::Duration(_) => self + .duration() + .unwrap() + .as_ref() + .clone() + .into_duration(timeunit) + .into_series(), + dt => panic!("into_duration not implemented for {dt:?}"), + } + } + + // used for formatting + pub fn str_value(&self, index: usize) -> PolarsResult> { + Ok(self.0.get(index)?.str_value()) + } + /// Get the head of the Series. + pub fn head(&self, length: Option) -> Series { + let len = length.unwrap_or(HEAD_DEFAULT_LENGTH); + self.slice(0, std::cmp::min(len, self.len())) + } + + /// Get the tail of the Series. + pub fn tail(&self, length: Option) -> Series { + let len = length.unwrap_or(TAIL_DEFAULT_LENGTH); + let len = std::cmp::min(len, self.len()); + self.slice(-(len as i64), len) + } + + pub fn mean_reduce(&self) -> Scalar { + crate::scalar::reduce::mean_reduce(self.mean(), self.dtype().clone()) + } + + /// Compute the unique elements, but maintain order. This requires more work + /// than a naive [`Series::unique`](SeriesTrait::unique). + pub fn unique_stable(&self) -> PolarsResult { + let idx = self.arg_unique()?; + // SAFETY: Indices are in bounds. + unsafe { Ok(self.take_unchecked(&idx)) } + } + + pub fn try_idx(&self) -> Option<&IdxCa> { + #[cfg(feature = "bigidx")] + { + self.try_u64() + } + #[cfg(not(feature = "bigidx"))] + { + self.try_u32() + } + } + + pub fn idx(&self) -> PolarsResult<&IdxCa> { + #[cfg(feature = "bigidx")] + { + self.u64() + } + #[cfg(not(feature = "bigidx"))] + { + self.u32() + } + } + + /// Returns an estimation of the total (heap) allocated size of the `Series` in bytes. + /// + /// # Implementation + /// This estimation is the sum of the size of its buffers, validity, including nested arrays. + /// Multiple arrays may share buffers and bitmaps. Therefore, the size of 2 arrays is not the + /// sum of the sizes computed from this function. In particular, [`StructArray`]'s size is an upper bound. + /// + /// When an array is sliced, its allocated size remains constant because the buffer unchanged. + /// However, this function will yield a smaller number. This is because this function returns + /// the visible size of the buffer, not its total capacity. + /// + /// FFI buffers are included in this estimation. + pub fn estimated_size(&self) -> usize { + let mut size = 0; + match self.dtype() { + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(Some(rv), _) | DataType::Enum(Some(rv), _) => match &**rv { + RevMapping::Local(arr, _) => size += estimated_bytes_size(arr), + RevMapping::Global(map, arr, _) => { + size += map.capacity() * size_of::() * 2 + estimated_bytes_size(arr); + }, + }, + #[cfg(feature = "object")] + DataType::Object(_) => { + let ArrowDataType::FixedSizeBinary(size) = self.chunks()[0].dtype() else { + unreachable!() + }; + // This is only the pointer size in python. So will be a huge underestimation. + return self.len() * *size; + }, + _ => {}, + } + + size += self + .chunks() + .iter() + .map(|arr| estimated_bytes_size(&**arr)) + .sum::(); + + size + } + + /// Packs every element into a list. + pub fn as_list(&self) -> ListChunked { + let s = self.rechunk(); + // don't use `to_arrow` as we need the physical types + let values = s.chunks()[0].clone(); + let offsets = (0i64..(s.len() as i64 + 1)).collect::>(); + let offsets = unsafe { Offsets::new_unchecked(offsets) }; + + let dtype = LargeListArray::default_datatype( + s.dtype().to_physical().to_arrow(CompatLevel::newest()), + ); + let new_arr = LargeListArray::new(dtype, offsets.into(), values, None); + let mut out = ListChunked::with_chunk(s.name().clone(), new_arr); + out.set_inner_dtype(s.dtype().clone()); + out + } +} + +impl Deref for Series { + type Target = dyn SeriesTrait; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} + +impl<'a> AsRef<(dyn SeriesTrait + 'a)> for Series { + fn as_ref(&self) -> &(dyn SeriesTrait + 'a) { + self.0.as_ref() + } +} + +impl Default for Series { + fn default() -> Self { + Int64Chunked::default().into_series() + } +} + +impl AsRef> for dyn SeriesTrait + '_ +where + T: 'static + PolarsDataType, +{ + fn as_ref(&self) -> &ChunkedArray { + // @NOTE: SeriesTrait `as_any` returns a std::any::Any for the underlying ChunkedArray / + // Logical (so not the SeriesWrap). + let Some(ca) = self.as_any().downcast_ref::>() else { + panic!( + "implementation error, cannot get ref {:?} from {:?}", + T::get_dtype(), + self.dtype() + ); + }; + + ca + } +} + +impl AsMut> for dyn SeriesTrait + '_ +where + T: 'static + PolarsDataType, +{ + fn as_mut(&mut self) -> &mut ChunkedArray { + if !self.as_any_mut().is::>() { + panic!( + "implementation error, cannot get ref {:?} from {:?}", + T::get_dtype(), + self.dtype() + ); + } + + // @NOTE: SeriesTrait `as_any` returns a std::any::Any for the underlying ChunkedArray / + // Logical (so not the SeriesWrap). + self.as_any_mut().downcast_mut::>().unwrap() + } +} + +#[cfg(test)] +mod test { + use crate::prelude::*; + use crate::series::*; + + #[test] + fn cast() { + let ar = UInt32Chunked::new("a".into(), &[1, 2]); + let s = ar.into_series(); + let s2 = s.cast(&DataType::Int64).unwrap(); + + assert!(s2.i64().is_ok()); + let s2 = s.cast(&DataType::Float32).unwrap(); + assert!(s2.f32().is_ok()); + } + + #[test] + fn new_series() { + let _ = Series::new("boolean series".into(), &vec![true, false, true]); + let _ = Series::new("int series".into(), &[1, 2, 3]); + let ca = Int32Chunked::new("a".into(), &[1, 2, 3]); + let _ = ca.into_series(); + } + + #[test] + #[cfg(feature = "dtype-date")] + fn roundtrip_list_logical_20311() { + let list = ListChunked::from_chunk_iter( + PlSmallStr::from_static("a"), + [ListArray::new( + ArrowDataType::LargeList(Box::new(ArrowField::new( + PlSmallStr::from_static("item"), + ArrowDataType::Int32, + true, + ))), + unsafe { Offsets::new_unchecked(vec![0, 1]) }.into(), + PrimitiveArray::new(ArrowDataType::Int32, vec![1i32].into(), None).to_boxed(), + None, + )], + ); + let list = unsafe { list.from_physical_unchecked(DataType::Date) }.unwrap(); + assert_eq!(list.dtype(), &DataType::List(Box::new(DataType::Date))); + } + + #[test] + #[cfg(feature = "dtype-struct")] + fn new_series_from_empty_structs() { + let dtype = DataType::Struct(vec![]); + let empties = vec![AnyValue::StructOwned(Box::new((vec![], vec![]))); 3]; + let s = Series::from_any_values_and_dtype("".into(), &empties, &dtype, false).unwrap(); + assert_eq!(s.len(), 3); + } + #[test] + fn new_series_from_arrow_primitive_array() { + let array = UInt32Array::from_slice([1, 2, 3, 4, 5]); + let array_ref: ArrayRef = Box::new(array); + + let _ = Series::try_new("foo".into(), array_ref).unwrap(); + } + + #[test] + fn series_append() { + let mut s1 = Series::new("a".into(), &[1, 2]); + let s2 = Series::new("b".into(), &[3]); + s1.append(&s2).unwrap(); + assert_eq!(s1.len(), 3); + + // add wrong type + let s2 = Series::new("b".into(), &[3.0]); + assert!(s1.append(&s2).is_err()) + } + + #[test] + #[cfg(feature = "dtype-decimal")] + fn series_append_decimal() { + let s1 = Series::new("a".into(), &[1.1, 2.3]) + .cast(&DataType::Decimal(None, Some(2))) + .unwrap(); + let s2 = Series::new("b".into(), &[3]) + .cast(&DataType::Decimal(None, Some(0))) + .unwrap(); + + { + let mut s1 = s1.clone(); + s1.append(&s2).unwrap(); + assert_eq!(s1.len(), 3); + assert_eq!(s1.get(2).unwrap(), AnyValue::Decimal(300, 2)); + } + + { + let mut s2 = s2.clone(); + s2.extend(&s1).unwrap(); + assert_eq!(s2.get(2).unwrap(), AnyValue::Decimal(2, 0)); + } + } + + #[test] + fn series_slice_works() { + let series = Series::new("a".into(), &[1i64, 2, 3, 4, 5]); + + let slice_1 = series.slice(-3, 3); + let slice_2 = series.slice(-5, 5); + let slice_3 = series.slice(0, 5); + + assert_eq!(slice_1.get(0).unwrap(), AnyValue::Int64(3)); + assert_eq!(slice_2.get(0).unwrap(), AnyValue::Int64(1)); + assert_eq!(slice_3.get(0).unwrap(), AnyValue::Int64(1)); + } + + #[test] + fn out_of_range_slice_does_not_panic() { + let series = Series::new("a".into(), &[1i64, 2, 3, 4, 5]); + + let _ = series.slice(-3, 4); + let _ = series.slice(-6, 2); + let _ = series.slice(4, 2); + } +} diff --git a/crates/polars-core/src/series/ops/downcast.rs b/crates/polars-core/src/series/ops/downcast.rs new file mode 100644 index 000000000000..2a8cd18f824c --- /dev/null +++ b/crates/polars-core/src/series/ops/downcast.rs @@ -0,0 +1,365 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use crate::prelude::*; +use crate::series::implementations::null::NullChunked; + +macro_rules! unpack_chunked_err { + ($series:expr => $name:expr) => { + polars_err!(SchemaMismatch: "invalid series dtype: expected `{}`, got `{}` for series with name `{}`", $name, $series.dtype(), $series.name()) + }; +} + +macro_rules! try_unpack_chunked { + ($series:expr, $expected:pat => $ca:ty) => { + match $series.dtype() { + $expected => { + // Check downcast in debug compiles + #[cfg(debug_assertions)] + { + Some($series.as_ref().as_any().downcast_ref::<$ca>().unwrap()) + } + #[cfg(not(debug_assertions))] + unsafe { + Some(&*($series.as_ref() as *const dyn SeriesTrait as *const $ca)) + } + }, + _ => None, + } + }; +} + +impl Series { + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Int8`] + pub fn try_i8(&self) -> Option<&Int8Chunked> { + try_unpack_chunked!(self, DataType::Int8 => Int8Chunked) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Int16`] + pub fn try_i16(&self) -> Option<&Int16Chunked> { + try_unpack_chunked!(self, DataType::Int16 => Int16Chunked) + } + + /// Unpack to [`ChunkedArray`] + /// ``` + /// # use polars_core::prelude::*; + /// let s = Series::new("foo".into(), [1i32 ,2, 3]); + /// let s_squared: Series = s.i32() + /// .unwrap() + /// .into_iter() + /// .map(|opt_v| { + /// match opt_v { + /// Some(v) => Some(v * v), + /// None => None, // null value + /// } + /// }).collect(); + /// ``` + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Int32`] + pub fn try_i32(&self) -> Option<&Int32Chunked> { + try_unpack_chunked!(self, DataType::Int32 => Int32Chunked) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Int64`] + pub fn try_i64(&self) -> Option<&Int64Chunked> { + try_unpack_chunked!(self, DataType::Int64 => Int64Chunked) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Int128`] + #[cfg(feature = "dtype-i128")] + pub fn try_i128(&self) -> Option<&Int128Chunked> { + try_unpack_chunked!(self, DataType::Int128 => Int128Chunked) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Float32`] + pub fn try_f32(&self) -> Option<&Float32Chunked> { + try_unpack_chunked!(self, DataType::Float32 => Float32Chunked) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Float64`] + pub fn try_f64(&self) -> Option<&Float64Chunked> { + try_unpack_chunked!(self, DataType::Float64 => Float64Chunked) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::UInt8`] + pub fn try_u8(&self) -> Option<&UInt8Chunked> { + try_unpack_chunked!(self, DataType::UInt8 => UInt8Chunked) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::UInt16`] + pub fn try_u16(&self) -> Option<&UInt16Chunked> { + try_unpack_chunked!(self, DataType::UInt16 => UInt16Chunked) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::UInt32`] + pub fn try_u32(&self) -> Option<&UInt32Chunked> { + try_unpack_chunked!(self, DataType::UInt32 => UInt32Chunked) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::UInt64`] + pub fn try_u64(&self) -> Option<&UInt64Chunked> { + try_unpack_chunked!(self, DataType::UInt64 => UInt64Chunked) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Boolean`] + pub fn try_bool(&self) -> Option<&BooleanChunked> { + try_unpack_chunked!(self, DataType::Boolean => BooleanChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::String`] + pub fn try_str(&self) -> Option<&StringChunked> { + try_unpack_chunked!(self, DataType::String => StringChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Binary`] + pub fn try_binary(&self) -> Option<&BinaryChunked> { + try_unpack_chunked!(self, DataType::Binary => BinaryChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Binary`] + pub fn try_binary_offset(&self) -> Option<&BinaryOffsetChunked> { + try_unpack_chunked!(self, DataType::BinaryOffset => BinaryOffsetChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Time`] + #[cfg(feature = "dtype-time")] + pub fn try_time(&self) -> Option<&TimeChunked> { + try_unpack_chunked!(self, DataType::Time => TimeChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Date`] + #[cfg(feature = "dtype-date")] + pub fn try_date(&self) -> Option<&DateChunked> { + try_unpack_chunked!(self, DataType::Date => DateChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Datetime`] + #[cfg(feature = "dtype-datetime")] + pub fn try_datetime(&self) -> Option<&DatetimeChunked> { + try_unpack_chunked!(self, DataType::Datetime(_, _) => DatetimeChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Duration`] + #[cfg(feature = "dtype-duration")] + pub fn try_duration(&self) -> Option<&DurationChunked> { + try_unpack_chunked!(self, DataType::Duration(_) => DurationChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Decimal`] + #[cfg(feature = "dtype-decimal")] + pub fn try_decimal(&self) -> Option<&DecimalChunked> { + try_unpack_chunked!(self, DataType::Decimal(_, _) => DecimalChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype list + pub fn try_list(&self) -> Option<&ListChunked> { + try_unpack_chunked!(self, DataType::List(_) => ListChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Array`] + #[cfg(feature = "dtype-array")] + pub fn try_array(&self) -> Option<&ArrayChunked> { + try_unpack_chunked!(self, DataType::Array(_, _) => ArrayChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Categorical`] + #[cfg(feature = "dtype-categorical")] + pub fn try_categorical(&self) -> Option<&CategoricalChunked> { + try_unpack_chunked!(self, DataType::Categorical(_, _) | DataType::Enum(_, _) => CategoricalChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Struct`] + #[cfg(feature = "dtype-struct")] + pub fn try_struct(&self) -> Option<&StructChunked> { + #[cfg(debug_assertions)] + { + if let DataType::Struct(_) = self.dtype() { + let any = self.as_any(); + assert!(any.is::()); + } + } + try_unpack_chunked!(self, DataType::Struct(_) => StructChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Null`] + pub fn try_null(&self) -> Option<&NullChunked> { + try_unpack_chunked!(self, DataType::Null => NullChunked) + } + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Int8`] + pub fn i8(&self) -> PolarsResult<&Int8Chunked> { + self.try_i8() + .ok_or_else(|| unpack_chunked_err!(self => "Int8")) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Int16`] + pub fn i16(&self) -> PolarsResult<&Int16Chunked> { + self.try_i16() + .ok_or_else(|| unpack_chunked_err!(self => "Int16")) + } + + /// Unpack to [`ChunkedArray`] + /// ``` + /// # use polars_core::prelude::*; + /// let s = Series::new("foo".into(), [1i32 ,2, 3]); + /// let s_squared: Series = s.i32() + /// .unwrap() + /// .into_iter() + /// .map(|opt_v| { + /// match opt_v { + /// Some(v) => Some(v * v), + /// None => None, // null value + /// } + /// }).collect(); + /// ``` + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Int32`] + pub fn i32(&self) -> PolarsResult<&Int32Chunked> { + self.try_i32() + .ok_or_else(|| unpack_chunked_err!(self => "Int32")) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Int64`] + pub fn i64(&self) -> PolarsResult<&Int64Chunked> { + self.try_i64() + .ok_or_else(|| unpack_chunked_err!(self => "Int64")) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Int128`] + #[cfg(feature = "dtype-i128")] + pub fn i128(&self) -> PolarsResult<&Int128Chunked> { + self.try_i128() + .ok_or_else(|| unpack_chunked_err!(self => "Int128")) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Float32`] + pub fn f32(&self) -> PolarsResult<&Float32Chunked> { + self.try_f32() + .ok_or_else(|| unpack_chunked_err!(self => "Float32")) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Float64`] + pub fn f64(&self) -> PolarsResult<&Float64Chunked> { + self.try_f64() + .ok_or_else(|| unpack_chunked_err!(self => "Float64")) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::UInt8`] + pub fn u8(&self) -> PolarsResult<&UInt8Chunked> { + self.try_u8() + .ok_or_else(|| unpack_chunked_err!(self => "UInt8")) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::UInt16`] + pub fn u16(&self) -> PolarsResult<&UInt16Chunked> { + self.try_u16() + .ok_or_else(|| unpack_chunked_err!(self => "UInt16")) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::UInt32`] + pub fn u32(&self) -> PolarsResult<&UInt32Chunked> { + self.try_u32() + .ok_or_else(|| unpack_chunked_err!(self => "UInt32")) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::UInt64`] + pub fn u64(&self) -> PolarsResult<&UInt64Chunked> { + self.try_u64() + .ok_or_else(|| unpack_chunked_err!(self => "UInt64")) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Boolean`] + pub fn bool(&self) -> PolarsResult<&BooleanChunked> { + self.try_bool() + .ok_or_else(|| unpack_chunked_err!(self => "Boolean")) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::String`] + pub fn str(&self) -> PolarsResult<&StringChunked> { + self.try_str() + .ok_or_else(|| unpack_chunked_err!(self => "String")) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Binary`] + pub fn binary(&self) -> PolarsResult<&BinaryChunked> { + self.try_binary() + .ok_or_else(|| unpack_chunked_err!(self => "Binary")) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Binary`] + pub fn binary_offset(&self) -> PolarsResult<&BinaryOffsetChunked> { + self.try_binary_offset() + .ok_or_else(|| unpack_chunked_err!(self => "BinaryOffset")) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Time`] + #[cfg(feature = "dtype-time")] + pub fn time(&self) -> PolarsResult<&TimeChunked> { + self.try_time() + .ok_or_else(|| unpack_chunked_err!(self => "Time")) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Date`] + #[cfg(feature = "dtype-date")] + pub fn date(&self) -> PolarsResult<&DateChunked> { + self.try_date() + .ok_or_else(|| unpack_chunked_err!(self => "Date")) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Datetime`] + #[cfg(feature = "dtype-datetime")] + pub fn datetime(&self) -> PolarsResult<&DatetimeChunked> { + self.try_datetime() + .ok_or_else(|| unpack_chunked_err!(self => "Datetime")) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Duration`] + #[cfg(feature = "dtype-duration")] + pub fn duration(&self) -> PolarsResult<&DurationChunked> { + self.try_duration() + .ok_or_else(|| unpack_chunked_err!(self => "Duration")) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Decimal`] + #[cfg(feature = "dtype-decimal")] + pub fn decimal(&self) -> PolarsResult<&DecimalChunked> { + self.try_decimal() + .ok_or_else(|| unpack_chunked_err!(self => "Decimal")) + } + + /// Unpack to [`ChunkedArray`] of dtype list + pub fn list(&self) -> PolarsResult<&ListChunked> { + self.try_list() + .ok_or_else(|| unpack_chunked_err!(self => "List")) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Array`] + #[cfg(feature = "dtype-array")] + pub fn array(&self) -> PolarsResult<&ArrayChunked> { + self.try_array() + .ok_or_else(|| unpack_chunked_err!(self => "FixedSizeList")) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Categorical`] + #[cfg(feature = "dtype-categorical")] + pub fn categorical(&self) -> PolarsResult<&CategoricalChunked> { + self.try_categorical() + .ok_or_else(|| unpack_chunked_err!(self => "Enum | Categorical")) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Struct`] + #[cfg(feature = "dtype-struct")] + pub fn struct_(&self) -> PolarsResult<&StructChunked> { + #[cfg(debug_assertions)] + { + if let DataType::Struct(_) = self.dtype() { + let any = self.as_any(); + assert!(any.is::()); + } + } + + self.try_struct() + .ok_or_else(|| unpack_chunked_err!(self => "Struct")) + } + + /// Unpack to [`ChunkedArray`] of dtype [`DataType::Null`] + pub fn null(&self) -> PolarsResult<&NullChunked> { + self.try_null() + .ok_or_else(|| unpack_chunked_err!(self => "Null")) + } +} diff --git a/crates/polars-core/src/series/ops/extend.rs b/crates/polars-core/src/series/ops/extend.rs new file mode 100644 index 000000000000..8bb72d515d59 --- /dev/null +++ b/crates/polars-core/src/series/ops/extend.rs @@ -0,0 +1,15 @@ +use crate::prelude::*; + +impl Series { + /// Extend with a constant value. + pub fn extend_constant(&self, value: AnyValue, n: usize) -> PolarsResult { + // TODO: Use `from_any_values_and_dtype` here instead of casting afterwards + let s = Series::from_any_values(PlSmallStr::EMPTY, &[value], true).unwrap(); + let s = s.cast(self.dtype())?; + let to_append = s.new_from_index(0, n); + + let mut out = self.clone(); + out.append(&to_append)?; + Ok(out) + } +} diff --git a/crates/polars-core/src/series/ops/mod.rs b/crates/polars-core/src/series/ops/mod.rs new file mode 100644 index 000000000000..ad927e834f9c --- /dev/null +++ b/crates/polars-core/src/series/ops/mod.rs @@ -0,0 +1,16 @@ +mod downcast; +mod extend; +mod null; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; +mod reshape; + +#[derive(Copy, Clone, Hash, Eq, PartialEq, Debug, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum NullBehavior { + /// drop nulls + Drop, + /// ignore nulls + #[default] + Ignore, +} diff --git a/crates/polars-core/src/series/ops/null.rs b/crates/polars-core/src/series/ops/null.rs new file mode 100644 index 000000000000..7f17e2120e4d --- /dev/null +++ b/crates/polars-core/src/series/ops/null.rs @@ -0,0 +1,120 @@ +use arrow::bitmap::Bitmap; +use arrow::buffer::Buffer; +use arrow::offset::OffsetsBuffer; + +#[cfg(feature = "object")] +use crate::chunked_array::object::registry::get_object_builder; +use crate::prelude::*; + +impl Series { + pub fn full_null(name: PlSmallStr, size: usize, dtype: &DataType) -> Self { + // match the logical types and create them + match dtype { + DataType::List(inner_dtype) => { + ListChunked::full_null_with_dtype(name, size, inner_dtype).into_series() + }, + #[cfg(feature = "dtype-array")] + DataType::Array(inner_dtype, width) => { + ArrayChunked::full_null_with_dtype(name, size, inner_dtype, *width).into_series() + }, + #[cfg(feature = "dtype-categorical")] + dt @ (DataType::Categorical(rev_map, ord) | DataType::Enum(rev_map, ord)) => { + let mut ca = CategoricalChunked::full_null( + name, + matches!(dt, DataType::Enum(_, _)), + size, + *ord, + ); + // ensure we keep the rev-map of a cleared series + if let Some(rev_map) = rev_map { + unsafe { ca.set_rev_map(rev_map.clone(), false) } + } + ca.into_series() + }, + #[cfg(feature = "dtype-date")] + DataType::Date => Int32Chunked::full_null(name, size) + .into_date() + .into_series(), + #[cfg(feature = "dtype-datetime")] + DataType::Datetime(tu, tz) => Int64Chunked::full_null(name, size) + .into_datetime(*tu, tz.clone()) + .into_series(), + #[cfg(feature = "dtype-duration")] + DataType::Duration(tu) => Int64Chunked::full_null(name, size) + .into_duration(*tu) + .into_series(), + #[cfg(feature = "dtype-time")] + DataType::Time => Int64Chunked::full_null(name, size) + .into_time() + .into_series(), + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(precision, scale) => Int128Chunked::full_null(name, size) + .into_decimal_unchecked(*precision, scale.unwrap_or(0)) + .into_series(), + #[cfg(feature = "dtype-struct")] + DataType::Struct(fields) => { + let fields = fields + .iter() + .map(|fld| Series::full_null(fld.name().clone(), size, fld.dtype())) + .collect::>(); + let ca = StructChunked::from_series(name, size, fields.iter()).unwrap(); + + if !fields.is_empty() { + ca.with_outer_validity(Some(Bitmap::new_zeroed(size))) + .into_series() + } else { + ca.into_series() + } + }, + DataType::BinaryOffset => { + let length = size; + + let offsets = vec![0; size + 1]; + let array = BinaryArray::::new( + dtype.to_arrow(CompatLevel::oldest()), + unsafe { OffsetsBuffer::new_unchecked(Buffer::from(offsets)) }, + Buffer::default(), + Some(Bitmap::new_zeroed(size)), + ); + + unsafe { + BinaryOffsetChunked::new_with_dims( + Arc::new(Field::new(name, dtype.clone())), + vec![Box::new(array)], + length, + length, + ) + } + .into_series() + }, + DataType::Null => Series::new_null(name, size), + DataType::Unknown(kind) => { + let dtype = kind.materialize().unwrap_or(DataType::Null); + Series::full_null(name, size, &dtype) + }, + #[cfg(feature = "object")] + DataType::Object(_) => { + let mut builder = get_object_builder(name, size); + for _ in 0..size { + builder.append_null(); + } + builder.to_series() + }, + _ => { + macro_rules! primitive { + ($type:ty) => {{ ChunkedArray::<$type>::full_null(name, size).into_series() }}; + } + macro_rules! bool { + () => {{ ChunkedArray::::full_null(name, size).into_series() }}; + } + macro_rules! string { + () => {{ ChunkedArray::::full_null(name, size).into_series() }}; + } + macro_rules! binary { + () => {{ ChunkedArray::::full_null(name, size).into_series() }}; + } + match_dtype_to_logical_apply_macro!(dtype, primitive, string, binary, bool) + }, + } + } +} diff --git a/crates/polars-core/src/series/ops/reshape.rs b/crates/polars-core/src/series/ops/reshape.rs new file mode 100644 index 000000000000..72b8be28c843 --- /dev/null +++ b/crates/polars-core/src/series/ops/reshape.rs @@ -0,0 +1,346 @@ +use std::borrow::Cow; + +use arrow::array::*; +use arrow::bitmap::Bitmap; +use arrow::offset::{Offsets, OffsetsBuffer}; +use polars_compute::gather::sublist::list::array_to_unit_list; +use polars_error::{PolarsResult, polars_bail, polars_ensure}; +use polars_utils::format_tuple; + +use crate::chunked_array::builder::get_list_builder; +use crate::datatypes::{DataType, ListChunked}; +use crate::prelude::{IntoSeries, Series, *}; + +fn reshape_fast_path(name: PlSmallStr, s: &Series) -> Series { + let mut ca = ListChunked::from_chunk_iter( + name, + s.chunks().iter().map(|arr| array_to_unit_list(arr.clone())), + ); + + ca.set_inner_dtype(s.dtype().clone()); + ca.set_fast_explode(); + ca.into_series() +} + +impl Series { + /// Recurse nested types until we are at the leaf array. + pub fn get_leaf_array(&self) -> Series { + let s = self; + match s.dtype() { + #[cfg(feature = "dtype-array")] + DataType::Array(dtype, _) => { + let ca = s.array().unwrap(); + let chunks = ca + .downcast_iter() + .map(|arr| arr.values().clone()) + .collect::>(); + // Safety: guarded by the type system + unsafe { Series::from_chunks_and_dtype_unchecked(s.name().clone(), chunks, dtype) } + .get_leaf_array() + }, + DataType::List(dtype) => { + let ca = s.list().unwrap(); + let chunks = ca + .downcast_iter() + .map(|arr| arr.values().clone()) + .collect::>(); + // Safety: guarded by the type system + unsafe { Series::from_chunks_and_dtype_unchecked(s.name().clone(), chunks, dtype) } + .get_leaf_array() + }, + _ => s.clone(), + } + } + + /// TODO: Move this somewhere else? + pub fn list_offsets_and_validities_recursive( + &self, + ) -> (Vec>, Vec>) { + let mut offsets = vec![]; + let mut validities = vec![]; + + let mut s = self.rechunk(); + + while let DataType::List(_) = s.dtype() { + let ca = s.list().unwrap(); + offsets.push(ca.offsets().unwrap()); + validities.push(ca.rechunk_validity()); + s = ca.get_inner(); + } + + (offsets, validities) + } + + /// For ListArrays, recursively normalizes the offsets to begin from 0, and + /// slices excess length from the values array. + pub fn list_rechunk_and_trim_to_normalized_offsets(&self) -> Self { + if let Some(ca) = self.try_list() { + ca.rechunk_and_trim_to_normalized_offsets().into_series() + } else { + self.rechunk() + } + } + + /// Convert the values of this Series to a ListChunked with a length of 1, + /// so a Series of `[1, 2, 3]` becomes `[[1, 2, 3]]`. + pub fn implode(&self) -> PolarsResult { + let s = self; + let s = s.rechunk(); + let values = s.array_ref(0); + + let offsets = vec![0i64, values.len() as i64]; + let inner_type = s.dtype(); + + let dtype = ListArray::::default_datatype(values.dtype().clone()); + + // SAFETY: offsets are correct. + let arr = unsafe { + ListArray::new( + dtype, + Offsets::new_unchecked(offsets).into(), + values.clone(), + None, + ) + }; + + let mut ca = ListChunked::with_chunk(s.name().clone(), arr); + unsafe { ca.to_logical(inner_type.clone()) }; + ca.set_fast_explode(); + Ok(ca) + } + + #[cfg(feature = "dtype-array")] + pub fn reshape_array(&self, dimensions: &[ReshapeDimension]) -> PolarsResult { + polars_ensure!( + !dimensions.is_empty(), + InvalidOperation: "at least one dimension must be specified" + ); + + let leaf_array = self.get_leaf_array().rechunk(); + let size = leaf_array.len(); + + let mut total_dim_size = 1; + let mut num_infers = 0; + for &dim in dimensions { + match dim { + ReshapeDimension::Infer => num_infers += 1, + ReshapeDimension::Specified(dim) => total_dim_size *= dim.get() as usize, + } + } + + polars_ensure!(num_infers <= 1, InvalidOperation: "can only specify one inferred dimension"); + + if size == 0 { + polars_ensure!( + num_infers > 0 || total_dim_size == 0, + InvalidOperation: "cannot reshape empty array into shape without zero dimension: {}", + format_tuple!(dimensions), + ); + + let mut prev_arrow_dtype = leaf_array + .dtype() + .to_physical() + .to_arrow(CompatLevel::newest()); + let mut prev_dtype = leaf_array.dtype().clone(); + let mut prev_array = leaf_array.chunks()[0].clone(); + + // @NOTE: We need to collect the iterator here because it is lazily processed. + let mut current_length = dimensions[0].get_or_infer(0); + let len_iter = dimensions[1..] + .iter() + .map(|d| { + let length = current_length as usize; + current_length *= d.get_or_infer(0); + length + }) + .collect::>(); + + // We pop the outer dimension as that is the height of the series. + for (dim, length) in dimensions[1..].iter().zip(len_iter).rev() { + // Infer dimension if needed + let dim = dim.get_or_infer(0); + prev_arrow_dtype = prev_arrow_dtype.to_fixed_size_list(dim as usize, true); + prev_dtype = DataType::Array(Box::new(prev_dtype), dim as usize); + + prev_array = + FixedSizeListArray::new(prev_arrow_dtype.clone(), length, prev_array, None) + .boxed(); + } + + return Ok(unsafe { + Series::from_chunks_and_dtype_unchecked( + leaf_array.name().clone(), + vec![prev_array], + &prev_dtype, + ) + }); + } + + polars_ensure!( + total_dim_size > 0, + InvalidOperation: "cannot reshape non-empty array into shape containing a zero dimension: {}", + format_tuple!(dimensions) + ); + + polars_ensure!( + size % total_dim_size == 0, + InvalidOperation: "cannot reshape array of size {} into shape {}", size, format_tuple!(dimensions) + ); + + let leaf_array = leaf_array.rechunk(); + let mut prev_arrow_dtype = leaf_array + .dtype() + .to_physical() + .to_arrow(CompatLevel::newest()); + let mut prev_dtype = leaf_array.dtype().clone(); + let mut prev_array = leaf_array.chunks()[0].clone(); + + // We pop the outer dimension as that is the height of the series. + for dim in dimensions[1..].iter().rev() { + // Infer dimension if needed + let dim = dim.get_or_infer((size / total_dim_size) as u64); + prev_arrow_dtype = prev_arrow_dtype.to_fixed_size_list(dim as usize, true); + prev_dtype = DataType::Array(Box::new(prev_dtype), dim as usize); + + prev_array = FixedSizeListArray::new( + prev_arrow_dtype.clone(), + prev_array.len() / dim as usize, + prev_array, + None, + ) + .boxed(); + } + Ok(unsafe { + Series::from_chunks_and_dtype_unchecked( + leaf_array.name().clone(), + vec![prev_array], + &prev_dtype, + ) + }) + } + + pub fn reshape_list(&self, dimensions: &[ReshapeDimension]) -> PolarsResult { + polars_ensure!( + !dimensions.is_empty(), + InvalidOperation: "at least one dimension must be specified" + ); + + let s = self; + let s = if let DataType::List(_) = s.dtype() { + Cow::Owned(s.explode()?) + } else { + Cow::Borrowed(s) + }; + + let s_ref = s.as_ref(); + + // let dimensions = dimensions.to_vec(); + + match dimensions.len() { + 1 => { + polars_ensure!( + dimensions[0].get().is_none_or( |dim| dim as usize == s_ref.len()), + InvalidOperation: "cannot reshape len {} into shape {:?}", s_ref.len(), dimensions, + ); + Ok(s_ref.clone()) + }, + 2 => { + let rows = dimensions[0]; + let cols = dimensions[1]; + + if s_ref.is_empty() { + if rows.get_or_infer(0) == 0 && cols.get_or_infer(0) <= 1 { + let s = reshape_fast_path(s.name().clone(), s_ref); + return Ok(s); + } else { + polars_bail!(InvalidOperation: "cannot reshape len 0 into shape {}", format_tuple!(dimensions)) + } + } + + use ReshapeDimension as RD; + // Infer dimension. + + let (rows, cols) = match (rows, cols) { + (RD::Infer, RD::Specified(cols)) if cols.get() >= 1 => { + (s_ref.len() as u64 / cols.get(), cols.get()) + }, + (RD::Specified(rows), RD::Infer) if rows.get() >= 1 => { + (rows.get(), s_ref.len() as u64 / rows.get()) + }, + (RD::Infer, RD::Infer) => (s_ref.len() as u64, 1u64), + (RD::Specified(rows), RD::Specified(cols)) => (rows.get(), cols.get()), + _ => polars_bail!(InvalidOperation: "reshape of non-zero list into zero list"), + }; + + // Fast path, we can create a unit list so we only allocate offsets. + if rows as usize == s_ref.len() && cols == 1 { + let s = reshape_fast_path(s.name().clone(), s_ref); + return Ok(s); + } + + polars_ensure!( + (rows*cols) as usize == s_ref.len() && rows >= 1 && cols >= 1, + InvalidOperation: "cannot reshape len {} into shape {:?}", s_ref.len(), dimensions, + ); + + let mut builder = + get_list_builder(s_ref.dtype(), s_ref.len(), rows as usize, s.name().clone()); + + let mut offset = 0u64; + for _ in 0..rows { + let row = s_ref.slice(offset as i64, cols as usize); + builder.append_series(&row).unwrap(); + offset += cols; + } + Ok(builder.finish().into_series()) + }, + _ => { + polars_bail!(InvalidOperation: "more than two dimensions not supported in reshaping to List.\n\nConsider reshaping to Array type."); + }, + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::prelude::*; + + #[test] + fn test_to_list() -> PolarsResult<()> { + let s = Series::new("a".into(), &[1, 2, 3]); + + let mut builder = get_list_builder(s.dtype(), s.len(), 1, s.name().clone()); + builder.append_series(&s).unwrap(); + let expected = builder.finish(); + + let out = s.implode()?; + assert!(expected.into_series().equals(&out.into_series())); + + Ok(()) + } + + #[test] + fn test_reshape() -> PolarsResult<()> { + let s = Series::new("a".into(), &[1, 2, 3, 4]); + + for (dims, list_len) in [ + (&[-1, 1], 4), + (&[4, 1], 4), + (&[2, 2], 2), + (&[-1, 2], 2), + (&[2, -1], 2), + ] { + let dims = dims + .iter() + .map(|&v| ReshapeDimension::new(v)) + .collect::>(); + let out = s.reshape_list(&dims)?; + assert_eq!(out.len(), list_len); + assert!(matches!(out.dtype(), DataType::List(_))); + assert_eq!(out.explode()?.len(), 4); + } + + Ok(()) + } +} diff --git a/crates/polars-core/src/series/series_trait.rs b/crates/polars-core/src/series/series_trait.rs new file mode 100644 index 000000000000..1298216c90e7 --- /dev/null +++ b/crates/polars-core/src/series/series_trait.rs @@ -0,0 +1,612 @@ +use std::any::Any; +use std::borrow::Cow; + +use arrow::bitmap::{Bitmap, BitmapBuilder}; +use polars_compute::rolling::QuantileMethod; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use crate::chunked_array::cast::CastOptions; +#[cfg(feature = "object")] +use crate::chunked_array::object::PolarsObjectSafe; +use crate::prelude::*; + +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum IsSorted { + Ascending, + Descending, + Not, +} + +impl IsSorted { + pub fn reverse(self) -> Self { + use IsSorted::*; + match self { + Ascending => Descending, + Descending => Ascending, + Not => Not, + } + } +} + +pub enum BitRepr { + Small(UInt32Chunked), + Large(UInt64Chunked), +} + +pub(crate) mod private { + use polars_utils::aliases::PlSeedableRandomStateQuality; + + use super::*; + use crate::chunked_array::flags::StatisticsFlags; + use crate::chunked_array::ops::compare_inner::{TotalEqInner, TotalOrdInner}; + + pub trait PrivateSeriesNumeric { + /// Return a bit representation + /// + /// If there is no available bit representation this returns `None`. + fn bit_repr(&self) -> Option; + } + + pub trait PrivateSeries { + #[cfg(feature = "object")] + fn get_list_builder( + &self, + _name: PlSmallStr, + _values_capacity: usize, + _list_capacity: usize, + ) -> Box { + invalid_operation_panic!(get_list_builder, self) + } + + /// Get field (used in schema) + fn _field(&self) -> Cow; + + fn _dtype(&self) -> &DataType; + + fn compute_len(&mut self); + + fn _get_flags(&self) -> StatisticsFlags; + + fn _set_flags(&mut self, flags: StatisticsFlags); + + unsafe fn equal_element( + &self, + _idx_self: usize, + _idx_other: usize, + _other: &Series, + ) -> bool { + invalid_operation_panic!(equal_element, self) + } + #[expect(clippy::wrong_self_convention)] + fn into_total_eq_inner<'a>(&'a self) -> Box; + #[expect(clippy::wrong_self_convention)] + fn into_total_ord_inner<'a>(&'a self) -> Box; + + fn vec_hash( + &self, + _build_hasher: PlSeedableRandomStateQuality, + _buf: &mut Vec, + ) -> PolarsResult<()> { + polars_bail!(opq = vec_hash, self._dtype()); + } + fn vec_hash_combine( + &self, + _build_hasher: PlSeedableRandomStateQuality, + _hashes: &mut [u64], + ) -> PolarsResult<()> { + polars_bail!(opq = vec_hash_combine, self._dtype()); + } + + /// # Safety + /// + /// Does no bounds checks, groups must be correct. + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_min(&self, groups: &GroupsType) -> Series { + Series::full_null(self._field().name().clone(), groups.len(), self._dtype()) + } + /// # Safety + /// + /// Does no bounds checks, groups must be correct. + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_max(&self, groups: &GroupsType) -> Series { + Series::full_null(self._field().name().clone(), groups.len(), self._dtype()) + } + /// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16}` the `Series` is + /// first cast to `Int64` to prevent overflow issues. + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_sum(&self, groups: &GroupsType) -> Series { + Series::full_null(self._field().name().clone(), groups.len(), self._dtype()) + } + /// # Safety + /// + /// Does no bounds checks, groups must be correct. + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_std(&self, groups: &GroupsType, _ddof: u8) -> Series { + Series::full_null(self._field().name().clone(), groups.len(), self._dtype()) + } + /// # Safety + /// + /// Does no bounds checks, groups must be correct. + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_var(&self, groups: &GroupsType, _ddof: u8) -> Series { + Series::full_null(self._field().name().clone(), groups.len(), self._dtype()) + } + /// # Safety + /// + /// Does no bounds checks, groups must be correct. + #[cfg(feature = "algorithm_group_by")] + unsafe fn agg_list(&self, groups: &GroupsType) -> Series { + Series::full_null(self._field().name().clone(), groups.len(), self._dtype()) + } + + /// # Safety + /// + /// Does no bounds checks, groups must be correct. + #[cfg(feature = "bitwise")] + unsafe fn agg_and(&self, groups: &GroupsType) -> Series { + Series::full_null(self._field().name().clone(), groups.len(), self._dtype()) + } + + /// # Safety + /// + /// Does no bounds checks, groups must be correct. + #[cfg(feature = "bitwise")] + unsafe fn agg_or(&self, groups: &GroupsType) -> Series { + Series::full_null(self._field().name().clone(), groups.len(), self._dtype()) + } + + /// # Safety + /// + /// Does no bounds checks, groups must be correct. + #[cfg(feature = "bitwise")] + unsafe fn agg_xor(&self, groups: &GroupsType) -> Series { + Series::full_null(self._field().name().clone(), groups.len(), self._dtype()) + } + + fn subtract(&self, _rhs: &Series) -> PolarsResult { + polars_bail!(opq = subtract, self._dtype()); + } + fn add_to(&self, _rhs: &Series) -> PolarsResult { + polars_bail!(opq = add, self._dtype()); + } + fn multiply(&self, _rhs: &Series) -> PolarsResult { + polars_bail!(opq = multiply, self._dtype()); + } + fn divide(&self, _rhs: &Series) -> PolarsResult { + polars_bail!(opq = divide, self._dtype()); + } + fn remainder(&self, _rhs: &Series) -> PolarsResult { + polars_bail!(opq = remainder, self._dtype()); + } + #[cfg(feature = "algorithm_group_by")] + fn group_tuples(&self, _multithreaded: bool, _sorted: bool) -> PolarsResult { + polars_bail!(opq = group_tuples, self._dtype()); + } + #[cfg(feature = "zip_with")] + fn zip_with_same_type( + &self, + _mask: &BooleanChunked, + _other: &Series, + ) -> PolarsResult { + polars_bail!(opq = zip_with_same_type, self._dtype()); + } + + #[allow(unused_variables)] + fn arg_sort_multiple( + &self, + by: &[Column], + _options: &SortMultipleOptions, + ) -> PolarsResult { + polars_bail!(opq = arg_sort_multiple, self._dtype()); + } + } +} + +pub trait SeriesTrait: + Send + Sync + private::PrivateSeries + private::PrivateSeriesNumeric +{ + /// Rename the Series. + fn rename(&mut self, name: PlSmallStr); + + /// Get the lengths of the underlying chunks + fn chunk_lengths(&self) -> ChunkLenIter; + + /// Name of series. + fn name(&self) -> &PlSmallStr; + + /// Get field (used in schema) + fn field(&self) -> Cow { + self._field() + } + + /// Get datatype of series. + fn dtype(&self) -> &DataType { + self._dtype() + } + + /// Underlying chunks. + fn chunks(&self) -> &Vec; + + /// Underlying chunks. + /// + /// # Safety + /// The caller must ensure the length and the data types of `ArrayRef` does not change. + unsafe fn chunks_mut(&mut self) -> &mut Vec; + + /// Number of chunks in this Series + fn n_chunks(&self) -> usize { + self.chunks().len() + } + + /// Shrink the capacity of this array to fit its length. + fn shrink_to_fit(&mut self) { + // no-op + } + + /// Take `num_elements` from the top as a zero copy view. + fn limit(&self, num_elements: usize) -> Series { + self.slice(0, num_elements) + } + + /// Get a zero copy view of the data. + /// + /// When offset is negative the offset is counted from the + /// end of the array + fn slice(&self, _offset: i64, _length: usize) -> Series; + + /// Get a zero copy view of the data. + /// + /// When offset is negative the offset is counted from the + /// end of the array + fn split_at(&self, _offset: i64) -> (Series, Series); + + fn append(&mut self, other: &Series) -> PolarsResult<()>; + fn append_owned(&mut self, other: Series) -> PolarsResult<()>; + + #[doc(hidden)] + fn extend(&mut self, _other: &Series) -> PolarsResult<()>; + + /// Filter by boolean mask. This operation clones data. + fn filter(&self, _filter: &BooleanChunked) -> PolarsResult; + + /// Take from `self` at the indexes given by `idx`. + /// + /// Null values in `idx` because null values in the output array. + /// + /// This operation is clone. + fn take(&self, _indices: &IdxCa) -> PolarsResult; + + /// Take from `self` at the indexes given by `idx`. + /// + /// Null values in `idx` because null values in the output array. + /// + /// # Safety + /// This doesn't check any bounds. + unsafe fn take_unchecked(&self, _idx: &IdxCa) -> Series; + + /// Take from `self` at the indexes given by `idx`. + /// + /// This operation is clone. + fn take_slice(&self, _indices: &[IdxSize]) -> PolarsResult; + + /// Take from `self` at the indexes given by `idx`. + /// + /// # Safety + /// This doesn't check any bounds. + unsafe fn take_slice_unchecked(&self, _idx: &[IdxSize]) -> Series; + + /// Get length of series. + fn len(&self) -> usize; + + /// Check if Series is empty. + fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Aggregate all chunks to a contiguous array of memory. + fn rechunk(&self) -> Series; + + fn rechunk_validity(&self) -> Option { + if self.chunks().len() == 1 { + return self.chunks()[0].validity().cloned(); + } + + if !self.has_nulls() || self.is_empty() { + return None; + } + + let mut bm = BitmapBuilder::with_capacity(self.len()); + for arr in self.chunks() { + if let Some(v) = arr.validity() { + bm.extend_from_bitmap(v); + } else { + bm.extend_constant(arr.len(), true); + } + } + bm.into_opt_validity() + } + + /// Drop all null values and return a new Series. + fn drop_nulls(&self) -> Series { + if self.null_count() == 0 { + Series(self.clone_inner()) + } else { + self.filter(&self.is_not_null()).unwrap() + } + } + + /// Returns the sum of the array as an f64. + fn _sum_as_f64(&self) -> f64 { + invalid_operation_panic!(_sum_as_f64, self) + } + + /// Returns the mean value in the array + /// Returns an option because the array is nullable. + fn mean(&self) -> Option { + None + } + + /// Returns the std value in the array + /// Returns an option because the array is nullable. + fn std(&self, _ddof: u8) -> Option { + None + } + + /// Returns the var value in the array + /// Returns an option because the array is nullable. + fn var(&self, _ddof: u8) -> Option { + None + } + + /// Returns the median value in the array + /// Returns an option because the array is nullable. + fn median(&self) -> Option { + None + } + + /// Create a new Series filled with values from the given index. + /// + /// # Example + /// + /// ```rust + /// use polars_core::prelude::*; + /// let s = Series::new("a".into(), [0i32, 1, 8]); + /// let s2 = s.new_from_index(2, 4); + /// assert_eq!(Vec::from(s2.i32().unwrap()), &[Some(8), Some(8), Some(8), Some(8)]) + /// ``` + fn new_from_index(&self, _index: usize, _length: usize) -> Series; + + fn cast(&self, _dtype: &DataType, options: CastOptions) -> PolarsResult; + + /// Get a single value by index. Don't use this operation for loops as a runtime cast is + /// needed for every iteration. + fn get(&self, index: usize) -> PolarsResult { + polars_ensure!(index < self.len(), oob = index, self.len()); + // SAFETY: Just did bounds check + let value = unsafe { self.get_unchecked(index) }; + Ok(value) + } + + /// Get a single value by index. Don't use this operation for loops as a runtime cast is + /// needed for every iteration. + /// + /// This may refer to physical types + /// + /// # Safety + /// Does not do any bounds checking + unsafe fn get_unchecked(&self, _index: usize) -> AnyValue; + + fn sort_with(&self, _options: SortOptions) -> PolarsResult { + polars_bail!(opq = sort_with, self._dtype()); + } + + /// Retrieve the indexes needed for a sort. + #[allow(unused)] + fn arg_sort(&self, options: SortOptions) -> IdxCa { + invalid_operation_panic!(arg_sort, self) + } + + /// Count the null values. + fn null_count(&self) -> usize; + + /// Return if any the chunks in this [`ChunkedArray`] have nulls. + fn has_nulls(&self) -> bool; + + /// Get unique values in the Series. + fn unique(&self) -> PolarsResult { + polars_bail!(opq = unique, self._dtype()); + } + + /// Get unique values in the Series. + /// + /// A `null` value also counts as a unique value. + fn n_unique(&self) -> PolarsResult { + polars_bail!(opq = n_unique, self._dtype()); + } + + /// Get first indexes of unique values. + fn arg_unique(&self) -> PolarsResult { + polars_bail!(opq = arg_unique, self._dtype()); + } + + /// Get a mask of the null values. + fn is_null(&self) -> BooleanChunked; + + /// Get a mask of the non-null values. + fn is_not_null(&self) -> BooleanChunked; + + /// return a Series in reversed order + fn reverse(&self) -> Series; + + /// Rechunk and return a pointer to the start of the Series. + /// Only implemented for numeric types + fn as_single_ptr(&mut self) -> PolarsResult { + polars_bail!(opq = as_single_ptr, self._dtype()); + } + + /// Shift the values by a given period and fill the parts that will be empty due to this operation + /// with `Nones`. + /// + /// *NOTE: If you want to fill the Nones with a value use the + /// [`shift` operation on `ChunkedArray`](../chunked_array/ops/trait.ChunkShift.html).* + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// fn example() -> PolarsResult<()> { + /// let s = Series::new("series".into(), &[1, 2, 3]); + /// + /// let shifted = s.shift(1); + /// assert_eq!(Vec::from(shifted.i32()?), &[None, Some(1), Some(2)]); + /// + /// let shifted = s.shift(-1); + /// assert_eq!(Vec::from(shifted.i32()?), &[Some(2), Some(3), None]); + /// + /// let shifted = s.shift(2); + /// assert_eq!(Vec::from(shifted.i32()?), &[None, None, Some(1)]); + /// + /// Ok(()) + /// } + /// example(); + /// ``` + fn shift(&self, _periods: i64) -> Series; + + /// Get the sum of the Series as a new Scalar. + /// + /// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16}` the `Series` is + /// first cast to `Int64` to prevent overflow issues. + fn sum_reduce(&self) -> PolarsResult { + polars_bail!(opq = sum, self._dtype()); + } + /// Get the max of the Series as a new Series of length 1. + fn max_reduce(&self) -> PolarsResult { + polars_bail!(opq = max, self._dtype()); + } + /// Get the min of the Series as a new Series of length 1. + fn min_reduce(&self) -> PolarsResult { + polars_bail!(opq = min, self._dtype()); + } + /// Get the median of the Series as a new Series of length 1. + fn median_reduce(&self) -> PolarsResult { + polars_bail!(opq = median, self._dtype()); + } + /// Get the variance of the Series as a new Series of length 1. + fn var_reduce(&self, _ddof: u8) -> PolarsResult { + polars_bail!(opq = var, self._dtype()); + } + /// Get the standard deviation of the Series as a new Series of length 1. + fn std_reduce(&self, _ddof: u8) -> PolarsResult { + polars_bail!(opq = std, self._dtype()); + } + /// Get the quantile of the ChunkedArray as a new Series of length 1. + fn quantile_reduce(&self, _quantile: f64, _method: QuantileMethod) -> PolarsResult { + polars_bail!(opq = quantile, self._dtype()); + } + /// Get the bitwise AND of the Series as a new Series of length 1, + fn and_reduce(&self) -> PolarsResult { + polars_bail!(opq = and_reduce, self._dtype()); + } + /// Get the bitwise OR of the Series as a new Series of length 1, + fn or_reduce(&self) -> PolarsResult { + polars_bail!(opq = or_reduce, self._dtype()); + } + /// Get the bitwise XOR of the Series as a new Series of length 1, + fn xor_reduce(&self) -> PolarsResult { + polars_bail!(opq = xor_reduce, self._dtype()); + } + + /// Get the first element of the [`Series`] as a [`Scalar`] + /// + /// If the [`Series`] is empty, a [`Scalar`] with a [`AnyValue::Null`] is returned. + fn first(&self) -> Scalar { + let dt = self.dtype(); + let av = self.get(0).map_or(AnyValue::Null, AnyValue::into_static); + + Scalar::new(dt.clone(), av) + } + + /// Get the last element of the [`Series`] as a [`Scalar`] + /// + /// If the [`Series`] is empty, a [`Scalar`] with a [`AnyValue::Null`] is returned. + fn last(&self) -> Scalar { + let dt = self.dtype(); + let av = if self.len() == 0 { + AnyValue::Null + } else { + // SAFETY: len-1 < len if len != 0 + unsafe { self.get_unchecked(self.len() - 1) }.into_static() + }; + + Scalar::new(dt.clone(), av) + } + + #[cfg(feature = "approx_unique")] + fn approx_n_unique(&self) -> PolarsResult { + polars_bail!(opq = approx_n_unique, self._dtype()); + } + + /// Clone inner ChunkedArray and wrap in a new Arc + fn clone_inner(&self) -> Arc; + + #[cfg(feature = "object")] + /// Get the value at this index as a downcastable Any trait ref. + fn get_object(&self, _index: usize) -> Option<&dyn PolarsObjectSafe> { + invalid_operation_panic!(get_object, self) + } + + #[cfg(feature = "object")] + /// Get the value at this index as a downcastable Any trait ref. + /// + /// # Safety + /// This function doesn't do any bound checks. + unsafe fn get_object_chunked_unchecked( + &self, + _chunk: usize, + _index: usize, + ) -> Option<&dyn PolarsObjectSafe> { + invalid_operation_panic!(get_object_chunked_unchecked, self) + } + + /// Get a hold of the [`ChunkedArray`], [`Logical`] or `NullChunked` as an `Any` trait + /// reference. + fn as_any(&self) -> &dyn Any; + + /// Get a hold of the [`ChunkedArray`], [`Logical`] or `NullChunked` as an `Any` trait mutable + /// reference. + fn as_any_mut(&mut self) -> &mut dyn Any; + + /// Get a hold of the [`ChunkedArray`] or `NullChunked` as an `Any` trait reference. This + /// pierces through `Logical` types to get the underlying physical array. + fn as_phys_any(&self) -> &dyn Any; + + fn as_arc_any(self: Arc) -> Arc; + + #[cfg(feature = "checked_arithmetic")] + fn checked_div(&self, _rhs: &Series) -> PolarsResult { + polars_bail!(opq = checked_div, self._dtype()); + } + + #[cfg(feature = "rolling_window")] + /// Apply a custom function over a rolling/ moving window of the array. + /// This has quite some dynamic dispatch, so prefer rolling_min, max, mean, sum over this. + fn rolling_map( + &self, + _f: &dyn Fn(&Series) -> Series, + _options: RollingOptionsFixedWindow, + ) -> PolarsResult { + polars_bail!(opq = rolling_map, self._dtype()); + } +} + +impl (dyn SeriesTrait + '_) { + pub fn unpack(&self) -> PolarsResult<&ChunkedArray> + where + N: 'static + PolarsDataType, + { + polars_ensure!(&N::get_dtype() == self.dtype(), unpack); + Ok(self.as_ref()) + } +} diff --git a/crates/polars-core/src/testing.rs b/crates/polars-core/src/testing.rs new file mode 100644 index 000000000000..ed7c3d4fbd3e --- /dev/null +++ b/crates/polars-core/src/testing.rs @@ -0,0 +1,204 @@ +//! Testing utilities. + +use crate::prelude::*; + +impl Series { + /// Check if series are equal. Note that `None == None` evaluates to `false` + pub fn equals(&self, other: &Series) -> bool { + if self.null_count() > 0 || other.null_count() > 0 { + false + } else { + self.equals_missing(other) + } + } + + /// Check if all values in series are equal where `None == None` evaluates to `true`. + pub fn equals_missing(&self, other: &Series) -> bool { + match (self.dtype(), other.dtype()) { + // Two [`Datetime`](DataType::Datetime) series are *not* equal if their timezones + // are different, regardless if they represent the same UTC time or not. + #[cfg(feature = "timezones")] + (DataType::Datetime(_, tz_lhs), DataType::Datetime(_, tz_rhs)) => { + if tz_lhs != tz_rhs { + return false; + } + }, + _ => {}, + } + + // Differs from Partial::eq in that numerical dtype may be different + self.len() == other.len() && self.null_count() == other.null_count() && { + let eq = self.equal_missing(other); + match eq { + Ok(b) => b.all(), + Err(_) => false, + } + } + } +} + +impl PartialEq for Series { + fn eq(&self, other: &Self) -> bool { + self.equals_missing(other) + } +} + +impl DataFrame { + /// Check if [`DataFrame`]' schemas are equal. + pub fn schema_equal(&self, other: &DataFrame) -> PolarsResult<()> { + for (lhs, rhs) in self.iter().zip(other.iter()) { + polars_ensure!( + lhs.name() == rhs.name(), + SchemaMismatch: "column name mismatch: left-hand = '{}', right-hand = '{}'", + lhs.name(), rhs.name() + ); + polars_ensure!( + lhs.dtype() == rhs.dtype(), + SchemaMismatch: "column datatype mismatch: left-hand = '{}', right-hand = '{}'", + lhs.dtype(), rhs.dtype() + ); + } + Ok(()) + } + + /// Check if [`DataFrame`]s are equal. Note that `None == None` evaluates to `false` + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let df1: DataFrame = df!("Atomic number" => &[1, 51, 300], + /// "Element" => &[Some("Hydrogen"), Some("Antimony"), None])?; + /// let df2: DataFrame = df!("Atomic number" => &[1, 51, 300], + /// "Element" => &[Some("Hydrogen"), Some("Antimony"), None])?; + /// + /// assert!(!df1.equals(&df2)); + /// # Ok::<(), PolarsError>(()) + /// ``` + pub fn equals(&self, other: &DataFrame) -> bool { + if self.shape() != other.shape() { + return false; + } + for (left, right) in self.get_columns().iter().zip(other.get_columns()) { + if left.name() != right.name() || !left.equals(right) { + return false; + } + } + true + } + + /// Check if all values in [`DataFrame`]s are equal where `None == None` evaluates to `true`. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// let df1: DataFrame = df!("Atomic number" => &[1, 51, 300], + /// "Element" => &[Some("Hydrogen"), Some("Antimony"), None])?; + /// let df2: DataFrame = df!("Atomic number" => &[1, 51, 300], + /// "Element" => &[Some("Hydrogen"), Some("Antimony"), None])?; + /// + /// assert!(df1.equals_missing(&df2)); + /// # Ok::<(), PolarsError>(()) + /// ``` + pub fn equals_missing(&self, other: &DataFrame) -> bool { + if self.shape() != other.shape() { + return false; + } + for (left, right) in self.get_columns().iter().zip(other.get_columns()) { + if left.name() != right.name() || !left.equals_missing(right) { + return false; + } + } + true + } +} + +impl PartialEq for DataFrame { + fn eq(&self, other: &Self) -> bool { + self.shape() == other.shape() + && self + .columns + .iter() + .zip(other.columns.iter()) + .all(|(s1, s2)| s1.equals_missing(s2)) + } +} + +/// Asserts that two expressions of type [`DataFrame`] are equal according to [`DataFrame::equals`] +/// at runtime. +/// +/// If the expression are not equal, the program will panic with a message that displays +/// both dataframes. +#[macro_export] +macro_rules! assert_df_eq { + ($a:expr, $b:expr $(,)?) => { + let a: &$crate::frame::DataFrame = &$a; + let b: &$crate::frame::DataFrame = &$b; + assert!(a.equals(b), "expected {:?}\nto equal {:?}", a, b); + }; +} + +#[cfg(test)] +mod test { + use crate::prelude::*; + + #[test] + fn test_series_equals() { + let a = Series::new("a".into(), &[1_u32, 2, 3]); + let b = Series::new("a".into(), &[1_u32, 2, 3]); + assert!(a.equals(&b)); + + let s = Series::new("foo".into(), &[None, Some(1i64)]); + assert!(s.equals_missing(&s)); + } + + #[test] + fn test_series_dtype_not_equal() { + let s_i32 = Series::new("a".into(), &[1_i32, 2_i32]); + let s_i64 = Series::new("a".into(), &[1_i64, 2_i64]); + assert!(s_i32.dtype() != s_i64.dtype()); + assert!(s_i32.equals(&s_i64)); + } + + #[test] + fn test_df_equal() { + let a = Column::new("a".into(), [1, 2, 3].as_ref()); + let b = Column::new("b".into(), [1, 2, 3].as_ref()); + + let df1 = DataFrame::new(vec![a, b]).unwrap(); + assert!(df1.equals(&df1)) + } + + #[test] + fn assert_df_eq_passes() { + let df = df!("a" => [1], "b" => [2]).unwrap(); + assert_df_eq!(df, df); + drop(df); // Ensure `assert_df_eq!` does not consume its arguments. + } + + #[test] + #[should_panic(expected = "to equal")] + fn assert_df_eq_panics() { + assert_df_eq!(df!("a" => [1]).unwrap(), df!("a" => [2]).unwrap(),); + } + + #[test] + fn test_df_partialeq() { + let df1 = df!("a" => &[1, 2, 3], + "b" => &[4, 5, 6]) + .unwrap(); + let df2 = df!("b" => &[4, 5, 6], + "a" => &[1, 2, 3]) + .unwrap(); + let df3 = df!("" => &[Some(1), None]).unwrap(); + let df4 = df!("" => &[f32::NAN]).unwrap(); + + assert_eq!(df1, df1); + assert_ne!(df1, df2); + assert_eq!(df2, df2); + assert_ne!(df2, df3); + assert_eq!(df3, df3); + assert_eq!(df4, df4); + } +} diff --git a/crates/polars-core/src/tests.rs b/crates/polars-core/src/tests.rs new file mode 100644 index 000000000000..b1c042e80a4b --- /dev/null +++ b/crates/polars-core/src/tests.rs @@ -0,0 +1,17 @@ +use crate::prelude::*; + +#[test] +fn test_initial_empty_sort() -> PolarsResult<()> { + // https://github.com/pola-rs/polars/issues/1396 + let data = vec![1.3; 42]; + let mut series = Column::new("data".into(), Vec::::new()); + let series2 = Column::new("data2".into(), data.clone()); + let series3 = Column::new("data3".into(), data); + let df = DataFrame::new(vec![series2, series3])?; + + for column in df.get_columns().iter() { + series.append(column)?; + } + series.f64()?.sort(false); + Ok(()) +} diff --git a/crates/polars-core/src/utils/any_value.rs b/crates/polars-core/src/utils/any_value.rs new file mode 100644 index 000000000000..5a9688cc1394 --- /dev/null +++ b/crates/polars-core/src/utils/any_value.rs @@ -0,0 +1,37 @@ +use crate::prelude::*; +use crate::utils::dtypes_to_supertype; + +/// Determine the supertype of a collection of [`AnyValue`]. +/// +/// [`AnyValue`]: crate::datatypes::AnyValue +pub fn any_values_to_supertype<'a, I>(values: I) -> PolarsResult +where + I: IntoIterator>, +{ + let dtypes = any_values_to_dtype_set(values); + dtypes_to_supertype(&dtypes) +} + +/// Determine the supertype and the number of unique data types of a collection of [`AnyValue`]. +/// +/// [`AnyValue`]: crate::datatypes::AnyValue +pub fn any_values_to_supertype_and_n_dtypes<'a, I>(values: I) -> PolarsResult<(DataType, usize)> +where + I: IntoIterator>, +{ + let dtypes = any_values_to_dtype_set(values); + let supertype = dtypes_to_supertype(&dtypes)?; + let n_dtypes = dtypes.len(); + Ok((supertype, n_dtypes)) +} + +/// Extract the ordered set of data types from a collection of AnyValues +/// +/// Retaining the order is important if the set is used to determine a supertype, +/// as this can influence how Struct fields are constructed. +fn any_values_to_dtype_set<'a, I>(values: I) -> PlIndexSet +where + I: IntoIterator>, +{ + values.into_iter().map(|av| av.into()).collect() +} diff --git a/crates/polars-core/src/utils/flatten.rs b/crates/polars-core/src/utils/flatten.rs new file mode 100644 index 000000000000..1c0e36703a7d --- /dev/null +++ b/crates/polars-core/src/utils/flatten.rs @@ -0,0 +1,120 @@ +use arrow::bitmap::MutableBitmap; +use polars_utils::sync::SyncPtr; + +use super::*; + +pub fn flatten_df_iter(df: &DataFrame) -> impl Iterator + '_ { + df.iter_chunks_physical().flat_map(|chunk| { + let columns = df + .iter() + .zip(chunk.into_arrays()) + .map(|(s, arr)| { + // SAFETY: + // datatypes are correct + let mut out = unsafe { + Series::from_chunks_and_dtype_unchecked(s.name().clone(), vec![arr], s.dtype()) + }; + out.set_sorted_flag(s.is_sorted_flag()); + Column::from(out) + }) + .collect::>(); + + let height = DataFrame::infer_height(&columns); + let df = unsafe { DataFrame::new_no_checks(height, columns) }; + if df.is_empty() { None } else { Some(df) } + }) +} + +pub fn flatten_series(s: &Series) -> Vec { + let name = s.name(); + let dtype = s.dtype(); + unsafe { + s.chunks() + .iter() + .map(|arr| { + Series::from_chunks_and_dtype_unchecked(name.clone(), vec![arr.clone()], dtype) + }) + .collect() + } +} + +pub fn cap_and_offsets(v: &[Vec]) -> (usize, Vec) { + let cap = v.iter().map(|v| v.len()).sum::(); + let offsets = v + .iter() + .scan(0_usize, |acc, v| { + let out = *acc; + *acc += v.len(); + Some(out) + }) + .collect::>(); + (cap, offsets) +} + +pub fn flatten_par>(bufs: &[S]) -> Vec { + let mut len = 0; + let mut offsets = Vec::with_capacity(bufs.len()); + let bufs = bufs + .iter() + .map(|s| { + offsets.push(len); + let slice = s.as_ref(); + len += slice.len(); + slice + }) + .collect::>(); + flatten_par_impl(&bufs, len, offsets) +} + +fn flatten_par_impl( + bufs: &[&[T]], + len: usize, + offsets: Vec, +) -> Vec { + let mut out = Vec::with_capacity(len); + let out_ptr = unsafe { SyncPtr::new(out.as_mut_ptr()) }; + + POOL.install(|| { + offsets.into_par_iter().enumerate().for_each(|(i, offset)| { + let buf = bufs[i]; + let ptr: *mut T = out_ptr.get(); + unsafe { + let dst = ptr.add(offset); + let src = buf.as_ptr(); + std::ptr::copy_nonoverlapping(src, dst, buf.len()) + } + }) + }); + unsafe { + out.set_len(len); + } + out +} + +pub fn flatten_nullable + Send + Sync>( + bufs: &[S], +) -> PrimitiveArray { + let a = || flatten_par(bufs); + let b = || { + let cap = bufs.iter().map(|s| s.as_ref().len()).sum::(); + let mut validity = MutableBitmap::with_capacity(cap); + validity.extend_constant(cap, true); + + let mut count = 0usize; + for s in bufs { + let s = s.as_ref(); + + for id in s { + if id.is_null_idx() { + unsafe { validity.set_unchecked(count, false) }; + } + + count += 1; + } + } + validity.freeze() + }; + + let (a, b) = POOL.join(a, b); + PrimitiveArray::from_vec(bytemuck::cast_vec::<_, IdxSize>(a)).with_validity(Some(b)) +} diff --git a/crates/polars-core/src/utils/mod.rs b/crates/polars-core/src/utils/mod.rs new file mode 100644 index 000000000000..a84957855f69 --- /dev/null +++ b/crates/polars-core/src/utils/mod.rs @@ -0,0 +1,1288 @@ +mod any_value; +use arrow::compute::concatenate::concatenate_validities; +use arrow::compute::utils::combine_validities_and; +pub mod flatten; +pub(crate) mod series; +mod supertype; +use std::borrow::Cow; +use std::ops::{Deref, DerefMut}; +mod schema; + +pub use any_value::*; +use arrow::bitmap::Bitmap; +use arrow::bitmap::bitmask::BitMask; +pub use arrow::legacy::utils::*; +pub use arrow::trusted_len::TrustMyLength; +use flatten::*; +use num_traits::{One, Zero}; +use rayon::prelude::*; +pub use schema::*; +pub use series::*; +pub use supertype::*; +pub use {arrow, rayon}; + +use crate::POOL; +use crate::prelude::*; + +#[repr(transparent)] +pub struct Wrap(pub T); + +impl Deref for Wrap { + type Target = T; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[inline(always)] +pub fn _set_partition_size() -> usize { + POOL.current_num_threads() +} + +/// Just a wrapper structure which is useful for certain impl specializations. +/// +/// This is for instance use to implement +/// `impl FromIterator for NoNull>` +/// as `Option` was already implemented: +/// `impl FromIterator> for ChunkedArray` +pub struct NoNull { + inner: T, +} + +impl NoNull { + pub fn new(inner: T) -> Self { + NoNull { inner } + } + + pub fn into_inner(self) -> T { + self.inner + } +} + +impl Deref for NoNull { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for NoNull { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +pub(crate) fn get_iter_capacity>(iter: &I) -> usize { + match iter.size_hint() { + (_lower, Some(upper)) => upper, + (0, None) => 1024, + (lower, None) => lower, + } +} + +// prefer this one over split_ca, as this can push the null_count into the thread pool +// returns an `(offset, length)` tuple +#[doc(hidden)] +pub fn _split_offsets(len: usize, n: usize) -> Vec<(usize, usize)> { + if n == 1 { + vec![(0, len)] + } else { + let chunk_size = len / n; + + (0..n) + .map(|partition| { + let offset = partition * chunk_size; + let len = if partition == (n - 1) { + len - offset + } else { + chunk_size + }; + (partition * chunk_size, len) + }) + .collect_trusted() + } +} + +#[allow(clippy::len_without_is_empty)] +pub trait Container: Clone { + fn slice(&self, offset: i64, len: usize) -> Self; + + fn split_at(&self, offset: i64) -> (Self, Self); + + fn len(&self) -> usize; + + fn iter_chunks(&self) -> impl Iterator; + + fn n_chunks(&self) -> usize; + + fn chunk_lengths(&self) -> impl Iterator; +} + +impl Container for DataFrame { + fn slice(&self, offset: i64, len: usize) -> Self { + DataFrame::slice(self, offset, len) + } + + fn split_at(&self, offset: i64) -> (Self, Self) { + DataFrame::split_at(self, offset) + } + + fn len(&self) -> usize { + self.height() + } + + fn iter_chunks(&self) -> impl Iterator { + flatten_df_iter(self) + } + + fn n_chunks(&self) -> usize { + DataFrame::first_col_n_chunks(self) + } + + fn chunk_lengths(&self) -> impl Iterator { + // @scalar-correctness? + self.columns[0].as_materialized_series().chunk_lengths() + } +} + +impl Container for ChunkedArray { + fn slice(&self, offset: i64, len: usize) -> Self { + ChunkedArray::slice(self, offset, len) + } + + fn split_at(&self, offset: i64) -> (Self, Self) { + ChunkedArray::split_at(self, offset) + } + + fn len(&self) -> usize { + ChunkedArray::len(self) + } + + fn iter_chunks(&self) -> impl Iterator { + self.downcast_iter() + .map(|arr| Self::with_chunk(self.name().clone(), arr.clone())) + } + + fn n_chunks(&self) -> usize { + self.chunks().len() + } + + fn chunk_lengths(&self) -> impl Iterator { + ChunkedArray::chunk_lengths(self) + } +} + +impl Container for Series { + fn slice(&self, offset: i64, len: usize) -> Self { + self.0.slice(offset, len) + } + + fn split_at(&self, offset: i64) -> (Self, Self) { + self.0.split_at(offset) + } + + fn len(&self) -> usize { + self.0.len() + } + + fn iter_chunks(&self) -> impl Iterator { + (0..self.0.n_chunks()).map(|i| self.select_chunk(i)) + } + + fn n_chunks(&self) -> usize { + self.chunks().len() + } + + fn chunk_lengths(&self) -> impl Iterator { + self.0.chunk_lengths() + } +} + +fn split_impl(container: &C, target: usize, chunk_size: usize) -> Vec { + if target == 1 { + return vec![container.clone()]; + } + let mut out = Vec::with_capacity(target); + let chunk_size = chunk_size as i64; + + // First split + let (chunk, mut remainder) = container.split_at(chunk_size); + out.push(chunk); + + // Take the rest of the splits of exactly chunk size, but skip the last remainder as we won't split that. + for _ in 1..target - 1 { + let (a, b) = remainder.split_at(chunk_size); + out.push(a); + remainder = b + } + // This can be slightly larger than `chunk_size`, but is smaller than `2 * chunk_size`. + out.push(remainder); + out +} + +/// Splits, but doesn't flatten chunks. E.g. a container can still have multiple chunks. +pub fn split(container: &C, target: usize) -> Vec { + let total_len = container.len(); + if total_len == 0 { + return vec![container.clone()]; + } + + let chunk_size = std::cmp::max(total_len / target, 1); + + if container.n_chunks() == target + && container + .chunk_lengths() + .all(|len| len.abs_diff(chunk_size) < 100) + { + return container.iter_chunks().collect(); + } + split_impl(container, target, chunk_size) +} + +/// Split a [`Container`] in `target` elements. The target doesn't have to be respected if not +/// Deviation of the target might be done to create more equal size chunks. +pub fn split_and_flatten(container: &C, target: usize) -> Vec { + let total_len = container.len(); + if total_len == 0 { + return vec![container.clone()]; + } + + let chunk_size = std::cmp::max(total_len / target, 1); + + if container.n_chunks() == target + && container + .chunk_lengths() + .all(|len| len.abs_diff(chunk_size) < 100) + { + return container.iter_chunks().collect(); + } + + if container.n_chunks() == 1 { + split_impl(container, target, chunk_size) + } else { + let mut out = Vec::with_capacity(target); + let chunks = container.iter_chunks(); + + 'new_chunk: for mut chunk in chunks { + loop { + let h = chunk.len(); + if h < chunk_size { + // TODO if the chunk is much smaller than chunk size, we should try to merge it with the next one. + out.push(chunk); + continue 'new_chunk; + } + + // If a split leads to the next chunk being smaller than 30% take the whole chunk + if ((h - chunk_size) as f64 / chunk_size as f64) < 0.3 { + out.push(chunk); + continue 'new_chunk; + } + + let (a, b) = chunk.split_at(chunk_size as i64); + out.push(a); + chunk = b; + } + } + out + } +} + +/// Split a [`DataFrame`] in `target` elements. The target doesn't have to be respected if not +/// strict. Deviation of the target might be done to create more equal size chunks. +/// +/// # Panics +/// if chunks are not aligned +pub fn split_df_as_ref(df: &DataFrame, target: usize, strict: bool) -> Vec { + if strict { + split(df, target) + } else { + split_and_flatten(df, target) + } +} + +#[doc(hidden)] +/// Split a [`DataFrame`] into `n` parts. We take a `&mut` to be able to repartition/align chunks. +/// `strict` in that it respects `n` even if the chunks are suboptimal. +pub fn split_df(df: &mut DataFrame, target: usize, strict: bool) -> Vec { + if target == 0 || df.is_empty() { + return vec![df.clone()]; + } + // make sure that chunks are aligned. + df.align_chunks_par(); + split_df_as_ref(df, target, strict) +} + +pub fn slice_slice(vals: &[T], offset: i64, len: usize) -> &[T] { + let (raw_offset, slice_len) = slice_offsets(offset, len, vals.len()); + &vals[raw_offset..raw_offset + slice_len] +} + +#[inline] +#[doc(hidden)] +pub fn slice_offsets(offset: i64, length: usize, array_len: usize) -> (usize, usize) { + let signed_start_offset = if offset < 0 { + offset.saturating_add_unsigned(array_len as u64) + } else { + offset + }; + let signed_stop_offset = signed_start_offset.saturating_add_unsigned(length as u64); + + let signed_array_len: i64 = array_len + .try_into() + .expect("array length larger than i64::MAX"); + let clamped_start_offset = signed_start_offset.clamp(0, signed_array_len); + let clamped_stop_offset = signed_stop_offset.clamp(0, signed_array_len); + + let slice_start_idx = clamped_start_offset as usize; + let slice_len = (clamped_stop_offset - clamped_start_offset) as usize; + (slice_start_idx, slice_len) +} + +/// Apply a macro on the Series +#[macro_export] +macro_rules! match_dtype_to_physical_apply_macro { + ($obj:expr, $macro:ident, $macro_string:ident, $macro_bool:ident $(, $opt_args:expr)*) => {{ + match $obj { + DataType::String => $macro_string!($($opt_args)*), + DataType::Boolean => $macro_bool!($($opt_args)*), + #[cfg(feature = "dtype-u8")] + DataType::UInt8 => $macro!(u8 $(, $opt_args)*), + #[cfg(feature = "dtype-u16")] + DataType::UInt16 => $macro!(u16 $(, $opt_args)*), + DataType::UInt32 => $macro!(u32 $(, $opt_args)*), + DataType::UInt64 => $macro!(u64 $(, $opt_args)*), + #[cfg(feature = "dtype-i8")] + DataType::Int8 => $macro!(i8 $(, $opt_args)*), + #[cfg(feature = "dtype-i16")] + DataType::Int16 => $macro!(i16 $(, $opt_args)*), + DataType::Int32 => $macro!(i32 $(, $opt_args)*), + DataType::Int64 => $macro!(i64 $(, $opt_args)*), + #[cfg(feature = "dtype-i128")] + DataType::Int128 => $macro!(i128 $(, $opt_args)*), + DataType::Float32 => $macro!(f32 $(, $opt_args)*), + DataType::Float64 => $macro!(f64 $(, $opt_args)*), + dt => panic!("not implemented for dtype {:?}", dt), + } + }}; +} + +/// Apply a macro on the Series +#[macro_export] +macro_rules! match_dtype_to_logical_apply_macro { + ($obj:expr, $macro:ident, $macro_string:ident, $macro_binary:ident, $macro_bool:ident $(, $opt_args:expr)*) => {{ + match $obj { + DataType::String => $macro_string!($($opt_args)*), + DataType::Binary => $macro_binary!($($opt_args)*), + DataType::Boolean => $macro_bool!($($opt_args)*), + #[cfg(feature = "dtype-u8")] + DataType::UInt8 => $macro!(UInt8Type $(, $opt_args)*), + #[cfg(feature = "dtype-u16")] + DataType::UInt16 => $macro!(UInt16Type $(, $opt_args)*), + DataType::UInt32 => $macro!(UInt32Type $(, $opt_args)*), + DataType::UInt64 => $macro!(UInt64Type $(, $opt_args)*), + #[cfg(feature = "dtype-i8")] + DataType::Int8 => $macro!(Int8Type $(, $opt_args)*), + #[cfg(feature = "dtype-i16")] + DataType::Int16 => $macro!(Int16Type $(, $opt_args)*), + DataType::Int32 => $macro!(Int32Type $(, $opt_args)*), + DataType::Int64 => $macro!(Int64Type $(, $opt_args)*), + #[cfg(feature = "dtype-i128")] + DataType::Int128 => $macro!(Int128Type $(, $opt_args)*), + DataType::Float32 => $macro!(Float32Type $(, $opt_args)*), + DataType::Float64 => $macro!(Float64Type $(, $opt_args)*), + dt => panic!("not implemented for dtype {:?}", dt), + } + }}; +} + +/// Apply a macro on the Downcasted ChunkedArrays +#[macro_export] +macro_rules! match_arrow_dtype_apply_macro_ca { + ($self:expr, $macro:ident, $macro_string:ident, $macro_bool:ident $(, $opt_args:expr)*) => {{ + match $self.dtype() { + DataType::String => $macro_string!($self.str().unwrap() $(, $opt_args)*), + DataType::Boolean => $macro_bool!($self.bool().unwrap() $(, $opt_args)*), + #[cfg(feature = "dtype-u8")] + DataType::UInt8 => $macro!($self.u8().unwrap() $(, $opt_args)*), + #[cfg(feature = "dtype-u16")] + DataType::UInt16 => $macro!($self.u16().unwrap() $(, $opt_args)*), + DataType::UInt32 => $macro!($self.u32().unwrap() $(, $opt_args)*), + DataType::UInt64 => $macro!($self.u64().unwrap() $(, $opt_args)*), + #[cfg(feature = "dtype-i8")] + DataType::Int8 => $macro!($self.i8().unwrap() $(, $opt_args)*), + #[cfg(feature = "dtype-i16")] + DataType::Int16 => $macro!($self.i16().unwrap() $(, $opt_args)*), + DataType::Int32 => $macro!($self.i32().unwrap() $(, $opt_args)*), + DataType::Int64 => $macro!($self.i64().unwrap() $(, $opt_args)*), + #[cfg(feature = "dtype-i128")] + DataType::Int128 => $macro!($self.i128().unwrap() $(, $opt_args)*), + DataType::Float32 => $macro!($self.f32().unwrap() $(, $opt_args)*), + DataType::Float64 => $macro!($self.f64().unwrap() $(, $opt_args)*), + dt => panic!("not implemented for dtype {:?}", dt), + } + }}; +} + +#[macro_export] +macro_rules! with_match_physical_numeric_type {( + $dtype:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use $crate::datatypes::DataType::*; + match $dtype { + #[cfg(feature = "dtype-i8")] + Int8 => __with_ty__! { i8 }, + #[cfg(feature = "dtype-i16")] + Int16 => __with_ty__! { i16 }, + Int32 => __with_ty__! { i32 }, + Int64 => __with_ty__! { i64 }, + #[cfg(feature = "dtype-i128")] + Int128 => __with_ty__! { i128 }, + #[cfg(feature = "dtype-u8")] + UInt8 => __with_ty__! { u8 }, + #[cfg(feature = "dtype-u16")] + UInt16 => __with_ty__! { u16 }, + UInt32 => __with_ty__! { u32 }, + UInt64 => __with_ty__! { u64 }, + Float32 => __with_ty__! { f32 }, + Float64 => __with_ty__! { f64 }, + dt => panic!("not implemented for dtype {:?}", dt), + } +})} + +#[macro_export] +macro_rules! with_match_physical_integer_type {( + $dtype:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use $crate::datatypes::DataType::*; + match $dtype { + #[cfg(feature = "dtype-i8")] + Int8 => __with_ty__! { i8 }, + #[cfg(feature = "dtype-i16")] + Int16 => __with_ty__! { i16 }, + Int32 => __with_ty__! { i32 }, + Int64 => __with_ty__! { i64 }, + #[cfg(feature = "dtype-i128")] + Int128 => __with_ty__! { i128 }, + #[cfg(feature = "dtype-u8")] + UInt8 => __with_ty__! { u8 }, + #[cfg(feature = "dtype-u16")] + UInt16 => __with_ty__! { u16 }, + UInt32 => __with_ty__! { u32 }, + UInt64 => __with_ty__! { u64 }, + dt => panic!("not implemented for dtype {:?}", dt), + } +})} + +#[macro_export] +macro_rules! with_match_physical_float_type {( + $dtype:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use $crate::datatypes::DataType::*; + match $dtype { + Float32 => __with_ty__! { f32 }, + Float64 => __with_ty__! { f64 }, + dt => panic!("not implemented for dtype {:?}", dt), + } +})} + +#[macro_export] +macro_rules! with_match_physical_float_polars_type {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use $crate::datatypes::DataType::*; + match $key_type { + Float32 => __with_ty__! { Float32Type }, + Float64 => __with_ty__! { Float64Type }, + dt => panic!("not implemented for dtype {:?}", dt), + } +})} + +#[macro_export] +macro_rules! with_match_physical_numeric_polars_type {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use $crate::datatypes::DataType::*; + match $key_type { + #[cfg(feature = "dtype-i8")] + Int8 => __with_ty__! { Int8Type }, + #[cfg(feature = "dtype-i16")] + Int16 => __with_ty__! { Int16Type }, + Int32 => __with_ty__! { Int32Type }, + Int64 => __with_ty__! { Int64Type }, + #[cfg(feature = "dtype-i128")] + Int128 => __with_ty__! { Int128Type }, + #[cfg(feature = "dtype-u8")] + UInt8 => __with_ty__! { UInt8Type }, + #[cfg(feature = "dtype-u16")] + UInt16 => __with_ty__! { UInt16Type }, + UInt32 => __with_ty__! { UInt32Type }, + UInt64 => __with_ty__! { UInt64Type }, + Float32 => __with_ty__! { Float32Type }, + Float64 => __with_ty__! { Float64Type }, + dt => panic!("not implemented for dtype {:?}", dt), + } +})} + +#[macro_export] +macro_rules! with_match_physical_integer_polars_type {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use $crate::datatypes::DataType::*; + use $crate::datatypes::*; + match $key_type { + #[cfg(feature = "dtype-i8")] + Int8 => __with_ty__! { Int8Type }, + #[cfg(feature = "dtype-i16")] + Int16 => __with_ty__! { Int16Type }, + Int32 => __with_ty__! { Int32Type }, + Int64 => __with_ty__! { Int64Type }, + #[cfg(feature = "dtype-i128")] + Int128 => __with_ty__! { Int128Type }, + #[cfg(feature = "dtype-u8")] + UInt8 => __with_ty__! { UInt8Type }, + #[cfg(feature = "dtype-u16")] + UInt16 => __with_ty__! { UInt16Type }, + UInt32 => __with_ty__! { UInt32Type }, + UInt64 => __with_ty__! { UInt64Type }, + dt => panic!("not implemented for dtype {:?}", dt), + } +})} + +/// Apply a macro on the Downcasted ChunkedArrays of DataTypes that are logical numerics. +/// So no logical. +#[macro_export] +macro_rules! downcast_as_macro_arg_physical { + ($self:expr, $macro:ident $(, $opt_args:expr)*) => {{ + match $self.dtype() { + #[cfg(feature = "dtype-u8")] + DataType::UInt8 => $macro!($self.u8().unwrap() $(, $opt_args)*), + #[cfg(feature = "dtype-u16")] + DataType::UInt16 => $macro!($self.u16().unwrap() $(, $opt_args)*), + DataType::UInt32 => $macro!($self.u32().unwrap() $(, $opt_args)*), + DataType::UInt64 => $macro!($self.u64().unwrap() $(, $opt_args)*), + #[cfg(feature = "dtype-i8")] + DataType::Int8 => $macro!($self.i8().unwrap() $(, $opt_args)*), + #[cfg(feature = "dtype-i16")] + DataType::Int16 => $macro!($self.i16().unwrap() $(, $opt_args)*), + DataType::Int32 => $macro!($self.i32().unwrap() $(, $opt_args)*), + DataType::Int64 => $macro!($self.i64().unwrap() $(, $opt_args)*), + #[cfg(feature = "dtype-i128")] + DataType::Int128 => $macro!($self.i128().unwrap() $(, $opt_args)*), + DataType::Float32 => $macro!($self.f32().unwrap() $(, $opt_args)*), + DataType::Float64 => $macro!($self.f64().unwrap() $(, $opt_args)*), + dt => panic!("not implemented for {:?}", dt), + } + }}; +} + +/// Apply a macro on the Downcasted ChunkedArrays of DataTypes that are logical numerics. +/// So no logical. +#[macro_export] +macro_rules! downcast_as_macro_arg_physical_mut { + ($self:expr, $macro:ident $(, $opt_args:expr)*) => {{ + // clone so that we do not borrow + match $self.dtype().clone() { + #[cfg(feature = "dtype-u8")] + DataType::UInt8 => { + let ca: &mut UInt8Chunked = $self.as_mut(); + $macro!(UInt8Type, ca $(, $opt_args)*) + }, + #[cfg(feature = "dtype-u16")] + DataType::UInt16 => { + let ca: &mut UInt16Chunked = $self.as_mut(); + $macro!(UInt16Type, ca $(, $opt_args)*) + }, + DataType::UInt32 => { + let ca: &mut UInt32Chunked = $self.as_mut(); + $macro!(UInt32Type, ca $(, $opt_args)*) + }, + DataType::UInt64 => { + let ca: &mut UInt64Chunked = $self.as_mut(); + $macro!(UInt64Type, ca $(, $opt_args)*) + }, + #[cfg(feature = "dtype-i8")] + DataType::Int8 => { + let ca: &mut Int8Chunked = $self.as_mut(); + $macro!(Int8Type, ca $(, $opt_args)*) + }, + #[cfg(feature = "dtype-i16")] + DataType::Int16 => { + let ca: &mut Int16Chunked = $self.as_mut(); + $macro!(Int16Type, ca $(, $opt_args)*) + }, + DataType::Int32 => { + let ca: &mut Int32Chunked = $self.as_mut(); + $macro!(Int32Type, ca $(, $opt_args)*) + }, + DataType::Int64 => { + let ca: &mut Int64Chunked = $self.as_mut(); + $macro!(Int64Type, ca $(, $opt_args)*) + }, + #[cfg(feature = "dtype-i128")] + DataType::Int128 => { + let ca: &mut Int128Chunked = $self.as_mut(); + $macro!(Int128Type, ca $(, $opt_args)*) + }, + DataType::Float32 => { + let ca: &mut Float32Chunked = $self.as_mut(); + $macro!(Float32Type, ca $(, $opt_args)*) + }, + DataType::Float64 => { + let ca: &mut Float64Chunked = $self.as_mut(); + $macro!(Float64Type, ca $(, $opt_args)*) + }, + dt => panic!("not implemented for {:?}", dt), + } + }}; +} + +#[macro_export] +macro_rules! apply_method_all_arrow_series { + ($self:expr, $method:ident, $($args:expr),*) => { + match $self.dtype() { + DataType::Boolean => $self.bool().unwrap().$method($($args),*), + DataType::String => $self.str().unwrap().$method($($args),*), + #[cfg(feature = "dtype-u8")] + DataType::UInt8 => $self.u8().unwrap().$method($($args),*), + #[cfg(feature = "dtype-u16")] + DataType::UInt16 => $self.u16().unwrap().$method($($args),*), + DataType::UInt32 => $self.u32().unwrap().$method($($args),*), + DataType::UInt64 => $self.u64().unwrap().$method($($args),*), + #[cfg(feature = "dtype-i8")] + DataType::Int8 => $self.i8().unwrap().$method($($args),*), + #[cfg(feature = "dtype-i16")] + DataType::Int16 => $self.i16().unwrap().$method($($args),*), + DataType::Int32 => $self.i32().unwrap().$method($($args),*), + DataType::Int64 => $self.i64().unwrap().$method($($args),*), + #[cfg(feature = "dtype-i128")] + DataType::Int128 => $self.i128().unwrap().$method($($args),*), + DataType::Float32 => $self.f32().unwrap().$method($($args),*), + DataType::Float64 => $self.f64().unwrap().$method($($args),*), + DataType::Time => $self.time().unwrap().$method($($args),*), + DataType::Date => $self.date().unwrap().$method($($args),*), + DataType::Datetime(_, _) => $self.datetime().unwrap().$method($($args),*), + DataType::List(_) => $self.list().unwrap().$method($($args),*), + DataType::Struct(_) => $self.struct_().unwrap().$method($($args),*), + dt => panic!("dtype {:?} not supported", dt) + } + } +} + +#[macro_export] +macro_rules! apply_method_physical_integer { + ($self:expr, $method:ident, $($args:expr),*) => { + match $self.dtype() { + #[cfg(feature = "dtype-u8")] + DataType::UInt8 => $self.u8().unwrap().$method($($args),*), + #[cfg(feature = "dtype-u16")] + DataType::UInt16 => $self.u16().unwrap().$method($($args),*), + DataType::UInt32 => $self.u32().unwrap().$method($($args),*), + DataType::UInt64 => $self.u64().unwrap().$method($($args),*), + #[cfg(feature = "dtype-i8")] + DataType::Int8 => $self.i8().unwrap().$method($($args),*), + #[cfg(feature = "dtype-i16")] + DataType::Int16 => $self.i16().unwrap().$method($($args),*), + DataType::Int32 => $self.i32().unwrap().$method($($args),*), + DataType::Int64 => $self.i64().unwrap().$method($($args),*), + #[cfg(feature = "dtype-i128")] + DataType::Int128 => $self.i128().unwrap().$method($($args),*), + dt => panic!("not implemented for dtype {:?}", dt), + } + } +} + +// doesn't include Bool and String +#[macro_export] +macro_rules! apply_method_physical_numeric { + ($self:expr, $method:ident, $($args:expr),*) => { + match $self.dtype() { + DataType::Float32 => $self.f32().unwrap().$method($($args),*), + DataType::Float64 => $self.f64().unwrap().$method($($args),*), + _ => apply_method_physical_integer!($self, $method, $($args),*), + } + } +} + +#[macro_export] +macro_rules! df { + ($($col_name:expr => $slice:expr), + $(,)?) => { + $crate::prelude::DataFrame::new(vec![ + $($crate::prelude::Column::from(<$crate::prelude::Series as $crate::prelude::NamedFrom::<_, _>>::new($col_name.into(), $slice)),)+ + ]) + } +} + +pub fn get_time_units(tu_l: &TimeUnit, tu_r: &TimeUnit) -> TimeUnit { + use TimeUnit::*; + match (tu_l, tu_r) { + (Nanoseconds, Microseconds) => Microseconds, + (_, Milliseconds) => Milliseconds, + _ => *tu_l, + } +} + +#[cold] +#[inline(never)] +fn width_mismatch(df1: &DataFrame, df2: &DataFrame) -> PolarsError { + let mut df1_extra = Vec::new(); + let mut df2_extra = Vec::new(); + + let s1 = df1.schema(); + let s2 = df2.schema(); + + s1.field_compare(s2, &mut df1_extra, &mut df2_extra); + + let df1_extra = df1_extra + .into_iter() + .map(|(_, (n, _))| n.as_str()) + .collect::>() + .join(", "); + let df2_extra = df2_extra + .into_iter() + .map(|(_, (n, _))| n.as_str()) + .collect::>() + .join(", "); + + polars_err!( + SchemaMismatch: r#"unable to vstack, dataframes have different widths ({} != {}). +One dataframe has additional columns: [{df1_extra}]. +Other dataframe has additional columns: [{df2_extra}]."#, + df1.width(), + df2.width(), + ) +} + +pub fn accumulate_dataframes_vertical_unchecked_optional(dfs: I) -> Option +where + I: IntoIterator, +{ + let mut iter = dfs.into_iter(); + let additional = iter.size_hint().0; + let mut acc_df = iter.next()?; + acc_df.reserve_chunks(additional); + + for df in iter { + if acc_df.width() != df.width() { + panic!("{}", width_mismatch(&acc_df, &df)); + } + + acc_df.vstack_mut_owned_unchecked(df); + } + Some(acc_df) +} + +/// This takes ownership of the DataFrame so that drop is called earlier. +/// Does not check if schema is correct +pub fn accumulate_dataframes_vertical_unchecked(dfs: I) -> DataFrame +where + I: IntoIterator, +{ + let mut iter = dfs.into_iter(); + let additional = iter.size_hint().0; + let mut acc_df = iter.next().unwrap(); + acc_df.reserve_chunks(additional); + + for df in iter { + if acc_df.width() != df.width() { + panic!("{}", width_mismatch(&acc_df, &df)); + } + + acc_df.vstack_mut_owned_unchecked(df); + } + acc_df +} + +/// This takes ownership of the DataFrame so that drop is called earlier. +/// # Panics +/// Panics if `dfs` is empty. +pub fn accumulate_dataframes_vertical(dfs: I) -> PolarsResult +where + I: IntoIterator, +{ + let mut iter = dfs.into_iter(); + let additional = iter.size_hint().0; + let mut acc_df = iter.next().unwrap(); + acc_df.reserve_chunks(additional); + for df in iter { + if acc_df.width() != df.width() { + return Err(width_mismatch(&acc_df, &df)); + } + + acc_df.vstack_mut_owned(df)?; + } + + Ok(acc_df) +} + +/// Concat the DataFrames to a single DataFrame. +pub fn concat_df<'a, I>(dfs: I) -> PolarsResult +where + I: IntoIterator, +{ + let mut iter = dfs.into_iter(); + let additional = iter.size_hint().0; + let mut acc_df = iter.next().unwrap().clone(); + acc_df.reserve_chunks(additional); + for df in iter { + acc_df.vstack_mut(df)?; + } + Ok(acc_df) +} + +/// Concat the DataFrames to a single DataFrame. +pub fn concat_df_unchecked<'a, I>(dfs: I) -> DataFrame +where + I: IntoIterator, +{ + let mut iter = dfs.into_iter(); + let additional = iter.size_hint().0; + let mut acc_df = iter.next().unwrap().clone(); + acc_df.reserve_chunks(additional); + for df in iter { + acc_df.vstack_mut_unchecked(df); + } + acc_df +} + +pub fn accumulate_dataframes_horizontal(dfs: Vec) -> PolarsResult { + let mut iter = dfs.into_iter(); + let mut acc_df = iter.next().unwrap(); + for df in iter { + acc_df.hstack_mut(df.get_columns())?; + } + Ok(acc_df) +} + +/// Ensure the chunks in both ChunkedArrays have the same length. +/// # Panics +/// This will panic if `left.len() != right.len()` and array is chunked. +pub fn align_chunks_binary<'a, T, B>( + left: &'a ChunkedArray, + right: &'a ChunkedArray, +) -> (Cow<'a, ChunkedArray>, Cow<'a, ChunkedArray>) +where + B: PolarsDataType, + T: PolarsDataType, +{ + let assert = || { + assert_eq!( + left.len(), + right.len(), + "expected arrays of the same length" + ) + }; + match (left.chunks.len(), right.chunks.len()) { + // All chunks are equal length + (1, 1) => (Cow::Borrowed(left), Cow::Borrowed(right)), + // All chunks are equal length + (a, b) + if a == b + && left + .chunk_lengths() + .zip(right.chunk_lengths()) + .all(|(l, r)| l == r) => + { + (Cow::Borrowed(left), Cow::Borrowed(right)) + }, + (_, 1) => { + assert(); + ( + Cow::Borrowed(left), + Cow::Owned(right.match_chunks(left.chunk_lengths())), + ) + }, + (1, _) => { + assert(); + ( + Cow::Owned(left.match_chunks(right.chunk_lengths())), + Cow::Borrowed(right), + ) + }, + (_, _) => { + assert(); + // could optimize to choose to rechunk a primitive and not a string or list type + let left = left.rechunk(); + ( + Cow::Owned(left.match_chunks(right.chunk_lengths())), + Cow::Borrowed(right), + ) + }, + } +} + +#[cfg(feature = "performant")] +pub(crate) fn align_chunks_binary_owned_series(left: Series, right: Series) -> (Series, Series) { + match (left.chunks().len(), right.chunks().len()) { + (1, 1) => (left, right), + // All chunks are equal length + (a, b) + if a == b + && left + .chunk_lengths() + .zip(right.chunk_lengths()) + .all(|(l, r)| l == r) => + { + (left, right) + }, + (_, 1) => (left.rechunk(), right), + (1, _) => (left, right.rechunk()), + (_, _) => (left.rechunk(), right.rechunk()), + } +} + +pub(crate) fn align_chunks_binary_owned( + left: ChunkedArray, + right: ChunkedArray, +) -> (ChunkedArray, ChunkedArray) +where + B: PolarsDataType, + T: PolarsDataType, +{ + match (left.chunks.len(), right.chunks.len()) { + (1, 1) => (left, right), + // All chunks are equal length + (a, b) + if a == b + && left + .chunk_lengths() + .zip(right.chunk_lengths()) + .all(|(l, r)| l == r) => + { + (left, right) + }, + (_, 1) => (left.rechunk().into_owned(), right), + (1, _) => (left, right.rechunk().into_owned()), + (_, _) => (left.rechunk().into_owned(), right.rechunk().into_owned()), + } +} + +/// # Panics +/// This will panic if `a.len() != b.len() || b.len() != c.len()` and array is chunked. +#[allow(clippy::type_complexity)] +pub fn align_chunks_ternary<'a, A, B, C>( + a: &'a ChunkedArray, + b: &'a ChunkedArray, + c: &'a ChunkedArray, +) -> ( + Cow<'a, ChunkedArray>, + Cow<'a, ChunkedArray>, + Cow<'a, ChunkedArray>, +) +where + A: PolarsDataType, + B: PolarsDataType, + C: PolarsDataType, +{ + if a.chunks.len() == 1 && b.chunks.len() == 1 && c.chunks.len() == 1 { + return (Cow::Borrowed(a), Cow::Borrowed(b), Cow::Borrowed(c)); + } + + assert!( + a.len() == b.len() && b.len() == c.len(), + "expected arrays of the same length" + ); + + match (a.chunks.len(), b.chunks.len(), c.chunks.len()) { + (_, 1, 1) => ( + Cow::Borrowed(a), + Cow::Owned(b.match_chunks(a.chunk_lengths())), + Cow::Owned(c.match_chunks(a.chunk_lengths())), + ), + (1, 1, _) => ( + Cow::Owned(a.match_chunks(c.chunk_lengths())), + Cow::Owned(b.match_chunks(c.chunk_lengths())), + Cow::Borrowed(c), + ), + (1, _, 1) => ( + Cow::Owned(a.match_chunks(b.chunk_lengths())), + Cow::Borrowed(b), + Cow::Owned(c.match_chunks(b.chunk_lengths())), + ), + (1, _, _) => { + let b = b.rechunk(); + ( + Cow::Owned(a.match_chunks(c.chunk_lengths())), + Cow::Owned(b.match_chunks(c.chunk_lengths())), + Cow::Borrowed(c), + ) + }, + (_, 1, _) => { + let a = a.rechunk(); + ( + Cow::Owned(a.match_chunks(c.chunk_lengths())), + Cow::Owned(b.match_chunks(c.chunk_lengths())), + Cow::Borrowed(c), + ) + }, + (_, _, 1) => { + let b = b.rechunk(); + ( + Cow::Borrowed(a), + Cow::Owned(b.match_chunks(a.chunk_lengths())), + Cow::Owned(c.match_chunks(a.chunk_lengths())), + ) + }, + (len_a, len_b, len_c) + if len_a == len_b + && len_b == len_c + && a.chunk_lengths() + .zip(b.chunk_lengths()) + .zip(c.chunk_lengths()) + .all(|((a, b), c)| a == b && b == c) => + { + (Cow::Borrowed(a), Cow::Borrowed(b), Cow::Borrowed(c)) + }, + _ => { + // could optimize to choose to rechunk a primitive and not a string or list type + let a = a.rechunk(); + let b = b.rechunk(); + ( + Cow::Owned(a.match_chunks(c.chunk_lengths())), + Cow::Owned(b.match_chunks(c.chunk_lengths())), + Cow::Borrowed(c), + ) + }, + } +} + +pub fn binary_concatenate_validities<'a, T, B>( + left: &'a ChunkedArray, + right: &'a ChunkedArray, +) -> Option +where + B: PolarsDataType, + T: PolarsDataType, +{ + let (left, right) = align_chunks_binary(left, right); + let left_validity = concatenate_validities(left.chunks()); + let right_validity = concatenate_validities(right.chunks()); + combine_validities_and(left_validity.as_ref(), right_validity.as_ref()) +} + +/// Convenience for `x.into_iter().map(Into::into).collect()` using an `into_vec()` function. +pub trait IntoVec { + fn into_vec(self) -> Vec; +} + +impl IntoVec for I +where + I: IntoIterator, + S: Into, +{ + fn into_vec(self) -> Vec { + self.into_iter().map(|s| s.into()).collect() + } +} + +/// This logic is same as the impl on ChunkedArray +/// The difference is that there is less indirection because the caller should preallocate +/// `chunk_lens` once. On the `ChunkedArray` we indirect through an `ArrayRef` which is an indirection +/// and a vtable. +#[inline] +pub(crate) fn index_to_chunked_index< + I: Iterator, + Idx: PartialOrd + std::ops::AddAssign + std::ops::SubAssign + Zero + One, +>( + chunk_lens: I, + index: Idx, +) -> (Idx, Idx) { + let mut index_remainder = index; + let mut current_chunk_idx = Zero::zero(); + + for chunk_len in chunk_lens { + if chunk_len > index_remainder { + break; + } else { + index_remainder -= chunk_len; + current_chunk_idx += One::one(); + } + } + (current_chunk_idx, index_remainder) +} + +pub(crate) fn index_to_chunked_index_rev< + I: Iterator, + Idx: PartialOrd + + std::ops::AddAssign + + std::ops::SubAssign + + std::ops::Sub + + Zero + + One + + Copy + + std::fmt::Debug, +>( + chunk_lens_rev: I, + index_from_back: Idx, + total_chunks: Idx, +) -> (Idx, Idx) { + debug_assert!(index_from_back > Zero::zero(), "at least -1"); + let mut index_remainder = index_from_back; + let mut current_chunk_idx = One::one(); + let mut current_chunk_len = Zero::zero(); + + for chunk_len in chunk_lens_rev { + current_chunk_len = chunk_len; + if chunk_len >= index_remainder { + break; + } else { + index_remainder -= chunk_len; + current_chunk_idx += One::one(); + } + } + ( + total_chunks - current_chunk_idx, + current_chunk_len - index_remainder, + ) +} + +pub(crate) fn first_non_null<'a, I>(iter: I) -> Option +where + I: Iterator>, +{ + let mut offset = 0; + for validity in iter { + if let Some(validity) = validity { + let mask = BitMask::from_bitmap(validity); + if let Some(n) = mask.nth_set_bit_idx(0, 0) { + return Some(offset + n); + } + offset += validity.len() + } else { + return Some(offset); + } + } + None +} + +pub(crate) fn last_non_null<'a, I>(iter: I, len: usize) -> Option +where + I: DoubleEndedIterator>, +{ + if len == 0 { + return None; + } + let mut offset = 0; + for validity in iter.rev() { + if let Some(validity) = validity { + let mask = BitMask::from_bitmap(validity); + if let Some(n) = mask.nth_set_bit_idx_rev(0, mask.len()) { + let mask_start = len - offset - mask.len(); + return Some(mask_start + n); + } + offset += validity.len() + } else { + return Some(len - 1 - offset); + } + } + None +} + +/// ensure that nulls are propagated to both arrays +pub fn coalesce_nulls<'a, T: PolarsDataType>( + a: &'a ChunkedArray, + b: &'a ChunkedArray, +) -> (Cow<'a, ChunkedArray>, Cow<'a, ChunkedArray>) { + if a.null_count() > 0 || b.null_count() > 0 { + let (a, b) = align_chunks_binary(a, b); + let mut b = b.into_owned(); + let a = a.coalesce_nulls(b.chunks()); + + for arr in a.chunks().iter() { + for arr_b in unsafe { b.chunks_mut() } { + *arr_b = arr_b.with_validity(arr.validity().cloned()) + } + } + b.compute_len(); + (Cow::Owned(a), Cow::Owned(b)) + } else { + (Cow::Borrowed(a), Cow::Borrowed(b)) + } +} + +pub fn coalesce_nulls_columns(a: &Column, b: &Column) -> (Column, Column) { + if a.null_count() > 0 || b.null_count() > 0 { + let mut a = a.as_materialized_series().rechunk(); + let mut b = b.as_materialized_series().rechunk(); + for (arr_a, arr_b) in unsafe { a.chunks_mut().iter_mut().zip(b.chunks_mut()) } { + let validity = match (arr_a.validity(), arr_b.validity()) { + (None, Some(b)) => Some(b.clone()), + (Some(a), Some(b)) => Some(a & b), + (Some(a), None) => Some(a.clone()), + (None, None) => None, + }; + *arr_a = arr_a.with_validity(validity.clone()); + *arr_b = arr_b.with_validity(validity); + } + a.compute_len(); + b.compute_len(); + (a.into(), b.into()) + } else { + (a.clone(), b.clone()) + } +} + +pub fn operation_exceeded_idxsize_msg(operation: &str) -> String { + if size_of::() == size_of::() { + format!( + "{} exceeded the maximum supported limit of {} rows. Consider installing 'polars-u64-idx'.", + operation, + IdxSize::MAX, + ) + } else { + format!( + "{} exceeded the maximum supported limit of {} rows.", + operation, + IdxSize::MAX, + ) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_split() { + let ca: Int32Chunked = (0..10).collect_ca("a".into()); + + let out = split(&ca, 3); + assert_eq!(out[0].len(), 3); + assert_eq!(out[1].len(), 3); + assert_eq!(out[2].len(), 4); + } + + #[test] + fn test_align_chunks() -> PolarsResult<()> { + let a = Int32Chunked::new(PlSmallStr::EMPTY, &[1, 2, 3, 4]); + let mut b = Int32Chunked::new(PlSmallStr::EMPTY, &[1]); + let b2 = Int32Chunked::new(PlSmallStr::EMPTY, &[2, 3, 4]); + + b.append(&b2)?; + let (a, b) = align_chunks_binary(&a, &b); + assert_eq!( + a.chunk_lengths().collect::>(), + b.chunk_lengths().collect::>() + ); + + let a = Int32Chunked::new(PlSmallStr::EMPTY, &[1, 2, 3, 4]); + let mut b = Int32Chunked::new(PlSmallStr::EMPTY, &[1]); + let b1 = b.clone(); + b.append(&b1)?; + b.append(&b1)?; + b.append(&b1)?; + let (a, b) = align_chunks_binary(&a, &b); + assert_eq!( + a.chunk_lengths().collect::>(), + b.chunk_lengths().collect::>() + ); + + Ok(()) + } +} diff --git a/crates/polars-core/src/utils/schema.rs b/crates/polars-core/src/utils/schema.rs new file mode 100644 index 000000000000..558a0ea8f1b8 --- /dev/null +++ b/crates/polars-core/src/utils/schema.rs @@ -0,0 +1,19 @@ +use polars_utils::format_pl_smallstr; + +use crate::prelude::*; + +/// Convert a collection of [`DataType`] into a schema. +/// +/// Field names are set as `column_0`, `column_1`, and so on. +/// +/// [`DataType`]: crate::datatypes::DataType +pub fn dtypes_to_schema(dtypes: I) -> Schema +where + I: IntoIterator, +{ + dtypes + .into_iter() + .enumerate() + .map(|(i, dtype)| Field::new(format_pl_smallstr!("column_{i}"), dtype)) + .collect() +} diff --git a/crates/polars-core/src/utils/series.rs b/crates/polars-core/src/utils/series.rs new file mode 100644 index 000000000000..21a5639afdf3 --- /dev/null +++ b/crates/polars-core/src/utils/series.rs @@ -0,0 +1,46 @@ +use std::rc::Rc; + +use crate::prelude::*; +use crate::series::amortized_iter::AmortSeries; + +/// A utility that allocates an [`AmortSeries`]. The applied function can then use that +/// series container to save heap allocations and swap arrow arrays. +pub fn with_unstable_series(dtype: &DataType, f: F) -> T +where + F: Fn(&mut AmortSeries) -> T, +{ + let container = Series::full_null(PlSmallStr::EMPTY, 0, dtype); + let mut us = AmortSeries::new(Rc::new(container)); + + f(&mut us) +} + +pub fn handle_casting_failures(input: &Series, output: &Series) -> PolarsResult<()> { + let failure_mask = !input.is_null() & output.is_null(); + let failures = input.filter(&failure_mask)?; + + let additional_info = match (input.dtype(), output.dtype()) { + (DataType::String, DataType::Date | DataType::Datetime(_, _)) => { + "\n\nYou might want to try:\n\ + - setting `strict=False` to set values that cannot be converted to `null`\n\ + - using `str.strptime`, `str.to_date`, or `str.to_datetime` and providing a format string" + }, + #[cfg(feature = "dtype-categorical")] + (DataType::String, DataType::Enum(_, _)) => { + "\n\nEnsure that all values in the input column are present in the categories of the enum datatype." + }, + _ => "", + }; + + polars_bail!( + InvalidOperation: + "conversion from `{}` to `{}` failed in column '{}' for {} out of {} values: {}{}", + input.dtype(), + output.dtype(), + output.name(), + failures.len(), + input.len(), + failures.fmt_list(), + additional_info, + ) +} diff --git a/crates/polars-core/src/utils/supertype.rs b/crates/polars-core/src/utils/supertype.rs new file mode 100644 index 000000000000..fb0f45dffc96 --- /dev/null +++ b/crates/polars-core/src/utils/supertype.rs @@ -0,0 +1,653 @@ +use bitflags::bitflags; +use num_traits::Signed; + +use super::*; + +/// Given two data types, determine the data type that both types can safely be cast to. +/// +/// Returns a [`PolarsError::ComputeError`] if no such data type exists. +pub fn try_get_supertype(l: &DataType, r: &DataType) -> PolarsResult { + get_supertype(l, r).ok_or_else( + || polars_err!(SchemaMismatch: "failed to determine supertype of {} and {}", l, r), + ) +} + +pub fn try_get_supertype_with_options( + l: &DataType, + r: &DataType, + options: SuperTypeOptions, +) -> PolarsResult { + get_supertype_with_options(l, r, options).ok_or_else( + || polars_err!(SchemaMismatch: "failed to determine supertype of {} and {}", l, r), + ) +} + +/// Returns a numeric supertype that `l` and `r` can be safely upcasted to if it exists. +pub fn get_numeric_upcast_supertype_lossless(l: &DataType, r: &DataType) -> Option { + use DataType::*; + + if l == r || matches!(l, Unknown(_)) || matches!(r, Unknown(_)) { + None + } else if l.is_float() && r.is_float() { + match (l, r) { + (Float64, _) | (_, Float64) => Some(Float64), + v => { + // Did we add a new float type? + if cfg!(debug_assertions) { + panic!("{:?}", v) + } else { + None + } + }, + } + } else if l.is_signed_integer() && r.is_signed_integer() { + match (l, r) { + (Int128, _) | (_, Int128) => Some(Int128), + (Int64, _) | (_, Int64) => Some(Int64), + (Int32, _) | (_, Int32) => Some(Int32), + (Int16, _) | (_, Int16) => Some(Int16), + (Int8, _) | (_, Int8) => Some(Int8), + v => { + if cfg!(debug_assertions) { + panic!("{:?}", v) + } else { + None + } + }, + } + } else if l.is_unsigned_integer() && r.is_unsigned_integer() { + match (l, r) { + (UInt64, _) | (_, UInt64) => Some(UInt64), + (UInt32, _) | (_, UInt32) => Some(UInt32), + (UInt16, _) | (_, UInt16) => Some(UInt16), + (UInt8, _) | (_, UInt8) => Some(UInt8), + v => { + if cfg!(debug_assertions) { + panic!("{:?}", v) + } else { + None + } + }, + } + } else if l.is_integer() && r.is_integer() { + // One side is signed, the other is unsigned. We just need to upcast the + // unsigned side to a signed integer with the next-largest bit width. + match (l, r) { + (UInt64, _) | (_, UInt64) | (Int128, _) | (_, Int128) => Some(Int128), + (UInt32, _) | (_, UInt32) | (Int64, _) | (_, Int64) => Some(Int64), + (UInt16, _) | (_, UInt16) | (Int32, _) | (_, Int32) => Some(Int32), + (UInt8, _) | (_, UInt8) | (Int16, _) | (_, Int16) => Some(Int16), + v => { + // One side was UInt and we should have already matched against + // all the UInt types + if cfg!(debug_assertions) { + panic!("{:?}", v) + } else { + None + } + }, + } + } else { + None + } +} + +bitflags! { + #[repr(transparent)] + #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] + pub struct SuperTypeFlags: u8 { + /// Implode lists to match nesting types. + const ALLOW_IMPLODE_LIST = 1 << 0; + /// Allow casting of primitive types (numeric, bools) to strings + const ALLOW_PRIMITIVE_TO_STRING = 1 << 1; + } +} + +impl Default for SuperTypeFlags { + fn default() -> Self { + SuperTypeFlags::from_bits_truncate(0) | SuperTypeFlags::ALLOW_PRIMITIVE_TO_STRING + } +} + +#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash, Default)] +pub struct SuperTypeOptions { + pub flags: SuperTypeFlags, +} + +impl From for SuperTypeOptions { + fn from(flags: SuperTypeFlags) -> Self { + SuperTypeOptions { flags } + } +} + +impl SuperTypeOptions { + pub fn allow_implode_list(&self) -> bool { + self.flags.contains(SuperTypeFlags::ALLOW_IMPLODE_LIST) + } + + pub fn allow_primitive_to_string(&self) -> bool { + self.flags + .contains(SuperTypeFlags::ALLOW_PRIMITIVE_TO_STRING) + } +} + +pub fn get_supertype(l: &DataType, r: &DataType) -> Option { + get_supertype_with_options(l, r, SuperTypeOptions::default()) +} + +/// Given two data types, determine the data type that both types can safely be cast to. +/// +/// Returns [`None`] if no such data type exists. +pub fn get_supertype_with_options( + l: &DataType, + r: &DataType, + options: SuperTypeOptions, +) -> Option { + fn inner(l: &DataType, r: &DataType, options: SuperTypeOptions) -> Option { + use DataType::*; + if l == r { + return Some(l.clone()); + } + match (l, r) { + #[cfg(feature = "dtype-i8")] + (Int8, Boolean) => Some(Int8), + //(Int8, Int8) => Some(Int8), + #[cfg(all(feature = "dtype-i8", feature = "dtype-i16"))] + (Int8, Int16) => Some(Int16), + #[cfg(feature = "dtype-i8")] + (Int8, Int32) => Some(Int32), + #[cfg(feature = "dtype-i8")] + (Int8, Int64) => Some(Int64), + #[cfg(all(feature = "dtype-i8", feature = "dtype-i16"))] + (Int8, UInt8) => Some(Int16), + #[cfg(all(feature = "dtype-i8", feature = "dtype-u16"))] + (Int8, UInt16) => Some(Int32), + #[cfg(feature = "dtype-i8")] + (Int8, UInt32) => Some(Int64), + #[cfg(feature = "dtype-i8")] + (Int8, UInt64) => Some(Float64), // Follow numpy + #[cfg(feature = "dtype-i8")] + (Int8, Float32) => Some(Float32), + #[cfg(feature = "dtype-i8")] + (Int8, Float64) => Some(Float64), + + #[cfg(feature = "dtype-i16")] + (Int16, Boolean) => Some(Int16), + #[cfg(all(feature = "dtype-i16", feature = "dtype-i8"))] + (Int16, Int8) => Some(Int16), + //(Int16, Int16) => Some(Int16), + #[cfg(feature = "dtype-i16")] + (Int16, Int32) => Some(Int32), + #[cfg(feature = "dtype-i16")] + (Int16, Int64) => Some(Int64), + #[cfg(all(feature = "dtype-i16", feature = "dtype-u8"))] + (Int16, UInt8) => Some(Int16), + #[cfg(all(feature = "dtype-i16", feature = "dtype-u16"))] + (Int16, UInt16) => Some(Int32), + #[cfg(feature = "dtype-i16")] + (Int16, UInt32) => Some(Int64), + #[cfg(feature = "dtype-i16")] + (Int16, UInt64) => Some(Float64), // Follow numpy + #[cfg(feature = "dtype-i16")] + (Int16, Float32) => Some(Float32), + #[cfg(feature = "dtype-i16")] + (Int16, Float64) => Some(Float64), + + + #[cfg(feature = "dtype-i128")] + (a, Int128) if a.is_integer() | a.is_bool() => Some(Int128), + #[cfg(feature = "dtype-i128")] + (a, Int128) if a.is_float() => Some(Float64), + #[cfg(feature = "dtype-i128")] + + + (Int32, Boolean) => Some(Int32), + #[cfg(feature = "dtype-i8")] + (Int32, Int8) => Some(Int32), + #[cfg(feature = "dtype-i16")] + (Int32, Int16) => Some(Int32), + //(Int32, Int32) => Some(Int32), + (Int32, Int64) => Some(Int64), + #[cfg(feature = "dtype-u8")] + (Int32, UInt8) => Some(Int32), + #[cfg(feature = "dtype-u16")] + (Int32, UInt16) => Some(Int32), + (Int32, UInt32) => Some(Int64), + #[cfg(not(feature = "bigidx"))] + (Int32, UInt64) => Some(Float64), // Follow numpy + #[cfg(feature = "bigidx")] + (Int32, UInt64) => Some(Int64), // Needed for bigidx + (Int32, Float32) => Some(Float64), // Follow numpy + (Int32, Float64) => Some(Float64), + + (Int64, Boolean) => Some(Int64), + #[cfg(feature = "dtype-i8")] + (Int64, Int8) => Some(Int64), + #[cfg(feature = "dtype-i16")] + (Int64, Int16) => Some(Int64), + (Int64, Int32) => Some(Int64), + //(Int64, Int64) => Some(Int64), + #[cfg(feature = "dtype-u8")] + (Int64, UInt8) => Some(Int64), + #[cfg(feature = "dtype-u16")] + (Int64, UInt16) => Some(Int64), + (Int64, UInt32) => Some(Int64), + #[cfg(not(feature = "bigidx"))] + (Int64, UInt64) => Some(Float64), // Follow numpy + #[cfg(feature = "bigidx")] + (Int64, UInt64) => Some(Int64), // Needed for bigidx + (Int64, Float32) => Some(Float64), // Follow numpy + (Int64, Float64) => Some(Float64), + + #[cfg(all(feature = "dtype-u16", feature = "dtype-u8"))] + (UInt16, UInt8) => Some(UInt16), + #[cfg(feature = "dtype-u16")] + (UInt16, UInt32) => Some(UInt32), + #[cfg(feature = "dtype-u16")] + (UInt16, UInt64) => Some(UInt64), + + #[cfg(feature = "dtype-u8")] + (UInt8, UInt32) => Some(UInt32), + #[cfg(feature = "dtype-u8")] + (UInt8, UInt64) => Some(UInt64), + + (UInt32, UInt64) => Some(UInt64), + + #[cfg(feature = "dtype-u8")] + (Boolean, UInt8) => Some(UInt8), + #[cfg(feature = "dtype-u16")] + (Boolean, UInt16) => Some(UInt16), + (Boolean, UInt32) => Some(UInt32), + (Boolean, UInt64) => Some(UInt64), + + #[cfg(feature = "dtype-u8")] + (Float32, UInt8) => Some(Float32), + #[cfg(feature = "dtype-u16")] + (Float32, UInt16) => Some(Float32), + (Float32, UInt32) => Some(Float64), + (Float32, UInt64) => Some(Float64), + + #[cfg(feature = "dtype-u8")] + (Float64, UInt8) => Some(Float64), + #[cfg(feature = "dtype-u16")] + (Float64, UInt16) => Some(Float64), + (Float64, UInt32) => Some(Float64), + (Float64, UInt64) => Some(Float64), + + (Float64, Float32) => Some(Float64), + + // Time related dtypes + #[cfg(feature = "dtype-date")] + (Date, UInt32) => Some(Int64), + #[cfg(feature = "dtype-date")] + (Date, UInt64) => Some(Int64), + #[cfg(feature = "dtype-date")] + (Date, Int32) => Some(Int32), + #[cfg(feature = "dtype-date")] + (Date, Int64) => Some(Int64), + #[cfg(feature = "dtype-date")] + (Date, Float32) => Some(Float32), + #[cfg(feature = "dtype-date")] + (Date, Float64) => Some(Float64), + #[cfg(all(feature = "dtype-date", feature = "dtype-datetime"))] + (Date, Datetime(tu, tz)) => Some(Datetime(*tu, tz.clone())), + + #[cfg(feature = "dtype-datetime")] + (Datetime(_, _), UInt32) => Some(Int64), + #[cfg(feature = "dtype-datetime")] + (Datetime(_, _), UInt64) => Some(Int64), + #[cfg(feature = "dtype-datetime")] + (Datetime(_, _), Int32) => Some(Int64), + #[cfg(feature = "dtype-datetime")] + (Datetime(_, _), Int64) => Some(Int64), + #[cfg(feature = "dtype-datetime")] + (Datetime(_, _), Float32) => Some(Float64), + #[cfg(feature = "dtype-datetime")] + (Datetime(_, _), Float64) => Some(Float64), + #[cfg(all(feature = "dtype-datetime", feature = "dtype-date"))] + (Datetime(tu, tz), Date) => Some(Datetime(*tu, tz.clone())), + + (Boolean, Float32) => Some(Float32), + (Boolean, Float64) => Some(Float64), + + #[cfg(feature = "dtype-duration")] + (Duration(_), UInt32) => Some(Int64), + #[cfg(feature = "dtype-duration")] + (Duration(_), UInt64) => Some(Int64), + #[cfg(feature = "dtype-duration")] + (Duration(_), Int32) => Some(Int64), + #[cfg(feature = "dtype-duration")] + (Duration(_), Int64) => Some(Int64), + #[cfg(feature = "dtype-duration")] + (Duration(_), Float32) => Some(Float64), + #[cfg(feature = "dtype-duration")] + (Duration(_), Float64) => Some(Float64), + + #[cfg(feature = "dtype-time")] + (Time, Int32) => Some(Int64), + #[cfg(feature = "dtype-time")] + (Time, Int64) => Some(Int64), + #[cfg(feature = "dtype-time")] + (Time, Float32) => Some(Float64), + #[cfg(feature = "dtype-time")] + (Time, Float64) => Some(Float64), + + // Every known type can be cast to a string except binary + (dt, String) if !matches!(dt, Unknown(UnknownKind::Any)) && dt != &Binary && options.allow_primitive_to_string() || !dt.to_physical().is_primitive() => Some(String), + (String, Binary) => Some(Binary), + (dt, Null) => Some(dt.clone()), + + #[cfg(all(feature = "dtype-duration", feature = "dtype-datetime"))] + (Duration(lu), Datetime(ru, Some(tz))) | (Datetime(lu, Some(tz)), Duration(ru)) => { + if tz.is_empty() { + Some(Datetime(get_time_units(lu, ru), None)) + } else { + Some(Datetime(get_time_units(lu, ru), Some(tz.clone()))) + } + } + #[cfg(all(feature = "dtype-duration", feature = "dtype-datetime"))] + (Duration(lu), Datetime(ru, None)) | (Datetime(lu, None), Duration(ru)) => { + Some(Datetime(get_time_units(lu, ru), None)) + } + #[cfg(all(feature = "dtype-duration", feature = "dtype-date"))] + (Duration(_), Date) | (Date, Duration(_)) => Some(Date), + #[cfg(feature = "dtype-duration")] + (Duration(lu), Duration(ru)) => Some(Duration(get_time_units(lu, ru))), + + // both None or both Some("") timezones + // we cast from more precision to higher precision as that always fits with occasional loss of precision + #[cfg(feature = "dtype-datetime")] + (Datetime(tu_l, tz_l), Datetime(tu_r, tz_r)) if + // both are none + (tz_l.is_none() && tz_r.is_none()) + // both have the same time zone + || (tz_l.is_some() && (tz_l == tz_r)) => { + let tu = get_time_units(tu_l, tu_r); + Some(Datetime(tu, tz_r.clone())) + } + (List(inner_left), List(inner_right)) => { + let st = get_supertype(inner_left, inner_right)?; + Some(List(Box::new(st))) + } + #[cfg(feature = "dtype-array")] + (List(inner_left), Array(inner_right, _)) | (Array(inner_left, _), List(inner_right)) => { + let st = get_supertype(inner_left, inner_right)?; + Some(List(Box::new(st))) + } + #[cfg(feature = "dtype-array")] + (Array(inner_left, width_left), Array(inner_right, width_right)) if *width_left == *width_right => { + let st = get_supertype(inner_left, inner_right)?; + Some(Array(Box::new(st), *width_left)) + } + (List(inner), other) | (other, List(inner)) if options.allow_implode_list() => { + let st = get_supertype(inner, other)?; + Some(List(Box::new(st))) + } + #[cfg(feature = "dtype-array")] + (Array(inner_left, _), Array(inner_right, _)) => { + let st = get_supertype(inner_left, inner_right)?; + Some(List(Box::new(st))) + } + #[cfg(feature = "dtype-struct")] + (Struct(inner), right @ Unknown(UnknownKind::Float | UnknownKind::Int(_))) => { + match inner.first() { + Some(inner) => get_supertype(&inner.dtype, right), + None => None + } + }, + (dt, Unknown(kind)) => { + match kind { + UnknownKind::Float | UnknownKind::Int(_) if dt.is_string() => { + if options.allow_primitive_to_string() { + Some(dt.clone()) + } else { + None + } + }, + // numeric vs float|str -> always float|str|decimal + UnknownKind::Float | UnknownKind::Int(_) if dt.is_float() | dt.is_decimal() => Some(dt.clone()), + UnknownKind::Float if dt.is_integer() => Some(Unknown(UnknownKind::Float)), + // Materialize float to float or decimal + UnknownKind::Float if dt.is_float() | dt.is_decimal() => Some(dt.clone()), + // Materialize str + UnknownKind::Str if dt.is_string() | dt.is_enum() => Some(dt.clone()), + // Materialize str + #[cfg(feature = "dtype-categorical")] + UnknownKind::Str if dt.is_categorical() => { + let Categorical(_, ord) = dt else { unreachable!()}; + Some(Categorical(None, *ord)) + }, + // Keep unknown + dynam if dt.is_null() => Some(Unknown(*dynam)), + // Find integers sizes + UnknownKind::Int(v) if dt.is_primitive_numeric() => { + // Both dyn int + if let Unknown(UnknownKind::Int(v_other)) = dt { + // Take the maximum value to ensure we bubble up the required minimal size. + Some(Unknown(UnknownKind::Int(std::cmp::max(*v, *v_other)))) + } + // dyn int vs number + else { + let smallest_fitting_dtype = if dt.is_unsigned_integer() && !v.is_negative() { + materialize_dyn_int_pos(*v).dtype() + } else { + materialize_smallest_dyn_int(*v).dtype() + }; + match dt { + UInt64 if smallest_fitting_dtype.is_signed_integer() => { + // Ensure we don't cast to float when dealing with dynamic literals + Some(Int64) + }, + _ => { + get_supertype(dt, &smallest_fitting_dtype) + } + } + } + } + UnknownKind::Int(_) if dt.is_decimal() => Some(dt.clone()), + _ => Some(Unknown(UnknownKind::Any)) + } + }, + #[cfg(feature = "dtype-struct")] + (Struct(fields_a), Struct(fields_b)) => { + super_type_structs(fields_a, fields_b) + } + #[cfg(feature = "dtype-struct")] + (Struct(fields_a), rhs) if rhs.is_primitive_numeric() => { + let mut new_fields = Vec::with_capacity(fields_a.len()); + for a in fields_a { + let st = get_supertype(&a.dtype, rhs)?; + new_fields.push(Field::new(a.name.clone(), st)) + } + Some(Struct(new_fields)) + } + #[cfg(feature = "dtype-decimal")] + (Decimal(p1, s1), Decimal(p2, s2)) => { + Some(Decimal((*p1).zip(*p2).map(|(p1, p2)| p1.max(p2)), (*s1).max(*s2))) + } + #[cfg(feature = "dtype-decimal")] + (Decimal(_, _), f @ (Float32 | Float64)) => Some(f.clone()), + #[cfg(feature = "dtype-decimal")] + (d @ Decimal(_, _), dt) if dt.is_signed_integer() || dt.is_unsigned_integer() => Some(d.clone()), + _ => None, + } + } + + inner(l, r, options).or_else(|| inner(r, l, options)) +} + +/// Given multiple data types, determine the data type that all types can safely be cast to. +/// +/// Returns [`DataType::Null`] if no data types were passed. +pub fn dtypes_to_supertype<'a, I>(dtypes: I) -> PolarsResult +where + I: IntoIterator, +{ + dtypes + .into_iter() + .try_fold(DataType::Null, |supertype, dtype| { + try_get_supertype(&supertype, dtype) + }) +} + +#[cfg(feature = "dtype-struct")] +fn union_struct_fields(fields_a: &[Field], fields_b: &[Field]) -> Option { + let (longest, shortest) = { + // if equal length we also take the lhs + // so that the lhs determines the order of the fields + if fields_a.len() >= fields_b.len() { + (fields_a, fields_b) + } else { + (fields_b, fields_a) + } + }; + let mut longest_map = + PlIndexMap::from_iter(longest.iter().map(|fld| (&fld.name, fld.dtype.clone()))); + for field in shortest { + let dtype_longest = longest_map + .entry(&field.name) + .or_insert_with(|| field.dtype.clone()); + if &field.dtype != dtype_longest { + let st = get_supertype(&field.dtype, dtype_longest)?; + *dtype_longest = st + } + } + let new_fields = longest_map + .into_iter() + .map(|(name, dtype)| Field::new(name.clone(), dtype)) + .collect::>(); + Some(DataType::Struct(new_fields)) +} + +#[cfg(feature = "dtype-struct")] +fn super_type_structs(fields_a: &[Field], fields_b: &[Field]) -> Option { + if fields_a.len() != fields_b.len() { + union_struct_fields(fields_a, fields_b) + } else { + let mut new_fields = Vec::with_capacity(fields_a.len()); + for (a, b) in fields_a.iter().zip(fields_b) { + if a.name != b.name { + return union_struct_fields(fields_a, fields_b); + } + let st = get_supertype(&a.dtype, &b.dtype)?; + new_fields.push(Field::new(a.name.clone(), st)) + } + Some(DataType::Struct(new_fields)) + } +} + +pub fn materialize_dyn_int(v: i128) -> AnyValue<'static> { + // Try to get the "smallest" fitting value. + // TODO! next breaking go to true smallest. + if let Ok(v) = i32::try_from(v) { + return AnyValue::Int32(v); + } + if let Ok(v) = i64::try_from(v) { + return AnyValue::Int64(v); + } + if let Ok(v) = u64::try_from(v) { + return AnyValue::UInt64(v); + } + #[cfg(feature = "dtype-i128")] + { + AnyValue::Int128(v) + } + + #[cfg(not(feature = "dtype-i128"))] + AnyValue::Null +} + +fn materialize_dyn_int_pos(v: i128) -> AnyValue<'static> { + // Try to get the "smallest" fitting value. + // TODO! next breaking go to true smallest. + #[cfg(feature = "dtype-u8")] + if let Ok(v) = u8::try_from(v) { + return AnyValue::UInt8(v); + } + #[cfg(feature = "dtype-u16")] + if let Ok(v) = u16::try_from(v) { + return AnyValue::UInt16(v); + } + match u32::try_from(v).ok() { + Some(v) => AnyValue::UInt32(v), + None => match u64::try_from(v).ok() { + Some(v) => AnyValue::UInt64(v), + None => AnyValue::Null, + }, + } +} + +fn materialize_smallest_dyn_int(v: i128) -> AnyValue<'static> { + #[cfg(feature = "dtype-i8")] + if let Ok(v) = i8::try_from(v) { + return AnyValue::Int8(v); + } + #[cfg(feature = "dtype-i16")] + if let Ok(v) = i16::try_from(v) { + return AnyValue::Int16(v); + } + match i32::try_from(v).ok() { + Some(v) => AnyValue::Int32(v), + None => match i64::try_from(v).ok() { + Some(v) => AnyValue::Int64(v), + None => match u64::try_from(v).ok() { + Some(v) => AnyValue::UInt64(v), + None => AnyValue::Null, + }, + }, + } +} + +pub fn merge_dtypes_many + Clone, D: AsRef>( + into_iter: I, +) -> PolarsResult { + let mut iter = into_iter.clone().into_iter(); + + let mut st = iter + .next() + .ok_or_else(|| polars_err!(ComputeError: "expect at least 1 dtype")) + .map(|d| d.as_ref().clone())?; + + for d in iter { + st = try_get_supertype(d.as_ref(), &st)?; + } + + match st { + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(Some(_), ordering) => { + // This merges the global rev maps with linear complexity. + // If we do a binary reduce, it would be quadratic. + let mut iter = into_iter.into_iter(); + let first_dt = iter.next().unwrap(); + let first_dt = first_dt.as_ref(); + let DataType::Categorical(Some(rm), _) = first_dt else { + unreachable!() + }; + polars_ensure!(matches!(rm.as_ref(), RevMapping::Global(_, _, _)), ComputeError: "global string cache must be set to merge categorical columns"); + + let mut merger = GlobalRevMapMerger::new(rm.clone()); + + for d in iter { + if let DataType::Categorical(Some(rm), _) = d.as_ref() { + merger.merge_map(rm)? + } + } + let rev_map = merger.finish(); + + Ok(DataType::Categorical(Some(rev_map), ordering)) + }, + // This would be quadratic if we do this with the binary `merge_dtypes`. + DataType::List(inner) if inner.contains_categoricals() => { + polars_bail!(ComputeError: "merging nested categoricals not yet supported") + }, + #[cfg(feature = "dtype-array")] + DataType::Array(inner, _) if inner.contains_categoricals() => { + polars_bail!(ComputeError: "merging nested categoricals not yet supported") + }, + #[cfg(feature = "dtype-struct")] + DataType::Struct(fields) if fields.iter().any(|f| f.dtype().contains_categoricals()) => { + polars_bail!(ComputeError: "merging nested categoricals not yet supported") + }, + _ => Ok(st), + } +} diff --git a/crates/polars-dylib/Cargo.toml b/crates/polars-dylib/Cargo.toml new file mode 100644 index 000000000000..5cc963f2d701 --- /dev/null +++ b/crates/polars-dylib/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "polars-dylib" +version.workspace = true +authors.workspace = true +edition.workspace = true +homepage.workspace = true +license.workspace = true +repository.workspace = true + +[lib] +crate-type = ["dylib", "rlib"] + +[dependencies] +arrow = { workspace = true, optional = true, features = ["io_flight"] } +polars = { workspace = true, features = ["full"] } +polars-core = { workspace = true, optional = true } +polars-expr = { workspace = true, optional = true } +polars-lazy = { workspace = true, optional = true } +polars-mem-engine = { workspace = true, optional = true } +polars-plan = { workspace = true, optional = true } +polars-python = { workspace = true, optional = true, default-features = true } + +[features] +private = ["polars-plan", "arrow", "polars-core", "polars-lazy", "polars-expr", "polars-mem-engine"] +python = ["polars-plan?/python", "polars-python", "polars-lazy?/python"] diff --git a/crates/polars-dylib/README.md b/crates/polars-dylib/README.md new file mode 100644 index 000000000000..3fd4b30de8f7 --- /dev/null +++ b/crates/polars-dylib/README.md @@ -0,0 +1,16 @@ +# Polars dynamic library + +```toml +# Cargo.toml +[workspace.dependencies.polars] +package = "polars-dylib" +``` + +```toml +# .cargo/config.toml +[build] +rustflags = [ + "-C", + "prefer-dynamic", +] +``` diff --git a/crates/polars-dylib/src/lib.rs b/crates/polars-dylib/src/lib.rs new file mode 100644 index 000000000000..907ce175aec8 --- /dev/null +++ b/crates/polars-dylib/src/lib.rs @@ -0,0 +1,15 @@ +#[cfg(feature = "private")] +pub use arrow as _arrow; +pub use polars::*; +#[cfg(feature = "private")] +pub use polars_core as _core; +#[cfg(feature = "private")] +pub use polars_expr as _expr; +#[cfg(feature = "private")] +pub use polars_lazy as _lazy; +#[cfg(feature = "private")] +pub use polars_mem_engine as _mem_engine; +#[cfg(feature = "private")] +pub use polars_plan as _plan; +#[cfg(feature = "python")] +pub use polars_python as _python; diff --git a/crates/polars-error/Cargo.toml b/crates/polars-error/Cargo.toml new file mode 100644 index 000000000000..18852dad97ee --- /dev/null +++ b/crates/polars-error/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "polars-error" +version = { workspace = true } +authors = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +license = { workspace = true } +repository = { workspace = true } +description = "Error definitions for the Polars DataFrame library" + +[dependencies] +arrow-format = { workspace = true, optional = true } +avro-schema = { workspace = true, optional = true } +object_store = { workspace = true, optional = true } +parking_lot = { workspace = true } +regex = { workspace = true, optional = true } +simdutf8 = { workspace = true } + +[target.'cfg(not(target_family = "wasm"))'.dependencies] +signal-hook = "0.3" + +[features] +python = [] diff --git a/crates/polars-error/LICENSE b/crates/polars-error/LICENSE new file mode 120000 index 000000000000..30cff7403da0 --- /dev/null +++ b/crates/polars-error/LICENSE @@ -0,0 +1 @@ +../../LICENSE \ No newline at end of file diff --git a/crates/polars-error/README.md b/crates/polars-error/README.md new file mode 100644 index 000000000000..a8af201817db --- /dev/null +++ b/crates/polars-error/README.md @@ -0,0 +1,7 @@ +# polars-error + +`polars-error` is an **internal sub-crate** of the [Polars](https://crates.io/crates/polars) +library, defining its error types. + +**Important Note**: This crate is **not intended for external usage**. Please refer to the main +[Polars crate](https://crates.io/crates/polars) for intended usage. diff --git a/crates/polars-error/src/constants.rs b/crates/polars-error/src/constants.rs new file mode 100644 index 000000000000..b6367e3abb2e --- /dev/null +++ b/crates/polars-error/src/constants.rs @@ -0,0 +1,17 @@ +//! Constant that help with creating error messages dependent on the host language. +#[cfg(feature = "python")] +pub static TRUE: &str = "True"; +#[cfg(feature = "python")] +pub static FALSE: &str = "False"; + +#[cfg(not(feature = "python"))] +pub static TRUE: &str = "true"; +#[cfg(not(feature = "python"))] +pub static FALSE: &str = "false"; + +#[cfg(not(feature = "python"))] +pub static LENGTH_LIMIT_MSG: &str = + "Polars' maximum length reached. Consider compiling with 'bigidx' feature."; +#[cfg(feature = "python")] +pub static LENGTH_LIMIT_MSG: &str = + "Polars' maximum length reached. Consider installing 'polars-u64-idx'."; diff --git a/crates/polars-error/src/lib.rs b/crates/polars-error/src/lib.rs new file mode 100644 index 000000000000..3531f19c7287 --- /dev/null +++ b/crates/polars-error/src/lib.rs @@ -0,0 +1,474 @@ +pub mod constants; +mod warning; + +use std::borrow::Cow; +use std::collections::TryReserveError; +use std::convert::Infallible; +use std::error::Error; +use std::fmt::{self, Display, Formatter, Write}; +use std::ops::Deref; +use std::sync::{Arc, LazyLock}; +use std::{env, io}; +pub mod signals; + +pub use warning::*; + +enum ErrorStrategy { + Panic, + WithBacktrace, + Normal, +} + +static ERROR_STRATEGY: LazyLock = LazyLock::new(|| { + if env::var("POLARS_PANIC_ON_ERR").as_deref() == Ok("1") { + ErrorStrategy::Panic + } else if env::var("POLARS_BACKTRACE_IN_ERR").as_deref() == Ok("1") { + ErrorStrategy::WithBacktrace + } else { + ErrorStrategy::Normal + } +}); + +#[derive(Debug, Clone)] +pub struct ErrString(Cow<'static, str>); + +impl ErrString { + pub const fn new_static(s: &'static str) -> Self { + Self(Cow::Borrowed(s)) + } +} + +impl From for ErrString +where + T: Into>, +{ + fn from(msg: T) -> Self { + match &*ERROR_STRATEGY { + ErrorStrategy::Panic => panic!("{}", msg.into()), + ErrorStrategy::WithBacktrace => ErrString(Cow::Owned(format!( + "{}\n\nRust backtrace:\n{}", + msg.into(), + std::backtrace::Backtrace::force_capture() + ))), + ErrorStrategy::Normal => ErrString(msg.into()), + } + } +} + +impl AsRef for ErrString { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl Deref for ErrString { + type Target = str; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Display for ErrString { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +#[derive(Debug, Clone)] +pub enum PolarsError { + AssertionError(ErrString), + ColumnNotFound(ErrString), + ComputeError(ErrString), + Duplicate(ErrString), + InvalidOperation(ErrString), + IO { + error: Arc, + msg: Option, + }, + NoData(ErrString), + OutOfBounds(ErrString), + SchemaFieldNotFound(ErrString), + SchemaMismatch(ErrString), + ShapeMismatch(ErrString), + SQLInterface(ErrString), + SQLSyntax(ErrString), + StringCacheMismatch(ErrString), + StructFieldNotFound(ErrString), + Context { + error: Box, + msg: ErrString, + }, +} + +impl Error for PolarsError {} + +impl Display for PolarsError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + use PolarsError::*; + match self { + ComputeError(msg) + | InvalidOperation(msg) + | OutOfBounds(msg) + | SchemaMismatch(msg) + | SQLInterface(msg) + | SQLSyntax(msg) => write!(f, "{msg}"), + + AssertionError(msg) => write!(f, "assertion failed: {msg}"), + ColumnNotFound(msg) => write!(f, "not found: {msg}"), + Duplicate(msg) => write!(f, "duplicate: {msg}"), + IO { error, msg } => match msg { + Some(m) => write!(f, "{m}"), + None => write!(f, "{error}"), + }, + NoData(msg) => write!(f, "no data: {msg}"), + SchemaFieldNotFound(msg) => write!(f, "field not found: {msg}"), + ShapeMismatch(msg) => write!(f, "lengths don't match: {msg}"), + StringCacheMismatch(msg) => write!(f, "string caches don't match: {msg}"), + StructFieldNotFound(msg) => write!(f, "field not found: {msg}"), + Context { error, msg } => write!(f, "{error}: {msg}"), + } + } +} + +impl From for PolarsError { + fn from(value: io::Error) -> Self { + PolarsError::IO { + error: Arc::new(value), + msg: None, + } + } +} + +#[cfg(feature = "regex")] +impl From for PolarsError { + fn from(err: regex::Error) -> Self { + PolarsError::ComputeError(format!("regex error: {err}").into()) + } +} + +#[cfg(feature = "object_store")] +impl From for PolarsError { + fn from(err: object_store::Error) -> Self { + std::io::Error::other(format!("object-store error: {err:?}")).into() + } +} + +#[cfg(feature = "avro-schema")] +impl From for PolarsError { + fn from(value: avro_schema::error::Error) -> Self { + polars_err!(ComputeError: "avro-error: {}", value) + } +} + +impl From for PolarsError { + fn from(value: simdutf8::basic::Utf8Error) -> Self { + polars_err!(ComputeError: "invalid utf8: {}", value) + } +} +#[cfg(feature = "arrow-format")] +impl From for PolarsError { + fn from(err: arrow_format::ipc::planus::Error) -> Self { + polars_err!(ComputeError: "parquet error: {err:?}") + } +} + +impl From for PolarsError { + fn from(value: TryReserveError) -> Self { + polars_err!(ComputeError: "OOM: {}", value) + } +} + +impl From for PolarsError { + fn from(_: Infallible) -> Self { + unreachable!() + } +} + +pub type PolarsResult = Result; + +impl PolarsError { + pub fn context_trace(self) -> Self { + use PolarsError::*; + match self { + Context { error, msg } => { + // If context is 1 level deep, just return error. + if !matches!(&*error, PolarsError::Context { .. }) { + return *error; + } + let mut current_error = &*error; + let material_error = error.get_err(); + + let mut messages = vec![&msg]; + + while let PolarsError::Context { msg, error } = current_error { + current_error = error; + messages.push(msg) + } + + let mut bt = String::new(); + + let mut count = 0; + while let Some(msg) = messages.pop() { + count += 1; + writeln!(&mut bt, "\t[{count}] {}", msg).unwrap(); + } + material_error.wrap_msg(move |msg| { + format!("{msg}\n\nThis error occurred with the following context stack:\n{bt}") + }) + }, + err => err, + } + } + + pub fn wrap_msg String>(&self, func: F) -> Self { + use PolarsError::*; + match self { + AssertionError(msg) => AssertionError(func(msg).into()), + ColumnNotFound(msg) => ColumnNotFound(func(msg).into()), + ComputeError(msg) => ComputeError(func(msg).into()), + Duplicate(msg) => Duplicate(func(msg).into()), + InvalidOperation(msg) => InvalidOperation(func(msg).into()), + IO { error, msg } => { + let msg = match msg { + Some(msg) => func(msg), + None => func(&format!("{}", error)), + }; + IO { + error: error.clone(), + msg: Some(msg.into()), + } + }, + NoData(msg) => NoData(func(msg).into()), + OutOfBounds(msg) => OutOfBounds(func(msg).into()), + SchemaFieldNotFound(msg) => SchemaFieldNotFound(func(msg).into()), + SchemaMismatch(msg) => SchemaMismatch(func(msg).into()), + ShapeMismatch(msg) => ShapeMismatch(func(msg).into()), + StringCacheMismatch(msg) => StringCacheMismatch(func(msg).into()), + StructFieldNotFound(msg) => StructFieldNotFound(func(msg).into()), + SQLInterface(msg) => SQLInterface(func(msg).into()), + SQLSyntax(msg) => SQLSyntax(func(msg).into()), + Context { error, .. } => error.wrap_msg(func), + } + } + + fn get_err(&self) -> &Self { + use PolarsError::*; + match self { + Context { error, .. } => error.get_err(), + err => err, + } + } + + pub fn context(self, msg: ErrString) -> Self { + PolarsError::Context { + msg, + error: Box::new(self), + } + } +} + +pub fn map_err(error: E) -> PolarsError { + PolarsError::ComputeError(format!("{error}").into()) +} + +#[macro_export] +macro_rules! polars_err { + ($variant:ident: $fmt:literal $(, $arg:expr)* $(,)?) => { + $crate::__private::must_use( + $crate::PolarsError::$variant(format!($fmt, $($arg),*).into()) + ) + }; + ($variant:ident: $err:expr $(,)?) => { + $crate::__private::must_use( + $crate::PolarsError::$variant($err.into()) + ) + }; + (expr = $expr:expr, $variant:ident: $err:expr $(,)?) => { + $crate::__private::must_use( + $crate::PolarsError::$variant( + format!("{}\n\nError originated in expression: '{:?}'", $err, $expr).into() + ) + ) + }; + (expr = $expr:expr, $variant:ident: $fmt:literal, $($arg:tt)+) => { + polars_err!(expr = $expr, $variant: format!($fmt, $($arg)+)) + }; + (op = $op:expr, got = $arg:expr, expected = $expected:expr) => { + $crate::polars_err!( + InvalidOperation: "{} operation not supported for dtype `{}` (expected: {})", + $op, $arg, $expected + ) + }; + (opq = $op:ident, got = $arg:expr, expected = $expected:expr) => { + $crate::polars_err!( + op = concat!("`", stringify!($op), "`"), got = $arg, expected = $expected + ) + }; + (un_impl = $op:ident) => { + $crate::polars_err!( + InvalidOperation: "{} operation is not implemented.", concat!("`", stringify!($op), "`") + ) + }; + (op = $op:expr, $arg:expr) => { + $crate::polars_err!( + InvalidOperation: "{} operation not supported for dtype `{}`", $op, $arg + ) + }; + (op = $op:expr, $arg:expr, hint = $hint:literal) => { + $crate::polars_err!( + InvalidOperation: "{} operation not supported for dtype `{}`\n\nHint: {}", $op, $arg, $hint + ) + }; + (op = $op:expr, $lhs:expr, $rhs:expr) => { + $crate::polars_err!( + InvalidOperation: "{} operation not supported for dtypes `{}` and `{}`", $op, $lhs, $rhs + ) + }; + (oos = $($tt:tt)+) => { + $crate::polars_err!(ComputeError: "out-of-spec: {}", $($tt)+) + }; + (nyi = $($tt:tt)+) => { + $crate::polars_err!(ComputeError: "not yet implemented: {}", format!($($tt)+) ) + }; + (opq = $op:ident, $arg:expr) => { + $crate::polars_err!(op = concat!("`", stringify!($op), "`"), $arg) + }; + (opq = $op:ident, $lhs:expr, $rhs:expr) => { + $crate::polars_err!(op = stringify!($op), $lhs, $rhs) + }; + (bigidx, ctx = $ctx:expr, size = $size:expr) => { + $crate::polars_err!(ComputeError: "\ +{} produces {} rows which is more than maximum allowed pow(2, 32) rows; \ +consider compiling with bigidx feature (polars-u64-idx package on python)", + $ctx, + $size, + ) + }; + (append) => { + polars_err!(SchemaMismatch: "cannot append series, data types don't match") + }; + (extend) => { + polars_err!(SchemaMismatch: "cannot extend series, data types don't match") + }; + (unpack) => { + polars_err!(SchemaMismatch: "cannot unpack series, data types don't match") + }; + (not_in_enum,value=$value:expr,categories=$categories:expr) =>{ + polars_err!(ComputeError: "value '{}' is not present in Enum: {:?}",$value,$categories) + }; + (string_cache_mismatch) => { + polars_err!(StringCacheMismatch: r#" +cannot compare categoricals coming from different sources, consider setting a global StringCache. + +Help: if you're using Python, this may look something like: + + with pl.StringCache(): + df1 = pl.DataFrame({'a': ['1', '2']}, schema={'a': pl.Categorical}) + df2 = pl.DataFrame({'a': ['1', '3']}, schema={'a': pl.Categorical}) + pl.concat([df1, df2]) + +Alternatively, if the performance cost is acceptable, you could just set: + + import polars as pl + pl.enable_string_cache() + +on startup."#.trim_start()) + }; + (duplicate = $name:expr) => { + $crate::polars_err!(Duplicate: "column with name '{}' has more than one occurrence", $name) + }; + (col_not_found = $name:expr) => { + $crate::polars_err!(ColumnNotFound: "{:?} not found", $name) + }; + (mismatch, col=$name:expr, expected=$expected:expr, found=$found:expr) => { + $crate::polars_err!( + SchemaMismatch: "data type mismatch for column {}: expected: {}, found: {}", + $name, + $expected, + $found, + ) + }; + (oob = $idx:expr, $len:expr) => { + polars_err!(OutOfBounds: "index {} is out of bounds for sequence of length {}", $idx, $len) + }; + (agg_len = $agg_len:expr, $groups_len:expr) => { + polars_err!( + ComputeError: + "returned aggregation is of different length: {} than the groups length: {}", + $agg_len, $groups_len + ) + }; + (parse_fmt_idk = $dtype:expr) => { + polars_err!( + ComputeError: "could not find an appropriate format to parse {}s, please define a format", + $dtype, + ) + }; + (length_mismatch = $operation:literal, $lhs:expr, $rhs:expr) => { + $crate::polars_err!( + ShapeMismatch: "arguments for `{}` have different lengths ({} != {})", + $operation, $lhs, $rhs + ) + }; + (length_mismatch = $operation:literal, $lhs:expr, $rhs:expr, argument = $argument:expr, argument_idx = $argument_idx:expr) => { + $crate::polars_err!( + ShapeMismatch: "argument {} called '{}' for `{}` have different lengths ({} != {})", + $argument_idx, $argument, $operation, $lhs, $rhs + ) + }; + (assertion_error = $objects:expr, $detail:expr, $lhs:expr, $rhs:expr) => { + $crate::polars_err!( + AssertionError: "{} are different ({})\n[left]: {}\n[right]: {}", + $objects, $detail, $lhs, $rhs + ) + }; +} + +#[macro_export] +macro_rules! polars_bail { + ($($tt:tt)+) => { + return Err($crate::polars_err!($($tt)+)) + }; +} + +#[macro_export] +macro_rules! polars_ensure { + ($cond:expr, $($tt:tt)+) => { + if !$cond { + $crate::polars_bail!($($tt)+); + } + }; +} + +#[inline] +#[cold] +#[must_use] +pub fn to_compute_err(err: impl Display) -> PolarsError { + PolarsError::ComputeError(err.to_string().into()) +} +#[macro_export] +macro_rules! feature_gated { + ($($feature:literal);*, $content:expr) => {{ + #[cfg(all($(feature = $feature),*))] + { + $content + } + #[cfg(not(all($(feature = $feature),*)))] + { + panic!("activate '{}' feature", concat!($($feature, ", "),*)) + } + }}; +} + +// Not public, referenced by macros only. +#[doc(hidden)] +pub mod __private { + #[doc(hidden)] + #[inline] + #[cold] + #[must_use] + pub fn must_use(error: crate::PolarsError) -> crate::PolarsError { + error + } +} diff --git a/crates/polars-error/src/signals.rs b/crates/polars-error/src/signals.rs new file mode 100644 index 000000000000..5a171f4cf585 --- /dev/null +++ b/crates/polars-error/src/signals.rs @@ -0,0 +1,115 @@ +use std::any::Any; +use std::panic::{UnwindSafe, catch_unwind}; +use std::sync::atomic::{AtomicU64, Ordering}; + +/// Python hooks SIGINT to instead generate a KeyboardInterrupt exception. +/// So we do the same to try and abort long-running computations and return to +/// Python so that the Python exception can be generated. +pub struct KeyboardInterrupt; + +// We use a unique string so we can detect it in backtraces. +static POLARS_KEYBOARD_INTERRUPT_STRING: &str = "__POLARS_KEYBOARD_INTERRUPT"; + +// Bottom bit: interrupt flag. +// Top 63 bits: number of alive interrupt catchers. +static INTERRUPT_STATE: AtomicU64 = AtomicU64::new(0); + +fn is_keyboard_interrupt(p: &dyn Any) -> bool { + if let Some(s) = p.downcast_ref::<&str>() { + s.contains(POLARS_KEYBOARD_INTERRUPT_STRING) + } else if let Some(s) = p.downcast_ref::() { + s.contains(POLARS_KEYBOARD_INTERRUPT_STRING) + } else { + false + } +} + +pub fn register_polars_keyboard_interrupt_hook() { + let default_hook = std::panic::take_hook(); + std::panic::set_hook(Box::new(move |p| { + // Suppress output if there is an active catcher and the panic message + // contains the keyboard interrupt string. + let num_catchers = INTERRUPT_STATE.load(Ordering::Relaxed) >> 1; + let suppress = num_catchers > 0 && is_keyboard_interrupt(p.payload()); + if !suppress { + default_hook(p); + } + })); + + // WASM doesn't support signals, so we just skip installing the hook there. + #[cfg(not(target_family = "wasm"))] + unsafe { + // SAFETY: we only do an atomic op in the signal handler, which is allowed. + signal_hook::low_level::register(signal_hook::consts::signal::SIGINT, move || { + // Set the interrupt flag, but only if there are active catchers. + INTERRUPT_STATE + .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |state| { + let num_catchers = state >> 1; + if num_catchers > 0 { + Some(state | 1) + } else { + None + } + }) + .ok(); + }) + .unwrap(); + } +} + +/// Checks if the keyboard interrupt flag is set, and if yes panics as a +/// keyboard interrupt. This function is very cheap. +#[inline(always)] +pub fn try_raise_keyboard_interrupt() { + if INTERRUPT_STATE.load(Ordering::Relaxed) & 1 != 0 { + try_raise_keyboard_interrupt_slow() + } +} + +#[inline(never)] +#[cold] +fn try_raise_keyboard_interrupt_slow() { + std::panic::panic_any(POLARS_KEYBOARD_INTERRUPT_STRING); +} + +/// Runs the passed function, catching any KeyboardInterrupts if they occur +/// while running the function. +pub fn catch_keyboard_interrupt R + UnwindSafe>( + try_fn: F, +) -> Result { + // Try to register this catcher (or immediately return if there is an + // uncaught interrupt). + try_register_catcher()?; + let ret = catch_unwind(try_fn); + unregister_catcher(); + ret.map_err(|p| { + if is_keyboard_interrupt(&*p) { + KeyboardInterrupt + } else { + std::panic::resume_unwind(p) + } + }) +} + +fn try_register_catcher() -> Result<(), KeyboardInterrupt> { + let old_state = INTERRUPT_STATE.fetch_add(2, Ordering::Relaxed); + if old_state & 1 != 0 { + unregister_catcher(); + return Err(KeyboardInterrupt); + } + Ok(()) +} + +fn unregister_catcher() { + INTERRUPT_STATE + .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |state| { + let num_catchers = state >> 1; + if num_catchers > 1 { + Some(state - 2) + } else { + // Last catcher, clear interrupt flag. + Some(0) + } + }) + .ok(); +} diff --git a/crates/polars-error/src/warning.rs b/crates/polars-error/src/warning.rs new file mode 100644 index 000000000000..da5d6c56afab --- /dev/null +++ b/crates/polars-error/src/warning.rs @@ -0,0 +1,46 @@ +use parking_lot::RwLock; + +type WarningFunction = fn(&str, PolarsWarning); +static WARNING_FUNCTION: RwLock = RwLock::new(eprintln); + +fn eprintln(fmt: &str, warning: PolarsWarning) { + eprintln!("{:?}: {}", warning, fmt); +} + +/// Set the function that will be called by the `polars_warn!` macro. +/// You can use this to set logging in polars. +pub fn set_warning_function(function: WarningFunction) { + *WARNING_FUNCTION.write() = function; +} + +pub fn get_warning_function() -> WarningFunction { + *WARNING_FUNCTION.read() +} + +#[derive(Debug)] +pub enum PolarsWarning { + Deprecation, + UserWarning, + CategoricalRemappingWarning, + MapWithoutReturnDtypeWarning, +} + +#[macro_export] +macro_rules! polars_warn { + ($variant:ident, $fmt:literal $(, $arg:tt)*) => { + {{ + let func = $crate::get_warning_function(); + let warn = $crate::PolarsWarning::$variant; + func(format!($fmt, $($arg)*).as_ref(), warn) + }} + }; + ($fmt:literal, $($arg:tt)+) => { + {{ + let func = $crate::get_warning_function(); + func(format!($fmt, $($arg)+).as_ref(), $crate::PolarsWarning::UserWarning) + }} + }; + ($($arg:tt)+) => { + polars_warn!("{}", $($arg)+); + }; +} diff --git a/crates/polars-expr/Cargo.toml b/crates/polars-expr/Cargo.toml new file mode 100644 index 000000000000..608771389461 --- /dev/null +++ b/crates/polars-expr/Cargo.toml @@ -0,0 +1,80 @@ +[package] +name = "polars-expr" +version.workspace = true +authors.workspace = true +edition.workspace = true +homepage.workspace = true +license.workspace = true +repository.workspace = true +description = "Physical expression implementation of the Polars project." + +[dependencies] +arrow = { workspace = true } +bitflags = { workspace = true } +hashbrown = { workspace = true } +num-traits = { workspace = true } +polars-compute = { workspace = true } +polars-core = { workspace = true, features = ["lazy", "zip_with", "random"] } +polars-io = { workspace = true, features = ["lazy"] } +polars-json = { workspace = true, optional = true } +polars-ops = { workspace = true, features = ["chunked_ids"] } +polars-plan = { workspace = true } +polars-row = { workspace = true } +polars-time = { workspace = true, optional = true } +polars-utils = { workspace = true } +rand = { workspace = true } +rayon = { workspace = true } +recursive = { workspace = true } + +[features] +nightly = ["polars-core/nightly", "polars-plan/nightly"] +streaming = ["polars-plan/streaming", "polars-ops/chunked_ids"] +parquet = ["polars-io/parquet", "polars-plan/parquet"] +temporal = [ + "dtype-datetime", + "dtype-date", + "dtype-time", + "dtype-i8", + "dtype-i16", + "dtype-duration", + "polars-plan/temporal", +] + +dtype-full = [ + "dtype-array", + "dtype-categorical", + "dtype-date", + "dtype-datetime", + "dtype-decimal", + "dtype-duration", + "dtype-i16", + "dtype-i128", + "dtype-i8", + "dtype-struct", + "dtype-time", + "dtype-u16", + "dtype-u8", +] +dtype-array = ["polars-plan/dtype-array", "polars-ops/dtype-array"] +dtype-categorical = ["polars-plan/dtype-categorical"] +dtype-date = ["polars-plan/dtype-date", "polars-time/dtype-date", "temporal"] +dtype-datetime = ["polars-plan/dtype-datetime", "polars-time/dtype-datetime", "temporal"] +dtype-decimal = ["polars-plan/dtype-decimal", "dtype-i128"] +dtype-duration = ["polars-plan/dtype-duration", "polars-time/dtype-duration", "temporal"] +dtype-i16 = ["polars-plan/dtype-i16"] +dtype-i8 = ["polars-plan/dtype-i8"] +dtype-i128 = ["polars-plan/dtype-i128"] +dtype-struct = ["polars-plan/dtype-struct", "polars-ops/dtype-struct"] +dtype-time = ["polars-plan/dtype-time", "polars-time/dtype-time", "temporal"] +dtype-u16 = ["polars-plan/dtype-u16"] +dtype-u8 = ["polars-plan/dtype-u8"] + +# operations +approx_unique = ["polars-plan/approx_unique"] +is_in = ["polars-plan/is_in", "polars-ops/is_in"] + +bitwise = ["polars-core/bitwise", "polars-plan/bitwise"] +round_series = ["polars-plan/round_series", "polars-ops/round_series"] +is_between = ["polars-plan/is_between"] +dynamic_group_by = ["polars-plan/dynamic_group_by", "polars-time", "temporal"] +propagate_nans = ["polars-plan/propagate_nans", "polars-ops/propagate_nans"] diff --git a/crates/polars-expr/LICENSE b/crates/polars-expr/LICENSE new file mode 120000 index 000000000000..30cff7403da0 --- /dev/null +++ b/crates/polars-expr/LICENSE @@ -0,0 +1 @@ +../../LICENSE \ No newline at end of file diff --git a/crates/polars-expr/README.md b/crates/polars-expr/README.md new file mode 100644 index 000000000000..73c6bc1c7c7a --- /dev/null +++ b/crates/polars-expr/README.md @@ -0,0 +1,8 @@ +# polars-expr + +Physical expression implementations. + +`polars-expr` is an **internal sub-crate** of the [Polars](https://crates.io/crates/polars) library. + +**Important Note**: This crate is **not intended for external usage**. Please refer to the main +[Polars crate](https://crates.io/crates/polars) for intended usage. diff --git a/crates/polars-expr/src/expressions/aggregation.rs b/crates/polars-expr/src/expressions/aggregation.rs new file mode 100644 index 000000000000..7b578cd77eb3 --- /dev/null +++ b/crates/polars-expr/src/expressions/aggregation.rs @@ -0,0 +1,807 @@ +use std::borrow::Cow; + +use arrow::array::*; +use arrow::compute::concatenate::concatenate; +use arrow::legacy::utils::CustomIterTools; +use arrow::offset::Offsets; +use polars_compute::rolling::QuantileMethod; +use polars_core::POOL; +use polars_core::prelude::*; +use polars_core::series::IsSorted; +use polars_core::utils::{_split_offsets, NoNull}; +#[cfg(feature = "propagate_nans")] +use polars_ops::prelude::nan_propagating_aggregate; +use rayon::prelude::*; + +use super::*; +use crate::expressions::AggState::AggregatedScalar; +use crate::expressions::{ + AggState, AggregationContext, PartitionedAggregation, PhysicalExpr, UpdateGroups, +}; + +#[derive(Debug, Clone, Copy)] +pub struct AggregationType { + pub(crate) groupby: GroupByMethod, + pub(crate) allow_threading: bool, +} + +pub(crate) struct AggregationExpr { + pub(crate) input: Arc, + pub(crate) agg_type: AggregationType, + field: Option, +} + +impl AggregationExpr { + pub fn new( + expr: Arc, + agg_type: AggregationType, + field: Option, + ) -> Self { + Self { + input: expr, + agg_type, + field, + } + } +} + +impl PhysicalExpr for AggregationExpr { + fn as_expression(&self) -> Option<&Expr> { + None + } + + fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult { + let s = self.input.evaluate(df, state)?; + + let AggregationType { + groupby, + allow_threading, + } = self.agg_type; + + let is_float = s.dtype().is_float(); + let group_by = match groupby { + GroupByMethod::NanMin if !is_float => GroupByMethod::Min, + GroupByMethod::NanMax if !is_float => GroupByMethod::Max, + gb => gb, + }; + + match group_by { + GroupByMethod::Min => match s.is_sorted_flag() { + IsSorted::Ascending | IsSorted::Descending => { + s.min_reduce().map(|sc| sc.into_column(s.name().clone())) + }, + IsSorted::Not => parallel_op_columns( + |s| s.min_reduce().map(|sc| sc.into_column(s.name().clone())), + s, + allow_threading, + ), + }, + #[cfg(feature = "propagate_nans")] + GroupByMethod::NanMin => parallel_op_columns( + |s| { + Ok(polars_ops::prelude::nan_propagating_aggregate::nan_min_s( + s.as_materialized_series(), + s.name().clone(), + ) + .into_column()) + }, + s, + allow_threading, + ), + #[cfg(not(feature = "propagate_nans"))] + GroupByMethod::NanMin => { + panic!("activate 'propagate_nans' feature") + }, + GroupByMethod::Max => match s.is_sorted_flag() { + IsSorted::Ascending | IsSorted::Descending => { + s.max_reduce().map(|sc| sc.into_column(s.name().clone())) + }, + IsSorted::Not => parallel_op_columns( + |s| s.max_reduce().map(|sc| sc.into_column(s.name().clone())), + s, + allow_threading, + ), + }, + #[cfg(feature = "propagate_nans")] + GroupByMethod::NanMax => parallel_op_columns( + |s| { + Ok(polars_ops::prelude::nan_propagating_aggregate::nan_max_s( + s.as_materialized_series(), + s.name().clone(), + ) + .into_column()) + }, + s, + allow_threading, + ), + #[cfg(not(feature = "propagate_nans"))] + GroupByMethod::NanMax => { + panic!("activate 'propagate_nans' feature") + }, + GroupByMethod::Median => s.median_reduce().map(|sc| sc.into_column(s.name().clone())), + GroupByMethod::Mean => Ok(s.mean_reduce().into_column(s.name().clone())), + GroupByMethod::First => Ok(if s.is_empty() { + Column::full_null(s.name().clone(), 1, s.dtype()) + } else { + s.head(Some(1)) + }), + GroupByMethod::Last => Ok(if s.is_empty() { + Column::full_null(s.name().clone(), 1, s.dtype()) + } else { + s.tail(Some(1)) + }), + GroupByMethod::Sum => parallel_op_columns( + |s| s.sum_reduce().map(|sc| sc.into_column(s.name().clone())), + s, + allow_threading, + ), + GroupByMethod::Groups => unreachable!(), + GroupByMethod::NUnique => s.n_unique().map(|count| { + IdxCa::from_slice(s.name().clone(), &[count as IdxSize]).into_column() + }), + GroupByMethod::Count { include_nulls } => { + let count = s.len() - s.null_count() * !include_nulls as usize; + + Ok(IdxCa::from_slice(s.name().clone(), &[count as IdxSize]).into_column()) + }, + GroupByMethod::Implode => s.implode().map(|ca| ca.into_column()), + GroupByMethod::Std(ddof) => s + .std_reduce(ddof) + .map(|sc| sc.into_column(s.name().clone())), + GroupByMethod::Var(ddof) => s + .var_reduce(ddof) + .map(|sc| sc.into_column(s.name().clone())), + GroupByMethod::Quantile(_, _) => unimplemented!(), + } + } + #[allow(clippy::ptr_arg)] + fn evaluate_on_groups<'a>( + &self, + df: &DataFrame, + groups: &'a GroupPositions, + state: &ExecutionState, + ) -> PolarsResult> { + let mut ac = self.input.evaluate_on_groups(df, groups, state)?; + // don't change names by aggregations as is done in polars-core + let keep_name = ac.get_values().name().clone(); + + // Literals cannot be aggregated except for implode. + polars_ensure!((!matches!(ac.agg_state(), AggState::Literal(_)) || matches!(self.agg_type.groupby, GroupByMethod::Implode)), ComputeError: "cannot aggregate a literal"); + + if let AggregatedScalar(_) = ac.agg_state() { + match self.agg_type.groupby { + GroupByMethod::Implode => {}, + _ => { + polars_bail!(ComputeError: "cannot aggregate as {}, the column is already aggregated", self.agg_type.groupby); + }, + } + } + + // SAFETY: + // groups must always be in bounds. + let out = unsafe { + match self.agg_type.groupby { + GroupByMethod::Min => { + let (c, groups) = ac.get_final_aggregation(); + let agg_c = c.agg_min(&groups); + AggregatedScalar(agg_c.with_name(keep_name)) + }, + GroupByMethod::Max => { + let (c, groups) = ac.get_final_aggregation(); + let agg_c = c.agg_max(&groups); + AggregatedScalar(agg_c.with_name(keep_name)) + }, + GroupByMethod::Median => { + let (c, groups) = ac.get_final_aggregation(); + let agg_c = c.agg_median(&groups); + AggregatedScalar(agg_c.with_name(keep_name)) + }, + GroupByMethod::Mean => { + let (c, groups) = ac.get_final_aggregation(); + let agg_c = c.agg_mean(&groups); + AggregatedScalar(agg_c.with_name(keep_name)) + }, + GroupByMethod::Sum => { + let (c, groups) = ac.get_final_aggregation(); + let agg_c = c.agg_sum(&groups); + AggregatedScalar(agg_c.with_name(keep_name)) + }, + GroupByMethod::Count { include_nulls } => { + if include_nulls || ac.get_values().null_count() == 0 { + // a few fast paths that prevent materializing new groups + match ac.update_groups { + UpdateGroups::WithSeriesLen => { + let list = ac + .get_values() + .list() + .expect("impl error, should be a list at this point"); + + let mut s = match list.chunks().len() { + 1 => { + let arr = list.downcast_iter().next().unwrap(); + let offsets = arr.offsets().as_slice(); + + let mut previous = 0i64; + let counts: NoNull = offsets[1..] + .iter() + .map(|&o| { + let len = (o - previous) as IdxSize; + previous = o; + len + }) + .collect_trusted(); + counts.into_inner() + }, + _ => { + let counts: NoNull = list + .amortized_iter() + .map(|s| { + if let Some(s) = s { + s.as_ref().len() as IdxSize + } else { + 1 + } + }) + .collect_trusted(); + counts.into_inner() + }, + }; + s.rename(keep_name); + AggregatedScalar(s.into_column()) + }, + UpdateGroups::WithGroupsLen => { + // no need to update the groups + // we can just get the attribute, because we only need the length, + // not the correct order + let mut ca = ac.groups.group_count(); + ca.rename(keep_name); + AggregatedScalar(ca.into_column()) + }, + // materialize groups + _ => { + let mut ca = ac.groups().group_count(); + ca.rename(keep_name); + AggregatedScalar(ca.into_column()) + }, + } + } else { + // TODO: optimize this/and write somewhere else. + match ac.agg_state() { + AggState::Literal(s) | AggState::AggregatedScalar(s) => { + AggregatedScalar(Column::new( + keep_name, + [(s.len() as IdxSize - s.null_count() as IdxSize)], + )) + }, + AggState::AggregatedList(s) => { + let ca = s.list()?; + let out: IdxCa = ca + .into_iter() + .map(|opt_s| { + opt_s + .map(|s| s.len() as IdxSize - s.null_count() as IdxSize) + }) + .collect(); + AggregatedScalar(out.into_column().with_name(keep_name)) + }, + AggState::NotAggregated(s) => { + let s = s.clone(); + let groups = ac.groups(); + let out: IdxCa = if matches!(s.dtype(), &DataType::Null) { + IdxCa::full(s.name().clone(), 0, groups.len()) + } else { + match groups.as_ref().as_ref() { + GroupsType::Idx(idx) => { + let s = s.rechunk(); + // @scalar-opt + // @partition-opt + let array = &s.as_materialized_series().chunks()[0]; + let validity = array.validity().unwrap(); + idx.iter() + .map(|(_, g)| { + let mut count = 0 as IdxSize; + // Count valid values + g.iter().for_each(|i| { + count += validity + .get_bit_unchecked(*i as usize) + as IdxSize; + }); + count + }) + .collect_ca_trusted_with_dtype(keep_name, IDX_DTYPE) + }, + GroupsType::Slice { groups, .. } => { + // Slice and use computed null count + groups + .iter() + .map(|g| { + let start = g[0]; + let len = g[1]; + len - s + .slice(start as i64, len as usize) + .null_count() + as IdxSize + }) + .collect_ca_trusted_with_dtype(keep_name, IDX_DTYPE) + }, + } + }; + AggregatedScalar(out.into_column()) + }, + } + } + }, + GroupByMethod::First => { + let (s, groups) = ac.get_final_aggregation(); + let agg_s = s.agg_first(&groups); + AggregatedScalar(agg_s.with_name(keep_name)) + }, + GroupByMethod::Last => { + let (s, groups) = ac.get_final_aggregation(); + let agg_s = s.agg_last(&groups); + AggregatedScalar(agg_s.with_name(keep_name)) + }, + GroupByMethod::NUnique => { + let (s, groups) = ac.get_final_aggregation(); + let agg_s = s.agg_n_unique(&groups); + AggregatedScalar(agg_s.with_name(keep_name)) + }, + GroupByMethod::Implode => { + // if the aggregation is already + // in an aggregate flat state for instance by + // a mean aggregation, we simply convert to list + // + // if it is not, we traverse the groups and create + // a list per group. + let c = match ac.agg_state() { + // mean agg: + // -> f64 -> list + AggregatedScalar(c) => c + .reshape_list(&[ + ReshapeDimension::Infer, + ReshapeDimension::new_dimension(1), + ]) + .unwrap(), + // Auto-imploded + AggState::NotAggregated(_) | AggState::AggregatedList(_) => { + ac._implode_no_agg(); + return Ok(ac); + }, + _ => { + let agg = ac.aggregated(); + agg.as_list().into_column() + }, + }; + AggState::AggregatedList(c.with_name(keep_name)) + }, + GroupByMethod::Groups => { + let mut column: ListChunked = ac.groups().as_list_chunked(); + column.rename(keep_name); + AggregatedScalar(column.into_column()) + }, + GroupByMethod::Std(ddof) => { + let (c, groups) = ac.get_final_aggregation(); + let agg_c = c.agg_std(&groups, ddof); + AggregatedScalar(agg_c.with_name(keep_name)) + }, + GroupByMethod::Var(ddof) => { + let (c, groups) = ac.get_final_aggregation(); + let agg_c = c.agg_var(&groups, ddof); + AggregatedScalar(agg_c.with_name(keep_name)) + }, + GroupByMethod::Quantile(_, _) => { + // implemented explicitly in AggQuantile struct + unimplemented!() + }, + GroupByMethod::NanMin => { + #[cfg(feature = "propagate_nans")] + { + let (c, groups) = ac.get_final_aggregation(); + let agg_c = if c.dtype().is_float() { + nan_propagating_aggregate::group_agg_nan_min_s( + c.as_materialized_series(), + &groups, + ) + .into_column() + } else { + c.agg_min(&groups) + }; + AggregatedScalar(agg_c.with_name(keep_name)) + } + #[cfg(not(feature = "propagate_nans"))] + { + panic!("activate 'propagate_nans' feature") + } + }, + GroupByMethod::NanMax => { + #[cfg(feature = "propagate_nans")] + { + let (c, groups) = ac.get_final_aggregation(); + let agg_c = if c.dtype().is_float() { + nan_propagating_aggregate::group_agg_nan_max_s( + c.as_materialized_series(), + &groups, + ) + .into_column() + } else { + c.agg_max(&groups) + }; + AggregatedScalar(agg_c.with_name(keep_name)) + } + #[cfg(not(feature = "propagate_nans"))] + { + panic!("activate 'propagate_nans' feature") + } + }, + } + }; + + Ok(AggregationContext::from_agg_state( + out, + Cow::Borrowed(groups), + )) + } + + fn to_field(&self, input_schema: &Schema) -> PolarsResult { + if let Some(field) = self.field.as_ref() { + Ok(field.clone()) + } else { + self.input.to_field(input_schema) + } + } + + fn is_scalar(&self) -> bool { + true + } + + fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> { + Some(self) + } +} + +impl PartitionedAggregation for AggregationExpr { + fn evaluate_partitioned( + &self, + df: &DataFrame, + groups: &GroupPositions, + state: &ExecutionState, + ) -> PolarsResult { + let expr = self.input.as_partitioned_aggregator().unwrap(); + let column = expr.evaluate_partitioned(df, groups, state)?; + + // SAFETY: + // groups are in bounds + unsafe { + match self.agg_type.groupby { + #[cfg(feature = "dtype-struct")] + GroupByMethod::Mean => { + let new_name = column.name().clone(); + + // ensure we don't overflow + // the all 8 and 16 bits integers are already upcasted to int16 on `agg_sum` + let mut agg_s = if matches!(column.dtype(), DataType::Int32 | DataType::UInt32) + { + column.cast(&DataType::Int64).unwrap().agg_sum(groups) + } else { + column.agg_sum(groups) + }; + agg_s.rename(new_name.clone()); + + if !agg_s.dtype().is_primitive_numeric() { + Ok(agg_s) + } else { + let agg_s = match agg_s.dtype() { + DataType::Float32 => agg_s, + _ => agg_s.cast(&DataType::Float64).unwrap(), + }; + let mut count_s = column.agg_valid_count(groups); + count_s.rename(PlSmallStr::from_static("__POLARS_COUNT")); + Ok( + StructChunked::from_columns(new_name, agg_s.len(), &[agg_s, count_s]) + .unwrap() + .into_column(), + ) + } + }, + GroupByMethod::Implode => { + let new_name = column.name().clone(); + let mut agg = column.agg_list(groups); + agg.rename(new_name); + Ok(agg) + }, + GroupByMethod::First => { + let mut agg = column.agg_first(groups); + agg.rename(column.name().clone()); + Ok(agg) + }, + GroupByMethod::Last => { + let mut agg = column.agg_last(groups); + agg.rename(column.name().clone()); + Ok(agg) + }, + GroupByMethod::Max => { + let mut agg = column.agg_max(groups); + agg.rename(column.name().clone()); + Ok(agg) + }, + GroupByMethod::Min => { + let mut agg = column.agg_min(groups); + agg.rename(column.name().clone()); + Ok(agg) + }, + GroupByMethod::Sum => { + let mut agg = column.agg_sum(groups); + agg.rename(column.name().clone()); + Ok(agg) + }, + GroupByMethod::Count { + include_nulls: true, + } => { + let mut ca = groups.group_count(); + ca.rename(column.name().clone()); + Ok(ca.into_column()) + }, + _ => { + unimplemented!() + }, + } + } + } + + fn finalize( + &self, + partitioned: Column, + groups: &GroupPositions, + _state: &ExecutionState, + ) -> PolarsResult { + match self.agg_type.groupby { + GroupByMethod::Count { + include_nulls: true, + } + | GroupByMethod::Sum => { + let mut agg = unsafe { partitioned.agg_sum(groups) }; + agg.rename(partitioned.name().clone()); + Ok(agg) + }, + #[cfg(feature = "dtype-struct")] + GroupByMethod::Mean => { + let new_name = partitioned.name().clone(); + match partitioned.dtype() { + DataType::Struct(_) => { + let ca = partitioned.struct_().unwrap(); + let fields = ca.fields_as_series(); + let sum = &fields[0]; + let count = &fields[1]; + let (agg_count, agg_s) = + unsafe { POOL.join(|| count.agg_sum(groups), || sum.agg_sum(groups)) }; + + // Ensure that we don't divide by zero by masking out zeros. + let agg_count = agg_count.idx().unwrap(); + let mask = agg_count.equal(0 as IdxSize); + let agg_count = agg_count.set(&mask, None).unwrap().into_series(); + + let agg_s = &agg_s / &agg_count.cast(agg_s.dtype()).unwrap(); + Ok(agg_s?.with_name(new_name).into_column()) + }, + _ => Ok(Column::full_null( + new_name, + groups.len(), + partitioned.dtype(), + )), + } + }, + GroupByMethod::Implode => { + // the groups are scattered over multiple groups/sub dataframes. + // we now must collect them into a single group + let ca = partitioned.list().unwrap(); + let new_name = partitioned.name().clone(); + + let mut values = Vec::with_capacity(groups.len()); + let mut can_fast_explode = true; + + let mut offsets = Vec::::with_capacity(groups.len() + 1); + let mut length_so_far = 0i64; + offsets.push(length_so_far); + + let mut process_group = |ca: ListChunked| -> PolarsResult<()> { + let s = ca.explode()?; + length_so_far += s.len() as i64; + offsets.push(length_so_far); + values.push(s.chunks()[0].clone()); + + if s.is_empty() { + can_fast_explode = false; + } + Ok(()) + }; + + match groups.as_ref() { + GroupsType::Idx(groups) => { + for (_, idx) in groups { + let ca = unsafe { + // SAFETY: + // The indexes of the group_by operation are never out of bounds + ca.take_unchecked(idx) + }; + process_group(ca)?; + } + }, + GroupsType::Slice { groups, .. } => { + for [first, len] in groups { + let len = *len as usize; + let ca = ca.slice(*first as i64, len); + process_group(ca)?; + } + }, + } + + let vals = values.iter().map(|arr| &**arr).collect::>(); + let values = concatenate(&vals).unwrap(); + + let dtype = ListArray::::default_datatype(values.dtype().clone()); + // SAFETY: offsets are monotonically increasing. + let arr = ListArray::::new( + dtype, + unsafe { Offsets::new_unchecked(offsets).into() }, + values, + None, + ); + let mut ca = ListChunked::with_chunk(new_name, arr); + if can_fast_explode { + ca.set_fast_explode() + } + Ok(ca.into_series().as_list().into_column()) + }, + GroupByMethod::First => { + let mut agg = unsafe { partitioned.agg_first(groups) }; + agg.rename(partitioned.name().clone()); + Ok(agg) + }, + GroupByMethod::Last => { + let mut agg = unsafe { partitioned.agg_last(groups) }; + agg.rename(partitioned.name().clone()); + Ok(agg) + }, + GroupByMethod::Max => { + let mut agg = unsafe { partitioned.agg_max(groups) }; + agg.rename(partitioned.name().clone()); + Ok(agg) + }, + GroupByMethod::Min => { + let mut agg = unsafe { partitioned.agg_min(groups) }; + agg.rename(partitioned.name().clone()); + Ok(agg) + }, + _ => unimplemented!(), + } + } +} + +pub struct AggQuantileExpr { + pub(crate) input: Arc, + pub(crate) quantile: Arc, + pub(crate) method: QuantileMethod, +} + +impl AggQuantileExpr { + pub fn new( + input: Arc, + quantile: Arc, + method: QuantileMethod, + ) -> Self { + Self { + input, + quantile, + method, + } + } + + fn get_quantile(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult { + let quantile = self.quantile.evaluate(df, state)?; + polars_ensure!(quantile.len() <= 1, ComputeError: + "polars only supports computing a single quantile; \ + make sure the 'quantile' expression input produces a single quantile" + ); + quantile.get(0).unwrap().try_extract() + } +} + +impl PhysicalExpr for AggQuantileExpr { + fn as_expression(&self) -> Option<&Expr> { + None + } + + fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult { + let input = self.input.evaluate(df, state)?; + let quantile = self.get_quantile(df, state)?; + input + .quantile_reduce(quantile, self.method) + .map(|sc| sc.into_column(input.name().clone())) + } + #[allow(clippy::ptr_arg)] + fn evaluate_on_groups<'a>( + &self, + df: &DataFrame, + groups: &'a GroupPositions, + state: &ExecutionState, + ) -> PolarsResult> { + let mut ac = self.input.evaluate_on_groups(df, groups, state)?; + // don't change names by aggregations as is done in polars-core + let keep_name = ac.get_values().name().clone(); + + let quantile = self.get_quantile(df, state)?; + + // SAFETY: + // groups are in bounds + let mut agg = unsafe { + ac.flat_naive() + .into_owned() + .agg_quantile(ac.groups(), quantile, self.method) + }; + agg.rename(keep_name); + Ok(AggregationContext::from_agg_state( + AggregatedScalar(agg), + Cow::Borrowed(groups), + )) + } + + fn to_field(&self, input_schema: &Schema) -> PolarsResult { + self.input.to_field(input_schema) + } + + fn is_scalar(&self) -> bool { + true + } +} + +/// Simple wrapper to parallelize functions that can be divided over threads aggregated and +/// finally aggregated in the main thread. This can be done for sum, min, max, etc. +fn parallel_op_columns(f: F, s: Column, allow_threading: bool) -> PolarsResult +where + F: Fn(Column) -> PolarsResult + Send + Sync, +{ + // set during debug low so + // we mimic production size data behavior + #[cfg(debug_assertions)] + let thread_boundary = 0; + + #[cfg(not(debug_assertions))] + let thread_boundary = 100_000; + + // Temporary until categorical min/max multithreading implementation is corrected. + #[cfg(feature = "dtype-categorical")] + let is_categorical = matches!(s.dtype(), &DataType::Categorical(_, _)); + #[cfg(not(feature = "dtype-categorical"))] + let is_categorical = false; + // threading overhead/ splitting work stealing is costly.. + + if !allow_threading + || s.len() < thread_boundary + || POOL.current_thread_has_pending_tasks().unwrap_or(false) + || is_categorical + { + return f(s); + } + let n_threads = POOL.current_num_threads(); + let splits = _split_offsets(s.len(), n_threads); + + let chunks = POOL.install(|| { + splits + .into_par_iter() + .map(|(offset, len)| { + let s = s.slice(offset as i64, len); + f(s) + }) + .collect::>>() + })?; + + let mut iter = chunks.into_iter(); + let first = iter.next().unwrap(); + let dtype = first.dtype(); + let out = iter.fold(first.to_physical_repr(), |mut acc, s| { + acc.append(&s.to_physical_repr()).unwrap(); + acc + }); + + unsafe { f(out.from_physical_unchecked(dtype).unwrap()) } +} diff --git a/crates/polars-expr/src/expressions/alias.rs b/crates/polars-expr/src/expressions/alias.rs new file mode 100644 index 000000000000..188ad9d65ada --- /dev/null +++ b/crates/polars-expr/src/expressions/alias.rs @@ -0,0 +1,104 @@ +use polars_core::prelude::*; + +use super::*; +use crate::expressions::{AggregationContext, PartitionedAggregation, PhysicalExpr}; + +pub struct AliasExpr { + pub(crate) physical_expr: Arc, + pub(crate) name: PlSmallStr, + expr: Expr, +} + +impl AliasExpr { + pub fn new(physical_expr: Arc, name: PlSmallStr, expr: Expr) -> Self { + Self { + physical_expr, + name, + expr, + } + } + + fn finish(&self, input: Column) -> Column { + input.with_name(self.name.clone()) + } +} + +impl PhysicalExpr for AliasExpr { + fn as_expression(&self) -> Option<&Expr> { + Some(&self.expr) + } + + fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult { + let series = self.physical_expr.evaluate(df, state)?; + Ok(self.finish(series)) + } + + fn evaluate_inline_impl(&self, depth_limit: u8) -> Option { + let depth_limit = depth_limit.checked_sub(1)?; + self.physical_expr + .evaluate_inline_impl(depth_limit) + .map(|s| self.finish(s)) + } + + #[allow(clippy::ptr_arg)] + fn evaluate_on_groups<'a>( + &self, + df: &DataFrame, + groups: &'a GroupPositions, + state: &ExecutionState, + ) -> PolarsResult> { + let mut ac = self.physical_expr.evaluate_on_groups(df, groups, state)?; + let c = ac.take(); + let c = self.finish(c); + + if ac.is_literal() { + ac.with_literal(c); + } else { + ac.with_values(c, ac.is_aggregated(), Some(&self.expr))?; + } + Ok(ac) + } + + fn to_field(&self, input_schema: &Schema) -> PolarsResult { + Ok(Field::new( + self.name.clone(), + self.physical_expr.to_field(input_schema)?.dtype().clone(), + )) + } + + fn is_literal(&self) -> bool { + self.physical_expr.is_literal() + } + + fn is_scalar(&self) -> bool { + self.physical_expr.is_scalar() + } + + fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> { + Some(self) + } +} + +impl PartitionedAggregation for AliasExpr { + fn evaluate_partitioned( + &self, + df: &DataFrame, + groups: &GroupPositions, + state: &ExecutionState, + ) -> PolarsResult { + let agg = self.physical_expr.as_partitioned_aggregator().unwrap(); + let s = agg.evaluate_partitioned(df, groups, state)?; + Ok(s.with_name(self.name.clone())) + } + + fn finalize( + &self, + partitioned: Column, + groups: &GroupPositions, + state: &ExecutionState, + ) -> PolarsResult { + let agg = self.physical_expr.as_partitioned_aggregator().unwrap(); + let s = agg.finalize(partitioned, groups, state)?; + Ok(s.with_name(self.name.clone())) + } +} diff --git a/crates/polars-expr/src/expressions/apply.rs b/crates/polars-expr/src/expressions/apply.rs new file mode 100644 index 000000000000..90b98606f648 --- /dev/null +++ b/crates/polars-expr/src/expressions/apply.rs @@ -0,0 +1,526 @@ +use std::borrow::Cow; +use std::sync::OnceLock; + +use polars_core::POOL; +use polars_core::chunked_array::builder::get_list_builder; +use polars_core::chunked_array::from_iterator_par::try_list_from_par_iter; +use polars_core::prelude::*; +use rayon::prelude::*; + +use super::*; +use crate::expressions::{ + AggState, AggregationContext, PartitionedAggregation, PhysicalExpr, UpdateGroups, +}; + +#[derive(Clone)] +pub struct ApplyExpr { + inputs: Vec>, + function: SpecialEq>, + expr: Expr, + collect_groups: ApplyOptions, + function_returns_scalar: bool, + function_operates_on_scalar: bool, + allow_rename: bool, + pass_name_to_apply: bool, + input_schema: SchemaRef, + allow_threading: bool, + check_lengths: bool, + allow_group_aware: bool, + output_field: Field, + inlined_eval: OnceLock>, +} + +impl ApplyExpr { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + inputs: Vec>, + function: SpecialEq>, + expr: Expr, + options: FunctionOptions, + allow_threading: bool, + input_schema: SchemaRef, + output_field: Field, + returns_scalar: bool, + ) -> Self { + #[cfg(debug_assertions)] + if matches!(options.collect_groups, ApplyOptions::ElementWise) + && options.flags.contains(FunctionFlags::RETURNS_SCALAR) + { + panic!( + "expr {:?} is not implemented correctly. 'returns_scalar' and 'elementwise' are mutually exclusive", + expr + ) + } + + Self { + inputs, + function, + expr, + collect_groups: options.collect_groups, + function_returns_scalar: options.flags.contains(FunctionFlags::RETURNS_SCALAR), + function_operates_on_scalar: returns_scalar, + allow_rename: options.flags.contains(FunctionFlags::ALLOW_RENAME), + pass_name_to_apply: options.flags.contains(FunctionFlags::PASS_NAME_TO_APPLY), + input_schema, + allow_threading, + check_lengths: options.check_lengths(), + allow_group_aware: options.flags.contains(FunctionFlags::ALLOW_GROUP_AWARE), + output_field, + inlined_eval: Default::default(), + } + } + + #[allow(clippy::ptr_arg)] + fn prepare_multiple_inputs<'a>( + &self, + df: &DataFrame, + groups: &'a GroupPositions, + state: &ExecutionState, + ) -> PolarsResult>> { + let f = |e: &Arc| e.evaluate_on_groups(df, groups, state); + if self.allow_threading { + POOL.install(|| self.inputs.par_iter().map(f).collect()) + } else { + self.inputs.iter().map(f).collect() + } + } + + fn finish_apply_groups<'a>( + &self, + mut ac: AggregationContext<'a>, + ca: ListChunked, + ) -> PolarsResult> { + let c = if self.function_returns_scalar { + ac.update_groups = UpdateGroups::No; + ca.explode().unwrap().into_column() + } else { + ac.with_update_groups(UpdateGroups::WithSeriesLen); + ca.into_series().into() + }; + + ac.with_values_and_args(c, true, None, false, self.function_returns_scalar)?; + + Ok(ac) + } + + fn get_input_schema(&self, _df: &DataFrame) -> Cow { + Cow::Borrowed(self.input_schema.as_ref()) + } + + /// Evaluates and flattens `Option` to `Column`. + fn eval_and_flatten(&self, inputs: &mut [Column]) -> PolarsResult { + if let Some(out) = self.function.call_udf(inputs)? { + Ok(out) + } else { + let field = self.to_field(self.input_schema.as_ref()).unwrap(); + Ok(Column::full_null(field.name().clone(), 1, field.dtype())) + } + } + fn apply_single_group_aware<'a>( + &self, + mut ac: AggregationContext<'a>, + ) -> PolarsResult> { + let s = ac.get_values(); + + #[allow(clippy::nonminimal_bool)] + { + polars_ensure!( + !(matches!(ac.agg_state(), AggState::AggregatedScalar(_)) && !s.dtype().is_list() ) , + expr = self.expr, + ComputeError: "cannot aggregate, the column is already aggregated", + ); + } + + let name = s.name().clone(); + let agg = ac.aggregated(); + // Collection of empty list leads to a null dtype. See: #3687. + if agg.is_empty() { + // Create input for the function to determine the output dtype, see #3946. + let agg = agg.list().unwrap(); + let input_dtype = agg.inner_dtype(); + let input = Column::full_null(PlSmallStr::EMPTY, 0, input_dtype); + + let output = self.eval_and_flatten(&mut [input])?; + let ca = ListChunked::full(name, output.as_materialized_series(), 0); + return self.finish_apply_groups(ac, ca); + } + + let f = |opt_s: Option| match opt_s { + None => Ok(None), + Some(mut s) => { + if self.pass_name_to_apply { + s.rename(name.clone()); + } + Ok(self + .function + .call_udf(&mut [Column::from(s)])? + .map(|c| c.as_materialized_series().clone())) + }, + }; + + let ca: ListChunked = if self.allow_threading { + let dtype = if self.output_field.dtype.is_known() && !self.output_field.dtype.is_null() + { + Some(self.output_field.dtype.clone()) + } else { + None + }; + + let lst = agg.list().unwrap(); + let iter = lst.par_iter().map(f); + + if let Some(dtype) = dtype { + // TODO! uncomment this line and remove debug_assertion after a while. + // POOL.install(|| { + // iter.collect_ca_with_dtype::>(PlSmallStr::EMPTY, DataType::List(Box::new(dtype))) + // })? + let out: ListChunked = POOL.install(|| iter.collect::>())?; + + if self.function_returns_scalar { + debug_assert_eq!(&DataType::List(Box::new(dtype)), out.dtype()); + } else { + debug_assert_eq!(&dtype, out.dtype()); + } + + out + } else { + POOL.install(|| try_list_from_par_iter(iter, PlSmallStr::EMPTY))? + } + } else { + agg.list() + .unwrap() + .into_iter() + .map(f) + .collect::>()? + }; + + self.finish_apply_groups(ac, ca.with_name(name)) + } + + /// Apply elementwise e.g. ignore the group/list indices. + fn apply_single_elementwise<'a>( + &self, + mut ac: AggregationContext<'a>, + ) -> PolarsResult> { + let (c, aggregated) = match ac.agg_state() { + AggState::AggregatedList(c) => { + let ca = c.list().unwrap(); + let out = ca.apply_to_inner(&|s| { + Ok(self + .eval_and_flatten(&mut [s.into_column()])? + .take_materialized_series()) + })?; + (out.into_column(), true) + }, + AggState::NotAggregated(c) => { + let (out, aggregated) = (self.eval_and_flatten(&mut [c.clone()])?, false); + check_map_output_len(c.len(), out.len(), &self.expr)?; + (out, aggregated) + }, + agg_state => { + ac.with_agg_state(agg_state.try_map(|s| self.eval_and_flatten(&mut [s.clone()]))?); + return Ok(ac); + }, + }; + + ac.with_values_and_args(c, aggregated, Some(&self.expr), true, self.is_scalar())?; + Ok(ac) + } + fn apply_multiple_group_aware<'a>( + &self, + mut acs: Vec>, + df: &DataFrame, + ) -> PolarsResult> { + let mut container = vec![Default::default(); acs.len()]; + let schema = self.get_input_schema(df); + let field = self.to_field(&schema)?; + + // Aggregate representation of the aggregation contexts, + // then unpack the lists and finally create iterators from this list chunked arrays. + let mut iters = acs + .iter_mut() + .map(|ac| ac.iter_groups(self.pass_name_to_apply)) + .collect::>(); + + // Length of the items to iterate over. + let len = iters[0].size_hint().0; + + let ca = if len == 0 { + let mut builder = get_list_builder(&field.dtype, len * 5, len, field.name); + for _ in 0..len { + container.clear(); + for iter in &mut iters { + match iter.next().unwrap() { + None => { + builder.append_null(); + }, + Some(s) => container.push(s.deep_clone().into()), + } + } + let out = self + .function + .call_udf(&mut container) + .map(|r| r.map(|c| c.as_materialized_series().clone()))?; + + builder.append_opt_series(out.as_ref())? + } + builder.finish() + } else { + // We still need this branch to materialize unknown/ data dependent types in eager. :( + (0..len) + .map(|_| { + container.clear(); + for iter in &mut iters { + match iter.next().unwrap() { + None => return Ok(None), + Some(s) => container.push(s.deep_clone().into()), + } + } + self.function + .call_udf(&mut container) + .map(|r| r.map(|c| c.as_materialized_series().clone())) + }) + .collect::>()? + .with_name(field.name.clone()) + }; + #[cfg(debug_assertions)] + { + let inner = ca.dtype().inner_dtype().unwrap(); + if field.dtype.is_known() { + assert_eq!(inner, &field.dtype); + } + } + + drop(iters); + + // Take the first aggregation context that as that is the input series. + let ac = acs.swap_remove(0); + self.finish_apply_groups(ac, ca) + } +} + +fn check_map_output_len(input_len: usize, output_len: usize, expr: &Expr) -> PolarsResult<()> { + polars_ensure!( + input_len == output_len, expr = expr, InvalidOperation: + "output length of `map` ({}) must be equal to the input length ({}); \ + consider using `apply` instead", output_len, input_len + ); + Ok(()) +} + +impl PhysicalExpr for ApplyExpr { + fn as_expression(&self) -> Option<&Expr> { + Some(&self.expr) + } + + fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult { + let f = |e: &Arc| e.evaluate(df, state); + let mut inputs = if self.allow_threading && self.inputs.len() > 1 { + POOL.install(|| { + self.inputs + .par_iter() + .map(f) + .collect::>>() + }) + } else { + self.inputs.iter().map(f).collect::>>() + }?; + + if self.allow_rename { + self.eval_and_flatten(&mut inputs) + } else { + let in_name = inputs[0].name().clone(); + Ok(self.eval_and_flatten(&mut inputs)?.with_name(in_name)) + } + } + + fn evaluate_inline_impl(&self, depth_limit: u8) -> Option { + // For predicate evaluation at I/O of: + // `lit("2024-01-01").str.strptime()` + + self.inlined_eval + .get_or_init(|| { + let depth_limit = depth_limit.checked_sub(1)?; + let mut inputs = self + .inputs + .iter() + .map(|x| x.evaluate_inline_impl(depth_limit).filter(|s| s.len() == 1)) + .collect::>>()?; + + self.eval_and_flatten(&mut inputs).ok() + }) + .clone() + } + + #[allow(clippy::ptr_arg)] + fn evaluate_on_groups<'a>( + &self, + df: &DataFrame, + groups: &'a GroupPositions, + state: &ExecutionState, + ) -> PolarsResult> { + polars_ensure!( + self.allow_group_aware, + expr = self.expr, + ComputeError: "this expression cannot run in the group_by context", + ); + if self.inputs.len() == 1 { + let mut ac = self.inputs[0].evaluate_on_groups(df, groups, state)?; + + match self.collect_groups { + ApplyOptions::ApplyList => { + let c = self.eval_and_flatten(&mut [ac.aggregated()])?; + ac.with_values(c, true, Some(&self.expr))?; + Ok(ac) + }, + ApplyOptions::GroupWise => self.apply_single_group_aware(ac), + ApplyOptions::ElementWise => self.apply_single_elementwise(ac), + } + } else { + let mut acs = self.prepare_multiple_inputs(df, groups, state)?; + + match self.collect_groups { + ApplyOptions::ApplyList => { + let mut c = acs.iter_mut().map(|ac| ac.aggregated()).collect::>(); + let c = self.eval_and_flatten(&mut c)?; + // take the first aggregation context that as that is the input series + let mut ac = acs.swap_remove(0); + ac.with_update_groups(UpdateGroups::WithGroupsLen); + ac.with_values(c, true, Some(&self.expr))?; + Ok(ac) + }, + ApplyOptions::GroupWise => self.apply_multiple_group_aware(acs, df), + ApplyOptions::ElementWise => { + let mut has_agg_list = false; + let mut has_agg_scalar = false; + let mut has_not_agg = false; + for ac in &acs { + match ac.state { + AggState::AggregatedList(_) => has_agg_list = true, + AggState::AggregatedScalar(_) => has_agg_scalar = true, + AggState::NotAggregated(_) => has_not_agg = true, + _ => {}, + } + } + if has_agg_list || (has_agg_scalar && has_not_agg) { + self.apply_multiple_group_aware(acs, df) + } else { + apply_multiple_elementwise( + acs, + self.function.as_ref(), + &self.expr, + self.check_lengths, + self.is_scalar(), + ) + } + }, + } + } + } + + fn to_field(&self, input_schema: &Schema) -> PolarsResult { + self.expr.to_field(input_schema, Context::Default) + } + fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> { + if self.inputs.len() == 1 && matches!(self.collect_groups, ApplyOptions::ElementWise) { + Some(self) + } else { + None + } + } + fn is_scalar(&self) -> bool { + self.function_returns_scalar || self.function_operates_on_scalar + } +} + +fn apply_multiple_elementwise<'a>( + mut acs: Vec>, + function: &dyn ColumnsUdf, + expr: &Expr, + check_lengths: bool, + returns_scalar: bool, +) -> PolarsResult> { + match acs.first().unwrap().agg_state() { + // A fast path that doesn't drop groups of the first arg. + // This doesn't require group re-computation. + AggState::AggregatedList(s) => { + let ca = s.list().unwrap(); + + let other = acs[1..] + .iter() + .map(|ac| ac.flat_naive().into_owned()) + .collect::>(); + + let out = ca.apply_to_inner(&|s| { + let mut args = Vec::with_capacity(other.len() + 1); + args.push(s.into()); + args.extend_from_slice(&other); + Ok(function + .call_udf(&mut args)? + .unwrap() + .as_materialized_series() + .clone()) + })?; + let mut ac = acs.swap_remove(0); + ac.with_values(out.into_column(), true, None)?; + Ok(ac) + }, + first_as => { + let check_lengths = check_lengths && !matches!(first_as, AggState::Literal(_)); + let aggregated = acs.iter().all(|ac| ac.is_aggregated() | ac.is_literal()) + && acs.iter().any(|ac| ac.is_aggregated()); + let mut c = acs + .iter_mut() + .enumerate() + .map(|(i, ac)| { + // Make sure the groups are updated because we are about to throw away + // the series length information, only on the first iteration. + if let (0, UpdateGroups::WithSeriesLen) = (i, &ac.update_groups) { + ac.groups(); + } + + ac.flat_naive().into_owned() + }) + .collect::>(); + + let input_len = c[0].len(); + let c = function.call_udf(&mut c)?.unwrap(); + if check_lengths { + check_map_output_len(input_len, c.len(), expr)?; + } + + // Take the first aggregation context that as that is the input series. + let mut ac = acs.swap_remove(0); + ac.with_values_and_args(c, aggregated, None, true, returns_scalar)?; + Ok(ac) + }, + } +} + +impl PartitionedAggregation for ApplyExpr { + fn evaluate_partitioned( + &self, + df: &DataFrame, + groups: &GroupPositions, + state: &ExecutionState, + ) -> PolarsResult { + let a = self.inputs[0].as_partitioned_aggregator().unwrap(); + let s = a.evaluate_partitioned(df, groups, state)?; + + if self.allow_rename { + self.eval_and_flatten(&mut [s]) + } else { + let in_name = s.name().clone(); + Ok(self.eval_and_flatten(&mut [s])?.with_name(in_name)) + } + } + + fn finalize( + &self, + partitioned: Column, + _groups: &GroupPositions, + _state: &ExecutionState, + ) -> PolarsResult { + Ok(partitioned) + } +} diff --git a/crates/polars-expr/src/expressions/binary.rs b/crates/polars-expr/src/expressions/binary.rs new file mode 100644 index 000000000000..8b2604663711 --- /dev/null +++ b/crates/polars-expr/src/expressions/binary.rs @@ -0,0 +1,306 @@ +use polars_core::POOL; +use polars_core::prelude::*; +#[cfg(feature = "round_series")] +use polars_ops::prelude::floor_div_series; + +use super::*; +use crate::expressions::{ + AggState, AggregationContext, PartitionedAggregation, PhysicalExpr, UpdateGroups, +}; + +#[derive(Clone)] +pub struct BinaryExpr { + left: Arc, + op: Operator, + right: Arc, + expr: Expr, + has_literal: bool, + allow_threading: bool, + is_scalar: bool, +} + +impl BinaryExpr { + pub fn new( + left: Arc, + op: Operator, + right: Arc, + expr: Expr, + has_literal: bool, + allow_threading: bool, + is_scalar: bool, + ) -> Self { + Self { + left, + op, + right, + expr, + has_literal, + allow_threading, + is_scalar, + } + } +} + +/// Can partially do operations in place. +fn apply_operator_owned(left: Column, right: Column, op: Operator) -> PolarsResult { + match op { + Operator::Plus => left.try_add_owned(right), + Operator::Minus => left.try_sub_owned(right), + Operator::Multiply + if left.dtype().is_primitive_numeric() && right.dtype().is_primitive_numeric() => + { + left.try_mul_owned(right) + }, + _ => apply_operator(&left, &right, op), + } +} + +pub fn apply_operator(left: &Column, right: &Column, op: Operator) -> PolarsResult { + use DataType::*; + match op { + Operator::Gt => ChunkCompareIneq::gt(left, right).map(|ca| ca.into_column()), + Operator::GtEq => ChunkCompareIneq::gt_eq(left, right).map(|ca| ca.into_column()), + Operator::Lt => ChunkCompareIneq::lt(left, right).map(|ca| ca.into_column()), + Operator::LtEq => ChunkCompareIneq::lt_eq(left, right).map(|ca| ca.into_column()), + Operator::Eq => ChunkCompareEq::equal(left, right).map(|ca| ca.into_column()), + Operator::NotEq => ChunkCompareEq::not_equal(left, right).map(|ca| ca.into_column()), + Operator::Plus => left + right, + Operator::Minus => left - right, + Operator::Multiply => left * right, + Operator::Divide => left / right, + Operator::TrueDivide => match left.dtype() { + #[cfg(feature = "dtype-decimal")] + Decimal(_, _) => left / right, + Duration(_) | Date | Datetime(_, _) | Float32 | Float64 => left / right, + #[cfg(feature = "dtype-array")] + Array(..) => left / right, + #[cfg(feature = "dtype-array")] + _ if right.dtype().is_array() => left / right, + List(_) => left / right, + _ if right.dtype().is_list() => left / right, + _ => { + if right.dtype().is_temporal() { + return left / right; + } + left.cast(&Float64)? / right.cast(&Float64)? + }, + }, + Operator::FloorDivide => { + #[cfg(feature = "round_series")] + { + floor_div_series( + left.as_materialized_series(), + right.as_materialized_series(), + ) + .map(Column::from) + } + #[cfg(not(feature = "round_series"))] + { + panic!("activate 'round_series' feature") + } + }, + Operator::And => left.bitand(right), + Operator::Or => left.bitor(right), + Operator::LogicalOr => left + .cast(&DataType::Boolean)? + .bitor(&right.cast(&DataType::Boolean)?), + Operator::LogicalAnd => left + .cast(&DataType::Boolean)? + .bitand(&right.cast(&DataType::Boolean)?), + Operator::Xor => left.bitxor(right), + Operator::Modulus => left % right, + Operator::EqValidity => left.equal_missing(right).map(|ca| ca.into_column()), + Operator::NotEqValidity => left.not_equal_missing(right).map(|ca| ca.into_column()), + } +} + +impl BinaryExpr { + fn apply_elementwise<'a>( + &self, + mut ac_l: AggregationContext<'a>, + ac_r: AggregationContext, + aggregated: bool, + ) -> PolarsResult> { + // We want to be able to mutate in place, so we take the lhs to make sure that we drop. + let lhs = ac_l.get_values().clone(); + let rhs = ac_r.get_values().clone(); + + // Drop lhs so that we might operate in place. + drop(ac_l.take()); + + let out = apply_operator_owned(lhs, rhs, self.op)?; + ac_l.with_values(out, aggregated, Some(&self.expr))?; + Ok(ac_l) + } + + fn apply_all_literal<'a>( + &self, + mut ac_l: AggregationContext<'a>, + mut ac_r: AggregationContext<'a>, + ) -> PolarsResult> { + let name = ac_l.get_values().name().clone(); + ac_l.groups(); + ac_r.groups(); + polars_ensure!(ac_l.groups.len() == ac_r.groups.len(), ComputeError: "lhs and rhs should have same group length"); + let left_c = ac_l.get_values().rechunk().into_column(); + let right_c = ac_r.get_values().rechunk().into_column(); + let res_c = apply_operator(&left_c, &right_c, self.op)?; + ac_l.with_update_groups(UpdateGroups::WithSeriesLen); + let res_s = if res_c.len() == 1 { + res_c.new_from_index(0, ac_l.groups.len()) + } else { + ListChunked::full(name, res_c.as_materialized_series(), ac_l.groups.len()).into_column() + }; + ac_l.with_values(res_s, true, Some(&self.expr))?; + Ok(ac_l) + } + + fn apply_group_aware<'a>( + &self, + mut ac_l: AggregationContext<'a>, + mut ac_r: AggregationContext<'a>, + ) -> PolarsResult> { + let name = ac_l.get_values().name().clone(); + let ca = ac_l + .iter_groups(false) + .zip(ac_r.iter_groups(false)) + .map(|(l, r)| { + Some(apply_operator( + &l?.as_ref().clone().into_column(), + &r?.as_ref().clone().into_column(), + self.op, + )) + }) + .map(|opt_res| opt_res.transpose()) + .collect::>()? + .with_name(name); + + ac_l.with_update_groups(UpdateGroups::WithSeriesLen); + ac_l.with_agg_state(AggState::AggregatedList(ca.into_column())); + Ok(ac_l) + } +} + +impl PhysicalExpr for BinaryExpr { + fn as_expression(&self) -> Option<&Expr> { + Some(&self.expr) + } + + fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult { + // Window functions may set a global state that determine their output + // state, so we don't let them run in parallel as they race + // they also saturate the thread pool by themselves, so that's fine. + let has_window = state.has_window(); + + let (lhs, rhs); + if has_window { + let mut state = state.split(); + state.remove_cache_window_flag(); + lhs = self.left.evaluate(df, &state)?; + rhs = self.right.evaluate(df, &state)?; + } else if !self.allow_threading || self.has_literal { + // Literals are free, don't pay par cost. + lhs = self.left.evaluate(df, state)?; + rhs = self.right.evaluate(df, state)?; + } else { + let (opt_lhs, opt_rhs) = POOL.install(|| { + rayon::join( + || self.left.evaluate(df, state), + || self.right.evaluate(df, state), + ) + }); + (lhs, rhs) = (opt_lhs?, opt_rhs?); + }; + polars_ensure!( + lhs.len() == rhs.len() || lhs.len() == 1 || rhs.len() == 1, + expr = self.expr, + ShapeMismatch: "cannot evaluate two Series of different lengths ({} and {})", + lhs.len(), rhs.len(), + ); + apply_operator_owned(lhs, rhs, self.op) + } + + #[allow(clippy::ptr_arg)] + fn evaluate_on_groups<'a>( + &self, + df: &DataFrame, + groups: &'a GroupPositions, + state: &ExecutionState, + ) -> PolarsResult> { + let (result_a, result_b) = POOL.install(|| { + rayon::join( + || self.left.evaluate_on_groups(df, groups, state), + || self.right.evaluate_on_groups(df, groups, state), + ) + }); + let mut ac_l = result_a?; + let ac_r = result_b?; + + match (ac_l.agg_state(), ac_r.agg_state()) { + (AggState::Literal(s), AggState::NotAggregated(_)) + | (AggState::NotAggregated(_), AggState::Literal(s)) => match s.len() { + 1 => self.apply_elementwise(ac_l, ac_r, false), + _ => self.apply_group_aware(ac_l, ac_r), + }, + (AggState::Literal(_), AggState::Literal(_)) => self.apply_all_literal(ac_l, ac_r), + (AggState::NotAggregated(_), AggState::NotAggregated(_)) => { + self.apply_elementwise(ac_l, ac_r, false) + }, + ( + AggState::AggregatedScalar(_) | AggState::Literal(_), + AggState::AggregatedScalar(_) | AggState::Literal(_), + ) => self.apply_elementwise(ac_l, ac_r, true), + (AggState::AggregatedScalar(_), AggState::NotAggregated(_)) + | (AggState::NotAggregated(_), AggState::AggregatedScalar(_)) => { + self.apply_group_aware(ac_l, ac_r) + }, + (AggState::AggregatedList(lhs), AggState::AggregatedList(rhs)) => { + let lhs = lhs.list().unwrap(); + let rhs = rhs.list().unwrap(); + let out = lhs.apply_to_inner(&|lhs| { + apply_operator(&lhs.into_column(), &rhs.get_inner().into_column(), self.op) + .map(|c| c.take_materialized_series()) + })?; + ac_l.with_values(out.into_column(), true, Some(&self.expr))?; + Ok(ac_l) + }, + _ => self.apply_group_aware(ac_l, ac_r), + } + } + + fn to_field(&self, input_schema: &Schema) -> PolarsResult { + self.expr.to_field(input_schema, Context::Default) + } + + fn is_scalar(&self) -> bool { + self.is_scalar + } + + fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> { + Some(self) + } +} + +impl PartitionedAggregation for BinaryExpr { + fn evaluate_partitioned( + &self, + df: &DataFrame, + groups: &GroupPositions, + state: &ExecutionState, + ) -> PolarsResult { + let left = self.left.as_partitioned_aggregator().unwrap(); + let right = self.right.as_partitioned_aggregator().unwrap(); + let left = left.evaluate_partitioned(df, groups, state)?; + let right = right.evaluate_partitioned(df, groups, state)?; + apply_operator(&left, &right, self.op) + } + + fn finalize( + &self, + partitioned: Column, + _groups: &GroupPositions, + _state: &ExecutionState, + ) -> PolarsResult { + Ok(partitioned) + } +} diff --git a/crates/polars-expr/src/expressions/cast.rs b/crates/polars-expr/src/expressions/cast.rs new file mode 100644 index 000000000000..62305a0ce682 --- /dev/null +++ b/crates/polars-expr/src/expressions/cast.rs @@ -0,0 +1,126 @@ +use std::sync::OnceLock; + +use polars_core::chunked_array::cast::CastOptions; +use polars_core::prelude::*; + +use super::*; +use crate::expressions::{AggState, AggregationContext, PartitionedAggregation, PhysicalExpr}; + +pub struct CastExpr { + pub(crate) input: Arc, + pub(crate) dtype: DataType, + pub(crate) expr: Expr, + pub(crate) options: CastOptions, + pub(crate) inlined_eval: OnceLock>, +} + +impl CastExpr { + fn finish(&self, input: &Column) -> PolarsResult { + input.cast_with_options(&self.dtype, self.options) + } +} + +impl PhysicalExpr for CastExpr { + fn as_expression(&self) -> Option<&Expr> { + Some(&self.expr) + } + + fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult { + let column = self.input.evaluate(df, state)?; + self.finish(&column) + } + + fn evaluate_inline_impl(&self, depth_limit: u8) -> Option { + self.inlined_eval + .get_or_init(|| { + let depth_limit = depth_limit.checked_sub(1)?; + self.input + .evaluate_inline_impl(depth_limit) + .filter(|x| x.len() == 1) + .and_then(|x| self.finish(&x).ok()) + }) + .clone() + } + + #[allow(clippy::ptr_arg)] + fn evaluate_on_groups<'a>( + &self, + df: &DataFrame, + groups: &'a GroupPositions, + state: &ExecutionState, + ) -> PolarsResult> { + let mut ac = self.input.evaluate_on_groups(df, groups, state)?; + + match ac.agg_state() { + // this will not explode and potentially increase memory due to overlapping groups + AggState::AggregatedList(s) => { + let ca = s.list().unwrap(); + let casted = ca.apply_to_inner(&|s| { + self.finish(&s.into_column()) + .map(|c| c.take_materialized_series()) + })?; + ac.with_values(casted.into_column(), true, None)?; + }, + AggState::AggregatedScalar(s) => { + let s = self.finish(&s.clone().into_column())?; + if ac.is_literal() { + ac.with_literal(s); + } else { + ac.with_values(s, true, None)?; + } + }, + _ => { + // before we flatten, make sure that groups are updated + ac.groups(); + + let s = ac.flat_naive(); + let s = self.finish(&s.as_ref().clone().into_column())?; + + if ac.is_literal() { + ac.with_literal(s); + } else { + ac.with_values(s, false, None)?; + } + }, + } + + Ok(ac) + } + + fn to_field(&self, input_schema: &Schema) -> PolarsResult { + self.input.to_field(input_schema).map(|mut fld| { + fld.coerce(self.dtype.clone()); + fld + }) + } + + fn is_scalar(&self) -> bool { + self.input.is_scalar() + } + + fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> { + Some(self) + } +} + +impl PartitionedAggregation for CastExpr { + fn evaluate_partitioned( + &self, + df: &DataFrame, + groups: &GroupPositions, + state: &ExecutionState, + ) -> PolarsResult { + let e = self.input.as_partitioned_aggregator().unwrap(); + self.finish(&e.evaluate_partitioned(df, groups, state)?) + } + + fn finalize( + &self, + partitioned: Column, + groups: &GroupPositions, + state: &ExecutionState, + ) -> PolarsResult { + let agg = self.input.as_partitioned_aggregator().unwrap(); + agg.finalize(partitioned, groups, state) + } +} diff --git a/crates/polars-expr/src/expressions/column.rs b/crates/polars-expr/src/expressions/column.rs new file mode 100644 index 000000000000..f45a036cba9d --- /dev/null +++ b/crates/polars-expr/src/expressions/column.rs @@ -0,0 +1,182 @@ +use std::borrow::Cow; + +use polars_core::prelude::*; +use polars_plan::constants::CSE_REPLACED; + +use super::*; +use crate::expressions::{AggregationContext, PartitionedAggregation, PhysicalExpr}; + +pub struct ColumnExpr { + name: PlSmallStr, + expr: Expr, + schema: SchemaRef, +} + +impl ColumnExpr { + pub fn new(name: PlSmallStr, expr: Expr, schema: SchemaRef) -> Self { + Self { name, expr, schema } + } +} + +impl ColumnExpr { + fn check_external_context( + &self, + out: PolarsResult, + state: &ExecutionState, + ) -> PolarsResult { + match out { + Ok(col) => Ok(col), + Err(e) => { + if state.ext_contexts.is_empty() { + Err(e) + } else { + for df in state.ext_contexts.as_ref() { + let out = df.column(&self.name); + if out.is_ok() { + return out.cloned(); + } + } + Err(e) + } + }, + } + } + + fn process_by_idx( + &self, + out: &Column, + _state: &ExecutionState, + _schema: &Schema, + df: &DataFrame, + check_state_schema: bool, + ) -> PolarsResult { + if out.name() != &*self.name { + if check_state_schema { + if let Some(schema) = _state.get_schema() { + return self.process_from_state_schema(df, _state, &schema); + } + } + + df.column(&self.name).cloned() + } else { + Ok(out.clone()) + } + } + fn process_by_linear_search( + &self, + df: &DataFrame, + _state: &ExecutionState, + _panic_during_test: bool, + ) -> PolarsResult { + df.column(&self.name).cloned() + } + + fn process_from_state_schema( + &self, + df: &DataFrame, + state: &ExecutionState, + schema: &Schema, + ) -> PolarsResult { + match schema.get_full(&self.name) { + None => self.process_by_linear_search(df, state, true), + Some((idx, _, _)) => match df.get_columns().get(idx) { + Some(out) => self.process_by_idx(out, state, schema, df, false), + None => self.process_by_linear_search(df, state, true), + }, + } + } + + fn process_cse(&self, df: &DataFrame, schema: &Schema) -> PolarsResult { + // The CSE columns are added on the rhs. + let offset = schema.len(); + let columns = &df.get_columns()[offset..]; + // Linear search will be relatively cheap as we only search the CSE columns. + Ok(columns + .iter() + .find(|s| s.name() == &self.name) + .unwrap() + .clone()) + } +} + +impl PhysicalExpr for ColumnExpr { + fn as_expression(&self) -> Option<&Expr> { + Some(&self.expr) + } + + fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult { + let out = match self.schema.get_full(&self.name) { + Some((idx, _, _)) => { + // check if the schema was correct + // if not do O(n) search + match df.get_columns().get(idx) { + Some(out) => self.process_by_idx(out, state, &self.schema, df, true), + None => { + // partitioned group_by special case + if let Some(schema) = state.get_schema() { + self.process_from_state_schema(df, state, &schema) + } else { + self.process_by_linear_search(df, state, true) + } + }, + } + }, + // in the future we will throw an error here + // now we do a linear search first as the lazy reported schema may still be incorrect + // in debug builds we panic so that it can be fixed when occurring + None => { + if self.name.starts_with(CSE_REPLACED) { + return self.process_cse(df, &self.schema); + } + self.process_by_linear_search(df, state, true) + }, + }; + self.check_external_context(out, state) + } + + #[allow(clippy::ptr_arg)] + fn evaluate_on_groups<'a>( + &self, + df: &DataFrame, + groups: &'a GroupPositions, + state: &ExecutionState, + ) -> PolarsResult> { + let c = self.evaluate(df, state)?; + Ok(AggregationContext::new(c, Cow::Borrowed(groups), false)) + } + + fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> { + Some(self) + } + + fn to_field(&self, input_schema: &Schema) -> PolarsResult { + input_schema.get_field(&self.name).ok_or_else(|| { + polars_err!( + ColumnNotFound: "could not find {:?} in schema: {:?}", self.name, &input_schema + ) + }) + } + fn is_scalar(&self) -> bool { + false + } +} + +impl PartitionedAggregation for ColumnExpr { + fn evaluate_partitioned( + &self, + df: &DataFrame, + _groups: &GroupPositions, + state: &ExecutionState, + ) -> PolarsResult { + self.evaluate(df, state) + } + + fn finalize( + &self, + partitioned: Column, + _groups: &GroupPositions, + _state: &ExecutionState, + ) -> PolarsResult { + Ok(partitioned) + } +} diff --git a/crates/polars-expr/src/expressions/count.rs b/crates/polars-expr/src/expressions/count.rs new file mode 100644 index 000000000000..c6183f5d4f21 --- /dev/null +++ b/crates/polars-expr/src/expressions/count.rs @@ -0,0 +1,76 @@ +use std::borrow::Cow; + +use polars_core::prelude::*; +use polars_plan::constants::LEN; + +use super::*; +use crate::expressions::{AggregationContext, PartitionedAggregation, PhysicalExpr}; + +pub struct CountExpr { + expr: Expr, +} + +impl CountExpr { + pub(crate) fn new() -> Self { + Self { expr: Expr::Len } + } +} + +impl PhysicalExpr for CountExpr { + fn as_expression(&self) -> Option<&Expr> { + Some(&self.expr) + } + + fn evaluate(&self, df: &DataFrame, _state: &ExecutionState) -> PolarsResult { + Ok(Series::new(PlSmallStr::from_static("len"), [df.height() as IdxSize]).into_column()) + } + + fn evaluate_on_groups<'a>( + &self, + _df: &DataFrame, + groups: &'a GroupPositions, + _state: &ExecutionState, + ) -> PolarsResult> { + let ca = groups.group_count().with_name(PlSmallStr::from_static(LEN)); + let c = ca.into_column(); + Ok(AggregationContext::new(c, Cow::Borrowed(groups), true)) + } + + fn to_field(&self, _input_schema: &Schema) -> PolarsResult { + Ok(Field::new(PlSmallStr::from_static(LEN), IDX_DTYPE)) + } + + fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> { + Some(self) + } + + fn is_scalar(&self) -> bool { + true + } +} + +impl PartitionedAggregation for CountExpr { + #[allow(clippy::ptr_arg)] + fn evaluate_partitioned( + &self, + df: &DataFrame, + groups: &GroupPositions, + state: &ExecutionState, + ) -> PolarsResult { + self.evaluate_on_groups(df, groups, state) + .map(|mut ac| ac.aggregated().into_column()) + } + + /// Called to merge all the partitioned results in a final aggregate. + #[allow(clippy::ptr_arg)] + fn finalize( + &self, + partitioned: Column, + groups: &GroupPositions, + _state: &ExecutionState, + ) -> PolarsResult { + // SAFETY: groups are in bounds. + let agg = unsafe { partitioned.agg_sum(groups) }; + Ok(agg.with_name(PlSmallStr::from_static(LEN))) + } +} diff --git a/crates/polars-expr/src/expressions/filter.rs b/crates/polars-expr/src/expressions/filter.rs new file mode 100644 index 000000000000..dd7f6534e74c --- /dev/null +++ b/crates/polars-expr/src/expressions/filter.rs @@ -0,0 +1,160 @@ +use arrow::legacy::is_valid::IsValid; +use polars_core::POOL; +use polars_core::prelude::*; +use polars_utils::idx_vec::IdxVec; +use rayon::prelude::*; + +use super::*; +use crate::expressions::UpdateGroups::WithSeriesLen; +use crate::expressions::{AggregationContext, PhysicalExpr}; + +pub struct FilterExpr { + pub(crate) input: Arc, + pub(crate) by: Arc, + expr: Expr, +} + +impl FilterExpr { + pub fn new(input: Arc, by: Arc, expr: Expr) -> Self { + Self { input, by, expr } + } +} + +impl PhysicalExpr for FilterExpr { + fn as_expression(&self) -> Option<&Expr> { + Some(&self.expr) + } + + fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult { + let s_f = || self.input.evaluate(df, state); + let predicate_f = || self.by.evaluate(df, state); + + let (series, predicate) = POOL.install(|| rayon::join(s_f, predicate_f)); + let (series, predicate) = (series?, predicate?); + + series.filter(predicate.bool()?) + } + + fn evaluate_on_groups<'a>( + &self, + df: &DataFrame, + groups: &'a GroupPositions, + state: &ExecutionState, + ) -> PolarsResult> { + let ac_s_f = || self.input.evaluate_on_groups(df, groups, state); + let ac_predicate_f = || self.by.evaluate_on_groups(df, groups, state); + + let (ac_s, ac_predicate) = POOL.install(|| rayon::join(ac_s_f, ac_predicate_f)); + let (mut ac_s, mut ac_predicate) = (ac_s?, ac_predicate?); + // Check if the groups are still equal, otherwise aggregate. + // TODO! create a special group iters that don't materialize + if !std::ptr::eq( + ac_s.groups.as_ref() as *const _, + ac_predicate.groups.as_ref() as *const _, + ) { + let _ = ac_s.aggregated(); + let _ = ac_predicate.aggregated(); + } + + if ac_predicate.is_aggregated() || ac_s.is_aggregated() { + let preds = ac_predicate.iter_groups(false); + let s = ac_s.aggregated(); + let ca = s.list()?; + let out = if ca.is_empty() { + // return an empty list if ca is empty. + ListChunked::full_null_with_dtype(ca.name().clone(), 0, ca.inner_dtype()) + } else { + { + ca.amortized_iter() + .zip(preds) + .map(|(opt_s, opt_pred)| match (opt_s, opt_pred) { + (Some(s), Some(pred)) => { + s.as_ref().filter(pred.as_ref().bool()?).map(Some) + }, + _ => Ok(None), + }) + .collect::>()? + .with_name(s.name().clone()) + } + }; + ac_s.with_values(out.into_column(), true, Some(&self.expr))?; + ac_s.update_groups = WithSeriesLen; + Ok(ac_s) + } else { + let groups = ac_s.groups(); + let predicate_s = ac_predicate.flat_naive(); + let predicate = predicate_s.bool()?; + + // All values true - don't do anything. + if let Some(true) = predicate.all_kleene() { + return Ok(ac_s); + } + // All values false - create empty groups. + let groups = if !predicate.any() { + let groups = groups.iter().map(|gi| [gi.first(), 0]).collect::>(); + GroupsType::Slice { + groups, + rolling: false, + } + } + // Filter the indexes that are true. + else { + let predicate = predicate.rechunk(); + let predicate = predicate.downcast_as_array(); + POOL.install(|| { + match groups.as_ref().as_ref() { + GroupsType::Idx(groups) => { + let groups = groups + .par_iter() + .map(|(first, idx)| unsafe { + let idx: IdxVec = idx + .iter() + .copied() + .filter(|i| { + // SAFETY: just checked bounds in short circuited lhs. + predicate.value(*i as usize) + && predicate.is_valid_unchecked(*i as usize) + }) + .collect(); + + (*idx.first().unwrap_or(&first), idx) + }) + .collect(); + + GroupsType::Idx(groups) + }, + GroupsType::Slice { groups, .. } => { + let groups = groups + .par_iter() + .map(|&[first, len]| unsafe { + let idx: IdxVec = (first..first + len) + .filter(|&i| { + // SAFETY: just checked bounds in short circuited lhs + predicate.value(i as usize) + && predicate.is_valid_unchecked(i as usize) + }) + .collect(); + + (*idx.first().unwrap_or(&first), idx) + }) + .collect(); + GroupsType::Idx(groups) + }, + } + }) + }; + + ac_s.with_groups(groups.into_sliceable()) + .set_original_len(false); + Ok(ac_s) + } + } + + fn to_field(&self, input_schema: &Schema) -> PolarsResult { + self.input.to_field(input_schema) + } + + fn is_scalar(&self) -> bool { + false + } +} diff --git a/crates/polars-expr/src/expressions/gather.rs b/crates/polars-expr/src/expressions/gather.rs new file mode 100644 index 000000000000..7941ed5f089d --- /dev/null +++ b/crates/polars-expr/src/expressions/gather.rs @@ -0,0 +1,275 @@ +use arrow::legacy::utils::CustomIterTools; +use polars_core::chunked_array::builder::get_list_builder; +use polars_core::prelude::*; +use polars_core::utils::NoNull; +use polars_ops::prelude::{convert_to_unsigned_index, is_positive_idx_uncertain_col}; + +use super::*; +use crate::expressions::{AggState, AggregationContext, PhysicalExpr, UpdateGroups}; + +pub struct GatherExpr { + pub(crate) phys_expr: Arc, + pub(crate) idx: Arc, + pub(crate) expr: Expr, + pub(crate) returns_scalar: bool, +} + +impl PhysicalExpr for GatherExpr { + fn as_expression(&self) -> Option<&Expr> { + Some(&self.expr) + } + + fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult { + let series = self.phys_expr.evaluate(df, state)?; + self.finish(df, state, series) + } + + #[allow(clippy::ptr_arg)] + fn evaluate_on_groups<'a>( + &self, + df: &DataFrame, + groups: &'a GroupPositions, + state: &ExecutionState, + ) -> PolarsResult> { + let mut ac = self.phys_expr.evaluate_on_groups(df, groups, state)?; + let mut idx = self.idx.evaluate_on_groups(df, groups, state)?; + + let c_idx = idx.get_values(); + match c_idx.dtype() { + DataType::List(inner) => { + polars_ensure!(inner.is_integer(), InvalidOperation: "expected numeric dtype as index, got {:?}", inner) + }, + dt if dt.is_integer() => { + // Unsigned integers will fall through and will use faster paths. + if !is_positive_idx_uncertain_col(c_idx) { + return self.process_negative_indices_agg(ac, idx, groups); + } + }, + dt => polars_bail!(InvalidOperation: "expected numeric dtype as index, got {:?}", dt), + } + + let idx = match idx.state { + AggState::AggregatedScalar(s) => { + let idx = s.cast(&IDX_DTYPE)?; + return self.process_positive_indices_agg_scalar(ac, idx.idx().unwrap()); + }, + AggState::AggregatedList(s) => { + polars_ensure!(!self.returns_scalar, ComputeError: "expected single index"); + s.list().unwrap().clone() + }, + // Maybe a literal as well, this needs a different path. + AggState::NotAggregated(_) => { + polars_ensure!(!self.returns_scalar, ComputeError: "expected single index"); + let s = idx.aggregated(); + s.list().unwrap().clone() + }, + AggState::Literal(s) => { + let idx = s.cast(&IDX_DTYPE)?; + return self.process_positive_indices_agg_literal(ac, idx.idx().unwrap()); + }, + }; + + let s = idx.cast(&DataType::List(Box::new(IDX_DTYPE)))?; + let idx = s.list().unwrap(); + + let taken = { + ac.aggregated() + .list() + .unwrap() + .amortized_iter() + .zip(idx.amortized_iter()) + .map(|(s, idx)| Some(s?.as_ref().take(idx?.as_ref().idx().unwrap()))) + .map(|opt_res| opt_res.transpose()) + .collect::>()? + .with_name(ac.get_values().name().clone()) + }; + + ac.with_values(taken.into_column(), true, Some(&self.expr))?; + ac.with_update_groups(UpdateGroups::WithSeriesLen); + Ok(ac) + } + + fn to_field(&self, input_schema: &Schema) -> PolarsResult { + self.phys_expr.to_field(input_schema) + } + + fn is_scalar(&self) -> bool { + self.returns_scalar + } +} + +impl GatherExpr { + fn finish( + &self, + df: &DataFrame, + state: &ExecutionState, + series: Column, + ) -> PolarsResult { + let idx = self.idx.evaluate(df, state)?; + let idx = convert_to_unsigned_index(idx.as_materialized_series(), series.len())?; + series.take(&idx) + } + + fn oob_err(&self) -> PolarsResult<()> { + polars_bail!(expr = self.expr, OutOfBounds: "index out of bounds"); + } + + fn process_positive_indices_agg_scalar<'b>( + &self, + mut ac: AggregationContext<'b>, + idx: &IdxCa, + ) -> PolarsResult> { + if ac.is_not_aggregated() { + // A previous aggregation may have updated the groups. + let groups = ac.groups(); + + // Determine the gather indices. + let idx: IdxCa = match groups.as_ref().as_ref() { + GroupsType::Idx(groups) => { + if groups.all().iter().zip(idx).any(|(g, idx)| match idx { + None => false, + Some(idx) => idx >= g.len() as IdxSize, + }) { + self.oob_err()?; + } + + idx.into_iter() + .zip(groups.iter()) + .map(|(idx, (_first, groups))| { + idx.map(|idx| { + // SAFETY: + // we checked bounds + unsafe { *groups.get_unchecked(usize::try_from(idx).unwrap()) } + }) + }) + .collect_trusted() + }, + GroupsType::Slice { groups, .. } => { + if groups.iter().zip(idx).any(|(g, idx)| match idx { + None => false, + Some(idx) => idx >= g[1], + }) { + self.oob_err()?; + } + + idx.into_iter() + .zip(groups.iter()) + .map(|(idx, g)| idx.map(|idx| idx + g[0])) + .collect_trusted() + }, + }; + + let taken = ac.flat_naive().take(&idx)?; + let taken = if self.returns_scalar { + taken + } else { + taken.as_list().into_column() + }; + + ac.with_values(taken, true, Some(&self.expr))?; + ac.with_update_groups(UpdateGroups::WithSeriesLen); + Ok(ac) + } else { + self.gather_aggregated_expensive(ac, idx) + } + } + + fn gather_aggregated_expensive<'b>( + &self, + mut ac: AggregationContext<'b>, + idx: &IdxCa, + ) -> PolarsResult> { + let out = ac + .aggregated() + .list() + .unwrap() + .try_apply_amortized(|s| s.as_ref().take(idx))?; + + ac.with_values(out.into_column(), true, Some(&self.expr))?; + ac.with_update_groups(UpdateGroups::WithSeriesLen); + Ok(ac) + } + + fn process_positive_indices_agg_literal<'b>( + &self, + mut ac: AggregationContext<'b>, + idx: &IdxCa, + ) -> PolarsResult> { + if idx.len() == 1 { + match idx.get(0) { + None => polars_bail!(ComputeError: "cannot take by a null"), + Some(idx) => { + // Make sure that we look at the updated groups. + let groups = ac.groups(); + + // We offset the groups first by idx. + let idx: NoNull = match groups.as_ref().as_ref() { + GroupsType::Idx(groups) => { + if groups.all().iter().any(|g| idx >= g.len() as IdxSize) { + self.oob_err()?; + } + + groups + .iter() + .map(|(_, group)| { + // SAFETY: we just bound checked. + unsafe { *group.get_unchecked(idx as usize) } + }) + .collect_trusted() + }, + GroupsType::Slice { groups, .. } => { + if groups.iter().any(|g| idx >= g[1]) { + self.oob_err()?; + } + + groups.iter().map(|g| g[0] + idx).collect_trusted() + }, + }; + let taken = ac.flat_naive().take(&idx.into_inner())?; + + let taken = if self.returns_scalar { + taken + } else { + taken.as_list().into_column() + }; + + ac.with_values(taken, true, Some(&self.expr))?; + ac.with_update_groups(UpdateGroups::WithSeriesLen); + Ok(ac) + }, + } + } else { + self.gather_aggregated_expensive(ac, idx) + } + } + + fn process_negative_indices_agg<'b>( + &self, + mut ac: AggregationContext<'b>, + mut idx: AggregationContext<'b>, + groups: &'b GroupsType, + ) -> PolarsResult> { + let mut builder = get_list_builder( + &ac.dtype(), + idx.get_values().len(), + groups.len(), + ac.get_values().name().clone(), + ); + + let iter = ac.iter_groups(false).zip(idx.iter_groups(false)); + for (s, idx) in iter { + match (s, idx) { + (Some(s), Some(idx)) => { + let idx = convert_to_unsigned_index(idx.as_ref(), s.as_ref().len())?; + let out = s.as_ref().take(&idx)?; + builder.append_series(&out)?; + }, + _ => builder.append_null(), + }; + } + let out = builder.finish().into_column(); + ac.with_agg_state(AggState::AggregatedList(out)); + ac.with_update_groups(UpdateGroups::WithSeriesLen); + Ok(ac) + } +} diff --git a/crates/polars-expr/src/expressions/group_iter.rs b/crates/polars-expr/src/expressions/group_iter.rs new file mode 100644 index 000000000000..acfcffcb7b32 --- /dev/null +++ b/crates/polars-expr/src/expressions/group_iter.rs @@ -0,0 +1,188 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use std::rc::Rc; + +use polars_core::series::amortized_iter::AmortSeries; + +use super::*; + +impl AggregationContext<'_> { + pub(super) fn iter_groups( + &mut self, + keep_names: bool, + ) -> Box> + '_> { + match self.agg_state() { + AggState::Literal(_) => { + self.groups(); + let c = self.get_values().rechunk(); + let name = if keep_names { + c.name().clone() + } else { + PlSmallStr::EMPTY + }; + // SAFETY: dtype is correct + unsafe { + Box::new(LitIter::new( + c.as_materialized_series().array_ref(0).clone(), + self.groups.len(), + c.dtype(), + name, + )) + } + }, + AggState::AggregatedScalar(_) => { + self.groups(); + let c = self.get_values(); + let name = if keep_names { + c.name().clone() + } else { + PlSmallStr::EMPTY + }; + // SAFETY: dtype is correct + unsafe { + Box::new(FlatIter::new( + c.as_materialized_series().chunks(), + self.groups.len(), + c.dtype(), + name, + )) + } + }, + AggState::AggregatedList(_) => { + let c = self.get_values(); + let list = c.list().unwrap(); + let name = if keep_names { + c.name().clone() + } else { + PlSmallStr::EMPTY + }; + Box::new(list.amortized_iter_with_name(name)) + }, + AggState::NotAggregated(_) => { + // we don't take the owned series as we want a reference + let _ = self.aggregated(); + let c = self.get_values(); + let list = c.list().unwrap(); + let name = if keep_names { + c.name().clone() + } else { + PlSmallStr::EMPTY + }; + Box::new(list.amortized_iter_with_name(name)) + }, + } + } +} + +struct LitIter { + len: usize, + offset: usize, + // AmortSeries referenced that series + #[allow(dead_code)] + series_container: Rc, + item: AmortSeries, +} + +impl LitIter { + /// # Safety + /// Caller must ensure the given `logical` dtype belongs to `array`. + unsafe fn new(array: ArrayRef, len: usize, logical: &DataType, name: PlSmallStr) -> Self { + let series_container = Rc::new(Series::from_chunks_and_dtype_unchecked( + name, + vec![array], + logical, + )); + + Self { + offset: 0, + len, + series_container: series_container.clone(), + // SAFETY: we pinned the series so the location is still valid + item: AmortSeries::new(series_container), + } + } +} + +impl Iterator for LitIter { + type Item = Option; + + fn next(&mut self) -> Option { + if self.len == self.offset { + None + } else { + self.offset += 1; + Some(Some(self.item.clone())) + } + } + + fn size_hint(&self) -> (usize, Option) { + (self.len, Some(self.len)) + } +} + +struct FlatIter { + current_array: ArrayRef, + chunks: Vec, + offset: usize, + chunk_offset: usize, + len: usize, + // AmortSeries referenced that series + #[allow(dead_code)] + series_container: Rc, + item: AmortSeries, +} + +impl FlatIter { + /// # Safety + /// Caller must ensure the given `logical` dtype belongs to `array`. + unsafe fn new(chunks: &[ArrayRef], len: usize, logical: &DataType, name: PlSmallStr) -> Self { + let mut stack = Vec::with_capacity(chunks.len()); + for chunk in chunks.iter().rev() { + stack.push(chunk.clone()) + } + let current_array = stack.pop().unwrap(); + let series_container = Rc::new(Series::from_chunks_and_dtype_unchecked( + name, + vec![current_array.clone()], + logical, + )); + Self { + current_array, + chunks: stack, + offset: 0, + chunk_offset: 0, + len, + series_container: series_container.clone(), + item: AmortSeries::new(series_container), + } + } +} + +impl Iterator for FlatIter { + type Item = Option; + + fn next(&mut self) -> Option { + if self.len == self.offset { + None + } else { + if self.chunk_offset < self.current_array.len() { + let mut arr = unsafe { self.current_array.sliced_unchecked(self.chunk_offset, 1) }; + unsafe { self.item.swap(&mut arr) }; + } else { + match self.chunks.pop() { + Some(arr) => { + self.current_array = arr; + self.chunk_offset = 0; + return self.next(); + }, + None => return None, + } + } + self.offset += 1; + self.chunk_offset += 1; + Some(Some(self.item.clone())) + } + } + fn size_hint(&self) -> (usize, Option) { + (self.len, Some(self.offset)) + } +} diff --git a/crates/polars-expr/src/expressions/literal.rs b/crates/polars-expr/src/expressions/literal.rs new file mode 100644 index 000000000000..1266f7e51418 --- /dev/null +++ b/crates/polars-expr/src/expressions/literal.rs @@ -0,0 +1,150 @@ +use std::borrow::Cow; +use std::ops::Deref; + +use arrow::temporal_conversions::NANOSECONDS_IN_DAY; +use polars_core::prelude::*; +use polars_core::utils::NoNull; +use polars_plan::constants::get_literal_name; + +use super::*; +use crate::expressions::{AggregationContext, PartitionedAggregation, PhysicalExpr}; + +pub struct LiteralExpr(pub LiteralValue, Expr); + +impl LiteralExpr { + pub fn new(value: LiteralValue, expr: Expr) -> Self { + Self(value, expr) + } + + fn as_column(&self) -> PolarsResult { + use LiteralValue as L; + let column = match &self.0 { + L::Scalar(sc) => { + #[cfg(feature = "dtype-time")] + if let AnyValue::Time(v) = sc.value() { + if !(0..NANOSECONDS_IN_DAY).contains(v) { + polars_bail!( + InvalidOperation: "value `{v}` is out-of-range for `time` which can be 0 - {}", + NANOSECONDS_IN_DAY - 1 + ); + } + } + + sc.clone().into_column(get_literal_name().clone()) + }, + L::Series(s) => s.deref().clone().into_column(), + lv @ L::Dyn(_) => polars_core::prelude::Series::from_any_values( + get_literal_name().clone(), + &[lv.to_any_value().unwrap()], + false, + ) + .unwrap() + .into_column(), + L::Range(RangeLiteralValue { low, high, dtype }) => { + let low = *low; + let high = *high; + match dtype { + DataType::Int32 => { + polars_ensure!( + low >= i32::MIN as i128 && high <= i32::MAX as i128, + ComputeError: "range not within bounds of `Int32`: [{}, {}]", low, high + ); + let low = low as i32; + let high = high as i32; + let ca: NoNull = (low..high).collect(); + ca.into_inner().into_column() + }, + DataType::Int64 => { + polars_ensure!( + low >= i64::MIN as i128 && high <= i64::MAX as i128, + ComputeError: "range not within bounds of `Int32`: [{}, {}]", low, high + ); + let low = low as i64; + let high = high as i64; + let ca: NoNull = (low..high).collect(); + ca.into_inner().into_column() + }, + DataType::UInt32 => { + polars_ensure!( + low >= u32::MIN as i128 && high <= u32::MAX as i128, + ComputeError: "range not within bounds of `UInt32`: [{}, {}]", low, high + ); + let low = low as u32; + let high = high as u32; + let ca: NoNull = (low..high).collect(); + ca.into_inner().into_column() + }, + dt => polars_bail!( + InvalidOperation: "datatype `{}` is not supported as range", dt + ), + } + }, + }; + Ok(column) + } +} + +impl PhysicalExpr for LiteralExpr { + fn as_expression(&self) -> Option<&Expr> { + Some(&self.1) + } + + fn evaluate(&self, _df: &DataFrame, _state: &ExecutionState) -> PolarsResult { + self.as_column() + } + + fn evaluate_inline_impl(&self, _depth_limit: u8) -> Option { + use LiteralValue::*; + match &self.0 { + Range { .. } => None, + _ => self.as_column().ok(), + } + } + + #[allow(clippy::ptr_arg)] + fn evaluate_on_groups<'a>( + &self, + df: &DataFrame, + groups: &'a GroupPositions, + state: &ExecutionState, + ) -> PolarsResult> { + let s = self.evaluate(df, state)?; + Ok(AggregationContext::from_literal(s, Cow::Borrowed(groups))) + } + + fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> { + Some(self) + } + + fn to_field(&self, _input_schema: &Schema) -> PolarsResult { + let dtype = self.0.get_datatype(); + Ok(Field::new(PlSmallStr::from_static("literal"), dtype)) + } + fn is_literal(&self) -> bool { + true + } + + fn is_scalar(&self) -> bool { + self.0.is_scalar() + } +} + +impl PartitionedAggregation for LiteralExpr { + fn evaluate_partitioned( + &self, + df: &DataFrame, + _groups: &GroupPositions, + state: &ExecutionState, + ) -> PolarsResult { + self.evaluate(df, state) + } + + fn finalize( + &self, + partitioned: Column, + _groups: &GroupPositions, + _state: &ExecutionState, + ) -> PolarsResult { + Ok(partitioned) + } +} diff --git a/crates/polars-expr/src/expressions/mod.rs b/crates/polars-expr/src/expressions/mod.rs new file mode 100644 index 000000000000..40509d69b59c --- /dev/null +++ b/crates/polars-expr/src/expressions/mod.rs @@ -0,0 +1,708 @@ +mod aggregation; +mod alias; +mod apply; +mod binary; +mod cast; +mod column; +mod count; +mod filter; +mod gather; +mod group_iter; +mod literal; +#[cfg(feature = "dynamic_group_by")] +mod rolling; +mod slice; +mod sort; +mod sortby; +mod ternary; +mod window; + +use std::borrow::Cow; +use std::fmt::{Display, Formatter}; + +pub(crate) use aggregation::*; +pub(crate) use alias::*; +pub(crate) use apply::*; +use arrow::array::ArrayRef; +use arrow::legacy::utils::CustomIterTools; +pub(crate) use binary::*; +pub(crate) use cast::*; +pub(crate) use column::*; +pub(crate) use count::*; +pub(crate) use filter::*; +pub(crate) use gather::*; +pub(crate) use literal::*; +use polars_core::prelude::*; +use polars_io::predicates::PhysicalIoExpr; +use polars_plan::prelude::*; +#[cfg(feature = "dynamic_group_by")] +pub(crate) use rolling::RollingExpr; +pub(crate) use slice::*; +pub(crate) use sort::*; +pub(crate) use sortby::*; +pub(crate) use ternary::*; +pub use window::window_function_format_order_by; +pub(crate) use window::*; + +use crate::state::ExecutionState; + +#[derive(Clone, Debug)] +pub enum AggState { + /// Already aggregated: `.agg_list(group_tuples)` is called + /// and produced a `Series` of dtype `List` + AggregatedList(Column), + /// Already aggregated: `.agg` is called on an aggregation + /// that produces a scalar. + /// think of `sum`, `mean`, `variance` like aggregations. + AggregatedScalar(Column), + /// Not yet aggregated: `agg_list` still has to be called. + NotAggregated(Column), + Literal(Column), +} + +impl AggState { + fn try_map(&self, func: F) -> PolarsResult + where + F: FnOnce(&Column) -> PolarsResult, + { + Ok(match self { + AggState::AggregatedList(c) => AggState::AggregatedList(func(c)?), + AggState::AggregatedScalar(c) => AggState::AggregatedScalar(func(c)?), + AggState::Literal(c) => AggState::Literal(func(c)?), + AggState::NotAggregated(c) => AggState::NotAggregated(func(c)?), + }) + } + + fn is_scalar(&self) -> bool { + matches!(self, Self::AggregatedScalar(_)) + } +} + +// lazy update strategy +#[cfg_attr(debug_assertions, derive(Debug))] +#[derive(PartialEq, Clone, Copy)] +pub(crate) enum UpdateGroups { + /// don't update groups + No, + /// use the length of the current groups to determine new sorted indexes, preferred + /// for performance + WithGroupsLen, + /// use the series list offsets to determine the new group lengths + /// this one should be used when the length has changed. Note that + /// the series should be aggregated state or else it will panic. + WithSeriesLen, +} + +#[cfg_attr(debug_assertions, derive(Debug))] +pub struct AggregationContext<'a> { + /// Can be in one of two states + /// 1. already aggregated as list + /// 2. flat (still needs the grouptuples to aggregate) + state: AggState, + /// group tuples for AggState + groups: Cow<'a, GroupPositions>, + /// if the group tuples are already used in a level above + /// and the series is exploded, the group tuples are sorted + /// e.g. the exploded Series is grouped per group. + sorted: bool, + /// This is used to determined if we need to update the groups + /// into a sorted groups. We do this lazily, so that this work only is + /// done when the groups are needed + update_groups: UpdateGroups, + /// This is true when the Series and Groups still have all + /// their original values. Not the case when filtered + original_len: bool, +} + +impl<'a> AggregationContext<'a> { + pub(crate) fn dtype(&self) -> DataType { + match &self.state { + AggState::Literal(s) => s.dtype().clone(), + AggState::AggregatedList(s) => s.list().unwrap().inner_dtype().clone(), + AggState::AggregatedScalar(s) => s.dtype().clone(), + AggState::NotAggregated(s) => s.dtype().clone(), + } + } + pub(crate) fn groups(&mut self) -> &Cow<'a, GroupPositions> { + match self.update_groups { + UpdateGroups::No => {}, + UpdateGroups::WithGroupsLen => { + // the groups are unordered + // and the series is aggregated with this groups + // so we need to recreate new grouptuples that + // match the exploded Series + let mut offset = 0 as IdxSize; + + match self.groups.as_ref().as_ref() { + GroupsType::Idx(groups) => { + let groups = groups + .iter() + .map(|g| { + let len = g.1.len() as IdxSize; + let new_offset = offset + len; + let out = [offset, len]; + offset = new_offset; + out + }) + .collect(); + self.groups = Cow::Owned( + GroupsType::Slice { + groups, + rolling: false, + } + .into_sliceable(), + ) + }, + // sliced groups are already in correct order + GroupsType::Slice { .. } => {}, + } + self.update_groups = UpdateGroups::No; + }, + UpdateGroups::WithSeriesLen => { + let s = self.get_values().clone(); + self.det_groups_from_list(s.as_materialized_series()); + }, + } + &self.groups + } + + pub(crate) fn get_values(&self) -> &Column { + match &self.state { + AggState::NotAggregated(s) + | AggState::AggregatedScalar(s) + | AggState::AggregatedList(s) => s, + AggState::Literal(s) => s, + } + } + + pub fn agg_state(&self) -> &AggState { + &self.state + } + + pub(crate) fn is_not_aggregated(&self) -> bool { + matches!( + &self.state, + AggState::NotAggregated(_) | AggState::Literal(_) + ) + } + + pub(crate) fn is_aggregated(&self) -> bool { + !self.is_not_aggregated() + } + + pub(crate) fn is_literal(&self) -> bool { + matches!(self.state, AggState::Literal(_)) + } + + /// # Arguments + /// - `aggregated` sets if the Series is a list due to aggregation (could also be a list because its + /// the columns dtype) + fn new( + column: Column, + groups: Cow<'a, GroupPositions>, + aggregated: bool, + ) -> AggregationContext<'a> { + let series = match (aggregated, column.dtype()) { + (true, &DataType::List(_)) => { + assert_eq!(column.len(), groups.len()); + AggState::AggregatedList(column) + }, + (true, _) => { + assert_eq!(column.len(), groups.len()); + AggState::AggregatedScalar(column) + }, + _ => AggState::NotAggregated(column), + }; + + Self { + state: series, + groups, + sorted: false, + update_groups: UpdateGroups::No, + original_len: true, + } + } + + fn with_agg_state(&mut self, agg_state: AggState) { + self.state = agg_state; + } + + fn from_agg_state( + agg_state: AggState, + groups: Cow<'a, GroupPositions>, + ) -> AggregationContext<'a> { + Self { + state: agg_state, + groups, + sorted: false, + update_groups: UpdateGroups::No, + original_len: true, + } + } + + fn from_literal(lit: Column, groups: Cow<'a, GroupPositions>) -> AggregationContext<'a> { + Self { + state: AggState::Literal(lit), + groups, + sorted: false, + update_groups: UpdateGroups::No, + original_len: true, + } + } + + pub(crate) fn set_original_len(&mut self, original_len: bool) -> &mut Self { + self.original_len = original_len; + self + } + + pub(crate) fn with_update_groups(&mut self, update: UpdateGroups) -> &mut Self { + self.update_groups = update; + self + } + + fn det_groups_from_list(&mut self, s: &Series) { + let mut offset = 0 as IdxSize; + let list = s + .list() + .expect("impl error, should be a list at this point"); + + match list.chunks().len() { + 1 => { + let arr = list.downcast_iter().next().unwrap(); + let offsets = arr.offsets().as_slice(); + + let mut previous = 0i64; + let groups = offsets[1..] + .iter() + .map(|&o| { + let len = (o - previous) as IdxSize; + // explode will fill empty rows with null, so we must increment the group + // offset accordingly + let new_offset = offset + len + (len == 0) as IdxSize; + + previous = o; + let out = [offset, len]; + offset = new_offset; + out + }) + .collect_trusted(); + self.groups = Cow::Owned( + GroupsType::Slice { + groups, + rolling: false, + } + .into_sliceable(), + ); + }, + _ => { + let groups = { + self.get_values() + .list() + .expect("impl error, should be a list at this point") + .amortized_iter() + .map(|s| { + if let Some(s) = s { + let len = s.as_ref().len() as IdxSize; + let new_offset = offset + len; + let out = [offset, len]; + offset = new_offset; + out + } else { + [offset, 0] + } + }) + .collect_trusted() + }; + self.groups = Cow::Owned( + GroupsType::Slice { + groups, + rolling: false, + } + .into_sliceable(), + ); + }, + } + self.update_groups = UpdateGroups::No; + } + + /// # Arguments + /// - `aggregated` sets if the Series is a list due to aggregation (could also be a list because its + /// the columns dtype) + pub(crate) fn with_values( + &mut self, + column: Column, + aggregated: bool, + expr: Option<&Expr>, + ) -> PolarsResult<&mut Self> { + self.with_values_and_args( + column, + aggregated, + expr, + false, + self.agg_state().is_scalar(), + ) + } + + pub(crate) fn with_values_and_args( + &mut self, + column: Column, + aggregated: bool, + expr: Option<&Expr>, + // if the applied function was a `map` instead of an `apply` + // this will keep functions applied over literals as literals: F(lit) = lit + mapped: bool, + returns_scalar: bool, + ) -> PolarsResult<&mut Self> { + self.state = match (aggregated, column.dtype()) { + (true, &DataType::List(_)) if !returns_scalar => { + if column.len() != self.groups.len() { + let fmt_expr = if let Some(e) = expr { + format!("'{e:?}' ") + } else { + String::new() + }; + polars_bail!( + ComputeError: + "aggregation expression '{}' produced a different number of elements: {} \ + than the number of groups: {} (this is likely invalid)", + fmt_expr, column.len(), self.groups.len(), + ); + } + AggState::AggregatedList(column) + }, + (true, _) => AggState::AggregatedScalar(column), + _ => { + match self.state { + // already aggregated to sum, min even this series was flattened it never could + // retrieve the length before grouping, so it stays in this state. + AggState::AggregatedScalar(_) => AggState::AggregatedScalar(column), + // applying a function on a literal, keeps the literal state + AggState::Literal(_) if column.len() == 1 && mapped => { + AggState::Literal(column) + }, + _ => AggState::NotAggregated(column.into_column()), + } + }, + }; + Ok(self) + } + + pub(crate) fn with_literal(&mut self, column: Column) -> &mut Self { + self.state = AggState::Literal(column); + self + } + + /// Update the group tuples + pub(crate) fn with_groups(&mut self, groups: GroupPositions) -> &mut Self { + if let AggState::AggregatedList(_) = self.agg_state() { + // In case of new groups, a series always needs to be flattened + self.with_values(self.flat_naive().into_owned(), false, None) + .unwrap(); + } + self.groups = Cow::Owned(groups); + // make sure that previous setting is not used + self.update_groups = UpdateGroups::No; + self + } + + pub(crate) fn _implode_no_agg(&mut self) { + match self.state.clone() { + AggState::NotAggregated(_) => { + let _ = self.aggregated(); + let AggState::AggregatedList(s) = self.state.clone() else { + unreachable!() + }; + self.state = AggState::AggregatedScalar(s); + }, + AggState::AggregatedList(s) => { + self.state = AggState::AggregatedScalar(s); + }, + _ => unreachable!("should only be called in non-agg/list-agg state by aggregation.rs"), + } + } + + /// Get the aggregated version of the series. + pub fn aggregated(&mut self) -> Column { + // we clone, because we only want to call `self.groups()` if needed. + // self groups may instantiate new groups and thus can be expensive. + match self.state.clone() { + AggState::NotAggregated(s) => { + // The groups are determined lazily and in case of a flat/non-aggregated + // series we use the groups to aggregate the list + // because this is lazy, we first must to update the groups + // by calling .groups() + self.groups(); + #[cfg(debug_assertions)] + { + if self.groups.len() > s.len() { + polars_warn!( + "groups may be out of bounds; more groups than elements in a series is only possible in dynamic group_by" + ) + } + } + + // SAFETY: + // groups are in bounds + let out = unsafe { s.agg_list(&self.groups) }; + self.state = AggState::AggregatedList(out.clone()); + + self.sorted = true; + self.update_groups = UpdateGroups::WithGroupsLen; + out + }, + AggState::AggregatedList(s) | AggState::AggregatedScalar(s) => s.into_column(), + AggState::Literal(s) => { + self.groups(); + let rows = self.groups.len(); + let s = s.new_from_index(0, rows); + let out = s + .reshape_list(&[ + ReshapeDimension::new_dimension(rows as u64), + ReshapeDimension::Infer, + ]) + .unwrap(); + self.state = AggState::AggregatedList(out.clone()); + out.into_column() + }, + } + } + + /// Get the final aggregated version of the series. + pub fn finalize(&mut self) -> Column { + // we clone, because we only want to call `self.groups()` if needed. + // self groups may instantiate new groups and thus can be expensive. + match &self.state { + AggState::Literal(c) => { + let c = c.clone(); + self.groups(); + let rows = self.groups.len(); + c.new_from_index(0, rows) + }, + _ => self.aggregated(), + } + } + + // If a binary or ternary function has both of these branches true, it should + // flatten the list + fn arity_should_explode(&self) -> bool { + use AggState::*; + match self.agg_state() { + Literal(s) => s.len() == 1, + AggregatedScalar(_) => true, + _ => false, + } + } + + pub fn get_final_aggregation(mut self) -> (Column, Cow<'a, GroupPositions>) { + let _ = self.groups(); + let groups = self.groups; + match self.state { + AggState::NotAggregated(c) => (c, groups), + AggState::AggregatedScalar(c) => (c, groups), + AggState::Literal(c) => (c, groups), + AggState::AggregatedList(c) => { + let flattened = c.explode().unwrap(); + let groups = groups.into_owned(); + // unroll the possible flattened state + // say we have groups with overlapping windows: + // + // offset, len + // 0, 1 + // 0, 2 + // 0, 4 + // + // gets aggregation + // + // [0] + // [0, 1], + // [0, 1, 2, 3] + // + // before aggregation the column was + // [0, 1, 2, 3] + // but explode on this list yields + // [0, 0, 1, 0, 1, 2, 3] + // + // so we unroll the groups as + // + // [0, 1] + // [1, 2] + // [3, 4] + let groups = groups.unroll(); + (flattened, Cow::Owned(groups)) + }, + } + } + + /// Get the not-aggregated version of the series. + /// Note that we call it naive, because if a previous expr + /// has filtered or sorted this, this information is in the + /// group tuples not the flattened series. + pub(crate) fn flat_naive(&self) -> Cow<'_, Column> { + match &self.state { + AggState::NotAggregated(c) => Cow::Borrowed(c), + AggState::AggregatedList(c) => { + #[cfg(debug_assertions)] + { + // panic so we find cases where we accidentally explode overlapping groups + // we don't want this as this can create a lot of data + if let GroupsType::Slice { rolling: true, .. } = self.groups.as_ref().as_ref() { + panic!( + "implementation error, polars should not hit this branch for overlapping groups" + ) + } + } + + Cow::Owned(c.explode().unwrap()) + }, + AggState::AggregatedScalar(c) => Cow::Borrowed(c), + AggState::Literal(c) => Cow::Borrowed(c), + } + } + + /// Take the series. + pub(crate) fn take(&mut self) -> Column { + let c = match &mut self.state { + AggState::NotAggregated(c) + | AggState::AggregatedScalar(c) + | AggState::AggregatedList(c) => c, + AggState::Literal(c) => c, + }; + std::mem::take(c) + } +} + +/// Take a DataFrame and evaluate the expressions. +/// Implement this for Column, lt, eq, etc +pub trait PhysicalExpr: Send + Sync { + fn as_expression(&self) -> Option<&Expr> { + None + } + + /// Take a DataFrame and evaluate the expression. + fn evaluate(&self, df: &DataFrame, _state: &ExecutionState) -> PolarsResult; + + /// Attempt to cheaply evaluate this expression in-line without a DataFrame context. + /// This is used by StatsEvaluator when skipping files / row groups using a predicate. + /// TODO: Maybe in the future we can do this evaluation in-line at the optimizer stage? + /// + /// Do not implement this directly - instead implement `evaluate_inline_impl` + fn evaluate_inline(&self) -> Option { + self.evaluate_inline_impl(4) + } + + /// Implementation of `evaluate_inline` + fn evaluate_inline_impl(&self, _depth_limit: u8) -> Option { + None + } + + /// Some expression that are not aggregations can be done per group + /// Think of sort, slice, filter, shift, etc. + /// defaults to ignoring the group + /// + /// This method is called by an aggregation function. + /// + /// In case of a simple expr, like 'column', the groups are ignored and the column is returned. + /// In case of an expr where group behavior makes sense, this method is called. + /// For a filter operation for instance, a Series is created per groups and filtered. + /// + /// An implementation of this method may apply an aggregation on the groups only. For instance + /// on a shift, the groups are first aggregated to a `ListChunked` and the shift is applied per + /// group. The implementation then has to return the `Series` exploded (because a later aggregation + /// will use the group tuples to aggregate). The group tuples also have to be updated, because + /// aggregation to a list sorts the exploded `Series` by group. + /// + /// This has some gotcha's. An implementation may also change the group tuples instead of + /// the `Series`. + /// + // we allow this because we pass the vec to the Cow + // Note to self: Don't be smart and dispatch to evaluate as default implementation + // this means filters will be incorrect and lead to invalid results down the line + #[allow(clippy::ptr_arg)] + fn evaluate_on_groups<'a>( + &self, + df: &DataFrame, + groups: &'a GroupPositions, + state: &ExecutionState, + ) -> PolarsResult>; + + /// Get the output field of this expr + fn to_field(&self, input_schema: &Schema) -> PolarsResult; + + /// Convert to a partitioned aggregator. + fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> { + None + } + + fn is_literal(&self) -> bool { + false + } + fn is_scalar(&self) -> bool; +} + +impl Display for &dyn PhysicalExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self.as_expression() { + None => Ok(()), + Some(e) => write!(f, "{e:?}"), + } + } +} + +/// Wrapper struct that allow us to use a PhysicalExpr in polars-io. +/// +/// This is used to filter rows during the scan of file. +pub struct PhysicalIoHelper { + pub expr: Arc, + pub has_window_function: bool, +} + +impl PhysicalIoExpr for PhysicalIoHelper { + fn evaluate_io(&self, df: &DataFrame) -> PolarsResult { + let mut state: ExecutionState = Default::default(); + if self.has_window_function { + state.insert_has_window_function_flag(); + } + self.expr + .evaluate(df, &state) + .map(|c| c.take_materialized_series()) + } +} + +pub fn phys_expr_to_io_expr(expr: Arc) -> Arc { + let has_window_function = if let Some(expr) = expr.as_expression() { + expr.into_iter() + .any(|expr| matches!(expr, Expr::Window { .. })) + } else { + false + }; + Arc::new(PhysicalIoHelper { + expr, + has_window_function, + }) as Arc +} + +pub trait PartitionedAggregation: Send + Sync + PhysicalExpr { + /// This is called in partitioned aggregation. + /// Partitioned results may differ from aggregation results. + /// For instance, for a `mean` operation a partitioned result + /// needs to return the `sum` and the `valid_count` (length - null count). + /// + /// A final aggregation can then take the sum of sums and sum of valid_counts + /// to produce a final mean. + #[allow(clippy::ptr_arg)] + fn evaluate_partitioned( + &self, + df: &DataFrame, + groups: &GroupPositions, + state: &ExecutionState, + ) -> PolarsResult; + + /// Called to merge all the partitioned results in a final aggregate. + #[allow(clippy::ptr_arg)] + fn finalize( + &self, + partitioned: Column, + groups: &GroupPositions, + state: &ExecutionState, + ) -> PolarsResult; +} diff --git a/crates/polars-expr/src/expressions/rolling.rs b/crates/polars-expr/src/expressions/rolling.rs new file mode 100644 index 000000000000..1adbee5f686f --- /dev/null +++ b/crates/polars-expr/src/expressions/rolling.rs @@ -0,0 +1,73 @@ +use polars_time::{PolarsTemporalGroupby, RollingGroupOptions}; + +use super::*; + +pub(crate) struct RollingExpr { + /// the root column that the Function will be applied on. + /// This will be used to create a smaller DataFrame to prevent taking unneeded columns by index + /// TODO! support keys? + /// The challenge is that the group_by will reorder the results and the + /// keys, and time index would need to be updated, or the result should be joined back + /// For now, don't support it. + /// + /// A function Expr. i.e. Mean, Median, Max, etc. + pub(crate) function: Expr, + pub(crate) phys_function: Arc, + pub(crate) out_name: Option, + pub(crate) options: RollingGroupOptions, + pub(crate) expr: Expr, +} + +impl PhysicalExpr for RollingExpr { + fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult { + let groups_key = format!("{:?}", &self.options); + + let groups = { + // Groups must be set by expression runner. + state.window_cache.get_groups(&groups_key).clone() + }; + + // There can be multiple rolling expressions in a single expr. + // E.g. `min().rolling() + max().rolling()` + // So if we hit that we will compute them here. + let groups = match groups { + Some(groups) => groups, + None => { + let (_time_key, groups) = df.rolling(None, &self.options)?; + state.window_cache.insert_groups(groups_key, groups.clone()); + groups + }, + }; + + let mut out = self + .phys_function + .evaluate_on_groups(df, &groups, state)? + .finalize(); + polars_ensure!(out.len() == groups.len(), agg_len = out.len(), groups.len()); + if let Some(name) = &self.out_name { + out.rename(name.clone()); + } + Ok(out.into_column()) + } + + fn evaluate_on_groups<'a>( + &self, + _df: &DataFrame, + _groups: &'a GroupPositions, + _state: &ExecutionState, + ) -> PolarsResult> { + polars_bail!(InvalidOperation: "rolling expression not allowed in aggregation"); + } + + fn to_field(&self, input_schema: &Schema) -> PolarsResult { + self.function.to_field(input_schema, Context::Default) + } + + fn as_expression(&self) -> Option<&Expr> { + Some(&self.expr) + } + + fn is_scalar(&self) -> bool { + false + } +} diff --git a/crates/polars-expr/src/expressions/slice.rs b/crates/polars-expr/src/expressions/slice.rs new file mode 100644 index 000000000000..92042b029a01 --- /dev/null +++ b/crates/polars-expr/src/expressions/slice.rs @@ -0,0 +1,279 @@ +use AnyValue::Null; +use polars_core::POOL; +use polars_core::prelude::*; +use polars_core::utils::{CustomIterTools, slice_offsets}; +use rayon::prelude::*; + +use super::*; +use crate::expressions::{AggregationContext, PhysicalExpr}; + +pub struct SliceExpr { + pub(crate) input: Arc, + pub(crate) offset: Arc, + pub(crate) length: Arc, + pub(crate) expr: Expr, +} + +fn extract_offset(offset: &Column, expr: &Expr) -> PolarsResult { + polars_ensure!( + offset.len() <= 1, expr = expr, ComputeError: + "invalid argument to slice; expected an offset literal, got series of length {}", + offset.len() + ); + offset.get(0).unwrap().extract().ok_or_else( + || polars_err!(expr = expr, ComputeError: "unable to extract offset from {:?}", offset), + ) +} + +fn extract_length(length: &Column, expr: &Expr) -> PolarsResult { + polars_ensure!( + length.len() <= 1, expr = expr, ComputeError: + "invalid argument to slice; expected a length literal, got series of length {}", + length.len() + ); + match length.get(0).unwrap() { + Null => Ok(usize::MAX), + v => v.extract().ok_or_else( + || polars_err!(expr = expr, ComputeError: "unable to extract length from {:?}", length), + ), + } +} + +fn extract_args(offset: &Column, length: &Column, expr: &Expr) -> PolarsResult<(i64, usize)> { + Ok((extract_offset(offset, expr)?, extract_length(length, expr)?)) +} + +fn check_argument(arg: &Column, groups: &GroupsType, name: &str, expr: &Expr) -> PolarsResult<()> { + polars_ensure!( + !matches!(arg.dtype(), DataType::List(_)), expr = expr, ComputeError: + "invalid slice argument: cannot use an array as {} argument", name, + ); + polars_ensure!( + arg.len() == groups.len(), expr = expr, ComputeError: + "invalid slice argument: the evaluated length expression was \ + of different {} than the number of groups", name + ); + polars_ensure!( + arg.null_count() == 0, expr = expr, ComputeError: + "invalid slice argument: the {} expression has nulls", name + ); + Ok(()) +} + +fn slice_groups_idx(offset: i64, length: usize, mut first: IdxSize, idx: &[IdxSize]) -> IdxItem { + let (offset, len) = slice_offsets(offset, length, idx.len()); + + // If slice isn't out of bounds, we replace first. + // If slice is oob, the `idx` vec will be empty and `first` will be ignored + if let Some(f) = idx.get(offset) { + first = *f; + } + // This is a clone of the vec, which is unfortunate. Maybe we have a `sliceable` unitvec one day. + (first, idx[offset..offset + len].into()) +} + +fn slice_groups_slice(offset: i64, length: usize, first: IdxSize, len: IdxSize) -> [IdxSize; 2] { + let (offset, len) = slice_offsets(offset, length, len as usize); + [first + offset as IdxSize, len as IdxSize] +} + +impl PhysicalExpr for SliceExpr { + fn as_expression(&self) -> Option<&Expr> { + Some(&self.expr) + } + + fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult { + let results = POOL.install(|| { + [&self.offset, &self.length, &self.input] + .par_iter() + .map(|e| e.evaluate(df, state)) + .collect::>>() + })?; + let offset = &results[0]; + let length = &results[1]; + let series = &results[2]; + let (offset, length) = extract_args(offset, length, &self.expr)?; + + Ok(series.slice(offset, length)) + } + + fn evaluate_on_groups<'a>( + &self, + df: &DataFrame, + groups: &'a GroupPositions, + state: &ExecutionState, + ) -> PolarsResult> { + let mut results = POOL.install(|| { + [&self.offset, &self.length, &self.input] + .par_iter() + .map(|e| e.evaluate_on_groups(df, groups, state)) + .collect::>>() + })?; + let mut ac = results.pop().unwrap(); + + if let AggState::AggregatedScalar(_) = ac.agg_state() { + polars_bail!(InvalidOperation: "cannot slice() an aggregated scalar value") + } + + let mut ac_length = results.pop().unwrap(); + let mut ac_offset = results.pop().unwrap(); + + use AggState::*; + let groups = match (&ac_offset.state, &ac_length.state) { + (Literal(offset), Literal(length)) => { + let (offset, length) = extract_args(offset, length, &self.expr)?; + + if let Literal(s) = ac.agg_state() { + let s1 = s.slice(offset, length); + ac.with_literal(s1); + return Ok(ac); + } + let groups = ac.groups(); + + match groups.as_ref().as_ref() { + GroupsType::Idx(groups) => { + let groups = groups + .iter() + .map(|(first, idx)| slice_groups_idx(offset, length, first, idx)) + .collect(); + GroupsType::Idx(groups) + }, + GroupsType::Slice { groups, .. } => { + let groups = groups + .iter() + .map(|&[first, len]| slice_groups_slice(offset, length, first, len)) + .collect_trusted(); + GroupsType::Slice { + groups, + rolling: false, + } + }, + } + }, + (Literal(offset), _) => { + let groups = ac.groups(); + let offset = extract_offset(offset, &self.expr)?; + let length = ac_length.aggregated(); + check_argument(&length, groups, "length", &self.expr)?; + + let length = length.cast(&IDX_DTYPE)?; + let length = length.idx().unwrap(); + + match groups.as_ref().as_ref() { + GroupsType::Idx(groups) => { + let groups = groups + .iter() + .zip(length.into_no_null_iter()) + .map(|((first, idx), length)| { + slice_groups_idx(offset, length as usize, first, idx) + }) + .collect(); + GroupsType::Idx(groups) + }, + GroupsType::Slice { groups, .. } => { + let groups = groups + .iter() + .zip(length.into_no_null_iter()) + .map(|(&[first, len], length)| { + slice_groups_slice(offset, length as usize, first, len) + }) + .collect_trusted(); + GroupsType::Slice { + groups, + rolling: false, + } + }, + } + }, + (_, Literal(length)) => { + let groups = ac.groups(); + let length = extract_length(length, &self.expr)?; + let offset = ac_offset.aggregated(); + check_argument(&offset, groups, "offset", &self.expr)?; + + let offset = offset.cast(&DataType::Int64)?; + let offset = offset.i64().unwrap(); + + match groups.as_ref().as_ref() { + GroupsType::Idx(groups) => { + let groups = groups + .iter() + .zip(offset.into_no_null_iter()) + .map(|((first, idx), offset)| { + slice_groups_idx(offset, length, first, idx) + }) + .collect(); + GroupsType::Idx(groups) + }, + GroupsType::Slice { groups, .. } => { + let groups = groups + .iter() + .zip(offset.into_no_null_iter()) + .map(|(&[first, len], offset)| { + slice_groups_slice(offset, length, first, len) + }) + .collect_trusted(); + GroupsType::Slice { + groups, + rolling: false, + } + }, + } + }, + _ => { + let groups = ac.groups(); + let length = ac_length.aggregated(); + let offset = ac_offset.aggregated(); + check_argument(&length, groups, "length", &self.expr)?; + check_argument(&offset, groups, "offset", &self.expr)?; + + let offset = offset.cast(&DataType::Int64)?; + let offset = offset.i64().unwrap(); + + let length = length.cast(&IDX_DTYPE)?; + let length = length.idx().unwrap(); + + match groups.as_ref().as_ref() { + GroupsType::Idx(groups) => { + let groups = groups + .iter() + .zip(offset.into_no_null_iter()) + .zip(length.into_no_null_iter()) + .map(|(((first, idx), offset), length)| { + slice_groups_idx(offset, length as usize, first, idx) + }) + .collect(); + GroupsType::Idx(groups) + }, + GroupsType::Slice { groups, .. } => { + let groups = groups + .iter() + .zip(offset.into_no_null_iter()) + .zip(length.into_no_null_iter()) + .map(|((&[first, len], offset), length)| { + slice_groups_slice(offset, length as usize, first, len) + }) + .collect_trusted(); + GroupsType::Slice { + groups, + rolling: false, + } + }, + } + }, + }; + + ac.with_groups(groups.into_sliceable()) + .set_original_len(false); + + Ok(ac) + } + + fn to_field(&self, input_schema: &Schema) -> PolarsResult { + self.input.to_field(input_schema) + } + + fn is_scalar(&self) -> bool { + false + } +} diff --git a/crates/polars-expr/src/expressions/sort.rs b/crates/polars-expr/src/expressions/sort.rs new file mode 100644 index 000000000000..56efce4ddb7d --- /dev/null +++ b/crates/polars-expr/src/expressions/sort.rs @@ -0,0 +1,115 @@ +use polars_core::POOL; +use polars_core::prelude::*; +use polars_ops::chunked_array::ListNameSpaceImpl; +use polars_utils::idx_vec::IdxVec; +use rayon::prelude::*; + +use super::*; +use crate::expressions::{AggState, AggregationContext, PhysicalExpr}; + +pub struct SortExpr { + pub(crate) physical_expr: Arc, + pub(crate) options: SortOptions, + expr: Expr, +} + +impl SortExpr { + pub fn new(physical_expr: Arc, options: SortOptions, expr: Expr) -> Self { + Self { + physical_expr, + options, + expr, + } + } +} + +/// Map arg_sort result back to the indices on the `GroupIdx` +pub(crate) fn map_sorted_indices_to_group_idx(sorted_idx: &IdxCa, idx: &[IdxSize]) -> IdxVec { + sorted_idx + .cont_slice() + .unwrap() + .iter() + .map(|&i| unsafe { *idx.get_unchecked(i as usize) }) + .collect() +} + +pub(crate) fn map_sorted_indices_to_group_slice(sorted_idx: &IdxCa, first: IdxSize) -> IdxVec { + sorted_idx + .cont_slice() + .unwrap() + .iter() + .map(|&i| i + first) + .collect() +} + +impl PhysicalExpr for SortExpr { + fn as_expression(&self) -> Option<&Expr> { + Some(&self.expr) + } + + fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult { + let series = self.physical_expr.evaluate(df, state)?; + series.sort_with(self.options) + } + + #[allow(clippy::ptr_arg)] + fn evaluate_on_groups<'a>( + &self, + df: &DataFrame, + groups: &'a GroupPositions, + state: &ExecutionState, + ) -> PolarsResult> { + let mut ac = self.physical_expr.evaluate_on_groups(df, groups, state)?; + match ac.agg_state() { + AggState::AggregatedList(s) => { + let ca = s.list().unwrap(); + let out = ca.lst_sort(self.options)?; + ac.with_values(out.into_column(), true, Some(&self.expr))?; + }, + _ => { + let series = ac.flat_naive().into_owned(); + + let mut sort_options = self.options; + sort_options.multithreaded = false; + let groups = POOL.install(|| { + match ac.groups().as_ref().as_ref() { + GroupsType::Idx(groups) => { + groups + .par_iter() + .map(|(first, idx)| { + // SAFETY: group tuples are always in bounds. + let group = unsafe { series.take_slice_unchecked(idx) }; + + let sorted_idx = group.arg_sort(sort_options); + let new_idx = map_sorted_indices_to_group_idx(&sorted_idx, idx); + (new_idx.first().copied().unwrap_or(first), new_idx) + }) + .collect() + }, + GroupsType::Slice { groups, .. } => groups + .par_iter() + .map(|&[first, len]| { + let group = series.slice(first as i64, len as usize); + let sorted_idx = group.arg_sort(sort_options); + let new_idx = map_sorted_indices_to_group_slice(&sorted_idx, first); + (new_idx.first().copied().unwrap_or(first), new_idx) + }) + .collect(), + } + }); + let groups = GroupsType::Idx(groups); + ac.with_groups(groups.into_sliceable()); + }, + } + + Ok(ac) + } + + fn to_field(&self, input_schema: &Schema) -> PolarsResult { + self.physical_expr.to_field(input_schema) + } + + fn is_scalar(&self) -> bool { + false + } +} diff --git a/crates/polars-expr/src/expressions/sortby.rs b/crates/polars-expr/src/expressions/sortby.rs new file mode 100644 index 000000000000..9e28a9755dd6 --- /dev/null +++ b/crates/polars-expr/src/expressions/sortby.rs @@ -0,0 +1,410 @@ +use polars_core::POOL; +use polars_core::chunked_array::from_iterator_par::ChunkedCollectParIterExt; +use polars_core::prelude::*; +use polars_utils::idx_vec::IdxVec; +use rayon::prelude::*; + +use super::*; +use crate::expressions::{ + AggregationContext, PhysicalExpr, UpdateGroups, map_sorted_indices_to_group_idx, + map_sorted_indices_to_group_slice, +}; + +pub struct SortByExpr { + pub(crate) input: Arc, + pub(crate) by: Vec>, + pub(crate) expr: Expr, + pub(crate) sort_options: SortMultipleOptions, +} + +impl SortByExpr { + pub fn new( + input: Arc, + by: Vec>, + expr: Expr, + sort_options: SortMultipleOptions, + ) -> Self { + Self { + input, + by, + expr, + sort_options, + } + } +} + +fn prepare_bool_vec(values: &[bool], by_len: usize) -> Vec { + match (values.len(), by_len) { + // Equal length. + (n_rvalues, n) if n_rvalues == n => values.to_vec(), + // None given all false. + (0, n) => vec![false; n], + // Broadcast first. + (_, n) => vec![values[0]; n], + } +} + +static ERR_MSG: &str = "expressions in 'sort_by' produced a different number of groups"; + +fn check_groups(a: &GroupsType, b: &GroupsType) -> PolarsResult<()> { + polars_ensure!(a.iter().zip(b.iter()).all(|(a, b)| { + a.len() == b.len() + }), ComputeError: ERR_MSG); + Ok(()) +} + +pub(super) fn update_groups_sort_by( + groups: &GroupsType, + sort_by_s: &Series, + options: &SortOptions, +) -> PolarsResult { + // Will trigger a gather for every group, so rechunk before. + let sort_by_s = sort_by_s.rechunk(); + let groups = POOL.install(|| { + groups + .par_iter() + .map(|indicator| sort_by_groups_single_by(indicator, &sort_by_s, options)) + .collect::>() + })?; + + Ok(GroupsType::Idx(groups)) +} + +fn sort_by_groups_single_by( + indicator: GroupsIndicator, + sort_by_s: &Series, + options: &SortOptions, +) -> PolarsResult<(IdxSize, IdxVec)> { + let options = SortOptions { + descending: options.descending, + nulls_last: options.nulls_last, + // We are already in par iter. + multithreaded: false, + ..Default::default() + }; + let new_idx = match indicator { + GroupsIndicator::Idx((_, idx)) => { + // SAFETY: group tuples are always in bounds. + let group = unsafe { sort_by_s.take_slice_unchecked(idx) }; + + let sorted_idx = group.arg_sort(options); + map_sorted_indices_to_group_idx(&sorted_idx, idx) + }, + GroupsIndicator::Slice([first, len]) => { + let group = sort_by_s.slice(first as i64, len as usize); + let sorted_idx = group.arg_sort(options); + map_sorted_indices_to_group_slice(&sorted_idx, first) + }, + }; + let first = new_idx + .first() + .ok_or_else(|| polars_err!(ComputeError: "{}", ERR_MSG))?; + + Ok((*first, new_idx)) +} + +fn sort_by_groups_no_match_single<'a>( + mut ac_in: AggregationContext<'a>, + mut ac_by: AggregationContext<'a>, + descending: bool, + expr: &Expr, +) -> PolarsResult> { + let s_in = ac_in.aggregated(); + let s_by = ac_by.aggregated(); + let mut s_in = s_in.list().unwrap().clone(); + let mut s_by = s_by.list().unwrap().clone(); + + let dtype = s_in.dtype().clone(); + let ca: PolarsResult = POOL.install(|| { + s_in.par_iter_indexed() + .zip(s_by.par_iter_indexed()) + .map(|(opt_s, s_sort_by)| match (opt_s, s_sort_by) { + (Some(s), Some(s_sort_by)) => { + polars_ensure!(s.len() == s_sort_by.len(), ComputeError: "series lengths don't match in 'sort_by' expression"); + let idx = s_sort_by.arg_sort(SortOptions { + descending, + // We are already in par iter. + multithreaded: false, + ..Default::default() + }); + Ok(Some(unsafe { s.take_unchecked(&idx) })) + }, + _ => Ok(None), + }) + .collect_ca_with_dtype(PlSmallStr::EMPTY, dtype) + }); + let c = ca?.with_name(s_in.name().clone()).into_column(); + ac_in.with_values(c, true, Some(expr))?; + Ok(ac_in) +} + +fn sort_by_groups_multiple_by( + indicator: GroupsIndicator, + sort_by_s: &[Series], + descending: &[bool], + nulls_last: &[bool], + multithreaded: bool, + maintain_order: bool, +) -> PolarsResult<(IdxSize, IdxVec)> { + let new_idx = match indicator { + GroupsIndicator::Idx((_first, idx)) => { + // SAFETY: group tuples are always in bounds. + let groups = sort_by_s + .iter() + .map(|s| unsafe { s.take_slice_unchecked(idx) }) + .map(Column::from) + .collect::>(); + + let options = SortMultipleOptions { + descending: descending.to_owned(), + nulls_last: nulls_last.to_owned(), + multithreaded, + maintain_order, + limit: None, + }; + + let sorted_idx = groups[0] + .as_materialized_series() + .arg_sort_multiple(&groups[1..], &options) + .unwrap(); + map_sorted_indices_to_group_idx(&sorted_idx, idx) + }, + GroupsIndicator::Slice([first, len]) => { + let groups = sort_by_s + .iter() + .map(|s| s.slice(first as i64, len as usize)) + .map(Column::from) + .collect::>(); + + let options = SortMultipleOptions { + descending: descending.to_owned(), + nulls_last: nulls_last.to_owned(), + multithreaded, + maintain_order, + limit: None, + }; + let sorted_idx = groups[0] + .as_materialized_series() + .arg_sort_multiple(&groups[1..], &options) + .unwrap(); + map_sorted_indices_to_group_slice(&sorted_idx, first) + }, + }; + let first = new_idx + .first() + .ok_or_else(|| polars_err!(ComputeError: "{}", ERR_MSG))?; + + Ok((*first, new_idx)) +} + +impl PhysicalExpr for SortByExpr { + fn as_expression(&self) -> Option<&Expr> { + Some(&self.expr) + } + + fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult { + let series_f = || self.input.evaluate(df, state); + if self.by.is_empty() { + // Sorting by 0 columns returns input unchanged. + return series_f(); + } + let (series, sorted_idx) = if self.by.len() == 1 { + let sorted_idx_f = || { + let s_sort_by = self.by[0].evaluate(df, state)?; + Ok(s_sort_by.arg_sort(SortOptions::from(&self.sort_options))) + }; + POOL.install(|| rayon::join(series_f, sorted_idx_f)) + } else { + let descending = prepare_bool_vec(&self.sort_options.descending, self.by.len()); + let nulls_last = prepare_bool_vec(&self.sort_options.nulls_last, self.by.len()); + + let sorted_idx_f = || { + let mut needs_broadcast = false; + let mut broadcast_length = 1; + + let mut s_sort_by = self + .by + .iter() + .enumerate() + .map(|(i, e)| { + let column = e.evaluate(df, state).map(|c| match c.dtype() { + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_, _) | DataType::Enum(_, _) => c, + _ => c.to_physical_repr(), + })?; + + if column.len() == 1 && broadcast_length != 1 { + polars_ensure!( + e.is_scalar(), + ShapeMismatch: "non-scalar expression produces broadcasting column", + ); + + return Ok(column.new_from_index(0, broadcast_length)); + } + + if broadcast_length != column.len() { + polars_ensure!( + broadcast_length == 1, ShapeMismatch: + "`sort_by` produced different length ({}) than earlier Series' length in `by` ({})", + broadcast_length, column.len() + ); + + needs_broadcast |= i > 0; + broadcast_length = column.len(); + } + + Ok(column) + }) + .collect::>>()?; + + if needs_broadcast { + for c in s_sort_by.iter_mut() { + if c.len() != broadcast_length { + *c = c.new_from_index(0, broadcast_length); + } + } + } + + let options = self + .sort_options + .clone() + .with_order_descending_multi(descending) + .with_nulls_last_multi(nulls_last); + + s_sort_by[0] + .as_materialized_series() + .arg_sort_multiple(&s_sort_by[1..], &options) + }; + POOL.install(|| rayon::join(series_f, sorted_idx_f)) + }; + let (sorted_idx, series) = (sorted_idx?, series?); + polars_ensure!( + sorted_idx.len() == series.len(), + expr = self.expr, ShapeMismatch: + "`sort_by` produced different length ({}) than the Series that has to be sorted ({})", + sorted_idx.len(), series.len() + ); + + // SAFETY: sorted index are within bounds. + unsafe { Ok(series.take_unchecked(&sorted_idx)) } + } + + #[allow(clippy::ptr_arg)] + fn evaluate_on_groups<'a>( + &self, + df: &DataFrame, + groups: &'a GroupPositions, + state: &ExecutionState, + ) -> PolarsResult> { + let mut ac_in = self.input.evaluate_on_groups(df, groups, state)?; + let descending = prepare_bool_vec(&self.sort_options.descending, self.by.len()); + let nulls_last = prepare_bool_vec(&self.sort_options.nulls_last, self.by.len()); + + let mut ac_sort_by = self + .by + .iter() + .map(|e| e.evaluate_on_groups(df, groups, state)) + .collect::>>()?; + let mut sort_by_s = ac_sort_by + .iter() + .map(|c| { + let c = c.flat_naive(); + match c.dtype() { + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_, _) | DataType::Enum(_, _) => { + c.as_materialized_series().clone() + }, + // @scalar-opt + // @partition-opt + _ => c.to_physical_repr().take_materialized_series(), + } + }) + .collect::>(); + + // A check up front to ensure the input expressions have the same number of total elements. + for sort_by_s in &sort_by_s { + polars_ensure!( + sort_by_s.len() == ac_in.flat_naive().len(), expr = self.expr, ComputeError: + "the expression in `sort_by` argument must result in the same length" + ); + } + + let ordered_by_group_operation = matches!( + ac_sort_by[0].update_groups, + UpdateGroups::WithSeriesLen | UpdateGroups::WithGroupsLen + ); + + let groups = if self.by.len() == 1 { + let mut ac_sort_by = ac_sort_by.pop().unwrap(); + + // The groups of the lhs of the expressions do not match the series values, + // we must take the slower path. + if !matches!(ac_in.update_groups, UpdateGroups::No) { + return sort_by_groups_no_match_single( + ac_in, + ac_sort_by, + self.sort_options.descending[0], + &self.expr, + ); + }; + + let sort_by_s = sort_by_s.pop().unwrap(); + let groups = ac_sort_by.groups(); + + let (check, groups) = POOL.join( + || check_groups(groups, ac_in.groups()), + || { + update_groups_sort_by( + groups, + &sort_by_s, + &SortOptions { + descending: descending[0], + nulls_last: nulls_last[0], + ..Default::default() + }, + ) + }, + ); + check?; + + groups? + } else { + let groups = ac_sort_by[0].groups(); + + let groups = POOL.install(|| { + groups + .par_iter() + .map(|indicator| { + sort_by_groups_multiple_by( + indicator, + &sort_by_s, + &descending, + &nulls_last, + self.sort_options.multithreaded, + self.sort_options.maintain_order, + ) + }) + .collect::>() + }); + GroupsType::Idx(groups?) + }; + + // If the rhs is already aggregated once, it is reordered by the + // group_by operation - we must ensure that we are as well. + if ordered_by_group_operation { + let s = ac_in.aggregated(); + ac_in.with_values(s.explode().unwrap(), false, None)?; + } + + ac_in.with_groups(groups.into_sliceable()); + Ok(ac_in) + } + + fn to_field(&self, input_schema: &Schema) -> PolarsResult { + self.input.to_field(input_schema) + } + + fn is_scalar(&self) -> bool { + false + } +} diff --git a/crates/polars-expr/src/expressions/ternary.rs b/crates/polars-expr/src/expressions/ternary.rs new file mode 100644 index 000000000000..a8b25d4fbdf6 --- /dev/null +++ b/crates/polars-expr/src/expressions/ternary.rs @@ -0,0 +1,365 @@ +use polars_core::POOL; +use polars_core::prelude::*; +use polars_plan::prelude::*; + +use super::*; +use crate::expressions::{AggregationContext, PhysicalExpr}; + +pub struct TernaryExpr { + predicate: Arc, + truthy: Arc, + falsy: Arc, + expr: Expr, + // Can be expensive on small data to run literals in parallel. + run_par: bool, + returns_scalar: bool, +} + +impl TernaryExpr { + pub fn new( + predicate: Arc, + truthy: Arc, + falsy: Arc, + expr: Expr, + run_par: bool, + returns_scalar: bool, + ) -> Self { + Self { + predicate, + truthy, + falsy, + expr, + run_par, + returns_scalar, + } + } +} + +fn finish_as_iters<'a>( + mut ac_truthy: AggregationContext<'a>, + mut ac_falsy: AggregationContext<'a>, + mut ac_mask: AggregationContext<'a>, +) -> PolarsResult> { + let ca = ac_truthy + .iter_groups(false) + .zip(ac_falsy.iter_groups(false)) + .zip(ac_mask.iter_groups(false)) + .map(|((truthy, falsy), mask)| { + match (truthy, falsy, mask) { + (Some(truthy), Some(falsy), Some(mask)) => Some( + truthy + .as_ref() + .zip_with(mask.as_ref().bool()?, falsy.as_ref()), + ), + _ => None, + } + .transpose() + }) + .collect::>()? + .with_name(ac_truthy.get_values().name().clone()); + + // Aggregation leaves only a single chunk. + let arr = ca.downcast_iter().next().unwrap(); + let list_vals_len = arr.values().len(); + + let mut out = ca.into_column(); + if ac_truthy.arity_should_explode() && ac_falsy.arity_should_explode() && ac_mask.arity_should_explode() && + // Exploded list should be equal to groups length. + list_vals_len == ac_truthy.groups.len() + { + out = out.explode()? + } + + ac_truthy.with_values(out, true, None)?; + Ok(ac_truthy) +} + +impl PhysicalExpr for TernaryExpr { + fn as_expression(&self) -> Option<&Expr> { + Some(&self.expr) + } + + fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult { + let mut state = state.split(); + // Don't cache window functions as they run in parallel. + state.remove_cache_window_flag(); + let mask_series = self.predicate.evaluate(df, &state)?; + let mask = mask_series.bool()?.clone(); + + let op_truthy = || self.truthy.evaluate(df, &state); + let op_falsy = || self.falsy.evaluate(df, &state); + let (truthy, falsy) = if self.run_par { + POOL.install(|| rayon::join(op_truthy, op_falsy)) + } else { + (op_truthy(), op_falsy()) + }; + let truthy = truthy?; + let falsy = falsy?; + + truthy.zip_with(&mask, &falsy) + } + + fn to_field(&self, input_schema: &Schema) -> PolarsResult { + self.truthy.to_field(input_schema) + } + + #[allow(clippy::ptr_arg)] + fn evaluate_on_groups<'a>( + &self, + df: &DataFrame, + groups: &'a GroupPositions, + state: &ExecutionState, + ) -> PolarsResult> { + let op_mask = || self.predicate.evaluate_on_groups(df, groups, state); + let op_truthy = || self.truthy.evaluate_on_groups(df, groups, state); + let op_falsy = || self.falsy.evaluate_on_groups(df, groups, state); + let (ac_mask, (ac_truthy, ac_falsy)) = if self.run_par { + POOL.install(|| rayon::join(op_mask, || rayon::join(op_truthy, op_falsy))) + } else { + (op_mask(), (op_truthy(), op_falsy())) + }; + + let mut ac_mask = ac_mask?; + let mut ac_truthy = ac_truthy?; + let mut ac_falsy = ac_falsy?; + + use AggState::*; + + // Check if there are any: + // - non-unit literals + // - AggregatedScalar or AggregatedList + let mut has_non_unit_literal = false; + let mut has_aggregated = false; + // If the length has changed then we must not apply on the flat values + // as ternary broadcasting is length-sensitive. + let mut non_aggregated_len_modified = false; + + for ac in [&ac_mask, &ac_truthy, &ac_falsy].into_iter() { + match ac.agg_state() { + Literal(s) => { + has_non_unit_literal = s.len() != 1; + + if has_non_unit_literal { + break; + } + }, + NotAggregated(_) => { + non_aggregated_len_modified |= !ac.original_len; + }, + AggregatedScalar(_) | AggregatedList(_) => { + has_aggregated = true; + }, + } + } + + if has_non_unit_literal { + // finish_as_iters for non-unit literals to avoid materializing the + // literal inputs per-group. + if state.verbose() { + eprintln!("ternary agg: finish as iters due to non-unit literal") + } + return finish_as_iters(ac_truthy, ac_falsy, ac_mask); + } + + if !has_aggregated && !non_aggregated_len_modified { + // Everything is flat (either NotAggregated or a unit literal). + if state.verbose() { + eprintln!("ternary agg: finish all not-aggregated or unit literal"); + } + + let out = ac_truthy + .get_values() + .zip_with(ac_mask.get_values().bool()?, ac_falsy.get_values())?; + + for ac in [&ac_mask, &ac_truthy, &ac_falsy].into_iter() { + if matches!(ac.agg_state(), NotAggregated(_)) { + let ac_target = ac; + + return Ok(AggregationContext { + state: NotAggregated(out), + groups: ac_target.groups.clone(), + sorted: ac_target.sorted, + update_groups: ac_target.update_groups, + original_len: ac_target.original_len, + }); + } + } + + ac_truthy.with_agg_state(Literal(out)); + + return Ok(ac_truthy); + } + + for ac in [&mut ac_mask, &mut ac_truthy, &mut ac_falsy].into_iter() { + if matches!(ac.agg_state(), NotAggregated(_)) { + let _ = ac.aggregated(); + } + } + + // At this point the input agg states are one of the following: + // * `Literal` where `s.len() == 1` + // * `AggregatedList` + // * `AggregatedScalar` + + let mut non_literal_acs = Vec::<&AggregationContext>::with_capacity(3); + + // non_literal_acs will have at least 1 item because has_aggregated was + // true from above. + for ac in [&ac_mask, &ac_truthy, &ac_falsy].into_iter() { + if !matches!(ac.agg_state(), Literal(_)) { + non_literal_acs.push(ac); + } + } + + for (ac_l, ac_r) in non_literal_acs.iter().zip(non_literal_acs.iter().skip(1)) { + if std::mem::discriminant(ac_l.agg_state()) != std::mem::discriminant(ac_r.agg_state()) + { + // Mix of AggregatedScalar and AggregatedList is done per group, + // as every row of the AggregatedScalar must be broadcasted to a + // list of the same length as the corresponding AggregatedList + // row. + if state.verbose() { + eprintln!( + "ternary agg: finish as iters due to mix of AggregatedScalar and AggregatedList" + ) + } + return finish_as_iters(ac_truthy, ac_falsy, ac_mask); + } + } + + // At this point, the possible combinations are: + // * mix of unit literals and AggregatedScalar + // * `zip_with` can be called directly with the series + // * mix of unit literals and AggregatedList + // * `zip_with` can be called with the flat values after the offsets + // have been checked for alignment + let ac_target = non_literal_acs.first().unwrap(); + + let agg_state_out = match ac_target.agg_state() { + AggregatedList(_) => { + // Ternary can be applied directly on the flattened series, + // given that their offsets have been checked to be equal. + if state.verbose() { + eprintln!("ternary agg: finish AggregatedList") + } + + for (ac_l, ac_r) in non_literal_acs.iter().zip(non_literal_acs.iter().skip(1)) { + match (ac_l.agg_state(), ac_r.agg_state()) { + (AggregatedList(s_l), AggregatedList(s_r)) => { + let check = s_l.list().unwrap().offsets()?.as_slice() + == s_r.list().unwrap().offsets()?.as_slice(); + + polars_ensure!( + check, + ShapeMismatch: "shapes of `self`, `mask` and `other` are not suitable for `zip_with` operation" + ); + }, + _ => unreachable!(), + } + } + + let truthy = if let AggregatedList(s) = ac_truthy.agg_state() { + s.list().unwrap().get_inner().into_column() + } else { + ac_truthy.get_values().clone() + }; + + let falsy = if let AggregatedList(s) = ac_falsy.agg_state() { + s.list().unwrap().get_inner().into_column() + } else { + ac_falsy.get_values().clone() + }; + + let mask = if let AggregatedList(s) = ac_mask.agg_state() { + s.list().unwrap().get_inner().into_column() + } else { + ac_mask.get_values().clone() + }; + + let out = truthy.zip_with(mask.bool()?, &falsy)?; + + // The output series is guaranteed to be aligned with expected + // offsets buffer of the result, so we construct the result + // ListChunked directly from the 2. + let out = out.rechunk(); + // @scalar-opt + // @partition-opt + let values = out.as_materialized_series().array_ref(0); + let offsets = ac_target.get_values().list().unwrap().offsets()?; + let inner_type = out.dtype(); + let dtype = LargeListArray::default_datatype(values.dtype().clone()); + + // SAFETY: offsets are correct. + let out = LargeListArray::new(dtype, offsets, values.clone(), None); + + let mut out = ListChunked::with_chunk(truthy.name().clone(), out); + unsafe { out.to_logical(inner_type.clone()) }; + + if ac_target.get_values().list().unwrap()._can_fast_explode() { + out.set_fast_explode(); + }; + + let out = out.into_column(); + + AggregatedList(out) + }, + AggregatedScalar(_) => { + if state.verbose() { + eprintln!("ternary agg: finish AggregatedScalar") + } + + let out = ac_truthy + .get_values() + .zip_with(ac_mask.get_values().bool()?, ac_falsy.get_values())?; + AggregatedScalar(out) + }, + _ => { + unreachable!() + }, + }; + + Ok(AggregationContext { + state: agg_state_out, + groups: ac_target.groups.clone(), + sorted: ac_target.sorted, + update_groups: ac_target.update_groups, + original_len: ac_target.original_len, + }) + } + fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> { + Some(self) + } + + fn is_scalar(&self) -> bool { + self.returns_scalar + } +} + +impl PartitionedAggregation for TernaryExpr { + fn evaluate_partitioned( + &self, + df: &DataFrame, + groups: &GroupPositions, + state: &ExecutionState, + ) -> PolarsResult { + let truthy = self.truthy.as_partitioned_aggregator().unwrap(); + let falsy = self.falsy.as_partitioned_aggregator().unwrap(); + let mask = self.predicate.as_partitioned_aggregator().unwrap(); + + let truthy = truthy.evaluate_partitioned(df, groups, state)?; + let falsy = falsy.evaluate_partitioned(df, groups, state)?; + let mask = mask.evaluate_partitioned(df, groups, state)?; + let mask = mask.bool()?.clone(); + + truthy.zip_with(&mask, &falsy) + } + + fn finalize( + &self, + partitioned: Column, + _groups: &GroupPositions, + _state: &ExecutionState, + ) -> PolarsResult { + Ok(partitioned) + } +} diff --git a/crates/polars-expr/src/expressions/window.rs b/crates/polars-expr/src/expressions/window.rs new file mode 100644 index 000000000000..a3ce20db6184 --- /dev/null +++ b/crates/polars-expr/src/expressions/window.rs @@ -0,0 +1,848 @@ +use std::fmt::Write; + +use arrow::array::PrimitiveArray; +use arrow::bitmap::Bitmap; +use polars_core::prelude::*; +use polars_core::series::IsSorted; +use polars_core::utils::_split_offsets; +use polars_core::{POOL, downcast_as_macro_arg_physical}; +use polars_ops::frame::SeriesJoin; +use polars_ops::frame::join::{ChunkJoinOptIds, private_left_join_multiple_keys}; +use polars_ops::prelude::*; +use polars_plan::prelude::*; +use polars_utils::sort::perfect_sort; +use polars_utils::sync::SyncPtr; +use rayon::prelude::*; + +use super::*; + +pub struct WindowExpr { + /// the root column that the Function will be applied on. + /// This will be used to create a smaller DataFrame to prevent taking unneeded columns by index + pub(crate) group_by: Vec>, + pub(crate) order_by: Option<(Arc, SortOptions)>, + pub(crate) apply_columns: Vec, + pub(crate) out_name: Option, + /// A function Expr. i.e. Mean, Median, Max, etc. + pub(crate) function: Expr, + pub(crate) phys_function: Arc, + pub(crate) mapping: WindowMapping, + pub(crate) expr: Expr, +} + +#[cfg_attr(debug_assertions, derive(Debug))] +enum MapStrategy { + // Join by key, this the most expensive + // for reduced aggregations + Join, + // explode now + Explode, + // Use an arg_sort to map the values back + Map, + Nothing, +} + +impl WindowExpr { + fn map_list_agg_by_arg_sort( + &self, + out_column: Column, + flattened: &Column, + mut ac: AggregationContext, + gb: GroupBy, + ) -> PolarsResult { + // idx (new-idx, original-idx) + let mut idx_mapping = Vec::with_capacity(out_column.len()); + + // we already set this buffer so we can reuse the `original_idx` buffer + // that saves an allocation + let mut take_idx = vec![]; + + // groups are not changed, we can map by doing a standard arg_sort. + if std::ptr::eq(ac.groups().as_ref(), gb.get_groups()) { + let mut iter = 0..flattened.len() as IdxSize; + match ac.groups().as_ref().as_ref() { + GroupsType::Idx(groups) => { + for g in groups.all() { + idx_mapping.extend(g.iter().copied().zip(&mut iter)); + } + }, + GroupsType::Slice { groups, .. } => { + for &[first, len] in groups { + idx_mapping.extend((first..first + len).zip(&mut iter)); + } + }, + } + } + // groups are changed, we use the new group indexes as arguments of the arg_sort + // and sort by the old indexes + else { + let mut original_idx = Vec::with_capacity(out_column.len()); + match gb.get_groups().as_ref() { + GroupsType::Idx(groups) => { + for g in groups.all() { + original_idx.extend_from_slice(g) + } + }, + GroupsType::Slice { groups, .. } => { + for &[first, len] in groups { + original_idx.extend(first..first + len) + } + }, + }; + + let mut original_idx_iter = original_idx.iter().copied(); + + match ac.groups().as_ref().as_ref() { + GroupsType::Idx(groups) => { + for g in groups.all() { + idx_mapping.extend(g.iter().copied().zip(&mut original_idx_iter)); + } + }, + GroupsType::Slice { groups, .. } => { + for &[first, len] in groups { + idx_mapping.extend((first..first + len).zip(&mut original_idx_iter)); + } + }, + } + original_idx.clear(); + take_idx = original_idx; + } + // SAFETY: + // we only have unique indices ranging from 0..len + unsafe { perfect_sort(&POOL, &idx_mapping, &mut take_idx) }; + Ok(IdxCa::from_vec(PlSmallStr::EMPTY, take_idx)) + } + + #[allow(clippy::too_many_arguments)] + fn map_by_arg_sort( + &self, + df: &DataFrame, + out_column: Column, + flattened: &Column, + mut ac: AggregationContext, + group_by_columns: &[Column], + gb: GroupBy, + cache_key: String, + state: &ExecutionState, + ) -> PolarsResult { + // we use an arg_sort to map the values back + + // This is a bit more complicated because the final group tuples may differ from the original + // so we use the original indices as idx values to arg_sort the original column + // + // The example below shows the naive version without group tuple mapping + + // columns + // a b a a + // + // agg list + // [0, 2, 3] + // [1] + // + // flatten + // + // [0, 2, 3, 1] + // + // arg_sort + // + // [0, 3, 1, 2] + // + // take by arg_sorted indexes and voila groups mapped + // [0, 1, 2, 3] + + if flattened.len() != df.height() { + let ca = out_column.list().unwrap(); + let non_matching_group = + ca.into_iter() + .zip(ac.groups().iter()) + .find(|(output, group)| { + if let Some(output) = output { + output.as_ref().len() != group.len() + } else { + false + } + }); + + if let Some((output, group)) = non_matching_group { + let first = group.first(); + let group = group_by_columns + .iter() + .map(|s| format!("{}", s.get(first as usize).unwrap())) + .collect::>(); + polars_bail!( + expr = self.expr, ComputeError: + "the length of the window expression did not match that of the group\ + \n> group: {}\n> group length: {}\n> output: '{:?}'", + comma_delimited(String::new(), &group), group.len(), output.unwrap() + ); + } else { + polars_bail!( + expr = self.expr, ComputeError: + "the length of the window expression did not match that of the group" + ); + }; + } + + let idx = if state.cache_window() { + if let Some(idx) = state.window_cache.get_map(&cache_key) { + idx + } else { + let idx = Arc::new(self.map_list_agg_by_arg_sort(out_column, flattened, ac, gb)?); + state.window_cache.insert_map(cache_key, idx.clone()); + idx + } + } else { + Arc::new(self.map_list_agg_by_arg_sort(out_column, flattened, ac, gb)?) + }; + + // SAFETY: + // groups should always be in bounds. + unsafe { Ok(flattened.take_unchecked(&idx)) } + } + + fn run_aggregation<'a>( + &self, + df: &DataFrame, + state: &ExecutionState, + gb: &'a GroupBy, + ) -> PolarsResult> { + let ac = self + .phys_function + .evaluate_on_groups(df, gb.get_groups(), state)?; + Ok(ac) + } + + fn is_explicit_list_agg(&self) -> bool { + // col("foo").implode() + // col("foo").implode().alias() + // .. + // col("foo").implode().alias().alias() + // + // but not: + // col("foo").implode().sum().alias() + // .. + // col("foo").min() + let mut explicit_list = false; + for e in &self.expr { + if let Expr::Window { function, .. } = e { + // or list().alias + let mut finishes_list = false; + for e in &**function { + match e { + Expr::Agg(AggExpr::Implode(_)) => { + finishes_list = true; + }, + Expr::Alias(_, _) => {}, + _ => break, + } + } + explicit_list = finishes_list; + } + } + + explicit_list + } + + fn is_simple_column_expr(&self) -> bool { + // col() + // or col().alias() + let mut simple_col = false; + for e in &self.expr { + if let Expr::Window { function, .. } = e { + // or list().alias + for e in &**function { + match e { + Expr::Column(_) => { + simple_col = true; + }, + Expr::Alias(_, _) => {}, + _ => break, + } + } + } + } + simple_col + } + + fn is_aggregation(&self) -> bool { + // col() + // or col().agg() + let mut agg_col = false; + for e in &self.expr { + if let Expr::Window { function, .. } = e { + // or list().alias + for e in &**function { + match e { + Expr::Agg(_) => { + agg_col = true; + }, + Expr::Alias(_, _) => {}, + _ => break, + } + } + } + } + agg_col + } + + /// Check if the branches have an aggregation + /// when(a > sum) + /// then (foo) + /// otherwise(bar - sum) + fn has_different_group_sources(&self) -> bool { + let mut has_arity = false; + let mut agg_col = false; + for e in &self.expr { + if let Expr::Window { function, .. } = e { + // or list().alias + for e in &**function { + match e { + Expr::Ternary { .. } | Expr::BinaryExpr { .. } => { + has_arity = true; + }, + Expr::Alias(_, _) => {}, + Expr::Agg(_) => { + agg_col = true; + }, + Expr::Function { options, .. } + | Expr::AnonymousFunction { options, .. } => { + if options.flags.contains(FunctionFlags::RETURNS_SCALAR) + && matches!(options.collect_groups, ApplyOptions::GroupWise) + { + agg_col = true; + } + }, + _ => {}, + } + } + } + } + has_arity && agg_col + } + + fn determine_map_strategy( + &self, + agg_state: &AggState, + gb: &GroupBy, + ) -> PolarsResult { + match (self.mapping, agg_state) { + // Explode + // `(col("x").sum() * col("y")).list().over("groups").flatten()` + (WindowMapping::Explode, _) => Ok(MapStrategy::Explode), + // // explicit list + // // `(col("x").sum() * col("y")).list().over("groups")` + // (false, false, _) => Ok(MapStrategy::Join), + // aggregations + //`sum("foo").over("groups")` + (_, AggState::AggregatedScalar(_)) => Ok(MapStrategy::Join), + // no explicit aggregations, map over the groups + //`(col("x").sum() * col("y")).over("groups")` + (WindowMapping::Join, AggState::AggregatedList(_)) => Ok(MapStrategy::Join), + // no explicit aggregations, map over the groups + //`(col("x").sum() * col("y")).over("groups")` + (WindowMapping::GroupsToRows, AggState::AggregatedList(_)) => { + if let GroupsType::Slice { .. } = gb.get_groups().as_ref() { + // Result can be directly exploded if the input was sorted. + Ok(MapStrategy::Explode) + } else { + Ok(MapStrategy::Map) + } + }, + // no aggregations, just return column + // or an aggregation that has been flattened + // we have to check which one + //`col("foo").over("groups")` + (WindowMapping::GroupsToRows, AggState::NotAggregated(_)) => { + // col() + // or col().alias() + if self.is_simple_column_expr() { + Ok(MapStrategy::Nothing) + } else { + Ok(MapStrategy::Map) + } + }, + (WindowMapping::Join, AggState::NotAggregated(_)) => Ok(MapStrategy::Join), + // literals, do nothing and let broadcast + (_, AggState::Literal(_)) => Ok(MapStrategy::Nothing), + } + } +} + +// Utility to create partitions and cache keys +pub fn window_function_format_order_by(to: &mut String, e: &Expr, k: &SortOptions) { + write!(to, "_PL_{:?}{}_{}", e, k.descending, k.nulls_last).unwrap(); +} + +impl PhysicalExpr for WindowExpr { + // Note: this was first implemented with expression evaluation but this performed really bad. + // Therefore we choose the group_by -> apply -> self join approach + + // This first cached the group_by and the join tuples, but rayon under a mutex leads to deadlocks: + // https://github.com/rayon-rs/rayon/issues/592 + fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult { + // This method does the following: + // 1. determine group_by tuples based on the group_column + // 2. apply an aggregation function + // 3. join the results back to the original dataframe + // this stores all group values on the original df size + // + // we have several strategies for this + // - 3.1 JOIN + // Use a join for aggregations like + // `sum("foo").over("groups")` + // and explicit `list` aggregations + // `(col("x").sum() * col("y")).list().over("groups")` + // + // - 3.2 EXPLODE + // Explicit list aggregations that are followed by `over().flatten()` + // # the fastest method to do things over groups when the groups are sorted. + // # note that it will require an explicit `list()` call from now on. + // `(col("x").sum() * col("y")).list().over("groups").flatten()` + // + // - 3.3. MAP to original locations + // This will be done for list aggregations that are not explicitly aggregated as list + // `(col("x").sum() * col("y")).over("groups") + // This can be used to reverse, sort, shuffle etc. the values in a group + + // 4. select the final column and return + + if df.is_empty() { + let field = self.phys_function.to_field(df.schema())?; + match self.mapping { + WindowMapping::Join => { + return Ok(Column::full_null( + field.name().clone(), + 0, + &DataType::List(Box::new(field.dtype().clone())), + )); + }, + _ => { + return Ok(Column::full_null(field.name().clone(), 0, field.dtype())); + }, + } + } + + let group_by_columns = self + .group_by + .iter() + .map(|e| e.evaluate(df, state)) + .collect::>>()?; + + // if the keys are sorted + let sorted_keys = group_by_columns.iter().all(|s| { + matches!( + s.is_sorted_flag(), + IsSorted::Ascending | IsSorted::Descending + ) + }); + let explicit_list_agg = self.is_explicit_list_agg(); + + // if we flatten this column we need to make sure the groups are sorted. + let mut sort_groups = matches!(self.mapping, WindowMapping::Explode) || + // if not + // `col().over()` + // and not + // `col().list().over` + // and not + // `col().sum()` + // and keys are sorted + // we may optimize with explode call + (!self.is_simple_column_expr() && !explicit_list_agg && sorted_keys && !self.is_aggregation()); + + // overwrite sort_groups for some expressions + // TODO: fully understand the rationale is here. + if self.has_different_group_sources() { + sort_groups = true + } + + let create_groups = || { + let gb = df.group_by_with_series(group_by_columns.clone(), true, sort_groups)?; + let mut groups = gb.take_groups(); + + if let Some((order_by, options)) = &self.order_by { + let order_by = order_by.evaluate(df, state)?; + polars_ensure!(order_by.len() == df.height(), ShapeMismatch: "the order by expression evaluated to a length: {} that doesn't match the input DataFrame: {}", order_by.len(), df.height()); + groups = update_groups_sort_by(&groups, order_by.as_materialized_series(), options)? + .into_sliceable() + } + + let out: PolarsResult = Ok(groups); + out + }; + + // Try to get cached grouptuples + let (mut groups, cache_key) = if state.cache_window() { + let mut cache_key = String::with_capacity(32 * group_by_columns.len()); + write!(&mut cache_key, "{}", state.branch_idx).unwrap(); + for s in &group_by_columns { + cache_key.push_str(s.name()); + } + if let Some((e, options)) = &self.order_by { + let e = match e.as_expression() { + Some(e) => e, + None => { + polars_bail!(InvalidOperation: "cannot order by this expression in window function") + }, + }; + window_function_format_order_by(&mut cache_key, e, options) + } + + let groups = match state.window_cache.get_groups(&cache_key) { + Some(groups) => groups, + None => create_groups()?, + }; + (groups, cache_key) + } else { + (create_groups()?, "".to_string()) + }; + + // 2. create GroupBy object and apply aggregation + let apply_columns = self.apply_columns.clone(); + + // some window expressions need sorted groups + // to make sure that the caches align we sort + // the groups, so that the cached groups and join keys + // are consistent among all windows + if sort_groups || state.cache_window() { + groups.sort(); + state + .window_cache + .insert_groups(cache_key.clone(), groups.clone()); + } + let gb = GroupBy::new(df, group_by_columns.clone(), groups, Some(apply_columns)); + + // If the aggregation creates categoricals and `MapStrategy` is `Join`, + // the string cache was needed. So we hold it for that case. + // Worst case is that a categorical is created with indexes from the string + // cache which is fine, as the physical representation is undefined. + #[cfg(feature = "dtype-categorical")] + let _sc = polars_core::StringCacheHolder::hold(); + let mut ac = self.run_aggregation(df, state, &gb)?; + + use MapStrategy::*; + match self.determine_map_strategy(ac.agg_state(), &gb)? { + Nothing => { + let mut out = ac.flat_naive().into_owned(); + + if ac.is_literal() { + out = out.new_from_index(0, df.height()) + } + if let Some(name) = &self.out_name { + out.rename(name.clone()); + } + Ok(out.into_column()) + }, + Explode => { + let mut out = ac.aggregated().explode()?; + if let Some(name) = &self.out_name { + out.rename(name.clone()); + } + Ok(out.into_column()) + }, + Map => { + // TODO! + // investigate if sorted arrays can be return directly + let out_column = ac.aggregated(); + let flattened = out_column.explode()?; + // we extend the lifetime as we must convince the compiler that ac lives + // long enough. We drop `GrouBy` when we are done with `ac`. + let ac = unsafe { + std::mem::transmute::, AggregationContext<'static>>(ac) + }; + self.map_by_arg_sort( + df, + out_column, + &flattened, + ac, + &group_by_columns, + gb, + cache_key, + state, + ) + }, + Join => { + let out_column = ac.aggregated(); + // we try to flatten/extend the array by repeating the aggregated value n times + // where n is the number of members in that group. That way we can try to reuse + // the same map by arg_sort logic as done for listed aggregations + let update_groups = !matches!(&ac.update_groups, UpdateGroups::No); + match ( + &ac.update_groups, + set_by_groups(&out_column, &ac, df.height(), update_groups), + ) { + // for aggregations that reduce like sum, mean, first and are numeric + // we take the group locations to directly map them to the right place + (UpdateGroups::No, Some(out)) => Ok(out.into_column()), + (_, _) => { + let keys = gb.keys(); + + let get_join_tuples = || { + if group_by_columns.len() == 1 { + let mut left = group_by_columns[0].clone(); + // group key from right column + let mut right = keys[0].clone(); + + let (left, right) = if left.dtype().is_nested() { + ( + ChunkedArray::::with_chunk( + "".into(), + row_encode::_get_rows_encoded_unordered(&[ + left.clone() + ])? + .into_array(), + ) + .into_series(), + ChunkedArray::::with_chunk( + "".into(), + row_encode::_get_rows_encoded_unordered(&[ + right.clone() + ])? + .into_array(), + ) + .into_series(), + ) + } else { + ( + left.into_materialized_series().clone(), + right.into_materialized_series().clone(), + ) + }; + + PolarsResult::Ok(Arc::new( + left.hash_join_left(&right, JoinValidation::ManyToMany, true) + .unwrap() + .1, + )) + } else { + let df_right = + unsafe { DataFrame::new_no_checks_height_from_first(keys) }; + let df_left = unsafe { + DataFrame::new_no_checks_height_from_first(group_by_columns) + }; + Ok(Arc::new( + private_left_join_multiple_keys(&df_left, &df_right, true)?.1, + )) + } + }; + + // try to get cached join_tuples + let join_opt_ids = if state.cache_window() { + if let Some(jt) = state.window_cache.get_join(&cache_key) { + jt + } else { + let jt = get_join_tuples()?; + state + .window_cache + .insert_join(cache_key.clone(), jt.clone()); + jt + } + } else { + get_join_tuples()? + }; + + let mut out = materialize_column(&join_opt_ids, &out_column); + + if let Some(name) = &self.out_name { + out.rename(name.clone()); + } + + Ok(out.into_column()) + }, + } + }, + } + } + + fn to_field(&self, input_schema: &Schema) -> PolarsResult { + self.function.to_field(input_schema, Context::Default) + } + + fn is_scalar(&self) -> bool { + false + } + + #[allow(clippy::ptr_arg)] + fn evaluate_on_groups<'a>( + &self, + _df: &DataFrame, + _groups: &'a GroupPositions, + _state: &ExecutionState, + ) -> PolarsResult> { + polars_bail!(InvalidOperation: "window expression not allowed in aggregation"); + } + + fn as_expression(&self) -> Option<&Expr> { + Some(&self.expr) + } +} + +fn materialize_column(join_opt_ids: &ChunkJoinOptIds, out_column: &Column) -> Column { + { + use arrow::Either; + use polars_ops::chunked_array::TakeChunked; + + match join_opt_ids { + Either::Left(ids) => unsafe { + IdxCa::with_nullable_idx(ids, |idx| out_column.take_unchecked(idx)) + }, + Either::Right(ids) => unsafe { out_column.take_opt_chunked_unchecked(ids, false) }, + } + } +} + +/// Simple reducing aggregation can be set by the groups +fn set_by_groups( + s: &Column, + ac: &AggregationContext, + len: usize, + update_groups: bool, +) -> Option { + if update_groups || !ac.original_len { + return None; + } + if s.dtype().to_physical().is_primitive_numeric() { + let dtype = s.dtype(); + let s = s.to_physical_repr(); + + macro_rules! dispatch { + ($ca:expr) => {{ Some(set_numeric($ca, &ac.groups, len)) }}; + } + downcast_as_macro_arg_physical!(&s, dispatch) + .map(|s| unsafe { s.from_physical_unchecked(dtype) }.unwrap()) + .map(Column::from) + } else { + None + } +} + +fn set_numeric(ca: &ChunkedArray, groups: &GroupsType, len: usize) -> Series +where + T: PolarsNumericType, + ChunkedArray: IntoSeries, +{ + let mut values = Vec::with_capacity(len); + let ptr: *mut T::Native = values.as_mut_ptr(); + // SAFETY: + // we will write from different threads but we will never alias. + let sync_ptr_values = unsafe { SyncPtr::new(ptr) }; + + if ca.null_count() == 0 { + let ca = ca.rechunk(); + match groups { + GroupsType::Idx(groups) => { + let agg_vals = ca.cont_slice().expect("rechunked"); + POOL.install(|| { + agg_vals + .par_iter() + .zip(groups.all().par_iter()) + .for_each(|(v, g)| { + let ptr = sync_ptr_values.get(); + for idx in g.as_slice() { + debug_assert!((*idx as usize) < len); + unsafe { *ptr.add(*idx as usize) = *v } + } + }) + }) + }, + GroupsType::Slice { groups, .. } => { + let agg_vals = ca.cont_slice().expect("rechunked"); + POOL.install(|| { + agg_vals + .par_iter() + .zip(groups.par_iter()) + .for_each(|(v, [start, g_len])| { + let ptr = sync_ptr_values.get(); + let start = *start as usize; + let end = start + *g_len as usize; + for idx in start..end { + debug_assert!(idx < len); + unsafe { *ptr.add(idx) = *v } + } + }) + }); + }, + } + + // SAFETY: we have written all slots + unsafe { values.set_len(len) } + ChunkedArray::new_vec(ca.name().clone(), values).into_series() + } else { + // We don't use a mutable bitmap as bits will have race conditions! + // A single byte might alias if we write from single threads. + let mut validity: Vec = vec![false; len]; + let validity_ptr = validity.as_mut_ptr(); + let sync_ptr_validity = unsafe { SyncPtr::new(validity_ptr) }; + + let n_threads = POOL.current_num_threads(); + let offsets = _split_offsets(ca.len(), n_threads); + + match groups { + GroupsType::Idx(groups) => offsets.par_iter().for_each(|(offset, offset_len)| { + let offset = *offset; + let offset_len = *offset_len; + let ca = ca.slice(offset as i64, offset_len); + let groups = &groups.all()[offset..offset + offset_len]; + let values_ptr = sync_ptr_values.get(); + let validity_ptr = sync_ptr_validity.get(); + + ca.iter().zip(groups.iter()).for_each(|(opt_v, g)| { + for idx in g.as_slice() { + let idx = *idx as usize; + debug_assert!(idx < len); + unsafe { + match opt_v { + Some(v) => { + *values_ptr.add(idx) = v; + *validity_ptr.add(idx) = true; + }, + None => { + *values_ptr.add(idx) = T::Native::default(); + *validity_ptr.add(idx) = false; + }, + }; + } + } + }) + }), + GroupsType::Slice { groups, .. } => { + offsets.par_iter().for_each(|(offset, offset_len)| { + let offset = *offset; + let offset_len = *offset_len; + let ca = ca.slice(offset as i64, offset_len); + let groups = &groups[offset..offset + offset_len]; + let values_ptr = sync_ptr_values.get(); + let validity_ptr = sync_ptr_validity.get(); + + for (opt_v, [start, g_len]) in ca.iter().zip(groups.iter()) { + let start = *start as usize; + let end = start + *g_len as usize; + for idx in start..end { + debug_assert!(idx < len); + unsafe { + match opt_v { + Some(v) => { + *values_ptr.add(idx) = v; + *validity_ptr.add(idx) = true; + }, + None => { + *values_ptr.add(idx) = T::Native::default(); + *validity_ptr.add(idx) = false; + }, + }; + } + } + } + }) + }, + } + // SAFETY: we have written all slots + unsafe { values.set_len(len) } + let validity = Bitmap::from(validity); + let arr = PrimitiveArray::new( + T::get_dtype().to_physical().to_arrow(CompatLevel::newest()), + values.into(), + Some(validity), + ); + Series::try_from((ca.name().clone(), arr.boxed())).unwrap() + } +} diff --git a/crates/polars-expr/src/groups/binview.rs b/crates/polars-expr/src/groups/binview.rs new file mode 100644 index 000000000000..cf5c674c6f30 --- /dev/null +++ b/crates/polars-expr/src/groups/binview.rs @@ -0,0 +1,275 @@ +use arrow::array::{Array, BinaryViewArrayGeneric, View, ViewType}; +use arrow::bitmap::{Bitmap, MutableBitmap}; +use arrow::buffer::Buffer; +use polars_compute::binview_index_map::{BinaryViewIndexMap, Entry}; + +use super::*; +use crate::hash_keys::HashKeys; + +#[derive(Default)] +pub struct BinviewHashGrouper { + name: PlSmallStr, + dtype: DataType, + idx_map: BinaryViewIndexMap<()>, + null_idx: IdxSize, +} + +impl BinviewHashGrouper { + pub fn new(name: PlSmallStr, dtype: DataType) -> Self { + Self { + name, + dtype, + idx_map: BinaryViewIndexMap::default(), + null_idx: IdxSize::MAX, + } + } + + /// # Safety + /// The view must be valid for the given buffer set. + #[inline(always)] + unsafe fn insert_key(&mut self, hash: u64, view: View, buffers: &Arc<[Buffer]>) -> IdxSize { + unsafe { + match self.idx_map.entry_view(hash, view, buffers) { + Entry::Occupied(o) => o.index(), + Entry::Vacant(v) => { + let index = v.index(); + v.insert(()); + index + }, + } + } + } + + #[inline(always)] + fn insert_null(&mut self) -> IdxSize { + if self.null_idx == IdxSize::MAX { + self.null_idx = self.idx_map.push_unmapped_empty_entry(()); + } + self.null_idx + } + + /// # Safety + /// The view must be valid for the given buffer set. + #[inline(always)] + unsafe fn contains_key(&self, hash: u64, view: &View, buffers: &Arc<[Buffer]>) -> bool { + unsafe { self.idx_map.get_view(hash, view, buffers).is_some() } + } + + #[inline(always)] + fn contains_null(&self) -> bool { + self.null_idx < IdxSize::MAX + } + + /// # Safety + /// The views must be valid for the given buffers. + unsafe fn finalize_keys( + &self, + views: Buffer, + buffers: Arc<[Buffer]>, + validity: Option, + ) -> DataFrame { + unsafe { + let arrow_dtype = self.dtype.to_arrow(CompatLevel::newest()); + let keys = BinaryViewArrayGeneric::::new_unchecked_unknown_md( + arrow_dtype, + views, + buffers, + validity, + None, + ); + let s = Series::from_chunks_and_dtype_unchecked( + self.name.clone(), + vec![Box::new(keys)], + &self.dtype, + ); + DataFrame::new(vec![Column::from(s)]).unwrap() + } + } +} + +impl Grouper for BinviewHashGrouper { + fn new_empty(&self) -> Box { + Box::new(Self::new(self.name.clone(), self.dtype.clone())) + } + + fn reserve(&mut self, additional: usize) { + self.idx_map.reserve(additional); + } + + fn num_groups(&self) -> IdxSize { + self.idx_map.len() + } + + unsafe fn insert_keys_subset( + &mut self, + hash_keys: &HashKeys, + subset: &[IdxSize], + group_idxs: Option<&mut Vec>, + ) { + let HashKeys::Binview(hash_keys) = hash_keys else { + unreachable!() + }; + + unsafe { + let views = hash_keys.keys.views().as_slice(); + let buffers = hash_keys.keys.data_buffers(); + if let Some(validity) = hash_keys.keys.validity() { + if hash_keys.null_is_valid { + let groups = subset.iter().map(|idx| { + if validity.get_bit_unchecked(*idx as usize) { + let hash = hash_keys.hashes.value_unchecked(*idx as usize); + let view = views.get_unchecked(*idx as usize); + self.insert_key(hash, *view, buffers) + } else { + self.insert_null() + } + }); + if let Some(group_idxs) = group_idxs { + group_idxs.reserve(subset.len()); + group_idxs.extend(groups); + } else { + groups.for_each(drop); + } + } else { + let groups = subset.iter().filter_map(|idx| { + if validity.get_bit_unchecked(*idx as usize) { + let hash = hash_keys.hashes.value_unchecked(*idx as usize); + let view = views.get_unchecked(*idx as usize); + Some(self.insert_key(hash, *view, buffers)) + } else { + None + } + }); + if let Some(group_idxs) = group_idxs { + group_idxs.reserve(subset.len()); + group_idxs.extend(groups); + } else { + groups.for_each(drop); + } + } + } else { + let groups = subset.iter().map(|idx| { + let hash = hash_keys.hashes.value_unchecked(*idx as usize); + let view = views.get_unchecked(*idx as usize); + self.insert_key(hash, *view, buffers) + }); + if let Some(group_idxs) = group_idxs { + group_idxs.reserve(subset.len()); + group_idxs.extend(groups); + } else { + groups.for_each(drop); + } + } + } + } + + fn get_keys_in_group_order(&self) -> DataFrame { + let buffers: Arc<[_]> = self + .idx_map + .buffers() + .iter() + .map(|b| Buffer::from(b.to_vec())) + .collect(); + let views = self.idx_map.iter_hash_views().map(|(_h, v)| v).collect(); + let validity = if self.null_idx < IdxSize::MAX { + let mut validity = MutableBitmap::new(); + validity.extend_constant(self.idx_map.len() as usize, true); + validity.set(self.null_idx as usize, false); + Some(validity.freeze()) + } else { + None + }; + + unsafe { + match &self.dtype { + DataType::Binary => self.finalize_keys::<[u8]>(views, buffers, validity), + DataType::String => self.finalize_keys::(views, buffers, validity), + _ => unreachable!(), + } + } + } + + /// # Safety + /// All groupers must be a BinviewHashGrouper. + unsafe fn probe_partitioned_groupers( + &self, + groupers: &[Box], + hash_keys: &HashKeys, + partitioner: &HashPartitioner, + invert: bool, + probe_matches: &mut Vec, + ) { + let HashKeys::Binview(hash_keys) = hash_keys else { + unreachable!() + }; + assert!(partitioner.num_partitions() == groupers.len()); + + unsafe { + let null_p = partitioner.null_partition(); + let buffers = hash_keys.keys.data_buffers(); + let views = hash_keys.keys.views().as_slice(); + hash_keys.for_each_hash(|idx, opt_h| { + let has_group = if let Some(h) = opt_h { + let p = partitioner.hash_to_partition(h); + let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(p); + let grouper = + &*(dyn_grouper as *const dyn Grouper as *const BinviewHashGrouper); + let view = views.get_unchecked(idx as usize); + grouper.contains_key(h, view, buffers) + } else { + let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(null_p); + let grouper = + &*(dyn_grouper as *const dyn Grouper as *const BinviewHashGrouper); + grouper.contains_null() + }; + + if has_group != invert { + probe_matches.push(idx); + } + }); + } + } + + /// # Safety + /// All groupers must be a BinviewHashGrouper. + unsafe fn contains_key_partitioned_groupers( + &self, + groupers: &[Box], + hash_keys: &HashKeys, + partitioner: &HashPartitioner, + invert: bool, + contains_key: &mut BitmapBuilder, + ) { + let HashKeys::Binview(hash_keys) = hash_keys else { + unreachable!() + }; + assert!(partitioner.num_partitions() == groupers.len()); + + unsafe { + let null_p = partitioner.null_partition(); + let buffers = hash_keys.keys.data_buffers(); + let views = hash_keys.keys.views().as_slice(); + hash_keys.for_each_hash(|idx, opt_h| { + let has_group = if let Some(h) = opt_h { + let p = partitioner.hash_to_partition(h); + let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(p); + let grouper = + &*(dyn_grouper as *const dyn Grouper as *const BinviewHashGrouper); + let view = views.get_unchecked(idx as usize); + grouper.contains_key(h, view, buffers) + } else { + let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(null_p); + let grouper = + &*(dyn_grouper as *const dyn Grouper as *const BinviewHashGrouper); + grouper.contains_null() + }; + + contains_key.push(has_group != invert); + }); + } + } + + fn as_any(&self) -> &dyn Any { + self + } +} diff --git a/crates/polars-expr/src/groups/mod.rs b/crates/polars-expr/src/groups/mod.rs new file mode 100644 index 000000000000..d818dd8e6022 --- /dev/null +++ b/crates/polars-expr/src/groups/mod.rs @@ -0,0 +1,116 @@ +use std::any::Any; + +use arrow::bitmap::BitmapBuilder; +use polars_core::prelude::*; +use polars_utils::IdxSize; +use polars_utils::hashing::HashPartitioner; + +use crate::hash_keys::HashKeys; + +mod binview; +mod row_encoded; +mod single_key; + +/// A Grouper maps keys to groups, such that duplicate keys map to the same group. +pub trait Grouper: Any + Send + Sync { + /// Creates a new empty Grouper similar to this one. + fn new_empty(&self) -> Box; + + /// Reserves space for the given number additional groups. + fn reserve(&mut self, additional: usize); + + /// Returns the number of groups in this Grouper. + fn num_groups(&self) -> IdxSize; + + /// Inserts the given subset of keys into this Grouper. If groups_idxs is + /// passed it is extended such with the group index of keys[subset[i]]. + /// + /// # Safety + /// The subset indexes must be in-bounds. + unsafe fn insert_keys_subset( + &mut self, + keys: &HashKeys, + subset: &[IdxSize], + group_idxs: Option<&mut Vec>, + ); + + /// Returns the keys in this Grouper in group order, that is the key for + /// group i is returned in row i. + fn get_keys_in_group_order(&self) -> DataFrame; + + /// Returns the (indices of the) keys found in the groupers. If + /// invert is true it instead returns the keys not found in the groupers. + /// # Safety + /// All groupers must have the same schema. + unsafe fn probe_partitioned_groupers( + &self, + groupers: &[Box], + keys: &HashKeys, + partitioner: &HashPartitioner, + invert: bool, + probe_matches: &mut Vec, + ); + + /// Returns for each key if it is found in the groupers. If invert is true + /// it returns true if it isn't found. + /// # Safety + /// All groupers must have the same schema. + unsafe fn contains_key_partitioned_groupers( + &self, + groupers: &[Box], + keys: &HashKeys, + partitioner: &HashPartitioner, + invert: bool, + contains_key: &mut BitmapBuilder, + ); + + fn as_any(&self) -> &dyn Any; +} + +pub fn new_hash_grouper(key_schema: Arc) -> Box { + if key_schema.len() > 1 { + Box::new(row_encoded::RowEncodedHashGrouper::new(key_schema)) + } else { + use single_key::SingleKeyHashGrouper as SK; + let (name, dt) = key_schema.get_at_index(0).unwrap(); + let (name, dt) = (name.clone(), dt.clone()); + match dt { + #[cfg(feature = "dtype-u8")] + DataType::UInt8 => Box::new(SK::::new(name, dt)), + #[cfg(feature = "dtype-u16")] + DataType::UInt16 => Box::new(SK::::new(name, dt)), + DataType::UInt32 => Box::new(SK::::new(name, dt)), + DataType::UInt64 => Box::new(SK::::new(name, dt)), + #[cfg(feature = "dtype-i8")] + DataType::Int8 => Box::new(SK::::new(name, dt)), + #[cfg(feature = "dtype-i16")] + DataType::Int16 => Box::new(SK::::new(name, dt)), + DataType::Int32 => Box::new(SK::::new(name, dt)), + DataType::Int64 => Box::new(SK::::new(name, dt)), + #[cfg(feature = "dtype-i128")] + DataType::Int128 => Box::new(SK::::new(name, dt)), + DataType::Float32 => Box::new(SK::::new(name, dt)), + DataType::Float64 => Box::new(SK::::new(name, dt)), + + #[cfg(feature = "dtype-date")] + DataType::Date => Box::new(SK::::new(name, dt)), + #[cfg(feature = "dtype-datetime")] + DataType::Datetime(_, _) => Box::new(SK::::new(name, dt)), + #[cfg(feature = "dtype-duration")] + DataType::Duration(_) => Box::new(SK::::new(name, dt)), + #[cfg(feature = "dtype-time")] + DataType::Time => Box::new(SK::::new(name, dt)), + + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(_, _) => Box::new(SK::::new(name, dt)), + #[cfg(feature = "dtype-categorical")] + DataType::Enum(_, _) => Box::new(SK::::new(name, dt)), + + DataType::String | DataType::Binary => { + Box::new(binview::BinviewHashGrouper::new(name, dt)) + }, + + _ => Box::new(row_encoded::RowEncodedHashGrouper::new(key_schema)), + } + } +} diff --git a/crates/polars-expr/src/groups/row_encoded.rs b/crates/polars-expr/src/groups/row_encoded.rs new file mode 100644 index 000000000000..0b8d59ae06a5 --- /dev/null +++ b/crates/polars-expr/src/groups/row_encoded.rs @@ -0,0 +1,219 @@ +use arrow::array::Array; +use polars_row::RowEncodingOptions; +use polars_utils::idx_map::bytes_idx_map::{BytesIndexMap, Entry}; +use polars_utils::itertools::Itertools; +use polars_utils::vec::PushUnchecked; + +use self::row_encode::get_row_encoding_context; +use super::*; +use crate::hash_keys::HashKeys; + +#[derive(Default)] +pub struct RowEncodedHashGrouper { + key_schema: Arc, + idx_map: BytesIndexMap<()>, +} + +impl RowEncodedHashGrouper { + pub fn new(key_schema: Arc) -> Self { + Self { + key_schema, + idx_map: BytesIndexMap::new(), + } + } + + fn insert_key(&mut self, hash: u64, key: &[u8]) -> IdxSize { + match self.idx_map.entry(hash, key) { + Entry::Occupied(o) => o.index(), + Entry::Vacant(v) => { + let index = v.index(); + v.insert(()); + index + }, + } + } + + fn contains_key(&self, hash: u64, key: &[u8]) -> bool { + self.idx_map.contains_key(hash, key) + } + + fn finalize_keys(&self, mut key_rows: Vec<&[u8]>) -> DataFrame { + let key_dtypes = self + .key_schema + .iter() + .map(|(_name, dt)| dt.to_physical().to_arrow(CompatLevel::newest())) + .collect::>(); + let ctxts = self + .key_schema + .iter() + .map(|(_, dt)| get_row_encoding_context(dt, false)) + .collect::>(); + let fields = vec![RowEncodingOptions::new_unsorted(); key_dtypes.len()]; + let key_columns = + unsafe { polars_row::decode::decode_rows(&mut key_rows, &fields, &ctxts, &key_dtypes) }; + + let cols = self + .key_schema + .iter() + .zip(key_columns) + .map(|((name, dt), col)| { + let s = Series::try_from((name.clone(), col)).unwrap(); + unsafe { s.from_physical_unchecked(dt) } + .unwrap() + .into_column() + }) + .collect(); + unsafe { DataFrame::new_no_checks_height_from_first(cols) } + } +} + +impl Grouper for RowEncodedHashGrouper { + fn new_empty(&self) -> Box { + Box::new(Self::new(self.key_schema.clone())) + } + + fn reserve(&mut self, additional: usize) { + self.idx_map.reserve(additional); + } + + fn num_groups(&self) -> IdxSize { + self.idx_map.len() + } + + unsafe fn insert_keys_subset( + &mut self, + keys: &HashKeys, + subset: &[IdxSize], + group_idxs: Option<&mut Vec>, + ) { + let HashKeys::RowEncoded(keys) = keys else { + unreachable!() + }; + + unsafe { + if let Some(group_idxs) = group_idxs { + group_idxs.reserve(subset.len()); + keys.for_each_hash_subset(subset, |idx, opt_hash| { + if let Some(hash) = opt_hash { + let key = keys.keys.value_unchecked(idx as usize); + group_idxs.push_unchecked(self.insert_key(hash, key)); + } + }); + } else { + keys.for_each_hash_subset(subset, |idx, opt_hash| { + if let Some(hash) = opt_hash { + let key = keys.keys.value_unchecked(idx as usize); + self.insert_key(hash, key); + } + }); + } + } + } + + fn get_keys_in_group_order(&self) -> DataFrame { + unsafe { + let mut key_rows: Vec<&[u8]> = Vec::with_capacity(self.idx_map.len() as usize); + for (_, key) in self.idx_map.iter_hash_keys() { + key_rows.push_unchecked(key); + } + self.finalize_keys(key_rows) + } + } + + /// # Safety + /// All groupers must be a RowEncodedHashGrouper. + unsafe fn probe_partitioned_groupers( + &self, + groupers: &[Box], + keys: &HashKeys, + partitioner: &HashPartitioner, + invert: bool, + probe_matches: &mut Vec, + ) { + let HashKeys::RowEncoded(keys) = keys else { + unreachable!() + }; + assert!(partitioner.num_partitions() == groupers.len()); + + unsafe { + if keys.keys.has_nulls() { + for (idx, hash) in keys.hashes.values_iter().enumerate_idx() { + let has_group = if let Some(key) = keys.keys.get_unchecked(idx as usize) { + let p = partitioner.hash_to_partition(*hash); + let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(p); + let grouper = + &*(dyn_grouper as *const dyn Grouper as *const RowEncodedHashGrouper); + grouper.contains_key(*hash, key) + } else { + false + }; + + if has_group != invert { + probe_matches.push(idx); + } + } + } else { + for (idx, (hash, key)) in keys + .hashes + .values_iter() + .zip(keys.keys.values_iter()) + .enumerate_idx() + { + let p = partitioner.hash_to_partition(*hash); + let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(p); + let grouper = + &*(dyn_grouper as *const dyn Grouper as *const RowEncodedHashGrouper); + if grouper.contains_key(*hash, key) != invert { + probe_matches.push(idx); + } + } + } + } + } + + /// # Safety + /// All groupers must be a RowEncodedHashGrouper. + unsafe fn contains_key_partitioned_groupers( + &self, + groupers: &[Box], + keys: &HashKeys, + partitioner: &HashPartitioner, + invert: bool, + contains_key: &mut BitmapBuilder, + ) { + let HashKeys::RowEncoded(keys) = keys else { + unreachable!() + }; + assert!(partitioner.num_partitions() == groupers.len()); + + unsafe { + if keys.keys.has_nulls() { + for (idx, hash) in keys.hashes.values_iter().enumerate_idx() { + let has_group = if let Some(key) = keys.keys.get_unchecked(idx as usize) { + let p = partitioner.hash_to_partition(*hash); + let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(p); + let grouper = + &*(dyn_grouper as *const dyn Grouper as *const RowEncodedHashGrouper); + grouper.contains_key(*hash, key) + } else { + false + }; + + contains_key.push(has_group != invert); + } + } else { + for (hash, key) in keys.hashes.values_iter().zip(keys.keys.values_iter()) { + let p = partitioner.hash_to_partition(*hash); + let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(p); + let grouper = + &*(dyn_grouper as *const dyn Grouper as *const RowEncodedHashGrouper); + contains_key.push(grouper.contains_key(*hash, key) != invert); + } + } + } + } + + fn as_any(&self) -> &dyn Any { + self + } +} diff --git a/crates/polars-expr/src/groups/single_key.rs b/crates/polars-expr/src/groups/single_key.rs new file mode 100644 index 000000000000..26e9fac45b4e --- /dev/null +++ b/crates/polars-expr/src/groups/single_key.rs @@ -0,0 +1,250 @@ +use arrow::array::Array; +use arrow::bitmap::MutableBitmap; +use polars_utils::idx_map::total_idx_map::{Entry, TotalIndexMap}; +use polars_utils::total_ord::{TotalEq, TotalHash}; +use polars_utils::vec::PushUnchecked; + +use super::*; +use crate::hash_keys::{HashKeys, for_each_hash_single}; + +#[derive(Default)] +pub struct SingleKeyHashGrouper { + name: PlSmallStr, + dtype: DataType, + idx_map: TotalIndexMap, ()>, + null_idx: IdxSize, +} + +impl SingleKeyHashGrouper +where + for<'a> T: PolarsDataType = K>, + K: Default + TotalHash + TotalEq, +{ + pub fn new(name: PlSmallStr, dtype: DataType) -> Self { + Self { + name, + dtype, + idx_map: TotalIndexMap::default(), + null_idx: IdxSize::MAX, + } + } + + #[inline(always)] + fn insert_key(&mut self, key: T::Physical<'static>) -> IdxSize { + match self.idx_map.entry(key) { + Entry::Occupied(o) => o.index(), + Entry::Vacant(v) => { + let index = v.index(); + v.insert(()); + index + }, + } + } + + #[inline(always)] + fn insert_null(&mut self) -> IdxSize { + if self.null_idx == IdxSize::MAX { + self.null_idx = self.idx_map.push_unmapped_entry(T::Physical::default(), ()); + } + self.null_idx + } + + #[inline(always)] + fn contains_key(&self, key: &T::Physical<'static>) -> bool { + self.idx_map.get(key).is_some() + } + + #[inline(always)] + fn contains_null(&self) -> bool { + self.null_idx < IdxSize::MAX + } + + fn finalize_keys(&self, keys: Vec>) -> DataFrame { + let mut keys = T::Array::from_vec( + keys, + self.dtype.to_physical().to_arrow(CompatLevel::newest()), + ); + if self.null_idx < IdxSize::MAX { + let mut validity = MutableBitmap::new(); + validity.extend_constant(keys.len(), true); + validity.set(self.null_idx as usize, false); + keys = keys.with_validity_typed(Some(validity.freeze())); + } + unsafe { + let s = Series::from_chunks_and_dtype_unchecked( + self.name.clone(), + vec![Box::new(keys)], + &self.dtype, + ); + DataFrame::new(vec![Column::from(s)]).unwrap() + } + } +} + +impl Grouper for SingleKeyHashGrouper +where + for<'a> T: PolarsDataType = K>, + K: Default + TotalHash + TotalEq + Clone + Send + Sync + 'static, +{ + fn new_empty(&self) -> Box { + Box::new(Self::new(self.name.clone(), self.dtype.clone())) + } + + fn reserve(&mut self, additional: usize) { + self.idx_map.reserve(additional); + } + + fn num_groups(&self) -> IdxSize { + self.idx_map.len() + } + + unsafe fn insert_keys_subset( + &mut self, + hash_keys: &HashKeys, + subset: &[IdxSize], + group_idxs: Option<&mut Vec>, + ) { + let HashKeys::Single(hash_keys) = hash_keys else { + unreachable!() + }; + let ca: &ChunkedArray = hash_keys.keys.as_phys_any().downcast_ref().unwrap(); + let arr = ca.downcast_as_array(); + + unsafe { + if arr.has_nulls() { + if hash_keys.null_is_valid { + let groups = subset.iter().map(|idx| { + let opt_k = arr.get_unchecked(*idx as usize); + if let Some(k) = opt_k { + self.insert_key(k) + } else { + self.insert_null() + } + }); + if let Some(group_idxs) = group_idxs { + group_idxs.reserve(subset.len()); + group_idxs.extend(groups); + } else { + groups.for_each(drop); + } + } else { + let groups = subset.iter().filter_map(|idx| { + let opt_k = arr.get_unchecked(*idx as usize); + opt_k.map(|k| self.insert_key(k)) + }); + if let Some(group_idxs) = group_idxs { + group_idxs.reserve(subset.len()); + group_idxs.extend(groups); + } else { + groups.for_each(drop); + } + } + } else { + let groups = subset.iter().map(|idx| { + let k = arr.value_unchecked(*idx as usize); + self.insert_key(k) + }); + if let Some(group_idxs) = group_idxs { + group_idxs.reserve(subset.len()); + group_idxs.extend(groups); + } else { + groups.for_each(drop); + } + } + } + } + + fn get_keys_in_group_order(&self) -> DataFrame { + unsafe { + let mut key_rows = Vec::with_capacity(self.idx_map.len() as usize); + for key in self.idx_map.iter_keys() { + key_rows.push_unchecked(key.clone()); + } + self.finalize_keys(key_rows) + } + } + + /// # Safety + /// All groupers must be a SingleKeyHashGrouper. + unsafe fn probe_partitioned_groupers( + &self, + groupers: &[Box], + hash_keys: &HashKeys, + partitioner: &HashPartitioner, + invert: bool, + probe_matches: &mut Vec, + ) { + let HashKeys::Single(hash_keys) = hash_keys else { + unreachable!() + }; + let ca: &ChunkedArray = hash_keys.keys.as_phys_any().downcast_ref().unwrap(); + let arr = ca.downcast_as_array(); + assert!(partitioner.num_partitions() == groupers.len()); + + unsafe { + let null_p = partitioner.null_partition(); + for_each_hash_single(ca, &hash_keys.random_state, |idx, opt_h| { + let has_group = if let Some(h) = opt_h { + let p = partitioner.hash_to_partition(h); + let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(p); + let grouper = + &*(dyn_grouper as *const dyn Grouper as *const SingleKeyHashGrouper); + let key = arr.value_unchecked(idx as usize); + grouper.contains_key(&key) + } else { + let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(null_p); + let grouper = + &*(dyn_grouper as *const dyn Grouper as *const SingleKeyHashGrouper); + grouper.contains_null() + }; + + if has_group != invert { + probe_matches.push(idx); + } + }); + } + } + + /// # Safety + /// All groupers must be a SingleKeyHashGrouper. + unsafe fn contains_key_partitioned_groupers( + &self, + groupers: &[Box], + hash_keys: &HashKeys, + partitioner: &HashPartitioner, + invert: bool, + contains_key: &mut BitmapBuilder, + ) { + let HashKeys::Single(hash_keys) = hash_keys else { + unreachable!() + }; + let ca: &ChunkedArray = hash_keys.keys.as_phys_any().downcast_ref().unwrap(); + let arr = ca.downcast_as_array(); + assert!(partitioner.num_partitions() == groupers.len()); + + unsafe { + let null_p = partitioner.null_partition(); + for_each_hash_single(ca, &hash_keys.random_state, |idx, opt_h| { + let has_group = if let Some(h) = opt_h { + let p = partitioner.hash_to_partition(h); + let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(p); + let grouper = + &*(dyn_grouper as *const dyn Grouper as *const SingleKeyHashGrouper); + let key = arr.value_unchecked(idx as usize); + grouper.contains_key(&key) + } else { + let dyn_grouper: &dyn Grouper = &**groupers.get_unchecked(null_p); + let grouper = + &*(dyn_grouper as *const dyn Grouper as *const SingleKeyHashGrouper); + grouper.contains_null() + }; + + contains_key.push(has_group != invert); + }); + } + } + + fn as_any(&self) -> &dyn Any { + self + } +} diff --git a/crates/polars-expr/src/hash_keys.rs b/crates/polars-expr/src/hash_keys.rs new file mode 100644 index 000000000000..1b52357f5021 --- /dev/null +++ b/crates/polars-expr/src/hash_keys.rs @@ -0,0 +1,550 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use std::hash::BuildHasher; + +use arrow::array::{Array, BinaryArray, BinaryViewArray, PrimitiveArray, StaticArray, UInt64Array}; +use arrow::bitmap::Bitmap; +use arrow::compute::utils::combine_validities_and_many; +use polars_core::error::polars_err; +use polars_core::frame::DataFrame; +use polars_core::prelude::row_encode::_get_rows_encoded_unordered; +use polars_core::prelude::{ChunkedArray, DataType, PlRandomState, PolarsDataType}; +use polars_core::series::Series; +use polars_utils::IdxSize; +use polars_utils::cardinality_sketch::CardinalitySketch; +use polars_utils::hashing::HashPartitioner; +use polars_utils::itertools::Itertools; +use polars_utils::total_ord::{BuildHasherTotalExt, TotalHash}; +use polars_utils::vec::PushUnchecked; + +#[derive(PartialEq, Eq, PartialOrd, Ord)] +pub enum HashKeysVariant { + RowEncoded, + Single, + Binview, +} + +pub fn hash_keys_variant_for_dtype(dt: &DataType) -> HashKeysVariant { + match dt { + dt if dt.is_primitive_numeric() | dt.is_temporal() => HashKeysVariant::Single, + + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(_, _) => HashKeysVariant::Single, + #[cfg(feature = "dtype-categorical")] + DataType::Enum(_, _) => HashKeysVariant::Single, + + DataType::String | DataType::Binary => HashKeysVariant::Binview, + + // TODO: more efficient encoding for these. + DataType::Boolean | DataType::Null => HashKeysVariant::RowEncoded, + + _ => HashKeysVariant::RowEncoded, + } +} + +macro_rules! downcast_single_key_ca { + ( + $self:expr, | $ca:ident | $($body:tt)* + ) => {{ + #[allow(unused_imports)] + use polars_core::datatypes::DataType::*; + match $self.dtype() { + #[cfg(feature = "dtype-i8")] + DataType::Int8 => { let $ca = $self.i8().unwrap(); $($body)* }, + #[cfg(feature = "dtype-i16")] + DataType::Int16 => { let $ca = $self.i16().unwrap(); $($body)* }, + DataType::Int32 => { let $ca = $self.i32().unwrap(); $($body)* }, + DataType::Int64 => { let $ca = $self.i64().unwrap(); $($body)* }, + #[cfg(feature = "dtype-u8")] + DataType::UInt8 => { let $ca = $self.u8().unwrap(); $($body)* }, + #[cfg(feature = "dtype-u16")] + DataType::UInt16 => { let $ca = $self.u16().unwrap(); $($body)* }, + DataType::UInt32 => { let $ca = $self.u32().unwrap(); $($body)* }, + DataType::UInt64 => { let $ca = $self.u64().unwrap(); $($body)* }, + #[cfg(feature = "dtype-i128")] + DataType::Int128 => { let $ca = $self.i128().unwrap(); $($body)* }, + DataType::Float32 => { let $ca = $self.f32().unwrap(); $($body)* }, + DataType::Float64 => { let $ca = $self.f64().unwrap(); $($body)* }, + + #[cfg(feature = "dtype-date")] + DataType::Date => { let $ca = $self.date().unwrap(); $($body)* }, + #[cfg(feature = "dtype-time")] + DataType::Time => { let $ca = $self.time().unwrap(); $($body)* }, + #[cfg(feature = "dtype-datetime")] + DataType::Datetime(..) => { let $ca = $self.datetime().unwrap(); $($body)* }, + #[cfg(feature = "dtype-duration")] + DataType::Duration(..) => { let $ca = $self.duration().unwrap(); $($body)* }, + + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(..) => { let $ca = $self.decimal().unwrap(); $($body)* }, + #[cfg(feature = "dtype-categorical")] + DataType::Enum(..) => { let $ca = $self.categorical().unwrap().physical(); $($body)* }, + + _ => unreachable!(), + } + }} +} +pub(crate) use downcast_single_key_ca; + +/// Represents a DataFrame plus a hash per row, intended for keys in grouping +/// or joining. The hashes may or may not actually be physically pre-computed, +/// this depends per type. +#[derive(Clone, Debug)] +pub enum HashKeys { + RowEncoded(RowEncodedKeys), + Binview(BinviewKeys), + Single(SingleKeys), +} + +impl HashKeys { + pub fn from_df( + df: &DataFrame, + random_state: PlRandomState, + null_is_valid: bool, + force_row_encoding: bool, + ) -> Self { + let first_col_variant = hash_keys_variant_for_dtype(df[0].dtype()); + let use_row_encoding = force_row_encoding + || df.width() > 1 + || first_col_variant == HashKeysVariant::RowEncoded; + if use_row_encoding { + let keys = df.get_columns(); + #[cfg(feature = "dtype-categorical")] + for key in keys { + if let DataType::Categorical(Some(rev_map), _) = key.dtype() { + assert!( + rev_map.is_active_global(), + "{}", + polars_err!(string_cache_mismatch) + ); + } + } + let mut keys_encoded = _get_rows_encoded_unordered(keys).unwrap().into_array(); + + if !null_is_valid { + let validities = keys + .iter() + .map(|c| c.as_materialized_series().rechunk_validity()) + .collect_vec(); + let combined = combine_validities_and_many(&validities); + keys_encoded.set_validity(combined); + } + + // TODO: use vechash? Not supported yet for lists. + // let mut hashes = Vec::with_capacity(df.height()); + // columns_to_hashes(df.get_columns(), Some(random_state), &mut hashes).unwrap(); + + let hashes = keys_encoded + .values_iter() + .map(|k| random_state.hash_one(k)) + .collect(); + Self::RowEncoded(RowEncodedKeys { + hashes: PrimitiveArray::from_vec(hashes), + keys: keys_encoded, + }) + } else if first_col_variant == HashKeysVariant::Binview { + let keys = if let Ok(ca_str) = df[0].str() { + ca_str.as_binary() + } else { + df[0].binary().unwrap().clone() + }; + let keys = keys.rechunk().downcast_as_array().clone(); + + let hashes = if keys.has_nulls() { + keys.iter() + .map(|opt_k| opt_k.map(|k| random_state.hash_one(k)).unwrap_or(0)) + .collect() + } else { + keys.values_iter() + .map(|k| random_state.hash_one(k)) + .collect() + }; + + Self::Binview(BinviewKeys { + hashes: PrimitiveArray::from_vec(hashes), + keys, + null_is_valid, + }) + } else { + Self::Single(SingleKeys { + random_state, + keys: df[0].as_materialized_series().rechunk(), + null_is_valid, + }) + } + } + + pub fn len(&self) -> usize { + match self { + HashKeys::RowEncoded(s) => s.keys.len(), + HashKeys::Single(s) => s.keys.len(), + HashKeys::Binview(s) => s.keys.len(), + } + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn validity(&self) -> Option<&Bitmap> { + match self { + HashKeys::RowEncoded(s) => s.keys.validity(), + HashKeys::Single(s) => s.keys.chunks()[0].validity(), + HashKeys::Binview(s) => s.keys.validity(), + } + } + + pub fn null_is_valid(&self) -> bool { + match self { + HashKeys::RowEncoded(_) => false, + HashKeys::Single(s) => s.null_is_valid, + HashKeys::Binview(s) => s.null_is_valid, + } + } + + /// Calls f with the index of and hash of each element in this HashKeys. + /// + /// If the element is null and null_is_valid is false the respective hash + /// will be None. + pub fn for_each_hash)>(&self, f: F) { + match self { + HashKeys::RowEncoded(s) => s.for_each_hash(f), + HashKeys::Single(s) => s.for_each_hash(f), + HashKeys::Binview(s) => s.for_each_hash(f), + } + } + + /// Calls f with the index of and hash of each element in the given + /// subset of indices of the HashKeys. + /// + /// If the element is null and null_is_valid is false the respective hash + /// will be None. + /// + /// # Safety + /// The indices in the subset must be in-bounds. + pub unsafe fn for_each_hash_subset)>( + &self, + subset: &[IdxSize], + f: F, + ) { + match self { + HashKeys::RowEncoded(s) => s.for_each_hash_subset(subset, f), + HashKeys::Single(s) => s.for_each_hash_subset(subset, f), + HashKeys::Binview(s) => s.for_each_hash_subset(subset, f), + } + } + + /// After this call partitions will be extended with the partition for each + /// hash. Nulls are assigned IdxSize::MAX or a specific partition depending + /// on whether partition_nulls is true. + pub fn gen_partitions( + &self, + partitioner: &HashPartitioner, + partitions: &mut Vec, + partition_nulls: bool, + ) { + unsafe { + let null_p = if partition_nulls | self.null_is_valid() { + partitioner.null_partition() as IdxSize + } else { + IdxSize::MAX + }; + partitions.reserve(self.len()); + self.for_each_hash(|_idx, opt_h| { + partitions.push_unchecked( + opt_h + .map(|h| partitioner.hash_to_partition(h) as IdxSize) + .unwrap_or(null_p), + ); + }); + } + } + + /// After this call partition_idxs[p] will be extended with the indices of + /// hashes that belong to partition p, and the cardinality sketches are + /// updated accordingly. + pub fn gen_idxs_per_partition( + &self, + partitioner: &HashPartitioner, + partition_idxs: &mut [Vec], + sketches: &mut [CardinalitySketch], + partition_nulls: bool, + ) { + if sketches.is_empty() { + self.gen_idxs_per_partition_impl::( + partitioner, + partition_idxs, + sketches, + partition_nulls | self.null_is_valid(), + ); + } else { + self.gen_idxs_per_partition_impl::( + partitioner, + partition_idxs, + sketches, + partition_nulls | self.null_is_valid(), + ); + } + } + + fn gen_idxs_per_partition_impl( + &self, + partitioner: &HashPartitioner, + partition_idxs: &mut [Vec], + sketches: &mut [CardinalitySketch], + partition_nulls: bool, + ) { + assert!(partition_idxs.len() == partitioner.num_partitions()); + assert!(!BUILD_SKETCHES || sketches.len() == partitioner.num_partitions()); + + let null_p = partitioner.null_partition(); + self.for_each_hash(|idx, opt_h| { + if let Some(h) = opt_h { + unsafe { + // SAFETY: we assured the number of partitions matches. + let p = partitioner.hash_to_partition(h); + partition_idxs.get_unchecked_mut(p).push(idx); + if BUILD_SKETCHES { + sketches.get_unchecked_mut(p).insert(h); + } + } + } else if partition_nulls { + unsafe { + partition_idxs.get_unchecked_mut(null_p).push(idx); + } + } + }); + } + + pub fn sketch_cardinality(&self, sketch: &mut CardinalitySketch) { + self.for_each_hash(|_idx, opt_h| { + sketch.insert(opt_h.unwrap_or(0)); + }) + } + + /// # Safety + /// The indices must be in-bounds. + pub unsafe fn gather_unchecked(&self, idxs: &[IdxSize]) -> Self { + match self { + HashKeys::RowEncoded(s) => Self::RowEncoded(s.gather_unchecked(idxs)), + HashKeys::Single(s) => Self::Single(s.gather_unchecked(idxs)), + HashKeys::Binview(s) => Self::Binview(s.gather_unchecked(idxs)), + } + } +} + +#[derive(Clone, Debug)] +pub struct RowEncodedKeys { + pub hashes: UInt64Array, // Always non-null, we use the validity of keys. + pub keys: BinaryArray, +} + +impl RowEncodedKeys { + pub fn for_each_hash)>(&self, f: F) { + for_each_hash_prehashed(self.hashes.values().as_slice(), self.keys.validity(), f); + } + + /// # Safety + /// The indices must be in-bounds. + pub unsafe fn for_each_hash_subset)>( + &self, + subset: &[IdxSize], + f: F, + ) { + for_each_hash_subset_prehashed( + self.hashes.values().as_slice(), + self.keys.validity(), + subset, + f, + ); + } + + /// # Safety + /// The indices must be in-bounds. + pub unsafe fn gather_unchecked(&self, idxs: &[IdxSize]) -> Self { + let idx_arr = arrow::ffi::mmap::slice(idxs); + Self { + hashes: polars_compute::gather::primitive::take_primitive_unchecked( + &self.hashes, + &idx_arr, + ), + keys: polars_compute::gather::binary::take_unchecked(&self.keys, &idx_arr), + } + } +} + +/// Single keys without prehashing. +#[derive(Clone, Debug)] +pub struct SingleKeys { + pub random_state: PlRandomState, + pub keys: Series, + pub null_is_valid: bool, +} + +impl SingleKeys { + pub fn for_each_hash)>(&self, f: F) { + downcast_single_key_ca!(self.keys, |keys| { + for_each_hash_single(keys, &self.random_state, f); + }) + } + + /// # Safety + /// The indices must be in-bounds. + pub unsafe fn for_each_hash_subset)>( + &self, + subset: &[IdxSize], + f: F, + ) { + downcast_single_key_ca!(self.keys, |keys| { + for_each_hash_subset_single(keys, subset, &self.random_state, f); + }) + } + + /// # Safety + /// The indices must be in-bounds. + pub unsafe fn gather_unchecked(&self, idxs: &[IdxSize]) -> Self { + Self { + random_state: self.random_state, + keys: self.keys.take_slice_unchecked(idxs), + null_is_valid: self.null_is_valid, + } + } +} + +/// Pre-hashed binary view keys with prehashing. +#[derive(Clone, Debug)] +pub struct BinviewKeys { + pub hashes: UInt64Array, + pub keys: BinaryViewArray, + pub null_is_valid: bool, +} + +impl BinviewKeys { + pub fn for_each_hash)>(&self, f: F) { + for_each_hash_prehashed(self.hashes.values().as_slice(), self.keys.validity(), f); + } + + /// # Safety + /// The indices must be in-bounds. + pub unsafe fn for_each_hash_subset)>( + &self, + subset: &[IdxSize], + f: F, + ) { + for_each_hash_subset_prehashed( + self.hashes.values().as_slice(), + self.keys.validity(), + subset, + f, + ); + } + + /// # Safety + /// The indices must be in-bounds. + pub unsafe fn gather_unchecked(&self, idxs: &[IdxSize]) -> Self { + let idx_arr = arrow::ffi::mmap::slice(idxs); + Self { + hashes: polars_compute::gather::primitive::take_primitive_unchecked( + &self.hashes, + &idx_arr, + ), + keys: polars_compute::gather::binview::take_binview_unchecked(&self.keys, &idx_arr), + null_is_valid: self.null_is_valid, + } + } +} + +fn for_each_hash_prehashed)>( + hashes: &[u64], + opt_v: Option<&Bitmap>, + mut f: F, +) { + if let Some(validity) = opt_v { + for (idx, (is_v, hash)) in validity.iter().zip(hashes).enumerate_idx() { + if is_v { + f(idx, Some(*hash)) + } else { + f(idx, None) + } + } + } else { + for (idx, h) in hashes.iter().enumerate_idx() { + f(idx, Some(*h)); + } + } +} + +/// # Safety +/// The indices must be in-bounds. +unsafe fn for_each_hash_subset_prehashed)>( + hashes: &[u64], + opt_v: Option<&Bitmap>, + subset: &[IdxSize], + mut f: F, +) { + if let Some(validity) = opt_v { + for idx in subset { + let hash = *hashes.get_unchecked(*idx as usize); + let is_v = validity.get_bit_unchecked(*idx as usize); + if is_v { + f(*idx, Some(hash)) + } else { + f(*idx, None) + } + } + } else { + for idx in subset { + f(*idx, Some(*hashes.get_unchecked(*idx as usize))); + } + } +} + +pub fn for_each_hash_single(keys: &ChunkedArray, random_state: &PlRandomState, mut f: F) +where + T: PolarsDataType, + for<'a> ::Physical<'a>: TotalHash, + F: FnMut(IdxSize, Option), +{ + let mut idx = 0; + if keys.has_nulls() { + for arr in keys.downcast_iter() { + for opt_k in arr.iter() { + f(idx, opt_k.map(|k| random_state.tot_hash_one(k))); + idx += 1; + } + } + } else { + for arr in keys.downcast_iter() { + for k in arr.values_iter() { + f(idx, Some(random_state.tot_hash_one(k))); + idx += 1; + } + } + } +} + +/// # Safety +/// The indices must be in-bounds. +unsafe fn for_each_hash_subset_single( + keys: &ChunkedArray, + subset: &[IdxSize], + random_state: &PlRandomState, + mut f: F, +) where + T: PolarsDataType, + for<'a> ::Physical<'a>: TotalHash, + F: FnMut(IdxSize, Option), +{ + let keys_arr = keys.downcast_as_array(); + + if keys_arr.has_nulls() { + for idx in subset { + let opt_k = keys_arr.get_unchecked(*idx as usize); + f(*idx, opt_k.map(|k| random_state.tot_hash_one(k))); + } + } else { + for idx in subset { + let k = keys_arr.value_unchecked(*idx as usize); + f(*idx, Some(random_state.tot_hash_one(k))); + } + } +} diff --git a/crates/polars-expr/src/hot_groups/binview.rs b/crates/polars-expr/src/hot_groups/binview.rs new file mode 100644 index 000000000000..82570c4fd8af --- /dev/null +++ b/crates/polars-expr/src/hot_groups/binview.rs @@ -0,0 +1,200 @@ +use arrow::array::builder::StaticArrayBuilder; +use arrow::array::{BinaryViewArrayGenericBuilder, PrimitiveArray, View}; +use arrow::bitmap::MutableBitmap; +use arrow::buffer::Buffer; +use polars_utils::vec::PushUnchecked; + +use super::*; +use crate::hash_keys::BinviewKeys; +use crate::hot_groups::fixed_index_table::FixedIndexTable; + +pub struct BinviewHashHotGrouper { + // The views in this table when not inline are stored in the vec. + table: FixedIndexTable<(u64, View, Vec)>, + evicted_key_hashes: Vec, + evicted_keys: BinaryViewArrayGenericBuilder<[u8]>, + null_idx: IdxSize, +} + +impl BinviewHashHotGrouper { + pub fn new(max_groups: usize) -> Self { + Self { + table: FixedIndexTable::new(max_groups.try_into().unwrap()), + evicted_key_hashes: Vec::new(), + evicted_keys: BinaryViewArrayGenericBuilder::new(ArrowDataType::BinaryView), + null_idx: IdxSize::MAX, + } + } + + /// # Safety + /// The view must be valid for the given buffer set. + #[inline(always)] + unsafe fn insert_key( + &mut self, + hash: u64, + view: View, + buffers: &Arc<[Buffer]>, + ) -> Option { + unsafe { + let mut evict = |ev_h: &u64, ev_view: &View, ev_buffer: &Vec| { + self.evicted_key_hashes.push(*ev_h); + if ev_view.is_inline() { + self.evicted_keys.push_inline_view_ignore_validity(*ev_view); + } else { + self.evicted_keys + .push_value_ignore_validity(ev_buffer.as_slice()); + } + }; + if view.is_inline() { + self.table.insert_key( + hash, + (), + |_, b| view == b.1, + |_| (hash, view, Vec::new()), + |_, ev_k| { + let (ev_h, ev_view, ev_buffer) = ev_k; + evict(ev_h, ev_view, ev_buffer); + *ev_h = hash; + *ev_view = view; + ev_buffer.clear(); + }, + ) + } else { + let bytes = view.get_external_slice_unchecked(buffers); + self.table.insert_key( + hash, + (), + |_, b| { + // We only reach here if the hash matched, so jump straight to full comparison. + bytes == b.2 + }, + |_| (hash, view, bytes.to_vec()), + |_, ev_k| { + let (ev_h, ev_view, ev_buffer) = ev_k; + evict(ev_h, ev_view, ev_buffer); + *ev_h = hash; + *ev_view = view; + ev_buffer.clear(); + ev_buffer.extend_from_slice(bytes); + }, + ) + } + } + } + + #[inline(always)] + fn insert_null(&mut self) -> Option { + if self.null_idx == IdxSize::MAX { + self.null_idx = self + .table + .push_unmapped_key((0, View::default(), Vec::new())); + } + Some(EvictIdx::new(self.null_idx, false)) + } +} + +impl HotGrouper for BinviewHashHotGrouper { + fn new_empty(&self, max_groups: usize) -> Box { + Box::new(Self::new(max_groups)) + } + + fn num_groups(&self) -> IdxSize { + self.table.len() as IdxSize + } + + fn insert_keys( + &mut self, + hash_keys: &HashKeys, + hot_idxs: &mut Vec, + hot_group_idxs: &mut Vec, + cold_idxs: &mut Vec, + ) { + let HashKeys::Binview(hash_keys) = hash_keys else { + unreachable!() + }; + + hot_idxs.reserve(hash_keys.keys.len()); + hot_group_idxs.reserve(hash_keys.keys.len()); + cold_idxs.reserve(hash_keys.keys.len()); + + let mut push_g = |idx: usize, opt_g: Option| unsafe { + if let Some(g) = opt_g { + hot_idxs.push_unchecked(idx as IdxSize); + hot_group_idxs.push_unchecked(g); + } else { + cold_idxs.push_unchecked(idx as IdxSize); + } + }; + + unsafe { + let views = hash_keys.keys.views().as_slice(); + let buffers = hash_keys.keys.data_buffers(); + if hash_keys.null_is_valid { + hash_keys.for_each_hash(|idx, opt_h| { + if let Some(h) = opt_h { + let view = views.get_unchecked(idx as usize); + push_g(idx as usize, self.insert_key(h, *view, buffers)); + } else { + push_g(idx as usize, self.insert_null()); + } + }); + } else { + hash_keys.for_each_hash(|idx, opt_h| { + if let Some(h) = opt_h { + let view = views.get_unchecked(idx as usize); + push_g(idx as usize, self.insert_key(h, *view, buffers)); + } + }); + } + } + } + + fn keys(&self) -> HashKeys { + unsafe { + let mut hashes = Vec::with_capacity(self.table.len()); + let mut keys_builder = BinaryViewArrayGenericBuilder::new(ArrowDataType::BinaryView); + keys_builder.reserve(self.table.len()); + for (h, view, buf) in self.table.keys() { + hashes.push_unchecked(*h); + if view.is_inline() { + keys_builder.push_inline_view_ignore_validity(*view); + } else { + keys_builder.push_value_ignore_validity(buf.as_slice()); + } + } + + let hashes = PrimitiveArray::from_vec(hashes); + let mut keys = keys_builder.freeze(); + let null_is_valid = self.null_idx < IdxSize::MAX; + if null_is_valid { + let mut validity = MutableBitmap::new(); + validity.extend_constant(keys.len(), true); + validity.set(self.null_idx as usize, false); + keys = keys.with_validity_typed(Some(validity.freeze())); + } + HashKeys::Binview(BinviewKeys { + hashes, + keys, + null_is_valid, + }) + } + } + + fn num_evictions(&self) -> usize { + self.evicted_keys.len() + } + + fn take_evicted_keys(&mut self) -> HashKeys { + let hashes = core::mem::take(&mut self.evicted_key_hashes); + let keys = self.evicted_keys.freeze_reset(); + HashKeys::Binview(BinviewKeys { + hashes: PrimitiveArray::from_vec(hashes), + keys, + null_is_valid: false, + }) + } + + fn as_any(&self) -> &dyn Any { + self + } +} diff --git a/crates/polars-expr/src/hot_groups/fixed_index_table.rs b/crates/polars-expr/src/hot_groups/fixed_index_table.rs new file mode 100644 index 000000000000..39720791d84e --- /dev/null +++ b/crates/polars-expr/src/hot_groups/fixed_index_table.rs @@ -0,0 +1,158 @@ +use polars_utils::IdxSize; +use polars_utils::select::select_unpredictable; +use polars_utils::vec::PushUnchecked; + +use crate::EvictIdx; + +const H2_MULT: u64 = 0xf1357aea2e62a9c5; + +#[derive(Clone)] +struct Slot { + tag: u32, + last_access_tag: u32, + key_index: IdxSize, +} + +/// A fixed-size hash table which maps keys to indices. +/// +/// Instead of growing indefinitely this table will evict keys instead. +pub struct FixedIndexTable { + slots: Vec, + keys: Vec, + shift: u8, + prng: u64, +} + +impl FixedIndexTable { + pub fn new(num_slots: IdxSize) -> Self { + assert!(num_slots.is_power_of_two()); + let empty_slot = Slot { + tag: u32::MAX, + last_access_tag: u32::MAX, + key_index: IdxSize::MAX, + }; + Self { + slots: vec![empty_slot; num_slots as usize], + shift: 64 - num_slots.trailing_zeros() as u8, + // We add one to the capacity for the null key. + keys: Vec::with_capacity(1 + num_slots as usize), + prng: 0, + } + } + + pub fn len(&self) -> usize { + self.keys.len() + } + + /// Insert a key which will never be mapped to nor evicted. + /// + /// This is useful for permanent entries which are handled externally. + /// Returns the key index this would have taken up. + pub fn push_unmapped_key(&mut self, key: K) -> IdxSize { + let idx = self.keys.len(); + self.keys.push(key); + idx as IdxSize + } + + /// Tries to insert a key with a given hash. + /// + /// Returns Some((index, evict_old)) if successful, None otherwise. + pub fn insert_key( + &mut self, + hash: u64, + key: Q, + mut eq: E, + mut insert: I, + mut evict_insert: V, + ) -> Option + where + E: FnMut(&Q, &K) -> bool, + I: FnMut(Q) -> K, + V: FnMut(Q, &mut K), + { + let tag = hash as u32; + let h1 = (hash >> self.shift) as usize; + let h2 = (hash.wrapping_mul(H2_MULT) >> self.shift) as usize; + + unsafe { + // We only want a single branch for the hot hit/miss check. This is + // why we check both slots at once. + let s1 = self.slots.get_unchecked(h1); + let s2 = self.slots.get_unchecked(h2); + let s1_delta = s1.tag ^ tag; + let s2_delta = s2.tag ^ tag; + // This check can have false positives (the binary AND of the deltas + // happens to be zero by accident), but this is very unlikely (~1/10k) + // and harmless if it does. False negatives are impossible. If this + // branch succeeds we almost surely have a hit, if it fails + // we're certain we have a miss. + if s1_delta & s2_delta == 0 { + // We want to branchlessly select the most likely candidate + // first to ensure no further branch mispredicts in the vast + // majority of cases. + let ha = select_unpredictable(s1_delta == 0, h1, h2); + let sa = self.slots.get_unchecked_mut(ha); + if let Some(sak) = self.keys.get(sa.key_index as usize) { + if eq(&key, sak) { + sa.last_access_tag = tag; + return Some(EvictIdx::new(sa.key_index, false)); + } + } + + // If both hashes matched we have to check the second slot too. + if s1_delta == s2_delta { + let hb = h1 ^ h2 ^ ha; + let sb = self.slots.get_unchecked_mut(hb); + if let Some(sbk) = self.keys.get(sb.key_index as usize) { + if eq(&key, sbk) { + sb.last_access_tag = tag; + return Some(EvictIdx::new(sb.key_index, false)); + } + } + } + } + + // Check if we can insert into an empty slot. + let num_keys = self.keys.len() as IdxSize; + if (num_keys as usize) < self.slots.len() { + // Check the first slot. + let s1 = self.slots.get_unchecked_mut(h1); + if s1.key_index >= num_keys { + s1.tag = tag; + s1.last_access_tag = tag; + s1.key_index = num_keys; + self.keys.push_unchecked(insert(key)); + return Some(EvictIdx::new(s1.key_index, false)); + } + + // Check the second slot. + let s2 = self.slots.get_unchecked_mut(h2); + if s2.key_index >= num_keys { + s2.tag = tag; + s2.last_access_tag = tag; + s2.key_index = num_keys; + self.keys.push_unchecked(insert(key)); + return Some(EvictIdx::new(s2.key_index, false)); + } + } + + // Randomly try to evict one of the two slots. + let hr = select_unpredictable(self.prng >> 63 != 0, h1, h2); + self.prng = self.prng.wrapping_add(hash); + let slot = self.slots.get_unchecked_mut(hr); + if slot.last_access_tag == tag { + slot.tag = tag; + let evict_key = self.keys.get_unchecked_mut(slot.key_index as usize); + evict_insert(key, evict_key); + Some(EvictIdx::new(slot.key_index, true)) + } else { + slot.last_access_tag = tag; + None + } + } + } + + pub fn keys(&self) -> &[K] { + &self.keys + } +} diff --git a/crates/polars-expr/src/hot_groups/mod.rs b/crates/polars-expr/src/hot_groups/mod.rs new file mode 100644 index 000000000000..256772cbd5f9 --- /dev/null +++ b/crates/polars-expr/src/hot_groups/mod.rs @@ -0,0 +1,96 @@ +use std::any::Any; + +use polars_core::prelude::*; +use polars_utils::IdxSize; + +use crate::EvictIdx; +use crate::hash_keys::HashKeys; + +mod binview; +mod fixed_index_table; +mod row_encoded; +mod single_key; + +/// A HotGrouper maps keys to groups, such that duplicate keys map to the same +/// group. Unlike a Grouper it has a fixed size and will cause evictions rather +/// than growing. +pub trait HotGrouper: Any + Send + Sync { + /// Creates a new empty HotGrouper similar to this one, with the given size. + fn new_empty(&self, groups: usize) -> Box; + + /// Returns the number of groups in this HotGrouper. + fn num_groups(&self) -> IdxSize; + + /// Inserts the given keys into this Grouper, extending groups_idxs with + /// the group index of keys[i]. + fn insert_keys( + &mut self, + keys: &HashKeys, + hot_idxs: &mut Vec, + hot_group_idxs: &mut Vec, + cold_idxs: &mut Vec, + ); + + /// Get all the current hot keys, in group order. + fn keys(&self) -> HashKeys; + + /// Get the number of evicted keys stored. + fn num_evictions(&self) -> usize; + + /// Consume all the evicted keys from this HotGrouper. + fn take_evicted_keys(&mut self) -> HashKeys; + + fn as_any(&self) -> &dyn Any; +} + +pub fn new_hash_hot_grouper(key_schema: Arc, num_groups: usize) -> Box { + if key_schema.len() > 1 { + Box::new(row_encoded::RowEncodedHashHotGrouper::new( + key_schema, num_groups, + )) + } else { + use single_key::SingleKeyHashHotGrouper as SK; + let dt = key_schema.get_at_index(0).unwrap().1.clone(); + let ng = num_groups; + match dt { + #[cfg(feature = "dtype-u8")] + DataType::UInt8 => Box::new(SK::::new(dt, ng)), + #[cfg(feature = "dtype-u16")] + DataType::UInt16 => Box::new(SK::::new(dt, ng)), + DataType::UInt32 => Box::new(SK::::new(dt, ng)), + DataType::UInt64 => Box::new(SK::::new(dt, ng)), + #[cfg(feature = "dtype-i8")] + DataType::Int8 => Box::new(SK::::new(dt, ng)), + #[cfg(feature = "dtype-i16")] + DataType::Int16 => Box::new(SK::::new(dt, ng)), + DataType::Int32 => Box::new(SK::::new(dt, ng)), + DataType::Int64 => Box::new(SK::::new(dt, ng)), + #[cfg(feature = "dtype-i128")] + DataType::Int128 => Box::new(SK::::new(dt, ng)), + DataType::Float32 => Box::new(SK::::new(dt, ng)), + DataType::Float64 => Box::new(SK::::new(dt, ng)), + + #[cfg(feature = "dtype-date")] + DataType::Date => Box::new(SK::::new(dt, ng)), + #[cfg(feature = "dtype-datetime")] + DataType::Datetime(_, _) => Box::new(SK::::new(dt, ng)), + #[cfg(feature = "dtype-duration")] + DataType::Duration(_) => Box::new(SK::::new(dt, ng)), + #[cfg(feature = "dtype-time")] + DataType::Time => Box::new(SK::::new(dt, ng)), + + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(_, _) => Box::new(SK::::new(dt, ng)), + #[cfg(feature = "dtype-categorical")] + DataType::Enum(_, _) => Box::new(SK::::new(dt, ng)), + + DataType::String | DataType::Binary => { + Box::new(binview::BinviewHashHotGrouper::new(ng)) + }, + + _ => Box::new(row_encoded::RowEncodedHashHotGrouper::new( + key_schema, num_groups, + )), + } + } +} diff --git a/crates/polars-expr/src/hot_groups/row_encoded.rs b/crates/polars-expr/src/hot_groups/row_encoded.rs new file mode 100644 index 000000000000..f1e87ada63c7 --- /dev/null +++ b/crates/polars-expr/src/hot_groups/row_encoded.rs @@ -0,0 +1,112 @@ +use arrow::array::{BinaryArray, PrimitiveArray}; +use arrow::buffer::Buffer; +use arrow::offset::{Offsets, OffsetsBuffer}; +use polars_utils::vec::PushUnchecked; + +use super::*; +use crate::hash_keys::RowEncodedKeys; +use crate::hot_groups::fixed_index_table::FixedIndexTable; + +pub struct RowEncodedHashHotGrouper { + key_schema: Arc, + table: FixedIndexTable<(u64, Vec)>, + evicted_key_hashes: Vec, + evicted_key_data: Vec, + evicted_key_offsets: Offsets, +} + +impl RowEncodedHashHotGrouper { + pub fn new(key_schema: Arc, max_groups: usize) -> Self { + Self { + key_schema, + table: FixedIndexTable::new(max_groups.try_into().unwrap()), + evicted_key_hashes: Vec::new(), + evicted_key_data: Vec::new(), + evicted_key_offsets: Offsets::new(), + } + } +} + +impl HotGrouper for RowEncodedHashHotGrouper { + fn new_empty(&self, max_groups: usize) -> Box { + Box::new(Self::new(self.key_schema.clone(), max_groups)) + } + + fn num_groups(&self) -> IdxSize { + self.table.len() as IdxSize + } + + fn insert_keys( + &mut self, + keys: &HashKeys, + hot_idxs: &mut Vec, + hot_group_idxs: &mut Vec, + cold_idxs: &mut Vec, + ) { + let HashKeys::RowEncoded(keys) = keys else { + unreachable!() + }; + + hot_idxs.reserve(keys.hashes.len()); + hot_group_idxs.reserve(keys.hashes.len()); + cold_idxs.reserve(keys.hashes.len()); + + unsafe { + keys.for_each_hash(|idx, opt_h| { + if let Some(h) = opt_h { + let key = keys.keys.value_unchecked(idx as usize); + let opt_g = self.table.insert_key( + h, + key, + |a, b| *a == b.1, + |k| (h, k.to_owned()), + |k, ev_k| { + self.evicted_key_hashes.push(ev_k.0); + self.evicted_key_offsets.try_push(ev_k.1.len()).unwrap(); + self.evicted_key_data.extend_from_slice(&ev_k.1); + ev_k.0 = h; + ev_k.1.clear(); + ev_k.1.extend_from_slice(k); + }, + ); + if let Some(g) = opt_g { + hot_idxs.push_unchecked(idx as IdxSize); + hot_group_idxs.push_unchecked(g); + } else { + cold_idxs.push_unchecked(idx as IdxSize); + } + } + }); + } + } + + fn keys(&self) -> HashKeys { + unsafe { + let mut hashes = Vec::with_capacity(self.table.len()); + let keys = LargeBinaryArray::from_trusted_len_values_iter( + self.table.keys().iter().map(|(h, k)| { + hashes.push_unchecked(*h); + k + }), + ); + let hashes = PrimitiveArray::from_vec(hashes); + HashKeys::RowEncoded(RowEncodedKeys { hashes, keys }) + } + } + + fn num_evictions(&self) -> usize { + self.evicted_key_offsets.len_proxy() + } + + fn take_evicted_keys(&mut self) -> HashKeys { + let hashes = PrimitiveArray::from_vec(core::mem::take(&mut self.evicted_key_hashes)); + let values = Buffer::from(core::mem::take(&mut self.evicted_key_data)); + let offsets = OffsetsBuffer::from(core::mem::take(&mut self.evicted_key_offsets)); + let keys = BinaryArray::new(ArrowDataType::LargeBinary, offsets, values, None); + HashKeys::RowEncoded(RowEncodedKeys { hashes, keys }) + } + + fn as_any(&self) -> &dyn Any { + self + } +} diff --git a/crates/polars-expr/src/hot_groups/single_key.rs b/crates/polars-expr/src/hot_groups/single_key.rs new file mode 100644 index 000000000000..ca38085d5f84 --- /dev/null +++ b/crates/polars-expr/src/hot_groups/single_key.rs @@ -0,0 +1,177 @@ +use std::hash::BuildHasher; + +use arrow::array::Array; +use arrow::bitmap::MutableBitmap; +use polars_utils::total_ord::{BuildHasherTotalExt, TotalEq, TotalHash}; +use polars_utils::vec::PushUnchecked; + +use super::*; +use crate::hash_keys::SingleKeys; +use crate::hot_groups::fixed_index_table::FixedIndexTable; + +pub struct SingleKeyHashHotGrouper { + dtype: DataType, + table: FixedIndexTable>, + evicted_keys: Vec>, + null_idx: IdxSize, + random_state: PlRandomState, +} + +impl SingleKeyHashHotGrouper +where + ChunkedArray: IntoSeries, + for<'a> T: PolarsDataType = K>, + K: Default + TotalHash + TotalEq + Send + Sync + 'static, +{ + pub fn new(dtype: DataType, max_groups: usize) -> Self { + Self { + dtype, + table: FixedIndexTable::new(max_groups.try_into().unwrap()), + evicted_keys: Vec::new(), + null_idx: IdxSize::MAX, + random_state: PlRandomState::default(), + } + } + + #[inline(always)] + fn insert_key( + &mut self, + k: T::Physical<'static>, + random_state: &R, + ) -> Option { + let h = random_state.tot_hash_one(&k); + self.table.insert_key( + h, + k, + |a, b| a.tot_eq(b), + |k| k, + |k, ev_k| self.evicted_keys.push(core::mem::replace(ev_k, k)), + ) + } + + #[inline(always)] + fn insert_null(&mut self) -> Option { + if self.null_idx == IdxSize::MAX { + self.null_idx = self.table.push_unmapped_key(T::Physical::default()); + } + Some(EvictIdx::new(self.null_idx, false)) + } + + fn finalize_keys(&self, keys: Vec>, add_mask: bool) -> HashKeys { + let mut keys = T::Array::from_vec( + keys, + self.dtype.to_physical().to_arrow(CompatLevel::newest()), + ); + if add_mask && self.null_idx < IdxSize::MAX { + let mut validity = MutableBitmap::new(); + validity.extend_constant(keys.len(), true); + validity.set(self.null_idx as usize, false); + keys = keys.with_validity_typed(Some(validity.freeze())); + } + + unsafe { + let s = Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![Box::new(keys)], + &self.dtype, + ); + HashKeys::Single(SingleKeys { + keys: s, + null_is_valid: self.null_idx < IdxSize::MAX, + random_state: self.random_state, + }) + } + } +} + +impl HotGrouper for SingleKeyHashHotGrouper +where + ChunkedArray: IntoSeries, + for<'a> T: PolarsDataType = K>, + K: Default + TotalHash + TotalEq + Clone + Send + Sync + 'static, +{ + fn new_empty(&self, max_groups: usize) -> Box { + Box::new(Self::new(self.dtype.clone(), max_groups)) + } + + fn num_groups(&self) -> IdxSize { + self.table.len() as IdxSize + } + + fn insert_keys( + &mut self, + hash_keys: &HashKeys, + hot_idxs: &mut Vec, + hot_group_idxs: &mut Vec, + cold_idxs: &mut Vec, + ) { + let HashKeys::Single(hash_keys) = hash_keys else { + unreachable!() + }; + + // Preserve random state if non-empty. + if !hash_keys.keys.is_empty() { + self.random_state = hash_keys.random_state; + } + + let keys: &ChunkedArray = hash_keys.keys.as_phys_any().downcast_ref().unwrap(); + hot_idxs.reserve(keys.len()); + hot_group_idxs.reserve(keys.len()); + cold_idxs.reserve(keys.len()); + + let mut push_g = |idx: usize, opt_g: Option| unsafe { + if let Some(g) = opt_g { + hot_idxs.push_unchecked(idx as IdxSize); + hot_group_idxs.push_unchecked(g); + } else { + cold_idxs.push_unchecked(idx as IdxSize); + } + }; + + let mut idx = 0; + for arr in keys.downcast_iter() { + if arr.has_nulls() { + if hash_keys.null_is_valid { + for opt_k in arr.iter() { + if let Some(k) = opt_k { + push_g(idx, self.insert_key(k, &hash_keys.random_state)); + } else { + push_g(idx, self.insert_null()); + } + idx += 1; + } + } else { + for opt_k in arr.iter() { + if let Some(k) = opt_k { + push_g(idx, self.insert_key(k, &hash_keys.random_state)); + } + idx += 1; + } + } + } else { + for k in arr.values_iter() { + let g = self.insert_key(k, &hash_keys.random_state); + push_g(idx, g); + idx += 1; + } + } + } + } + + fn keys(&self) -> HashKeys { + self.finalize_keys(self.table.keys().to_vec(), true) + } + + fn num_evictions(&self) -> usize { + self.evicted_keys.len() + } + + fn take_evicted_keys(&mut self) -> HashKeys { + let keys = core::mem::take(&mut self.evicted_keys); + self.finalize_keys(keys, false) + } + + fn as_any(&self) -> &dyn Any { + self + } +} diff --git a/crates/polars-expr/src/idx_table/binview.rs b/crates/polars-expr/src/idx_table/binview.rs new file mode 100644 index 000000000000..47c5ff808f68 --- /dev/null +++ b/crates/polars-expr/src/idx_table/binview.rs @@ -0,0 +1,374 @@ +#![allow(clippy::unnecessary_cast)] // Clippy doesn't recognize that IdxSize and u64 can be different. +#![allow(unsafe_op_in_unsafe_fn)] + +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; + +use arrow::array::{Array, View}; +use arrow::buffer::Buffer; +use polars_compute::binview_index_map::{BinaryViewIndexMap, Entry}; +use polars_utils::idx_vec::UnitVec; +use polars_utils::itertools::Itertools; +use polars_utils::unitvec; + +use super::*; +use crate::hash_keys::HashKeys; + +pub struct BinviewKeyIdxTable { + // These AtomicU64s actually are IdxSizes, but we use the top bit of the + // first index in each to mark keys during probing. + idx_map: BinaryViewIndexMap>, + idx_offset: IdxSize, + null_keys: Vec, + nulls_emitted: AtomicBool, +} + +impl BinviewKeyIdxTable { + pub fn new() -> Self { + Self { + idx_map: BinaryViewIndexMap::default(), + idx_offset: 0, + null_keys: Vec::new(), + nulls_emitted: AtomicBool::new(false), + } + } + + /// # Safety + /// The view must be valid for the buffers. + #[inline(always)] + unsafe fn probe_one( + &self, + key_idx: IdxSize, + hash: u64, + key: &View, + buffers: &[Buffer], + table_match: &mut Vec, + probe_match: &mut Vec, + ) -> bool { + if let Some(idxs) = unsafe { self.idx_map.get_view(hash, key, buffers) } { + for idx in &idxs[..] { + // Create matches, making sure to clear top bit. + table_match.push((idx.load(Ordering::Relaxed) & !(1 << 63)) as IdxSize); + probe_match.push(key_idx); + } + + // Mark if necessary. This action is idempotent so doesn't need + // atomic fetch_or to do it atomically. + if MARK_MATCHES { + let first_idx = unsafe { idxs.get_unchecked(0) }; + let first_idx_val = first_idx.load(Ordering::Relaxed); + if first_idx_val >> 63 == 0 { + first_idx.store(first_idx_val | (1 << 63), Ordering::Relaxed); + } + } + true + } else { + false + } + } + + /// # Safety + /// The views must be valid for the buffers. + unsafe fn probe_impl< + 'a, + const MARK_MATCHES: bool, + const EMIT_UNMATCHED: bool, + const NULL_IS_VALID: bool, + >( + &self, + keys: impl Iterator)>, + buffers: &[Buffer], + table_match: &mut Vec, + probe_match: &mut Vec, + limit: IdxSize, + ) -> IdxSize { + let mut keys_processed = 0; + for (key_idx, hash, key) in keys { + let found_match = if let Some(key) = key { + self.probe_one::( + key_idx, + hash, + key, + buffers, + table_match, + probe_match, + ) + } else if NULL_IS_VALID { + for idx in &self.null_keys { + table_match.push(*idx); + probe_match.push(key_idx); + } + if MARK_MATCHES && !self.nulls_emitted.load(Ordering::Relaxed) { + self.nulls_emitted.store(true, Ordering::Relaxed); + } + !self.null_keys.is_empty() + } else { + false + }; + + if EMIT_UNMATCHED && !found_match { + table_match.push(IdxSize::MAX); + probe_match.push(key_idx); + } + + keys_processed += 1; + if table_match.len() >= limit as usize { + break; + } + } + keys_processed + } + + /// # Safety + /// The views must be valid for the buffers. + #[allow(clippy::too_many_arguments)] + unsafe fn probe_dispatch<'a>( + &self, + keys: impl Iterator)>, + buffers: &[Buffer], + table_match: &mut Vec, + probe_match: &mut Vec, + mark_matches: bool, + emit_unmatched: bool, + null_is_valid: bool, + limit: IdxSize, + ) -> IdxSize { + match (mark_matches, emit_unmatched, null_is_valid) { + (false, false, false) => self.probe_impl::( + keys, + buffers, + table_match, + probe_match, + limit, + ), + (false, false, true) => self.probe_impl::( + keys, + buffers, + table_match, + probe_match, + limit, + ), + (false, true, false) => self.probe_impl::( + keys, + buffers, + table_match, + probe_match, + limit, + ), + (false, true, true) => { + self.probe_impl::(keys, buffers, table_match, probe_match, limit) + }, + (true, false, false) => self.probe_impl::( + keys, + buffers, + table_match, + probe_match, + limit, + ), + (true, false, true) => { + self.probe_impl::(keys, buffers, table_match, probe_match, limit) + }, + (true, true, false) => { + self.probe_impl::(keys, buffers, table_match, probe_match, limit) + }, + (true, true, true) => { + self.probe_impl::(keys, buffers, table_match, probe_match, limit) + }, + } + } +} + +impl IdxTable for BinviewKeyIdxTable { + fn new_empty(&self) -> Box { + Box::new(Self::new()) + } + + fn reserve(&mut self, additional: usize) { + self.idx_map.reserve(additional); + } + + fn num_keys(&self) -> IdxSize { + self.idx_map.len() + } + + fn insert_keys(&mut self, _hash_keys: &HashKeys, _track_unmatchable: bool) { + // Isn't needed anymore, but also don't want to remove the code from the other implementations. + unimplemented!() + } + + unsafe fn insert_keys_subset( + &mut self, + hash_keys: &HashKeys, + subset: &[IdxSize], + track_unmatchable: bool, + ) { + let HashKeys::Binview(hash_keys) = hash_keys else { + unreachable!() + }; + let new_idx_offset = (self.idx_offset as usize) + .checked_add(subset.len()) + .unwrap(); + assert!( + new_idx_offset < IdxSize::MAX as usize, + "overly large index in BinviewKeyIdxTable" + ); + + unsafe { + let buffers = hash_keys.keys.data_buffers(); + let views = hash_keys.keys.views(); + if let Some(validity) = hash_keys.keys.validity() { + for (i, subset_idx) in subset.iter().enumerate_idx() { + let hash = hash_keys.hashes.value_unchecked(*subset_idx as usize); + let key = views.get_unchecked(*subset_idx as usize); + let idx = self.idx_offset + i; + if validity.get_bit_unchecked(*subset_idx as usize) { + match self.idx_map.entry_view(hash, *key, buffers) { + Entry::Occupied(o) => { + o.into_mut().push(AtomicU64::new(idx as u64)); + }, + Entry::Vacant(v) => { + v.insert(unitvec![AtomicU64::new(idx as u64)]); + }, + } + } else if track_unmatchable | hash_keys.null_is_valid { + self.null_keys.push(idx); + } + } + } else { + for (i, subset_idx) in subset.iter().enumerate_idx() { + let hash = hash_keys.hashes.value_unchecked(*subset_idx as usize); + let key = views.get_unchecked(*subset_idx as usize); + let idx = self.idx_offset + i; + match self.idx_map.entry_view(hash, *key, buffers) { + Entry::Occupied(o) => { + o.into_mut().push(AtomicU64::new(idx as u64)); + }, + Entry::Vacant(v) => { + v.insert(unitvec![AtomicU64::new(idx as u64)]); + }, + } + } + } + } + + self.idx_offset = new_idx_offset as IdxSize; + } + + fn probe( + &self, + _hash_keys: &HashKeys, + _table_match: &mut Vec, + _probe_match: &mut Vec, + _mark_matches: bool, + _emit_unmatched: bool, + _limit: IdxSize, + ) -> IdxSize { + // Isn't needed anymore, but also don't want to remove the code from the other implementations. + unimplemented!() + } + + unsafe fn probe_subset( + &self, + hash_keys: &HashKeys, + subset: &[IdxSize], + table_match: &mut Vec, + probe_match: &mut Vec, + mark_matches: bool, + emit_unmatched: bool, + limit: IdxSize, + ) -> IdxSize { + let HashKeys::Binview(hash_keys) = hash_keys else { + unreachable!() + }; + + unsafe { + let buffers = hash_keys.keys.data_buffers(); + let views = hash_keys.keys.views(); + if let Some(validity) = hash_keys.keys.validity() { + let iter = subset.iter().map(|i| { + ( + *i, + hash_keys.hashes.value_unchecked(*i as usize), + if validity.get_bit_unchecked(*i as usize) { + Some(views.get_unchecked(*i as usize)) + } else { + None + }, + ) + }); + self.probe_dispatch( + iter, + buffers, + table_match, + probe_match, + mark_matches, + emit_unmatched, + hash_keys.null_is_valid, + limit, + ) + } else { + let iter = subset.iter().map(|i| { + ( + *i, + hash_keys.hashes.value_unchecked(*i as usize), + Some(views.get_unchecked(*i as usize)), + ) + }); + self.probe_dispatch( + iter, + buffers, + table_match, + probe_match, + mark_matches, + emit_unmatched, + false, // Whether or not nulls are valid doesn't matter. + limit, + ) + } + } + } + + fn unmarked_keys( + &self, + out: &mut Vec, + mut offset: IdxSize, + limit: IdxSize, + ) -> IdxSize { + out.clear(); + + let mut keys_processed = 0; + if !self.nulls_emitted.load(Ordering::Relaxed) { + if (offset as usize) < self.null_keys.len() { + out.extend( + self.null_keys[offset as usize..] + .iter() + .copied() + .take(limit as usize), + ); + keys_processed += out.len() as IdxSize; + offset += out.len() as IdxSize; + if out.len() >= limit as usize { + return keys_processed; + } + } + offset -= self.null_keys.len() as IdxSize; + } + + while let Some((_, _, idxs)) = self.idx_map.get_index(offset) { + let first_idx = unsafe { idxs.get_unchecked(0) }; + let first_idx_val = first_idx.load(Ordering::Relaxed); + if first_idx_val >> 63 == 0 { + for idx in &idxs[..] { + out.push((idx.load(Ordering::Relaxed) & !(1 << 63)) as IdxSize); + } + } + + keys_processed += 1; + offset += 1; + if out.len() >= limit as usize { + break; + } + } + + keys_processed + } +} diff --git a/crates/polars-expr/src/idx_table/mod.rs b/crates/polars-expr/src/idx_table/mod.rs new file mode 100644 index 000000000000..f29a89ba1ca0 --- /dev/null +++ b/crates/polars-expr/src/idx_table/mod.rs @@ -0,0 +1,114 @@ +use std::any::Any; + +use polars_core::prelude::*; +use polars_utils::IdxSize; + +use crate::hash_keys::HashKeys; + +mod binview; +mod row_encoded; +mod single_key; + +pub trait IdxTable: Any + Send + Sync { + /// Creates a new empty IdxTable similar to this one. + fn new_empty(&self) -> Box; + + /// Reserves space for the given number additional keys. + fn reserve(&mut self, additional: usize); + + /// Returns the number of unique keys in this IdxTable. + fn num_keys(&self) -> IdxSize; + + /// Inserts the given keys into this IdxTable. + fn insert_keys(&mut self, keys: &HashKeys, track_unmatchable: bool); + + /// Inserts a subset of the given keys into this IdxTable. + /// # Safety + /// The provided subset indices must be in-bounds. + unsafe fn insert_keys_subset( + &mut self, + keys: &HashKeys, + subset: &[IdxSize], + track_unmatchable: bool, + ); + + /// Probe the table, adding an entry to table_match and probe_match for each + /// match. Will stop processing new keys once limit matches have been + /// generated, returning the number of keys processed. + /// + /// If mark_matches is true, matches are marked in the table as such. + /// + /// If emit_unmatched is true, for keys that do not have a match we emit a + /// match with ChunkId::null() on the table match. + fn probe( + &self, + hash_keys: &HashKeys, + table_match: &mut Vec, + probe_match: &mut Vec, + mark_matches: bool, + emit_unmatched: bool, + limit: IdxSize, + ) -> IdxSize; + + /// The same as probe, except it will only apply to the specified subset of keys. + /// # Safety + /// The provided subset indices must be in-bounds. + #[allow(clippy::too_many_arguments)] + unsafe fn probe_subset( + &self, + hash_keys: &HashKeys, + subset: &[IdxSize], + table_match: &mut Vec, + probe_match: &mut Vec, + mark_matches: bool, + emit_unmatched: bool, + limit: IdxSize, + ) -> IdxSize; + + /// Get the ChunkIds for each key which was never marked during probing. + fn unmarked_keys(&self, out: &mut Vec, offset: IdxSize, limit: IdxSize) -> IdxSize; +} + +pub fn new_idx_table(key_schema: Arc) -> Box { + if key_schema.len() > 1 { + Box::new(row_encoded::RowEncodedIdxTable::new()) + } else { + use single_key::SingleKeyIdxTable as SKIT; + match key_schema.get_at_index(0).unwrap().1 { + #[cfg(feature = "dtype-u8")] + DataType::UInt8 => Box::new(SKIT::::new()), + #[cfg(feature = "dtype-u16")] + DataType::UInt16 => Box::new(SKIT::::new()), + DataType::UInt32 => Box::new(SKIT::::new()), + DataType::UInt64 => Box::new(SKIT::::new()), + #[cfg(feature = "dtype-i8")] + DataType::Int8 => Box::new(SKIT::::new()), + #[cfg(feature = "dtype-i16")] + DataType::Int16 => Box::new(SKIT::::new()), + DataType::Int32 => Box::new(SKIT::::new()), + DataType::Int64 => Box::new(SKIT::::new()), + #[cfg(feature = "dtype-i128")] + DataType::Int128 => Box::new(SKIT::::new()), + DataType::Float32 => Box::new(SKIT::::new()), + DataType::Float64 => Box::new(SKIT::::new()), + + #[cfg(feature = "dtype-date")] + DataType::Date => Box::new(SKIT::::new()), + #[cfg(feature = "dtype-datetime")] + DataType::Datetime(_, _) => Box::new(SKIT::::new()), + #[cfg(feature = "dtype-duration")] + DataType::Duration(_) => Box::new(SKIT::::new()), + #[cfg(feature = "dtype-time")] + DataType::Time => Box::new(SKIT::::new()), + + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(_, _) => Box::new(SKIT::::new()), + #[cfg(feature = "dtype-categorical")] + DataType::Enum(_, _) => Box::new(SKIT::::new()), + + DataType::String | DataType::Binary => Box::new(binview::BinviewKeyIdxTable::new()), + + _ => Box::new(row_encoded::RowEncodedIdxTable::new()), + } + } +} diff --git a/crates/polars-expr/src/idx_table/row_encoded.rs b/crates/polars-expr/src/idx_table/row_encoded.rs new file mode 100644 index 000000000000..2b8b5f54b5e6 --- /dev/null +++ b/crates/polars-expr/src/idx_table/row_encoded.rs @@ -0,0 +1,348 @@ +#![allow(clippy::unnecessary_cast)] // Clippy doesn't recognize that IdxSize and u64 can be different. +#![allow(unsafe_op_in_unsafe_fn)] + +use std::sync::atomic::{AtomicU64, Ordering}; + +use arrow::array::Array; +use polars_compute::binview_index_map::{BinaryViewIndexMap, Entry}; +use polars_utils::idx_vec::UnitVec; +use polars_utils::itertools::Itertools; +use polars_utils::unitvec; + +use super::*; +use crate::hash_keys::HashKeys; + +#[derive(Default)] +pub struct RowEncodedIdxTable { + // These AtomicU64s actually are IdxSizes, but we use the top bit of the + // first index in each to mark keys during probing. + idx_map: BinaryViewIndexMap>, + idx_offset: IdxSize, + null_keys: Vec, +} + +impl RowEncodedIdxTable { + pub fn new() -> Self { + Self { + idx_map: BinaryViewIndexMap::new(), + idx_offset: 0, + null_keys: Vec::new(), + } + } +} + +impl RowEncodedIdxTable { + #[inline(always)] + fn probe_one( + &self, + key_idx: IdxSize, + hash: u64, + key: &[u8], + table_match: &mut Vec, + probe_match: &mut Vec, + ) -> bool { + if let Some(idxs) = self.idx_map.get(hash, key) { + for idx in &idxs[..] { + // Create matches, making sure to clear top bit. + table_match.push((idx.load(Ordering::Relaxed) & !(1 << 63)) as IdxSize); + probe_match.push(key_idx); + } + + // Mark if necessary. This action is idempotent so doesn't + // need any synchronization on the load, nor does it need a + // fetch_or to do it atomically. + if MARK_MATCHES { + let first_idx = unsafe { idxs.get_unchecked(0) }; + let first_idx_val = first_idx.load(Ordering::Relaxed); + if first_idx_val >> 63 == 0 { + first_idx.store(first_idx_val | (1 << 63), Ordering::Release); + } + } + true + } else { + false + } + } + + fn probe_impl<'a, const MARK_MATCHES: bool, const EMIT_UNMATCHED: bool>( + &self, + hash_keys: impl Iterator)>, + table_match: &mut Vec, + probe_match: &mut Vec, + limit: IdxSize, + ) -> IdxSize { + let mut keys_processed = 0; + for (key_idx, hash, key) in hash_keys { + let found_match = if let Some(key) = key { + self.probe_one::(key_idx, hash, key, table_match, probe_match) + } else { + false + }; + + if EMIT_UNMATCHED && !found_match { + table_match.push(IdxSize::MAX); + probe_match.push(key_idx); + } + + keys_processed += 1; + if table_match.len() >= limit as usize { + break; + } + } + keys_processed + } + + fn probe_dispatch<'a>( + &self, + hash_keys: impl Iterator)>, + table_match: &mut Vec, + probe_match: &mut Vec, + mark_matches: bool, + emit_unmatched: bool, + limit: IdxSize, + ) -> IdxSize { + match (mark_matches, emit_unmatched) { + (false, false) => { + self.probe_impl::(hash_keys, table_match, probe_match, limit) + }, + (false, true) => { + self.probe_impl::(hash_keys, table_match, probe_match, limit) + }, + (true, false) => { + self.probe_impl::(hash_keys, table_match, probe_match, limit) + }, + (true, true) => { + self.probe_impl::(hash_keys, table_match, probe_match, limit) + }, + } + } +} + +impl IdxTable for RowEncodedIdxTable { + fn new_empty(&self) -> Box { + Box::new(Self::new()) + } + + fn reserve(&mut self, additional: usize) { + self.idx_map.reserve(additional); + } + + fn num_keys(&self) -> IdxSize { + self.idx_map.len() + } + + fn insert_keys(&mut self, hash_keys: &HashKeys, track_unmatchable: bool) { + let HashKeys::RowEncoded(hash_keys) = hash_keys else { + unreachable!() + }; + let new_idx_offset = (self.idx_offset as usize) + .checked_add(hash_keys.keys.len()) + .unwrap(); + assert!( + new_idx_offset < IdxSize::MAX as usize, + "overly large index in RowEncodedIdxTable" + ); + + for (i, (hash, key)) in hash_keys + .hashes + .values_iter() + .zip(hash_keys.keys.iter()) + .enumerate_idx() + { + let idx = self.idx_offset + i; + if let Some(key) = key { + match self.idx_map.entry(*hash, key) { + Entry::Occupied(o) => { + o.into_mut().push(AtomicU64::new(idx as u64)); + }, + Entry::Vacant(v) => { + v.insert(unitvec![AtomicU64::new(idx as u64)]); + }, + } + } else if track_unmatchable { + self.null_keys.push(idx); + } + } + + self.idx_offset = new_idx_offset as IdxSize; + } + + unsafe fn insert_keys_subset( + &mut self, + hash_keys: &HashKeys, + subset: &[IdxSize], + track_unmatchable: bool, + ) { + let HashKeys::RowEncoded(hash_keys) = hash_keys else { + unreachable!() + }; + let new_idx_offset = (self.idx_offset as usize) + .checked_add(subset.len()) + .unwrap(); + assert!( + new_idx_offset < IdxSize::MAX as usize, + "overly large index in RowEncodedIdxTable" + ); + + for (i, subset_idx) in subset.iter().enumerate_idx() { + let hash = unsafe { hash_keys.hashes.value_unchecked(*subset_idx as usize) }; + let key = unsafe { hash_keys.keys.get_unchecked(*subset_idx as usize) }; + let idx = self.idx_offset + i; + if let Some(key) = key { + match self.idx_map.entry(hash, key) { + Entry::Occupied(o) => { + o.into_mut().push(AtomicU64::new(idx as u64)); + }, + Entry::Vacant(v) => { + v.insert(unitvec![AtomicU64::new(idx as u64)]); + }, + } + } else if track_unmatchable { + self.null_keys.push(idx); + } + } + + self.idx_offset = new_idx_offset as IdxSize; + } + + fn probe( + &self, + hash_keys: &HashKeys, + table_match: &mut Vec, + probe_match: &mut Vec, + mark_matches: bool, + emit_unmatched: bool, + limit: IdxSize, + ) -> IdxSize { + let HashKeys::RowEncoded(hash_keys) = hash_keys else { + unreachable!() + }; + + if hash_keys.keys.has_nulls() { + let iter = hash_keys + .hashes + .values_iter() + .copied() + .zip(hash_keys.keys.iter()) + .enumerate_idx() + .map(|(i, (h, k))| (i, h, k)); + self.probe_dispatch( + iter, + table_match, + probe_match, + mark_matches, + emit_unmatched, + limit, + ) + } else { + let iter = hash_keys + .hashes + .values_iter() + .copied() + .zip(hash_keys.keys.values_iter().map(Some)) + .enumerate_idx() + .map(|(i, (h, k))| (i, h, k)); + self.probe_dispatch( + iter, + table_match, + probe_match, + mark_matches, + emit_unmatched, + limit, + ) + } + } + + unsafe fn probe_subset( + &self, + hash_keys: &HashKeys, + subset: &[IdxSize], + table_match: &mut Vec, + probe_match: &mut Vec, + mark_matches: bool, + emit_unmatched: bool, + limit: IdxSize, + ) -> IdxSize { + let HashKeys::RowEncoded(hash_keys) = hash_keys else { + unreachable!() + }; + + if hash_keys.keys.has_nulls() { + let iter = subset.iter().map(|i| { + ( + *i, + hash_keys.hashes.value_unchecked(*i as usize), + hash_keys.keys.get_unchecked(*i as usize), + ) + }); + self.probe_dispatch( + iter, + table_match, + probe_match, + mark_matches, + emit_unmatched, + limit, + ) + } else { + let iter = subset.iter().map(|i| { + ( + *i, + hash_keys.hashes.value_unchecked(*i as usize), + Some(hash_keys.keys.value_unchecked(*i as usize)), + ) + }); + self.probe_dispatch( + iter, + table_match, + probe_match, + mark_matches, + emit_unmatched, + limit, + ) + } + } + + fn unmarked_keys( + &self, + out: &mut Vec, + mut offset: IdxSize, + limit: IdxSize, + ) -> IdxSize { + out.clear(); + + let mut keys_processed = 0; + if (offset as usize) < self.null_keys.len() { + out.extend( + self.null_keys[offset as usize..] + .iter() + .copied() + .take(limit as usize), + ); + keys_processed += out.len() as IdxSize; + offset += out.len() as IdxSize; + if out.len() >= limit as usize { + return keys_processed; + } + } + + offset -= self.null_keys.len() as IdxSize; + + while let Some((_, _, idxs)) = self.idx_map.get_index(offset) { + let first_idx = unsafe { idxs.get_unchecked(0) }; + let first_idx_val = first_idx.load(Ordering::Acquire); + if first_idx_val >> 63 == 0 { + for idx in &idxs[..] { + out.push((idx.load(Ordering::Relaxed) & !(1 << 63)) as IdxSize); + } + } + + keys_processed += 1; + offset += 1; + if out.len() >= limit as usize { + break; + } + } + + keys_processed + } +} diff --git a/crates/polars-expr/src/idx_table/single_key.rs b/crates/polars-expr/src/idx_table/single_key.rs new file mode 100644 index 000000000000..3c983131cc89 --- /dev/null +++ b/crates/polars-expr/src/idx_table/single_key.rs @@ -0,0 +1,310 @@ +#![allow(clippy::unnecessary_cast)] // Clippy doesn't recognize that IdxSize and u64 can be different. +#![allow(unsafe_op_in_unsafe_fn)] + +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; + +use polars_utils::idx_map::total_idx_map::{Entry, TotalIndexMap}; +use polars_utils::idx_vec::UnitVec; +use polars_utils::itertools::Itertools; +use polars_utils::total_ord::{TotalEq, TotalHash}; +use polars_utils::unitvec; + +use super::*; +use crate::hash_keys::HashKeys; + +pub struct SingleKeyIdxTable { + // These AtomicU64s actually are IdxSizes, but we use the top bit of the + // first index in each to mark keys during probing. + idx_map: TotalIndexMap, UnitVec>, + idx_offset: IdxSize, + null_keys: Vec, + nulls_emitted: AtomicBool, +} + +impl SingleKeyIdxTable { + pub fn new() -> Self { + Self { + idx_map: TotalIndexMap::default(), + idx_offset: 0, + null_keys: Vec::new(), + nulls_emitted: AtomicBool::new(false), + } + } +} + +impl SingleKeyIdxTable +where + for<'a> T: PolarsDataType = K>, + K: TotalHash + TotalEq + Send + Sync + 'static, +{ + #[inline(always)] + fn probe_one( + &self, + key_idx: IdxSize, + key: &K, + table_match: &mut Vec, + probe_match: &mut Vec, + ) -> bool { + if let Some(idxs) = self.idx_map.get(key) { + for idx in &idxs[..] { + // Create matches, making sure to clear top bit. + table_match.push((idx.load(Ordering::Relaxed) & !(1 << 63)) as IdxSize); + probe_match.push(key_idx); + } + + // Mark if necessary. This action is idempotent so doesn't need + // atomic fetch_or to do it atomically. + if MARK_MATCHES { + let first_idx = unsafe { idxs.get_unchecked(0) }; + let first_idx_val = first_idx.load(Ordering::Relaxed); + if first_idx_val >> 63 == 0 { + first_idx.store(first_idx_val | (1 << 63), Ordering::Relaxed); + } + } + true + } else { + false + } + } + + fn probe_impl< + const MARK_MATCHES: bool, + const EMIT_UNMATCHED: bool, + const NULL_IS_VALID: bool, + >( + &self, + keys: impl Iterator)>, + table_match: &mut Vec, + probe_match: &mut Vec, + limit: IdxSize, + ) -> IdxSize { + let mut keys_processed = 0; + for (key_idx, key) in keys { + let found_match = if let Some(key) = key { + self.probe_one::(key_idx, &key, table_match, probe_match) + } else if NULL_IS_VALID { + for idx in &self.null_keys { + table_match.push(*idx); + probe_match.push(key_idx); + } + if MARK_MATCHES && !self.nulls_emitted.load(Ordering::Relaxed) { + self.nulls_emitted.store(true, Ordering::Relaxed); + } + !self.null_keys.is_empty() + } else { + false + }; + + if EMIT_UNMATCHED && !found_match { + table_match.push(IdxSize::MAX); + probe_match.push(key_idx); + } + + keys_processed += 1; + if table_match.len() >= limit as usize { + break; + } + } + keys_processed + } + + #[allow(clippy::too_many_arguments)] + fn probe_dispatch( + &self, + keys: impl Iterator)>, + table_match: &mut Vec, + probe_match: &mut Vec, + mark_matches: bool, + emit_unmatched: bool, + null_is_valid: bool, + limit: IdxSize, + ) -> IdxSize { + match (mark_matches, emit_unmatched, null_is_valid) { + (false, false, false) => { + self.probe_impl::(keys, table_match, probe_match, limit) + }, + (false, false, true) => { + self.probe_impl::(keys, table_match, probe_match, limit) + }, + (false, true, false) => { + self.probe_impl::(keys, table_match, probe_match, limit) + }, + (false, true, true) => { + self.probe_impl::(keys, table_match, probe_match, limit) + }, + (true, false, false) => { + self.probe_impl::(keys, table_match, probe_match, limit) + }, + (true, false, true) => { + self.probe_impl::(keys, table_match, probe_match, limit) + }, + (true, true, false) => { + self.probe_impl::(keys, table_match, probe_match, limit) + }, + (true, true, true) => { + self.probe_impl::(keys, table_match, probe_match, limit) + }, + } + } +} + +impl IdxTable for SingleKeyIdxTable +where + for<'a> T: PolarsDataType = K>, + K: TotalHash + TotalEq + Send + Sync + 'static, +{ + fn new_empty(&self) -> Box { + Box::new(Self::new()) + } + + fn reserve(&mut self, additional: usize) { + self.idx_map.reserve(additional); + } + + fn num_keys(&self) -> IdxSize { + self.idx_map.len() + } + + fn insert_keys(&mut self, _hash_keys: &HashKeys, _track_unmatchable: bool) { + // Isn't needed anymore, but also don't want to remove the code from the other implementations. + unimplemented!() + } + + unsafe fn insert_keys_subset( + &mut self, + hash_keys: &HashKeys, + subset: &[IdxSize], + track_unmatchable: bool, + ) { + let HashKeys::Single(hash_keys) = hash_keys else { + unreachable!() + }; + let new_idx_offset = (self.idx_offset as usize) + .checked_add(subset.len()) + .unwrap(); + assert!( + new_idx_offset < IdxSize::MAX as usize, + "overly large index in SingleKeyIdxTable" + ); + + let keys: &ChunkedArray = hash_keys.keys.as_phys_any().downcast_ref().unwrap(); + for (i, subset_idx) in subset.iter().enumerate_idx() { + let key = unsafe { keys.get_unchecked(*subset_idx as usize) }; + let idx = self.idx_offset + i; + if let Some(key) = key { + match self.idx_map.entry(key) { + Entry::Occupied(o) => { + o.into_mut().push(AtomicU64::new(idx as u64)); + }, + Entry::Vacant(v) => { + v.insert(unitvec![AtomicU64::new(idx as u64)]); + }, + } + } else if track_unmatchable | hash_keys.null_is_valid { + self.null_keys.push(idx); + } + } + + self.idx_offset = new_idx_offset as IdxSize; + } + + fn probe( + &self, + _hash_keys: &HashKeys, + _table_match: &mut Vec, + _probe_match: &mut Vec, + _mark_matches: bool, + _emit_unmatched: bool, + _limit: IdxSize, + ) -> IdxSize { + // Isn't needed anymore, but also don't want to remove the code from the other implementations. + unimplemented!() + } + + unsafe fn probe_subset( + &self, + hash_keys: &HashKeys, + subset: &[IdxSize], + table_match: &mut Vec, + probe_match: &mut Vec, + mark_matches: bool, + emit_unmatched: bool, + limit: IdxSize, + ) -> IdxSize { + let HashKeys::Single(hash_keys) = hash_keys else { + unreachable!() + }; + + let keys: &ChunkedArray = hash_keys.keys.as_phys_any().downcast_ref().unwrap(); + if keys.has_nulls() { + let iter = subset.iter().map(|i| (*i, keys.get_unchecked(*i as usize))); + self.probe_dispatch( + iter, + table_match, + probe_match, + mark_matches, + emit_unmatched, + hash_keys.null_is_valid, + limit, + ) + } else { + let iter = subset + .iter() + .map(|i| (*i, Some(keys.value_unchecked(*i as usize)))); + self.probe_dispatch( + iter, + table_match, + probe_match, + mark_matches, + emit_unmatched, + false, // Whether or not nulls are valid doesn't matter. + limit, + ) + } + } + + fn unmarked_keys( + &self, + out: &mut Vec, + mut offset: IdxSize, + limit: IdxSize, + ) -> IdxSize { + out.clear(); + + let mut keys_processed = 0; + if !self.nulls_emitted.load(Ordering::Relaxed) { + if (offset as usize) < self.null_keys.len() { + out.extend( + self.null_keys[offset as usize..] + .iter() + .copied() + .take(limit as usize), + ); + keys_processed += out.len() as IdxSize; + offset += out.len() as IdxSize; + if out.len() >= limit as usize { + return keys_processed; + } + } + offset -= self.null_keys.len() as IdxSize; + } + + while let Some((_, idxs)) = self.idx_map.get_index(offset) { + let first_idx = unsafe { idxs.get_unchecked(0) }; + let first_idx_val = first_idx.load(Ordering::Relaxed); + if first_idx_val >> 63 == 0 { + for idx in &idxs[..] { + out.push((idx.load(Ordering::Relaxed) & !(1 << 63)) as IdxSize); + } + } + + keys_processed += 1; + offset += 1; + if out.len() >= limit as usize { + break; + } + } + + keys_processed + } +} diff --git a/crates/polars-expr/src/lib.rs b/crates/polars-expr/src/lib.rs new file mode 100644 index 000000000000..bd06b465dd9e --- /dev/null +++ b/crates/polars-expr/src/lib.rs @@ -0,0 +1,34 @@ +mod expressions; +pub mod groups; +pub mod hash_keys; +pub mod hot_groups; +pub mod idx_table; +pub mod planner; +pub mod prelude; +pub mod reduce; +pub mod state; + +use polars_utils::IdxSize; + +pub use crate::planner::{ExpressionConversionState, create_physical_expr}; + +/// An index where the top bit indicates whether a value should be evicted. +pub struct EvictIdx(IdxSize); + +impl EvictIdx { + #[inline(always)] + pub fn new(idx: IdxSize, should_evict: bool) -> Self { + debug_assert!(idx >> (IdxSize::BITS - 1) == 0); + Self(idx | ((should_evict as IdxSize) << (IdxSize::BITS - 1))) + } + + #[inline(always)] + pub fn idx(&self) -> usize { + (self.0 & ((1 << (IdxSize::BITS - 1)) - 1)) as usize + } + + #[inline(always)] + pub fn should_evict(&self) -> bool { + (self.0 >> (IdxSize::BITS - 1)) != 0 + } +} diff --git a/crates/polars-expr/src/planner.rs b/crates/polars-expr/src/planner.rs new file mode 100644 index 000000000000..6bef2a6c49f1 --- /dev/null +++ b/crates/polars-expr/src/planner.rs @@ -0,0 +1,538 @@ +use polars_core::prelude::*; +use polars_plan::prelude::expr_ir::ExprIR; +use polars_plan::prelude::*; +use recursive::recursive; + +use crate::expressions as phys_expr; +use crate::expressions::*; + +pub fn get_expr_depth_limit() -> PolarsResult { + let depth = if let Ok(d) = std::env::var("POLARS_MAX_EXPR_DEPTH") { + let v = d + .parse::() + .map_err(|_| polars_err!(ComputeError: "could not parse 'max_expr_depth': {}", d))?; + u16::try_from(v).unwrap_or(0) + } else { + 512 + }; + Ok(depth) +} + +fn ok_checker(_i: usize, _state: &ExpressionConversionState) -> PolarsResult<()> { + Ok(()) +} + +pub fn create_physical_expressions_from_irs( + exprs: &[ExprIR], + context: Context, + expr_arena: &Arena, + schema: &SchemaRef, + state: &mut ExpressionConversionState, +) -> PolarsResult>> { + create_physical_expressions_check_state(exprs, context, expr_arena, schema, state, ok_checker) +} + +pub(crate) fn create_physical_expressions_check_state( + exprs: &[ExprIR], + context: Context, + expr_arena: &Arena, + schema: &SchemaRef, + state: &mut ExpressionConversionState, + checker: F, +) -> PolarsResult>> +where + F: Fn(usize, &ExpressionConversionState) -> PolarsResult<()>, +{ + exprs + .iter() + .enumerate() + .map(|(i, e)| { + state.reset(); + let out = create_physical_expr(e, context, expr_arena, schema, state); + checker(i, state)?; + out + }) + .collect() +} + +pub(crate) fn create_physical_expressions_from_nodes( + exprs: &[Node], + context: Context, + expr_arena: &Arena, + schema: &SchemaRef, + state: &mut ExpressionConversionState, +) -> PolarsResult>> { + create_physical_expressions_from_nodes_check_state( + exprs, context, expr_arena, schema, state, ok_checker, + ) +} + +pub(crate) fn create_physical_expressions_from_nodes_check_state( + exprs: &[Node], + context: Context, + expr_arena: &Arena, + schema: &SchemaRef, + state: &mut ExpressionConversionState, + checker: F, +) -> PolarsResult>> +where + F: Fn(usize, &ExpressionConversionState) -> PolarsResult<()>, +{ + exprs + .iter() + .enumerate() + .map(|(i, e)| { + state.reset(); + let out = create_physical_expr_inner(*e, context, expr_arena, schema, state); + checker(i, state)?; + out + }) + .collect() +} + +#[derive(Copy, Clone)] +pub struct ExpressionConversionState { + // settings per context + // they remain activate between + // expressions + pub allow_threading: bool, + pub has_windows: bool, + // settings per expression + // those are reset every expression + local: LocalConversionState, +} + +#[derive(Copy, Clone, Default)] +struct LocalConversionState { + has_implode: bool, + has_window: bool, + has_lit: bool, +} + +impl ExpressionConversionState { + pub fn new(allow_threading: bool) -> Self { + Self { + allow_threading, + has_windows: false, + local: LocalConversionState { + ..Default::default() + }, + } + } + + fn reset(&mut self) { + self.local = LocalConversionState::default(); + } + + fn has_implode(&self) -> bool { + self.local.has_implode + } + + fn set_window(&mut self) { + self.has_windows = true; + self.local.has_window = true; + } +} + +pub fn create_physical_expr( + expr_ir: &ExprIR, + ctxt: Context, + expr_arena: &Arena, + schema: &SchemaRef, + state: &mut ExpressionConversionState, +) -> PolarsResult> { + let phys_expr = create_physical_expr_inner(expr_ir.node(), ctxt, expr_arena, schema, state)?; + + if let Some(name) = expr_ir.get_alias() { + Ok(Arc::new(AliasExpr::new( + phys_expr, + name.clone(), + node_to_expr(expr_ir.node(), expr_arena), + ))) + } else { + Ok(phys_expr) + } +} + +#[recursive] +fn create_physical_expr_inner( + expression: Node, + ctxt: Context, + expr_arena: &Arena, + schema: &SchemaRef, + state: &mut ExpressionConversionState, +) -> PolarsResult> { + use AExpr::*; + + match expr_arena.get(expression) { + Len => Ok(Arc::new(phys_expr::CountExpr::new())), + Window { + function, + partition_by, + order_by, + options, + } => { + let mut function = *function; + state.set_window(); + let phys_function = create_physical_expr_inner( + function, + Context::Aggregation, + expr_arena, + schema, + state, + )?; + + let order_by = order_by + .map(|(node, options)| { + PolarsResult::Ok(( + create_physical_expr_inner( + node, + Context::Aggregation, + expr_arena, + schema, + state, + )?, + options, + )) + }) + .transpose()?; + + let mut out_name = None; + if let Alias(expr, name) = expr_arena.get(function) { + function = *expr; + out_name = Some(name.clone()); + }; + let function_expr = node_to_expr(function, expr_arena); + let expr = node_to_expr(expression, expr_arena); + + // set again as the state can be reset + state.set_window(); + match options { + WindowType::Over(mapping) => { + // TODO! Order by + let group_by = create_physical_expressions_from_nodes( + partition_by, + Context::Aggregation, + expr_arena, + schema, + state, + )?; + let mut apply_columns = aexpr_to_leaf_names(function, expr_arena); + // sort and then dedup removes consecutive duplicates == all duplicates + apply_columns.sort(); + apply_columns.dedup(); + + if apply_columns.is_empty() { + if has_aexpr(function, expr_arena, |e| matches!(e, AExpr::Literal(_))) { + apply_columns.push(PlSmallStr::from_static("literal")) + } else if has_aexpr(function, expr_arena, |e| matches!(e, AExpr::Len)) { + apply_columns.push(PlSmallStr::from_static("len")) + } else { + let e = node_to_expr(function, expr_arena); + polars_bail!( + ComputeError: + "cannot apply a window function, did not find a root column; \ + this is likely due to a syntax error in this expression: {:?}", e + ); + } + } + + Ok(Arc::new(WindowExpr { + group_by, + order_by, + apply_columns, + out_name, + function: function_expr, + phys_function, + mapping: *mapping, + expr, + })) + }, + #[cfg(feature = "dynamic_group_by")] + WindowType::Rolling(options) => Ok(Arc::new(RollingExpr { + function: function_expr, + phys_function, + out_name, + options: options.clone(), + expr, + })), + } + }, + Literal(value) => { + state.local.has_lit = true; + Ok(Arc::new(LiteralExpr::new( + value.clone(), + node_to_expr(expression, expr_arena), + ))) + }, + BinaryExpr { left, op, right } => { + let is_scalar = is_scalar_ae(expression, expr_arena); + let lhs = create_physical_expr_inner(*left, ctxt, expr_arena, schema, state)?; + let rhs = create_physical_expr_inner(*right, ctxt, expr_arena, schema, state)?; + Ok(Arc::new(phys_expr::BinaryExpr::new( + lhs, + *op, + rhs, + node_to_expr(expression, expr_arena), + state.local.has_lit, + state.allow_threading, + is_scalar, + ))) + }, + Column(column) => Ok(Arc::new(ColumnExpr::new( + column.clone(), + node_to_expr(expression, expr_arena), + schema.clone(), + ))), + Sort { expr, options } => { + let phys_expr = create_physical_expr_inner(*expr, ctxt, expr_arena, schema, state)?; + Ok(Arc::new(SortExpr::new( + phys_expr, + *options, + node_to_expr(expression, expr_arena), + ))) + }, + Gather { + expr, + idx, + returns_scalar, + } => { + let phys_expr = create_physical_expr_inner(*expr, ctxt, expr_arena, schema, state)?; + let phys_idx = create_physical_expr_inner(*idx, ctxt, expr_arena, schema, state)?; + Ok(Arc::new(GatherExpr { + phys_expr, + idx: phys_idx, + expr: node_to_expr(expression, expr_arena), + returns_scalar: *returns_scalar, + })) + }, + SortBy { + expr, + by, + sort_options, + } => { + let phys_expr = create_physical_expr_inner(*expr, ctxt, expr_arena, schema, state)?; + let phys_by = + create_physical_expressions_from_nodes(by, ctxt, expr_arena, schema, state)?; + Ok(Arc::new(SortByExpr::new( + phys_expr, + phys_by, + node_to_expr(expression, expr_arena), + sort_options.clone(), + ))) + }, + Filter { input, by } => { + let phys_input = create_physical_expr_inner(*input, ctxt, expr_arena, schema, state)?; + let phys_by = create_physical_expr_inner(*by, ctxt, expr_arena, schema, state)?; + Ok(Arc::new(FilterExpr::new( + phys_input, + phys_by, + node_to_expr(expression, expr_arena), + ))) + }, + Agg(agg) => { + let expr = agg.get_input().first(); + let input = create_physical_expr_inner(expr, ctxt, expr_arena, schema, state)?; + polars_ensure!(!(state.has_implode() && matches!(ctxt, Context::Aggregation)), InvalidOperation: "'implode' followed by an aggregation is not allowed"); + state.local.has_implode |= matches!(agg, IRAggExpr::Implode(_)); + let allow_threading = state.allow_threading; + + match ctxt { + Context::Default if !matches!(agg, IRAggExpr::Quantile { .. }) => { + use {GroupByMethod as GBM, IRAggExpr as I}; + + let groupby = match agg { + I::Min { propagate_nans, .. } if *propagate_nans => GBM::NanMin, + I::Min { .. } => GBM::Min, + I::Max { propagate_nans, .. } if *propagate_nans => GBM::NanMax, + I::Max { .. } => GBM::Max, + I::Median(_) => GBM::Median, + I::NUnique(_) => GBM::NUnique, + I::First(_) => GBM::First, + I::Last(_) => GBM::Last, + I::Mean(_) => GBM::Mean, + I::Implode(_) => GBM::Implode, + I::Quantile { .. } => unreachable!(), + I::Sum(_) => GBM::Sum, + I::Count(_, include_nulls) => GBM::Count { + include_nulls: *include_nulls, + }, + I::Std(_, ddof) => GBM::Std(*ddof), + I::Var(_, ddof) => GBM::Var(*ddof), + I::AggGroups(_) => { + polars_bail!(InvalidOperation: "agg groups expression only supported in aggregation context") + }, + }; + + let agg_type = AggregationType { + groupby, + allow_threading, + }; + + Ok(Arc::new(AggregationExpr::new(input, agg_type, None))) + }, + _ => { + if let IRAggExpr::Quantile { + quantile, + method: interpol, + .. + } = agg + { + let quantile = + create_physical_expr_inner(*quantile, ctxt, expr_arena, schema, state)?; + return Ok(Arc::new(AggQuantileExpr::new(input, quantile, *interpol))); + } + + let field = expr_arena.get(expression).to_field( + schema, + Context::Aggregation, + expr_arena, + )?; + + let groupby = GroupByMethod::from(agg.clone()); + let agg_type = AggregationType { + groupby, + allow_threading: false, + }; + Ok(Arc::new(AggregationExpr::new(input, agg_type, Some(field)))) + }, + } + }, + Cast { + expr, + dtype, + options, + } => { + let phys_expr = create_physical_expr_inner(*expr, ctxt, expr_arena, schema, state)?; + Ok(Arc::new(CastExpr { + input: phys_expr, + dtype: dtype.clone(), + expr: node_to_expr(expression, expr_arena), + options: *options, + inlined_eval: Default::default(), + })) + }, + Ternary { + predicate, + truthy, + falsy, + } => { + let is_scalar = is_scalar_ae(expression, expr_arena); + let mut lit_count = 0u8; + state.reset(); + let predicate = + create_physical_expr_inner(*predicate, ctxt, expr_arena, schema, state)?; + lit_count += state.local.has_lit as u8; + state.reset(); + let truthy = create_physical_expr_inner(*truthy, ctxt, expr_arena, schema, state)?; + lit_count += state.local.has_lit as u8; + state.reset(); + let falsy = create_physical_expr_inner(*falsy, ctxt, expr_arena, schema, state)?; + lit_count += state.local.has_lit as u8; + Ok(Arc::new(TernaryExpr::new( + predicate, + truthy, + falsy, + node_to_expr(expression, expr_arena), + state.allow_threading && lit_count < 2, + is_scalar, + ))) + }, + AnonymousFunction { + input, + function, + output_type: _, + options, + } => { + let is_scalar = is_scalar_ae(expression, expr_arena); + let output_field = expr_arena + .get(expression) + .to_field(schema, ctxt, expr_arena)?; + + let input = + create_physical_expressions_from_irs(input, ctxt, expr_arena, schema, state)?; + + Ok(Arc::new(ApplyExpr::new( + input, + function.clone().materialize()?, + node_to_expr(expression, expr_arena), + *options, + state.allow_threading, + schema.clone(), + output_field, + is_scalar, + ))) + }, + Function { + input, + function, + options, + } => { + let is_scalar = is_scalar_ae(expression, expr_arena); + let output_field = expr_arena + .get(expression) + .to_field(schema, ctxt, expr_arena)?; + let input = + create_physical_expressions_from_irs(input, ctxt, expr_arena, schema, state)?; + + Ok(Arc::new(ApplyExpr::new( + input, + function.clone().into(), + node_to_expr(expression, expr_arena), + *options, + state.allow_threading, + schema.clone(), + output_field, + is_scalar, + ))) + }, + Slice { + input, + offset, + length, + } => { + let input = create_physical_expr_inner(*input, ctxt, expr_arena, schema, state)?; + let offset = create_physical_expr_inner(*offset, ctxt, expr_arena, schema, state)?; + let length = create_physical_expr_inner(*length, ctxt, expr_arena, schema, state)?; + polars_ensure!(!(state.has_implode() && matches!(ctxt, Context::Aggregation)), InvalidOperation: "'implode' followed by a slice during aggregation is not allowed"); + Ok(Arc::new(SliceExpr { + input, + offset, + length, + expr: node_to_expr(expression, expr_arena), + })) + }, + Explode(expr) => { + let input = create_physical_expr_inner(*expr, ctxt, expr_arena, schema, state)?; + let function = SpecialEq::new(Arc::new( + move |c: &mut [polars_core::frame::column::Column]| c[0].explode().map(Some), + ) as Arc); + + let field = expr_arena + .get(expression) + .to_field(schema, ctxt, expr_arena)?; + + Ok(Arc::new(ApplyExpr::new( + vec![input], + function, + node_to_expr(expression, expr_arena), + FunctionOptions { + collect_groups: ApplyOptions::GroupWise, + ..Default::default() + }, + state.allow_threading, + schema.clone(), + field, + false, + ))) + }, + Alias(input, name) => { + let phys_expr = create_physical_expr_inner(*input, ctxt, expr_arena, schema, state)?; + Ok(Arc::new(AliasExpr::new( + phys_expr, + name.clone(), + node_to_expr(*input, expr_arena), + ))) + }, + } +} diff --git a/crates/polars-expr/src/prelude.rs b/crates/polars-expr/src/prelude.rs new file mode 100644 index 000000000000..36c336c857bb --- /dev/null +++ b/crates/polars-expr/src/prelude.rs @@ -0,0 +1,2 @@ +pub use crate::expressions::*; +pub use crate::state::*; diff --git a/crates/polars-expr/src/reduce/convert.rs b/crates/polars-expr/src/reduce/convert.rs new file mode 100644 index 000000000000..700adb4abf56 --- /dev/null +++ b/crates/polars-expr/src/reduce/convert.rs @@ -0,0 +1,78 @@ +// use polars_core::error::feature_gated; +use polars_plan::prelude::*; +use polars_utils::arena::{Arena, Node}; + +use super::*; +use crate::reduce::count::CountReduce; +use crate::reduce::first_last::{new_first_reduction, new_last_reduction}; +use crate::reduce::len::LenReduce; +use crate::reduce::mean::new_mean_reduction; +use crate::reduce::min_max::{new_max_reduction, new_min_reduction}; +use crate::reduce::sum::new_sum_reduction; +use crate::reduce::var_std::new_var_std_reduction; + +/// Converts a node into a reduction + its associated selector expression. +pub fn into_reduction( + node: Node, + expr_arena: &mut Arena, + schema: &Schema, +) -> PolarsResult<(Box, Node)> { + let get_dt = |node| { + expr_arena + .get(node) + .to_dtype(schema, Context::Default, expr_arena)? + .materialize_unknown(false) + }; + let out = match expr_arena.get(node) { + AExpr::Agg(agg) => match agg { + IRAggExpr::Sum(input) => (new_sum_reduction(get_dt(*input)?), *input), + IRAggExpr::Mean(input) => (new_mean_reduction(get_dt(*input)?), *input), + IRAggExpr::Min { + propagate_nans, + input, + } => (new_min_reduction(get_dt(*input)?, *propagate_nans), *input), + IRAggExpr::Max { + propagate_nans, + input, + } => (new_max_reduction(get_dt(*input)?, *propagate_nans), *input), + IRAggExpr::Var(input, ddof) => { + (new_var_std_reduction(get_dt(*input)?, false, *ddof), *input) + }, + IRAggExpr::Std(input, ddof) => { + (new_var_std_reduction(get_dt(*input)?, true, *ddof), *input) + }, + IRAggExpr::First(input) => (new_first_reduction(get_dt(*input)?), *input), + IRAggExpr::Last(input) => (new_last_reduction(get_dt(*input)?), *input), + IRAggExpr::Count(input, include_nulls) => { + let count = Box::new(CountReduce::new(*include_nulls)) as Box<_>; + (count, *input) + }, + IRAggExpr::Quantile { .. } => todo!(), + IRAggExpr::Median(_) => todo!(), + IRAggExpr::NUnique(_) => todo!(), + IRAggExpr::Implode(_) => todo!(), + IRAggExpr::AggGroups(_) => todo!(), + }, + AExpr::Len => { + if let Some(first_column) = schema.iter_names().next() { + let out: Box = Box::new(LenReduce::default()); + let expr = expr_arena.add(AExpr::Column(first_column.as_str().into())); + + (out, expr) + } else { + // Support len aggregation on 0-width morsels. + // Notes: + // * We do this instead of projecting a scalar, because scalar literals don't + // project to the height of the DataFrame (in the PhysicalExpr impl). + // * This approach is not sound for `update_groups()`, but currently that case is + // not hit (it would need group-by -> len on empty morsels). + let out: Box = new_sum_reduction(DataType::new_idxsize()); + let expr = expr_arena.add(AExpr::Len); + + (out, expr) + } + }, + _ => unreachable!(), + }; + Ok(out) +} diff --git a/crates/polars-expr/src/reduce/count.rs b/crates/polars-expr/src/reduce/count.rs new file mode 100644 index 000000000000..50268beafadd --- /dev/null +++ b/crates/polars-expr/src/reduce/count.rs @@ -0,0 +1,168 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use polars_core::error::constants::LENGTH_LIMIT_MSG; + +use super::*; + +pub struct CountReduce { + counts: Vec, + evicted_counts: Vec, + include_nulls: bool, +} + +impl CountReduce { + pub fn new(include_nulls: bool) -> Self { + Self { + counts: Vec::new(), + evicted_counts: Vec::new(), + include_nulls, + } + } +} + +impl GroupedReduction for CountReduce { + fn new_empty(&self) -> Box { + Box::new(Self::new(self.include_nulls)) + } + + fn reserve(&mut self, additional: usize) { + self.counts.reserve(additional); + } + + fn resize(&mut self, num_groups: IdxSize) { + self.counts.resize(num_groups as usize, 0); + } + + fn update_group( + &mut self, + values: &Column, + group_idx: IdxSize, + _seq_id: u64, + ) -> PolarsResult<()> { + let mut count = values.len(); + if !self.include_nulls { + count -= values.null_count(); + } + self.counts[group_idx as usize] += count as u64; + Ok(()) + } + + unsafe fn update_groups( + &mut self, + values: &Column, + group_idxs: &[IdxSize], + _seq_id: u64, + ) -> PolarsResult<()> { + assert!(values.len() == group_idxs.len()); + let values = values.as_materialized_series(); // @scalar-opt + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + let mut offset = 0; + for chunk in values.chunks() { + let gs = &group_idxs[offset..offset + chunk.len()]; + offset += chunk.len(); + + if chunk.has_nulls() && !self.include_nulls { + let validity = chunk.validity().unwrap(); + for (g, v) in gs.iter().zip(validity.iter()) { + *self.counts.get_unchecked_mut(*g as usize) += v as u64; + } + } else { + for g in gs { + *self.counts.get_unchecked_mut(*g as usize) += 1; + } + } + } + } + Ok(()) + } + + unsafe fn update_groups_while_evicting( + &mut self, + values: &Column, + subset: &[IdxSize], + group_idxs: &[EvictIdx], + _seq_id: u64, + ) -> PolarsResult<()> { + assert!(subset.len() == group_idxs.len()); + let values = values.as_materialized_series(); // @scalar-opt + let chunks = values.chunks(); + assert!(chunks.len() == 1); + let arr = &*chunks[0]; + if arr.has_nulls() && !self.include_nulls { + let valid = arr.validity().unwrap(); + for (i, g) in subset.iter().zip(group_idxs) { + let grp = self.counts.get_unchecked_mut(g.idx()); + if g.should_evict() { + self.evicted_counts.push(*grp); + *grp = 0; + } + *grp += valid.get_bit_unchecked(*i as usize) as u64; + } + } else { + for (_, g) in subset.iter().zip(group_idxs) { + let grp = self.counts.get_unchecked_mut(g.idx()); + if g.should_evict() { + self.evicted_counts.push(*grp); + *grp = 0; + } + *grp += 1; + } + } + Ok(()) + } + + unsafe fn combine( + &mut self, + other: &dyn GroupedReduction, + group_idxs: &[IdxSize], + ) -> PolarsResult<()> { + let other = other.as_any().downcast_ref::().unwrap(); + assert!(other.counts.len() == group_idxs.len()); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for (g, v) in group_idxs.iter().zip(other.counts.iter()) { + *self.counts.get_unchecked_mut(*g as usize) += v; + } + } + Ok(()) + } + + unsafe fn gather_combine( + &mut self, + other: &dyn GroupedReduction, + subset: &[IdxSize], + group_idxs: &[IdxSize], + ) -> PolarsResult<()> { + let other = other.as_any().downcast_ref::().unwrap(); + assert!(subset.len() == group_idxs.len()); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for (i, g) in subset.iter().zip(group_idxs) { + *self.counts.get_unchecked_mut(*g as usize) += + *other.counts.get_unchecked(*i as usize); + } + } + Ok(()) + } + + fn take_evictions(&mut self) -> Box { + Box::new(Self { + counts: core::mem::take(&mut self.evicted_counts), + evicted_counts: Vec::new(), + include_nulls: self.include_nulls, + }) + } + + fn finalize(&mut self) -> PolarsResult { + let ca: IdxCa = self + .counts + .drain(..) + .map(|l| IdxSize::try_from(l).expect(LENGTH_LIMIT_MSG)) + .collect_ca(PlSmallStr::EMPTY); + Ok(ca.into_series()) + } + + fn as_any(&self) -> &dyn Any { + self + } +} diff --git a/crates/polars-expr/src/reduce/first_last.rs b/crates/polars-expr/src/reduce/first_last.rs new file mode 100644 index 000000000000..fade0375bc59 --- /dev/null +++ b/crates/polars-expr/src/reduce/first_last.rs @@ -0,0 +1,409 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use std::marker::PhantomData; + +use polars_core::frame::row::AnyValueBufferTrusted; +use polars_core::with_match_physical_numeric_polars_type; + +use super::*; + +pub fn new_first_reduction(dtype: DataType) -> Box { + new_reduction_with_policy::(dtype) +} + +pub fn new_last_reduction(dtype: DataType) -> Box { + new_reduction_with_policy::(dtype) +} + +fn new_reduction_with_policy(dtype: DataType) -> Box { + use DataType::*; + use VecGroupedReduction as VGR; + match dtype { + Boolean => Box::new(VecGroupedReduction::new( + dtype, + BoolFirstLastReducer::

(PhantomData), + )), + _ if dtype.is_primitive_numeric() || dtype.is_temporal() => { + with_match_physical_numeric_polars_type!(dtype.to_physical(), |$T| { + Box::new(VGR::new(dtype, NumFirstLastReducer::(PhantomData))) + }) + }, + String | Binary => Box::new(VecGroupedReduction::new( + dtype, + BinaryFirstLastReducer::

(PhantomData), + )), + _ => Box::new(GenericFirstLastGroupedReduction::

::new(dtype)), + } +} + +trait Policy: Send + Sync + 'static { + fn index(len: usize) -> usize; + fn should_replace(new: u64, old: u64) -> bool; +} + +struct First; +impl Policy for First { + fn index(_len: usize) -> usize { + 0 + } + + fn should_replace(new: u64, old: u64) -> bool { + // Subtracting 1 with wrapping leaves all order unchanged, except it + // makes 0 (no value) the largest possible. + new.wrapping_sub(1) < old.wrapping_sub(1) + } +} + +struct Last; +impl Policy for Last { + fn index(len: usize) -> usize { + len - 1 + } + + fn should_replace(new: u64, old: u64) -> bool { + new >= old + } +} + +#[expect(dead_code)] +struct Arbitrary; +impl Policy for Arbitrary { + fn index(_len: usize) -> usize { + 0 + } + + fn should_replace(_new: u64, old: u64) -> bool { + old == 0 + } +} + +struct NumFirstLastReducer(PhantomData<(P, T)>); + +impl Clone for NumFirstLastReducer { + fn clone(&self) -> Self { + Self(PhantomData) + } +} + +impl Reducer for NumFirstLastReducer +where + P: Policy, + T: PolarsNumericType, + ChunkedArray: IntoSeries, +{ + type Dtype = T; + type Value = (Option, u64); + + fn init(&self) -> Self::Value { + (None, 0) + } + + fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> { + s.to_physical_repr() + } + + fn combine(&self, a: &mut Self::Value, b: &Self::Value) { + if P::should_replace(b.1, a.1) { + *a = *b; + } + } + + fn reduce_one(&self, a: &mut Self::Value, b: Option, seq_id: u64) { + if P::should_replace(seq_id, a.1) { + *a = (b, seq_id); + } + } + + fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray, seq_id: u64) { + if !ca.is_empty() && P::should_replace(seq_id, v.1) { + let val = ca.get(P::index(ca.len())); + *v = (val, seq_id); + } + } + + fn finish( + &self, + v: Vec, + m: Option, + dtype: &DataType, + ) -> PolarsResult { + assert!(m.is_none()); // This should only be used with VecGroupedReduction. + let ca: ChunkedArray = v.into_iter().map(|(x, _s)| x).collect_ca(PlSmallStr::EMPTY); + ca.into_series().cast(dtype) + } +} + +struct BinaryFirstLastReducer

(PhantomData

); + +impl

Clone for BinaryFirstLastReducer

{ + fn clone(&self) -> Self { + Self(PhantomData) + } +} + +fn replace_opt_bytes(l: &mut Option>, r: Option<&[u8]>) { + match (l, r) { + (Some(l), Some(r)) => { + l.clear(); + l.extend_from_slice(r); + }, + (l, r) => *l = r.map(|s| s.to_owned()), + } +} + +impl

Reducer for BinaryFirstLastReducer

+where + P: Policy, +{ + type Dtype = BinaryType; + type Value = (Option>, u64); + + fn init(&self) -> Self::Value { + (None, 0) + } + + fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> { + Cow::Owned(s.cast(&DataType::Binary).unwrap()) + } + + fn combine(&self, a: &mut Self::Value, b: &Self::Value) { + if P::should_replace(b.1, a.1) { + a.0.clone_from(&b.0); + a.1 = b.1; + } + } + + fn reduce_one(&self, a: &mut Self::Value, b: Option<&[u8]>, seq_id: u64) { + if P::should_replace(seq_id, a.1) { + replace_opt_bytes(&mut a.0, b); + a.1 = seq_id; + } + } + + fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray, seq_id: u64) { + if !ca.is_empty() && P::should_replace(seq_id, v.1) { + replace_opt_bytes(&mut v.0, ca.get(P::index(ca.len()))); + v.1 = seq_id; + } + } + + fn finish( + &self, + v: Vec, + m: Option, + dtype: &DataType, + ) -> PolarsResult { + assert!(m.is_none()); // This should only be used with VecGroupedReduction. + let ca: BinaryChunked = v.into_iter().map(|(x, _s)| x).collect_ca(PlSmallStr::EMPTY); + ca.into_series().cast(dtype) + } +} + +struct BoolFirstLastReducer

(PhantomData

); + +impl

Clone for BoolFirstLastReducer

{ + fn clone(&self) -> Self { + Self(PhantomData) + } +} + +impl

Reducer for BoolFirstLastReducer

+where + P: Policy, +{ + type Dtype = BooleanType; + type Value = (Option, u64); + + fn init(&self) -> Self::Value { + (None, 0) + } + + fn combine(&self, a: &mut Self::Value, b: &Self::Value) { + if P::should_replace(b.1, a.1) { + *a = *b; + } + } + + fn reduce_one(&self, a: &mut Self::Value, b: Option, seq_id: u64) { + if P::should_replace(seq_id, a.1) { + a.0 = b; + a.1 = seq_id; + } + } + + fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray, seq_id: u64) { + if !ca.is_empty() && P::should_replace(seq_id, v.1) { + v.0 = ca.get(P::index(ca.len())); + v.1 = seq_id; + } + } + + fn finish( + &self, + v: Vec, + m: Option, + _dtype: &DataType, + ) -> PolarsResult { + assert!(m.is_none()); // This should only be used with VecGroupedReduction. + let ca: BooleanChunked = v.into_iter().map(|(x, _s)| x).collect_ca(PlSmallStr::EMPTY); + Ok(ca.into_series()) + } +} + +pub struct GenericFirstLastGroupedReduction

{ + in_dtype: DataType, + values: Vec>, + seqs: Vec, + evicted_values: Vec>, + evicted_seqs: Vec, + policy: PhantomData P>, +} + +impl

GenericFirstLastGroupedReduction

{ + fn new(in_dtype: DataType) -> Self { + Self { + in_dtype, + values: Vec::new(), + seqs: Vec::new(), + evicted_values: Vec::new(), + evicted_seqs: Vec::new(), + policy: PhantomData, + } + } +} + +impl GroupedReduction for GenericFirstLastGroupedReduction

{ + fn new_empty(&self) -> Box { + Box::new(Self::new(self.in_dtype.clone())) + } + + fn reserve(&mut self, additional: usize) { + self.values.reserve(additional); + self.seqs.reserve(additional); + } + + fn resize(&mut self, num_groups: IdxSize) { + self.values.resize(num_groups as usize, AnyValue::Null); + self.seqs.resize(num_groups as usize, 0); + } + + fn update_group( + &mut self, + values: &Column, + group_idx: IdxSize, + seq_id: u64, + ) -> PolarsResult<()> { + if !values.is_empty() { + let seq_id = seq_id + 1; // We use 0 for 'no value'. + if P::should_replace(seq_id, self.seqs[group_idx as usize]) { + self.values[group_idx as usize] = values.get(P::index(values.len()))?.into_static(); + self.seqs[group_idx as usize] = seq_id; + } + } + Ok(()) + } + + unsafe fn update_groups( + &mut self, + values: &Column, + group_idxs: &[IdxSize], + seq_id: u64, + ) -> PolarsResult<()> { + let seq_id = seq_id + 1; // We use 0 for 'no value'. + for (i, g) in group_idxs.iter().enumerate() { + if P::should_replace(seq_id, *self.seqs.get_unchecked(*g as usize)) { + *self.values.get_unchecked_mut(*g as usize) = values.get_unchecked(i).into_static(); + *self.seqs.get_unchecked_mut(*g as usize) = seq_id; + } + } + Ok(()) + } + + unsafe fn update_groups_while_evicting( + &mut self, + values: &Column, + subset: &[IdxSize], + group_idxs: &[EvictIdx], + seq_id: u64, + ) -> PolarsResult<()> { + let seq_id = seq_id + 1; // We use 0 for 'no value'. + for (i, g) in subset.iter().zip(group_idxs) { + let grp_val = self.values.get_unchecked_mut(g.idx()); + let grp_seq = self.seqs.get_unchecked_mut(g.idx()); + if g.should_evict() { + self.evicted_values + .push(core::mem::replace(grp_val, AnyValue::Null)); + self.evicted_seqs.push(core::mem::replace(grp_seq, 0)); + } + if P::should_replace(seq_id, *grp_seq) { + *grp_val = values.get_unchecked(*i as usize).into_static(); + *grp_seq = seq_id; + } + } + Ok(()) + } + + unsafe fn combine( + &mut self, + other: &dyn GroupedReduction, + group_idxs: &[IdxSize], + ) -> PolarsResult<()> { + let other = other.as_any().downcast_ref::().unwrap(); + for (i, g) in group_idxs.iter().enumerate() { + if P::should_replace( + *other.seqs.get_unchecked(i), + *self.seqs.get_unchecked(*g as usize), + ) { + *self.values.get_unchecked_mut(*g as usize) = other.values.get_unchecked(i).clone(); + *self.seqs.get_unchecked_mut(*g as usize) = *other.seqs.get_unchecked(i); + } + } + Ok(()) + } + + unsafe fn gather_combine( + &mut self, + other: &dyn GroupedReduction, + subset: &[IdxSize], + group_idxs: &[IdxSize], + ) -> PolarsResult<()> { + let other = other.as_any().downcast_ref::().unwrap(); + for (i, g) in group_idxs.iter().enumerate() { + let si = *subset.get_unchecked(i) as usize; + if P::should_replace( + *other.seqs.get_unchecked(si), + *self.seqs.get_unchecked(*g as usize), + ) { + *self.values.get_unchecked_mut(*g as usize) = + other.values.get_unchecked(si).clone(); + *self.seqs.get_unchecked_mut(*g as usize) = *other.seqs.get_unchecked(si); + } + } + Ok(()) + } + + fn take_evictions(&mut self) -> Box { + Box::new(Self { + in_dtype: self.in_dtype.clone(), + values: core::mem::take(&mut self.evicted_values), + seqs: core::mem::take(&mut self.evicted_seqs), + evicted_values: Vec::new(), + evicted_seqs: Vec::new(), + policy: PhantomData, + }) + } + + fn finalize(&mut self) -> PolarsResult { + self.seqs.clear(); + unsafe { + let mut buf = AnyValueBufferTrusted::new(&self.in_dtype, self.values.len()); + for v in core::mem::take(&mut self.values) { + buf.add_unchecked_owned_physical(&v); + } + Ok(buf.into_series()) + } + } + + fn as_any(&self) -> &dyn Any { + self + } +} diff --git a/crates/polars-expr/src/reduce/len.rs b/crates/polars-expr/src/reduce/len.rs new file mode 100644 index 000000000000..89c0f15d2310 --- /dev/null +++ b/crates/polars-expr/src/reduce/len.rs @@ -0,0 +1,125 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use polars_core::error::constants::LENGTH_LIMIT_MSG; + +use super::*; + +#[derive(Default)] +pub struct LenReduce { + groups: Vec, + evictions: Vec, +} + +impl GroupedReduction for LenReduce { + fn new_empty(&self) -> Box { + Box::new(Self::default()) + } + + fn reserve(&mut self, additional: usize) { + self.groups.reserve(additional); + } + + fn resize(&mut self, num_groups: IdxSize) { + self.groups.resize(num_groups as usize, 0); + } + + fn update_group( + &mut self, + values: &Column, + group_idx: IdxSize, + _seq_id: u64, + ) -> PolarsResult<()> { + self.groups[group_idx as usize] += values.len() as u64; + Ok(()) + } + + unsafe fn update_groups( + &mut self, + values: &Column, + group_idxs: &[IdxSize], + _seq_id: u64, + ) -> PolarsResult<()> { + assert!(values.len() == group_idxs.len()); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for g in group_idxs.iter() { + *self.groups.get_unchecked_mut(*g as usize) += 1; + } + } + Ok(()) + } + + unsafe fn update_groups_while_evicting( + &mut self, + _values: &Column, + _subset: &[IdxSize], + group_idxs: &[EvictIdx], + _seq_id: u64, + ) -> PolarsResult<()> { + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for g in group_idxs.iter() { + let grp = self.groups.get_unchecked_mut(g.idx()); + if g.should_evict() { + self.evictions.push(*grp); + *grp = 0; + } + *grp += 1; + } + } + Ok(()) + } + + unsafe fn combine( + &mut self, + other: &dyn GroupedReduction, + group_idxs: &[IdxSize], + ) -> PolarsResult<()> { + let other = other.as_any().downcast_ref::().unwrap(); + assert!(other.groups.len() == group_idxs.len()); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for (g, v) in group_idxs.iter().zip(other.groups.iter()) { + *self.groups.get_unchecked_mut(*g as usize) += v; + } + } + Ok(()) + } + + unsafe fn gather_combine( + &mut self, + other: &dyn GroupedReduction, + subset: &[IdxSize], + group_idxs: &[IdxSize], + ) -> PolarsResult<()> { + let other = other.as_any().downcast_ref::().unwrap(); + assert!(subset.len() == group_idxs.len()); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for (i, g) in subset.iter().zip(group_idxs) { + *self.groups.get_unchecked_mut(*g as usize) += + *other.groups.get_unchecked(*i as usize); + } + } + Ok(()) + } + + fn take_evictions(&mut self) -> Box { + Box::new(Self { + groups: core::mem::take(&mut self.evictions), + evictions: Vec::new(), + }) + } + + fn finalize(&mut self) -> PolarsResult { + let ca: IdxCa = self + .groups + .drain(..) + .map(|l| IdxSize::try_from(l).expect(LENGTH_LIMIT_MSG)) + .collect_ca(PlSmallStr::EMPTY); + Ok(ca.into_series()) + } + + fn as_any(&self) -> &dyn Any { + self + } +} diff --git a/crates/polars-expr/src/reduce/mean.rs b/crates/polars-expr/src/reduce/mean.rs new file mode 100644 index 000000000000..a5c666b9dd7a --- /dev/null +++ b/crates/polars-expr/src/reduce/mean.rs @@ -0,0 +1,168 @@ +use std::marker::PhantomData; + +use num_traits::{AsPrimitive, Zero}; +use polars_core::with_match_physical_numeric_polars_type; + +use super::*; + +pub fn new_mean_reduction(dtype: DataType) -> Box { + use DataType::*; + use VecGroupedReduction as VGR; + match dtype { + Boolean => Box::new(VGR::new(dtype, BoolMeanReducer)), + _ if dtype.is_primitive_numeric() || dtype.is_temporal() => { + with_match_physical_numeric_polars_type!(dtype.to_physical(), |$T| { + Box::new(VGR::new(dtype, NumMeanReducer::<$T>(PhantomData))) + }) + }, + #[cfg(feature = "dtype-decimal")] + Decimal(_, _) => Box::new(VGR::new(dtype, NumMeanReducer::(PhantomData))), + + // For compatibility with the current engine, should probably be an error. + String | Binary => Box::new(super::NullGroupedReduction::new(dtype)), + + _ => unimplemented!("{dtype:?} is not supported by mean reduction"), + } +} + +fn finish_output(values: Vec<(f64, usize)>, dtype: &DataType) -> Series { + match dtype { + DataType::Float32 => { + let ca: Float32Chunked = values + .into_iter() + .map(|(s, c)| (c != 0).then(|| (s / c as f64) as f32)) + .collect_ca(PlSmallStr::EMPTY); + ca.into_series() + }, + dt if dt.is_primitive_numeric() => { + let ca: Float64Chunked = values + .into_iter() + .map(|(s, c)| (c != 0).then(|| s / c as f64)) + .collect_ca(PlSmallStr::EMPTY); + ca.into_series() + }, + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(_prec, scale) => { + let inv_scale_factor = 1.0 / 10u128.pow(scale.unwrap() as u32) as f64; + let ca: Float64Chunked = values + .into_iter() + .map(|(s, c)| (c != 0).then(|| s / c as f64 * inv_scale_factor)) + .collect_ca(PlSmallStr::EMPTY); + ca.into_series() + }, + #[cfg(feature = "dtype-datetime")] + DataType::Date => { + const MS_IN_DAY: i64 = 86_400_000; + let ca: Int64Chunked = values + .into_iter() + .map(|(s, c)| (c != 0).then(|| (s / c as f64 * MS_IN_DAY as f64) as i64)) + .collect_ca(PlSmallStr::EMPTY); + ca.into_datetime(TimeUnit::Milliseconds, None).into_series() + }, + DataType::Datetime(_, _) | DataType::Duration(_) | DataType::Time => { + let ca: Int64Chunked = values + .into_iter() + .map(|(s, c)| (c != 0).then(|| (s / c as f64) as i64)) + .collect_ca(PlSmallStr::EMPTY); + ca.into_series().cast(dtype).unwrap() + }, + _ => unimplemented!(), + } +} + +struct NumMeanReducer(PhantomData); +impl Clone for NumMeanReducer { + fn clone(&self) -> Self { + Self(PhantomData) + } +} + +impl Reducer for NumMeanReducer +where + T: PolarsNumericType, + ChunkedArray: ChunkAgg + IntoSeries, +{ + type Dtype = T; + type Value = (f64, usize); + + #[inline(always)] + fn init(&self) -> Self::Value { + (0.0, 0) + } + + fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> { + s.to_physical_repr() + } + + #[inline(always)] + fn combine(&self, a: &mut Self::Value, b: &Self::Value) { + a.0 += b.0; + a.1 += b.1; + } + + #[inline(always)] + fn reduce_one(&self, a: &mut Self::Value, b: Option, _seq_id: u64) { + a.0 += b.unwrap_or(T::Native::zero()).as_(); + a.1 += b.is_some() as usize; + } + + fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray, _seq_id: u64) { + v.0 += ChunkAgg::_sum_as_f64(ca); + v.1 += ca.len() - ca.null_count(); + } + + fn finish( + &self, + v: Vec, + m: Option, + dtype: &DataType, + ) -> PolarsResult { + assert!(m.is_none()); + Ok(finish_output(v, dtype)) + } +} + +#[derive(Clone)] +struct BoolMeanReducer; + +impl Reducer for BoolMeanReducer { + type Dtype = BooleanType; + type Value = (usize, usize); + + #[inline(always)] + fn init(&self) -> Self::Value { + (0, 0) + } + + #[inline(always)] + fn combine(&self, a: &mut Self::Value, b: &Self::Value) { + a.0 += b.0; + a.1 += b.1; + } + + #[inline(always)] + fn reduce_one(&self, a: &mut Self::Value, b: Option, _seq_id: u64) { + a.0 += b.unwrap_or(false) as usize; + a.1 += b.is_some() as usize; + } + + fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray, _seq_id: u64) { + v.0 += ca.sum().unwrap_or(0) as usize; + v.1 += ca.len() - ca.null_count(); + } + + fn finish( + &self, + v: Vec, + m: Option, + dtype: &DataType, + ) -> PolarsResult { + assert!(m.is_none()); + assert!(dtype == &DataType::Boolean); + let ca: Float64Chunked = v + .into_iter() + .map(|(s, c)| (c != 0).then(|| s as f64 / c as f64)) + .collect_ca(PlSmallStr::EMPTY); + Ok(ca.into_series()) + } +} diff --git a/crates/polars-expr/src/reduce/min_max.rs b/crates/polars-expr/src/reduce/min_max.rs new file mode 100644 index 000000000000..af3e3fa4eefd --- /dev/null +++ b/crates/polars-expr/src/reduce/min_max.rs @@ -0,0 +1,603 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use std::borrow::Cow; +use std::marker::PhantomData; + +use arrow::array::BooleanArray; +use arrow::bitmap::Bitmap; +use num_traits::Bounded; +use polars_core::with_match_physical_integer_polars_type; +#[cfg(feature = "propagate_nans")] +use polars_ops::prelude::nan_propagating_aggregate::ca_nan_agg; +use polars_utils::float::IsFloat; +use polars_utils::min_max::MinMax; + +use super::*; + +pub fn new_min_reduction(dtype: DataType, propagate_nans: bool) -> Box { + use DataType::*; + use VecMaskGroupedReduction as VMGR; + match dtype { + Boolean => Box::new(BoolMinGroupedReduction::default()), + #[cfg(feature = "propagate_nans")] + Float32 if propagate_nans => { + Box::new(VMGR::new(dtype, NumReducer::>::new())) + }, + #[cfg(feature = "propagate_nans")] + Float64 if propagate_nans => { + Box::new(VMGR::new(dtype, NumReducer::>::new())) + }, + Float32 => Box::new(VMGR::new(dtype, NumReducer::>::new())), + Float64 => Box::new(VMGR::new(dtype, NumReducer::>::new())), + String | Binary => Box::new(VecGroupedReduction::new(dtype, BinaryMinReducer)), + _ if dtype.is_integer() || dtype.is_temporal() => { + with_match_physical_integer_polars_type!(dtype.to_physical(), |$T| { + Box::new(VMGR::new(dtype, NumReducer::>::new())) + }) + }, + #[cfg(feature = "dtype-decimal")] + Decimal(_, _) => Box::new(VMGR::new(dtype, NumReducer::>::new())), + _ => unimplemented!(), + } +} + +pub fn new_max_reduction(dtype: DataType, propagate_nans: bool) -> Box { + use DataType::*; + use VecMaskGroupedReduction as VMGR; + match dtype { + Boolean => Box::new(BoolMaxGroupedReduction::default()), + #[cfg(feature = "propagate_nans")] + Float32 if propagate_nans => { + Box::new(VMGR::new(dtype, NumReducer::>::new())) + }, + #[cfg(feature = "propagate_nans")] + Float64 if propagate_nans => { + Box::new(VMGR::new(dtype, NumReducer::>::new())) + }, + Float32 => Box::new(VMGR::new(dtype, NumReducer::>::new())), + Float64 => Box::new(VMGR::new(dtype, NumReducer::>::new())), + String | Binary => Box::new(VecGroupedReduction::new(dtype, BinaryMaxReducer)), + _ if dtype.is_integer() || dtype.is_temporal() => { + with_match_physical_integer_polars_type!(dtype.to_physical(), |$T| { + Box::new(VMGR::new(dtype, NumReducer::>::new())) + }) + }, + #[cfg(feature = "dtype-decimal")] + Decimal(_, _) => Box::new(VMGR::new(dtype, NumReducer::>::new())), + _ => unimplemented!(), + } +} + +// These two variants ignore nans. +struct Min(PhantomData); +struct Max(PhantomData); + +// These two variants propagate nans. +#[cfg(feature = "propagate_nans")] +struct NanMin(PhantomData); +#[cfg(feature = "propagate_nans")] +struct NanMax(PhantomData); + +impl NumericReduction for Min +where + T: PolarsNumericType, + ChunkedArray: ChunkAgg, +{ + type Dtype = T; + + #[inline(always)] + fn init() -> T::Native { + if T::Native::is_float() { + T::Native::nan_value() + } else { + T::Native::max_value() + } + } + + #[inline(always)] + fn combine(a: T::Native, b: T::Native) -> T::Native { + MinMax::min_ignore_nan(a, b) + } + + #[inline(always)] + fn reduce_ca(ca: &ChunkedArray) -> Option { + ChunkAgg::min(ca) + } +} + +impl NumericReduction for Max +where + T: PolarsNumericType, + ChunkedArray: ChunkAgg, +{ + type Dtype = T; + + #[inline(always)] + fn init() -> T::Native { + if T::Native::is_float() { + T::Native::nan_value() + } else { + T::Native::min_value() + } + } + + #[inline(always)] + fn combine(a: T::Native, b: T::Native) -> T::Native { + MinMax::max_ignore_nan(a, b) + } + + #[inline(always)] + fn reduce_ca(ca: &ChunkedArray) -> Option { + ChunkAgg::max(ca) + } +} + +#[cfg(feature = "propagate_nans")] +impl NumericReduction for NanMin { + type Dtype = T; + + #[inline(always)] + fn init() -> T::Native { + T::Native::max_value() + } + + #[inline(always)] + fn combine(a: T::Native, b: T::Native) -> T::Native { + MinMax::min_propagate_nan(a, b) + } + + #[inline(always)] + fn reduce_ca(ca: &ChunkedArray) -> Option { + ca_nan_agg(ca, MinMax::min_propagate_nan) + } +} + +#[cfg(feature = "propagate_nans")] +impl NumericReduction for NanMax { + type Dtype = T; + + #[inline(always)] + fn init() -> T::Native { + T::Native::min_value() + } + + #[inline(always)] + fn combine(a: T::Native, b: T::Native) -> T::Native { + MinMax::max_propagate_nan(a, b) + } + + #[inline(always)] + fn reduce_ca(ca: &ChunkedArray) -> Option { + ca_nan_agg(ca, MinMax::max_propagate_nan) + } +} + +#[derive(Clone)] +struct BinaryMinReducer; +#[derive(Clone)] +struct BinaryMaxReducer; + +impl Reducer for BinaryMinReducer { + type Dtype = BinaryType; + type Value = Option>; // TODO: evaluate SmallVec. + + fn init(&self) -> Self::Value { + None + } + + #[inline(always)] + fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> { + Cow::Owned(s.cast(&DataType::Binary).unwrap()) + } + + fn combine(&self, a: &mut Self::Value, b: &Self::Value) { + self.reduce_one(a, b.as_deref(), 0) + } + + fn reduce_one(&self, a: &mut Self::Value, b: Option<&[u8]>, _seq_id: u64) { + match (a, b) { + (_, None) => {}, + (l @ None, Some(r)) => *l = Some(r.to_owned()), + (Some(l), Some(r)) => { + if l.as_slice() > r { + l.clear(); + l.extend_from_slice(r); + } + }, + } + } + + fn reduce_ca(&self, v: &mut Self::Value, ca: &BinaryChunked, _seq_id: u64) { + self.reduce_one(v, ca.min_binary(), 0) + } + + fn finish( + &self, + v: Vec, + m: Option, + dtype: &DataType, + ) -> PolarsResult { + assert!(m.is_none()); // This should only be used with VecGroupedReduction. + let ca: BinaryChunked = v.into_iter().collect_ca(PlSmallStr::EMPTY); + ca.into_series().cast(dtype) + } +} + +impl Reducer for BinaryMaxReducer { + type Dtype = BinaryType; + type Value = Option>; // TODO: evaluate SmallVec. + + #[inline(always)] + fn init(&self) -> Self::Value { + None + } + + #[inline(always)] + fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> { + Cow::Owned(s.cast(&DataType::Binary).unwrap()) + } + + #[inline(always)] + fn combine(&self, a: &mut Self::Value, b: &Self::Value) { + self.reduce_one(a, b.as_deref(), 0) + } + + #[inline(always)] + fn reduce_one(&self, a: &mut Self::Value, b: Option<&[u8]>, _seq_id: u64) { + match (a, b) { + (_, None) => {}, + (l @ None, Some(r)) => *l = Some(r.to_owned()), + (Some(l), Some(r)) => { + if l.as_slice() < r { + l.clear(); + l.extend_from_slice(r); + } + }, + } + } + + #[inline(always)] + fn reduce_ca(&self, v: &mut Self::Value, ca: &BinaryChunked, _seq_id: u64) { + self.reduce_one(v, ca.max_binary(), 0) + } + + #[inline(always)] + fn finish( + &self, + v: Vec, + m: Option, + dtype: &DataType, + ) -> PolarsResult { + assert!(m.is_none()); // This should only be used with VecGroupedReduction. + let ca: BinaryChunked = v.into_iter().collect_ca(PlSmallStr::EMPTY); + ca.into_series().cast(dtype) + } +} + +#[derive(Default)] +pub struct BoolMinGroupedReduction { + values: MutableBitmap, + mask: MutableBitmap, + evicted_values: BitmapBuilder, + evicted_mask: BitmapBuilder, +} + +impl GroupedReduction for BoolMinGroupedReduction { + fn new_empty(&self) -> Box { + Box::new(Self::default()) + } + + fn reserve(&mut self, additional: usize) { + self.values.reserve(additional); + self.mask.reserve(additional) + } + + fn resize(&mut self, num_groups: IdxSize) { + self.values.resize(num_groups as usize, true); + self.mask.resize(num_groups as usize, false); + } + + fn update_group( + &mut self, + values: &Column, + group_idx: IdxSize, + _seq_id: u64, + ) -> PolarsResult<()> { + // TODO: we should really implement a sum-as-other-type operation instead + // of doing this materialized cast. + assert!(values.dtype() == &DataType::Boolean); + let values = values.as_materialized_series_maintain_scalar(); + let ca: &BooleanChunked = values.as_ref().as_ref(); + if !ca.all() { + self.values.set(group_idx as usize, false); + } + if ca.len() != ca.null_count() { + self.mask.set(group_idx as usize, true); + } + Ok(()) + } + + unsafe fn update_groups( + &mut self, + values: &Column, + group_idxs: &[IdxSize], + _seq_id: u64, + ) -> PolarsResult<()> { + assert!(values.dtype() == &DataType::Boolean); + assert!(values.len() == group_idxs.len()); + let values = values.as_materialized_series(); // @scalar-opt + let ca: &BooleanChunked = values.as_ref().as_ref(); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for (g, ov) in group_idxs.iter().zip(ca.iter()) { + self.values + .and_pos_unchecked(*g as usize, ov.unwrap_or(true)); + self.mask.or_pos_unchecked(*g as usize, ov.is_some()); + } + } + Ok(()) + } + + unsafe fn update_groups_while_evicting( + &mut self, + values: &Column, + subset: &[IdxSize], + group_idxs: &[EvictIdx], + _seq_id: u64, + ) -> PolarsResult<()> { + assert!(values.dtype() == &DataType::Boolean); + assert!(subset.len() == group_idxs.len()); + let values = values.as_materialized_series(); // @scalar-opt + let ca: &BooleanChunked = values.as_ref().as_ref(); + let arr = ca.downcast_as_array(); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for (i, g) in subset.iter().zip(group_idxs) { + let ov = arr.get_unchecked(*i as usize); + if g.should_evict() { + self.evicted_values.push(self.values.get_unchecked(g.idx())); + self.evicted_mask.push(self.mask.get_unchecked(g.idx())); + self.values.set_unchecked(g.idx(), ov.unwrap_or(true)); + self.mask.set_unchecked(g.idx(), ov.is_some()); + } else { + self.values.and_pos_unchecked(g.idx(), ov.unwrap_or(true)); + self.mask.or_pos_unchecked(g.idx(), ov.is_some()); + } + } + } + Ok(()) + } + + unsafe fn combine( + &mut self, + other: &dyn GroupedReduction, + group_idxs: &[IdxSize], + ) -> PolarsResult<()> { + let other = other.as_any().downcast_ref::().unwrap(); + assert!(self.values.len() == other.values.len()); + assert!(self.mask.len() == other.mask.len()); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for (g, (v, o)) in group_idxs + .iter() + .zip(other.values.iter().zip(other.mask.iter())) + { + self.values.and_pos_unchecked(*g as usize, v); + self.mask.or_pos_unchecked(*g as usize, o); + } + } + Ok(()) + } + + unsafe fn gather_combine( + &mut self, + other: &dyn GroupedReduction, + subset: &[IdxSize], + group_idxs: &[IdxSize], + ) -> PolarsResult<()> { + let other = other.as_any().downcast_ref::().unwrap(); + assert!(subset.len() == group_idxs.len()); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for (i, g) in subset.iter().zip(group_idxs) { + self.values + .and_pos_unchecked(*g as usize, other.values.get_unchecked(*i as usize)); + self.mask + .or_pos_unchecked(*g as usize, other.mask.get_unchecked(*i as usize)); + } + } + Ok(()) + } + + fn take_evictions(&mut self) -> Box { + Box::new(Self { + values: core::mem::take(&mut self.evicted_values).into_mut(), + mask: core::mem::take(&mut self.evicted_mask).into_mut(), + evicted_values: BitmapBuilder::new(), + evicted_mask: BitmapBuilder::new(), + }) + } + + fn finalize(&mut self) -> PolarsResult { + let v = core::mem::take(&mut self.values); + let m = core::mem::take(&mut self.mask); + let arr = BooleanArray::from(v.freeze()) + .with_validity(Some(m.freeze())) + .boxed(); + Ok(unsafe { + Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![arr], + &DataType::Boolean, + ) + }) + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +#[derive(Default)] +pub struct BoolMaxGroupedReduction { + values: MutableBitmap, + mask: MutableBitmap, + evicted_values: BitmapBuilder, + evicted_mask: BitmapBuilder, +} + +impl GroupedReduction for BoolMaxGroupedReduction { + fn new_empty(&self) -> Box { + Box::new(Self::default()) + } + + fn reserve(&mut self, additional: usize) { + self.values.reserve(additional); + self.mask.reserve(additional) + } + + fn resize(&mut self, num_groups: IdxSize) { + self.values.resize(num_groups as usize, false); + self.mask.resize(num_groups as usize, false); + } + + fn update_group( + &mut self, + values: &Column, + group_idx: IdxSize, + _seq_id: u64, + ) -> PolarsResult<()> { + // TODO: we should really implement a sum-as-other-type operation instead + // of doing this materialized cast. + assert!(values.dtype() == &DataType::Boolean); + let values = values.as_materialized_series_maintain_scalar(); + let ca: &BooleanChunked = values.as_ref().as_ref(); + if ca.any() { + self.values.set(group_idx as usize, true); + } + if ca.len() != ca.null_count() { + self.mask.set(group_idx as usize, true); + } + Ok(()) + } + + unsafe fn update_groups( + &mut self, + values: &Column, + group_idxs: &[IdxSize], + _seq_id: u64, + ) -> PolarsResult<()> { + // TODO: we should really implement a sum-as-other-type operation instead + // of doing this materialized cast. + assert!(values.dtype() == &DataType::Boolean); + assert!(values.len() == group_idxs.len()); + let values = values.as_materialized_series(); // @scalar-opt + let ca: &BooleanChunked = values.as_ref().as_ref(); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for (g, ov) in group_idxs.iter().zip(ca.iter()) { + self.values + .or_pos_unchecked(*g as usize, ov.unwrap_or(false)); + self.mask.or_pos_unchecked(*g as usize, ov.is_some()); + } + } + Ok(()) + } + + unsafe fn update_groups_while_evicting( + &mut self, + values: &Column, + subset: &[IdxSize], + group_idxs: &[EvictIdx], + _seq_id: u64, + ) -> PolarsResult<()> { + assert!(values.dtype() == &DataType::Boolean); + assert!(subset.len() == group_idxs.len()); + let values = values.as_materialized_series(); // @scalar-opt + let ca: &BooleanChunked = values.as_ref().as_ref(); + let arr = ca.downcast_as_array(); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for (i, g) in subset.iter().zip(group_idxs) { + let ov = arr.get_unchecked(*i as usize); + if g.should_evict() { + self.evicted_values.push(self.values.get_unchecked(g.idx())); + self.evicted_mask.push(self.mask.get_unchecked(g.idx())); + self.values.set_unchecked(g.idx(), ov.unwrap_or(false)); + self.mask.set_unchecked(g.idx(), ov.is_some()); + } else { + self.values.or_pos_unchecked(g.idx(), ov.unwrap_or(false)); + self.mask.or_pos_unchecked(g.idx(), ov.is_some()); + } + } + } + Ok(()) + } + + unsafe fn combine( + &mut self, + other: &dyn GroupedReduction, + group_idxs: &[IdxSize], + ) -> PolarsResult<()> { + let other = other.as_any().downcast_ref::().unwrap(); + assert!(other.values.len() == group_idxs.len()); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for (g, (v, o)) in group_idxs + .iter() + .zip(other.values.iter().zip(other.mask.iter())) + { + self.values.or_pos_unchecked(*g as usize, v); + self.mask.or_pos_unchecked(*g as usize, o); + } + } + Ok(()) + } + + unsafe fn gather_combine( + &mut self, + other: &dyn GroupedReduction, + subset: &[IdxSize], + group_idxs: &[IdxSize], + ) -> PolarsResult<()> { + let other = other.as_any().downcast_ref::().unwrap(); + assert!(subset.len() == group_idxs.len()); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for (i, g) in subset.iter().zip(group_idxs) { + self.values + .or_pos_unchecked(*g as usize, other.values.get_unchecked(*i as usize)); + self.mask + .or_pos_unchecked(*g as usize, other.mask.get_unchecked(*i as usize)); + } + } + Ok(()) + } + + fn take_evictions(&mut self) -> Box { + Box::new(Self { + values: core::mem::take(&mut self.evicted_values).into_mut(), + mask: core::mem::take(&mut self.evicted_mask).into_mut(), + evicted_values: BitmapBuilder::new(), + evicted_mask: BitmapBuilder::new(), + }) + } + + fn finalize(&mut self) -> PolarsResult { + let v = core::mem::take(&mut self.values); + let m = core::mem::take(&mut self.mask); + let arr = BooleanArray::from(v.freeze()) + .with_validity(Some(m.freeze())) + .boxed(); + Ok(unsafe { + Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![arr], + &DataType::Boolean, + ) + }) + } + + fn as_any(&self) -> &dyn Any { + self + } +} diff --git a/crates/polars-expr/src/reduce/mod.rs b/crates/polars-expr/src/reduce/mod.rs new file mode 100644 index 000000000000..b091f2866a8c --- /dev/null +++ b/crates/polars-expr/src/reduce/mod.rs @@ -0,0 +1,702 @@ +#![allow(unsafe_op_in_unsafe_fn)] +mod convert; +mod count; +mod first_last; +mod len; +mod mean; +mod min_max; +mod sum; +mod var_std; + +use std::any::Any; +use std::borrow::Cow; +use std::marker::PhantomData; + +use arrow::array::{Array, PrimitiveArray, StaticArray}; +use arrow::bitmap::{Bitmap, BitmapBuilder, MutableBitmap}; +pub use convert::into_reduction; +use polars_core::prelude::*; + +use crate::EvictIdx; + +/// A reduction with groups. +/// +/// Each group has its own reduction state that values can be aggregated into. +pub trait GroupedReduction: Any + Send + Sync { + /// Returns a new empty reduction. + fn new_empty(&self) -> Box; + + /// Reserves space in this GroupedReduction for an additional number of groups. + fn reserve(&mut self, additional: usize); + + /// Resizes this GroupedReduction to the given number of groups. + /// + /// While not an actual member of the trait, the safety preconditions below + /// refer to self.num_groups() as given by the last call of this function. + fn resize(&mut self, num_groups: IdxSize); + + /// Updates the specified group with the given values. + /// + /// For order-sensitive grouped reductions, seq_id can be used to resolve + /// order between calls/multiple reductions. + fn update_group( + &mut self, + values: &Column, + group_idx: IdxSize, + seq_id: u64, + ) -> PolarsResult<()>; + + /// Updates this GroupedReduction with new values. values[i] should + /// be added to reduction self[group_idxs[i]]. For order-sensitive grouped + /// reductions, seq_id can be used to resolve order between calls/multiple + /// reductions. + /// + /// # Safety + /// group_idxs[i] < self.num_groups() for all i. + unsafe fn update_groups( + &mut self, + values: &Column, + group_idxs: &[IdxSize], + seq_id: u64, + ) -> PolarsResult<()>; + + /// Updates this GroupedReduction with new values. values[subset[i]] should + /// be added to reduction self[group_idxs[i]]. For order-sensitive grouped + /// reductions, seq_id can be used to resolve order between calls/multiple + /// reductions. + /// + /// # Safety + /// The subset and group_idxs are in-bounds. + unsafe fn update_groups_subset( + &mut self, + values: &Column, + subset: &[IdxSize], + group_idxs: &[IdxSize], + seq_id: u64, + ) -> PolarsResult<()> { + assert!(values.len() < (1 << (IdxSize::BITS - 1))); + let evict_group_idxs = core::mem::transmute::<&[IdxSize], &[EvictIdx]>(group_idxs); + self.update_groups_while_evicting(values, subset, evict_group_idxs, seq_id) + } + + /// Updates this GroupedReduction with new values. values[subset[i]] should + /// be added to reduction self[group_idxs[i]]. For order-sensitive grouped + /// reductions, seq_id can be used to resolve order between calls/multiple + /// reductions. If the group_idxs[i] has its evict bit set the current value + /// in the group should be evicted and reset before updating. + /// + /// # Safety + /// The subset and group_idxs are in-bounds. + unsafe fn update_groups_while_evicting( + &mut self, + values: &Column, + subset: &[IdxSize], + group_idxs: &[EvictIdx], + seq_id: u64, + ) -> PolarsResult<()>; + + /// Combines this GroupedReduction with another. Group other[i] + /// should be combined into group self[group_idxs[i]]. + /// + /// # Safety + /// group_idxs[i] < self.num_groups() for all i. + unsafe fn combine( + &mut self, + other: &dyn GroupedReduction, + group_idxs: &[IdxSize], + ) -> PolarsResult<()>; + + /// Combines this GroupedReduction with another. Group other[subset[i]] + /// should be combined into group self[group_idxs[i]]. + /// + /// # Safety + /// subset[i] < other.num_groups() for all i. + /// group_idxs[i] < self.num_groups() for all i. + unsafe fn gather_combine( + &mut self, + other: &dyn GroupedReduction, + subset: &[IdxSize], + group_idxs: &[IdxSize], + ) -> PolarsResult<()>; + + /// Take the accumulated evicted groups. + fn take_evictions(&mut self) -> Box; + + /// Returns the finalized value per group as a Series. + /// + /// After this operation the number of groups is reset to 0. + fn finalize(&mut self) -> PolarsResult; + + /// Returns this GroupedReduction as a dyn Any. + fn as_any(&self) -> &dyn Any; +} + +// Helper traits used in the VecGroupedReduction and VecMaskGroupedReduction to +// reduce code duplication. +pub trait Reducer: Send + Sync + Clone + 'static { + type Dtype: PolarsDataType; + type Value: Clone + Send + Sync + 'static; + fn init(&self) -> Self::Value; + #[inline(always)] + fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> { + Cow::Borrowed(s) + } + fn combine(&self, a: &mut Self::Value, b: &Self::Value); + fn reduce_one( + &self, + a: &mut Self::Value, + b: Option<::Physical<'_>>, + seq_id: u64, + ); + fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray, seq_id: u64); + fn finish( + &self, + v: Vec, + m: Option, + dtype: &DataType, + ) -> PolarsResult; +} + +pub trait NumericReduction: Send + Sync + 'static { + type Dtype: PolarsNumericType; + fn init() -> ::Native; + fn combine( + a: ::Native, + b: ::Native, + ) -> ::Native; + fn reduce_ca( + ca: &ChunkedArray, + ) -> Option<::Native>; +} + +struct NumReducer(PhantomData); +impl NumReducer { + fn new() -> Self { + Self(PhantomData) + } +} +impl Clone for NumReducer { + fn clone(&self) -> Self { + Self(PhantomData) + } +} + +impl Reducer for NumReducer { + type Dtype = ::Dtype; + type Value = <::Dtype as PolarsNumericType>::Native; + + #[inline(always)] + fn init(&self) -> Self::Value { + ::init() + } + + #[inline(always)] + fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> { + s.to_physical_repr() + } + + #[inline(always)] + fn combine(&self, a: &mut Self::Value, b: &Self::Value) { + *a = ::combine(*a, *b); + } + + #[inline(always)] + fn reduce_one( + &self, + a: &mut Self::Value, + b: Option<::Physical<'_>>, + _seq_id: u64, + ) { + if let Some(b) = b { + *a = ::combine(*a, b); + } + } + + #[inline(always)] + fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray, _seq_id: u64) { + if let Some(r) = ::reduce_ca(ca) { + *v = ::combine(*v, r); + } + } + + fn finish( + &self, + v: Vec, + m: Option, + dtype: &DataType, + ) -> PolarsResult { + let arr = Box::new(PrimitiveArray::::from_vec(v).with_validity(m)); + Ok(unsafe { Series::from_chunks_and_dtype_unchecked(PlSmallStr::EMPTY, vec![arr], dtype) }) + } +} + +pub struct VecGroupedReduction { + values: Vec, + evicted_values: Vec, + in_dtype: DataType, + reducer: R, +} + +impl VecGroupedReduction { + fn new(in_dtype: DataType, reducer: R) -> Self { + Self { + values: Vec::new(), + evicted_values: Vec::new(), + in_dtype, + reducer, + } + } +} + +impl GroupedReduction for VecGroupedReduction +where + R: Reducer, +{ + fn new_empty(&self) -> Box { + Box::new(Self { + values: Vec::new(), + evicted_values: Vec::new(), + in_dtype: self.in_dtype.clone(), + reducer: self.reducer.clone(), + }) + } + + fn reserve(&mut self, additional: usize) { + self.values.reserve(additional); + } + + fn resize(&mut self, num_groups: IdxSize) { + self.values.resize(num_groups as usize, self.reducer.init()); + } + + fn update_group( + &mut self, + values: &Column, + group_idx: IdxSize, + seq_id: u64, + ) -> PolarsResult<()> { + assert!(values.dtype() == &self.in_dtype); + let seq_id = seq_id + 1; // So we can use 0 for 'none yet'. + let values = values.as_materialized_series(); // @scalar-opt + let values = self.reducer.cast_series(values); + let ca: &ChunkedArray = values.as_ref().as_ref().as_ref(); + self.reducer + .reduce_ca(&mut self.values[group_idx as usize], ca, seq_id); + Ok(()) + } + + unsafe fn update_groups( + &mut self, + values: &Column, + group_idxs: &[IdxSize], + seq_id: u64, + ) -> PolarsResult<()> { + assert!(values.dtype() == &self.in_dtype); + assert!(values.len() == group_idxs.len()); + let seq_id = seq_id + 1; // So we can use 0 for 'none yet'. + let values = values.as_materialized_series(); // @scalar-opt + let values = self.reducer.cast_series(values); + let ca: &ChunkedArray = values.as_ref().as_ref().as_ref(); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + if values.has_nulls() { + for (g, ov) in group_idxs.iter().zip(ca.iter()) { + let grp = self.values.get_unchecked_mut(*g as usize); + self.reducer.reduce_one(grp, ov, seq_id); + } + } else { + let mut offset = 0; + for arr in ca.downcast_iter() { + let subgroup = &group_idxs[offset..offset + arr.len()]; + for (g, v) in subgroup.iter().zip(arr.values_iter()) { + let grp = self.values.get_unchecked_mut(*g as usize); + self.reducer.reduce_one(grp, Some(v), seq_id); + } + offset += arr.len(); + } + } + } + Ok(()) + } + + unsafe fn update_groups_while_evicting( + &mut self, + values: &Column, + subset: &[IdxSize], + group_idxs: &[EvictIdx], + seq_id: u64, + ) -> PolarsResult<()> { + assert!(values.dtype() == &self.in_dtype); + assert!(subset.len() == group_idxs.len()); + let seq_id = seq_id + 1; // So we can use 0 for 'none yet'. + let values = values.as_materialized_series(); // @scalar-opt + let values = self.reducer.cast_series(values); + let ca: &ChunkedArray = values.as_ref().as_ref().as_ref(); + let arr = ca.downcast_as_array(); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + if values.has_nulls() { + for (i, g) in subset.iter().zip(group_idxs) { + let ov = arr.get_unchecked(*i as usize); + let grp = self.values.get_unchecked_mut(g.idx()); + if g.should_evict() { + let old = core::mem::replace(grp, self.reducer.init()); + self.evicted_values.push(old); + } + self.reducer.reduce_one(grp, ov, seq_id); + } + } else { + for (i, g) in subset.iter().zip(group_idxs) { + let v = arr.value_unchecked(*i as usize); + let grp = self.values.get_unchecked_mut(g.idx()); + if g.should_evict() { + let old = core::mem::replace(grp, self.reducer.init()); + self.evicted_values.push(old); + } + self.reducer.reduce_one(grp, Some(v), seq_id); + } + } + } + Ok(()) + } + + unsafe fn combine( + &mut self, + other: &dyn GroupedReduction, + group_idxs: &[IdxSize], + ) -> PolarsResult<()> { + let other = other.as_any().downcast_ref::().unwrap(); + assert!(self.in_dtype == other.in_dtype); + assert!(group_idxs.len() == other.values.len()); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for (g, v) in group_idxs.iter().zip(other.values.iter()) { + let grp = self.values.get_unchecked_mut(*g as usize); + self.reducer.combine(grp, v); + } + } + Ok(()) + } + + unsafe fn gather_combine( + &mut self, + other: &dyn GroupedReduction, + subset: &[IdxSize], + group_idxs: &[IdxSize], + ) -> PolarsResult<()> { + let other = other.as_any().downcast_ref::().unwrap(); + assert!(self.in_dtype == other.in_dtype); + assert!(subset.len() == group_idxs.len()); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for (i, g) in subset.iter().zip(group_idxs) { + let v = other.values.get_unchecked(*i as usize); + let grp = self.values.get_unchecked_mut(*g as usize); + self.reducer.combine(grp, v); + } + } + Ok(()) + } + + fn take_evictions(&mut self) -> Box { + Box::new(Self { + values: core::mem::take(&mut self.evicted_values), + evicted_values: Vec::new(), + in_dtype: self.in_dtype.clone(), + reducer: self.reducer.clone(), + }) + } + + fn finalize(&mut self) -> PolarsResult { + let v = core::mem::take(&mut self.values); + self.reducer.finish(v, None, &self.in_dtype) + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +pub struct VecMaskGroupedReduction { + values: Vec, + mask: MutableBitmap, + evicted_values: Vec, + evicted_mask: BitmapBuilder, + in_dtype: DataType, + reducer: R, +} + +impl VecMaskGroupedReduction { + fn new(in_dtype: DataType, reducer: R) -> Self { + Self { + values: Vec::new(), + mask: MutableBitmap::new(), + evicted_values: Vec::new(), + evicted_mask: BitmapBuilder::new(), + in_dtype, + reducer, + } + } +} + +impl GroupedReduction for VecMaskGroupedReduction +where + R: Reducer, +{ + fn new_empty(&self) -> Box { + Box::new(Self::new(self.in_dtype.clone(), self.reducer.clone())) + } + + fn reserve(&mut self, additional: usize) { + self.values.reserve(additional); + self.mask.reserve(additional) + } + + fn resize(&mut self, num_groups: IdxSize) { + self.values.resize(num_groups as usize, self.reducer.init()); + self.mask.resize(num_groups as usize, false); + } + + fn update_group( + &mut self, + values: &Column, + group_idx: IdxSize, + seq_id: u64, + ) -> PolarsResult<()> { + assert!(values.dtype() == &self.in_dtype); + let seq_id = seq_id + 1; // So we can use 0 for 'none yet'. + let values = values.as_materialized_series(); // @scalar-opt + let values = values.to_physical_repr(); + let ca: &ChunkedArray = values.as_ref().as_ref().as_ref(); + self.reducer + .reduce_ca(&mut self.values[group_idx as usize], ca, seq_id); + if ca.len() != ca.null_count() { + self.mask.set(group_idx as usize, true); + } + Ok(()) + } + + unsafe fn update_groups( + &mut self, + values: &Column, + group_idxs: &[IdxSize], + seq_id: u64, + ) -> PolarsResult<()> { + assert!(values.dtype() == &self.in_dtype); + assert!(values.len() == group_idxs.len()); + let seq_id = seq_id + 1; // So we can use 0 for 'none yet'. + let values = values.as_materialized_series(); // @scalar-opt + let values = values.to_physical_repr(); + let ca: &ChunkedArray = values.as_ref().as_ref().as_ref(); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for (g, ov) in group_idxs.iter().zip(ca.iter()) { + if let Some(v) = ov { + let grp = self.values.get_unchecked_mut(*g as usize); + self.reducer.reduce_one(grp, Some(v), seq_id); + self.mask.set_unchecked(*g as usize, true); + } + } + } + Ok(()) + } + + unsafe fn update_groups_while_evicting( + &mut self, + values: &Column, + subset: &[IdxSize], + group_idxs: &[EvictIdx], + seq_id: u64, + ) -> PolarsResult<()> { + assert!(values.dtype() == &self.in_dtype); + assert!(subset.len() == group_idxs.len()); + let seq_id = seq_id + 1; // So we can use 0 for 'none yet'. + let values = values.as_materialized_series(); // @scalar-opt + let values = values.to_physical_repr(); + let ca: &ChunkedArray = values.as_ref().as_ref().as_ref(); + let arr = ca.downcast_as_array(); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for (i, g) in subset.iter().zip(group_idxs) { + let ov = arr.get_unchecked(*i as usize); + let grp = self.values.get_unchecked_mut(g.idx()); + if g.should_evict() { + self.evicted_values + .push(core::mem::replace(grp, self.reducer.init())); + self.evicted_mask.push(self.mask.get_unchecked(g.idx())); + self.mask.set_unchecked(g.idx(), false); + } + if let Some(v) = ov { + self.reducer.reduce_one(grp, Some(v), seq_id); + self.mask.set_unchecked(g.idx(), true); + } + } + } + Ok(()) + } + + unsafe fn combine( + &mut self, + other: &dyn GroupedReduction, + group_idxs: &[IdxSize], + ) -> PolarsResult<()> { + let other = other.as_any().downcast_ref::().unwrap(); + assert!(self.in_dtype == other.in_dtype); + assert!(group_idxs.len() == other.values.len()); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for (g, (v, o)) in group_idxs + .iter() + .zip(other.values.iter().zip(other.mask.iter())) + { + if o { + let grp = self.values.get_unchecked_mut(*g as usize); + self.reducer.combine(grp, v); + self.mask.set_unchecked(*g as usize, true); + } + } + } + Ok(()) + } + + unsafe fn gather_combine( + &mut self, + other: &dyn GroupedReduction, + subset: &[IdxSize], + group_idxs: &[IdxSize], + ) -> PolarsResult<()> { + let other = other.as_any().downcast_ref::().unwrap(); + assert!(self.in_dtype == other.in_dtype); + assert!(subset.len() == group_idxs.len()); + unsafe { + // SAFETY: indices are in-bounds guaranteed by trait. + for (i, g) in subset.iter().zip(group_idxs) { + let o = other.mask.get_unchecked(*i as usize); + if o { + let v = other.values.get_unchecked(*i as usize); + let grp = self.values.get_unchecked_mut(*g as usize); + self.reducer.combine(grp, v); + self.mask.set_unchecked(*g as usize, true); + } + } + } + Ok(()) + } + + fn take_evictions(&mut self) -> Box { + Box::new(Self { + values: core::mem::take(&mut self.evicted_values), + mask: core::mem::take(&mut self.evicted_mask).into_mut(), + evicted_values: Vec::new(), + evicted_mask: BitmapBuilder::new(), + in_dtype: self.in_dtype.clone(), + reducer: self.reducer.clone(), + }) + } + + fn finalize(&mut self) -> PolarsResult { + let v = core::mem::take(&mut self.values); + let m = core::mem::take(&mut self.mask); + self.reducer.finish(v, Some(m.freeze()), &self.in_dtype) + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +struct NullGroupedReduction { + num_groups: IdxSize, + num_evictions: IdxSize, + dtype: DataType, +} + +impl NullGroupedReduction { + fn new(dtype: DataType) -> Self { + Self { + num_groups: 0, + num_evictions: 0, + dtype, + } + } +} + +impl GroupedReduction for NullGroupedReduction { + fn new_empty(&self) -> Box { + Box::new(Self::new(self.dtype.clone())) + } + + fn reserve(&mut self, _additional: usize) {} + + fn resize(&mut self, num_groups: IdxSize) { + self.num_groups = num_groups; + } + + fn update_group( + &mut self, + _values: &Column, + _group_idx: IdxSize, + _seq_id: u64, + ) -> PolarsResult<()> { + Ok(()) + } + + unsafe fn update_groups( + &mut self, + _values: &Column, + _group_idxs: &[IdxSize], + _seq_id: u64, + ) -> PolarsResult<()> { + Ok(()) + } + + unsafe fn update_groups_while_evicting( + &mut self, + _values: &Column, + _subset: &[IdxSize], + group_idxs: &[EvictIdx], + _seq_id: u64, + ) -> PolarsResult<()> { + for g in group_idxs { + self.num_evictions += g.should_evict() as IdxSize; + } + Ok(()) + } + + unsafe fn combine( + &mut self, + _other: &dyn GroupedReduction, + _group_idxs: &[IdxSize], + ) -> PolarsResult<()> { + Ok(()) + } + + unsafe fn gather_combine( + &mut self, + _other: &dyn GroupedReduction, + _subset: &[IdxSize], + _group_idxs: &[IdxSize], + ) -> PolarsResult<()> { + Ok(()) + } + + fn take_evictions(&mut self) -> Box { + Box::new(Self { + num_groups: core::mem::replace(&mut self.num_evictions, 0), + num_evictions: 0, + dtype: self.dtype.clone(), + }) + } + + fn finalize(&mut self) -> PolarsResult { + Ok(Series::full_null( + PlSmallStr::EMPTY, + core::mem::replace(&mut self.num_groups, 0) as usize, + &self.dtype, + )) + } + + fn as_any(&self) -> &dyn Any { + self + } +} diff --git a/crates/polars-expr/src/reduce/sum.rs b/crates/polars-expr/src/reduce/sum.rs new file mode 100644 index 000000000000..0adbfd0aab5f --- /dev/null +++ b/crates/polars-expr/src/reduce/sum.rs @@ -0,0 +1,165 @@ +use std::borrow::Cow; + +use arrow::array::PrimitiveArray; +use num_traits::Zero; +use polars_core::with_match_physical_numeric_polars_type; +use polars_utils::float::IsFloat; + +use super::*; + +pub trait SumCast: Sized { + type Sum: NumericNative + From; +} + +macro_rules! impl_sum_cast { + ($($x:ty),*) => { + $(impl SumCast for $x { type Sum = $x; })* + }; + ($($from:ty as $to:ty),*) => { + $(impl SumCast for $from { type Sum = $to; })* + }; +} + +impl_sum_cast!( + bool as IdxSize, + u8 as i64, + u16 as i64, + i8 as i64, + i16 as i64 +); +impl_sum_cast!(u32, u64, i32, i64, f32, f64); +#[cfg(feature = "dtype-i128")] +impl_sum_cast!(i128); + +fn out_dtype(in_dtype: &DataType) -> DataType { + use DataType::*; + match in_dtype { + Boolean => IDX_DTYPE, + Int8 | UInt8 | Int16 | UInt16 => Int64, + dt => dt.clone(), + } +} + +pub fn new_sum_reduction(dtype: DataType) -> Box { + use DataType::*; + use VecGroupedReduction as VGR; + match dtype { + Boolean => Box::new(VGR::new(dtype, BoolSumReducer)), + _ if dtype.is_primitive_numeric() => { + with_match_physical_numeric_polars_type!(dtype.to_physical(), |$T| { + Box::new(VGR::new(dtype, NumSumReducer::<$T>(PhantomData))) + }) + }, + #[cfg(feature = "dtype-decimal")] + Decimal(_, _) => Box::new(VGR::new(dtype, NumSumReducer::(PhantomData))), + Duration(_) => Box::new(VGR::new(dtype, NumSumReducer::(PhantomData))), + // For compatibility with the current engine, should probably be an error. + String | Binary => Box::new(super::NullGroupedReduction::new(dtype)), + _ => unimplemented!("{dtype:?} is not supported by sum reduction"), + } +} + +struct NumSumReducer(PhantomData); +impl Clone for NumSumReducer { + fn clone(&self) -> Self { + Self(PhantomData) + } +} + +impl Reducer for NumSumReducer +where + T: PolarsNumericType, + ::Native: SumCast, + ChunkedArray: ChunkAgg + IntoSeries, +{ + type Dtype = T; + type Value = ::Sum; + + #[inline(always)] + fn init(&self) -> Self::Value { + Zero::zero() + } + + fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> { + s.to_physical_repr() + } + + #[inline(always)] + fn combine(&self, a: &mut Self::Value, b: &Self::Value) { + *a += *b; + } + + #[inline(always)] + fn reduce_one(&self, a: &mut Self::Value, b: Option, _seq_id: u64) { + *a += b.map(Into::into).unwrap_or(Zero::zero()); + } + + fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray, _seq_id: u64) { + if T::Native::is_float() { + *v += ChunkAgg::sum(ca).map(Into::into).unwrap_or(Zero::zero()); + } else { + for arr in ca.downcast_iter() { + if arr.has_nulls() { + for x in arr.iter() { + *v += x.copied().map(Into::into).unwrap_or(Zero::zero()); + } + } else { + for x in arr.values_iter().copied() { + *v += x.into(); + } + } + } + } + } + + fn finish( + &self, + v: Vec, + m: Option, + dtype: &DataType, + ) -> PolarsResult { + assert!(m.is_none()); + let arr = Box::new(PrimitiveArray::from_vec(v)); + Ok(unsafe { + Series::from_chunks_and_dtype_unchecked(PlSmallStr::EMPTY, vec![arr], &out_dtype(dtype)) + }) + } +} + +#[derive(Clone)] +struct BoolSumReducer; + +impl Reducer for BoolSumReducer { + type Dtype = BooleanType; + type Value = IdxSize; + + #[inline(always)] + fn init(&self) -> Self::Value { + 0 + } + + #[inline(always)] + fn combine(&self, a: &mut Self::Value, b: &Self::Value) { + *a += *b; + } + + #[inline(always)] + fn reduce_one(&self, a: &mut Self::Value, b: Option, _seq_id: u64) { + *a += b.unwrap_or(false) as IdxSize; + } + + fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray, _seq_id: u64) { + *v += ca.sum().unwrap_or(0) as IdxSize; + } + + fn finish( + &self, + v: Vec, + m: Option, + dtype: &DataType, + ) -> PolarsResult { + assert!(m.is_none()); + assert!(dtype == &DataType::Boolean); + Ok(IdxCa::from_vec(PlSmallStr::EMPTY, v).into_series()) + } +} diff --git a/crates/polars-expr/src/reduce/var_std.rs b/crates/polars-expr/src/reduce/var_std.rs new file mode 100644 index 000000000000..f48508e2f86b --- /dev/null +++ b/crates/polars-expr/src/reduce/var_std.rs @@ -0,0 +1,164 @@ +use std::marker::PhantomData; + +use num_traits::AsPrimitive; +use polars_compute::moment::VarState; +use polars_core::with_match_physical_numeric_polars_type; + +use super::*; + +pub fn new_var_std_reduction(dtype: DataType, is_std: bool, ddof: u8) -> Box { + use DataType::*; + use VecGroupedReduction as VGR; + match dtype { + Boolean => Box::new(VGR::new(dtype, BoolVarStdReducer { is_std, ddof })), + _ if dtype.is_primitive_numeric() => { + with_match_physical_numeric_polars_type!(dtype.to_physical(), |$T| { + Box::new(VGR::new(dtype, VarStdReducer::<$T> { + is_std, + ddof, + needs_cast: false, + _phantom: PhantomData, + })) + }) + }, + #[cfg(feature = "dtype-decimal")] + Decimal(_, _) => Box::new(VGR::new( + dtype, + VarStdReducer:: { + is_std, + ddof, + needs_cast: true, + _phantom: PhantomData, + }, + )), + Duration(..) => todo!(), + _ => unimplemented!(), + } +} + +struct VarStdReducer { + is_std: bool, + ddof: u8, + needs_cast: bool, + _phantom: PhantomData, +} + +impl Clone for VarStdReducer { + fn clone(&self) -> Self { + Self { + is_std: self.is_std, + ddof: self.ddof, + needs_cast: self.needs_cast, + _phantom: PhantomData, + } + } +} + +impl Reducer for VarStdReducer { + type Dtype = T; + type Value = VarState; + + fn init(&self) -> Self::Value { + VarState::default() + } + + fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> { + if self.needs_cast { + Cow::Owned(s.cast(&DataType::Float64).unwrap()) + } else { + Cow::Borrowed(s) + } + } + + fn combine(&self, a: &mut Self::Value, b: &Self::Value) { + a.combine(b) + } + + #[inline(always)] + fn reduce_one(&self, a: &mut Self::Value, b: Option, _seq_id: u64) { + if let Some(x) = b { + a.insert_one(x.as_()); + } + } + + fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray, _seq_id: u64) { + for arr in ca.downcast_iter() { + v.combine(&polars_compute::moment::var(arr)) + } + } + + fn finish( + &self, + v: Vec, + m: Option, + _dtype: &DataType, + ) -> PolarsResult { + assert!(m.is_none()); + let ca: Float64Chunked = v + .into_iter() + .map(|s| { + let var = s.finalize(self.ddof); + if self.is_std { var.map(f64::sqrt) } else { var } + }) + .collect_ca(PlSmallStr::EMPTY); + Ok(ca.into_series()) + } +} + +#[derive(Clone)] +struct BoolVarStdReducer { + is_std: bool, + ddof: u8, +} + +impl Reducer for BoolVarStdReducer { + type Dtype = BooleanType; + type Value = (usize, usize); + + fn init(&self) -> Self::Value { + (0, 0) + } + + fn combine(&self, a: &mut Self::Value, b: &Self::Value) { + a.0 += b.0; + a.1 += b.1; + } + + #[inline(always)] + fn reduce_one(&self, a: &mut Self::Value, b: Option, _seq_id: u64) { + a.0 += b.unwrap_or(false) as usize; + a.1 += b.is_some() as usize; + } + + fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray, _seq_id: u64) { + v.0 += ca.sum().unwrap_or(0) as usize; + v.1 += ca.len() - ca.null_count(); + } + + fn finish( + &self, + v: Vec, + m: Option, + _dtype: &DataType, + ) -> PolarsResult { + assert!(m.is_none()); + let ca: Float64Chunked = v + .into_iter() + .map(|v| { + if v.1 <= self.ddof as usize { + return None; + } + + let sum = v.0 as f64; // Both the sum and sum-of-squares, letting us simplify. + let n = v.1; + let var = sum * (1.0 - sum / n as f64) / ((n - self.ddof as usize) as f64); + if self.is_std { + Some(var.sqrt()) + } else { + Some(var) + } + }) + .collect_ca(PlSmallStr::EMPTY); + Ok(ca.into_series()) + } +} diff --git a/crates/polars-expr/src/state/execution_state.rs b/crates/polars-expr/src/state/execution_state.rs new file mode 100644 index 000000000000..d52234a99461 --- /dev/null +++ b/crates/polars-expr/src/state/execution_state.rs @@ -0,0 +1,314 @@ +use std::borrow::Cow; +use std::sync::atomic::{AtomicBool, AtomicI64, AtomicU8, Ordering}; +use std::sync::{Mutex, OnceLock, RwLock}; +use std::time::Duration; + +use bitflags::bitflags; +use polars_core::config::verbose; +use polars_core::prelude::*; +use polars_ops::prelude::ChunkJoinOptIds; + +use super::NodeTimer; + +pub type JoinTuplesCache = Arc>>; + +#[derive(Default)] +pub struct WindowCache { + groups: RwLock>, + join_tuples: RwLock>>, + map_idx: RwLock>>, +} + +impl WindowCache { + pub(crate) fn clear(&self) { + let mut g = self.groups.write().unwrap(); + g.clear(); + let mut g = self.join_tuples.write().unwrap(); + g.clear(); + } + + pub fn get_groups(&self, key: &str) -> Option { + let g = self.groups.read().unwrap(); + g.get(key).cloned() + } + + pub fn insert_groups(&self, key: String, groups: GroupPositions) { + let mut g = self.groups.write().unwrap(); + g.insert(key, groups); + } + + pub fn get_join(&self, key: &str) -> Option> { + let g = self.join_tuples.read().unwrap(); + g.get(key).cloned() + } + + pub fn insert_join(&self, key: String, join_tuples: Arc) { + let mut g = self.join_tuples.write().unwrap(); + g.insert(key, join_tuples); + } + + pub fn get_map(&self, key: &str) -> Option> { + let g = self.map_idx.read().unwrap(); + g.get(key).cloned() + } + + pub fn insert_map(&self, key: String, idx: Arc) { + let mut g = self.map_idx.write().unwrap(); + g.insert(key, idx); + } +} + +bitflags! { + #[repr(transparent)] + #[derive(Copy, Clone)] + pub(super) struct StateFlags: u8 { + /// More verbose logging + const VERBOSE = 0x01; + /// Indicates that window expression's [`GroupTuples`] may be cached. + const CACHE_WINDOW_EXPR = 0x02; + /// Indicates the expression has a window function + const HAS_WINDOW = 0x04; + /// If set, the expression is evaluated in the + /// streaming engine. + const IN_STREAMING = 0x08; + } +} + +impl Default for StateFlags { + fn default() -> Self { + StateFlags::CACHE_WINDOW_EXPR + } +} + +impl StateFlags { + fn init() -> Self { + let verbose = verbose(); + let mut flags: StateFlags = Default::default(); + if verbose { + flags |= StateFlags::VERBOSE; + } + flags + } + fn as_u8(self) -> u8 { + unsafe { std::mem::transmute(self) } + } +} + +impl From for StateFlags { + fn from(value: u8) -> Self { + unsafe { std::mem::transmute(value) } + } +} + +type CachedValue = Arc<(AtomicI64, OnceLock)>; + +/// State/ cache that is maintained during the Execution of the physical plan. +pub struct ExecutionState { + // cached by a `.cache` call and kept in memory for the duration of the plan. + df_cache: Arc>>, + pub schema_cache: RwLock>, + /// Used by Window Expressions to cache intermediate state + pub window_cache: Arc, + // every join/union split gets an increment to distinguish between schema state + pub branch_idx: usize, + pub flags: AtomicU8, + pub ext_contexts: Arc>, + node_timer: Option, + stop: Arc, +} + +impl ExecutionState { + pub fn new() -> Self { + let mut flags: StateFlags = Default::default(); + if verbose() { + flags |= StateFlags::VERBOSE; + } + Self { + df_cache: Default::default(), + schema_cache: Default::default(), + window_cache: Default::default(), + branch_idx: 0, + flags: AtomicU8::new(StateFlags::init().as_u8()), + ext_contexts: Default::default(), + node_timer: None, + stop: Arc::new(AtomicBool::new(false)), + } + } + + /// Toggle this to measure execution times. + pub fn time_nodes(&mut self, start: std::time::Instant) { + self.node_timer = Some(NodeTimer::new(start)) + } + pub fn has_node_timer(&self) -> bool { + self.node_timer.is_some() + } + + pub fn finish_timer(self) -> PolarsResult { + self.node_timer.unwrap().finish() + } + + // Timings should be a list of (start, end, name) where the start + // and end are raw durations since the query start as nanoseconds. + pub fn record_raw_timings(&self, timings: &[(u64, u64, String)]) { + for &(start, end, ref name) in timings { + self.node_timer.as_ref().unwrap().store_duration( + Duration::from_nanos(start), + Duration::from_nanos(end), + name.to_string(), + ); + } + } + + // This is wrong when the U64 overflows which will never happen. + pub fn should_stop(&self) -> PolarsResult<()> { + try_raise_keyboard_interrupt(); + polars_ensure!(!self.stop.load(Ordering::Relaxed), ComputeError: "query interrupted"); + Ok(()) + } + + pub fn cancel_token(&self) -> Arc { + self.stop.clone() + } + + pub fn record T>(&self, func: F, name: Cow<'static, str>) -> T { + match &self.node_timer { + None => func(), + Some(timer) => { + let start = std::time::Instant::now(); + let out = func(); + let end = std::time::Instant::now(); + + timer.store(start, end, name.as_ref().to_string()); + out + }, + } + } + + /// Partially clones and partially clears state + /// This should be used when splitting a node, like a join or union + pub fn split(&self) -> Self { + Self { + df_cache: self.df_cache.clone(), + schema_cache: Default::default(), + window_cache: Default::default(), + branch_idx: self.branch_idx, + flags: AtomicU8::new(self.flags.load(Ordering::Relaxed)), + ext_contexts: self.ext_contexts.clone(), + node_timer: self.node_timer.clone(), + stop: self.stop.clone(), + } + } + + pub fn set_schema(&self, schema: SchemaRef) { + let mut lock = self.schema_cache.write().unwrap(); + *lock = Some(schema); + } + + /// Clear the schema. Typically at the end of a projection. + pub fn clear_schema_cache(&self) { + let mut lock = self.schema_cache.write().unwrap(); + *lock = None; + } + + /// Get the schema. + pub fn get_schema(&self) -> Option { + let lock = self.schema_cache.read().unwrap(); + lock.clone() + } + + pub fn get_df_cache(&self, key: usize, cache_hits: u32) -> CachedValue { + let guard = self.df_cache.read().unwrap(); + + match guard.get(&key) { + Some(v) => v.clone(), + None => { + drop(guard); + let mut guard = self.df_cache.write().unwrap(); + + guard + .entry(key) + .or_insert_with(|| { + Arc::new((AtomicI64::new(cache_hits as i64), OnceLock::new())) + }) + .clone() + }, + } + } + + pub fn remove_df_cache(&self, key: usize) { + let mut guard = self.df_cache.write().unwrap(); + let _ = guard.remove(&key).unwrap(); + } + + /// Clear the cache used by the Window expressions + pub fn clear_window_expr_cache(&self) { + self.window_cache.clear(); + } + + fn set_flags(&self, f: &dyn Fn(StateFlags) -> StateFlags) { + let flags: StateFlags = self.flags.load(Ordering::Relaxed).into(); + let flags = f(flags); + self.flags.store(flags.as_u8(), Ordering::Relaxed); + } + + /// Indicates that window expression's [`GroupTuples`] may be cached. + pub fn cache_window(&self) -> bool { + let flags: StateFlags = self.flags.load(Ordering::Relaxed).into(); + flags.contains(StateFlags::CACHE_WINDOW_EXPR) + } + + /// Indicates that window expression's [`GroupTuples`] may be cached. + pub fn has_window(&self) -> bool { + let flags: StateFlags = self.flags.load(Ordering::Relaxed).into(); + flags.contains(StateFlags::HAS_WINDOW) + } + + /// More verbose logging + pub fn verbose(&self) -> bool { + let flags: StateFlags = self.flags.load(Ordering::Relaxed).into(); + flags.contains(StateFlags::VERBOSE) + } + + pub fn remove_cache_window_flag(&mut self) { + self.set_flags(&|mut flags| { + flags.remove(StateFlags::CACHE_WINDOW_EXPR); + flags + }); + } + + pub fn insert_cache_window_flag(&mut self) { + self.set_flags(&|mut flags| { + flags.insert(StateFlags::CACHE_WINDOW_EXPR); + flags + }); + } + // this will trigger some conservative + pub fn insert_has_window_function_flag(&mut self) { + self.set_flags(&|mut flags| { + flags.insert(StateFlags::HAS_WINDOW); + flags + }); + } +} + +impl Default for ExecutionState { + fn default() -> Self { + ExecutionState::new() + } +} + +impl Clone for ExecutionState { + /// clones, but clears no state. + fn clone(&self) -> Self { + Self { + df_cache: self.df_cache.clone(), + schema_cache: self.schema_cache.read().unwrap().clone().into(), + window_cache: self.window_cache.clone(), + branch_idx: self.branch_idx, + flags: AtomicU8::new(self.flags.load(Ordering::Relaxed)), + ext_contexts: self.ext_contexts.clone(), + node_timer: self.node_timer.clone(), + stop: self.stop.clone(), + } + } +} diff --git a/crates/polars-expr/src/state/mod.rs b/crates/polars-expr/src/state/mod.rs new file mode 100644 index 000000000000..d8f5ca5b8ca0 --- /dev/null +++ b/crates/polars-expr/src/state/mod.rs @@ -0,0 +1,5 @@ +mod execution_state; +mod node_timer; + +pub use execution_state::*; +use node_timer::*; diff --git a/crates/polars-expr/src/state/node_timer.rs b/crates/polars-expr/src/state/node_timer.rs new file mode 100644 index 000000000000..cdaebe528c04 --- /dev/null +++ b/crates/polars-expr/src/state/node_timer.rs @@ -0,0 +1,73 @@ +use std::sync::Mutex; +use std::time::{Duration, Instant}; + +use polars_core::prelude::*; +use polars_core::utils::NoNull; + +type StartInstant = Instant; +type EndInstant = Instant; + +type Nodes = Vec; +type Ticks = Vec<(Duration, Duration)>; + +#[derive(Clone)] +pub(super) struct NodeTimer { + query_start: Instant, + data: Arc>, +} + +impl NodeTimer { + pub(super) fn new(query_start: Instant) -> Self { + Self { + query_start, + data: Arc::new(Mutex::new((Vec::with_capacity(16), Vec::with_capacity(16)))), + } + } + + pub(super) fn store(&self, start: StartInstant, end: EndInstant, name: String) { + self.store_duration( + start.duration_since(self.query_start), + end.duration_since(self.query_start), + name, + ) + } + + pub(super) fn store_duration(&self, start: Duration, end: Duration, name: String) { + let mut data = self.data.lock().unwrap(); + let nodes = &mut data.0; + nodes.push(name); + let ticks = &mut data.1; + ticks.push((start, end)) + } + + pub(super) fn finish(self) -> PolarsResult { + let mut data = self.data.lock().unwrap(); + let mut nodes = std::mem::take(&mut data.0); + nodes.push("optimization".to_string()); + + let mut ticks = std::mem::take(&mut data.1); + // first value is end of optimization + polars_ensure!(!ticks.is_empty(), ComputeError: "no data to time"); + let start = ticks[0].0; + ticks.push((Duration::from_nanos(0), start)); + let nodes_s = Column::new(PlSmallStr::from_static("node"), nodes); + let start: NoNull = ticks + .iter() + .map(|(start, _)| start.as_micros() as u64) + .collect(); + let mut start = start.into_inner(); + start.rename(PlSmallStr::from_static("start")); + + let end: NoNull = ticks + .iter() + .map(|(_, end)| end.as_micros() as u64) + .collect(); + let mut end = end.into_inner(); + end.rename(PlSmallStr::from_static("end")); + + let height = nodes_s.len(); + let columns = vec![nodes_s, start.into_column(), end.into_column()]; + let df = unsafe { DataFrame::new_no_checks(height, columns) }; + df.sort(vec!["start"], SortMultipleOptions::default()) + } +} diff --git a/crates/polars-ffi/Cargo.toml b/crates/polars-ffi/Cargo.toml new file mode 100644 index 000000000000..2414c80aae70 --- /dev/null +++ b/crates/polars-ffi/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "polars-ffi" +version = { workspace = true } +authors = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +license = { workspace = true } +repository = { workspace = true } +description = "FFI utils for the Polars project." + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +arrow = { workspace = true } +polars-core = { workspace = true } diff --git a/crates/polars-ffi/src/lib.rs b/crates/polars-ffi/src/lib.rs new file mode 100644 index 000000000000..969e6069d69e --- /dev/null +++ b/crates/polars-ffi/src/lib.rs @@ -0,0 +1,35 @@ +#![allow(unsafe_op_in_unsafe_fn)] +pub mod version_0; + +use std::mem::ManuallyDrop; + +use arrow::array::ArrayRef; +use arrow::ffi; +use arrow::ffi::{ArrowArray, ArrowSchema}; +use polars_core::error::PolarsResult; +use polars_core::prelude::{ArrowField, Series}; + +pub const MAJOR: u16 = 0; +pub const MINOR: u16 = 1; + +pub const fn get_version() -> (u16, u16) { + (MAJOR, MINOR) +} + +// A utility that helps releasing/owning memory. +#[allow(dead_code)] +struct PrivateData { + schema: Box, + arrays: Box<[*mut ArrowArray]>, +} + +/// # Safety +/// `ArrowArray` and `ArrowSchema` must be valid +unsafe fn import_array( + array: ffi::ArrowArray, + schema: &ffi::ArrowSchema, +) -> PolarsResult { + let field = ffi::import_field_from_c(schema)?; + let out = ffi::import_array_from_c(array, field.dtype)?; + Ok(out) +} diff --git a/crates/polars-ffi/src/version_0.rs b/crates/polars-ffi/src/version_0.rs new file mode 100644 index 000000000000..103d3a60816f --- /dev/null +++ b/crates/polars-ffi/src/version_0.rs @@ -0,0 +1,175 @@ +use polars_core::frame::DataFrame; +use polars_core::prelude::{Column, CompatLevel}; + +use super::*; + +/// An FFI exported `Series`. +#[repr(C)] +pub struct SeriesExport { + field: *mut ArrowSchema, + // A double ptr, so we can easily release the buffer + // without dropping the arrays. + arrays: *mut *mut ArrowArray, + len: usize, + release: Option, + private_data: *mut std::os::raw::c_void, +} + +impl SeriesExport { + pub fn empty() -> Self { + Self { + field: std::ptr::null_mut(), + arrays: std::ptr::null_mut(), + len: 0, + release: None, + private_data: std::ptr::null_mut(), + } + } + + pub fn is_null(&self) -> bool { + self.private_data.is_null() + } +} + +impl Drop for SeriesExport { + fn drop(&mut self) { + if let Some(release) = self.release { + unsafe { release(self) } + } + } +} + +// callback used to drop [SeriesExport] when it is exported. +unsafe extern "C" fn c_release_series_export(e: *mut SeriesExport) { + if e.is_null() { + return; + } + let e = &mut *e; + let private = Box::from_raw(e.private_data as *mut PrivateData); + for ptr in private.arrays.iter() { + // drop the box, not the array + let _ = Box::from_raw(*ptr as *mut ManuallyDrop); + } + + e.release = None; +} + +pub fn export_column(c: &Column) -> SeriesExport { + export_series(c.as_materialized_series()) +} + +pub fn export_series(s: &Series) -> SeriesExport { + let field = ArrowField::new( + s.name().clone(), + s.dtype().to_arrow(CompatLevel::newest()), + true, + ); + let schema = Box::new(ffi::export_field_to_c(&field)); + + let mut arrays = (0..s.chunks().len()) + .map(|i| { + // Make sure we export the logical type. + let arr = s.to_arrow(i, CompatLevel::newest()); + Box::into_raw(Box::new(ffi::export_array_to_c(arr.clone()))) + }) + .collect::>(); + + let len = arrays.len(); + let ptr = arrays.as_mut_ptr(); + SeriesExport { + field: schema.as_ref() as *const ArrowSchema as *mut ArrowSchema, + arrays: ptr, + len, + release: Some(c_release_series_export), + private_data: Box::into_raw(Box::new(PrivateData { arrays, schema })) + as *mut std::os::raw::c_void, + } +} + +/// # Safety +/// `SeriesExport` must be valid +pub unsafe fn import_series(e: SeriesExport) -> PolarsResult { + let field = ffi::import_field_from_c(&(*e.field))?; + + let pointers = std::slice::from_raw_parts_mut(e.arrays, e.len); + let chunks = pointers + .iter() + .map(|ptr| { + let arr = std::ptr::read(*ptr); + import_array(arr, &(*e.field)) + }) + .collect::>>()?; + + Series::try_from((field.name.clone(), chunks)) +} + +/// # Safety +/// `SeriesExport` must be valid +pub unsafe fn import_series_buffer(e: *mut SeriesExport, len: usize) -> PolarsResult> { + let mut out = Vec::with_capacity(len); + for i in 0..len { + let e = std::ptr::read(e.add(i)); + out.push(import_series(e)?) + } + Ok(out) +} + +/// # Safety +/// `SeriesExport` must be valid +pub unsafe fn import_df(e: *mut SeriesExport, len: usize) -> PolarsResult { + let mut out = Vec::with_capacity(len); + for i in 0..len { + let e = std::ptr::read(e.add(i)); + let s = import_series(e)?; + out.push(s.into()) + } + Ok(DataFrame::new_no_checks_height_from_first(out)) +} + +/// Passed to an expression. +/// This contains information for the implementer of the expression on what it is allowed to do. +#[derive(Copy, Clone, Debug, Default)] +#[repr(C)] +pub struct CallerContext { + // bit + // 1: PARALLEL + bitflags: u64, +} + +impl CallerContext { + const fn kth_bit_set(&self, k: u64) -> bool { + (self.bitflags & (1 << k)) > 0 + } + + fn set_kth_bit(&mut self, k: u64) { + self.bitflags |= 1 << k + } + + /// Parallelism is done by polars' main engine, the plugin should not run its own parallelism. + /// If this is `false`, the plugin could use parallelism without (much) contention with polars + /// parallelism strategies. + pub fn parallel(&self) -> bool { + self.kth_bit_set(0) + } + + pub fn _set_parallel(&mut self) { + self.set_kth_bit(0) + } +} + +#[cfg(test)] +mod test { + use polars_core::prelude::*; + + use super::*; + + #[test] + fn test_ffi() { + let s = Series::new("a".into(), [1, 2]); + let e = export_series(&s); + + unsafe { + assert_eq!(import_series(e).unwrap(), s); + }; + } +} diff --git a/crates/polars-io/Cargo.toml b/crates/polars-io/Cargo.toml new file mode 100644 index 000000000000..9fd4eb504116 --- /dev/null +++ b/crates/polars-io/Cargo.toml @@ -0,0 +1,140 @@ +[package] +name = "polars-io" +version = { workspace = true } +authors = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +license = { workspace = true } +repository = { workspace = true } +description = "IO related logic for the Polars DataFrame library" + +[dependencies] +polars-core = { workspace = true } +polars-error = { workspace = true } +polars-json = { workspace = true, optional = true } +polars-parquet = { workspace = true, optional = true } +polars-schema = { workspace = true } +polars-time = { workspace = true, features = [], optional = true } +polars-utils = { workspace = true, features = ['mmap'] } + +arrow = { workspace = true } +async-trait = { workspace = true, optional = true } +atoi_simd = { workspace = true, optional = true } +blake3 = { version = "1.6.1", optional = true } +bytes = { workspace = true } +chrono = { workspace = true, optional = true } +chrono-tz = { workspace = true, optional = true } +fast-float2 = { workspace = true, optional = true } +flate2 = { workspace = true, optional = true } +futures = { workspace = true, optional = true } +glob = { version = "0.3" } +hashbrown = { workspace = true } +itoa = { workspace = true, optional = true } +memchr = { workspace = true } +memmap = { workspace = true } +num-traits = { workspace = true } +object_store = { workspace = true, optional = true } +percent-encoding = { workspace = true } +pyo3 = { workspace = true, optional = true } +rayon = { workspace = true } +regex = { workspace = true } +reqwest = { workspace = true, optional = true } +ryu = { workspace = true, optional = true } +serde = { workspace = true, features = ["rc"], optional = true } +serde_json = { version = "1", optional = true } +simd-json = { workspace = true, optional = true } +simdutf8 = { workspace = true, optional = true } +strum = { workspace = true, optional = true } +strum_macros = { workspace = true, optional = true } +tokio = { workspace = true, features = ["fs", "net", "rt-multi-thread", "time", "sync"], optional = true } +tokio-util = { workspace = true, features = ["io", "io-util"], optional = true } +url = { workspace = true, optional = true } +zstd = { workspace = true, optional = true } + +[target.'cfg(not(target_family = "wasm"))'.dependencies] +fs4 = { version = "0.13", features = ["sync"], optional = true } +home = "0.5.4" + +[dev-dependencies] +tempfile = "3" + +[features] +catalog = ["cloud", "serde", "reqwest", "futures", "strum", "strum_macros", "chrono"] +default = ["decompress"] +# support for arrows json parsing +json = [ + "polars-json", + "simd-json", + "atoi_simd", + "dtype-struct", + "csv", +] +serde = ["dep:serde", "polars-core/serde-lazy", "polars-parquet/serde", "polars-utils/serde"] +# support for arrows ipc file parsing +ipc = ["arrow/io_ipc", "arrow/io_ipc_compression"] +# support for arrows streaming ipc file parsing +ipc_streaming = ["arrow/io_ipc", "arrow/io_ipc_compression"] +# support for arrow avro parsing +avro = ["arrow/io_avro", "arrow/io_avro_compression"] +csv = ["atoi_simd", "polars-core/rows", "itoa", "ryu", "fast-float2", "simdutf8"] +decompress = ["flate2/zlib-rs", "zstd"] +dtype-u8 = ["polars-core/dtype-u8"] +dtype-u16 = ["polars-core/dtype-u16"] +dtype-i8 = ["polars-core/dtype-i8"] +dtype-i16 = ["polars-core/dtype-i16"] +dtype-i128 = ["polars-core/dtype-i128"] +dtype-categorical = ["polars-core/dtype-categorical"] +dtype-date = ["polars-core/dtype-date", "polars-time/dtype-date"] +object = ["polars-core/object"] +dtype-datetime = [ + "polars-core/dtype-datetime", + "polars-core/temporal", + "polars-time/dtype-datetime", + "chrono", +] +timezones = [ + "chrono-tz", + "dtype-datetime", + "arrow/timezones", + "polars-json?/chrono-tz", + "polars-json?/timezones", +] +dtype-time = ["polars-core/dtype-time", "polars-core/temporal", "polars-time/dtype-time"] +dtype-duration = ["polars-core/dtype-duration", "polars-time/dtype-duration"] +dtype-struct = ["polars-core/dtype-struct"] +dtype-decimal = ["polars-core/dtype-decimal", "polars-json?/dtype-decimal"] +fmt = ["polars-core/fmt"] +lazy = [] +parquet = ["polars-parquet", "polars-parquet/compression", "polars-core/partition_by"] +async = [ + "async-trait", + "futures", + "tokio", + "tokio-util", + "polars-error/regex", + "polars-parquet?/async", +] +cloud = [ + "object_store", + "async", + "polars-error/object_store", + "url", + "serde_json", + "serde", + "file_cache", + "reqwest", + "http", +] +file_cache = ["async", "dep:blake3", "dep:fs4", "serde_json", "cloud"] +aws = ["object_store/aws", "cloud", "reqwest"] +azure = ["object_store/azure", "cloud"] +gcp = ["object_store/gcp", "cloud"] +http = ["object_store/http", "cloud"] +temporal = ["dtype-datetime", "dtype-date", "dtype-time"] +simd = [] +python = ["pyo3", "polars-error/python", "polars-utils/python"] + +[package.metadata.docs.rs] +all-features = true +# defines the configuration attribute `docsrs` +rustdoc-args = ["--cfg", "docsrs"] diff --git a/crates/polars-io/LICENSE b/crates/polars-io/LICENSE new file mode 120000 index 000000000000..30cff7403da0 --- /dev/null +++ b/crates/polars-io/LICENSE @@ -0,0 +1 @@ +../../LICENSE \ No newline at end of file diff --git a/crates/polars-io/README.md b/crates/polars-io/README.md new file mode 100644 index 000000000000..8258a7ef8601 --- /dev/null +++ b/crates/polars-io/README.md @@ -0,0 +1,7 @@ +# polars-io + +`polars-io` is an **internal sub-crate** of the [Polars](https://crates.io/crates/polars) library, +that provides IO functionality for the Polars dataframe library. + +**Important Note**: This crate is **not intended for external usage**. Please refer to the main +[Polars crate](https://crates.io/crates/polars) for intended usage. diff --git a/crates/polars-io/src/avro/mod.rs b/crates/polars-io/src/avro/mod.rs new file mode 100644 index 000000000000..c8b54eba1dfa --- /dev/null +++ b/crates/polars-io/src/avro/mod.rs @@ -0,0 +1,5 @@ +mod read; +mod write; + +pub use read::*; +pub use write::*; diff --git a/crates/polars-io/src/avro/read.rs b/crates/polars-io/src/avro/read.rs new file mode 100644 index 000000000000..eea01c559470 --- /dev/null +++ b/crates/polars-io/src/avro/read.rs @@ -0,0 +1,132 @@ +use std::io::{Read, Seek}; + +use arrow::io::avro::{self, read}; +use arrow::record_batch::RecordBatch; +use polars_core::error::to_compute_err; +use polars_core::prelude::*; + +use crate::prelude::*; +use crate::shared::{ArrowReader, finish_reader}; + +/// Read [Apache Avro] format into a [`DataFrame`] +/// +/// [Apache Avro]: https://avro.apache.org +/// +/// # Example +/// ``` +/// use std::fs::File; +/// use polars_core::prelude::*; +/// use polars_io::avro::AvroReader; +/// use polars_io::SerReader; +/// +/// fn example() -> PolarsResult { +/// let file = File::open("file.avro").expect("file not found"); +/// +/// AvroReader::new(file) +/// .finish() +/// } +/// ``` +#[must_use] +pub struct AvroReader { + reader: R, + rechunk: bool, + n_rows: Option, + columns: Option>, + projection: Option>, +} + +impl AvroReader { + /// Get schema of the Avro File + pub fn schema(&mut self) -> PolarsResult { + let schema = self.arrow_schema()?; + Ok(Schema::from_arrow_schema(&schema)) + } + + /// Get arrow schema of the avro File, this is faster than a polars schema. + pub fn arrow_schema(&mut self) -> PolarsResult { + let metadata = + avro::avro_schema::read::read_metadata(&mut self.reader).map_err(to_compute_err)?; + let schema = read::infer_schema(&metadata.record)?; + Ok(schema) + } + + /// Stop reading when `n` rows are read. + pub fn with_n_rows(mut self, num_rows: Option) -> Self { + self.n_rows = num_rows; + self + } + + /// Set the reader's column projection. This counts from 0, meaning that + /// `vec![0, 4]` would select the 1st and 5th column. + pub fn with_projection(mut self, projection: Option>) -> Self { + self.projection = projection; + self + } + + /// Columns to select/ project + pub fn with_columns(mut self, columns: Option>) -> Self { + self.columns = columns; + self + } +} + +impl ArrowReader for read::Reader +where + R: Read + Seek, +{ + fn next_record_batch(&mut self) -> PolarsResult> { + self.next().map_or(Ok(None), |v| v.map(Some)) + } +} + +impl SerReader for AvroReader +where + R: Read + Seek, +{ + fn new(reader: R) -> Self { + AvroReader { + reader, + rechunk: true, + n_rows: None, + columns: None, + projection: None, + } + } + + fn set_rechunk(mut self, rechunk: bool) -> Self { + self.rechunk = rechunk; + self + } + + fn finish(mut self) -> PolarsResult { + let rechunk = self.rechunk; + let metadata = + avro::avro_schema::read::read_metadata(&mut self.reader).map_err(to_compute_err)?; + let schema = read::infer_schema(&metadata.record)?; + + if let Some(columns) = &self.columns { + self.projection = Some(columns_to_projection(columns, &schema)?); + } + + let (projection, projected_schema) = if let Some(projection) = self.projection { + let mut prj = vec![false; schema.len()]; + for &index in projection.iter() { + prj[index] = true; + } + (Some(prj), apply_projection(&schema, &projection)) + } else { + (None, schema.clone()) + }; + + let avro_reader = avro::read::Reader::new(&mut self.reader, metadata, schema, projection); + + finish_reader( + avro_reader, + rechunk, + self.n_rows, + None, + &projected_schema, + None, + ) + } +} diff --git a/crates/polars-io/src/avro/write.rs b/crates/polars-io/src/avro/write.rs new file mode 100644 index 000000000000..552cde3dc0ed --- /dev/null +++ b/crates/polars-io/src/avro/write.rs @@ -0,0 +1,102 @@ +use std::io::Write; + +pub use Compression as AvroCompression; +pub use arrow::io::avro::avro_schema::file::Compression; +use arrow::io::avro::avro_schema::{self}; +use arrow::io::avro::write; +use polars_core::error::to_compute_err; +use polars_core::prelude::*; + +use crate::shared::{SerWriter, schema_to_arrow_checked}; + +/// Write a [`DataFrame`] to [Apache Avro] format +/// +/// [Apache Avro]: https://avro.apache.org +/// +/// # Example +/// +/// ``` +/// use polars_core::prelude::*; +/// use polars_io::avro::AvroWriter; +/// use std::fs::File; +/// use polars_io::SerWriter; +/// +/// fn example(df: &mut DataFrame) -> PolarsResult<()> { +/// let mut file = File::create("file.avro").expect("could not create file"); +/// +/// AvroWriter::new(&mut file) +/// .finish(df) +/// } +/// ``` +#[must_use] +pub struct AvroWriter { + writer: W, + compression: Option, + name: String, +} + +impl AvroWriter +where + W: Write, +{ + /// Set the compression used. Defaults to None. + pub fn with_compression(mut self, compression: Option) -> Self { + self.compression = compression; + self + } + + pub fn with_name(mut self, name: String) -> Self { + self.name = name; + self + } +} + +impl SerWriter for AvroWriter +where + W: Write, +{ + fn new(writer: W) -> Self { + Self { + writer, + compression: None, + name: "".to_string(), + } + } + + fn finish(&mut self, df: &mut DataFrame) -> PolarsResult<()> { + let schema = schema_to_arrow_checked(df.schema(), CompatLevel::oldest(), "avro")?; + let record = write::to_record(&schema, self.name.clone())?; + + let mut data = vec![]; + let mut compressed_block = avro_schema::file::CompressedBlock::default(); + for chunk in df.iter_chunks(CompatLevel::oldest(), true) { + let mut serializers = chunk + .iter() + .zip(record.fields.iter()) + .map(|(array, field)| write::new_serializer(array.as_ref(), &field.schema)) + .collect::>(); + + let mut block = + avro_schema::file::Block::new(chunk.arrays()[0].len(), std::mem::take(&mut data)); + write::serialize(&mut serializers, &mut block); + let _was_compressed = + avro_schema::write::compress(&mut block, &mut compressed_block, self.compression) + .map_err(to_compute_err)?; + + avro_schema::write::write_metadata(&mut self.writer, record.clone(), self.compression) + .map_err(to_compute_err)?; + + avro_schema::write::write_block(&mut self.writer, &compressed_block) + .map_err(to_compute_err)?; + // reuse block for next iteration. + data = block.data; + data.clear(); + + // reuse block for next iteration + compressed_block.data.clear(); + compressed_block.number_of_rows = 0 + } + + Ok(()) + } +} diff --git a/crates/polars-io/src/catalog/mod.rs b/crates/polars-io/src/catalog/mod.rs new file mode 100644 index 000000000000..1816b7c87322 --- /dev/null +++ b/crates/polars-io/src/catalog/mod.rs @@ -0,0 +1 @@ +pub mod unity; diff --git a/crates/polars-io/src/catalog/unity/client.rs b/crates/polars-io/src/catalog/unity/client.rs new file mode 100644 index 000000000000..26769dce43e7 --- /dev/null +++ b/crates/polars-io/src/catalog/unity/client.rs @@ -0,0 +1,360 @@ +use polars_core::prelude::PlHashMap; +use polars_core::schema::Schema; +use polars_error::{PolarsResult, polars_bail, to_compute_err}; + +use super::models::{CatalogInfo, NamespaceInfo, TableCredentials, TableInfo}; +use super::schema::schema_to_column_info_list; +use super::utils::{PageWalker, do_request}; +use crate::catalog::unity::models::{ColumnInfo, DataSourceFormat, TableType}; +use crate::impl_page_walk; +use crate::utils::decode_json_response; + +/// Unity catalog client. +pub struct CatalogClient { + workspace_url: String, + http_client: reqwest::Client, +} + +impl CatalogClient { + pub async fn list_catalogs(&self) -> PolarsResult> { + ListCatalogs(PageWalker::new(self.http_client.get(format!( + "{}{}", + &self.workspace_url, "/api/2.1/unity-catalog/catalogs" + )))) + .read_all_pages() + .await + } + + pub async fn list_namespaces(&self, catalog_name: &str) -> PolarsResult> { + ListSchemas(PageWalker::new( + self.http_client + .get(format!( + "{}{}", + &self.workspace_url, "/api/2.1/unity-catalog/schemas" + )) + .query(&[("catalog_name", catalog_name)]), + )) + .read_all_pages() + .await + } + + pub async fn list_tables( + &self, + catalog_name: &str, + namespace: &str, + ) -> PolarsResult> { + ListTables(PageWalker::new( + self.http_client + .get(format!( + "{}{}", + &self.workspace_url, "/api/2.1/unity-catalog/tables" + )) + .query(&[("catalog_name", catalog_name), ("schema_name", namespace)]), + )) + .read_all_pages() + .await + } + + pub async fn get_table_info( + &self, + catalog_name: &str, + namespace: &str, + table_name: &str, + ) -> PolarsResult { + let full_table_name = format!( + "{}.{}.{}", + catalog_name.replace('/', "%2F"), + namespace.replace('/', "%2F"), + table_name.replace('/', "%2F") + ); + + let bytes = do_request( + self.http_client + .get(format!( + "{}{}{}", + &self.workspace_url, "/api/2.1/unity-catalog/tables/", full_table_name + )) + .query(&[("full_name", full_table_name)]), + ) + .await?; + + let out: TableInfo = decode_json_response(&bytes)?; + + Ok(out) + } + + pub async fn get_table_credentials( + &self, + table_id: &str, + write: bool, + ) -> PolarsResult { + let bytes = do_request( + self.http_client + .post(format!( + "{}{}", + &self.workspace_url, "/api/2.1/unity-catalog/temporary-table-credentials" + )) + .query(&[ + ("table_id", table_id), + ("operation", if write { "READ_WRITE" } else { "READ" }), + ]), + ) + .await?; + + let out: TableCredentials = decode_json_response(&bytes)?; + + Ok(out) + } + + pub async fn create_catalog( + &self, + catalog_name: &str, + comment: Option<&str>, + storage_root: Option<&str>, + ) -> PolarsResult { + let resp = do_request( + self.http_client + .post(format!( + "{}{}", + &self.workspace_url, "/api/2.1/unity-catalog/catalogs" + )) + .json(&Body { + name: catalog_name, + comment, + storage_root, + }), + ) + .await?; + + return decode_json_response(&resp); + + #[derive(serde::Serialize)] + struct Body<'a> { + name: &'a str, + comment: Option<&'a str>, + storage_root: Option<&'a str>, + } + } + + pub async fn delete_catalog(&self, catalog_name: &str, force: bool) -> PolarsResult<()> { + let catalog_name = catalog_name.replace('/', "%2F"); + + do_request( + self.http_client + .delete(format!( + "{}{}{}", + &self.workspace_url, "/api/2.1/unity-catalog/catalogs/", catalog_name + )) + .query(&[("force", force)]), + ) + .await?; + + Ok(()) + } + + pub async fn create_namespace( + &self, + catalog_name: &str, + namespace: &str, + comment: Option<&str>, + storage_root: Option<&str>, + ) -> PolarsResult { + let resp = do_request( + self.http_client + .post(format!( + "{}{}", + &self.workspace_url, "/api/2.1/unity-catalog/schemas" + )) + .json(&Body { + name: namespace, + catalog_name, + comment, + storage_root, + }), + ) + .await?; + + return decode_json_response(&resp); + + #[derive(serde::Serialize)] + struct Body<'a> { + name: &'a str, + catalog_name: &'a str, + comment: Option<&'a str>, + storage_root: Option<&'a str>, + } + } + + pub async fn delete_namespace( + &self, + catalog_name: &str, + namespace: &str, + force: bool, + ) -> PolarsResult<()> { + let full_name = format!( + "{}.{}", + catalog_name.replace('/', "%2F"), + namespace.replace('/', "%2F"), + ); + + do_request( + self.http_client + .delete(format!( + "{}{}{}", + &self.workspace_url, "/api/2.1/unity-catalog/schemas/", full_name + )) + .query(&[("force", force)]), + ) + .await?; + + Ok(()) + } + + /// Note, `data_source_format` can be None for some `table_type`s. + #[allow(clippy::too_many_arguments)] + pub async fn create_table( + &self, + catalog_name: &str, + namespace: &str, + table_name: &str, + schema: Option<&Schema>, + table_type: &TableType, + data_source_format: Option<&DataSourceFormat>, + comment: Option<&str>, + storage_location: Option<&str>, + properties: &mut (dyn Iterator + Send + Sync), + ) -> PolarsResult { + let columns = schema.map(schema_to_column_info_list).transpose()?; + let columns = columns.as_deref(); + + let resp = do_request( + self.http_client + .post(format!( + "{}{}", + &self.workspace_url, "/api/2.1/unity-catalog/tables" + )) + .json(&Body { + name: table_name, + catalog_name, + schema_name: namespace, + table_type, + data_source_format, + comment, + columns, + storage_location, + properties: properties.collect(), + }), + ) + .await?; + + return decode_json_response(&resp); + + #[derive(serde::Serialize)] + struct Body<'a> { + name: &'a str, + catalog_name: &'a str, + schema_name: &'a str, + comment: Option<&'a str>, + table_type: &'a TableType, + #[serde(skip_serializing_if = "Option::is_none")] + data_source_format: Option<&'a DataSourceFormat>, + columns: Option<&'a [ColumnInfo]>, + storage_location: Option<&'a str>, + properties: PlHashMap<&'a str, &'a str>, + } + } + + pub async fn delete_table( + &self, + catalog_name: &str, + namespace: &str, + table_name: &str, + ) -> PolarsResult<()> { + let full_name = format!( + "{}.{}.{}", + catalog_name.replace('/', "%2F"), + namespace.replace('/', "%2F"), + table_name.replace('/', "%2F"), + ); + + do_request(self.http_client.delete(format!( + "{}{}{}", + &self.workspace_url, "/api/2.1/unity-catalog/tables/", full_name + ))) + .await?; + + Ok(()) + } +} + +pub struct CatalogClientBuilder { + workspace_url: Option, + bearer_token: Option, +} + +#[allow(clippy::derivable_impls)] +impl Default for CatalogClientBuilder { + fn default() -> Self { + Self { + workspace_url: None, + bearer_token: None, + } + } +} + +impl CatalogClientBuilder { + pub fn new() -> Self { + Self::default() + } + + pub fn with_workspace_url(mut self, workspace_url: impl Into) -> Self { + self.workspace_url = Some(workspace_url.into()); + self + } + + pub fn with_bearer_token(mut self, bearer_token: impl Into) -> Self { + self.bearer_token = Some(bearer_token.into()); + self + } + + pub fn build(self) -> PolarsResult { + let Some(workspace_url) = self.workspace_url else { + polars_bail!(ComputeError: "expected Some(_) for workspace_url") + }; + + Ok(CatalogClient { + workspace_url, + http_client: { + let builder = reqwest::ClientBuilder::new().user_agent("polars"); + + let builder = if let Some(bearer_token) = self.bearer_token { + use reqwest::header::{AUTHORIZATION, HeaderMap, HeaderValue, USER_AGENT}; + + let mut headers = HeaderMap::new(); + + let mut auth_value = + HeaderValue::from_str(format!("Bearer {}", bearer_token).as_str()).unwrap(); + auth_value.set_sensitive(true); + + headers.insert(AUTHORIZATION, auth_value); + headers.insert(USER_AGENT, "polars".try_into().unwrap()); + + builder.default_headers(headers) + } else { + builder + }; + + builder.build().map_err(to_compute_err)? + }, + }) + } +} + +pub struct ListCatalogs(pub(crate) PageWalker); +impl_page_walk!(ListCatalogs, CatalogInfo, key_name = catalogs); + +pub struct ListSchemas(pub(crate) PageWalker); +impl_page_walk!(ListSchemas, NamespaceInfo, key_name = schemas); + +pub struct ListTables(pub(crate) PageWalker); +impl_page_walk!(ListTables, TableInfo, key_name = tables); diff --git a/crates/polars-io/src/catalog/unity/mod.rs b/crates/polars-io/src/catalog/unity/mod.rs new file mode 100644 index 000000000000..03fb920460a8 --- /dev/null +++ b/crates/polars-io/src/catalog/unity/mod.rs @@ -0,0 +1,4 @@ +pub mod client; +pub mod models; +pub mod schema; +pub(crate) mod utils; diff --git a/crates/polars-io/src/catalog/unity/models.rs b/crates/polars-io/src/catalog/unity/models.rs new file mode 100644 index 000000000000..a3101dc823e7 --- /dev/null +++ b/crates/polars-io/src/catalog/unity/models.rs @@ -0,0 +1,316 @@ +use polars_core::prelude::PlHashMap; +use polars_utils::pl_str::PlSmallStr; + +#[derive(Debug, serde::Deserialize)] +pub struct CatalogInfo { + pub name: String, + + pub comment: Option, + + #[serde(default)] + pub storage_location: Option, + + #[serde(default, deserialize_with = "null_to_default")] + pub properties: PlHashMap, + + #[serde(default, deserialize_with = "null_to_default")] + pub options: PlHashMap, + + #[serde(with = "chrono::serde::ts_milliseconds_option")] + pub created_at: Option>, + + pub created_by: Option, + + #[serde(with = "chrono::serde::ts_milliseconds_option")] + pub updated_at: Option>, + + pub updated_by: Option, +} + +#[derive(Debug, serde::Deserialize)] +pub struct NamespaceInfo { + pub name: String, + pub comment: Option, + + #[serde(default, deserialize_with = "null_to_default")] + pub properties: PlHashMap, + + #[serde(default)] + pub storage_location: Option, + + #[serde(with = "chrono::serde::ts_milliseconds_option")] + pub created_at: Option>, + + pub created_by: Option, + + #[serde(with = "chrono::serde::ts_milliseconds_option")] + pub updated_at: Option>, + + pub updated_by: Option, +} + +#[derive(Debug, serde::Deserialize)] +pub struct TableInfo { + pub name: String, + pub table_id: String, + pub table_type: TableType, + + #[serde(default)] + pub comment: Option, + + #[serde(default)] + pub storage_location: Option, + + #[serde(default)] + pub data_source_format: Option, + + #[serde(default)] + pub columns: Option>, + + #[serde(default, deserialize_with = "null_to_default")] + pub properties: PlHashMap, + + #[serde(with = "chrono::serde::ts_milliseconds_option")] + pub created_at: Option>, + + pub created_by: Option, + + #[serde(with = "chrono::serde::ts_milliseconds_option")] + pub updated_at: Option>, + + pub updated_by: Option, +} + +#[derive( + Debug, strum_macros::Display, strum_macros::EnumString, serde::Serialize, serde::Deserialize, +)] +#[strum(serialize_all = "SCREAMING_SNAKE_CASE")] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum TableType { + Managed, + External, + View, + MaterializedView, + StreamingTable, + ManagedShallowClone, + Foreign, + ExternalShallowClone, +} + +#[derive( + Debug, strum_macros::Display, strum_macros::EnumString, serde::Serialize, serde::Deserialize, +)] +#[strum(serialize_all = "SCREAMING_SNAKE_CASE")] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum DataSourceFormat { + Delta, + Csv, + Json, + Avro, + Parquet, + Orc, + Text, + + // Databricks-specific + UnityCatalog, + Deltasharing, + DatabricksFormat, + MysqlFormat, + PostgresqlFormat, + RedshiftFormat, + SnowflakeFormat, + SqldwFormat, + SqlserverFormat, + SalesforceFormat, + BigqueryFormat, + NetsuiteFormat, + WorkdayRaasFormat, + HiveSerde, + HiveCustom, + VectorIndexFormat, +} + +#[derive(Debug, serde::Serialize, serde::Deserialize)] +pub struct ColumnInfo { + pub name: PlSmallStr, + pub type_name: PlSmallStr, + pub type_text: PlSmallStr, + pub type_json: String, + pub position: Option, + pub comment: Option, + pub partition_index: Option, +} + +/// Note: This struct contains all the field names for a few different possible type / field presence +/// combinations. We use serde(default) and skip_serializing_if to get the desired serialization +/// output. +/// +/// E.g.: +/// +/// ```text +/// { +/// "name": "List", +/// "type": {"type": "array", "elementType": "long", "containsNull": True}, +/// "nullable": True, +/// "metadata": {}, +/// } +/// { +/// "name": "Struct", +/// "type": { +/// "type": "struct", +/// "fields": [{"name": "x", "type": "long", "nullable": True, "metadata": {}}], +/// }, +/// "nullable": True, +/// "metadata": {}, +/// } +/// { +/// "name": "ListStruct", +/// "type": { +/// "type": "array", +/// "elementType": { +/// "type": "struct", +/// "fields": [{"name": "x", "type": "long", "nullable": True, "metadata": {}}], +/// }, +/// "containsNull": True, +/// }, +/// "nullable": True, +/// "metadata": {}, +/// } +/// { +/// "name": "Map", +/// "type": { +/// "type": "map", +/// "keyType": "string", +/// "valueType": "string", +/// "valueContainsNull": True, +/// }, +/// "nullable": True, +/// "metadata": {}, +/// } +/// ``` +#[derive(Debug, Default, serde::Serialize, serde::Deserialize)] +pub struct ColumnTypeJson { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub name: Option, + + #[serde(rename = "type")] + pub type_: ColumnTypeJsonType, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub nullable: Option, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub metadata: Option>, + + // Used for List types + #[serde( + default, + rename = "elementType", + skip_serializing_if = "Option::is_none" + )] + pub element_type: Option, + + #[serde( + default, + rename = "containsNull", + skip_serializing_if = "Option::is_none" + )] + pub contains_null: Option, + + // Used for Struct types + #[serde(default, skip_serializing_if = "Option::is_none")] + pub fields: Option>, + + // Used for Map types + #[serde(default, rename = "keyType", skip_serializing_if = "Option::is_none")] + pub key_type: Option, + + #[serde(default, rename = "valueType", skip_serializing_if = "Option::is_none")] + pub value_type: Option, + + #[serde( + default, + rename = "valueContainsNull", + skip_serializing_if = "Option::is_none" + )] + pub value_contains_null: Option, +} + +#[derive(Debug, serde::Serialize, serde::Deserialize)] +#[serde(untagged)] +pub enum ColumnTypeJsonType { + /// * `{"type": "name", ..}`` + TypeName(PlSmallStr), + /// * `{"type": {"type": "name", ..}}` + TypeJson(Box), +} + +impl Default for ColumnTypeJsonType { + fn default() -> Self { + Self::TypeName(PlSmallStr::EMPTY) + } +} + +impl ColumnTypeJsonType { + pub const fn from_static_type_name(type_name: &'static str) -> Self { + Self::TypeName(PlSmallStr::from_static(type_name)) + } +} + +#[derive(Debug, serde::Deserialize)] +pub struct TableCredentials { + pub aws_temp_credentials: Option, + pub azure_user_delegation_sas: Option, + pub gcp_oauth_token: Option, + pub expiration_time: i64, +} + +impl TableCredentials { + pub fn into_enum(self) -> Option { + if let v @ Some(_) = self.aws_temp_credentials { + v.map(TableCredentialsVariants::Aws) + } else if let v @ Some(_) = self.azure_user_delegation_sas { + v.map(TableCredentialsVariants::Azure) + } else if let v @ Some(_) = self.gcp_oauth_token { + v.map(TableCredentialsVariants::Gcp) + } else { + None + } + } +} + +pub enum TableCredentialsVariants { + Aws(TableCredentialsAws), + Azure(TableCredentialsAzure), + Gcp(TableCredentialsGcp), +} + +#[derive(Debug, serde::Deserialize)] +pub struct TableCredentialsAws { + pub access_key_id: String, + pub secret_access_key: String, + pub session_token: Option, + + #[serde(default)] + pub access_point: Option, +} + +#[derive(Debug, serde::Deserialize)] +pub struct TableCredentialsAzure { + pub sas_token: String, +} + +#[derive(Debug, serde::Deserialize)] +pub struct TableCredentialsGcp { + pub oauth_token: String, +} + +fn null_to_default<'de, T, D>(d: D) -> Result +where + T: Default + serde::de::Deserialize<'de>, + D: serde::de::Deserializer<'de>, +{ + use serde::Deserialize; + let opt_val = Option::::deserialize(d)?; + Ok(opt_val.unwrap_or_default()) +} diff --git a/crates/polars-io/src/catalog/unity/schema.rs b/crates/polars-io/src/catalog/unity/schema.rs new file mode 100644 index 000000000000..8dba964a7d09 --- /dev/null +++ b/crates/polars-io/src/catalog/unity/schema.rs @@ -0,0 +1,512 @@ +use polars_core::prelude::{DataType, Field}; +use polars_core::schema::{Schema, SchemaRef}; +use polars_error::{PolarsResult, polars_bail, polars_err, to_compute_err}; +use polars_utils::error::TruncateErrorDetail; +use polars_utils::format_pl_smallstr; +use polars_utils::pl_str::PlSmallStr; + +use super::models::{ColumnInfo, ColumnTypeJson, ColumnTypeJsonType, TableInfo}; +use crate::utils::decode_json_response; + +/// Returns `(schema, hive_schema)` +pub fn table_info_to_schemas( + table_info: &TableInfo, +) -> PolarsResult<(Option, Option)> { + let Some(columns) = table_info.columns.as_deref() else { + return Ok((None, None)); + }; + + let mut schema = Schema::default(); + let mut hive_schema = Schema::default(); + + for (i, col) in columns.iter().enumerate() { + if let Some(position) = col.position { + if usize::try_from(position).unwrap() != i { + polars_bail!( + ComputeError: + "not yet supported: position was not ordered" + ) + } + } + + let field = column_info_to_field(col)?; + + if let Some(i) = col.partition_index { + if usize::try_from(i).unwrap() != hive_schema.len() { + polars_bail!( + ComputeError: + "not yet supported: partition_index was not ordered" + ) + } + + hive_schema.extend([field]); + } else { + schema.extend([field]) + } + } + + Ok(( + Some(schema.into()), + Some(hive_schema) + .filter(|x| !x.is_empty()) + .map(|x| x.into()), + )) +} + +pub fn column_info_to_field(column_info: &ColumnInfo) -> PolarsResult { + Ok(Field::new( + column_info.name.clone(), + parse_type_json_str(&column_info.type_json)?, + )) +} + +/// e.g. +/// ```json +/// {"name":"Int64","type":"long","nullable":true} +/// {"name":"List","type":{"type":"array","elementType":"long","containsNull":true},"nullable":true} +/// ``` +pub fn parse_type_json_str(type_json: &str) -> PolarsResult { + let decoded: ColumnTypeJson = decode_json_response(type_json.as_bytes())?; + + parse_type_json(&decoded).map_err(|e| { + e.wrap_msg(|e| { + format!( + "error parsing type response: {}, type_json: {}", + e, + TruncateErrorDetail(type_json) + ) + }) + }) +} + +/// We prefer this as `type_text` cannot be trusted for consistency (e.g. we may expect `decimal(int,int)` +/// but instead get `decimal`, or `struct<...>` and instead get `struct`). +pub fn parse_type_json(type_json: &ColumnTypeJson) -> PolarsResult { + use ColumnTypeJsonType::*; + + let out = match &type_json.type_ { + TypeName(name) => match name.as_str() { + "array" => { + let inner_json: &ColumnTypeJsonType = + type_json.element_type.as_ref().ok_or_else(|| { + polars_err!( + ComputeError: + "missing elementType in response for array type" + ) + })?; + + let inner_dtype = parse_type_json_type(inner_json)?; + + DataType::List(Box::new(inner_dtype)) + }, + + "struct" => { + let fields_json: &[ColumnTypeJson] = + type_json.fields.as_deref().ok_or_else(|| { + polars_err!( + ComputeError: + "missing elementType in response for array type" + ) + })?; + + let fields = fields_json + .iter() + .map(|x| { + let name = x.name.clone().ok_or_else(|| { + polars_err!( + ComputeError: + "missing name in fields response for struct type" + ) + })?; + let dtype = parse_type_json(x)?; + + Ok(Field::new(name, dtype)) + }) + .collect::>>()?; + + DataType::Struct(fields) + }, + + "map" => { + let key_type = type_json.key_type.as_ref().ok_or_else(|| { + polars_err!( + ComputeError: + "missing keyType in response for map type" + ) + })?; + + let value_type = type_json.value_type.as_ref().ok_or_else(|| { + polars_err!( + ComputeError: + "missing valueType in response for map type" + ) + })?; + + DataType::List(Box::new(DataType::Struct(vec![ + Field::new( + PlSmallStr::from_static("key"), + parse_type_json_type(key_type)?, + ), + Field::new( + PlSmallStr::from_static("value"), + parse_type_json_type(value_type)?, + ), + ]))) + }, + + name => parse_type_text(name)?, + }, + + TypeJson(type_json) => parse_type_json(type_json.as_ref())?, + }; + + Ok(out) +} + +fn parse_type_json_type(type_json_type: &ColumnTypeJsonType) -> PolarsResult { + use ColumnTypeJsonType::*; + + match type_json_type { + TypeName(name) => parse_type_text(name), + TypeJson(type_json) => parse_type_json(type_json.as_ref()), + } +} + +/// Parses the string variant of the `type` field within a `type_json`. This can be understood as +/// the leaf / non-nested datatypes of the field. +/// +/// References: +/// * https://spark.apache.org/docs/latest/sql-ref-datatypes.html +/// * https://docs.databricks.com/api/workspace/tables/get +/// * https://docs.databricks.com/en/sql/language-manual/sql-ref-datatypes.html +/// +/// Notes: +/// * `type_precision` and `type_scale` in the API response are defined as supplementary data to +/// the `type_text`, but from testing they aren't actually used - e.g. a decimal type would have a +/// `type_text` of `decimal(18, 2)` +fn parse_type_text(type_text: &str) -> PolarsResult { + use DataType::*; + use polars_core::prelude::TimeUnit; + + let dtype = match type_text { + "boolean" => Boolean, + + "tinyint" | "byte" => Int8, + "smallint" | "short" => Int16, + "int" | "integer" => Int32, + "bigint" | "long" => Int64, + + "float" | "real" => Float32, + "double" => Float64, + + "date" => Date, + "timestamp" | "timestamp_ntz" | "timestamp_ltz" => Datetime(TimeUnit::Microseconds, None), + + "string" => String, + "binary" => Binary, + + "null" | "void" => Null, + + v => { + if v.starts_with("decimal") { + // e.g. decimal(38,18) + (|| { + let (precision, scale) = v + .get(7..)? + .strip_prefix('(')? + .strip_suffix(')')? + .split_once(',')?; + let precision: usize = precision.parse().ok()?; + let scale: usize = scale.parse().ok()?; + + Some(DataType::Decimal(Some(precision), Some(scale))) + })() + .ok_or_else(|| { + polars_err!( + ComputeError: + "type format did not match decimal(int,int): {}", + v + ) + })? + } else { + polars_bail!( + ComputeError: + "parse_type_text unknown type name: {}", + v + ) + } + }, + }; + + Ok(dtype) +} + +// Conversion functions to API format. Mainly used for constructing the request to create tables. + +pub fn schema_to_column_info_list(schema: &Schema) -> PolarsResult> { + schema + .iter() + .enumerate() + .map(|(i, (name, dtype))| { + let name = name.clone(); + let type_text = dtype_to_type_text(dtype)?; + let type_name = dtype_to_type_name(dtype)?; + let type_json = serde_json::to_string(&field_to_type_json(name.clone(), dtype)?) + .map_err(to_compute_err)?; + + Ok(ColumnInfo { + name, + type_name, + type_text, + type_json, + position: Some(i.try_into().unwrap()), + comment: None, + partition_index: None, + }) + }) + .collect::>() +} + +/// Creates the `type_text` field of the API. Opposite of [`parse_type_text`] +fn dtype_to_type_text(dtype: &DataType) -> PolarsResult { + use DataType::*; + use polars_core::prelude::TimeUnit; + + macro_rules! S { + ($e:expr) => { + PlSmallStr::from_static($e) + }; + } + + let out = match dtype { + Boolean => S!("boolean"), + + Int8 => S!("tinyint"), + Int16 => S!("smallint"), + Int32 => S!("int"), + Int64 => S!("bigint"), + + Float32 => S!("float"), + Float64 => S!("double"), + + Date => S!("date"), + Datetime(TimeUnit::Microseconds, None) => S!("timestamp_ntz"), + + String => S!("string"), + Binary => S!("binary"), + + Null => S!("null"), + + Decimal(precision, scale) => { + let precision = precision.unwrap_or(38); + let scale = scale.unwrap_or(0); + + format_pl_smallstr!("decimal({},{})", precision, scale) + }, + + List(inner) => { + if let Some((key_type, value_type)) = get_list_map_type(inner) { + format_pl_smallstr!( + "map<{},{}>", + dtype_to_type_text(key_type)?, + dtype_to_type_text(value_type)? + ) + } else { + format_pl_smallstr!("array<{}>", dtype_to_type_text(inner)?) + } + }, + + Struct(fields) => { + // Yes, it's possible to construct column names containing the brackets. This won't + // affect us as we parse using `type_json` rather than this field. + let mut out = std::string::String::from("struct<"); + + for Field { name, dtype } in fields { + out.push_str(name); + out.push(':'); + out.push_str(&dtype_to_type_text(dtype)?); + out.push(','); + } + + if out.ends_with(',') { + out.truncate(out.len() - 1); + } + + out.push('>'); + + out.into() + }, + + v => polars_bail!( + ComputeError: + "dtype_to_type_text unsupported type: {}", + v + ), + }; + + Ok(out) +} + +/// Creates the `type_name` field, from testing this wasn't exactly the same as the `type_text` field. +fn dtype_to_type_name(dtype: &DataType) -> PolarsResult { + use DataType::*; + use polars_core::prelude::TimeUnit; + + macro_rules! S { + ($e:expr) => { + PlSmallStr::from_static($e) + }; + } + + let out = match dtype { + Boolean => S!("BOOLEAN"), + + Int8 => S!("BYTE"), + Int16 => S!("SHORT"), + Int32 => S!("INT"), + Int64 => S!("LONG"), + + Float32 => S!("FLOAT"), + Float64 => S!("DOUBLE"), + + Date => S!("DATE"), + Datetime(TimeUnit::Microseconds, None) => S!("TIMESTAMP_NTZ"), + String => S!("STRING"), + Binary => S!("BINARY"), + + Null => S!("NULL"), + + Decimal(..) => S!("DECIMAL"), + + List(inner) => { + if get_list_map_type(inner).is_some() { + S!("MAP") + } else { + S!("ARRAY") + } + }, + + Struct(..) => S!("STRUCT"), + + v => polars_bail!( + ComputeError: + "dtype_to_type_text unsupported type: {}", + v + ), + }; + + Ok(out) +} + +/// Creates the `type_json` field. +fn field_to_type_json(name: PlSmallStr, dtype: &DataType) -> PolarsResult { + Ok(ColumnTypeJson { + name: Some(name), + type_: dtype_to_type_json(dtype)?, + nullable: Some(true), + // We set this to Some(_) so that the output matches the one generated by Databricks. + metadata: Some(Default::default()), + + ..Default::default() + }) +} + +fn dtype_to_type_json(dtype: &DataType) -> PolarsResult { + use DataType::*; + use polars_core::prelude::TimeUnit; + + macro_rules! S { + ($e:expr) => { + ColumnTypeJsonType::from_static_type_name($e) + }; + } + + let out = match dtype { + Boolean => S!("boolean"), + + Int8 => S!("byte"), + Int16 => S!("short"), + Int32 => S!("integer"), + Int64 => S!("long"), + + Float32 => S!("float"), + Float64 => S!("double"), + + Date => S!("date"), + Datetime(TimeUnit::Microseconds, None) => S!("timestamp_ntz"), + + String => S!("string"), + Binary => S!("binary"), + + Null => S!("null"), + + Decimal(..) => ColumnTypeJsonType::TypeName(dtype_to_type_text(dtype)?), + + List(inner) => { + let out = if let Some((key_type, value_type)) = get_list_map_type(inner) { + ColumnTypeJson { + type_: ColumnTypeJsonType::from_static_type_name("map"), + key_type: Some(dtype_to_type_json(key_type)?), + value_type: Some(dtype_to_type_json(value_type)?), + value_contains_null: Some(true), + + ..Default::default() + } + } else { + ColumnTypeJson { + type_: ColumnTypeJsonType::from_static_type_name("array"), + element_type: Some(dtype_to_type_json(inner)?), + contains_null: Some(true), + + ..Default::default() + } + }; + + ColumnTypeJsonType::TypeJson(Box::new(out)) + }, + + Struct(fields) => { + let out = ColumnTypeJson { + type_: ColumnTypeJsonType::from_static_type_name("struct"), + fields: Some( + fields + .iter() + .map(|Field { name, dtype }| field_to_type_json(name.clone(), dtype)) + .collect::>()?, + ), + + ..Default::default() + }; + + ColumnTypeJsonType::TypeJson(Box::new(out)) + }, + + v => polars_bail!( + ComputeError: + "dtype_to_type_text unsupported type: {}", + v + ), + }; + + Ok(out) +} + +/// Tries to interpret the List type as a `map` field, which is essentially +/// List(Struct(("key", ), ("value", ))). +/// +/// Returns `Option<(key_type, value_type)>` +fn get_list_map_type(list_inner_dtype: &DataType) -> Option<(&DataType, &DataType)> { + let DataType::Struct(fields) = list_inner_dtype else { + return None; + }; + + let [fld1, fld2] = fields.as_slice() else { + return None; + }; + + if !(fld1.name == "key" && fld2.name == "value") { + return None; + } + + Some((fld1.dtype(), fld2.dtype())) +} diff --git a/crates/polars-io/src/catalog/unity/utils.rs b/crates/polars-io/src/catalog/unity/utils.rs new file mode 100644 index 000000000000..c76d8fa3b844 --- /dev/null +++ b/crates/polars-io/src/catalog/unity/utils.rs @@ -0,0 +1,121 @@ +use bytes::Bytes; +use polars_error::{PolarsResult, to_compute_err}; +use polars_utils::error::TruncateErrorDetail; +use reqwest::RequestBuilder; + +/// Performs the request and attaches the response body to any error messages. +pub(super) async fn do_request(request: reqwest::RequestBuilder) -> PolarsResult { + let resp = request.send().await.map_err(to_compute_err)?; + let opt_err = resp.error_for_status_ref().map(|_| ()); + let resp_bytes = resp.bytes().await.map_err(to_compute_err)?; + + opt_err.map_err(|e| { + to_compute_err(e).wrap_msg(|e| { + let body = String::from_utf8_lossy(&resp_bytes); + + format!( + "error: {}, response body: {}", + e, + TruncateErrorDetail(&body) + ) + }) + })?; + + Ok(resp_bytes) +} + +/// Support for traversing paginated response values that look like: +/// ```text +/// { +/// $key_name: [$T, $T, ...], +/// next_page_token: "token" or null, +/// } +/// ``` +#[macro_export] +macro_rules! impl_page_walk { + ($S:ty, $T:ty, key_name = $key_name:tt) => { + impl $S { + pub async fn next(&mut self) -> PolarsResult>> { + return self + .0 + .next(|bytes| { + let Response { + $key_name: out, + next_page_token, + } = decode_json_response(bytes)?; + + Ok((out, next_page_token)) + }) + .await; + + #[derive(serde::Deserialize)] + struct Response { + #[serde(default = "Vec::new")] + $key_name: Vec<$T>, + #[serde(default)] + next_page_token: Option, + } + } + + pub async fn read_all_pages(mut self) -> PolarsResult> { + let Some(mut out) = self.next().await? else { + return Ok(vec![]); + }; + + while let Some(v) = self.next().await? { + out.extend(v); + } + + Ok(out) + } + } + }; +} + +pub(crate) struct PageWalker { + request: RequestBuilder, + next_page_token: Option, + has_run: bool, +} + +impl PageWalker { + pub(crate) fn new(request: RequestBuilder) -> Self { + Self { + request, + next_page_token: None, + has_run: false, + } + } + + pub(crate) async fn next(&mut self, deserializer: F) -> PolarsResult> + where + F: Fn(&[u8]) -> PolarsResult<(T, Option)>, + { + let Some(resp_bytes) = self.next_bytes().await? else { + return Ok(None); + }; + + let (value, next_page_token) = deserializer(&resp_bytes)?; + self.next_page_token = next_page_token; + + Ok(Some(value)) + } + + pub(crate) async fn next_bytes(&mut self) -> PolarsResult> { + if self.has_run && self.next_page_token.is_none() { + return Ok(None); + } + + self.has_run = true; + + let request = self.request.try_clone().unwrap(); + + let request = if let Some(page_token) = self.next_page_token.take() { + request.query(&[("page_token", page_token)]) + } else { + request + }; + + do_request(request).await.map(Some) + } +} diff --git a/crates/polars-io/src/cloud/adaptors.rs b/crates/polars-io/src/cloud/adaptors.rs new file mode 100644 index 000000000000..13b2a4d1e21d --- /dev/null +++ b/crates/polars-io/src/cloud/adaptors.rs @@ -0,0 +1,221 @@ +//! Interface with the object_store crate and define AsyncSeek, AsyncRead. + +use std::sync::Arc; + +use object_store::ObjectStore; +use object_store::buffered::BufWriter; +use object_store::path::Path; +use polars_error::PolarsResult; +use polars_utils::file::WriteClose; +use tokio::io::AsyncWriteExt; + +use super::{CloudOptions, object_path_from_str}; +use crate::pl_async::{get_runtime, get_upload_chunk_size}; + +fn clone_io_err(e: &std::io::Error) -> std::io::Error { + std::io::Error::new(e.kind(), e.to_string()) +} + +/// Adaptor which wraps the interface of [ObjectStore::BufWriter] exposing a synchronous interface +/// which implements `std::io::Write`. +/// +/// This allows it to be used in sync code which would otherwise write to a simple File or byte stream, +/// such as with `polars::prelude::CsvWriter`. +/// +/// [ObjectStore::BufWriter]: https://docs.rs/object_store/latest/object_store/buffered/struct.BufWriter.html +pub struct BlockingCloudWriter { + state: std::io::Result, +} + +impl BlockingCloudWriter { + /// Construct a new BlockingCloudWriter, re-using the given `object_store` + /// + /// Creates a new (current-thread) Tokio runtime + /// which bridges the sync writing process with the async ObjectStore multipart uploading. + /// TODO: Naming? + pub fn new_with_object_store( + object_store: Arc, + path: Path, + ) -> PolarsResult { + let writer = BufWriter::with_capacity(object_store, path, get_upload_chunk_size()); + Ok(BlockingCloudWriter { state: Ok(writer) }) + } + + /// Constructs a new BlockingCloudWriter from a path and an optional set of CloudOptions. + /// + /// Wrapper around `BlockingCloudWriter::new_with_object_store` that is useful if you only have a single write task. + /// TODO: Naming? + pub async fn new(uri: &str, cloud_options: Option<&CloudOptions>) -> PolarsResult { + if let Some(local_path) = uri.strip_prefix("file://") { + // Local paths must be created first, otherwise object store will not write anything. + if !matches!(std::fs::exists(local_path), Ok(true)) { + panic!( + "[BlockingCloudWriter] Expected local file to be created: {}", + local_path + ); + } + } + + let (cloud_location, object_store) = + crate::cloud::build_object_store(uri, cloud_options, false).await?; + Self::new_with_object_store( + object_store.to_dyn_object_store().await, + object_path_from_str(&cloud_location.prefix)?, + ) + } + + /// Returns the underlying [`object_store::buffered::BufWriter`] + pub fn try_into_inner(mut self) -> std::io::Result { + // We can't just return self.state: + // * cannot move out of type `adaptors::BlockingCloudWriter`, which implements the `Drop` trait + std::mem::replace(&mut self.state, Err(std::io::Error::other(""))) + } + + /// Closes the writer, or returns the existing error if it exists. After this function is called + /// the writer is guaranteed to be in an error state. + pub fn close(&mut self) -> std::io::Result<()> { + match self.try_with_writer(|writer| get_runtime().block_in_place_on(writer.shutdown())) { + Ok(_) => { + self.state = Err(std::io::Error::other("closed")); + Ok(()) + }, + Err(e) => Err(e), + } + } + + fn try_with_writer(&mut self, func: F) -> std::io::Result + where + F: Fn(&mut BufWriter) -> std::io::Result, + { + let writer: &mut BufWriter = self.state.as_mut().map_err(|e| clone_io_err(e))?; + match func(writer) { + Ok(v) => Ok(v), + Err(e) => { + self.state = Err(clone_io_err(&e)); + Err(e) + }, + } + } +} + +impl std::io::Write for BlockingCloudWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + // SAFETY: + // We extend the lifetime for the duration of this function. This is safe as we block the + // async runtime here + let buf = unsafe { std::mem::transmute::<&[u8], &'static [u8]>(buf) }; + + self.try_with_writer(|writer| { + get_runtime() + .block_in_place_on(async { writer.write_all(buf).await.map(|_t| buf.len()) }) + }) + } + + fn flush(&mut self) -> std::io::Result<()> { + self.try_with_writer(|writer| get_runtime().block_in_place_on(writer.flush())) + } +} + +impl WriteClose for BlockingCloudWriter { + fn close(mut self: Box) -> std::io::Result<()> { + BlockingCloudWriter::close(self.as_mut()) + } +} + +impl Drop for BlockingCloudWriter { + fn drop(&mut self) { + if self.state.is_err() { + return; + } + + // Note: We should not hit here - the writer should instead be explicitly closed. + // But we still have this here as a safety measure to prevent silently dropping errors. + match self.close() { + Ok(()) => {}, + e @ Err(_) => { + if std::thread::panicking() { + eprintln!("ERROR: CloudWriter errored on close: {:?}", e) + } else { + e.unwrap() + } + }, + } + } +} + +#[cfg(test)] +mod tests { + + use polars_core::df; + use polars_core::prelude::DataFrame; + + fn example_dataframe() -> DataFrame { + df!( + "foo" => &[1, 2, 3], + "bar" => &[None, Some("bak"), Some("baz")], + ) + .unwrap() + } + + #[test] + #[cfg(feature = "csv")] + fn csv_to_local_objectstore_cloudwriter() { + use super::*; + use crate::csv::write::CsvWriter; + use crate::prelude::SerWriter; + + let mut df = example_dataframe(); + + let object_store: Arc = Arc::new( + object_store::local::LocalFileSystem::new_with_prefix(std::env::temp_dir()) + .expect("Could not initialize connection"), + ); + + let path: object_store::path::Path = "cloud_writer_example.csv".into(); + + let mut cloud_writer = + BlockingCloudWriter::new_with_object_store(object_store, path).unwrap(); + CsvWriter::new(&mut cloud_writer) + .finish(&mut df) + .expect("Could not write DataFrame as CSV to remote location"); + } + + // Skip this tests on Windows since it does not have a convenient /tmp/ location. + #[cfg_attr(target_os = "windows", ignore)] + #[cfg(feature = "csv")] + #[test] + fn cloudwriter_from_cloudlocation_test() { + use super::*; + use crate::SerReader; + use crate::csv::write::CsvWriter; + use crate::prelude::{CsvReadOptions, SerWriter}; + + let mut df = example_dataframe(); + + let path = "/tmp/cloud_writer_example2.csv"; + + std::fs::File::create(path).unwrap(); + + let mut cloud_writer = get_runtime() + .block_on(BlockingCloudWriter::new( + format!("file://{}", path).as_str(), + None, + )) + .unwrap(); + + CsvWriter::new(&mut cloud_writer) + .finish(&mut df) + .expect("Could not write DataFrame as CSV to remote location"); + + cloud_writer.close().unwrap(); + + assert_eq!( + CsvReadOptions::default() + .try_into_reader_with_file_path(Some(path.into())) + .unwrap() + .finish() + .unwrap(), + df + ); + } +} diff --git a/crates/polars-io/src/cloud/credential_provider.rs b/crates/polars-io/src/cloud/credential_provider.rs new file mode 100644 index 000000000000..8f1dd531fe9d --- /dev/null +++ b/crates/polars-io/src/cloud/credential_provider.rs @@ -0,0 +1,798 @@ +use std::fmt::Debug; +use std::future::Future; +use std::hash::Hash; +use std::pin::Pin; +use std::sync::Arc; +use std::time::{SystemTime, UNIX_EPOCH}; + +use async_trait::async_trait; +#[cfg(feature = "aws")] +pub use object_store::aws::AwsCredential; +#[cfg(feature = "azure")] +pub use object_store::azure::AzureCredential; +#[cfg(feature = "gcp")] +pub use object_store::gcp::GcpCredential; +use polars_core::config; +use polars_error::{PolarsResult, polars_bail}; +#[cfg(feature = "python")] +use polars_utils::python_function::PythonObject; +#[cfg(feature = "python")] +use python_impl::PythonCredentialProvider; + +#[derive(Clone, Debug, PartialEq, Hash, Eq)] +pub enum PlCredentialProvider { + /// Prefer using [`PlCredentialProvider::from_func`] instead of constructing this directly + Function(CredentialProviderFunction), + #[cfg(feature = "python")] + Python(python_impl::PythonCredentialProvider), +} + +impl PlCredentialProvider { + /// Accepts a function that returns (credential, expiry time as seconds since UNIX_EPOCH) + /// + /// This functionality is unstable. + pub fn from_func( + // Internal notes + // * This function is exposed as the Rust API for `PlCredentialProvider` + func: impl Fn() -> Pin< + Box> + Send + Sync>, + > + Send + + Sync + + 'static, + ) -> Self { + Self::Function(CredentialProviderFunction(Arc::new(func))) + } + + /// Intended to be called with an internal `CredentialProviderBuilder` from + /// py-polars. + #[cfg(feature = "python")] + pub fn from_python_builder(func: pyo3::PyObject) -> Self { + Self::Python(python_impl::PythonCredentialProvider::Builder(Arc::new( + PythonObject(func), + ))) + } + + pub(super) fn func_addr(&self) -> usize { + match self { + Self::Function(CredentialProviderFunction(v)) => Arc::as_ptr(v) as *const () as usize, + #[cfg(feature = "python")] + Self::Python(v) => v.func_addr(), + } + } + + /// Python passes a `CredentialProviderBuilder`, this calls the builder to build the final + /// credential provider. + /// + /// This returns `Option` as the auto-initialization case is fallible and falls back to None. + pub(crate) fn try_into_initialized(self) -> PolarsResult> { + match self { + Self::Function(_) => Ok(Some(self)), + #[cfg(feature = "python")] + Self::Python(v) => Ok(v.try_into_initialized()?.map(Self::Python)), + } + } +} + +pub enum ObjectStoreCredential { + #[cfg(feature = "aws")] + Aws(Arc), + #[cfg(feature = "azure")] + Azure(Arc), + #[cfg(feature = "gcp")] + Gcp(Arc), + /// For testing purposes + None, +} + +impl ObjectStoreCredential { + fn variant_name(&self) -> &'static str { + match self { + #[cfg(feature = "aws")] + Self::Aws(_) => "Aws", + #[cfg(feature = "azure")] + Self::Azure(_) => "Azure", + #[cfg(feature = "gcp")] + Self::Gcp(_) => "Gcp", + Self::None => "None", + } + } + + fn panic_type_mismatch(&self, expected: &str) { + panic!( + "impl error: credential type mismatch: expected {}, got {} instead", + expected, + self.variant_name() + ) + } + + #[cfg(feature = "aws")] + fn unwrap_aws(self) -> Arc { + let Self::Aws(v) = self else { + self.panic_type_mismatch("aws"); + unreachable!() + }; + v + } + + #[cfg(feature = "azure")] + fn unwrap_azure(self) -> Arc { + let Self::Azure(v) = self else { + self.panic_type_mismatch("azure"); + unreachable!() + }; + v + } + + #[cfg(feature = "gcp")] + fn unwrap_gcp(self) -> Arc { + let Self::Gcp(v) = self else { + self.panic_type_mismatch("gcp"); + unreachable!() + }; + v + } +} + +pub trait IntoCredentialProvider: Sized { + #[cfg(feature = "aws")] + fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider { + unimplemented!() + } + + #[cfg(feature = "azure")] + fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider { + unimplemented!() + } + + #[cfg(feature = "gcp")] + fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider { + unimplemented!() + } +} + +impl IntoCredentialProvider for PlCredentialProvider { + #[cfg(feature = "aws")] + fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider { + match self { + Self::Function(v) => v.into_aws_provider(), + #[cfg(feature = "python")] + Self::Python(v) => v.into_aws_provider(), + } + } + + #[cfg(feature = "azure")] + fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider { + match self { + Self::Function(v) => v.into_azure_provider(), + #[cfg(feature = "python")] + Self::Python(v) => v.into_azure_provider(), + } + } + + #[cfg(feature = "gcp")] + fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider { + match self { + Self::Function(v) => v.into_gcp_provider(), + #[cfg(feature = "python")] + Self::Python(v) => v.into_gcp_provider(), + } + } +} + +type CredentialProviderFunctionImpl = Arc< + dyn Fn() -> Pin< + Box> + Send + Sync>, + > + Send + + Sync, +>; + +/// Wrapper that implements [`IntoCredentialProvider`], [`Debug`], [`PartialEq`], [`Hash`] etc. +#[derive(Clone)] +pub struct CredentialProviderFunction(CredentialProviderFunctionImpl); + +macro_rules! build_to_object_store_err { + ($s:expr) => {{ + fn to_object_store_err( + e: impl std::error::Error + Send + Sync + 'static, + ) -> object_store::Error { + object_store::Error::Generic { + store: $s, + source: Box::new(e), + } + } + + to_object_store_err + }}; +} + +impl IntoCredentialProvider for CredentialProviderFunction { + #[cfg(feature = "aws")] + fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider { + #[derive(Debug)] + struct S( + CredentialProviderFunction, + FetchedCredentialsCache>, + ); + + #[async_trait] + impl object_store::CredentialProvider for S { + type Credential = object_store::aws::AwsCredential; + + async fn get_credential(&self) -> object_store::Result> { + self.1 + .get_maybe_update(async { + let (creds, expiry) = self.0.0().await?; + PolarsResult::Ok((creds.unwrap_aws(), expiry)) + }) + .await + .map_err(build_to_object_store_err!("credential-provider-aws")) + } + } + + Arc::new(S( + self, + FetchedCredentialsCache::new(Arc::new(AwsCredential { + key_id: String::new(), + secret_key: String::new(), + token: None, + })), + )) + } + + #[cfg(feature = "azure")] + fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider { + #[derive(Debug)] + struct S( + CredentialProviderFunction, + FetchedCredentialsCache>, + ); + + #[async_trait] + impl object_store::CredentialProvider for S { + type Credential = object_store::azure::AzureCredential; + + async fn get_credential(&self) -> object_store::Result> { + self.1 + .get_maybe_update(async { + let (creds, expiry) = self.0.0().await?; + PolarsResult::Ok((creds.unwrap_azure(), expiry)) + }) + .await + .map_err(build_to_object_store_err!("credential-provider-azure")) + } + } + + Arc::new(S( + self, + FetchedCredentialsCache::new(Arc::new(AzureCredential::BearerToken(String::new()))), + )) + } + + #[cfg(feature = "gcp")] + fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider { + #[derive(Debug)] + struct S( + CredentialProviderFunction, + FetchedCredentialsCache>, + ); + + #[async_trait] + impl object_store::CredentialProvider for S { + type Credential = object_store::gcp::GcpCredential; + + async fn get_credential(&self) -> object_store::Result> { + self.1 + .get_maybe_update(async { + let (creds, expiry) = self.0.0().await?; + PolarsResult::Ok((creds.unwrap_gcp(), expiry)) + }) + .await + .map_err(build_to_object_store_err!("credential-provider-gcp")) + } + } + + Arc::new(S( + self, + FetchedCredentialsCache::new(Arc::new(GcpCredential { + bearer: String::new(), + })), + )) + } +} + +impl Debug for CredentialProviderFunction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "credential provider function at 0x{:016x}", + self.0.as_ref() as *const _ as *const () as usize + ) + } +} + +impl Eq for CredentialProviderFunction {} + +impl PartialEq for CredentialProviderFunction { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.0, &other.0) + } +} + +impl Hash for CredentialProviderFunction { + fn hash(&self, state: &mut H) { + state.write_usize(Arc::as_ptr(&self.0) as *const () as usize) + } +} + +#[cfg(feature = "serde")] +impl<'de> serde::Deserialize<'de> for PlCredentialProvider { + fn deserialize(_deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[cfg(feature = "python")] + { + Ok(Self::Python(PythonCredentialProvider::deserialize( + _deserializer, + )?)) + } + #[cfg(not(feature = "python"))] + { + use serde::de::Error; + Err(D::Error::custom("cannot deserialize PlCredentialProvider")) + } + } +} + +#[cfg(feature = "serde")] +impl serde::Serialize for PlCredentialProvider { + fn serialize(&self, _serializer: S) -> Result + where + S: serde::Serializer, + { + use serde::ser::Error; + + #[cfg(feature = "python")] + if let PlCredentialProvider::Python(v) = self { + return v.serialize(_serializer); + } + + Err(S::Error::custom(format!("cannot serialize {:?}", self))) + } +} + +/// Avoids calling the credential provider function if we have not yet passed the expiry time. +#[derive(Debug)] +struct FetchedCredentialsCache(tokio::sync::Mutex<(C, u64)>); + +impl FetchedCredentialsCache { + fn new(init_creds: C) -> Self { + Self(tokio::sync::Mutex::new((init_creds, 0))) + } + + async fn get_maybe_update( + &self, + // Taking an `impl Future` here allows us to potentially avoid a `Box::pin` allocation from + // a `Fn() -> Pin>` by having it wrapped in an `async { f() }` block. We + // will not poll that block if the credentials have not yet expired. + update_func: impl Future>, + ) -> PolarsResult { + let verbose = config::verbose(); + + fn expiry_msg(last_fetched_expiry: u64, now: u64) -> String { + if last_fetched_expiry == u64::MAX { + "expiry = (never expires)".into() + } else { + format!( + "expiry = {} (in {} seconds)", + last_fetched_expiry, + last_fetched_expiry.saturating_sub(now) + ) + } + } + + let mut inner = self.0.lock().await; + let (last_fetched_credentials, last_fetched_expiry) = &mut *inner; + + let current_time = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + // Ensure the credential is valid for at least this many seconds to + // accommodate for latency. + const REQUEST_TIME_BUFFER: u64 = 7; + + if last_fetched_expiry.saturating_sub(current_time) < REQUEST_TIME_BUFFER { + if verbose { + eprintln!( + "[FetchedCredentialsCache]: Call update_func: current_time = {}\ + , last_fetched_expiry = {}", + current_time, *last_fetched_expiry + ) + } + let (credentials, expiry) = update_func.await?; + + *last_fetched_credentials = credentials; + *last_fetched_expiry = expiry; + + if expiry < current_time && expiry != 0 { + polars_bail!( + ComputeError: + "credential expiry time {} is older than system time {} \ + by {} seconds", + expiry, + current_time, + current_time - expiry + ) + } + + if verbose { + eprintln!( + "[FetchedCredentialsCache]: Finish update_func: new {}", + expiry_msg( + *last_fetched_expiry, + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + ) + ) + } + } else if verbose { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + eprintln!( + "[FetchedCredentialsCache]: Using cached credentials: \ + current_time = {}, {}", + now, + expiry_msg(*last_fetched_expiry, now) + ) + } + + Ok(last_fetched_credentials.clone()) + } +} + +#[cfg(feature = "python")] +mod python_impl { + use std::hash::Hash; + use std::sync::Arc; + + use polars_error::{PolarsError, PolarsResult, to_compute_err}; + use polars_utils::python_function::PythonObject; + use pyo3::Python; + use pyo3::exceptions::PyValueError; + use pyo3::pybacked::PyBackedStr; + use pyo3::types::{PyAnyMethods, PyDict, PyDictMethods}; + + use super::IntoCredentialProvider; + + #[derive(Clone, Debug)] + #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] + pub enum PythonCredentialProvider { + #[cfg_attr( + feature = "serde", + serde( + serialize_with = "PythonObject::serialize_with_pyversion", + deserialize_with = "PythonObject::deserialize_with_pyversion" + ) + )] + /// Indicates `py_object` is a `CredentialProviderBuilder`. + Builder(Arc), + #[cfg_attr( + feature = "serde", + serde( + serialize_with = "PythonObject::serialize_with_pyversion", + deserialize_with = "PythonObject::deserialize_with_pyversion" + ) + )] + /// Indicates `py_object` is an instantiated credential provider + Provider(Arc), + } + + impl PythonCredentialProvider { + /// Performs initialization if necessary. + /// + /// This exists as a separate step that must be called beforehand. This approach is easier + /// as the alternative is to refactor the `IntoCredentialProvider` trait to return + /// `PolarsResult>` for every single function. + pub(super) fn try_into_initialized(self) -> PolarsResult> { + match self { + Self::Builder(py_object) => { + let opt_initialized_py_object = Python::with_gil(|py| { + let build_fn = py_object.getattr(py, "build_credential_provider")?; + + let v = build_fn.call0(py)?; + let v = (!v.is_none(py)).then_some(v); + + pyo3::PyResult::Ok(v) + }) + .map_err(to_compute_err)?; + + Ok(opt_initialized_py_object + .map(PythonObject) + .map(Arc::new) + .map(Self::Provider)) + }, + Self::Provider(_) => { + // Note: We don't expect to hit here. + Ok(Some(self)) + }, + } + } + + fn unwrap_as_provider(self) -> Arc { + match self { + Self::Builder(_) => panic!(), + Self::Provider(v) => v, + } + } + + pub(super) fn func_addr(&self) -> usize { + (match self { + Self::Builder(v) => Arc::as_ptr(v), + Self::Provider(v) => Arc::as_ptr(v), + }) as *const () as usize + } + } + + impl IntoCredentialProvider for PythonCredentialProvider { + #[cfg(feature = "aws")] + fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider { + use polars_error::{PolarsResult, to_compute_err}; + + use crate::cloud::credential_provider::{ + CredentialProviderFunction, ObjectStoreCredential, + }; + + let func = self.unwrap_as_provider(); + + CredentialProviderFunction(Arc::new(move || { + let func = func.clone(); + Box::pin(async move { + let mut credentials = object_store::aws::AwsCredential { + key_id: String::new(), + secret_key: String::new(), + token: None, + }; + + let expiry = Python::with_gil(|py| { + let v = func.0.call0(py)?.into_bound(py); + let (storage_options, expiry) = + v.extract::<(pyo3::Bound<'_, PyDict>, Option)>()?; + + for (k, v) in storage_options.iter() { + let k = k.extract::()?; + let v = v.extract::>()?; + + match k.as_ref() { + "aws_access_key_id" => { + credentials.key_id = v.ok_or_else(|| { + PyValueError::new_err("aws_access_key_id was None") + })?; + }, + "aws_secret_access_key" => { + credentials.secret_key = v.ok_or_else(|| { + PyValueError::new_err("aws_secret_access_key was None") + })? + }, + "aws_session_token" => credentials.token = v, + v => { + return pyo3::PyResult::Err(PyValueError::new_err(format!( + "unknown configuration key for aws: {}, \ + valid configuration keys are: \ + {}, {}, {}", + v, + "aws_access_key_id", + "aws_secret_access_key", + "aws_session_token" + ))); + }, + } + } + + pyo3::PyResult::Ok(expiry.unwrap_or(u64::MAX)) + }) + .map_err(to_compute_err)?; + + if credentials.key_id.is_empty() { + return Err(PolarsError::ComputeError( + "aws_access_key_id was empty or not given".into(), + )); + } + + if credentials.secret_key.is_empty() { + return Err(PolarsError::ComputeError( + "aws_secret_access_key was empty or not given".into(), + )); + } + + PolarsResult::Ok((ObjectStoreCredential::Aws(Arc::new(credentials)), expiry)) + }) + })) + .into_aws_provider() + } + + #[cfg(feature = "azure")] + fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider { + use object_store::azure::AzureAccessKey; + use polars_error::{PolarsResult, to_compute_err}; + + use crate::cloud::credential_provider::{ + CredentialProviderFunction, ObjectStoreCredential, + }; + + let func = self.unwrap_as_provider(); + + CredentialProviderFunction(Arc::new(move || { + let func = func.clone(); + Box::pin(async move { + let mut credentials = None; + + static VALID_KEYS_MSG: &str = + "valid configuration keys are: account_key, bearer_token"; + + let expiry = Python::with_gil(|py| { + let v = func.0.call0(py)?.into_bound(py); + let (storage_options, expiry) = + v.extract::<(pyo3::Bound<'_, PyDict>, Option)>()?; + + for (k, v) in storage_options.iter() { + let k = k.extract::()?; + let v = v.extract::()?; + + match k.as_ref() { + "account_key" => { + credentials = + Some(object_store::azure::AzureCredential::AccessKey( + AzureAccessKey::try_new(v.as_str()).map_err(|e| { + PyValueError::new_err(e.to_string()) + })?, + )) + }, + "bearer_token" => { + credentials = + Some(object_store::azure::AzureCredential::BearerToken(v)) + }, + v => { + return pyo3::PyResult::Err(PyValueError::new_err(format!( + "unknown configuration key for azure: {}, {}", + v, VALID_KEYS_MSG + ))); + }, + } + } + + pyo3::PyResult::Ok(expiry.unwrap_or(u64::MAX)) + }) + .map_err(to_compute_err)?; + + let Some(credentials) = credentials else { + return Err(PolarsError::ComputeError( + format!( + "did not find a valid configuration key for azure, {}", + VALID_KEYS_MSG + ) + .into(), + )); + }; + + PolarsResult::Ok((ObjectStoreCredential::Azure(Arc::new(credentials)), expiry)) + }) + })) + .into_azure_provider() + } + + #[cfg(feature = "gcp")] + fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider { + use polars_error::{PolarsResult, to_compute_err}; + + use crate::cloud::credential_provider::{ + CredentialProviderFunction, ObjectStoreCredential, + }; + + let func = self.unwrap_as_provider(); + + CredentialProviderFunction(Arc::new(move || { + let func = func.clone(); + Box::pin(async move { + let mut credentials = object_store::gcp::GcpCredential { + bearer: String::new(), + }; + + let expiry = Python::with_gil(|py| { + let v = func.0.call0(py)?.into_bound(py); + let (storage_options, expiry) = + v.extract::<(pyo3::Bound<'_, PyDict>, Option)>()?; + + for (k, v) in storage_options.iter() { + let k = k.extract::()?; + let v = v.extract::()?; + + match k.as_ref() { + "bearer_token" => credentials.bearer = v, + v => { + return pyo3::PyResult::Err(PyValueError::new_err(format!( + "unknown configuration key for gcp: {}, \ + valid configuration keys are: {}", + v, "bearer_token", + ))); + }, + } + } + + pyo3::PyResult::Ok(expiry.unwrap_or(u64::MAX)) + }) + .map_err(to_compute_err)?; + + if credentials.bearer.is_empty() { + return Err(PolarsError::ComputeError( + "bearer was empty or not given".into(), + )); + } + + PolarsResult::Ok((ObjectStoreCredential::Gcp(Arc::new(credentials)), expiry)) + }) + })) + .into_gcp_provider() + } + } + + // Note: We don't consider `is_builder` for hash/eq - we don't expect the same Arc + // to be referenced as both true and false from the `is_builder` field. + + impl Eq for PythonCredentialProvider {} + + impl PartialEq for PythonCredentialProvider { + fn eq(&self, other: &Self) -> bool { + self.func_addr() == other.func_addr() + } + } + + impl Hash for PythonCredentialProvider { + fn hash(&self, state: &mut H) { + // # Safety + // * Inner is an `Arc` + // * Visibility is limited to super + // * No code in `mod python_impl` or `super` mutates the Arc inner. + state.write_usize(self.func_addr()) + } + } +} + +#[cfg(test)] +mod tests { + #[cfg(feature = "serde")] + #[allow(clippy::redundant_pattern_matching)] + #[test] + fn test_serde() { + use super::*; + + assert!(matches!( + serde_json::to_string(&Some(PlCredentialProvider::from_func(|| { + Box::pin(core::future::ready(PolarsResult::Ok(( + ObjectStoreCredential::None, + 0, + )))) + }))), + Err(_) + )); + + assert!(matches!( + serde_json::to_string(&Option::::None), + Ok(String { .. }) + )); + + assert!(matches!( + serde_json::from_str::>( + serde_json::to_string(&Option::::None) + .unwrap() + .as_str() + ), + Ok(None) + )); + } +} diff --git a/crates/polars-io/src/cloud/glob.rs b/crates/polars-io/src/cloud/glob.rs new file mode 100644 index 000000000000..6a0127d64244 --- /dev/null +++ b/crates/polars-io/src/cloud/glob.rs @@ -0,0 +1,418 @@ +use std::borrow::Cow; + +use futures::TryStreamExt; +use object_store::path::Path; +use polars_core::error::to_compute_err; +use polars_error::{PolarsResult, polars_bail}; +use polars_utils::format_pl_smallstr; +use polars_utils::pl_str::PlSmallStr; +use regex::Regex; +use url::Url; + +use super::{CloudOptions, parse_url}; + +const DELIMITER: char = '/'; + +/// Converts a glob to regex form. +/// +/// # Returns +/// 1. the prefix part (all path components until the first one with '*') +/// 2. a regular expression representation of the rest. +pub(crate) fn extract_prefix_expansion(url: &str) -> PolarsResult<(Cow, Option)> { + let url = url.strip_prefix('/').unwrap_or(url); + // (offset, len, replacement) + let mut replacements: Vec<(usize, usize, &[u8])> = vec![]; + + // The position after the last slash before glob characters begin. + // `a/b/c*/` + // ^ + let mut pos: usize = if let Some(after_last_slash) = memchr::memchr2(b'*', b'[', url.as_bytes()) + .map(|i| { + url.as_bytes()[..i] + .iter() + .rposition(|x| *x == b'/') + .map_or(0, |x| 1 + x) + }) { + // First value is used as the starting point later. + replacements.push((after_last_slash, 0, &[])); + after_last_slash + } else { + usize::MAX + }; + + while pos < url.len() { + match memchr::memchr2(b'*', b'.', &url.as_bytes()[pos..]) { + None => break, + Some(i) => pos += i, + } + + let (len, replace): (usize, &[u8]) = match &url[pos..] { + // Accept: + // - `**/` + // - `**` only if it is the end of the url + v if v.starts_with("**") && (v.len() == 2 || v.as_bytes()[2] == b'/') => { + // Wrapping in a capture group ensures we also match non-nested paths. + (3, b"(.*/)?" as _) + }, + v if v.starts_with("**") => { + polars_bail!(ComputeError: "invalid ** glob pattern") + }, + v if v.starts_with('*') => (1, b"[^/]*" as _), + // Dots need to be escaped in regex. + v if v.starts_with('.') => (1, b"\\." as _), + _ => { + pos += 1; + continue; + }, + }; + + replacements.push((pos, len, replace)); + pos += len; + } + + if replacements.is_empty() { + return Ok((Cow::Borrowed(url), None)); + } + + let prefix = Cow::Borrowed(&url[..replacements[0].0]); + + let mut pos = replacements[0].0; + let mut expansion = Vec::with_capacity(url.len() - pos); + expansion.push(b'^'); + + for (offset, len, replace) in replacements { + expansion.extend_from_slice(&url.as_bytes()[pos..offset]); + expansion.extend_from_slice(replace); + pos = offset + len; + } + + if pos < url.len() { + expansion.extend_from_slice(&url.as_bytes()[pos..]); + } + + expansion.push(b'$'); + + Ok((prefix, Some(String::from_utf8(expansion).unwrap()))) +} + +/// A location on cloud storage, may have wildcards. +#[derive(PartialEq, Debug, Default)] +pub struct CloudLocation { + /// The scheme (s3, ...). + pub scheme: PlSmallStr, + /// The bucket name. + pub bucket: PlSmallStr, + /// The prefix inside the bucket, this will be the full key when wildcards are not used. + pub prefix: String, + /// The path components that need to be expanded. + pub expansion: Option, +} + +impl CloudLocation { + pub fn from_url(parsed: &Url, glob: bool) -> PolarsResult { + let is_local = parsed.scheme() == "file"; + let (bucket, key) = if is_local { + ("".into(), parsed.path()) + } else { + if parsed.scheme().starts_with("http") { + return Ok(CloudLocation { + scheme: parsed.scheme().into(), + ..Default::default() + }); + } + + let key = parsed.path(); + + let bucket = format_pl_smallstr!( + "{}", + &parsed[url::Position::BeforeUsername..url::Position::AfterPort] + ); + + if bucket.is_empty() { + polars_bail!(ComputeError: "CloudLocation::from_url(): empty bucket: {}", parsed); + } + + (bucket, key) + }; + + let key = percent_encoding::percent_decode_str(key) + .decode_utf8() + .map_err(to_compute_err)?; + let (prefix, expansion) = if glob { + let (prefix, expansion) = extract_prefix_expansion(&key)?; + let mut prefix = prefix.into_owned(); + if is_local && key.starts_with(DELIMITER) && !prefix.starts_with(DELIMITER) { + prefix.insert(0, DELIMITER); + } + (prefix, expansion.map(|x| x.into())) + } else { + (key.as_ref().into(), None) + }; + + Ok(CloudLocation { + scheme: parsed.scheme().into(), + bucket, + prefix, + expansion, + }) + } + + /// Parse a CloudLocation from an url. + pub fn new(url: &str, glob: bool) -> PolarsResult { + let parsed = parse_url(url).map_err(to_compute_err)?; + Self::from_url(&parsed, glob) + } +} + +/// Return a full url from a key relative to the given location. +fn full_url(scheme: &str, bucket: &str, key: Path) -> String { + format!("{scheme}://{bucket}/{key}") +} + +/// A simple matcher, if more is required consider depending on https://crates.io/crates/globset. +/// The Cloud list api returns a list of all the file names under a prefix, there is no additional cost of `readdir`. +pub(crate) struct Matcher { + prefix: String, + re: Option, +} + +impl Matcher { + /// Build a Matcher for the given prefix and expansion. + pub(crate) fn new(prefix: String, expansion: Option<&str>) -> PolarsResult { + // Cloud APIs accept a prefix without any expansion, extract it. + let re = expansion + .map(polars_utils::regex_cache::compile_regex) + .transpose()?; + Ok(Matcher { prefix, re }) + } + + pub(crate) fn is_matching(&self, key: &str) -> bool { + if !key.starts_with(self.prefix.as_str()) { + // Prefix does not match, should not happen. + return false; + } + if self.re.is_none() { + return true; + } + let last = &key[self.prefix.len()..]; + self.re.as_ref().unwrap().is_match(last.as_ref()) + } +} + +/// List files with a prefix derived from the pattern. +pub async fn glob(url: &str, cloud_options: Option<&CloudOptions>) -> PolarsResult> { + // Find the fixed prefix, up to the first '*'. + + let ( + CloudLocation { + scheme, + bucket, + prefix, + expansion, + }, + store, + ) = super::build_object_store(url, cloud_options, true).await?; + let matcher = &Matcher::new( + if scheme == "file" { + // For local paths the returned location has the leading slash stripped. + prefix[1..].into() + } else { + prefix.clone() + }, + expansion.as_deref(), + )?; + + let path = Path::from(prefix.as_str()); + let path = Some(&path); + + let mut locations = store + .try_exec_rebuild_on_err(|store| { + let st = store.clone(); + + async { + let store = st; + store + .list(path) + .try_filter_map(|x| async move { + let out = (x.size > 0 && matcher.is_matching(x.location.as_ref())) + .then_some(x.location); + Ok(out) + }) + .try_collect::>() + .await + .map_err(to_compute_err) + } + }) + .await?; + + locations.sort_unstable(); + Ok(locations + .into_iter() + .map(|l| full_url(&scheme, &bucket, l)) + .collect::>()) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_cloud_location() { + assert_eq!( + CloudLocation::new("s3://a/b", true).unwrap(), + CloudLocation { + scheme: "s3".into(), + bucket: "a".into(), + prefix: "b".into(), + expansion: None, + } + ); + assert_eq!( + CloudLocation::new("s3://a/b/*.c", true).unwrap(), + CloudLocation { + scheme: "s3".into(), + bucket: "a".into(), + prefix: "b/".into(), + expansion: Some("^[^/]*\\.c$".into()), + } + ); + assert_eq!( + CloudLocation::new("file:///a/b", true).unwrap(), + CloudLocation { + scheme: "file".into(), + bucket: "".into(), + prefix: "/a/b".into(), + expansion: None, + } + ); + } + + #[test] + fn test_extract_prefix_expansion() { + assert!(extract_prefix_expansion("**url").is_err()); + assert_eq!( + extract_prefix_expansion("a/b.c").unwrap(), + ("a/b.c".into(), None) + ); + assert_eq!( + extract_prefix_expansion("a/**").unwrap(), + ("a/".into(), Some("^(.*/)?$".into())) + ); + assert_eq!( + extract_prefix_expansion("a/**/b").unwrap(), + ("a/".into(), Some("^(.*/)?b$".into())) + ); + assert_eq!( + extract_prefix_expansion("a/**/*b").unwrap(), + ("a/".into(), Some("^(.*/)?[^/]*b$".into())) + ); + assert_eq!( + extract_prefix_expansion("a/**/data/*b").unwrap(), + ("a/".into(), Some("^(.*/)?data/[^/]*b$".into())) + ); + assert_eq!( + extract_prefix_expansion("a/*b").unwrap(), + ("a/".into(), Some("^[^/]*b$".into())) + ); + } + + #[test] + fn test_matcher_file_name() { + let cloud_location = CloudLocation::new("s3://bucket/folder/*.parquet", true).unwrap(); + let a = Matcher::new(cloud_location.prefix, cloud_location.expansion.as_deref()).unwrap(); + // Regular match. + assert!(a.is_matching(Path::from("folder/1.parquet").as_ref())); + // Require . in the file name. + assert!(!a.is_matching(Path::from("folder/1parquet").as_ref())); + // Intermediary folders are not allowed. + assert!(!a.is_matching(Path::from("folder/other/1.parquet").as_ref())); + } + + #[test] + fn test_matcher_folders() { + let cloud_location = CloudLocation::new("s3://bucket/folder/**/*.parquet", true).unwrap(); + + let a = Matcher::new(cloud_location.prefix, cloud_location.expansion.as_deref()).unwrap(); + // Intermediary folders are optional. + assert!(a.is_matching(Path::from("folder/1.parquet").as_ref())); + // Intermediary folders are allowed. + assert!(a.is_matching(Path::from("folder/other/1.parquet").as_ref())); + + let cloud_location = + CloudLocation::new("s3://bucket/folder/**/data/*.parquet", true).unwrap(); + let a = Matcher::new(cloud_location.prefix, cloud_location.expansion.as_deref()).unwrap(); + + // Required folder `data` is missing. + assert!(!a.is_matching(Path::from("folder/1.parquet").as_ref())); + // Required folder is present. + assert!(a.is_matching(Path::from("folder/data/1.parquet").as_ref())); + // Required folder is present and additional folders are allowed. + assert!(a.is_matching(Path::from("folder/other/data/1.parquet").as_ref())); + } + + #[test] + fn test_cloud_location_no_glob() { + let cloud_location = CloudLocation::new("s3://bucket/[*", false).unwrap(); + assert_eq!( + cloud_location, + CloudLocation { + scheme: "s3".into(), + bucket: "bucket".into(), + prefix: "/[*".into(), + expansion: None, + }, + ) + } + + #[test] + fn test_cloud_location_percentages() { + use super::CloudLocation; + + let path = "s3://bucket/%25"; + let cloud_location = CloudLocation::new(path, true).unwrap(); + + assert_eq!( + cloud_location, + CloudLocation { + scheme: "s3".into(), + bucket: "bucket".into(), + prefix: "%25".into(), + expansion: None, + } + ); + + let path = "https://pola.rs/%25"; + let cloud_location = CloudLocation::new(path, true).unwrap(); + + assert_eq!( + cloud_location, + CloudLocation { + scheme: "https".into(), + bucket: "".into(), + prefix: "".into(), + expansion: None, + } + ); + } + + #[test] + fn test_glob_wildcard_21736() { + let url = "s3://bucket/folder/**/data.parquet"; + let cloud_location = CloudLocation::new(url, true).unwrap(); + + let a = Matcher::new(cloud_location.prefix, cloud_location.expansion.as_deref()).unwrap(); + + assert!(!a.is_matching("folder/_data.parquet")); + + assert!(a.is_matching("folder/data.parquet")); + assert!(a.is_matching("folder/abc/data.parquet")); + assert!(a.is_matching("folder/abc/def/data.parquet")); + + let url = "s3://bucket/folder/data_*.parquet"; + let cloud_location = CloudLocation::new(url, true).unwrap(); + + let a = Matcher::new(cloud_location.prefix, cloud_location.expansion.as_deref()).unwrap(); + + assert!(!a.is_matching("folder/data_1.ipc")) + } +} diff --git a/crates/polars-io/src/cloud/mod.rs b/crates/polars-io/src/cloud/mod.rs new file mode 100644 index 000000000000..7ae2d99444a7 --- /dev/null +++ b/crates/polars-io/src/cloud/mod.rs @@ -0,0 +1,24 @@ +//! Interface with cloud storage through the object_store crate. + +#[cfg(feature = "cloud")] +mod adaptors; +#[cfg(feature = "cloud")] +mod glob; +#[cfg(feature = "cloud")] +mod object_store_setup; +pub mod options; +#[cfg(feature = "cloud")] +mod polars_object_store; + +#[cfg(feature = "cloud")] +pub use adaptors::*; +#[cfg(feature = "cloud")] +pub use glob::*; +#[cfg(feature = "cloud")] +pub use object_store_setup::*; +pub use options::*; +#[cfg(feature = "cloud")] +pub use polars_object_store::*; + +#[cfg(feature = "cloud")] +pub mod credential_provider; diff --git a/crates/polars-io/src/cloud/object_store_setup.rs b/crates/polars-io/src/cloud/object_store_setup.rs new file mode 100644 index 000000000000..e486b1af0ee8 --- /dev/null +++ b/crates/polars-io/src/cloud/object_store_setup.rs @@ -0,0 +1,254 @@ +use std::sync::{Arc, LazyLock}; + +use object_store::ObjectStore; +use object_store::local::LocalFileSystem; +use polars_core::config::{self, verbose_print_sensitive}; +use polars_error::{PolarsError, PolarsResult, polars_bail, to_compute_err}; +use polars_utils::aliases::PlHashMap; +use polars_utils::pl_str::PlSmallStr; +use polars_utils::{format_pl_smallstr, pl_serialize}; +use tokio::sync::RwLock; +use url::Url; + +use super::{CloudLocation, CloudOptions, CloudType, PolarsObjectStore, parse_url}; +use crate::cloud::CloudConfig; + +/// Object stores must be cached. Every object-store will do DNS lookups and +/// get rate limited when querying the DNS (can take up to 5s). +/// Other reasons are connection pools that must be shared between as much as possible. +#[allow(clippy::type_complexity)] +static OBJECT_STORE_CACHE: LazyLock, PolarsObjectStore>>> = + LazyLock::new(Default::default); + +#[allow(dead_code)] +fn err_missing_feature(feature: &str, scheme: &str) -> PolarsResult> { + polars_bail!( + ComputeError: + "feature '{}' must be enabled in order to use '{}' cloud urls", feature, scheme, + ); +} + +/// Get the key of a url for object store registration. +fn url_and_creds_to_key(url: &Url, options: Option<&CloudOptions>) -> Vec { + // We include credentials as they can expire, so users will send new credentials for the same url. + let cloud_options = options.map( + |CloudOptions { + // Destructure to ensure this breaks if anything changes. + max_retries, + #[cfg(feature = "file_cache")] + file_cache_ttl, + config, + #[cfg(feature = "cloud")] + credential_provider, + }| { + CloudOptions2 { + max_retries: *max_retries, + #[cfg(feature = "file_cache")] + file_cache_ttl: *file_cache_ttl, + config: config.clone(), + #[cfg(feature = "cloud")] + credential_provider: credential_provider.as_ref().map_or(0, |x| x.func_addr()), + } + }, + ); + + let cache_key = CacheKey { + url_base: format_pl_smallstr!( + "{}", + &url[url::Position::BeforeScheme..url::Position::AfterPort] + ), + cloud_options, + }; + + verbose_print_sensitive(|| format!("object store cache key: {} {:?}", url, &cache_key)); + + return pl_serialize::serialize_to_bytes::<_, false>(&cache_key).unwrap(); + + #[derive(Clone, Debug, PartialEq, Hash, Eq)] + #[cfg_attr(feature = "serde", derive(serde::Serialize))] + struct CacheKey { + url_base: PlSmallStr, + cloud_options: Option, + } + + /// Variant of CloudOptions for serializing to a cache key. The credential + /// provider is replaced by the function address. + #[derive(Clone, Debug, PartialEq, Hash, Eq)] + #[cfg_attr(feature = "serde", derive(serde::Serialize))] + struct CloudOptions2 { + max_retries: usize, + #[cfg(feature = "file_cache")] + file_cache_ttl: u64, + config: Option, + #[cfg(feature = "cloud")] + credential_provider: usize, + } +} + +/// Construct an object_store `Path` from a string without any encoding/decoding. +pub fn object_path_from_str(path: &str) -> PolarsResult { + object_store::path::Path::parse(path).map_err(to_compute_err) +} + +#[derive(Debug, Clone)] +pub(crate) struct PolarsObjectStoreBuilder { + url: PlSmallStr, + parsed_url: Url, + #[allow(unused)] + scheme: PlSmallStr, + cloud_type: CloudType, + options: Option, +} + +impl PolarsObjectStoreBuilder { + pub(super) async fn build_impl(&self) -> PolarsResult> { + let options = self + .options + .as_ref() + .unwrap_or_else(|| CloudOptions::default_static_ref()); + + let store = match self.cloud_type { + CloudType::Aws => { + #[cfg(feature = "aws")] + { + let store = options.build_aws(&self.url).await?; + Ok::<_, PolarsError>(Arc::new(store) as Arc) + } + #[cfg(not(feature = "aws"))] + return err_missing_feature("aws", &self.scheme); + }, + CloudType::Gcp => { + #[cfg(feature = "gcp")] + { + let store = options.build_gcp(&self.url)?; + Ok::<_, PolarsError>(Arc::new(store) as Arc) + } + #[cfg(not(feature = "gcp"))] + return err_missing_feature("gcp", &self.scheme); + }, + CloudType::Azure => { + { + #[cfg(feature = "azure")] + { + let store = options.build_azure(&self.url)?; + Ok::<_, PolarsError>(Arc::new(store) as Arc) + } + } + #[cfg(not(feature = "azure"))] + return err_missing_feature("azure", &self.scheme); + }, + CloudType::File => { + let local = LocalFileSystem::new(); + Ok::<_, PolarsError>(Arc::new(local) as Arc) + }, + CloudType::Http => { + { + #[cfg(feature = "http")] + { + let store = options.build_http(&self.url)?; + PolarsResult::Ok(Arc::new(store) as Arc) + } + } + #[cfg(not(feature = "http"))] + return err_missing_feature("http", &cloud_location.scheme); + }, + CloudType::Hf => panic!("impl error: unresolved hf:// path"), + }?; + + Ok(store) + } + + /// Note: Use `build_impl` for a non-caching version. + pub(super) async fn build(self) -> PolarsResult { + let opt_cache_key = match &self.cloud_type { + CloudType::Aws | CloudType::Gcp | CloudType::Azure => Some(url_and_creds_to_key( + &self.parsed_url, + self.options.as_ref(), + )), + CloudType::File | CloudType::Http | CloudType::Hf => None, + }; + + let opt_cache_write_guard = if let Some(cache_key) = opt_cache_key.as_deref() { + let cache = OBJECT_STORE_CACHE.read().await; + + if let Some(store) = cache.get(cache_key) { + return Ok(store.clone()); + } + + drop(cache); + + let cache = OBJECT_STORE_CACHE.write().await; + + if let Some(store) = cache.get(cache_key) { + return Ok(store.clone()); + } + + Some(cache) + } else { + None + }; + + let store = self.build_impl().await?; + let store = PolarsObjectStore::new_from_inner(store, self); + + if let Some(mut cache) = opt_cache_write_guard { + // Clear the cache if we surpass a certain amount of buckets. + if cache.len() >= 8 { + if config::verbose() { + eprintln!( + "build_object_store: clearing store cache (cache.len(): {})", + cache.len() + ); + } + cache.clear() + } + + cache.insert(opt_cache_key.unwrap(), store.clone()); + } + + Ok(store) + } + + pub(crate) fn is_azure(&self) -> bool { + matches!(&self.cloud_type, CloudType::Azure) + } +} + +/// Build an [`ObjectStore`] based on the URL and passed in url. Return the cloud location and an implementation of the object store. +pub async fn build_object_store( + url: &str, + #[cfg_attr( + not(any(feature = "aws", feature = "gcp", feature = "azure")), + allow(unused_variables) + )] + options: Option<&CloudOptions>, + glob: bool, +) -> PolarsResult<(CloudLocation, PolarsObjectStore)> { + let parsed = parse_url(url).map_err(to_compute_err)?; + let cloud_location = CloudLocation::from_url(&parsed, glob)?; + let cloud_type = CloudType::from_url(&parsed)?; + + let store = PolarsObjectStoreBuilder { + url: url.into(), + parsed_url: parsed, + scheme: cloud_location.scheme.as_str().into(), + cloud_type, + options: options.cloned(), + } + .build() + .await?; + + Ok((cloud_location, store)) +} + +mod test { + #[test] + fn test_object_path_from_str() { + use super::object_path_from_str; + + let path = "%25"; + let out = object_path_from_str(path).unwrap(); + + assert_eq!(out.as_ref(), path); + } +} diff --git a/crates/polars-io/src/cloud/options.rs b/crates/polars-io/src/cloud/options.rs new file mode 100644 index 000000000000..aa3712d88ff0 --- /dev/null +++ b/crates/polars-io/src/cloud/options.rs @@ -0,0 +1,750 @@ +#[cfg(feature = "aws")] +use std::io::Read; +#[cfg(feature = "aws")] +use std::path::Path; +use std::str::FromStr; +use std::sync::LazyLock; + +#[cfg(any(feature = "aws", feature = "gcp", feature = "azure", feature = "http"))] +use object_store::ClientOptions; +#[cfg(feature = "aws")] +use object_store::aws::AmazonS3Builder; +#[cfg(feature = "aws")] +pub use object_store::aws::AmazonS3ConfigKey; +#[cfg(feature = "azure")] +pub use object_store::azure::AzureConfigKey; +#[cfg(feature = "azure")] +use object_store::azure::MicrosoftAzureBuilder; +#[cfg(feature = "gcp")] +use object_store::gcp::GoogleCloudStorageBuilder; +#[cfg(feature = "gcp")] +pub use object_store::gcp::GoogleConfigKey; +#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] +use object_store::{BackoffConfig, RetryConfig}; +use polars_error::*; +#[cfg(feature = "aws")] +use polars_utils::cache::LruCache; +#[cfg(feature = "http")] +use reqwest::header::HeaderMap; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; +#[cfg(feature = "cloud")] +use url::Url; + +#[cfg(feature = "cloud")] +use super::credential_provider::PlCredentialProvider; +#[cfg(feature = "file_cache")] +use crate::file_cache::get_env_file_cache_ttl; +#[cfg(feature = "aws")] +use crate::pl_async::with_concurrency_budget; + +#[cfg(feature = "aws")] +static BUCKET_REGION: LazyLock< + std::sync::Mutex>, +> = LazyLock::new(|| std::sync::Mutex::new(LruCache::with_capacity(32))); + +/// The type of the config keys must satisfy the following requirements: +/// 1. must be easily collected into a HashMap, the type required by the object_crate API. +/// 2. be Serializable, required when the serde-lazy feature is defined. +/// 3. not actually use HashMap since that type is disallowed in Polars for performance reasons. +/// +/// Currently this type is a vector of pairs config key - config value. +#[allow(dead_code)] +type Configs = Vec<(T, String)>; + +#[derive(Clone, Debug, PartialEq, Hash, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub(crate) enum CloudConfig { + #[cfg(feature = "aws")] + Aws(Configs), + #[cfg(feature = "azure")] + Azure(Configs), + #[cfg(feature = "gcp")] + Gcp(Configs), + #[cfg(feature = "http")] + Http { headers: Vec<(String, String)> }, +} + +#[derive(Clone, Debug, PartialEq, Hash, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +/// Options to connect to various cloud providers. +pub struct CloudOptions { + pub max_retries: usize, + #[cfg(feature = "file_cache")] + pub file_cache_ttl: u64, + pub(crate) config: Option, + #[cfg(feature = "cloud")] + /// Note: In most cases you will want to access this via [`CloudOptions::initialized_credential_provider`] + /// rather than directly. + pub(crate) credential_provider: Option, +} + +impl Default for CloudOptions { + fn default() -> Self { + Self::default_static_ref().clone() + } +} + +impl CloudOptions { + pub fn default_static_ref() -> &'static Self { + static DEFAULT: LazyLock = LazyLock::new(|| CloudOptions { + max_retries: 2, + #[cfg(feature = "file_cache")] + file_cache_ttl: get_env_file_cache_ttl(), + config: None, + #[cfg(feature = "cloud")] + credential_provider: None, + }); + + &DEFAULT + } +} + +#[cfg(feature = "http")] +pub(crate) fn try_build_http_header_map_from_items_slice>( + headers: &[(S, S)], +) -> PolarsResult { + use reqwest::header::{HeaderName, HeaderValue}; + + let mut map = HeaderMap::with_capacity(headers.len()); + for (k, v) in headers { + let (k, v) = (k.as_ref(), v.as_ref()); + map.insert( + HeaderName::from_str(k).map_err(to_compute_err)?, + HeaderValue::from_str(v).map_err(to_compute_err)?, + ); + } + + Ok(map) +} + +#[allow(dead_code)] +/// Parse an untype configuration hashmap to a typed configuration for the given configuration key type. +fn parsed_untyped_config, impl Into)>>( + config: I, +) -> PolarsResult> +where + T: FromStr + Eq + std::hash::Hash, +{ + Ok(config + .into_iter() + // Silently ignores custom upstream storage_options + .filter_map(|(key, val)| { + T::from_str(key.as_ref().to_ascii_lowercase().as_str()) + .ok() + .map(|typed_key| (typed_key, val.into())) + }) + .collect::>()) +} + +#[derive(Debug, Clone, PartialEq)] +pub enum CloudType { + Aws, + Azure, + File, + Gcp, + Http, + Hf, +} + +impl CloudType { + #[cfg(feature = "cloud")] + pub(crate) fn from_url(parsed: &Url) -> PolarsResult { + Ok(match parsed.scheme() { + "s3" | "s3a" => Self::Aws, + "az" | "azure" | "adl" | "abfs" | "abfss" => Self::Azure, + "gs" | "gcp" | "gcs" => Self::Gcp, + "file" => Self::File, + "http" | "https" => Self::Http, + "hf" => Self::Hf, + _ => polars_bail!(ComputeError: "unknown url scheme"), + }) + } +} + +#[cfg(feature = "cloud")] +pub(crate) fn parse_url(input: &str) -> std::result::Result { + Ok(if input.contains("://") { + if input.starts_with("http://") || input.starts_with("https://") { + url::Url::parse(input) + } else { + url::Url::parse(&input.replace("%", "%25")) + }? + } else { + let path = std::path::Path::new(input); + let mut tmp; + url::Url::from_file_path(if path.is_relative() { + tmp = std::env::current_dir().unwrap(); + tmp.push(path); + tmp.as_path() + } else { + path + }) + .unwrap() + }) +} + +impl FromStr for CloudType { + type Err = PolarsError; + + #[cfg(feature = "cloud")] + fn from_str(url: &str) -> Result { + let parsed = parse_url(url).map_err(to_compute_err)?; + Self::from_url(&parsed) + } + + #[cfg(not(feature = "cloud"))] + fn from_str(_s: &str) -> Result { + polars_bail!(ComputeError: "at least one of the cloud features must be enabled"); + } +} +#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] +fn get_retry_config(max_retries: usize) -> RetryConfig { + RetryConfig { + backoff: BackoffConfig::default(), + max_retries, + retry_timeout: std::time::Duration::from_secs(10), + } +} + +#[cfg(any(feature = "aws", feature = "gcp", feature = "azure", feature = "http"))] +pub(super) fn get_client_options() -> ClientOptions { + ClientOptions::new() + // We set request timeout super high as the timeout isn't reset at ACK, + // but starts from the moment we start downloading a body. + // https://docs.rs/reqwest/latest/reqwest/struct.ClientBuilder.html#method.timeout + .with_timeout_disabled() + // Concurrency can increase connection latency, so set to None, similar to default. + .with_connect_timeout_disabled() + .with_allow_http(true) +} + +#[cfg(feature = "aws")] +fn read_config( + builder: &mut AmazonS3Builder, + items: &[(&Path, &[(&str, AmazonS3ConfigKey)])], +) -> Option<()> { + use crate::path_utils::resolve_homedir; + + for (path, keys) in items { + if keys + .iter() + .all(|(_, key)| builder.get_config_value(key).is_some()) + { + continue; + } + + let mut config = std::fs::File::open(resolve_homedir(path)).ok()?; + let mut buf = vec![]; + config.read_to_end(&mut buf).ok()?; + let content = std::str::from_utf8(buf.as_ref()).ok()?; + + for (pattern, key) in keys.iter() { + if builder.get_config_value(key).is_none() { + let reg = polars_utils::regex_cache::compile_regex(pattern).unwrap(); + let cap = reg.captures(content)?; + let m = cap.get(1)?; + let parsed = m.as_str(); + *builder = std::mem::take(builder).with_config(*key, parsed); + } + } + } + Some(()) +} + +impl CloudOptions { + /// Set the maximum number of retries. + pub fn with_max_retries(mut self, max_retries: usize) -> Self { + self.max_retries = max_retries; + self + } + + #[cfg(feature = "cloud")] + pub fn with_credential_provider( + mut self, + credential_provider: Option, + ) -> Self { + self.credential_provider = credential_provider; + self + } + + /// Set the configuration for AWS connections. This is the preferred API from rust. + #[cfg(feature = "aws")] + pub fn with_aws)>>( + mut self, + configs: I, + ) -> Self { + self.config = Some(CloudConfig::Aws( + configs.into_iter().map(|(k, v)| (k, v.into())).collect(), + )); + self + } + + /// Build the [`object_store::ObjectStore`] implementation for AWS. + #[cfg(feature = "aws")] + pub async fn build_aws(&self, url: &str) -> PolarsResult { + use super::credential_provider::IntoCredentialProvider; + + let mut builder = AmazonS3Builder::from_env() + .with_client_options(get_client_options()) + .with_url(url); + + read_config( + &mut builder, + &[( + Path::new("~/.aws/config"), + &[("region\\s*=\\s*([^\r\n]*)", AmazonS3ConfigKey::Region)], + )], + ); + + read_config( + &mut builder, + &[( + Path::new("~/.aws/credentials"), + &[ + ( + "aws_access_key_id\\s*=\\s*([^\\r\\n]*)", + AmazonS3ConfigKey::AccessKeyId, + ), + ( + "aws_secret_access_key\\s*=\\s*([^\\r\\n]*)", + AmazonS3ConfigKey::SecretAccessKey, + ), + ( + "aws_session_token\\s*=\\s*([^\\r\\n]*)", + AmazonS3ConfigKey::Token, + ), + ], + )], + ); + + if let Some(options) = &self.config { + let CloudConfig::Aws(options) = options else { + panic!("impl error: cloud type mismatch") + }; + for (key, value) in options.iter() { + builder = builder.with_config(*key, value); + } + } + + if builder + .get_config_value(&AmazonS3ConfigKey::DefaultRegion) + .is_none() + && builder + .get_config_value(&AmazonS3ConfigKey::Region) + .is_none() + { + let bucket = crate::cloud::CloudLocation::new(url, false)?.bucket; + let region = { + let mut bucket_region = BUCKET_REGION.lock().unwrap(); + bucket_region.get(bucket.as_str()).cloned() + }; + + match region { + Some(region) => { + builder = builder.with_config(AmazonS3ConfigKey::Region, region.as_str()) + }, + None => { + if builder + .get_config_value(&AmazonS3ConfigKey::Endpoint) + .is_some() + { + // Set a default value if the endpoint is not aws. + // See: #13042 + builder = builder.with_config(AmazonS3ConfigKey::Region, "us-east-1"); + } else { + polars_warn!( + "'(default_)region' not set; polars will try to get it from bucket\n\nSet the region manually to silence this warning." + ); + let result = with_concurrency_budget(1, || async { + reqwest::Client::builder() + .build() + .unwrap() + .head(format!("https://{bucket}.s3.amazonaws.com")) + .send() + .await + .map_err(to_compute_err) + }) + .await?; + if let Some(region) = result.headers().get("x-amz-bucket-region") { + let region = + std::str::from_utf8(region.as_bytes()).map_err(to_compute_err)?; + let mut bucket_region = BUCKET_REGION.lock().unwrap(); + bucket_region.insert(bucket, region.into()); + builder = builder.with_config(AmazonS3ConfigKey::Region, region) + } + } + }, + }; + }; + + let builder = builder.with_retry(get_retry_config(self.max_retries)); + + let builder = if let Some(v) = self.initialized_credential_provider()? { + builder.with_credentials(v.into_aws_provider()) + } else { + builder + }; + + builder.build().map_err(to_compute_err) + } + + /// Set the configuration for Azure connections. This is the preferred API from rust. + #[cfg(feature = "azure")] + pub fn with_azure)>>( + mut self, + configs: I, + ) -> Self { + self.config = Some(CloudConfig::Azure( + configs.into_iter().map(|(k, v)| (k, v.into())).collect(), + )); + self + } + + /// Build the [`object_store::ObjectStore`] implementation for Azure. + #[cfg(feature = "azure")] + pub fn build_azure(&self, url: &str) -> PolarsResult { + use super::credential_provider::IntoCredentialProvider; + + let verbose = polars_core::config::verbose(); + + // The credential provider `self.credentials` is prioritized if it is set. We also need + // `from_env()` as it may source environment configured storage account name. + let mut builder = + MicrosoftAzureBuilder::from_env().with_client_options(get_client_options()); + + if let Some(options) = &self.config { + let CloudConfig::Azure(options) = options else { + panic!("impl error: cloud type mismatch") + }; + for (key, value) in options.iter() { + builder = builder.with_config(*key, value); + } + } + + let builder = builder + .with_url(url) + .with_retry(get_retry_config(self.max_retries)); + + let builder = if let Some(v) = self.initialized_credential_provider()? { + if verbose { + eprintln!( + "[CloudOptions::build_azure]: Using credential provider {:?}", + &v + ); + } + builder.with_credentials(v.into_azure_provider()) + } else { + builder + }; + + builder.build().map_err(to_compute_err) + } + + /// Set the configuration for GCP connections. This is the preferred API from rust. + #[cfg(feature = "gcp")] + pub fn with_gcp)>>( + mut self, + configs: I, + ) -> Self { + self.config = Some(CloudConfig::Gcp( + configs.into_iter().map(|(k, v)| (k, v.into())).collect(), + )); + self + } + + /// Build the [`object_store::ObjectStore`] implementation for GCP. + #[cfg(feature = "gcp")] + pub fn build_gcp(&self, url: &str) -> PolarsResult { + use super::credential_provider::IntoCredentialProvider; + + let credential_provider = self.initialized_credential_provider()?; + + let builder = if credential_provider.is_none() { + GoogleCloudStorageBuilder::from_env() + } else { + GoogleCloudStorageBuilder::new() + }; + + let mut builder = builder.with_client_options(get_client_options()); + + if let Some(options) = &self.config { + let CloudConfig::Gcp(options) = options else { + panic!("impl error: cloud type mismatch") + }; + for (key, value) in options.iter() { + builder = builder.with_config(*key, value); + } + } + + let builder = builder + .with_url(url) + .with_retry(get_retry_config(self.max_retries)); + + let builder = if let Some(v) = credential_provider.clone() { + builder.with_credentials(v.into_gcp_provider()) + } else { + builder + }; + + builder.build().map_err(to_compute_err) + } + + #[cfg(feature = "http")] + pub fn build_http(&self, url: &str) -> PolarsResult { + object_store::http::HttpBuilder::new() + .with_url(url) + .with_client_options({ + let mut opts = super::get_client_options(); + if let Some(CloudConfig::Http { headers }) = &self.config { + opts = opts.with_default_headers(try_build_http_header_map_from_items_slice( + headers.as_slice(), + )?); + } + opts + }) + .build() + .map_err(to_compute_err) + } + + /// Parse a configuration from a Hashmap. This is the interface from Python. + #[allow(unused_variables)] + pub fn from_untyped_config, impl Into)>>( + url: &str, + config: I, + ) -> PolarsResult { + match CloudType::from_str(url)? { + CloudType::Aws => { + #[cfg(feature = "aws")] + { + parsed_untyped_config::(config) + .map(|aws| Self::default().with_aws(aws)) + } + #[cfg(not(feature = "aws"))] + { + polars_bail!(ComputeError: "'aws' feature is not enabled"); + } + }, + CloudType::Azure => { + #[cfg(feature = "azure")] + { + parsed_untyped_config::(config) + .map(|azure| Self::default().with_azure(azure)) + } + #[cfg(not(feature = "azure"))] + { + polars_bail!(ComputeError: "'azure' feature is not enabled"); + } + }, + CloudType::File => Ok(Self::default()), + CloudType::Http => Ok(Self::default()), + CloudType::Gcp => { + #[cfg(feature = "gcp")] + { + parsed_untyped_config::(config) + .map(|gcp| Self::default().with_gcp(gcp)) + } + #[cfg(not(feature = "gcp"))] + { + polars_bail!(ComputeError: "'gcp' feature is not enabled"); + } + }, + CloudType::Hf => { + #[cfg(feature = "http")] + { + use polars_core::config; + + use crate::path_utils::resolve_homedir; + + let mut this = Self::default(); + let mut token = None; + let verbose = config::verbose(); + + for (i, (k, v)) in config.into_iter().enumerate() { + let (k, v) = (k.as_ref(), v.into()); + + if i == 0 && k == "token" { + if verbose { + eprintln!("HF token sourced from storage_options"); + } + token = Some(v); + } else { + polars_bail!(ComputeError: "unknown configuration key for HF: {}", k) + } + } + + token = token + .or_else(|| { + let v = std::env::var("HF_TOKEN").ok(); + if v.is_some() && verbose { + eprintln!("HF token sourced from HF_TOKEN env var"); + } + v + }) + .or_else(|| { + let hf_home = std::env::var("HF_HOME"); + let hf_home = hf_home.as_deref(); + let hf_home = hf_home.unwrap_or("~/.cache/huggingface"); + let hf_home = resolve_homedir(&hf_home); + let cached_token_path = hf_home.join("token"); + + let v = std::string::String::from_utf8( + std::fs::read(&cached_token_path).ok()?, + ) + .ok() + .filter(|x| !x.is_empty()); + + if v.is_some() && verbose { + eprintln!( + "HF token sourced from {}", + cached_token_path.to_str().unwrap() + ); + } + + v + }); + + if let Some(v) = token { + this.config = Some(CloudConfig::Http { + headers: vec![("Authorization".into(), format!("Bearer {}", v))], + }) + } + + Ok(this) + } + #[cfg(not(feature = "http"))] + { + polars_bail!(ComputeError: "'http' feature is not enabled"); + } + }, + } + } + + /// Python passes a credential provider builder that needs to be called to get the actual credential + /// provider. + #[cfg(feature = "cloud")] + fn initialized_credential_provider(&self) -> PolarsResult> { + if let Some(v) = self.credential_provider.clone() { + v.try_into_initialized() + } else { + Ok(None) + } + } +} + +#[cfg(feature = "cloud")] +#[cfg(test)] +mod tests { + use hashbrown::HashMap; + + use super::{parse_url, parsed_untyped_config}; + + #[test] + fn test_parse_url() { + assert_eq!( + parse_url(r"http://Users/Jane Doe/data.csv") + .unwrap() + .as_str(), + "http://users/Jane%20Doe/data.csv" + ); + assert_eq!( + parse_url(r"http://Users/Jane Doe/data.csv") + .unwrap() + .as_str(), + "http://users/Jane%20Doe/data.csv" + ); + #[cfg(target_os = "windows")] + { + assert_eq!( + parse_url(r"file:///c:/Users/Jane Doe/data.csv") + .unwrap() + .as_str(), + "file:///c:/Users/Jane%20Doe/data.csv" + ); + assert_eq!( + parse_url(r"file://\c:\Users\Jane Doe\data.csv") + .unwrap() + .as_str(), + "file:///c:/Users/Jane%20Doe/data.csv" + ); + assert_eq!( + parse_url(r"c:\Users\Jane Doe\data.csv").unwrap().as_str(), + "file:///C:/Users/Jane%20Doe/data.csv" + ); + assert_eq!( + parse_url(r"data.csv").unwrap().as_str(), + url::Url::from_file_path( + [ + std::env::current_dir().unwrap().as_path(), + std::path::Path::new("data.csv") + ] + .into_iter() + .collect::() + ) + .unwrap() + .as_str() + ); + } + #[cfg(not(target_os = "windows"))] + { + assert_eq!( + parse_url(r"file:///home/Jane Doe/data.csv") + .unwrap() + .as_str(), + "file:///home/Jane%20Doe/data.csv" + ); + assert_eq!( + parse_url(r"/home/Jane Doe/data.csv").unwrap().as_str(), + "file:///home/Jane%20Doe/data.csv" + ); + assert_eq!( + parse_url(r"data.csv").unwrap().as_str(), + url::Url::from_file_path( + [ + std::env::current_dir().unwrap().as_path(), + std::path::Path::new("data.csv") + ] + .into_iter() + .collect::() + ) + .unwrap() + .as_str() + ); + } + } + #[cfg(feature = "aws")] + #[test] + fn test_parse_untyped_config() { + use object_store::aws::AmazonS3ConfigKey; + + let aws_config = [ + ("aws_secret_access_key", "a_key"), + ("aws_s3_allow_unsafe_rename", "true"), + ] + .into_iter() + .collect::>(); + let aws_keys = parsed_untyped_config::(aws_config) + .expect("Parsing keys shouldn't have thrown an error"); + + assert_eq!( + aws_keys.first().unwrap().0, + AmazonS3ConfigKey::SecretAccessKey + ); + assert_eq!(aws_keys.len(), 1); + + let aws_config = [ + ("AWS_SECRET_ACCESS_KEY", "a_key"), + ("aws_s3_allow_unsafe_rename", "true"), + ] + .into_iter() + .collect::>(); + let aws_keys = parsed_untyped_config::(aws_config) + .expect("Parsing keys shouldn't have thrown an error"); + + assert_eq!( + aws_keys.first().unwrap().0, + AmazonS3ConfigKey::SecretAccessKey + ); + assert_eq!(aws_keys.len(), 1); + } +} diff --git a/crates/polars-io/src/cloud/polars_object_store.rs b/crates/polars-io/src/cloud/polars_object_store.rs new file mode 100644 index 000000000000..43082b1c5f73 --- /dev/null +++ b/crates/polars-io/src/cloud/polars_object_store.rs @@ -0,0 +1,627 @@ +use std::ops::Range; + +use bytes::Bytes; +use futures::{StreamExt, TryStreamExt}; +use object_store::path::Path; +use object_store::{ObjectMeta, ObjectStore}; +use polars_core::prelude::{InitHashMaps, PlHashMap}; +use polars_error::{PolarsError, PolarsResult, to_compute_err}; +use tokio::io::{AsyncSeekExt, AsyncWriteExt}; + +use crate::pl_async::{ + self, MAX_BUDGET_PER_REQUEST, get_concurrency_limit, get_download_chunk_size, + tune_with_concurrency_budget, with_concurrency_budget, +}; + +mod inner { + use std::future::Future; + use std::sync::Arc; + use std::sync::atomic::AtomicBool; + + use object_store::ObjectStore; + use polars_core::config; + use polars_error::PolarsResult; + + use crate::cloud::PolarsObjectStoreBuilder; + + #[derive(Debug)] + struct Inner { + store: tokio::sync::Mutex>, + builder: PolarsObjectStoreBuilder, + } + + /// Polars wrapper around [`ObjectStore`] functionality. This struct is cheaply cloneable. + #[derive(Debug)] + pub struct PolarsObjectStore { + inner: Arc, + /// Avoid contending the Mutex `lock()` until the first re-build. + initial_store: std::sync::Arc, + /// Used for interior mutability. Doesn't need to be shared with other threads so it's not + /// inside `Arc<>`. + rebuilt: AtomicBool, + } + + impl Clone for PolarsObjectStore { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + initial_store: self.initial_store.clone(), + rebuilt: AtomicBool::new(self.rebuilt.load(std::sync::atomic::Ordering::Relaxed)), + } + } + } + + impl PolarsObjectStore { + pub(crate) fn new_from_inner( + store: Arc, + builder: PolarsObjectStoreBuilder, + ) -> Self { + let initial_store = store.clone(); + Self { + inner: Arc::new(Inner { + store: tokio::sync::Mutex::new(store), + builder, + }), + initial_store, + rebuilt: AtomicBool::new(false), + } + } + + /// Gets the underlying [`ObjectStore`] implementation. + pub async fn to_dyn_object_store(&self) -> Arc { + if !self.rebuilt.load(std::sync::atomic::Ordering::Relaxed) { + self.initial_store.clone() + } else { + self.inner.store.lock().await.clone() + } + } + + pub async fn rebuild_inner( + &self, + from_version: &Arc, + ) -> PolarsResult> { + let mut current_store = self.inner.store.lock().await; + + self.rebuilt + .store(true, std::sync::atomic::Ordering::Relaxed); + + // If this does not eq, then `inner` was already re-built by another thread. + if Arc::ptr_eq(&*current_store, from_version) { + *current_store = self.inner.builder.clone().build_impl().await.map_err(|e| { + e.wrap_msg(|e| format!("attempt to rebuild object store failed: {}", e)) + })?; + } + + Ok((*current_store).clone()) + } + + pub async fn try_exec_rebuild_on_err(&self, mut func: Fn) -> PolarsResult + where + Fn: FnMut(&Arc) -> Fut, + Fut: Future>, + { + let store = self.to_dyn_object_store().await; + + let out = func(&store).await; + + let orig_err = match out { + Ok(v) => return Ok(v), + Err(e) => e, + }; + + if config::verbose() { + eprintln!( + "[PolarsObjectStore]: got error: {}, will attempt re-build", + &orig_err + ); + } + + let store = self + .rebuild_inner(&store) + .await + .map_err(|e| e.wrap_msg(|e| format!("{}; original error: {}", e, orig_err)))?; + + func(&store).await.map_err(|e| { + if self.inner.builder.is_azure() + && std::env::var("POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY").as_deref() + != Ok("1") + { + // Note: This error is intended for Python audiences. The logic for retrieving + // these keys exist only on the Python side. + e.wrap_msg(|e| { + format!( + "{}; note: if you are using Python, consider setting \ +POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY=1 if you would like polars to try to retrieve \ +and use the storage account keys from Azure CLI to authenticate", + e + ) + }) + } else { + e + } + }) + } + } +} + +pub use inner::PolarsObjectStore; + +pub type ObjectStorePath = object_store::path::Path; + +impl PolarsObjectStore { + /// Returns a buffered stream that downloads concurrently up to the concurrency limit. + fn get_buffered_ranges_stream<'a, T: Iterator>>( + store: &'a dyn ObjectStore, + path: &'a Path, + ranges: T, + ) -> impl StreamExt> + + TryStreamExt> + + use<'a, T> { + futures::stream::iter( + ranges + .map(|range| async { store.get_range(path, range).await.map_err(to_compute_err) }), + ) + // Add a limit locally as this gets run inside a single `tune_with_concurrency_budget`. + .buffered(get_concurrency_limit() as usize) + } + + pub async fn get_range(&self, path: &Path, range: Range) -> PolarsResult { + self.try_exec_rebuild_on_err(move |store| { + let range = range.clone(); + let st = store.clone(); + + async { + let store = st; + let parts = split_range(range.clone()); + + if parts.len() == 1 { + tune_with_concurrency_budget(1, || async { store.get_range(path, range).await }) + .await + .map_err(to_compute_err) + } else { + let parts = tune_with_concurrency_budget( + parts.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32, + || { + Self::get_buffered_ranges_stream(&store, path, parts) + .try_collect::>() + }, + ) + .await?; + + let mut combined = Vec::with_capacity(range.len()); + + for part in parts { + combined.extend_from_slice(&part) + } + + assert_eq!(combined.len(), range.len()); + + PolarsResult::Ok(Bytes::from(combined)) + } + } + }) + .await + } + + /// Fetch byte ranges into a HashMap keyed by the range start. This will mutably sort the + /// `ranges` slice for coalescing. + /// + /// # Panics + /// Panics if the same range start is used by more than 1 range. + pub async fn get_ranges_sort< + K: TryFrom + std::hash::Hash + Eq, + T: From, + >( + &self, + path: &Path, + ranges: &mut [Range], + ) -> PolarsResult> { + if ranges.is_empty() { + return Ok(Default::default()); + } + + ranges.sort_unstable_by_key(|x| x.start); + + let ranges_len = ranges.len(); + let (merged_ranges, merged_ends): (Vec<_>, Vec<_>) = merge_ranges(ranges).unzip(); + + self.try_exec_rebuild_on_err(|store| { + let st = store.clone(); + + async { + let store = st; + let mut out = PlHashMap::with_capacity(ranges_len); + + let mut stream = + Self::get_buffered_ranges_stream(&store, path, merged_ranges.iter().cloned()); + + tune_with_concurrency_budget( + merged_ranges.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32, + || async { + let mut len = 0; + let mut current_offset = 0; + let mut ends_iter = merged_ends.iter(); + + let mut splitted_parts = vec![]; + + while let Some(bytes) = stream.try_next().await? { + len += bytes.len(); + let end = *ends_iter.next().unwrap(); + + if end == 0 { + splitted_parts.push(bytes); + continue; + } + + let full_range = ranges[current_offset..end] + .iter() + .cloned() + .reduce(|l, r| l.start.min(r.start)..l.end.max(r.end)) + .unwrap(); + + let bytes = if splitted_parts.is_empty() { + bytes + } else { + let mut out = Vec::with_capacity(full_range.len()); + + for x in splitted_parts.drain(..) { + out.extend_from_slice(&x); + } + + out.extend_from_slice(&bytes); + Bytes::from(out) + }; + + assert_eq!(bytes.len(), full_range.len()); + + for range in &ranges[current_offset..end] { + let v = out.insert( + K::try_from(range.start).unwrap(), + T::from(bytes.slice( + range.start - full_range.start + ..range.end - full_range.start, + )), + ); + + assert!(v.is_none()); // duplicate range start + } + + current_offset = end; + } + + assert!(splitted_parts.is_empty()); + + PolarsResult::Ok(pl_async::Size::from(len as u64)) + }, + ) + .await?; + + Ok(out) + } + }) + .await + } + + pub async fn download(&self, path: &Path, file: &mut tokio::fs::File) -> PolarsResult<()> { + let opt_size = self.head(path).await.ok().map(|x| x.size); + + let initial_pos = file.stream_position().await?; + + self.try_exec_rebuild_on_err(|store| { + let st = store.clone(); + + // Workaround for "can't move captured variable". + let file: &mut tokio::fs::File = unsafe { std::mem::transmute_copy(&file) }; + + async { + file.set_len(initial_pos).await?; // Reset if this function was called again. + + let store = st; + let parts = opt_size.map(|x| split_range(0..x)).filter(|x| x.len() > 1); + + if let Some(parts) = parts { + tune_with_concurrency_budget( + parts.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32, + || async { + let mut stream = Self::get_buffered_ranges_stream(&store, path, parts); + let mut len = 0; + while let Some(bytes) = stream.try_next().await? { + len += bytes.len(); + file.write_all(&bytes).await.map_err(to_compute_err)?; + } + + assert_eq!(len, opt_size.unwrap()); + + PolarsResult::Ok(pl_async::Size::from(len as u64)) + }, + ) + .await? + } else { + tune_with_concurrency_budget(1, || async { + let mut stream = + store.get(path).await.map_err(to_compute_err)?.into_stream(); + + let mut len = 0; + while let Some(bytes) = stream.try_next().await? { + len += bytes.len(); + file.write_all(&bytes).await.map_err(to_compute_err)?; + } + + PolarsResult::Ok(pl_async::Size::from(len as u64)) + }) + .await? + }; + + // Dropping is delayed for tokio async files so we need to explicitly + // flush here (https://github.com/tokio-rs/tokio/issues/2307#issuecomment-596336451). + file.sync_all().await.map_err(PolarsError::from)?; + + Ok(()) + } + }) + .await + } + + /// Fetch the metadata of the parquet file, do not memoize it. + pub async fn head(&self, path: &Path) -> PolarsResult { + self.try_exec_rebuild_on_err(|store| { + let st = store.clone(); + + async { + with_concurrency_budget(1, || async { + let store = st; + let head_result = store.head(path).await; + + if head_result.is_err() { + // Pre-signed URLs forbid the HEAD method, but we can still retrieve the header + // information with a range 0-0 request. + let get_range_0_0_result = store + .get_opts( + path, + object_store::GetOptions { + range: Some((0..1).into()), + ..Default::default() + }, + ) + .await; + + if let Ok(v) = get_range_0_0_result { + return Ok(v.meta); + } + } + + head_result + }) + .await + .map_err(to_compute_err) + } + }) + .await + } +} + +/// Splits a single range into multiple smaller ranges, which can be downloaded concurrently for +/// much higher throughput. +fn split_range(range: Range) -> impl ExactSizeIterator> { + let chunk_size = get_download_chunk_size(); + + // Calculate n_parts such that we are as close as possible to the `chunk_size`. + let n_parts = [ + (range.len().div_ceil(chunk_size)).max(1), + (range.len() / chunk_size).max(1), + ] + .into_iter() + .min_by_key(|x| (range.len() / *x).abs_diff(chunk_size)) + .unwrap(); + + let chunk_size = (range.len() / n_parts).max(1); + + assert_eq!(n_parts, (range.len() / chunk_size).max(1)); + let bytes_rem = range.len() % chunk_size; + + (0..n_parts).map(move |part_no| { + let (start, end) = if part_no == 0 { + // Download remainder length in the first chunk since it starts downloading first. + let end = range.start + chunk_size + bytes_rem; + let end = if end > range.end { range.end } else { end }; + (range.start, end) + } else { + let start = bytes_rem + range.start + part_no * chunk_size; + (start, start + chunk_size) + }; + + start..end + }) +} + +/// Note: For optimal performance, `ranges` should be sorted. More generally, +/// ranges placed next to each other should also be close in range value. +/// +/// # Returns +/// `[(range1, end1), (range2, end2)]`, where: +/// * `range1` contains bytes for the ranges from `ranges[0..end1]` +/// * `range2` contains bytes for the ranges from `ranges[end1..end2]` +/// * etc.. +/// +/// Note that if an end value is 0, it means the range is a splitted part and should be combined. +fn merge_ranges(ranges: &[Range]) -> impl Iterator, usize)> + '_ { + let chunk_size = get_download_chunk_size(); + + let mut current_merged_range = ranges.first().map_or(0..0, Clone::clone); + // Number of fetched bytes excluding excess. + let mut current_n_bytes = current_merged_range.len(); + + (0..ranges.len()) + .filter_map(move |current_idx| { + let current_idx = 1 + current_idx; + + if current_idx == ranges.len() { + // No more items - flush current state. + Some((current_merged_range.clone(), current_idx)) + } else { + let range = ranges[current_idx].clone(); + + let new_merged = current_merged_range.start.min(range.start) + ..current_merged_range.end.max(range.end); + + // E.g.: + // |--------| + // oo // range1 + // oo // range2 + // ^^^ // distance = 3, is_overlapping = false + // E.g.: + // |--------| + // ooooo // range1 + // ooooo // range2 + // ^^ // distance = 2, is_overlapping = true + let (distance, is_overlapping) = { + let l = current_merged_range.end.min(range.end); + let r = current_merged_range.start.max(range.start); + + (r.abs_diff(l), r < l) + }; + + let should_merge = is_overlapping || { + let leq_current_len_dist_to_chunk_size = new_merged.len().abs_diff(chunk_size) + <= current_merged_range.len().abs_diff(chunk_size); + let gap_tolerance = + (current_n_bytes.max(range.len()) / 8).clamp(1024 * 1024, 8 * 1024 * 1024); + + leq_current_len_dist_to_chunk_size && distance <= gap_tolerance + }; + + if should_merge { + // Merge to existing range + current_merged_range = new_merged; + current_n_bytes += if is_overlapping { + range.len() - distance + } else { + range.len() + }; + None + } else { + let out = (current_merged_range.clone(), current_idx); + current_merged_range = range; + current_n_bytes = current_merged_range.len(); + Some(out) + } + } + }) + .flat_map(|x| { + // Split large individual ranges within the list of ranges. + let (range, end) = x; + let split = split_range(range.clone()); + let len = split.len(); + + split + .enumerate() + .map(move |(i, range)| (range, if 1 + i == len { end } else { 0 })) + }) +} + +#[cfg(test)] +mod tests { + + #[test] + fn test_split_range() { + use super::{get_download_chunk_size, split_range}; + + let chunk_size = get_download_chunk_size(); + + assert_eq!(chunk_size, 64 * 1024 * 1024); + + #[allow(clippy::single_range_in_vec_init)] + { + // Round-trip empty ranges. + assert_eq!(split_range(0..0).collect::>(), [0..0]); + assert_eq!(split_range(3..3).collect::>(), [3..3]); + } + + // Threshold to start splitting to 2 ranges + // + // n - chunk_size == chunk_size - n / 2 + // n + n / 2 == 2 * chunk_size + // 3 * n == 4 * chunk_size + // n = 4 * chunk_size / 3 + let n = 4 * chunk_size / 3; + + #[allow(clippy::single_range_in_vec_init)] + { + assert_eq!(split_range(0..n).collect::>(), [0..89478485]); + } + + assert_eq!( + split_range(0..n + 1).collect::>(), + [0..44739243, 44739243..89478486] + ); + + // Threshold to start splitting to 3 ranges + // + // n / 2 - chunk_size == chunk_size - n / 3 + // n / 2 + n / 3 == 2 * chunk_size + // 5 * n == 12 * chunk_size + // n == 12 * chunk_size / 5 + let n = 12 * chunk_size / 5; + + assert_eq!( + split_range(0..n).collect::>(), + [0..80530637, 80530637..161061273] + ); + + assert_eq!( + split_range(0..n + 1).collect::>(), + [0..53687092, 53687092..107374183, 107374183..161061274] + ); + } + + #[test] + fn test_merge_ranges() { + use super::{get_download_chunk_size, merge_ranges}; + + let chunk_size = get_download_chunk_size(); + + assert_eq!(chunk_size, 64 * 1024 * 1024); + + // Round-trip empty slice + assert_eq!(merge_ranges(&[]).collect::>(), []); + + // We have 1 tiny request followed by 1 huge request. They are combined as it reduces the + // `abs_diff()` to the `chunk_size`, but afterwards they are split to 2 evenly sized + // requests. + assert_eq!( + merge_ranges(&[0..1, 1..127 * 1024 * 1024]).collect::>(), + [(0..66584576, 0), (66584576..133169152, 2)] + ); + + // <= 1MiB gap, merge + assert_eq!( + merge_ranges(&[0..1, 1024 * 1024 + 1..1024 * 1024 + 2]).collect::>(), + [(0..1048578, 2)] + ); + + // > 1MiB gap, do not merge + assert_eq!( + merge_ranges(&[0..1, 1024 * 1024 + 2..1024 * 1024 + 3]).collect::>(), + [(0..1, 1), (1048578..1048579, 2)] + ); + + // <= 12.5% gap, merge + assert_eq!( + merge_ranges(&[0..8, 10..11]).collect::>(), + [(0..11, 2)] + ); + + // <= 12.5% gap relative to RHS, merge + assert_eq!( + merge_ranges(&[0..1, 3..11]).collect::>(), + [(0..11, 2)] + ); + + // Overlapping range, merge + assert_eq!( + merge_ranges(&[0..80 * 1024 * 1024, 10 * 1024 * 1024..70 * 1024 * 1024]) + .collect::>(), + [(0..80 * 1024 * 1024, 2)] + ); + } +} diff --git a/crates/polars-io/src/csv/mod.rs b/crates/polars-io/src/csv/mod.rs new file mode 100644 index 000000000000..47a277f861c9 --- /dev/null +++ b/crates/polars-io/src/csv/mod.rs @@ -0,0 +1,4 @@ +//! Functionality for reading and writing CSV files. + +pub mod read; +pub mod write; diff --git a/crates/polars-io/src/csv/read/buffer.rs b/crates/polars-io/src/csv/read/buffer.rs new file mode 100644 index 000000000000..4db5e4e8f395 --- /dev/null +++ b/crates/polars-io/src/csv/read/buffer.rs @@ -0,0 +1,987 @@ +use arrow::array::MutableBinaryViewArray; +use polars_core::prelude::*; +use polars_error::to_compute_err; +#[cfg(any(feature = "dtype-datetime", feature = "dtype-date"))] +use polars_time::chunkedarray::string::Pattern; +#[cfg(any(feature = "dtype-datetime", feature = "dtype-date"))] +use polars_time::prelude::string::infer::{ + DatetimeInfer, StrpTimeParser, TryFromWithUnit, infer_pattern_single, +}; +use polars_utils::vec::PushUnchecked; + +use super::options::CsvEncoding; +use super::parser::{is_whitespace, skip_whitespace}; +use super::utils::escape_field; + +pub(crate) trait PrimitiveParser: PolarsNumericType { + fn parse(bytes: &[u8]) -> Option; +} + +impl PrimitiveParser for Float32Type { + #[inline] + fn parse(bytes: &[u8]) -> Option { + fast_float2::parse(bytes).ok() + } +} +impl PrimitiveParser for Float64Type { + #[inline] + fn parse(bytes: &[u8]) -> Option { + fast_float2::parse(bytes).ok() + } +} + +#[cfg(feature = "dtype-u8")] +impl PrimitiveParser for UInt8Type { + #[inline] + fn parse(bytes: &[u8]) -> Option { + atoi_simd::parse_skipped(bytes).ok() + } +} +#[cfg(feature = "dtype-u16")] +impl PrimitiveParser for UInt16Type { + #[inline] + fn parse(bytes: &[u8]) -> Option { + atoi_simd::parse_skipped(bytes).ok() + } +} +impl PrimitiveParser for UInt32Type { + #[inline] + fn parse(bytes: &[u8]) -> Option { + atoi_simd::parse_skipped(bytes).ok() + } +} +impl PrimitiveParser for UInt64Type { + #[inline] + fn parse(bytes: &[u8]) -> Option { + atoi_simd::parse_skipped(bytes).ok() + } +} +#[cfg(feature = "dtype-i8")] +impl PrimitiveParser for Int8Type { + #[inline] + fn parse(bytes: &[u8]) -> Option { + atoi_simd::parse_skipped(bytes).ok() + } +} +#[cfg(feature = "dtype-i16")] +impl PrimitiveParser for Int16Type { + #[inline] + fn parse(bytes: &[u8]) -> Option { + atoi_simd::parse_skipped(bytes).ok() + } +} +impl PrimitiveParser for Int32Type { + #[inline] + fn parse(bytes: &[u8]) -> Option { + atoi_simd::parse_skipped(bytes).ok() + } +} +impl PrimitiveParser for Int64Type { + #[inline] + fn parse(bytes: &[u8]) -> Option { + atoi_simd::parse_skipped(bytes).ok() + } +} +#[cfg(feature = "dtype-i128")] +impl PrimitiveParser for Int128Type { + #[inline] + fn parse(bytes: &[u8]) -> Option { + atoi_simd::parse_skipped(bytes).ok() + } +} + +trait ParsedBuffer { + fn parse_bytes( + &mut self, + bytes: &[u8], + ignore_errors: bool, + _needs_escaping: bool, + _missing_is_null: bool, + _time_unit: Option, + ) -> PolarsResult<()>; +} + +impl ParsedBuffer for PrimitiveChunkedBuilder +where + T: PolarsNumericType + PrimitiveParser, +{ + #[inline] + fn parse_bytes( + &mut self, + bytes: &[u8], + ignore_errors: bool, + needs_escaping: bool, + _missing_is_null: bool, + _time_unit: Option, + ) -> PolarsResult<()> { + if bytes.is_empty() { + self.append_null() + } else { + let bytes = if needs_escaping { + &bytes[1..bytes.len() - 1] + } else { + bytes + }; + + // legacy comment (remember this if you decide to use Results again): + // its faster to work on options. + // if we need to throw an error, we parse again to be able to throw the error + + match T::parse(bytes) { + Some(value) => self.append_value(value), + None => { + // try again without whitespace + if !bytes.is_empty() && is_whitespace(bytes[0]) { + let bytes = skip_whitespace(bytes); + return self.parse_bytes( + bytes, + ignore_errors, + false, // escaping was already done + _missing_is_null, + None, + ); + } + polars_ensure!( + bytes.is_empty() || ignore_errors, + ComputeError: "remaining bytes non-empty", + ); + self.append_null() + }, + }; + } + Ok(()) + } +} + +pub struct Utf8Field { + name: PlSmallStr, + mutable: MutableBinaryViewArray<[u8]>, + scratch: Vec, + quote_char: u8, + encoding: CsvEncoding, +} + +impl Utf8Field { + fn new( + name: PlSmallStr, + capacity: usize, + quote_char: Option, + encoding: CsvEncoding, + ) -> Self { + Self { + name, + mutable: MutableBinaryViewArray::with_capacity(capacity), + scratch: vec![], + quote_char: quote_char.unwrap_or(b'"'), + encoding, + } + } +} + +#[inline] +pub fn validate_utf8(bytes: &[u8]) -> bool { + simdutf8::basic::from_utf8(bytes).is_ok() +} + +impl ParsedBuffer for Utf8Field { + #[inline] + fn parse_bytes( + &mut self, + bytes: &[u8], + ignore_errors: bool, + needs_escaping: bool, + missing_is_null: bool, + _time_unit: Option, + ) -> PolarsResult<()> { + if bytes.is_empty() { + if missing_is_null { + self.mutable.push_null() + } else { + self.mutable.push(Some([])) + } + return Ok(()); + } + + // note that one branch writes without updating the length, so we must do that later. + let escaped_bytes = if needs_escaping { + self.scratch.clear(); + self.scratch.reserve(bytes.len()); + polars_ensure!(bytes.len() > 1 && bytes.last() == Some(&self.quote_char), ComputeError: "invalid csv file\n\nField `{}` is not properly escaped.", std::str::from_utf8(bytes).map_err(to_compute_err)?); + + // SAFETY: + // we just allocated enough capacity and data_len is correct. + unsafe { + let n_written = + escape_field(bytes, self.quote_char, self.scratch.spare_capacity_mut()); + self.scratch.set_len(n_written); + } + + self.scratch.as_slice() + } else { + bytes + }; + + if matches!(self.encoding, CsvEncoding::LossyUtf8) | ignore_errors { + // It is important that this happens after escaping, as invalid escaped string can produce + // invalid utf8. + let parse_result = validate_utf8(escaped_bytes); + + match parse_result { + true => { + let value = escaped_bytes; + self.mutable.push_value(value) + }, + false => { + if matches!(self.encoding, CsvEncoding::LossyUtf8) { + // TODO! do this without allocating + let s = String::from_utf8_lossy(escaped_bytes); + self.mutable.push_value(s.as_ref().as_bytes()) + } else if ignore_errors { + self.mutable.push_null() + } else { + // If field before escaping is valid utf8, the escaping is incorrect. + if needs_escaping && validate_utf8(bytes) { + polars_bail!(ComputeError: "string field is not properly escaped"); + } else { + polars_bail!(ComputeError: "invalid utf-8 sequence"); + } + } + }, + } + } else { + self.mutable.push_value(escaped_bytes) + } + + Ok(()) + } +} + +#[cfg(not(feature = "dtype-categorical"))] +pub struct CategoricalField { + phantom: std::marker::PhantomData, +} + +#[cfg(feature = "dtype-categorical")] +pub struct CategoricalField { + escape_scratch: Vec, + quote_char: u8, + builder: CategoricalChunkedBuilder, + is_enum: bool, +} + +#[cfg(feature = "dtype-categorical")] +impl CategoricalField { + fn new( + name: PlSmallStr, + capacity: usize, + quote_char: Option, + ordering: CategoricalOrdering, + ) -> Self { + let builder = CategoricalChunkedBuilder::new(name, capacity, ordering); + + Self { + escape_scratch: vec![], + quote_char: quote_char.unwrap_or(b'"'), + builder, + is_enum: false, + } + } + + fn new_enum(quote_char: Option, builder: CategoricalChunkedBuilder) -> Self { + Self { + escape_scratch: vec![], + quote_char: quote_char.unwrap_or(b'"'), + builder, + is_enum: true, + } + } + + #[inline] + fn parse_bytes( + &mut self, + bytes: &[u8], + ignore_errors: bool, + needs_escaping: bool, + _missing_is_null: bool, + _time_unit: Option, + ) -> PolarsResult<()> { + if bytes.is_empty() { + self.builder.append_null(); + return Ok(()); + } + if validate_utf8(bytes) { + if needs_escaping { + polars_ensure!(bytes.len() > 1, ComputeError: "invalid csv file\n\nField `{}` is not properly escaped.", std::str::from_utf8(bytes).map_err(to_compute_err)?); + self.escape_scratch.clear(); + self.escape_scratch.reserve(bytes.len()); + // SAFETY: + // we just allocated enough capacity and data_len is correct. + unsafe { + let n_written = escape_field( + bytes, + self.quote_char, + self.escape_scratch.spare_capacity_mut(), + ); + self.escape_scratch.set_len(n_written); + } + + // SAFETY: + // just did utf8 check + let key = unsafe { std::str::from_utf8_unchecked(&self.escape_scratch) }; + if self.is_enum { + self.builder.try_append_value(key)?; + } else { + self.builder.append_value(key); + } + } else { + // SAFETY: + // just did utf8 check + let key = unsafe { std::str::from_utf8_unchecked(bytes) }; + if self.is_enum { + self.builder.try_append_value(key)? + } else { + self.builder.append_value(key) + } + } + } else if ignore_errors { + self.builder.append_null() + } else { + polars_bail!(ComputeError: "invalid utf-8 sequence"); + } + Ok(()) + } +} + +impl ParsedBuffer for BooleanChunkedBuilder { + #[inline] + fn parse_bytes( + &mut self, + bytes: &[u8], + ignore_errors: bool, + needs_escaping: bool, + _missing_is_null: bool, + _time_unit: Option, + ) -> PolarsResult<()> { + let bytes = if needs_escaping { + &bytes[1..bytes.len() - 1] + } else { + bytes + }; + if bytes.eq_ignore_ascii_case(b"false") { + self.append_value(false); + } else if bytes.eq_ignore_ascii_case(b"true") { + self.append_value(true); + } else if ignore_errors || bytes.is_empty() { + self.append_null(); + } else { + polars_bail!( + ComputeError: "error while parsing value {} as boolean", + String::from_utf8_lossy(bytes), + ); + } + Ok(()) + } +} + +#[cfg(any(feature = "dtype-datetime", feature = "dtype-date"))] +pub struct DatetimeField { + compiled: Option>, + builder: PrimitiveChunkedBuilder, +} + +#[cfg(any(feature = "dtype-datetime", feature = "dtype-date"))] +impl DatetimeField { + fn new(name: PlSmallStr, capacity: usize) -> Self { + let builder = PrimitiveChunkedBuilder::::new(name, capacity); + Self { + compiled: None, + builder, + } + } +} + +#[cfg(any(feature = "dtype-datetime", feature = "dtype-date"))] +fn slow_datetime_parser( + buf: &mut DatetimeField, + bytes: &[u8], + time_unit: Option, + ignore_errors: bool, +) -> PolarsResult<()> +where + T: PolarsNumericType, + DatetimeInfer: TryFromWithUnit, +{ + let val = if bytes.is_ascii() { + // SAFETY: + // we just checked it is ascii + unsafe { std::str::from_utf8_unchecked(bytes) } + } else { + match std::str::from_utf8(bytes) { + Ok(val) => val, + Err(_) => { + if ignore_errors { + buf.builder.append_null(); + return Ok(()); + } else { + polars_bail!(ComputeError: "invalid utf-8 sequence"); + } + }, + } + }; + + let pattern = match &buf.compiled { + Some(compiled) => compiled.pattern, + None => match infer_pattern_single(val) { + Some(pattern) => pattern, + None => { + if ignore_errors { + buf.builder.append_null(); + return Ok(()); + } else { + polars_bail!(ComputeError: "could not find a 'date/datetime' pattern for '{}'", val) + } + }, + }, + }; + match DatetimeInfer::try_from_with_unit(pattern, time_unit) { + Ok(mut infer) => { + let parsed = infer.parse(val); + let Some(parsed) = parsed else { + if ignore_errors { + buf.builder.append_null(); + return Ok(()); + } else { + polars_bail!(ComputeError: "could not parse '{}' with pattern '{:?}'", val, pattern) + } + }; + + buf.compiled = Some(infer); + buf.builder.append_value(parsed); + Ok(()) + }, + Err(err) => { + if ignore_errors { + buf.builder.append_null(); + Ok(()) + } else { + Err(err) + } + }, + } +} + +#[cfg(any(feature = "dtype-datetime", feature = "dtype-date"))] +impl ParsedBuffer for DatetimeField +where + T: PolarsNumericType, + DatetimeInfer: TryFromWithUnit + StrpTimeParser, +{ + #[inline] + fn parse_bytes( + &mut self, + mut bytes: &[u8], + ignore_errors: bool, + needs_escaping: bool, + _missing_is_null: bool, + time_unit: Option, + ) -> PolarsResult<()> { + if needs_escaping && bytes.len() >= 2 { + bytes = &bytes[1..bytes.len() - 1] + } + + if bytes.is_empty() { + // for types other than string `_missing_is_null` is irrelevant; we always append null + self.builder.append_null(); + return Ok(()); + } + + match &mut self.compiled { + None => slow_datetime_parser(self, bytes, time_unit, ignore_errors), + Some(compiled) => { + match compiled.parse_bytes(bytes, time_unit) { + Some(parsed) => { + self.builder.append_value(parsed); + Ok(()) + }, + // fall back on chrono parser + // this is a lot slower, we need to do utf8 checking and use + // the slower parser + None => slow_datetime_parser(self, bytes, time_unit, ignore_errors), + } + }, + } + } +} + +pub fn init_buffers( + projection: &[usize], + capacity: usize, + schema: &Schema, + quote_char: Option, + encoding: CsvEncoding, + decimal_comma: bool, +) -> PolarsResult> { + projection + .iter() + .map(|&i| { + let (name, dtype) = schema.get_at_index(i).unwrap(); + let name = name.clone(); + let builder = match dtype { + &DataType::Boolean => Buffer::Boolean(BooleanChunkedBuilder::new(name, capacity)), + #[cfg(feature = "dtype-i8")] + &DataType::Int8 => Buffer::Int8(PrimitiveChunkedBuilder::new(name, capacity)), + #[cfg(feature = "dtype-i16")] + &DataType::Int16 => Buffer::Int16(PrimitiveChunkedBuilder::new(name, capacity)), + &DataType::Int32 => Buffer::Int32(PrimitiveChunkedBuilder::new(name, capacity)), + &DataType::Int64 => Buffer::Int64(PrimitiveChunkedBuilder::new(name, capacity)), + #[cfg(feature = "dtype-i128")] + &DataType::Int128 => Buffer::Int128(PrimitiveChunkedBuilder::new(name, capacity)), + #[cfg(feature = "dtype-u8")] + &DataType::UInt8 => Buffer::UInt8(PrimitiveChunkedBuilder::new(name, capacity)), + #[cfg(feature = "dtype-u16")] + &DataType::UInt16 => Buffer::UInt16(PrimitiveChunkedBuilder::new(name, capacity)), + &DataType::UInt32 => Buffer::UInt32(PrimitiveChunkedBuilder::new(name, capacity)), + &DataType::UInt64 => Buffer::UInt64(PrimitiveChunkedBuilder::new(name, capacity)), + &DataType::Float32 => { + if decimal_comma { + Buffer::DecimalFloat32( + PrimitiveChunkedBuilder::new(name, capacity), + Default::default(), + ) + } else { + Buffer::Float32(PrimitiveChunkedBuilder::new(name, capacity)) + } + }, + &DataType::Float64 => { + if decimal_comma { + Buffer::DecimalFloat64( + PrimitiveChunkedBuilder::new(name, capacity), + Default::default(), + ) + } else { + Buffer::Float64(PrimitiveChunkedBuilder::new(name, capacity)) + } + }, + &DataType::String => { + Buffer::Utf8(Utf8Field::new(name, capacity, quote_char, encoding)) + }, + #[cfg(feature = "dtype-datetime")] + DataType::Datetime(time_unit, time_zone) => Buffer::Datetime { + buf: DatetimeField::new(name, capacity), + time_unit: *time_unit, + time_zone: time_zone.clone(), + }, + #[cfg(feature = "dtype-date")] + &DataType::Date => Buffer::Date(DatetimeField::new(name, capacity)), + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_, ordering) => Buffer::Categorical(CategoricalField::new( + name, capacity, quote_char, *ordering, + )), + #[cfg(feature = "dtype-categorical")] + DataType::Enum(rev_map, _) => { + let Some(rev_map) = rev_map else { + polars_bail!(ComputeError: "enum categories must be set") + }; + let cats = rev_map.get_categories(); + let mut builder = + CategoricalChunkedBuilder::new(name, capacity, Default::default()); + for cat in cats.values_iter() { + builder.register_value(cat); + } + Buffer::Categorical(CategoricalField::new_enum(quote_char, builder)) + }, + dt => polars_bail!( + ComputeError: "unsupported data type when reading CSV: {} when reading CSV", dt, + ), + }; + Ok(builder) + }) + .collect() +} + +#[allow(clippy::large_enum_variant)] +pub enum Buffer { + Boolean(BooleanChunkedBuilder), + #[cfg(feature = "dtype-i8")] + Int8(PrimitiveChunkedBuilder), + #[cfg(feature = "dtype-i16")] + Int16(PrimitiveChunkedBuilder), + Int32(PrimitiveChunkedBuilder), + Int64(PrimitiveChunkedBuilder), + #[cfg(feature = "dtype-i128")] + Int128(PrimitiveChunkedBuilder), + #[cfg(feature = "dtype-u8")] + UInt8(PrimitiveChunkedBuilder), + #[cfg(feature = "dtype-u16")] + UInt16(PrimitiveChunkedBuilder), + UInt32(PrimitiveChunkedBuilder), + UInt64(PrimitiveChunkedBuilder), + Float32(PrimitiveChunkedBuilder), + Float64(PrimitiveChunkedBuilder), + /// Stores the Utf8 fields and the total string length seen for that column + Utf8(Utf8Field), + #[cfg(feature = "dtype-datetime")] + Datetime { + buf: DatetimeField, + time_unit: TimeUnit, + time_zone: Option, + }, + #[cfg(feature = "dtype-date")] + Date(DatetimeField), + #[allow(dead_code)] + Categorical(CategoricalField), + DecimalFloat32(PrimitiveChunkedBuilder, Vec), + DecimalFloat64(PrimitiveChunkedBuilder, Vec), +} + +impl Buffer { + pub fn into_series(self) -> PolarsResult { + let s = match self { + Buffer::Boolean(v) => v.finish().into_series(), + #[cfg(feature = "dtype-i8")] + Buffer::Int8(v) => v.finish().into_series(), + #[cfg(feature = "dtype-i16")] + Buffer::Int16(v) => v.finish().into_series(), + Buffer::Int32(v) => v.finish().into_series(), + Buffer::Int64(v) => v.finish().into_series(), + #[cfg(feature = "dtype-i128")] + Buffer::Int128(v) => v.finish().into_series(), + #[cfg(feature = "dtype-u8")] + Buffer::UInt8(v) => v.finish().into_series(), + #[cfg(feature = "dtype-u16")] + Buffer::UInt16(v) => v.finish().into_series(), + Buffer::UInt32(v) => v.finish().into_series(), + Buffer::UInt64(v) => v.finish().into_series(), + Buffer::Float32(v) => v.finish().into_series(), + Buffer::Float64(v) => v.finish().into_series(), + Buffer::DecimalFloat32(v, _) => v.finish().into_series(), + Buffer::DecimalFloat64(v, _) => v.finish().into_series(), + #[cfg(feature = "dtype-datetime")] + Buffer::Datetime { + buf, + time_unit, + time_zone, + } => buf + .builder + .finish() + .into_series() + .cast(&DataType::Datetime(time_unit, time_zone)) + .unwrap(), + #[cfg(feature = "dtype-date")] + Buffer::Date(v) => v + .builder + .finish() + .into_series() + .cast(&DataType::Date) + .unwrap(), + + Buffer::Utf8(v) => { + let arr = v.mutable.freeze(); + StringChunked::with_chunk(v.name.clone(), unsafe { arr.to_utf8view_unchecked() }) + .into_series() + }, + #[allow(unused_variables)] + Buffer::Categorical(buf) => { + #[cfg(feature = "dtype-categorical")] + { + let ca = buf.builder.finish(); + + if buf.is_enum { + let DataType::Categorical(Some(rev_map), _) = ca.dtype() else { + unreachable!() + }; + let idx = ca.physical().clone(); + let dtype = DataType::Enum(Some(rev_map.clone()), Default::default()); + + unsafe { + CategoricalChunked::from_cats_and_dtype_unchecked(idx, dtype) + .into_series() + } + } else { + ca.into_series() + } + } + #[cfg(not(feature = "dtype-categorical"))] + { + panic!("activate 'dtype-categorical' feature") + } + }, + }; + Ok(s) + } + + pub fn add_null(&mut self, valid: bool) { + match self { + Buffer::Boolean(v) => v.append_null(), + #[cfg(feature = "dtype-i8")] + Buffer::Int8(v) => v.append_null(), + #[cfg(feature = "dtype-i16")] + Buffer::Int16(v) => v.append_null(), + Buffer::Int32(v) => v.append_null(), + Buffer::Int64(v) => v.append_null(), + #[cfg(feature = "dtype-i128")] + Buffer::Int128(v) => v.append_null(), + #[cfg(feature = "dtype-u8")] + Buffer::UInt8(v) => v.append_null(), + #[cfg(feature = "dtype-u16")] + Buffer::UInt16(v) => v.append_null(), + Buffer::UInt32(v) => v.append_null(), + Buffer::UInt64(v) => v.append_null(), + Buffer::Float32(v) => v.append_null(), + Buffer::Float64(v) => v.append_null(), + Buffer::DecimalFloat32(v, _) => v.append_null(), + Buffer::DecimalFloat64(v, _) => v.append_null(), + Buffer::Utf8(v) => { + if valid { + v.mutable.push_value("") + } else { + v.mutable.push_null() + } + }, + #[cfg(feature = "dtype-datetime")] + Buffer::Datetime { buf, .. } => buf.builder.append_null(), + #[cfg(feature = "dtype-date")] + Buffer::Date(v) => v.builder.append_null(), + #[allow(unused_variables)] + Buffer::Categorical(cat_builder) => { + #[cfg(feature = "dtype-categorical")] + { + cat_builder.builder.append_null() + } + #[cfg(not(feature = "dtype-categorical"))] + { + panic!("activate 'dtype-categorical' feature") + } + }, + }; + } + + pub fn dtype(&self) -> DataType { + match self { + Buffer::Boolean(_) => DataType::Boolean, + #[cfg(feature = "dtype-i8")] + Buffer::Int8(_) => DataType::Int8, + #[cfg(feature = "dtype-i16")] + Buffer::Int16(_) => DataType::Int16, + Buffer::Int32(_) => DataType::Int32, + Buffer::Int64(_) => DataType::Int64, + #[cfg(feature = "dtype-i128")] + Buffer::Int128(_) => DataType::Int128, + #[cfg(feature = "dtype-u8")] + Buffer::UInt8(_) => DataType::UInt8, + #[cfg(feature = "dtype-u16")] + Buffer::UInt16(_) => DataType::UInt16, + Buffer::UInt32(_) => DataType::UInt32, + Buffer::UInt64(_) => DataType::UInt64, + Buffer::Float32(_) | Buffer::DecimalFloat32(_, _) => DataType::Float32, + Buffer::Float64(_) | Buffer::DecimalFloat64(_, _) => DataType::Float64, + Buffer::Utf8(_) => DataType::String, + #[cfg(feature = "dtype-datetime")] + Buffer::Datetime { time_unit, .. } => DataType::Datetime(*time_unit, None), + #[cfg(feature = "dtype-date")] + Buffer::Date(_) => DataType::Date, + Buffer::Categorical(_) => { + #[cfg(feature = "dtype-categorical")] + { + DataType::Categorical(None, Default::default()) + } + + #[cfg(not(feature = "dtype-categorical"))] + { + panic!("activate 'dtype-categorical' feature") + } + }, + } + } + + #[inline] + pub fn add( + &mut self, + bytes: &[u8], + ignore_errors: bool, + needs_escaping: bool, + missing_is_null: bool, + ) -> PolarsResult<()> { + use Buffer::*; + match self { + Boolean(buf) => ::parse_bytes( + buf, + bytes, + ignore_errors, + needs_escaping, + missing_is_null, + None, + ), + #[cfg(feature = "dtype-i8")] + Int8(buf) => as ParsedBuffer>::parse_bytes( + buf, + bytes, + ignore_errors, + needs_escaping, + missing_is_null, + None, + ), + #[cfg(feature = "dtype-i16")] + Int16(buf) => as ParsedBuffer>::parse_bytes( + buf, + bytes, + ignore_errors, + needs_escaping, + missing_is_null, + None, + ), + Int32(buf) => as ParsedBuffer>::parse_bytes( + buf, + bytes, + ignore_errors, + needs_escaping, + missing_is_null, + None, + ), + Int64(buf) => as ParsedBuffer>::parse_bytes( + buf, + bytes, + ignore_errors, + needs_escaping, + missing_is_null, + None, + ), + #[cfg(feature = "dtype-i128")] + Int128(buf) => as ParsedBuffer>::parse_bytes( + buf, + bytes, + ignore_errors, + needs_escaping, + missing_is_null, + None, + ), + #[cfg(feature = "dtype-u8")] + UInt8(buf) => as ParsedBuffer>::parse_bytes( + buf, + bytes, + ignore_errors, + needs_escaping, + missing_is_null, + None, + ), + #[cfg(feature = "dtype-u16")] + UInt16(buf) => as ParsedBuffer>::parse_bytes( + buf, + bytes, + ignore_errors, + needs_escaping, + missing_is_null, + None, + ), + UInt32(buf) => as ParsedBuffer>::parse_bytes( + buf, + bytes, + ignore_errors, + needs_escaping, + missing_is_null, + None, + ), + UInt64(buf) => as ParsedBuffer>::parse_bytes( + buf, + bytes, + ignore_errors, + needs_escaping, + missing_is_null, + None, + ), + Float32(buf) => as ParsedBuffer>::parse_bytes( + buf, + bytes, + ignore_errors, + needs_escaping, + missing_is_null, + None, + ), + Float64(buf) => as ParsedBuffer>::parse_bytes( + buf, + bytes, + ignore_errors, + needs_escaping, + missing_is_null, + None, + ), + DecimalFloat32(buf, scratch) => { + prepare_decimal_comma(bytes, scratch); + as ParsedBuffer>::parse_bytes( + buf, + scratch, + ignore_errors, + needs_escaping, + missing_is_null, + None, + ) + }, + DecimalFloat64(buf, scratch) => { + prepare_decimal_comma(bytes, scratch); + as ParsedBuffer>::parse_bytes( + buf, + scratch, + ignore_errors, + needs_escaping, + missing_is_null, + None, + ) + }, + Utf8(buf) => ::parse_bytes( + buf, + bytes, + ignore_errors, + needs_escaping, + missing_is_null, + None, + ), + #[cfg(feature = "dtype-datetime")] + Datetime { buf, time_unit, .. } => { + as ParsedBuffer>::parse_bytes( + buf, + bytes, + ignore_errors, + needs_escaping, + missing_is_null, + Some(*time_unit), + ) + }, + #[cfg(feature = "dtype-date")] + Date(buf) => as ParsedBuffer>::parse_bytes( + buf, + bytes, + ignore_errors, + needs_escaping, + missing_is_null, + None, + ), + #[allow(unused_variables)] + Categorical(buf) => { + #[cfg(feature = "dtype-categorical")] + { + buf.parse_bytes(bytes, ignore_errors, needs_escaping, missing_is_null, None) + } + + #[cfg(not(feature = "dtype-categorical"))] + { + panic!("activate 'dtype-categorical' feature") + } + }, + } + } +} + +#[inline] +fn prepare_decimal_comma(bytes: &[u8], scratch: &mut Vec) { + scratch.clear(); + scratch.reserve(bytes.len()); + + // SAFETY: we pre-allocated. + for &byte in bytes { + if byte == b',' { + unsafe { scratch.push_unchecked(b'.') } + } else { + unsafe { scratch.push_unchecked(byte) } + } + } +} diff --git a/crates/polars-io/src/csv/read/mod.rs b/crates/polars-io/src/csv/read/mod.rs new file mode 100644 index 000000000000..a707abafb0e6 --- /dev/null +++ b/crates/polars-io/src/csv/read/mod.rs @@ -0,0 +1,40 @@ +//! Functionality for reading CSV files. +//! +//! # Examples +//! +//! ``` +//! use polars_core::prelude::*; +//! use polars_io::prelude::*; +//! use std::fs::File; +//! +//! fn example() -> PolarsResult { +//! // Prefer `from_path` over `new` as it is faster. +//! CsvReadOptions::default() +//! .with_has_header(true) +//! .try_into_reader_with_file_path(Some("example.csv".into()))? +//! .finish() +//! } +//! ``` + +pub mod buffer; +mod options; +mod parser; +mod read_impl; +mod reader; +pub mod schema_inference; +mod splitfields; +mod utils; + +pub use options::{CommentPrefix, CsvEncoding, CsvParseOptions, CsvReadOptions, NullValues}; +pub use parser::{count_rows, count_rows_from_slice, count_rows_from_slice_par}; +pub use read_impl::batched::{BatchedCsvReader, OwnedBatchedCsvReader}; +pub use reader::CsvReader; +pub use schema_inference::infer_file_schema; + +pub mod _csv_read_internal { + pub use super::buffer::validate_utf8; + pub use super::options::NullValuesCompiled; + pub use super::parser::CountLines; + pub use super::read_impl::{cast_columns, find_starting_point, read_chunk}; + pub use super::reader::prepare_csv_schema; +} diff --git a/crates/polars-io/src/csv/read/options.rs b/crates/polars-io/src/csv/read/options.rs new file mode 100644 index 000000000000..d5a849ea2679 --- /dev/null +++ b/crates/polars-io/src/csv/read/options.rs @@ -0,0 +1,432 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use std::path::PathBuf; +use std::sync::Arc; + +use polars_core::datatypes::{DataType, Field}; +use polars_core::schema::{Schema, SchemaRef}; +use polars_error::PolarsResult; +use polars_utils::pl_str::PlSmallStr; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use crate::RowIndex; + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct CsvReadOptions { + pub path: Option, + // Performance related options + pub rechunk: bool, + pub n_threads: Option, + pub low_memory: bool, + // Row-wise options + pub n_rows: Option, + pub row_index: Option, + // Column-wise options + pub columns: Option>, + pub projection: Option>>, + pub schema: Option, + pub schema_overwrite: Option, + pub dtype_overwrite: Option>>, + // CSV-specific options + pub parse_options: Arc, + pub has_header: bool, + pub chunk_size: usize, + /// Skip rows according to the CSV spec. + pub skip_rows: usize, + /// Skip lines according to newline char (e.g. escaping will be ignored) + pub skip_lines: usize, + pub skip_rows_after_header: usize, + pub infer_schema_length: Option, + pub raise_if_empty: bool, + pub ignore_errors: bool, + pub fields_to_cast: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct CsvParseOptions { + pub separator: u8, + pub quote_char: Option, + pub eol_char: u8, + pub encoding: CsvEncoding, + pub null_values: Option, + pub missing_is_null: bool, + pub truncate_ragged_lines: bool, + pub comment_prefix: Option, + pub try_parse_dates: bool, + pub decimal_comma: bool, +} + +impl Default for CsvReadOptions { + fn default() -> Self { + Self { + path: None, + + rechunk: false, + n_threads: None, + low_memory: false, + + n_rows: None, + row_index: None, + + columns: None, + projection: None, + schema: None, + schema_overwrite: None, + dtype_overwrite: None, + + parse_options: Default::default(), + has_header: true, + chunk_size: 1 << 18, + skip_rows: 0, + skip_lines: 0, + skip_rows_after_header: 0, + infer_schema_length: Some(100), + raise_if_empty: true, + ignore_errors: false, + fields_to_cast: vec![], + } + } +} + +/// Options related to parsing the CSV format. +impl Default for CsvParseOptions { + fn default() -> Self { + Self { + separator: b',', + quote_char: Some(b'"'), + eol_char: b'\n', + encoding: Default::default(), + null_values: None, + missing_is_null: true, + truncate_ragged_lines: false, + comment_prefix: None, + try_parse_dates: false, + decimal_comma: false, + } + } +} + +impl CsvReadOptions { + pub fn get_parse_options(&self) -> Arc { + self.parse_options.clone() + } + + pub fn with_path>(mut self, path: Option

) -> Self { + self.path = path.map(|p| p.into()); + self + } + + /// Whether to makes the columns contiguous in memory. + pub fn with_rechunk(mut self, rechunk: bool) -> Self { + self.rechunk = rechunk; + self + } + + /// Number of threads to use for reading. Defaults to the size of the polars + /// thread pool. + pub fn with_n_threads(mut self, n_threads: Option) -> Self { + self.n_threads = n_threads; + self + } + + /// Reduce memory consumption at the expense of performance + pub fn with_low_memory(mut self, low_memory: bool) -> Self { + self.low_memory = low_memory; + self + } + + /// Limits the number of rows to read. + pub fn with_n_rows(mut self, n_rows: Option) -> Self { + self.n_rows = n_rows; + self + } + + /// Adds a row index column. + pub fn with_row_index(mut self, row_index: Option) -> Self { + self.row_index = row_index; + self + } + + /// Which columns to select. + pub fn with_columns(mut self, columns: Option>) -> Self { + self.columns = columns; + self + } + + /// Which columns to select denoted by their index. The index starts from 0 + /// (i.e. [0, 4] would select the 1st and 5th column). + pub fn with_projection(mut self, projection: Option>>) -> Self { + self.projection = projection; + self + } + + /// Set the schema to use for CSV file. The length of the schema must match + /// the number of columns in the file. If this is [None], the schema is + /// inferred from the file. + pub fn with_schema(mut self, schema: Option) -> Self { + self.schema = schema; + self + } + + /// Overwrites the data types in the schema by column name. + pub fn with_schema_overwrite(mut self, schema_overwrite: Option) -> Self { + self.schema_overwrite = schema_overwrite; + self + } + + /// Overwrite the dtypes in the schema in the order of the slice that's given. + /// This is useful if you don't know the column names beforehand + pub fn with_dtype_overwrite(mut self, dtype_overwrite: Option>>) -> Self { + self.dtype_overwrite = dtype_overwrite; + self + } + + /// Sets the CSV parsing options. See [map_parse_options][Self::map_parse_options] + /// for an easier way to mutate them in-place. + pub fn with_parse_options(mut self, parse_options: CsvParseOptions) -> Self { + self.parse_options = Arc::new(parse_options); + self + } + + /// Sets whether the CSV file has a header row. + pub fn with_has_header(mut self, has_header: bool) -> Self { + self.has_header = has_header; + self + } + + /// Sets the chunk size used by the parser. This influences performance. + pub fn with_chunk_size(mut self, chunk_size: usize) -> Self { + self.chunk_size = chunk_size; + self + } + + /// Start reading after ``skip_rows`` rows. The header will be parsed at this + /// offset. Note that we respect CSV escaping/comments when skipping rows. + /// If you want to skip by newline char only, use `skip_lines`. + pub fn with_skip_rows(mut self, skip_rows: usize) -> Self { + self.skip_rows = skip_rows; + self + } + + /// Start reading after `skip_lines` lines. The header will be parsed at this + /// offset. Note that CSV escaping will not be respected when skipping lines. + /// If you want to skip valid CSV rows, use ``skip_rows``. + pub fn with_skip_lines(mut self, skip_lines: usize) -> Self { + self.skip_lines = skip_lines; + self + } + + /// Number of rows to skip after the header row. + pub fn with_skip_rows_after_header(mut self, skip_rows_after_header: usize) -> Self { + self.skip_rows_after_header = skip_rows_after_header; + self + } + + /// Set the number of rows to use when inferring the csv schema. + /// The default is 100 rows. + /// Setting to [None] will do a full table scan, which is very slow. + pub fn with_infer_schema_length(mut self, infer_schema_length: Option) -> Self { + self.infer_schema_length = infer_schema_length; + self + } + + /// Whether to raise an error if the frame is empty. By default an empty + /// DataFrame is returned. + pub fn with_raise_if_empty(mut self, raise_if_empty: bool) -> Self { + self.raise_if_empty = raise_if_empty; + self + } + + /// Continue with next batch when a ParserError is encountered. + pub fn with_ignore_errors(mut self, ignore_errors: bool) -> Self { + self.ignore_errors = ignore_errors; + self + } + + /// Apply a function to the parse options. + pub fn map_parse_options CsvParseOptions>( + mut self, + map_func: F, + ) -> Self { + let parse_options = Arc::unwrap_or_clone(self.parse_options); + self.parse_options = Arc::new(map_func(parse_options)); + self + } +} + +impl CsvParseOptions { + /// The character used to separate fields in the CSV file. This + /// is most often a comma ','. + pub fn with_separator(mut self, separator: u8) -> Self { + self.separator = separator; + self + } + + /// Set the character used for field quoting. This is most often double + /// quotes '"'. Set this to [None] to disable quote parsing. + pub fn with_quote_char(mut self, quote_char: Option) -> Self { + self.quote_char = quote_char; + self + } + + /// Set the character used to indicate an end-of-line (eol). + pub fn with_eol_char(mut self, eol_char: u8) -> Self { + self.eol_char = eol_char; + self + } + + /// Set the encoding used by the file. + pub fn with_encoding(mut self, encoding: CsvEncoding) -> Self { + self.encoding = encoding; + self + } + + /// Set values that will be interpreted as missing/null. + /// + /// Note: These values are matched before quote-parsing, so if the null values + /// are quoted then those quotes also need to be included here. + pub fn with_null_values(mut self, null_values: Option) -> Self { + self.null_values = null_values; + self + } + + /// Treat missing fields as null. + pub fn with_missing_is_null(mut self, missing_is_null: bool) -> Self { + self.missing_is_null = missing_is_null; + self + } + + /// Truncate lines that are longer than the schema. + pub fn with_truncate_ragged_lines(mut self, truncate_ragged_lines: bool) -> Self { + self.truncate_ragged_lines = truncate_ragged_lines; + self + } + + /// Sets the comment prefix for this instance. Lines starting with this + /// prefix will be ignored. + pub fn with_comment_prefix>( + mut self, + comment_prefix: Option, + ) -> Self { + self.comment_prefix = comment_prefix.map(Into::into); + self + } + + /// Automatically try to parse dates/datetimes and time. If parsing fails, + /// columns remain of dtype [`DataType::String`]. + pub fn with_try_parse_dates(mut self, try_parse_dates: bool) -> Self { + self.try_parse_dates = try_parse_dates; + self + } + + /// Parse floats with a comma as decimal separator. + pub fn with_decimal_comma(mut self, decimal_comma: bool) -> Self { + self.decimal_comma = decimal_comma; + self + } +} + +#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum CsvEncoding { + /// Utf8 encoding. + #[default] + Utf8, + /// Utf8 encoding and unknown bytes are replaced with �. + LossyUtf8, +} + +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum CommentPrefix { + /// A single byte character that indicates the start of a comment line. + Single(u8), + /// A string that indicates the start of a comment line. + /// This allows for multiple characters to be used as a comment identifier. + Multi(PlSmallStr), +} + +impl CommentPrefix { + /// Creates a new `CommentPrefix` for the `Single` variant. + pub fn new_single(prefix: u8) -> Self { + CommentPrefix::Single(prefix) + } + + /// Creates a new `CommentPrefix` for the `Multi` variant. + pub fn new_multi(prefix: PlSmallStr) -> Self { + CommentPrefix::Multi(prefix) + } + + /// Creates a new `CommentPrefix` from a `&str`. + pub fn new_from_str(prefix: &str) -> Self { + if prefix.len() == 1 && prefix.chars().next().unwrap().is_ascii() { + let c = prefix.as_bytes()[0]; + CommentPrefix::Single(c) + } else { + CommentPrefix::Multi(PlSmallStr::from_str(prefix)) + } + } +} + +impl From<&str> for CommentPrefix { + fn from(value: &str) -> Self { + Self::new_from_str(value) + } +} + +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum NullValues { + /// A single value that's used for all columns + AllColumnsSingle(PlSmallStr), + /// Multiple values that are used for all columns + AllColumns(Vec), + /// Tuples that map column names to null value of that column + Named(Vec<(PlSmallStr, PlSmallStr)>), +} + +impl NullValues { + pub fn compile(self, schema: &Schema) -> PolarsResult { + Ok(match self { + NullValues::AllColumnsSingle(v) => NullValuesCompiled::AllColumnsSingle(v), + NullValues::AllColumns(v) => NullValuesCompiled::AllColumns(v), + NullValues::Named(v) => { + let mut null_values = vec![PlSmallStr::from_static(""); schema.len()]; + for (name, null_value) in v { + let i = schema.try_index_of(&name)?; + null_values[i] = null_value; + } + NullValuesCompiled::Columns(null_values) + }, + }) + } +} + +#[derive(Debug, Clone)] +pub enum NullValuesCompiled { + /// A single value that's used for all columns + AllColumnsSingle(PlSmallStr), + // Multiple null values that are null for all columns + AllColumns(Vec), + /// A different null value per column, computed from `NullValues::Named` + Columns(Vec), +} + +impl NullValuesCompiled { + /// # Safety + /// + /// The caller must ensure that `index` is in bounds + pub(super) unsafe fn is_null(&self, field: &[u8], index: usize) -> bool { + use NullValuesCompiled::*; + match self { + AllColumnsSingle(v) => v.as_bytes() == field, + AllColumns(v) => v.iter().any(|v| v.as_bytes() == field), + Columns(v) => { + debug_assert!(index < v.len()); + v.get_unchecked(index).as_bytes() == field + }, + } + } +} diff --git a/crates/polars-io/src/csv/read/parser.rs b/crates/polars-io/src/csv/read/parser.rs new file mode 100644 index 000000000000..29e2e8430ada --- /dev/null +++ b/crates/polars-io/src/csv/read/parser.rs @@ -0,0 +1,1110 @@ +use std::path::Path; + +use memchr::memchr2_iter; +use num_traits::Pow; +use polars_core::prelude::*; +use polars_core::{POOL, config}; +use polars_error::feature_gated; +use polars_utils::mmap::MMapSemaphore; +use polars_utils::select::select_unpredictable; +use rayon::prelude::*; + +use super::CsvParseOptions; +use super::buffer::Buffer; +use super::options::{CommentPrefix, NullValuesCompiled}; +use super::splitfields::SplitFields; +use super::utils::get_file_chunks; +use crate::path_utils::is_cloud_url; +use crate::utils::compression::maybe_decompress_bytes; + +/// Read the number of rows without parsing columns +/// useful for count(*) queries +pub fn count_rows( + path: &Path, + separator: u8, + quote_char: Option, + comment_prefix: Option<&CommentPrefix>, + eol_char: u8, + has_header: bool, +) -> PolarsResult { + let file = if is_cloud_url(path) || config::force_async() { + feature_gated!("cloud", { + crate::file_cache::FILE_CACHE + .get_entry(path.to_str().unwrap()) + // Safety: This was initialized by schema inference. + .unwrap() + .try_open_assume_latest()? + }) + } else { + polars_utils::open_file(path)? + }; + + let mmap = MMapSemaphore::new_from_file(&file).unwrap(); + let owned = &mut vec![]; + let reader_bytes = maybe_decompress_bytes(mmap.as_ref(), owned)?; + + count_rows_from_slice_par( + reader_bytes, + separator, + quote_char, + comment_prefix, + eol_char, + has_header, + ) +} + +/// Read the number of rows without parsing columns +/// useful for count(*) queries +pub fn count_rows_from_slice_par( + mut bytes: &[u8], + separator: u8, + quote_char: Option, + comment_prefix: Option<&CommentPrefix>, + eol_char: u8, + has_header: bool, +) -> PolarsResult { + for _ in 0..bytes.len() { + if bytes[0] != eol_char { + break; + } + + bytes = &bytes[1..]; + } + + const MIN_ROWS_PER_THREAD: usize = 1024; + let max_threads = POOL.current_num_threads(); + + // Determine if parallelism is beneficial and how many threads + let n_threads = get_line_stats( + bytes, + MIN_ROWS_PER_THREAD, + eol_char, + None, + separator, + quote_char, + ) + .map(|(mean, std)| { + let n_rows = (bytes.len() as f32 / (mean - 0.01 * std)) as usize; + (n_rows / MIN_ROWS_PER_THREAD).clamp(1, max_threads) + }) + .unwrap_or(1); + + if n_threads == 1 { + return count_rows_from_slice(bytes, quote_char, comment_prefix, eol_char, has_header); + } + + let file_chunks: Vec<(usize, usize)> = + get_file_chunks(bytes, n_threads, None, separator, quote_char, eol_char); + + let iter = file_chunks.into_par_iter().map(|(start, stop)| { + let bytes = &bytes[start..stop]; + + if comment_prefix.is_some() { + SplitLines::new(bytes, quote_char, eol_char, comment_prefix) + .filter(|line| !is_comment_line(line, comment_prefix)) + .count() + } else { + CountLines::new(quote_char, eol_char).count(bytes).0 + } + }); + + let n: usize = POOL.install(|| iter.sum()); + + Ok(n - (has_header as usize)) +} + +/// Read the number of rows without parsing columns +pub fn count_rows_from_slice( + mut bytes: &[u8], + quote_char: Option, + comment_prefix: Option<&CommentPrefix>, + eol_char: u8, + has_header: bool, +) -> PolarsResult { + for _ in 0..bytes.len() { + if bytes[0] != eol_char { + break; + } + + bytes = &bytes[1..]; + } + + let n = if comment_prefix.is_some() { + SplitLines::new(bytes, quote_char, eol_char, comment_prefix) + .filter(|line| !is_comment_line(line, comment_prefix)) + .count() + } else { + CountLines::new(quote_char, eol_char).count(bytes).0 + }; + + Ok(n - (has_header as usize)) +} + +/// Skip the utf-8 Byte Order Mark. +/// credits to csv-core +pub(super) fn skip_bom(input: &[u8]) -> &[u8] { + if input.len() >= 3 && &input[0..3] == b"\xef\xbb\xbf" { + &input[3..] + } else { + input + } +} + +/// Checks if a line in a CSV file is a comment based on the given comment prefix configuration. +/// +/// This function is used during CSV parsing to determine whether a line should be ignored based on its starting characters. +#[inline] +pub(super) fn is_comment_line(line: &[u8], comment_prefix: Option<&CommentPrefix>) -> bool { + match comment_prefix { + Some(CommentPrefix::Single(c)) => line.first() == Some(c), + Some(CommentPrefix::Multi(s)) => line.starts_with(s.as_bytes()), + None => false, + } +} + +/// Find the nearest next line position. +/// Does not check for new line characters embedded in String fields. +pub(super) fn next_line_position_naive(input: &[u8], eol_char: u8) -> Option { + let pos = memchr::memchr(eol_char, input)? + 1; + if input.len() - pos == 0 { + return None; + } + Some(pos) +} + +pub(super) fn skip_lines_naive(mut input: &[u8], eol_char: u8, skip: usize) -> &[u8] { + for _ in 0..skip { + if let Some(pos) = next_line_position_naive(input, eol_char) { + input = &input[pos..]; + } else { + return input; + } + } + input +} + +/// Find the nearest next line position that is not embedded in a String field. +pub(super) fn next_line_position( + mut input: &[u8], + mut expected_fields: Option, + separator: u8, + quote_char: Option, + eol_char: u8, +) -> Option { + fn accept_line( + line: &[u8], + expected_fields: usize, + separator: u8, + eol_char: u8, + quote_char: Option, + ) -> bool { + let mut count = 0usize; + for (field, _) in SplitFields::new(line, separator, quote_char, eol_char) { + if memchr2_iter(separator, eol_char, field).count() >= expected_fields { + return false; + } + count += 1; + } + + // if the latest field is missing + // e.g.: + // a,b,c + // vala,valb, + // SplitFields returns a count that is 1 less + // There fore we accept: + // expected == count + // and + // expected == count - 1 + expected_fields.wrapping_sub(count) <= 1 + } + + // we check 3 subsequent lines for `accept_line` before we accept + // if 3 groups are rejected we reject completely + let mut rejected_line_groups = 0u8; + + let mut total_pos = 0; + if input.is_empty() { + return None; + } + let mut lines_checked = 0u8; + loop { + if rejected_line_groups >= 3 { + return None; + } + lines_checked = lines_checked.wrapping_add(1); + // headers might have an extra value + // So if we have churned through enough lines + // we try one field less. + if lines_checked == u8::MAX { + if let Some(ef) = expected_fields { + expected_fields = Some(ef.saturating_sub(1)) + } + }; + let pos = memchr::memchr(eol_char, input)? + 1; + if input.len() - pos == 0 { + return None; + } + debug_assert!(pos <= input.len()); + let new_input = unsafe { input.get_unchecked(pos..) }; + let mut lines = SplitLines::new(new_input, quote_char, eol_char, None); + let line = lines.next(); + + match (line, expected_fields) { + // count the fields, and determine if they are equal to what we expect from the schema + (Some(line), Some(expected_fields)) => { + if accept_line(line, expected_fields, separator, eol_char, quote_char) { + let mut valid = true; + for line in lines.take(2) { + if !accept_line(line, expected_fields, separator, eol_char, quote_char) { + valid = false; + break; + } + } + if valid { + return Some(total_pos + pos); + } else { + rejected_line_groups += 1; + } + } else { + debug_assert!(pos < input.len()); + unsafe { + input = input.get_unchecked(pos + 1..); + } + total_pos += pos + 1; + } + }, + // don't count the fields + (Some(_), None) => return Some(total_pos + pos), + // // no new line found, check latest line (without eol) for number of fields + _ => return None, + } + } +} + +pub(super) fn is_line_ending(b: u8, eol_char: u8) -> bool { + b == eol_char || b == b'\r' +} + +pub(super) fn is_whitespace(b: u8) -> bool { + b == b' ' || b == b'\t' +} + +#[inline] +fn skip_condition(input: &[u8], f: F) -> &[u8] +where + F: Fn(u8) -> bool, +{ + if input.is_empty() { + return input; + } + + let read = input.iter().position(|b| !f(*b)).unwrap_or(input.len()); + &input[read..] +} + +/// Remove whitespace from the start of buffer. +/// Makes sure that the bytes stream starts with +/// 'field_1,field_2' +/// and not with +/// '\nfield_1,field_1' +#[inline] +pub(super) fn skip_whitespace(input: &[u8]) -> &[u8] { + skip_condition(input, is_whitespace) +} + +#[inline] +pub(super) fn skip_line_ending(input: &[u8], eol_char: u8) -> &[u8] { + skip_condition(input, |b| is_line_ending(b, eol_char)) +} + +/// Get the mean and standard deviation of length of lines in bytes +pub(super) fn get_line_stats( + bytes: &[u8], + n_lines: usize, + eol_char: u8, + expected_fields: Option, + separator: u8, + quote_char: Option, +) -> Option<(f32, f32)> { + let mut lengths = Vec::with_capacity(n_lines); + + let mut bytes_trunc; + let n_lines_per_iter = n_lines / 2; + + let mut n_read = 0; + + // sample from start and 75% in the file + for offset in [0, (bytes.len() as f32 * 0.75) as usize] { + bytes_trunc = &bytes[offset..]; + let pos = next_line_position( + bytes_trunc, + expected_fields, + separator, + quote_char, + eol_char, + )?; + bytes_trunc = &bytes_trunc[pos + 1..]; + + for _ in offset..(offset + n_lines_per_iter) { + let pos = next_line_position_naive(bytes_trunc, eol_char)? + 1; + n_read += pos; + lengths.push(pos); + bytes_trunc = &bytes_trunc[pos..]; + } + } + + let n_samples = lengths.len(); + + let mean = (n_read as f32) / (n_samples as f32); + let mut std = 0.0; + for &len in lengths.iter() { + std += (len as f32 - mean).pow(2.0) + } + std = (std / n_samples as f32).sqrt(); + Some((mean, std)) +} + +/// An adapted version of std::iter::Split. +/// This exists solely because we cannot split the file in lines naively as +/// +/// ```text +/// for line in bytes.split(b'\n') { +/// ``` +/// +/// This will fail when strings fields are have embedded end line characters. +/// For instance: "This is a valid field\nI have multiples lines" is a valid string field, that contains multiple lines. +pub(super) struct SplitLines<'a> { + v: &'a [u8], + quote_char: u8, + eol_char: u8, + #[cfg(feature = "simd")] + simd_eol_char: SimdVec, + #[cfg(feature = "simd")] + simd_quote_char: SimdVec, + #[cfg(feature = "simd")] + previous_valid_eols: u64, + total_index: usize, + quoting: bool, + comment_prefix: Option<&'a CommentPrefix>, +} + +#[cfg(feature = "simd")] +const SIMD_SIZE: usize = 64; +#[cfg(feature = "simd")] +use std::simd::prelude::*; + +#[cfg(feature = "simd")] +use polars_utils::clmul::prefix_xorsum_inclusive; + +#[cfg(feature = "simd")] +type SimdVec = u8x64; + +impl<'a> SplitLines<'a> { + pub(super) fn new( + slice: &'a [u8], + quote_char: Option, + eol_char: u8, + comment_prefix: Option<&'a CommentPrefix>, + ) -> Self { + let quoting = quote_char.is_some(); + let quote_char = quote_char.unwrap_or(b'\"'); + #[cfg(feature = "simd")] + let simd_eol_char = SimdVec::splat(eol_char); + #[cfg(feature = "simd")] + let simd_quote_char = SimdVec::splat(quote_char); + Self { + v: slice, + quote_char, + eol_char, + #[cfg(feature = "simd")] + simd_eol_char, + #[cfg(feature = "simd")] + simd_quote_char, + #[cfg(feature = "simd")] + previous_valid_eols: 0, + total_index: 0, + quoting, + comment_prefix, + } + } +} + +impl<'a> SplitLines<'a> { + // scalar as in non-simd + fn next_scalar(&mut self) -> Option<&'a [u8]> { + if self.v.is_empty() { + return None; + } + if is_comment_line(self.v, self.comment_prefix) { + return self.next_comment_line(); + } + { + let mut pos = 0u32; + let mut iter = self.v.iter(); + let mut in_field = false; + loop { + match iter.next() { + Some(&c) => { + pos += 1; + + if self.quoting && c == self.quote_char { + // toggle between string field enclosure + // if we encounter a starting '"' -> in_field = true; + // if we encounter a closing '"' -> in_field = false; + in_field = !in_field; + } + // if we are not in a string and we encounter '\n' we can stop at this position. + else if c == self.eol_char && !in_field { + break; + } + }, + None => { + let remainder = self.v; + self.v = &[]; + return Some(remainder); + }, + } + } + + unsafe { + debug_assert!((pos as usize) <= self.v.len()); + + // return line up to this position + let ret = Some( + self.v + .get_unchecked(..(self.total_index + pos as usize - 1)), + ); + // skip the '\n' token and update slice. + self.v = self.v.get_unchecked(self.total_index + pos as usize..); + ret + } + } + } + fn next_comment_line(&mut self) -> Option<&'a [u8]> { + if let Some(pos) = next_line_position_naive(self.v, self.eol_char) { + unsafe { + // return line up to this position + let ret = Some(self.v.get_unchecked(..(pos - 1))); + // skip the '\n' token and update slice. + self.v = self.v.get_unchecked(pos..); + ret + } + } else { + let remainder = self.v; + self.v = &[]; + Some(remainder) + } + } +} + +impl<'a> Iterator for SplitLines<'a> { + type Item = &'a [u8]; + + #[inline] + #[cfg(not(feature = "simd"))] + fn next(&mut self) -> Option<&'a [u8]> { + self.next_scalar() + } + + #[inline] + #[cfg(feature = "simd")] + fn next(&mut self) -> Option<&'a [u8]> { + // First check cached value + if self.previous_valid_eols != 0 { + let pos = self.previous_valid_eols.trailing_zeros() as usize; + self.previous_valid_eols >>= (pos + 1) as u64; + + unsafe { + debug_assert!((pos) <= self.v.len()); + + // return line up to this position + let ret = Some(self.v.get_unchecked(..pos)); + // skip the '\n' token and update slice. + self.v = self.v.get_unchecked(pos + 1..); + return ret; + } + } + if self.v.is_empty() { + return None; + } + if self.comment_prefix.is_some() { + return self.next_scalar(); + } + + self.total_index = 0; + let mut not_in_field_previous_iter = true; + + loop { + let bytes = unsafe { self.v.get_unchecked(self.total_index..) }; + if bytes.len() > SIMD_SIZE { + let lane: [u8; SIMD_SIZE] = unsafe { + bytes + .get_unchecked(0..SIMD_SIZE) + .try_into() + .unwrap_unchecked() + }; + let simd_bytes = SimdVec::from(lane); + let eol_mask = simd_bytes.simd_eq(self.simd_eol_char).to_bitmask(); + + let valid_eols = if self.quoting { + let quote_mask = simd_bytes.simd_eq(self.simd_quote_char).to_bitmask(); + let mut not_in_quote_field = prefix_xorsum_inclusive(quote_mask); + + if not_in_field_previous_iter { + not_in_quote_field = !not_in_quote_field; + } + not_in_field_previous_iter = (not_in_quote_field & (1 << (SIMD_SIZE - 1))) > 0; + eol_mask & not_in_quote_field + } else { + eol_mask + }; + + if valid_eols != 0 { + let pos = valid_eols.trailing_zeros() as usize; + if pos == SIMD_SIZE - 1 { + self.previous_valid_eols = 0; + } else { + self.previous_valid_eols = valid_eols >> (pos + 1) as u64; + } + + unsafe { + let pos = self.total_index + pos; + debug_assert!((pos) <= self.v.len()); + + // return line up to this position + let ret = Some(self.v.get_unchecked(..pos)); + // skip the '\n' token and update slice. + self.v = self.v.get_unchecked(pos + 1..); + return ret; + } + } else { + self.total_index += SIMD_SIZE; + } + } else { + // Denotes if we are in a string field, started with a quote + let mut in_field = !not_in_field_previous_iter; + let mut pos = 0u32; + let mut iter = bytes.iter(); + loop { + match iter.next() { + Some(&c) => { + pos += 1; + + if self.quoting && c == self.quote_char { + // toggle between string field enclosure + // if we encounter a starting '"' -> in_field = true; + // if we encounter a closing '"' -> in_field = false; + in_field = !in_field; + } + // if we are not in a string and we encounter '\n' we can stop at this position. + else if c == self.eol_char && !in_field { + break; + } + }, + None => { + let remainder = self.v; + self.v = &[]; + return Some(remainder); + }, + } + } + + unsafe { + debug_assert!((pos as usize) <= self.v.len()); + + // return line up to this position + let ret = Some( + self.v + .get_unchecked(..(self.total_index + pos as usize - 1)), + ); + // skip the '\n' token and update slice. + self.v = self.v.get_unchecked(self.total_index + pos as usize..); + return ret; + } + } + } + } +} + +pub struct CountLines { + quote_char: u8, + eol_char: u8, + #[cfg(feature = "simd")] + simd_eol_char: SimdVec, + #[cfg(feature = "simd")] + simd_quote_char: SimdVec, + quoting: bool, +} + +#[derive(Copy, Clone, Debug)] +pub struct LineStats { + newline_count: usize, + last_newline_offset: usize, + end_inside_string: bool, +} + +impl CountLines { + pub fn new(quote_char: Option, eol_char: u8) -> Self { + let quoting = quote_char.is_some(); + let quote_char = quote_char.unwrap_or(b'\"'); + #[cfg(feature = "simd")] + let simd_eol_char = SimdVec::splat(eol_char); + #[cfg(feature = "simd")] + let simd_quote_char = SimdVec::splat(quote_char); + Self { + quote_char, + eol_char, + #[cfg(feature = "simd")] + simd_eol_char, + #[cfg(feature = "simd")] + simd_quote_char, + quoting, + } + } + + /// Analyzes a chunk of CSV data. + /// + /// Returns (newline_count, last_newline_offset, end_inside_string) twice, + /// the first is assuming the start of the chunk is *not* inside a string, + /// the second assuming the start is inside a string. + pub fn analyze_chunk(&self, bytes: &[u8]) -> [LineStats; 2] { + let mut scan_offset = 0; + let mut states = [ + LineStats { + newline_count: 0, + last_newline_offset: 0, + end_inside_string: false, + }, + LineStats { + newline_count: 0, + last_newline_offset: 0, + end_inside_string: false, + }, + ]; + + // false if even number of quotes seen so far, true otherwise. + #[allow(unused_assignments)] + let mut global_quote_parity = false; + + #[cfg(feature = "simd")] + { + // 0 if even number of quotes seen so far, u64::MAX otherwise. + let mut global_quote_parity_mask = 0; + while scan_offset + 64 <= bytes.len() { + let block: [u8; 64] = unsafe { + bytes + .get_unchecked(scan_offset..scan_offset + 64) + .try_into() + .unwrap_unchecked() + }; + let simd_bytes = SimdVec::from(block); + let eol_mask = simd_bytes.simd_eq(self.simd_eol_char).to_bitmask(); + if self.quoting { + let quote_mask = simd_bytes.simd_eq(self.simd_quote_char).to_bitmask(); + let quote_parity = + prefix_xorsum_inclusive(quote_mask) ^ global_quote_parity_mask; + global_quote_parity_mask = ((quote_parity as i64) >> 63) as u64; + + let start_outside_string_eol_mask = eol_mask & !quote_parity; + states[0].newline_count += start_outside_string_eol_mask.count_ones() as usize; + states[0].last_newline_offset = select_unpredictable( + start_outside_string_eol_mask != 0, + (scan_offset + 63) + .wrapping_sub(start_outside_string_eol_mask.leading_zeros() as usize), + states[0].last_newline_offset, + ); + + let start_inside_string_eol_mask = eol_mask & quote_parity; + states[1].newline_count += start_inside_string_eol_mask.count_ones() as usize; + states[1].last_newline_offset = select_unpredictable( + start_inside_string_eol_mask != 0, + (scan_offset + 63) + .wrapping_sub(start_inside_string_eol_mask.leading_zeros() as usize), + states[1].last_newline_offset, + ); + } else { + states[0].newline_count += eol_mask.count_ones() as usize; + states[0].last_newline_offset = select_unpredictable( + eol_mask != 0, + (scan_offset + 63).wrapping_sub(eol_mask.leading_zeros() as usize), + states[0].last_newline_offset, + ); + } + + scan_offset += 64; + } + + global_quote_parity = global_quote_parity_mask > 0; + } + + while scan_offset < bytes.len() { + let c = unsafe { *bytes.get_unchecked(scan_offset) }; + global_quote_parity ^= (c == self.quote_char) & self.quoting; + + let state = &mut states[global_quote_parity as usize]; + state.newline_count += (c == self.eol_char) as usize; + state.last_newline_offset = + select_unpredictable(c == self.eol_char, scan_offset, state.last_newline_offset); + + scan_offset += 1; + } + + states[0].end_inside_string = global_quote_parity; + states[1].end_inside_string = !global_quote_parity; + states + } + + pub fn find_next(&self, bytes: &[u8], chunk_size: &mut usize) -> (usize, usize) { + loop { + let b = unsafe { bytes.get_unchecked(..(*chunk_size).min(bytes.len())) }; + + let (count, offset) = self.count(b); + + if count > 0 || b.len() == bytes.len() { + return (count, offset); + } + + *chunk_size *= 2; + } + } + + /// Returns count and offset to split for remainder in slice. + #[cfg(feature = "simd")] + pub fn count(&self, bytes: &[u8]) -> (usize, usize) { + let mut total_idx = 0; + let original_bytes = bytes; + let mut count = 0; + let mut position = 0; + let mut not_in_field_previous_iter = true; + + loop { + let bytes = unsafe { original_bytes.get_unchecked(total_idx..) }; + + if bytes.len() > SIMD_SIZE { + let lane: [u8; SIMD_SIZE] = unsafe { + bytes + .get_unchecked(0..SIMD_SIZE) + .try_into() + .unwrap_unchecked() + }; + let simd_bytes = SimdVec::from(lane); + let eol_mask = simd_bytes.simd_eq(self.simd_eol_char).to_bitmask(); + + let valid_eols = if self.quoting { + let quote_mask = simd_bytes.simd_eq(self.simd_quote_char).to_bitmask(); + let mut not_in_quote_field = prefix_xorsum_inclusive(quote_mask); + + if not_in_field_previous_iter { + not_in_quote_field = !not_in_quote_field; + } + not_in_field_previous_iter = (not_in_quote_field & (1 << (SIMD_SIZE - 1))) > 0; + eol_mask & not_in_quote_field + } else { + eol_mask + }; + + if valid_eols != 0 { + count += valid_eols.count_ones() as usize; + position = total_idx + 63 - valid_eols.leading_zeros() as usize; + debug_assert_eq!(original_bytes[position], self.eol_char) + } + total_idx += SIMD_SIZE; + } else if bytes.is_empty() { + debug_assert!(count == 0 || original_bytes[position] == self.eol_char); + return (count, position); + } else { + let (c, o) = self.count_no_simd(bytes, !not_in_field_previous_iter); + + let (count, position) = if c > 0 { + (count + c, total_idx + o) + } else { + (count, position) + }; + debug_assert!(count == 0 || original_bytes[position] == self.eol_char); + + return (count, position); + } + } + } + + #[cfg(not(feature = "simd"))] + pub fn count(&self, bytes: &[u8]) -> (usize, usize) { + self.count_no_simd(bytes, false) + } + + fn count_no_simd(&self, bytes: &[u8], in_field: bool) -> (usize, usize) { + let iter = bytes.iter(); + let mut in_field = in_field; + let mut count = 0; + let mut position = 0; + + for b in iter { + let c = *b; + if self.quoting && c == self.quote_char { + // toggle between string field enclosure + // if we encounter a starting '"' -> in_field = true; + // if we encounter a closing '"' -> in_field = false; + in_field = !in_field; + } + // If we are not in a string and we encounter '\n' we can stop at this position. + else if c == self.eol_char && !in_field { + position = (b as *const _ as usize) - (bytes.as_ptr() as usize); + count += 1; + } + } + debug_assert!(count == 0 || bytes[position] == self.eol_char); + + (count, position) + } +} + +#[inline] +fn find_quoted(bytes: &[u8], quote_char: u8, needle: u8) -> Option { + let mut in_field = false; + + let mut idx = 0u32; + // micro optimizations + #[allow(clippy::explicit_counter_loop)] + for &c in bytes.iter() { + if c == quote_char { + // toggle between string field enclosure + // if we encounter a starting '"' -> in_field = true; + // if we encounter a closing '"' -> in_field = false; + in_field = !in_field; + } + + if !in_field && c == needle { + return Some(idx as usize); + } + idx += 1; + } + None +} + +#[inline] +pub(super) fn skip_this_line(bytes: &[u8], quote: Option, eol_char: u8) -> &[u8] { + let pos = match quote { + Some(quote) => find_quoted(bytes, quote, eol_char), + None => bytes.iter().position(|x| *x == eol_char), + }; + match pos { + None => &[], + Some(pos) => &bytes[pos + 1..], + } +} + +#[inline] +pub(super) fn skip_this_line_naive(input: &[u8], eol_char: u8) -> &[u8] { + if let Some(pos) = next_line_position_naive(input, eol_char) { + unsafe { input.get_unchecked(pos..) } + } else { + &[] + } +} + +/// Parse CSV. +/// +/// # Arguments +/// * `bytes` - input to parse +/// * `offset` - offset in bytes in total input. This is 0 if single threaded. If multi-threaded every +/// thread has a different offset. +/// * `projection` - Indices of the columns to project. +/// * `buffers` - Parsed output will be written to these buffers. Except for UTF8 data. The offsets of the +/// fields are written to the buffers. The UTF8 data will be parsed later. +#[allow(clippy::too_many_arguments)] +pub(super) fn parse_lines( + mut bytes: &[u8], + parse_options: &CsvParseOptions, + offset: usize, + ignore_errors: bool, + null_values: Option<&NullValuesCompiled>, + projection: &[usize], + buffers: &mut [Buffer], + n_lines: usize, + // length of original schema + schema_len: usize, + schema: &Schema, +) -> PolarsResult { + assert!( + !projection.is_empty(), + "at least one column should be projected" + ); + let mut truncate_ragged_lines = parse_options.truncate_ragged_lines; + // During projection pushdown we are not checking other csv fields. + // This would be very expensive and we don't care as we only want + // the projected columns. + if projection.len() != schema_len { + truncate_ragged_lines = true + } + + // we use the pointers to track the no of bytes read. + let start = bytes.as_ptr() as usize; + let original_bytes_len = bytes.len(); + let n_lines = n_lines as u32; + + let mut line_count = 0u32; + loop { + if line_count > n_lines { + let end = bytes.as_ptr() as usize; + return Ok(end - start); + } + + if bytes.is_empty() { + return Ok(original_bytes_len); + } else if is_comment_line(bytes, parse_options.comment_prefix.as_ref()) { + // deal with comments + let bytes_rem = skip_this_line_naive(bytes, parse_options.eol_char); + bytes = bytes_rem; + continue; + } + + // Every line we only need to parse the columns that are projected. + // Therefore we check if the idx of the field is in our projected columns. + // If it is not, we skip the field. + let mut projection_iter = projection.iter().copied(); + let mut next_projected = unsafe { projection_iter.next().unwrap_unchecked() }; + let mut processed_fields = 0; + + let mut iter = SplitFields::new( + bytes, + parse_options.separator, + parse_options.quote_char, + parse_options.eol_char, + ); + let mut idx = 0u32; + let mut read_sol = 0; + loop { + match iter.next() { + // end of line + None => { + bytes = unsafe { bytes.get_unchecked(std::cmp::min(read_sol, bytes.len())..) }; + break; + }, + Some((mut field, needs_escaping)) => { + let field_len = field.len(); + + // +1 is the split character that is consumed by the iterator. + read_sol += field_len + 1; + + if idx == next_projected as u32 { + // the iterator is finished when it encounters a `\n` + // this could be preceded by a '\r' + unsafe { + if field_len > 0 && *field.get_unchecked(field_len - 1) == b'\r' { + field = field.get_unchecked(..field_len - 1); + } + } + + debug_assert!(processed_fields < buffers.len()); + let buf = unsafe { + // SAFETY: processed fields index can never exceed the projection indices. + buffers.get_unchecked_mut(processed_fields) + }; + let mut add_null = false; + + // if we have null values argument, check if this field equal null value + if let Some(null_values) = null_values { + let field = if needs_escaping && !field.is_empty() { + unsafe { field.get_unchecked(1..field.len() - 1) } + } else { + field + }; + + // SAFETY: + // process fields is in bounds + add_null = unsafe { null_values.is_null(field, idx as usize) } + } + if add_null { + buf.add_null(!parse_options.missing_is_null && field.is_empty()) + } else { + buf.add(field, ignore_errors, needs_escaping, parse_options.missing_is_null) + .map_err(|e| { + let bytes_offset = offset + field.as_ptr() as usize - start; + let unparsable = String::from_utf8_lossy(field); + let column_name = schema.get_at_index(idx as usize).unwrap().0; + polars_err!( + ComputeError: + "could not parse `{}` as dtype `{}` at column '{}' (column number {})\n\n\ + The current offset in the file is {} bytes.\n\ + \n\ + You might want to try:\n\ + - increasing `infer_schema_length` (e.g. `infer_schema_length=10000`),\n\ + - specifying correct dtype with the `schema_overrides` argument\n\ + - setting `ignore_errors` to `True`,\n\ + - adding `{}` to the `null_values` list.\n\n\ + Original error: ```{}```", + &unparsable, + buf.dtype(), + column_name, + idx + 1, + bytes_offset, + &unparsable, + e + ) + })?; + } + processed_fields += 1; + + // if we have all projected columns we are done with this line + match projection_iter.next() { + Some(p) => next_projected = p, + None => { + if bytes.get(read_sol - 1) == Some(&parse_options.eol_char) { + bytes = &bytes[read_sol..]; + } else { + if !truncate_ragged_lines && read_sol < bytes.len() { + polars_bail!(ComputeError: r#"found more fields than defined in 'Schema' + +Consider setting 'truncate_ragged_lines={}'."#, polars_error::constants::TRUE) + } + let bytes_rem = skip_this_line( + unsafe { bytes.get_unchecked(read_sol - 1..) }, + parse_options.quote_char, + parse_options.eol_char, + ); + bytes = bytes_rem; + } + break; + }, + } + } + idx += 1; + }, + } + } + + // there can be lines that miss fields (also the comma values) + // this means the splitter won't process them. + // We traverse them to read them as null values. + while processed_fields < projection.len() { + debug_assert!(processed_fields < buffers.len()); + let buf = unsafe { + // SAFETY: processed fields index can never exceed the projection indices. + buffers.get_unchecked_mut(processed_fields) + }; + buf.add_null(!parse_options.missing_is_null); + processed_fields += 1; + } + line_count += 1; + } +} + +#[cfg(test)] +mod test { + use super::SplitLines; + + #[test] + fn test_splitlines() { + let input = "1,\"foo\n\"\n2,\"foo\n\"\n"; + let mut lines = SplitLines::new(input.as_bytes(), Some(b'"'), b'\n', None); + assert_eq!(lines.next(), Some("1,\"foo\n\"".as_bytes())); + assert_eq!(lines.next(), Some("2,\"foo\n\"".as_bytes())); + assert_eq!(lines.next(), None); + + let input2 = "1,'foo\n'\n2,'foo\n'\n"; + let mut lines2 = SplitLines::new(input2.as_bytes(), Some(b'\''), b'\n', None); + assert_eq!(lines2.next(), Some("1,'foo\n'".as_bytes())); + assert_eq!(lines2.next(), Some("2,'foo\n'".as_bytes())); + assert_eq!(lines2.next(), None); + } +} diff --git a/crates/polars-io/src/csv/read/read_impl.rs b/crates/polars-io/src/csv/read/read_impl.rs new file mode 100644 index 000000000000..a7e9af6b9927 --- /dev/null +++ b/crates/polars-io/src/csv/read/read_impl.rs @@ -0,0 +1,664 @@ +pub(super) mod batched; + +use std::fmt; +use std::sync::Mutex; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use polars_core::POOL; +use polars_core::prelude::*; +use polars_core::utils::{accumulate_dataframes_vertical, handle_casting_failures}; +#[cfg(feature = "polars-time")] +use polars_time::prelude::*; +use rayon::prelude::*; + +use super::CsvParseOptions; +use super::buffer::init_buffers; +use super::options::{CommentPrefix, CsvEncoding, NullValuesCompiled}; +use super::parser::{ + CountLines, SplitLines, is_comment_line, parse_lines, skip_bom, skip_line_ending, + skip_lines_naive, skip_this_line, +}; +use super::reader::prepare_csv_schema; +use super::schema_inference::{check_decimal_comma, infer_file_schema}; +#[cfg(feature = "decompress")] +use super::utils::decompress; +use crate::RowIndex; +use crate::csv::read::parser::skip_this_line_naive; +use crate::mmap::ReaderBytes; +use crate::predicates::PhysicalIoExpr; +use crate::utils::compression::SupportedCompression; +use crate::utils::update_row_counts2; + +pub fn cast_columns( + df: &mut DataFrame, + to_cast: &[Field], + parallel: bool, + ignore_errors: bool, +) -> PolarsResult<()> { + let cast_fn = |c: &Column, fld: &Field| { + let out = match (c.dtype(), fld.dtype()) { + #[cfg(feature = "temporal")] + (DataType::String, DataType::Date) => c + .str() + .unwrap() + .as_date(None, false) + .map(|ca| ca.into_column()), + #[cfg(feature = "temporal")] + (DataType::String, DataType::Time) => c + .str() + .unwrap() + .as_time(None, false) + .map(|ca| ca.into_column()), + #[cfg(feature = "temporal")] + (DataType::String, DataType::Datetime(tu, _)) => c + .str() + .unwrap() + .as_datetime( + None, + *tu, + false, + false, + None, + &StringChunked::from_iter(std::iter::once("raise")), + ) + .map(|ca| ca.into_column()), + (_, dt) => c.cast(dt), + }?; + if !ignore_errors && c.null_count() != out.null_count() { + handle_casting_failures(c.as_materialized_series(), out.as_materialized_series())?; + } + Ok(out) + }; + + if parallel { + let cols = POOL.install(|| { + df.get_columns() + .into_par_iter() + .map(|s| { + if let Some(fld) = to_cast.iter().find(|fld| fld.name() == s.name()) { + cast_fn(s, fld) + } else { + Ok(s.clone()) + } + }) + .collect::>>() + })?; + *df = unsafe { DataFrame::new_no_checks(df.height(), cols) } + } else { + // cast to the original dtypes in the schema + for fld in to_cast { + // field may not be projected + if let Some(idx) = df.get_column_index(fld.name()) { + df.try_apply_at_idx(idx, |s| cast_fn(s, fld))?; + } + } + + df.clear_schema(); + } + Ok(()) +} + +/// CSV file reader +pub(crate) struct CoreReader<'a> { + reader_bytes: Option>, + /// Explicit schema for the CSV file + schema: SchemaRef, + parse_options: CsvParseOptions, + /// Optional projection for which columns to load (zero-based column indices) + projection: Option>, + /// Current line number, used in error reporting + current_line: usize, + ignore_errors: bool, + skip_lines: usize, + skip_rows_before_header: usize, + // after the header, we need to take embedded lines into account + skip_rows_after_header: usize, + n_rows: Option, + n_threads: Option, + has_header: bool, + chunk_size: usize, + null_values: Option, + predicate: Option>, + to_cast: Vec, + row_index: Option, + #[cfg_attr(not(feature = "dtype-categorical"), allow(unused))] + has_categorical: bool, +} + +impl fmt::Debug for CoreReader<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Reader") + .field("schema", &self.schema) + .field("projection", &self.projection) + .field("current_line", &self.current_line) + .finish() + } +} + +impl<'a> CoreReader<'a> { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + reader_bytes: ReaderBytes<'a>, + parse_options: Arc, + n_rows: Option, + skip_rows: usize, + skip_lines: usize, + mut projection: Option>, + max_records: Option, + has_header: bool, + ignore_errors: bool, + schema: Option, + columns: Option>, + n_threads: Option, + schema_overwrite: Option, + dtype_overwrite: Option>>, + chunk_size: usize, + predicate: Option>, + mut to_cast: Vec, + skip_rows_after_header: usize, + row_index: Option, + raise_if_empty: bool, + ) -> PolarsResult> { + let separator = parse_options.separator; + + check_decimal_comma(parse_options.decimal_comma, separator)?; + #[cfg(feature = "decompress")] + let mut reader_bytes = reader_bytes; + + if !cfg!(feature = "decompress") && SupportedCompression::check(&reader_bytes).is_some() { + polars_bail!( + ComputeError: "cannot read compressed CSV file; \ + compile with feature 'decompress'" + ); + } + // We keep track of the inferred schema bool + // In case the file is compressed this schema inference is wrong and has to be done + // again after decompression. + #[cfg(feature = "decompress")] + { + let total_n_rows = + n_rows.map(|n| skip_rows + (has_header as usize) + skip_rows_after_header + n); + if let Some(b) = decompress( + &reader_bytes, + total_n_rows, + separator, + parse_options.quote_char, + parse_options.eol_char, + ) { + reader_bytes = ReaderBytes::Owned(b.into()); + } + } + + let mut schema = match schema { + Some(schema) => schema, + None => { + let (inferred_schema, _, _) = infer_file_schema( + &reader_bytes, + &parse_options, + max_records, + has_header, + schema_overwrite.as_deref(), + skip_rows, + skip_lines, + skip_rows_after_header, + raise_if_empty, + )?; + Arc::new(inferred_schema) + }, + }; + if let Some(dtypes) = dtype_overwrite { + polars_ensure!( + dtypes.len() <= schema.len(), + InvalidOperation: "The number of schema overrides must be less than or equal to the number of fields" + ); + let s = Arc::make_mut(&mut schema); + for (index, dt) in dtypes.iter().enumerate() { + s.set_dtype_at_index(index, dt.clone()).unwrap(); + } + } + + let has_categorical = prepare_csv_schema(&mut schema, &mut to_cast)?; + + // Create a null value for every column + let null_values = parse_options + .null_values + .as_ref() + .map(|nv| nv.clone().compile(&schema)) + .transpose()?; + + if let Some(cols) = columns { + let mut prj = Vec::with_capacity(cols.len()); + for col in cols.as_ref() { + let i = schema.try_index_of(col)?; + prj.push(i); + } + projection = Some(prj); + } + + Ok(CoreReader { + reader_bytes: Some(reader_bytes), + parse_options: (*parse_options).clone(), + schema, + projection, + current_line: usize::from(has_header), + ignore_errors, + skip_lines, + skip_rows_before_header: skip_rows, + skip_rows_after_header, + n_rows, + n_threads, + has_header, + chunk_size, + null_values, + predicate, + to_cast, + row_index, + has_categorical, + }) + } + + fn find_starting_point<'b>( + &self, + bytes: &'b [u8], + quote_char: Option, + eol_char: u8, + ) -> PolarsResult<(&'b [u8], Option)> { + let i = find_starting_point( + bytes, + quote_char, + eol_char, + self.schema.len(), + self.skip_lines, + self.skip_rows_before_header, + self.skip_rows_after_header, + self.parse_options.comment_prefix.as_ref(), + self.has_header, + )?; + + Ok((&bytes[i..], (i <= bytes.len()).then_some(i))) + } + + fn get_projection(&mut self) -> PolarsResult> { + // we also need to sort the projection to have predictable output. + // the `parse_lines` function expects this. + self.projection + .take() + .map(|mut v| { + v.sort_unstable(); + if let Some(idx) = v.last() { + polars_ensure!(*idx < self.schema.len(), OutOfBounds: "projection index: {} is out of bounds for csv schema with length: {}", idx, self.schema.len()) + } + Ok(v) + }) + .unwrap_or_else(|| Ok((0..self.schema.len()).collect())) + } + + fn read_chunk( + &self, + bytes: &[u8], + projection: &[usize], + bytes_offset: usize, + capacity: usize, + starting_point_offset: Option, + stop_at_nbytes: usize, + ) -> PolarsResult { + let mut df = read_chunk( + bytes, + &self.parse_options, + self.schema.as_ref(), + self.ignore_errors, + projection, + bytes_offset, + capacity, + self.null_values.as_ref(), + usize::MAX, + stop_at_nbytes, + starting_point_offset, + )?; + + cast_columns(&mut df, &self.to_cast, false, self.ignore_errors)?; + Ok(df) + } + + fn parse_csv(&mut self, bytes: &[u8]) -> PolarsResult { + let (bytes, _) = self.find_starting_point( + bytes, + self.parse_options.quote_char, + self.parse_options.eol_char, + )?; + + let projection = self.get_projection()?; + + // An empty file with a schema should return an empty DataFrame with that schema + if bytes.is_empty() { + let mut df = if projection.len() == self.schema.len() { + DataFrame::empty_with_schema(self.schema.as_ref()) + } else { + DataFrame::empty_with_schema( + &projection + .iter() + .map(|&i| self.schema.get_at_index(i).unwrap()) + .map(|(name, dtype)| Field { + name: name.clone(), + dtype: dtype.clone(), + }) + .collect::(), + ) + }; + if let Some(ref row_index) = self.row_index { + df.insert_column(0, Series::new_empty(row_index.name.clone(), &IDX_DTYPE))?; + } + return Ok(df); + } + + let n_threads = self.n_threads.unwrap_or_else(|| POOL.current_num_threads()); + + // This is chosen by benchmarking on ny city trip csv dataset. + // We want small enough chunks such that threads start working as soon as possible + // But we also want them large enough, so that we have less chunks related overhead, but + // We minimize chunks to 16 MB to still fit L3 cache. + let n_parts_hint = n_threads * 16; + let chunk_size = std::cmp::min(bytes.len() / n_parts_hint, 16 * 1024 * 1024); + + // Use a small min chunk size to catch failures in tests. + #[cfg(debug_assertions)] + let min_chunk_size = 64; + #[cfg(not(debug_assertions))] + let min_chunk_size = 1024 * 4; + + let mut chunk_size = std::cmp::max(chunk_size, min_chunk_size); + let mut total_bytes_offset = 0; + + let results = Arc::new(Mutex::new(vec![])); + // We have to do this after parsing as there can be comments. + let total_line_count = &AtomicUsize::new(0); + + #[cfg(not(target_family = "wasm"))] + let pool; + #[cfg(not(target_family = "wasm"))] + let pool = if n_threads == POOL.current_num_threads() { + &POOL + } else { + pool = rayon::ThreadPoolBuilder::new() + .num_threads(n_threads) + .build() + .map_err(|_| polars_err!(ComputeError: "could not spawn threads"))?; + &pool + }; + #[cfg(target_family = "wasm")] + let pool = &POOL; + + let counter = CountLines::new(self.parse_options.quote_char, self.parse_options.eol_char); + let mut total_offset = 0; + let check_utf8 = matches!(self.parse_options.encoding, CsvEncoding::Utf8) + && self.schema.iter_fields().any(|f| f.dtype().is_string()); + + pool.scope(|s| { + loop { + let b = unsafe { bytes.get_unchecked(total_offset..) }; + if b.is_empty() { + break; + } + debug_assert!( + total_offset == 0 || bytes[total_offset - 1] == self.parse_options.eol_char + ); + let (count, position) = counter.find_next(b, &mut chunk_size); + debug_assert!(count == 0 || b[position] == self.parse_options.eol_char); + + let (b, count) = if count == 0 + && unsafe { + std::ptr::eq(b.as_ptr().add(b.len()), bytes.as_ptr().add(bytes.len())) + } { + total_offset = bytes.len(); + (b, 1) + } else { + if count == 0 { + chunk_size *= 2; + continue; + } + + let end = total_offset + position + 1; + let b = unsafe { bytes.get_unchecked(total_offset..end) }; + + total_offset = end; + (b, count) + }; + + if !b.is_empty() { + let results = results.clone(); + let projection = projection.as_ref(); + let slf = &(*self); + s.spawn(move |_| { + if check_utf8 && !super::buffer::validate_utf8(b) { + let mut results = results.lock().unwrap(); + results.push(( + b.as_ptr() as usize, + Err(polars_err!(ComputeError: "invalid utf-8 sequence")), + )); + return; + } + + let result = slf + .read_chunk(b, projection, 0, count, Some(0), b.len()) + .and_then(|mut df| { + debug_assert!(df.height() <= count); + + if slf.n_rows.is_some() { + total_line_count.fetch_add(df.height(), Ordering::Relaxed); + } + + // We cannot use the line count as there can be comments in the lines so we must correct line counts later. + if let Some(rc) = &slf.row_index { + // is first chunk + let offset = if std::ptr::eq(b.as_ptr(), bytes.as_ptr()) { + Some(rc.offset) + } else { + None + }; + + unsafe { df.with_row_index_mut(rc.name.clone(), offset) }; + }; + + if let Some(predicate) = slf.predicate.as_ref() { + let s = predicate.evaluate_io(&df)?; + let mask = s.bool()?; + df = df.filter(mask)?; + } + Ok(df) + }); + + results.lock().unwrap().push((b.as_ptr() as usize, result)); + }); + + // Check just after we spawned a chunk. That mean we processed all data up until + // row count. + if self.n_rows.is_some() + && total_line_count.load(Ordering::Relaxed) > self.n_rows.unwrap() + { + break; + } + } + total_bytes_offset += b.len(); + } + }); + let mut results = std::mem::take(&mut *results.lock().unwrap()); + results.sort_unstable_by_key(|k| k.0); + let mut dfs = results + .into_iter() + .map(|k| k.1) + .collect::>>()?; + + if let Some(rc) = &self.row_index { + update_row_counts2(&mut dfs, rc.offset) + }; + accumulate_dataframes_vertical(dfs) + } + + /// Read the csv into a DataFrame. The predicate can come from a lazy physical plan. + pub fn finish(mut self) -> PolarsResult { + #[cfg(feature = "dtype-categorical")] + let mut _cat_lock = if self.has_categorical { + Some(polars_core::StringCacheHolder::hold()) + } else { + None + }; + + let reader_bytes = self.reader_bytes.take().unwrap(); + + let mut df = self.parse_csv(&reader_bytes)?; + + // if multi-threaded the n_rows was probabilistically determined. + // Let's slice to correct number of rows if possible. + if let Some(n_rows) = self.n_rows { + if n_rows < df.height() { + df = df.slice(0, n_rows) + } + } + Ok(df) + } +} + +#[allow(clippy::too_many_arguments)] +pub fn read_chunk( + bytes: &[u8], + parse_options: &CsvParseOptions, + schema: &Schema, + ignore_errors: bool, + projection: &[usize], + bytes_offset_thread: usize, + capacity: usize, + null_values: Option<&NullValuesCompiled>, + chunk_size: usize, + stop_at_nbytes: usize, + starting_point_offset: Option, +) -> PolarsResult { + let mut read = bytes_offset_thread; + // There's an off-by-one error somewhere in the reading code, where it reads + // one more item than the requested capacity. Given the batch sizes are + // approximate (sometimes they're smaller), this isn't broken, but it does + // mean a bunch of extra allocation and copying. So we allocate a + // larger-by-one buffer so the size is more likely to be accurate. + let mut buffers = init_buffers( + projection, + capacity + 1, + schema, + parse_options.quote_char, + parse_options.encoding, + parse_options.decimal_comma, + )?; + + debug_assert!(projection.is_sorted()); + + let mut last_read = usize::MAX; + loop { + if read >= stop_at_nbytes || read == last_read { + break; + } + let local_bytes = &bytes[read..stop_at_nbytes]; + + last_read = read; + let offset = read + starting_point_offset.unwrap(); + read += parse_lines( + local_bytes, + parse_options, + offset, + ignore_errors, + null_values, + projection, + &mut buffers, + chunk_size, + schema.len(), + schema, + )?; + } + + let columns = buffers + .into_iter() + .map(|buf| buf.into_series().map(Column::from)) + .collect::>>()?; + Ok(unsafe { DataFrame::new_no_checks_height_from_first(columns) }) +} + +#[allow(clippy::too_many_arguments)] +pub fn find_starting_point( + mut bytes: &[u8], + quote_char: Option, + eol_char: u8, + schema_len: usize, + skip_lines: usize, + skip_rows_before_header: usize, + skip_rows_after_header: usize, + comment_prefix: Option<&CommentPrefix>, + has_header: bool, +) -> PolarsResult { + let full_len = bytes.len(); + let starting_point_offset = bytes.as_ptr() as usize; + + bytes = if skip_lines > 0 { + polars_ensure!(skip_rows_before_header == 0, InvalidOperation: "only one of 'skip_rows'/'skip_lines' may be set"); + skip_lines_naive(bytes, eol_char, skip_lines) + } else { + // Skip utf8 byte-order-mark (BOM) + bytes = skip_bom(bytes); + + // \n\n can be a empty string row of a single column + // in other cases we skip it. + if schema_len > 1 { + bytes = skip_line_ending(bytes, eol_char) + } + bytes + }; + + // skip 'n' leading rows + if skip_rows_before_header > 0 { + let mut split_lines = SplitLines::new(bytes, quote_char, eol_char, comment_prefix); + let mut current_line = &bytes[..0]; + + for _ in 0..skip_rows_before_header { + current_line = split_lines + .next() + .ok_or_else(|| polars_err!(NoData: "not enough lines to skip"))?; + } + + current_line = split_lines + .next() + .unwrap_or(¤t_line[current_line.len()..]); + bytes = &bytes[current_line.as_ptr() as usize - bytes.as_ptr() as usize..]; + } + + // skip lines that are comments + while is_comment_line(bytes, comment_prefix) { + bytes = skip_this_line_naive(bytes, eol_char); + } + + // skip header row + if has_header { + bytes = skip_this_line(bytes, quote_char, eol_char); + } + // skip 'n' rows following the header + if skip_rows_after_header > 0 { + let mut split_lines = SplitLines::new(bytes, quote_char, eol_char, comment_prefix); + let mut current_line = &bytes[..0]; + + for _ in 0..skip_rows_after_header { + current_line = split_lines + .next() + .ok_or_else(|| polars_err!(NoData: "not enough lines to skip"))?; + } + + current_line = split_lines + .next() + .unwrap_or(¤t_line[current_line.len()..]); + bytes = &bytes[current_line.as_ptr() as usize - bytes.as_ptr() as usize..]; + } + + Ok( + // Some of the functions we call may return `&'static []` instead of + // slices of `&bytes[..]`. + if bytes.is_empty() { + full_len + } else { + bytes.as_ptr() as usize - starting_point_offset + }, + ) +} diff --git a/crates/polars-io/src/csv/read/read_impl/batched.rs b/crates/polars-io/src/csv/read/read_impl/batched.rs new file mode 100644 index 000000000000..fa6021efa94d --- /dev/null +++ b/crates/polars-io/src/csv/read/read_impl/batched.rs @@ -0,0 +1,312 @@ +use std::collections::VecDeque; +use std::ops::Deref; + +use polars_core::POOL; +use polars_core::datatypes::Field; +use polars_core::frame::DataFrame; +use polars_core::schema::SchemaRef; +use polars_error::PolarsResult; +use polars_utils::IdxSize; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; + +use super::{CoreReader, CountLines, cast_columns, read_chunk}; +use crate::RowIndex; +use crate::csv::read::CsvReader; +use crate::csv::read::options::NullValuesCompiled; +use crate::mmap::{MmapBytesReader, ReaderBytes}; +use crate::prelude::{CsvParseOptions, update_row_counts2}; + +#[allow(clippy::too_many_arguments)] +pub(crate) fn get_file_chunks_iterator( + offsets: &mut VecDeque<(usize, usize)>, + last_pos: &mut usize, + n_chunks: usize, + chunk_size: &mut usize, + bytes: &[u8], + quote_char: Option, + eol_char: u8, +) { + let cl = CountLines::new(quote_char, eol_char); + + for _ in 0..n_chunks { + let bytes = &bytes[*last_pos..]; + + if bytes.is_empty() { + break; + } + + let position; + + loop { + let b = &bytes[..(*chunk_size).min(bytes.len())]; + let (count, position_) = cl.count(b); + + let (count, position_) = if b.len() == bytes.len() { + (if count != 0 { count } else { 1 }, b.len()) + } else { + ( + count, + if position_ < b.len() { + // 1+ for the '\n' + 1 + position_ + } else { + position_ + }, + ) + }; + + if count == 0 { + *chunk_size *= 2; + continue; + } + + position = position_; + break; + } + + offsets.push_back((*last_pos, *last_pos + position)); + *last_pos += position; + } +} + +struct ChunkOffsetIter<'a> { + bytes: &'a [u8], + offsets: VecDeque<(usize, usize)>, + last_offset: usize, + n_chunks: usize, + chunk_size: usize, + // not a promise, but something we want + #[allow(unused)] + rows_per_batch: usize, + quote_char: Option, + eol_char: u8, +} + +impl Iterator for ChunkOffsetIter<'_> { + type Item = (usize, usize); + + fn next(&mut self) -> Option { + match self.offsets.pop_front() { + Some(offsets) => Some(offsets), + None => { + if self.last_offset == self.bytes.len() { + return None; + } + get_file_chunks_iterator( + &mut self.offsets, + &mut self.last_offset, + self.n_chunks, + &mut self.chunk_size, + self.bytes, + self.quote_char, + self.eol_char, + ); + match self.offsets.pop_front() { + Some(offsets) => Some(offsets), + // We depleted the iterator. Ensure we deplete the slice as well + None => { + let out = Some((self.last_offset, self.bytes.len())); + self.last_offset = self.bytes.len(); + out + }, + } + }, + } + } +} + +impl<'a> CoreReader<'a> { + /// Create a batched csv reader that uses mmap to load data. + pub fn batched(mut self) -> PolarsResult> { + let reader_bytes = self.reader_bytes.take().unwrap(); + let bytes = reader_bytes.as_ref(); + let (bytes, starting_point_offset) = self.find_starting_point( + bytes, + self.parse_options.quote_char, + self.parse_options.eol_char, + )?; + + let n_threads = self.n_threads.unwrap_or_else(|| POOL.current_num_threads()); + + // Copied from [`Self::parse_csv`] + let n_parts_hint = n_threads * 16; + let chunk_size = std::cmp::min(bytes.len() / n_parts_hint, 16 * 1024 * 1024); + + // Use a small min chunk size to catch failures in tests. + #[cfg(debug_assertions)] + let min_chunk_size = 64; + #[cfg(not(debug_assertions))] + let min_chunk_size = 1024 * 4; + + let chunk_size = std::cmp::max(chunk_size, min_chunk_size); + + // this is arbitrarily chosen. + // we don't want this to depend on the thread pool size + // otherwise the chunks are not deterministic + let offset_batch_size = 16; + // extend lifetime. It is bound to `readerbytes` and we keep track of that + // lifetime so this is sound. + let bytes = unsafe { std::mem::transmute::<&[u8], &'static [u8]>(bytes) }; + let file_chunks = ChunkOffsetIter { + bytes, + offsets: VecDeque::with_capacity(offset_batch_size), + last_offset: 0, + n_chunks: offset_batch_size, + chunk_size, + rows_per_batch: self.chunk_size, + quote_char: self.parse_options.quote_char, + eol_char: self.parse_options.eol_char, + }; + + let projection = self.get_projection()?; + + // RAII structure that will ensure we maintain a global stringcache + #[cfg(feature = "dtype-categorical")] + let _cat_lock = if self.has_categorical { + Some(polars_core::StringCacheHolder::hold()) + } else { + None + }; + + #[cfg(not(feature = "dtype-categorical"))] + let _cat_lock = None; + + Ok(BatchedCsvReader { + reader_bytes, + parse_options: self.parse_options, + chunk_size: self.chunk_size, + file_chunks_iter: file_chunks, + file_chunks: vec![], + projection, + starting_point_offset, + row_index: self.row_index, + null_values: self.null_values, + to_cast: self.to_cast, + ignore_errors: self.ignore_errors, + remaining: self.n_rows.unwrap_or(usize::MAX), + schema: self.schema, + rows_read: 0, + _cat_lock, + }) + } +} + +pub struct BatchedCsvReader<'a> { + reader_bytes: ReaderBytes<'a>, + parse_options: CsvParseOptions, + chunk_size: usize, + file_chunks_iter: ChunkOffsetIter<'a>, + file_chunks: Vec<(usize, usize)>, + projection: Vec, + starting_point_offset: Option, + row_index: Option, + null_values: Option, + to_cast: Vec, + ignore_errors: bool, + remaining: usize, + schema: SchemaRef, + rows_read: IdxSize, + #[cfg(feature = "dtype-categorical")] + _cat_lock: Option, + #[cfg(not(feature = "dtype-categorical"))] + _cat_lock: Option, +} + +impl BatchedCsvReader<'_> { + pub fn next_batches(&mut self, n: usize) -> PolarsResult>> { + if n == 0 || self.remaining == 0 { + return Ok(None); + } + + // get next `n` offset positions. + let file_chunks_iter = (&mut self.file_chunks_iter).take(n); + self.file_chunks.extend(file_chunks_iter); + // depleted the offsets iterator, we are done as well. + if self.file_chunks.is_empty() { + return Ok(None); + } + let chunks = &self.file_chunks; + + let mut bytes = self.reader_bytes.deref(); + if let Some(pos) = self.starting_point_offset { + bytes = &bytes[pos..]; + } + + let mut chunks = POOL.install(|| { + chunks + .into_par_iter() + .copied() + .map(|(bytes_offset_thread, stop_at_nbytes)| { + let mut df = read_chunk( + bytes, + &self.parse_options, + self.schema.as_ref(), + self.ignore_errors, + &self.projection, + bytes_offset_thread, + self.chunk_size, + self.null_values.as_ref(), + usize::MAX, + stop_at_nbytes, + self.starting_point_offset, + )?; + + cast_columns(&mut df, &self.to_cast, false, self.ignore_errors)?; + + if let Some(rc) = &self.row_index { + unsafe { df.with_row_index_mut(rc.name.clone(), Some(rc.offset)) }; + } + Ok(df) + }) + .collect::>>() + })?; + self.file_chunks.clear(); + + if self.row_index.is_some() { + update_row_counts2(&mut chunks, self.rows_read) + } + for df in &mut chunks { + let h = df.height(); + + if self.remaining < h { + *df = df.slice(0, self.remaining) + }; + self.remaining = self.remaining.saturating_sub(h); + + self.rows_read += h as IdxSize; + } + Ok(Some(chunks)) + } +} + +pub struct OwnedBatchedCsvReader { + #[allow(dead_code)] + // this exist because we need to keep ownership + schema: SchemaRef, + batched_reader: BatchedCsvReader<'static>, + // keep ownership + _reader: CsvReader>, +} + +impl OwnedBatchedCsvReader { + pub fn next_batches(&mut self, n: usize) -> PolarsResult>> { + self.batched_reader.next_batches(n) + } +} + +pub fn to_batched_owned( + mut reader: CsvReader>, +) -> PolarsResult { + let batched_reader = reader.batched_borrowed()?; + let schema = batched_reader.schema.clone(); + // If you put a drop(reader) here, rust will complain that reader is borrowed, + // so we presumably have to keep ownership of it to maintain the safety of the + // 'static transmute. + let batched_reader: BatchedCsvReader<'static> = unsafe { std::mem::transmute(batched_reader) }; + + Ok(OwnedBatchedCsvReader { + schema, + batched_reader, + _reader: reader, + }) +} diff --git a/crates/polars-io/src/csv/read/reader.rs b/crates/polars-io/src/csv/read/reader.rs new file mode 100644 index 000000000000..980fcf5b4f5e --- /dev/null +++ b/crates/polars-io/src/csv/read/reader.rs @@ -0,0 +1,259 @@ +use std::fs::File; +use std::path::PathBuf; + +use polars_core::prelude::*; + +use super::options::CsvReadOptions; +use super::read_impl::CoreReader; +use super::read_impl::batched::to_batched_owned; +use super::{BatchedCsvReader, OwnedBatchedCsvReader}; +use crate::mmap::MmapBytesReader; +use crate::path_utils::resolve_homedir; +use crate::predicates::PhysicalIoExpr; +use crate::shared::SerReader; +use crate::utils::get_reader_bytes; + +/// Create a new DataFrame by reading a csv file. +/// +/// # Example +/// +/// ``` +/// use polars_core::prelude::*; +/// use polars_io::prelude::*; +/// use std::fs::File; +/// +/// fn example() -> PolarsResult { +/// CsvReadOptions::default() +/// .with_has_header(true) +/// .try_into_reader_with_file_path(Some("iris.csv".into()))? +/// .finish() +/// } +/// ``` +#[must_use] +pub struct CsvReader +where + R: MmapBytesReader, +{ + /// File or Stream object. + reader: R, + /// Options for the CSV reader. + options: CsvReadOptions, + predicate: Option>, +} + +impl CsvReader +where + R: MmapBytesReader, +{ + pub fn _with_predicate(mut self, predicate: Option>) -> Self { + self.predicate = predicate; + self + } + + // TODO: Investigate if we can remove this + pub(crate) fn with_schema(mut self, schema: SchemaRef) -> Self { + self.options.schema = Some(schema); + self + } +} + +impl CsvReadOptions { + /// Creates a CSV reader using a file path. + /// + /// # Panics + /// If both self.path and the path parameter are non-null. Only one of them is + /// to be non-null. + pub fn try_into_reader_with_file_path( + mut self, + path: Option, + ) -> PolarsResult> { + if self.path.is_some() { + assert!( + path.is_none(), + "impl error: only 1 of self.path or the path parameter is to be non-null" + ); + } else { + self.path = path; + }; + + assert!( + self.path.is_some(), + "impl error: either one of self.path or the path parameter is to be non-null" + ); + + let path = resolve_homedir(self.path.as_ref().unwrap()); + let reader = polars_utils::open_file(&path)?; + let options = self; + + Ok(CsvReader { + reader, + options, + predicate: None, + }) + } + + /// Creates a CSV reader using a file handle. + pub fn into_reader_with_file_handle(self, reader: R) -> CsvReader { + let options = self; + + CsvReader { + reader, + options, + predicate: Default::default(), + } + } +} + +impl CsvReader { + fn core_reader(&mut self) -> PolarsResult { + let reader_bytes = get_reader_bytes(&mut self.reader)?; + + let parse_options = self.options.get_parse_options(); + + CoreReader::new( + reader_bytes, + parse_options, + self.options.n_rows, + self.options.skip_rows, + self.options.skip_lines, + self.options.projection.clone().map(|x| x.as_ref().clone()), + self.options.infer_schema_length, + self.options.has_header, + self.options.ignore_errors, + self.options.schema.clone(), + self.options.columns.clone(), + self.options.n_threads, + self.options.schema_overwrite.clone(), + self.options.dtype_overwrite.clone(), + self.options.chunk_size, + self.predicate.clone(), + self.options.fields_to_cast.clone(), + self.options.skip_rows_after_header, + self.options.row_index.clone(), + self.options.raise_if_empty, + ) + } + + pub fn batched_borrowed(&mut self) -> PolarsResult { + let csv_reader = self.core_reader()?; + csv_reader.batched() + } +} + +impl CsvReader> { + pub fn batched(mut self, schema: Option) -> PolarsResult { + if let Some(schema) = schema { + self = self.with_schema(schema); + } + + to_batched_owned(self) + } +} + +impl SerReader for CsvReader +where + R: MmapBytesReader, +{ + /// Create a new CsvReader from a file/stream using default read options. To + /// use non-default read options, first construct [CsvReadOptions] and then use + /// any of the `(try)_into_` methods. + fn new(reader: R) -> Self { + CsvReader { + reader, + options: Default::default(), + predicate: None, + } + } + + /// Read the file and create the DataFrame. + fn finish(mut self) -> PolarsResult { + let rechunk = self.options.rechunk; + let low_memory = self.options.low_memory; + + let csv_reader = self.core_reader()?; + let mut df = csv_reader.finish()?; + + // Important that this rechunk is never done in parallel. + // As that leads to great memory overhead. + if rechunk && df.first_col_n_chunks() > 1 { + if low_memory { + df.as_single_chunk(); + } else { + df.as_single_chunk_par(); + } + } + + Ok(df) + } +} + +impl CsvReader { + /// Sets custom CSV read options. + pub fn with_options(mut self, options: CsvReadOptions) -> Self { + self.options = options; + self + } +} + +/// Splits datatypes that cannot be natively read into a `fields_to_cast` for +/// post-read casting. +/// +/// # Returns +/// `has_categorical` +pub fn prepare_csv_schema( + schema: &mut SchemaRef, + fields_to_cast: &mut Vec, +) -> PolarsResult { + // This branch we check if there are dtypes we cannot parse. + // We only support a few dtypes in the parser and later cast to the required dtype + let mut _has_categorical = false; + + let mut changed = false; + + let new_schema = schema + .iter_fields() + .map(|mut fld| { + use DataType::*; + + let mut matched = true; + + let out = match fld.dtype() { + Time => { + fields_to_cast.push(fld.clone()); + fld.coerce(String); + PolarsResult::Ok(fld) + }, + #[cfg(feature = "dtype-categorical")] + Categorical(_, _) => { + _has_categorical = true; + PolarsResult::Ok(fld) + }, + #[cfg(feature = "dtype-decimal")] + Decimal(precision, scale) => match (precision, scale) { + (_, Some(_)) => { + fields_to_cast.push(fld.clone()); + fld.coerce(String); + PolarsResult::Ok(fld) + }, + _ => Err(PolarsError::ComputeError( + "'scale' must be set when reading csv column as Decimal".into(), + )), + }, + _ => { + matched = false; + PolarsResult::Ok(fld) + }, + }?; + + changed |= matched; + + PolarsResult::Ok(out) + }) + .collect::>()?; + + if changed { + *schema = Arc::new(new_schema); + } + + Ok(_has_categorical) +} diff --git a/crates/polars-io/src/csv/read/schema_inference.rs b/crates/polars-io/src/csv/read/schema_inference.rs new file mode 100644 index 000000000000..ea21ef5926b7 --- /dev/null +++ b/crates/polars-io/src/csv/read/schema_inference.rs @@ -0,0 +1,571 @@ +use std::borrow::Cow; + +use polars_core::prelude::*; +#[cfg(feature = "polars-time")] +use polars_time::chunkedarray::string::infer as date_infer; +#[cfg(feature = "polars-time")] +use polars_time::prelude::string::Pattern; +use polars_utils::format_pl_smallstr; + +use super::parser::{SplitLines, is_comment_line, skip_bom, skip_line_ending}; +use super::splitfields::SplitFields; +use super::{CsvEncoding, CsvParseOptions, CsvReadOptions, NullValues}; +use crate::csv::read::parser::skip_lines_naive; +use crate::mmap::ReaderBytes; +use crate::utils::{BOOLEAN_RE, FLOAT_RE, FLOAT_RE_DECIMAL, INTEGER_RE}; + +#[derive(Clone, Debug, Default)] +pub struct SchemaInferenceResult { + inferred_schema: SchemaRef, + rows_read: usize, + bytes_read: usize, + bytes_total: usize, + n_threads: Option, +} + +impl SchemaInferenceResult { + pub fn try_from_reader_bytes_and_options( + reader_bytes: &ReaderBytes, + options: &CsvReadOptions, + ) -> PolarsResult { + let parse_options = options.get_parse_options(); + + let infer_schema_length = options.infer_schema_length; + let has_header = options.has_header; + let schema_overwrite_arc = options.schema_overwrite.clone(); + let schema_overwrite = schema_overwrite_arc.as_ref().map(|x| x.as_ref()); + let skip_rows = options.skip_rows; + let skip_lines = options.skip_lines; + let skip_rows_after_header = options.skip_rows_after_header; + let raise_if_empty = options.raise_if_empty; + let n_threads = options.n_threads; + + let bytes_total = reader_bytes.len(); + + let (inferred_schema, rows_read, bytes_read) = infer_file_schema( + reader_bytes, + &parse_options, + infer_schema_length, + has_header, + schema_overwrite, + skip_rows, + skip_lines, + skip_rows_after_header, + raise_if_empty, + )?; + + let this = Self { + inferred_schema: Arc::new(inferred_schema), + rows_read, + bytes_read, + bytes_total, + n_threads, + }; + + Ok(this) + } + + pub fn with_inferred_schema(mut self, inferred_schema: SchemaRef) -> Self { + self.inferred_schema = inferred_schema; + self + } + + pub fn get_inferred_schema(&self) -> SchemaRef { + self.inferred_schema.clone() + } + + pub fn get_estimated_n_rows(&self) -> usize { + (self.rows_read as f64 / self.bytes_read as f64 * self.bytes_total as f64) as usize + } +} + +impl CsvReadOptions { + /// Note: This does not update the schema from the inference result. + pub fn update_with_inference_result(&mut self, si_result: &SchemaInferenceResult) { + self.n_threads = si_result.n_threads; + } +} + +pub fn finish_infer_field_schema(possibilities: &PlHashSet) -> DataType { + // determine data type based on possible types + // if there are incompatible types, use DataType::String + match possibilities.len() { + 1 => possibilities.iter().next().unwrap().clone(), + 2 if possibilities.contains(&DataType::Int64) + && possibilities.contains(&DataType::Float64) => + { + // we have an integer and double, fall down to double + DataType::Float64 + }, + // default to String for conflicting datatypes (e.g bool and int) + _ => DataType::String, + } +} + +/// Infer the data type of a record +pub fn infer_field_schema(string: &str, try_parse_dates: bool, decimal_comma: bool) -> DataType { + // when quoting is enabled in the reader, these quotes aren't escaped, we default to + // String for them + let bytes = string.as_bytes(); + if bytes.len() >= 2 && *bytes.first().unwrap() == b'"' && *bytes.last().unwrap() == b'"' { + if try_parse_dates { + #[cfg(feature = "polars-time")] + { + match date_infer::infer_pattern_single(&string[1..string.len() - 1]) { + Some(pattern_with_offset) => match pattern_with_offset { + Pattern::DatetimeYMD | Pattern::DatetimeDMY => { + DataType::Datetime(TimeUnit::Microseconds, None) + }, + Pattern::DateYMD | Pattern::DateDMY => DataType::Date, + Pattern::DatetimeYMDZ => DataType::Datetime( + TimeUnit::Microseconds, + Some(PlSmallStr::from_static("UTC")), + ), + Pattern::Time => DataType::Time, + }, + None => DataType::String, + } + } + #[cfg(not(feature = "polars-time"))] + { + panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features") + } + } else { + DataType::String + } + } + // match regex in a particular order + else if BOOLEAN_RE.is_match(string) { + DataType::Boolean + } else if !decimal_comma && FLOAT_RE.is_match(string) + || decimal_comma && FLOAT_RE_DECIMAL.is_match(string) + { + DataType::Float64 + } else if INTEGER_RE.is_match(string) { + DataType::Int64 + } else if try_parse_dates { + #[cfg(feature = "polars-time")] + { + match date_infer::infer_pattern_single(string) { + Some(pattern_with_offset) => match pattern_with_offset { + Pattern::DatetimeYMD | Pattern::DatetimeDMY => { + DataType::Datetime(TimeUnit::Microseconds, None) + }, + Pattern::DateYMD | Pattern::DateDMY => DataType::Date, + Pattern::DatetimeYMDZ => DataType::Datetime( + TimeUnit::Microseconds, + Some(PlSmallStr::from_static("UTC")), + ), + Pattern::Time => DataType::Time, + }, + None => DataType::String, + } + } + #[cfg(not(feature = "polars-time"))] + { + panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features") + } + } else { + DataType::String + } +} + +#[inline] +fn parse_bytes_with_encoding(bytes: &[u8], encoding: CsvEncoding) -> PolarsResult> { + Ok(match encoding { + CsvEncoding::Utf8 => simdutf8::basic::from_utf8(bytes) + .map_err(|_| polars_err!(ComputeError: "invalid utf-8 sequence"))? + .into(), + CsvEncoding::LossyUtf8 => String::from_utf8_lossy(bytes), + }) +} + +fn column_name(i: usize) -> PlSmallStr { + format_pl_smallstr!("column_{}", i + 1) +} + +#[allow(clippy::too_many_arguments)] +fn infer_file_schema_inner( + reader_bytes: &ReaderBytes, + parse_options: &CsvParseOptions, + max_read_rows: Option, + has_header: bool, + schema_overwrite: Option<&Schema>, + // we take &mut because we maybe need to skip more rows dependent + // on the schema inference + mut skip_rows: usize, + skip_rows_after_header: usize, + recursion_count: u8, + raise_if_empty: bool, +) -> PolarsResult<(Schema, usize, usize)> { + // keep track so that we can determine the amount of bytes read + let start_ptr = reader_bytes.as_ptr() as usize; + + // We use lossy utf8 here because we don't want the schema inference to fail on utf8. + // It may later. + let encoding = CsvEncoding::LossyUtf8; + + let bytes = skip_line_ending(skip_bom(reader_bytes), parse_options.eol_char); + if raise_if_empty { + polars_ensure!(!bytes.is_empty(), NoData: "empty CSV"); + }; + let mut lines = SplitLines::new( + bytes, + parse_options.quote_char, + parse_options.eol_char, + parse_options.comment_prefix.as_ref(), + ) + .skip(skip_rows); + + // get or create header names + // when has_header is false, creates default column names with column_ prefix + + // skip lines that are comments + let mut first_line = None; + + for (i, line) in (&mut lines).enumerate() { + if !is_comment_line(line, parse_options.comment_prefix.as_ref()) { + first_line = Some(line); + skip_rows += i; + break; + } + } + + if first_line.is_none() { + first_line = lines.next(); + } + + // now that we've found the first non-comment line we parse the headers, or we create a header + let mut headers: Vec = if let Some(mut header_line) = first_line { + let len = header_line.len(); + if len > 1 { + // remove carriage return + let trailing_byte = header_line[len - 1]; + if trailing_byte == b'\r' { + header_line = &header_line[..len - 1]; + } + } + + let byterecord = SplitFields::new( + header_line, + parse_options.separator, + parse_options.quote_char, + parse_options.eol_char, + ); + if has_header { + let headers = byterecord + .map(|(slice, needs_escaping)| { + let slice_escaped = if needs_escaping && (slice.len() >= 2) { + &slice[1..(slice.len() - 1)] + } else { + slice + }; + let s = parse_bytes_with_encoding(slice_escaped, encoding)?; + Ok(s) + }) + .collect::>>()?; + + let mut final_headers = Vec::with_capacity(headers.len()); + + let mut header_names = PlHashMap::with_capacity(headers.len()); + + for name in &headers { + let count = header_names.entry(name.as_ref()).or_insert(0usize); + if *count != 0 { + final_headers.push(format_pl_smallstr!("{}_duplicated_{}", name, *count - 1)) + } else { + final_headers.push(PlSmallStr::from_str(name)) + } + *count += 1; + } + final_headers + } else { + byterecord + .enumerate() + .map(|(i, _s)| column_name(i)) + .collect::>() + } + } else if has_header && !bytes.is_empty() && recursion_count == 0 { + // there was no new line char. So we copy the whole buf and add one + // this is likely to be cheap as there are no rows. + let mut buf = Vec::with_capacity(bytes.len() + 2); + buf.extend_from_slice(bytes); + buf.push(parse_options.eol_char); + + return infer_file_schema_inner( + &ReaderBytes::Owned(buf.into()), + parse_options, + max_read_rows, + has_header, + schema_overwrite, + skip_rows, + skip_rows_after_header, + recursion_count + 1, + raise_if_empty, + ); + } else if !raise_if_empty { + return Ok((Schema::default(), 0, 0)); + } else { + polars_bail!(NoData: "empty CSV"); + }; + if !has_header { + // re-init lines so that the header is included in type inference. + lines = SplitLines::new( + bytes, + parse_options.quote_char, + parse_options.eol_char, + parse_options.comment_prefix.as_ref(), + ) + .skip(skip_rows); + } + + // keep track of inferred field types + let mut column_types: Vec> = + vec![PlHashSet::with_capacity(4); headers.len()]; + // keep track of columns with nulls + let mut nulls: Vec = vec![false; headers.len()]; + + let mut rows_count = 0; + let mut fields = Vec::with_capacity(headers.len()); + + // needed to prevent ownership going into the iterator loop + let records_ref = &mut lines; + + let mut end_ptr = start_ptr; + for mut line in records_ref + .take(match max_read_rows { + Some(max_read_rows) => { + if max_read_rows <= (usize::MAX - skip_rows_after_header) { + // read skip_rows_after_header more rows for inferring + // the correct schema as the first skip_rows_after_header + // rows will be skipped + max_read_rows + skip_rows_after_header + } else { + max_read_rows + } + }, + None => usize::MAX, + }) + .skip(skip_rows_after_header) + { + rows_count += 1; + // keep track so that we can determine the amount of bytes read + end_ptr = line.as_ptr() as usize + line.len(); + + if line.is_empty() { + continue; + } + + // line is a comment -> skip + if is_comment_line(line, parse_options.comment_prefix.as_ref()) { + continue; + } + + let len = line.len(); + if len > 1 { + // remove carriage return + let trailing_byte = line[len - 1]; + if trailing_byte == b'\r' { + line = &line[..len - 1]; + } + } + + let record = SplitFields::new( + line, + parse_options.separator, + parse_options.quote_char, + parse_options.eol_char, + ); + + for (i, (slice, needs_escaping)) in record.enumerate() { + // When `has_header = False` and `` + // Increase the schema if the first line didn't have all columns. + if i >= headers.len() { + if !has_header { + headers.push(column_name(i)); + column_types.push(Default::default()); + nulls.push(false); + } else { + break; + } + } + + if slice.is_empty() { + unsafe { *nulls.get_unchecked_mut(i) = true }; + } else { + let slice_escaped = if needs_escaping && (slice.len() >= 2) { + &slice[1..(slice.len() - 1)] + } else { + slice + }; + let s = parse_bytes_with_encoding(slice_escaped, encoding)?; + let dtype = match &parse_options.null_values { + None => Some(infer_field_schema( + &s, + parse_options.try_parse_dates, + parse_options.decimal_comma, + )), + Some(NullValues::AllColumns(names)) => { + if !names.iter().any(|nv| nv == s.as_ref()) { + Some(infer_field_schema( + &s, + parse_options.try_parse_dates, + parse_options.decimal_comma, + )) + } else { + None + } + }, + Some(NullValues::AllColumnsSingle(name)) => { + if s.as_ref() != name.as_str() { + Some(infer_field_schema( + &s, + parse_options.try_parse_dates, + parse_options.decimal_comma, + )) + } else { + None + } + }, + Some(NullValues::Named(names)) => { + // SAFETY: + // we iterate over headers length. + let current_name = unsafe { headers.get_unchecked(i) }; + let null_name = &names.iter().find(|name| name.0 == current_name); + + if let Some(null_name) = null_name { + if null_name.1.as_str() != s.as_ref() { + Some(infer_field_schema( + &s, + parse_options.try_parse_dates, + parse_options.decimal_comma, + )) + } else { + None + } + } else { + Some(infer_field_schema( + &s, + parse_options.try_parse_dates, + parse_options.decimal_comma, + )) + } + }, + }; + if let Some(dtype) = dtype { + unsafe { column_types.get_unchecked_mut(i).insert(dtype) }; + } + } + } + } + + // build schema from inference results + for i in 0..headers.len() { + let field_name = &headers[i]; + + if let Some(schema_overwrite) = schema_overwrite { + if let Some((_, name, dtype)) = schema_overwrite.get_full(field_name) { + fields.push(Field::new(name.clone(), dtype.clone())); + continue; + } + + // column might have been renamed + // execute only if schema is complete + if schema_overwrite.len() == headers.len() { + if let Some((name, dtype)) = schema_overwrite.get_at_index(i) { + fields.push(Field::new(name.clone(), dtype.clone())); + continue; + } + } + } + + let possibilities = &column_types[i]; + let dtype = finish_infer_field_schema(possibilities); + fields.push(Field::new(field_name.clone(), dtype)); + } + // if there is a single line after the header without an eol + // we copy the bytes add an eol and rerun this function + // so that the inference is consistent with and without eol char + if rows_count == 0 + && !reader_bytes.is_empty() + && reader_bytes[reader_bytes.len() - 1] != parse_options.eol_char + && recursion_count == 0 + { + let mut rb = Vec::with_capacity(reader_bytes.len() + 1); + rb.extend_from_slice(reader_bytes); + rb.push(parse_options.eol_char); + return infer_file_schema_inner( + &ReaderBytes::Owned(rb.into()), + parse_options, + max_read_rows, + has_header, + schema_overwrite, + skip_rows, + skip_rows_after_header, + recursion_count + 1, + raise_if_empty, + ); + } + + Ok((Schema::from_iter(fields), rows_count, end_ptr - start_ptr)) +} + +pub(super) fn check_decimal_comma(decimal_comma: bool, separator: u8) -> PolarsResult<()> { + if decimal_comma { + polars_ensure!(b',' != separator, InvalidOperation: "'decimal_comma' argument cannot be combined with ',' separator") + } + Ok(()) +} + +/// Infer the schema of a CSV file by reading through the first n rows of the file, +/// with `max_read_rows` controlling the maximum number of rows to read. +/// +/// If `max_read_rows` is not set, the whole file is read to infer its schema. +/// +/// Returns +/// - inferred schema +/// - number of rows used for inference. +/// - bytes read +#[allow(clippy::too_many_arguments)] +pub fn infer_file_schema( + reader_bytes: &ReaderBytes, + parse_options: &CsvParseOptions, + max_read_rows: Option, + has_header: bool, + schema_overwrite: Option<&Schema>, + skip_rows: usize, + skip_lines: usize, + skip_rows_after_header: usize, + raise_if_empty: bool, +) -> PolarsResult<(Schema, usize, usize)> { + check_decimal_comma(parse_options.decimal_comma, parse_options.separator)?; + + if skip_lines > 0 { + polars_ensure!(skip_rows == 0, InvalidOperation: "only one of 'skip_rows'/'skip_lines' may be set"); + let bytes = skip_lines_naive(reader_bytes, parse_options.eol_char, skip_lines); + let reader_bytes = ReaderBytes::Borrowed(bytes); + infer_file_schema_inner( + &reader_bytes, + parse_options, + max_read_rows, + has_header, + schema_overwrite, + skip_rows, + skip_rows_after_header, + 0, + raise_if_empty, + ) + } else { + infer_file_schema_inner( + reader_bytes, + parse_options, + max_read_rows, + has_header, + schema_overwrite, + skip_rows, + skip_rows_after_header, + 0, + raise_if_empty, + ) + } +} diff --git a/crates/polars-io/src/csv/read/splitfields.rs b/crates/polars-io/src/csv/read/splitfields.rs new file mode 100644 index 000000000000..002230eaf458 --- /dev/null +++ b/crates/polars-io/src/csv/read/splitfields.rs @@ -0,0 +1,432 @@ +#![allow(unsafe_op_in_unsafe_fn)] +#[cfg(not(feature = "simd"))] +mod inner { + /// An adapted version of std::iter::Split. + /// This exists solely because we cannot split the lines naively as + pub(crate) struct SplitFields<'a> { + v: &'a [u8], + separator: u8, + finished: bool, + quote_char: u8, + quoting: bool, + eol_char: u8, + } + + impl<'a> SplitFields<'a> { + pub(crate) fn new( + slice: &'a [u8], + separator: u8, + quote_char: Option, + eol_char: u8, + ) -> Self { + Self { + v: slice, + separator, + finished: false, + quote_char: quote_char.unwrap_or(b'"'), + quoting: quote_char.is_some(), + eol_char, + } + } + + unsafe fn finish_eol( + &mut self, + need_escaping: bool, + idx: usize, + ) -> Option<(&'a [u8], bool)> { + self.finished = true; + debug_assert!(idx <= self.v.len()); + Some((self.v.get_unchecked(..idx), need_escaping)) + } + + fn finish(&mut self, need_escaping: bool) -> Option<(&'a [u8], bool)> { + self.finished = true; + Some((self.v, need_escaping)) + } + + fn eof_oel(&self, current_ch: u8) -> bool { + current_ch == self.separator || current_ch == self.eol_char + } + } + + impl<'a> Iterator for SplitFields<'a> { + // the bool is used to indicate that it requires escaping + type Item = (&'a [u8], bool); + + #[inline] + fn next(&mut self) -> Option<(&'a [u8], bool)> { + if self.finished { + return None; + } else if self.v.is_empty() { + return self.finish(false); + } + + let mut needs_escaping = false; + // There can be strings with separators: + // "Street, City", + + // SAFETY: + // we have checked bounds + let pos = if self.quoting && unsafe { *self.v.get_unchecked(0) } == self.quote_char { + needs_escaping = true; + // There can be pair of double-quotes within string. + // Each of the embedded double-quote characters must be represented + // by a pair of double-quote characters: + // e.g. 1997,Ford,E350,"Super, ""luxurious"" truck",20020 + + // denotes if we are in a string field, started with a quote + let mut in_field = false; + + let mut idx = 0u32; + let mut current_idx = 0u32; + // micro optimizations + #[allow(clippy::explicit_counter_loop)] + for &c in self.v.iter() { + if c == self.quote_char { + // toggle between string field enclosure + // if we encounter a starting '"' -> in_field = true; + // if we encounter a closing '"' -> in_field = false; + in_field = !in_field; + } + + if !in_field && self.eof_oel(c) { + if c == self.eol_char { + // SAFETY: + // we are in bounds + return unsafe { + self.finish_eol(needs_escaping, current_idx as usize) + }; + } + idx = current_idx; + break; + } + current_idx += 1; + } + + if idx == 0 { + return self.finish(needs_escaping); + } + + idx as usize + } else { + match self.v.iter().position(|&c| self.eof_oel(c)) { + None => return self.finish(needs_escaping), + Some(idx) => unsafe { + // SAFETY: + // idx was just found + if *self.v.get_unchecked(idx) == self.eol_char { + return self.finish_eol(needs_escaping, idx); + } else { + idx + } + }, + } + }; + + unsafe { + debug_assert!(pos <= self.v.len()); + // SAFETY: + // we are in bounds + let ret = Some((self.v.get_unchecked(..pos), needs_escaping)); + self.v = self.v.get_unchecked(pos + 1..); + ret + } + } + } +} + +#[cfg(feature = "simd")] +mod inner { + use std::simd::prelude::*; + + use polars_utils::clmul::prefix_xorsum_inclusive; + + const SIMD_SIZE: usize = 64; + type SimdVec = u8x64; + + /// An adapted version of std::iter::Split. + /// This exists solely because we cannot split the lines naively as + pub(crate) struct SplitFields<'a> { + pub v: &'a [u8], + separator: u8, + pub finished: bool, + quote_char: u8, + quoting: bool, + eol_char: u8, + simd_separator: SimdVec, + simd_eol_char: SimdVec, + simd_quote_char: SimdVec, + previous_valid_ends: u64, + } + + impl<'a> SplitFields<'a> { + pub(crate) fn new( + slice: &'a [u8], + separator: u8, + quote_char: Option, + eol_char: u8, + ) -> Self { + let simd_separator = SimdVec::splat(separator); + let simd_eol_char = SimdVec::splat(eol_char); + let quoting = quote_char.is_some(); + let quote_char = quote_char.unwrap_or(b'"'); + let simd_quote_char = SimdVec::splat(quote_char); + + Self { + v: slice, + separator, + finished: false, + quote_char, + quoting, + eol_char, + simd_separator, + simd_eol_char, + simd_quote_char, + previous_valid_ends: 0, + } + } + + unsafe fn finish_eol( + &mut self, + need_escaping: bool, + pos: usize, + ) -> Option<(&'a [u8], bool)> { + self.finished = true; + debug_assert!(pos <= self.v.len()); + Some((self.v.get_unchecked(..pos), need_escaping)) + } + + #[inline] + fn finish(&mut self, need_escaping: bool) -> Option<(&'a [u8], bool)> { + self.finished = true; + Some((self.v, need_escaping)) + } + + fn eof_oel(&self, current_ch: u8) -> bool { + current_ch == self.separator || current_ch == self.eol_char + } + } + + impl<'a> Iterator for SplitFields<'a> { + // the bool is used to indicate that it requires escaping + type Item = (&'a [u8], bool); + + #[inline] + fn next(&mut self) -> Option<(&'a [u8], bool)> { + // This must be before we check the cached value + if self.finished { + return None; + } + // Then check cached value as this is hot. + if self.previous_valid_ends != 0 { + let pos = self.previous_valid_ends.trailing_zeros() as usize; + self.previous_valid_ends >>= (pos + 1) as u64; + + unsafe { + debug_assert!(pos < self.v.len()); + // SAFETY: + // we are in bounds + let needs_escaping = self + .v + .first() + .map(|c| *c == self.quote_char && self.quoting) + .unwrap_or(false); + + if *self.v.get_unchecked(pos) == self.eol_char { + return self.finish_eol(needs_escaping, pos); + } + + let bytes = self.v.get_unchecked(..pos); + + self.v = self.v.get_unchecked(pos + 1..); + let ret = Some((bytes, needs_escaping)); + + return ret; + } + } + if self.v.is_empty() { + return self.finish(false); + } + + let mut needs_escaping = false; + // There can be strings with separators: + // "Street, City", + + // SAFETY: + // we have checked bounds + let pos = if self.quoting && unsafe { *self.v.get_unchecked(0) } == self.quote_char { + let mut total_idx = 0; + needs_escaping = true; + let mut not_in_field_previous_iter = true; + + loop { + let bytes = unsafe { self.v.get_unchecked(total_idx..) }; + + if bytes.len() > SIMD_SIZE { + let lane: [u8; SIMD_SIZE] = unsafe { + bytes + .get_unchecked(0..SIMD_SIZE) + .try_into() + .unwrap_unchecked() + }; + let simd_bytes = SimdVec::from(lane); + let has_eol = simd_bytes.simd_eq(self.simd_eol_char); + let has_sep = simd_bytes.simd_eq(self.simd_separator); + let quote_mask = simd_bytes.simd_eq(self.simd_quote_char).to_bitmask(); + let mut end_mask = (has_sep | has_eol).to_bitmask(); + + let mut not_in_quote_field = prefix_xorsum_inclusive(quote_mask); + + if not_in_field_previous_iter { + not_in_quote_field = !not_in_quote_field; + } + not_in_field_previous_iter = + (not_in_quote_field & (1 << (SIMD_SIZE - 1))) > 0; + end_mask &= not_in_quote_field; + + if end_mask != 0 { + let pos = end_mask.trailing_zeros() as usize; + total_idx += pos; + debug_assert!( + self.v[total_idx] == self.eol_char + || self.v[total_idx] == self.separator + ); + + if pos == SIMD_SIZE - 1 { + self.previous_valid_ends = 0; + } else { + self.previous_valid_ends = end_mask >> (pos + 1) as u64; + } + + break; + } else { + total_idx += SIMD_SIZE; + } + } else { + // There can be a pair of double-quotes within a string. + // Each of the embedded double-quote characters must be represented + // by a pair of double-quote characters: + // e.g. 1997,Ford,E350,"Super, ""luxurious"" truck",20020 + + // denotes if we are in a string field, started with a quote + let mut in_field = !not_in_field_previous_iter; + + // usize::MAX is unset. + let mut idx = usize::MAX; + let mut current_idx = 0; + // micro optimizations + #[allow(clippy::explicit_counter_loop)] + for &c in bytes.iter() { + if c == self.quote_char { + // toggle between string field enclosure + // if we encounter a starting '"' -> in_field = true; + // if we encounter a closing '"' -> in_field = false; + in_field = !in_field; + } + + if !in_field && self.eof_oel(c) { + if c == self.eol_char { + // SAFETY: + // we are in bounds + return unsafe { + self.finish_eol(needs_escaping, current_idx + total_idx) + }; + } + idx = current_idx; + break; + } + current_idx += 1; + } + + if idx == usize::MAX { + return self.finish(needs_escaping); + } + + total_idx += idx; + debug_assert!( + self.v[total_idx] == self.eol_char + || self.v[total_idx] == self.separator + ); + break; + } + } + total_idx + } else { + let mut total_idx = 0; + + loop { + let bytes = unsafe { self.v.get_unchecked(total_idx..) }; + + if bytes.len() > SIMD_SIZE { + let lane: [u8; SIMD_SIZE] = unsafe { + bytes + .get_unchecked(0..SIMD_SIZE) + .try_into() + .unwrap_unchecked() + }; + let simd_bytes = SimdVec::from(lane); + let has_eol_char = simd_bytes.simd_eq(self.simd_eol_char); + let has_separator = simd_bytes.simd_eq(self.simd_separator); + let has_any_mask = (has_separator | has_eol_char).to_bitmask(); + + if has_any_mask != 0 { + total_idx += has_any_mask.trailing_zeros() as usize; + break; + } else { + total_idx += SIMD_SIZE; + } + } else { + match bytes.iter().position(|&c| self.eof_oel(c)) { + None => return self.finish(needs_escaping), + Some(idx) => { + total_idx += idx; + break; + }, + } + } + } + unsafe { + if *self.v.get_unchecked(total_idx) == self.eol_char { + return self.finish_eol(needs_escaping, total_idx); + } else { + total_idx + } + } + }; + + unsafe { + debug_assert!(pos < self.v.len()); + // SAFETY: + // we are in bounds + let ret = Some((self.v.get_unchecked(..pos), needs_escaping)); + self.v = self.v.get_unchecked(pos + 1..); + ret + } + } + } +} + +pub(crate) use inner::SplitFields; + +#[cfg(test)] +mod test { + use super::SplitFields; + + #[test] + fn test_splitfields() { + let input = "\"foo\",\"bar\""; + let mut fields = SplitFields::new(input.as_bytes(), b',', Some(b'"'), b'\n'); + + assert_eq!(fields.next(), Some(("\"foo\"".as_bytes(), true))); + assert_eq!(fields.next(), Some(("\"bar\"".as_bytes(), true))); + assert_eq!(fields.next(), None); + + let input2 = "\"foo\n bar\";\"baz\";12345"; + let mut fields2 = SplitFields::new(input2.as_bytes(), b';', Some(b'"'), b'\n'); + + assert_eq!(fields2.next(), Some(("\"foo\n bar\"".as_bytes(), true))); + assert_eq!(fields2.next(), Some(("\"baz\"".as_bytes(), true))); + assert_eq!(fields2.next(), Some(("12345".as_bytes(), false))); + assert_eq!(fields2.next(), None); + } +} diff --git a/crates/polars-io/src/csv/read/utils.rs b/crates/polars-io/src/csv/read/utils.rs new file mode 100644 index 000000000000..5f66ab77f825 --- /dev/null +++ b/crates/polars-io/src/csv/read/utils.rs @@ -0,0 +1,208 @@ +#![allow(unsafe_op_in_unsafe_fn)] +#[cfg(feature = "decompress")] +use std::io::Read; +use std::mem::MaybeUninit; + +use super::parser::next_line_position; +#[cfg(feature = "decompress")] +use super::parser::next_line_position_naive; +use super::splitfields::SplitFields; + +/// TODO: Remove this in favor of parallel CountLines::analyze_chunk +/// +/// (see https://github.com/pola-rs/polars/issues/19078) +pub(crate) fn get_file_chunks( + bytes: &[u8], + n_chunks: usize, + expected_fields: Option, + separator: u8, + quote_char: Option, + eol_char: u8, +) -> Vec<(usize, usize)> { + let mut last_pos = 0; + let total_len = bytes.len(); + let chunk_size = total_len / n_chunks; + let mut offsets = Vec::with_capacity(n_chunks); + for _ in 0..n_chunks { + let search_pos = last_pos + chunk_size; + + if search_pos >= bytes.len() { + break; + } + + let end_pos = match next_line_position( + &bytes[search_pos..], + expected_fields, + separator, + quote_char, + eol_char, + ) { + Some(pos) => search_pos + pos, + None => { + break; + }, + }; + offsets.push((last_pos, end_pos)); + last_pos = end_pos; + } + offsets.push((last_pos, total_len)); + offsets +} + +#[cfg(feature = "decompress")] +fn decompress_impl( + decoder: &mut R, + n_rows: Option, + separator: u8, + quote_char: Option, + eol_char: u8, +) -> Option> { + let chunk_size = 4096; + Some(match n_rows { + None => { + // decompression in a preallocated buffer does not work with zlib-ng + // and will put the original compressed data in the buffer. + let mut out = Vec::new(); + decoder.read_to_end(&mut out).ok()?; + out + }, + Some(n_rows) => { + // we take the first rows first '\n\ + let mut out = vec![]; + let mut expected_fields = 0; + // make sure that we have enough bytes to decode the header (even if it has embedded new line chars) + // those extra bytes in the buffer don't matter, we don't need to track them + loop { + let read = decoder.take(chunk_size).read_to_end(&mut out).ok()?; + if read == 0 { + break; + } + if next_line_position_naive(&out, eol_char).is_some() { + // an extra shot + let read = decoder.take(chunk_size).read_to_end(&mut out).ok()?; + if read == 0 { + break; + } + // now that we have enough, we compute the number of fields (also takes embedding into account) + expected_fields = + SplitFields::new(&out, separator, quote_char, eol_char).count(); + break; + } + } + + let mut line_count = 0; + let mut buf_pos = 0; + // keep decoding bytes and count lines + // keep track of the n_rows we read + while line_count < n_rows { + match next_line_position( + &out[buf_pos + 1..], + Some(expected_fields), + separator, + quote_char, + eol_char, + ) { + Some(pos) => { + line_count += 1; + buf_pos += pos; + }, + None => { + // take more bytes so that we might find a new line the next iteration + let read = decoder.take(chunk_size).read_to_end(&mut out).ok()?; + // we depleted the reader + if read == 0 { + break; + } + continue; + }, + }; + } + if line_count == n_rows { + out.truncate(buf_pos); // retain only first n_rows in out + } + out + }, + }) +} + +#[cfg(feature = "decompress")] +pub(crate) fn decompress( + bytes: &[u8], + n_rows: Option, + separator: u8, + quote_char: Option, + eol_char: u8, +) -> Option> { + use crate::utils::compression::SupportedCompression; + + if let Some(algo) = SupportedCompression::check(bytes) { + match algo { + SupportedCompression::GZIP => { + let mut decoder = flate2::read::MultiGzDecoder::new(bytes); + decompress_impl(&mut decoder, n_rows, separator, quote_char, eol_char) + }, + SupportedCompression::ZLIB => { + let mut decoder = flate2::read::ZlibDecoder::new(bytes); + decompress_impl(&mut decoder, n_rows, separator, quote_char, eol_char) + }, + SupportedCompression::ZSTD => { + let mut decoder = zstd::Decoder::with_buffer(bytes).ok()?; + decompress_impl(&mut decoder, n_rows, separator, quote_char, eol_char) + }, + } + } else { + None + } +} + +/// replace double quotes by single ones +/// +/// This function assumes that bytes is wrapped in the quoting character. +/// +/// # Safety +/// +/// The caller must ensure that: +/// - Output buffer must have enough capacity to hold `bytes.len()` +/// - bytes ends with the quote character e.g.: `"` +/// - bytes length > 1. +pub(super) unsafe fn escape_field(bytes: &[u8], quote: u8, buf: &mut [MaybeUninit]) -> usize { + debug_assert!(bytes.len() > 1); + let mut prev_quote = false; + + let mut count = 0; + for c in bytes.get_unchecked(1..bytes.len() - 1) { + if *c == quote { + if prev_quote { + prev_quote = false; + buf.get_unchecked_mut(count).write(*c); + count += 1; + } else { + prev_quote = true; + } + } else { + prev_quote = false; + buf.get_unchecked_mut(count).write(*c); + count += 1; + } + } + count +} + +#[cfg(test)] +mod test { + use super::get_file_chunks; + + #[test] + fn test_get_file_chunks() { + let path = "../../examples/datasets/foods1.csv"; + let s = std::fs::read_to_string(path).unwrap(); + let bytes = s.as_bytes(); + // can be within -1 / +1 bounds. + assert!( + (get_file_chunks(bytes, 10, Some(4), b',', None, b'\n').len() as i32 - 10).abs() <= 1 + ); + assert!( + (get_file_chunks(bytes, 8, Some(4), b',', None, b'\n').len() as i32 - 8).abs() <= 1 + ); + } +} diff --git a/crates/polars-io/src/csv/write/mod.rs b/crates/polars-io/src/csv/write/mod.rs new file mode 100644 index 000000000000..53111bba4c95 --- /dev/null +++ b/crates/polars-io/src/csv/write/mod.rs @@ -0,0 +1,25 @@ +//! Functionality for writing CSV files. +//! +//! # Examples +//! +//! ``` +//! use polars_core::prelude::*; +//! use polars_io::prelude::*; +//! use std::fs::File; +//! +//! fn example(df: &mut DataFrame) -> PolarsResult<()> { +//! let mut file = File::create("example.csv").expect("could not create file"); +//! +//! CsvWriter::new(&mut file) +//! .include_header(true) +//! .with_separator(b',') +//! .finish(df) +//! } +//! ``` + +mod options; +mod write_impl; +mod writer; + +pub use options::{CsvWriterOptions, QuoteStyle, SerializeOptions}; +pub use writer::{BatchedWriter, CsvWriter}; diff --git a/crates/polars-io/src/csv/write/options.rs b/crates/polars-io/src/csv/write/options.rs new file mode 100644 index 000000000000..e49595dbafde --- /dev/null +++ b/crates/polars-io/src/csv/write/options.rs @@ -0,0 +1,93 @@ +use std::num::NonZeroUsize; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// Options for writing CSV files. +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct CsvWriterOptions { + pub include_bom: bool, + pub include_header: bool, + pub batch_size: NonZeroUsize, + pub serialize_options: SerializeOptions, +} + +impl Default for CsvWriterOptions { + fn default() -> Self { + Self { + include_bom: false, + include_header: true, + batch_size: NonZeroUsize::new(1024).unwrap(), + serialize_options: SerializeOptions::default(), + } + } +} + +/// Options to serialize logical types to CSV. +/// +/// The default is to format times and dates as `chrono` crate formats them. +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct SerializeOptions { + /// Used for [`DataType::Date`](polars_core::datatypes::DataType::Date). + pub date_format: Option, + /// Used for [`DataType::Time`](polars_core::datatypes::DataType::Time). + pub time_format: Option, + /// Used for [`DataType::Datetime`](polars_core::datatypes::DataType::Datetime). + pub datetime_format: Option, + /// Used for [`DataType::Float64`](polars_core::datatypes::DataType::Float64) + /// and [`DataType::Float32`](polars_core::datatypes::DataType::Float32). + pub float_scientific: Option, + pub float_precision: Option, + /// Used as separator. + pub separator: u8, + /// Quoting character. + pub quote_char: u8, + /// Null value representation. + pub null: String, + /// String appended after every row. + pub line_terminator: String, + /// When to insert quotes. + pub quote_style: QuoteStyle, +} + +impl Default for SerializeOptions { + fn default() -> Self { + Self { + date_format: None, + time_format: None, + datetime_format: None, + float_scientific: None, + float_precision: None, + separator: b',', + quote_char: b'"', + null: String::new(), + line_terminator: "\n".into(), + quote_style: Default::default(), + } + } +} + +/// Quote style indicating when to insert quotes around a field. +#[derive(Copy, Clone, Debug, Default, Eq, Hash, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum QuoteStyle { + /// Quote fields only when necessary. + /// + /// Quotes are necessary when fields contain a quote, separator or record terminator. + /// Quotes are also necessary when writing an empty record (which is indistinguishable + /// from arecord with one empty field). + /// This is the default. + #[default] + Necessary, + /// Quote every field. Always. + Always, + /// Quote non-numeric fields. + /// + /// When writing a field that does not parse as a valid float or integer, + /// quotes will be used even if they aren't strictly necessary. + NonNumeric, + /// Never quote any fields, even if it would produce invalid CSV data. + Never, +} diff --git a/crates/polars-io/src/csv/write/write_impl.rs b/crates/polars-io/src/csv/write/write_impl.rs new file mode 100644 index 000000000000..1315b7f58072 --- /dev/null +++ b/crates/polars-io/src/csv/write/write_impl.rs @@ -0,0 +1,237 @@ +mod serializer; + +use std::io::Write; + +use arrow::array::NullArray; +use arrow::legacy::time_zone::Tz; +use polars_core::POOL; +use polars_core::prelude::*; +use polars_error::polars_ensure; +use rayon::prelude::*; +use serializer::{serializer_for, string_serializer}; + +use crate::csv::write::SerializeOptions; + +pub(crate) fn write( + writer: &mut W, + df: &DataFrame, + chunk_size: usize, + options: &SerializeOptions, + n_threads: usize, +) -> PolarsResult<()> { + for s in df.get_columns() { + let nested = match s.dtype() { + DataType::List(_) => true, + #[cfg(feature = "dtype-struct")] + DataType::Struct(_) => true, + #[cfg(feature = "object")] + DataType::Object(_) => { + return Err(PolarsError::ComputeError( + "csv writer does not support object dtype".into(), + )); + }, + _ => false, + }; + polars_ensure!( + !nested, + ComputeError: "CSV format does not support nested data", + ); + } + + // Check that the double quote is valid UTF-8. + polars_ensure!( + std::str::from_utf8(&[options.quote_char, options.quote_char]).is_ok(), + ComputeError: "quote char results in invalid utf-8", + ); + + let (datetime_formats, time_zones): (Vec<&str>, Vec>) = df + .get_columns() + .iter() + .map(|column| match column.dtype() { + DataType::Datetime(TimeUnit::Milliseconds, tz) => { + let (format, tz_parsed) = match tz { + #[cfg(feature = "timezones")] + Some(tz) => ( + options + .datetime_format + .as_deref() + .unwrap_or("%FT%H:%M:%S.%3f%z"), + tz.parse::().ok(), + ), + _ => ( + options + .datetime_format + .as_deref() + .unwrap_or("%FT%H:%M:%S.%3f"), + None, + ), + }; + (format, tz_parsed) + }, + DataType::Datetime(TimeUnit::Microseconds, tz) => { + let (format, tz_parsed) = match tz { + #[cfg(feature = "timezones")] + Some(tz) => ( + options + .datetime_format + .as_deref() + .unwrap_or("%FT%H:%M:%S.%6f%z"), + tz.parse::().ok(), + ), + _ => ( + options + .datetime_format + .as_deref() + .unwrap_or("%FT%H:%M:%S.%6f"), + None, + ), + }; + (format, tz_parsed) + }, + DataType::Datetime(TimeUnit::Nanoseconds, tz) => { + let (format, tz_parsed) = match tz { + #[cfg(feature = "timezones")] + Some(tz) => ( + options + .datetime_format + .as_deref() + .unwrap_or("%FT%H:%M:%S.%9f%z"), + tz.parse::().ok(), + ), + _ => ( + options + .datetime_format + .as_deref() + .unwrap_or("%FT%H:%M:%S.%9f"), + None, + ), + }; + (format, tz_parsed) + }, + _ => ("", None), + }) + .unzip(); + + let len = df.height(); + let total_rows_per_pool_iter = n_threads * chunk_size; + + let mut n_rows_finished = 0; + + let mut buffers: Vec<_> = (0..n_threads).map(|_| (Vec::new(), Vec::new())).collect(); + while n_rows_finished < len { + let buf_writer = |thread_no, write_buffer: &mut Vec<_>, serializers_vec: &mut Vec<_>| { + let thread_offset = thread_no * chunk_size; + let total_offset = n_rows_finished + thread_offset; + let mut df = df.slice(total_offset as i64, chunk_size); + // the `series.iter` needs rechunked series. + // we don't do this on the whole as this probably needs much less rechunking + // so will be faster. + // and allows writing `pl.concat([df] * 100, rechunk=False).write_csv()` as the rechunk + // would go OOM + df.as_single_chunk(); + let cols = df.get_columns(); + + // SAFETY: + // the bck thinks the lifetime is bounded to write_buffer_pool, but at the time we return + // the vectors the buffer pool, the series have already been removed from the buffers + // in other words, the lifetime does not leave this scope + let cols = unsafe { std::mem::transmute::<&[Column], &[Column]>(cols) }; + + if df.is_empty() { + return Ok(()); + } + + if serializers_vec.is_empty() { + *serializers_vec = cols + .iter() + .enumerate() + .map(|(i, col)| { + serializer_for( + &*col.as_materialized_series().chunks()[0], + options, + col.dtype(), + datetime_formats[i], + time_zones[i], + ) + }) + .collect::>()?; + } else { + debug_assert_eq!(serializers_vec.len(), cols.len()); + for (col_iter, col) in std::iter::zip(serializers_vec.iter_mut(), cols) { + col_iter.update_array(&*col.as_materialized_series().chunks()[0]); + } + } + + let serializers = serializers_vec.as_mut_slice(); + + let len = std::cmp::min(cols[0].len(), chunk_size); + + for _ in 0..len { + serializers[0].serialize(write_buffer, options); + for serializer in &mut serializers[1..] { + write_buffer.push(options.separator); + serializer.serialize(write_buffer, options); + } + + write_buffer.extend_from_slice(options.line_terminator.as_bytes()); + } + + Ok(()) + }; + + if n_threads > 1 { + POOL.install(|| { + buffers + .par_iter_mut() + .enumerate() + .map(|(i, (w, s))| buf_writer(i, w, s)) + .collect::>() + })?; + } else { + let (w, s) = &mut buffers[0]; + buf_writer(0, w, s)?; + } + + for (write_buffer, _) in &mut buffers { + writer.write_all(write_buffer)?; + write_buffer.clear(); + } + + n_rows_finished += total_rows_per_pool_iter; + } + Ok(()) +} + +/// Writes a CSV header to `writer`. +pub(crate) fn write_header( + writer: &mut W, + names: &[&str], + options: &SerializeOptions, +) -> PolarsResult<()> { + let mut header = Vec::new(); + + // A hack, but it works for this case. + let fake_arr = NullArray::new(ArrowDataType::Null, 0); + let mut names_serializer = string_serializer( + |iter: &mut std::slice::Iter<&str>| iter.next().copied(), + options, + |_| names.iter(), + &fake_arr, + ); + for i in 0..names.len() { + names_serializer.serialize(&mut header, options); + if i != names.len() - 1 { + header.push(options.separator); + } + } + header.extend_from_slice(options.line_terminator.as_bytes()); + writer.write_all(&header)?; + Ok(()) +} + +/// Writes a UTF-8 BOM to `writer`. +pub(crate) fn write_bom(writer: &mut W) -> PolarsResult<()> { + const BOM: [u8; 3] = [0xEF, 0xBB, 0xBF]; + writer.write_all(&BOM)?; + Ok(()) +} diff --git a/crates/polars-io/src/csv/write/write_impl/serializer.rs b/crates/polars-io/src/csv/write/write_impl/serializer.rs new file mode 100644 index 000000000000..03dc36b94223 --- /dev/null +++ b/crates/polars-io/src/csv/write/write_impl/serializer.rs @@ -0,0 +1,777 @@ +//! This file is complicated because we have complicated escape handling. We want to avoid having +//! to write down each combination of type & escaping, but we also want the compiler to optimize them +//! to efficient machine code - so no dynamic dispatch. That means a lot of generics and macros. +//! +//! We need to differentiate between several kinds of types, and several kinds of escaping we support: +//! +//! - The simplest escaping mechanism are [`QuoteStyle::Always`] and [`QuoteStyle::Never`]. +//! For `Never` we just never quote. For `Always` we pass any serializer that never quotes +//! to [`quote_serializer()`] then it becomes quoted properly. +//! - [`QuoteStyle::Necessary`] (the default) is only relevant for strings, as it is the only type that +//! can have newlines (row separators), commas (column separators) or quotes. String +//! escaping is complicated anyway, and it is all inside [`string_serializer()`]. +//! - The real complication is [`QuoteStyle::NonNumeric`], that doesn't quote numbers and nulls, +//! and quotes any other thing. The problem is that nulls can be within any type, so we need to handle +//! two possibilities of quoting everywhere. +//! +//! So in case the chosen style is anything but `NonNumeric`, we statically know for each column except strings +//! whether it should be quoted (and for strings too when not `Necessary`). There we use `quote_serializer()` +//! or nothing. +//! +//! But to help with `NonNumeric`, each serializer carry the potential to distinguish between nulls and non-nulls, +//! and quote the later and not the former. But in order to not have the branch when we statically know the answer, +//! we have an option to statically disable it with a const generic flag `QUOTE_NON_NULL`. Numbers (that should never +//! be quoted with `NonNumeric`) just always disable this flag. +//! +//! So we have three possibilities: +//! +//! 1. A serializer that never quotes. This is a bare serializer with `QUOTE_NON_NULL = false`. +//! 2. A serializer that always quotes. This is a serializer wrapped with `quote_serializer()`, +//! but also with `QUOTE_NON_NULL = false`. +//! 3. A serializer that quotes only non-nulls. This is a bare serializer with `QUOTE_NON_NULL = true`. + +use std::fmt::LowerExp; +use std::io::Write; + +use arrow::array::{Array, BooleanArray, NullArray, PrimitiveArray, Utf8ViewArray}; +use arrow::legacy::time_zone::Tz; +use arrow::types::NativeType; +#[cfg(feature = "timezones")] +use chrono::TimeZone; +use memchr::{memchr_iter, memchr3}; +use num_traits::NumCast; +use polars_core::prelude::*; + +use crate::csv::write::{QuoteStyle, SerializeOptions}; + +const TOO_MANY_MSG: &str = "too many items requested from CSV serializer"; +const ARRAY_MISMATCH_MSG: &str = "wrong array type"; + +#[allow(dead_code)] +struct IgnoreFmt; +impl std::fmt::Write for IgnoreFmt { + fn write_str(&mut self, _s: &str) -> std::fmt::Result { + Ok(()) + } +} + +pub(super) trait Serializer<'a> { + fn serialize(&mut self, buf: &mut Vec, options: &SerializeOptions); + // Updates the array without changing the configuration. + fn update_array(&mut self, array: &'a dyn Array); +} + +fn make_serializer<'a, T, I: Iterator>, const QUOTE_NON_NULL: bool>( + f: impl FnMut(T, &mut Vec, &SerializeOptions), + iter: I, + update_array: impl FnMut(&'a dyn Array) -> I, +) -> impl Serializer<'a> { + struct SerializerImpl { + f: F, + iter: I, + update_array: Update, + } + + impl<'a, T, F, I, Update, const QUOTE_NON_NULL: bool> Serializer<'a> + for SerializerImpl + where + F: FnMut(T, &mut Vec, &SerializeOptions), + I: Iterator>, + Update: FnMut(&'a dyn Array) -> I, + { + fn serialize(&mut self, buf: &mut Vec, options: &SerializeOptions) { + let item = self.iter.next().expect(TOO_MANY_MSG); + match item { + Some(item) => { + if QUOTE_NON_NULL { + buf.push(options.quote_char); + } + (self.f)(item, buf, options); + if QUOTE_NON_NULL { + buf.push(options.quote_char); + } + }, + None => buf.extend_from_slice(options.null.as_bytes()), + } + } + + fn update_array(&mut self, array: &'a dyn Array) { + self.iter = (self.update_array)(array); + } + } + + SerializerImpl::<_, _, _, QUOTE_NON_NULL> { + f, + iter, + update_array, + } +} + +fn integer_serializer(array: &PrimitiveArray) -> impl Serializer { + let f = move |&item, buf: &mut Vec, _options: &SerializeOptions| { + let mut buffer = itoa::Buffer::new(); + let value = buffer.format(item); + buf.extend_from_slice(value.as_bytes()); + }; + + make_serializer::<_, _, false>(f, array.iter(), |array| { + array + .as_any() + .downcast_ref::>() + .expect(ARRAY_MISMATCH_MSG) + .iter() + }) +} + +fn float_serializer_no_precision_autoformat( + array: &PrimitiveArray, +) -> impl Serializer { + let f = move |&item, buf: &mut Vec, _options: &SerializeOptions| { + let mut buffer = ryu::Buffer::new(); + let value = buffer.format(item); + buf.extend_from_slice(value.as_bytes()); + }; + + make_serializer::<_, _, false>(f, array.iter(), |array| { + array + .as_any() + .downcast_ref::>() + .expect(ARRAY_MISMATCH_MSG) + .iter() + }) +} + +fn float_serializer_no_precision_scientific( + array: &PrimitiveArray, +) -> impl Serializer { + let f = move |&item, buf: &mut Vec, _options: &SerializeOptions| { + // Float writing into a buffer of `Vec` cannot fail. + let _ = write!(buf, "{item:.e}"); + }; + + make_serializer::<_, _, false>(f, array.iter(), |array| { + array + .as_any() + .downcast_ref::>() + .expect(ARRAY_MISMATCH_MSG) + .iter() + }) +} + +fn float_serializer_no_precision_positional( + array: &PrimitiveArray, +) -> impl Serializer { + let f = move |&item, buf: &mut Vec, _options: &SerializeOptions| { + let v: f64 = NumCast::from(item).unwrap(); + let value = v.to_string(); + buf.extend_from_slice(value.as_bytes()); + }; + + make_serializer::<_, _, false>(f, array.iter(), |array| { + array + .as_any() + .downcast_ref::>() + .expect(ARRAY_MISMATCH_MSG) + .iter() + }) +} + +fn float_serializer_with_precision_scientific( + array: &PrimitiveArray, + precision: usize, +) -> impl Serializer { + let f = move |&item, buf: &mut Vec, _options: &SerializeOptions| { + // Float writing into a buffer of `Vec` cannot fail. + let _ = write!(buf, "{item:.precision$e}"); + }; + + make_serializer::<_, _, false>(f, array.iter(), |array| { + array + .as_any() + .downcast_ref::>() + .expect(ARRAY_MISMATCH_MSG) + .iter() + }) +} + +fn float_serializer_with_precision_positional( + array: &PrimitiveArray, + precision: usize, +) -> impl Serializer { + let f = move |&item, buf: &mut Vec, _options: &SerializeOptions| { + // Float writing into a buffer of `Vec` cannot fail. + let _ = write!(buf, "{item:.precision$}"); + }; + + make_serializer::<_, _, false>(f, array.iter(), |array| { + array + .as_any() + .downcast_ref::>() + .expect(ARRAY_MISMATCH_MSG) + .iter() + }) +} + +fn null_serializer(_array: &NullArray) -> impl Serializer { + struct NullSerializer; + impl<'a> Serializer<'a> for NullSerializer { + fn serialize(&mut self, buf: &mut Vec, options: &SerializeOptions) { + buf.extend_from_slice(options.null.as_bytes()); + } + fn update_array(&mut self, _array: &'a dyn Array) {} + } + NullSerializer +} + +fn bool_serializer(array: &BooleanArray) -> impl Serializer { + let f = move |item, buf: &mut Vec, _options: &SerializeOptions| { + let s = if item { "true" } else { "false" }; + buf.extend_from_slice(s.as_bytes()); + }; + + make_serializer::<_, _, QUOTE_NON_NULL>(f, array.iter(), |array| { + array + .as_any() + .downcast_ref::() + .expect(ARRAY_MISMATCH_MSG) + .iter() + }) +} + +#[cfg(feature = "dtype-decimal")] +fn decimal_serializer(array: &PrimitiveArray, scale: usize) -> impl Serializer { + let trim_zeros = arrow::compute::decimal::get_trim_decimal_zeros(); + + let mut fmt_buf = arrow::compute::decimal::DecimalFmtBuffer::new(); + let f = move |&item, buf: &mut Vec, _options: &SerializeOptions| { + buf.extend_from_slice(fmt_buf.format(item, scale, trim_zeros).as_bytes()); + }; + + make_serializer::<_, _, false>(f, array.iter(), |array| { + array + .as_any() + .downcast_ref::>() + .expect(ARRAY_MISMATCH_MSG) + .iter() + }) +} + +#[cfg(any( + feature = "dtype-date", + feature = "dtype-time", + feature = "dtype-datetime" +))] +fn callback_serializer<'a, T: NativeType, const QUOTE_NON_NULL: bool>( + array: &'a PrimitiveArray, + mut callback: impl FnMut(T, &mut Vec) + 'a, +) -> impl Serializer<'a> { + let f = move |&item, buf: &mut Vec, _options: &SerializeOptions| { + callback(item, buf); + }; + + make_serializer::<_, _, QUOTE_NON_NULL>(f, array.iter(), |array| { + array + .as_any() + .downcast_ref::>() + .expect(ARRAY_MISMATCH_MSG) + .iter() + }) +} + +#[cfg(any(feature = "dtype-date", feature = "dtype-time"))] +type ChronoFormatIter<'a, 'b> = std::slice::Iter<'a, chrono::format::Item<'b>>; + +#[cfg(any(feature = "dtype-date", feature = "dtype-time"))] +fn date_and_time_serializer<'a, Underlying: NativeType, T: std::fmt::Display>( + format_str: &'a Option, + description: &str, + array: &'a dyn Array, + sample_value: T, + mut convert: impl FnMut(Underlying) -> T + Send + 'a, + mut format_fn: impl for<'b> FnMut( + &T, + ChronoFormatIter<'b, 'a>, + ) -> chrono::format::DelayedFormat> + + Send + + 'a, + options: &SerializeOptions, +) -> PolarsResult + Send + 'a>> { + let array = array.as_any().downcast_ref().unwrap(); + let serializer = match format_str { + Some(format_str) => { + let format = chrono::format::StrftimeItems::new(format_str).parse().map_err( + |_| polars_err!(ComputeError: "cannot format {description} with format '{format_str}'"), + )?; + use std::fmt::Write; + // Fail fast for invalid format. This return error faster to the user, and allows us to not return + // `Result` from `serialize()`. + write!(IgnoreFmt, "{}", format_fn(&sample_value, format.iter())).map_err( + |_| polars_err!(ComputeError: "cannot format {description} with format '{format_str}'"), + )?; + let callback = move |item, buf: &mut Vec| { + let item = convert(item); + // We checked the format is valid above. + let _ = write!(buf, "{}", format_fn(&item, format.iter())); + }; + date_and_time_final_serializer(array, callback, options) + }, + None => { + let callback = move |item, buf: &mut Vec| { + let item = convert(item); + // Formatting dates into `Vec` cannot fail. + let _ = write!(buf, "{item}"); + }; + date_and_time_final_serializer(array, callback, options) + }, + }; + Ok(serializer) +} + +#[cfg(any( + feature = "dtype-date", + feature = "dtype-time", + feature = "dtype-datetime" +))] +fn date_and_time_final_serializer<'a, T: NativeType>( + array: &'a PrimitiveArray, + callback: impl FnMut(T, &mut Vec) + Send + 'a, + options: &SerializeOptions, +) -> Box + Send + 'a> { + match options.quote_style { + QuoteStyle::Always => Box::new(quote_serializer(callback_serializer::( + array, callback, + ))) as Box, + QuoteStyle::NonNumeric => Box::new(callback_serializer::(array, callback)), + _ => Box::new(callback_serializer::(array, callback)), + } +} + +pub(super) fn string_serializer<'a, Iter: Send + 'a>( + mut f: impl FnMut(&mut Iter) -> Option<&str> + Send + 'a, + options: &SerializeOptions, + mut update: impl FnMut(&'a dyn Array) -> Iter + Send + 'a, + array: &'a dyn Array, +) -> Box + 'a + Send> { + const LF: u8 = b'\n'; + const CR: u8 = b'\r'; + + struct StringSerializer { + serialize: F, + update: Update, + iter: Iter, + } + + impl<'a, F, Iter, Update> Serializer<'a> for StringSerializer + where + F: FnMut(&mut Iter, &mut Vec, &SerializeOptions), + Update: FnMut(&'a dyn Array) -> Iter, + { + fn serialize(&mut self, buf: &mut Vec, options: &SerializeOptions) { + (self.serialize)(&mut self.iter, buf, options); + } + + fn update_array(&mut self, array: &'a dyn Array) { + self.iter = (self.update)(array); + } + } + + fn serialize_str_escaped(buf: &mut Vec, s: &[u8], quote_char: u8, quoted: bool) { + let mut iter = memchr_iter(quote_char, s); + let first_quote = iter.next(); + match first_quote { + None => buf.extend_from_slice(s), + Some(mut quote_pos) => { + if !quoted { + buf.push(quote_char); + } + let mut start_pos = 0; + loop { + buf.extend_from_slice(&s[start_pos..quote_pos]); + buf.extend_from_slice(&[quote_char, quote_char]); + match iter.next() { + Some(quote) => { + start_pos = quote_pos + 1; + quote_pos = quote; + }, + None => { + buf.extend_from_slice(&s[quote_pos + 1..]); + break; + }, + } + } + if !quoted { + buf.push(quote_char); + } + }, + } + } + + let iter = update(array); + match options.quote_style { + QuoteStyle::Always => { + let serialize = + move |iter: &mut Iter, buf: &mut Vec, options: &SerializeOptions| { + let quote_char = options.quote_char; + buf.push(quote_char); + let Some(s) = f(iter) else { + buf.extend_from_slice(options.null.as_bytes()); + buf.push(quote_char); + return; + }; + serialize_str_escaped(buf, s.as_bytes(), quote_char, true); + buf.push(quote_char); + }; + Box::new(StringSerializer { + serialize, + update, + iter, + }) + }, + QuoteStyle::NonNumeric => { + let serialize = + move |iter: &mut Iter, buf: &mut Vec, options: &SerializeOptions| { + let Some(s) = f(iter) else { + buf.extend_from_slice(options.null.as_bytes()); + return; + }; + let quote_char = options.quote_char; + buf.push(quote_char); + serialize_str_escaped(buf, s.as_bytes(), quote_char, true); + buf.push(quote_char); + }; + Box::new(StringSerializer { + serialize, + update, + iter, + }) + }, + QuoteStyle::Necessary => { + let serialize = + move |iter: &mut Iter, buf: &mut Vec, options: &SerializeOptions| { + let Some(s) = f(iter) else { + buf.extend_from_slice(options.null.as_bytes()); + return; + }; + let quote_char = options.quote_char; + // An empty string conflicts with null, so it is necessary to quote. + if s.is_empty() { + buf.extend_from_slice(&[quote_char, quote_char]); + return; + } + let needs_quote = memchr3(options.separator, LF, CR, s.as_bytes()).is_some(); + if needs_quote { + buf.push(quote_char); + } + serialize_str_escaped(buf, s.as_bytes(), quote_char, needs_quote); + if needs_quote { + buf.push(quote_char); + } + }; + Box::new(StringSerializer { + serialize, + update, + iter, + }) + }, + QuoteStyle::Never => { + let serialize = + move |iter: &mut Iter, buf: &mut Vec, options: &SerializeOptions| { + let Some(s) = f(iter) else { + buf.extend_from_slice(options.null.as_bytes()); + return; + }; + buf.extend_from_slice(s.as_bytes()); + }; + Box::new(StringSerializer { + serialize, + update, + iter, + }) + }, + } +} + +fn quote_serializer<'a>(serializer: impl Serializer<'a>) -> impl Serializer<'a> { + struct QuoteSerializer(S); + impl<'a, S: Serializer<'a>> Serializer<'a> for QuoteSerializer { + fn serialize(&mut self, buf: &mut Vec, options: &SerializeOptions) { + buf.push(options.quote_char); + self.0.serialize(buf, options); + buf.push(options.quote_char); + } + + fn update_array(&mut self, array: &'a dyn Array) { + self.0.update_array(array); + } + } + QuoteSerializer(serializer) +} + +pub(super) fn serializer_for<'a>( + array: &'a dyn Array, + options: &'a SerializeOptions, + dtype: &'a DataType, + _datetime_format: &'a str, + _time_zone: Option, +) -> PolarsResult + Send + 'a>> { + macro_rules! quote_if_always { + ($make_serializer:path, $($arg:tt)*) => {{ + let serializer = $make_serializer(array.as_any().downcast_ref().unwrap(), $($arg)*); + if let QuoteStyle::Always = options.quote_style { + Box::new(quote_serializer(serializer)) as Box + } else { + Box::new(serializer) + } + }}; + ($make_serializer:path) => { quote_if_always!($make_serializer,) }; + } + + let serializer = match dtype { + DataType::Int8 => quote_if_always!(integer_serializer::), + DataType::UInt8 => quote_if_always!(integer_serializer::), + DataType::Int16 => quote_if_always!(integer_serializer::), + DataType::UInt16 => quote_if_always!(integer_serializer::), + DataType::Int32 => quote_if_always!(integer_serializer::), + DataType::UInt32 => quote_if_always!(integer_serializer::), + DataType::Int64 => quote_if_always!(integer_serializer::), + DataType::UInt64 => quote_if_always!(integer_serializer::), + DataType::Int128 => quote_if_always!(integer_serializer::), + DataType::Float32 => match options.float_precision { + Some(precision) => match options.float_scientific { + Some(true) => { + quote_if_always!(float_serializer_with_precision_scientific::, precision) + }, + _ => quote_if_always!(float_serializer_with_precision_positional::, precision), + }, + None => match options.float_scientific { + Some(true) => quote_if_always!(float_serializer_no_precision_scientific::), + Some(false) => quote_if_always!(float_serializer_no_precision_positional::), + None => quote_if_always!(float_serializer_no_precision_autoformat::), + }, + }, + DataType::Float64 => match options.float_precision { + Some(precision) => match options.float_scientific { + Some(true) => { + quote_if_always!(float_serializer_with_precision_scientific::, precision) + }, + _ => quote_if_always!(float_serializer_with_precision_positional::, precision), + }, + None => match options.float_scientific { + Some(true) => quote_if_always!(float_serializer_no_precision_scientific::), + Some(false) => quote_if_always!(float_serializer_no_precision_positional::), + None => quote_if_always!(float_serializer_no_precision_autoformat::), + }, + }, + DataType::Null => quote_if_always!(null_serializer), + DataType::Boolean => { + let array = array.as_any().downcast_ref().unwrap(); + match options.quote_style { + QuoteStyle::Always => Box::new(quote_serializer(bool_serializer::(array))) + as Box, + QuoteStyle::NonNumeric => Box::new(bool_serializer::(array)), + _ => Box::new(bool_serializer::(array)), + } + }, + #[cfg(feature = "dtype-date")] + DataType::Date => date_and_time_serializer( + &options.date_format, + "NaiveDate", + array, + chrono::NaiveDate::MAX, + arrow::temporal_conversions::date32_to_date, + |date, items| date.format_with_items(items), + options, + )?, + #[cfg(feature = "dtype-time")] + DataType::Time => date_and_time_serializer( + &options.time_format, + "NaiveTime", + array, + chrono::NaiveTime::MIN, + arrow::temporal_conversions::time64ns_to_time, + |time, items| time.format_with_items(items), + options, + )?, + #[cfg(feature = "dtype-datetime")] + DataType::Datetime(time_unit, _) => { + let format = chrono::format::StrftimeItems::new(_datetime_format) + .parse() + .map_err(|_| { + polars_err!( + ComputeError: "cannot format {} with format '{_datetime_format}'", + if _time_zone.is_some() { "DateTime" } else { "NaiveDateTime" }, + ) + })?; + use std::fmt::Write; + let sample_datetime = match _time_zone { + #[cfg(feature = "timezones")] + Some(time_zone) => time_zone + .from_utc_datetime(&chrono::NaiveDateTime::MAX) + .format_with_items(format.iter()), + #[cfg(not(feature = "timezones"))] + Some(_) => panic!("activate 'timezones' feature"), + None => chrono::NaiveDateTime::MAX.format_with_items(format.iter()), + }; + // Fail fast for invalid format. This return error faster to the user, and allows us to not return + // `Result` from `serialize()`. + write!(IgnoreFmt, "{sample_datetime}").map_err(|_| { + polars_err!( + ComputeError: "cannot format {} with format '{_datetime_format}'", + if _time_zone.is_some() { "DateTime" } else { "NaiveDateTime" }, + ) + })?; + + let array = array.as_any().downcast_ref().unwrap(); + + macro_rules! time_unit_serializer { + ($convert:ident) => { + match _time_zone { + #[cfg(feature = "timezones")] + Some(time_zone) => { + let callback = move |item, buf: &mut Vec| { + let item = arrow::temporal_conversions::$convert(item); + let item = time_zone.from_utc_datetime(&item); + // We checked the format is valid above. + let _ = write!(buf, "{}", item.format_with_items(format.iter())); + }; + date_and_time_final_serializer(array, callback, options) + }, + #[cfg(not(feature = "timezones"))] + Some(_) => panic!("activate 'timezones' feature"), + None => { + let callback = move |item, buf: &mut Vec| { + let item = arrow::temporal_conversions::$convert(item); + // We checked the format is valid above. + let _ = write!(buf, "{}", item.format_with_items(format.iter())); + }; + date_and_time_final_serializer(array, callback, options) + }, + } + }; + } + + match time_unit { + TimeUnit::Nanoseconds => time_unit_serializer!(timestamp_ns_to_datetime), + TimeUnit::Microseconds => time_unit_serializer!(timestamp_us_to_datetime), + TimeUnit::Milliseconds => time_unit_serializer!(timestamp_ms_to_datetime), + } + }, + DataType::String => string_serializer( + |iter| Iterator::next(iter).expect(TOO_MANY_MSG), + options, + |arr| { + arr.as_any() + .downcast_ref::() + .expect(ARRAY_MISMATCH_MSG) + .iter() + }, + array, + ), + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(rev_map, _) | DataType::Enum(rev_map, _) => { + let rev_map = rev_map.as_deref().unwrap(); + string_serializer( + |iter| { + let &idx: &u32 = Iterator::next(iter).expect(TOO_MANY_MSG)?; + Some(rev_map.get(idx)) + }, + options, + |arr| { + arr.as_any() + .downcast_ref::>() + .expect(ARRAY_MISMATCH_MSG) + .iter() + }, + array, + ) + }, + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(_, scale) => { + quote_if_always!(decimal_serializer, scale.unwrap_or(0)) + }, + _ => { + polars_bail!(ComputeError: "datatype {dtype} cannot be written to CSV\n\nConsider using JSON or a binary format.") + }, + }; + Ok(serializer) +} + +#[cfg(test)] +mod test { + use arrow::array::NullArray; + use polars_core::prelude::ArrowDataType; + + use super::string_serializer; + use crate::csv::write::options::{QuoteStyle, SerializeOptions}; + + // It is the most complex serializer with most edge cases, it definitely needs a comprehensive test. + #[test] + fn test_string_serializer() { + #[track_caller] + fn check_string_serialization(options: &SerializeOptions, s: Option<&str>, expected: &str) { + let fake_array = NullArray::new(ArrowDataType::Null, 0); + let mut serializer = string_serializer(|s| *s, options, |_| s, &fake_array); + let mut buf = Vec::new(); + serializer.serialize(&mut buf, options); + let serialized = std::str::from_utf8(&buf).unwrap(); + // Don't use `assert_eq!()` because it prints debug format and it's hard to read with all the escapes. + if serialized != expected { + panic!( + "CSV string {s:?} wasn't serialized correctly: expected: `{expected}`, got: `{serialized}`" + ); + } + } + + let always_quote = SerializeOptions { + quote_style: QuoteStyle::Always, + ..SerializeOptions::default() + }; + check_string_serialization(&always_quote, None, r#""""#); + check_string_serialization(&always_quote, Some(""), r#""""#); + check_string_serialization(&always_quote, Some("a"), r#""a""#); + check_string_serialization(&always_quote, Some("\""), r#""""""#); + check_string_serialization(&always_quote, Some("a\"\"b"), r#""a""""b""#); + + let necessary_quote = SerializeOptions { + quote_style: QuoteStyle::Necessary, + ..SerializeOptions::default() + }; + check_string_serialization(&necessary_quote, None, r#""#); + check_string_serialization(&necessary_quote, Some(""), r#""""#); + check_string_serialization(&necessary_quote, Some("a"), r#"a"#); + check_string_serialization(&necessary_quote, Some("\""), r#""""""#); + check_string_serialization(&necessary_quote, Some("a\"\"b"), r#""a""""b""#); + check_string_serialization(&necessary_quote, Some("a b"), r#"a b"#); + check_string_serialization(&necessary_quote, Some("a,b"), r#""a,b""#); + check_string_serialization(&necessary_quote, Some("a\nb"), "\"a\nb\""); + check_string_serialization(&necessary_quote, Some("a\rb"), "\"a\rb\""); + + let never_quote = SerializeOptions { + quote_style: QuoteStyle::Never, + ..SerializeOptions::default() + }; + check_string_serialization(&never_quote, None, ""); + check_string_serialization(&never_quote, Some(""), ""); + check_string_serialization(&never_quote, Some("a"), "a"); + check_string_serialization(&never_quote, Some("\""), "\""); + check_string_serialization(&never_quote, Some("a\"\"b"), "a\"\"b"); + check_string_serialization(&never_quote, Some("a b"), "a b"); + check_string_serialization(&never_quote, Some("a,b"), "a,b"); + check_string_serialization(&never_quote, Some("a\nb"), "a\nb"); + check_string_serialization(&never_quote, Some("a\rb"), "a\rb"); + + let non_numeric_quote = SerializeOptions { + quote_style: QuoteStyle::NonNumeric, + ..SerializeOptions::default() + }; + check_string_serialization(&non_numeric_quote, None, ""); + check_string_serialization(&non_numeric_quote, Some(""), r#""""#); + check_string_serialization(&non_numeric_quote, Some("a"), r#""a""#); + check_string_serialization(&non_numeric_quote, Some("\""), r#""""""#); + check_string_serialization(&non_numeric_quote, Some("a\"\"b"), r#""a""""b""#); + check_string_serialization(&non_numeric_quote, Some("a b"), r#""a b""#); + check_string_serialization(&non_numeric_quote, Some("a,b"), r#""a,b""#); + check_string_serialization(&non_numeric_quote, Some("a\nb"), "\"a\nb\""); + check_string_serialization(&non_numeric_quote, Some("a\rb"), "\"a\rb\""); + } +} diff --git a/crates/polars-io/src/csv/write/writer.rs b/crates/polars-io/src/csv/write/writer.rs new file mode 100644 index 000000000000..f8ceade73cb1 --- /dev/null +++ b/crates/polars-io/src/csv/write/writer.rs @@ -0,0 +1,241 @@ +use std::io::Write; +use std::num::NonZeroUsize; + +use polars_core::POOL; +use polars_core::frame::DataFrame; +use polars_core::schema::Schema; +use polars_error::PolarsResult; + +use super::write_impl::{write, write_bom, write_header}; +use super::{QuoteStyle, SerializeOptions}; +use crate::shared::SerWriter; + +/// Write a DataFrame to csv. +/// +/// Don't use a `Buffered` writer, the `CsvWriter` internally already buffers writes. +#[must_use] +pub struct CsvWriter { + /// File or Stream handler + buffer: W, + options: SerializeOptions, + header: bool, + bom: bool, + batch_size: NonZeroUsize, + n_threads: usize, +} + +impl SerWriter for CsvWriter +where + W: Write, +{ + fn new(buffer: W) -> Self { + // 9f: all nanoseconds + let options = SerializeOptions { + time_format: Some("%T%.9f".to_string()), + ..Default::default() + }; + + CsvWriter { + buffer, + options, + header: true, + bom: false, + batch_size: NonZeroUsize::new(1024).unwrap(), + n_threads: POOL.current_num_threads(), + } + } + + fn finish(&mut self, df: &mut DataFrame) -> PolarsResult<()> { + if self.bom { + write_bom(&mut self.buffer)?; + } + let names = df + .get_column_names() + .into_iter() + .map(|x| x.as_str()) + .collect::>(); + if self.header { + write_header(&mut self.buffer, names.as_slice(), &self.options)?; + } + write( + &mut self.buffer, + df, + self.batch_size.into(), + &self.options, + self.n_threads, + ) + } +} + +impl CsvWriter +where + W: Write, +{ + /// Set whether to write UTF-8 BOM. + pub fn include_bom(mut self, include_bom: bool) -> Self { + self.bom = include_bom; + self + } + + /// Set whether to write headers. + pub fn include_header(mut self, include_header: bool) -> Self { + self.header = include_header; + self + } + + /// Set the CSV file's column separator as a byte character. + pub fn with_separator(mut self, separator: u8) -> Self { + self.options.separator = separator; + self + } + + /// Set the batch size to use while writing the CSV. + pub fn with_batch_size(mut self, batch_size: NonZeroUsize) -> Self { + self.batch_size = batch_size; + self + } + + /// Set the CSV file's date format. + pub fn with_date_format(mut self, format: Option) -> Self { + if format.is_some() { + self.options.date_format = format; + } + self + } + + /// Set the CSV file's time format. + pub fn with_time_format(mut self, format: Option) -> Self { + if format.is_some() { + self.options.time_format = format; + } + self + } + + /// Set the CSV file's datetime format. + pub fn with_datetime_format(mut self, format: Option) -> Self { + if format.is_some() { + self.options.datetime_format = format; + } + self + } + + /// Set the CSV file's forced scientific notation for floats. + pub fn with_float_scientific(mut self, scientific: Option) -> Self { + if scientific.is_some() { + self.options.float_scientific = scientific; + } + self + } + + /// Set the CSV file's float precision. + pub fn with_float_precision(mut self, precision: Option) -> Self { + if precision.is_some() { + self.options.float_precision = precision; + } + self + } + + /// Set the single byte character used for quoting. + pub fn with_quote_char(mut self, char: u8) -> Self { + self.options.quote_char = char; + self + } + + /// Set the CSV file's null value representation. + pub fn with_null_value(mut self, null_value: String) -> Self { + self.options.null = null_value; + self + } + + /// Set the CSV file's line terminator. + pub fn with_line_terminator(mut self, line_terminator: String) -> Self { + self.options.line_terminator = line_terminator; + self + } + + /// Set the CSV file's quoting behavior. + /// See more on [`QuoteStyle`]. + pub fn with_quote_style(mut self, quote_style: QuoteStyle) -> Self { + self.options.quote_style = quote_style; + self + } + + pub fn n_threads(mut self, n_threads: usize) -> Self { + self.n_threads = n_threads; + self + } + + pub fn batched(self, schema: &Schema) -> PolarsResult> { + let expects_bom = self.bom; + let expects_header = self.header; + Ok(BatchedWriter { + writer: self, + has_written_bom: !expects_bom, + has_written_header: !expects_header, + schema: schema.clone(), + }) + } +} + +pub struct BatchedWriter { + writer: CsvWriter, + has_written_bom: bool, + has_written_header: bool, + schema: Schema, +} + +impl BatchedWriter { + /// Write a batch to the csv writer. + /// + /// # Panics + /// The caller must ensure the chunks in the given [`DataFrame`] are aligned. + pub fn write_batch(&mut self, df: &DataFrame) -> PolarsResult<()> { + if !self.has_written_bom { + self.has_written_bom = true; + write_bom(&mut self.writer.buffer)?; + } + + if !self.has_written_header { + self.has_written_header = true; + let names = df + .get_column_names() + .into_iter() + .map(|x| x.as_str()) + .collect::>(); + write_header( + &mut self.writer.buffer, + names.as_slice(), + &self.writer.options, + )?; + } + + write( + &mut self.writer.buffer, + df, + self.writer.batch_size.into(), + &self.writer.options, + self.writer.n_threads, + )?; + Ok(()) + } + + /// Writes the header of the csv file if not done already. Returns the total size of the file. + pub fn finish(&mut self) -> PolarsResult<()> { + if !self.has_written_bom { + self.has_written_bom = true; + write_bom(&mut self.writer.buffer)?; + } + + if !self.has_written_header { + self.has_written_header = true; + let names = self + .schema + .iter_names() + .map(|x| x.as_str()) + .collect::>(); + write_header(&mut self.writer.buffer, &names, &self.writer.options)?; + }; + + Ok(()) + } +} diff --git a/crates/polars-io/src/file_cache/cache.rs b/crates/polars-io/src/file_cache/cache.rs new file mode 100644 index 000000000000..efcba3253d49 --- /dev/null +++ b/crates/polars-io/src/file_cache/cache.rs @@ -0,0 +1,187 @@ +use std::path::Path; +use std::sync::atomic::AtomicU64; +use std::sync::{Arc, LazyLock, RwLock}; + +use polars_core::config; +use polars_error::PolarsResult; +use polars_utils::aliases::PlHashMap; + +use super::entry::{DATA_PREFIX, FileCacheEntry, METADATA_PREFIX}; +use super::eviction::EvictionManager; +use super::file_fetcher::FileFetcher; +use super::utils::FILE_CACHE_PREFIX; +use crate::path_utils::{ensure_directory_init, is_cloud_url}; + +pub static FILE_CACHE: LazyLock = LazyLock::new(|| { + let prefix = FILE_CACHE_PREFIX.as_ref(); + let prefix = Arc::::from(prefix); + + if config::verbose() { + eprintln!("file cache prefix: {}", prefix.to_str().unwrap()); + } + + let min_ttl = Arc::new(AtomicU64::from(get_env_file_cache_ttl())); + let notify_ttl_updated = Arc::new(tokio::sync::Notify::new()); + + let metadata_dir = prefix + .as_ref() + .join(std::str::from_utf8(&[METADATA_PREFIX]).unwrap()) + .into_boxed_path(); + if let Err(err) = ensure_directory_init(&metadata_dir) { + panic!( + "failed to create file cache metadata directory: path = {}, err = {}", + metadata_dir.to_str().unwrap(), + err + ) + } + + let data_dir = prefix + .as_ref() + .join(std::str::from_utf8(&[DATA_PREFIX]).unwrap()) + .into_boxed_path(); + + if let Err(err) = ensure_directory_init(&data_dir) { + panic!( + "failed to create file cache data directory: path = {}, err = {}", + data_dir.to_str().unwrap(), + err + ) + } + + EvictionManager { + data_dir, + metadata_dir, + files_to_remove: None, + min_ttl: min_ttl.clone(), + notify_ttl_updated: notify_ttl_updated.clone(), + } + .run_in_background(); + + // Safety: We have created the data and metadata directories. + unsafe { FileCache::new_unchecked(prefix, min_ttl, notify_ttl_updated) } +}); + +pub struct FileCache { + prefix: Arc, + entries: Arc, Arc>>>, + min_ttl: Arc, + notify_ttl_updated: Arc, +} + +impl FileCache { + /// # Safety + /// The following directories exist: + /// * `{prefix}/{METADATA_PREFIX}/` + /// * `{prefix}/{DATA_PREFIX}/` + unsafe fn new_unchecked( + prefix: Arc, + min_ttl: Arc, + notify_ttl_updated: Arc, + ) -> Self { + Self { + prefix, + entries: Default::default(), + min_ttl, + notify_ttl_updated, + } + } + + /// If `uri` is a local path, it must be an absolute path. This is not exposed + /// for now - initialize entries using `init_entries_from_uri_list` instead. + pub(super) fn init_entry PolarsResult>>( + &self, + uri: Arc, + get_file_fetcher: F, + ttl: u64, + ) -> PolarsResult> { + let verbose = config::verbose(); + + #[cfg(debug_assertions)] + { + // Local paths must be absolute or else the cache would be wrong. + if !crate::path_utils::is_cloud_url(uri.as_ref()) { + let path = Path::new(uri.as_ref()); + assert_eq!(path, std::fs::canonicalize(path).unwrap().as_path()); + } + } + + if self + .min_ttl + .fetch_min(ttl, std::sync::atomic::Ordering::Relaxed) + < ttl + { + self.notify_ttl_updated.notify_one(); + } + + { + let entries = self.entries.read().unwrap(); + + if let Some(entry) = entries.get(uri.as_ref()) { + if verbose { + eprintln!( + "[file_cache] init_entry: return existing entry for uri = {}", + uri.clone() + ); + } + entry.update_ttl(ttl); + return Ok(entry.clone()); + } + } + + let uri_hash = blake3::hash(uri.as_bytes()).to_hex()[..32].to_string(); + + { + let mut entries = self.entries.write().unwrap(); + + // May have been raced + if let Some(entry) = entries.get(uri.as_ref()) { + if verbose { + eprintln!( + "[file_cache] init_entry: return existing entry for uri = {} (lost init race)", + uri.clone() + ); + } + entry.update_ttl(ttl); + return Ok(entry.clone()); + } + + if verbose { + eprintln!( + "[file_cache] init_entry: creating new entry for uri = {}, hash = {}", + uri.clone(), + uri_hash.clone() + ); + } + + let entry = Arc::new(FileCacheEntry::new( + uri.clone(), + uri_hash, + self.prefix.clone(), + get_file_fetcher()?, + ttl, + )); + entries.insert(uri, entry.clone()); + Ok(entry.clone()) + } + } + + /// This function can accept relative local paths. + pub fn get_entry(&self, uri: &str) -> Option> { + if is_cloud_url(uri) { + self.entries.read().unwrap().get(uri).map(Arc::clone) + } else { + let uri = std::fs::canonicalize(uri).unwrap(); + self.entries + .read() + .unwrap() + .get(uri.to_str().unwrap()) + .map(Arc::clone) + } + } +} + +pub fn get_env_file_cache_ttl() -> u64 { + std::env::var("POLARS_FILE_CACHE_TTL") + .map(|x| x.parse::().expect("integer")) + .unwrap_or(60 * 60) +} diff --git a/crates/polars-io/src/file_cache/cache_lock.rs b/crates/polars-io/src/file_cache/cache_lock.rs new file mode 100644 index 000000000000..6cb320bbefdb --- /dev/null +++ b/crates/polars-io/src/file_cache/cache_lock.rs @@ -0,0 +1,222 @@ +use std::sync::atomic::AtomicBool; +use std::sync::{Arc, LazyLock, RwLock, RwLockReadGuard, RwLockWriteGuard}; +use std::time::Duration; + +use fs4::fs_std::FileExt; + +use super::utils::FILE_CACHE_PREFIX; +use crate::pl_async; + +pub(super) static GLOBAL_FILE_CACHE_LOCK: LazyLock = LazyLock::new(|| { + let path = FILE_CACHE_PREFIX.join(".process-lock"); + + let file = std::fs::OpenOptions::new() + .write(true) + .create(true) + .truncate(false) + .open(path) + .map_err(|err| { + panic!("failed to open/create global file cache lockfile: {}", err); + }) + .unwrap(); + + let at_bool = Arc::new(AtomicBool::new(false)); + // Holding this access tracker prevents the background task from + // unlocking the lock. + let access_tracker = AccessTracker(at_bool.clone()); + let notify_lock_acquired = Arc::new(tokio::sync::Notify::new()); + let notify_lock_acquired_2 = notify_lock_acquired.clone(); + + pl_async::get_runtime().spawn(async move { + let access_tracker = at_bool; + let notify_lock_acquired = notify_lock_acquired_2; + let verbose = false; + + loop { + if verbose { + eprintln!("file cache background unlock: waiting for acquisition notification"); + } + + notify_lock_acquired.notified().await; + + if verbose { + eprintln!("file cache background unlock: got acquisition notification"); + } + + loop { + if !access_tracker.swap(false, std::sync::atomic::Ordering::Relaxed) { + if let Some(unlocked_by_this_call) = GLOBAL_FILE_CACHE_LOCK.try_unlock() { + if unlocked_by_this_call && verbose { + eprintln!( + "file cache background unlock: unlocked global file cache lockfile" + ); + } + break; + } + } + tokio::time::sleep(Duration::from_secs(3)).await; + } + } + }); + + GlobalLock { + inner: RwLock::new(GlobalLockData { file, state: None }), + access_tracker, + notify_lock_acquired, + } +}); + +pub(super) enum LockedState { + /// Shared between threads and other processes. + Shared, + #[allow(dead_code)] + /// Locked exclusively by the eviction task of this process. + Eviction, +} + +#[allow(dead_code)] +pub(super) type GlobalFileCacheGuardAny<'a> = RwLockReadGuard<'a, GlobalLockData>; +pub(super) type GlobalFileCacheGuardExclusive<'a> = RwLockWriteGuard<'a, GlobalLockData>; + +pub(super) struct GlobalLockData { + file: std::fs::File, + state: Option, +} + +pub(super) struct GlobalLock { + inner: RwLock, + access_tracker: AccessTracker, + notify_lock_acquired: Arc, +} + +/// Tracks access to the global lock: +/// * The inner `bool` is used to delay the background unlock task from unlocking +/// the global lock until 3 seconds after the last lock attempt. +/// * The `Arc` ref-count is used as a semaphore that allows us to block exclusive +/// lock attempts while temporarily releasing the `RwLock`. +#[derive(Clone)] +struct AccessTracker(Arc); + +impl Drop for AccessTracker { + fn drop(&mut self) { + self.0.store(true, std::sync::atomic::Ordering::Relaxed); + } +} + +struct NotifyOnDrop(Arc); + +impl Drop for NotifyOnDrop { + fn drop(&mut self) { + self.0.notify_one(); + } +} + +impl GlobalLock { + fn get_access_tracker(&self) -> AccessTracker { + let at = self.access_tracker.clone(); + at.0.store(true, std::sync::atomic::Ordering::Relaxed); + at + } + + /// Returns + /// * `None` - Could be locked (ambiguous) + /// * `Some(true)` - Unlocked (by this function call) + /// * `Some(false)` - Unlocked (was not locked) + fn try_unlock(&self) -> Option { + if let Ok(mut this) = self.inner.try_write() { + if Arc::strong_count(&self.access_tracker.0) <= 2 { + return if this.state.take().is_some() { + FileExt::unlock(&this.file).unwrap(); + Some(true) + } else { + Some(false) + }; + } + } + None + } + + /// Acquire a shared lock. + pub(super) fn lock_shared(&self) -> GlobalFileCacheGuardAny { + let access_tracker = self.get_access_tracker(); + let _notify_on_drop = NotifyOnDrop(self.notify_lock_acquired.clone()); + + { + let this = self.inner.read().unwrap(); + + if let Some(LockedState::Shared) = this.state { + return this; + } + } + + { + let mut this = self.inner.write().unwrap(); + + if let Some(LockedState::Eviction) = this.state { + FileExt::unlock(&this.file).unwrap(); + this.state = None; + } + + if this.state.is_none() { + FileExt::lock_shared(&this.file).unwrap(); + this.state = Some(LockedState::Shared); + } + } + + // Safety: Holding the access tracker guard maintains an Arc refcount + // > 2, which prevents automatic unlock. + debug_assert!(Arc::strong_count(&access_tracker.0) > 2); + + { + let this = self.inner.read().unwrap(); + + if let Some(LockedState::Eviction) = this.state { + // Try again + drop(this); + return self.lock_shared(); + } + + assert!( + this.state.is_some(), + "impl error: global file cache lock was unlocked" + ); + this + } + } + + /// Acquire an exclusive lock on the cache directory. Holding this lock freezes + /// all cache operations except for reading from already-opened data files. + #[allow(dead_code)] + pub(super) fn try_lock_eviction(&self) -> Option { + let access_tracker = self.get_access_tracker(); + + if let Ok(mut this) = self.inner.try_write() { + if + // 3: + // * the Lazy + // * the global unlock background task + // * this function + Arc::strong_count(&access_tracker.0) > 3 { + return None; + } + + let _notify_on_drop = NotifyOnDrop(self.notify_lock_acquired.clone()); + + if let Some(ref state) = this.state { + if matches!(state, LockedState::Eviction) { + return Some(this); + } + } + + if this.state.take().is_some() { + FileExt::unlock(&this.file).unwrap(); + } + + if this.file.try_lock_exclusive().is_ok() { + this.state = Some(LockedState::Eviction); + return Some(this); + } + } + None + } +} diff --git a/crates/polars-io/src/file_cache/entry.rs b/crates/polars-io/src/file_cache/entry.rs new file mode 100644 index 000000000000..029882a1ba8d --- /dev/null +++ b/crates/polars-io/src/file_cache/entry.rs @@ -0,0 +1,432 @@ +use std::io::{Seek, SeekFrom}; +use std::path::{Path, PathBuf}; +use std::sync::atomic::AtomicU64; +use std::sync::{Arc, LazyLock, Mutex}; + +use fs4::fs_std::FileExt; +use polars_core::config; +use polars_error::{PolarsError, PolarsResult, polars_bail, to_compute_err}; + +use super::cache_lock::{self, GLOBAL_FILE_CACHE_LOCK}; +use super::file_fetcher::{FileFetcher, RemoteMetadata}; +use super::file_lock::{FileLock, FileLockAnyGuard}; +use super::metadata::{EntryMetadata, FileVersion}; +use super::utils::update_last_accessed; + +pub(super) const DATA_PREFIX: u8 = b'd'; +pub(super) const METADATA_PREFIX: u8 = b'm'; + +struct CachedData { + last_modified: u64, + metadata: Arc, + data_file_path: PathBuf, +} + +struct Inner { + uri: Arc, + uri_hash: String, + path_prefix: Arc, + metadata: FileLock, + cached_data: Option, + ttl: Arc, + file_fetcher: Arc, +} + +struct EntryData { + uri: Arc, + inner: Mutex, + ttl: Arc, +} + +pub struct FileCacheEntry(EntryData); + +impl EntryMetadata { + fn matches_remote_metadata(&self, remote_metadata: &RemoteMetadata) -> bool { + self.remote_version == remote_metadata.version && self.local_size == remote_metadata.size + } +} + +impl Inner { + fn try_open_assume_latest(&mut self) -> PolarsResult { + let verbose = config::verbose(); + + { + let cache_guard = GLOBAL_FILE_CACHE_LOCK.lock_shared(); + // We want to use an exclusive lock here to avoid an API call in the case where only the + // local TTL was updated. + let metadata_file = &mut self.metadata.acquire_exclusive().unwrap(); + update_last_accessed(metadata_file); + + if let Ok(metadata) = self.try_get_metadata(metadata_file, &cache_guard) { + let data_file_path = self.get_cached_data_file_path(); + + if metadata.compare_local_state(data_file_path).is_ok() { + if verbose { + eprintln!( + "[file_cache::entry] try_open_assume_latest: opening already fetched file for uri = {}", + self.uri.clone() + ); + } + return Ok(finish_open(data_file_path, metadata_file)); + } + } + } + + if verbose { + eprintln!( + "[file_cache::entry] try_open_assume_latest: did not find cached file for uri = {}", + self.uri.clone() + ); + } + + self.try_open_check_latest() + } + + fn try_open_check_latest(&mut self) -> PolarsResult { + let verbose = config::verbose(); + let remote_metadata = &self.file_fetcher.fetch_metadata()?; + let cache_guard = GLOBAL_FILE_CACHE_LOCK.lock_shared(); + + { + let metadata_file = &mut self.metadata.acquire_shared().unwrap(); + update_last_accessed(metadata_file); + + if let Ok(metadata) = self.try_get_metadata(metadata_file, &cache_guard) { + if metadata.matches_remote_metadata(remote_metadata) { + let data_file_path = self.get_cached_data_file_path(); + + if metadata.compare_local_state(data_file_path).is_ok() { + if verbose { + eprintln!( + "[file_cache::entry] try_open_check_latest: opening already fetched file for uri = {}", + self.uri.clone() + ); + } + return Ok(finish_open(data_file_path, metadata_file)); + } + } + } + } + + let metadata_file = &mut self.metadata.acquire_exclusive().unwrap(); + let metadata = self + .try_get_metadata(metadata_file, &cache_guard) + // Safety: `metadata_file` is an exclusive guard. + .unwrap_or_else(|_| { + Arc::new(EntryMetadata::new( + self.uri.clone(), + self.ttl.load(std::sync::atomic::Ordering::Relaxed), + )) + }); + + if metadata.matches_remote_metadata(remote_metadata) { + let data_file_path = self.get_cached_data_file_path(); + + if metadata.compare_local_state(data_file_path).is_ok() { + if verbose { + eprintln!( + "[file_cache::entry] try_open_check_latest: opening already fetched file (lost race) for uri = {}", + self.uri.clone() + ); + } + return Ok(finish_open(data_file_path, metadata_file)); + } + } + + if verbose { + eprintln!( + "[file_cache::entry] try_open_check_latest: fetching new data file for uri = {}, remote_version = {:?}, remote_size = {}", + self.uri.clone(), + remote_metadata.version, + remote_metadata.size + ); + } + + let data_file_path = &get_data_file_path( + self.path_prefix.to_str().unwrap().as_bytes(), + self.uri_hash.as_bytes(), + &remote_metadata.version, + ); + // Remove the file if it exists, since it doesn't match the metadata. + // This could be left from an aborted process. + let _ = std::fs::remove_file(data_file_path); + if !self.file_fetcher.fetches_as_symlink() { + let file = std::fs::OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(data_file_path) + .map_err(PolarsError::from)?; + + // * Some(true) => always raise + // * Some(false) => never raise + // * None => do not raise if fallocate() is not permitted, otherwise raise. + static RAISE_ALLOC_ERROR: LazyLock> = LazyLock::new(|| { + let v = match std::env::var("POLARS_IGNORE_FILE_CACHE_ALLOCATE_ERROR").as_deref() { + Ok("1") => Some(false), + Ok("0") => Some(true), + Err(_) => None, + Ok(v) => panic!( + "invalid value {} for POLARS_IGNORE_FILE_CACHE_ALLOCATE_ERROR", + v + ), + }; + if config::verbose() { + eprintln!("[file_cache]: RAISE_ALLOC_ERROR: {:?}", v); + } + v + }); + + // Initialize it to get the verbose print + let raise_alloc_err = *RAISE_ALLOC_ERROR; + + file.lock_exclusive().unwrap(); + if let Err(e) = file.allocate(remote_metadata.size) { + let msg = format!( + "failed to reserve {} bytes on disk to download uri = {}: {:?}", + remote_metadata.size, + self.uri.as_ref(), + e + ); + + if raise_alloc_err == Some(true) + || (raise_alloc_err.is_none() && file.allocate(1).is_ok()) + { + polars_bail!(ComputeError: msg) + } else if config::verbose() { + eprintln!("[file_cache]: warning: {}", msg) + } + } + } + self.file_fetcher.fetch(data_file_path)?; + + // Don't do this on windows as it will break setting last accessed times. + #[cfg(target_family = "unix")] + if !self.file_fetcher.fetches_as_symlink() { + let mut perms = std::fs::metadata(data_file_path.clone()) + .unwrap() + .permissions(); + perms.set_readonly(true); + std::fs::set_permissions(data_file_path, perms).unwrap(); + } + + let data_file_metadata = std::fs::metadata(data_file_path).unwrap(); + let local_last_modified = super::utils::last_modified_u64(&data_file_metadata); + let local_size = data_file_metadata.len(); + + if local_size != remote_metadata.size { + polars_bail!(ComputeError: "downloaded file size ({}) does not match expected size ({})", local_size, remote_metadata.size); + } + + let mut metadata = metadata; + let metadata = Arc::make_mut(&mut metadata); + metadata.local_last_modified = local_last_modified; + metadata.local_size = local_size; + metadata.remote_version = remote_metadata.version.clone(); + + if let Err(e) = metadata.compare_local_state(data_file_path) { + panic!("metadata mismatch after file fetch: {}", e); + } + + let data_file = finish_open(data_file_path, metadata_file); + + metadata_file.set_len(0).unwrap(); + metadata_file.seek(SeekFrom::Start(0)).unwrap(); + metadata + .try_write(&mut **metadata_file) + .map_err(to_compute_err)?; + + Ok(data_file) + } + + /// Try to read the metadata from disk. If `F` is an exclusive guard, this + /// will update the TTL stored in the metadata file if it does not match. + fn try_get_metadata( + &mut self, + metadata_file: &mut F, + _cache_guard: &cache_lock::GlobalFileCacheGuardAny, + ) -> PolarsResult> { + let last_modified = super::utils::last_modified_u64(&metadata_file.metadata().unwrap()); + let ttl = self.ttl.load(std::sync::atomic::Ordering::Relaxed); + + for _ in 0..2 { + if let Some(ref cached) = self.cached_data { + if cached.last_modified == last_modified { + if cached.metadata.ttl != ttl { + polars_bail!(ComputeError: "TTL mismatch"); + } + + if cached.metadata.uri != self.uri { + unimplemented!( + "hash collision: uri1 = {}, uri2 = {}, hash = {}", + cached.metadata.uri, + self.uri, + self.uri_hash, + ); + } + + return Ok(cached.metadata.clone()); + } + } + + // Ensure cache is unset if read fails + self.cached_data = None; + + let mut metadata = + EntryMetadata::try_from_reader(&mut **metadata_file).map_err(to_compute_err)?; + + // Note this means if multiple processes on the same system set a + // different TTL for the same path, the metadata file will constantly + // get overwritten. + if metadata.ttl != ttl { + if F::IS_EXCLUSIVE { + metadata.ttl = ttl; + metadata_file.set_len(0).unwrap(); + metadata_file.seek(SeekFrom::Start(0)).unwrap(); + metadata + .try_write(&mut **metadata_file) + .map_err(to_compute_err)?; + } else { + polars_bail!(ComputeError: "TTL mismatch"); + } + } + + let metadata = Arc::new(metadata); + let data_file_path = get_data_file_path( + self.path_prefix.to_str().unwrap().as_bytes(), + self.uri_hash.as_bytes(), + &metadata.remote_version, + ); + self.cached_data = Some(CachedData { + last_modified, + metadata, + data_file_path, + }); + } + + unreachable!(); + } + + /// # Panics + /// Panics if `self.cached_data` is `None`. + fn get_cached_data_file_path(&self) -> &Path { + &self.cached_data.as_ref().unwrap().data_file_path + } +} + +impl FileCacheEntry { + pub(crate) fn new( + uri: Arc, + uri_hash: String, + path_prefix: Arc, + file_fetcher: Arc, + file_cache_ttl: u64, + ) -> Self { + let metadata = FileLock::from(get_metadata_file_path( + path_prefix.to_str().unwrap().as_bytes(), + uri_hash.as_bytes(), + )); + + debug_assert!( + Arc::ptr_eq(&uri, file_fetcher.get_uri()), + "impl error: entry uri != file_fetcher uri" + ); + + let ttl = Arc::new(AtomicU64::from(file_cache_ttl)); + + Self(EntryData { + uri: uri.clone(), + inner: Mutex::new(Inner { + uri, + uri_hash, + path_prefix, + metadata, + cached_data: None, + ttl: ttl.clone(), + file_fetcher, + }), + ttl, + }) + } + + pub fn uri(&self) -> &Arc { + &self.0.uri + } + + /// Directly returns the cached file if it finds one without checking if + /// there is a newer version on the remote. This does not make any API calls + /// if it finds a cached file, otherwise it simply downloads the file. + pub fn try_open_assume_latest(&self) -> PolarsResult { + self.0.inner.lock().unwrap().try_open_assume_latest() + } + + /// Returns the cached file after ensuring it is up to date against the remote + /// This will always perform at least 1 API call for fetching metadata. + pub fn try_open_check_latest(&self) -> PolarsResult { + self.0.inner.lock().unwrap().try_open_check_latest() + } + + pub fn update_ttl(&self, ttl: u64) { + self.0.ttl.store(ttl, std::sync::atomic::Ordering::Relaxed); + } +} + +fn finish_open(data_file_path: &Path, _metadata_guard: &F) -> std::fs::File { + let file = { + #[cfg(not(target_family = "windows"))] + { + std::fs::OpenOptions::new() + .read(true) + .open(data_file_path) + .unwrap() + } + // windows requires write access to update the last accessed time + #[cfg(target_family = "windows")] + { + std::fs::OpenOptions::new() + .read(true) + .write(true) + .open(data_file_path) + .unwrap() + } + }; + update_last_accessed(&file); + if FileExt::try_lock_shared(&file).is_err() { + panic!( + "finish_open: could not acquire shared lock on data file at {}", + data_file_path.to_str().unwrap() + ); + } + file +} + +/// `[prefix]/d/[uri hash][last modified]` +fn get_data_file_path( + path_prefix: &[u8], + uri_hash: &[u8], + remote_version: &FileVersion, +) -> PathBuf { + let owned; + let path = [ + path_prefix, + &[b'/', DATA_PREFIX, b'/'], + uri_hash, + match remote_version { + FileVersion::Timestamp(v) => { + owned = Some(format!("{:013x}", v)); + owned.as_deref().unwrap() + }, + FileVersion::ETag(v) => v.as_str(), + FileVersion::Uninitialized => panic!("impl error: version not initialized"), + } + .as_bytes(), + ] + .concat(); + PathBuf::from(String::from_utf8(path).unwrap()) +} + +/// `[prefix]/m/[uri hash]` +fn get_metadata_file_path(path_prefix: &[u8], uri_hash: &[u8]) -> PathBuf { + let bytes = [path_prefix, &[b'/', METADATA_PREFIX, b'/'], uri_hash].concat(); + PathBuf::from(String::from_utf8(bytes).unwrap()) +} diff --git a/crates/polars-io/src/file_cache/eviction.rs b/crates/polars-io/src/file_cache/eviction.rs new file mode 100644 index 000000000000..b69af7225f91 --- /dev/null +++ b/crates/polars-io/src/file_cache/eviction.rs @@ -0,0 +1,335 @@ +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::sync::atomic::AtomicU64; +use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; + +use fs4::fs_std::FileExt; +use polars_error::{PolarsError, PolarsResult}; + +use super::cache_lock::{GLOBAL_FILE_CACHE_LOCK, GlobalFileCacheGuardExclusive}; +use super::metadata::EntryMetadata; +use crate::pl_async; + +#[derive(Debug, Clone)] +pub(super) struct EvictionCandidate { + path: PathBuf, + metadata_path: PathBuf, + metadata_last_modified: SystemTime, + ttl: u64, +} + +pub(super) struct EvictionManager { + pub(super) data_dir: Box, + pub(super) metadata_dir: Box, + pub(super) files_to_remove: Option>, + pub(super) min_ttl: Arc, + pub(super) notify_ttl_updated: Arc, +} + +impl EvictionCandidate { + fn update_ttl(&mut self) { + let Ok(metadata_last_modified) = + std::fs::metadata(&self.metadata_path).map(|md| md.modified().unwrap()) + else { + self.ttl = 0; + return; + }; + + if self.metadata_last_modified == metadata_last_modified { + return; + } + + let Ok(ref mut file) = std::fs::OpenOptions::new() + .read(true) + .open(&self.metadata_path) + else { + self.ttl = 0; + return; + }; + + let ttl = EntryMetadata::try_from_reader(file) + .map(|x| x.ttl) + .unwrap_or(0); + + self.metadata_last_modified = metadata_last_modified; + self.ttl = ttl; + } + + fn should_remove(&self, now: &SystemTime) -> bool { + let Ok(metadata) = std::fs::metadata(&self.path) else { + return false; + }; + + if let Ok(duration) = now.duration_since( + metadata + .accessed() + .unwrap_or_else(|_| metadata.modified().unwrap()), + ) { + duration.as_secs() >= self.ttl + } else { + false + } + } + + fn try_evict( + &mut self, + now: &SystemTime, + verbose: bool, + _guard: &GlobalFileCacheGuardExclusive, + ) { + self.update_ttl(); + let path = &self.path; + + if !path.exists() { + if verbose { + eprintln!( + "[EvictionManager] evict_files: skipping {} (path no longer exists)", + path.to_str().unwrap() + ); + } + return; + } + + let metadata = std::fs::metadata(path).unwrap(); + + let since_last_accessed = match now.duration_since( + metadata + .accessed() + .unwrap_or_else(|_| metadata.modified().unwrap()), + ) { + Ok(v) => v.as_secs(), + Err(_) => { + if verbose { + eprintln!( + "[EvictionManager] evict_files: skipping {} (last accessed time was updated)", + path.to_str().unwrap() + ); + } + return; + }, + }; + + if since_last_accessed < self.ttl { + if verbose { + eprintln!( + "[EvictionManager] evict_files: skipping {} (last accessed time was updated)", + path.to_str().unwrap() + ); + } + return; + } + + { + let file = std::fs::OpenOptions::new().read(true).open(path).unwrap(); + + if file.try_lock_exclusive().is_err() { + if verbose { + eprintln!( + "[EvictionManager] evict_files: skipping {} (file is locked)", + self.path.to_str().unwrap() + ); + } + return; + } + } + + if let Err(err) = std::fs::remove_file(path) { + if verbose { + eprintln!( + "[EvictionManager] evict_files: error removing file: {} ({})", + path.to_str().unwrap(), + err + ); + } + } else if verbose { + eprintln!( + "[EvictionManager] evict_files: removed file at {}", + path.to_str().unwrap() + ); + } + } +} + +impl EvictionManager { + /// # Safety + /// The following directories exist: + /// * `self.data_dir` + /// * `self.metadata_dir` + pub(super) fn run_in_background(mut self) { + let verbose = false; + + if verbose { + eprintln!( + "[EvictionManager] creating cache eviction background task, self.min_ttl = {}", + self.min_ttl.load(std::sync::atomic::Ordering::Relaxed) + ); + } + + pl_async::get_runtime().spawn(async move { + // Give some time at startup for other code to run. + tokio::time::sleep(Duration::from_secs(3)).await; + let mut last_eviction_time; + + loop { + let this: &'static mut Self = unsafe { std::mem::transmute(&mut self) }; + + let result = tokio::task::spawn_blocking(|| this.update_file_list()) + .await + .unwrap(); + + last_eviction_time = Instant::now(); + + match result { + Ok(_) if self.files_to_remove.as_ref().unwrap().is_empty() => {}, + Ok(_) => loop { + if let Some(guard) = GLOBAL_FILE_CACHE_LOCK.try_lock_eviction() { + if verbose { + eprintln!( + "[EvictionManager] got exclusive cache lock, evicting {} files", + self.files_to_remove.as_ref().unwrap().len() + ); + } + + tokio::task::block_in_place(|| self.evict_files(&guard)); + break; + } + tokio::time::sleep(Duration::from_secs(7)).await; + }, + Err(err) => { + if verbose { + eprintln!("[EvictionManager] error updating file list: {}", err); + } + }, + } + + loop { + let min_ttl = self.min_ttl.load(std::sync::atomic::Ordering::Relaxed); + let sleep_interval = std::cmp::max(min_ttl / 4, { + #[cfg(debug_assertions)] + { + 3 + } + #[cfg(not(debug_assertions))] + { + 60 + } + }); + + let since_last_eviction = + Instant::now().duration_since(last_eviction_time).as_secs(); + let sleep_interval = sleep_interval.saturating_sub(since_last_eviction); + let sleep_interval = Duration::from_secs(sleep_interval); + + tokio::select! { + _ = self.notify_ttl_updated.notified() => { + continue; + } + _ = tokio::time::sleep(sleep_interval) => { + break; + } + } + } + } + }); + } + + fn update_file_list(&mut self) -> PolarsResult<()> { + let data_files_iter = match std::fs::read_dir(self.data_dir.as_ref()) { + Ok(v) => v, + Err(e) => { + let msg = format!("failed to read data directory: {}", e); + + return Err(PolarsError::IO { + error: e.into(), + msg: Some(msg.into()), + }); + }, + }; + + let metadata_files_iter = match std::fs::read_dir(self.metadata_dir.as_ref()) { + Ok(v) => v, + Err(e) => { + let msg = format!("failed to read metadata directory: {}", e); + + return Err(PolarsError::IO { + error: e.into(), + msg: Some(msg.into()), + }); + }, + }; + + let mut files_to_remove = Vec::with_capacity( + data_files_iter + .size_hint() + .1 + .unwrap_or(data_files_iter.size_hint().0) + + metadata_files_iter + .size_hint() + .1 + .unwrap_or(metadata_files_iter.size_hint().0), + ); + + let now = SystemTime::now(); + + for file in data_files_iter { + let file = file?; + let path = file.path(); + + let hash = path + .file_name() + .unwrap() + .to_str() + .unwrap() + .get(..32) + .unwrap(); + let metadata_path = self.metadata_dir.join(hash); + + let mut eviction_candidate = EvictionCandidate { + path, + metadata_path, + metadata_last_modified: UNIX_EPOCH, + ttl: 0, + }; + eviction_candidate.update_ttl(); + + if eviction_candidate.should_remove(&now) { + files_to_remove.push(eviction_candidate); + } + } + + for file in metadata_files_iter { + let file = file?; + let path = file.path(); + let metadata_path = path.clone(); + + let mut eviction_candidate = EvictionCandidate { + path, + metadata_path, + metadata_last_modified: UNIX_EPOCH, + ttl: 0, + }; + + eviction_candidate.update_ttl(); + + if eviction_candidate.should_remove(&now) { + files_to_remove.push(eviction_candidate); + } + } + + self.files_to_remove = Some(files_to_remove); + + Ok(()) + } + + /// # Panics + /// Panics if `self.files_to_remove` is `None`. + fn evict_files(&mut self, _guard: &GlobalFileCacheGuardExclusive) { + let verbose = false; + let mut files_to_remove = self.files_to_remove.take().unwrap(); + let now = &SystemTime::now(); + + for eviction_candidate in files_to_remove.iter_mut() { + eviction_candidate.try_evict(now, verbose, _guard); + } + } +} diff --git a/crates/polars-io/src/file_cache/file_fetcher.rs b/crates/polars-io/src/file_cache/file_fetcher.rs new file mode 100644 index 000000000000..ca326ef2f374 --- /dev/null +++ b/crates/polars-io/src/file_cache/file_fetcher.rs @@ -0,0 +1,122 @@ +use std::sync::Arc; + +use polars_error::{PolarsError, PolarsResult}; + +use super::metadata::FileVersion; +use super::utils::last_modified_u64; +use crate::cloud::PolarsObjectStore; +use crate::pl_async; + +pub trait FileFetcher: Send + Sync { + fn get_uri(&self) -> &Arc; + fn fetch_metadata(&self) -> PolarsResult; + /// Fetches the object to a `local_path`. + fn fetch(&self, local_path: &std::path::Path) -> PolarsResult<()>; + fn fetches_as_symlink(&self) -> bool; +} + +pub struct RemoteMetadata { + pub size: u64, + pub(super) version: FileVersion, +} + +/// A struct that fetches data from local disk and stores it into the `cache`. +/// Mostly used for debugging, it only ever gets called if `POLARS_FORCE_ASYNC` is set. +pub(super) struct LocalFileFetcher { + uri: Arc, + path: Box, +} + +impl LocalFileFetcher { + pub(super) fn from_uri(uri: Arc) -> Self { + let path = std::path::PathBuf::from(uri.as_ref()).into_boxed_path(); + debug_assert_eq!( + path, + std::fs::canonicalize(&path).unwrap().into_boxed_path() + ); + + Self { uri, path } + } +} + +impl FileFetcher for LocalFileFetcher { + fn get_uri(&self) -> &Arc { + &self.uri + } + + fn fetches_as_symlink(&self) -> bool { + #[cfg(target_family = "unix")] + { + true + } + #[cfg(not(target_family = "unix"))] + { + false + } + } + + fn fetch_metadata(&self) -> PolarsResult { + let metadata = std::fs::metadata(&self.path).map_err(PolarsError::from)?; + + Ok(RemoteMetadata { + size: metadata.len(), + version: FileVersion::Timestamp(last_modified_u64(&metadata)), + }) + } + + fn fetch(&self, local_path: &std::path::Path) -> PolarsResult<()> { + #[cfg(target_family = "unix")] + { + std::os::unix::fs::symlink(&self.path, local_path).map_err(PolarsError::from) + } + #[cfg(not(target_family = "unix"))] + { + std::fs::copy(&self.path, local_path).map_err(PolarsError::from)?; + Ok(()) + } + } +} + +pub(super) struct CloudFileFetcher { + pub(super) uri: Arc, + pub(super) cloud_path: object_store::path::Path, + pub(super) object_store: PolarsObjectStore, +} + +impl FileFetcher for CloudFileFetcher { + fn get_uri(&self) -> &Arc { + &self.uri + } + + fn fetches_as_symlink(&self) -> bool { + false + } + + fn fetch_metadata(&self) -> PolarsResult { + let metadata = + pl_async::get_runtime().block_in_place_on(self.object_store.head(&self.cloud_path))?; + + Ok(RemoteMetadata { + size: metadata.size as u64, + version: metadata + .e_tag + .map(|x| FileVersion::ETag(blake3::hash(x.as_bytes()).to_hex()[..32].to_string())) + .unwrap_or_else(|| { + FileVersion::Timestamp(metadata.last_modified.timestamp_millis() as u64) + }), + }) + } + + fn fetch(&self, local_path: &std::path::Path) -> PolarsResult<()> { + pl_async::get_runtime().block_in_place_on(async { + let file = &mut tokio::fs::OpenOptions::new() + .write(true) + .truncate(true) + .open(local_path) + .await + .map_err(PolarsError::from)?; + + self.object_store.download(&self.cloud_path, file).await + }) + } +} diff --git a/crates/polars-io/src/file_cache/file_lock.rs b/crates/polars-io/src/file_cache/file_lock.rs new file mode 100644 index 000000000000..e52b559c2b88 --- /dev/null +++ b/crates/polars-io/src/file_cache/file_lock.rs @@ -0,0 +1,91 @@ +use std::fs::{File, OpenOptions}; +use std::path::Path; + +use fs4::fs_std::FileExt; + +/// Note: this creates the file if it does not exist when acquiring locks. +pub(super) struct FileLock>(T); +pub(super) struct FileLockSharedGuard(File); +pub(super) struct FileLockExclusiveGuard(File); + +/// Trait to specify a file is lock-guarded without needing a particular type of +/// guard (i.e. shared/exclusive). +pub(super) trait FileLockAnyGuard: + std::ops::Deref + std::ops::DerefMut +{ + const IS_EXCLUSIVE: bool; +} +impl FileLockAnyGuard for FileLockSharedGuard { + const IS_EXCLUSIVE: bool = false; +} +impl FileLockAnyGuard for FileLockExclusiveGuard { + const IS_EXCLUSIVE: bool = true; +} + +impl> From for FileLock { + fn from(path: T) -> Self { + Self(path) + } +} + +impl> FileLock { + pub(super) fn acquire_shared(&self) -> Result { + let file = OpenOptions::new() + .create(true) + .truncate(false) + .read(true) + .write(true) + .open(self.0.as_ref())?; + FileExt::lock_shared(&file).map(|_| FileLockSharedGuard(file)) + } + + pub(super) fn acquire_exclusive(&self) -> Result { + let file = OpenOptions::new() + .create(true) + .truncate(false) + .read(true) + .write(true) + .open(self.0.as_ref())?; + file.lock_exclusive().map(|_| FileLockExclusiveGuard(file)) + } +} + +impl std::ops::Deref for FileLockSharedGuard { + type Target = File; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for FileLockSharedGuard { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl Drop for FileLockSharedGuard { + fn drop(&mut self) { + FileExt::unlock(&self.0).unwrap(); + } +} + +impl std::ops::Deref for FileLockExclusiveGuard { + type Target = File; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for FileLockExclusiveGuard { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl Drop for FileLockExclusiveGuard { + fn drop(&mut self) { + FileExt::unlock(&self.0).unwrap(); + } +} diff --git a/crates/polars-io/src/file_cache/metadata.rs b/crates/polars-io/src/file_cache/metadata.rs new file mode 100644 index 000000000000..571b68940478 --- /dev/null +++ b/crates/polars-io/src/file_cache/metadata.rs @@ -0,0 +1,93 @@ +use std::path::Path; +use std::sync::Arc; + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub(super) enum FileVersion { + Timestamp(u64), + ETag(String), + Uninitialized, +} + +#[derive(Debug)] +pub enum LocalCompareError { + LastModifiedMismatch { expected: u64, actual: u64 }, + SizeMismatch { expected: u64, actual: u64 }, + DataFileReadError(std::io::Error), +} + +pub type LocalCompareResult = Result<(), LocalCompareError>; + +/// Metadata written to a file used to track state / synchronize across processes. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub(super) struct EntryMetadata { + pub(super) uri: Arc, + pub(super) local_last_modified: u64, + pub(super) local_size: u64, + pub(super) remote_version: FileVersion, + /// TTL since last access, in seconds. + pub(super) ttl: u64, +} + +impl std::fmt::Display for LocalCompareError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::LastModifiedMismatch { expected, actual } => write!( + f, + "last modified time mismatch: expected {}, found {}", + expected, actual + ), + Self::SizeMismatch { expected, actual } => { + write!(f, "size mismatch: expected {}, found {}", expected, actual) + }, + Self::DataFileReadError(err) => { + write!(f, "failed to read local file metadata: {}", err) + }, + } + } +} + +impl EntryMetadata { + pub(super) fn new(uri: Arc, ttl: u64) -> Self { + Self { + uri, + local_last_modified: 0, + local_size: 0, + remote_version: FileVersion::Uninitialized, + ttl, + } + } + + pub(super) fn compare_local_state(&self, data_file_path: &Path) -> LocalCompareResult { + let metadata = match std::fs::metadata(data_file_path) { + Ok(v) => v, + Err(e) => return Err(LocalCompareError::DataFileReadError(e)), + }; + + let local_last_modified = super::utils::last_modified_u64(&metadata); + let local_size = metadata.len(); + + if local_last_modified != self.local_last_modified { + Err(LocalCompareError::LastModifiedMismatch { + expected: self.local_last_modified, + actual: local_last_modified, + }) + } else if local_size != self.local_size { + Err(LocalCompareError::SizeMismatch { + expected: self.local_size, + actual: local_size, + }) + } else { + Ok(()) + } + } + + pub(super) fn try_write(&self, writer: &mut W) -> serde_json::Result<()> { + serde_json::to_writer(writer, self) + } + + pub(super) fn try_from_reader(reader: &mut R) -> serde_json::Result { + serde_json::from_reader(reader) + } +} diff --git a/crates/polars-io/src/file_cache/mod.rs b/crates/polars-io/src/file_cache/mod.rs new file mode 100644 index 000000000000..0445b48f6db0 --- /dev/null +++ b/crates/polars-io/src/file_cache/mod.rs @@ -0,0 +1,11 @@ +mod cache; +mod cache_lock; +mod entry; +mod eviction; +mod file_fetcher; +mod file_lock; +mod metadata; +mod utils; +pub use cache::{FILE_CACHE, get_env_file_cache_ttl}; +pub use entry::FileCacheEntry; +pub use utils::{FILE_CACHE_PREFIX, init_entries_from_uri_list}; diff --git a/crates/polars-io/src/file_cache/utils.rs b/crates/polars-io/src/file_cache/utils.rs new file mode 100644 index 000000000000..48d45af9f0bc --- /dev/null +++ b/crates/polars-io/src/file_cache/utils.rs @@ -0,0 +1,129 @@ +use std::path::Path; +use std::sync::{Arc, LazyLock}; +use std::time::UNIX_EPOCH; + +use polars_error::{PolarsError, PolarsResult}; + +use super::cache::{FILE_CACHE, get_env_file_cache_ttl}; +use super::entry::FileCacheEntry; +use super::file_fetcher::{CloudFileFetcher, LocalFileFetcher}; +use crate::cloud::{CloudLocation, CloudOptions, build_object_store, object_path_from_str}; +use crate::path_utils::{POLARS_TEMP_DIR_BASE_PATH, ensure_directory_init, is_cloud_url}; +use crate::pl_async; + +pub static FILE_CACHE_PREFIX: LazyLock> = LazyLock::new(|| { + let path = POLARS_TEMP_DIR_BASE_PATH + .join("file-cache/") + .into_boxed_path(); + + if let Err(err) = ensure_directory_init(path.as_ref()) { + panic!( + "failed to create file cache directory: path = {}, err = {}", + path.to_str().unwrap(), + err + ); + } + + path +}); + +pub(super) fn last_modified_u64(metadata: &std::fs::Metadata) -> u64 { + metadata + .modified() + .unwrap() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_millis() as u64 +} + +pub(super) fn update_last_accessed(file: &std::fs::File) { + let file_metadata = file.metadata().unwrap(); + + if let Err(e) = file.set_times( + std::fs::FileTimes::new() + .set_modified(file_metadata.modified().unwrap()) + .set_accessed(std::time::SystemTime::now()), + ) { + panic!("failed to update file last accessed time: {}", e); + } +} + +pub fn init_entries_from_uri_list( + uri_list: &[Arc], + cloud_options: Option<&CloudOptions>, +) -> PolarsResult>> { + if uri_list.is_empty() { + return Ok(Default::default()); + } + + let first_uri = uri_list.first().unwrap().as_ref(); + + let file_cache_ttl = cloud_options + .map(|x| x.file_cache_ttl) + .unwrap_or_else(get_env_file_cache_ttl); + + if is_cloud_url(first_uri) { + let object_stores = pl_async::get_runtime().block_in_place_on(async { + futures::future::try_join_all( + (0..if first_uri.starts_with("http") { + // Object stores for http are tied to the path. + uri_list.len() + } else { + 1 + }) + .map(|i| async move { + let (_, object_store) = + build_object_store(&uri_list[i], cloud_options, false).await?; + PolarsResult::Ok(object_store) + }), + ) + .await + })?; + + uri_list + .iter() + .enumerate() + .map(|(i, uri)| { + FILE_CACHE.init_entry( + uri.clone(), + || { + let CloudLocation { prefix, .. } = + CloudLocation::new(uri.as_ref(), false).unwrap(); + let cloud_path = object_path_from_str(&prefix)?; + + let object_store = + object_stores[std::cmp::min(i, object_stores.len() - 1)].clone(); + let uri = uri.clone(); + + Ok(Arc::new(CloudFileFetcher { + uri, + object_store, + cloud_path, + })) + }, + file_cache_ttl, + ) + }) + .collect::>>() + } else { + uri_list + .iter() + .map(|uri| { + let uri = std::fs::canonicalize(uri.as_ref()).map_err(|err| { + let msg = Some(format!("{}: {}", err, uri.as_ref()).into()); + PolarsError::IO { + error: err.into(), + msg, + } + })?; + let uri = Arc::::from(uri.to_str().unwrap()); + + FILE_CACHE.init_entry( + uri.clone(), + || Ok(Arc::new(LocalFileFetcher::from_uri(uri.clone()))), + file_cache_ttl, + ) + }) + .collect::>>() + } +} diff --git a/crates/polars-io/src/hive.rs b/crates/polars-io/src/hive.rs new file mode 100644 index 000000000000..0efb16f6e926 --- /dev/null +++ b/crates/polars-io/src/hive.rs @@ -0,0 +1,134 @@ +use polars_core::frame::DataFrame; +use polars_core::frame::column::ScalarColumn; +use polars_core::prelude::Column; +use polars_core::series::Series; + +/// Materializes hive partitions. +/// We have a special num_rows arg, as df can be empty when a projection contains +/// only hive partition columns. +/// +/// The `hive_partition_columns` must be ordered by their position in the `reader_schema`. The +/// columns will be materialized by their positions in the file schema if they exist, or otherwise +/// at the end. +/// +/// # Safety +/// +/// num_rows equals the height of the df when the df height is non-zero. +pub(crate) fn materialize_hive_partitions( + df: &mut DataFrame, + reader_schema: &polars_schema::Schema, + hive_partition_columns: Option<&[Series]>, +) { + let num_rows = df.height(); + + if let Some(hive_columns) = hive_partition_columns { + // Insert these hive columns in the order they are stored in the file. + if hive_columns.is_empty() { + return; + } + + let hive_columns = hive_columns + .iter() + .map(|s| ScalarColumn::new(s.name().clone(), s.first(), num_rows).into()) + .collect::>(); + + if reader_schema.index_of(hive_columns[0].name()).is_none() || df.width() == 0 { + // Fast-path - all hive columns are at the end + if df.width() == 0 { + unsafe { df.set_height(num_rows) }; + } + unsafe { df.hstack_mut_unchecked(&hive_columns) }; + return; + } + + let mut merged = Vec::with_capacity(df.width() + hive_columns.len()); + + // `hive_partitions_from_paths()` guarantees `hive_columns` is sorted by their appearance in `reader_schema`. + merge_sorted_to_schema_order( + &mut unsafe { df.get_columns_mut().drain(..) }, + &mut hive_columns.into_iter(), + reader_schema, + &mut merged, + ); + + *df = unsafe { DataFrame::new_no_checks(num_rows, merged) }; + } +} + +/// Merge 2 lists of columns into one, where each list contains columns ordered such that their indices +/// in the `schema` are in ascending order. +/// +/// Layouts: +/// * `cols_lhs`: `[row_index?, ..schema_columns?, ..other_left?]` +/// * If the first item in `cols_lhs` is not found in the schema, it will be assumed to be a +/// `row_index` column and placed first into the result. +/// * `cols_rhs`: `[..schema_columns? ..other_right?]` +/// +/// Output: +/// * `[..schema_columns?, ..other_left?, ..other_right?]` +/// +/// Note: The `row_index` column should be handled before calling this function. +/// +/// # Panics +/// Panics if either `cols_lhs` or `cols_rhs` is empty. +pub fn merge_sorted_to_schema_order<'a, D>( + cols_lhs: &'a mut dyn Iterator, + cols_rhs: &'a mut dyn Iterator, + schema: &polars_schema::Schema, + output: &'a mut Vec, +) { + merge_sorted_to_schema_order_impl(cols_lhs, cols_rhs, output, &|v| schema.index_of(v.name())) +} + +pub fn merge_sorted_to_schema_order_impl<'a, T, O>( + cols_lhs: &'a mut dyn Iterator, + cols_rhs: &'a mut dyn Iterator, + output: &mut O, + get_opt_index: &dyn for<'b> Fn(&'b T) -> Option, +) where + O: Extend, +{ + let mut series_arr = [cols_lhs.peekable(), cols_rhs.peekable()]; + + (|| { + let (Some(a), Some(b)) = ( + series_arr[0] + .peek() + .and_then(|x| get_opt_index(x).or(Some(0))), + series_arr[1].peek().and_then(get_opt_index), + ) else { + return; + }; + + let mut schema_idx_arr = [a, b]; + + loop { + // Take from the side whose next column appears earlier in the `schema`. + let arg_min = if schema_idx_arr[1] < schema_idx_arr[0] { + 1 + } else { + 0 + }; + + output.extend([series_arr[arg_min].next().unwrap()]); + + let Some(v) = series_arr[arg_min].peek() else { + return; + }; + + let Some(i) = get_opt_index(v) else { + // All columns in `cols_lhs` should be present in `schema` except for a row_index column. + // We assume that if a row_index column exists it is always the first column and handle that at + // initialization. + debug_assert_eq!(arg_min, 1); + break; + }; + + schema_idx_arr[arg_min] = i; + } + })(); + + let [a, b] = series_arr; + output.extend(a); + output.extend(b); +} diff --git a/crates/polars-io/src/ipc/ipc_file.rs b/crates/polars-io/src/ipc/ipc_file.rs new file mode 100644 index 000000000000..ce95aadba8af --- /dev/null +++ b/crates/polars-io/src/ipc/ipc_file.rs @@ -0,0 +1,319 @@ +//! # (De)serializing Arrows IPC format. +//! +//! Arrow IPC is a [binary format](https://arrow.apache.org/docs/python/ipc.html). +//! It is the recommended way to serialize and deserialize Polars DataFrames as this is most true +//! to the data schema. +//! +//! ## Example +//! +//! ```rust +//! use polars_core::prelude::*; +//! use polars_io::prelude::*; +//! use std::io::Cursor; +//! +//! +//! let s0 = Column::new("days".into(), &[0, 1, 2, 3, 4]); +//! let s1 = Column::new("temp".into(), &[22.1, 19.9, 7., 2., 3.]); +//! let mut df = DataFrame::new(vec![s0, s1]).unwrap(); +//! +//! // Create an in memory file handler. +//! // Vec: Read + Write +//! // Cursor: Seek +//! +//! let mut buf: Cursor> = Cursor::new(Vec::new()); +//! +//! // write to the in memory buffer +//! IpcWriter::new(&mut buf).finish(&mut df).expect("ipc writer"); +//! +//! // reset the buffers index after writing to the beginning of the buffer +//! buf.set_position(0); +//! +//! // read the buffer into a DataFrame +//! let df_read = IpcReader::new(buf).finish().unwrap(); +//! assert!(df.equals(&df_read)); +//! ``` +use std::io::{Read, Seek}; +use std::path::PathBuf; + +use arrow::datatypes::{ArrowSchemaRef, Metadata}; +use arrow::io::ipc::read::{self, get_row_count}; +use arrow::record_batch::RecordBatch; +use polars_core::prelude::*; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use crate::RowIndex; +use crate::hive::materialize_hive_partitions; +use crate::mmap::MmapBytesReader; +use crate::predicates::PhysicalIoExpr; +use crate::prelude::*; +use crate::shared::{ArrowReader, finish_reader}; + +#[derive(Clone, Debug, PartialEq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct IpcScanOptions; + +/// Read Arrows IPC format into a DataFrame +/// +/// # Example +/// ``` +/// use polars_core::prelude::*; +/// use std::fs::File; +/// use polars_io::ipc::IpcReader; +/// use polars_io::SerReader; +/// +/// fn example() -> PolarsResult { +/// let file = File::open("file.ipc").expect("file not found"); +/// +/// IpcReader::new(file) +/// .finish() +/// } +/// ``` +#[must_use] +pub struct IpcReader { + /// File or Stream object + pub(super) reader: R, + /// Aggregates chunks afterwards to a single chunk. + rechunk: bool, + pub(super) n_rows: Option, + pub(super) projection: Option>, + pub(crate) columns: Option>, + hive_partition_columns: Option>, + include_file_path: Option<(PlSmallStr, Arc)>, + pub(super) row_index: Option, + // Stores the as key semaphore to make sure we don't write to the memory mapped file. + pub(super) memory_map: Option, + metadata: Option, + schema: Option, +} + +fn check_mmap_err(err: PolarsError) -> PolarsResult<()> { + if let PolarsError::ComputeError(s) = &err { + if s.as_ref() == "memory_map can only be done on uncompressed IPC files" { + eprintln!( + "Could not memory_map compressed IPC file, defaulting to normal read. \ + Toggle off 'memory_map' to silence this warning." + ); + return Ok(()); + } + } + Err(err) +} + +impl IpcReader { + fn get_metadata(&mut self) -> PolarsResult<&read::FileMetadata> { + if self.metadata.is_none() { + let metadata = read::read_file_metadata(&mut self.reader)?; + self.schema = Some(metadata.schema.clone()); + self.metadata = Some(metadata); + } + Ok(self.metadata.as_ref().unwrap()) + } + + /// Get arrow schema of the Ipc File. + pub fn schema(&mut self) -> PolarsResult { + self.get_metadata()?; + Ok(self.schema.as_ref().unwrap().clone()) + } + + /// Get schema-level custom metadata of the Ipc file + pub fn custom_metadata(&mut self) -> PolarsResult>> { + self.get_metadata()?; + Ok(self + .metadata + .as_ref() + .and_then(|meta| meta.custom_schema_metadata.clone())) + } + + /// Stop reading when `n` rows are read. + pub fn with_n_rows(mut self, num_rows: Option) -> Self { + self.n_rows = num_rows; + self + } + + /// Columns to select/ project + pub fn with_columns(mut self, columns: Option>) -> Self { + self.columns = columns; + self + } + + pub fn with_hive_partition_columns(mut self, columns: Option>) -> Self { + self.hive_partition_columns = columns; + self + } + + pub fn with_include_file_path( + mut self, + include_file_path: Option<(PlSmallStr, Arc)>, + ) -> Self { + self.include_file_path = include_file_path; + self + } + + /// Add a row index column. + pub fn with_row_index(mut self, row_index: Option) -> Self { + self.row_index = row_index; + self + } + + /// Set the reader's column projection. This counts from 0, meaning that + /// `vec![0, 4]` would select the 1st and 5th column. + pub fn with_projection(mut self, projection: Option>) -> Self { + self.projection = projection; + self + } + + /// Set if the file is to be memory_mapped. Only works with uncompressed files. + /// The file name must be passed to register the memory mapped file. + pub fn memory_mapped(mut self, path_buf: Option) -> Self { + self.memory_map = path_buf; + self + } + + // todo! hoist to lazy crate + #[cfg(feature = "lazy")] + pub fn finish_with_scan_ops( + mut self, + predicate: Option>, + verbose: bool, + ) -> PolarsResult { + if self.memory_map.is_some() && self.reader.to_file().is_some() { + if verbose { + eprintln!("memory map ipc file") + } + match self.finish_memmapped(predicate.clone()) { + Ok(df) => return Ok(df), + Err(err) => check_mmap_err(err)?, + } + } + let rechunk = self.rechunk; + let metadata = read::read_file_metadata(&mut self.reader)?; + + // NOTE: For some code paths this already happened. See + // https://github.com/pola-rs/polars/pull/14984#discussion_r1520125000 + // where this was introduced. + if let Some(columns) = &self.columns { + self.projection = Some(columns_to_projection(columns, &metadata.schema)?); + } + + let schema = if let Some(projection) = &self.projection { + Arc::new(apply_projection(&metadata.schema, projection)) + } else { + metadata.schema.clone() + }; + + let reader = read::FileReader::new(self.reader, metadata, self.projection, self.n_rows); + + finish_reader(reader, rechunk, None, predicate, &schema, self.row_index) + } +} + +impl ArrowReader for read::FileReader +where + R: Read + Seek, +{ + fn next_record_batch(&mut self) -> PolarsResult> { + self.next().map_or(Ok(None), |v| v.map(Some)) + } +} + +impl SerReader for IpcReader { + fn new(reader: R) -> Self { + IpcReader { + reader, + rechunk: true, + n_rows: None, + columns: None, + hive_partition_columns: None, + include_file_path: None, + projection: None, + row_index: None, + memory_map: None, + metadata: None, + schema: None, + } + } + + fn set_rechunk(mut self, rechunk: bool) -> Self { + self.rechunk = rechunk; + self + } + + fn finish(mut self) -> PolarsResult { + let reader_schema = if let Some(ref schema) = self.schema { + schema.clone() + } else { + self.get_metadata()?.schema.clone() + }; + let reader_schema = reader_schema.as_ref(); + + let hive_partition_columns = self.hive_partition_columns.take(); + let include_file_path = self.include_file_path.take(); + + // In case only hive columns are projected, the df would be empty, but we need the row count + // of the file in order to project the correct number of rows for the hive columns. + let mut df = (|| { + if self.projection.as_ref().is_some_and(|x| x.is_empty()) { + let row_count = if let Some(v) = self.n_rows { + v + } else { + get_row_count(&mut self.reader)? as usize + }; + let mut df = DataFrame::empty_with_height(row_count); + + if let Some(ri) = &self.row_index { + unsafe { df.with_row_index_mut(ri.name.clone(), Some(ri.offset)) }; + } + return PolarsResult::Ok(df); + } + + if self.memory_map.is_some() && self.reader.to_file().is_some() { + match self.finish_memmapped(None) { + Ok(df) => { + return Ok(df); + }, + Err(err) => check_mmap_err(err)?, + } + } + let rechunk = self.rechunk; + let schema = self.get_metadata()?.schema.clone(); + + if let Some(columns) = &self.columns { + let prj = columns_to_projection(columns, schema.as_ref())?; + self.projection = Some(prj); + } + + let schema = if let Some(projection) = &self.projection { + Arc::new(apply_projection(schema.as_ref(), projection)) + } else { + schema + }; + + let metadata = self.get_metadata()?.clone(); + + let ipc_reader = + read::FileReader::new(self.reader, metadata, self.projection, self.n_rows); + let df = finish_reader(ipc_reader, rechunk, None, None, &schema, self.row_index)?; + Ok(df) + })()?; + + if let Some(hive_cols) = hive_partition_columns { + materialize_hive_partitions(&mut df, reader_schema, Some(hive_cols.as_slice())); + }; + + if let Some((col, value)) = include_file_path { + unsafe { + df.with_column_unchecked(Column::new_scalar( + col, + Scalar::new( + DataType::String, + AnyValue::StringOwned(value.as_ref().into()), + ), + df.height(), + )) + }; + } + + Ok(df) + } +} diff --git a/crates/polars-io/src/ipc/ipc_reader_async.rs b/crates/polars-io/src/ipc/ipc_reader_async.rs new file mode 100644 index 000000000000..94d85d4c4766 --- /dev/null +++ b/crates/polars-io/src/ipc/ipc_reader_async.rs @@ -0,0 +1,209 @@ +use std::sync::Arc; + +use arrow::io::ipc::read::{FileMetadata, OutOfSpecKind, get_row_count}; +use object_store::ObjectMeta; +use object_store::path::Path; +use polars_core::datatypes::IDX_DTYPE; +use polars_core::frame::DataFrame; +use polars_core::schema::{Schema, SchemaExt}; +use polars_error::{PolarsResult, polars_bail, polars_err, to_compute_err}; +use polars_utils::mmap::MMapSemaphore; +use polars_utils::pl_str::PlSmallStr; + +use crate::RowIndex; +use crate::cloud::{ + CloudLocation, CloudOptions, PolarsObjectStore, build_object_store, object_path_from_str, +}; +use crate::file_cache::{FileCacheEntry, init_entries_from_uri_list}; +use crate::predicates::PhysicalIoExpr; +use crate::prelude::{IpcReader, materialize_projection}; +use crate::shared::SerReader; + +/// An Arrow IPC reader implemented on top of PolarsObjectStore. +pub struct IpcReaderAsync { + store: PolarsObjectStore, + cache_entry: Arc, + path: Path, +} + +#[derive(Default, Clone)] +pub struct IpcReadOptions { + // Names of the columns to include in the output. + projection: Option>, + + // The maximum number of rows to include in the output. + row_limit: Option, + + // Include a column with the row number under the provided name starting at the provided index. + row_index: Option, + + // Only include rows that pass this predicate. + predicate: Option>, +} + +impl IpcReadOptions { + pub fn with_projection(mut self, projection: Option>) -> Self { + self.projection = projection; + self + } + + pub fn with_row_limit(mut self, row_limit: impl Into>) -> Self { + self.row_limit = row_limit.into(); + self + } + + pub fn with_row_index(mut self, row_index: impl Into>) -> Self { + self.row_index = row_index.into(); + self + } + + pub fn with_predicate(mut self, predicate: impl Into>>) -> Self { + self.predicate = predicate.into(); + self + } +} + +impl IpcReaderAsync { + pub async fn from_uri( + uri: &str, + cloud_options: Option<&CloudOptions>, + ) -> PolarsResult { + let cache_entry = init_entries_from_uri_list(&[Arc::from(uri)], cloud_options)?[0].clone(); + let (CloudLocation { prefix, .. }, store) = + build_object_store(uri, cloud_options, false).await?; + + let path = object_path_from_str(&prefix)?; + + Ok(Self { + store, + cache_entry, + path, + }) + } + + async fn object_metadata(&self) -> PolarsResult { + self.store.head(&self.path).await + } + + async fn file_size(&self) -> PolarsResult { + Ok(self.object_metadata().await?.size) + } + + pub async fn metadata(&self) -> PolarsResult { + let file_size = self.file_size().await?; + + // TODO: Do a larger request and hope that the entire footer is contained within it to save one round-trip. + let footer_metadata = + self.store + .get_range( + &self.path, + file_size.checked_sub(FOOTER_METADATA_SIZE).ok_or_else(|| { + to_compute_err("ipc file size is smaller than the minimum") + })?..file_size, + ) + .await?; + + let footer_size = deserialize_footer_metadata( + footer_metadata + .as_ref() + .try_into() + .map_err(to_compute_err)?, + )?; + + let footer = self + .store + .get_range( + &self.path, + file_size + .checked_sub(FOOTER_METADATA_SIZE + footer_size) + .ok_or_else(|| { + to_compute_err("invalid ipc footer metadata: footer size too large") + })?..file_size, + ) + .await?; + + arrow::io::ipc::read::deserialize_footer( + footer.as_ref(), + footer_size.try_into().map_err(to_compute_err)?, + ) + } + + pub async fn data( + &self, + metadata: Option<&FileMetadata>, + options: IpcReadOptions, + verbose: bool, + ) -> PolarsResult { + // TODO: Only download what is needed rather than the entire file by + // making use of the projection, row limit, predicate and such. + let file = tokio::task::block_in_place(|| self.cache_entry.try_open_check_latest())?; + let bytes = MMapSemaphore::new_from_file(&file).unwrap(); + + let projection = match options.projection.as_deref() { + Some(projection) => { + fn prepare_schema(mut schema: Schema, row_index: Option<&RowIndex>) -> Schema { + if let Some(rc) = row_index { + let _ = schema.insert_at_index(0, rc.name.clone(), IDX_DTYPE); + } + schema + } + + // Retrieve the metadata for the schema so we can map column names to indices. + let fetched_metadata; + let metadata = if let Some(metadata) = metadata { + metadata + } else { + // This branch is happens when _metadata is None, which can happen if we Deserialize the execution plan. + fetched_metadata = self.metadata().await?; + &fetched_metadata + }; + + let schema = prepare_schema( + Schema::from_arrow_schema(metadata.schema.as_ref()), + options.row_index.as_ref(), + ); + + let hive_partitions = None; + + materialize_projection( + Some(projection), + &schema, + hive_partitions, + options.row_index.is_some(), + ) + }, + None => None, + }; + + let reader = as SerReader<_>>::new(std::io::Cursor::new(bytes.as_ref())) + .with_row_index(options.row_index) + .with_n_rows(options.row_limit) + .with_projection(projection); + reader.finish_with_scan_ops(options.predicate, verbose) + } + + pub async fn count_rows(&self, _metadata: Option<&FileMetadata>) -> PolarsResult { + // TODO: Only download what is needed rather than the entire file by + // making use of the projection, row limit, predicate and such. + let file = tokio::task::block_in_place(|| self.cache_entry.try_open_check_latest())?; + let bytes = MMapSemaphore::new_from_file(&file).unwrap(); + get_row_count(&mut std::io::Cursor::new(bytes.as_ref())) + } +} + +const FOOTER_METADATA_SIZE: usize = 10; + +// TODO: Move to polars-arrow and deduplicate parsing of footer metadata in +// sync and async readers. +fn deserialize_footer_metadata(bytes: [u8; FOOTER_METADATA_SIZE]) -> PolarsResult { + let footer_size: usize = + i32::from_le_bytes(bytes[0..4].try_into().unwrap_or_else(|_| unreachable!())) + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + + if &bytes[4..] != b"ARROW1" { + polars_bail!(oos = OutOfSpecKind::InvalidFooter); + } + + Ok(footer_size) +} diff --git a/crates/polars-io/src/ipc/ipc_stream.rs b/crates/polars-io/src/ipc/ipc_stream.rs new file mode 100644 index 000000000000..1c2e39143172 --- /dev/null +++ b/crates/polars-io/src/ipc/ipc_stream.rs @@ -0,0 +1,321 @@ +//! # (De)serializing Arrows Streaming IPC format. +//! +//! Arrow Streaming IPC is a [binary format](https://arrow.apache.org/docs/python/ipc.html). +//! It used for sending an arbitrary length sequence of record batches. +//! The format must be processed from start to end, and does not support random access. +//! It is different than IPC, if you can't deserialize a file with `IpcReader::new`, it's probably an IPC Stream File. +//! +//! ## Example +//! +//! ```rust +//! use polars_core::prelude::*; +//! use polars_io::prelude::*; +//! use std::io::Cursor; +//! +//! +//! let c0 = Column::new("days".into(), &[0, 1, 2, 3, 4]); +//! let c1 = Column::new("temp".into(), &[22.1, 19.9, 7., 2., 3.]); +//! let mut df = DataFrame::new(vec![c0, c1]).unwrap(); +//! +//! // Create an in memory file handler. +//! // Vec: Read + Write +//! // Cursor: Seek +//! +//! let mut buf: Cursor> = Cursor::new(Vec::new()); +//! +//! // write to the in memory buffer +//! IpcStreamWriter::new(&mut buf).finish(&mut df).expect("ipc writer"); +//! +//! // reset the buffers index after writing to the beginning of the buffer +//! buf.set_position(0); +//! +//! // read the buffer into a DataFrame +//! let df_read = IpcStreamReader::new(buf).finish().unwrap(); +//! assert!(df.equals(&df_read)); +//! ``` +use std::io::{Read, Write}; +use std::path::PathBuf; + +use arrow::datatypes::Metadata; +use arrow::io::ipc::read::{StreamMetadata, StreamState}; +use arrow::io::ipc::write::WriteOptions; +use arrow::io::ipc::{read, write}; +use polars_core::frame::chunk_df_for_writing; +use polars_core::prelude::*; + +use crate::prelude::*; +use crate::shared::{ArrowReader, finish_reader}; + +/// Read Arrows Stream IPC format into a DataFrame +/// +/// # Example +/// ``` +/// use polars_core::prelude::*; +/// use std::fs::File; +/// use polars_io::ipc::IpcStreamReader; +/// use polars_io::SerReader; +/// +/// fn example() -> PolarsResult { +/// let file = File::open("file.ipc").expect("file not found"); +/// +/// IpcStreamReader::new(file) +/// .finish() +/// } +/// ``` +#[must_use] +pub struct IpcStreamReader { + /// File or Stream object + reader: R, + /// Aggregates chunks afterwards to a single chunk. + rechunk: bool, + n_rows: Option, + projection: Option>, + columns: Option>, + row_index: Option, + metadata: Option, +} + +impl IpcStreamReader { + /// Get schema of the Ipc Stream File + pub fn schema(&mut self) -> PolarsResult { + Ok(Schema::from_arrow_schema(&self.metadata()?.schema)) + } + + /// Get arrow schema of the Ipc Stream File, this is faster than creating a polars schema. + pub fn arrow_schema(&mut self) -> PolarsResult { + Ok(self.metadata()?.schema) + } + + /// Get schema-level custom metadata of the Ipc Stream file + pub fn custom_metadata(&mut self) -> PolarsResult>> { + Ok(self.metadata()?.custom_schema_metadata.map(Arc::new)) + } + + /// Stop reading when `n` rows are read. + pub fn with_n_rows(mut self, num_rows: Option) -> Self { + self.n_rows = num_rows; + self + } + + /// Columns to select/ project + pub fn with_columns(mut self, columns: Option>) -> Self { + self.columns = columns; + self + } + + /// Add a row index column. + pub fn with_row_index(mut self, row_index: Option) -> Self { + self.row_index = row_index; + self + } + + /// Set the reader's column projection. This counts from 0, meaning that + /// `vec![0, 4]` would select the 1st and 5th column. + pub fn with_projection(mut self, projection: Option>) -> Self { + self.projection = projection; + self + } + + fn metadata(&mut self) -> PolarsResult { + match &self.metadata { + None => { + let metadata = read::read_stream_metadata(&mut self.reader)?; + self.metadata = Option::from(metadata.clone()); + Ok(metadata) + }, + Some(md) => Ok(md.clone()), + } + } +} + +impl ArrowReader for read::StreamReader +where + R: Read, +{ + fn next_record_batch(&mut self) -> PolarsResult> { + self.next().map_or(Ok(None), |v| match v { + Ok(stream_state) => match stream_state { + StreamState::Waiting => Ok(None), + StreamState::Some(chunk) => Ok(Some(chunk)), + }, + Err(err) => Err(err), + }) + } +} + +impl SerReader for IpcStreamReader +where + R: Read, +{ + fn new(reader: R) -> Self { + IpcStreamReader { + reader, + rechunk: true, + n_rows: None, + columns: None, + projection: None, + row_index: None, + metadata: None, + } + } + + fn set_rechunk(mut self, rechunk: bool) -> Self { + self.rechunk = rechunk; + self + } + + fn finish(mut self) -> PolarsResult { + let rechunk = self.rechunk; + let metadata = self.metadata()?; + let schema = &metadata.schema; + + if let Some(columns) = self.columns { + let prj = columns_to_projection(&columns, schema)?; + self.projection = Some(prj); + } + + let schema = if let Some(projection) = &self.projection { + apply_projection(&metadata.schema, projection) + } else { + metadata.schema.clone() + }; + + let ipc_reader = + read::StreamReader::new(&mut self.reader, metadata.clone(), self.projection); + finish_reader( + ipc_reader, + rechunk, + self.n_rows, + None, + &schema, + self.row_index, + ) + } +} + +/// Write a DataFrame to Arrow's Streaming IPC format +/// +/// # Example +/// +/// ``` +/// use polars_core::prelude::*; +/// use polars_io::ipc::IpcStreamWriter; +/// use std::fs::File; +/// use polars_io::SerWriter; +/// +/// fn example(df: &mut DataFrame) -> PolarsResult<()> { +/// let mut file = File::create("file.ipc").expect("could not create file"); +/// +/// let mut writer = IpcStreamWriter::new(&mut file); +/// +/// let custom_metadata = [ +/// ("first_name".into(), "John".into()), +/// ("last_name".into(), "Doe".into()), +/// ] +/// .into_iter() +/// .collect(); +/// writer.set_custom_schema_metadata(Arc::new(custom_metadata)); +/// +/// writer.finish(df) +/// } +/// +/// ``` +#[must_use] +pub struct IpcStreamWriter { + writer: W, + compression: Option, + compat_level: CompatLevel, + /// Custom schema-level metadata + custom_schema_metadata: Option>, +} + +use arrow::record_batch::RecordBatch; + +use crate::RowIndex; + +impl IpcStreamWriter { + /// Set the compression used. Defaults to None. + pub fn with_compression(mut self, compression: Option) -> Self { + self.compression = compression; + self + } + + pub fn with_compat_level(mut self, compat_level: CompatLevel) -> Self { + self.compat_level = compat_level; + self + } + + /// Sets custom schema metadata. Must be called before `start` is called + pub fn set_custom_schema_metadata(&mut self, custom_metadata: Arc) { + self.custom_schema_metadata = Some(custom_metadata); + } +} + +impl SerWriter for IpcStreamWriter +where + W: Write, +{ + fn new(writer: W) -> Self { + IpcStreamWriter { + writer, + compression: None, + compat_level: CompatLevel::oldest(), + custom_schema_metadata: None, + } + } + + fn finish(&mut self, df: &mut DataFrame) -> PolarsResult<()> { + let mut ipc_stream_writer = write::StreamWriter::new( + &mut self.writer, + WriteOptions { + compression: self.compression.map(|c| c.into()), + }, + ); + + if let Some(custom_metadata) = &self.custom_schema_metadata { + ipc_stream_writer.set_custom_schema_metadata(Arc::clone(custom_metadata)); + } + + ipc_stream_writer.start(&df.schema().to_arrow(self.compat_level), None)?; + let df = chunk_df_for_writing(df, 512 * 512)?; + let iter = df.iter_chunks(self.compat_level, true); + + for batch in iter { + ipc_stream_writer.write(&batch, None)? + } + ipc_stream_writer.finish()?; + Ok(()) + } +} + +pub struct IpcStreamWriterOption { + compression: Option, + extension: PathBuf, +} + +impl IpcStreamWriterOption { + pub fn new() -> Self { + Self { + compression: None, + extension: PathBuf::from(".ipc"), + } + } + + /// Set the compression used. Defaults to None. + pub fn with_compression(mut self, compression: Option) -> Self { + self.compression = compression; + self + } + + /// Set the extension. Defaults to ".ipc". + pub fn with_extension(mut self, extension: PathBuf) -> Self { + self.extension = extension; + self + } +} + +impl Default for IpcStreamWriterOption { + fn default() -> Self { + Self::new() + } +} diff --git a/crates/polars-io/src/ipc/mmap.rs b/crates/polars-io/src/ipc/mmap.rs new file mode 100644 index 000000000000..30e20703eed9 --- /dev/null +++ b/crates/polars-io/src/ipc/mmap.rs @@ -0,0 +1,111 @@ +use arrow::io::ipc::read; +use arrow::io::ipc::read::{Dictionaries, FileMetadata}; +use arrow::mmap::{mmap_dictionaries_unchecked, mmap_unchecked}; +use arrow::record_batch::RecordBatch; +use polars_core::prelude::*; +use polars_utils::mmap::MMapSemaphore; + +use super::ipc_file::IpcReader; +use crate::mmap::MmapBytesReader; +use crate::predicates::PhysicalIoExpr; +use crate::shared::{ArrowReader, finish_reader}; +use crate::utils::{apply_projection, columns_to_projection}; + +impl IpcReader { + pub(super) fn finish_memmapped( + &mut self, + predicate: Option>, + ) -> PolarsResult { + match self.reader.to_file() { + Some(file) => { + let semaphore = MMapSemaphore::new_from_file(file)?; + let metadata = + read::read_file_metadata(&mut std::io::Cursor::new(semaphore.as_ref()))?; + + if let Some(columns) = &self.columns { + let schema = &metadata.schema; + let prj = columns_to_projection(columns, schema)?; + self.projection = Some(prj); + } + + let schema = if let Some(projection) = &self.projection { + Arc::new(apply_projection(&metadata.schema, projection)) + } else { + metadata.schema.clone() + }; + + let reader = MMapChunkIter::new(Arc::new(semaphore), metadata, &self.projection)?; + + finish_reader( + reader, + // don't rechunk, that would trigger a read. + false, + self.n_rows, + predicate, + &schema, + self.row_index.clone(), + ) + }, + None => polars_bail!(ComputeError: "cannot memory-map, you must provide a file"), + } + } +} + +struct MMapChunkIter<'a> { + dictionaries: Dictionaries, + metadata: FileMetadata, + mmap: Arc, + idx: usize, + end: usize, + projection: &'a Option>, +} + +impl<'a> MMapChunkIter<'a> { + fn new( + mmap: Arc, + metadata: FileMetadata, + projection: &'a Option>, + ) -> PolarsResult { + let end = metadata.blocks.len(); + // mmap the dictionaries + let dictionaries = unsafe { mmap_dictionaries_unchecked(&metadata, mmap.clone())? }; + + Ok(Self { + dictionaries, + metadata, + mmap, + idx: 0, + end, + projection, + }) + } +} + +impl ArrowReader for MMapChunkIter<'_> { + fn next_record_batch(&mut self) -> PolarsResult> { + if self.idx < self.end { + let chunk = unsafe { + mmap_unchecked( + &self.metadata, + &self.dictionaries, + self.mmap.clone(), + self.idx, + ) + }?; + self.idx += 1; + let chunk = match &self.projection { + None => chunk, + Some(proj) => { + let length = chunk.len(); + let (schema, cols) = chunk.into_schema_and_arrays(); + let schema = schema.try_project_indices(proj).unwrap(); + let arrays = proj.iter().map(|i| cols[*i].clone()).collect(); + RecordBatch::new(length, Arc::new(schema), arrays) + }, + }; + Ok(Some(chunk)) + } else { + Ok(None) + } + } +} diff --git a/crates/polars-io/src/ipc/mod.rs b/crates/polars-io/src/ipc/mod.rs new file mode 100644 index 000000000000..1e341de98c56 --- /dev/null +++ b/crates/polars-io/src/ipc/mod.rs @@ -0,0 +1,16 @@ +#[cfg(feature = "ipc")] +mod ipc_file; +#[cfg(feature = "cloud")] +mod ipc_reader_async; +#[cfg(feature = "ipc_streaming")] +mod ipc_stream; +#[cfg(feature = "ipc")] +mod mmap; +mod write; +#[cfg(feature = "ipc")] +pub use ipc_file::{IpcReader, IpcScanOptions}; +#[cfg(feature = "cloud")] +pub use ipc_reader_async::*; +#[cfg(feature = "ipc_streaming")] +pub use ipc_stream::*; +pub use write::{BatchedWriter, IpcCompression, IpcWriter, IpcWriterOptions}; diff --git a/crates/polars-io/src/ipc/write.rs b/crates/polars-io/src/ipc/write.rs new file mode 100644 index 000000000000..cdebaa7285d8 --- /dev/null +++ b/crates/polars-io/src/ipc/write.rs @@ -0,0 +1,215 @@ +use std::io::Write; + +use arrow::datatypes::Metadata; +use arrow::io::ipc::write::{self, EncodedData, WriteOptions}; +use polars_core::prelude::*; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use crate::prelude::*; +use crate::shared::schema_to_arrow_checked; + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct IpcWriterOptions { + /// Data page compression + pub compression: Option, + /// Compatibility level + pub compat_level: CompatLevel, + /// Size of each written chunk. + pub chunk_size: IdxSize, +} + +impl Default for IpcWriterOptions { + fn default() -> Self { + Self { + compression: None, + compat_level: CompatLevel::newest(), + chunk_size: 1 << 18, + } + } +} + +impl IpcWriterOptions { + pub fn to_writer(&self, writer: W) -> IpcWriter { + IpcWriter::new(writer).with_compression(self.compression) + } +} + +/// Write a DataFrame to Arrow's IPC format +/// +/// # Example +/// +/// ``` +/// use polars_core::prelude::*; +/// use polars_io::ipc::IpcWriter; +/// use std::fs::File; +/// use polars_io::SerWriter; +/// +/// fn example(df: &mut DataFrame) -> PolarsResult<()> { +/// let mut file = File::create("file.ipc").expect("could not create file"); +/// +/// let mut writer = IpcWriter::new(&mut file); +/// +/// let custom_metadata = [ +/// ("first_name".into(), "John".into()), +/// ("last_name".into(), "Doe".into()), +/// ] +/// .into_iter() +/// .collect(); +/// writer.set_custom_schema_metadata(Arc::new(custom_metadata)); +/// writer.finish(df) +/// } +/// +/// ``` +#[must_use] +pub struct IpcWriter { + pub(super) writer: W, + pub(super) compression: Option, + /// Polars' flavor of arrow. This might be temporary. + pub(super) compat_level: CompatLevel, + pub(super) parallel: bool, + pub(super) custom_schema_metadata: Option>, +} + +impl IpcWriter { + /// Set the compression used. Defaults to None. + pub fn with_compression(mut self, compression: Option) -> Self { + self.compression = compression; + self + } + + pub fn with_compat_level(mut self, compat_level: CompatLevel) -> Self { + self.compat_level = compat_level; + self + } + + pub fn with_parallel(mut self, parallel: bool) -> Self { + self.parallel = parallel; + self + } + + pub fn batched(self, schema: &Schema) -> PolarsResult> { + let schema = schema_to_arrow_checked(schema, self.compat_level, "ipc")?; + let mut writer = write::FileWriter::new( + self.writer, + Arc::new(schema), + None, + WriteOptions { + compression: self.compression.map(|c| c.into()), + }, + ); + writer.start()?; + + Ok(BatchedWriter { + writer, + compat_level: self.compat_level, + }) + } + + /// Sets custom schema metadata. Must be called before `start` is called + pub fn set_custom_schema_metadata(&mut self, custom_metadata: Arc) { + self.custom_schema_metadata = Some(custom_metadata); + } +} + +impl SerWriter for IpcWriter +where + W: Write, +{ + fn new(writer: W) -> Self { + IpcWriter { + writer, + compression: None, + compat_level: CompatLevel::newest(), + parallel: true, + custom_schema_metadata: None, + } + } + + fn finish(&mut self, df: &mut DataFrame) -> PolarsResult<()> { + let schema = schema_to_arrow_checked(df.schema(), self.compat_level, "ipc")?; + let mut ipc_writer = write::FileWriter::try_new( + &mut self.writer, + Arc::new(schema), + None, + WriteOptions { + compression: self.compression.map(|c| c.into()), + }, + )?; + if let Some(custom_metadata) = &self.custom_schema_metadata { + ipc_writer.set_custom_schema_metadata(Arc::clone(custom_metadata)); + } + + if self.parallel { + df.align_chunks_par(); + } else { + df.align_chunks(); + } + let iter = df.iter_chunks(self.compat_level, true); + + for batch in iter { + ipc_writer.write(&batch, None)? + } + ipc_writer.finish()?; + Ok(()) + } +} + +pub struct BatchedWriter { + writer: write::FileWriter, + compat_level: CompatLevel, +} + +impl BatchedWriter { + /// Write a batch to the ipc writer. + /// + /// # Panics + /// The caller must ensure the chunks in the given [`DataFrame`] are aligned. + pub fn write_batch(&mut self, df: &DataFrame) -> PolarsResult<()> { + let iter = df.iter_chunks(self.compat_level, true); + for batch in iter { + self.writer.write(&batch, None)? + } + Ok(()) + } + + /// Write a encoded data to the ipc writer. + /// + /// # Panics + /// The caller must ensure the chunks in the given [`DataFrame`] are aligned. + pub fn write_encoded( + &mut self, + dictionaries: &[EncodedData], + message: &EncodedData, + ) -> PolarsResult<()> { + self.writer.write_encoded(dictionaries, message)?; + Ok(()) + } + + /// Writes the footer of the IPC file. + pub fn finish(&mut self) -> PolarsResult<()> { + self.writer.finish()?; + Ok(()) + } +} + +/// Compression codec +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum IpcCompression { + /// LZ4 (framed) + LZ4, + /// ZSTD + #[default] + ZSTD, +} + +impl From for write::Compression { + fn from(value: IpcCompression) -> Self { + match value { + IpcCompression::LZ4 => write::Compression::LZ4, + IpcCompression::ZSTD => write::Compression::ZSTD, + } + } +} diff --git a/crates/polars-io/src/json/infer.rs b/crates/polars-io/src/json/infer.rs new file mode 100644 index 000000000000..fd804ea18c81 --- /dev/null +++ b/crates/polars-io/src/json/infer.rs @@ -0,0 +1,36 @@ +use std::num::NonZeroUsize; + +use polars_core::prelude::DataType; +use polars_core::utils::try_get_supertype; +use polars_error::{PolarsResult, polars_bail}; +use simd_json::BorrowedValue; + +pub(crate) fn json_values_to_supertype( + values: &[BorrowedValue], + infer_schema_len: NonZeroUsize, +) -> PolarsResult { + // struct types may have missing fields so find supertype + values + .iter() + .take(infer_schema_len.into()) + .map(|value| polars_json::json::infer(value).map(|dt| DataType::from_arrow_dtype(&dt))) + .reduce(|l, r| { + let l = l?; + let r = r?; + try_get_supertype(&l, &r) + }) + .unwrap_or_else(|| polars_bail!(ComputeError: "could not infer data-type")) +} + +pub(crate) fn dtypes_to_supertype>( + datatypes: I, +) -> PolarsResult { + datatypes + .map(Ok) + .reduce(|l, r| { + let l = l?; + let r = r?; + try_get_supertype(&l, &r) + }) + .unwrap_or_else(|| polars_bail!(ComputeError: "could not infer data-type")) +} diff --git a/crates/polars-io/src/json/mod.rs b/crates/polars-io/src/json/mod.rs new file mode 100644 index 000000000000..9a466ef579be --- /dev/null +++ b/crates/polars-io/src/json/mod.rs @@ -0,0 +1,440 @@ +//! # (De)serialize JSON files. +//! +//! ## Read JSON to a DataFrame +//! +//! ## Example +//! +//! ``` +//! use polars_core::prelude::*; +//! use polars_io::prelude::*; +//! use std::io::Cursor; +//! use std::num::NonZeroUsize; +//! +//! let basic_json = r#"{"a":1, "b":2.0, "c":false, "d":"4"} +//! {"a":-10, "b":-3.5, "c":true, "d":"4"} +//! {"a":2, "b":0.6, "c":false, "d":"text"} +//! {"a":1, "b":2.0, "c":false, "d":"4"} +//! {"a":7, "b":-3.5, "c":true, "d":"4"} +//! {"a":1, "b":0.6, "c":false, "d":"text"} +//! {"a":1, "b":2.0, "c":false, "d":"4"} +//! {"a":5, "b":-3.5, "c":true, "d":"4"} +//! {"a":1, "b":0.6, "c":false, "d":"text"} +//! {"a":1, "b":2.0, "c":false, "d":"4"} +//! {"a":1, "b":-3.5, "c":true, "d":"4"} +//! {"a":1, "b":0.6, "c":false, "d":"text"}"#; +//! let file = Cursor::new(basic_json); +//! let df = JsonReader::new(file) +//! .with_json_format(JsonFormat::JsonLines) +//! .infer_schema_len(NonZeroUsize::new(3)) +//! .with_batch_size(NonZeroUsize::new(3).unwrap()) +//! .finish() +//! .unwrap(); +//! +//! println!("{:?}", df); +//! ``` +//! >>> Outputs: +//! +//! ```text +//! +-----+--------+-------+--------+ +//! | a | b | c | d | +//! | --- | --- | --- | --- | +//! | i64 | f64 | bool | str | +//! +=====+========+=======+========+ +//! | 1 | 2 | false | "4" | +//! +-----+--------+-------+--------+ +//! | -10 | -3.5e0 | true | "4" | +//! +-----+--------+-------+--------+ +//! | 2 | 0.6 | false | "text" | +//! +-----+--------+-------+--------+ +//! | 1 | 2 | false | "4" | +//! +-----+--------+-------+--------+ +//! | 7 | -3.5e0 | true | "4" | +//! +-----+--------+-------+--------+ +//! | 1 | 0.6 | false | "text" | +//! +-----+--------+-------+--------+ +//! | 1 | 2 | false | "4" | +//! +-----+--------+-------+--------+ +//! | 5 | -3.5e0 | true | "4" | +//! +-----+--------+-------+--------+ +//! | 1 | 0.6 | false | "text" | +//! +-----+--------+-------+--------+ +//! | 1 | 2 | false | "4" | +//! +-----+--------+-------+--------+ +//! ``` +//! +pub(crate) mod infer; + +use std::io::Write; +use std::num::NonZeroUsize; +use std::ops::Deref; + +use arrow::legacy::conversion::chunk_to_struct; +use polars_core::error::to_compute_err; +use polars_core::prelude::*; +use polars_error::{PolarsResult, polars_bail}; +use polars_json::json::write::FallibleStreamingIterator; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; +use simd_json::BorrowedValue; + +use crate::mmap::{MmapBytesReader, ReaderBytes}; +use crate::prelude::*; + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct JsonWriterOptions {} + +/// The format to use to write the DataFrame to JSON: `Json` (a JSON array) +/// or `JsonLines` (each row output on a separate line). +/// +/// In either case, each row is serialized as a JSON object whose keys are the column names and +/// whose values are the row's corresponding values. +pub enum JsonFormat { + /// A single JSON array containing each DataFrame row as an object. The length of the array is the number of rows in + /// the DataFrame. + /// + /// Use this to create valid JSON that can be deserialized back into an array in one fell swoop. + Json, + /// Each DataFrame row is serialized as a JSON object on a separate line. The number of lines in the output is the + /// number of rows in the DataFrame. + /// + /// The [JSON Lines](https://jsonlines.org) format makes it easy to read records in a streaming fashion, one (line) + /// at a time. But the output in its entirety is not valid JSON; only the individual lines are. + /// + /// It is recommended to use the file extension `.jsonl` when saving as JSON Lines. + JsonLines, +} + +/// Writes a DataFrame to JSON. +/// +/// Under the hood, this uses [`arrow2::io::json`](https://docs.rs/arrow2/latest/arrow2/io/json/write/fn.write.html). +/// `arrow2` generally serializes types that are not JSON primitives, such as Date and DateTime, as their +/// `Display`-formatted versions. For instance, a (naive) DateTime column is formatted as the String `"yyyy-mm-dd +/// HH:MM:SS"`. To control how non-primitive columns are serialized, convert them to String or another primitive type +/// before serializing. +#[must_use] +pub struct JsonWriter { + /// File or Stream handler + buffer: W, + json_format: JsonFormat, +} + +impl JsonWriter { + pub fn with_json_format(mut self, format: JsonFormat) -> Self { + self.json_format = format; + self + } +} + +impl SerWriter for JsonWriter +where + W: Write, +{ + /// Create a new `JsonWriter` writing to `buffer` with format `JsonFormat::JsonLines`. To specify a different + /// format, use e.g., [`JsonWriter::new(buffer).with_json_format(JsonFormat::Json)`](JsonWriter::with_json_format). + fn new(buffer: W) -> Self { + JsonWriter { + buffer, + json_format: JsonFormat::JsonLines, + } + } + + fn finish(&mut self, df: &mut DataFrame) -> PolarsResult<()> { + df.align_chunks_par(); + let fields = df + .iter() + .map(|s| { + #[cfg(feature = "object")] + polars_ensure!(!matches!(s.dtype(), DataType::Object(_)), ComputeError: "cannot write 'Object' datatype to json"); + Ok(s.field().to_arrow(CompatLevel::newest())) + }) + .collect::>>()?; + let batches = df + .iter_chunks(CompatLevel::newest(), false) + .map(|chunk| Ok(Box::new(chunk_to_struct(chunk, fields.clone())) as ArrayRef)); + + match self.json_format { + JsonFormat::JsonLines => { + let serializer = polars_json::ndjson::write::Serializer::new(batches, vec![]); + let writer = + polars_json::ndjson::write::FileWriter::new(&mut self.buffer, serializer); + writer.collect::>()?; + }, + JsonFormat::Json => { + let serializer = polars_json::json::write::Serializer::new(batches, vec![]); + polars_json::json::write::write(&mut self.buffer, serializer)?; + }, + } + + Ok(()) + } +} + +pub struct BatchedWriter { + writer: W, +} + +impl BatchedWriter +where + W: Write, +{ + pub fn new(writer: W) -> Self { + BatchedWriter { writer } + } + /// Write a batch to the json writer. + /// + /// # Panics + /// The caller must ensure the chunks in the given [`DataFrame`] are aligned. + pub fn write_batch(&mut self, df: &DataFrame) -> PolarsResult<()> { + let fields = df + .iter() + .map(|s| { + #[cfg(feature = "object")] + polars_ensure!(!matches!(s.dtype(), DataType::Object(_)), ComputeError: "cannot write 'Object' datatype to json"); + Ok(s.field().to_arrow(CompatLevel::newest())) + }) + .collect::>>()?; + let chunks = df.iter_chunks(CompatLevel::newest(), false); + let batches = + chunks.map(|chunk| Ok(Box::new(chunk_to_struct(chunk, fields.clone())) as ArrayRef)); + let mut serializer = polars_json::ndjson::write::Serializer::new(batches, vec![]); + while let Some(block) = serializer.next()? { + self.writer.write_all(block)?; + } + Ok(()) + } +} + +/// Reads JSON in one of the formats in [`JsonFormat`] into a DataFrame. +#[must_use] +pub struct JsonReader<'a, R> +where + R: MmapBytesReader, +{ + reader: R, + rechunk: bool, + ignore_errors: bool, + infer_schema_len: Option, + batch_size: NonZeroUsize, + projection: Option>, + schema: Option, + schema_overwrite: Option<&'a Schema>, + json_format: JsonFormat, +} + +pub fn remove_bom(bytes: &[u8]) -> PolarsResult<&[u8]> { + if bytes.starts_with(&[0xEF, 0xBB, 0xBF]) { + // UTF-8 BOM + Ok(&bytes[3..]) + } else if bytes.starts_with(&[0xFE, 0xFF]) || bytes.starts_with(&[0xFF, 0xFE]) { + // UTF-16 BOM + polars_bail!(ComputeError: "utf-16 not supported") + } else { + Ok(bytes) + } +} +impl SerReader for JsonReader<'_, R> +where + R: MmapBytesReader, +{ + fn new(reader: R) -> Self { + JsonReader { + reader, + rechunk: true, + ignore_errors: false, + infer_schema_len: Some(NonZeroUsize::new(100).unwrap()), + batch_size: NonZeroUsize::new(8192).unwrap(), + projection: None, + schema: None, + schema_overwrite: None, + json_format: JsonFormat::Json, + } + } + + fn set_rechunk(mut self, rechunk: bool) -> Self { + self.rechunk = rechunk; + self + } + + /// Take the SerReader and return a parsed DataFrame. + /// + /// Because JSON values specify their types (number, string, etc), no upcasting or conversion is performed between + /// incompatible types in the input. In the event that a column contains mixed dtypes, is it unspecified whether an + /// error is returned or whether elements of incompatible dtypes are replaced with `null`. + fn finish(mut self) -> PolarsResult { + let pre_rb: ReaderBytes = (&mut self.reader).into(); + let bytes = remove_bom(pre_rb.deref())?; + let rb = ReaderBytes::Borrowed(bytes); + let out = match self.json_format { + JsonFormat::Json => { + polars_ensure!(!self.ignore_errors, InvalidOperation: "'ignore_errors' only supported in ndjson"); + let mut bytes = rb.deref().to_vec(); + let owned = &mut vec![]; + compression::maybe_decompress_bytes(&bytes, owned)?; + // the easiest way to avoid ownership issues is by implicitly figuring out if + // decompression happened (owned is only populated on decompress), then pick which bytes to parse + let json_value = if owned.is_empty() { + simd_json::to_borrowed_value(&mut bytes).map_err(to_compute_err)? + } else { + simd_json::to_borrowed_value(owned).map_err(to_compute_err)? + }; + if let BorrowedValue::Array(array) = &json_value { + if array.is_empty() & self.schema.is_none() & self.schema_overwrite.is_none() { + return Ok(DataFrame::empty()); + } + } + + let allow_extra_fields_in_struct = self.schema.is_some(); + + // struct type + let dtype = if let Some(mut schema) = self.schema { + if let Some(overwrite) = self.schema_overwrite { + let mut_schema = Arc::make_mut(&mut schema); + overwrite_schema(mut_schema, overwrite)?; + } + + DataType::Struct(schema.iter_fields().collect()).to_arrow(CompatLevel::newest()) + } else { + // infer + let inner_dtype = if let BorrowedValue::Array(values) = &json_value { + infer::json_values_to_supertype( + values, + self.infer_schema_len + .unwrap_or(NonZeroUsize::new(usize::MAX).unwrap()), + )? + .to_arrow(CompatLevel::newest()) + } else { + polars_json::json::infer(&json_value)? + }; + + if let Some(overwrite) = self.schema_overwrite { + let ArrowDataType::Struct(fields) = inner_dtype else { + polars_bail!(ComputeError: "can only deserialize json objects") + }; + + let mut schema = Schema::from_iter(fields.iter().map(Into::::into)); + overwrite_schema(&mut schema, overwrite)?; + + DataType::Struct( + schema + .into_iter() + .map(|(name, dt)| Field::new(name, dt)) + .collect(), + ) + .to_arrow(CompatLevel::newest()) + } else { + inner_dtype + } + }; + + let dtype = if let BorrowedValue::Array(_) = &json_value { + ArrowDataType::LargeList(Box::new(arrow::datatypes::Field::new( + PlSmallStr::from_static("item"), + dtype, + true, + ))) + } else { + dtype + }; + + let arr = polars_json::json::deserialize( + &json_value, + dtype, + allow_extra_fields_in_struct, + )?; + let arr = arr.as_any().downcast_ref::().ok_or_else( + || polars_err!(ComputeError: "can only deserialize json objects"), + )?; + DataFrame::try_from(arr.clone()) + }, + JsonFormat::JsonLines => { + let mut json_reader = CoreJsonReader::new( + rb, + None, + self.schema, + self.schema_overwrite, + None, + 1024, // sample size + NonZeroUsize::new(1 << 18).unwrap(), + false, + self.infer_schema_len, + self.ignore_errors, + None, + None, + None, + )?; + let mut df: DataFrame = json_reader.as_df()?; + if self.rechunk { + df.as_single_chunk_par(); + } + Ok(df) + }, + }?; + + // TODO! Ensure we don't materialize the columns we don't need + if let Some(proj) = self.projection.as_deref() { + out.select(proj.iter().cloned()) + } else { + Ok(out) + } + } +} + +impl<'a, R> JsonReader<'a, R> +where + R: MmapBytesReader, +{ + /// Set the JSON file's schema + pub fn with_schema(mut self, schema: SchemaRef) -> Self { + self.schema = Some(schema); + self + } + + /// Overwrite parts of the inferred schema. + pub fn with_schema_overwrite(mut self, schema: &'a Schema) -> Self { + self.schema_overwrite = Some(schema); + self + } + + /// Set the JSON reader to infer the schema of the file. Currently, this is only used when reading from + /// [`JsonFormat::JsonLines`], as [`JsonFormat::Json`] reads in the entire array anyway. + /// + /// When using [`JsonFormat::JsonLines`], `max_records = None` will read the entire buffer in order to infer the + /// schema, `Some(1)` would look only at the first record, `Some(2)` the first two records, etc. + /// + /// It is an error to pass `max_records = Some(0)`, as a schema cannot be inferred from 0 records when deserializing + /// from JSON (unlike CSVs, there is no header row to inspect for column names). + pub fn infer_schema_len(mut self, max_records: Option) -> Self { + self.infer_schema_len = max_records; + self + } + + /// Set the batch size (number of records to load at one time) + /// + /// This heavily influences loading time. + pub fn with_batch_size(mut self, batch_size: NonZeroUsize) -> Self { + self.batch_size = batch_size; + self + } + + /// Set the reader's column projection: the names of the columns to keep after deserialization. If `None`, all + /// columns are kept. + /// + /// Setting `projection` to the columns you want to keep is more efficient than deserializing all of the columns and + /// then dropping the ones you don't want. + pub fn with_projection(mut self, projection: Option>) -> Self { + self.projection = projection; + self + } + + pub fn with_json_format(mut self, format: JsonFormat) -> Self { + self.json_format = format; + self + } + + /// Return a `null` if an error occurs during parsing. + pub fn with_ignore_errors(mut self, ignore: bool) -> Self { + self.ignore_errors = ignore; + self + } +} diff --git a/crates/polars-io/src/lib.rs b/crates/polars-io/src/lib.rs new file mode 100644 index 000000000000..20c57f69b17e --- /dev/null +++ b/crates/polars-io/src/lib.rs @@ -0,0 +1,41 @@ +#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(feature = "simd", feature(portable_simd))] +#![allow(ambiguous_glob_reexports)] +extern crate core; + +#[cfg(feature = "avro")] +pub mod avro; +#[cfg(feature = "catalog")] +pub mod catalog; +pub mod cloud; +#[cfg(any(feature = "csv", feature = "json"))] +pub mod csv; +#[cfg(feature = "file_cache")] +pub mod file_cache; +#[cfg(any(feature = "ipc", feature = "ipc_streaming"))] +pub mod ipc; +#[cfg(feature = "json")] +pub mod json; +pub mod mmap; +#[cfg(feature = "json")] +pub mod ndjson; +mod options; +#[cfg(feature = "parquet")] +pub mod parquet; +#[cfg(feature = "parquet")] +pub mod partition; +pub mod path_utils; +#[cfg(feature = "async")] +pub mod pl_async; +pub mod predicates; +pub mod prelude; +mod shared; +pub mod utils; + +#[cfg(feature = "cloud")] +pub use cloud::glob as async_glob; +pub use options::*; +pub use path_utils::*; +pub use shared::*; + +pub mod hive; diff --git a/crates/polars-io/src/mmap.rs b/crates/polars-io/src/mmap.rs new file mode 100644 index 000000000000..f7674941b8dc --- /dev/null +++ b/crates/polars-io/src/mmap.rs @@ -0,0 +1,126 @@ +use std::fs::File; +use std::io::{BufReader, Cursor, Read, Seek}; + +use polars_core::config::verbose; +use polars_utils::file::ClosableFile; +use polars_utils::mmap::MemSlice; + +/// Trait used to get a hold to file handler or to the underlying bytes +/// without performing a Read. +pub trait MmapBytesReader: Read + Seek + Send + Sync { + fn to_file(&self) -> Option<&File> { + None + } + + fn to_bytes(&self) -> Option<&[u8]> { + None + } +} + +impl MmapBytesReader for File { + fn to_file(&self) -> Option<&File> { + Some(self) + } +} + +impl MmapBytesReader for ClosableFile { + fn to_file(&self) -> Option<&File> { + Some(self.as_ref()) + } +} + +impl MmapBytesReader for BufReader { + fn to_file(&self) -> Option<&File> { + Some(self.get_ref()) + } +} + +impl MmapBytesReader for BufReader<&File> { + fn to_file(&self) -> Option<&File> { + Some(self.get_ref()) + } +} + +impl MmapBytesReader for Cursor +where + T: AsRef<[u8]> + Send + Sync, +{ + fn to_bytes(&self) -> Option<&[u8]> { + Some(self.get_ref().as_ref()) + } +} + +impl MmapBytesReader for Box { + fn to_file(&self) -> Option<&File> { + T::to_file(self) + } + + fn to_bytes(&self) -> Option<&[u8]> { + T::to_bytes(self) + } +} + +impl MmapBytesReader for &mut T { + fn to_file(&self) -> Option<&File> { + T::to_file(self) + } + + fn to_bytes(&self) -> Option<&[u8]> { + T::to_bytes(self) + } +} + +// Handle various forms of input bytes +pub enum ReaderBytes<'a> { + Borrowed(&'a [u8]), + Owned(MemSlice), +} + +impl std::ops::Deref for ReaderBytes<'_> { + type Target = [u8]; + fn deref(&self) -> &[u8] { + match self { + Self::Borrowed(ref_bytes) => ref_bytes, + Self::Owned(vec) => vec, + } + } +} + +/// There are some places that perform manual lifetime management after transmuting `ReaderBytes` +/// to have a `'static` inner lifetime. The advantage to doing this is that it lets you construct a +/// `MemSlice` from the `ReaderBytes` in a zero-copy manner regardless of the underlying enum +/// variant. +impl ReaderBytes<'static> { + /// Construct a `MemSlice` in a zero-copy manner from the underlying bytes, with the assumption + /// that the underlying bytes have a `'static` lifetime. + pub fn to_memslice(&self) -> MemSlice { + match self { + ReaderBytes::Borrowed(v) => MemSlice::from_static(v), + ReaderBytes::Owned(v) => v.clone(), + } + } +} + +impl<'a, T: 'a + MmapBytesReader> From<&'a mut T> for ReaderBytes<'a> { + fn from(m: &'a mut T) -> Self { + match m.to_bytes() { + // , but somehow bchk doesn't see that lifetime is 'a. + Some(s) => { + let s = unsafe { std::mem::transmute::<&[u8], &'a [u8]>(s) }; + ReaderBytes::Borrowed(s) + }, + None => { + if let Some(f) = m.to_file() { + ReaderBytes::Owned(MemSlice::from_file(f).unwrap()) + } else { + if verbose() { + eprintln!("could not memory map file; read to buffer.") + } + let mut buf = vec![]; + m.read_to_end(&mut buf).expect("could not read"); + ReaderBytes::Owned(MemSlice::from_vec(buf)) + } + }, + } + } +} diff --git a/crates/polars-io/src/ndjson/buffer.rs b/crates/polars-io/src/ndjson/buffer.rs new file mode 100644 index 000000000000..5c1a949d5944 --- /dev/null +++ b/crates/polars-io/src/ndjson/buffer.rs @@ -0,0 +1,397 @@ +use std::hash::{Hash, Hasher}; + +use arrow::types::NativeType; +use num_traits::NumCast; +use polars_core::frame::row::AnyValueBuffer; +use polars_core::prelude::*; +#[cfg(any(feature = "dtype-datetime", feature = "dtype-date"))] +use polars_time::prelude::string::Pattern; +#[cfg(any(feature = "dtype-datetime", feature = "dtype-date"))] +use polars_time::prelude::string::infer::{DatetimeInfer, TryFromWithUnit, infer_pattern_single}; +use polars_utils::format_pl_smallstr; +use simd_json::{BorrowedValue as Value, KnownKey, StaticNode}; + +#[derive(Debug, Clone, PartialEq)] +pub(crate) struct BufferKey<'a>(pub(crate) KnownKey<'a>); +impl Eq for BufferKey<'_> {} + +impl Hash for BufferKey<'_> { + fn hash(&self, state: &mut H) { + self.0.key().hash(state) + } +} + +pub(crate) struct Buffer<'a> { + name: &'a str, + ignore_errors: bool, + buf: AnyValueBuffer<'a>, +} + +impl Buffer<'_> { + pub fn into_series(self) -> PolarsResult { + let mut buf = self.buf; + let mut s = buf.reset(0, !self.ignore_errors)?; + s.rename(PlSmallStr::from_str(self.name)); + Ok(s) + } + + #[inline] + pub(crate) fn add(&mut self, value: &Value) -> PolarsResult<()> { + use AnyValueBuffer::*; + match &mut self.buf { + Boolean(buf) => { + match value { + Value::Static(StaticNode::Bool(b)) => buf.append_value(*b), + Value::Static(StaticNode::Null) => buf.append_null(), + _ if self.ignore_errors => buf.append_null(), + v => polars_bail!(ComputeError: "cannot parse '{}' as Boolean", v), + } + Ok(()) + }, + Int32(buf) => { + let n = deserialize_number::(value, self.ignore_errors)?; + match n { + Some(v) => buf.append_value(v), + None => buf.append_null(), + } + Ok(()) + }, + Int64(buf) => { + let n = deserialize_number::(value, self.ignore_errors)?; + match n { + Some(v) => buf.append_value(v), + None => buf.append_null(), + } + Ok(()) + }, + UInt64(buf) => { + let n = deserialize_number::(value, self.ignore_errors)?; + match n { + Some(v) => buf.append_value(v), + None => buf.append_null(), + } + Ok(()) + }, + UInt32(buf) => { + let n = deserialize_number::(value, self.ignore_errors)?; + match n { + Some(v) => buf.append_value(v), + None => buf.append_null(), + } + Ok(()) + }, + Float32(buf) => { + let n = deserialize_number::(value, self.ignore_errors)?; + match n { + Some(v) => buf.append_value(v), + None => buf.append_null(), + } + Ok(()) + }, + Float64(buf) => { + let n = deserialize_number::(value, self.ignore_errors)?; + match n { + Some(v) => buf.append_value(v), + None => buf.append_null(), + } + Ok(()) + }, + + String(buf) => { + match value { + Value::String(v) => buf.append_value(v), + Value::Static(StaticNode::Null) => buf.append_null(), + // Forcibly convert to String using the Display impl. + v => buf.append_value(format_pl_smallstr!("{}", ValueDisplay(v))), + } + Ok(()) + }, + #[cfg(feature = "dtype-datetime")] + Datetime(buf, tu, _) => { + let v = + deserialize_datetime::(value, "Datetime", self.ignore_errors, *tu)?; + buf.append_option(v); + Ok(()) + }, + #[cfg(feature = "dtype-date")] + Date(buf) => { + let v = deserialize_datetime::( + value, + "Date", + self.ignore_errors, + TimeUnit::Microseconds, // ignored + )?; + buf.append_option(v); + Ok(()) + }, + All(dtype, buf) => { + let av = deserialize_all(value, dtype, self.ignore_errors)?; + buf.push(av); + Ok(()) + }, + Null(builder) => { + if !(matches!(value, Value::Static(StaticNode::Null)) || self.ignore_errors) { + polars_bail!(ComputeError: "got non-null value for NULL-typed column: {}", value) + }; + + builder.append_null(); + Ok(()) + }, + _ => panic!("unexpected dtype when deserializing ndjson"), + } + } + + pub fn add_null(&mut self) { + self.buf.add(AnyValue::Null).expect("should not fail"); + } +} +pub(crate) fn init_buffers( + schema: &Schema, + capacity: usize, + ignore_errors: bool, +) -> PolarsResult> { + schema + .iter() + .map(|(name, dtype)| { + let av_buf = (dtype, capacity).into(); + let key = KnownKey::from(name.as_str()); + Ok(( + BufferKey(key), + Buffer { + name, + buf: av_buf, + ignore_errors, + }, + )) + }) + .collect() +} + +fn deserialize_number( + value: &Value, + ignore_errors: bool, +) -> PolarsResult> { + let to_result = |x: Option| { + let out = if ignore_errors { + x + } else { + Some(x.ok_or_else(|| { + polars_err!(ComputeError: "cannot parse '{}' as {:?}", value, T::PRIMITIVE + ) + })?) + }; + + Ok(out) + }; + + match value { + Value::Static(StaticNode::F64(f)) => to_result(num_traits::cast(*f)), + Value::Static(StaticNode::I64(i)) => to_result(num_traits::cast(*i)), + Value::Static(StaticNode::U64(u)) => to_result(num_traits::cast(*u)), + Value::Static(StaticNode::Bool(b)) => to_result(num_traits::cast(*b as i32)), + Value::Static(StaticNode::Null) => Ok(None), + _ => to_result(None), + } +} + +#[cfg(feature = "dtype-datetime")] +fn deserialize_datetime( + value: &Value, + type_name: &str, + ignore_errors: bool, + tu: TimeUnit, +) -> PolarsResult> +where + T: PolarsNumericType, + DatetimeInfer: TryFromWithUnit, +{ + match value { + Value::String(val) => { + if let Some(pattern) = infer_pattern_single(val) { + if let Ok(mut infer) = DatetimeInfer::try_from_with_unit(pattern, Some(tu)) { + if let Some(v) = infer.parse(val) { + return Ok(Some(v)); + } + } + } + }, + Value::Static(StaticNode::Null) => return Ok(None), + _ => {}, + }; + + if ignore_errors { + return Ok(None); + } + + polars_bail!(ComputeError: "cannot parse '{}' as {}", value, type_name) +} + +fn deserialize_all<'a>( + json: &Value, + dtype: &DataType, + ignore_errors: bool, +) -> PolarsResult> { + if let Value::Static(StaticNode::Null) = json { + return Ok(AnyValue::Null); + } + match dtype { + #[cfg(feature = "dtype-datetime")] + DataType::Date => { + let value = deserialize_datetime::( + json, + "Date", + ignore_errors, + TimeUnit::Microseconds, // ignored + )?; + return Ok(if let Some(value) = value { + AnyValue::Date(value) + } else { + AnyValue::Null + }); + }, + #[cfg(feature = "dtype-datetime")] + DataType::Datetime(tu, tz) => { + let value = deserialize_datetime::(json, "Datetime", ignore_errors, *tu)?; + return Ok(if let Some(value) = value { + AnyValue::DatetimeOwned(value, *tu, tz.as_ref().map(|s| Arc::from(s.clone()))) + } else { + AnyValue::Null + }); + }, + DataType::Float32 => { + return Ok( + if let Some(v) = deserialize_number::(json, ignore_errors)? { + AnyValue::Float32(v) + } else { + AnyValue::Null + }, + ); + }, + DataType::Float64 => { + return Ok( + if let Some(v) = deserialize_number::(json, ignore_errors)? { + AnyValue::Float64(v) + } else { + AnyValue::Null + }, + ); + }, + DataType::String => { + return Ok(match json { + Value::String(s) => AnyValue::StringOwned(s.as_ref().into()), + v => AnyValue::StringOwned(format_pl_smallstr!("{}", ValueDisplay(v))), + }); + }, + dt if dt.is_primitive_numeric() => { + return Ok( + if let Some(v) = deserialize_number::(json, ignore_errors)? { + AnyValue::Int128(v).cast(dt).into_static() + } else { + AnyValue::Null + }, + ); + }, + _ => {}, + } + + let out = match json { + Value::Static(StaticNode::Bool(b)) => AnyValue::Boolean(*b), + Value::Static(StaticNode::I64(i)) => AnyValue::Int64(*i), + Value::Static(StaticNode::U64(u)) => AnyValue::UInt64(*u), + Value::Static(StaticNode::F64(f)) => AnyValue::Float64(*f), + Value::String(s) => AnyValue::StringOwned(s.as_ref().into()), + Value::Array(arr) => { + let Some(inner_dtype) = dtype.inner_dtype() else { + if ignore_errors { + return Ok(AnyValue::Null); + } + polars_bail!(ComputeError: "expected dtype '{}' in JSON value, got dtype: Array\n\nEncountered value: {}", dtype, json); + }; + let vals: Vec = arr + .iter() + .map(|val| deserialize_all(val, inner_dtype, ignore_errors)) + .collect::>()?; + let strict = !ignore_errors; + let s = + Series::from_any_values_and_dtype(PlSmallStr::EMPTY, &vals, inner_dtype, strict)?; + AnyValue::List(s) + }, + #[cfg(feature = "dtype-struct")] + Value::Object(doc) => { + if let DataType::Struct(fields) = dtype { + let document = &**doc; + + let vals = fields + .iter() + .map(|field| { + if let Some(value) = document.get(field.name.as_str()) { + deserialize_all(value, &field.dtype, ignore_errors) + } else { + Ok(AnyValue::Null) + } + }) + .collect::>>()?; + AnyValue::StructOwned(Box::new((vals, fields.clone()))) + } else { + if ignore_errors { + return Ok(AnyValue::Null); + } + polars_bail!( + ComputeError: "expected {} in json value, got object", dtype, + ); + } + }, + val => AnyValue::StringOwned(format!("{:#?}", val).into()), + }; + Ok(out) +} + +/// Wrapper for serde_json's `Value` with a human-friendly Display impl for nested types: +/// +/// * Default: `{"x": Static(U64(1))}` +/// * ValueDisplay: `{x: 1}` +/// +/// This intended for reading in arbitrary `Value` types into a String type. Note that the output +/// is not guaranteed to be valid JSON as we don't do any escaping of e.g. quote/newline values. +struct ValueDisplay<'a>(&'a Value<'a>); + +impl std::fmt::Display for ValueDisplay<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use Value::*; + + match self.0 { + Static(s) => write!(f, "{s}"), + String(s) => write!(f, r#""{s}""#), + Array(a) => { + write!(f, "[")?; + + let mut iter = a.iter(); + + for v in (&mut iter).take(1) { + write!(f, "{}", ValueDisplay(v))?; + } + + for v in iter { + write!(f, ", {}", ValueDisplay(v))?; + } + + write!(f, "]") + }, + Object(o) => { + write!(f, "{{")?; + + let mut iter = o.iter(); + + for (k, v) in (&mut iter).take(1) { + write!(f, r#""{}": {}"#, k, ValueDisplay(v))?; + } + + for (k, v) in iter { + write!(f, r#", "{}": {}"#, k, ValueDisplay(v))?; + } + + write!(f, "}}") + }, + } + } +} diff --git a/crates/polars-io/src/ndjson/core.rs b/crates/polars-io/src/ndjson/core.rs new file mode 100644 index 000000000000..f651a93f4d11 --- /dev/null +++ b/crates/polars-io/src/ndjson/core.rs @@ -0,0 +1,549 @@ +use std::fs::File; +use std::io::Cursor; +use std::num::NonZeroUsize; +use std::path::PathBuf; + +pub use arrow::array::StructArray; +use num_traits::pow::Pow; +use polars_core::POOL; +use polars_core::prelude::*; +use polars_core::utils::accumulate_dataframes_vertical; +use rayon::prelude::*; + +use crate::mmap::{MmapBytesReader, ReaderBytes}; +use crate::ndjson::buffer::*; +use crate::predicates::PhysicalIoExpr; +use crate::prelude::*; +use crate::{RowIndex, SerReader}; +const NEWLINE: u8 = b'\n'; +const CLOSING_BRACKET: u8 = b'}'; + +#[must_use] +pub struct JsonLineReader<'a, R> +where + R: MmapBytesReader, +{ + reader: R, + rechunk: bool, + n_rows: Option, + n_threads: Option, + infer_schema_len: Option, + chunk_size: NonZeroUsize, + schema: Option, + schema_overwrite: Option<&'a Schema>, + path: Option, + low_memory: bool, + ignore_errors: bool, + row_index: Option<&'a mut RowIndex>, + predicate: Option>, + projection: Option>, +} + +impl<'a, R> JsonLineReader<'a, R> +where + R: 'a + MmapBytesReader, +{ + pub fn with_n_rows(mut self, num_rows: Option) -> Self { + self.n_rows = num_rows; + self + } + pub fn with_schema(mut self, schema: SchemaRef) -> Self { + self.schema = Some(schema); + self + } + + pub fn with_schema_overwrite(mut self, schema: &'a Schema) -> Self { + self.schema_overwrite = Some(schema); + self + } + + pub fn with_rechunk(mut self, rechunk: bool) -> Self { + self.rechunk = rechunk; + self + } + + pub fn with_predicate(mut self, predicate: Option>) -> Self { + self.predicate = predicate; + self + } + + pub fn with_projection(mut self, projection: Option>) -> Self { + self.projection = projection; + self + } + + pub fn with_row_index(mut self, row_index: Option<&'a mut RowIndex>) -> Self { + self.row_index = row_index; + self + } + + pub fn infer_schema_len(mut self, infer_schema_len: Option) -> Self { + self.infer_schema_len = infer_schema_len; + self + } + + pub fn with_n_threads(mut self, n: Option) -> Self { + self.n_threads = n; + self + } + + pub fn with_path>(mut self, path: Option

) -> Self { + self.path = path.map(|p| p.into()); + self + } + /// Sets the chunk size used by the parser. This influences performance + pub fn with_chunk_size(mut self, chunk_size: Option) -> Self { + if let Some(chunk_size) = chunk_size { + self.chunk_size = chunk_size; + }; + + self + } + /// Reduce memory consumption at the expense of performance + pub fn low_memory(mut self, toggle: bool) -> Self { + self.low_memory = toggle; + self + } + + /// Set values as `Null` if parsing fails because of schema mismatches. + pub fn with_ignore_errors(mut self, ignore_errors: bool) -> Self { + self.ignore_errors = ignore_errors; + self + } + + pub fn count(mut self) -> PolarsResult { + let reader_bytes = get_reader_bytes(&mut self.reader)?; + let json_reader = CoreJsonReader::new( + reader_bytes, + self.n_rows, + self.schema, + self.schema_overwrite, + self.n_threads, + 1024, // sample size + self.chunk_size, + self.low_memory, + self.infer_schema_len, + self.ignore_errors, + self.row_index, + self.predicate, + self.projection, + )?; + + json_reader.count() + } +} + +impl JsonLineReader<'_, File> { + /// This is the recommended way to create a json reader as this allows for fastest parsing. + pub fn from_path>(path: P) -> PolarsResult { + let path = crate::resolve_homedir(&path.into()); + let f = polars_utils::open_file(&path)?; + Ok(Self::new(f).with_path(Some(path))) + } +} +impl SerReader for JsonLineReader<'_, R> +where + R: MmapBytesReader, +{ + /// Create a new JsonLineReader from a file/ stream + fn new(reader: R) -> Self { + JsonLineReader { + reader, + rechunk: true, + n_rows: None, + n_threads: None, + infer_schema_len: Some(NonZeroUsize::new(100).unwrap()), + schema: None, + schema_overwrite: None, + path: None, + chunk_size: NonZeroUsize::new(1 << 18).unwrap(), + low_memory: false, + ignore_errors: false, + row_index: None, + predicate: None, + projection: None, + } + } + fn finish(mut self) -> PolarsResult { + let rechunk = self.rechunk; + let reader_bytes = get_reader_bytes(&mut self.reader)?; + let mut json_reader = CoreJsonReader::new( + reader_bytes, + self.n_rows, + self.schema, + self.schema_overwrite, + self.n_threads, + 1024, // sample size + self.chunk_size, + self.low_memory, + self.infer_schema_len, + self.ignore_errors, + self.row_index, + self.predicate, + self.projection, + )?; + + let mut df: DataFrame = json_reader.as_df()?; + if rechunk && df.first_col_n_chunks() > 1 { + df.as_single_chunk_par(); + } + Ok(df) + } +} + +pub(crate) struct CoreJsonReader<'a> { + reader_bytes: Option>, + n_rows: Option, + schema: SchemaRef, + n_threads: Option, + sample_size: usize, + chunk_size: NonZeroUsize, + low_memory: bool, + ignore_errors: bool, + row_index: Option<&'a mut RowIndex>, + predicate: Option>, + projection: Option>, +} +impl<'a> CoreJsonReader<'a> { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + reader_bytes: ReaderBytes<'a>, + n_rows: Option, + schema: Option, + schema_overwrite: Option<&Schema>, + n_threads: Option, + sample_size: usize, + chunk_size: NonZeroUsize, + low_memory: bool, + infer_schema_len: Option, + ignore_errors: bool, + row_index: Option<&'a mut RowIndex>, + predicate: Option>, + projection: Option>, + ) -> PolarsResult> { + let reader_bytes = reader_bytes; + + let mut schema = match schema { + Some(schema) => schema, + None => { + let bytes: &[u8] = &reader_bytes; + let mut cursor = Cursor::new(bytes); + Arc::new(crate::ndjson::infer_schema(&mut cursor, infer_schema_len)?) + }, + }; + if let Some(overwriting_schema) = schema_overwrite { + let schema = Arc::make_mut(&mut schema); + overwrite_schema(schema, overwriting_schema)?; + } + + Ok(CoreJsonReader { + reader_bytes: Some(reader_bytes), + schema, + sample_size, + n_rows, + n_threads, + chunk_size, + low_memory, + ignore_errors, + row_index, + predicate, + projection, + }) + } + + fn count(mut self) -> PolarsResult { + let bytes = self.reader_bytes.take().unwrap(); + Ok(super::count_rows_par(&bytes, self.n_threads)) + } + + fn parse_json(&mut self, mut n_threads: usize, bytes: &[u8]) -> PolarsResult { + let mut bytes = bytes; + let mut total_rows = 128; + + if let Some((mean, std)) = get_line_stats_json(bytes, self.sample_size) { + let line_length_upper_bound = mean + 1.1 * std; + + total_rows = (bytes.len() as f32 / (mean - 0.01 * std)) as usize; + if let Some(n_rows) = self.n_rows { + total_rows = std::cmp::min(n_rows, total_rows); + // the guessed upper bound of the no. of bytes in the file + let n_bytes = (line_length_upper_bound * (n_rows as f32)) as usize; + + if n_bytes < bytes.len() { + if let Some(pos) = next_line_position_naive_json(&bytes[n_bytes..]) { + bytes = &bytes[..n_bytes + pos] + } + } + } + } + + if total_rows <= 128 { + n_threads = 1; + } + + let rows_per_thread = total_rows / n_threads; + + let max_proxy = bytes.len() / n_threads / 2; + let capacity = if self.low_memory { + usize::from(self.chunk_size) + } else { + std::cmp::min(rows_per_thread, max_proxy) + }; + let file_chunks = get_file_chunks_json(bytes, n_threads); + + let row_index = self.row_index.as_ref().map(|ri| ri as &RowIndex); + let (mut dfs, prepredicate_heights) = POOL.install(|| { + file_chunks + .into_par_iter() + .map(|(start_pos, stop_at_nbytes)| { + let mut local_df = parse_ndjson( + &bytes[start_pos..stop_at_nbytes], + Some(capacity), + &self.schema, + self.ignore_errors, + )?; + + let prepredicate_height = local_df.height() as IdxSize; + if let Some(projection) = self.projection.as_deref() { + local_df = local_df.select(projection.iter().cloned())?; + } + + if let Some(row_index) = row_index { + local_df = local_df + .with_row_index(row_index.name.clone(), Some(row_index.offset))?; + } + + if let Some(predicate) = &self.predicate { + let s = predicate.evaluate_io(&local_df)?; + let mask = s.bool()?; + local_df = local_df.filter(mask)?; + } + + Ok((local_df, prepredicate_height)) + }) + .collect::, Vec<_>)>>() + })?; + + if let Some(ref mut row_index) = self.row_index { + update_row_counts3(&mut dfs, &prepredicate_heights, 0); + row_index.offset += prepredicate_heights.iter().copied().sum::(); + } + + accumulate_dataframes_vertical(dfs) + } + + pub fn as_df(&mut self) -> PolarsResult { + let n_threads = self.n_threads.unwrap_or_else(|| POOL.current_num_threads()); + + let reader_bytes = self.reader_bytes.take().unwrap(); + + let mut df = self.parse_json(n_threads, &reader_bytes)?; + + // if multi-threaded the n_rows was probabilistically determined. + // Let's slice to correct number of rows if possible. + if let Some(n_rows) = self.n_rows { + if n_rows < df.height() { + df = df.slice(0, n_rows) + } + } + Ok(df) + } +} + +#[inline(always)] +fn parse_impl( + bytes: &[u8], + buffers: &mut PlIndexMap, + scratch: &mut Scratch, +) -> PolarsResult { + scratch.json.clear(); + scratch.json.extend_from_slice(bytes); + let n = scratch.json.len(); + let value = simd_json::to_borrowed_value_with_buffers(&mut scratch.json, &mut scratch.buffers) + .map_err(|e| polars_err!(ComputeError: "error parsing line: {}", e))?; + match value { + simd_json::BorrowedValue::Object(value) => { + buffers.iter_mut().try_for_each(|(s, inner)| { + match s.0.map_lookup(&value) { + Some(v) => inner.add(v)?, + None => inner.add_null(), + } + PolarsResult::Ok(()) + })?; + }, + _ => { + buffers.iter_mut().for_each(|(_, inner)| inner.add_null()); + }, + }; + Ok(n) +} + +#[derive(Default)] +struct Scratch { + json: Vec, + buffers: simd_json::Buffers, +} + +pub fn json_lines(bytes: &[u8]) -> impl Iterator { + // This previously used `serde_json`'s `RawValue` to deserialize chunks without really deserializing them. + // However, this convenience comes at a cost. serde_json allocates and parses and does UTF-8 validation, all + // things we don't need since we use simd_json for them. Also, `serde_json::StreamDeserializer` has a more + // ambitious goal: it wants to parse potentially *non-delimited* sequences of JSON values, while we know + // our values are line-delimited. Turns out, custom splitting is very easy, and gives a very nice performance boost. + bytes.split(|&byte| byte == b'\n').filter(|&bytes| { + bytes + .iter() + .any(|&byte| !matches!(byte, b' ' | b'\t' | b'\r')) + }) +} + +fn parse_lines(bytes: &[u8], buffers: &mut PlIndexMap) -> PolarsResult<()> { + let mut scratch = Scratch::default(); + + let iter = json_lines(bytes); + for bytes in iter { + parse_impl(bytes, buffers, &mut scratch)?; + } + Ok(()) +} + +pub fn parse_ndjson( + bytes: &[u8], + n_rows_hint: Option, + schema: &Schema, + ignore_errors: bool, +) -> PolarsResult { + let capacity = n_rows_hint.unwrap_or_else(|| estimate_n_lines_in_chunk(bytes)); + + let mut buffers = init_buffers(schema, capacity, ignore_errors)?; + parse_lines(bytes, &mut buffers)?; + + DataFrame::new( + buffers + .into_values() + .map(|buf| Ok(buf.into_series()?.into_column())) + .collect::>() + .map_err(|e| match e { + // Nested types raise SchemaMismatch instead of ComputeError, we map it back here to + // be consistent. + PolarsError::ComputeError(..) => e, + PolarsError::SchemaMismatch(e) => PolarsError::ComputeError(e), + e => e, + })?, + ) +} + +pub fn estimate_n_lines_in_file(file_bytes: &[u8], sample_size: usize) -> usize { + if let Some((mean, std)) = get_line_stats_json(file_bytes, sample_size) { + (file_bytes.len() as f32 / (mean - 0.01 * std)) as usize + } else { + estimate_n_lines_in_chunk(file_bytes) + } +} + +/// Total len divided by max len of first and last non-empty lines. This is intended to be cheaper +/// than `estimate_n_lines_in_file`. +pub fn estimate_n_lines_in_chunk(chunk: &[u8]) -> usize { + chunk + .split(|&c| c == b'\n') + .find(|x| !x.is_empty()) + .map_or(1, |x| { + chunk.len().div_ceil( + x.len().max( + chunk + .rsplit(|&c| c == b'\n') + .find(|x| !x.is_empty()) + .unwrap() + .len(), + ), + ) + }) +} + +/// Find the nearest next line position. +/// Does not check for new line characters embedded in String fields. +/// This just looks for `}\n` +pub(crate) fn next_line_position_naive_json(input: &[u8]) -> Option { + let pos = memchr::memchr(NEWLINE, input)?; + if pos == 0 { + return Some(1); + } + + let is_closing_bracket = input.get(pos - 1) == Some(&CLOSING_BRACKET); + if is_closing_bracket { + Some(pos + 1) + } else { + None + } +} + +/// Get the mean and standard deviation of length of lines in bytes +pub(crate) fn get_line_stats_json(bytes: &[u8], n_lines: usize) -> Option<(f32, f32)> { + let mut lengths = Vec::with_capacity(n_lines); + + let mut bytes_trunc; + let n_lines_per_iter = n_lines / 2; + + let mut n_read = 0; + + let bytes_len = bytes.len(); + + // sample from start and 75% in the file + for offset in [0, (bytes_len as f32 * 0.75) as usize] { + bytes_trunc = &bytes[offset..]; + let pos = next_line_position_naive_json(bytes_trunc)?; + if pos >= bytes_len { + return None; + } + bytes_trunc = &bytes_trunc[pos + 1..]; + + for _ in offset..(offset + n_lines_per_iter) { + let pos = next_line_position_naive_json(bytes_trunc); + if let Some(pos) = pos { + lengths.push(pos); + let next_bytes = &bytes_trunc[pos..]; + if next_bytes.is_empty() { + return None; + } + bytes_trunc = next_bytes; + n_read += pos; + } else { + break; + } + } + } + + let n_samples = lengths.len(); + let mean = (n_read as f32) / (n_samples as f32); + let mut std = 0.0; + for &len in lengths.iter() { + std += (len as f32 - mean).pow(2.0) + } + std = (std / n_samples as f32).sqrt(); + Some((mean, std)) +} + +pub(crate) fn get_file_chunks_json(bytes: &[u8], n_threads: usize) -> Vec<(usize, usize)> { + let mut last_pos = 0; + let total_len = bytes.len(); + let chunk_size = total_len / n_threads; + let mut offsets = Vec::with_capacity(n_threads); + for _ in 0..n_threads { + let search_pos = last_pos + chunk_size; + + if search_pos >= bytes.len() { + break; + } + + let end_pos = match next_line_position_naive_json(&bytes[search_pos..]) { + Some(pos) => search_pos + pos, + None => { + break; + }, + }; + offsets.push((last_pos, end_pos)); + last_pos = end_pos; + } + offsets.push((last_pos, total_len)); + offsets +} diff --git a/crates/polars-io/src/ndjson/mod.rs b/crates/polars-io/src/ndjson/mod.rs new file mode 100644 index 000000000000..332fe10623c0 --- /dev/null +++ b/crates/polars-io/src/ndjson/mod.rs @@ -0,0 +1,49 @@ +use core::{get_file_chunks_json, json_lines}; +use std::num::NonZeroUsize; + +use arrow::array::StructArray; +use polars_core::POOL; +use polars_core::prelude::*; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; + +pub(crate) mod buffer; +pub mod core; + +pub fn infer_schema( + reader: &mut R, + infer_schema_len: Option, +) -> PolarsResult { + let dtypes = polars_json::ndjson::iter_unique_dtypes(reader, infer_schema_len)?; + let dtype = + crate::json::infer::dtypes_to_supertype(dtypes.map(|dt| DataType::from_arrow_dtype(&dt)))?; + let schema = StructArray::get_fields(&dtype.to_arrow(CompatLevel::newest())) + .iter() + .map(Into::::into) + .collect(); + Ok(schema) +} + +/// Count the number of rows. The slice passed must represent the entire file. This will +/// potentially parallelize using rayon. +/// +/// This does not check if the lines are valid NDJSON - it assumes that is the case. +pub fn count_rows_par(full_bytes: &[u8], n_threads: Option) -> usize { + let n_threads = n_threads.unwrap_or(POOL.current_num_threads()); + let file_chunks = get_file_chunks_json(full_bytes, n_threads); + + if file_chunks.len() == 1 { + count_rows(full_bytes) + } else { + let iter = file_chunks + .into_par_iter() + .map(|(start_pos, stop_at_nbytes)| count_rows(&full_bytes[start_pos..stop_at_nbytes])); + + POOL.install(|| iter.sum()) + } +} + +/// Count the number of rows. The slice passed must represent the entire file. +/// This does not check if the lines are valid NDJSON - it assumes that is the case. +pub fn count_rows(full_bytes: &[u8]) -> usize { + json_lines(full_bytes).count() +} diff --git a/crates/polars-io/src/options.rs b/crates/polars-io/src/options.rs new file mode 100644 index 000000000000..4b9c653cac02 --- /dev/null +++ b/crates/polars-io/src/options.rs @@ -0,0 +1,51 @@ +use polars_core::schema::SchemaRef; +use polars_utils::IdxSize; +use polars_utils::pl_str::PlSmallStr; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct RowIndex { + pub name: PlSmallStr, + pub offset: IdxSize, +} + +/// Options for Hive partitioning. +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct HiveOptions { + /// This can be `None` to automatically enable for single directory scans + /// and disable otherwise. However it should be initialized if it is inside + /// a DSL / IR plan. + pub enabled: Option, + pub hive_start_idx: usize, + pub schema: Option, + pub try_parse_dates: bool, +} + +impl HiveOptions { + pub fn new_enabled() -> Self { + Self { + enabled: Some(true), + hive_start_idx: 0, + schema: None, + try_parse_dates: true, + } + } + + pub fn new_disabled() -> Self { + Self { + enabled: Some(false), + hive_start_idx: 0, + schema: None, + try_parse_dates: false, + } + } +} + +impl Default for HiveOptions { + fn default() -> Self { + Self::new_enabled() + } +} diff --git a/crates/polars-io/src/parquet/metadata.rs b/crates/polars-io/src/parquet/metadata.rs new file mode 100644 index 000000000000..0ffbb03940e0 --- /dev/null +++ b/crates/polars-io/src/parquet/metadata.rs @@ -0,0 +1,8 @@ +//! Apache Parquet file metadata. + +use std::sync::Arc; + +pub use polars_parquet::parquet::metadata::FileMetadata; +pub use polars_parquet::read::statistics::{Statistics as ParquetStatistics, deserialize}; + +pub type FileMetadataRef = Arc; diff --git a/crates/polars-io/src/parquet/mod.rs b/crates/polars-io/src/parquet/mod.rs new file mode 100644 index 000000000000..5107f317e187 --- /dev/null +++ b/crates/polars-io/src/parquet/mod.rs @@ -0,0 +1,5 @@ +//! Functionality for reading and writing Apache Parquet files. + +pub mod metadata; +pub mod read; +pub mod write; diff --git a/crates/polars-io/src/parquet/read/async_impl.rs b/crates/polars-io/src/parquet/read/async_impl.rs new file mode 100644 index 000000000000..72a0a10d9b22 --- /dev/null +++ b/crates/polars-io/src/parquet/read/async_impl.rs @@ -0,0 +1,152 @@ +//! Read parquet files in parallel from the Object Store without a third party crate. + +use arrow::datatypes::ArrowSchemaRef; +use object_store::path::Path as ObjectPath; +use polars_core::prelude::*; +use polars_parquet::write::FileMetadata; + +use crate::cloud::{ + CloudLocation, CloudOptions, PolarsObjectStore, build_object_store, object_path_from_str, +}; +use crate::parquet::metadata::FileMetadataRef; + +pub struct ParquetObjectStore { + store: PolarsObjectStore, + path: ObjectPath, + length: Option, + metadata: Option, + schema: Option, +} + +impl ParquetObjectStore { + pub async fn from_uri( + uri: &str, + options: Option<&CloudOptions>, + metadata: Option, + ) -> PolarsResult { + let (CloudLocation { prefix, .. }, store) = build_object_store(uri, options, false).await?; + let path = object_path_from_str(&prefix)?; + + Ok(ParquetObjectStore { + store, + path, + length: None, + metadata, + schema: None, + }) + } + + /// Initialize the length property of the object, unless it has already been fetched. + async fn length(&mut self) -> PolarsResult { + if self.length.is_none() { + self.length = Some(self.store.head(&self.path).await?.size); + } + Ok(self.length.unwrap()) + } + + /// Number of rows in the parquet file. + pub async fn num_rows(&mut self) -> PolarsResult { + let metadata = self.get_metadata().await?; + Ok(metadata.num_rows) + } + + /// Fetch the metadata of the parquet file, do not memoize it. + async fn fetch_metadata(&mut self) -> PolarsResult { + let length = self.length().await?; + fetch_metadata(&self.store, &self.path, length).await + } + + /// Fetch and memoize the metadata of the parquet file. + pub async fn get_metadata(&mut self) -> PolarsResult<&FileMetadataRef> { + if self.metadata.is_none() { + self.metadata = Some(Arc::new(self.fetch_metadata().await?)); + } + Ok(self.metadata.as_ref().unwrap()) + } + + pub async fn schema(&mut self) -> PolarsResult { + self.schema = Some(match self.schema.as_ref() { + Some(schema) => Arc::clone(schema), + None => { + let metadata = self.get_metadata().await?; + let arrow_schema = polars_parquet::arrow::read::infer_schema(metadata)?; + Arc::new(arrow_schema) + }, + }); + + Ok(self.schema.clone().unwrap()) + } +} + +fn read_n(reader: &mut &[u8]) -> Option<[u8; N]> { + if N <= reader.len() { + let (head, tail) = reader.split_at(N); + *reader = tail; + Some(head.try_into().unwrap()) + } else { + None + } +} + +fn read_i32le(reader: &mut &[u8]) -> Option { + read_n(reader).map(i32::from_le_bytes) +} + +/// Asynchronously reads the files' metadata +pub async fn fetch_metadata( + store: &PolarsObjectStore, + path: &ObjectPath, + file_byte_length: usize, +) -> PolarsResult { + let footer_header_bytes = store + .get_range( + path, + file_byte_length + .checked_sub(polars_parquet::parquet::FOOTER_SIZE as usize) + .ok_or_else(|| { + polars_parquet::parquet::error::ParquetError::OutOfSpec( + "not enough bytes to contain parquet footer".to_string(), + ) + })?..file_byte_length, + ) + .await?; + + let footer_byte_length: usize = { + let reader = &mut footer_header_bytes.as_ref(); + let footer_byte_size = read_i32le(reader).unwrap(); + let magic = read_n(reader).unwrap(); + debug_assert!(reader.is_empty()); + if magic != polars_parquet::parquet::PARQUET_MAGIC { + return Err(polars_parquet::parquet::error::ParquetError::OutOfSpec( + "incorrect magic in parquet footer".to_string(), + ) + .into()); + } + footer_byte_size.try_into().map_err(|_| { + polars_parquet::parquet::error::ParquetError::OutOfSpec( + "negative footer byte length".to_string(), + ) + })? + }; + + let footer_bytes = store + .get_range( + path, + file_byte_length + .checked_sub(polars_parquet::parquet::FOOTER_SIZE as usize + footer_byte_length) + .ok_or_else(|| { + polars_parquet::parquet::error::ParquetError::OutOfSpec( + "not enough bytes to contain parquet footer".to_string(), + ) + })?..file_byte_length, + ) + .await?; + + Ok(polars_parquet::parquet::read::deserialize_metadata( + std::io::Cursor::new(footer_bytes.as_ref()), + // TODO: Describe why this makes sense. Taken from the previous + // implementation which said "a highly nested but sparse struct could + // result in many allocations". + footer_bytes.as_ref().len() * 2 + 1024, + )?) +} diff --git a/crates/polars-io/src/parquet/read/mmap.rs b/crates/polars-io/src/parquet/read/mmap.rs new file mode 100644 index 000000000000..cd656938a50d --- /dev/null +++ b/crates/polars-io/src/parquet/read/mmap.rs @@ -0,0 +1,72 @@ +use arrow::array::Array; +use arrow::bitmap::Bitmap; +use arrow::datatypes::Field; +use polars_error::PolarsResult; +use polars_parquet::read::{ + BasicDecompressor, ColumnChunkMetadata, Filter, PageReader, column_iter_to_arrays, +}; +use polars_utils::mmap::{MemReader, MemSlice}; + +/// Store columns data in two scenarios: +/// 1. a local memory mapped file +/// 2. data fetched from cloud storage on demand, in this case +/// a. the key in the hashmap is the start in the file +/// b. the value in the hashmap is the actual data. +/// +/// For the fetched case we use a two phase approach: +/// a. identify all the needed columns +/// b. asynchronously fetch them in parallel, for example using object_store +/// c. store the data in this data structure +/// d. when all the data is available deserialize on multiple threads, for example using rayon +pub enum ColumnStore { + Local(MemSlice), +} + +/// For local files memory maps all columns that are part of the parquet field `field_name`. +/// For cloud files the relevant memory regions should have been prefetched. +pub(super) fn mmap_columns<'a>( + store: &'a ColumnStore, + field_columns: &'a [&ColumnChunkMetadata], +) -> Vec<(&'a ColumnChunkMetadata, MemSlice)> { + field_columns + .iter() + .map(|meta| _mmap_single_column(store, meta)) + .collect() +} + +fn _mmap_single_column<'a>( + store: &'a ColumnStore, + meta: &'a ColumnChunkMetadata, +) -> (&'a ColumnChunkMetadata, MemSlice) { + let byte_range = meta.byte_range(); + let chunk = match store { + ColumnStore::Local(mem_slice) => { + mem_slice.slice(byte_range.start as usize..byte_range.end as usize) + }, + }; + (meta, chunk) +} + +// similar to arrow2 serializer, except this accepts a slice instead of a vec. +// this allows us to memory map +pub fn to_deserializer( + columns: Vec<(&ColumnChunkMetadata, MemSlice)>, + field: Field, + filter: Option, +) -> PolarsResult<(Box, Bitmap)> { + let (columns, types): (Vec<_>, Vec<_>) = columns + .into_iter() + .map(|(column_meta, chunk)| { + // Advise fetching the data for the column chunk + chunk.prefetch(); + + let pages = PageReader::new(MemReader::new(chunk), column_meta, vec![], usize::MAX); + ( + BasicDecompressor::new(pages, vec![]), + &column_meta.descriptor().descriptor.primitive_type, + ) + }) + .unzip(); + + column_iter_to_arrays(columns, types, field, filter) +} diff --git a/crates/polars-io/src/parquet/read/mod.rs b/crates/polars-io/src/parquet/read/mod.rs new file mode 100644 index 000000000000..6ddc6c755645 --- /dev/null +++ b/crates/polars-io/src/parquet/read/mod.rs @@ -0,0 +1,48 @@ +//! Functionality for reading Apache Parquet files. +//! +//! # Examples +//! +//! ``` +//! use polars_core::prelude::*; +//! use polars_io::prelude::*; +//! use std::fs::File; +//! +//! fn example() -> PolarsResult { +//! let r = File::open("example.parquet").unwrap(); +//! let reader = ParquetReader::new(r); +//! reader.finish() +//! } +//! ``` + +#[cfg(feature = "cloud")] +mod async_impl; +mod mmap; +mod options; +mod predicates; +mod read_impl; +mod reader; +mod utils; + +const ROW_COUNT_OVERFLOW_ERR: PolarsError = PolarsError::ComputeError(ErrString::new_static( + "\ +Parquet file produces more than pow(2, 32) rows; \ +consider compiling with polars-bigidx feature (polars-u64-idx package on python), \ +or set 'streaming'", +)); + +#[cfg(feature = "cloud")] +pub use async_impl::ParquetObjectStore; +pub use options::{ParallelStrategy, ParquetOptions}; +use polars_error::{ErrString, PolarsError}; +pub use polars_parquet::arrow::read::infer_schema; +pub use polars_parquet::read::FileMetadata; +pub use read_impl::{create_sorting_map, try_set_sorted_flag}; +pub use reader::ParquetReader; +pub use utils::materialize_empty_df; + +pub mod _internal { + pub use super::mmap::to_deserializer; + pub use super::predicates::collect_statistics_with_live_columns; + pub use super::read_impl::{PrefilterMaskSetting, calc_prefilter_cost}; + pub use super::utils::ensure_matching_dtypes_if_found; +} diff --git a/crates/polars-io/src/parquet/read/options.rs b/crates/polars-io/src/parquet/read/options.rs new file mode 100644 index 000000000000..c1bc5ace852a --- /dev/null +++ b/crates/polars-io/src/parquet/read/options.rs @@ -0,0 +1,35 @@ +use polars_core::schema::SchemaRef; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct ParquetOptions { + pub schema: Option, + pub parallel: ParallelStrategy, + pub low_memory: bool, + pub use_statistics: bool, +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq, Default, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum ParallelStrategy { + /// Don't parallelize + None, + /// Parallelize over the columns + Columns, + /// Parallelize over the row groups + RowGroups, + /// First evaluates the pushed-down predicates in parallel and determines a mask of which rows + /// to read. Then, it parallelizes over both the columns and the row groups while filtering out + /// rows that do not need to be read. This can provide significant speedups for large files + /// (i.e. many row-groups) with a predicate that filters clustered rows or filters heavily. In + /// other cases, this may slow down the scan compared other strategies. + /// + /// If no predicate is given, this falls back to back to [`ParallelStrategy::Auto`]. + Prefiltered, + /// Automatically determine over which unit to parallelize + /// This will choose the most occurring unit. + #[default] + Auto, +} diff --git a/crates/polars-io/src/parquet/read/predicates.rs b/crates/polars-io/src/parquet/read/predicates.rs new file mode 100644 index 000000000000..b6ca3e75d8ea --- /dev/null +++ b/crates/polars-io/src/parquet/read/predicates.rs @@ -0,0 +1,38 @@ +use polars_core::prelude::*; +use polars_parquet::read::RowGroupMetadata; +use polars_parquet::read::statistics::{ArrowColumnStatisticsArrays, deserialize_all}; + +/// Collect the statistics in a row-group +pub fn collect_statistics_with_live_columns( + row_groups: &[RowGroupMetadata], + schema: &ArrowSchema, + live_columns: &PlIndexSet, +) -> PolarsResult>> { + if row_groups.is_empty() { + return Ok((0..live_columns.len()).map(|_| None).collect()); + } + + let md = &row_groups[0]; + live_columns + .iter() + .map(|c| { + let field = schema.get(c).unwrap(); + + // This can be None in the allow_missing_columns case. + let Some(idxs) = md.columns_idxs_under_root_iter(&field.name) else { + return Ok(None); + }; + + // 0 is possible for possible for empty structs. + // + // 2+ is for structs. We don't support reading nested statistics for now. It does not + // really make any sense at the moment with how we structure statistics. + if idxs.is_empty() || idxs.len() > 1 { + return Ok(None); + } + + let idx = idxs[0]; + Ok(deserialize_all(field, row_groups, idx)?) + }) + .collect::>>() +} diff --git a/crates/polars-io/src/parquet/read/read_impl.rs b/crates/polars-io/src/parquet/read/read_impl.rs new file mode 100644 index 000000000000..f3ac07e2eef1 --- /dev/null +++ b/crates/polars-io/src/parquet/read/read_impl.rs @@ -0,0 +1,552 @@ +use std::borrow::Cow; + +use arrow::bitmap::Bitmap; +use arrow::datatypes::ArrowSchemaRef; +use polars_core::chunked_array::builder::NullChunkedBuilder; +use polars_core::prelude::*; +use polars_core::series::IsSorted; +use polars_core::utils::accumulate_dataframes_vertical; +use polars_core::{POOL, config}; +use polars_parquet::read::{self, ColumnChunkMetadata, FileMetadata, Filter, RowGroupMetadata}; +use rayon::prelude::*; + +use super::mmap::mmap_columns; +use super::utils::materialize_empty_df; +use super::{ParallelStrategy, mmap}; +use crate::RowIndex; +use crate::hive::materialize_hive_partitions; +use crate::mmap::{MmapBytesReader, ReaderBytes}; +use crate::parquet::metadata::FileMetadataRef; +use crate::parquet::read::ROW_COUNT_OVERFLOW_ERR; +use crate::utils::slice::split_slice_at_file; + +#[cfg(debug_assertions)] +// Ensure we get the proper polars types from schema inference +// This saves unneeded casts. +fn assert_dtypes(dtype: &ArrowDataType) { + use ArrowDataType as D; + + match dtype { + // These should all be cast to the BinaryView / Utf8View variants + D::Utf8 | D::Binary | D::LargeUtf8 | D::LargeBinary => unreachable!(), + + // These should be cast to Float32 + D::Float16 => unreachable!(), + + // This should have been converted to a LargeList + D::List(_) => unreachable!(), + + // This should have been converted to a LargeList(Struct(_)) + D::Map(_, _) => unreachable!(), + + // Recursive checks + D::Dictionary(_, dtype, _) => assert_dtypes(dtype), + D::Extension(ext) => assert_dtypes(&ext.inner), + D::LargeList(inner) => assert_dtypes(&inner.dtype), + D::FixedSizeList(inner, _) => assert_dtypes(&inner.dtype), + D::Struct(fields) => fields.iter().for_each(|f| assert_dtypes(f.dtype())), + + _ => {}, + } +} + +fn should_copy_sortedness(dtype: &DataType) -> bool { + // @NOTE: For now, we are a bit conservative with this. + use DataType as D; + + matches!( + dtype, + D::Int8 | D::Int16 | D::Int32 | D::Int64 | D::UInt8 | D::UInt16 | D::UInt32 | D::UInt64 + ) +} + +pub fn try_set_sorted_flag( + series: &mut Series, + col_idx: usize, + sorting_map: &PlHashMap, +) { + if let Some(is_sorted) = sorting_map.get(&col_idx) { + if should_copy_sortedness(series.dtype()) { + if config::verbose() { + eprintln!( + "Parquet conserved SortingColumn for column chunk of '{}' to {is_sorted:?}", + series.name() + ); + } + + series.set_sorted_flag(*is_sorted); + } + } +} + +pub fn create_sorting_map(md: &RowGroupMetadata) -> PlHashMap { + let capacity = md.sorting_columns().map_or(0, |s| s.len()); + let mut sorting_map = PlHashMap::with_capacity(capacity); + + if let Some(sorting_columns) = md.sorting_columns() { + for sorting in sorting_columns { + let prev_value = sorting_map.insert( + sorting.column_idx as usize, + if sorting.descending { + IsSorted::Descending + } else { + IsSorted::Ascending + }, + ); + + debug_assert!(prev_value.is_none()); + } + } + + sorting_map +} + +fn column_idx_to_series( + column_i: usize, + // The metadata belonging to this column + field_md: &[&ColumnChunkMetadata], + filter: Option, + file_schema: &ArrowSchema, + store: &mmap::ColumnStore, +) -> PolarsResult<(Series, Bitmap)> { + let field = file_schema.get_at_index(column_i).unwrap().1; + + #[cfg(debug_assertions)] + { + assert_dtypes(field.dtype()) + } + let columns = mmap_columns(store, field_md); + let (array, pred_true_mask) = mmap::to_deserializer(columns, field.clone(), filter)?; + let series = Series::try_from((field, array))?; + + Ok((series, pred_true_mask)) +} + +#[allow(clippy::too_many_arguments)] +fn rg_to_dfs( + store: &mmap::ColumnStore, + previous_row_count: &mut IdxSize, + row_group_start: usize, + row_group_end: usize, + pre_slice: (usize, usize), + file_metadata: &FileMetadata, + schema: &ArrowSchemaRef, + row_index: Option, + parallel: ParallelStrategy, + projection: &[usize], + hive_partition_columns: Option<&[Series]>, +) -> PolarsResult> { + if config::verbose() { + eprintln!("parquet scan with parallel = {parallel:?}"); + } + + // If we are only interested in the row_index, we take a little special path here. + if projection.is_empty() { + if let Some(row_index) = row_index { + let placeholder = + NullChunkedBuilder::new(PlSmallStr::from_static("__PL_TMP"), pre_slice.1).finish(); + return Ok(vec![ + DataFrame::new(vec![placeholder.into_series().into_column()])? + .with_row_index( + row_index.name.clone(), + Some(row_index.offset + IdxSize::try_from(pre_slice.0).unwrap()), + )? + .select(std::iter::once(row_index.name))?, + ]); + } + } + + use ParallelStrategy as S; + + match parallel { + S::Columns | S::None => rg_to_dfs_optionally_par_over_columns( + store, + previous_row_count, + row_group_start, + row_group_end, + pre_slice, + file_metadata, + schema, + row_index, + parallel, + projection, + hive_partition_columns, + ), + _ => rg_to_dfs_par_over_rg( + store, + row_group_start, + row_group_end, + previous_row_count, + pre_slice, + file_metadata, + schema, + row_index, + projection, + hive_partition_columns, + ), + } +} + +#[allow(clippy::too_many_arguments)] +// might parallelize over columns +fn rg_to_dfs_optionally_par_over_columns( + store: &mmap::ColumnStore, + previous_row_count: &mut IdxSize, + row_group_start: usize, + row_group_end: usize, + slice: (usize, usize), + file_metadata: &FileMetadata, + schema: &ArrowSchemaRef, + row_index: Option, + parallel: ParallelStrategy, + projection: &[usize], + hive_partition_columns: Option<&[Series]>, +) -> PolarsResult> { + let mut dfs = Vec::with_capacity(row_group_end - row_group_start); + + let mut n_rows_processed: usize = (0..row_group_start) + .map(|i| file_metadata.row_groups[i].num_rows()) + .sum(); + let slice_end = slice.0 + slice.1; + + for rg_idx in row_group_start..row_group_end { + let md = &file_metadata.row_groups[rg_idx]; + + let rg_slice = + split_slice_at_file(&mut n_rows_processed, md.num_rows(), slice.0, slice_end); + let current_row_count = md.num_rows() as IdxSize; + + let sorting_map = create_sorting_map(md); + + let f = |column_i: &usize| { + let (name, field) = schema.get_at_index(*column_i).unwrap(); + + let Some(iter) = md.columns_under_root_iter(name) else { + return Ok(Column::full_null( + name.clone(), + rg_slice.1, + &DataType::from_arrow_field(field), + )); + }; + + let part = iter.collect::>(); + + let (mut series, _) = column_idx_to_series( + *column_i, + part.as_slice(), + Some(Filter::new_ranged(rg_slice.0, rg_slice.0 + rg_slice.1)), + schema, + store, + )?; + + try_set_sorted_flag(&mut series, *column_i, &sorting_map); + Ok(series.into_column()) + }; + + let columns = if let ParallelStrategy::Columns = parallel { + POOL.install(|| { + projection + .par_iter() + .map(f) + .collect::>>() + })? + } else { + projection.iter().map(f).collect::>>()? + }; + + let mut df = unsafe { DataFrame::new_no_checks(rg_slice.1, columns) }; + if let Some(rc) = &row_index { + unsafe { + df.with_row_index_mut( + rc.name.clone(), + Some(*previous_row_count + rc.offset + rg_slice.0 as IdxSize), + ) + }; + } + + materialize_hive_partitions(&mut df, schema.as_ref(), hive_partition_columns); + + *previous_row_count = previous_row_count.checked_add(current_row_count).ok_or_else(|| + polars_err!( + ComputeError: "Parquet file produces more than pow(2, 32) rows; \ + consider compiling with polars-bigidx feature (polars-u64-idx package on python), \ + or set 'streaming'" + ), + )?; + dfs.push(df); + + if *previous_row_count as usize >= slice_end { + break; + } + } + + Ok(dfs) +} + +#[allow(clippy::too_many_arguments)] +// parallelizes over row groups +fn rg_to_dfs_par_over_rg( + store: &mmap::ColumnStore, + row_group_start: usize, + row_group_end: usize, + rows_read: &mut IdxSize, + slice: (usize, usize), + file_metadata: &FileMetadata, + schema: &ArrowSchemaRef, + row_index: Option, + projection: &[usize], + hive_partition_columns: Option<&[Series]>, +) -> PolarsResult> { + // compute the limits per row group and the row count offsets + let mut row_groups = Vec::with_capacity(row_group_end - row_group_start); + + let mut n_rows_processed: usize = (0..row_group_start) + .map(|i| file_metadata.row_groups[i].num_rows()) + .sum(); + let slice_end = slice.0 + slice.1; + + // rows_scanned is the number of rows that have been scanned so far when checking for overlap with the slice. + // rows_read is the number of rows found to overlap with the slice, and thus the number of rows that will be + // read into a dataframe. + let mut rows_scanned: IdxSize; + + if row_group_start > 0 { + // In the case of async reads, we need to account for the fact that row_group_start may be greater than + // zero due to earlier processing. + // For details, see: https://github.com/pola-rs/polars/pull/20508#discussion_r1900165649 + rows_scanned = (0..row_group_start) + .map(|i| file_metadata.row_groups[i].num_rows() as IdxSize) + .sum(); + } else { + rows_scanned = 0; + } + + for i in row_group_start..row_group_end { + let row_count_start = rows_scanned; + let rg_md = &file_metadata.row_groups[i]; + let n_rows_this_file = rg_md.num_rows(); + let rg_slice = + split_slice_at_file(&mut n_rows_processed, n_rows_this_file, slice.0, slice_end); + rows_scanned = rows_scanned + .checked_add(n_rows_this_file as IdxSize) + .ok_or(ROW_COUNT_OVERFLOW_ERR)?; + + *rows_read += rg_slice.1 as IdxSize; + + if rg_slice.1 == 0 { + continue; + } + + row_groups.push((rg_md, rg_slice, row_count_start)); + } + + let dfs = POOL.install(|| { + // Set partitioned fields to prevent quadratic behavior. + // Ensure all row groups are partitioned. + row_groups + .into_par_iter() + .map(|(md, slice, row_count_start)| { + if slice.1 == 0 { + return Ok(None); + } + // test we don't read the parquet file if this env var is set + #[cfg(debug_assertions)] + { + assert!(std::env::var("POLARS_PANIC_IF_PARQUET_PARSED").is_err()) + } + + let sorting_map = create_sorting_map(md); + + let columns = projection + .iter() + .map(|column_i| { + let (name, field) = schema.get_at_index(*column_i).unwrap(); + + let Some(iter) = md.columns_under_root_iter(name) else { + return Ok(Column::full_null( + name.clone(), + md.num_rows(), + &DataType::from_arrow_field(field), + )); + }; + + let part = iter.collect::>(); + + let (mut series, _) = column_idx_to_series( + *column_i, + part.as_slice(), + Some(Filter::new_ranged(slice.0, slice.0 + slice.1)), + schema, + store, + )?; + + try_set_sorted_flag(&mut series, *column_i, &sorting_map); + Ok(series.into_column()) + }) + .collect::>>()?; + + let mut df = unsafe { DataFrame::new_no_checks(slice.1, columns) }; + + if let Some(rc) = &row_index { + unsafe { + df.with_row_index_mut( + rc.name.clone(), + Some(row_count_start as IdxSize + rc.offset + slice.0 as IdxSize), + ) + }; + } + + materialize_hive_partitions(&mut df, schema.as_ref(), hive_partition_columns); + + Ok(Some(df)) + }) + .collect::>>() + })?; + Ok(dfs.into_iter().flatten().collect()) +} + +#[allow(clippy::too_many_arguments)] +pub fn read_parquet( + mut reader: R, + pre_slice: (usize, usize), + projection: Option<&[usize]>, + reader_schema: &ArrowSchemaRef, + metadata: Option, + mut parallel: ParallelStrategy, + row_index: Option, + hive_partition_columns: Option<&[Series]>, +) -> PolarsResult { + // Fast path. + if pre_slice.1 == 0 { + return Ok(materialize_empty_df( + projection, + reader_schema, + hive_partition_columns, + row_index.as_ref(), + )); + } + + let file_metadata = metadata + .map(Ok) + .unwrap_or_else(|| read::read_metadata(&mut reader).map(Arc::new))?; + let n_row_groups = file_metadata.row_groups.len(); + + // if there are multiple row groups and categorical data + // we need a string cache + // we keep it alive until the end of the function + let _sc = if n_row_groups > 1 { + #[cfg(feature = "dtype-categorical")] + { + Some(polars_core::StringCacheHolder::hold()) + } + #[cfg(not(feature = "dtype-categorical"))] + { + Some(0u8) + } + } else { + None + }; + + let materialized_projection = projection + .map(Cow::Borrowed) + .unwrap_or_else(|| Cow::Owned((0usize..reader_schema.len()).collect::>())); + + if ParallelStrategy::Auto == parallel { + if n_row_groups > materialized_projection.len() || n_row_groups > POOL.current_num_threads() + { + parallel = ParallelStrategy::RowGroups; + } else { + parallel = ParallelStrategy::Columns; + } + } + + if let (ParallelStrategy::Columns, true) = (parallel, materialized_projection.len() == 1) { + parallel = ParallelStrategy::None; + } + + let reader = ReaderBytes::from(&mut reader); + let store = mmap::ColumnStore::Local(unsafe { + std::mem::transmute::, ReaderBytes<'static>>(reader).to_memslice() + }); + + let dfs = rg_to_dfs( + &store, + &mut 0, + 0, + n_row_groups, + pre_slice, + &file_metadata, + reader_schema, + row_index.clone(), + parallel, + &materialized_projection, + hive_partition_columns, + )?; + + if dfs.is_empty() { + Ok(materialize_empty_df( + projection, + reader_schema, + hive_partition_columns, + row_index.as_ref(), + )) + } else { + accumulate_dataframes_vertical(dfs) + } +} + +pub fn calc_prefilter_cost(mask: &arrow::bitmap::Bitmap) -> f64 { + let num_edges = mask.num_edges() as f64; + let rg_len = mask.len() as f64; + + // @GB: I did quite some analysis on this. + // + // Pre-filtered and Post-filtered can both be faster in certain scenarios. + // + // - Pre-filtered is faster when there is some amount of clustering or + // sorting involved or if the number of values selected is small. + // - Post-filtering is faster when the predicate selects a somewhat random + // elements throughout the row group. + // + // The following is a heuristic value to try and estimate which one is + // faster. Essentially, it sees how many times it needs to switch between + // skipping items and collecting items and compares it against the number + // of values that it will collect. + // + // Closer to 0: pre-filtering is probably better. + // Closer to 1: post-filtering is probably better. + (num_edges / rg_len).clamp(0.0, 1.0) +} + +#[derive(Clone, Copy)] +pub enum PrefilterMaskSetting { + Auto, + Pre, + Post, +} + +impl PrefilterMaskSetting { + pub fn init_from_env() -> Self { + std::env::var("POLARS_PQ_PREFILTERED_MASK").map_or(Self::Auto, |v| match &v[..] { + "auto" => Self::Auto, + "pre" => Self::Pre, + "post" => Self::Post, + _ => panic!("Invalid `POLARS_PQ_PREFILTERED_MASK` value '{v}'."), + }) + } + + pub fn should_prefilter(&self, prefilter_cost: f64, dtype: &ArrowDataType) -> bool { + match self { + Self::Auto => { + // Prefiltering is only expensive for nested types so we make the cut-off quite + // high. + let is_nested = dtype.is_nested(); + + // We empirically selected these numbers. + !is_nested && prefilter_cost <= 0.01 + }, + Self::Pre => true, + Self::Post => false, + } + } +} diff --git a/crates/polars-io/src/parquet/read/reader.rs b/crates/polars-io/src/parquet/read/reader.rs new file mode 100644 index 000000000000..5b9552382a04 --- /dev/null +++ b/crates/polars-io/src/parquet/read/reader.rs @@ -0,0 +1,242 @@ +use std::io::{Read, Seek}; +use std::sync::Arc; + +use arrow::datatypes::ArrowSchemaRef; +use polars_core::prelude::*; +use polars_parquet::read; + +use super::read_impl::read_parquet; +use super::utils::{ensure_matching_dtypes_if_found, projected_arrow_schema_to_projection_indices}; +use crate::RowIndex; +use crate::mmap::MmapBytesReader; +use crate::parquet::metadata::FileMetadataRef; +use crate::prelude::*; + +/// Read Apache parquet format into a DataFrame. +#[must_use] +pub struct ParquetReader { + reader: R, + rechunk: bool, + slice: (usize, usize), + columns: Option>, + projection: Option>, + parallel: ParallelStrategy, + schema: Option, + row_index: Option, + low_memory: bool, + metadata: Option, + hive_partition_columns: Option>, + include_file_path: Option<(PlSmallStr, Arc)>, +} + +impl ParquetReader { + /// Try to reduce memory pressure at the expense of performance. If setting this does not reduce memory + /// enough, turn off parallelization. + pub fn set_low_memory(mut self, low_memory: bool) -> Self { + self.low_memory = low_memory; + self + } + + /// Read the parquet file in parallel (default). The single threaded reader consumes less memory. + pub fn read_parallel(mut self, parallel: ParallelStrategy) -> Self { + self.parallel = parallel; + self + } + + pub fn with_slice(mut self, slice: Option<(usize, usize)>) -> Self { + self.slice = slice.unwrap_or((0, usize::MAX)); + self + } + + /// Columns to select/ project + pub fn with_columns(mut self, columns: Option>) -> Self { + self.columns = columns; + self + } + + /// Set the reader's column projection. This counts from 0, meaning that + /// `vec![0, 4]` would select the 1st and 5th column. + pub fn with_projection(mut self, projection: Option>) -> Self { + self.projection = projection; + self + } + + /// Add a row index column. + pub fn with_row_index(mut self, row_index: Option) -> Self { + self.row_index = row_index; + self + } + + /// Checks that the file contains all the columns in `projected_arrow_schema` with the same + /// dtype, and sets the projection indices. + pub fn with_arrow_schema_projection( + mut self, + first_schema: &Arc, + projected_arrow_schema: Option<&ArrowSchema>, + allow_missing_columns: bool, + ) -> PolarsResult { + let slf_schema = self.schema()?; + let slf_schema_width = slf_schema.len(); + + if allow_missing_columns { + // Must check the dtypes + ensure_matching_dtypes_if_found( + projected_arrow_schema.unwrap_or(first_schema.as_ref()), + self.schema()?.as_ref(), + )?; + self.schema = Some(Arc::new( + first_schema + .iter() + .map(|(name, field)| { + (name.clone(), slf_schema.get(name).unwrap_or(field).clone()) + }) + .collect(), + )); + } + + let schema = self.schema()?; + + (|| { + if let Some(projected_arrow_schema) = projected_arrow_schema { + self.projection = projected_arrow_schema_to_projection_indices( + schema.as_ref(), + projected_arrow_schema, + )?; + } else { + if slf_schema_width > first_schema.len() { + polars_bail!( + SchemaMismatch: + "parquet file contained extra columns and no selection was given" + ) + } + + self.projection = + projected_arrow_schema_to_projection_indices(schema.as_ref(), first_schema)?; + }; + Ok(()) + })() + .map_err(|e| { + if !allow_missing_columns && matches!(e, PolarsError::ColumnNotFound(_)) { + e.wrap_msg(|s| { + format!( + "error with column selection, \ + consider enabling `allow_missing_columns`: {}", + s + ) + }) + } else { + e + } + })?; + + Ok(self) + } + + /// [`Schema`] of the file. + pub fn schema(&mut self) -> PolarsResult { + self.schema = Some(match &self.schema { + Some(schema) => schema.clone(), + None => { + let metadata = self.get_metadata()?; + Arc::new(read::infer_schema(metadata)?) + }, + }); + + Ok(self.schema.clone().unwrap()) + } + + /// Number of rows in the parquet file. + pub fn num_rows(&mut self) -> PolarsResult { + let metadata = self.get_metadata()?; + Ok(metadata.num_rows) + } + + pub fn with_hive_partition_columns(mut self, columns: Option>) -> Self { + self.hive_partition_columns = columns; + self + } + + pub fn with_include_file_path( + mut self, + include_file_path: Option<(PlSmallStr, Arc)>, + ) -> Self { + self.include_file_path = include_file_path; + self + } + + pub fn set_metadata(&mut self, metadata: FileMetadataRef) { + self.metadata = Some(metadata); + } + + pub fn get_metadata(&mut self) -> PolarsResult<&FileMetadataRef> { + if self.metadata.is_none() { + self.metadata = Some(Arc::new(read::read_metadata(&mut self.reader)?)); + } + Ok(self.metadata.as_ref().unwrap()) + } +} + +impl SerReader for ParquetReader { + /// Create a new [`ParquetReader`] from an existing `Reader`. + fn new(reader: R) -> Self { + ParquetReader { + reader, + rechunk: false, + slice: (0, usize::MAX), + columns: None, + projection: None, + parallel: Default::default(), + row_index: None, + low_memory: false, + metadata: None, + schema: None, + hive_partition_columns: None, + include_file_path: None, + } + } + + fn set_rechunk(mut self, rechunk: bool) -> Self { + self.rechunk = rechunk; + self + } + + fn finish(mut self) -> PolarsResult { + let schema = self.schema()?; + let metadata = self.get_metadata()?.clone(); + let n_rows = metadata.num_rows.min(self.slice.0 + self.slice.1); + + if let Some(cols) = &self.columns { + self.projection = Some(columns_to_projection(cols, schema.as_ref())?); + } + + let mut df = read_parquet( + self.reader, + self.slice, + self.projection.as_deref(), + &schema, + Some(metadata), + self.parallel, + self.row_index, + self.hive_partition_columns.as_deref(), + )?; + + if self.rechunk { + df.as_single_chunk_par(); + }; + + if let Some((col, value)) = &self.include_file_path { + unsafe { + df.with_column_unchecked(Column::new_scalar( + col.clone(), + Scalar::new( + DataType::String, + AnyValue::StringOwned(value.as_ref().into()), + ), + if df.width() > 0 { df.height() } else { n_rows }, + )) + }; + } + + Ok(df) + } +} diff --git a/crates/polars-io/src/parquet/read/utils.rs b/crates/polars-io/src/parquet/read/utils.rs new file mode 100644 index 000000000000..64290e91166a --- /dev/null +++ b/crates/polars-io/src/parquet/read/utils.rs @@ -0,0 +1,94 @@ +use std::borrow::Cow; + +use polars_core::prelude::{ArrowSchema, DataFrame, DataType, IDX_DTYPE, Series}; +use polars_core::schema::SchemaNamesAndDtypes; +use polars_error::{PolarsResult, polars_bail}; + +use crate::RowIndex; +use crate::hive::materialize_hive_partitions; +use crate::utils::apply_projection; + +pub fn materialize_empty_df( + projection: Option<&[usize]>, + reader_schema: &ArrowSchema, + hive_partition_columns: Option<&[Series]>, + row_index: Option<&RowIndex>, +) -> DataFrame { + let schema = if let Some(projection) = projection { + Cow::Owned(apply_projection(reader_schema, projection)) + } else { + Cow::Borrowed(reader_schema) + }; + let mut df = DataFrame::empty_with_arrow_schema(&schema); + + if let Some(row_index) = row_index { + df.insert_column(0, Series::new_empty(row_index.name.clone(), &IDX_DTYPE)) + .unwrap(); + } + + materialize_hive_partitions(&mut df, reader_schema, hive_partition_columns); + + df +} + +pub(super) fn projected_arrow_schema_to_projection_indices( + schema: &ArrowSchema, + projected_arrow_schema: &ArrowSchema, +) -> PolarsResult>> { + let mut projection_indices = Vec::with_capacity(projected_arrow_schema.len()); + let mut is_full_ordered_projection = projected_arrow_schema.len() == schema.len(); + + for (i, field) in projected_arrow_schema.iter_values().enumerate() { + let dtype = { + let Some((idx, _, field)) = schema.get_full(&field.name) else { + polars_bail!(ColumnNotFound: "did not find column in file: {}", field.name) + }; + + projection_indices.push(idx); + is_full_ordered_projection &= idx == i; + + DataType::from_arrow_field(field) + }; + let expected_dtype = DataType::from_arrow_field(field); + + if dtype.clone() != expected_dtype { + polars_bail!( + mismatch, + col = &field.name, + expected = expected_dtype, + found = dtype + ); + } + } + + Ok((!is_full_ordered_projection).then_some(projection_indices)) +} + +/// Utility to ensure the dtype of the column in `current_schema` matches the dtype in `schema` if +/// that column exists in `schema`. +pub fn ensure_matching_dtypes_if_found( + schema: &ArrowSchema, + current_schema: &ArrowSchema, +) -> PolarsResult<()> { + current_schema + .iter_names_and_dtypes() + .try_for_each(|(name, dtype)| { + if let Some(field) = schema.get(name) { + if dtype != &field.dtype { + // Check again with timezone normalization + // TODO: Add an ArrowDtype eq wrapper? + let lhs = DataType::from_arrow_dtype(dtype); + let rhs = DataType::from_arrow_field(field); + + if lhs != rhs { + polars_bail!( + SchemaMismatch: + "dtypes differ for column {}: {:?} != {:?}" + , name, dtype, &field.dtype + ); + } + } + } + Ok(()) + }) +} diff --git a/crates/polars-io/src/parquet/write/batched_writer.rs b/crates/polars-io/src/parquet/write/batched_writer.rs new file mode 100644 index 000000000000..db69ed5b3441 --- /dev/null +++ b/crates/polars-io/src/parquet/write/batched_writer.rs @@ -0,0 +1,239 @@ +use std::io::Write; +use std::sync::Mutex; + +use arrow::record_batch::RecordBatch; +use polars_core::POOL; +use polars_core::prelude::*; +use polars_parquet::read::{ParquetError, fallible_streaming_iterator}; +use polars_parquet::write::{ + CompressedPage, Compressor, DynIter, DynStreamingIterator, Encoding, FallibleStreamingIterator, + FileWriter, Page, ParquetType, RowGroupIterColumns, SchemaDescriptor, WriteOptions, + array_to_columns, +}; +use rayon::prelude::*; + +pub struct BatchedWriter { + // A mutex so that streaming engine can get concurrent read access to + // compress pages. + // + // @TODO: Remove mutex when old streaming engine is removed + pub(super) writer: Mutex>, + // @TODO: Remove when old streaming engine is removed + pub(super) parquet_schema: SchemaDescriptor, + pub(super) encodings: Vec>, + pub(super) options: WriteOptions, + pub(super) parallel: bool, +} + +impl BatchedWriter { + pub fn new( + writer: Mutex>, + encodings: Vec>, + options: WriteOptions, + parallel: bool, + ) -> Self { + Self { + writer, + parquet_schema: SchemaDescriptor::new(PlSmallStr::EMPTY, vec![]), + encodings, + options, + parallel, + } + } + + pub fn encode_and_compress<'a>( + &'a self, + df: &'a DataFrame, + ) -> impl Iterator>> + 'a { + let rb_iter = df.iter_chunks(CompatLevel::newest(), false); + rb_iter.filter_map(move |batch| match batch.len() { + 0 => None, + _ => { + let row_group = create_eager_serializer( + batch, + self.parquet_schema.fields(), + self.encodings.as_ref(), + self.options, + ); + + Some(row_group) + }, + }) + } + + /// Write a batch to the parquet writer. + /// + /// # Panics + /// The caller must ensure the chunks in the given [`DataFrame`] are aligned. + pub fn write_batch(&mut self, df: &DataFrame) -> PolarsResult<()> { + let row_group_iter = prepare_rg_iter( + df, + &self.parquet_schema, + &self.encodings, + self.options, + self.parallel, + ); + // Lock before looping so that order is maintained under contention. + let mut writer = self.writer.lock().unwrap(); + for group in row_group_iter { + writer.write(group?)?; + } + Ok(()) + } + + pub fn parquet_schema(&mut self) -> &SchemaDescriptor { + let writer = self.writer.get_mut().unwrap(); + writer.parquet_schema() + } + + pub fn write_row_group(&mut self, rg: &[Vec]) -> PolarsResult<()> { + let writer = self.writer.get_mut().unwrap(); + let rg = DynIter::new(rg.iter().map(|col_pages| { + Ok(DynStreamingIterator::new( + fallible_streaming_iterator::convert(col_pages.iter().map(PolarsResult::Ok)), + )) + })); + writer.write(rg)?; + Ok(()) + } + + pub fn get_writer(&self) -> &Mutex> { + &self.writer + } + + pub fn write_row_groups( + &self, + rgs: Vec>, + ) -> PolarsResult<()> { + // Lock before looping so that order is maintained. + let mut writer = self.writer.lock().unwrap(); + for group in rgs { + writer.write(group)?; + } + Ok(()) + } + + /// Writes the footer of the parquet file. Returns the total size of the file. + pub fn finish(&self) -> PolarsResult { + let mut writer = self.writer.lock().unwrap(); + let size = writer.end(None)?; + Ok(size) + } +} + +// Note that the df should be rechunked +fn prepare_rg_iter<'a>( + df: &'a DataFrame, + parquet_schema: &'a SchemaDescriptor, + encodings: &'a [Vec], + options: WriteOptions, + parallel: bool, +) -> impl Iterator>> + 'a { + let rb_iter = df.iter_chunks(CompatLevel::newest(), false); + rb_iter.filter_map(move |batch| match batch.len() { + 0 => None, + _ => { + let row_group = + create_serializer(batch, parquet_schema.fields(), encodings, options, parallel); + + Some(row_group) + }, + }) +} + +fn pages_iter_to_compressor( + encoded_columns: Vec>>, + options: WriteOptions, +) -> Vec>> { + encoded_columns + .into_iter() + .map(|encoded_pages| { + // iterator over pages + let pages = DynStreamingIterator::new( + Compressor::new_from_vec( + encoded_pages.map(|result| { + result.map_err(|e| { + ParquetError::FeatureNotSupported(format!("reraised in polars: {e}",)) + }) + }), + options.compression, + vec![], + ) + .map_err(PolarsError::from), + ); + + Ok(pages) + }) + .collect::>() +} + +fn array_to_pages_iter( + array: &ArrayRef, + type_: &ParquetType, + encoding: &[Encoding], + options: WriteOptions, +) -> Vec>> { + let encoded_columns = array_to_columns(array, type_.clone(), options, encoding).unwrap(); + pages_iter_to_compressor(encoded_columns, options) +} + +fn create_serializer( + batch: RecordBatch, + fields: &[ParquetType], + encodings: &[Vec], + options: WriteOptions, + parallel: bool, +) -> PolarsResult> { + let func = move |((array, type_), encoding): ((&ArrayRef, &ParquetType), &Vec)| { + array_to_pages_iter(array, type_, encoding, options) + }; + + let columns = if parallel { + POOL.install(|| { + batch + .columns() + .par_iter() + .zip(fields) + .zip(encodings) + .flat_map(func) + .collect::>() + }) + } else { + batch + .columns() + .iter() + .zip(fields) + .zip(encodings) + .flat_map(func) + .collect::>() + }; + + let row_group = DynIter::new(columns.into_iter()); + + Ok(row_group) +} + +/// This serializer encodes and compresses all eagerly in memory. +/// Used for separating compute from IO. +fn create_eager_serializer( + batch: RecordBatch, + fields: &[ParquetType], + encodings: &[Vec], + options: WriteOptions, +) -> PolarsResult> { + let func = move |((array, type_), encoding): ((&ArrayRef, &ParquetType), &Vec)| { + array_to_pages_iter(array, type_, encoding, options) + }; + + let columns = batch + .columns() + .iter() + .zip(fields) + .zip(encodings) + .flat_map(func) + .collect::>(); + + let row_group = DynIter::new(columns.into_iter()); + + Ok(row_group) +} diff --git a/crates/polars-io/src/parquet/write/mod.rs b/crates/polars-io/src/parquet/write/mod.rs new file mode 100644 index 000000000000..8548ac3e0570 --- /dev/null +++ b/crates/polars-io/src/parquet/write/mod.rs @@ -0,0 +1,10 @@ +//! Functionality for reading and writing Apache Parquet files. + +mod batched_writer; +mod options; +mod writer; + +pub use batched_writer::BatchedWriter; +pub use options::{BrotliLevel, GzipLevel, ParquetCompression, ParquetWriteOptions, ZstdLevel}; +pub use polars_parquet::write::{RowGroupIterColumns, StatisticsOptions}; +pub use writer::{ParquetWriter, get_encodings}; diff --git a/crates/polars-io/src/parquet/write/options.rs b/crates/polars-io/src/parquet/write/options.rs new file mode 100644 index 000000000000..2b733ad17e80 --- /dev/null +++ b/crates/polars-io/src/parquet/write/options.rs @@ -0,0 +1,96 @@ +use polars_error::PolarsResult; +use polars_parquet::write::{ + BrotliLevel as BrotliLevelParquet, CompressionOptions, GzipLevel as GzipLevelParquet, + StatisticsOptions, ZstdLevel as ZstdLevelParquet, +}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct ParquetWriteOptions { + /// Data page compression + pub compression: ParquetCompression, + /// Compute and write column statistics. + pub statistics: StatisticsOptions, + /// If `None` will be all written to a single row group. + pub row_group_size: Option, + /// if `None` will be 1024^2 bytes + pub data_page_size: Option, +} + +/// The compression strategy to use for writing Parquet files. +#[derive(Debug, Eq, PartialEq, Hash, Clone, Copy)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum ParquetCompression { + Uncompressed, + Snappy, + Gzip(Option), + Lzo, + Brotli(Option), + Zstd(Option), + Lz4Raw, +} + +impl Default for ParquetCompression { + fn default() -> Self { + Self::Zstd(None) + } +} + +/// A valid Gzip compression level. +#[derive(Debug, Eq, PartialEq, Hash, Clone, Copy)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct GzipLevel(u8); + +impl GzipLevel { + pub fn try_new(level: u8) -> PolarsResult { + GzipLevelParquet::try_new(level)?; + Ok(GzipLevel(level)) + } +} + +/// A valid Brotli compression level. +#[derive(Debug, Eq, PartialEq, Hash, Clone, Copy)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct BrotliLevel(u32); + +impl BrotliLevel { + pub fn try_new(level: u32) -> PolarsResult { + BrotliLevelParquet::try_new(level)?; + Ok(BrotliLevel(level)) + } +} + +/// A valid Zstandard compression level. +#[derive(Debug, Eq, PartialEq, Hash, Clone, Copy)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct ZstdLevel(i32); + +impl ZstdLevel { + pub fn try_new(level: i32) -> PolarsResult { + ZstdLevelParquet::try_new(level)?; + Ok(ZstdLevel(level)) + } +} + +impl From for CompressionOptions { + fn from(value: ParquetCompression) -> Self { + use ParquetCompression::*; + match value { + Uncompressed => CompressionOptions::Uncompressed, + Snappy => CompressionOptions::Snappy, + Gzip(level) => { + CompressionOptions::Gzip(level.map(|v| GzipLevelParquet::try_new(v.0).unwrap())) + }, + Lzo => CompressionOptions::Lzo, + Brotli(level) => { + CompressionOptions::Brotli(level.map(|v| BrotliLevelParquet::try_new(v.0).unwrap())) + }, + Lz4Raw => CompressionOptions::Lz4Raw, + Zstd(level) => { + CompressionOptions::Zstd(level.map(|v| ZstdLevelParquet::try_new(v.0).unwrap())) + }, + } + } +} diff --git a/crates/polars-io/src/parquet/write/writer.rs b/crates/polars-io/src/parquet/write/writer.rs new file mode 100644 index 000000000000..68c72a6d3afe --- /dev/null +++ b/crates/polars-io/src/parquet/write/writer.rs @@ -0,0 +1,158 @@ +use std::io::Write; +use std::sync::Mutex; + +use arrow::datatypes::PhysicalType; +use polars_core::frame::chunk_df_for_writing; +use polars_core::prelude::*; +use polars_parquet::write::{ + CompressionOptions, Encoding, FileWriter, StatisticsOptions, Version, WriteOptions, + to_parquet_schema, transverse, +}; + +use super::ParquetWriteOptions; +use super::batched_writer::BatchedWriter; +use super::options::ParquetCompression; +use crate::shared::schema_to_arrow_checked; + +impl ParquetWriteOptions { + pub fn to_writer(&self, f: F) -> ParquetWriter + where + F: Write, + { + ParquetWriter::new(f) + .with_compression(self.compression) + .with_statistics(self.statistics) + .with_row_group_size(self.row_group_size) + .with_data_page_size(self.data_page_size) + } +} + +/// Write a DataFrame to Parquet format. +#[must_use] +pub struct ParquetWriter { + writer: W, + /// Data page compression + compression: CompressionOptions, + /// Compute and write column statistics. + statistics: StatisticsOptions, + /// if `None` will be 512^2 rows + row_group_size: Option, + /// if `None` will be 1024^2 bytes + data_page_size: Option, + /// Serialize columns in parallel + parallel: bool, +} + +impl ParquetWriter +where + W: Write, +{ + /// Create a new writer + pub fn new(writer: W) -> Self + where + W: Write, + { + ParquetWriter { + writer, + compression: ParquetCompression::default().into(), + statistics: StatisticsOptions::default(), + row_group_size: None, + data_page_size: None, + parallel: true, + } + } + + /// Set the compression used. Defaults to `Zstd`. + /// + /// The default compression `Zstd` has very good performance, but may not yet been supported + /// by older readers. If you want more compatibility guarantees, consider using `Snappy`. + pub fn with_compression(mut self, compression: ParquetCompression) -> Self { + self.compression = compression.into(); + self + } + + /// Compute and write statistic + pub fn with_statistics(mut self, statistics: StatisticsOptions) -> Self { + self.statistics = statistics; + self + } + + /// Set the row group size (in number of rows) during writing. This can reduce memory pressure and improve + /// writing performance. + pub fn with_row_group_size(mut self, size: Option) -> Self { + self.row_group_size = size; + self + } + + /// Sets the maximum bytes size of a data page. If `None` will be 1024^2 bytes. + pub fn with_data_page_size(mut self, limit: Option) -> Self { + self.data_page_size = limit; + self + } + + /// Serialize columns in parallel + pub fn set_parallel(mut self, parallel: bool) -> Self { + self.parallel = parallel; + self + } + + pub fn batched(self, schema: &Schema) -> PolarsResult> { + let schema = schema_to_arrow_checked(schema, CompatLevel::newest(), "parquet")?; + let parquet_schema = to_parquet_schema(&schema)?; + let encodings = get_encodings(&schema); + let options = self.materialize_options(); + let writer = Mutex::new(FileWriter::try_new(self.writer, schema, options)?); + + Ok(BatchedWriter { + writer, + parquet_schema, + encodings, + options, + parallel: self.parallel, + }) + } + + fn materialize_options(&self) -> WriteOptions { + WriteOptions { + statistics: self.statistics, + compression: self.compression, + version: Version::V1, + data_page_size: self.data_page_size, + } + } + + /// Write the given DataFrame in the writer `W`. Returns the total size of the file. + pub fn finish(self, df: &mut DataFrame) -> PolarsResult { + let chunked_df = chunk_df_for_writing(df, self.row_group_size.unwrap_or(512 * 512))?; + let mut batched = self.batched(chunked_df.schema())?; + batched.write_batch(&chunked_df)?; + batched.finish() + } +} + +pub fn get_encodings(schema: &ArrowSchema) -> Vec> { + schema + .iter_values() + .map(|f| transverse(&f.dtype, encoding_map)) + .collect() +} + +/// Declare encodings +fn encoding_map(dtype: &ArrowDataType) -> Encoding { + match dtype.to_physical_type() { + PhysicalType::Dictionary(_) + | PhysicalType::LargeBinary + | PhysicalType::LargeUtf8 + | PhysicalType::Utf8View + | PhysicalType::BinaryView => Encoding::RleDictionary, + PhysicalType::Primitive(dt) => { + use arrow::types::PrimitiveType::*; + match dt { + Float32 | Float64 | Float16 => Encoding::Plain, + _ => Encoding::RleDictionary, + } + }, + // remaining is plain + _ => Encoding::Plain, + } +} diff --git a/crates/polars-io/src/partition.rs b/crates/polars-io/src/partition.rs new file mode 100644 index 000000000000..08812f4d8749 --- /dev/null +++ b/crates/polars-io/src/partition.rs @@ -0,0 +1,200 @@ +//! Functionality for writing a DataFrame partitioned into multiple files. + +use std::path::Path; + +use polars_core::POOL; +use polars_core::prelude::*; +use polars_core::series::IsSorted; +use rayon::prelude::*; + +use crate::cloud::CloudOptions; +use crate::parquet::write::ParquetWriteOptions; +#[cfg(feature = "ipc")] +use crate::prelude::IpcWriterOptions; +use crate::prelude::URL_ENCODE_CHAR_SET; +use crate::utils::file::try_get_writeable; +use crate::{SerWriter, WriteDataFrameToFile, is_cloud_url}; + +impl WriteDataFrameToFile for ParquetWriteOptions { + fn write_df_to_file( + &self, + df: &mut DataFrame, + path: &str, + cloud_options: Option<&CloudOptions>, + ) -> PolarsResult<()> { + let f = try_get_writeable(path, cloud_options)?; + self.to_writer(f).finish(df)?; + Ok(()) + } +} + +#[cfg(feature = "ipc")] +impl WriteDataFrameToFile for IpcWriterOptions { + fn write_df_to_file( + &self, + df: &mut DataFrame, + path: &str, + cloud_options: Option<&CloudOptions>, + ) -> PolarsResult<()> { + let f = try_get_writeable(path, cloud_options)?; + self.to_writer(f).finish(df)?; + Ok(()) + } +} + +/// Write a partitioned parquet dataset. This functionality is unstable. +pub fn write_partitioned_dataset( + df: &mut DataFrame, + path: &Path, + partition_by: Vec, + file_write_options: &(dyn WriteDataFrameToFile + Send + Sync), + cloud_options: Option<&CloudOptions>, + chunk_size: usize, +) -> PolarsResult<()> { + // Ensure we have a single chunk as the gather will otherwise rechunk per group. + df.as_single_chunk_par(); + + // Note: When adding support for formats other than Parquet, avoid writing the partitioned + // columns into the file. We write them for parquet because they are encoded efficiently with + // RLE and also gives us a way to get the hive schema from the parquet file for free. + let get_hive_path_part = { + let schema = &df.schema(); + + let partition_by_col_idx = partition_by + .iter() + .map(|x| { + let Some(i) = schema.index_of(x.as_str()) else { + polars_bail!(col_not_found = x) + }; + Ok(i) + }) + .collect::>>()?; + + move |df: &DataFrame| { + let cols = df.get_columns(); + + partition_by_col_idx + .iter() + .map(|&i| { + let s = &cols[i].slice(0, 1).cast(&DataType::String).unwrap(); + + format!( + "{}={}", + s.name(), + percent_encoding::percent_encode( + s.str() + .unwrap() + .get(0) + .unwrap_or("__HIVE_DEFAULT_PARTITION__") + .as_bytes(), + URL_ENCODE_CHAR_SET + ) + ) + }) + .collect::>() + .join("/") + } + }; + + let base_path = path; + let is_cloud = is_cloud_url(base_path); + let groups = df.group_by(partition_by)?.take_groups(); + + let init_part_base_dir = |part_df: &DataFrame| { + let path_part = get_hive_path_part(part_df); + let dir = base_path.join(path_part); + + if !is_cloud { + std::fs::create_dir_all(&dir)?; + } + + PolarsResult::Ok(dir) + }; + + fn get_path_for_index(i: usize) -> String { + // Use a fixed-width file name so that it sorts properly. + format!("{:08x}.parquet", i) + } + + let get_n_files_and_rows_per_file = |part_df: &DataFrame| { + let n_files = (part_df.estimated_size() / chunk_size).clamp(1, 0xffff_ffff); + let rows_per_file = (df.height() / n_files).saturating_add(1); + (n_files, rows_per_file) + }; + + let write_part = |mut df: DataFrame, path: &Path| { + file_write_options.write_df_to_file(&mut df, path.to_str().unwrap(), cloud_options)?; + PolarsResult::Ok(()) + }; + + // This is sqrt(N) of the actual limit - we chunk the input both at the groups + // proxy level and within every group. + const MAX_OPEN_FILES: usize = 8; + + let finish_part_df = |df: DataFrame| { + let dir_path = init_part_base_dir(&df)?; + let (n_files, rows_per_file) = get_n_files_and_rows_per_file(&df); + + if n_files == 1 { + write_part(df.clone(), &dir_path.join(get_path_for_index(0))) + } else { + (0..df.height()) + .step_by(rows_per_file) + .enumerate() + .collect::>() + .chunks(MAX_OPEN_FILES) + .map(|chunk| { + chunk + .into_par_iter() + .map(|&(idx, slice_start)| { + let df = df.slice(slice_start as i64, rows_per_file); + write_part(df.clone(), &dir_path.join(get_path_for_index(idx))) + }) + .reduce( + || PolarsResult::Ok(()), + |a, b| if a.is_err() { a } else { b }, + ) + }) + .collect::>>()?; + Ok(()) + } + }; + + POOL.install(|| match groups.as_ref() { + GroupsType::Idx(idx) => idx + .all() + .chunks(MAX_OPEN_FILES) + .map(|chunk| { + chunk + .par_iter() + .map(|group| { + let df = unsafe { + df._take_unchecked_slice_sorted(group, true, IsSorted::Ascending) + }; + finish_part_df(df) + }) + .reduce( + || PolarsResult::Ok(()), + |a, b| if a.is_err() { a } else { b }, + ) + }) + .collect::>>(), + GroupsType::Slice { groups, .. } => groups + .chunks(MAX_OPEN_FILES) + .map(|chunk| { + chunk + .into_par_iter() + .map(|&[offset, len]| { + let df = df.slice(offset as i64, len as usize); + finish_part_df(df) + }) + .reduce( + || PolarsResult::Ok(()), + |a, b| if a.is_err() { a } else { b }, + ) + }) + .collect::>>(), + })?; + + Ok(()) +} diff --git a/crates/polars-io/src/path_utils/hugging_face.rs b/crates/polars-io/src/path_utils/hugging_face.rs new file mode 100644 index 000000000000..d67559b23ae3 --- /dev/null +++ b/crates/polars-io/src/path_utils/hugging_face.rs @@ -0,0 +1,411 @@ +// Hugging Face path resolution support + +use std::borrow::Cow; +use std::collections::VecDeque; +use std::path::PathBuf; + +use polars_error::{PolarsResult, polars_bail, to_compute_err}; + +use crate::cloud::{ + CloudConfig, CloudOptions, Matcher, extract_prefix_expansion, + try_build_http_header_map_from_items_slice, +}; +use crate::path_utils::HiveIdxTracker; +use crate::pl_async::with_concurrency_budget; +use crate::prelude::URL_ENCODE_CHAR_SET; +use crate::utils::decode_json_response; + +#[derive(Debug, PartialEq)] +struct HFPathParts { + bucket: String, + repository: String, + revision: String, + /// Path relative to the repository root. + path: String, +} + +struct HFRepoLocation { + api_base_path: String, + download_base_path: String, +} + +impl HFRepoLocation { + fn new(bucket: &str, repository: &str, revision: &str) -> Self { + let bucket = percent_encode(bucket.as_bytes()); + let repository = percent_encode(repository.as_bytes()); + + // "https://huggingface.co/api/ [datasets | spaces] / {username} / {reponame} / tree / {revision} / {path from root}" + let api_base_path = format!( + "{}{}{}{}{}{}{}", + "https://huggingface.co/api/", bucket, "/", repository, "/tree/", revision, "/" + ); + let download_base_path = format!( + "{}{}{}{}{}{}{}", + "https://huggingface.co/", bucket, "/", repository, "/resolve/", revision, "/" + ); + + Self { + api_base_path, + download_base_path, + } + } + + fn get_file_uri(&self, rel_path: &str) -> String { + format!( + "{}{}", + self.download_base_path, + percent_encode(rel_path.as_bytes()) + ) + } + + fn get_api_uri(&self, rel_path: &str) -> String { + format!( + "{}{}", + self.api_base_path, + percent_encode(rel_path.as_bytes()) + ) + } +} + +impl HFPathParts { + /// Extracts path components from a hugging face path: + /// `hf:// [datasets | spaces] / {username} / {reponame} @ {revision} / {path from root}` + fn try_from_uri(uri: &str) -> PolarsResult { + let Some(this) = (|| { + // hf:// [datasets | spaces] / {username} / {reponame} @ {revision} / {path from root} + // !> + if !uri.starts_with("hf://") { + return None; + } + let uri = &uri[5..]; + + // [datasets | spaces] / {username} / {reponame} @ {revision} / {path from root} + // ^-----------------^ !> + let i = memchr::memchr(b'/', uri.as_bytes())?; + let bucket = uri.get(..i)?.to_string(); + let uri = uri.get(1 + i..)?; + + // {username} / {reponame} @ {revision} / {path from root} + // ^----------------------------------^ !> + let i = memchr::memchr(b'/', uri.as_bytes())?; + let i = { + // Also handle if they just give the repository, i.e.: + // hf:// [datasets | spaces] / {username} / {reponame} @ {revision} + let uri = uri.get(1 + i..)?; + if uri.is_empty() { + return None; + } + 1 + i + memchr::memchr(b'/', uri.as_bytes()).unwrap_or(uri.len()) + }; + let repository = uri.get(..i)?; + let uri = uri.get(1 + i..).unwrap_or(""); + + let (repository, revision) = + if let Some(i) = memchr::memchr(b'@', repository.as_bytes()) { + (repository[..i].to_string(), repository[1 + i..].to_string()) + } else { + // No @revision in uri, default to `main` + (repository.to_string(), "main".to_string()) + }; + + // {path from root} + // ^--------------^ + let path = uri.to_string(); + + Some(HFPathParts { + bucket, + repository, + revision, + path, + }) + })() else { + polars_bail!(ComputeError: "invalid Hugging Face path: {}", uri); + }; + + const BUCKETS: [&str; 2] = ["datasets", "spaces"]; + if !BUCKETS.contains(&this.bucket.as_str()) { + polars_bail!(ComputeError: "hugging face uri bucket must be one of {:?}, got {} instead.", BUCKETS, this.bucket); + } + + Ok(this) + } +} + +#[derive(Debug, serde::Deserialize)] +struct HFAPIResponse { + #[serde(rename = "type")] + type_: String, + path: String, + size: u64, +} + +impl HFAPIResponse { + fn is_file(&self) -> bool { + self.type_ == "file" + } + + fn is_directory(&self) -> bool { + self.type_ == "directory" + } +} + +/// API response is paginated with a `link` header. +/// * https://huggingface.co/docs/hub/en/api#get-apidatasets +/// * https://docs.github.com/en/rest/using-the-rest-api/using-pagination-in-the-rest-api?apiVersion=2022-11-28#using-link-headers +struct GetPages<'a> { + client: &'a reqwest::Client, + uri: Option, +} + +impl GetPages<'_> { + async fn next(&mut self) -> Option> { + let uri = self.uri.take()?; + + Some( + async { + let resp = with_concurrency_budget(1, || async { + self.client.get(uri).send().await.map_err(to_compute_err) + }) + .await?; + + self.uri = resp + .headers() + .get("link") + .and_then(|x| Self::find_link(x.as_bytes(), "next".as_bytes())) + .transpose()?; + + let resp_bytes = resp.bytes().await.map_err(to_compute_err)?; + + Ok(resp_bytes) + } + .await, + ) + } + + fn find_link(mut link: &[u8], rel: &[u8]) -> Option> { + // "; rel=\"next\", ; rel=\"last\"" + while !link.is_empty() { + let i = memchr::memchr(b'<', link)?; + link = link.get(1 + i..)?; + let i = memchr::memchr(b'>', link)?; + let uri = &link[..i]; + link = link.get(1 + i..)?; + + while !link.starts_with("rel=\"".as_bytes()) { + link = link.get(1..)? + } + + // rel="next" + link = link.get(5..)?; + let i = memchr::memchr(b'"', link)?; + + if &link[..i] == rel { + return Some( + std::str::from_utf8(uri) + .map_err(to_compute_err) + .map(ToString::to_string), + ); + } + } + + None + } +} + +pub(super) async fn expand_paths_hf( + paths: &[PathBuf], + check_directory_level: bool, + cloud_options: Option<&CloudOptions>, + glob: bool, +) -> PolarsResult<(usize, Vec)> { + assert!(!paths.is_empty()); + + let client = reqwest::ClientBuilder::new().http1_only().https_only(true); + + let client = if let Some(CloudOptions { + config: Some(CloudConfig::Http { headers }), + .. + }) = cloud_options + { + client.default_headers(try_build_http_header_map_from_items_slice( + headers.as_slice(), + )?) + } else { + client + }; + + let client = &client.build().unwrap(); + + let mut out_paths = vec![]; + let mut stack = VecDeque::new(); + let mut entries = vec![]; + let mut hive_idx_tracker = HiveIdxTracker { + idx: usize::MAX, + paths, + check_directory_level, + }; + + for (path_idx, path) in paths.iter().enumerate() { + let path_parts = &HFPathParts::try_from_uri(path.to_str().unwrap())?; + let repo_location = &HFRepoLocation::new( + &path_parts.bucket, + &path_parts.repository, + &path_parts.revision, + ); + let rel_path = path_parts.path.as_str(); + + let (prefix, expansion) = if glob { + extract_prefix_expansion(rel_path)? + } else { + (Cow::Owned(path_parts.path.clone()), None) + }; + let expansion_matcher = &if expansion.is_some() { + Some(Matcher::new(prefix.to_string(), expansion.as_deref())?) + } else { + None + }; + + if !path_parts.path.ends_with("/") && expansion.is_none() { + hive_idx_tracker.update(0, path_idx)?; + let file_uri = repo_location.get_file_uri(rel_path); + let file_uri = file_uri.as_str(); + + if with_concurrency_budget(1, || async { + client.head(file_uri).send().await.map_err(to_compute_err) + }) + .await? + .status() + == 200 + { + out_paths.push(PathBuf::from(file_uri)); + continue; + } + } + + hive_idx_tracker.update(repo_location.get_file_uri(rel_path).len(), path_idx)?; + + assert!(stack.is_empty()); + stack.push_back(prefix.into_owned()); + + while let Some(rel_path) = stack.pop_front() { + assert!(entries.is_empty()); + + let uri = repo_location.get_api_uri(rel_path.as_str()); + let mut gp = GetPages { + uri: Some(uri), + client, + }; + + if let Some(matcher) = expansion_matcher { + while let Some(bytes) = gp.next().await { + let bytes = bytes?; + let bytes = bytes.as_ref(); + let response: Vec = decode_json_response(bytes)?; + entries.extend(response.into_iter().filter(|x| { + !x.is_file() || (x.size > 0 && matcher.is_matching(x.path.as_str())) + })); + } + } else { + while let Some(bytes) = gp.next().await { + let bytes = bytes?; + let bytes = bytes.as_ref(); + let response: Vec = decode_json_response(bytes)?; + entries.extend(response.into_iter().filter(|x| !x.is_file() || x.size > 0)); + } + } + + entries.sort_unstable_by(|a, b| a.path.as_str().partial_cmp(b.path.as_str()).unwrap()); + + for e in entries.drain(..) { + if e.is_file() { + out_paths.push(PathBuf::from(repo_location.get_file_uri(&e.path))); + } else if e.is_directory() { + stack.push_back(e.path); + } + } + } + } + + Ok((hive_idx_tracker.idx, out_paths)) +} + +fn percent_encode(bytes: &[u8]) -> percent_encoding::PercentEncode { + percent_encoding::percent_encode(bytes, URL_ENCODE_CHAR_SET) +} + +mod tests { + + #[test] + fn test_hf_path_from_uri() { + use super::HFPathParts; + + let uri = "hf://datasets/pola-rs/polars/README.md"; + let expect = HFPathParts { + bucket: "datasets".into(), + repository: "pola-rs/polars".into(), + revision: "main".into(), + path: "README.md".into(), + }; + + assert_eq!(HFPathParts::try_from_uri(uri).unwrap(), expect); + + let uri = "hf://spaces/pola-rs/polars@~parquet/"; + let expect = HFPathParts { + bucket: "spaces".into(), + repository: "pola-rs/polars".into(), + revision: "~parquet".into(), + path: "".into(), + }; + + assert_eq!(HFPathParts::try_from_uri(uri).unwrap(), expect); + + let uri = "hf://spaces/pola-rs/polars@~parquet"; + let expect = HFPathParts { + bucket: "spaces".into(), + repository: "pola-rs/polars".into(), + revision: "~parquet".into(), + path: "".into(), + }; + + assert_eq!(HFPathParts::try_from_uri(uri).unwrap(), expect); + + for uri in [ + "://", + "s3://", + "https://", + "hf://", + "hf:///", + "hf:////", + "hf://datasets/a", + "hf://datasets/a/", + "hf://bucket/a/b/c", // Invalid bucket name + ] { + let out = HFPathParts::try_from_uri(uri); + if out.is_err() { + continue; + } + panic!("expected err result for uri {} instead of {:?}", uri, out); + } + } + + #[test] + fn test_get_pages_find_next_link() { + use super::GetPages; + let link = r#"; rel="next", ; rel="last""#.as_bytes(); + + assert_eq!( + GetPages::find_link(link, "next".as_bytes()).map(Result::unwrap), + Some("https://api.github.com/repositories/263727855/issues?page=3".into()), + ); + + assert_eq!( + GetPages::find_link(link, "last".as_bytes()).map(Result::unwrap), + Some("https://api.github.com/repositories/263727855/issues?page=7".into()), + ); + + assert_eq!( + GetPages::find_link(link, "non-existent".as_bytes()).map(Result::unwrap), + None, + ); + } +} diff --git a/crates/polars-io/src/path_utils/mod.rs b/crates/polars-io/src/path_utils/mod.rs new file mode 100644 index 000000000000..78ab2ec4d700 --- /dev/null +++ b/crates/polars-io/src/path_utils/mod.rs @@ -0,0 +1,577 @@ +use std::collections::VecDeque; +use std::path::{Path, PathBuf}; +use std::sync::{Arc, LazyLock}; + +use polars_core::config; +use polars_core::error::{PolarsError, PolarsResult, polars_bail, to_compute_err}; +use polars_utils::pl_str::PlSmallStr; + +#[cfg(feature = "cloud")] +mod hugging_face; + +use crate::cloud::CloudOptions; + +pub static POLARS_TEMP_DIR_BASE_PATH: LazyLock> = LazyLock::new(|| { + (|| { + let verbose = config::verbose(); + + let path = if let Ok(v) = std::env::var("POLARS_TEMP_DIR").map(PathBuf::from) { + if verbose { + eprintln!("init_temp_dir: sourced from POLARS_TEMP_DIR") + } + v + } else if cfg!(target_family = "unix") { + let id = std::env::var("USER") + .inspect(|_| { + if verbose { + eprintln!("init_temp_dir: sourced $USER") + } + }) + .or_else(|_e| { + // We shouldn't hit here, but we can fallback to hashing $HOME if blake3 is + // available (it is available when file_cache is activated). + #[cfg(feature = "file_cache")] + { + std::env::var("HOME") + .inspect(|_| { + if verbose { + eprintln!("init_temp_dir: sourced $HOME") + } + }) + .map(|x| blake3::hash(x.as_bytes()).to_hex()[..32].to_string()) + } + #[cfg(not(feature = "file_cache"))] + { + Err(_e) + } + }); + + if let Ok(v) = id { + std::env::temp_dir().join(format!("polars-{}/", v)) + } else { + return Err(std::io::Error::other( + "could not load $USER or $HOME environment variables", + )); + } + } else if cfg!(target_family = "windows") { + // Setting permissions on Windows is not as easy compared to Unix, but fortunately + // the default temporary directory location is underneath the user profile, so we + // shouldn't need to do anything. + std::env::temp_dir().join("polars/") + } else { + std::env::temp_dir().join("polars/") + } + .into_boxed_path(); + + if let Err(err) = std::fs::create_dir_all(path.as_ref()) { + if !path.is_dir() { + panic!( + "failed to create temporary directory: {} (path = {:?})", + err, + path.as_ref() + ); + } + } + + #[cfg(target_family = "unix")] + { + use std::os::unix::fs::PermissionsExt; + + let result = (|| { + std::fs::set_permissions(path.as_ref(), std::fs::Permissions::from_mode(0o700))?; + let perms = std::fs::metadata(path.as_ref())?.permissions(); + + if (perms.mode() % 0o1000) != 0o700 { + std::io::Result::Err(std::io::Error::other(format!( + "permission mismatch: {:?}", + perms + ))) + } else { + std::io::Result::Ok(()) + } + })() + .map_err(|e| { + std::io::Error::new( + e.kind(), + format!( + "error setting temporary directory permissions: {} (path = {:?})", + e, + path.as_ref() + ), + ) + }); + + if std::env::var("POLARS_ALLOW_UNSECURED_TEMP_DIR").as_deref() != Ok("1") { + result?; + } + } + + std::io::Result::Ok(path) + })() + .map_err(|e| { + std::io::Error::new( + e.kind(), + format!( + "error initializing temporary directory: {} \ + consider explicitly setting POLARS_TEMP_DIR", + e + ), + ) + }) + .unwrap() +}); + +/// Replaces a "~" in the Path with the home directory. +pub fn resolve_homedir(path: &dyn AsRef) -> PathBuf { + let path = path.as_ref(); + + if path.starts_with("~") { + // home crate does not compile on wasm https://github.com/rust-lang/cargo/issues/12297 + #[cfg(not(target_family = "wasm"))] + if let Some(homedir) = home::home_dir() { + return homedir.join(path.strip_prefix("~").unwrap()); + } + } + + path.into() +} + +polars_utils::regex_cache::cached_regex! { + static CLOUD_URL = r"^(s3a?|gs|gcs|file|abfss?|azure|az|adl|https?|hf)://"; +} + +/// Check if the path is a cloud url. +pub fn is_cloud_url>(p: P) -> bool { + match p.as_ref().as_os_str().to_str() { + Some(s) => CLOUD_URL.is_match(s), + _ => false, + } +} + +/// Get the index of the first occurrence of a glob symbol. +pub fn get_glob_start_idx(path: &[u8]) -> Option { + memchr::memchr3(b'*', b'?', b'[', path) +} + +/// Returns `true` if `expanded_paths` were expanded from a single directory +pub fn expanded_from_single_directory>( + paths: &[P], + expanded_paths: &[P], +) -> bool { + // Single input that isn't a glob + paths.len() == 1 && get_glob_start_idx(paths[0].as_ref().to_str().unwrap().as_bytes()).is_none() + // And isn't a file + && { + ( + // For local paths, we can just use `is_dir` + !is_cloud_url(paths[0].as_ref()) && paths[0].as_ref().is_dir() + ) + || ( + // For cloud paths, we determine that the input path isn't a file by checking that the + // output path differs. + expanded_paths.is_empty() || (paths[0].as_ref() != expanded_paths[0].as_ref()) + ) + } +} + +/// Recursively traverses directories and expands globs if `glob` is `true`. +pub fn expand_paths( + paths: &[PathBuf], + glob: bool, + #[allow(unused_variables)] cloud_options: Option<&CloudOptions>, +) -> PolarsResult> { + expand_paths_hive(paths, glob, cloud_options, false).map(|x| x.0) +} + +struct HiveIdxTracker<'a> { + idx: usize, + paths: &'a [PathBuf], + check_directory_level: bool, +} + +impl HiveIdxTracker<'_> { + fn update(&mut self, i: usize, path_idx: usize) -> PolarsResult<()> { + let check_directory_level = self.check_directory_level; + let paths = self.paths; + + if check_directory_level + && ![usize::MAX, i].contains(&self.idx) + // They could still be the same directory level, just with different name length + && (paths[path_idx].parent() != paths[path_idx - 1].parent()) + { + polars_bail!( + InvalidOperation: + "attempted to read from different directory levels with hive partitioning enabled: first path: {}, second path: {}", + paths[path_idx - 1].to_str().unwrap(), + paths[path_idx].to_str().unwrap(), + ) + } else { + self.idx = std::cmp::min(self.idx, i); + Ok(()) + } + } +} + +/// Recursively traverses directories and expands globs if `glob` is `true`. +/// Returns the expanded paths and the index at which to start parsing hive +/// partitions from the path. +pub fn expand_paths_hive( + paths: &[PathBuf], + glob: bool, + #[allow(unused_variables)] cloud_options: Option<&CloudOptions>, + check_directory_level: bool, +) -> PolarsResult<(Arc<[PathBuf]>, usize)> { + let Some(first_path) = paths.first() else { + return Ok((vec![].into(), 0)); + }; + + let is_cloud = is_cloud_url(first_path); + + /// Wrapper around `Vec` that also tracks file extensions, so that + /// we don't have to traverse the entire list again to validate extensions. + struct OutPaths { + paths: Vec, + exts: [Option<(PlSmallStr, usize)>; 2], + current_idx: usize, + } + + impl OutPaths { + fn update_ext_status( + current_idx: &mut usize, + exts: &mut [Option<(PlSmallStr, usize)>; 2], + value: &Path, + ) { + let ext = value + .extension() + .map(|x| PlSmallStr::from(x.to_str().unwrap())) + .unwrap_or(PlSmallStr::EMPTY); + + if exts[0].is_none() { + exts[0] = Some((ext, *current_idx)); + } else if exts[1].is_none() && ext != exts[0].as_ref().unwrap().0 { + exts[1] = Some((ext, *current_idx)); + } + + *current_idx += 1; + } + + fn push(&mut self, value: PathBuf) { + { + let current_idx = &mut self.current_idx; + let exts = &mut self.exts; + Self::update_ext_status(current_idx, exts, &value); + } + self.paths.push(value) + } + + fn extend(&mut self, values: impl IntoIterator) { + let current_idx = &mut self.current_idx; + let exts = &mut self.exts; + + self.paths.extend(values.into_iter().inspect(|x| { + Self::update_ext_status(current_idx, exts, x); + })) + } + + fn extend_from_slice(&mut self, values: &[PathBuf]) { + self.extend(values.iter().cloned()) + } + } + + let mut out_paths = OutPaths { + paths: vec![], + exts: [None, None], + current_idx: 0, + }; + + let mut hive_idx_tracker = HiveIdxTracker { + idx: usize::MAX, + paths, + check_directory_level, + }; + + if is_cloud || { cfg!(not(target_family = "windows")) && config::force_async() } { + #[cfg(feature = "cloud")] + { + use polars_utils::_limit_path_len_io_err; + + use crate::cloud::object_path_from_str; + + if first_path.starts_with("hf://") { + let (expand_start_idx, paths) = crate::pl_async::get_runtime().block_in_place_on( + hugging_face::expand_paths_hf( + paths, + check_directory_level, + cloud_options, + glob, + ), + )?; + + return Ok((Arc::from(paths), expand_start_idx)); + } + + let format_path = |scheme: &str, bucket: &str, location: &str| { + if is_cloud { + format!("{}://{}/{}", scheme, bucket, location) + } else { + format!("/{}", location) + } + }; + + let expand_path_cloud = |path: &str, + cloud_options: Option<&CloudOptions>| + -> PolarsResult<(usize, Vec)> { + crate::pl_async::get_runtime().block_in_place_on(async { + let (cloud_location, store) = + crate::cloud::build_object_store(path, cloud_options, glob).await?; + let prefix = object_path_from_str(&cloud_location.prefix)?; + + let out = if !path.ends_with("/") + && (!glob || cloud_location.expansion.is_none()) + && { + // We need to check if it is a directory for local paths (we can be here due + // to FORCE_ASYNC). For cloud paths the convention is that the user must add + // a trailing slash `/` to scan directories. We don't infer it as that would + // mean sending one network request per path serially (very slow). + is_cloud || PathBuf::from(path).is_file() + } { + ( + 0, + vec![PathBuf::from(format_path( + &cloud_location.scheme, + &cloud_location.bucket, + prefix.as_ref(), + ))], + ) + } else { + use futures::TryStreamExt; + + if !is_cloud { + // FORCE_ASYNC in the test suite wants us to raise a proper error message + // for non-existent file paths. Note we can't do this for cloud paths as + // there is no concept of a "directory" - a non-existent path is + // indistinguishable from an empty directory. + let path = PathBuf::from(path); + if !path.is_dir() { + path.metadata() + .map_err(|err| _limit_path_len_io_err(&path, err))?; + } + } + + let cloud_location = &cloud_location; + + let mut paths = store + .try_exec_rebuild_on_err(|store| { + let st = store.clone(); + + async { + let store = st; + store + .list(Some(&prefix)) + .try_filter_map(|x| async move { + let out = (x.size > 0).then(|| { + PathBuf::from({ + format_path( + &cloud_location.scheme, + &cloud_location.bucket, + x.location.as_ref(), + ) + }) + }); + Ok(out) + }) + .try_collect::>() + .await + .map_err(to_compute_err) + } + }) + .await?; + + paths.sort_unstable(); + ( + format_path( + &cloud_location.scheme, + &cloud_location.bucket, + &cloud_location.prefix, + ) + .len(), + paths, + ) + }; + + PolarsResult::Ok(out) + }) + }; + + for (path_idx, path) in paths.iter().enumerate() { + if path.to_str().unwrap().starts_with("http") { + out_paths.push(path.clone()); + hive_idx_tracker.update(0, path_idx)?; + continue; + } + + let glob_start_idx = get_glob_start_idx(path.to_str().unwrap().as_bytes()); + + let path = if glob && glob_start_idx.is_some() { + path.clone() + } else { + let (expand_start_idx, paths) = + expand_path_cloud(path.to_str().unwrap(), cloud_options)?; + out_paths.extend_from_slice(&paths); + hive_idx_tracker.update(expand_start_idx, path_idx)?; + continue; + }; + + hive_idx_tracker.update(0, path_idx)?; + + let iter = crate::pl_async::get_runtime() + .block_in_place_on(crate::async_glob(path.to_str().unwrap(), cloud_options))?; + + if is_cloud { + out_paths.extend(iter.into_iter().map(PathBuf::from)); + } else { + // FORCE_ASYNC, remove leading file:// as not all readers support it. + out_paths.extend(iter.iter().map(|x| &x[7..]).map(PathBuf::from)) + } + } + } + #[cfg(not(feature = "cloud"))] + panic!("Feature `cloud` must be enabled to use globbing patterns with cloud urls.") + } else { + let mut stack = VecDeque::new(); + + for path_idx in 0..paths.len() { + let path = &paths[path_idx]; + stack.clear(); + + if path.is_dir() { + let i = path.to_str().unwrap().len(); + + hive_idx_tracker.update(i, path_idx)?; + + stack.push_back(path.clone()); + + while let Some(dir) = stack.pop_front() { + let mut paths = std::fs::read_dir(dir) + .map_err(PolarsError::from)? + .map(|x| x.map(|x| x.path())) + .collect::>>() + .map_err(PolarsError::from)?; + paths.sort_unstable(); + + for path in paths { + if path.is_dir() { + stack.push_back(path); + } else if path.metadata()?.len() > 0 { + out_paths.push(path); + } + } + } + + continue; + } + + let i = get_glob_start_idx(path.to_str().unwrap().as_bytes()); + + if glob && i.is_some() { + hive_idx_tracker.update(0, path_idx)?; + + let Ok(paths) = glob::glob(path.to_str().unwrap()) else { + polars_bail!(ComputeError: "invalid glob pattern given") + }; + + for path in paths { + let path = path.map_err(to_compute_err)?; + if !path.is_dir() && path.metadata()?.len() > 0 { + out_paths.push(path); + } + } + } else { + hive_idx_tracker.update(0, path_idx)?; + out_paths.push(path.clone()); + } + } + } + + assert_eq!(out_paths.current_idx, out_paths.paths.len()); + + if expanded_from_single_directory(paths, out_paths.paths.as_slice()) { + if let [Some((_, i1)), Some((_, i2))] = out_paths.exts { + polars_bail!( + InvalidOperation: r#"directory contained paths with different file extensions: \ + first path: {}, second path: {}. Please use a glob pattern to explicitly specify \ + which files to read (e.g. "dir/**/*", "dir/**/*.parquet")"#, + &out_paths.paths[i1].to_string_lossy(), &out_paths.paths[i2].to_string_lossy() + ) + } + } + + Ok((out_paths.paths.into(), hive_idx_tracker.idx)) +} + +/// Ignores errors from `std::fs::create_dir_all` if the directory exists. +#[cfg(feature = "file_cache")] +pub(crate) fn ensure_directory_init(path: &Path) -> std::io::Result<()> { + let result = std::fs::create_dir_all(path); + + if path.is_dir() { Ok(()) } else { result } +} + +#[cfg(test)] +mod tests { + use std::path::PathBuf; + + use super::resolve_homedir; + + #[cfg(not(target_os = "windows"))] + #[test] + fn test_resolve_homedir() { + let paths: Vec = vec![ + "~/dir1/dir2/test.csv".into(), + "/abs/path/test.csv".into(), + "rel/path/test.csv".into(), + "/".into(), + "~".into(), + ]; + + let resolved: Vec = paths.iter().map(|x| resolve_homedir(x)).collect(); + + assert_eq!(resolved[0].file_name(), paths[0].file_name()); + assert!(resolved[0].is_absolute()); + assert_eq!(resolved[1], paths[1]); + assert_eq!(resolved[2], paths[2]); + assert_eq!(resolved[3], paths[3]); + assert!(resolved[4].is_absolute()); + } + + #[cfg(target_os = "windows")] + #[test] + fn test_resolve_homedir_windows() { + let paths: Vec = vec![ + r#"c:\Users\user1\test.csv"#.into(), + r#"~\user1\test.csv"#.into(), + "~".into(), + ]; + + let resolved: Vec = paths.iter().map(|x| resolve_homedir(x)).collect(); + + assert_eq!(resolved[0], paths[0]); + assert_eq!(resolved[1].file_name(), paths[1].file_name()); + assert!(resolved[1].is_absolute()); + assert!(resolved[2].is_absolute()); + } + + #[test] + fn test_http_path_with_query_parameters_is_not_expanded_as_glob() { + // Don't confuse HTTP URL's with query parameters for globs. + // See https://github.com/pola-rs/polars/pull/17774 + use std::path::PathBuf; + + use super::expand_paths; + + let path = "https://pola.rs/test.csv?token=bear"; + let paths = &[PathBuf::from(path)]; + let out = expand_paths(paths, true, None).unwrap(); + assert_eq!(out.as_ref(), paths); + } +} diff --git a/crates/polars-io/src/pl_async.rs b/crates/polars-io/src/pl_async.rs new file mode 100644 index 000000000000..641d8eb7b35d --- /dev/null +++ b/crates/polars-io/src/pl_async.rs @@ -0,0 +1,370 @@ +use std::error::Error; +use std::future::Future; +use std::ops::Deref; +use std::sync::LazyLock; +use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering}; + +use polars_core::POOL; +use polars_core::config::{self, verbose}; +use tokio::runtime::{Builder, Runtime}; +use tokio::sync::Semaphore; + +static CONCURRENCY_BUDGET: std::sync::OnceLock<(Semaphore, u32)> = std::sync::OnceLock::new(); +pub(super) const MAX_BUDGET_PER_REQUEST: usize = 10; + +/// Used to determine chunks when splitting large ranges, or combining small +/// ranges. +static DOWNLOAD_CHUNK_SIZE: LazyLock = LazyLock::new(|| { + let v: usize = std::env::var("POLARS_DOWNLOAD_CHUNK_SIZE") + .as_deref() + .map(|x| x.parse().expect("integer")) + .unwrap_or(64 * 1024 * 1024); + + if config::verbose() { + eprintln!("async download_chunk_size: {}", v) + } + + v +}); + +pub(super) fn get_download_chunk_size() -> usize { + *DOWNLOAD_CHUNK_SIZE +} + +static UPLOAD_CHUNK_SIZE: LazyLock = LazyLock::new(|| { + let v: usize = std::env::var("POLARS_UPLOAD_CHUNK_SIZE") + .as_deref() + .map(|x| x.parse().expect("integer")) + .unwrap_or(64 * 1024 * 1024); + + if config::verbose() { + eprintln!("async upload_chunk_size: {}", v) + } + + v +}); + +pub(super) fn get_upload_chunk_size() -> usize { + *UPLOAD_CHUNK_SIZE +} + +pub trait GetSize { + fn size(&self) -> u64; +} + +impl GetSize for bytes::Bytes { + fn size(&self) -> u64 { + self.len() as u64 + } +} + +impl GetSize for Vec { + fn size(&self) -> u64 { + self.iter().map(|v| v.size()).sum() + } +} + +impl GetSize for Result { + fn size(&self) -> u64 { + match self { + Ok(v) => v.size(), + Err(_) => 0, + } + } +} + +#[cfg(feature = "cloud")] +pub(crate) struct Size(u64); + +#[cfg(feature = "cloud")] +impl GetSize for Size { + fn size(&self) -> u64 { + self.0 + } +} +#[cfg(feature = "cloud")] +impl From for Size { + fn from(value: u64) -> Self { + Self(value) + } +} + +enum Optimization { + Step, + Accept, + Finished, +} + +struct SemaphoreTuner { + previous_download_speed: u64, + last_tune: std::time::Instant, + downloaded: AtomicU64, + download_time: AtomicU64, + opt_state: Optimization, + increments: u32, +} + +impl SemaphoreTuner { + fn new() -> Self { + Self { + previous_download_speed: 0, + last_tune: std::time::Instant::now(), + downloaded: AtomicU64::new(0), + download_time: AtomicU64::new(0), + opt_state: Optimization::Step, + increments: 0, + } + } + fn should_tune(&self) -> bool { + match self.opt_state { + Optimization::Finished => false, + _ => self.last_tune.elapsed().as_millis() > 350, + } + } + + fn add_stats(&self, downloaded_bytes: u64, download_time: u64) { + self.downloaded + .fetch_add(downloaded_bytes, Ordering::Relaxed); + self.download_time + .fetch_add(download_time, Ordering::Relaxed); + } + + fn increment(&mut self, semaphore: &Semaphore) { + semaphore.add_permits(1); + self.increments += 1; + } + + fn tune(&mut self, semaphore: &'static Semaphore) -> bool { + let bytes_downloaded = self.downloaded.fetch_add(0, Ordering::Relaxed); + let time_elapsed = self.download_time.fetch_add(0, Ordering::Relaxed); + let download_speed = bytes_downloaded + .checked_div(time_elapsed) + .unwrap_or_default(); + + let increased = download_speed > self.previous_download_speed; + self.previous_download_speed = download_speed; + match self.opt_state { + Optimization::Step => { + self.increment(semaphore); + self.opt_state = Optimization::Accept + }, + Optimization::Accept => { + // Accept the step + if increased { + // Set new step + self.increment(semaphore); + // Keep accept state to check next iteration + } + // Decline the step + else { + self.opt_state = Optimization::Finished; + FINISHED_TUNING.store(true, Ordering::Relaxed); + if verbose() { + eprintln!( + "concurrency tuner finished after adding {} steps", + self.increments + ) + } + // Finished. + return true; + } + }, + Optimization::Finished => {}, + } + self.last_tune = std::time::Instant::now(); + // Not finished. + false + } +} +static INCR: AtomicU8 = AtomicU8::new(0); +static FINISHED_TUNING: AtomicBool = AtomicBool::new(false); +static PERMIT_STORE: std::sync::OnceLock> = + std::sync::OnceLock::new(); + +fn get_semaphore() -> &'static (Semaphore, u32) { + CONCURRENCY_BUDGET.get_or_init(|| { + let permits = std::env::var("POLARS_CONCURRENCY_BUDGET") + .map(|s| { + let budget = s.parse::().expect("integer"); + FINISHED_TUNING.store(true, Ordering::Relaxed); + budget + }) + .unwrap_or_else(|_| std::cmp::max(POOL.current_num_threads(), MAX_BUDGET_PER_REQUEST)); + (Semaphore::new(permits), permits as u32) + }) +} + +pub(crate) fn get_concurrency_limit() -> u32 { + get_semaphore().1 +} + +pub async fn tune_with_concurrency_budget(requested_budget: u32, callable: F) -> Fut::Output +where + F: FnOnce() -> Fut, + Fut: Future, + Fut::Output: GetSize, +{ + let (semaphore, initial_budget) = get_semaphore(); + + // This would never finish otherwise. + assert!(requested_budget <= *initial_budget); + + // Keep permit around. + // On drop it is returned to the semaphore. + let _permit_acq = semaphore.acquire_many(requested_budget).await.unwrap(); + + let now = std::time::Instant::now(); + let res = callable().await; + + if FINISHED_TUNING.load(Ordering::Relaxed) || res.size() == 0 { + return res; + } + + let duration = now.elapsed().as_millis() as u64; + let permit_store = PERMIT_STORE.get_or_init(|| tokio::sync::RwLock::new(SemaphoreTuner::new())); + + let Ok(tuner) = permit_store.try_read() else { + return res; + }; + // Keep track of download speed + tuner.add_stats(res.size(), duration); + + // We only tune every n ms + if !tuner.should_tune() { + return res; + } + // Drop the read tuner before trying to acquire a writer + drop(tuner); + + // Reduce locking by letting only 1 in 5 tasks lock the tuner + if (INCR.fetch_add(1, Ordering::Relaxed) % 5) != 0 { + return res; + } + // Never lock as we will deadlock. This can run under rayon + let Ok(mut tuner) = permit_store.try_write() else { + return res; + }; + let finished = tuner.tune(semaphore); + if finished { + drop(_permit_acq); + // Undo the last step + let undo = semaphore.acquire().await.unwrap(); + std::mem::forget(undo) + } + res +} + +pub async fn with_concurrency_budget(requested_budget: u32, callable: F) -> Fut::Output +where + F: FnOnce() -> Fut, + Fut: Future, +{ + let (semaphore, initial_budget) = get_semaphore(); + + // This would never finish otherwise. + assert!(requested_budget <= *initial_budget); + + // Keep permit around. + // On drop it is returned to the semaphore. + let _permit_acq = semaphore.acquire_many(requested_budget).await.unwrap(); + + callable().await +} + +pub struct RuntimeManager { + rt: Runtime, +} + +impl RuntimeManager { + fn new() -> Self { + let n_threads = std::env::var("POLARS_ASYNC_THREAD_COUNT") + .map(|x| x.parse::().expect("integer")) + .unwrap_or(POOL.current_num_threads().clamp(1, 4)); + + if polars_core::config::verbose() { + eprintln!("async thread count: {}", n_threads); + } + + let rt = Builder::new_multi_thread() + .worker_threads(n_threads) + .enable_io() + .enable_time() + .build() + .unwrap(); + + Self { rt } + } + + /// Forcibly blocks this thread to evaluate the given future. This can be + /// dangerous and lead to deadlocks if called re-entrantly on an async + /// worker thread as the entire thread pool can end up blocking, leading to + /// a deadlock. If you want to prevent this use block_on, which will panic + /// if called from an async thread. + pub fn block_in_place_on(&self, future: F) -> F::Output + where + F: Future, + { + tokio::task::block_in_place(|| self.rt.block_on(future)) + } + + /// Blocks this thread to evaluate the given future. Panics if the current + /// thread is an async runtime worker thread. + pub fn block_on(&self, future: F) -> F::Output + where + F: Future, + { + self.rt.block_on(future) + } + + /// Spawns a future onto the Tokio runtime (see [`tokio::runtime::Runtime::spawn`]). + pub fn spawn(&self, future: F) -> tokio::task::JoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + self.rt.spawn(future) + } + + // See [`tokio::runtime::Runtime::spawn_blocking`]. + pub fn spawn_blocking(&self, f: F) -> tokio::task::JoinHandle + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + self.rt.spawn_blocking(f) + } + + /// Run a task on the rayon threadpool. To avoid deadlocks, if the current thread is already a + /// rayon thread, the task is executed on the current thread after tokio's `block_in_place` is + /// used to spawn another thread to poll futures. + pub async fn spawn_rayon(&self, func: F) -> O + where + F: FnOnce() -> O + Send + Sync + 'static, + O: Send + Sync + 'static, + { + if POOL.current_thread_index().is_some() { + // We are a rayon thread, so we can't use POOL.spawn as it would mean we spawn a task and block until + // another rayon thread executes it - we would deadlock if all rayon threads did this. + // Safety: The tokio runtime flavor is multi-threaded. + tokio::task::block_in_place(func) + } else { + let (tx, rx) = tokio::sync::oneshot::channel(); + + let func = move || { + let out = func(); + // Don't unwrap send attempt - async task could be cancelled. + let _ = tx.send(out); + }; + + POOL.spawn(func); + + rx.await.unwrap() + } + } +} + +static RUNTIME: LazyLock = LazyLock::new(RuntimeManager::new); + +pub fn get_runtime() -> &'static RuntimeManager { + RUNTIME.deref() +} diff --git a/crates/polars-io/src/predicates.rs b/crates/polars-io/src/predicates.rs new file mode 100644 index 000000000000..00dc0fbc2cac --- /dev/null +++ b/crates/polars-io/src/predicates.rs @@ -0,0 +1,510 @@ +use std::fmt; + +use arrow::array::Array; +use arrow::bitmap::{Bitmap, BitmapBuilder}; +use polars_core::prelude::*; +#[cfg(feature = "parquet")] +use polars_parquet::read::expr::{ParquetColumnExpr, ParquetScalar, ParquetScalarRange}; +use polars_utils::format_pl_smallstr; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +pub trait PhysicalIoExpr: Send + Sync { + /// Take a [`DataFrame`] and produces a boolean [`Series`] that serves + /// as a predicate mask + fn evaluate_io(&self, df: &DataFrame) -> PolarsResult; +} + +#[derive(Debug, Clone)] +pub enum SpecializedColumnPredicateExpr { + Eq(Scalar), + EqMissing(Scalar), +} + +#[derive(Clone)] +pub struct ColumnPredicateExpr { + column_name: PlSmallStr, + dtype: DataType, + specialized: Option, + expr: Arc, +} + +impl ColumnPredicateExpr { + pub fn new( + column_name: PlSmallStr, + dtype: DataType, + expr: Arc, + specialized: Option, + ) -> Self { + Self { + column_name, + dtype, + specialized, + expr, + } + } + + pub fn is_eq_scalar(&self) -> bool { + self.to_eq_scalar().is_some() + } + pub fn to_eq_scalar(&self) -> Option<&Scalar> { + match &self.specialized { + Some(SpecializedColumnPredicateExpr::Eq(sc)) if !sc.is_null() => Some(sc), + Some(SpecializedColumnPredicateExpr::EqMissing(sc)) => Some(sc), + _ => None, + } + } +} + +#[cfg(feature = "parquet")] +impl ParquetColumnExpr for ColumnPredicateExpr { + fn evaluate_mut(&self, values: &dyn Array, bm: &mut BitmapBuilder) { + // We should never evaluate nulls with this. + assert!(values.validity().is_none_or(|v| v.set_bits() == 0)); + + // @TODO: Probably these unwraps should be removed. + let series = + Series::from_chunk_and_dtype(self.column_name.clone(), values.to_boxed(), &self.dtype) + .unwrap(); + let column = series.into_column(); + let df = unsafe { DataFrame::new_no_checks(values.len(), vec![column]) }; + + // @TODO: Probably these unwraps should be removed. + let true_mask = self.expr.evaluate_io(&df).unwrap(); + let true_mask = true_mask.bool().unwrap(); + + bm.reserve(true_mask.len()); + for chunk in true_mask.downcast_iter() { + match chunk.validity() { + None => bm.extend_from_bitmap(chunk.values()), + Some(v) => bm.extend_from_bitmap(&(chunk.values() & v)), + } + } + } + fn evaluate_null(&self) -> bool { + let column = Column::full_null(self.column_name.clone(), 1, &self.dtype); + let df = unsafe { DataFrame::new_no_checks(1, vec![column]) }; + + // @TODO: Probably these unwraps should be removed. + let true_mask = self.expr.evaluate_io(&df).unwrap(); + let true_mask = true_mask.bool().unwrap(); + + true_mask.get(0).unwrap_or(false) + } + + fn to_equals_scalar(&self) -> Option { + self.to_eq_scalar() + .and_then(|s| cast_to_parquet_scalar(s.clone())) + } + + fn to_range_scalar(&self) -> Option { + None + } +} + +#[cfg(feature = "parquet")] +fn cast_to_parquet_scalar(scalar: Scalar) -> Option { + use {AnyValue as A, ParquetScalar as P}; + + Some(match scalar.into_value() { + A::Null => P::Null, + A::Boolean(v) => P::Boolean(v), + + A::UInt8(v) => P::UInt8(v), + A::UInt16(v) => P::UInt16(v), + A::UInt32(v) => P::UInt32(v), + A::UInt64(v) => P::UInt64(v), + + A::Int8(v) => P::Int8(v), + A::Int16(v) => P::Int16(v), + A::Int32(v) => P::Int32(v), + A::Int64(v) => P::Int64(v), + + #[cfg(feature = "dtype-time")] + A::Date(v) => P::Int32(v), + #[cfg(feature = "dtype-datetime")] + A::Datetime(v, _, _) | A::DatetimeOwned(v, _, _) => P::Int64(v), + #[cfg(feature = "dtype-duration")] + A::Duration(v, _) => P::Int64(v), + #[cfg(feature = "dtype-time")] + A::Time(v) => P::Int64(v), + + A::Float32(v) => P::Float32(v), + A::Float64(v) => P::Float64(v), + + // @TODO: Cast to string + #[cfg(feature = "dtype-categorical")] + A::Categorical(_, _, _) + | A::CategoricalOwned(_, _, _) + | A::Enum(_, _, _) + | A::EnumOwned(_, _, _) => return None, + + A::String(v) => P::String(v.into()), + A::StringOwned(v) => P::String(v.as_str().into()), + A::Binary(v) => P::Binary(v.into()), + A::BinaryOwned(v) => P::Binary(v.into()), + _ => return None, + }) +} + +#[cfg(any(feature = "parquet", feature = "ipc"))] +pub fn apply_predicate( + df: &mut DataFrame, + predicate: Option<&dyn PhysicalIoExpr>, + parallel: bool, +) -> PolarsResult<()> { + if let (Some(predicate), false) = (&predicate, df.get_columns().is_empty()) { + let s = predicate.evaluate_io(df)?; + let mask = s.bool().expect("filter predicates was not of type boolean"); + + if parallel { + *df = df.filter(mask)?; + } else { + *df = df._filter_seq(mask)?; + } + } + Ok(()) +} + +/// Statistics of the values in a column. +/// +/// The following statistics are tracked for each row group: +/// - Null count +/// - Minimum value +/// - Maximum value +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct ColumnStats { + field: Field, + // Each Series contains the stats for each row group. + null_count: Option, + min_value: Option, + max_value: Option, +} + +impl ColumnStats { + /// Constructs a new [`ColumnStats`]. + pub fn new( + field: Field, + null_count: Option, + min_value: Option, + max_value: Option, + ) -> Self { + Self { + field, + null_count, + min_value, + max_value, + } + } + + /// Constructs a new [`ColumnStats`] with only the [`Field`] information and no statistics. + pub fn from_field(field: Field) -> Self { + Self { + field, + null_count: None, + min_value: None, + max_value: None, + } + } + + /// Constructs a new [`ColumnStats`] from a single-value Series. + pub fn from_column_literal(s: Series) -> Self { + debug_assert_eq!(s.len(), 1); + Self { + field: s.field().into_owned(), + null_count: None, + min_value: Some(s.clone()), + max_value: Some(s), + } + } + + pub fn field_name(&self) -> &PlSmallStr { + self.field.name() + } + + /// Returns the [`DataType`] of the column. + pub fn dtype(&self) -> &DataType { + self.field.dtype() + } + + /// Returns the null count of each row group of the column. + pub fn get_null_count_state(&self) -> Option<&Series> { + self.null_count.as_ref() + } + + /// Returns the minimum value of each row group of the column. + pub fn get_min_state(&self) -> Option<&Series> { + self.min_value.as_ref() + } + + /// Returns the maximum value of each row group of the column. + pub fn get_max_state(&self) -> Option<&Series> { + self.max_value.as_ref() + } + + /// Returns the null count of the column. + pub fn null_count(&self) -> Option { + match self.dtype() { + #[cfg(feature = "dtype-struct")] + DataType::Struct(_) => None, + _ => { + let s = self.get_null_count_state()?; + // if all null, there are no statistics. + if s.null_count() != s.len() { + s.sum().ok() + } else { + None + } + }, + } + } + + /// Returns the minimum and maximum values of the column as a single [`Series`]. + pub fn to_min_max(&self) -> Option { + let min_val = self.get_min_state()?; + let max_val = self.get_max_state()?; + let dtype = self.dtype(); + + if !use_min_max(dtype) { + return None; + } + + let mut min_max_values = min_val.clone(); + min_max_values.append(max_val).unwrap(); + if min_max_values.null_count() > 0 { + None + } else { + Some(min_max_values) + } + } + + /// Returns the minimum value of the column as a single-value [`Series`]. + /// + /// Returns `None` if no maximum value is available. + pub fn to_min(&self) -> Option<&Series> { + // @scalar-opt + let min_val = self.min_value.as_ref()?; + let dtype = min_val.dtype(); + + if !use_min_max(dtype) || min_val.len() != 1 { + return None; + } + + if min_val.null_count() > 0 { + None + } else { + Some(min_val) + } + } + + /// Returns the maximum value of the column as a single-value [`Series`]. + /// + /// Returns `None` if no maximum value is available. + pub fn to_max(&self) -> Option<&Series> { + // @scalar-opt + let max_val = self.max_value.as_ref()?; + let dtype = max_val.dtype(); + + if !use_min_max(dtype) || max_val.len() != 1 { + return None; + } + + if max_val.null_count() > 0 { + None + } else { + Some(max_val) + } + } +} + +/// Returns whether the [`DataType`] supports minimum/maximum operations. +fn use_min_max(dtype: &DataType) -> bool { + dtype.is_primitive_numeric() + || dtype.is_temporal() + || matches!( + dtype, + DataType::String | DataType::Binary | DataType::Boolean + ) +} + +pub struct ColumnStatistics { + pub dtype: DataType, + pub min: AnyValue<'static>, + pub max: AnyValue<'static>, + pub null_count: Option, +} + +pub trait SkipBatchPredicate: Send + Sync { + fn schema(&self) -> &SchemaRef; + + fn can_skip_batch( + &self, + batch_size: IdxSize, + live_columns: &PlIndexSet, + mut statistics: PlIndexMap, + ) -> PolarsResult { + let mut columns = Vec::with_capacity(1 + live_columns.len() * 3); + + columns.push(Column::new_scalar( + PlSmallStr::from_static("len"), + Scalar::new(IDX_DTYPE, batch_size.into()), + 1, + )); + + for col in live_columns.iter() { + let dtype = self.schema().get(col).unwrap(); + let (min, max, nc) = match statistics.swap_remove(col) { + None => ( + Scalar::null(dtype.clone()), + Scalar::null(dtype.clone()), + Scalar::null(IDX_DTYPE), + ), + Some(stat) => ( + Scalar::new(dtype.clone(), stat.min), + Scalar::new(dtype.clone(), stat.max), + Scalar::new( + IDX_DTYPE, + stat.null_count.map_or(AnyValue::Null, |nc| nc.into()), + ), + ), + }; + columns.extend([ + Column::new_scalar(format_pl_smallstr!("{col}_min"), min, 1), + Column::new_scalar(format_pl_smallstr!("{col}_max"), max, 1), + Column::new_scalar(format_pl_smallstr!("{col}_nc"), nc, 1), + ]); + } + + // SAFETY: + // * Each column is length = 1 + // * We have an IndexSet, so each column name is unique + let df = unsafe { DataFrame::new_no_checks(1, columns) }; + Ok(self.evaluate_with_stat_df(&df)?.get_bit(0)) + } + fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult; +} + +#[derive(Clone)] +pub struct ColumnPredicates { + pub predicates: PlHashMap< + PlSmallStr, + ( + Arc, + Option, + ), + >, + pub is_sumwise_complete: bool, +} + +// I want to be explicit here. +#[allow(clippy::derivable_impls)] +impl Default for ColumnPredicates { + fn default() -> Self { + Self { + predicates: PlHashMap::default(), + is_sumwise_complete: false, + } + } +} + +pub struct PhysicalExprWithConstCols { + constants: Vec<(PlSmallStr, Scalar)>, + child: T, +} + +impl SkipBatchPredicate for PhysicalExprWithConstCols> { + fn schema(&self) -> &SchemaRef { + self.child.schema() + } + + fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult { + let mut df = df.clone(); + for (name, scalar) in self.constants.iter() { + df.with_column(Column::new_scalar( + name.clone(), + scalar.clone(), + df.height(), + ))?; + } + self.child.evaluate_with_stat_df(&df) + } +} + +impl PhysicalIoExpr for PhysicalExprWithConstCols> { + fn evaluate_io(&self, df: &DataFrame) -> PolarsResult { + let mut df = df.clone(); + for (name, scalar) in self.constants.iter() { + df.with_column(Column::new_scalar( + name.clone(), + scalar.clone(), + df.height(), + ))?; + } + + self.child.evaluate_io(&df) + } +} + +#[derive(Clone)] +pub struct ScanIOPredicate { + pub predicate: Arc, + + /// Column names that are used in the predicate. + pub live_columns: Arc>, + + /// A predicate that gets given statistics and evaluates whether a batch can be skipped. + pub skip_batch_predicate: Option>, + + /// A predicate that gets given statistics and evaluates whether a batch can be skipped. + pub column_predicates: Arc, +} +impl ScanIOPredicate { + pub fn set_external_constant_columns(&mut self, constant_columns: Vec<(PlSmallStr, Scalar)>) { + if constant_columns.is_empty() { + return; + } + + let mut live_columns = self.live_columns.as_ref().clone(); + for (c, _) in constant_columns.iter() { + live_columns.swap_remove(c); + } + self.live_columns = Arc::new(live_columns); + + if let Some(skip_batch_predicate) = self.skip_batch_predicate.take() { + let mut sbp_constant_columns = Vec::with_capacity(constant_columns.len() * 3); + for (c, v) in constant_columns.iter() { + sbp_constant_columns.push((format_pl_smallstr!("{c}_min"), v.clone())); + sbp_constant_columns.push((format_pl_smallstr!("{c}_max"), v.clone())); + let nc = if v.is_null() { + AnyValue::Null + } else { + (0 as IdxSize).into() + }; + sbp_constant_columns + .push((format_pl_smallstr!("{c}_nc"), Scalar::new(IDX_DTYPE, nc))); + } + self.skip_batch_predicate = Some(Arc::new(PhysicalExprWithConstCols { + constants: sbp_constant_columns, + child: skip_batch_predicate, + })); + } + + let mut column_predicates = self.column_predicates.as_ref().clone(); + for (c, _) in constant_columns.iter() { + column_predicates.predicates.remove(c); + } + self.column_predicates = Arc::new(column_predicates); + + self.predicate = Arc::new(PhysicalExprWithConstCols { + constants: constant_columns, + child: self.predicate.clone(), + }); + } +} + +impl fmt::Debug for ScanIOPredicate { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("scan_io_predicate") + } +} diff --git a/crates/polars-io/src/prelude.rs b/crates/polars-io/src/prelude.rs new file mode 100644 index 000000000000..9d4dad35b341 --- /dev/null +++ b/crates/polars-io/src/prelude.rs @@ -0,0 +1,16 @@ +pub use crate::cloud; +#[cfg(feature = "csv")] +pub use crate::csv::{read::*, write::*}; +#[cfg(any(feature = "ipc", feature = "ipc_streaming"))] +pub use crate::ipc::*; +#[cfg(feature = "json")] +pub use crate::json::*; +#[cfg(feature = "json")] +pub use crate::ndjson::core::*; +#[cfg(feature = "parquet")] +pub use crate::parquet::{metadata::*, read::*, write::*}; +#[cfg(feature = "parquet")] +pub use crate::partition::write_partitioned_dataset; +pub use crate::path_utils::*; +pub use crate::shared::{SerReader, SerWriter}; +pub use crate::utils::*; diff --git a/crates/polars-io/src/shared.rs b/crates/polars-io/src/shared.rs new file mode 100644 index 000000000000..d744aef57b50 --- /dev/null +++ b/crates/polars-io/src/shared.rs @@ -0,0 +1,151 @@ +use std::io::{Read, Write}; +use std::sync::Arc; + +use arrow::array::new_empty_array; +use arrow::record_batch::RecordBatch; +use polars_core::prelude::*; + +use crate::cloud::CloudOptions; +use crate::options::RowIndex; +#[cfg(any(feature = "ipc", feature = "avro", feature = "ipc_streaming",))] +use crate::predicates::PhysicalIoExpr; + +pub trait SerReader +where + R: Read, +{ + /// Create a new instance of the [`SerReader`] + fn new(reader: R) -> Self; + + /// Make sure that all columns are contiguous in memory by + /// aggregating the chunks into a single array. + #[must_use] + fn set_rechunk(self, _rechunk: bool) -> Self + where + Self: Sized, + { + self + } + + /// Take the SerReader and return a parsed DataFrame. + fn finish(self) -> PolarsResult; +} + +pub trait SerWriter +where + W: Write, +{ + fn new(writer: W) -> Self + where + Self: Sized; + fn finish(&mut self, df: &mut DataFrame) -> PolarsResult<()>; +} + +pub trait WriteDataFrameToFile { + fn write_df_to_file( + &self, + df: &mut DataFrame, + path: &str, + cloud_options: Option<&CloudOptions>, + ) -> PolarsResult<()>; +} + +pub trait ArrowReader { + fn next_record_batch(&mut self) -> PolarsResult>; +} + +#[cfg(any(feature = "ipc", feature = "avro", feature = "ipc_streaming",))] +pub(crate) fn finish_reader( + mut reader: R, + rechunk: bool, + n_rows: Option, + predicate: Option>, + arrow_schema: &ArrowSchema, + row_index: Option, +) -> PolarsResult { + use polars_core::utils::accumulate_dataframes_vertical_unchecked; + + let mut num_rows = 0; + let mut parsed_dfs = Vec::with_capacity(1024); + + while let Some(batch) = reader.next_record_batch()? { + let current_num_rows = num_rows as IdxSize; + num_rows += batch.len(); + let mut df = DataFrame::from(batch); + + if let Some(rc) = &row_index { + unsafe { df.with_row_index_mut(rc.name.clone(), Some(current_num_rows + rc.offset)) }; + } + + if let Some(predicate) = &predicate { + let s = predicate.evaluate_io(&df)?; + let mask = s.bool().expect("filter predicates was not of type boolean"); + df = df.filter(mask)?; + } + + if let Some(n) = n_rows { + if num_rows >= n { + let len = n - parsed_dfs + .iter() + .map(|df: &DataFrame| df.height()) + .sum::(); + if polars_core::config::verbose() { + eprintln!( + "sliced off {} rows of the 'DataFrame'. These lines were read because they were in a single chunk.", + df.height().saturating_sub(n) + ) + } + parsed_dfs.push(df.slice(0, len)); + break; + } + } + parsed_dfs.push(df); + } + + let mut df = { + if parsed_dfs.is_empty() { + // Create an empty dataframe with the correct data types + let empty_cols = arrow_schema + .iter_values() + .map(|fld| { + Series::try_from((fld.name.clone(), new_empty_array(fld.dtype.clone()))) + .map(Column::from) + }) + .collect::>()?; + DataFrame::new(empty_cols)? + } else { + // If there are any rows, accumulate them into a df + accumulate_dataframes_vertical_unchecked(parsed_dfs) + } + }; + + if rechunk { + df.as_single_chunk_par(); + } + Ok(df) +} + +pub fn schema_to_arrow_checked( + schema: &Schema, + compat_level: CompatLevel, + _file_name: &str, +) -> PolarsResult { + schema + .iter_fields() + .map(|field| { + #[cfg(feature = "object")] + { + polars_ensure!( + !matches!(field.dtype(), DataType::Object(_)), + ComputeError: "cannot write 'Object' datatype to {}", + _file_name + ); + } + + let field = field + .dtype() + .to_arrow_field(field.name().clone(), compat_level); + Ok((field.name.clone(), field)) + }) + .collect::>() +} diff --git a/crates/polars-io/src/utils/byte_source.rs b/crates/polars-io/src/utils/byte_source.rs new file mode 100644 index 000000000000..0bb5b7ca4b67 --- /dev/null +++ b/crates/polars-io/src/utils/byte_source.rs @@ -0,0 +1,193 @@ +use std::ops::Range; +use std::sync::Arc; + +use polars_core::prelude::PlHashMap; +use polars_error::PolarsResult; +use polars_utils::_limit_path_len_io_err; +use polars_utils::mmap::MemSlice; + +use crate::cloud::{ + CloudLocation, CloudOptions, ObjectStorePath, PolarsObjectStore, build_object_store, + object_path_from_str, +}; + +#[allow(async_fn_in_trait)] +pub trait ByteSource: Send + Sync { + async fn get_size(&self) -> PolarsResult; + /// # Panics + /// Panics if `range` is not in bounds. + async fn get_range(&self, range: Range) -> PolarsResult; + /// Note: This will mutably sort ranges for coalescing. + async fn get_ranges( + &self, + ranges: &mut [Range], + ) -> PolarsResult>; +} + +/// Byte source backed by a `MemSlice`, which can potentially be memory-mapped. +pub struct MemSliceByteSource(pub MemSlice); + +impl MemSliceByteSource { + async fn try_new_mmap_from_path( + path: &str, + _cloud_options: Option<&CloudOptions>, + ) -> PolarsResult { + let file = Arc::new( + tokio::fs::File::open(path) + .await + .map_err(|err| _limit_path_len_io_err(path.as_ref(), err))? + .into_std() + .await, + ); + + Ok(Self(MemSlice::from_file(file.as_ref())?)) + } +} + +impl ByteSource for MemSliceByteSource { + async fn get_size(&self) -> PolarsResult { + Ok(self.0.as_ref().len()) + } + + async fn get_range(&self, range: Range) -> PolarsResult { + let out = self.0.slice(range); + Ok(out) + } + + async fn get_ranges( + &self, + ranges: &mut [Range], + ) -> PolarsResult> { + Ok(ranges + .iter() + .map(|x| (x.start, self.0.slice(x.clone()))) + .collect()) + } +} + +pub struct ObjectStoreByteSource { + store: PolarsObjectStore, + path: ObjectStorePath, +} + +impl ObjectStoreByteSource { + async fn try_new_from_path( + path: &str, + cloud_options: Option<&CloudOptions>, + ) -> PolarsResult { + let (CloudLocation { prefix, .. }, store) = + build_object_store(path, cloud_options, false).await?; + let path = object_path_from_str(&prefix)?; + + Ok(Self { store, path }) + } +} + +impl ByteSource for ObjectStoreByteSource { + async fn get_size(&self) -> PolarsResult { + Ok(self.store.head(&self.path).await?.size) + } + + async fn get_range(&self, range: Range) -> PolarsResult { + let bytes = self.store.get_range(&self.path, range).await?; + let mem_slice = MemSlice::from_bytes(bytes); + + Ok(mem_slice) + } + + async fn get_ranges( + &self, + ranges: &mut [Range], + ) -> PolarsResult> { + self.store.get_ranges_sort(&self.path, ranges).await + } +} + +/// Dynamic dispatch to async functions. +pub enum DynByteSource { + MemSlice(MemSliceByteSource), + Cloud(ObjectStoreByteSource), +} + +impl DynByteSource { + pub fn variant_name(&self) -> &str { + match self { + Self::MemSlice(_) => "MemSlice", + Self::Cloud(_) => "Cloud", + } + } +} + +impl Default for DynByteSource { + fn default() -> Self { + Self::MemSlice(MemSliceByteSource(MemSlice::default())) + } +} + +impl ByteSource for DynByteSource { + async fn get_size(&self) -> PolarsResult { + match self { + Self::MemSlice(v) => v.get_size().await, + Self::Cloud(v) => v.get_size().await, + } + } + + async fn get_range(&self, range: Range) -> PolarsResult { + match self { + Self::MemSlice(v) => v.get_range(range).await, + Self::Cloud(v) => v.get_range(range).await, + } + } + + async fn get_ranges( + &self, + ranges: &mut [Range], + ) -> PolarsResult> { + match self { + Self::MemSlice(v) => v.get_ranges(ranges).await, + Self::Cloud(v) => v.get_ranges(ranges).await, + } + } +} + +impl From for DynByteSource { + fn from(value: MemSliceByteSource) -> Self { + Self::MemSlice(value) + } +} + +impl From for DynByteSource { + fn from(value: ObjectStoreByteSource) -> Self { + Self::Cloud(value) + } +} + +impl From for DynByteSource { + fn from(value: MemSlice) -> Self { + Self::MemSlice(MemSliceByteSource(value)) + } +} + +#[derive(Clone, Debug)] +pub enum DynByteSourceBuilder { + Mmap, + /// Supports both cloud and local files. + ObjectStore, +} + +impl DynByteSourceBuilder { + pub async fn try_build_from_path( + &self, + path: &str, + cloud_options: Option<&CloudOptions>, + ) -> PolarsResult { + Ok(match self { + Self::Mmap => MemSliceByteSource::try_new_mmap_from_path(path, cloud_options) + .await? + .into(), + Self::ObjectStore => ObjectStoreByteSource::try_new_from_path(path, cloud_options) + .await? + .into(), + }) + } +} diff --git a/crates/polars-io/src/utils/compression.rs b/crates/polars-io/src/utils/compression.rs new file mode 100644 index 000000000000..1875c8e75e7c --- /dev/null +++ b/crates/polars-io/src/utils/compression.rs @@ -0,0 +1,61 @@ +use std::io::Read; + +use polars_core::prelude::*; +use polars_error::{feature_gated, to_compute_err}; + +/// Represents the compression algorithms that we have decoders for +pub enum SupportedCompression { + GZIP, + ZLIB, + ZSTD, +} + +impl SupportedCompression { + /// If the given byte slice starts with the "magic" bytes for a supported compression family, return + /// that family, for unsupported/uncompressed slices, return None + pub fn check(bytes: &[u8]) -> Option { + if bytes.len() < 4 { + // not enough bytes to perform prefix checks + return None; + } + match bytes[..4] { + [31, 139, _, _] => Some(Self::GZIP), + [0x78, 0x01, _, _] | // ZLIB0 + [0x78, 0x9C, _, _] | // ZLIB1 + [0x78, 0xDA, _, _] // ZLIB2 + => Some(Self::ZLIB), + [0x28, 0xB5, 0x2F, 0xFD] => Some(Self::ZSTD), + _ => None, + } + } +} + +/// Decompress `bytes` if compression is detected, otherwise simply return it. +/// An `out` vec must be given for ownership of the decompressed data. +pub fn maybe_decompress_bytes<'a>(bytes: &'a [u8], out: &'a mut Vec) -> PolarsResult<&'a [u8]> { + assert!(out.is_empty()); + + if let Some(algo) = SupportedCompression::check(bytes) { + feature_gated!("decompress", { + match algo { + SupportedCompression::GZIP => { + flate2::read::MultiGzDecoder::new(bytes) + .read_to_end(out) + .map_err(to_compute_err)?; + }, + SupportedCompression::ZLIB => { + flate2::read::ZlibDecoder::new(bytes) + .read_to_end(out) + .map_err(to_compute_err)?; + }, + SupportedCompression::ZSTD => { + zstd::Decoder::with_buffer(bytes)?.read_to_end(out)?; + }, + } + + Ok(out) + }) + } else { + Ok(bytes) + } +} diff --git a/crates/polars-io/src/utils/file.rs b/crates/polars-io/src/utils/file.rs new file mode 100644 index 000000000000..6678ce456085 --- /dev/null +++ b/crates/polars-io/src/utils/file.rs @@ -0,0 +1,323 @@ +use std::io; +use std::ops::{Deref, DerefMut}; +use std::path::Path; + +#[cfg(feature = "cloud")] +pub use async_writeable::AsyncWriteable; +use polars_core::config; +use polars_error::{PolarsError, PolarsResult, feature_gated}; +use polars_utils::create_file; +use polars_utils::file::{ClosableFile, WriteClose}; +use polars_utils::mmap::ensure_not_mapped; + +use super::sync_on_close::SyncOnCloseType; +use crate::cloud::CloudOptions; +use crate::{is_cloud_url, resolve_homedir}; + +pub trait DynWriteable: io::Write + Send { + // Needed because trait upcasting is only stable in 1.86. + fn as_dyn_write(&self) -> &(dyn io::Write + Send + 'static); + fn as_mut_dyn_write(&mut self) -> &mut (dyn io::Write + Send + 'static); + + fn close(self: Box) -> io::Result<()>; + fn sync_on_close(&mut self, sync_on_close: SyncOnCloseType) -> io::Result<()>; +} + +impl DynWriteable for ClosableFile { + fn as_dyn_write(&self) -> &(dyn io::Write + Send + 'static) { + self as _ + } + fn as_mut_dyn_write(&mut self) -> &mut (dyn io::Write + Send + 'static) { + self as _ + } + fn close(self: Box) -> io::Result<()> { + ClosableFile::close(*self) + } + fn sync_on_close(&mut self, sync_on_close: SyncOnCloseType) -> io::Result<()> { + super::sync_on_close::sync_on_close(sync_on_close, self.as_mut()) + } +} + +/// Holds a non-async writeable file, abstracted over local files or cloud files. +/// +/// This implements `DerefMut` to a trait object implementing [`std::io::Write`]. +/// +/// Also see: `Writeable::try_into_async_writeable` and `AsyncWriteable`. +#[allow(clippy::large_enum_variant)] // It will be boxed +pub enum Writeable { + /// An abstract implementation for writable. + /// + /// This is used to implement writing to in-memory and arbitrary file descriptors. + Dyn(Box), + Local(std::fs::File), + #[cfg(feature = "cloud")] + Cloud(crate::cloud::BlockingCloudWriter), +} + +impl Writeable { + pub fn try_new( + path: &str, + #[cfg_attr(not(feature = "cloud"), allow(unused))] cloud_options: Option<&CloudOptions>, + ) -> PolarsResult { + let is_cloud = is_cloud_url(path); + let verbose = config::verbose(); + + if is_cloud { + feature_gated!("cloud", { + use crate::cloud::BlockingCloudWriter; + + if verbose { + eprintln!("Writeable: try_new: cloud: {}", path) + } + + if path.starts_with("file://") { + create_file(Path::new(&path[const { "file://".len() }..]))?; + } + + let writer = crate::pl_async::get_runtime() + .block_in_place_on(BlockingCloudWriter::new(path, cloud_options))?; + Ok(Self::Cloud(writer)) + }) + } else if config::force_async() { + feature_gated!("cloud", { + use crate::cloud::BlockingCloudWriter; + + let path = resolve_homedir(&path); + + if verbose { + eprintln!( + "Writeable: try_new: forced async: {}", + path.to_str().unwrap() + ) + } + + create_file(&path)?; + let path = std::fs::canonicalize(&path)?; + + ensure_not_mapped(&path.metadata()?)?; + + let path = format!( + "file://{}", + if cfg!(target_family = "windows") { + path.to_str().unwrap().strip_prefix(r#"\\?\"#).unwrap() + } else { + path.to_str().unwrap() + } + ); + + if verbose { + eprintln!("Writeable: try_new: forced async converted path: {}", path) + } + + let writer = crate::pl_async::get_runtime() + .block_in_place_on(BlockingCloudWriter::new(&path, cloud_options))?; + Ok(Self::Cloud(writer)) + }) + } else { + let path = resolve_homedir(&path); + create_file(&path)?; + + // Note: `canonicalize` does not work on some systems. + + if verbose { + eprintln!( + "Writeable: try_new: local: {} (canonicalize: {:?})", + path.to_str().unwrap(), + std::fs::canonicalize(&path) + ) + } + + Ok(Self::Local(polars_utils::open_file_write(&path)?)) + } + } + + /// This returns `Result<>` - if a write was performed before calling this, + /// `CloudWriter` can be in an Err(_) state. + #[cfg(feature = "cloud")] + pub fn try_into_async_writeable(self) -> PolarsResult { + use self::async_writeable::AsyncDynWriteable; + + match self { + Self::Dyn(v) => Ok(AsyncWriteable::Dyn(AsyncDynWriteable(v))), + Self::Local(v) => Ok(AsyncWriteable::Local(tokio::fs::File::from_std(v))), + // Moves the `BufWriter` out of the `BlockingCloudWriter` wrapper, as + // `BlockingCloudWriter` has a `Drop` impl that we don't want. + Self::Cloud(v) => v + .try_into_inner() + .map(AsyncWriteable::Cloud) + .map_err(PolarsError::from), + } + } + + pub fn sync_on_close(&mut self, sync_on_close: SyncOnCloseType) -> std::io::Result<()> { + match self { + Writeable::Dyn(d) => d.sync_on_close(sync_on_close), + Writeable::Local(file) => { + crate::utils::sync_on_close::sync_on_close(sync_on_close, file) + }, + #[cfg(feature = "cloud")] + Writeable::Cloud(_) => Ok(()), + } + } + + pub fn close(self) -> std::io::Result<()> { + match self { + Self::Dyn(v) => v.close(), + Self::Local(v) => ClosableFile::from(v).close(), + #[cfg(feature = "cloud")] + Self::Cloud(mut v) => v.close(), + } + } +} + +impl Deref for Writeable { + type Target = dyn io::Write + Send; + + fn deref(&self) -> &Self::Target { + match self { + Self::Dyn(v) => v.as_dyn_write(), + Self::Local(v) => v, + #[cfg(feature = "cloud")] + Self::Cloud(v) => v, + } + } +} + +impl DerefMut for Writeable { + fn deref_mut(&mut self) -> &mut Self::Target { + match self { + Self::Dyn(v) => v.as_mut_dyn_write(), + Self::Local(v) => v, + #[cfg(feature = "cloud")] + Self::Cloud(v) => v, + } + } +} + +/// Note: Prefer using [`Writeable`] / [`Writeable::try_new`] where possible. +/// +/// Open a path for writing. Supports cloud paths. +pub fn try_get_writeable( + path: &str, + cloud_options: Option<&CloudOptions>, +) -> PolarsResult> { + Writeable::try_new(path, cloud_options).map(|x| match x { + Writeable::Dyn(_) => unreachable!(), + Writeable::Local(v) => Box::new(ClosableFile::from(v)) as Box, + #[cfg(feature = "cloud")] + Writeable::Cloud(v) => Box::new(v) as Box, + }) +} + +#[cfg(feature = "cloud")] +mod async_writeable { + use std::io; + use std::ops::{Deref, DerefMut}; + use std::pin::Pin; + use std::task::{Context, Poll}; + + use polars_error::{PolarsError, PolarsResult}; + use polars_utils::file::ClosableFile; + use tokio::io::AsyncWriteExt; + use tokio::task; + + use super::{DynWriteable, Writeable}; + use crate::cloud::CloudOptions; + use crate::utils::sync_on_close::SyncOnCloseType; + + /// Turn an abstract io::Write into an abstract tokio::io::AsyncWrite. + pub struct AsyncDynWriteable(pub Box); + + impl tokio::io::AsyncWrite for AsyncDynWriteable { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let result = task::block_in_place(|| self.get_mut().0.write(buf)); + Poll::Ready(result) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + let result = task::block_in_place(|| self.get_mut().0.flush()); + Poll::Ready(result) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_flush(cx) + } + } + + /// Holds an async writeable file, abstracted over local files or cloud files. + /// + /// This implements `DerefMut` to a trait object implementing [`tokio::io::AsyncWrite`]. + /// + /// Note: It is important that you do not call `shutdown()` on the deref'ed `AsyncWrite` object. + /// You should instead call the [`AsyncWriteable::close`] at the end. + pub enum AsyncWriteable { + Dyn(AsyncDynWriteable), + Local(tokio::fs::File), + Cloud(object_store::buffered::BufWriter), + } + + impl AsyncWriteable { + pub async fn try_new( + path: &str, + cloud_options: Option<&CloudOptions>, + ) -> PolarsResult { + // TODO: Native async impl + Writeable::try_new(path, cloud_options).and_then(|x| x.try_into_async_writeable()) + } + + pub async fn sync_on_close( + &mut self, + sync_on_close: SyncOnCloseType, + ) -> std::io::Result<()> { + match self { + Self::Dyn(d) => task::block_in_place(|| d.0.sync_on_close(sync_on_close)), + Self::Local(file) => { + crate::utils::sync_on_close::tokio_sync_on_close(sync_on_close, file).await + }, + Self::Cloud(_) => Ok(()), + } + } + + pub async fn close(self) -> PolarsResult<()> { + match self { + Self::Dyn(mut v) => { + v.shutdown().await.map_err(PolarsError::from)?; + Ok(task::block_in_place(|| v.0.close())?) + }, + Self::Local(v) => async { + let f = v.into_std().await; + ClosableFile::from(f).close() + } + .await + .map_err(PolarsError::from), + Self::Cloud(mut v) => v.shutdown().await.map_err(PolarsError::from), + } + } + } + + impl Deref for AsyncWriteable { + type Target = dyn tokio::io::AsyncWrite + Send + Unpin; + + fn deref(&self) -> &Self::Target { + match self { + Self::Dyn(v) => v, + Self::Local(v) => v, + Self::Cloud(v) => v, + } + } + } + + impl DerefMut for AsyncWriteable { + fn deref_mut(&mut self) -> &mut Self::Target { + match self { + Self::Dyn(v) => v, + Self::Local(v) => v, + Self::Cloud(v) => v, + } + } + } +} diff --git a/crates/polars-io/src/utils/mkdir.rs b/crates/polars-io/src/utils/mkdir.rs new file mode 100644 index 000000000000..7947c40713b2 --- /dev/null +++ b/crates/polars-io/src/utils/mkdir.rs @@ -0,0 +1,20 @@ +use std::io; +use std::path::Path; + +pub fn mkdir_recursive(path: &Path) -> io::Result<()> { + std::fs::DirBuilder::new().recursive(true).create( + path.parent() + .ok_or(io::Error::other("path is not a file"))?, + ) +} + +#[cfg(feature = "tokio")] +pub async fn tokio_mkdir_recursive(path: &Path) -> io::Result<()> { + tokio::fs::DirBuilder::new() + .recursive(true) + .create( + path.parent() + .ok_or(io::Error::other("path is not a file"))?, + ) + .await +} diff --git a/crates/polars-io/src/utils/mod.rs b/crates/polars-io/src/utils/mod.rs new file mode 100644 index 000000000000..3b211ef460cf --- /dev/null +++ b/crates/polars-io/src/utils/mod.rs @@ -0,0 +1,17 @@ +pub mod compression; +mod other; + +pub use other::*; +#[cfg(feature = "cloud")] +pub mod byte_source; +pub mod file; +pub mod mkdir; +pub mod slice; +pub mod sync_on_close; + +pub const URL_ENCODE_CHAR_SET: &percent_encoding::AsciiSet = &percent_encoding::CONTROLS + .add(b'/') + .add(b'=') + .add(b':') + .add(b' ') + .add(b'%'); diff --git a/crates/polars-io/src/utils/other.rs b/crates/polars-io/src/utils/other.rs new file mode 100644 index 000000000000..956d1b256108 --- /dev/null +++ b/crates/polars-io/src/utils/other.rs @@ -0,0 +1,240 @@ +use std::io::Read; +#[cfg(target_os = "emscripten")] +use std::io::{Seek, SeekFrom}; + +use polars_core::prelude::*; +use polars_utils::mmap::{MMapSemaphore, MemSlice}; + +use crate::mmap::{MmapBytesReader, ReaderBytes}; + +pub fn get_reader_bytes( + reader: &mut R, +) -> PolarsResult> { + // we have a file so we can mmap + // only seekable files are mmap-able + if let Some((file, offset)) = reader + .stream_position() + .ok() + .and_then(|offset| Some((reader.to_file()?, offset))) + { + let mut options = memmap::MmapOptions::new(); + options.offset(offset); + + // Set mmap size based on seek to end when running under Emscripten + #[cfg(target_os = "emscripten")] + { + let mut file = file; + let size = file.seek(SeekFrom::End(0)).unwrap(); + options.len((size - offset) as usize); + } + + let mmap = MMapSemaphore::new_from_file_with_options(file, options)?; + Ok(ReaderBytes::Owned(MemSlice::from_mmap(Arc::new(mmap)))) + } else { + // we can get the bytes for free + if reader.to_bytes().is_some() { + // duplicate .to_bytes() is necessary to satisfy the borrow checker + Ok(ReaderBytes::Borrowed((*reader).to_bytes().unwrap())) + } else { + // we have to read to an owned buffer to get the bytes. + let mut bytes = Vec::with_capacity(1024 * 128); + reader.read_to_end(&mut bytes)?; + Ok(ReaderBytes::Owned(bytes.into())) + } + } +} + +#[cfg(any( + feature = "ipc", + feature = "ipc_streaming", + feature = "parquet", + feature = "avro" +))] +pub fn apply_projection(schema: &ArrowSchema, projection: &[usize]) -> ArrowSchema { + projection + .iter() + .map(|idx| schema.get_at_index(*idx).unwrap()) + .map(|(k, v)| (k.clone(), v.clone())) + .collect() +} + +#[cfg(any( + feature = "ipc", + feature = "ipc_streaming", + feature = "avro", + feature = "parquet" +))] +pub fn columns_to_projection>( + columns: &[T], + schema: &ArrowSchema, +) -> PolarsResult> { + let mut prj = Vec::with_capacity(columns.len()); + + for column in columns { + let i = schema.try_index_of(column.as_ref())?; + prj.push(i); + } + + Ok(prj) +} + +#[cfg(debug_assertions)] +fn check_offsets(dfs: &[DataFrame]) { + dfs.windows(2).for_each(|s| { + let a = &s[0].get_columns()[0]; + let b = &s[1].get_columns()[0]; + + let prev = a.get(a.len() - 1).unwrap().extract::().unwrap(); + let next = b.get(0).unwrap().extract::().unwrap(); + assert_eq!(prev + 1, next); + }) +} + +/// Because of threading every row starts from `0` or from `offset`. +/// We must correct that so that they are monotonically increasing. +#[cfg(any(feature = "csv", feature = "json"))] +pub(crate) fn update_row_counts2(dfs: &mut [DataFrame], offset: IdxSize) { + if !dfs.is_empty() { + let mut previous = offset; + for df in &mut *dfs { + if df.is_empty() { + continue; + } + let n_read = df.height() as IdxSize; + if let Some(s) = unsafe { df.get_columns_mut() }.get_mut(0) { + if let Ok(v) = s.get(0) { + if v.extract::().unwrap() != previous as usize { + *s = &*s + previous; + } + } + } + previous += n_read; + } + } + #[cfg(debug_assertions)] + { + check_offsets(dfs) + } +} + +/// Because of threading every row starts from `0` or from `offset`. +/// We must correct that so that they are monotonically increasing. +#[cfg(feature = "json")] +pub(crate) fn update_row_counts3(dfs: &mut [DataFrame], heights: &[IdxSize], offset: IdxSize) { + assert_eq!(dfs.len(), heights.len()); + if !dfs.is_empty() { + let mut previous = offset; + for i in 0..dfs.len() { + let df = &mut dfs[i]; + if df.is_empty() { + continue; + } + + if let Some(s) = unsafe { df.get_columns_mut() }.get_mut(0) { + if let Ok(v) = s.get(0) { + if v.extract::().unwrap() != previous as usize { + *s = &*s + previous; + } + } + } + let n_read = heights[i]; + previous += n_read; + } + } +} + +#[cfg(feature = "json")] +pub fn overwrite_schema(schema: &mut Schema, overwriting_schema: &Schema) -> PolarsResult<()> { + for (k, value) in overwriting_schema.iter() { + *schema.try_get_mut(k)? = value.clone(); + } + Ok(()) +} + +polars_utils::regex_cache::cached_regex! { + pub static FLOAT_RE = r"^[-+]?((\d*\.\d+)([eE][-+]?\d+)?|inf|NaN|(\d+)[eE][-+]?\d+|\d+\.)$"; + pub static FLOAT_RE_DECIMAL = r"^[-+]?((\d*,\d+)([eE][-+]?\d+)?|inf|NaN|(\d+)[eE][-+]?\d+|\d+,)$"; + pub static INTEGER_RE = r"^-?(\d+)$"; + pub static BOOLEAN_RE = r"^(?i:true|false)$"; +} + +pub fn materialize_projection( + with_columns: Option<&[PlSmallStr]>, + schema: &Schema, + hive_partitions: Option<&[Series]>, + has_row_index: bool, +) -> Option> { + match hive_partitions { + None => with_columns.map(|with_columns| { + with_columns + .iter() + .map(|name| schema.index_of(name).unwrap() - has_row_index as usize) + .collect() + }), + Some(part_cols) => { + with_columns.map(|with_columns| { + with_columns + .iter() + .flat_map(|name| { + // the hive partitions are added at the end of the schema, but we don't want to project + // them from the file + if part_cols.iter().any(|s| s.name() == name.as_str()) { + None + } else { + Some(schema.index_of(name).unwrap() - has_row_index as usize) + } + }) + .collect() + }) + }, + } +} + +/// Utility for decoding JSON that adds the response value to the error message if decoding fails. +/// This makes it much easier to debug errors from parsing network responses. +#[cfg(feature = "cloud")] +pub fn decode_json_response(bytes: &[u8]) -> PolarsResult +where + T: for<'de> serde::de::Deserialize<'de>, +{ + use polars_error::to_compute_err; + use polars_utils::error::TruncateErrorDetail; + + serde_json::from_slice(bytes) + .map_err(to_compute_err) + .map_err(|e| { + e.wrap_msg(|e| { + format!( + "error decoding response: {}, response value: {}", + e, + TruncateErrorDetail(&String::from_utf8_lossy(bytes)) + ) + }) + }) +} + +#[cfg(test)] +mod tests { + use super::FLOAT_RE; + + #[test] + fn test_float_parse() { + assert!(FLOAT_RE.is_match("0.1")); + assert!(FLOAT_RE.is_match("3.0")); + assert!(FLOAT_RE.is_match("3.00001")); + assert!(FLOAT_RE.is_match("-9.9990e-003")); + assert!(FLOAT_RE.is_match("9.9990e+003")); + assert!(FLOAT_RE.is_match("9.9990E+003")); + assert!(FLOAT_RE.is_match("9.9990E+003")); + assert!(FLOAT_RE.is_match(".5")); + assert!(FLOAT_RE.is_match("2.5E-10")); + assert!(FLOAT_RE.is_match("2.5e10")); + assert!(FLOAT_RE.is_match("NaN")); + assert!(FLOAT_RE.is_match("-NaN")); + assert!(FLOAT_RE.is_match("-inf")); + assert!(FLOAT_RE.is_match("inf")); + assert!(FLOAT_RE.is_match("-7e-05")); + assert!(FLOAT_RE.is_match("7e-05")); + assert!(FLOAT_RE.is_match("+7e+05")); + } +} diff --git a/crates/polars-io/src/utils/slice.rs b/crates/polars-io/src/utils/slice.rs new file mode 100644 index 000000000000..24a3b7dc1ab8 --- /dev/null +++ b/crates/polars-io/src/utils/slice.rs @@ -0,0 +1,58 @@ +/// Given a `slice` that is relative to the start of a list of files, calculate the slice to apply +/// at a file with a row offset of `current_row_offset`. +pub fn split_slice_at_file( + current_row_offset_ref: &mut usize, + n_rows_this_file: usize, + global_slice_start: usize, + global_slice_end: usize, +) -> (usize, usize) { + let current_row_offset = *current_row_offset_ref; + *current_row_offset_ref += n_rows_this_file; + match SplitSlicePosition::split_slice_at_file( + current_row_offset, + n_rows_this_file, + global_slice_start..global_slice_end, + ) { + SplitSlicePosition::Overlapping(offset, len) => (offset, len), + SplitSlicePosition::Before | SplitSlicePosition::After => (0, 0), + } +} + +#[derive(Debug)] +pub enum SplitSlicePosition { + Before, + Overlapping(usize, usize), + After, +} + +impl SplitSlicePosition { + pub fn split_slice_at_file( + current_row_offset: usize, + n_rows_this_file: usize, + global_slice: std::ops::Range, + ) -> Self { + // e.g. + // slice: (start: 1, end: 2) + // files: + // 0: (1 row): current_offset: 0, next_file_offset: 1 + // 1: (1 row): current_offset: 1, next_file_offset: 2 + // 2: (1 row): current_offset: 2, next_file_offset: 3 + // in this example we want to include only file 1. + + let next_row_offset = current_row_offset + n_rows_this_file; + + if next_row_offset <= global_slice.start { + Self::Before + } else if current_row_offset >= global_slice.end { + Self::After + } else { + let n_rows_to_skip = global_slice.start.saturating_sub(current_row_offset); + let n_excess_rows = next_row_offset.saturating_sub(global_slice.end); + + Self::Overlapping( + n_rows_to_skip, + n_rows_this_file - n_rows_to_skip - n_excess_rows, + ) + } + } +} diff --git a/crates/polars-io/src/utils/sync_on_close.rs b/crates/polars-io/src/utils/sync_on_close.rs new file mode 100644 index 000000000000..d18ca4d24b2a --- /dev/null +++ b/crates/polars-io/src/utils/sync_on_close.rs @@ -0,0 +1,34 @@ +use std::{fs, io}; + +#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub enum SyncOnCloseType { + /// Don't call sync on close. + #[default] + None, + + /// Sync only the file contents. + Data, + /// Synce the file contents and the metadata. + All, +} + +pub fn sync_on_close(sync_on_close: SyncOnCloseType, file: &mut fs::File) -> io::Result<()> { + match sync_on_close { + SyncOnCloseType::None => Ok(()), + SyncOnCloseType::Data => file.sync_data(), + SyncOnCloseType::All => file.sync_all(), + } +} + +#[cfg(feature = "tokio")] +pub async fn tokio_sync_on_close( + sync_on_close: SyncOnCloseType, + file: &mut tokio::fs::File, +) -> io::Result<()> { + match sync_on_close { + SyncOnCloseType::None => Ok(()), + SyncOnCloseType::Data => file.sync_data().await, + SyncOnCloseType::All => file.sync_all().await, + } +} diff --git a/crates/polars-json/Cargo.toml b/crates/polars-json/Cargo.toml new file mode 100644 index 000000000000..de5414ab3eb7 --- /dev/null +++ b/crates/polars-json/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "polars-json" +version = { workspace = true } +authors = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +license = { workspace = true } +repository = { workspace = true } +description = "JSON related logic for the Polars DataFrame library" + +[dependencies] +polars-compute = { workspace = true, features = ["cast"] } +polars-error = { workspace = true } +polars-utils = { workspace = true } + +arrow = { workspace = true } +chrono = { workspace = true } +chrono-tz = { workspace = true, optional = true } +fallible-streaming-iterator = { version = "0.1" } +hashbrown = { workspace = true } +indexmap = { workspace = true } +itoa = { workspace = true } +num-traits = { workspace = true } +ryu = { workspace = true } +simd-json = { workspace = true } +streaming-iterator = { workspace = true } + +[features] +chrono-tz = ["dep:chrono-tz", "arrow/chrono-tz"] +dtype-decimal = ["arrow/dtype-decimal"] +timezones = ["arrow/chrono-tz"] diff --git a/crates/polars-json/LICENSE b/crates/polars-json/LICENSE new file mode 120000 index 000000000000..30cff7403da0 --- /dev/null +++ b/crates/polars-json/LICENSE @@ -0,0 +1 @@ +../../LICENSE \ No newline at end of file diff --git a/crates/polars-json/README.md b/crates/polars-json/README.md new file mode 100644 index 000000000000..f056af4c72b8 --- /dev/null +++ b/crates/polars-json/README.md @@ -0,0 +1,7 @@ +# polars-json + +`polars-json` is an **internal sub-crate** of the [Polars](https://crates.io/crates/polars) library, +provides functionalities to handle JSON objects. + +**Important Note**: This crate is **not intended for external usage**. Please refer to the main +[Polars crate](https://crates.io/crates/polars) for intended usage. diff --git a/crates/polars-json/src/json/deserialize.rs b/crates/polars-json/src/json/deserialize.rs new file mode 100644 index 000000000000..9630754a85aa --- /dev/null +++ b/crates/polars-json/src/json/deserialize.rs @@ -0,0 +1,467 @@ +use std::borrow::Borrow; +use std::fmt::Write; + +use arrow::array::*; +use arrow::bitmap::BitmapBuilder; +use arrow::datatypes::{ArrowDataType, IntervalUnit}; +use arrow::offset::{Offset, Offsets}; +use arrow::temporal_conversions; +use arrow::types::NativeType; +use num_traits::NumCast; +use simd_json::{BorrowedValue, StaticNode}; + +use super::*; + +const JSON_NULL_VALUE: BorrowedValue = BorrowedValue::Static(StaticNode::Null); + +fn deserialize_boolean_into<'a, A: Borrow>>( + target: &mut MutableBooleanArray, + rows: &[A], +) -> PolarsResult<()> { + let mut err_idx = rows.len(); + let iter = rows.iter().enumerate().map(|(i, row)| match row.borrow() { + BorrowedValue::Static(StaticNode::Bool(v)) => Some(v), + BorrowedValue::Static(StaticNode::Null) => None, + _ => { + err_idx = if err_idx == rows.len() { i } else { err_idx }; + None + }, + }); + target.extend_trusted_len(iter); + check_err_idx(rows, err_idx, "boolean") +} + +fn deserialize_primitive_into<'a, T: NativeType + NumCast, A: Borrow>>( + target: &mut MutablePrimitiveArray, + rows: &[A], +) -> PolarsResult<()> { + let mut err_idx = rows.len(); + let iter = rows.iter().enumerate().map(|(i, row)| match row.borrow() { + BorrowedValue::Static(StaticNode::I64(v)) => T::from(*v), + BorrowedValue::Static(StaticNode::U64(v)) => T::from(*v), + BorrowedValue::Static(StaticNode::F64(v)) => T::from(*v), + BorrowedValue::Static(StaticNode::Bool(v)) => T::from(*v as u8), + BorrowedValue::Static(StaticNode::Null) => None, + _ => { + err_idx = if err_idx == rows.len() { i } else { err_idx }; + None + }, + }); + target.extend_trusted_len(iter); + check_err_idx(rows, err_idx, "numeric") +} + +fn deserialize_binary<'a, A: Borrow>>( + rows: &[A], +) -> PolarsResult> { + let mut err_idx = rows.len(); + let iter = rows.iter().enumerate().map(|(i, row)| match row.borrow() { + BorrowedValue::String(v) => Some(v.as_bytes()), + BorrowedValue::Static(StaticNode::Null) => None, + _ => { + err_idx = if err_idx == rows.len() { i } else { err_idx }; + None + }, + }); + let out = BinaryArray::from_trusted_len_iter(iter); + check_err_idx(rows, err_idx, "binary")?; + Ok(out) +} + +fn deserialize_utf8_into<'a, O: Offset, A: Borrow>>( + target: &mut MutableUtf8Array, + rows: &[A], +) -> PolarsResult<()> { + let mut err_idx = rows.len(); + let mut scratch = String::new(); + for (i, row) in rows.iter().enumerate() { + match row.borrow() { + BorrowedValue::String(v) => target.push(Some(v.as_ref())), + BorrowedValue::Static(StaticNode::Bool(v)) => { + target.push(Some(if *v { "true" } else { "false" })) + }, + BorrowedValue::Static(StaticNode::Null) => target.push_null(), + BorrowedValue::Static(node) => { + write!(scratch, "{node}").unwrap(); + target.push(Some(scratch.as_str())); + scratch.clear(); + }, + _ => { + err_idx = if err_idx == rows.len() { i } else { err_idx }; + }, + } + } + check_err_idx(rows, err_idx, "string") +} + +fn deserialize_utf8view_into<'a, A: Borrow>>( + target: &mut MutableBinaryViewArray, + rows: &[A], +) -> PolarsResult<()> { + let mut err_idx = rows.len(); + let mut scratch = String::new(); + for (i, row) in rows.iter().enumerate() { + match row.borrow() { + BorrowedValue::String(v) => target.push_value(v.as_ref()), + BorrowedValue::Static(StaticNode::Bool(v)) => { + target.push_value(if *v { "true" } else { "false" }) + }, + BorrowedValue::Static(StaticNode::Null) => target.push_null(), + BorrowedValue::Static(node) => { + write!(scratch, "{node}").unwrap(); + target.push_value(scratch.as_str()); + scratch.clear(); + }, + _ => { + err_idx = if err_idx == rows.len() { i } else { err_idx }; + }, + } + } + check_err_idx(rows, err_idx, "string") +} + +fn deserialize_list<'a, A: Borrow>>( + rows: &[A], + dtype: ArrowDataType, + allow_extra_fields_in_struct: bool, +) -> PolarsResult> { + let mut err_idx = rows.len(); + let child = ListArray::::get_child_type(&dtype); + + let mut validity = BitmapBuilder::with_capacity(rows.len()); + let mut offsets = Offsets::::with_capacity(rows.len()); + let mut inner = vec![]; + rows.iter() + .enumerate() + .for_each(|(i, row)| match row.borrow() { + BorrowedValue::Array(value) => { + inner.extend(value.iter()); + validity.push(true); + offsets + .try_push(value.len()) + .expect("List offset is too large :/"); + }, + BorrowedValue::Static(StaticNode::Null) => { + validity.push(false); + offsets.extend_constant(1) + }, + value @ (BorrowedValue::Static(_) | BorrowedValue::String(_)) => { + inner.push(value); + validity.push(true); + offsets.try_push(1).expect("List offset is too large :/"); + }, + _ => { + err_idx = if err_idx == rows.len() { i } else { err_idx }; + }, + }); + + check_err_idx(rows, err_idx, "list")?; + + let values = _deserialize(&inner, child.clone(), allow_extra_fields_in_struct)?; + + Ok(ListArray::::new( + dtype, + offsets.into(), + values, + validity.into_opt_validity(), + )) +} + +fn deserialize_struct<'a, A: Borrow>>( + rows: &[A], + dtype: ArrowDataType, + allow_extra_fields_in_struct: bool, +) -> PolarsResult { + let mut err_idx = rows.len(); + let fields = StructArray::get_fields(&dtype); + + let mut out_values = fields + .iter() + .map(|f| (f.name.as_str(), (f.dtype(), vec![]))) + .collect::>(); + + let mut validity = BitmapBuilder::with_capacity(rows.len()); + // Custom error tracker + let mut extra_field = None; + + rows.iter().enumerate().for_each(|(i, row)| { + match row.borrow() { + BorrowedValue::Object(values) => { + let mut n_matched = 0usize; + for (&key, &mut (_, ref mut inner)) in out_values.iter_mut() { + if let Some(v) = values.get(key) { + n_matched += 1; + inner.push(v) + } else { + inner.push(&JSON_NULL_VALUE) + } + } + + validity.push(true); + + if n_matched < values.len() && extra_field.is_none() { + for k in values.keys() { + if !out_values.contains_key(k.as_ref()) { + extra_field = Some(k.as_ref()) + } + } + } + }, + BorrowedValue::Static(StaticNode::Null) => { + out_values + .iter_mut() + .for_each(|(_, (_, inner))| inner.push(&JSON_NULL_VALUE)); + validity.push(false); + }, + _ => { + err_idx = if err_idx == rows.len() { i } else { err_idx }; + }, + }; + }); + + if let Some(v) = extra_field { + if !allow_extra_fields_in_struct { + polars_bail!( + ComputeError: + "extra field in struct data: {}, consider increasing infer_schema_length, or \ + manually specifying the full schema to ignore extra fields", + v + ) + } + } + + check_err_idx(rows, err_idx, "struct")?; + + // ensure we collect in the proper order + let values = fields + .iter() + .map(|fld| { + let (dtype, vals) = out_values.get(fld.name.as_str()).unwrap(); + _deserialize(vals, (*dtype).clone(), allow_extra_fields_in_struct) + }) + .collect::>>()?; + + Ok(StructArray::new( + dtype.clone(), + rows.len(), + values, + validity.into_opt_validity(), + )) +} + +fn fill_array_from( + f: fn(&mut MutablePrimitiveArray, &[B]) -> PolarsResult<()>, + dtype: ArrowDataType, + rows: &[B], +) -> PolarsResult> +where + T: NativeType, + A: From> + Array, +{ + let mut array = MutablePrimitiveArray::::with_capacity(rows.len()).to(dtype); + f(&mut array, rows)?; + Ok(Box::new(A::from(array))) +} + +/// A trait describing an array with a backing store that can be preallocated to +/// a given size. +pub(crate) trait Container { + /// Create this array with a given capacity. + fn with_capacity(capacity: usize) -> Self + where + Self: Sized; +} + +impl Container for MutableBinaryArray { + fn with_capacity(capacity: usize) -> Self { + MutableBinaryArray::with_capacity(capacity) + } +} + +impl Container for MutableBooleanArray { + fn with_capacity(capacity: usize) -> Self { + MutableBooleanArray::with_capacity(capacity) + } +} + +impl Container for MutableFixedSizeBinaryArray { + fn with_capacity(capacity: usize) -> Self { + MutableFixedSizeBinaryArray::with_capacity(capacity, 0) + } +} + +impl Container for MutableBinaryViewArray { + fn with_capacity(capacity: usize) -> Self + where + Self: Sized, + { + MutableBinaryViewArray::with_capacity(capacity) + } +} + +impl Container for MutableListArray { + fn with_capacity(capacity: usize) -> Self { + MutableListArray::with_capacity(capacity) + } +} + +impl Container for MutablePrimitiveArray { + fn with_capacity(capacity: usize) -> Self { + MutablePrimitiveArray::with_capacity(capacity) + } +} + +impl Container for MutableUtf8Array { + fn with_capacity(capacity: usize) -> Self { + MutableUtf8Array::with_capacity(capacity) + } +} + +fn fill_generic_array_from( + f: fn(&mut M, &[B]) -> PolarsResult<()>, + rows: &[B], +) -> PolarsResult> +where + M: Container, + A: From + Array, +{ + let mut array = M::with_capacity(rows.len()); + f(&mut array, rows)?; + Ok(Box::new(A::from(array))) +} + +pub(crate) fn _deserialize<'a, A: Borrow>>( + rows: &[A], + dtype: ArrowDataType, + allow_extra_fields_in_struct: bool, +) -> PolarsResult> { + match &dtype { + ArrowDataType::Null => { + if let Some(err_idx) = (0..rows.len()) + .find(|i| !matches!(rows[*i].borrow(), BorrowedValue::Static(StaticNode::Null))) + { + check_err_idx(rows, err_idx, "null")?; + } + + Ok(Box::new(NullArray::new(dtype, rows.len()))) + }, + ArrowDataType::Boolean => { + fill_generic_array_from::<_, _, BooleanArray>(deserialize_boolean_into, rows) + }, + ArrowDataType::Int8 => { + fill_array_from::<_, _, PrimitiveArray>(deserialize_primitive_into, dtype, rows) + }, + ArrowDataType::Int16 => { + fill_array_from::<_, _, PrimitiveArray>(deserialize_primitive_into, dtype, rows) + }, + ArrowDataType::Int32 + | ArrowDataType::Date32 + | ArrowDataType::Time32(_) + | ArrowDataType::Interval(IntervalUnit::YearMonth) => { + fill_array_from::<_, _, PrimitiveArray>(deserialize_primitive_into, dtype, rows) + }, + ArrowDataType::Interval(IntervalUnit::DayTime) => { + unimplemented!("There is no natural representation of DayTime in JSON.") + }, + ArrowDataType::Int64 + | ArrowDataType::Date64 + | ArrowDataType::Time64(_) + | ArrowDataType::Duration(_) => { + fill_array_from::<_, _, PrimitiveArray>(deserialize_primitive_into, dtype, rows) + }, + ArrowDataType::Timestamp(tu, tz) => { + let mut err_idx = rows.len(); + let iter = rows.iter().enumerate().map(|(i, row)| match row.borrow() { + BorrowedValue::Static(StaticNode::I64(v)) => Some(*v), + BorrowedValue::String(v) => match (tu, tz) { + (_, None) => { + polars_compute::cast::temporal::utf8_to_naive_timestamp_scalar(v, "%+", tu) + }, + (_, Some(tz)) => { + let tz = temporal_conversions::parse_offset(tz.as_str()).unwrap(); + temporal_conversions::utf8_to_timestamp_scalar(v, "%+", &tz, tu) + }, + }, + BorrowedValue::Static(StaticNode::Null) => None, + _ => { + err_idx = if err_idx == rows.len() { i } else { err_idx }; + None + }, + }); + let out = Box::new(Int64Array::from_iter(iter).to(dtype)); + check_err_idx(rows, err_idx, "timestamp")?; + Ok(out) + }, + ArrowDataType::UInt8 => { + fill_array_from::<_, _, PrimitiveArray>(deserialize_primitive_into, dtype, rows) + }, + ArrowDataType::UInt16 => { + fill_array_from::<_, _, PrimitiveArray>(deserialize_primitive_into, dtype, rows) + }, + ArrowDataType::UInt32 => { + fill_array_from::<_, _, PrimitiveArray>(deserialize_primitive_into, dtype, rows) + }, + ArrowDataType::UInt64 => { + fill_array_from::<_, _, PrimitiveArray>(deserialize_primitive_into, dtype, rows) + }, + ArrowDataType::Float16 => unreachable!(), + ArrowDataType::Float32 => { + fill_array_from::<_, _, PrimitiveArray>(deserialize_primitive_into, dtype, rows) + }, + ArrowDataType::Float64 => { + fill_array_from::<_, _, PrimitiveArray>(deserialize_primitive_into, dtype, rows) + }, + ArrowDataType::LargeUtf8 => { + fill_generic_array_from::<_, _, Utf8Array>(deserialize_utf8_into, rows) + }, + ArrowDataType::Utf8View => { + fill_generic_array_from::<_, _, Utf8ViewArray>(deserialize_utf8view_into, rows) + }, + ArrowDataType::LargeList(_) => Ok(Box::new(deserialize_list( + rows, + dtype, + allow_extra_fields_in_struct, + )?)), + ArrowDataType::LargeBinary => Ok(Box::new(deserialize_binary(rows)?)), + ArrowDataType::Struct(_) => Ok(Box::new(deserialize_struct( + rows, + dtype, + allow_extra_fields_in_struct, + )?)), + _ => todo!(), + } +} + +pub fn deserialize( + json: &BorrowedValue, + dtype: ArrowDataType, + allow_extra_fields_in_struct: bool, +) -> PolarsResult> { + match json { + BorrowedValue::Array(rows) => match dtype { + ArrowDataType::LargeList(inner) => { + _deserialize(rows, inner.dtype, allow_extra_fields_in_struct) + }, + _ => todo!("read an Array from a non-Array data type"), + }, + _ => _deserialize(&[json], dtype, allow_extra_fields_in_struct), + } +} + +fn check_err_idx<'a>( + rows: &[impl Borrow>], + err_idx: usize, + type_name: &'static str, +) -> PolarsResult<()> { + if err_idx != rows.len() { + polars_bail!( + ComputeError: + r#"error deserializing value "{:?}" as {}. \ + Try increasing `infer_schema_length` or specifying a schema. + "#, + rows[err_idx].borrow(), type_name, + ) + } + + Ok(()) +} diff --git a/crates/polars-json/src/json/infer_schema.rs b/crates/polars-json/src/json/infer_schema.rs new file mode 100644 index 000000000000..f00ed6e26eb8 --- /dev/null +++ b/crates/polars-json/src/json/infer_schema.rs @@ -0,0 +1,179 @@ +use std::borrow::Borrow; + +use arrow::datatypes::{ArrowDataType, Field}; +use indexmap::map::Entry; +use polars_utils::pl_str::PlSmallStr; +use simd_json::borrowed::Object; +use simd_json::{BorrowedValue, StaticNode}; + +use super::*; + +const ITEM_NAME: &str = "item"; + +/// Infers [`ArrowDataType`] from [`Value`][Value]. +/// +/// [Value]: simd_json::value::Value +pub fn infer(json: &BorrowedValue) -> PolarsResult { + Ok(match json { + BorrowedValue::Static(StaticNode::Bool(_)) => ArrowDataType::Boolean, + BorrowedValue::Static(StaticNode::U64(_) | StaticNode::I64(_)) => ArrowDataType::Int64, + BorrowedValue::Static(StaticNode::F64(_)) => ArrowDataType::Float64, + BorrowedValue::Static(StaticNode::Null) => ArrowDataType::Null, + BorrowedValue::Array(array) => infer_array(array)?, + BorrowedValue::String(_) => ArrowDataType::LargeUtf8, + BorrowedValue::Object(inner) => infer_object(inner)?, + }) +} + +fn infer_object(inner: &Object) -> PolarsResult { + let fields = inner + .iter() + .map(|(key, value)| infer(value).map(|dt| (key, dt))) + .map(|maybe_dt| { + let (key, dt) = maybe_dt?; + Ok(Field::new(key.as_ref().into(), dt, true)) + }) + .collect::>>()?; + Ok(ArrowDataType::Struct(fields)) +} + +fn infer_array(values: &[BorrowedValue]) -> PolarsResult { + let types = values + .iter() + .map(infer) + // deduplicate entries + .collect::>>()?; + + let dt = if !types.is_empty() { + let types = types.into_iter().collect::>(); + coerce_dtype(&types) + } else { + ArrowDataType::Null + }; + + Ok(ArrowDataType::LargeList(Box::new(Field::new( + PlSmallStr::from_static(ITEM_NAME), + dt, + true, + )))) +} + +/// Coerce an heterogeneous set of [`ArrowDataType`] into a single one. Rules: +/// * The empty set is coerced to `Null` +/// * `Int64` and `Float64` are `Float64` +/// * Lists and scalars are coerced to a list of a compatible scalar +/// * Structs contain the union of all fields +/// * All other types are coerced to `Utf8` +pub(crate) fn coerce_dtype>(datatypes: &[A]) -> ArrowDataType { + use ArrowDataType::*; + + if datatypes.is_empty() { + return Null; + } + + let are_all_equal = datatypes.windows(2).all(|w| w[0].borrow() == w[1].borrow()); + + if are_all_equal { + return datatypes[0].borrow().clone(); + } + let mut are_all_structs = true; + let mut are_all_lists = true; + for dt in datatypes { + are_all_structs &= matches!(dt.borrow(), Struct(_)); + are_all_lists &= matches!(dt.borrow(), LargeList(_)); + } + + if are_all_structs { + // all are structs => union of all fields (that may have equal names) + let fields = datatypes.iter().fold(vec![], |mut acc, dt| { + if let Struct(new_fields) = dt.borrow() { + acc.extend(new_fields); + }; + acc + }); + // group fields by unique + let fields = fields.iter().fold( + PlIndexMap::<&str, PlHashSet<&ArrowDataType>>::default(), + |mut acc, field| { + match acc.entry(field.name.as_str()) { + Entry::Occupied(mut v) => { + v.get_mut().insert(&field.dtype); + }, + Entry::Vacant(v) => { + let mut a = PlHashSet::default(); + a.insert(&field.dtype); + v.insert(a); + }, + } + acc + }, + ); + // and finally, coerce each of the fields within the same name + let fields = fields + .into_iter() + .map(|(name, dts)| { + let dts = dts.into_iter().collect::>(); + Field::new(name.into(), coerce_dtype(&dts), true) + }) + .collect(); + return Struct(fields); + } else if are_all_lists { + let inner_types: Vec<&ArrowDataType> = datatypes + .iter() + .map(|dt| { + if let LargeList(inner) = dt.borrow() { + inner.dtype() + } else { + unreachable!(); + } + }) + .collect(); + return LargeList(Box::new(Field::new( + PlSmallStr::from_static(ITEM_NAME), + coerce_dtype(inner_types.as_slice()), + true, + ))); + } else if datatypes.len() > 2 { + return datatypes + .iter() + .map(|t| t.borrow().clone()) + .reduce(|a, b| coerce_dtype(&[a, b])) + .expect("not empty"); + } + let (lhs, rhs) = (datatypes[0].borrow(), datatypes[1].borrow()); + + match (lhs, rhs) { + (lhs, rhs) if lhs == rhs => lhs.clone(), + (LargeList(lhs), LargeList(rhs)) => { + let inner = coerce_dtype(&[lhs.dtype(), rhs.dtype()]); + LargeList(Box::new(Field::new( + PlSmallStr::from_static(ITEM_NAME), + inner, + true, + ))) + }, + (scalar, LargeList(list)) => { + let inner = coerce_dtype(&[scalar, list.dtype()]); + LargeList(Box::new(Field::new( + PlSmallStr::from_static(ITEM_NAME), + inner, + true, + ))) + }, + (LargeList(list), scalar) => { + let inner = coerce_dtype(&[scalar, list.dtype()]); + LargeList(Box::new(Field::new( + PlSmallStr::from_static(ITEM_NAME), + inner, + true, + ))) + }, + (Float64, Int64) => Float64, + (Int64, Float64) => Float64, + (Int64, Boolean) => Int64, + (Boolean, Int64) => Int64, + (Null, rhs) => rhs.clone(), + (lhs, Null) => lhs.clone(), + (_, _) => LargeUtf8, + } +} diff --git a/crates/polars-json/src/json/mod.rs b/crates/polars-json/src/json/mod.rs new file mode 100644 index 000000000000..7a653bcda93a --- /dev/null +++ b/crates/polars-json/src/json/mod.rs @@ -0,0 +1,8 @@ +pub mod deserialize; +pub(crate) mod infer_schema; + +pub use deserialize::deserialize; +pub use infer_schema::infer; +use polars_error::*; +use polars_utils::aliases::*; +pub mod write; diff --git a/crates/polars-json/src/json/write/mod.rs b/crates/polars-json/src/json/write/mod.rs new file mode 100644 index 000000000000..6796ef7436bb --- /dev/null +++ b/crates/polars-json/src/json/write/mod.rs @@ -0,0 +1,158 @@ +//! APIs to write to JSON +mod serialize; +mod utf8; + +use std::io::Write; + +use arrow::array::Array; +use arrow::datatypes::ArrowSchema; +use arrow::io::iterator::StreamingIterator; +use arrow::record_batch::RecordBatchT; +pub use fallible_streaming_iterator::*; +use polars_error::{PolarsError, PolarsResult}; +pub(crate) use serialize::new_serializer; +use serialize::serialize; +pub use utf8::serialize_to_utf8; + +/// [`FallibleStreamingIterator`] that serializes an [`Array`] to bytes of valid JSON +/// # Implementation +/// Advancing this iterator CPU-bounded +#[derive(Debug, Clone)] +pub struct Serializer +where + A: AsRef, + I: Iterator>, +{ + arrays: I, + buffer: Vec, +} + +impl Serializer +where + A: AsRef, + I: Iterator>, +{ + /// Creates a new [`Serializer`]. + pub fn new(arrays: I, buffer: Vec) -> Self { + Self { arrays, buffer } + } +} + +impl FallibleStreamingIterator for Serializer +where + A: AsRef, + I: Iterator>, +{ + type Item = [u8]; + + type Error = PolarsError; + + fn advance(&mut self) -> PolarsResult<()> { + self.buffer.clear(); + self.arrays + .next() + .map(|maybe_array| maybe_array.map(|array| serialize(array.as_ref(), &mut self.buffer))) + .transpose()?; + Ok(()) + } + + fn get(&self) -> Option<&Self::Item> { + if !self.buffer.is_empty() { + Some(&self.buffer) + } else { + None + } + } +} + +/// [`FallibleStreamingIterator`] that serializes a [`RecordBatchT`] into bytes of JSON +/// in a (pandas-compatible) record-oriented format. +/// +/// # Implementation +/// Advancing this iterator is CPU-bounded. +pub struct RecordSerializer<'a> { + schema: ArrowSchema, + index: usize, + end: usize, + iterators: Vec + Send + Sync + 'a>>, + buffer: Vec, +} + +impl<'a> RecordSerializer<'a> { + /// Creates a new [`RecordSerializer`]. + pub fn new(schema: ArrowSchema, chunk: &'a RecordBatchT, buffer: Vec) -> Self + where + A: AsRef, + { + let end = chunk.len(); + let iterators = chunk + .arrays() + .iter() + .map(|arr| new_serializer(arr.as_ref(), 0, usize::MAX)) + .collect(); + + Self { + schema, + index: 0, + end, + iterators, + buffer, + } + } +} + +impl FallibleStreamingIterator for RecordSerializer<'_> { + type Item = [u8]; + + type Error = PolarsError; + + fn advance(&mut self) -> PolarsResult<()> { + self.buffer.clear(); + if self.index == self.end { + return Ok(()); + } + + let mut is_first_row = true; + write!(&mut self.buffer, "{{")?; + for (f, ref mut it) in self.schema.iter_values().zip(self.iterators.iter_mut()) { + if !is_first_row { + write!(&mut self.buffer, ",")?; + } + write!(&mut self.buffer, "\"{}\":", f.name)?; + + self.buffer.extend_from_slice(it.next().unwrap()); + is_first_row = false; + } + write!(&mut self.buffer, "}}")?; + + self.index += 1; + Ok(()) + } + + fn get(&self) -> Option<&Self::Item> { + if !self.buffer.is_empty() { + Some(&self.buffer) + } else { + None + } + } +} + +/// Writes valid JSON from an iterator of (assumed JSON-encoded) bytes to `writer` +pub fn write(writer: &mut W, mut blocks: I) -> PolarsResult<()> +where + W: std::io::Write, + I: FallibleStreamingIterator, +{ + writer.write_all(b"[")?; + let mut is_first_row = true; + while let Some(block) = blocks.next()? { + if !is_first_row { + writer.write_all(b",")?; + } + is_first_row = false; + writer.write_all(block)?; + } + writer.write_all(b"]")?; + Ok(()) +} diff --git a/crates/polars-json/src/json/write/serialize.rs b/crates/polars-json/src/json/write/serialize.rs new file mode 100644 index 000000000000..9d2165744d09 --- /dev/null +++ b/crates/polars-json/src/json/write/serialize.rs @@ -0,0 +1,598 @@ +use std::io::Write; + +use arrow::array::*; +use arrow::bitmap::utils::ZipValidity; +#[cfg(feature = "dtype-decimal")] +use arrow::compute::decimal::get_trim_decimal_zeros; +use arrow::datatypes::{ArrowDataType, IntegerType, TimeUnit}; +use arrow::io::iterator::BufStreamingIterator; +use arrow::offset::Offset; +#[cfg(feature = "timezones")] +use arrow::temporal_conversions::parse_offset_tz; +use arrow::temporal_conversions::{ + date32_to_date, duration_ms_to_duration, duration_ns_to_duration, duration_us_to_duration, + parse_offset, time64ns_to_time, timestamp_ms_to_datetime, timestamp_ns_to_datetime, + timestamp_to_datetime, timestamp_us_to_datetime, +}; +use arrow::types::NativeType; +use chrono::{Duration, NaiveDate, NaiveDateTime, NaiveTime}; +use streaming_iterator::StreamingIterator; + +use super::utf8; + +fn write_integer(buf: &mut Vec, val: I) { + let mut buffer = itoa::Buffer::new(); + let value = buffer.format(val); + buf.extend_from_slice(value.as_bytes()) +} + +fn write_float(f: &mut Vec, val: I) { + let mut buffer = ryu::Buffer::new(); + let value = buffer.format(val); + f.extend_from_slice(value.as_bytes()) +} + +fn materialize_serializer<'a, I, F, T>( + f: F, + iterator: I, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> +where + T: 'a, + I: Iterator + Send + Sync + 'a, + F: FnMut(T, &mut Vec) + Send + Sync + 'a, +{ + if offset > 0 || take < usize::MAX { + Box::new(BufStreamingIterator::new( + iterator.skip(offset).take(take), + f, + vec![], + )) + } else { + Box::new(BufStreamingIterator::new(iterator, f, vec![])) + } +} + +fn boolean_serializer<'a>( + array: &'a BooleanArray, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> { + let f = |x: Option, buf: &mut Vec| match x { + Some(true) => buf.extend_from_slice(b"true"), + Some(false) => buf.extend_from_slice(b"false"), + None => buf.extend_from_slice(b"null"), + }; + materialize_serializer(f, array.iter(), offset, take) +} + +fn null_serializer( + len: usize, + offset: usize, + take: usize, +) -> Box + Send + Sync> { + let f = |_x: (), buf: &mut Vec| buf.extend_from_slice(b"null"); + materialize_serializer(f, std::iter::repeat_n((), len), offset, take) +} + +fn primitive_serializer<'a, T: NativeType + itoa::Integer>( + array: &'a PrimitiveArray, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> { + let f = |x: Option<&T>, buf: &mut Vec| { + if let Some(x) = x { + write_integer(buf, *x) + } else { + buf.extend(b"null") + } + }; + materialize_serializer(f, array.iter(), offset, take) +} + +fn float_serializer<'a, T>( + array: &'a PrimitiveArray, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> +where + T: num_traits::Float + NativeType + ryu::Float, +{ + let f = |x: Option<&T>, buf: &mut Vec| { + if let Some(x) = x { + if T::is_nan(*x) || T::is_infinite(*x) { + buf.extend(b"null") + } else { + write_float(buf, *x) + } + } else { + buf.extend(b"null") + } + }; + + materialize_serializer(f, array.iter(), offset, take) +} + +#[cfg(feature = "dtype-decimal")] +fn decimal_serializer<'a>( + array: &'a PrimitiveArray, + scale: usize, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> { + let trim_zeros = get_trim_decimal_zeros(); + let mut fmt_buf = arrow::compute::decimal::DecimalFmtBuffer::new(); + let f = move |x: Option<&i128>, buf: &mut Vec| { + if let Some(x) = x { + utf8::write_str(buf, fmt_buf.format(*x, scale, trim_zeros)).unwrap() + } else { + buf.extend(b"null") + } + }; + + materialize_serializer(f, array.iter(), offset, take) +} + +fn dictionary_utf8view_serializer<'a, K: DictionaryKey>( + array: &'a DictionaryArray, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> { + let iter = array.iter_typed::().unwrap().skip(offset); + let f = |x: Option<&str>, buf: &mut Vec| { + if let Some(x) = x { + utf8::write_str(buf, x).unwrap(); + } else { + buf.extend_from_slice(b"null") + } + }; + materialize_serializer(f, iter, offset, take) +} + +fn utf8_serializer<'a, O: Offset>( + array: &'a Utf8Array, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> { + let f = |x: Option<&str>, buf: &mut Vec| { + if let Some(x) = x { + utf8::write_str(buf, x).unwrap(); + } else { + buf.extend_from_slice(b"null") + } + }; + materialize_serializer(f, array.iter(), offset, take) +} + +fn utf8view_serializer<'a>( + array: &'a Utf8ViewArray, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> { + let f = |x: Option<&str>, buf: &mut Vec| { + if let Some(x) = x { + utf8::write_str(buf, x).unwrap(); + } else { + buf.extend_from_slice(b"null") + } + }; + materialize_serializer(f, array.iter(), offset, take) +} + +fn struct_serializer<'a>( + array: &'a StructArray, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> { + // {"a": [1, 2, 3], "b": [a, b, c], "c": {"a": [1, 2, 3]}} + // [ + // {"a": 1, "b": a, "c": {"a": 1}}, + // {"a": 2, "b": b, "c": {"a": 2}}, + // {"a": 3, "b": c, "c": {"a": 3}}, + // ] + // + let mut serializers = array + .values() + .iter() + .map(|x| x.as_ref()) + .map(|arr| new_serializer(arr, offset, take)) + .collect::>(); + + Box::new(BufStreamingIterator::new( + ZipValidity::new_with_validity(0..array.len(), array.validity()), + move |maybe, buf| { + if maybe.is_some() { + let names = array.fields().iter().map(|f| f.name.as_str()); + serialize_item( + buf, + names.zip( + serializers + .iter_mut() + .map(|serializer| serializer.next().unwrap()), + ), + true, + ); + } else { + serializers.iter_mut().for_each(|iter| { + let _ = iter.next(); + }); + buf.extend(b"null"); + } + }, + vec![], + )) +} + +fn list_serializer<'a, O: Offset>( + array: &'a ListArray, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> { + // [[1, 2], [3]] + // [ + // [1, 2], + // [3] + // ] + // + let offsets = array.offsets().as_slice(); + let start = offsets[0].to_usize(); + let end = offsets.last().unwrap().to_usize(); + let mut serializer = new_serializer(array.values().as_ref(), start, end - start); + + let mut prev_offset = start; + let f = move |offset: Option<&[O]>, buf: &mut Vec| { + if let Some(offset) = offset { + if offset[0].to_usize() > prev_offset { + for _ in 0..(offset[0].to_usize() - prev_offset) { + serializer.next().unwrap(); + } + } + + let length = (offset[1] - offset[0]).to_usize(); + buf.push(b'['); + let mut is_first_row = true; + for _ in 0..length { + if !is_first_row { + buf.push(b','); + } + is_first_row = false; + buf.extend(serializer.next().unwrap()); + } + buf.push(b']'); + prev_offset = offset[1].to_usize(); + } else { + buf.extend(b"null"); + } + }; + + let iter = + ZipValidity::new_with_validity(array.offsets().buffer().windows(2), array.validity()); + materialize_serializer(f, iter, offset, take) +} + +fn fixed_size_list_serializer<'a>( + array: &'a FixedSizeListArray, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> { + let mut serializer = new_serializer(array.values().as_ref(), offset, take); + + Box::new(BufStreamingIterator::new( + ZipValidity::new(0..array.len(), array.validity().map(|x| x.iter())), + move |ix, buf| { + if ix.is_some() { + let length = array.size(); + buf.push(b'['); + let mut is_first_row = true; + for _ in 0..length { + if !is_first_row { + buf.push(b','); + } + is_first_row = false; + buf.extend(serializer.next().unwrap()); + } + buf.push(b']'); + } else { + buf.extend(b"null"); + } + }, + vec![], + )) +} + +fn date_serializer<'a, T, F>( + array: &'a PrimitiveArray, + convert: F, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> +where + T: NativeType, + F: Fn(T) -> NaiveDate + 'static + Send + Sync, +{ + let f = move |x: Option<&T>, buf: &mut Vec| { + if let Some(x) = x { + let nd = convert(*x); + write!(buf, "\"{nd}\"").unwrap(); + } else { + buf.extend_from_slice(b"null") + } + }; + + materialize_serializer(f, array.iter(), offset, take) +} + +fn duration_serializer<'a, T, F>( + array: &'a PrimitiveArray, + convert: F, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> +where + T: NativeType, + F: Fn(T) -> Duration + 'static + Send + Sync, +{ + let f = move |x: Option<&T>, buf: &mut Vec| { + if let Some(x) = x { + let duration = convert(*x); + write!(buf, "\"{duration}\"").unwrap(); + } else { + buf.extend_from_slice(b"null") + } + }; + + materialize_serializer(f, array.iter(), offset, take) +} + +fn time_serializer<'a, T, F>( + array: &'a PrimitiveArray, + convert: F, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> +where + T: NativeType, + F: Fn(T) -> NaiveTime + 'static + Send + Sync, +{ + let f = move |x: Option<&T>, buf: &mut Vec| { + if let Some(x) = x { + let time = convert(*x); + write!(buf, "\"{time}\"").unwrap(); + } else { + buf.extend_from_slice(b"null") + } + }; + + materialize_serializer(f, array.iter(), offset, take) +} + +fn timestamp_serializer<'a, F>( + array: &'a PrimitiveArray, + convert: F, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> +where + F: Fn(i64) -> NaiveDateTime + 'static + Send + Sync, +{ + let f = move |x: Option<&i64>, buf: &mut Vec| { + if let Some(x) = x { + let ndt = convert(*x); + write!(buf, "\"{ndt}\"").unwrap(); + } else { + buf.extend_from_slice(b"null") + } + }; + materialize_serializer(f, array.iter(), offset, take) +} + +fn timestamp_tz_serializer<'a>( + array: &'a PrimitiveArray, + time_unit: TimeUnit, + tz: &str, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> { + match parse_offset(tz) { + Ok(parsed_tz) => { + let f = move |x: Option<&i64>, buf: &mut Vec| { + if let Some(x) = x { + let dt_str = timestamp_to_datetime(*x, time_unit, &parsed_tz).to_rfc3339(); + write!(buf, "\"{dt_str}\"").unwrap(); + } else { + buf.extend_from_slice(b"null") + } + }; + + materialize_serializer(f, array.iter(), offset, take) + }, + #[cfg(feature = "timezones")] + _ => match parse_offset_tz(tz) { + Ok(parsed_tz) => { + let f = move |x: Option<&i64>, buf: &mut Vec| { + if let Some(x) = x { + let dt_str = timestamp_to_datetime(*x, time_unit, &parsed_tz).to_rfc3339(); + write!(buf, "\"{dt_str}\"").unwrap(); + } else { + buf.extend_from_slice(b"null") + } + }; + + materialize_serializer(f, array.iter(), offset, take) + }, + _ => { + panic!("Timezone {} is invalid or not supported", tz); + }, + }, + #[cfg(not(feature = "timezones"))] + _ => { + panic!("Invalid Offset format (must be [-]00:00) or timezones feature not active"); + }, + } +} + +pub(crate) fn new_serializer<'a>( + array: &'a dyn Array, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> { + match array.dtype().to_logical_type() { + ArrowDataType::Boolean => { + boolean_serializer(array.as_any().downcast_ref().unwrap(), offset, take) + }, + ArrowDataType::Int8 => { + primitive_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + ArrowDataType::Int16 => { + primitive_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + ArrowDataType::Int32 => { + primitive_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + ArrowDataType::Int64 => { + primitive_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + ArrowDataType::UInt8 => { + primitive_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + ArrowDataType::UInt16 => { + primitive_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + ArrowDataType::UInt32 => { + primitive_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + ArrowDataType::UInt64 => { + primitive_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + ArrowDataType::Float32 => { + float_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + ArrowDataType::Float64 => { + float_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + #[cfg(feature = "dtype-decimal")] + ArrowDataType::Decimal(_, scale) => { + decimal_serializer(array.as_any().downcast_ref().unwrap(), *scale, offset, take) + }, + ArrowDataType::LargeUtf8 => { + utf8_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + ArrowDataType::Utf8View => { + utf8view_serializer(array.as_any().downcast_ref().unwrap(), offset, take) + }, + ArrowDataType::Struct(_) => { + struct_serializer(array.as_any().downcast_ref().unwrap(), offset, take) + }, + ArrowDataType::FixedSizeList(_, _) => { + fixed_size_list_serializer(array.as_any().downcast_ref().unwrap(), offset, take) + }, + ArrowDataType::LargeList(_) => { + list_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + ArrowDataType::Dictionary(k, v, _) => match (k, &**v) { + (IntegerType::UInt32, ArrowDataType::Utf8View) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + dictionary_utf8view_serializer::(array, offset, take) + }, + _ => { + // Not produced by polars + unreachable!() + }, + }, + ArrowDataType::Date32 => date_serializer( + array.as_any().downcast_ref().unwrap(), + date32_to_date, + offset, + take, + ), + ArrowDataType::Timestamp(tu, None) => { + let convert = match tu { + TimeUnit::Nanosecond => timestamp_ns_to_datetime, + TimeUnit::Microsecond => timestamp_us_to_datetime, + TimeUnit::Millisecond => timestamp_ms_to_datetime, + tu => panic!("Invalid time unit '{:?}' for Datetime.", tu), + }; + timestamp_serializer( + array.as_any().downcast_ref().unwrap(), + convert, + offset, + take, + ) + }, + ArrowDataType::Timestamp(time_unit, Some(tz)) => timestamp_tz_serializer( + array.as_any().downcast_ref().unwrap(), + *time_unit, + tz, + offset, + take, + ), + ArrowDataType::Duration(tu) => { + let convert = match tu { + TimeUnit::Nanosecond => duration_ns_to_duration, + TimeUnit::Microsecond => duration_us_to_duration, + TimeUnit::Millisecond => duration_ms_to_duration, + tu => panic!("Invalid time unit '{:?}' for Duration.", tu), + }; + duration_serializer( + array.as_any().downcast_ref().unwrap(), + convert, + offset, + take, + ) + }, + ArrowDataType::Time64(tu) => { + let convert = match tu { + TimeUnit::Nanosecond => time64ns_to_time, + tu => panic!("Invalid time unit '{:?}' for Time.", tu), + }; + time_serializer( + array.as_any().downcast_ref().unwrap(), + convert, + offset, + take, + ) + }, + ArrowDataType::Null => null_serializer(array.len(), offset, take), + other => todo!("Writing {:?} to JSON", other), + } +} + +fn serialize_item<'a>( + buffer: &mut Vec, + record: impl Iterator, + is_first_row: bool, +) { + if !is_first_row { + buffer.push(b','); + } + buffer.push(b'{'); + let mut first_item = true; + for (key, value) in record { + if !first_item { + buffer.push(b','); + } + first_item = false; + utf8::write_str(buffer, key).unwrap(); + buffer.push(b':'); + buffer.extend(value); + } + buffer.push(b'}'); +} + +/// Serializes `array` to a valid JSON to `buffer` +/// # Implementation +/// This operation is CPU-bounded +pub(crate) fn serialize(array: &dyn Array, buffer: &mut Vec) { + let mut serializer = new_serializer(array, 0, usize::MAX); + + (0..array.len()).for_each(|i| { + if i != 0 { + buffer.push(b','); + } + buffer.extend_from_slice(serializer.next().unwrap()); + }); +} diff --git a/crates/polars-json/src/json/write/utf8.rs b/crates/polars-json/src/json/write/utf8.rs new file mode 100644 index 000000000000..f967853bc1e1 --- /dev/null +++ b/crates/polars-json/src/json/write/utf8.rs @@ -0,0 +1,152 @@ +// Adapted from https://github.com/serde-rs/json/blob/f901012df66811354cb1d490ad59480d8fdf77b5/src/ser.rs +use std::io; + +use arrow::array::{Array, MutableBinaryViewArray, Utf8ViewArray}; + +use crate::json::write::new_serializer; + +pub fn write_str(writer: &mut W, value: &str) -> io::Result<()> +where + W: io::Write, +{ + writer.write_all(b"\"")?; + let bytes = value.as_bytes(); + + let mut start = 0; + + for (i, &byte) in bytes.iter().enumerate() { + let escape = ESCAPE[byte as usize]; + if escape == 0 { + continue; + } + + if start < i { + writer.write_all(&bytes[start..i])?; + } + + let char_escape = CharEscape::from_escape_table(escape, byte); + write_char_escape(writer, char_escape)?; + + start = i + 1; + } + + if start != bytes.len() { + writer.write_all(&bytes[start..])?; + } + writer.write_all(b"\"") +} + +const BB: u8 = b'b'; // \x08 +const TT: u8 = b't'; // \x09 +const NN: u8 = b'n'; // \x0A +const FF: u8 = b'f'; // \x0C +const RR: u8 = b'r'; // \x0D +const QU: u8 = b'"'; // \x22 +const BS: u8 = b'\\'; // \x5C +const UU: u8 = b'u'; // \x00...\x1F except the ones above +const __: u8 = 0; + +// Lookup table of escape sequences. A value of b'x' at index i means that byte +// i is escaped as "\x" in JSON. A value of 0 means that byte i is not escaped. +static ESCAPE: [u8; 256] = [ + // 1 2 3 4 5 6 7 8 9 A B C D E F + UU, UU, UU, UU, UU, UU, UU, UU, BB, TT, NN, UU, FF, RR, UU, UU, // 0 + UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, // 1 + __, __, QU, __, __, __, __, __, __, __, __, __, __, __, __, __, // 2 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 3 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 4 + __, __, __, __, __, __, __, __, __, __, __, __, BS, __, __, __, // 5 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 6 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 7 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 8 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 9 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // A + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // B + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // C + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // D + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // E + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // F +]; + +/// Represents a character escape code in a type-safe manner. +pub enum CharEscape { + /// An escaped quote `"` + Quote, + /// An escaped reverse solidus `\` + ReverseSolidus, + // An escaped solidus `/` + //Solidus, + /// An escaped backspace character (usually escaped as `\b`) + Backspace, + /// An escaped form feed character (usually escaped as `\f`) + FormFeed, + /// An escaped line feed character (usually escaped as `\n`) + LineFeed, + /// An escaped carriage return character (usually escaped as `\r`) + CarriageReturn, + /// An escaped tab character (usually escaped as `\t`) + Tab, + /// An escaped ASCII plane control character (usually escaped as + /// `\u00XX` where `XX` are two hex characters) + AsciiControl(u8), +} + +impl CharEscape { + #[inline] + fn from_escape_table(escape: u8, byte: u8) -> CharEscape { + match escape { + self::BB => CharEscape::Backspace, + self::TT => CharEscape::Tab, + self::NN => CharEscape::LineFeed, + self::FF => CharEscape::FormFeed, + self::RR => CharEscape::CarriageReturn, + self::QU => CharEscape::Quote, + self::BS => CharEscape::ReverseSolidus, + self::UU => CharEscape::AsciiControl(byte), + _ => unreachable!(), + } + } +} + +#[inline] +fn write_char_escape(writer: &mut W, char_escape: CharEscape) -> io::Result<()> +where + W: io::Write, +{ + use self::CharEscape::*; + + let s = match char_escape { + Quote => b"\\\"", + ReverseSolidus => b"\\\\", + //Solidus => b"\\/", + Backspace => b"\\b", + FormFeed => b"\\f", + LineFeed => b"\\n", + CarriageReturn => b"\\r", + Tab => b"\\t", + AsciiControl(byte) => { + static HEX_DIGITS: [u8; 16] = *b"0123456789abcdef"; + let bytes = &[ + b'\\', + b'u', + b'0', + b'0', + HEX_DIGITS[(byte >> 4) as usize], + HEX_DIGITS[(byte & 0xF) as usize], + ]; + return writer.write_all(bytes); + }, + }; + + writer.write_all(s) +} + +pub fn serialize_to_utf8(array: &dyn Array) -> Utf8ViewArray { + let mut values = MutableBinaryViewArray::with_capacity(array.len()); + let mut serializer = new_serializer(array, 0, usize::MAX); + + while let Some(v) = serializer.next() { + unsafe { values.push_value(std::str::from_utf8_unchecked(v)) } + } + values.into() +} diff --git a/crates/polars-json/src/lib.rs b/crates/polars-json/src/lib.rs new file mode 100644 index 000000000000..8c0d31873183 --- /dev/null +++ b/crates/polars-json/src/lib.rs @@ -0,0 +1,2 @@ +pub mod json; +pub mod ndjson; diff --git a/crates/polars-json/src/ndjson/deserialize.rs b/crates/polars-json/src/ndjson/deserialize.rs new file mode 100644 index 000000000000..aaf7ce3691dd --- /dev/null +++ b/crates/polars-json/src/ndjson/deserialize.rs @@ -0,0 +1,78 @@ +use arrow::array::Array; +use arrow::compute::concatenate::concatenate_unchecked; +use simd_json::BorrowedValue; + +use super::*; + +/// Deserializes an iterator of rows into an [`Array`][Array] of [`DataType`]. +/// +/// [Array]: arrow::array::Array +/// +/// # Implementation +/// This function is CPU-bounded. +/// This function is guaranteed to return an array of length equal to the length +/// # Errors +/// This function errors iff any of the rows is not a valid JSON (i.e. the format is not valid NDJSON). +pub fn deserialize_iter<'a>( + rows: impl Iterator, + dtype: ArrowDataType, + buf_size: usize, + count: usize, + allow_extra_fields_in_struct: bool, +) -> PolarsResult { + let mut arr: Vec> = Vec::new(); + let mut buf = Vec::with_capacity(std::cmp::min(buf_size + count + 2, u32::MAX as usize)); + buf.push(b'['); + + fn _deserializer( + s: &mut [u8], + dtype: ArrowDataType, + allow_extra_fields_in_struct: bool, + ) -> PolarsResult> { + let out = simd_json::to_borrowed_value(s) + .map_err(|e| PolarsError::ComputeError(format!("json parsing error: '{e}'").into()))?; + if let BorrowedValue::Array(rows) = out { + super::super::json::deserialize::_deserialize( + &rows, + dtype.clone(), + allow_extra_fields_in_struct, + ) + } else { + unreachable!() + } + } + let mut row_iter = rows.peekable(); + + while let Some(row) = row_iter.next() { + buf.extend_from_slice(row.as_bytes()); + buf.push(b','); + + let next_row_length = row_iter.peek().map(|row| row.len()).unwrap_or(0); + if buf.len() + next_row_length >= u32::MAX as usize { + let _ = buf.pop(); + buf.push(b']'); + arr.push(_deserializer( + &mut buf, + dtype.clone(), + allow_extra_fields_in_struct, + )?); + buf.clear(); + buf.push(b'['); + } + } + if buf.len() > 1 { + let _ = buf.pop(); + } + buf.push(b']'); + + if arr.is_empty() { + _deserializer(&mut buf, dtype.clone(), allow_extra_fields_in_struct) + } else { + arr.push(_deserializer( + &mut buf, + dtype.clone(), + allow_extra_fields_in_struct, + )?); + concatenate_unchecked(&arr) + } +} diff --git a/crates/polars-json/src/ndjson/file.rs b/crates/polars-json/src/ndjson/file.rs new file mode 100644 index 000000000000..e0a166f934e8 --- /dev/null +++ b/crates/polars-json/src/ndjson/file.rs @@ -0,0 +1,146 @@ +use std::io::BufRead; +use std::num::NonZeroUsize; + +use arrow::datatypes::ArrowDataType; +use fallible_streaming_iterator::FallibleStreamingIterator; +use indexmap::IndexSet; +use polars_error::*; +use polars_utils::aliases::{PlIndexSet, PlRandomState}; +use simd_json::BorrowedValue; + +/// Reads up to a number of lines from `reader` into `rows` bounded by `limit`. +fn read_rows(reader: &mut R, rows: &mut [String], limit: usize) -> PolarsResult { + if limit == 0 { + return Ok(0); + } + let mut row_number = 0; + for row in rows.iter_mut() { + loop { + row.clear(); + let _ = reader.read_line(row).map_err(|e| { + PolarsError::ComputeError(format!("{e} at line {row_number}").into()) + })?; + if row.is_empty() { + break; + } + if !row.trim().is_empty() { + break; + } + } + if row.is_empty() { + break; + } + row_number += 1; + if row_number == limit { + break; + } + } + Ok(row_number) +} + +/// A [`FallibleStreamingIterator`] of NDJSON rows. +/// +/// This iterator is used to read chunks of an NDJSON in batches. +/// This iterator is guaranteed to yield at least one row. +/// # Implementation +/// Advancing this iterator is IO-bounded, but does require parsing each byte to find end of lines. +/// # Error +/// Advancing this iterator errors iff the reader errors. +pub struct FileReader { + reader: R, + rows: Vec, + number_of_rows: usize, + remaining: usize, +} + +impl FileReader { + /// Creates a new [`FileReader`] from a reader and `rows`. + /// + /// The number of items in `rows` denotes the batch size. + pub fn new(reader: R, rows: Vec, limit: Option) -> Self { + Self { + reader, + rows, + remaining: limit.unwrap_or(usize::MAX), + number_of_rows: 0, + } + } +} + +impl FallibleStreamingIterator for FileReader { + type Error = PolarsError; + type Item = [String]; + + fn advance(&mut self) -> PolarsResult<()> { + self.number_of_rows = read_rows(&mut self.reader, &mut self.rows, self.remaining)?; + self.remaining -= self.number_of_rows; + Ok(()) + } + + fn get(&self) -> Option<&Self::Item> { + if self.number_of_rows > 0 { + Some(&self.rows[..self.number_of_rows]) + } else { + None + } + } +} + +fn parse_value<'a>(scratch: &'a mut Vec, val: &[u8]) -> PolarsResult> { + scratch.clear(); + scratch.extend_from_slice(val); + // 0 because it is row by row + + simd_json::to_borrowed_value(scratch) + .map_err(|e| PolarsError::ComputeError(format!("{e}").into())) +} + +/// Infers the [`ArrowDataType`] from an NDJSON file, optionally only using `number_of_rows` rows. +/// +/// # Implementation +/// This implementation reads the file line by line and infers the type of each line. +/// It performs both `O(N)` IO and CPU-bounded operations where `N` is the number of rows. +pub fn iter_unique_dtypes( + reader: &mut R, + number_of_rows: Option, +) -> PolarsResult> { + if reader.fill_buf().map(|b| b.is_empty())? { + return Err(PolarsError::ComputeError( + "Cannot infer NDJSON types on empty reader because empty string is not a valid JSON value".into(), + )); + } + + let rows = vec!["".to_string(); 1]; // 1 <=> read row by row + let mut reader = FileReader::new(reader, rows, number_of_rows.map(|v| v.into())); + + let mut dtypes = PlIndexSet::default(); + let mut buf = vec![]; + while let Some(rows) = reader.next()? { + // 0 because it is row by row + let value = parse_value(&mut buf, rows[0].as_bytes())?; + let dtype = crate::json::infer(&value)?; + dtypes.insert(dtype); + } + Ok(dtypes.into_iter()) +} + +/// Infers the [`ArrowDataType`] from an iterator of JSON strings. A limited number of +/// rows can be used by passing `rows.take(number_of_rows)` as an input. +/// +/// # Implementation +/// This implementation infers each row by going through the entire iterator. +pub fn infer_iter>(rows: impl Iterator) -> PolarsResult { + let mut dtypes = IndexSet::<_, PlRandomState>::default(); + + let mut buf = vec![]; + for row in rows { + let v = parse_value(&mut buf, row.as_ref().as_bytes())?; + let dtype = crate::json::infer(&v)?; + if dtype != ArrowDataType::Null { + dtypes.insert(dtype); + } + } + + let v: Vec<&ArrowDataType> = dtypes.iter().collect(); + Ok(crate::json::infer_schema::coerce_dtype(&v)) +} diff --git a/crates/polars-json/src/ndjson/mod.rs b/crates/polars-json/src/ndjson/mod.rs new file mode 100644 index 000000000000..4b345362e4ac --- /dev/null +++ b/crates/polars-json/src/ndjson/mod.rs @@ -0,0 +1,7 @@ +use arrow::array::ArrayRef; +use arrow::datatypes::*; +use polars_error::*; +pub mod deserialize; +mod file; +pub mod write; +pub use file::{infer_iter, iter_unique_dtypes}; diff --git a/crates/polars-json/src/ndjson/write.rs b/crates/polars-json/src/ndjson/write.rs new file mode 100644 index 000000000000..90f202b02360 --- /dev/null +++ b/crates/polars-json/src/ndjson/write.rs @@ -0,0 +1,118 @@ +//! APIs to serialize and write to [NDJSON](http://ndjson.org/). +use std::io::Write; + +use arrow::array::Array; +pub use fallible_streaming_iterator::FallibleStreamingIterator; +use polars_error::{PolarsError, PolarsResult}; + +use super::super::json::write::new_serializer; + +fn serialize(array: &dyn Array, buffer: &mut Vec) { + let mut serializer = new_serializer(array, 0, usize::MAX); + (0..array.len()).for_each(|_| { + buffer.extend_from_slice(serializer.next().unwrap()); + buffer.push(b'\n'); + }); +} + +/// [`FallibleStreamingIterator`] that serializes an [`Array`] to bytes of valid NDJSON +/// where every line is an element of the array. +/// # Implementation +/// Advancing this iterator CPU-bounded +#[derive(Debug, Clone)] +pub struct Serializer +where + A: AsRef, + I: Iterator>, +{ + arrays: I, + buffer: Vec, +} + +impl Serializer +where + A: AsRef, + I: Iterator>, +{ + /// Creates a new [`Serializer`]. + pub fn new(arrays: I, buffer: Vec) -> Self { + Self { arrays, buffer } + } +} + +impl FallibleStreamingIterator for Serializer +where + A: AsRef, + I: Iterator>, +{ + type Item = [u8]; + + type Error = PolarsError; + + fn advance(&mut self) -> PolarsResult<()> { + self.buffer.clear(); + self.arrays + .next() + .map(|maybe_array| maybe_array.map(|array| serialize(array.as_ref(), &mut self.buffer))) + .transpose()?; + Ok(()) + } + + fn get(&self) -> Option<&Self::Item> { + if !self.buffer.is_empty() { + Some(&self.buffer) + } else { + None + } + } +} + +/// An iterator adapter that receives an implementer of [`Write`] and +/// an implementer of [`FallibleStreamingIterator`] (such as [`Serializer`]) +/// and writes a valid NDJSON +/// # Implementation +/// Advancing this iterator mixes CPU-bounded (serializing arrays) tasks and IO-bounded (write to the writer). +pub struct FileWriter +where + W: Write, + I: FallibleStreamingIterator, +{ + writer: W, + iterator: I, +} + +impl FileWriter +where + W: Write, + I: FallibleStreamingIterator, +{ + /// Creates a new [`FileWriter`]. + pub fn new(writer: W, iterator: I) -> Self { + Self { writer, iterator } + } + + /// Returns the inner content of this iterator + /// + /// There are two use-cases for this function: + /// * to continue writing to its writer + /// * to reuse an internal buffer of its iterator + pub fn into_inner(self) -> (W, I) { + (self.writer, self.iterator) + } +} + +impl Iterator for FileWriter +where + W: Write, + I: FallibleStreamingIterator, +{ + type Item = PolarsResult<()>; + + fn next(&mut self) -> Option { + let item = self.iterator.next().transpose()?; + Some(item.and_then(|x| { + self.writer.write_all(x)?; + Ok(()) + })) + } +} diff --git a/crates/polars-lazy/Cargo.toml b/crates/polars-lazy/Cargo.toml new file mode 100644 index 000000000000..f733c4c3947b --- /dev/null +++ b/crates/polars-lazy/Cargo.toml @@ -0,0 +1,431 @@ +[package] +name = "polars-lazy" +version = { workspace = true } +authors = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +license = { workspace = true } +repository = { workspace = true } +description = "Lazy query engine for the Polars DataFrame library" + +[dependencies] +arrow = { workspace = true } +chrono = { workspace = true } +futures = { workspace = true, optional = true } +polars-compute = { workspace = true } +polars-core = { workspace = true, features = ["lazy", "zip_with", "random"] } +polars-expr = { workspace = true } +polars-io = { workspace = true, features = ["lazy"] } +polars-json = { workspace = true, optional = true } +polars-mem-engine = { workspace = true } +polars-ops = { workspace = true, features = ["chunked_ids"] } +polars-pipe = { workspace = true, optional = true } +polars-plan = { workspace = true } +polars-stream = { workspace = true, optional = true } +polars-time = { workspace = true, optional = true } +polars-utils = { workspace = true } + +bitflags = { workspace = true } +either = { workspace = true } +memchr = { workspace = true } +pyo3 = { workspace = true, optional = true } +rayon = { workspace = true } +tokio = { workspace = true, optional = true } + +[dev-dependencies] +serde_json = { workspace = true } + +[build-dependencies] +version_check = { workspace = true } + +[features] +catalog = ["polars-io/catalog"] +nightly = ["polars-core/nightly", "polars-pipe?/nightly", "polars-plan/nightly"] +streaming = ["polars-pipe", "polars-plan/streaming", "polars-ops/chunked_ids", "polars-expr/streaming"] +new_streaming = ["polars-stream"] +parquet = [ + "polars-io/parquet", + "polars-plan/parquet", + "polars-pipe?/parquet", + "polars-expr/parquet", + "polars-mem-engine/parquet", + "polars-stream?/parquet", +] +async = [ + "polars-plan/async", + "polars-io/cloud", + "polars-pipe?/async", + "polars-mem-engine/async", +] +cloud = [ + "async", + "polars-pipe?/cloud", + "polars-plan/cloud", + "tokio", + "futures", + "polars-mem-engine/cloud", + "polars-stream?/cloud", +] +ipc = ["polars-io/ipc", "polars-plan/ipc", "polars-pipe?/ipc", "polars-mem-engine/ipc", "polars-stream?/ipc"] +json = [ + "polars-io/json", + "polars-plan/json", + "polars-json", + "polars-pipe?/json", + "polars-mem-engine/json", + "polars-stream?/json", +] +csv = ["polars-io/csv", "polars-plan/csv", "polars-pipe?/csv", "polars-mem-engine/csv", "polars-stream?/csv"] +temporal = [ + "dtype-datetime", + "dtype-date", + "dtype-time", + "dtype-i8", + "dtype-i16", + "dtype-duration", + "polars-plan/temporal", + "polars-expr/temporal", +] +# debugging purposes +fmt = ["polars-core/fmt", "polars-plan/fmt"] +strings = ["polars-plan/strings", "polars-stream?/strings"] +future = [] + +dtype-full = [ + "dtype-array", + "dtype-categorical", + "dtype-date", + "dtype-datetime", + "dtype-decimal", + "dtype-duration", + "dtype-i16", + "dtype-i128", + "dtype-i8", + "dtype-struct", + "dtype-time", + "dtype-u16", + "dtype-u8", +] +dtype-array = [ + "polars-plan/dtype-array", + "polars-pipe?/dtype-array", + "polars-ops/dtype-array", + "polars-expr/dtype-array", +] +dtype-categorical = [ + "polars-plan/dtype-categorical", + "polars-pipe?/dtype-categorical", + "polars-stream?/dtype-categorical", + "polars-expr/dtype-categorical", + "polars-mem-engine/dtype-categorical", +] +dtype-date = [ + "polars-plan/dtype-date", + "polars-time/dtype-date", + "temporal", + "polars-expr/dtype-date", + "polars-mem-engine/dtype-date", +] +dtype-datetime = [ + "polars-plan/dtype-datetime", + "polars-time/dtype-datetime", + "temporal", + "polars-expr/dtype-datetime", + "polars-mem-engine/dtype-datetime", +] +dtype-decimal = [ + "polars-plan/dtype-decimal", + "polars-pipe?/dtype-decimal", + "polars-expr/dtype-decimal", + "polars-mem-engine/dtype-decimal", +] +dtype-duration = [ + "polars-plan/dtype-duration", + "polars-time/dtype-duration", + "temporal", + "polars-expr/dtype-duration", + "polars-mem-engine/dtype-duration", +] +dtype-i16 = ["polars-plan/dtype-i16", "polars-pipe?/dtype-i16", "polars-expr/dtype-i16", "polars-mem-engine/dtype-i16"] +dtype-i128 = ["polars-plan/dtype-i128", "polars-pipe?/dtype-i128", "polars-expr/dtype-i128"] +dtype-i8 = ["polars-plan/dtype-i8", "polars-pipe?/dtype-i8", "polars-expr/dtype-i8", "polars-mem-engine/dtype-i8"] +dtype-struct = [ + "polars-plan/dtype-struct", + "polars-ops/dtype-struct", + "polars-expr/dtype-struct", + "polars-mem-engine/dtype-struct", +] +dtype-time = [ + "polars-plan/dtype-time", + "polars-time/dtype-time", + "temporal", + "polars-expr/dtype-time", + "polars-mem-engine/dtype-time", +] +dtype-u16 = ["polars-plan/dtype-u16", "polars-pipe?/dtype-u16", "polars-expr/dtype-u16", "polars-mem-engine/dtype-u16"] +dtype-u8 = ["polars-plan/dtype-u8", "polars-pipe?/dtype-u8", "polars-expr/dtype-u8", "polars-mem-engine/dtype-u8"] + +object = ["polars-plan/object", "polars-mem-engine/object", "polars-stream?/object"] +month_start = ["polars-plan/month_start"] +month_end = ["polars-plan/month_end"] +offset_by = ["polars-plan/offset_by"] +trigonometry = ["polars-plan/trigonometry"] +sign = ["polars-plan/sign"] +timezones = ["polars-plan/timezones"] +list_gather = ["polars-ops/list_gather", "polars-plan/list_gather"] +list_count = ["polars-ops/list_count", "polars-plan/list_count"] +array_count = ["polars-ops/array_count", "polars-plan/array_count", "dtype-array"] +true_div = ["polars-plan/true_div"] +extract_jsonpath = ["polars-plan/extract_jsonpath", "polars-ops/extract_jsonpath"] + +# operations +bitwise = [ + "polars-plan/bitwise", + "polars-expr/bitwise", + "polars-core/bitwise", + "polars-stream?/bitwise", + "polars-ops/bitwise", +] +approx_unique = ["polars-plan/approx_unique"] +is_in = ["polars-plan/is_in", "polars-ops/is_in", "polars-expr/is_in"] +repeat_by = ["polars-plan/repeat_by"] +round_series = ["polars-plan/round_series", "polars-ops/round_series", "polars-expr/round_series"] +is_first_distinct = ["polars-plan/is_first_distinct"] +is_last_distinct = ["polars-plan/is_last_distinct"] +is_between = ["polars-plan/is_between", "polars-expr/is_between"] +is_unique = ["polars-plan/is_unique"] +cross_join = ["polars-plan/cross_join", "polars-pipe?/cross_join", "polars-ops/cross_join"] +asof_join = ["polars-plan/asof_join", "polars-time", "polars-ops/asof_join", "polars-mem-engine/asof_join"] +iejoin = ["polars-plan/iejoin"] +business = ["polars-plan/business"] +concat_str = ["polars-plan/concat_str"] +range = ["polars-plan/range"] +mode = ["polars-plan/mode"] +cum_agg = ["polars-plan/cum_agg"] +interpolate = ["polars-plan/interpolate"] +interpolate_by = ["polars-plan/interpolate_by"] +rolling_window = [ + "polars-plan/rolling_window", +] +rolling_window_by = [ + "polars-plan/rolling_window_by", + "polars-time/rolling_window_by", +] +rank = ["polars-plan/rank"] +diff = ["polars-plan/diff", "polars-plan/diff"] +pct_change = ["polars-plan/pct_change"] +moment = ["polars-plan/moment", "polars-ops/moment"] +abs = ["polars-plan/abs"] +random = ["polars-plan/random"] +dynamic_group_by = [ + "polars-plan/dynamic_group_by", + "polars-time", + "temporal", + "polars-expr/dynamic_group_by", + "polars-mem-engine/dynamic_group_by", + "polars-stream?/dynamic_group_by", +] +ewma = ["polars-plan/ewma"] +ewma_by = ["polars-plan/ewma_by"] +dot_diagram = ["polars-plan/dot_diagram"] +diagonal_concat = [] +unique_counts = ["polars-plan/unique_counts"] +log = ["polars-plan/log"] +list_eval = [] +cumulative_eval = [] +list_to_struct = ["polars-plan/list_to_struct"] +array_to_struct = ["polars-plan/array_to_struct"] +python = [ + "pyo3", + "polars-plan/python", + "polars-core/python", + "polars-io/python", + "polars-mem-engine/python", + "polars-stream?/python", +] +row_hash = ["polars-plan/row_hash"] +reinterpret = ["polars-plan/reinterpret", "polars-ops/reinterpret"] +string_pad = ["polars-plan/string_pad"] +string_normalize = ["polars-plan/string_normalize"] +string_reverse = ["polars-plan/string_reverse"] +string_to_integer = ["polars-plan/string_to_integer"] +arg_where = ["polars-plan/arg_where"] +index_of = ["polars-plan/index_of"] +search_sorted = ["polars-plan/search_sorted"] +merge_sorted = ["polars-plan/merge_sorted", "polars-stream?/merge_sorted", "polars-mem-engine/merge_sorted"] +meta = ["polars-plan/meta"] +pivot = ["polars-core/rows", "polars-ops/pivot", "polars-plan/pivot"] +top_k = ["polars-plan/top_k"] +semi_anti_join = ["polars-plan/semi_anti_join"] +cse = ["polars-plan/cse"] +propagate_nans = ["polars-plan/propagate_nans", "polars-expr/propagate_nans"] +coalesce = ["polars-plan/coalesce"] +regex = ["polars-plan/regex"] +serde = [ + "polars-plan/serde", + "arrow/serde", + "polars-core/serde-lazy", + "polars-time?/serde", + "polars-io/serde", + "polars-ops/serde", + "polars-utils/serde", +] +fused = ["polars-plan/fused", "polars-ops/fused"] +list_sets = ["polars-plan/list_sets", "polars-ops/list_sets"] +list_any_all = ["polars-ops/list_any_all", "polars-plan/list_any_all"] +array_any_all = ["polars-ops/array_any_all", "polars-plan/array_any_all", "dtype-array"] +list_drop_nulls = ["polars-ops/list_drop_nulls", "polars-plan/list_drop_nulls"] +list_sample = ["polars-ops/list_sample", "polars-plan/list_sample"] +cutqcut = ["polars-plan/cutqcut", "polars-ops/cutqcut"] +rle = ["polars-plan/rle", "polars-ops/rle"] +extract_groups = ["polars-plan/extract_groups"] +peaks = ["polars-plan/peaks"] +cov = ["polars-ops/cov", "polars-plan/cov"] +hist = ["polars-plan/hist"] +replace = ["polars-plan/replace"] + +binary_encoding = ["polars-plan/binary_encoding"] +string_encoding = ["polars-plan/string_encoding"] + +bigidx = ["polars-plan/bigidx", "polars-utils/bigidx"] +polars_cloud = ["polars-plan/polars_cloud"] + +test = [ + "polars-plan/debugging", + "rolling_window", + "rank", + "round_series", + "csv", + "dtype-categorical", + "cum_agg", + "regex", + "polars-core/fmt", + "diff", + "abs", + "parquet", + "ipc", + "dtype-date", +] + +test_all = [ + "test", + "strings", + "regex", + "ipc", + "row_hash", + "string_pad", + "string_to_integer", + "index_of", + "search_sorted", + "top_k", + "pivot", + "semi_anti_join", + "cse", + "dtype-struct", + "peaks", + "cov", + "hist", + "extract_groups", + "rle", + "cutqcut", + "replace", + "list_sample", +] + +[package.metadata.docs.rs] +features = [ + "abs", + "approx_unique", + "arg_where", + "asof_join", + "async", + "bigidx", + "binary_encoding", + "cloud", + "coalesce", + "concat_str", + "cov", + "cross_join", + "cse", + "csv", + "cum_agg", + "cumulative_eval", + "cutqcut", + "diagonal_concat", + "diff", + "dot_diagram", + "dtype-full", + "dynamic_group_by", + "ewma", + "extract_groups", + "fmt", + "fused", + "futures", + "hist", + "index_of", + "interpolate", + "interpolate_by", + "ipc", + "is_first_distinct", + "is_in", + "is_last_distinct", + "is_unique", + "json", + "list_any_all", + "list_count", + "list_drop_nulls", + "list_eval", + "list_gather", + "list_sample", + "list_sets", + "list_to_struct", + "log", + "merge_sorted", + "meta", + "mode", + "moment", + "month_start", + "month_end", + "nightly", + "object", + "offset_by", + "panic_on_schema", + "parquet", + "pct_change", + "peaks", + "pivot", + "polars-json", + "polars-time", + "propagate_nans", + "random", + "range", + "rank", + "regex", + "repeat_by", + "replace", + "rle", + "rolling_window", + "rolling_window_by", + "round_series", + "row_hash", + "search_sorted", + "semi_anti_join", + "serde", + "sign", + "streaming", + "string_encoding", + "string_normalize", + "string_pad", + "string_reverse", + "string_to_integer", + "strings", + "temporal", + "timezones", + "tokio", + "top_k", + "trigonometry", + "true_div", + "unique_counts", +] +# defines the configuration attribute `docsrs` +rustdoc-args = ["--cfg", "docsrs"] diff --git a/crates/polars-lazy/LICENSE b/crates/polars-lazy/LICENSE new file mode 120000 index 000000000000..30cff7403da0 --- /dev/null +++ b/crates/polars-lazy/LICENSE @@ -0,0 +1 @@ +../../LICENSE \ No newline at end of file diff --git a/crates/polars-lazy/README.md b/crates/polars-lazy/README.md new file mode 100644 index 000000000000..b12552cb4a5a --- /dev/null +++ b/crates/polars-lazy/README.md @@ -0,0 +1,9 @@ +# polars-lazy + +`polars-lazy` serves as the lazy query engine for the [Polars](https://crates.io/crates/polars) +DataFrame library. It allows you to perform operations on DataFrames in a lazy manner, only +executing them when necessary. This can lead to significant performance improvements for large +datasets. + +**Important Note**: This crate is **not intended for external usage**. Please refer to the main +[Polars crate](https://crates.io/crates/polars) for intended usage. diff --git a/crates/polars-lazy/build.rs b/crates/polars-lazy/build.rs new file mode 100644 index 000000000000..3e4ab64620ac --- /dev/null +++ b/crates/polars-lazy/build.rs @@ -0,0 +1,7 @@ +fn main() { + println!("cargo:rerun-if-changed=build.rs"); + let channel = version_check::Channel::read().unwrap(); + if channel.is_nightly() { + println!("cargo:rustc-cfg=feature=\"nightly\""); + } +} diff --git a/crates/polars-lazy/src/dot.rs b/crates/polars-lazy/src/dot.rs new file mode 100644 index 000000000000..f8facf074838 --- /dev/null +++ b/crates/polars-lazy/src/dot.rs @@ -0,0 +1,16 @@ +use polars_core::prelude::*; + +use crate::prelude::*; + +impl LazyFrame { + /// Get a dot language representation of the LogicalPlan. + pub fn to_dot(&self, optimized: bool) -> PolarsResult { + let lp = if optimized { + self.clone().to_alp_optimized() + } else { + self.clone().to_alp() + }?; + + Ok(lp.display_dot().to_string()) + } +} diff --git a/crates/polars-lazy/src/dsl/eval.rs b/crates/polars-lazy/src/dsl/eval.rs new file mode 100644 index 000000000000..1cd75b0a5110 --- /dev/null +++ b/crates/polars-lazy/src/dsl/eval.rs @@ -0,0 +1,131 @@ +use polars_core::POOL; +use polars_core::prelude::*; +use polars_expr::{ExpressionConversionState, create_physical_expr}; +use rayon::prelude::*; + +use super::*; +use crate::prelude::*; + +pub(crate) fn eval_field_to_dtype(f: &Field, expr: &Expr, list: bool) -> Field { + // Dummy df to determine output dtype. + let dtype = f + .dtype() + .inner_dtype() + .cloned() + .unwrap_or_else(|| f.dtype().clone()); + + let df = Series::new_empty(PlSmallStr::EMPTY, &dtype).into_frame(); + + #[cfg(feature = "python")] + let out = { + use pyo3::Python; + Python::with_gil(|py| py.allow_threads(|| df.lazy().select([expr.clone()]).collect())) + }; + #[cfg(not(feature = "python"))] + let out = { df.lazy().select([expr.clone()]).collect() }; + + match out { + Ok(out) => { + let dtype = out.get_columns()[0].dtype(); + if list { + Field::new(f.name().clone(), DataType::List(Box::new(dtype.clone()))) + } else { + Field::new(f.name().clone(), dtype.clone()) + } + }, + Err(_) => Field::new(f.name().clone(), DataType::Null), + } +} + +pub trait ExprEvalExtension: IntoExpr + Sized { + /// Run an expression over a sliding window that increases `1` slot every iteration. + /// + /// # Warning + /// This can be really slow as it can have `O(n^2)` complexity. Don't use this for operations + /// that visit all elements. + fn cumulative_eval(self, expr: Expr, min_periods: usize, parallel: bool) -> Expr { + let this = self.into_expr(); + let expr2 = expr.clone(); + let func = move |mut c: Column| { + let name = c.name().clone(); + c.rename(PlSmallStr::EMPTY); + + // Ensure we get the new schema. + let output_field = eval_field_to_dtype(c.field().as_ref(), &expr, false); + let schema = Arc::new(Schema::from_iter(std::iter::once(output_field.clone()))); + + let expr = expr.clone(); + let mut arena = Arena::with_capacity(10); + let aexpr = to_expr_ir(expr, &mut arena)?; + let phys_expr = create_physical_expr( + &aexpr, + Context::Default, + &arena, + &schema, + &mut ExpressionConversionState::new(true), + )?; + + let state = ExecutionState::new(); + + let finish = |out: Column| { + polars_ensure!( + out.len() <= 1, + ComputeError: + "expected single value, got a result with length {}, {:?}", + out.len(), out, + ); + Ok(out.get(0).unwrap().into_static()) + }; + + let avs = if parallel { + POOL.install(|| { + (1..c.len() + 1) + .into_par_iter() + .map(|len| { + let c = c.slice(0, len); + if (len - c.null_count()) >= min_periods { + let df = c.clone().into_frame(); + let out = phys_expr.evaluate(&df, &state)?.into_column(); + finish(out) + } else { + Ok(AnyValue::Null) + } + }) + .collect::>>() + })? + } else { + let mut df_container = DataFrame::empty(); + (1..c.len() + 1) + .map(|len| { + let c = c.slice(0, len); + if (len - c.null_count()) >= min_periods { + unsafe { + df_container.with_column_unchecked(c.into_column()); + let out = phys_expr.evaluate(&df_container, &state)?.into_column(); + df_container.clear_columns(); + finish(out) + } + } else { + Ok(AnyValue::Null) + } + }) + .collect::>>()? + }; + let c = Column::new(name, avs); + + if c.dtype() != output_field.dtype() { + c.cast(output_field.dtype()).map(Some) + } else { + Ok(Some(c)) + } + }; + + this.apply( + func, + GetOutput::map_field(move |f| Ok(eval_field_to_dtype(f, &expr2, false))), + ) + .with_fmt("expanding_eval") + } +} + +impl ExprEvalExtension for Expr {} diff --git a/crates/polars-lazy/src/dsl/functions.rs b/crates/polars-lazy/src/dsl/functions.rs new file mode 100644 index 000000000000..ef7e25f23796 --- /dev/null +++ b/crates/polars-lazy/src/dsl/functions.rs @@ -0,0 +1,135 @@ +//! # Functions +//! +//! Function on multiple expressions. +//! + +use polars_core::prelude::*; +pub use polars_plan::dsl::functions::*; +use polars_plan::prelude::UnionArgs; +use rayon::prelude::*; + +use crate::prelude::*; + +pub(crate) fn concat_impl>( + inputs: L, + args: UnionArgs, +) -> PolarsResult { + let mut inputs = inputs.as_ref().to_vec(); + + let lf = std::mem::take( + inputs + .get_mut(0) + .ok_or_else(|| polars_err!(NoData: "empty container given"))?, + ); + + let opt_state = lf.opt_state; + let cached_arenas = lf.cached_arena.clone(); + + let mut lps = Vec::with_capacity(inputs.len()); + lps.push(lf.logical_plan); + + for lf in &mut inputs[1..] { + let lp = std::mem::take(&mut lf.logical_plan); + lps.push(lp) + } + + let lp = DslPlan::Union { inputs: lps, args }; + Ok(LazyFrame::from_inner(lp, opt_state, cached_arenas)) +} + +#[cfg(feature = "diagonal_concat")] +/// Concat [LazyFrame]s diagonally. +/// Calls [`concat`][concat()] internally. +pub fn concat_lf_diagonal>( + inputs: L, + mut args: UnionArgs, +) -> PolarsResult { + args.diagonal = true; + concat_impl(inputs, args) +} + +/// Concat [LazyFrame]s horizontally. +pub fn concat_lf_horizontal>( + inputs: L, + args: UnionArgs, +) -> PolarsResult { + let lfs = inputs.as_ref(); + let (opt_state, cached_arena) = lfs + .first() + .map(|lf| (lf.opt_state, lf.cached_arena.clone())) + .ok_or_else( + || polars_err!(NoData: "Require at least one LazyFrame for horizontal concatenation"), + )?; + + let options = HConcatOptions { + parallel: args.parallel, + }; + let lp = DslPlan::HConcat { + inputs: lfs.iter().map(|lf| lf.logical_plan.clone()).collect(), + options, + }; + Ok(LazyFrame::from_inner(lp, opt_state, cached_arena)) +} + +/// Concat multiple [`LazyFrame`]s vertically. +pub fn concat>(inputs: L, args: UnionArgs) -> PolarsResult { + concat_impl(inputs, args) +} + +/// Collect all [`LazyFrame`] computations. +pub fn collect_all(lfs: I) -> PolarsResult> +where + I: IntoParallelIterator, +{ + let iter = lfs.into_par_iter(); + + polars_core::POOL.install(|| iter.map(|lf| lf.collect()).collect()) +} + +#[cfg(test)] +mod test { + // used only if feature="diagonal_concat" + #[allow(unused_imports)] + use super::*; + + #[test] + #[cfg(feature = "diagonal_concat")] + fn test_diag_concat_lf() -> PolarsResult<()> { + let a = df![ + "a" => [1, 2], + "b" => ["a", "b"] + ]?; + + let b = df![ + "b" => ["a", "b"], + "c" => [1, 2] + ]?; + + let c = df![ + "a" => [5, 7], + "c" => [1, 2], + "d" => [1, 2] + ]?; + + let out = concat_lf_diagonal( + &[a.lazy(), b.lazy(), c.lazy()], + UnionArgs { + rechunk: false, + parallel: false, + ..Default::default() + }, + )? + .collect()?; + + let expected = df![ + "a" => [Some(1), Some(2), None, None, Some(5), Some(7)], + "b" => [Some("a"), Some("b"), Some("a"), Some("b"), None, None], + "c" => [None, None, Some(1), Some(2), Some(1), Some(2)], + "d" => [None, None, None, None, Some(1), Some(2)] + ]?; + + assert!(out.equals_missing(&expected)); + + Ok(()) + } +} diff --git a/crates/polars-lazy/src/dsl/into.rs b/crates/polars-lazy/src/dsl/into.rs new file mode 100644 index 000000000000..26d3cea20b94 --- /dev/null +++ b/crates/polars-lazy/src/dsl/into.rs @@ -0,0 +1,11 @@ +use polars_plan::dsl::Expr; + +pub trait IntoExpr { + fn into_expr(self) -> Expr; +} + +impl IntoExpr for Expr { + fn into_expr(self) -> Expr { + self + } +} diff --git a/crates/polars-lazy/src/dsl/list.rs b/crates/polars-lazy/src/dsl/list.rs new file mode 100644 index 000000000000..660a6f7f85fd --- /dev/null +++ b/crates/polars-lazy/src/dsl/list.rs @@ -0,0 +1,217 @@ +use std::sync::Mutex; + +use arrow::array::ValueSize; +use arrow::legacy::utils::CustomIterTools; +use polars_core::chunked_array::from_iterator_par::ChunkedCollectParIterExt; +use polars_core::prelude::*; +use polars_plan::constants::MAP_LIST_NAME; +use polars_plan::dsl::*; +use rayon::prelude::*; + +use crate::physical_plan::exotic::prepare_expression_for_context; +use crate::prelude::*; + +pub trait IntoListNameSpace { + fn into_list_name_space(self) -> ListNameSpace; +} + +impl IntoListNameSpace for ListNameSpace { + fn into_list_name_space(self) -> ListNameSpace { + self + } +} + +fn offsets_to_groups(offsets: &[i64]) -> Option { + let mut start = offsets[0]; + let end = *offsets.last().unwrap(); + if IdxSize::try_from(end - start).is_err() { + return None; + } + let groups = offsets + .iter() + .skip(1) + .map(|end| { + let offset = start as IdxSize; + let len = (*end - start) as IdxSize; + start = *end; + [offset, len] + }) + .collect(); + Some( + GroupsType::Slice { + groups, + rolling: false, + } + .into_sliceable(), + ) +} + +fn run_per_sublist( + s: Column, + lst: &ListChunked, + expr: &Expr, + parallel: bool, + output_field: Field, +) -> PolarsResult> { + let phys_expr = prepare_expression_for_context( + PlSmallStr::EMPTY, + expr, + lst.inner_dtype(), + Context::Default, + )?; + + let state = ExecutionState::new(); + + let mut err = None; + let mut ca: ListChunked = if parallel { + let m_err = Mutex::new(None); + let ca: ListChunked = lst + .par_iter() + .map(|opt_s| { + opt_s.and_then(|s| { + let df = s.into_frame(); + let out = phys_expr.evaluate(&df, &state); + match out { + Ok(s) => Some(s.take_materialized_series()), + Err(e) => { + *m_err.lock().unwrap() = Some(e); + None + }, + } + }) + }) + .collect_ca_with_dtype(PlSmallStr::EMPTY, output_field.dtype.clone()); + err = m_err.into_inner().unwrap(); + ca + } else { + let mut df_container = DataFrame::empty(); + + lst.into_iter() + .map(|s| { + s.and_then(|s| unsafe { + df_container.with_column_unchecked(s.into_column()); + let out = phys_expr.evaluate(&df_container, &state); + df_container.clear_columns(); + match out { + Ok(s) => Some(s.take_materialized_series()), + Err(e) => { + err = Some(e); + None + }, + } + }) + }) + .collect_trusted() + }; + if let Some(err) = err { + return Err(err); + } + + ca.rename(s.name().clone()); + + if ca.dtype() != output_field.dtype() { + ca.cast(output_field.dtype()).map(Column::from).map(Some) + } else { + Ok(Some(ca.into_column())) + } +} + +fn run_on_group_by_engine( + name: PlSmallStr, + lst: &ListChunked, + expr: &Expr, +) -> PolarsResult> { + let lst = lst.rechunk(); + let arr = lst.downcast_as_array(); + let groups = offsets_to_groups(arr.offsets()).unwrap(); + + // List elements in a series. + let values = Series::try_from((PlSmallStr::EMPTY, arr.values().clone())).unwrap(); + let inner_dtype = lst.inner_dtype(); + // SAFETY: + // Invariant in List means values physicals can be cast to inner dtype + let values = unsafe { values.from_physical_unchecked(inner_dtype).unwrap() }; + + let df_context = values.into_frame(); + let phys_expr = + prepare_expression_for_context(PlSmallStr::EMPTY, expr, inner_dtype, Context::Aggregation)?; + + let state = ExecutionState::new(); + let mut ac = phys_expr.evaluate_on_groups(&df_context, &groups, &state)?; + let out = match ac.agg_state() { + AggState::AggregatedScalar(_) => { + let out = ac.aggregated(); + out.as_list().into_column() + }, + _ => ac.aggregated(), + }; + Ok(Some(out.with_name(name).into_column())) +} + +pub trait ListNameSpaceExtension: IntoListNameSpace + Sized { + /// Run any [`Expr`] on these lists elements + fn eval(self, expr: Expr, parallel: bool) -> Expr { + let this = self.into_list_name_space(); + + let expr2 = expr.clone(); + let func = move |c: Column| { + for e in expr.into_iter() { + match e { + #[cfg(feature = "dtype-categorical")] + Expr::Cast { + dtype: DataType::Categorical(_, _) | DataType::Enum(_, _), + .. + } => { + polars_bail!( + ComputeError: "casting to categorical not allowed in `list.eval`" + ) + }, + Expr::Column(name) => { + polars_ensure!( + name.is_empty(), + ComputeError: + "named columns are not allowed in `list.eval`; consider using `element` or `col(\"\")`" + ); + }, + _ => {}, + } + } + let lst = c.list()?.clone(); + + // # fast returns + // ensure we get the new schema + let output_field = eval_field_to_dtype(lst.ref_field(), &expr, true); + if lst.is_empty() { + return Ok(Some(Column::new_empty( + c.name().clone(), + output_field.dtype(), + ))); + } + if lst.null_count() == lst.len() { + return Ok(Some(c.cast(output_field.dtype())?.into_column())); + } + + let fits_idx_size = lst.get_values_size() <= (IdxSize::MAX as usize); + // If a users passes a return type to `apply`, e.g. `return_dtype=pl.Int64`, + // this fails as the list builder expects `List`, so let's skip that for now. + let is_user_apply = || { + expr.into_iter().any(|e| matches!(e, Expr::AnonymousFunction { options, .. } if options.fmt_str == MAP_LIST_NAME)) + }; + + if fits_idx_size && c.null_count() == 0 && !is_user_apply() { + run_on_group_by_engine(c.name().clone(), &lst, &expr) + } else { + run_per_sublist(c, &lst, &expr, parallel, output_field) + } + }; + + this.0 + .map( + func, + GetOutput::map_field(move |f| Ok(eval_field_to_dtype(f, &expr2, true))), + ) + .with_fmt("eval") + } +} + +impl ListNameSpaceExtension for ListNameSpace {} diff --git a/crates/polars-lazy/src/dsl/mod.rs b/crates/polars-lazy/src/dsl/mod.rs new file mode 100644 index 000000000000..5c157e4565ae --- /dev/null +++ b/crates/polars-lazy/src/dsl/mod.rs @@ -0,0 +1,50 @@ +//! Domain specific language for the Lazy API. +//! +//! This DSL revolves around the [`Expr`] type, which represents an abstract +//! operation on a DataFrame, such as mapping over a column, filtering, group_by, or aggregation. +//! In general, functions on [`LazyFrame`]s consume the [`LazyFrame`] and produce a new [`LazyFrame`] representing +//! the result of applying the function and passed expressions to the consumed LazyFrame. +//! At runtime, when [`LazyFrame::collect`](crate::frame::LazyFrame::collect) is called, the expressions that comprise +//! the [`LazyFrame`]'s logical plan are materialized on the actual underlying Series. +//! For instance, `let expr = col("x").pow(lit(2)).alias("x2");` would produce an expression representing the abstract +//! operation of squaring the column `"x"` and naming the resulting column `"x2"`, and to apply this operation to a +//! [`LazyFrame`], you'd use `let lazy_df = lazy_df.with_column(expr);`. +//! (Of course, a column named `"x"` must either exist in the original DataFrame or be produced by one of the preceding +//! operations on the [`LazyFrame`].) +//! +//! [`LazyFrame`]: crate::frame::LazyFrame +//! +//! There are many, many free functions that this module exports that produce an [`Expr`] from scratch; [`col`] and +//! [`lit`] are two examples. +//! Expressions also have several methods, such as [`pow`](`Expr::pow`) and [`alias`](`Expr::alias`), that consume them +//! and produce a new expression. +//! +//! Several expressions are only available when the necessary feature is enabled. +//! Examples of features that unlock specialized expression include `string`, `temporal`, and `dtype-categorical`. +//! These specialized expressions provide implementations of functions that you'd otherwise have to implement by hand. +//! +//! Because of how abstract and flexible the [`Expr`] type is, care must be take to ensure you only attempt to perform +//! sensible operations with them. +//! For instance, as mentioned above, you have to make sure any columns you reference already exist in the LazyFrame. +//! Furthermore, there is nothing stopping you from calling, for example, [`any`](`Expr::any`) with an expression +//! that will yield an `f64` column (instead of `bool`), or `col("string") - col("f64")`, which would attempt +//! to subtract an `f64` Series from a `string` Series. +//! These kinds of invalid operations will only yield an error at runtime, when +//! [`collect`](crate::frame::LazyFrame::collect) is called on the [`LazyFrame`]. + +#[cfg(any(feature = "cumulative_eval", feature = "list_eval"))] +mod eval; +pub mod functions; +mod into; +#[cfg(feature = "list_eval")] +mod list; + +#[cfg(any(feature = "cumulative_eval", feature = "list_eval"))] +pub use eval::*; +pub use functions::*; +#[cfg(any(feature = "cumulative_eval", feature = "list_eval"))] +use into::IntoExpr; +#[cfg(feature = "list_eval")] +pub use list::*; +pub use polars_plan::dsl::*; +pub use polars_plan::plans::UdfSchema; diff --git a/crates/polars-lazy/src/frame/cached_arenas.rs b/crates/polars-lazy/src/frame/cached_arenas.rs new file mode 100644 index 000000000000..3f985b8b1298 --- /dev/null +++ b/crates/polars-lazy/src/frame/cached_arenas.rs @@ -0,0 +1,117 @@ +use super::*; + +pub(crate) struct CachedArena { + lp_arena: Arena, + expr_arena: Arena, +} + +impl LazyFrame { + pub fn set_cached_arena(&self, lp_arena: Arena, expr_arena: Arena) { + let mut cached = self.cached_arena.lock().unwrap(); + *cached = Some(CachedArena { + lp_arena, + expr_arena, + }); + } + + pub fn schema_with_arenas( + &mut self, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + ) -> PolarsResult { + let node = to_alp( + self.logical_plan.clone(), + expr_arena, + lp_arena, + &mut OptFlags::schema_only(), + )?; + + let schema = lp_arena.get(node).schema(lp_arena).into_owned(); + // Cache the logical plan so that next schema call is cheap. + self.logical_plan = DslPlan::IR { + node: Some(node), + dsl: Arc::new(self.logical_plan.clone()), + version: lp_arena.version(), + }; + Ok(schema) + } + + /// Get a handle to the schema — a map from column names to data types — of the current + /// `LazyFrame` computation. + /// + /// Returns an `Err` if the logical plan has already encountered an error (i.e., if + /// `self.collect()` would fail), `Ok` otherwise. + pub fn collect_schema(&mut self) -> PolarsResult { + let mut cached_arenas = self.cached_arena.lock().unwrap(); + + match &mut *cached_arenas { + None => { + let mut lp_arena = Default::default(); + let mut expr_arena = Default::default(); + // Code duplication because of bchk. :( + let node = to_alp( + self.logical_plan.clone(), + &mut expr_arena, + &mut lp_arena, + &mut OptFlags::schema_only(), + )?; + + let schema = lp_arena.get(node).schema(&lp_arena).into_owned(); + // Cache the logical plan so that next schema call is cheap. + self.logical_plan = DslPlan::IR { + node: Some(node), + dsl: Arc::new(self.logical_plan.clone()), + version: lp_arena.version(), + }; + *cached_arenas = Some(CachedArena { + lp_arena, + expr_arena, + }); + + Ok(schema) + }, + Some(arenas) => { + match self.logical_plan { + // We have got arenas and don't need to convert the DSL. + DslPlan::IR { + node: Some(node), .. + } => Ok(arenas + .lp_arena + .get(node) + .schema(&arenas.lp_arena) + .into_owned()), + _ => { + // We have got arenas, but still need to convert (parts) of the DSL. + // Code duplication because of bchk. :( + let node = to_alp( + self.logical_plan.clone(), + &mut arenas.expr_arena, + &mut arenas.lp_arena, + &mut OptFlags::schema_only(), + )?; + + let schema = arenas + .lp_arena + .get(node) + .schema(&arenas.lp_arena) + .into_owned(); + // Cache the logical plan so that next schema call is cheap. + self.logical_plan = DslPlan::IR { + node: Some(node), + dsl: Arc::new(self.logical_plan.clone()), + version: arenas.lp_arena.version(), + }; + Ok(schema) + }, + } + }, + } + } + + pub(super) fn get_arenas(&mut self) -> (Arena, Arena) { + match self.cached_arena.lock().unwrap().as_mut() { + Some(arenas) => (arenas.lp_arena.clone(), arenas.expr_arena.clone()), + None => (Arena::with_capacity(16), Arena::with_capacity(16)), + } + } +} diff --git a/crates/polars-lazy/src/frame/err.rs b/crates/polars-lazy/src/frame/err.rs new file mode 100644 index 000000000000..af6d4a5cc86a --- /dev/null +++ b/crates/polars-lazy/src/frame/err.rs @@ -0,0 +1,17 @@ +/// Helper to delay a failing method until the query plan is collected +#[macro_export] +macro_rules! fallible { + ($e:expr, $lf:expr) => {{ + use $crate::prelude::*; + match $e { + Ok(e) => e, + Err(err) => { + let lf: LazyFrame = LogicalPlanBuilder::from($lf.clone().logical_plan) + .add_err(err) + .0 + .into(); + return lf; + }, + } + }}; +} diff --git a/crates/polars-lazy/src/frame/exitable.rs b/crates/polars-lazy/src/frame/exitable.rs new file mode 100644 index 000000000000..d18c142c8ca9 --- /dev/null +++ b/crates/polars-lazy/src/frame/exitable.rs @@ -0,0 +1,59 @@ +use std::sync::Mutex; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::mpsc::{Receiver, channel}; + +use polars_core::POOL; + +use super::*; + +impl LazyFrame { + pub fn collect_concurrently(self) -> PolarsResult { + let (mut state, mut physical_plan, _) = self.prepare_collect(false, None)?; + + let (tx, rx) = channel(); + let token = state.cancel_token(); + POOL.spawn_fifo(move || { + let result = physical_plan.execute(&mut state); + tx.send(result).unwrap(); + }); + + Ok(InProcessQuery { + rx: Arc::new(Mutex::new(rx)), + token, + }) + } +} + +#[derive(Clone)] +pub struct InProcessQuery { + rx: Arc>>>, + token: Arc, +} + +impl InProcessQuery { + /// Cancel the query at earliest convenience. + pub fn cancel(&self) { + self.token.store(true, Ordering::Relaxed) + } + + /// Fetch the result. + /// + /// If it is ready, a materialized DataFrame is returned. + /// If it is not ready it will return `None`. + pub fn fetch(&self) -> Option> { + let rx = self.rx.lock().unwrap(); + rx.try_recv().ok() + } + + /// Await the result synchronously. + pub fn fetch_blocking(&self) -> PolarsResult { + let rx = self.rx.lock().unwrap(); + rx.recv().unwrap() + } +} + +impl Drop for InProcessQuery { + fn drop(&mut self) { + self.token.store(true, Ordering::Relaxed); + } +} diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs new file mode 100644 index 000000000000..9e74f884795f --- /dev/null +++ b/crates/polars-lazy/src/frame/mod.rs @@ -0,0 +1,2650 @@ +//! Lazy variant of a [DataFrame]. +#[cfg(feature = "python")] +mod python; + +mod cached_arenas; +mod err; +#[cfg(not(target_arch = "wasm32"))] +mod exitable; +#[cfg(feature = "pivot")] +pub mod pivot; + +use std::path::PathBuf; +use std::sync::{Arc, Mutex}; + +pub use anonymous_scan::*; +#[cfg(feature = "csv")] +pub use csv::*; +#[cfg(not(target_arch = "wasm32"))] +pub use exitable::*; +pub use file_list_reader::*; +#[cfg(feature = "ipc")] +pub use ipc::*; +#[cfg(feature = "json")] +pub use ndjson::*; +#[cfg(feature = "parquet")] +pub use parquet::*; +use polars_compute::rolling::QuantileMethod; +use polars_core::POOL; +use polars_core::error::feature_gated; +use polars_core::prelude::*; +use polars_expr::{ExpressionConversionState, create_physical_expr}; +use polars_io::RowIndex; +use polars_mem_engine::{Executor, create_multiple_physical_plans, create_physical_plan}; +use polars_ops::frame::{JoinCoalesce, MaintainOrderJoin}; +#[cfg(feature = "is_between")] +use polars_ops::prelude::ClosedInterval; +pub use polars_plan::frame::{AllowedOptimizations, OptFlags}; +use polars_plan::global::FETCH_ROWS; +use polars_utils::pl_str::PlSmallStr; +use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; + +use crate::frame::cached_arenas::CachedArena; +#[cfg(feature = "streaming")] +use crate::physical_plan::streaming::insert_streaming_nodes; +use crate::prelude::*; + +pub trait IntoLazy { + fn lazy(self) -> LazyFrame; +} + +impl IntoLazy for DataFrame { + /// Convert the `DataFrame` into a `LazyFrame` + fn lazy(self) -> LazyFrame { + let lp = DslBuilder::from_existing_df(self).build(); + LazyFrame { + logical_plan: lp, + opt_state: Default::default(), + cached_arena: Default::default(), + } + } +} + +impl IntoLazy for LazyFrame { + fn lazy(self) -> LazyFrame { + self + } +} + +/// Lazy abstraction over an eager `DataFrame`. +/// +/// It really is an abstraction over a logical plan. The methods of this struct will incrementally +/// modify a logical plan until output is requested (via [`collect`](crate::frame::LazyFrame::collect)). +#[derive(Clone, Default)] +#[must_use] +pub struct LazyFrame { + pub logical_plan: DslPlan, + pub(crate) opt_state: OptFlags, + pub(crate) cached_arena: Arc>>, +} + +impl From for LazyFrame { + fn from(plan: DslPlan) -> Self { + Self { + logical_plan: plan, + opt_state: OptFlags::default(), + cached_arena: Default::default(), + } + } +} + +impl LazyFrame { + pub(crate) fn from_inner( + logical_plan: DslPlan, + opt_state: OptFlags, + cached_arena: Arc>>, + ) -> Self { + Self { + logical_plan, + opt_state, + cached_arena, + } + } + + pub(crate) fn get_plan_builder(self) -> DslBuilder { + DslBuilder::from(self.logical_plan) + } + + fn get_opt_state(&self) -> OptFlags { + self.opt_state + } + + fn from_logical_plan(logical_plan: DslPlan, opt_state: OptFlags) -> Self { + LazyFrame { + logical_plan, + opt_state, + cached_arena: Default::default(), + } + } + + /// Get current optimizations. + pub fn get_current_optimizations(&self) -> OptFlags { + self.opt_state + } + + /// Set allowed optimizations. + pub fn with_optimizations(mut self, opt_state: OptFlags) -> Self { + self.opt_state = opt_state; + self + } + + /// Turn off all optimizations. + pub fn without_optimizations(self) -> Self { + self.with_optimizations(OptFlags::from_bits_truncate(0) | OptFlags::TYPE_COERCION) + } + + /// Toggle projection pushdown optimization. + pub fn with_projection_pushdown(mut self, toggle: bool) -> Self { + self.opt_state.set(OptFlags::PROJECTION_PUSHDOWN, toggle); + self + } + + /// Toggle cluster with columns optimization. + pub fn with_cluster_with_columns(mut self, toggle: bool) -> Self { + self.opt_state.set(OptFlags::CLUSTER_WITH_COLUMNS, toggle); + self + } + + /// Toggle collapse joins optimization. + pub fn with_collapse_joins(mut self, toggle: bool) -> Self { + self.opt_state.set(OptFlags::COLLAPSE_JOINS, toggle); + self + } + + /// Check if operations are order dependent and unset maintaining_order if + /// the order would not be observed. + pub fn with_check_order(mut self, toggle: bool) -> Self { + self.opt_state.set(OptFlags::CHECK_ORDER_OBSERVE, toggle); + self + } + + /// Toggle predicate pushdown optimization. + pub fn with_predicate_pushdown(mut self, toggle: bool) -> Self { + self.opt_state.set(OptFlags::PREDICATE_PUSHDOWN, toggle); + self + } + + /// Toggle type coercion optimization. + pub fn with_type_coercion(mut self, toggle: bool) -> Self { + self.opt_state.set(OptFlags::TYPE_COERCION, toggle); + self + } + + /// Toggle type check optimization. + pub fn with_type_check(mut self, toggle: bool) -> Self { + self.opt_state.set(OptFlags::TYPE_CHECK, toggle); + self + } + + /// Toggle expression simplification optimization on or off. + pub fn with_simplify_expr(mut self, toggle: bool) -> Self { + self.opt_state.set(OptFlags::SIMPLIFY_EXPR, toggle); + self + } + + /// Toggle common subplan elimination optimization on or off + #[cfg(feature = "cse")] + pub fn with_comm_subplan_elim(mut self, toggle: bool) -> Self { + self.opt_state.set(OptFlags::COMM_SUBPLAN_ELIM, toggle); + self + } + + /// Toggle common subexpression elimination optimization on or off + #[cfg(feature = "cse")] + pub fn with_comm_subexpr_elim(mut self, toggle: bool) -> Self { + self.opt_state.set(OptFlags::COMM_SUBEXPR_ELIM, toggle); + self + } + + /// Toggle slice pushdown optimization. + pub fn with_slice_pushdown(mut self, toggle: bool) -> Self { + self.opt_state.set(OptFlags::SLICE_PUSHDOWN, toggle); + self + } + + /// Run nodes that are capably of doing so on the streaming engine. + #[cfg(feature = "streaming")] + pub fn with_streaming(mut self, toggle: bool) -> Self { + self.opt_state.set(OptFlags::STREAMING, toggle); + self + } + + #[cfg(feature = "new_streaming")] + pub fn with_new_streaming(mut self, toggle: bool) -> Self { + self.opt_state.set(OptFlags::NEW_STREAMING, toggle); + self + } + + /// Try to estimate the number of rows so that joins can determine which side to keep in memory. + pub fn with_row_estimate(mut self, toggle: bool) -> Self { + self.opt_state.set(OptFlags::ROW_ESTIMATE, toggle); + self + } + + /// Run every node eagerly. This turns off multi-node optimizations. + pub fn _with_eager(mut self, toggle: bool) -> Self { + self.opt_state.set(OptFlags::EAGER, toggle); + self + } + + /// Return a String describing the naive (un-optimized) logical plan. + pub fn describe_plan(&self) -> PolarsResult { + Ok(self.clone().to_alp()?.describe()) + } + + /// Return a String describing the naive (un-optimized) logical plan in tree format. + pub fn describe_plan_tree(&self) -> PolarsResult { + Ok(self.clone().to_alp()?.describe_tree_format()) + } + + // @NOTE: this is used because we want to set the `enable_fmt` flag of `optimize_with_scratch` + // to `true` for describe. + fn _describe_to_alp_optimized(mut self) -> PolarsResult { + let (mut lp_arena, mut expr_arena) = self.get_arenas(); + let node = self.optimize_with_scratch(&mut lp_arena, &mut expr_arena, &mut vec![], true)?; + + Ok(IRPlan::new(node, lp_arena, expr_arena)) + } + + /// Return a String describing the optimized logical plan. + /// + /// Returns `Err` if optimizing the logical plan fails. + pub fn describe_optimized_plan(&self) -> PolarsResult { + Ok(self.clone()._describe_to_alp_optimized()?.describe()) + } + + /// Return a String describing the optimized logical plan in tree format. + /// + /// Returns `Err` if optimizing the logical plan fails. + pub fn describe_optimized_plan_tree(&self) -> PolarsResult { + Ok(self + .clone() + ._describe_to_alp_optimized()? + .describe_tree_format()) + } + + /// Return a String describing the logical plan. + /// + /// If `optimized` is `true`, explains the optimized plan. If `optimized` is `false`, + /// explains the naive, un-optimized plan. + pub fn explain(&self, optimized: bool) -> PolarsResult { + if optimized { + self.describe_optimized_plan() + } else { + self.describe_plan() + } + } + + /// Add a sort operation to the logical plan. + /// + /// Sorts the LazyFrame by the column name specified using the provided options. + /// + /// # Example + /// + /// Sort DataFrame by 'sepal_width' column: + /// ```rust + /// # use polars_core::prelude::*; + /// # use polars_lazy::prelude::*; + /// fn sort_by_a(df: DataFrame) -> LazyFrame { + /// df.lazy().sort(["sepal_width"], Default::default()) + /// } + /// ``` + /// Sort by a single column with specific order: + /// ``` + /// # use polars_core::prelude::*; + /// # use polars_lazy::prelude::*; + /// fn sort_with_specific_order(df: DataFrame, descending: bool) -> LazyFrame { + /// df.lazy().sort( + /// ["sepal_width"], + /// SortMultipleOptions::new() + /// .with_order_descending(descending) + /// ) + /// } + /// ``` + /// Sort by multiple columns with specifying order for each column: + /// ``` + /// # use polars_core::prelude::*; + /// # use polars_lazy::prelude::*; + /// fn sort_by_multiple_columns_with_specific_order(df: DataFrame) -> LazyFrame { + /// df.lazy().sort( + /// ["sepal_width", "sepal_length"], + /// SortMultipleOptions::new() + /// .with_order_descending_multi([false, true]) + /// ) + /// } + /// ``` + /// See [`SortMultipleOptions`] for more options. + pub fn sort(self, by: impl IntoVec, sort_options: SortMultipleOptions) -> Self { + let opt_state = self.get_opt_state(); + let lp = self + .get_plan_builder() + .sort(by.into_vec().into_iter().map(col).collect(), sort_options) + .build(); + Self::from_logical_plan(lp, opt_state) + } + + /// Add a sort operation to the logical plan. + /// + /// Sorts the LazyFrame by the provided list of expressions, which will be turned into + /// concrete columns before sorting. + /// + /// See [`SortMultipleOptions`] for more options. + /// + /// # Example + /// + /// ```rust + /// use polars_core::prelude::*; + /// use polars_lazy::prelude::*; + /// + /// /// Sort DataFrame by 'sepal_width' column + /// fn example(df: DataFrame) -> LazyFrame { + /// df.lazy() + /// .sort_by_exprs(vec![col("sepal_width")], Default::default()) + /// } + /// ``` + pub fn sort_by_exprs>( + self, + by_exprs: E, + sort_options: SortMultipleOptions, + ) -> Self { + let by_exprs = by_exprs.as_ref().to_vec(); + if by_exprs.is_empty() { + self + } else { + let opt_state = self.get_opt_state(); + let lp = self.get_plan_builder().sort(by_exprs, sort_options).build(); + Self::from_logical_plan(lp, opt_state) + } + } + + pub fn top_k>( + self, + k: IdxSize, + by_exprs: E, + sort_options: SortMultipleOptions, + ) -> Self { + // this will optimize to top-k + self.sort_by_exprs( + by_exprs, + sort_options.with_order_reversed().with_nulls_last(true), + ) + .slice(0, k) + } + + pub fn bottom_k>( + self, + k: IdxSize, + by_exprs: E, + sort_options: SortMultipleOptions, + ) -> Self { + // this will optimize to bottom-k + self.sort_by_exprs(by_exprs, sort_options.with_nulls_last(true)) + .slice(0, k) + } + + /// Reverse the `DataFrame` from top to bottom. + /// + /// Row `i` becomes row `number_of_rows - i - 1`. + /// + /// # Example + /// + /// ```rust + /// use polars_core::prelude::*; + /// use polars_lazy::prelude::*; + /// + /// fn example(df: DataFrame) -> LazyFrame { + /// df.lazy() + /// .reverse() + /// } + /// ``` + pub fn reverse(self) -> Self { + self.select(vec![col(PlSmallStr::from_static("*")).reverse()]) + } + + /// Rename columns in the DataFrame. + /// + /// `existing` and `new` are iterables of the same length containing the old and + /// corresponding new column names. Renaming happens to all `existing` columns + /// simultaneously, not iteratively. If `strict` is true, all columns in `existing` + /// must be present in the `LazyFrame` when `rename` is called; otherwise, only + /// those columns that are actually found will be renamed (others will be ignored). + pub fn rename(self, existing: I, new: J, strict: bool) -> Self + where + I: IntoIterator, + J: IntoIterator, + T: AsRef, + S: AsRef, + { + let iter = existing.into_iter(); + let cap = iter.size_hint().0; + let mut existing_vec: Vec = Vec::with_capacity(cap); + let mut new_vec: Vec = Vec::with_capacity(cap); + + // TODO! should this error if `existing` and `new` have different lengths? + // Currently, the longer of the two is truncated. + for (existing, new) in iter.zip(new) { + let existing = existing.as_ref(); + let new = new.as_ref(); + if new != existing { + existing_vec.push(existing.into()); + new_vec.push(new.into()); + } + } + + self.map_private(DslFunction::Rename { + existing: existing_vec.into(), + new: new_vec.into(), + strict, + }) + } + + /// Removes columns from the DataFrame. + /// Note that it's better to only select the columns you need + /// and let the projection pushdown optimize away the unneeded columns. + /// + /// If `strict` is `true`, then any given columns that are not in the schema will + /// give a [`PolarsError::ColumnNotFound`] error while materializing the [`LazyFrame`]. + fn _drop(self, columns: I, strict: bool) -> Self + where + I: IntoIterator, + T: Into, + { + let to_drop = columns.into_iter().map(|c| c.into()).collect(); + + let opt_state = self.get_opt_state(); + let lp = self.get_plan_builder().drop(to_drop, strict).build(); + Self::from_logical_plan(lp, opt_state) + } + + /// Removes columns from the DataFrame. + /// Note that it's better to only select the columns you need + /// and let the projection pushdown optimize away the unneeded columns. + /// + /// Any given columns that are not in the schema will give a [`PolarsError::ColumnNotFound`] + /// error while materializing the [`LazyFrame`]. + pub fn drop(self, columns: I) -> Self + where + I: IntoIterator, + T: Into, + { + self._drop(columns, true) + } + + /// Removes columns from the DataFrame. + /// Note that it's better to only select the columns you need + /// and let the projection pushdown optimize away the unneeded columns. + /// + /// If a column name does not exist in the schema, it will quietly be ignored. + pub fn drop_no_validate(self, columns: I) -> Self + where + I: IntoIterator, + T: Into, + { + self._drop(columns, false) + } + + /// Shift the values by a given period and fill the parts that will be empty due to this operation + /// with `Nones`. + /// + /// See the method on [Series](polars_core::series::SeriesTrait::shift) for more info on the `shift` operation. + pub fn shift>(self, n: E) -> Self { + self.select(vec![col(PlSmallStr::from_static("*")).shift(n.into())]) + } + + /// Shift the values by a given period and fill the parts that will be empty due to this operation + /// with the result of the `fill_value` expression. + /// + /// See the method on [Series](polars_core::series::SeriesTrait::shift) for more info on the `shift` operation. + pub fn shift_and_fill, IE: Into>(self, n: E, fill_value: IE) -> Self { + self.select(vec![ + col(PlSmallStr::from_static("*")).shift_and_fill(n.into(), fill_value.into()), + ]) + } + + /// Fill None values in the DataFrame with an expression. + pub fn fill_null>(self, fill_value: E) -> LazyFrame { + let opt_state = self.get_opt_state(); + let lp = self.get_plan_builder().fill_null(fill_value.into()).build(); + Self::from_logical_plan(lp, opt_state) + } + + /// Fill NaN values in the DataFrame with an expression. + pub fn fill_nan>(self, fill_value: E) -> LazyFrame { + let opt_state = self.get_opt_state(); + let lp = self.get_plan_builder().fill_nan(fill_value.into()).build(); + Self::from_logical_plan(lp, opt_state) + } + + /// Caches the result into a new LazyFrame. + /// + /// This should be used to prevent computations running multiple times. + pub fn cache(self) -> Self { + let opt_state = self.get_opt_state(); + let lp = self.get_plan_builder().cache().build(); + Self::from_logical_plan(lp, opt_state) + } + + /// Cast named frame columns, resulting in a new LazyFrame with updated dtypes + pub fn cast(self, dtypes: PlHashMap<&str, DataType>, strict: bool) -> Self { + let cast_cols: Vec = dtypes + .into_iter() + .map(|(name, dt)| { + let name = PlSmallStr::from_str(name); + + if strict { + col(name).strict_cast(dt) + } else { + col(name).cast(dt) + } + }) + .collect(); + + if cast_cols.is_empty() { + self.clone() + } else { + self.with_columns(cast_cols) + } + } + + /// Cast all frame columns to the given dtype, resulting in a new LazyFrame + pub fn cast_all(self, dtype: DataType, strict: bool) -> Self { + self.with_columns(vec![if strict { + col(PlSmallStr::from_static("*")).strict_cast(dtype) + } else { + col(PlSmallStr::from_static("*")).cast(dtype) + }]) + } + + /// Fetch is like a collect operation, but it overwrites the number of rows read by every scan + /// operation. This is a utility that helps debug a query on a smaller number of rows. + /// + /// Note that the fetch does not guarantee the final number of rows in the DataFrame. + /// Filter, join operations and a lower number of rows available in the scanned file influence + /// the final number of rows. + pub fn fetch(self, n_rows: usize) -> PolarsResult { + FETCH_ROWS.with(|fetch_rows| fetch_rows.set(Some(n_rows))); + let res = self.collect(); + FETCH_ROWS.with(|fetch_rows| fetch_rows.set(None)); + res + } + + pub fn optimize( + self, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + ) -> PolarsResult { + self.optimize_with_scratch(lp_arena, expr_arena, &mut vec![], false) + } + + pub fn to_alp_optimized(mut self) -> PolarsResult { + let (mut lp_arena, mut expr_arena) = self.get_arenas(); + let node = + self.optimize_with_scratch(&mut lp_arena, &mut expr_arena, &mut vec![], false)?; + + Ok(IRPlan::new(node, lp_arena, expr_arena)) + } + + pub fn to_alp(mut self) -> PolarsResult { + let (mut lp_arena, mut expr_arena) = self.get_arenas(); + let node = to_alp( + self.logical_plan, + &mut expr_arena, + &mut lp_arena, + &mut self.opt_state, + )?; + let plan = IRPlan::new(node, lp_arena, expr_arena); + Ok(plan) + } + + pub(crate) fn optimize_with_scratch( + self, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + scratch: &mut Vec, + enable_fmt: bool, + ) -> PolarsResult { + #[allow(unused_mut)] + let mut opt_state = self.opt_state; + let streaming = self.opt_state.contains(OptFlags::STREAMING); + let new_streaming = self.opt_state.contains(OptFlags::NEW_STREAMING); + #[cfg(feature = "cse")] + if streaming && !new_streaming { + opt_state &= !OptFlags::COMM_SUBPLAN_ELIM; + } + + #[cfg(feature = "cse")] + if new_streaming { + // The new streaming engine can't deal with the way the common + // subexpression elimination adds length-incorrect with_columns. + opt_state &= !OptFlags::COMM_SUBEXPR_ELIM; + } + + let lp_top = optimize( + self.logical_plan, + opt_state, + lp_arena, + expr_arena, + scratch, + Some(&|expr, expr_arena, schema| { + let phys_expr = create_physical_expr( + expr, + Context::Default, + expr_arena, + schema, + &mut ExpressionConversionState::new(true), + ) + .ok()?; + let io_expr = phys_expr_to_io_expr(phys_expr); + Some(io_expr) + }), + )?; + + if streaming { + #[cfg(feature = "streaming")] + { + insert_streaming_nodes( + lp_top, + lp_arena, + expr_arena, + scratch, + enable_fmt, + true, + opt_state.contains(OptFlags::ROW_ESTIMATE), + )?; + } + #[cfg(not(feature = "streaming"))] + { + _ = enable_fmt; + panic!("activate feature 'streaming'") + } + } + + Ok(lp_top) + } + + fn prepare_collect_post_opt

( + mut self, + check_sink: bool, + query_start: Option, + post_opt: P, + ) -> PolarsResult<(ExecutionState, Box, bool)> + where + P: FnOnce( + Node, + &mut Arena, + &mut Arena, + Option, + ) -> PolarsResult<()>, + { + let (mut lp_arena, mut expr_arena) = self.get_arenas(); + + let mut scratch = vec![]; + let lp_top = + self.optimize_with_scratch(&mut lp_arena, &mut expr_arena, &mut scratch, false)?; + + post_opt( + lp_top, + &mut lp_arena, + &mut expr_arena, + // Post optimization callback gets the time since the + // query was started as its "base" timepoint. + query_start.map(|s| s.elapsed()), + )?; + + // sink should be replaced + let no_file_sink = if check_sink { + !matches!( + lp_arena.get(lp_top), + IR::Sink { + payload: SinkTypeIR::File { .. } | SinkTypeIR::Partition { .. }, + .. + } + ) + } else { + true + }; + let physical_plan = create_physical_plan( + lp_top, + &mut lp_arena, + &mut expr_arena, + BUILD_STREAMING_EXECUTOR, + )?; + + let state = ExecutionState::new(); + Ok((state, physical_plan, no_file_sink)) + } + + // post_opt: A function that is called after optimization. This can be used to modify the IR jit. + pub fn _collect_post_opt

(self, post_opt: P) -> PolarsResult + where + P: FnOnce( + Node, + &mut Arena, + &mut Arena, + Option, + ) -> PolarsResult<()>, + { + let (mut state, mut physical_plan, _) = + self.prepare_collect_post_opt(false, None, post_opt)?; + physical_plan.execute(&mut state) + } + + #[allow(unused_mut)] + fn prepare_collect( + self, + check_sink: bool, + query_start: Option, + ) -> PolarsResult<(ExecutionState, Box, bool)> { + self.prepare_collect_post_opt(check_sink, query_start, |_, _, _, _| Ok(())) + } + + /// Execute all the lazy operations and collect them into a [`DataFrame`] using a specified + /// `engine`. + /// + /// The query is optimized prior to execution. + pub fn collect_with_engine(mut self, mut engine: Engine) -> PolarsResult { + let payload = if let DslPlan::Sink { payload, .. } = &self.logical_plan { + payload.clone() + } else { + self.logical_plan = DslPlan::Sink { + input: Arc::new(self.logical_plan), + payload: SinkType::Memory, + }; + SinkType::Memory + }; + + // Default engine for collect is InMemory, sink_* is Streaming + if engine == Engine::Auto { + engine = match payload { + #[cfg(feature = "new_streaming")] + SinkType::File { .. } | SinkType::Partition { .. } => Engine::Streaming, + _ => Engine::InMemory, + }; + } + // Gpu uses some hacks to dispatch. + if engine == Engine::Gpu { + engine = Engine::InMemory; + } + + #[cfg(feature = "new_streaming")] + { + if let Some(result) = self.try_new_streaming_if_requested() { + return result.map(|v| v.unwrap_single()); + } + } + + match engine { + Engine::Auto => unreachable!(), + Engine::Streaming => { + feature_gated!("new_streaming", self = self.with_new_streaming(true)) + }, + Engine::OldStreaming => feature_gated!("streaming", self = self.with_streaming(true)), + _ => {}, + } + let mut alp_plan = self.clone().to_alp_optimized()?; + + match engine { + Engine::Auto | Engine::Streaming => feature_gated!("new_streaming", { + let string_cache_hold = StringCacheHolder::hold(); + let result = polars_stream::run_query( + alp_plan.lp_top, + &mut alp_plan.lp_arena, + &mut alp_plan.expr_arena, + ); + drop(string_cache_hold); + result.map(|v| v.unwrap_single()) + }), + _ if matches!(payload, SinkType::Partition { .. }) => Err(polars_err!( + InvalidOperation: "partition sinks are not supported on for the '{}' engine", + engine.into_static_str() + )), + Engine::Gpu => { + Err(polars_err!(InvalidOperation: "sink is not supported for the gpu engine")) + }, + Engine::InMemory => { + let mut physical_plan = create_physical_plan( + alp_plan.lp_top, + &mut alp_plan.lp_arena, + &mut alp_plan.expr_arena, + BUILD_STREAMING_EXECUTOR, + )?; + let mut state = ExecutionState::new(); + physical_plan.execute(&mut state) + }, + Engine::OldStreaming => { + self.opt_state |= OptFlags::STREAMING; + let (mut state, mut physical_plan, is_streaming) = + self.prepare_collect(true, None)?; + polars_ensure!( + is_streaming, + ComputeError: format!("cannot run the whole query in a streaming order") + ); + physical_plan.execute(&mut state) + }, + } + } + + pub fn explain_all(plans: Vec, opt_state: OptFlags) -> PolarsResult { + let sink_multiple = LazyFrame { + logical_plan: DslPlan::SinkMultiple { inputs: plans }, + opt_state, + cached_arena: Default::default(), + }; + sink_multiple.explain(true) + } + + pub fn collect_all_with_engine( + plans: Vec, + mut engine: Engine, + opt_state: OptFlags, + ) -> PolarsResult> { + if plans.is_empty() { + return Ok(Vec::new()); + } + + // Default engine for collect_all is InMemory + if engine == Engine::Auto { + engine = Engine::InMemory; + } + // Gpu uses some hacks to dispatch. + if engine == Engine::Gpu { + engine = Engine::InMemory; + } + + let mut sink_multiple = LazyFrame { + logical_plan: DslPlan::SinkMultiple { inputs: plans }, + opt_state, + cached_arena: Default::default(), + }; + + #[cfg(feature = "new_streaming")] + { + if let Some(result) = sink_multiple.try_new_streaming_if_requested() { + return result.map(|v| v.unwrap_multiple()); + } + } + + match engine { + Engine::Auto => unreachable!(), + Engine::Streaming => { + feature_gated!( + "new_streaming", + sink_multiple = sink_multiple.with_new_streaming(true) + ) + }, + Engine::OldStreaming => feature_gated!( + "streaming", + sink_multiple = sink_multiple.with_streaming(true) + ), + _ => {}, + } + let mut alp_plan = sink_multiple.to_alp_optimized()?; + + if engine == Engine::Streaming { + feature_gated!("new_streaming", { + let string_cache_hold = StringCacheHolder::hold(); + let result = polars_stream::run_query( + alp_plan.lp_top, + &mut alp_plan.lp_arena, + &mut alp_plan.expr_arena, + ); + drop(string_cache_hold); + return result.map(|v| v.unwrap_multiple()); + }); + } + + let IR::SinkMultiple { inputs } = alp_plan.root() else { + unreachable!() + }; + + let mut multiplan = create_multiple_physical_plans( + inputs.clone().as_slice(), + &mut alp_plan.lp_arena, + &mut alp_plan.expr_arena, + BUILD_STREAMING_EXECUTOR, + )?; + + match engine { + Engine::Gpu => polars_bail!( + InvalidOperation: "collect_all is not supported for the gpu engine" + ), + Engine::InMemory => { + // We don't use par_iter directly because the LP may also start threads for every LP (for instance scan_csv) + // this might then lead to a rayon SO. So we take a multitude of the threads to keep work stealing + // within bounds + let mut state = ExecutionState::new(); + if let Some(mut cache_prefiller) = multiplan.cache_prefiller { + cache_prefiller.execute(&mut state)?; + } + let out = POOL.install(|| { + multiplan + .physical_plans + .chunks_mut(POOL.current_num_threads() * 3) + .map(|chunk| { + chunk + .into_par_iter() + .enumerate() + .map(|(idx, input)| { + let mut input = std::mem::take(input); + let mut state = state.split(); + state.branch_idx += idx; + + let df = input.execute(&mut state)?; + Ok(df) + }) + .collect::>>() + }) + .collect::>>() + }); + Ok(out?.into_iter().flatten().collect()) + }, + Engine::OldStreaming => panic!("This is no longer supported"), + _ => unreachable!(), + } + } + + /// Execute all the lazy operations and collect them into a [`DataFrame`]. + /// + /// The query is optimized prior to execution. + /// + /// # Example + /// + /// ```rust + /// use polars_core::prelude::*; + /// use polars_lazy::prelude::*; + /// + /// fn example(df: DataFrame) -> PolarsResult { + /// df.lazy() + /// .group_by([col("foo")]) + /// .agg([col("bar").sum(), col("ham").mean().alias("avg_ham")]) + /// .collect() + /// } + /// ``` + pub fn collect(self) -> PolarsResult { + self.collect_with_engine(Engine::InMemory) + } + + // post_opt: A function that is called after optimization. This can be used to modify the IR jit. + // This version does profiling of the node execution. + pub fn _profile_post_opt

(self, post_opt: P) -> PolarsResult<(DataFrame, DataFrame)> + where + P: FnOnce( + Node, + &mut Arena, + &mut Arena, + Option, + ) -> PolarsResult<()>, + { + let query_start = std::time::Instant::now(); + let (mut state, mut physical_plan, _) = + self.prepare_collect_post_opt(false, Some(query_start), post_opt)?; + state.time_nodes(query_start); + let out = physical_plan.execute(&mut state)?; + let timer_df = state.finish_timer()?; + Ok((out, timer_df)) + } + + /// Profile a LazyFrame. + /// + /// This will run the query and return a tuple + /// containing the materialized DataFrame and a DataFrame that contains profiling information + /// of each node that is executed. + /// + /// The units of the timings are microseconds. + pub fn profile(self) -> PolarsResult<(DataFrame, DataFrame)> { + self._profile_post_opt(|_, _, _, _| Ok(())) + } + + /// Stream a query result into a parquet file. This is useful if the final result doesn't fit + /// into memory. This methods will return an error if the query cannot be completely done in a + /// streaming fashion. + #[cfg(feature = "parquet")] + pub fn sink_parquet( + self, + target: SinkTarget, + options: ParquetWriteOptions, + cloud_options: Option, + sink_options: SinkOptions, + ) -> PolarsResult { + self.sink(SinkType::File(FileSinkType { + target, + sink_options, + file_type: FileType::Parquet(options), + cloud_options, + })) + } + + /// Stream a query result into an ipc/arrow file. This is useful if the final result doesn't fit + /// into memory. This methods will return an error if the query cannot be completely done in a + /// streaming fashion. + #[cfg(feature = "ipc")] + pub fn sink_ipc( + self, + target: SinkTarget, + options: IpcWriterOptions, + cloud_options: Option, + sink_options: SinkOptions, + ) -> PolarsResult { + self.sink(SinkType::File(FileSinkType { + target, + sink_options, + file_type: FileType::Ipc(options), + cloud_options, + })) + } + + /// Stream a query result into an csv file. This is useful if the final result doesn't fit + /// into memory. This methods will return an error if the query cannot be completely done in a + /// streaming fashion. + #[cfg(feature = "csv")] + pub fn sink_csv( + self, + target: SinkTarget, + options: CsvWriterOptions, + cloud_options: Option, + sink_options: SinkOptions, + ) -> PolarsResult { + self.sink(SinkType::File(FileSinkType { + target, + sink_options, + file_type: FileType::Csv(options), + cloud_options, + })) + } + + /// Stream a query result into a JSON file. This is useful if the final result doesn't fit + /// into memory. This methods will return an error if the query cannot be completely done in a + /// streaming fashion. + #[cfg(feature = "json")] + pub fn sink_json( + self, + target: SinkTarget, + options: JsonWriterOptions, + cloud_options: Option, + sink_options: SinkOptions, + ) -> PolarsResult { + self.sink(SinkType::File(FileSinkType { + target, + sink_options, + file_type: FileType::Json(options), + cloud_options, + })) + } + + /// Stream a query result into a parquet file in a partitioned manner. This is useful if the + /// final result doesn't fit into memory. This methods will return an error if the query cannot + /// be completely done in a streaming fashion. + #[cfg(feature = "parquet")] + pub fn sink_parquet_partitioned( + self, + base_path: Arc, + file_path_cb: Option, + variant: PartitionVariant, + options: ParquetWriteOptions, + cloud_options: Option, + sink_options: SinkOptions, + ) -> PolarsResult { + self.sink(SinkType::Partition(PartitionSinkType { + base_path, + file_path_cb, + sink_options, + variant, + file_type: FileType::Parquet(options), + cloud_options, + })) + } + + /// Stream a query result into an ipc/arrow file in a partitioned manner. This is useful if the + /// final result doesn't fit into memory. This methods will return an error if the query cannot + /// be completely done in a streaming fashion. + #[cfg(feature = "ipc")] + pub fn sink_ipc_partitioned( + self, + base_path: Arc, + file_path_cb: Option, + variant: PartitionVariant, + options: IpcWriterOptions, + cloud_options: Option, + sink_options: SinkOptions, + ) -> PolarsResult { + self.sink(SinkType::Partition(PartitionSinkType { + base_path, + file_path_cb, + sink_options, + variant, + file_type: FileType::Ipc(options), + cloud_options, + })) + } + + /// Stream a query result into an csv file in a partitioned manner. This is useful if the final + /// result doesn't fit into memory. This methods will return an error if the query cannot be + /// completely done in a streaming fashion. + #[cfg(feature = "csv")] + pub fn sink_csv_partitioned( + self, + base_path: Arc, + file_path_cb: Option, + variant: PartitionVariant, + options: CsvWriterOptions, + cloud_options: Option, + sink_options: SinkOptions, + ) -> PolarsResult { + self.sink(SinkType::Partition(PartitionSinkType { + base_path, + file_path_cb, + sink_options, + variant, + file_type: FileType::Csv(options), + cloud_options, + })) + } + + /// Stream a query result into a JSON file in a partitioned manner. This is useful if the final + /// result doesn't fit into memory. This methods will return an error if the query cannot be + /// completely done in a streaming fashion. + #[cfg(feature = "json")] + pub fn sink_json_partitioned( + self, + base_path: Arc, + file_path_cb: Option, + variant: PartitionVariant, + options: JsonWriterOptions, + cloud_options: Option, + sink_options: SinkOptions, + ) -> PolarsResult { + self.sink(SinkType::Partition(PartitionSinkType { + base_path, + file_path_cb, + sink_options, + variant, + file_type: FileType::Json(options), + cloud_options, + })) + } + + #[cfg(feature = "new_streaming")] + pub fn try_new_streaming_if_requested( + &mut self, + ) -> Option> { + let auto_new_streaming = std::env::var("POLARS_AUTO_NEW_STREAMING").as_deref() == Ok("1"); + let force_new_streaming = std::env::var("POLARS_FORCE_NEW_STREAMING").as_deref() == Ok("1"); + + if auto_new_streaming || force_new_streaming { + // Try to run using the new streaming engine, falling back + // if it fails in a todo!() error if auto_new_streaming is set. + let mut new_stream_lazy = self.clone(); + new_stream_lazy.opt_state |= OptFlags::NEW_STREAMING; + new_stream_lazy.opt_state &= !OptFlags::STREAMING; + let mut alp_plan = match new_stream_lazy.to_alp_optimized() { + Ok(v) => v, + Err(e) => return Some(Err(e)), + }; + + let _hold = StringCacheHolder::hold(); + let f = || { + polars_stream::run_query( + alp_plan.lp_top, + &mut alp_plan.lp_arena, + &mut alp_plan.expr_arena, + ) + }; + + match std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)) { + Ok(v) => return Some(v), + Err(e) => { + // Fallback to normal engine if error is due to not being implemented + // and auto_new_streaming is set, otherwise propagate error. + if !force_new_streaming + && auto_new_streaming + && e.downcast_ref::<&str>() + .map(|s| s.starts_with("not yet implemented")) + .unwrap_or(false) + { + if polars_core::config::verbose() { + eprintln!( + "caught unimplemented error in new streaming engine, falling back to normal engine" + ); + } + } else { + std::panic::resume_unwind(e); + } + }, + } + } + + None + } + + fn sink(mut self, payload: SinkType) -> Result { + polars_ensure!( + !matches!(self.logical_plan, DslPlan::Sink { .. }), + InvalidOperation: "cannot create a sink on top of another sink" + ); + self.logical_plan = DslPlan::Sink { + input: Arc::new(self.logical_plan), + payload: payload.clone(), + }; + Ok(self) + } + + /// Filter frame rows that match a predicate expression. + /// + /// The expression must yield boolean values (note that rows where the + /// predicate resolves to `null` are *not* included in the resulting frame). + /// + /// # Example + /// + /// ```rust + /// use polars_core::prelude::*; + /// use polars_lazy::prelude::*; + /// + /// fn example(df: DataFrame) -> LazyFrame { + /// df.lazy() + /// .filter(col("sepal_width").is_not_null()) + /// .select([col("sepal_width"), col("sepal_length")]) + /// } + /// ``` + pub fn filter(self, predicate: Expr) -> Self { + let opt_state = self.get_opt_state(); + let lp = self.get_plan_builder().filter(predicate).build(); + Self::from_logical_plan(lp, opt_state) + } + + /// Remove frame rows that match a predicate expression. + /// + /// The expression must yield boolean values (note that rows where the + /// predicate resolves to `null` are *not* removed from the resulting frame). + /// + /// # Example + /// + /// ```rust + /// use polars_core::prelude::*; + /// use polars_lazy::prelude::*; + /// + /// fn example(df: DataFrame) -> LazyFrame { + /// df.lazy() + /// .remove(col("sepal_width").is_null()) + /// .select([col("sepal_width"), col("sepal_length")]) + /// } + /// ``` + pub fn remove(self, predicate: Expr) -> Self { + self.filter(predicate.neq_missing(lit(true))) + } + + /// Select (and optionally rename, with [`alias`](crate::dsl::Expr::alias)) columns from the query. + /// + /// Columns can be selected with [`col`]; + /// If you want to select all columns use `col(PlSmallStr::from_static("*"))`. + /// + /// # Example + /// + /// ```rust + /// use polars_core::prelude::*; + /// use polars_lazy::prelude::*; + /// + /// /// This function selects column "foo" and column "bar". + /// /// Column "bar" is renamed to "ham". + /// fn example(df: DataFrame) -> LazyFrame { + /// df.lazy() + /// .select([col("foo"), + /// col("bar").alias("ham")]) + /// } + /// + /// /// This function selects all columns except "foo" + /// fn exclude_a_column(df: DataFrame) -> LazyFrame { + /// df.lazy() + /// .select([col(PlSmallStr::from_static("*")).exclude(["foo"])]) + /// } + /// ``` + pub fn select>(self, exprs: E) -> Self { + let exprs = exprs.as_ref().to_vec(); + self.select_impl( + exprs, + ProjectionOptions { + run_parallel: true, + duplicate_check: true, + should_broadcast: true, + }, + ) + } + + pub fn select_seq>(self, exprs: E) -> Self { + let exprs = exprs.as_ref().to_vec(); + self.select_impl( + exprs, + ProjectionOptions { + run_parallel: false, + duplicate_check: true, + should_broadcast: true, + }, + ) + } + + fn select_impl(self, exprs: Vec, options: ProjectionOptions) -> Self { + let opt_state = self.get_opt_state(); + let lp = self.get_plan_builder().project(exprs, options).build(); + Self::from_logical_plan(lp, opt_state) + } + + /// Performs a "group-by" on a `LazyFrame`, producing a [`LazyGroupBy`], which can subsequently be aggregated. + /// + /// Takes a list of expressions to group on. + /// + /// # Example + /// + /// ```rust + /// use polars_core::prelude::*; + /// use polars_lazy::prelude::*; + /// + /// fn example(df: DataFrame) -> LazyFrame { + /// df.lazy() + /// .group_by([col("date")]) + /// .agg([ + /// col("rain").min().alias("min_rain"), + /// col("rain").sum().alias("sum_rain"), + /// col("rain").quantile(lit(0.5), QuantileMethod::Nearest).alias("median_rain"), + /// ]) + /// } + /// ``` + pub fn group_by, IE: Into + Clone>(self, by: E) -> LazyGroupBy { + let keys = by + .as_ref() + .iter() + .map(|e| e.clone().into()) + .collect::>(); + let opt_state = self.get_opt_state(); + + #[cfg(feature = "dynamic_group_by")] + { + LazyGroupBy { + logical_plan: self.logical_plan, + opt_state, + keys, + maintain_order: false, + dynamic_options: None, + rolling_options: None, + } + } + + #[cfg(not(feature = "dynamic_group_by"))] + { + LazyGroupBy { + logical_plan: self.logical_plan, + opt_state, + keys, + maintain_order: false, + } + } + } + + /// Create rolling groups based on a time column. + /// + /// Also works for index values of type UInt32, UInt64, Int32, or Int64. + /// + /// Different from a [`group_by_dynamic`][`Self::group_by_dynamic`], the windows are now determined by the + /// individual values and are not of constant intervals. For constant intervals use + /// *group_by_dynamic* + #[cfg(feature = "dynamic_group_by")] + pub fn rolling>( + mut self, + index_column: Expr, + group_by: E, + mut options: RollingGroupOptions, + ) -> LazyGroupBy { + if let Expr::Column(name) = index_column { + options.index_column = name; + } else { + let output_field = index_column + .to_field(&self.collect_schema().unwrap(), Context::Default) + .unwrap(); + return self.with_column(index_column).rolling( + Expr::Column(output_field.name().clone()), + group_by, + options, + ); + } + let opt_state = self.get_opt_state(); + LazyGroupBy { + logical_plan: self.logical_plan, + opt_state, + keys: group_by.as_ref().to_vec(), + maintain_order: true, + dynamic_options: None, + rolling_options: Some(options), + } + } + + /// Group based on a time value (or index value of type Int32, Int64). + /// + /// Time windows are calculated and rows are assigned to windows. Different from a + /// normal group_by is that a row can be member of multiple groups. The time/index + /// window could be seen as a rolling window, with a window size determined by + /// dates/times/values instead of slots in the DataFrame. + /// + /// A window is defined by: + /// + /// - every: interval of the window + /// - period: length of the window + /// - offset: offset of the window + /// + /// The `group_by` argument should be empty `[]` if you don't want to combine this + /// with a ordinary group_by on these keys. + #[cfg(feature = "dynamic_group_by")] + pub fn group_by_dynamic>( + mut self, + index_column: Expr, + group_by: E, + mut options: DynamicGroupOptions, + ) -> LazyGroupBy { + if let Expr::Column(name) = index_column { + options.index_column = name; + } else { + let output_field = index_column + .to_field(&self.collect_schema().unwrap(), Context::Default) + .unwrap(); + return self.with_column(index_column).group_by_dynamic( + Expr::Column(output_field.name().clone()), + group_by, + options, + ); + } + let opt_state = self.get_opt_state(); + LazyGroupBy { + logical_plan: self.logical_plan, + opt_state, + keys: group_by.as_ref().to_vec(), + maintain_order: true, + dynamic_options: Some(options), + rolling_options: None, + } + } + + /// Similar to [`group_by`][`Self::group_by`], but order of the DataFrame is maintained. + pub fn group_by_stable, IE: Into + Clone>(self, by: E) -> LazyGroupBy { + let keys = by + .as_ref() + .iter() + .map(|e| e.clone().into()) + .collect::>(); + let opt_state = self.get_opt_state(); + + #[cfg(feature = "dynamic_group_by")] + { + LazyGroupBy { + logical_plan: self.logical_plan, + opt_state, + keys, + maintain_order: true, + dynamic_options: None, + rolling_options: None, + } + } + + #[cfg(not(feature = "dynamic_group_by"))] + { + LazyGroupBy { + logical_plan: self.logical_plan, + opt_state, + keys, + maintain_order: true, + } + } + } + + /// Left anti join this query with another lazy query. + /// + /// Matches on the values of the expressions `left_on` and `right_on`. For more + /// flexible join logic, see [`join`](LazyFrame::join) or + /// [`join_builder`](LazyFrame::join_builder). + /// + /// # Example + /// + /// ```rust + /// use polars_core::prelude::*; + /// use polars_lazy::prelude::*; + /// fn anti_join_dataframes(ldf: LazyFrame, other: LazyFrame) -> LazyFrame { + /// ldf + /// .anti_join(other, col("foo"), col("bar").cast(DataType::String)) + /// } + /// ``` + #[cfg(feature = "semi_anti_join")] + pub fn anti_join>(self, other: LazyFrame, left_on: E, right_on: E) -> LazyFrame { + self.join( + other, + [left_on.into()], + [right_on.into()], + JoinArgs::new(JoinType::Anti), + ) + } + + /// Creates the Cartesian product from both frames, preserving the order of the left keys. + #[cfg(feature = "cross_join")] + pub fn cross_join(self, other: LazyFrame, suffix: Option) -> LazyFrame { + self.join( + other, + vec![], + vec![], + JoinArgs::new(JoinType::Cross).with_suffix(suffix), + ) + } + + /// Left outer join this query with another lazy query. + /// + /// Matches on the values of the expressions `left_on` and `right_on`. For more + /// flexible join logic, see [`join`](LazyFrame::join) or + /// [`join_builder`](LazyFrame::join_builder). + /// + /// # Example + /// + /// ```rust + /// use polars_core::prelude::*; + /// use polars_lazy::prelude::*; + /// fn left_join_dataframes(ldf: LazyFrame, other: LazyFrame) -> LazyFrame { + /// ldf + /// .left_join(other, col("foo"), col("bar")) + /// } + /// ``` + pub fn left_join>(self, other: LazyFrame, left_on: E, right_on: E) -> LazyFrame { + self.join( + other, + [left_on.into()], + [right_on.into()], + JoinArgs::new(JoinType::Left), + ) + } + + /// Inner join this query with another lazy query. + /// + /// Matches on the values of the expressions `left_on` and `right_on`. For more + /// flexible join logic, see [`join`](LazyFrame::join) or + /// [`join_builder`](LazyFrame::join_builder). + /// + /// # Example + /// + /// ```rust + /// use polars_core::prelude::*; + /// use polars_lazy::prelude::*; + /// fn inner_join_dataframes(ldf: LazyFrame, other: LazyFrame) -> LazyFrame { + /// ldf + /// .inner_join(other, col("foo"), col("bar").cast(DataType::String)) + /// } + /// ``` + pub fn inner_join>(self, other: LazyFrame, left_on: E, right_on: E) -> LazyFrame { + self.join( + other, + [left_on.into()], + [right_on.into()], + JoinArgs::new(JoinType::Inner), + ) + } + + /// Full outer join this query with another lazy query. + /// + /// Matches on the values of the expressions `left_on` and `right_on`. For more + /// flexible join logic, see [`join`](LazyFrame::join) or + /// [`join_builder`](LazyFrame::join_builder). + /// + /// # Example + /// + /// ```rust + /// use polars_core::prelude::*; + /// use polars_lazy::prelude::*; + /// fn full_join_dataframes(ldf: LazyFrame, other: LazyFrame) -> LazyFrame { + /// ldf + /// .full_join(other, col("foo"), col("bar")) + /// } + /// ``` + pub fn full_join>(self, other: LazyFrame, left_on: E, right_on: E) -> LazyFrame { + self.join( + other, + [left_on.into()], + [right_on.into()], + JoinArgs::new(JoinType::Full), + ) + } + + /// Left semi join this query with another lazy query. + /// + /// Matches on the values of the expressions `left_on` and `right_on`. For more + /// flexible join logic, see [`join`](LazyFrame::join) or + /// [`join_builder`](LazyFrame::join_builder). + /// + /// # Example + /// + /// ```rust + /// use polars_core::prelude::*; + /// use polars_lazy::prelude::*; + /// fn semi_join_dataframes(ldf: LazyFrame, other: LazyFrame) -> LazyFrame { + /// ldf + /// .semi_join(other, col("foo"), col("bar").cast(DataType::String)) + /// } + /// ``` + #[cfg(feature = "semi_anti_join")] + pub fn semi_join>(self, other: LazyFrame, left_on: E, right_on: E) -> LazyFrame { + self.join( + other, + [left_on.into()], + [right_on.into()], + JoinArgs::new(JoinType::Semi), + ) + } + + /// Generic function to join two LazyFrames. + /// + /// `join` can join on multiple columns, given as two list of expressions, and with a + /// [`JoinType`] specified by `how`. Non-joined column names in the right DataFrame + /// that already exist in this DataFrame are suffixed with `"_right"`. For control + /// over how columns are renamed and parallelization options, use + /// [`join_builder`](LazyFrame::join_builder). + /// + /// Any provided `args.slice` parameter is not considered, but set by the internal optimizer. + /// + /// # Example + /// + /// ```rust + /// use polars_core::prelude::*; + /// use polars_lazy::prelude::*; + /// + /// fn example(ldf: LazyFrame, other: LazyFrame) -> LazyFrame { + /// ldf + /// .join(other, [col("foo"), col("bar")], [col("foo"), col("bar")], JoinArgs::new(JoinType::Inner)) + /// } + /// ``` + pub fn join>( + self, + other: LazyFrame, + left_on: E, + right_on: E, + args: JoinArgs, + ) -> LazyFrame { + let left_on = left_on.as_ref().to_vec(); + let right_on = right_on.as_ref().to_vec(); + + self._join_impl(other, left_on, right_on, args) + } + + fn _join_impl( + self, + other: LazyFrame, + left_on: Vec, + right_on: Vec, + args: JoinArgs, + ) -> LazyFrame { + let JoinArgs { + how, + validation, + suffix, + slice, + nulls_equal, + coalesce, + maintain_order, + } = args; + + if slice.is_some() { + panic!("impl error: slice is not handled") + } + + let mut builder = self + .join_builder() + .with(other) + .left_on(left_on) + .right_on(right_on) + .how(how) + .validate(validation) + .join_nulls(nulls_equal) + .coalesce(coalesce) + .maintain_order(maintain_order); + + if let Some(suffix) = suffix { + builder = builder.suffix(suffix); + } + + // Note: args.slice is set by the optimizer + builder.finish() + } + + /// Consume `self` and return a [`JoinBuilder`] to customize a join on this LazyFrame. + /// + /// After the `JoinBuilder` has been created and set up, calling + /// [`finish()`](JoinBuilder::finish) on it will give back the `LazyFrame` + /// representing the `join` operation. + pub fn join_builder(self) -> JoinBuilder { + JoinBuilder::new(self) + } + + /// Add or replace a column, given as an expression, to a DataFrame. + /// + /// # Example + /// + /// ```rust + /// use polars_core::prelude::*; + /// use polars_lazy::prelude::*; + /// fn add_column(df: DataFrame) -> LazyFrame { + /// df.lazy() + /// .with_column( + /// when(col("sepal_length").lt(lit(5.0))) + /// .then(lit(10)) + /// .otherwise(lit(1)) + /// .alias("new_column_name"), + /// ) + /// } + /// ``` + pub fn with_column(self, expr: Expr) -> LazyFrame { + let opt_state = self.get_opt_state(); + let lp = self + .get_plan_builder() + .with_columns( + vec![expr], + ProjectionOptions { + run_parallel: false, + duplicate_check: true, + should_broadcast: true, + }, + ) + .build(); + Self::from_logical_plan(lp, opt_state) + } + + /// Add or replace multiple columns, given as expressions, to a DataFrame. + /// + /// # Example + /// + /// ```rust + /// use polars_core::prelude::*; + /// use polars_lazy::prelude::*; + /// fn add_columns(df: DataFrame) -> LazyFrame { + /// df.lazy() + /// .with_columns( + /// vec![lit(10).alias("foo"), lit(100).alias("bar")] + /// ) + /// } + /// ``` + pub fn with_columns>(self, exprs: E) -> LazyFrame { + let exprs = exprs.as_ref().to_vec(); + self.with_columns_impl( + exprs, + ProjectionOptions { + run_parallel: true, + duplicate_check: true, + should_broadcast: true, + }, + ) + } + + /// Add or replace multiple columns to a DataFrame, but evaluate them sequentially. + pub fn with_columns_seq>(self, exprs: E) -> LazyFrame { + let exprs = exprs.as_ref().to_vec(); + self.with_columns_impl( + exprs, + ProjectionOptions { + run_parallel: false, + duplicate_check: true, + should_broadcast: true, + }, + ) + } + + fn with_columns_impl(self, exprs: Vec, options: ProjectionOptions) -> LazyFrame { + let opt_state = self.get_opt_state(); + let lp = self.get_plan_builder().with_columns(exprs, options).build(); + Self::from_logical_plan(lp, opt_state) + } + + pub fn with_context>(self, contexts: C) -> LazyFrame { + let contexts = contexts + .as_ref() + .iter() + .map(|lf| lf.logical_plan.clone()) + .collect(); + let opt_state = self.get_opt_state(); + let lp = self.get_plan_builder().with_context(contexts).build(); + Self::from_logical_plan(lp, opt_state) + } + + /// Aggregate all the columns as their maximum values. + /// + /// Aggregated columns will have the same names as the original columns. + pub fn max(self) -> Self { + self.map_private(DslFunction::Stats(StatsFunction::Max)) + } + + /// Aggregate all the columns as their minimum values. + /// + /// Aggregated columns will have the same names as the original columns. + pub fn min(self) -> Self { + self.map_private(DslFunction::Stats(StatsFunction::Min)) + } + + /// Aggregate all the columns as their sum values. + /// + /// Aggregated columns will have the same names as the original columns. + /// + /// - Boolean columns will sum to a `u32` containing the number of `true`s. + /// - For integer columns, the ordinary checks for overflow are performed: + /// if running in `debug` mode, overflows will panic, whereas in `release` mode overflows will + /// silently wrap. + /// - String columns will sum to None. + pub fn sum(self) -> Self { + self.map_private(DslFunction::Stats(StatsFunction::Sum)) + } + + /// Aggregate all the columns as their mean values. + /// + /// - Boolean and integer columns are converted to `f64` before computing the mean. + /// - String columns will have a mean of None. + pub fn mean(self) -> Self { + self.map_private(DslFunction::Stats(StatsFunction::Mean)) + } + + /// Aggregate all the columns as their median values. + /// + /// - Boolean and integer results are converted to `f64`. However, they are still + /// susceptible to overflow before this conversion occurs. + /// - String columns will sum to None. + pub fn median(self) -> Self { + self.map_private(DslFunction::Stats(StatsFunction::Median)) + } + + /// Aggregate all the columns as their quantile values. + pub fn quantile(self, quantile: Expr, method: QuantileMethod) -> Self { + self.map_private(DslFunction::Stats(StatsFunction::Quantile { + quantile, + method, + })) + } + + /// Aggregate all the columns as their standard deviation values. + /// + /// `ddof` is the "Delta Degrees of Freedom"; `N - ddof` will be the denominator when + /// computing the variance, where `N` is the number of rows. + /// > In standard statistical practice, `ddof=1` provides an unbiased estimator of the + /// > variance of a hypothetical infinite population. `ddof=0` provides a maximum + /// > likelihood estimate of the variance for normally distributed variables. The + /// > standard deviation computed in this function is the square root of the estimated + /// > variance, so even with `ddof=1`, it will not be an unbiased estimate of the + /// > standard deviation per se. + /// + /// Source: [Numpy](https://numpy.org/doc/stable/reference/generated/numpy.std.html#) + pub fn std(self, ddof: u8) -> Self { + self.map_private(DslFunction::Stats(StatsFunction::Std { ddof })) + } + + /// Aggregate all the columns as their variance values. + /// + /// `ddof` is the "Delta Degrees of Freedom"; `N - ddof` will be the denominator when + /// computing the variance, where `N` is the number of rows. + /// > In standard statistical practice, `ddof=1` provides an unbiased estimator of the + /// > variance of a hypothetical infinite population. `ddof=0` provides a maximum + /// > likelihood estimate of the variance for normally distributed variables. + /// + /// Source: [Numpy](https://numpy.org/doc/stable/reference/generated/numpy.var.html#) + pub fn var(self, ddof: u8) -> Self { + self.map_private(DslFunction::Stats(StatsFunction::Var { ddof })) + } + + /// Apply explode operation. [See eager explode](polars_core::frame::DataFrame::explode). + pub fn explode, IE: Into + Clone>(self, columns: E) -> LazyFrame { + self.explode_impl(columns, false) + } + + /// Apply explode operation. [See eager explode](polars_core::frame::DataFrame::explode). + fn explode_impl, IE: Into + Clone>( + self, + columns: E, + allow_empty: bool, + ) -> LazyFrame { + let columns = columns + .as_ref() + .iter() + .map(|e| e.clone().into()) + .collect::>(); + let opt_state = self.get_opt_state(); + let lp = self + .get_plan_builder() + .explode(columns, allow_empty) + .build(); + Self::from_logical_plan(lp, opt_state) + } + + /// Aggregate all the columns as the sum of their null value count. + pub fn null_count(self) -> LazyFrame { + self.select(vec![col(PlSmallStr::from_static("*")).null_count()]) + } + + /// Drop non-unique rows and maintain the order of kept rows. + /// + /// `subset` is an optional `Vec` of column names to consider for uniqueness; if + /// `None`, all columns are considered. + pub fn unique_stable( + self, + subset: Option>, + keep_strategy: UniqueKeepStrategy, + ) -> LazyFrame { + self.unique_stable_generic(subset, keep_strategy) + } + + pub fn unique_stable_generic( + self, + subset: Option, + keep_strategy: UniqueKeepStrategy, + ) -> LazyFrame + where + E: AsRef<[IE]>, + IE: Into + Clone, + { + let subset = subset.map(|s| { + s.as_ref() + .iter() + .map(|e| e.clone().into()) + .collect::>() + }); + + let opt_state = self.get_opt_state(); + let options = DistinctOptionsDSL { + subset, + maintain_order: true, + keep_strategy, + }; + let lp = self.get_plan_builder().distinct(options).build(); + Self::from_logical_plan(lp, opt_state) + } + + /// Drop non-unique rows without maintaining the order of kept rows. + /// + /// The order of the kept rows may change; to maintain the original row order, use + /// [`unique_stable`](LazyFrame::unique_stable). + /// + /// `subset` is an optional `Vec` of column names to consider for uniqueness; if None, + /// all columns are considered. + pub fn unique( + self, + subset: Option>, + keep_strategy: UniqueKeepStrategy, + ) -> LazyFrame { + self.unique_generic(subset, keep_strategy) + } + + pub fn unique_generic, IE: Into + Clone>( + self, + subset: Option, + keep_strategy: UniqueKeepStrategy, + ) -> LazyFrame { + let subset = subset.map(|s| { + s.as_ref() + .iter() + .map(|e| e.clone().into()) + .collect::>() + }); + let opt_state = self.get_opt_state(); + let options = DistinctOptionsDSL { + subset, + maintain_order: false, + keep_strategy, + }; + let lp = self.get_plan_builder().distinct(options).build(); + Self::from_logical_plan(lp, opt_state) + } + + /// Drop rows containing one or more NaN values. + /// + /// `subset` is an optional `Vec` of column names to consider for NaNs; if None, all + /// floating point columns are considered. + pub fn drop_nans(self, subset: Option>) -> LazyFrame { + let opt_state = self.get_opt_state(); + let lp = self.get_plan_builder().drop_nans(subset).build(); + Self::from_logical_plan(lp, opt_state) + } + + /// Drop rows containing one or more None values. + /// + /// `subset` is an optional `Vec` of column names to consider for nulls; if None, all + /// columns are considered. + pub fn drop_nulls(self, subset: Option>) -> LazyFrame { + let opt_state = self.get_opt_state(); + let lp = self.get_plan_builder().drop_nulls(subset).build(); + Self::from_logical_plan(lp, opt_state) + } + + /// Slice the DataFrame using an offset (starting row) and a length. + /// + /// If `offset` is negative, it is counted from the end of the DataFrame. For + /// instance, `lf.slice(-5, 3)` gets three rows, starting at the row fifth from the + /// end. + /// + /// If `offset` and `len` are such that the slice extends beyond the end of the + /// DataFrame, the portion between `offset` and the end will be returned. In this + /// case, the number of rows in the returned DataFrame will be less than `len`. + pub fn slice(self, offset: i64, len: IdxSize) -> LazyFrame { + let opt_state = self.get_opt_state(); + let lp = self.get_plan_builder().slice(offset, len).build(); + Self::from_logical_plan(lp, opt_state) + } + + /// Get the first row. + /// + /// Equivalent to `self.slice(0, 1)`. + pub fn first(self) -> LazyFrame { + self.slice(0, 1) + } + + /// Get the last row. + /// + /// Equivalent to `self.slice(-1, 1)`. + pub fn last(self) -> LazyFrame { + self.slice(-1, 1) + } + + /// Get the last `n` rows. + /// + /// Equivalent to `self.slice(-(n as i64), n)`. + pub fn tail(self, n: IdxSize) -> LazyFrame { + let neg_tail = -(n as i64); + self.slice(neg_tail, n) + } + + /// Unpivot the DataFrame from wide to long format. + /// + /// See [`UnpivotArgsIR`] for information on how to unpivot a DataFrame. + #[cfg(feature = "pivot")] + pub fn unpivot(self, args: UnpivotArgsDSL) -> LazyFrame { + let opt_state = self.get_opt_state(); + let lp = self.get_plan_builder().unpivot(args).build(); + Self::from_logical_plan(lp, opt_state) + } + + /// Limit the DataFrame to the first `n` rows. + /// + /// Note if you don't want the rows to be scanned, use [`fetch`](LazyFrame::fetch). + pub fn limit(self, n: IdxSize) -> LazyFrame { + self.slice(0, n) + } + + /// Apply a function/closure once the logical plan get executed. + /// + /// The function has access to the whole materialized DataFrame at the time it is + /// called. + /// + /// To apply specific functions to specific columns, use [`Expr::map`] in conjunction + /// with `LazyFrame::with_column` or `with_columns`. + /// + /// ## Warning + /// This can blow up in your face if the schema is changed due to the operation. The + /// optimizer relies on a correct schema. + /// + /// You can toggle certain optimizations off. + pub fn map( + self, + function: F, + optimizations: AllowedOptimizations, + schema: Option>, + name: Option<&'static str>, + ) -> LazyFrame + where + F: 'static + Fn(DataFrame) -> PolarsResult + Send + Sync, + { + let opt_state = self.get_opt_state(); + let lp = self + .get_plan_builder() + .map( + function, + optimizations, + schema, + PlSmallStr::from_static(name.unwrap_or("ANONYMOUS UDF")), + ) + .build(); + Self::from_logical_plan(lp, opt_state) + } + + #[cfg(feature = "python")] + pub fn map_python( + self, + function: polars_utils::python_function::PythonFunction, + optimizations: AllowedOptimizations, + schema: Option, + validate_output: bool, + ) -> LazyFrame { + let opt_state = self.get_opt_state(); + let lp = self + .get_plan_builder() + .map_python(function, optimizations, schema, validate_output) + .build(); + Self::from_logical_plan(lp, opt_state) + } + + pub(crate) fn map_private(self, function: DslFunction) -> LazyFrame { + let opt_state = self.get_opt_state(); + let lp = self.get_plan_builder().map_private(function).build(); + Self::from_logical_plan(lp, opt_state) + } + + /// Add a new column at index 0 that counts the rows. + /// + /// `name` is the name of the new column. `offset` is where to start counting from; if + /// `None`, it is set to `0`. + /// + /// # Warning + /// This can have a negative effect on query performance. This may for instance block + /// predicate pushdown optimization. + pub fn with_row_index(self, name: S, offset: Option) -> LazyFrame + where + S: Into, + { + let name = name.into(); + + match &self.logical_plan { + v @ DslPlan::Scan { scan_type, .. } + if !matches!(&**scan_type, FileScan::Anonymous { .. }) => + { + let DslPlan::Scan { + sources, + mut unified_scan_args, + scan_type, + file_info, + cached_ir: _, + } = v.clone() + else { + unreachable!() + }; + + unified_scan_args.row_index = Some(RowIndex { + name, + offset: offset.unwrap_or(0), + }); + + DslPlan::Scan { + sources, + unified_scan_args, + scan_type, + file_info, + cached_ir: Default::default(), + } + .into() + }, + _ => self.map_private(DslFunction::RowIndex { name, offset }), + } + } + + /// Return the number of non-null elements for each column. + pub fn count(self) -> LazyFrame { + self.select(vec![col(PlSmallStr::from_static("*")).count()]) + } + + /// Unnest the given `Struct` columns: the fields of the `Struct` type will be + /// inserted as columns. + #[cfg(feature = "dtype-struct")] + pub fn unnest(self, cols: E) -> Self + where + E: AsRef<[IE]>, + IE: Into + Clone, + { + let cols = cols + .as_ref() + .iter() + .map(|ie| ie.clone().into()) + .collect::>(); + self.map_private(DslFunction::Unnest(cols)) + } + + #[cfg(feature = "merge_sorted")] + pub fn merge_sorted(self, other: LazyFrame, key: S) -> PolarsResult + where + S: Into, + { + let key = key.into(); + + let lp = DslPlan::MergeSorted { + input_left: Arc::new(self.logical_plan), + input_right: Arc::new(other.logical_plan), + key, + }; + Ok(LazyFrame::from_logical_plan(lp, self.opt_state)) + } +} + +/// Utility struct for lazy group_by operation. +#[derive(Clone)] +pub struct LazyGroupBy { + pub logical_plan: DslPlan, + opt_state: OptFlags, + keys: Vec, + maintain_order: bool, + #[cfg(feature = "dynamic_group_by")] + dynamic_options: Option, + #[cfg(feature = "dynamic_group_by")] + rolling_options: Option, +} + +impl From for LazyFrame { + fn from(lgb: LazyGroupBy) -> Self { + Self { + logical_plan: lgb.logical_plan, + opt_state: lgb.opt_state, + cached_arena: Default::default(), + } + } +} + +impl LazyGroupBy { + /// Group by and aggregate. + /// + /// Select a column with [col] and choose an aggregation. + /// If you want to aggregate all columns use `col(PlSmallStr::from_static("*"))`. + /// + /// # Example + /// + /// ```rust + /// use polars_core::prelude::*; + /// use polars_lazy::prelude::*; + /// + /// fn example(df: DataFrame) -> LazyFrame { + /// df.lazy() + /// .group_by_stable([col("date")]) + /// .agg([ + /// col("rain").min().alias("min_rain"), + /// col("rain").sum().alias("sum_rain"), + /// col("rain").quantile(lit(0.5), QuantileMethod::Nearest).alias("median_rain"), + /// ]) + /// } + /// ``` + pub fn agg>(self, aggs: E) -> LazyFrame { + #[cfg(feature = "dynamic_group_by")] + let lp = DslBuilder::from(self.logical_plan) + .group_by( + self.keys, + aggs, + None, + self.maintain_order, + self.dynamic_options, + self.rolling_options, + ) + .build(); + + #[cfg(not(feature = "dynamic_group_by"))] + let lp = DslBuilder::from(self.logical_plan) + .group_by(self.keys, aggs, None, self.maintain_order) + .build(); + LazyFrame::from_logical_plan(lp, self.opt_state) + } + + /// Return first n rows of each group + pub fn head(self, n: Option) -> LazyFrame { + let keys = self + .keys + .iter() + .filter_map(|expr| expr_output_name(expr).ok()) + .collect::>(); + + self.agg([col(PlSmallStr::from_static("*")) + .exclude(keys.iter().cloned()) + .head(n)]) + .explode_impl( + [col(PlSmallStr::from_static("*")).exclude(keys.iter().cloned())], + true, + ) + } + + /// Return last n rows of each group + pub fn tail(self, n: Option) -> LazyFrame { + let keys = self + .keys + .iter() + .filter_map(|expr| expr_output_name(expr).ok()) + .collect::>(); + + self.agg([col(PlSmallStr::from_static("*")) + .exclude(keys.iter().cloned()) + .tail(n)]) + .explode_impl( + [col(PlSmallStr::from_static("*")).exclude(keys.iter().cloned())], + true, + ) + } + + /// Apply a function over the groups as a new DataFrame. + /// + /// **It is not recommended that you use this as materializing the DataFrame is very + /// expensive.** + pub fn apply(self, f: F, schema: SchemaRef) -> LazyFrame + where + F: 'static + Fn(DataFrame) -> PolarsResult + Send + Sync, + { + #[cfg(feature = "dynamic_group_by")] + let options = GroupbyOptions { + dynamic: self.dynamic_options, + rolling: self.rolling_options, + slice: None, + }; + + #[cfg(not(feature = "dynamic_group_by"))] + let options = GroupbyOptions { slice: None }; + + let lp = DslPlan::GroupBy { + input: Arc::new(self.logical_plan), + keys: self.keys, + aggs: vec![], + apply: Some((Arc::new(f), schema)), + maintain_order: self.maintain_order, + options: Arc::new(options), + }; + LazyFrame::from_logical_plan(lp, self.opt_state) + } +} + +#[must_use] +pub struct JoinBuilder { + lf: LazyFrame, + how: JoinType, + other: Option, + left_on: Vec, + right_on: Vec, + allow_parallel: bool, + force_parallel: bool, + suffix: Option, + validation: JoinValidation, + nulls_equal: bool, + coalesce: JoinCoalesce, + maintain_order: MaintainOrderJoin, +} +impl JoinBuilder { + /// Create the `JoinBuilder` with the provided `LazyFrame` as the left table. + pub fn new(lf: LazyFrame) -> Self { + Self { + lf, + other: None, + how: JoinType::Inner, + left_on: vec![], + right_on: vec![], + allow_parallel: true, + force_parallel: false, + suffix: None, + validation: Default::default(), + nulls_equal: false, + coalesce: Default::default(), + maintain_order: Default::default(), + } + } + + /// The right table in the join. + pub fn with(mut self, other: LazyFrame) -> Self { + self.other = Some(other); + self + } + + /// Select the join type. + pub fn how(mut self, how: JoinType) -> Self { + self.how = how; + self + } + + pub fn validate(mut self, validation: JoinValidation) -> Self { + self.validation = validation; + self + } + + /// The expressions you want to join both tables on. + /// + /// The passed expressions must be valid in both `LazyFrame`s in the join. + pub fn on>(mut self, on: E) -> Self { + let on = on.as_ref().to_vec(); + self.left_on.clone_from(&on); + self.right_on = on; + self + } + + /// The expressions you want to join the left table on. + /// + /// The passed expressions must be valid in the left table. + pub fn left_on>(mut self, on: E) -> Self { + self.left_on = on.as_ref().to_vec(); + self + } + + /// The expressions you want to join the right table on. + /// + /// The passed expressions must be valid in the right table. + pub fn right_on>(mut self, on: E) -> Self { + self.right_on = on.as_ref().to_vec(); + self + } + + /// Allow parallel table evaluation. + pub fn allow_parallel(mut self, allow: bool) -> Self { + self.allow_parallel = allow; + self + } + + /// Force parallel table evaluation. + pub fn force_parallel(mut self, force: bool) -> Self { + self.force_parallel = force; + self + } + + /// Join on null values. By default null values will never produce matches. + pub fn join_nulls(mut self, nulls_equal: bool) -> Self { + self.nulls_equal = nulls_equal; + self + } + + /// Suffix to add duplicate column names in join. + /// Defaults to `"_right"` if this method is never called. + pub fn suffix(mut self, suffix: S) -> Self + where + S: Into, + { + self.suffix = Some(suffix.into()); + self + } + + /// Whether to coalesce join columns. + pub fn coalesce(mut self, coalesce: JoinCoalesce) -> Self { + self.coalesce = coalesce; + self + } + + /// Whether to preserve the row order. + pub fn maintain_order(mut self, maintain_order: MaintainOrderJoin) -> Self { + self.maintain_order = maintain_order; + self + } + + /// Finish builder + pub fn finish(self) -> LazyFrame { + let opt_state = self.lf.opt_state; + let other = self.other.expect("'with' not set in join builder"); + + let args = JoinArgs { + how: self.how, + validation: self.validation, + suffix: self.suffix, + slice: None, + nulls_equal: self.nulls_equal, + coalesce: self.coalesce, + maintain_order: self.maintain_order, + }; + + let lp = self + .lf + .get_plan_builder() + .join( + other.logical_plan, + self.left_on, + self.right_on, + JoinOptions { + allow_parallel: self.allow_parallel, + force_parallel: self.force_parallel, + args, + ..Default::default() + } + .into(), + ) + .build(); + LazyFrame::from_logical_plan(lp, opt_state) + } + + // Finish with join predicates + pub fn join_where(self, predicates: Vec) -> LazyFrame { + let opt_state = self.lf.opt_state; + let other = self.other.expect("with not set"); + + // Decompose `And` conjunctions into their component expressions + fn decompose_and(predicate: Expr, expanded_predicates: &mut Vec) { + if let Expr::BinaryExpr { + op: Operator::And, + left, + right, + } = predicate + { + decompose_and((*left).clone(), expanded_predicates); + decompose_and((*right).clone(), expanded_predicates); + } else { + expanded_predicates.push(predicate); + } + } + let mut expanded_predicates = Vec::with_capacity(predicates.len() * 2); + for predicate in predicates { + decompose_and(predicate, &mut expanded_predicates); + } + let predicates: Vec = expanded_predicates; + + // Decompose `is_between` predicates to allow for cleaner expression of range joins + #[cfg(feature = "is_between")] + let predicates: Vec = { + let mut expanded_predicates = Vec::with_capacity(predicates.len() * 2); + for predicate in predicates { + if let Expr::Function { + function: FunctionExpr::Boolean(BooleanFunction::IsBetween { closed }), + input, + .. + } = &predicate + { + if let [expr, lower, upper] = input.as_slice() { + match closed { + ClosedInterval::Both => { + expanded_predicates.push(expr.clone().gt_eq(lower.clone())); + expanded_predicates.push(expr.clone().lt_eq(upper.clone())); + }, + ClosedInterval::Right => { + expanded_predicates.push(expr.clone().gt(lower.clone())); + expanded_predicates.push(expr.clone().lt_eq(upper.clone())); + }, + ClosedInterval::Left => { + expanded_predicates.push(expr.clone().gt_eq(lower.clone())); + expanded_predicates.push(expr.clone().lt(upper.clone())); + }, + ClosedInterval::None => { + expanded_predicates.push(expr.clone().gt(lower.clone())); + expanded_predicates.push(expr.clone().lt(upper.clone())); + }, + } + continue; + } + } + expanded_predicates.push(predicate); + } + expanded_predicates + }; + + let args = JoinArgs { + how: self.how, + validation: self.validation, + suffix: self.suffix, + slice: None, + nulls_equal: self.nulls_equal, + coalesce: self.coalesce, + maintain_order: self.maintain_order, + }; + let options = JoinOptions { + allow_parallel: self.allow_parallel, + force_parallel: self.force_parallel, + args, + ..Default::default() + }; + + let lp = DslPlan::Join { + input_left: Arc::new(self.lf.logical_plan), + input_right: Arc::new(other.logical_plan), + left_on: Default::default(), + right_on: Default::default(), + predicates, + options: Arc::from(options), + }; + + LazyFrame::from_logical_plan(lp, opt_state) + } +} + +pub const BUILD_STREAMING_EXECUTOR: Option = { + #[cfg(not(feature = "new_streaming"))] + { + None + } + #[cfg(feature = "new_streaming")] + { + Some(streaming_dispatch::build_streaming_query_executor) + } +}; +#[cfg(feature = "new_streaming")] +pub use streaming_dispatch::build_streaming_query_executor; + +#[cfg(feature = "new_streaming")] +mod streaming_dispatch { + use std::sync::{Arc, Mutex}; + + use polars_core::error::PolarsResult; + use polars_core::frame::DataFrame; + use polars_expr::state::ExecutionState; + use polars_mem_engine::Executor; + use polars_plan::dsl::SinkTypeIR; + use polars_plan::plans::{AExpr, IR}; + use polars_utils::arena::{Arena, Node}; + + pub fn build_streaming_query_executor( + node: Node, + ir_arena: &mut Arena, + expr_arena: &mut Arena, + ) -> PolarsResult> { + let rechunk = match ir_arena.get(node) { + IR::Scan { + unified_scan_args, .. + } => unified_scan_args.rechunk, + _ => false, + }; + + let node = ir_arena.add(IR::Sink { + input: node, + payload: SinkTypeIR::Memory, + }); + + polars_stream::StreamingQuery::build(node, ir_arena, expr_arena) + .map(Some) + .map(Mutex::new) + .map(Arc::new) + .map(|x| StreamingQueryExecutor { + executor: x, + rechunk, + }) + .map(|x| Box::new(x) as Box) + } + + // Note: Arc/Mutex is because Executor requires Sync, but SlotMap is not Sync. + struct StreamingQueryExecutor { + executor: Arc>>, + rechunk: bool, + } + + impl Executor for StreamingQueryExecutor { + fn execute(&mut self, _cache: &mut ExecutionState) -> PolarsResult { + let mut df = { self.executor.try_lock().unwrap().take() } + .expect("unhandled: execute() more than once") + .execute() + .map(|x| x.unwrap_single())?; + + if self.rechunk { + df.as_single_chunk_par(); + } + + Ok(df) + } + } +} diff --git a/crates/polars-lazy/src/frame/pivot.rs b/crates/polars-lazy/src/frame/pivot.rs new file mode 100644 index 000000000000..6225c1667bbb --- /dev/null +++ b/crates/polars-lazy/src/frame/pivot.rs @@ -0,0 +1,90 @@ +//! Module containing implementation of the pivot operation. +//! +//! Polars lazy does not implement a pivot because it is impossible to know the schema without +//! materializing the whole dataset. This makes a pivot quite a terrible operation for performant +//! workflows. An optimization can never be pushed down passed a pivot. +//! +//! We can do a pivot on an eager `DataFrame` as that is already materialized. The code for the +//! pivot is here, because we want to be able to pass expressions to the pivot operation. +//! + +use polars_core::frame::group_by::expr::PhysicalAggExpr; +use polars_core::prelude::*; +use polars_ops::pivot::PivotAgg; + +use crate::physical_plan::exotic::{prepare_eval_expr, prepare_expression_for_context}; +use crate::prelude::*; + +struct PivotExpr(Expr); + +impl PhysicalAggExpr for PivotExpr { + fn evaluate(&self, df: &DataFrame, groups: &GroupPositions) -> PolarsResult { + let state = ExecutionState::new(); + let dtype = df.get_columns()[0].dtype(); + let phys_expr = prepare_expression_for_context( + PlSmallStr::EMPTY, + &self.0, + dtype, + Context::Aggregation, + )?; + phys_expr + .evaluate_on_groups(df, groups, &state) + .map(|mut ac| ac.aggregated().take_materialized_series()) + } + + fn root_name(&self) -> PolarsResult<&PlSmallStr> { + Ok(PlSmallStr::EMPTY_REF) + } +} + +pub fn pivot( + df: &DataFrame, + on: I0, + index: Option, + values: Option, + sort_columns: bool, + agg_expr: Option, + // used as separator/delimiter in generated column names. + separator: Option<&str>, +) -> PolarsResult +where + I0: IntoIterator, + I1: IntoIterator, + I2: IntoIterator, + S0: Into, + S1: Into, + S2: Into, +{ + // make sure that the root column is replaced + let agg_expr = agg_expr.map(|agg_expr| { + let expr = prepare_eval_expr(agg_expr); + PivotAgg::Expr(Arc::new(PivotExpr(expr))) + }); + polars_ops::pivot::pivot(df, on, index, values, sort_columns, agg_expr, separator) +} + +pub fn pivot_stable( + df: &DataFrame, + on: I0, + index: Option, + values: Option, + sort_columns: bool, + agg_expr: Option, + // used as separator/delimiter in generated column names. + separator: Option<&str>, +) -> PolarsResult +where + I0: IntoIterator, + I1: IntoIterator, + I2: IntoIterator, + S0: Into, + S1: Into, + S2: Into, +{ + // make sure that the root column is replaced + let agg_expr = agg_expr.map(|agg_expr| { + let expr = prepare_eval_expr(agg_expr); + PivotAgg::Expr(Arc::new(PivotExpr(expr))) + }); + polars_ops::pivot::pivot_stable(df, on, index, values, sort_columns, agg_expr, separator) +} diff --git a/crates/polars-lazy/src/frame/python.rs b/crates/polars-lazy/src/frame/python.rs new file mode 100644 index 000000000000..1e4409e37c92 --- /dev/null +++ b/crates/polars-lazy/src/frame/python.rs @@ -0,0 +1,33 @@ +use std::sync::Arc; + +use either::Either; +use polars_core::schema::SchemaRef; +use pyo3::PyObject; + +use self::python_dsl::{PythonOptionsDsl, PythonScanSource}; +use crate::prelude::*; + +impl LazyFrame { + pub fn scan_from_python_function( + schema: Either, + scan_fn: PyObject, + pyarrow: bool, + // Validate that the source gives the proper schema + validate_schema: bool, + ) -> Self { + DslPlan::PythonScan { + options: PythonOptionsDsl { + // Should be a python function that returns a generator + scan_fn: Some(scan_fn.into()), + schema_fn: Some(SpecialEq::new(Arc::new(schema.map_left(|obj| obj.into())))), + python_source: if pyarrow { + PythonScanSource::Pyarrow + } else { + PythonScanSource::IOPlugin + }, + validate_schema, + }, + } + .into() + } +} diff --git a/crates/polars-lazy/src/lib.rs b/crates/polars-lazy/src/lib.rs new file mode 100644 index 000000000000..08da7ff13891 --- /dev/null +++ b/crates/polars-lazy/src/lib.rs @@ -0,0 +1,211 @@ +//! Lazy API of Polars +//! +//! The lazy API of Polars supports a subset of the eager API. Apart from the distributed compute, +//! it is very similar to [Apache Spark](https://spark.apache.org/). You write queries in a +//! domain specific language. These queries translate to a logical plan, which represent your query steps. +//! Before execution this logical plan is optimized and may change the order of operations if this will increase performance. +//! Or implicit type casts may be added such that execution of the query won't lead to a type error (if it can be resolved). +//! +//! # Lazy DSL +//! +//! The lazy API of polars replaces the eager [`DataFrame`] with the [`LazyFrame`], through which +//! the lazy API is exposed. +//! The [`LazyFrame`] represents a logical execution plan: a sequence of operations to perform on a concrete data source. +//! These operations are not executed until we call [`collect`]. +//! This allows polars to optimize/reorder the query which may lead to faster queries or fewer type errors. +//! +//! [`DataFrame`]: polars_core::frame::DataFrame +//! [`LazyFrame`]: crate::frame::LazyFrame +//! [`collect`]: crate::frame::LazyFrame::collect +//! +//! In general, a [`LazyFrame`] requires a concrete data source — a [`DataFrame`], a file on disk, etc. — which polars-lazy +//! then applies the user-specified sequence of operations to. +//! To obtain a [`LazyFrame`] from an existing [`DataFrame`], we call the [`lazy`](crate::frame::IntoLazy::lazy) method on +//! the [`DataFrame`]. +//! A [`LazyFrame`] can also be obtained through the lazy versions of file readers, such as [`LazyCsvReader`](crate::frame::LazyCsvReader). +//! +//! The other major component of the polars lazy API is [`Expr`](crate::dsl::Expr), which represents an operation to be +//! performed on a [`LazyFrame`], such as mapping over a column, filtering, or groupby-aggregation. +//! [`Expr`] and the functions that produce them can be found in the [dsl module](crate::dsl). +//! +//! [`Expr`]: crate::dsl::Expr +//! +//! Most operations on a [`LazyFrame`] consume the [`LazyFrame`] and return a new [`LazyFrame`] with the updated plan. +//! If you need to use the same [`LazyFrame`] multiple times, you should [`clone`](crate::frame::LazyFrame::clone) it, and optionally +//! [`cache`](crate::frame::LazyFrame::cache) it beforehand. +//! +//! ## Examples +//! +//! #### Adding a new column to a lazy DataFrame +//! +//!```rust +//! #[macro_use] extern crate polars_core; +//! use polars_core::prelude::*; +//! use polars_lazy::prelude::*; +//! +//! let df = df! { +//! "column_a" => &[1, 2, 3, 4, 5], +//! "column_b" => &["a", "b", "c", "d", "e"] +//! }.unwrap(); +//! +//! let new = df.lazy() +//! // Note the reverse here!! +//! .reverse() +//! .with_column( +//! // always rename a new column +//! (col("column_a") * lit(10)).alias("new_column") +//! ) +//! .collect() +//! .unwrap(); +//! +//! assert!(new.column("new_column") +//! .unwrap() +//! .equals( +//! &Column::new("new_column".into(), &[50, 40, 30, 20, 10]) +//! ) +//! ); +//! ``` +//! #### Modifying a column based on some predicate +//! +//!```rust +//! #[macro_use] extern crate polars_core; +//! use polars_core::prelude::*; +//! use polars_lazy::prelude::*; +//! +//! let df = df! { +//! "column_a" => &[1, 2, 3, 4, 5], +//! "column_b" => &["a", "b", "c", "d", "e"] +//! }.unwrap(); +//! +//! let new = df.lazy() +//! .with_column( +//! // value = 100 if x < 3 else x +//! when( +//! col("column_a").lt(lit(3)) +//! ).then( +//! lit(100) +//! ).otherwise( +//! col("column_a") +//! ).alias("new_column") +//! ) +//! .collect() +//! .unwrap(); +//! +//! assert!(new.column("new_column") +//! .unwrap() +//! .equals( +//! &Column::new("new_column".into(), &[100, 100, 3, 4, 5]) +//! ) +//! ); +//! ``` +//! #### Groupby + Aggregations +//! +//!```rust +//! use polars_core::prelude::*; +//! use polars_core::df; +//! use polars_lazy::prelude::*; +//! +//! fn example() -> PolarsResult { +//! let df = df!( +//! "date" => ["2020-08-21", "2020-08-21", "2020-08-22", "2020-08-23", "2020-08-22"], +//! "temp" => [20, 10, 7, 9, 1], +//! "rain" => [0.2, 0.1, 0.3, 0.1, 0.01] +//! )?; +//! +//! df.lazy() +//! .group_by([col("date")]) +//! .agg([ +//! col("rain").min().alias("min_rain"), +//! col("rain").sum().alias("sum_rain"), +//! col("rain").quantile(lit(0.5), QuantileMethod::Nearest).alias("median_rain"), +//! ]) +//! .sort(["date"], Default::default()) +//! .collect() +//! } +//! ``` +//! +//! #### Calling any function +//! +//! Below we lazily call a custom closure of type `Series => Result`. Because the closure +//! changes the type/variant of the Series we also define the return type. This is important because +//! due to the laziness the types should be known beforehand. Note that by applying these custom +//! functions you have access to the whole **eager API** of the Series/ChunkedArrays. +//! +//!```rust +//! #[macro_use] extern crate polars_core; +//! use polars_core::prelude::*; +//! use polars_lazy::prelude::*; +//! +//! let df = df! { +//! "column_a" => &[1, 2, 3, 4, 5], +//! "column_b" => &["a", "b", "c", "d", "e"] +//! }.unwrap(); +//! +//! let new = df.lazy() +//! .with_column( +//! col("column_a") +//! // apply a custom closure Series => Result +//! .map(|_s| { +//! Ok(Some(Column::new("".into(), &[6.0f32, 6.0, 6.0, 6.0, 6.0]))) +//! }, +//! // return type of the closure +//! GetOutput::from_type(DataType::Float64)).alias("new_column") +//! ) +//! .collect() +//! .unwrap(); +//! ``` +//! +//! #### Joins, filters and projections +//! +//! In the query below we do a lazy join and afterwards we filter rows based on the predicate `a < 2`. +//! And last we select the columns `"b"` and `"c_first"`. In an eager API this query would be very +//! suboptimal because we join on DataFrames with more columns and rows than needed. In this case +//! the query optimizer will do the selection of the columns (projection) and the filtering of the +//! rows (selection) before the join, thereby reducing the amount of work done by the query. +//! +//! ```rust +//! # use polars_core::prelude::*; +//! # use polars_lazy::prelude::*; +//! +//! fn example(df_a: DataFrame, df_b: DataFrame) -> LazyFrame { +//! df_a.lazy() +//! .left_join(df_b.lazy(), col("b_left"), col("b_right")) +//! .filter( +//! col("a").lt(lit(2)) +//! ) +//! .group_by([col("b")]) +//! .agg( +//! vec![col("b").first().alias("first_b"), col("c").first().alias("first_c")] +//! ) +//! .select(&[col("b"), col("c_first")]) +//! } +//! ``` +//! +//! If we want to do an aggregation on all columns we can use the wildcard operator `*` to achieve this. +//! +//! ```rust +//! # use polars_core::prelude::*; +//! # use polars_lazy::prelude::*; +//! +//! fn aggregate_all_columns(df_a: DataFrame) -> LazyFrame { +//! df_a.lazy() +//! .group_by([col("b")]) +//! .agg( +//! vec![col("*").first()] +//! ) +//! } +//! ``` +#![allow(ambiguous_glob_reexports)] +#![cfg_attr(docsrs, feature(doc_auto_cfg))] +extern crate core; + +#[cfg(feature = "dot_diagram")] +mod dot; +pub mod dsl; +pub mod frame; +pub mod physical_plan; +pub mod prelude; + +mod scan; +#[cfg(test)] +mod tests; diff --git a/crates/polars-lazy/src/physical_plan/exotic.rs b/crates/polars-lazy/src/physical_plan/exotic.rs new file mode 100644 index 000000000000..df218b345532 --- /dev/null +++ b/crates/polars-lazy/src/physical_plan/exotic.rs @@ -0,0 +1,48 @@ +use polars_core::prelude::*; +use polars_expr::{ExpressionConversionState, create_physical_expr}; + +use crate::prelude::*; + +#[cfg(feature = "pivot")] +pub(crate) fn prepare_eval_expr(expr: Expr) -> Expr { + expr.map_expr(|e| match e { + Expr::Column(_) => Expr::Column(PlSmallStr::EMPTY), + Expr::Nth(_) => Expr::Column(PlSmallStr::EMPTY), + e => e, + }) +} + +pub(crate) fn prepare_expression_for_context( + name: PlSmallStr, + expr: &Expr, + dtype: &DataType, + ctxt: Context, +) -> PolarsResult> { + let mut lp_arena = Arena::with_capacity(8); + let mut expr_arena = Arena::with_capacity(10); + + // create a dummy lazyframe and run a very simple optimization run so that + // type coercion and simplify expression optimizations run. + let column = Series::full_null(name, 0, dtype); + let df = column.into_frame(); + let input_schema = df.schema().clone(); + let lf = df + .lazy() + .without_optimizations() + .with_simplify_expr(true) + .select([expr.clone()]); + let optimized = lf.optimize(&mut lp_arena, &mut expr_arena)?; + let lp = lp_arena.get(optimized); + let aexpr = lp + .get_exprs() + .pop() + .ok_or_else(|| polars_err!(ComputeError: "expected expressions in the context"))?; + + create_physical_expr( + &aexpr, + ctxt, + &expr_arena, + &input_schema, + &mut ExpressionConversionState::new(true), + ) +} diff --git a/crates/polars-lazy/src/physical_plan/mod.rs b/crates/polars-lazy/src/physical_plan/mod.rs new file mode 100644 index 000000000000..4f0895cc38f9 --- /dev/null +++ b/crates/polars-lazy/src/physical_plan/mod.rs @@ -0,0 +1,4 @@ +#[cfg(any(feature = "list_eval", feature = "pivot"))] +pub(crate) mod exotic; +#[cfg(feature = "streaming")] +pub(crate) mod streaming; diff --git a/crates/polars-lazy/src/physical_plan/streaming/checks.rs b/crates/polars-lazy/src/physical_plan/streaming/checks.rs new file mode 100644 index 000000000000..0f5810bf7588 --- /dev/null +++ b/crates/polars-lazy/src/physical_plan/streaming/checks.rs @@ -0,0 +1,43 @@ +use polars_core::chunked_array::ops::SortMultipleOptions; +use polars_ops::prelude::*; +use polars_plan::plans::expr_ir::ExprIR; +use polars_plan::prelude::*; + +pub(super) fn is_streamable_sort( + slice: &Option<(i64, usize)>, + sort_options: &SortMultipleOptions, +) -> bool { + // check if slice is positive or maintain order is true + if sort_options.maintain_order { + false + } else if let Some((offset, _)) = slice { + *offset >= 0 + } else { + true + } +} + +/// check if all expressions are a simple column projection +pub(super) fn all_column(exprs: &[ExprIR], expr_arena: &Arena) -> bool { + exprs + .iter() + .all(|e| matches!(expr_arena.get(e.node()), AExpr::Column(_))) +} + +pub(super) fn streamable_join(args: &JoinArgs) -> bool { + let supported = match args.how { + #[cfg(feature = "cross_join")] + JoinType::Cross => true, + JoinType::Left => true, + JoinType::Inner => { + // no-coalescing not yet supported in streaming + matches!( + args.coalesce, + JoinCoalesce::JoinSpecific | JoinCoalesce::CoalesceColumns + ) + }, + JoinType::Full => true, + _ => false, + }; + supported && !args.validation.needs_checks() +} diff --git a/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs b/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs new file mode 100644 index 000000000000..3c94231b0ac4 --- /dev/null +++ b/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs @@ -0,0 +1,261 @@ +use std::cell::RefCell; +use std::rc::Rc; +use std::sync::Mutex; + +use polars_core::config::verbose; +use polars_core::prelude::*; +use polars_expr::{ExpressionConversionState, create_physical_expr}; +use polars_io::predicates::PhysicalIoExpr; +use polars_pipe::expressions::PhysicalPipedExpr; +use polars_pipe::operators::chunks::DataChunk; +use polars_pipe::pipeline::{ + CallBacks, PipeLine, create_pipeline, execute_pipeline, get_dummy_operator, get_operator, +}; +use polars_plan::prelude::expr_ir::ExprIR; + +use crate::physical_plan::streaming::tree::{PipelineNode, Tree}; +use crate::prelude::*; + +pub struct Wrap(Arc); + +impl PhysicalIoExpr for Wrap { + fn evaluate_io(&self, df: &DataFrame) -> PolarsResult { + let h = PhysicalIoHelper { + expr: self.0.clone(), + has_window_function: false, + }; + h.evaluate_io(df) + } +} +impl PhysicalPipedExpr for Wrap { + fn evaluate(&self, chunk: &DataChunk, state: &ExecutionState) -> PolarsResult { + self.0 + .evaluate(&chunk.data, state) + .map(|c| c.take_materialized_series()) + } + fn field(&self, input_schema: &Schema) -> PolarsResult { + self.0.to_field(input_schema) + } + + fn expression(&self) -> Expr { + self.0.as_expression().unwrap().clone() + } +} + +fn to_physical_piped_expr( + expr: &ExprIR, + expr_arena: &Arena, + schema: &SchemaRef, +) -> PolarsResult> { + // this is a double Arc explore if we can create a single of it. + create_physical_expr( + expr, + Context::Default, + expr_arena, + schema, + &mut ExpressionConversionState::new(false), + ) + .map(|e| Arc::new(Wrap(e)) as Arc) +} + +fn jit_insert_slice( + node: Node, + lp_arena: &mut Arena, + sink_nodes: &mut Vec<(usize, Node, Rc>)>, + operator_offset: usize, +) { + // if the join has a slice, we add a new slice node + // note that we take the offset + 1, because we want to + // slice AFTER the join has happened and the join will be an + // operator + // NOTE: Don't do this for union, that doesn't work. + // TODO! Deal with this in the optimizer. + use IR::*; + let (offset, len) = match lp_arena.get(node) { + Join { options, .. } if options.args.slice.is_some() => { + let Some((offset, len)) = options.args.slice else { + unreachable!() + }; + (offset, len) + }, + _ => return, + }; + + let slice_node = lp_arena.add(Slice { + input: node, + offset, + len: len as IdxSize, + }); + sink_nodes.push((operator_offset + 1, slice_node, Rc::new(RefCell::new(1)))); +} + +pub(super) fn construct( + tree: Tree, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + fmt: bool, +) -> PolarsResult> { + use IR::*; + + let mut pipelines = Vec::with_capacity(tree.len()); + let mut callbacks = CallBacks::new(); + + let is_verbose = verbose(); + + // First traverse the branches and nodes to determine how often a sink is + // shared. + // This shared count will be used in the pipeline to determine + // when the sink can be finalized. + let mut sink_share_count = PlHashMap::new(); + let n_branches = tree.len(); + if n_branches > 1 { + for branch in &tree { + for op in branch.operators_sinks.iter() { + match op { + PipelineNode::Sink(sink) => { + let count = sink_share_count + .entry(sink.0) + .or_insert(Rc::new(RefCell::new(0u32))); + *count.borrow_mut() += 1; + }, + PipelineNode::RhsJoin(node) => { + let _ = callbacks.insert(*node, get_dummy_operator()); + }, + _ => {}, + } + } + } + } + + // Shared sinks are stored in a cache, so that they share state. + // If the shared sink is already in cache, that one is used. + let mut sink_cache = PlHashMap::new(); + let mut final_sink = None; + + for branch in tree { + // The file sink is always to the top of the tree + // not every branch has a final sink. For instance rhs join branches + if let Some(node) = branch.get_final_sink() { + if matches!(lp_arena.get(node), IR::Sink { .. }) { + final_sink = Some(node) + } + } + // should be reset for every branch + let mut sink_nodes = vec![]; + + let mut operators = Vec::with_capacity(branch.operators_sinks.len()); + let mut operator_nodes = Vec::with_capacity(branch.operators_sinks.len()); + + // iterate from leaves upwards + let mut iter = branch.operators_sinks.into_iter().rev(); + + for pipeline_node in &mut iter { + let operator_offset = operators.len(); + match pipeline_node { + PipelineNode::Sink(node) => { + let shared_count = if n_branches > 1 { + // should be here + sink_share_count.get(&node.0).unwrap().clone() + } else { + Rc::new(RefCell::new(1)) + }; + sink_nodes.push((operator_offset, node, shared_count)) + }, + PipelineNode::Operator(node) => { + operator_nodes.push(node); + let op = get_operator(node, lp_arena, expr_arena, &to_physical_piped_expr)?; + operators.push(op); + }, + PipelineNode::Union(node) => { + operator_nodes.push(node); + let op = get_operator(node, lp_arena, expr_arena, &to_physical_piped_expr)?; + operators.push(op); + }, + PipelineNode::RhsJoin(node) => { + operator_nodes.push(node); + jit_insert_slice(node, lp_arena, &mut sink_nodes, operator_offset); + let op = callbacks.get(&node).unwrap().clone(); + operators.push(Box::new(op)) + }, + } + } + + let pipeline = create_pipeline( + &branch.sources, + operators, + sink_nodes, + lp_arena, + expr_arena, + to_physical_piped_expr, + is_verbose, + &mut sink_cache, + &mut callbacks, + )?; + pipelines.push(pipeline); + } + + let Some(final_sink) = final_sink else { + return Ok(None); + }; + let insertion_location = match lp_arena.get(final_sink) { + // this was inserted only during conversion and does not exist + // in the original tree, so we take the input, as that's where + // we connect into the original tree. + Sink { + input, + payload: SinkTypeIR::Memory, + } => *input, + // Other sinks were not inserted during conversion, + // so they are returned as-is + Sink { .. } => final_sink, + _ => unreachable!(), + }; + // keep the original around for formatting purposes + let original_lp = if fmt { + let original_lp = IRPlan::new(insertion_location, lp_arena.clone(), expr_arena.clone()); + Some(original_lp) + } else { + None + }; + + // Replace the part of the logical plan with a `MapFunction` that will execute the pipeline. + let schema = lp_arena + .get(insertion_location) + .schema(lp_arena) + .into_owned(); + let pipeline_node = get_pipeline_node(lp_arena, pipelines, schema, original_lp); + lp_arena.replace(insertion_location, pipeline_node); + + Ok(Some(final_sink)) +} + +fn get_pipeline_node( + lp_arena: &mut Arena, + mut pipelines: Vec, + schema: SchemaRef, + original_lp: Option, +) -> IR { + // create a dummy input as the map function will call the input + // so we just create a scan that returns an empty df + let dummy = lp_arena.add(IR::DataFrameScan { + df: Arc::new(DataFrame::empty()), + schema: Arc::new(Schema::default()), + output_schema: None, + }); + + IR::MapFunction { + function: FunctionIR::Pipeline { + function: Arc::new(Mutex::new(move |_df: DataFrame| { + let state = ExecutionState::new(); + if state.verbose() { + eprintln!("RUN STREAMING PIPELINE"); + eprintln!("{:?}", &pipelines) + } + execute_pipeline(state, std::mem::take(&mut pipelines)) + })), + schema, + original: original_lp.map(Arc::new), + }, + input: dummy, + } +} diff --git a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs new file mode 100644 index 000000000000..d3c330432f95 --- /dev/null +++ b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs @@ -0,0 +1,446 @@ +use polars_core::prelude::*; +use polars_pipe::pipeline::swap_join_order; +use polars_plan::prelude::*; + +use super::checks::*; +use crate::physical_plan::streaming::tree::*; + +// The index of the pipeline tree we are building at this moment +// if we have a node we cannot do streaming, we have finished that pipeline tree +// and start a new one. +type CurrentIdx = usize; + +// Frame in the stack of logical plans to process while inserting streaming nodes +struct StackFrame { + node: Node, // LogicalPlan node + state: Branch, + current_idx: CurrentIdx, + insert_sink: bool, +} + +impl StackFrame { + fn root(node: Node) -> StackFrame { + StackFrame { + node, + state: Branch::default(), + current_idx: 0, + insert_sink: false, + } + } + + fn new(node: Node, state: Branch, current_idx: CurrentIdx) -> StackFrame { + StackFrame { + node, + state, + current_idx, + insert_sink: false, + } + } + + // Create a new streaming subtree below a non-streaming node + fn new_subtree(node: Node, current_idx: CurrentIdx) -> StackFrame { + StackFrame { + node, + state: Branch::default(), + current_idx, + insert_sink: true, + } + } +} + +fn process_non_streamable_node( + current_idx: &mut CurrentIdx, + state: &mut Branch, + stack: &mut Vec, + scratch: &mut Vec, + pipeline_trees: &mut Vec>, + lp: &IR, +) { + lp.copy_inputs(scratch); + while let Some(input) = scratch.pop() { + if state.streamable { + *current_idx += 1; + // create a completely new streaming pipeline + // maybe we can stream a subsection of the plan + pipeline_trees.push(vec![]); + } + stack.push(StackFrame::new_subtree(input, *current_idx)); + } + state.streamable = false; +} + +fn insert_file_sink(mut root: Node, lp_arena: &mut Arena) -> Node { + // The pipelines need a final sink, we insert that here. + // this allows us to split at joins/unions and share a sink + if !matches!(lp_arena.get(root), IR::Sink { .. }) { + root = lp_arena.add(IR::Sink { + input: root, + payload: SinkTypeIR::Memory, + }) + } + root +} + +pub(crate) fn insert_streaming_nodes( + root: Node, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + scratch: &mut Vec, + fmt: bool, + // whether the full plan needs to be translated + // to streaming + allow_partial: bool, + row_estimate: bool, +) -> PolarsResult { + scratch.clear(); + + // This is needed to determine which side of the joins should be + // traversed first. As we want to keep the smallest table in the build phase as that keeps most + // data in memory. + if row_estimate { + set_estimated_row_counts(root, lp_arena, expr_arena, 0, scratch); + } + + scratch.clear(); + + // The pipelines always need to end in a SINK, we insert that here. + // this allows us to split at joins/unions and share a sink + let root = insert_file_sink(root, lp_arena); + + // We use a bool flag in the stack to communicate when we need to insert a file sink. + // This happens for instance when we + // + // ________*non-streamable part of query + // /\ + // ________*streamable below this line so we must insert + // /\ a file sink here so the pipeline can be built + // /\ + + let mut stack = Vec::with_capacity(16); + + stack.push(StackFrame::root(root)); + + // A state holds a full pipeline until the breaker + // 1/\ + // 2/\ + // 3\ + // + // so 1 and 2 are short pipelines and 3 goes all the way to the root. + // but 3 can only run if 1 and 2 have finished and set the join as operator in 3 + // and state are filled with pipeline 1, 2, 3 in that order + // + // / \ + // /\ 3/\ + // 1 2 4\ + // or in this case 1, 2, 3, 4 + // every inner vec contains a branch/pipeline of a complete pipeline tree + // the outer vec contains whole pipeline trees + // + // # Execution order + // Trees can have arbitrary splits via joins and unions + // the branches we have accumulated are flattened into a single Vec + // this therefore has lost the information of the tree. To know in which + // order the branches need to be executed. For this reason we keep track of + // an `execution_id` which will be incremented on every stack operation. + // This way we know in which order the stack/tree was traversed and can + // use that info to determine the execution order of the single branch/pipelines + let mut pipeline_trees: Vec = vec![vec![]]; + // keep the counter global so that the order will match traversal order + let mut execution_id = 0; + + use IR::*; + while let Some(StackFrame { + node: mut root, + mut state, + mut current_idx, + insert_sink, + }) = stack.pop() + { + if insert_sink { + root = insert_file_sink(root, lp_arena); + } + state.execution_id = execution_id; + execution_id += 1; + match lp_arena.get(root) { + Filter { input, predicate } if is_elementwise_rec(predicate.node(), expr_arena) => { + state.streamable = true; + state.operators_sinks.push(PipelineNode::Operator(root)); + stack.push(StackFrame::new(*input, state, current_idx)) + }, + HStack { input, exprs, .. } if all_elementwise(exprs, expr_arena) => { + state.streamable = true; + state.operators_sinks.push(PipelineNode::Operator(root)); + stack.push(StackFrame::new(*input, state, current_idx)) + }, + Slice { input, offset, .. } if *offset >= 0 => { + state.streamable = true; + state.operators_sinks.push(PipelineNode::Sink(root)); + stack.push(StackFrame::new(*input, state, current_idx)) + }, + Sink { input, .. } => { + state.streamable = true; + state.operators_sinks.push(PipelineNode::Sink(root)); + stack.push(StackFrame::new(*input, state, current_idx)) + }, + Sort { + input, + by_column, + slice, + sort_options, + } if is_streamable_sort(slice, sort_options) && all_column(by_column, expr_arena) => { + state.streamable = true; + state.operators_sinks.push(PipelineNode::Sink(root)); + stack.push(StackFrame::new(*input, state, current_idx)) + }, + Select { input, expr, .. } if all_elementwise(expr, expr_arena) => { + state.streamable = true; + state.operators_sinks.push(PipelineNode::Operator(root)); + stack.push(StackFrame::new(*input, state, current_idx)) + }, + SimpleProjection { input, .. } => { + state.streamable = true; + state.operators_sinks.push(PipelineNode::Operator(root)); + stack.push(StackFrame::new(*input, state, current_idx)) + }, + // Rechunks are ignored + MapFunction { + input, + function: FunctionIR::Rechunk, + } => { + state.streamable = true; + stack.push(StackFrame::new(*input, state, current_idx)) + }, + // Streamable functions will be converted + lp @ MapFunction { input, function } => { + if function.is_streamable() { + state.streamable = true; + state.operators_sinks.push(PipelineNode::Operator(root)); + stack.push(StackFrame::new(*input, state, current_idx)) + } else { + process_non_streamable_node( + &mut current_idx, + &mut state, + &mut stack, + scratch, + &mut pipeline_trees, + lp, + ) + } + }, + Scan { + scan_type, + unified_scan_args, + .. + } if scan_type.streamable() + && matches!( + &unified_scan_args.pre_slice, + None | Some(polars_utils::slice_enum::Slice::Positive { .. }) + ) => + { + if state.streamable { + state.sources.push(root); + pipeline_trees[current_idx].push(state) + } + }, + DataFrameScan { .. } => { + if state.streamable { + state.sources.push(root); + pipeline_trees[current_idx].push(state) + } + }, + Join { + input_left, + input_right, + options, + .. + } if streamable_join(&options.args) => { + let input_left = *input_left; + let input_right = *input_right; + state.streamable = true; + state.join_count += 1; + + // We swap so that the build phase contains the smallest table + // and then we stream the larger table + // *except* for a left join. In a left join we use the right + // table as build table and we stream the left table. This way + // we maintain order in the left join. + let (input_left, input_right) = if swap_join_order(options) { + (input_right, input_left) + } else { + (input_left, input_right) + }; + let mut state_left = state.split(); + + // Rhs is second, so that is first on the stack. + let mut state_right = state; + state_right.join_count = 0; + state_right + .operators_sinks + .push(PipelineNode::RhsJoin(root)); + + // We want to traverse lhs last, so push it first on the stack + // rhs is a new pipeline. + state_left.operators_sinks.push(PipelineNode::Sink(root)); + stack.push(StackFrame::new(input_left, state_left, current_idx)); + stack.push(StackFrame::new(input_right, state_right, current_idx)); + }, + // add globbing patterns + #[cfg(any(feature = "csv", feature = "parquet"))] + Union { inputs, options } + if options.slice.is_none() + && inputs.iter().all(|node| match lp_arena.get(*node) { + Scan { .. } => true, + MapFunction { + input, + function: FunctionIR::Rechunk, + } => matches!(lp_arena.get(*input), Scan { .. }), + _ => false, + }) => + { + state.sources.push(root); + pipeline_trees[current_idx].push(state); + }, + Union { inputs, .. } => { + { + state.streamable = true; + for (i, input) in inputs.iter().enumerate() { + let mut state = if i == 0 { + // note the clone! + let mut state = state.clone(); + state.join_count += inputs.len() as u32 - 1; + state + } else { + let mut state = state.split_from_sink(); + state.join_count = 0; + state + }; + state.operators_sinks.push(PipelineNode::Union(root)); + stack.push(StackFrame::new(*input, state, current_idx)); + } + } + }, + Distinct { input, options } + if !options.maintain_order + && !matches!(options.keep_strategy, UniqueKeepStrategy::None) => + { + state.streamable = true; + state.operators_sinks.push(PipelineNode::Sink(root)); + stack.push(StackFrame::new(*input, state, current_idx)) + }, + #[allow(unused_variables)] + lp @ GroupBy { + input, + keys, + aggs, + maintain_order: false, + apply: None, + schema: output_schema, + options, + .. + } => { + #[cfg(feature = "dtype-categorical")] + let string_cache = polars_core::using_string_cache(); + #[cfg(not(feature = "dtype-categorical"))] + let string_cache = true; + + #[allow(unused_variables)] + fn allowed_dtype(dt: &DataType, string_cache: bool) -> bool { + match dt { + #[cfg(feature = "object")] + DataType::Object(_) => false, + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_, _) => string_cache, + DataType::List(inner) => allowed_dtype(inner, string_cache), + #[cfg(feature = "dtype-struct")] + DataType::Struct(fields) => fields + .iter() + .all(|fld| allowed_dtype(fld.dtype(), string_cache)), + // We need to be able to sink to disk or produce the aggregate return dtype. + DataType::Unknown(_) => false, + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(_, _) => false, + DataType::Int128 => false, + _ => true, + } + } + let input_schema = lp_arena.get(*input).schema(lp_arena); + #[allow(unused_mut)] + let mut can_stream = true; + + #[cfg(feature = "dynamic_group_by")] + { + if options.rolling.is_some() || options.dynamic.is_some() { + can_stream = false + } + } + + let valid_agg = || { + aggs.iter().all(|e| { + polars_pipe::pipeline::can_convert_to_hash_agg( + e.node(), + expr_arena, + &input_schema, + ) + }) + }; + + let valid_key = || { + keys.iter().all(|e| { + output_schema + .get(e.output_name()) + .map(|dt| !matches!(dt, DataType::List(_))) + .unwrap_or(false) + }) + }; + + let valid_types = || { + output_schema + .iter_values() + .all(|dt| allowed_dtype(dt, string_cache)) + }; + + if can_stream && valid_agg() && valid_key() && valid_types() { + state.streamable = true; + state.operators_sinks.push(PipelineNode::Sink(root)); + stack.push(StackFrame::new(*input, state, current_idx)) + } else if allow_partial { + process_non_streamable_node( + &mut current_idx, + &mut state, + &mut stack, + scratch, + &mut pipeline_trees, + lp, + ) + } else { + return Ok(false); + } + }, + lp => { + if allow_partial { + process_non_streamable_node( + &mut current_idx, + &mut state, + &mut stack, + scratch, + &mut pipeline_trees, + lp, + ) + } else { + return Ok(false); + } + }, + } + } + + let mut inserted = false; + for tree in pipeline_trees { + if is_valid_tree(&tree) + && super::construct_pipeline::construct(tree, lp_arena, expr_arena, fmt)?.is_some() + { + inserted = true; + } + } + + Ok(inserted) +} diff --git a/crates/polars-lazy/src/physical_plan/streaming/mod.rs b/crates/polars-lazy/src/physical_plan/streaming/mod.rs new file mode 100644 index 000000000000..893ff7461f3c --- /dev/null +++ b/crates/polars-lazy/src/physical_plan/streaming/mod.rs @@ -0,0 +1,6 @@ +mod checks; +mod construct_pipeline; +mod convert_alp; +mod tree; + +pub(crate) use convert_alp::insert_streaming_nodes; diff --git a/crates/polars-lazy/src/physical_plan/streaming/tree.rs b/crates/polars-lazy/src/physical_plan/streaming/tree.rs new file mode 100644 index 000000000000..db25b429bfe3 --- /dev/null +++ b/crates/polars-lazy/src/physical_plan/streaming/tree.rs @@ -0,0 +1,195 @@ +use std::collections::BTreeSet; +use std::fmt::Debug; + +use polars_plan::prelude::*; + +#[derive(Copy, Clone, Debug)] +pub(super) enum PipelineNode { + Sink(Node), + Operator(Node), + RhsJoin(Node), + Union(Node), +} + +impl PipelineNode { + pub(super) fn node(self) -> Node { + match self { + Self::Sink(node) => node, + Self::Operator(node) => node, + Self::RhsJoin(node) => node, + Self::Union(node) => node, + } + } +} + +/// Represents a pipeline/ branch in a subquery tree +#[derive(Default, Debug, Clone)] +pub(super) struct Branch { + // During traversal of ALP + // we determine the execution order + // as traversal order == execution order + // we can increment this counter + // the individual branches are then flattened + // sorted and executed in reversed order + // (to traverse from leaves to root) + pub(super) execution_id: u32, + pub(super) streamable: bool, + pub(super) sources: Vec, + // joins seen in whole branch (we count a union as joins with multiple counts) + pub(super) join_count: u32, + // node is operator/sink + pub(super) operators_sinks: Vec, +} + +fn sink_node(pl_node: &PipelineNode) -> Option { + match pl_node { + PipelineNode::Sink(node) => Some(*node), + _ => None, + } +} + +impl Branch { + pub(super) fn get_final_sink(&self) -> Option { + // this is still in the order of discovery + // so the first sink is the final one. + self.operators_sinks.iter().find_map(sink_node) + } + pub(super) fn split(&self) -> Self { + Self { + execution_id: self.execution_id, + streamable: self.streamable, + join_count: self.join_count, + ..Default::default() + } + } + + /// this will share the sink + pub(super) fn split_from_sink(&self) -> Self { + match self + .operators_sinks + .iter() + .rposition(|pl_node| sink_node(pl_node).is_some()) + { + None => self.split(), + Some(pos) => Self { + execution_id: self.execution_id, + streamable: self.streamable, + join_count: self.join_count, + operators_sinks: self.operators_sinks[pos..].to_vec(), + ..Default::default() + }, + } + } +} + +/// Represents a subquery tree of pipelines. +type TreeRef<'a> = &'a [Branch]; +pub(super) type Tree = Vec; + +/// We validate a tree in order to check if it is eligible for streaming. +/// It could be that a join branch wasn't added during collection of branches +/// (because it contained a non-streamable node). This function checks if every join +/// node has a match. +pub(super) fn is_valid_tree(tree: TreeRef) -> bool { + if tree.is_empty() { + return false; + }; + let joins_in_tree = tree.iter().map(|branch| branch.join_count).sum::(); + let branches_in_tree = tree.len() as u32; + + // all join branches should be added, if not we skip the tree, as it is invalid + if (branches_in_tree - 1) != joins_in_tree { + return false; + } + + // rhs joins will initially be placeholders + let mut left_joins = BTreeSet::new(); + for branch in tree { + for pl_node in &branch.operators_sinks { + if !matches!(pl_node, PipelineNode::RhsJoin(_)) { + left_joins.insert(pl_node.node().0); + } + } + } + for branch in tree { + for pl_node in &branch.operators_sinks { + // check if every rhs join has a lhs join node + if matches!(pl_node, PipelineNode::RhsJoin(_)) + && !left_joins.contains(&pl_node.node().0) + { + return false; + } + } + } + true +} + +#[cfg(debug_assertions)] +#[allow(unused)] +pub(super) fn dbg_branch(b: &Branch, lp_arena: &Arena) { + // streamable: bool, + // sources: Vec, + // // joins seen in whole branch (we count a union as joins with multiple counts) + // join_count: IdxSize, + // // node is operator/sink + // operators_sinks: Vec<(IsSink, IsRhsJoin, Node)>, + + if b.streamable { + print!("streamable: ") + } else { + print!("non-streamable: ") + } + for src in &b.sources { + let lp = lp_arena.get(*src); + print!("{}, ", lp.name()); + } + print!("=> "); + + for pl_node in &b.operators_sinks { + let lp = lp_arena.get(pl_node.node()); + if matches!(pl_node, PipelineNode::RhsJoin(_)) { + print!("rhs_join_placeholder -> "); + } else { + print!("{} -> ", lp.name()); + } + } + println!(); +} + +#[cfg(debug_assertions)] +#[allow(unused)] +pub(super) fn dbg_tree(tree: Tree, lp_arena: &Arena, expr_arena: &Arena) { + if tree.is_empty() { + println!("EMPTY TREE"); + return; + } + let root = tree + .iter() + .map(|branch| { + let pl_node = branch.operators_sinks.last().unwrap(); + pl_node.node() + }) + .max_by_key(|root| { + // count the children of this root + // the branch with the most children is the root of the whole tree + lp_arena.iter(*root).count() + }) + .unwrap(); + + println!("SUBPLAN ELIGIBLE FOR STREAMING:"); + println!( + "{}\n", + IRPlanRef { + lp_top: root, + lp_arena, + expr_arena + } + .display() + ); + + println!("PIPELINE TREE:"); + for (i, branch) in tree.iter().enumerate() { + print!("{i}: "); + dbg_branch(branch, lp_arena); + } +} diff --git a/crates/polars-lazy/src/prelude.rs b/crates/polars-lazy/src/prelude.rs new file mode 100644 index 000000000000..7de15a0f6622 --- /dev/null +++ b/crates/polars-lazy/src/prelude.rs @@ -0,0 +1,27 @@ +pub(crate) use polars_expr::prelude::*; +#[cfg(feature = "csv")] +pub use polars_io::csv::write::CsvWriterOptions; +#[cfg(feature = "ipc")] +pub use polars_io::ipc::IpcWriterOptions; +#[cfg(feature = "json")] +pub use polars_io::json::JsonWriterOptions; +#[cfg(feature = "parquet")] +pub use polars_io::parquet::write::ParquetWriteOptions; +pub use polars_ops::prelude::{JoinArgs, JoinType, JoinValidation}; +#[cfg(feature = "rank")] +pub use polars_ops::prelude::{RankMethod, RankOptions}; +#[cfg(feature = "polars_cloud")] +pub use polars_plan::client::prepare_cloud_plan; +pub use polars_plan::dsl::AnonymousScanOptions; +pub use polars_plan::plans::{AnonymousScan, AnonymousScanArgs, Literal, LiteralValue, NULL, Null}; +pub use polars_plan::prelude::UnionArgs; +pub(crate) use polars_plan::prelude::*; +#[cfg(feature = "rolling_window_by")] +pub use polars_time::Duration; +#[cfg(feature = "dynamic_group_by")] +pub use polars_time::{DynamicGroupOptions, PolarsTemporalGroupby, RollingGroupOptions}; +pub(crate) use polars_utils::arena::{Arena, Node}; + +pub use crate::dsl::*; +pub use crate::frame::*; +pub(crate) use crate::scan::*; diff --git a/crates/polars-lazy/src/scan/anonymous_scan.rs b/crates/polars-lazy/src/scan/anonymous_scan.rs new file mode 100644 index 000000000000..273c8fb1bbb2 --- /dev/null +++ b/crates/polars-lazy/src/scan/anonymous_scan.rs @@ -0,0 +1,69 @@ +use polars_core::prelude::*; +use polars_io::{HiveOptions, RowIndex}; +use polars_utils::slice_enum::Slice; + +use crate::prelude::*; + +#[derive(Clone)] +pub struct ScanArgsAnonymous { + pub infer_schema_length: Option, + pub schema: Option, + pub skip_rows: Option, + pub n_rows: Option, + pub row_index: Option, + pub name: &'static str, +} + +impl Default for ScanArgsAnonymous { + fn default() -> Self { + Self { + infer_schema_length: None, + skip_rows: None, + n_rows: None, + schema: None, + row_index: None, + name: "ANONYMOUS SCAN", + } + } +} +impl LazyFrame { + pub fn anonymous_scan( + function: Arc, + args: ScanArgsAnonymous, + ) -> PolarsResult { + let schema = match args.schema { + Some(s) => s, + None => function.schema(args.infer_schema_length)?, + }; + + let mut lf: LazyFrame = DslBuilder::anonymous_scan( + function, + AnonymousScanOptions { + skip_rows: args.skip_rows, + fmt_str: args.name, + }, + UnifiedScanArgs { + schema: Some(schema), + cloud_options: None, + hive_options: HiveOptions::new_disabled(), + rechunk: false, + cache: false, + glob: false, + projection: None, + row_index: None, + pre_slice: args.n_rows.map(|len| Slice::Positive { offset: 0, len }), + cast_columns_policy: CastColumnsPolicy::ErrorOnMismatch, + missing_columns_policy: MissingColumnsPolicy::Raise, + include_file_paths: None, + }, + )? + .build() + .into(); + + if let Some(rc) = args.row_index { + lf = lf.with_row_index(rc.name.clone(), Some(rc.offset)) + }; + + Ok(lf) + } +} diff --git a/crates/polars-lazy/src/scan/catalog.rs b/crates/polars-lazy/src/scan/catalog.rs new file mode 100644 index 000000000000..94ae9a1f7624 --- /dev/null +++ b/crates/polars-lazy/src/scan/catalog.rs @@ -0,0 +1,55 @@ +use polars_core::error::{PolarsResult, feature_gated, polars_bail}; +use polars_io::catalog::unity::models::{DataSourceFormat, TableInfo}; +use polars_io::catalog::unity::schema::table_info_to_schemas; +use polars_io::cloud::CloudOptions; + +use crate::frame::LazyFrame; + +impl LazyFrame { + pub fn scan_catalog_table( + table_info: &TableInfo, + cloud_options: Option, + ) -> PolarsResult { + let Some(data_source_format) = &table_info.data_source_format else { + polars_bail!(ComputeError: "scan_catalog_table requires Some(_) for data_source_format") + }; + + let Some(storage_location) = table_info.storage_location.as_deref() else { + polars_bail!(ComputeError: "scan_catalog_table requires Some(_) for storage_location") + }; + + match data_source_format { + DataSourceFormat::Parquet => feature_gated!("parquet", { + use polars_io::HiveOptions; + + use crate::frame::ScanArgsParquet; + let (schema, hive_schema) = table_info_to_schemas(table_info)?; + + let args = ScanArgsParquet { + schema, + cloud_options, + hive_options: HiveOptions { + schema: hive_schema, + ..Default::default() + }, + ..Default::default() + }; + + Self::scan_parquet(storage_location, args) + }), + DataSourceFormat::Csv => feature_gated!("csv", { + use crate::frame::{LazyCsvReader, LazyFileListReader}; + let (schema, _) = table_info_to_schemas(table_info)?; + + LazyCsvReader::new(storage_location) + .with_schema(schema) + .finish() + }), + v => polars_bail!( + ComputeError: + "not yet supported data_source_format: {:?}", + v + ), + } + } +} diff --git a/crates/polars-lazy/src/scan/csv.rs b/crates/polars-lazy/src/scan/csv.rs new file mode 100644 index 000000000000..17f23f39c245 --- /dev/null +++ b/crates/polars-lazy/src/scan/csv.rs @@ -0,0 +1,416 @@ +use std::path::{Path, PathBuf}; + +use polars_core::prelude::*; +use polars_io::cloud::CloudOptions; +use polars_io::csv::read::{ + CommentPrefix, CsvEncoding, CsvParseOptions, CsvReadOptions, NullValues, infer_file_schema, +}; +use polars_io::path_utils::expand_paths; +use polars_io::utils::compression::maybe_decompress_bytes; +use polars_io::utils::get_reader_bytes; +use polars_io::{HiveOptions, RowIndex}; +use polars_utils::mmap::MemSlice; +use polars_utils::slice_enum::Slice; + +use crate::prelude::*; + +#[derive(Clone)] +#[cfg(feature = "csv")] +pub struct LazyCsvReader { + sources: ScanSources, + glob: bool, + cache: bool, + read_options: CsvReadOptions, + cloud_options: Option, + include_file_paths: Option, +} + +#[cfg(feature = "csv")] +impl LazyCsvReader { + /// Re-export to shorten code. + pub fn map_parse_options CsvParseOptions>( + mut self, + map_func: F, + ) -> Self { + self.read_options = self.read_options.map_parse_options(map_func); + self + } + + pub fn new_paths(paths: Arc<[PathBuf]>) -> Self { + Self::new_with_sources(ScanSources::Paths(paths)) + } + + pub fn new_with_sources(sources: ScanSources) -> Self { + LazyCsvReader { + sources, + glob: true, + cache: true, + read_options: Default::default(), + cloud_options: Default::default(), + include_file_paths: None, + } + } + + pub fn new(path: impl AsRef) -> Self { + Self::new_with_sources(ScanSources::Paths([path.as_ref().to_path_buf()].into())) + } + + /// Skip this number of rows after the header location. + #[must_use] + pub fn with_skip_rows_after_header(mut self, offset: usize) -> Self { + self.read_options.skip_rows_after_header = offset; + self + } + + /// Add a row index column. + #[must_use] + pub fn with_row_index(mut self, row_index: Option) -> Self { + self.read_options.row_index = row_index; + self + } + + /// Try to stop parsing when `n` rows are parsed. During multithreaded parsing the upper bound `n` cannot + /// be guaranteed. + #[must_use] + pub fn with_n_rows(mut self, num_rows: Option) -> Self { + self.read_options.n_rows = num_rows; + self + } + + /// Set the number of rows to use when inferring the csv schema. + /// The default is 100 rows. + /// Setting to [None] will do a full table scan, which is very slow. + #[must_use] + pub fn with_infer_schema_length(mut self, num_rows: Option) -> Self { + self.read_options.infer_schema_length = num_rows; + self + } + + /// Continue with next batch when a ParserError is encountered. + #[must_use] + pub fn with_ignore_errors(mut self, ignore: bool) -> Self { + self.read_options.ignore_errors = ignore; + self + } + + /// Set the CSV file's schema + #[must_use] + pub fn with_schema(mut self, schema: Option) -> Self { + self.read_options.schema = schema; + self + } + + /// Skip the first `n` rows during parsing. The header will be parsed at row `n`. + /// Note that by row we mean valid CSV, encoding and comments are respected. + #[must_use] + pub fn with_skip_rows(mut self, skip_rows: usize) -> Self { + self.read_options.skip_rows = skip_rows; + self + } + + /// Skip the first `n` lines during parsing. The header will be parsed at line `n`. + /// We don't respect CSV escaping when skipping lines. + #[must_use] + pub fn with_skip_lines(mut self, skip_lines: usize) -> Self { + self.read_options.skip_lines = skip_lines; + self + } + + /// Overwrite the schema with the dtypes in this given Schema. The given schema may be a subset + /// of the total schema. + #[must_use] + pub fn with_dtype_overwrite(mut self, schema: Option) -> Self { + self.read_options.schema_overwrite = schema; + self + } + + /// Set whether the CSV file has headers + #[must_use] + pub fn with_has_header(mut self, has_header: bool) -> Self { + self.read_options.has_header = has_header; + self + } + + /// Sets the chunk size used by the parser. This influences performance. + /// This can be used as a way to reduce memory usage during the parsing at the cost of performance. + pub fn with_chunk_size(mut self, chunk_size: usize) -> Self { + self.read_options.chunk_size = chunk_size; + self + } + + /// Set the CSV file's column separator as a byte character + #[must_use] + pub fn with_separator(self, separator: u8) -> Self { + self.map_parse_options(|opts| opts.with_separator(separator)) + } + + /// Set the comment prefix for this instance. Lines starting with this prefix will be ignored. + #[must_use] + pub fn with_comment_prefix(self, comment_prefix: Option) -> Self { + self.map_parse_options(|opts| { + opts.with_comment_prefix(comment_prefix.clone().map(|s| { + if s.len() == 1 && s.chars().next().unwrap().is_ascii() { + CommentPrefix::Single(s.as_bytes()[0]) + } else { + CommentPrefix::Multi(s) + } + })) + }) + } + + /// Set the `char` used as quote char. The default is `b'"'`. If set to [`None`] quoting is disabled. + #[must_use] + pub fn with_quote_char(self, quote_char: Option) -> Self { + self.map_parse_options(|opts| opts.with_quote_char(quote_char)) + } + + /// Set the `char` used as end of line. The default is `b'\n'`. + #[must_use] + pub fn with_eol_char(self, eol_char: u8) -> Self { + self.map_parse_options(|opts| opts.with_eol_char(eol_char)) + } + + /// Set values that will be interpreted as missing/ null. + #[must_use] + pub fn with_null_values(self, null_values: Option) -> Self { + self.map_parse_options(|opts| opts.with_null_values(null_values.clone())) + } + + /// Treat missing fields as null. + pub fn with_missing_is_null(self, missing_is_null: bool) -> Self { + self.map_parse_options(|opts| opts.with_missing_is_null(missing_is_null)) + } + + /// Cache the DataFrame after reading. + #[must_use] + pub fn with_cache(mut self, cache: bool) -> Self { + self.cache = cache; + self + } + + /// Reduce memory usage at the expense of performance + #[must_use] + pub fn with_low_memory(mut self, low_memory: bool) -> Self { + self.read_options.low_memory = low_memory; + self + } + + /// Set [`CsvEncoding`] + #[must_use] + pub fn with_encoding(self, encoding: CsvEncoding) -> Self { + self.map_parse_options(|opts| opts.with_encoding(encoding)) + } + + /// Automatically try to parse dates/datetimes and time. + /// If parsing fails, columns remain of dtype [`DataType::String`]. + #[cfg(feature = "temporal")] + pub fn with_try_parse_dates(self, try_parse_dates: bool) -> Self { + self.map_parse_options(|opts| opts.with_try_parse_dates(try_parse_dates)) + } + + /// Raise an error if CSV is empty (otherwise return an empty frame) + #[must_use] + pub fn with_raise_if_empty(mut self, raise_if_empty: bool) -> Self { + self.read_options.raise_if_empty = raise_if_empty; + self + } + + /// Truncate lines that are longer than the schema. + #[must_use] + pub fn with_truncate_ragged_lines(self, truncate_ragged_lines: bool) -> Self { + self.map_parse_options(|opts| opts.with_truncate_ragged_lines(truncate_ragged_lines)) + } + + #[must_use] + pub fn with_decimal_comma(self, decimal_comma: bool) -> Self { + self.map_parse_options(|opts| opts.with_decimal_comma(decimal_comma)) + } + + #[must_use] + /// Expand path given via globbing rules. + pub fn with_glob(mut self, toggle: bool) -> Self { + self.glob = toggle; + self + } + + pub fn with_cloud_options(mut self, cloud_options: Option) -> Self { + self.cloud_options = cloud_options; + self + } + + /// Modify a schema before we run the lazy scanning. + /// + /// Important! Run this function latest in the builder! + pub fn with_schema_modify(mut self, f: F) -> PolarsResult + where + F: Fn(Schema) -> PolarsResult, + { + let n_threads = self.read_options.n_threads; + + let infer_schema = |bytes: MemSlice| { + let skip_rows = self.read_options.skip_rows; + let skip_lines = self.read_options.skip_lines; + let parse_options = self.read_options.get_parse_options(); + + let mut owned = vec![]; + let bytes = maybe_decompress_bytes(bytes.as_ref(), &mut owned)?; + + PolarsResult::Ok( + infer_file_schema( + &get_reader_bytes(&mut std::io::Cursor::new(bytes))?, + &parse_options, + self.read_options.infer_schema_length, + self.read_options.has_header, + // we set it to None and modify them after the schema is updated + None, + skip_rows, + skip_lines, + self.read_options.skip_rows_after_header, + self.read_options.raise_if_empty, + )? + .0, + ) + }; + + let schema = match self.sources.clone() { + ScanSources::Paths(paths) => { + // TODO: Path expansion should happen when converting to the IR + // https://github.com/pola-rs/polars/issues/17634 + let paths = expand_paths(&paths[..], self.glob(), self.cloud_options())?; + + let Some(path) = paths.first() else { + polars_bail!(ComputeError: "no paths specified for this reader"); + }; + + infer_schema(MemSlice::from_file(&polars_utils::open_file(path)?)?)? + }, + ScanSources::Files(files) => { + let Some(file) = files.first() else { + polars_bail!(ComputeError: "no buffers specified for this reader"); + }; + + infer_schema(MemSlice::from_file(file)?)? + }, + ScanSources::Buffers(buffers) => { + let Some(buffer) = buffers.first() else { + polars_bail!(ComputeError: "no buffers specified for this reader"); + }; + + infer_schema(buffer.clone())? + }, + }; + + self.read_options.n_threads = n_threads; + let mut schema = f(schema)?; + + // the dtypes set may be for the new names, so update again + if let Some(overwrite_schema) = &self.read_options.schema_overwrite { + for (name, dtype) in overwrite_schema.iter() { + schema.with_column(name.clone(), dtype.clone()); + } + } + + Ok(self.with_schema(Some(Arc::new(schema)))) + } + + pub fn with_include_file_paths(mut self, include_file_paths: Option) -> Self { + self.include_file_paths = include_file_paths; + self + } +} + +impl LazyFileListReader for LazyCsvReader { + /// Get the final [LazyFrame]. + fn finish(self) -> PolarsResult { + let rechunk = self.rechunk(); + let row_index = self.row_index().cloned(); + let pre_slice = self.n_rows().map(|len| Slice::Positive { offset: 0, len }); + + let lf: LazyFrame = DslBuilder::scan_csv( + self.sources, + self.read_options, + UnifiedScanArgs { + schema: None, + cloud_options: self.cloud_options, + hive_options: HiveOptions::new_disabled(), + rechunk, + cache: self.cache, + glob: self.glob, + projection: None, + row_index, + pre_slice, + cast_columns_policy: CastColumnsPolicy::ErrorOnMismatch, + missing_columns_policy: MissingColumnsPolicy::Raise, + include_file_paths: self.include_file_paths, + }, + )? + .build() + .into(); + Ok(lf) + } + + fn finish_no_glob(self) -> PolarsResult { + unreachable!(); + } + + fn glob(&self) -> bool { + self.glob + } + + fn sources(&self) -> &ScanSources { + &self.sources + } + + fn with_sources(mut self, sources: ScanSources) -> Self { + self.sources = sources; + self + } + + fn with_n_rows(mut self, n_rows: impl Into>) -> Self { + self.read_options.n_rows = n_rows.into(); + self + } + + fn with_row_index(mut self, row_index: impl Into>) -> Self { + self.read_options.row_index = row_index.into(); + self + } + + fn rechunk(&self) -> bool { + self.read_options.rechunk + } + + /// Rechunk the memory to contiguous chunks when parsing is done. + fn with_rechunk(mut self, rechunk: bool) -> Self { + self.read_options.rechunk = rechunk; + self + } + + /// Try to stop parsing when `n` rows are parsed. During multithreaded parsing the upper bound `n` cannot + /// be guaranteed. + fn n_rows(&self) -> Option { + self.read_options.n_rows + } + + /// Return the row index settings. + fn row_index(&self) -> Option<&RowIndex> { + self.read_options.row_index.as_ref() + } + + fn concat_impl(&self, lfs: Vec) -> PolarsResult { + // set to false, as the csv parser has full thread utilization + let args = UnionArgs { + rechunk: self.rechunk(), + parallel: false, + to_supertypes: false, + from_partitioned_ds: true, + ..Default::default() + }; + concat_impl(&lfs, args) + } + + /// [CloudOptions] used to list files. + fn cloud_options(&self) -> Option<&CloudOptions> { + self.cloud_options.as_ref() + } +} diff --git a/crates/polars-lazy/src/scan/file_list_reader.rs b/crates/polars-lazy/src/scan/file_list_reader.rs new file mode 100644 index 000000000000..4fe5293d0991 --- /dev/null +++ b/crates/polars-lazy/src/scan/file_list_reader.rs @@ -0,0 +1,123 @@ +use std::path::PathBuf; +use std::sync::Arc; + +use polars_core::prelude::*; +use polars_io::RowIndex; +use polars_io::cloud::CloudOptions; +use polars_plan::prelude::UnionArgs; + +use crate::prelude::*; + +/// Reads [LazyFrame] from a filesystem or a cloud storage. +/// Supports glob patterns. +/// +/// Use [LazyFileListReader::finish] to get the final [LazyFrame]. +pub trait LazyFileListReader: Clone { + /// Get the final [LazyFrame]. + fn finish(self) -> PolarsResult { + if !self.glob() { + return self.finish_no_glob(); + } + + let ScanSources::Paths(paths) = self.sources() else { + unreachable!("opened-files or in-memory buffers should never be globbed"); + }; + + let lfs = paths + .iter() + .map(|path| { + self.clone() + // Each individual reader should not apply a row limit. + .with_n_rows(None) + // Each individual reader should not apply a row index. + .with_row_index(None) + .with_paths([path.clone()].into()) + .with_rechunk(false) + .finish_no_glob() + .map_err(|e| { + polars_err!( + ComputeError: "error while reading {}: {}", path.display(), e + ) + }) + }) + .collect::>>()?; + + polars_ensure!( + !lfs.is_empty(), + ComputeError: "no matching files found in {:?}", paths.iter().map(|x| x.to_str().unwrap()).collect::>() + ); + + let mut lf = self.concat_impl(lfs)?; + if let Some(n_rows) = self.n_rows() { + lf = lf.slice(0, n_rows as IdxSize) + }; + if let Some(rc) = self.row_index() { + lf = lf.with_row_index(rc.name.clone(), Some(rc.offset)) + }; + + Ok(lf) + } + + /// Recommended concatenation of [LazyFrame]s from many input files. + /// + /// This method should not take into consideration [LazyFileListReader::n_rows] + /// nor [LazyFileListReader::row_index]. + fn concat_impl(&self, lfs: Vec) -> PolarsResult { + let args = UnionArgs { + rechunk: self.rechunk(), + parallel: true, + to_supertypes: false, + from_partitioned_ds: true, + ..Default::default() + }; + concat_impl(&lfs, args) + } + + /// Get the final [LazyFrame]. + /// This method assumes, that path is *not* a glob. + /// + /// It is recommended to always use [LazyFileListReader::finish] method. + fn finish_no_glob(self) -> PolarsResult; + + fn glob(&self) -> bool { + true + } + + /// Get the sources for this reader. + fn sources(&self) -> &ScanSources; + + /// Set sources of the scanned files. + #[must_use] + fn with_sources(self, source: ScanSources) -> Self; + + /// Set paths of the scanned files. + #[must_use] + fn with_paths(self, paths: Arc<[PathBuf]>) -> Self { + self.with_sources(ScanSources::Paths(paths)) + } + + /// Configure the row limit. + fn with_n_rows(self, n_rows: impl Into>) -> Self; + + /// Configure the row index. + fn with_row_index(self, row_index: impl Into>) -> Self; + + /// Rechunk the memory to contiguous chunks when parsing is done. + fn rechunk(&self) -> bool; + + /// Rechunk the memory to contiguous chunks when parsing is done. + #[must_use] + fn with_rechunk(self, toggle: bool) -> Self; + + /// Try to stop parsing when `n` rows are parsed. During multithreaded parsing the upper bound `n` cannot + /// be guaranteed. + fn n_rows(&self) -> Option; + + /// Add a row index column. + fn row_index(&self) -> Option<&RowIndex>; + + /// [CloudOptions] used to list files. + fn cloud_options(&self) -> Option<&CloudOptions> { + None + } +} diff --git a/crates/polars-lazy/src/scan/ipc.rs b/crates/polars-lazy/src/scan/ipc.rs new file mode 100644 index 000000000000..bd675e1e10f8 --- /dev/null +++ b/crates/polars-lazy/src/scan/ipc.rs @@ -0,0 +1,151 @@ +use std::path::{Path, PathBuf}; + +use polars_core::prelude::*; +use polars_io::cloud::CloudOptions; +use polars_io::ipc::IpcScanOptions; +use polars_io::{HiveOptions, RowIndex}; +use polars_utils::slice_enum::Slice; + +use crate::prelude::*; + +#[derive(Clone)] +pub struct ScanArgsIpc { + pub n_rows: Option, + pub cache: bool, + pub rechunk: bool, + pub row_index: Option, + pub cloud_options: Option, + pub hive_options: HiveOptions, + pub include_file_paths: Option, +} + +impl Default for ScanArgsIpc { + fn default() -> Self { + Self { + n_rows: None, + cache: true, + rechunk: false, + row_index: None, + cloud_options: Default::default(), + hive_options: Default::default(), + include_file_paths: None, + } + } +} + +#[derive(Clone)] +struct LazyIpcReader { + args: ScanArgsIpc, + sources: ScanSources, +} + +impl LazyIpcReader { + fn new(args: ScanArgsIpc) -> Self { + Self { + args, + sources: ScanSources::default(), + } + } +} + +impl LazyFileListReader for LazyIpcReader { + fn finish(self) -> PolarsResult { + let args = self.args; + + let options = IpcScanOptions {}; + let pre_slice = args.n_rows.map(|len| Slice::Positive { offset: 0, len }); + + let cloud_options = args.cloud_options; + let hive_options = args.hive_options; + let rechunk = args.rechunk; + let cache = args.cache; + let row_index = args.row_index; + let include_file_paths = args.include_file_paths; + + let lf: LazyFrame = DslBuilder::scan_ipc( + self.sources, + options, + UnifiedScanArgs { + schema: None, + cloud_options, + hive_options, + rechunk, + cache, + glob: true, + projection: None, + row_index, + pre_slice, + cast_columns_policy: CastColumnsPolicy::ErrorOnMismatch, + missing_columns_policy: MissingColumnsPolicy::Raise, + include_file_paths, + }, + )? + .build() + .into(); + + Ok(lf) + } + + fn finish_no_glob(self) -> PolarsResult { + unreachable!() + } + + fn sources(&self) -> &ScanSources { + &self.sources + } + + fn with_sources(mut self, sources: ScanSources) -> Self { + self.sources = sources; + self + } + + fn with_n_rows(mut self, n_rows: impl Into>) -> Self { + self.args.n_rows = n_rows.into(); + self + } + + fn with_row_index(mut self, row_index: impl Into>) -> Self { + self.args.row_index = row_index.into(); + self + } + + fn rechunk(&self) -> bool { + self.args.rechunk + } + + fn with_rechunk(mut self, toggle: bool) -> Self { + self.args.rechunk = toggle; + self + } + + fn n_rows(&self) -> Option { + self.args.n_rows + } + + fn row_index(&self) -> Option<&RowIndex> { + self.args.row_index.as_ref() + } + + /// [CloudOptions] used to list files. + fn cloud_options(&self) -> Option<&CloudOptions> { + self.args.cloud_options.as_ref() + } +} + +impl LazyFrame { + /// Create a LazyFrame directly from a ipc scan. + pub fn scan_ipc(path: impl AsRef, args: ScanArgsIpc) -> PolarsResult { + Self::scan_ipc_sources( + ScanSources::Paths([path.as_ref().to_path_buf()].into()), + args, + ) + } + + pub fn scan_ipc_files(paths: Arc<[PathBuf]>, args: ScanArgsIpc) -> PolarsResult { + Self::scan_ipc_sources(ScanSources::Paths(paths), args) + } + + pub fn scan_ipc_sources(sources: ScanSources, args: ScanArgsIpc) -> PolarsResult { + LazyIpcReader::new(args).with_sources(sources).finish() + } +} diff --git a/crates/polars-lazy/src/scan/mod.rs b/crates/polars-lazy/src/scan/mod.rs new file mode 100644 index 000000000000..86792bd8b5c9 --- /dev/null +++ b/crates/polars-lazy/src/scan/mod.rs @@ -0,0 +1,13 @@ +pub(super) mod anonymous_scan; +#[cfg(feature = "csv")] +pub(super) mod csv; +pub(super) mod file_list_reader; +#[cfg(feature = "ipc")] +pub(super) mod ipc; +#[cfg(feature = "json")] +pub(super) mod ndjson; +#[cfg(feature = "parquet")] +pub(super) mod parquet; + +#[cfg(feature = "catalog")] +mod catalog; diff --git a/crates/polars-lazy/src/scan/ndjson.rs b/crates/polars-lazy/src/scan/ndjson.rs new file mode 100644 index 000000000000..9e0a92d6926d --- /dev/null +++ b/crates/polars-lazy/src/scan/ndjson.rs @@ -0,0 +1,210 @@ +use std::num::NonZeroUsize; +use std::path::{Path, PathBuf}; +use std::sync::Arc; + +use polars_core::prelude::*; +use polars_io::cloud::CloudOptions; +use polars_io::{HiveOptions, RowIndex}; +use polars_plan::dsl::{CastColumnsPolicy, DslPlan, FileScan, MissingColumnsPolicy, ScanSources}; +use polars_plan::prelude::{NDJsonReadOptions, UnifiedScanArgs}; +use polars_utils::slice_enum::Slice; + +use crate::prelude::LazyFrame; +use crate::scan::file_list_reader::LazyFileListReader; + +#[derive(Clone)] +pub struct LazyJsonLineReader { + pub(crate) sources: ScanSources, + pub(crate) batch_size: Option, + pub(crate) low_memory: bool, + pub(crate) rechunk: bool, + pub(crate) schema: Option, + pub(crate) schema_overwrite: Option, + pub(crate) row_index: Option, + pub(crate) infer_schema_length: Option, + pub(crate) n_rows: Option, + pub(crate) ignore_errors: bool, + pub(crate) include_file_paths: Option, + pub(crate) cloud_options: Option, +} + +impl LazyJsonLineReader { + pub fn new_paths(paths: Arc<[PathBuf]>) -> Self { + Self::new_with_sources(ScanSources::Paths(paths)) + } + + pub fn new_with_sources(sources: ScanSources) -> Self { + LazyJsonLineReader { + sources, + batch_size: None, + low_memory: false, + rechunk: false, + schema: None, + schema_overwrite: None, + row_index: None, + infer_schema_length: NonZeroUsize::new(100), + ignore_errors: false, + n_rows: None, + include_file_paths: None, + cloud_options: None, + } + } + + pub fn new(path: impl AsRef) -> Self { + Self::new_with_sources(ScanSources::Paths([path.as_ref().to_path_buf()].into())) + } + + /// Add a row index column. + #[must_use] + pub fn with_row_index(mut self, row_index: Option) -> Self { + self.row_index = row_index; + self + } + + /// Set values as `Null` if parsing fails because of schema mismatches. + #[must_use] + pub fn with_ignore_errors(mut self, ignore_errors: bool) -> Self { + self.ignore_errors = ignore_errors; + self + } + /// Try to stop parsing when `n` rows are parsed. During multithreaded parsing the upper bound `n` cannot + /// be guaranteed. + #[must_use] + pub fn with_n_rows(mut self, num_rows: Option) -> Self { + self.n_rows = num_rows; + self + } + /// Set the number of rows to use when inferring the json schema. + /// the default is 100 rows. + /// Ignored when the schema is specified explicitly using [`Self::with_schema`]. + /// Setting to `None` will do a full table scan, very slow. + #[must_use] + pub fn with_infer_schema_length(mut self, num_rows: Option) -> Self { + self.infer_schema_length = num_rows; + self + } + /// Set the JSON file's schema + #[must_use] + pub fn with_schema(mut self, schema: Option) -> Self { + self.schema = schema; + self + } + + /// Set the JSON file's schema + #[must_use] + pub fn with_schema_overwrite(mut self, schema_overwrite: Option) -> Self { + self.schema_overwrite = schema_overwrite; + self + } + + /// Reduce memory usage at the expense of performance + #[must_use] + pub fn low_memory(mut self, toggle: bool) -> Self { + self.low_memory = toggle; + self + } + + #[must_use] + pub fn with_batch_size(mut self, batch_size: Option) -> Self { + self.batch_size = batch_size; + self + } + + pub fn with_cloud_options(mut self, cloud_options: Option) -> Self { + self.cloud_options = cloud_options; + self + } + + pub fn with_include_file_paths(mut self, include_file_paths: Option) -> Self { + self.include_file_paths = include_file_paths; + self + } +} + +impl LazyFileListReader for LazyJsonLineReader { + fn finish(self) -> PolarsResult { + let unified_scan_args = UnifiedScanArgs { + schema: None, + cloud_options: self.cloud_options, + hive_options: HiveOptions::new_disabled(), + rechunk: self.rechunk, + cache: false, + glob: true, + projection: None, + row_index: self.row_index, + pre_slice: self.n_rows.map(|len| Slice::Positive { offset: 0, len }), + cast_columns_policy: CastColumnsPolicy::ErrorOnMismatch, + missing_columns_policy: MissingColumnsPolicy::Raise, + include_file_paths: self.include_file_paths, + }; + + let options = NDJsonReadOptions { + n_threads: None, + infer_schema_length: self.infer_schema_length, + chunk_size: NonZeroUsize::new(1 << 18).unwrap(), + low_memory: self.low_memory, + ignore_errors: self.ignore_errors, + schema: self.schema, + schema_overwrite: self.schema_overwrite, + }; + + let scan_type = Box::new(FileScan::NDJson { options }); + + Ok(LazyFrame::from(DslPlan::Scan { + sources: self.sources, + file_info: None, + unified_scan_args: Box::new(unified_scan_args), + scan_type, + cached_ir: Default::default(), + })) + } + + fn finish_no_glob(self) -> PolarsResult { + unreachable!(); + } + + fn sources(&self) -> &ScanSources { + &self.sources + } + + fn with_sources(mut self, sources: ScanSources) -> Self { + self.sources = sources; + self + } + + fn with_n_rows(mut self, n_rows: impl Into>) -> Self { + self.n_rows = n_rows.into(); + self + } + + fn with_row_index(mut self, row_index: impl Into>) -> Self { + self.row_index = row_index.into(); + self + } + + fn rechunk(&self) -> bool { + self.rechunk + } + + /// Rechunk the memory to contiguous chunks when parsing is done. + fn with_rechunk(mut self, toggle: bool) -> Self { + self.rechunk = toggle; + self + } + + /// Try to stop parsing when `n` rows are parsed. During multithreaded parsing the upper bound `n` cannot + /// be guaranteed. + fn n_rows(&self) -> Option { + self.n_rows + } + + /// Add a row index column. + fn row_index(&self) -> Option<&RowIndex> { + self.row_index.as_ref() + } + + /// [CloudOptions] used to list files. + fn cloud_options(&self) -> Option<&CloudOptions> { + self.cloud_options.as_ref() + } +} diff --git a/crates/polars-lazy/src/scan/parquet.rs b/crates/polars-lazy/src/scan/parquet.rs new file mode 100644 index 000000000000..78dd680e6a97 --- /dev/null +++ b/crates/polars-lazy/src/scan/parquet.rs @@ -0,0 +1,178 @@ +use std::path::{Path, PathBuf}; + +use polars_core::prelude::*; +use polars_io::cloud::CloudOptions; +use polars_io::parquet::read::ParallelStrategy; +use polars_io::prelude::ParquetOptions; +use polars_io::{HiveOptions, RowIndex}; +use polars_utils::slice_enum::Slice; + +use crate::prelude::*; + +#[derive(Clone)] +pub struct ScanArgsParquet { + pub n_rows: Option, + pub parallel: ParallelStrategy, + pub row_index: Option, + pub cloud_options: Option, + pub hive_options: HiveOptions, + pub use_statistics: bool, + pub schema: Option, + pub low_memory: bool, + pub rechunk: bool, + pub cache: bool, + /// Expand path given via globbing rules. + pub glob: bool, + pub include_file_paths: Option, + pub allow_missing_columns: bool, +} + +impl Default for ScanArgsParquet { + fn default() -> Self { + Self { + n_rows: None, + parallel: Default::default(), + row_index: None, + cloud_options: None, + hive_options: Default::default(), + use_statistics: true, + schema: None, + rechunk: false, + low_memory: false, + cache: true, + glob: true, + include_file_paths: None, + allow_missing_columns: false, + } + } +} + +#[derive(Clone)] +struct LazyParquetReader { + args: ScanArgsParquet, + sources: ScanSources, +} + +impl LazyParquetReader { + fn new(args: ScanArgsParquet) -> Self { + Self { + args, + sources: ScanSources::default(), + } + } +} + +impl LazyFileListReader for LazyParquetReader { + /// Get the final [LazyFrame]. + fn finish(self) -> PolarsResult { + let parquet_options = ParquetOptions { + schema: self.args.schema, + parallel: self.args.parallel, + low_memory: self.args.low_memory, + use_statistics: self.args.use_statistics, + }; + + let unified_scan_args = UnifiedScanArgs { + schema: None, + cloud_options: self.args.cloud_options, + hive_options: self.args.hive_options, + rechunk: self.args.rechunk, + cache: self.args.cache, + glob: self.args.glob, + projection: None, + // Note: We call `with_row_index()` on the LazyFrame below + row_index: None, + pre_slice: self + .args + .n_rows + .map(|len| Slice::Positive { offset: 0, len }), + cast_columns_policy: CastColumnsPolicy::ErrorOnMismatch, + missing_columns_policy: if self.args.allow_missing_columns { + MissingColumnsPolicy::Insert + } else { + MissingColumnsPolicy::Raise + }, + include_file_paths: self.args.include_file_paths, + }; + + let mut lf: LazyFrame = + DslBuilder::scan_parquet(self.sources, parquet_options, unified_scan_args)? + .build() + .into(); + + // It's a bit hacky, but this row_index function updates the schema. + if let Some(row_index) = self.args.row_index { + lf = lf.with_row_index(row_index.name, Some(row_index.offset)) + } + + Ok(lf) + } + + fn glob(&self) -> bool { + self.args.glob + } + + fn finish_no_glob(self) -> PolarsResult { + unreachable!(); + } + + fn sources(&self) -> &ScanSources { + &self.sources + } + + fn with_sources(mut self, sources: ScanSources) -> Self { + self.sources = sources; + self + } + + fn with_n_rows(mut self, n_rows: impl Into>) -> Self { + self.args.n_rows = n_rows.into(); + self + } + + fn with_row_index(mut self, row_index: impl Into>) -> Self { + self.args.row_index = row_index.into(); + self + } + + fn rechunk(&self) -> bool { + self.args.rechunk + } + + fn with_rechunk(mut self, toggle: bool) -> Self { + self.args.rechunk = toggle; + self + } + + fn cloud_options(&self) -> Option<&CloudOptions> { + self.args.cloud_options.as_ref() + } + + fn n_rows(&self) -> Option { + self.args.n_rows + } + + fn row_index(&self) -> Option<&RowIndex> { + self.args.row_index.as_ref() + } +} + +impl LazyFrame { + /// Create a LazyFrame directly from a parquet scan. + pub fn scan_parquet(path: impl AsRef, args: ScanArgsParquet) -> PolarsResult { + Self::scan_parquet_sources( + ScanSources::Paths([path.as_ref().to_path_buf()].into()), + args, + ) + } + + /// Create a LazyFrame directly from a parquet scan. + pub fn scan_parquet_sources(sources: ScanSources, args: ScanArgsParquet) -> PolarsResult { + LazyParquetReader::new(args).with_sources(sources).finish() + } + + /// Create a LazyFrame directly from a parquet scan. + pub fn scan_parquet_files(paths: Arc<[PathBuf]>, args: ScanArgsParquet) -> PolarsResult { + Self::scan_parquet_sources(ScanSources::Paths(paths), args) + } +} diff --git a/crates/polars-lazy/src/tests/aggregations.rs b/crates/polars-lazy/src/tests/aggregations.rs new file mode 100644 index 000000000000..03ad8476bc43 --- /dev/null +++ b/crates/polars-lazy/src/tests/aggregations.rs @@ -0,0 +1,617 @@ +use polars_ops::prelude::ListNameSpaceImpl; +use polars_utils::unitvec; + +use super::*; + +#[test] +#[cfg(feature = "dtype-datetime")] +fn test_agg_list_type() -> PolarsResult<()> { + let s = Series::new("foo".into(), &[1, 2, 3]); + let s = s.cast(&DataType::Datetime(TimeUnit::Nanoseconds, None))?; + + let l = unsafe { s.agg_list(&GroupsType::Idx(vec![(0, unitvec![0, 1, 2])].into())) }; + + let result = match l.dtype() { + DataType::List(inner) => { + matches!(&**inner, DataType::Datetime(TimeUnit::Nanoseconds, None)) + }, + _ => false, + }; + assert!(result); + + Ok(()) +} + +#[test] +fn test_agg_exprs() -> PolarsResult<()> { + let df = fruits_cars(); + + // a binary expression followed by a function and an aggregation. See if it runs + let out = df + .lazy() + .group_by_stable([col("cars")]) + .agg([(lit(1) - col("A")) + .map(|s| Ok(Some(&s * 2)), GetOutput::same_type()) + .alias("foo")]) + .collect()?; + let ca = out.column("foo")?.list()?; + let out = ca.lst_lengths(); + + assert_eq!(Vec::from(&out), &[Some(4), Some(1)]); + Ok(()) +} + +#[test] +fn test_agg_unique_first() -> PolarsResult<()> { + let df = df![ + "g"=> [1, 1, 2, 2, 3, 4, 1], + "v"=> [1, 2, 2, 2, 3, 4, 1], + ]?; + + let out = df + .lazy() + .group_by_stable([col("g")]) + .agg([ + col("v").unique().first().alias("v_first"), + col("v") + .unique() + .sort(Default::default()) + .first() + .alias("true_first"), + col("v").unique().implode(), + ]) + .collect()?; + + let a = out.column("v_first").unwrap(); + let a = a.as_materialized_series().sum::().unwrap(); + // can be both because unique does not guarantee order + assert!(a == 10 || a == 11); + + let a = out.column("true_first").unwrap(); + let a = a.as_materialized_series().sum::().unwrap(); + // can be both because unique does not guarantee order + assert_eq!(a, 10); + + Ok(()) +} + +#[test] +#[cfg(feature = "cum_agg")] +fn test_cum_sum_agg_as_key() -> PolarsResult<()> { + let df = df![ + "depth" => &[0i32, 1, 2, 3, 4, 5, 6, 7, 8, 9], + "soil" => &["peat", "peat", "peat", "silt", "silt", "silt", "sand", "sand", "peat", "peat"] + ]?; + // this checks if the grouper can work with the complex query as a key + + let out = df + .lazy() + .group_by([col("soil") + .neq(col("soil").shift_and_fill(lit(1), col("soil").first())) + .cum_sum(false) + .alias("key")]) + .agg([col("depth").max().name().keep()]) + .sort(["depth"], Default::default()) + .collect()?; + + assert_eq!( + Vec::from(out.column("key")?.u32()?), + &[Some(0), Some(1), Some(2), Some(3)] + ); + assert_eq!( + Vec::from(out.column("depth")?.i32()?), + &[Some(2), Some(5), Some(7), Some(9)] + ); + + Ok(()) +} + +#[test] +#[cfg(feature = "moment")] +fn test_auto_skew_kurtosis_agg() -> PolarsResult<()> { + let df = fruits_cars(); + + let out = df + .lazy() + .group_by([col("fruits")]) + .agg([ + col("B").skew(false).alias("bskew"), + col("B").kurtosis(false, false).alias("bkurt"), + ]) + .collect()?; + + assert!(matches!(out.column("bskew")?.dtype(), DataType::Float64)); + assert!(matches!(out.column("bkurt")?.dtype(), DataType::Float64)); + + Ok(()) +} + +#[test] +fn test_auto_list_agg() -> PolarsResult<()> { + let df = fruits_cars(); + + // test if alias executor adds a list after shift and fill + let out = df + .clone() + .lazy() + .group_by([col("fruits")]) + .agg([col("B").shift_and_fill(lit(-1), lit(-1)).alias("foo")]) + .collect()?; + + assert!(matches!(out.column("foo")?.dtype(), DataType::List(_))); + + // test if it runs and group_by executor thus implements a list after shift_and_fill + let _out = df + .clone() + .lazy() + .group_by([col("fruits")]) + .agg([col("B").shift_and_fill(lit(-1), lit(-1))]) + .collect()?; + + // test if window expr executor adds list + let _out = df + .clone() + .lazy() + .select([col("B").shift_and_fill(lit(-1), lit(-1)).alias("foo")]) + .collect()?; + + let _out = df + .lazy() + .select([col("B").shift_and_fill(lit(-1), lit(-1))]) + .collect()?; + Ok(()) +} +#[test] +#[cfg(feature = "rolling_window")] +fn test_power_in_agg_list1() -> PolarsResult<()> { + let df = fruits_cars(); + + // this test if the group tuples are correctly updated after + // a flat apply on a final aggregation + let out = df + .lazy() + .group_by([col("fruits")]) + .agg([ + col("A") + .rolling_min(RollingOptionsFixedWindow { + window_size: 1, + ..Default::default() + }) + .alias("input"), + col("A") + .rolling_min(RollingOptionsFixedWindow { + window_size: 1, + ..Default::default() + }) + .pow(2.0) + .alias("foo"), + ]) + .sort( + ["fruits"], + SortMultipleOptions::default().with_order_descending(true), + ) + .collect()?; + + let agg = out.column("foo")?.list()?; + let first = agg.get_as_series(0).unwrap(); + let vals = first.f64()?; + assert_eq!(Vec::from(vals), &[Some(1.0), Some(4.0), Some(25.0)]); + + Ok(()) +} + +#[test] +#[cfg(feature = "rolling_window")] +fn test_power_in_agg_list2() -> PolarsResult<()> { + let df = fruits_cars(); + + // this test if the group tuples are correctly updated after + // a flat apply on evaluate_on_groups + let out = df + .lazy() + .group_by([col("fruits")]) + .agg([col("A") + .rolling_min(RollingOptionsFixedWindow { + window_size: 2, + min_periods: 2, + ..Default::default() + }) + .pow(2.0) + .sum() + .alias("foo")]) + .sort( + ["fruits"], + SortMultipleOptions::default().with_order_descending(true), + ) + .collect()?; + + let agg = out.column("foo")?.f64()?; + assert_eq!(Vec::from(agg), &[Some(5.0), Some(9.0)]); + + Ok(()) +} +#[test] +fn test_binary_agg_context_0() -> PolarsResult<()> { + let df = df![ + "groups" => [1, 1, 2, 2, 3, 3], + "vals" => [1, 2, 3, 4, 5, 6] + ] + .unwrap(); + + let out = df + .lazy() + .group_by_stable([col("groups")]) + .agg([when(col("vals").first().neq(lit(1))) + .then(repeat(lit("a"), len())) + .otherwise(repeat(lit("b"), len())) + .alias("foo")]) + .collect() + .unwrap(); + + let out = out.column("foo")?; + let out = out.explode()?; + let out = out.str()?; + assert_eq!( + Vec::from(out), + &[ + Some("b"), + Some("b"), + Some("a"), + Some("a"), + Some("a"), + Some("a") + ] + ); + Ok(()) +} + +// just like binary expression, this must be changed. This can work +#[test] +fn test_binary_agg_context_1() -> PolarsResult<()> { + let df = df![ + "groups" => [1, 1, 2, 2, 3, 3], + "vals" => [1, 13, 3, 87, 1, 6] + ]?; + + // groups + // 1 => [1, 13] + // 2 => [3, 87] + // 3 => [1, 6] + + let out = df + .clone() + .lazy() + .group_by_stable([col("groups")]) + .agg([when(col("vals").eq(lit(1))) + .then(col("vals").sum()) + .otherwise(lit(90)) + .alias("vals")]) + .collect()?; + + // if vals == 1 then sum(vals) else vals + // [14, 90] + // [90, 90] + // [7, 90] + let out = out.column("vals")?; + let out = out.explode()?; + let out = out.i32()?; + assert_eq!( + Vec::from(out), + &[Some(14), Some(90), Some(90), Some(90), Some(7), Some(90)] + ); + + let out = df + .lazy() + .group_by_stable([col("groups")]) + .agg([when(col("vals").eq(lit(1))) + .then(lit(90)) + .otherwise(col("vals").sum()) + .alias("vals")]) + .collect()?; + + // if vals == 1 then 90 else sum(vals) + // [90, 14] + // [90, 90] + // [90, 7] + let out = out.column("vals")?; + let out = out.explode()?; + let out = out.i32()?; + assert_eq!( + Vec::from(out), + &[Some(90), Some(14), Some(90), Some(90), Some(90), Some(7)] + ); + + Ok(()) +} + +#[test] +fn test_binary_agg_context_2() -> PolarsResult<()> { + let df = df![ + "groups" => [1, 1, 2, 2, 3, 3], + "vals" => [1, 2, 3, 4, 5, 6] + ]?; + + // this is complex because we first aggregate one expression of the binary operation. + + let out = df + .clone() + .lazy() + .group_by_stable([col("groups")]) + .agg([(col("vals").first() - col("vals")).alias("vals")]) + .collect()?; + + // 0 - [1, 2] = [0, -1] + // 3 - [3, 4] = [0, -1] + // 5 - [5, 6] = [0, -1] + let out = out.column("vals")?; + let out = out.explode()?; + let out = out.i32()?; + assert_eq!( + Vec::from(out), + &[Some(0), Some(-1), Some(0), Some(-1), Some(0), Some(-1)] + ); + + // Same, but now we reverse the lhs / rhs. + let out = df + .lazy() + .group_by_stable([col("groups")]) + .agg([((col("vals")) - col("vals").first()).alias("vals")]) + .collect()?; + + // [1, 2] - 1 = [0, 1] + // [3, 4] - 3 = [0, 1] + // [5, 6] - 5 = [0, 1] + let out = out.column("vals")?; + let out = out.explode()?; + let out = out.i32()?; + assert_eq!( + Vec::from(out), + &[Some(0), Some(1), Some(0), Some(1), Some(0), Some(1)] + ); + + Ok(()) +} + +#[test] +fn test_binary_agg_context_3() -> PolarsResult<()> { + let df = fruits_cars(); + + let out = df + .lazy() + .group_by_stable([col("cars")]) + .agg([(col("A") - col("A").first()).last().alias("last")]) + .collect()?; + + let out = out.column("last")?; + assert_eq!(out.get(0)?, AnyValue::Int32(4)); + assert_eq!(out.get(1)?, AnyValue::Int32(0)); + + Ok(()) +} + +#[test] +fn test_shift_elementwise_issue_2509() -> PolarsResult<()> { + let df = df![ + "x"=> [0, 0, 0, 1, 1, 1, 2, 2, 2], + "y"=> [0, 10, 20, 0, 10, 20, 0, 10, 20] + ]?; + let out = df + .lazy() + // Don't use maintain order here! That hides the bug + .group_by([col("x")]) + .agg(&[(col("y").shift(lit(-1)) + col("x")).alias("sum")]) + .sort(["x"], Default::default()) + .collect()?; + + let out = out.explode(["sum"])?; + let out = out.column("sum")?; + assert_eq!(out.get(0)?, AnyValue::Int32(10)); + assert_eq!(out.get(1)?, AnyValue::Int32(20)); + assert_eq!(out.get(2)?, AnyValue::Null); + assert_eq!(out.get(3)?, AnyValue::Int32(11)); + assert_eq!(out.get(4)?, AnyValue::Int32(21)); + assert_eq!(out.get(5)?, AnyValue::Null); + + Ok(()) +} + +#[test] +fn take_aggregations() -> PolarsResult<()> { + let df = df![ + "user" => ["lucy", "bob", "bob", "lucy", "tim"], + "book" => ["c", "b", "a", "a", "a"], + "count" => [3, 1, 2, 1, 1] + ]?; + + let out = df + .clone() + .lazy() + .group_by([col("user")]) + .agg([col("book").get(col("count").arg_max()).alias("fav_book")]) + .sort(["user"], Default::default()) + .collect()?; + + let s = out.column("fav_book")?; + assert_eq!(s.get(0)?, AnyValue::String("a")); + assert_eq!(s.get(1)?, AnyValue::String("c")); + assert_eq!(s.get(2)?, AnyValue::String("a")); + + let out = df + .clone() + .lazy() + .group_by([col("user")]) + .agg([ + // keep the head as it test slice correctness + col("book") + .gather( + col("count") + .arg_sort(SortOptions { + descending: true, + nulls_last: false, + multithreaded: true, + maintain_order: false, + limit: None, + }) + .head(Some(2)), + ) + .alias("ordered"), + ]) + .sort(["user"], Default::default()) + .collect()?; + let s = out.column("ordered")?; + let flat = s.explode()?; + let flat = flat.str()?; + let vals = flat.into_no_null_iter().collect::>(); + assert_eq!(vals, ["a", "b", "c", "a", "a"]); + + let out = df + .lazy() + .group_by([col("user")]) + .agg([col("book").get(lit(0)).alias("take_lit")]) + .sort(["user"], Default::default()) + .collect()?; + + let taken = out.column("take_lit")?; + let taken = taken.str()?; + let vals = taken.into_no_null_iter().collect::>(); + assert_eq!(vals, ["b", "c", "a"]); + + Ok(()) +} +#[test] +fn test_take_consistency() -> PolarsResult<()> { + let df = fruits_cars(); + let out = df + .clone() + .lazy() + .select([col("A") + .arg_sort(SortOptions { + descending: true, + nulls_last: false, + multithreaded: true, + maintain_order: false, + limit: None, + }) + .get(lit(0))]) + .collect()?; + + let a = out.column("A")?; + let a = a.idx()?; + assert_eq!(a.get(0), Some(4)); + + let out = df + .clone() + .lazy() + .group_by_stable([col("cars")]) + .agg([col("A") + .arg_sort(SortOptions { + descending: true, + nulls_last: false, + multithreaded: true, + maintain_order: false, + limit: None, + }) + .get(lit(0))]) + .collect()?; + + let out = out.column("A")?; + let out = out.idx()?; + assert_eq!(Vec::from(out), &[Some(3), Some(0)]); + + let out_df = df + .lazy() + .group_by_stable([col("cars")]) + .agg([ + col("A"), + col("A") + .arg_sort(SortOptions { + descending: true, + nulls_last: false, + multithreaded: true, + maintain_order: false, + limit: None, + }) + .get(lit(0)) + .alias("1"), + col("A") + .get( + col("A") + .arg_sort(SortOptions { + descending: true, + nulls_last: false, + multithreaded: true, + maintain_order: false, + limit: None, + }) + .get(lit(0)), + ) + .alias("2"), + ]) + .collect()?; + + let out = out_df.column("2")?; + let out = out.i32()?; + assert_eq!(Vec::from(out), &[Some(5), Some(2)]); + + let out = out_df.column("1")?; + let out = out.idx()?; + assert_eq!(Vec::from(out), &[Some(3), Some(0)]); + + Ok(()) +} + +#[test] +fn test_take_in_groups() -> PolarsResult<()> { + let df = fruits_cars(); + + let out = df + .lazy() + .sort(["fruits"], Default::default()) + .select([col("B").get(lit(0u32)).over([col("fruits")]).alias("taken")]) + .collect()?; + + assert_eq!( + Vec::from(out.column("taken")?.i32()?), + &[Some(3), Some(3), Some(5), Some(5), Some(5)] + ); + Ok(()) +} + +#[test] +fn test_anonymous_function_returns_scalar_all_null_20679() { + use std::sync::Arc; + + fn reduction_function(column: Column) -> PolarsResult> { + let val = column.get(0)?.into_static(); + let col = Column::new_scalar("".into(), Scalar::new(column.dtype().clone(), val), 1); + Ok(Some(col)) + } + + let a = Column::new("a".into(), &[0, 0, 1]); + let dtype = DataType::Null; + let b = Column::new_scalar("b".into(), Scalar::new(dtype, AnyValue::Null), 3); + let df = DataFrame::new(vec![a, b]).unwrap(); + + let f = move |c: &mut [Column]| reduction_function(std::mem::take(&mut c[0])); + + let expr = Expr::AnonymousFunction { + input: vec![col("b")], + function: LazySerde::Deserialized(SpecialEq::new(Arc::new(f))), + output_type: Default::default(), + options: FunctionOptions { + collect_groups: ApplyOptions::GroupWise, + fmt_str: "", + flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR, + ..Default::default() + }, + }; + + let grouped_df = df + .lazy() + .group_by([col("a")]) + .agg([expr]) + .collect() + .unwrap(); + + assert_eq!(grouped_df.get_columns()[1].dtype(), &DataType::Null); +} diff --git a/crates/polars-lazy/src/tests/arity.rs b/crates/polars-lazy/src/tests/arity.rs new file mode 100644 index 000000000000..73450259c667 --- /dev/null +++ b/crates/polars-lazy/src/tests/arity.rs @@ -0,0 +1,83 @@ +use super::*; + +#[test] +#[cfg(feature = "cov")] +fn test_pearson_corr() -> PolarsResult<()> { + let df = df! { + "uid" => [0, 0, 0, 1, 1, 1], + "day" => [1, 2, 4, 1, 2, 3], + "cumcases" => [10, 12, 15, 25, 30, 41] + } + .unwrap(); + + let out = df + .clone() + .lazy() + .group_by_stable([col("uid")]) + // a double aggregation expression. + .agg([pearson_corr(col("day"), col("cumcases")).alias("pearson_corr")]) + .collect()?; + let s = out.column("pearson_corr")?.f64()?; + assert!((s.get(0).unwrap() - 0.997176).abs() < 0.000001); + assert!((s.get(1).unwrap() - 0.977356).abs() < 0.000001); + + let out = df + .lazy() + .group_by_stable([col("uid")]) + // a double aggregation expression. + .agg([pearson_corr(col("day"), col("cumcases")) + .pow(2.0) + .alias("pearson_corr")]) + .collect() + .unwrap(); + let s = out.column("pearson_corr")?.f64()?; + assert!((s.get(0).unwrap() - 0.994360902255639).abs() < 0.000001); + assert!((s.get(1).unwrap() - 0.9552238805970149).abs() < 0.000001); + Ok(()) +} + +// TODO! fix this we must get a token that prevents resetting the string cache until the plan has +// finished running. We cannot store a mutexguard in the executionstate because they don't implement +// send. +// #[test] +// fn test_single_thread_when_then_otherwise_categorical() -> PolarsResult<()> { +// let df = df!["col1"=> ["a", "b", "a", "b"], +// "col2"=> ["a", "a", "b", "b"], +// "col3"=> ["same", "same", "same", "same"] +// ]?; + +// let out = df +// .lazy() +// .with_column(col("*").cast(DataType::Categorical)) +// .select([when(col("col1").eq(col("col2"))) +// .then(col("col3")) +// .otherwise(col("col1"))]) +// .collect()?; +// let col = out.column("col3")?; +// assert_eq!(col.dtype(), &DataType::Categorical); +// let s = format!("{}", col); +// assert!(s.contains("same")); +// Ok(()) +// } + +#[test] +fn test_lazy_ternary() { + let df = get_df() + .lazy() + .with_column( + when(col("sepal_length").lt(lit(5.0))) + .then(lit(10)) + .otherwise(lit(1)) + .alias("new"), + ) + .collect() + .unwrap(); + assert_eq!( + 43, + df.column("new") + .unwrap() + .as_materialized_series() + .sum::() + .unwrap() + ); +} diff --git a/crates/polars-lazy/src/tests/cse.rs b/crates/polars-lazy/src/tests/cse.rs new file mode 100644 index 000000000000..678a8fe095b0 --- /dev/null +++ b/crates/polars-lazy/src/tests/cse.rs @@ -0,0 +1,357 @@ +use std::collections::BTreeSet; + +use super::*; + +fn cached_before_root(q: LazyFrame) { + let (mut expr_arena, mut lp_arena) = get_arenas(); + let lp = q.optimize(&mut lp_arena, &mut expr_arena).unwrap(); + for input in lp_arena.get(lp).get_inputs_vec() { + assert!(matches!(lp_arena.get(input), IR::Cache { .. })); + } +} + +fn count_caches(q: LazyFrame) -> usize { + let IRPlan { + lp_top, lp_arena, .. + } = q.to_alp_optimized().unwrap(); + (&lp_arena) + .iter(lp_top) + .filter(|(_node, lp)| matches!(lp, IR::Cache { .. })) + .count() +} + +#[test] +fn test_cse_self_joins() -> PolarsResult<()> { + let lf = scan_foods_ipc(); + + let lf = lf.with_column(col("category").str().to_uppercase()); + + let lf = lf + .clone() + .left_join(lf, col("fats_g"), col("fats_g")) + .with_comm_subplan_elim(true); + + cached_before_root(lf); + + Ok(()) +} + +#[test] +fn test_cse_unions() -> PolarsResult<()> { + let lf = scan_foods_ipc(); + + let lf1 = lf.clone().with_column(col("category").str().to_uppercase()); + + let lf = concat( + &[lf1.clone(), lf, lf1], + UnionArgs { + rechunk: false, + parallel: false, + ..Default::default() + }, + )? + .select([col("category"), col("fats_g")]) + .with_comm_subplan_elim(true); + + let (mut expr_arena, mut lp_arena) = get_arenas(); + let lp = lf.clone().optimize(&mut lp_arena, &mut expr_arena).unwrap(); + let mut cache_count = 0; + assert!((&lp_arena).iter(lp).all(|(_, lp)| { + use IR::*; + match lp { + Cache { .. } => { + cache_count += 1; + true + }, + Scan { + unified_scan_args, .. + } => { + if let Some(columns) = &unified_scan_args.projection { + columns.len() == 2 + } else { + false + } + }, + _ => true, + } + })); + assert_eq!(cache_count, 2); + let out = lf.collect()?; + assert_eq!(out.get_column_names(), &["category", "fats_g"]); + + Ok(()) +} + +#[test] +fn test_cse_cache_union_projection_pd() -> PolarsResult<()> { + let q = df![ + "a" => [1], + "b" => [2], + "c" => [3], + ]? + .lazy(); + + let q1 = q.clone().filter(col("a").eq(lit(1))).select([col("a")]); + let q2 = q.filter(col("a").eq(lit(1))).select([col("a"), col("b")]); + let q = q1 + .left_join(q2, col("a"), col("a")) + .with_comm_subplan_elim(true); + + // check that the projection of a is not done before the cache + let (mut expr_arena, mut lp_arena) = get_arenas(); + let lp = q.optimize(&mut lp_arena, &mut expr_arena).unwrap(); + let mut cache_count = 0; + assert!((&lp_arena).iter(lp).all(|(_, lp)| { + use IR::*; + match lp { + Cache { .. } => { + cache_count += 1; + true + }, + DataFrameScan { + output_schema: Some(projection), + .. + } => projection.as_ref().len() <= 2, + DataFrameScan { .. } => false, + _ => true, + } + })); + assert_eq!(cache_count, 2); + + Ok(()) +} + +#[test] +fn test_cse_union2_4925() -> PolarsResult<()> { + let lf1 = df![ + "ts" => [1], + "sym" => ["a"], + "c" => [true], + ]? + .lazy(); + + let lf2 = df![ + "ts" => [1], + "d" => [3], + ]? + .lazy(); + + let args = UnionArgs { + parallel: false, + rechunk: false, + ..Default::default() + }; + let lf1 = concat(&[lf1.clone(), lf1], args)?; + let lf2 = concat(&[lf2.clone(), lf2], args)?; + + let q = lf1.inner_join(lf2, col("ts"), col("ts")).select([ + col("ts"), + col("sym"), + col("d") / col("c"), + ]); + + let (mut expr_arena, mut lp_arena) = get_arenas(); + let lp = q.optimize(&mut lp_arena, &mut expr_arena).unwrap(); + + // ensure we get two different caches + // and ensure that every cache only has 1 hit. + let cache_ids = (&lp_arena) + .iter(lp) + .flat_map(|(_, lp)| { + use IR::*; + match lp { + Cache { id, cache_hits, .. } => { + assert_eq!(*cache_hits, 1); + Some(*id) + }, + _ => None, + } + }) + .collect::>(); + + assert_eq!(cache_ids.len(), 2); + + Ok(()) +} + +#[test] +fn test_cse_joins_4954() -> PolarsResult<()> { + let x = df![ + "a"=> [1], + "b"=> [1], + "c"=> [1], + ]? + .lazy(); + + let y = df![ + "a"=> [1], + "b"=> [1], + ]? + .lazy(); + + let z = df![ + "a"=> [1], + ]? + .lazy(); + + let a = x.left_join(z.clone(), col("a"), col("a")); + let b = y.left_join(z, col("a"), col("a")); + let c = a.join( + b, + &[col("a"), col("b")], + &[col("a"), col("b")], + JoinType::Left.into(), + ); + + let (mut expr_arena, mut lp_arena) = get_arenas(); + let lp = c.optimize(&mut lp_arena, &mut expr_arena).unwrap(); + + // Ensure we get only one cache and it is not above the join + // and ensure that every cache only has 1 hit. + let cache_ids = (&lp_arena) + .iter(lp) + .flat_map(|(_, lp)| { + use IR::*; + match lp { + Cache { + id, + cache_hits, + input, + .. + } => { + assert_eq!(*cache_hits, 1); + assert!(matches!(lp_arena.get(*input), IR::SimpleProjection { .. })); + + Some(*id) + }, + _ => None, + } + }) + .collect::>(); + + assert_eq!(cache_ids.len(), 1); + + Ok(()) +} +#[test] +#[cfg(feature = "semi_anti_join")] +fn test_cache_with_partial_projection() -> PolarsResult<()> { + let lf1 = df![ + "id" => ["a"], + "x" => [1], + "freq" => [2] + ]? + .lazy(); + + let lf2 = df![ + "id" => ["a"] + ]? + .lazy(); + + let q = lf2 + .join( + lf1.clone().select([col("id"), col("freq")]), + [col("id")], + [col("id")], + JoinType::Semi.into(), + ) + .join( + lf1.clone().filter(col("x").neq(lit(8))), + [col("id")], + [col("id")], + JoinType::Semi.into(), + ) + .join( + lf1.filter(col("x").neq(lit(8))), + [col("id")], + [col("id")], + JoinType::Semi.into(), + ); + + let (mut expr_arena, mut lp_arena) = get_arenas(); + let lp = q.optimize(&mut lp_arena, &mut expr_arena).unwrap(); + + // EDIT: #15264 this originally + // tested 2 caches, but we cannot do that after #15264 due to projection pushdown + // running first and the cache semantics changing, so now we test 1. Maybe we can improve later. + + // ensure we get two different caches + // and ensure that every cache only has 1 hit. + let cache_ids = (&lp_arena) + .iter(lp) + .flat_map(|(_, lp)| { + use IR::*; + match lp { + Cache { id, .. } => Some(*id), + _ => None, + } + }) + .collect::>(); + assert_eq!(cache_ids.len(), 1); + + Ok(()) +} + +#[test] +#[cfg(feature = "cross_join")] +fn test_cse_columns_projections() -> PolarsResult<()> { + let right = df![ + "A" => [1, 2], + "B" => [3, 4], + "D" => [5, 6] + ]? + .lazy(); + + let left = df![ + "C" => [3, 4], + ]? + .lazy(); + + let left = left.cross_join(right.clone().select([col("A")]), None); + let q = left.join( + right.rename(["B"], ["C"], true), + [col("A"), col("C")], + [col("A"), col("C")], + JoinType::Left.into(), + ); + + let out = q.collect()?; + + assert_eq!(out.get_column_names(), &["C", "A", "D"]); + + Ok(()) +} + +#[test] +fn test_cse_prune_scan_filter_difference() -> PolarsResult<()> { + let lf = scan_foods_ipc(); + let lf = lf.with_column(col("category").str().to_uppercase()); + + let pred = col("fats_g").gt(2.0); + + // If filter are the same, we can cache + let q = lf + .clone() + .filter(pred.clone()) + .left_join(lf.clone().filter(pred), col("fats_g"), col("fats_g")) + .with_comm_subplan_elim(true); + cached_before_root(q); + + // If the filters are different the caches are removed. + let q = lf + .clone() + .filter(col("fats_g").gt(2.0)) + .clone() + .left_join( + lf.filter(col("fats_g").gt(1.0)), + col("fats_g"), + col("fats_g"), + ) + .with_comm_subplan_elim(true); + + // Check that the caches are removed and that both predicates have been pushed down instead. + assert_eq!(count_caches(q.clone()), 0); + assert!(predicate_at_scan(q)); + + Ok(()) +} diff --git a/crates/polars-lazy/src/tests/io.rs b/crates/polars-lazy/src/tests/io.rs new file mode 100644 index 000000000000..13287f2f8287 --- /dev/null +++ b/crates/polars-lazy/src/tests/io.rs @@ -0,0 +1,765 @@ +use polars_io::RowIndex; +#[cfg(feature = "is_between")] +use polars_ops::prelude::ClosedInterval; +use polars_utils::slice_enum::Slice; + +use super::*; +use crate::dsl; + +#[test] +#[cfg(feature = "parquet")] +fn test_parquet_exec() -> PolarsResult<()> { + let _guard = SINGLE_LOCK.lock().unwrap(); + // filter + for par in [true, false] { + let out = scan_foods_parquet(par) + .filter(col("category").eq(lit("seafood"))) + .collect()?; + assert_eq!(out.shape(), (8, 4)); + } + + // project + for par in [true, false] { + let out = scan_foods_parquet(par) + .select([col("category"), col("sugars_g")]) + .collect()?; + assert_eq!(out.shape(), (27, 2)); + } + + // project + filter + for par in [true, false] { + let out = scan_foods_parquet(par) + .select([col("category"), col("sugars_g")]) + .filter(col("category").eq(lit("seafood"))) + .collect()?; + assert_eq!(out.shape(), (8, 2)); + } + + Ok(()) +} + +#[test] +#[cfg(all(feature = "parquet", feature = "is_between"))] +fn test_parquet_statistics_no_skip() { + let _guard = SINGLE_LOCK.lock().unwrap(); + init_files(); + let par = true; + let out = scan_foods_parquet(par) + .filter(col("calories").gt(lit(0i32))) + .collect() + .unwrap(); + assert_eq!(out.shape(), (27, 4)); + + let out = scan_foods_parquet(par) + .filter(col("calories").lt(lit(1000i32))) + .collect() + .unwrap(); + assert_eq!(out.shape(), (27, 4)); + + let out = scan_foods_parquet(par) + .filter(lit(0i32).lt(col("calories"))) + .collect() + .unwrap(); + assert_eq!(out.shape(), (27, 4)); + + let out = scan_foods_parquet(par) + .filter(lit(1000i32).gt(col("calories"))) + .collect() + .unwrap(); + assert_eq!(out.shape(), (27, 4)); + + // statistics and `is_between` + // normal case + let out = scan_foods_parquet(par) + .filter(col("calories").is_between(40, 300, ClosedInterval::Both)) + .collect() + .unwrap(); + assert_eq!(out.shape(), (19, 4)); + // normal case + let out = scan_foods_parquet(par) + .filter(col("calories").is_between(10, 50, ClosedInterval::Both)) + .collect() + .unwrap(); + assert_eq!(out.shape(), (11, 4)); + // edge case: 20 = min(calories) but the right end is closed + let out = scan_foods_parquet(par) + .filter(col("calories").is_between(5, 20, ClosedInterval::Right)) + .collect() + .unwrap(); + assert_eq!(out.shape(), (1, 4)); + // edge case: 200 = max(calories) but the left end is closed + let out = scan_foods_parquet(par) + .filter(col("calories").is_between(200, 250, ClosedInterval::Left)) + .collect() + .unwrap(); + assert_eq!(out.shape(), (3, 4)); + // edge case: left == right but both ends are closed + let out = scan_foods_parquet(par) + .filter(col("calories").is_between(200, 200, ClosedInterval::Both)) + .collect() + .unwrap(); + assert_eq!(out.shape(), (3, 4)); + + // Or operation + let out = scan_foods_parquet(par) + .filter( + col("sugars_g") + .lt(lit(0i32)) + .or(col("fats_g").lt(lit(1000.0))), + ) + .collect() + .unwrap(); + assert_eq!(out.shape(), (27, 4)); +} + +#[test] +#[cfg(all(feature = "parquet", feature = "is_between"))] +fn test_parquet_statistics() -> PolarsResult<()> { + let _guard = SINGLE_LOCK.lock().unwrap(); + init_files(); + unsafe { std::env::set_var("POLARS_PANIC_IF_PARQUET_PARSED", "1") }; + let par = true; + + // Test single predicates + let out = scan_foods_parquet(par) + .filter(col("calories").lt(lit(0i32))) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + let out = scan_foods_parquet(par) + .filter(col("calories").gt(lit(1000))) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + let out = scan_foods_parquet(par) + .filter(lit(0i32).gt(col("calories"))) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // issue: 13427 + let out = scan_foods_parquet(par) + .filter(col("calories").is_in(lit(Series::new("".into(), [0, 500])), false)) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // statistics and `is_between` + // 15 < min(calories)=20 + let out = scan_foods_parquet(par) + .filter(col("calories").is_between(5, 15, ClosedInterval::Both)) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // 300 > max(calories)=200 + let out = scan_foods_parquet(par) + .filter(col("calories").is_between(300, 500, ClosedInterval::Both)) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // 20 == min(calories) but right end is open + let out = scan_foods_parquet(par) + .filter(col("calories").is_between(5, 20, ClosedInterval::Left)) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // 20 == min(calories) but both ends are open + let out = scan_foods_parquet(par) + .filter(col("calories").is_between(5, 20, ClosedInterval::None)) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // 200 == max(calories) but left end is open + let out = scan_foods_parquet(par) + .filter(col("calories").is_between(200, 250, ClosedInterval::Right)) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // 200 == max(calories) but both ends are open + let out = scan_foods_parquet(par) + .filter(col("calories").is_between(200, 250, ClosedInterval::None)) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // between(100, 40) is impossible + let out = scan_foods_parquet(par) + .filter(col("calories").is_between(100, 40, ClosedInterval::Both)) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // with strings + let out = scan_foods_parquet(par) + .filter(col("category").is_between(lit("yams"), lit("zest"), ClosedInterval::Both)) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // with strings + let out = scan_foods_parquet(par) + .filter(col("category").is_between(lit("dairy"), lit("eggs"), ClosedInterval::Both)) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + let out = scan_foods_parquet(par) + .filter(lit(1000i32).lt(col("calories"))) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // not(a > b) => a <= b + let out = scan_foods_parquet(par) + .filter(not(col("calories").gt(5))) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // not(a >= b) => a < b + // note that min(calories)=20 + let out = scan_foods_parquet(par) + .filter(not(col("calories").gt_eq(20))) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // not(a < b) => a >= b + let out = scan_foods_parquet(par) + .filter(not(col("calories").lt(250))) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // not(a <= b) => a > b + // note that max(calories)=200 + let out = scan_foods_parquet(par) + .filter(not(col("calories").lt_eq(200))) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // not(a == b) => a != b + // note that proteins_g=10 for all rows + let out = scan_nutri_score_null_column_parquet(par) + .filter(not(col("proteins_g").eq(10))) + .collect()?; + assert_eq!(out.shape(), (0, 6)); + + // not(a != b) => a == b + // note that proteins_g=10 for all rows + let out = scan_nutri_score_null_column_parquet(par) + .filter(not(col("proteins_g").neq(5))) + .collect()?; + assert_eq!(out.shape(), (0, 6)); + + // not(col(c) is between [a, b]) => col(c) < a or col(c) > b + let out = scan_foods_parquet(par) + .filter(not(col("calories").is_between( + 20, + 200, + ClosedInterval::Both, + ))) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // not(col(c) is between [a, b[) => col(c) < a or col(c) >= b + let out = scan_foods_parquet(par) + .filter(not(col("calories").is_between( + 20, + 201, + ClosedInterval::Left, + ))) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // not(col(c) is between ]a, b]) => col(c) <= a or col(c) > b + let out = scan_foods_parquet(par) + .filter(not(col("calories").is_between( + 19, + 200, + ClosedInterval::Right, + ))) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // not(col(c) is between ]a, b]) => col(c) <= a or col(c) > b + let out = scan_foods_parquet(par) + .filter(not(col("calories").is_between( + 19, + 200, + ClosedInterval::Right, + ))) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // not(col(c) is between ]a, b[) => col(c) <= a or col(c) >= b + let out = scan_foods_parquet(par) + .filter(not(col("calories").is_between( + 19, + 201, + ClosedInterval::None, + ))) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // not (a or b) => not(a) and not(b) + // note that not(fats_g <= 9) is possible; not(calories > 5) should allow us skip the rg + let out = scan_foods_parquet(par) + .filter(not(col("calories").gt(5).or(col("fats_g").lt_eq(9)))) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // not (a and b) => not(a) or not(b) + let out = scan_foods_parquet(par) + .filter(not(col("calories").gt(5).and(col("fats_g").lt_eq(12)))) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // is_not_null + let out = scan_nutri_score_null_column_parquet(par) + .filter(col("nutri_score").is_not_null()) + .collect()?; + assert_eq!(out.shape(), (0, 6)); + + // not(is_null) (~pl.col('nutri_score').is_null()) + let out = scan_nutri_score_null_column_parquet(par) + .filter(not(col("nutri_score").is_null())) + .collect()?; + assert_eq!(out.shape(), (0, 6)); + + // Test multiple predicates + + // And operation + let out = scan_foods_parquet(par) + .filter(col("calories").lt(lit(0i32))) + .filter(col("calories").gt(lit(1000))) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + let out = scan_foods_parquet(par) + .filter(col("calories").lt(lit(0i32))) + .filter(col("calories").gt(lit(1000))) + .filter(col("calories").lt(lit(50i32))) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + let out = scan_foods_parquet(par) + .filter( + col("calories") + .lt(lit(0i32)) + .and(col("fats_g").lt(lit(0.0))), + ) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + // Or operation + let out = scan_foods_parquet(par) + .filter( + col("sugars_g") + .lt(lit(0i32)) + .or(col("fats_g").gt(lit(1000.0))), + ) + .collect()?; + assert_eq!(out.shape(), (0, 4)); + + unsafe { std::env::remove_var("POLARS_PANIC_IF_PARQUET_PARSED") }; + + Ok(()) +} + +#[test] +#[cfg(not(target_os = "windows"))] +fn test_parquet_globbing() -> PolarsResult<()> { + // for side effects + init_files(); + let _guard = SINGLE_LOCK.lock().unwrap(); + let glob = "../../examples/datasets/foods*.parquet"; + let df = LazyFrame::scan_parquet( + glob, + ScanArgsParquet { + n_rows: None, + cache: true, + parallel: Default::default(), + ..Default::default() + }, + )? + .collect()?; + assert_eq!(df.shape(), (54, 4)); + let cal = df.column("calories")?; + assert_eq!(cal.get(0)?, AnyValue::Int64(45)); + assert_eq!(cal.get(53)?, AnyValue::Int64(194)); + + Ok(()) +} + +#[test] +fn test_scan_parquet_limit_9001() { + init_files(); + let path = GLOB_PARQUET; + let args = ScanArgsParquet { + n_rows: Some(10000), + cache: false, + rechunk: true, + ..Default::default() + }; + let q = LazyFrame::scan_parquet(path, args).unwrap().limit(3); + let IRPlan { + lp_top, lp_arena, .. + } = q.to_alp_optimized().unwrap(); + (&lp_arena).iter(lp_top).all(|(_, lp)| match lp { + IR::Union { options, .. } => { + let sliced = options.slice.unwrap(); + sliced.1 == 3 + }, + IR::Scan { + unified_scan_args, .. + } => unified_scan_args.pre_slice == Some(Slice::Positive { offset: 0, len: 3 }), + _ => true, + }); +} + +#[test] +#[cfg(not(target_os = "windows"))] +fn test_ipc_globbing() -> PolarsResult<()> { + // for side effects + init_files(); + let glob = "../../examples/datasets/foods*.ipc"; + let df = LazyFrame::scan_ipc( + glob, + ScanArgsIpc { + n_rows: None, + cache: true, + rechunk: false, + row_index: None, + cloud_options: None, + hive_options: Default::default(), + include_file_paths: None, + }, + )? + .collect()?; + assert_eq!(df.shape(), (54, 4)); + let cal = df.column("calories")?; + assert_eq!(cal.get(0)?, AnyValue::Int64(45)); + assert_eq!(cal.get(53)?, AnyValue::Int64(194)); + + Ok(()) +} + +fn slice_at_union(lp_arena: &Arena, lp: Node) -> bool { + (&lp_arena).iter(lp).all(|(_, lp)| { + if let IR::Union { options, .. } = lp { + options.slice.is_some() + } else { + true + } + }) +} + +#[test] +fn test_csv_globbing() -> PolarsResult<()> { + let glob = "../../examples/datasets/foods*.csv"; + let full_df = LazyCsvReader::new(glob).finish()?.collect()?; + + // all 5 files * 27 rows + assert_eq!(full_df.shape(), (135, 4)); + let cal = full_df.column("calories")?; + assert_eq!(cal.get(0)?, AnyValue::Int64(45)); + assert_eq!(cal.get(53)?, AnyValue::Int64(194)); + + let glob = "../../examples/datasets/foods*.csv"; + let lf = LazyCsvReader::new(glob).finish()?.slice(0, 100); + + let df = lf.clone().collect()?; + assert_eq!(df, full_df.slice(0, 100)); + let df = LazyCsvReader::new(glob).finish()?.slice(20, 60).collect()?; + assert_eq!(df, full_df.slice(20, 60)); + + let mut expr_arena = Arena::with_capacity(16); + let mut lp_arena = Arena::with_capacity(8); + let node = lf.optimize(&mut lp_arena, &mut expr_arena)?; + assert!(slice_at_union(&lp_arena, node)); + + let lf = LazyCsvReader::new(glob) + .finish()? + .filter(col("sugars_g").lt(lit(1i32))) + .slice(0, 100); + let node = lf.optimize(&mut lp_arena, &mut expr_arena)?; + assert!(slice_at_union(&lp_arena, node)); + + Ok(()) +} + +#[test] +#[cfg(feature = "json")] +fn test_ndjson_globbing() -> PolarsResult<()> { + // for side effects + init_files(); + let glob = "../../examples/datasets/foods*.ndjson"; + let df = LazyJsonLineReader::new(glob).finish()?.collect()?; + assert_eq!(df.shape(), (54, 4)); + let cal = df.column("calories")?; + assert_eq!(cal.get(0)?, AnyValue::Int64(45)); + assert_eq!(cal.get(53)?, AnyValue::Int64(194)); + + Ok(()) +} + +#[test] +pub fn test_simple_slice() -> PolarsResult<()> { + let _guard = SINGLE_LOCK.lock().unwrap(); + let out = scan_foods_parquet(false).limit(3).collect()?; + assert_eq!(out.height(), 3); + + Ok(()) +} +#[test] +fn test_union_and_agg_projections() -> PolarsResult<()> { + init_files(); + let _guard = SINGLE_LOCK.lock().unwrap(); + // a union vstacks columns and aggscan optimization determines columns to aggregate in a + // hashmap, if that doesn't set them sorted the vstack will panic. + let lf1 = LazyFrame::scan_parquet(GLOB_PARQUET, Default::default())?; + let lf2 = LazyFrame::scan_ipc(GLOB_IPC, Default::default())?; + let lf3 = LazyCsvReader::new(GLOB_CSV).finish()?; + + for lf in [lf1, lf2, lf3] { + let lf = lf.filter(col("category").eq(lit("vegetables"))).select([ + col("fats_g").sum().alias("sum"), + col("fats_g").cast(DataType::Float64).mean().alias("mean"), + col("fats_g").min().alias("min"), + ]); + + let out = lf.collect()?; + assert_eq!(out.shape(), (1, 3)); + } + + Ok(()) +} + +#[test] +#[cfg(all(feature = "ipc", feature = "csv"))] +fn test_slice_filter() -> PolarsResult<()> { + init_files(); + let _guard = SINGLE_LOCK.lock().unwrap(); + + // make sure that the slices are not applied before the predicates. + let len = 5; + let offset = 3; + + let df1 = scan_foods_csv() + .filter(col("category").eq(lit("fruit"))) + .slice(offset, len) + .collect()?; + let df2 = scan_foods_parquet(false) + .filter(col("category").eq(lit("fruit"))) + .slice(offset, len) + .collect()?; + let df3 = scan_foods_ipc() + .filter(col("category").eq(lit("fruit"))) + .slice(offset, len) + .collect()?; + + let df1_ = scan_foods_csv() + .collect()? + .lazy() + .filter(col("category").eq(lit("fruit"))) + .slice(offset, len) + .collect()?; + let df2_ = scan_foods_parquet(false) + .collect()? + .lazy() + .filter(col("category").eq(lit("fruit"))) + .slice(offset, len) + .collect()?; + let df3_ = scan_foods_ipc() + .collect()? + .lazy() + .filter(col("category").eq(lit("fruit"))) + .slice(offset, len) + .collect()?; + + assert_eq!(df1.shape(), df1_.shape()); + assert_eq!(df2.shape(), df2_.shape()); + assert_eq!(df3.shape(), df3_.shape()); + + Ok(()) +} + +#[test] +fn skip_rows_and_slice() -> PolarsResult<()> { + let out = LazyCsvReader::new(FOODS_CSV) + .with_skip_rows(4) + .finish()? + .limit(1) + .collect()?; + assert_eq!(out.column("fruit")?.get(0)?, AnyValue::String("seafood")); + assert_eq!(out.shape(), (1, 4)); + Ok(()) +} + +#[test] +fn test_row_index_on_files() -> PolarsResult<()> { + let _guard = SINGLE_LOCK.lock().unwrap(); + for offset in [0 as IdxSize, 10] { + let lf = LazyCsvReader::new(FOODS_CSV) + .with_row_index(Some(RowIndex { + name: PlSmallStr::from_static("index"), + offset, + })) + .finish()?; + + assert!(row_index_at_scan(lf.clone())); + let df = lf.collect()?; + let idx = df.column("index")?; + assert_eq!( + idx.idx()?.into_no_null_iter().collect::>(), + (offset..27 + offset).collect::>() + ); + + let lf = LazyFrame::scan_parquet(FOODS_PARQUET, Default::default())? + .with_row_index("index", Some(offset)); + assert!(row_index_at_scan(lf.clone())); + let df = lf.collect()?; + let idx = df.column("index")?; + assert_eq!( + idx.idx()?.into_no_null_iter().collect::>(), + (offset..27 + offset).collect::>() + ); + + let lf = LazyFrame::scan_ipc(FOODS_IPC, Default::default())? + .with_row_index("index", Some(offset)); + + assert!(row_index_at_scan(lf.clone())); + let df = lf.clone().collect()?; + let idx = df.column("index")?; + assert_eq!( + idx.idx()?.into_no_null_iter().collect::>(), + (offset..27 + offset).collect::>() + ); + + let out = lf + .filter(col("index").gt(lit(-1))) + .select([col("calories")]) + .collect()?; + assert!(out.column("calories").is_ok()); + assert_eq!(out.shape(), (27, 1)); + } + + Ok(()) +} + +#[test] +fn scan_predicate_on_set_null_values() -> PolarsResult<()> { + let df = LazyCsvReader::new(FOODS_CSV) + .with_null_values(Some(NullValues::Named(vec![("fats_g".into(), "0".into())]))) + .with_infer_schema_length(Some(0)) + .finish()? + .select([col("category"), col("fats_g")]) + .filter(col("fats_g").is_null()) + .collect()?; + + assert_eq!(df.shape(), (12, 2)); + Ok(()) +} + +#[test] +fn scan_anonymous_fn_with_options() -> PolarsResult<()> { + struct MyScan {} + + impl AnonymousScan for MyScan { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn allows_projection_pushdown(&self) -> bool { + true + } + + fn scan(&self, scan_opts: AnonymousScanArgs) -> PolarsResult { + assert_eq!(scan_opts.with_columns.clone().unwrap().len(), 2); + assert_eq!(scan_opts.n_rows, Some(3)); + let out = fruits_cars().select(scan_opts.with_columns.unwrap().iter().cloned())?; + Ok(out.slice(0, scan_opts.n_rows.unwrap())) + } + } + + let function = Arc::new(MyScan {}); + + let args = ScanArgsAnonymous { + schema: Some(fruits_cars().schema().clone()), + ..ScanArgsAnonymous::default() + }; + + let q = LazyFrame::anonymous_scan(function, args)? + .with_column((col("A") * lit(2)).alias("A2")) + .select([col("A2"), col("fruits")]) + .limit(3); + + let df = q.collect()?; + + assert_eq!(df.shape(), (3, 2)); + Ok(()) +} + +#[test] +fn scan_anonymous_fn_count() -> PolarsResult<()> { + struct MyScan {} + + impl AnonymousScan for MyScan { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn allows_projection_pushdown(&self) -> bool { + true + } + + fn scan(&self, scan_opts: AnonymousScanArgs) -> PolarsResult { + assert_eq!(scan_opts.with_columns.as_deref(), Some(&["A".into()][..])); + + Ok(fruits_cars() + .select(scan_opts.with_columns.unwrap().iter().cloned()) + .unwrap()) + } + } + + let function = Arc::new(MyScan {}); + + let args = ScanArgsAnonymous { + schema: Some(fruits_cars().schema().clone()), + ..ScanArgsAnonymous::default() + }; + + let df = LazyFrame::anonymous_scan(function, args)? + .select(&[dsl::len()]) + .collect() + .unwrap(); + + assert_eq!(df.get_columns().len(), 1); + assert_eq!(df.get_columns()[0].len(), 1); + assert_eq!( + df.get_columns()[0] + .cast(&DataType::UInt32) + .unwrap() + .as_materialized_series() + .first(), + Scalar::new(DataType::UInt32, AnyValue::UInt32(5)) + ); + + Ok(()) +} + +#[test] +#[cfg(feature = "dtype-full")] +fn scan_small_dtypes() -> PolarsResult<()> { + let small_dt = vec![ + DataType::Int8, + DataType::UInt8, + DataType::Int16, + DataType::UInt16, + ]; + for dt in small_dt { + let df = LazyCsvReader::new(FOODS_CSV) + .with_has_header(true) + .with_dtype_overwrite(Some(Arc::new(Schema::from_iter([Field::new( + "sugars_g".into(), + dt.clone(), + )])))) + .finish()? + .select(&[col("sugars_g")]) + .collect()?; + + assert_eq!(df.dtypes(), &[dt]); + } + Ok(()) +} diff --git a/crates/polars-lazy/src/tests/logical.rs b/crates/polars-lazy/src/tests/logical.rs new file mode 100644 index 000000000000..ca9906d55fd7 --- /dev/null +++ b/crates/polars-lazy/src/tests/logical.rs @@ -0,0 +1,160 @@ +use polars_core::utils::arrow::temporal_conversions::MILLISECONDS_IN_DAY; + +use super::*; + +#[test] +#[cfg(all(feature = "strings", feature = "temporal", feature = "dtype-duration"))] +fn test_duration() -> PolarsResult<()> { + let df = df![ + "date" => ["2021-01-01", "2021-01-02", "2021-01-03"], + "groups" => [1, 1, 1] + ]?; + + let out = df + .lazy() + .with_columns(&[col("date").str().to_date(StrptimeOptions { + ..Default::default() + })]) + .with_column( + col("date") + .cast(DataType::Datetime(TimeUnit::Milliseconds, None)) + .alias("datetime"), + ) + .group_by([col("groups")]) + .agg([ + (col("date") - col("date").first()).alias("date"), + (col("datetime") - col("datetime").first()).alias("datetime"), + ]) + .explode([col("date"), col("datetime")]) + .collect()?; + + for c in ["date", "datetime"] { + let column = out.column(c)?; + assert!(matches!( + column.dtype(), + DataType::Duration(TimeUnit::Milliseconds) + )); + + assert_eq!( + column.get(0)?, + AnyValue::Duration(0, TimeUnit::Milliseconds) + ); + assert_eq!( + column.get(1)?, + AnyValue::Duration(MILLISECONDS_IN_DAY, TimeUnit::Milliseconds) + ); + assert_eq!( + column.get(2)?, + AnyValue::Duration(2 * MILLISECONDS_IN_DAY, TimeUnit::Milliseconds) + ); + } + Ok(()) +} + +fn print_plans(lf: &LazyFrame) { + println!("LOGICAL PLAN\n\n{}\n", lf.describe_plan().unwrap()); + println!( + "OPTIMIZED LOGICAL PLAN\n\n{}\n", + lf.describe_optimized_plan().unwrap() + ); +} + +#[test] +fn test_lazy_arithmetic() { + let df = get_df(); + let lf = df + .lazy() + .select(&[((col("sepal_width") * lit(100)).alias("super_wide"))]) + .sort(["super_wide"], SortMultipleOptions::default()); + + print_plans(&lf); + + let new = lf.collect().unwrap(); + println!("{:?}", new); + assert_eq!(new.height(), 7); + assert_eq!( + new.column("super_wide").unwrap().f64().unwrap().get(0), + Some(300.0) + ); +} + +#[test] +fn test_lazy_logical_plan_filter_and_alias_combined() { + let df = get_df(); + let lf = df + .lazy() + .filter(col("sepal_width").lt(lit(3.5))) + .select(&[col("variety").alias("foo")]); + + print_plans(&lf); + let df = lf.collect().unwrap(); + println!("{:?}", df); +} + +#[test] +fn test_lazy_logical_plan_schema() { + let df = get_df(); + let lp = df + .clone() + .lazy() + .select(&[col("variety").alias("foo")]) + .logical_plan; + + assert!(lp.compute_schema().unwrap().get("foo").is_some()); + + let lp = df + .lazy() + .group_by([col("variety")]) + .agg([col("sepal_width").min()]) + .logical_plan; + assert!(lp.compute_schema().unwrap().get("sepal_width").is_some()); +} + +#[test] +fn test_lazy_logical_plan_join() { + let left = df!("days" => &[0, 1, 2, 3, 4], + "temp" => [22.1, 19.9, 7., 2., 3.], + "rain" => &[0.1, 0.2, 0.3, 0.4, 0.5] + ) + .unwrap(); + + let right = df!( + "days" => &[1, 2], + "rain" => &[0.1, 0.2] + ) + .unwrap(); + + // check if optimizations succeeds without selection + { + let lf = left + .clone() + .lazy() + .left_join(right.clone().lazy(), col("days"), col("days")); + + print_plans(&lf); + // implicitly checks logical plan == optimized logical plan + let _df = lf.collect().unwrap(); + } + + // check if optimization succeeds with selection + { + let lf = left + .clone() + .lazy() + .left_join(right.clone().lazy(), col("days"), col("days")) + .select(&[col("temp")]); + + let _df = lf.collect().unwrap(); + } + + // check if optimization succeeds with selection of a renamed column due to the join + { + let lf = left + .lazy() + .left_join(right.lazy(), col("days"), col("days")) + .select(&[col("temp"), col("rain_right")]); + + print_plans(&lf); + let _df = lf.collect().unwrap(); + } +} diff --git a/crates/polars-lazy/src/tests/mod.rs b/crates/polars-lazy/src/tests/mod.rs new file mode 100644 index 000000000000..a45bd06ed68d --- /dev/null +++ b/crates/polars-lazy/src/tests/mod.rs @@ -0,0 +1,204 @@ +mod aggregations; +mod arity; +#[cfg(all(feature = "strings", feature = "cse"))] +mod cse; +#[cfg(feature = "parquet")] +mod io; +mod logical; +mod optimization_checks; +#[cfg(all(feature = "strings", feature = "cse"))] +mod pdsh; +mod predicate_queries; +mod projection_queries; +mod queries; +mod schema; +#[cfg(feature = "streaming")] +mod streaming; + +fn get_arenas() -> (Arena, Arena) { + let expr_arena = Arena::with_capacity(16); + let lp_arena = Arena::with_capacity(8); + (expr_arena, lp_arena) +} + +fn load_df() -> DataFrame { + df!("a" => &[1, 2, 3, 4, 5], + "b" => &["a", "a", "b", "c", "c"], + "c" => &[1, 2, 3, 4, 5] + ) + .unwrap() +} + +use std::io::Cursor; + +#[cfg(feature = "temporal")] +use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; +use optimization_checks::*; +#[cfg(feature = "parquet")] +pub(crate) use polars_core::SINGLE_LOCK; +use polars_core::chunked_array::builder::get_list_builder; +use polars_core::df; +use polars_core::prelude::*; +use polars_io::prelude::*; + +#[cfg(feature = "cov")] +use crate::dsl::pearson_corr; +use crate::prelude::*; + +#[cfg(feature = "parquet")] +static GLOB_PARQUET: &str = "../../examples/datasets/*.parquet"; +#[cfg(feature = "csv")] +static GLOB_CSV: &str = "../../examples/datasets/foods*.csv"; +#[cfg(feature = "ipc")] +static GLOB_IPC: &str = "../../examples/datasets/*.ipc"; +#[cfg(feature = "parquet")] +static FOODS_PARQUET: &str = "../../examples/datasets/foods1.parquet"; +#[cfg(feature = "parquet")] +static NUTRI_SCORE_NULL_COLUMN_PARQUET: &str = "../../examples/datasets/null_nutriscore.parquet"; +#[cfg(feature = "csv")] +static FOODS_CSV: &str = "../../examples/datasets/foods1.csv"; +#[cfg(feature = "ipc")] +static FOODS_IPC: &str = "../../examples/datasets/foods1.ipc"; + +#[cfg(feature = "csv")] +fn scan_foods_csv() -> LazyFrame { + LazyCsvReader::new(FOODS_CSV).finish().unwrap() +} + +#[cfg(feature = "ipc")] +fn scan_foods_ipc() -> LazyFrame { + init_files(); + LazyFrame::scan_ipc(FOODS_IPC, Default::default()).unwrap() +} + +#[cfg(any(feature = "ipc", feature = "parquet"))] +fn init_files() { + if std::fs::OpenOptions::new() + .write(true) + .create_new(true) + .open("../../examples/datasets/busy") + .is_err() + { + while !std::fs::exists("../../examples/datasets/finished").unwrap() {} + return; + } + + for path in &[ + "../../examples/datasets/foods1.csv", + "../../examples/datasets/foods2.csv", + "../../examples/datasets/null_nutriscore.csv", + ] { + for ext in [".parquet", ".ipc", ".ndjson"] { + let out_path = path.replace(".csv", ext); + + if std::fs::metadata(&out_path).is_err() { + let mut df = CsvReadOptions::default() + .try_into_reader_with_file_path(Some(path.into())) + .unwrap() + .finish() + .unwrap(); + let f = std::fs::File::create(&out_path).unwrap(); + + match ext { + ".parquet" => { + #[cfg(feature = "parquet")] + { + ParquetWriter::new(f) + .with_statistics(StatisticsOptions::full()) + .finish(&mut df) + .unwrap(); + } + }, + ".ipc" => { + IpcWriter::new(f).finish(&mut df).unwrap(); + }, + ".ndjson" => { + #[cfg(feature = "json")] + { + JsonWriter::new(f).finish(&mut df).unwrap() + } + }, + _ => panic!(), + } + } + } + } + + std::fs::OpenOptions::new() + .write(true) + .create_new(true) + .open("../../examples/datasets/finished") + .unwrap(); +} + +#[cfg(feature = "parquet")] +fn scan_foods_parquet(parallel: bool) -> LazyFrame { + init_files(); + let out_path = FOODS_PARQUET; + let parallel = if parallel { + ParallelStrategy::Auto + } else { + ParallelStrategy::None + }; + + let args = ScanArgsParquet { + n_rows: None, + cache: false, + parallel, + rechunk: true, + ..Default::default() + }; + LazyFrame::scan_parquet(out_path, args).unwrap() +} + +#[cfg(feature = "parquet")] +fn scan_nutri_score_null_column_parquet(parallel: bool) -> LazyFrame { + init_files(); + let out_path = NUTRI_SCORE_NULL_COLUMN_PARQUET; + let parallel = if parallel { + ParallelStrategy::Auto + } else { + ParallelStrategy::None + }; + + let args = ScanArgsParquet { + n_rows: None, + cache: false, + parallel, + rechunk: true, + ..Default::default() + }; + LazyFrame::scan_parquet(out_path, args).unwrap() +} + +pub(crate) fn fruits_cars() -> DataFrame { + df!( + "A"=> [1, 2, 3, 4, 5], + "fruits"=> ["banana", "banana", "apple", "apple", "banana"], + "B"=> [5, 4, 3, 2, 1], + "cars"=> ["beetle", "audi", "beetle", "beetle", "beetle"] + ) + .unwrap() +} + +pub(crate) fn get_df() -> DataFrame { + let s = r#" +"sepal_length","sepal_width","petal_length","petal_width","variety" +5.1,3.5,1.4,.2,"Setosa" +4.9,3,1.4,.2,"Setosa" +4.7,3.2,1.3,.2,"Setosa" +4.6,3.1,1.5,.2,"Setosa" +5,3.6,1.4,.2,"Setosa" +5.4,3.9,1.7,.4,"Setosa" +4.6,3.4,1.4,.3,"Setosa" +"#; + + let file = Cursor::new(s); + + CsvReadOptions::default() + .with_infer_schema_length(Some(3)) + .with_has_header(true) + .into_reader_with_file_handle(file) + .finish() + .unwrap() +} diff --git a/crates/polars-lazy/src/tests/optimization_checks.rs b/crates/polars-lazy/src/tests/optimization_checks.rs new file mode 100644 index 000000000000..d679ed89fd50 --- /dev/null +++ b/crates/polars-lazy/src/tests/optimization_checks.rs @@ -0,0 +1,654 @@ +use super::*; + +#[cfg(feature = "parquet")] +pub(crate) fn row_index_at_scan(q: LazyFrame) -> bool { + let (mut expr_arena, mut lp_arena) = get_arenas(); + let lp = q.optimize(&mut lp_arena, &mut expr_arena).unwrap(); + + (&lp_arena).iter(lp).any(|(_, lp)| { + if let IR::Scan { + unified_scan_args, .. + } = lp + { + unified_scan_args.row_index.is_some() + } else { + false + } + }) +} + +pub(crate) fn predicate_at_scan(q: LazyFrame) -> bool { + let (mut expr_arena, mut lp_arena) = get_arenas(); + let lp = q.optimize(&mut lp_arena, &mut expr_arena).unwrap(); + + (&lp_arena).iter(lp).any(|(_, lp)| match lp { + IR::Filter { input, .. } => { + matches!(lp_arena.get(*input), IR::DataFrameScan { .. }) + }, + IR::Scan { + predicate: Some(_), .. + } => true, + _ => false, + }) +} + +pub(crate) fn predicate_at_all_scans(q: LazyFrame) -> bool { + let (mut expr_arena, mut lp_arena) = get_arenas(); + let lp = q.optimize(&mut lp_arena, &mut expr_arena).unwrap(); + + (&lp_arena).iter(lp).all(|(_, lp)| match lp { + IR::Filter { input, .. } => { + matches!(lp_arena.get(*input), IR::DataFrameScan { .. }) + }, + IR::Scan { + predicate: Some(_), .. + } => true, + _ => false, + }) +} + +#[cfg(any(feature = "parquet", feature = "csv"))] +fn slice_at_scan(q: LazyFrame) -> bool { + let (mut expr_arena, mut lp_arena) = get_arenas(); + let lp = q.optimize(&mut lp_arena, &mut expr_arena).unwrap(); + (&lp_arena).iter(lp).any(|(_, lp)| { + use IR::*; + match lp { + Scan { + unified_scan_args, .. + } => unified_scan_args.pre_slice.is_some(), + _ => false, + } + }) +} + +#[test] +fn test_pred_pd_1() -> PolarsResult<()> { + let df = fruits_cars(); + + let q = df + .clone() + .lazy() + .select([col("A"), col("B")]) + .filter(col("A").gt(lit(1))); + + assert!(predicate_at_scan(q)); + + // Check if we understand that we can unwrap the alias. + let q = df + .clone() + .lazy() + .select([col("A").alias("C"), col("B")]) + .filter(col("C").gt(lit(1))); + + assert!(predicate_at_scan(q)); + + // Check if we pass hstack. + let q = df + .clone() + .lazy() + .with_columns([col("A").alias("C"), col("B")]) + .filter(col("B").gt(lit(1))); + + assert!(predicate_at_scan(q)); + + Ok(()) +} + +#[test] +fn test_no_left_join_pass() -> PolarsResult<()> { + let df1 = df![ + "foo" => ["abc", "def", "ghi"], + "idx1" => [0, 0, 1], + ]?; + let df2 = df![ + "bar" => [5, 6], + "idx2" => [0, 1], + ]?; + + let out = df1 + .lazy() + .join( + df2.lazy(), + [col("idx1")], + [col("idx2")], + JoinType::Left.into(), + ) + .filter(col("bar").eq(lit(5i32))) + .collect()?; + + let expected = df![ + "foo" => ["abc", "def"], + "idx1" => [0, 0], + "bar" => [5, 5], + ]?; + + assert!(out.equals(&expected)); + Ok(()) +} + +#[test] +#[cfg(feature = "parquet")] +pub fn test_simple_slice() -> PolarsResult<()> { + let _guard = SINGLE_LOCK.lock().unwrap(); + let q = scan_foods_parquet(false).limit(3); + + assert!(slice_at_scan(q.clone())); + let out = q.collect()?; + assert_eq!(out.height(), 3); + + let q = scan_foods_parquet(false) + .select([col("category"), col("calories").alias("bar")]) + .limit(3); + assert!(slice_at_scan(q.clone())); + let out = q.collect()?; + assert_eq!(out.height(), 3); + + Ok(()) +} + +#[test] +#[cfg(feature = "parquet")] +#[cfg(feature = "cse")] +pub fn test_slice_pushdown_join() -> PolarsResult<()> { + let _guard = SINGLE_LOCK.lock().unwrap(); + let q1 = scan_foods_parquet(false).limit(3); + let q2 = scan_foods_parquet(false); + + let q = q1 + .join( + q2, + [col("category")], + [col("category")], + JoinType::Left.into(), + ) + .slice(1, 3) + // this inserts a cache and blocks slice pushdown + .with_comm_subplan_elim(false); + // test if optimization continued beyond the join node + assert!(slice_at_scan(q.clone())); + + let (mut expr_arena, mut lp_arena) = get_arenas(); + let lp = q.clone().optimize(&mut lp_arena, &mut expr_arena).unwrap(); + assert!((&lp_arena).iter(lp).all(|(_, lp)| { + use IR::*; + match lp { + Join { options, .. } => options.args.slice == Some((1, 3)), + Slice { .. } => false, + _ => true, + } + })); + let out = q.collect()?; + assert_eq!(out.shape(), (3, 7)); + + Ok(()) +} + +#[test] +#[cfg(feature = "parquet")] +pub fn test_slice_pushdown_group_by() -> PolarsResult<()> { + let _guard = SINGLE_LOCK.lock().unwrap(); + let q = scan_foods_parquet(false).limit(100); + + let q = q + .group_by([col("category")]) + .agg([col("calories").sum()]) + .slice(1, 3); + + // test if optimization continued beyond the group_by node + assert!(slice_at_scan(q.clone())); + + let (mut expr_arena, mut lp_arena) = get_arenas(); + let lp = q.clone().optimize(&mut lp_arena, &mut expr_arena).unwrap(); + assert!((&lp_arena).iter(lp).all(|(_, lp)| { + use IR::*; + match lp { + GroupBy { options, .. } => options.slice == Some((1, 3)), + Slice { .. } => false, + _ => true, + } + })); + let out = q.collect()?; + assert_eq!(out.shape(), (3, 2)); + + Ok(()) +} + +#[test] +#[cfg(feature = "parquet")] +pub fn test_slice_pushdown_sort() -> PolarsResult<()> { + let _guard = SINGLE_LOCK.lock().unwrap(); + let q = scan_foods_parquet(false).limit(100); + + let q = q + .sort(["category"], SortMultipleOptions::default()) + .slice(1, 3); + + // test if optimization continued beyond the sort node + assert!(slice_at_scan(q.clone())); + + let (mut expr_arena, mut lp_arena) = get_arenas(); + let lp = q.clone().optimize(&mut lp_arena, &mut expr_arena).unwrap(); + assert!((&lp_arena).iter(lp).all(|(_, lp)| { + use IR::*; + match lp { + Sort { slice, .. } => *slice == Some((1, 3)), + Slice { .. } => false, + _ => true, + } + })); + let out = q.collect()?; + assert_eq!(out.shape(), (3, 4)); + + Ok(()) +} + +#[test] +#[cfg(feature = "dtype-i16")] +pub fn test_predicate_block_cast() -> PolarsResult<()> { + let df = df![ + "value" => [10, 20, 30, 40] + ]?; + + let lf1 = df + .clone() + .lazy() + .with_column(col("value").cast(DataType::Int16) * lit(0.1).cast(DataType::Float32)) + .filter(col("value").lt(lit(2.5f32))); + + let lf2 = df + .lazy() + .select([col("value").cast(DataType::Int16) * lit(0.1).cast(DataType::Float32)]) + .filter(col("value").lt(lit(2.5f32))); + + for lf in [lf1, lf2] { + assert!(!predicate_at_scan(lf.clone())); + + let out = lf.collect()?; + let s = out.column("value").unwrap(); + assert_eq!( + s, + &Column::new(PlSmallStr::from_static("value"), [1.0f32, 2.0]) + ); + } + + Ok(()) +} + +#[test] +fn test_lazy_filter_and_rename() { + let df = load_df(); + let lf = df + .clone() + .lazy() + .rename(["a"], ["x"], true) + .filter(col("x").map( + |s: Column| Ok(Some(s.as_materialized_series().gt(3)?.into_column())), + GetOutput::from_type(DataType::Boolean), + )) + .select([col("x")]); + + let correct = df! { + "x" => &[4, 5] + } + .unwrap(); + assert!(lf.collect().unwrap().equals(&correct)); + + // now we check if the column is rename or added when we don't select + let lf = df.lazy().rename(["a"], ["x"], true).filter(col("x").map( + |s: Column| Ok(Some(s.as_materialized_series().gt(3)?.into_column())), + GetOutput::from_type(DataType::Boolean), + )); + // the rename function should not interfere with the predicate pushdown + assert!(predicate_at_scan(lf.clone())); + + assert_eq!(lf.collect().unwrap().get_column_names(), &["x", "b", "c"]); +} + +#[test] +fn test_with_row_index_opts() -> PolarsResult<()> { + let df = df![ + "a" => [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + ]?; + + let out = df + .clone() + .lazy() + .with_row_index("index", None) + .tail(5) + .collect()?; + let expected = df![ + "index" => [5 as IdxSize, 6, 7, 8, 9], + "a" => [5, 6, 7, 8, 9], + ]?; + + assert!(out.equals(&expected)); + let out = df + .clone() + .lazy() + .with_row_index("index", None) + .slice(1, 2) + .collect()?; + assert_eq!( + out.column("index")? + .idx()? + .into_no_null_iter() + .collect::>(), + &[1, 2] + ); + + let out = df + .clone() + .lazy() + .with_row_index("index", None) + .filter(col("a").eq(lit(3i32))) + .collect()?; + assert_eq!( + out.column("index")? + .idx()? + .into_no_null_iter() + .collect::>(), + &[3] + ); + + let out = df + .clone() + .lazy() + .slice(1, 2) + .with_row_index("index", None) + .collect()?; + assert_eq!( + out.column("index")? + .idx()? + .into_no_null_iter() + .collect::>(), + &[0, 1] + ); + + let out = df + .lazy() + .filter(col("a").eq(lit(3i32))) + .with_row_index("index", None) + .collect()?; + assert_eq!( + out.column("index")? + .idx()? + .into_no_null_iter() + .collect::>(), + &[0] + ); + + Ok(()) +} + +#[cfg(all(feature = "concat_str", feature = "strings"))] +#[test] +fn test_string_addition_to_concat_str() -> PolarsResult<()> { + let df = df![ + "a"=> ["a"], + "b"=> ["b"], + ]?; + + let q = df + .lazy() + .select([lit("foo") + col("a") + col("b") + lit("bar")]); + + let (mut expr_arena, mut lp_arena) = get_arenas(); + let root = q.clone().optimize(&mut lp_arena, &mut expr_arena)?; + let lp = lp_arena.get(root); + let mut exprs = lp.get_exprs(); + let e = exprs.pop().unwrap(); + if let AExpr::Function { input, .. } = expr_arena.get(e.node()) { + // the concat_str has the 4 expressions as input + assert_eq!(input.len(), 4); + } else { + panic!() + } + + let out = q.collect()?; + let s = out.column("literal")?; + assert_eq!(s.get(0)?, AnyValue::String("fooabbar")); + + Ok(()) +} +#[test] +fn test_with_column_prune() -> PolarsResult<()> { + // don't + let df = df![ + "c0" => [0], + "c1" => [0], + "c2" => [0], + ]?; + let (mut expr_arena, mut lp_arena) = get_arenas(); + + // only a single expression pruned and only one column selection + let q = df + .clone() + .lazy() + .with_columns([col("c0"), col("c1").alias("c4")]) + .select([col("c1"), col("c4")]); + let lp = q.optimize(&mut lp_arena, &mut expr_arena).unwrap(); + (&lp_arena).iter(lp).for_each(|(_, lp)| { + use IR::*; + match lp { + DataFrameScan { output_schema, .. } => { + let projection = output_schema.as_ref().unwrap(); + assert_eq!(projection.len(), 1); + let name = projection.get_at_index(0).unwrap().0; + assert_eq!(name, "c1"); + }, + HStack { exprs, .. } => { + assert_eq!(exprs.len(), 1); + }, + _ => {}, + }; + }); + + // whole `with_columns` pruned + let mut q = df.lazy().with_column(col("c0")).select([col("c1")]); + + let lp = q.clone().optimize(&mut lp_arena, &mut expr_arena).unwrap(); + + // check if with_column is pruned + assert!((&lp_arena).iter(lp).all(|(_, lp)| { + use IR::*; + + matches!(lp, SimpleProjection { .. } | DataFrameScan { .. }) + })); + assert_eq!( + q.collect_schema().unwrap().as_ref(), + &Schema::from_iter([Field::new(PlSmallStr::from_static("c1"), DataType::Int32)]) + ); + Ok(()) +} + +#[test] +#[cfg(feature = "csv")] +fn test_slice_at_scan_group_by() -> PolarsResult<()> { + let ldf = scan_foods_csv(); + + // this tests if slice pushdown restarts aggregation nodes (it did not) + let q = ldf + .slice(0, 5) + .filter(col("calories").lt(lit(10))) + .group_by([col("calories")]) + .agg([col("fats_g").first()]) + .select([col("fats_g")]); + + assert!(slice_at_scan(q)); + Ok(()) +} + +#[test] +fn test_flatten_unions() -> PolarsResult<()> { + let (mut expr_arena, mut lp_arena) = get_arenas(); + + let lf = df! { + "a" => [1,2,3,4,5], + } + .unwrap() + .lazy(); + + let args = UnionArgs { + rechunk: false, + parallel: true, + ..Default::default() + }; + let lf2 = concat(&[lf.clone(), lf.clone()], args).unwrap(); + let lf3 = concat(&[lf.clone(), lf.clone(), lf], args).unwrap(); + let lf4 = concat(&[lf2, lf3], args).unwrap(); + let root = lf4.optimize(&mut lp_arena, &mut expr_arena).unwrap(); + let lp = lp_arena.get(root); + match lp { + IR::Union { inputs, .. } => { + // we make sure that the nested unions are flattened into a single union + assert_eq!(inputs.len(), 5); + }, + _ => panic!(), + } + Ok(()) +} + +fn num_occurrences(s: &str, needle: &str) -> usize { + let mut i = 0; + let mut num = 0; + + while let Some(n) = s[i..].find(needle) { + i += n + 1; + num += 1; + } + + num +} + +#[test] +fn test_cluster_with_columns() -> Result<(), Box> { + use polars_core::prelude::*; + + let df = df!("foo" => &[0.5, 1.7, 3.2], + "bar" => &[4.1, 1.5, 9.2])?; + + let df = df + .lazy() + .without_optimizations() + .with_cluster_with_columns(true) + .with_columns([col("foo") * lit(2.0)]) + .with_columns([col("bar") / lit(1.5)]); + + let unoptimized = df.clone().to_alp().unwrap(); + let optimized = df.clone().to_alp_optimized().unwrap(); + + let unoptimized = unoptimized.describe(); + let optimized = optimized.describe(); + + println!("\n---\n"); + + println!("Unoptimized:\n{unoptimized}",); + println!("\n---\n"); + println!("Optimized:\n{optimized}"); + + assert_eq!(num_occurrences(&unoptimized, "WITH_COLUMNS"), 2); + assert_eq!(num_occurrences(&optimized, "WITH_COLUMNS"), 1); + + Ok(()) +} + +#[test] +fn test_cluster_with_columns_dependency() -> Result<(), Box> { + use polars_core::prelude::*; + + let df = df!("foo" => &[0.5, 1.7, 3.2], + "bar" => &[4.1, 1.5, 9.2])?; + + let df = df + .lazy() + .without_optimizations() + .with_cluster_with_columns(true) + .with_columns([col("foo").alias("buzz")]) + .with_columns([col("buzz")]); + + let unoptimized = df.clone().to_alp().unwrap(); + let optimized = df.clone().to_alp_optimized().unwrap(); + + let unoptimized = unoptimized.describe(); + let optimized = optimized.describe(); + + println!("\n---\n"); + + println!("Unoptimized:\n{unoptimized}",); + println!("\n---\n"); + println!("Optimized:\n{optimized}"); + + assert_eq!(num_occurrences(&unoptimized, "WITH_COLUMNS"), 2); + assert_eq!(num_occurrences(&optimized, "WITH_COLUMNS"), 2); + + Ok(()) +} + +#[test] +fn test_cluster_with_columns_partial() -> Result<(), Box> { + use polars_core::prelude::*; + + let df = df!("foo" => &[0.5, 1.7, 3.2], + "bar" => &[4.1, 1.5, 9.2])?; + + let df = df + .lazy() + .without_optimizations() + .with_cluster_with_columns(true) + .with_columns([col("foo").alias("buzz")]) + .with_columns([col("buzz"), col("foo") * lit(2.0)]); + + let unoptimized = df.clone().to_alp().unwrap(); + let optimized = df.clone().to_alp_optimized().unwrap(); + + let unoptimized = unoptimized.describe(); + let optimized = optimized.describe(); + + println!("\n---\n"); + + println!("Unoptimized:\n{unoptimized}",); + println!("\n---\n"); + println!("Optimized:\n{optimized}"); + + assert!(unoptimized.contains(r#"[col("buzz"), [(col("foo")) * (2.0)]]"#)); + assert!(unoptimized.contains(r#"[col("foo").alias("buzz")]"#)); + assert!(optimized.contains(r#"[col("buzz")]"#)); + assert!(optimized.contains(r#"[col("foo").alias("buzz"), [(col("foo")) * (2.0)]]"#)); + + Ok(()) +} + +#[test] +fn test_cluster_with_columns_chain() -> Result<(), Box> { + use polars_core::prelude::*; + + let df = df!("foo" => &[0.5, 1.7, 3.2], + "bar" => &[4.1, 1.5, 9.2])?; + + let df = df + .lazy() + .without_optimizations() + .with_cluster_with_columns(true) + .with_columns([col("foo").alias("foo1")]) + .with_columns([col("foo").alias("foo2")]) + .with_columns([col("foo").alias("foo3")]) + .with_columns([col("foo").alias("foo4")]); + + let unoptimized = df.clone().to_alp().unwrap(); + let optimized = df.clone().to_alp_optimized().unwrap(); + + let unoptimized = unoptimized.describe(); + let optimized = optimized.describe(); + + println!("\n---\n"); + + println!("Unoptimized:\n{unoptimized}",); + println!("\n---\n"); + println!("Optimized:\n{optimized}"); + + assert_eq!(num_occurrences(&unoptimized, "WITH_COLUMNS"), 4); + assert_eq!(num_occurrences(&optimized, "WITH_COLUMNS"), 1); + + Ok(()) +} diff --git a/crates/polars-lazy/src/tests/pdsh.rs b/crates/polars-lazy/src/tests/pdsh.rs new file mode 100644 index 000000000000..f0de0b641446 --- /dev/null +++ b/crates/polars-lazy/src/tests/pdsh.rs @@ -0,0 +1,113 @@ +//! The PDSH files only got ten rows, so after all the joins filters there is not data +//! Still we can use this to test the schema, operation correctness on empty data, and optimizations +//! taken. +use super::*; + +const fn base_path() -> &'static str { + "../../examples/datasets/pds_heads" +} + +fn region() -> LazyFrame { + let base_path = base_path(); + LazyFrame::scan_ipc( + format!("{base_path}/region.feather"), + ScanArgsIpc::default(), + ) + .unwrap() +} +fn nation() -> LazyFrame { + let base_path = base_path(); + LazyFrame::scan_ipc( + format!("{base_path}/nation.feather"), + ScanArgsIpc::default(), + ) + .unwrap() +} + +fn supplier() -> LazyFrame { + let base_path = base_path(); + LazyFrame::scan_ipc( + format!("{base_path}/supplier.feather"), + ScanArgsIpc::default(), + ) + .unwrap() +} + +fn part() -> LazyFrame { + let base_path = base_path(); + LazyFrame::scan_ipc(format!("{base_path}/part.feather"), ScanArgsIpc::default()).unwrap() +} + +fn partsupp() -> LazyFrame { + let base_path = base_path(); + LazyFrame::scan_ipc( + format!("{base_path}/partsupp.feather"), + ScanArgsIpc::default(), + ) + .unwrap() +} + +#[test] +fn test_q2() -> PolarsResult<()> { + let q1 = part() + .inner_join(partsupp(), "p_partkey", "ps_partkey") + .inner_join(supplier(), "ps_suppkey", "s_suppkey") + .inner_join(nation(), "s_nationkey", "n_nationkey") + .inner_join(region(), "n_regionkey", "r_regionkey") + .filter(col("p_size").eq(15)) + .filter(col("p_type").str().ends_with(lit("BRASS".to_string()))); + let q = q1 + .clone() + .group_by([col("p_partkey")]) + .agg([col("ps_supplycost").min()]) + .join( + q1, + [col("p_partkey"), col("ps_supplycost")], + [col("p_partkey"), col("ps_supplycost")], + JoinType::Inner.into(), + ) + .select([cols([ + "s_acctbal", + "s_name", + "n_name", + "p_partkey", + "p_mfgr", + "s_address", + "s_phone", + "s_comment", + ])]) + .sort_by_exprs( + [cols(["s_acctbal", "n_name", "s_name", "p_partkey"])], + SortMultipleOptions::default() + .with_order_descending_multi([true, false, false, false]) + .with_maintain_order(true), + ) + .limit(100) + .with_comm_subplan_elim(true); + + let IRPlan { + lp_top, lp_arena, .. + } = q.clone().to_alp_optimized().unwrap(); + assert_eq!( + (&lp_arena) + .iter(lp_top) + .filter(|(_, alp)| matches!(alp, IR::Cache { .. })) + .count(), + 2 + ); + + let out = q.collect()?; + let schema = Schema::from_iter([ + Field::new("s_acctbal".into(), DataType::Float64), + Field::new("s_name".into(), DataType::String), + Field::new("n_name".into(), DataType::String), + Field::new("p_partkey".into(), DataType::Int64), + Field::new("p_mfgr".into(), DataType::String), + Field::new("s_address".into(), DataType::String), + Field::new("s_phone".into(), DataType::String), + Field::new("s_comment".into(), DataType::String), + ]); + assert_eq!(&**out.schema(), &schema); + + Ok(()) +} diff --git a/crates/polars-lazy/src/tests/predicate_queries.rs b/crates/polars-lazy/src/tests/predicate_queries.rs new file mode 100644 index 000000000000..c5c2fcc8798d --- /dev/null +++ b/crates/polars-lazy/src/tests/predicate_queries.rs @@ -0,0 +1,334 @@ +use super::*; + +#[test] +#[cfg(feature = "parquet")] +fn test_multiple_roots() -> PolarsResult<()> { + let mut expr_arena = Arena::with_capacity(16); + let mut lp_arena = Arena::with_capacity(8); + + let lf = scan_foods_parquet(false).select([col("calories").alias("bar")]); + + // this produces a predicate with two root columns, this test if we can + // deal with multiple roots + let lf = lf.filter(col("bar").gt(lit(45i32))); + let lf = lf.filter(col("bar").lt(lit(110i32))); + + // also check if all predicates are combined and pushed down + let root = lf.clone().optimize(&mut lp_arena, &mut expr_arena)?; + assert!(predicate_at_scan(lf)); + // and that we don't have any filter node + assert!( + !(&lp_arena) + .iter(root) + .any(|(_, lp)| matches!(lp, IR::Filter { .. })) + ); + + Ok(()) +} + +#[test] +#[cfg(all(feature = "is_in", feature = "strings", feature = "dtype-categorical"))] +fn test_issue_2472() -> PolarsResult<()> { + let df = df![ + "group" => ["54360-2001-0-20020312-4-1" + ,"39444-2020-0-20210418-4-1" + ,"68398-2020-0-20201216-4-1" + ,"30910-2020-0-20210223-4-1" + ,"71060-2020-0-20210315-4-1" + ,"47959-2020-0-20210305-4-1" + ,"63212-2018-0-20181007-2-2" + ,"61465-2018-0-20181018-2-2" + ] + ]?; + let base = df + .lazy() + .with_column(col("group").cast(DataType::Categorical(None, Default::default()))); + + let extract = col("group") + .cast(DataType::String) + .str() + .extract(lit(r"(\d+-){4}(\w+)-"), 2) + .cast(DataType::Int32) + .alias("age"); + let predicate = col("age").is_in(lit(Series::new("".into(), [2i32])), false); + + let out = base + .clone() + .with_column(extract.clone()) + .filter(predicate.clone()) + .collect()?; + + assert_eq!(out.shape(), (2, 2)); + + let out = base.select([extract]).filter(predicate).collect()?; + assert_eq!(out.shape(), (2, 1)); + + Ok(()) +} + +#[test] +fn test_pass_unrelated_apply() -> PolarsResult<()> { + // maps should not influence a predicate of a different column as maps should not depend on previous values + let df = fruits_cars(); + + let q = df + .lazy() + .with_column(col("A").map( + |s| Ok(Some(s.is_null().into_column())), + GetOutput::from_type(DataType::Boolean), + )) + .filter(col("B").gt(lit(10i32))); + + assert!(predicate_at_scan(q)); + + Ok(()) +} + +#[test] +fn filter_added_column_issue_2470() -> PolarsResult<()> { + let df = fruits_cars(); + + // the binary expression in the predicate lead to an incorrect pushdown because the rhs + // was not checked on the schema. + let out = df + .lazy() + .select([col("A"), lit(NULL).alias("foo")]) + .filter(col("A").gt(lit(2i32)).and(col("foo").is_null())) + .collect()?; + assert_eq!(out.shape(), (3, 2)); + + Ok(()) +} + +#[test] +fn filter_blocked_by_map() -> PolarsResult<()> { + let df = fruits_cars(); + + let allowed = OptFlags::default() & !OptFlags::PREDICATE_PUSHDOWN; + let q = df + .lazy() + .map(Ok, allowed, None, None) + .filter(col("A").gt(lit(2i32))); + + assert!(!predicate_at_scan(q.clone())); + let out = q.collect()?; + assert_eq!(out.shape(), (3, 4)); + + Ok(()) +} + +#[test] +#[cfg(all(feature = "temporal", feature = "strings"))] +fn test_strptime_block_predicate() -> PolarsResult<()> { + let df = df![ + "date" => ["2021-01-01", "2021-01-02"] + ]?; + + let q = df + .lazy() + .with_column(col("date").str().to_date(StrptimeOptions { + ..Default::default() + })) + .filter( + col("date").gt(NaiveDate::from_ymd_opt(2021, 1, 1) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap() + .lit()), + ); + + assert!(!predicate_at_scan(q.clone())); + let df = q.collect()?; + assert_eq!(df.shape(), (1, 1)); + + Ok(()) +} + +#[test] +fn test_strict_cast_predicate_pushdown() -> PolarsResult<()> { + let df = df![ + "a" => ["a", "b", "c"] + ]?; + + let lf = df + .lazy() + .with_column(col("a").cast(DataType::Int32)) + .filter(col("a").is_null()); + + assert!(!predicate_at_scan(lf.clone())); + let out = lf.collect()?; + assert_eq!(out.shape(), (3, 1)); + Ok(()) +} + +#[test] +fn test_filter_nulls_created_by_join() -> PolarsResult<()> { + // #2602 + let a = df![ + "key" => ["foo", "bar"], + "bar" => [1, 2] + ]?; + + let b = df![ + "key"=> ["bar"] + ]? + .lazy() + .with_column(lit(true).alias("flag")); + + let out = a + .clone() + .lazy() + .join(b.clone(), [col("key")], [col("key")], JoinType::Left.into()) + .filter(col("flag").is_null()) + .collect()?; + let expected = df![ + "key" => ["foo"], + "bar" => [1], + "flag" => &[None, Some(true)][0..1] + ]?; + assert!(out.equals_missing(&expected)); + + let out = a + .lazy() + .join(b, [col("key")], [col("key")], JoinType::Left.into()) + .filter(col("flag").is_null()) + .with_predicate_pushdown(false) + .collect()?; + assert!(out.equals_missing(&expected)); + + Ok(()) +} + +#[test] +fn test_filter_null_creation_by_cast() -> PolarsResult<()> { + let df = df![ + "int" => [1, 2, 3], + "empty" => ["", "", ""] + ]?; + + let out = df + .lazy() + .with_column(col("empty").cast(DataType::Int32).alias("empty")) + .filter(col("empty").is_null().and(col("int").eq(lit(3i32)))) + .collect()?; + + let expected = df![ + "int" => [3], + "empty" => &[None, Some(1i32)][..1] + ]?; + assert!(out.equals_missing(&expected)); + + Ok(()) +} + +#[test] +fn test_predicate_pd_apply() -> PolarsResult<()> { + let q = df![ + "a" => [1, 2, 3], + ]? + .lazy() + .select([ + // map_list is use in python `col().apply` + col("a"), + col("a") + .map_list(|s| Ok(Some(s)), GetOutput::same_type()) + .alias("a_applied"), + ]) + .filter(col("a").lt(lit(3))); + + assert!(predicate_at_scan(q)); + Ok(()) +} +#[test] +#[cfg(feature = "cse")] +fn test_predicate_on_join_suffix_4788() -> PolarsResult<()> { + let lf = df![ + "x" => [1, 2], + "y" => [1, 1], + ]? + .lazy(); + + let q = (lf.clone().join_builder().with(lf)) + .left_on([col("y")]) + .right_on([col("y")]) + .suffix("_") + .finish() + .filter(col("x").eq(1)) + .with_comm_subplan_elim(false); + + // the left hand side should have a predicate + assert!(predicate_at_scan(q.clone())); + + let expected = df![ + "x" => [1, 1], + "y" => [1, 1], + "x_" => [1, 2], + ]?; + assert_eq!(q.collect()?, expected); + + Ok(()) +} + +#[test] +fn test_push_join_col_predicates_to_both_sides_7247() -> PolarsResult<()> { + let df1 = df! { + "a" => ["a1", "a2"], + "b" => ["b1", "b2"], + }?; + let df2 = df! { + "a" => ["a1", "a1", "a2"], + "b2" => ["b1", "b1", "b2"], + "c" => ["a1", "c", "a2"] + }?; + let df = df1.lazy().join( + df2.lazy(), + [col("a"), col("b")], + [col("a"), col("b2")], + JoinArgs::new(JoinType::Inner), + ); + let q = df + .filter(col("a").eq(lit("a1"))) + .filter(col("a").eq(col("c"))); + + predicate_at_all_scans(q.clone()); + + let out = q.collect()?; + let expected = df![ + "a" => ["a1"], + "b" => ["b1"], + "c" => ["a1"], + ]?; + assert_eq!(out, expected); + Ok(()) +} + +#[test] +#[cfg(feature = "semi_anti_join")] +fn test_push_join_col_predicates_to_both_sides_semi_12565() -> PolarsResult<()> { + let df1 = df! { + "a" => ["a1", "a2"], + "b" => ["b1", "b2"], + }?; + let df2 = df! { + "a" => ["a1", "a1", "a2"], + "b2" => ["b1", "b1", "b2"], + "c" => ["a1", "c", "a2"] + }?; + let df = df1.lazy().join( + df2.lazy(), + [col("a"), col("b")], + [col("a"), col("b2")], + JoinArgs::new(JoinType::Semi), + ); + let q = df.filter(col("a").eq(lit("a1"))); + + predicate_at_all_scans(q.clone()); + + let out = q.collect()?; + let expected = df![ + "a" => ["a1"], + "b" => ["b1"], + ]?; + assert_eq!(out, expected); + Ok(()) +} diff --git a/crates/polars-lazy/src/tests/projection_queries.rs b/crates/polars-lazy/src/tests/projection_queries.rs new file mode 100644 index 000000000000..d1594a461a86 --- /dev/null +++ b/crates/polars-lazy/src/tests/projection_queries.rs @@ -0,0 +1,175 @@ +use polars_ops::frame::JoinCoalesce; + +use super::*; + +#[test] +fn test_join_suffix_and_drop() -> PolarsResult<()> { + let weight = df![ + "id" => [1, 2, 3, 4, 5, 0], + "wgt" => [4.32, 5.23, 2.33, 23.399, 392.2, 0.0] + ]? + .lazy(); + + let ped = df![ + "id"=> [1, 2, 3, 4, 5], + "sireid"=> [0, 0, 1, 3, 3] + ]? + .lazy(); + + let sumry = weight + .clone() + .filter(col("id").eq(lit(2i32))) + .inner_join(ped, "id", "id"); + + let out = sumry + .join_builder() + .with(weight) + .left_on([col("sireid")]) + .right_on([col("id")]) + .suffix("_sire") + .finish() + .drop(["sireid"]) + .collect()?; + + assert_eq!(out.shape(), (1, 3)); + + Ok(()) +} + +#[test] +#[cfg(feature = "cross_join")] +fn test_cross_join_pd() -> PolarsResult<()> { + let food = df![ + "name"=> ["Omelette", "Fried Egg"], + "price" => [8, 5] + ]?; + + let drink = df![ + "name" => ["Orange Juice", "Tea"], + "price" => [5, 4] + ]?; + + let q = food.lazy().cross_join(drink.lazy(), None).select([ + col("name").alias("food"), + col("name_right").alias("beverage"), + (col("price") + col("price_right")).alias("total"), + ]); + + let out = q.collect()?; + let expected = df![ + "food" => ["Omelette", "Omelette", "Fried Egg", "Fried Egg"], + "beverage" => ["Orange Juice", "Tea", "Orange Juice", "Tea"], + "total" => [13, 12, 10, 9] + ]?; + + assert!(out.equals(&expected)); + Ok(()) +} + +#[test] +fn test_row_number_pd() -> PolarsResult<()> { + let df = df![ + "x" => [1, 2, 3], + "y" => [3, 2, 1], + ]?; + + let df = df + .lazy() + .with_row_index("index", None) + .select([col("index"), col("x") * lit(3i32)]) + .collect()?; + + let expected = df![ + "index" => [0 as IdxSize, 1, 2], + "x" => [3i32, 6, 9] + ]?; + + assert!(df.equals(&expected)); + + Ok(()) +} + +#[test] +#[cfg(feature = "cse")] +fn scan_join_same_file() -> PolarsResult<()> { + let lf = LazyCsvReader::new(FOODS_CSV).finish()?; + + for cse in [true, false] { + let partial = lf.clone().select([col("category")]).limit(5); + let q = lf + .clone() + .join( + partial, + [col("category")], + [col("category")], + JoinType::Inner.into(), + ) + .with_comm_subplan_elim(cse); + let out = q.collect()?; + assert_eq!( + out.get_column_names(), + &["category", "calories", "fats_g", "sugars_g"] + ); + } + Ok(()) +} + +#[test] +#[cfg(all(feature = "regex", feature = "concat_str"))] +fn concat_str_regex_expansion() -> PolarsResult<()> { + let df = df![ + "a"=> [1, 1, 1], + "b_a_1"=> ["a--", "", ""], + "b_a_2"=> ["", "b--", ""], + "b_a_3"=> ["", "", "c--"] + ]? + .lazy(); + let out = df + .select([concat_str([col(r"^b_a_\d$")], ";", false).alias("concatenated")]) + .collect()?; + let s = out.column("concatenated")?; + assert_eq!( + s, + &Column::new("concatenated".into(), ["a--;;", ";b--;", ";;c--"]) + ); + + Ok(()) +} + +#[test] +fn test_coalesce_toggle_projection_pushdown() -> PolarsResult<()> { + // Test that the optimizer toggle coalesce to true if the non-coalesced column isn't used. + let q1 = df!["a" => [1], + "b" => [2] + ]? + .lazy(); + + let q2 = df!["a" => [1], + "c" => [2] + ]? + .lazy(); + + let plan = q1 + .join( + q2, + [col("a")], + [col("a")], + JoinArgs { + how: JoinType::Left, + coalesce: JoinCoalesce::KeepColumns, + ..Default::default() + }, + ) + .select([col("a"), col("b")]) + .to_alp_optimized()?; + + let node = plan.lp_top; + let lp_arena = plan.lp_arena; + + assert!((&lp_arena).iter(node).all(|(_, plan)| match plan { + IR::Join { options, .. } => options.args.should_coalesce(), + _ => true, + })); + + Ok(()) +} diff --git a/crates/polars-lazy/src/tests/queries.rs b/crates/polars-lazy/src/tests/queries.rs new file mode 100644 index 000000000000..c6002b462463 --- /dev/null +++ b/crates/polars-lazy/src/tests/queries.rs @@ -0,0 +1,1976 @@ +#[cfg(feature = "diff")] +use polars_core::series::ops::NullBehavior; + +use super::*; + +#[test] +fn test_lazy_with_column() { + let df = get_df() + .lazy() + .with_column(lit(10).alias("foo")) + .collect() + .unwrap(); + assert_eq!(df.width(), 6); + assert!(df.column("foo").is_ok()); +} + +#[test] +fn test_lazy_exec() { + let df = get_df(); + let _new = df + .clone() + .lazy() + .select([col("sepal_width"), col("variety")]) + .sort(["sepal_width"], Default::default()) + .collect(); + + let new = df + .lazy() + .filter(not(col("sepal_width").lt(lit(3.5)))) + .collect() + .unwrap(); + + let check = new.column("sepal_width").unwrap().f64().unwrap().gt(3.4); + assert!(check.all()) +} + +#[test] +fn test_lazy_alias() { + let df = get_df(); + let new = df + .lazy() + .select([col("sepal_width").alias("petals"), col("sepal_width")]) + .collect() + .unwrap(); + assert_eq!(new.get_column_names(), &["petals", "sepal_width"]); +} + +#[test] +#[cfg(feature = "pivot")] +fn test_lazy_unpivot() { + let df = get_df(); + + let args = UnpivotArgsDSL { + on: vec!["sepal_length".into(), "sepal_width".into()], + index: vec!["petal_width".into(), "petal_length".into()], + ..Default::default() + }; + + let out = df + .lazy() + .unpivot(args) + .filter(col("variable").eq(lit("sepal_length"))) + .select([col("variable"), col("petal_width"), col("value")]) + .collect() + .unwrap(); + assert_eq!(out.shape(), (7, 3)); +} + +#[test] +fn test_lazy_drop_nulls() { + let df = df! { + "foo" => &[Some(1), None, Some(3)], + "bar" => &[Some(1), Some(2), None] + } + .unwrap(); + + let new = df.lazy().drop_nulls(None).collect().unwrap(); + let out = df! { + "foo" => &[Some(1)], + "bar" => &[Some(1)] + } + .unwrap(); + assert!(new.equals(&out)); +} + +#[test] +fn test_lazy_udf() { + let df = get_df(); + let new = df + .lazy() + .select([col("sepal_width").map(|s| Ok(Some(s * 200.0)), GetOutput::same_type())]) + .collect() + .unwrap(); + assert_eq!( + new.column("sepal_width").unwrap().f64().unwrap().get(0), + Some(700.0) + ); +} + +#[test] +fn test_lazy_is_null() { + let df = get_df(); + let new = df + .clone() + .lazy() + .filter(col("sepal_width").is_null()) + .collect() + .unwrap(); + + assert_eq!(new.height(), 0); + + let new = df + .clone() + .lazy() + .filter(col("sepal_width").is_not_null()) + .collect() + .unwrap(); + assert_eq!(new.height(), df.height()); + + let new = df + .lazy() + .group_by([col("variety")]) + .agg([col("sepal_width").min()]) + .collect() + .unwrap(); + + assert_eq!(new.shape(), (1, 2)); +} + +#[test] +fn test_lazy_pushdown_through_agg() { + // An aggregation changes the schema names, check if the pushdown succeeds. + let df = get_df(); + let new = df + .lazy() + .group_by([col("variety")]) + .agg([ + col("sepal_length").min(), + col("petal_length").min().alias("foo"), + ]) + .select([col("foo")]) + // second selection is to test if optimizer can handle that + .select([col("foo").alias("bar")]) + .collect() + .unwrap(); + + assert_eq!(new.shape(), (1, 1)); + let bar = new.column("bar").unwrap(); + assert_eq!(bar.get(0).unwrap(), AnyValue::Float64(1.3)); +} + +#[test] +fn test_lazy_shift() { + let df = get_df(); + let new = df + .lazy() + .select([col("sepal_width").alias("foo").shift(lit(2))]) + .collect() + .unwrap(); + assert_eq!(new.column("foo").unwrap().f64().unwrap().get(0), None); +} + +#[test] +fn test_shift_and_fill() -> PolarsResult<()> { + let out = df![ + "a" => [1, 2, 3] + ]? + .lazy() + .select([col("a").shift_and_fill(lit(-1), lit(5))]) + .collect()?; + + let out = out.column("a")?; + assert_eq!(Vec::from(out.i32()?), &[Some(2), Some(3), Some(5)]); + Ok(()) +} + +#[test] +fn test_shift_and_fill_non_numeric() -> PolarsResult<()> { + let out = df![ + "bool" => [true, false, true], + ]? + .lazy() + .select([col("bool").shift_and_fill(1, true)]) + .collect()?; + + let out = out.column("bool")?; + assert_eq!( + Vec::from(out.bool()?), + &[Some(true), Some(true), Some(false)] + ); + Ok(()) +} + +#[test] +fn test_lazy_ternary_and_predicates() { + let df = get_df(); + // test if this runs. This failed because is_not_null changes the schema name, so we + // really need to check the root column + let ldf = df + .clone() + .lazy() + .with_column(lit(3).alias("foo")) + .filter(col("foo").is_not_null()); + let _new = ldf.collect().unwrap(); + + let ldf = df + .lazy() + .with_column( + when(col("sepal_length").lt(lit(5.0))) + .then( + lit(3), // is another type on purpose to check type coercion + ) + .otherwise(col("sepal_width")) + .alias("foo"), + ) + .filter(col("foo").gt(lit(3.0))); + + let new = ldf.collect().unwrap(); + let length = new.column("sepal_length").unwrap(); + assert_eq!( + length, + &Column::new("sepal_length".into(), &[5.1f64, 5.0, 5.4]) + ); + assert_eq!(new.shape(), (3, 6)); +} + +#[test] +fn test_lazy_binary_ops() { + let df = df!("a" => &[1, 2, 3, 4, 5, ]).unwrap(); + let new = df + .lazy() + .select([col("a").eq(lit(2)).alias("foo")]) + .collect() + .unwrap(); + assert_eq!( + new.column("foo") + .unwrap() + .as_materialized_series() + .sum::() + .unwrap(), + 1 + ); +} + +#[test] +fn test_lazy_query_2() { + let df = load_df(); + let ldf = df + .lazy() + .with_column(col("a").map(|s| Ok(Some(s * 2)), GetOutput::same_type())) + .filter(col("a").lt(lit(2))) + .select([col("b"), col("a")]); + + let new = ldf.collect().unwrap(); + assert_eq!(new.shape(), (0, 2)); +} + +#[test] +#[cfg(feature = "csv")] +fn test_lazy_query_3() { + // query checks if schema of scanning is not changed by aggregation + let _ = scan_foods_csv() + .group_by([col("calories")]) + .agg([col("fats_g").max()]) + .collect() + .unwrap(); +} + +#[test] +fn test_lazy_query_4() -> PolarsResult<()> { + let df = df! { + "uid" => [0, 0, 0, 1, 1, 1], + "day" => [1, 2, 3, 1, 2, 3], + "cumcases" => [10, 12, 15, 25, 30, 41] + } + .unwrap(); + + let base_df = df.lazy(); + + let out = base_df + .clone() + .group_by([col("uid")]) + .agg([ + col("day").alias("day"), + col("cumcases") + .apply( + |s: Column| (&s - &(s.shift(1))).map(Some), + GetOutput::same_type(), + ) + .alias("diff_cases"), + ]) + .explode([col("day"), col("diff_cases")]) + .join( + base_df, + [col("uid"), col("day")], + [col("uid"), col("day")], + JoinType::Inner.into(), + ) + .collect() + .unwrap(); + assert_eq!( + Vec::from(out.column("diff_cases").unwrap().i32().unwrap()), + &[None, Some(2), Some(3), None, Some(5), Some(11)] + ); + + Ok(()) +} + +#[test] +fn test_lazy_query_5() { + // if this one fails, the list builder probably does not handle offsets + let df = df! { + "uid" => [0, 0, 0, 1, 1, 1], + "day" => [1, 2, 4, 1, 2, 3], + "cumcases" => [10, 12, 15, 25, 30, 41] + } + .unwrap(); + + let out = df + .lazy() + .group_by([col("uid")]) + .agg([col("day").head(Some(2))]) + .collect() + .unwrap(); + let s = out + .select_at_idx(1) + .unwrap() + .list() + .unwrap() + .get_as_series(0) + .unwrap(); + assert_eq!(s.len(), 2); + let s = out + .select_at_idx(1) + .unwrap() + .list() + .unwrap() + .get_as_series(0) + .unwrap(); + assert_eq!(s.len(), 2); +} + +#[test] +#[cfg(feature = "is_in")] +fn test_lazy_query_8() -> PolarsResult<()> { + // https://github.com/pola-rs/polars/issues/842 + let df = df![ + "A" => [1, 2, 3], + "B" => [1, 2, 3], + "C" => [1, 2, 3], + "D" => [1, 2, 3], + "E" => [1, 2, 3] + ]?; + + let mut selection = vec![]; + + for &c in &["A", "B", "C", "D", "E"] { + let e = when(col(c).is_in(col("E"), false)) + .then(col("A")) + .otherwise(Null {}.lit()) + .alias(c); + selection.push(e); + } + + let out = df + .lazy() + .select(selection) + .filter(col("D").gt(lit(1))) + .collect()?; + assert_eq!(out.shape(), (2, 5)); + Ok(()) +} + +#[test] +fn test_lazy_query_9() -> PolarsResult<()> { + // https://github.com/pola-rs/polars/issues/958 + let cities = df![ + "Cities.City"=> ["Moscow", "Berlin", "Paris","Hamburg", "Lyon", "Novosibirsk"], + "Cities.Population"=> [11.92, 3.645, 2.161, 1.841, 0.513, 1.511], + "Cities.Country"=> ["Russia", "Germany", "France", "Germany", "France", "Russia"] + ]?; + + let sales = df![ + "Sales.City"=> ["Moscow", "Berlin", "Paris", "Moscow", "Berlin", "Paris", "Moscow", "Berlin", "Paris"], + "Sales.Item"=> ["Item A", "Item A","Item A", + "Item B", "Item B","Item B", + "Item C", "Item C","Item C"], + "Sales.Amount"=> [200, 180, 100, + 3, 30, 20, + 90, 130, 125] + ]?; + + let out = sales + .lazy() + .join( + cities.lazy(), + [col("Sales.City")], + [col("Cities.City")], + JoinType::Inner.into(), + ) + .group_by([col("Cities.Country")]) + .agg([col("Sales.Amount").sum().alias("sum")]) + .sort(["sum"], Default::default()) + .collect()?; + let vals = out + .column("sum")? + .i32()? + .into_no_null_iter() + .collect::>(); + assert_eq!(vals, &[245, 293, 340]); + Ok(()) +} + +#[test] +#[cfg(all( + feature = "temporal", + feature = "dtype-datetime", + feature = "dtype-date", + feature = "dtype-duration" +))] +fn test_lazy_query_10() { + use chrono::Duration as ChronoDuration; + let date = NaiveDate::from_ymd_opt(2021, 3, 5).unwrap(); + let x = DatetimeChunked::from_naive_datetime( + "x".into(), + [ + NaiveDateTime::new(date, NaiveTime::from_hms_opt(12, 0, 0).unwrap()), + NaiveDateTime::new(date, NaiveTime::from_hms_opt(13, 0, 0).unwrap()), + NaiveDateTime::new(date, NaiveTime::from_hms_opt(14, 0, 0).unwrap()), + ], + TimeUnit::Nanoseconds, + ) + .into_column(); + let y = DatetimeChunked::from_naive_datetime( + "y".into(), + [ + NaiveDateTime::new(date, NaiveTime::from_hms_opt(11, 0, 0).unwrap()), + NaiveDateTime::new(date, NaiveTime::from_hms_opt(11, 0, 0).unwrap()), + NaiveDateTime::new(date, NaiveTime::from_hms_opt(11, 0, 0).unwrap()), + ], + TimeUnit::Nanoseconds, + ) + .into_column(); + let df = DataFrame::new(vec![x, y]).unwrap(); + let out = df + .lazy() + .select(&[(col("x") - col("y")).alias("z")]) + .collect() + .unwrap(); + let z = DurationChunked::from_duration( + "z".into(), + [ + ChronoDuration::try_hours(1).unwrap(), + ChronoDuration::try_hours(2).unwrap(), + ChronoDuration::try_hours(3).unwrap(), + ], + TimeUnit::Nanoseconds, + ) + .into_column(); + assert!(out.column("z").unwrap().equals(&z)); + let x = DatetimeChunked::from_naive_datetime( + "x".into(), + [ + NaiveDateTime::new(date, NaiveTime::from_hms_opt(2, 0, 0).unwrap()), + NaiveDateTime::new(date, NaiveTime::from_hms_opt(3, 0, 0).unwrap()), + NaiveDateTime::new(date, NaiveTime::from_hms_opt(4, 0, 0).unwrap()), + ], + TimeUnit::Milliseconds, + ) + .into_column(); + let y = DatetimeChunked::from_naive_datetime( + "y".into(), + [ + NaiveDateTime::new(date, NaiveTime::from_hms_opt(1, 0, 0).unwrap()), + NaiveDateTime::new(date, NaiveTime::from_hms_opt(1, 0, 0).unwrap()), + NaiveDateTime::new(date, NaiveTime::from_hms_opt(1, 0, 0).unwrap()), + ], + TimeUnit::Nanoseconds, + ) + .into_column(); + let df = DataFrame::new(vec![x, y]).unwrap(); + let out = df + .lazy() + .select(&[(col("x") - col("y")).alias("z")]) + .collect() + .unwrap(); + assert!( + out.column("z") + .unwrap() + .equals(&z.cast(&DataType::Duration(TimeUnit::Milliseconds)).unwrap()) + ); +} + +#[test] +#[cfg(all( + feature = "temporal", + feature = "dtype-date", + feature = "dtype-datetime" +))] +fn test_lazy_query_7() { + let date = NaiveDate::from_ymd_opt(2021, 3, 5).unwrap(); + let dates = [ + NaiveDateTime::new(date, NaiveTime::from_hms_opt(12, 0, 0).unwrap()), + NaiveDateTime::new(date, NaiveTime::from_hms_opt(12, 1, 0).unwrap()), + NaiveDateTime::new(date, NaiveTime::from_hms_opt(12, 2, 0).unwrap()), + NaiveDateTime::new(date, NaiveTime::from_hms_opt(12, 3, 0).unwrap()), + NaiveDateTime::new(date, NaiveTime::from_hms_opt(12, 4, 0).unwrap()), + NaiveDateTime::new(date, NaiveTime::from_hms_opt(12, 5, 0).unwrap()), + ]; + let data = vec![Some(1.), Some(2.), Some(3.), Some(4.), None, None]; + let df = DataFrame::new(vec![ + DatetimeChunked::from_naive_datetime("date".into(), dates, TimeUnit::Nanoseconds) + .into_column(), + Column::new("data".into(), data), + ]) + .unwrap(); + // this tests if predicate pushdown not interferes with the shift data. + let out = df + .lazy() + .with_column(col("data").shift(lit(-1)).alias("output")) + .with_column(col("output").shift(lit(2)).alias("shifted")) + .filter(col("date").gt(lit(NaiveDateTime::new( + date, + NaiveTime::from_hms_opt(12, 2, 0).unwrap(), + )))) + .collect() + .unwrap(); + let a = out + .column("shifted") + .unwrap() + .as_materialized_series() + .sum::() + .unwrap() + - 7.0; + assert!(a < 0.01 && a > -0.01); +} + +#[test] +fn test_lazy_shift_and_fill_all() { + let data = &[1, 2, 3]; + let df = DataFrame::new(vec![Column::new("data".into(), data)]).unwrap(); + let out = df + .lazy() + .with_column(col("data").shift(lit(1)).fill_null(lit(0)).alias("output")) + .collect() + .unwrap(); + assert_eq!( + Vec::from(out.column("output").unwrap().i32().unwrap()), + vec![Some(0), Some(1), Some(2)] + ); +} + +#[test] +fn test_lazy_shift_operation_no_filter() { + // check if predicate pushdown optimization does not fail + let df = df! { + "a" => &[1, 2, 3], + "b" => &[1, 2, 3] + } + .unwrap(); + df.lazy() + .with_column(col("b").shift(lit(1)).alias("output")) + .collect() + .unwrap(); +} + +#[test] +fn test_simplify_expr() { + // Test if expression containing literals is simplified + let df = get_df(); + + let plan = df + .lazy() + .select(&[lit(1.0) + lit(1.0) + col("sepal_width")]) + .logical_plan; + + let mut expr_arena = Arena::new(); + let mut lp_arena = Arena::new(); + + #[allow(const_item_mutation)] + let lp_top = to_alp( + plan, + &mut expr_arena, + &mut lp_arena, + &mut OptFlags::SIMPLIFY_EXPR, + ) + .unwrap(); + let plan = node_to_lp(lp_top, &expr_arena, &mut lp_arena); + assert!( + matches!(plan, DslPlan::Select{ expr, ..} if matches!(&expr[0], Expr::BinaryExpr{left, ..} if **left == Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(2.0))))) + ); +} + +#[test] +fn test_lazy_wildcard() { + let df = load_df(); + let new = df.clone().lazy().select([col("*")]).collect().unwrap(); + assert_eq!(new.shape(), (5, 3)); + + let new = df + .lazy() + .group_by([col("b")]) + .agg([ + col("*").sum().name().suffix(""), + col("*").first().name().suffix("_first"), + ]) + .collect() + .unwrap(); + assert_eq!(new.shape(), (3, 5)); // Should exclude b from wildcard aggregations. +} + +#[test] +fn test_lazy_reverse() { + let df = load_df(); + assert!( + df.clone() + .lazy() + .reverse() + .collect() + .unwrap() + .equals_missing(&df.reverse()) + ) +} + +#[test] +fn test_lazy_fill_null() { + let df = df! { + "a" => &[None, Some(2.0)], + "b" => &[Some(1.0), None] + } + .unwrap(); + let out = df.lazy().fill_null(lit(10.0)).collect().unwrap(); + let correct = df! { + "a" => &[Some(10.0), Some(2.0)], + "b" => &[Some(1.0), Some(10.0)] + } + .unwrap(); + assert!(out.equals(&correct)); + assert_eq!(out.get_column_names(), vec!["a", "b"]) +} + +#[test] +fn test_lazy_double_projection() { + let df = df! { + "foo" => &[1, 2, 3] + } + .unwrap(); + df.lazy() + .select([col("foo").alias("bar")]) + .select([col("bar")]) + .collect() + .unwrap(); +} + +#[test] +fn test_type_coercion() { + let df = df! { + "foo" => &[1, 2, 3], + "bar" => &[1.0, 2.0, 3.0] + } + .unwrap(); + + let lp = df.lazy().select([col("foo") * col("bar")]).logical_plan; + + let mut expr_arena = Arena::new(); + let mut lp_arena = Arena::new(); + let lp_top = to_alp(lp, &mut expr_arena, &mut lp_arena, &mut OptFlags::default()).unwrap(); + let lp = node_to_lp(lp_top, &expr_arena, &mut lp_arena); + + if let DslPlan::Select { expr, .. } = lp { + if let Expr::BinaryExpr { left, right, .. } = &expr[0] { + assert!(matches!(&**left, Expr::Cast { .. })); + // bar is already float, does not have to be coerced + assert!(matches!(&**right, Expr::Column { .. })); + } else { + panic!() + } + }; +} + +#[test] +#[cfg(feature = "csv")] +fn test_lazy_partition_agg() { + let df = df! { + "foo" => &[1, 1, 2, 2, 3], + "bar" => &[1.0, 1.0, 2.0, 2.0, 3.0] + } + .unwrap(); + + let out = df + .lazy() + .group_by([col("foo")]) + .agg([col("bar").mean()]) + .sort(["foo"], Default::default()) + .collect() + .unwrap(); + + assert_eq!( + Vec::from(out.column("bar").unwrap().f64().unwrap()), + &[Some(1.0), Some(2.0), Some(3.0)] + ); + + let out = scan_foods_csv() + .group_by([col("category")]) + .agg([col("calories")]) + .sort(["category"], Default::default()) + .collect() + .unwrap(); + let cat_agg_list = out.select_at_idx(1).unwrap(); + let fruit_series = cat_agg_list.list().unwrap().get_as_series(0).unwrap(); + let fruit_list = fruit_series.i64().unwrap(); + assert_eq!( + Vec::from(fruit_list), + &[ + Some(60), + Some(30), + Some(50), + Some(30), + Some(60), + Some(130), + Some(50), + ] + ) +} + +#[test] +fn test_lazy_group_by_apply() { + let df = fruits_cars(); + + df.lazy() + .group_by([col("fruits")]) + .agg([col("cars").apply( + |s: Column| Ok(Some(Column::new("".into(), &[s.len() as u32]))), + GetOutput::from_type(DataType::UInt32), + )]) + .collect() + .unwrap(); +} + +#[test] +fn test_lazy_shift_and_fill() { + let df = df! { + "A" => &[1, 2, 3, 4, 5], + "B" => &[5, 4, 3, 2, 1] + } + .unwrap(); + let out = df + .clone() + .lazy() + .with_column(col("A").shift_and_fill(lit(2), col("B").mean())) + .collect() + .unwrap(); + assert_eq!(out.column("A").unwrap().null_count(), 0); + + // shift from the other side + let out = df + .clone() + .lazy() + .with_column(col("A").shift_and_fill(lit(-2), col("B").mean())) + .collect() + .unwrap(); + assert_eq!(out.column("A").unwrap().null_count(), 0); + + let out = df + .lazy() + .shift_and_fill(lit(-1), col("B").std(1)) + .collect() + .unwrap(); + assert_eq!(out.column("A").unwrap().null_count(), 0); +} + +#[test] +fn test_lazy_group_by() { + let df = df! { + "a" => &[Some(1.0), None, Some(3.0), Some(4.0), Some(5.0)], + "groups" => &["a", "a", "b", "c", "c"] + } + .unwrap(); + + let out = df + .lazy() + .group_by([col("groups")]) + .agg([col("a").mean()]) + .sort(["a"], Default::default()) + .collect() + .unwrap(); + + assert_eq!(out.column("a").unwrap().f64().unwrap().get(0), Some(1.0)); +} + +#[test] +fn test_lazy_tail() { + let df = df! { + "A" => &[1, 2, 3, 4, 5], + "B" => &[5, 4, 3, 2, 1] + } + .unwrap(); + + let _out = df.lazy().tail(3).collect().unwrap(); +} + +#[test] +fn test_lazy_group_by_sort() { + let df = df! { + "a" => ["a", "b", "a", "b", "b", "c"], + "b" => [1, 2, 3, 4, 5, 6] + } + .unwrap(); + + let out = df + .clone() + .lazy() + .group_by([col("a")]) + .agg([col("b").sort(Default::default()).first()]) + .collect() + .unwrap() + .sort(["a"], Default::default()) + .unwrap(); + + assert_eq!( + Vec::from(out.column("b").unwrap().i32().unwrap()), + [Some(1), Some(2), Some(6)] + ); + + let out = df + .lazy() + .group_by([col("a")]) + .agg([col("b").sort(Default::default()).last()]) + .collect() + .unwrap() + .sort(["a"], Default::default()) + .unwrap(); + + assert_eq!( + Vec::from(out.column("b").unwrap().i32().unwrap()), + [Some(3), Some(5), Some(6)] + ); +} + +#[test] +fn test_lazy_group_by_sort_by() { + let df = df! { + "a" => ["a", "a", "a", "b", "b", "c"], + "b" => [1, 2, 3, 4, 5, 6], + "c" => [6, 1, 4, 3, 2, 1] + } + .unwrap(); + + let out = df + .lazy() + .group_by([col("a")]) + .agg([col("b") + .sort_by( + [col("c")], + SortMultipleOptions::default().with_order_descending(true), + ) + .first()]) + .collect() + .unwrap() + .sort(["a"], Default::default()) + .unwrap(); + + assert_eq!( + Vec::from(out.column("b").unwrap().i32().unwrap()), + [Some(1), Some(4), Some(6)] + ); +} + +#[test] +#[cfg(feature = "dtype-datetime")] +fn test_lazy_group_by_cast() { + let df = df! { + "a" => ["a", "a", "a", "b", "b", "c"], + "b" => [1, 2, 3, 4, 5, 6] + } + .unwrap(); + + // test if it runs in group_by context + let _out = df + .lazy() + .group_by([col("a")]) + .agg([col("b") + .mean() + .cast(DataType::Datetime(TimeUnit::Nanoseconds, None))]) + .collect() + .unwrap(); +} + +#[test] +fn test_lazy_group_by_binary_expr() { + let df = df! { + "a" => ["a", "a", "a", "b", "b", "c"], + "b" => [1, 2, 3, 4, 5, 6] + } + .unwrap(); + + // test if it runs in group_by context + let out = df + .lazy() + .group_by([col("a")]) + .agg([col("b").mean() * lit(2)]) + .sort(["a"], Default::default()) + .collect() + .unwrap(); + assert_eq!( + Vec::from(out.column("b").unwrap().f64().unwrap()), + [Some(4.0), Some(9.0), Some(12.0)] + ); +} + +#[test] +fn test_lazy_group_by_filter() -> PolarsResult<()> { + let df = df! { + "a" => ["a", "a", "a", "b", "b", "c"], + "b" => [1, 2, 3, 4, 5, 6] + }?; + + // We test if the filters work in the group_by context + // and that the aggregations can deal with empty sets + + let out = df + .lazy() + .group_by([col("a")]) + .agg([ + col("b").filter(col("a").eq(lit("a"))).sum().alias("b_sum"), + col("b") + .filter(col("a").eq(lit("a"))) + .first() + .alias("b_first"), + col("b") + .filter(col("a").eq(lit("e"))) + .mean() + .alias("b_mean"), + col("b") + .filter(col("a").eq(lit("a"))) + .last() + .alias("b_last"), + ]) + .sort(["a"], SortMultipleOptions::default()) + .collect()?; + + assert_eq!( + Vec::from(out.column("b_sum").unwrap().i32().unwrap()), + [Some(6), Some(0), Some(0)] + ); + assert_eq!( + Vec::from(out.column("b_first").unwrap().i32().unwrap()), + [Some(1), None, None] + ); + assert_eq!( + Vec::from(out.column("b_mean").unwrap().f64().unwrap()), + [None, None, None] + ); + assert_eq!( + Vec::from(out.column("b_last").unwrap().i32().unwrap()), + [Some(3), None, None] + ); + + Ok(()) +} + +#[test] +fn test_group_by_projection_pd_same_column() -> PolarsResult<()> { + // this query failed when projection pushdown was enabled + + let a = || { + let df = df![ + "col1" => ["a", "ab", "abc"], + "col2" => [1, 2, 3] + ] + .unwrap(); + + df.lazy() + .select([col("col1").alias("foo"), col("col2").alias("bar")]) + }; + + let out = a() + .left_join(a(), col("foo"), col("foo")) + .select([col("bar")]) + .collect()?; + + let a = out.column("bar")?.i32()?; + assert_eq!(Vec::from(a), &[Some(1), Some(2), Some(3)]); + + Ok(()) +} + +#[test] +fn test_group_by_sort_slice() -> PolarsResult<()> { + let df = df![ + "groups" => [1, 2, 2, 3, 3, 3], + "vals" => [1, 5, 6, 3, 9, 8] + ]?; + // get largest two values per groups + + // expected: + // group values + // 1 1 + // 2 6, 5 + // 3 9, 8 + + let out1 = df + .clone() + .lazy() + .sort( + ["vals"], + SortMultipleOptions::default().with_order_descending(true), + ) + .group_by([col("groups")]) + .agg([col("vals").head(Some(2)).alias("foo")]) + .sort(["groups"], Default::default()) + .collect()?; + + let out2 = df + .lazy() + .group_by([col("groups")]) + .agg([col("vals") + .sort(SortOptions::default().with_order_descending(true)) + .head(Some(2)) + .alias("foo")]) + .sort(["groups"], Default::default()) + .collect()?; + + assert!(out1.column("foo")?.equals(out2.column("foo")?)); + Ok(()) +} + +#[test] +#[cfg(feature = "cum_agg")] +fn test_group_by_cum_sum() -> PolarsResult<()> { + let df = df![ + "groups" => [1, 2, 2, 3, 3, 3], + "vals" => [1, 5, 6, 3, 9, 8] + ]?; + + let out = df + .lazy() + .group_by([col("groups")]) + .agg([col("vals").cum_sum(false)]) + .sort(["groups"], Default::default()) + .collect()?; + + assert_eq!( + Vec::from(out.column("vals")?.explode()?.i32()?), + [1, 5, 11, 3, 12, 20] + .iter() + .copied() + .map(Some) + .collect::>() + ); + + Ok(()) +} + +#[test] +#[cfg(feature = "range")] +fn test_arg_sort_multiple() -> PolarsResult<()> { + let df = df![ + "int" => [1, 2, 3, 1, 2], + "flt" => [3.0, 2.0, 1.0, 2.0, 1.0], + "str" => ["a", "a", "a", "b", "b"] + ]?; + + let out = df + .clone() + .lazy() + .select([arg_sort_by( + [col("int"), col("flt")], + SortMultipleOptions::default().with_order_descending_multi([true, false]), + )]) + .collect()?; + + assert_eq!( + Vec::from(out.column("int")?.idx()?), + [2, 4, 1, 3, 0] + .iter() + .copied() + .map(Some) + .collect::>() + ); + + // check if this runs + let _out = df + .lazy() + .select([arg_sort_by( + [col("str"), col("flt")], + SortMultipleOptions::default().with_order_descending_multi([true, false]), + )]) + .collect()?; + Ok(()) +} + +#[test] +fn test_multiple_explode() -> PolarsResult<()> { + let df = df![ + "a" => [0, 1, 2, 0, 2], + "b" => [5, 4, 3, 2, 1], + "c" => [2, 3, 4, 1, 5] + ]?; + + let out = df + .lazy() + .group_by([col("a")]) + .agg([col("b").alias("b_list"), col("c").alias("c_list")]) + .explode([col("c_list"), col("b_list")]) + .collect()?; + assert_eq!(out.shape(), (5, 3)); + + Ok(()) +} + +#[test] +fn test_filter_and_alias() -> PolarsResult<()> { + let df = df![ + "a" => [0, 1, 2, 0, 2] + ]?; + + let out = df + .lazy() + .with_column(col("a").pow(2.0).alias("a_squared")) + .filter(col("a_squared").gt(lit(1)).and(col("a").gt(lit(1)))) + .collect()?; + + let expected = df![ + "a" => [2, 2], + "a_squared" => [4.0, 4.0] + ]?; + println!("{:?}", out); + println!("{:?}", expected); + assert!(out.equals(&expected)); + Ok(()) +} + +#[test] +fn test_filter_lit() { + // see https://github.com/pola-rs/polars/issues/790 + // failed due to broadcasting filters and splitting threads. + let iter = (0..100).map(|i| ('A'..='Z').nth(i % 26).unwrap().to_string()); + let a = Series::from_iter(iter).into_column(); + let df = DataFrame::new([a].into()).unwrap(); + + let out = df.lazy().filter(lit(true)).collect().unwrap(); + assert_eq!(out.shape(), (100, 1)); +} + +#[test] +fn test_ternary_null() -> PolarsResult<()> { + let df = df![ + "a" => ["a", "b", "c"] + ]?; + + let out = df + .lazy() + .select([when(col("a").eq(lit("c"))) + .then(Null {}.lit()) + .otherwise(col("a")) + .alias("foo")]) + .collect()?; + + assert_eq!( + out.column("foo")?.is_null().into_iter().collect::>(), + &[Some(false), Some(false), Some(true)] + ); + Ok(()) +} + +#[test] +fn test_fill_forward() -> PolarsResult<()> { + let df = df![ + "a" => ["a", "b", "a"], + "b" => [Some(1), None, None] + ]?; + + let out = df + .lazy() + .select([col("b") + .fill_null_with_strategy(FillNullStrategy::Forward(FillNullLimit::None)) + .over_with_options([col("a")], None, WindowMapping::Join)]) + .collect()?; + let agg = out.column("b")?.list()?; + + let a: Series = agg.get_as_series(0).unwrap(); + assert!(a.equals(&Series::new("b".into(), &[1, 1]))); + let a: Series = agg.get_as_series(2).unwrap(); + assert!(a.equals(&Series::new("b".into(), &[1, 1]))); + let a: Series = agg.get_as_series(1).unwrap(); + assert_eq!(a.null_count(), 1); + Ok(()) +} + +#[cfg(feature = "cross_join")] +#[test] +fn test_cross_join() -> PolarsResult<()> { + let df1 = df![ + "a" => ["a", "b", "a"], + "b" => [Some(1), None, None] + ]?; + + let df2 = df![ + "a" => [1, 2], + "b" => [None, Some(12)] + ]?; + + let out = df1.lazy().cross_join(df2.lazy(), None).collect()?; + assert_eq!(out.shape(), (6, 4)); + Ok(()) +} + +#[test] +fn test_select_empty_df() -> PolarsResult<()> { + // https://github.com/pola-rs/polars/issues/1056 + let df1 = df![ + "a" => [1, 2, 3], + "b" => [1, 2, 3] + ]?; + + let out = df1 + .lazy() + .filter(col("a").eq(lit(0))) // this will lead to an empty frame + .select([col("a"), lit(1).alias("c")]) + .collect()?; + + assert_eq!(out.column("a")?.len(), 0); + assert_eq!(out.column("c")?.len(), 0); + + Ok(()) +} + +#[test] +fn test_keep_name() -> PolarsResult<()> { + let df = df![ + "a" => [1, 2, 3], + "b" => [1, 2, 3] + ]?; + + let out = df + .lazy() + .select([ + col("a").alias("bar").name().keep(), + col("b").alias("bar").name().keep(), + ]) + .collect()?; + + assert_eq!(out.get_column_names(), &["a", "b"]); + Ok(()) +} + +#[test] +fn test_exclude() -> PolarsResult<()> { + let df = df![ + "a" => [1, 2, 3], + "b" => [1, 2, 3], + "c" => [1, 2, 3] + ]?; + + let out = df.lazy().select([col("*").exclude(["b"])]).collect()?; + + assert_eq!(out.get_column_names(), &["a", "c"]); + Ok(()) +} + +#[test] +#[cfg(feature = "regex")] +fn test_regex_selection() -> PolarsResult<()> { + let df = df![ + "anton" => [1, 2, 3], + "arnold schwars" => [1, 2, 3], + "annie" => [1, 2, 3] + ]?; + + let out = df.lazy().select([col("^a.*o.*$")]).collect()?; + + assert_eq!(out.get_column_names(), &["anton", "arnold schwars"]); + Ok(()) +} + +#[test] +fn test_sort_by() -> PolarsResult<()> { + let df = df![ + "a" => [1, 2, 3, 4, 5], + "b" => [1, 1, 1, 2, 2], + "c" => [2, 3, 1, 2, 1] + ]?; + + // evaluate + let out = df + .clone() + .lazy() + .select([col("a").sort_by([col("b"), col("c")], SortMultipleOptions::default())]) + .collect()?; + + let a = out.column("a")?; + assert_eq!( + Vec::from(a.i32().unwrap()), + &[Some(3), Some(1), Some(2), Some(5), Some(4)] + ); + + // aggregate + let out = df + .clone() + .lazy() + .group_by_stable([col("b")]) + .agg([col("a").sort_by([col("b"), col("c")], SortMultipleOptions::default())]) + .collect()?; + let a = out.column("a")?.explode()?; + assert_eq!( + Vec::from(a.i32().unwrap()), + &[Some(3), Some(1), Some(2), Some(5), Some(4)] + ); + + // evaluate_on_groups + let out = df + .lazy() + .group_by_stable([col("b")]) + .agg([col("a").sort_by([col("b"), col("c")], SortMultipleOptions::default())]) + .collect()?; + + let a = out.column("a")?.explode()?; + assert_eq!( + Vec::from(a.i32().unwrap()), + &[Some(3), Some(1), Some(2), Some(5), Some(4)] + ); + + Ok(()) +} + +#[test] +fn test_filter_after_shift_in_groups() -> PolarsResult<()> { + let df = fruits_cars(); + + let out = df + .lazy() + .select([ + col("fruits"), + col("B") + .shift(lit(1)) + .filter(col("B").shift(lit(1)).gt(lit(4))) + .over_with_options([col("fruits")], None, WindowMapping::Join) + .alias("filtered"), + ]) + .collect()?; + + assert_eq!( + out.column("filtered")? + .list()? + .get_as_series(0) + .unwrap() + .i32()? + .get(0) + .unwrap(), + 5 + ); + assert_eq!( + out.column("filtered")? + .list()? + .get_as_series(1) + .unwrap() + .i32()? + .get(0) + .unwrap(), + 5 + ); + assert_eq!( + out.column("filtered")? + .list()? + .get_as_series(2) + .unwrap() + .len(), + 0 + ); + + Ok(()) +} + +#[test] +fn test_lazy_ternary_predicate_pushdown() -> PolarsResult<()> { + let df = df![ + "a" => &[10, 1, 2, 3] + ]?; + + let out = df + .lazy() + .select([when(col("a").eq(lit(10))) + .then(Null {}.lit()) + .otherwise(col("a"))]) + .drop_nulls(None) + .collect()?; + + assert_eq!( + Vec::from(out.get_columns()[0].i32()?), + &[Some(1), Some(2), Some(3)] + ); + + Ok(()) +} + +#[test] +#[cfg(feature = "dtype-categorical")] +fn test_categorical_addition() -> PolarsResult<()> { + let df = fruits_cars(); + + // test if we can do that arithmetic operation with String and Categorical + let out = df + .lazy() + .select([ + col("fruits").cast(DataType::Categorical(None, Default::default())), + col("cars").cast(DataType::Categorical(None, Default::default())), + ]) + .select([(col("fruits") + lit(" ") + col("cars")).alias("foo")]) + .collect()?; + + assert_eq!(out.column("foo")?.str()?.get(0).unwrap(), "banana beetle"); + + Ok(()) +} + +#[test] +fn test_error_duplicate_names() { + let df = fruits_cars(); + assert!(df.lazy().select([col("*"), col("*")]).collect().is_err()); +} + +#[test] +fn test_filter_count() -> PolarsResult<()> { + let df = fruits_cars(); + let out = df + .lazy() + .select([col("fruits") + .filter(col("fruits").eq(lit("banana"))) + .count()]) + .collect()?; + assert_eq!(out.column("fruits")?.idx()?.get(0), Some(3)); + Ok(()) +} + +#[test] +#[cfg(feature = "dtype-i16")] +fn test_group_by_small_ints() -> PolarsResult<()> { + let df = df![ + "id_32" => [1i32, 2], + "id_16" => [1i16, 2] + ]?; + + // https://github.com/pola-rs/polars/issues/1255 + let out = df + .lazy() + .group_by([col("id_16"), col("id_32")]) + .agg([col("id_16").sum().alias("foo")]) + .sort( + ["foo"], + SortMultipleOptions::default().with_order_descending(true), + ) + .collect()?; + + assert_eq!(Vec::from(out.column("foo")?.i64()?), &[Some(2), Some(1)]); + Ok(()) +} + +#[test] +fn test_when_then_schema() -> PolarsResult<()> { + let df = fruits_cars(); + + let schema = df + .lazy() + .select([when(col("A").gt(lit(1))) + .then(Null {}.lit()) + .otherwise(col("A"))]) + .collect_schema(); + assert_ne!(schema?.get_at_index(0).unwrap().1, &DataType::Null); + + Ok(()) +} + +#[test] +fn test_singleton_broadcast() -> PolarsResult<()> { + let df = fruits_cars(); + let out = df + .lazy() + .select([col("fruits"), lit(1).alias("foo")]) + .collect()?; + + assert!(out.column("foo")?.len() > 1); + Ok(()) +} + +#[test] +fn test_list_in_select_context() -> PolarsResult<()> { + let s = Column::new("a".into(), &[1, 2, 3]); + let mut builder = get_list_builder(s.dtype(), s.len(), 1, s.name().clone()); + builder.append_series(s.as_materialized_series()).unwrap(); + let expected = builder.finish().into_column(); + + let df = DataFrame::new(vec![s])?; + + let out = df.lazy().select([col("a").implode()]).collect()?; + + let s = out.column("a")?; + assert!(s.equals(&expected)); + + Ok(()) +} + +#[test] +#[cfg(feature = "round_series")] +fn test_round_after_agg() -> PolarsResult<()> { + let df = fruits_cars(); + + let out = df + .lazy() + .group_by([col("fruits")]) + .agg([col("A") + .cast(DataType::Float32) + .mean() + .round(2) + .alias("foo")]) + .collect()?; + + assert!(out.column("foo")?.f32().is_ok()); + + let df = df![ + "groups" => ["pigeon", + "rabbit", + "rabbit", + "Chris", + "pigeon", + "fast", + "fast", + "pigeon", + "rabbit", + "Chris"], + "b" => [5409, 4848, 4864, 3540, 8103, 3083, 8575, 9963, 8809, 5425], + "c" => [0.4517241160719615, + 0.2551467646274673, + 0.8682045191407308, + 0.9925316385786037, + 0.5392027792928116, + 0.7633847828107002, + 0.7967295231651537, + 0.01444779067224733, + 0.23807484087472652, + 0.10985868798350984] + ]?; + + let out = df + .lazy() + .group_by_stable([col("groups")]) + .agg([((col("b") * col("c")).sum() / col("b").sum()) + .round(2) + .alias("foo")]) + .collect()?; + + let out = out.column("foo")?; + let out = out.f64()?; + + assert_eq!( + Vec::from(out), + &[Some(0.3), Some(0.41), Some(0.46), Some(0.79)] + ); + + Ok(()) +} + +#[test] +#[cfg(feature = "dtype-date")] +fn test_fill_nan() -> PolarsResult<()> { + let s0 = Column::new("date".into(), &[1, 2, 3]).cast(&DataType::Date)?; + let s1 = Column::new("float".into(), &[Some(1.0), Some(f32::NAN), Some(3.0)]); + + let df = DataFrame::new(vec![s0, s1])?; + let out = df.lazy().fill_nan(Null {}.lit()).collect()?; + let out = out.column("float")?; + assert_eq!(Vec::from(out.f32()?), &[Some(1.0), None, Some(3.0)]); + + Ok(()) +} + +#[test] +#[cfg(feature = "regex")] +fn test_exclude_regex() -> PolarsResult<()> { + let df = fruits_cars(); + let out = df + .lazy() + .select([col("*").exclude(["^(fruits|cars)$"])]) + .collect()?; + + assert_eq!(out.get_column_names(), &["A", "B"]); + Ok(()) +} + +#[test] +#[cfg(feature = "rank")] +fn test_group_by_rank() -> PolarsResult<()> { + let df = fruits_cars(); + let out = df + .lazy() + .group_by_stable([col("cars")]) + .agg([col("B").rank( + RankOptions { + method: RankMethod::Dense, + ..Default::default() + }, + None, + )]) + .collect()?; + + let out = out.column("B")?; + let out = out.list()?.get_as_series(1).unwrap(); + let out = out.idx()?; + + assert_eq!(Vec::from(out), &[Some(1)]); + Ok(()) +} + +#[test] +pub fn test_select_by_dtypes() -> PolarsResult<()> { + let df = df![ + "bools" => [true, false, true], + "ints" => [1, 2, 3], + "strings" => ["a", "b", "c"], + "floats" => [1.0, 2.0, 3.0f32] + ]?; + let out = df + .lazy() + .select([dtype_cols([DataType::Float32, DataType::String])]) + .collect()?; + assert_eq!(out.dtypes(), &[DataType::String, DataType::Float32]); + + Ok(()) +} + +#[test] +fn test_binary_expr() -> PolarsResult<()> { + // test panic in schema names + let df = fruits_cars(); + let _ = df.lazy().select([col("A").neq(lit(1))]).collect()?; + + // test type coercion + // https://github.com/pola-rs/polars/issues/1649 + let df = df!( + "nrs"=> [Some(1i64), Some(2), Some(3), None, Some(5)], + "random"=> [0.1f64, 0.6, 0.2, 0.6, 0.3] + )?; + + let other = when(col("random").gt(lit(0.5))) + .then(lit(2)) + .otherwise(col("random")) + .alias("other"); + let out = df.lazy().select([other * col("nrs").sum()]).collect()?; + assert_eq!(out.dtypes(), &[DataType::Float64]); + Ok(()) +} + +#[test] +fn test_single_group_result() -> PolarsResult<()> { + // the arg_sort should not auto explode + let df = df![ + "a" => [1, 2], + "b" => [1, 1] + ]?; + + let out = df + .lazy() + .select([col("a") + .arg_sort(SortOptions { + descending: false, + nulls_last: false, + multithreaded: true, + maintain_order: false, + limit: None, + }) + .over([col("a")])]) + .collect()?; + + let a = out.column("a")?.idx()?; + assert_eq!(Vec::from(a), &[Some(0), Some(0)]); + + Ok(()) +} + +#[test] +#[cfg(feature = "rank")] +fn test_single_ranked_group() -> PolarsResult<()> { + // tests type consistency of rank algorithm + let df = df!["group" => [1, 2, 2], + "value"=> [100, 50, 10] + ]?; + + let out = df + .lazy() + .with_columns([col("value") + .rank( + RankOptions { + method: RankMethod::Average, + ..Default::default() + }, + None, + ) + .over_with_options([col("group")], None, WindowMapping::Join)]) + .collect()?; + + let out = out.column("value")?.explode()?; + let out = out.f64()?; + assert_eq!( + Vec::from(out), + &[Some(1.0), Some(2.0), Some(1.0), Some(2.0), Some(1.0)] + ); + + Ok(()) +} + +#[test] +#[cfg(feature = "diff")] +fn empty_df() -> PolarsResult<()> { + let df = fruits_cars(); + let df = df.filter(&BooleanChunked::full("".into(), false, df.height()))?; + + df.lazy() + .select([ + col("A").shift(lit(1)).alias("1"), + col("A").shift_and_fill(lit(1), lit(1)).alias("2"), + col("A").shift_and_fill(lit(-1), lit(1)).alias("3"), + col("A").fill_null(lit(1)).alias("4"), + col("A").cum_count(false).alias("5"), + col("A").diff(lit(1), NullBehavior::Ignore).alias("6"), + col("A").cum_max(false).alias("7"), + col("A").cum_min(false).alias("8"), + ]) + .collect()?; + + Ok(()) +} + +#[test] +#[cfg(feature = "abs")] +fn test_apply_flatten() -> PolarsResult<()> { + let df = df![ + "A"=> [1.1435, 2.223456, 3.44732, -1.5234, -2.1238, -3.2923], + "B"=> ["a", "b", "a", "b", "a", "b"] + ]?; + + let out = df + .lazy() + .group_by_stable([col("B")]) + .agg([col("A").abs().sum().alias("A_sum")]) + .collect()?; + + let out = out.column("A_sum")?; + assert_eq!(out.get(0)?, AnyValue::Float64(6.71462)); + assert_eq!(out.get(1)?, AnyValue::Float64(7.039156)); + + Ok(()) +} + +#[test] +#[cfg(feature = "is_in")] +fn test_is_in() -> PolarsResult<()> { + let df = fruits_cars(); + + // // this will be executed by apply + let out = df + .clone() + .lazy() + .group_by_stable([col("fruits")]) + .agg([col("cars").is_in( + col("cars").filter(col("cars").eq(lit("beetle"))).implode(), + false, + )]) + .collect()?; + let out = out.column("cars").unwrap(); + let out = out.explode()?; + let out = out.bool().unwrap(); + assert_eq!( + Vec::from(out), + &[Some(true), Some(false), Some(true), Some(true), Some(true)] + ); + + // this will be executed by map + let out = df + .lazy() + .group_by_stable([col("fruits")]) + .agg([col("cars").is_in( + lit(Series::new("a".into(), ["beetle", "vw"])).implode(), + false, + )]) + .collect()?; + + let out = out.column("cars").unwrap(); + let out = out.explode()?; + let out = out.bool().unwrap(); + assert_eq!( + Vec::from(out), + &[Some(true), Some(false), Some(true), Some(true), Some(true)] + ); + + Ok(()) +} + +#[test] +fn test_partitioned_gb_1() -> PolarsResult<()> { + // don't move these to integration tests + // keep these dtypes + let out = df![ + "keys" => [1, 1, 1, 1, 2], + "vals" => ["a", "b", "c", "a", "a"] + ]? + .lazy() + .group_by([col("keys")]) + .agg([ + (col("vals").eq(lit("a"))).sum().alias("eq_a"), + (col("vals").eq(lit("b"))).sum().alias("eq_b"), + ]) + .sort(["keys"], Default::default()) + .collect()?; + + assert!(out.equals(&df![ + "keys" => [1, 2], + "eq_a" => [2 as IdxSize, 1], + "eq_b" => [1 as IdxSize, 0], + ]?)); + + Ok(()) +} + +#[test] +fn test_partitioned_gb_count() -> PolarsResult<()> { + // don't move these to integration tests + let out = df![ + "col" => (0..100).map(|_| Some(0)).collect::().into_series(), + ]? + .lazy() + .group_by([col("col")]) + .agg([ + // we make sure to alias with a different name + len().alias("counted"), + col("col").count().alias("count2"), + ]) + .collect()?; + + assert!(out.equals(&df![ + "col" => [0], + "counted" => [100 as IdxSize], + "count2" => [100 as IdxSize], + ]?)); + + Ok(()) +} + +#[test] +fn test_partitioned_gb_mean() -> PolarsResult<()> { + // don't move these to integration tests + let out = df![ + "key" => (0..100).map(|_| Some(0)).collect::().into_series(), + ]? + .lazy() + .with_columns([lit("a").alias("str"), lit(1).alias("int")]) + .group_by([col("key")]) + .agg([ + col("str").mean().alias("mean_str"), + col("int").mean().alias("mean_int"), + ]) + .collect()?; + + assert_eq!(out.shape(), (1, 3)); + let str_col = out.column("mean_str")?; + assert_eq!(str_col.get(0)?, AnyValue::Null); + let int_col = out.column("mean_int")?; + assert_eq!(int_col.get(0)?, AnyValue::Float64(1.0)); + + Ok(()) +} + +#[test] +fn test_partitioned_gb_binary() -> PolarsResult<()> { + // don't move these to integration tests + let df = df![ + "col" => (0..20).map(|_| Some(0)).collect::().into_series(), + ]?; + + let out = df + .clone() + .lazy() + .group_by([col("col")]) + .agg([(col("col") + lit(10)).sum().alias("sum")]) + .collect()?; + + assert!(out.equals(&df![ + "col" => [0], + "sum" => [200], + ]?)); + + let out = df + .lazy() + .group_by([col("col")]) + .agg([(col("col").cast(DataType::Float32) + lit(10.0)) + .sum() + .alias("sum")]) + .collect()?; + + assert!(out.equals(&df![ + "col" => [0], + "sum" => [200.0_f32], + ]?)); + + Ok(()) +} + +#[test] +fn test_partitioned_gb_ternary() -> PolarsResult<()> { + // don't move these to integration tests + let df = df![ + "col" => (0..20).map(|_| Some(0)).collect::().into_series(), + "val" => (0..20).map(Some).collect::().into_series(), + ]?; + + let out = df + .lazy() + .group_by([col("col")]) + .agg([when(col("val").gt(lit(10))) + .then(lit(1)) + .otherwise(lit(0)) + .sum() + .alias("sum")]) + .collect()?; + + assert!(out.equals(&df![ + "col" => [0], + "sum" => [9], + ]?)); + + Ok(()) +} + +#[test] +fn test_sort_maintain_order_true() -> PolarsResult<()> { + let q = df![ + "A" => [1, 1, 1, 1], + "B" => ["A", "B", "C", "D"], + ]? + .lazy(); + + let res = q + .sort_by_exprs( + [col("A")], + SortMultipleOptions::default() + .with_maintain_order(true) + .with_nulls_last(true), + ) + .slice(0, 3) + .collect()?; + println!("{:?}", res); + assert!(res.equals(&df![ + "A" => [1, 1, 1], + "B" => ["A", "B", "C"], + ]?)); + Ok(()) +} + +#[test] +fn test_over_with_options_empty_join() -> PolarsResult<()> { + let empty_df = DataFrame::new(vec![ + Series::new_empty("a".into(), &DataType::Int32).into(), + Series::new_empty("b".into(), &DataType::Int32).into(), + ])?; + + let empty_df_out = empty_df + .lazy() + .select([col("b").over_with_options([col("a")], Option::None, WindowMapping::Join)]) + .collect()?; + + let f1: Field = Field::new("b".into(), DataType::List(Box::new(DataType::Int32))); + let sc: Schema = Schema::from_iter(vec![f1]); + + assert_eq!(&**empty_df_out.schema(), &sc); + + Ok(()) +} diff --git a/crates/polars-lazy/src/tests/schema.rs b/crates/polars-lazy/src/tests/schema.rs new file mode 100644 index 000000000000..e4166d33c94a --- /dev/null +++ b/crates/polars-lazy/src/tests/schema.rs @@ -0,0 +1,42 @@ +use super::*; + +#[test] +fn test_schema_update_after_projection_pd() -> PolarsResult<()> { + let df = df![ + "a" => [1], + "b" => [1], + "c" => [1], + ]?; + + let q = df + .lazy() + .with_column(col("a").implode()) + .explode([col("a")]) + .select([cols(["a", "b"])]); + + // run optimizations + // Get the explode node + let IRPlan { + lp_top, + lp_arena, + expr_arena: _, + } = q.to_alp_optimized()?; + + // assert the schema has been corrected with the projection pushdown run + let lp = lp_arena.get(lp_top); + assert!(matches!( + lp, + IR::MapFunction { + function: FunctionIR::Explode { .. }, + .. + } + )); + + let schema = lp.schema(&lp_arena).into_owned(); + let mut expected = Schema::default(); + expected.with_column("a".into(), DataType::Int32); + expected.with_column("b".into(), DataType::Int32); + assert_eq!(schema.as_ref(), &expected); + + Ok(()) +} diff --git a/crates/polars-lazy/src/tests/streaming.rs b/crates/polars-lazy/src/tests/streaming.rs new file mode 100644 index 000000000000..46b7732a0c31 --- /dev/null +++ b/crates/polars-lazy/src/tests/streaming.rs @@ -0,0 +1,419 @@ +use polars_ops::frame::JoinCoalesce; + +use super::*; + +fn get_csv_file() -> LazyFrame { + let file = "../../examples/datasets/foods1.csv"; + LazyCsvReader::new(file).finish().unwrap() +} + +fn get_parquet_file() -> LazyFrame { + let file = "../../examples/datasets/foods1.parquet"; + LazyFrame::scan_parquet(file, Default::default()).unwrap() +} + +fn get_csv_glob() -> LazyFrame { + let file = "../../examples/datasets/foods*.csv"; + LazyCsvReader::new(file).finish().unwrap() +} + +fn assert_streaming_with_default(q: LazyFrame, check_shape_only: bool) { + let q_streaming = q.clone().with_new_streaming(true); + let q_expected = q.with_new_streaming(false); + let out = q_streaming.collect().unwrap(); + let expected = q_expected.collect().unwrap(); + if check_shape_only { + assert_eq!(out.shape(), expected.shape()) + } else { + assert_eq!(out, expected); + } +} + +#[test] +fn test_streaming_parquet() -> PolarsResult<()> { + let q = get_parquet_file(); + + let q = q + .group_by([col("sugars_g")]) + .agg([((lit(1) - col("fats_g")) + col("calories")).sum()]) + .sort(["sugars_g"], Default::default()); + + assert_streaming_with_default(q, false); + Ok(()) +} + +#[test] +fn test_streaming_csv() -> PolarsResult<()> { + let q = get_csv_file(); + + let q = q + .select([col("sugars_g"), col("calories")]) + .group_by([col("sugars_g")]) + .agg([col("calories").sum()]) + .sort(["sugars_g"], Default::default()); + + assert_streaming_with_default(q, false); + Ok(()) +} + +#[test] +fn test_streaming_glob() -> PolarsResult<()> { + assert_streaming_with_default(get_csv_glob(), false); + Ok(()) +} + +#[test] +fn test_streaming_union_order() -> PolarsResult<()> { + let q = get_csv_glob(); + let q = concat([q.clone(), q], Default::default())?; + let q = q.select([col("sugars_g"), col("calories")]); + + assert_streaming_with_default(q, false); + Ok(()) +} + +#[test] +#[cfg(feature = "cross_join")] +fn test_streaming_union_join() -> PolarsResult<()> { + let q = get_csv_glob(); + let q = q.select([col("sugars_g"), col("calories")]); + let q = q.clone().cross_join(q, None); + + assert_streaming_with_default(q, true); + Ok(()) +} + +#[test] +fn test_streaming_multiple_keys_aggregate() -> PolarsResult<()> { + let q = get_csv_glob(); + + let q = q + .filter(col("sugars_g").gt(lit(10))) + .group_by([col("sugars_g"), col("calories")]) + .agg([ + (col("fats_g") * lit(10)).sum(), + col("calories").mean().alias("cal_mean"), + ]) + .sort_by_exprs( + [col("sugars_g"), col("calories")], + SortMultipleOptions::default().with_order_descending_multi([false, false]), + ); + + assert_streaming_with_default(q, false); + Ok(()) +} + +#[test] +fn test_streaming_first_sum() -> PolarsResult<()> { + let q = get_csv_file(); + + let q = q + .select([col("sugars_g"), col("calories")]) + .group_by([col("sugars_g")]) + .agg([ + col("calories").sum(), + col("calories").first().alias("calories_first"), + ]) + .sort(["sugars_g"], Default::default()); + + assert_streaming_with_default(q, false); + Ok(()) +} + +#[test] +fn test_streaming_unique() -> PolarsResult<()> { + let q = get_csv_file(); + + let q = q + .select([col("sugars_g"), col("calories")]) + .unique(None, Default::default()) + .sort_by_exprs( + [cols(["sugars_g", "calories"])], + SortMultipleOptions::default(), + ); + + assert_streaming_with_default(q, false); + Ok(()) +} + +#[test] +fn test_streaming_aggregate_slice() -> PolarsResult<()> { + let q = get_parquet_file(); + + let q = q + .group_by([col("sugars_g")]) + .agg([((lit(1) - col("fats_g")) + col("calories")).sum()]) + .slice(3, 3); + + let out_streaming = q.collect_with_engine(Engine::Streaming)?; + assert_eq!(out_streaming.shape(), (3, 2)); + Ok(()) +} + +#[test] +#[cfg(feature = "cross_join")] +fn test_streaming_cross_join() -> PolarsResult<()> { + let df = df![ + "a" => [1 ,2, 3] + ]?; + let q = df.lazy(); + let out = q + .clone() + .cross_join(q, None) + .collect_with_engine(Engine::Streaming)?; + assert_eq!(out.shape(), (9, 2)); + + let q = get_parquet_file().with_projection_pushdown(false); + let q1 = q + .clone() + .select([col("calories")]) + .cross_join(q.clone(), None) + .filter(col("calories").gt(col("calories_right"))); + let q2 = q1 + .select([all().name().suffix("_second")]) + .cross_join(q, None) + .filter(col("calories_right_second").lt(col("calories"))) + .select([ + col("calories"), + col("calories_right_second").alias("calories_right"), + ]); + + let out_streaming = q2.collect_with_engine(Engine::Streaming)?; + + assert_eq!( + out_streaming.get_column_names(), + &["calories", "calories_right"] + ); + assert_eq!(out_streaming.shape(), (5753, 2)); + Ok(()) +} + +#[test] +fn test_streaming_inner_join3() -> PolarsResult<()> { + let lf_left = df![ + "col1" => [1, 1, 1], + "col2" => ["a", "a", "b"], + "int_col" => [1, 2, 3] + ]? + .lazy(); + + let lf_right = df![ + "col1" => [1, 1, 1, 1, 1, 2], + "col2" => ["a", "a", "a", "a", "a", "c"], + "floats" => [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] + ]? + .lazy(); + + let q = lf_left.inner_join(lf_right, col("col1"), col("col1")); + + assert_streaming_with_default(q, false); + Ok(()) +} + +#[test] +fn test_streaming_inner_join2() -> PolarsResult<()> { + let lf_left = df![ + "a"=> [0, 0, 0, 3, 0, 1, 3, 3, 3, 1, 4, 4, 2, 1, 1, 3, 1, 4, 2, 2], + "b"=> [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] + ]? + .lazy(); + + let lf_right = df![ + "a"=> [10, 18, 13, 9, 1, 13, 14, 12, 15, 11], + "b"=> [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + ]? + .lazy(); + + let q = lf_left.inner_join(lf_right, col("a"), col("a")); + + assert_streaming_with_default(q, false); + Ok(()) +} +#[test] +fn test_streaming_left_join() -> PolarsResult<()> { + let lf_left = df![ + "a"=> [0, 0, 0, 3, 0, 1, 3, 3, 3, 1, 4, 4, 2, 1, 1, 3, 1, 4, 2, 2], + "b"=> [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] + ]? + .lazy(); + + let lf_right = df![ + "a"=> [10, 18, 13, 9, 1, 13, 14, 12, 15, 11], + "b"=> [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + ]? + .lazy(); + + let q = lf_left.left_join(lf_right, col("a"), col("a")); + + assert_streaming_with_default(q, false); + Ok(()) +} + +#[test] +#[cfg(feature = "cross_join")] +fn test_streaming_slice() -> PolarsResult<()> { + let vals = (0..100).collect::>(); + let s = Series::new("".into(), vals); + let lf_a = df![ + "a" => s + ]? + .lazy(); + + let q = lf_a.clone().cross_join(lf_a, None).slice(10, 20); + let a = q.with_streaming(true).collect().unwrap(); + assert_eq!(a.shape(), (20, 2)); + + Ok(()) +} + +#[test] +fn test_streaming_partial() -> PolarsResult<()> { + let lf_left = df![ + "a"=> [0], + "b"=> [0], + ]? + .lazy(); + + let lf_right = df![ + "a"=> [0], + "b"=> [0], + ]? + .lazy(); + + let q = lf_left.clone().left_join(lf_right, col("a"), col("a")); + + // we add a join that is not supported streaming (for now) + // so we can test if the partial query is executed without panics + let q = q + .join_builder() + .with(lf_left.clone()) + .left_on([col("a")]) + .right_on([col("a")]) + .suffix("_foo") + .how(JoinType::Full) + .coalesce(JoinCoalesce::CoalesceColumns) + .finish(); + + let q = q.left_join( + lf_left.select([all().name().suffix("_foo")]), + col("a"), + col("a_foo"), + ); + assert_streaming_with_default(q, true); + + Ok(()) +} + +#[test] +fn test_streaming_aggregate_join() -> PolarsResult<()> { + let q = get_parquet_file(); + + let q = q + .group_by([col("sugars_g")]) + .agg([((lit(1) - col("fats_g")) + col("calories")).sum()]) + .slice(0, 3); + + let q = q.clone().left_join(q, col("sugars_g"), col("sugars_g")); + let q1 = q.with_new_streaming(true); + let out_streaming = q1.collect()?; + assert_eq!(out_streaming.shape(), (3, 3)); + Ok(()) +} + +#[test] +fn test_streaming_double_left_join() -> PolarsResult<()> { + // A left join swaps the tables, so that checks the swapping of the branches + let q1 = df![ + "id" => ["1a"], + "p_id" => ["1b"], + "m_id" => ["1c"], + ]? + .lazy(); + + let q2 = df![ + "p_id2" => ["2a"], + "p_code" => ["2b"], + ]? + .lazy(); + + let q3 = df![ + "m_id3" => ["3a"], + "m_code" => ["3b"], + ]? + .lazy(); + + let q = q1 + .clone() + .left_join(q2.clone(), col("p_id"), col("p_id2")) + .left_join(q3.clone(), col("m_id"), col("m_id3")); + + assert_streaming_with_default(q.clone(), false); + + // more joins + let q = q.left_join(q1.clone(), col("p_id"), col("p_id")).left_join( + q3.clone(), + col("m_id"), + col("m_id3"), + ); + + assert_streaming_with_default(q, false); + // empty tables + let q = q1 + .slice(0, 0) + .left_join(q2.slice(0, 0), col("p_id"), col("p_id2")) + .left_join(q3.slice(0, 0), col("m_id"), col("m_id3")); + + assert_streaming_with_default(q, false); + Ok(()) +} + +#[test] +fn test_sort_maintain_order_streaming() -> PolarsResult<()> { + let q = df![ + "A" => [1, 1, 1, 1], + "B" => ["A", "B", "C", "D"], + ]? + .lazy(); + + let res = q + .sort_by_exprs( + [col("A")], + SortMultipleOptions::default() + .with_nulls_last(true) + .with_maintain_order(true), + ) + .slice(0, 3) + .with_streaming(true) + .collect()?; + assert!(res.equals(&df![ + "A" => [1, 1, 1], + "B" => ["A", "B", "C"], + ]?)); + Ok(()) +} + +#[test] +fn test_streaming_full_outer_join() -> PolarsResult<()> { + let lf_left = df![ + "a"=> [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], + "b"=> [0, 0, 0, 3, 0, 1, 3, 3, 3, 1, 4, 4, 2, 1, 1, 3, 1, 4, 2, 2], + ]? + .lazy(); + + let lf_right = df![ + "a"=> [10, 18, 13, 9, 1, 13, 14, 12, 15, 11], + "b"=> [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + ]? + .lazy(); + + let q = lf_left + .full_join(lf_right, col("a"), col("a")) + .sort_by_exprs([all()], SortMultipleOptions::default()); + + // Toggle so that the join order is swapped. + for toggle in [true, true] { + assert_streaming_with_default(q.clone().with_streaming(toggle), false); + } + + Ok(()) +} diff --git a/crates/polars-mem-engine/Cargo.toml b/crates/polars-mem-engine/Cargo.toml new file mode 100644 index 000000000000..562918bf0763 --- /dev/null +++ b/crates/polars-mem-engine/Cargo.toml @@ -0,0 +1,54 @@ +[package] +name = "polars-mem-engine" +version.workspace = true +authors.workspace = true +edition.workspace = true +homepage.workspace = true +license.workspace = true +repository.workspace = true +description = "In memory engine of the Polars project." + +[dependencies] +arrow = { workspace = true } +futures = { workspace = true, optional = true } +memmap = { workspace = true } +polars-core = { workspace = true, features = ["lazy"] } +polars-error = { workspace = true } +polars-expr = { workspace = true } +polars-io = { workspace = true, features = ["lazy"] } +polars-json = { workspace = true, optional = true } +polars-ops = { workspace = true, features = ["chunked_ids"] } +polars-plan = { workspace = true } +polars-time = { workspace = true, optional = true } +polars-utils = { workspace = true } +pyo3 = { workspace = true, optional = true } +rayon = { workspace = true } +recursive = { workspace = true } +tokio = { workspace = true, optional = true } + +[features] +async = [ + "polars-plan/async", + "polars-io/cloud", +] +python = ["pyo3", "polars-plan/python", "polars-core/python", "polars-io/python"] +ipc = ["polars-io/ipc", "polars-plan/ipc"] +json = ["polars-io/json", "polars-plan/json", "polars-json"] +csv = ["polars-io/csv", "polars-plan/csv"] +cloud = ["async", "polars-plan/cloud", "tokio", "futures"] +parquet = ["polars-io/parquet", "polars-plan/parquet"] +dtype-categorical = ["polars-plan/dtype-categorical"] +dtype-date = ["polars-plan/dtype-date", "polars-time/dtype-date"] +dtype-datetime = ["polars-plan/dtype-datetime", "polars-time/dtype-datetime"] +dtype-decimal = ["polars-plan/dtype-decimal"] +dtype-duration = ["polars-plan/dtype-duration", "polars-time/dtype-duration"] +dtype-i16 = ["polars-plan/dtype-i16"] +dtype-i8 = ["polars-plan/dtype-i8"] +dtype-struct = ["polars-plan/dtype-struct", "polars-ops/dtype-struct"] +dtype-time = ["polars-plan/dtype-time", "polars-time/dtype-time"] +dtype-u16 = ["polars-plan/dtype-u16"] +dtype-u8 = ["polars-plan/dtype-u8"] +object = ["polars-core/object"] +dynamic_group_by = ["polars-plan/dynamic_group_by", "polars-time", "polars-expr/dynamic_group_by"] +asof_join = ["polars-plan/asof_join", "polars-time", "polars-ops/asof_join"] +merge_sorted = ["polars-plan/merge_sorted", "polars-ops/merge_sorted"] diff --git a/crates/polars-mem-engine/LICENSE b/crates/polars-mem-engine/LICENSE new file mode 100644 index 000000000000..cc6d55aa7571 --- /dev/null +++ b/crates/polars-mem-engine/LICENSE @@ -0,0 +1,20 @@ +Copyright (c) 2025 Ritchie Vink +Some portions Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/crates/polars-mem-engine/README.md b/crates/polars-mem-engine/README.md new file mode 100644 index 000000000000..a07c36a16dfd --- /dev/null +++ b/crates/polars-mem-engine/README.md @@ -0,0 +1,7 @@ +# polars-mem-engine + +`polars-mem-engine` is an **internal sub-crate** of the [Polars](https://crates.io/crates/polars) +library. + +**Important Note**: This crate is **not intended for external usage**. Please refer to the main +[Polars crate](https://crates.io/crates/polars) for intended usage. diff --git a/crates/polars-mem-engine/src/executors/cache.rs b/crates/polars-mem-engine/src/executors/cache.rs new file mode 100644 index 000000000000..ad4220cc686f --- /dev/null +++ b/crates/polars-mem-engine/src/executors/cache.rs @@ -0,0 +1,64 @@ +use std::sync::atomic::Ordering; + +use super::*; + +pub struct CacheExec { + pub input: Option>, + pub id: usize, + pub count: u32, +} + +impl Executor for CacheExec { + fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { + match &mut self.input { + // Cache node + None => { + if state.verbose() { + eprintln!("CACHE HIT: cache id: {:x}", self.id); + } + let cache = state.get_df_cache(self.id, self.count); + let out = cache.1.get().expect("prefilled").clone(); + let previous = cache.0.fetch_sub(1, Ordering::Relaxed); + if previous == 0 { + state.remove_df_cache(self.id); + } + + Ok(out) + }, + // Cache Prefill node + Some(input) => { + if state.verbose() { + eprintln!("CACHE SET: cache id: {:x}", self.id); + } + let df = input.execute(state)?; + let cache = state.get_df_cache(self.id, self.count); + cache.1.set(df).expect("should be empty"); + Ok(DataFrame::empty()) + }, + } + } +} + +pub struct CachePrefiller { + pub caches: PlIndexMap>, + pub phys_plan: Box, +} + +impl Executor for CachePrefiller { + fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { + if state.verbose() { + eprintln!("PREFILL CACHES") + } + // Ensure we traverse in discovery order. This will ensure that caches aren't dependent on each + // other. + for cache in self.caches.values_mut() { + let mut state = state.split(); + state.branch_idx += 1; + let _df = cache.execute(&mut state)?; + } + if state.verbose() { + eprintln!("EXECUTE PHYS PLAN") + } + self.phys_plan.execute(state) + } +} diff --git a/crates/polars-mem-engine/src/executors/executor.rs b/crates/polars-mem-engine/src/executors/executor.rs new file mode 100644 index 000000000000..03806362778d --- /dev/null +++ b/crates/polars-mem-engine/src/executors/executor.rs @@ -0,0 +1,57 @@ +use super::*; + +// Executor are the executors of the physical plan and produce DataFrames. They +// combine physical expressions, which produce Series. + +/// Executors will evaluate physical expressions and collect them in a DataFrame. +/// +/// Executors have other executors as input. By having a tree of executors we can execute the +/// physical plan until the last executor is evaluated. +pub trait Executor: Send + Sync { + fn execute(&mut self, cache: &mut ExecutionState) -> PolarsResult; +} + +type SinkFn = + Box PolarsResult> + Send + Sync>; +pub struct SinkExecutor { + pub name: String, + pub input: Box, + pub f: SinkFn, +} + +impl Executor for SinkExecutor { + fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { + state.should_stop()?; + #[cfg(debug_assertions)] + { + if state.verbose() { + eprintln!("run sink_{}", self.name) + } + } + let df = self.input.execute(state)?; + + let profile_name = if state.has_node_timer() { + Cow::Owned(format!(".sink_{}()", &self.name)) + } else { + Cow::Borrowed("") + }; + + state.clone().record( + || (self.f)(df, state).map(|df| df.unwrap_or_else(DataFrame::empty)), + profile_name, + ) + } +} + +pub struct Dummy {} +impl Executor for Dummy { + fn execute(&mut self, _cache: &mut ExecutionState) -> PolarsResult { + panic!("should not get here"); + } +} + +impl Default for Box { + fn default() -> Self { + Box::new(Dummy {}) + } +} diff --git a/crates/polars-mem-engine/src/executors/ext_context.rs b/crates/polars-mem-engine/src/executors/ext_context.rs new file mode 100644 index 000000000000..89fee49cc10b --- /dev/null +++ b/crates/polars-mem-engine/src/executors/ext_context.rs @@ -0,0 +1,27 @@ +use super::*; + +pub struct ExternalContext { + pub input: Box, + pub contexts: Vec>, +} + +impl Executor for ExternalContext { + fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { + #[cfg(debug_assertions)] + { + if state.verbose() { + eprintln!("run ExternalContext") + } + } + // we evaluate contexts first as input may has pushed exprs. + let contexts = self + .contexts + .iter_mut() + .map(|e| e.execute(state)) + .collect::>>()?; + state.ext_contexts = Arc::new(contexts); + let df = self.input.execute(state)?; + + Ok(df) + } +} diff --git a/crates/polars-mem-engine/src/executors/filter.rs b/crates/polars-mem-engine/src/executors/filter.rs new file mode 100644 index 000000000000..a47e9b6f5ed9 --- /dev/null +++ b/crates/polars-mem-engine/src/executors/filter.rs @@ -0,0 +1,121 @@ +use polars_core::utils::accumulate_dataframes_vertical_unchecked; + +use super::*; + +pub struct FilterExec { + pub(crate) predicate: Arc, + pub(crate) input: Box, + // if the predicate contains a window function + has_window: bool, + streamable: bool, +} + +fn column_to_mask(c: &Column) -> PolarsResult<&BooleanChunked> { + c.bool().map_err(|_| { + polars_err!( + ComputeError: "filter predicate must be of type `Boolean`, got `{}`", c.dtype() + ) + }) +} + +impl FilterExec { + pub fn new( + predicate: Arc, + input: Box, + has_window: bool, + streamable: bool, + ) -> Self { + Self { + predicate, + input, + has_window, + streamable, + } + } + + fn execute_hor( + &mut self, + df: DataFrame, + state: &mut ExecutionState, + ) -> PolarsResult { + if self.has_window { + state.insert_has_window_function_flag() + } + let c = self.predicate.evaluate(&df, state)?; + if self.has_window { + state.clear_window_expr_cache() + } + + // @scalar-opt + // @partition-opt + df.filter(column_to_mask(&c)?) + } + + fn execute_chunks( + &mut self, + chunks: Vec, + state: &ExecutionState, + ) -> PolarsResult { + let iter = chunks.into_par_iter().map(|df| { + let c = self.predicate.evaluate(&df, state)?; + + // @scalar-opt + // @partition-opt + df.filter(column_to_mask(&c)?) + }); + let df = POOL.install(|| iter.collect::>>())?; + Ok(accumulate_dataframes_vertical_unchecked(df)) + } + + fn execute_impl( + &mut self, + mut df: DataFrame, + state: &mut ExecutionState, + ) -> PolarsResult { + let n_partitions = POOL.current_num_threads(); + // Vertical parallelism. + if self.streamable && df.height() > 0 { + if df.first_col_n_chunks() > 1 { + let chunks = df.split_chunks().collect::>(); + self.execute_chunks(chunks, state) + } else if df.width() < n_partitions { + self.execute_hor(df, state) + } else { + let chunks = df.split_chunks_by_n(n_partitions, true); + self.execute_chunks(chunks, state) + } + } else { + self.execute_hor(df, state) + } + } +} + +impl Executor for FilterExec { + fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { + state.should_stop()?; + #[cfg(debug_assertions)] + { + if state.verbose() { + eprintln!("run FilterExec") + } + } + let df = self.input.execute(state)?; + + let profile_name = if state.has_node_timer() { + Cow::Owned(format!(".filter({})", &self.predicate.as_ref())) + } else { + Cow::Borrowed("") + }; + + state.clone().record( + || { + let df = self.execute_impl(df, state); + if state.verbose() { + eprintln!("dataframe filtered"); + } + df + }, + profile_name, + ) + } +} diff --git a/crates/polars-mem-engine/src/executors/group_by.rs b/crates/polars-mem-engine/src/executors/group_by.rs new file mode 100644 index 000000000000..fec04ea5f3b9 --- /dev/null +++ b/crates/polars-mem-engine/src/executors/group_by.rs @@ -0,0 +1,148 @@ +use rayon::prelude::*; + +use super::*; + +pub(super) fn evaluate_aggs( + df: &DataFrame, + aggs: &[Arc], + groups: &GroupPositions, + state: &ExecutionState, +) -> PolarsResult> { + POOL.install(|| { + aggs.par_iter() + .map(|expr| { + let agg = expr.evaluate_on_groups(df, groups, state)?.finalize(); + polars_ensure!(agg.len() == groups.len(), agg_len = agg.len(), groups.len()); + Ok(agg) + }) + .collect::>>() + }) +} + +/// Take an input Executor and a multiple expressions +pub struct GroupByExec { + input: Box, + keys: Vec>, + aggs: Vec>, + apply: Option>, + maintain_order: bool, + input_schema: SchemaRef, + slice: Option<(i64, usize)>, +} + +impl GroupByExec { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + input: Box, + keys: Vec>, + aggs: Vec>, + apply: Option>, + maintain_order: bool, + input_schema: SchemaRef, + slice: Option<(i64, usize)>, + ) -> Self { + Self { + input, + keys, + aggs, + apply, + maintain_order, + input_schema, + slice, + } + } +} + +#[allow(clippy::too_many_arguments)] +pub(super) fn group_by_helper( + mut df: DataFrame, + keys: Vec, + aggs: &[Arc], + apply: Option>, + state: &ExecutionState, + maintain_order: bool, + slice: Option<(i64, usize)>, +) -> PolarsResult { + df.as_single_chunk_par(); + let gb = df.group_by_with_series(keys, true, maintain_order)?; + + if let Some(f) = apply { + return gb.sliced(slice).apply(move |df| f.call_udf(df)); + } + + let mut groups = gb.get_groups(); + + #[allow(unused_assignments)] + // it is unused because we only use it to keep the lifetime of sliced_group valid + let mut sliced_groups = None; + + if let Some((offset, len)) = slice { + sliced_groups = Some(groups.slice(offset, len)); + groups = sliced_groups.as_ref().unwrap(); + } + + let (mut columns, agg_columns) = POOL.install(|| { + let get_columns = || gb.keys_sliced(slice); + + let get_agg = || evaluate_aggs(&df, aggs, groups, state); + + rayon::join(get_columns, get_agg) + }); + + columns.extend(agg_columns?); + DataFrame::new(columns) +} + +impl GroupByExec { + fn execute_impl(&mut self, state: &ExecutionState, df: DataFrame) -> PolarsResult { + let keys = self + .keys + .iter() + .map(|e| e.evaluate(&df, state)) + .collect::>()?; + group_by_helper( + df, + keys, + &self.aggs, + self.apply.take(), + state, + self.maintain_order, + self.slice, + ) + } +} + +impl Executor for GroupByExec { + fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { + state.should_stop()?; + #[cfg(debug_assertions)] + { + if state.verbose() { + eprintln!("run GroupbyExec") + } + } + if state.verbose() { + eprintln!("keys/aggregates are not partitionable: running default HASH AGGREGATION") + } + let df = self.input.execute(state)?; + + let profile_name = if state.has_node_timer() { + let by = self + .keys + .iter() + .map(|s| Ok(s.to_field(&self.input_schema)?.name)) + .collect::>>()?; + let name = comma_delimited("group_by".to_string(), &by); + Cow::Owned(name) + } else { + Cow::Borrowed("") + }; + + if state.has_node_timer() { + let new_state = state.clone(); + new_state.record(|| self.execute_impl(state, df), profile_name) + } else { + self.execute_impl(state, df) + } + } +} diff --git a/crates/polars-mem-engine/src/executors/group_by_dynamic.rs b/crates/polars-mem-engine/src/executors/group_by_dynamic.rs new file mode 100644 index 000000000000..569152aef0d0 --- /dev/null +++ b/crates/polars-mem-engine/src/executors/group_by_dynamic.rs @@ -0,0 +1,123 @@ +use super::*; + +#[cfg_attr(not(feature = "dynamic_group_by"), allow(dead_code))] +pub(crate) struct GroupByDynamicExec { + pub(crate) input: Box, + // we will use this later + #[allow(dead_code)] + pub(crate) keys: Vec>, + pub(crate) aggs: Vec>, + #[cfg(feature = "dynamic_group_by")] + pub(crate) options: DynamicGroupOptions, + pub(crate) input_schema: SchemaRef, + pub(crate) slice: Option<(i64, usize)>, + pub(crate) apply: Option>, +} + +impl GroupByDynamicExec { + #[cfg(feature = "dynamic_group_by")] + fn execute_impl( + &mut self, + state: &ExecutionState, + mut df: DataFrame, + ) -> PolarsResult { + use crate::executors::group_by_rolling::sort_and_groups; + + df.as_single_chunk_par(); + + let mut keys = self + .keys + .iter() + .map(|e| e.evaluate(&df, state)) + .collect::>>()?; + + let group_by = if !self.keys.is_empty() { + Some(sort_and_groups(&mut df, &mut keys)?) + } else { + None + }; + + let (mut time_key, bounds, groups) = df.group_by_dynamic(group_by, &self.options)?; + POOL.install(|| { + keys.iter_mut().for_each(|key| { + unsafe { *key = key.agg_first(&groups) }; + }) + }); + keys.extend(bounds); + + if let Some(f) = &self.apply { + let gb = GroupBy::new(&df, vec![], groups, None); + let out = gb.apply(move |df| f.call_udf(df))?; + return Ok(if let Some((offset, len)) = self.slice { + out.slice(offset, len) + } else { + out + }); + } + + let mut groups = &groups; + #[allow(unused_assignments)] + // it is unused because we only use it to keep the lifetime of sliced_group valid + let mut sliced_groups = None; + + if let Some((offset, len)) = self.slice { + sliced_groups = Some(groups.slice(offset, len)); + groups = sliced_groups.as_ref().unwrap(); + + time_key = time_key.slice(offset, len); + + // todo! optimize this, we can prevent an agg_first aggregation upstream + // the ordering has changed due to the group_by + for key in keys.iter_mut() { + *key = key.slice(offset, len) + } + } + + let agg_columns = evaluate_aggs(&df, &self.aggs, groups, state)?; + + let mut columns = Vec::with_capacity(agg_columns.len() + 1 + keys.len()); + columns.extend_from_slice(&keys); + columns.push(time_key); + columns.extend(agg_columns); + + DataFrame::new(columns) + } +} + +impl Executor for GroupByDynamicExec { + #[cfg(not(feature = "dynamic_group_by"))] + fn execute(&mut self, _state: &mut ExecutionState) -> PolarsResult { + panic!("activate feature dynamic_group_by") + } + + #[cfg(feature = "dynamic_group_by")] + fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { + state.should_stop()?; + #[cfg(debug_assertions)] + { + if state.verbose() { + eprintln!("run GroupbyDynamicExec") + } + } + let df = self.input.execute(state)?; + + let profile_name = if state.has_node_timer() { + let by = self + .keys + .iter() + .map(|s| Ok(s.to_field(&self.input_schema)?.name)) + .collect::>>()?; + let name = comma_delimited("group_by_dynamic".to_string(), &by); + Cow::Owned(name) + } else { + Cow::Borrowed("") + }; + + if state.has_node_timer() { + let new_state = state.clone(); + new_state.record(|| self.execute_impl(state, df), profile_name) + } else { + self.execute_impl(state, df) + } + } +} diff --git a/crates/polars-mem-engine/src/executors/group_by_partitioned.rs b/crates/polars-mem-engine/src/executors/group_by_partitioned.rs new file mode 100644 index 000000000000..8dbde323d291 --- /dev/null +++ b/crates/polars-mem-engine/src/executors/group_by_partitioned.rs @@ -0,0 +1,381 @@ +use polars_core::series::IsSorted; +use polars_core::utils::{accumulate_dataframes_vertical, split_df}; +use rayon::prelude::*; + +use super::*; + +/// Take an input Executor and a multiple expressions +pub struct PartitionGroupByExec { + input: Box, + phys_keys: Vec>, + phys_aggs: Vec>, + maintain_order: bool, + slice: Option<(i64, usize)>, + input_schema: SchemaRef, + output_schema: SchemaRef, + from_partitioned_ds: bool, + #[allow(dead_code)] + keys: Vec, + #[allow(dead_code)] + aggs: Vec, +} + +impl PartitionGroupByExec { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + input: Box, + phys_keys: Vec>, + phys_aggs: Vec>, + maintain_order: bool, + slice: Option<(i64, usize)>, + input_schema: SchemaRef, + output_schema: SchemaRef, + from_partitioned_ds: bool, + keys: Vec, + aggs: Vec, + ) -> Self { + Self { + input, + phys_keys, + phys_aggs, + maintain_order, + slice, + input_schema, + output_schema, + from_partitioned_ds, + keys, + aggs, + } + } + + fn keys(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult> { + compute_keys(&self.phys_keys, df, state) + } +} + +fn compute_keys( + keys: &[Arc], + df: &DataFrame, + state: &ExecutionState, +) -> PolarsResult> { + let evaluated = keys + .iter() + .map(|s| s.evaluate(df, state)) + .collect::>()?; + let df = check_expand_literals(df, keys, evaluated, false, Default::default())?; + Ok(df.take_columns()) +} + +fn run_partitions( + df: &mut DataFrame, + exec: &PartitionGroupByExec, + state: &ExecutionState, + n_threads: usize, + maintain_order: bool, +) -> PolarsResult<(Vec, Vec>)> { + // We do a partitioned group_by. + // Meaning that we first do the group_by operation arbitrarily + // split on several threads. Than the final result we apply the same group_by again. + let dfs = split_df(df, n_threads, true); + + let phys_aggs = &exec.phys_aggs; + let keys = &exec.phys_keys; + + let mut keys = DataFrame::from_iter(compute_keys(keys, df, state)?); + let splitted_keys = split_df(&mut keys, n_threads, true); + + POOL.install(|| { + dfs.into_par_iter() + .zip(splitted_keys) + .map(|(df, keys)| { + let gb = df.group_by_with_series(keys.into(), false, maintain_order)?; + let groups = gb.get_groups(); + + let mut columns = gb.keys(); + // don't naively call par_iter here, it will segfault in rayon + // if you do, throw it on the POOL threadpool. + let agg_columns = phys_aggs + .iter() + .map(|expr| { + let agg_expr = expr.as_partitioned_aggregator().unwrap(); + let agg = agg_expr.evaluate_partitioned(&df, groups, state)?; + Ok(if agg.len() != groups.len() { + polars_ensure!(agg.len() == 1, agg_len = agg.len(), groups.len()); + match groups.len() { + 0 => agg.clear(), + len => agg.new_from_index(0, len), + } + } else { + agg + } + .into_column()) + }) + .collect::>>()?; + + columns.extend_from_slice(&agg_columns); + + let df = DataFrame::new(columns)?; + Ok((df, gb.keys())) + }) + .collect() + }) +} + +fn estimate_unique_count(keys: &[Column], mut sample_size: usize) -> PolarsResult { + // https://stats.stackexchange.com/a/19090/147321 + // estimated unique size + // u + ui / m (s - m) + // s: set_size + // m: sample_size + // u: total unique groups counted in sample + // ui: groups with single unique value counted in sample + let set_size = keys[0].len(); + if set_size < sample_size { + sample_size = set_size; + } + + let finish = |groups: &GroupsType| { + let u = groups.len() as f64; + let ui = if groups.len() == sample_size { + u + } else { + groups.iter().filter(|g| g.len() == 1).count() as f64 + }; + + (u + (ui / sample_size as f64) * (set_size - sample_size) as f64) as usize + }; + + if keys.len() == 1 { + // we sample as that will work also with sorted data. + // not that sampling without replacement is *very* expensive. don't do that. + let s = keys[0].sample_n(sample_size, true, false, None).unwrap(); + // fast multi-threaded way to get unique. + let groups = s.as_materialized_series().group_tuples(true, false)?; + Ok(finish(&groups)) + } else { + let offset = (keys[0].len() / 2) as i64; + let df = unsafe { DataFrame::new_no_checks_height_from_first(keys.to_vec()) }; + let df = df.slice(offset, sample_size); + let names = df.get_column_names().into_iter().cloned(); + let gb = df.group_by(names).unwrap(); + Ok(finish(gb.get_groups())) + } +} + +// Lower this at debug builds so that we hit this in the test suite. +#[cfg(debug_assertions)] +const PARTITION_LIMIT: usize = 15; +#[cfg(not(debug_assertions))] +const PARTITION_LIMIT: usize = 1000; + +// Checks if we should run normal or default aggregation +// by sampling data. +fn can_run_partitioned( + keys: &[Column], + original_df: &DataFrame, + state: &ExecutionState, + from_partitioned_ds: bool, +) -> PolarsResult { + if !keys + .iter() + .take(1) + .all(|s| matches!(s.is_sorted_flag(), IsSorted::Not)) + { + if state.verbose() { + eprintln!("FOUND SORTED KEY: running default HASH AGGREGATION") + } + Ok(false) + } else if std::env::var("POLARS_NO_PARTITION").is_ok() { + if state.verbose() { + eprintln!("POLARS_NO_PARTITION set: running default HASH AGGREGATION") + } + Ok(false) + } else if std::env::var("POLARS_FORCE_PARTITION").is_ok() { + if state.verbose() { + eprintln!("POLARS_FORCE_PARTITION set: running partitioned HASH AGGREGATION") + } + Ok(true) + } else if original_df.height() < PARTITION_LIMIT && !cfg!(test) { + if state.verbose() { + eprintln!("DATAFRAME < {PARTITION_LIMIT} rows: running default HASH AGGREGATION") + } + Ok(false) + } else { + // below this boundary we assume the partitioned group_by will be faster + let unique_count_boundary = std::env::var("POLARS_PARTITION_UNIQUE_COUNT") + .map(|s| s.parse::().unwrap()) + .unwrap_or(1000); + + let (unique_estimate, sampled_method) = match (keys.len(), keys[0].dtype()) { + #[cfg(feature = "dtype-categorical")] + (1, DataType::Categorical(Some(rev_map), _) | DataType::Enum(Some(rev_map), _)) => { + (rev_map.len(), "known") + }, + _ => { + // sqrt(N) is a good sample size as it remains low on large numbers + // it is better than taking a fraction as it saturates + let sample_size = (original_df.height() as f64).powf(0.5) as usize; + + // we never sample less than 100 data points. + let sample_size = std::cmp::max(100, sample_size); + (estimate_unique_count(keys, sample_size)?, "estimated") + }, + }; + if state.verbose() { + eprintln!("{sampled_method} unique values: {unique_estimate}"); + } + + if from_partitioned_ds { + let estimated_cardinality = unique_estimate as f32 / original_df.height() as f32; + if estimated_cardinality < 0.4 { + if state.verbose() { + eprintln!("PARTITIONED DS"); + } + Ok(true) + } else { + if state.verbose() { + eprintln!( + "PARTITIONED DS: estimated cardinality: {estimated_cardinality} exceeded the boundary: 0.4, running default HASH AGGREGATION" + ); + } + Ok(false) + } + } else if unique_estimate > unique_count_boundary { + if state.verbose() { + eprintln!( + "estimated unique count: {unique_estimate} exceeded the boundary: {unique_count_boundary}, running default HASH AGGREGATION" + ) + } + Ok(false) + } else { + Ok(true) + } + } +} + +impl PartitionGroupByExec { + fn execute_impl( + &mut self, + state: &mut ExecutionState, + mut original_df: DataFrame, + ) -> PolarsResult { + let (splitted_dfs, splitted_keys) = { + // already get the keys. This is the very last minute decision which group_by method we choose. + // If the column is a categorical, we know the number of groups we have and can decide to continue + // partitioned or go for the standard group_by. The partitioned is likely to be faster on a small number + // of groups. + let keys = self.keys(&original_df, state)?; + + if !can_run_partitioned(&keys, &original_df, state, self.from_partitioned_ds)? { + return group_by_helper( + original_df, + keys, + &self.phys_aggs, + None, + state, + self.maintain_order, + self.slice, + ); + } + + if state.verbose() { + eprintln!("run PARTITIONED HASH AGGREGATION") + } + + // Run the partitioned aggregations + let n_threads = POOL.current_num_threads(); + + run_partitions( + &mut original_df, + self, + state, + n_threads, + self.maintain_order, + )? + }; + + // MERGE phase + + let df = accumulate_dataframes_vertical(splitted_dfs)?; + let keys = splitted_keys + .into_iter() + .reduce(|mut acc, e| { + acc.iter_mut().zip(e).for_each(|(acc, e)| { + let _ = acc.append(&e); + }); + acc + }) + .unwrap(); + + // the partitioned group_by has added columns so we must update the schema. + state.set_schema(self.output_schema.clone()); + + // merge and hash aggregate again + + // first get mutable access and optionally sort + let gb = df.group_by_with_series(keys, true, self.maintain_order)?; + let mut groups = gb.get_groups(); + + #[allow(unused_assignments)] + // it is unused because we only use it to keep the lifetime of sliced_group valid + let mut sliced_groups = None; + + if let Some((offset, len)) = self.slice { + sliced_groups = Some(groups.slice(offset, len)); + groups = sliced_groups.as_ref().unwrap(); + } + + let get_columns = || gb.keys_sliced(self.slice); + let get_agg = || { + let out: PolarsResult> = self + .phys_aggs + .par_iter() + // we slice the keys off and finalize every aggregation + .zip(&df.get_columns()[self.phys_keys.len()..]) + .map(|(expr, partitioned_s)| { + let agg_expr = expr.as_partitioned_aggregator().unwrap(); + agg_expr.finalize(partitioned_s.clone(), groups, state) + }) + .collect(); + + out + }; + let (mut columns, agg_columns): (Vec<_>, _) = POOL.join(get_columns, get_agg); + + columns.extend(agg_columns?); + state.clear_schema_cache(); + + Ok(DataFrame::new(columns).unwrap()) + } +} + +impl Executor for PartitionGroupByExec { + fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { + state.should_stop()?; + #[cfg(debug_assertions)] + { + if state.verbose() { + eprintln!("run PartitionGroupbyExec") + } + } + let original_df = self.input.execute(state)?; + + let profile_name = if state.has_node_timer() { + let by = self + .phys_keys + .iter() + .map(|s| Ok(s.to_field(&self.input_schema)?.name)) + .collect::>>()?; + let name = comma_delimited("group_by_partitioned".to_string(), &by); + Cow::Owned(name) + } else { + Cow::Borrowed("") + }; + if state.has_node_timer() { + let new_state = state.clone(); + new_state.record(|| self.execute_impl(state, original_df), profile_name) + } else { + self.execute_impl(state, original_df) + } + } +} diff --git a/crates/polars-mem-engine/src/executors/group_by_rolling.rs b/crates/polars-mem-engine/src/executors/group_by_rolling.rs new file mode 100644 index 000000000000..6d45f0b2fd33 --- /dev/null +++ b/crates/polars-mem-engine/src/executors/group_by_rolling.rs @@ -0,0 +1,155 @@ +use polars_utils::unique_column_name; + +use super::*; + +#[cfg_attr(not(feature = "dynamic_group_by"), allow(dead_code))] +pub(crate) struct GroupByRollingExec { + pub(crate) input: Box, + pub(crate) keys: Vec>, + pub(crate) aggs: Vec>, + #[cfg(feature = "dynamic_group_by")] + pub(crate) options: RollingGroupOptions, + pub(crate) input_schema: SchemaRef, + pub(crate) slice: Option<(i64, usize)>, + pub(crate) apply: Option>, +} + +pub(super) fn sort_and_groups( + df: &mut DataFrame, + keys: &mut Vec, +) -> PolarsResult> { + let encoded = row_encode::encode_rows_vertical_par_unordered(keys)?; + let encoded = encoded.rechunk().into_owned(); + let encoded = encoded.with_name(unique_column_name()); + let idx = encoded.arg_sort(SortOptions { + maintain_order: true, + ..Default::default() + }); + + let encoded = unsafe { + df.with_column_unchecked(encoded.into_series().into()); + + // If not sorted on keys, sort. + let idx_s = idx.clone().into_series(); + if !idx_s.is_sorted(Default::default()).unwrap() { + let (df_ordered, keys_ordered) = POOL.join( + || df.take_unchecked(&idx), + || { + keys.iter() + .map(|c| c.take_unchecked(&idx)) + .collect::>() + }, + ); + *df = df_ordered; + *keys = keys_ordered; + } + + df.get_columns_mut().pop().unwrap() + }; + let encoded = encoded.as_series().unwrap(); + let encoded = encoded.binary_offset().unwrap(); + let encoded = encoded.with_sorted_flag(polars_core::series::IsSorted::Ascending); + let groups = encoded.group_tuples(true, false).unwrap(); + + let GroupsType::Slice { groups, .. } = groups else { + // memory would explode + unreachable!(); + }; + Ok(groups) +} + +impl GroupByRollingExec { + #[cfg(feature = "dynamic_group_by")] + fn execute_impl( + &mut self, + state: &ExecutionState, + mut df: DataFrame, + ) -> PolarsResult { + df.as_single_chunk_par(); + + let mut keys = self + .keys + .iter() + .map(|e| e.evaluate(&df, state)) + .collect::>>()?; + + let group_by = if !self.keys.is_empty() { + Some(sort_and_groups(&mut df, &mut keys)?) + } else { + None + }; + + let (mut time_key, groups) = df.rolling(group_by, &self.options)?; + + if let Some(f) = &self.apply { + let gb = GroupBy::new(&df, vec![], groups, None); + let out = gb.apply(move |df| f.call_udf(df))?; + return Ok(if let Some((offset, len)) = self.slice { + out.slice(offset, len) + } else { + out + }); + } + + let mut groups = &groups; + #[allow(unused_assignments)] + // it is unused because we only use it to keep the lifetime of sliced_group valid + let mut sliced_groups = None; + + if let Some((offset, len)) = self.slice { + sliced_groups = Some(groups.slice(offset, len)); + groups = sliced_groups.as_ref().unwrap(); + + time_key = time_key.slice(offset, len); + for k in &mut keys { + *k = k.slice(offset, len); + } + } + + let agg_columns = evaluate_aggs(&df, &self.aggs, groups, state)?; + + let mut columns = Vec::with_capacity(agg_columns.len() + 1 + keys.len()); + columns.extend_from_slice(&keys); + columns.push(time_key); + columns.extend(agg_columns); + + DataFrame::new(columns) + } +} + +impl Executor for GroupByRollingExec { + #[cfg(not(feature = "dynamic_group_by"))] + fn execute(&mut self, _state: &mut ExecutionState) -> PolarsResult { + panic!("activate feature dynamic_group_by") + } + + #[cfg(feature = "dynamic_group_by")] + fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { + state.should_stop()?; + #[cfg(debug_assertions)] + { + if state.verbose() { + eprintln!("run GroupbyRollingExec") + } + } + let df = self.input.execute(state)?; + let profile_name = if state.has_node_timer() { + let by = self + .keys + .iter() + .map(|s| Ok(s.to_field(&self.input_schema)?.name)) + .collect::>>()?; + let name = comma_delimited("group_by_rolling".to_string(), &by); + Cow::Owned(name) + } else { + Cow::Borrowed("") + }; + + if state.has_node_timer() { + let new_state = state.clone(); + new_state.record(|| self.execute_impl(state, df), profile_name) + } else { + self.execute_impl(state, df) + } + } +} diff --git a/crates/polars-mem-engine/src/executors/hconcat.rs b/crates/polars-mem-engine/src/executors/hconcat.rs new file mode 100644 index 000000000000..2d9c543dd3c5 --- /dev/null +++ b/crates/polars-mem-engine/src/executors/hconcat.rs @@ -0,0 +1,64 @@ +use polars_core::functions::concat_df_horizontal; + +use super::*; + +pub(crate) struct HConcatExec { + pub(crate) inputs: Vec>, + pub(crate) options: HConcatOptions, +} + +impl Executor for HConcatExec { + fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { + #[cfg(debug_assertions)] + { + if state.verbose() { + eprintln!("run HConcatExec") + } + } + let mut inputs = std::mem::take(&mut self.inputs); + + let dfs = if !self.options.parallel { + if state.verbose() { + eprintln!("HCONCAT: `parallel=false` hconcat is run sequentially") + } + let mut dfs = Vec::with_capacity(inputs.len()); + for (idx, mut input) in inputs.into_iter().enumerate() { + let mut state = state.split(); + state.branch_idx += idx; + + let df = input.execute(&mut state)?; + + dfs.push(df); + } + dfs + } else { + if state.verbose() { + eprintln!("HCONCAT: hconcat is run in parallel") + } + // We don't use par_iter directly because the LP may also start threads for every LP (for instance scan_csv) + // this might then lead to a rayon SO. So we take a multitude of the threads to keep work stealing + // within bounds + let out = POOL.install(|| { + inputs + .chunks_mut(POOL.current_num_threads() * 3) + .map(|chunk| { + chunk + .into_par_iter() + .enumerate() + .map(|(idx, input)| { + let mut input = std::mem::take(input); + let mut state = state.split(); + state.branch_idx += idx; + input.execute(&mut state) + }) + .collect::>>() + }) + .collect::>>() + }); + out?.into_iter().flatten().collect() + }; + + // Invariant of IR. Schema is already checked to contain no duplicates. + concat_df_horizontal(&dfs, false) + } +} diff --git a/crates/polars-mem-engine/src/executors/join.rs b/crates/polars-mem-engine/src/executors/join.rs new file mode 100644 index 000000000000..f305a548dc56 --- /dev/null +++ b/crates/polars-mem-engine/src/executors/join.rs @@ -0,0 +1,160 @@ +use polars_ops::frame::DataFrameJoinOps; + +use super::*; + +pub struct JoinExec { + input_left: Option>, + input_right: Option>, + left_on: Vec>, + right_on: Vec>, + parallel: bool, + args: JoinArgs, + options: Option, +} + +impl JoinExec { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + input_left: Box, + input_right: Box, + left_on: Vec>, + right_on: Vec>, + parallel: bool, + args: JoinArgs, + options: Option, + ) -> Self { + JoinExec { + input_left: Some(input_left), + input_right: Some(input_right), + left_on, + right_on, + parallel, + args, + options, + } + } +} + +impl Executor for JoinExec { + fn execute<'a>(&'a mut self, state: &'a mut ExecutionState) -> PolarsResult { + state.should_stop()?; + #[cfg(debug_assertions)] + { + if state.verbose() { + eprintln!("run JoinExec") + } + } + if state.verbose() { + eprintln!("join parallel: {}", self.parallel); + }; + let mut input_left = self.input_left.take().unwrap(); + let mut input_right = self.input_right.take().unwrap(); + + let (df_left, df_right) = if self.parallel { + let mut state_right = state.split(); + let mut state_left = state.split(); + state_right.branch_idx += 1; + // propagate the fetch_rows static value to the spawning threads. + let fetch_rows = FETCH_ROWS.with(|fetch_rows| fetch_rows.get()); + + POOL.join( + move || { + FETCH_ROWS.with(|fr| fr.set(fetch_rows)); + input_left.execute(&mut state_left) + }, + move || { + FETCH_ROWS.with(|fr| fr.set(fetch_rows)); + input_right.execute(&mut state_right) + }, + ) + } else { + (input_left.execute(state), input_right.execute(state)) + }; + + let df_left = df_left?; + let df_right = df_right?; + + let profile_name = if state.has_node_timer() { + let by = self + .left_on + .iter() + .map(|s| Ok(s.to_field(df_left.schema())?.name)) + .collect::>>()?; + let name = comma_delimited("join".to_string(), &by); + Cow::Owned(name) + } else { + Cow::Borrowed("") + }; + + state.record(|| { + + let left_on_series = self + .left_on + .iter() + .map(|e| e.evaluate(&df_left, state)) + .collect::>>()?; + + let right_on_series = self + .right_on + .iter() + .map(|e| e.evaluate(&df_right, state)) + .collect::>>()?; + + // prepare the tolerance + // we must ensure that we use the right units + #[cfg(feature = "asof_join")] + { + if let JoinType::AsOf(options) = &mut self.args.how { + use polars_core::utils::arrow::temporal_conversions::MILLISECONDS_IN_DAY; + if let Some(tol) = &options.tolerance_str { + let duration = polars_time::Duration::try_parse(tol)?; + polars_ensure!( + duration.months() == 0, + ComputeError: "cannot use month offset in timedelta of an asof join; \ + consider using 4 weeks" + ); + let left_asof = df_left.column(left_on_series[0].name())?; + use DataType::*; + match left_asof.dtype() { + Datetime(tu, _) | Duration(tu) => { + let tolerance = match tu { + TimeUnit::Nanoseconds => duration.duration_ns(), + TimeUnit::Microseconds => duration.duration_us(), + TimeUnit::Milliseconds => duration.duration_ms(), + }; + options.tolerance = Some(AnyValue::from(tolerance)) + } + Date => { + let days = (duration.duration_ms() / MILLISECONDS_IN_DAY) as i32; + options.tolerance = Some(AnyValue::from(days)) + } + Time => { + let tolerance = duration.duration_ns(); + options.tolerance = Some(AnyValue::from(tolerance)) + } + _ => { + panic!("can only use timedelta string language with Date/Datetime/Duration/Time dtypes") + } + } + } + } + } + + let df = df_left._join_impl( + &df_right, + left_on_series.into_iter().map(|c| c.take_materialized_series()).collect(), + right_on_series.into_iter().map(|c| c.take_materialized_series()).collect(), + self.args.clone(), + self.options.clone(), + true, + state.verbose(), + ); + + if state.verbose() { + eprintln!("{:?} join dataframes finished", self.args.how); + }; + df + + }, profile_name) + } +} diff --git a/crates/polars-mem-engine/src/executors/merge_sorted.rs b/crates/polars-mem-engine/src/executors/merge_sorted.rs new file mode 100644 index 000000000000..9d3a2d16a469 --- /dev/null +++ b/crates/polars-mem-engine/src/executors/merge_sorted.rs @@ -0,0 +1,47 @@ +use polars_ops::prelude::*; + +use super::*; + +pub(crate) struct MergeSorted { + pub(crate) input_left: Box, + pub(crate) input_right: Box, + pub(crate) key: PlSmallStr, +} + +impl Executor for MergeSorted { + fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { + state.should_stop()?; + #[cfg(debug_assertions)] + { + if state.verbose() { + eprintln!("run MergeSorted") + } + } + let (left, right) = { + let mut state2 = state.split(); + state2.branch_idx += 1; + let (left, right) = POOL.join( + || self.input_left.execute(state), + || self.input_right.execute(&mut state2), + ); + (left?, right?) + }; + + let profile_name = Cow::Borrowed("Merge Sorted"); + state.record( + || { + let lhs = left.column(self.key.as_str())?; + let rhs = right.column(self.key.as_str())?; + + _merge_sorted_dfs( + &left, + &right, + lhs.as_materialized_series(), + rhs.as_materialized_series(), + true, + ) + }, + profile_name, + ) + } +} diff --git a/crates/polars-mem-engine/src/executors/mod.rs b/crates/polars-mem-engine/src/executors/mod.rs new file mode 100644 index 000000000000..6199f249e4ab --- /dev/null +++ b/crates/polars-mem-engine/src/executors/mod.rs @@ -0,0 +1,55 @@ +mod cache; +mod executor; +mod ext_context; +mod filter; +mod group_by; +mod group_by_dynamic; +mod group_by_partitioned; +pub(super) mod group_by_rolling; +mod hconcat; +mod join; +#[cfg(feature = "merge_sorted")] +mod merge_sorted; +mod projection; +mod projection_simple; +mod projection_utils; +mod scan; +mod slice; +mod sort; +mod stack; +mod udf; +mod union; +mod unique; + +use std::borrow::Cow; + +pub use executor::*; +use polars_core::POOL; +use polars_plan::global::FETCH_ROWS; +use polars_plan::utils::*; +use projection_utils::*; +use rayon::prelude::*; + +pub(super) use self::cache::*; +pub(super) use self::ext_context::*; +pub(super) use self::filter::*; +pub(super) use self::group_by::*; +#[cfg(feature = "dynamic_group_by")] +pub(super) use self::group_by_dynamic::*; +pub(super) use self::group_by_partitioned::*; +#[cfg(feature = "dynamic_group_by")] +pub(super) use self::group_by_rolling::GroupByRollingExec; +pub(super) use self::hconcat::*; +pub(super) use self::join::*; +#[cfg(feature = "merge_sorted")] +pub(super) use self::merge_sorted::*; +pub(super) use self::projection::*; +pub(super) use self::projection_simple::*; +pub(super) use self::scan::*; +pub(super) use self::slice::*; +pub(super) use self::sort::*; +pub(super) use self::stack::*; +pub(super) use self::udf::*; +pub(super) use self::union::*; +pub(super) use self::unique::*; +use crate::prelude::*; diff --git a/crates/polars-mem-engine/src/executors/projection.rs b/crates/polars-mem-engine/src/executors/projection.rs new file mode 100644 index 000000000000..bbab42740c5e --- /dev/null +++ b/crates/polars-mem-engine/src/executors/projection.rs @@ -0,0 +1,103 @@ +use polars_core::utils::accumulate_dataframes_vertical_unchecked; + +use super::*; + +/// Take an input Executor (creates the input DataFrame) +/// and a multiple PhysicalExpressions (create the output Series) +pub struct ProjectionExec { + pub(crate) input: Box, + pub(crate) expr: Vec>, + pub(crate) has_windows: bool, + pub(crate) input_schema: SchemaRef, + #[cfg(test)] + pub(crate) schema: SchemaRef, + pub(crate) options: ProjectionOptions, + // Can run all operations elementwise + pub(crate) allow_vertical_parallelism: bool, +} + +impl ProjectionExec { + fn execute_impl( + &mut self, + state: &ExecutionState, + mut df: DataFrame, + ) -> PolarsResult { + // Vertical and horizontal parallelism. + let df = if self.allow_vertical_parallelism + && df.first_col_n_chunks() > 1 + && df.height() > POOL.current_num_threads() * 2 + && self.options.run_parallel + { + let chunks = df.split_chunks().collect::>(); + let iter = chunks.into_par_iter().map(|mut df| { + let selected_cols = evaluate_physical_expressions( + &mut df, + &self.expr, + state, + self.has_windows, + self.options.run_parallel, + )?; + check_expand_literals(&df, &self.expr, selected_cols, df.is_empty(), self.options) + }); + + let df = POOL.install(|| iter.collect::>>())?; + accumulate_dataframes_vertical_unchecked(df) + } + // Only horizontal parallelism. + else { + #[allow(clippy::let_and_return)] + let selected_cols = evaluate_physical_expressions( + &mut df, + &self.expr, + state, + self.has_windows, + self.options.run_parallel, + )?; + check_expand_literals(&df, &self.expr, selected_cols, df.is_empty(), self.options)? + }; + + // this only runs during testing and check if the runtime type matches the predicted schema + #[cfg(test)] + #[allow(unused_must_use)] + { + // TODO: also check the types. + for (l, r) in df.iter().zip(self.schema.iter_names()) { + assert_eq!(l.name(), r); + } + } + + Ok(df) + } +} + +impl Executor for ProjectionExec { + fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { + state.should_stop()?; + #[cfg(debug_assertions)] + { + if state.verbose() { + eprintln!("run ProjectionExec"); + } + } + let df = self.input.execute(state)?; + + let profile_name = if state.has_node_timer() { + let by = self + .expr + .iter() + .map(|s| profile_name(s.as_ref(), self.input_schema.as_ref())) + .collect::>>()?; + let name = comma_delimited("select".to_string(), &by); + Cow::Owned(name) + } else { + Cow::Borrowed("") + }; + + if state.has_node_timer() { + let new_state = state.clone(); + new_state.record(|| self.execute_impl(state, df), profile_name) + } else { + self.execute_impl(state, df) + } + } +} diff --git a/crates/polars-mem-engine/src/executors/projection_simple.rs b/crates/polars-mem-engine/src/executors/projection_simple.rs new file mode 100644 index 000000000000..c3102d3b7222 --- /dev/null +++ b/crates/polars-mem-engine/src/executors/projection_simple.rs @@ -0,0 +1,34 @@ +use super::*; + +pub struct ProjectionSimple { + pub(crate) input: Box, + pub(crate) columns: SchemaRef, +} + +impl ProjectionSimple { + fn execute_impl(&mut self, df: DataFrame, columns: &[PlSmallStr]) -> PolarsResult { + // No duplicate check as that an invariant of this node. + df._select_impl_unchecked(columns.as_ref()) + } +} + +impl Executor for ProjectionSimple { + fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { + state.should_stop()?; + let columns = self.columns.iter_names_cloned().collect::>(); + + let profile_name = if state.has_node_timer() { + let name = comma_delimited("simple-projection".to_string(), columns.as_slice()); + Cow::Owned(name) + } else { + Cow::Borrowed("") + }; + let df = self.input.execute(state)?; + + if state.has_node_timer() { + state.record(|| self.execute_impl(df, columns.as_slice()), profile_name) + } else { + self.execute_impl(df, columns.as_slice()) + } + } +} diff --git a/crates/polars-mem-engine/src/executors/projection_utils.rs b/crates/polars-mem-engine/src/executors/projection_utils.rs new file mode 100644 index 000000000000..0eac04eaf7f4 --- /dev/null +++ b/crates/polars-mem-engine/src/executors/projection_utils.rs @@ -0,0 +1,387 @@ +use polars_plan::constants::CSE_REPLACED; +use polars_utils::itertools::Itertools; + +use super::*; + +pub(super) fn profile_name( + s: &dyn PhysicalExpr, + input_schema: &Schema, +) -> PolarsResult { + match s.to_field(input_schema) { + Err(e) => Err(e), + Ok(fld) => Ok(fld.name), + } +} + +type IdAndExpression = (u32, Arc); + +#[cfg(feature = "dynamic_group_by")] +fn rolling_evaluate( + df: &DataFrame, + state: &ExecutionState, + rolling: PlHashMap<&RollingGroupOptions, Vec>, +) -> PolarsResult>> { + POOL.install(|| { + rolling + .par_iter() + .map(|(options, partition)| { + // clear the cache for every partitioned group + let state = state.split(); + + let (_time_key, groups) = df.rolling(None, options)?; + + let groups_key = format!("{:?}", options); + // Set the groups so all expressions in partition can use it. + // Create a separate scope, so the lock is dropped, otherwise we deadlock when the + // rolling expression try to get read access. + state.window_cache.insert_groups(groups_key, groups); + partition + .par_iter() + .map(|(idx, expr)| expr.evaluate(df, &state).map(|s| (*idx, s))) + .collect::>>() + }) + .collect() + }) +} + +fn window_evaluate( + df: &DataFrame, + state: &ExecutionState, + window: PlHashMap>, +) -> PolarsResult>> { + if window.is_empty() { + return Ok(vec![]); + } + let n_threads = POOL.current_num_threads(); + + let max_hor = window.values().map(|v| v.len()).max().unwrap_or(0); + let vert = window.len(); + + // We don't want to cache and parallel horizontally and vertically as that keeps many cache + // states alive. + let (cache, par_vertical, par_horizontal) = if max_hor >= n_threads || max_hor >= vert { + (true, false, true) + } else { + (false, true, true) + }; + + let apply = |partition: &[(u32, Arc)]| { + // clear the cache for every partitioned group + let mut state = state.split(); + // inform the expression it has window functions. + state.insert_has_window_function_flag(); + + // caching more than one window expression is a complicated topic for another day + // see issue #2523 + let cache = cache + && partition.len() > 1 + && partition.iter().all(|(_, e)| { + e.as_expression() + .unwrap() + .into_iter() + .filter(|e| matches!(e, Expr::Window { .. })) + .count() + == 1 + }); + let mut first_result = None; + // First run 1 to fill the cache. Condvars and such don't work as + // rayon threads should not be blocked. + if cache { + let first = &partition[0]; + let c = first.1.evaluate(df, &state)?; + first_result = Some((first.0, c)); + state.insert_cache_window_flag(); + } else { + state.remove_cache_window_flag(); + } + + let apply = + |index: &u32, e: &Arc| e.evaluate(df, &state).map(|c| (*index, c)); + + let slice = &partition[first_result.is_some() as usize..]; + let mut results = if par_horizontal { + slice + .par_iter() + .map(|(index, e)| apply(index, e)) + .collect::>>()? + } else { + slice + .iter() + .map(|(index, e)| apply(index, e)) + .collect::>>()? + }; + + if let Some(item) = first_result { + results.push(item) + } + + Ok(results) + }; + + if par_vertical { + POOL.install(|| window.par_iter().map(|t| apply(t.1)).collect()) + } else { + window.iter().map(|t| apply(t.1)).collect() + } +} + +fn execute_projection_cached_window_fns( + df: &DataFrame, + exprs: &[Arc], + state: &ExecutionState, +) -> PolarsResult> { + // We partition by normal expression and window expression + // - the normal expressions can run in parallel + // - the window expression take more memory and often use the same group_by keys and join tuples + // so they are cached and run sequential + + // the partitioning messes with column order, so we also store the idx + // and use those to restore the original projection order + #[allow(clippy::type_complexity)] + // String: partition_name, + // u32: index, + let mut windows: PlHashMap> = PlHashMap::default(); + #[cfg(feature = "dynamic_group_by")] + let mut rolling: PlHashMap<&RollingGroupOptions, Vec> = PlHashMap::default(); + let mut other = Vec::with_capacity(exprs.len()); + + // first we partition the window function by the values they group over. + // the group_by values should be cached + exprs.iter().enumerate_u32().for_each(|(index, phys)| { + let mut is_window = false; + if let Some(e) = phys.as_expression() { + for e in e.into_iter() { + if let Expr::Window { + partition_by, + options, + order_by, + .. + } = e + { + let entry = match options { + WindowType::Over(g) => { + let g: &str = g.into(); + let mut key = format!("{:?}_{}", partition_by.as_slice(), g); + if let Some((e, k)) = order_by { + polars_expr::prelude::window_function_format_order_by( + &mut key, + e.as_ref(), + k, + ) + } + windows.entry(key).or_insert_with(Vec::new) + }, + #[cfg(feature = "dynamic_group_by")] + WindowType::Rolling(options) => { + rolling.entry(options).or_insert_with(Vec::new) + }, + }; + entry.push((index, phys.clone())); + is_window = true; + break; + } + } + } else { + // Window physical expressions always have the `Expr`. + is_window = false; + } + if !is_window { + other.push((index, phys.as_ref())) + } + }); + + let mut selected_columns = POOL.install(|| { + other + .par_iter() + .map(|(idx, expr)| expr.evaluate(df, state).map(|s| (*idx, s))) + .collect::>>() + })?; + + // Run partitioned rolling expressions. + // Per partition we run in parallel. We compute the groups before and store them once per partition. + // The rolling expression knows how to fetch the groups. + #[cfg(feature = "dynamic_group_by")] + { + let (a, b) = POOL.join( + || rolling_evaluate(df, state, rolling), + || window_evaluate(df, state, windows), + ); + + let partitions = a?; + for part in partitions { + selected_columns.extend_from_slice(&part) + } + let partitions = b?; + for part in partitions { + selected_columns.extend_from_slice(&part) + } + } + #[cfg(not(feature = "dynamic_group_by"))] + { + let partitions = window_evaluate(df, state, windows)?; + for part in partitions { + selected_columns.extend_from_slice(&part) + } + } + + selected_columns.sort_unstable_by_key(|tpl| tpl.0); + let selected_columns = selected_columns.into_iter().map(|tpl| tpl.1).collect(); + Ok(selected_columns) +} + +fn run_exprs_par( + df: &DataFrame, + exprs: &[Arc], + state: &ExecutionState, +) -> PolarsResult> { + POOL.install(|| { + exprs + .par_iter() + .map(|expr| expr.evaluate(df, state)) + .collect() + }) +} + +fn run_exprs_seq( + df: &DataFrame, + exprs: &[Arc], + state: &ExecutionState, +) -> PolarsResult> { + exprs.iter().map(|expr| expr.evaluate(df, state)).collect() +} + +pub(super) fn evaluate_physical_expressions( + df: &mut DataFrame, + exprs: &[Arc], + state: &ExecutionState, + has_windows: bool, + run_parallel: bool, +) -> PolarsResult> { + let expr_runner = if has_windows { + execute_projection_cached_window_fns + } else if run_parallel && exprs.len() > 1 { + run_exprs_par + } else { + run_exprs_seq + }; + + let selected_columns = expr_runner(df, exprs, state)?; + + if has_windows { + state.clear_window_expr_cache(); + } + + Ok(selected_columns) +} + +pub(super) fn check_expand_literals( + df: &DataFrame, + phys_expr: &[Arc], + mut selected_columns: Vec, + zero_length: bool, + options: ProjectionOptions, +) -> PolarsResult { + let Some(first_len) = selected_columns.first().map(|s| s.len()) else { + return Ok(DataFrame::empty()); + }; + let duplicate_check = options.duplicate_check; + let should_broadcast = options.should_broadcast; + + // When we have CSE we cannot verify scalars yet. + let verify_scalar = if !df.get_columns().is_empty() { + !df.get_columns()[df.width() - 1] + .name() + .starts_with(CSE_REPLACED) + } else { + true + }; + + let mut df_height = 0; + let mut has_empty = false; + let mut all_equal_len = true; + { + let mut names = PlHashSet::with_capacity(selected_columns.len()); + for s in &selected_columns { + let len = s.len(); + has_empty |= len == 0; + df_height = std::cmp::max(df_height, len); + if len != first_len { + all_equal_len = false; + } + let name = s.name(); + + if duplicate_check && !names.insert(name) { + let msg = format!( + "the name '{}' is duplicate\n\n\ + It's possible that multiple expressions are returning the same default column \ + name. If this is the case, try renaming the columns with \ + `.alias(\"new_name\")` to avoid duplicate column names.", + name + ); + return Err(PolarsError::Duplicate(msg.into())); + } + } + } + + // If all series are the same length it is ok. If not we can broadcast Series of length one. + if !all_equal_len && should_broadcast { + selected_columns = selected_columns + .into_iter() + .zip(phys_expr) + .map(|(series, phys)| { + Ok(match series.len() { + 0 if df_height == 1 => series, + 1 => { + if !has_empty && df_height == 1 { + series + } else { + if has_empty { + polars_ensure!(df_height == 1, + ShapeMismatch: "Series length {} doesn't match the DataFrame height of {}", + series.len(), df_height + ); + + } + + if verify_scalar && !phys.is_scalar() && std::env::var("POLARS_ALLOW_NON_SCALAR_EXP").as_deref() != Ok("1") { + let identifier = match phys.as_expression() { + Some(e) => format!("expression: {}", e), + None => "this Series".to_string(), + }; + polars_bail!(ShapeMismatch: "Series {}, length {} doesn't match the DataFrame height of {}\n\n\ + If you want {} to be broadcasted, ensure it is a scalar (for instance by adding '.first()').", + series.name(), series.len(), df_height *(!has_empty as usize), identifier + ); + } + series.new_from_index(0, df_height * (!has_empty as usize) ) + } + }, + len if len == df_height => { + series + }, + _ => { + polars_bail!( + ShapeMismatch: "Series length {} doesn't match the DataFrame height of {}", + series.len(), df_height + ) + } + }) + }) + .collect::>()? + } + + // @scalar-opt + let selected_columns = selected_columns.into_iter().collect::>(); + + let df = unsafe { DataFrame::new_no_checks_height_from_first(selected_columns) }; + + // a literal could be projected to a zero length dataframe. + // This prevents a panic. + let df = if zero_length { + let min = df.get_columns().iter().map(|s| s.len()).min(); + if min.is_some() { df.head(min) } else { df } + } else { + df + }; + Ok(df) +} diff --git a/crates/polars-mem-engine/src/executors/scan/mod.rs b/crates/polars-mem-engine/src/executors/scan/mod.rs new file mode 100644 index 000000000000..e8149286f864 --- /dev/null +++ b/crates/polars-mem-engine/src/executors/scan/mod.rs @@ -0,0 +1,92 @@ +#[cfg(feature = "python")] +mod python_scan; + +use std::mem; + +use polars_plan::global::_set_n_rows_for_scan; +use polars_utils::slice_enum::Slice; + +#[cfg(feature = "python")] +pub(crate) use self::python_scan::*; +use super::*; +use crate::ScanPredicate; +use crate::prelude::*; + +/// Producer of an in memory DataFrame +pub struct DataFrameExec { + pub(crate) df: Arc, + pub(crate) projection: Option>, +} + +impl Executor for DataFrameExec { + fn execute(&mut self, _state: &mut ExecutionState) -> PolarsResult { + let df = mem::take(&mut self.df); + let mut df = Arc::try_unwrap(df).unwrap_or_else(|df| (*df).clone()); + + // projection should be before selection as those are free + // TODO: this is only the case if we don't create new columns + if let Some(projection) = &self.projection { + df = df.select(projection.iter().cloned())?; + } + + Ok(match _set_n_rows_for_scan(None) { + Some(limit) => df.head(Some(limit)), + None => df, + }) + } +} + +pub(crate) struct AnonymousScanExec { + pub(crate) function: Arc, + pub(crate) unified_scan_args: Box, + pub(crate) file_info: FileInfo, + pub(crate) predicate: Option, + pub(crate) output_schema: Option, + pub(crate) predicate_has_windows: bool, +} + +impl Executor for AnonymousScanExec { + fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { + let mut args = AnonymousScanArgs { + n_rows: self.unified_scan_args.pre_slice.clone().map(|x| { + assert!(matches!(x, Slice::Positive { offset: 0, .. })); + + x.len() + }), + with_columns: self.unified_scan_args.projection.clone(), + schema: self.file_info.schema.clone(), + output_schema: self.output_schema.clone(), + predicate: None, + }; + if self.predicate.is_some() { + state.insert_has_window_function_flag() + } + + match (self.function.allows_predicate_pushdown(), &self.predicate) { + (true, Some(predicate)) => state.record( + || { + args.predicate = predicate.predicate.as_expression().cloned(); + self.function.scan(args) + }, + "anonymous_scan".into(), + ), + (false, Some(predicate)) => state.record( + || { + let mut df = self.function.scan(args)?; + let s = predicate.predicate.evaluate(&df, state)?; + if self.predicate_has_windows { + state.clear_window_expr_cache() + } + let mask = s.bool().map_err( + |_| polars_err!(ComputeError: "filter predicate was not of type boolean"), + )?; + df = df.filter(mask)?; + + Ok(df) + }, + "anonymous_scan".into(), + ), + _ => state.record(|| self.function.scan(args), "anonymous_scan".into()), + } + } +} diff --git a/crates/polars-mem-engine/src/executors/scan/python_scan.rs b/crates/polars-mem-engine/src/executors/scan/python_scan.rs new file mode 100644 index 000000000000..4beefc4bf482 --- /dev/null +++ b/crates/polars-mem-engine/src/executors/scan/python_scan.rs @@ -0,0 +1,186 @@ +use polars_core::error::to_compute_err; +use polars_core::utils::accumulate_dataframes_vertical; +use pyo3::exceptions::PyStopIteration; +use pyo3::prelude::*; +use pyo3::types::{PyBytes, PyNone}; +use pyo3::{IntoPyObjectExt, PyTypeInfo, intern}; + +use self::python_dsl::PythonScanSource; +use super::*; + +pub(crate) struct PythonScanExec { + pub(crate) options: PythonOptions, + pub(crate) predicate: Option>, + pub(crate) predicate_serialized: Option>, +} + +impl PythonScanExec { + /// Get the output schema. E.g. the schema the plugins produce, not consume. + fn get_schema(&self) -> &SchemaRef { + self.options + .output_schema + .as_ref() + .unwrap_or(&self.options.schema) + } + + fn check_schema(&self, df: &DataFrame) -> PolarsResult<()> { + if self.options.validate_schema { + let output_schema = self.get_schema(); + polars_ensure!(df.schema() == output_schema, SchemaMismatch: "user provided schema: {:?} doesn't match the DataFrame schema: {:?}", output_schema, df.schema()); + } + Ok(()) + } + + fn finish_df( + &self, + py: Python, + df: Bound<'_, PyAny>, + state: &mut ExecutionState, + ) -> PolarsResult { + let df = python_df_to_rust(py, df)?; + + self.check_schema(&df)?; + + if let Some(pred) = &self.predicate { + let mask = pred.evaluate(&df, state)?; + df.filter(mask.bool()?) + } else { + Ok(df) + } + } +} + +impl Executor for PythonScanExec { + fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { + state.should_stop()?; + #[cfg(debug_assertions)] + { + if state.verbose() { + eprintln!("run PythonScanExec") + } + } + let with_columns = self.options.with_columns.take(); + let n_rows = self.options.n_rows.take(); + Python::with_gil(|py| { + let pl = PyModule::import(py, intern!(py, "polars")).unwrap(); + let utils = pl.getattr(intern!(py, "_utils")).unwrap(); + let callable = utils.getattr(intern!(py, "_execute_from_rust")).unwrap(); + + let python_scan_function = self.options.scan_fn.take().unwrap().0; + + let with_columns = with_columns.map(|cols| cols.iter().cloned().collect::>()); + let mut could_serialize_predicate = true; + + let predicate = match &self.options.predicate { + PythonPredicate::PyArrow(s) => s.into_bound_py_any(py).unwrap(), + PythonPredicate::None => None::<()>.into_bound_py_any(py).unwrap(), + PythonPredicate::Polars(_) => { + assert!(self.predicate.is_some(), "should be set"); + + match &self.predicate_serialized { + None => { + could_serialize_predicate = false; + PyNone::get(py).to_owned().into_any() + }, + Some(buf) => PyBytes::new(py, buf).into_any(), + } + }, + }; + + match self.options.python_source { + PythonScanSource::Cuda => { + let args = ( + python_scan_function, + with_columns + .map(|x| x.into_iter().map(|x| x.to_string()).collect::>()), + predicate, + n_rows, + // If this boolean is true, callback should return + // a dataframe and list of timings [(start, end, + // name)] + state.has_node_timer(), + ); + let result = callable.call1(args).map_err(to_compute_err)?; + let df = if state.has_node_timer() { + let df = result.get_item(0).map_err(to_compute_err); + let timing_info: Vec<(u64, u64, String)> = result + .get_item(1) + .map_err(to_compute_err)? + .extract() + .map_err(to_compute_err)?; + state.record_raw_timings(&timing_info); + df? + } else { + result + }; + self.finish_df(py, df, state) + }, + PythonScanSource::Pyarrow => { + let args = ( + python_scan_function, + with_columns + .map(|x| x.into_iter().map(|x| x.to_string()).collect::>()), + predicate, + n_rows, + ); + let df = callable.call1(args).map_err(to_compute_err)?; + self.finish_df(py, df, state) + }, + PythonScanSource::IOPlugin => { + // If there are filters, take smaller chunks to ensure we can keep memory + // pressure low. + let batch_size = if self.predicate.is_some() { + Some(100_000usize) + } else { + None + }; + let args = ( + python_scan_function, + with_columns + .map(|x| x.into_iter().map(|x| x.to_string()).collect::>()), + predicate, + n_rows, + batch_size, + ); + + let generator_init = callable.call1(args).map_err(to_compute_err)?; + let generator = generator_init.get_item(0).map_err( + |_| polars_err!(ComputeError: "expected tuple got {}", generator_init), + )?; + let can_parse_predicate = generator_init.get_item(1).map_err( + |_| polars_err!(ComputeError: "expected tuple got {}", generator), + )?; + let can_parse_predicate = can_parse_predicate.extract::().map_err( + |_| polars_err!(ComputeError: "expected bool got {}", can_parse_predicate), + )? && could_serialize_predicate; + + let mut chunks = vec![]; + loop { + match generator.call_method0(intern!(py, "__next__")) { + Ok(out) => { + let mut df = python_df_to_rust(py, out)?; + if let (Some(pred), false) = (&self.predicate, can_parse_predicate) + { + let mask = pred.evaluate(&df, state)?; + df = df.filter(mask.bool()?)?; + } + chunks.push(df) + }, + Err(err) if err.matches(py, PyStopIteration::type_object(py))? => break, + Err(err) => { + polars_bail!(ComputeError: "caught exception during execution of a Python source, exception: {}", err) + }, + } + } + if chunks.is_empty() { + return Ok(DataFrame::empty_with_schema(self.get_schema().as_ref())); + } + let df = accumulate_dataframes_vertical(chunks)?; + + self.check_schema(&df)?; + Ok(df) + }, + } + }) + } +} diff --git a/crates/polars-mem-engine/src/executors/slice.rs b/crates/polars-mem-engine/src/executors/slice.rs new file mode 100644 index 000000000000..c1ff0a9667a5 --- /dev/null +++ b/crates/polars-mem-engine/src/executors/slice.rs @@ -0,0 +1,24 @@ +use super::*; + +pub struct SliceExec { + pub input: Box, + pub offset: i64, + pub len: IdxSize, +} + +impl Executor for SliceExec { + fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { + #[cfg(debug_assertions)] + { + if state.verbose() { + eprintln!("run SliceExec") + } + } + let df = self.input.execute(state)?; + + state.record( + || Ok(df.slice(self.offset, self.len as usize)), + "slice".into(), + ) + } +} diff --git a/crates/polars-mem-engine/src/executors/sort.rs b/crates/polars-mem-engine/src/executors/sort.rs new file mode 100644 index 000000000000..ec1e0aad276c --- /dev/null +++ b/crates/polars-mem-engine/src/executors/sort.rs @@ -0,0 +1,79 @@ +use polars_utils::format_pl_smallstr; + +use super::*; + +pub(crate) struct SortExec { + pub(crate) input: Box, + pub(crate) by_column: Vec>, + pub(crate) slice: Option<(i64, usize)>, + pub(crate) sort_options: SortMultipleOptions, +} + +impl SortExec { + fn execute_impl( + &mut self, + state: &ExecutionState, + mut df: DataFrame, + ) -> PolarsResult { + state.should_stop()?; + df.as_single_chunk_par(); + + let height = df.height(); + + let by_columns = self + .by_column + .iter() + .enumerate() + .map(|(i, e)| { + let mut s = e.evaluate(&df, state)?.into_column(); + // Polars core will try to set the sorted columns as sorted. + // This should only be done with simple col("foo") expressions, + // therefore we rename more complex expressions so that + // polars core does not match these. + if !matches!(e.as_expression(), Some(&Expr::Column(_))) { + s.rename(format_pl_smallstr!("_POLARS_SORT_BY_{i}")); + } + polars_ensure!( + s.len() == height, + ShapeMismatch: "sort expressions must have same \ + length as DataFrame, got DataFrame height: {} and Series length: {}", + height, s.len() + ); + Ok(s) + }) + .collect::>>()?; + + df.sort_impl(by_columns, self.sort_options.clone(), self.slice) + } +} + +impl Executor for SortExec { + fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { + #[cfg(debug_assertions)] + { + if state.verbose() { + eprintln!("run SortExec") + } + } + let df = self.input.execute(state)?; + + let profile_name = if state.has_node_timer() { + let by = self + .by_column + .iter() + .map(|s| Ok(s.to_field(df.schema())?.name)) + .collect::>>()?; + let name = comma_delimited("sort".to_string(), &by); + Cow::Owned(name) + } else { + Cow::Borrowed("") + }; + + if state.has_node_timer() { + let new_state = state.clone(); + new_state.record(|| self.execute_impl(state, df), profile_name) + } else { + self.execute_impl(state, df) + } + } +} diff --git a/crates/polars-mem-engine/src/executors/stack.rs b/crates/polars-mem-engine/src/executors/stack.rs new file mode 100644 index 000000000000..62e1ffb7fdfe --- /dev/null +++ b/crates/polars-mem-engine/src/executors/stack.rs @@ -0,0 +1,139 @@ +use polars_core::utils::accumulate_dataframes_vertical_unchecked; +use polars_plan::constants::CSE_REPLACED; + +use super::*; + +pub struct StackExec { + pub(crate) input: Box, + pub(crate) has_windows: bool, + pub(crate) exprs: Vec>, + pub(crate) input_schema: SchemaRef, + pub(crate) output_schema: SchemaRef, + pub(crate) options: ProjectionOptions, + // Can run all operations elementwise + pub(crate) allow_vertical_parallelism: bool, +} + +impl StackExec { + fn execute_impl( + &mut self, + state: &ExecutionState, + mut df: DataFrame, + ) -> PolarsResult { + let schema = &*self.output_schema; + + // Vertical and horizontal parallelism. + let df = if self.allow_vertical_parallelism + && df.first_col_n_chunks() > 1 + && df.height() > 0 + && self.options.run_parallel + { + let chunks = df.split_chunks().collect::>(); + let iter = chunks.into_par_iter().map(|mut df| { + let res = evaluate_physical_expressions( + &mut df, + &self.exprs, + state, + self.has_windows, + self.options.run_parallel, + )?; + // We don't have to do a broadcast check as cse is not allowed to hit this. + df._add_columns(res.into_iter().collect(), schema)?; + Ok(df) + }); + + let df = POOL.install(|| iter.collect::>>())?; + accumulate_dataframes_vertical_unchecked(df) + } + // Only horizontal parallelism + else { + let res = evaluate_physical_expressions( + &mut df, + &self.exprs, + state, + self.has_windows, + self.options.run_parallel, + )?; + if !self.options.should_broadcast { + debug_assert!( + res.iter() + .all(|column| column.name().starts_with("__POLARS_CSER_0x")), + "non-broadcasting hstack should only be used for CSE columns" + ); + // Safety: this case only appears as a result of + // CSE optimization, and the usage there produces + // new, unique column names. It is immediately + // followed by a projection which pulls out the + // possibly mismatching column lengths. + unsafe { df.column_extend_unchecked(res) }; + } else { + let (df_height, df_width) = df.shape(); + + // When we have CSE we cannot verify scalars yet. + let verify_scalar = if !df.get_columns().is_empty() { + !df.get_columns()[df.width() - 1] + .name() + .starts_with(CSE_REPLACED) + } else { + true + }; + for (i, c) in res.iter().enumerate() { + let len = c.len(); + if verify_scalar && len != df_height && len == 1 && df_width > 0 { + #[allow(clippy::collapsible_if)] + if !self.exprs[i].is_scalar() + && std::env::var("POLARS_ALLOW_NON_SCALAR_EXP").as_deref() != Ok("1") + { + let identifier = match self.exprs[i].as_expression() { + Some(e) => format!("expression: {}", e), + None => "this Series".to_string(), + }; + polars_bail!(InvalidOperation: "Series {}, length {} doesn't match the DataFrame height of {}\n\n\ + If you want {} to be broadcasted, ensure it is a scalar (for instance by adding '.first()').", + c.name(), len, df_height, identifier + ); + } + } + } + df._add_columns(res.into_iter().collect(), schema)?; + } + df + }; + + state.clear_window_expr_cache(); + + Ok(df) + } +} + +impl Executor for StackExec { + fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { + state.should_stop()?; + #[cfg(debug_assertions)] + { + if state.verbose() { + eprintln!("run StackExec"); + } + } + let df = self.input.execute(state)?; + + let profile_name = if state.has_node_timer() { + let by = self + .exprs + .iter() + .map(|s| profile_name(s.as_ref(), self.input_schema.as_ref())) + .collect::>>()?; + let name = comma_delimited("with_column".to_string(), &by); + Cow::Owned(name) + } else { + Cow::Borrowed("") + }; + + if state.has_node_timer() { + let new_state = state.clone(); + new_state.record(|| self.execute_impl(state, df), profile_name) + } else { + self.execute_impl(state, df) + } + } +} diff --git a/crates/polars-mem-engine/src/executors/udf.rs b/crates/polars-mem-engine/src/executors/udf.rs new file mode 100644 index 000000000000..f5bea3a8cfee --- /dev/null +++ b/crates/polars-mem-engine/src/executors/udf.rs @@ -0,0 +1,26 @@ +use super::*; + +pub(crate) struct UdfExec { + pub(crate) input: Box, + pub(crate) function: FunctionIR, +} + +impl Executor for UdfExec { + fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { + state.should_stop()?; + #[cfg(debug_assertions)] + { + if state.verbose() { + eprintln!("run UdfExec") + } + } + let df = self.input.execute(state)?; + + let profile_name = if state.has_node_timer() { + Cow::Owned(format!("{}", self.function)) + } else { + Cow::Borrowed("") + }; + state.record(|| self.function.evaluate(df), profile_name) + } +} diff --git a/crates/polars-mem-engine/src/executors/union.rs b/crates/polars-mem-engine/src/executors/union.rs new file mode 100644 index 000000000000..54be89a6047a --- /dev/null +++ b/crates/polars-mem-engine/src/executors/union.rs @@ -0,0 +1,122 @@ +use polars_core::utils::concat_df; +use polars_plan::global::_is_fetch_query; + +use super::*; + +pub(crate) struct UnionExec { + pub(crate) inputs: Vec>, + pub(crate) options: UnionOptions, +} + +impl Executor for UnionExec { + fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { + state.should_stop()?; + #[cfg(debug_assertions)] + { + if state.verbose() { + eprintln!("run UnionExec") + } + } + // keep scans thread local if 'fetch' is used. + if _is_fetch_query() { + self.options.parallel = false; + } + let mut inputs = std::mem::take(&mut self.inputs); + + let sliced_path = if let Some((offset, _)) = self.options.slice { + offset >= 0 + } else { + false + }; + + if !self.options.parallel || sliced_path { + if state.verbose() { + if !self.options.parallel { + eprintln!("UNION: `parallel=false` union is run sequentially") + } else { + eprintln!("UNION: `slice is set` union is run sequentially") + } + } + + let (slice_offset, mut slice_len) = self.options.slice.unwrap_or((0, usize::MAX)); + let mut slice_offset = slice_offset as usize; + let mut dfs = Vec::with_capacity(inputs.len()); + + for (idx, mut input) in inputs.into_iter().enumerate() { + let mut state = state.split(); + state.branch_idx += idx; + + let df = input.execute(&mut state)?; + + if !sliced_path { + dfs.push(df); + continue; + } + + let height = df.height(); + // this part can be skipped as we haven't reached the offset yet + // TODO!: don't read the file yet! + if slice_offset > height { + slice_offset -= height; + } + // applying the slice + // continue iteration + else if slice_offset + slice_len > height { + slice_len -= height - slice_offset; + if slice_offset == 0 { + dfs.push(df); + } else { + dfs.push(df.slice(slice_offset as i64, usize::MAX)); + slice_offset = 0; + } + } + // we finished the slice + else { + dfs.push(df.slice(slice_offset as i64, slice_len)); + break; + } + } + + concat_df(&dfs) + } else { + if state.verbose() { + eprintln!("UNION: union is run in parallel") + } + + // we don't use par_iter directly because the LP may also start threads for every LP (for instance scan_csv) + // this might then lead to a rayon SO. So we take a multitude of the threads to keep work stealing + // within bounds + let out = POOL.install(|| { + inputs + .chunks_mut(POOL.current_num_threads() * 3) + .map(|chunk| { + chunk + .into_par_iter() + .enumerate() + .map(|(idx, input)| { + let mut input = std::mem::take(input); + let mut state = state.split(); + state.branch_idx += idx; + input.execute(&mut state) + }) + .collect::>>() + }) + .collect::>>() + }); + + concat_df(out?.iter().flat_map(|dfs| dfs.iter())).map(|df| { + if let Some((offset, len)) = self.options.slice { + df.slice(offset, len) + } else { + df + } + }) + } + .map(|mut df| { + if self.options.rechunk { + df.as_single_chunk_par(); + } + df + }) + } +} diff --git a/crates/polars-mem-engine/src/executors/unique.rs b/crates/polars-mem-engine/src/executors/unique.rs new file mode 100644 index 000000000000..69c7b19c528a --- /dev/null +++ b/crates/polars-mem-engine/src/executors/unique.rs @@ -0,0 +1,41 @@ +use super::*; + +pub(crate) struct UniqueExec { + pub(crate) input: Box, + pub(crate) options: DistinctOptionsIR, +} + +impl Executor for UniqueExec { + fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { + state.should_stop()?; + #[cfg(debug_assertions)] + { + if state.verbose() { + eprintln!("run UniqueExec") + } + } + let df = self.input.execute(state)?; + let subset = self + .options + .subset + .as_ref() + .map(|v| v.iter().cloned().collect::>()); + let keep = self.options.keep_strategy; + + state.record( + || { + if df.is_empty() { + return Ok(df); + } + + df.unique_impl( + self.options.maintain_order, + subset, + keep, + self.options.slice, + ) + }, + Cow::Borrowed("unique()"), + ) + } +} diff --git a/crates/polars-mem-engine/src/lib.rs b/crates/polars-mem-engine/src/lib.rs new file mode 100644 index 000000000000..d0d0ba7e1cbe --- /dev/null +++ b/crates/polars-mem-engine/src/lib.rs @@ -0,0 +1,13 @@ +mod executors; +mod planner; +mod predicate; +mod prelude; + +pub use executors::Executor; +#[cfg(feature = "python")] +pub use planner::python_scan_predicate; +pub use planner::{ + StreamingExecutorBuilder, create_multiple_physical_plans, create_physical_plan, + create_scan_predicate, +}; +pub use predicate::ScanPredicate; diff --git a/crates/polars-mem-engine/src/planner/lp.rs b/crates/polars-mem-engine/src/planner/lp.rs new file mode 100644 index 000000000000..ef18d03b320d --- /dev/null +++ b/crates/polars-mem-engine/src/planner/lp.rs @@ -0,0 +1,956 @@ +use polars_core::POOL; +use polars_core::prelude::*; +use polars_expr::state::ExecutionState; +use polars_plan::global::_set_n_rows_for_scan; +use polars_plan::plans::expr_ir::ExprIR; +use polars_utils::format_pl_smallstr; +use recursive::recursive; + +use self::expr_ir::OutputName; +use self::predicates::{aexpr_to_column_predicates, aexpr_to_skip_batch_predicate}; +#[cfg(feature = "python")] +use self::python_dsl::PythonScanSource; +use super::super::executors::{self, Executor}; +use super::*; +use crate::ScanPredicate; +use crate::executors::{CachePrefiller, SinkExecutor}; +use crate::predicate::PhysicalColumnPredicates; + +pub type StreamingExecutorBuilder = + fn(Node, &mut Arena, &mut Arena) -> PolarsResult>; + +fn partitionable_gb( + keys: &[ExprIR], + aggs: &[ExprIR], + input_schema: &Schema, + expr_arena: &Arena, + apply: &Option>, +) -> bool { + // checks: + // 1. complex expressions in the group_by itself are also not partitionable + // in this case anything more than col("foo") + // 2. a custom function cannot be partitioned + // 3. we don't bother with more than 2 keys, as the cardinality likely explodes + // by the combinations + if !keys.is_empty() && keys.len() < 3 && apply.is_none() { + // complex expressions in the group_by itself are also not partitionable + // in this case anything more than col("foo") + for key in keys { + if (expr_arena).iter(key.node()).count() > 1 + || has_aexpr(key.node(), expr_arena, |ae| match ae { + AExpr::Literal(lv) => !lv.is_scalar(), + _ => false, + }) + { + return false; + } + } + + can_pre_agg_exprs(aggs, expr_arena, input_schema) + } else { + false + } +} + +#[derive(Clone)] +struct ConversionState { + has_cache_child: bool, + has_cache_parent: bool, +} + +impl ConversionState { + fn new() -> PolarsResult { + Ok(ConversionState { + has_cache_child: false, + has_cache_parent: false, + }) + } + + fn with_new_branch K>(&mut self, func: F) -> K { + let mut new_state = self.clone(); + new_state.has_cache_child = false; + let out = func(&mut new_state); + self.has_cache_child = new_state.has_cache_child; + out + } +} + +pub fn create_physical_plan( + root: Node, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + build_streaming_executor: Option, +) -> PolarsResult> { + let mut state = ConversionState::new()?; + let mut cache_nodes = Default::default(); + let plan = create_physical_plan_impl( + root, + lp_arena, + expr_arena, + &mut state, + &mut cache_nodes, + build_streaming_executor, + )?; + + if cache_nodes.is_empty() { + Ok(plan) + } else { + Ok(Box::new(CachePrefiller { + caches: cache_nodes, + phys_plan: plan, + })) + } +} + +pub struct MultiplePhysicalPlans { + pub cache_prefiller: Option>, + pub physical_plans: Vec>, +} +pub fn create_multiple_physical_plans( + roots: &[Node], + lp_arena: &mut Arena, + expr_arena: &mut Arena, + build_streaming_executor: Option, +) -> PolarsResult { + let mut state = ConversionState::new()?; + let mut cache_nodes = Default::default(); + let plans = state.with_new_branch(|new_state| { + roots + .iter() + .map(|&node| { + create_physical_plan_impl( + node, + lp_arena, + expr_arena, + new_state, + &mut cache_nodes, + build_streaming_executor, + ) + }) + .collect::>>() + })?; + + let cache_prefiller = (!cache_nodes.is_empty()).then(|| { + struct Empty; + impl Executor for Empty { + fn execute(&mut self, _cache: &mut ExecutionState) -> PolarsResult { + Ok(DataFrame::empty()) + } + } + Box::new(CachePrefiller { + caches: cache_nodes, + phys_plan: Box::new(Empty), + }) as _ + }); + + Ok(MultiplePhysicalPlans { + cache_prefiller, + physical_plans: plans, + }) +} + +#[cfg(feature = "python")] +#[allow(clippy::type_complexity)] +pub fn python_scan_predicate( + options: &mut PythonOptions, + expr_arena: &Arena, + state: &mut ExpressionConversionState, +) -> PolarsResult<( + Option>, + Option>, +)> { + let mut predicate_serialized = None; + let predicate = if let PythonPredicate::Polars(e) = &options.predicate { + // Convert to a pyarrow eval string. + if matches!(options.python_source, PythonScanSource::Pyarrow) { + if let Some(eval_str) = polars_plan::plans::python::pyarrow::predicate_to_pa( + e.node(), + expr_arena, + Default::default(), + ) { + options.predicate = PythonPredicate::PyArrow(eval_str); + // We don't have to use a physical expression as pyarrow deals with the filter. + None + } else { + Some(create_physical_expr( + e, + Context::Default, + expr_arena, + &options.schema, + state, + )?) + } + } + // Convert to physical expression for the case the reader cannot consume the predicate. + else { + let dsl_expr = e.to_expr(expr_arena); + predicate_serialized = polars_plan::plans::python::predicate::serialize(&dsl_expr)?; + + Some(create_physical_expr( + e, + Context::Default, + expr_arena, + &options.schema, + state, + )?) + } + } else { + None + }; + + Ok((predicate, predicate_serialized)) +} + +#[recursive] +fn create_physical_plan_impl( + root: Node, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + state: &mut ConversionState, + // Cache nodes in order of discovery + cache_nodes: &mut PlIndexMap>, + build_streaming_executor: Option, +) -> PolarsResult> { + use IR::*; + + macro_rules! recurse { + ($node:expr, $state: expr) => { + create_physical_plan_impl( + $node, + lp_arena, + expr_arena, + $state, + cache_nodes, + build_streaming_executor, + ) + }; + } + + let logical_plan = if state.has_cache_parent || matches!(lp_arena.get(root), IR::Scan { .. }) { + lp_arena.get(root).clone() + } else { + lp_arena.take(root) + }; + + match logical_plan { + #[cfg(feature = "python")] + PythonScan { mut options } => { + let mut expr_conv_state = ExpressionConversionState::new(true); + let (predicate, predicate_serialized) = + python_scan_predicate(&mut options, expr_arena, &mut expr_conv_state)?; + Ok(Box::new(executors::PythonScanExec { + options, + predicate, + predicate_serialized, + })) + }, + Sink { input, payload } => { + let input = recurse!(input, state)?; + match payload { + SinkTypeIR::Memory => Ok(Box::new(SinkExecutor { + input, + name: "mem".to_string(), + f: Box::new(move |df, _state| Ok(Some(df))), + })), + SinkTypeIR::File(FileSinkType { + file_type, + target, + sink_options, + cloud_options, + }) => { + let name: &'static str = match &file_type { + #[cfg(feature = "parquet")] + FileType::Parquet(_) => "parquet", + #[cfg(feature = "ipc")] + FileType::Ipc(_) => "ipc", + #[cfg(feature = "csv")] + FileType::Csv(_) => "csv", + #[cfg(feature = "json")] + FileType::Json(_) => "json", + #[allow(unreachable_patterns)] + _ => panic!("enable filetype feature"), + }; + + Ok(Box::new(SinkExecutor { + input, + name: name.to_string(), + f: Box::new(move |mut df, _state| { + let mut file = target + .open_into_writeable(&sink_options, cloud_options.as_ref())?; + let writer = &mut *file; + + use std::io::BufWriter; + match &file_type { + #[cfg(feature = "parquet")] + FileType::Parquet(options) => { + use polars_io::parquet::write::ParquetWriter; + ParquetWriter::new(BufWriter::new(writer)) + .with_compression(options.compression) + .with_statistics(options.statistics) + .with_row_group_size(options.row_group_size) + .with_data_page_size(options.data_page_size) + .finish(&mut df)?; + }, + #[cfg(feature = "ipc")] + FileType::Ipc(options) => { + use polars_io::SerWriter; + use polars_io::ipc::IpcWriter; + IpcWriter::new(BufWriter::new(writer)) + .with_compression(options.compression) + .with_compat_level(options.compat_level) + .finish(&mut df)?; + }, + #[cfg(feature = "csv")] + FileType::Csv(options) => { + use polars_io::SerWriter; + use polars_io::csv::write::CsvWriter; + CsvWriter::new(BufWriter::new(writer)) + .include_bom(options.include_bom) + .include_header(options.include_header) + .with_separator(options.serialize_options.separator) + .with_line_terminator( + options.serialize_options.line_terminator.clone(), + ) + .with_quote_char(options.serialize_options.quote_char) + .with_batch_size(options.batch_size) + .with_datetime_format( + options.serialize_options.datetime_format.clone(), + ) + .with_date_format( + options.serialize_options.date_format.clone(), + ) + .with_time_format( + options.serialize_options.time_format.clone(), + ) + .with_float_scientific( + options.serialize_options.float_scientific, + ) + .with_float_precision( + options.serialize_options.float_precision, + ) + .with_null_value(options.serialize_options.null.clone()) + .with_quote_style(options.serialize_options.quote_style) + .finish(&mut df)?; + }, + #[cfg(feature = "json")] + FileType::Json(_options) => { + use polars_io::SerWriter; + use polars_io::json::{JsonFormat, JsonWriter}; + + JsonWriter::new(BufWriter::new(writer)) + .with_json_format(JsonFormat::JsonLines) + .finish(&mut df)?; + }, + #[allow(unreachable_patterns)] + _ => panic!("enable filetype feature"), + } + + file.sync_on_close(sink_options.sync_on_close)?; + file.close()?; + + Ok(None) + }), + })) + }, + + SinkTypeIR::Partition { .. } => { + polars_bail!(InvalidOperation: + "partition sinks not yet supported in standard engine." + ) + }, + } + }, + SinkMultiple { .. } => { + unreachable!("should be handled with create_multiple_physical_plans") + }, + Union { inputs, options } => { + let inputs = state.with_new_branch(|new_state| { + inputs + .into_iter() + .map(|node| recurse!(node, new_state)) + .collect::>>() + }); + let inputs = inputs?; + Ok(Box::new(executors::UnionExec { inputs, options })) + }, + HConcat { + inputs, options, .. + } => { + let inputs = state.with_new_branch(|new_state| { + inputs + .into_iter() + .map(|node| recurse!(node, new_state)) + .collect::>>() + }); + + let inputs = inputs?; + + Ok(Box::new(executors::HConcatExec { inputs, options })) + }, + Slice { input, offset, len } => { + let input = recurse!(input, state)?; + Ok(Box::new(executors::SliceExec { input, offset, len })) + }, + Filter { input, predicate } => { + let mut streamable = + is_elementwise_rec_no_cat_cast(expr_arena.get(predicate.node()), expr_arena); + let input_schema = lp_arena.get(input).schema(lp_arena).into_owned(); + if streamable { + // This can cause problems with string caches + streamable = !input_schema + .iter_values() + .any(|dt| dt.contains_categoricals()) + || { + #[cfg(feature = "dtype-categorical")] + { + polars_core::using_string_cache() + } + + #[cfg(not(feature = "dtype-categorical"))] + { + false + } + } + } + let input = recurse!(input, state)?; + let mut state = ExpressionConversionState::new(true); + let predicate = create_physical_expr( + &predicate, + Context::Default, + expr_arena, + &input_schema, + &mut state, + )?; + Ok(Box::new(executors::FilterExec::new( + predicate, + input, + state.has_windows, + streamable, + ))) + }, + #[allow(unused_variables)] + Scan { + sources, + file_info, + hive_parts, + output_schema, + scan_type, + predicate, + mut unified_scan_args, + } => { + unified_scan_args.pre_slice = if let Some(mut slice) = unified_scan_args.pre_slice { + *slice.len_mut() = _set_n_rows_for_scan(Some(slice.len())).unwrap(); + Some(slice) + } else { + _set_n_rows_for_scan(None) + .map(|len| polars_utils::slice_enum::Slice::Positive { offset: 0, len }) + }; + + let mut state = ExpressionConversionState::new(true); + + let mut create_skip_batch_predicate = false; + #[cfg(feature = "parquet")] + { + create_skip_batch_predicate |= matches!( + &*scan_type, + FileScan::Parquet { + options: polars_io::prelude::ParquetOptions { + use_statistics: true, + .. + }, + .. + } + ); + } + + let predicate = predicate + .map(|predicate| { + create_scan_predicate( + &predicate, + expr_arena, + output_schema.as_ref().unwrap_or(&file_info.schema), + &mut state, + create_skip_batch_predicate, + false, + ) + }) + .transpose()?; + + match *scan_type { + FileScan::Anonymous { function, .. } => { + Ok(Box::new(executors::AnonymousScanExec { + function, + predicate, + unified_scan_args, + file_info, + output_schema, + predicate_has_windows: state.has_windows, + })) + }, + #[allow(unreachable_patterns)] + _ => { + let build_func = build_streaming_executor + .expect("invalid build. Missing feature new-streaming"); + return build_func(root, lp_arena, expr_arena); + }, + #[allow(unreachable_patterns)] + _ => unreachable!(), + } + }, + + Select { + expr, + input, + schema: _schema, + options, + .. + } => { + let input_schema = lp_arena.get(input).schema(lp_arena).into_owned(); + let input = recurse!(input, state)?; + let mut state = ExpressionConversionState::new(POOL.current_num_threads() > expr.len()); + let phys_expr = create_physical_expressions_from_irs( + &expr, + Context::Default, + expr_arena, + &input_schema, + &mut state, + )?; + + let allow_vertical_parallelism = options.should_broadcast && expr.iter().all(|e| is_elementwise_rec_no_cat_cast(expr_arena.get(e.node()), expr_arena)) + // If all columns are literal we would get a 1 row per thread. + && !phys_expr.iter().all(|p| { + p.is_literal() + }); + + Ok(Box::new(executors::ProjectionExec { + input, + expr: phys_expr, + has_windows: state.has_windows, + input_schema, + #[cfg(test)] + schema: _schema, + options, + allow_vertical_parallelism, + })) + }, + DataFrameScan { + df, output_schema, .. + } => Ok(Box::new(executors::DataFrameExec { + df, + projection: output_schema.map(|s| s.iter_names_cloned().collect()), + })), + Sort { + input, + by_column, + slice, + sort_options, + } => { + let input_schema = lp_arena.get(input).schema(lp_arena); + let by_column = create_physical_expressions_from_irs( + &by_column, + Context::Default, + expr_arena, + input_schema.as_ref(), + &mut ExpressionConversionState::new(true), + )?; + let input = recurse!(input, state)?; + Ok(Box::new(executors::SortExec { + input, + by_column, + slice, + sort_options, + })) + }, + Cache { + input, + id, + cache_hits, + } => { + state.has_cache_parent = true; + state.has_cache_child = true; + + if !cache_nodes.contains_key(&id) { + let input = recurse!(input, state)?; + + let cache = Box::new(executors::CacheExec { + id, + input: Some(input), + count: cache_hits, + }); + + cache_nodes.insert(id, cache); + } + + Ok(Box::new(executors::CacheExec { + id, + input: None, + count: cache_hits, + })) + }, + Distinct { input, options } => { + let input = recurse!(input, state)?; + Ok(Box::new(executors::UniqueExec { input, options })) + }, + GroupBy { + input, + keys, + aggs, + apply, + schema, + maintain_order, + options, + } => { + let input_schema = lp_arena.get(input).schema(lp_arena).into_owned(); + let options = Arc::try_unwrap(options).unwrap_or_else(|options| (*options).clone()); + let phys_keys = create_physical_expressions_from_irs( + &keys, + Context::Default, + expr_arena, + &input_schema, + &mut ExpressionConversionState::new(true), + )?; + let phys_aggs = create_physical_expressions_from_irs( + &aggs, + Context::Aggregation, + expr_arena, + &input_schema, + &mut ExpressionConversionState::new(true), + )?; + + let _slice = options.slice; + #[cfg(feature = "dynamic_group_by")] + if let Some(options) = options.dynamic { + let input = recurse!(input, state)?; + return Ok(Box::new(executors::GroupByDynamicExec { + input, + keys: phys_keys, + aggs: phys_aggs, + options, + input_schema, + slice: _slice, + apply, + })); + } + + #[cfg(feature = "dynamic_group_by")] + if let Some(options) = options.rolling { + let input = recurse!(input, state)?; + return Ok(Box::new(executors::GroupByRollingExec { + input, + keys: phys_keys, + aggs: phys_aggs, + options, + input_schema, + slice: _slice, + apply, + })); + } + + // We first check if we can partition the group_by on the latest moment. + let partitionable = partitionable_gb(&keys, &aggs, &input_schema, expr_arena, &apply); + if partitionable { + let from_partitioned_ds = (&*lp_arena).iter(input).any(|(_, lp)| { + if let Union { options, .. } = lp { + options.from_partitioned_ds + } else { + false + } + }); + let input = recurse!(input, state)?; + let keys = keys + .iter() + .map(|e| e.to_expr(expr_arena)) + .collect::>(); + let aggs = aggs + .iter() + .map(|e| e.to_expr(expr_arena)) + .collect::>(); + Ok(Box::new(executors::PartitionGroupByExec::new( + input, + phys_keys, + phys_aggs, + maintain_order, + options.slice, + input_schema, + schema, + from_partitioned_ds, + keys, + aggs, + ))) + } else { + let input = recurse!(input, state)?; + Ok(Box::new(executors::GroupByExec::new( + input, + phys_keys, + phys_aggs, + apply, + maintain_order, + input_schema, + options.slice, + ))) + } + }, + Join { + input_left, + input_right, + left_on, + right_on, + options, + schema, + .. + } => { + let schema_left = lp_arena.get(input_left).schema(lp_arena).into_owned(); + let schema_right = lp_arena.get(input_right).schema(lp_arena).into_owned(); + + let (input_left, input_right) = state.with_new_branch(|new_state| { + ( + recurse!(input_left, new_state), + recurse!(input_right, new_state), + ) + }); + let input_left = input_left?; + let input_right = input_right?; + + // Todo! remove the force option. It can deadlock. + let parallel = if options.force_parallel { + true + } else { + options.allow_parallel + }; + + let left_on = create_physical_expressions_from_irs( + &left_on, + Context::Default, + expr_arena, + &schema_left, + &mut ExpressionConversionState::new(true), + )?; + let right_on = create_physical_expressions_from_irs( + &right_on, + Context::Default, + expr_arena, + &schema_right, + &mut ExpressionConversionState::new(true), + )?; + let options = Arc::try_unwrap(options).unwrap_or_else(|options| (*options).clone()); + + // Convert the join options, to the physical join options. This requires the physical + // planner, so we do this last minute. + let join_type_options = options + .options + .map(|o| { + o.compile(|e| { + let phys_expr = create_physical_expr( + e, + Context::Default, + expr_arena, + &schema, + &mut ExpressionConversionState::new(false), + )?; + + let execution_state = ExecutionState::default(); + + Ok(Arc::new(move |df: DataFrame| { + let mask = phys_expr.evaluate(&df, &execution_state)?; + let mask = mask.as_materialized_series(); + let mask = mask.bool()?; + df._filter_seq(mask) + })) + }) + }) + .transpose()?; + + Ok(Box::new(executors::JoinExec::new( + input_left, + input_right, + left_on, + right_on, + parallel, + options.args, + join_type_options, + ))) + }, + HStack { + input, + exprs, + schema: output_schema, + options, + } => { + let input_schema = lp_arena.get(input).schema(lp_arena).into_owned(); + let input = recurse!(input, state)?; + + let allow_vertical_parallelism = options.should_broadcast + && exprs + .iter() + .all(|e| is_elementwise_rec_no_cat_cast(expr_arena.get(e.node()), expr_arena)); + + let mut state = + ExpressionConversionState::new(POOL.current_num_threads() > exprs.len()); + + let phys_exprs = create_physical_expressions_from_irs( + &exprs, + Context::Default, + expr_arena, + &input_schema, + &mut state, + )?; + Ok(Box::new(executors::StackExec { + input, + has_windows: state.has_windows, + exprs: phys_exprs, + input_schema, + output_schema, + options, + allow_vertical_parallelism, + })) + }, + MapFunction { + input, function, .. + } => { + let input = recurse!(input, state)?; + Ok(Box::new(executors::UdfExec { input, function })) + }, + ExtContext { + input, contexts, .. + } => { + let input = recurse!(input, state)?; + let contexts = contexts + .into_iter() + .map(|node| recurse!(node, state)) + .collect::>()?; + Ok(Box::new(executors::ExternalContext { input, contexts })) + }, + SimpleProjection { input, columns } => { + let input = recurse!(input, state)?; + let exec = executors::ProjectionSimple { input, columns }; + Ok(Box::new(exec)) + }, + #[cfg(feature = "merge_sorted")] + MergeSorted { + input_left, + input_right, + key, + } => { + let (input_left, input_right) = state.with_new_branch(|new_state| { + ( + recurse!(input_left, new_state), + recurse!(input_right, new_state), + ) + }); + let input_left = input_left?; + let input_right = input_right?; + + let exec = executors::MergeSorted { + input_left, + input_right, + key, + }; + Ok(Box::new(exec)) + }, + Invalid => unreachable!(), + } +} + +pub fn create_scan_predicate( + predicate: &ExprIR, + expr_arena: &mut Arena, + schema: &Arc, + state: &mut ExpressionConversionState, + create_skip_batch_predicate: bool, + create_column_predicates: bool, +) -> PolarsResult { + let phys_predicate = + create_physical_expr(predicate, Context::Default, expr_arena, schema, state)?; + let live_columns = Arc::new(PlIndexSet::from_iter(aexpr_to_leaf_names_iter( + predicate.node(), + expr_arena, + ))); + + let mut skip_batch_predicate = None; + + if create_skip_batch_predicate { + if let Some(node) = aexpr_to_skip_batch_predicate(predicate.node(), expr_arena, schema) { + let expr = ExprIR::new(node, predicate.output_name_inner().clone()); + + if std::env::var("POLARS_OUTPUT_SKIP_BATCH_PRED").as_deref() == Ok("1") { + eprintln!("predicate: {}", predicate.display(expr_arena)); + eprintln!("skip_batch_predicate: {}", expr.display(expr_arena)); + } + + let mut skip_batch_schema = Schema::with_capacity(1 + live_columns.len()); + + skip_batch_schema.insert(PlSmallStr::from_static("len"), IDX_DTYPE); + for (col, dtype) in schema.iter() { + if !live_columns.contains(col) { + continue; + } + + skip_batch_schema.insert(format_pl_smallstr!("{col}_min"), dtype.clone()); + skip_batch_schema.insert(format_pl_smallstr!("{col}_max"), dtype.clone()); + skip_batch_schema.insert(format_pl_smallstr!("{col}_nc"), IDX_DTYPE); + } + + skip_batch_predicate = Some(create_physical_expr( + &expr, + Context::Default, + expr_arena, + &Arc::new(skip_batch_schema), + state, + )?); + } + } + + let column_predicates = if create_column_predicates { + let column_predicates = aexpr_to_column_predicates(predicate.node(), expr_arena, schema); + if std::env::var("POLARS_OUTPUT_COLUMN_PREDS").as_deref() == Ok("1") { + eprintln!("column_predicates: {{"); + eprintln!(" ["); + for (pred, spec) in column_predicates.predicates.values() { + eprintln!( + " {} ({spec:?}),", + ExprIRDisplay::display_node(*pred, expr_arena) + ); + } + eprintln!(" ],"); + eprintln!( + " is_sumwise_complete: {}", + column_predicates.is_sumwise_complete + ); + eprintln!("}}"); + } + PhysicalColumnPredicates { + predicates: column_predicates + .predicates + .into_iter() + .map(|(n, (p, s))| { + PolarsResult::Ok(( + n, + ( + create_physical_expr( + &ExprIR::new(p, OutputName::Alias(PlSmallStr::EMPTY)), + Context::Default, + expr_arena, + schema, + state, + )?, + s, + ), + )) + }) + .collect::>>()?, + is_sumwise_complete: column_predicates.is_sumwise_complete, + } + } else { + PhysicalColumnPredicates { + predicates: PlHashMap::default(), + is_sumwise_complete: false, + } + }; + + PolarsResult::Ok(ScanPredicate { + predicate: phys_predicate, + live_columns, + skip_batch_predicate, + column_predicates, + }) +} diff --git a/crates/polars-mem-engine/src/planner/mod.rs b/crates/polars-mem-engine/src/planner/mod.rs new file mode 100644 index 000000000000..90364b90f852 --- /dev/null +++ b/crates/polars-mem-engine/src/planner/mod.rs @@ -0,0 +1,4 @@ +mod lp; +pub use lp::*; +pub(crate) use polars_expr::planner::*; +use polars_plan::prelude::*; diff --git a/crates/polars-mem-engine/src/predicate.rs b/crates/polars-mem-engine/src/predicate.rs new file mode 100644 index 000000000000..36aa0edda8dd --- /dev/null +++ b/crates/polars-mem-engine/src/predicate.rs @@ -0,0 +1,223 @@ +use core::fmt; +use std::sync::Arc; + +use arrow::bitmap::Bitmap; +use polars_core::frame::DataFrame; +use polars_core::prelude::{AnyValue, Column, Field, GroupPositions, PlHashMap, PlIndexSet}; +use polars_core::scalar::Scalar; +use polars_core::schema::{Schema, SchemaRef}; +use polars_error::PolarsResult; +use polars_expr::prelude::{AggregationContext, PhysicalExpr, phys_expr_to_io_expr}; +use polars_expr::state::ExecutionState; +use polars_io::predicates::{ + ColumnPredicates, ScanIOPredicate, SkipBatchPredicate, SpecializedColumnPredicateExpr, +}; +use polars_utils::pl_str::PlSmallStr; +use polars_utils::{IdxSize, format_pl_smallstr}; + +/// All the expressions and metadata used to filter out rows using predicates. +#[derive(Clone)] +pub struct ScanPredicate { + pub predicate: Arc, + + /// Column names that are used in the predicate. + pub live_columns: Arc>, + + /// A predicate expression used to skip record batches based on its statistics. + /// + /// This expression will be given a batch size along with a `min`, `max` and `null count` for + /// each live column (set to `null` when it is not known) and the expression evaluates to + /// `true` if the whole batch can for sure be skipped. This may be conservative and evaluate to + /// `false` even when the batch could theoretically be skipped. + pub skip_batch_predicate: Option>, + + /// Partial predicates for each column for filter when loading columnar formats. + pub column_predicates: PhysicalColumnPredicates, +} + +impl fmt::Debug for ScanPredicate { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("scan_predicate") + } +} + +#[derive(Clone)] +pub struct PhysicalColumnPredicates { + pub predicates: PlHashMap< + PlSmallStr, + ( + Arc, + Option, + ), + >, + pub is_sumwise_complete: bool, +} + +/// Helper to implement [`SkipBatchPredicate`]. +struct SkipBatchPredicateHelper { + skip_batch_predicate: Arc, + schema: SchemaRef, +} + +/// Helper for the [`PhysicalExpr`] trait to include constant columns. +pub struct PhysicalExprWithConstCols { + constants: Vec<(PlSmallStr, Scalar)>, + child: Arc, +} + +impl PhysicalExpr for PhysicalExprWithConstCols { + fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult { + let mut df = df.clone(); + for (name, scalar) in &self.constants { + df.with_column(Column::new_scalar( + name.clone(), + scalar.clone(), + df.height(), + ))?; + } + + self.child.evaluate(&df, state) + } + + fn evaluate_on_groups<'a>( + &self, + df: &DataFrame, + groups: &'a GroupPositions, + state: &ExecutionState, + ) -> PolarsResult> { + let mut df = df.clone(); + for (name, scalar) in &self.constants { + df.with_column(Column::new_scalar( + name.clone(), + scalar.clone(), + df.height(), + ))?; + } + + self.child.evaluate_on_groups(&df, groups, state) + } + + fn to_field(&self, input_schema: &Schema) -> PolarsResult { + self.child.to_field(input_schema) + } + fn is_scalar(&self) -> bool { + self.child.is_scalar() + } +} + +impl ScanPredicate { + pub fn with_constant_columns( + &self, + constant_columns: impl IntoIterator, + ) -> Self { + let constant_columns = constant_columns.into_iter(); + + let mut live_columns = self.live_columns.as_ref().clone(); + let mut skip_batch_predicate_constants = Vec::with_capacity( + self.skip_batch_predicate + .is_some() + .then_some(1 + constant_columns.size_hint().0 * 3) + .unwrap_or_default(), + ); + + let predicate_constants = constant_columns + .filter_map(|(name, scalar): (PlSmallStr, Scalar)| { + if !live_columns.swap_remove(&name) { + return None; + } + + if self.skip_batch_predicate.is_some() { + let mut null_count: Scalar = (0 as IdxSize).into(); + + // If the constant value is Null, we don't know how many nulls there are + // because the length of the batch may vary. + if scalar.is_null() { + null_count.update(AnyValue::Null); + } + + skip_batch_predicate_constants.extend([ + (format_pl_smallstr!("{name}_min"), scalar.clone()), + (format_pl_smallstr!("{name}_max"), scalar.clone()), + (format_pl_smallstr!("{name}_nc"), null_count), + ]); + } + + Some((name, scalar)) + }) + .collect(); + + let predicate = Arc::new(PhysicalExprWithConstCols { + constants: predicate_constants, + child: self.predicate.clone(), + }); + let skip_batch_predicate = self.skip_batch_predicate.as_ref().map(|skp| { + Arc::new(PhysicalExprWithConstCols { + constants: skip_batch_predicate_constants, + child: skp.clone(), + }) as _ + }); + + Self { + predicate, + live_columns: Arc::new(live_columns), + skip_batch_predicate, + column_predicates: self.column_predicates.clone(), // Q? Maybe this should cull + // predicates. + } + } + + /// Create a predicate to skip batches using statistics. + pub(crate) fn to_dyn_skip_batch_predicate( + &self, + schema: SchemaRef, + ) -> Option> { + let skip_batch_predicate = self.skip_batch_predicate.as_ref()?.clone(); + Some(Arc::new(SkipBatchPredicateHelper { + skip_batch_predicate, + schema, + })) + } + + pub fn to_io( + &self, + skip_batch_predicate: Option<&Arc>, + schema: SchemaRef, + ) -> ScanIOPredicate { + ScanIOPredicate { + predicate: phys_expr_to_io_expr(self.predicate.clone()), + live_columns: self.live_columns.clone(), + skip_batch_predicate: skip_batch_predicate + .cloned() + .or_else(|| self.to_dyn_skip_batch_predicate(schema)), + column_predicates: Arc::new(ColumnPredicates { + predicates: self + .column_predicates + .predicates + .iter() + .map(|(n, (p, s))| (n.clone(), (phys_expr_to_io_expr(p.clone()), s.clone()))) + .collect(), + is_sumwise_complete: self.column_predicates.is_sumwise_complete, + }), + } + } +} + +impl SkipBatchPredicate for SkipBatchPredicateHelper { + fn schema(&self) -> &SchemaRef { + &self.schema + } + + fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult { + let array = self + .skip_batch_predicate + .evaluate(df, &Default::default())?; + let array = array.bool()?; + let array = array.downcast_as_array(); + + if let Some(validity) = array.validity() { + Ok(array.values() & validity) + } else { + Ok(array.values().clone()) + } + } +} diff --git a/crates/polars-mem-engine/src/prelude.rs b/crates/polars-mem-engine/src/prelude.rs new file mode 100644 index 000000000000..b344f7c642fc --- /dev/null +++ b/crates/polars-mem-engine/src/prelude.rs @@ -0,0 +1,6 @@ +pub(crate) use polars_core::prelude::*; +pub(crate) use polars_expr::prelude::*; +pub(crate) use polars_ops::prelude::*; +pub(crate) use polars_plan::prelude::*; +#[cfg(feature = "dynamic_group_by")] +pub(crate) use polars_time::prelude::*; diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml new file mode 100644 index 000000000000..bf4652a385ef --- /dev/null +++ b/crates/polars-ops/Cargo.toml @@ -0,0 +1,147 @@ +[package] +name = "polars-ops" +version = { workspace = true } +authors = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +license = { workspace = true } +repository = { workspace = true } +description = "More operations on Polars data structures" + +[dependencies] +polars-compute = { workspace = true } +polars-core = { workspace = true, features = ["algorithm_group_by", "zip_with"] } +polars-error = { workspace = true } +polars-json = { workspace = true, optional = true } +polars-schema = { workspace = true } +polars-utils = { workspace = true } + +aho-corasick = { workspace = true, optional = true } +argminmax = { version = "0.6.2", default-features = false, features = ["float"] } +arrow = { workspace = true } +base64 = { workspace = true, optional = true } +bytemuck = { workspace = true } +chrono = { workspace = true, optional = true } +chrono-tz = { workspace = true, optional = true } +either = { workspace = true } +hashbrown = { workspace = true } +hex = { workspace = true, optional = true } +indexmap = { workspace = true } +libm = { workspace = true } +memchr = { workspace = true } +num-traits = { workspace = true } +rand = { workspace = true, optional = true, features = ["small_rng", "std"] } +rand_distr = { workspace = true, optional = true } +rayon = { workspace = true } +regex = { workspace = true } +regex-syntax = { workspace = true } +serde = { workspace = true, optional = true } +serde_json = { workspace = true, optional = true } +strum_macros = { workspace = true } +unicode-normalization = { workspace = true, optional = true } +unicode-reverse = { workspace = true, optional = true } + +[dependencies.jsonpath_lib] +package = "jsonpath_lib_polars_vendor" +optional = true +version = "0.0.1" + +[dev-dependencies] +rand = { workspace = true, features = ["small_rng"] } + +[build-dependencies] +version_check = { workspace = true } + +[features] +simd = ["argminmax/nightly_simd"] +nightly = ["polars-utils/nightly"] +dtype-categorical = ["polars-core/dtype-categorical"] +dtype-date = ["polars-core/dtype-date", "polars-core/temporal"] +dtype-datetime = ["polars-core/dtype-datetime", "polars-core/temporal"] +dtype-time = ["polars-core/dtype-time", "polars-core/temporal"] +dtype-duration = ["polars-core/dtype-duration", "polars-core/temporal"] +dtype-struct = ["polars-core/dtype-struct", "polars-core/temporal"] +dtype-u8 = ["polars-core/dtype-u8"] +dtype-u16 = ["polars-core/dtype-u16"] +dtype-i8 = ["polars-core/dtype-i8"] +dtype-i128 = ["polars-core/dtype-i128"] +dtype-i16 = ["polars-core/dtype-i16"] +dtype-array = ["polars-core/dtype-array"] +dtype-decimal = ["polars-core/dtype-decimal", "dtype-i128"] +object = ["polars-core/object"] +propagate_nans = [] +performant = ["polars-core/performant", "fused"] +big_idx = ["polars-core/bigidx"] +round_series = [] +is_first_distinct = [] +is_last_distinct = [] +is_unique = [] +unique_counts = [] +is_between = [] +approx_unique = [] +business = ["dtype-date", "chrono"] +fused = [] +cutqcut = ["dtype-categorical", "dtype-struct"] +rle = ["dtype-struct"] +timezones = ["chrono", "chrono-tz", "polars-core/temporal", "polars-core/timezones", "polars-core/dtype-datetime"] +random = ["rand", "rand_distr"] +rank = ["rand"] +find_many = ["aho-corasick"] +serde = ["dep:serde", "polars-core/serde", "polars-utils/serde", "polars-schema/serde"] + +# extra utilities for BinaryChunked +binary_encoding = ["base64", "hex"] +string_encoding = ["base64", "hex"] + +# ops +bitwise = ["polars-core/bitwise"] +to_dummies = [] +interpolate = [] +interpolate_by = [] +list_to_struct = ["polars-core/dtype-struct"] +array_to_struct = ["polars-core/dtype-array", "polars-core/dtype-struct"] +list_count = [] +diff = [] +pct_change = ["diff"] +strings = ["polars-core/strings"] +string_pad = ["polars-core/strings"] +string_normalize = ["polars-core/strings", "unicode-normalization"] +string_reverse = ["polars-core/strings", "unicode-reverse"] +string_to_integer = ["polars-core/strings"] +extract_jsonpath = ["serde_json", "jsonpath_lib", "polars-json"] +log = [] +hash = [] +reinterpret = ["polars-core/reinterpret"] +rolling_window = ["polars-core/rolling_window"] +rolling_window_by = ["polars-core/rolling_window_by"] +moment = [] +mode = [] +index_of = [] +search_sorted = [] +merge_sorted = [] +top_k = [] +pivot = ["polars-core/reinterpret", "polars-core/dtype-struct"] +cross_join = [] +chunked_ids = [] +asof_join = [] +iejoin = [] +semi_anti_join = [] +array_any_all = ["dtype-array"] +array_count = ["dtype-array"] +list_gather = [] +list_sets = [] +list_any_all = [] +list_drop_nulls = [] +list_sample = ["polars-core/random"] +extract_groups = ["dtype-struct", "polars-core/regex"] +is_in = ["polars-core/reinterpret"] +hist = ["dtype-categorical", "dtype-struct"] +repeat_by = [] +peaks = [] +cum_agg = [] +ewma = [] +ewma_by = [] +abs = [] +cov = [] +gather = [] +replace = ["is_in"] diff --git a/crates/polars-ops/LICENSE b/crates/polars-ops/LICENSE new file mode 120000 index 000000000000..30cff7403da0 --- /dev/null +++ b/crates/polars-ops/LICENSE @@ -0,0 +1 @@ +../../LICENSE \ No newline at end of file diff --git a/crates/polars-ops/README.md b/crates/polars-ops/README.md new file mode 100644 index 000000000000..6801a4d5736e --- /dev/null +++ b/crates/polars-ops/README.md @@ -0,0 +1,7 @@ +# polars-ops + +`polars-ops` is an **internal sub-crate** of the [Polars](https://crates.io/crates/polars) library, +providing extended operations on Polars data structures. + +**Important Note**: This crate is **not intended for external usage**. Please refer to the main +[Polars crate](https://crates.io/crates/polars) for intended usage. diff --git a/crates/polars-ops/build.rs b/crates/polars-ops/build.rs new file mode 100644 index 000000000000..3e4ab64620ac --- /dev/null +++ b/crates/polars-ops/build.rs @@ -0,0 +1,7 @@ +fn main() { + println!("cargo:rerun-if-changed=build.rs"); + let channel = version_check::Channel::read().unwrap(); + if channel.is_nightly() { + println!("cargo:rustc-cfg=feature=\"nightly\""); + } +} diff --git a/crates/polars-ops/src/chunked_array/array/any_all.rs b/crates/polars-ops/src/chunked_array/array/any_all.rs new file mode 100644 index 000000000000..904231903c57 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/array/any_all.rs @@ -0,0 +1,53 @@ +use arrow::array::{BooleanArray, FixedSizeListArray}; +use arrow::bitmap::Bitmap; +use arrow::legacy::utils::CustomIterTools; + +use super::*; + +fn array_all_any(arr: &FixedSizeListArray, op: F, is_all: bool) -> PolarsResult +where + F: Fn(&BooleanArray) -> bool, +{ + let values = arr.values(); + + polars_ensure!(values.dtype() == &ArrowDataType::Boolean, ComputeError: "expected boolean elements in array"); + + let values = values.as_any().downcast_ref::().unwrap(); + let validity = arr.validity().cloned(); + + // Fast path where all values set (all is free). + if is_all { + let all_set = arrow::compute::boolean::all(values); + if all_set { + let bits = Bitmap::new_with_value(true, arr.len()); + return Ok(BooleanArray::from_data_default(bits, None).with_validity(validity)); + } + } + + let len = arr.size(); + let iter = (0..values.len()).step_by(len).map(|start| { + // SAFETY: start + len is in bound guarded by invariant of FixedSizeListArray + let val = unsafe { values.clone().sliced_unchecked(start, len) }; + op(&val) + }); + + Ok(BooleanArray::from_trusted_len_values_iter( + // SAFETY: we evaluate for every sub-array, the length is equals to arr.len(). + unsafe { iter.trust_my_length(arr.len()) }, + ) + .with_validity(validity)) +} + +pub(super) fn array_all(ca: &ArrayChunked) -> PolarsResult { + let chunks = ca + .downcast_iter() + .map(|arr| array_all_any(arr, arrow::compute::boolean::all, true)); + Ok(BooleanChunked::try_from_chunk_iter(ca.name().clone(), chunks)?.into_series()) +} + +pub(super) fn array_any(ca: &ArrayChunked) -> PolarsResult { + let chunks = ca + .downcast_iter() + .map(|arr| array_all_any(arr, arrow::compute::boolean::any, false)); + Ok(BooleanChunked::try_from_chunk_iter(ca.name().clone(), chunks)?.into_series()) +} diff --git a/crates/polars-ops/src/chunked_array/array/count.rs b/crates/polars-ops/src/chunked_array/array/count.rs new file mode 100644 index 000000000000..571061123658 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/array/count.rs @@ -0,0 +1,46 @@ +use arrow::array::{Array, BooleanArray}; +use arrow::bitmap::Bitmap; +use arrow::bitmap::utils::count_zeros; +use arrow::legacy::utils::CustomIterTools; +use polars_core::prelude::arity::unary_mut_with_options; + +use super::*; + +#[cfg(feature = "array_count")] +pub fn array_count_matches(ca: &ArrayChunked, value: AnyValue) -> PolarsResult { + let value = Series::new(PlSmallStr::EMPTY, [value]); + + let ca = ca.apply_to_inner(&|s| { + ChunkCompareEq::<&Series>::equal_missing(&s, &value).map(|ca| ca.into_series()) + })?; + let out = count_boolean_bits(&ca); + Ok(out.into_series()) +} + +pub(super) fn count_boolean_bits(ca: &ArrayChunked) -> IdxCa { + unary_mut_with_options(ca, |arr| { + let inner_arr = arr.values(); + let mask = inner_arr.as_any().downcast_ref::().unwrap(); + assert_eq!(mask.null_count(), 0); + let out = count_bits_set(mask.values(), arr.len(), arr.size()); + IdxArr::from_data_default(out.into(), arr.validity().cloned()) + }) +} + +fn count_bits_set(values: &Bitmap, len: usize, width: usize) -> Vec { + // Fast path where all bits are either set or unset. + if values.unset_bits() == values.len() { + return vec![0 as IdxSize; len]; + } else if values.unset_bits() == 0 { + return vec![width as IdxSize; len]; + } + + let (bits, bitmap_offset, _) = values.as_slice(); + + (0..len) + .map(|i| { + let set_ones = width - count_zeros(bits, bitmap_offset + i * width, width); + set_ones as IdxSize + }) + .collect_trusted() +} diff --git a/crates/polars-ops/src/chunked_array/array/dispersion.rs b/crates/polars-ops/src/chunked_array/array/dispersion.rs new file mode 100644 index 000000000000..17924d7c38bb --- /dev/null +++ b/crates/polars-ops/src/chunked_array/array/dispersion.rs @@ -0,0 +1,95 @@ +use super::*; + +pub(super) fn median_with_nulls(ca: &ArrayChunked) -> PolarsResult { + let mut out = match ca.inner_dtype() { + DataType::Float32 => { + let out: Float32Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().median().map(|v| v as f32))) + .with_name(ca.name().clone()); + out.into_series() + }, + #[cfg(feature = "dtype-duration")] + DataType::Duration(tu) => { + let out: Int64Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().median().map(|v| v as i64))) + .with_name(ca.name().clone()); + out.into_duration(*tu).into_series() + }, + _ => { + let out: Float64Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().median())) + .with_name(ca.name().clone()); + out.into_series() + }, + }; + out.rename(ca.name().clone()); + Ok(out) +} + +pub(super) fn std_with_nulls(ca: &ArrayChunked, ddof: u8) -> PolarsResult { + let mut out = match ca.inner_dtype() { + DataType::Float32 => { + let out: Float32Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().std(ddof).map(|v| v as f32))) + .with_name(ca.name().clone()); + out.into_series() + }, + #[cfg(feature = "dtype-duration")] + DataType::Duration(tu) => { + let out: Int64Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().std(ddof).map(|v| v as i64))) + .with_name(ca.name().clone()); + out.into_duration(*tu).into_series() + }, + _ => { + let out: Float64Chunked = { + ca.amortized_iter() + .map(|s| s.and_then(|s| s.as_ref().std(ddof))) + .collect() + }; + out.into_series() + }, + }; + out.rename(ca.name().clone()); + Ok(out) +} + +pub(super) fn var_with_nulls(ca: &ArrayChunked, ddof: u8) -> PolarsResult { + let mut out = match ca.inner_dtype() { + DataType::Float32 => { + let out: Float32Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().var(ddof).map(|v| v as f32))) + .with_name(ca.name().clone()); + out.into_series() + }, + #[cfg(feature = "dtype-duration")] + DataType::Duration(TimeUnit::Milliseconds) => { + let out: Int64Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().var(ddof).map(|v| v as i64))) + .with_name(ca.name().clone()); + out.into_duration(TimeUnit::Milliseconds).into_series() + }, + #[cfg(feature = "dtype-duration")] + DataType::Duration(TimeUnit::Microseconds | TimeUnit::Nanoseconds) => { + let out: Int64Chunked = ca + .cast(&DataType::Array( + Box::new(DataType::Duration(TimeUnit::Milliseconds)), + ca.width(), + )) + .unwrap() + .array() + .unwrap() + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().var(ddof).map(|v| v as i64))) + .with_name(ca.name().clone()); + out.into_duration(TimeUnit::Milliseconds).into_series() + }, + _ => { + let out: Float64Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().var(ddof))) + .with_name(ca.name().clone()); + out.into_series() + }, + }; + out.rename(ca.name().clone()); + Ok(out) +} diff --git a/crates/polars-ops/src/chunked_array/array/get.rs b/crates/polars-ops/src/chunked_array/array/get.rs new file mode 100644 index 000000000000..f7a534055708 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/array/get.rs @@ -0,0 +1,73 @@ +use arrow::array::Array; +use polars_compute::gather::sublist::fixed_size_list::{ + sub_fixed_size_list_get, sub_fixed_size_list_get_literal, +}; +use polars_core::utils::align_chunks_binary; + +use super::*; + +fn array_get_literal(ca: &ArrayChunked, idx: i64, null_on_oob: bool) -> PolarsResult { + let chunks = ca + .downcast_iter() + .map(|arr| sub_fixed_size_list_get_literal(arr, idx, null_on_oob)) + .collect::>>()?; + Series::try_from((ca.name().clone(), chunks)) + .unwrap() + .cast(ca.inner_dtype()) +} + +/// Get the value by literal index in the array. +/// So index `0` would return the first item of every sub-array +/// and index `-1` would return the last item of every sub-array +/// if an index is out of bounds, it will return a `None`. +pub fn array_get( + ca: &ArrayChunked, + index: &Int64Chunked, + null_on_oob: bool, +) -> PolarsResult { + match index.len() { + 1 => { + let index = index.get(0); + if let Some(index) = index { + array_get_literal(ca, index, null_on_oob) + } else { + Ok(Series::full_null( + ca.name().clone(), + ca.len(), + ca.inner_dtype(), + )) + } + }, + len if len == ca.len() => { + let out = binary_to_series_arr_get(ca, index, null_on_oob, |arr, idx, nob| { + sub_fixed_size_list_get(arr, idx, nob) + }); + out?.cast(ca.inner_dtype()) + }, + len => polars_bail!( + ComputeError: + "`arr.get` expression got an index array of length {} while the array has {} elements", + len, ca.len() + ), + } +} + +pub fn binary_to_series_arr_get( + lhs: &ChunkedArray, + rhs: &ChunkedArray, + null_on_oob: bool, + mut op: F, +) -> PolarsResult +where + T: PolarsDataType, + U: PolarsDataType, + F: FnMut(&T::Array, &U::Array, bool) -> PolarsResult>, +{ + let (lhs, rhs) = align_chunks_binary(lhs, rhs); + let chunks = lhs + .downcast_iter() + .zip(rhs.downcast_iter()) + .map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr, null_on_oob)) + .collect::>>()?; + Series::try_from((lhs.name().clone(), chunks)) +} diff --git a/crates/polars-ops/src/chunked_array/array/join.rs b/crates/polars-ops/src/chunked_array/array/join.rs new file mode 100644 index 000000000000..f450dea0119f --- /dev/null +++ b/crates/polars-ops/src/chunked_array/array/join.rs @@ -0,0 +1,104 @@ +use std::fmt::Write; + +use super::*; + +fn join_literal( + ca: &ArrayChunked, + separator: &str, + ignore_nulls: bool, +) -> PolarsResult { + let DataType::Array(_, _) = ca.dtype() else { + unreachable!() + }; + + let mut buf = String::with_capacity(128); + let mut builder = StringChunkedBuilder::new(ca.name().clone(), ca.len()); + + ca.for_each_amortized(|opt_s| { + let opt_val = opt_s.and_then(|s| { + // make sure that we don't write values of previous iteration + buf.clear(); + let ca = s.as_ref().str().unwrap(); + + if ca.null_count() != 0 && !ignore_nulls { + return None; + } + for arr in ca.downcast_iter() { + for val in arr.non_null_values_iter() { + buf.write_str(val).unwrap(); + buf.write_str(separator).unwrap(); + } + } + + // last value should not have a separator, so slice that off + // saturating sub because there might have been nothing written. + Some(&buf[..buf.len().saturating_sub(separator.len())]) + }); + builder.append_option(opt_val) + }); + Ok(builder.finish()) +} + +fn join_many( + ca: &ArrayChunked, + separator: &StringChunked, + ignore_nulls: bool, +) -> PolarsResult { + polars_ensure!( + ca.len() == separator.len(), + length_mismatch = "arr.join", + ca.len(), + separator.len() + ); + + let mut buf = String::new(); + let mut builder = StringChunkedBuilder::new(ca.name().clone(), ca.len()); + + { ca.amortized_iter() } + .zip(separator) + .for_each(|(opt_s, opt_sep)| match opt_sep { + Some(separator) => { + let opt_val = opt_s.and_then(|s| { + // make sure that we don't write values of previous iteration + buf.clear(); + let ca = s.as_ref().str().unwrap(); + + if ca.null_count() != 0 && !ignore_nulls { + return None; + } + + for arr in ca.downcast_iter() { + for val in arr.non_null_values_iter() { + buf.write_str(val).unwrap(); + buf.write_str(separator).unwrap(); + } + } + // last value should not have a separator, so slice that off + // saturating sub because there might have been nothing written. + Some(&buf[..buf.len().saturating_sub(separator.len())]) + }); + builder.append_option(opt_val) + }, + _ => builder.append_null(), + }); + Ok(builder.finish()) +} + +/// In case the inner dtype [`DataType::String`], the individual items will be joined into a +/// single string separated by `separator`. +pub fn array_join( + ca: &ArrayChunked, + separator: &StringChunked, + ignore_nulls: bool, +) -> PolarsResult { + match ca.inner_dtype() { + DataType::String => match separator.len() { + 1 => match separator.get(0) { + Some(separator) => join_literal(ca, separator, ignore_nulls), + _ => Ok(StringChunked::full_null(ca.name().clone(), ca.len())), + }, + _ => join_many(ca, separator, ignore_nulls), + }, + dt => polars_bail!(op = "`array.join`", got = dt, expected = "String"), + } +} diff --git a/crates/polars-ops/src/chunked_array/array/min_max.rs b/crates/polars-ops/src/chunked_array/array/min_max.rs new file mode 100644 index 000000000000..a82de2436291 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/array/min_max.rs @@ -0,0 +1,86 @@ +use arrow::array::{Array, PrimitiveArray}; +use polars_compute::min_max::MinMaxKernel; +use polars_core::prelude::*; +use polars_core::with_match_physical_numeric_polars_type; + +fn array_agg( + values: &PrimitiveArray, + width: usize, + slice_agg: F1, + arr_agg: F2, +) -> PrimitiveArray +where + T: NumericNative, + S: NumericNative, + F1: Fn(&[T]) -> Option, + F2: Fn(&PrimitiveArray) -> Option, +{ + if values.null_count() == 0 { + let values = values.values().as_slice(); + values + .chunks_exact(width) + .map(|sl| slice_agg(sl).unwrap()) + .collect_arr() + } else { + (0..values.len()) + .step_by(width) + .map(|start| { + // SAFETY: This value array from a FixedSizeListArray, + // we can ensure that `start + width` will not out out range + let sliced = unsafe { values.clone().sliced_unchecked(start, width) }; + arr_agg(&sliced) + }) + .collect_arr() + } +} + +pub(super) enum AggType { + Min, + Max, +} + +fn agg_min(values: &PrimitiveArray, width: usize) -> PrimitiveArray +where + T: NumericNative, + PrimitiveArray: for<'a> MinMaxKernel = T>, + [T]: for<'a> MinMaxKernel = T>, +{ + array_agg( + values, + width, + MinMaxKernel::min_ignore_nan_kernel, + MinMaxKernel::min_ignore_nan_kernel, + ) +} + +fn agg_max(values: &PrimitiveArray, width: usize) -> PrimitiveArray +where + T: NumericNative, + PrimitiveArray: for<'a> MinMaxKernel = T>, + [T]: for<'a> MinMaxKernel = T>, +{ + array_agg( + values, + width, + MinMaxKernel::max_ignore_nan_kernel, + MinMaxKernel::max_ignore_nan_kernel, + ) +} + +pub(super) fn array_dispatch( + name: PlSmallStr, + values: &Series, + width: usize, + agg_type: AggType, +) -> Series { + let chunks: Vec = with_match_physical_numeric_polars_type!(values.dtype(), |$T| { + let ca: &ChunkedArray<$T> = values.as_ref().as_ref().as_ref(); + ca.downcast_iter().map(|arr| { + match agg_type { + AggType::Min => Box::new(agg_min(arr, width)) as ArrayRef, + AggType::Max => Box::new(agg_max(arr, width)) as ArrayRef, + } + }).collect() + }); + Series::try_from((name, chunks)).unwrap() +} diff --git a/crates/polars-ops/src/chunked_array/array/mod.rs b/crates/polars-ops/src/chunked_array/array/mod.rs new file mode 100644 index 000000000000..efe4dcbf339c --- /dev/null +++ b/crates/polars-ops/src/chunked_array/array/mod.rs @@ -0,0 +1,26 @@ +#[cfg(feature = "array_any_all")] +mod any_all; +mod count; +mod dispersion; +mod get; +mod join; +mod min_max; +mod namespace; +mod sum_mean; +#[cfg(feature = "array_to_struct")] +mod to_struct; + +pub use namespace::ArrayNameSpace; +use polars_core::prelude::*; +#[cfg(feature = "array_to_struct")] +pub use to_struct::*; + +pub trait AsArray { + fn as_array(&self) -> &ArrayChunked; +} + +impl AsArray for ArrayChunked { + fn as_array(&self) -> &ArrayChunked { + self + } +} diff --git a/crates/polars-ops/src/chunked_array/array/namespace.rs b/crates/polars-ops/src/chunked_array/array/namespace.rs new file mode 100644 index 000000000000..5bcd22af084e --- /dev/null +++ b/crates/polars-ops/src/chunked_array/array/namespace.rs @@ -0,0 +1,198 @@ +use super::min_max::AggType; +use super::*; +#[cfg(feature = "array_count")] +use crate::chunked_array::array::count::array_count_matches; +use crate::chunked_array::array::count::count_boolean_bits; +use crate::chunked_array::array::sum_mean::sum_with_nulls; +#[cfg(feature = "array_any_all")] +use crate::prelude::array::any_all::{array_all, array_any}; +use crate::prelude::array::get::array_get; +use crate::prelude::array::join::array_join; +use crate::prelude::array::sum_mean::sum_array_numerical; +use crate::series::ArgAgg; + +pub fn has_inner_nulls(ca: &ArrayChunked) -> bool { + for arr in ca.downcast_iter() { + if arr.values().null_count() > 0 { + return true; + } + } + false +} + +fn get_agg(ca: &ArrayChunked, agg_type: AggType) -> Series { + let values = ca.get_inner(); + let width = ca.width(); + min_max::array_dispatch(ca.name().clone(), &values, width, agg_type) +} + +pub trait ArrayNameSpace: AsArray { + fn array_max(&self) -> Series { + let ca = self.as_array(); + get_agg(ca, AggType::Max) + } + + fn array_min(&self) -> Series { + let ca = self.as_array(); + get_agg(ca, AggType::Min) + } + + fn array_sum(&self) -> PolarsResult { + let ca = self.as_array(); + + if has_inner_nulls(ca) { + return sum_with_nulls(ca, ca.inner_dtype()); + }; + + match ca.inner_dtype() { + DataType::Boolean => Ok(count_boolean_bits(ca).into_series()), + dt if dt.is_primitive_numeric() => Ok(sum_array_numerical(ca, dt)), + dt => sum_with_nulls(ca, dt), + } + } + + fn array_median(&self) -> PolarsResult { + let ca = self.as_array(); + dispersion::median_with_nulls(ca) + } + + fn array_std(&self, ddof: u8) -> PolarsResult { + let ca = self.as_array(); + dispersion::std_with_nulls(ca, ddof) + } + + fn array_var(&self, ddof: u8) -> PolarsResult { + let ca = self.as_array(); + dispersion::var_with_nulls(ca, ddof) + } + + fn array_unique(&self) -> PolarsResult { + let ca = self.as_array(); + ca.try_apply_amortized_to_list(|s| s.as_ref().unique()) + } + + fn array_unique_stable(&self) -> PolarsResult { + let ca = self.as_array(); + ca.try_apply_amortized_to_list(|s| s.as_ref().unique_stable()) + } + + fn array_n_unique(&self) -> PolarsResult { + let ca = self.as_array(); + ca.try_apply_amortized_generic(|opt_s| { + let opt_v = opt_s.map(|s| s.as_ref().n_unique()).transpose()?; + Ok(opt_v.map(|idx| idx as IdxSize)) + }) + } + + #[cfg(feature = "array_any_all")] + fn array_any(&self) -> PolarsResult { + let ca = self.as_array(); + array_any(ca) + } + + #[cfg(feature = "array_any_all")] + fn array_all(&self) -> PolarsResult { + let ca = self.as_array(); + array_all(ca) + } + + fn array_sort(&self, options: SortOptions) -> PolarsResult { + let ca = self.as_array(); + // SAFETY: Sort only changes the order of the elements in each subarray. + unsafe { ca.try_apply_amortized_same_type(|s| s.as_ref().sort_with(options)) } + } + + fn array_reverse(&self) -> ArrayChunked { + let ca = self.as_array(); + // SAFETY: Reverse only changes the order of the elements in each subarray + unsafe { ca.apply_amortized_same_type(|s| s.as_ref().reverse()) } + } + + fn array_arg_min(&self) -> IdxCa { + let ca = self.as_array(); + ca.apply_amortized_generic(|opt_s| { + opt_s.and_then(|s| s.as_ref().arg_min().map(|idx| idx as IdxSize)) + }) + } + + fn array_arg_max(&self) -> IdxCa { + let ca = self.as_array(); + ca.apply_amortized_generic(|opt_s| { + opt_s.and_then(|s| s.as_ref().arg_max().map(|idx| idx as IdxSize)) + }) + } + + fn array_get(&self, index: &Int64Chunked, null_on_oob: bool) -> PolarsResult { + let ca = self.as_array(); + array_get(ca, index, null_on_oob) + } + + fn array_join(&self, separator: &StringChunked, ignore_nulls: bool) -> PolarsResult { + let ca = self.as_array(); + array_join(ca, separator, ignore_nulls).map(|ok| ok.into_series()) + } + + #[cfg(feature = "array_count")] + fn array_count_matches(&self, element: AnyValue) -> PolarsResult { + let ca = self.as_array(); + array_count_matches(ca, element) + } + + fn array_shift(&self, n: &Series) -> PolarsResult { + let ca = self.as_array(); + let n_s = n.cast(&DataType::Int64)?; + let n = n_s.i64()?; + let out = match (ca.len(), n.len()) { + (a, b) if a == b => { + // SAFETY: Shift does not change the dtype and number of elements of sub-array. + unsafe { + ca.zip_and_apply_amortized_same_type(n, |opt_s, opt_periods| { + match (opt_s, opt_periods) { + (Some(s), Some(n)) => Some(s.as_ref().shift(n)), + _ => None, + } + }) + } + }, + (_, 1) => { + if let Some(n) = n.get(0) { + // SAFETY: Shift does not change the dtype and number of elements of sub-array. + unsafe { ca.apply_amortized_same_type(|s| s.as_ref().shift(n)) } + } else { + ArrayChunked::full_null_with_dtype( + ca.name().clone(), + ca.len(), + ca.inner_dtype(), + ca.width(), + ) + } + }, + (1, _) => { + if ca.get(0).is_some() { + // Optimize: This does not need to broadcast first. + let ca = ca.new_from_index(0, n.len()); + // SAFETY: Shift does not change the dtype and number of elements of sub-array. + unsafe { + ca.zip_and_apply_amortized_same_type(n, |opt_s, opt_periods| { + match (opt_s, opt_periods) { + (Some(s), Some(n)) => Some(s.as_ref().shift(n)), + _ => None, + } + }) + } + } else { + ArrayChunked::full_null_with_dtype( + ca.name().clone(), + ca.len(), + ca.inner_dtype(), + ca.width(), + ) + } + }, + _ => polars_bail!(length_mismatch = "arr.shift", ca.len(), n.len()), + }; + Ok(out.into_series()) + } +} + +impl ArrayNameSpace for ArrayChunked {} diff --git a/crates/polars-ops/src/chunked_array/array/sum_mean.rs b/crates/polars-ops/src/chunked_array/array/sum_mean.rs new file mode 100644 index 000000000000..1dae9aa83d9c --- /dev/null +++ b/crates/polars-ops/src/chunked_array/array/sum_mean.rs @@ -0,0 +1,129 @@ +use arrow::array::{Array, PrimitiveArray}; +use arrow::bitmap::Bitmap; +use arrow::legacy::utils::CustomIterTools; +use arrow::types::NativeType; +use num_traits::{NumCast, ToPrimitive}; +use polars_core::prelude::*; + +use crate::chunked_array::sum::sum_slice; + +fn dispatch_sum(arr: &dyn Array, width: usize, validity: Option<&Bitmap>) -> ArrayRef +where + T: NativeType + ToPrimitive, + S: NativeType + NumCast + std::iter::Sum, +{ + let values = arr.as_any().downcast_ref::>().unwrap(); + let values = values.values().as_slice(); + + let summed: Vec<_> = (0..values.len()) + .step_by(width) + .map(|start| { + let slice = unsafe { values.get_unchecked(start..start + width) }; + sum_slice::(slice) + }) + .collect_trusted(); + + Box::new(PrimitiveArray::from_data_default( + summed.into(), + validity.cloned(), + )) as ArrayRef +} + +pub(super) fn sum_array_numerical(ca: &ArrayChunked, inner_type: &DataType) -> Series { + let width = ca.width(); + use DataType::*; + let chunks = ca + .downcast_iter() + .map(|arr| { + let values = arr.values().as_ref(); + + match inner_type { + Int8 => dispatch_sum::(values, width, arr.validity()), + Int16 => dispatch_sum::(values, width, arr.validity()), + Int32 => dispatch_sum::(values, width, arr.validity()), + Int64 => dispatch_sum::(values, width, arr.validity()), + Int128 => dispatch_sum::(values, width, arr.validity()), + UInt8 => dispatch_sum::(values, width, arr.validity()), + UInt16 => dispatch_sum::(values, width, arr.validity()), + UInt32 => dispatch_sum::(values, width, arr.validity()), + UInt64 => dispatch_sum::(values, width, arr.validity()), + Float32 => dispatch_sum::(values, width, arr.validity()), + Float64 => dispatch_sum::(values, width, arr.validity()), + _ => unimplemented!(), + } + }) + .collect::>(); + + Series::try_from((ca.name().clone(), chunks)).unwrap() +} + +pub(super) fn sum_with_nulls(ca: &ArrayChunked, inner_dtype: &DataType) -> PolarsResult { + use DataType::*; + // TODO: add fast path for smaller ints? + let mut out = { + match inner_dtype { + Boolean => { + let out: IdxCa = ca + .amortized_iter() + .map(|s| s.and_then(|s| s.as_ref().sum().ok())) + .collect(); + out.into_series() + }, + UInt32 => { + let out: UInt32Chunked = ca + .amortized_iter() + .map(|s| s.and_then(|s| s.as_ref().sum().ok())) + .collect(); + out.into_series() + }, + UInt64 => { + let out: UInt64Chunked = ca + .amortized_iter() + .map(|s| s.and_then(|s| s.as_ref().sum().ok())) + .collect(); + out.into_series() + }, + Int32 => { + let out: Int32Chunked = ca + .amortized_iter() + .map(|s| s.and_then(|s| s.as_ref().sum().ok())) + .collect(); + out.into_series() + }, + Int64 => { + let out: Int64Chunked = ca + .amortized_iter() + .map(|s| s.and_then(|s| s.as_ref().sum().ok())) + .collect(); + out.into_series() + }, + #[cfg(feature = "dtype-i128")] + Int128 => { + let out: Int128Chunked = ca + .amortized_iter() + .map(|s| s.and_then(|s| s.as_ref().sum().ok())) + .collect(); + out.into_series() + }, + Float32 => { + let out: Float32Chunked = ca + .amortized_iter() + .map(|s| s.and_then(|s| s.as_ref().sum().ok())) + .collect(); + out.into_series() + }, + Float64 => { + let out: Float64Chunked = ca + .amortized_iter() + .map(|s| s.and_then(|s| s.as_ref().sum().ok())) + .collect(); + out.into_series() + }, + _ => { + polars_bail!(ComputeError: "summing array with dtype: {} not yet supported", ca.dtype()) + }, + } + }; + out.rename(ca.name().clone()); + Ok(out) +} diff --git a/crates/polars-ops/src/chunked_array/array/to_struct.rs b/crates/polars-ops/src/chunked_array/array/to_struct.rs new file mode 100644 index 000000000000..d316e82888c0 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/array/to_struct.rs @@ -0,0 +1,46 @@ +use polars_core::POOL; +use polars_utils::format_pl_smallstr; +use polars_utils::pl_str::PlSmallStr; +use rayon::prelude::*; + +use super::*; + +pub type ArrToStructNameGenerator = Arc PlSmallStr + Send + Sync>; + +pub fn arr_default_struct_name_gen(idx: usize) -> PlSmallStr { + format_pl_smallstr!("field_{idx}") +} + +pub trait ToStruct: AsArray { + fn to_struct( + &self, + name_generator: Option, + ) -> PolarsResult { + let ca = self.as_array(); + let n_fields = ca.width(); + + let name_generator = name_generator + .as_deref() + .unwrap_or(&arr_default_struct_name_gen); + + let fields = POOL.install(|| { + (0..n_fields) + .into_par_iter() + .map(|i| { + ca.array_get( + &Int64Chunked::from_slice(PlSmallStr::EMPTY, &[i as i64]), + true, + ) + .map(|mut s| { + s.rename(name_generator(i).clone()); + s + }) + }) + .collect::>>() + })?; + + StructChunked::from_series(ca.name().clone(), ca.len(), fields.iter()) + } +} + +impl ToStruct for ArrayChunked {} diff --git a/crates/polars-ops/src/chunked_array/binary/cast_binary_to_numerical.rs b/crates/polars-ops/src/chunked_array/binary/cast_binary_to_numerical.rs new file mode 100644 index 000000000000..d3f76f6b8263 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/binary/cast_binary_to_numerical.rs @@ -0,0 +1,80 @@ +use arrow::array::{Array, BinaryViewArray, PrimitiveArray}; +use arrow::datatypes::ArrowDataType; +use arrow::types::NativeType; +use polars_error::PolarsResult; + +/// Trait for casting bytes to a primitive type +pub trait Cast { + fn cast_le(val: &[u8]) -> Option + where + Self: Sized; + fn cast_be(val: &[u8]) -> Option + where + Self: Sized; +} +macro_rules! impl_cast { + ($primitive_type:ident) => { + impl Cast for $primitive_type { + fn cast_le(val: &[u8]) -> Option { + Some($primitive_type::from_le_bytes(val.try_into().ok()?)) + } + + fn cast_be(val: &[u8]) -> Option { + Some($primitive_type::from_be_bytes(val.try_into().ok()?)) + } + } + }; +} + +impl_cast!(i8); +impl_cast!(i16); +impl_cast!(i32); +impl_cast!(i64); +impl_cast!(i128); +impl_cast!(u8); +impl_cast!(u16); +impl_cast!(u32); +impl_cast!(u64); +impl_cast!(u128); +impl_cast!(f32); +impl_cast!(f64); + +/// Casts a [`BinaryArray`] to a [`PrimitiveArray`], making any uncastable value a Null. +pub(super) fn cast_binview_to_primitive( + from: &BinaryViewArray, + to: &ArrowDataType, + is_little_endian: bool, +) -> PrimitiveArray +where + T: Cast + NativeType, +{ + let iter = from.iter().map(|x| { + x.and_then::(|x| { + if is_little_endian { + T::cast_le(x) + } else { + T::cast_be(x) + } + }) + }); + + PrimitiveArray::::from_trusted_len_iter(iter).to(to.clone()) +} + +/// Casts a [`BinaryArray`] to a [`PrimitiveArray`], making any uncastable value a Null. +pub(super) fn cast_binview_to_primitive_dyn( + from: &dyn Array, + to: &ArrowDataType, + is_little_endian: bool, +) -> PolarsResult> +where + T: Cast + NativeType, +{ + let from = from.as_any().downcast_ref().unwrap(); + + Ok(Box::new(cast_binview_to_primitive::( + from, + to, + is_little_endian, + ))) +} diff --git a/crates/polars-ops/src/chunked_array/binary/mod.rs b/crates/polars-ops/src/chunked_array/binary/mod.rs new file mode 100644 index 000000000000..df8847a1bc0c --- /dev/null +++ b/crates/polars-ops/src/chunked_array/binary/mod.rs @@ -0,0 +1,15 @@ +mod cast_binary_to_numerical; +mod namespace; + +pub use namespace::*; +use polars_core::prelude::*; + +pub trait AsBinary { + fn as_binary(&self) -> &BinaryChunked; +} + +impl AsBinary for BinaryChunked { + fn as_binary(&self) -> &BinaryChunked { + self + } +} diff --git a/crates/polars-ops/src/chunked_array/binary/namespace.rs b/crates/polars-ops/src/chunked_array/binary/namespace.rs new file mode 100644 index 000000000000..310d79d36e0d --- /dev/null +++ b/crates/polars-ops/src/chunked_array/binary/namespace.rs @@ -0,0 +1,188 @@ +#[cfg(feature = "binary_encoding")] +use std::borrow::Cow; + +use arrow::with_match_primitive_type; +#[cfg(feature = "binary_encoding")] +use base64::Engine as _; +#[cfg(feature = "binary_encoding")] +use base64::engine::general_purpose; +use memchr::memmem::find; +use polars_compute::size::binary_size_bytes; +use polars_core::prelude::arity::{broadcast_binary_elementwise_values, unary_elementwise_values}; + +use super::cast_binary_to_numerical::cast_binview_to_primitive_dyn; +use super::*; + +pub trait BinaryNameSpaceImpl: AsBinary { + /// Check if binary contains given literal + fn contains(&self, lit: &[u8]) -> BooleanChunked { + let ca = self.as_binary(); + let f = |s: &[u8]| find(s, lit).is_some(); + unary_elementwise_values(ca, f) + } + + fn contains_chunked(&self, lit: &BinaryChunked) -> PolarsResult { + let ca = self.as_binary(); + Ok(match lit.len() { + 1 => match lit.get(0) { + Some(lit) => ca.contains(lit), + None => BooleanChunked::full_null(ca.name().clone(), ca.len()), + }, + _ => { + polars_ensure!( + ca.len() == lit.len() || ca.len() == 1, + length_mismatch = "bin.contains", + ca.len(), + lit.len() + ); + broadcast_binary_elementwise_values(ca, lit, |src, lit| find(src, lit).is_some()) + }, + }) + } + + /// Check if strings ends with a substring + fn ends_with(&self, sub: &[u8]) -> BooleanChunked { + let ca = self.as_binary(); + let f = |s: &[u8]| s.ends_with(sub); + ca.apply_nonnull_values_generic(DataType::Boolean, f) + } + + /// Check if strings starts with a substring + fn starts_with(&self, sub: &[u8]) -> BooleanChunked { + let ca = self.as_binary(); + let f = |s: &[u8]| s.starts_with(sub); + ca.apply_nonnull_values_generic(DataType::Boolean, f) + } + + fn starts_with_chunked(&self, prefix: &BinaryChunked) -> PolarsResult { + let ca = self.as_binary(); + Ok(match prefix.len() { + 1 => match prefix.get(0) { + Some(s) => self.starts_with(s), + None => BooleanChunked::full_null(ca.name().clone(), ca.len()), + }, + _ => { + polars_ensure!( + ca.len() == prefix.len() || ca.len() == 1, + length_mismatch = "bin.starts_with", + ca.len(), + prefix.len() + ); + broadcast_binary_elementwise_values(ca, prefix, |s, sub| s.starts_with(sub)) + }, + }) + } + + fn ends_with_chunked(&self, suffix: &BinaryChunked) -> PolarsResult { + let ca = self.as_binary(); + Ok(match suffix.len() { + 1 => match suffix.get(0) { + Some(s) => self.ends_with(s), + None => BooleanChunked::full_null(ca.name().clone(), ca.len()), + }, + _ => { + polars_ensure!( + ca.len() == suffix.len() || ca.len() == 1, + length_mismatch = "bin.ends_with", + ca.len(), + suffix.len() + ); + broadcast_binary_elementwise_values(ca, suffix, |s, sub| s.ends_with(sub)) + }, + }) + } + + /// Get the size of the binary values in bytes. + fn size_bytes(&self) -> UInt32Chunked { + let ca = self.as_binary(); + ca.apply_kernel_cast(&binary_size_bytes) + } + + #[cfg(feature = "binary_encoding")] + fn hex_decode(&self, strict: bool) -> PolarsResult { + let ca = self.as_binary(); + if strict { + ca.try_apply_nonnull_values_generic(|s| { + hex::decode(s).map_err(|_| { + polars_err!( + ComputeError: + "invalid `hex` encoding found; try setting `strict=false` to ignore" + ) + }) + }) + } else { + Ok(ca.apply(|opt_s| opt_s.and_then(|s| hex::decode(s).ok().map(Cow::Owned)))) + } + } + + #[cfg(feature = "binary_encoding")] + fn hex_encode(&self) -> Series { + let ca = self.as_binary(); + unsafe { + ca.apply_values(|s| hex::encode(s).into_bytes().into()) + .cast_unchecked(&DataType::String) + .unwrap() + } + } + + #[cfg(feature = "binary_encoding")] + fn base64_decode(&self, strict: bool) -> PolarsResult { + let ca = self.as_binary(); + if strict { + ca.try_apply_nonnull_values_generic(|s| { + general_purpose::STANDARD.decode(s).map_err(|_e| { + polars_err!( + ComputeError: + "invalid `base64` encoding found; try setting `strict=false` to ignore" + ) + }) + }) + } else { + Ok(ca.apply(|opt_s| { + opt_s.and_then(|s| general_purpose::STANDARD.decode(s).ok().map(Cow::Owned)) + })) + } + } + + #[cfg(feature = "binary_encoding")] + fn base64_encode(&self) -> Series { + let ca = self.as_binary(); + unsafe { + ca.apply_values(|s| general_purpose::STANDARD.encode(s).into_bytes().into()) + .cast_unchecked(&DataType::String) + .unwrap() + } + } + + #[cfg(feature = "binary_encoding")] + #[allow(clippy::wrong_self_convention)] + fn from_buffer(&self, dtype: &DataType, is_little_endian: bool) -> PolarsResult { + let ca = self.as_binary(); + let arrow_type = dtype.to_arrow(CompatLevel::newest()); + + match arrow_type.to_physical_type() { + arrow::datatypes::PhysicalType::Primitive(ty) => { + with_match_primitive_type!(ty, |$T| { + unsafe { + Ok(Series::from_chunks_and_dtype_unchecked( + ca.name().clone(), + ca.chunks().iter().map(|chunk| { + cast_binview_to_primitive_dyn::<$T>( + &**chunk, + &arrow_type, + is_little_endian, + ) + }).collect::>>()?, + dtype + )) + } + }) + }, + _ => Err( + polars_err!(InvalidOperation:"unsupported data type in from_buffer. Only numerical types are allowed."), + ), + } + } +} + +impl BinaryNameSpaceImpl for BinaryChunked {} diff --git a/crates/polars-ops/src/chunked_array/cov.rs b/crates/polars-ops/src/chunked_array/cov.rs new file mode 100644 index 000000000000..d6d21975d57d --- /dev/null +++ b/crates/polars-ops/src/chunked_array/cov.rs @@ -0,0 +1,34 @@ +use num_traits::AsPrimitive; +use polars_compute::moment::{CovState, PearsonState}; +use polars_core::prelude::*; +use polars_core::utils::align_chunks_binary; + +/// Compute the covariance between two columns. +pub fn cov(a: &ChunkedArray, b: &ChunkedArray, ddof: u8) -> Option +where + T: PolarsNumericType, + T::Native: AsPrimitive, + ChunkedArray: ChunkVar, +{ + let (a, b) = align_chunks_binary(a, b); + let mut out = CovState::default(); + for (a, b) in a.downcast_iter().zip(b.downcast_iter()) { + out.combine(&polars_compute::moment::cov(a, b)) + } + out.finalize(ddof) +} + +/// Compute the pearson correlation between two columns. +pub fn pearson_corr(a: &ChunkedArray, b: &ChunkedArray) -> Option +where + T: PolarsNumericType, + T::Native: AsPrimitive, + ChunkedArray: ChunkVar, +{ + let (a, b) = align_chunks_binary(a, b); + let mut out = PearsonState::default(); + for (a, b) in a.downcast_iter().zip(b.downcast_iter()) { + out.combine(&polars_compute::moment::pearson_corr(a, b)) + } + Some(out.finalize()) +} diff --git a/crates/polars-ops/src/chunked_array/datetime/mod.rs b/crates/polars-ops/src/chunked_array/datetime/mod.rs new file mode 100644 index 000000000000..a84031ff486d --- /dev/null +++ b/crates/polars-ops/src/chunked_array/datetime/mod.rs @@ -0,0 +1,4 @@ +#[cfg(feature = "timezones")] +mod replace_time_zone; +#[cfg(feature = "timezones")] +pub use replace_time_zone::*; diff --git a/crates/polars-ops/src/chunked_array/datetime/replace_time_zone.rs b/crates/polars-ops/src/chunked_array/datetime/replace_time_zone.rs new file mode 100644 index 000000000000..1637dd392707 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/datetime/replace_time_zone.rs @@ -0,0 +1,153 @@ +use std::str::FromStr; + +use arrow::legacy::kernels::convert_to_naive_local; +use arrow::temporal_conversions::{ + timestamp_ms_to_datetime, timestamp_ns_to_datetime, timestamp_us_to_datetime, +}; +use chrono::NaiveDateTime; +use chrono_tz::UTC; +use polars_core::chunked_array::ops::arity::try_binary_elementwise; +use polars_core::chunked_array::temporal::parse_time_zone; +use polars_core::prelude::*; + +pub fn replace_time_zone( + datetime: &Logical, + time_zone: Option<&str>, + ambiguous: &StringChunked, + non_existent: NonExistent, +) -> PolarsResult { + let from_time_zone = datetime.time_zone().as_deref().unwrap_or("UTC"); + let from_tz = parse_time_zone(from_time_zone)?; + let to_tz = parse_time_zone(time_zone.unwrap_or("UTC"))?; + if (from_tz == to_tz) + & ((from_tz == UTC) | ((ambiguous.len() == 1) & (ambiguous.get(0) == Some("raise")))) + { + let mut out = datetime + .0 + .clone() + .into_datetime(datetime.time_unit(), time_zone.map(PlSmallStr::from_str)); + out.set_sorted_flag(datetime.is_sorted_flag()); + return Ok(out); + } + let timestamp_to_datetime: fn(i64) -> NaiveDateTime = match datetime.time_unit() { + TimeUnit::Milliseconds => timestamp_ms_to_datetime, + TimeUnit::Microseconds => timestamp_us_to_datetime, + TimeUnit::Nanoseconds => timestamp_ns_to_datetime, + }; + let datetime_to_timestamp: fn(NaiveDateTime) -> i64 = match datetime.time_unit() { + TimeUnit::Milliseconds => datetime_to_timestamp_ms, + TimeUnit::Microseconds => datetime_to_timestamp_us, + TimeUnit::Nanoseconds => datetime_to_timestamp_ns, + }; + + let out = if ambiguous.len() == 1 + && ambiguous.get(0) != Some("null") + && non_existent == NonExistent::Raise + { + impl_replace_time_zone_fast( + datetime, + ambiguous.get(0), + timestamp_to_datetime, + datetime_to_timestamp, + &from_tz, + &to_tz, + ) + } else { + impl_replace_time_zone( + datetime, + ambiguous, + non_existent, + timestamp_to_datetime, + datetime_to_timestamp, + &from_tz, + &to_tz, + ) + }; + + let mut out = out?.into_datetime(datetime.time_unit(), time_zone.map(PlSmallStr::from_str)); + if from_time_zone == "UTC" && ambiguous.len() == 1 && ambiguous.get(0) == Some("raise") { + // In general, the sortedness flag can't be preserved. + // To be safe, we only do so in the simplest case when we know for sure that there is no "daylight savings weirdness" going on, i.e.: + // - `from_tz` is guaranteed to not observe daylight savings time; + // - user is just passing 'raise' to 'ambiguous'. + // Both conditions above need to be satisfied. + out.set_sorted_flag(datetime.is_sorted_flag()); + } + Ok(out) +} + +/// If `ambiguous` is length-1 and not equal to "null", we can take a slightly faster path. +pub fn impl_replace_time_zone_fast( + datetime: &Logical, + ambiguous: Option<&str>, + timestamp_to_datetime: fn(i64) -> NaiveDateTime, + datetime_to_timestamp: fn(NaiveDateTime) -> i64, + from_tz: &chrono_tz::Tz, + to_tz: &chrono_tz::Tz, +) -> PolarsResult { + match ambiguous { + Some(ambiguous) => datetime.0.try_apply_nonnull_values_generic(|timestamp| { + let ndt = timestamp_to_datetime(timestamp); + Ok(datetime_to_timestamp( + convert_to_naive_local( + from_tz, + to_tz, + ndt, + Ambiguous::from_str(ambiguous)?, + NonExistent::Raise, + )? + .expect("we didn't use Ambiguous::Null or NonExistent::Null"), + )) + }), + _ => Ok(datetime.0.apply(|_| None)), + } +} + +pub fn impl_replace_time_zone( + datetime: &Logical, + ambiguous: &StringChunked, + non_existent: NonExistent, + timestamp_to_datetime: fn(i64) -> NaiveDateTime, + datetime_to_timestamp: fn(NaiveDateTime) -> i64, + from_tz: &chrono_tz::Tz, + to_tz: &chrono_tz::Tz, +) -> PolarsResult { + match ambiguous.len() { + 1 => { + let iter = datetime.0.downcast_iter().map(|arr| { + let element_iter = arr.iter().map(|timestamp_opt| match timestamp_opt { + Some(timestamp) => { + let ndt = timestamp_to_datetime(*timestamp); + let res = convert_to_naive_local( + from_tz, + to_tz, + ndt, + Ambiguous::from_str(ambiguous.get(0).unwrap())?, + non_existent, + )?; + Ok::<_, PolarsError>(res.map(datetime_to_timestamp)) + }, + None => Ok(None), + }); + element_iter.try_collect_arr() + }); + ChunkedArray::try_from_chunk_iter(datetime.0.name().clone(), iter) + }, + _ => try_binary_elementwise(datetime, ambiguous, |timestamp_opt, ambiguous_opt| { + match (timestamp_opt, ambiguous_opt) { + (Some(timestamp), Some(ambiguous)) => { + let ndt = timestamp_to_datetime(timestamp); + Ok(convert_to_naive_local( + from_tz, + to_tz, + ndt, + Ambiguous::from_str(ambiguous)?, + non_existent, + )? + .map(datetime_to_timestamp)) + }, + _ => Ok(None), + } + }), + } +} diff --git a/crates/polars-ops/src/chunked_array/gather/chunked.rs b/crates/polars-ops/src/chunked_array/gather/chunked.rs new file mode 100644 index 000000000000..cee1ce46d111 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/gather/chunked.rs @@ -0,0 +1,1027 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use std::fmt::Debug; + +use arrow::array::{Array, BinaryViewArrayGeneric, View, ViewType}; +use arrow::bitmap::BitmapBuilder; +use arrow::buffer::Buffer; +use arrow::legacy::trusted_len::TrustedLenPush; +use hashbrown::hash_map::Entry; +use polars_core::prelude::gather::_update_gather_sorted_flag; +use polars_core::prelude::*; +use polars_core::series::IsSorted; +use polars_core::utils::Container; +use polars_core::with_match_physical_numeric_polars_type; + +use crate::frame::IntoDf; + +/// Gather by [`ChunkId`] +pub trait TakeChunked { + /// Gathers elements from a ChunkedArray, specifying for each element a + /// chunk index and index within that chunk through ChunkId. If + /// avoid_sharing is true the returned data should not share references + /// with the original array (like shared buffers in views). + /// + /// # Safety + /// This function doesn't do any bound checks. + unsafe fn take_chunked_unchecked( + &self, + by: &[ChunkId], + sorted: IsSorted, + avoid_sharing: bool, + ) -> Self; + + /// # Safety + /// This function doesn't do any bound checks. + unsafe fn take_opt_chunked_unchecked( + &self, + by: &[ChunkId], + avoid_sharing: bool, + ) -> Self; +} + +impl TakeChunked for DataFrame { + /// Take elements by a slice of [`ChunkId`]s. + /// + /// # Safety + /// Does not do any bound checks. + /// `sorted` indicates if the chunks are sorted. + unsafe fn take_chunked_unchecked( + &self, + idx: &[ChunkId], + sorted: IsSorted, + avoid_sharing: bool, + ) -> DataFrame { + let cols = self + .to_df() + ._apply_columns(&|s| s.take_chunked_unchecked(idx, sorted, avoid_sharing)); + + unsafe { DataFrame::new_no_checks_height_from_first(cols) } + } + + /// Take elements by a slice of optional [`ChunkId`]s. + /// + /// # Safety + /// Does not do any bound checks. + unsafe fn take_opt_chunked_unchecked( + &self, + idx: &[ChunkId], + avoid_sharing: bool, + ) -> DataFrame { + let cols = self + .to_df() + ._apply_columns(&|s| s.take_opt_chunked_unchecked(idx, avoid_sharing)); + + unsafe { DataFrame::new_no_checks_height_from_first(cols) } + } +} + +pub trait TakeChunkedHorPar: IntoDf { + /// # Safety + /// Doesn't perform any bound checks + unsafe fn _take_chunked_unchecked_hor_par( + &self, + idx: &[ChunkId], + sorted: IsSorted, + ) -> DataFrame { + let cols = self + .to_df() + ._apply_columns_par(&|s| s.take_chunked_unchecked(idx, sorted, false)); + + unsafe { DataFrame::new_no_checks_height_from_first(cols) } + } + + /// # Safety + /// Doesn't perform any bound checks + /// + /// Check for null state in `ChunkId`. + unsafe fn _take_opt_chunked_unchecked_hor_par( + &self, + idx: &[ChunkId], + ) -> DataFrame { + let cols = self + .to_df() + ._apply_columns_par(&|s| s.take_opt_chunked_unchecked(idx, false)); + + unsafe { DataFrame::new_no_checks_height_from_first(cols) } + } +} + +impl TakeChunkedHorPar for DataFrame {} + +impl TakeChunked for Column { + unsafe fn take_chunked_unchecked( + &self, + by: &[ChunkId], + sorted: IsSorted, + avoid_sharing: bool, + ) -> Self { + // @scalar-opt + let s = self.as_materialized_series(); + let s = unsafe { s.take_chunked_unchecked(by, sorted, avoid_sharing) }; + s.into_column() + } + + unsafe fn take_opt_chunked_unchecked( + &self, + by: &[ChunkId], + avoid_sharing: bool, + ) -> Self { + // @scalar-opt + let s = self.as_materialized_series(); + let s = unsafe { s.take_opt_chunked_unchecked(by, avoid_sharing) }; + s.into_column() + } +} + +impl TakeChunked for Series { + unsafe fn take_chunked_unchecked( + &self, + by: &[ChunkId], + sorted: IsSorted, + avoid_sharing: bool, + ) -> Self { + use DataType::*; + match self.dtype() { + dt if dt.is_primitive_numeric() => { + with_match_physical_numeric_polars_type!(self.dtype(), |$T| { + let ca: &ChunkedArray<$T> = self.as_ref().as_ref().as_ref(); + ca.take_chunked_unchecked(by, sorted, avoid_sharing).into_series() + }) + }, + Boolean => { + let ca = self.bool().unwrap(); + ca.take_chunked_unchecked(by, sorted, avoid_sharing) + .into_series() + }, + Binary => { + let ca = self.binary().unwrap(); + take_chunked_unchecked_binview(ca, by, sorted, avoid_sharing).into_series() + }, + String => { + let ca = self.str().unwrap(); + take_chunked_unchecked_binview(ca, by, sorted, avoid_sharing).into_series() + }, + List(_) => { + let ca = self.list().unwrap(); + ca.take_chunked_unchecked(by, sorted, avoid_sharing) + .into_series() + }, + #[cfg(feature = "dtype-array")] + Array(_, _) => { + let ca = self.array().unwrap(); + ca.take_chunked_unchecked(by, sorted, avoid_sharing) + .into_series() + }, + #[cfg(feature = "dtype-struct")] + Struct(_) => { + let ca = self.struct_().unwrap(); + take_chunked_unchecked_struct(ca, by, sorted, avoid_sharing).into_series() + }, + #[cfg(feature = "object")] + Object(_) => take_unchecked_object(self, by, sorted), + #[cfg(feature = "dtype-decimal")] + Decimal(_, _) => { + let ca = self.decimal().unwrap(); + let out = ca.0.take_chunked_unchecked(by, sorted, avoid_sharing); + out.into_decimal_unchecked(ca.precision(), ca.scale()) + .into_series() + }, + #[cfg(feature = "dtype-date")] + Date => { + let ca = self.date().unwrap(); + ca.physical() + .take_chunked_unchecked(by, sorted, avoid_sharing) + .into_date() + .into_series() + }, + #[cfg(feature = "dtype-datetime")] + Datetime(u, z) => { + let ca = self.datetime().unwrap(); + ca.physical() + .take_chunked_unchecked(by, sorted, avoid_sharing) + .into_datetime(*u, z.clone()) + .into_series() + }, + #[cfg(feature = "dtype-duration")] + Duration(u) => { + let ca = self.duration().unwrap(); + ca.physical() + .take_chunked_unchecked(by, sorted, avoid_sharing) + .into_duration(*u) + .into_series() + }, + #[cfg(feature = "dtype-time")] + Time => { + let ca = self.time().unwrap(); + ca.physical() + .take_chunked_unchecked(by, sorted, avoid_sharing) + .into_time() + .into_series() + }, + #[cfg(feature = "dtype-categorical")] + Categorical(revmap, ord) | Enum(revmap, ord) => { + let ca = self.categorical().unwrap(); + let t = ca + .physical() + .take_chunked_unchecked(by, sorted, avoid_sharing); + CategoricalChunked::from_cats_and_rev_map_unchecked( + t, + revmap.as_ref().unwrap().clone(), + matches!(self.dtype(), Enum(..)), + *ord, + ) + .into_series() + }, + Null => Series::new_null(self.name().clone(), by.len()), + _ => unreachable!(), + } + } + + /// Take function that checks of null state in `ChunkIdx`. + unsafe fn take_opt_chunked_unchecked( + &self, + by: &[ChunkId], + avoid_sharing: bool, + ) -> Self { + use DataType::*; + match self.dtype() { + dt if dt.is_primitive_numeric() => { + with_match_physical_numeric_polars_type!(self.dtype(), |$T| { + let ca: &ChunkedArray<$T> = self.as_ref().as_ref().as_ref(); + ca.take_opt_chunked_unchecked(by, avoid_sharing).into_series() + }) + }, + Boolean => { + let ca = self.bool().unwrap(); + ca.take_opt_chunked_unchecked(by, avoid_sharing) + .into_series() + }, + Binary => { + let ca = self.binary().unwrap(); + take_opt_chunked_unchecked_binview(ca, by, avoid_sharing).into_series() + }, + String => { + let ca = self.str().unwrap(); + take_opt_chunked_unchecked_binview(ca, by, avoid_sharing).into_series() + }, + List(_) => { + let ca = self.list().unwrap(); + ca.take_opt_chunked_unchecked(by, avoid_sharing) + .into_series() + }, + #[cfg(feature = "dtype-array")] + Array(_, _) => { + let ca = self.array().unwrap(); + ca.take_opt_chunked_unchecked(by, avoid_sharing) + .into_series() + }, + #[cfg(feature = "dtype-struct")] + Struct(_) => { + let ca = self.struct_().unwrap(); + take_opt_chunked_unchecked_struct(ca, by, avoid_sharing).into_series() + }, + #[cfg(feature = "object")] + Object(_) => take_opt_unchecked_object(self, by, avoid_sharing), + #[cfg(feature = "dtype-decimal")] + Decimal(_, _) => { + let ca = self.decimal().unwrap(); + let out = ca.0.take_opt_chunked_unchecked(by, avoid_sharing); + out.into_decimal_unchecked(ca.precision(), ca.scale()) + .into_series() + }, + #[cfg(feature = "dtype-date")] + Date => { + let ca = self.date().unwrap(); + ca.physical() + .take_opt_chunked_unchecked(by, avoid_sharing) + .into_date() + .into_series() + }, + #[cfg(feature = "dtype-datetime")] + Datetime(u, z) => { + let ca = self.datetime().unwrap(); + ca.physical() + .take_opt_chunked_unchecked(by, avoid_sharing) + .into_datetime(*u, z.clone()) + .into_series() + }, + #[cfg(feature = "dtype-duration")] + Duration(u) => { + let ca = self.duration().unwrap(); + ca.physical() + .take_opt_chunked_unchecked(by, avoid_sharing) + .into_duration(*u) + .into_series() + }, + #[cfg(feature = "dtype-time")] + Time => { + let ca = self.time().unwrap(); + ca.physical() + .take_opt_chunked_unchecked(by, avoid_sharing) + .into_time() + .into_series() + }, + #[cfg(feature = "dtype-categorical")] + Categorical(revmap, ord) | Enum(revmap, ord) => { + let ca = self.categorical().unwrap(); + let ret = ca.physical().take_opt_chunked_unchecked(by, avoid_sharing); + CategoricalChunked::from_cats_and_rev_map_unchecked( + ret, + revmap.as_ref().unwrap().clone(), + matches!(self.dtype(), Enum(..)), + *ord, + ) + .into_series() + }, + Null => Series::new_null(self.name().clone(), by.len()), + _ => unreachable!(), + } + } +} + +impl TakeChunked for ChunkedArray +where + T: PolarsDataType, + T::Array: Debug, +{ + unsafe fn take_chunked_unchecked( + &self, + by: &[ChunkId], + sorted: IsSorted, + _allow_sharing: bool, + ) -> Self { + let arrow_dtype = self.dtype().to_arrow(CompatLevel::newest()); + + let mut out = if !self.has_nulls() { + let iter = by.iter().map(|chunk_id| { + debug_assert!( + !chunk_id.is_null(), + "null chunks should not hit this branch" + ); + let (chunk_idx, array_idx) = chunk_id.extract(); + let arr = self.downcast_get_unchecked(chunk_idx as usize); + arr.value_unchecked(array_idx as usize).clone() + }); + + let arr = iter.collect_arr_trusted_with_dtype(arrow_dtype); + ChunkedArray::with_chunk(self.name().clone(), arr) + } else { + let iter = by.iter().map(|chunk_id| { + debug_assert!( + !chunk_id.is_null(), + "null chunks should not hit this branch" + ); + let (chunk_idx, array_idx) = chunk_id.extract(); + let arr = self.downcast_get_unchecked(chunk_idx as usize); + arr.get_unchecked(array_idx as usize) + }); + + let arr = iter.collect_arr_trusted_with_dtype(arrow_dtype); + ChunkedArray::with_chunk(self.name().clone(), arr) + }; + let sorted_flag = _update_gather_sorted_flag(self.is_sorted_flag(), sorted); + out.set_sorted_flag(sorted_flag); + out + } + + // Take function that checks of null state in `ChunkIdx`. + unsafe fn take_opt_chunked_unchecked( + &self, + by: &[ChunkId], + _allow_sharing: bool, + ) -> Self { + let arrow_dtype = self.dtype().to_arrow(CompatLevel::newest()); + + if !self.has_nulls() { + let arr = by + .iter() + .map(|chunk_id| { + if chunk_id.is_null() { + None + } else { + let (chunk_idx, array_idx) = chunk_id.extract(); + let arr = self.downcast_get_unchecked(chunk_idx as usize); + Some(arr.value_unchecked(array_idx as usize).clone()) + } + }) + .collect_arr_trusted_with_dtype(arrow_dtype); + + ChunkedArray::with_chunk(self.name().clone(), arr) + } else { + let arr = by + .iter() + .map(|chunk_id| { + if chunk_id.is_null() { + None + } else { + let (chunk_idx, array_idx) = chunk_id.extract(); + let arr = self.downcast_get_unchecked(chunk_idx as usize); + arr.get_unchecked(array_idx as usize) + } + }) + .collect_arr_trusted_with_dtype(arrow_dtype); + + ChunkedArray::with_chunk(self.name().clone(), arr) + } + } +} + +#[cfg(feature = "object")] +unsafe fn take_unchecked_object( + s: &Series, + by: &[ChunkId], + _sorted: IsSorted, +) -> Series { + use polars_core::chunked_array::object::registry::get_object_builder; + + let mut builder = get_object_builder(s.name().clone(), by.len()); + + by.iter().for_each(|chunk_id| { + let (chunk_idx, array_idx) = chunk_id.extract(); + let object = s.get_object_chunked_unchecked(chunk_idx as usize, array_idx as usize); + builder.append_option(object.map(|v| v.as_any())) + }); + builder.to_series() +} + +#[cfg(feature = "object")] +unsafe fn take_opt_unchecked_object( + s: &Series, + by: &[ChunkId], + _allow_sharing: bool, +) -> Series { + use polars_core::chunked_array::object::registry::get_object_builder; + + let mut builder = get_object_builder(s.name().clone(), by.len()); + + by.iter().for_each(|chunk_id| { + if chunk_id.is_null() { + builder.append_null() + } else { + let (chunk_idx, array_idx) = chunk_id.extract(); + let object = s.get_object_chunked_unchecked(chunk_idx as usize, array_idx as usize); + builder.append_option(object.map(|v| v.as_any())) + } + }); + builder.to_series() +} + +unsafe fn take_chunked_unchecked_binview( + ca: &ChunkedArray, + by: &[ChunkId], + sorted: IsSorted, + avoid_sharing: bool, +) -> ChunkedArray +where + T: PolarsDataType>, + T::Array: Debug, + V: ViewType + ?Sized, +{ + if avoid_sharing { + return ca.take_chunked_unchecked(by, sorted, avoid_sharing); + } + + let mut views = Vec::with_capacity(by.len()); + let (validity, arc_data_buffers); + + // If we can cheaply clone the list of buffers from the ChunkedArray we will, + // otherwise we will only clone those buffers we need. + if ca.n_chunks() == 1 { + let arr = ca.downcast_iter().next().unwrap(); + let arr_views = arr.views(); + + validity = if arr.has_nulls() { + let mut validity = BitmapBuilder::with_capacity(by.len()); + for id in by.iter() { + let (chunk_idx, array_idx) = id.extract(); + debug_assert!(chunk_idx == 0); + if arr.is_null_unchecked(array_idx as usize) { + views.push_unchecked(View::default()); + validity.push_unchecked(false); + } else { + views.push_unchecked(*arr_views.get_unchecked(array_idx as usize)); + validity.push_unchecked(true); + } + } + Some(validity.freeze()) + } else { + for id in by.iter() { + let (chunk_idx, array_idx) = id.extract(); + debug_assert!(chunk_idx == 0); + views.push_unchecked(*arr_views.get_unchecked(array_idx as usize)); + } + None + }; + + arc_data_buffers = arr.data_buffers().clone(); + } + // Dedup the buffers while creating the views. + else if by.len() < ca.n_chunks() { + let mut buffer_idxs = PlHashMap::with_capacity(8); + let mut buffers = Vec::with_capacity(8); + + validity = if ca.has_nulls() { + let mut validity = BitmapBuilder::with_capacity(by.len()); + for id in by.iter() { + let (chunk_idx, array_idx) = id.extract(); + + let arr = ca.downcast_get_unchecked(chunk_idx as usize); + if arr.is_null_unchecked(array_idx as usize) { + views.push_unchecked(View::default()); + validity.push_unchecked(false); + } else { + let view = *arr.views().get_unchecked(array_idx as usize); + views.push_unchecked(update_view_and_dedup( + view, + arr.data_buffers(), + &mut buffer_idxs, + &mut buffers, + )); + validity.push_unchecked(true); + } + } + Some(validity.freeze()) + } else { + for id in by.iter() { + let (chunk_idx, array_idx) = id.extract(); + + let arr = ca.downcast_get_unchecked(chunk_idx as usize); + let view = *arr.views().get_unchecked(array_idx as usize); + views.push_unchecked(update_view_and_dedup( + view, + arr.data_buffers(), + &mut buffer_idxs, + &mut buffers, + )); + } + None + }; + + arc_data_buffers = buffers.into(); + } + // Dedup the buffers up front + else { + let (buffers, buffer_offsets) = dedup_buffers_by_arc(ca); + + validity = if ca.has_nulls() { + let mut validity = BitmapBuilder::with_capacity(by.len()); + for id in by.iter() { + let (chunk_idx, array_idx) = id.extract(); + + let arr = ca.downcast_get_unchecked(chunk_idx as usize); + if arr.is_null_unchecked(array_idx as usize) { + views.push_unchecked(View::default()); + validity.push_unchecked(false); + } else { + let view = *arr.views().get_unchecked(array_idx as usize); + let view = rewrite_view(view, chunk_idx, &buffer_offsets); + views.push_unchecked(view); + validity.push_unchecked(true); + } + } + Some(validity.freeze()) + } else { + for id in by.iter() { + let (chunk_idx, array_idx) = id.extract(); + + let arr = ca.downcast_get_unchecked(chunk_idx as usize); + let view = *arr.views().get_unchecked(array_idx as usize); + let view = rewrite_view(view, chunk_idx, &buffer_offsets); + views.push_unchecked(view); + } + None + }; + + arc_data_buffers = buffers.into(); + }; + + let arr = BinaryViewArrayGeneric::::new_unchecked_unknown_md( + V::DATA_TYPE, + views.into(), + arc_data_buffers, + validity, + None, + ); + + let mut out = ChunkedArray::with_chunk(ca.name().clone(), arr.maybe_gc()); + let sorted_flag = _update_gather_sorted_flag(ca.is_sorted_flag(), sorted); + out.set_sorted_flag(sorted_flag); + out +} + +#[allow(clippy::unnecessary_cast)] +#[inline(always)] +unsafe fn rewrite_view(mut view: View, chunk_idx: IdxSize, buffer_offsets: &[u32]) -> View { + if view.length > 12 { + let base_offset = *buffer_offsets.get_unchecked(chunk_idx as usize); + view.buffer_idx += base_offset; + } + view +} + +unsafe fn update_view_and_dedup( + mut view: View, + orig_buffers: &[Buffer], + buffer_idxs: &mut PlHashMap<(*const u8, usize), u32>, + buffers: &mut Vec>, +) -> View { + if view.length > 12 { + // Dedup on pointer + length. + let orig_buffer = orig_buffers.get_unchecked(view.buffer_idx as usize); + view.buffer_idx = + match buffer_idxs.entry((orig_buffer.as_slice().as_ptr(), orig_buffer.len())) { + Entry::Occupied(o) => *o.get(), + Entry::Vacant(v) => { + let buffer_idx = buffers.len() as u32; + buffers.push(orig_buffer.clone()); + v.insert(buffer_idx); + buffer_idx + }, + }; + } + view +} + +fn dedup_buffers_by_arc(ca: &ChunkedArray) -> (Vec>, Vec) +where + T: PolarsDataType>, + V: ViewType + ?Sized, +{ + // Dedup buffers up front. Note: don't do this during view update, as this is often is much + // more costly. + let mut buffers = Vec::with_capacity(ca.chunks().len()); + // Dont need to include the length, as we look at the arc pointers, which are immutable. + let mut buffers_dedup = PlHashMap::with_capacity(ca.chunks().len()); + let mut buffer_offsets = Vec::with_capacity(ca.chunks().len() + 1); + + for arr in ca.downcast_iter() { + let data_buffers = arr.data_buffers(); + let arc_ptr = data_buffers.as_ptr(); + let offset = match buffers_dedup.entry(arc_ptr) { + Entry::Occupied(o) => *o.get(), + Entry::Vacant(v) => { + let offset = buffers.len() as u32; + buffers.extend(data_buffers.iter().cloned()); + v.insert(offset); + offset + }, + }; + buffer_offsets.push(offset); + } + (buffers, buffer_offsets) +} + +unsafe fn take_opt_chunked_unchecked_binview( + ca: &ChunkedArray, + by: &[ChunkId], + avoid_sharing: bool, +) -> ChunkedArray +where + T: PolarsDataType>, + T::Array: Debug, + V: ViewType + ?Sized, +{ + if avoid_sharing { + return ca.take_opt_chunked_unchecked(by, avoid_sharing); + } + + let mut views = Vec::with_capacity(by.len()); + let mut validity = BitmapBuilder::with_capacity(by.len()); + + // If we can cheaply clone the list of buffers from the ChunkedArray we will, + // otherwise we will only clone those buffers we need. + let arc_data_buffers = if ca.n_chunks() == 1 { + let arr = ca.downcast_iter().next().unwrap(); + let arr_views = arr.views(); + + if arr.has_nulls() { + for id in by.iter() { + let (chunk_idx, array_idx) = id.extract(); + debug_assert!(id.is_null() || chunk_idx == 0); + if id.is_null() || arr.is_null_unchecked(array_idx as usize) { + views.push_unchecked(View::default()); + validity.push_unchecked(false); + } else { + views.push_unchecked(*arr_views.get_unchecked(array_idx as usize)); + validity.push_unchecked(true); + } + } + } else { + for id in by.iter() { + let (chunk_idx, array_idx) = id.extract(); + debug_assert!(id.is_null() || chunk_idx == 0); + if id.is_null() { + views.push_unchecked(View::default()); + validity.push_unchecked(false); + } else { + views.push_unchecked(*arr_views.get_unchecked(array_idx as usize)); + validity.push_unchecked(true); + } + } + } + + arr.data_buffers().clone() + } + // Dedup the buffers while creating the views. + else if by.len() < ca.n_chunks() { + let mut buffer_idxs = PlHashMap::with_capacity(8); + let mut buffers = Vec::with_capacity(8); + + if ca.has_nulls() { + for id in by.iter() { + let (chunk_idx, array_idx) = id.extract(); + + if id.is_null() { + views.push_unchecked(View::default()); + validity.push_unchecked(false); + } else { + let arr = ca.downcast_get_unchecked(chunk_idx as usize); + if arr.is_null_unchecked(array_idx as usize) { + views.push_unchecked(View::default()); + validity.push_unchecked(false); + } else { + let view = *arr.views().get_unchecked(array_idx as usize); + views.push_unchecked(update_view_and_dedup( + view, + arr.data_buffers(), + &mut buffer_idxs, + &mut buffers, + )); + validity.push_unchecked(true); + } + } + } + } else { + for id in by.iter() { + let (chunk_idx, array_idx) = id.extract(); + + if id.is_null() { + views.push_unchecked(View::default()); + validity.push_unchecked(false); + } else { + let arr = ca.downcast_get_unchecked(chunk_idx as usize); + let view = *arr.views().get_unchecked(array_idx as usize); + views.push_unchecked(update_view_and_dedup( + view, + arr.data_buffers(), + &mut buffer_idxs, + &mut buffers, + )); + validity.push_unchecked(true); + } + } + }; + + buffers.into() + } + // Dedup the buffers up front + else { + let (buffers, buffer_offsets) = dedup_buffers_by_arc(ca); + + if ca.has_nulls() { + for id in by.iter() { + let (chunk_idx, array_idx) = id.extract(); + + if id.is_null() { + views.push_unchecked(View::default()); + validity.push_unchecked(false); + } else { + let arr = ca.downcast_get_unchecked(chunk_idx as usize); + if arr.is_null_unchecked(array_idx as usize) { + views.push_unchecked(View::default()); + validity.push_unchecked(false); + } else { + let view = *arr.views().get_unchecked(array_idx as usize); + let view = rewrite_view(view, chunk_idx, &buffer_offsets); + views.push_unchecked(view); + validity.push_unchecked(true); + } + } + } + } else { + for id in by.iter() { + let (chunk_idx, array_idx) = id.extract(); + + if id.is_null() { + views.push_unchecked(View::default()); + validity.push_unchecked(false); + } else { + let arr = ca.downcast_get_unchecked(chunk_idx as usize); + let view = *arr.views().get_unchecked(array_idx as usize); + let view = rewrite_view(view, chunk_idx, &buffer_offsets); + views.push_unchecked(view); + validity.push_unchecked(true); + } + } + }; + + buffers.into() + }; + + let arr = BinaryViewArrayGeneric::::new_unchecked_unknown_md( + V::DATA_TYPE, + views.into(), + arc_data_buffers, + Some(validity.freeze()), + None, + ); + + ChunkedArray::with_chunk(ca.name().clone(), arr.maybe_gc()) +} + +#[cfg(feature = "dtype-struct")] +unsafe fn take_chunked_unchecked_struct( + ca: &StructChunked, + by: &[ChunkId], + sorted: IsSorted, + avoid_sharing: bool, +) -> StructChunked { + let fields = ca + .fields_as_series() + .iter() + .map(|s| s.take_chunked_unchecked(by, sorted, avoid_sharing)) + .collect::>(); + let mut out = StructChunked::from_series(ca.name().clone(), by.len(), fields.iter()).unwrap(); + + if !ca.has_nulls() { + return out; + } + + let mut validity = BitmapBuilder::with_capacity(by.len()); + if ca.n_chunks() == 1 { + let arr = ca.downcast_as_array(); + let bitmap = arr.validity().unwrap(); + for id in by.iter() { + let (chunk_idx, array_idx) = id.extract(); + debug_assert!(chunk_idx == 0); + validity.push_unchecked(bitmap.get_bit_unchecked(array_idx as usize)); + } + } else { + for id in by.iter() { + let (chunk_idx, array_idx) = id.extract(); + let arr = ca.downcast_get_unchecked(chunk_idx as usize); + if let Some(bitmap) = arr.validity() { + validity.push_unchecked(bitmap.get_bit_unchecked(array_idx as usize)); + } else { + validity.push_unchecked(true); + } + } + } + + out.rechunk_mut(); // Should be a no-op. + out.downcast_iter_mut() + .next() + .unwrap() + .set_validity(validity.into_opt_validity()); + out +} + +#[cfg(feature = "dtype-struct")] +unsafe fn take_opt_chunked_unchecked_struct( + ca: &StructChunked, + by: &[ChunkId], + avoid_sharing: bool, +) -> StructChunked { + let fields = ca + .fields_as_series() + .iter() + .map(|s| s.take_opt_chunked_unchecked(by, avoid_sharing)) + .collect::>(); + let mut out = StructChunked::from_series(ca.name().clone(), by.len(), fields.iter()).unwrap(); + + let mut validity = BitmapBuilder::with_capacity(by.len()); + if ca.n_chunks() == 1 { + let arr = ca.downcast_as_array(); + if let Some(bitmap) = arr.validity() { + for id in by.iter() { + if id.is_null() { + validity.push_unchecked(false); + } else { + let (chunk_idx, array_idx) = id.extract(); + debug_assert!(chunk_idx == 0); + validity.push_unchecked(bitmap.get_bit_unchecked(array_idx as usize)); + } + } + } else { + for id in by.iter() { + validity.push_unchecked(!id.is_null()); + } + } + } else { + for id in by.iter() { + if id.is_null() { + validity.push_unchecked(false); + } else { + let (chunk_idx, array_idx) = id.extract(); + let arr = ca.downcast_get_unchecked(chunk_idx as usize); + if let Some(bitmap) = arr.validity() { + validity.push_unchecked(bitmap.get_bit_unchecked(array_idx as usize)); + } else { + validity.push_unchecked(true); + } + } + } + } + + out.rechunk_mut(); // Should be a no-op. + out.downcast_iter_mut() + .next() + .unwrap() + .set_validity(validity.into_opt_validity()); + out +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_binview_chunked_gather() { + unsafe { + // # Series without nulls; + let mut s_1 = Series::new( + "a".into(), + &["1 loooooooooooong string", "2 loooooooooooong string"], + ); + let s_2 = Series::new( + "a".into(), + &["11 loooooooooooong string", "22 loooooooooooong string"], + ); + let s_3 = Series::new( + "a".into(), + &[ + "111 loooooooooooong string", + "222 loooooooooooong string", + "small", // this tests we don't mess with the inlined view + ], + ); + s_1.append(&s_2).unwrap(); + s_1.append(&s_3).unwrap(); + + assert_eq!(s_1.n_chunks(), 3); + + // ## Ids without nulls; + let by: [ChunkId<24>; 7] = [ + ChunkId::store(0, 0), + ChunkId::store(0, 1), + ChunkId::store(1, 1), + ChunkId::store(1, 0), + ChunkId::store(2, 0), + ChunkId::store(2, 1), + ChunkId::store(2, 2), + ]; + + let out = s_1.take_chunked_unchecked(&by, IsSorted::Not, true); + let idx = IdxCa::new("".into(), [0, 1, 3, 2, 4, 5, 6]); + let expected = s_1.rechunk().take(&idx).unwrap(); + assert!(out.equals(&expected)); + + // ## Ids with nulls; + let by: [ChunkId<24>; 4] = [ + ChunkId::null(), + ChunkId::store(0, 1), + ChunkId::store(1, 1), + ChunkId::store(1, 0), + ]; + let out = s_1.take_opt_chunked_unchecked(&by, true); + + let idx = IdxCa::new("".into(), [None, Some(1), Some(3), Some(2)]); + let expected = s_1.rechunk().take(&idx).unwrap(); + assert!(out.equals_missing(&expected)); + + // # Series with nulls; + let mut s_1 = Series::new( + "a".into(), + &["1 loooooooooooong string 1", "2 loooooooooooong string 2"], + ); + let s_2 = Series::new("a".into(), &[Some("11 loooooooooooong string 11"), None]); + s_1.append(&s_2).unwrap(); + + // ## Ids without nulls; + let by: [ChunkId<24>; 4] = [ + ChunkId::store(0, 0), + ChunkId::store(0, 1), + ChunkId::store(1, 1), + ChunkId::store(1, 0), + ]; + + let out = s_1.take_chunked_unchecked(&by, IsSorted::Not, true); + let idx = IdxCa::new("".into(), [0, 1, 3, 2]); + let expected = s_1.rechunk().take(&idx).unwrap(); + assert!(out.equals_missing(&expected)); + + // ## Ids with nulls; + let by: [ChunkId<24>; 4] = [ + ChunkId::null(), + ChunkId::store(0, 1), + ChunkId::store(1, 1), + ChunkId::store(1, 0), + ]; + let out = s_1.take_opt_chunked_unchecked(&by, true); + + let idx = IdxCa::new("".into(), [None, Some(1), Some(3), Some(2)]); + let expected = s_1.rechunk().take(&idx).unwrap(); + assert!(out.equals_missing(&expected)); + } + } +} diff --git a/crates/polars-ops/src/chunked_array/gather/mod.rs b/crates/polars-ops/src/chunked_array/gather/mod.rs new file mode 100644 index 000000000000..fe4a565d63bb --- /dev/null +++ b/crates/polars-ops/src/chunked_array/gather/mod.rs @@ -0,0 +1,4 @@ +#[cfg(feature = "chunked_ids")] +pub(crate) mod chunked; +#[cfg(feature = "chunked_ids")] +pub use chunked::*; diff --git a/crates/polars-ops/src/chunked_array/gather_skip_nulls.rs b/crates/polars-ops/src/chunked_array/gather_skip_nulls.rs new file mode 100644 index 000000000000..75f11c5f5e0c --- /dev/null +++ b/crates/polars-ops/src/chunked_array/gather_skip_nulls.rs @@ -0,0 +1,226 @@ +use arrow::array::Array; +use arrow::bitmap::bitmask::BitMask; +use arrow::compute::concatenate::concatenate_validities; +use bytemuck::allocation::zeroed_vec; +use polars_core::prelude::gather::check_bounds_ca; +use polars_core::prelude::*; +use polars_utils::index::check_bounds; + +/// # Safety +/// For each index pair, pair.0 < len && pair.1 < ca.null_count() must hold. +unsafe fn gather_skip_nulls_idx_pairs_unchecked<'a, T: PolarsDataType>( + ca: &'a ChunkedArray, + mut index_pairs: Vec<(IdxSize, IdxSize)>, + len: usize, +) -> Vec> { + if index_pairs.is_empty() { + return zeroed_vec(len); + } + + // We sort by gather index so we can do the null scan in one pass. + index_pairs.sort_unstable_by_key(|t| t.1); + let mut pair_iter = index_pairs.iter().copied(); + let (mut out_idx, mut nonnull_idx); + (out_idx, nonnull_idx) = pair_iter.next().unwrap(); + + let mut out: Vec> = zeroed_vec(len); + let mut nonnull_prev_arrays = 0; + 'outer: for arr in ca.downcast_iter() { + let arr_nonnull_len = arr.len() - arr.null_count(); + let mut arr_scan_offset = 0; + let mut nonnull_before_offset = 0; + let mask = arr.validity().map(BitMask::from_bitmap).unwrap_or_default(); + + // Is our next nonnull_idx in this array? + while nonnull_idx as usize - nonnull_prev_arrays < arr_nonnull_len { + let nonnull_idx_in_arr = nonnull_idx as usize - nonnull_prev_arrays; + + let phys_idx_in_arr = if arr.null_count() == 0 { + // Happy fast path for full non-null array. + nonnull_idx_in_arr + } else { + mask.nth_set_bit_idx(nonnull_idx_in_arr - nonnull_before_offset, arr_scan_offset) + .unwrap() + }; + + unsafe { + let val = arr.value_unchecked(phys_idx_in_arr); + *out.get_unchecked_mut(out_idx as usize) = val.into(); + } + + arr_scan_offset = phys_idx_in_arr; + nonnull_before_offset = nonnull_idx_in_arr; + + let Some(next_pair) = pair_iter.next() else { + break 'outer; + }; + (out_idx, nonnull_idx) = next_pair; + } + + nonnull_prev_arrays += arr_nonnull_len; + } + + out +} + +pub trait ChunkGatherSkipNulls: Sized { + fn gather_skip_nulls(&self, indices: &I) -> PolarsResult; +} + +impl ChunkGatherSkipNulls<[IdxSize]> for ChunkedArray +where + ChunkedArray: ChunkFilter + ChunkTake<[IdxSize]>, +{ + fn gather_skip_nulls(&self, indices: &[IdxSize]) -> PolarsResult { + if self.null_count() == 0 { + return self.take(indices); + } + + // If we want many indices it's probably better to do a normal gather on + // a dense array. + if indices.len() >= self.len() / 4 { + return ChunkFilter::filter(self, &self.is_not_null()) + .unwrap() + .take(indices); + } + + let bound = self.len() - self.null_count(); + check_bounds(indices, bound as IdxSize)?; + + let index_pairs: Vec<_> = indices + .iter() + .enumerate() + .map(|(out_idx, nonnull_idx)| (out_idx as IdxSize, *nonnull_idx)) + .collect(); + let gathered = + unsafe { gather_skip_nulls_idx_pairs_unchecked(self, index_pairs, indices.len()) }; + let arr = + T::Array::from_zeroable_vec(gathered, self.dtype().to_arrow(CompatLevel::newest())); + Ok(ChunkedArray::from_chunk_iter_like(self, [arr])) + } +} + +impl ChunkGatherSkipNulls for ChunkedArray +where + ChunkedArray: ChunkFilter + ChunkTake, +{ + fn gather_skip_nulls(&self, indices: &IdxCa) -> PolarsResult { + if self.null_count() == 0 { + return self.take(indices); + } + + // If we want many indices it's probably better to do a normal gather on + // a dense array. + if indices.len() >= self.len() / 4 { + return ChunkFilter::filter(self, &self.is_not_null()) + .unwrap() + .take(indices); + } + + let bound = self.len() - self.null_count(); + check_bounds_ca(indices, bound as IdxSize)?; + + let index_pairs: Vec<_> = if indices.null_count() == 0 { + indices + .downcast_iter() + .flat_map(|arr| arr.values_iter()) + .enumerate() + .map(|(out_idx, nonnull_idx)| (out_idx as IdxSize, *nonnull_idx)) + .collect() + } else { + // Filter *after* the enumerate so we place the non-null gather + // requests at the right places. + indices + .downcast_iter() + .flat_map(|arr| arr.iter()) + .enumerate() + .filter_map(|(out_idx, nonnull_idx)| Some((out_idx as IdxSize, *nonnull_idx?))) + .collect() + }; + let gathered = unsafe { + gather_skip_nulls_idx_pairs_unchecked(self, index_pairs, indices.as_ref().len()) + }; + + let mut arr = + T::Array::from_zeroable_vec(gathered, self.dtype().to_arrow(CompatLevel::newest())); + if indices.null_count() > 0 { + arr = arr.with_validity_typed(concatenate_validities(indices.chunks())); + } + Ok(ChunkedArray::from_chunk_iter_like(self, [arr])) + } +} + +#[cfg(test)] +mod test { + use std::ops::Range; + + use rand::distributions::uniform::SampleUniform; + use rand::prelude::*; + + use super::*; + + fn random_vec( + rng: &mut R, + val: Range, + len_range: Range, + ) -> Vec { + let n = rng.gen_range(len_range); + (0..n).map(|_| rng.gen_range(val.clone())).collect() + } + + fn random_filter(rng: &mut R, v: &[T], pr: Range) -> Vec> { + let p = rng.gen_range(pr); + let rand_filter = |x| Some(x).filter(|_| rng.r#gen::() < p); + v.iter().cloned().map(rand_filter).collect() + } + + fn ref_gather_nulls(v: Vec>, idx: Vec>) -> Option>> { + let v: Vec = v.into_iter().flatten().collect(); + if idx.iter().any(|oi| oi.map(|i| i >= v.len()) == Some(true)) { + return None; + } + Some(idx.into_iter().map(|i| Some(v[i?])).collect()) + } + + fn test_equal_ref(ca: &UInt32Chunked, idx_ca: &IdxCa) { + let ref_ca: Vec> = ca.iter().collect(); + let ref_idx_ca: Vec> = idx_ca.iter().map(|i| Some(i? as usize)).collect(); + let gather = ca.gather_skip_nulls(idx_ca).ok(); + let ref_gather = ref_gather_nulls(ref_ca, ref_idx_ca); + assert_eq!(gather.map(|ca| ca.iter().collect()), ref_gather); + } + + fn gather_skip_nulls_check(ca: &UInt32Chunked, idx_ca: &IdxCa) { + test_equal_ref(ca, idx_ca); + test_equal_ref(&ca.rechunk(), idx_ca); + test_equal_ref(ca, &idx_ca.rechunk()); + test_equal_ref(&ca.rechunk(), &idx_ca.rechunk()); + } + + #[rustfmt::skip] + #[test] + fn test_gather_skip_nulls() { + let mut rng = SmallRng::seed_from_u64(0xdeadbeef); + + for _test in 0..20 { + let num_elem_chunks = rng.gen_range(1..10); + let elem_chunks: Vec<_> = (0..num_elem_chunks).map(|_| random_vec(&mut rng, 0..u32::MAX, 0..100)).collect(); + let null_elem_chunks: Vec<_> = elem_chunks.iter().map(|c| random_filter(&mut rng, c, 0.7..1.0)).collect(); + let num_nonnull_elems: usize = null_elem_chunks.iter().map(|c| c.iter().filter(|x| x.is_some()).count()).sum(); + + let num_idx_chunks = rng.gen_range(1..10); + let idx_chunks: Vec<_> = (0..num_idx_chunks).map(|_| random_vec(&mut rng, 0..num_nonnull_elems as IdxSize, 0..200)).collect(); + let null_idx_chunks: Vec<_> = idx_chunks.iter().map(|c| random_filter(&mut rng, c, 0.7..1.0)).collect(); + + let nonnull_ca = UInt32Chunked::from_chunk_iter("".into(), elem_chunks.iter().cloned().map(|v| v.into_iter().collect_arr())); + let ca = UInt32Chunked::from_chunk_iter("".into(), null_elem_chunks.iter().cloned().map(|v| v.into_iter().collect_arr())); + let nonnull_idx_ca = IdxCa::from_chunk_iter("".into(), idx_chunks.iter().cloned().map(|v| v.into_iter().collect_arr())); + let idx_ca = IdxCa::from_chunk_iter("".into(), null_idx_chunks.iter().cloned().map(|v| v.into_iter().collect_arr())); + + gather_skip_nulls_check(&ca, &idx_ca); + gather_skip_nulls_check(&ca, &nonnull_idx_ca); + gather_skip_nulls_check(&nonnull_ca, &idx_ca); + gather_skip_nulls_check(&nonnull_ca, &nonnull_idx_ca); + } + } +} diff --git a/crates/polars-ops/src/chunked_array/hist.rs b/crates/polars-ops/src/chunked_array/hist.rs new file mode 100644 index 000000000000..127452ff26e4 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/hist.rs @@ -0,0 +1,261 @@ +use std::cmp; +use std::fmt::Write; + +use num_traits::ToPrimitive; +use polars_core::prelude::*; +use polars_core::with_match_physical_numeric_polars_type; + +const DEFAULT_BIN_COUNT: usize = 10; + +fn get_breaks( + ca: &ChunkedArray, + bin_count: Option, + bins: Option<&[f64]>, +) -> PolarsResult<(Vec, bool)> +where + T: PolarsNumericType, + ChunkedArray: ChunkAgg, +{ + let (bins, uniform) = match (bin_count, bins) { + (Some(_), Some(_)) => { + return Err(PolarsError::ComputeError( + "can only provide one of `bin_count` or `bins`".into(), + )); + }, + (None, Some(bins)) => { + // User-supplied bins. Note these are actually bin edges. Check for monotonicity. + // If we only have one edge, we have no bins. + let bin_len = bins.len(); + if bin_len > 1 { + for i in 1..bin_len { + if (bins[i] - bins[i - 1]) <= 0.0 { + return Err(PolarsError::ComputeError( + "bins must increase monotonically".into(), + )); + } + } + (bins.to_vec(), false) + } else { + (Vec::::new(), false) + } + }, + (bin_count, None) => { + // User-supplied bin count, or 10 by default. Compute edges from the data. + let bin_count = bin_count.unwrap_or(DEFAULT_BIN_COUNT); + let n = ca.len() - ca.null_count(); + let (offset, width, upper_limit) = if n == 0 { + // No non-null items; supply unit interval. + (0.0, 1.0 / bin_count as f64, 1.0) + } else if n == 1 { + // Unit interval around single point + let idx = ca.first_non_null().unwrap(); + // SAFETY: idx is guaranteed to contain an element. + let center = unsafe { ca.get_unchecked(idx) }.unwrap().to_f64().unwrap(); + (center - 0.5, 1.0 / bin_count as f64, center + 0.5) + } else { + // Determine outer bin edges from the data itself + let min_value = ca.min().unwrap().to_f64().unwrap(); + let max_value = ca.max().unwrap().to_f64().unwrap(); + + // All data points are identical--use unit interval. + if min_value == max_value { + (min_value - 0.5, 1.0 / bin_count as f64, max_value + 0.5) + } else { + ( + min_value, + (max_value - min_value) / bin_count as f64, + max_value, + ) + } + }; + // Manually set the final value to the maximum value to ensure the final value isn't + // missed due to floating-point precision. + let out = (0..bin_count) + .map(|x| (x as f64 * width) + offset) + .chain(std::iter::once(upper_limit)) + .collect::>(); + (out, true) + }, + }; + Ok((bins, uniform)) +} + +// O(n) implementation when buckets are fixed-size. +// We deposit items directly into their buckets. +fn uniform_hist_count(breaks: &[f64], ca: &ChunkedArray) -> Vec +where + T: PolarsNumericType, + ChunkedArray: ChunkAgg, +{ + let num_bins = breaks.len() - 1; + let mut count: Vec = vec![0; num_bins]; + let min_break: f64 = breaks[0]; + let max_break: f64 = breaks[num_bins]; + let scale = num_bins as f64 / (max_break - min_break); + let max_idx = num_bins - 1; + + for chunk in ca.downcast_iter() { + for item in chunk.non_null_values_iter() { + let item = item.to_f64().unwrap(); + if item > min_break && item <= max_break { + // idx > (num_bins - 1) may happen due to floating point representation imprecision + let mut idx = cmp::min((scale * (item - min_break)) as usize, max_idx); + + // Adjust for float imprecision providing idx > 1 ULP of the breaks + if item <= breaks[idx] { + idx -= 1; + } else if item > breaks[idx + 1] { + idx += 1; + } + + count[idx] += 1; + } else if item == min_break { + count[0] += 1; + } + } + } + count +} + +// Variable-width bucketing. We sort the items and then move linearly through buckets. +fn hist_count(breaks: &[f64], ca: &ChunkedArray) -> Vec +where + T: PolarsNumericType, + ChunkedArray: ChunkAgg, +{ + let num_bins = breaks.len() - 1; + let mut breaks_iter = breaks.iter().skip(1); // Skip the first lower bound + let (min_break, max_break) = (breaks[0], breaks[breaks.len() - 1]); + let mut upper_bound = *breaks_iter.next().unwrap(); + let mut sorted = ca.sort(false); + sorted.rechunk_mut(); + let mut current_count: IdxSize = 0; + let chunk = sorted.downcast_as_array(); + let mut count: Vec = Vec::with_capacity(num_bins); + + 'item: for item in chunk.non_null_values_iter() { + let item = item.to_f64().unwrap(); + + // Cycle through items until we hit the first bucket. + if item.is_nan() || item < min_break { + continue; + } + + while item > upper_bound { + if item > max_break { + // No more items will fit in any buckets + break 'item; + } + + // Finished with prior bucket; push, reset, and move to next. + count.push(current_count); + current_count = 0; + upper_bound = *breaks_iter.next().unwrap(); + } + + // Item is in bound. + current_count += 1; + } + count.push(current_count); + count.resize(num_bins, 0); // If we left early, fill remainder with 0. + count +} + +fn compute_hist( + ca: &ChunkedArray, + bin_count: Option, + bins: Option<&[f64]>, + include_category: bool, + include_breakpoint: bool, +) -> PolarsResult +where + T: PolarsNumericType, + ChunkedArray: ChunkAgg, +{ + let (breaks, uniform) = get_breaks(ca, bin_count, bins)?; + let num_bins = std::cmp::max(breaks.len(), 1) - 1; + let count = if num_bins > 0 && ca.len() > ca.null_count() { + if uniform { + uniform_hist_count(&breaks, ca) + } else { + hist_count(&breaks, ca) + } + } else { + vec![0; num_bins] + }; + + // Generate output: breakpoint (optional), breaks (optional), count + let mut fields = Vec::with_capacity(3); + + if include_breakpoint { + let breakpoints = if num_bins > 0 { + Series::new(PlSmallStr::from_static("breakpoint"), &breaks[1..]) + } else { + let empty: &[f64; 0] = &[]; + Series::new(PlSmallStr::from_static("breakpoint"), empty) + }; + fields.push(breakpoints) + } + + if include_category { + let mut categories = + StringChunkedBuilder::new(PlSmallStr::from_static("category"), breaks.len()); + if num_bins > 0 { + let mut lower = AnyValue::Float64(breaks[0]); + let mut buf = String::new(); + let mut open_bracket = "["; + for br in &breaks[1..] { + let br = AnyValue::Float64(*br); + buf.clear(); + write!(buf, "{open_bracket}{lower}, {br}]").unwrap(); + open_bracket = "("; + categories.append_value(buf.as_str()); + lower = br; + } + } + let categories = categories + .finish() + .cast(&DataType::Categorical(None, Default::default())) + .unwrap(); + fields.push(categories); + }; + + let count = Series::new(PlSmallStr::from_static("count"), count); + fields.push(count); + + Ok(if fields.len() == 1 { + fields.pop().unwrap().with_name(ca.name().clone()) + } else { + StructChunked::from_series(ca.name().clone(), fields[0].len(), fields.iter()) + .unwrap() + .into_series() + }) +} + +pub fn hist_series( + s: &Series, + bin_count: Option, + bins: Option, + include_category: bool, + include_breakpoint: bool, +) -> PolarsResult { + let mut bins_arg = None; + + let owned_bins; + if let Some(bins) = bins { + polars_ensure!(bins.null_count() == 0, InvalidOperation: "nulls not supported in 'bins' argument"); + let bins = bins.cast(&DataType::Float64)?; + let bins_s = bins.rechunk(); + owned_bins = bins_s; + let bins = owned_bins.f64().unwrap(); + let bins = bins.cont_slice().unwrap(); + bins_arg = Some(bins); + }; + polars_ensure!(s.dtype().is_primitive_numeric(), InvalidOperation: "'hist' is only supported for numeric data"); + + let out = with_match_physical_numeric_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + compute_hist(ca, bin_count, bins_arg, include_category, include_breakpoint)? + }); + Ok(out) +} diff --git a/crates/polars-ops/src/chunked_array/list/any_all.rs b/crates/polars-ops/src/chunked_array/list/any_all.rs new file mode 100644 index 000000000000..2ba0d2aeb691 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/list/any_all.rs @@ -0,0 +1,51 @@ +use arrow::array::{BooleanArray, ListArray}; +use arrow::bitmap::Bitmap; + +use super::*; + +fn list_all_any(arr: &ListArray, op: F, is_all: bool) -> PolarsResult +where + F: Fn(&BooleanArray) -> bool, +{ + let offsets = arr.offsets().as_slice(); + let values = arr.values(); + + polars_ensure!(values.dtype() == &ArrowDataType::Boolean, ComputeError: "expected boolean elements in list"); + + let values = values.as_any().downcast_ref::().unwrap(); + let validity = arr.validity().cloned(); + + // Fast path where all values set (all is free). + if is_all { + let all_set = arrow::compute::boolean::all(values); + if all_set { + let bits = Bitmap::new_with_value(true, arr.len()); + return Ok(BooleanArray::from_data_default(bits, None).with_validity(validity)); + } + } + + let mut start = offsets[0] as usize; + let iter = offsets[1..].iter().map(|&end| { + let end = end as usize; + let len = end - start; + let val = unsafe { values.clone().sliced_unchecked(start, len) }; + start = end; + op(&val) + }); + + Ok(BooleanArray::from_trusted_len_values_iter(iter).with_validity(validity)) +} + +pub(super) fn list_all(ca: &ListChunked) -> PolarsResult { + let chunks = ca + .downcast_iter() + .map(|arr| list_all_any(arr, arrow::compute::boolean::all, true)); + Ok(BooleanChunked::try_from_chunk_iter(ca.name().clone(), chunks)?.into_series()) +} + +pub(super) fn list_any(ca: &ListChunked) -> PolarsResult { + let chunks = ca + .downcast_iter() + .map(|arr| list_all_any(arr, arrow::compute::boolean::any, false)); + Ok(BooleanChunked::try_from_chunk_iter(ca.name().clone(), chunks)?.into_series()) +} diff --git a/crates/polars-ops/src/chunked_array/list/count.rs b/crates/polars-ops/src/chunked_array/list/count.rs new file mode 100644 index 000000000000..7ff951c842ca --- /dev/null +++ b/crates/polars-ops/src/chunked_array/list/count.rs @@ -0,0 +1,63 @@ +use arrow::array::{Array, BooleanArray}; +use arrow::bitmap::Bitmap; +use arrow::bitmap::utils::count_zeros; +use arrow::legacy::utils::CustomIterTools; + +use super::*; + +fn count_bits_set_by_offsets(values: &Bitmap, offset: &[i64]) -> Vec { + // Fast path where all bits are either set or unset. + if values.unset_bits() == values.len() { + return vec![0 as IdxSize; offset.len() - 1]; + } else if values.unset_bits() == 0 { + let mut start = offset[0]; + let v = (offset[1..]) + .iter() + .map(|end| { + let current_offset = start; + start = *end; + (end - current_offset) as IdxSize + }) + .collect_trusted(); + return v; + } + + let (bits, bitmap_offset, _) = values.as_slice(); + + let mut running_offset = offset[0]; + + (offset[1..]) + .iter() + .map(|end| { + let current_offset = running_offset; + running_offset = *end; + + let len = (end - current_offset) as usize; + + let set_ones = len - count_zeros(bits, bitmap_offset + current_offset as usize, len); + set_ones as IdxSize + }) + .collect_trusted() +} + +#[cfg(feature = "list_count")] +pub fn list_count_matches(ca: &ListChunked, value: AnyValue) -> PolarsResult { + let value = Series::new(PlSmallStr::EMPTY, [value]); + + let ca = ca.apply_to_inner(&|s| { + ChunkCompareEq::<&Series>::equal_missing(&s, &value).map(|ca| ca.into_series()) + })?; + let out = count_boolean_bits(&ca); + Ok(out.into_series()) +} + +pub(super) fn count_boolean_bits(ca: &ListChunked) -> IdxCa { + let chunks = ca.downcast_iter().map(|arr| { + let inner_arr = arr.values(); + let mask = inner_arr.as_any().downcast_ref::().unwrap(); + assert_eq!(mask.null_count(), 0); + let out = count_bits_set_by_offsets(mask.values(), arr.offsets().as_slice()); + IdxArr::from_data_default(out.into(), arr.validity().cloned()) + }); + IdxCa::from_chunk_iter(ca.name().clone(), chunks) +} diff --git a/crates/polars-ops/src/chunked_array/list/dispersion.rs b/crates/polars-ops/src/chunked_array/list/dispersion.rs new file mode 100644 index 000000000000..e9ff069031e2 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/list/dispersion.rs @@ -0,0 +1,96 @@ +use super::*; + +pub(super) fn median_with_nulls(ca: &ListChunked) -> Series { + match ca.inner_dtype() { + DataType::Float32 => { + let out: Float32Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().median().map(|v| v as f32))) + .with_name(ca.name().clone()); + out.into_series() + }, + #[cfg(feature = "dtype-datetime")] + DataType::Date => { + const MS_IN_DAY: i64 = 86_400_000; + let out: Int64Chunked = ca + .apply_amortized_generic(|s| { + s.and_then(|s| s.as_ref().median().map(|v| (v * (MS_IN_DAY as f64)) as i64)) + }) + .with_name(ca.name().clone()); + out.into_datetime(TimeUnit::Milliseconds, None) + .into_series() + }, + dt if dt.is_temporal() => { + let out: Int64Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().median().map(|v| v as i64))) + .with_name(ca.name().clone()); + out.cast(dt).unwrap() + }, + _ => { + let out: Float64Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().median())) + .with_name(ca.name().clone()); + out.into_series() + }, + } +} + +pub(super) fn std_with_nulls(ca: &ListChunked, ddof: u8) -> Series { + match ca.inner_dtype() { + DataType::Float32 => { + let out: Float32Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().std(ddof).map(|v| v as f32))) + .with_name(ca.name().clone()); + out.into_series() + }, + #[cfg(feature = "dtype-duration")] + DataType::Duration(tu) => { + let out: Int64Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().std(ddof).map(|v| v as i64))) + .with_name(ca.name().clone()); + out.into_duration(*tu).into_series() + }, + _ => { + let out: Float64Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().std(ddof))) + .with_name(ca.name().clone()); + out.into_series() + }, + } +} + +pub(super) fn var_with_nulls(ca: &ListChunked, ddof: u8) -> Series { + match ca.inner_dtype() { + DataType::Float32 => { + let out: Float32Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().var(ddof).map(|v| v as f32))) + .with_name(ca.name().clone()); + out.into_series() + }, + #[cfg(feature = "dtype-duration")] + DataType::Duration(TimeUnit::Milliseconds) => { + let out: Int64Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().var(ddof).map(|v| v as i64))) + .with_name(ca.name().clone()); + out.into_duration(TimeUnit::Milliseconds).into_series() + }, + #[cfg(feature = "dtype-duration")] + DataType::Duration(TimeUnit::Microseconds | TimeUnit::Nanoseconds) => { + let out: Int64Chunked = ca + .cast(&DataType::List(Box::new(DataType::Duration( + TimeUnit::Milliseconds, + )))) + .unwrap() + .list() + .unwrap() + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().var(ddof).map(|v| v as i64))) + .with_name(ca.name().clone()); + out.into_duration(TimeUnit::Milliseconds).into_series() + }, + _ => { + let out: Float64Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().var(ddof))) + .with_name(ca.name().clone()); + out.into_series() + }, + } +} diff --git a/crates/polars-ops/src/chunked_array/list/hash.rs b/crates/polars-ops/src/chunked_array/list/hash.rs new file mode 100644 index 000000000000..3cf45b5246ec --- /dev/null +++ b/crates/polars-ops/src/chunked_array/list/hash.rs @@ -0,0 +1,89 @@ +use std::hash::{BuildHasher, Hash}; + +use polars_core::series::BitRepr; +use polars_core::utils::NoNull; +use polars_core::{POOL, with_match_physical_float_polars_type}; +use polars_utils::aliases::PlSeedableRandomStateQuality; +use polars_utils::hashing::_boost_hash_combine; +use polars_utils::total_ord::{ToTotalOrd, TotalHash}; +use rayon::prelude::*; + +use super::*; + +fn hash_agg(ca: &ChunkedArray, random_state: &PlSeedableRandomStateQuality) -> u64 +where + T: PolarsNumericType, + T::Native: TotalHash + ToTotalOrd, + ::TotalOrdItem: Hash, +{ + // Note that we don't use the no null branch! This can break in unexpected ways. + // for instance with threading we split an array in n_threads, this may lead to + // splits that have no nulls and splits that have nulls. Then one array is hashed with + // Option and the other array with T. + // Meaning that they cannot be compared. By always hashing on Option the random_state is + // the only deterministic seed. + + // just some large prime + let mut hash_agg = 9069731903u64; + + // just some large prime + let null_hash = 2413670057; + + ca.downcast_iter().for_each(|arr| { + for opt_v in arr.iter() { + match opt_v { + Some(v) => { + let r = random_state.hash_one(v.to_total_ord()); + hash_agg = _boost_hash_combine(hash_agg, r); + }, + None => { + hash_agg = _boost_hash_combine(hash_agg, null_hash); + }, + } + } + }); + hash_agg +} + +pub(crate) fn hash( + ca: &mut ListChunked, + build_hasher: PlSeedableRandomStateQuality, +) -> UInt64Chunked { + if !ca.inner_dtype().to_physical().is_primitive_numeric() { + panic!( + "Hashing a list with a non-numeric inner type not supported. Got dtype: {:?}", + ca.dtype() + ); + } + + // just some large prime + let null_hash = 1969099309u64; + + ca.set_inner_dtype(ca.inner_dtype().to_physical()); + + let out: NoNull = POOL.install(|| { + ca.par_iter() + .map(|opt_s: Option| match opt_s { + None => null_hash, + Some(s) => { + if s.dtype().is_float() { + with_match_physical_float_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + hash_agg(ca, &build_hasher) + }) + } else { + match s.bit_repr() { + None => unimplemented!("Hash for lists without bit representation"), + Some(BitRepr::Small(ca)) => hash_agg(&ca, &build_hasher), + Some(BitRepr::Large(ca)) => hash_agg(&ca, &build_hasher), + } + } + }, + }) + .collect() + }); + + let mut out = out.into_inner(); + out.rename(ca.name().clone()); + out +} diff --git a/crates/polars-ops/src/chunked_array/list/min_max.rs b/crates/polars-ops/src/chunked_array/list/min_max.rs new file mode 100644 index 000000000000..5d795ecaca7e --- /dev/null +++ b/crates/polars-ops/src/chunked_array/list/min_max.rs @@ -0,0 +1,224 @@ +use arrow::array::{Array, PrimitiveArray}; +use arrow::bitmap::Bitmap; +use arrow::compute::utils::combine_validities_and; +use arrow::types::NativeType; +use polars_compute::min_max::MinMaxKernel; +use polars_core::prelude::*; +use polars_core::with_match_physical_numeric_polars_type; + +use crate::chunked_array::list::namespace::has_inner_nulls; + +fn min_between_offsets(values: &[T], offset: &[i64]) -> PrimitiveArray +where + T: NativeType, + [T]: for<'a> MinMaxKernel = T>, +{ + let mut running_offset = offset[0]; + + (offset[1..]) + .iter() + .map(|end| { + let current_offset = running_offset; + running_offset = *end; + if current_offset == *end { + return None; + } + + let slice = unsafe { values.get_unchecked(current_offset as usize..*end as usize) }; + slice.min_ignore_nan_kernel() + }) + .collect() +} + +fn dispatch_min(arr: &dyn Array, offsets: &[i64], validity: Option<&Bitmap>) -> ArrayRef +where + T: NativeType, + [T]: for<'a> MinMaxKernel = T>, +{ + let values = arr.as_any().downcast_ref::>().unwrap(); + let values = values.values().as_slice(); + let out = min_between_offsets(values, offsets); + let new_validity = combine_validities_and(out.validity(), validity); + out.with_validity(new_validity).to_boxed() +} + +fn min_list_numerical(ca: &ListChunked, inner_type: &DataType) -> Series { + use DataType::*; + let chunks = ca + .downcast_iter() + .map(|arr| { + let offsets = arr.offsets().as_slice(); + let values = arr.values().as_ref(); + + match inner_type { + Int8 => dispatch_min::(values, offsets, arr.validity()), + Int16 => dispatch_min::(values, offsets, arr.validity()), + Int32 => dispatch_min::(values, offsets, arr.validity()), + Int64 => dispatch_min::(values, offsets, arr.validity()), + Int128 => dispatch_min::(values, offsets, arr.validity()), + UInt8 => dispatch_min::(values, offsets, arr.validity()), + UInt16 => dispatch_min::(values, offsets, arr.validity()), + UInt32 => dispatch_min::(values, offsets, arr.validity()), + UInt64 => dispatch_min::(values, offsets, arr.validity()), + Float32 => dispatch_min::(values, offsets, arr.validity()), + Float64 => dispatch_min::(values, offsets, arr.validity()), + _ => unimplemented!(), + } + }) + .collect::>(); + + Series::try_from((ca.name().clone(), chunks)).unwrap() +} + +pub(super) fn list_min_function(ca: &ListChunked) -> PolarsResult { + fn inner(ca: &ListChunked) -> PolarsResult { + match ca.inner_dtype() { + DataType::Boolean => { + let out: BooleanChunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().bool().unwrap().min())); + Ok(out.into_series()) + }, + dt if dt.to_physical().is_primitive_numeric() => { + with_match_physical_numeric_polars_type!(dt.to_physical(), |$T| { + let out: ChunkedArray<$T> = ca.to_physical_repr().apply_amortized_generic(|opt_s| { + let s = opt_s?; + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + ca.min() + }); + // restore logical type + unsafe { out.into_series().from_physical_unchecked(dt) } + }) + }, + dt => ca + .try_apply_amortized(|s| { + let s = s.as_ref(); + let sc = s.min_reduce()?; + Ok(sc.into_series(s.name().clone())) + })? + .explode() + .unwrap() + .into_series() + .cast(dt), + } + } + + if has_inner_nulls(ca) { + return inner(ca); + }; + + match ca.inner_dtype() { + dt if dt.is_primitive_numeric() => Ok(min_list_numerical(ca, dt)), + _ => inner(ca), + } +} + +fn max_between_offsets(values: &[T], offset: &[i64]) -> PrimitiveArray +where + T: NativeType, + [T]: for<'a> MinMaxKernel = T>, +{ + let mut running_offset = offset[0]; + + (offset[1..]) + .iter() + .map(|end| { + let current_offset = running_offset; + running_offset = *end; + if current_offset == *end { + return None; + } + + let slice = unsafe { values.get_unchecked(current_offset as usize..*end as usize) }; + slice.max_ignore_nan_kernel() + }) + .collect() +} + +fn dispatch_max(arr: &dyn Array, offsets: &[i64], validity: Option<&Bitmap>) -> ArrayRef +where + T: NativeType, + [T]: for<'a> MinMaxKernel = T>, +{ + let values = arr.as_any().downcast_ref::>().unwrap(); + let values = values.values().as_slice(); + let mut out = max_between_offsets(values, offsets); + + if let Some(validity) = validity { + if out.null_count() > 0 { + out.apply_validity(|other_validity| validity & &other_validity) + } else { + out = out.with_validity(Some(validity.clone())); + } + } + Box::new(out) +} + +fn max_list_numerical(ca: &ListChunked, inner_type: &DataType) -> Series { + use DataType::*; + let chunks = ca + .downcast_iter() + .map(|arr| { + let offsets = arr.offsets().as_slice(); + let values = arr.values().as_ref(); + + match inner_type { + Int8 => dispatch_max::(values, offsets, arr.validity()), + Int16 => dispatch_max::(values, offsets, arr.validity()), + Int32 => dispatch_max::(values, offsets, arr.validity()), + Int64 => dispatch_max::(values, offsets, arr.validity()), + Int128 => dispatch_max::(values, offsets, arr.validity()), + UInt8 => dispatch_max::(values, offsets, arr.validity()), + UInt16 => dispatch_max::(values, offsets, arr.validity()), + UInt32 => dispatch_max::(values, offsets, arr.validity()), + UInt64 => dispatch_max::(values, offsets, arr.validity()), + Float32 => dispatch_max::(values, offsets, arr.validity()), + Float64 => dispatch_max::(values, offsets, arr.validity()), + _ => unimplemented!(), + } + }) + .collect::>(); + + Series::try_from((ca.name().clone(), chunks)).unwrap() +} + +pub(super) fn list_max_function(ca: &ListChunked) -> PolarsResult { + fn inner(ca: &ListChunked) -> PolarsResult { + match ca.inner_dtype() { + DataType::Boolean => { + let out: BooleanChunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().bool().unwrap().max())); + Ok(out.into_series()) + }, + dt if dt.to_physical().is_primitive_numeric() => { + with_match_physical_numeric_polars_type!(dt.to_physical(), |$T| { + let out: ChunkedArray<$T> = ca.to_physical_repr().apply_amortized_generic(|opt_s| { + let s = opt_s?; + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + ca.max() + }); + // restore logical type + unsafe { out.into_series().from_physical_unchecked(dt) } + }) + }, + dt => ca + .try_apply_amortized(|s| { + let s = s.as_ref(); + let sc = s.max_reduce()?; + Ok(sc.into_series(s.name().clone())) + })? + .explode() + .unwrap() + .into_series() + .cast(dt), + } + } + + if has_inner_nulls(ca) { + return inner(ca); + }; + + match ca.inner_dtype() { + dt if dt.is_primitive_numeric() => Ok(max_list_numerical(ca, dt)), + _ => inner(ca), + } +} diff --git a/crates/polars-ops/src/chunked_array/list/mod.rs b/crates/polars-ops/src/chunked_array/list/mod.rs new file mode 100644 index 000000000000..a93b1ed7e2b3 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/list/mod.rs @@ -0,0 +1,35 @@ +use polars_core::prelude::*; + +#[cfg(feature = "list_any_all")] +mod any_all; +mod count; +mod dispersion; +#[cfg(feature = "hash")] +pub(crate) mod hash; +mod min_max; +mod namespace; +#[cfg(feature = "list_sets")] +mod sets; +mod sum_mean; +#[cfg(feature = "list_to_struct")] +mod to_struct; + +#[cfg(feature = "list_count")] +pub use count::*; +#[cfg(not(feature = "list_count"))] +use count::*; +pub use namespace::*; +#[cfg(feature = "list_sets")] +pub use sets::*; +#[cfg(feature = "list_to_struct")] +pub use to_struct::*; + +pub trait AsList { + fn as_list(&self) -> &ListChunked; +} + +impl AsList for ListChunked { + fn as_list(&self) -> &ListChunked { + self + } +} diff --git a/crates/polars-ops/src/chunked_array/list/namespace.rs b/crates/polars-ops/src/chunked_array/list/namespace.rs new file mode 100644 index 000000000000..d1b5761ff934 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/list/namespace.rs @@ -0,0 +1,912 @@ +use std::borrow::Cow; +use std::fmt::Write; + +use arrow::array::ValueSize; +#[cfg(feature = "list_gather")] +use num_traits::ToPrimitive; +#[cfg(feature = "list_gather")] +use num_traits::{NumCast, Signed, Zero}; +use polars_compute::gather::sublist::list::{index_is_oob, sublist_get}; +use polars_core::chunked_array::builder::get_list_builder; +#[cfg(feature = "diff")] +use polars_core::series::ops::NullBehavior; +use polars_core::utils::try_get_supertype; + +use super::*; +#[cfg(feature = "list_any_all")] +use crate::chunked_array::list::any_all::*; +use crate::chunked_array::list::min_max::{list_max_function, list_min_function}; +use crate::chunked_array::list::sum_mean::sum_with_nulls; +#[cfg(feature = "diff")] +use crate::prelude::diff; +use crate::prelude::list::sum_mean::{mean_list_numerical, sum_list_numerical}; +use crate::series::ArgAgg; + +pub(super) fn has_inner_nulls(ca: &ListChunked) -> bool { + for arr in ca.downcast_iter() { + if arr.values().null_count() > 0 { + return true; + } + } + false +} + +fn cast_rhs( + other: &mut [Column], + inner_type: &DataType, + dtype: &DataType, + length: usize, + allow_broadcast: bool, +) -> PolarsResult<()> { + for s in other.iter_mut() { + // make sure that inner types match before we coerce into list + if !matches!(s.dtype(), DataType::List(_)) { + *s = s.cast(inner_type)? + } + if !matches!(s.dtype(), DataType::List(_)) && s.dtype() == inner_type { + // coerce to list JIT + *s = s + .reshape_list(&[ReshapeDimension::Infer, ReshapeDimension::new_dimension(1)]) + .unwrap(); + } + if s.dtype() != dtype { + *s = s.cast(dtype).map_err(|e| { + polars_err!( + SchemaMismatch: + "cannot concat `{}` into a list of `{}`: {}", + s.dtype(), + dtype, + e + ) + })?; + } + + if s.len() != length { + polars_ensure!( + s.len() == 1, + ShapeMismatch: "series length {} does not match expected length of {}", + s.len(), length + ); + if allow_broadcast { + // broadcast JIT + *s = s.new_from_index(0, length) + } + // else do nothing + } + } + Ok(()) +} + +pub trait ListNameSpaceImpl: AsList { + /// In case the inner dtype [`DataType::String`], the individual items will be joined into a + /// single string separated by `separator`. + fn lst_join( + &self, + separator: &StringChunked, + ignore_nulls: bool, + ) -> PolarsResult { + let ca = self.as_list(); + match ca.inner_dtype() { + DataType::String => match separator.len() { + 1 => match separator.get(0) { + Some(separator) => self.join_literal(separator, ignore_nulls), + _ => Ok(StringChunked::full_null(ca.name().clone(), ca.len())), + }, + _ => self.join_many(separator, ignore_nulls), + }, + dt => polars_bail!(op = "`lst.join`", got = dt, expected = "String"), + } + } + + fn join_literal(&self, separator: &str, ignore_nulls: bool) -> PolarsResult { + let ca = self.as_list(); + // used to amortize heap allocs + let mut buf = String::with_capacity(128); + let mut builder = StringChunkedBuilder::new(ca.name().clone(), ca.len()); + + ca.for_each_amortized(|opt_s| { + let opt_val = opt_s.and_then(|s| { + // make sure that we don't write values of previous iteration + buf.clear(); + let ca = s.as_ref().str().unwrap(); + + if ca.null_count() != 0 && !ignore_nulls { + return None; + } + + for arr in ca.downcast_iter() { + for val in arr.non_null_values_iter() { + buf.write_str(val).unwrap(); + buf.write_str(separator).unwrap(); + } + } + + // last value should not have a separator, so slice that off + // saturating sub because there might have been nothing written. + Some(&buf[..buf.len().saturating_sub(separator.len())]) + }); + builder.append_option(opt_val) + }); + Ok(builder.finish()) + } + + fn join_many( + &self, + separator: &StringChunked, + ignore_nulls: bool, + ) -> PolarsResult { + let ca = self.as_list(); + // used to amortize heap allocs + let mut buf = String::with_capacity(128); + let mut builder = StringChunkedBuilder::new(ca.name().clone(), ca.len()); + { + ca.amortized_iter() + .zip(separator) + .for_each(|(opt_s, opt_sep)| match opt_sep { + Some(separator) => { + let opt_val = opt_s.and_then(|s| { + // make sure that we don't write values of previous iteration + buf.clear(); + let ca = s.as_ref().str().unwrap(); + + if ca.null_count() != 0 && !ignore_nulls { + return None; + } + + for arr in ca.downcast_iter() { + for val in arr.non_null_values_iter() { + buf.write_str(val).unwrap(); + buf.write_str(separator).unwrap(); + } + } + + // last value should not have a separator, so slice that off + // saturating sub because there might have been nothing written. + Some(&buf[..buf.len().saturating_sub(separator.len())]) + }); + builder.append_option(opt_val) + }, + _ => builder.append_null(), + }) + } + Ok(builder.finish()) + } + + fn lst_max(&self) -> PolarsResult { + list_max_function(self.as_list()) + } + + #[cfg(feature = "list_any_all")] + fn lst_all(&self) -> PolarsResult { + let ca = self.as_list(); + list_all(ca) + } + + #[cfg(feature = "list_any_all")] + fn lst_any(&self) -> PolarsResult { + let ca = self.as_list(); + list_any(ca) + } + + fn lst_min(&self) -> PolarsResult { + list_min_function(self.as_list()) + } + + fn lst_sum(&self) -> PolarsResult { + let ca = self.as_list(); + + if has_inner_nulls(ca) { + return sum_with_nulls(ca, ca.inner_dtype()); + }; + + match ca.inner_dtype() { + DataType::Boolean => Ok(count_boolean_bits(ca).into_series()), + dt if dt.is_primitive_numeric() => Ok(sum_list_numerical(ca, dt)), + dt => sum_with_nulls(ca, dt), + } + } + + fn lst_mean(&self) -> Series { + let ca = self.as_list(); + + if has_inner_nulls(ca) { + return sum_mean::mean_with_nulls(ca); + }; + + match ca.inner_dtype() { + dt if dt.is_primitive_numeric() => mean_list_numerical(ca, dt), + _ => sum_mean::mean_with_nulls(ca), + } + } + + fn lst_median(&self) -> Series { + let ca = self.as_list(); + dispersion::median_with_nulls(ca) + } + + fn lst_std(&self, ddof: u8) -> Series { + let ca = self.as_list(); + dispersion::std_with_nulls(ca, ddof) + } + + fn lst_var(&self, ddof: u8) -> Series { + let ca = self.as_list(); + dispersion::var_with_nulls(ca, ddof) + } + + fn same_type(&self, out: ListChunked) -> ListChunked { + let ca = self.as_list(); + let dtype = ca.dtype(); + if out.dtype() != dtype { + out.cast(ca.dtype()).unwrap().list().unwrap().clone() + } else { + out + } + } + + fn lst_sort(&self, options: SortOptions) -> PolarsResult { + let ca = self.as_list(); + let out = ca.try_apply_amortized(|s| s.as_ref().sort_with(options))?; + Ok(self.same_type(out)) + } + + #[must_use] + fn lst_reverse(&self) -> ListChunked { + let ca = self.as_list(); + let out = ca.apply_amortized(|s| s.as_ref().reverse()); + self.same_type(out) + } + + fn lst_n_unique(&self) -> PolarsResult { + let ca = self.as_list(); + ca.try_apply_amortized_generic(|s| { + let opt_v = s.map(|s| s.as_ref().n_unique()).transpose()?; + Ok(opt_v.map(|idx| idx as IdxSize)) + }) + } + + fn lst_unique(&self) -> PolarsResult { + let ca = self.as_list(); + let out = ca.try_apply_amortized(|s| s.as_ref().unique())?; + Ok(self.same_type(out)) + } + + fn lst_unique_stable(&self) -> PolarsResult { + let ca = self.as_list(); + let out = ca.try_apply_amortized(|s| s.as_ref().unique_stable())?; + Ok(self.same_type(out)) + } + + fn lst_arg_min(&self) -> IdxCa { + let ca = self.as_list(); + ca.apply_amortized_generic(|opt_s| { + opt_s.and_then(|s| s.as_ref().arg_min().map(|idx| idx as IdxSize)) + }) + } + + fn lst_arg_max(&self) -> IdxCa { + let ca = self.as_list(); + ca.apply_amortized_generic(|opt_s| { + opt_s.and_then(|s| s.as_ref().arg_max().map(|idx| idx as IdxSize)) + }) + } + + #[cfg(feature = "diff")] + fn lst_diff(&self, n: i64, null_behavior: NullBehavior) -> PolarsResult { + let ca = self.as_list(); + ca.try_apply_amortized(|s| diff(s.as_ref(), n, null_behavior)) + } + + fn lst_shift(&self, periods: &Column) -> PolarsResult { + let ca = self.as_list(); + let periods_s = periods.cast(&DataType::Int64)?; + let periods = periods_s.i64()?; + + polars_ensure!( + ca.len() == periods.len() || ca.len() == 1 || periods.len() == 1, + length_mismatch = "list.shift", + ca.len(), + periods.len() + ); + + // Broadcast `self` + let mut ca = Cow::Borrowed(ca); + if ca.len() == 1 && periods.len() != 1 { + // Optimize: Don't broadcast and instead have a special path. + ca = Cow::Owned(ca.new_from_index(0, periods.len())); + } + let ca = ca.as_ref(); + + let out = match periods.len() { + 1 => { + if let Some(periods) = periods.get(0) { + ca.apply_amortized(|s| s.as_ref().shift(periods)) + } else { + ListChunked::full_null_with_dtype(ca.name().clone(), ca.len(), ca.inner_dtype()) + } + }, + _ => ca.zip_and_apply_amortized(periods, |opt_s, opt_periods| { + match (opt_s, opt_periods) { + (Some(s), Some(periods)) => Some(s.as_ref().shift(periods)), + _ => None, + } + }), + }; + Ok(self.same_type(out)) + } + + fn lst_slice(&self, offset: i64, length: usize) -> ListChunked { + let ca = self.as_list(); + let out = ca.apply_amortized(|s| s.as_ref().slice(offset, length)); + self.same_type(out) + } + + fn lst_lengths(&self) -> IdxCa { + let ca = self.as_list(); + + let ca_validity = ca.rechunk_validity(); + + if ca_validity.as_ref().is_some_and(|x| x.set_bits() == 0) { + return IdxCa::full_null(ca.name().clone(), ca.len()); + } + + let mut lengths = Vec::with_capacity(ca.len()); + ca.downcast_iter().for_each(|arr| { + let offsets = arr.offsets().as_slice(); + let mut last = offsets[0]; + for o in &offsets[1..] { + lengths.push((*o - last) as IdxSize); + last = *o; + } + }); + + let arr = IdxArr::from_vec(lengths).with_validity(ca_validity); + IdxCa::with_chunk(ca.name().clone(), arr) + } + + /// Get the value by index in the sublists. + /// So index `0` would return the first item of every sublist + /// and index `-1` would return the last item of every sublist + /// if an index is out of bounds, it will return a `None`. + fn lst_get(&self, idx: i64, null_on_oob: bool) -> PolarsResult { + let ca = self.as_list(); + if !null_on_oob && ca.downcast_iter().any(|arr| index_is_oob(arr, idx)) { + polars_bail!(ComputeError: "get index is out of bounds"); + } + + let chunks = ca + .downcast_iter() + .map(|arr| sublist_get(arr, idx)) + .collect::>(); + + let s = Series::try_from((ca.name().clone(), chunks)).unwrap(); + // SAFETY: every element in list has dtype equal to its inner type + unsafe { s.from_physical_unchecked(ca.inner_dtype()) } + } + + #[cfg(feature = "list_gather")] + fn lst_gather_every(&self, n: &IdxCa, offset: &IdxCa) -> PolarsResult { + let list_ca = self.as_list(); + let out = match (n.len(), offset.len()) { + (1, 1) => match (n.get(0), offset.get(0)) { + (Some(n), Some(offset)) => list_ca.try_apply_amortized(|s| { + s.as_ref().gather_every(n as usize, offset as usize) + })?, + _ => ListChunked::full_null_with_dtype( + list_ca.name().clone(), + list_ca.len(), + list_ca.inner_dtype(), + ), + }, + (1, len_offset) if len_offset == list_ca.len() => { + if let Some(n) = n.get(0) { + list_ca.try_zip_and_apply_amortized(offset, |opt_s, opt_offset| { + match (opt_s, opt_offset) { + (Some(s), Some(offset)) => { + Ok(Some(s.as_ref().gather_every(n as usize, offset as usize)?)) + }, + _ => Ok(None), + } + })? + } else { + ListChunked::full_null_with_dtype( + list_ca.name().clone(), + list_ca.len(), + list_ca.inner_dtype(), + ) + } + }, + (len_n, 1) if len_n == list_ca.len() => { + if let Some(offset) = offset.get(0) { + list_ca.try_zip_and_apply_amortized(n, |opt_s, opt_n| match (opt_s, opt_n) { + (Some(s), Some(n)) => { + Ok(Some(s.as_ref().gather_every(n as usize, offset as usize)?)) + }, + _ => Ok(None), + })? + } else { + ListChunked::full_null_with_dtype( + list_ca.name().clone(), + list_ca.len(), + list_ca.inner_dtype(), + ) + } + }, + (len_n, len_offset) if len_n == len_offset && len_n == list_ca.len() => list_ca + .try_binary_zip_and_apply_amortized( + n, + offset, + |opt_s, opt_n, opt_offset| match (opt_s, opt_n, opt_offset) { + (Some(s), Some(n), Some(offset)) => { + Ok(Some(s.as_ref().gather_every(n as usize, offset as usize)?)) + }, + _ => Ok(None), + }, + )?, + _ => { + polars_bail!(ComputeError: "The lengths of `n` and `offset` should be 1 or equal to the length of list.") + }, + }; + Ok(out.into_series()) + } + + #[cfg(feature = "list_gather")] + fn lst_gather(&self, idx: &Series, null_on_oob: bool) -> PolarsResult { + let list_ca = self.as_list(); + + let index_typed_index = |idx: &Series| { + let idx = idx.cast(&IDX_DTYPE).unwrap(); + { + list_ca + .amortized_iter() + .map(|s| { + s.map(|s| { + let s = s.as_ref(); + take_series(s, idx.clone(), null_on_oob) + }) + .transpose() + }) + .collect::>() + .map(|mut ca| { + ca.rename(list_ca.name().clone()); + ca.into_series() + }) + } + }; + + use DataType::*; + match idx.dtype() { + List(boxed_dt) if boxed_dt.is_integer() => { + let idx_ca = idx.list().unwrap(); + let mut out = { + list_ca + .amortized_iter() + .zip(idx_ca) + .map(|(opt_s, opt_idx)| { + { + match (opt_s, opt_idx) { + (Some(s), Some(idx)) => { + Some(take_series(s.as_ref(), idx, null_on_oob)) + }, + _ => None, + } + } + .transpose() + }) + .collect::>()? + }; + out.rename(list_ca.name().clone()); + + Ok(out.into_series()) + }, + UInt32 | UInt64 => index_typed_index(idx), + dt if dt.is_signed_integer() => { + if let Some(min) = idx.min::().unwrap() { + if min >= 0 { + index_typed_index(idx) + } else { + let mut out = { + list_ca + .amortized_iter() + .map(|opt_s| { + opt_s + .map(|s| take_series(s.as_ref(), idx.clone(), null_on_oob)) + .transpose() + }) + .collect::>()? + }; + out.rename(list_ca.name().clone()); + Ok(out.into_series()) + } + } else { + polars_bail!(ComputeError: "all indices are null"); + } + }, + dt => polars_bail!(ComputeError: "cannot use dtype `{}` as an index", dt), + } + } + + #[cfg(feature = "list_drop_nulls")] + fn lst_drop_nulls(&self) -> ListChunked { + let list_ca = self.as_list(); + + list_ca.apply_amortized(|s| s.as_ref().drop_nulls()) + } + + #[cfg(feature = "list_sample")] + fn lst_sample_n( + &self, + n: &Series, + with_replacement: bool, + shuffle: bool, + seed: Option, + ) -> PolarsResult { + use std::borrow::Cow; + + let ca = self.as_list(); + + let n_s = n.cast(&IDX_DTYPE)?; + let n = n_s.idx()?; + + polars_ensure!( + ca.len() == n.len() || ca.len() == 1 || n.len() == 1, + length_mismatch = "list.sample(n)", + ca.len(), + n.len() + ); + + // Broadcast `self` + let mut ca = Cow::Borrowed(ca); + if ca.len() == 1 && n.len() != 1 { + // Optimize: Don't broadcast and instead have a special path. + ca = Cow::Owned(ca.new_from_index(0, n.len())); + } + let ca = ca.as_ref(); + + let out = match n.len() { + 1 => { + if let Some(n) = n.get(0) { + ca.try_apply_amortized(|s| { + s.as_ref() + .sample_n(n as usize, with_replacement, shuffle, seed) + }) + } else { + Ok(ListChunked::full_null_with_dtype( + ca.name().clone(), + ca.len(), + ca.inner_dtype(), + )) + } + }, + _ => ca.try_zip_and_apply_amortized(n, |opt_s, opt_n| match (opt_s, opt_n) { + (Some(s), Some(n)) => s + .as_ref() + .sample_n(n as usize, with_replacement, shuffle, seed) + .map(Some), + _ => Ok(None), + }), + }; + out.map(|ok| self.same_type(ok)) + } + + #[cfg(feature = "list_sample")] + fn lst_sample_fraction( + &self, + fraction: &Series, + with_replacement: bool, + shuffle: bool, + seed: Option, + ) -> PolarsResult { + use std::borrow::Cow; + + let ca = self.as_list(); + + let fraction_s = fraction.cast(&DataType::Float64)?; + let fraction = fraction_s.f64()?; + + polars_ensure!( + ca.len() == fraction.len() || ca.len() == 1 || fraction.len() == 1, + length_mismatch = "list.sample(fraction)", + ca.len(), + fraction.len() + ); + + // Broadcast `self` + let mut ca = Cow::Borrowed(ca); + if ca.len() == 1 && fraction.len() != 1 { + // Optimize: Don't broadcast and instead have a special path. + ca = Cow::Owned(ca.new_from_index(0, fraction.len())); + } + let ca = ca.as_ref(); + + let out = match fraction.len() { + 1 => { + if let Some(fraction) = fraction.get(0) { + ca.try_apply_amortized(|s| { + let n = (s.as_ref().len() as f64 * fraction) as usize; + s.as_ref().sample_n(n, with_replacement, shuffle, seed) + }) + } else { + Ok(ListChunked::full_null_with_dtype( + ca.name().clone(), + ca.len(), + ca.inner_dtype(), + )) + } + }, + _ => ca.try_zip_and_apply_amortized(fraction, |opt_s, opt_n| match (opt_s, opt_n) { + (Some(s), Some(fraction)) => { + let n = (s.as_ref().len() as f64 * fraction) as usize; + s.as_ref() + .sample_n(n, with_replacement, shuffle, seed) + .map(Some) + }, + _ => Ok(None), + }), + }; + out.map(|ok| self.same_type(ok)) + } + + fn lst_concat(&self, other: &[Column]) -> PolarsResult { + let ca = self.as_list(); + let other_len = other.len(); + let length = ca.len(); + let mut other = other.to_vec(); + let mut inner_super_type = ca.inner_dtype().clone(); + + for s in &other { + match s.dtype() { + DataType::List(inner_type) => { + inner_super_type = try_get_supertype(&inner_super_type, inner_type)?; + #[cfg(feature = "dtype-categorical")] + if matches!( + &inner_super_type, + DataType::Categorical(_, _) | DataType::Enum(_, _) + ) { + inner_super_type = merge_dtypes(&inner_super_type, inner_type)?; + } + }, + dt => { + inner_super_type = try_get_supertype(&inner_super_type, dt)?; + #[cfg(feature = "dtype-categorical")] + if matches!( + &inner_super_type, + DataType::Categorical(_, _) | DataType::Enum(_, _) + ) { + inner_super_type = merge_dtypes(&inner_super_type, dt)?; + } + }, + } + } + + // cast lhs + let dtype = &DataType::List(Box::new(inner_super_type.clone())); + let ca = ca.cast(dtype)?; + let ca = ca.list().unwrap(); + + // broadcasting path in case all unit length + // this path will not expand the series, so saves memory + let out = if other.iter().all(|s| s.len() == 1) && ca.len() != 1 { + cast_rhs(&mut other, &inner_super_type, dtype, length, false)?; + let to_append = other + .iter() + .flat_map(|s| { + let lst = s.list().unwrap(); + lst.get_as_series(0) + }) + .collect::>(); + // there was a None, so all values will be None + if to_append.len() != other_len { + return Ok(ListChunked::full_null_with_dtype( + ca.name().clone(), + length, + &inner_super_type, + )); + } + + let vals_size_other = other + .iter() + .map(|s| s.list().unwrap().get_values_size()) + .sum::(); + + let mut builder = get_list_builder( + &inner_super_type, + ca.get_values_size() + vals_size_other + 1, + length, + ca.name().clone(), + ); + ca.into_iter().for_each(|opt_s| { + let opt_s = opt_s.map(|mut s| { + for append in &to_append { + s.append(append).unwrap(); + } + match inner_super_type { + // structs don't have chunks, so we must first rechunk the underlying series + #[cfg(feature = "dtype-struct")] + DataType::Struct(_) => s = s.rechunk(), + // nothing + _ => {}, + } + s + }); + builder.append_opt_series(opt_s.as_ref()).unwrap(); + }); + builder.finish() + } else { + // normal path which may contain same length list or unit length lists + cast_rhs(&mut other, &inner_super_type, dtype, length, true)?; + + let vals_size_other = other + .iter() + .map(|s| s.list().unwrap().get_values_size()) + .sum::(); + let mut iters = Vec::with_capacity(other_len + 1); + + for s in other.iter_mut() { + iters.push(s.list()?.amortized_iter()) + } + let mut first_iter: Box>> = ca.into_iter(); + let mut builder = get_list_builder( + &inner_super_type, + ca.get_values_size() + vals_size_other + 1, + length, + ca.name().clone(), + ); + + for _ in 0..ca.len() { + let mut acc = match first_iter.next().unwrap() { + Some(s) => s, + None => { + builder.append_null(); + // make sure that the iterators advance before we continue + for it in &mut iters { + it.next().unwrap(); + } + continue; + }, + }; + + let mut has_nulls = false; + for it in &mut iters { + match it.next().unwrap() { + Some(s) => { + if !has_nulls { + acc.append(s.as_ref())?; + } + }, + None => { + has_nulls = true; + }, + } + } + if has_nulls { + builder.append_null(); + continue; + } + + match inner_super_type { + // structs don't have chunks, so we must first rechunk the underlying series + #[cfg(feature = "dtype-struct")] + DataType::Struct(_) => acc = acc.rechunk(), + // nothing + _ => {}, + } + builder.append_series(&acc).unwrap(); + } + builder.finish() + }; + Ok(out) + } +} + +impl ListNameSpaceImpl for ListChunked {} + +#[cfg(feature = "list_gather")] +fn take_series(s: &Series, idx: Series, null_on_oob: bool) -> PolarsResult { + let len = s.len(); + let idx = cast_index(idx, len, null_on_oob)?; + let idx = idx.idx().unwrap(); + s.take(idx) +} + +#[cfg(feature = "list_gather")] +fn cast_signed_index_ca(idx: &ChunkedArray, len: usize) -> Series +where + T::Native: Copy + PartialOrd + PartialEq + NumCast + Signed + Zero, +{ + idx.iter() + .map(|opt_idx| opt_idx.and_then(|idx| idx.negative_to_usize(len).map(|idx| idx as IdxSize))) + .collect::() + .into_series() +} + +#[cfg(feature = "list_gather")] +fn cast_unsigned_index_ca(idx: &ChunkedArray, len: usize) -> Series +where + T::Native: Copy + PartialOrd + ToPrimitive, +{ + idx.iter() + .map(|opt_idx| { + opt_idx.and_then(|idx| { + let idx = idx.to_usize().unwrap(); + if idx >= len { + None + } else { + Some(idx as IdxSize) + } + }) + }) + .collect::() + .into_series() +} + +#[cfg(feature = "list_gather")] +fn cast_index(idx: Series, len: usize, null_on_oob: bool) -> PolarsResult { + let idx_null_count = idx.null_count(); + use DataType::*; + let out = match idx.dtype() { + #[cfg(feature = "big_idx")] + UInt32 => { + if null_on_oob { + let a = idx.u32().unwrap(); + cast_unsigned_index_ca(a, len) + } else { + idx.cast(&IDX_DTYPE).unwrap() + } + }, + #[cfg(feature = "big_idx")] + UInt64 => { + if null_on_oob { + let a = idx.u64().unwrap(); + cast_unsigned_index_ca(a, len) + } else { + idx + } + }, + #[cfg(not(feature = "big_idx"))] + UInt64 => { + if null_on_oob { + let a = idx.u64().unwrap(); + cast_unsigned_index_ca(a, len) + } else { + idx.cast(&IDX_DTYPE).unwrap() + } + }, + #[cfg(not(feature = "big_idx"))] + UInt32 => { + if null_on_oob { + let a = idx.u32().unwrap(); + cast_unsigned_index_ca(a, len) + } else { + idx + } + }, + dt if dt.is_unsigned_integer() => idx.cast(&IDX_DTYPE).unwrap(), + Int8 => { + let a = idx.i8().unwrap(); + cast_signed_index_ca(a, len) + }, + Int16 => { + let a = idx.i16().unwrap(); + cast_signed_index_ca(a, len) + }, + Int32 => { + let a = idx.i32().unwrap(); + cast_signed_index_ca(a, len) + }, + Int64 => { + let a = idx.i64().unwrap(); + cast_signed_index_ca(a, len) + }, + _ => { + unreachable!() + }, + }; + polars_ensure!( + out.null_count() == idx_null_count || null_on_oob, + OutOfBounds: "gather indices are out of bounds" + ); + Ok(out) +} + +// TODO: implement the above for ArrayChunked as well? diff --git a/crates/polars-ops/src/chunked_array/list/sets.rs b/crates/polars-ops/src/chunked_array/list/sets.rs new file mode 100644 index 000000000000..1f2b42052234 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/list/sets.rs @@ -0,0 +1,448 @@ +use std::fmt::{Display, Formatter}; +use std::hash::Hash; + +use arrow::array::{ + Array, BinaryViewArray, ListArray, MutableArray, MutablePlBinary, MutablePrimitiveArray, + PrimitiveArray, Utf8ViewArray, +}; +use arrow::bitmap::Bitmap; +use arrow::compute::utils::combine_validities_and; +use arrow::offset::OffsetsBuffer; +use arrow::types::NativeType; +use polars_core::prelude::*; +use polars_core::with_match_physical_numeric_type; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash, TotalOrdWrap}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +trait MaterializeValues { + // extends the iterator to the values and returns the current offset + fn extend_buf>(&mut self, values: I) -> usize; +} + +impl MaterializeValues> for MutablePrimitiveArray +where + T: NativeType, +{ + fn extend_buf>>(&mut self, values: I) -> usize { + self.extend(values); + self.len() + } +} + +impl MaterializeValues>> for MutablePrimitiveArray +where + T: NativeType, +{ + fn extend_buf>>>(&mut self, values: I) -> usize { + self.extend(values.map(|x| x.0)); + self.len() + } +} + +impl<'a> MaterializeValues> for MutablePlBinary { + fn extend_buf>>(&mut self, values: I) -> usize { + self.extend(values); + self.len() + } +} + +fn set_operation( + set: &mut PlIndexSet, + set2: &mut PlIndexSet, + a: I, + b: J, + out: &mut R, + set_op: SetOperation, + broadcast_rhs: bool, +) -> usize +where + K: Eq + Hash + Copy, + I: IntoIterator, + J: IntoIterator, + R: MaterializeValues, +{ + set.clear(); + let a = a.into_iter(); + let b = b.into_iter(); + + match set_op { + SetOperation::Intersection => { + set.extend(a); + // If broadcast `set2` should already be filled. + if !broadcast_rhs { + set2.clear(); + set2.extend(b); + } + out.extend_buf(set.intersection(set2).copied()) + }, + SetOperation::Union => { + set.extend(a); + set.extend(b); + out.extend_buf(set.drain(..)) + }, + SetOperation::Difference => { + set.extend(a); + for v in b { + set.swap_remove(&v); + } + out.extend_buf(set.drain(..)) + }, + SetOperation::SymmetricDifference => { + // If broadcast `set2` should already be filled. + if !broadcast_rhs { + set2.clear(); + set2.extend(b); + } + // We could speed this up, but implementing ourselves, but we need to have a cloneable + // iterator as we need 2 passes + set.extend(a); + out.extend_buf(set.symmetric_difference(set2).copied()) + }, + } +} + +fn copied_wrapper_opt( + v: Option<&T>, +) -> as ToTotalOrd>::TotalOrdItem { + v.copied().to_total_ord() +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum SetOperation { + Intersection, + Union, + Difference, + SymmetricDifference, +} + +impl Display for SetOperation { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let s = match self { + SetOperation::Intersection => "intersection", + SetOperation::Union => "union", + SetOperation::Difference => "difference", + SetOperation::SymmetricDifference => "symmetric_difference", + }; + write!(f, "{s}") + } +} + +fn primitive( + a: &PrimitiveArray, + b: &PrimitiveArray, + offsets_a: &[i64], + offsets_b: &[i64], + set_op: SetOperation, + validity: Option, +) -> PolarsResult> +where + T: NativeType + TotalHash + TotalEq + Copy + ToTotalOrd, + as ToTotalOrd>::TotalOrdItem: Hash + Eq + Copy, +{ + let broadcast_lhs = offsets_a.len() == 2; + let broadcast_rhs = offsets_b.len() == 2; + + let mut set = Default::default(); + let mut set2: PlIndexSet< as ToTotalOrd>::TotalOrdItem> = Default::default(); + + let mut values_out = MutablePrimitiveArray::with_capacity(std::cmp::max( + *offsets_a.last().unwrap(), + *offsets_b.last().unwrap(), + ) as usize); + let mut offsets = Vec::with_capacity(std::cmp::max(offsets_a.len(), offsets_b.len())); + offsets.push(0i64); + + let offsets_slice = if offsets_a.len() > offsets_b.len() { + offsets_a + } else { + offsets_b + }; + let first_a = offsets_a[0]; + let second_a = offsets_a[1]; + let first_b = offsets_b[0]; + let second_b = offsets_b[1]; + if broadcast_rhs { + set2.extend( + b.into_iter() + .skip(first_b as usize) + .take(second_b as usize - first_b as usize) + .map(copied_wrapper_opt), + ); + } + for i in 1..offsets_slice.len() { + // If we go OOB we take the first element as we are then broadcasting. + let start_a = *offsets_a.get(i - 1).unwrap_or(&first_a) as usize; + let end_a = *offsets_a.get(i).unwrap_or(&second_a) as usize; + + let start_b = *offsets_b.get(i - 1).unwrap_or(&first_b) as usize; + let end_b = *offsets_b.get(i).unwrap_or(&second_b) as usize; + + // The branches are the same every loop. + // We rely on branch prediction here. + let offset = if broadcast_rhs { + // going via skip iterator instead of slice doesn't heap alloc nor trigger a bitcount + let a_iter = a + .into_iter() + .skip(start_a) + .take(end_a - start_a) + .map(copied_wrapper_opt); + let b_iter = b + .into_iter() + .skip(first_b as usize) + .take(second_b as usize - first_b as usize) + .map(copied_wrapper_opt); + set_operation( + &mut set, + &mut set2, + a_iter, + b_iter, + &mut values_out, + set_op, + true, + ) + } else if broadcast_lhs { + let a_iter = a + .into_iter() + .skip(first_a as usize) + .take(second_a as usize - first_a as usize) + .map(copied_wrapper_opt); + + let b_iter = b + .into_iter() + .skip(start_b) + .take(end_b - start_b) + .map(copied_wrapper_opt); + + set_operation( + &mut set, + &mut set2, + a_iter, + b_iter, + &mut values_out, + set_op, + false, + ) + } else { + // going via skip iterator instead of slice doesn't heap alloc nor trigger a bitcount + let a_iter = a + .into_iter() + .skip(start_a) + .take(end_a - start_a) + .map(copied_wrapper_opt); + + let b_iter = b + .into_iter() + .skip(start_b) + .take(end_b - start_b) + .map(copied_wrapper_opt); + set_operation( + &mut set, + &mut set2, + a_iter, + b_iter, + &mut values_out, + set_op, + false, + ) + }; + + offsets.push(offset as i64); + } + let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) }; + let dtype = ListArray::::default_datatype(values_out.dtype().clone()); + + let values: PrimitiveArray = values_out.into(); + Ok(ListArray::new(dtype, offsets, values.boxed(), validity)) +} + +fn binary( + a: &BinaryViewArray, + b: &BinaryViewArray, + offsets_a: &[i64], + offsets_b: &[i64], + set_op: SetOperation, + validity: Option, + as_utf8: bool, +) -> PolarsResult> { + let broadcast_lhs = offsets_a.len() == 2; + let broadcast_rhs = offsets_b.len() == 2; + let mut set = Default::default(); + let mut set2: PlIndexSet> = Default::default(); + + let mut values_out = MutablePlBinary::with_capacity(std::cmp::max( + *offsets_a.last().unwrap(), + *offsets_b.last().unwrap(), + ) as usize); + let mut offsets = Vec::with_capacity(std::cmp::max(offsets_a.len(), offsets_b.len())); + offsets.push(0i64); + + if broadcast_rhs { + set2.extend(b); + } + let offsets_slice = if offsets_a.len() > offsets_b.len() { + offsets_a + } else { + offsets_b + }; + let first_a = offsets_a[0]; + let second_a = offsets_a[1]; + let first_b = offsets_b[0]; + let second_b = offsets_b[1]; + for i in 1..offsets_slice.len() { + // If we go OOB we take the first element as we are then broadcasting. + let start_a = *offsets_a.get(i - 1).unwrap_or(&first_a) as usize; + let end_a = *offsets_a.get(i).unwrap_or(&second_a) as usize; + + let start_b = *offsets_b.get(i - 1).unwrap_or(&first_b) as usize; + let end_b = *offsets_b.get(i).unwrap_or(&second_b) as usize; + + // The branches are the same every loop. + // We rely on branch prediction here. + let offset = if broadcast_rhs { + // going via skip iterator instead of slice doesn't heap alloc nor trigger a bitcount + let a_iter = a.into_iter().skip(start_a).take(end_a - start_a); + let b_iter = b + .into_iter() + .skip(first_b as usize) + .take(second_b as usize - first_b as usize); + set_operation( + &mut set, + &mut set2, + a_iter, + b_iter, + &mut values_out, + set_op, + true, + ) + } else if broadcast_lhs { + let a_iter = a + .into_iter() + .skip(first_a as usize) + .take(second_a as usize - first_a as usize); + let b_iter = b.into_iter().skip(start_b).take(end_b - start_b); + set_operation( + &mut set, + &mut set2, + a_iter, + b_iter, + &mut values_out, + set_op, + false, + ) + } else { + // going via skip iterator instead of slice doesn't heap alloc nor trigger a bitcount + let a_iter = a.into_iter().skip(start_a).take(end_a - start_a); + let b_iter = b.into_iter().skip(start_b).take(end_b - start_b); + set_operation( + &mut set, + &mut set2, + a_iter, + b_iter, + &mut values_out, + set_op, + false, + ) + }; + offsets.push(offset as i64); + } + let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) }; + let values = values_out.freeze(); + + if as_utf8 { + let values = unsafe { values.to_utf8view_unchecked() }; + let dtype = ListArray::::default_datatype(values.dtype().clone()); + Ok(ListArray::new(dtype, offsets, values.boxed(), validity)) + } else { + let dtype = ListArray::::default_datatype(values.dtype().clone()); + Ok(ListArray::new(dtype, offsets, values.boxed(), validity)) + } +} + +fn array_set_operation( + a: &ListArray, + b: &ListArray, + set_op: SetOperation, +) -> PolarsResult> { + let offsets_a = a.offsets().as_slice(); + let offsets_b = b.offsets().as_slice(); + + let values_a = a.values(); + let values_b = b.values(); + assert_eq!(values_a.dtype(), values_b.dtype()); + + let dtype = values_b.dtype(); + let validity = combine_validities_and(a.validity(), b.validity()); + + match dtype { + ArrowDataType::Utf8View => { + let a = values_a + .as_any() + .downcast_ref::() + .unwrap() + .to_binview(); + let b = values_b + .as_any() + .downcast_ref::() + .unwrap() + .to_binview(); + + binary(&a, &b, offsets_a, offsets_b, set_op, validity, true) + }, + ArrowDataType::BinaryView => { + let a = values_a.as_any().downcast_ref::().unwrap(); + let b = values_b.as_any().downcast_ref::().unwrap(); + binary(a, b, offsets_a, offsets_b, set_op, validity, false) + }, + ArrowDataType::Boolean => { + polars_bail!(InvalidOperation: "boolean type not yet supported in list 'set' operations") + }, + _ => { + with_match_physical_numeric_type!(DataType::from_arrow_dtype(dtype), |$T| { + let a = values_a.as_any().downcast_ref::>().unwrap(); + let b = values_b.as_any().downcast_ref::>().unwrap(); + + primitive(&a, &b, offsets_a, offsets_b, set_op, validity) + }) + }, + } +} + +pub fn list_set_operation( + a: &ListChunked, + b: &ListChunked, + set_op: SetOperation, +) -> PolarsResult { + polars_ensure!(a.len() == b.len() || b.len() == 1 || a.len() == 1, ShapeMismatch: "column lengths don't match"); + polars_ensure!(a.dtype() == b.dtype(), InvalidOperation: "cannot do 'set' operation on dtypes: {} and {}", a.dtype(), b.dtype()); + let mut a = a.clone(); + let mut b = b.clone(); + if a.len() != b.len() { + a.rechunk_mut(); + b.rechunk_mut(); + } + + // We will OOB in the kernel otherwise. + a.prune_empty_chunks(); + b.prune_empty_chunks(); + + // Make categoricals compatible + #[cfg(feature = "dtype-categorical")] + if let (DataType::Categorical(_, _), DataType::Categorical(_, _)) = + (&a.inner_dtype(), &b.inner_dtype()) + { + (a, b) = make_rhs_list_categoricals_compatible(a, b)?; + } + + // we use the unsafe variant because we want to keep the nested logical types type. + unsafe { + arity::try_binary_unchecked_same_type( + &a, + &b, + |a, b| array_set_operation(a, b, set_op).map(|arr| arr.boxed()), + false, + false, + ) + } +} diff --git a/crates/polars-ops/src/chunked_array/list/sum_mean.rs b/crates/polars-ops/src/chunked_array/list/sum_mean.rs new file mode 100644 index 000000000000..8d4f2fa025c1 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/list/sum_mean.rs @@ -0,0 +1,212 @@ +use std::ops::Div; + +use arrow::array::{Array, PrimitiveArray}; +use arrow::bitmap::Bitmap; +use arrow::compute::utils::combine_validities_and; +use arrow::types::NativeType; +use num_traits::{NumCast, ToPrimitive}; + +use super::*; +use crate::chunked_array::sum::sum_slice; + +fn sum_between_offsets(values: &[T], offset: &[i64]) -> Vec +where + T: NativeType + ToPrimitive, + S: NumCast + std::iter::Sum, +{ + offset + .windows(2) + .map(|w| { + values + .get(w[0] as usize..w[1] as usize) + .map(sum_slice) + .unwrap_or(S::from(0).unwrap()) + }) + .collect() +} + +fn dispatch_sum(arr: &dyn Array, offsets: &[i64], validity: Option<&Bitmap>) -> ArrayRef +where + T: NativeType + ToPrimitive, + S: NativeType + NumCast + std::iter::Sum, +{ + let values = arr.as_any().downcast_ref::>().unwrap(); + let values = values.values().as_slice(); + Box::new(PrimitiveArray::from_data_default( + sum_between_offsets::<_, S>(values, offsets).into(), + validity.cloned(), + )) as ArrayRef +} + +pub(super) fn sum_list_numerical(ca: &ListChunked, inner_type: &DataType) -> Series { + use DataType::*; + let chunks = ca + .downcast_iter() + .map(|arr| { + let offsets = arr.offsets().as_slice(); + let values = arr.values().as_ref(); + + match inner_type { + Int8 => dispatch_sum::(values, offsets, arr.validity()), + Int16 => dispatch_sum::(values, offsets, arr.validity()), + Int32 => dispatch_sum::(values, offsets, arr.validity()), + Int64 => dispatch_sum::(values, offsets, arr.validity()), + Int128 => dispatch_sum::(values, offsets, arr.validity()), + UInt8 => dispatch_sum::(values, offsets, arr.validity()), + UInt16 => dispatch_sum::(values, offsets, arr.validity()), + UInt32 => dispatch_sum::(values, offsets, arr.validity()), + UInt64 => dispatch_sum::(values, offsets, arr.validity()), + Float32 => dispatch_sum::(values, offsets, arr.validity()), + Float64 => dispatch_sum::(values, offsets, arr.validity()), + _ => unimplemented!(), + } + }) + .collect::>(); + + Series::try_from((ca.name().clone(), chunks)).unwrap() +} + +pub(super) fn sum_with_nulls(ca: &ListChunked, inner_dtype: &DataType) -> PolarsResult { + use DataType::*; + // TODO: add fast path for smaller ints? + let mut out = match inner_dtype { + Boolean => { + let out: IdxCa = + ca.apply_amortized_generic(|s| s.map(|s| s.as_ref().sum::().unwrap())); + out.into_series() + }, + UInt32 => { + let out: UInt32Chunked = + ca.apply_amortized_generic(|s| s.map(|s| s.as_ref().sum::().unwrap())); + out.into_series() + }, + UInt64 => { + let out: UInt64Chunked = + ca.apply_amortized_generic(|s| s.map(|s| s.as_ref().sum::().unwrap())); + out.into_series() + }, + Int32 => { + let out: Int32Chunked = + ca.apply_amortized_generic(|s| s.map(|s| s.as_ref().sum::().unwrap())); + out.into_series() + }, + Int64 => { + let out: Int64Chunked = + ca.apply_amortized_generic(|s| s.map(|s| s.as_ref().sum::().unwrap())); + out.into_series() + }, + Float32 => { + let out: Float32Chunked = + ca.apply_amortized_generic(|s| s.map(|s| s.as_ref().sum::().unwrap())); + out.into_series() + }, + Float64 => { + let out: Float64Chunked = + ca.apply_amortized_generic(|s| s.map(|s| s.as_ref().sum::().unwrap())); + out.into_series() + }, + // slowest sum_as_series path + dt => ca + .try_apply_amortized(|s| { + s.as_ref() + .sum_reduce() + .map(|sc| sc.into_series(PlSmallStr::EMPTY)) + })? + .explode() + .unwrap() + .into_series() + .cast(dt)?, + }; + out.rename(ca.name().clone()); + Ok(out) +} + +fn mean_between_offsets(values: &[T], offset: &[i64]) -> PrimitiveArray +where + T: NativeType + ToPrimitive, + S: NativeType + NumCast + std::iter::Sum + Div, +{ + offset + .windows(2) + .map(|w| { + values + .get(w[0] as usize..w[1] as usize) + .filter(|sl| !sl.is_empty()) + .map(|sl| sum_slice::<_, S>(sl) / NumCast::from(sl.len()).unwrap()) + }) + .collect() +} + +fn dispatch_mean(arr: &dyn Array, offsets: &[i64], validity: Option<&Bitmap>) -> ArrayRef +where + T: NativeType + ToPrimitive, + S: NativeType + NumCast + std::iter::Sum + Div, +{ + let values = arr.as_any().downcast_ref::>().unwrap(); + let values = values.values().as_slice(); + let out = mean_between_offsets::<_, S>(values, offsets); + let new_validity = combine_validities_and(out.validity(), validity); + out.with_validity(new_validity).to_boxed() +} + +pub(super) fn mean_list_numerical(ca: &ListChunked, inner_type: &DataType) -> Series { + use DataType::*; + let chunks = ca + .downcast_iter() + .map(|arr| { + let offsets = arr.offsets().as_slice(); + let values = arr.values().as_ref(); + + match inner_type { + Int8 => dispatch_mean::(values, offsets, arr.validity()), + Int16 => dispatch_mean::(values, offsets, arr.validity()), + Int32 => dispatch_mean::(values, offsets, arr.validity()), + Int64 => dispatch_mean::(values, offsets, arr.validity()), + Int128 => dispatch_mean::(values, offsets, arr.validity()), + UInt8 => dispatch_mean::(values, offsets, arr.validity()), + UInt16 => dispatch_mean::(values, offsets, arr.validity()), + UInt32 => dispatch_mean::(values, offsets, arr.validity()), + UInt64 => dispatch_mean::(values, offsets, arr.validity()), + Float32 => dispatch_mean::(values, offsets, arr.validity()), + Float64 => dispatch_mean::(values, offsets, arr.validity()), + _ => unimplemented!(), + } + }) + .collect::>(); + + Series::try_from((ca.name().clone(), chunks)).unwrap() +} + +pub(super) fn mean_with_nulls(ca: &ListChunked) -> Series { + match ca.inner_dtype() { + DataType::Float32 => { + let out: Float32Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().mean().map(|v| v as f32))) + .with_name(ca.name().clone()); + out.into_series() + }, + #[cfg(feature = "dtype-datetime")] + DataType::Date => { + const MS_IN_DAY: i64 = 86_400_000; + let out: Int64Chunked = ca + .apply_amortized_generic(|s| { + s.and_then(|s| s.as_ref().mean().map(|v| (v * (MS_IN_DAY as f64)) as i64)) + }) + .with_name(ca.name().clone()); + out.into_datetime(TimeUnit::Milliseconds, None) + .into_series() + }, + dt if dt.is_temporal() => { + let out: Int64Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().mean().map(|v| v as i64))) + .with_name(ca.name().clone()); + out.cast(dt).unwrap() + }, + _ => { + let out: Float64Chunked = ca + .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().mean())) + .with_name(ca.name().clone()); + out.into_series() + }, + } +} diff --git a/crates/polars-ops/src/chunked_array/list/to_struct.rs b/crates/polars-ops/src/chunked_array/list/to_struct.rs new file mode 100644 index 000000000000..4b0ccde25a45 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/list/to_struct.rs @@ -0,0 +1,224 @@ +use polars_core::POOL; +use polars_utils::format_pl_smallstr; +use polars_utils::pl_str::PlSmallStr; +use rayon::prelude::*; + +use super::*; + +#[derive(Clone, Eq, PartialEq, Hash, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub enum ListToStructArgs { + FixedWidth(Arc<[PlSmallStr]>), + InferWidth { + infer_field_strategy: ListToStructWidthStrategy, + get_index_name: Option, + /// If this is 0, it means unbounded. + max_fields: usize, + }, +} + +#[derive(Clone, Eq, PartialEq, Hash, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub enum ListToStructWidthStrategy { + FirstNonNull, + MaxWidth, +} + +impl ListToStructArgs { + pub fn get_output_dtype(&self, input_dtype: &DataType) -> PolarsResult { + let DataType::List(inner_dtype) = input_dtype else { + polars_bail!( + InvalidOperation: + "attempted list to_struct on non-list dtype: {}", + input_dtype + ); + }; + let inner_dtype = inner_dtype.as_ref(); + + match self { + Self::FixedWidth(names) => Ok(DataType::Struct( + names + .iter() + .map(|x| Field::new(x.clone(), inner_dtype.clone())) + .collect::>(), + )), + Self::InferWidth { + get_index_name, + max_fields, + .. + } if *max_fields > 0 => { + let get_index_name_func = get_index_name.as_ref().map_or( + &_default_struct_name_gen as &dyn Fn(usize) -> PlSmallStr, + |x| x.0.as_ref(), + ); + Ok(DataType::Struct( + (0..*max_fields) + .map(|i| Field::new(get_index_name_func(i), inner_dtype.clone())) + .collect::>(), + )) + }, + Self::InferWidth { .. } => Ok(DataType::Unknown(UnknownKind::Any)), + } + } + + fn det_n_fields(&self, ca: &ListChunked) -> usize { + match self { + Self::FixedWidth(v) => v.len(), + Self::InferWidth { + infer_field_strategy, + max_fields, + .. + } => { + let inferred = match infer_field_strategy { + ListToStructWidthStrategy::MaxWidth => { + let mut max = 0; + + ca.downcast_iter().for_each(|arr| { + let offsets = arr.offsets().as_slice(); + let mut last = offsets[0]; + for o in &offsets[1..] { + let len = (*o - last) as usize; + max = std::cmp::max(max, len); + last = *o; + } + }); + max + }, + ListToStructWidthStrategy::FirstNonNull => { + let mut len = 0; + for arr in ca.downcast_iter() { + let offsets = arr.offsets().as_slice(); + let mut last = offsets[0]; + for o in &offsets[1..] { + len = (*o - last) as usize; + if len > 0 { + break; + } + last = *o; + } + if len > 0 { + break; + } + } + len + }, + }; + + if *max_fields > 0 { + inferred.min(*max_fields) + } else { + inferred + } + }, + } + } + + fn set_output_names(&self, columns: &mut [Series]) { + match self { + Self::FixedWidth(v) => { + assert_eq!(columns.len(), v.len()); + + for (c, name) in columns.iter_mut().zip(v.iter()) { + c.rename(name.clone()); + } + }, + Self::InferWidth { get_index_name, .. } => { + let get_index_name_func = get_index_name.as_ref().map_or( + &_default_struct_name_gen as &dyn Fn(usize) -> PlSmallStr, + |x| x.0.as_ref(), + ); + + for (i, c) in columns.iter_mut().enumerate() { + c.rename(get_index_name_func(i)); + } + }, + } + } +} + +#[derive(Clone)] +pub struct NameGenerator(pub Arc PlSmallStr + Send + Sync>); + +impl NameGenerator { + pub fn from_func(func: impl Fn(usize) -> PlSmallStr + Send + Sync + 'static) -> Self { + Self(Arc::new(func)) + } +} + +impl std::fmt::Debug for NameGenerator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "list::to_struct::NameGenerator function at 0x{:016x}", + self.0.as_ref() as *const _ as *const () as usize + ) + } +} + +impl Eq for NameGenerator {} + +impl PartialEq for NameGenerator { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.0, &other.0) + } +} + +impl std::hash::Hash for NameGenerator { + fn hash(&self, state: &mut H) { + state.write_usize(Arc::as_ptr(&self.0) as *const () as usize) + } +} + +pub fn _default_struct_name_gen(idx: usize) -> PlSmallStr { + format_pl_smallstr!("field_{idx}") +} + +pub trait ToStruct: AsList { + fn to_struct(&self, args: &ListToStructArgs) -> PolarsResult { + let ca = self.as_list(); + let n_fields = args.det_n_fields(ca); + + let mut fields = POOL.install(|| { + (0..n_fields) + .into_par_iter() + .map(|i| ca.lst_get(i as i64, true)) + .collect::>>() + })?; + + args.set_output_names(&mut fields); + + StructChunked::from_series(ca.name().clone(), ca.len(), fields.iter()) + } +} + +impl ToStruct for ListChunked {} + +#[cfg(feature = "serde")] +mod _serde_impl { + use super::*; + + impl serde::Serialize for NameGenerator { + fn serialize(&self, _serializer: S) -> Result + where + S: serde::Serializer, + { + use serde::ser::Error; + Err(S::Error::custom( + "cannot serialize name generator function for to_struct, \ + consider passing a list of field names instead.", + )) + } + } + + impl<'de> serde::Deserialize<'de> for NameGenerator { + fn deserialize(_deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + use serde::de::Error; + Err(D::Error::custom( + "invalid data: attempted to deserialize list::to_struct::NameGenerator", + )) + } + } +} diff --git a/crates/polars-ops/src/chunked_array/mod.rs b/crates/polars-ops/src/chunked_array/mod.rs new file mode 100644 index 000000000000..c0f1941a90dc --- /dev/null +++ b/crates/polars-ops/src/chunked_array/mod.rs @@ -0,0 +1,48 @@ +#[cfg(feature = "dtype-array")] +pub mod array; +mod binary; +#[cfg(feature = "timezones")] +pub mod datetime; +pub mod list; +#[cfg(feature = "propagate_nans")] +pub mod nan_propagating_aggregate; +#[cfg(feature = "peaks")] +pub mod peaks; +mod scatter; +pub mod strings; +mod sum; +#[cfg(feature = "top_k")] +mod top_k; + +#[cfg(feature = "mode")] +pub mod mode; + +#[cfg(feature = "cov")] +pub mod cov; +pub(crate) mod gather; +#[cfg(feature = "gather")] +pub mod gather_skip_nulls; +#[cfg(feature = "hist")] +mod hist; +#[cfg(feature = "repeat_by")] +mod repeat_by; + +pub use binary::*; +#[cfg(feature = "timezones")] +pub use datetime::*; +#[cfg(feature = "chunked_ids")] +pub use gather::*; +#[cfg(feature = "hist")] +pub use hist::*; +pub use list::*; +#[allow(unused_imports)] +use polars_core::prelude::*; +#[cfg(feature = "repeat_by")] +pub use repeat_by::*; +pub use scatter::ChunkedSet; +pub use strings::*; +#[cfg(feature = "top_k")] +pub use top_k::*; + +#[allow(unused_imports)] +use crate::prelude::*; diff --git a/crates/polars-ops/src/chunked_array/mode.rs b/crates/polars-ops/src/chunked_array/mode.rs new file mode 100644 index 000000000000..80aedcefc1e1 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/mode.rs @@ -0,0 +1,37 @@ +use polars_core::POOL; +use polars_core::prelude::*; + +fn mode_indices(groups: GroupsType) -> Vec { + match groups { + GroupsType::Idx(groups) => { + let Some(max_len) = groups.iter().map(|g| g.1.len()).max() else { + return Vec::new(); + }; + groups + .into_iter() + .filter(|g| g.1.len() == max_len) + .map(|g| g.0) + .collect() + }, + GroupsType::Slice { groups, .. } => { + let Some(max_len) = groups.iter().map(|g| g[1]).max() else { + return Vec::new(); + }; + groups + .into_iter() + .filter(|g| g[1] == max_len) + .map(|g| g[0]) + .collect() + }, + } +} + +pub fn mode(s: &Series) -> PolarsResult { + let parallel = !POOL.current_thread_has_pending_tasks().unwrap_or(false); + let groups = s.group_tuples(parallel, false).unwrap(); + let idx = mode_indices(groups); + let idx = IdxCa::from_vec("".into(), idx); + // SAFETY: + // group indices are in bounds + Ok(unsafe { s.take_unchecked(&idx) }) +} diff --git a/crates/polars-ops/src/chunked_array/nan_propagating_aggregate.rs b/crates/polars-ops/src/chunked_array/nan_propagating_aggregate.rs new file mode 100644 index 000000000000..55eaf3913e5b --- /dev/null +++ b/crates/polars-ops/src/chunked_array/nan_propagating_aggregate.rs @@ -0,0 +1,230 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use arrow::array::Array; +use arrow::legacy::kernels::take_agg::{ + take_agg_no_null_primitive_iter_unchecked, take_agg_primitive_iter_unchecked, +}; +use polars_compute::rolling; +use polars_compute::rolling::no_nulls::{MaxWindow, MinWindow}; +use polars_core::frame::group_by::aggregations::{ + _agg_helper_idx, _agg_helper_slice, _rolling_apply_agg_window_no_nulls, + _rolling_apply_agg_window_nulls, _slice_from_offsets, _use_rolling_kernels, +}; +use polars_core::prelude::*; +use polars_utils::min_max::MinMax; + +pub fn ca_nan_agg(ca: &ChunkedArray, min_or_max_fn: Agg) -> Option +where + T: PolarsFloatType, + Agg: Fn(T::Native, T::Native) -> T::Native + Copy, +{ + ca.downcast_iter() + .filter_map(|arr| { + if arr.null_count() == 0 { + arr.values().iter().copied().reduce(min_or_max_fn) + } else { + arr.iter() + .unwrap_optional() + .filter_map(|opt| opt.copied()) + .reduce(min_or_max_fn) + } + }) + .reduce(min_or_max_fn) +} + +pub fn nan_min_s(s: &Series, name: PlSmallStr) -> Series { + match s.dtype() { + DataType::Float32 => { + let ca = s.f32().unwrap(); + Series::new(name, [ca_nan_agg(ca, MinMax::min_propagate_nan)]) + }, + DataType::Float64 => { + let ca = s.f64().unwrap(); + Series::new(name, [ca_nan_agg(ca, MinMax::min_propagate_nan)]) + }, + _ => panic!("expected float"), + } +} + +pub fn nan_max_s(s: &Series, name: PlSmallStr) -> Series { + match s.dtype() { + DataType::Float32 => { + let ca = s.f32().unwrap(); + Series::new(name, [ca_nan_agg(ca, MinMax::max_propagate_nan)]) + }, + DataType::Float64 => { + let ca = s.f64().unwrap(); + Series::new(name, [ca_nan_agg(ca, MinMax::max_propagate_nan)]) + }, + _ => panic!("expected float"), + } +} + +unsafe fn group_nan_max(ca: &ChunkedArray, groups: &GroupsType) -> Series +where + T: PolarsFloatType, + ChunkedArray: IntoSeries, +{ + match groups { + GroupsType::Idx(groups) => _agg_helper_idx::(groups, |(first, idx)| { + debug_assert!(idx.len() <= ca.len()); + if idx.is_empty() { + None + } else if idx.len() == 1 { + ca.get(first as usize) + } else { + match (ca.has_nulls(), ca.chunks().len()) { + (false, 1) => take_agg_no_null_primitive_iter_unchecked( + ca.downcast_iter().next().unwrap(), + idx.iter().map(|i| *i as usize), + MinMax::max_propagate_nan, + ), + (_, 1) => take_agg_primitive_iter_unchecked( + ca.downcast_iter().next().unwrap(), + idx.iter().map(|i| *i as usize), + MinMax::max_propagate_nan, + ), + _ => { + let take = { ca.take_unchecked(idx) }; + ca_nan_agg(&take, MinMax::max_propagate_nan) + }, + } + } + }), + GroupsType::Slice { + groups: groups_slice, + .. + } => { + if _use_rolling_kernels(groups_slice, ca.chunks()) { + let arr = ca.downcast_iter().next().unwrap(); + let values = arr.values().as_slice(); + let offset_iter = groups_slice.iter().map(|[first, len]| (*first, *len)); + let arr = match arr.validity() { + None => _rolling_apply_agg_window_no_nulls::, _, _>( + values, + offset_iter, + None, + ), + Some(validity) => _rolling_apply_agg_window_nulls::< + rolling::nulls::MaxWindow<_>, + _, + _, + >(values, validity, offset_iter, None), + }; + ChunkedArray::from(arr).into_series() + } else { + _agg_helper_slice::(groups_slice, |[first, len]| { + debug_assert!(len <= ca.len() as IdxSize); + match len { + 0 => None, + 1 => ca.get(first as usize), + _ => { + let arr_group = _slice_from_offsets(ca, first, len); + ca_nan_agg(&arr_group, MinMax::max_propagate_nan) + }, + } + }) + } + }, + } +} + +unsafe fn group_nan_min(ca: &ChunkedArray, groups: &GroupsType) -> Series +where + T: PolarsFloatType, + ChunkedArray: IntoSeries, +{ + match groups { + GroupsType::Idx(groups) => _agg_helper_idx::(groups, |(first, idx)| { + debug_assert!(idx.len() <= ca.len()); + if idx.is_empty() { + None + } else if idx.len() == 1 { + ca.get(first as usize) + } else { + match (ca.has_nulls(), ca.chunks().len()) { + (false, 1) => take_agg_no_null_primitive_iter_unchecked( + ca.downcast_iter().next().unwrap(), + idx.iter().map(|i| *i as usize), + MinMax::min_propagate_nan, + ), + (_, 1) => take_agg_primitive_iter_unchecked( + ca.downcast_iter().next().unwrap(), + idx.iter().map(|i| *i as usize), + MinMax::min_propagate_nan, + ), + _ => { + let take = { ca.take_unchecked(idx) }; + ca_nan_agg(&take, MinMax::min_propagate_nan) + }, + } + } + }), + GroupsType::Slice { + groups: groups_slice, + .. + } => { + if _use_rolling_kernels(groups_slice, ca.chunks()) { + let arr = ca.downcast_iter().next().unwrap(); + let values = arr.values().as_slice(); + let offset_iter = groups_slice.iter().map(|[first, len]| (*first, *len)); + let arr = match arr.validity() { + None => _rolling_apply_agg_window_no_nulls::, _, _>( + values, + offset_iter, + None, + ), + Some(validity) => _rolling_apply_agg_window_nulls::< + rolling::nulls::MinWindow<_>, + _, + _, + >(values, validity, offset_iter, None), + }; + ChunkedArray::from(arr).into_series() + } else { + _agg_helper_slice::(groups_slice, |[first, len]| { + debug_assert!(len <= ca.len() as IdxSize); + match len { + 0 => None, + 1 => ca.get(first as usize), + _ => { + let arr_group = _slice_from_offsets(ca, first, len); + ca_nan_agg(&arr_group, MinMax::min_propagate_nan) + }, + } + }) + } + }, + } +} + +/// # Safety +/// `groups` must be in bounds. +pub unsafe fn group_agg_nan_min_s(s: &Series, groups: &GroupsType) -> Series { + match s.dtype() { + DataType::Float32 => { + let ca = s.f32().unwrap(); + group_nan_min(ca, groups) + }, + DataType::Float64 => { + let ca = s.f64().unwrap(); + group_nan_min(ca, groups) + }, + _ => panic!("expected float"), + } +} + +/// # Safety +/// `groups` must be in bounds. +pub unsafe fn group_agg_nan_max_s(s: &Series, groups: &GroupsType) -> Series { + match s.dtype() { + DataType::Float32 => { + let ca = s.f32().unwrap(); + group_nan_max(ca, groups) + }, + DataType::Float64 => { + let ca = s.f64().unwrap(); + group_nan_max(ca, groups) + }, + _ => panic!("expected float"), + } +} diff --git a/crates/polars-ops/src/chunked_array/peaks.rs b/crates/polars-ops/src/chunked_array/peaks.rs new file mode 100644 index 000000000000..7631a07ac141 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/peaks.rs @@ -0,0 +1,22 @@ +use num_traits::Zero; +use polars_core::prelude::*; + +/// Get a boolean mask of the local maximum peaks. +pub fn peak_max(ca: &ChunkedArray) -> BooleanChunked +where + ChunkedArray: for<'a> ChunkCompareIneq<&'a ChunkedArray, Item = BooleanChunked>, +{ + let shift_left = ca.shift_and_fill(1, Some(Zero::zero())); + let shift_right = ca.shift_and_fill(-1, Some(Zero::zero())); + ChunkedArray::lt(&shift_left, ca) & ChunkedArray::lt(&shift_right, ca) +} + +/// Get a boolean mask of the local minimum peaks. +pub fn peak_min(ca: &ChunkedArray) -> BooleanChunked +where + ChunkedArray: for<'a> ChunkCompareIneq<&'a ChunkedArray, Item = BooleanChunked>, +{ + let shift_left = ca.shift_and_fill(1, Some(Zero::zero())); + let shift_right = ca.shift_and_fill(-1, Some(Zero::zero())); + ChunkedArray::gt(&shift_left, ca) & ChunkedArray::gt(&shift_right, ca) +} diff --git a/crates/polars-ops/src/chunked_array/repeat_by.rs b/crates/polars-ops/src/chunked_array/repeat_by.rs new file mode 100644 index 000000000000..c16142e674bc --- /dev/null +++ b/crates/polars-ops/src/chunked_array/repeat_by.rs @@ -0,0 +1,215 @@ +use arrow::array::builder::{ArrayBuilder, ShareStrategy, make_builder}; +use arrow::array::{Array, ListArray}; +use arrow::bitmap::BitmapBuilder; +use arrow::offset::Offsets; +use arrow::pushable::Pushable; +use polars_core::prelude::*; +use polars_core::with_match_physical_numeric_polars_type; + +type LargeListArray = ListArray; + +fn check_lengths(length_srs: usize, length_by: usize) -> PolarsResult<()> { + polars_ensure!( + (length_srs == length_by) | (length_by == 1) | (length_srs == 1), + ShapeMismatch: "repeat_by argument and the Series should have equal length, or at least one of them should have length 1. Series length {}, by length {}", + length_srs, length_by + ); + Ok(()) +} + +fn new_by(by: &IdxCa, len: usize) -> IdxCa { + if let Some(x) = by.get(0) { + let values = std::iter::repeat_n(x, len).collect::>(); + IdxCa::new(PlSmallStr::EMPTY, values) + } else { + IdxCa::full_null(PlSmallStr::EMPTY, len) + } +} + +fn repeat_by_primitive(ca: &ChunkedArray, by: &IdxCa) -> PolarsResult +where + T: PolarsNumericType, +{ + check_lengths(ca.len(), by.len())?; + + match (ca.len(), by.len()) { + (left_len, right_len) if left_len == right_len => { + Ok(arity::binary(ca, by, |arr, by| { + let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| { + opt_by.map(|by| std::iter::repeat_n(opt_v.copied(), *by as usize)) + }); + + // SAFETY: length of iter is trusted. + unsafe { + LargeListArray::from_iter_primitive_trusted_len( + iter, + T::get_dtype().to_arrow(CompatLevel::newest()), + ) + } + })) + }, + (_, 1) => { + let by = new_by(by, ca.len()); + repeat_by_primitive(ca, &by) + }, + (1, _) => { + let new_array = ca.new_from_index(0, by.len()); + repeat_by_primitive(&new_array, by) + }, + // we have already checked the length + _ => unreachable!(), + } +} + +fn repeat_by_bool(ca: &BooleanChunked, by: &IdxCa) -> PolarsResult { + check_lengths(ca.len(), by.len())?; + + match (ca.len(), by.len()) { + (left_len, right_len) if left_len == right_len => { + Ok(arity::binary(ca, by, |arr, by| { + let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| { + opt_by.map(|by| std::iter::repeat_n(opt_v, *by as usize)) + }); + + // SAFETY: length of iter is trusted. + unsafe { LargeListArray::from_iter_bool_trusted_len(iter) } + })) + }, + (_, 1) => { + let by = new_by(by, ca.len()); + repeat_by_bool(ca, &by) + }, + (1, _) => { + let new_array = ca.new_from_index(0, by.len()); + repeat_by_bool(&new_array, by) + }, + // we have already checked the length + _ => unreachable!(), + } +} + +fn repeat_by_binary(ca: &BinaryChunked, by: &IdxCa) -> PolarsResult { + check_lengths(ca.len(), by.len())?; + + match (ca.len(), by.len()) { + (left_len, right_len) if left_len == right_len => { + Ok(arity::binary(ca, by, |arr, by| { + let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| { + opt_by.map(|by| std::iter::repeat_n(opt_v, *by as usize)) + }); + + // SAFETY: length of iter is trusted. + unsafe { LargeListArray::from_iter_binary_trusted_len(iter, ca.len()) } + })) + }, + (_, 1) => { + let by = new_by(by, ca.len()); + repeat_by_binary(ca, &by) + }, + (1, _) => { + let new_array = ca.new_from_index(0, by.len()); + repeat_by_binary(&new_array, by) + }, + // we have already checked the length + _ => unreachable!(), + } +} + +fn repeat_by_list(ca: &ListChunked, by: &IdxCa) -> PolarsResult { + check_lengths(ca.len(), by.len())?; + + match (ca.len(), by.len()) { + (left_len, right_len) if left_len == right_len => Ok(repeat_by_generic_inner(ca, by)), + (_, 1) => { + let by = new_by(by, ca.len()); + repeat_by_list(ca, &by) + }, + (1, _) => { + let new_array = ca.new_from_index(0, by.len()); + repeat_by_list(&new_array, by) + }, + // we have already checked the length + _ => unreachable!(), + } +} + +#[cfg(feature = "dtype-struct")] +fn repeat_by_struct(ca: &StructChunked, by: &IdxCa) -> PolarsResult { + check_lengths(ca.len(), by.len())?; + + match (ca.len(), by.len()) { + (left_len, right_len) if left_len == right_len => Ok(repeat_by_generic_inner(ca, by)), + (_, 1) => { + let by = new_by(by, ca.len()); + repeat_by_struct(ca, &by) + }, + (1, _) => { + let new_array = ca.new_from_index(0, by.len()); + repeat_by_struct(&new_array, by) + }, + // we have already checked the length + _ => unreachable!(), + } +} + +fn repeat_by_generic_inner(ca: &ChunkedArray, by: &IdxCa) -> ListChunked { + let mut builder = make_builder(&ca.dtype().to_arrow(CompatLevel::newest())); + arity::binary(ca, by, |arr, by| { + let arr_length = by.iter().flatten().map(|x| *x as usize).sum(); + builder.reserve(arr_length); + + let mut validity = BitmapBuilder::with_capacity(by.len()); + let mut offsets = Offsets::::with_capacity(by.len()); + for (idx, n_repeat) in by.iter().enumerate() { + validity.push(n_repeat.is_some()); + if let Some(repeats) = n_repeat { + offsets.push(*repeats as usize); + builder.subslice_extend_repeated( + arr, + idx, + 1, + *repeats as usize, + ShareStrategy::Always, + ); + } else { + offsets.push_null(); + } + } + + let repeated_values = builder.freeze_reset(); + LargeListArray::new( + ListArray::::default_datatype(arr.dtype().clone()), + offsets.into(), + repeated_values, + validity.into_opt_validity(), + ) + }) +} + +pub fn repeat_by(s: &Series, by: &IdxCa) -> PolarsResult { + let s_phys = s.to_physical_repr(); + use DataType::*; + let out = match s_phys.dtype() { + Boolean => repeat_by_bool(s_phys.bool().unwrap(), by), + String => { + let ca = s_phys.str().unwrap(); + repeat_by_binary(&ca.as_binary(), by) + .and_then(|ca| ca.apply_to_inner(&|s| unsafe { s.cast_unchecked(&String) })) + }, + Binary => repeat_by_binary(s_phys.binary().unwrap(), by), + dt if dt.is_primitive_numeric() => { + with_match_physical_numeric_polars_type!(dt, |$T| { + let ca: &ChunkedArray<$T> = s_phys.as_ref().as_ref().as_ref(); + repeat_by_primitive(ca, by) + }) + }, + List(_) => repeat_by_list(s_phys.list().unwrap(), by), + #[cfg(feature = "dtype-struct")] + Struct(_) => repeat_by_struct(s_phys.struct_().unwrap(), by), + _ => polars_bail!(opq = repeat_by, s.dtype()), + }; + out.and_then(|ca| { + let logical_type = s.dtype(); + ca.apply_to_inner(&|s| unsafe { s.from_physical_unchecked(logical_type) }) + }) +} diff --git a/crates/polars-ops/src/chunked_array/scatter.rs b/crates/polars-ops/src/chunked_array/scatter.rs new file mode 100644 index 000000000000..06f45eca8b6c --- /dev/null +++ b/crates/polars-ops/src/chunked_array/scatter.rs @@ -0,0 +1,199 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use arrow::array::{Array, PrimitiveArray}; +use polars_core::prelude::*; +use polars_core::series::IsSorted; +use polars_core::utils::arrow::bitmap::MutableBitmap; +use polars_core::utils::arrow::types::NativeType; +use polars_utils::index::check_bounds; + +pub trait ChunkedSet { + /// Invariant for implementations: if the scatter() fails, typically because + /// of bad indexes, then self should remain unmodified. + fn scatter(self, idx: &[IdxSize], values: V) -> PolarsResult + where + V: IntoIterator>; +} +fn check_sorted(idx: &[IdxSize]) -> PolarsResult<()> { + if idx.is_empty() { + return Ok(()); + } + let mut sorted = true; + let mut previous = idx[0]; + for &i in &idx[1..] { + if i < previous { + // we will not break here as that prevents SIMD + sorted = false; + } + previous = i; + } + polars_ensure!(sorted, ComputeError: "set indices must be sorted"); + Ok(()) +} + +trait PolarsOpsNumericType: PolarsNumericType {} + +impl PolarsOpsNumericType for UInt8Type {} +impl PolarsOpsNumericType for UInt16Type {} +impl PolarsOpsNumericType for UInt32Type {} +impl PolarsOpsNumericType for UInt64Type {} +impl PolarsOpsNumericType for Int8Type {} +impl PolarsOpsNumericType for Int16Type {} +impl PolarsOpsNumericType for Int32Type {} +impl PolarsOpsNumericType for Int64Type {} +#[cfg(feature = "dtype-i128")] +impl PolarsOpsNumericType for Int128Type {} +impl PolarsOpsNumericType for Float32Type {} +impl PolarsOpsNumericType for Float64Type {} + +unsafe fn scatter_impl( + new_values_slice: &mut [T], + set_values: V, + arr: &mut PrimitiveArray, + idx: &[IdxSize], + len: usize, +) where + V: IntoIterator>, +{ + let mut values_iter = set_values.into_iter(); + + if arr.null_count() > 0 { + arr.apply_validity(|v| { + let mut mut_validity = v.make_mut(); + + for (idx, val) in idx.iter().zip(&mut values_iter) { + match val { + Some(value) => { + mut_validity.set_unchecked(*idx as usize, true); + *new_values_slice.get_unchecked_mut(*idx as usize) = value + }, + None => mut_validity.set_unchecked(*idx as usize, false), + } + } + mut_validity.into() + }) + } else { + let mut null_idx = vec![]; + for (idx, val) in idx.iter().zip(values_iter) { + match val { + Some(value) => *new_values_slice.get_unchecked_mut(*idx as usize) = value, + None => { + null_idx.push(*idx); + }, + } + } + // only make a validity bitmap when null values are set + if !null_idx.is_empty() { + let mut validity = MutableBitmap::with_capacity(len); + validity.extend_constant(len, true); + for idx in null_idx { + validity.set_unchecked(idx as usize, false) + } + arr.set_validity(Some(validity.into())) + } + } +} + +impl ChunkedSet for &mut ChunkedArray +where + ChunkedArray: IntoSeries, +{ + fn scatter(self, idx: &[IdxSize], values: V) -> PolarsResult + where + V: IntoIterator>, + { + check_bounds(idx, self.len() as IdxSize)?; + let mut ca = std::mem::take(self); + ca.rechunk_mut(); + + // SAFETY: + // we will not modify the length + // and we unset the sorted flag. + ca.set_sorted_flag(IsSorted::Not); + let arr = unsafe { ca.downcast_iter_mut() }.next().unwrap(); + let len = arr.len(); + + match arr.get_mut_values() { + Some(current_values) => { + let ptr = current_values.as_mut_ptr(); + + // reborrow because the bck does not allow it + let current_values = unsafe { &mut *std::slice::from_raw_parts_mut(ptr, len) }; + // SAFETY: + // we checked bounds + unsafe { scatter_impl(current_values, values, arr, idx, len) }; + }, + None => { + let mut new_values = arr.values().as_slice().to_vec(); + // SAFETY: + // we checked bounds + unsafe { scatter_impl(&mut new_values, values, arr, idx, len) }; + arr.set_values(new_values.into()); + }, + }; + + // The null count may have changed - make sure to update the ChunkedArray + let new_null_count = arr.null_count(); + unsafe { ca.set_null_count(new_null_count) }; + + Ok(ca.into_series()) + } +} + +impl<'a> ChunkedSet<&'a str> for &'a StringChunked { + fn scatter(self, idx: &[IdxSize], values: V) -> PolarsResult + where + V: IntoIterator>, + { + check_bounds(idx, self.len() as IdxSize)?; + check_sorted(idx)?; + let mut ca_iter = self.into_iter().enumerate(); + let mut builder = StringChunkedBuilder::new(self.name().clone(), self.len()); + + for (current_idx, current_value) in idx.iter().zip(values) { + for (cnt_idx, opt_val_self) in &mut ca_iter { + if cnt_idx == *current_idx as usize { + builder.append_option(current_value); + break; + } else { + builder.append_option(opt_val_self); + } + } + } + // the last idx is probably not the last value so we finish the iterator + for (_, opt_val_self) in ca_iter { + builder.append_option(opt_val_self); + } + + let ca = builder.finish(); + Ok(ca.into_series()) + } +} +impl ChunkedSet for &BooleanChunked { + fn scatter(self, idx: &[IdxSize], values: V) -> PolarsResult + where + V: IntoIterator>, + { + check_bounds(idx, self.len() as IdxSize)?; + check_sorted(idx)?; + let mut ca_iter = self.into_iter().enumerate(); + let mut builder = BooleanChunkedBuilder::new(self.name().clone(), self.len()); + + for (current_idx, current_value) in idx.iter().zip(values) { + for (cnt_idx, opt_val_self) in &mut ca_iter { + if cnt_idx == *current_idx as usize { + builder.append_option(current_value); + break; + } else { + builder.append_option(opt_val_self); + } + } + } + // the last idx is probably not the last value so we finish the iterator + for (_, opt_val_self) in ca_iter { + builder.append_option(opt_val_self); + } + + let ca = builder.finish(); + Ok(ca.into_series()) + } +} diff --git a/crates/polars-ops/src/chunked_array/strings/case.rs b/crates/polars-ops/src/chunked_array/strings/case.rs new file mode 100644 index 000000000000..4004f122143f --- /dev/null +++ b/crates/polars-ops/src/chunked_array/strings/case.rs @@ -0,0 +1,168 @@ +use polars_core::prelude::StringChunked; + +// Inlined from std. +fn convert_while_ascii(b: &[u8], convert: fn(&u8) -> u8, out: &mut Vec) { + out.clear(); + out.reserve(b.len()); + + const USIZE_SIZE: usize = size_of::(); + const MAGIC_UNROLL: usize = 2; + const N: usize = USIZE_SIZE * MAGIC_UNROLL; + const NONASCII_MASK: usize = usize::from_ne_bytes([0x80; USIZE_SIZE]); + + let mut i = 0; + unsafe { + while i + N <= b.len() { + // SAFETY: we have checks the sizes `b` and `out`. + let in_chunk = b.get_unchecked(i..i + N); + let out_chunk = out.spare_capacity_mut().get_unchecked_mut(i..i + N); + + let mut bits = 0; + for j in 0..MAGIC_UNROLL { + // Read the bytes 1 usize at a time (unaligned since we haven't checked the alignment). + // SAFETY: in_chunk is valid bytes in the range. + bits |= in_chunk.as_ptr().cast::().add(j).read_unaligned(); + } + // If our chunks aren't ascii, then return only the prior bytes as init. + if bits & NONASCII_MASK != 0 { + break; + } + + // Perform the case conversions on N bytes (gets heavily autovec'd). + for j in 0..N { + // SAFETY: in_chunk and out_chunk are valid bytes in the range. + let out = out_chunk.get_unchecked_mut(j); + out.write(convert(in_chunk.get_unchecked(j))); + } + + // Mark these bytes as initialised. + i += N; + } + out.set_len(i); + } +} + +fn to_lowercase_helper(s: &str, buf: &mut Vec) { + convert_while_ascii(s.as_bytes(), u8::to_ascii_lowercase, buf); + + // SAFETY: we know this is a valid char boundary since + // out.len() is only progressed if ASCII bytes are found. + let rest = unsafe { s.get_unchecked(buf.len()..) }; + + // SAFETY: We have written only valid ASCII to our vec. + let mut s = unsafe { String::from_utf8_unchecked(std::mem::take(buf)) }; + + for (i, c) in rest[..].char_indices() { + if c == 'Σ' { + // Σ maps to σ, except at the end of a word where it maps to ς. + // This is the only conditional (contextual) but language-independent mapping + // in `SpecialCasing.txt`, + // so hard-code it rather than have a generic "condition" mechanism. + // See https://github.com/rust-lang/rust/issues/26035 + map_uppercase_sigma(rest, i, &mut s) + } else { + s.extend(c.to_lowercase()); + } + } + + fn map_uppercase_sigma(from: &str, i: usize, to: &mut String) { + // See https://www.unicode.org/versions/Unicode7.0.0/ch03.pdf#G33992 + // for the definition of `Final_Sigma`. + debug_assert!('Σ'.len_utf8() == 2); + let is_word_final = case_ignoreable_then_cased(from[..i].chars().rev()) + && !case_ignoreable_then_cased(from[i + 2..].chars()); + to.push_str(if is_word_final { "ς" } else { "σ" }); + } + + fn case_ignoreable_then_cased>(iter: I) -> bool { + #[cfg(feature = "nightly")] + use core::unicode::{Case_Ignorable, Cased}; + + #[cfg(not(feature = "nightly"))] + use super::unicode_internals::{Case_Ignorable, Cased}; + #[allow(clippy::skip_while_next)] + match iter.skip_while(|&c| Case_Ignorable(c)).next() { + Some(c) => Cased(c), + None => false, + } + } + + // Put buf back for next iteration. + *buf = s.into_bytes(); +} + +pub(super) fn to_lowercase<'a>(ca: &'a StringChunked) -> StringChunked { + // Amortize allocation. + let mut buf = Vec::new(); + let f = |s: &'a str| -> &'a str { + to_lowercase_helper(s, &mut buf); + // SAFETY: apply_mut will copy value from buf before next iteration. + let slice = unsafe { std::str::from_utf8_unchecked(&buf) }; + unsafe { std::mem::transmute::<&str, &'a str>(slice) } + }; + ca.apply_mut(f) +} + +// Inlined from std. +pub(super) fn to_uppercase<'a>(ca: &'a StringChunked) -> StringChunked { + // Amortize allocation. + let mut buf = Vec::new(); + let f = |s: &'a str| -> &'a str { + convert_while_ascii(s.as_bytes(), u8::to_ascii_uppercase, &mut buf); + + // SAFETY: we know this is a valid char boundary since + // out.len() is only progressed if ascii bytes are found. + let rest = unsafe { s.get_unchecked(buf.len()..) }; + + // SAFETY: We have written only valid ASCII to our vec. + let mut s = unsafe { String::from_utf8_unchecked(std::mem::take(&mut buf)) }; + + for c in rest.chars() { + s.extend(c.to_uppercase()); + } + + // Put buf back for next iteration. + buf = s.into_bytes(); + + // SAFETY: apply_mut will copy value from buf before next iteration. + let slice = unsafe { std::str::from_utf8_unchecked(&buf) }; + unsafe { std::mem::transmute::<&str, &'a str>(slice) } + }; + ca.apply_mut(f) +} + +#[cfg(feature = "nightly")] +pub(super) fn to_titlecase<'a>(ca: &'a StringChunked) -> StringChunked { + // Amortize allocation. + let mut buf = Vec::new(); + + // Temporary scratch space. + // We have a double copy as we first convert to lowercase and then copy to `buf`. + let mut scratch = Vec::new(); + let f = |s: &'a str| -> &'a str { + to_lowercase_helper(s, &mut scratch); + let lowercased = unsafe { std::str::from_utf8_unchecked(&scratch) }; + + // SAFETY: the buffer is clear, empty string is valid UTF-8. + buf.clear(); + let mut s = unsafe { String::from_utf8_unchecked(std::mem::take(&mut buf)) }; + + let mut next_is_upper = true; + for c in lowercased.chars() { + if next_is_upper { + s.extend(c.to_uppercase()); + } else { + s.push(c); + } + next_is_upper = !c.is_alphanumeric(); + } + + // Put buf back for next iteration. + buf = s.into_bytes(); + + // SAFETY: apply_mut will copy value from buf before next iteration. + let slice = unsafe { std::str::from_utf8_unchecked(&buf) }; + unsafe { std::mem::transmute::<&str, &'a str>(slice) } + }; + ca.apply_mut(f) +} diff --git a/crates/polars-ops/src/chunked_array/strings/concat.rs b/crates/polars-ops/src/chunked_array/strings/concat.rs new file mode 100644 index 000000000000..1814877323fa --- /dev/null +++ b/crates/polars-ops/src/chunked_array/strings/concat.rs @@ -0,0 +1,167 @@ +use arrow::array::{Utf8Array, ValueSize}; +use polars_compute::cast::utf8_to_utf8view; +use polars_core::prelude::arity::unary_elementwise; +use polars_core::prelude::*; + +// Vertically concatenate all strings in a StringChunked. +pub fn str_join(ca: &StringChunked, delimiter: &str, ignore_nulls: bool) -> StringChunked { + if ca.is_empty() { + return StringChunked::new(ca.name().clone(), &[""]); + } + + // Propagate null value. + if !ignore_nulls && ca.null_count() != 0 { + return StringChunked::full_null(ca.name().clone(), 1); + } + + // Fast path for all nulls. + if ignore_nulls && ca.null_count() == ca.len() { + return StringChunked::new(ca.name().clone(), &[""]); + } + + if ca.len() == 1 { + return ca.clone(); + } + + // Calculate capacity. + let capacity = ca.get_values_size() + delimiter.len() * (ca.len() - 1); + + let mut buf = String::with_capacity(capacity); + let mut first = true; + ca.for_each(|val| { + if let Some(val) = val { + if !first { + buf.push_str(delimiter); + } + buf.push_str(val); + first = false; + } + }); + + let buf = buf.into_bytes(); + assert!(capacity >= buf.len()); + let offsets = vec![0, buf.len() as i64]; + let arr = unsafe { Utf8Array::from_data_unchecked_default(offsets.into(), buf.into(), None) }; + // conversion is cheap with one value. + let arr = utf8_to_utf8view(&arr); + StringChunked::with_chunk(ca.name().clone(), arr) +} + +enum ColumnIter { + Iter(I), + Broadcast(T), +} + +/// Horizontally concatenate all strings. +/// +/// Each array should have length 1 or a length equal to the maximum length. +pub fn hor_str_concat( + cas: &[&StringChunked], + delimiter: &str, + ignore_nulls: bool, +) -> PolarsResult { + if cas.is_empty() { + return Ok(StringChunked::full_null(PlSmallStr::EMPTY, 0)); + } + if cas.len() == 1 { + let ca = cas[0]; + return if !ignore_nulls || ca.null_count() == 0 { + Ok(ca.clone()) + } else { + Ok(unary_elementwise(ca, |val| Some(val.unwrap_or("")))) + }; + } + + // Calculate the post-broadcast length and ensure everything is consistent. + let len = cas + .iter() + .map(|ca| ca.len()) + .filter(|l| *l != 1) + .max() + .unwrap_or(1); + polars_ensure!( + cas.iter().all(|ca| ca.len() == 1 || ca.len() == len), + ShapeMismatch: "all series in `hor_str_concat` should have equal or unit length" + ); + + let mut builder = StringChunkedBuilder::new(cas[0].name().clone(), len); + + // Broadcast if appropriate. + let mut cols: Vec<_> = cas + .iter() + .map(|ca| match ca.len() { + 0 => ColumnIter::Broadcast(None), + 1 => ColumnIter::Broadcast(ca.get(0)), + _ => ColumnIter::Iter(ca.iter()), + }) + .collect(); + + // Build concatenated string. + let mut buf = String::with_capacity(1024); + for _row in 0..len { + let mut has_null = false; + let mut found_not_null_value = false; + for col in cols.iter_mut() { + let val = match col { + ColumnIter::Iter(i) => i.next().unwrap(), + ColumnIter::Broadcast(s) => *s, + }; + + if has_null && !ignore_nulls { + // We know that the result must be null, but we can't just break out of the loop, + // because all cols iterator has to be moved correctly. + continue; + } + + if let Some(s) = val { + if found_not_null_value { + buf.push_str(delimiter); + } + buf.push_str(s); + found_not_null_value = true; + } else { + has_null = true; + } + } + + if !ignore_nulls && has_null { + builder.append_null(); + } else { + builder.append_value(&buf) + } + buf.clear(); + } + + Ok(builder.finish()) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_str_concat() { + let ca = Int32Chunked::new("foo".into(), &[Some(1), None, Some(3)]); + let ca_str = ca.cast(&DataType::String).unwrap(); + let out = str_join(ca_str.str().unwrap(), "-", true); + + let out = out.get(0); + assert_eq!(out, Some("1-3")); + } + + #[test] + fn test_hor_str_concat() { + let a = StringChunked::new("a".into(), &["foo", "bar"]); + let b = StringChunked::new("b".into(), &["spam", "ham"]); + + let out = hor_str_concat(&[&a, &b], "_", true).unwrap(); + assert_eq!(Vec::from(&out), &[Some("foo_spam"), Some("bar_ham")]); + + let c = StringChunked::new("b".into(), &["literal"]); + let out = hor_str_concat(&[&a, &b, &c], "_", true).unwrap(); + assert_eq!( + Vec::from(&out), + &[Some("foo_spam_literal"), Some("bar_ham_literal")] + ); + } +} diff --git a/crates/polars-ops/src/chunked_array/strings/escape_regex.rs b/crates/polars-ops/src/chunked_array/strings/escape_regex.rs new file mode 100644 index 000000000000..1edb9146e9f4 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/strings/escape_regex.rs @@ -0,0 +1,21 @@ +use polars_core::prelude::{StringChunked, StringChunkedBuilder}; + +#[inline] +pub fn escape_regex_str(s: &str) -> String { + regex_syntax::escape(s) +} + +pub fn escape_regex(ca: &StringChunked) -> StringChunked { + let mut buffer = String::new(); + let mut builder = StringChunkedBuilder::new(ca.name().clone(), ca.len()); + for opt_s in ca.iter() { + if let Some(s) = opt_s { + buffer.clear(); + regex_syntax::escape_into(s, &mut buffer); + builder.append_value(&buffer); + } else { + builder.append_null(); + } + } + builder.finish() +} diff --git a/crates/polars-ops/src/chunked_array/strings/extract.rs b/crates/polars-ops/src/chunked_array/strings/extract.rs new file mode 100644 index 000000000000..0f05d58faa35 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/strings/extract.rs @@ -0,0 +1,180 @@ +use std::iter::zip; + +#[cfg(feature = "extract_groups")] +use arrow::array::{Array, StructArray}; +use arrow::array::{MutablePlString, Utf8ViewArray}; +use polars_core::prelude::arity::{try_binary_mut_with_options, try_unary_mut_with_options}; +use regex::Regex; + +use super::*; + +#[cfg(feature = "extract_groups")] +fn extract_groups_array( + arr: &Utf8ViewArray, + reg: &Regex, + names: &[&str], + dtype: ArrowDataType, +) -> PolarsResult { + let mut builders = (0..names.len()) + .map(|_| MutablePlString::with_capacity(arr.len())) + .collect::>(); + + let mut locs = reg.capture_locations(); + for opt_v in arr { + if let Some(s) = opt_v { + if reg.captures_read(&mut locs, s).is_some() { + for (i, builder) in builders.iter_mut().enumerate() { + builder.push(locs.get(i + 1).map(|(start, stop)| &s[start..stop])); + } + continue; + } + } + + // Push nulls if either the string is null or there was no match. We + // distinguish later between the two by copying arr's validity mask. + builders.iter_mut().for_each(|arr| arr.push_null()); + } + + let values = builders.into_iter().map(|a| a.freeze().boxed()).collect(); + Ok(StructArray::new(dtype.clone(), arr.len(), values, arr.validity().cloned()).boxed()) +} + +#[cfg(feature = "extract_groups")] +pub(super) fn extract_groups( + ca: &StringChunked, + pat: &str, + dtype: &DataType, +) -> PolarsResult { + let reg = polars_utils::regex_cache::compile_regex(pat)?; + let n_fields = reg.captures_len(); + if n_fields == 1 { + return StructChunked::from_series( + ca.name().clone(), + ca.len(), + [Series::new_null(ca.name().clone(), ca.len())].iter(), + ) + .map(|ca| ca.into_series()); + } + + let arrow_dtype = dtype.try_to_arrow(CompatLevel::newest())?; + let DataType::Struct(fields) = dtype else { + unreachable!() // Implementation error if it isn't a struct. + }; + let names = fields + .iter() + .map(|fld| fld.name.as_str()) + .collect::>(); + + let chunks = ca + .downcast_iter() + .map(|array| extract_groups_array(array, ®, &names, arrow_dtype.clone())) + .collect::>>()?; + + Series::try_from((ca.name().clone(), chunks)) +} + +fn extract_group_reg_lit( + arr: &Utf8ViewArray, + reg: &Regex, + group_index: usize, +) -> PolarsResult { + let mut builder = MutablePlString::with_capacity(arr.len()); + + let mut locs = reg.capture_locations(); + for opt_v in arr { + if let Some(s) = opt_v { + if reg.captures_read(&mut locs, s).is_some() { + builder.push(locs.get(group_index).map(|(start, stop)| &s[start..stop])); + continue; + } + } + + // Push null if either the string is null or there was no match. + builder.push_null(); + } + + Ok(builder.into()) +} + +fn extract_group_array_lit( + s: &str, + pat: &Utf8ViewArray, + group_index: usize, +) -> PolarsResult { + let mut builder = MutablePlString::with_capacity(pat.len()); + + for opt_pat in pat { + if let Some(pat) = opt_pat { + let reg = polars_utils::regex_cache::compile_regex(pat)?; + let mut locs = reg.capture_locations(); + if reg.captures_read(&mut locs, s).is_some() { + builder.push(locs.get(group_index).map(|(start, stop)| &s[start..stop])); + continue; + } + } + + // Push null if either the pat is null or there was no match. + builder.push_null(); + } + + Ok(builder.into()) +} + +fn extract_group_binary( + arr: &Utf8ViewArray, + pat: &Utf8ViewArray, + group_index: usize, +) -> PolarsResult { + let mut builder = MutablePlString::with_capacity(arr.len()); + + for (opt_s, opt_pat) in zip(arr, pat) { + match (opt_s, opt_pat) { + (Some(s), Some(pat)) => { + let reg = polars_utils::regex_cache::compile_regex(pat)?; + let mut locs = reg.capture_locations(); + if reg.captures_read(&mut locs, s).is_some() { + builder.push(locs.get(group_index).map(|(start, stop)| &s[start..stop])); + continue; + } + // Push null if there was no match. + builder.push_null() + }, + _ => builder.push_null(), + } + } + + Ok(builder.into()) +} + +pub(super) fn extract_group( + ca: &StringChunked, + pat: &StringChunked, + group_index: usize, +) -> PolarsResult { + match (ca.len(), pat.len()) { + (_, 1) => { + if let Some(pat) = pat.get(0) { + let reg = polars_utils::regex_cache::compile_regex(pat)?; + try_unary_mut_with_options(ca, |arr| extract_group_reg_lit(arr, ®, group_index)) + } else { + Ok(StringChunked::full_null(ca.name().clone(), ca.len())) + } + }, + (1, _) => { + if let Some(s) = ca.get(0) { + try_unary_mut_with_options(pat, |pat| extract_group_array_lit(s, pat, group_index)) + } else { + Ok(StringChunked::full_null(ca.name().clone(), pat.len())) + } + }, + (len_ca, len_pat) if len_ca == len_pat => try_binary_mut_with_options( + ca, + pat, + |ca, pat| extract_group_binary(ca, pat, group_index), + ca.name().clone(), + ), + _ => { + polars_bail!(ComputeError: "ca(len: {}) and pat(len: {}) should either broadcast or have the same length", ca.len(), pat.len()) + }, + } +} diff --git a/crates/polars-ops/src/chunked_array/strings/find_many.rs b/crates/polars-ops/src/chunked_array/strings/find_many.rs new file mode 100644 index 000000000000..478f1b05b777 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/strings/find_many.rs @@ -0,0 +1,190 @@ +use aho_corasick::{AhoCorasick, AhoCorasickBuilder}; +use arrow::array::Utf8ViewArray; +use polars_core::prelude::arity::unary_elementwise; +use polars_core::prelude::*; +use polars_core::utils::align_chunks_binary; + +fn build_ac(patterns: &StringChunked, ascii_case_insensitive: bool) -> PolarsResult { + AhoCorasickBuilder::new() + .ascii_case_insensitive(ascii_case_insensitive) + .build(patterns.downcast_iter().flatten().flatten()) + .map_err(|e| polars_err!(ComputeError: "could not build aho corasick automaton {}", e)) +} + +fn build_ac_arr( + patterns: &Utf8ViewArray, + ascii_case_insensitive: bool, +) -> PolarsResult { + AhoCorasickBuilder::new() + .ascii_case_insensitive(ascii_case_insensitive) + .build(patterns.into_iter().flatten()) + .map_err(|e| polars_err!(ComputeError: "could not build aho corasick automaton {}", e)) +} + +pub fn contains_any( + ca: &StringChunked, + patterns: &StringChunked, + ascii_case_insensitive: bool, +) -> PolarsResult { + let ac = build_ac(patterns, ascii_case_insensitive)?; + + Ok(unary_elementwise(ca, |opt_val| { + opt_val.map(|val| ac.find(val).is_some()) + })) +} + +pub fn replace_all( + ca: &StringChunked, + patterns: &StringChunked, + replace_with: &StringChunked, + ascii_case_insensitive: bool, +) -> PolarsResult { + let replace_with = if replace_with.len() == 1 && patterns.len() > 1 { + replace_with.new_from_index(0, patterns.len()) + } else { + replace_with.clone() + }; + + polars_ensure!(patterns.len() == replace_with.len(), InvalidOperation: "expected the same amount of patterns as replacement strings"); + polars_ensure!(patterns.null_count() == 0 && replace_with.null_count() == 0, InvalidOperation: "'patterns'/'replace_with' should not have nulls"); + let replace_with = replace_with + .downcast_iter() + .flatten() + .flatten() + .collect::>(); + + let ac = build_ac(patterns, ascii_case_insensitive)?; + + Ok(unary_elementwise(ca, |opt_val| { + opt_val.map(|val| ac.replace_all(val, replace_with.as_slice())) + })) +} + +fn push_str( + val: &str, + builder: &mut ListStringChunkedBuilder, + ac: &AhoCorasick, + overlapping: bool, +) { + if overlapping { + let iter = ac.find_overlapping_iter(val); + let iter = iter.map(|m| &val[m.start()..m.end()]); + builder.append_values_iter(iter); + } else { + let iter = ac.find_iter(val); + let iter = iter.map(|m| &val[m.start()..m.end()]); + builder.append_values_iter(iter); + } +} + +pub fn extract_many( + ca: &StringChunked, + patterns: &Series, + ascii_case_insensitive: bool, + overlapping: bool, +) -> PolarsResult { + match patterns.dtype() { + DataType::List(inner) if inner.is_string() => { + let mut builder = + ListStringChunkedBuilder::new(ca.name().clone(), ca.len(), ca.len() * 2); + let patterns = patterns.list().unwrap(); + let (ca, patterns) = align_chunks_binary(ca, patterns); + + for (arr, pat_arr) in ca.downcast_iter().zip(patterns.downcast_iter()) { + for z in arr.into_iter().zip(pat_arr.into_iter()) { + match z { + (None, _) | (_, None) => builder.append_null(), + (Some(val), Some(pat)) => { + let pat = pat.as_any().downcast_ref::().unwrap(); + let ac = build_ac_arr(pat, ascii_case_insensitive)?; + push_str(val, &mut builder, &ac, overlapping); + }, + } + } + } + Ok(builder.finish()) + }, + DataType::String => { + let patterns = patterns.str().unwrap(); + let ac = build_ac(patterns, ascii_case_insensitive)?; + let mut builder = + ListStringChunkedBuilder::new(ca.name().clone(), ca.len(), ca.len() * 2); + + for arr in ca.downcast_iter() { + for opt_val in arr.into_iter() { + if let Some(val) = opt_val { + push_str(val, &mut builder, &ac, overlapping); + } else { + builder.append_null(); + } + } + } + Ok(builder.finish()) + }, + _ => { + polars_bail!(InvalidOperation: "expected 'String/List' datatype for 'patterns' argument") + }, + } +} + +type B = ListPrimitiveChunkedBuilder; +fn push_idx(val: &str, builder: &mut B, ac: &AhoCorasick, overlapping: bool) { + if overlapping { + let iter = ac.find_overlapping_iter(val); + let iter = iter.map(|m| m.start() as u32); + builder.append_values_iter(iter); + } else { + let iter = ac.find_iter(val); + let iter = iter.map(|m| m.start() as u32); + builder.append_values_iter(iter); + } +} + +pub fn find_many( + ca: &StringChunked, + patterns: &Series, + ascii_case_insensitive: bool, + overlapping: bool, +) -> PolarsResult { + type B = ListPrimitiveChunkedBuilder; + match patterns.dtype() { + DataType::List(inner) if inner.is_string() => { + let mut builder = B::new(ca.name().clone(), ca.len(), ca.len() * 2, DataType::UInt32); + let patterns = patterns.list().unwrap(); + let (ca, patterns) = align_chunks_binary(ca, patterns); + + for (arr, pat_arr) in ca.downcast_iter().zip(patterns.downcast_iter()) { + for z in arr.into_iter().zip(pat_arr.into_iter()) { + match z { + (None, _) | (_, None) => builder.append_null(), + (Some(val), Some(pat)) => { + let pat = pat.as_any().downcast_ref::().unwrap(); + let ac = build_ac_arr(pat, ascii_case_insensitive)?; + push_idx(val, &mut builder, &ac, overlapping); + }, + } + } + } + Ok(builder.finish()) + }, + DataType::String => { + let patterns = patterns.str().unwrap(); + let ac = build_ac(patterns, ascii_case_insensitive)?; + let mut builder = B::new(ca.name().clone(), ca.len(), ca.len() * 2, DataType::UInt32); + + for arr in ca.downcast_iter() { + for opt_val in arr.into_iter() { + if let Some(val) = opt_val { + push_idx(val, &mut builder, &ac, overlapping); + } else { + builder.append_null(); + } + } + } + Ok(builder.finish()) + }, + _ => { + polars_bail!(InvalidOperation: "expected 'String/List' datatype for 'patterns' argument") + }, + } +} diff --git a/crates/polars-ops/src/chunked_array/strings/json_path.rs b/crates/polars-ops/src/chunked_array/strings/json_path.rs new file mode 100644 index 000000000000..973658256d1c --- /dev/null +++ b/crates/polars-ops/src/chunked_array/strings/json_path.rs @@ -0,0 +1,314 @@ +use std::borrow::Cow; + +use arrow::array::ValueSize; +use jsonpath_lib::PathCompiled; +use polars_core::prelude::arity::{broadcast_try_binary_elementwise, unary_elementwise}; +use serde_json::Value; + +use super::*; + +pub fn extract_json(expr: &PathCompiled, json_str: &str) -> Option { + serde_json::from_str(json_str).ok().and_then(|value| { + // TODO: a lot of heap allocations here. Improve json path by adding a take? + let result = expr.select(&value).ok()?; + let first = *result.first()?; + + match first { + Value::String(s) => Some(s.clone()), + Value::Null => None, + v => Some(v.to_string()), + } + }) +} + +/// Returns a string of the most specific value given the compiled JSON path expression. +/// This avoids creating a list to represent individual elements so that they can be +/// selected directly. +pub fn select_json<'a>(expr: &PathCompiled, json_str: &'a str) -> Option> { + serde_json::from_str(json_str).ok().and_then(|value| { + // TODO: a lot of heap allocations here. Improve json path by adding a take? + let result = expr.select(&value).ok()?; + + let result_str = match result.len() { + 0 => None, + 1 => serde_json::to_string(&result[0]).ok(), + _ => serde_json::to_string(&result).ok(), + }; + + result_str.map(Cow::Owned) + }) +} + +pub trait Utf8JsonPathImpl: AsString { + /// Extract json path, first match + /// Refer to + fn json_path_match(&self, json_path: &StringChunked) -> PolarsResult { + let ca = self.as_string(); + match (ca.len(), json_path.len()) { + (_, 1) => { + // SAFETY: `json_path` was verified to have exactly 1 element. + let opt_path = unsafe { json_path.get_unchecked(0) }; + let out = if let Some(path) = opt_path { + let pat = PathCompiled::compile(path).map_err( + |e| polars_err!(ComputeError: "error compiling JSON path expression {}", e), + )?; + unary_elementwise(ca, |opt_s| opt_s.and_then(|s| extract_json(&pat, s))) + } else { + StringChunked::full_null(ca.name().clone(), ca.len()) + }; + Ok(out) + }, + (len_ca, len_path) if len_ca == 1 || len_ca == len_path => { + broadcast_try_binary_elementwise(ca, json_path, |opt_str, opt_path| { + match (opt_str, opt_path) { + (Some(str_val), Some(path)) => { + PathCompiled::compile(path) + .map_err(|e| polars_err!(ComputeError: "error compiling JSON path expression {}", e)) + .map(|path| extract_json(&path, str_val)) + }, + _ => Ok(None), + } + }) + }, + (len_ca, len_path) => { + polars_bail!(ComputeError: "The length of `ca` and `json_path` should either 1 or the same, but `{}`, `{}` founded", len_ca, len_path) + }, + } + } + + /// Returns the inferred DataType for JSON values for each row + /// in the StringChunked, with an optional number of rows to inspect. + /// When None is passed for the number of rows, all rows are inspected. + fn json_infer(&self, number_of_rows: Option) -> PolarsResult { + let ca = self.as_string(); + let values_iter = ca + .iter() + .map(|x| x.unwrap_or("null")) + .take(number_of_rows.unwrap_or(ca.len())); + + polars_json::ndjson::infer_iter(values_iter) + .map(|d| DataType::from_arrow_dtype(&d)) + .map_err(|e| polars_err!(ComputeError: "error inferring JSON: {}", e)) + } + + /// Extracts a typed-JSON value for each row in the StringChunked + fn json_decode( + &self, + dtype: Option, + infer_schema_len: Option, + ) -> PolarsResult { + let ca = self.as_string(); + // Ignore extra fields instead of erroring if the dtype was explicitly given. + let allow_extra_fields_in_struct = dtype.is_some(); + let dtype = match dtype { + Some(dt) => dt, + None => ca.json_infer(infer_schema_len)?, + }; + let buf_size = ca.get_values_size() + ca.null_count() * "null".len(); + let iter = ca.iter().map(|x| x.unwrap_or("null")); + + let array = polars_json::ndjson::deserialize::deserialize_iter( + iter, + dtype.to_arrow(CompatLevel::newest()), + buf_size, + ca.len(), + allow_extra_fields_in_struct, + ) + .map_err(|e| polars_err!(ComputeError: "error deserializing JSON: {}", e))?; + Series::try_from((PlSmallStr::EMPTY, array)) + } + + fn json_path_select(&self, json_path: &str) -> PolarsResult { + let pat = PathCompiled::compile(json_path) + .map_err(|e| polars_err!(ComputeError: "error compiling JSONpath expression: {}", e))?; + Ok(self + .as_string() + .apply(|opt_s| opt_s.and_then(|s| select_json(&pat, s)))) + } + + fn json_path_extract( + &self, + json_path: &str, + dtype: Option, + infer_schema_len: Option, + ) -> PolarsResult { + let selected_json = self.as_string().json_path_select(json_path)?; + selected_json.json_decode(dtype, infer_schema_len) + } +} + +impl Utf8JsonPathImpl for StringChunked {} + +#[cfg(test)] +mod tests { + use arrow::bitmap::Bitmap; + + use super::*; + + #[test] + fn test_json_select() { + let json_str = r#"{"a":1,"b":{"c":"hello"},"d":[{"e":0},{"e":2},{"e":null}]}"#; + + let compile = |s| PathCompiled::compile(s).unwrap(); + let some_cow = |s: &str| Some(Cow::Owned(s.to_string())); + + assert_eq!(select_json(&compile("$"), json_str), some_cow(json_str)); + assert_eq!(select_json(&compile("$.a"), json_str), some_cow("1")); + assert_eq!( + select_json(&compile("$.b.c"), json_str), + some_cow(r#""hello""#) + ); + assert_eq!(select_json(&compile("$.d[0].e"), json_str), some_cow("0")); + assert_eq!( + select_json(&compile("$.d[2].e"), json_str), + some_cow("null") + ); + assert_eq!( + select_json(&compile("$.d[:].e"), json_str), + some_cow("[0,2,null]") + ); + } + + #[test] + fn test_json_infer() { + let s = Series::new( + "json".into(), + [ + None, + Some(r#"{"a": 1, "b": [{"c": 0}, {"c": 1}]}"#), + Some(r#"{"a": 2, "b": [{"c": 2}, {"c": 5}]}"#), + None, + ], + ); + let ca = s.str().unwrap(); + + let inner_dtype = DataType::Struct(vec![Field::new("c".into(), DataType::Int64)]); + let expected_dtype = DataType::Struct(vec![ + Field::new("a".into(), DataType::Int64), + Field::new("b".into(), DataType::List(Box::new(inner_dtype))), + ]); + + assert_eq!(ca.json_infer(None).unwrap(), expected_dtype); + // Infereing with the first row will only see None + assert_eq!(ca.json_infer(Some(1)).unwrap(), DataType::Null); + assert_eq!(ca.json_infer(Some(2)).unwrap(), expected_dtype); + } + + #[test] + fn test_json_decode() { + let s = Series::new( + "json".into(), + [ + None, + Some(r#"{"a": 1, "b": "hello"}"#), + Some(r#"{"a": 2, "b": "goodbye"}"#), + None, + ], + ); + let ca = s.str().unwrap(); + + let expected_series = StructChunked::from_series( + "".into(), + 4, + [ + Series::new("a".into(), &[None, Some(1), Some(2), None]), + Series::new("b".into(), &[None, Some("hello"), Some("goodbye"), None]), + ] + .iter(), + ) + .unwrap() + .with_outer_validity(Some(Bitmap::from_iter([false, true, true, false]))) + .into_series(); + let expected_dtype = expected_series.dtype().clone(); + + assert!( + ca.json_decode(None, None) + .unwrap() + .equals_missing(&expected_series) + ); + assert!( + ca.json_decode(Some(expected_dtype), None) + .unwrap() + .equals_missing(&expected_series) + ); + } + + #[test] + fn test_json_path_select() { + let s = Series::new( + "json".into(), + [ + None, + Some(r#"{"a":1,"b":[{"c":0},{"c":1}]}"#), + Some(r#"{"a":2,"b":[{"c":2},{"c":5}]}"#), + None, + ], + ); + let ca = s.str().unwrap(); + + assert!( + ca.json_path_select("$") + .unwrap() + .into_series() + .equals_missing(&s) + ); + + let b_series = Series::new( + "json".into(), + [ + None, + Some(r#"[{"c":0},{"c":1}]"#), + Some(r#"[{"c":2},{"c":5}]"#), + None, + ], + ); + assert!( + ca.json_path_select("$.b") + .unwrap() + .into_series() + .equals_missing(&b_series) + ); + + let c_series = Series::new( + "json".into(), + [None, Some(r#"[0,1]"#), Some(r#"[2,5]"#), None], + ); + assert!( + ca.json_path_select("$.b[:].c") + .unwrap() + .into_series() + .equals_missing(&c_series) + ); + } + + #[test] + fn test_json_path_extract() { + let s = Series::new( + "json".into(), + [ + None, + Some(r#"{"a":1,"b":[{"c":0},{"c":1}]}"#), + Some(r#"{"a":2,"b":[{"c":2},{"c":5}]}"#), + None, + ], + ); + let ca = s.str().unwrap(); + + let c_series = Series::new( + "".into(), + [ + None, + Some(Series::new("".into(), &[0, 1])), + Some(Series::new("".into(), &[2, 5])), + None, + ], + ); + + assert!( + ca.json_path_extract("$.b[:].c", None, None) + .unwrap() + .into_series() + .equals_missing(&c_series) + ); + } +} diff --git a/crates/polars-ops/src/chunked_array/strings/mod.rs b/crates/polars-ops/src/chunked_array/strings/mod.rs new file mode 100644 index 000000000000..b1c50dcd37a6 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/strings/mod.rs @@ -0,0 +1,58 @@ +#[cfg(feature = "strings")] +mod case; +#[cfg(feature = "strings")] +mod concat; +#[cfg(feature = "strings")] +mod escape_regex; +#[cfg(feature = "strings")] +mod extract; +#[cfg(feature = "find_many")] +mod find_many; +#[cfg(feature = "extract_jsonpath")] +mod json_path; +#[cfg(feature = "strings")] +mod namespace; +#[cfg(feature = "string_normalize")] +mod normalize; +#[cfg(feature = "string_pad")] +mod pad; +#[cfg(feature = "string_reverse")] +mod reverse; +#[cfg(feature = "strings")] +mod split; +#[cfg(feature = "strings")] +mod strip; +#[cfg(feature = "strings")] +mod substring; +#[cfg(all(not(feature = "nightly"), feature = "strings"))] +mod unicode_internals; + +#[cfg(feature = "strings")] +pub use concat::*; +#[cfg(feature = "strings")] +pub use escape_regex::*; +#[cfg(feature = "find_many")] +pub use find_many::*; +#[cfg(feature = "extract_jsonpath")] +pub use json_path::*; +#[cfg(feature = "strings")] +pub use namespace::*; +#[cfg(feature = "string_normalize")] +pub use normalize::*; +use polars_core::prelude::*; +#[cfg(feature = "strings")] +pub use split::*; +#[cfg(feature = "strings")] +pub use strip::*; +#[cfg(feature = "strings")] +pub use substring::{substring_ternary_offsets_value, update_view}; + +pub trait AsString { + fn as_string(&self) -> &StringChunked; +} + +impl AsString for StringChunked { + fn as_string(&self) -> &StringChunked { + self + } +} diff --git a/crates/polars-ops/src/chunked_array/strings/namespace.rs b/crates/polars-ops/src/chunked_array/strings/namespace.rs new file mode 100644 index 000000000000..8a5c8ac9a3a0 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/strings/namespace.rs @@ -0,0 +1,678 @@ +use arrow::array::ValueSize; +use arrow::legacy::kernels::string::*; +#[cfg(feature = "string_encoding")] +use base64::Engine as _; +#[cfg(feature = "string_encoding")] +use base64::engine::general_purpose; +#[cfg(feature = "string_to_integer")] +use num_traits::Num; +use polars_core::prelude::arity::*; +use polars_utils::regex_cache::{compile_regex, with_regex_cache}; + +use super::*; +#[cfg(feature = "binary_encoding")] +use crate::chunked_array::binary::BinaryNameSpaceImpl; +#[cfg(feature = "string_normalize")] +use crate::prelude::strings::normalize::UnicodeForm; + +// We need this to infer the right lifetimes for the match closure. +#[inline(always)] +fn infer_re_match(f: F) -> F +where + F: for<'a, 'b> FnMut(Option<&'a str>, Option<&'b str>) -> Option, +{ + f +} + +pub trait StringNameSpaceImpl: AsString { + #[cfg(not(feature = "binary_encoding"))] + fn hex_decode(&self) -> PolarsResult { + panic!("activate 'binary_encoding' feature") + } + + #[cfg(feature = "binary_encoding")] + fn hex_decode(&self, strict: bool) -> PolarsResult { + let ca = self.as_string(); + ca.as_binary().hex_decode(strict) + } + + #[must_use] + #[cfg(feature = "string_encoding")] + fn hex_encode(&self) -> StringChunked { + let ca = self.as_string(); + ca.apply_values(|s| hex::encode(s).into()) + } + + #[cfg(not(feature = "binary_encoding"))] + fn base64_decode(&self) -> PolarsResult { + panic!("activate 'binary_encoding' feature") + } + + #[cfg(feature = "binary_encoding")] + fn base64_decode(&self, strict: bool) -> PolarsResult { + let ca = self.as_string(); + ca.as_binary().base64_decode(strict) + } + + #[must_use] + #[cfg(feature = "string_encoding")] + fn base64_encode(&self) -> StringChunked { + let ca = self.as_string(); + ca.apply_values(|s| general_purpose::STANDARD.encode(s).into()) + } + + #[cfg(feature = "string_to_integer")] + // Parse a string number with base _radix_ into a decimal (i64) + fn to_integer(&self, base: &UInt32Chunked, strict: bool) -> PolarsResult { + let ca = self.as_string(); + + polars_ensure!( + ca.len() == base.len() || ca.len() == 1 || base.len() == 1, + length_mismatch = "str.to_integer", + ca.len(), + base.len() + ); + + let f = |opt_s: Option<&str>, opt_base: Option| -> PolarsResult> { + let (Some(s), Some(base)) = (opt_s, opt_base) else { + return Ok(None); + }; + + if !(2..=36).contains(&base) { + polars_bail!(ComputeError: "`to_integer` called with invalid base '{base}'"); + } + + Ok(::from_str_radix(s, base).ok()) + }; + let out = broadcast_try_binary_elementwise(ca, base, f)?; + if strict && ca.null_count() != out.null_count() { + let failure_mask = ca.is_not_null() & out.is_null() & base.is_not_null(); + let n_failures = failure_mask.num_trues(); + if n_failures == 0 { + return Ok(out); + } + + let some_failures = if ca.len() == 1 { + ca.clone() + } else { + let all_failures = ca.filter(&failure_mask)?; + // `.unique()` does not necessarily preserve the original order. + let unique_failures_args = all_failures.arg_unique()?; + all_failures.take(&unique_failures_args.slice(0, 10))? + }; + let some_error_msg = match base.len() { + 1 => { + // we can ensure that base is not null. + let base = base.get(0).unwrap(); + some_failures + .get(0) + .and_then(|s| ::from_str_radix(s, base).err()) + .map_or_else( + || unreachable!("failed to extract ParseIntError"), + |e| format!("{}", e), + ) + }, + _ => { + let base_failures = base.filter(&failure_mask)?; + some_failures + .get(0) + .zip(base_failures.get(0)) + .and_then(|(s, base)| ::from_str_radix(s, base).err()) + .map_or_else( + || unreachable!("failed to extract ParseIntError"), + |e| format!("{}", e), + ) + }, + }; + polars_bail!( + ComputeError: + "strict integer parsing failed for {} value(s): {}; error message for the \ + first shown value: '{}' (consider non-strict parsing)", + n_failures, + some_failures.into_series().fmt_list(), + some_error_msg + ); + }; + + Ok(out) + } + + fn contains_chunked( + &self, + pat: &StringChunked, + literal: bool, + strict: bool, + ) -> PolarsResult { + let ca = self.as_string(); + match (ca.len(), pat.len()) { + (_, 1) => match pat.get(0) { + Some(pat) => { + if literal { + ca.contains_literal(pat) + } else { + ca.contains(pat, strict) + } + }, + None => Ok(BooleanChunked::full_null(ca.name().clone(), ca.len())), + }, + (1, _) if ca.null_count() == 1 => Ok(BooleanChunked::full_null( + ca.name().clone(), + ca.len().max(pat.len()), + )), + _ => { + if literal { + Ok(broadcast_binary_elementwise_values(ca, pat, |src, pat| { + src.contains(pat) + })) + } else if strict { + with_regex_cache(|reg_cache| { + broadcast_try_binary_elementwise(ca, pat, |opt_src, opt_pat| { + match (opt_src, opt_pat) { + (Some(src), Some(pat)) => { + let reg = reg_cache.compile(pat)?; + Ok(Some(reg.is_match(src))) + }, + _ => Ok(None), + } + }) + }) + } else { + with_regex_cache(|reg_cache| { + Ok(broadcast_binary_elementwise( + ca, + pat, + infer_re_match(|src, pat| { + let reg = reg_cache.compile(pat?).ok()?; + Some(reg.is_match(src?)) + }), + )) + }) + } + }, + } + } + + fn find_chunked( + &self, + pat: &StringChunked, + literal: bool, + strict: bool, + ) -> PolarsResult { + let ca = self.as_string(); + if pat.len() == 1 { + return if let Some(pat) = pat.get(0) { + if literal { + ca.find_literal(pat) + } else { + ca.find(pat, strict) + } + } else { + Ok(UInt32Chunked::full_null(ca.name().clone(), ca.len())) + }; + } else if ca.len() == 1 && ca.null_count() == 1 { + return Ok(UInt32Chunked::full_null( + ca.name().clone(), + ca.len().max(pat.len()), + )); + } + if literal { + Ok(broadcast_binary_elementwise( + ca, + pat, + |src: Option<&str>, pat: Option<&str>| src?.find(pat?).map(|idx| idx as u32), + )) + } else { + with_regex_cache(|reg_cache| { + let matcher = |src: Option<&str>, pat: Option<&str>| -> PolarsResult> { + if let (Some(src), Some(pat)) = (src, pat) { + let re = reg_cache.compile(pat)?; + return Ok(re.find(src).map(|m| m.start() as u32)); + } + Ok(None) + }; + broadcast_try_binary_elementwise(ca, pat, matcher) + }) + } + } + + /// Get the length of the string values as number of chars. + fn str_len_chars(&self) -> UInt32Chunked { + let ca = self.as_string(); + ca.apply_kernel_cast(&string_len_chars) + } + + /// Get the length of the string values as number of bytes. + fn str_len_bytes(&self) -> UInt32Chunked { + let ca = self.as_string(); + ca.apply_kernel_cast(&utf8view_len_bytes) + } + + /// Pad the start of the string until it reaches the given length. + /// + /// Padding is done using the specified `fill_char`. + /// Strings with length equal to or greater than the given length are + /// returned as-is. + #[cfg(feature = "string_pad")] + fn pad_start(&self, length: usize, fill_char: char) -> StringChunked { + let ca = self.as_string(); + pad::pad_start(ca, length, fill_char) + } + + /// Pad the end of the string until it reaches the given length. + /// + /// Padding is done using the specified `fill_char`. + /// Strings with length equal to or greater than the given length are + /// returned as-is. + #[cfg(feature = "string_pad")] + fn pad_end(&self, length: usize, fill_char: char) -> StringChunked { + let ca = self.as_string(); + pad::pad_end(ca, length, fill_char) + } + + /// Pad the start of the string with zeros until it reaches the given length. + /// + /// A sign prefix (`-`) is handled by inserting the padding after the sign + /// character rather than before. + /// Strings with length equal to or greater than the given length are + /// returned as-is. + #[cfg(feature = "string_pad")] + fn zfill(&self, length: &UInt64Chunked) -> StringChunked { + let ca = self.as_string(); + pad::zfill(ca, length) + } + + /// Check if strings contain a regex pattern. + fn contains(&self, pat: &str, strict: bool) -> PolarsResult { + let ca = self.as_string(); + let res_reg = polars_utils::regex_cache::compile_regex(pat); + let opt_reg = if strict { Some(res_reg?) } else { res_reg.ok() }; + let out: BooleanChunked = if let Some(reg) = opt_reg { + unary_elementwise_values(ca, |s| reg.is_match(s)) + } else { + BooleanChunked::full_null(ca.name().clone(), ca.len()) + }; + Ok(out) + } + + /// Check if strings contain a given literal + fn contains_literal(&self, lit: &str) -> PolarsResult { + // note: benchmarking shows that the regex engine is actually + // faster at finding literal matches than str::contains. + // ref: https://github.com/pola-rs/polars/pull/6811 + self.contains(regex::escape(lit).as_str(), true) + } + + /// Return the index position of a literal substring in the target string. + fn find_literal(&self, lit: &str) -> PolarsResult { + self.find(regex::escape(lit).as_str(), true) + } + + /// Return the index position of a regular expression substring in the target string. + fn find(&self, pat: &str, strict: bool) -> PolarsResult { + let ca = self.as_string(); + match polars_utils::regex_cache::compile_regex(pat) { + Ok(rx) => Ok(unary_elementwise(ca, |opt_s| { + opt_s.and_then(|s| rx.find(s)).map(|m| m.start() as u32) + })), + Err(_) if !strict => Ok(UInt32Chunked::full_null(ca.name().clone(), ca.len())), + Err(e) => Err(PolarsError::ComputeError( + format!("Invalid regular expression: {}", e).into(), + )), + } + } + + /// Replace the leftmost regex-matched (sub)string with another string + fn replace<'a>(&'a self, pat: &str, val: &str) -> PolarsResult { + let reg = polars_utils::regex_cache::compile_regex(pat)?; + let f = |s: &'a str| reg.replace(s, val); + let ca = self.as_string(); + Ok(ca.apply_values(f)) + } + + /// Replace the leftmost literal (sub)string with another string + fn replace_literal<'a>( + &'a self, + pat: &str, + val: &str, + n: usize, + ) -> PolarsResult { + let ca = self.as_string(); + if ca.is_empty() { + return Ok(ca.clone()); + } + + // amortize allocation + let mut buf = String::new(); + + let f = move |s: &'a str| { + buf.clear(); + let mut changed = false; + + // See: str.replacen + let mut last_end = 0; + for (start, part) in s.match_indices(pat).take(n) { + changed = true; + buf.push_str(unsafe { s.get_unchecked(last_end..start) }); + buf.push_str(val); + last_end = start + part.len(); + } + buf.push_str(unsafe { s.get_unchecked(last_end..s.len()) }); + + if changed { + // extend lifetime + // lifetime is bound to 'a + let slice = buf.as_str(); + unsafe { std::mem::transmute::<&str, &'a str>(slice) } + } else { + s + } + }; + Ok(ca.apply_mut(f)) + } + + /// Replace all regex-matched (sub)strings with another string + fn replace_all(&self, pat: &str, val: &str) -> PolarsResult { + let ca = self.as_string(); + let reg = polars_utils::regex_cache::compile_regex(pat)?; + Ok(ca.apply_values(|s| reg.replace_all(s, val))) + } + + /// Replace all matching literal (sub)strings with another string + fn replace_literal_all<'a>(&'a self, pat: &str, val: &str) -> PolarsResult { + let ca = self.as_string(); + if ca.is_empty() { + return Ok(ca.clone()); + } + + // Amortize allocation. + let mut buf = String::new(); + + let f = move |s: &'a str| { + buf.clear(); + let mut changed = false; + + // See: str.replace. + let mut last_end = 0; + for (start, part) in s.match_indices(pat) { + changed = true; + buf.push_str(unsafe { s.get_unchecked(last_end..start) }); + buf.push_str(val); + last_end = start + part.len(); + } + buf.push_str(unsafe { s.get_unchecked(last_end..s.len()) }); + + if changed { + // Extend lifetime, lifetime is bound to 'a. + let slice = buf.as_str(); + unsafe { std::mem::transmute::<&str, &'a str>(slice) } + } else { + s + } + }; + + Ok(ca.apply_mut(f)) + } + + /// Extract the nth capture group from pattern. + fn extract(&self, pat: &StringChunked, group_index: usize) -> PolarsResult { + let ca = self.as_string(); + super::extract::extract_group(ca, pat, group_index) + } + + /// Extract each successive non-overlapping regex match in an individual string as an array. + fn extract_all(&self, pat: &str) -> PolarsResult { + let ca = self.as_string(); + let reg = polars_utils::regex_cache::compile_regex(pat)?; + + let mut builder = + ListStringChunkedBuilder::new(ca.name().clone(), ca.len(), ca.get_values_size()); + for arr in ca.downcast_iter() { + for opt_s in arr { + match opt_s { + None => builder.append_null(), + Some(s) => builder.append_values_iter(reg.find_iter(s).map(|m| m.as_str())), + } + } + } + Ok(builder.finish()) + } + + fn strip_chars(&self, pat: &Column) -> PolarsResult { + let ca = self.as_string(); + if pat.dtype() == &DataType::Null { + Ok(unary_elementwise(ca, |opt_s| opt_s.map(|s| s.trim()))) + } else { + Ok(strip_chars(ca, pat.str()?)) + } + } + + fn strip_chars_start(&self, pat: &Column) -> PolarsResult { + let ca = self.as_string(); + if pat.dtype() == &DataType::Null { + Ok(unary_elementwise(ca, |opt_s| opt_s.map(|s| s.trim_start()))) + } else { + Ok(strip_chars_start(ca, pat.str()?)) + } + } + + fn strip_chars_end(&self, pat: &Column) -> PolarsResult { + let ca = self.as_string(); + if pat.dtype() == &DataType::Null { + Ok(unary_elementwise(ca, |opt_s| opt_s.map(|s| s.trim_end()))) + } else { + Ok(strip_chars_end(ca, pat.str()?)) + } + } + + fn strip_prefix(&self, prefix: &StringChunked) -> StringChunked { + let ca = self.as_string(); + strip_prefix(ca, prefix) + } + + fn strip_suffix(&self, suffix: &StringChunked) -> StringChunked { + let ca = self.as_string(); + strip_suffix(ca, suffix) + } + + #[cfg(feature = "dtype-struct")] + fn split_exact(&self, by: &StringChunked, n: usize) -> PolarsResult { + let ca = self.as_string(); + + split_to_struct(ca, by, n + 1, str::split, false) + } + + #[cfg(feature = "dtype-struct")] + fn split_exact_inclusive(&self, by: &StringChunked, n: usize) -> PolarsResult { + let ca = self.as_string(); + + split_to_struct(ca, by, n + 1, str::split_inclusive, false) + } + + #[cfg(feature = "dtype-struct")] + fn splitn(&self, by: &StringChunked, n: usize) -> PolarsResult { + let ca = self.as_string(); + + split_to_struct(ca, by, n, |s, by| s.splitn(n, by), true) + } + + fn split(&self, by: &StringChunked) -> PolarsResult { + let ca = self.as_string(); + split_helper(ca, by, str::split) + } + + fn split_inclusive(&self, by: &StringChunked) -> PolarsResult { + let ca = self.as_string(); + split_helper(ca, by, str::split_inclusive) + } + + /// Extract each successive non-overlapping regex match in an individual string as an array. + fn extract_all_many(&self, pat: &StringChunked) -> PolarsResult { + let ca = self.as_string(); + polars_ensure!( + ca.len() == pat.len(), + ComputeError: "pattern's length: {} does not match that of the argument series: {}", + pat.len(), ca.len(), + ); + + let mut builder = + ListStringChunkedBuilder::new(ca.name().clone(), ca.len(), ca.get_values_size()); + with_regex_cache(|re_cache| { + binary_elementwise_for_each(ca, pat, |opt_s, opt_pat| match (opt_s, opt_pat) { + (_, None) | (None, _) => builder.append_null(), + (Some(s), Some(pat)) => { + let re = re_cache.compile(pat).unwrap(); + builder.append_values_iter(re.find_iter(s).map(|m| m.as_str())); + }, + }); + }); + Ok(builder.finish()) + } + + #[cfg(feature = "extract_groups")] + /// Extract all capture groups from pattern and return as a struct. + fn extract_groups(&self, pat: &str, dtype: &DataType) -> PolarsResult { + let ca = self.as_string(); + super::extract::extract_groups(ca, pat, dtype) + } + + /// Count all successive non-overlapping regex matches. + fn count_matches(&self, pat: &str, literal: bool) -> PolarsResult { + let ca = self.as_string(); + if literal { + Ok(unary_elementwise(ca, |opt_s| { + opt_s.map(|s| s.matches(pat).count() as u32) + })) + } else { + let re = compile_regex(pat)?; + Ok(unary_elementwise(ca, |opt_s| { + opt_s.map(|s| re.find_iter(s).count() as u32) + })) + } + } + + /// Count all successive non-overlapping regex matches. + fn count_matches_many( + &self, + pat: &StringChunked, + literal: bool, + ) -> PolarsResult { + let ca = self.as_string(); + polars_ensure!( + ca.len() == pat.len(), + ComputeError: "pattern's length: {} does not match that of the argument series: {}", + pat.len(), ca.len(), + ); + + let out: UInt32Chunked = if literal { + broadcast_binary_elementwise(ca, pat, |s: Option<&str>, p: Option<&str>| { + Some(s?.matches(p?).count() as u32) + }) + } else { + with_regex_cache(|re_cache| { + let op = move |opt_s: Option<&str>, + opt_pat: Option<&str>| + -> PolarsResult> { + match (opt_s, opt_pat) { + (Some(s), Some(pat)) => { + let reg = re_cache.compile(pat)?; + Ok(Some(reg.find_iter(s).count() as u32)) + }, + _ => Ok(None), + } + }; + broadcast_try_binary_elementwise(ca, pat, op) + })? + }; + + Ok(out.with_name(ca.name().clone())) + } + + /// Modify the strings to their lowercase equivalent. + #[must_use] + fn to_lowercase(&self) -> StringChunked { + let ca = self.as_string(); + case::to_lowercase(ca) + } + + /// Modify the strings to their uppercase equivalent. + #[must_use] + fn to_uppercase(&self) -> StringChunked { + let ca = self.as_string(); + case::to_uppercase(ca) + } + + /// Modify the strings to their titlecase equivalent. + #[must_use] + #[cfg(feature = "nightly")] + fn to_titlecase(&self) -> StringChunked { + let ca = self.as_string(); + case::to_titlecase(ca) + } + + /// Concat with the values from a second StringChunked. + #[must_use] + fn concat(&self, other: &StringChunked) -> StringChunked { + let ca = self.as_string(); + ca + other + } + + /// Normalizes the string values + #[must_use] + #[cfg(feature = "string_normalize")] + fn str_normalize(&self, form: UnicodeForm) -> StringChunked { + let ca = self.as_string(); + normalize::normalize(ca, form) + } + + /// Reverses the string values + #[must_use] + #[cfg(feature = "string_reverse")] + fn str_reverse(&self) -> StringChunked { + let ca = self.as_string(); + reverse::reverse(ca) + } + + /// Slice the string values. + /// + /// Determines a substring starting from `offset` and with length `length` of each of the elements in `array`. + /// `offset` can be negative, in which case the start counts from the end of the string. + fn str_slice(&self, offset: &Column, length: &Column) -> PolarsResult { + let ca = self.as_string(); + let offset = offset.cast(&DataType::Int64)?; + // We strict cast, otherwise negative value will be treated as a valid length. + let length = length.strict_cast(&DataType::UInt64)?; + + Ok(substring::substring(ca, offset.i64()?, length.u64()?)) + } + + /// Slice the first `n` values of the string. + /// + /// Determines a substring starting at the beginning of the string up to offset `n` of each + /// element in `array`. `n` can be negative, in which case the slice ends `n` characters from + /// the end of the string. + fn str_head(&self, n: &Column) -> PolarsResult { + let ca = self.as_string(); + let n = n.strict_cast(&DataType::Int64)?; + + substring::head(ca, n.i64()?) + } + + /// Slice the last `n` values of the string. + /// + /// Determines a substring starting at offset `n` of each element in `array`. `n` can be + /// negative, in which case the slice begins `n` characters from the start of the string. + fn str_tail(&self, n: &Column) -> PolarsResult { + let ca = self.as_string(); + let n = n.strict_cast(&DataType::Int64)?; + + substring::tail(ca, n.i64()?) + } + #[cfg(feature = "strings")] + /// Escapes all regular expression meta characters in the string. + fn str_escape_regex(&self) -> StringChunked { + let ca = self.as_string(); + escape_regex::escape_regex(ca) + } +} + +impl StringNameSpaceImpl for StringChunked {} diff --git a/crates/polars-ops/src/chunked_array/strings/normalize.rs b/crates/polars-ops/src/chunked_array/strings/normalize.rs new file mode 100644 index 000000000000..29c4939d5679 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/strings/normalize.rs @@ -0,0 +1,38 @@ +use polars_core::prelude::{StringChunked, StringChunkedBuilder}; +use unicode_normalization::UnicodeNormalization; + +#[derive(Clone, Eq, PartialEq, Hash, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub enum UnicodeForm { + NFC, + NFKC, + NFD, + NFKD, +} + +pub fn normalize_with( + ca: &StringChunked, + normalizer: F, +) -> StringChunked { + let mut buffer = String::new(); + let mut builder = StringChunkedBuilder::new(ca.name().clone(), ca.len()); + for opt_s in ca.iter() { + if let Some(s) = opt_s { + buffer.clear(); + normalizer(s, &mut buffer); + builder.append_value(&buffer); + } else { + builder.append_null(); + } + } + builder.finish() +} + +pub fn normalize(ca: &StringChunked, form: UnicodeForm) -> StringChunked { + match form { + UnicodeForm::NFC => normalize_with(ca, |s, b| b.extend(s.nfc())), + UnicodeForm::NFKC => normalize_with(ca, |s, b| b.extend(s.nfkc())), + UnicodeForm::NFD => normalize_with(ca, |s, b| b.extend(s.nfd())), + UnicodeForm::NFKD => normalize_with(ca, |s, b| b.extend(s.nfkd())), + } +} diff --git a/crates/polars-ops/src/chunked_array/strings/pad.rs b/crates/polars-ops/src/chunked_array/strings/pad.rs new file mode 100644 index 000000000000..8e1bbe4a1dba --- /dev/null +++ b/crates/polars-ops/src/chunked_array/strings/pad.rs @@ -0,0 +1,101 @@ +use std::fmt::Write; + +use polars_core::prelude::arity::broadcast_binary_elementwise; +use polars_core::prelude::{StringChunked, UInt64Chunked}; + +pub(super) fn pad_end<'a>(ca: &'a StringChunked, length: usize, fill_char: char) -> StringChunked { + // amortize allocation + let mut buf = String::new(); + let f = |s: &'a str| { + let padding = length.saturating_sub(s.chars().count()); + if padding == 0 { + s + } else { + buf.clear(); + buf.push_str(s); + for _ in 0..padding { + buf.push(fill_char) + } + // extend lifetime + // lifetime is bound to 'a + let slice = buf.as_str(); + unsafe { std::mem::transmute::<&str, &'a str>(slice) } + } + }; + ca.apply_mut(f) +} + +pub(super) fn pad_start<'a>( + ca: &'a StringChunked, + length: usize, + fill_char: char, +) -> StringChunked { + // amortize allocation + let mut buf = String::new(); + let f = |s: &'a str| { + let padding = length.saturating_sub(s.chars().count()); + if padding == 0 { + s + } else { + buf.clear(); + for _ in 0..padding { + buf.push(fill_char) + } + buf.push_str(s); + // extend lifetime + // lifetime is bound to 'a + let slice = buf.as_str(); + unsafe { std::mem::transmute::<&str, &'a str>(slice) } + } + }; + ca.apply_mut(f) +} + +fn zfill_fn<'a>(s: Option<&'a str>, len: Option, buf: &mut String) -> Option<&'a str> { + match (s, len) { + (Some(s), Some(length)) => { + let length = length.saturating_sub(s.len() as u64); + if length == 0 { + return Some(s); + } + buf.clear(); + if let Some(stripped) = s.strip_prefix('-') { + write!( + buf, + "-{:0length$}{value}", + 0, + length = length as usize, + value = stripped + ) + .unwrap(); + } else { + write!( + buf, + "{:0length$}{value}", + 0, + length = length as usize, + value = s + ) + .unwrap(); + }; + // extend lifetime + // lifetime is bound to 'a + let slice = buf.as_str(); + Some(unsafe { std::mem::transmute::<&str, &'a str>(slice) }) + }, + _ => None, + } +} + +pub(super) fn zfill<'a>(ca: &'a StringChunked, length: &'a UInt64Chunked) -> StringChunked { + // amortize allocation + let mut buf = String::new(); + fn infer FnMut(Option<&'a str>, Option) -> Option<&'a str>>(f: F) -> F where { + f + } + broadcast_binary_elementwise( + ca, + length, + infer(|opt_s, opt_len| zfill_fn(opt_s, opt_len, &mut buf)), + ) +} diff --git a/crates/polars-ops/src/chunked_array/strings/reverse.rs b/crates/polars-ops/src/chunked_array/strings/reverse.rs new file mode 100644 index 000000000000..960adcbd3762 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/strings/reverse.rs @@ -0,0 +1,15 @@ +use polars_core::prelude::StringChunked; +use polars_core::prelude::arity::unary_elementwise; +use unicode_reverse::reverse_grapheme_clusters_in_place; + +fn to_reverse_helper(s: Option<&str>) -> Option { + s.map(|v| { + let mut text = v.to_string(); + reverse_grapheme_clusters_in_place(&mut text); + text + }) +} + +pub fn reverse(ca: &StringChunked) -> StringChunked { + unary_elementwise(ca, to_reverse_helper) +} diff --git a/crates/polars-ops/src/chunked_array/strings/split.rs b/crates/polars-ops/src/chunked_array/strings/split.rs new file mode 100644 index 000000000000..0f4b8e974c05 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/strings/split.rs @@ -0,0 +1,225 @@ +use arrow::array::ValueSize; +#[cfg(feature = "dtype-struct")] +use arrow::array::{MutableArray, MutableUtf8Array}; +use polars_core::chunked_array::ops::arity::binary_elementwise_for_each; + +use super::*; + +pub struct SplitNChars<'a> { + s: &'a str, + n: usize, + keep_remainder: bool, +} + +impl<'a> Iterator for SplitNChars<'a> { + type Item = &'a str; + + fn next(&mut self) -> Option { + let single_char_limit = if self.keep_remainder { 2 } else { 1 }; + if self.n >= single_char_limit { + self.n -= 1; + let ch = self.s.chars().next()?; + let first; + (first, self.s) = self.s.split_at(ch.len_utf8()); + Some(first) + } else if self.n == 1 && !self.s.is_empty() { + self.n -= 1; + Some(self.s) + } else { + None + } + } +} + +/// Splits a string into substrings consisting of single characters. +/// +/// Returns at most n strings, where the last string is the entire remainder +/// of the string if keep_remainder is True, and just the nth character otherwise. +#[cfg(feature = "dtype-struct")] +fn splitn_chars(s: &str, n: usize, keep_remainder: bool) -> SplitNChars<'_> { + SplitNChars { + s, + n, + keep_remainder, + } +} + +/// Splits a string into substrings consisting of single characters. +fn split_chars(s: &str) -> SplitNChars<'_> { + SplitNChars { + s, + n: usize::MAX, + keep_remainder: false, + } +} + +#[cfg(feature = "dtype-struct")] +pub fn split_to_struct<'a, F, I>( + ca: &'a StringChunked, + by: &'a StringChunked, + n: usize, + op: F, + keep_remainder: bool, +) -> PolarsResult +where + F: Fn(&'a str, &'a str) -> I, + I: Iterator, +{ + use polars_utils::format_pl_smallstr; + + let mut arrs = (0..n) + .map(|_| MutableUtf8Array::::with_capacity(ca.len())) + .collect::>(); + + if by.len() == 1 { + if let Some(by) = by.get(0) { + if by.is_empty() { + ca.for_each(|opt_s| match opt_s { + None => { + for arr in &mut arrs { + arr.push_null() + } + }, + Some(s) => { + let mut arr_iter = arrs.iter_mut(); + splitn_chars(s, n, keep_remainder) + .zip(&mut arr_iter) + .for_each(|(splitted, arr)| arr.push(Some(splitted))); + // fill the remaining with null + for arr in arr_iter { + arr.push_null() + } + }, + }); + } else { + ca.for_each(|opt_s| match opt_s { + None => { + for arr in &mut arrs { + arr.push_null() + } + }, + Some(s) => { + let mut arr_iter = arrs.iter_mut(); + op(s, by) + .zip(&mut arr_iter) + .for_each(|(splitted, arr)| arr.push(Some(splitted))); + // fill the remaining with null + for arr in arr_iter { + arr.push_null() + } + }, + }); + } + } else { + for arr in &mut arrs { + arr.push_null() + } + } + } else { + binary_elementwise_for_each(ca, by, |opt_s, opt_by| match (opt_s, opt_by) { + (Some(s), Some(by)) => { + let mut arr_iter = arrs.iter_mut(); + if by.is_empty() { + splitn_chars(s, n, keep_remainder) + .zip(&mut arr_iter) + .for_each(|(splitted, arr)| arr.push(Some(splitted))); + } else { + op(s, by) + .zip(&mut arr_iter) + .for_each(|(splitted, arr)| arr.push(Some(splitted))); + }; + // fill the remaining with null + for arr in arr_iter { + arr.push_null() + } + }, + _ => { + for arr in &mut arrs { + arr.push_null() + } + }, + }) + } + + let fields = arrs + .into_iter() + .enumerate() + .map(|(i, mut arr)| { + Series::try_from((format_pl_smallstr!("field_{i}"), arr.as_box())).unwrap() + }) + .collect::>(); + + StructChunked::from_series(ca.name().clone(), ca.len(), fields.iter()) +} + +pub fn split_helper<'a, F, I>( + ca: &'a StringChunked, + by: &'a StringChunked, + op: F, +) -> PolarsResult +where + F: Fn(&'a str, &'a str) -> I, + I: Iterator, +{ + Ok(match (ca.len(), by.len()) { + (a, b) if a == b => { + let mut builder = + ListStringChunkedBuilder::new(ca.name().clone(), ca.len(), ca.get_values_size()); + + binary_elementwise_for_each(ca, by, |opt_s, opt_by| match (opt_s, opt_by) { + (Some(s), Some(by)) => { + if by.is_empty() { + builder.append_values_iter(split_chars(s)) + } else { + builder.append_values_iter(op(s, by)) + } + }, + _ => builder.append_null(), + }); + + builder.finish() + }, + (1, _) => { + if let Some(s) = ca.get(0) { + let mut builder = ListStringChunkedBuilder::new( + by.name().clone(), + by.len(), + by.get_values_size(), + ); + + by.for_each(|opt_by| match opt_by { + Some(by) => builder.append_values_iter(op(s, by)), + _ => builder.append_null(), + }); + builder.finish() + } else { + ListChunked::full_null_with_dtype(ca.name().clone(), ca.len(), &DataType::String) + } + }, + (_, 1) => { + if let Some(by) = by.get(0) { + let mut builder = ListStringChunkedBuilder::new( + ca.name().clone(), + ca.len(), + ca.get_values_size(), + ); + + if by.is_empty() { + ca.for_each(|opt_s| match opt_s { + Some(s) => builder.append_values_iter(split_chars(s)), + _ => builder.append_null(), + }); + } else { + ca.for_each(|opt_s| match opt_s { + Some(s) => builder.append_values_iter(op(s, by)), + _ => builder.append_null(), + }); + } + builder.finish() + } else { + ListChunked::full_null_with_dtype(ca.name().clone(), ca.len(), &DataType::String) + } + }, + _ => polars_bail!(length_mismatch = "str.split", ca.len(), by.len()), + }) +} diff --git a/crates/polars-ops/src/chunked_array/strings/strip.rs b/crates/polars-ops/src/chunked_array/strings/strip.rs new file mode 100644 index 000000000000..cd92704d6bfe --- /dev/null +++ b/crates/polars-ops/src/chunked_array/strings/strip.rs @@ -0,0 +1,143 @@ +use polars_core::prelude::arity::{broadcast_binary_elementwise, unary_elementwise}; + +use super::*; + +fn strip_chars_binary<'a>(opt_s: Option<&'a str>, opt_pat: Option<&str>) -> Option<&'a str> { + match (opt_s, opt_pat) { + (Some(s), Some(pat)) => { + if pat.chars().count() == 1 { + Some(s.trim_matches(pat.chars().next().unwrap())) + } else { + Some(s.trim_matches(|c| pat.contains(c))) + } + }, + (Some(s), _) => Some(s.trim()), + _ => None, + } +} + +fn strip_chars_start_binary<'a>(opt_s: Option<&'a str>, opt_pat: Option<&str>) -> Option<&'a str> { + match (opt_s, opt_pat) { + (Some(s), Some(pat)) => { + if pat.chars().count() == 1 { + Some(s.trim_start_matches(pat.chars().next().unwrap())) + } else { + Some(s.trim_start_matches(|c| pat.contains(c))) + } + }, + (Some(s), _) => Some(s.trim_start()), + _ => None, + } +} + +fn strip_chars_end_binary<'a>(opt_s: Option<&'a str>, opt_pat: Option<&str>) -> Option<&'a str> { + match (opt_s, opt_pat) { + (Some(s), Some(pat)) => { + if pat.chars().count() == 1 { + Some(s.trim_end_matches(pat.chars().next().unwrap())) + } else { + Some(s.trim_end_matches(|c| pat.contains(c))) + } + }, + (Some(s), _) => Some(s.trim_end()), + _ => None, + } +} + +fn strip_prefix_binary<'a>(s: Option<&'a str>, prefix: Option<&str>) -> Option<&'a str> { + Some(s?.strip_prefix(prefix?).unwrap_or(s?)) +} + +fn strip_suffix_binary<'a>(s: Option<&'a str>, suffix: Option<&str>) -> Option<&'a str> { + Some(s?.strip_suffix(suffix?).unwrap_or(s?)) +} + +pub fn strip_chars(ca: &StringChunked, pat: &StringChunked) -> StringChunked { + match pat.len() { + 1 => { + if let Some(pat) = pat.get(0) { + if pat.chars().count() == 1 { + // Fast path for when a single character is passed + unary_elementwise(ca, |opt_s| { + opt_s.map(|s| s.trim_matches(pat.chars().next().unwrap())) + }) + } else { + unary_elementwise(ca, |opt_s| { + opt_s.map(|s| s.trim_matches(|c| pat.contains(c))) + }) + } + } else { + unary_elementwise(ca, |opt_s| opt_s.map(|s| s.trim())) + } + }, + _ => broadcast_binary_elementwise(ca, pat, strip_chars_binary), + } +} + +pub fn strip_chars_start(ca: &StringChunked, pat: &StringChunked) -> StringChunked { + match pat.len() { + 1 => { + if let Some(pat) = pat.get(0) { + if pat.chars().count() == 1 { + // Fast path for when a single character is passed + unary_elementwise(ca, |opt_s| { + opt_s.map(|s| s.trim_start_matches(pat.chars().next().unwrap())) + }) + } else { + unary_elementwise(ca, |opt_s| { + opt_s.map(|s| s.trim_start_matches(|c| pat.contains(c))) + }) + } + } else { + unary_elementwise(ca, |opt_s| opt_s.map(|s| s.trim_start())) + } + }, + _ => broadcast_binary_elementwise(ca, pat, strip_chars_start_binary), + } +} + +pub fn strip_chars_end(ca: &StringChunked, pat: &StringChunked) -> StringChunked { + match pat.len() { + 1 => { + if let Some(pat) = pat.get(0) { + if pat.chars().count() == 1 { + // Fast path for when a single character is passed + unary_elementwise(ca, |opt_s| { + opt_s.map(|s| s.trim_end_matches(pat.chars().next().unwrap())) + }) + } else { + unary_elementwise(ca, |opt_s| { + opt_s.map(|s| s.trim_end_matches(|c| pat.contains(c))) + }) + } + } else { + unary_elementwise(ca, |opt_s| opt_s.map(|s| s.trim_end())) + } + }, + _ => broadcast_binary_elementwise(ca, pat, strip_chars_end_binary), + } +} + +pub fn strip_prefix(ca: &StringChunked, prefix: &StringChunked) -> StringChunked { + match prefix.len() { + 1 => match prefix.get(0) { + Some(prefix) => unary_elementwise(ca, |opt_s| { + opt_s.map(|s| s.strip_prefix(prefix).unwrap_or(s)) + }), + _ => StringChunked::full_null(ca.name().clone(), ca.len()), + }, + _ => broadcast_binary_elementwise(ca, prefix, strip_prefix_binary), + } +} + +pub fn strip_suffix(ca: &StringChunked, suffix: &StringChunked) -> StringChunked { + match suffix.len() { + 1 => match suffix.get(0) { + Some(suffix) => unary_elementwise(ca, |opt_s| { + opt_s.map(|s| s.strip_suffix(suffix).unwrap_or(s)) + }), + _ => StringChunked::full_null(ca.name().clone(), ca.len()), + }, + _ => broadcast_binary_elementwise(ca, suffix, strip_suffix_binary), + } +} diff --git a/crates/polars-ops/src/chunked_array/strings/substring.rs b/crates/polars-ops/src/chunked_array/strings/substring.rs new file mode 100644 index 000000000000..3631b966e7c9 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/strings/substring.rs @@ -0,0 +1,275 @@ +use std::cmp::Ordering; + +use arrow::array::View; +use polars_core::prelude::arity::{binary_elementwise, ternary_elementwise, unary_elementwise}; +use polars_core::prelude::{ChunkFullNull, Int64Chunked, StringChunked, UInt64Chunked}; +use polars_error::{PolarsResult, polars_ensure}; + +fn head_binary(opt_str_val: Option<&str>, opt_n: Option) -> Option<&str> { + if let (Some(str_val), Some(n)) = (opt_str_val, opt_n) { + let end_idx = head_binary_values(str_val, n); + Some(unsafe { str_val.get_unchecked(..end_idx) }) + } else { + None + } +} + +fn head_binary_values(str_val: &str, n: i64) -> usize { + match n.cmp(&0) { + Ordering::Equal => 0, + Ordering::Greater => { + if n as usize >= str_val.len() { + return str_val.len(); + } + // End after the nth codepoint. + str_val + .char_indices() + .nth(n as usize) + .map(|(idx, _)| idx) + .unwrap_or(str_val.len()) + }, + _ => { + // End after the nth codepoint from the end. + str_val + .char_indices() + .rev() + .nth((-n - 1) as usize) + .map(|(idx, _)| idx) + .unwrap_or(0) + }, + } +} + +fn tail_binary(opt_str_val: Option<&str>, opt_n: Option) -> Option<&str> { + if let (Some(str_val), Some(n)) = (opt_str_val, opt_n) { + let start_idx = tail_binary_values(str_val, n); + Some(unsafe { str_val.get_unchecked(start_idx..) }) + } else { + None + } +} + +fn tail_binary_values(str_val: &str, n: i64) -> usize { + // `max_len` is guaranteed to be at least the total number of characters. + let max_len = str_val.len(); + + match n.cmp(&0) { + Ordering::Equal => max_len, + Ordering::Greater => { + if n as usize >= max_len { + return 0; + } + // Start from nth codepoint from the end + str_val + .char_indices() + .rev() + .nth((n - 1) as usize) + .map(|(idx, _)| idx) + .unwrap_or(0) + }, + _ => { + // Start after the nth codepoint + str_val + .char_indices() + .nth((-n) as usize) + .map(|(idx, _)| idx) + .unwrap_or(max_len) + }, + } +} + +fn substring_ternary_offsets( + opt_str_val: Option<&str>, + opt_offset: Option, + opt_length: Option, +) -> Option<(usize, usize)> { + let str_val = opt_str_val?; + let offset = opt_offset?; + Some(substring_ternary_offsets_value( + str_val, + offset, + opt_length.unwrap_or(u64::MAX), + )) +} + +pub fn substring_ternary_offsets_value(str_val: &str, offset: i64, length: u64) -> (usize, usize) { + // Fast-path: always empty string. + if length == 0 || offset >= str_val.len() as i64 { + return (0, 0); + } + + let mut indices = str_val.char_indices().map(|(o, _)| o); + let mut length_reduction = 0; + let start_byte_offset = if offset >= 0 { + indices.nth(offset as usize).unwrap_or(str_val.len()) + } else { + // If `offset` is negative, it counts from the end of the string. + let mut chars_skipped = 0; + let found = indices + .inspect(|_| chars_skipped += 1) + .nth_back((-offset - 1) as usize); + + // If we didn't find our char that means our offset was so negative it + // is before the start of our string. This means our length must be + // reduced, assuming it is finite. + if let Some(off) = found { + off + } else { + length_reduction = (-offset) as usize - chars_skipped; + 0 + } + }; + + let str_val = &str_val[start_byte_offset..]; + let mut indices = str_val.char_indices().map(|(o, _)| o); + let stop_byte_offset = indices + .nth((length as usize).saturating_sub(length_reduction)) + .unwrap_or(str_val.len()); + (start_byte_offset, stop_byte_offset + start_byte_offset) +} + +fn substring_ternary( + opt_str_val: Option<&str>, + opt_offset: Option, + opt_length: Option, +) -> Option<&str> { + let (start, end) = substring_ternary_offsets(opt_str_val, opt_offset, opt_length)?; + unsafe { opt_str_val.map(|str_val| str_val.get_unchecked(start..end)) } +} + +pub fn update_view(mut view: View, start: usize, end: usize, val: &str) -> View { + let length = (end - start) as u32; + view.length = length; + + // SAFETY: we just compute the start /end. + let subval = unsafe { val.get_unchecked(start..end).as_bytes() }; + + if length <= 12 { + View::new_inline(subval) + } else { + view.offset += start as u32; + view.length = length; + view.prefix = u32::from_le_bytes(subval[0..4].try_into().unwrap()); + view + } +} + +pub(super) fn substring( + ca: &StringChunked, + offset: &Int64Chunked, + length: &UInt64Chunked, +) -> StringChunked { + match (ca.len(), offset.len(), length.len()) { + (1, 1, _) => { + let str_val = ca.get(0); + let offset = offset.get(0); + unary_elementwise(length, |length| substring_ternary(str_val, offset, length)) + .with_name(ca.name().clone()) + }, + (_, 1, 1) => { + let offset = offset.get(0); + let length = length.get(0).unwrap_or(u64::MAX); + + let Some(offset) = offset else { + return StringChunked::full_null(ca.name().clone(), ca.len()); + }; + + unsafe { + ca.apply_views(|view, val| { + let (start, end) = substring_ternary_offsets_value(val, offset, length); + update_view(view, start, end, val) + }) + } + }, + (1, _, 1) => { + let str_val = ca.get(0); + let length = length.get(0); + unary_elementwise(offset, |offset| substring_ternary(str_val, offset, length)) + .with_name(ca.name().clone()) + }, + (1, len_b, len_c) if len_b == len_c => { + let str_val = ca.get(0); + binary_elementwise(offset, length, |offset, length| { + substring_ternary(str_val, offset, length) + }) + }, + (len_a, 1, len_c) if len_a == len_c => { + fn infer FnMut(Option<&'a str>, Option) -> Option<&'a str>>(f: F) -> F where + { + f + } + let offset = offset.get(0); + binary_elementwise( + ca, + length, + infer(|str_val, length| substring_ternary(str_val, offset, length)), + ) + }, + (len_a, len_b, 1) if len_a == len_b => { + fn infer FnMut(Option<&'a str>, Option) -> Option<&'a str>>(f: F) -> F where + { + f + } + let length = length.get(0); + binary_elementwise( + ca, + offset, + infer(|str_val, offset| substring_ternary(str_val, offset, length)), + ) + }, + _ => ternary_elementwise(ca, offset, length, substring_ternary), + } +} + +pub(super) fn head(ca: &StringChunked, n: &Int64Chunked) -> PolarsResult { + match (ca.len(), n.len()) { + (len, 1) => { + let n = n.get(0); + let Some(n) = n else { + return Ok(StringChunked::full_null(ca.name().clone(), len)); + }; + + Ok(unsafe { + ca.apply_views(|view, val| { + let end = head_binary_values(val, n); + update_view(view, 0, end, val) + }) + }) + }, + // TODO! below should also work on only views + (1, _) => { + let str_val = ca.get(0); + Ok(unary_elementwise(n, |n| head_binary(str_val, n)).with_name(ca.name().clone())) + }, + (a, b) => { + polars_ensure!(a == b, ShapeMismatch: "lengths of arguments do not align in 'str.head' got length: {} for column: {}, got length: {} for argument 'n'", a, ca.name(), b); + Ok(binary_elementwise(ca, n, head_binary)) + }, + } +} + +pub(super) fn tail(ca: &StringChunked, n: &Int64Chunked) -> PolarsResult { + Ok(match (ca.len(), n.len()) { + (len, 1) => { + let n = n.get(0); + let Some(n) = n else { + return Ok(StringChunked::full_null(ca.name().clone(), len)); + }; + unsafe { + ca.apply_views(|view, val| { + let start = tail_binary_values(val, n); + update_view(view, start, val.len(), val) + }) + } + }, + // TODO! below should also work on only views + (1, _) => { + let str_val = ca.get(0); + unary_elementwise(n, |n| tail_binary(str_val, n)).with_name(ca.name().clone()) + }, + (a, b) => { + polars_ensure!(a == b, ShapeMismatch: "lengths of arguments do not align in 'str.tail' got length: {} for column: {}, got length: {} for argument 'n'", a, ca.name(), b); + binary_elementwise(ca, n, tail_binary) + }, + }) +} diff --git a/crates/polars-ops/src/chunked_array/strings/unicode_internals/mod.rs b/crates/polars-ops/src/chunked_array/strings/unicode_internals/mod.rs new file mode 100644 index 000000000000..768399942418 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/strings/unicode_internals/mod.rs @@ -0,0 +1,5 @@ +mod unicode_data; + +// For use in alloc, not re-exported in std. +pub(super) use unicode_data::case_ignorable::lookup as Case_Ignorable; +pub(super) use unicode_data::cased::lookup as Cased; diff --git a/crates/polars-ops/src/chunked_array/strings/unicode_internals/unicode_data.rs b/crates/polars-ops/src/chunked_array/strings/unicode_internals/unicode_data.rs new file mode 100644 index 000000000000..c7ac78e16ce7 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/strings/unicode_internals/unicode_data.rs @@ -0,0 +1,134 @@ +//! This file is generated by src/tools/unicode-table-generator; do not edit manually! + +fn decode_prefix_sum(short_offset_run_header: u32) -> u32 { + short_offset_run_header & ((1 << 21) - 1) +} + +fn decode_length(short_offset_run_header: u32) -> usize { + (short_offset_run_header >> 21) as usize +} + +#[inline(always)] +fn skip_search( + needle: u32, + short_offset_runs: &[u32; SOR], + offsets: &[u8; OFFSETS], +) -> bool { + // Note that this *cannot* be past the end of the array, as the last + // element is greater than std::char::MAX (the largest possible needle). + // + // So, we cannot have found it (i.e. Ok(idx) + 1 != length) and the correct + // location cannot be past it, so Err(idx) != length either. + // + // This means that we can avoid bounds checking for the accesses below, too. + let last_idx = + match short_offset_runs.binary_search_by_key(&(needle << 11), |header| header << 11) { + Ok(idx) => idx + 1, + Err(idx) => idx, + }; + + let mut offset_idx = decode_length(short_offset_runs[last_idx]); + let length = if let Some(next) = short_offset_runs.get(last_idx + 1) { + decode_length(*next) - offset_idx + } else { + offsets.len() - offset_idx + }; + let prev = last_idx + .checked_sub(1) + .map(|prev| decode_prefix_sum(short_offset_runs[prev])) + .unwrap_or(0); + + let total = needle - prev; + let mut prefix_sum = 0; + for _ in 0..(length - 1) { + let offset = offsets[offset_idx]; + prefix_sum += offset as u32; + if prefix_sum > total { + break; + } + offset_idx += 1; + } + offset_idx % 2 == 1 +} + +#[rustfmt::skip] +pub mod case_ignorable { + static SHORT_OFFSET_RUNS: [u32; 35] = [ + 688, 44045149, 572528402, 576724925, 807414908, 878718981, 903913493, 929080568, 933275148, + 937491230, 1138818560, 1147208189, 1210124160, 1222707713, 1235291428, 1260457643, + 1264654383, 1499535675, 1507925040, 1566646003, 1629566000, 1650551536, 1658941263, + 1671540720, 1688321181, 1700908800, 1709298023, 1717688832, 1738661888, 1763828398, + 1797383403, 1805773008, 1809970171, 1819148289, 1824457200, + ]; + static OFFSETS: [u8; 875] = [ + 39, 1, 6, 1, 11, 1, 35, 1, 1, 1, 71, 1, 4, 1, 1, 1, 4, 1, 2, 2, 0, 192, 4, 2, 4, 1, 9, 2, + 1, 1, 251, 7, 207, 1, 5, 1, 49, 45, 1, 1, 1, 2, 1, 2, 1, 1, 44, 1, 11, 6, 10, 11, 1, 1, 35, + 1, 10, 21, 16, 1, 101, 8, 1, 10, 1, 4, 33, 1, 1, 1, 30, 27, 91, 11, 58, 11, 4, 1, 2, 1, 24, + 24, 43, 3, 44, 1, 7, 2, 6, 8, 41, 58, 55, 1, 1, 1, 4, 8, 4, 1, 3, 7, 10, 2, 13, 1, 15, 1, + 58, 1, 4, 4, 8, 1, 20, 2, 26, 1, 2, 2, 57, 1, 4, 2, 4, 2, 2, 3, 3, 1, 30, 2, 3, 1, 11, 2, + 57, 1, 4, 5, 1, 2, 4, 1, 20, 2, 22, 6, 1, 1, 58, 1, 2, 1, 1, 4, 8, 1, 7, 2, 11, 2, 30, 1, + 61, 1, 12, 1, 50, 1, 3, 1, 55, 1, 1, 3, 5, 3, 1, 4, 7, 2, 11, 2, 29, 1, 58, 1, 2, 1, 6, 1, + 5, 2, 20, 2, 28, 2, 57, 2, 4, 4, 8, 1, 20, 2, 29, 1, 72, 1, 7, 3, 1, 1, 90, 1, 2, 7, 11, 9, + 98, 1, 2, 9, 9, 1, 1, 7, 73, 2, 27, 1, 1, 1, 1, 1, 55, 14, 1, 5, 1, 2, 5, 11, 1, 36, 9, 1, + 102, 4, 1, 6, 1, 2, 2, 2, 25, 2, 4, 3, 16, 4, 13, 1, 2, 2, 6, 1, 15, 1, 94, 1, 0, 3, 0, 3, + 29, 2, 30, 2, 30, 2, 64, 2, 1, 7, 8, 1, 2, 11, 3, 1, 5, 1, 45, 5, 51, 1, 65, 2, 34, 1, 118, + 3, 4, 2, 9, 1, 6, 3, 219, 2, 2, 1, 58, 1, 1, 7, 1, 1, 1, 1, 2, 8, 6, 10, 2, 1, 39, 1, 8, 31, + 49, 4, 48, 1, 1, 5, 1, 1, 5, 1, 40, 9, 12, 2, 32, 4, 2, 2, 1, 3, 56, 1, 1, 2, 3, 1, 1, 3, + 58, 8, 2, 2, 64, 6, 82, 3, 1, 13, 1, 7, 4, 1, 6, 1, 3, 2, 50, 63, 13, 1, 34, 101, 0, 1, 1, + 3, 11, 3, 13, 3, 13, 3, 13, 2, 12, 5, 8, 2, 10, 1, 2, 1, 2, 5, 49, 5, 1, 10, 1, 1, 13, 1, + 16, 13, 51, 33, 0, 2, 113, 3, 125, 1, 15, 1, 96, 32, 47, 1, 0, 1, 36, 4, 3, 5, 5, 1, 93, 6, + 93, 3, 0, 1, 0, 6, 0, 1, 98, 4, 1, 10, 1, 1, 28, 4, 80, 2, 14, 34, 78, 1, 23, 3, 103, 3, 3, + 2, 8, 1, 3, 1, 4, 1, 25, 2, 5, 1, 151, 2, 26, 18, 13, 1, 38, 8, 25, 11, 46, 3, 48, 1, 2, 4, + 2, 2, 17, 1, 21, 2, 66, 6, 2, 2, 2, 2, 12, 1, 8, 1, 35, 1, 11, 1, 51, 1, 1, 3, 2, 2, 5, 2, + 1, 1, 27, 1, 14, 2, 5, 2, 1, 1, 100, 5, 9, 3, 121, 1, 2, 1, 4, 1, 0, 1, 147, 17, 0, 16, 3, + 1, 12, 16, 34, 1, 2, 1, 169, 1, 7, 1, 6, 1, 11, 1, 35, 1, 1, 1, 47, 1, 45, 2, 67, 1, 21, 3, + 0, 1, 226, 1, 149, 5, 0, 6, 1, 42, 1, 9, 0, 3, 1, 2, 5, 4, 40, 3, 4, 1, 165, 2, 0, 4, 0, 2, + 80, 3, 70, 11, 49, 4, 123, 1, 54, 15, 41, 1, 2, 2, 10, 3, 49, 4, 2, 2, 2, 1, 4, 1, 10, 1, + 50, 3, 36, 5, 1, 8, 62, 1, 12, 2, 52, 9, 10, 4, 2, 1, 95, 3, 2, 1, 1, 2, 6, 1, 2, 1, 157, 1, + 3, 8, 21, 2, 57, 2, 3, 1, 37, 7, 3, 5, 195, 8, 2, 3, 1, 1, 23, 1, 84, 6, 1, 1, 4, 2, 1, 2, + 238, 4, 6, 2, 1, 2, 27, 2, 85, 8, 2, 1, 1, 2, 106, 1, 1, 1, 2, 6, 1, 1, 101, 3, 2, 4, 1, 5, + 0, 9, 1, 2, 0, 2, 1, 1, 4, 1, 144, 4, 2, 2, 4, 1, 32, 10, 40, 6, 2, 4, 8, 1, 9, 6, 2, 3, 46, + 13, 1, 2, 0, 7, 1, 6, 1, 1, 82, 22, 2, 7, 1, 2, 1, 2, 122, 6, 3, 1, 1, 2, 1, 7, 1, 1, 72, 2, + 3, 1, 1, 1, 0, 2, 11, 2, 52, 5, 5, 1, 1, 1, 0, 17, 6, 15, 0, 5, 59, 7, 9, 4, 0, 1, 63, 17, + 64, 2, 1, 2, 0, 4, 1, 7, 1, 2, 0, 2, 1, 4, 0, 46, 2, 23, 0, 3, 9, 16, 2, 7, 30, 4, 148, 3, + 0, 55, 4, 50, 8, 1, 14, 1, 22, 5, 1, 15, 0, 7, 1, 17, 2, 7, 1, 2, 1, 5, 5, 62, 33, 1, 160, + 14, 0, 1, 61, 4, 0, 5, 0, 7, 109, 8, 0, 5, 0, 1, 30, 96, 128, 240, 0, + ]; + pub fn lookup(c: char) -> bool { + super::skip_search( + c as u32, + &SHORT_OFFSET_RUNS, + &OFFSETS, + ) + } +} + +#[rustfmt::skip] +pub mod cased { + static SHORT_OFFSET_RUNS: [u32; 22] = [ + 4256, 115348384, 136322176, 144711446, 163587254, 320875520, 325101120, 350268208, + 392231680, 404815649, 413205504, 421595008, 467733632, 484513952, 492924480, 497144832, + 501339814, 578936576, 627171376, 639756544, 643952944, 649261450, + ]; + static OFFSETS: [u8; 315] = [ + 65, 26, 6, 26, 47, 1, 10, 1, 4, 1, 5, 23, 1, 31, 1, 195, 1, 4, 4, 208, 1, 36, 7, 2, 30, 5, + 96, 1, 42, 4, 2, 2, 2, 4, 1, 1, 6, 1, 1, 3, 1, 1, 1, 20, 1, 83, 1, 139, 8, 166, 1, 38, 9, + 41, 0, 38, 1, 1, 5, 1, 2, 43, 1, 4, 0, 86, 2, 6, 0, 9, 7, 43, 2, 3, 64, 192, 64, 0, 2, 6, 2, + 38, 2, 6, 2, 8, 1, 1, 1, 1, 1, 1, 1, 31, 2, 53, 1, 7, 1, 1, 3, 3, 1, 7, 3, 4, 2, 6, 4, 13, + 5, 3, 1, 7, 116, 1, 13, 1, 16, 13, 101, 1, 4, 1, 2, 10, 1, 1, 3, 5, 6, 1, 1, 1, 1, 1, 1, 4, + 1, 6, 4, 1, 2, 4, 5, 5, 4, 1, 17, 32, 3, 2, 0, 52, 0, 229, 6, 4, 3, 2, 12, 38, 1, 1, 5, 1, + 0, 46, 18, 30, 132, 102, 3, 4, 1, 59, 5, 2, 1, 1, 1, 5, 24, 5, 1, 3, 0, 43, 1, 14, 6, 80, 0, + 7, 12, 5, 0, 26, 6, 26, 0, 80, 96, 36, 4, 36, 116, 11, 1, 15, 1, 7, 1, 2, 1, 11, 1, 15, 1, + 7, 1, 2, 0, 1, 2, 3, 1, 42, 1, 9, 0, 51, 13, 51, 0, 64, 0, 64, 0, 85, 1, 71, 1, 2, 2, 1, 2, + 2, 2, 4, 1, 12, 1, 1, 1, 7, 1, 65, 1, 4, 2, 8, 1, 7, 1, 28, 1, 4, 1, 5, 1, 1, 3, 7, 1, 0, 2, + 25, 1, 25, 1, 31, 1, 25, 1, 31, 1, 25, 1, 31, 1, 25, 1, 31, 1, 25, 1, 8, 0, 10, 1, 20, 6, 6, + 0, 62, 0, 68, 0, 26, 6, 26, 6, 26, 0, + ]; + pub fn lookup(c: char) -> bool { + super::skip_search( + c as u32, + &SHORT_OFFSET_RUNS, + &OFFSETS, + ) + } +} diff --git a/crates/polars-ops/src/chunked_array/sum.rs b/crates/polars-ops/src/chunked_array/sum.rs new file mode 100644 index 000000000000..f07ce2b4071c --- /dev/null +++ b/crates/polars-ops/src/chunked_array/sum.rs @@ -0,0 +1,17 @@ +use arrow::types::NativeType; +use num_traits::{NumCast, ToPrimitive}; + +pub(super) fn sum_slice(values: &[T]) -> S +where + T: NativeType + ToPrimitive, + S: NumCast + std::iter::Sum, +{ + values + .iter() + .copied() + .map(|t| unsafe { + let s: S = NumCast::from(t).unwrap_unchecked(); + s + }) + .sum() +} diff --git a/crates/polars-ops/src/chunked_array/top_k.rs b/crates/polars-ops/src/chunked_array/top_k.rs new file mode 100644 index 000000000000..3088a6ae0ca1 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/top_k.rs @@ -0,0 +1,290 @@ +use arrow::array::{BinaryViewArray, BooleanArray, PrimitiveArray, StaticArray, View}; +use arrow::bitmap::{Bitmap, BitmapBuilder}; +use polars_core::chunked_array::ops::sort::arg_bottom_k::_arg_bottom_k; +use polars_core::prelude::*; +use polars_core::series::IsSorted; +use polars_core::{POOL, downcast_as_macro_arg_physical}; +use polars_utils::total_ord::TotalOrd; + +fn first_n_valid_mask(num_valid: usize, out_len: usize) -> Option { + if num_valid < out_len { + let mut bm = BitmapBuilder::with_capacity(out_len); + bm.extend_constant(num_valid, true); + bm.extend_constant(out_len - num_valid, false); + Some(bm.freeze()) + } else { + None + } +} + +fn top_k_bool_impl( + ca: &ChunkedArray, + k: usize, + descending: bool, +) -> ChunkedArray { + if k >= ca.len() && ca.null_count() == 0 { + return ca.clone(); + } + + let null_count = ca.null_count(); + let non_null_count = ca.len() - ca.null_count(); + let true_count = ca.sum().unwrap() as usize; + let false_count = non_null_count - true_count; + let mut out_len = k.min(ca.len()); + let validity = first_n_valid_mask(non_null_count, out_len); + + // Logical sequence of physical bits. + let sequence = if descending { + [ + (false_count, false), + (true_count, true), + (null_count, false), + ] + } else { + [ + (true_count, true), + (false_count, false), + (null_count, false), + ] + }; + + let mut bm = BitmapBuilder::with_capacity(out_len); + for (n, value) in sequence { + if out_len == 0 { + break; + } + let extra = out_len.min(n); + bm.extend_constant(extra, value); + out_len -= extra; + } + + let arr = BooleanArray::from_data_default(bm.freeze(), validity); + ChunkedArray::with_chunk_like(ca, arr) +} + +fn top_k_num_impl(ca: &ChunkedArray, k: usize, descending: bool) -> ChunkedArray +where + T: PolarsNumericType, +{ + if k >= ca.len() && ca.null_count() == 0 { + return ca.clone(); + } + + // Get rid of all the nulls and transform into Vec. + let mut nnca = ca.drop_nulls(); + nnca.rechunk_mut(); + let chunk = nnca.downcast_into_iter().next().unwrap(); + let (_, buffer, _) = chunk.into_inner(); + let mut vec = buffer.make_mut(); + + // Partition. + if k < vec.len() { + if descending { + vec.select_nth_unstable_by(k, TotalOrd::tot_cmp); + } else { + vec.select_nth_unstable_by(k, |a, b| TotalOrd::tot_cmp(b, a)); + } + } + + // Reconstruct output (with nulls at the end). + let out_len = k.min(ca.len()); + let non_null_count = ca.len() - ca.null_count(); + vec.resize(out_len, T::Native::default()); + let validity = first_n_valid_mask(non_null_count, out_len); + + let arr = PrimitiveArray::from_vec(vec).with_validity_typed(validity); + ChunkedArray::with_chunk_like(ca, arr) +} + +fn top_k_binary_impl( + ca: &ChunkedArray, + k: usize, + descending: bool, +) -> ChunkedArray { + if k >= ca.len() && ca.null_count() == 0 { + return ca.clone(); + } + + // Get rid of all the nulls and transform into mutable views. + let mut nnca = ca.drop_nulls(); + nnca.rechunk_mut(); + let chunk = nnca.downcast_into_iter().next().unwrap(); + let buffers = chunk.data_buffers().clone(); + let mut views = chunk.into_views(); + + // Partition. + if k < views.len() { + if descending { + views.select_nth_unstable_by(k, |a, b| unsafe { + let a_sl = a.get_slice_unchecked(&buffers); + let b_sl = b.get_slice_unchecked(&buffers); + a_sl.cmp(b_sl) + }); + } else { + views.select_nth_unstable_by(k, |a, b| unsafe { + let a_sl = a.get_slice_unchecked(&buffers); + let b_sl = b.get_slice_unchecked(&buffers); + b_sl.cmp(a_sl) + }); + } + } + + // Reconstruct output (with nulls at the end). + let out_len = k.min(ca.len()); + let non_null_count = ca.len() - ca.null_count(); + views.resize(out_len, View::default()); + let validity = first_n_valid_mask(non_null_count, out_len); + + let arr = unsafe { + BinaryViewArray::new_unchecked_unknown_md( + ArrowDataType::BinaryView, + views.into(), + buffers, + validity, + None, + ) + }; + ChunkedArray::with_chunk_like(ca, arr) +} + +pub fn top_k(s: &[Column], descending: bool) -> PolarsResult { + fn extract_target_and_k(s: &[Column]) -> PolarsResult<(usize, &Column)> { + let k_s = &s[1]; + polars_ensure!( + k_s.len() == 1, + ComputeError: "`k` must be a single value for `top_k`." + ); + + let Some(k) = k_s.cast(&IDX_DTYPE)?.idx()?.get(0) else { + polars_bail!(ComputeError: "`k` must be set for `top_k`") + }; + + let src = &s[0]; + Ok((k as usize, src)) + } + + let (k, src) = extract_target_and_k(s)?; + + if src.is_empty() { + return Ok(src.clone()); + } + + let sorted_flag = src.is_sorted_flag(); + let is_sorted = match src.is_sorted_flag() { + IsSorted::Ascending => true, + IsSorted::Descending => true, + IsSorted::Not => false, + }; + if is_sorted { + let out_len = k.min(src.len()); + let ignored_len = src.len() - out_len; + let slice_at_start = (sorted_flag == IsSorted::Ascending) == descending; + let nulls_at_start = src.get(0).unwrap() == AnyValue::Null; + let offset = if nulls_at_start == slice_at_start { + src.null_count().min(ignored_len) + } else { + 0 + }; + + return if slice_at_start { + Ok(src.slice(offset as i64, out_len)) + } else { + Ok(src.slice(-(offset as i64) - (out_len as i64), out_len)) + }; + } + + let origin_dtype = src.dtype(); + + let s = src.to_physical_repr(); + + match s.dtype() { + DataType::Boolean => Ok(top_k_bool_impl(s.bool().unwrap(), k, descending).into_column()), + DataType::String => { + let ca = top_k_binary_impl(&s.str().unwrap().as_binary(), k, descending); + let ca = unsafe { ca.to_string_unchecked() }; + Ok(ca.into_column()) + }, + DataType::Binary => Ok(top_k_binary_impl(s.binary().unwrap(), k, descending).into_column()), + DataType::Null => Ok(src.slice(0, k)), + dt if dt.is_primitive_numeric() => { + macro_rules! dispatch { + ($ca:expr) => {{ top_k_num_impl($ca, k, descending).into_column() }}; + } + unsafe { + downcast_as_macro_arg_physical!(&s, dispatch).from_physical_unchecked(origin_dtype) + } + }, + _ => { + // Fallback to more generic impl. + top_k_by_impl(k, src, &[src.clone()], vec![descending]) + }, + } +} + +pub fn top_k_by(s: &[Column], descending: Vec) -> PolarsResult { + /// Return (k, src, by) + fn extract_parameters(s: &[Column]) -> PolarsResult<(usize, &Column, &[Column])> { + let k_s = &s[1]; + + polars_ensure!( + k_s.len() == 1, + ComputeError: "`k` must be a single value for `top_k`." + ); + + let Some(k) = k_s.cast(&IDX_DTYPE)?.idx()?.get(0) else { + polars_bail!(ComputeError: "`k` must be set for `top_k`") + }; + + let src = &s[0]; + + let by = &s[2..]; + + Ok((k as usize, src, by)) + } + + let (k, src, by) = extract_parameters(s)?; + + if src.is_empty() { + return Ok(src.clone()); + } + + if by.first().map(|x| x.is_empty()).unwrap_or(false) { + return Ok(src.clone()); + } + + for s in by { + if s.len() != src.len() { + polars_bail!(ComputeError: "`by` column's ({}) length ({}) should have the same length as the source column length ({}) in `top_k`", s.name(), s.len(), src.len()) + } + } + + top_k_by_impl(k, src, by, descending) +} + +fn top_k_by_impl( + k: usize, + src: &Column, + by: &[Column], + descending: Vec, +) -> PolarsResult { + if src.is_empty() { + return Ok(src.clone()); + } + + let multithreaded = k >= 10000 && POOL.current_num_threads() > 1; + let mut sort_options = SortMultipleOptions { + descending: descending.into_iter().map(|x| !x).collect(), + nulls_last: vec![true; by.len()], + multithreaded, + maintain_order: false, + limit: None, + }; + + let idx = _arg_bottom_k(k, by, &mut sort_options)?; + + let result = unsafe { + src.as_materialized_series() + .take_unchecked(&idx.into_inner()) + }; + Ok(result.into()) +} diff --git a/crates/polars-ops/src/frame/join/args.rs b/crates/polars-ops/src/frame/join/args.rs new file mode 100644 index 000000000000..fda5397cd375 --- /dev/null +++ b/crates/polars-ops/src/frame/join/args.rs @@ -0,0 +1,391 @@ +use super::*; + +pub(super) type JoinIds = Vec; +pub type LeftJoinIds = (ChunkJoinIds, ChunkJoinOptIds); +pub type InnerJoinIds = (JoinIds, JoinIds); + +#[cfg(feature = "chunked_ids")] +pub(super) type ChunkJoinIds = Either, Vec>; +#[cfg(feature = "chunked_ids")] +pub type ChunkJoinOptIds = Either, Vec>; + +#[cfg(not(feature = "chunked_ids"))] +pub type ChunkJoinOptIds = Vec; + +#[cfg(not(feature = "chunked_ids"))] +pub type ChunkJoinIds = Vec; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; + +#[derive(Clone, PartialEq, Eq, Debug, Hash, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct JoinArgs { + pub how: JoinType, + pub validation: JoinValidation, + pub suffix: Option, + pub slice: Option<(i64, usize)>, + pub nulls_equal: bool, + pub coalesce: JoinCoalesce, + pub maintain_order: MaintainOrderJoin, +} + +impl JoinArgs { + pub fn should_coalesce(&self) -> bool { + self.coalesce.coalesce(&self.how) + } +} + +#[derive(Clone, PartialEq, Eq, Hash, Default, IntoStaticStr)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum JoinType { + #[default] + Inner, + Left, + Right, + Full, + #[cfg(feature = "asof_join")] + AsOf(AsOfOptions), + #[cfg(feature = "semi_anti_join")] + Semi, + #[cfg(feature = "semi_anti_join")] + Anti, + #[cfg(feature = "iejoin")] + // Options are set by optimizer/planner in Options + IEJoin, + // Options are set by optimizer/planner in Options + Cross, +} + +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum JoinCoalesce { + #[default] + JoinSpecific, + CoalesceColumns, + KeepColumns, +} + +impl JoinCoalesce { + pub fn coalesce(&self, join_type: &JoinType) -> bool { + use JoinCoalesce::*; + use JoinType::*; + match join_type { + Left | Inner | Right => { + matches!(self, JoinSpecific | CoalesceColumns) + }, + Full => { + matches!(self, CoalesceColumns) + }, + #[cfg(feature = "asof_join")] + AsOf(_) => matches!(self, JoinSpecific | CoalesceColumns), + #[cfg(feature = "iejoin")] + IEJoin => false, + Cross => false, + #[cfg(feature = "semi_anti_join")] + Semi | Anti => false, + } + } +} + +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default, IntoStaticStr)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] +pub enum MaintainOrderJoin { + #[default] + None, + Left, + Right, + LeftRight, + RightLeft, +} + +impl MaintainOrderJoin { + pub(super) fn flip(&self) -> Self { + match self { + MaintainOrderJoin::None => MaintainOrderJoin::None, + MaintainOrderJoin::Left => MaintainOrderJoin::Right, + MaintainOrderJoin::Right => MaintainOrderJoin::Left, + MaintainOrderJoin::LeftRight => MaintainOrderJoin::RightLeft, + MaintainOrderJoin::RightLeft => MaintainOrderJoin::LeftRight, + } + } +} + +impl JoinArgs { + pub fn new(how: JoinType) -> Self { + Self { + how, + validation: Default::default(), + suffix: None, + slice: None, + nulls_equal: false, + coalesce: Default::default(), + maintain_order: Default::default(), + } + } + + pub fn with_coalesce(mut self, coalesce: JoinCoalesce) -> Self { + self.coalesce = coalesce; + self + } + + pub fn with_suffix(mut self, suffix: Option) -> Self { + self.suffix = suffix; + self + } + + pub fn suffix(&self) -> &PlSmallStr { + const DEFAULT: &PlSmallStr = &PlSmallStr::from_static("_right"); + self.suffix.as_ref().unwrap_or(DEFAULT) + } +} + +impl From for JoinArgs { + fn from(value: JoinType) -> Self { + JoinArgs::new(value) + } +} + +pub trait CrossJoinFilter: Send + Sync { + fn apply(&self, df: DataFrame) -> PolarsResult; +} + +impl CrossJoinFilter for T +where + T: Fn(DataFrame) -> PolarsResult + Send + Sync, +{ + fn apply(&self, df: DataFrame) -> PolarsResult { + self(df) + } +} + +#[derive(Clone)] +pub struct CrossJoinOptions { + pub predicate: Arc, +} + +impl CrossJoinOptions { + fn as_ptr_ref(&self) -> *const dyn CrossJoinFilter { + Arc::as_ptr(&self.predicate) + } +} + +impl Eq for CrossJoinOptions {} + +impl PartialEq for CrossJoinOptions { + fn eq(&self, other: &Self) -> bool { + std::ptr::addr_eq(self.as_ptr_ref(), other.as_ptr_ref()) + } +} + +impl Hash for CrossJoinOptions { + fn hash(&self, state: &mut H) { + self.as_ptr_ref().hash(state); + } +} + +impl Debug for CrossJoinOptions { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "CrossJoinOptions",) + } +} + +#[derive(Clone, PartialEq, Eq, Hash, IntoStaticStr, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] +pub enum JoinTypeOptions { + #[cfg(feature = "iejoin")] + IEJoin(IEJoinOptions), + #[cfg_attr(feature = "serde", serde(skip))] + Cross(CrossJoinOptions), +} + +impl Display for JoinType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use JoinType::*; + let val = match self { + Left => "LEFT", + Right => "RIGHT", + Inner => "INNER", + Full => "FULL", + #[cfg(feature = "asof_join")] + AsOf(_) => "ASOF", + #[cfg(feature = "iejoin")] + IEJoin => "IEJOIN", + Cross => "CROSS", + #[cfg(feature = "semi_anti_join")] + Semi => "SEMI", + #[cfg(feature = "semi_anti_join")] + Anti => "ANTI", + }; + write!(f, "{val}") + } +} + +impl Debug for JoinType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{self}") + } +} + +impl JoinType { + pub fn is_equi(&self) -> bool { + matches!( + self, + JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full + ) + } + + pub fn is_semi_anti(&self) -> bool { + #[cfg(feature = "semi_anti_join")] + { + matches!(self, JoinType::Semi | JoinType::Anti) + } + #[cfg(not(feature = "semi_anti_join"))] + { + false + } + } + + pub fn is_asof(&self) -> bool { + #[cfg(feature = "asof_join")] + { + matches!(self, JoinType::AsOf(_)) + } + #[cfg(not(feature = "asof_join"))] + { + false + } + } + + pub fn is_cross(&self) -> bool { + matches!(self, JoinType::Cross) + } + + pub fn is_ie(&self) -> bool { + #[cfg(feature = "iejoin")] + { + matches!(self, JoinType::IEJoin) + } + #[cfg(not(feature = "iejoin"))] + { + false + } + } +} + +#[derive(Copy, Clone, PartialEq, Eq, Default, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum JoinValidation { + /// No unique checks + #[default] + ManyToMany, + /// Check if join keys are unique in right dataset. + ManyToOne, + /// Check if join keys are unique in left dataset. + OneToMany, + /// Check if join keys are unique in both left and right datasets + OneToOne, +} + +impl JoinValidation { + pub fn needs_checks(&self) -> bool { + !matches!(self, JoinValidation::ManyToMany) + } + + fn swap(self, swap: bool) -> Self { + use JoinValidation::*; + if swap { + match self { + ManyToMany => ManyToMany, + ManyToOne => OneToMany, + OneToMany => ManyToOne, + OneToOne => OneToOne, + } + } else { + self + } + } + + pub fn is_valid_join(&self, join_type: &JoinType) -> PolarsResult<()> { + if !self.needs_checks() { + return Ok(()); + } + polars_ensure!(matches!(join_type, JoinType::Inner | JoinType::Full | JoinType::Left), + ComputeError: "{self} validation on a {join_type} join is not supported"); + Ok(()) + } + + pub(super) fn validate_probe( + &self, + s_left: &Series, + s_right: &Series, + build_shortest_table: bool, + nulls_equal: bool, + ) -> PolarsResult<()> { + // In default, probe is the left series. + // + // In inner join and outer join, the shortest relation will be used to create a hash table. + // In left join, always use the right side to create. + // + // If `build_shortest_table` and left is shorter, swap. Then rhs will be the probe. + // If left == right, swap too. (apply the same logic as `det_hash_prone_order`) + let should_swap = build_shortest_table && s_left.len() <= s_right.len(); + let probe = if should_swap { s_right } else { s_left }; + + use JoinValidation::*; + let valid = match self.swap(should_swap) { + // Only check the `build` side. + // The other side use `validate_build` to check + ManyToMany | ManyToOne => true, + OneToMany | OneToOne => { + if !nulls_equal && probe.null_count() > 0 { + probe.n_unique()? - 1 == probe.len() - probe.null_count() + } else { + probe.n_unique()? == probe.len() + } + }, + }; + polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self); + Ok(()) + } + + pub(super) fn validate_build( + &self, + build_size: usize, + expected_size: usize, + swapped: bool, + ) -> PolarsResult<()> { + use JoinValidation::*; + + // In default, build is in rhs. + let valid = match self.swap(swapped) { + // Only check the `build` side. + // The other side use `validate_prone` to check + ManyToMany | OneToMany => true, + ManyToOne | OneToOne => build_size == expected_size, + }; + polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self); + Ok(()) + } +} + +impl Display for JoinValidation { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let s = match self { + JoinValidation::ManyToMany => "m:m", + JoinValidation::ManyToOne => "m:1", + JoinValidation::OneToMany => "1:m", + JoinValidation::OneToOne => "1:1", + }; + write!(f, "{s}") + } +} + +impl Debug for JoinValidation { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "JoinValidation: {self}") + } +} diff --git a/crates/polars-ops/src/frame/join/asof/default.rs b/crates/polars-ops/src/frame/join/asof/default.rs new file mode 100644 index 000000000000..ba2e455b51f8 --- /dev/null +++ b/crates/polars-ops/src/frame/join/asof/default.rs @@ -0,0 +1,237 @@ +use arrow::array::Array; +use arrow::bitmap::Bitmap; +use num_traits::Zero; +use polars_core::prelude::*; +use polars_utils::abs_diff::AbsDiff; + +use super::{ + AsofJoinBackwardState, AsofJoinForwardState, AsofJoinNearestState, AsofJoinState, AsofStrategy, +}; + +fn join_asof_impl<'a, T, S, F>( + left: &'a T::Array, + right: &'a T::Array, + mut filter: F, + allow_eq: bool, +) -> IdxCa +where + T: PolarsDataType, + S: AsofJoinState>, + F: FnMut(T::Physical<'a>, T::Physical<'a>) -> bool, +{ + if left.len() == left.null_count() || right.len() == right.null_count() { + return IdxCa::full_null(PlSmallStr::EMPTY, left.len()); + } + + let mut out = vec![0; left.len()]; + let mut mask = vec![0; left.len().div_ceil(8)]; + let mut state = S::new(allow_eq); + + if left.null_count() == 0 && right.null_count() == 0 { + for (i, val_l) in left.values_iter().enumerate() { + if let Some(r_idx) = state.next( + &val_l, + // SAFETY: next() only calls with indices < right.len(). + |j| Some(unsafe { right.value_unchecked(j as usize) }), + right.len() as IdxSize, + ) { + // SAFETY: r_idx is non-null and valid. + unsafe { + let val_r = right.value_unchecked(r_idx as usize); + *out.get_unchecked_mut(i) = r_idx; + *mask.get_unchecked_mut(i / 8) |= (filter(val_l, val_r) as u8) << (i % 8); + } + } + } + } else { + for (i, opt_val_l) in left.iter().enumerate() { + if let Some(val_l) = opt_val_l { + if let Some(r_idx) = state.next( + &val_l, + // SAFETY: next() only calls with indices < right.len(). + |j| unsafe { right.get_unchecked(j as usize) }, + right.len() as IdxSize, + ) { + // SAFETY: r_idx is non-null and valid. + unsafe { + let val_r = right.value_unchecked(r_idx as usize); + *out.get_unchecked_mut(i) = r_idx; + *mask.get_unchecked_mut(i / 8) |= (filter(val_l, val_r) as u8) << (i % 8); + } + } + } + } + } + + let bitmap = Bitmap::try_new(mask, out.len()).unwrap(); + IdxCa::from_vec_validity(PlSmallStr::EMPTY, out, Some(bitmap)) +} + +fn join_asof_forward<'a, T, F>( + left: &'a T::Array, + right: &'a T::Array, + filter: F, + allow_eq: bool, +) -> IdxCa +where + T: PolarsDataType, + T::Physical<'a>: PartialOrd, + F: FnMut(T::Physical<'a>, T::Physical<'a>) -> bool, +{ + join_asof_impl::<'a, T, AsofJoinForwardState, _>(left, right, filter, allow_eq) +} + +fn join_asof_backward<'a, T, F>( + left: &'a T::Array, + right: &'a T::Array, + filter: F, + allow_eq: bool, +) -> IdxCa +where + T: PolarsDataType, + T::Physical<'a>: PartialOrd, + F: FnMut(T::Physical<'a>, T::Physical<'a>) -> bool, +{ + join_asof_impl::<'a, T, AsofJoinBackwardState, _>(left, right, filter, allow_eq) +} + +fn join_asof_nearest<'a, T, F>( + left: &'a T::Array, + right: &'a T::Array, + filter: F, + allow_eq: bool, +) -> IdxCa +where + T: PolarsDataType, + T::Physical<'a>: NumericNative, + F: FnMut(T::Physical<'a>, T::Physical<'a>) -> bool, +{ + join_asof_impl::<'a, T, AsofJoinNearestState, _>(left, right, filter, allow_eq) +} + +pub(crate) fn join_asof_numeric( + input_ca: &ChunkedArray, + other: &Series, + strategy: AsofStrategy, + tolerance: Option>, + allow_eq: bool, +) -> PolarsResult { + let other = input_ca.unpack_series_matching_type(other)?; + + let ca = input_ca.rechunk(); + let other = other.rechunk(); + let left = ca.downcast_as_array(); + let right = other.downcast_as_array(); + + let out = if let Some(t) = tolerance { + let native_tolerance = t.try_extract::()?; + let abs_tolerance = native_tolerance.abs_diff(T::Native::zero()); + let filter = |l: T::Native, r: T::Native| l.abs_diff(r) <= abs_tolerance; + match strategy { + AsofStrategy::Forward => join_asof_forward::(left, right, filter, allow_eq), + AsofStrategy::Backward => join_asof_backward::(left, right, filter, allow_eq), + AsofStrategy::Nearest => join_asof_nearest::(left, right, filter, allow_eq), + } + } else { + let filter = |_l: T::Native, _r: T::Native| true; + match strategy { + AsofStrategy::Forward => join_asof_forward::(left, right, filter, allow_eq), + AsofStrategy::Backward => join_asof_backward::(left, right, filter, allow_eq), + AsofStrategy::Nearest => join_asof_nearest::(left, right, filter, allow_eq), + } + }; + Ok(out) +} + +pub(crate) fn join_asof( + input_ca: &ChunkedArray, + other: &Series, + strategy: AsofStrategy, + allow_eq: bool, +) -> PolarsResult +where + T: PolarsDataType, + for<'a> T::Physical<'a>: PartialOrd, +{ + let other = input_ca.unpack_series_matching_type(other)?; + + let ca = input_ca.rechunk(); + let other = other.rechunk(); + let left = ca.downcast_iter().next().unwrap(); + let right = other.downcast_iter().next().unwrap(); + + let filter = |_l: T::Physical<'_>, _r: T::Physical<'_>| true; + Ok(match strategy { + AsofStrategy::Forward => { + join_asof_impl::(left, right, filter, allow_eq) + }, + AsofStrategy::Backward => { + join_asof_impl::(left, right, filter, allow_eq) + }, + AsofStrategy::Nearest => unimplemented!(), + }) +} + +#[cfg(test)] +mod test { + use arrow::array::PrimitiveArray; + + use super::*; + + #[test] + fn test_asof_backward() { + let a = PrimitiveArray::from_slice([-1, 2, 3, 3, 3, 4]); + let b = PrimitiveArray::from_slice([1, 2, 3, 3]); + + let tuples = join_asof_backward::(&a, &b, |_, _| true, true); + assert_eq!(tuples.len(), a.len()); + assert_eq!( + tuples.to_vec(), + &[None, Some(1), Some(3), Some(3), Some(3), Some(3)] + ); + + let b = PrimitiveArray::from_slice([1, 2, 4, 5]); + let tuples = join_asof_backward::(&a, &b, |_, _| true, true); + assert_eq!( + tuples.to_vec(), + &[None, Some(1), Some(1), Some(1), Some(1), Some(2)] + ); + + let a = PrimitiveArray::from_slice([2, 4, 4, 4]); + let b = PrimitiveArray::from_slice([1, 2, 3, 3]); + let tuples = join_asof_backward::(&a, &b, |_, _| true, true); + assert_eq!(tuples.to_vec(), &[Some(1), Some(3), Some(3), Some(3)]); + } + + #[test] + fn test_asof_backward_tolerance() { + let a = PrimitiveArray::from_slice([-1, 20, 25, 30, 30, 40]); + let b = PrimitiveArray::from_slice([10, 20, 30, 30]); + let tuples = join_asof_backward::(&a, &b, |l, r| l.abs_diff(r) <= 4u32, true); + assert_eq!( + tuples.to_vec(), + &[None, Some(1), None, Some(3), Some(3), None] + ); + } + + #[test] + fn test_asof_forward_tolerance() { + let a = PrimitiveArray::from_slice([-1, 20, 25, 30, 30, 40, 52]); + let b = PrimitiveArray::from_slice([10, 20, 33, 55]); + let tuples = join_asof_forward::(&a, &b, |l, r| l.abs_diff(r) <= 4u32, true); + assert_eq!( + tuples.to_vec(), + &[None, Some(1), None, Some(2), Some(2), None, Some(3)] + ); + } + + #[test] + fn test_asof_forward() { + let a = PrimitiveArray::from_slice([-1, 1, 2, 4, 6]); + let b = PrimitiveArray::from_slice([1, 2, 4, 5]); + + let tuples = join_asof_forward::(&a, &b, |_, _| true, true); + assert_eq!(tuples.len(), a.len()); + assert_eq!(tuples.to_vec(), &[Some(0), Some(0), Some(1), Some(2), None]); + } +} diff --git a/crates/polars-ops/src/frame/join/asof/groups.rs b/crates/polars-ops/src/frame/join/asof/groups.rs new file mode 100644 index 000000000000..9d2a3a7ab4f2 --- /dev/null +++ b/crates/polars-ops/src/frame/join/asof/groups.rs @@ -0,0 +1,855 @@ +use std::hash::Hash; + +use num_traits::Zero; +use polars_core::hashing::_HASHMAP_INIT_SIZE; +use polars_core::prelude::*; +use polars_core::series::BitRepr; +use polars_core::utils::flatten::flatten_nullable; +use polars_core::utils::split_and_flatten; +use polars_core::{POOL, with_match_physical_float_polars_type}; +use polars_utils::abs_diff::AbsDiff; +use polars_utils::hashing::{DirtyHash, hash_to_partition}; +use polars_utils::nulls::IsNull; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; +use rayon::prelude::*; + +use super::*; +use crate::frame::join::{prepare_binary, prepare_keys_multiple}; + +fn compute_len_offsets>(iter: I) -> Vec { + let mut cumlen = 0; + iter.into_iter() + .map(|l| { + let offset = cumlen; + cumlen += l; + offset + }) + .collect() +} + +#[inline(always)] +fn materialize_nullable(idx: Option) -> NullableIdxSize { + match idx { + Some(t) => NullableIdxSize::from(t), + None => NullableIdxSize::null(), + } +} + +fn asof_in_group<'a, T, A, F>( + left_val: T::Physical<'a>, + right_val_arr: &'a T::Array, + right_grp_idxs: &[IdxSize], + group_states: &mut PlHashMap, + filter: F, + allow_eq: bool, +) -> Option +where + T: PolarsDataType, + A: AsofJoinState>, + F: Fn(T::Physical<'a>, T::Physical<'a>) -> bool, +{ + // We use the index of the first element in a group as an identifier to + // associate with the group state. + let id = right_grp_idxs.first()?; + let grp_state = group_states.entry(*id).or_insert_with(|| A::new(allow_eq)); + + unsafe { + let r_grp_idx = grp_state.next( + &left_val, + |i| { + // SAFETY: the group indices are valid, and next() only calls with + // i < right_grp_idxs.len(). + right_val_arr.get_unchecked(*right_grp_idxs.get_unchecked(i as usize) as usize) + }, + right_grp_idxs.len() as IdxSize, + )?; + + // SAFETY: r_grp_idx is valid, as is r_idx (which must be non-null) if + // we get here. + let r_idx = *right_grp_idxs.get_unchecked(r_grp_idx as usize); + let right_val = right_val_arr.value_unchecked(r_idx as usize); + filter(left_val, right_val).then_some(r_idx) + } +} + +fn asof_join_by_numeric( + by_left: &ChunkedArray, + by_right: &ChunkedArray, + left_asof: &ChunkedArray, + right_asof: &ChunkedArray, + filter: F, + allow_eq: bool, +) -> PolarsResult +where + T: PolarsDataType, + S: PolarsNumericType, + S::Native: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull, + A: for<'a> AsofJoinState>, + F: Sync + for<'a> Fn(T::Physical<'a>, T::Physical<'a>) -> bool, +{ + let (left_asof, right_asof) = POOL.join(|| left_asof.rechunk(), || right_asof.rechunk()); + let left_val_arr = left_asof.downcast_as_array(); + let right_val_arr = right_asof.downcast_as_array(); + + let n_threads = POOL.current_num_threads(); + // `strict` is false so that we always flatten. Even if there are more chunks than threads. + let split_by_left = split_and_flatten(by_left, n_threads); + let split_by_right = split_and_flatten(by_right, n_threads); + let offsets = compute_len_offsets(split_by_left.iter().map(|s| s.len())); + + // TODO: handle nulls more efficiently. Right now we just join on the value + // ignoring the validity mask, and ignore the nulls later. + let right_slices = split_by_right + .iter() + .map(|ca| { + assert_eq!(ca.chunks().len(), 1); + ca.downcast_iter().next().unwrap().values_iter().copied() + }) + .collect(); + let hash_tbls = build_tables(right_slices, false); + let n_tables = hash_tbls.len(); + + // Now we probe the right hand side for each left hand side. + let out = split_by_left + .into_par_iter() + .zip(offsets) + .map(|(by_left, offset)| { + let mut results = Vec::with_capacity(by_left.len()); + let mut group_states: PlHashMap = + PlHashMap::with_capacity(_HASHMAP_INIT_SIZE); + + assert_eq!(by_left.chunks().len(), 1); + let by_left_chunk = by_left.downcast_iter().next().unwrap(); + for (rel_idx_left, opt_by_left_k) in by_left_chunk.iter().enumerate() { + let Some(by_left_k) = opt_by_left_k else { + results.push(NullableIdxSize::null()); + continue; + }; + let by_left_k = by_left_k.to_total_ord(); + let idx_left = (rel_idx_left + offset) as IdxSize; + let Some(left_val) = left_val_arr.get(idx_left as usize) else { + results.push(NullableIdxSize::null()); + continue; + }; + + let group_probe_table = unsafe { + hash_tbls.get_unchecked(hash_to_partition(by_left_k.dirty_hash(), n_tables)) + }; + let Some(right_grp_idxs) = group_probe_table.get(&by_left_k) else { + results.push(NullableIdxSize::null()); + continue; + }; + let id = asof_in_group::( + left_val, + right_val_arr, + right_grp_idxs.as_slice(), + &mut group_states, + &filter, + allow_eq, + ); + results.push(materialize_nullable(id)); + } + results + }); + + let bufs = POOL.install(|| out.collect::>()); + Ok(flatten_nullable(&bufs)) +} + +fn asof_join_by_binary( + by_left: &ChunkedArray, + by_right: &ChunkedArray, + left_asof: &ChunkedArray, + right_asof: &ChunkedArray, + filter: F, + allow_eq: bool, +) -> IdxArr +where + B: PolarsDataType, + for<'b> ::ValueT<'b>: AsRef<[u8]>, + T: PolarsDataType, + A: for<'a> AsofJoinState>, + F: Sync + for<'a> Fn(T::Physical<'a>, T::Physical<'a>) -> bool, +{ + let (left_asof, right_asof) = POOL.join(|| left_asof.rechunk(), || right_asof.rechunk()); + let left_val_arr = left_asof.downcast_as_array(); + let right_val_arr = right_asof.downcast_as_array(); + + let (prep_by_left, prep_by_right, _, _) = prepare_binary::(by_left, by_right, false); + let offsets = compute_len_offsets(prep_by_left.iter().map(|s| s.len())); + let hash_tbls = build_tables(prep_by_right, false); + let n_tables = hash_tbls.len(); + + // Now we probe the right hand side for each left hand side. + let iter = prep_by_left + .into_par_iter() + .zip(offsets) + .map(|(by_left, offset)| { + let mut results = Vec::with_capacity(by_left.len()); + let mut group_states: PlHashMap<_, A> = PlHashMap::with_capacity(_HASHMAP_INIT_SIZE); + + for (rel_idx_left, by_left_k) in by_left.iter().enumerate() { + let idx_left = (rel_idx_left + offset) as IdxSize; + let Some(left_val) = left_val_arr.get(idx_left as usize) else { + results.push(NullableIdxSize::null()); + continue; + }; + + let group_probe_table = unsafe { + hash_tbls.get_unchecked(hash_to_partition(by_left_k.dirty_hash(), n_tables)) + }; + let Some(right_grp_idxs) = group_probe_table.get(by_left_k) else { + results.push(NullableIdxSize::null()); + continue; + }; + let id = asof_in_group::( + left_val, + right_val_arr, + right_grp_idxs.as_slice(), + &mut group_states, + &filter, + allow_eq, + ); + + results.push(materialize_nullable(id)); + } + results + }); + let bufs = POOL.install(|| iter.collect::>()); + flatten_nullable(&bufs) +} + +#[allow(clippy::too_many_arguments)] +fn dispatch_join_by_type( + left_asof: &ChunkedArray, + right_asof: &ChunkedArray, + left_by: &mut DataFrame, + right_by: &mut DataFrame, + filter: F, + allow_eq: bool, +) -> PolarsResult +where + T: PolarsDataType, + A: for<'a> AsofJoinState>, + F: Sync + for<'a> Fn(T::Physical<'a>, T::Physical<'a>) -> bool, +{ + let out = if left_by.width() == 1 { + let left_by_s = left_by.get_columns()[0].to_physical_repr(); + let right_by_s = right_by.get_columns()[0].to_physical_repr(); + let left_dtype = left_by_s.dtype(); + let right_dtype = right_by_s.dtype(); + polars_ensure!(left_dtype == right_dtype, + ComputeError: "mismatching dtypes in 'by' parameter of asof-join: `{left_dtype}` and `{right_dtype}`", + ); + match left_dtype { + DataType::String => { + let left_by = &left_by_s.str().unwrap().as_binary(); + let right_by = right_by_s.str().unwrap().as_binary(); + asof_join_by_binary::( + left_by, &right_by, left_asof, right_asof, filter, allow_eq, + ) + }, + DataType::Binary => { + let left_by = &left_by_s.binary().unwrap(); + let right_by = right_by_s.binary().unwrap(); + asof_join_by_binary::( + left_by, right_by, left_asof, right_asof, filter, allow_eq, + ) + }, + x if x.is_float() => { + with_match_physical_float_polars_type!(left_by_s.dtype(), |$T| { + let left_by: &ChunkedArray<$T> = left_by_s.as_materialized_series().as_ref().as_ref().as_ref(); + let right_by: &ChunkedArray<$T> = right_by_s.as_materialized_series().as_ref().as_ref().as_ref(); + asof_join_by_numeric::( + left_by, right_by, left_asof, right_asof, filter, allow_eq + )? + }) + }, + _ => { + let left_by = left_by_s.bit_repr(); + let right_by = right_by_s.bit_repr(); + + let (Some(left_by), Some(right_by)) = (left_by, right_by) else { + polars_bail!(nyi = "Dispatch join for {left_dtype} and {right_dtype}"); + }; + + use BitRepr as B; + match (left_by, right_by) { + (B::Small(left_by), B::Small(right_by)) => { + asof_join_by_numeric::( + &left_by, &right_by, left_asof, right_asof, filter, allow_eq, + )? + }, + (B::Large(left_by), B::Large(right_by)) => { + asof_join_by_numeric::( + &left_by, &right_by, left_asof, right_asof, filter, allow_eq, + )? + }, + // We have already asserted that the datatypes are the same. + _ => unreachable!(), + } + }, + } + } else { + for (lhs, rhs) in left_by.get_columns().iter().zip(right_by.get_columns()) { + polars_ensure!(lhs.dtype() == rhs.dtype(), + ComputeError: "mismatching dtypes in 'by' parameter of asof-join: `{}` and `{}`", lhs.dtype(), rhs.dtype() + ); + #[cfg(feature = "dtype-categorical")] + _check_categorical_src(lhs.dtype(), rhs.dtype())?; + } + + // TODO: @scalar-opt. + let left_by_series: Vec<_> = left_by.materialized_column_iter().cloned().collect(); + let right_by_series: Vec<_> = right_by.materialized_column_iter().cloned().collect(); + let lhs_keys = prepare_keys_multiple(&left_by_series, false)?; + let rhs_keys = prepare_keys_multiple(&right_by_series, false)?; + asof_join_by_binary::( + &lhs_keys, &rhs_keys, left_asof, right_asof, filter, allow_eq, + ) + }; + Ok(out) +} + +#[allow(clippy::too_many_arguments)] +fn dispatch_join_strategy( + left_asof: &ChunkedArray, + right_asof: &Series, + left_by: &mut DataFrame, + right_by: &mut DataFrame, + strategy: AsofStrategy, + allow_eq: bool, +) -> PolarsResult +where + for<'a> T::Physical<'a>: PartialOrd, +{ + let right_asof = left_asof.unpack_series_matching_type(right_asof)?; + + let filter = |_a: T::Physical<'_>, _b: T::Physical<'_>| true; + match strategy { + AsofStrategy::Backward => dispatch_join_by_type::( + left_asof, right_asof, left_by, right_by, filter, allow_eq, + ), + AsofStrategy::Forward => dispatch_join_by_type::( + left_asof, right_asof, left_by, right_by, filter, allow_eq, + ), + AsofStrategy::Nearest => unimplemented!(), + } +} + +#[allow(clippy::too_many_arguments)] +fn dispatch_join_strategy_numeric( + left_asof: &ChunkedArray, + right_asof: &Series, + left_by: &mut DataFrame, + right_by: &mut DataFrame, + strategy: AsofStrategy, + tolerance: Option>, + allow_eq: bool, +) -> PolarsResult { + let right_ca = left_asof.unpack_series_matching_type(right_asof)?; + + if let Some(tol) = tolerance { + let native_tolerance: T::Native = tol.try_extract()?; + let abs_tolerance = native_tolerance.abs_diff(T::Native::zero()); + let filter = |a: T::Native, b: T::Native| a.abs_diff(b) <= abs_tolerance; + match strategy { + AsofStrategy::Backward => dispatch_join_by_type::( + left_asof, right_ca, left_by, right_by, filter, allow_eq, + ), + AsofStrategy::Forward => dispatch_join_by_type::( + left_asof, right_ca, left_by, right_by, filter, allow_eq, + ), + AsofStrategy::Nearest => dispatch_join_by_type::( + left_asof, right_ca, left_by, right_by, filter, allow_eq, + ), + } + } else { + let filter = |_a: T::Physical<'_>, _b: T::Physical<'_>| true; + match strategy { + AsofStrategy::Backward => dispatch_join_by_type::( + left_asof, right_ca, left_by, right_by, filter, allow_eq, + ), + AsofStrategy::Forward => dispatch_join_by_type::( + left_asof, right_ca, left_by, right_by, filter, allow_eq, + ), + AsofStrategy::Nearest => dispatch_join_by_type::( + left_asof, right_ca, left_by, right_by, filter, allow_eq, + ), + } + } +} + +#[allow(clippy::too_many_arguments)] +fn dispatch_join_type( + left_asof: &Series, + right_asof: &Series, + left_by: &mut DataFrame, + right_by: &mut DataFrame, + strategy: AsofStrategy, + tolerance: Option>, + allow_eq: bool, +) -> PolarsResult { + match left_asof.dtype() { + DataType::Int64 => { + let ca = left_asof.i64().unwrap(); + dispatch_join_strategy_numeric( + ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq, + ) + }, + DataType::Int32 => { + let ca = left_asof.i32().unwrap(); + dispatch_join_strategy_numeric( + ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq, + ) + }, + DataType::UInt64 => { + let ca = left_asof.u64().unwrap(); + dispatch_join_strategy_numeric( + ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq, + ) + }, + DataType::UInt32 => { + let ca = left_asof.u32().unwrap(); + dispatch_join_strategy_numeric( + ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq, + ) + }, + DataType::Float32 => { + let ca = left_asof.f32().unwrap(); + dispatch_join_strategy_numeric( + ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq, + ) + }, + DataType::Float64 => { + let ca = left_asof.f64().unwrap(); + dispatch_join_strategy_numeric( + ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq, + ) + }, + DataType::Boolean => { + let ca = left_asof.bool().unwrap(); + dispatch_join_strategy::( + ca, right_asof, left_by, right_by, strategy, allow_eq, + ) + }, + DataType::Binary => { + let ca = left_asof.binary().unwrap(); + dispatch_join_strategy::( + ca, right_asof, left_by, right_by, strategy, allow_eq, + ) + }, + DataType::String => { + let ca = left_asof.str().unwrap(); + let right_binary = right_asof.cast(&DataType::Binary).unwrap(); + dispatch_join_strategy::( + &ca.as_binary(), + &right_binary, + left_by, + right_by, + strategy, + allow_eq, + ) + }, + _ => { + let left_asof = left_asof.cast(&DataType::Int32).unwrap(); + let right_asof = right_asof.cast(&DataType::Int32).unwrap(); + let ca = left_asof.i32().unwrap(); + dispatch_join_strategy_numeric( + ca, + &right_asof, + left_by, + right_by, + strategy, + tolerance, + allow_eq, + ) + }, + } +} + +pub trait AsofJoinBy: IntoDf { + #[allow(clippy::too_many_arguments)] + #[doc(hidden)] + fn _join_asof_by( + &self, + other: &DataFrame, + left_on: &Series, + right_on: &Series, + left_by: Vec, + right_by: Vec, + strategy: AsofStrategy, + tolerance: Option>, + suffix: Option, + slice: Option<(i64, usize)>, + coalesce: bool, + allow_eq: bool, + check_sortedness: bool, + ) -> PolarsResult { + let (self_sliced_slot, other_sliced_slot, left_slice_s, right_slice_s); // Keeps temporaries alive. + let (self_df, other_df, left_key, right_key); + if let Some((offset, len)) = slice { + self_sliced_slot = self.to_df().slice(offset, len); + other_sliced_slot = other.slice(offset, len); + left_slice_s = left_on.slice(offset, len); + right_slice_s = right_on.slice(offset, len); + left_key = &left_slice_s; + right_key = &right_slice_s; + self_df = &self_sliced_slot; + other_df = &other_sliced_slot; + } else { + self_df = self.to_df(); + other_df = other; + left_key = left_on; + right_key = right_on; + } + + let left_asof = left_key.to_physical_repr(); + let right_asof = right_key.to_physical_repr(); + let right_asof_name = right_asof.name(); + let left_asof_name = left_asof.name(); + check_asof_columns( + &left_asof, + &right_asof, + tolerance.is_some(), + check_sortedness, + !(left_by.is_empty() && right_by.is_empty()), + )?; + + let mut left_by = self_df.select(left_by)?; + let mut right_by = other_df.select(right_by)?; + + unsafe { + for (l, r) in left_by + .get_columns_mut() + .iter_mut() + .zip(right_by.get_columns_mut().iter_mut()) + { + #[cfg(feature = "dtype-categorical")] + _check_categorical_src(l.dtype(), r.dtype())?; + *l = l.to_physical_repr(); + *r = r.to_physical_repr(); + } + } + + let right_join_tuples = dispatch_join_type( + &left_asof, + &right_asof, + &mut left_by, + &mut right_by, + strategy, + tolerance, + allow_eq, + )?; + + let mut drop_these = right_by.get_column_names(); + if coalesce && left_asof_name == right_asof_name { + drop_these.push(right_asof_name); + } + + let cols = other_df + .get_columns() + .iter() + .filter(|s| !drop_these.contains(&s.name())) + .cloned() + .collect(); + let proj_other_df = unsafe { DataFrame::new_no_checks(other_df.height(), cols) }; + + let left = self_df.clone(); + + // SAFETY: join tuples are in bounds. + let right_df = unsafe { + proj_other_df.take_unchecked(&IdxCa::with_chunk(PlSmallStr::EMPTY, right_join_tuples)) + }; + + _finish_join(left, right_df, suffix) + } + + /// This is similar to a left-join except that we match on nearest key + /// rather than equal keys. The keys must be sorted to perform an asof join. + /// This is a special implementation of an asof join that searches for the + /// nearest keys within a subgroup set by `by`. + #[allow(clippy::too_many_arguments)] + fn join_asof_by( + &self, + other: &DataFrame, + left_on: &str, + right_on: &str, + left_by: I, + right_by: I, + strategy: AsofStrategy, + tolerance: Option>, + allow_eq: bool, + check_sortedness: bool, + ) -> PolarsResult + where + I: IntoIterator, + S: AsRef, + { + let self_df = self.to_df(); + let left_by = left_by.into_iter().map(|s| s.as_ref().into()).collect(); + let right_by = right_by.into_iter().map(|s| s.as_ref().into()).collect(); + let left_key = self_df.column(left_on)?.as_materialized_series(); + let right_key = other.column(right_on)?.as_materialized_series(); + self_df._join_asof_by( + other, + left_key, + right_key, + left_by, + right_by, + strategy, + tolerance, + None, + None, + true, + allow_eq, + check_sortedness, + ) + } +} + +impl AsofJoinBy for DataFrame {} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_asof_by() -> PolarsResult<()> { + let a = df![ + "a" => [-1, 2, 3, 3, 3, 4], + "b" => ["a", "b", "c", "d", "e", "f"] + ]?; + + let b = df![ + "a" => [1, 2, 3, 3], + "b" => ["a", "b", "c", "d"], + "right_vals" => [1, 2, 3, 4] + ]?; + + let out = a.join_asof_by( + &b, + "a", + "a", + ["b"], + ["b"], + AsofStrategy::Backward, + None, + true, + true, + )?; + assert_eq!(out.get_column_names(), &["a", "b", "right_vals"]); + let out = out.column("right_vals").unwrap(); + let out = out.i32().unwrap(); + assert_eq!( + Vec::from(out), + &[None, Some(2), Some(3), Some(4), None, None] + ); + Ok(()) + } + + #[test] + fn test_asof_by2() -> PolarsResult<()> { + let trades = df![ + "time" => [23i64, 38, 48, 48, 48], + "ticker" => ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"], + "groups_numeric" => [1, 1, 2, 2, 3], + "bid" => [51.95, 51.95, 720.77, 720.92, 98.0] + ]?; + + let quotes = df![ + "time" => [23i64, + 23, + 30, + 41, + 48, + 49, + 72, + 75], + "ticker" => ["GOOG", "MSFT", "MSFT", "MSFT", "GOOG", "AAPL", "GOOG", "MSFT"], + "groups_numeric" => [2, 1, 1, 1, 2, 3, 2, 1], + "bid" => [720.5, 51.95, 51.97, 51.99, 720.5, 97.99, 720.5, 52.01] + + ]?; + + let out = trades.join_asof_by( + "es, + "time", + "time", + ["ticker"], + ["ticker"], + AsofStrategy::Backward, + None, + true, + true, + )?; + let a = out.column("bid_right").unwrap(); + let a = a.f64().unwrap(); + let expected = &[Some(51.95), Some(51.97), Some(720.5), Some(720.5), None]; + + assert_eq!(Vec::from(a), expected); + + let out = trades.join_asof_by( + "es, + "time", + "time", + ["groups_numeric"], + ["groups_numeric"], + AsofStrategy::Backward, + None, + true, + true, + )?; + let a = out.column("bid_right").unwrap(); + let a = a.f64().unwrap(); + + assert_eq!(Vec::from(a), expected); + + Ok(()) + } + + #[test] + fn test_asof_by3() -> PolarsResult<()> { + let a = df![ + "a" => [ -1, 2, 2, 3, 3, 3, 4], + "b" => ["a", "a", "b", "c", "d", "e", "f"] + ]?; + + let b = df![ + "a" => [ 1, 3, 2, 3, 2], + "b" => ["a", "a", "b", "c", "d"], + "right_vals" => [ 1, 3, 2, 3, 4] + ]?; + + let out = a.join_asof_by( + &b, + "a", + "a", + ["b"], + ["b"], + AsofStrategy::Forward, + None, + true, + true, + )?; + assert_eq!(out.get_column_names(), &["a", "b", "right_vals"]); + let out = out.column("right_vals").unwrap(); + let out = out.i32().unwrap(); + assert_eq!( + Vec::from(out), + &[Some(1), Some(3), Some(2), Some(3), None, None, None] + ); + + let out = a.join_asof_by( + &b, + "a", + "a", + ["b"], + ["b"], + AsofStrategy::Forward, + Some(AnyValue::Int32(1)), + true, + true, + )?; + assert_eq!(out.get_column_names(), &["a", "b", "right_vals"]); + let out = out.column("right_vals").unwrap(); + let out = out.i32().unwrap(); + assert_eq!( + Vec::from(out), + &[None, Some(3), Some(2), Some(3), None, None, None] + ); + + Ok(()) + } + + #[test] + fn test_asof_by4() -> PolarsResult<()> { + let trades = df![ + "time" => [23i64, 38, 48, 48, 48], + "ticker" => ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"], + "groups_numeric" => [1, 1, 2, 2, 3], + "bid" => [51.95, 51.95, 720.77, 720.92, 98.0] + ]?; + + let quotes = df![ + "time" => [23i64, 23, 30, 41, 48, 49, 72, 75], + "ticker" => ["GOOG", "MSFT", "MSFT", "MSFT", "GOOG", "AAPL", "GOOG", "MSFT"], + "bid" => [720.5, 51.95, 51.97, 51.99, 720.5, 97.99, 720.5, 52.01], + "groups_numeric" => [2, 1, 1, 1, 2, 3, 2, 1], + + ]?; + /* + trades: + shape: (5, 4) + ┌──────┬────────┬────────────────┬────────┐ + │ time ┆ ticker ┆ groups_numeric ┆ bid │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ str ┆ i32 ┆ f64 │ + ╞══════╪════════╪════════════════╪════════╡ + │ 23 ┆ MSFT ┆ 1 ┆ 51.95 │ + │ 38 ┆ MSFT ┆ 1 ┆ 51.95 │ + │ 48 ┆ GOOG ┆ 2 ┆ 720.77 │ + │ 48 ┆ GOOG ┆ 2 ┆ 720.92 │ + │ 48 ┆ AAPL ┆ 3 ┆ 98.0 │ + └──────┴────────┴────────────────┴────────┘ + quotes: + shape: (8, 4) + ┌──────┬────────┬───────┬────────────────┐ + │ time ┆ ticker ┆ bid ┆ groups_numeric │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ str ┆ f64 ┆ i32 │ + ╞══════╪════════╪═══════╪════════════════╡ + │ 23 ┆ GOOG ┆ 720.5 ┆ 2 │ + │ 23 ┆ MSFT ┆ 51.95 ┆ 1 │ + │ 30 ┆ MSFT ┆ 51.97 ┆ 1 │ + │ 41 ┆ MSFT ┆ 51.99 ┆ 1 │ + │ 48 ┆ GOOG ┆ 720.5 ┆ 2 │ + │ 49 ┆ AAPL ┆ 97.99 ┆ 3 │ + │ 72 ┆ GOOG ┆ 720.5 ┆ 2 │ + │ 75 ┆ MSFT ┆ 52.01 ┆ 1 │ + └──────┴────────┴───────┴────────────────┘ + */ + + let out = trades.join_asof_by( + "es, + "time", + "time", + ["ticker"], + ["ticker"], + AsofStrategy::Forward, + None, + true, + true, + )?; + let a = out.column("bid_right").unwrap(); + let a = a.f64().unwrap(); + let expected = &[ + Some(51.95), + Some(51.99), + Some(720.5), + Some(720.5), + Some(97.99), + ]; + + assert_eq!(Vec::from(a), expected); + + let out = trades.join_asof_by( + "es, + "time", + "time", + ["groups_numeric"], + ["groups_numeric"], + AsofStrategy::Forward, + None, + true, + true, + )?; + let a = out.column("bid_right").unwrap(); + let a = a.f64().unwrap(); + + assert_eq!(Vec::from(a), expected); + + Ok(()) + } +} diff --git a/crates/polars-ops/src/frame/join/asof/mod.rs b/crates/polars-ops/src/frame/join/asof/mod.rs new file mode 100644 index 000000000000..1f433b7b10dc --- /dev/null +++ b/crates/polars-ops/src/frame/join/asof/mod.rs @@ -0,0 +1,355 @@ +mod default; +mod groups; +use std::borrow::Cow; +use std::cmp::Ordering; + +use default::*; +pub use groups::AsofJoinBy; +use polars_core::prelude::*; +use polars_utils::pl_str::PlSmallStr; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +#[cfg(feature = "dtype-categorical")] +use super::_check_categorical_src; +use super::{_finish_join, build_tables}; +use crate::frame::IntoDf; +use crate::series::SeriesMethods; + +#[inline] +fn ge_allow_eq(l: &T, r: &T, allow_eq: bool) -> bool { + match l.partial_cmp(r) { + Some(Ordering::Equal) => allow_eq, + Some(Ordering::Greater) => true, + _ => false, + } +} + +#[inline] +fn lt_allow_eq(l: &T, r: &T, allow_eq: bool) -> bool { + match l.partial_cmp(r) { + Some(Ordering::Equal) => allow_eq, + Some(Ordering::Less) => true, + _ => false, + } +} + +trait AsofJoinState { + fn next Option>( + &mut self, + left_val: &T, + right: F, + n_right: IdxSize, + ) -> Option; + + fn new(allow_eq: bool) -> Self; +} + +struct AsofJoinForwardState { + scan_offset: IdxSize, + allow_eq: bool, +} + +impl AsofJoinState for AsofJoinForwardState { + fn new(allow_eq: bool) -> Self { + AsofJoinForwardState { + scan_offset: Default::default(), + allow_eq, + } + } + #[inline] + fn next Option>( + &mut self, + left_val: &T, + mut right: F, + n_right: IdxSize, + ) -> Option { + while (self.scan_offset) < n_right { + if let Some(right_val) = right(self.scan_offset) { + if ge_allow_eq(&right_val, left_val, self.allow_eq) { + return Some(self.scan_offset); + } + } + self.scan_offset += 1; + } + None + } +} + +struct AsofJoinBackwardState { + // best_bound is the greatest right index <= left_val. + best_bound: Option, + scan_offset: IdxSize, + allow_eq: bool, +} + +impl AsofJoinState for AsofJoinBackwardState { + fn new(allow_eq: bool) -> Self { + AsofJoinBackwardState { + scan_offset: Default::default(), + best_bound: Default::default(), + allow_eq, + } + } + #[inline] + fn next Option>( + &mut self, + left_val: &T, + mut right: F, + n_right: IdxSize, + ) -> Option { + while self.scan_offset < n_right { + if let Some(right_val) = right(self.scan_offset) { + if lt_allow_eq(&right_val, left_val, self.allow_eq) { + self.best_bound = Some(self.scan_offset); + } else { + break; + } + } + self.scan_offset += 1; + } + self.best_bound + } +} + +#[derive(Default)] +struct AsofJoinNearestState { + // best_bound is the nearest value to left_val, with ties broken towards the last element. + best_bound: Option, + scan_offset: IdxSize, + allow_eq: bool, +} + +impl AsofJoinState for AsofJoinNearestState { + fn new(allow_eq: bool) -> Self { + AsofJoinNearestState { + scan_offset: Default::default(), + best_bound: Default::default(), + allow_eq, + } + } + #[inline] + fn next Option>( + &mut self, + left_val: &T, + mut right: F, + n_right: IdxSize, + ) -> Option { + // Skipping ahead to the first value greater than left_val. This is + // cheaper than computing differences. + while self.scan_offset < n_right { + if let Some(scan_right_val) = right(self.scan_offset) { + if lt_allow_eq(&scan_right_val, left_val, self.allow_eq) { + self.best_bound = Some(self.scan_offset); + } else { + // Now we must compute a difference to see if scan_right_val + // is closer than our current best bound. + let scan_is_better = if let Some(best_idx) = self.best_bound { + let best_right_val = unsafe { right(best_idx).unwrap_unchecked() }; + let best_diff = left_val.abs_diff(best_right_val); + let scan_diff = left_val.abs_diff(scan_right_val); + + lt_allow_eq(&scan_diff, &best_diff, self.allow_eq) + } else { + true + }; + + if scan_is_better { + self.best_bound = Some(self.scan_offset); + self.scan_offset += 1; + + // It is possible there are later elements equal to our + // scan, so keep going on. + while self.scan_offset < n_right { + if let Some(next_right_val) = right(self.scan_offset) { + if next_right_val == scan_right_val && self.allow_eq { + self.best_bound = Some(self.scan_offset); + } else { + break; + } + } + + self.scan_offset += 1; + } + } + + break; + } + } + + self.scan_offset += 1; + } + + self.best_bound + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Default, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct AsOfOptions { + pub strategy: AsofStrategy, + /// A tolerance in the same unit as the asof column + pub tolerance: Option>, + /// A time duration specified as a string, for example: + /// - "5m" + /// - "2h15m" + /// - "1d6h" + pub tolerance_str: Option, + pub left_by: Option>, + pub right_by: Option>, + /// Allow equal matches + pub allow_eq: bool, + pub check_sortedness: bool, +} + +fn check_asof_columns( + a: &Series, + b: &Series, + has_tolerance: bool, + check_sortedness: bool, + by_groups_present: bool, +) -> PolarsResult<()> { + let dtype_a = a.dtype(); + let dtype_b = b.dtype(); + if has_tolerance { + polars_ensure!( + dtype_a.to_physical().is_primitive_numeric() && dtype_b.to_physical().is_primitive_numeric(), + InvalidOperation: + "asof join with tolerance is only supported on numeric/temporal keys" + ); + } else { + polars_ensure!( + dtype_a.to_physical().is_primitive() && dtype_b.to_physical().is_primitive(), + InvalidOperation: + "asof join is only supported on primitive key types" + ); + } + polars_ensure!( + dtype_a == dtype_b, + ComputeError: "mismatching key dtypes in asof-join: `{}` and `{}`", + a.dtype(), b.dtype() + ); + if check_sortedness { + if by_groups_present { + polars_warn!("Sortedness of columns cannot be checked when 'by' groups provided"); + } else { + a.ensure_sorted_arg("asof_join")?; + b.ensure_sorted_arg("asof_join")?; + } + } + Ok(()) +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Default, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum AsofStrategy { + /// selects the last row in the right DataFrame whose ‘on’ key is less than or equal to the left’s key + #[default] + Backward, + /// selects the first row in the right DataFrame whose ‘on’ key is greater than or equal to the left’s key. + Forward, + /// selects the right in the right DataFrame whose 'on' key is nearest to the left's key. + Nearest, +} + +pub trait AsofJoin: IntoDf { + #[doc(hidden)] + #[allow(clippy::too_many_arguments)] + fn _join_asof( + &self, + other: &DataFrame, + left_key: &Series, + right_key: &Series, + strategy: AsofStrategy, + tolerance: Option>, + suffix: Option, + slice: Option<(i64, usize)>, + coalesce: bool, + allow_eq: bool, + check_sortedness: bool, + ) -> PolarsResult { + let self_df = self.to_df(); + + check_asof_columns( + left_key, + right_key, + tolerance.is_some(), + check_sortedness, + false, + )?; + let left_key = left_key.to_physical_repr(); + let right_key = right_key.to_physical_repr(); + + let mut take_idx = match left_key.dtype() { + DataType::Int64 => { + let ca = left_key.i64().unwrap(); + join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq) + }, + DataType::Int32 => { + let ca = left_key.i32().unwrap(); + join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq) + }, + #[cfg(feature = "dtype-i128")] + DataType::Int128 => { + let ca = left_key.i128().unwrap(); + join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq) + }, + DataType::UInt64 => { + let ca = left_key.u64().unwrap(); + join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq) + }, + DataType::UInt32 => { + let ca = left_key.u32().unwrap(); + join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq) + }, + DataType::Float32 => { + let ca = left_key.f32().unwrap(); + join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq) + }, + DataType::Float64 => { + let ca = left_key.f64().unwrap(); + join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq) + }, + DataType::Boolean => { + let ca = left_key.bool().unwrap(); + join_asof::(ca, &right_key, strategy, allow_eq) + }, + DataType::Binary => { + let ca = left_key.binary().unwrap(); + join_asof::(ca, &right_key, strategy, allow_eq) + }, + DataType::String => { + let ca = left_key.str().unwrap(); + let right_binary = right_key.cast(&DataType::Binary).unwrap(); + join_asof::(&ca.as_binary(), &right_binary, strategy, allow_eq) + }, + _ => { + let left_key = left_key.cast(&DataType::Int32).unwrap(); + let right_key = right_key.cast(&DataType::Int32).unwrap(); + let ca = left_key.i32().unwrap(); + join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq) + }, + }?; + try_raise_keyboard_interrupt(); + + // Drop right join column. + let other = if coalesce && left_key.name() == right_key.name() { + Cow::Owned(other.drop(right_key.name())?) + } else { + Cow::Borrowed(other) + }; + + let mut left = self_df.clone(); + if let Some((offset, len)) = slice { + left = left.slice(offset, len); + take_idx = take_idx.slice(offset, len); + } + + // SAFETY: join tuples are in bounds. + let right_df = unsafe { other.take_unchecked(&take_idx) }; + + _finish_join(left, right_df, suffix) + } +} + +impl AsofJoin for DataFrame {} diff --git a/crates/polars-ops/src/frame/join/checks.rs b/crates/polars-ops/src/frame/join/checks.rs new file mode 100644 index 000000000000..0fa179afba7b --- /dev/null +++ b/crates/polars-ops/src/frame/join/checks.rs @@ -0,0 +1,18 @@ +use super::*; + +/// If Categorical types are created without a global string cache or under +/// a different global string cache the mapping will be incorrect. +pub(crate) fn _check_categorical_src(l: &DataType, r: &DataType) -> PolarsResult<()> { + match (l, r) { + (DataType::Categorical(Some(l), _), DataType::Categorical(Some(r), _)) + | (DataType::Enum(Some(l), _), DataType::Enum(Some(r), _)) => { + polars_ensure!(l.same_src(r), string_cache_mismatch); + }, + (DataType::Categorical(_, _), DataType::Enum(_, _)) + | (DataType::Enum(_, _), DataType::Categorical(_, _)) => { + polars_bail!(ComputeError: "enum and categorical are not from the same source") + }, + _ => (), + }; + Ok(()) +} diff --git a/crates/polars-ops/src/frame/join/cross_join.rs b/crates/polars-ops/src/frame/join/cross_join.rs new file mode 100644 index 000000000000..38b819f8e313 --- /dev/null +++ b/crates/polars-ops/src/frame/join/cross_join.rs @@ -0,0 +1,184 @@ +use polars_core::utils::{ + _set_partition_size, CustomIterTools, NoNull, accumulate_dataframes_vertical_unchecked, + concat_df_unchecked, split, +}; +use polars_utils::pl_str::PlSmallStr; + +use super::*; + +fn slice_take( + total_rows: IdxSize, + n_rows_right: IdxSize, + slice: Option<(i64, usize)>, + inner: fn(IdxSize, IdxSize, IdxSize) -> IdxCa, +) -> IdxCa { + match slice { + None => inner(0, total_rows, n_rows_right), + Some((offset, len)) => { + let (offset, len) = slice_offsets(offset, len, total_rows as usize); + inner(offset as IdxSize, (len + offset) as IdxSize, n_rows_right) + }, + } +} + +fn take_left(total_rows: IdxSize, n_rows_right: IdxSize, slice: Option<(i64, usize)>) -> IdxCa { + fn inner(offset: IdxSize, total_rows: IdxSize, n_rows_right: IdxSize) -> IdxCa { + let mut take: NoNull = (offset..total_rows) + .map(|i| i / n_rows_right) + .collect_trusted(); + take.set_sorted_flag(IsSorted::Ascending); + take.into_inner() + } + slice_take(total_rows, n_rows_right, slice, inner) +} + +fn take_right(total_rows: IdxSize, n_rows_right: IdxSize, slice: Option<(i64, usize)>) -> IdxCa { + fn inner(offset: IdxSize, total_rows: IdxSize, n_rows_right: IdxSize) -> IdxCa { + let take: NoNull = (offset..total_rows) + .map(|i| i % n_rows_right) + .collect_trusted(); + take.into_inner() + } + slice_take(total_rows, n_rows_right, slice, inner) +} + +pub trait CrossJoin: IntoDf { + #[doc(hidden)] + /// used by streaming + fn _cross_join_with_names( + &self, + other: &DataFrame, + names: &[PlSmallStr], + ) -> PolarsResult { + let (mut l_df, r_df) = cross_join_dfs(self.to_df(), other, None, false)?; + l_df.clear_schema(); + + unsafe { + l_df.get_columns_mut().extend_from_slice(r_df.get_columns()); + + l_df.get_columns_mut() + .iter_mut() + .zip(names) + .for_each(|(s, name)| { + if s.name() != name { + s.rename(name.clone()); + } + }); + } + Ok(l_df) + } + + /// Creates the Cartesian product from both frames, preserves the order of the left keys. + fn cross_join( + &self, + other: &DataFrame, + suffix: Option, + slice: Option<(i64, usize)>, + ) -> PolarsResult { + let (l_df, r_df) = cross_join_dfs(self.to_df(), other, slice, true)?; + + _finish_join(l_df, r_df, suffix) + } +} + +impl CrossJoin for DataFrame {} + +fn cross_join_dfs( + df_self: &DataFrame, + other: &DataFrame, + slice: Option<(i64, usize)>, + parallel: bool, +) -> PolarsResult<(DataFrame, DataFrame)> { + let n_rows_left = df_self.height() as IdxSize; + let n_rows_right = other.height() as IdxSize; + let Some(total_rows) = n_rows_left.checked_mul(n_rows_right) else { + polars_bail!( + ComputeError: "cross joins would produce more rows than fits into 2^32; \ + consider compiling with polars-big-idx feature, or set 'streaming'" + ); + }; + if n_rows_left == 0 || n_rows_right == 0 { + return Ok((df_self.clear(), other.clear())); + } + + // the left side has the Nth row combined with every row from right. + // So let's say we have the following no. of rows + // left: 3 + // right: 4 + // + // left take idx: 000011112222 + // right take idx: 012301230123 + + let create_left_df = || { + // SAFETY: + // take left is in bounds + unsafe { + df_self.take_unchecked_impl(&take_left(total_rows, n_rows_right, slice), parallel) + } + }; + + let create_right_df = || { + // concatenation of dataframes is very expensive if we need to make the series mutable + // many times, these are atomic operations + // so we choose a different strategy at > 100 rows (arbitrarily small number) + if n_rows_left > 100 || slice.is_some() { + // SAFETY: + // take right is in bounds + unsafe { + other.take_unchecked_impl(&take_right(total_rows, n_rows_right, slice), parallel) + } + } else { + let iter = (0..n_rows_left).map(|_| other); + concat_df_unchecked(iter) + } + }; + let (l_df, r_df) = if parallel { + try_raise_keyboard_interrupt(); + POOL.install(|| rayon::join(create_left_df, create_right_df)) + } else { + (create_left_df(), create_right_df()) + }; + Ok((l_df, r_df)) +} + +pub(super) fn fused_cross_filter( + left: &DataFrame, + right: &DataFrame, + suffix: Option, + cross_join_options: &CrossJoinOptions, +) -> PolarsResult { + // Because we do a cartesian product, the number of partitions is squared. + // We take the sqrt, but we don't expect every partition to produce results and work can be + // imbalanced, so we multiply the number of partitions by 2; + let n_partitions = (_set_partition_size() as f32).sqrt() as usize * 2; + let splitted_a = split(left, n_partitions); + let splitted_b = split(right, n_partitions); + + let cartesian_prod = splitted_a + .iter() + .flat_map(|l| splitted_b.iter().map(move |r| (l, r))) + .collect::>(); + + let names = _finish_join(left.clear(), right.clear(), suffix)?; + let rename_names = names.get_column_names(); + let rename_names = &rename_names[left.width()..]; + + let dfs = POOL + .install(|| { + cartesian_prod.par_iter().map(|(left, right)| { + let (mut left, right) = cross_join_dfs(left, right, None, false)?; + let mut right_columns = right.take_columns(); + + for (c, name) in right_columns.iter_mut().zip(rename_names) { + c.rename((*name).clone()); + } + + unsafe { left.hstack_mut_unchecked(&right_columns) }; + + cross_join_options.predicate.apply(left) + }) + }) + .collect::>>()?; + + Ok(accumulate_dataframes_vertical_unchecked(dfs)) +} diff --git a/crates/polars-ops/src/frame/join/dispatch_left_right.rs b/crates/polars-ops/src/frame/join/dispatch_left_right.rs new file mode 100644 index 000000000000..08289b3e86bc --- /dev/null +++ b/crates/polars-ops/src/frame/join/dispatch_left_right.rs @@ -0,0 +1,266 @@ +use super::*; +use crate::prelude::*; + +pub(super) fn left_join_from_series( + left: DataFrame, + right: &DataFrame, + s_left: &Series, + s_right: &Series, + args: JoinArgs, + verbose: bool, + drop_names: Option>, +) -> PolarsResult { + let (df_left, df_right) = materialize_left_join_from_series( + left, right, s_left, s_right, &args, verbose, drop_names, + )?; + _finish_join(df_left, df_right, args.suffix) +} + +pub(super) fn right_join_from_series( + left: &DataFrame, + right: DataFrame, + s_left: &Series, + s_right: &Series, + mut args: JoinArgs, + verbose: bool, + drop_names: Option>, +) -> PolarsResult { + // Swap the order of tables to do a right join. + args.maintain_order = args.maintain_order.flip(); + let (df_right, df_left) = materialize_left_join_from_series( + right, left, s_right, s_left, &args, verbose, drop_names, + )?; + _finish_join(df_left, df_right, args.suffix) +} + +pub fn materialize_left_join_from_series( + mut left: DataFrame, + right_: &DataFrame, + s_left: &Series, + s_right: &Series, + args: &JoinArgs, + verbose: bool, + drop_names: Option>, +) -> PolarsResult<(DataFrame, DataFrame)> { + #[cfg(feature = "dtype-categorical")] + _check_categorical_src(s_left.dtype(), s_right.dtype())?; + + let mut s_left = s_left.clone(); + // Eagerly limit left if possible. + if let Some((offset, len)) = args.slice { + if offset == 0 { + left = left.slice(0, len); + s_left = s_left.slice(0, len); + } + } + + // Ensure that the chunks are aligned otherwise we go OOB. + let mut right = Cow::Borrowed(right_); + let mut s_right = s_right.clone(); + if left.should_rechunk() { + left.as_single_chunk_par(); + s_left = s_left.rechunk(); + } + if right.should_rechunk() { + let mut other = right_.clone(); + other.as_single_chunk_par(); + right = Cow::Owned(other); + s_right = s_right.rechunk(); + } + + // The current sort_or_hash_left implementation preserves the Left DataFrame order so skip left for now. + let requires_ordering = matches!( + args.maintain_order, + MaintainOrderJoin::Right | MaintainOrderJoin::RightLeft + ); + if requires_ordering { + // When ordering we rechunk the series so we don't get ChunkIds as output + s_left = s_left.rechunk(); + s_right = s_right.rechunk(); + } + + let (left_idx, right_idx) = sort_or_hash_left( + &s_left, + &s_right, + verbose, + args.validation, + args.nulls_equal, + )?; + + let right = if let Some(drop_names) = drop_names { + right.drop_many(drop_names) + } else { + right.drop(s_right.name()).unwrap() + }; + try_raise_keyboard_interrupt(); + + #[cfg(feature = "chunked_ids")] + match (left_idx, right_idx) { + (ChunkJoinIds::Left(left_idx), ChunkJoinOptIds::Left(right_idx)) => { + if requires_ordering { + Ok(maintain_order_idx( + &left, + &right, + left_idx.as_slice(), + right_idx.as_slice(), + args, + )) + } else { + Ok(POOL.join( + || materialize_left_join_idx_left(&left, left_idx.as_slice(), args), + || materialize_left_join_idx_right(&right, right_idx.as_slice(), args), + )) + } + }, + (ChunkJoinIds::Left(left_idx), ChunkJoinOptIds::Right(right_idx)) => Ok(POOL.join( + || materialize_left_join_idx_left(&left, left_idx.as_slice(), args), + || materialize_left_join_chunked_right(&right, right_idx.as_slice(), args), + )), + (ChunkJoinIds::Right(left_idx), ChunkJoinOptIds::Right(right_idx)) => Ok(POOL.join( + || materialize_left_join_chunked_left(&left, left_idx.as_slice(), args), + || materialize_left_join_chunked_right(&right, right_idx.as_slice(), args), + )), + (ChunkJoinIds::Right(left_idx), ChunkJoinOptIds::Left(right_idx)) => Ok(POOL.join( + || materialize_left_join_chunked_left(&left, left_idx.as_slice(), args), + || materialize_left_join_idx_right(&right, right_idx.as_slice(), args), + )), + } + + #[cfg(not(feature = "chunked_ids"))] + if requires_ordering { + Ok(maintain_order_idx( + &left, + &right, + left_idx.as_slice(), + right_idx.as_slice(), + args, + )) + } else { + Ok(POOL.join( + || materialize_left_join_idx_left(&left, left_idx.as_slice(), args), + || materialize_left_join_idx_right(&right, right_idx.as_slice(), args), + )) + } +} + +fn maintain_order_idx( + left: &DataFrame, + other: &DataFrame, + left_idx: &[IdxSize], + right_idx: &[NullableIdxSize], + args: &JoinArgs, +) -> (DataFrame, DataFrame) { + let mut df = { + // SAFETY: left_idx and right_idx are continuous memory that outlive the memory mapped slices + let left = unsafe { IdxCa::mmap_slice("a".into(), left_idx) }; + let right = unsafe { IdxCa::mmap_slice("b".into(), bytemuck::cast_slice(right_idx)) }; + DataFrame::new(vec![left.into_series().into(), right.into_series().into()]).unwrap() + }; + + let options = SortMultipleOptions::new() + .with_order_descending(false) + .with_maintain_order(true); + + let columns = match args.maintain_order { + // If the left order is preserved then there are no unsorted right rows + // So Left and LeftRight are equal + MaintainOrderJoin::Left | MaintainOrderJoin::LeftRight => vec!["a"], + MaintainOrderJoin::Right => vec!["b"], + MaintainOrderJoin::RightLeft => vec!["b", "a"], + _ => unreachable!(), + }; + + df.sort_in_place(columns, options).unwrap(); + df.rechunk_mut(); + + let join_tuples_left = df + .column("a") + .unwrap() + .as_materialized_series() + .idx() + .unwrap() + .cont_slice() + .unwrap(); + + let join_tuples_right = df + .column("b") + .unwrap() + .as_materialized_series() + .idx() + .unwrap() + .cont_slice() + .unwrap(); + + POOL.join( + || materialize_left_join_idx_left(left, join_tuples_left, args), + || materialize_left_join_idx_right(other, bytemuck::cast_slice(join_tuples_right), args), + ) +} + +fn materialize_left_join_idx_left( + left: &DataFrame, + left_idx: &[IdxSize], + args: &JoinArgs, +) -> DataFrame { + let left_idx = if let Some((offset, len)) = args.slice { + slice_slice(left_idx, offset, len) + } else { + left_idx + }; + + unsafe { + left._create_left_df_from_slice( + left_idx, + true, + args.slice.is_some(), + matches!( + args.maintain_order, + MaintainOrderJoin::Left | MaintainOrderJoin::LeftRight + ) || args.how == JoinType::Left + && !matches!( + args.maintain_order, + MaintainOrderJoin::Right | MaintainOrderJoin::RightLeft, + ), + ) + } +} + +fn materialize_left_join_idx_right( + right: &DataFrame, + right_idx: &[NullableIdxSize], + args: &JoinArgs, +) -> DataFrame { + let right_idx = if let Some((offset, len)) = args.slice { + slice_slice(right_idx, offset, len) + } else { + right_idx + }; + unsafe { IdxCa::with_nullable_idx(right_idx, |idx| right.take_unchecked(idx)) } +} +#[cfg(feature = "chunked_ids")] +fn materialize_left_join_chunked_left( + left: &DataFrame, + left_idx: &[ChunkId], + args: &JoinArgs, +) -> DataFrame { + let left_idx = if let Some((offset, len)) = args.slice { + slice_slice(left_idx, offset, len) + } else { + left_idx + }; + unsafe { left.create_left_df_chunked(left_idx, true, args.slice.is_some()) } +} + +#[cfg(feature = "chunked_ids")] +fn materialize_left_join_chunked_right( + right: &DataFrame, + right_idx: &[ChunkId], + args: &JoinArgs, +) -> DataFrame { + let right_idx = if let Some((offset, len)) = args.slice { + slice_slice(right_idx, offset, len) + } else { + right_idx + }; + unsafe { right._take_opt_chunked_unchecked_hor_par(right_idx) } +} diff --git a/crates/polars-ops/src/frame/join/general.rs b/crates/polars-ops/src/frame/join/general.rs new file mode 100644 index 000000000000..f5328bf73d6b --- /dev/null +++ b/crates/polars-ops/src/frame/join/general.rs @@ -0,0 +1,109 @@ +use polars_utils::format_pl_smallstr; + +use super::*; +use crate::series::coalesce_columns; + +pub fn _join_suffix_name(name: &str, suffix: &str) -> PlSmallStr { + format_pl_smallstr!("{name}{suffix}") +} + +fn get_suffix(suffix: Option) -> PlSmallStr { + suffix.unwrap_or_else(|| PlSmallStr::from_static("_right")) +} + +/// Renames the columns on the right to not clash with the left using a specified or otherwise default suffix +/// and then merges the right dataframe into the left +#[doc(hidden)] +pub fn _finish_join( + mut df_left: DataFrame, + mut df_right: DataFrame, + suffix: Option, +) -> PolarsResult { + let mut left_names = PlHashSet::with_capacity(df_left.width()); + + df_left.get_columns().iter().for_each(|series| { + left_names.insert(series.name()); + }); + + let mut rename_strs = Vec::with_capacity(df_right.width()); + let right_names = df_right.schema(); + + for name in right_names.iter_names() { + if left_names.contains(name) { + rename_strs.push(name.clone()) + } + } + + let suffix = get_suffix(suffix); + + for name in rename_strs { + let new_name = _join_suffix_name(name.as_str(), suffix.as_str()); + // Safety: IR resolving should guarantee this passes + df_right.rename(&name, new_name.clone()).unwrap(); + } + + drop(left_names); + // Safety: IR resolving should guarantee this passes + unsafe { df_left.hstack_mut_unchecked(df_right.get_columns()) }; + Ok(df_left) +} + +pub fn _coalesce_full_join( + mut df: DataFrame, + keys_left: &[PlSmallStr], + keys_right: &[PlSmallStr], + suffix: Option, + df_left: &DataFrame, +) -> DataFrame { + // No need to allocate the schema because we already + // know for certain that the column name for left is `name` + // and for right is `name + suffix` + let schema_left = if keys_left == keys_right { + Arc::new(Schema::default()) + } else { + df_left.schema().clone() + }; + + let schema = df.schema().clone(); + let mut to_remove = Vec::with_capacity(keys_right.len()); + + // SAFETY: we maintain invariants. + let columns = unsafe { df.get_columns_mut() }; + let suffix = get_suffix(suffix); + for (l, r) in keys_left.iter().zip(keys_right.iter()) { + let pos_l = schema.get_full(l.as_str()).unwrap().0; + + let r = if l == r || schema_left.contains(r.as_str()) { + _join_suffix_name(r.as_str(), suffix.as_str()) + } else { + r.clone() + }; + let pos_r = schema.get_full(&r).unwrap().0; + + let l = columns[pos_l].clone(); + let r = columns[pos_r].clone(); + + columns[pos_l] = coalesce_columns(&[l, r]).unwrap(); + to_remove.push(pos_r); + } + // sort in reverse order, so the indexes remain correct if we remove. + to_remove.sort_by(|a, b| b.cmp(a)); + for pos in to_remove { + let _ = columns.remove(pos); + } + df.clear_schema(); + df +} + +#[cfg(feature = "chunked_ids")] +pub(crate) fn create_chunked_index_mapping(chunks: &[ArrayRef], len: usize) -> Vec { + let mut vals = Vec::with_capacity(len); + + for (chunk_i, chunk) in chunks.iter().enumerate() { + vals.extend( + (0..chunk.len()).map(|array_i| ChunkId::store(chunk_i as IdxSize, array_i as IdxSize)), + ) + } + + vals +} diff --git a/crates/polars-ops/src/frame/join/hash_join/mod.rs b/crates/polars-ops/src/frame/join/hash_join/mod.rs new file mode 100644 index 000000000000..9653e8593609 --- /dev/null +++ b/crates/polars-ops/src/frame/join/hash_join/mod.rs @@ -0,0 +1,218 @@ +#![allow(unsafe_op_in_unsafe_fn)] +pub(super) mod single_keys; +mod single_keys_dispatch; +mod single_keys_inner; +mod single_keys_left; +mod single_keys_outer; +#[cfg(feature = "semi_anti_join")] +mod single_keys_semi_anti; +pub(super) mod sort_merge; +use arrow::array::ArrayRef; +use polars_core::POOL; +use polars_core::utils::_set_partition_size; +use polars_utils::index::ChunkId; +pub(super) use single_keys::*; +pub use single_keys_dispatch::SeriesJoin; +#[cfg(feature = "asof_join")] +pub(super) use single_keys_dispatch::prepare_binary; +use single_keys_inner::*; +use single_keys_left::*; +use single_keys_outer::*; +#[cfg(feature = "semi_anti_join")] +use single_keys_semi_anti::*; +pub(crate) use sort_merge::*; + +pub use super::*; +#[cfg(feature = "chunked_ids")] +use crate::chunked_array::gather::chunked::TakeChunkedHorPar; + +pub fn default_join_ids() -> ChunkJoinOptIds { + #[cfg(feature = "chunked_ids")] + { + Either::Left(vec![]) + } + #[cfg(not(feature = "chunked_ids"))] + { + vec![] + } +} + +macro_rules! det_hash_prone_order { + ($self:expr, $other:expr) => {{ + // The shortest relation will be used to create a hash table. + if $self.len() > $other.len() { + ($self, $other, false) + } else { + ($other, $self, true) + } + }}; +} + +#[cfg(feature = "performant")] +use arrow::legacy::conversion::primitive_to_vec; +pub(super) use det_hash_prone_order; + +pub trait JoinDispatch: IntoDf { + /// # Safety + /// Join tuples must be in bounds + #[cfg(feature = "chunked_ids")] + unsafe fn create_left_df_chunked( + &self, + chunk_ids: &[ChunkId], + left_join: bool, + was_sliced: bool, + ) -> DataFrame { + let df_self = self.to_df(); + + let left_join_no_duplicate_matches = + left_join && !was_sliced && chunk_ids.len() == df_self.height(); + + if left_join_no_duplicate_matches { + df_self.clone() + } else { + // left join keys are in ascending order + let sorted = if left_join { + IsSorted::Ascending + } else { + IsSorted::Not + }; + df_self._take_chunked_unchecked_hor_par(chunk_ids, sorted) + } + } + + /// # Safety + /// Join tuples must be in bounds + unsafe fn _create_left_df_from_slice( + &self, + join_tuples: &[IdxSize], + left_join: bool, + was_sliced: bool, + sorted_tuple_idx: bool, + ) -> DataFrame { + let df_self = self.to_df(); + + let left_join_no_duplicate_matches = + sorted_tuple_idx && left_join && !was_sliced && join_tuples.len() == df_self.height(); + + if left_join_no_duplicate_matches { + df_self.clone() + } else { + let sorted = if sorted_tuple_idx { + IsSorted::Ascending + } else { + IsSorted::Not + }; + + df_self._take_unchecked_slice_sorted(join_tuples, true, sorted) + } + } + + #[cfg(feature = "semi_anti_join")] + /// # Safety + /// `idx` must be in bounds + unsafe fn _finish_anti_semi_join( + &self, + mut idx: &[IdxSize], + slice: Option<(i64, usize)>, + ) -> DataFrame { + let ca_self = self.to_df(); + if let Some((offset, len)) = slice { + idx = slice_slice(idx, offset, len); + } + // idx from anti-semi join should always be sorted + ca_self._take_unchecked_slice_sorted(idx, true, IsSorted::Ascending) + } + + #[cfg(feature = "semi_anti_join")] + fn _semi_anti_join_from_series( + &self, + s_left: &Series, + s_right: &Series, + slice: Option<(i64, usize)>, + anti: bool, + nulls_equal: bool, + ) -> PolarsResult { + let ca_self = self.to_df(); + #[cfg(feature = "dtype-categorical")] + _check_categorical_src(s_left.dtype(), s_right.dtype())?; + + let idx = s_left.hash_join_semi_anti(s_right, anti, nulls_equal)?; + // SAFETY: + // indices are in bounds + Ok(unsafe { ca_self._finish_anti_semi_join(&idx, slice) }) + } + fn _full_join_from_series( + &self, + other: &DataFrame, + s_left: &Series, + s_right: &Series, + args: JoinArgs, + ) -> PolarsResult { + let df_self = self.to_df(); + #[cfg(feature = "dtype-categorical")] + _check_categorical_src(s_left.dtype(), s_right.dtype())?; + + // Get the indexes of the joined relations + let (mut join_idx_l, mut join_idx_r) = + s_left.hash_join_outer(s_right, args.validation, args.nulls_equal)?; + + try_raise_keyboard_interrupt(); + if let Some((offset, len)) = args.slice { + let (offset, len) = slice_offsets(offset, len, join_idx_l.len()); + join_idx_l.slice(offset, len); + join_idx_r.slice(offset, len); + } + let idx_ca_l = IdxCa::with_chunk("a".into(), join_idx_l); + let idx_ca_r = IdxCa::with_chunk("b".into(), join_idx_r); + + let (df_left, df_right) = if args.maintain_order != MaintainOrderJoin::None { + let mut df = DataFrame::new(vec![ + idx_ca_l.into_series().into(), + idx_ca_r.into_series().into(), + ])?; + + let options = SortMultipleOptions::new() + .with_order_descending(false) + .with_maintain_order(true) + .with_nulls_last(true); + + let columns = match args.maintain_order { + MaintainOrderJoin::Left => vec!["a"], + MaintainOrderJoin::LeftRight => vec!["a", "b"], + MaintainOrderJoin::Right => vec!["b"], + MaintainOrderJoin::RightLeft => vec!["b", "a"], + _ => unreachable!(), + }; + + df.sort_in_place(columns, options)?; + + let join_tuples_left = df.column("a").unwrap().idx().unwrap(); + let join_tuples_right = df.column("b").unwrap().idx().unwrap(); + POOL.join( + || unsafe { df_self.take_unchecked(join_tuples_left) }, + || unsafe { other.take_unchecked(join_tuples_right) }, + ) + } else { + POOL.join( + || unsafe { df_self.take_unchecked(&idx_ca_l) }, + || unsafe { other.take_unchecked(&idx_ca_r) }, + ) + }; + + let coalesce = args.coalesce.coalesce(&JoinType::Full); + let out = _finish_join(df_left, df_right, args.suffix.clone()); + if coalesce { + Ok(_coalesce_full_join( + out?, + &[s_left.name().clone()], + &[s_right.name().clone()], + args.suffix.clone(), + df_self, + )) + } else { + out + } + } +} + +impl JoinDispatch for DataFrame {} diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys.rs new file mode 100644 index 000000000000..7792a0b5b4d6 --- /dev/null +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys.rs @@ -0,0 +1,182 @@ +use polars_utils::hashing::{DirtyHash, hash_to_partition}; +use polars_utils::idx_vec::IdxVec; +use polars_utils::nulls::IsNull; +use polars_utils::sync::SyncPtr; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; +use polars_utils::unitvec; + +use super::*; + +// FIXME: we should compute the number of threads / partition size we'll use. +// let avail_threads = POOL.current_num_threads(); +// let n_threads = (num_keys / MIN_ELEMS_PER_THREAD).clamp(1, avail_threads); +// Use a small element per thread threshold for debugging/testing purposes. +const MIN_ELEMS_PER_THREAD: usize = if cfg!(debug_assertions) { 1 } else { 128 }; + +pub(crate) fn build_tables( + keys: Vec, + nulls_equal: bool, +) -> Vec::TotalOrdItem, IdxVec>> +where + T: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull, + I: IntoIterator + Send + Sync + Clone, +{ + // FIXME: change interface to split the input here, instead of taking + // pre-split input iterators. + let n_partitions = keys.len(); + let n_threads = n_partitions; + let num_keys_est: usize = keys + .iter() + .map(|k| k.clone().into_iter().size_hint().0) + .sum(); + + // Don't bother parallelizing anything for small inputs. + if num_keys_est < 2 * MIN_ELEMS_PER_THREAD { + let mut hm: PlHashMap = PlHashMap::new(); + let mut offset = 0; + for it in keys { + for k in it { + let k = k.to_total_ord(); + if !k.is_null() || nulls_equal { + hm.entry(k).or_default().push(offset); + } + offset += 1; + } + } + return vec![hm]; + } + + POOL.install(|| { + // Compute the number of elements in each partition for each portion. + let per_thread_partition_sizes: Vec> = keys + .par_iter() + .with_max_len(1) + .map(|key_portion| { + let mut partition_sizes = vec![0; n_partitions]; + for key in key_portion.clone() { + let key = key.to_total_ord(); + let p = hash_to_partition(key.dirty_hash(), n_partitions); + unsafe { + *partition_sizes.get_unchecked_mut(p) += 1; + } + } + partition_sizes + }) + .collect(); + + // Compute output offsets with a cumulative sum. + let mut per_thread_partition_offsets = vec![0; n_partitions * n_threads + 1]; + let mut partition_offsets = vec![0; n_partitions + 1]; + let mut cum_offset = 0; + for p in 0..n_partitions { + partition_offsets[p] = cum_offset; + for t in 0..n_threads { + per_thread_partition_offsets[t * n_partitions + p] = cum_offset; + cum_offset += per_thread_partition_sizes[t][p]; + } + } + let num_keys = cum_offset; + per_thread_partition_offsets[n_threads * n_partitions] = num_keys; + partition_offsets[n_partitions] = num_keys; + + // FIXME: we wouldn't need this if we changed our interface to split the + // input in this function, instead of taking a vec of iterators. + let mut per_thread_input_offsets = vec![0; n_partitions]; + cum_offset = 0; + for t in 0..n_threads { + per_thread_input_offsets[t] = cum_offset; + for p in 0..n_partitions { + cum_offset += per_thread_partition_sizes[t][p]; + } + } + + // Scatter values into partitions. + let mut scatter_keys: Vec = Vec::with_capacity(num_keys); + let mut scatter_idxs: Vec = Vec::with_capacity(num_keys); + let scatter_keys_ptr = unsafe { SyncPtr::new(scatter_keys.as_mut_ptr()) }; + let scatter_idxs_ptr = unsafe { SyncPtr::new(scatter_idxs.as_mut_ptr()) }; + keys.into_par_iter() + .with_max_len(1) + .enumerate() + .for_each(|(t, key_portion)| { + let mut partition_offsets = + per_thread_partition_offsets[t * n_partitions..(t + 1) * n_partitions].to_vec(); + for (i, key) in key_portion.into_iter().enumerate() { + let key = key.to_total_ord(); + unsafe { + let p = hash_to_partition(key.dirty_hash(), n_partitions); + let off = partition_offsets.get_unchecked_mut(p); + *scatter_keys_ptr.get().add(*off) = key; + *scatter_idxs_ptr.get().add(*off) = + (per_thread_input_offsets[t] + i) as IdxSize; + *off += 1; + } + } + }); + unsafe { + scatter_keys.set_len(num_keys); + scatter_idxs.set_len(num_keys); + } + + // Build tables. + (0..n_partitions) + .into_par_iter() + .with_max_len(1) + .map(|p| { + // Resizing the hash map is very, very expensive. That's why we + // adopt a hybrid strategy: we assume an initially small hash + // map, which would satisfy a highly skewed relation. If this + // fills up we immediately reserve enough for a full cardinality + // data set. + let partition_range = partition_offsets[p]..partition_offsets[p + 1]; + let full_size = partition_range.len(); + let mut conservative_size = _HASHMAP_INIT_SIZE.max(full_size / 64); + let mut hm: PlHashMap = + PlHashMap::with_capacity(conservative_size); + + unsafe { + for i in partition_range { + if hm.len() == conservative_size { + hm.reserve(full_size - conservative_size); + conservative_size = 0; // Hack to ensure we never hit this branch again. + } + + let key = *scatter_keys.get_unchecked(i); + + if !key.is_null() || nulls_equal { + let idx = *scatter_idxs.get_unchecked(i); + match hm.entry(key) { + Entry::Occupied(mut o) => { + o.get_mut().push(idx as IdxSize); + }, + Entry::Vacant(v) => { + let iv = unitvec![idx as IdxSize]; + v.insert(iv); + }, + }; + } + } + } + + hm + }) + .collect() + }) +} + +// we determine the offset so that we later know which index to store in the join tuples +pub(super) fn probe_to_offsets(probe: &[I]) -> Vec +where + I: IntoIterator + Clone, +{ + probe + .iter() + .map(|ph| ph.clone().into_iter().size_hint().1.unwrap()) + .scan(0, |state, val| { + let out = *state; + *state += val; + Some(out) + }) + .collect() +} diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs new file mode 100644 index 000000000000..e51f87f98f4d --- /dev/null +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs @@ -0,0 +1,732 @@ +use arrow::array::PrimitiveArray; +use polars_core::chunked_array::ops::row_encode::encode_rows_unordered; +use polars_core::series::BitRepr; +use polars_core::utils::split; +use polars_core::with_match_physical_float_polars_type; +use polars_utils::aliases::PlRandomState; +use polars_utils::hashing::DirtyHash; +use polars_utils::nulls::IsNull; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; + +use super::*; +use crate::series::SeriesSealed; + +pub trait SeriesJoin: SeriesSealed + Sized { + #[doc(hidden)] + fn hash_join_left( + &self, + other: &Series, + validate: JoinValidation, + nulls_equal: bool, + ) -> PolarsResult { + let s_self = self.as_series(); + let (lhs, rhs) = (s_self.to_physical_repr(), other.to_physical_repr()); + validate.validate_probe(&lhs, &rhs, false, nulls_equal)?; + + let lhs_dtype = lhs.dtype(); + let rhs_dtype = rhs.dtype(); + + use DataType as T; + match lhs_dtype { + T::String | T::Binary => { + let lhs = lhs.cast(&T::Binary).unwrap(); + let rhs = rhs.cast(&T::Binary).unwrap(); + let lhs = lhs.binary().unwrap(); + let rhs = rhs.binary().unwrap(); + let (lhs, rhs, _, _) = prepare_binary::(lhs, rhs, false); + let lhs = lhs.iter().map(|v| v.as_slice()).collect::>(); + let rhs = rhs.iter().map(|v| v.as_slice()).collect::>(); + let build_null_count = other.null_count(); + hash_join_tuples_left( + lhs, + rhs, + None, + None, + validate, + nulls_equal, + build_null_count, + ) + }, + T::BinaryOffset => { + let lhs = lhs.binary_offset().unwrap(); + let rhs = rhs.binary_offset().unwrap(); + let (lhs, rhs, _, _) = prepare_binary::(lhs, rhs, false); + // Take slices so that vecs are not copied + let lhs = lhs.iter().map(|k| k.as_slice()).collect::>(); + let rhs = rhs.iter().map(|k| k.as_slice()).collect::>(); + let build_null_count = other.null_count(); + hash_join_tuples_left( + lhs, + rhs, + None, + None, + validate, + nulls_equal, + build_null_count, + ) + }, + T::List(_) => { + let lhs = &encode_rows_unordered(&[lhs.into_owned().into()])?.into_series(); + let rhs = &encode_rows_unordered(&[rhs.into_owned().into()])?.into_series(); + lhs.hash_join_left(rhs, validate, nulls_equal) + }, + #[cfg(feature = "dtype-array")] + T::Array(_, _) => { + let lhs = &encode_rows_unordered(&[lhs.into_owned().into()])?.into_series(); + let rhs = &encode_rows_unordered(&[rhs.into_owned().into()])?.into_series(); + lhs.hash_join_left(rhs, validate, nulls_equal) + }, + #[cfg(feature = "dtype-struct")] + T::Struct(_) => { + let lhs = &encode_rows_unordered(&[lhs.into_owned().into()])?.into_series(); + let rhs = &encode_rows_unordered(&[rhs.into_owned().into()])?.into_series(); + lhs.hash_join_left(rhs, validate, nulls_equal) + }, + x if x.is_float() => { + with_match_physical_float_polars_type!(lhs.dtype(), |$T| { + let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); + let rhs: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); + num_group_join_left(lhs, rhs, validate, nulls_equal) + }) + }, + _ => { + let lhs = s_self.bit_repr(); + let rhs = other.bit_repr(); + + let (Some(lhs), Some(rhs)) = (lhs, rhs) else { + polars_bail!(nyi = "Hash Left Join between {lhs_dtype} and {rhs_dtype}"); + }; + + use BitRepr as B; + match (lhs, rhs) { + (B::Small(lhs), B::Small(rhs)) => { + // Turbofish: see #17137. + num_group_join_left::(&lhs, &rhs, validate, nulls_equal) + }, + (B::Large(lhs), B::Large(rhs)) => { + // Turbofish: see #17137. + num_group_join_left::(&lhs, &rhs, validate, nulls_equal) + }, + _ => { + polars_bail!( + nyi = "Mismatch bit repr Hash Left Join between {lhs_dtype} and {rhs_dtype}", + ); + }, + } + }, + } + } + + #[cfg(feature = "semi_anti_join")] + fn hash_join_semi_anti( + &self, + other: &Series, + anti: bool, + nulls_equal: bool, + ) -> PolarsResult> { + let s_self = self.as_series(); + let (lhs, rhs) = (s_self.to_physical_repr(), other.to_physical_repr()); + + let lhs_dtype = lhs.dtype(); + let rhs_dtype = rhs.dtype(); + + use DataType as T; + Ok(match lhs_dtype { + T::String | T::Binary => { + let lhs = lhs.cast(&T::Binary).unwrap(); + let rhs = rhs.cast(&T::Binary).unwrap(); + let lhs = lhs.binary().unwrap(); + let rhs = rhs.binary().unwrap(); + let (lhs, rhs, _, _) = prepare_binary::(lhs, rhs, false); + // Take slices so that vecs are not copied + let lhs = lhs.iter().map(|k| k.as_slice()).collect::>(); + let rhs = rhs.iter().map(|k| k.as_slice()).collect::>(); + if anti { + hash_join_tuples_left_anti(lhs, rhs, nulls_equal) + } else { + hash_join_tuples_left_semi(lhs, rhs, nulls_equal) + } + }, + T::BinaryOffset => { + let lhs = lhs.binary_offset().unwrap(); + let rhs = rhs.binary_offset().unwrap(); + let (lhs, rhs, _, _) = prepare_binary::(lhs, rhs, false); + // Take slices so that vecs are not copied + let lhs = lhs.iter().map(|k| k.as_slice()).collect::>(); + let rhs = rhs.iter().map(|k| k.as_slice()).collect::>(); + if anti { + hash_join_tuples_left_anti(lhs, rhs, nulls_equal) + } else { + hash_join_tuples_left_semi(lhs, rhs, nulls_equal) + } + }, + T::List(_) => { + let lhs = &encode_rows_unordered(&[lhs.into_owned().into()])?.into_series(); + let rhs = &encode_rows_unordered(&[rhs.into_owned().into()])?.into_series(); + lhs.hash_join_semi_anti(rhs, anti, nulls_equal)? + }, + #[cfg(feature = "dtype-array")] + T::Array(_, _) => { + let lhs = &encode_rows_unordered(&[lhs.into_owned().into()])?.into_series(); + let rhs = &encode_rows_unordered(&[rhs.into_owned().into()])?.into_series(); + lhs.hash_join_semi_anti(rhs, anti, nulls_equal)? + }, + #[cfg(feature = "dtype-struct")] + T::Struct(_) => { + let lhs = &encode_rows_unordered(&[lhs.into_owned().into()])?.into_series(); + let rhs = &encode_rows_unordered(&[rhs.into_owned().into()])?.into_series(); + lhs.hash_join_semi_anti(rhs, anti, nulls_equal)? + }, + x if x.is_float() => { + with_match_physical_float_polars_type!(lhs.dtype(), |$T| { + let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); + let rhs: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); + num_group_join_anti_semi(lhs, rhs, anti, nulls_equal) + }) + }, + _ => { + let lhs = s_self.bit_repr(); + let rhs = other.bit_repr(); + + let (Some(lhs), Some(rhs)) = (lhs, rhs) else { + polars_bail!(nyi = "Hash Semi-Anti Join between {lhs_dtype} and {rhs_dtype}"); + }; + + use BitRepr as B; + match (lhs, rhs) { + (B::Small(lhs), B::Small(rhs)) => { + // Turbofish: see #17137. + num_group_join_anti_semi::(&lhs, &rhs, anti, nulls_equal) + }, + (B::Large(lhs), B::Large(rhs)) => { + // Turbofish: see #17137. + num_group_join_anti_semi::(&lhs, &rhs, anti, nulls_equal) + }, + _ => { + polars_bail!( + nyi = "Mismatch bit repr Hash Semi-Anti Join between {lhs_dtype} and {rhs_dtype}", + ); + }, + } + }, + }) + } + + // returns the join tuples and whether or not the lhs tuples are sorted + fn hash_join_inner( + &self, + other: &Series, + validate: JoinValidation, + nulls_equal: bool, + ) -> PolarsResult<(InnerJoinIds, bool)> { + let s_self = self.as_series(); + let (lhs, rhs) = (s_self.to_physical_repr(), other.to_physical_repr()); + validate.validate_probe(&lhs, &rhs, true, nulls_equal)?; + + let lhs_dtype = lhs.dtype(); + let rhs_dtype = rhs.dtype(); + + use DataType as T; + match lhs_dtype { + T::String | T::Binary => { + let lhs = lhs.cast(&T::Binary).unwrap(); + let rhs = rhs.cast(&T::Binary).unwrap(); + let lhs = lhs.binary().unwrap(); + let rhs = rhs.binary().unwrap(); + let (lhs, rhs, swapped, _) = prepare_binary::(lhs, rhs, true); + // Take slices so that vecs are not copied + let lhs = lhs.iter().map(|k| k.as_slice()).collect::>(); + let rhs = rhs.iter().map(|k| k.as_slice()).collect::>(); + let build_null_count = if swapped { + s_self.null_count() + } else { + other.null_count() + }; + Ok(( + hash_join_tuples_inner( + lhs, + rhs, + swapped, + validate, + nulls_equal, + build_null_count, + )?, + !swapped, + )) + }, + T::BinaryOffset => { + let lhs = lhs.binary_offset().unwrap(); + let rhs = rhs.binary_offset()?; + let (lhs, rhs, swapped, _) = prepare_binary::(lhs, rhs, true); + // Take slices so that vecs are not copied + let lhs = lhs.iter().map(|k| k.as_slice()).collect::>(); + let rhs = rhs.iter().map(|k| k.as_slice()).collect::>(); + let build_null_count = if swapped { + s_self.null_count() + } else { + other.null_count() + }; + Ok(( + hash_join_tuples_inner( + lhs, + rhs, + swapped, + validate, + nulls_equal, + build_null_count, + )?, + !swapped, + )) + }, + T::List(_) => { + let lhs = &encode_rows_unordered(&[lhs.into_owned().into()])?.into_series(); + let rhs = &encode_rows_unordered(&[rhs.into_owned().into()])?.into_series(); + lhs.hash_join_inner(rhs, validate, nulls_equal) + }, + #[cfg(feature = "dtype-array")] + T::Array(_, _) => { + let lhs = &encode_rows_unordered(&[lhs.into_owned().into()])?.into_series(); + let rhs = &encode_rows_unordered(&[rhs.into_owned().into()])?.into_series(); + lhs.hash_join_inner(rhs, validate, nulls_equal) + }, + #[cfg(feature = "dtype-struct")] + T::Struct(_) => { + let lhs = &encode_rows_unordered(&[lhs.into_owned().into()])?.into_series(); + let rhs = &encode_rows_unordered(&[rhs.into_owned().into()])?.into_series(); + lhs.hash_join_inner(rhs, validate, nulls_equal) + }, + x if x.is_float() => { + with_match_physical_float_polars_type!(lhs.dtype(), |$T| { + let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); + let rhs: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); + group_join_inner::<$T>(lhs, rhs, validate, nulls_equal) + }) + }, + _ => { + let lhs = s_self.bit_repr(); + let rhs = other.bit_repr(); + + let (Some(lhs), Some(rhs)) = (lhs, rhs) else { + polars_bail!(nyi = "Hash Inner Join between {lhs_dtype} and {rhs_dtype}"); + }; + + use BitRepr as B; + match (lhs, rhs) { + (B::Small(lhs), B::Small(rhs)) => { + // Turbofish: see #17137. + group_join_inner::(&lhs, &rhs, validate, nulls_equal) + }, + (B::Large(lhs), BitRepr::Large(rhs)) => { + // Turbofish: see #17137. + group_join_inner::(&lhs, &rhs, validate, nulls_equal) + }, + _ => { + polars_bail!( + nyi = "Mismatch bit repr Hash Inner Join between {lhs_dtype} and {rhs_dtype}" + ); + }, + } + }, + } + } + + fn hash_join_outer( + &self, + other: &Series, + validate: JoinValidation, + nulls_equal: bool, + ) -> PolarsResult<(PrimitiveArray, PrimitiveArray)> { + let s_self = self.as_series(); + let (lhs, rhs) = (s_self.to_physical_repr(), other.to_physical_repr()); + validate.validate_probe(&lhs, &rhs, true, nulls_equal)?; + + let lhs_dtype = lhs.dtype(); + let rhs_dtype = rhs.dtype(); + + use DataType as T; + match lhs_dtype { + T::String | T::Binary => { + let lhs = lhs.cast(&T::Binary).unwrap(); + let rhs = rhs.cast(&T::Binary).unwrap(); + let lhs = lhs.binary().unwrap(); + let rhs = rhs.binary().unwrap(); + let (lhs, rhs, swapped, _) = prepare_binary::(lhs, rhs, true); + // Take slices so that vecs are not copied + let lhs = lhs.iter().map(|k| k.as_slice()).collect::>(); + let rhs = rhs.iter().map(|k| k.as_slice()).collect::>(); + hash_join_tuples_outer(lhs, rhs, swapped, validate, nulls_equal) + }, + T::BinaryOffset => { + let lhs = lhs.binary_offset().unwrap(); + let rhs = rhs.binary_offset()?; + let (lhs, rhs, swapped, _) = prepare_binary::(lhs, rhs, true); + // Take slices so that vecs are not copied + let lhs = lhs.iter().map(|k| k.as_slice()).collect::>(); + let rhs = rhs.iter().map(|k| k.as_slice()).collect::>(); + hash_join_tuples_outer(lhs, rhs, swapped, validate, nulls_equal) + }, + T::List(_) => { + let lhs = &encode_rows_unordered(&[lhs.into_owned().into()])?.into_series(); + let rhs = &encode_rows_unordered(&[rhs.into_owned().into()])?.into_series(); + lhs.hash_join_outer(rhs, validate, nulls_equal) + }, + #[cfg(feature = "dtype-array")] + T::Array(_, _) => { + let lhs = &encode_rows_unordered(&[lhs.into_owned().into()])?.into_series(); + let rhs = &encode_rows_unordered(&[rhs.into_owned().into()])?.into_series(); + lhs.hash_join_outer(rhs, validate, nulls_equal) + }, + #[cfg(feature = "dtype-struct")] + T::Struct(_) => { + let lhs = &encode_rows_unordered(&[lhs.into_owned().into()])?.into_series(); + let rhs = &encode_rows_unordered(&[rhs.into_owned().into()])?.into_series(); + lhs.hash_join_outer(rhs, validate, nulls_equal) + }, + x if x.is_float() => { + with_match_physical_float_polars_type!(lhs.dtype(), |$T| { + let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); + let rhs: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); + hash_join_outer(lhs, rhs, validate, nulls_equal) + }) + }, + _ => { + let (Some(lhs), Some(rhs)) = (s_self.bit_repr(), other.bit_repr()) else { + polars_bail!(nyi = "Hash Join Outer between {lhs_dtype} and {rhs_dtype}"); + }; + + use BitRepr as B; + match (lhs, rhs) { + (B::Small(lhs), B::Small(rhs)) => { + // Turbofish: see #17137. + hash_join_outer::(&lhs, &rhs, validate, nulls_equal) + }, + (B::Large(lhs), B::Large(rhs)) => { + // Turbofish: see #17137. + hash_join_outer::(&lhs, &rhs, validate, nulls_equal) + }, + _ => { + polars_bail!( + nyi = "Mismatch bit repr Hash Join Outer between {lhs_dtype} and {rhs_dtype}" + ); + }, + } + }, + } + } +} + +impl SeriesJoin for Series {} + +fn chunks_as_slices(splitted: &[ChunkedArray]) -> Vec<&[T::Native]> +where + T: PolarsNumericType, +{ + splitted + .iter() + .flat_map(|ca| ca.downcast_iter().map(|arr| arr.values().as_slice())) + .collect() +} + +fn get_arrays(cas: &[ChunkedArray]) -> Vec<&T::Array> { + cas.iter().flat_map(|arr| arr.downcast_iter()).collect() +} + +fn group_join_inner( + left: &ChunkedArray, + right: &ChunkedArray, + validate: JoinValidation, + nulls_equal: bool, +) -> PolarsResult<(InnerJoinIds, bool)> +where + T: PolarsDataType, + for<'a> &'a T::Array: IntoIterator>>, + for<'a> T::Physical<'a>: + Send + Sync + Copy + TotalHash + TotalEq + DirtyHash + IsNull + ToTotalOrd, + for<'a> as ToTotalOrd>::TotalOrdItem: + Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull, +{ + let n_threads = POOL.current_num_threads(); + let (a, b, swapped) = det_hash_prone_order!(left, right); + let splitted_a = split(a, n_threads); + let splitted_b = split(b, n_threads); + let splitted_a = get_arrays(&splitted_a); + let splitted_b = get_arrays(&splitted_b); + + match (left.null_count(), right.null_count()) { + (0, 0) => { + let first = &splitted_a[0]; + if first.as_slice().is_some() { + let splitted_a = splitted_a + .iter() + .map(|arr| arr.as_slice().unwrap()) + .collect::>(); + let splitted_b = splitted_b + .iter() + .map(|arr| arr.as_slice().unwrap()) + .collect::>(); + Ok(( + hash_join_tuples_inner( + splitted_a, + splitted_b, + swapped, + validate, + nulls_equal, + 0, + )?, + !swapped, + )) + } else { + Ok(( + hash_join_tuples_inner( + splitted_a, + splitted_b, + swapped, + validate, + nulls_equal, + 0, + )?, + !swapped, + )) + } + }, + _ => { + let build_null_count = if swapped { + left.null_count() + } else { + right.null_count() + }; + Ok(( + hash_join_tuples_inner( + splitted_a, + splitted_b, + swapped, + validate, + nulls_equal, + build_null_count, + )?, + !swapped, + )) + }, + } +} + +#[cfg(feature = "chunked_ids")] +fn create_mappings( + chunks_left: &[ArrayRef], + chunks_right: &[ArrayRef], + left_len: usize, + right_len: usize, +) -> (Option>, Option>) { + let mapping_left = || { + if chunks_left.len() > 1 { + Some(create_chunked_index_mapping(chunks_left, left_len)) + } else { + None + } + }; + + let mapping_right = || { + if chunks_right.len() > 1 { + Some(create_chunked_index_mapping(chunks_right, right_len)) + } else { + None + } + }; + + POOL.join(mapping_left, mapping_right) +} + +#[cfg(not(feature = "chunked_ids"))] +fn create_mappings( + _chunks_left: &[ArrayRef], + _chunks_right: &[ArrayRef], + _left_len: usize, + _right_len: usize, +) -> (Option>, Option>) { + (None, None) +} + +fn num_group_join_left( + left: &ChunkedArray, + right: &ChunkedArray, + validate: JoinValidation, + nulls_equal: bool, +) -> PolarsResult +where + T: PolarsNumericType, + T::Native: TotalHash + TotalEq + DirtyHash + IsNull + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull, + T::Native: DirtyHash + Copy + ToTotalOrd, + as ToTotalOrd>::TotalOrdItem: Send + Sync + DirtyHash, +{ + let n_threads = POOL.current_num_threads(); + let splitted_a = split(left, n_threads); + let splitted_b = split(right, n_threads); + match ( + left.null_count(), + right.null_count(), + left.chunks().len(), + right.chunks().len(), + ) { + (0, 0, 1, 1) => { + let keys_a = chunks_as_slices(&splitted_a); + let keys_b = chunks_as_slices(&splitted_b); + hash_join_tuples_left(keys_a, keys_b, None, None, validate, nulls_equal, 0) + }, + (0, 0, _, _) => { + let keys_a = chunks_as_slices(&splitted_a); + let keys_b = chunks_as_slices(&splitted_b); + + let (mapping_left, mapping_right) = + create_mappings(left.chunks(), right.chunks(), left.len(), right.len()); + hash_join_tuples_left( + keys_a, + keys_b, + mapping_left.as_deref(), + mapping_right.as_deref(), + validate, + nulls_equal, + 0, + ) + }, + _ => { + let keys_a = get_arrays(&splitted_a); + let keys_b = get_arrays(&splitted_b); + let (mapping_left, mapping_right) = + create_mappings(left.chunks(), right.chunks(), left.len(), right.len()); + let build_null_count = right.null_count(); + hash_join_tuples_left( + keys_a, + keys_b, + mapping_left.as_deref(), + mapping_right.as_deref(), + validate, + nulls_equal, + build_null_count, + ) + }, + } +} + +fn hash_join_outer( + ca_in: &ChunkedArray, + other: &ChunkedArray, + validate: JoinValidation, + nulls_equal: bool, +) -> PolarsResult<(PrimitiveArray, PrimitiveArray)> +where + T: PolarsNumericType, + T::Native: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Copy + Hash + Eq + IsNull, +{ + let (a, b, swapped) = det_hash_prone_order!(ca_in, other); + + let n_partitions = _set_partition_size(); + let splitted_a = split(a, n_partitions); + let splitted_b = split(b, n_partitions); + + match (a.null_count(), b.null_count()) { + (0, 0) => { + let iters_a = splitted_a + .iter() + .flat_map(|ca| ca.downcast_iter().map(|arr| arr.values().as_slice())) + .collect::>(); + let iters_b = splitted_b + .iter() + .flat_map(|ca| ca.downcast_iter().map(|arr| arr.values().as_slice())) + .collect::>(); + hash_join_tuples_outer(iters_a, iters_b, swapped, validate, nulls_equal) + }, + _ => { + let iters_a = splitted_a + .iter() + .flat_map(|ca| ca.downcast_iter().map(|arr| arr.iter())) + .collect::>(); + let iters_b = splitted_b + .iter() + .flat_map(|ca| ca.downcast_iter().map(|arr| arr.iter())) + .collect::>(); + hash_join_tuples_outer(iters_a, iters_b, swapped, validate, nulls_equal) + }, + } +} + +pub(crate) fn prepare_binary<'a, T>( + ca: &'a ChunkedArray, + other: &'a ChunkedArray, + // In inner join and outer join, the shortest relation will be used to create a hash table. + // In left join, always use the right side to create. + build_shortest_table: bool, +) -> ( + Vec>>, + Vec>>, + bool, + PlRandomState, +) +where + T: PolarsDataType, + for<'b> ::ValueT<'b>: AsRef<[u8]>, +{ + let (a, b, swapped) = if build_shortest_table { + det_hash_prone_order!(ca, other) + } else { + (ca, other, false) + }; + let hb = PlRandomState::default(); + let bh_a = a.to_bytes_hashes(true, hb); + let bh_b = b.to_bytes_hashes(true, hb); + + (bh_a, bh_b, swapped, hb) +} + +#[cfg(feature = "semi_anti_join")] +fn num_group_join_anti_semi( + left: &ChunkedArray, + right: &ChunkedArray, + anti: bool, + nulls_equal: bool, +) -> Vec +where + T: PolarsNumericType, + T::Native: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull, + as ToTotalOrd>::TotalOrdItem: Send + Sync + DirtyHash + IsNull, +{ + let n_threads = POOL.current_num_threads(); + let splitted_a = split(left, n_threads); + let splitted_b = split(right, n_threads); + match ( + left.null_count(), + right.null_count(), + left.chunks().len(), + right.chunks().len(), + ) { + (0, 0, 1, 1) => { + let keys_a = chunks_as_slices(&splitted_a); + let keys_b = chunks_as_slices(&splitted_b); + if anti { + hash_join_tuples_left_anti(keys_a, keys_b, nulls_equal) + } else { + hash_join_tuples_left_semi(keys_a, keys_b, nulls_equal) + } + }, + (0, 0, _, _) => { + let keys_a = chunks_as_slices(&splitted_a); + let keys_b = chunks_as_slices(&splitted_b); + if anti { + hash_join_tuples_left_anti(keys_a, keys_b, nulls_equal) + } else { + hash_join_tuples_left_semi(keys_a, keys_b, nulls_equal) + } + }, + _ => { + let keys_a = get_arrays(&splitted_a); + let keys_b = get_arrays(&splitted_b); + if anti { + hash_join_tuples_left_anti(keys_a, keys_b, nulls_equal) + } else { + hash_join_tuples_left_semi(keys_a, keys_b, nulls_equal) + } + }, + } +} diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys_inner.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_inner.rs new file mode 100644 index 000000000000..ce7d06bf14cc --- /dev/null +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_inner.rs @@ -0,0 +1,149 @@ +use polars_core::utils::flatten; +use polars_utils::hashing::{DirtyHash, hash_to_partition}; +use polars_utils::idx_vec::IdxVec; +use polars_utils::itertools::Itertools; +use polars_utils::nulls::IsNull; +use polars_utils::sync::SyncPtr; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; + +use super::*; + +pub(super) fn probe_inner( + probe: I, + hash_tbls: &[PlHashMap<::TotalOrdItem, IdxVec>], + results: &mut Vec<(IdxSize, IdxSize)>, + local_offset: IdxSize, + n_tables: usize, + swap_fn: F, +) where + T: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Hash + Eq + DirtyHash, + I: IntoIterator, + F: Fn(IdxSize, IdxSize) -> (IdxSize, IdxSize), +{ + probe.into_iter().enumerate_idx().for_each(|(idx_a, k)| { + let k = k.to_total_ord(); + let idx_a = idx_a + local_offset; + // probe table that contains the hashed value + let current_probe_table = + unsafe { hash_tbls.get_unchecked(hash_to_partition(k.dirty_hash(), n_tables)) }; + + let value = current_probe_table.get(&k); + + if let Some(indexes_b) = value { + let tuples = indexes_b.iter().map(|&idx_b| swap_fn(idx_a, idx_b)); + results.extend(tuples); + } + }); +} + +pub(super) fn hash_join_tuples_inner( + probe: Vec, + build: Vec, + // Because b should be the shorter relation we could need to swap to keep left left and right right. + swapped: bool, + validate: JoinValidation, + nulls_equal: bool, + // Null count is required for join validation + build_null_count: usize, +) -> PolarsResult<(Vec, Vec)> +where + I: IntoIterator + Send + Sync + Clone, + T: Send + Sync + Copy + TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull, +{ + // NOTE: see the left join for more elaborate comments + // first we hash one relation + let hash_tbls = if validate.needs_checks() { + let mut expected_size = build + .iter() + .map(|v| v.clone().into_iter().size_hint().1.unwrap()) + .sum(); + if !nulls_equal { + expected_size -= build_null_count; + } + let hash_tbls = build_tables(build, nulls_equal); + let build_size = hash_tbls.iter().map(|m| m.len()).sum(); + validate.validate_build(build_size, expected_size, swapped)?; + hash_tbls + } else { + build_tables(build, nulls_equal) + }; + try_raise_keyboard_interrupt(); + + let n_tables = hash_tbls.len(); + let offsets = probe_to_offsets(&probe); + // next we probe the other relation + // code duplication is because we want to only do the swap check once + let out = POOL.install(|| { + let tuples = probe + .into_par_iter() + .zip(offsets) + .map(|(probe, offset)| { + let probe = probe.into_iter(); + // local reference + let hash_tbls = &hash_tbls; + let mut results = Vec::with_capacity(probe.size_hint().1.unwrap()); + let local_offset = offset as IdxSize; + + // branch is to hoist swap out of the inner loop. + if swapped { + probe_inner( + probe, + hash_tbls, + &mut results, + local_offset, + n_tables, + |idx_a, idx_b| (idx_b, idx_a), + ) + } else { + probe_inner( + probe, + hash_tbls, + &mut results, + local_offset, + n_tables, + |idx_a, idx_b| (idx_a, idx_b), + ) + } + + results + }) + .collect::>(); + + // parallel materialization + let (cap, offsets) = flatten::cap_and_offsets(&tuples); + let mut left = Vec::with_capacity(cap); + let mut right = Vec::with_capacity(cap); + + let left_ptr = unsafe { SyncPtr::new(left.as_mut_ptr()) }; + let right_ptr = unsafe { SyncPtr::new(right.as_mut_ptr()) }; + + tuples + .into_par_iter() + .zip(offsets) + .for_each(|(tuples, offset)| unsafe { + let left_ptr: *mut IdxSize = left_ptr.get(); + let left_ptr = left_ptr.add(offset); + let right_ptr: *mut IdxSize = right_ptr.get(); + let right_ptr = right_ptr.add(offset); + + // amortize loop counter + for i in 0..tuples.len() { + let tuple = tuples.get_unchecked(i); + let left_row_idx = tuple.0; + let right_row_idx = tuple.1; + + std::ptr::write(left_ptr.add(i), left_row_idx); + std::ptr::write(right_ptr.add(i), right_row_idx); + } + }); + unsafe { + left.set_len(cap); + right.set_len(cap); + } + + (left, right) + }); + Ok(out) +} diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys_left.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_left.rs new file mode 100644 index 000000000000..b3bfb8d4685b --- /dev/null +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_left.rs @@ -0,0 +1,195 @@ +use polars_core::utils::flatten::flatten_par; +use polars_utils::hashing::{DirtyHash, hash_to_partition}; +use polars_utils::nulls::IsNull; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; + +use super::*; + +#[cfg(feature = "chunked_ids")] +unsafe fn apply_mapping(idx: Vec, chunk_mapping: &[ChunkId]) -> Vec { + idx.iter() + .map(|idx| *chunk_mapping.get_unchecked(*idx as usize)) + .collect() +} + +#[cfg(feature = "chunked_ids")] +unsafe fn apply_opt_mapping(idx: Vec, chunk_mapping: &[ChunkId]) -> Vec { + idx.iter() + .map(|opt_idx| { + if opt_idx.is_null_idx() { + ChunkId::null() + } else { + *chunk_mapping.get_unchecked(opt_idx.idx() as usize) + } + }) + .collect() +} + +#[cfg(feature = "chunked_ids")] +pub(super) fn finish_left_join_mappings( + result_idx_left: Vec, + result_idx_right: Vec, + chunk_mapping_left: Option<&[ChunkId]>, + chunk_mapping_right: Option<&[ChunkId]>, +) -> LeftJoinIds { + let left = match chunk_mapping_left { + None => ChunkJoinIds::Left(result_idx_left), + Some(mapping) => ChunkJoinIds::Right(unsafe { apply_mapping(result_idx_left, mapping) }), + }; + + let right = match chunk_mapping_right { + None => ChunkJoinOptIds::Left(result_idx_right), + Some(mapping) => { + ChunkJoinOptIds::Right(unsafe { apply_opt_mapping(result_idx_right, mapping) }) + }, + }; + (left, right) +} + +#[cfg(not(feature = "chunked_ids"))] +pub(super) fn finish_left_join_mappings( + _result_idx_left: Vec, + _result_idx_right: Vec, + _chunk_mapping_left: Option<&[ChunkId]>, + _chunk_mapping_right: Option<&[ChunkId]>, +) -> LeftJoinIds { + (_result_idx_left, _result_idx_right) +} + +pub(super) fn flatten_left_join_ids(result: Vec) -> LeftJoinIds { + #[cfg(feature = "chunked_ids")] + { + let left = if result[0].0.is_left() { + let lefts = result + .iter() + .map(|join_id| join_id.0.as_ref().left().unwrap()) + .collect::>(); + let lefts = flatten_par(&lefts); + ChunkJoinIds::Left(lefts) + } else { + let lefts = result + .iter() + .map(|join_id| join_id.0.as_ref().right().unwrap()) + .collect::>(); + let lefts = flatten_par(&lefts); + ChunkJoinIds::Right(lefts) + }; + + let right = if result[0].1.is_left() { + let rights = result + .iter() + .map(|join_id| join_id.1.as_ref().left().unwrap()) + .collect::>(); + let rights = flatten_par(&rights); + ChunkJoinOptIds::Left(rights) + } else { + let rights = result + .iter() + .map(|join_id| join_id.1.as_ref().right().unwrap()) + .collect::>(); + let rights = flatten_par(&rights); + ChunkJoinOptIds::Right(rights) + }; + + (left, right) + } + #[cfg(not(feature = "chunked_ids"))] + { + let lefts = result.iter().map(|join_id| &join_id.0).collect::>(); + let rights = result.iter().map(|join_id| &join_id.1).collect::>(); + let lefts = flatten_par(&lefts); + let rights = flatten_par(&rights); + (lefts, rights) + } +} + +pub(super) fn hash_join_tuples_left( + probe: Vec, + build: Vec, + // map the global indices to [chunk_idx, array_idx] + // only needed if we have non contiguous memory + chunk_mapping_left: Option<&[ChunkId]>, + chunk_mapping_right: Option<&[ChunkId]>, + validate: JoinValidation, + nulls_equal: bool, + // We should know the number of nulls to avoid extra calculation + build_null_count: usize, +) -> PolarsResult +where + I: IntoIterator, + ::IntoIter: Send + Sync + Clone, + T: Send + Sync + Copy + TotalHash + TotalEq + DirtyHash + IsNull + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull, +{ + let probe = probe.into_iter().map(|i| i.into_iter()).collect::>(); + let build = build.into_iter().map(|i| i.into_iter()).collect::>(); + // first we hash one relation + let hash_tbls = if validate.needs_checks() { + let mut expected_size = build.iter().map(|v| v.size_hint().1.unwrap()).sum(); + if !nulls_equal { + expected_size -= build_null_count; + } + let hash_tbls = build_tables(build, nulls_equal); + let build_size = hash_tbls.iter().map(|m| m.len()).sum(); + validate.validate_build(build_size, expected_size, false)?; + hash_tbls + } else { + build_tables(build, nulls_equal) + }; + try_raise_keyboard_interrupt(); + let n_tables = hash_tbls.len(); + + // we determine the offset so that we later know which index to store in the join tuples + let offsets = probe_to_offsets(&probe); + + // next we probe the other relation + let result: Vec = POOL.install(move || { + probe + .into_par_iter() + .zip(offsets) + // probes_hashes: Vec processed by this thread + // offset: offset index + .map(move |(probe, offset)| { + // local reference + let hash_tbls = &hash_tbls; + + // assume the result tuples equal length of the no. of hashes processed by this thread. + let mut result_idx_left = Vec::with_capacity(probe.size_hint().1.unwrap()); + let mut result_idx_right = Vec::with_capacity(probe.size_hint().1.unwrap()); + + probe.enumerate().for_each(|(idx_a, k)| { + let k = k.to_total_ord(); + let idx_a = (idx_a + offset) as IdxSize; + // probe table that contains the hashed value + let current_probe_table = unsafe { + hash_tbls.get_unchecked(hash_to_partition(k.dirty_hash(), n_tables)) + }; + + // we already hashed, so we don't have to hash again. + let value = current_probe_table.get(&k); + + match value { + // left and right matches + Some(indexes_b) => { + result_idx_left.extend(std::iter::repeat_n(idx_a, indexes_b.len())); + result_idx_right.extend_from_slice(bytemuck::cast_slice(indexes_b)); + }, + // only left values, right = null + None => { + result_idx_left.push(idx_a); + result_idx_right.push(NullableIdxSize::null()); + }, + } + }); + finish_left_join_mappings( + result_idx_left, + result_idx_right, + chunk_mapping_left, + chunk_mapping_right, + ) + }) + .collect() + }); + + Ok(flatten_left_join_ids(result)) +} diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys_outer.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_outer.rs new file mode 100644 index 000000000000..db5ba9e28686 --- /dev/null +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_outer.rs @@ -0,0 +1,266 @@ +use std::hash::BuildHasher; + +use arrow::array::{MutablePrimitiveArray, PrimitiveArray}; +use arrow::legacy::utils::CustomIterTools; +use polars_utils::hashing::hash_to_partition; +use polars_utils::idx_vec::IdxVec; +use polars_utils::nulls::IsNull; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; +use polars_utils::unitvec; + +use super::*; + +pub(crate) fn create_hash_and_keys_threaded_vectorized( + iters: Vec, + build_hasher: Option, +) -> (Vec>, PlRandomState) +where + I: IntoIterator + Send, + I::IntoIter: TrustedLen, + T: TotalHash + TotalEq + Send + ToTotalOrd, + ::TotalOrdItem: Hash + Eq, +{ + let build_hasher = build_hasher.unwrap_or_default(); + let hashes = POOL.install(|| { + iters + .into_par_iter() + .map(|iter| { + // create hashes and keys + #[allow(clippy::needless_borrows_for_generic_args)] + iter.into_iter() + .map(|val| (build_hasher.hash_one(&val.to_total_ord()), val)) + .collect_trusted::>() + }) + .collect() + }); + (hashes, build_hasher) +} + +pub(crate) fn prepare_hashed_relation_threaded( + iters: Vec, +) -> Vec::TotalOrdItem, (bool, IdxVec)>> +where + I: Iterator + Send + TrustedLen, + T: Send + Sync + TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Hash + Eq, +{ + let n_partitions = _set_partition_size(); + let (hashes_and_keys, build_hasher) = create_hash_and_keys_threaded_vectorized(iters, None); + + // We will create a hashtable in every thread. + // We use the hash to partition the keys to the matching hashtable. + // Every thread traverses all keys/hashes and ignores the ones that doesn't fall in that partition. + POOL.install(|| { + (0..n_partitions) + .into_par_iter() + .map(|partition_no| { + let hashes_and_keys = &hashes_and_keys; + let mut hash_tbl: PlHashMap = + PlHashMap::with_hasher(build_hasher); + + let mut offset = 0; + for hashes_and_keys in hashes_and_keys { + let len = hashes_and_keys.len(); + hashes_and_keys + .iter() + .enumerate() + .for_each(|(idx, (h, k))| { + let k = k.to_total_ord(); + let idx = idx as IdxSize; + // partition hashes by thread no. + // So only a part of the hashes go to this hashmap + if partition_no == hash_to_partition(*h, n_partitions) { + let idx = idx + offset; + let entry = hash_tbl + .raw_entry_mut() + // uses the key to check equality to find and entry + .from_key_hashed_nocheck(*h, &k); + + match entry { + RawEntryMut::Vacant(entry) => { + entry.insert_hashed_nocheck(*h, k, (false, unitvec![idx])); + }, + RawEntryMut::Occupied(mut entry) => { + let (_k, v) = entry.get_key_value_mut(); + v.1.push(idx); + }, + } + } + }); + + offset += len as IdxSize; + } + hash_tbl + }) + .collect() + }) +} + +/// Probe the build table and add tuples to the results. +#[allow(clippy::too_many_arguments)] +fn probe_outer( + probe_hashes: &[Vec<(u64, T)>], + hash_tbls: &mut [PlHashMap<::TotalOrdItem, (bool, IdxVec)>], + results: &mut ( + MutablePrimitiveArray, + MutablePrimitiveArray, + ), + n_tables: usize, + // Function that get index_a, index_b when there is a match and pushes to result + swap_fn_match: F, + // Function that get index_a when there is no match and pushes to result + swap_fn_no_match: G, + // Function that get index_b from the build table that did not match any in A and pushes to result + swap_fn_drain: H, + nulls_equal: bool, +) where + T: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq + IsNull, + // idx_a, idx_b -> ... + F: Fn(IdxSize, IdxSize) -> (Option, Option), + // idx_a -> ... + G: Fn(IdxSize) -> (Option, Option), + // idx_b -> ... + H: Fn(IdxSize) -> (Option, Option), +{ + // needed for the partition shift instead of modulo to make sense + let mut idx_a = 0; + for probe_hashes in probe_hashes { + for (h, key) in probe_hashes { + let key = key.to_total_ord(); + let h = *h; + // probe table that contains the hashed value + let current_probe_table = + unsafe { hash_tbls.get_unchecked_mut(hash_to_partition(h, n_tables)) }; + + let entry = current_probe_table + .raw_entry_mut() + .from_key_hashed_nocheck(h, &key); + + match entry { + // match and remove + RawEntryMut::Occupied(mut occupied) => { + if key.is_null() && !nulls_equal { + let (l, r) = swap_fn_no_match(idx_a); + results.0.push(l); + results.1.push(r); + } else { + let (tracker, indexes_b) = occupied.get_mut(); + *tracker = true; + for (l, r) in indexes_b.iter().map(|&idx_b| swap_fn_match(idx_a, idx_b)) { + results.0.push(l); + results.1.push(r); + } + } + }, + // no match + RawEntryMut::Vacant(_) => { + let (l, r) = swap_fn_no_match(idx_a); + results.0.push(l); + results.1.push(r); + }, + } + idx_a += 1; + } + } + + for hash_tbl in hash_tbls { + hash_tbl.iter().for_each(|(_k, (tracker, indexes_b))| { + // remaining joined values from the right table + if !*tracker { + for (l, r) in indexes_b.iter().map(|&idx_b| swap_fn_drain(idx_b)) { + results.0.push(l); + results.1.push(r); + } + } + }); + } +} + +/// Hash join outer. Both left and right can have no match so Options +pub(super) fn hash_join_tuples_outer( + probe: Vec, + build: Vec, + swapped: bool, + validate: JoinValidation, + nulls_equal: bool, +) -> PolarsResult<(PrimitiveArray, PrimitiveArray)> +where + I: IntoIterator, + J: IntoIterator, + ::IntoIter: TrustedLen + Send, + ::IntoIter: TrustedLen + Send, + T: Send + Sync + TotalHash + TotalEq + IsNull + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Hash + Eq + IsNull, +{ + let probe = probe.into_iter().map(|i| i.into_iter()).collect::>(); + let build = build.into_iter().map(|i| i.into_iter()).collect::>(); + // This function is partially multi-threaded. + // Parts that are done in parallel: + // - creation of the probe tables + // - creation of the hashes + + // during the probe phase values are removed from the tables, that's done single threaded to + // keep it lock free. + + let size = probe + .iter() + .map(|a| a.size_hint().1.unwrap()) + .sum::() + + build + .iter() + .map(|b| b.size_hint().1.unwrap()) + .sum::(); + let mut results = ( + MutablePrimitiveArray::with_capacity(size), + MutablePrimitiveArray::with_capacity(size), + ); + + // prepare hash table + let mut hash_tbls = if validate.needs_checks() { + let expected_size = build.iter().map(|i| i.size_hint().0).sum(); + let hash_tbls = prepare_hashed_relation_threaded(build); + let build_size = hash_tbls.iter().map(|m| m.len()).sum(); + validate.validate_build(build_size, expected_size, swapped)?; + hash_tbls + } else { + prepare_hashed_relation_threaded(build) + }; + let random_state = hash_tbls[0].hasher(); + + // we pre hash the probing values + let (probe_hashes, _) = create_hash_and_keys_threaded_vectorized(probe, Some(*random_state)); + + let n_tables = hash_tbls.len(); + try_raise_keyboard_interrupt(); + + // probe the hash table. + // Note: indexes from b that are not matched will be None, Some(idx_b) + // Therefore we remove the matches and the remaining will be joined from the right + + // branch is because we want to only do the swap check once + if swapped { + probe_outer( + &probe_hashes, + &mut hash_tbls, + &mut results, + n_tables, + |idx_a, idx_b| (Some(idx_b), Some(idx_a)), + |idx_a| (None, Some(idx_a)), + |idx_b| (Some(idx_b), None), + nulls_equal, + ) + } else { + probe_outer( + &probe_hashes, + &mut hash_tbls, + &mut results, + n_tables, + |idx_a, idx_b| (Some(idx_a), Some(idx_b)), + |idx_a| (Some(idx_a), None), + |idx_b| (None, Some(idx_b)), + nulls_equal, + ) + } + Ok((results.0.into(), results.1.into())) +} diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys_semi_anti.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_semi_anti.rs new file mode 100644 index 000000000000..ecce25d36bb9 --- /dev/null +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_semi_anti.rs @@ -0,0 +1,125 @@ +use polars_utils::hashing::{DirtyHash, hash_to_partition}; +use polars_utils::nulls::IsNull; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; + +use super::*; + +/// Only keeps track of membership in right table +pub(super) fn build_table_semi_anti( + keys: Vec, + nulls_equal: bool, +) -> Vec::TotalOrdItem>> +where + T: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Hash + Eq + DirtyHash + IsNull, + I: IntoIterator + Copy + Send + Sync, +{ + let n_partitions = _set_partition_size(); + + // We will create a hashtable in every thread. + // We use the hash to partition the keys to the matching hashtable. + // Every thread traverses all keys/hashes and ignores the ones that doesn't fall in that partition. + let par_iter = (0..n_partitions).into_par_iter().map(|partition_no| { + let mut hash_tbl: PlHashSet = PlHashSet::with_capacity(_HASHMAP_INIT_SIZE); + for keys in &keys { + keys.into_iter().for_each(|k| { + let k = k.to_total_ord(); + if partition_no == hash_to_partition(k.dirty_hash(), n_partitions) + && (!k.is_null() || nulls_equal) + { + hash_tbl.insert(k); + } + }); + } + hash_tbl + }); + POOL.install(|| par_iter.collect()) +} + +/// Construct a ParallelIterator, but doesn't iterate it. This means the caller +/// context (or wherever it gets iterated) should be in POOL.install. +fn semi_anti_impl( + probe: Vec, + build: Vec, + nulls_equal: bool, +) -> impl ParallelIterator +where + I: IntoIterator + Copy + Send + Sync, + T: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Hash + Eq + DirtyHash + IsNull, +{ + // first we hash one relation + let hash_sets = build_table_semi_anti(build, nulls_equal); + + // we determine the offset so that we later know which index to store in the join tuples + let offsets = probe_to_offsets(&probe); + + let n_tables = hash_sets.len(); + + // next we probe the other relation + // This is not wrapped in POOL.install because it is not being iterated here + probe + .into_par_iter() + .zip(offsets) + // probes_hashes: Vec processed by this thread + // offset: offset index + .flat_map(move |(probe, offset)| { + // local reference + let hash_sets = &hash_sets; + let probe_iter = probe.into_iter(); + + // assume the result tuples equal length of the no. of hashes processed by this thread. + let mut results = Vec::with_capacity(probe_iter.size_hint().1.unwrap()); + + probe_iter.enumerate().for_each(|(idx_a, k)| { + let k = k.to_total_ord(); + let idx_a = (idx_a + offset) as IdxSize; + // probe table that contains the hashed value + let current_probe_table = + unsafe { hash_sets.get_unchecked(hash_to_partition(k.dirty_hash(), n_tables)) }; + + // we already hashed, so we don't have to hash again. + let value = current_probe_table.get(&k); + + match value { + // left and right matches + Some(_) => results.push((idx_a, true)), + // only left values, right = null + None => results.push((idx_a, false)), + } + }); + results + }) +} + +pub(super) fn hash_join_tuples_left_anti( + probe: Vec, + build: Vec, + nulls_equal: bool, +) -> Vec +where + I: IntoIterator + Copy + Send + Sync, + T: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Hash + Eq + DirtyHash + IsNull, +{ + let par_iter = semi_anti_impl(probe, build, nulls_equal) + .filter(|tpls| !tpls.1) + .map(|tpls| tpls.0); + POOL.install(|| par_iter.collect()) +} + +pub(super) fn hash_join_tuples_left_semi( + probe: Vec, + build: Vec, + nulls_equal: bool, +) -> Vec +where + I: IntoIterator + Copy + Send + Sync, + T: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Hash + Eq + DirtyHash + IsNull, +{ + let par_iter = semi_anti_impl(probe, build, nulls_equal) + .filter(|tpls| tpls.1) + .map(|tpls| tpls.0); + POOL.install(|| par_iter.collect()) +} diff --git a/crates/polars-ops/src/frame/join/hash_join/sort_merge.rs b/crates/polars-ops/src/frame/join/hash_join/sort_merge.rs new file mode 100644 index 000000000000..3f5587975e62 --- /dev/null +++ b/crates/polars-ops/src/frame/join/hash_join/sort_merge.rs @@ -0,0 +1,361 @@ +#[cfg(feature = "performant")] +use arrow::legacy::kernels::sorted_join; +#[cfg(feature = "performant")] +use polars_core::utils::_split_offsets; +#[cfg(feature = "performant")] +use polars_core::utils::flatten::flatten_par; + +use super::*; + +#[cfg(feature = "performant")] +fn par_sorted_merge_left_impl( + s_left: &ChunkedArray, + s_right: &ChunkedArray, +) -> (Vec, Vec) +where + T: PolarsNumericType, +{ + let offsets = _split_offsets(s_left.len(), POOL.current_num_threads()); + let s_left = s_left.rechunk(); + let s_right = s_right.rechunk(); + + // we can unwrap because we should not have nulls + let slice_left = s_left.cont_slice().unwrap(); + let slice_right = s_right.cont_slice().unwrap(); + + let indexes = offsets.into_par_iter().map(|(offset, len)| { + let slice_left = &slice_left[offset..offset + len]; + sorted_join::left::join(slice_left, slice_right, offset as IdxSize) + }); + let indexes = POOL.install(|| indexes.collect::>()); + + let lefts = indexes.iter().map(|t| &t.0).collect::>(); + let rights = indexes.iter().map(|t| &t.1).collect::>(); + + (flatten_par(&lefts), flatten_par(&rights)) +} + +#[cfg(feature = "performant")] +pub(super) fn par_sorted_merge_left( + s_left: &Series, + s_right: &Series, +) -> (Vec, Vec) { + // Don't use bit_repr here. It messes up sortedness. + debug_assert_eq!(s_left.dtype(), s_right.dtype()); + let s_left = s_left.to_physical_repr(); + let s_right = s_right.to_physical_repr(); + + match s_left.dtype() { + #[cfg(feature = "dtype-i8")] + DataType::Int8 => par_sorted_merge_left_impl(s_left.i8().unwrap(), s_right.i8().unwrap()), + #[cfg(feature = "dtype-u8")] + DataType::UInt8 => par_sorted_merge_left_impl(s_left.u8().unwrap(), s_right.u8().unwrap()), + #[cfg(feature = "dtype-u16")] + DataType::UInt16 => { + par_sorted_merge_left_impl(s_left.u16().unwrap(), s_right.u16().unwrap()) + }, + #[cfg(feature = "dtype-i16")] + DataType::Int16 => { + par_sorted_merge_left_impl(s_left.i16().unwrap(), s_right.i16().unwrap()) + }, + DataType::UInt32 => { + par_sorted_merge_left_impl(s_left.u32().unwrap(), s_right.u32().unwrap()) + }, + DataType::Int32 => { + par_sorted_merge_left_impl(s_left.i32().unwrap(), s_right.i32().unwrap()) + }, + DataType::UInt64 => { + par_sorted_merge_left_impl(s_left.u64().unwrap(), s_right.u64().unwrap()) + }, + DataType::Int64 => { + par_sorted_merge_left_impl(s_left.i64().unwrap(), s_right.i64().unwrap()) + }, + #[cfg(feature = "dtype-i128")] + DataType::Int128 => { + par_sorted_merge_left_impl(s_left.i128().unwrap(), s_right.i128().unwrap()) + }, + DataType::Float32 => { + par_sorted_merge_left_impl(s_left.f32().unwrap(), s_right.f32().unwrap()) + }, + DataType::Float64 => { + par_sorted_merge_left_impl(s_left.f64().unwrap(), s_right.f64().unwrap()) + }, + dt => panic!("{:?}", dt), + } +} +#[cfg(feature = "performant")] +fn par_sorted_merge_inner_impl( + s_left: &ChunkedArray, + s_right: &ChunkedArray, +) -> (Vec, Vec) +where + T: PolarsNumericType, +{ + let offsets = _split_offsets(s_left.len(), POOL.current_num_threads()); + let s_left = s_left.rechunk(); + let s_right = s_right.rechunk(); + + // we can unwrap because we should not have nulls + let slice_left = s_left.cont_slice().unwrap(); + let slice_right = s_right.cont_slice().unwrap(); + + let indexes = offsets.into_par_iter().map(|(offset, len)| { + let slice_left = &slice_left[offset..offset + len]; + sorted_join::inner::join(slice_left, slice_right, offset as IdxSize) + }); + let indexes = POOL.install(|| indexes.collect::>()); + + let lefts = indexes.iter().map(|t| &t.0).collect::>(); + let rights = indexes.iter().map(|t| &t.1).collect::>(); + + (flatten_par(&lefts), flatten_par(&rights)) +} + +#[cfg(feature = "performant")] +pub(super) fn par_sorted_merge_inner_no_nulls( + s_left: &Series, + s_right: &Series, +) -> (Vec, Vec) { + // Don't use bit_repr here. It messes up sortedness. + debug_assert_eq!(s_left.dtype(), s_right.dtype()); + let s_left = s_left.to_physical_repr(); + let s_right = s_right.to_physical_repr(); + + match s_left.dtype() { + #[cfg(feature = "dtype-i8")] + DataType::Int8 => par_sorted_merge_inner_impl(s_left.i8().unwrap(), s_right.i8().unwrap()), + #[cfg(feature = "dtype-u8")] + DataType::UInt8 => par_sorted_merge_inner_impl(s_left.u8().unwrap(), s_right.u8().unwrap()), + #[cfg(feature = "dtype-u16")] + DataType::UInt16 => { + par_sorted_merge_inner_impl(s_left.u16().unwrap(), s_right.u16().unwrap()) + }, + #[cfg(feature = "dtype-i16")] + DataType::Int16 => { + par_sorted_merge_inner_impl(s_left.i16().unwrap(), s_right.i16().unwrap()) + }, + DataType::UInt32 => { + par_sorted_merge_inner_impl(s_left.u32().unwrap(), s_right.u32().unwrap()) + }, + DataType::Int32 => { + par_sorted_merge_inner_impl(s_left.i32().unwrap(), s_right.i32().unwrap()) + }, + DataType::UInt64 => { + par_sorted_merge_inner_impl(s_left.u64().unwrap(), s_right.u64().unwrap()) + }, + DataType::Int64 => { + par_sorted_merge_inner_impl(s_left.i64().unwrap(), s_right.i64().unwrap()) + }, + #[cfg(feature = "dtype-i128")] + DataType::Int128 => { + par_sorted_merge_inner_impl(s_left.i128().unwrap(), s_right.i128().unwrap()) + }, + DataType::Float32 => { + par_sorted_merge_inner_impl(s_left.f32().unwrap(), s_right.f32().unwrap()) + }, + DataType::Float64 => { + par_sorted_merge_inner_impl(s_left.f64().unwrap(), s_right.f64().unwrap()) + }, + _ => unreachable!(), + } +} + +pub(crate) fn to_left_join_ids( + left_idx: Vec, + right_idx: Vec, +) -> LeftJoinIds { + #[cfg(feature = "chunked_ids")] + { + (Either::Left(left_idx), Either::Left(right_idx)) + } + + #[cfg(not(feature = "chunked_ids"))] + { + (left_idx, right_idx) + } +} + +#[cfg(feature = "performant")] +fn create_reverse_map_from_arg_sort(mut arg_sort: IdxCa) -> Vec { + let arr = unsafe { arg_sort.chunks_mut() }.pop().unwrap(); + primitive_to_vec::(arr).unwrap() +} + +#[cfg(not(feature = "performant"))] +pub(crate) fn _sort_or_hash_inner( + s_left: &Series, + s_right: &Series, + _verbose: bool, + validate: JoinValidation, + nulls_equal: bool, +) -> PolarsResult<(InnerJoinIds, bool)> { + s_left.hash_join_inner(s_right, validate, nulls_equal) +} + +#[cfg(feature = "performant")] +pub(crate) fn _sort_or_hash_inner( + s_left: &Series, + s_right: &Series, + verbose: bool, + validate: JoinValidation, + nulls_equal: bool, +) -> PolarsResult<(InnerJoinIds, bool)> { + // We check if keys are sorted. + // - If they are we can do a sorted merge join + // If one of the keys is not, it can still be faster to sort that key and use + // the `arg_sort` indices to revert the sort once the join keys are determined. + let size_factor_rhs = s_right.len() as f32 / s_left.len() as f32; + let size_factor_lhs = s_left.len() as f32 / s_right.len() as f32; + let size_factor_acceptable = std::env::var("POLARS_JOIN_SORT_FACTOR") + .map(|s| s.parse::().unwrap()) + .unwrap_or(1.0); + let is_numeric = s_left.dtype().to_physical().is_primitive_numeric(); + + if validate.needs_checks() { + return s_left.hash_join_inner(s_right, validate, nulls_equal); + } + + let no_nulls = s_left.null_count() == 0 && s_right.null_count() == 0; + match (s_left.is_sorted_flag(), s_right.is_sorted_flag(), no_nulls) { + (IsSorted::Ascending, IsSorted::Ascending, true) if is_numeric => { + if verbose { + eprintln!("inner join: keys are sorted: use sorted merge join"); + } + Ok((par_sorted_merge_inner_no_nulls(s_left, s_right), true)) + }, + (IsSorted::Ascending, _, true) + if is_numeric && size_factor_rhs < size_factor_acceptable => + { + if verbose { + eprintln!("right key will be descending sorted in inner join operation.") + } + + let sort_idx = s_right.arg_sort(SortOptions { + descending: false, + nulls_last: false, + multithreaded: true, + maintain_order: false, + limit: None, + }); + let s_right = unsafe { s_right.take_unchecked(&sort_idx) }; + let ids = par_sorted_merge_inner_no_nulls(s_left, &s_right); + let reverse_idx_map = create_reverse_map_from_arg_sort(sort_idx); + + let (left, mut right) = ids; + + POOL.install(|| { + right.par_iter_mut().for_each(|idx| { + *idx = unsafe { *reverse_idx_map.get_unchecked(*idx as usize) }; + }); + }); + + Ok(((left, right), true)) + }, + (_, IsSorted::Ascending, true) + if is_numeric && size_factor_lhs < size_factor_acceptable => + { + if verbose { + eprintln!("left key will be descending sorted in inner join operation.") + } + + let sort_idx = s_left.arg_sort(SortOptions { + descending: false, + nulls_last: false, + multithreaded: true, + maintain_order: false, + limit: None, + }); + let s_left = unsafe { s_left.take_unchecked(&sort_idx) }; + let ids = par_sorted_merge_inner_no_nulls(&s_left, s_right); + let reverse_idx_map = create_reverse_map_from_arg_sort(sort_idx); + + let (mut left, right) = ids; + + POOL.install(|| { + left.par_iter_mut().for_each(|idx| { + *idx = unsafe { *reverse_idx_map.get_unchecked(*idx as usize) }; + }); + }); + + // set sorted to `false` as we descending sorted the left key. + Ok(((left, right), false)) + }, + _ => s_left.hash_join_inner(s_right, validate, nulls_equal), + } +} + +#[cfg(not(feature = "performant"))] +pub(crate) fn sort_or_hash_left( + s_left: &Series, + s_right: &Series, + _verbose: bool, + validate: JoinValidation, + nulls_equal: bool, +) -> PolarsResult { + s_left.hash_join_left(s_right, validate, nulls_equal) +} + +#[cfg(feature = "performant")] +pub(crate) fn sort_or_hash_left( + s_left: &Series, + s_right: &Series, + verbose: bool, + validate: JoinValidation, + nulls_equal: bool, +) -> PolarsResult { + if validate.needs_checks() { + return s_left.hash_join_left(s_right, validate, nulls_equal); + } + + let size_factor_rhs = s_right.len() as f32 / s_left.len() as f32; + let size_factor_acceptable = std::env::var("POLARS_JOIN_SORT_FACTOR") + .map(|s| s.parse::().unwrap()) + .unwrap_or(1.0); + let is_numeric = s_left.dtype().to_physical().is_primitive_numeric(); + + let no_nulls = s_left.null_count() == 0 && s_right.null_count() == 0; + + match (s_left.is_sorted_flag(), s_right.is_sorted_flag(), no_nulls) { + (IsSorted::Ascending, IsSorted::Ascending, true) if is_numeric => { + if verbose { + eprintln!("left join: keys are sorted: use sorted merge join"); + } + let (left_idx, right_idx) = par_sorted_merge_left(s_left, s_right); + Ok(to_left_join_ids(left_idx, right_idx)) + }, + (IsSorted::Ascending, _, true) + if is_numeric && size_factor_rhs < size_factor_acceptable => + { + if verbose { + eprintln!("right key will be reverse sorted in left join operation.") + } + + let sort_idx = s_right.arg_sort(SortOptions { + descending: false, + nulls_last: false, + multithreaded: true, + maintain_order: false, + limit: None, + }); + let s_right = unsafe { s_right.take_unchecked(&sort_idx) }; + + let ids = par_sorted_merge_left(s_left, &s_right); + let reverse_idx_map = create_reverse_map_from_arg_sort(sort_idx); + let (left, mut right) = ids; + + POOL.install(|| { + right.par_iter_mut().for_each(|opt_idx| { + if !opt_idx.is_null_idx() { + *opt_idx = + unsafe { *reverse_idx_map.get_unchecked(opt_idx.idx() as usize) } + .into(); + } + }); + }); + + Ok(to_left_join_ids(left, right)) + }, + // don't reverse sort a left join key yet. Have to figure out how to set sorted flag + _ => s_left.hash_join_left(s_right, validate, nulls_equal), + } +} diff --git a/crates/polars-ops/src/frame/join/iejoin/filtered_bit_array.rs b/crates/polars-ops/src/frame/join/iejoin/filtered_bit_array.rs new file mode 100644 index 000000000000..2c741a797f11 --- /dev/null +++ b/crates/polars-ops/src/frame/join/iejoin/filtered_bit_array.rs @@ -0,0 +1,49 @@ +use std::cmp::min; + +use arrow::bitmap::MutableBitmap; + +/// Bit array with a filter to speed up searching for set bits when sparse, +/// based on section 4.1 from Khayyat et al. 2015, +/// "Lightning Fast and Space Efficient Inequality Joins" +pub struct FilteredBitArray { + bit_array: MutableBitmap, + filter: MutableBitmap, +} + +impl FilteredBitArray { + const CHUNK_SIZE: usize = 1024; + + pub fn from_len_zeroed(len: usize) -> Self { + Self { + bit_array: MutableBitmap::from_len_zeroed(len), + filter: MutableBitmap::from_len_zeroed(len.div_ceil(Self::CHUNK_SIZE)), + } + } + + pub unsafe fn set_bit_unchecked(&mut self, index: usize) { + self.bit_array.set_unchecked(index, true); + self.filter.set_unchecked(index / Self::CHUNK_SIZE, true); + } + + pub unsafe fn on_set_bits_from(&self, start: usize, mut action: F) + where + F: FnMut(usize), + { + let start_chunk = start / Self::CHUNK_SIZE; + let mut chunk_offset = start % Self::CHUNK_SIZE; + for chunk_idx in start_chunk..self.filter.len() { + if self.filter.get_unchecked(chunk_idx) { + // There are some set bits in this chunk + let start = chunk_idx * Self::CHUNK_SIZE + chunk_offset; + let end = min((chunk_idx + 1) * Self::CHUNK_SIZE, self.bit_array.len()); + for bit_idx in start..end { + // SAFETY: `bit_idx` is always less than `self.bit_array.len()` + if self.bit_array.get_unchecked(bit_idx) { + action(bit_idx); + } + } + } + chunk_offset = 0; + } + } +} diff --git a/crates/polars-ops/src/frame/join/iejoin/l1_l2.rs b/crates/polars-ops/src/frame/join/iejoin/l1_l2.rs new file mode 100644 index 000000000000..89e108eb4673 --- /dev/null +++ b/crates/polars-ops/src/frame/join/iejoin/l1_l2.rs @@ -0,0 +1,263 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use polars_core::chunked_array::ChunkedArray; +use polars_core::datatypes::{IdxCa, PolarsNumericType}; +use polars_core::prelude::Series; +use polars_core::with_match_physical_numeric_polars_type; +use polars_error::PolarsResult; +use polars_utils::IdxSize; +use polars_utils::total_ord::TotalOrd; + +use super::*; + +/// Create a vector of L1 items from the array of LHS x values concatenated with RHS x values +/// and their ordering. +pub(super) fn build_l1_array( + ca: &ChunkedArray, + order: &IdxCa, + right_df_offset: IdxSize, +) -> PolarsResult>> +where + T: PolarsNumericType, +{ + assert_eq!(order.null_count(), 0); + assert_eq!(ca.chunks().len(), 1); + let arr = ca.downcast_get(0).unwrap(); + // Even if there are nulls, they will not be selected by order. + let values = arr.values().as_slice(); + + let mut array: Vec> = Vec::with_capacity(ca.len()); + + for order_arr in order.downcast_iter() { + for index in order_arr.values().as_slice().iter().copied() { + debug_assert!(arr.get(index as usize).is_some()); + let value = unsafe { *values.get_unchecked(index as usize) }; + let row_index = if index < right_df_offset { + // Row from LHS + index as i64 + 1 + } else { + // Row from RHS + -((index - right_df_offset) as i64) - 1 + }; + array.push(L1Item { row_index, value }); + } + } + + Ok(array) +} + +pub(super) fn build_l2_array(s: &Series, order: &[IdxSize]) -> PolarsResult> { + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { + build_l2_array_impl::<$T>(s.as_ref().as_ref(), order) + }) +} + +/// Create a vector of L2 items from the array of y values ordered according to the L1 order, +/// and their ordering. We don't need to store actual y values but only track whether we're at +/// the end of a run of equal values. +fn build_l2_array_impl(ca: &ChunkedArray, order: &[IdxSize]) -> PolarsResult> +where + T: PolarsNumericType, + T::Native: TotalOrd, +{ + assert_eq!(ca.chunks().len(), 1); + + let mut array = Vec::with_capacity(ca.len()); + let mut prev_index = 0; + let mut prev_value = T::Native::default(); + + let arr = ca.downcast_get(0).unwrap(); + // Even if there are nulls, they will not be selected by order. + let values = arr.values().as_slice(); + + for (i, l1_index) in order.iter().copied().enumerate() { + debug_assert!(arr.get(l1_index as usize).is_some()); + let value = unsafe { *values.get_unchecked(l1_index as usize) }; + if i > 0 { + array.push(L2Item { + l1_index: prev_index, + run_end: value.tot_ne(&prev_value), + }); + } + prev_index = l1_index; + prev_value = value; + } + if !order.is_empty() { + array.push(L2Item { + l1_index: prev_index, + run_end: true, + }); + } + Ok(array) +} + +/// Item in L1 array used in the IEJoin algorithm +#[derive(Clone, Copy, Debug)] +pub(super) struct L1Item { + /// 1 based index for entries from the LHS df, or -1 based index for entries from the RHS + pub(super) row_index: i64, + /// X value + pub(super) value: T, +} + +/// Item in L2 array used in the IEJoin algorithm +#[derive(Clone, Copy, Debug)] +pub(super) struct L2Item { + /// Corresponding index into the L1 array of + pub(super) l1_index: IdxSize, + /// Whether this is the end of a run of equal y values + pub(super) run_end: bool, +} + +pub(super) trait L1Array { + unsafe fn process_entry( + &self, + l1_index: usize, + bit_array: &mut FilteredBitArray, + op1: InequalityOperator, + left_row_ids: &mut Vec, + right_row_ids: &mut Vec, + ) -> i64; + + unsafe fn process_lhs_entry( + &self, + l1_index: usize, + bit_array: &FilteredBitArray, + op1: InequalityOperator, + left_row_ids: &mut Vec, + right_row_ids: &mut Vec, + ) -> i64; + + unsafe fn mark_visited(&self, index: usize, bit_array: &mut FilteredBitArray); +} + +/// Find the position in the L1 array where we should begin checking for matches, +/// given the index in L1 corresponding to the current position in L2. +unsafe fn find_search_start_index( + l1_array: &[L1Item], + index: usize, + operator: InequalityOperator, +) -> usize +where + T: NumericNative, + T: TotalOrd, +{ + let sub_l1 = l1_array.get_unchecked(index..); + let value = l1_array.get_unchecked(index).value; + + match operator { + InequalityOperator::Gt => { + sub_l1.partition_point_exponential(|a| a.value.tot_ge(&value)) + index + }, + InequalityOperator::Lt => { + sub_l1.partition_point_exponential(|a| a.value.tot_le(&value)) + index + }, + InequalityOperator::GtEq => { + sub_l1.partition_point_exponential(|a| value.tot_lt(&a.value)) + index + }, + InequalityOperator::LtEq => { + sub_l1.partition_point_exponential(|a| value.tot_gt(&a.value)) + index + }, + } +} + +fn find_matches_in_l1( + l1_array: &[L1Item], + l1_index: usize, + row_index: i64, + bit_array: &FilteredBitArray, + op1: InequalityOperator, + left_row_ids: &mut Vec, + right_row_ids: &mut Vec, +) -> i64 +where + T: NumericNative, + T: TotalOrd, +{ + debug_assert!(row_index > 0); + let mut match_count = 0; + + // This entry comes from the left hand side DataFrame. + // Find all following entries in L1 (meaning they satisfy the first operator) + // that have already been visited (so satisfy the second operator). + // Because we use a stable sort for l2, we know that we won't find any + // matches for duplicate y values when traversing forwards in l1. + let start_index = unsafe { find_search_start_index(l1_array, l1_index, op1) }; + unsafe { + bit_array.on_set_bits_from(start_index, |set_bit: usize| { + // SAFETY + // set bit is within bounds. + let right_row_index = l1_array.get_unchecked(set_bit).row_index; + debug_assert!(right_row_index < 0); + left_row_ids.push((row_index - 1) as IdxSize); + right_row_ids.push((-right_row_index) as IdxSize - 1); + match_count += 1; + }) + }; + + match_count +} + +impl L1Array for Vec> +where + T: NumericNative, +{ + unsafe fn process_entry( + &self, + l1_index: usize, + bit_array: &mut FilteredBitArray, + op1: InequalityOperator, + left_row_ids: &mut Vec, + right_row_ids: &mut Vec, + ) -> i64 { + let row_index = self.get_unchecked(l1_index).row_index; + let from_lhs = row_index > 0; + if from_lhs { + find_matches_in_l1( + self, + l1_index, + row_index, + bit_array, + op1, + left_row_ids, + right_row_ids, + ) + } else { + bit_array.set_bit_unchecked(l1_index); + 0 + } + } + + unsafe fn process_lhs_entry( + &self, + l1_index: usize, + bit_array: &FilteredBitArray, + op1: InequalityOperator, + left_row_ids: &mut Vec, + right_row_ids: &mut Vec, + ) -> i64 { + let row_index = self.get_unchecked(l1_index).row_index; + let from_lhs = row_index > 0; + if from_lhs { + find_matches_in_l1( + self, + l1_index, + row_index, + bit_array, + op1, + left_row_ids, + right_row_ids, + ) + } else { + 0 + } + } + + unsafe fn mark_visited(&self, index: usize, bit_array: &mut FilteredBitArray) { + let from_lhs = self.get_unchecked(index).row_index > 0; + // We only mark RHS entries as visited, + // so that we don't try to match LHS entries with other LHS entries. + if !from_lhs { + bit_array.set_bit_unchecked(index); + } + } +} diff --git a/crates/polars-ops/src/frame/join/iejoin/mod.rs b/crates/polars-ops/src/frame/join/iejoin/mod.rs new file mode 100644 index 000000000000..70c667651cfb --- /dev/null +++ b/crates/polars-ops/src/frame/join/iejoin/mod.rs @@ -0,0 +1,615 @@ +#![allow(unsafe_op_in_unsafe_fn)] +mod filtered_bit_array; +mod l1_l2; + +use std::cmp::min; + +use filtered_bit_array::FilteredBitArray; +use l1_l2::*; +use polars_core::chunked_array::ChunkedArray; +use polars_core::datatypes::{IdxCa, NumericNative, PolarsNumericType}; +use polars_core::frame::DataFrame; +use polars_core::prelude::*; +use polars_core::series::IsSorted; +use polars_core::utils::{_set_partition_size, split}; +use polars_core::{POOL, with_match_physical_numeric_polars_type}; +use polars_error::{PolarsResult, polars_err}; +use polars_utils::IdxSize; +use polars_utils::binary_search::ExponentialSearch; +use polars_utils::itertools::Itertools; +use polars_utils::total_ord::{TotalEq, TotalOrd}; +use rayon::prelude::*; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use crate::frame::_finish_join; + +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum InequalityOperator { + #[default] + Lt, + LtEq, + Gt, + GtEq, +} + +impl InequalityOperator { + fn is_strict(&self) -> bool { + matches!(self, InequalityOperator::Gt | InequalityOperator::Lt) + } +} +#[derive(Clone, Debug, PartialEq, Eq, Default, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct IEJoinOptions { + pub operator1: InequalityOperator, + pub operator2: Option, +} + +#[allow(clippy::too_many_arguments)] +fn ie_join_impl_t( + slice: Option<(i64, usize)>, + l1_order: IdxCa, + l2_order: &[IdxSize], + op1: InequalityOperator, + op2: InequalityOperator, + x: Series, + y_ordered_by_x: Series, + left_height: usize, +) -> PolarsResult<(Vec, Vec)> { + // Create a bit array with order corresponding to L1, + // denoting which entries have been visited while traversing L2. + let mut bit_array = FilteredBitArray::from_len_zeroed(l1_order.len()); + + let mut left_row_idx: Vec = vec![]; + let mut right_row_idx: Vec = vec![]; + + let slice_end = slice_end_index(slice); + let mut match_count = 0; + + let ca: &ChunkedArray = x.as_ref().as_ref(); + let l1_array = build_l1_array(ca, &l1_order, left_height as IdxSize)?; + + if op2.is_strict() { + // For strict inequalities, we rely on using a stable sort of l2 so that + // p values only increase as we traverse a run of equal y values. + // To handle inclusive comparisons in x and duplicate x values we also need the + // sort of l1 to be stable, so that the left hand side entries come before the right + // hand side entries (as we mark visited entries from the right hand side). + for &p in l2_order { + match_count += unsafe { + l1_array.process_entry( + p as usize, + &mut bit_array, + op1, + &mut left_row_idx, + &mut right_row_idx, + ) + }; + + if slice_end.is_some_and(|end| match_count >= end) { + break; + } + } + } else { + let l2_array = build_l2_array(&y_ordered_by_x, l2_order)?; + + // For non-strict inequalities in l2, we need to track runs of equal y values and only + // check for matches after we reach the end of the run and have marked all rhs entries + // in the run as visited. + let mut run_start = 0; + + for i in 0..l2_array.len() { + // Elide bound checks + unsafe { + let item = l2_array.get_unchecked(i); + let p = item.l1_index; + l1_array.mark_visited(p as usize, &mut bit_array); + + if item.run_end { + for l2_item in l2_array.get_unchecked(run_start..i + 1) { + let p = l2_item.l1_index; + match_count += l1_array.process_lhs_entry( + p as usize, + &bit_array, + op1, + &mut left_row_idx, + &mut right_row_idx, + ); + } + + run_start = i + 1; + + if slice_end.is_some_and(|end| match_count >= end) { + break; + } + } + } + } + } + Ok((left_row_idx, right_row_idx)) +} + +fn piecewise_merge_join_impl_t( + slice: Option<(i64, usize)>, + left_order: Option<&[IdxSize]>, + right_order: Option<&[IdxSize]>, + left_ordered: Series, + right_ordered: Series, + mut pred: P, +) -> PolarsResult<(Vec, Vec)> +where + T: PolarsNumericType, + P: FnMut(&T::Native, &T::Native) -> bool, +{ + let slice_end = slice_end_index(slice); + + let mut left_row_idx: Vec = vec![]; + let mut right_row_idx: Vec = vec![]; + + let left_ca: &ChunkedArray = left_ordered.as_ref().as_ref(); + let right_ca: &ChunkedArray = right_ordered.as_ref().as_ref(); + + debug_assert!(left_order.is_none_or(|order| order.len() == left_ca.len())); + debug_assert!(right_order.is_none_or(|order| order.len() == right_ca.len())); + + let mut left_idx = 0; + let mut right_idx = 0; + let mut match_count = 0; + + while left_idx < left_ca.len() { + debug_assert!(left_ca.get(left_idx).is_some()); + let left_val = unsafe { left_ca.value_unchecked(left_idx) }; + while right_idx < right_ca.len() { + debug_assert!(right_ca.get(right_idx).is_some()); + let right_val = unsafe { right_ca.value_unchecked(right_idx) }; + if pred(&left_val, &right_val) { + // If the predicate is true, then it will also be true for all + // remaining rows from the right side. + let left_row = match left_order { + None => left_idx as IdxSize, + Some(order) => order[left_idx], + }; + let right_end_idx = match slice_end { + None => right_ca.len(), + Some(end) => min(right_ca.len(), (end as usize) - match_count + right_idx), + }; + for included_right_row_idx in right_idx..right_end_idx { + let right_row = match right_order { + None => included_right_row_idx as IdxSize, + Some(order) => order[included_right_row_idx], + }; + left_row_idx.push(left_row); + right_row_idx.push(right_row); + } + match_count += right_end_idx - right_idx; + break; + } else { + right_idx += 1; + } + } + if right_idx == right_ca.len() { + // We've reached the end of the right side + // so there can be no more matches for LHS rows + break; + } + if slice_end.is_some_and(|end| match_count >= end as usize) { + break; + } + left_idx += 1; + } + + Ok((left_row_idx, right_row_idx)) +} + +pub(super) fn iejoin_par( + left: &DataFrame, + right: &DataFrame, + selected_left: Vec, + selected_right: Vec, + options: &IEJoinOptions, + suffix: Option, + slice: Option<(i64, usize)>, +) -> PolarsResult { + let l1_descending = matches!( + options.operator1, + InequalityOperator::Gt | InequalityOperator::GtEq + ); + + let l1_sort_options = SortOptions::default() + .with_maintain_order(true) + .with_nulls_last(false) + .with_order_descending(l1_descending); + + let sl = &selected_left[0]; + let l1_s_l = sl + .arg_sort(l1_sort_options) + .slice(sl.null_count() as i64, sl.len() - sl.null_count()); + + let sr = &selected_right[0]; + let l1_s_r = sr + .arg_sort(l1_sort_options) + .slice(sr.null_count() as i64, sr.len() - sr.null_count()); + + // Because we do a cartesian product, the number of partitions is squared. + // We take the sqrt, but we don't expect every partition to produce results and work can be + // imbalanced, so we multiply the number of partitions by 2, which leads to 2^2= 4 + let n_partitions = (_set_partition_size() as f32).sqrt() as usize * 2; + let splitted_a = split(&l1_s_l, n_partitions); + let splitted_b = split(&l1_s_r, n_partitions); + + let cartesian_prod = splitted_a + .iter() + .flat_map(|l| splitted_b.iter().map(move |r| (l, r))) + .collect::>(); + + let iter = cartesian_prod.par_iter().map(|(l_l1_idx, r_l1_idx)| { + if l_l1_idx.is_empty() || r_l1_idx.is_empty() { + return Ok(None); + } + fn get_extrema<'a>( + l1_idx: &'a IdxCa, + s: &'a Series, + ) -> Option<(AnyValue<'a>, AnyValue<'a>)> { + let first = l1_idx.first()?; + let last = l1_idx.last()?; + + let start = s.get(first as usize).unwrap(); + let end = s.get(last as usize).unwrap(); + + Some(if start < end { + (start, end) + } else { + (end, start) + }) + } + let Some((min_l, max_l)) = get_extrema(l_l1_idx, sl) else { + return Ok(None); + }; + let Some((min_r, max_r)) = get_extrema(r_l1_idx, sr) else { + return Ok(None); + }; + + let include_block = match options.operator1 { + InequalityOperator::Lt => min_l < max_r, + InequalityOperator::LtEq => min_l <= max_r, + InequalityOperator::Gt => max_l > min_r, + InequalityOperator::GtEq => max_l >= min_r, + }; + + if include_block { + let (mut l, mut r) = unsafe { + ( + selected_left + .iter() + .map(|s| s.take_unchecked(l_l1_idx)) + .collect_vec(), + selected_right + .iter() + .map(|s| s.take_unchecked(r_l1_idx)) + .collect_vec(), + ) + }; + let sorted_flag = if l1_descending { + IsSorted::Descending + } else { + IsSorted::Ascending + }; + // We sorted using the first series + l[0].set_sorted_flag(sorted_flag); + r[0].set_sorted_flag(sorted_flag); + + // Compute the row indexes + let (idx_l, idx_r) = if options.operator2.is_some() { + iejoin_tuples(l, r, options, None) + } else { + piecewise_merge_join_tuples(l, r, options, None) + }?; + + if idx_l.is_empty() { + return Ok(None); + } + + // These are row indexes in the slices we have given, so we use those to gather in the + // original l1 offset arrays. This gives us indexes in the original tables. + unsafe { + Ok(Some(( + l_l1_idx.take_unchecked(&idx_l), + r_l1_idx.take_unchecked(&idx_r), + ))) + } + } else { + Ok(None) + } + }); + + let row_indices = POOL.install(|| iter.collect::>>())?; + + let mut left_idx = IdxCa::default(); + let mut right_idx = IdxCa::default(); + for (l, r) in row_indices.into_iter().flatten() { + left_idx.append(&l)?; + right_idx.append(&r)?; + } + if let Some((offset, end)) = slice { + left_idx = left_idx.slice(offset, end); + right_idx = right_idx.slice(offset, end); + } + + unsafe { materialize_join(left, right, &left_idx, &right_idx, suffix) } +} + +pub(super) fn iejoin( + left: &DataFrame, + right: &DataFrame, + selected_left: Vec, + selected_right: Vec, + options: &IEJoinOptions, + suffix: Option, + slice: Option<(i64, usize)>, +) -> PolarsResult { + let (left_row_idx, right_row_idx) = if options.operator2.is_some() { + iejoin_tuples(selected_left, selected_right, options, slice) + } else { + piecewise_merge_join_tuples(selected_left, selected_right, options, slice) + }?; + unsafe { materialize_join(left, right, &left_row_idx, &right_row_idx, suffix) } +} + +unsafe fn materialize_join( + left: &DataFrame, + right: &DataFrame, + left_row_idx: &IdxCa, + right_row_idx: &IdxCa, + suffix: Option, +) -> PolarsResult { + try_raise_keyboard_interrupt(); + let (join_left, join_right) = { + POOL.join( + || left.take_unchecked(left_row_idx), + || right.take_unchecked(right_row_idx), + ) + }; + + _finish_join(join_left, join_right, suffix) +} + +/// Inequality join. Matches rows between two DataFrames using two inequality operators +/// (one of [<, <=, >, >=]). +/// Based on Khayyat et al. 2015, "Lightning Fast and Space Efficient Inequality Joins" +/// and extended to work with duplicate values. +fn iejoin_tuples( + selected_left: Vec, + selected_right: Vec, + options: &IEJoinOptions, + slice: Option<(i64, usize)>, +) -> PolarsResult<(IdxCa, IdxCa)> { + if selected_left.len() != 2 { + return Err( + polars_err!(ComputeError: "IEJoin requires exactly two expressions from the left DataFrame"), + ); + }; + if selected_right.len() != 2 { + return Err( + polars_err!(ComputeError: "IEJoin requires exactly two expressions from the right DataFrame"), + ); + }; + + let op1 = options.operator1; + let op2 = match options.operator2 { + None => { + return Err(polars_err!(ComputeError: "IEJoin requires two inequality operators")); + }, + Some(op2) => op2, + }; + + // Determine the sort order based on the comparison operators used. + // We want to sort L1 so that "x[i] op1 x[j]" is true for j > i, + // and L2 so that "y[i] op2 y[j]" is true for j < i + // (except in the case of duplicates and strict inequalities). + // Note that the algorithms published in Khayyat et al. have incorrect logic for + // determining whether to sort descending. + let l1_descending = matches!(op1, InequalityOperator::Gt | InequalityOperator::GtEq); + let l2_descending = matches!(op2, InequalityOperator::Lt | InequalityOperator::LtEq); + + let mut x = selected_left[0].to_physical_repr().into_owned(); + let left_height = x.len(); + + x.extend(&selected_right[0].to_physical_repr())?; + // Rechunk because we will gather. + let x = x.rechunk(); + + let mut y = selected_left[1].to_physical_repr().into_owned(); + y.extend(&selected_right[1].to_physical_repr())?; + // Rechunk because we will gather. + let y = y.rechunk(); + + let l1_sort_options = SortOptions::default() + .with_maintain_order(true) + .with_nulls_last(false) + .with_order_descending(l1_descending); + // Get ordering of x, skipping any null entries as these cannot be matches + let l1_order = x + .arg_sort(l1_sort_options) + .slice(x.null_count() as i64, x.len() - x.null_count()); + + let y_ordered_by_x = unsafe { y.take_unchecked(&l1_order) }; + let l2_sort_options = SortOptions::default() + .with_maintain_order(true) + .with_nulls_last(false) + .with_order_descending(l2_descending); + // Get the indexes into l1, ordered by y values. + // l2_order is the same as "p" from Khayyat et al. + let l2_order = y_ordered_by_x.arg_sort(l2_sort_options).slice( + y_ordered_by_x.null_count() as i64, + y_ordered_by_x.len() - y_ordered_by_x.null_count(), + ); + let l2_order = l2_order.rechunk(); + let l2_order = l2_order.downcast_as_array().values().as_slice(); + + let (left_row_idx, right_row_idx) = with_match_physical_numeric_polars_type!(x.dtype(), |$T| { + ie_join_impl_t::<$T>( + slice, + l1_order, + l2_order, + op1, + op2, + x, + y_ordered_by_x, + left_height + ) + })?; + + debug_assert_eq!(left_row_idx.len(), right_row_idx.len()); + let left_row_idx = IdxCa::from_vec("".into(), left_row_idx); + let right_row_idx = IdxCa::from_vec("".into(), right_row_idx); + let (left_row_idx, right_row_idx) = match slice { + None => (left_row_idx, right_row_idx), + Some((offset, len)) => ( + left_row_idx.slice(offset, len), + right_row_idx.slice(offset, len), + ), + }; + Ok((left_row_idx, right_row_idx)) +} + +/// Piecewise merge join, for joins with only a single inequality. +fn piecewise_merge_join_tuples( + selected_left: Vec, + selected_right: Vec, + options: &IEJoinOptions, + slice: Option<(i64, usize)>, +) -> PolarsResult<(IdxCa, IdxCa)> { + if selected_left.len() != 1 { + return Err( + polars_err!(ComputeError: "Piecewise merge join requires exactly one expression from the left DataFrame"), + ); + }; + if selected_right.len() != 1 { + return Err( + polars_err!(ComputeError: "Piecewise merge join requires exactly one expression from the right DataFrame"), + ); + }; + if options.operator2.is_some() { + return Err( + polars_err!(ComputeError: "Piecewise merge join expects only one inequality operator"), + ); + } + + let op = options.operator1; + // The left side is sorted such that if the condition is false, it will also + // be false for the same RHS row and all following LHS rows. + // The right side is sorted such that if the condition is true then it is also + // true for the same LHS row and all following RHS rows. + // The desired sort order should match the l1 order used in iejoin_par + // so we don't need to re-sort slices when doing a parallel join. + let descending = matches!(op, InequalityOperator::Gt | InequalityOperator::GtEq); + + let left = selected_left[0].to_physical_repr().into_owned(); + let mut right = selected_right[0].to_physical_repr().into_owned(); + let must_cast = right.dtype().matches_schema_type(left.dtype())?; + if must_cast { + right = right.cast(left.dtype())?; + } + + fn get_sorted(series: Series, descending: bool) -> (Series, Option) { + let expected_flag = if descending { + IsSorted::Descending + } else { + IsSorted::Ascending + }; + if (series.is_sorted_flag() == expected_flag || series.len() <= 1) && !series.has_nulls() { + // Fast path, no need to re-sort + (series, None) + } else { + let sort_options = SortOptions::default() + .with_nulls_last(false) + .with_order_descending(descending); + + // Get order and slice to ignore any null values, which cannot be match results + let mut order = series.arg_sort(sort_options).slice( + series.null_count() as i64, + series.len() - series.null_count(), + ); + order.rechunk_mut(); + let ordered = unsafe { series.take_unchecked(&order) }; + (ordered, Some(order)) + } + } + + let (left_ordered, left_order) = get_sorted(left, descending); + debug_assert!( + left_order + .as_ref() + .is_none_or(|order| order.chunks().len() == 1) + ); + let left_order = left_order + .as_ref() + .map(|order| order.downcast_get(0).unwrap().values().as_slice()); + + let (right_ordered, right_order) = get_sorted(right, descending); + debug_assert!( + right_order + .as_ref() + .is_none_or(|order| order.chunks().len() == 1) + ); + let right_order = right_order + .as_ref() + .map(|order| order.downcast_get(0).unwrap().values().as_slice()); + + let (left_row_idx, right_row_idx) = with_match_physical_numeric_polars_type!(left_ordered.dtype(), |$T| { + match op { + InequalityOperator::Lt => piecewise_merge_join_impl_t::<$T, _>( + slice, + left_order, + right_order, + left_ordered, + right_ordered, + |l, r| l.tot_lt(r), + ), + InequalityOperator::LtEq => piecewise_merge_join_impl_t::<$T, _>( + slice, + left_order, + right_order, + left_ordered, + right_ordered, + |l, r| l.tot_le(r), + ), + InequalityOperator::Gt => piecewise_merge_join_impl_t::<$T, _>( + slice, + left_order, + right_order, + left_ordered, + right_ordered, + |l, r| l.tot_gt(r), + ), + InequalityOperator::GtEq => piecewise_merge_join_impl_t::<$T, _>( + slice, + left_order, + right_order, + left_ordered, + right_ordered, + |l, r| l.tot_ge(r), + ), + } + })?; + + debug_assert_eq!(left_row_idx.len(), right_row_idx.len()); + let left_row_idx = IdxCa::from_vec("".into(), left_row_idx); + let right_row_idx = IdxCa::from_vec("".into(), right_row_idx); + let (left_row_idx, right_row_idx) = match slice { + None => (left_row_idx, right_row_idx), + Some((offset, len)) => ( + left_row_idx.slice(offset, len), + right_row_idx.slice(offset, len), + ), + }; + Ok((left_row_idx, right_row_idx)) +} + +fn slice_end_index(slice: Option<(i64, usize)>) -> Option { + match slice { + Some((offset, len)) if offset >= 0 => Some(offset.saturating_add_unsigned(len as u64)), + _ => None, + } +} diff --git a/crates/polars-ops/src/frame/join/merge_sorted.rs b/crates/polars-ops/src/frame/join/merge_sorted.rs new file mode 100644 index 000000000000..23ec1b5392c1 --- /dev/null +++ b/crates/polars-ops/src/frame/join/merge_sorted.rs @@ -0,0 +1,333 @@ +use arrow::legacy::utils::{CustomIterTools, FromTrustedLenIterator}; +use polars_core::prelude::*; +use polars_core::with_match_physical_numeric_polars_type; + +fn check_and_union_revmaps( + lhs_revmap: &Option>, + rhs_revmap: &Option>, +) -> PolarsResult>> { + // Ensure we are operating on either identical locals, or compatible globals. + let lhs_revmap = lhs_revmap.as_ref().unwrap(); + let rhs_revmap = rhs_revmap.as_ref().unwrap(); + match (&**lhs_revmap, &**rhs_revmap) { + (RevMapping::Local(_, l_hash), RevMapping::Local(_, r_hash)) => { + // Same local categoricals, we return immediately + polars_ensure!(l_hash == r_hash, ComputeError: "cannot merge-sort incompatible categoricals"); + Ok(None) + }, + // Return revmap that is the union of the two revmaps. + (RevMapping::Global(_, _, l), RevMapping::Global(_, _, r)) => { + polars_ensure!(l == r, ComputeError: "cannot merge-sort incompatible categoricals"); + let mut rev_map_merger = GlobalRevMapMerger::new(lhs_revmap.clone()); + rev_map_merger.merge_map(rhs_revmap)?; + let new_map = rev_map_merger.finish(); + Ok(Some(new_map)) + }, + _ => unreachable!(), + } +} + +pub fn _merge_sorted_dfs( + left: &DataFrame, + right: &DataFrame, + left_s: &Series, + right_s: &Series, + check_schema: bool, +) -> PolarsResult { + if check_schema { + left.schema_equal(right)?; + } + let dtype_lhs = left_s.dtype(); + let dtype_rhs = right_s.dtype(); + + polars_ensure!( + dtype_lhs == dtype_rhs, + ComputeError: "merge-sort datatype mismatch: {} != {}", dtype_lhs, dtype_rhs + ); + + if dtype_lhs.is_categorical() { + let rev_map_lhs = left_s.categorical().unwrap().get_rev_map(); + let rev_map_rhs = right_s.categorical().unwrap().get_rev_map(); + polars_ensure!( + rev_map_lhs.same_src(rev_map_rhs), + ComputeError: "can only merge-sort categoricals with the same categories" + ); + } + + // If one frame is empty, we can return the other immediately. + if right_s.is_empty() { + return Ok(left.clone()); + } else if left_s.is_empty() { + return Ok(right.clone()); + } + + let merge_indicator = series_to_merge_indicator(left_s, right_s)?; + let new_columns = left + .get_columns() + .iter() + .zip(right.get_columns()) + .map(|(lhs, rhs)| { + let lhs_phys = lhs.to_physical_repr(); + let rhs_phys = rhs.to_physical_repr(); + + let out = Column::from(merge_series( + lhs_phys.as_materialized_series(), + rhs_phys.as_materialized_series(), + &merge_indicator, + )?); + + let lhs_dt = lhs.dtype(); + let dtype_out = match (lhs_dt, rhs.dtype()) { + // Global categorical revmaps must be merged for the output. + (DataType::Categorical(lhs_revmap, ord), DataType::Categorical(rhs_revmap, _)) => { + if let Some(new_revmap) = check_and_union_revmaps(lhs_revmap, rhs_revmap)? { + &DataType::Categorical(Some(new_revmap), *ord) + } else { + lhs_dt + } + }, + _ => lhs_dt, + }; + + let mut out = unsafe { out.from_physical_unchecked(dtype_out) }.unwrap(); + out.rename(lhs.name().clone()); + Ok(out) + }) + .collect::>()?; + + Ok(unsafe { DataFrame::new_no_checks(left.height() + right.height(), new_columns) }) +} + +fn merge_series(lhs: &Series, rhs: &Series, merge_indicator: &[bool]) -> PolarsResult { + use DataType::*; + let out = match lhs.dtype() { + Boolean => { + let lhs = lhs.bool().unwrap(); + let rhs = rhs.bool().unwrap(); + + merge_ca(lhs, rhs, merge_indicator).into_series() + }, + String => { + // dispatch via binary + let lhs = lhs.str().unwrap().as_binary(); + let rhs = rhs.str().unwrap().as_binary(); + let out = merge_ca(&lhs, &rhs, merge_indicator); + unsafe { out.to_string_unchecked() }.into_series() + }, + Binary => { + let lhs = lhs.binary().unwrap(); + let rhs = rhs.binary().unwrap(); + merge_ca(lhs, rhs, merge_indicator).into_series() + }, + #[cfg(feature = "dtype-struct")] + Struct(_) => { + let lhs = lhs.struct_().unwrap(); + let rhs = rhs.struct_().unwrap(); + polars_ensure!(lhs.null_count() + rhs.null_count() == 0, InvalidOperation: "merge sorted with structs with outer nulls not yet supported"); + + let new_fields = lhs + .fields_as_series() + .iter() + .zip(rhs.fields_as_series()) + .map(|(lhs, rhs)| { + merge_series(lhs, &rhs, merge_indicator) + .map(|merged| merged.with_name(lhs.name().clone())) + }) + .collect::>>()?; + StructChunked::from_series(PlSmallStr::EMPTY, new_fields[0].len(), new_fields.iter()) + .unwrap() + .into_series() + }, + List(_) => { + let lhs = lhs.list().unwrap(); + let rhs = rhs.list().unwrap(); + merge_ca(lhs, rhs, merge_indicator).into_series() + }, + dt => { + with_match_physical_numeric_polars_type!(dt, |$T| { + let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); + let rhs: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); + merge_ca(lhs, rhs, merge_indicator).into_series() + }) + }, + }; + Ok(out) +} + +fn merge_ca<'a, T>( + a: &'a ChunkedArray, + b: &'a ChunkedArray, + merge_indicator: &[bool], +) -> ChunkedArray +where + T: PolarsDataType + 'static, + &'a ChunkedArray: IntoIterator, + ChunkedArray: + FromTrustedLenIterator<<<&'a ChunkedArray as IntoIterator>::IntoIter as Iterator>::Item>, +{ + let total_len = a.len() + b.len(); + let mut a = a.into_iter(); + let mut b = b.into_iter(); + + let iter = merge_indicator.iter().map(|a_indicator| { + if *a_indicator { + a.next().unwrap() + } else { + b.next().unwrap() + } + }); + + // SAFETY: length is correct + unsafe { iter.trust_my_length(total_len).collect_trusted() } +} + +fn series_to_merge_indicator(lhs: &Series, rhs: &Series) -> PolarsResult> { + if lhs.dtype().is_categorical() { + let lhs_ca = lhs.categorical().unwrap(); + if lhs_ca.uses_lexical_ordering() { + let rhs_ca = rhs.categorical().unwrap(); + let out = get_merge_indicator(lhs_ca.iter_str(), rhs_ca.iter_str()); + return Ok(out); + } + } + + let lhs_s = lhs.to_physical_repr().into_owned(); + let rhs_s = rhs.to_physical_repr().into_owned(); + + let out = match lhs_s.dtype() { + DataType::Boolean => { + let lhs = lhs_s.bool().unwrap(); + let rhs = rhs_s.bool().unwrap(); + get_merge_indicator(lhs.into_iter(), rhs.into_iter()) + }, + DataType::String => { + let lhs = lhs.str().unwrap().as_binary(); + let rhs = rhs.str().unwrap().as_binary(); + get_merge_indicator(lhs.into_iter(), rhs.into_iter()) + }, + DataType::Binary => { + let lhs = lhs_s.binary().unwrap(); + let rhs = rhs_s.binary().unwrap(); + get_merge_indicator(lhs.into_iter(), rhs.into_iter()) + }, + #[cfg(feature = "dtype-struct")] + DataType::Struct(_) => { + let options = SortOptions::default(); + let lhs = lhs_s.struct_().unwrap().get_row_encoded(options)?; + let rhs = rhs_s.struct_().unwrap().get_row_encoded(options)?; + get_merge_indicator(lhs.into_iter(), rhs.into_iter()) + }, + _ => { + with_match_physical_numeric_polars_type!(lhs_s.dtype(), |$T| { + let lhs: &ChunkedArray<$T> = lhs_s.as_ref().as_ref().as_ref(); + let rhs: &ChunkedArray<$T> = rhs_s.as_ref().as_ref().as_ref(); + + get_merge_indicator(lhs.into_iter(), rhs.into_iter()) + + }) + }, + }; + Ok(out) +} + +// get a boolean values, left: true, right: false +// that indicate from which side we should take a value +fn get_merge_indicator( + mut a_iter: impl ExactSizeIterator, + mut b_iter: impl ExactSizeIterator, +) -> Vec +where + T: PartialOrd + Default + Copy, +{ + const A_INDICATOR: bool = true; + const B_INDICATOR: bool = false; + + let a_len = a_iter.size_hint().0; + let b_len = b_iter.size_hint().0; + if a_len == 0 { + return vec![true; b_len]; + }; + if b_len == 0 { + return vec![false; a_len]; + } + + let mut current_a = T::default(); + let cap = a_len + b_len; + let mut out = Vec::with_capacity(cap); + + let mut current_b = b_iter.next().unwrap(); + + for a in &mut a_iter { + current_a = a; + if a <= current_b { + out.push(A_INDICATOR); + continue; + } + out.push(B_INDICATOR); + + loop { + if let Some(b) = b_iter.next() { + current_b = b; + if b >= a { + out.push(A_INDICATOR); + break; + } + out.push(B_INDICATOR); + continue; + } + // b is depleted fill with a indicator + let remaining = cap - out.len(); + out.extend(std::iter::repeat_n(A_INDICATOR, remaining)); + return out; + } + } + if current_a < current_b { + out.push(B_INDICATOR); + } + // check if current value already is added + if *out.last().unwrap() == A_INDICATOR { + out.push(B_INDICATOR); + } + // take remaining + out.extend(b_iter.map(|_| B_INDICATOR)); + assert_eq!(out.len(), b_len + a_len); + + out +} + +#[test] +fn test_merge_sorted() { + fn get_merge_indicator_sliced(a: &[T], b: &[T]) -> Vec { + get_merge_indicator(a.iter().copied(), b.iter().copied()) + } + + let a = [1, 2, 4, 6, 9]; + let b = [2, 3, 4, 5, 10]; + + let out = get_merge_indicator_sliced(&a, &b); + let expected = [ + true, true, false, false, true, false, false, true, true, false, + ]; + // 1 2 2 3 4 4 5 6 9 10 + assert_eq!(out, expected); + + // swap + // it is not the inverse because left is preferred when both are equal + let out = get_merge_indicator_sliced(&b, &a); + let expected = [ + false, true, false, true, true, false, true, false, false, true, + ]; + assert_eq!(out, expected); + + let a = [5, 6, 7, 10]; + let b = [1, 2, 5]; + let out = get_merge_indicator_sliced(&a, &b); + let expected = [false, false, true, false, true, true, true]; + assert_eq!(out, expected); + + // swap + // it is not the inverse because left is preferred when both are equal + let out = get_merge_indicator_sliced(&b, &a); + let expected = [true, true, true, false, false, false, false]; + assert_eq!(out, expected); +} diff --git a/crates/polars-ops/src/frame/join/mod.rs b/crates/polars-ops/src/frame/join/mod.rs new file mode 100644 index 000000000000..c64501760ce4 --- /dev/null +++ b/crates/polars-ops/src/frame/join/mod.rs @@ -0,0 +1,677 @@ +mod args; +#[cfg(feature = "asof_join")] +mod asof; +#[cfg(feature = "dtype-categorical")] +mod checks; +mod cross_join; +mod dispatch_left_right; +mod general; +mod hash_join; +#[cfg(feature = "iejoin")] +mod iejoin; +#[cfg(feature = "merge_sorted")] +mod merge_sorted; + +use std::borrow::Cow; +use std::fmt::{Debug, Display, Formatter}; +use std::hash::Hash; + +pub use args::*; +use arrow::trusted_len::TrustedLen; +#[cfg(feature = "asof_join")] +pub use asof::{AsOfOptions, AsofJoin, AsofJoinBy, AsofStrategy}; +#[cfg(feature = "dtype-categorical")] +pub(crate) use checks::*; +pub use cross_join::CrossJoin; +#[cfg(feature = "chunked_ids")] +use either::Either; +#[cfg(feature = "chunked_ids")] +use general::create_chunked_index_mapping; +pub use general::{_coalesce_full_join, _finish_join, _join_suffix_name}; +pub use hash_join::*; +use hashbrown::hash_map::{Entry, RawEntryMut}; +#[cfg(feature = "iejoin")] +pub use iejoin::{IEJoinOptions, InequalityOperator}; +#[cfg(feature = "merge_sorted")] +pub use merge_sorted::_merge_sorted_dfs; +use polars_core::POOL; +#[allow(unused_imports)] +use polars_core::chunked_array::ops::row_encode::{ + encode_rows_vertical_par_unordered, encode_rows_vertical_par_unordered_broadcast_nulls, +}; +use polars_core::hashing::_HASHMAP_INIT_SIZE; +use polars_core::prelude::*; +pub(super) use polars_core::series::IsSorted; +use polars_core::utils::slice_offsets; +#[allow(unused_imports)] +use polars_core::utils::slice_slice; +use polars_utils::hashing::BytesHash; +use rayon::prelude::*; + +use self::cross_join::fused_cross_filter; +use super::IntoDf; + +pub trait DataFrameJoinOps: IntoDf { + /// Generic join method. Can be used to join on multiple columns. + /// + /// # Example + /// + /// ```no_run + /// # use polars_core::prelude::*; + /// # use polars_ops::prelude::*; + /// let df1: DataFrame = df!("Fruit" => &["Apple", "Banana", "Pear"], + /// "Phosphorus (mg/100g)" => &[11, 22, 12])?; + /// let df2: DataFrame = df!("Name" => &["Apple", "Banana", "Pear"], + /// "Potassium (mg/100g)" => &[107, 358, 115])?; + /// + /// let df3: DataFrame = df1.join(&df2, ["Fruit"], ["Name"], JoinArgs::new(JoinType::Inner), + /// None)?; + /// assert_eq!(df3.shape(), (3, 3)); + /// println!("{}", df3); + /// # Ok::<(), PolarsError>(()) + /// ``` + /// + /// Output: + /// + /// ```text + /// shape: (3, 3) + /// +--------+----------------------+---------------------+ + /// | Fruit | Phosphorus (mg/100g) | Potassium (mg/100g) | + /// | --- | --- | --- | + /// | str | i32 | i32 | + /// +========+======================+=====================+ + /// | Apple | 11 | 107 | + /// +--------+----------------------+---------------------+ + /// | Banana | 22 | 358 | + /// +--------+----------------------+---------------------+ + /// | Pear | 12 | 115 | + /// +--------+----------------------+---------------------+ + /// ``` + fn join( + &self, + other: &DataFrame, + left_on: impl IntoIterator>, + right_on: impl IntoIterator>, + args: JoinArgs, + options: Option, + ) -> PolarsResult { + let df_left = self.to_df(); + let selected_left = df_left.select_columns(left_on)?; + let selected_right = other.select_columns(right_on)?; + + let selected_left = selected_left + .into_iter() + .map(Column::take_materialized_series) + .collect::>(); + let selected_right = selected_right + .into_iter() + .map(Column::take_materialized_series) + .collect::>(); + + self._join_impl( + other, + selected_left, + selected_right, + args, + options, + true, + false, + ) + } + + #[doc(hidden)] + #[allow(clippy::too_many_arguments)] + #[allow(unused_mut)] + fn _join_impl( + &self, + other: &DataFrame, + mut selected_left: Vec, + mut selected_right: Vec, + mut args: JoinArgs, + options: Option, + _check_rechunk: bool, + _verbose: bool, + ) -> PolarsResult { + let left_df = self.to_df(); + + #[cfg(feature = "cross_join")] + if let JoinType::Cross = args.how { + if let Some(JoinTypeOptions::Cross(cross_options)) = &options { + assert!(args.slice.is_none()); + return fused_cross_filter(left_df, other, args.suffix.clone(), cross_options); + } + return left_df.cross_join(other, args.suffix.clone(), args.slice); + } + + // Clear literals if a frame is empty. Otherwise we could get an oob + fn clear(s: &mut [Series]) { + for s in s.iter_mut() { + if s.len() == 1 { + *s = s.clear() + } + } + } + if left_df.is_empty() { + clear(&mut selected_left); + } + if other.is_empty() { + clear(&mut selected_right); + } + + let should_coalesce = args.should_coalesce(); + assert_eq!(selected_left.len(), selected_right.len()); + + #[cfg(feature = "chunked_ids")] + { + // a left join create chunked-ids + // the others not yet. + // TODO! change this to other join types once they support chunked-id joins + if _check_rechunk + && !(matches!(args.how, JoinType::Left) + || std::env::var("POLARS_NO_CHUNKED_JOIN").is_ok()) + { + let mut left = Cow::Borrowed(left_df); + let mut right = Cow::Borrowed(other); + if left_df.should_rechunk() { + if _verbose { + eprintln!( + "{:?} join triggered a rechunk of the left DataFrame: {} columns are affected", + args.how, + left_df.width() + ); + } + + let mut tmp_left = left_df.clone(); + tmp_left.as_single_chunk_par(); + left = Cow::Owned(tmp_left); + } + if other.should_rechunk() { + if _verbose { + eprintln!( + "{:?} join triggered a rechunk of the right DataFrame: {} columns are affected", + args.how, + other.width() + ); + } + let mut tmp_right = other.clone(); + tmp_right.as_single_chunk_par(); + right = Cow::Owned(tmp_right); + } + return left._join_impl( + &right, + selected_left, + selected_right, + args, + options, + false, + _verbose, + ); + } + } + + if let Some((l, r)) = selected_left + .iter() + .zip(&selected_right) + .find(|(l, r)| l.dtype() != r.dtype()) + { + polars_bail!( + ComputeError: + format!( + "datatypes of join keys don't match - `{}`: {} on left does not match `{}`: {} on right", + l.name(), l.dtype(), r.name(), r.dtype() + ) + ); + }; + + #[cfg(feature = "dtype-categorical")] + for (l, r) in selected_left.iter_mut().zip(selected_right.iter_mut()) { + match _check_categorical_src(l.dtype(), r.dtype()) { + Ok(_) => {}, + Err(_) => { + let (ca_left, ca_right) = + make_rhs_categoricals_compatible(l.categorical()?, r.categorical()?)?; + *l = ca_left.into_series().with_name(l.name().clone()); + *r = ca_right.into_series().with_name(r.name().clone()); + }, + } + } + + #[cfg(feature = "iejoin")] + if let JoinType::IEJoin = args.how { + let Some(JoinTypeOptions::IEJoin(options)) = options else { + unreachable!() + }; + let func = if POOL.current_num_threads() > 1 && !left_df.is_empty() && !other.is_empty() + { + iejoin::iejoin_par + } else { + iejoin::iejoin + }; + return func( + left_df, + other, + selected_left, + selected_right, + &options, + args.suffix, + args.slice, + ); + } + + // Single keys. + if selected_left.len() == 1 { + let s_left = &selected_left[0]; + let s_right = &selected_right[0]; + let drop_names: Option> = + if should_coalesce { None } else { Some(vec![]) }; + return match args.how { + JoinType::Inner => left_df + ._inner_join_from_series(other, s_left, s_right, args, _verbose, drop_names), + JoinType::Left => dispatch_left_right::left_join_from_series( + self.to_df().clone(), + other, + s_left, + s_right, + args, + _verbose, + drop_names, + ), + JoinType::Right => dispatch_left_right::right_join_from_series( + self.to_df(), + other.clone(), + s_left, + s_right, + args, + _verbose, + drop_names, + ), + JoinType::Full => left_df._full_join_from_series(other, s_left, s_right, args), + #[cfg(feature = "semi_anti_join")] + JoinType::Anti => left_df._semi_anti_join_from_series( + s_left, + s_right, + args.slice, + true, + args.nulls_equal, + ), + #[cfg(feature = "semi_anti_join")] + JoinType::Semi => left_df._semi_anti_join_from_series( + s_left, + s_right, + args.slice, + false, + args.nulls_equal, + ), + #[cfg(feature = "asof_join")] + JoinType::AsOf(options) => match (options.left_by, options.right_by) { + (Some(left_by), Some(right_by)) => left_df._join_asof_by( + other, + s_left, + s_right, + left_by, + right_by, + options.strategy, + options.tolerance, + args.suffix.clone(), + args.slice, + should_coalesce, + options.allow_eq, + options.check_sortedness, + ), + (None, None) => left_df._join_asof( + other, + s_left, + s_right, + options.strategy, + options.tolerance, + args.suffix, + args.slice, + should_coalesce, + options.allow_eq, + options.check_sortedness, + ), + _ => { + panic!("expected by arguments on both sides") + }, + }, + #[cfg(feature = "iejoin")] + JoinType::IEJoin => { + unreachable!() + }, + JoinType::Cross => { + unreachable!() + }, + }; + } + let (lhs_keys, rhs_keys) = + if (left_df.is_empty() || other.is_empty()) && matches!(&args.how, JoinType::Inner) { + // Fast path for empty inner joins. + // Return 2 dummies so that we don't row-encode. + let a = Series::full_null("".into(), 0, &DataType::Null); + (a.clone(), a) + } else { + // Row encode the keys. + ( + prepare_keys_multiple(&selected_left, args.nulls_equal)?.into_series(), + prepare_keys_multiple(&selected_right, args.nulls_equal)?.into_series(), + ) + }; + + let drop_names = if should_coalesce { + if args.how == JoinType::Right { + selected_left + .iter() + .map(|s| s.name().clone()) + .collect::>() + } else { + selected_right + .iter() + .map(|s| s.name().clone()) + .collect::>() + } + } else { + vec![] + }; + + // Multiple keys. + match args.how { + #[cfg(feature = "asof_join")] + JoinType::AsOf(_) => polars_bail!( + ComputeError: "asof join not supported for join on multiple keys" + ), + #[cfg(feature = "iejoin")] + JoinType::IEJoin => { + unreachable!() + }, + JoinType::Cross => { + unreachable!() + }, + JoinType::Full => { + let names_left = selected_left + .iter() + .map(|s| s.name().clone()) + .collect::>(); + args.coalesce = JoinCoalesce::KeepColumns; + let suffix = args.suffix.clone(); + let out = left_df._full_join_from_series(other, &lhs_keys, &rhs_keys, args); + + if should_coalesce { + Ok(_coalesce_full_join( + out?, + names_left.as_slice(), + drop_names.as_slice(), + suffix.clone(), + left_df, + )) + } else { + out + } + }, + JoinType::Inner => left_df._inner_join_from_series( + other, + &lhs_keys, + &rhs_keys, + args, + _verbose, + Some(drop_names), + ), + JoinType::Left => dispatch_left_right::left_join_from_series( + left_df.clone(), + other, + &lhs_keys, + &rhs_keys, + args, + _verbose, + Some(drop_names), + ), + JoinType::Right => dispatch_left_right::right_join_from_series( + left_df, + other.clone(), + &lhs_keys, + &rhs_keys, + args, + _verbose, + Some(drop_names), + ), + #[cfg(feature = "semi_anti_join")] + JoinType::Anti | JoinType::Semi => self._join_impl( + other, + vec![lhs_keys], + vec![rhs_keys], + args, + options, + _check_rechunk, + _verbose, + ), + } + } + + /// Perform an inner join on two DataFrames. + /// + /// # Example + /// + /// ``` + /// # use polars_core::prelude::*; + /// # use polars_ops::prelude::*; + /// fn join_dfs(left: &DataFrame, right: &DataFrame) -> PolarsResult { + /// left.inner_join(right, ["join_column_left"], ["join_column_right"]) + /// } + /// ``` + fn inner_join( + &self, + other: &DataFrame, + left_on: impl IntoIterator>, + right_on: impl IntoIterator>, + ) -> PolarsResult { + self.join( + other, + left_on, + right_on, + JoinArgs::new(JoinType::Inner), + None, + ) + } + + /// Perform a left outer join on two DataFrames + /// # Example + /// + /// ```no_run + /// # use polars_core::prelude::*; + /// # use polars_ops::prelude::*; + /// let df1: DataFrame = df!("Wavelength (nm)" => &[480.0, 650.0, 577.0, 1201.0, 100.0])?; + /// let df2: DataFrame = df!("Color" => &["Blue", "Yellow", "Red"], + /// "Wavelength nm" => &[480.0, 577.0, 650.0])?; + /// + /// let df3: DataFrame = df1.left_join(&df2, ["Wavelength (nm)"], ["Wavelength nm"])?; + /// println!("{:?}", df3); + /// # Ok::<(), PolarsError>(()) + /// ``` + /// + /// Output: + /// + /// ```text + /// shape: (5, 2) + /// +-----------------+--------+ + /// | Wavelength (nm) | Color | + /// | --- | --- | + /// | f64 | str | + /// +=================+========+ + /// | 480 | Blue | + /// +-----------------+--------+ + /// | 650 | Red | + /// +-----------------+--------+ + /// | 577 | Yellow | + /// +-----------------+--------+ + /// | 1201 | null | + /// +-----------------+--------+ + /// | 100 | null | + /// +-----------------+--------+ + /// ``` + fn left_join( + &self, + other: &DataFrame, + left_on: impl IntoIterator>, + right_on: impl IntoIterator>, + ) -> PolarsResult { + self.join( + other, + left_on, + right_on, + JoinArgs::new(JoinType::Left), + None, + ) + } + + /// Perform a full outer join on two DataFrames + /// # Example + /// + /// ``` + /// # use polars_core::prelude::*; + /// # use polars_ops::prelude::*; + /// fn join_dfs(left: &DataFrame, right: &DataFrame) -> PolarsResult { + /// left.full_join(right, ["join_column_left"], ["join_column_right"]) + /// } + /// ``` + fn full_join( + &self, + other: &DataFrame, + left_on: impl IntoIterator>, + right_on: impl IntoIterator>, + ) -> PolarsResult { + self.join( + other, + left_on, + right_on, + JoinArgs::new(JoinType::Full), + None, + ) + } +} + +trait DataFrameJoinOpsPrivate: IntoDf { + fn _inner_join_from_series( + &self, + other: &DataFrame, + s_left: &Series, + s_right: &Series, + args: JoinArgs, + verbose: bool, + drop_names: Option>, + ) -> PolarsResult { + let left_df = self.to_df(); + #[cfg(feature = "dtype-categorical")] + _check_categorical_src(s_left.dtype(), s_right.dtype())?; + let ((join_tuples_left, join_tuples_right), sorted) = + _sort_or_hash_inner(s_left, s_right, verbose, args.validation, args.nulls_equal)?; + + let mut join_tuples_left = &*join_tuples_left; + let mut join_tuples_right = &*join_tuples_right; + + if let Some((offset, len)) = args.slice { + join_tuples_left = slice_slice(join_tuples_left, offset, len); + join_tuples_right = slice_slice(join_tuples_right, offset, len); + } + + let other = if let Some(drop_names) = drop_names { + other.drop_many(drop_names) + } else { + other.drop(s_right.name()).unwrap() + }; + + let mut left = unsafe { IdxCa::mmap_slice("a".into(), join_tuples_left) }; + if sorted { + left.set_sorted_flag(IsSorted::Ascending); + } + let right = unsafe { IdxCa::mmap_slice("b".into(), join_tuples_right) }; + + let already_left_sorted = sorted + && matches!( + args.maintain_order, + MaintainOrderJoin::Left | MaintainOrderJoin::LeftRight + ); + try_raise_keyboard_interrupt(); + let (df_left, df_right) = + if args.maintain_order != MaintainOrderJoin::None && !already_left_sorted { + let mut df = + DataFrame::new(vec![left.into_series().into(), right.into_series().into()])?; + + let columns = match args.maintain_order { + MaintainOrderJoin::Left | MaintainOrderJoin::LeftRight => vec!["a"], + MaintainOrderJoin::Right | MaintainOrderJoin::RightLeft => vec!["b"], + _ => unreachable!(), + }; + + let options = SortMultipleOptions::new() + .with_order_descending(false) + .with_maintain_order(true); + + df.sort_in_place(columns, options)?; + + let [mut a, b]: [Column; 2] = df.take_columns().try_into().unwrap(); + if matches!( + args.maintain_order, + MaintainOrderJoin::Left | MaintainOrderJoin::LeftRight + ) { + a.set_sorted_flag(IsSorted::Ascending); + } + + POOL.join( + // SAFETY: join indices are known to be in bounds + || unsafe { left_df.take_unchecked(a.idx().unwrap()) }, + || unsafe { other.take_unchecked(b.idx().unwrap()) }, + ) + } else { + POOL.join( + // SAFETY: join indices are known to be in bounds + || unsafe { left_df.take_unchecked(left.into_series().idx().unwrap()) }, + || unsafe { other.take_unchecked(right.into_series().idx().unwrap()) }, + ) + }; + + _finish_join(df_left, df_right, args.suffix.clone()) + } +} + +impl DataFrameJoinOps for DataFrame {} +impl DataFrameJoinOpsPrivate for DataFrame {} + +fn prepare_keys_multiple(s: &[Series], nulls_equal: bool) -> PolarsResult { + let keys = s + .iter() + .map(|s| { + let phys = s.to_physical_repr(); + match phys.dtype() { + DataType::Float32 => phys.f32().unwrap().to_canonical().into_column(), + DataType::Float64 => phys.f64().unwrap().to_canonical().into_column(), + _ => phys.into_owned().into_column(), + } + }) + .collect::>(); + + if nulls_equal { + encode_rows_vertical_par_unordered(&keys) + } else { + encode_rows_vertical_par_unordered_broadcast_nulls(&keys) + } +} +pub fn private_left_join_multiple_keys( + a: &DataFrame, + b: &DataFrame, + nulls_equal: bool, +) -> PolarsResult { + // @scalar-opt + let a_cols = a + .get_columns() + .iter() + .map(|c| c.as_materialized_series().clone()) + .collect::>(); + let b_cols = b + .get_columns() + .iter() + .map(|c| c.as_materialized_series().clone()) + .collect::>(); + + let a = prepare_keys_multiple(&a_cols, nulls_equal)?.into_series(); + let b = prepare_keys_multiple(&b_cols, nulls_equal)?.into_series(); + sort_or_hash_left(&a, &b, false, JoinValidation::ManyToMany, nulls_equal) +} diff --git a/crates/polars-ops/src/frame/mod.rs b/crates/polars-ops/src/frame/mod.rs new file mode 100644 index 000000000000..f57fa4fb9ce2 --- /dev/null +++ b/crates/polars-ops/src/frame/mod.rs @@ -0,0 +1,116 @@ +pub mod join; +#[cfg(feature = "pivot")] +pub mod pivot; + +pub use join::*; +#[cfg(feature = "to_dummies")] +use polars_core::POOL; +use polars_core::prelude::*; +#[cfg(feature = "to_dummies")] +use polars_core::utils::accumulate_dataframes_horizontal; +#[cfg(feature = "to_dummies")] +use rayon::prelude::*; + +pub trait IntoDf { + fn to_df(&self) -> &DataFrame; +} + +impl IntoDf for DataFrame { + fn to_df(&self) -> &DataFrame { + self + } +} + +impl DataFrameOps for T {} + +pub trait DataFrameOps: IntoDf { + /// Create dummy variables. + /// + /// # Example + /// + /// ```ignore + /// + /// # #[macro_use] extern crate polars_core; + /// # fn main() { + /// + /// use polars_core::prelude::*; + /// + /// let df = df! { + /// "id" => &[1, 2, 3, 1, 2, 3, 1, 1], + /// "type" => &["A", "B", "B", "B", "C", "C", "C", "B"], + /// "code" => &["X1", "X2", "X3", "X3", "X2", "X2", "X1", "X1"] + /// }.unwrap(); + /// + /// let dummies = df.to_dummies(None, false).unwrap(); + /// println!("{}", dummies); + /// # } + /// ``` + /// Outputs: + /// ```text + /// +------+------+------+--------+--------+--------+---------+---------+---------+ + /// | id_1 | id_3 | id_2 | type_A | type_B | type_C | code_X1 | code_X2 | code_X3 | + /// | --- | --- | --- | --- | --- | --- | --- | --- | --- | + /// | u8 | u8 | u8 | u8 | u8 | u8 | u8 | u8 | u8 | + /// +======+======+======+========+========+========+=========+=========+=========+ + /// | 1 | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 0 | + /// +------+------+------+--------+--------+--------+---------+---------+---------+ + /// | 0 | 0 | 1 | 0 | 1 | 0 | 0 | 1 | 0 | + /// +------+------+------+--------+--------+--------+---------+---------+---------+ + /// | 0 | 1 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | + /// +------+------+------+--------+--------+--------+---------+---------+---------+ + /// | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | + /// +------+------+------+--------+--------+--------+---------+---------+---------+ + /// | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 1 | 0 | + /// +------+------+------+--------+--------+--------+---------+---------+---------+ + /// | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 1 | 0 | + /// +------+------+------+--------+--------+--------+---------+---------+---------+ + /// | 1 | 0 | 0 | 0 | 0 | 1 | 1 | 0 | 0 | + /// +------+------+------+--------+--------+--------+---------+---------+---------+ + /// | 1 | 0 | 0 | 0 | 1 | 0 | 1 | 0 | 0 | + /// +------+------+------+--------+--------+--------+---------+---------+---------+ + /// ``` + #[cfg(feature = "to_dummies")] + fn to_dummies(&self, separator: Option<&str>, drop_first: bool) -> PolarsResult { + self._to_dummies(None, separator, drop_first) + } + + #[cfg(feature = "to_dummies")] + fn columns_to_dummies( + &self, + columns: Vec<&str>, + separator: Option<&str>, + drop_first: bool, + ) -> PolarsResult { + self._to_dummies(Some(columns), separator, drop_first) + } + + #[cfg(feature = "to_dummies")] + fn _to_dummies( + &self, + columns: Option>, + separator: Option<&str>, + drop_first: bool, + ) -> PolarsResult { + use crate::series::ToDummies; + + let df = self.to_df(); + + let set: PlHashSet<&str> = if let Some(columns) = columns { + PlHashSet::from_iter(columns) + } else { + PlHashSet::from_iter(df.iter().map(|s| s.name().as_str())) + }; + + let cols = POOL.install(|| { + df.get_columns() + .par_iter() + .map(|s| match set.contains(s.name().as_str()) { + true => s.as_materialized_series().to_dummies(separator, drop_first), + false => Ok(s.clone().into_frame()), + }) + .collect::>>() + })?; + + accumulate_dataframes_horizontal(cols) + } +} diff --git a/crates/polars-ops/src/frame/pivot/mod.rs b/crates/polars-ops/src/frame/pivot/mod.rs new file mode 100644 index 000000000000..9096f5c6e35a --- /dev/null +++ b/crates/polars-ops/src/frame/pivot/mod.rs @@ -0,0 +1,387 @@ +mod positioning; +mod unpivot; + +use std::borrow::Cow; + +use polars_core::frame::group_by::expr::PhysicalAggExpr; +use polars_core::prelude::*; +use polars_core::utils::_split_offsets; +use polars_core::{POOL, downcast_as_macro_arg_physical}; +use polars_utils::format_pl_smallstr; +use rayon::prelude::*; +pub use unpivot::UnpivotDF; + +const HASHMAP_INIT_SIZE: usize = 512; + +#[derive(Clone)] +pub enum PivotAgg { + First, + Sum, + Min, + Max, + Mean, + Median, + Count, + Last, + Expr(Arc), +} + +fn restore_logical_type(s: &Series, logical_type: &DataType) -> Series { + // restore logical type + match (logical_type, s.dtype()) { + #[cfg(feature = "dtype-categorical")] + (dt @ DataType::Categorical(Some(rev_map), ordering), _) + | (dt @ DataType::Enum(Some(rev_map), ordering), _) => { + let cats = s.u32().unwrap().clone(); + // SAFETY: + // the rev-map comes from these categoricals + unsafe { + CategoricalChunked::from_cats_and_rev_map_unchecked( + cats, + rev_map.clone(), + matches!(dt, DataType::Enum(_, _)), + *ordering, + ) + .into_series() + } + }, + (DataType::Float32, DataType::UInt32) => { + let ca = s.u32().unwrap(); + ca._reinterpret_float().into_series() + }, + (DataType::Float64, DataType::UInt64) => { + let ca = s.u64().unwrap(); + ca._reinterpret_float().into_series() + }, + (DataType::Int32, DataType::UInt32) => { + let ca = s.u32().unwrap(); + ca.reinterpret_signed() + }, + (DataType::Int64, DataType::UInt64) => { + let ca = s.u64().unwrap(); + ca.reinterpret_signed() + }, + #[cfg(feature = "dtype-duration")] + (DataType::Duration(_), DataType::UInt64) => { + let ca = s.u64().unwrap(); + ca.reinterpret_signed().cast(logical_type).unwrap() + }, + #[cfg(feature = "dtype-datetime")] + (DataType::Datetime(_, _), DataType::UInt64) => { + let ca = s.u64().unwrap(); + ca.reinterpret_signed().cast(logical_type).unwrap() + }, + #[cfg(feature = "dtype-date")] + (DataType::Date, DataType::UInt32) => { + let ca = s.u32().unwrap(); + ca.reinterpret_signed().cast(logical_type).unwrap() + }, + (dt, DataType::Null) => { + let ca = Series::full_null(s.name().clone(), s.len(), dt); + ca.into_series() + }, + _ => unsafe { s.from_physical_unchecked(logical_type).unwrap() }, + } +} + +/// Do a pivot operation based on the group key, a pivot column and an aggregation function on the values column. +/// +/// # Note +/// Polars'/arrow memory is not ideal for transposing operations like pivots. +/// If you have a relatively large table, consider using a group_by over a pivot. +pub fn pivot( + pivot_df: &DataFrame, + on: I0, + index: Option, + values: Option, + sort_columns: bool, + agg_fn: Option, + separator: Option<&str>, +) -> PolarsResult +where + I0: IntoIterator, + I1: IntoIterator, + I2: IntoIterator, + S0: Into, + S1: Into, + S2: Into, +{ + let on = on.into_iter().map(Into::into).collect::>(); + let (index, values) = assign_remaining_columns(pivot_df, &on, index, values)?; + pivot_impl( + pivot_df, + &on, + &index, + &values, + agg_fn, + sort_columns, + false, + separator, + ) +} + +/// Do a pivot operation based on the group key, a pivot column and an aggregation function on the values column. +/// +/// # Note +/// Polars'/arrow memory is not ideal for transposing operations like pivots. +/// If you have a relatively large table, consider using a group_by over a pivot. +pub fn pivot_stable( + pivot_df: &DataFrame, + on: I0, + index: Option, + values: Option, + sort_columns: bool, + agg_fn: Option, + separator: Option<&str>, +) -> PolarsResult +where + I0: IntoIterator, + I1: IntoIterator, + I2: IntoIterator, + S0: Into, + S1: Into, + S2: Into, +{ + let on = on.into_iter().map(Into::into).collect::>(); + let (index, values) = assign_remaining_columns(pivot_df, &on, index, values)?; + pivot_impl( + pivot_df, + on.as_slice(), + index.as_slice(), + values.as_slice(), + agg_fn, + sort_columns, + true, + separator, + ) +} + +/// Ensure both `index` and `values` are populated with `Vec`. +/// +/// - If `index` is None, assign columns not in `on` and `values` to it. +/// - If `values` is None, assign columns not in `on` and `index` to it. +/// - At least one of `index` and `values` must be non-null. +fn assign_remaining_columns( + df: &DataFrame, + on: &[PlSmallStr], + index: Option, + values: Option, +) -> PolarsResult<(Vec, Vec)> +where + I1: IntoIterator, + I2: IntoIterator, + S1: Into, + S2: Into, +{ + match (index, values) { + (Some(index), Some(values)) => { + let index = index.into_iter().map(Into::into).collect(); + let values = values.into_iter().map(Into::into).collect(); + Ok((index, values)) + }, + (Some(index), None) => { + let index: Vec = index.into_iter().map(Into::into).collect(); + let values = df + .get_column_names() + .into_iter() + .filter(|c| !(index.contains(c) | on.contains(c))) + .cloned() + .collect(); + Ok((index, values)) + }, + (None, Some(values)) => { + let values: Vec = values.into_iter().map(Into::into).collect(); + let index = df + .get_column_names() + .into_iter() + .filter(|c| !(values.contains(c) | on.contains(c))) + .cloned() + .collect(); + Ok((index, values)) + }, + (None, None) => { + polars_bail!(InvalidOperation: "`index` and `values` cannot both be None in `pivot` operation") + }, + } +} + +#[allow(clippy::too_many_arguments)] +fn pivot_impl( + pivot_df: &DataFrame, + // keys of the first group_by operation + on: &[PlSmallStr], + // these columns will be aggregated in the nested group_by + index: &[PlSmallStr], + // these columns will be used for a nested group_by + // the rows of this nested group_by will be pivoted as header column values + values: &[PlSmallStr], + // aggregation function + agg_fn: Option, + sort_columns: bool, + stable: bool, + // used as separator/delimiter in generated column names. + separator: Option<&str>, +) -> PolarsResult { + polars_ensure!(!index.is_empty(), ComputeError: "index cannot be zero length"); + polars_ensure!(!on.is_empty(), ComputeError: "`on` cannot be zero length"); + if !stable { + println!("unstable pivot not yet supported, using stable pivot"); + }; + if on.len() > 1 { + let schema = Arc::new(pivot_df.schema()); + let binding = pivot_df.select_with_schema(on.iter().cloned(), &schema)?; + let fields = binding.get_columns(); + let column = format_pl_smallstr!("{{\"{}\"}}", on.join("\",\"")); + if schema.contains(column.as_str()) { + polars_bail!(ComputeError: "cannot use column name {column} that \ + already exists in the DataFrame. Please rename it prior to calling `pivot`.") + } + // @scalar-opt + let columns_struct = StructChunked::from_columns(column.clone(), fields[0].len(), fields) + .unwrap() + .into_column(); + let mut binding = pivot_df.clone(); + let pivot_df = unsafe { binding.with_column_unchecked(columns_struct) }; + pivot_impl_single_column( + pivot_df, + index, + &column, + values, + agg_fn, + sort_columns, + separator, + ) + } else { + pivot_impl_single_column( + pivot_df, + index, + unsafe { on.get_unchecked(0) }, + values, + agg_fn, + sort_columns, + separator, + ) + } +} + +fn pivot_impl_single_column( + pivot_df: &DataFrame, + index: &[PlSmallStr], + column: &PlSmallStr, + values: &[PlSmallStr], + agg_fn: Option, + sort_columns: bool, + separator: Option<&str>, +) -> PolarsResult { + let sep = separator.unwrap_or("_"); + let mut final_cols = vec![]; + let mut count = 0; + let out: PolarsResult<()> = POOL.install(|| { + let mut group_by = index.to_vec(); + group_by.push(column.clone()); + + let groups = pivot_df.group_by_stable(group_by)?.take_groups(); + + let (col, row) = POOL.join( + || positioning::compute_col_idx(pivot_df, column, &groups), + || positioning::compute_row_idx(pivot_df, index, &groups, count), + ); + let (col_locations, column_agg) = col?; + let (row_locations, n_rows, mut row_index) = row?; + + for value_col_name in values { + let value_col = pivot_df.column(value_col_name)?; + + use PivotAgg::*; + let value_agg = unsafe { + match &agg_fn { + None => match value_col.len() > groups.len() { + true => polars_bail!( + ComputeError: + "found multiple elements in the same group, \ + please specify an aggregation function" + ), + false => value_col.agg_first(&groups), + }, + Some(agg_fn) => match agg_fn { + Sum => value_col.agg_sum(&groups), + Min => value_col.agg_min(&groups), + Max => value_col.agg_max(&groups), + Last => value_col.agg_last(&groups), + First => value_col.agg_first(&groups), + Mean => value_col.agg_mean(&groups), + Median => value_col.agg_median(&groups), + Count => groups.group_count().into_column(), + Expr(expr) => { + let name = expr.root_name()?.clone(); + let mut value_col = value_col.clone(); + value_col.rename(name); + let tmp_df = value_col.into_frame(); + let mut aggregated = Column::from(expr.evaluate(&tmp_df, &groups)?); + aggregated.rename(value_col_name.clone()); + aggregated + }, + }, + } + }; + + let headers = column_agg.unique_stable()?.cast(&DataType::String)?; + let mut headers = headers.str().unwrap().clone(); + if values.len() > 1 { + headers = headers.apply_values(|v| Cow::from(format!("{value_col_name}{sep}{v}"))) + } + + let n_cols = headers.len(); + let value_agg_phys = value_agg.to_physical_repr(); + let logical_type = value_agg.dtype(); + + debug_assert_eq!(row_locations.len(), col_locations.len()); + debug_assert_eq!(value_agg_phys.len(), row_locations.len()); + + let mut cols = if value_agg_phys.dtype().is_primitive_numeric() { + macro_rules! dispatch { + ($ca:expr) => {{ + positioning::position_aggregates_numeric( + n_rows, + n_cols, + &row_locations, + &col_locations, + $ca, + logical_type, + &headers, + ) + }}; + } + downcast_as_macro_arg_physical!(value_agg_phys, dispatch) + } else { + positioning::position_aggregates( + n_rows, + n_cols, + &row_locations, + &col_locations, + value_agg_phys.as_materialized_series(), + logical_type, + &headers, + ) + }; + + if sort_columns { + cols.sort_unstable_by(|a, b| a.name().partial_cmp(b.name()).unwrap()); + } + + let cols = if count == 0 { + let mut final_cols = row_index.take().unwrap(); + final_cols.extend(cols); + final_cols + } else { + cols + }; + count += 1; + final_cols.extend_from_slice(&cols); + } + Ok(()) + }); + out?; + + DataFrame::new(final_cols) +} diff --git a/crates/polars-ops/src/frame/pivot/positioning.rs b/crates/polars-ops/src/frame/pivot/positioning.rs new file mode 100644 index 000000000000..78a0aaec8f50 --- /dev/null +++ b/crates/polars-ops/src/frame/pivot/positioning.rs @@ -0,0 +1,511 @@ +use std::hash::Hash; + +use arrow::legacy::trusted_len::TrustedLenPush; +use polars_core::prelude::*; +use polars_core::series::BitRepr; +use polars_utils::sync::SyncPtr; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; + +use super::*; + +pub(super) fn position_aggregates( + n_rows: usize, + n_cols: usize, + row_locations: &[IdxSize], + col_locations: &[IdxSize], + value_agg_phys: &Series, + logical_type: &DataType, + headers: &StringChunked, +) -> Vec { + let mut buf = vec![AnyValue::Null; n_rows * n_cols]; + let start_ptr = buf.as_mut_ptr() as usize; + + let n_threads = POOL.current_num_threads(); + let split = _split_offsets(row_locations.len(), n_threads); + + // ensure the slice series are not dropped + // so the AnyValues are referencing correct data, if they reference arrays (struct) + let n_splits = split.len(); + let mut arrays: Vec = Vec::with_capacity(n_splits); + + // every thread will only write to their partition + let array_ptr = unsafe { SyncPtr::new(arrays.as_mut_ptr()) }; + + POOL.install(|| { + split + .into_par_iter() + .enumerate() + .for_each(|(i, (offset, len))| { + let start_ptr = start_ptr as *mut AnyValue; + let row_locations = &row_locations[offset..offset + len]; + let col_locations = &col_locations[offset..offset + len]; + let value_agg_phys = value_agg_phys.slice(offset as i64, len); + + for ((row_idx, col_idx), val) in row_locations + .iter() + .zip(col_locations) + .zip(value_agg_phys.phys_iter()) + { + // SAFETY: + // in bounds + unsafe { + let idx = *row_idx as usize + *col_idx as usize * n_rows; + debug_assert!(idx < buf.len()); + let pos = start_ptr.add(idx); + std::ptr::write(pos, val) + } + } + // ensure the `values_agg_phys` stays alive + let array_ptr = array_ptr.clone().get(); + unsafe { std::ptr::write(array_ptr.add(i), value_agg_phys) } + }); + // ensure the content of the arrays are dropped + unsafe { + arrays.set_len(n_splits); + } + + let headers_iter = headers.par_iter_indexed(); + let phys_type = logical_type.to_physical(); + + (0..n_cols) + .into_par_iter() + .zip(headers_iter) + .map(|(i, opt_name)| { + let offset = i * n_rows; + let avs = &buf[offset..offset + n_rows]; + let name = opt_name + .map(PlSmallStr::from_str) + .unwrap_or_else(|| PlSmallStr::from_static("null")); + let out = match &phys_type { + #[cfg(feature = "dtype-struct")] + DataType::Struct(_) => { + // we know we can trust this data, so we use the explicit builder + use polars_core::frame::row::AnyValueBufferTrusted; + let mut buf = AnyValueBufferTrusted::new(&phys_type, avs.len()); + for av in avs { + unsafe { + buf.add_unchecked_borrowed_physical(av); + } + } + let mut out = buf.into_series(); + out.rename(name); + out + }, + _ => Series::from_any_values_and_dtype(name, avs, &phys_type, false).unwrap(), + }; + unsafe { out.from_physical_unchecked(logical_type).unwrap() }.into() + }) + .collect::>() + }) +} + +pub(super) fn position_aggregates_numeric( + n_rows: usize, + n_cols: usize, + row_locations: &[IdxSize], + col_locations: &[IdxSize], + value_agg_phys: &ChunkedArray, + logical_type: &DataType, + headers: &StringChunked, +) -> Vec +where + T: PolarsNumericType, + ChunkedArray: IntoSeries, +{ + let mut buf = vec![None; n_rows * n_cols]; + let start_ptr = buf.as_mut_ptr() as usize; + + let n_threads = POOL.current_num_threads(); + + let split = _split_offsets(row_locations.len(), n_threads); + let n_splits = split.len(); + // ensure the arrays are not dropped + // so the AnyValues are referencing correct data, if they reference arrays (struct) + let mut arrays: Vec> = Vec::with_capacity(n_splits); + + // every thread will only write to their partition + let array_ptr = unsafe { SyncPtr::new(arrays.as_mut_ptr()) }; + + POOL.install(|| { + split + .into_par_iter() + .enumerate() + .for_each(|(i, (offset, len))| { + let start_ptr = start_ptr as *mut Option; + let row_locations = &row_locations[offset..offset + len]; + let col_locations = &col_locations[offset..offset + len]; + let value_agg_phys = value_agg_phys.slice(offset as i64, len); + + // todo! remove lint silencing + #[allow(clippy::useless_conversion)] + for ((row_idx, col_idx), val) in row_locations + .iter() + .zip(col_locations) + .zip(value_agg_phys.into_iter()) + { + // SAFETY: + // in bounds + unsafe { + let idx = *row_idx as usize + *col_idx as usize * n_rows; + debug_assert!(idx < buf.len()); + let pos = start_ptr.add(idx); + std::ptr::write(pos, val) + } + } + // ensure the `values_agg_phys` stays alive + let array_ptr = array_ptr.clone().get(); + unsafe { std::ptr::write(array_ptr.add(i), value_agg_phys) } + }); + // ensure the content of the arrays are dropped + unsafe { + arrays.set_len(n_splits); + } + let headers_iter = headers.par_iter_indexed(); + + (0..n_cols) + .into_par_iter() + .zip(headers_iter) + .map(|(i, opt_name)| { + let offset = i * n_rows; + let opt_values = &buf[offset..offset + n_rows]; + let name = opt_name + .map(PlSmallStr::from_str) + .unwrap_or_else(|| PlSmallStr::from_static("null")); + let out = ChunkedArray::::from_slice_options(name, opt_values).into_series(); + unsafe { out.from_physical_unchecked(logical_type).unwrap() }.into() + }) + .collect::>() + }) +} + +fn compute_col_idx_numeric(column_agg_physical: &ChunkedArray) -> Vec +where + T: PolarsNumericType, + T::Native: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq, +{ + let mut col_to_idx = PlHashMap::with_capacity(HASHMAP_INIT_SIZE); + let mut idx = 0 as IdxSize; + let mut out = Vec::with_capacity(column_agg_physical.len()); + + for opt_v in column_agg_physical.iter() { + let opt_v = opt_v.to_total_ord(); + let idx = *col_to_idx.entry(opt_v).or_insert_with(|| { + let old_idx = idx; + idx += 1; + old_idx + }); + // SAFETY: + // we pre-allocated + unsafe { out.push_unchecked(idx) }; + } + out +} + +fn compute_col_idx_gen<'a, T>(column_agg_physical: &'a ChunkedArray) -> Vec +where + T: PolarsDataType, + &'a T::Array: IntoIterator>>, + T::Physical<'a>: Hash + Eq, +{ + let mut col_to_idx = PlHashMap::with_capacity(HASHMAP_INIT_SIZE); + let mut idx = 0 as IdxSize; + let mut out = Vec::with_capacity(column_agg_physical.len()); + + for arr in column_agg_physical.downcast_iter() { + for opt_v in arr.into_iter() { + let idx = *col_to_idx.entry(opt_v).or_insert_with(|| { + let old_idx = idx; + idx += 1; + old_idx + }); + // SAFETY: + // we pre-allocated + unsafe { out.push_unchecked(idx) }; + } + } + out +} + +pub(super) fn compute_col_idx( + pivot_df: &DataFrame, + column: &str, + groups: &GroupsType, +) -> PolarsResult<(Vec, Column)> { + let column_s = pivot_df.column(column)?; + let column_agg = unsafe { column_s.agg_first(groups) }; + let column_agg_physical = column_agg.to_physical_repr(); + + use DataType as T; + let col_locations = match column_agg_physical.dtype() { + T::Int32 | T::UInt32 => { + let Some(BitRepr::Small(ca)) = column_agg_physical.bit_repr() else { + polars_bail!(ComputeError: "Expected 32-bit representation to be available; this should never happen"); + }; + compute_col_idx_numeric(&ca) + }, + T::Int64 | T::UInt64 => { + let Some(BitRepr::Large(ca)) = column_agg_physical.bit_repr() else { + polars_bail!(ComputeError: "Expected 64-bit representation to be available; this should never happen"); + }; + compute_col_idx_numeric(&ca) + }, + T::Float64 => { + let ca: &ChunkedArray = column_agg_physical + .as_materialized_series() + .as_ref() + .as_ref(); + compute_col_idx_numeric(ca) + }, + T::Float32 => { + let ca: &ChunkedArray = column_agg_physical + .as_materialized_series() + .as_ref() + .as_ref(); + compute_col_idx_numeric(ca) + }, + T::Struct(_) => { + let ca = column_agg_physical.struct_().unwrap(); + let ca = ca.get_row_encoded(Default::default())?; + compute_col_idx_gen(&ca) + }, + T::String => { + let ca = column_agg_physical.str().unwrap(); + let ca = ca.as_binary(); + compute_col_idx_gen(&ca) + }, + T::Binary => { + let ca = column_agg_physical.binary().unwrap(); + compute_col_idx_gen(ca) + }, + T::Boolean => { + let ca = column_agg_physical.bool().unwrap(); + compute_col_idx_gen(ca) + }, + _ => { + let mut col_to_idx = PlHashMap::with_capacity(HASHMAP_INIT_SIZE); + let mut idx = 0 as IdxSize; + column_agg_physical + .as_materialized_series() + .phys_iter() + .map(|v| { + *col_to_idx.entry(v).or_insert_with(|| { + let old_idx = idx; + idx += 1; + old_idx + }) + }) + .collect() + }, + }; + + Ok((col_locations, column_agg)) +} + +fn compute_row_index<'a, T>( + index: &[PlSmallStr], + index_agg_physical: &'a ChunkedArray, + count: usize, + logical_type: &DataType, +) -> (Vec, usize, Option>) +where + T: PolarsDataType, + T::Physical<'a>: TotalHash + TotalEq + Copy + ToTotalOrd, + > as ToTotalOrd>::TotalOrdItem: Hash + Eq, + ChunkedArray: FromIterator>>, + ChunkedArray: IntoSeries, +{ + let mut row_to_idx = + PlIndexMap::with_capacity_and_hasher(HASHMAP_INIT_SIZE, Default::default()); + let mut idx = 0 as IdxSize; + + let mut row_locations = Vec::with_capacity(index_agg_physical.len()); + for opt_v in index_agg_physical.iter() { + let opt_v = opt_v.to_total_ord(); + let idx = *row_to_idx.entry(opt_v).or_insert_with(|| { + let old_idx = idx; + idx += 1; + old_idx + }); + + // SAFETY: + // we pre-allocated + unsafe { + row_locations.push_unchecked(idx); + } + } + let row_index = match count { + 0 => { + let mut s = row_to_idx + .into_iter() + .map(|(k, _)| Option::>::peel_total_ord(k)) + .collect::>() + .into_series(); + s.rename(index[0].clone()); + let s = restore_logical_type(&s, logical_type); + Some(vec![s.into()]) + }, + _ => None, + }; + + (row_locations, idx as usize, row_index) +} + +fn compute_row_index_struct( + index: &[PlSmallStr], + index_agg: &Series, + index_agg_physical: &BinaryOffsetChunked, + count: usize, +) -> (Vec, usize, Option>) { + let mut row_to_idx = + PlIndexMap::with_capacity_and_hasher(HASHMAP_INIT_SIZE, Default::default()); + let mut idx = 0 as IdxSize; + + let mut row_locations = Vec::with_capacity(index_agg_physical.len()); + let mut unique_indices = Vec::with_capacity(index_agg_physical.len()); + let mut row_number: IdxSize = 0; + for arr in index_agg_physical.downcast_iter() { + for opt_v in arr.iter() { + let idx = *row_to_idx.entry(opt_v).or_insert_with(|| { + // SAFETY: we pre-allocated + unsafe { unique_indices.push_unchecked(row_number) }; + let old_idx = idx; + idx += 1; + old_idx + }); + row_number += 1; + + // SAFETY: + // we pre-allocated + unsafe { + row_locations.push_unchecked(idx); + } + } + } + let row_index = match count { + 0 => { + // SAFETY: `unique_indices` is filled with elements between + // 0 and `index_agg.len() - 1`. + let mut s = unsafe { index_agg.take_slice_unchecked(&unique_indices) }; + s.rename(index[0].clone()); + Some(vec![s.into()]) + }, + _ => None, + }; + + (row_locations, idx as usize, row_index) +} + +// TODO! Also create a specialized version for numerics. +pub(super) fn compute_row_idx( + pivot_df: &DataFrame, + index: &[PlSmallStr], + groups: &GroupsType, + count: usize, +) -> PolarsResult<(Vec, usize, Option>)> { + let (row_locations, n_rows, row_index) = if index.len() == 1 { + let index_s = pivot_df.column(&index[0])?; + let index_agg = unsafe { index_s.agg_first(groups) }; + let index_agg_physical = index_agg.to_physical_repr(); + + use DataType as T; + match index_agg_physical.dtype() { + T::Int32 | T::UInt32 => { + let Some(BitRepr::Small(ca)) = index_agg_physical.bit_repr() else { + polars_bail!(ComputeError: "Expected 32-bit representation to be available; this should never happen"); + }; + compute_row_index(index, &ca, count, index_s.dtype()) + }, + T::Int64 | T::UInt64 => { + let Some(BitRepr::Large(ca)) = index_agg_physical.bit_repr() else { + polars_bail!(ComputeError: "Expected 64-bit representation to be available; this should never happen"); + }; + compute_row_index(index, &ca, count, index_s.dtype()) + }, + T::Float64 => { + let ca: &ChunkedArray = index_agg_physical + .as_materialized_series() + .as_ref() + .as_ref(); + compute_row_index(index, ca, count, index_s.dtype()) + }, + T::Float32 => { + let ca: &ChunkedArray = index_agg_physical + .as_materialized_series() + .as_ref() + .as_ref(); + compute_row_index(index, ca, count, index_s.dtype()) + }, + T::Boolean => { + let ca = index_agg_physical.bool().unwrap(); + compute_row_index(index, ca, count, index_s.dtype()) + }, + T::Struct(_) => { + let ca = index_agg_physical.struct_().unwrap(); + let ca = ca.get_row_encoded(Default::default())?; + compute_row_index_struct(index, index_agg.as_materialized_series(), &ca, count) + }, + T::String => { + let ca = index_agg_physical.str().unwrap(); + compute_row_index(index, ca, count, index_s.dtype()) + }, + _ => { + let mut row_to_idx = + PlIndexMap::with_capacity_and_hasher(HASHMAP_INIT_SIZE, Default::default()); + let mut idx = 0 as IdxSize; + let row_locations = index_agg_physical + .as_materialized_series() + .phys_iter() + .map(|v| { + *row_to_idx.entry(v).or_insert_with(|| { + let old_idx = idx; + idx += 1; + old_idx + }) + }) + .collect::>(); + + let row_index = match count { + 0 => { + let s = Series::new( + index[0].clone(), + row_to_idx.into_iter().map(|(k, _)| k).collect::>(), + ); + let s = restore_logical_type(&s, index_s.dtype()); + Some(vec![Column::from(s)]) + }, + _ => None, + }; + + (row_locations, idx as usize, row_index) + }, + } + } else { + let binding = pivot_df.select(index.iter().cloned())?; + let fields = binding.get_columns(); + let index_struct_series = StructChunked::from_columns( + PlSmallStr::from_static("placeholder"), + fields[0].len(), + fields, + )? + .into_series(); + let index_agg = unsafe { index_struct_series.agg_first(groups) }; + let index_agg_physical = index_agg.to_physical_repr(); + let ca = index_agg_physical.struct_()?; + let ca = ca.get_row_encoded(Default::default())?; + let (row_locations, n_rows, row_index) = + compute_row_index_struct(index, &index_agg, &ca, count); + let row_index = row_index.map(|x| { + let ca = x.first().unwrap() + .struct_().unwrap(); + + polars_ensure!(ca.null_count() == 0, InvalidOperation: "outer nullability in struct pivot not yet supported"); + + // @scalar-opt + Ok(ca.fields_as_series().into_iter().map(Column::from).collect()) + }).transpose()?; + (row_locations, n_rows, row_index) + }; + + Ok((row_locations, n_rows, row_index)) +} diff --git a/crates/polars-ops/src/frame/pivot/unpivot.rs b/crates/polars-ops/src/frame/pivot/unpivot.rs new file mode 100644 index 000000000000..0bcbded64633 --- /dev/null +++ b/crates/polars-ops/src/frame/pivot/unpivot.rs @@ -0,0 +1,294 @@ +use arrow::array::{MutableArray, MutablePlString}; +use arrow::compute::concatenate::concatenate_unchecked; +use polars_core::datatypes::{DataType, PlSmallStr}; +use polars_core::frame::DataFrame; +use polars_core::frame::column::Column; +use polars_core::prelude::{IntoVec, Series, UnpivotArgsIR}; +use polars_core::utils::merge_dtypes_many; +use polars_error::{PolarsResult, polars_err}; +use polars_utils::aliases::PlHashSet; + +use crate::frame::IntoDf; + +pub trait UnpivotDF: IntoDf { + /// Unpivot a `DataFrame` from wide to long format. + /// + /// # Example + /// + /// # Arguments + /// + /// * `on` - String slice that represent the columns to use as value variables. + /// * `index` - String slice that represent the columns to use as id variables. + /// + /// If `on` is empty all columns that are not in `index` will be used. + /// + /// ```ignore + /// # use polars_core::prelude::*; + /// let df = df!("A" => &["a", "b", "a"], + /// "B" => &[1, 3, 5], + /// "C" => &[10, 11, 12], + /// "D" => &[2, 4, 6] + /// )?; + /// + /// let unpivoted = df.unpivot(&["A", "B"], &["C", "D"])?; + /// println!("{:?}", df); + /// println!("{:?}", unpivoted); + /// # Ok::<(), PolarsError>(()) + /// ``` + /// Outputs: + /// ```text + /// +-----+-----+-----+-----+ + /// | A | B | C | D | + /// | --- | --- | --- | --- | + /// | str | i32 | i32 | i32 | + /// +=====+=====+=====+=====+ + /// | "a" | 1 | 10 | 2 | + /// +-----+-----+-----+-----+ + /// | "b" | 3 | 11 | 4 | + /// +-----+-----+-----+-----+ + /// | "a" | 5 | 12 | 6 | + /// +-----+-----+-----+-----+ + /// + /// +-----+-----+----------+-------+ + /// | A | B | variable | value | + /// | --- | --- | --- | --- | + /// | str | i32 | str | i32 | + /// +=====+=====+==========+=======+ + /// | "a" | 1 | "C" | 10 | + /// +-----+-----+----------+-------+ + /// | "b" | 3 | "C" | 11 | + /// +-----+-----+----------+-------+ + /// | "a" | 5 | "C" | 12 | + /// +-----+-----+----------+-------+ + /// | "a" | 1 | "D" | 2 | + /// +-----+-----+----------+-------+ + /// | "b" | 3 | "D" | 4 | + /// +-----+-----+----------+-------+ + /// | "a" | 5 | "D" | 6 | + /// +-----+-----+----------+-------+ + /// ``` + fn unpivot(&self, on: I, index: J) -> PolarsResult + where + I: IntoVec, + J: IntoVec, + { + let index = index.into_vec(); + let on = on.into_vec(); + self.unpivot2(UnpivotArgsIR { + on, + index, + ..Default::default() + }) + } + + /// Similar to unpivot, but without generics. This may be easier if you want to pass + /// an empty `index` or empty `on`. + fn unpivot2(&self, args: UnpivotArgsIR) -> PolarsResult { + let self_ = self.to_df(); + let index = args.index; + let mut on = args.on; + + let variable_name = args + .variable_name + .unwrap_or_else(|| PlSmallStr::from_static("variable")); + let value_name = args + .value_name + .unwrap_or_else(|| PlSmallStr::from_static("value")); + + if self_.get_columns().is_empty() { + return DataFrame::new(vec![ + Column::new_empty(variable_name, &DataType::String), + Column::new_empty(value_name, &DataType::Null), + ]); + } + + let len = self_.height(); + + // If value vars is empty we take all columns that are not in id_vars. + if on.is_empty() { + // Return empty frame if there are no columns available to use as value vars. + if index.len() == self_.width() { + let variable_col = Column::new_empty(variable_name, &DataType::String); + let value_col = Column::new_empty(value_name, &DataType::Null); + + let mut out = self_.select(index).unwrap().clear().take_columns(); + + out.push(variable_col); + out.push(value_col); + + return Ok(unsafe { DataFrame::new_no_checks(0, out) }); + } + + let index_set = PlHashSet::from_iter(index.iter().cloned()); + on = self_ + .get_columns() + .iter() + .filter_map(|s| { + if index_set.contains(s.name()) { + None + } else { + Some(s.name().clone()) + } + }) + .collect(); + } + + // Values will all be placed in single column, so we must find their supertype + let schema = self_.schema(); + let dtypes = on + .iter() + .map(|v| schema.get(v).ok_or_else(|| polars_err!(col_not_found = v))) + .collect::>>()?; + + let st = merge_dtypes_many(dtypes.iter())?; + + // The column name of the variable that is unpivoted + let mut variable_col = MutablePlString::with_capacity(len * on.len() + 1); + // prepare ids + let ids_ = self_.select_with_schema_unchecked(index, schema)?; + let mut ids = ids_.clone(); + if ids.width() > 0 { + for _ in 0..on.len() - 1 { + ids.vstack_mut_unchecked(&ids_) + } + } + ids.as_single_chunk_par(); + drop(ids_); + + let mut values = Vec::with_capacity(on.len()); + let columns = self_.get_columns(); + + for value_column_name in &on { + variable_col.extend_constant(len, Some(value_column_name.as_str())); + // ensure we go via the schema so we are O(1) + // self.column() is linear + // together with this loop that would make it O^2 over `on` + let (pos, _name, _dtype) = schema.try_get_full(value_column_name)?; + let col = &columns[pos]; + let value_col = col.cast(&st).map_err( + |_| polars_err!(InvalidOperation: "'unpivot' not supported for dtype: {}\n\nConsider casting to String.", col.dtype()), + )?; + values.extend_from_slice(value_col.as_materialized_series().chunks()) + } + let values_arr = concatenate_unchecked(&values)?; + // SAFETY: + // The give dtype is correct + let values = + unsafe { Series::from_chunks_and_dtype_unchecked(value_name, vec![values_arr], &st) } + .into(); + + let variable_col = variable_col.as_box(); + // SAFETY: + // The given dtype is correct + let variables = unsafe { + Series::from_chunks_and_dtype_unchecked( + variable_name, + vec![variable_col], + &DataType::String, + ) + } + .into(); + + ids.hstack_mut(&[variables, values])?; + + Ok(ids) + } +} + +impl UnpivotDF for DataFrame {} + +#[cfg(test)] +mod test { + use polars_core::df; + use polars_core::utils::Container; + + use super::*; + + #[test] + fn test_unpivot() -> PolarsResult<()> { + let df = df!("A" => &["a", "b", "a"], + "B" => &[1, 3, 5], + "C" => &[10, 11, 12], + "D" => &[2, 4, 6] + ) + .unwrap(); + + // Specify on and index + let unpivoted = df.unpivot(["C", "D"], ["A", "B"])?; + assert_eq!( + unpivoted.get_column_names(), + &["A", "B", "variable", "value"] + ); + assert_eq!( + Vec::from(unpivoted.column("value")?.i32()?), + &[Some(10), Some(11), Some(12), Some(2), Some(4), Some(6)] + ); + + // Specify custom column names + let args = UnpivotArgsIR { + on: vec!["C".into(), "D".into()], + index: vec!["A".into(), "B".into()], + variable_name: Some("custom_variable".into()), + value_name: Some("custom_value".into()), + }; + let unpivoted = df.unpivot2(args).unwrap(); + assert_eq!( + unpivoted.get_column_names(), + &["A", "B", "custom_variable", "custom_value"] + ); + + // Specify neither on nor index + let args = UnpivotArgsIR { + on: vec![], + index: vec![], + ..Default::default() + }; + + let unpivoted = df.unpivot2(args).unwrap(); + assert_eq!(unpivoted.get_column_names(), &["variable", "value"]); + let value = unpivoted.column("value")?; + // String because of supertype + let value = value.str()?; + let value = value.into_no_null_iter().collect::>(); + assert_eq!( + value, + &[ + "a", "b", "a", "1", "3", "5", "10", "11", "12", "2", "4", "6" + ] + ); + + // Specify index but not on + let args = UnpivotArgsIR { + on: vec![], + index: vec!["A".into()], + ..Default::default() + }; + + let unpivoted = df.unpivot2(args).unwrap(); + assert_eq!(unpivoted.get_column_names(), &["A", "variable", "value"]); + let value = unpivoted.column("value")?; + let value = value.i32()?; + let value = value.into_no_null_iter().collect::>(); + assert_eq!(value, &[1, 3, 5, 10, 11, 12, 2, 4, 6]); + let variable = unpivoted.column("variable")?; + let variable = variable.str()?; + let variable = variable.into_no_null_iter().collect::>(); + assert_eq!(variable, &["B", "B", "B", "C", "C", "C", "D", "D", "D"]); + assert!(unpivoted.column("A").is_ok()); + + // Specify all columns in index + let args = UnpivotArgsIR { + on: vec![], + index: vec!["A".into(), "B".into(), "C".into(), "D".into()], + ..Default::default() + }; + let unpivoted = df.unpivot2(args).unwrap(); + assert_eq!( + unpivoted.get_column_names(), + &["A", "B", "C", "D", "variable", "value"] + ); + assert_eq!(unpivoted.len(), 0); + + Ok(()) + } +} diff --git a/crates/polars-ops/src/lib.rs b/crates/polars-ops/src/lib.rs new file mode 100644 index 000000000000..5889f915ef3d --- /dev/null +++ b/crates/polars-ops/src/lib.rs @@ -0,0 +1,10 @@ +#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(feature = "nightly", feature(unicode_internals))] +#![cfg_attr(feature = "nightly", allow(internal_features))] + +pub mod chunked_array; +#[cfg(feature = "pivot")] +pub use frame::pivot; +pub mod frame; +pub mod prelude; +pub mod series; diff --git a/crates/polars-ops/src/prelude.rs b/crates/polars-ops/src/prelude.rs new file mode 100644 index 000000000000..f4586f417de9 --- /dev/null +++ b/crates/polars-ops/src/prelude.rs @@ -0,0 +1,11 @@ +#[allow(unused_imports)] +pub(crate) use {crate::series::*, rayon::prelude::*}; + +pub use crate::chunked_array::*; +#[cfg(feature = "merge_sorted")] +pub use crate::frame::_merge_sorted_dfs; +pub use crate::frame::join::*; +#[cfg(feature = "pivot")] +pub use crate::frame::pivot::UnpivotDF; +pub use crate::frame::{DataFrameJoinOps, DataFrameOps}; +pub use crate::series::*; diff --git a/crates/polars-ops/src/series/mod.rs b/crates/polars-ops/src/series/mod.rs new file mode 100644 index 000000000000..e940100af65a --- /dev/null +++ b/crates/polars-ops/src/series/mod.rs @@ -0,0 +1,2 @@ +mod ops; +pub use ops::*; diff --git a/crates/polars-ops/src/series/ops/abs.rs b/crates/polars-ops/src/series/ops/abs.rs new file mode 100644 index 000000000000..e93e3c13c60d --- /dev/null +++ b/crates/polars-ops/src/series/ops/abs.rs @@ -0,0 +1,37 @@ +use polars_core::prelude::*; + +/// Convert numerical values to their absolute value. +pub fn abs(s: &Series) -> PolarsResult { + use DataType::*; + let out = match s.dtype() { + #[cfg(feature = "dtype-i8")] + Int8 => s.i8().unwrap().wrapping_abs().into_series(), + #[cfg(feature = "dtype-i16")] + Int16 => s.i16().unwrap().wrapping_abs().into_series(), + Int32 => s.i32().unwrap().wrapping_abs().into_series(), + Int64 => s.i64().unwrap().wrapping_abs().into_series(), + #[cfg(feature = "dtype-i128")] + Int128 => s.i128().unwrap().wrapping_abs().into_series(), + Float32 => s.f32().unwrap().wrapping_abs().into_series(), + Float64 => s.f64().unwrap().wrapping_abs().into_series(), + #[cfg(feature = "dtype-decimal")] + Decimal(_, _) => { + let ca = s.decimal().unwrap(); + let precision = ca.precision(); + let scale = ca.scale(); + + let out = ca.as_ref().wrapping_abs(); + out.into_decimal_unchecked(precision, scale).into_series() + }, + #[cfg(feature = "dtype-duration")] + Duration(_) => { + let physical = s.to_physical_repr(); + let ca = physical.i64().unwrap(); + let out = ca.wrapping_abs().into_series(); + out.cast(s.dtype())? + }, + dt if dt.is_unsigned_integer() => s.clone(), + dt => polars_bail!(opq = abs, dt), + }; + Ok(out) +} diff --git a/crates/polars-ops/src/series/ops/arg_min_max.rs b/crates/polars-ops/src/series/ops/arg_min_max.rs new file mode 100644 index 000000000000..faba98eb93aa --- /dev/null +++ b/crates/polars-ops/src/series/ops/arg_min_max.rs @@ -0,0 +1,378 @@ +use argminmax::ArgMinMax; +use arrow::array::Array; +use arrow::legacy::bit_util::*; +use polars_core::chunked_array::ops::float_sorted_arg_max::{ + float_arg_max_sorted_ascending, float_arg_max_sorted_descending, +}; +use polars_core::series::IsSorted; +use polars_core::with_match_physical_numeric_polars_type; + +use super::*; + +/// Argmin/ Argmax +pub trait ArgAgg { + /// Get the index of the minimal value + fn arg_min(&self) -> Option; + /// Get the index of the maximal value + fn arg_max(&self) -> Option; +} + +macro_rules! with_match_physical_numeric_polars_type {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use DataType::*; + match $key_type { + #[cfg(feature = "dtype-i8")] + Int8 => __with_ty__! { Int8Type }, + #[cfg(feature = "dtype-i16")] + Int16 => __with_ty__! { Int16Type }, + Int32 => __with_ty__! { Int32Type }, + Int64 => __with_ty__! { Int64Type }, + #[cfg(feature = "dtype-u8")] + UInt8 => __with_ty__! { UInt8Type }, + #[cfg(feature = "dtype-u16")] + UInt16 => __with_ty__! { UInt16Type }, + UInt32 => __with_ty__! { UInt32Type }, + UInt64 => __with_ty__! { UInt64Type }, + Float32 => __with_ty__! { Float32Type }, + Float64 => __with_ty__! { Float64Type }, + dt => panic!("not implemented for dtype {:?}", dt), + } +})} + +impl ArgAgg for Series { + fn arg_min(&self) -> Option { + use DataType::*; + let s = self.to_physical_repr(); + match self.dtype() { + #[cfg(feature = "dtype-categorical")] + Categorical(_, _) => { + let ca = self.categorical().unwrap(); + if ca.null_count() == ca.len() { + return None; + } + if ca.uses_lexical_ordering() { + ca.iter_str() + .enumerate() + .flat_map(|(idx, val)| val.map(|val| (idx, val))) + .reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc }) + .map(|tpl| tpl.0) + } else { + let ca = s.u32().unwrap(); + arg_min_numeric_dispatch(ca) + } + }, + String => { + let ca = self.str().unwrap(); + arg_min_str(ca) + }, + Boolean => { + let ca = self.bool().unwrap(); + arg_min_bool(ca) + }, + Date => { + let ca = s.i32().unwrap(); + arg_min_numeric_dispatch(ca) + }, + Datetime(_, _) | Duration(_) | Time => { + let ca = s.i64().unwrap(); + arg_min_numeric_dispatch(ca) + }, + dt if dt.is_primitive_numeric() => { + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + arg_min_numeric_dispatch(ca) + }) + }, + _ => None, + } + } + + fn arg_max(&self) -> Option { + use DataType::*; + let s = self.to_physical_repr(); + match self.dtype() { + #[cfg(feature = "dtype-categorical")] + Categorical(_, _) => { + let ca = self.categorical().unwrap(); + if ca.null_count() == ca.len() { + return None; + } + if ca.uses_lexical_ordering() { + ca.iter_str() + .enumerate() + .reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc }) + .map(|tpl| tpl.0) + } else { + let ca_phys = s.u32().unwrap(); + arg_max_numeric_dispatch(ca_phys) + } + }, + String => { + let ca = self.str().unwrap(); + arg_max_str(ca) + }, + Boolean => { + let ca = self.bool().unwrap(); + arg_max_bool(ca) + }, + Date => { + let ca = s.i32().unwrap(); + arg_max_numeric_dispatch(ca) + }, + Datetime(_, _) | Duration(_) | Time => { + let ca = s.i64().unwrap(); + arg_max_numeric_dispatch(ca) + }, + dt if dt.is_primitive_numeric() => { + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + arg_max_numeric_dispatch(ca) + }) + }, + _ => None, + } + } +} + +fn arg_max_numeric_dispatch(ca: &ChunkedArray) -> Option +where + T: PolarsNumericType, + for<'b> &'b [T::Native]: ArgMinMax, +{ + if ca.null_count() == ca.len() { + None + } else if T::get_dtype().is_float() && !matches!(ca.is_sorted_flag(), IsSorted::Not) { + arg_max_float_sorted(ca) + } else if let Ok(vals) = ca.cont_slice() { + arg_max_numeric_slice(vals, ca.is_sorted_flag()) + } else { + arg_max_numeric(ca) + } +} + +fn arg_min_numeric_dispatch(ca: &ChunkedArray) -> Option +where + T: PolarsNumericType, + for<'b> &'b [T::Native]: ArgMinMax, +{ + if ca.null_count() == ca.len() { + None + } else if let Ok(vals) = ca.cont_slice() { + arg_min_numeric_slice(vals, ca.is_sorted_flag()) + } else { + arg_min_numeric(ca) + } +} + +pub(crate) fn arg_max_bool(ca: &BooleanChunked) -> Option { + if ca.null_count() == ca.len() { + None + } + // don't check for any, that on itself is already an argmax search + else if ca.null_count() == 0 && ca.chunks().len() == 1 { + let arr = ca.downcast_iter().next().unwrap(); + let mask = arr.values(); + Some(first_set_bit(mask)) + } else { + let mut first_false_idx: Option = None; + ca.iter() + .enumerate() + .find_map(|(idx, val)| match val { + Some(true) => Some(idx), + Some(false) if first_false_idx.is_none() => { + first_false_idx = Some(idx); + None + }, + _ => None, + }) + .or(first_false_idx) + } +} + +/// # Safety +/// `ca` has a float dtype, has at least one non-null value and is sorted. +fn arg_max_float_sorted(ca: &ChunkedArray) -> Option +where + T: PolarsNumericType, +{ + let out = match ca.is_sorted_flag() { + IsSorted::Ascending => float_arg_max_sorted_ascending(ca), + IsSorted::Descending => float_arg_max_sorted_descending(ca), + _ => unreachable!(), + }; + + Some(out) +} + +fn arg_min_bool(ca: &BooleanChunked) -> Option { + if ca.null_count() == ca.len() { + None + } else if ca.null_count() == 0 && ca.chunks().len() == 1 { + let arr = ca.downcast_iter().next().unwrap(); + let mask = arr.values(); + Some(first_unset_bit(mask)) + } else { + let mut first_true_idx: Option = None; + ca.iter() + .enumerate() + .find_map(|(idx, val)| match val { + Some(false) => Some(idx), + Some(true) if first_true_idx.is_none() => { + first_true_idx = Some(idx); + None + }, + _ => None, + }) + .or(first_true_idx) + } +} + +fn arg_min_str(ca: &StringChunked) -> Option { + if ca.null_count() == ca.len() { + return None; + } + match ca.is_sorted_flag() { + IsSorted::Ascending => ca.first_non_null(), + IsSorted::Descending => ca.last_non_null(), + IsSorted::Not => ca + .iter() + .enumerate() + .flat_map(|(idx, val)| val.map(|val| (idx, val))) + .reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc }) + .map(|tpl| tpl.0), + } +} + +fn arg_max_str(ca: &StringChunked) -> Option { + if ca.null_count() == ca.len() { + return None; + } + match ca.is_sorted_flag() { + IsSorted::Ascending => ca.last_non_null(), + IsSorted::Descending => ca.first_non_null(), + IsSorted::Not => ca + .iter() + .enumerate() + .reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc }) + .map(|tpl| tpl.0), + } +} + +fn arg_min_numeric<'a, T>(ca: &'a ChunkedArray) -> Option +where + T: PolarsNumericType, + for<'b> &'b [T::Native]: ArgMinMax, +{ + match ca.is_sorted_flag() { + IsSorted::Ascending => ca.first_non_null(), + IsSorted::Descending => ca.last_non_null(), + IsSorted::Not => { + ca.downcast_iter() + .fold((None, None, 0), |acc, arr| { + if arr.len() == 0 { + return acc; + } + let chunk_min: Option<(usize, T::Native)> = if arr.null_count() > 0 { + arr.into_iter() + .enumerate() + .flat_map(|(idx, val)| val.map(|val| (idx, *val))) + .reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc }) + } else { + // When no nulls & array not empty => we can use fast argminmax + let min_idx: usize = arr.values().as_slice().argmin(); + Some((min_idx, arr.value(min_idx))) + }; + + let new_offset: usize = acc.2 + arr.len(); + match acc { + (Some(_), Some(acc_v), offset) => match chunk_min { + Some((idx, val)) if val < acc_v => { + (Some(idx + offset), Some(val), new_offset) + }, + _ => (acc.0, acc.1, new_offset), + }, + (None, None, offset) => match chunk_min { + Some((idx, val)) => (Some(idx + offset), Some(val), new_offset), + None => (None, None, new_offset), + }, + _ => unreachable!(), + } + }) + .0 + }, + } +} + +fn arg_max_numeric<'a, T>(ca: &'a ChunkedArray) -> Option +where + T: PolarsNumericType, + for<'b> &'b [T::Native]: ArgMinMax, +{ + match ca.is_sorted_flag() { + IsSorted::Ascending => ca.last_non_null(), + IsSorted::Descending => ca.first_non_null(), + IsSorted::Not => { + ca.downcast_iter() + .fold((None, None, 0), |acc, arr| { + if arr.len() == 0 { + return acc; + } + let chunk_max: Option<(usize, T::Native)> = if arr.null_count() > 0 { + // When there are nulls, we should compare Option + arr.into_iter() + .enumerate() + .flat_map(|(idx, val)| val.map(|val| (idx, *val))) + .reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc }) + } else { + // When no nulls & array not empty => we can use fast argminmax + let max_idx: usize = arr.values().as_slice().argmax(); + Some((max_idx, arr.value(max_idx))) + }; + + let new_offset: usize = acc.2 + arr.len(); + match acc { + (Some(_), Some(acc_v), offset) => match chunk_max { + Some((idx, val)) if acc_v < val => { + (Some(idx + offset), Some(val), new_offset) + }, + _ => (acc.0, acc.1, new_offset), + }, + (None, None, offset) => match chunk_max { + Some((idx, val)) => (Some(idx + offset), Some(val), new_offset), + None => (None, None, new_offset), + }, + _ => unreachable!(), + } + }) + .0 + }, + } +} + +fn arg_min_numeric_slice(vals: &[T], is_sorted: IsSorted) -> Option +where + for<'a> &'a [T]: ArgMinMax, +{ + match is_sorted { + // all vals are not null guarded by cont_slice + IsSorted::Ascending => Some(0), + // all vals are not null guarded by cont_slice + IsSorted::Descending => Some(vals.len() - 1), + IsSorted::Not => Some(vals.argmin()), // assumes not empty + } +} + +fn arg_max_numeric_slice(vals: &[T], is_sorted: IsSorted) -> Option +where + for<'a> &'a [T]: ArgMinMax, +{ + match is_sorted { + // all vals are not null guarded by cont_slice + IsSorted::Ascending => Some(vals.len() - 1), + // all vals are not null guarded by cont_slice + IsSorted::Descending => Some(0), + IsSorted::Not => Some(vals.argmax()), // assumes not empty + } +} diff --git a/crates/polars-ops/src/series/ops/bitwise.rs b/crates/polars-ops/src/series/ops/bitwise.rs new file mode 100644 index 000000000000..e83f3c53ef34 --- /dev/null +++ b/crates/polars-ops/src/series/ops/bitwise.rs @@ -0,0 +1,57 @@ +use polars_core::chunked_array::ChunkedArray; +use polars_core::chunked_array::ops::arity::unary_mut_values; +use polars_core::prelude::DataType; +use polars_core::series::Series; +use polars_core::{with_match_physical_float_polars_type, with_match_physical_integer_polars_type}; +use polars_error::{PolarsResult, polars_bail}; + +use super::*; + +macro_rules! apply_bitwise_op { + ($($op:ident),+ $(,)?) => { + $( + pub fn $op(s: &Series) -> PolarsResult { + match s.dtype() { + DataType::Boolean => { + let ca: &ChunkedArray = s.as_any().downcast_ref().unwrap(); + Ok(unary_mut_values::( + ca, + |a| polars_compute::bitwise::BitwiseKernel::$op(a), + ).into_series()) + }, + dt if dt.is_integer() => { + with_match_physical_integer_polars_type!(dt, |$T| { + let ca: &ChunkedArray<$T> = s.as_any().downcast_ref().unwrap(); + Ok(unary_mut_values::<$T, UInt32Type, _, _>( + ca, + |a| polars_compute::bitwise::BitwiseKernel::$op(a), + ).into_series()) + }) + }, + dt if dt.is_float() => { + with_match_physical_float_polars_type!(dt, |$T| { + let ca: &ChunkedArray<$T> = s.as_any().downcast_ref().unwrap(); + Ok(unary_mut_values::<$T, UInt32Type, _, _>( + ca, + |a| polars_compute::bitwise::BitwiseKernel::$op(a), + ).into_series()) + }) + }, + dt => { + polars_bail!(InvalidOperation: "dtype {:?} not supported in '{}' operation", dt, stringify!($op)) + }, + } + } + )+ + + }; +} + +apply_bitwise_op! { + count_ones, + count_zeros, + leading_ones, + leading_zeros, + trailing_ones, + trailing_zeros, +} diff --git a/crates/polars-ops/src/series/ops/business.rs b/crates/polars-ops/src/series/ops/business.rs new file mode 100644 index 000000000000..272da63d0b1f --- /dev/null +++ b/crates/polars-ops/src/series/ops/business.rs @@ -0,0 +1,451 @@ +#[cfg(feature = "dtype-date")] +use chrono::DateTime; +use polars_core::prelude::arity::{binary_elementwise_values, try_binary_elementwise}; +use polars_core::prelude::*; +#[cfg(feature = "dtype-date")] +use polars_core::utils::arrow::temporal_conversions::SECONDS_IN_DAY; +use polars_utils::binary_search::{find_first_ge_index, find_first_gt_index}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +#[cfg(feature = "timezones")] +use crate::prelude::replace_time_zone; + +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum Roll { + Forward, + Backward, + Raise, +} + +/// Count the number of business days between `start` and `end`, excluding `end`. +/// +/// # Arguments +/// - `start`: Series holding start dates. +/// - `end`: Series holding end dates. +/// - `week_mask`: A boolean array of length 7, where `true` indicates that the day is a business day. +/// - `holidays`: timestamps that are holidays. Must be provided as i32, i.e. the number of +/// days since the UNIX epoch. +pub fn business_day_count( + start: &Series, + end: &Series, + week_mask: [bool; 7], + holidays: &[i32], +) -> PolarsResult { + if !week_mask.iter().any(|&x| x) { + polars_bail!(ComputeError:"`week_mask` must have at least one business day"); + } + + // Sort now so we can use `binary_search` in the hot for-loop. + let holidays = normalise_holidays(holidays, &week_mask); + let start_dates = start.date()?; + let end_dates = end.date()?; + let n_business_days_in_week_mask = week_mask.iter().filter(|&x| *x).count() as i32; + + let out = match (start_dates.len(), end_dates.len()) { + (_, 1) => { + if let Some(end_date) = end_dates.get(0) { + start_dates.apply_values(|start_date| { + business_day_count_impl( + start_date, + end_date, + &week_mask, + n_business_days_in_week_mask, + &holidays, + ) + }) + } else { + Int32Chunked::full_null(start_dates.name().clone(), start_dates.len()) + } + }, + (1, _) => { + if let Some(start_date) = start_dates.get(0) { + end_dates.apply_values(|end_date| { + business_day_count_impl( + start_date, + end_date, + &week_mask, + n_business_days_in_week_mask, + &holidays, + ) + }) + } else { + Int32Chunked::full_null(start_dates.name().clone(), end_dates.len()) + } + }, + _ => { + polars_ensure!( + start_dates.len() == end_dates.len(), + length_mismatch = "business_day_count", + start_dates.len(), + end_dates.len() + ); + binary_elementwise_values(start_dates, end_dates, |start_date, end_date| { + business_day_count_impl( + start_date, + end_date, + &week_mask, + n_business_days_in_week_mask, + &holidays, + ) + }) + }, + }; + Ok(out.into_series()) +} + +/// Ported from: +/// https://github.com/numpy/numpy/blob/e59c074842e3f73483afa5ddef031e856b9fd313/numpy/_core/src/multiarray/datetime_busday.c#L355-L433 +fn business_day_count_impl( + mut start_date: i32, + mut end_date: i32, + week_mask: &[bool; 7], + n_business_days_in_week_mask: i32, + holidays: &[i32], // Caller's responsibility to ensure it's sorted. +) -> i32 { + let swapped = start_date > end_date; + if swapped { + (start_date, end_date) = (end_date, start_date); + start_date += 1; + end_date += 1; + } + + let holidays_begin = find_first_ge_index(holidays, start_date); + let holidays_end = find_first_ge_index(&holidays[holidays_begin..], end_date) + holidays_begin; + let mut start_day_of_week = get_day_of_week(start_date); + let diff = end_date - start_date; + let whole_weeks = diff / 7; + let mut count = -((holidays_end - holidays_begin) as i32); + count += whole_weeks * n_business_days_in_week_mask; + start_date += whole_weeks * 7; + while start_date < end_date { + // SAFETY: week_mask is length 7, start_day_of_week is between 0 and 6 + if unsafe { *week_mask.get_unchecked(start_day_of_week) } { + count += 1; + } + start_date += 1; + start_day_of_week = increment_day_of_week(start_day_of_week); + } + if swapped { -count } else { count } +} + +/// Add a given number of business days. +/// +/// # Arguments +/// - `start`: Series holding start dates. +/// - `n`: Number of business days to add. +/// - `week_mask`: A boolean array of length 7, where `true` indicates that the day is a business day. +/// - `holidays`: timestamps that are holidays. Must be provided as i32, i.e. the number of +/// days since the UNIX epoch. +/// - `roll`: what to do when the start date doesn't land on a business day: +/// - `Roll::Forward`: roll forward to the next business day. +/// - `Roll::Backward`: roll backward to the previous business day. +/// - `Roll::Raise`: raise an error. +pub fn add_business_days( + start: &Series, + n: &Series, + week_mask: [bool; 7], + holidays: &[i32], + roll: Roll, +) -> PolarsResult { + if !week_mask.iter().any(|&x| x) { + polars_bail!(ComputeError:"`week_mask` must have at least one business day"); + } + + match start.dtype() { + DataType::Date => {}, + #[cfg(feature = "dtype-datetime")] + DataType::Datetime(time_unit, None) => { + let result_date = + add_business_days(&start.cast(&DataType::Date)?, n, week_mask, holidays, roll)?; + let start_time = start + .cast(&DataType::Time)? + .cast(&DataType::Duration(*time_unit))?; + return std::ops::Add::add( + result_date.cast(&DataType::Datetime(*time_unit, None))?, + start_time, + ); + }, + #[cfg(feature = "timezones")] + DataType::Datetime(time_unit, Some(time_zone)) => { + let start_naive = replace_time_zone( + start.datetime().unwrap(), + None, + &StringChunked::from_iter(std::iter::once("raise")), + NonExistent::Raise, + )?; + let result_date = add_business_days( + &start_naive.cast(&DataType::Date)?, + n, + week_mask, + holidays, + roll, + )?; + let start_time = start_naive + .cast(&DataType::Time)? + .cast(&DataType::Duration(*time_unit))?; + let result_naive = std::ops::Add::add( + result_date.cast(&DataType::Datetime(*time_unit, None))?, + start_time, + )?; + let result_tz_aware = replace_time_zone( + result_naive.datetime().unwrap(), + Some(time_zone), + &StringChunked::from_iter(std::iter::once("raise")), + NonExistent::Raise, + )?; + return Ok(result_tz_aware.into_series()); + }, + _ => polars_bail!(InvalidOperation: "expected date or datetime, got {}", start.dtype()), + } + + // Sort now so we can use `binary_search` in the hot for-loop. + let holidays = normalise_holidays(holidays, &week_mask); + let start_dates = start.date()?; + let n = match &n.dtype() { + DataType::Int64 | DataType::UInt64 | DataType::UInt32 => n.cast(&DataType::Int32)?, + DataType::Int32 => n.clone(), + _ => { + polars_bail!(InvalidOperation: "expected Int64, Int32, UInt64, or UInt32, got {}", n.dtype()) + }, + }; + let n = n.i32()?; + let n_business_days_in_week_mask = week_mask.iter().filter(|&x| *x).count() as i32; + + let out: Int32Chunked = match (start_dates.len(), n.len()) { + (_, 1) => { + if let Some(n) = n.get(0) { + start_dates.try_apply_nonnull_values_generic(|start_date| { + let (start_date, day_of_week) = + roll_start_date(start_date, roll, &week_mask, &holidays)?; + Ok::(add_business_days_impl( + start_date, + day_of_week, + n, + &week_mask, + n_business_days_in_week_mask, + &holidays, + )) + })? + } else { + Int32Chunked::full_null(start_dates.name().clone(), start_dates.len()) + } + }, + (1, _) => { + if let Some(start_date) = start_dates.get(0) { + let (start_date, day_of_week) = + roll_start_date(start_date, roll, &week_mask, &holidays)?; + n.apply_values(|n| { + add_business_days_impl( + start_date, + day_of_week, + n, + &week_mask, + n_business_days_in_week_mask, + &holidays, + ) + }) + } else { + Int32Chunked::full_null(start_dates.name().clone(), n.len()) + } + }, + _ => { + polars_ensure!( + start_dates.len() == n.len(), + length_mismatch = "dt.add_business_days", + start_dates.len(), + n.len() + ); + try_binary_elementwise(start_dates, n, |opt_start_date, opt_n| { + match (opt_start_date, opt_n) { + (Some(start_date), Some(n)) => { + let (start_date, day_of_week) = + roll_start_date(start_date, roll, &week_mask, &holidays)?; + Ok::, PolarsError>(Some(add_business_days_impl( + start_date, + day_of_week, + n, + &week_mask, + n_business_days_in_week_mask, + &holidays, + ))) + }, + _ => Ok(None), + } + })? + }, + }; + Ok(out.into_date().into_series()) +} + +/// Ported from: +/// https://github.com/numpy/numpy/blob/e59c074842e3f73483afa5ddef031e856b9fd313/numpy/_core/src/multiarray/datetime_busday.c#L265-L353 +fn add_business_days_impl( + mut date: i32, + mut day_of_week: usize, + mut n: i32, + week_mask: &[bool; 7], + n_business_days_in_week_mask: i32, + holidays: &[i32], // Caller's responsibility to ensure it's sorted. +) -> i32 { + if n > 0 { + let holidays_begin = find_first_ge_index(holidays, date); + date += (n / n_business_days_in_week_mask) * 7; + n %= n_business_days_in_week_mask; + let holidays_temp = find_first_gt_index(&holidays[holidays_begin..], date) + holidays_begin; + n += (holidays_temp - holidays_begin) as i32; + let holidays_begin = holidays_temp; + while n > 0 { + date += 1; + day_of_week = increment_day_of_week(day_of_week); + // SAFETY: week_mask is length 7, day_of_week is between 0 and 6 + if unsafe { + (*week_mask.get_unchecked(day_of_week)) + && (holidays[holidays_begin..].binary_search(&date).is_err()) + } { + n -= 1; + } + } + date + } else { + let holidays_end = find_first_gt_index(holidays, date); + date += (n / n_business_days_in_week_mask) * 7; + n %= n_business_days_in_week_mask; + let holidays_temp = find_first_ge_index(&holidays[..holidays_end], date); + n -= (holidays_end - holidays_temp) as i32; + let holidays_end = holidays_temp; + while n < 0 { + date -= 1; + day_of_week = decrement_day_of_week(day_of_week); + // SAFETY: week_mask is length 7, day_of_week is between 0 and 6 + if unsafe { + (*week_mask.get_unchecked(day_of_week)) + && (holidays[..holidays_end].binary_search(&date).is_err()) + } { + n += 1; + } + } + date + } +} + +/// Determine if a day lands on a business day. +/// +/// # Arguments +/// - `week_mask`: A boolean array of length 7, where `true` indicates that the day is a business day. +/// - `holidays`: timestamps that are holidays. Must be provided as i32, i.e. the number of +/// days since the UNIX epoch. +pub fn is_business_day( + dates: &Series, + week_mask: [bool; 7], + holidays: &[i32], +) -> PolarsResult { + if !week_mask.iter().any(|&x| x) { + polars_bail!(ComputeError:"`week_mask` must have at least one business day"); + } + + match dates.dtype() { + DataType::Date => {}, + #[cfg(feature = "dtype-datetime")] + DataType::Datetime(_, None) => { + return is_business_day(&dates.cast(&DataType::Date)?, week_mask, holidays); + }, + #[cfg(feature = "timezones")] + DataType::Datetime(_, Some(_)) => { + let dates_local = replace_time_zone( + dates.datetime().unwrap(), + None, + &StringChunked::from_iter(std::iter::once("raise")), + NonExistent::Raise, + )?; + return is_business_day(&dates_local.cast(&DataType::Date)?, week_mask, holidays); + }, + _ => polars_bail!(InvalidOperation: "expected date or datetime, got {}", dates.dtype()), + } + + // Sort now so we can use `binary_search` in the hot for-loop. + let holidays = normalise_holidays(holidays, &week_mask); + let dates = dates.date()?; + let out: BooleanChunked = dates.apply_nonnull_values_generic(DataType::Boolean, |date| { + let day_of_week = get_day_of_week(date); + // SAFETY: week_mask is length 7, day_of_week is between 0 and 6 + unsafe { (*week_mask.get_unchecked(day_of_week)) && holidays.binary_search(&date).is_err() } + }); + Ok(out.into_series()) +} + +fn roll_start_date( + mut date: i32, + roll: Roll, + week_mask: &[bool; 7], + holidays: &[i32], // Caller's responsibility to ensure it's sorted. +) -> PolarsResult<(i32, usize)> { + let mut day_of_week = get_day_of_week(date); + match roll { + Roll::Raise => { + // SAFETY: week_mask is length 7, day_of_week is between 0 and 6 + if holidays.binary_search(&date).is_ok() + | unsafe { !*week_mask.get_unchecked(day_of_week) } + { + let date = DateTime::from_timestamp(date as i64 * SECONDS_IN_DAY, 0) + .unwrap() + .format("%Y-%m-%d"); + polars_bail!(ComputeError: + "date {} is not a business date; use `roll` to roll forwards (or backwards) to the next (or previous) valid date.", date + ) + }; + }, + Roll::Forward => { + // SAFETY: week_mask is length 7, day_of_week is between 0 and 6 + while holidays.binary_search(&date).is_ok() + | unsafe { !*week_mask.get_unchecked(day_of_week) } + { + date += 1; + day_of_week = increment_day_of_week(day_of_week); + } + }, + Roll::Backward => { + // SAFETY: week_mask is length 7, day_of_week is between 0 and 6 + while holidays.binary_search(&date).is_ok() + | unsafe { !*week_mask.get_unchecked(day_of_week) } + { + date -= 1; + day_of_week = decrement_day_of_week(day_of_week); + } + }, + } + Ok((date, day_of_week)) +} + +/// Sort and deduplicate holidays and remove holidays that are not business days. +fn normalise_holidays(holidays: &[i32], week_mask: &[bool; 7]) -> Vec { + let mut holidays: Vec = holidays.to_vec(); + holidays.sort_unstable(); + let mut previous_holiday: Option = None; + holidays.retain(|&x| { + // SAFETY: week_mask is length 7, get_day_of_week result is between 0 and 6 + if (Some(x) == previous_holiday) || !unsafe { *week_mask.get_unchecked(get_day_of_week(x)) } + { + return false; + } + previous_holiday = Some(x); + true + }); + holidays +} + +fn get_day_of_week(x: i32) -> usize { + // the first modulo might return a negative number, so we add 7 and take + // the modulo again so we're sure we have something between 0 (Monday) + // and 6 (Sunday) + (((x - 4) % 7 + 7) % 7) as usize +} + +fn increment_day_of_week(x: usize) -> usize { + if x == 6 { 0 } else { x + 1 } +} + +fn decrement_day_of_week(x: usize) -> usize { + if x == 0 { 6 } else { x - 1 } +} diff --git a/crates/polars-ops/src/series/ops/clip.rs b/crates/polars-ops/src/series/ops/clip.rs new file mode 100644 index 000000000000..2d7b21a05c93 --- /dev/null +++ b/crates/polars-ops/src/series/ops/clip.rs @@ -0,0 +1,210 @@ +use polars_core::prelude::arity::{binary_elementwise, ternary_elementwise, unary_elementwise}; +use polars_core::prelude::*; +use polars_core::with_match_physical_numeric_polars_type; + +/// Set values outside the given boundaries to the boundary value. +pub fn clip(s: &Series, min: &Series, max: &Series) -> PolarsResult { + polars_ensure!( + s.dtype().to_physical().is_primitive_numeric(), + InvalidOperation: "`clip` only supports physical numeric types" + ); + let n = [s.len(), min.len(), max.len()] + .into_iter() + .find(|l| *l != 1) + .unwrap_or(1); + + for (i, (name, length)) in [("self", s.len()), ("min", min.len()), ("max", max.len())] + .into_iter() + .enumerate() + { + polars_ensure!( + length == n || length == 1, + length_mismatch = "clip", + length, + n, + argument = name, + argument_idx = i + ); + } + + let original_type = s.dtype(); + let (min, max) = (min.strict_cast(s.dtype())?, max.strict_cast(s.dtype())?); + + let (s, min, max) = ( + s.to_physical_repr(), + min.to_physical_repr(), + max.to_physical_repr(), + ); + + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + let min: &ChunkedArray<$T> = min.as_ref().as_ref().as_ref(); + let max: &ChunkedArray<$T> = max.as_ref().as_ref().as_ref(); + let out = clip_helper_both_bounds(ca, min, max).into_series(); + match original_type { + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(precision, scale) => { + let phys = out.i128()?.as_ref().clone(); + Ok(phys.into_decimal_unchecked(*precision, scale.unwrap()).into_series()) + }, + dt if dt.is_logical() => out.cast(original_type), + _ => Ok(out) + } + }) +} + +/// Set values above the given maximum to the maximum value. +pub fn clip_max(s: &Series, max: &Series) -> PolarsResult { + polars_ensure!( + s.dtype().to_physical().is_primitive_numeric(), + InvalidOperation: "`clip` only supports physical numeric types" + ); + polars_ensure!( + s.len() == max.len() || s.len() == 1 || max.len() == 1, + length_mismatch = "clip(max)", + s.len(), + max.len() + ); + + let original_type = s.dtype(); + let max = max.strict_cast(s.dtype())?; + + let (s, max) = (s.to_physical_repr(), max.to_physical_repr()); + + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + let max: &ChunkedArray<$T> = max.as_ref().as_ref().as_ref(); + let out = clip_helper_single_bound(ca, max, num_traits::clamp_max).into_series(); + match original_type { + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(precision, scale) => { + let phys = out.i128()?.as_ref().clone(); + Ok(phys.into_decimal_unchecked(*precision, scale.unwrap()).into_series()) + }, + dt if dt.is_logical() => out.cast(original_type), + _ => Ok(out) + } + }) +} + +/// Set values below the given minimum to the minimum value. +pub fn clip_min(s: &Series, min: &Series) -> PolarsResult { + polars_ensure!( + s.dtype().to_physical().is_primitive_numeric(), + InvalidOperation: "`clip` only supports physical numeric types" + ); + polars_ensure!( + s.len() == min.len() || s.len() == 1 || min.len() == 1, + length_mismatch = "clip(min)", + s.len(), + min.len() + ); + + let original_type = s.dtype(); + let min = min.strict_cast(s.dtype())?; + + let (s, min) = (s.to_physical_repr(), min.to_physical_repr()); + + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + let min: &ChunkedArray<$T> = min.as_ref().as_ref().as_ref(); + let out = clip_helper_single_bound(ca, min, num_traits::clamp_min).into_series(); + match original_type { + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(precision, scale) => { + let phys = out.i128()?.as_ref().clone(); + Ok(phys.into_decimal_unchecked(*precision, scale.unwrap()).into_series()) + }, + dt if dt.is_logical() => out.cast(original_type), + _ => Ok(out) + } + }) +} + +fn clip_helper_both_bounds( + ca: &ChunkedArray, + min: &ChunkedArray, + max: &ChunkedArray, +) -> ChunkedArray +where + T: PolarsNumericType, + T::Native: PartialOrd, +{ + match (min.len(), max.len()) { + (1, 1) => match (min.get(0), max.get(0)) { + (Some(min), Some(max)) => clip_unary(ca, |v| num_traits::clamp(v, min, max)), + (Some(min), None) => clip_unary(ca, |v| num_traits::clamp_min(v, min)), + (None, Some(max)) => clip_unary(ca, |v| num_traits::clamp_max(v, max)), + (None, None) => ca.clone(), + }, + (1, _) => match min.get(0) { + Some(min) => clip_binary(ca, max, |v, b| num_traits::clamp(v, min, b)), + None => clip_binary(ca, max, num_traits::clamp_max), + }, + (_, 1) => match max.get(0) { + Some(max) => clip_binary(ca, min, |v, b| num_traits::clamp(v, b, max)), + None => clip_binary(ca, min, num_traits::clamp_min), + }, + _ => clip_ternary(ca, min, max), + } +} + +fn clip_helper_single_bound( + ca: &ChunkedArray, + bound: &ChunkedArray, + op: F, +) -> ChunkedArray +where + T: PolarsNumericType, + T::Native: PartialOrd, + F: Fn(T::Native, T::Native) -> T::Native, +{ + match bound.len() { + 1 => match bound.get(0) { + Some(bound) => clip_unary(ca, |v| op(v, bound)), + None => ca.clone(), + }, + _ => clip_binary(ca, bound, op), + } +} + +fn clip_unary(ca: &ChunkedArray, op: F) -> ChunkedArray +where + T: PolarsNumericType, + F: Fn(T::Native) -> T::Native + Copy, +{ + unary_elementwise(ca, |v| v.map(op)) +} + +fn clip_binary(ca: &ChunkedArray, bound: &ChunkedArray, op: F) -> ChunkedArray +where + T: PolarsNumericType, + T::Native: PartialOrd, + F: Fn(T::Native, T::Native) -> T::Native, +{ + binary_elementwise(ca, bound, |opt_s, opt_bound| match (opt_s, opt_bound) { + (Some(s), Some(bound)) => Some(op(s, bound)), + (Some(s), None) => Some(s), + (None, _) => None, + }) +} + +fn clip_ternary( + ca: &ChunkedArray, + min: &ChunkedArray, + max: &ChunkedArray, +) -> ChunkedArray +where + T: PolarsNumericType, + T::Native: PartialOrd, +{ + ternary_elementwise(ca, min, max, |opt_v, opt_min, opt_max| { + match (opt_v, opt_min, opt_max) { + (Some(v), Some(min), Some(max)) => Some(num_traits::clamp(v, min, max)), + (Some(v), Some(min), None) => Some(num_traits::clamp_min(v, min)), + (Some(v), None, Some(max)) => Some(num_traits::clamp_max(v, max)), + (Some(v), None, None) => Some(v), + (None, _, _) => None, + } + }) +} diff --git a/crates/polars-ops/src/series/ops/concat_arr.rs b/crates/polars-ops/src/series/ops/concat_arr.rs new file mode 100644 index 000000000000..dd0d0b00b824 --- /dev/null +++ b/crates/polars-ops/src/series/ops/concat_arr.rs @@ -0,0 +1,142 @@ +use arrow::array::FixedSizeListArray; +use arrow::compute::utils::combine_validities_and; +use polars_compute::horizontal_flatten::horizontal_flatten_unchecked; +use polars_core::prelude::{ArrayChunked, Column, CompatLevel, DataType, IntoColumn}; +use polars_core::series::Series; +use polars_error::{PolarsResult, polars_bail}; +use polars_utils::pl_str::PlSmallStr; + +/// Note: The caller must ensure all columns in `args` have the same type. +/// +/// # Panics +/// Panics if +/// * `args` is empty +/// * `dtype` is not a `DataType::Array` +pub fn concat_arr(args: &[Column], dtype: &DataType) -> PolarsResult { + let DataType::Array(inner_dtype, width) = dtype else { + panic!("{}", dtype); + }; + + let inner_dtype = inner_dtype.as_ref(); + let width = *width; + + let mut output_height = args[0].len(); + let mut calculated_width = 0; + let mut mismatch_height = (&PlSmallStr::EMPTY, output_height); + // If there is a `Array` column with a single NULL, the output will be entirely NULL. + let mut return_all_null = false; + // Indicates whether all `arrays` have unit length (excluding zero-width arrays) + let mut all_unit_len = true; + let mut validities = Vec::with_capacity(args.len()); + + let (arrays, widths): (Vec<_>, Vec<_>) = args + .iter() + .map(|c| { + let len = c.len(); + + // Handle broadcasting + if output_height == 1 { + output_height = len; + mismatch_height.1 = len; + } + + if len != output_height && len != 1 && mismatch_height.1 == output_height { + mismatch_height = (c.name(), len); + } + + // Don't expand scalars to height, this is handled by the `horizontal_flatten` kernel. + let s = c.as_materialized_series_maintain_scalar(); + + match s.dtype() { + DataType::Array(inner, width) => { + debug_assert_eq!(inner.as_ref(), inner_dtype); + + let arr = s.array().unwrap().rechunk(); + let validity = arr.rechunk_validity(); + + return_all_null |= len == 1 && validity.as_ref().is_some_and(|x| !x.get_bit(0)); + + // Ignore unit-length validities. If they are non-valid then `return_all_null` will + // cause an early return. + if let Some(v) = validity.filter(|_| len > 1) { + validities.push(v) + } + + (arr.downcast_as_array().values().clone(), *width) + }, + dtype => { + debug_assert_eq!(dtype, inner_dtype); + // Note: We ignore the validity of non-array input columns, their outer is always valid after + // being reshaped to (-1, 1). + (s.rechunk().into_chunks()[0].clone(), 1) + }, + } + }) + // Filter out zero-width + .filter(|x| x.1 > 0) + .inspect(|x| { + calculated_width += x.1; + all_unit_len &= x.0.len() == 1; + }) + .unzip(); + + assert_eq!(calculated_width, width); + + if mismatch_height.1 != output_height { + polars_bail!( + ShapeMismatch: + "concat_arr: length of column '{}' (len={}) did not match length of \ + first column '{}' (len={})", + mismatch_height.0, mismatch_height.1, args[0].name(), output_height, + ) + } + + if return_all_null || output_height == 0 { + let arr = + FixedSizeListArray::new_null(dtype.to_arrow(CompatLevel::newest()), output_height); + return Ok(ArrayChunked::with_chunk(args[0].name().clone(), arr).into_column()); + } + + // Combine validities + let outer_validity = validities.into_iter().fold(None, |a, b| { + debug_assert_eq!(b.len(), output_height); + combine_validities_and(a.as_ref(), Some(&b)) + }); + + // At this point the output height and all arrays should have non-zero length + let out = if all_unit_len && width > 0 { + // Fast-path for all scalars + let inner_arr = unsafe { horizontal_flatten_unchecked(&arrays, &widths, 1) }; + + let arr = FixedSizeListArray::new( + dtype.to_arrow(CompatLevel::newest()), + 1, + inner_arr, + outer_validity, + ); + + return Ok(ArrayChunked::with_chunk(args[0].name().clone(), arr) + .into_column() + .new_from_index(0, output_height)); + } else { + let inner_arr = if width == 0 { + Series::new_empty(PlSmallStr::EMPTY, inner_dtype) + .into_chunks() + .into_iter() + .next() + .unwrap() + } else { + unsafe { horizontal_flatten_unchecked(&arrays, &widths, output_height) } + }; + + let arr = FixedSizeListArray::new( + dtype.to_arrow(CompatLevel::newest()), + output_height, + inner_arr, + outer_validity, + ); + ArrayChunked::with_chunk(args[0].name().clone(), arr).into_column() + }; + + Ok(out) +} diff --git a/crates/polars-ops/src/series/ops/cum_agg.rs b/crates/polars-ops/src/series/ops/cum_agg.rs new file mode 100644 index 000000000000..ad6089e1d63d --- /dev/null +++ b/crates/polars-ops/src/series/ops/cum_agg.rs @@ -0,0 +1,341 @@ +use std::ops::{Add, AddAssign, Mul}; + +use arity::unary_elementwise_values; +use arrow::array::BooleanArray; +use arrow::bitmap::BitmapBuilder; +use num_traits::{Bounded, One, Zero}; +use polars_core::prelude::*; +use polars_core::series::IsSorted; +use polars_core::utils::{CustomIterTools, NoNull}; +use polars_core::with_match_physical_numeric_polars_type; + +fn det_max(state: &mut T, v: Option) -> Option> +where + T: Copy + PartialOrd + AddAssign + Add, +{ + match v { + Some(v) => { + if v > *state { + *state = v + } + Some(Some(*state)) + }, + None => Some(None), + } +} + +fn det_min(state: &mut T, v: Option) -> Option> +where + T: Copy + PartialOrd + AddAssign + Add, +{ + match v { + Some(v) => { + if v < *state { + *state = v + } + Some(Some(*state)) + }, + None => Some(None), + } +} + +fn det_sum(state: &mut T, v: Option) -> Option> +where + T: Copy + PartialOrd + AddAssign + Add, +{ + match v { + Some(v) => { + *state += v; + Some(Some(*state)) + }, + None => Some(None), + } +} + +fn det_prod(state: &mut T, v: Option) -> Option> +where + T: Copy + PartialOrd + Mul, +{ + match v { + Some(v) => { + *state = *state * v; + Some(Some(*state)) + }, + None => Some(None), + } +} + +fn cum_max_numeric(ca: &ChunkedArray, reverse: bool) -> ChunkedArray +where + T: PolarsNumericType, + ChunkedArray: FromIterator>, +{ + let init = Bounded::min_value(); + + let out: ChunkedArray = match reverse { + false => ca.iter().scan(init, det_max).collect_trusted(), + true => ca.iter().rev().scan(init, det_max).collect_reversed(), + }; + out.with_name(ca.name().clone()) +} + +fn cum_min_numeric(ca: &ChunkedArray, reverse: bool) -> ChunkedArray +where + T: PolarsNumericType, + ChunkedArray: FromIterator>, +{ + let init = Bounded::max_value(); + let out: ChunkedArray = match reverse { + false => ca.iter().scan(init, det_min).collect_trusted(), + true => ca.iter().rev().scan(init, det_min).collect_reversed(), + }; + out.with_name(ca.name().clone()) +} + +fn cum_max_bool(ca: &BooleanChunked, reverse: bool) -> BooleanChunked { + if ca.len() == ca.null_count() { + return ca.clone(); + } + + let mut out; + if !reverse { + // TODO: efficient bitscan. + let Some(first_true_idx) = ca.iter().position(|x| x == Some(true)) else { + return ca.clone(); + }; + out = BitmapBuilder::with_capacity(ca.len()); + out.extend_constant(first_true_idx, false); + out.extend_constant(ca.len() - first_true_idx, true); + } else { + // TODO: efficient bitscan. + let Some(last_true_idx) = ca.iter().rposition(|x| x == Some(true)) else { + return ca.clone(); + }; + out = BitmapBuilder::with_capacity(ca.len()); + out.extend_constant(last_true_idx + 1, true); + out.extend_constant(ca.len() - 1 - last_true_idx, false); + } + + let arr: BooleanArray = out.freeze().into(); + BooleanChunked::with_chunk_like(ca, arr.with_validity(ca.rechunk_validity())) +} + +fn cum_min_bool(ca: &BooleanChunked, reverse: bool) -> BooleanChunked { + if ca.len() == ca.null_count() { + return ca.clone(); + } + + let mut out; + if !reverse { + // TODO: efficient bitscan. + let Some(first_false_idx) = ca.iter().position(|x| x == Some(false)) else { + return ca.clone(); + }; + out = BitmapBuilder::with_capacity(ca.len()); + out.extend_constant(first_false_idx, true); + out.extend_constant(ca.len() - first_false_idx, false); + } else { + // TODO: efficient bitscan. + let Some(last_false_idx) = ca.iter().rposition(|x| x == Some(false)) else { + return ca.clone(); + }; + out = BitmapBuilder::with_capacity(ca.len()); + out.extend_constant(last_false_idx + 1, false); + out.extend_constant(ca.len() - 1 - last_false_idx, true); + } + + let arr: BooleanArray = out.freeze().into(); + BooleanChunked::with_chunk_like(ca, arr.with_validity(ca.rechunk_validity())) +} + +fn cum_sum_numeric(ca: &ChunkedArray, reverse: bool) -> ChunkedArray +where + T: PolarsNumericType, + ChunkedArray: FromIterator>, +{ + let init = T::Native::zero(); + let out: ChunkedArray = match reverse { + false => ca.iter().scan(init, det_sum).collect_trusted(), + true => ca.iter().rev().scan(init, det_sum).collect_reversed(), + }; + out.with_name(ca.name().clone()) +} + +fn cum_prod_numeric(ca: &ChunkedArray, reverse: bool) -> ChunkedArray +where + T: PolarsNumericType, + ChunkedArray: FromIterator>, +{ + let init = T::Native::one(); + let out: ChunkedArray = match reverse { + false => ca.iter().scan(init, det_prod).collect_trusted(), + true => ca.iter().rev().scan(init, det_prod).collect_reversed(), + }; + out.with_name(ca.name().clone()) +} + +/// Get an array with the cumulative product computed at every element. +/// +/// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16, Int32, UInt32}` the `Series` is +/// first cast to `Int64` to prevent overflow issues. +pub fn cum_prod(s: &Series, reverse: bool) -> PolarsResult { + use DataType::*; + let out = match s.dtype() { + Boolean | Int8 | UInt8 | Int16 | UInt16 | Int32 | UInt32 => { + let s = s.cast(&Int64)?; + cum_prod_numeric(s.i64()?, reverse).into_series() + }, + Int64 => cum_prod_numeric(s.i64()?, reverse).into_series(), + UInt64 => cum_prod_numeric(s.u64()?, reverse).into_series(), + #[cfg(feature = "dtype-i128")] + Int128 => cum_prod_numeric(s.i128()?, reverse).into_series(), + Float32 => cum_prod_numeric(s.f32()?, reverse).into_series(), + Float64 => cum_prod_numeric(s.f64()?, reverse).into_series(), + dt => polars_bail!(opq = cum_prod, dt), + }; + Ok(out) +} + +/// Get an array with the cumulative sum computed at every element +/// +/// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16}` the `Series` is +/// first cast to `Int64` to prevent overflow issues. +pub fn cum_sum(s: &Series, reverse: bool) -> PolarsResult { + use DataType::*; + let out = match s.dtype() { + Boolean => { + let s = s.cast(&UInt32)?; + cum_sum_numeric(s.u32()?, reverse).into_series() + }, + Int8 | UInt8 | Int16 | UInt16 => { + let s = s.cast(&Int64)?; + cum_sum_numeric(s.i64()?, reverse).into_series() + }, + Int32 => cum_sum_numeric(s.i32()?, reverse).into_series(), + UInt32 => cum_sum_numeric(s.u32()?, reverse).into_series(), + Int64 => cum_sum_numeric(s.i64()?, reverse).into_series(), + UInt64 => cum_sum_numeric(s.u64()?, reverse).into_series(), + #[cfg(feature = "dtype-i128")] + Int128 => cum_sum_numeric(s.i128()?, reverse).into_series(), + Float32 => cum_sum_numeric(s.f32()?, reverse).into_series(), + Float64 => cum_sum_numeric(s.f64()?, reverse).into_series(), + #[cfg(feature = "dtype-decimal")] + Decimal(precision, scale) => { + let ca = s.decimal().unwrap().as_ref(); + cum_sum_numeric(ca, reverse) + .into_decimal_unchecked(*precision, scale.unwrap()) + .into_series() + }, + #[cfg(feature = "dtype-duration")] + Duration(tu) => { + let s = s.to_physical_repr(); + let ca = s.i64()?; + cum_sum_numeric(ca, reverse).cast(&Duration(*tu))? + }, + dt => polars_bail!(opq = cum_sum, dt), + }; + Ok(out) +} + +/// Get an array with the cumulative min computed at every element. +pub fn cum_min(s: &Series, reverse: bool) -> PolarsResult { + match s.dtype() { + DataType::Boolean => Ok(cum_min_bool(s.bool()?, reverse).into_series()), + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(precision, scale) => { + let ca = s.decimal().unwrap().as_ref(); + let out = cum_min_numeric(ca, reverse) + .into_decimal_unchecked(*precision, scale.unwrap()) + .into_series(); + Ok(out) + }, + dt if dt.to_physical().is_primitive_numeric() => { + let s = s.to_physical_repr(); + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + let out = cum_min_numeric(ca, reverse).into_series(); + if dt.is_logical() { + out.cast(dt) + } else { + Ok(out) + } + }) + }, + dt => polars_bail!(opq = cum_min, dt), + } +} + +/// Get an array with the cumulative max computed at every element. +pub fn cum_max(s: &Series, reverse: bool) -> PolarsResult { + match s.dtype() { + DataType::Boolean => Ok(cum_max_bool(s.bool()?, reverse).into_series()), + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(precision, scale) => { + let ca = s.decimal().unwrap().as_ref(); + let out = cum_max_numeric(ca, reverse) + .into_decimal_unchecked(*precision, scale.unwrap()) + .into_series(); + Ok(out) + }, + dt if dt.to_physical().is_primitive_numeric() => { + let s = s.to_physical_repr(); + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + let out = cum_max_numeric(ca, reverse).into_series(); + if dt.is_logical() { + out.cast(dt) + } else { + Ok(out) + } + }) + }, + dt => polars_bail!(opq = cum_max, dt), + } +} + +pub fn cum_count(s: &Series, reverse: bool) -> PolarsResult { + let mut out = if s.null_count() == 0 { + // Fast paths for no nulls + cum_count_no_nulls(s.name().clone(), s.len(), reverse) + } else { + let ca = s.is_not_null(); + let out: IdxCa = if reverse { + let mut count = (s.len() - s.null_count()) as IdxSize; + let mut prev = false; + unary_elementwise_values(&ca, |v: bool| { + if prev { + count -= 1; + } + prev = v; + count + }) + } else { + let mut count = 0 as IdxSize; + unary_elementwise_values(&ca, |v: bool| { + if v { + count += 1; + } + count + }) + }; + + out.into() + }; + + out.set_sorted_flag([IsSorted::Ascending, IsSorted::Descending][reverse as usize]); + + Ok(out) +} + +fn cum_count_no_nulls(name: PlSmallStr, len: usize, reverse: bool) -> Series { + let start = 1 as IdxSize; + let end = len as IdxSize + 1; + let ca: NoNull = if reverse { + (start..end).rev().collect() + } else { + (start..end).collect() + }; + let mut ca = ca.into_inner(); + ca.rename(name); + ca.into_series() +} diff --git a/crates/polars-ops/src/series/ops/cut.rs b/crates/polars-ops/src/series/ops/cut.rs new file mode 100644 index 000000000000..7016fd08c1f1 --- /dev/null +++ b/crates/polars-ops/src/series/ops/cut.rs @@ -0,0 +1,194 @@ +use polars_compute::rolling::QuantileMethod; +use polars_core::prelude::*; +use polars_utils::format_pl_smallstr; + +fn map_cats( + s: &Series, + labels: &[PlSmallStr], + sorted_breaks: &[f64], + left_closed: bool, + include_breaks: bool, +) -> PolarsResult { + let out_name = PlSmallStr::from_static("category"); + + // Create new categorical and pre-register labels for consistent categorical indexes. + let mut bld = CategoricalChunkedBuilder::new(out_name.clone(), s.len(), Default::default()); + for label in labels { + bld.register_value(label); + } + + let s2 = s.cast(&DataType::Float64)?; + // It would be nice to parallelize this + let s_iter = s2.f64()?.into_iter(); + + let op = if left_closed { + PartialOrd::ge + } else { + PartialOrd::gt + }; + + // Ensure fast unique is only set if all labels were seen. + let mut label_has_value = vec![false; 1 + sorted_breaks.len()]; + + if include_breaks { + // This is to replicate the behavior of the old buggy version that only worked on series and + // returned a dataframe. That included a column of the right endpoint of the interval. So we + // return a struct series instead which can be turned into a dataframe later. + let right_ends = [sorted_breaks, &[f64::INFINITY]].concat(); + let mut brk_vals = PrimitiveChunkedBuilder::::new( + PlSmallStr::from_static("breakpoint"), + s.len(), + ); + s_iter + .map(|opt| { + opt.filter(|x| !x.is_nan()).map(|x| { + let pt = sorted_breaks.partition_point(|v| op(&x, v)); + unsafe { *label_has_value.get_unchecked_mut(pt) = true }; + pt + }) + }) + .for_each(|idx| match idx { + None => { + bld.append_null(); + brk_vals.append_null(); + }, + Some(idx) => unsafe { + bld.append_value(labels.get_unchecked(idx)); + brk_vals.append_value(*right_ends.get_unchecked(idx)); + }, + }); + + let outvals = [brk_vals.finish().into_series(), unsafe { + bld.finish() + ._with_fast_unique(label_has_value.iter().all(bool::clone)) + .into_series() + }]; + Ok(StructChunked::from_series(out_name, outvals[0].len(), outvals.iter())?.into_series()) + } else { + Ok(unsafe { + bld.drain_iter_and_finish(s_iter.map(|opt| { + opt.filter(|x| !x.is_nan()).map(|x| { + let pt = sorted_breaks.partition_point(|v| op(&x, v)); + *label_has_value.get_unchecked_mut(pt) = true; + labels.get_unchecked(pt).as_str() + }) + })) + ._with_fast_unique(label_has_value.iter().all(bool::clone)) + } + .into_series()) + } +} + +pub fn compute_labels(breaks: &[f64], left_closed: bool) -> PolarsResult> { + let lo = std::iter::once(&f64::NEG_INFINITY).chain(breaks.iter()); + let hi = breaks.iter().chain(std::iter::once(&f64::INFINITY)); + + let ret = lo + .zip(hi) + .map(|(l, h)| { + if left_closed { + format_pl_smallstr!("[{}, {})", l, h) + } else { + format_pl_smallstr!("({}, {}]", l, h) + } + }) + .collect(); + Ok(ret) +} + +pub fn cut( + s: &Series, + mut breaks: Vec, + labels: Option>, + left_closed: bool, + include_breaks: bool, +) -> PolarsResult { + // Breaks must be sorted to cut inputs properly. + polars_ensure!(!breaks.iter().any(|x| x.is_nan()), ComputeError: "breaks cannot be NaN"); + breaks.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); + + polars_ensure!(breaks.windows(2).all(|x| x[0] != x[1]), Duplicate: "breaks are not unique"); + if !breaks.is_empty() { + polars_ensure!(breaks[0] > f64::NEG_INFINITY, ComputeError: "don't include -inf in breaks"); + polars_ensure!(breaks[breaks.len() - 1] < f64::INFINITY, ComputeError: "don't include inf in breaks"); + } + + let cut_labels = if let Some(l) = labels { + polars_ensure!(l.len() == breaks.len() + 1, ShapeMismatch: "provide len(quantiles) + 1 labels"); + l + } else { + compute_labels(&breaks, left_closed)? + }; + map_cats(s, &cut_labels, &breaks, left_closed, include_breaks) +} + +pub fn qcut( + s: &Series, + probs: Vec, + labels: Option>, + left_closed: bool, + allow_duplicates: bool, + include_breaks: bool, +) -> PolarsResult { + polars_ensure!(!probs.iter().any(|x| x.is_nan()), ComputeError: "quantiles cannot be NaN"); + + if s.null_count() == s.len() { + // If we only have nulls we don't have any breakpoints. + return Ok(Series::full_null( + s.name().clone(), + s.len(), + &DataType::Categorical(None, Default::default()), + )); + } + + let s = s.cast(&DataType::Float64)?; + let s2 = s.sort(SortOptions::default())?; + let ca = s2.f64()?; + + let f = |&p| ca.quantile(p, QuantileMethod::Linear).unwrap().unwrap(); + let mut qbreaks: Vec<_> = probs.iter().map(f).collect(); + qbreaks.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); + + if !allow_duplicates { + polars_ensure!(qbreaks.windows(2).all(|x| x[0] != x[1]), Duplicate: "quantiles are not unique while allow_duplicates=False"); + } + + let cut_labels = if let Some(l) = labels { + polars_ensure!(l.len() == qbreaks.len() + 1, ShapeMismatch: "provide len(quantiles) + 1 labels"); + l + } else { + compute_labels(&qbreaks, left_closed)? + }; + + map_cats(&s, &cut_labels, &qbreaks, left_closed, include_breaks) +} + +mod test { + // This need metadata in fields + #[ignore] + #[test] + fn test_map_cats_fast_unique() { + // This test is here to check the fast unique flag is set when it can be + // as it is not visible to Python. + use polars_core::prelude::*; + + use super::map_cats; + + let s = Series::new("x".into(), &[1, 2, 3, 4, 5]); + + let labels = &["a", "b", "c"].map(PlSmallStr::from_static); + let breaks = &[2.0, 4.0]; + let left_closed = false; + + let include_breaks = false; + let out = map_cats(&s, labels, breaks, left_closed, include_breaks).unwrap(); + let out = out.categorical().unwrap(); + assert!(out._can_fast_unique()); + + let include_breaks = true; + let out = map_cats(&s, labels, breaks, left_closed, include_breaks).unwrap(); + let out = out.struct_().unwrap().fields_as_series()[1].clone(); + let out = out.categorical().unwrap(); + assert!(out._can_fast_unique()); + } +} diff --git a/crates/polars-ops/src/series/ops/diff.rs b/crates/polars-ops/src/series/ops/diff.rs new file mode 100644 index 000000000000..49c20553d10a --- /dev/null +++ b/crates/polars-ops/src/series/ops/diff.rs @@ -0,0 +1,22 @@ +use polars_core::prelude::*; +use polars_core::series::ops::NullBehavior; + +pub fn diff(s: &Series, n: i64, null_behavior: NullBehavior) -> PolarsResult { + use DataType::*; + let s = match s.dtype() { + UInt8 => s.cast(&Int16)?, + UInt16 => s.cast(&Int32)?, + UInt32 | UInt64 => s.cast(&Int64)?, + _ => s.clone(), + }; + + match null_behavior { + NullBehavior::Ignore => &s - &s.shift(n), + NullBehavior::Drop => { + polars_ensure!(n > 0, InvalidOperation: "only positive integer allowed if nulls are dropped in 'diff' operation"); + let n = n as usize; + let len = s.len() - n; + &s.slice(n as i64, len) - &s.slice(0, len) + }, + } +} diff --git a/crates/polars-ops/src/series/ops/duration.rs b/crates/polars-ops/src/series/ops/duration.rs new file mode 100644 index 000000000000..497a806366bb --- /dev/null +++ b/crates/polars-ops/src/series/ops/duration.rs @@ -0,0 +1,91 @@ +use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS, SECONDS_IN_DAY}; +use polars_core::datatypes::{AnyValue, DataType, TimeUnit}; +use polars_core::prelude::Column; +use polars_error::PolarsResult; + +pub fn impl_duration(s: &[Column], time_unit: TimeUnit) -> PolarsResult { + if s.iter().any(|s| s.is_empty()) { + return Ok(Column::new_empty( + s[0].name().clone(), + &DataType::Duration(time_unit), + )); + } + + // TODO: Handle overflow for UInt64 + let weeks = &s[0]; + let days = &s[1]; + let hours = &s[2]; + let minutes = &s[3]; + let seconds = &s[4]; + let mut milliseconds = s[5].clone(); + let mut microseconds = s[6].clone(); + let mut nanoseconds = s[7].clone(); + + let is_scalar = |s: &Column| s.len() == 1; + let is_zero_scalar = |s: &Column| is_scalar(s) && s.get(0).unwrap() == AnyValue::Int64(0); + + // Process subseconds + let max_len = s.iter().map(|s| s.len()).max().unwrap(); + let mut duration = match time_unit { + TimeUnit::Microseconds => { + if is_scalar(µseconds) { + microseconds = microseconds.new_from_index(0, max_len); + } + if !is_zero_scalar(&nanoseconds) { + microseconds = (microseconds + (nanoseconds.wrapping_trunc_div_scalar(1_000)))?; + } + if !is_zero_scalar(&milliseconds) { + microseconds = (microseconds + milliseconds * 1_000)?; + } + microseconds + }, + TimeUnit::Nanoseconds => { + if is_scalar(&nanoseconds) { + nanoseconds = nanoseconds.new_from_index(0, max_len); + } + if !is_zero_scalar(µseconds) { + nanoseconds = (nanoseconds + microseconds * 1_000)?; + } + if !is_zero_scalar(&milliseconds) { + nanoseconds = (nanoseconds + milliseconds * 1_000_000)?; + } + nanoseconds + }, + TimeUnit::Milliseconds => { + if is_scalar(&milliseconds) { + milliseconds = milliseconds.new_from_index(0, max_len); + } + if !is_zero_scalar(&nanoseconds) { + milliseconds = (milliseconds + (nanoseconds.wrapping_trunc_div_scalar(1_000_000)))?; + } + if !is_zero_scalar(µseconds) { + milliseconds = (milliseconds + (microseconds.wrapping_trunc_div_scalar(1_000)))?; + } + milliseconds + }, + }; + + // Process other duration specifiers + let multiplier = match time_unit { + TimeUnit::Nanoseconds => NANOSECONDS, + TimeUnit::Microseconds => MICROSECONDS, + TimeUnit::Milliseconds => MILLISECONDS, + }; + if !is_zero_scalar(seconds) { + duration = (duration + seconds * multiplier)?; + } + if !is_zero_scalar(minutes) { + duration = (duration + minutes * multiplier * 60)?; + } + if !is_zero_scalar(hours) { + duration = (duration + hours * multiplier * 60 * 60)?; + } + if !is_zero_scalar(days) { + duration = (duration + days * multiplier * SECONDS_IN_DAY)?; + } + if !is_zero_scalar(weeks) { + duration = (duration + weeks * multiplier * SECONDS_IN_DAY * 7)?; + } + + duration.cast(&DataType::Duration(time_unit)) +} diff --git a/crates/polars-ops/src/series/ops/ewm.rs b/crates/polars-ops/src/series/ops/ewm.rs new file mode 100644 index 000000000000..d6fa9c31a044 --- /dev/null +++ b/crates/polars-ops/src/series/ops/ewm.rs @@ -0,0 +1,101 @@ +pub use arrow::legacy::kernels::ewm::EWMOptions; +use arrow::legacy::kernels::ewm::{ + ewm_mean as kernel_ewm_mean, ewm_std as kernel_ewm_std, ewm_var as kernel_ewm_var, +}; +use polars_core::prelude::*; + +fn check_alpha(alpha: f64) -> PolarsResult<()> { + polars_ensure!((0.0..=1.0).contains(&alpha), ComputeError: "alpha must be in [0; 1]"); + Ok(()) +} + +pub fn ewm_mean(s: &Series, options: EWMOptions) -> PolarsResult { + check_alpha(options.alpha)?; + match s.dtype() { + DataType::Float32 => { + let xs = s.f32().unwrap(); + let result = kernel_ewm_mean( + xs, + options.alpha as f32, + options.adjust, + options.min_periods, + options.ignore_nulls, + ); + Series::try_from((s.name().clone(), Box::new(result) as ArrayRef)) + }, + DataType::Float64 => { + let xs = s.f64().unwrap(); + let result = kernel_ewm_mean( + xs, + options.alpha, + options.adjust, + options.min_periods, + options.ignore_nulls, + ); + Series::try_from((s.name().clone(), Box::new(result) as ArrayRef)) + }, + _ => ewm_mean(&s.cast(&DataType::Float64)?, options), + } +} + +pub fn ewm_std(s: &Series, options: EWMOptions) -> PolarsResult { + check_alpha(options.alpha)?; + match s.dtype() { + DataType::Float32 => { + let xs = s.f32().unwrap(); + let result = kernel_ewm_std( + xs, + options.alpha as f32, + options.adjust, + options.bias, + options.min_periods, + options.ignore_nulls, + ); + Series::try_from((s.name().clone(), Box::new(result) as ArrayRef)) + }, + DataType::Float64 => { + let xs = s.f64().unwrap(); + let result = kernel_ewm_std( + xs, + options.alpha, + options.adjust, + options.bias, + options.min_periods, + options.ignore_nulls, + ); + Series::try_from((s.name().clone(), Box::new(result) as ArrayRef)) + }, + _ => ewm_std(&s.cast(&DataType::Float64)?, options), + } +} + +pub fn ewm_var(s: &Series, options: EWMOptions) -> PolarsResult { + check_alpha(options.alpha)?; + match s.dtype() { + DataType::Float32 => { + let xs = s.f32().unwrap(); + let result = kernel_ewm_var( + xs, + options.alpha as f32, + options.adjust, + options.bias, + options.min_periods, + options.ignore_nulls, + ); + Series::try_from((s.name().clone(), Box::new(result) as ArrayRef)) + }, + DataType::Float64 => { + let xs = s.f64().unwrap(); + let result = kernel_ewm_var( + xs, + options.alpha, + options.adjust, + options.bias, + options.min_periods, + options.ignore_nulls, + ); + Series::try_from((s.name().clone(), Box::new(result) as ArrayRef)) + }, + _ => ewm_var(&s.cast(&DataType::Float64)?, options), + } +} diff --git a/crates/polars-ops/src/series/ops/ewm_by.rs b/crates/polars-ops/src/series/ops/ewm_by.rs new file mode 100644 index 000000000000..b1b6607e4747 --- /dev/null +++ b/crates/polars-ops/src/series/ops/ewm_by.rs @@ -0,0 +1,220 @@ +use bytemuck::allocation::zeroed_vec; +use num_traits::{Float, FromPrimitive, One, Zero}; +use polars_core::prelude::*; +use polars_core::utils::binary_concatenate_validities; + +pub fn ewm_mean_by( + s: &Series, + times: &Series, + half_life: i64, + times_is_sorted: bool, +) -> PolarsResult { + fn func( + values: &ChunkedArray, + times: &Int64Chunked, + half_life: i64, + times_is_sorted: bool, + ) -> PolarsResult + where + T: PolarsFloatType, + T::Native: Float + Zero + One, + ChunkedArray: IntoSeries, + { + if times_is_sorted { + Ok(ewm_mean_by_impl_sorted(values, times, half_life).into_series()) + } else { + Ok(ewm_mean_by_impl(values, times, half_life).into_series()) + } + } + + polars_ensure!( + s.len() == times.len(), + length_mismatch = "ewm_mean_by", + s.len(), + times.len() + ); + + match (s.dtype(), times.dtype()) { + (DataType::Float64, DataType::Int64) => func( + s.f64().unwrap(), + times.i64().unwrap(), + half_life, + times_is_sorted, + ), + (DataType::Float32, DataType::Int64) => func( + s.f32().unwrap(), + times.i64().unwrap(), + half_life, + times_is_sorted, + ), + #[cfg(feature = "dtype-datetime")] + (_, DataType::Datetime(time_unit, _)) => { + let half_life = adjust_half_life_to_time_unit(half_life, time_unit); + ewm_mean_by( + s, + ×.cast(&DataType::Int64)?, + half_life, + times_is_sorted, + ) + }, + #[cfg(feature = "dtype-date")] + (_, DataType::Date) => ewm_mean_by( + s, + ×.cast(&DataType::Datetime(TimeUnit::Milliseconds, None))?, + half_life, + times_is_sorted, + ), + (_, DataType::UInt64 | DataType::UInt32 | DataType::Int32) => ewm_mean_by( + s, + ×.cast(&DataType::Int64)?, + half_life, + times_is_sorted, + ), + (DataType::UInt64 | DataType::UInt32 | DataType::Int64 | DataType::Int32, _) => { + ewm_mean_by( + &s.cast(&DataType::Float64)?, + times, + half_life, + times_is_sorted, + ) + }, + _ => { + polars_bail!(InvalidOperation: "expected series to be Float64, Float32, \ + Int64, Int32, UInt64, UInt32, and `by` to be Date, Datetime, Int64, Int32, \ + UInt64, or UInt32") + }, + } +} + +/// Sort on behalf of user +fn ewm_mean_by_impl( + values: &ChunkedArray, + times: &Int64Chunked, + half_life: i64, +) -> ChunkedArray +where + T: PolarsFloatType, + T::Native: Float + Zero + One, + ChunkedArray: ChunkTakeUnchecked, +{ + let sorting_indices = times.arg_sort(Default::default()); + let sorted_values = unsafe { values.take_unchecked(&sorting_indices) }; + let sorted_times = unsafe { times.take_unchecked(&sorting_indices) }; + let sorting_indices = sorting_indices + .cont_slice() + .expect("`arg_sort` should have returned a single chunk"); + + let mut out: Vec<_> = zeroed_vec(sorted_times.len()); + + let mut skip_rows: usize = 0; + let mut prev_time: i64 = 0; + let mut prev_result = T::Native::zero(); + for (idx, (value, time)) in sorted_values.iter().zip(sorted_times.iter()).enumerate() { + if let (Some(time), Some(value)) = (time, value) { + prev_time = time; + prev_result = value; + unsafe { + let out_idx = sorting_indices.get_unchecked(idx); + *out.get_unchecked_mut(*out_idx as usize) = prev_result; + } + skip_rows = idx + 1; + break; + }; + } + sorted_values + .iter() + .zip(sorted_times.iter()) + .enumerate() + .skip(skip_rows) + .for_each(|(idx, (value, time))| { + if let (Some(time), Some(value)) = (time, value) { + let result = update(value, prev_result, time, prev_time, half_life); + prev_time = time; + prev_result = result; + unsafe { + let out_idx = sorting_indices.get_unchecked(idx); + *out.get_unchecked_mut(*out_idx as usize) = result; + } + }; + }); + let mut arr = T::Array::from_zeroable_vec(out, values.dtype().to_arrow(CompatLevel::newest())); + if (times.null_count() > 0) || (values.null_count() > 0) { + let validity = binary_concatenate_validities(times, values); + arr = arr.with_validity_typed(validity); + } + ChunkedArray::with_chunk(values.name().clone(), arr) +} + +/// Fastpath if `times` is known to already be sorted. +fn ewm_mean_by_impl_sorted( + values: &ChunkedArray, + times: &Int64Chunked, + half_life: i64, +) -> ChunkedArray +where + T: PolarsFloatType, + T::Native: Float + Zero + One, +{ + let mut out: Vec<_> = zeroed_vec(times.len()); + + let mut skip_rows: usize = 0; + let mut prev_time: i64 = 0; + let mut prev_result = T::Native::zero(); + for (idx, (value, time)) in values.iter().zip(times.iter()).enumerate() { + if let (Some(time), Some(value)) = (time, value) { + prev_time = time; + prev_result = value; + unsafe { + *out.get_unchecked_mut(idx) = prev_result; + } + skip_rows = idx + 1; + break; + } + } + values + .iter() + .zip(times.iter()) + .enumerate() + .skip(skip_rows) + .for_each(|(idx, (value, time))| { + if let (Some(time), Some(value)) = (time, value) { + let result = update(value, prev_result, time, prev_time, half_life); + prev_time = time; + prev_result = result; + unsafe { + *out.get_unchecked_mut(idx) = result; + } + }; + }); + let mut arr = T::Array::from_zeroable_vec(out, values.dtype().to_arrow(CompatLevel::newest())); + if (times.null_count() > 0) || (values.null_count() > 0) { + let validity = binary_concatenate_validities(times, values); + arr = arr.with_validity_typed(validity); + } + ChunkedArray::with_chunk(values.name().clone(), arr) +} + +fn adjust_half_life_to_time_unit(half_life: i64, time_unit: &TimeUnit) -> i64 { + match time_unit { + TimeUnit::Milliseconds => half_life / 1_000_000, + TimeUnit::Microseconds => half_life / 1_000, + TimeUnit::Nanoseconds => half_life, + } +} + +fn update(value: T, prev_result: T, time: i64, prev_time: i64, half_life: i64) -> T +where + T: Float + Zero + One + FromPrimitive, +{ + if value != prev_result { + let delta_time = time - prev_time; + // equivalent to: alpha = 1 - exp(-delta_time*ln(2) / half_life) + let one_minus_alpha = T::from_f64(0.5) + .unwrap() + .powf(T::from_i64(delta_time).unwrap() / T::from_i64(half_life).unwrap()); + let alpha = T::one() - one_minus_alpha; + alpha * value + one_minus_alpha * prev_result + } else { + value + } +} diff --git a/crates/polars-ops/src/series/ops/floor_divide.rs b/crates/polars-ops/src/series/ops/floor_divide.rs new file mode 100644 index 000000000000..68d66652af5a --- /dev/null +++ b/crates/polars-ops/src/series/ops/floor_divide.rs @@ -0,0 +1,56 @@ +use polars_compute::arithmetic::ArithmeticKernel; +use polars_core::chunked_array::ops::arity::apply_binary_kernel_broadcast; +use polars_core::prelude::*; +#[cfg(feature = "dtype-struct")] +use polars_core::series::arithmetic::_struct_arithmetic; +use polars_core::series::arithmetic::NumericListOp; +use polars_core::with_match_physical_numeric_polars_type; + +fn floor_div_ca( + lhs: &ChunkedArray, + rhs: &ChunkedArray, +) -> ChunkedArray { + apply_binary_kernel_broadcast( + lhs, + rhs, + |l, r| ArithmeticKernel::wrapping_floor_div(l.clone(), r.clone()), + |l, r| ArithmeticKernel::wrapping_floor_div_scalar_lhs(l, r.clone()), + |l, r| ArithmeticKernel::wrapping_floor_div_scalar(l.clone(), r), + ) +} + +pub fn floor_div_series(a: &Series, b: &Series) -> PolarsResult { + match (a.dtype(), b.dtype()) { + #[cfg(feature = "dtype-struct")] + (DataType::Struct(_), DataType::Struct(_)) => { + return _struct_arithmetic(a, b, floor_div_series); + }, + (DataType::List(_), _) | (_, DataType::List(_)) => { + return NumericListOp::floor_div().execute(a, b); + }, + #[cfg(feature = "dtype-array")] + (DataType::Array(..), _) | (_, DataType::Array(..)) => { + return polars_core::series::arithmetic::NumericFixedSizeListOp::floor_div() + .execute(a, b); + }, + _ => {}, + } + + if !a.dtype().is_primitive_numeric() { + polars_bail!(op = "floor_div", a.dtype()); + } + + let logical_type = a.dtype(); + + let a = a.to_physical_repr(); + let b = b.to_physical_repr(); + + let out = with_match_physical_numeric_polars_type!(a.dtype(), |$T| { + let a: &ChunkedArray<$T> = a.as_ref().as_ref().as_ref(); + let b: &ChunkedArray<$T> = b.as_ref().as_ref().as_ref(); + + floor_div_ca(a, b).into_series() + }); + + unsafe { out.from_physical_unchecked(logical_type) } +} diff --git a/crates/polars-ops/src/series/ops/fused.rs b/crates/polars-ops/src/series/ops/fused.rs new file mode 100644 index 000000000000..8132eda7c22a --- /dev/null +++ b/crates/polars-ops/src/series/ops/fused.rs @@ -0,0 +1,166 @@ +use arrow::array::PrimitiveArray; +use arrow::compute::utils::combine_validities_and3; +use polars_core::prelude::*; +use polars_core::utils::align_chunks_ternary; +use polars_core::with_match_physical_numeric_polars_type; + +// a + (b * c) +fn fma_arr( + a: &PrimitiveArray, + b: &PrimitiveArray, + c: &PrimitiveArray, +) -> PrimitiveArray { + assert_eq!(a.len(), b.len()); + let validity = combine_validities_and3(a.validity(), b.validity(), c.validity()); + let a = a.values().as_slice(); + let b = b.values().as_slice(); + let c = c.values().as_slice(); + + assert_eq!(a.len(), b.len()); + assert_eq!(b.len(), c.len()); + let out = a + .iter() + .zip(b.iter()) + .zip(c.iter()) + .map(|((a, b), c)| *a + (*b * *c)) + .collect::>(); + PrimitiveArray::from_data_default(out.into(), validity) +} + +fn fma_ca( + a: &ChunkedArray, + b: &ChunkedArray, + c: &ChunkedArray, +) -> ChunkedArray { + let (a, b, c) = align_chunks_ternary(a, b, c); + let chunks = a + .downcast_iter() + .zip(b.downcast_iter()) + .zip(c.downcast_iter()) + .map(|((a, b), c)| fma_arr(a, b, c)); + ChunkedArray::from_chunk_iter(a.name().clone(), chunks) +} + +pub fn fma_columns(a: &Column, b: &Column, c: &Column) -> Column { + if a.len() == b.len() && a.len() == c.len() { + with_match_physical_numeric_polars_type!(a.dtype(), |$T| { + let a: &ChunkedArray<$T> = a.as_materialized_series().as_ref().as_ref().as_ref(); + let b: &ChunkedArray<$T> = b.as_materialized_series().as_ref().as_ref().as_ref(); + let c: &ChunkedArray<$T> = c.as_materialized_series().as_ref().as_ref().as_ref(); + + fma_ca(a, b, c).into_column() + }) + } else { + (a.as_materialized_series() + + &(b.as_materialized_series() * c.as_materialized_series()).unwrap()) + .unwrap() + .into() + } +} + +// a - (b * c) +fn fsm_arr( + a: &PrimitiveArray, + b: &PrimitiveArray, + c: &PrimitiveArray, +) -> PrimitiveArray { + assert_eq!(a.len(), b.len()); + let validity = combine_validities_and3(a.validity(), b.validity(), c.validity()); + let a = a.values().as_slice(); + let b = b.values().as_slice(); + let c = c.values().as_slice(); + + assert_eq!(a.len(), b.len()); + assert_eq!(b.len(), c.len()); + let out = a + .iter() + .zip(b.iter()) + .zip(c.iter()) + .map(|((a, b), c)| *a - (*b * *c)) + .collect::>(); + PrimitiveArray::from_data_default(out.into(), validity) +} + +fn fsm_ca( + a: &ChunkedArray, + b: &ChunkedArray, + c: &ChunkedArray, +) -> ChunkedArray { + let (a, b, c) = align_chunks_ternary(a, b, c); + let chunks = a + .downcast_iter() + .zip(b.downcast_iter()) + .zip(c.downcast_iter()) + .map(|((a, b), c)| fsm_arr(a, b, c)); + ChunkedArray::from_chunk_iter(a.name().clone(), chunks) +} + +pub fn fsm_columns(a: &Column, b: &Column, c: &Column) -> Column { + if a.len() == b.len() && a.len() == c.len() { + with_match_physical_numeric_polars_type!(a.dtype(), |$T| { + let a: &ChunkedArray<$T> = a.as_materialized_series().as_ref().as_ref().as_ref(); + let b: &ChunkedArray<$T> = b.as_materialized_series().as_ref().as_ref().as_ref(); + let c: &ChunkedArray<$T> = c.as_materialized_series().as_ref().as_ref().as_ref(); + + fsm_ca(a, b, c).into_column() + }) + } else { + (a.as_materialized_series() + - &(b.as_materialized_series() * c.as_materialized_series()).unwrap()) + .unwrap() + .into() + } +} + +fn fms_arr( + a: &PrimitiveArray, + b: &PrimitiveArray, + c: &PrimitiveArray, +) -> PrimitiveArray { + assert_eq!(a.len(), b.len()); + let validity = combine_validities_and3(a.validity(), b.validity(), c.validity()); + let a = a.values().as_slice(); + let b = b.values().as_slice(); + let c = c.values().as_slice(); + + assert_eq!(a.len(), b.len()); + assert_eq!(b.len(), c.len()); + let out = a + .iter() + .zip(b.iter()) + .zip(c.iter()) + .map(|((a, b), c)| (*a * *b) - *c) + .collect::>(); + PrimitiveArray::from_data_default(out.into(), validity) +} + +fn fms_ca( + a: &ChunkedArray, + b: &ChunkedArray, + c: &ChunkedArray, +) -> ChunkedArray { + let (a, b, c) = align_chunks_ternary(a, b, c); + let chunks = a + .downcast_iter() + .zip(b.downcast_iter()) + .zip(c.downcast_iter()) + .map(|((a, b), c)| fms_arr(a, b, c)); + ChunkedArray::from_chunk_iter(a.name().clone(), chunks) +} + +pub fn fms_columns(a: &Column, b: &Column, c: &Column) -> Column { + if a.len() == b.len() && a.len() == c.len() { + with_match_physical_numeric_polars_type!(a.dtype(), |$T| { + let a: &ChunkedArray<$T> = a.as_materialized_series().as_ref().as_ref().as_ref(); + let b: &ChunkedArray<$T> = b.as_materialized_series().as_ref().as_ref().as_ref(); + let c: &ChunkedArray<$T> = c.as_materialized_series().as_ref().as_ref().as_ref(); + + fms_ca(a, b, c).into_column() + }) + } else { + (&(a.as_materialized_series() * b.as_materialized_series()).unwrap() + - c.as_materialized_series()) + .unwrap() + .into() + } +} diff --git a/crates/polars-ops/src/series/ops/horizontal.rs b/crates/polars-ops/src/series/ops/horizontal.rs new file mode 100644 index 000000000000..3394092953eb --- /dev/null +++ b/crates/polars-ops/src/series/ops/horizontal.rs @@ -0,0 +1,410 @@ +use std::borrow::Cow; + +use polars_core::chunked_array::cast::CastOptions; +use polars_core::prelude::*; +use polars_core::series::arithmetic::coerce_lhs_rhs; +use polars_core::utils::dtypes_to_supertype; +use polars_core::{POOL, with_match_physical_numeric_polars_type}; +use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; + +fn validate_column_lengths(cs: &[Column]) -> PolarsResult<()> { + let mut length = 1; + for c in cs { + let len = c.len(); + if len != 1 && len != length { + if length == 1 { + length = len; + } else { + polars_bail!(ShapeMismatch: "cannot evaluate two Series of different lengths ({len} and {length})"); + } + } + } + Ok(()) +} + +pub trait MinMaxHorizontal { + /// Aggregate the column horizontally to their min values. + fn min_horizontal(&self) -> PolarsResult>; + /// Aggregate the column horizontally to their max values. + fn max_horizontal(&self) -> PolarsResult>; +} + +impl MinMaxHorizontal for DataFrame { + fn min_horizontal(&self) -> PolarsResult> { + min_horizontal(self.get_columns()) + } + fn max_horizontal(&self) -> PolarsResult> { + max_horizontal(self.get_columns()) + } +} + +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum NullStrategy { + Ignore, + Propagate, +} + +pub trait SumMeanHorizontal { + /// Sum all values horizontally across columns. + fn sum_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult>; + + /// Compute the mean of all numeric values horizontally across columns. + fn mean_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult>; +} + +impl SumMeanHorizontal for DataFrame { + fn sum_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult> { + sum_horizontal(self.get_columns(), null_strategy) + } + fn mean_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult> { + mean_horizontal(self.get_columns(), null_strategy) + } +} + +fn min_binary(left: &ChunkedArray, right: &ChunkedArray) -> ChunkedArray +where + T: PolarsNumericType, + T::Native: PartialOrd, +{ + let op = |l: T::Native, r: T::Native| { + if l < r { l } else { r } + }; + arity::binary_elementwise_values(left, right, op) +} + +fn max_binary(left: &ChunkedArray, right: &ChunkedArray) -> ChunkedArray +where + T: PolarsNumericType, + T::Native: PartialOrd, +{ + let op = |l: T::Native, r: T::Native| { + if l > r { l } else { r } + }; + arity::binary_elementwise_values(left, right, op) +} + +fn min_max_binary_columns(left: &Column, right: &Column, min: bool) -> PolarsResult { + if left.dtype().to_physical().is_primitive_numeric() + && right.dtype().to_physical().is_primitive_numeric() + && left.null_count() == 0 + && right.null_count() == 0 + && left.len() == right.len() + { + match (left, right) { + (Column::Series(left), Column::Series(right)) => { + let (lhs, rhs) = coerce_lhs_rhs(left, right)?; + let logical = lhs.dtype(); + let lhs = lhs.to_physical_repr(); + let rhs = rhs.to_physical_repr(); + + with_match_physical_numeric_polars_type!(lhs.dtype(), |$T| { + let a: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); + let b: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); + + unsafe { + if min { + min_binary(a, b).into_series().from_physical_unchecked(logical) + } else { + max_binary(a, b).into_series().from_physical_unchecked(logical) + } + } + }) + .map(Column::from) + }, + _ => { + let mask = if min { + left.lt(right)? + } else { + left.gt(right)? + }; + + left.zip_with(&mask, right) + }, + } + } else { + let mask = if min { + left.lt(right)? & left.is_not_null() | right.is_null() + } else { + left.gt(right)? & left.is_not_null() | right.is_null() + }; + left.zip_with(&mask, right) + } +} + +pub fn max_horizontal(columns: &[Column]) -> PolarsResult> { + validate_column_lengths(columns)?; + + let max_fn = |acc: &Column, s: &Column| min_max_binary_columns(acc, s, false); + + match columns.len() { + 0 => Ok(None), + 1 => Ok(Some(columns[0].clone())), + 2 => max_fn(&columns[0], &columns[1]).map(Some), + _ => { + // the try_reduce_with is a bit slower in parallelism, + // but I don't think it matters here as we parallelize over columns, not over elements + POOL.install(|| { + columns + .par_iter() + .map(|s| Ok(Cow::Borrowed(s))) + .try_reduce_with(|l, r| max_fn(&l, &r).map(Cow::Owned)) + // we can unwrap the option, because we are certain there is a column + // we started this operation on 3 columns + .unwrap() + .map(|cow| Some(cow.into_owned())) + }) + }, + } +} + +pub fn min_horizontal(columns: &[Column]) -> PolarsResult> { + validate_column_lengths(columns)?; + + let min_fn = |acc: &Column, s: &Column| min_max_binary_columns(acc, s, true); + + match columns.len() { + 0 => Ok(None), + 1 => Ok(Some(columns[0].clone())), + 2 => min_fn(&columns[0], &columns[1]).map(Some), + _ => { + // the try_reduce_with is a bit slower in parallelism, + // but I don't think it matters here as we parallelize over columns, not over elements + POOL.install(|| { + columns + .par_iter() + .map(|s| Ok(Cow::Borrowed(s))) + .try_reduce_with(|l, r| min_fn(&l, &r).map(Cow::Owned)) + // we can unwrap the option, because we are certain there is a column + // we started this operation on 3 columns + .unwrap() + .map(|cow| Some(cow.into_owned())) + }) + }, + } +} + +pub fn sum_horizontal( + columns: &[Column], + null_strategy: NullStrategy, +) -> PolarsResult> { + validate_column_lengths(columns)?; + let ignore_nulls = null_strategy == NullStrategy::Ignore; + + let apply_null_strategy = |s: Series| -> PolarsResult { + if ignore_nulls && s.null_count() > 0 { + s.fill_null(FillNullStrategy::Zero) + } else { + Ok(s) + } + }; + + let sum_fn = |acc: Series, s: Series| -> PolarsResult { + let acc: Series = apply_null_strategy(acc)?; + let s = apply_null_strategy(s)?; + // This will do owned arithmetic and can be mutable + std::ops::Add::add(acc, s) + }; + + // @scalar-opt + let non_null_cols = columns + .iter() + .filter(|x| x.dtype() != &DataType::Null) + .map(|c| c.as_materialized_series()) + .collect::>(); + + // If we have any null columns and null strategy is not `Ignore`, we can return immediately. + if !ignore_nulls && non_null_cols.len() < columns.len() { + // We must determine the correct return dtype. + let return_dtype = match dtypes_to_supertype(non_null_cols.iter().map(|c| c.dtype()))? { + DataType::Boolean => IDX_DTYPE, + dt => dt, + }; + return Ok(Some(Column::full_null( + columns[0].name().clone(), + columns[0].len(), + &return_dtype, + ))); + } + + match non_null_cols.len() { + 0 => { + if columns.is_empty() { + Ok(None) + } else { + // all columns are null dtype, so result is null dtype + Ok(Some(columns[0].clone())) + } + }, + 1 => Ok(Some( + apply_null_strategy(if non_null_cols[0].dtype() == &DataType::Boolean { + non_null_cols[0].cast(&IDX_DTYPE)? + } else { + non_null_cols[0].clone() + })? + .into(), + )), + 2 => sum_fn(non_null_cols[0].clone(), non_null_cols[1].clone()) + .map(Column::from) + .map(Some), + _ => { + // the try_reduce_with is a bit slower in parallelism, + // but I don't think it matters here as we parallelize over columns, not over elements + let out = POOL.install(|| { + non_null_cols + .into_par_iter() + .cloned() + .map(Ok) + .try_reduce_with(sum_fn) + // We can unwrap because we started with at least 3 columns, so we always get a Some + .unwrap() + }); + out.map(Column::from).map(Some) + }, + } +} + +pub fn mean_horizontal( + columns: &[Column], + null_strategy: NullStrategy, +) -> PolarsResult> { + validate_column_lengths(columns)?; + + let (numeric_columns, non_numeric_columns): (Vec<_>, Vec<_>) = columns.iter().partition(|s| { + let dtype = s.dtype(); + dtype.is_primitive_numeric() || dtype.is_decimal() || dtype.is_bool() || dtype.is_null() + }); + + if !non_numeric_columns.is_empty() { + let col = non_numeric_columns.first().cloned(); + polars_bail!( + InvalidOperation: "'horizontal_mean' expects numeric expressions, found {:?} (dtype={})", + col.unwrap().name(), + col.unwrap().dtype(), + ); + } + let columns = numeric_columns.into_iter().cloned().collect::>(); + let num_rows = columns.len(); + match num_rows { + 0 => Ok(None), + 1 => Ok(Some(match columns[0].dtype() { + dt if dt != &DataType::Float32 && !dt.is_decimal() => { + columns[0].cast(&DataType::Float64)? + }, + _ => columns[0].clone(), + })), + _ => { + let sum = || sum_horizontal(columns.as_slice(), null_strategy); + let null_count = || { + columns + .par_iter() + .map(|c| { + c.is_null() + .into_column() + .cast_with_options(&DataType::UInt32, CastOptions::NonStrict) + }) + .reduce_with(|l, r| { + let l = l?; + let r = r?; + let result = std::ops::Add::add(&l, &r)?; + PolarsResult::Ok(result) + }) + // we can unwrap the option, because we are certain there is a column + // we started this operation on 2 columns + .unwrap() + }; + + let (sum, null_count) = POOL.install(|| rayon::join(sum, null_count)); + let sum = sum?; + let null_count = null_count?; + + // value lengths: len - null_count + let value_length: UInt32Chunked = (Column::new_scalar( + PlSmallStr::EMPTY, + Scalar::from(num_rows as u32), + null_count.len(), + ) - null_count)? + .u32() + .unwrap() + .clone(); + + // make sure that we do not divide by zero + // by replacing with None + let dt = if sum + .as_ref() + .is_some_and(|s| s.dtype() == &DataType::Float32) + { + &DataType::Float32 + } else { + &DataType::Float64 + }; + let value_length = value_length + .set(&value_length.equal(0), None)? + .into_column() + .cast(dt)?; + + sum.map(|sum| std::ops::Div::div(&sum, &value_length)) + .transpose() + }, + } +} + +pub fn coalesce_columns(s: &[Column]) -> PolarsResult { + // TODO! this can be faster if we have more than two inputs. + polars_ensure!(!s.is_empty(), NoData: "cannot coalesce empty list"); + let mut out = s[0].clone(); + for s in s { + if !out.null_count() == 0 { + return Ok(out); + } else { + let mask = out.is_not_null(); + out = out + .as_materialized_series() + .zip_with_same_type(&mask, s.as_materialized_series())? + .into(); + } + } + Ok(out) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + #[cfg_attr(miri, ignore)] + fn test_horizontal_agg() { + let a = Column::new("a".into(), [1, 2, 6]); + let b = Column::new("b".into(), [Some(1), None, None]); + let c = Column::new("c".into(), [Some(4), None, Some(3)]); + + let df = DataFrame::new(vec![a, b, c]).unwrap(); + assert_eq!( + Vec::from( + df.mean_horizontal(NullStrategy::Ignore) + .unwrap() + .unwrap() + .f64() + .unwrap() + ), + &[Some(2.0), Some(2.0), Some(4.5)] + ); + assert_eq!( + Vec::from( + df.sum_horizontal(NullStrategy::Ignore) + .unwrap() + .unwrap() + .i32() + .unwrap() + ), + &[Some(6), Some(2), Some(9)] + ); + assert_eq!( + Vec::from(df.min_horizontal().unwrap().unwrap().i32().unwrap()), + &[Some(1), Some(2), Some(3)] + ); + assert_eq!( + Vec::from(df.max_horizontal().unwrap().unwrap().i32().unwrap()), + &[Some(4), Some(2), Some(6)] + ); + } +} diff --git a/crates/polars-ops/src/series/ops/index.rs b/crates/polars-ops/src/series/ops/index.rs new file mode 100644 index 000000000000..20746614037f --- /dev/null +++ b/crates/polars-ops/src/series/ops/index.rs @@ -0,0 +1,108 @@ +use num_traits::{Signed, Zero}; +use polars_core::error::{PolarsResult, polars_ensure}; +use polars_core::prelude::arity::unary_elementwise_values; +use polars_core::prelude::{ + ChunkedArray, Column, DataType, IDX_DTYPE, IdxCa, PolarsIntegerType, Series, +}; +use polars_utils::index::ToIdx; + +fn convert(ca: &ChunkedArray, target_len: usize) -> PolarsResult +where + T: PolarsIntegerType, + T::Native: ToIdx, +{ + let target_len = target_len as u64; + Ok(unary_elementwise_values(ca, |v| v.to_idx(target_len))) +} + +pub fn convert_to_unsigned_index(s: &Series, target_len: usize) -> PolarsResult { + let dtype = s.dtype(); + polars_ensure!(dtype.is_integer(), InvalidOperation: "expected integers as index"); + if dtype.is_unsigned_integer() { + let nulls_before_cast = s.null_count(); + let out = s.cast(&IDX_DTYPE).unwrap(); + polars_ensure!(out.null_count() == nulls_before_cast, OutOfBounds: "some integers did not fit polars' index size"); + return Ok(out.idx().unwrap().clone()); + } + match dtype { + DataType::Int64 => { + let ca = s.i64().unwrap(); + convert(ca, target_len) + }, + DataType::Int32 => { + let ca = s.i32().unwrap(); + convert(ca, target_len) + }, + #[cfg(feature = "dtype-i16")] + DataType::Int16 => { + let ca = s.i16().unwrap(); + convert(ca, target_len) + }, + #[cfg(feature = "dtype-i8")] + DataType::Int8 => { + let ca = s.i8().unwrap(); + convert(ca, target_len) + }, + _ => unreachable!(), + } +} + +/// May give false negatives because it ignores the null values. +fn is_positive_idx_uncertain_impl(ca: &ChunkedArray) -> bool +where + T: PolarsIntegerType, + T::Native: Signed, +{ + ca.downcast_iter().all(|v| { + let values = v.values(); + let mut all_positive = true; + + // process chunks to autovec but still have early return + for chunk in values.chunks(1024) { + for v in chunk.iter() { + all_positive &= v.is_positive() | v.is_zero() + } + if !all_positive { + return all_positive; + } + } + all_positive + }) +} + +/// May give false negatives because it ignores the null values. +pub fn is_positive_idx_uncertain(s: &Series) -> bool { + let dtype = s.dtype(); + debug_assert!(dtype.is_integer(), "expected integers as index"); + if dtype.is_unsigned_integer() { + return true; + } + match dtype { + DataType::Int64 => { + let ca = s.i64().unwrap(); + is_positive_idx_uncertain_impl(ca) + }, + DataType::Int32 => { + let ca = s.i32().unwrap(); + is_positive_idx_uncertain_impl(ca) + }, + #[cfg(feature = "dtype-i16")] + DataType::Int16 => { + let ca = s.i16().unwrap(); + is_positive_idx_uncertain_impl(ca) + }, + #[cfg(feature = "dtype-i8")] + DataType::Int8 => { + let ca = s.i8().unwrap(); + is_positive_idx_uncertain_impl(ca) + }, + _ => unreachable!(), + } +} + +/// May give false negatives because it ignores the null values. +pub fn is_positive_idx_uncertain_col(c: &Column) -> bool { + // @scalar-opt + // @partition-opt + is_positive_idx_uncertain(c.as_materialized_series()) +} diff --git a/crates/polars-ops/src/series/ops/index_of.rs b/crates/polars-ops/src/series/ops/index_of.rs new file mode 100644 index 000000000000..ee7f74d3dd70 --- /dev/null +++ b/crates/polars-ops/src/series/ops/index_of.rs @@ -0,0 +1,121 @@ +use arrow::array::{BinaryArray, PrimitiveArray}; +use polars_core::downcast_as_macro_arg_physical; +use polars_core::prelude::*; +use polars_utils::total_ord::TotalEq; +use row_encode::encode_rows_unordered; + +/// Find the index of the value, or ``None`` if it can't be found. +fn index_of_value<'a, DT, AR>(ca: &'a ChunkedArray

, value: AR::ValueT<'a>) -> Option +where + DT: PolarsDataType, + AR: StaticArray, + AR::ValueT<'a>: TotalEq, +{ + let req_value = &value; + let mut index = 0; + for chunk in ca.chunks() { + let chunk = chunk.as_any().downcast_ref::().unwrap(); + if chunk.validity().is_some() { + for maybe_value in chunk.iter() { + if maybe_value.map(|v| v.tot_eq(req_value)) == Some(true) { + return Some(index); + } else { + index += 1; + } + } + } else { + // A lack of a validity bitmap means there are no nulls, so we + // can simplify our logic and use a faster code path: + for value in chunk.values_iter() { + if value.tot_eq(req_value) { + return Some(index); + } else { + index += 1; + } + } + } + } + None +} + +fn index_of_numeric_value(ca: &ChunkedArray, value: T::Native) -> Option +where + T: PolarsNumericType, +{ + index_of_value::<_, PrimitiveArray>(ca, value) +} + +/// Try casting the value to the correct type, then call +/// index_of_numeric_value(). +macro_rules! try_index_of_numeric_ca { + ($ca:expr, $value:expr) => {{ + let ca = $ca; + let value = $value; + // extract() returns None if casting failed, so consider an extract() + // failure as not finding the value. Nulls should have been handled + // earlier. + let value = value.value().extract().unwrap(); + index_of_numeric_value(ca, value) + }}; +} + +/// Find the index of a given value (the first and only entry in `value_series`) +/// within the series. +pub fn index_of(series: &Series, needle: Scalar) -> PolarsResult> { + polars_ensure!( + series.dtype() == needle.dtype(), + InvalidOperation: "Cannot perform index_of with mismatching datatypes: {:?} and {:?}", + series.dtype(), + needle.dtype(), + ); + + // Series is null: + if series.dtype().is_null() { + if needle.is_null() { + return Ok((!series.is_empty()).then_some(0)); + } else { + return Ok(None); + } + } + + // Series is not null, and the value is null: + if needle.is_null() { + let mut index = 0; + for chunk in series.chunks() { + let length = chunk.len(); + if let Some(bitmap) = chunk.validity() { + let leading_ones = bitmap.leading_ones(); + if leading_ones < length { + return Ok(Some(index + leading_ones)); + } + } else { + index += length; + } + } + return Ok(None); + } + + if series.dtype().is_primitive_numeric() { + return Ok(downcast_as_macro_arg_physical!( + series, + try_index_of_numeric_ca, + needle + )); + } + + if series.dtype().is_categorical() { + // See https://github.com/pola-rs/polars/issues/20318 + polars_bail!(InvalidOperation: "index_of() on Categoricals is not supported"); + } + + // For non-numeric dtypes, we convert to row-encoding, which essentially has + // us searching the physical representation of the data as a series of + // bytes. + let value_as_column = Column::new_scalar(PlSmallStr::EMPTY, needle, 1); + let value_as_row_encoded_ca = encode_rows_unordered(&[value_as_column])?; + let value = value_as_row_encoded_ca + .first() + .expect("Shouldn't have nulls in a row-encoded result"); + let ca = encode_rows_unordered(&[series.clone().into()])?; + Ok(index_of_value::<_, BinaryArray>(&ca, value)) +} diff --git a/crates/polars-ops/src/series/ops/int_range.rs b/crates/polars-ops/src/series/ops/int_range.rs new file mode 100644 index 000000000000..5e5a3d419acb --- /dev/null +++ b/crates/polars-ops/src/series/ops/int_range.rs @@ -0,0 +1,35 @@ +use polars_core::prelude::*; +use polars_core::series::IsSorted; + +pub fn new_int_range( + start: T::Native, + end: T::Native, + step: i64, + name: PlSmallStr, +) -> PolarsResult +where + T: PolarsIntegerType, + ChunkedArray: IntoSeries, + std::ops::Range: DoubleEndedIterator, +{ + let mut ca = match step { + 0 => polars_bail!(InvalidOperation: "step must not be zero"), + 1 => ChunkedArray::::from_iter_values(name, start..end), + 2.. => ChunkedArray::::from_iter_values(name, (start..end).step_by(step as usize)), + _ => ChunkedArray::::from_iter_values( + name, + (end..start) + .step_by(step.unsigned_abs() as usize) + .map(|x| start - (x - end)), + ), + }; + + let is_sorted = if end < start { + IsSorted::Descending + } else { + IsSorted::Ascending + }; + ca.set_sorted_flag(is_sorted); + + Ok(ca.into_series()) +} diff --git a/crates/polars-ops/src/series/ops/interpolation/interpolate.rs b/crates/polars-ops/src/series/ops/interpolation/interpolate.rs new file mode 100644 index 000000000000..c8f4fdc06e8e --- /dev/null +++ b/crates/polars-ops/src/series/ops/interpolation/interpolate.rs @@ -0,0 +1,312 @@ +use std::ops::{Add, Div, Mul, Sub}; + +use arrow::array::PrimitiveArray; +use arrow::bitmap::MutableBitmap; +use num_traits::{NumCast, Zero}; +use polars_core::downcast_as_macro_arg_physical; +use polars_core::prelude::*; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use super::{linear_itp, nearest_itp}; + +fn near_interp(low: T, high: T, steps: IdxSize, steps_n: T, out: &mut Vec) +where + T: Sub + + Mul + + Add + + Div + + NumCast + + Copy + + PartialOrd, +{ + let diff = high - low; + for step_i in 1..steps { + let step_i: T = NumCast::from(step_i).unwrap(); + let v = nearest_itp(low, step_i, diff, steps_n); + out.push(v) + } +} + +#[inline] +fn signed_interp(low: T, high: T, steps: IdxSize, steps_n: T, out: &mut Vec) +where + T: Sub + Mul + Add + Div + NumCast + Copy, +{ + let slope = (high - low) / steps_n; + for step_i in 1..steps { + let step_i: T = NumCast::from(step_i).unwrap(); + let v = linear_itp(low, step_i, slope); + out.push(v) + } +} + +fn interpolate_impl(chunked_arr: &ChunkedArray, interpolation_branch: I) -> ChunkedArray +where + T: PolarsNumericType, + I: Fn(T::Native, T::Native, IdxSize, T::Native, &mut Vec), +{ + // This implementation differs from pandas as that boundary None's are not removed. + // This prevents a lot of errors due to expressions leading to different lengths. + if !chunked_arr.has_nulls() || chunked_arr.null_count() == chunked_arr.len() { + return chunked_arr.clone(); + } + + // We first find the first and last so that we can set the null buffer. + let first = chunked_arr.first_non_null().unwrap(); + let last = chunked_arr.last_non_null().unwrap() + 1; + + // Fill out with `first` nulls. + let mut out = Vec::with_capacity(chunked_arr.len()); + let mut iter = chunked_arr.iter().skip(first); + for _ in 0..first { + out.push(Zero::zero()); + } + + // The next element of `iter` is definitely `Some(Some(v))`, because we skipped the first + // elements `first` and if all values were missing we'd have done an early return. + let mut low = iter.next().unwrap().unwrap(); + out.push(low); + while let Some(next) = iter.next() { + if let Some(v) = next { + out.push(v); + low = v; + } else { + let mut steps = 1 as IdxSize; + for next in iter.by_ref() { + steps += 1; + if let Some(high) = next { + let steps_n: T::Native = NumCast::from(steps).unwrap(); + interpolation_branch(low, high, steps, steps_n, &mut out); + out.push(high); + low = high; + break; + } + } + } + } + if first != 0 || last != chunked_arr.len() { + let mut validity = MutableBitmap::with_capacity(chunked_arr.len()); + validity.extend_constant(chunked_arr.len(), true); + + for i in 0..first { + unsafe { validity.set_unchecked(i, false) }; + } + + for i in last..chunked_arr.len() { + unsafe { validity.set_unchecked(i, false) }; + out.push(Zero::zero()) + } + + let array = PrimitiveArray::new( + T::get_dtype().to_arrow(CompatLevel::newest()), + out.into(), + Some(validity.into()), + ); + ChunkedArray::with_chunk(chunked_arr.name().clone(), array) + } else { + ChunkedArray::from_vec(chunked_arr.name().clone(), out) + } +} + +fn interpolate_nearest(s: &Series) -> Series { + match s.dtype() { + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_, _) | DataType::Enum(_, _) => s.clone(), + DataType::Binary => s.clone(), + #[cfg(feature = "dtype-struct")] + DataType::Struct(_) => s.clone(), + DataType::List(_) => s.clone(), + _ => { + let logical = s.dtype(); + let s = s.to_physical_repr(); + + macro_rules! dispatch { + ($ca:expr) => {{ interpolate_impl($ca, near_interp).into_series() }}; + } + let out = downcast_as_macro_arg_physical!(s, dispatch); + out.cast(logical).unwrap() + }, + } +} + +fn interpolate_linear(s: &Series) -> Series { + match s.dtype() { + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_, _) | DataType::Enum(_, _) => s.clone(), + DataType::Binary => s.clone(), + #[cfg(feature = "dtype-struct")] + DataType::Struct(_) => s.clone(), + DataType::List(_) => s.clone(), + _ => { + let logical = s.dtype(); + + let s = s.to_physical_repr(); + + let out = if matches!( + logical, + DataType::Date | DataType::Datetime(_, _) | DataType::Duration(_) | DataType::Time + ) { + match s.dtype() { + // Datetime, Time, or Duration + DataType::Int64 => linear_interp_signed(s.i64().unwrap()), + // Date + DataType::Int32 => linear_interp_signed(s.i32().unwrap()), + _ => unreachable!(), + } + } else { + match s.dtype() { + DataType::Float32 => linear_interp_signed(s.f32().unwrap()), + DataType::Float64 => linear_interp_signed(s.f64().unwrap()), + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Int128 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => { + linear_interp_signed(s.cast(&DataType::Float64).unwrap().f64().unwrap()) + }, + _ => s.as_ref().clone(), + } + }; + match logical { + DataType::Date + | DataType::Datetime(_, _) + | DataType::Duration(_) + | DataType::Time => out.cast(logical).unwrap(), + _ => out, + } + }, + } +} + +fn linear_interp_signed(ca: &ChunkedArray) -> Series +where + ChunkedArray: IntoSeries, +{ + interpolate_impl(ca, signed_interp::).into_series() +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum InterpolationMethod { + Linear, + Nearest, +} + +pub fn interpolate(s: &Series, method: InterpolationMethod) -> Series { + match method { + InterpolationMethod::Linear => interpolate_linear(s), + InterpolationMethod::Nearest => interpolate_nearest(s), + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_interpolate() { + let ca = UInt32Chunked::new("".into(), &[Some(1), None, None, Some(4), Some(5)]); + let out = interpolate(&ca.into_series(), InterpolationMethod::Linear); + let out = out.f64().unwrap(); + assert_eq!( + Vec::from(out), + &[Some(1.0), Some(2.0), Some(3.0), Some(4.0), Some(5.0)] + ); + + let ca = UInt32Chunked::new("".into(), &[None, Some(1), None, None, Some(4), Some(5)]); + let out = interpolate(&ca.into_series(), InterpolationMethod::Linear); + let out = out.f64().unwrap(); + assert_eq!( + Vec::from(out), + &[None, Some(1.0), Some(2.0), Some(3.0), Some(4.0), Some(5.0)] + ); + + let ca = UInt32Chunked::new( + "".into(), + &[None, Some(1), None, None, Some(4), Some(5), None], + ); + let out = interpolate(&ca.into_series(), InterpolationMethod::Linear); + let out = out.f64().unwrap(); + assert_eq!( + Vec::from(out), + &[ + None, + Some(1.0), + Some(2.0), + Some(3.0), + Some(4.0), + Some(5.0), + None + ] + ); + let ca = UInt32Chunked::new( + "".into(), + &[None, Some(1), None, None, Some(4), Some(5), None], + ); + let out = interpolate(&ca.into_series(), InterpolationMethod::Nearest); + let out = out.u32().unwrap(); + assert_eq!( + Vec::from(out), + &[None, Some(1), Some(1), Some(4), Some(4), Some(5), None] + ); + } + + #[test] + fn test_interpolate_decreasing_unsigned() { + let ca = UInt32Chunked::new("".into(), &[Some(4), None, None, Some(1)]); + let out = interpolate(&ca.into_series(), InterpolationMethod::Linear); + let out = out.f64().unwrap(); + assert_eq!( + Vec::from(out), + &[Some(4.0), Some(3.0), Some(2.0), Some(1.0)] + ) + } + + #[test] + fn test_interpolate2() { + let ca = Float32Chunked::new( + "".into(), + &[ + Some(4653f32), + None, + None, + None, + Some(4657f32), + None, + None, + Some(4657f32), + None, + Some(4657f32), + None, + None, + Some(4660f32), + ], + ); + let out = interpolate(&ca.into_series(), InterpolationMethod::Linear); + let out = out.f32().unwrap(); + + assert_eq!( + Vec::from(out), + &[ + Some(4653.0), + Some(4654.0), + Some(4655.0), + Some(4656.0), + Some(4657.0), + Some(4657.0), + Some(4657.0), + Some(4657.0), + Some(4657.0), + Some(4657.0), + Some(4658.0), + Some(4659.0), + Some(4660.0) + ] + ); + } +} diff --git a/crates/polars-ops/src/series/ops/interpolation/interpolate_by.rs b/crates/polars-ops/src/series/ops/interpolation/interpolate_by.rs new file mode 100644 index 000000000000..986215664bd2 --- /dev/null +++ b/crates/polars-ops/src/series/ops/interpolation/interpolate_by.rs @@ -0,0 +1,344 @@ +use std::ops::{Add, Div, Mul, Sub}; + +use arrow::array::PrimitiveArray; +use arrow::bitmap::MutableBitmap; +use bytemuck::allocation::zeroed_vec; +use num_traits::{NumCast, Zero}; +use polars_core::prelude::*; +use polars_utils::slice::SliceAble; + +use super::linear_itp; + +/// # Safety +/// - `x` must be non-empty. +#[inline] +unsafe fn signed_interp_by_sorted(y_start: T, y_end: T, x: &[F], out: &mut Vec) +where + T: Sub + + Mul + + Add + + Div + + NumCast + + Copy + + Zero, + F: Sub + NumCast + Copy, +{ + let range_y = y_end - y_start; + let x_start; + let range_x; + let iter; + unsafe { + x_start = x.get_unchecked(0); + range_x = NumCast::from(*x.get_unchecked(x.len() - 1) - *x_start).unwrap(); + iter = x.slice_unchecked(1..x.len() - 1).iter(); + } + let slope = range_y / range_x; + for x_i in iter { + let x_delta = NumCast::from(*x_i - *x_start).unwrap(); + let v = linear_itp(y_start, x_delta, slope); + out.push(v) + } +} + +/// # Safety +/// - `x` must be non-empty. +/// - `sorting_indices` must be the same size as `x` +#[inline] +unsafe fn signed_interp_by( + y_start: T, + y_end: T, + x: &[F], + out: &mut [T], + sorting_indices: &[IdxSize], +) where + T: Sub + + Mul + + Add + + Div + + NumCast + + Copy + + Zero, + F: Sub + NumCast + Copy, +{ + let range_y = y_end - y_start; + let x_start; + let range_x; + let iter; + unsafe { + x_start = x.get_unchecked(0); + range_x = NumCast::from(*x.get_unchecked(x.len() - 1) - *x_start).unwrap(); + iter = x.slice_unchecked(1..x.len() - 1).iter(); + } + let slope = range_y / range_x; + for (idx, x_i) in iter.enumerate() { + let x_delta = NumCast::from(*x_i - *x_start).unwrap(); + let v = linear_itp(y_start, x_delta, slope); + unsafe { + let out_idx = sorting_indices.get_unchecked(idx + 1); + *out.get_unchecked_mut(*out_idx as usize) = v; + } + } +} + +fn interpolate_impl_by_sorted( + chunked_arr: &ChunkedArray, + by: &ChunkedArray, + interpolation_branch: I, +) -> PolarsResult> +where + T: PolarsNumericType, + F: PolarsNumericType, + I: Fn(T::Native, T::Native, &[F::Native], &mut Vec), +{ + // This implementation differs from pandas as that boundary None's are not removed. + // This prevents a lot of errors due to expressions leading to different lengths. + if !chunked_arr.has_nulls() || chunked_arr.null_count() == chunked_arr.len() { + return Ok(chunked_arr.clone()); + } + + polars_ensure!(by.null_count() == 0, InvalidOperation: "null values in `by` column are not yet supported in 'interpolate_by' expression"); + let by = by.rechunk(); + let by_values = by.cont_slice().unwrap(); + + // We first find the first and last so that we can set the null buffer. + let first = chunked_arr.first_non_null().unwrap(); + let last = chunked_arr.last_non_null().unwrap() + 1; + + // Fill out with `first` nulls. + let mut out = Vec::with_capacity(chunked_arr.len()); + let mut iter = chunked_arr.iter().enumerate().skip(first); + for _ in 0..first { + out.push(Zero::zero()); + } + + // The next element of `iter` is definitely `Some(idx, Some(v))`, because we skipped the first + // `first` elements and if all values were missing we'd have done an early return. + let (mut low_idx, opt_low) = iter.next().unwrap(); + let mut low = opt_low.unwrap(); + out.push(low); + while let Some((idx, next)) = iter.next() { + if let Some(v) = next { + out.push(v); + low = v; + low_idx = idx; + } else { + for (high_idx, next) in iter.by_ref() { + if let Some(high) = next { + // SAFETY: we are in bounds, and `x` is non-empty. + unsafe { + let x = &by_values.slice_unchecked(low_idx..high_idx + 1); + interpolation_branch(low, high, x, &mut out); + } + out.push(high); + low = high; + low_idx = high_idx; + break; + } + } + } + } + if first != 0 || last != chunked_arr.len() { + let mut validity = MutableBitmap::with_capacity(chunked_arr.len()); + validity.extend_constant(chunked_arr.len(), true); + + for i in 0..first { + unsafe { validity.set_unchecked(i, false) }; + } + + for i in last..chunked_arr.len() { + unsafe { validity.set_unchecked(i, false) } + out.push(Zero::zero()); + } + + let array = PrimitiveArray::new( + T::get_dtype().to_arrow(CompatLevel::newest()), + out.into(), + Some(validity.into()), + ); + Ok(ChunkedArray::with_chunk(chunked_arr.name().clone(), array)) + } else { + Ok(ChunkedArray::from_vec(chunked_arr.name().clone(), out)) + } +} + +// Sort on behalf of user +fn interpolate_impl_by( + ca: &ChunkedArray, + by: &ChunkedArray, + interpolation_branch: I, +) -> PolarsResult> +where + T: PolarsNumericType, + F: PolarsNumericType, + I: Fn(T::Native, T::Native, &[F::Native], &mut [T::Native], &[IdxSize]), +{ + // This implementation differs from pandas as that boundary None's are not removed. + // This prevents a lot of errors due to expressions leading to different lengths. + if !ca.has_nulls() || ca.null_count() == ca.len() { + return Ok(ca.clone()); + } + + polars_ensure!(by.null_count() == 0, InvalidOperation: "null values in `by` column are not yet supported in 'interpolate_by' expression"); + let sorting_indices = by.arg_sort(Default::default()); + let sorting_indices = sorting_indices + .cont_slice() + .expect("arg sort produces single chunk"); + let by_sorted = unsafe { by.take_unchecked(sorting_indices) }; + let ca_sorted = unsafe { ca.take_unchecked(sorting_indices) }; + let by_sorted_values = by_sorted + .cont_slice() + .expect("We already checked for nulls, and `take_unchecked` produces single chunk"); + + // We first find the first and last so that we can set the null buffer. + let first = ca_sorted.first_non_null().unwrap(); + let last = ca_sorted.last_non_null().unwrap() + 1; + + let mut out = zeroed_vec(ca_sorted.len()); + let mut iter = ca_sorted.iter().enumerate().skip(first); + + // The next element of `iter` is definitely `Some(idx, Some(v))`, because we skipped the first + // `first` elements and if all values were missing we'd have done an early return. + let (mut low_idx, opt_low) = iter.next().unwrap(); + let mut low = opt_low.unwrap(); + unsafe { + let out_idx = sorting_indices.get_unchecked(low_idx); + *out.get_unchecked_mut(*out_idx as usize) = low; + } + while let Some((idx, next)) = iter.next() { + if let Some(v) = next { + unsafe { + let out_idx = sorting_indices.get_unchecked(idx); + *out.get_unchecked_mut(*out_idx as usize) = v; + } + low = v; + low_idx = idx; + } else { + for (high_idx, next) in iter.by_ref() { + if let Some(high) = next { + // SAFETY: we are in bounds, and the slices are the same length (and non-empty). + unsafe { + interpolation_branch( + low, + high, + by_sorted_values.slice_unchecked(low_idx..high_idx + 1), + &mut out, + sorting_indices.slice_unchecked(low_idx..high_idx + 1), + ); + let out_idx = sorting_indices.get_unchecked(high_idx); + *out.get_unchecked_mut(*out_idx as usize) = high; + } + low = high; + low_idx = high_idx; + break; + } + } + } + } + if first != 0 || last != ca_sorted.len() { + let mut validity = MutableBitmap::with_capacity(ca_sorted.len()); + validity.extend_constant(ca_sorted.len(), true); + + for i in 0..first { + unsafe { + let out_idx = sorting_indices.get_unchecked(i); + validity.set_unchecked(*out_idx as usize, false); + } + } + + for i in last..ca_sorted.len() { + unsafe { + let out_idx = sorting_indices.get_unchecked(i); + validity.set_unchecked(*out_idx as usize, false); + } + } + + let array = PrimitiveArray::new( + T::get_dtype().to_arrow(CompatLevel::newest()), + out.into(), + Some(validity.into()), + ); + Ok(ChunkedArray::with_chunk(ca_sorted.name().clone(), array)) + } else { + Ok(ChunkedArray::from_vec(ca_sorted.name().clone(), out)) + } +} + +pub fn interpolate_by(s: &Column, by: &Column, by_is_sorted: bool) -> PolarsResult { + polars_ensure!(s.len() == by.len(), InvalidOperation: "`by` column must be the same length as Series ({}), got {}", s.len(), by.len()); + + fn func( + ca: &ChunkedArray, + by: &ChunkedArray, + is_sorted: bool, + ) -> PolarsResult + where + T: PolarsNumericType, + F: PolarsNumericType, + ChunkedArray: IntoColumn, + { + if is_sorted { + interpolate_impl_by_sorted(ca, by, |y_start, y_end, x, out| unsafe { + signed_interp_by_sorted(y_start, y_end, x, out) + }) + .map(|x| x.into_column()) + } else { + interpolate_impl_by(ca, by, |y_start, y_end, x, out, sorting_indices| unsafe { + signed_interp_by(y_start, y_end, x, out, sorting_indices) + }) + .map(|x| x.into_column()) + } + } + + match (s.dtype(), by.dtype()) { + (DataType::Float64, DataType::Float64) => { + func(s.f64().unwrap(), by.f64().unwrap(), by_is_sorted) + }, + (DataType::Float64, DataType::Float32) => { + func(s.f64().unwrap(), by.f32().unwrap(), by_is_sorted) + }, + (DataType::Float32, DataType::Float64) => { + func(s.f32().unwrap(), by.f64().unwrap(), by_is_sorted) + }, + (DataType::Float32, DataType::Float32) => { + func(s.f32().unwrap(), by.f32().unwrap(), by_is_sorted) + }, + (DataType::Float64, DataType::Int64) => { + func(s.f64().unwrap(), by.i64().unwrap(), by_is_sorted) + }, + (DataType::Float64, DataType::Int32) => { + func(s.f64().unwrap(), by.i32().unwrap(), by_is_sorted) + }, + (DataType::Float64, DataType::UInt64) => { + func(s.f64().unwrap(), by.u64().unwrap(), by_is_sorted) + }, + (DataType::Float64, DataType::UInt32) => { + func(s.f64().unwrap(), by.u32().unwrap(), by_is_sorted) + }, + (DataType::Float32, DataType::Int64) => { + func(s.f32().unwrap(), by.i64().unwrap(), by_is_sorted) + }, + (DataType::Float32, DataType::Int32) => { + func(s.f32().unwrap(), by.i32().unwrap(), by_is_sorted) + }, + (DataType::Float32, DataType::UInt64) => { + func(s.f32().unwrap(), by.u64().unwrap(), by_is_sorted) + }, + (DataType::Float32, DataType::UInt32) => { + func(s.f32().unwrap(), by.u32().unwrap(), by_is_sorted) + }, + #[cfg(feature = "dtype-date")] + (_, DataType::Date) => interpolate_by(s, &by.cast(&DataType::Int32).unwrap(), by_is_sorted), + #[cfg(feature = "dtype-datetime")] + (_, DataType::Datetime(_, _)) => { + interpolate_by(s, &by.cast(&DataType::Int64).unwrap(), by_is_sorted) + }, + (DataType::UInt64 | DataType::UInt32 | DataType::Int64 | DataType::Int32, _) => { + interpolate_by(&s.cast(&DataType::Float64).unwrap(), by, by_is_sorted) + }, + _ => { + polars_bail!(InvalidOperation: "expected series to be Float64, Float32, \ + Int64, Int32, UInt64, UInt32, and `by` to be Date, Datetime, Int64, Int32, \ + UInt64, UInt32, Float32 or Float64") + }, + } +} diff --git a/crates/polars-ops/src/series/ops/interpolation/mod.rs b/crates/polars-ops/src/series/ops/interpolation/mod.rs new file mode 100644 index 000000000000..44511ff35b4b --- /dev/null +++ b/crates/polars-ops/src/series/ops/interpolation/mod.rs @@ -0,0 +1,26 @@ +use std::ops::{Add, Div, Mul, Sub}; +#[cfg(feature = "interpolate")] +pub mod interpolate; +#[cfg(feature = "interpolate_by")] +pub mod interpolate_by; + +fn linear_itp(low: T, step: T, slope: T) -> T +where + T: Sub + Mul + Add + Div, +{ + low + step * slope +} + +fn nearest_itp(low: T, step: T, diff: T, steps_n: T) -> T +where + T: Sub + Mul + Add + Div + PartialOrd + Copy, +{ + // 5 - 1 = 5 -> low + // 5 - 2 = 3 -> low + // 5 - 3 = 2 -> high + if (steps_n - step) > step { + low + } else { + low + diff + } +} diff --git a/crates/polars-ops/src/series/ops/is_between.rs b/crates/polars-ops/src/series/ops/is_between.rs new file mode 100644 index 000000000000..f61de819b5b8 --- /dev/null +++ b/crates/polars-ops/src/series/ops/is_between.rs @@ -0,0 +1,24 @@ +use std::ops::BitAnd; + +use polars_core::prelude::*; + +use crate::series::ClosedInterval; + +pub fn is_between( + s: &Series, + lower: &Series, + upper: &Series, + closed: ClosedInterval, +) -> PolarsResult { + let left_cmp_op = match closed { + ClosedInterval::None | ClosedInterval::Right => Series::gt, + ClosedInterval::Both | ClosedInterval::Left => Series::gt_eq, + }; + let right_cmp_op = match closed { + ClosedInterval::None | ClosedInterval::Left => Series::lt, + ClosedInterval::Both | ClosedInterval::Right => Series::lt_eq, + }; + let left = left_cmp_op(s, lower)?; + let right = right_cmp_op(s, upper)?; + Ok(left.bitand(right)) +} diff --git a/crates/polars-ops/src/series/ops/is_first_distinct.rs b/crates/polars-ops/src/series/ops/is_first_distinct.rs new file mode 100644 index 000000000000..8a4950bfec6a --- /dev/null +++ b/crates/polars-ops/src/series/ops/is_first_distinct.rs @@ -0,0 +1,149 @@ +use std::hash::Hash; + +use arrow::array::BooleanArray; +use arrow::bitmap::MutableBitmap; +use arrow::legacy::bit_util::*; +use arrow::legacy::utils::CustomIterTools; +use polars_core::prelude::*; +use polars_core::with_match_physical_numeric_polars_type; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; +fn is_first_distinct_numeric(ca: &ChunkedArray) -> BooleanChunked +where + T: PolarsNumericType, + T::Native: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq, +{ + let mut unique = PlHashSet::new(); + let chunks = ca.downcast_iter().map(|arr| -> BooleanArray { + arr.into_iter() + .map(|opt_v| unique.insert(opt_v.to_total_ord())) + .collect_trusted() + }); + + BooleanChunked::from_chunk_iter(ca.name().clone(), chunks) +} + +fn is_first_distinct_bin(ca: &BinaryChunked) -> BooleanChunked { + let mut unique = PlHashSet::new(); + let chunks = ca.downcast_iter().map(|arr| -> BooleanArray { + arr.into_iter() + .map(|opt_v| unique.insert(opt_v)) + .collect_trusted() + }); + + BooleanChunked::from_chunk_iter(ca.name().clone(), chunks) +} + +fn is_first_distinct_boolean(ca: &BooleanChunked) -> BooleanChunked { + let mut out = MutableBitmap::with_capacity(ca.len()); + out.extend_constant(ca.len(), false); + + if ca.null_count() == ca.len() { + out.set(0, true); + } else { + let ca = ca.rechunk(); + let arr = ca.downcast_as_array(); + if ca.null_count() == 0 { + let (true_index, false_index) = + find_first_true_false_no_null(arr.values().chunks::()); + if let Some(idx) = true_index { + out.set(idx, true) + } + if let Some(idx) = false_index { + out.set(idx, true) + } + } else { + let (true_index, false_index, null_index) = find_first_true_false_null( + arr.values().chunks::(), + arr.validity().unwrap().chunks::(), + ); + if let Some(idx) = true_index { + out.set(idx, true) + } + if let Some(idx) = false_index { + out.set(idx, true) + } + if let Some(idx) = null_index { + out.set(idx, true) + } + } + } + let arr = BooleanArray::new(ArrowDataType::Boolean, out.into(), None); + BooleanChunked::with_chunk(ca.name().clone(), arr) +} + +#[cfg(feature = "dtype-struct")] +fn is_first_distinct_struct(s: &Series) -> PolarsResult { + let groups = s.group_tuples(true, false)?; + let first = groups.take_group_firsts(); + let mut out = MutableBitmap::with_capacity(s.len()); + out.extend_constant(s.len(), false); + + for idx in first { + // Group tuples are always in bounds + unsafe { out.set_unchecked(idx as usize, true) } + } + + let arr = BooleanArray::new(ArrowDataType::Boolean, out.into(), None); + Ok(BooleanChunked::with_chunk(s.name().clone(), arr)) +} + +fn is_first_distinct_list(ca: &ListChunked) -> PolarsResult { + let groups = ca.group_tuples(true, false)?; + let first = groups.take_group_firsts(); + let mut out = MutableBitmap::with_capacity(ca.len()); + out.extend_constant(ca.len(), false); + + for idx in first { + // Group tuples are always in bounds + unsafe { out.set_unchecked(idx as usize, true) } + } + + let arr = BooleanArray::new(ArrowDataType::Boolean, out.into(), None); + Ok(BooleanChunked::with_chunk(ca.name().clone(), arr)) +} + +pub fn is_first_distinct(s: &Series) -> PolarsResult { + // fast path. + if s.is_empty() { + return Ok(BooleanChunked::full_null(s.name().clone(), 0)); + } else if s.len() == 1 { + return Ok(BooleanChunked::new(s.name().clone(), &[true])); + } + + let s = s.to_physical_repr(); + + use DataType::*; + let out = match s.dtype() { + Boolean => { + let ca = s.bool().unwrap(); + is_first_distinct_boolean(ca) + }, + Binary => { + let ca = s.binary().unwrap(); + is_first_distinct_bin(ca) + }, + String => { + let s = s.cast(&Binary).unwrap(); + return is_first_distinct(&s); + }, + dt if dt.is_primitive_numeric() => { + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + is_first_distinct_numeric(ca) + }) + }, + #[cfg(feature = "dtype-struct")] + Struct(_) => return is_first_distinct_struct(&s), + List(inner) => { + polars_ensure!( + !inner.is_nested(), + InvalidOperation: "`is_first_distinct` on list type is only allowed if the inner type is not nested." + ); + let ca = s.list().unwrap(); + return is_first_distinct_list(ca); + }, + dt => polars_bail!(opq = is_first_distinct, dt), + }; + Ok(out) +} diff --git a/crates/polars-ops/src/series/ops/is_in.rs b/crates/polars-ops/src/series/ops/is_in.rs new file mode 100644 index 000000000000..75e40e4afe45 --- /dev/null +++ b/crates/polars-ops/src/series/ops/is_in.rs @@ -0,0 +1,763 @@ +use std::hash::Hash; + +use arrow::array::BooleanArray; +use arrow::bitmap::BitmapBuilder; +use polars_core::prelude::arity::{unary_elementwise, unary_elementwise_values}; +use polars_core::prelude::*; +use polars_core::with_match_physical_numeric_polars_type; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; + +use self::row_encode::_get_rows_encoded_ca_unordered; + +fn is_in_helper_ca<'a, T>( + ca: &'a ChunkedArray, + other: &'a ChunkedArray, + nulls_equal: bool, +) -> PolarsResult +where + T: PolarsDataType, + T::Physical<'a>: TotalHash + TotalEq + ToTotalOrd + Copy, + as ToTotalOrd>::TotalOrdItem: Hash + Eq + Copy, +{ + let mut set = PlHashSet::with_capacity(other.len()); + other.downcast_iter().for_each(|iter| { + iter.iter().for_each(|opt_val| { + if let Some(v) = opt_val { + set.insert(v.to_total_ord()); + } + }) + }); + + if nulls_equal { + if other.has_nulls() { + // If the rhs has nulls, then nulls in the left set evaluates to true. + Ok(unary_elementwise(ca, |val| { + val.is_none_or(|v| set.contains(&v.to_total_ord())) + })) + } else { + // The rhs has no nulls; nulls in the left evaluates to false. + Ok(unary_elementwise(ca, |val| { + val.is_some_and(|v| set.contains(&v.to_total_ord())) + })) + } + } else { + Ok( + unary_elementwise_values(ca, |v| set.contains(&v.to_total_ord())) + .with_name(ca.name().clone()), + ) + } +} + +fn is_in_helper_list_ca<'a, T>( + ca_in: &'a ChunkedArray, + other: &'a ListChunked, + nulls_equal: bool, +) -> PolarsResult +where + T: PolarsDataType, + for<'b> T::Physical<'b>: TotalHash + TotalEq + ToTotalOrd + Copy, + for<'b> as ToTotalOrd>::TotalOrdItem: Hash + Eq + Copy, +{ + let offsets = other.offsets()?; + let inner = other.get_inner(); + let inner: &ChunkedArray = inner.as_ref().as_ref(); + let validity = other.rechunk_validity(); + + let mut ca: BooleanChunked = if ca_in.len() == 1 && other.len() != 1 { + let value = ca_in.get(0); + + match value { + None if !nulls_equal => BooleanChunked::full_null(PlSmallStr::EMPTY, other.len()), + value => { + let mut builder = BitmapBuilder::with_capacity(other.len()); + + for (start, length) in offsets.offset_and_length_iter() { + let mut is_in = false; + for i in 0..length { + is_in |= value.to_total_ord() == inner.get(start + i).to_total_ord(); + } + builder.push(is_in); + } + + let values = builder.freeze(); + + let result = BooleanArray::new(ArrowDataType::Boolean, values, validity); + BooleanChunked::from_chunk_iter(PlSmallStr::EMPTY, [result]) + }, + } + } else { + assert_eq!(ca_in.len(), offsets.len_proxy()); + { + if nulls_equal { + let mut builder = BitmapBuilder::with_capacity(ca_in.len()); + + for (value, (start, length)) in ca_in.iter().zip(offsets.offset_and_length_iter()) { + let mut is_in = false; + for i in 0..length { + is_in |= value.to_total_ord() == inner.get(start + i).to_total_ord(); + } + builder.push(is_in); + } + + let values = builder.freeze(); + + let result = BooleanArray::new(ArrowDataType::Boolean, values, validity); + BooleanChunked::from_chunk_iter(PlSmallStr::EMPTY, [result]) + } else { + let mut builder = BitmapBuilder::with_capacity(ca_in.len()); + + for (value, (start, length)) in ca_in.iter().zip(offsets.offset_and_length_iter()) { + let mut is_in = false; + if value.is_some() { + for i in 0..length { + is_in |= value.to_total_ord() == inner.get(start + i).to_total_ord(); + } + } + builder.push(is_in); + } + + let values = builder.freeze(); + + let validity = match (validity, ca_in.rechunk_validity()) { + (None, None) => None, + (Some(v), None) | (None, Some(v)) => Some(v), + (Some(l), Some(r)) => Some(arrow::bitmap::and(&l, &r)), + }; + + let result = BooleanArray::new(ArrowDataType::Boolean, values, validity); + BooleanChunked::from_chunk_iter(PlSmallStr::EMPTY, [result]) + } + } + }; + ca.rename(ca_in.name().clone()); + Ok(ca) +} + +#[cfg(feature = "dtype-array")] +fn is_in_helper_array_ca<'a, T>( + ca_in: &'a ChunkedArray, + other: &'a ArrayChunked, + nulls_equal: bool, +) -> PolarsResult +where + T: PolarsDataType, + for<'b> T::Physical<'b>: TotalHash + TotalEq + ToTotalOrd + Copy, + for<'b> as ToTotalOrd>::TotalOrdItem: Hash + Eq + Copy, +{ + let width = other.width(); + let inner = other.get_inner(); + let inner: &ChunkedArray = inner.as_ref().as_ref(); + let validity = other.rechunk_validity(); + + let mut ca: BooleanChunked = if ca_in.len() == 1 && other.len() != 1 { + let value = ca_in.get(0); + + match value { + None if !nulls_equal => BooleanChunked::full_null(PlSmallStr::EMPTY, other.len()), + value => { + let mut builder = BitmapBuilder::with_capacity(other.len()); + + for i in 0..other.len() { + let mut is_in = false; + for j in 0..width { + is_in |= value.to_total_ord() == inner.get(i * width + j).to_total_ord(); + } + builder.push(is_in); + } + + let values = builder.freeze(); + + let result = BooleanArray::new(ArrowDataType::Boolean, values, validity); + BooleanChunked::from_chunk_iter(PlSmallStr::EMPTY, [result]) + }, + } + } else { + assert_eq!(ca_in.len(), other.len()); + { + if nulls_equal { + let mut builder = BitmapBuilder::with_capacity(ca_in.len()); + + for (i, value) in ca_in.iter().enumerate() { + let mut is_in = false; + for j in 0..width { + is_in |= value.to_total_ord() == inner.get(i * width + j).to_total_ord(); + } + builder.push(is_in); + } + + let values = builder.freeze(); + + let result = BooleanArray::new(ArrowDataType::Boolean, values, validity); + BooleanChunked::from_chunk_iter(PlSmallStr::EMPTY, [result]) + } else { + let mut builder = BitmapBuilder::with_capacity(ca_in.len()); + + for (i, value) in ca_in.iter().enumerate() { + let mut is_in = false; + if value.is_some() { + for j in 0..width { + is_in |= + value.to_total_ord() == inner.get(i * width + j).to_total_ord(); + } + } + builder.push(is_in); + } + + let values = builder.freeze(); + + let validity = match (validity, ca_in.rechunk_validity()) { + (None, None) => None, + (Some(v), None) | (None, Some(v)) => Some(v), + (Some(l), Some(r)) => Some(arrow::bitmap::and(&l, &r)), + }; + + let result = BooleanArray::new(ArrowDataType::Boolean, values, validity); + BooleanChunked::from_chunk_iter(PlSmallStr::EMPTY, [result]) + } + } + }; + ca.rename(ca_in.name().clone()); + Ok(ca) +} + +fn is_in_numeric( + ca_in: &ChunkedArray, + other: &Series, + nulls_equal: bool, +) -> PolarsResult +where + T: PolarsNumericType, + T::Native: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq + Copy, +{ + match other.dtype() { + DataType::List(..) => { + let other = other.list()?; + if other.len() == 1 { + if other.has_nulls() { + return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len())); + } + + let other = other.explode()?; + let other = other.as_ref().as_ref(); + is_in_helper_ca(ca_in, other, nulls_equal) + } else { + is_in_helper_list_ca(ca_in, other, nulls_equal) + } + }, + #[cfg(feature = "dtype-array")] + DataType::Array(..) => { + let other = other.array()?; + if other.len() == 1 { + if other.has_nulls() { + return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len())); + } + + let other = other.explode()?; + let other = other.as_ref().as_ref(); + is_in_helper_ca(ca_in, other, nulls_equal) + } else { + is_in_helper_array_ca(ca_in, other, nulls_equal) + } + }, + _ => polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()), + } +} + +fn is_in_string( + ca_in: &StringChunked, + other: &Series, + nulls_equal: bool, +) -> PolarsResult { + let other = match other.dtype() { + DataType::List(dt) if dt.is_string() || dt.is_enum() || dt.is_categorical() => { + let other = other.list()?; + other + .apply_to_inner(&|mut s| { + if dt.is_enum() || dt.is_categorical() { + s = s.cast(&DataType::String)?; + } + let s = s.str()?; + Ok(s.as_binary().into_series()) + })? + .into_series() + }, + #[cfg(feature = "dtype-array")] + DataType::Array(dt, _) if dt.is_string() || dt.is_enum() || dt.is_categorical() => { + let other = other.array()?; + other + .apply_to_inner(&|mut s| { + if dt.is_enum() || dt.is_categorical() { + s = s.cast(&DataType::String)?; + } + Ok(s.str()?.as_binary().into_series()) + })? + .into_series() + }, + _ => polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()), + }; + is_in_binary(&ca_in.as_binary(), &other, nulls_equal) +} + +fn is_in_binary( + ca_in: &BinaryChunked, + other: &Series, + nulls_equal: bool, +) -> PolarsResult { + match other.dtype() { + DataType::List(dt) if DataType::Binary == **dt => { + let other = other.list()?; + if other.len() == 1 { + if other.has_nulls() { + return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len())); + } + + let other = other.explode()?; + let other = other.binary()?; + is_in_helper_ca(ca_in, other, nulls_equal) + } else { + is_in_helper_list_ca(ca_in, other, nulls_equal) + } + }, + #[cfg(feature = "dtype-array")] + DataType::Array(dt, _) if DataType::Binary == **dt => { + let other = other.array()?; + if other.len() == 1 { + if other.has_nulls() { + return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len())); + } + + let other = other.explode()?; + let other = other.binary()?; + is_in_helper_ca(ca_in, other, nulls_equal) + } else { + is_in_helper_array_ca(ca_in, other, nulls_equal) + } + }, + _ => polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()), + } +} + +fn is_in_boolean( + ca_in: &BooleanChunked, + other: &Series, + nulls_equal: bool, +) -> PolarsResult { + fn is_in_boolean_broadcast( + ca_in: &BooleanChunked, + other: &BooleanChunked, + nulls_equal: bool, + ) -> PolarsResult { + let has_true = other.any(); + let nc = other.null_count(); + + let has_false = if nc == 0 { + !other.all() + } else { + (other.sum().unwrap() as usize + nc) != other.len() + }; + let value_map = |v| if v { has_true } else { has_false }; + if nulls_equal { + if other.has_nulls() { + // If the rhs has nulls, then nulls in the left set evaluates to true. + Ok(ca_in.apply(|opt_v| Some(opt_v.is_none_or(value_map)))) + } else { + // The rhs has no nulls; nulls in the left evaluates to false. + Ok(ca_in.apply(|opt_v| Some(opt_v.is_some_and(value_map)))) + } + } else { + Ok(ca_in + .apply_values(value_map) + .with_name(ca_in.name().clone())) + } + } + + match other.dtype() { + DataType::List(dt) if ca_in.dtype() == &**dt => { + let other = other.list()?; + if other.len() == 1 { + if other.has_nulls() { + return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len())); + } + + let other = other.explode()?; + let other = other.bool()?; + is_in_boolean_broadcast(ca_in, other, nulls_equal) + } else { + is_in_helper_list_ca(ca_in, other, nulls_equal) + } + }, + #[cfg(feature = "dtype-array")] + DataType::Array(dt, _) if ca_in.dtype() == &**dt => { + let other = other.array()?; + if other.len() == 1 { + if other.has_nulls() { + return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len())); + } + + let other = other.explode()?; + let other = other.bool()?; + is_in_boolean_broadcast(ca_in, other, nulls_equal) + } else { + is_in_helper_array_ca(ca_in, other, nulls_equal) + } + }, + _ => polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()), + } +} + +#[cfg(feature = "dtype-categorical")] +fn is_in_cat_and_enum( + ca_in: &CategoricalChunked, + other: &Series, + nulls_equal: bool, +) -> PolarsResult { + use std::borrow::Cow; + + use arrow::array::{Array, FixedSizeListArray, IntoBoxedArray, ListArray}; + + let mut needs_remap = false; + let to_categories = match (ca_in.dtype(), other.dtype().inner_dtype().unwrap()) { + (DataType::Enum(revmap, ordering), DataType::String) => { + let categories = revmap.as_deref().unwrap().get_categories(); + (&|s: Series| { + let ca = s.str()?; + let ca = CategoricalChunked::from_string_to_enum(ca, categories, *ordering)?; + let ca = ca.into_physical(); + Ok(ca.into_series()) + }) as _ + }, + (DataType::Categorical(revmap, ordering), DataType::String) => { + (&|s: Series| { + let categories = revmap.as_deref().unwrap().get_categories(); + let ca = s.str()?; + let ca = + if ca_in.get_rev_map().is_local() { + assert!(categories.len() < u32::MAX as usize); + let cats = PlIndexSet::from_iter(categories.values_iter()); + UInt32Chunked::from_iter(ca.iter().map(|v| { + v.map(|v| cats.get_index_of(v).map_or(u32::MAX, |n| n as u32)) + })) + } else { + let cat = ca.cast(&DataType::Categorical(None, *ordering))?; + cat.categorical()?.physical().clone() + }; + Ok(ca.into_series()) + }) as _ + }, + (DataType::Categorical(revmap, _), DataType::Categorical(other_revmap, _)) => { + let (Some(revmap), Some(other_revmap)) = (revmap, other_revmap) else { + polars_bail!(ComputeError: "expected revmap to be set at this point"); + }; + needs_remap = !revmap.same_src(other_revmap); + (&|s: Series| { + let ca = s.categorical()?; + let ca = ca.physical().clone(); + Ok(ca.into_series()) + }) as _ + }, + (DataType::Enum(revmap, _), DataType::Enum(other_revmap, _)) => { + let (Some(revmap), Some(other_revmap)) = (revmap, other_revmap) else { + polars_bail!(ComputeError: "expected revmap to be set at this point"); + }; + polars_ensure!( + revmap.same_src(other_revmap), + opq = is_in, + ca_in.dtype(), + other.dtype() + ); + (&|s: Series| { + let ca = s.categorical()?; + let ca = ca.physical().clone(); + Ok(ca.into_series()) + }) as _ + }, + _ => polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()), + }; + + let mut ca_in = Cow::Borrowed(ca_in); + let other = match other.dtype() { + DataType::List(_) => { + let mut other = Cow::Borrowed(other.list()?); + if needs_remap { + let other_rechunked = other.rechunk(); + let other_arr = other_rechunked.downcast_as_array(); + + let other_inner = other.get_inner(); + let other_offsets = other_arr.offsets().clone(); + let other_inner = other_inner.categorical()?; + let (ca_in_remapped, other_inner) = + make_rhs_categoricals_compatible(&ca_in, other_inner)?; + + let other_inner_phys = other_inner.physical().rechunk(); + let other_inner_phys = other_inner_phys.downcast_as_array(); + + let other_phys = ListArray::try_new( + other_arr.dtype().clone(), + other_offsets, + other_inner_phys.clone().into_boxed(), + other_arr.validity().cloned(), + )?; + + other = Cow::Owned(unsafe { + ListChunked::from_chunks_and_dtype( + other.name().clone(), + vec![other_phys.into_boxed()], + DataType::List(Box::new(other_inner.dtype().clone())), + ) + }); + ca_in = Cow::Owned(ca_in_remapped); + } + let other = other.apply_to_inner(to_categories)?; + other.into_series() + }, + #[cfg(feature = "dtype-array")] + DataType::Array(_, _) => { + let mut other = Cow::Borrowed(other.array()?); + if needs_remap { + let other_rechunked = other.rechunk(); + let other_arr = other_rechunked.downcast_as_array(); + + let other_inner = other.get_inner(); + let other_inner = other_inner.categorical()?; + let (ca_in_remapped, other_inner) = + make_rhs_categoricals_compatible(&ca_in, other_inner)?; + + let other_inner_phys = other_inner.physical().rechunk(); + let other_inner_phys = other_inner_phys.downcast_as_array(); + + let other_phys = FixedSizeListArray::try_new( + other_arr.dtype().clone(), + other.len(), + other_inner_phys.clone().into_boxed(), + other_arr.validity().cloned(), + )?; + + other = Cow::Owned(unsafe { + ArrayChunked::from_chunks_and_dtype( + other.name().clone(), + vec![other_phys.into_boxed()], + DataType::Array(Box::new(other_inner.dtype().clone()), other.width()), + ) + }); + ca_in = Cow::Owned(ca_in_remapped); + } + let other = other.apply_to_inner(to_categories)?; + other.into_series() + }, + _ => polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()), + }; + + is_in_numeric(ca_in.physical(), &other, nulls_equal) +} + +fn is_in_null(s: &Series, other: &Series, nulls_equal: bool) -> PolarsResult { + if nulls_equal { + let ca_in = s.null()?; + Ok(match other.dtype() { + DataType::List(_) => { + let other = other.list()?; + if other.len() == 1 { + if other.has_nulls() { + return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len())); + } + + let other = other.explode()?; + BooleanChunked::from_iter_values( + ca_in.name().clone(), + std::iter::repeat_n(other.has_nulls(), ca_in.len()), + ) + } else { + other.apply_amortized_generic(|opt_s| { + Some(opt_s.map(|s| s.as_ref().has_nulls()) == Some(true)) + }) + } + }, + #[cfg(feature = "dtype-array")] + DataType::Array(_, _) => { + let other = other.array()?; + if other.len() == 1 { + if other.has_nulls() { + return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len())); + } + + let other = other.explode()?; + BooleanChunked::from_iter_values( + ca_in.name().clone(), + std::iter::repeat_n(other.has_nulls(), ca_in.len()), + ) + } else { + other.apply_amortized_generic(|opt_s| { + Some(opt_s.map(|s| s.as_ref().has_nulls()) == Some(true)) + }) + } + }, + _ => polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()), + }) + } else { + let out = s.cast(&DataType::Boolean)?; + let ca_bool = out.bool()?.clone(); + Ok(ca_bool) + } +} + +#[cfg(feature = "dtype-decimal")] +fn is_in_decimal( + ca_in: &DecimalChunked, + other: &Series, + nulls_equal: bool, +) -> PolarsResult { + let Some(DataType::Decimal(_, other_scale)) = other.dtype().inner_dtype() else { + polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()); + }; + let other_scale = other_scale.unwrap(); + let scale = ca_in.scale().max(other_scale); + let ca_in = ca_in.to_scale(scale)?; + + match other.dtype() { + DataType::List(_) => { + let other = other.list()?; + let other = other.apply_to_inner(&|s| { + let s = s.decimal()?; + let s = s.to_scale(scale)?; + let s = s.physical(); + Ok(s.to_owned().into_series()) + })?; + let other = other.into_series(); + is_in_numeric(ca_in.physical(), &other, nulls_equal) + }, + #[cfg(feature = "dtype-array")] + DataType::Array(_, _) => { + let other = other.array()?; + let other = other.apply_to_inner(&|s| { + let s = s.decimal()?; + let s = s.to_scale(scale)?; + let s = s.physical(); + Ok(s.to_owned().into_series()) + })?; + let other = other.into_series(); + is_in_numeric(ca_in.physical(), &other, nulls_equal) + }, + _ => unreachable!(), + } +} + +fn is_in_row_encoded( + s: &Series, + other: &Series, + nulls_equal: bool, +) -> PolarsResult { + let ca_in = _get_rows_encoded_ca_unordered(s.name().clone(), &[s.clone().into_column()])?; + let mut mask = match other.dtype() { + DataType::List(_) => { + let other = other.list()?; + let other = other.apply_to_inner(&|s| { + Ok( + _get_rows_encoded_ca_unordered(s.name().clone(), &[s.into_column()])? + .into_series(), + ) + })?; + if other.len() == 1 { + if other.has_nulls() { + return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len())); + } + + let other = other.explode()?; + let other = other.binary_offset()?; + is_in_helper_ca(&ca_in, other, nulls_equal) + } else { + is_in_helper_list_ca(&ca_in, &other, nulls_equal) + } + }, + #[cfg(feature = "dtype-array")] + DataType::Array(_, _) => { + let other = other.array()?; + let other = other.apply_to_inner(&|s| { + Ok( + _get_rows_encoded_ca_unordered(s.name().clone(), &[s.into_column()])? + .into_series(), + ) + })?; + if other.len() == 1 { + if other.has_nulls() { + return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len())); + } + + let other = other.explode()?; + let other = other.binary_offset()?; + is_in_helper_ca(&ca_in, other, nulls_equal) + } else { + is_in_helper_array_ca(&ca_in, &other, nulls_equal) + } + }, + _ => unreachable!(), + }?; + + let mut validity = other.rechunk_validity(); + if !nulls_equal { + validity = match (validity, s.rechunk_validity()) { + (None, None) => None, + (Some(v), None) | (None, Some(v)) => Some(v), + (Some(l), Some(r)) => Some(arrow::bitmap::and(&l, &r)), + }; + } + + assert_eq!(mask.null_count(), 0); + mask.with_validities(&[validity]); + + Ok(mask) +} + +pub fn is_in(s: &Series, other: &Series, nulls_equal: bool) -> PolarsResult { + polars_ensure!( + s.len() == other.len() || s.len() == 1 || other.len() == 1, + length_mismatch = "is_in", + s.len(), + other.len() + ); + + #[allow(unused_mut)] + let mut other_is_valid_type = matches!(other.dtype(), DataType::List(_)); + #[cfg(feature = "dtype-array")] + { + other_is_valid_type |= matches!(other.dtype(), DataType::Array(..)) + } + polars_ensure!(other_is_valid_type, opq = is_in, s.dtype(), other.dtype()); + + match s.dtype() { + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_, _) | DataType::Enum(_, _) => { + let ca = s.categorical().unwrap(); + is_in_cat_and_enum(ca, other, nulls_equal) + }, + DataType::String => { + let ca = s.str().unwrap(); + is_in_string(ca, other, nulls_equal) + }, + DataType::Binary => { + let ca = s.binary().unwrap(); + is_in_binary(ca, other, nulls_equal) + }, + DataType::Boolean => { + let ca = s.bool().unwrap(); + is_in_boolean(ca, other, nulls_equal) + }, + DataType::Null => is_in_null(s, other, nulls_equal), + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(_, _) => { + let ca_in = s.decimal()?; + is_in_decimal(ca_in, other, nulls_equal) + }, + dt if dt.is_nested() => is_in_row_encoded(s, other, nulls_equal), + dt if dt.to_physical().is_primitive_numeric() => { + let s = s.to_physical_repr(); + let other = other.to_physical_repr(); + let other = other.as_ref(); + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + is_in_numeric(ca, other, nulls_equal) + }) + }, + dt => polars_bail!(opq = is_in, dt), + } +} diff --git a/crates/polars-ops/src/series/ops/is_last_distinct.rs b/crates/polars-ops/src/series/ops/is_last_distinct.rs new file mode 100644 index 000000000000..2a3528525307 --- /dev/null +++ b/crates/polars-ops/src/series/ops/is_last_distinct.rs @@ -0,0 +1,161 @@ +use std::hash::Hash; + +use arrow::array::BooleanArray; +use arrow::bitmap::MutableBitmap; +use arrow::legacy::utils::CustomIterTools; +use polars_core::prelude::*; +use polars_core::utils::NoNull; +use polars_core::with_match_physical_numeric_polars_type; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; + +pub fn is_last_distinct(s: &Series) -> PolarsResult { + // fast path. + if s.is_empty() { + return Ok(BooleanChunked::full_null(s.name().clone(), 0)); + } else if s.len() == 1 { + return Ok(BooleanChunked::new(s.name().clone(), &[true])); + } + + let s = s.to_physical_repr(); + + use DataType::*; + let out = match s.dtype() { + Boolean => { + let ca = s.bool().unwrap(); + is_last_distinct_boolean(ca) + }, + Binary => { + let ca = s.binary().unwrap(); + is_last_distinct_bin(ca) + }, + String => { + let s = s.cast(&Binary).unwrap(); + return is_last_distinct(&s); + }, + dt if dt.is_primitive_numeric() => { + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + is_last_distinct_numeric(ca) + }) + }, + #[cfg(feature = "dtype-struct")] + Struct(_) => return is_last_distinct_struct(&s), + List(inner) => { + polars_ensure!( + !inner.is_nested(), + InvalidOperation: "`is_last_distinct` on list type is only allowed if the inner type is not nested." + ); + let ca = s.list().unwrap(); + return is_last_distinct_list(ca); + }, + dt => polars_bail!(opq = is_last_distinct, dt), + }; + Ok(out) +} + +fn is_last_distinct_boolean(ca: &BooleanChunked) -> BooleanChunked { + let mut out = MutableBitmap::with_capacity(ca.len()); + out.extend_constant(ca.len(), false); + + if ca.null_count() == ca.len() { + out.set(ca.len() - 1, true); + } + // TODO supports fast path. + else { + let mut first_true_found = false; + let mut first_false_found = false; + let mut first_null_found = false; + let mut all_found = false; + let ca = ca.rechunk(); + ca.downcast_as_array() + .iter() + .enumerate() + .rev() + .find_map(|(idx, val)| match val { + Some(true) if !first_true_found => { + first_true_found = true; + all_found &= first_true_found; + out.set(idx, true); + if all_found { Some(()) } else { None } + }, + Some(false) if !first_false_found => { + first_false_found = true; + all_found &= first_false_found; + out.set(idx, true); + if all_found { Some(()) } else { None } + }, + None if !first_null_found => { + first_null_found = true; + all_found &= first_null_found; + out.set(idx, true); + if all_found { Some(()) } else { None } + }, + _ => None, + }); + } + + let arr = BooleanArray::new(ArrowDataType::Boolean, out.into(), None); + BooleanChunked::with_chunk(ca.name().clone(), arr) +} + +fn is_last_distinct_bin(ca: &BinaryChunked) -> BooleanChunked { + let tmp = ca.rechunk(); + let arr = tmp.downcast_as_array(); + let mut unique = PlHashSet::new(); + arr.iter() + .rev() + .map(|opt_v| unique.insert(opt_v)) + .collect_reversed::>() + .into_inner() + .with_name(ca.name().clone()) +} + +fn is_last_distinct_numeric(ca: &ChunkedArray) -> BooleanChunked +where + T: PolarsNumericType, + T::Native: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq, +{ + let tmp = ca.rechunk(); + let arr = tmp.downcast_as_array(); + let mut unique = PlHashSet::new(); + arr.iter() + .rev() + .map(|opt_v| unique.insert(opt_v.to_total_ord())) + .collect_reversed::>() + .into_inner() + .with_name(ca.name().clone()) +} + +#[cfg(feature = "dtype-struct")] +fn is_last_distinct_struct(s: &Series) -> PolarsResult { + let groups = s.group_tuples(true, false)?; + // SAFETY: all groups have at least a single member + let last = unsafe { groups.take_group_lasts() }; + let mut out = MutableBitmap::with_capacity(s.len()); + out.extend_constant(s.len(), false); + + for idx in last { + // Group tuples are always in bounds + unsafe { out.set_unchecked(idx as usize, true) } + } + + let arr = BooleanArray::new(ArrowDataType::Boolean, out.into(), None); + Ok(BooleanChunked::with_chunk(s.name().clone(), arr)) +} + +fn is_last_distinct_list(ca: &ListChunked) -> PolarsResult { + let groups = ca.group_tuples(true, false)?; + // SAFETY: all groups have at least a single member + let last = unsafe { groups.take_group_lasts() }; + let mut out = MutableBitmap::with_capacity(ca.len()); + out.extend_constant(ca.len(), false); + + for idx in last { + // Group tuples are always in bounds + unsafe { out.set_unchecked(idx as usize, true) } + } + + let arr = BooleanArray::new(ArrowDataType::Boolean, out.into(), None); + Ok(BooleanChunked::with_chunk(ca.name().clone(), arr)) +} diff --git a/crates/polars-ops/src/series/ops/is_unique.rs b/crates/polars-ops/src/series/ops/is_unique.rs new file mode 100644 index 000000000000..4cafa1ca435b --- /dev/null +++ b/crates/polars-ops/src/series/ops/is_unique.rs @@ -0,0 +1,99 @@ +use std::hash::Hash; + +use arrow::array::BooleanArray; +use arrow::bitmap::MutableBitmap; +use polars_core::prelude::*; +use polars_core::with_match_physical_integer_polars_type; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; + +// If invert is true then this is an `is_duplicated`. +fn is_unique_ca<'a, T>(ca: &'a ChunkedArray, invert: bool) -> BooleanChunked +where + T: PolarsDataType, + T::Physical<'a>: TotalHash + TotalEq + Copy + ToTotalOrd, + > as ToTotalOrd>::TotalOrdItem: Hash + Eq, +{ + let len = ca.len(); + let mut idx_key = PlHashMap::new(); + + // Instead of group_tuples, which allocates a full Vec per group, we now + // just toggle a boolean that's false if a group has multiple entries. + ca.iter().enumerate().for_each(|(idx, key)| { + idx_key + .entry(key.to_total_ord()) + .and_modify(|v: &mut (IdxSize, bool)| v.1 = false) + .or_insert((idx as IdxSize, true)); + }); + + let unique_idx = idx_key + .into_iter() + .filter_map(|(_k, v)| if v.1 { Some(v.0) } else { None }); + + let (default, setter) = if invert { (true, false) } else { (false, true) }; + let mut values = MutableBitmap::with_capacity(len); + values.extend_constant(len, default); + for idx in unique_idx { + unsafe { values.set_unchecked(idx as usize, setter) } + } + let arr = BooleanArray::from_data_default(values.into(), None); + BooleanChunked::with_chunk(ca.name().clone(), arr) +} + +fn dispatcher(s: &Series, invert: bool) -> PolarsResult { + let s = s.to_physical_repr(); + use DataType::*; + let out = match s.dtype() { + Boolean => { + let ca = s.bool().unwrap(); + is_unique_ca(ca, invert) + }, + Binary => { + let ca = s.binary().unwrap(); + is_unique_ca(ca, invert) + }, + String => { + let s = s.cast(&Binary).unwrap(); + let ca = s.binary().unwrap(); + is_unique_ca(ca, invert) + }, + Float32 => { + let ca = s.f32().unwrap(); + is_unique_ca(ca, invert) + }, + Float64 => { + let ca = s.f64().unwrap(); + is_unique_ca(ca, invert) + }, + #[cfg(feature = "dtype-struct")] + Struct(_) => { + let ca = s.struct_().unwrap().clone(); + let df = ca.unnest(); + return if invert { + df.is_duplicated() + } else { + df.is_unique() + }; + }, + Null => match s.len() { + 0 => BooleanChunked::new(s.name().clone(), [] as [bool; 0]), + 1 => BooleanChunked::new(s.name().clone(), [!invert]), + len => BooleanChunked::full(s.name().clone(), invert, len), + }, + dt if dt.is_primitive_numeric() => { + with_match_physical_integer_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + is_unique_ca(ca, invert) + }) + }, + dt => polars_bail!(opq = is_unique, dt), + }; + Ok(out) +} + +pub fn is_unique(s: &Series) -> PolarsResult { + dispatcher(s, false) +} + +pub fn is_duplicated(s: &Series) -> PolarsResult { + dispatcher(s, true) +} diff --git a/crates/polars-ops/src/series/ops/linear_space.rs b/crates/polars-ops/src/series/ops/linear_space.rs new file mode 100644 index 000000000000..71cff210c6a0 --- /dev/null +++ b/crates/polars-ops/src/series/ops/linear_space.rs @@ -0,0 +1,102 @@ +use polars_core::prelude::*; +use polars_core::series::IsSorted; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; + +#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Default, IntoStaticStr)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] +pub enum ClosedInterval { + #[default] + Both, + Left, + Right, + None, +} + +pub fn new_linear_space_f32( + start: f32, + end: f32, + n: u64, + closed: ClosedInterval, + name: PlSmallStr, +) -> PolarsResult { + let mut ca = match n { + 0 => Float32Chunked::full_null(name, 0), + 1 => match closed { + ClosedInterval::None => Float32Chunked::from_slice(name, &[(end + start) * 0.5]), + ClosedInterval::Left | ClosedInterval::Both => { + Float32Chunked::from_slice(name, &[start]) + }, + ClosedInterval::Right => Float32Chunked::from_slice(name, &[end]), + }, + _ => Float32Chunked::from_iter_values(name, { + let span = end - start; + + let (start, d, end) = match closed { + ClosedInterval::None => { + let d = span / (n + 1) as f32; + (start + d, d, end - d) + }, + ClosedInterval::Left => (start, span / n as f32, end - span / n as f32), + ClosedInterval::Right => (start + span / n as f32, span / n as f32, end), + ClosedInterval::Both => (start, span / (n - 1) as f32, end), + }; + (0..n - 1) + .map(move |v| (v as f32 * d) + start) + .chain(std::iter::once(end)) // ensures floating point accuracy of final value + }), + }; + + let is_sorted = if end < start { + IsSorted::Descending + } else { + IsSorted::Ascending + }; + ca.set_sorted_flag(is_sorted); + Ok(ca) +} + +pub fn new_linear_space_f64( + start: f64, + end: f64, + n: u64, + closed: ClosedInterval, + name: PlSmallStr, +) -> PolarsResult { + let mut ca = match n { + 0 => Float64Chunked::full_null(name, 0), + 1 => match closed { + ClosedInterval::None => Float64Chunked::from_slice(name, &[(end + start) * 0.5]), + ClosedInterval::Left | ClosedInterval::Both => { + Float64Chunked::from_slice(name, &[start]) + }, + ClosedInterval::Right => Float64Chunked::from_slice(name, &[end]), + }, + _ => Float64Chunked::from_iter_values(name, { + let span = end - start; + + let (start, d, end) = match closed { + ClosedInterval::None => { + let d = span / (n + 1) as f64; + (start + d, d, end - d) + }, + ClosedInterval::Left => (start, span / n as f64, end - span / n as f64), + ClosedInterval::Right => (start + span / n as f64, span / n as f64, end), + ClosedInterval::Both => (start, span / (n - 1) as f64, end), + }; + (0..n - 1) + .map(move |v| (v as f64 * d) + start) + .chain(std::iter::once(end)) // ensures floating point accuracy of final value + }), + }; + + let is_sorted = if end < start { + IsSorted::Descending + } else { + IsSorted::Ascending + }; + ca.set_sorted_flag(is_sorted); + Ok(ca) +} diff --git a/crates/polars-ops/src/series/ops/log.rs b/crates/polars-ops/src/series/ops/log.rs new file mode 100644 index 000000000000..452935f9f6a0 --- /dev/null +++ b/crates/polars-ops/src/series/ops/log.rs @@ -0,0 +1,131 @@ +use polars_core::prelude::*; +use polars_core::with_match_physical_integer_polars_type; + +use crate::series::ops::SeriesSealed; + +fn log(ca: &ChunkedArray, base: f64) -> Float64Chunked { + ca.cast_and_apply_in_place(|v: f64| v.log(base)) +} + +fn log1p(ca: &ChunkedArray) -> Float64Chunked { + ca.cast_and_apply_in_place(|v: f64| v.ln_1p()) +} + +fn exp(ca: &ChunkedArray) -> Float64Chunked { + ca.cast_and_apply_in_place(|v: f64| v.exp()) +} + +pub trait LogSeries: SeriesSealed { + /// Compute the logarithm to a given base + fn log(&self, base: f64) -> Series { + let s = self.as_series(); + if s.dtype().is_decimal() { + return s.cast(&DataType::Float64).unwrap().log(base); + } + + let s = s.to_physical_repr(); + let s = s.as_ref(); + + use DataType::*; + match s.dtype() { + dt if dt.is_integer() => { + with_match_physical_integer_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + log(ca, base).into_series() + }) + }, + Float32 => s + .f32() + .unwrap() + .apply_values(|v| v.log(base as f32)) + .into_series(), + Float64 => s.f64().unwrap().apply_values(|v| v.log(base)).into_series(), + _ => s.cast(&DataType::Float64).unwrap().log(base), + } + } + + /// Compute the natural logarithm of all elements plus one in the input array + fn log1p(&self) -> Series { + let s = self.as_series(); + if s.dtype().is_decimal() { + return s.cast(&DataType::Float64).unwrap().log1p(); + } + + let s = s.to_physical_repr(); + let s = s.as_ref(); + + use DataType::*; + match s.dtype() { + dt if dt.is_integer() => { + with_match_physical_integer_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + log1p(ca).into_series() + }) + }, + Float32 => s.f32().unwrap().apply_values(|v| v.ln_1p()).into_series(), + Float64 => s.f64().unwrap().apply_values(|v| v.ln_1p()).into_series(), + _ => s.cast(&DataType::Float64).unwrap().log1p(), + } + } + + /// Calculate the exponential of all elements in the input array. + fn exp(&self) -> Series { + let s = self.as_series(); + if s.dtype().is_decimal() { + return s.cast(&DataType::Float64).unwrap().exp(); + } + + let s = s.to_physical_repr(); + let s = s.as_ref(); + + use DataType::*; + match s.dtype() { + dt if dt.is_integer() => { + with_match_physical_integer_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + exp(ca).into_series() + }) + }, + Float32 => s.f32().unwrap().apply_values(|v| v.exp()).into_series(), + Float64 => s.f64().unwrap().apply_values(|v| v.exp()).into_series(), + _ => s.cast(&DataType::Float64).unwrap().exp(), + } + } + + /// Compute the entropy as `-sum(pk * log(pk)`. + /// where `pk` are discrete probabilities. + fn entropy(&self, base: f64, normalize: bool) -> PolarsResult { + let s = self.as_series().to_physical_repr(); + polars_ensure!(s.dtype().is_primitive_numeric(), InvalidOperation: "expected numerical input for 'entropy'"); + // if there is only one value in the series, return 0.0 to prevent the + // function from returning -0.0 + if s.len() == 1 { + return Ok(0.0); + } + match s.dtype() { + DataType::Float32 | DataType::Float64 => { + let pk = s.as_ref(); + + let pk = if normalize { + let sum = pk.sum_reduce().unwrap().into_series(PlSmallStr::EMPTY); + + if sum.get(0).unwrap().extract::().unwrap() != 1.0 { + (pk / &sum)? + } else { + pk.clone() + } + } else { + pk.clone() + }; + + let log_pk = pk.log(base); + (&pk * &log_pk)?.sum::().map(|v| -v) + }, + _ => s + .cast(&DataType::Float64) + .map(|s| s.entropy(base, normalize))?, + } + } +} + +impl LogSeries for Series {} diff --git a/crates/polars-ops/src/series/ops/mod.rs b/crates/polars-ops/src/series/ops/mod.rs new file mode 100644 index 000000000000..cf119ce1359d --- /dev/null +++ b/crates/polars-ops/src/series/ops/mod.rs @@ -0,0 +1,156 @@ +#[cfg(feature = "abs")] +mod abs; +mod arg_min_max; +mod bitwise; +#[cfg(feature = "business")] +mod business; +mod clip; +#[cfg(feature = "cum_agg")] +mod cum_agg; +#[cfg(feature = "cutqcut")] +mod cut; +#[cfg(feature = "diff")] +mod diff; +#[cfg(feature = "ewma")] +mod ewm; +#[cfg(feature = "ewma_by")] +mod ewm_by; +#[cfg(feature = "round_series")] +mod floor_divide; +#[cfg(feature = "fused")] +mod fused; +mod horizontal; +mod index; +#[cfg(feature = "index_of")] +mod index_of; +mod int_range; +#[cfg(any(feature = "interpolate_by", feature = "interpolate"))] +mod interpolation; +#[cfg(feature = "is_between")] +mod is_between; +#[cfg(feature = "is_first_distinct")] +mod is_first_distinct; +#[cfg(feature = "is_in")] +mod is_in; +#[cfg(feature = "is_last_distinct")] +mod is_last_distinct; +#[cfg(feature = "is_unique")] +mod is_unique; +mod linear_space; +#[cfg(feature = "log")] +mod log; +#[cfg(feature = "moment")] +mod moment; +mod negate; +#[cfg(feature = "pct_change")] +mod pct_change; +#[cfg(feature = "rank")] +mod rank; +#[cfg(feature = "reinterpret")] +mod reinterpret; +#[cfg(feature = "replace")] +mod replace; +#[cfg(feature = "rle")] +mod rle; +#[cfg(feature = "rolling_window")] +mod rolling; +#[cfg(feature = "round_series")] +mod round; +#[cfg(feature = "search_sorted")] +mod search_sorted; +#[cfg(feature = "to_dummies")] +mod to_dummies; +#[cfg(feature = "unique_counts")] +mod unique; +mod various; + +#[cfg(feature = "abs")] +pub use abs::*; +pub use arg_min_max::ArgAgg; +pub use bitwise::*; +#[cfg(feature = "business")] +pub use business::*; +pub use clip::*; +#[cfg(feature = "cum_agg")] +pub use cum_agg::*; +#[cfg(feature = "cutqcut")] +pub use cut::*; +#[cfg(feature = "diff")] +pub use diff::*; +#[cfg(feature = "ewma")] +pub use ewm::*; +#[cfg(feature = "ewma_by")] +pub use ewm_by::*; +#[cfg(feature = "round_series")] +pub use floor_divide::*; +#[cfg(feature = "fused")] +pub use fused::*; +pub use horizontal::*; +pub use index::*; +#[cfg(feature = "index_of")] +pub use index_of::*; +pub use int_range::*; +#[cfg(feature = "interpolate")] +pub use interpolation::interpolate::*; +#[cfg(feature = "interpolate_by")] +pub use interpolation::interpolate_by::*; +#[cfg(any(feature = "interpolate", feature = "interpolate_by"))] +pub use interpolation::*; +#[cfg(feature = "is_between")] +pub use is_between::*; +#[cfg(feature = "is_first_distinct")] +pub use is_first_distinct::*; +#[cfg(feature = "is_in")] +pub use is_in::*; +#[cfg(feature = "is_last_distinct")] +pub use is_last_distinct::*; +#[cfg(feature = "is_unique")] +pub use is_unique::*; +pub use linear_space::*; +#[cfg(feature = "log")] +pub use log::*; +#[cfg(feature = "moment")] +pub use moment::*; +pub use negate::*; +#[cfg(feature = "pct_change")] +pub use pct_change::*; +pub use polars_core::chunked_array::ops::search_sorted::SearchSortedSide; +use polars_core::prelude::*; +#[cfg(feature = "rank")] +pub use rank::*; +#[cfg(feature = "reinterpret")] +pub use reinterpret::*; +#[cfg(feature = "replace")] +pub use replace::*; +#[cfg(feature = "rle")] +pub use rle::*; +#[cfg(feature = "rolling_window")] +pub use rolling::*; +#[cfg(feature = "round_series")] +pub use round::*; +#[cfg(feature = "search_sorted")] +pub use search_sorted::*; +#[cfg(feature = "to_dummies")] +pub use to_dummies::*; +#[cfg(feature = "unique_counts")] +pub use unique::*; +pub use various::*; +mod not; + +#[cfg(feature = "dtype-array")] +pub mod concat_arr; +#[cfg(feature = "dtype-duration")] +pub(crate) mod duration; +#[cfg(feature = "dtype-duration")] +pub use duration::*; +pub use not::*; + +pub trait SeriesSealed { + fn as_series(&self) -> &Series; +} + +impl SeriesSealed for Series { + fn as_series(&self) -> &Series { + self + } +} diff --git a/crates/polars-ops/src/series/ops/moment.rs b/crates/polars-ops/src/series/ops/moment.rs new file mode 100644 index 000000000000..8bdc114e95c1 --- /dev/null +++ b/crates/polars-ops/src/series/ops/moment.rs @@ -0,0 +1,93 @@ +use polars_compute::moment::{KurtosisState, SkewState, kurtosis, skew}; +use polars_core::prelude::*; + +use crate::prelude::SeriesSealed; + +pub trait MomentSeries: SeriesSealed { + /// Compute the sample skewness of a data set. + /// + /// For normally distributed data, the skewness should be about zero. For + /// uni-modal continuous distributions, a skewness value greater than zero means + /// that there is more weight in the right tail of the distribution. The + /// function `skewtest` can be used to determine if the skewness value + /// is close enough to zero, statistically speaking. + /// + /// see: [scipy](https://github.com/scipy/scipy/blob/47bb6febaa10658c72962b9615d5d5aa2513fa3a/scipy/stats/stats.py#L1024) + fn skew(&self, bias: bool) -> PolarsResult> { + let s = self.as_series(); + let s = s.cast(&DataType::Float64)?; + let ca = s.f64().unwrap(); + + let mut state = SkewState::default(); + for arr in ca.downcast_iter() { + state.combine(&skew(arr)); + } + Ok(state.finalize(bias)) + } + + /// Compute the kurtosis (Fisher or Pearson) of a dataset. + /// + /// Kurtosis is the fourth central moment divided by the square of the + /// variance. If Fisher's definition is used, then 3.0 is subtracted from + /// the result to give 0.0 for a normal distribution. + /// If bias is `false` then the kurtosis is calculated using k statistics to + /// eliminate bias coming from biased moment estimators + /// + /// see: [scipy](https://github.com/scipy/scipy/blob/47bb6febaa10658c72962b9615d5d5aa2513fa3a/scipy/stats/stats.py#L1027) + fn kurtosis(&self, fisher: bool, bias: bool) -> PolarsResult> { + let s = self.as_series(); + let s = s.cast(&DataType::Float64)?; + let ca = s.f64().unwrap(); + + let mut state = KurtosisState::default(); + for arr in ca.downcast_iter() { + state.combine(&kurtosis(arr)); + } + Ok(state.finalize(fisher, bias)) + } +} + +impl MomentSeries for Series {} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_skew() -> PolarsResult<()> { + let s = Series::new(PlSmallStr::EMPTY, &[1, 2, 3, 4, 5, 23]); + let s2 = Series::new( + PlSmallStr::EMPTY, + &[Some(1), Some(2), Some(3), None, Some(1)], + ); + + assert!((s.skew(false)?.unwrap() - 2.2905330058490514).abs() < 0.0001); + assert!((s.skew(true)?.unwrap() - 1.6727687946848508).abs() < 0.0001); + + assert!((s2.skew(false)?.unwrap() - 0.8545630383279711).abs() < 0.0001); + assert!((s2.skew(true)?.unwrap() - 0.49338220021815865).abs() < 0.0001); + + Ok(()) + } + + #[test] + fn test_kurtosis() -> PolarsResult<()> { + let s = Series::new(PlSmallStr::EMPTY, &[1, 2, 3, 4, 5, 23]); + + assert!((s.kurtosis(true, true)?.unwrap() - 0.9945668771797536).abs() < 0.0001); + assert!((s.kurtosis(true, false)?.unwrap() - 5.400820058440946).abs() < 0.0001); + assert!((s.kurtosis(false, true)?.unwrap() - 3.994566877179754).abs() < 0.0001); + assert!((s.kurtosis(false, false)?.unwrap() - 8.400820058440946).abs() < 0.0001); + + let s2 = Series::new( + PlSmallStr::EMPTY, + &[Some(1), Some(2), Some(3), None, Some(1), Some(2), Some(3)], + ); + assert!((s2.kurtosis(true, true)?.unwrap() - (-1.5)).abs() < 0.0001); + assert!((s2.kurtosis(true, false)?.unwrap() - (-1.875)).abs() < 0.0001); + assert!((s2.kurtosis(false, true)?.unwrap() - 1.5).abs() < 0.0001); + assert!((s2.kurtosis(false, false)?.unwrap() - 1.125).abs() < 0.0001); + + Ok(()) + } +} diff --git a/crates/polars-ops/src/series/ops/negate.rs b/crates/polars-ops/src/series/ops/negate.rs new file mode 100644 index 000000000000..7af246d2810a --- /dev/null +++ b/crates/polars-ops/src/series/ops/negate.rs @@ -0,0 +1,33 @@ +use polars_core::prelude::*; + +pub fn negate(s: &Series) -> PolarsResult { + use DataType::*; + let out = match s.dtype() { + #[cfg(feature = "dtype-i8")] + Int8 => s.i8().unwrap().wrapping_neg().into_series(), + #[cfg(feature = "dtype-i16")] + Int16 => s.i16().unwrap().wrapping_neg().into_series(), + Int32 => s.i32().unwrap().wrapping_neg().into_series(), + Int64 => s.i64().unwrap().wrapping_neg().into_series(), + Float32 => s.f32().unwrap().wrapping_neg().into_series(), + Float64 => s.f64().unwrap().wrapping_neg().into_series(), + #[cfg(feature = "dtype-decimal")] + Decimal(_, _) => { + let ca = s.decimal().unwrap(); + let precision = ca.precision(); + let scale = ca.scale(); + + let out = ca.as_ref().wrapping_neg(); + out.into_decimal_unchecked(precision, scale).into_series() + }, + #[cfg(feature = "dtype-duration")] + Duration(_) => { + let physical = s.to_physical_repr(); + let ca = physical.i64().unwrap(); + let out = ca.wrapping_neg().into_series(); + out.cast(s.dtype())? + }, + dt => polars_bail!(opq = neg, dt), + }; + Ok(out) +} diff --git a/crates/polars-ops/src/series/ops/not.rs b/crates/polars-ops/src/series/ops/not.rs new file mode 100644 index 000000000000..b6abf1559dce --- /dev/null +++ b/crates/polars-ops/src/series/ops/not.rs @@ -0,0 +1,18 @@ +use std::ops::Not; + +use polars_core::with_match_physical_integer_polars_type; + +use super::*; + +pub fn negate_bitwise(s: &Series) -> PolarsResult { + match s.dtype() { + DataType::Boolean => Ok(s.bool().unwrap().not().into_series()), + dt if dt.is_integer() => { + with_match_physical_integer_polars_type!(dt, |$T| { + let ca: &ChunkedArray<$T> = s.as_any().downcast_ref().unwrap(); + Ok(ca.apply_values(|v| !v).into_series()) + }) + }, + dt => polars_bail!(InvalidOperation: "dtype {:?} not supported in 'not' operation", dt), + } +} diff --git a/crates/polars-ops/src/series/ops/pct_change.rs b/crates/polars-ops/src/series/ops/pct_change.rs new file mode 100644 index 000000000000..9cb45dac1d6f --- /dev/null +++ b/crates/polars-ops/src/series/ops/pct_change.rs @@ -0,0 +1,25 @@ +use polars_core::prelude::*; +use polars_core::series::ops::NullBehavior; + +use crate::prelude::diff; + +pub fn pct_change(s: &Series, n: &Series) -> PolarsResult { + polars_ensure!( + n.len() == 1, + ComputeError: "n must be a single value." + ); + + match s.dtype() { + DataType::Float64 | DataType::Float32 => {}, + _ => return pct_change(&s.cast(&DataType::Float64)?, n), + } + + let fill_null_s = s.fill_null(FillNullStrategy::Forward(None))?; + + let n_s = n.cast(&DataType::Int64)?; + if let Some(n) = n_s.i64()?.get(0) { + diff(&fill_null_s, n, NullBehavior::Ignore)?.divide(&fill_null_s.shift(n)) + } else { + Ok(Series::full_null(s.name().clone(), s.len(), s.dtype())) + } +} diff --git a/crates/polars-ops/src/series/ops/rank.rs b/crates/polars-ops/src/series/ops/rank.rs new file mode 100644 index 000000000000..0122ca8eed82 --- /dev/null +++ b/crates/polars-ops/src/series/ops/rank.rs @@ -0,0 +1,347 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use arrow::array::BooleanArray; +use arrow::compute::concatenate::concatenate_validities; +use polars_core::prelude::*; +use rand::prelude::*; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use crate::prelude::SeriesSealed; + +#[derive(Copy, Clone, Debug, PartialEq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum RankMethod { + Average, + Min, + Max, + Dense, + Ordinal, + #[cfg(feature = "random")] + Random, +} + +// We might want to add a `nulls_last` or `null_behavior` field. +#[derive(Copy, Clone, Debug, PartialEq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct RankOptions { + pub method: RankMethod, + pub descending: bool, +} + +impl Default for RankOptions { + fn default() -> Self { + Self { + method: RankMethod::Dense, + descending: false, + } + } +} + +#[cfg(feature = "random")] +fn get_random_seed() -> u64 { + let mut rng = SmallRng::from_entropy(); + + rng.next_u64() +} + +unsafe fn rank_impl(idxs: &IdxCa, neq: &BooleanArray, mut flush_ties: F) { + let mut ties_indices = Vec::with_capacity(128); + let mut idx_it = idxs.downcast_iter().flat_map(|arr| arr.values_iter()); + let Some(first_idx) = idx_it.next() else { + return; + }; + ties_indices.push(*first_idx); + + for (eq_idx, idx) in idx_it.enumerate() { + if neq.value_unchecked(eq_idx) { + flush_ties(&mut ties_indices); + ties_indices.clear() + } + + ties_indices.push(*idx); + } + flush_ties(&mut ties_indices); +} + +fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option) -> Series { + let len = s.len(); + let null_count = s.null_count(); + + if null_count == len { + let dt = match method { + Average => DataType::Float64, + _ => IDX_DTYPE, + }; + return Series::full_null(s.name().clone(), s.len(), &dt); + } + + match len { + 1 => { + return match method { + Average => Series::new(s.name().clone(), &[1.0f64]), + _ => Series::new(s.name().clone(), &[1 as IdxSize]), + }; + }, + 0 => { + return match method { + Average => Float64Chunked::from_slice(s.name().clone(), &[]).into_series(), + _ => IdxCa::from_slice(s.name().clone(), &[]).into_series(), + }; + }, + _ => {}, + } + + if null_count == len { + return match method { + Average => Float64Chunked::full_null(s.name().clone(), len).into_series(), + _ => IdxCa::full_null(s.name().clone(), len).into_series(), + }; + } + + let sort_idx_ca = s + .arg_sort(SortOptions { + descending, + nulls_last: true, + ..Default::default() + }) + .slice(0, len - null_count); + + let validity = concatenate_validities(s.chunks()); + + use RankMethod::*; + if let Ordinal = method { + let mut out = vec![0 as IdxSize; s.len()]; + let mut rank = 0; + for arr in sort_idx_ca.downcast_iter() { + for i in arr.values_iter() { + out[*i as usize] = rank + 1; + rank += 1; + } + } + IdxCa::from_vec_validity(s.name().clone(), out, validity).into_series() + } else { + let sorted_values = unsafe { s.take_unchecked(&sort_idx_ca) }; + let not_consecutive_same = sorted_values + .slice(1, sorted_values.len() - 1) + .not_equal(&sorted_values.slice(0, sorted_values.len() - 1)) + .unwrap(); + let neq = not_consecutive_same.rechunk(); + let neq = neq.downcast_as_array(); + + let mut rank = 1; + match method { + #[cfg(feature = "random")] + Random => unsafe { + let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_random_seed)); + let mut out = vec![0 as IdxSize; s.len()]; + rank_impl(&sort_idx_ca, neq, |ties| { + ties.shuffle(&mut rng); + for i in ties { + *out.get_unchecked_mut(*i as usize) = rank; + rank += 1; + } + }); + IdxCa::from_vec_validity(s.name().clone(), out, validity).into_series() + }, + Average => unsafe { + let mut out = vec![0.0; s.len()]; + rank_impl(&sort_idx_ca, neq, |ties| { + let first = rank; + rank += ties.len() as IdxSize; + let last = rank - 1; + let avg = 0.5 * (first as f64 + last as f64); + for i in ties { + *out.get_unchecked_mut(*i as usize) = avg; + } + }); + Float64Chunked::from_vec_validity(s.name().clone(), out, validity).into_series() + }, + Min => unsafe { + let mut out = vec![0 as IdxSize; s.len()]; + rank_impl(&sort_idx_ca, neq, |ties| { + for i in ties.iter() { + *out.get_unchecked_mut(*i as usize) = rank; + } + rank += ties.len() as IdxSize; + }); + IdxCa::from_vec_validity(s.name().clone(), out, validity).into_series() + }, + Max => unsafe { + let mut out = vec![0 as IdxSize; s.len()]; + rank_impl(&sort_idx_ca, neq, |ties| { + rank += ties.len() as IdxSize; + for i in ties { + *out.get_unchecked_mut(*i as usize) = rank - 1; + } + }); + IdxCa::from_vec_validity(s.name().clone(), out, validity).into_series() + }, + Dense => unsafe { + let mut out = vec![0 as IdxSize; s.len()]; + rank_impl(&sort_idx_ca, neq, |ties| { + for i in ties { + *out.get_unchecked_mut(*i as usize) = rank; + } + rank += 1; + }); + IdxCa::from_vec_validity(s.name().clone(), out, validity).into_series() + }, + Ordinal => unreachable!(), + } + } +} + +pub trait SeriesRank: SeriesSealed { + fn rank(&self, options: RankOptions, seed: Option) -> Series { + rank(self.as_series(), options.method, options.descending, seed) + } +} + +impl SeriesRank for Series {} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_rank() -> PolarsResult<()> { + let s = Series::new("a".into(), &[1, 2, 3, 2, 2, 3, 0]); + + let out = rank(&s, RankMethod::Ordinal, false, None) + .idx()? + .into_no_null_iter() + .collect::>(); + assert_eq!(out, &[2 as IdxSize, 3, 6, 4, 5, 7, 1]); + + #[cfg(feature = "random")] + { + let out = rank(&s, RankMethod::Random, false, None) + .idx()? + .into_no_null_iter() + .collect::>(); + assert_eq!(out[0], 2); + assert_eq!(out[6], 1); + assert_eq!(out[1] + out[3] + out[4], 12); + assert_eq!(out[2] + out[5], 13); + assert_ne!(out[1], out[3]); + assert_ne!(out[1], out[4]); + assert_ne!(out[3], out[4]); + } + + let out = rank(&s, RankMethod::Dense, false, None) + .idx()? + .into_no_null_iter() + .collect::>(); + assert_eq!(out, &[2, 3, 4, 3, 3, 4, 1]); + + let out = rank(&s, RankMethod::Max, false, None) + .idx()? + .into_no_null_iter() + .collect::>(); + assert_eq!(out, &[2, 5, 7, 5, 5, 7, 1]); + + let out = rank(&s, RankMethod::Min, false, None) + .idx()? + .into_no_null_iter() + .collect::>(); + assert_eq!(out, &[2, 3, 6, 3, 3, 6, 1]); + + let out = rank(&s, RankMethod::Average, false, None) + .f64()? + .into_no_null_iter() + .collect::>(); + assert_eq!(out, &[2.0f64, 4.0, 6.5, 4.0, 4.0, 6.5, 1.0]); + + let s = Series::new( + "a".into(), + &[Some(1), Some(2), Some(3), Some(2), None, None, Some(0)], + ); + + let out = rank(&s, RankMethod::Average, false, None) + .f64()? + .into_iter() + .collect::>(); + + assert_eq!( + out, + &[ + Some(2.0f64), + Some(3.5), + Some(5.0), + Some(3.5), + None, + None, + Some(1.0) + ] + ); + let s = Series::new( + "a".into(), + &[ + Some(5), + Some(6), + Some(4), + None, + Some(78), + Some(4), + Some(2), + Some(8), + ], + ); + let out = rank(&s, RankMethod::Max, false, None) + .idx()? + .into_iter() + .collect::>(); + assert_eq!( + out, + &[ + Some(4), + Some(5), + Some(3), + None, + Some(7), + Some(3), + Some(1), + Some(6) + ] + ); + + Ok(()) + } + + #[test] + fn test_rank_all_null() -> PolarsResult<()> { + let s = UInt32Chunked::new("".into(), &[None, None, None]).into_series(); + let out = rank(&s, RankMethod::Average, false, None) + .f64()? + .into_iter() + .collect::>(); + assert_eq!(out, &[None, None, None]); + let out = rank(&s, RankMethod::Dense, false, None) + .idx()? + .into_iter() + .collect::>(); + assert_eq!(out, &[None, None, None]); + Ok(()) + } + + #[test] + fn test_rank_empty() { + let s = UInt32Chunked::from_slice("".into(), &[]).into_series(); + let out = rank(&s, RankMethod::Average, false, None); + assert_eq!(out.dtype(), &DataType::Float64); + let out = rank(&s, RankMethod::Max, false, None); + assert_eq!(out.dtype(), &IDX_DTYPE); + } + + #[test] + fn test_rank_reverse() -> PolarsResult<()> { + let s = Series::new("".into(), &[None, Some(1), Some(1), Some(5), None]); + let out = rank(&s, RankMethod::Dense, true, None) + .idx()? + .into_iter() + .collect::>(); + assert_eq!(out, &[None, Some(2 as IdxSize), Some(2), Some(1), None]); + + Ok(()) + } +} diff --git a/crates/polars-ops/src/series/ops/reinterpret.rs b/crates/polars-ops/src/series/ops/reinterpret.rs new file mode 100644 index 000000000000..7a271ed0c0a5 --- /dev/null +++ b/crates/polars-ops/src/series/ops/reinterpret.rs @@ -0,0 +1,18 @@ +use polars_core::prelude::*; + +pub fn reinterpret(s: &Series, signed: bool) -> PolarsResult { + Ok(match (s.dtype(), signed) { + (DataType::UInt64, true) => s.u64().unwrap().reinterpret_signed().into_series(), + (DataType::UInt64, false) => s.clone(), + (DataType::Int64, false) => s.i64().unwrap().reinterpret_unsigned().into_series(), + (DataType::Int64, true) => s.clone(), + (DataType::UInt32, true) => s.u32().unwrap().reinterpret_signed().into_series(), + (DataType::UInt32, false) => s.clone(), + (DataType::Int32, false) => s.i32().unwrap().reinterpret_unsigned().into_series(), + (DataType::Int32, true) => s.clone(), + _ => polars_bail!( + ComputeError: + "reinterpret is only allowed for 64-bit/32-bit integers types, use cast otherwise" + ), + }) +} diff --git a/crates/polars-ops/src/series/ops/replace.rs b/crates/polars-ops/src/series/ops/replace.rs new file mode 100644 index 000000000000..24168cc6baf6 --- /dev/null +++ b/crates/polars-ops/src/series/ops/replace.rs @@ -0,0 +1,288 @@ +use polars_core::prelude::*; +use polars_core::utils::try_get_supertype; +use polars_error::polars_ensure; + +use crate::frame::join::*; +use crate::prelude::*; + +/// Replace values by different values of the same data type. +pub fn replace(s: &Series, old: &Series, new: &Series) -> PolarsResult { + if old.is_empty() { + return Ok(s.clone()); + } + validate_old(old)?; + + let dtype = s.dtype(); + let old = cast_old_to_series_dtype(old, dtype)?; + let new = new.strict_cast(dtype)?; + + if new.len() == 1 { + replace_by_single(s, &old, &new, s) + } else { + replace_by_multiple(s, old, new, s) + } +} + +/// Replace all values by different values. +/// +/// Unmatched values are replaced by a default value. +pub fn replace_or_default( + s: &Series, + old: &Series, + new: &Series, + default: &Series, + return_dtype: Option, +) -> PolarsResult { + polars_ensure!( + default.len() == s.len() || default.len() == 1, + InvalidOperation: "`default` input for `replace_strict` must have the same length as the input or have length 1" + ); + validate_old(old)?; + + let return_dtype = match return_dtype { + Some(dtype) => dtype, + None => try_get_supertype(new.dtype(), default.dtype())?, + }; + let default = default.cast(&return_dtype)?; + + if old.is_empty() { + let out = if default.len() == 1 && s.len() != 1 { + default.new_from_index(0, s.len()) + } else { + default + }; + return Ok(out); + } + + let old = cast_old_to_series_dtype(old, s.dtype())?; + let new = new.cast(&return_dtype)?; + + if new.len() == 1 { + replace_by_single(s, &old, &new, &default) + } else { + replace_by_multiple(s, old, new, &default) + } +} + +/// Replace all values by different values. +/// +/// Raises an error if not all values were replaced. +pub fn replace_strict( + s: &Series, + old: &Series, + new: &Series, + return_dtype: Option, +) -> PolarsResult { + if old.is_empty() { + polars_ensure!( + s.len() == s.null_count(), + InvalidOperation: "must specify which values to replace" + ); + return Ok(s.clone()); + } + validate_old(old)?; + + let old = cast_old_to_series_dtype(old, s.dtype())?; + let new = match return_dtype { + Some(dtype) => new.strict_cast(&dtype)?, + None => new.clone(), + }; + + if new.len() == 1 { + replace_by_single_strict(s, &old, &new) + } else { + replace_by_multiple_strict(s, old, new) + } +} + +/// Validate the `old` input. +fn validate_old(old: &Series) -> PolarsResult<()> { + polars_ensure!( + old.n_unique()? == old.len(), + InvalidOperation: "`old` input for `replace` must not contain duplicates" + ); + Ok(()) +} + +/// Cast `old` input while enabling String to Categorical casts. +fn cast_old_to_series_dtype(old: &Series, dtype: &DataType) -> PolarsResult { + match (old.dtype(), dtype) { + #[cfg(feature = "dtype-categorical")] + (DataType::String, DataType::Categorical(_, ord)) => { + let empty_categorical_dtype = DataType::Categorical(None, *ord); + old.strict_cast(&empty_categorical_dtype) + }, + _ => old.strict_cast(dtype), + } +} + +// Fast path for replacing by a single value +fn replace_by_single( + s: &Series, + old: &Series, + new: &Series, + default: &Series, +) -> PolarsResult { + let mut mask = get_replacement_mask(s, old)?; + if old.null_count() > 0 { + mask = mask.fill_null_with_values(true)?; + } + new.zip_with(&mask, default) +} +/// Fast path for replacing by a single value in strict mode +fn replace_by_single_strict(s: &Series, old: &Series, new: &Series) -> PolarsResult { + let mask = get_replacement_mask(s, old)?; + ensure_all_replaced(&mask, s, old.null_count() > 0, true)?; + + let mut out = new.new_from_index(0, s.len()); + + // Transfer validity from `mask` to `out`. + if mask.null_count() > 0 { + out = out.zip_with(&mask, &Series::new_null(PlSmallStr::EMPTY, s.len()))? + } + Ok(out) +} +/// Get a boolean mask of which values in the original Series will be replaced. +/// +/// Null values are propagated to the mask. +fn get_replacement_mask(s: &Series, old: &Series) -> PolarsResult { + if old.null_count() == old.len() { + // Fast path for when users are using `replace(None, ...)` instead of `fill_null`. + Ok(s.is_null()) + } else { + let old = old.implode()?; + is_in(s, &old.into_series(), false) + } +} + +/// General case for replacing by multiple values +fn replace_by_multiple( + s: &Series, + old: Series, + new: Series, + default: &Series, +) -> PolarsResult { + validate_new(&new, &old)?; + + let df = s.clone().into_frame(); + let add_replacer_mask = new.null_count() > 0; + let replacer = create_replacer(old, new, add_replacer_mask)?; + + let joined = df.join( + &replacer, + [s.name().as_str()], + ["__POLARS_REPLACE_OLD"], + JoinArgs { + how: JoinType::Left, + coalesce: JoinCoalesce::CoalesceColumns, + nulls_equal: true, + ..Default::default() + }, + None, + )?; + + let replaced = joined + .column("__POLARS_REPLACE_NEW") + .unwrap() + .as_materialized_series(); + + if replaced.null_count() == 0 { + return Ok(replaced.clone()); + } + + match joined.column("__POLARS_REPLACE_MASK") { + Ok(col) => { + let mask = col.bool().unwrap(); + replaced.zip_with(mask, default) + }, + Err(_) => { + let mask = &replaced.is_not_null(); + replaced.zip_with(mask, default) + }, + } +} + +/// General case for replacing by multiple values in strict mode +fn replace_by_multiple_strict(s: &Series, old: Series, new: Series) -> PolarsResult { + validate_new(&new, &old)?; + + let df = s.clone().into_frame(); + let old_has_null = old.null_count() > 0; + let replacer = create_replacer(old, new, true)?; + + let joined = df.join( + &replacer, + [s.name().as_str()], + ["__POLARS_REPLACE_OLD"], + JoinArgs { + how: JoinType::Left, + coalesce: JoinCoalesce::CoalesceColumns, + nulls_equal: true, + ..Default::default() + }, + None, + )?; + + let replaced = joined.column("__POLARS_REPLACE_NEW").unwrap(); + + let mask = joined + .column("__POLARS_REPLACE_MASK") + .unwrap() + .bool() + .unwrap(); + ensure_all_replaced(mask, s, old_has_null, false)?; + + Ok(replaced.as_materialized_series().clone()) +} + +// Build replacer dataframe. +fn create_replacer(mut old: Series, mut new: Series, add_mask: bool) -> PolarsResult { + old.rename(PlSmallStr::from_static("__POLARS_REPLACE_OLD")); + new.rename(PlSmallStr::from_static("__POLARS_REPLACE_NEW")); + + let len = old.len(); + let cols = if add_mask { + let mask = Column::new_scalar( + PlSmallStr::from_static("__POLARS_REPLACE_MASK"), + true.into(), + new.len(), + ); + vec![old.into(), new.into(), mask] + } else { + vec![old.into(), new.into()] + }; + let out = unsafe { DataFrame::new_no_checks(len, cols) }; + Ok(out) +} + +/// Validate the `new` input. +fn validate_new(new: &Series, old: &Series) -> PolarsResult<()> { + polars_ensure!( + new.len() == old.len(), + InvalidOperation: "`new` input for `replace` must have the same length as `old` or have length 1" + ); + Ok(()) +} + +/// Ensure that all values were replaced. +fn ensure_all_replaced( + mask: &BooleanChunked, + s: &Series, + old_has_null: bool, + check_all: bool, +) -> PolarsResult<()> { + let nulls_check = if old_has_null { + mask.null_count() == 0 + } else { + mask.null_count() == s.null_count() + }; + // Checking booleans is only relevant for the 'replace_by_single' path. + let bools_check = !check_all || mask.all(); + + let all_replaced = bools_check && nulls_check; + polars_ensure!( + all_replaced, + InvalidOperation: "incomplete mapping specified for `replace_strict`\n\nHint: Pass a `default` value to set unmapped values." + ); + Ok(()) +} diff --git a/crates/polars-ops/src/series/ops/rle.rs b/crates/polars-ops/src/series/ops/rle.rs new file mode 100644 index 000000000000..809989f7afd9 --- /dev/null +++ b/crates/polars-ops/src/series/ops/rle.rs @@ -0,0 +1,90 @@ +use polars_core::prelude::*; +use polars_core::series::IsSorted; + +/// Get the run-Lengths of values. +pub fn rle_lengths(s: &Column, lengths: &mut Vec) -> PolarsResult<()> { + lengths.clear(); + if s.is_empty() { + return Ok(()); + } + + if let Some(sc) = s.as_scalar_column() { + lengths.push(sc.len() as IdxSize); + return Ok(()); + } + + let (s1, s2) = (s.slice(0, s.len() - 1), s.slice(1, s.len())); + let s_neq = s1 + .as_materialized_series() + .not_equal_missing(s2.as_materialized_series())?; + let n_runs = s_neq.sum().unwrap() + 1; + + lengths.reserve(n_runs as usize); + lengths.push(1); + + assert!(!s_neq.has_nulls()); + for arr in s_neq.downcast_iter() { + let mut values = arr.values().clone(); + while !values.is_empty() { + // @NOTE: This `as IdxSize` is safe because it is less than or equal to the a ChunkedArray + // length. + *lengths.last_mut().unwrap() += values.take_leading_zeros() as IdxSize; + + if !values.is_empty() { + lengths.push(1); + values.slice(1, values.len() - 1); + } + } + } + Ok(()) +} + +/// Get the lengths of runs of identical values. +pub fn rle(s: &Column) -> PolarsResult { + let mut lengths = Vec::new(); + rle_lengths(s, &mut lengths)?; + + let mut idxs = Vec::with_capacity(lengths.len()); + if !lengths.is_empty() { + idxs.push(0); + for length in &lengths[..lengths.len() - 1] { + idxs.push(*idxs.last().unwrap() + length); + } + } + + let vals = s + .take_slice(&idxs) + .unwrap() + .with_name(PlSmallStr::from_static("value")); + let outvals = vec![ + Series::from_vec(PlSmallStr::from_static("len"), lengths).into(), + vals, + ]; + Ok(StructChunked::from_columns(s.name().clone(), idxs.len(), &outvals)?.into_column()) +} + +/// Similar to `rle`, but maps values to run IDs. +pub fn rle_id(s: &Column) -> PolarsResult { + if s.is_empty() { + return Ok(Column::new_empty(s.name().clone(), &IDX_DTYPE)); + } + + let (s1, s2) = (s.slice(0, s.len() - 1), s.slice(1, s.len())); + let s_neq = s1 + .as_materialized_series() + .not_equal_missing(s2.as_materialized_series())?; + + let mut out = Vec::::with_capacity(s.len()); + let mut last = 0; + out.push(last); // Run numbers start at zero + assert_eq!(s_neq.null_count(), 0); + for a in s_neq.downcast_iter() { + for aa in a.values_iter() { + last += aa as IdxSize; + out.push(last); + } + } + Ok(IdxCa::from_vec(s.name().clone(), out) + .with_sorted_flag(IsSorted::Ascending) + .into_column()) +} diff --git a/crates/polars-ops/src/series/ops/rolling.rs b/crates/polars-ops/src/series/ops/rolling.rs new file mode 100644 index 000000000000..57d96de279be --- /dev/null +++ b/crates/polars-ops/src/series/ops/rolling.rs @@ -0,0 +1,126 @@ +use polars_core::prelude::*; +#[cfg(feature = "moment")] +use { + num_traits::{Float, pow::Pow}, + std::ops::SubAssign, +}; + +#[cfg(feature = "moment")] +fn rolling_skew_ca( + ca: &ChunkedArray, + window_size: usize, + min_periods: usize, + center: bool, + params: Option, +) -> PolarsResult> +where + ChunkedArray: IntoSeries, + T: PolarsFloatType, + T::Native: Float + SubAssign + Pow, +{ + use arrow::array::Array; + + let ca = ca.rechunk(); + let arr = ca.downcast_get(0).unwrap(); + let arr = if arr.has_nulls() { + polars_compute::rolling::nulls::rolling_skew(arr, window_size, min_periods, center, params) + } else { + let values = arr.values(); + polars_compute::rolling::no_nulls::rolling_skew( + values, + window_size, + min_periods, + center, + params, + )? + }; + Ok(unsafe { ca.with_chunks(vec![arr]) }) +} + +#[cfg(feature = "moment")] +pub fn rolling_skew(s: &Series, options: RollingOptionsFixedWindow) -> PolarsResult { + let window_size = options.window_size; + let min_periods = options.min_periods; + let center = options.center; + let params = options.fn_params; + + match s.dtype() { + DataType::Float64 => { + let ca = s.f64().unwrap(); + rolling_skew_ca(ca, window_size, min_periods, center, params).map(|ca| ca.into_series()) + }, + DataType::Float32 => { + let ca = s.f32().unwrap(); + rolling_skew_ca(ca, window_size, min_periods, center, params).map(|ca| ca.into_series()) + }, + dt if dt.is_primitive_numeric() => { + let s = s.cast(&DataType::Float64).unwrap(); + rolling_skew(&s, options) + }, + dt => polars_bail!(opq = rolling_skew, dt), + } +} + +#[cfg(feature = "moment")] +fn rolling_kurtosis_ca( + ca: &ChunkedArray, + window_size: usize, + params: Option, + min_periods: usize, + center: bool, +) -> PolarsResult> +where + ChunkedArray: IntoSeries, + T: PolarsFloatType, + T::Native: Float + SubAssign + Pow, +{ + use arrow::array::Array; + + let ca = ca.rechunk(); + let arr = ca.downcast_get(0).unwrap(); + let arr = if arr.has_nulls() { + polars_compute::rolling::nulls::rolling_kurtosis( + arr, + window_size, + min_periods, + center, + params, + ) + } else { + let values = arr.values(); + polars_compute::rolling::no_nulls::rolling_kurtosis( + values, + window_size, + min_periods, + center, + params, + )? + }; + Ok(unsafe { ca.with_chunks(vec![arr]) }) +} + +#[cfg(feature = "moment")] +pub fn rolling_kurtosis(s: &Series, options: RollingOptionsFixedWindow) -> PolarsResult { + let window_size = options.window_size; + let min_periods = options.min_periods; + let center = options.center; + let params = options.fn_params; + + match s.dtype() { + DataType::Float64 => { + let ca = s.f64().unwrap(); + rolling_kurtosis_ca(ca, window_size, params, min_periods, center) + .map(|ca| ca.into_series()) + }, + DataType::Float32 => { + let ca = s.f32().unwrap(); + rolling_kurtosis_ca(ca, window_size, params, min_periods, center) + .map(|ca| ca.into_series()) + }, + dt if dt.is_primitive_numeric() => { + let s = s.cast(&DataType::Float64).unwrap(); + rolling_kurtosis(&s, options) + }, + dt => polars_bail!(opq = rolling_kurtosis, dt), + } +} diff --git a/crates/polars-ops/src/series/ops/round.rs b/crates/polars-ops/src/series/ops/round.rs new file mode 100644 index 000000000000..e151cc2b184a --- /dev/null +++ b/crates/polars-ops/src/series/ops/round.rs @@ -0,0 +1,273 @@ +use polars_core::prelude::*; +use polars_core::with_match_physical_numeric_polars_type; + +use crate::series::ops::SeriesSealed; + +pub trait RoundSeries: SeriesSealed { + /// Round underlying floating point array to given decimal. + fn round(&self, decimals: u32) -> PolarsResult { + let s = self.as_series(); + + if let Ok(ca) = s.f32() { + return if decimals == 0 { + let s = ca.apply_values(|val| val.round()).into_series(); + Ok(s) + } else if decimals >= 326 { + // More precise than smallest denormal. + Ok(s.clone()) + } else { + // Note we do the computation on f64 floats to not lose precision + // when the computation is done, we cast to f32 + let multiplier = 10.0_f64.powi(decimals as i32); + let s = ca + .apply_values(|val| { + let ret = ((val as f64 * multiplier).round() / multiplier) as f32; + if ret.is_finite() { + ret + } else { + // We return the original value which is correct both for overflows and non-finite inputs. + val + } + }) + .into_series(); + Ok(s) + }; + } + if let Ok(ca) = s.f64() { + return if decimals == 0 { + let s = ca.apply_values(|val| val.round()).into_series(); + Ok(s) + } else if decimals >= 326 { + // More precise than smallest denormal. + Ok(s.clone()) + } else if decimals >= 300 { + // We're getting into unrepresentable territory for the multiplier + // here, split up the 10^n multiplier into 2^n and 5^n. + let mul2 = libm::scalbn(1.0, decimals as i32); + let invmul2 = 1.0 / mul2; // Still exact for any valid value of decimals. + let mul5 = 5.0_f64.powi(decimals as i32); + let s = ca + .apply_values(|val| { + let ret = (val * mul2 * mul5).round() / mul5 * invmul2; + if ret.is_finite() { + ret + } else { + // We return the original value which is correct both for overflows and non-finite inputs. + val + } + }) + .into_series(); + Ok(s) + } else { + let multiplier = 10.0_f64.powi(decimals as i32); + let s = ca + .apply_values(|val| { + let ret = (val * multiplier).round() / multiplier; + if ret.is_finite() { + ret + } else { + // We return the original value which is correct both for overflows and non-finite inputs. + val + } + }) + .into_series(); + Ok(s) + }; + } + #[cfg(feature = "dtype-decimal")] + if let Some(ca) = s.try_decimal() { + let precision = ca.precision(); + let scale = ca.scale() as u32; + + if scale <= decimals { + return Ok(ca.clone().into_series()); + } + + let decimal_delta = scale - decimals; + let multiplier = 10i128.pow(decimal_delta); + let threshold = multiplier / 2; + + let ca = ca + .apply_values(|v| { + // We use rounding=ROUND_HALF_EVEN + let rem = v % multiplier; + let is_v_floor_even = ((v - rem) / multiplier) % 2 == 0; + let threshold = threshold + i128::from(is_v_floor_even); + let round_offset = if rem.abs() >= threshold { + multiplier + } else { + 0 + }; + let round_offset = if v < 0 { -round_offset } else { round_offset }; + v - rem + round_offset + }) + .into_decimal_unchecked(precision, scale as usize); + + return Ok(ca.into_series()); + } + + polars_ensure!(s.dtype().is_primitive_numeric(), InvalidOperation: "round can only be used on numeric types" ); + Ok(s.clone()) + } + + fn round_sig_figs(&self, digits: i32) -> PolarsResult { + let s = self.as_series(); + polars_ensure!(digits >= 1, InvalidOperation: "digits must be an integer >= 1"); + + #[cfg(feature = "dtype-decimal")] + if let Some(ca) = s.try_decimal() { + let precision = ca.precision(); + let scale = ca.scale() as u32; + + let s = ca + .apply_values(|v| { + if v == 0 { + return 0; + } + + let mut magnitude = v.abs().ilog10(); + let magnitude_mult = 10i128.pow(magnitude); // @Q? It might be better to do this with a + // LUT. + if v.abs() > magnitude_mult { + magnitude += 1; + } + let decimals = magnitude.saturating_sub(digits as u32); + let multiplier = 10i128.pow(decimals); // @Q? It might be better to do this with a + // LUT. + let threshold = multiplier / 2; + + // We use rounding=ROUND_HALF_EVEN + let rem = v % multiplier; + let is_v_floor_even = decimals <= scale && ((v - rem) / multiplier) % 2 == 0; + let threshold = threshold + i128::from(is_v_floor_even); + let round_offset = if rem.abs() >= threshold { + multiplier + } else { + 0 + }; + let round_offset = if v < 0 { -round_offset } else { round_offset }; + v - rem + round_offset + }) + .into_decimal_unchecked(precision, scale as usize) + .into_series(); + + return Ok(s); + } + + polars_ensure!(s.dtype().is_primitive_numeric(), InvalidOperation: "round_sig_figs can only be used on numeric types" ); + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + let s = ca.apply_values(|value| { + let value = value as f64; + if value == 0.0 { + return value as <$T as PolarsNumericType>::Native; + } + // To deal with very large/small numbers we split up 10^n in 5^n and 2^n. + // The scaling by 2^n is almost always lossless. + let exp = digits - 1 - value.abs().log10().floor() as i32; + let pow5 = 5.0_f64.powi(exp); + let scaled = libm::scalbn(value, exp) * pow5; + let descaled = libm::scalbn(scaled.round() / pow5, -exp); + if descaled.is_finite() { + descaled as <$T as PolarsNumericType>::Native + } else { + value as <$T as PolarsNumericType>::Native + } + }).into_series(); + return Ok(s); + }); + } + + /// Floor underlying floating point array to the lowest integers smaller or equal to the float value. + fn floor(&self) -> PolarsResult { + let s = self.as_series(); + + if let Ok(ca) = s.f32() { + let s = ca.apply_values(|val| val.floor()).into_series(); + return Ok(s); + } + if let Ok(ca) = s.f64() { + let s = ca.apply_values(|val| val.floor()).into_series(); + return Ok(s); + } + #[cfg(feature = "dtype-decimal")] + if let Some(ca) = s.try_decimal() { + let precision = ca.precision(); + let scale = ca.scale() as u32; + if scale == 0 { + return Ok(ca.clone().into_series()); + } + + let decimal_delta = scale; + let multiplier = 10i128.pow(decimal_delta); + + let ca = ca + .apply_values(|v| { + let rem = v % multiplier; + let round_offset = if v < 0 { multiplier + rem } else { rem }; + let round_offset = if rem == 0 { 0 } else { round_offset }; + v - round_offset + }) + .into_decimal_unchecked(precision, scale as usize); + + return Ok(ca.into_series()); + } + + polars_ensure!(s.dtype().is_primitive_numeric(), InvalidOperation: "floor can only be used on numeric types" ); + Ok(s.clone()) + } + + /// Ceil underlying floating point array to the highest integers smaller or equal to the float value. + fn ceil(&self) -> PolarsResult { + let s = self.as_series(); + + if let Ok(ca) = s.f32() { + let s = ca.apply_values(|val| val.ceil()).into_series(); + return Ok(s); + } + if let Ok(ca) = s.f64() { + let s = ca.apply_values(|val| val.ceil()).into_series(); + return Ok(s); + } + #[cfg(feature = "dtype-decimal")] + if let Some(ca) = s.try_decimal() { + let precision = ca.precision(); + let scale = ca.scale() as u32; + if scale == 0 { + return Ok(ca.clone().into_series()); + } + + let decimal_delta = scale; + let multiplier = 10i128.pow(decimal_delta); + + let ca = ca + .apply_values(|v| { + let rem = v % multiplier; + let round_offset = if v < 0 { -rem } else { multiplier - rem }; + let round_offset = if rem == 0 { 0 } else { round_offset }; + v + round_offset + }) + .into_decimal_unchecked(precision, scale as usize); + + return Ok(ca.into_series()); + } + + polars_ensure!(s.dtype().is_primitive_numeric(), InvalidOperation: "ceil can only be used on numeric types" ); + Ok(s.clone()) + } +} + +impl RoundSeries for Series {} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_round_series() { + let series = Series::new("a".into(), &[1.003, 2.23222, 3.4352]); + let out = series.round(2).unwrap(); + let ca = out.f64().unwrap(); + assert_eq!(ca.get(0), Some(1.0)); + } +} diff --git a/crates/polars-ops/src/series/ops/search_sorted.rs b/crates/polars-ops/src/series/ops/search_sorted.rs new file mode 100644 index 000000000000..ed4e5dc9db85 --- /dev/null +++ b/crates/polars-ops/src/series/ops/search_sorted.rs @@ -0,0 +1,98 @@ +use polars_core::chunked_array::ops::search_sorted::{SearchSortedSide, binary_search_ca}; +use polars_core::prelude::row_encode::_get_rows_encoded_ca; +use polars_core::prelude::*; +use polars_core::with_match_physical_numeric_polars_type; + +pub fn search_sorted( + s: &Series, + search_values: &Series, + side: SearchSortedSide, + descending: bool, +) -> PolarsResult { + let original_dtype = s.dtype(); + + if s.dtype().is_categorical() { + // See https://github.com/pola-rs/polars/issues/20171 + polars_bail!(InvalidOperation: "'search_sorted' is not supported on dtype: {}", s.dtype()) + } + + let s = s.to_physical_repr(); + let phys_dtype = s.dtype(); + + match phys_dtype { + DataType::String => { + let ca = s.str().unwrap(); + let ca = ca.as_binary(); + let search_values = search_values.str()?; + let search_values = search_values.as_binary(); + let idx = binary_search_ca(&ca, search_values.iter(), side, descending); + Ok(IdxCa::new_vec(s.name().clone(), idx)) + }, + DataType::Boolean => { + let ca = s.bool().unwrap(); + let search_values = search_values.bool()?; + + let mut none_pos = None; + let mut false_pos = None; + let mut true_pos = None; + let idxs = search_values + .iter() + .map(|v| { + let cache = match v { + None => &mut none_pos, + Some(false) => &mut false_pos, + Some(true) => &mut true_pos, + }; + *cache.get_or_insert_with(|| { + binary_search_ca(ca, [v].into_iter(), side, descending)[0] + }) + }) + .collect(); + Ok(IdxCa::new_vec(s.name().clone(), idxs)) + }, + DataType::Binary => { + let ca = s.binary().unwrap(); + + let idx = match search_values.dtype() { + DataType::BinaryOffset => { + let search_values = search_values.binary_offset().unwrap(); + binary_search_ca(ca, search_values.iter(), side, descending) + }, + DataType::Binary => { + let search_values = search_values.binary().unwrap(); + binary_search_ca(ca, search_values.iter(), side, descending) + }, + _ => unreachable!(), + }; + + Ok(IdxCa::new_vec(s.name().clone(), idx)) + }, + dt if dt.is_primitive_numeric() => { + let search_values = search_values.to_physical_repr(); + + let idx = with_match_physical_numeric_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + let search_values: &ChunkedArray<$T> = search_values.as_ref().as_ref().as_ref(); + binary_search_ca(ca, search_values.iter(), side, descending) + }); + Ok(IdxCa::new_vec(s.name().clone(), idx)) + }, + dt if dt.is_nested() => { + let ca = _get_rows_encoded_ca( + "".into(), + &[s.as_ref().clone().into_column()], + &[descending], + &[false], + )?; + let search_values = _get_rows_encoded_ca( + "".into(), + &[search_values.clone().into_column()], + &[descending], + &[false], + )?; + let idx = binary_search_ca(&ca, search_values.iter(), side, false); + Ok(IdxCa::new_vec(s.name().clone(), idx)) + }, + _ => polars_bail!(opq = search_sorted, original_dtype), + } +} diff --git a/crates/polars-ops/src/series/ops/to_dummies.rs b/crates/polars-ops/src/series/ops/to_dummies.rs new file mode 100644 index 000000000000..5737d35b0eae --- /dev/null +++ b/crates/polars-ops/src/series/ops/to_dummies.rs @@ -0,0 +1,83 @@ +use polars_utils::format_pl_smallstr; + +use super::*; + +#[cfg(feature = "dtype-u8")] +type DummyType = u8; +#[cfg(feature = "dtype-u8")] +type DummyCa = UInt8Chunked; + +#[cfg(not(feature = "dtype-u8"))] +type DummyType = i32; +#[cfg(not(feature = "dtype-u8"))] +type DummyCa = Int32Chunked; + +pub trait ToDummies { + fn to_dummies(&self, separator: Option<&str>, drop_first: bool) -> PolarsResult; +} + +impl ToDummies for Series { + fn to_dummies(&self, separator: Option<&str>, drop_first: bool) -> PolarsResult { + let sep = separator.unwrap_or("_"); + let col_name = self.name(); + let groups = self.group_tuples(true, drop_first)?; + + // SAFETY: groups are in bounds + let columns = unsafe { self.agg_first(&groups) }; + let columns = columns.iter().zip(groups.iter()).skip(drop_first as usize); + let columns = columns + .map(|(av, group)| { + // strings are formatted with extra \" \" in polars, so we + // extract the string + let name = if let Some(s) = av.get_str() { + format_pl_smallstr!("{col_name}{sep}{s}") + } else { + // other types don't have this formatting issue + format_pl_smallstr!("{col_name}{sep}{av}") + }; + + let ca = match group { + GroupsIndicator::Idx((_, group)) => dummies_helper_idx(group, self.len(), name), + GroupsIndicator::Slice([offset, len]) => { + dummies_helper_slice(offset, len, self.len(), name) + }, + }; + ca.into_column() + }) + .collect::>(); + + DataFrame::new(sort_columns(columns)) + } +} + +fn dummies_helper_idx(groups: &[IdxSize], len: usize, name: PlSmallStr) -> DummyCa { + let mut av = vec![0 as DummyType; len]; + + for &idx in groups { + let elem = unsafe { av.get_unchecked_mut(idx as usize) }; + *elem = 1; + } + + ChunkedArray::from_vec(name, av) +} + +fn dummies_helper_slice( + group_offset: IdxSize, + group_len: IdxSize, + len: usize, + name: PlSmallStr, +) -> DummyCa { + let mut av = vec![0 as DummyType; len]; + + for idx in group_offset..(group_offset + group_len) { + let elem = unsafe { av.get_unchecked_mut(idx as usize) }; + *elem = 1; + } + + ChunkedArray::from_vec(name, av) +} + +fn sort_columns(mut columns: Vec) -> Vec { + columns.sort_by(|a, b| a.name().partial_cmp(b.name()).unwrap()); + columns +} diff --git a/crates/polars-ops/src/series/ops/unique.rs b/crates/polars-ops/src/series/ops/unique.rs new file mode 100644 index 000000000000..906aa304d186 --- /dev/null +++ b/crates/polars-ops/src/series/ops/unique.rs @@ -0,0 +1,55 @@ +use std::hash::Hash; + +use polars_core::hashing::_HASHMAP_INIT_SIZE; +use polars_core::prelude::*; +use polars_core::utils::NoNull; +use polars_core::with_match_physical_numeric_polars_type; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; + +fn unique_counts_helper(items: I) -> IdxCa +where + I: Iterator, + J: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq, +{ + let mut map = PlIndexMap::with_capacity_and_hasher(_HASHMAP_INIT_SIZE, Default::default()); + for item in items { + let item = item.to_total_ord(); + map.entry(item) + .and_modify(|cnt| { + *cnt += 1; + }) + .or_insert(1 as IdxSize); + } + let out: NoNull = map.into_values().collect(); + out.into_inner() +} + +/// Returns a count of the unique values in the order of appearance. +pub fn unique_counts(s: &Series) -> PolarsResult { + if s.dtype().to_physical().is_primitive_numeric() { + let s_physical = s.to_physical_repr(); + + with_match_physical_numeric_polars_type!(s_physical.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s_physical.as_ref().as_ref().as_ref(); + Ok(unique_counts_helper(ca.iter()).into_series()) + }) + } else { + match s.dtype() { + DataType::String => { + Ok(unique_counts_helper(s.str().unwrap().into_iter()).into_series()) + }, + DataType::Null => { + let ca = if s.is_empty() { + IdxCa::new(s.name().clone(), [] as [IdxSize; 0]) + } else { + IdxCa::new(s.name().clone(), [s.len() as IdxSize]) + }; + Ok(ca.into_series()) + }, + dt => { + polars_bail!(opq = unique_counts, dt) + }, + } + } +} diff --git a/crates/polars-ops/src/series/ops/various.rs b/crates/polars-ops/src/series/ops/various.rs new file mode 100644 index 000000000000..9ea9bd737495 --- /dev/null +++ b/crates/polars-ops/src/series/ops/various.rs @@ -0,0 +1,214 @@ +use num_traits::Bounded; +#[cfg(feature = "dtype-struct")] +use polars_core::chunked_array::ops::row_encode::_get_rows_encoded_ca; +use polars_core::prelude::arity::unary_elementwise_values; +use polars_core::prelude::*; +use polars_core::series::IsSorted; +use polars_core::with_match_physical_numeric_polars_type; +use polars_utils::aliases::PlSeedableRandomStateQuality; +use polars_utils::total_ord::TotalOrd; + +use crate::series::ops::SeriesSealed; + +pub trait SeriesMethods: SeriesSealed { + /// Create a [`DataFrame`] with the unique `values` of this [`Series`] and a column `"counts"` + /// with dtype [`IdxType`] + fn value_counts( + &self, + sort: bool, + parallel: bool, + name: PlSmallStr, + normalize: bool, + ) -> PolarsResult { + let s = self.as_series(); + polars_ensure!( + s.name() != &name, + Duplicate: "using `value_counts` on a column/series named '{}' would lead to duplicate \ + column names; change `name` to fix", name, + ); + // we need to sort here as well in case of `maintain_order` because duplicates behavior is undefined + let groups = s.group_tuples(parallel, sort)?; + let values = unsafe { s.agg_first(&groups) }.into(); + let counts = groups.group_count().with_name(name.clone()); + + let counts = if normalize { + let len = s.len() as f64; + let counts: Float64Chunked = + unary_elementwise_values(&counts, |count| count as f64 / len); + counts.into_column() + } else { + counts.into_column() + }; + + let height = counts.len(); + let cols = vec![values, counts]; + let df = unsafe { DataFrame::new_no_checks(height, cols) }; + if sort { + df.sort( + [name], + SortMultipleOptions::default() + .with_order_descending(true) + .with_multithreaded(parallel), + ) + } else { + Ok(df) + } + } + + #[cfg(feature = "hash")] + fn hash(&self, build_hasher: PlSeedableRandomStateQuality) -> UInt64Chunked { + let s = self.as_series().to_physical_repr(); + match s.dtype() { + DataType::List(_) => { + let mut ca = s.list().unwrap().clone(); + crate::chunked_array::hash::hash(&mut ca, build_hasher) + }, + _ => { + let mut h = vec![]; + s.0.vec_hash(build_hasher, &mut h).unwrap(); + UInt64Chunked::from_vec(s.name().clone(), h) + }, + } + } + + fn ensure_sorted_arg(&self, operation: &str) -> PolarsResult<()> { + polars_ensure!(self.is_sorted(Default::default())?, InvalidOperation: "argument in operation '{}' is not sorted, please sort the 'expr/series/column' first", operation); + Ok(()) + } + + /// Checks if a [`Series`] is sorted. Tries to fail fast. + fn is_sorted(&self, options: SortOptions) -> PolarsResult { + let s = self.as_series(); + let null_count = s.null_count(); + + // fast paths + if (options.descending + && (options.nulls_last || null_count == 0) + && matches!(s.is_sorted_flag(), IsSorted::Descending)) + || (!options.descending + && (!options.nulls_last || null_count == 0) + && matches!(s.is_sorted_flag(), IsSorted::Ascending)) + { + return Ok(true); + } + + // for struct types we row-encode and recurse + #[cfg(feature = "dtype-struct")] + if matches!(s.dtype(), DataType::Struct(_)) { + let encoded = _get_rows_encoded_ca( + PlSmallStr::EMPTY, + &[s.clone().into()], + &[options.descending], + &[options.nulls_last], + )?; + return encoded.into_series().is_sorted(options); + } + + let s_len = s.len(); + if null_count == s_len { + // All nulls is all equal + return Ok(true); + } + // Check if nulls are in the right location. + if null_count > 0 { + // The slice triggers a fast null count + if options.nulls_last { + if s.slice((s_len - null_count) as i64, null_count) + .null_count() + != null_count + { + return Ok(false); + } + } else if s.slice(0, null_count).null_count() != null_count { + return Ok(false); + } + } + + if s.dtype().is_primitive_numeric() { + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + return Ok(is_sorted_ca_num::<$T>(ca, options)) + }) + } + + let cmp_len = s_len - null_count - 1; // Number of comparisons we might have to do + // TODO! Change this, allocation of a full boolean series is too expensive and doesn't fail fast. + // Compare adjacent elements with no-copy slices that don't include any nulls + let offset = !options.nulls_last as i64 * null_count as i64; + let (s1, s2) = (s.slice(offset, cmp_len), s.slice(offset + 1, cmp_len)); + let cmp_op = if options.descending { + Series::gt_eq + } else { + Series::lt_eq + }; + Ok(cmp_op(&s1, &s2)?.all()) + } +} + +fn check_cmp bool>( + vals: &[T], + f: Cmp, + previous: &mut T, +) -> bool { + let mut sorted = true; + + // Outer loop so we can fail fast + // Inner loop will auto vectorize + for c in vals.chunks(1024) { + // don't early stop or branch + // so it autovectorizes + for v in c { + sorted &= f(previous, v); + *previous = *v; + } + if !sorted { + return false; + } + } + sorted +} + +// Assumes nulls last/first is already checked. +fn is_sorted_ca_num(ca: &ChunkedArray, options: SortOptions) -> bool { + if let Ok(vals) = ca.cont_slice() { + let mut previous = vals[0]; + return if options.descending { + check_cmp(vals, |prev, c| prev.tot_ge(c), &mut previous) + } else { + check_cmp(vals, |prev, c| prev.tot_le(c), &mut previous) + }; + }; + + if ca.null_count() == 0 { + let mut previous = if options.descending { + T::Native::max_value() + } else { + T::Native::min_value() + }; + for arr in ca.downcast_iter() { + let vals = arr.values(); + + let sorted = if options.descending { + check_cmp(vals, |prev, c| prev.tot_ge(c), &mut previous) + } else { + check_cmp(vals, |prev, c| prev.tot_le(c), &mut previous) + }; + if !sorted { + return false; + } + } + return true; + }; + + // Slice off nulls and recurse. + let null_count = ca.null_count(); + if options.nulls_last { + let ca = ca.slice(0, ca.len() - null_count); + is_sorted_ca_num(&ca, options) + } else { + let ca = ca.slice(null_count as i64, ca.len() - null_count); + is_sorted_ca_num(&ca, options) + } +} + +impl SeriesMethods for Series {} diff --git a/crates/polars-parquet/Cargo.toml b/crates/polars-parquet/Cargo.toml new file mode 100644 index 000000000000..0414db12e8e3 --- /dev/null +++ b/crates/polars-parquet/Cargo.toml @@ -0,0 +1,65 @@ +[package] +name = "polars-parquet" +version = { workspace = true } +authors = [ + "Jorge C. Leitao ", + "Apache Arrow ", + "Ritchie Vink ", +] +edition = { workspace = true } +homepage = { workspace = true } +license = "MIT AND Apache-2.0" +repository = { workspace = true } +description = "Apache Parquet I/O operations for Polars" + +[dependencies] +arrow = { workspace = true, features = ["io_ipc"] } +base64 = { workspace = true } +bytemuck = { workspace = true } +ethnum = { workspace = true } +fallible-streaming-iterator = { workspace = true, optional = true } +futures = { workspace = true, optional = true } +hashbrown = { workspace = true } +num-traits = { workspace = true } +polars-compute = { workspace = true, features = ["approx_unique", "cast"] } +polars-error = { workspace = true } +polars-parquet-format = "0.1" +polars-utils = { workspace = true, features = ["mmap"] } +simdutf8 = { workspace = true } + +streaming-decompression = "0.1" + +async-stream = { version = "0.3.3", optional = true } + +brotli = { version = "^7.0", optional = true } +flate2 = { workspace = true, optional = true } +lz4 = { version = "1.24", optional = true } +lz4_flex = { version = "0.11", optional = true } +serde = { workspace = true, optional = true } +snap = { version = "^1.1", optional = true } +zstd = { workspace = true, optional = true } + +xxhash-rust = { version = "0.8", optional = true, features = ["xxh64"] } + +[dev-dependencies] +rand = "0.8" + +[features] +compression = [ + "brotli", + "gzip", + "lz4", + "snappy", + "zstd", +] + +# compression backends +snappy = ["snap"] +gzip = ["flate2/zlib-rs"] +lz4 = ["dep:lz4"] +lz4_flex = ["dep:lz4_flex"] + +async = ["async-stream", "futures", "polars-parquet-format/async"] +bloom_filter = ["xxhash-rust"] +serde_types = ["serde"] +simd = ["polars-compute/simd"] diff --git a/crates/polars-parquet/LICENSE b/crates/polars-parquet/LICENSE new file mode 100644 index 000000000000..7fd76611dd29 --- /dev/null +++ b/crates/polars-parquet/LICENSE @@ -0,0 +1,196 @@ +Some of the code in this crate is subject to the Apache 2 license below, as it +was taken from the arrow2 and parquet2 Rust crate in October 2023. Later changes are subject +to the MIT license in ../../LICENSE. + + + + Apache License + Version 2.0, January 2004 + https://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + Copyright 2020-2022 Jorge C. Leitão + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/crates/polars-parquet/src/arrow/mod.rs b/crates/polars-parquet/src/arrow/mod.rs new file mode 100644 index 000000000000..402dffe7ace6 --- /dev/null +++ b/crates/polars-parquet/src/arrow/mod.rs @@ -0,0 +1,8 @@ +pub mod read; +pub mod write; + +#[cfg(feature = "bloom_filter")] +#[cfg_attr(docsrs, doc(cfg(feature = "bloom_filter")))] +pub use crate::parquet::bloom_filter; + +const ARROW_SCHEMA_META_KEY: &str = "ARROW:schema"; diff --git a/crates/polars-parquet/src/arrow/read/README.md b/crates/polars-parquet/src/arrow/read/README.md new file mode 100644 index 000000000000..0d44ffca233a --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/README.md @@ -0,0 +1,35 @@ +## Observations + +### LSB equivalence between definition levels and bitmaps + +When the maximum repetition level is 0 and the maximum definition level is 1, the RLE-encoded +definition levels correspond exactly to Arrow's bitmap and can be memcopied without further +transformations. + +## Nested parquet groups are deserialized recursively + +Reading a parquet nested field is done by reading each primitive column sequentially, and build the +nested struct recursively. + +Rows of nested parquet groups are encoded in the repetition and definition levels. In arrow, they +correspond to: + +- list's offsets and validity +- struct's validity + +The implementation in this module leverages this observation: + +Nested parquet fields are initially recursed over to gather whether the type is a Struct or List, +and whether it is required or optional, which we store in `nested_info: Vec>`. +`Nested` is a trait object that receives definition and repetition levels depending on the type and +nullability of the nested item. We process the definition and repetition levels into `nested_info`. + +When we finish a field, we recursively pop from `nested_info` as we build the `StructArray` or +`ListArray`. + +With this approach, the only difference vs flat is: + +1. we do not leverage the bitmap optimization, and instead need to deserialize the repetition and + definition levels to `i32`. +2. we deserialize definition levels twice, once to extend the values/nullability and one to extend + `nested_info`. diff --git a/crates/polars-parquet/src/arrow/read/deserialize/README.md b/crates/polars-parquet/src/arrow/read/deserialize/README.md new file mode 100644 index 000000000000..aab9c280112d --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/README.md @@ -0,0 +1,73 @@ +# Design + +## Non-nested types + +Let's start with the design used for non-nested arrays. The (private) entry point of this module for +non-nested arrays is `simple::page_iter_to_arrays`. + +This function expects + +- a (fallible) streaming iterator of decompressed and encoded pages, `Pages` +- the source (parquet) column type, including its logical information +- the target (arrow) `DataType` +- the chunk size + +and returns an iterator of `Array`, `ArrayIter`. + +This design is shared among _all_ `(parquet, arrow)` implemented tuples. Their main difference is +how they are deserialized, which depends on the source and target types. + +When the array iterator is pulled the first time, the following happens: + +- a page from `Pages` is pulled +- a `PageState<'a>` is built from the page +- the `PageState` is consumed into a mutable array: + - if `chunk_size` is larger than the number of rows in the page, the mutable array state is + preserved and a new page is pulled and the process repeated until we fill a chunk. + - if `chunk_size` is smaller than the number of rows in the page, the mutable array state is + returned and the remaining of the page is consumed into multiple mutable arrays of length + `chunk_size` into a FIFO queue. + +Subsequent pulls of arrays will first try to pull from the FIFO queue. Once the queue is empty, the +a new page is pulled. + +### `PageState` + +As mentioned above, the iterator leverages the idea that we attach a state to a page. Recall that a +page is essentially `[header][data]`. The `data` part contains encoded +`[rep levels][def levels][non-null values]`. Some pages have an associated dictionary page, in which +case the `non-null values` represent the indices. + +Irrespectively of the physical type, the main idea is to split the page in two iterators: + +- An iterator over `def levels` +- An iterator over `non-null values` + +and progress the iterators as needed. In particular, for non-nested types, `def levels` is a bitmap +with the same representation as Arrow, in which case the validity is extended directly. + +The `non-null values` are "expanded" by filling null values with the default value of each physical +type. + +## Nested types + +For nested type with N+1 levels (1 is the primitive), we need to build the nest information of each +N levels + the non-nested Arrow array. + +This is done by first transversing the parquet types and using it to initialize, per chunk, the N +levels. + +The per-chunk execution is then similar but `chunk_size` only drives the number of retrieved rows +from the outermost parquet group (the field). Each of these pulls knows how many items need to be +pulled from the inner groups, all the way to the primitive type. This works because in parquet a row +cannot be split between two pages and thus each page is guaranteed to contain a full row. + +The `PageState` of nested types is composed by 4 iterators: + +- A (zipped) iterator over `rep levels` and `def levels` +- An iterator over `def levels` +- An iterator over `non-null values` + +The idea is that an iterator of `rep, def` contain all the information to decode the nesting +structure of an arrow array. The other two iterators are equivalent to the non-nested types with the +exception that `def levels` are no equivalent to arrow bitmaps. diff --git a/crates/polars-parquet/src/arrow/read/deserialize/binview/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/binview/mod.rs new file mode 100644 index 000000000000..84bd6fe34d94 --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/binview/mod.rs @@ -0,0 +1,686 @@ +use arrow::array::{Array, BinaryViewArray, MutableBinaryViewArray, Utf8ViewArray, View}; +use arrow::bitmap::{Bitmap, BitmapBuilder}; +use arrow::datatypes::{ArrowDataType, PhysicalType}; + +use super::dictionary_encoded::{append_validity, constrain_page_validity}; +use super::utils::{ + dict_indices_decoder, filter_from_range, freeze_validity, unspecialized_decode, +}; +use super::{Filter, PredicateFilter, dictionary_encoded}; +use crate::parquet::encoding::{Encoding, delta_byte_array, delta_length_byte_array, hybrid_rle}; +use crate::parquet::error::{ParquetError, ParquetResult}; +use crate::parquet::page::{DataPage, DictPage, split_buffer}; +use crate::read::deserialize::utils::{self}; + +mod optional; +mod optional_masked; +mod predicate; +mod required; +mod required_masked; + +type DecodedStateTuple = (MutableBinaryViewArray<[u8]>, BitmapBuilder); + +impl<'a> utils::StateTranslation<'a, BinViewDecoder> for StateTranslation<'a> { + type PlainDecoder = BinaryIter<'a>; + + fn new( + _decoder: &BinViewDecoder, + page: &'a DataPage, + dict: Option<&'a ::Dict>, + page_validity: Option<&Bitmap>, + ) -> ParquetResult { + match (page.encoding(), dict) { + (Encoding::Plain, _) => { + let values = split_buffer(page)?.values; + let values = BinaryIter::new(values, page.num_values()); + + Ok(Self::Plain(values)) + }, + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(_)) => { + let values = + dict_indices_decoder(page, page_validity.map_or(0, |bm| bm.unset_bits()))?; + Ok(Self::Dictionary(values)) + }, + (Encoding::DeltaLengthByteArray, _) => { + let values = split_buffer(page)?.values; + Ok(Self::DeltaLengthByteArray( + delta_length_byte_array::Decoder::try_new(values)?, + Vec::new(), + )) + }, + (Encoding::DeltaByteArray, _) => { + let values = split_buffer(page)?.values; + Ok(Self::DeltaBytes(delta_byte_array::Decoder::try_new( + values, + )?)) + }, + _ => Err(utils::not_implemented(page)), + } + } + + fn num_rows(&self) -> usize { + match self { + StateTranslation::Plain(i) => i.max_num_values, + StateTranslation::Dictionary(i) => i.len(), + StateTranslation::DeltaLengthByteArray(i, _) => i.len(), + StateTranslation::DeltaBytes(i) => i.len(), + } + } +} + +pub(crate) struct BinViewDecoder { + pub is_string: bool, +} + +impl BinViewDecoder { + pub fn new_string() -> Self { + Self { is_string: true } + } +} + +#[allow(clippy::large_enum_variant)] +#[derive(Debug)] +pub(crate) enum StateTranslation<'a> { + Plain(BinaryIter<'a>), + Dictionary(hybrid_rle::HybridRleDecoder<'a>), + DeltaLengthByteArray(delta_length_byte_array::Decoder<'a>, Vec), + DeltaBytes(delta_byte_array::Decoder<'a>), +} + +impl utils::Decoded for DecodedStateTuple { + fn len(&self) -> usize { + self.0.len() + } + + fn extend_nulls(&mut self, n: usize) { + self.0.extend_constant(n, Some(&[])); + self.1.extend_constant(n, false); + } +} + +#[allow(clippy::too_many_arguments)] +pub fn decode_plain( + values: &[u8], + max_num_values: usize, + target: &mut MutableBinaryViewArray<[u8]>, + + is_optional: bool, + validity: &mut BitmapBuilder, + + page_validity: Option<&Bitmap>, + filter: Option, + + pred_true_mask: &mut BitmapBuilder, + + verify_utf8: bool, +) -> ParquetResult<()> { + if is_optional { + append_validity(page_validity, filter.as_ref(), validity, max_num_values); + } + let page_validity = constrain_page_validity(max_num_values, page_validity, filter.as_ref()); + + match (filter, page_validity) { + (None, None) => required::decode(max_num_values, values, None, target, verify_utf8), + (Some(Filter::Range(rng)), None) if rng.start == 0 => { + required::decode(max_num_values, values, Some(rng.end), target, verify_utf8) + }, + (None, Some(page_validity)) => optional::decode( + page_validity.set_bits(), + values, + target, + &page_validity, + verify_utf8, + ), + (Some(Filter::Range(rng)), Some(page_validity)) if rng.start == 0 => optional::decode( + page_validity.set_bits(), + values, + target, + &page_validity, + verify_utf8, + ), + (Some(Filter::Mask(mask)), None) => { + required_masked::decode(max_num_values, values, target, &mask, verify_utf8) + }, + (Some(Filter::Mask(mask)), Some(page_validity)) => optional_masked::decode( + page_validity.set_bits(), + values, + target, + &page_validity, + &mask, + verify_utf8, + ), + (Some(Filter::Range(rng)), None) => required_masked::decode( + max_num_values, + values, + target, + &filter_from_range(rng.clone()), + verify_utf8, + ), + (Some(Filter::Range(rng)), Some(page_validity)) => optional_masked::decode( + page_validity.set_bits(), + values, + target, + &page_validity, + &filter_from_range(rng.clone()), + verify_utf8, + ), + (Some(Filter::Predicate(p)), page_validity) => { + let Some(needle) = p.predicate.to_equals_scalar() else { + unreachable!(); + }; + + if needle.is_null() || page_validity.is_some() { + todo!(); + } + + let needle = if verify_utf8 { + needle.as_str().unwrap().as_bytes() + } else { + needle.as_binary().unwrap() + }; + + let start_pred_true_num = pred_true_mask.set_bits(); + predicate::decode_equals(max_num_values, values, needle, pred_true_mask)?; + + if p.include_values { + let pred_true_num = pred_true_mask.set_bits() - start_pred_true_num; + + if pred_true_num > 0 { + let new_target_len = target.len() + pred_true_num; + let new_total_bytes_len = + target.total_bytes_len() + pred_true_num * needle.len(); + + target.push_value(needle); + let view = *target.views().last().unwrap(); + + // SAFETY: We know that the view is valid since we added it safely and we + // update the total_bytes_len afterwards. The total_buffer_len is not affected. + unsafe { + target.views_mut().resize(new_target_len, view); + target.set_total_bytes_len(new_total_bytes_len); + } + } + } + + Ok(()) + }, + }?; + + Ok(()) +} + +#[cold] +fn invalid_input_err() -> ParquetError { + ParquetError::oos("String data does not match given length") +} + +#[cold] +fn invalid_utf8_err() -> ParquetError { + ParquetError::oos("String data contained invalid UTF-8") +} + +pub fn decode_plain_generic( + values: &[u8], + target: &mut MutableBinaryViewArray<[u8]>, + + num_rows: usize, + mut next: impl FnMut() -> Option<(bool, bool)>, + + verify_utf8: bool, +) -> ParquetResult<()> { + // Since the offset in the buffer is decided by the interleaved lengths, every value has to be + // walked no matter what. This makes decoding rather inefficient in general. + // + // There are three cases: + // 1. All inlinable values + // - Most time is spend in decoding + // - No additional buffer has to be formed + // - Possible UTF-8 verification is fast because the len_below_128 trick + // 2. All non-inlinable values + // - Little time is spend in decoding + // - Most time is spend in buffer memcopying (we remove the interleaved lengths) + // - Possible UTF-8 verification is fast because the continuation byte trick + // 3. Mixed inlinable and non-inlinable values + // - Time shared between decoding and buffer forming + // - UTF-8 verification might still use len_below_128 trick, but might need to fall back to + // slow path. + + target.finish_in_progress(); + unsafe { target.views_mut() }.reserve(num_rows); + + let start_target_length = target.len(); + + let buffer_idx = target.completed_buffers().len() as u32; + let mut buffer = Vec::with_capacity(values.len() + 1); + let mut none_starting_with_continuation_byte = true; // Whether the transition from between strings is valid + // UTF-8 + let mut all_len_below_128 = true; // Whether all the lengths of the values are below 128, this + // allows us to make UTF-8 verification a lot faster. + + let mut total_bytes_len = 0; + let mut num_seen = 0; + let mut num_inlined = 0; + + let mut mvalues = values; + while let Some((is_valid, is_selected)) = next() { + if !is_valid { + if is_selected { + unsafe { target.views_mut() }.push(unsafe { View::new_inline_unchecked(&[]) }); + } + continue; + } + + if mvalues.len() < 4 { + return Err(invalid_input_err()); + } + + let length; + (length, mvalues) = mvalues.split_at(4); + let length: &[u8; 4] = unsafe { length.try_into().unwrap_unchecked() }; + let length = u32::from_le_bytes(*length); + + if mvalues.len() < length as usize { + return Err(invalid_input_err()); + } + + let value; + (value, mvalues) = mvalues.split_at(length as usize); + + num_seen += 1; + all_len_below_128 &= value.len() < 128; + // Everything starting with 10.. .... is a continuation byte. + none_starting_with_continuation_byte &= + value.is_empty() || value[0] & 0b1100_0000 != 0b1000_0000; + + if !is_selected { + continue; + } + + let offset = buffer.len() as u32; + + if value.len() <= View::MAX_INLINE_SIZE as usize { + unsafe { target.views_mut() }.push(unsafe { View::new_inline_unchecked(value) }); + num_inlined += 1; + } else { + buffer.extend_from_slice(value); + unsafe { target.views_mut() } + .push(unsafe { View::new_noninline_unchecked(value, buffer_idx, offset) }); + } + + total_bytes_len += value.len(); + } + + unsafe { + target.set_total_bytes_len(target.total_bytes_len() + total_bytes_len); + } + + if verify_utf8 { + // This is a trick that allows us to check the resulting buffer which allows to batch the + // UTF-8 verification. + // + // This is allowed if none of the strings start with a UTF-8 continuation byte, so we keep + // track of that during the decoding. + if num_inlined == 0 { + if !none_starting_with_continuation_byte || simdutf8::basic::from_utf8(&buffer).is_err() + { + return Err(invalid_utf8_err()); + } + + // This is a small trick that allows us to check the Parquet buffer instead of the view + // buffer. Batching the UTF-8 verification is more performant. For this to be allowed, + // all the interleaved lengths need to be valid UTF-8. + // + // Every strings prepended by 4 bytes (L, 0, 0, 0), since we check here L < 128. L is + // only a valid first byte of a UTF-8 code-point and (L, 0, 0, 0) is valid UTF-8. + // Consequently, it is valid to just check the whole buffer. + } else if all_len_below_128 { + if simdutf8::basic::from_utf8(&values[..values.len() - mvalues.len()]).is_err() { + return Err(invalid_utf8_err()); + } + } else { + // We check all the non-inlined values here. + if !none_starting_with_continuation_byte || simdutf8::basic::from_utf8(&buffer).is_err() + { + return Err(invalid_utf8_err()); + } + + let mut all_inlined_are_ascii = true; + + // @NOTE: This is only valid because we initialize our inline View's to be zeroes on + // non-included bytes. + for view in &target.views()[start_target_length..] { + all_inlined_are_ascii &= (view.length > View::MAX_INLINE_SIZE) + | (view.as_u128() & 0x0000_0000_8080_8080_8080_8080_8080_8080 == 0); + } + + // This is the very slow path. + if !all_inlined_are_ascii { + let mut is_valid = true; + for view in &target.views()[target.len() - num_seen..] { + if view.length <= View::MAX_INLINE_SIZE { + is_valid &= + std::str::from_utf8(unsafe { view.get_inlined_slice_unchecked() }) + .is_ok(); + } + } + + if !is_valid { + return Err(invalid_utf8_err()); + } + } + } + } + + target.push_buffer(buffer.into()); + + Ok(()) +} + +impl utils::Decoder for BinViewDecoder { + type Translation<'a> = StateTranslation<'a>; + type Dict = BinaryViewArray; + type DecodedState = DecodedStateTuple; + type Output = Box; + + fn with_capacity(&self, capacity: usize) -> Self::DecodedState { + ( + MutableBinaryViewArray::with_capacity(capacity), + BitmapBuilder::with_capacity(capacity), + ) + } + + fn evaluate_dict_predicate( + &self, + dict: &Self::Dict, + predicate: &PredicateFilter, + ) -> ParquetResult { + let utf8_array; + let mut dict_arr = dict as &dyn Array; + + if self.is_string { + utf8_array = unsafe { dict.to_utf8view_unchecked() }; + dict_arr = &utf8_array + } + + Ok(predicate.predicate.evaluate(dict_arr)) + } + + fn has_predicate_specialization( + &self, + state: &utils::State<'_, Self>, + predicate: &PredicateFilter, + ) -> ParquetResult { + let mut has_predicate_specialization = false; + + has_predicate_specialization |= + matches!(state.translation, StateTranslation::Dictionary(_)); + has_predicate_specialization |= matches!(state.translation, StateTranslation::Plain(_)) + && predicate.predicate.to_equals_scalar().is_some(); + + // @TODO: This should be implemented + has_predicate_specialization &= state.page_validity.is_none(); + + Ok(has_predicate_specialization) + } + + fn apply_dictionary( + &mut self, + (values, _): &mut Self::DecodedState, + dict: &Self::Dict, + ) -> ParquetResult<()> { + if values.completed_buffers().len() < dict.data_buffers().len() { + for buffer in dict.data_buffers().as_ref() { + values.push_buffer(buffer.clone()); + } + } + + assert!(values.completed_buffers().len() == dict.data_buffers().len()); + + Ok(()) + } + + fn deserialize_dict(&mut self, page: DictPage) -> ParquetResult { + let values = &page.buffer; + let num_values = page.num_values; + + let mut arr = MutableBinaryViewArray::new(); + required::decode(num_values, values, None, &mut arr, self.is_string)?; + + Ok(arr.freeze()) + } + + fn extend_decoded( + &self, + decoded: &mut Self::DecodedState, + additional: &dyn Array, + is_optional: bool, + ) -> ParquetResult<()> { + let is_utf8 = self.is_string; + if is_utf8 { + let array = additional.as_any().downcast_ref::().unwrap(); + let mut array = array.to_binview(); + + if let Some(validity) = array.take_validity() { + decoded.0.extend_from_array(&array); + decoded.1.extend_from_bitmap(&validity); + } else { + decoded.0.extend_from_array(&array); + if is_optional { + decoded.1.extend_constant(array.len(), true); + } + } + } else { + let array = additional + .as_any() + .downcast_ref::() + .unwrap(); + let mut array = array.clone(); + + if let Some(validity) = array.take_validity() { + decoded.0.extend_from_array(&array); + decoded.1.extend_from_bitmap(&validity); + } else { + decoded.0.extend_from_array(&array); + if is_optional { + decoded.1.extend_constant(array.len(), true); + } + } + } + + Ok(()) + } + + fn extend_filtered_with_state( + &mut self, + mut state: utils::State<'_, Self>, + decoded: &mut Self::DecodedState, + pred_true_mask: &mut BitmapBuilder, + filter: Option, + ) -> ParquetResult<()> { + match state.translation { + StateTranslation::Plain(iter) => decode_plain( + iter.values, + iter.max_num_values, + &mut decoded.0, + state.is_optional, + &mut decoded.1, + state.page_validity.as_ref(), + filter, + pred_true_mask, + self.is_string, + ), + StateTranslation::Dictionary(ref mut indexes) => { + let dict = state.dict.unwrap(); + + let start_length = decoded.0.views().len(); + + dictionary_encoded::decode_dict( + indexes.clone(), + dict.views().as_slice(), + state.dict_mask, + state.is_optional, + state.page_validity.as_ref(), + filter, + &mut decoded.1, + unsafe { decoded.0.views_mut() }, + pred_true_mask, + )?; + + let total_length: usize = decoded + .0 + .views() + .iter() + .skip(start_length) + .map(|view| view.length as usize) + .sum(); + unsafe { + decoded + .0 + .set_total_bytes_len(decoded.0.total_bytes_len() + total_length); + } + + Ok(()) + }, + StateTranslation::DeltaLengthByteArray(decoder, _vec) => { + let values = decoder.values; + let lengths = decoder.lengths.collect::>()?; + + if self.is_string { + let mut none_starting_with_continuation_byte = true; + let mut offset = 0; + for length in &lengths { + none_starting_with_continuation_byte &= + *length == 0 || values[offset] & 0xC0 != 0x80; + offset += *length as usize; + } + + if !none_starting_with_continuation_byte { + return Err(invalid_utf8_err()); + } + + if simdutf8::basic::from_utf8(&values[..offset]).is_err() { + return Err(invalid_utf8_err()); + } + } + + let mut i = 0; + let mut offset = 0; + unspecialized_decode( + lengths.len(), + || { + let length = lengths[i] as usize; + + let value = &values[offset..offset + length]; + + i += 1; + offset += length; + + Ok(value) + }, + filter, + state.page_validity, + state.is_optional, + &mut decoded.1, + &mut decoded.0, + ) + }, + StateTranslation::DeltaBytes(mut decoder) => { + let check_utf8 = self.is_string; + + unspecialized_decode( + decoder.len(), + || { + let value = decoder.next().unwrap()?; + + if check_utf8 && simdutf8::basic::from_utf8(&value[..]).is_err() { + return Err(invalid_utf8_err()); + } + + Ok(value) + }, + filter, + state.page_validity, + state.is_optional, + &mut decoded.1, + &mut decoded.0, + ) + }, + } + } + + fn finalize( + &self, + dtype: ArrowDataType, + _dict: Option, + (values, validity): Self::DecodedState, + ) -> ParquetResult> { + let mut array: BinaryViewArray = values.freeze(); + + let validity = freeze_validity(validity); + array = array.with_validity(validity); + + match dtype.to_physical_type() { + PhysicalType::BinaryView => Ok(array.boxed()), + PhysicalType::Utf8View => { + // SAFETY: we already checked utf8 + unsafe { + Ok(Utf8ViewArray::new_unchecked( + dtype, + array.views().clone(), + array.data_buffers().clone(), + array.validity().cloned(), + array.total_bytes_len(), + array.total_buffer_len(), + ) + .boxed()) + } + }, + _ => unreachable!(), + } + } +} + +#[derive(Debug)] +pub struct BinaryIter<'a> { + values: &'a [u8], + + /// A maximum number of items that this [`BinaryIter`] may produce. + /// + /// This equal the length of the iterator i.f.f. the data encoded by the [`BinaryIter`] is not + /// nullable. + max_num_values: usize, +} + +impl<'a> BinaryIter<'a> { + pub fn new(values: &'a [u8], max_num_values: usize) -> Self { + Self { + values, + max_num_values, + } + } +} + +impl<'a> Iterator for BinaryIter<'a> { + type Item = &'a [u8]; + + #[inline] + fn next(&mut self) -> Option { + if self.max_num_values == 0 { + assert!(self.values.is_empty()); + return None; + } + + let (length, remaining) = self.values.split_at(4); + let length: [u8; 4] = unsafe { length.try_into().unwrap_unchecked() }; + let length = u32::from_le_bytes(length) as usize; + let (result, remaining) = remaining.split_at(length); + self.max_num_values -= 1; + self.values = remaining; + Some(result) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (0, Some(self.max_num_values)) + } +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/binview/optional.rs b/crates/polars-parquet/src/arrow/read/deserialize/binview/optional.rs new file mode 100644 index 000000000000..b51e0e799f46 --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/binview/optional.rs @@ -0,0 +1,33 @@ +use arrow::array::MutableBinaryViewArray; +use arrow::bitmap::Bitmap; + +use super::decode_plain_generic; +use crate::parquet::error::ParquetResult; + +pub fn decode( + num_expected_values: usize, + values: &[u8], + target: &mut MutableBinaryViewArray<[u8]>, + page_validity: &Bitmap, + + verify_utf8: bool, +) -> ParquetResult<()> { + if page_validity.unset_bits() == 0 { + return super::required::decode( + num_expected_values, + values, + Some(page_validity.len()), + target, + verify_utf8, + ); + } + + let mut validity_iter = page_validity.iter(); + decode_plain_generic( + values, + target, + page_validity.len(), + || Some((validity_iter.next()?, true)), + verify_utf8, + ) +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/binview/optional_masked.rs b/crates/polars-parquet/src/arrow/read/deserialize/binview/optional_masked.rs new file mode 100644 index 000000000000..2b1df91559a1 --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/binview/optional_masked.rs @@ -0,0 +1,47 @@ +use arrow::array::MutableBinaryViewArray; +use arrow::bitmap::Bitmap; + +use crate::parquet::error::ParquetResult; +use crate::read::deserialize::binview::decode_plain_generic; + +pub fn decode( + num_expected_values: usize, + values: &[u8], + target: &mut MutableBinaryViewArray<[u8]>, + + page_validity: &Bitmap, + mask: &Bitmap, + + verify_utf8: bool, +) -> ParquetResult<()> { + assert_eq!(page_validity.len(), mask.len()); + + if mask.unset_bits() == 0 { + return super::optional::decode( + num_expected_values, + values, + target, + page_validity, + verify_utf8, + ); + } + if page_validity.unset_bits() == 0 { + return super::required_masked::decode( + num_expected_values, + values, + target, + mask, + verify_utf8, + ); + } + + let mut validity_iter = page_validity.iter(); + let mut mask_iter = mask.iter(); + decode_plain_generic( + values, + target, + mask.set_bits(), + || Some((validity_iter.next()?, mask_iter.next()?)), + verify_utf8, + ) +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/binview/predicate.rs b/crates/polars-parquet/src/arrow/read/deserialize/binview/predicate.rs new file mode 100644 index 000000000000..aec5bbd11ffb --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/binview/predicate.rs @@ -0,0 +1,88 @@ +//! Specialized kernels to do predicate evaluation directly on the `BinView` Parquet data. + +use arrow::array::View; +use arrow::bitmap::BitmapBuilder; + +use crate::parquet::error::ParquetResult; + +/// Create a mask for when a value is equal to the `needle`. +pub fn decode_equals( + num_expected_values: usize, + values: &[u8], + needle: &[u8], + pred_true_mask: &mut BitmapBuilder, +) -> ParquetResult<()> { + if needle.len() <= View::MAX_INLINE_SIZE as usize { + decode_equals_inlinable(num_expected_values, values, needle, pred_true_mask) + } else { + decode_equals_non_inlineable(num_expected_values, values, needle, pred_true_mask) + } +} + +/// Equality kernel for when the `needle` is inlineable into the `View`. +fn decode_equals_inlinable( + num_expected_values: usize, + mut values: &[u8], + needle: &[u8], + pred_true_mask: &mut BitmapBuilder, +) -> ParquetResult<()> { + let needle = View::new_inline(needle); + + pred_true_mask.reserve(num_expected_values); + for _ in 0..num_expected_values { + if values.len() < 4 { + return Err(super::invalid_input_err()); + } + + let length; + (length, values) = values.split_at(4); + let length: &[u8; 4] = unsafe { length.try_into().unwrap_unchecked() }; + let length = u32::from_le_bytes(*length); + + if values.len() < length as usize { + return Err(super::invalid_input_err()); + } + + let value; + (value, values) = values.split_at(length as usize); + let view = View::new_from_bytes(value, 0, 0); + + // SAFETY: We reserved enough just before the loop. + unsafe { pred_true_mask.push_unchecked(needle == view) }; + } + + Ok(()) +} + +/// Equality kernel for when the `needle` is not-inlineable into the `View`. +fn decode_equals_non_inlineable( + num_expected_values: usize, + mut values: &[u8], + needle: &[u8], + pred_true_mask: &mut BitmapBuilder, +) -> ParquetResult<()> { + pred_true_mask.reserve(num_expected_values); + for _ in 0..num_expected_values { + if values.len() < 4 { + return Err(super::invalid_input_err()); + } + + let length; + (length, values) = values.split_at(4); + let length: &[u8; 4] = unsafe { length.try_into().unwrap_unchecked() }; + let length = u32::from_le_bytes(*length); + + if values.len() < length as usize { + return Err(super::invalid_input_err()); + } + + let value; + (value, values) = values.split_at(length as usize); + + let is_pred_true = length as usize == needle.len() && value == needle; + // SAFETY: We reserved enough just before the loop. + unsafe { pred_true_mask.push_unchecked(is_pred_true) }; + } + + Ok(()) +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/binview/required.rs b/crates/polars-parquet/src/arrow/read/deserialize/binview/required.rs new file mode 100644 index 000000000000..ca55ced9e351 --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/binview/required.rs @@ -0,0 +1,32 @@ +use arrow::array::MutableBinaryViewArray; + +use super::decode_plain_generic; +use crate::parquet::error::ParquetResult; + +pub fn decode( + num_expected_values: usize, + values: &[u8], + limit: Option, + target: &mut MutableBinaryViewArray<[u8]>, + + verify_utf8: bool, +) -> ParquetResult<()> { + let limit = limit.unwrap_or(num_expected_values); + + let mut idx = 0; + decode_plain_generic( + values, + target, + limit, + || { + if idx >= limit { + return None; + } + + idx += 1; + + Some((true, true)) + }, + verify_utf8, + ) +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/binview/required_masked.rs b/crates/polars-parquet/src/arrow/read/deserialize/binview/required_masked.rs new file mode 100644 index 000000000000..94613ea4578a --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/binview/required_masked.rs @@ -0,0 +1,34 @@ +use arrow::array::MutableBinaryViewArray; +use arrow::bitmap::Bitmap; + +use super::decode_plain_generic; +use crate::parquet::error::ParquetResult; + +pub fn decode( + num_expected_values: usize, + values: &[u8], + target: &mut MutableBinaryViewArray<[u8]>, + + mask: &Bitmap, + + verify_utf8: bool, +) -> ParquetResult<()> { + if mask.unset_bits() == 0 { + return super::required::decode( + num_expected_values, + values, + Some(mask.len()), + target, + verify_utf8, + ); + } + + let mut mask_iter = mask.iter(); + decode_plain_generic( + values, + target, + mask.set_bits(), + || Some((true, mask_iter.next()?)), + verify_utf8, + ) +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/boolean.rs b/crates/polars-parquet/src/arrow/read/deserialize/boolean.rs new file mode 100644 index 000000000000..9e01c72bacff --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/boolean.rs @@ -0,0 +1,456 @@ +use arrow::array::{BooleanArray, Splitable}; +use arrow::bitmap::bitmask::BitMask; +use arrow::bitmap::utils::BitmapIter; +use arrow::bitmap::{Bitmap, BitmapBuilder}; +use arrow::datatypes::ArrowDataType; +use polars_compute::filter::filter_boolean_kernel; + +use super::dictionary_encoded::{append_validity, constrain_page_validity}; +use super::utils::{ + self, Decoded, Decoder, decode_hybrid_rle_into_bitmap, filter_from_range, freeze_validity, +}; +use super::{Filter, PredicateFilter}; +use crate::parquet::encoding::Encoding; +use crate::parquet::encoding::hybrid_rle::{HybridRleChunk, HybridRleDecoder}; +use crate::parquet::error::ParquetResult; +use crate::parquet::page::{DataPage, DictPage, split_buffer}; + +#[allow(clippy::large_enum_variant)] +#[derive(Debug)] +pub(crate) enum StateTranslation<'a> { + Plain(BitMask<'a>), + Rle(HybridRleDecoder<'a>), +} + +impl<'a> utils::StateTranslation<'a, BooleanDecoder> for StateTranslation<'a> { + type PlainDecoder = BitmapIter<'a>; + + fn new( + _decoder: &BooleanDecoder, + page: &'a DataPage, + _dict: Option<&'a ::Dict>, + page_validity: Option<&Bitmap>, + ) -> ParquetResult { + let values = split_buffer(page)?.values; + + match page.encoding() { + Encoding::Plain => { + let max_num_values = values.len() * u8::BITS as usize; + let num_values = if page_validity.is_some() { + // @NOTE: We overestimate the amount of values here, but in the V1 + // specification we don't really have a way to know the number of valid items. + // Without traversing the list. + max_num_values + } else { + // @NOTE: We cannot really trust the value from this as it might relate to the + // number of top-level nested values. Therefore, we do a `min` with the maximum + // number of possible values. + usize::min(page.num_values(), max_num_values) + }; + + Ok(Self::Plain(BitMask::new(values, 0, num_values))) + }, + Encoding::Rle => { + // @NOTE: For a nullable list, we might very well overestimate the amount of + // values, but we never collect those items. We don't really have a way to know the + // number of valid items in the V1 specification. + + // For RLE boolean values the length in bytes is pre-pended. + // https://github.com/apache/parquet-format/blob/e517ac4dbe08d518eb5c2e58576d4c711973db94/Encodings.md#run-length-encoding--bit-packing-hybrid-rle--3 + let (_len_in_bytes, values) = values.split_at(4); + Ok(Self::Rle(HybridRleDecoder::new( + values, + 1, + page.num_values(), + ))) + }, + _ => Err(utils::not_implemented(page)), + } + } + fn num_rows(&self) -> usize { + match self { + Self::Plain(m) => m.len(), + Self::Rle(m) => m.len(), + } + } +} + +fn decode_required_rle( + values: HybridRleDecoder<'_>, + limit: Option, + target: &mut BitmapBuilder, +) -> ParquetResult<()> { + decode_hybrid_rle_into_bitmap(values, limit, target)?; + Ok(()) +} + +fn decode_optional_rle( + values: HybridRleDecoder<'_>, + target: &mut BitmapBuilder, + page_validity: &Bitmap, +) -> ParquetResult<()> { + debug_assert!(page_validity.set_bits() <= values.len()); + + if page_validity.unset_bits() == 0 { + return decode_required_rle(values, Some(page_validity.len()), target); + } + + target.reserve(page_validity.len()); + + let mut validity_mask = BitMask::from_bitmap(page_validity); + + for chunk in values.into_chunk_iter() { + let chunk = chunk?; + + match chunk { + HybridRleChunk::Rle(value, size) => { + let offset = validity_mask + .nth_set_bit_idx(size, 0) + .unwrap_or(validity_mask.len()); + + let t; + (t, validity_mask) = validity_mask.split_at(offset); + + target.extend_constant(t.len(), value != 0); + }, + HybridRleChunk::Bitpacked(decoder) => { + let decoder_slice = decoder.as_slice(); + let offset = validity_mask + .nth_set_bit_idx(decoder.len(), 0) + .unwrap_or(validity_mask.len()); + + let decoder_validity; + (decoder_validity, validity_mask) = validity_mask.split_at(offset); + + let mut offset = 0; + let mut validity_iter = decoder_validity.iter(); + while validity_iter.num_remaining() > 0 { + let num_valid = validity_iter.take_leading_ones(); + target.extend_from_slice(decoder_slice, offset, num_valid); + offset += num_valid; + + let num_invalid = validity_iter.take_leading_zeros(); + target.extend_constant(num_invalid, false); + } + }, + } + } + + if cfg!(debug_assertions) { + assert_eq!(validity_mask.set_bits(), 0); + } + target.extend_constant(validity_mask.len(), false); + + Ok(()) +} + +fn decode_masked_required_rle( + values: HybridRleDecoder<'_>, + target: &mut BitmapBuilder, + mask: &Bitmap, +) -> ParquetResult<()> { + debug_assert!(mask.len() <= values.len()); + + if mask.unset_bits() == 0 { + return decode_required_rle(values, Some(mask.len()), target); + } + + let mut im_target = BitmapBuilder::new(); + decode_required_rle(values, Some(mask.len()), &mut im_target)?; + + target.extend_from_bitmap(&filter_boolean_kernel(&im_target.freeze(), mask)); + + Ok(()) +} + +fn decode_masked_optional_rle( + values: HybridRleDecoder<'_>, + target: &mut BitmapBuilder, + page_validity: &Bitmap, + mask: &Bitmap, +) -> ParquetResult<()> { + debug_assert_eq!(page_validity.len(), mask.len()); + debug_assert!(mask.len() <= values.len()); + + if mask.unset_bits() == 0 { + return decode_optional_rle(values, target, page_validity); + } + + if page_validity.unset_bits() == 0 { + return decode_masked_required_rle(values, target, mask); + } + + let mut im_target = BitmapBuilder::new(); + decode_optional_rle(values, &mut im_target, page_validity)?; + + target.extend_from_bitmap(&filter_boolean_kernel(&im_target.freeze(), mask)); + + Ok(()) +} + +fn decode_required_plain(values: BitMask<'_>, target: &mut BitmapBuilder) -> ParquetResult<()> { + target.extend_from_bitmask(values); + Ok(()) +} + +fn decode_optional_plain( + mut values: BitMask<'_>, + target: &mut BitmapBuilder, + mut page_validity: Bitmap, +) -> ParquetResult<()> { + debug_assert!(page_validity.set_bits() <= values.len()); + + if page_validity.unset_bits() == 0 { + return decode_required_plain(values.sliced(0, page_validity.len()), target); + } + + target.reserve(page_validity.len()); + + while !page_validity.is_empty() { + let num_valid = page_validity.take_leading_ones(); + let iv; + (iv, values) = values.split_at(num_valid); + target.extend_from_bitmask(iv); + + let num_invalid = page_validity.take_leading_zeros(); + target.extend_constant(num_invalid, false); + } + + Ok(()) +} + +fn decode_masked_required_plain( + mut values: BitMask, + target: &mut BitmapBuilder, + mut mask: Bitmap, +) -> ParquetResult<()> { + debug_assert!(mask.len() <= values.len()); + + let leading_zeros = mask.take_leading_zeros(); + mask.take_trailing_zeros(); + + values = values.sliced(leading_zeros, mask.len()); + + if mask.unset_bits() == 0 { + return decode_required_plain(values, target); + } + + let mut im_target = BitmapBuilder::new(); + decode_required_plain(values, &mut im_target)?; + + target.extend_from_bitmap(&filter_boolean_kernel(&im_target.freeze(), &mask)); + + Ok(()) +} + +fn decode_masked_optional_plain( + mut values: BitMask<'_>, + target: &mut BitmapBuilder, + mut page_validity: Bitmap, + mut mask: Bitmap, +) -> ParquetResult<()> { + debug_assert_eq!(page_validity.len(), mask.len()); + debug_assert!(page_validity.set_bits() <= values.len()); + + let leading_zeros = mask.take_leading_zeros(); + mask.take_trailing_zeros(); + + let (skipped, truncated); + (skipped, page_validity) = page_validity.split_at(leading_zeros); + (page_validity, truncated) = page_validity.split_at(mask.len()); + + let skipped_values = skipped.set_bits(); + let truncated_values = truncated.set_bits(); + values = values.sliced( + skipped_values, + values.len() - skipped_values - truncated_values, + ); + + if mask.unset_bits() == 0 { + return decode_optional_plain(values, target, page_validity); + } + + if page_validity.unset_bits() == 0 { + return decode_masked_required_plain(values, target, mask); + } + + let mut im_target = BitmapBuilder::new(); + decode_optional_plain(values, &mut im_target, page_validity)?; + + target.extend_from_bitmap(&filter_boolean_kernel(&im_target.freeze(), &mask)); + + Ok(()) +} + +impl Decoded for (BitmapBuilder, BitmapBuilder) { + fn len(&self) -> usize { + self.0.len() + } + fn extend_nulls(&mut self, n: usize) { + self.0.extend_constant(n, false); + self.1.extend_constant(n, false); + } +} + +pub(crate) struct BooleanDecoder; + +impl Decoder for BooleanDecoder { + type Translation<'a> = StateTranslation<'a>; + type Dict = BooleanArray; + type DecodedState = (BitmapBuilder, BitmapBuilder); + type Output = BooleanArray; + + fn with_capacity(&self, capacity: usize) -> Self::DecodedState { + ( + BitmapBuilder::with_capacity(capacity), + BitmapBuilder::with_capacity(capacity), + ) + } + + fn deserialize_dict(&mut self, _: DictPage) -> ParquetResult { + Ok(BooleanArray::new_empty(ArrowDataType::Boolean)) + } + + fn finalize( + &self, + dtype: ArrowDataType, + _dict: Option, + (values, validity): Self::DecodedState, + ) -> ParquetResult { + let validity = freeze_validity(validity); + Ok(BooleanArray::new(dtype, values.freeze(), validity)) + } + + fn has_predicate_specialization( + &self, + _state: &utils::State<'_, Self>, + _predicate: &PredicateFilter, + ) -> ParquetResult { + // @TODO: This can be enabled for the fast paths + Ok(false) + } + + fn extend_decoded( + &self, + decoded: &mut Self::DecodedState, + additional: &dyn arrow::array::Array, + is_optional: bool, + ) -> ParquetResult<()> { + let additional = additional.as_any().downcast_ref::().unwrap(); + decoded.0.extend_from_bitmap(additional.values()); + match additional.validity() { + Some(v) => decoded.1.extend_from_bitmap(v), + None if is_optional => decoded.1.extend_constant(additional.len(), true), + None => {}, + } + + Ok(()) + } + + fn extend_filtered_with_state( + &mut self, + state: utils::State<'_, Self>, + (target, validity): &mut Self::DecodedState, + _pred_true_mask: &mut BitmapBuilder, + filter: Option, + ) -> ParquetResult<()> { + match state.translation { + StateTranslation::Plain(values) => { + if state.is_optional { + append_validity( + state.page_validity.as_ref(), + filter.as_ref(), + validity, + values.len(), + ); + } + + let page_validity = constrain_page_validity( + values.len(), + state.page_validity.as_ref(), + filter.as_ref(), + ); + + match (filter, page_validity) { + (None, None) => decode_required_plain(values, target), + (Some(Filter::Range(rng)), None) => { + decode_required_plain(values.sliced(rng.start, rng.len()), target) + }, + (None, Some(page_validity)) => { + decode_optional_plain(values, target, page_validity) + }, + (Some(Filter::Range(rng)), Some(mut page_validity)) => { + let (skipped, truncated); + (skipped, page_validity) = page_validity.split_at(rng.start); + (page_validity, truncated) = page_validity.split_at(rng.len()); + + let skipped_values = skipped.set_bits(); + let truncated_values = truncated.set_bits(); + let values = values.sliced( + skipped_values, + values.len() - skipped_values - truncated_values, + ); + + decode_optional_plain(values, target, page_validity) + }, + (Some(Filter::Mask(mask)), None) => { + decode_masked_required_plain(values, target, mask) + }, + (Some(Filter::Mask(mask)), Some(page_validity)) => { + decode_masked_optional_plain(values, target, page_validity, mask) + }, + (Some(Filter::Predicate(_)), _) => todo!(), + }?; + + Ok(()) + }, + StateTranslation::Rle(values) => { + if state.is_optional { + append_validity( + state.page_validity.as_ref(), + filter.as_ref(), + validity, + values.len(), + ); + } + + let page_validity = constrain_page_validity( + values.len(), + state.page_validity.as_ref(), + filter.as_ref(), + ); + + match (filter, page_validity) { + (None, None) => decode_required_rle(values, None, target), + (Some(Filter::Range(rng)), None) if rng.start == 0 => { + decode_required_rle(values, Some(rng.end), target) + }, + (None, Some(page_validity)) => { + decode_optional_rle(values, target, &page_validity) + }, + (Some(Filter::Range(rng)), Some(page_validity)) if rng.start == 0 => { + decode_optional_rle(values, target, &page_validity) + }, + (Some(Filter::Mask(filter)), None) => { + decode_masked_required_rle(values, target, &filter) + }, + (Some(Filter::Mask(filter)), Some(page_validity)) => { + decode_masked_optional_rle(values, target, &page_validity, &filter) + }, + (Some(Filter::Range(rng)), None) => { + decode_masked_required_rle(values, target, &filter_from_range(rng.clone())) + }, + (Some(Filter::Range(rng)), Some(page_validity)) => decode_masked_optional_rle( + values, + target, + &page_validity, + &filter_from_range(rng.clone()), + ), + (Some(Filter::Predicate(_)), _) => todo!(), + }?; + + Ok(()) + }, + } + } +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/categorical.rs b/crates/polars-parquet/src/arrow/read/deserialize/categorical.rs new file mode 100644 index 000000000000..9c3cf6b77c18 --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/categorical.rs @@ -0,0 +1,149 @@ +use arrow::array::{DictionaryArray, MutableBinaryViewArray, PrimitiveArray}; +use arrow::bitmap::{Bitmap, BitmapBuilder}; +use arrow::datatypes::ArrowDataType; +use arrow::types::{AlignedBytes, NativeType}; + +use super::PredicateFilter; +use super::binview::BinViewDecoder; +use super::utils::{self, Decoder, StateTranslation, dict_indices_decoder, freeze_validity}; +use crate::parquet::encoding::Encoding; +use crate::parquet::encoding::hybrid_rle::HybridRleDecoder; +use crate::parquet::error::ParquetResult; +use crate::parquet::page::{DataPage, DictPage}; + +impl<'a> StateTranslation<'a, CategoricalDecoder> for HybridRleDecoder<'a> { + type PlainDecoder = HybridRleDecoder<'a>; + + fn new( + _decoder: &CategoricalDecoder, + page: &'a DataPage, + _dict: Option<&'a ::Dict>, + page_validity: Option<&Bitmap>, + ) -> ParquetResult { + if !matches!( + page.encoding(), + Encoding::PlainDictionary | Encoding::RleDictionary + ) { + return Err(utils::not_implemented(page)); + } + + dict_indices_decoder(page, page_validity.map_or(0, |bm| bm.unset_bits())) + } + fn num_rows(&self) -> usize { + self.len() + } +} + +/// Special decoder for Polars Enum and Categorical's. +/// +/// These are marked as special in the Arrow Field Metadata and they have the properly that for a +/// given row group all the values are in the dictionary page and all data pages are dictionary +/// encoded. This makes the job of decoding them extremely simple and fast. +pub struct CategoricalDecoder { + dict_size: usize, + decoder: BinViewDecoder, +} + +impl CategoricalDecoder { + pub fn new() -> Self { + Self { + dict_size: usize::MAX, + decoder: BinViewDecoder::new_string(), + } + } +} + +impl utils::Decoder for CategoricalDecoder { + type Translation<'a> = HybridRleDecoder<'a>; + type Dict = ::Dict; + type DecodedState = (Vec, BitmapBuilder); + type Output = DictionaryArray; + + fn with_capacity(&self, capacity: usize) -> Self::DecodedState { + ( + Vec::::with_capacity(capacity), + BitmapBuilder::with_capacity(capacity), + ) + } + + fn has_predicate_specialization( + &self, + state: &utils::State<'_, Self>, + _predicate: &PredicateFilter, + ) -> ParquetResult { + Ok(state.page_validity.is_none()) + } + + fn deserialize_dict(&mut self, page: DictPage) -> ParquetResult { + let dict = self.decoder.deserialize_dict(page)?; + self.dict_size = dict.len(); + Ok(dict) + } + + fn extend_decoded( + &self, + decoded: &mut Self::DecodedState, + additional: &dyn arrow::array::Array, + is_optional: bool, + ) -> ParquetResult<()> { + let additional = additional + .as_any() + .downcast_ref::>() + .unwrap(); + decoded.0.extend(additional.keys().values().iter().copied()); + match additional.validity() { + Some(v) => decoded.1.extend_from_bitmap(v), + None if is_optional => decoded.1.extend_constant(additional.len(), true), + None => {}, + } + + Ok(()) + } + + fn finalize( + &self, + dtype: ArrowDataType, + dict: Option, + (values, validity): Self::DecodedState, + ) -> ParquetResult> { + let validity = freeze_validity(validity); + let dict = dict.unwrap(); + let keys = PrimitiveArray::new(ArrowDataType::UInt32, values.into(), validity); + + let mut view_dict = MutableBinaryViewArray::with_capacity(dict.len()); + let (views, buffers, _, _, _) = dict.into_inner(); + + for buffer in buffers.iter() { + view_dict.push_buffer(buffer.clone()); + } + unsafe { view_dict.views_mut().extend(views.iter()) }; + unsafe { view_dict.set_total_bytes_len(views.iter().map(|v| v.length as usize).sum()) }; + let view_dict = view_dict.freeze(); + + // SAFETY: This was checked during construction of the dictionary + let dict = unsafe { view_dict.to_utf8view_unchecked() }.boxed(); + + // SAFETY: This was checked during decoding + Ok(unsafe { DictionaryArray::try_new_unchecked(dtype, keys, dict) }.unwrap()) + } + + fn extend_filtered_with_state( + &mut self, + state: utils::State<'_, Self>, + decoded: &mut Self::DecodedState, + pred_true_mask: &mut BitmapBuilder, + filter: Option, + ) -> ParquetResult<()> { + super::dictionary_encoded::decode_dict_dispatch( + state.translation, + self.dict_size, + state.dict_mask, + state.is_optional, + state.page_validity.as_ref(), + filter, + &mut decoded.1, + <::AlignedBytes as AlignedBytes>::cast_vec_ref_mut(&mut decoded.0), + pred_true_mask, + ) + } +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs new file mode 100644 index 000000000000..6b20bd25fe99 --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/mod.rs @@ -0,0 +1,263 @@ +use arrow::bitmap::bitmask::BitMask; +use arrow::bitmap::{Bitmap, BitmapBuilder}; +use arrow::types::{AlignedBytes, Bytes4Alignment4, NativeType}; +use polars_compute::filter::filter_boolean_kernel; + +use super::ParquetError; +use crate::parquet::encoding::hybrid_rle::HybridRleDecoder; +use crate::parquet::error::ParquetResult; +use crate::read::Filter; + +mod optional; +mod optional_masked_dense; +mod predicate; +mod required; +mod required_masked_dense; + +/// A mapping from a `u32` to a value. This is used in to map dictionary encoding to a value. +pub trait IndexMapping { + type Output: Copy; + + fn is_empty(&self) -> bool { + self.len() == 0 + } + fn len(&self) -> usize; + fn get(&self, idx: u32) -> Option { + ((idx as usize) < self.len()).then(|| unsafe { self.get_unchecked(idx) }) + } + unsafe fn get_unchecked(&self, idx: u32) -> Self::Output; +} + +// Base mapping used for everything except the CategoricalDecoder. +impl IndexMapping for &[T] { + type Output = T; + + #[inline(always)] + fn len(&self) -> usize { + <[T]>::len(self) + } + #[inline(always)] + unsafe fn get_unchecked(&self, idx: u32) -> Self::Output { + *unsafe { <[T]>::get_unchecked(self, idx as usize) } + } +} + +// Unit mapping used in the CategoricalDecoder. +impl IndexMapping for usize { + type Output = Bytes4Alignment4; + + #[inline(always)] + fn len(&self) -> usize { + *self + } + #[inline(always)] + unsafe fn get_unchecked(&self, idx: u32) -> Self::Output { + bytemuck::must_cast(idx) + } +} + +#[allow(clippy::too_many_arguments)] +pub fn decode_dict( + values: HybridRleDecoder<'_>, + dict: &[T], + dict_mask: Option<&Bitmap>, + is_optional: bool, + page_validity: Option<&Bitmap>, + filter: Option, + validity: &mut BitmapBuilder, + target: &mut Vec, + pred_true_mask: &mut BitmapBuilder, +) -> ParquetResult<()> { + decode_dict_dispatch( + values, + bytemuck::cast_slice(dict), + dict_mask, + is_optional, + page_validity, + filter, + validity, + ::cast_vec_ref_mut(target), + pred_true_mask, + ) +} + +#[inline(never)] +#[allow(clippy::too_many_arguments)] +pub fn decode_dict_dispatch>( + mut values: HybridRleDecoder<'_>, + dict: D, + dict_mask: Option<&Bitmap>, + is_optional: bool, + page_validity: Option<&Bitmap>, + filter: Option, + validity: &mut BitmapBuilder, + target: &mut Vec, + pred_true_mask: &mut BitmapBuilder, +) -> ParquetResult<()> { + if is_optional { + append_validity(page_validity, filter.as_ref(), validity, values.len()); + } + + let page_validity = constrain_page_validity(values.len(), page_validity, filter.as_ref()); + + match (filter, page_validity) { + (None, None) => required::decode(values, dict, target, 0), + (Some(Filter::Range(rng)), None) => { + values.limit_to(rng.end); + required::decode(values, dict, target, rng.start) + }, + (None, Some(page_validity)) => optional::decode(values, dict, page_validity, target, 0), + (Some(Filter::Range(rng)), Some(page_validity)) => { + optional::decode(values, dict, page_validity, target, rng.start) + }, + (Some(Filter::Mask(filter)), None) => { + required_masked_dense::decode(values, dict, filter, target) + }, + (Some(Filter::Mask(filter)), Some(page_validity)) => { + optional_masked_dense::decode(values, dict, filter, page_validity, target) + }, + (Some(Filter::Predicate(p)), None) => { + predicate::decode(values, dict, dict_mask.unwrap(), &p, target, pred_true_mask) + }, + (Some(Filter::Predicate(_)), Some(_)) => todo!(), + }?; + + Ok(()) +} + +pub(crate) fn append_validity( + page_validity: Option<&Bitmap>, + filter: Option<&Filter>, + validity: &mut BitmapBuilder, + values_len: usize, +) { + match (page_validity, filter) { + (None, None) => validity.extend_constant(values_len, true), + (None, Some(f)) => validity.extend_constant(f.num_rows(values_len), true), + (Some(page_validity), None) => validity.extend_from_bitmap(page_validity), + (Some(page_validity), Some(Filter::Range(rng))) => { + let page_validity = page_validity.clone(); + validity.extend_from_bitmap(&page_validity.clone().sliced(rng.start, rng.len())) + }, + (Some(page_validity), Some(Filter::Mask(mask))) => { + validity.extend_from_bitmap(&filter_boolean_kernel(page_validity, mask)) + }, + (_, Some(Filter::Predicate(_))) => todo!(), + } +} + +pub(crate) fn constrain_page_validity( + values_len: usize, + page_validity: Option<&Bitmap>, + filter: Option<&Filter>, +) -> Option { + let num_unfiltered_rows = match (filter.as_ref(), page_validity) { + (None, None) => values_len, + (None, Some(pv)) => pv.len(), + (Some(f), Some(pv)) => { + debug_assert!(pv.len() >= f.max_offset(pv.len())); + f.max_offset(pv.len()) + }, + (Some(f), None) => f.max_offset(values_len), + }; + + page_validity.map(|pv| { + if pv.len() > num_unfiltered_rows { + pv.clone().sliced(0, num_unfiltered_rows) + } else { + pv.clone() + } + }) +} + +#[cold] +fn oob_dict_idx() -> ParquetError { + ParquetError::oos("Dictionary Index is out-of-bounds") +} + +#[cold] +fn no_more_bitpacked_values() -> ParquetError { + ParquetError::oos("Bitpacked Hybrid-RLE ran out before all values were served") +} + +#[inline(always)] +fn verify_dict_indices(indices: &[u32], dict_size: usize) -> ParquetResult<()> { + debug_assert!(dict_size <= u32::MAX as usize); + let dict_size = dict_size as u32; + + let mut is_valid = true; + for &idx in indices { + is_valid &= idx < dict_size; + } + + if is_valid { + Ok(()) + } else { + Err(oob_dict_idx()) + } +} + +/// Skip over entire chunks in a [`HybridRleDecoder`] as long as all skipped chunks do not include +/// more than `num_values_to_skip` values. +#[inline(always)] +fn required_skip_whole_chunks( + values: &mut HybridRleDecoder<'_>, + num_values_to_skip: &mut usize, +) -> ParquetResult<()> { + if *num_values_to_skip == 0 { + return Ok(()); + } + + loop { + let mut values_clone = values.clone(); + let Some(chunk_len) = values_clone.next_chunk_length()? else { + break; + }; + if *num_values_to_skip < chunk_len { + break; + } + *values = values_clone; + *num_values_to_skip -= chunk_len; + } + + Ok(()) +} + +/// Skip over entire chunks in a [`HybridRleDecoder`] as long as all skipped chunks do not include +/// more than `num_values_to_skip` values. +#[inline(always)] +fn optional_skip_whole_chunks( + values: &mut HybridRleDecoder<'_>, + validity: &mut BitMask<'_>, + num_rows_to_skip: &mut usize, + num_values_to_skip: &mut usize, +) -> ParquetResult<()> { + if *num_values_to_skip == 0 { + return Ok(()); + } + + let mut total_num_skipped_values = 0; + + loop { + let mut values_clone = values.clone(); + let Some(chunk_len) = values_clone.next_chunk_length()? else { + break; + }; + if *num_values_to_skip < chunk_len { + break; + } + *values = values_clone; + *num_values_to_skip -= chunk_len; + total_num_skipped_values += chunk_len; + } + + if total_num_skipped_values > 0 { + let offset = validity + .nth_set_bit_idx(total_num_skipped_values - 1, 0) + .map_or(validity.len(), |v| v + 1); + *num_rows_to_skip -= offset; + validity.advance_by(offset); + } + + Ok(()) +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs new file mode 100644 index 000000000000..bb0d77d9d989 --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs @@ -0,0 +1,201 @@ +use arrow::bitmap::Bitmap; +use arrow::bitmap::bitmask::BitMask; +use arrow::types::AlignedBytes; + +use super::{ + IndexMapping, no_more_bitpacked_values, oob_dict_idx, optional_skip_whole_chunks, + verify_dict_indices, +}; +use crate::parquet::encoding::hybrid_rle::{HybridRleChunk, HybridRleDecoder}; +use crate::parquet::error::ParquetResult; + +/// Decoding kernel for optional dictionary encoded. +#[inline(never)] +pub fn decode>( + mut values: HybridRleDecoder<'_>, + dict: D, + mut validity: Bitmap, + target: &mut Vec, + mut num_rows_to_skip: usize, +) -> ParquetResult<()> { + debug_assert!(num_rows_to_skip <= validity.len()); + + let num_rows = validity.len() - num_rows_to_skip; + let end_length = target.len() + num_rows; + + target.reserve(num_rows); + + // Remove any leading and trailing nulls. This has two benefits: + // 1. It increases the chance of dispatching to the faster kernel (e.g. for sorted data) + // 2. It reduces the amount of iterations in the main loop and replaces it with `memset`s + let leading_nulls = validity.take_leading_zeros(); + let trailing_nulls = validity.take_trailing_zeros(); + + // Special case: all values are skipped, just add the trailing null. + if num_rows_to_skip >= leading_nulls + validity.len() { + target.resize(end_length, B::zeroed()); + return Ok(()); + } + + values.limit_to(validity.set_bits()); + + // Add the leading nulls + if num_rows_to_skip < leading_nulls { + target.resize(target.len() + leading_nulls - num_rows_to_skip, B::zeroed()); + num_rows_to_skip = 0; + } else { + num_rows_to_skip -= leading_nulls; + } + + if validity.set_bits() == validity.len() { + // Dispatch to the required kernel if all rows are valid anyway. + super::required::decode(values, dict, target, num_rows_to_skip)?; + } else { + if dict.is_empty() { + return Err(oob_dict_idx()); + } + + let mut num_values_to_skip = 0; + if num_rows_to_skip > 0 { + num_values_to_skip = validity.clone().sliced(0, num_rows_to_skip).set_bits(); + } + + let mut validity = BitMask::from_bitmap(&validity); + let mut values_buffer = [0u32; 128]; + let values_buffer = &mut values_buffer; + + // Skip over any whole HybridRleChunks + optional_skip_whole_chunks( + &mut values, + &mut validity, + &mut num_rows_to_skip, + &mut num_values_to_skip, + )?; + + while let Some(chunk) = values.next_chunk()? { + debug_assert!(num_values_to_skip < chunk.len() || chunk.len() == 0); + + match chunk { + HybridRleChunk::Rle(value, size) => { + if size == 0 { + continue; + } + + // If we know that we have `size` times `value` that we can append, but there + // might be nulls in between those values. + // + // 1. See how many `num_rows = valid + invalid` values `size` would entail. + // This is done with `nth_set_bit_idx` on the validity mask. + // 2. Fill `num_rows` values into the target buffer. + // 3. Advance the validity mask by `num_rows` values. + + let Some(value) = dict.get(value) else { + return Err(oob_dict_idx()); + }; + + let num_chunk_rows = + validity.nth_set_bit_idx(size, 0).unwrap_or(validity.len()); + validity.advance_by(num_chunk_rows); + + target.resize(target.len() + num_chunk_rows - num_rows_to_skip, value); + }, + HybridRleChunk::Bitpacked(mut decoder) => { + let num_rows_for_decoder = validity + .nth_set_bit_idx(decoder.len(), 0) + .unwrap_or(validity.len()); + + let mut chunked = decoder.chunked(); + + let mut buffer_part_idx = 0; + let mut values_offset = 0; + let mut num_buffered: usize = 0; + + let mut decoder_validity; + (decoder_validity, validity) = validity.split_at(num_rows_for_decoder); + + // Skip over any remaining values. + if num_rows_to_skip > 0 { + decoder_validity.advance_by(num_rows_to_skip); + + chunked.decoder.skip_chunks(num_values_to_skip / 32); + num_values_to_skip %= 32; + + if num_values_to_skip > 0 { + let buffer_part = <&mut [u32; 32]>::try_from( + &mut values_buffer[buffer_part_idx * 32..][..32], + ) + .unwrap(); + let Some(num_added) = chunked.next_into(buffer_part) else { + return Err(no_more_bitpacked_values()); + }; + + debug_assert!(num_values_to_skip <= num_added); + verify_dict_indices(buffer_part, dict.len())?; + + values_offset += num_values_to_skip; + num_buffered += num_added - num_values_to_skip; + buffer_part_idx += 1; + } + } + + let mut iter = |v: u64, n: usize| { + while num_buffered < v.count_ones() as usize { + buffer_part_idx %= 4; + + let buffer_part = <&mut [u32; 32]>::try_from( + &mut values_buffer[buffer_part_idx * 32..][..32], + ) + .unwrap(); + let Some(num_added) = chunked.next_into(buffer_part) else { + return Err(no_more_bitpacked_values()); + }; + + verify_dict_indices(buffer_part, dict.len())?; + + num_buffered += num_added; + + buffer_part_idx += 1; + } + + let mut num_read = 0; + + target.extend((0..n).map(|i| { + let idx = values_buffer[(values_offset + num_read) % 128]; + num_read += ((v >> i) & 1) as usize; + + // SAFETY: + // 1. `values_buffer` starts out as only zeros, which we know is in the + // dictionary following the original `dict.is_empty` check. + // 2. Each time we write to `values_buffer`, it is followed by a + // `verify_dict_indices`. + unsafe { dict.get_unchecked(idx) } + })); + + values_offset += num_read; + values_offset %= 128; + num_buffered -= num_read; + + ParquetResult::Ok(()) + }; + + let mut v_iter = decoder_validity.fast_iter_u56(); + for v in v_iter.by_ref() { + iter(v, 56)?; + } + + let (v, vl) = v_iter.remainder(); + iter(v, vl)?; + }, + } + + num_rows_to_skip = 0; + num_values_to_skip = 0; + } + } + + // Add back the trailing nulls + debug_assert_eq!(target.len(), end_length - trailing_nulls); + target.resize(end_length, B::zeroed()); + + Ok(()) +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional_masked_dense.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional_masked_dense.rs new file mode 100644 index 000000000000..28b58d960111 --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional_masked_dense.rs @@ -0,0 +1,218 @@ +use arrow::bitmap::Bitmap; +use arrow::bitmap::bitmask::BitMask; +use arrow::types::AlignedBytes; + +use super::{IndexMapping, oob_dict_idx, verify_dict_indices}; +use crate::parquet::encoding::hybrid_rle::{HybridRleChunk, HybridRleDecoder}; +use crate::parquet::error::ParquetResult; + +#[inline(never)] +pub fn decode>( + mut values: HybridRleDecoder<'_>, + dict: D, + mut filter: Bitmap, + mut validity: Bitmap, + target: &mut Vec, +) -> ParquetResult<()> { + // @NOTE: We don't skip leading filtered values, because it is a bit more involved than the + // other kernels. We could probably do it anyway after having tried to dispatch to faster + // kernels, but we lose quite a bit of the potency with that. + filter.take_trailing_zeros(); + validity = validity.sliced(0, filter.len()); + + let num_rows = filter.set_bits(); + let num_valid_values = validity.set_bits(); + + assert_eq!(filter.len(), validity.len()); + assert!(num_valid_values <= values.len()); + + // Dispatch to the non-filter kernel if all rows are needed anyway. + if num_rows == filter.len() { + return super::optional::decode(values, dict, validity, target, 0); + } + + // Dispatch to the required kernel if all rows are valid anyway. + if num_valid_values == validity.len() { + return super::required_masked_dense::decode(values, dict, filter, target); + } + + if dict.is_empty() && num_valid_values > 0 { + return Err(oob_dict_idx()); + } + + target.reserve(num_rows); + + let end_length = target.len() + num_rows; + + let mut filter = BitMask::from_bitmap(&filter); + let mut validity = BitMask::from_bitmap(&validity); + + values.limit_to(num_valid_values); + let mut values_buffer = [0u32; 128]; + let values_buffer = &mut values_buffer; + + let mut num_rows_left = num_rows; + + for chunk in values.into_chunk_iter() { + // Early stop if we have no more rows to load. + if num_rows_left == 0 { + break; + } + + match chunk? { + HybridRleChunk::Rle(value, size) => { + if size == 0 { + continue; + } + + // If we know that we have `size` times `value` that we can append, but there might + // be nulls in between those values. + // + // 1. See how many `num_rows = valid + invalid` values `size` would entail. This is + // done with `num_bits_before_nth_one` on the validity mask. + // 2. Fill `num_rows` values into the target buffer. + // 3. Advance the validity mask by `num_rows` values. + + let num_chunk_values = validity.nth_set_bit_idx(size, 0).unwrap_or(validity.len()); + + let current_filter; + (current_filter, filter) = filter.split_at(num_chunk_values); + validity.advance_by(num_chunk_values); + + let num_chunk_rows = current_filter.set_bits(); + + let Some(value) = dict.get(value) else { + return Err(oob_dict_idx()); + }; + + target.resize(target.len() + num_chunk_rows, value); + }, + HybridRleChunk::Bitpacked(mut decoder) => { + // For bitpacked we do the following: + // 1. See how many rows are encoded by this `decoder`. + // 2. Go through the filter and validity 56 bits at a time and: + // 0. If filter bits are 0, skip the chunk entirely. + // 1. Buffer enough values so that we can branchlessly decode with the filter + // and validity. + // 2. Decode with filter and validity. + // 3. Decode remainder. + + let size = decoder.len(); + let mut chunked = decoder.chunked(); + + let num_chunk_values = validity.nth_set_bit_idx(size, 0).unwrap_or(validity.len()); + + let mut buffer_part_idx = 0; + let mut values_offset = 0; + let mut num_buffered: usize = 0; + let mut skip_values = 0; + + let current_filter; + let current_validity; + + (current_filter, filter) = unsafe { filter.split_at_unchecked(num_chunk_values) }; + (current_validity, validity) = + unsafe { validity.split_at_unchecked(num_chunk_values) }; + + let mut iter = |mut f: u64, mut v: u64| { + // Skip chunk if we don't any values from here. + if f == 0 { + skip_values += v.count_ones() as usize; + return ParquetResult::Ok(()); + } + + // Skip over already buffered items. + let num_buffered_skipped = skip_values.min(num_buffered); + values_offset += num_buffered_skipped; + num_buffered -= num_buffered_skipped; + skip_values -= num_buffered_skipped; + + // If we skipped plenty already, just skip decoding those chunks instead of + // decoding them and throwing them away. + chunked.decoder.skip_chunks(skip_values / 32); + // The leftovers we have to decode but we can also just skip. + skip_values %= 32; + + while num_buffered < v.count_ones() as usize { + let buffer_part = <&mut [u32; 32]>::try_from( + &mut values_buffer[buffer_part_idx * 32..][..32], + ) + .unwrap(); + let num_added = chunked.next_into(buffer_part).unwrap(); + + verify_dict_indices(buffer_part, dict.len())?; + + let skip_chunk_values = skip_values.min(num_added); + + values_offset += skip_chunk_values; + num_buffered += num_added - skip_chunk_values; + skip_values -= skip_chunk_values; + + buffer_part_idx += 1; + buffer_part_idx %= 4; + } + + let mut num_read = 0; + let mut num_written = 0; + let target_ptr = unsafe { target.as_mut_ptr().add(target.len()) }; + + while f != 0 { + let offset = f.trailing_zeros(); + + num_read += (v & (1u64 << offset).wrapping_sub(1)).count_ones() as usize; + v >>= offset; + + let idx = values_buffer[(values_offset + num_read) % 128]; + // SAFETY: + // 1. `values_buffer` starts out as only zeros, which we know is in the + // dictionary following the original `dict.is_empty` check. + // 2. Each time we write to `values_buffer`, it is followed by a + // `verify_dict_indices`. + let value = unsafe { dict.get_unchecked(idx) }; + unsafe { target_ptr.add(num_written).write(value) }; + + num_written += 1; + num_read += (v & 1) as usize; + + f >>= offset + 1; // Clear least significant bit. + v >>= 1; + } + + num_read += v.count_ones() as usize; + + values_offset += num_read; + values_offset %= 128; + num_buffered -= num_read; + unsafe { + target.set_len(target.len() + num_written); + } + num_rows_left -= num_written; + + ParquetResult::Ok(()) + }; + + let mut f_iter = current_filter.fast_iter_u56(); + let mut v_iter = current_validity.fast_iter_u56(); + + for (f, v) in f_iter.by_ref().zip(v_iter.by_ref()) { + iter(f, v)?; + } + + let (f, fl) = f_iter.remainder(); + let (v, vl) = v_iter.remainder(); + + assert_eq!(fl, vl); + + iter(f, v)?; + }, + } + } + + if cfg!(debug_assertions) { + assert_eq!(validity.set_bits(), 0); + } + + target.resize(end_length, B::zeroed()); + + Ok(()) +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/predicate.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/predicate.rs new file mode 100644 index 000000000000..b63f830c180f --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/predicate.rs @@ -0,0 +1,260 @@ +use arrow::bitmap::{Bitmap, BitmapBuilder}; +use arrow::types::AlignedBytes; + +use super::{IndexMapping, oob_dict_idx, verify_dict_indices}; +use crate::parquet::encoding::hybrid_rle::{HybridRleChunk, HybridRleDecoder}; +use crate::parquet::error::ParquetResult; +use crate::read::PredicateFilter; + +#[inline(never)] +pub fn decode>( + values: HybridRleDecoder<'_>, + dict: D, + dict_mask: &Bitmap, + predicate: &PredicateFilter, + target: &mut Vec, + pred_true_mask: &mut BitmapBuilder, +) -> ParquetResult<()> { + let num_filtered_dict_values = dict_mask.set_bits(); + + let expected_pred_true_mask_len = pred_true_mask.len() + values.len(); + + // @NOTE: this has to be changed when there are nulls null + if num_filtered_dict_values == 0 { + pred_true_mask.extend_constant(values.len(), false); + } else if num_filtered_dict_values == 1 { + let needle = dict_mask.leading_zeros(); + let start_num_values = pred_true_mask.set_bits(); + + decode_single_no_values(values, needle as u32, pred_true_mask)?; + + if predicate.include_values { + let num_values = pred_true_mask.set_bits() - start_num_values; + target.resize(target.len() + num_values, dict.get(needle as u32).unwrap()); + } + } else if predicate.include_values { + decode_multiple_values(values, dict, dict_mask, target, pred_true_mask)?; + } else { + decode_multiple_no_values(values, dict_mask, pred_true_mask)?; + } + + assert_eq!(expected_pred_true_mask_len, pred_true_mask.len()); + + Ok(()) +} + +#[inline(never)] +pub fn decode_single_no_values( + mut values: HybridRleDecoder<'_>, + needle: u32, + pred_true_mask: &mut BitmapBuilder, +) -> ParquetResult<()> { + pred_true_mask.reserve(values.len()); + + let mut unpacked = [0u32; 32]; + while let Some(chunk) = values.next_chunk()? { + match chunk { + HybridRleChunk::Rle(value, size) => { + pred_true_mask.extend_constant(size, value == needle); + }, + HybridRleChunk::Bitpacked(mut decoder) => { + let size = decoder.len(); + let mut chunked = decoder.chunked(); + + for _ in 0..size / 32 { + let n = chunked.next_into(&mut unpacked).unwrap(); + debug_assert_eq!(n, 32); + + let mut is_equal_mask = 0u64; + for (i, &v) in unpacked.iter().enumerate() { + is_equal_mask |= u64::from(v == needle) << i; + } + + // SAFETY: We reserved enough in the beginning of the function. + unsafe { pred_true_mask.push_word_with_len_unchecked(is_equal_mask, 32) }; + } + + if let Some(n) = chunked.next_into(&mut unpacked) { + debug_assert_eq!(n, size % 32); + + let mut is_equal_mask = 0u64; + for (i, &v) in unpacked[..n].iter().enumerate() { + is_equal_mask |= u64::from(v == needle) << i; + } + + // SAFETY: We reserved enough in the beginning of the function. + unsafe { pred_true_mask.push_word_with_len_unchecked(is_equal_mask, n) }; + } + }, + } + } + + Ok(()) +} + +#[inline(never)] +pub fn decode_multiple_no_values( + mut values: HybridRleDecoder<'_>, + dict_mask: &Bitmap, + pred_true_mask: &mut BitmapBuilder, +) -> ParquetResult<()> { + pred_true_mask.reserve(values.len()); + + let mut unpacked = [0u32; 32]; + while let Some(chunk) = values.next_chunk()? { + match chunk { + HybridRleChunk::Rle(value, size) => { + let is_pred_true = dict_mask.get(value as usize).ok_or_else(oob_dict_idx)?; + pred_true_mask.extend_constant(size, is_pred_true); + }, + HybridRleChunk::Bitpacked(mut decoder) => { + let size = decoder.len(); + let mut chunked = decoder.chunked(); + + for _ in 0..size / 32 { + let n = chunked.next_into(&mut unpacked).unwrap(); + debug_assert_eq!(n, 32); + + verify_dict_indices(&unpacked, dict_mask.len())?; + let mut is_pred_true_mask = 0u64; + for (i, &v) in unpacked.iter().enumerate() { + // SAFETY: We just verified the dictionary indices + let is_pred_true = unsafe { dict_mask.get_bit_unchecked(v as usize) }; + is_pred_true_mask |= u64::from(is_pred_true) << i; + } + + // SAFETY: We reserved enough in the beginning of the function. + unsafe { pred_true_mask.push_word_with_len_unchecked(is_pred_true_mask, 32) }; + } + + if let Some(n) = chunked.next_into(&mut unpacked) { + debug_assert_eq!(n, size % 32); + + verify_dict_indices(&unpacked[..n], dict_mask.len())?; + let mut is_pred_true_mask = 0u64; + for (i, &v) in unpacked[..n].iter().enumerate() { + // SAFETY: We just verified the dictionary indices + let is_pred_true = unsafe { dict_mask.get_bit_unchecked(v as usize) }; + is_pred_true_mask |= u64::from(is_pred_true) << i; + } + + // SAFETY: We reserved enough in the beginning of the function. + unsafe { pred_true_mask.push_word_with_len_unchecked(is_pred_true_mask, n) }; + } + }, + } + } + + Ok(()) +} + +#[inline(never)] +pub fn decode_multiple_values>( + mut values: HybridRleDecoder<'_>, + dict: D, + dict_mask: &Bitmap, + target: &mut Vec, + pred_true_mask: &mut BitmapBuilder, +) -> ParquetResult<()> { + pred_true_mask.reserve(values.len()); + target.reserve(values.len()); + + let mut unpacked = [0u32; 32]; + while let Some(chunk) = values.next_chunk()? { + match chunk { + HybridRleChunk::Rle(value, size) => { + if size == 0 { + continue; + } + + let is_pred_true = dict_mask.get(value as usize).ok_or_else(oob_dict_idx)?; + pred_true_mask.extend_constant(size, is_pred_true); + if is_pred_true { + let value = dict.get(value).unwrap(); + target.resize(target.len() + size, value); + } + }, + HybridRleChunk::Bitpacked(mut decoder) => { + let size = decoder.len(); + let mut chunked = decoder.chunked(); + + for _ in 0..size / 32 { + let n = chunked.next_into(&mut unpacked).unwrap(); + debug_assert_eq!(n, 32); + + verify_dict_indices(&unpacked, dict_mask.len())?; + let mut is_pred_true_mask = 0u64; + for (i, &v) in unpacked.iter().enumerate() { + // SAFETY: We just verified the dictionary indices + let is_pred_true = unsafe { dict_mask.get_bit_unchecked(v as usize) }; + is_pred_true_mask |= u64::from(is_pred_true) << i; + } + // SAFETY: We reserved enough in the beginning of the function. + unsafe { pred_true_mask.push_word_with_len_unchecked(is_pred_true_mask, 32) }; + + let num_pred_true_values = is_pred_true_mask.count_ones() as usize; + + if num_pred_true_values == 32 { + target.extend( + unpacked + .iter() + // SAFETY: We just verified the dictionary indices + .map(|&v| unsafe { dict.get_unchecked(v) }), + ); + } else if num_pred_true_values > 0 { + let mut write_ptr = unsafe { target.as_mut_ptr().add(target.len()) }; + for v in unpacked { + unsafe { + write_ptr.write(dict.get_unchecked(v)); + let select = dict_mask.get_bit_unchecked(v as usize); + write_ptr = write_ptr.add(usize::from(select)); + } + } + + let new_len = target.len() + num_pred_true_values; + unsafe { target.set_len(new_len) }; + } + } + + if let Some(n) = chunked.next_into(&mut unpacked) { + debug_assert_eq!(n, size % 32); + + verify_dict_indices(&unpacked[..n], dict_mask.len())?; + let mut is_pred_true_mask = 0u64; + for (i, &v) in unpacked[..n].iter().enumerate() { + // SAFETY: We just verified the dictionary indices + let is_pred_true = unsafe { dict_mask.get_bit_unchecked(v as usize) }; + is_pred_true_mask |= u64::from(is_pred_true) << i; + } + // SAFETY: We reserved enough in the beginning of the function. + unsafe { pred_true_mask.push_word_with_len_unchecked(is_pred_true_mask, n) }; + + let num_pred_true_values = is_pred_true_mask.count_ones() as usize; + + if num_pred_true_values == n { + target.extend( + unpacked[..n] + .iter() + // SAFETY: We just verified the dictionary indices + .map(|&v| unsafe { dict.get_unchecked(v) }), + ); + } else if num_pred_true_values > 0 { + let mut write_ptr = unsafe { target.as_mut_ptr().add(target.len()) }; + for &v in &unpacked[..n] { + unsafe { + write_ptr.write(dict.get_unchecked(v)); + let select = dict_mask.get_bit_unchecked(v as usize); + write_ptr = write_ptr.add(usize::from(select)); + } + } + + let new_len = target.len() + num_pred_true_values; + unsafe { target.set_len(new_len) }; + } + } + }, + } + } + + Ok(()) +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs new file mode 100644 index 000000000000..20c0df9725d6 --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required.rs @@ -0,0 +1,88 @@ +use arrow::types::AlignedBytes; + +use super::{IndexMapping, oob_dict_idx, required_skip_whole_chunks, verify_dict_indices}; +use crate::parquet::encoding::hybrid_rle::{HybridRleChunk, HybridRleDecoder}; +use crate::parquet::error::ParquetResult; + +/// Decoding kernel for required dictionary encoded. +#[inline(never)] +pub fn decode>( + mut values: HybridRleDecoder<'_>, + dict: D, + target: &mut Vec, + mut num_rows_to_skip: usize, +) -> ParquetResult<()> { + debug_assert!(num_rows_to_skip <= values.len()); + + let num_rows = values.len() - num_rows_to_skip; + let end_length = target.len() + num_rows; + + if num_rows == 0 { + return Ok(()); + } + + target.reserve(num_rows); + + if dict.is_empty() { + return Err(oob_dict_idx()); + } + + // Skip over whole HybridRleChunks + required_skip_whole_chunks(&mut values, &mut num_rows_to_skip)?; + + while let Some(chunk) = values.next_chunk()? { + debug_assert!(num_rows_to_skip < chunk.len() || chunk.len() == 0); + + match chunk { + HybridRleChunk::Rle(value, size) => { + if size == 0 { + continue; + } + + let Some(value) = dict.get(value) else { + return Err(oob_dict_idx()); + }; + + target.resize(target.len() + size - num_rows_to_skip, value); + }, + HybridRleChunk::Bitpacked(mut decoder) => { + if num_rows_to_skip > 0 { + decoder.skip_chunks(num_rows_to_skip / 32); + num_rows_to_skip %= 32; + + if let Some((chunk, chunk_size)) = decoder.chunked().next_inexact() { + let chunk = &chunk[num_rows_to_skip..chunk_size]; + verify_dict_indices(chunk, dict.len())?; + target.extend(chunk.iter().map(|&idx| { + // SAFETY: The dict indices were verified before. + unsafe { dict.get_unchecked(idx) } + })); + } + } + + let mut chunked = decoder.chunked(); + for chunk in chunked.by_ref() { + verify_dict_indices(&chunk, dict.len())?; + target.extend(chunk.iter().map(|&idx| { + // SAFETY: The dict indices were verified before. + unsafe { dict.get_unchecked(idx) } + })); + } + + if let Some((chunk, chunk_size)) = chunked.remainder() { + verify_dict_indices(&chunk[..chunk_size], dict.len())?; + target.extend(chunk[..chunk_size].iter().map(|&idx| { + // SAFETY: The dict indices were verified before. + unsafe { dict.get_unchecked(idx) } + })); + } + }, + } + + num_rows_to_skip = 0; + } + + debug_assert_eq!(target.len(), end_length); + + Ok(()) +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs new file mode 100644 index 000000000000..11308f4b6a22 --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/required_masked_dense.rs @@ -0,0 +1,176 @@ +use arrow::bitmap::Bitmap; +use arrow::bitmap::bitmask::BitMask; +use arrow::types::AlignedBytes; + +use super::{IndexMapping, oob_dict_idx, required_skip_whole_chunks, verify_dict_indices}; +use crate::parquet::encoding::hybrid_rle::{HybridRleChunk, HybridRleDecoder}; +use crate::parquet::error::ParquetResult; + +#[inline(never)] +pub fn decode>( + mut values: HybridRleDecoder<'_>, + dict: D, + mut filter: Bitmap, + target: &mut Vec, +) -> ParquetResult<()> { + assert!(values.len() >= filter.len()); + + let mut num_rows_to_skip = filter.take_leading_zeros(); + filter.take_trailing_zeros(); + + let num_rows = filter.set_bits(); + + values.limit_to(num_rows_to_skip + filter.len()); + + // Dispatch to the non-filter kernel if all rows are needed anyway. + if num_rows == filter.len() { + return super::required::decode(values, dict, target, num_rows_to_skip); + } + + if dict.is_empty() && !filter.is_empty() { + return Err(oob_dict_idx()); + } + + target.reserve(num_rows); + + let mut filter = BitMask::from_bitmap(&filter); + + let mut values_buffer = [0u32; 128]; + let values_buffer = &mut values_buffer; + + // Skip over whole HybridRleChunks + required_skip_whole_chunks(&mut values, &mut num_rows_to_skip)?; + + while let Some(chunk) = values.next_chunk()? { + debug_assert!(num_rows_to_skip < chunk.len() || chunk.len() == 0); + + match chunk { + HybridRleChunk::Rle(value, size) => { + if size == 0 { + continue; + } + + // If we know that we have `size` times `value` that we can append, but there might + // be nulls in between those values. + // + // 1. See how many `num_rows = valid + invalid` values `size` would entail. This is + // done with `num_bits_before_nth_one` on the validity mask. + // 2. Fill `num_rows` values into the target buffer. + // 3. Advance the validity mask by `num_rows` values. + + let current_filter; + + (current_filter, filter) = filter.split_at(size - num_rows_to_skip); + let num_chunk_rows = current_filter.set_bits(); + + if num_chunk_rows > 0 { + let Some(value) = dict.get(value) else { + return Err(oob_dict_idx()); + }; + + target.resize(target.len() + num_chunk_rows, value); + } + }, + HybridRleChunk::Bitpacked(mut decoder) => { + let size = decoder.len(); + let mut chunked = decoder.chunked(); + + let mut buffer_part_idx = 0; + let mut values_offset = 0; + let mut num_buffered: usize = 0; + let mut skip_values = num_rows_to_skip; + + let current_filter; + + (current_filter, filter) = filter.split_at(size - num_rows_to_skip); + + let mut iter = |mut f: u64, len: usize| { + debug_assert!(len <= 64); + + // Skip chunk if we don't any values from here. + if f == 0 { + skip_values += len; + return ParquetResult::Ok(()); + } + + // Skip over already buffered items. + let num_buffered_skipped = skip_values.min(num_buffered); + values_offset += num_buffered_skipped; + num_buffered -= num_buffered_skipped; + skip_values -= num_buffered_skipped; + + // If we skipped plenty already, just skip decoding those chunks instead of + // decoding them and throwing them away. + chunked.decoder.skip_chunks(skip_values / 32); + // The leftovers we have to decode but we can also just skip. + skip_values %= 32; + + while num_buffered < len { + let buffer_part = <&mut [u32; 32]>::try_from( + &mut values_buffer[buffer_part_idx * 32..][..32], + ) + .unwrap(); + let num_added = chunked.next_into(buffer_part).unwrap(); + + verify_dict_indices(buffer_part, dict.len())?; + + let skip_chunk_values = skip_values.min(num_added); + + values_offset += skip_chunk_values; + num_buffered += num_added - skip_chunk_values; + skip_values -= skip_chunk_values; + + buffer_part_idx += 1; + buffer_part_idx %= 4; + } + + let mut num_read = 0; + let mut num_written = 0; + let target_ptr = unsafe { target.as_mut_ptr().add(target.len()) }; + + while f != 0 { + let offset = f.trailing_zeros() as usize; + + num_read += offset; + + let idx = values_buffer[(values_offset + num_read) % 128]; + // SAFETY: + // 1. `values_buffer` starts out as only zeros, which we know is in the + // dictionary following the original `dict.is_empty` check. + // 2. Each time we write to `values_buffer`, it is followed by a + // `verify_dict_indices`. + let value = unsafe { dict.get_unchecked(idx) }; + unsafe { target_ptr.add(num_written).write(value) }; + + num_written += 1; + num_read += 1; + + f >>= offset + 1; // Clear least significant bit. + } + + values_offset += len; + values_offset %= 128; + num_buffered -= len; + unsafe { + target.set_len(target.len() + num_written); + } + + ParquetResult::Ok(()) + }; + + let mut f_iter = current_filter.fast_iter_u56(); + + for f in f_iter.by_ref() { + iter(f, 56)?; + } + + let (f, fl) = f_iter.remainder(); + iter(f, fl)?; + }, + } + + num_rows_to_skip = 0; + } + + Ok(()) +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary.rs b/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary.rs new file mode 100644 index 000000000000..b64850c11420 --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary.rs @@ -0,0 +1,614 @@ +use arrow::array::{FixedSizeBinaryArray, Splitable}; +use arrow::bitmap::{Bitmap, BitmapBuilder}; +use arrow::buffer::Buffer; +use arrow::datatypes::ArrowDataType; +use arrow::storage::SharedStorage; +use arrow::types::{ + Bytes1Alignment1, Bytes2Alignment2, Bytes4Alignment4, Bytes8Alignment8, Bytes12Alignment4, + Bytes16Alignment16, Bytes32Alignment16, +}; +use bytemuck::Zeroable; + +use super::dictionary_encoded::append_validity; +use super::utils::array_chunks::ArrayChunks; +use super::utils::{Decoder, dict_indices_decoder, freeze_validity}; +use super::{Filter, PredicateFilter}; +use crate::parquet::encoding::hybrid_rle::{HybridRleChunk, HybridRleDecoder}; +use crate::parquet::encoding::{Encoding, hybrid_rle}; +use crate::parquet::error::{ParquetError, ParquetResult}; +use crate::parquet::page::{DataPage, DictPage, split_buffer}; +use crate::read::deserialize::dictionary_encoded::constrain_page_validity; +use crate::read::deserialize::utils; + +#[allow(clippy::large_enum_variant)] +#[derive(Debug)] +pub(crate) enum StateTranslation<'a> { + Plain(&'a [u8], usize), + Dictionary(hybrid_rle::HybridRleDecoder<'a>), +} + +impl<'a> utils::StateTranslation<'a, BinaryDecoder> for StateTranslation<'a> { + type PlainDecoder = &'a [u8]; + + fn new( + decoder: &BinaryDecoder, + page: &'a DataPage, + dict: Option<&'a ::Dict>, + page_validity: Option<&Bitmap>, + ) -> ParquetResult { + match (page.encoding(), dict) { + (Encoding::Plain, _) => { + let values = split_buffer(page)?.values; + if values.len() % decoder.size != 0 { + return Err(ParquetError::oos(format!( + "Fixed size binary data length {} is not divisible by size {}", + values.len(), + decoder.size + ))); + } + Ok(Self::Plain(values, decoder.size)) + }, + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(_)) => { + let values = + dict_indices_decoder(page, page_validity.map_or(0, |bm| bm.unset_bits()))?; + Ok(Self::Dictionary(values)) + }, + _ => Err(utils::not_implemented(page)), + } + } + + fn num_rows(&self) -> usize { + match self { + StateTranslation::Plain(v, n) => v.len() / n, + StateTranslation::Dictionary(i) => i.len(), + } + } +} + +pub(crate) struct BinaryDecoder { + pub(crate) size: usize, +} + +pub(crate) enum FSBVec { + Size1(Vec), + Size2(Vec), + Size4(Vec), + Size8(Vec), + Size12(Vec), + Size16(Vec), + Size32(Vec), + Other(Vec, usize), +} + +impl FSBVec { + pub fn new(size: usize) -> FSBVec { + match size { + 1 => Self::Size1(Vec::new()), + 2 => Self::Size2(Vec::new()), + 4 => Self::Size4(Vec::new()), + 8 => Self::Size8(Vec::new()), + 12 => Self::Size12(Vec::new()), + 16 => Self::Size16(Vec::new()), + 32 => Self::Size32(Vec::new()), + _ => Self::Other(Vec::new(), size), + } + } + + fn size(&self) -> usize { + match self { + FSBVec::Size1(_) => 1, + FSBVec::Size2(_) => 2, + FSBVec::Size4(_) => 4, + FSBVec::Size8(_) => 8, + FSBVec::Size12(_) => 12, + FSBVec::Size16(_) => 16, + FSBVec::Size32(_) => 32, + FSBVec::Other(_, size) => *size, + } + } + + pub fn into_bytes_buffer(self) -> Buffer { + Buffer::from_storage(match self { + FSBVec::Size1(vec) => SharedStorage::bytes_from_pod_vec(vec), + FSBVec::Size2(vec) => SharedStorage::bytes_from_pod_vec(vec), + FSBVec::Size4(vec) => SharedStorage::bytes_from_pod_vec(vec), + FSBVec::Size8(vec) => SharedStorage::bytes_from_pod_vec(vec), + FSBVec::Size12(vec) => SharedStorage::bytes_from_pod_vec(vec), + FSBVec::Size16(vec) => SharedStorage::bytes_from_pod_vec(vec), + FSBVec::Size32(vec) => SharedStorage::bytes_from_pod_vec(vec), + FSBVec::Other(vec, _) => SharedStorage::from_vec(vec), + }) + } + + pub fn extend_from_byte_slice(&mut self, slice: &[u8]) { + let size = self.size(); + if size == 0 { + assert_eq!(slice.len(), 0); + return; + } + + assert_eq!(slice.len() % size, 0); + + macro_rules! extend_from_slice { + ($v:expr) => {{ + $v.reserve(slice.len() / size); + unsafe { + std::ptr::copy_nonoverlapping( + slice.as_ptr(), + $v.as_mut_ptr().add($v.len()) as *mut _, + slice.len(), + ) + } + let new_len = $v.len() + slice.len() / size; + unsafe { $v.set_len(new_len) }; + }}; + } + + match self { + FSBVec::Size1(v) => extend_from_slice!(v), + FSBVec::Size2(v) => extend_from_slice!(v), + FSBVec::Size4(v) => extend_from_slice!(v), + FSBVec::Size8(v) => extend_from_slice!(v), + FSBVec::Size12(v) => extend_from_slice!(v), + FSBVec::Size16(v) => extend_from_slice!(v), + FSBVec::Size32(v) => extend_from_slice!(v), + FSBVec::Other(v, _) => v.extend_from_slice(slice), + } + } + + fn len(&self) -> usize { + match self { + FSBVec::Size1(vec) => vec.len(), + FSBVec::Size2(vec) => vec.len(), + FSBVec::Size4(vec) => vec.len(), + FSBVec::Size8(vec) => vec.len(), + FSBVec::Size12(vec) => vec.len(), + FSBVec::Size16(vec) => vec.len(), + FSBVec::Size32(vec) => vec.len(), + FSBVec::Other(vec, size) => vec.len() / size, + } + } +} + +impl utils::Decoded for (FSBVec, BitmapBuilder) { + fn len(&self) -> usize { + self.0.len() + } + + fn extend_nulls(&mut self, n: usize) { + match &mut self.0 { + FSBVec::Size1(v) => v.resize(v.len() + n, Zeroable::zeroed()), + FSBVec::Size2(v) => v.resize(v.len() + n, Zeroable::zeroed()), + FSBVec::Size4(v) => v.resize(v.len() + n, Zeroable::zeroed()), + FSBVec::Size8(v) => v.resize(v.len() + n, Zeroable::zeroed()), + FSBVec::Size12(v) => v.resize(v.len() + n, Zeroable::zeroed()), + FSBVec::Size16(v) => v.resize(v.len() + n, Zeroable::zeroed()), + FSBVec::Size32(v) => v.resize(v.len() + n, Zeroable::zeroed()), + FSBVec::Other(v, size) => v.resize(v.len() + n * *size, Zeroable::zeroed()), + } + self.1.extend_constant(n, false); + } +} + +#[allow(clippy::too_many_arguments)] +fn decode_fsb_plain( + size: usize, + values: &[u8], + target: &mut FSBVec, + pred_true_mask: &mut BitmapBuilder, + validity: &mut BitmapBuilder, + is_optional: bool, + filter: Option, + page_validity: Option<&Bitmap>, +) -> ParquetResult<()> { + assert_ne!(size, 0); + assert_eq!(values.len() % size, 0); + + macro_rules! decode_static_size { + ($target:ident) => {{ + let values = ArrayChunks::new(values).ok_or_else(|| { + ParquetError::oos("Page content does not align with expected element size") + })?; + super::primitive::plain::decode_aligned_bytes_dispatch( + values, + is_optional, + page_validity, + filter, + validity, + $target, + pred_true_mask, + ) + }}; + } + + use FSBVec as T; + match target { + T::Size1(target) => decode_static_size!(target), + T::Size2(target) => decode_static_size!(target), + T::Size4(target) => decode_static_size!(target), + T::Size8(target) => decode_static_size!(target), + T::Size12(target) => decode_static_size!(target), + T::Size16(target) => decode_static_size!(target), + T::Size32(target) => decode_static_size!(target), + T::Other(target, _) => { + // @NOTE: All these kernels are quite slow, but they should be very uncommon and the + // general case requires arbitrary length memcopies anyway. + + if is_optional { + append_validity( + page_validity, + filter.as_ref(), + validity, + values.len() / size, + ); + } + + let page_validity = + constrain_page_validity(values.len() / size, page_validity, filter.as_ref()); + + match (page_validity, filter.as_ref()) { + (None, None) => target.extend_from_slice(values), + (None, Some(filter)) => match filter { + Filter::Range(range) => { + target.extend_from_slice(&values[range.start * size..range.end * size]) + }, + Filter::Mask(bitmap) => { + let mut iter = bitmap.iter(); + let mut offset = 0; + + while iter.num_remaining() > 0 { + let num_selected = iter.take_leading_ones(); + target + .extend_from_slice(&values[offset * size..][..num_selected * size]); + offset += num_selected; + + let num_filtered = iter.take_leading_zeros(); + offset += num_filtered; + } + }, + Filter::Predicate(_) => todo!(), + }, + (Some(validity), None) => { + let mut iter = validity.iter(); + let mut offset = 0; + + while iter.num_remaining() > 0 { + let num_valid = iter.take_leading_ones(); + target.extend_from_slice(&values[offset * size..][..num_valid * size]); + offset += num_valid; + + let num_filtered = iter.take_leading_zeros(); + target.resize(target.len() + num_filtered * size, 0); + } + }, + (Some(validity), Some(filter)) => match filter { + Filter::Range(range) => { + let (skipped, active) = validity.split_at(range.start); + + let active = active.sliced(0, range.len()); + + let mut iter = active.iter(); + let mut offset = skipped.set_bits(); + + while iter.num_remaining() > 0 { + let num_valid = iter.take_leading_ones(); + target.extend_from_slice(&values[offset * size..][..num_valid * size]); + offset += num_valid; + + let num_filtered = iter.take_leading_zeros(); + target.resize(target.len() + num_filtered * size, 0); + } + }, + Filter::Mask(filter) => { + let mut offset = 0; + for (is_selected, is_valid) in filter.iter().zip(validity.iter()) { + if is_selected { + if is_valid { + target.extend_from_slice(&values[offset * size..][..size]); + } else { + target.resize(target.len() + size, 0); + } + } + + offset += usize::from(is_valid); + } + }, + Filter::Predicate(_) => todo!(), + }, + } + + Ok(()) + }, + } +} + +#[allow(clippy::too_many_arguments)] +fn decode_fsb_dict( + size: usize, + values: HybridRleDecoder<'_>, + dict: &FixedSizeBinaryArray, + dict_mask: Option<&Bitmap>, + target: &mut FSBVec, + pred_true_mask: &mut BitmapBuilder, + validity: &mut BitmapBuilder, + is_optional: bool, + filter: Option, + page_validity: Option<&Bitmap>, +) -> ParquetResult<()> { + assert_ne!(size, 0); + + macro_rules! decode_static_size { + ($dict:ident, $target:ident) => {{ + let dict = $dict.values().as_slice(); + // @NOTE: We initialize the dict with the right alignment for this to work. + let dict = bytemuck::cast_slice(dict); + + super::dictionary_encoded::decode_dict_dispatch( + values, + dict, + dict_mask, + is_optional, + page_validity, + filter, + validity, + $target, + pred_true_mask, + ) + }}; + } + + use FSBVec as T; + match target { + T::Size1(target) => decode_static_size!(dict, target), + T::Size2(target) => decode_static_size!(dict, target), + T::Size4(target) => decode_static_size!(dict, target), + T::Size8(target) => decode_static_size!(dict, target), + T::Size12(target) => decode_static_size!(dict, target), + T::Size16(target) => decode_static_size!(dict, target), + T::Size32(target) => decode_static_size!(dict, target), + T::Other(target, _) => { + // @NOTE: All these kernels are quite slow, but they should be very uncommon and the + // general case requires arbitrary length memcopies anyway. + + let dict = dict.values().as_slice(); + + if is_optional { + append_validity( + page_validity, + filter.as_ref(), + validity, + values.len() / size, + ); + } + + let page_validity = + constrain_page_validity(values.len() / size, page_validity, filter.as_ref()); + + let mut indexes = Vec::with_capacity(values.len()); + + for chunk in values.into_chunk_iter() { + match chunk? { + HybridRleChunk::Rle(value, repeats) => { + indexes.resize(indexes.len() + repeats, value) + }, + HybridRleChunk::Bitpacked(decoder) => decoder.collect_into(&mut indexes), + } + } + + match (page_validity, filter.as_ref()) { + (None, None) => target.extend( + indexes + .into_iter() + .flat_map(|v| &dict[(v as usize) * size..][..size]), + ), + (None, Some(filter)) => match filter { + Filter::Range(range) => target.extend( + indexes[range.start..range.end] + .iter() + .flat_map(|v| &dict[(*v as usize) * size..][..size]), + ), + Filter::Mask(bitmap) => { + let mut iter = bitmap.iter(); + let mut offset = 0; + + while iter.num_remaining() > 0 { + let num_selected = iter.take_leading_ones(); + target.extend( + indexes[offset..][..num_selected] + .iter() + .flat_map(|v| &dict[(*v as usize) * size..][..size]), + ); + offset += num_selected; + + let num_filtered = iter.take_leading_zeros(); + offset += num_filtered; + } + }, + Filter::Predicate(_) => todo!(), + }, + (Some(validity), None) => { + let mut iter = validity.iter(); + let mut offset = 0; + + while iter.num_remaining() > 0 { + let num_valid = iter.take_leading_ones(); + target.extend( + indexes[offset..][..num_valid] + .iter() + .flat_map(|v| &dict[(*v as usize) * size..][..size]), + ); + offset += num_valid; + + let num_filtered = iter.take_leading_zeros(); + target.resize(target.len() + num_filtered * size, 0); + } + }, + (Some(validity), Some(filter)) => match filter { + Filter::Range(range) => { + let (skipped, active) = validity.split_at(range.start); + + let active = active.sliced(0, range.len()); + + let mut iter = active.iter(); + let mut offset = skipped.set_bits(); + + while iter.num_remaining() > 0 { + let num_valid = iter.take_leading_ones(); + target.extend( + indexes[offset..][..num_valid] + .iter() + .flat_map(|v| &dict[(*v as usize) * size..][..size]), + ); + offset += num_valid; + + let num_filtered = iter.take_leading_zeros(); + target.resize(target.len() + num_filtered * size, 0); + } + }, + Filter::Mask(filter) => { + let mut offset = 0; + for (is_selected, is_valid) in filter.iter().zip(validity.iter()) { + if is_selected { + if is_valid { + target.extend_from_slice( + &dict[(indexes[offset] as usize) * size..][..size], + ); + } else { + target.resize(target.len() + size, 0); + } + } + + offset += usize::from(is_valid); + } + }, + Filter::Predicate(_) => todo!(), + }, + } + + Ok(()) + }, + } +} + +impl Decoder for BinaryDecoder { + type Translation<'a> = StateTranslation<'a>; + type Dict = FixedSizeBinaryArray; + type DecodedState = (FSBVec, BitmapBuilder); + type Output = FixedSizeBinaryArray; + + fn with_capacity(&self, capacity: usize) -> Self::DecodedState { + let size = self.size; + + let values = match size { + 1 => FSBVec::Size1(Vec::with_capacity(capacity)), + 2 => FSBVec::Size2(Vec::with_capacity(capacity)), + 4 => FSBVec::Size4(Vec::with_capacity(capacity)), + 8 => FSBVec::Size8(Vec::with_capacity(capacity)), + 12 => FSBVec::Size12(Vec::with_capacity(capacity)), + 16 => FSBVec::Size16(Vec::with_capacity(capacity)), + 32 => FSBVec::Size32(Vec::with_capacity(capacity)), + _ => FSBVec::Other(Vec::with_capacity(capacity * size), size), + }; + + (values, BitmapBuilder::with_capacity(capacity)) + } + + fn deserialize_dict(&mut self, page: DictPage) -> ParquetResult { + let mut target = FSBVec::new(self.size); + decode_fsb_plain( + self.size, + page.buffer.as_ref(), + &mut target, + &mut BitmapBuilder::new(), + &mut BitmapBuilder::new(), + false, + None, + None, + )?; + + Ok(FixedSizeBinaryArray::new( + ArrowDataType::FixedSizeBinary(self.size), + target.into_bytes_buffer(), + None, + )) + } + + fn has_predicate_specialization( + &self, + _state: &utils::State<'_, Self>, + _predicate: &PredicateFilter, + ) -> ParquetResult { + // @TODO: This can be enabled for the fast paths + Ok(false) + } + + fn extend_decoded( + &self, + decoded: &mut Self::DecodedState, + additional: &dyn arrow::array::Array, + is_optional: bool, + ) -> ParquetResult<()> { + let additional = additional + .as_any() + .downcast_ref::() + .unwrap(); + decoded + .0 + .extend_from_byte_slice(additional.values().as_slice()); + match additional.validity() { + Some(v) => decoded.1.extend_from_bitmap(v), + None if is_optional => decoded.1.extend_constant(additional.len(), true), + None => {}, + } + + Ok(()) + } + + fn finalize( + &self, + dtype: ArrowDataType, + _dict: Option, + (values, validity): Self::DecodedState, + ) -> ParquetResult { + let validity = freeze_validity(validity); + + Ok(FixedSizeBinaryArray::new( + dtype, + values.into_bytes_buffer(), + validity, + )) + } + + fn extend_filtered_with_state( + &mut self, + state: utils::State<'_, Self>, + decoded: &mut Self::DecodedState, + pred_true_mask: &mut BitmapBuilder, + filter: Option, + ) -> ParquetResult<()> { + match state.translation { + StateTranslation::Plain(values, size) => decode_fsb_plain( + size, + values, + &mut decoded.0, + pred_true_mask, + &mut decoded.1, + state.is_optional, + filter, + state.page_validity.as_ref(), + ), + StateTranslation::Dictionary(values) => decode_fsb_dict( + self.size, + values, + state.dict.unwrap(), + state.dict_mask, + &mut decoded.0, + pred_true_mask, + &mut decoded.1, + state.is_optional, + filter, + state.page_validity.as_ref(), + ), + } + } +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/mod.rs new file mode 100644 index 000000000000..6faeac4b55f2 --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/mod.rs @@ -0,0 +1,205 @@ +//! APIs to read from Parquet format. + +mod binview; +mod boolean; +mod categorical; +mod dictionary_encoded; +mod fixed_size_binary; +mod nested; +mod nested_utils; +mod null; +mod primitive; +mod simple; +mod utils; + +use arrow::array::{Array, FixedSizeListArray, ListArray, MapArray}; +use arrow::bitmap::Bitmap; +use arrow::datatypes::{ArrowDataType, Field}; +use arrow::offset::Offsets; +use polars_utils::mmap::MemReader; +use simple::page_iter_to_array; + +pub use self::nested_utils::{InitNested, NestedState, init_nested}; +pub use self::utils::filter::{Filter, PredicateFilter}; +use self::utils::freeze_validity; +use super::*; +use crate::parquet::error::ParquetResult; +use crate::parquet::read::get_page_iterator as _get_page_iterator; +use crate::parquet::schema::types::PrimitiveType; + +/// Creates a new iterator of compressed pages. +pub fn get_page_iterator( + column_metadata: &ColumnChunkMetadata, + reader: MemReader, + buffer: Vec, + max_header_size: usize, +) -> PolarsResult { + Ok(_get_page_iterator( + column_metadata, + reader, + buffer, + max_header_size, + )?) +} + +/// Creates a new [`ListArray`] or [`FixedSizeListArray`]. +pub fn create_list( + dtype: ArrowDataType, + nested: &mut NestedState, + values: Box, +) -> Box { + let (length, mut offsets, validity) = nested.pop().unwrap(); + let validity = validity.and_then(freeze_validity); + match dtype.to_logical_type() { + ArrowDataType::List(_) => { + offsets.push(values.len() as i64); + + let offsets = offsets.iter().map(|x| *x as i32).collect::>(); + + let offsets: Offsets = offsets + .try_into() + .expect("i64 offsets do not fit in i32 offsets"); + + Box::new(ListArray::::new( + dtype, + offsets.into(), + values, + validity, + )) + }, + ArrowDataType::LargeList(_) => { + offsets.push(values.len() as i64); + + Box::new(ListArray::::new( + dtype, + offsets.try_into().expect("List too large"), + values, + validity, + )) + }, + ArrowDataType::FixedSizeList(_, _) => { + Box::new(FixedSizeListArray::new(dtype, length, values, validity)) + }, + _ => unreachable!(), + } +} + +/// Creates a new [`MapArray`]. +pub fn create_map( + dtype: ArrowDataType, + nested: &mut NestedState, + values: Box, +) -> Box { + let (_, mut offsets, validity) = nested.pop().unwrap(); + match dtype.to_logical_type() { + ArrowDataType::Map(_, _) => { + offsets.push(values.len() as i64); + let offsets = offsets.iter().map(|x| *x as i32).collect::>(); + + let offsets: Offsets = offsets + .try_into() + .expect("i64 offsets do not fit in i32 offsets"); + + Box::new(MapArray::new( + dtype, + offsets.into(), + values, + validity.and_then(freeze_validity), + )) + }, + _ => unreachable!(), + } +} + +fn is_primitive(dtype: &ArrowDataType) -> bool { + matches!( + dtype.to_physical_type(), + arrow::datatypes::PhysicalType::Primitive(_) + | arrow::datatypes::PhysicalType::Null + | arrow::datatypes::PhysicalType::Boolean + | arrow::datatypes::PhysicalType::Utf8 + | arrow::datatypes::PhysicalType::LargeUtf8 + | arrow::datatypes::PhysicalType::Binary + | arrow::datatypes::PhysicalType::BinaryView + | arrow::datatypes::PhysicalType::Utf8View + | arrow::datatypes::PhysicalType::LargeBinary + | arrow::datatypes::PhysicalType::FixedSizeBinary + | arrow::datatypes::PhysicalType::Dictionary(_) + ) +} + +fn columns_to_iter_recursive( + mut columns: Vec, + mut types: Vec<&PrimitiveType>, + field: Field, + init: Vec, + filter: Option, +) -> ParquetResult<(NestedState, Box, Bitmap)> { + if init.is_empty() && is_primitive(&field.dtype) { + let (_, array, pred_true_mask) = page_iter_to_array( + columns.pop().unwrap(), + types.pop().unwrap(), + field, + filter, + None, + )?; + + return Ok((NestedState::default(), array, pred_true_mask)); + } + + nested::columns_to_iter_recursive(columns, types, field, init, filter) +} + +/// Returns the number of (parquet) columns that a [`ArrowDataType`] contains. +pub fn n_columns(dtype: &ArrowDataType) -> usize { + use arrow::datatypes::PhysicalType::*; + match dtype.to_physical_type() { + Null | Boolean | Primitive(_) | Binary | FixedSizeBinary | LargeBinary | Utf8 + | Dictionary(_) | LargeUtf8 | BinaryView | Utf8View => 1, + List | FixedSizeList | LargeList => { + let a = dtype.to_logical_type(); + if let ArrowDataType::List(inner) = a { + n_columns(&inner.dtype) + } else if let ArrowDataType::LargeList(inner) = a { + n_columns(&inner.dtype) + } else if let ArrowDataType::FixedSizeList(inner, _) = a { + n_columns(&inner.dtype) + } else { + unreachable!() + } + }, + Map => { + let a = dtype.to_logical_type(); + if let ArrowDataType::Map(inner, _) = a { + n_columns(&inner.dtype) + } else { + unreachable!() + } + }, + Struct => { + if let ArrowDataType::Struct(fields) = dtype.to_logical_type() { + fields.iter().map(|inner| n_columns(&inner.dtype)).sum() + } else { + unreachable!() + } + }, + _ => todo!(), + } +} + +/// An iterator adapter that maps multiple iterators of [`PagesIter`] into an iterator of [`Array`]s. +/// +/// For a non-nested datatypes such as [`ArrowDataType::Int32`], this function requires a single element in `columns` and `types`. +/// For nested types, `columns` must be composed by all parquet columns with associated types `types`. +/// +/// The arrays are guaranteed to be at most of size `chunk_size` and data type `field.dtype`. +pub fn column_iter_to_arrays( + columns: Vec, + types: Vec<&PrimitiveType>, + field: Field, + filter: Option, +) -> PolarsResult<(Box, Bitmap)> { + let (_, array, pred_true_mask) = + columns_to_iter_recursive(columns, types, field, vec![], filter)?; + Ok((array, pred_true_mask)) +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/nested.rs b/crates/polars-parquet/src/arrow/read/deserialize/nested.rs new file mode 100644 index 000000000000..e09b2d11e3a9 --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/nested.rs @@ -0,0 +1,182 @@ +use arrow::array::StructArray; +use arrow::datatypes::{DTYPE_CATEGORICAL, DTYPE_ENUM_VALUES, IntegerType}; +use polars_compute::cast::CastOptionsImpl; + +use self::categorical::CategoricalDecoder; +use self::nested::deserialize::utils::freeze_validity; +use self::nested_utils::NestedContent; +use self::utils::PageDecoder; +use super::*; +use crate::parquet::error::ParquetResult; + +pub fn columns_to_iter_recursive( + mut columns: Vec, + mut types: Vec<&PrimitiveType>, + field: Field, + mut init: Vec, + filter: Option, +) -> ParquetResult<(NestedState, Box, Bitmap)> { + if !field.dtype().is_nested() { + let pages = columns.pop().unwrap(); + init.push(InitNested::Primitive(field.is_nullable)); + let type_ = types.pop().unwrap(); + let (nested, arr, pdm) = page_iter_to_array(pages, type_, field, filter, Some(init))?; + Ok((nested.unwrap(), arr, pdm)) + } else { + match field.dtype() { + ArrowDataType::List(inner) | ArrowDataType::LargeList(inner) => { + init.push(InitNested::List(field.is_nullable)); + let (mut nested, array, ptm) = columns_to_iter_recursive( + columns, + types, + inner.as_ref().clone(), + init, + filter, + )?; + let array = create_list(field.dtype().clone(), &mut nested, array); + Ok((nested, array, ptm)) + }, + ArrowDataType::FixedSizeList(inner, width) => { + init.push(InitNested::FixedSizeList(field.is_nullable, *width)); + let (mut nested, array, ptm) = columns_to_iter_recursive( + columns, + types, + inner.as_ref().clone(), + init, + filter, + )?; + let array = create_list(field.dtype().clone(), &mut nested, array); + Ok((nested, array, ptm)) + }, + ArrowDataType::Struct(fields) => { + // This definitely does not support Filter predicate yet. + assert!(!matches!(&filter, Some(Filter::Predicate(_)))); + + // @NOTE: + // We go back to front here, because we constantly split off the end of the array + // to grab the relevant columns and types. + // + // Is this inefficient? Yes. Is this how we are going to do it for now? Yes. + + let Some(last_field) = fields.last() else { + return Err(ParquetError::not_supported("Struct has zero fields")); + }; + + let field_to_nested_array = + |mut init: Vec, + columns: &mut Vec, + types: &mut Vec<&PrimitiveType>, + struct_field: &Field| { + init.push(InitNested::Struct(field.is_nullable)); + let n = n_columns(&struct_field.dtype); + let columns = columns.split_off(columns.len() - n); + let types = types.split_off(types.len() - n); + + columns_to_iter_recursive( + columns, + types, + struct_field.clone(), + init, + filter.clone(), + ) + }; + + let (mut nested, last_array, _) = + field_to_nested_array(init.clone(), &mut columns, &mut types, last_field)?; + debug_assert!(matches!(nested.last().unwrap(), NestedContent::Struct)); + let (length, _, struct_validity) = nested.pop().unwrap(); + + let mut field_arrays = Vec::>::with_capacity(fields.len()); + field_arrays.push(last_array); + + for field in fields.iter().rev().skip(1) { + let (mut _nested, array, _) = + field_to_nested_array(init.clone(), &mut columns, &mut types, field)?; + + #[cfg(debug_assertions)] + { + debug_assert!(matches!(_nested.last().unwrap(), NestedContent::Struct)); + debug_assert_eq!( + _nested.pop().unwrap().2.and_then(freeze_validity), + struct_validity.clone().and_then(freeze_validity), + ); + } + + field_arrays.push(array); + } + + field_arrays.reverse(); + let struct_validity = struct_validity.and_then(freeze_validity); + + Ok(( + nested, + StructArray::new( + ArrowDataType::Struct(fields.clone()), + length, + field_arrays, + struct_validity, + ) + .to_boxed(), + Bitmap::new(), + )) + }, + ArrowDataType::Map(inner, _) => { + init.push(InitNested::List(field.is_nullable)); + let (mut nested, array, ptm) = columns_to_iter_recursive( + columns, + types, + inner.as_ref().clone(), + init, + filter, + )?; + let array = create_map(field.dtype().clone(), &mut nested, array); + Ok((nested, array, ptm)) + }, + + ArrowDataType::Dictionary(key_type, value_type, _) => { + // @note: this should only hit in two cases: + // - polars enum's and categorical's + // - int -> string which can be turned into categoricals + assert!(matches!(value_type.as_ref(), ArrowDataType::Utf8View)); + + init.push(InitNested::Primitive(field.is_nullable)); + + if field.metadata.as_ref().is_none_or(|md| { + !md.contains_key(DTYPE_ENUM_VALUES) && !md.contains_key(DTYPE_CATEGORICAL) + }) { + let (nested, arr, ptm) = PageDecoder::new( + columns.pop().unwrap(), + ArrowDataType::Utf8View, + binview::BinViewDecoder::new_string(), + Some(init), + )? + .collect_nested(filter)?; + + let arr = polars_compute::cast::cast( + arr.as_ref(), + field.dtype(), + CastOptionsImpl::default(), + ) + .unwrap(); + + Ok((nested, arr, ptm)) + } else { + assert!(matches!(key_type, IntegerType::UInt32)); + + let (nested, arr, ptm) = PageDecoder::new( + columns.pop().unwrap(), + field.dtype().clone(), + CategoricalDecoder::new(), + Some(init), + )? + .collect_boxed(filter)?; + + Ok((nested.unwrap(), arr, ptm)) + } + }, + other => Err(ParquetError::not_supported(format!( + "Deserializing type {other:?} from parquet" + ))), + } + } +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs b/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs new file mode 100644 index 000000000000..1af6b2d23b22 --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs @@ -0,0 +1,699 @@ +use arrow::bitmap::utils::BitmapIter; +use arrow::bitmap::{Bitmap, BitmapBuilder, MutableBitmap}; + +use super::utils::PageDecoder; +use super::{Filter, utils}; +use crate::parquet::encoding::hybrid_rle::{HybridRleChunk, HybridRleDecoder}; +use crate::parquet::error::ParquetResult; +use crate::parquet::page::{DataPage, split_buffer}; +use crate::parquet::read::levels::get_bit_width; + +pub struct Nested { + validity: Option, + length: usize, + content: NestedContent, + + // We batch the collection of valids and invalids to amortize the costs. This only really works + // when valids and invalids are grouped or there is a disbalance in the amount of valids vs. + // invalids. This, however, is a very common situation. + num_valids: usize, + num_invalids: usize, +} + +#[derive(Debug)] +pub enum NestedContent { + Primitive, + List { offsets: Vec }, + FixedSizeList { width: usize }, + Struct, +} + +impl Nested { + fn primitive(is_nullable: bool) -> Self { + // @NOTE: We allocate with `0` capacity here since we will not be pushing to this bitmap. + // This is because primitive does not keep track of the validity here. It keeps track in + // the decoder. We do still want to put something so that we can check for nullability by + // looking at the option. + let validity = is_nullable.then(|| BitmapBuilder::with_capacity(0)); + + Self { + validity, + length: 0, + content: NestedContent::Primitive, + + num_valids: 0, + num_invalids: 0, + } + } + + fn list_with_capacity(is_nullable: bool, capacity: usize) -> Self { + let offsets = Vec::with_capacity(capacity); + let validity = is_nullable.then(|| BitmapBuilder::with_capacity(capacity)); + Self { + validity, + length: 0, + content: NestedContent::List { offsets }, + + num_valids: 0, + num_invalids: 0, + } + } + + fn fixedlist_with_capacity(is_nullable: bool, width: usize, capacity: usize) -> Self { + let validity = is_nullable.then(|| BitmapBuilder::with_capacity(capacity)); + Self { + validity, + length: 0, + content: NestedContent::FixedSizeList { width }, + + num_valids: 0, + num_invalids: 0, + } + } + + fn struct_with_capacity(is_nullable: bool, capacity: usize) -> Self { + let validity = is_nullable.then(|| BitmapBuilder::with_capacity(capacity)); + Self { + validity, + length: 0, + content: NestedContent::Struct, + + num_valids: 0, + num_invalids: 0, + } + } + + fn take(mut self) -> (usize, Vec, Option) { + if !matches!(self.content, NestedContent::Primitive) { + if let Some(validity) = self.validity.as_mut() { + validity.extend_constant(self.num_valids, true); + validity.extend_constant(self.num_invalids, false); + } + + debug_assert!( + self.validity + .as_ref() + .is_none_or(|v| v.len() == self.length) + ); + } + + self.num_valids = 0; + self.num_invalids = 0; + + match self.content { + NestedContent::Primitive => { + debug_assert!(self.validity.is_none_or(|validity| validity.is_empty())); + (self.length, Vec::new(), None) + }, + NestedContent::List { offsets } => (self.length, offsets, self.validity), + NestedContent::FixedSizeList { .. } => (self.length, Vec::new(), self.validity), + NestedContent::Struct => (self.length, Vec::new(), self.validity), + } + } + + fn is_nullable(&self) -> bool { + self.validity.is_some() + } + + fn is_repeated(&self) -> bool { + match self.content { + NestedContent::Primitive => false, + NestedContent::List { .. } => true, + NestedContent::FixedSizeList { .. } => true, + NestedContent::Struct => false, + } + } + + fn is_required(&self) -> bool { + match self.content { + NestedContent::Primitive => false, + NestedContent::List { .. } => false, + NestedContent::FixedSizeList { .. } => false, + NestedContent::Struct => true, + } + } + + /// number of rows + fn len(&self) -> usize { + self.length + } + + fn invalid_num_values(&self) -> usize { + match &self.content { + NestedContent::Primitive => 1, + NestedContent::List { .. } => 0, + NestedContent::FixedSizeList { width } => *width, + NestedContent::Struct => 1, + } + } + + fn push(&mut self, value: i64, is_valid: bool) { + let is_primitive = matches!(self.content, NestedContent::Primitive); + + if is_valid && self.num_invalids != 0 { + debug_assert!(!is_primitive); + + // @NOTE: Having invalid items might not necessarily mean that we have a validity mask. + // + // For instance, if we have a optional struct with a required list in it, that struct + // will have a validity mask and the list will not. In the arrow representation of this + // array, however, the list will still have invalid items where the struct is null. + // + // Array: + // [ + // { 'x': [1] }, + // None, + // { 'x': [1, 2] }, + // ] + // + // Arrow: + // struct = [ list[0] None list[2] ] + // list = { + // values = [ 1, 1, 2 ], + // offsets = [ 0, 1, 1, 3 ], + // } + // + // Parquet: + // [ 1, 1, 2 ] + definition + repetition levels + // + // As you can see we need to insert an invalid item into the list even though it does + // not have a validity mask. + if let Some(validity) = self.validity.as_mut() { + validity.extend_constant(self.num_valids, true); + validity.extend_constant(self.num_invalids, false); + } + + self.num_valids = 0; + self.num_invalids = 0; + } + + self.num_valids += usize::from(!is_primitive & is_valid); + self.num_invalids += usize::from(!is_primitive & !is_valid); + + self.length += 1; + if let NestedContent::List { offsets } = &mut self.content { + offsets.push(value); + } + } + + fn push_default(&mut self, length: i64) { + let is_primitive = matches!(self.content, NestedContent::Primitive); + self.num_invalids += usize::from(!is_primitive); + + self.length += 1; + if let NestedContent::List { offsets } = &mut self.content { + offsets.push(length); + } + } +} + +/// Utility structure to create a `Filter` and `Validity` mask for the leaf values. +/// +/// This batches the extending. +pub struct BatchedNestedDecoder<'a> { + pub(crate) num_waiting_valids: usize, + pub(crate) num_waiting_invalids: usize, + + filter: &'a mut MutableBitmap, + validity: &'a mut MutableBitmap, +} + +impl BatchedNestedDecoder<'_> { + fn push_valid(&mut self) -> ParquetResult<()> { + self.push_n_valids(1) + } + + fn push_invalid(&mut self) -> ParquetResult<()> { + self.push_n_invalids(1) + } + + fn push_n_valids(&mut self, n: usize) -> ParquetResult<()> { + if self.num_waiting_invalids == 0 { + self.num_waiting_valids += n; + return Ok(()); + } + + self.filter.extend_constant(self.num_waiting_valids, true); + self.validity.extend_constant(self.num_waiting_valids, true); + + self.filter.extend_constant(self.num_waiting_invalids, true); + self.validity + .extend_constant(self.num_waiting_invalids, false); + + self.num_waiting_valids = n; + self.num_waiting_invalids = 0; + + Ok(()) + } + + fn push_n_invalids(&mut self, n: usize) -> ParquetResult<()> { + self.num_waiting_invalids += n; + Ok(()) + } + + fn skip_in_place(&mut self, n: usize) -> ParquetResult<()> { + if self.num_waiting_valids > 0 { + self.filter.extend_constant(self.num_waiting_valids, true); + self.validity.extend_constant(self.num_waiting_valids, true); + self.num_waiting_valids = 0; + } + if self.num_waiting_invalids > 0 { + self.filter.extend_constant(self.num_waiting_invalids, true); + self.validity + .extend_constant(self.num_waiting_invalids, false); + self.num_waiting_invalids = 0; + } + + self.filter.extend_constant(n, false); + self.validity.extend_constant(n, true); + + Ok(()) + } + + fn finalize(self) -> ParquetResult<()> { + self.filter.extend_constant(self.num_waiting_valids, true); + self.validity.extend_constant(self.num_waiting_valids, true); + + self.filter.extend_constant(self.num_waiting_invalids, true); + self.validity + .extend_constant(self.num_waiting_invalids, false); + + Ok(()) + } +} + +/// The initial info of nested data types. +/// The `bool` indicates if the type is nullable. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum InitNested { + /// Primitive data types + Primitive(bool), + /// List data types + List(bool), + /// Fixed-Size List data types + FixedSizeList(bool, usize), + /// Struct data types + Struct(bool), +} + +/// Initialize [`NestedState`] from `&[InitNested]`. +pub fn init_nested(init: &[InitNested], capacity: usize) -> NestedState { + use {InitNested as IN, Nested as N}; + + let container = init + .iter() + .map(|init| match init { + IN::Primitive(is_nullable) => N::primitive(*is_nullable), + IN::List(is_nullable) => N::list_with_capacity(*is_nullable, capacity), + IN::FixedSizeList(is_nullable, width) => { + N::fixedlist_with_capacity(*is_nullable, *width, capacity) + }, + IN::Struct(is_nullable) => N::struct_with_capacity(*is_nullable, capacity), + }) + .collect(); + + NestedState::new(container) +} + +/// The state of nested data types. +#[derive(Default)] +pub struct NestedState { + /// The nesteds composing `NestedState`. + nested: Vec, +} + +impl NestedState { + /// Creates a new [`NestedState`]. + fn new(nested: Vec) -> Self { + Self { nested } + } + + pub fn pop(&mut self) -> Option<(usize, Vec, Option)> { + Some(self.nested.pop()?.take()) + } + + pub fn last(&self) -> Option<&NestedContent> { + self.nested.last().map(|v| &v.content) + } + + /// The number of rows in this state + pub fn len(&self) -> usize { + // outermost is the number of rows + self.nested[0].len() + } + + /// Returns the definition and repetition levels for each nesting level + fn levels(&self) -> (Vec, Vec) { + let depth = self.nested.len(); + + let mut def_levels = Vec::with_capacity(depth + 1); + let mut rep_levels = Vec::with_capacity(depth + 1); + + def_levels.push(0); + rep_levels.push(0); + + for i in 0..depth { + let nest = &self.nested[i]; + + let def_delta = nest.is_nullable() as u16 + nest.is_repeated() as u16; + let rep_delta = nest.is_repeated() as u16; + + def_levels.push(def_levels[i] + def_delta); + rep_levels.push(rep_levels[i] + rep_delta); + } + + (def_levels, rep_levels) + } +} + +fn collect_level_values( + target: &mut Vec, + hybrid_rle: HybridRleDecoder<'_>, +) -> ParquetResult<()> { + target.reserve(hybrid_rle.len()); + + for chunk in hybrid_rle.into_chunk_iter() { + let chunk = chunk?; + + match chunk { + HybridRleChunk::Rle(value, size) => { + target.resize(target.len() + size, value as u16); + }, + HybridRleChunk::Bitpacked(decoder) => { + decoder.lower_element::()?.collect_into(target); + }, + } + } + + Ok(()) +} + +/// State to keep track of how many top-level values (i.e. rows) still need to be skipped and +/// collected. +/// +/// This state should be kept between pages because a top-level value / row value may span several +/// pages. +/// +/// - `num_skips = Some(n)` means that it will skip till the `n + 1`-th occurrence of the repetition +/// level of `0` (i.e. the start of a top-level value / row value). +/// - `num_collects = Some(n)` means that it will collect values till the `n + 1`-th occurrence of +/// the repetition level of `0` (i.e. the start of a top-level value / row value). +struct DecodingState { + num_skips: Option, + num_collects: Option, +} + +#[allow(clippy::too_many_arguments)] +fn decode_nested( + mut current_def_levels: &[u16], + mut current_rep_levels: &[u16], + + batched_collector: &mut BatchedNestedDecoder<'_>, + nested: &mut [Nested], + + state: &mut DecodingState, + top_level_filter: &mut BitmapIter<'_>, + + // Amortized allocations + def_levels: &[u16], + rep_levels: &[u16], +) -> ParquetResult<()> { + let max_depth = nested.len(); + let leaf_def_level = *def_levels.last().unwrap(); + + while !current_def_levels.is_empty() { + debug_assert_eq!(current_def_levels.len(), current_rep_levels.len()); + + // Handle skips + if let Some(ref mut num_skips) = state.num_skips { + let mut i = 0; + let mut num_skipped_values = 0; + while i < current_def_levels.len() && (*num_skips > 0 || current_rep_levels[i] != 0) { + let def = current_def_levels[i]; + let rep = current_rep_levels[i]; + + *num_skips -= usize::from(rep == 0); + i += 1; + + // @NOTE: + // We don't need to account for higher def-levels that imply extra values, since we + // don't have those higher levels either. + num_skipped_values += usize::from(def == leaf_def_level); + } + batched_collector.skip_in_place(num_skipped_values)?; + + current_def_levels = ¤t_def_levels[i..]; + current_rep_levels = ¤t_rep_levels[i..]; + + if current_def_levels.is_empty() { + break; + } else { + state.num_skips = None; + } + } + + // Handle collects + if let Some(ref mut num_collects) = state.num_collects { + let mut i = 0; + while i < current_def_levels.len() && (*num_collects > 0 || current_rep_levels[i] != 0) + { + let def = current_def_levels[i]; + let rep = current_rep_levels[i]; + + *num_collects -= usize::from(rep == 0); + i += 1; + + let mut is_required = false; + + for depth in 0..max_depth { + // Defines whether this element is defined at `depth` + // + // e.g. [ [ [ 1 ] ] ] is defined at [ ... ], [ [ ... ] ], [ [ [ ... ] ] ] and + // [ [ [ 1 ] ] ]. + let is_defined_at_this_depth = + rep <= rep_levels[depth] && def >= def_levels[depth]; + + let length = nested + .get(depth + 1) + .map(|x| x.len() as i64) + // the last depth is the leaf, which is always increased by 1 + .unwrap_or(1); + + let nest = &mut nested[depth]; + + let is_valid = !nest.is_nullable() || def > def_levels[depth]; + + if is_defined_at_this_depth && !is_valid { + let mut num_elements = 1; + + nest.push(length, is_valid); + + for embed_depth in depth..max_depth { + let embed_length = nested + .get(embed_depth + 1) + .map(|x| x.len() as i64) + // the last depth is the leaf, which is always increased by 1 + .unwrap_or(1); + + let embed_nest = &mut nested[embed_depth]; + + if embed_depth > depth { + for _ in 0..num_elements { + embed_nest.push_default(embed_length); + } + } + + let embed_num_values = embed_nest.invalid_num_values(); + num_elements *= embed_num_values; + + if embed_num_values == 0 { + break; + } + } + + batched_collector.push_n_invalids(num_elements)?; + + break; + } + + if is_required || is_defined_at_this_depth { + nest.push(length, is_valid); + + if depth == max_depth - 1 { + // the leaf / primitive + let is_valid = (def != def_levels[depth]) || !nest.is_nullable(); + + if is_valid { + batched_collector.push_valid()?; + } else { + batched_collector.push_invalid()?; + } + } + } + + is_required = (is_required || is_defined_at_this_depth) + && nest.is_required() + && !is_valid; + } + } + + current_def_levels = ¤t_def_levels[i..]; + current_rep_levels = ¤t_rep_levels[i..]; + + if current_def_levels.is_empty() { + break; + } else { + state.num_collects = None; + } + } + + if top_level_filter.num_remaining() == 0 { + break; + } + + state.num_skips = Some(top_level_filter.take_leading_zeros()).filter(|v| *v != 0); + state.num_collects = Some(top_level_filter.take_leading_ones()).filter(|v| *v != 0); + } + + Ok(()) +} + +/// Return the definition and repetition level iterators for this page. +fn level_iters(page: &DataPage) -> ParquetResult<(HybridRleDecoder, HybridRleDecoder)> { + let split = split_buffer(page)?; + let def = split.def; + let rep = split.rep; + + let max_def_level = page.descriptor.max_def_level; + let max_rep_level = page.descriptor.max_rep_level; + + let def_iter = HybridRleDecoder::new(def, get_bit_width(max_def_level), page.num_values()); + let rep_iter = HybridRleDecoder::new(rep, get_bit_width(max_rep_level), page.num_values()); + + Ok((def_iter, rep_iter)) +} + +impl PageDecoder { + pub fn collect_nested( + mut self, + filter: Option, + ) -> ParquetResult<(NestedState, D::Output, Bitmap)> { + let init = self.init_nested.as_mut().unwrap(); + + // @TODO: We should probably count the filter so that we don't overallocate + let mut target = self.decoder.with_capacity(self.iter.total_num_values()); + // @TODO: Self capacity + let mut nested_state = init_nested(init, 0); + + if let Some(dict) = self.dict.as_ref() { + self.decoder.apply_dictionary(&mut target, dict)?; + } + + // Amortize the allocations. + let (def_levels, rep_levels) = nested_state.levels(); + + let mut current_def_levels = Vec::::new(); + let mut current_rep_levels = Vec::::new(); + + let (mut decode_state, top_level_filter) = match filter { + None => ( + DecodingState { + num_skips: None, + num_collects: Some(usize::MAX), + }, + Bitmap::new(), + ), + Some(Filter::Range(range)) => ( + DecodingState { + num_skips: Some(range.start), + num_collects: Some(range.len()), + }, + Bitmap::new(), + ), + Some(Filter::Mask(mask)) => ( + DecodingState { + num_skips: None, + num_collects: None, + }, + mask, + ), + Some(Filter::Predicate(_)) => todo!(), + }; + + let mut top_level_filter = top_level_filter.iter(); + + loop { + let Some(page) = self.iter.next() else { + break; + }; + let page = page?; + let page = page.decompress(&mut self.iter)?; + + let (mut def_iter, mut rep_iter) = level_iters(&page)?; + + let num_levels = def_iter.len().min(rep_iter.len()); + def_iter.limit_to(num_levels); + rep_iter.limit_to(num_levels); + + current_def_levels.clear(); + current_rep_levels.clear(); + + collect_level_values(&mut current_def_levels, def_iter)?; + collect_level_values(&mut current_rep_levels, rep_iter)?; + + let mut leaf_filter = MutableBitmap::new(); + let mut leaf_validity = MutableBitmap::new(); + + // @TODO: move this to outside the loop. + let mut batched_collector = BatchedNestedDecoder { + num_waiting_valids: 0, + num_waiting_invalids: 0, + + filter: &mut leaf_filter, + validity: &mut leaf_validity, + }; + + decode_nested( + ¤t_def_levels, + ¤t_rep_levels, + &mut batched_collector, + &mut nested_state.nested, + &mut decode_state, + &mut top_level_filter, + &def_levels, + &rep_levels, + )?; + + batched_collector.finalize()?; + + let leaf_validity = leaf_validity.freeze(); + let leaf_filter = leaf_filter.freeze(); + + let state = utils::State::new_nested( + &self.decoder, + &page, + self.dict.as_ref(), + Some(leaf_validity), + )?; + state.decode( + &mut self.decoder, + &mut target, + &mut BitmapBuilder::new(), // This will not get used or filled + Some(Filter::Mask(leaf_filter)), + )?; + + self.iter.reuse_page_buffer(page); + } + + // we pop the primitive off here. + debug_assert!(matches!( + nested_state.nested.last().unwrap().content, + NestedContent::Primitive + )); + _ = nested_state.pop().unwrap(); + + let array = self.decoder.finalize(self.dtype, self.dict, target)?; + + Ok((nested_state, array, Bitmap::new())) + } +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/null.rs b/crates/polars-parquet/src/arrow/read/deserialize/null.rs new file mode 100644 index 000000000000..b5ad312739fb --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/null.rs @@ -0,0 +1,116 @@ +//! This implements the [`Decoder`][utils::Decoder] trait for the `UNKNOWN` or `Null` nested type. +//! The implementation mostly stubs all the function and just keeps track of the length in the +//! `DecodedState`. + +use arrow::array::NullArray; +use arrow::bitmap::{Bitmap, BitmapBuilder}; +use arrow::datatypes::ArrowDataType; + +use super::PredicateFilter; +use super::utils::filter::Filter; +use super::utils::{self}; +use crate::parquet::error::ParquetResult; +use crate::parquet::page::{DataPage, DictPage}; + +pub(crate) struct NullDecoder; +pub(crate) struct NullTranslation { + num_rows: usize, +} + +#[derive(Debug)] +pub(crate) struct NullArrayLength { + length: usize, +} + +impl utils::Decoded for NullArrayLength { + fn len(&self) -> usize { + self.length + } + fn extend_nulls(&mut self, n: usize) { + self.length += n; + } +} + +impl<'a> utils::StateTranslation<'a, NullDecoder> for NullTranslation { + type PlainDecoder = (); + + fn new( + _decoder: &NullDecoder, + page: &'a DataPage, + _dict: Option<&'a ::Dict>, + _page_validity: Option<&Bitmap>, + ) -> ParquetResult { + Ok(NullTranslation { + num_rows: page.num_values(), + }) + } + fn num_rows(&self) -> usize { + self.num_rows + } +} + +impl utils::Decoder for NullDecoder { + type Translation<'a> = NullTranslation; + type Dict = NullArray; + type DecodedState = NullArrayLength; + type Output = NullArray; + + /// Initializes a new state + fn with_capacity(&self, _: usize) -> Self::DecodedState { + NullArrayLength { length: 0 } + } + + fn deserialize_dict(&mut self, _: DictPage) -> ParquetResult { + Ok(NullArray::new_empty(ArrowDataType::Null)) + } + + fn has_predicate_specialization( + &self, + _state: &utils::State<'_, Self>, + _predicate: &PredicateFilter, + ) -> ParquetResult { + // @TODO: This can be enabled for the fast paths + Ok(false) + } + + fn extend_decoded( + &self, + decoded: &mut Self::DecodedState, + additional: &dyn arrow::array::Array, + _is_optional: bool, + ) -> ParquetResult<()> { + let additional = additional.as_any().downcast_ref::().unwrap(); + decoded.length += additional.len(); + + Ok(()) + } + + fn finalize( + &self, + dtype: ArrowDataType, + _dict: Option, + decoded: Self::DecodedState, + ) -> ParquetResult { + Ok(NullArray::new(dtype, decoded.length)) + } + + fn extend_filtered_with_state( + &mut self, + state: utils::State<'_, Self>, + decoded: &mut Self::DecodedState, + _pred_true_mask: &mut BitmapBuilder, + filter: Option, + ) -> ParquetResult<()> { + if matches!(filter, Some(Filter::Predicate(_))) { + todo!() + } + + let num_rows = match filter { + Some(f) => f.num_rows(0), + None => state.translation.num_rows, + }; + decoded.length += num_rows; + + Ok(()) + } +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/primitive/float.rs b/crates/polars-parquet/src/arrow/read/deserialize/primitive/float.rs new file mode 100644 index 000000000000..0858d58be408 --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/primitive/float.rs @@ -0,0 +1,253 @@ +use arrow::array::PrimitiveArray; +use arrow::bitmap::{Bitmap, BitmapBuilder}; +use arrow::datatypes::ArrowDataType; +use arrow::types::NativeType; + +use super::super::utils; +use super::{ClosureDecoderFunction, DecoderFunction, PrimitiveDecoder, UnitDecoderFunction}; +use crate::parquet::encoding::{Encoding, byte_stream_split, hybrid_rle}; +use crate::parquet::error::ParquetResult; +use crate::parquet::page::{DataPage, DictPage, split_buffer}; +use crate::parquet::types::{NativeType as ParquetNativeType, decode}; +use crate::read::deserialize::dictionary_encoded; +use crate::read::deserialize::utils::{ + dict_indices_decoder, freeze_validity, unspecialized_decode, +}; +use crate::read::{Filter, PredicateFilter}; + +#[allow(clippy::large_enum_variant)] +#[derive(Debug)] +pub(crate) enum StateTranslation<'a> { + Plain(&'a [u8]), + Dictionary(hybrid_rle::HybridRleDecoder<'a>), + ByteStreamSplit(byte_stream_split::Decoder<'a>), +} + +impl<'a, P, T, D> utils::StateTranslation<'a, FloatDecoder> for StateTranslation<'a> +where + T: NativeType, + P: ParquetNativeType, + D: DecoderFunction, +{ + type PlainDecoder = &'a [u8]; + + fn new( + _decoder: &FloatDecoder, + page: &'a DataPage, + dict: Option<&'a as utils::Decoder>::Dict>, + page_validity: Option<&Bitmap>, + ) -> ParquetResult { + match (page.encoding(), dict) { + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(_)) => { + let values = + dict_indices_decoder(page, page_validity.map_or(0, |bm| bm.unset_bits()))?; + Ok(Self::Dictionary(values)) + }, + (Encoding::Plain, _) => { + let values = split_buffer(page)?.values; + Ok(Self::Plain(values)) + }, + (Encoding::ByteStreamSplit, _) => { + let values = split_buffer(page)?.values; + Ok(Self::ByteStreamSplit(byte_stream_split::Decoder::try_new( + values, + size_of::

(), + )?)) + }, + _ => Err(utils::not_implemented(page)), + } + } + fn num_rows(&self) -> usize { + match self { + Self::Plain(v) => v.len() / size_of::

(), + Self::Dictionary(i) => i.len(), + Self::ByteStreamSplit(i) => i.len(), + } + } +} + +#[derive(Debug)] +pub(crate) struct FloatDecoder(PrimitiveDecoder) +where + P: ParquetNativeType, + T: NativeType, + D: DecoderFunction; + +impl FloatDecoder +where + P: ParquetNativeType, + T: NativeType, + D: DecoderFunction, +{ + #[inline] + fn new(decoder: D) -> Self { + Self(PrimitiveDecoder::new(decoder)) + } +} + +impl FloatDecoder> +where + T: NativeType + ParquetNativeType, + UnitDecoderFunction: Default + DecoderFunction, +{ + pub(crate) fn unit() -> Self { + Self::new(UnitDecoderFunction::::default()) + } +} + +impl FloatDecoder> +where + P: ParquetNativeType, + T: NativeType, + F: Copy + Fn(P) -> T, +{ + pub(crate) fn closure(f: F) -> Self { + Self::new(ClosureDecoderFunction(f, std::marker::PhantomData)) + } +} + +impl utils::Decoded for (Vec, BitmapBuilder) { + fn len(&self) -> usize { + self.0.len() + } + fn extend_nulls(&mut self, n: usize) { + self.0.resize(self.0.len() + n, T::default()); + self.1.extend_constant(n, false); + } +} + +impl utils::Decoder for FloatDecoder +where + T: NativeType, + P: ParquetNativeType, + D: DecoderFunction, +{ + type Translation<'a> = StateTranslation<'a>; + type Dict = PrimitiveArray; + type DecodedState = (Vec, BitmapBuilder); + type Output = PrimitiveArray; + + fn with_capacity(&self, capacity: usize) -> Self::DecodedState { + ( + Vec::::with_capacity(capacity), + BitmapBuilder::with_capacity(capacity), + ) + } + + fn deserialize_dict(&mut self, page: DictPage) -> ParquetResult { + let values = page.buffer.as_ref(); + + let mut target = Vec::with_capacity(page.num_values); + super::plain::decode( + values, + false, + None, + None, + &mut BitmapBuilder::new(), + &mut self.0.intermediate, + &mut target, + &mut BitmapBuilder::new(), + self.0.decoder, + )?; + Ok(PrimitiveArray::new( + T::PRIMITIVE.into(), + target.into(), + None, + )) + } + + fn has_predicate_specialization( + &self, + state: &utils::State<'_, Self>, + predicate: &PredicateFilter, + ) -> ParquetResult { + let mut has_predicate_specialization = false; + + has_predicate_specialization |= + matches!(state.translation, StateTranslation::Dictionary(_)); + has_predicate_specialization |= matches!(state.translation, StateTranslation::Plain(_)) + && predicate.predicate.to_equals_scalar().is_some(); + + // @TODO: This should be implemented + has_predicate_specialization &= state.page_validity.is_none(); + + Ok(has_predicate_specialization) + } + + fn extend_decoded( + &self, + decoded: &mut Self::DecodedState, + additional: &dyn arrow::array::Array, + is_optional: bool, + ) -> ParquetResult<()> { + let additional = additional + .as_any() + .downcast_ref::>() + .unwrap(); + decoded.0.extend(additional.values().iter().copied()); + match additional.validity() { + Some(v) => decoded.1.extend_from_bitmap(v), + None if is_optional => decoded.1.extend_constant(additional.len(), true), + None => {}, + } + + Ok(()) + } + + fn extend_filtered_with_state( + &mut self, + mut state: utils::State<'_, Self>, + decoded: &mut Self::DecodedState, + pred_true_mask: &mut BitmapBuilder, + filter: Option, + ) -> ParquetResult<()> { + match state.translation { + StateTranslation::Plain(ref mut values) => super::plain::decode( + values, + state.is_optional, + state.page_validity.as_ref(), + filter, + &mut decoded.1, + &mut self.0.intermediate, + &mut decoded.0, + pred_true_mask, + self.0.decoder, + ), + StateTranslation::Dictionary(ref mut indexes) => dictionary_encoded::decode_dict( + indexes.clone(), + state.dict.unwrap().values().as_slice(), + state.dict_mask, + state.is_optional, + state.page_validity.as_ref(), + filter, + &mut decoded.1, + &mut decoded.0, + pred_true_mask, + ), + StateTranslation::ByteStreamSplit(mut decoder) => { + let num_rows = decoder.len(); + let mut iter = decoder.iter_converted(|v| self.0.decoder.decode(decode(v))); + + unspecialized_decode( + num_rows, + || Ok(iter.next().unwrap()), + filter, + state.page_validity, + state.is_optional, + &mut decoded.1, + &mut decoded.0, + ) + }, + } + } + + fn finalize( + &self, + dtype: ArrowDataType, + _dict: Option, + (values, validity): Self::DecodedState, + ) -> ParquetResult { + let validity = freeze_validity(validity); + Ok(PrimitiveArray::try_new(dtype, values.into(), validity).unwrap()) + } +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/primitive/integer.rs b/crates/polars-parquet/src/arrow/read/deserialize/primitive/integer.rs new file mode 100644 index 000000000000..fb58ea88d5e2 --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/primitive/integer.rs @@ -0,0 +1,305 @@ +use arrow::array::PrimitiveArray; +use arrow::bitmap::{Bitmap, BitmapBuilder}; +use arrow::datatypes::ArrowDataType; +use arrow::types::NativeType; + +use super::super::utils; +use super::{ + AsDecoderFunction, ClosureDecoderFunction, DecoderFunction, IntoDecoderFunction, + PrimitiveDecoder, UnitDecoderFunction, +}; +use crate::parquet::encoding::{Encoding, byte_stream_split, delta_bitpacked, hybrid_rle}; +use crate::parquet::error::ParquetResult; +use crate::parquet::page::{DataPage, DictPage, split_buffer}; +use crate::parquet::types::{NativeType as ParquetNativeType, decode}; +use crate::read::deserialize::dictionary_encoded; +use crate::read::deserialize::utils::{ + dict_indices_decoder, freeze_validity, unspecialized_decode, +}; +use crate::read::{Filter, PredicateFilter}; + +#[allow(clippy::large_enum_variant)] +#[derive(Debug)] +pub(crate) enum StateTranslation<'a> { + Plain(&'a [u8]), + Dictionary(hybrid_rle::HybridRleDecoder<'a>), + ByteStreamSplit(byte_stream_split::Decoder<'a>), + DeltaBinaryPacked(delta_bitpacked::Decoder<'a>), +} + +impl<'a, P, T, D> utils::StateTranslation<'a, IntDecoder> for StateTranslation<'a> +where + T: NativeType, + P: ParquetNativeType, + i64: num_traits::AsPrimitive

, + D: DecoderFunction, +{ + type PlainDecoder = &'a [u8]; + + fn new( + _decoder: &IntDecoder, + page: &'a DataPage, + dict: Option<&'a as utils::Decoder>::Dict>, + page_validity: Option<&Bitmap>, + ) -> ParquetResult { + match (page.encoding(), dict) { + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(_)) => { + let values = + dict_indices_decoder(page, page_validity.map_or(0, |bm| bm.unset_bits()))?; + Ok(Self::Dictionary(values)) + }, + (Encoding::Plain, _) => { + let values = split_buffer(page)?.values; + Ok(Self::Plain(values)) + }, + (Encoding::ByteStreamSplit, _) => { + let values = split_buffer(page)?.values; + Ok(Self::ByteStreamSplit(byte_stream_split::Decoder::try_new( + values, + size_of::

(), + )?)) + }, + (Encoding::DeltaBinaryPacked, _) => { + let values = split_buffer(page)?.values; + Ok(Self::DeltaBinaryPacked( + delta_bitpacked::Decoder::try_new(values)?.0, + )) + }, + _ => Err(utils::not_implemented(page)), + } + } + fn num_rows(&self) -> usize { + match self { + Self::Plain(v) => v.len() / size_of::

(), + Self::Dictionary(i) => i.len(), + Self::ByteStreamSplit(i) => i.len(), + Self::DeltaBinaryPacked(i) => i.len(), + } + } +} + +/// Decoder of integer parquet type +#[derive(Debug)] +pub(crate) struct IntDecoder(PrimitiveDecoder) +where + T: NativeType, + P: ParquetNativeType, + i64: num_traits::AsPrimitive

, + D: DecoderFunction; + +impl IntDecoder +where + P: ParquetNativeType, + T: NativeType, + i64: num_traits::AsPrimitive

, + D: DecoderFunction, +{ + #[inline] + fn new(decoder: D) -> Self { + Self(PrimitiveDecoder::new(decoder)) + } +} + +impl IntDecoder> +where + T: NativeType + ParquetNativeType, + i64: num_traits::AsPrimitive, + UnitDecoderFunction: Default + DecoderFunction, +{ + pub(crate) fn unit() -> Self { + Self::new(UnitDecoderFunction::::default()) + } +} + +impl IntDecoder> +where + P: ParquetNativeType, + T: NativeType, + i64: num_traits::AsPrimitive

, + AsDecoderFunction: Default + DecoderFunction, +{ + pub(crate) fn cast_as() -> Self { + Self::new(AsDecoderFunction::::default()) + } +} + +impl IntDecoder> +where + P: ParquetNativeType, + T: NativeType, + i64: num_traits::AsPrimitive

, + IntoDecoderFunction: Default + DecoderFunction, +{ + pub(crate) fn cast_into() -> Self { + Self::new(IntoDecoderFunction::::default()) + } +} + +impl IntDecoder> +where + P: ParquetNativeType, + T: NativeType, + i64: num_traits::AsPrimitive

, + F: Copy + Fn(P) -> T, +{ + pub(crate) fn closure(f: F) -> Self { + Self::new(ClosureDecoderFunction(f, std::marker::PhantomData)) + } +} + +impl utils::Decoder for IntDecoder +where + T: NativeType, + P: ParquetNativeType, + i64: num_traits::AsPrimitive

, + D: DecoderFunction, +{ + type Translation<'a> = StateTranslation<'a>; + type Dict = PrimitiveArray; + type DecodedState = (Vec, BitmapBuilder); + type Output = PrimitiveArray; + + fn with_capacity(&self, capacity: usize) -> Self::DecodedState { + ( + Vec::::with_capacity(capacity), + BitmapBuilder::with_capacity(capacity), + ) + } + + fn deserialize_dict(&mut self, page: DictPage) -> ParquetResult { + let values = page.buffer.as_ref(); + + let mut target = Vec::with_capacity(page.num_values); + super::plain::decode( + values, + false, + None, + None, + &mut BitmapBuilder::new(), + &mut self.0.intermediate, + &mut target, + &mut BitmapBuilder::new(), + self.0.decoder, + )?; + Ok(PrimitiveArray::new( + T::PRIMITIVE.into(), + target.into(), + None, + )) + } + + fn has_predicate_specialization( + &self, + state: &utils::State<'_, Self>, + predicate: &PredicateFilter, + ) -> ParquetResult { + let mut has_predicate_specialization = false; + + has_predicate_specialization |= + matches!(state.translation, StateTranslation::Dictionary(_)); + has_predicate_specialization |= matches!(state.translation, StateTranslation::Plain(_)) + && predicate.predicate.to_equals_scalar().is_some(); + + // @TODO: This should be implemented + has_predicate_specialization &= state.page_validity.is_none(); + + Ok(has_predicate_specialization) + } + + fn finalize( + &self, + dtype: ArrowDataType, + _dict: Option, + (values, validity): Self::DecodedState, + ) -> ParquetResult { + let validity = freeze_validity(validity); + Ok(PrimitiveArray::try_new(dtype, values.into(), validity).unwrap()) + } + + fn extend_decoded( + &self, + decoded: &mut Self::DecodedState, + additional: &dyn arrow::array::Array, + is_optional: bool, + ) -> ParquetResult<()> { + let additional = additional + .as_any() + .downcast_ref::>() + .unwrap(); + decoded.0.extend(additional.values().iter().copied()); + match additional.validity() { + Some(v) => decoded.1.extend_from_bitmap(v), + None if is_optional => decoded.1.extend_constant(additional.len(), true), + None => {}, + } + + Ok(()) + } + + fn extend_filtered_with_state( + &mut self, + mut state: utils::State<'_, Self>, + decoded: &mut Self::DecodedState, + pred_true_mask: &mut BitmapBuilder, + filter: Option, + ) -> ParquetResult<()> { + match state.translation { + StateTranslation::Plain(ref mut values) => super::plain::decode( + values, + state.is_optional, + state.page_validity.as_ref(), + filter, + &mut decoded.1, + &mut self.0.intermediate, + &mut decoded.0, + pred_true_mask, + self.0.decoder, + ), + StateTranslation::Dictionary(ref mut indexes) => dictionary_encoded::decode_dict( + indexes.clone(), + state.dict.unwrap().values().as_slice(), + state.dict_mask, + state.is_optional, + state.page_validity.as_ref(), + filter, + &mut decoded.1, + &mut decoded.0, + pred_true_mask, + ), + StateTranslation::ByteStreamSplit(mut decoder) => { + let num_rows = decoder.len(); + let mut iter = decoder.iter_converted(|v| self.0.decoder.decode(decode(v))); + + unspecialized_decode( + num_rows, + || Ok(iter.next().unwrap()), + filter, + state.page_validity, + state.is_optional, + &mut decoded.1, + &mut decoded.0, + ) + }, + StateTranslation::DeltaBinaryPacked(decoder) => { + let num_rows = decoder.len(); + let values = decoder.collect::>()?; + + let mut i = 0; + unspecialized_decode( + num_rows, + || { + use num_traits::AsPrimitive; + let value = values[i]; + i += 1; + Ok(self.0.decoder.decode(value.as_())) + }, + filter, + state.page_validity, + state.is_optional, + &mut decoded.1, + &mut decoded.0, + ) + }, + } + } +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/primitive/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/primitive/mod.rs new file mode 100644 index 000000000000..88b8a55932a7 --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/primitive/mod.rs @@ -0,0 +1,130 @@ +use arrow::types::NativeType; + +use crate::parquet::types::NativeType as ParquetNativeType; + +mod float; +mod integer; +pub(crate) mod plain; + +pub(crate) use float::FloatDecoder; +pub(crate) use integer::IntDecoder; + +#[derive(Debug)] +pub(crate) struct PrimitiveDecoder +where + P: ParquetNativeType, + T: NativeType, + D: DecoderFunction, +{ + pub(crate) decoder: D, + pub(crate) intermediate: Vec

, + _pd: std::marker::PhantomData<(P, T)>, +} + +impl PrimitiveDecoder +where + P: ParquetNativeType, + T: NativeType, + D: DecoderFunction, +{ + #[inline] + pub(crate) fn new(decoder: D) -> Self { + Self { + decoder, + intermediate: Vec::new(), + _pd: std::marker::PhantomData, + } + } +} + +/// A function that defines how to decode from the +/// [`parquet::types::NativeType`][ParquetNativeType] to the [`arrow::types::NativeType`]. +/// +/// This should almost always be inlined. +pub(crate) trait DecoderFunction: Copy +where + T: NativeType, + P: ParquetNativeType, +{ + const NEED_TO_DECODE: bool; + const CAN_TRANSMUTE: bool = { + let has_same_size = size_of::

() == size_of::(); + let has_same_alignment = align_of::

() == align_of::(); + + has_same_size && has_same_alignment + }; + + fn decode(self, x: P) -> T; +} + +#[derive(Default, Clone, Copy)] +pub(crate) struct UnitDecoderFunction(std::marker::PhantomData); +impl DecoderFunction for UnitDecoderFunction { + const NEED_TO_DECODE: bool = false; + + #[inline(always)] + fn decode(self, x: T) -> T { + x + } +} + +#[derive(Default, Clone, Copy)] +pub(crate) struct AsDecoderFunction( + std::marker::PhantomData<(P, T)>, +); +macro_rules! as_decoder_impl { + ($($p:ty => $t:ty,)+) => { + $( + impl DecoderFunction<$p, $t> for AsDecoderFunction<$p, $t> { + const NEED_TO_DECODE: bool = Self::CAN_TRANSMUTE; + + #[inline(always)] + fn decode(self, x : $p) -> $t { + x as $t + } + } + )+ + }; +} + +as_decoder_impl![ + i32 => i8, + i32 => i16, + i32 => u8, + i32 => u16, + i32 => u32, + i64 => i32, + i64 => u32, + i64 => u64, +]; + +#[derive(Default, Clone, Copy)] +pub(crate) struct IntoDecoderFunction(std::marker::PhantomData<(P, T)>); +impl DecoderFunction for IntoDecoderFunction +where + P: ParquetNativeType + Into, + T: NativeType, +{ + const NEED_TO_DECODE: bool = true; + + #[inline(always)] + fn decode(self, x: P) -> T { + x.into() + } +} + +#[derive(Clone, Copy)] +pub(crate) struct ClosureDecoderFunction(F, std::marker::PhantomData<(P, T)>); +impl DecoderFunction for ClosureDecoderFunction +where + P: ParquetNativeType, + T: NativeType, + F: Copy + Fn(P) -> T, +{ + const NEED_TO_DECODE: bool = true; + + #[inline(always)] + fn decode(self, x: P) -> T { + (self.0)(x) + } +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/primitive/plain/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/primitive/plain/mod.rs new file mode 100644 index 000000000000..3a4e5fcb71b7 --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/primitive/plain/mod.rs @@ -0,0 +1,507 @@ +use arrow::array::{PrimitiveArray, Splitable}; +use arrow::bitmap::{Bitmap, BitmapBuilder}; +use arrow::types::{AlignedBytes, NativeType, PrimitiveType}; + +use super::DecoderFunction; +use crate::parquet::error::ParquetResult; +use crate::parquet::types::NativeType as ParquetNativeType; +use crate::read::deserialize::dictionary_encoded::{append_validity, constrain_page_validity}; +use crate::read::deserialize::utils::array_chunks::ArrayChunks; +use crate::read::deserialize::utils::freeze_validity; +use crate::read::{Filter, ParquetError}; + +mod predicate; +mod required; + +#[allow(clippy::too_many_arguments)] +pub fn decode>( + values: &[u8], + is_optional: bool, + page_validity: Option<&Bitmap>, + filter: Option, + validity: &mut BitmapBuilder, + intermediate: &mut Vec

, + target: &mut Vec, + pred_true_mask: &mut BitmapBuilder, + dfn: D, +) -> ParquetResult<()> { + let can_filter_on_raw_data = + // Floats have different equality that just byte-wise comparison. + // @TODO: Maybe be smarter about this, because most predicates should not hit this problem. + !matches!(T::PRIMITIVE, PrimitiveType::Float16 | PrimitiveType::Float32 | PrimitiveType::Float64) && + D::CAN_TRANSMUTE && !D::NEED_TO_DECODE; + + match filter { + Some(Filter::Predicate(p)) + if !can_filter_on_raw_data || p.predicate.to_equals_scalar().is_none() => + { + let num_values = values.len() / size_of::(); + + // @TODO: Do something smarter with the validity + let mut unfiltered_target = Vec::with_capacity(num_values); + let mut unfiltered_validity = page_validity + .is_some() + .then(|| BitmapBuilder::with_capacity(num_values)) + .unwrap_or_default(); + + decode_no_incompact_predicates( + values, + is_optional, + page_validity, + None, + &mut unfiltered_validity, + intermediate, + &mut unfiltered_target, + &mut BitmapBuilder::new(), + dfn, + )?; + + let unfiltered_validity = freeze_validity(unfiltered_validity); + + let array = PrimitiveArray::new( + T::PRIMITIVE.into(), + unfiltered_target.into(), + unfiltered_validity, + ); + let intermediate_pred_true_mask = p.predicate.evaluate(&array); + + let array = + polars_compute::filter::filter_with_bitmap(&array, &intermediate_pred_true_mask); + let array = array.as_any().downcast_ref::>().unwrap(); + + pred_true_mask.extend_from_bitmap(&intermediate_pred_true_mask); + target.extend(array.values().iter().copied()); + if is_optional { + match array.validity() { + None => validity.extend_constant(array.len(), true), + Some(v) => validity.extend_from_bitmap(v), + } + } + }, + f => { + decode_no_incompact_predicates( + values, + is_optional, + page_validity, + f, + validity, + intermediate, + target, + pred_true_mask, + dfn, + )?; + }, + } + + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn decode_no_incompact_predicates< + P: ParquetNativeType, + T: NativeType, + D: DecoderFunction, +>( + values: &[u8], + is_optional: bool, + page_validity: Option<&Bitmap>, + filter: Option, + validity: &mut BitmapBuilder, + intermediate: &mut Vec

, + target: &mut Vec, + pred_true_mask: &mut BitmapBuilder, + dfn: D, +) -> ParquetResult<()> { + if cfg!(debug_assertions) && is_optional { + assert_eq!(target.len(), validity.len()); + } + + if D::CAN_TRANSMUTE { + let values = ArrayChunks::<'_, T::AlignedBytes>::new(values).ok_or_else(|| { + ParquetError::oos("Page content does not align with expected element size") + })?; + + let start_length = target.len(); + decode_aligned_bytes_dispatch( + values, + is_optional, + page_validity, + filter, + validity, + ::cast_vec_ref_mut(target), + pred_true_mask, + )?; + + if D::NEED_TO_DECODE { + let to_decode: &mut [P] = bytemuck::cast_slice_mut(&mut target[start_length..]); + + for v in to_decode { + *v = bytemuck::cast(dfn.decode(*v)); + } + } + } else { + let values = ArrayChunks::<'_, P::AlignedBytes>::new(values).ok_or_else(|| { + ParquetError::oos("Page content does not align with expected element size") + })?; + + intermediate.clear(); + decode_aligned_bytes_dispatch( + values, + is_optional, + page_validity, + filter, + validity, + ::cast_vec_ref_mut(intermediate), + pred_true_mask, + )?; + + target.extend(intermediate.iter().copied().map(|v| dfn.decode(v))); + } + + if cfg!(debug_assertions) && is_optional { + assert_eq!(target.len(), validity.len()); + } + + Ok(()) +} + +#[inline(never)] +pub fn decode_aligned_bytes_dispatch( + values: ArrayChunks<'_, B>, + is_optional: bool, + page_validity: Option<&Bitmap>, + filter: Option, + validity: &mut BitmapBuilder, + target: &mut Vec, + pred_true_mask: &mut BitmapBuilder, +) -> ParquetResult<()> { + if is_optional { + append_validity(page_validity, filter.as_ref(), validity, values.len()); + } + + let page_validity = constrain_page_validity(values.len(), page_validity, filter.as_ref()); + + match (filter, page_validity) { + (None, None) => required::decode(values, target), + (None, Some(page_validity)) => decode_optional(values, page_validity, target), + + (Some(Filter::Range(rng)), None) => { + required::decode(values.slice(rng.start, rng.len()), target) + }, + (Some(Filter::Range(rng)), Some(mut page_validity)) => { + let mut values = values; + if rng.start > 0 { + let prevalidity; + (prevalidity, page_validity) = page_validity.split_at(rng.start); + page_validity.slice(0, rng.len()); + let values_start = prevalidity.set_bits(); + values = values.slice(values_start, values.len() - values_start); + } + + decode_optional(values, page_validity, target) + }, + + (Some(Filter::Mask(filter)), None) => decode_masked_required(values, filter, target), + (Some(Filter::Mask(filter)), Some(page_validity)) => { + decode_masked_optional(values, page_validity, filter, target) + }, + (Some(Filter::Predicate(p)), None) => { + if let Some(needle) = p.predicate.to_equals_scalar() { + let needle = needle.to_aligned_bytes::().unwrap(); + + let start_num_pred_true = pred_true_mask.set_bits(); + predicate::decode_equals_no_values(values, needle, pred_true_mask); + + if p.include_values { + let num_pred_true = pred_true_mask.set_bits() - start_num_pred_true; + target.resize(target.len() + num_pred_true, needle); + } + } else { + unreachable!() + } + + Ok(()) + }, + (Some(Filter::Predicate(_)), Some(_)) => todo!(), + }?; + + Ok(()) +} + +#[inline(never)] +fn decode_optional( + values: ArrayChunks<'_, B>, + mut validity: Bitmap, + target: &mut Vec, +) -> ParquetResult<()> { + target.reserve(validity.len()); + + // Handle the leading and trailing zeros. This may allow dispatch to a faster kernel or + // possibly removes iterations from the lower kernel. + let num_leading_nulls = validity.take_leading_zeros(); + target.resize(target.len() + num_leading_nulls, B::zeroed()); + let num_trailing_nulls = validity.take_trailing_zeros(); + + // Dispatch to a faster kernel if possible. + let num_values = validity.set_bits(); + if num_values == validity.len() { + required::decode(values.truncate(validity.len()), target)?; + target.resize(target.len() + num_trailing_nulls, B::zeroed()); + return Ok(()); + } + + assert!(num_values <= values.len()); + + let start_length = target.len(); + let end_length = target.len() + validity.len(); + let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; + + let mut validity_iter = validity.fast_iter_u56(); + let mut num_values_remaining = num_values; + let mut value_offset = 0; + + let mut iter = |mut v: u64, len: usize| { + debug_assert!(len < 64); + + let num_chunk_values = v.count_ones() as usize; + + if num_values_remaining == num_chunk_values { + for i in 0..len { + let is_valid = v & 1 != 0; + let value = if is_valid { + unsafe { values.get_unchecked(value_offset) } + } else { + B::zeroed() + }; + unsafe { target_ptr.add(i).write(value) }; + + value_offset += (v & 1) as usize; + v >>= 1; + } + } else { + for i in 0..len { + let value = unsafe { values.get_unchecked(value_offset) }; + unsafe { target_ptr.add(i).write(value) }; + + value_offset += (v & 1) as usize; + v >>= 1; + } + } + + num_values_remaining -= num_chunk_values; + unsafe { + target_ptr = target_ptr.add(len); + } + }; + + let mut num_remaining = validity.len(); + for v in validity_iter.by_ref() { + if num_remaining < 56 { + iter(v, num_remaining); + } else { + iter(v, 56); + } + num_remaining -= 56; + } + + let (v, vl) = validity_iter.remainder(); + + iter(v, vl.min(num_remaining)); + + unsafe { target.set_len(end_length) }; + target.resize(target.len() + num_trailing_nulls, B::zeroed()); + + Ok(()) +} + +#[inline(never)] +fn decode_masked_required( + values: ArrayChunks<'_, B>, + mut mask: Bitmap, + target: &mut Vec, +) -> ParquetResult<()> { + // Remove leading or trailing filtered values. This may allow dispatch to a faster kernel or + // may remove iterations from the slower kernel below. + let num_leading_filtered = mask.take_leading_zeros(); + mask.take_trailing_zeros(); + let values = values.slice(num_leading_filtered, mask.len()); + + // Dispatch to a faster kernel if possible. + let num_rows = mask.set_bits(); + if num_rows == mask.len() { + return required::decode(values.truncate(num_rows), target); + } + + assert!(mask.len() <= values.len()); + + let start_length = target.len(); + target.reserve(num_rows); + let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; + + let mut mask_iter = mask.fast_iter_u56(); + let mut num_rows_left = num_rows; + let mut value_offset = 0; + + let mut iter = |mut f: u64, len: usize| { + if num_rows_left == 0 { + return false; + } + + let mut num_read = 0; + let mut num_written = 0; + + while f != 0 { + let offset = f.trailing_zeros() as usize; + + num_read += offset; + + // SAFETY: + // 1. `values_buffer` starts out as only zeros, which we know is in the + // dictionary following the original `dict.is_empty` check. + // 2. Each time we write to `values_buffer`, it is followed by a + // `verify_dict_indices`. + let value = unsafe { values.get_unchecked(value_offset + num_read) }; + unsafe { target_ptr.add(num_written).write(value) }; + + num_written += 1; + num_read += 1; + + f >>= offset + 1; // Clear least significant bit. + } + + unsafe { + target_ptr = target_ptr.add(num_written); + } + value_offset += len; + num_rows_left -= num_written; + + true + }; + + for f in mask_iter.by_ref() { + if !iter(f, 56) { + break; + } + } + let (f, fl) = mask_iter.remainder(); + iter(f, fl); + + unsafe { target.set_len(start_length + num_rows) }; + + Ok(()) +} + +#[inline(never)] +fn decode_masked_optional( + values: ArrayChunks<'_, B>, + mut validity: Bitmap, + mut mask: Bitmap, + target: &mut Vec, +) -> ParquetResult<()> { + assert_eq!(validity.len(), mask.len()); + + let num_leading_filtered = mask.take_leading_zeros(); + mask.take_trailing_zeros(); + let leading_validity; + (leading_validity, validity) = validity.split_at(num_leading_filtered); + validity.slice(0, mask.len()); + + let num_rows = mask.set_bits(); + let num_values = validity.set_bits(); + + let values = values.slice(leading_validity.set_bits(), num_values); + + // Dispatch to a faster kernel if possible. + if num_rows == mask.len() { + return decode_optional(values, validity, target); + } + if num_values == validity.len() { + return decode_masked_required(values, mask, target); + } + + assert!(num_values <= values.len()); + + let start_length = target.len(); + target.reserve(num_rows); + let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; + + let mut validity_iter = validity.fast_iter_u56(); + let mut mask_iter = mask.fast_iter_u56(); + let mut num_values_left = num_values; + let mut num_rows_left = num_rows; + let mut value_offset = 0; + + let mut iter = |mut f: u64, mut v: u64| { + if num_rows_left == 0 { + return false; + } + + let num_chunk_values = v.count_ones() as usize; + + let mut num_read = 0; + let mut num_written = 0; + + if num_chunk_values == num_values_left { + while f != 0 { + let offset = f.trailing_zeros() as usize; + + num_read += (v & (1u64 << offset).wrapping_sub(1)).count_ones() as usize; + v >>= offset; + + let is_valid = v & 1 != 0; + let value = if is_valid { + unsafe { values.get_unchecked(value_offset + num_read) } + } else { + B::zeroed() + }; + unsafe { target_ptr.add(num_written).write(value) }; + + num_written += 1; + num_read += (v & 1) as usize; + + f >>= offset + 1; // Clear least significant bit. + v >>= 1; + } + } else { + while f != 0 { + let offset = f.trailing_zeros() as usize; + + num_read += (v & (1u64 << offset).wrapping_sub(1)).count_ones() as usize; + v >>= offset; + + let value = unsafe { values.get_unchecked(value_offset + num_read) }; + unsafe { target_ptr.add(num_written).write(value) }; + + num_written += 1; + num_read += (v & 1) as usize; + + f >>= offset + 1; // Clear least significant bit. + v >>= 1; + } + } + + unsafe { + target_ptr = target_ptr.add(num_written); + } + value_offset += num_chunk_values; + num_rows_left -= num_written; + num_values_left -= num_chunk_values; + + true + }; + + for (f, v) in mask_iter.by_ref().zip(validity_iter.by_ref()) { + if !iter(f, v) { + break; + } + } + + let (f, fl) = mask_iter.remainder(); + let (v, vl) = validity_iter.remainder(); + assert_eq!(fl, vl); + iter(f, v); + + unsafe { target.set_len(start_length + num_rows) }; + + Ok(()) +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/primitive/plain/predicate.rs b/crates/polars-parquet/src/arrow/read/deserialize/primitive/plain/predicate.rs new file mode 100644 index 000000000000..d586b7633cd6 --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/primitive/plain/predicate.rs @@ -0,0 +1,21 @@ +use arrow::bitmap::BitmapBuilder; +use arrow::types::AlignedBytes; + +use super::ArrayChunks; + +#[inline(never)] +pub fn decode_equals_no_values( + values: ArrayChunks<'_, B>, + needle: B, + pred_true_mask: &mut BitmapBuilder, +) { + pred_true_mask.reserve(values.len()); + for &v in values { + let is_pred_true = B::from_unaligned(v) == needle; + + // SAFETY: We reserved enough before the loop. + unsafe { + pred_true_mask.push_unchecked(is_pred_true); + } + } +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/primitive/plain/required.rs b/crates/polars-parquet/src/arrow/read/deserialize/primitive/plain/required.rs new file mode 100644 index 000000000000..d3a6196a8e23 --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/primitive/plain/required.rs @@ -0,0 +1,33 @@ +use arrow::types::AlignedBytes; + +use super::ArrayChunks; +use crate::parquet::error::ParquetResult; + +#[inline(never)] +pub fn decode( + values: ArrayChunks<'_, B>, + target: &mut Vec, +) -> ParquetResult<()> { + if values.is_empty() { + return Ok(()); + } + + target.reserve(values.len()); + + // SAFETY: Vec guarantees if the `capacity != 0` the pointer to valid since we just reserve + // that pointer. + let dst = unsafe { target.as_mut_ptr().add(target.len()) }; + let src = values.as_ptr(); + + // SAFETY: + // - `src` is valid for read of values.len() elements. + // - `dst` is valid for writes of values.len() elements, it was just reserved. + // - B::Unaligned is always aligned, since it has an alignment of 1 + // - The ranges for src and dst do not overlap + unsafe { + std::ptr::copy_nonoverlapping::(src.cast(), dst.cast(), values.len()); + target.set_len(target.len() + values.len()); + }; + + Ok(()) +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/simple.rs b/crates/polars-parquet/src/arrow/read/deserialize/simple.rs new file mode 100644 index 000000000000..234ffd9eb14b --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/simple.rs @@ -0,0 +1,562 @@ +use arrow::array::{Array, FixedSizeBinaryArray, PrimitiveArray}; +use arrow::bitmap::Bitmap; +use arrow::datatypes::{ + ArrowDataType, DTYPE_CATEGORICAL, DTYPE_ENUM_VALUES, Field, IntegerType, IntervalUnit, TimeUnit, +}; +use arrow::types::{NativeType, days_ms, i256}; +use ethnum::I256; +use polars_compute::cast::CastOptionsImpl; + +use super::utils::filter::Filter; +use super::{ + BasicDecompressor, InitNested, NestedState, boolean, fixed_size_binary, null, primitive, +}; +use crate::parquet::error::ParquetResult; +use crate::parquet::schema::types::{ + PhysicalType, PrimitiveLogicalType, PrimitiveType, TimeUnit as ParquetTimeUnit, +}; +use crate::parquet::types::int96_to_i64_ns; +use crate::read::ParquetError; +use crate::read::deserialize::binview; +use crate::read::deserialize::categorical::CategoricalDecoder; +use crate::read::deserialize::utils::PageDecoder; + +/// An iterator adapter that maps an iterator of Pages a boxed [`Array`] of [`ArrowDataType`] +/// `dtype` with a maximum of `num_rows` elements. +pub fn page_iter_to_array( + pages: BasicDecompressor, + type_: &PrimitiveType, + field: Field, + filter: Option, + init_nested: Option>, +) -> ParquetResult<(Option, Box, Bitmap)> { + use ArrowDataType::*; + + let physical_type = &type_.physical_type; + let logical_type = &type_.logical_type; + let dtype = field.dtype; + + Ok(match (physical_type, dtype.to_logical_type()) { + (_, Null) => { + PageDecoder::new(pages, dtype, null::NullDecoder, init_nested)?.collect_boxed(filter)? + }, + (PhysicalType::Boolean, Boolean) => { + PageDecoder::new(pages, dtype, boolean::BooleanDecoder, init_nested)? + .collect_boxed(filter)? + }, + (PhysicalType::Int32, UInt8) => PageDecoder::new( + pages, + dtype, + primitive::IntDecoder::::cast_as(), + init_nested, + )? + .collect_boxed(filter)?, + (PhysicalType::Int32, UInt16) => PageDecoder::new( + pages, + dtype, + primitive::IntDecoder::::cast_as(), + init_nested, + )? + .collect_boxed(filter)?, + (PhysicalType::Int32, UInt32) => PageDecoder::new( + pages, + dtype, + primitive::IntDecoder::::cast_as(), + init_nested, + )? + .collect_boxed(filter)?, + (PhysicalType::Int64, UInt32) => PageDecoder::new( + pages, + dtype, + primitive::IntDecoder::::cast_as(), + init_nested, + )? + .collect_boxed(filter)?, + (PhysicalType::Int32, Int8) => PageDecoder::new( + pages, + dtype, + primitive::IntDecoder::::cast_as(), + init_nested, + )? + .collect_boxed(filter)?, + (PhysicalType::Int32, Int16) => PageDecoder::new( + pages, + dtype, + primitive::IntDecoder::::cast_as(), + init_nested, + )? + .collect_boxed(filter)?, + (PhysicalType::Int32, Int32 | Date32 | Time32(_)) => PageDecoder::new( + pages, + dtype, + primitive::IntDecoder::::unit(), + init_nested, + )? + .collect_boxed(filter)?, + (PhysicalType::Int64 | PhysicalType::Int96, Timestamp(time_unit, _)) => { + let time_unit = *time_unit; + return timestamp( + pages, + physical_type, + logical_type, + dtype, + filter, + time_unit, + init_nested, + ); + }, + (PhysicalType::FixedLenByteArray(_), FixedSizeBinary(_)) => { + let size = FixedSizeBinaryArray::get_size(&dtype); + + PageDecoder::new( + pages, + dtype, + fixed_size_binary::BinaryDecoder { size }, + init_nested, + )? + .collect_boxed(filter)? + }, + (PhysicalType::FixedLenByteArray(12), Interval(IntervalUnit::YearMonth)) => { + // @TODO: Make a separate decoder for this + + let n = 12; + let (nested, array, ptm) = PageDecoder::new( + pages, + ArrowDataType::FixedSizeBinary(n), + fixed_size_binary::BinaryDecoder { size: n }, + init_nested, + )? + .collect(filter)?; + + let values = array + .values() + .chunks_exact(n) + .map(|value: &[u8]| i32::from_le_bytes(value[..4].try_into().unwrap())) + .collect::>(); + let validity = array.validity().cloned(); + + ( + nested, + PrimitiveArray::::try_new(dtype.clone(), values.into(), validity)?.to_boxed(), + ptm, + ) + }, + (PhysicalType::FixedLenByteArray(12), Interval(IntervalUnit::DayTime)) => { + // @TODO: Make a separate decoder for this + + let n = 12; + let (nested, array, ptm) = PageDecoder::new( + pages, + ArrowDataType::FixedSizeBinary(n), + fixed_size_binary::BinaryDecoder { size: n }, + init_nested, + )? + .collect(filter)?; + + let values = array + .values() + .chunks_exact(n) + .map(super::super::convert_days_ms) + .collect::>(); + let validity = array.validity().cloned(); + + ( + nested, + PrimitiveArray::::try_new(dtype.clone(), values.into(), validity)? + .to_boxed(), + ptm, + ) + }, + (PhysicalType::Int32, Decimal(_, _)) => PageDecoder::new( + pages, + dtype, + primitive::IntDecoder::::cast_into(), + init_nested, + )? + .collect_boxed(filter)?, + (PhysicalType::Int64, Decimal(_, _)) => PageDecoder::new( + pages, + dtype, + primitive::IntDecoder::::cast_into(), + init_nested, + )? + .collect_boxed(filter)?, + (PhysicalType::FixedLenByteArray(n), Decimal(_, _)) if *n > 16 => { + return Err(ParquetError::not_supported(format!( + "not implemented: can't decode Decimal128 type from Fixed Size Byte Array of len {n:?}" + ))); + }, + (PhysicalType::FixedLenByteArray(n), Decimal(_, _)) => { + // @TODO: Make a separate decoder for this + + let n = *n; + + let (nested, array, ptm) = PageDecoder::new( + pages, + ArrowDataType::FixedSizeBinary(n), + fixed_size_binary::BinaryDecoder { size: n }, + init_nested, + )? + .collect(filter)?; + + let values = array + .values() + .chunks_exact(n) + .map(|value: &[u8]| super::super::convert_i128(value, n)) + .collect::>(); + let validity = array.validity().cloned(); + + ( + nested, + PrimitiveArray::::try_new(dtype.clone(), values.into(), validity)?.to_boxed(), + ptm, + ) + }, + (PhysicalType::Int32, Decimal256(_, _)) => PageDecoder::new( + pages, + dtype, + primitive::IntDecoder::closure(|x: i32| i256(I256::new(x as i128))), + init_nested, + )? + .collect_boxed(filter)?, + (PhysicalType::Int64, Decimal256(_, _)) => PageDecoder::new( + pages, + dtype, + primitive::IntDecoder::closure(|x: i64| i256(I256::new(x as i128))), + init_nested, + )? + .collect_boxed(filter)?, + (PhysicalType::FixedLenByteArray(n), Decimal256(_, _)) if *n <= 16 => { + // @TODO: Make a separate decoder for this + + let n = *n; + + let (nested, array, ptm) = PageDecoder::new( + pages, + ArrowDataType::FixedSizeBinary(n), + fixed_size_binary::BinaryDecoder { size: n }, + init_nested, + )? + .collect(filter)?; + + let values = array + .values() + .chunks_exact(n) + .map(|value: &[u8]| i256(I256::new(super::super::convert_i128(value, n)))) + .collect::>(); + let validity = array.validity().cloned(); + + ( + nested, + PrimitiveArray::::try_new(dtype.clone(), values.into(), validity)?.to_boxed(), + ptm, + ) + }, + (PhysicalType::FixedLenByteArray(n), Decimal256(_, _)) if *n <= 32 => { + // @TODO: Make a separate decoder for this + + let n = *n; + + let (nested, array, ptm) = PageDecoder::new( + pages, + ArrowDataType::FixedSizeBinary(n), + fixed_size_binary::BinaryDecoder { size: n }, + init_nested, + )? + .collect(filter)?; + + let values = array + .values() + .chunks_exact(n) + .map(super::super::convert_i256) + .collect::>(); + let validity = array.validity().cloned(); + + ( + nested, + PrimitiveArray::::try_new(dtype.clone(), values.into(), validity)?.to_boxed(), + ptm, + ) + }, + (PhysicalType::FixedLenByteArray(n), Decimal256(_, _)) if *n > 32 => { + return Err(ParquetError::not_supported(format!( + "Can't decode Decimal256 type from Fixed Size Byte Array of len {n:?}", + ))); + }, + (PhysicalType::Int32, Date64) => PageDecoder::new( + pages, + dtype, + primitive::IntDecoder::closure(|x: i32| i64::from(x) * 86400000), + init_nested, + )? + .collect_boxed(filter)?, + (PhysicalType::Int64, Date64) => PageDecoder::new( + pages, + dtype, + primitive::IntDecoder::::unit(), + init_nested, + )? + .collect_boxed(filter)?, + (PhysicalType::Int64, Int64 | Time64(_) | Duration(_)) => PageDecoder::new( + pages, + dtype, + primitive::IntDecoder::::unit(), + init_nested, + )? + .collect_boxed(filter)?, + + (PhysicalType::Int64, UInt64) => PageDecoder::new( + pages, + dtype, + primitive::IntDecoder::::cast_as(), + init_nested, + )? + .collect_boxed(filter)?, + + // Float16 + (PhysicalType::FixedLenByteArray(2), Float32) => { + // @NOTE: To reduce code bloat, we just use the FixedSizeBinary decoder. + + let (nested, mut fsb_array, ptm) = PageDecoder::new( + pages, + ArrowDataType::FixedSizeBinary(2), + fixed_size_binary::BinaryDecoder { size: 2 }, + init_nested, + )? + .collect(filter)?; + + let validity = fsb_array.take_validity(); + let values = fsb_array.values().as_slice(); + assert_eq!(values.len() % 2, 0); + let values = values.chunks_exact(2); + let values = values + .map(|v| { + // SAFETY: We know that `v` is always of size two. + let le_bytes: [u8; 2] = unsafe { v.try_into().unwrap_unchecked() }; + let v = arrow::types::f16::from_le_bytes(le_bytes); + v.to_f32() + }) + .collect(); + + ( + nested, + PrimitiveArray::::new(dtype, values, validity).to_boxed(), + ptm, + ) + }, + + (PhysicalType::Float, Float32) => PageDecoder::new( + pages, + dtype, + primitive::FloatDecoder::::unit(), + init_nested, + )? + .collect_boxed(filter)?, + (PhysicalType::Double, Float64) => PageDecoder::new( + pages, + dtype, + primitive::FloatDecoder::::unit(), + init_nested, + )? + .collect_boxed(filter)?, + // Don't compile this code with `i32` as we don't use this in polars + (PhysicalType::ByteArray, LargeBinary | LargeUtf8) => { + let is_string = matches!(dtype, LargeUtf8); + PageDecoder::new( + pages, + dtype, + binview::BinViewDecoder { is_string }, + init_nested, + )? + .collect(filter)? + }, + (_, Binary | Utf8) => unreachable!(), + (PhysicalType::ByteArray, BinaryView | Utf8View) => { + let is_string = matches!(dtype, Utf8View); + PageDecoder::new( + pages, + dtype, + binview::BinViewDecoder { is_string }, + init_nested, + )? + .collect(filter)? + }, + (_, Dictionary(key_type, value_type, _)) => { + // @NOTE: This should only hit in two cases: + // - Polars enum's and categorical's + // - Int -> String which can be turned into categoricals + assert_eq!(value_type.as_ref(), &ArrowDataType::Utf8View); + + if field.metadata.is_some_and(|md| { + md.contains_key(DTYPE_ENUM_VALUES) || md.contains_key(DTYPE_CATEGORICAL) + }) && matches!(key_type, IntegerType::UInt32) + { + PageDecoder::new(pages, dtype, CategoricalDecoder::new(), init_nested)? + .collect_boxed(filter)? + } else { + let (nested, array, ptm) = PageDecoder::new( + pages, + ArrowDataType::Utf8View, + binview::BinViewDecoder::new_string(), + init_nested, + )? + .collect(filter)?; + + ( + nested, + polars_compute::cast::cast(array.as_ref(), &dtype, CastOptionsImpl::default()) + .unwrap(), + ptm, + ) + } + }, + (from, to) => { + return Err(ParquetError::not_supported(format!( + "reading parquet type {from:?} to {to:?} still not implemented", + ))); + }, + }) +} + +/// Unify the timestamp unit from parquet TimeUnit into arrow's TimeUnit +/// Returns (a int64 factor, is_multiplier) +fn unify_timestamp_unit( + logical_type: &Option, + time_unit: TimeUnit, +) -> (i64, bool) { + if let Some(PrimitiveLogicalType::Timestamp { unit, .. }) = logical_type { + match (*unit, time_unit) { + (ParquetTimeUnit::Milliseconds, TimeUnit::Millisecond) + | (ParquetTimeUnit::Microseconds, TimeUnit::Microsecond) + | (ParquetTimeUnit::Nanoseconds, TimeUnit::Nanosecond) => (1, true), + + (ParquetTimeUnit::Milliseconds, TimeUnit::Second) + | (ParquetTimeUnit::Microseconds, TimeUnit::Millisecond) + | (ParquetTimeUnit::Nanoseconds, TimeUnit::Microsecond) => (1000, false), + + (ParquetTimeUnit::Microseconds, TimeUnit::Second) + | (ParquetTimeUnit::Nanoseconds, TimeUnit::Millisecond) => (1_000_000, false), + + (ParquetTimeUnit::Nanoseconds, TimeUnit::Second) => (1_000_000_000, false), + + (ParquetTimeUnit::Milliseconds, TimeUnit::Microsecond) + | (ParquetTimeUnit::Microseconds, TimeUnit::Nanosecond) => (1_000, true), + + (ParquetTimeUnit::Milliseconds, TimeUnit::Nanosecond) => (1_000_000, true), + } + } else { + (1, true) + } +} + +#[inline] +pub fn int96_to_i64_us(value: [u32; 3]) -> i64 { + const JULIAN_DAY_OF_EPOCH: i64 = 2_440_588; + const SECONDS_PER_DAY: i64 = 86_400; + const MICROS_PER_SECOND: i64 = 1_000_000; + + let day = value[2] as i64; + let microseconds = (((value[1] as i64) << 32) + value[0] as i64) / 1_000; + let seconds = (day - JULIAN_DAY_OF_EPOCH) * SECONDS_PER_DAY; + + seconds * MICROS_PER_SECOND + microseconds +} + +#[inline] +pub fn int96_to_i64_ms(value: [u32; 3]) -> i64 { + const JULIAN_DAY_OF_EPOCH: i64 = 2_440_588; + const SECONDS_PER_DAY: i64 = 86_400; + const MILLIS_PER_SECOND: i64 = 1_000; + + let day = value[2] as i64; + let milliseconds = (((value[1] as i64) << 32) + value[0] as i64) / 1_000_000; + let seconds = (day - JULIAN_DAY_OF_EPOCH) * SECONDS_PER_DAY; + + seconds * MILLIS_PER_SECOND + milliseconds +} + +#[inline] +pub fn int96_to_i64_s(value: [u32; 3]) -> i64 { + const JULIAN_DAY_OF_EPOCH: i64 = 2_440_588; + const SECONDS_PER_DAY: i64 = 86_400; + + let day = value[2] as i64; + let seconds = (((value[1] as i64) << 32) + value[0] as i64) / 1_000_000_000; + let day_seconds = (day - JULIAN_DAY_OF_EPOCH) * SECONDS_PER_DAY; + + day_seconds + seconds +} + +fn timestamp( + pages: BasicDecompressor, + physical_type: &PhysicalType, + logical_type: &Option, + dtype: ArrowDataType, + filter: Option, + time_unit: TimeUnit, + nested: Option>, +) -> ParquetResult<(Option, Box, Bitmap)> { + if physical_type == &PhysicalType::Int96 { + return match time_unit { + TimeUnit::Nanosecond => PageDecoder::new( + pages, + dtype, + primitive::FloatDecoder::closure(|x: [u32; 3]| int96_to_i64_ns(x)), + nested, + )? + .collect_boxed(filter), + TimeUnit::Microsecond => PageDecoder::new( + pages, + dtype, + primitive::FloatDecoder::closure(|x: [u32; 3]| int96_to_i64_us(x)), + nested, + )? + .collect_boxed(filter), + TimeUnit::Millisecond => PageDecoder::new( + pages, + dtype, + primitive::FloatDecoder::closure(|x: [u32; 3]| int96_to_i64_ms(x)), + nested, + )? + .collect_boxed(filter), + TimeUnit::Second => PageDecoder::new( + pages, + dtype, + primitive::FloatDecoder::closure(|x: [u32; 3]| int96_to_i64_s(x)), + nested, + )? + .collect_boxed(filter), + }; + }; + + if physical_type != &PhysicalType::Int64 { + return Err(ParquetError::not_supported( + "can't decode a timestamp from a non-int64 parquet type", + )); + } + + let (factor, is_multiplier) = unify_timestamp_unit(logical_type, time_unit); + match (factor, is_multiplier) { + (1, _) => PageDecoder::new( + pages, + dtype, + primitive::IntDecoder::::unit(), + nested, + )? + .collect_boxed(filter), + (a, true) => PageDecoder::new( + pages, + dtype, + primitive::IntDecoder::closure(|x: i64| x * a), + nested, + )? + .collect_boxed(filter), + (a, false) => PageDecoder::new( + pages, + dtype, + primitive::IntDecoder::closure(|x: i64| x / a), + nested, + )? + .collect_boxed(filter), + } +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/utils/array_chunks.rs b/crates/polars-parquet/src/arrow/read/deserialize/utils/array_chunks.rs new file mode 100644 index 000000000000..33932223a04c --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/utils/array_chunks.rs @@ -0,0 +1,79 @@ +use arrow::types::AlignedBytes; + +/// A slice of chunks that fit an [`AlignedBytes`] type. +/// +/// This is essentially the equivalent of [`ChunksExact`][std::slice::ChunksExact], but with a size +/// and type known at compile-time. This makes the compiler able to reason much more about the +/// code. Especially, since the chunk-sizes for this type are almost always powers of 2 and +/// bitshifts or special instructions would be much better to use. +#[derive(Debug, Clone, Copy)] +pub(crate) struct ArrayChunks<'a, B: AlignedBytes> { + pub(crate) bytes: &'a [B::Unaligned], +} + +impl<'a, B: AlignedBytes> ArrayChunks<'a, B> { + /// Create a new [`ArrayChunks`] + /// + /// This returns null if the `bytes` slice's length is not a multiple of the size of `P::Bytes`. + pub(crate) fn new(bytes: &'a [u8]) -> Option { + if bytes.len() % B::SIZE != 0 { + return None; + } + + let bytes = bytemuck::cast_slice(bytes); + + Some(Self { bytes }) + } + + pub(crate) unsafe fn get_unchecked(&self, at: usize) -> B { + B::from_unaligned(*unsafe { self.bytes.get_unchecked(at) }) + } + + pub fn truncate(&self, length: usize) -> ArrayChunks<'a, B> { + let length = length.min(self.bytes.len()); + + Self { + bytes: unsafe { self.bytes.get_unchecked(..length) }, + } + } + + pub fn slice(&self, start: usize, length: usize) -> ArrayChunks<'a, B> { + assert!(start <= self.bytes.len()); + assert!(start + length <= self.bytes.len()); + unsafe { self.slice_unchecked(start, length) } + } + + pub unsafe fn slice_unchecked(&self, start: usize, length: usize) -> ArrayChunks<'a, B> { + debug_assert!(start <= self.bytes.len()); + debug_assert!(start + length <= self.bytes.len()); + Self { + bytes: unsafe { self.bytes.get_unchecked(start..start + length) }, + } + } + + pub fn as_ptr(&self) -> *const B::Unaligned { + self.bytes.as_ptr() + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +impl<'a, B: AlignedBytes> Iterator for ArrayChunks<'a, B> { + type Item = &'a B::Unaligned; + + #[inline(always)] + fn next(&mut self) -> Option { + let item = self.bytes.first()?; + self.bytes = &self.bytes[1..]; + Some(item) + } + + #[inline(always)] + fn size_hint(&self) -> (usize, Option) { + (self.bytes.len(), Some(self.bytes.len())) + } +} + +impl ExactSizeIterator for ArrayChunks<'_, B> {} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/utils/filter.rs b/crates/polars-parquet/src/arrow/read/deserialize/utils/filter.rs new file mode 100644 index 000000000000..cc72068c6200 --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/utils/filter.rs @@ -0,0 +1,87 @@ +use std::ops::Range; + +use arrow::array::Splitable; +use arrow::bitmap::Bitmap; + +use crate::read::expr::ParquetColumnExprRef; + +#[derive(Clone)] +pub struct PredicateFilter { + pub predicate: ParquetColumnExprRef, + pub include_values: bool, +} + +#[derive(Clone)] +pub enum Filter { + Range(Range), + Mask(Bitmap), + Predicate(PredicateFilter), +} + +impl Filter { + pub fn new_limited(x: usize) -> Self { + Filter::Range(0..x) + } + + pub fn new_ranged(start: usize, end: usize) -> Self { + Filter::Range(start..end) + } + + pub fn new_masked(mask: Bitmap) -> Self { + Filter::Mask(mask) + } + + pub fn num_rows(&self, total_num_rows: usize) -> usize { + match self { + Self::Range(range) => range.len(), + Self::Mask(bitmap) => bitmap.set_bits(), + Self::Predicate { .. } => total_num_rows, + } + } + + pub fn max_offset(&self, total_num_rows: usize) -> usize { + match self { + Self::Range(range) => range.end, + Self::Mask(bitmap) => bitmap.len(), + Self::Predicate { .. } => total_num_rows, + } + } + + pub(crate) fn split_at(&self, at: usize) -> (Self, Self) { + match self { + Self::Range(range) => { + let start = range.start; + let end = range.end; + + if at <= start { + (Self::Range(0..0), Self::Range(start - at..end - at)) + } else if at > end { + (Self::Range(start..end), Self::Range(0..0)) + } else { + (Self::Range(start..at), Self::Range(0..end - at)) + } + }, + Self::Mask(bitmap) => { + let (lhs, rhs) = bitmap.split_at(at); + (Self::Mask(lhs), Self::Mask(rhs)) + }, + Self::Predicate(e) => (Self::Predicate(e.clone()), Self::Predicate(e.clone())), + } + } + + pub(crate) fn opt_split_at(filter: &Option, at: usize) -> (Option, Option) { + let Some(filter) = filter else { + return (None, None); + }; + + let (lhs, rhs) = filter.split_at(at); + (Some(lhs), Some(rhs)) + } + + pub(crate) fn opt_num_rows(filter: &Option, total_num_rows: usize) -> usize { + match filter { + Some(filter) => usize::min(filter.num_rows(total_num_rows), total_num_rows), + None => total_num_rows, + } + } +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/utils/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/utils/mod.rs new file mode 100644 index 000000000000..3b361089ce75 --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/utils/mod.rs @@ -0,0 +1,686 @@ +pub(crate) mod array_chunks; +pub(crate) mod filter; + +use std::ops::Range; + +use arrow::array::{Array, IntoBoxedArray, Splitable}; +use arrow::bitmap::{Bitmap, BitmapBuilder}; +use arrow::datatypes::ArrowDataType; +use arrow::pushable::Pushable; +use polars_compute::filter::filter_boolean_kernel; + +use self::filter::Filter; +use super::{BasicDecompressor, InitNested, NestedState, PredicateFilter}; +use crate::parquet::encoding::hybrid_rle::{self, HybridRleChunk, HybridRleDecoder}; +use crate::parquet::error::{ParquetError, ParquetResult}; +use crate::parquet::page::{DataPage, DictPage, split_buffer}; +use crate::parquet::schema::Repetition; + +#[derive(Debug)] +pub(crate) struct State<'a, D: Decoder> { + pub(crate) dict: Option<&'a D::Dict>, + pub(crate) dict_mask: Option<&'a Bitmap>, + pub(crate) is_optional: bool, + pub(crate) page_validity: Option, + pub(crate) translation: D::Translation<'a>, +} + +pub(crate) trait StateTranslation<'a, D: Decoder>: Sized { + type PlainDecoder; + + fn new( + decoder: &D, + page: &'a DataPage, + dict: Option<&'a D::Dict>, + page_validity: Option<&Bitmap>, + ) -> ParquetResult; + fn num_rows(&self) -> usize; +} + +impl<'a, D: Decoder> State<'a, D> { + pub fn new( + decoder: &D, + page: &'a DataPage, + dict: Option<&'a D::Dict>, + dict_mask: Option<&'a Bitmap>, + ) -> ParquetResult { + let is_optional = + page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + + let mut page_validity = None; + + // Make the page_validity None if there are no nulls in the page + if is_optional && page.null_count().is_none_or(|nc| nc != 0) { + let pv = page_validity_decoder(page)?; + page_validity = decode_page_validity(pv, None)?; + } + + let translation = D::Translation::new(decoder, page, dict, page_validity.as_ref())?; + + Ok(Self { + dict, + dict_mask, + is_optional, + page_validity, + translation, + }) + } + + pub fn new_nested( + decoder: &D, + page: &'a DataPage, + dict: Option<&'a D::Dict>, + mut page_validity: Option, + ) -> ParquetResult { + let translation = D::Translation::new(decoder, page, dict, None)?; + + let is_optional = + page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + + if page_validity + .as_ref() + .is_some_and(|bm| bm.unset_bits() == 0) + { + page_validity = None; + } + + Ok(Self { + dict, + dict_mask: None, + translation, + is_optional, + page_validity, + }) + } + + pub fn decode( + self, + decoder: &mut D, + decoded: &mut D::DecodedState, + pred_true_mask: &mut BitmapBuilder, + filter: Option, + ) -> ParquetResult<()> { + decoder.extend_filtered_with_state(self, decoded, pred_true_mask, filter) + } +} + +pub fn not_implemented(page: &DataPage) -> ParquetError { + let is_optional = page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + let required = if is_optional { "optional" } else { "required" }; + ParquetError::not_supported(format!( + "Decoding {:?} \"{:?}\"-encoded {required} parquet pages not yet supported", + page.descriptor.primitive_type.physical_type, + page.encoding(), + )) +} + +pub(crate) type PageValidity<'a> = HybridRleDecoder<'a>; +pub(crate) fn page_validity_decoder(page: &DataPage) -> ParquetResult { + let validity = split_buffer(page)?.def; + let decoder = hybrid_rle::HybridRleDecoder::new(validity, 1, page.num_values()); + Ok(decoder) +} + +pub(crate) fn unspecialized_decode( + mut num_rows: usize, + + mut decode_one: impl FnMut() -> ParquetResult, + + mut filter: Option, + mut page_validity: Option, + + is_optional: bool, + + validity: &mut BitmapBuilder, + target: &mut impl Pushable, +) -> ParquetResult<()> { + match &mut filter { + None => {}, + Some(Filter::Range(range)) => { + match page_validity.as_mut() { + None => { + for _ in 0..range.start { + decode_one()?; + } + }, + Some(pv) => { + let c; + (c, *pv) = pv.split_at(range.start); + for _ in 0..c.set_bits() { + decode_one()?; + } + *pv = std::mem::take(pv).sliced(0, range.len()); + }, + } + + num_rows = range.len(); + filter = None; + }, + Some(Filter::Mask(mask)) => { + let leading_zeros = mask.take_leading_zeros(); + mask.take_trailing_zeros(); + + match page_validity.as_mut() { + None => { + for _ in 0..leading_zeros { + decode_one()?; + } + }, + Some(pv) => { + let c; + (c, *pv) = pv.split_at(leading_zeros); + for _ in 0..c.set_bits() { + decode_one()?; + } + *pv = std::mem::take(pv).sliced(0, mask.len()); + }, + } + + num_rows = mask.len(); + if mask.unset_bits() == 0 { + filter = None; + } + }, + Some(Filter::Predicate(_)) => todo!(), + }; + + page_validity = page_validity.filter(|pv| pv.unset_bits() > 0); + + match (filter, page_validity) { + (None, None) => { + target.reserve(num_rows); + for _ in 0..num_rows { + target.push(decode_one()?); + } + + if is_optional { + validity.extend_constant(num_rows, true); + } + }, + (None, Some(page_validity)) => { + target.reserve(page_validity.len()); + for is_valid in page_validity.iter() { + let v = if is_valid { + decode_one()? + } else { + T::default() + }; + target.push(v); + } + + validity.extend_from_bitmap(&page_validity); + }, + (Some(Filter::Range(_)), _) => unreachable!(), + (Some(Filter::Mask(mut mask)), None) => { + target.reserve(num_rows); + + while !mask.is_empty() { + let num_ones = mask.take_leading_ones(); + for _ in 0..num_ones { + target.push(decode_one()?); + } + + let num_zeros = mask.take_leading_zeros(); + for _ in 0..num_zeros { + decode_one()?; + } + } + + if is_optional { + validity.extend_constant(num_rows, true); + } + }, + (Some(Filter::Mask(mask)), Some(page_validity)) => { + assert_eq!(mask.len(), page_validity.len()); + + let num_rows = mask.set_bits(); + target.reserve(num_rows); + + let mut mask_iter = mask.fast_iter_u56(); + let mut validity_iter = page_validity.fast_iter_u56(); + + let mut iter = |mut f: u64, mut v: u64| { + while f != 0 { + let offset = f.trailing_zeros(); + + let skip = (v & (1u64 << offset).wrapping_sub(1)).count_ones() as usize; + for _ in 0..skip { + decode_one()?; + } + + if (v >> offset) & 1 != 0 { + target.push(decode_one()?); + } else { + target.push(T::default()); + } + + v >>= offset + 1; + f >>= offset + 1; + } + + for _ in 0..v.count_ones() as usize { + decode_one()?; + } + + ParquetResult::Ok(()) + }; + + for (f, v) in mask_iter.by_ref().zip(validity_iter.by_ref()) { + iter(f, v)?; + } + + let (f, fl) = mask_iter.remainder(); + let (v, vl) = validity_iter.remainder(); + + assert_eq!(fl, vl); + + iter(f, v)?; + + validity.extend_from_bitmap(&filter_boolean_kernel(&page_validity, &mask)); + }, + (Some(Filter::Predicate(_)), _) => todo!(), + } + + Ok(()) +} + +/// The state that will be decoded into. +/// +/// This is usually an Array and a validity mask as a MutableBitmap. +pub(super) trait Decoded { + /// The number of items in the container + fn len(&self) -> usize; + /// Extend the decoded state with `n` nulls. + fn extend_nulls(&mut self, n: usize); +} + +/// A decoder that knows how to map `State` -> Array +pub(super) trait Decoder: Sized { + /// The state that this decoder derives from a [`DataPage`]. This is bound to the page. + type Translation<'a>: StateTranslation<'a, Self>; + /// The dictionary representation that the decoder uses + type Dict: Array + Clone; + /// The target state that this Decoder decodes into. + type DecodedState: Decoded; + + type Output: IntoBoxedArray; + + fn evaluate_dict_predicate( + &self, + dict: &Self::Dict, + predicate: &PredicateFilter, + ) -> ParquetResult { + Ok(predicate.predicate.evaluate(dict)) + } + + /// Initializes a new [`Self::DecodedState`]. + fn with_capacity(&self, capacity: usize) -> Self::DecodedState; + + /// Deserializes a [`DictPage`] into [`Self::Dict`]. + fn deserialize_dict(&mut self, page: DictPage) -> ParquetResult; + + fn has_predicate_specialization( + &self, + state: &State<'_, Self>, + predicate: &PredicateFilter, + ) -> ParquetResult; + + fn extend_decoded( + &self, + decoded: &mut Self::DecodedState, + additional: &dyn Array, + is_optional: bool, + ) -> ParquetResult<()>; + + fn unspecialized_predicate_decode( + &mut self, + state: State<'_, Self>, + decoded: &mut Self::DecodedState, + pred_true_mask: &mut BitmapBuilder, + predicate: &PredicateFilter, + dict: Option, + dtype: &ArrowDataType, + ) -> ParquetResult<()> { + let is_optional = state.is_optional; + + let mut intermediate_array = self.with_capacity(state.translation.num_rows()); + if let Some(dict) = dict.as_ref() { + self.apply_dictionary(&mut intermediate_array, dict)?; + } + self.extend_filtered_with_state( + state, + &mut intermediate_array, + &mut BitmapBuilder::new(), + None, + )?; + let intermediate_array = self + .finalize(dtype.underlying_physical_type(), dict, intermediate_array)? + .into_boxed(); + + let mask = if let Some(validity) = intermediate_array.validity() { + let ignore_validity_array = intermediate_array.with_validity(None); + let mask = predicate.predicate.evaluate(ignore_validity_array.as_ref()); + + if predicate.predicate.evaluate_null() { + arrow::bitmap::or_not(&mask, validity) + } else { + &mask & validity + } + } else { + predicate.predicate.evaluate(intermediate_array.as_ref()) + }; + + let filtered = + polars_compute::filter::filter_with_bitmap(intermediate_array.as_ref(), &mask); + + pred_true_mask.extend_from_bitmap(&mask); + self.extend_decoded(decoded, filtered.as_ref(), is_optional)?; + + Ok(()) + } + + fn extend_filtered_with_state( + &mut self, + state: State<'_, Self>, + decoded: &mut Self::DecodedState, + pred_true_mask: &mut BitmapBuilder, + filter: Option, + ) -> ParquetResult<()>; + + fn apply_dictionary( + &mut self, + _decoded: &mut Self::DecodedState, + _dict: &Self::Dict, + ) -> ParquetResult<()> { + Ok(()) + } + + fn finalize( + &self, + dtype: ArrowDataType, + dict: Option, + decoded: Self::DecodedState, + ) -> ParquetResult; +} + +pub struct PageDecoder { + pub iter: BasicDecompressor, + pub dtype: ArrowDataType, + pub dict: Option, + pub decoder: D, + + pub init_nested: Option>, +} + +impl PageDecoder { + pub fn new( + mut iter: BasicDecompressor, + dtype: ArrowDataType, + mut decoder: D, + + init_nested: Option>, + ) -> ParquetResult { + let dict_page = iter.read_dict_page()?; + let dict = dict_page.map(|d| decoder.deserialize_dict(d)).transpose()?; + + Ok(Self { + iter, + dtype, + dict, + decoder, + + init_nested, + }) + } + + pub fn collect( + self, + filter: Option, + ) -> ParquetResult<(Option, D::Output, Bitmap)> { + if self.init_nested.is_some() { + self.collect_nested(filter) + .map(|(nested, arr, ptm)| (Some(nested), arr, ptm)) + } else { + self.collect_flat(filter).map(|(arr, ptm)| (None, arr, ptm)) + } + } + + pub fn collect_flat( + mut self, + mut filter: Option, + ) -> ParquetResult<(D::Output, Bitmap)> { + let mut num_rows_remaining = Filter::opt_num_rows(&filter, self.iter.total_num_values()); + + // @TODO: Don't allocate if include_values == false + let mut target = self.decoder.with_capacity(num_rows_remaining); + let mut pred_true_mask = BitmapBuilder::new(); + + let mut pred_tracks_nulls = true; + let mut dict_mask = None; + if let Some(dict) = self.dict.as_ref() { + self.decoder.apply_dictionary(&mut target, dict)?; + + if let Some(Filter::Predicate(p)) = &filter { + pred_tracks_nulls = p.predicate.evaluate_null(); + pred_true_mask.reserve(num_rows_remaining); + dict_mask = Some(self.decoder.evaluate_dict_predicate(dict, p)?); + } + } + + while num_rows_remaining > 0 { + let Some(page) = self.iter.next() else { + break; + }; + let page = page?; + + let page_num_values = page.num_values(); + + let state_filter; + (state_filter, filter) = Filter::opt_split_at(&filter, page_num_values); + + // Skip the whole page if we don't need any rows from it + if state_filter + .as_ref() + .is_some_and(|f| f.num_rows(page_num_values) == 0) + { + continue; + } + + // Skip a dictionary encoded page if none of the dictionary values match the predicate. + // This is essentially a slower version of statistics skipping. + if dict_mask.as_ref().is_some_and(|dm| dm.set_bits() == 0) + && page.page().header().is_dictionary_encoded() + && (page.page().descriptor.primitive_type.field_info.repetition + != Repetition::Optional + || !pred_tracks_nulls) + { + pred_true_mask.extend_constant(page.num_values(), false); + continue; + } + + let page = page.decompress(&mut self.iter)?; + + let state = State::new(&self.decoder, &page, self.dict.as_ref(), dict_mask.as_ref())?; + + let start_length = target.len(); + match &state_filter { + // Handle the case where column is held equal to Null. This can be the same for all + // non-nested columns. + Some(Filter::Predicate(p)) + if p.predicate + .to_equals_scalar() + .is_some_and(|sc| sc.is_null()) => + { + if state.is_optional { + match &state.page_validity { + None => pred_true_mask.extend_constant(page.num_values(), false), + Some(v) => { + pred_true_mask.extend_from_bitmap(v); + if p.include_values { + target.extend_nulls(v.set_bits()); + } + }, + } + } else { + pred_true_mask.extend_constant(page.num_values(), false); + } + drop(state); + }, + + // For now, we have a function that indicates whether the predicate can actually be + // handled in the kernels. If it cannot be handled in the kernels, catch it here + // and load it as if it weren't filtered. + Some(Filter::Predicate(p)) + if !self.decoder.has_predicate_specialization(&state, p)? => + { + self.decoder.unspecialized_predicate_decode( + state, + &mut target, + &mut pred_true_mask, + p, + self.dict.clone(), + &self.dtype, + )? + }, + _ => state.decode( + &mut self.decoder, + &mut target, + &mut pred_true_mask, + state_filter, + )?, + } + + let end_length = target.len(); + + num_rows_remaining -= end_length - start_length; + + self.iter.reuse_page_buffer(page); + } + + let array = self.decoder.finalize(self.dtype, self.dict, target)?; + Ok((array, pred_true_mask.freeze())) + } + + pub fn collect_boxed( + self, + filter: Option, + ) -> ParquetResult<(Option, Box, Bitmap)> { + use arrow::array::IntoBoxedArray; + let (nested, array, ptm) = self.collect(filter)?; + Ok((nested, array.into_boxed(), ptm)) + } +} + +#[inline] +pub(super) fn dict_indices_decoder( + page: &DataPage, + null_count: usize, +) -> ParquetResult { + let indices_buffer = split_buffer(page)?.values; + + // SPEC: Data page format: the bit width used to encode the entry ids stored as 1 byte (max bit width = 32), + // SPEC: followed by the values encoded using RLE/Bit packed described above (with the given bit width). + let bit_width = indices_buffer[0]; + let indices_buffer = &indices_buffer[1..]; + + Ok(hybrid_rle::HybridRleDecoder::new( + indices_buffer, + bit_width as u32, + page.num_values() - null_count, + )) +} + +/// Freeze a [`MutableBitmap`] into a `Option`. +/// +/// This will turn the several instances where `None` (representing "all valid") suffices. +pub fn freeze_validity(validity: BitmapBuilder) -> Option { + if validity.is_empty() || validity.unset_bits() == 0 { + return None; + } + + let validity = validity.freeze(); + Some(validity) +} + +pub(crate) fn filter_from_range(rng: Range) -> Bitmap { + let mut bm = BitmapBuilder::with_capacity(rng.end); + + bm.extend_constant(rng.start, false); + bm.extend_constant(rng.len(), true); + + bm.freeze() +} + +pub(crate) fn decode_hybrid_rle_into_bitmap( + mut page_validity: HybridRleDecoder<'_>, + limit: Option, + bitmap: &mut BitmapBuilder, +) -> ParquetResult<()> { + assert!(page_validity.num_bits() <= 1); + + let mut limit = limit.unwrap_or(page_validity.len()); + bitmap.reserve(limit); + + while let Some(chunk) = page_validity.next_chunk()? { + if limit == 0 { + break; + } + + match chunk { + HybridRleChunk::Rle(value, size) => { + let size = size.min(limit); + bitmap.extend_constant(size, value != 0); + limit -= size; + }, + HybridRleChunk::Bitpacked(decoder) => { + let len = decoder.len().min(limit); + bitmap.extend_from_slice(decoder.as_slice(), 0, len); + limit -= len; + }, + } + } + + Ok(()) +} + +pub(crate) fn decode_page_validity( + mut page_validity: HybridRleDecoder<'_>, + limit: Option, +) -> ParquetResult> { + assert!(page_validity.num_bits() <= 1); + + let mut num_ones = 0; + + let mut bm = BitmapBuilder::new(); + let limit = limit.unwrap_or(page_validity.len()); + page_validity.limit_to(limit); + let num_values = page_validity.len(); + + // If all values are valid anyway, we will return a None so don't allocate until we disprove + // that that is the case. + while let Some(chunk) = page_validity.next_chunk()? { + match chunk { + HybridRleChunk::Rle(value, size) if value != 0 => num_ones += size, + HybridRleChunk::Rle(value, size) => { + bm.reserve(num_values); + bm.extend_constant(num_ones, true); + bm.extend_constant(size, value != 0); + break; + }, + HybridRleChunk::Bitpacked(decoder) => { + let len = decoder.len(); + bm.reserve(num_values); + bm.extend_constant(num_ones, true); + bm.extend_from_slice(decoder.as_slice(), 0, len); + break; + }, + } + } + + if page_validity.len() == 0 && bm.is_empty() { + return Ok(None); + } + + decode_hybrid_rle_into_bitmap(page_validity, None, &mut bm)?; + Ok(Some(bm.freeze())) +} diff --git a/crates/polars-parquet/src/arrow/read/expr.rs b/crates/polars-parquet/src/arrow/read/expr.rs new file mode 100644 index 000000000000..469e9da44ade --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/expr.rs @@ -0,0 +1,122 @@ +use std::sync::Arc; + +use arrow::array::Array; +use arrow::bitmap::{Bitmap, BitmapBuilder}; +use arrow::types::AlignedBytes; + +#[derive(Clone)] +pub enum ParquetScalar { + Null, + + Boolean(bool), + + Int8(i8), + Int16(i16), + Int32(i32), + Int64(i64), + UInt8(u8), + UInt16(u16), + UInt32(u32), + UInt64(u64), + + Float32(f32), + Float64(f64), + + FixedSizeBinary(Box<[u8]>), + + String(Box), + Binary(Box<[u8]>), +} + +impl ParquetScalar { + pub(crate) fn is_null(&self) -> bool { + matches!(self, Self::Null) + } + + pub(crate) fn to_aligned_bytes(&self) -> Option { + match self { + Self::Int8(v) => ::try_from(&v.to_le_bytes()) + .ok() + .map(B::from_unaligned), + Self::Int16(v) => ::try_from(&v.to_le_bytes()) + .ok() + .map(B::from_unaligned), + Self::Int32(v) => ::try_from(&v.to_le_bytes()) + .ok() + .map(B::from_unaligned), + Self::Int64(v) => ::try_from(&v.to_le_bytes()) + .ok() + .map(B::from_unaligned), + Self::UInt8(v) => ::try_from(&v.to_le_bytes()) + .ok() + .map(B::from_unaligned), + Self::UInt16(v) => ::try_from(&v.to_le_bytes()) + .ok() + .map(B::from_unaligned), + Self::UInt32(v) => ::try_from(&v.to_le_bytes()) + .ok() + .map(B::from_unaligned), + Self::UInt64(v) => ::try_from(&v.to_le_bytes()) + .ok() + .map(B::from_unaligned), + Self::Float32(v) => ::try_from(&v.to_le_bytes()) + .ok() + .map(B::from_unaligned), + Self::Float64(v) => ::try_from(&v.to_le_bytes()) + .ok() + .map(B::from_unaligned), + _ => None, + } + } + + pub(crate) fn as_str(&self) -> Option<&str> { + match self { + Self::String(s) => Some(s.as_ref()), + _ => None, + } + } + + pub(crate) fn as_binary(&self) -> Option<&[u8]> { + match self { + Self::Binary(s) => Some(s.as_ref()), + _ => None, + } + } +} + +pub enum ParquetScalarRange { + Min(ParquetScalar), + Max(ParquetScalar), + Closed(ParquetScalar, ParquetScalar), +} + +pub type ParquetColumnExprRef = Arc; +pub trait ParquetColumnExpr: Send + Sync { + fn evaluate(&self, values: &dyn Array) -> Bitmap { + let mut bm = BitmapBuilder::new(); + self.evaluate_mut(values, &mut bm); + bm.freeze() + } + fn evaluate_mut(&self, values: &dyn Array, bm: &mut BitmapBuilder); + + fn evaluate_null(&self) -> bool; + + fn to_equals_scalar(&self) -> Option; + fn to_range_scalar(&self) -> Option; +} + +impl ParquetScalarRange { + pub fn into_closed_with_min_max( + self, + min: ParquetScalar, + max: ParquetScalar, + ) -> (ParquetScalar, ParquetScalar) { + debug_assert_eq!(std::mem::discriminant(&min), std::mem::discriminant(&max)); + + match self { + Self::Min(min) => (min, max), + Self::Max(max) => (min, max), + Self::Closed(min, max) => (min, max), + } + } +} diff --git a/crates/polars-parquet/src/arrow/read/mod.rs b/crates/polars-parquet/src/arrow/read/mod.rs new file mode 100644 index 000000000000..c8f6f575875c --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/mod.rs @@ -0,0 +1,100 @@ +//! APIs to read from Parquet format. +#![allow(clippy::type_complexity)] + +mod deserialize; +pub mod expr; +pub mod schema; +pub mod statistics; + +use std::io::{Read, Seek}; + +use arrow::types::{NativeType, i256}; +pub use deserialize::{ + Filter, InitNested, NestedState, PredicateFilter, column_iter_to_arrays, create_list, + create_map, get_page_iterator, init_nested, n_columns, +}; +#[cfg(feature = "async")] +use futures::{AsyncRead, AsyncSeek}; +use polars_error::PolarsResult; +pub use schema::{FileMetadata, infer_schema}; + +#[cfg(feature = "async")] +pub use crate::parquet::read::{get_page_stream, read_metadata_async as _read_metadata_async}; +// re-exports of crate::parquet's relevant APIs +pub use crate::parquet::{ + FallibleStreamingIterator, + error::ParquetError, + fallible_streaming_iterator, + metadata::{ColumnChunkMetadata, ColumnDescriptor, RowGroupMetadata}, + page::{CompressedDataPage, DataPageHeader, Page}, + read::{ + BasicDecompressor, MutStreamingIterator, PageReader, ReadColumnIterator, State, decompress, + get_column_iterator, read_metadata as _read_metadata, + }, + schema::types::{ + GroupLogicalType, ParquetType, PhysicalType, PrimitiveConvertedType, PrimitiveLogicalType, + TimeUnit as ParquetTimeUnit, + }, + types::int96_to_i64_ns, +}; + +/// Returns all [`ColumnChunkMetadata`] associated to `field_name`. +/// For non-nested parquet types, this returns a single column +pub fn get_field_pages<'a, T>( + columns: &'a [ColumnChunkMetadata], + items: &'a [T], + field_name: &str, +) -> Vec<&'a T> { + columns + .iter() + .zip(items) + .filter(|(metadata, _)| metadata.descriptor().path_in_schema[0].as_str() == field_name) + .map(|(_, item)| item) + .collect() +} + +/// Reads parquets' metadata synchronously. +pub fn read_metadata(reader: &mut R) -> PolarsResult { + Ok(_read_metadata(reader)?) +} + +/// Reads parquets' metadata asynchronously. +#[cfg(feature = "async")] +pub async fn read_metadata_async( + reader: &mut R, +) -> PolarsResult { + Ok(_read_metadata_async(reader).await?) +} + +fn convert_year_month(value: &[u8]) -> i32 { + i32::from_le_bytes(value[..4].try_into().unwrap()) +} + +fn convert_days_ms(value: &[u8]) -> arrow::types::days_ms { + arrow::types::days_ms( + i32::from_le_bytes(value[4..8].try_into().unwrap()), + i32::from_le_bytes(value[8..12].try_into().unwrap()), + ) +} + +fn convert_i128(value: &[u8], n: usize) -> i128 { + // Copy the fixed-size byte value to the start of a 16 byte stack + // allocated buffer, then use an arithmetic right shift to fill in + // MSBs, which accounts for leading 1's in negative (two's complement) + // values. + let mut bytes = [0u8; 16]; + bytes[..n].copy_from_slice(value); + i128::from_be_bytes(bytes) >> (8 * (16 - n)) +} + +fn convert_i256(value: &[u8]) -> i256 { + if value[0] >= 128 { + let mut neg_bytes = [255u8; 32]; + neg_bytes[32 - value.len()..].copy_from_slice(value); + i256::from_be_bytes(neg_bytes) + } else { + let mut bytes = [0u8; 32]; + bytes[32 - value.len()..].copy_from_slice(value); + i256::from_be_bytes(bytes) + } +} diff --git a/crates/polars-parquet/src/arrow/read/schema/convert.rs b/crates/polars-parquet/src/arrow/read/schema/convert.rs new file mode 100644 index 000000000000..509507774e07 --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/schema/convert.rs @@ -0,0 +1,1221 @@ +//! This module has entry points, [`parquet_to_arrow_schema`] and the more configurable [`parquet_to_arrow_schema_with_options`]. +use arrow::datatypes::{ArrowDataType, ArrowSchema, Field, IntervalUnit, TimeUnit}; +use polars_utils::pl_str::PlSmallStr; + +use crate::arrow::read::schema::SchemaInferenceOptions; +use crate::parquet::schema::Repetition; +use crate::parquet::schema::types::{ + FieldInfo, GroupConvertedType, GroupLogicalType, IntegerType, ParquetType, PhysicalType, + PrimitiveConvertedType, PrimitiveLogicalType, PrimitiveType, TimeUnit as ParquetTimeUnit, +}; + +/// Converts [`ParquetType`]s to a [`Field`], ignoring parquet fields that do not contain +/// any physical column. +pub fn parquet_to_arrow_schema(fields: &[ParquetType]) -> ArrowSchema { + parquet_to_arrow_schema_with_options(fields, &None) +} + +/// Like [`parquet_to_arrow_schema`] but with configurable options which affect the behavior of schema inference +pub fn parquet_to_arrow_schema_with_options( + fields: &[ParquetType], + options: &Option, +) -> ArrowSchema { + fields + .iter() + .filter_map(|f| to_field(f, options.as_ref().unwrap_or(&Default::default()))) + .map(|x| (x.name.clone(), x)) + .collect() +} + +fn from_int32( + logical_type: Option, + converted_type: Option, +) -> ArrowDataType { + use PrimitiveLogicalType::*; + match (logical_type, converted_type) { + // handle logical types first + (Some(Integer(t)), _) => match t { + IntegerType::Int8 => ArrowDataType::Int8, + IntegerType::Int16 => ArrowDataType::Int16, + IntegerType::Int32 => ArrowDataType::Int32, + IntegerType::UInt8 => ArrowDataType::UInt8, + IntegerType::UInt16 => ArrowDataType::UInt16, + IntegerType::UInt32 => ArrowDataType::UInt32, + // The above are the only possible annotations for parquet's int32. Anything else + // is a deviation to the parquet specification and we ignore + _ => ArrowDataType::Int32, + }, + (Some(Decimal(precision, scale)), _) => ArrowDataType::Decimal(precision, scale), + (Some(Date), _) => ArrowDataType::Date32, + (Some(Time { unit, .. }), _) => match unit { + ParquetTimeUnit::Milliseconds => ArrowDataType::Time32(TimeUnit::Millisecond), + // MILLIS is the only possible annotation for parquet's int32. Anything else + // is a deviation to the parquet specification and we ignore + _ => ArrowDataType::Int32, + }, + // handle converted types: + (_, Some(PrimitiveConvertedType::Uint8)) => ArrowDataType::UInt8, + (_, Some(PrimitiveConvertedType::Uint16)) => ArrowDataType::UInt16, + (_, Some(PrimitiveConvertedType::Uint32)) => ArrowDataType::UInt32, + (_, Some(PrimitiveConvertedType::Int8)) => ArrowDataType::Int8, + (_, Some(PrimitiveConvertedType::Int16)) => ArrowDataType::Int16, + (_, Some(PrimitiveConvertedType::Int32)) => ArrowDataType::Int32, + (_, Some(PrimitiveConvertedType::Date)) => ArrowDataType::Date32, + (_, Some(PrimitiveConvertedType::TimeMillis)) => { + ArrowDataType::Time32(TimeUnit::Millisecond) + }, + (_, Some(PrimitiveConvertedType::Decimal(precision, scale))) => { + ArrowDataType::Decimal(precision, scale) + }, + (_, _) => ArrowDataType::Int32, + } +} + +fn from_int64( + logical_type: Option, + converted_type: Option, +) -> ArrowDataType { + use PrimitiveLogicalType::*; + match (logical_type, converted_type) { + // handle logical types first + (Some(Integer(integer)), _) => match integer { + IntegerType::UInt64 => ArrowDataType::UInt64, + IntegerType::Int64 => ArrowDataType::Int64, + _ => ArrowDataType::Int64, + }, + ( + Some(Timestamp { + is_adjusted_to_utc, + unit, + }), + _, + ) => { + let timezone = if is_adjusted_to_utc { + // https://github.com/apache/parquet-format/blob/master/LogicalTypes.md + // A TIMESTAMP with isAdjustedToUTC=true is defined as [...] elapsed since the Unix epoch + Some(PlSmallStr::from_static("+00:00")) + } else { + // PARQUET: + // https://github.com/apache/parquet-format/blob/master/LogicalTypes.md + // A TIMESTAMP with isAdjustedToUTC=false represents [...] such + // timestamps should always be displayed the same way, regardless of the local time zone in effect + // ARROW: + // https://github.com/apache/parquet-format/blob/master/LogicalTypes.md + // If the time zone is null or equal to an empty string, the data is "time + // zone naive" and shall be displayed *as is* to the user, not localized + // to the locale of the user. + None + }; + + match unit { + ParquetTimeUnit::Milliseconds => { + ArrowDataType::Timestamp(TimeUnit::Millisecond, timezone) + }, + ParquetTimeUnit::Microseconds => { + ArrowDataType::Timestamp(TimeUnit::Microsecond, timezone) + }, + ParquetTimeUnit::Nanoseconds => { + ArrowDataType::Timestamp(TimeUnit::Nanosecond, timezone) + }, + } + }, + (Some(Time { unit, .. }), _) => match unit { + ParquetTimeUnit::Microseconds => ArrowDataType::Time64(TimeUnit::Microsecond), + ParquetTimeUnit::Nanoseconds => ArrowDataType::Time64(TimeUnit::Nanosecond), + // MILLIS is only possible for int32. Appearing in int64 is a deviation + // to parquet's spec, which we ignore + _ => ArrowDataType::Int64, + }, + (Some(Decimal(precision, scale)), _) => ArrowDataType::Decimal(precision, scale), + // handle converted types: + (_, Some(PrimitiveConvertedType::TimeMicros)) => { + ArrowDataType::Time64(TimeUnit::Microsecond) + }, + (_, Some(PrimitiveConvertedType::TimestampMillis)) => { + ArrowDataType::Timestamp(TimeUnit::Millisecond, None) + }, + (_, Some(PrimitiveConvertedType::TimestampMicros)) => { + ArrowDataType::Timestamp(TimeUnit::Microsecond, None) + }, + (_, Some(PrimitiveConvertedType::Int64)) => ArrowDataType::Int64, + (_, Some(PrimitiveConvertedType::Uint64)) => ArrowDataType::UInt64, + (_, Some(PrimitiveConvertedType::Decimal(precision, scale))) => { + ArrowDataType::Decimal(precision, scale) + }, + + (_, _) => ArrowDataType::Int64, + } +} + +fn from_byte_array( + logical_type: &Option, + converted_type: &Option, +) -> ArrowDataType { + match (logical_type, converted_type) { + (Some(PrimitiveLogicalType::String), _) => ArrowDataType::Utf8View, + (Some(PrimitiveLogicalType::Json), _) => ArrowDataType::BinaryView, + (Some(PrimitiveLogicalType::Bson), _) => ArrowDataType::BinaryView, + (Some(PrimitiveLogicalType::Enum), _) => ArrowDataType::BinaryView, + (_, Some(PrimitiveConvertedType::Json)) => ArrowDataType::BinaryView, + (_, Some(PrimitiveConvertedType::Bson)) => ArrowDataType::BinaryView, + (_, Some(PrimitiveConvertedType::Enum)) => ArrowDataType::BinaryView, + (_, Some(PrimitiveConvertedType::Utf8)) => ArrowDataType::Utf8View, + (_, _) => ArrowDataType::BinaryView, + } +} + +fn from_fixed_len_byte_array( + length: usize, + logical_type: Option, + converted_type: Option, +) -> ArrowDataType { + match (logical_type, converted_type) { + (Some(PrimitiveLogicalType::Decimal(precision, scale)), _) => { + ArrowDataType::Decimal(precision, scale) + }, + (None, Some(PrimitiveConvertedType::Decimal(precision, scale))) => { + ArrowDataType::Decimal(precision, scale) + }, + (None, Some(PrimitiveConvertedType::Interval)) => { + // There is currently no reliable way of determining which IntervalUnit + // to return. Thus without the original Arrow schema, the results + // would be incorrect if all 12 bytes of the interval are populated + ArrowDataType::Interval(IntervalUnit::DayTime) + }, + _ => ArrowDataType::FixedSizeBinary(length), + } +} + +/// Maps a [`PhysicalType`] with optional metadata to a [`ArrowDataType`] +fn to_primitive_type_inner( + primitive_type: &PrimitiveType, + options: &SchemaInferenceOptions, +) -> ArrowDataType { + match primitive_type.physical_type { + PhysicalType::Boolean => ArrowDataType::Boolean, + PhysicalType::Int32 => { + from_int32(primitive_type.logical_type, primitive_type.converted_type) + }, + PhysicalType::Int64 => { + from_int64(primitive_type.logical_type, primitive_type.converted_type) + }, + PhysicalType::Int96 => ArrowDataType::Timestamp(options.int96_coerce_to_timeunit, None), + PhysicalType::Float => ArrowDataType::Float32, + PhysicalType::Double => ArrowDataType::Float64, + PhysicalType::ByteArray => { + from_byte_array(&primitive_type.logical_type, &primitive_type.converted_type) + }, + PhysicalType::FixedLenByteArray(length) => from_fixed_len_byte_array( + length, + primitive_type.logical_type, + primitive_type.converted_type, + ), + } +} + +/// Entry point for converting parquet primitive type to arrow type. +/// +/// This function takes care of repetition. +fn to_primitive_type( + primitive_type: &PrimitiveType, + options: &SchemaInferenceOptions, +) -> ArrowDataType { + let base_type = to_primitive_type_inner(primitive_type, options); + + if primitive_type.field_info.repetition == Repetition::Repeated { + ArrowDataType::LargeList(Box::new(Field::new( + primitive_type.field_info.name.clone(), + base_type, + is_nullable(&primitive_type.field_info), + ))) + } else { + base_type + } +} + +fn non_repeated_group( + logical_type: &Option, + converted_type: &Option, + fields: &[ParquetType], + parent_name: &str, + options: &SchemaInferenceOptions, +) -> Option { + debug_assert!(!fields.is_empty()); + match (logical_type, converted_type) { + (Some(GroupLogicalType::List), _) => to_list(fields, parent_name, options), + (None, Some(GroupConvertedType::List)) => to_list(fields, parent_name, options), + (Some(GroupLogicalType::Map), _) => to_list(fields, parent_name, options), + (None, Some(GroupConvertedType::Map) | Some(GroupConvertedType::MapKeyValue)) => { + to_map(fields, options) + }, + _ => to_struct(fields, options), + } +} + +/// Converts a parquet group type to an arrow [`ArrowDataType::Struct`]. +/// Returns [`None`] if all its fields are empty +fn to_struct(fields: &[ParquetType], options: &SchemaInferenceOptions) -> Option { + let fields = fields + .iter() + .filter_map(|f| to_field(f, options)) + .collect::>(); + if fields.is_empty() { + None + } else { + Some(ArrowDataType::Struct(fields)) + } +} + +/// Converts a parquet group type to an arrow [`ArrowDataType::Struct`]. +/// Returns [`None`] if all its fields are empty +fn to_map(fields: &[ParquetType], options: &SchemaInferenceOptions) -> Option { + let inner = to_field(&fields[0], options)?; + Some(ArrowDataType::Map(Box::new(inner), false)) +} + +/// Entry point for converting parquet group type. +/// +/// This function takes care of logical type and repetition. +fn to_group_type( + field_info: &FieldInfo, + logical_type: &Option, + converted_type: &Option, + fields: &[ParquetType], + parent_name: &str, + options: &SchemaInferenceOptions, +) -> Option { + debug_assert!(!fields.is_empty()); + if field_info.repetition == Repetition::Repeated { + Some(ArrowDataType::LargeList(Box::new(Field::new( + field_info.name.clone(), + to_struct(fields, options)?, + is_nullable(field_info), + )))) + } else { + non_repeated_group(logical_type, converted_type, fields, parent_name, options) + } +} + +/// Checks whether this schema is nullable. +pub(crate) fn is_nullable(field_info: &FieldInfo) -> bool { + match field_info.repetition { + Repetition::Optional => true, + Repetition::Repeated => true, + Repetition::Required => false, + } +} + +/// Converts parquet schema to arrow field. +/// Returns `None` iff the parquet type has no associated primitive types, +/// i.e. if it is a column-less group type. +fn to_field(type_: &ParquetType, options: &SchemaInferenceOptions) -> Option { + Some(Field::new( + type_.get_field_info().name.clone(), + to_dtype(type_, options)?, + is_nullable(type_.get_field_info()), + )) +} + +/// Converts a parquet list to arrow list. +/// +/// To fully understand this algorithm, please refer to +/// [parquet doc](https://github.com/apache/parquet-format/blob/master/LogicalTypes.md). +fn to_list( + fields: &[ParquetType], + parent_name: &str, + options: &SchemaInferenceOptions, +) -> Option { + let item = fields.first().unwrap(); + + let item_type = match item { + ParquetType::PrimitiveType(primitive) => Some(to_primitive_type_inner(primitive, options)), + ParquetType::GroupType { fields, .. } => { + if fields.len() == 1 && item.name() != "array" && { + // item.name() != format!("{parent_name}_tuple") + let cmp = [parent_name, "_tuple"]; + let len_1 = parent_name.len(); + let len = len_1 + "_tuple".len(); + + item.name().len() != len || [&item.name()[..len_1], &item.name()[len_1..]] != cmp + } { + // extract the repetition field + let nested_item = fields.first().unwrap(); + to_dtype(nested_item, options) + } else { + to_struct(fields, options) + } + }, + }?; + + // Check that the name of the list child is "list", in which case we + // get the child nullability and name (normally "element") from the nested + // group type. + // Without this step, the child incorrectly inherits the parent's optionality + let (list_item_name, item_is_optional) = match item { + ParquetType::GroupType { + field_info, fields, .. + } if field_info.name.as_str() == "list" && fields.len() == 1 => { + let field = fields.first().unwrap(); + ( + field.get_field_info().name.clone(), + field.get_field_info().repetition == Repetition::Optional, + ) + }, + _ => ( + item.get_field_info().name.clone(), + item.get_field_info().repetition == Repetition::Optional, + ), + }; + + Some(ArrowDataType::LargeList(Box::new(Field::new( + list_item_name, + item_type, + item_is_optional, + )))) +} + +/// Converts parquet schema to arrow data type. +/// +/// This function discards schema name. +/// +/// If this schema is a primitive type and not included in the leaves, the result is +/// Ok(None). +/// +/// If this schema is a group type and none of its children is reserved in the +/// conversion, the result is Ok(None). +pub(crate) fn to_dtype( + type_: &ParquetType, + options: &SchemaInferenceOptions, +) -> Option { + match type_ { + ParquetType::PrimitiveType(primitive) => Some(to_primitive_type(primitive, options)), + ParquetType::GroupType { + field_info, + logical_type, + converted_type, + fields, + } => { + if fields.is_empty() { + None + } else { + to_group_type( + field_info, + logical_type, + converted_type, + fields, + field_info.name.as_str(), + options, + ) + } + }, + } +} + +#[cfg(test)] +mod tests { + use polars_error::*; + + use super::*; + use crate::parquet::metadata::SchemaDescriptor; + + #[test] + fn test_flat_primitives() -> PolarsResult<()> { + let message = " + message test_schema { + REQUIRED BOOLEAN boolean; + REQUIRED INT32 int8 (INT_8); + REQUIRED INT32 int16 (INT_16); + REQUIRED INT32 uint8 (INTEGER(8,false)); + REQUIRED INT32 uint16 (INTEGER(16,false)); + REQUIRED INT32 int32; + REQUIRED INT64 int64 ; + OPTIONAL DOUBLE double; + OPTIONAL FLOAT float; + OPTIONAL BINARY string (UTF8); + OPTIONAL BINARY string_2 (STRING); + } + "; + let expected = &[ + Field::new("boolean".into(), ArrowDataType::Boolean, false), + Field::new("int8".into(), ArrowDataType::Int8, false), + Field::new("int16".into(), ArrowDataType::Int16, false), + Field::new("uint8".into(), ArrowDataType::UInt8, false), + Field::new("uint16".into(), ArrowDataType::UInt16, false), + Field::new("int32".into(), ArrowDataType::Int32, false), + Field::new("int64".into(), ArrowDataType::Int64, false), + Field::new("double".into(), ArrowDataType::Float64, true), + Field::new("float".into(), ArrowDataType::Float32, true), + Field::new("string".into(), ArrowDataType::Utf8View, true), + Field::new("string_2".into(), ArrowDataType::Utf8View, true), + ]; + + let parquet_schema = SchemaDescriptor::try_from_message(message)?; + let fields = parquet_to_arrow_schema(parquet_schema.fields()); + let fields = fields.iter_values().cloned().collect::>(); + + assert_eq!(fields, expected); + Ok(()) + } + + #[test] + fn test_byte_array_fields() -> PolarsResult<()> { + let message = " + message test_schema { + REQUIRED BYTE_ARRAY binary; + REQUIRED FIXED_LEN_BYTE_ARRAY (20) fixed_binary; + } + "; + let expected = vec![ + Field::new("binary".into(), ArrowDataType::BinaryView, false), + Field::new( + "fixed_binary".into(), + ArrowDataType::FixedSizeBinary(20), + false, + ), + ]; + + let parquet_schema = SchemaDescriptor::try_from_message(message)?; + let fields = parquet_to_arrow_schema(parquet_schema.fields()); + let fields = fields.iter_values().cloned().collect::>(); + + assert_eq!(fields, expected); + Ok(()) + } + + #[test] + fn test_duplicate_fields() -> PolarsResult<()> { + let message = " + message test_schema { + REQUIRED BOOLEAN boolean; + REQUIRED INT32 int8 (INT_8); + } + "; + let expected = &[ + Field::new("boolean".into(), ArrowDataType::Boolean, false), + Field::new("int8".into(), ArrowDataType::Int8, false), + ]; + + let parquet_schema = SchemaDescriptor::try_from_message(message)?; + let fields = parquet_to_arrow_schema(parquet_schema.fields()); + let fields = fields.iter_values().cloned().collect::>(); + + assert_eq!(fields, expected); + Ok(()) + } + + #[ignore] + #[test] + fn test_parquet_lists() -> PolarsResult<()> { + let mut arrow_fields = Vec::new(); + + // LIST encoding example taken from parquet-format/LogicalTypes.md + let message_type = " + message test_schema { + REQUIRED GROUP my_list (LIST) { + REPEATED GROUP list { + OPTIONAL BINARY element (UTF8); + } + } + OPTIONAL GROUP my_list (LIST) { + REPEATED GROUP list { + REQUIRED BINARY element (UTF8); + } + } + OPTIONAL GROUP array_of_arrays (LIST) { + REPEATED GROUP list { + REQUIRED GROUP element (LIST) { + REPEATED GROUP list { + REQUIRED INT32 element; + } + } + } + } + OPTIONAL GROUP my_list (LIST) { + REPEATED GROUP element { + REQUIRED BINARY str (UTF8); + } + } + OPTIONAL GROUP my_list (LIST) { + REPEATED INT32 element; + } + OPTIONAL GROUP my_list (LIST) { + REPEATED GROUP element { + REQUIRED BINARY str (UTF8); + REQUIRED INT32 num; + } + } + OPTIONAL GROUP my_list (LIST) { + REPEATED GROUP array { + REQUIRED BINARY str (UTF8); + } + + } + OPTIONAL GROUP my_list (LIST) { + REPEATED GROUP my_list_tuple { + REQUIRED BINARY str (UTF8); + } + } + REPEATED INT32 name; + } + "; + + // // List (list non-null, elements nullable) + // required group my_list (LIST) { + // repeated group list { + // optional binary element (UTF8); + // } + // } + { + arrow_fields.push(Field::new( + "my_list".into(), + ArrowDataType::LargeList(Box::new(Field::new( + "element".into(), + ArrowDataType::Utf8, + true, + ))), + false, + )); + } + + // // List (list nullable, elements non-null) + // optional group my_list (LIST) { + // repeated group list { + // required binary element (UTF8); + // } + // } + { + arrow_fields.push(Field::new( + "my_list".into(), + ArrowDataType::LargeList(Box::new(Field::new( + "element".into(), + ArrowDataType::Utf8, + false, + ))), + true, + )); + } + + // Element types can be nested structures. For example, a list of lists: + // + // // List> + // optional group array_of_arrays (LIST) { + // repeated group list { + // required group element (LIST) { + // repeated group list { + // required int32 element; + // } + // } + // } + // } + { + let arrow_inner_list = ArrowDataType::LargeList(Box::new(Field::new( + "element".into(), + ArrowDataType::Int32, + false, + ))); + arrow_fields.push(Field::new( + "array_of_arrays".into(), + ArrowDataType::LargeList(Box::new(Field::new( + PlSmallStr::from_static("element"), + arrow_inner_list, + false, + ))), + true, + )); + } + + // // List (list nullable, elements non-null) + // optional group my_list (LIST) { + // repeated group element { + // required binary str (UTF8); + // }; + // } + { + arrow_fields.push(Field::new( + "my_list".into(), + ArrowDataType::LargeList(Box::new(Field::new( + "element".into(), + ArrowDataType::Utf8, + false, + ))), + true, + )); + } + + // // List (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated int32 element; + // } + { + arrow_fields.push(Field::new( + "my_list".into(), + ArrowDataType::LargeList(Box::new(Field::new( + "element".into(), + ArrowDataType::Int32, + false, + ))), + true, + )); + } + + // // List> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group element { + // required binary str (UTF8); + // required int32 num; + // }; + // } + { + let arrow_struct = ArrowDataType::Struct(vec![ + Field::new("str".into(), ArrowDataType::Utf8, false), + Field::new("num".into(), ArrowDataType::Int32, false), + ]); + arrow_fields.push(Field::new( + "my_list".into(), + ArrowDataType::LargeList(Box::new(Field::new( + "element".into(), + arrow_struct, + false, + ))), + true, + )); + } + + // // List> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group array { + // required binary str (UTF8); + // }; + // } + // Special case: group is named array + { + let arrow_struct = + ArrowDataType::Struct(vec![Field::new("str".into(), ArrowDataType::Utf8, false)]); + arrow_fields.push(Field::new( + "my_list".into(), + ArrowDataType::LargeList(Box::new(Field::new("array".into(), arrow_struct, false))), + true, + )); + } + + // // List> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group my_list_tuple { + // required binary str (UTF8); + // }; + // } + // Special case: group named ends in _tuple + { + let arrow_struct = + ArrowDataType::Struct(vec![Field::new("str".into(), ArrowDataType::Utf8, false)]); + arrow_fields.push(Field::new( + "my_list".into(), + ArrowDataType::LargeList(Box::new(Field::new( + "my_list_tuple".into(), + arrow_struct, + false, + ))), + true, + )); + } + + // One-level encoding: Only allows required lists with required cells + // repeated value_type name + { + arrow_fields.push(Field::new( + "name".into(), + ArrowDataType::LargeList(Box::new(Field::new( + "name".into(), + ArrowDataType::Int32, + false, + ))), + false, + )); + } + + let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; + let fields = parquet_to_arrow_schema(parquet_schema.fields()); + let fields = fields.iter_values().cloned().collect::>(); + + assert_eq!(arrow_fields, fields); + Ok(()) + } + + #[test] + fn test_parquet_list_with_struct() -> PolarsResult<()> { + let mut arrow_fields = Vec::new(); + + let message_type = " + message eventlog { + REQUIRED group events (LIST) { + REPEATED group array { + REQUIRED BYTE_ARRAY event_name (STRING); + REQUIRED INT64 event_time (TIMESTAMP(MILLIS,true)); + } + } + } + "; + + { + let struct_fields = vec![ + Field::new("event_name".into(), ArrowDataType::Utf8View, false), + Field::new( + "event_time".into(), + ArrowDataType::Timestamp(TimeUnit::Millisecond, Some("+00:00".into())), + false, + ), + ]; + arrow_fields.push(Field::new( + "events".into(), + ArrowDataType::LargeList(Box::new(Field::new( + "array".into(), + ArrowDataType::Struct(struct_fields), + false, + ))), + false, + )); + } + + let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; + let fields = parquet_to_arrow_schema(parquet_schema.fields()); + let fields = fields.iter_values().cloned().collect::>(); + + assert_eq!(arrow_fields, fields); + Ok(()) + } + + #[test] + fn test_parquet_list_nullable() -> PolarsResult<()> { + let mut arrow_fields = Vec::new(); + + let message_type = " + message test_schema { + REQUIRED GROUP my_list1 (LIST) { + REPEATED GROUP list { + OPTIONAL BINARY element (UTF8); + } + } + OPTIONAL GROUP my_list2 (LIST) { + REPEATED GROUP list { + REQUIRED BINARY element (UTF8); + } + } + REQUIRED GROUP my_list3 (LIST) { + REPEATED GROUP list { + REQUIRED BINARY element (UTF8); + } + } + } + "; + + // // List (list non-null, elements nullable) + // required group my_list1 (LIST) { + // repeated group list { + // optional binary element (UTF8); + // } + // } + { + arrow_fields.push(Field::new( + "my_list1".into(), + ArrowDataType::LargeList(Box::new(Field::new( + "element".into(), + ArrowDataType::Utf8View, + true, + ))), + false, + )); + } + + // // List (list nullable, elements non-null) + // optional group my_list2 (LIST) { + // repeated group list { + // required binary element (UTF8); + // } + // } + { + arrow_fields.push(Field::new( + "my_list2".into(), + ArrowDataType::LargeList(Box::new(Field::new( + "element".into(), + ArrowDataType::Utf8View, + false, + ))), + true, + )); + } + + // // List (list non-null, elements non-null) + // repeated group my_list3 (LIST) { + // repeated group list { + // required binary element (UTF8); + // } + // } + { + arrow_fields.push(Field::new( + "my_list3".into(), + ArrowDataType::LargeList(Box::new(Field::new( + "element".into(), + ArrowDataType::Utf8View, + false, + ))), + false, + )); + } + + let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; + let fields = parquet_to_arrow_schema(parquet_schema.fields()); + let fields = fields.iter_values().cloned().collect::>(); + + assert_eq!(arrow_fields, fields); + Ok(()) + } + + #[test] + fn test_nested_schema() -> PolarsResult<()> { + let mut arrow_fields = Vec::new(); + { + let group1_fields = vec![ + Field::new("leaf1".into(), ArrowDataType::Boolean, false), + Field::new("leaf2".into(), ArrowDataType::Int32, false), + ]; + let group1_struct = + Field::new("group1".into(), ArrowDataType::Struct(group1_fields), false); + arrow_fields.push(group1_struct); + + let leaf3_field = Field::new("leaf3".into(), ArrowDataType::Int64, false); + arrow_fields.push(leaf3_field); + } + + let message_type = " + message test_schema { + REQUIRED GROUP group1 { + REQUIRED BOOLEAN leaf1; + REQUIRED INT32 leaf2; + } + REQUIRED INT64 leaf3; + } + "; + + let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; + let fields = parquet_to_arrow_schema(parquet_schema.fields()); + let fields = fields.iter_values().cloned().collect::>(); + + assert_eq!(arrow_fields, fields); + Ok(()) + } + + #[ignore] + #[test] + fn test_repeated_nested_schema() -> PolarsResult<()> { + let mut arrow_fields = Vec::new(); + { + arrow_fields.push(Field::new("leaf1".into(), ArrowDataType::Int32, true)); + + let inner_group_list = Field::new( + "innerGroup".into(), + ArrowDataType::LargeList(Box::new(Field::new( + "innerGroup".into(), + ArrowDataType::Struct(vec![Field::new( + "leaf3".into(), + ArrowDataType::Int32, + true, + )]), + false, + ))), + false, + ); + + let outer_group_list = Field::new( + "outerGroup".into(), + ArrowDataType::LargeList(Box::new(Field::new( + "outerGroup".into(), + ArrowDataType::Struct(vec![ + Field::new("leaf2".into(), ArrowDataType::Int32, true), + inner_group_list, + ]), + false, + ))), + false, + ); + arrow_fields.push(outer_group_list); + } + + let message_type = " + message test_schema { + OPTIONAL INT32 leaf1; + REPEATED GROUP outerGroup { + OPTIONAL INT32 leaf2; + REPEATED GROUP innerGroup { + OPTIONAL INT32 leaf3; + } + } + } + "; + + let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; + let fields = parquet_to_arrow_schema(parquet_schema.fields()); + let fields = fields.iter_values().cloned().collect::>(); + + assert_eq!(arrow_fields, fields); + Ok(()) + } + + #[ignore] + #[test] + fn test_column_desc_to_field() -> PolarsResult<()> { + let message_type = " + message test_schema { + REQUIRED BOOLEAN boolean; + REQUIRED INT32 int8 (INT_8); + REQUIRED INT32 uint8 (INTEGER(8,false)); + REQUIRED INT32 int16 (INT_16); + REQUIRED INT32 uint16 (INTEGER(16,false)); + REQUIRED INT32 int32; + REQUIRED INT64 int64; + OPTIONAL DOUBLE double; + OPTIONAL FLOAT float; + OPTIONAL BINARY string (UTF8); + REPEATED BOOLEAN bools; + OPTIONAL INT32 date (DATE); + OPTIONAL INT32 time_milli (TIME_MILLIS); + OPTIONAL INT64 time_micro (TIME_MICROS); + OPTIONAL INT64 time_nano (TIME(NANOS,false)); + OPTIONAL INT64 ts_milli (TIMESTAMP_MILLIS); + REQUIRED INT64 ts_micro (TIMESTAMP_MICROS); + REQUIRED INT64 ts_nano (TIMESTAMP(NANOS,true)); + } + "; + let arrow_fields = vec![ + Field::new("boolean".into(), ArrowDataType::Boolean, false), + Field::new("int8".into(), ArrowDataType::Int8, false), + Field::new("uint8".into(), ArrowDataType::UInt8, false), + Field::new("int16".into(), ArrowDataType::Int16, false), + Field::new("uint16".into(), ArrowDataType::UInt16, false), + Field::new("int32".into(), ArrowDataType::Int32, false), + Field::new("int64".into(), ArrowDataType::Int64, false), + Field::new("double".into(), ArrowDataType::Float64, true), + Field::new("float".into(), ArrowDataType::Float32, true), + Field::new("string".into(), ArrowDataType::Utf8, true), + Field::new( + "bools".into(), + ArrowDataType::LargeList(Box::new(Field::new( + "bools".into(), + ArrowDataType::Boolean, + false, + ))), + false, + ), + Field::new("date".into(), ArrowDataType::Date32, true), + Field::new( + "time_milli".into(), + ArrowDataType::Time32(TimeUnit::Millisecond), + true, + ), + Field::new( + "time_micro".into(), + ArrowDataType::Time64(TimeUnit::Microsecond), + true, + ), + Field::new( + "time_nano".into(), + ArrowDataType::Time64(TimeUnit::Nanosecond), + true, + ), + Field::new( + "ts_milli".into(), + ArrowDataType::Timestamp(TimeUnit::Millisecond, None), + true, + ), + Field::new( + "ts_micro".into(), + ArrowDataType::Timestamp(TimeUnit::Microsecond, None), + false, + ), + Field::new( + "ts_nano".into(), + ArrowDataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".into())), + false, + ), + ]; + + let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; + let fields = parquet_to_arrow_schema(parquet_schema.fields()); + let fields = fields.iter_values().cloned().collect::>(); + + assert_eq!(arrow_fields, fields); + Ok(()) + } + + #[test] + fn test_field_to_column_desc() -> PolarsResult<()> { + let message_type = " + message arrow_schema { + REQUIRED BOOLEAN boolean; + REQUIRED INT32 int8 (INT_8); + REQUIRED INT32 int16 (INTEGER(16,true)); + REQUIRED INT32 int32; + REQUIRED INT64 int64; + OPTIONAL DOUBLE double; + OPTIONAL FLOAT float; + OPTIONAL BINARY string (STRING); + OPTIONAL GROUP bools (LIST) { + REPEATED GROUP list { + OPTIONAL BOOLEAN element; + } + } + REQUIRED GROUP bools_non_null (LIST) { + REPEATED GROUP list { + REQUIRED BOOLEAN element; + } + } + OPTIONAL INT32 date (DATE); + OPTIONAL INT32 time_milli (TIME(MILLIS,false)); + OPTIONAL INT64 time_micro (TIME_MICROS); + OPTIONAL INT64 ts_milli (TIMESTAMP_MILLIS); + REQUIRED INT64 ts_micro (TIMESTAMP(MICROS,false)); + REQUIRED GROUP struct { + REQUIRED BOOLEAN bools; + REQUIRED INT32 uint32 (INTEGER(32,false)); + REQUIRED GROUP int32 (LIST) { + REPEATED GROUP list { + OPTIONAL INT32 element; + } + } + } + REQUIRED BINARY dictionary_strings (STRING); + } + "; + + let arrow_fields = vec![ + Field::new("boolean".into(), ArrowDataType::Boolean, false), + Field::new("int8".into(), ArrowDataType::Int8, false), + Field::new("int16".into(), ArrowDataType::Int16, false), + Field::new("int32".into(), ArrowDataType::Int32, false), + Field::new("int64".into(), ArrowDataType::Int64, false), + Field::new("double".into(), ArrowDataType::Float64, true), + Field::new("float".into(), ArrowDataType::Float32, true), + Field::new("string".into(), ArrowDataType::Utf8View, true), + Field::new( + "bools".into(), + ArrowDataType::LargeList(Box::new(Field::new( + "element".into(), + ArrowDataType::Boolean, + true, + ))), + true, + ), + Field::new( + "bools_non_null".into(), + ArrowDataType::LargeList(Box::new(Field::new( + "element".into(), + ArrowDataType::Boolean, + false, + ))), + false, + ), + Field::new("date".into(), ArrowDataType::Date32, true), + Field::new( + "time_milli".into(), + ArrowDataType::Time32(TimeUnit::Millisecond), + true, + ), + Field::new( + "time_micro".into(), + ArrowDataType::Time64(TimeUnit::Microsecond), + true, + ), + Field::new( + "ts_milli".into(), + ArrowDataType::Timestamp(TimeUnit::Millisecond, None), + true, + ), + Field::new( + "ts_micro".into(), + ArrowDataType::Timestamp(TimeUnit::Microsecond, None), + false, + ), + Field::new( + "struct".into(), + ArrowDataType::Struct(vec![ + Field::new("bools".into(), ArrowDataType::Boolean, false), + Field::new("uint32".into(), ArrowDataType::UInt32, false), + Field::new( + "int32".into(), + ArrowDataType::LargeList(Box::new(Field::new( + "element".into(), + ArrowDataType::Int32, + true, + ))), + false, + ), + ]), + false, + ), + Field::new("dictionary_strings".into(), ArrowDataType::Utf8View, false), + ]; + + let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; + let fields = parquet_to_arrow_schema(parquet_schema.fields()); + let fields = fields.iter_values().cloned().collect::>(); + + assert_eq!(arrow_fields, fields); + Ok(()) + } + + #[test] + fn test_int96_options() -> PolarsResult<()> { + for tu in [ + TimeUnit::Second, + TimeUnit::Microsecond, + TimeUnit::Millisecond, + TimeUnit::Nanosecond, + ] { + let message_type = " + message arrow_schema { + REQUIRED INT96 int96_field; + OPTIONAL GROUP int96_list (LIST) { + REPEATED GROUP list { + OPTIONAL INT96 element; + } + } + REQUIRED GROUP int96_struct { + REQUIRED INT96 int96_field; + } + } + "; + let coerced_to = ArrowDataType::Timestamp(tu, None); + let arrow_fields = vec![ + Field::new("int96_field".into(), coerced_to.clone(), false), + Field::new( + "int96_list".into(), + ArrowDataType::LargeList(Box::new(Field::new( + "element".into(), + coerced_to.clone(), + true, + ))), + true, + ), + Field::new( + "int96_struct".into(), + ArrowDataType::Struct(vec![Field::new( + "int96_field".into(), + coerced_to.clone(), + false, + )]), + false, + ), + ]; + + let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; + let fields = parquet_to_arrow_schema_with_options( + parquet_schema.fields(), + &Some(SchemaInferenceOptions { + int96_coerce_to_timeunit: tu, + }), + ); + let fields = fields.iter_values().cloned().collect::>(); + assert_eq!(arrow_fields, fields); + } + Ok(()) + } +} diff --git a/crates/polars-parquet/src/arrow/read/schema/metadata.rs b/crates/polars-parquet/src/arrow/read/schema/metadata.rs new file mode 100644 index 000000000000..6a986db5d3d7 --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/schema/metadata.rs @@ -0,0 +1,126 @@ +use arrow::datatypes::{ + ArrowDataType, ArrowSchema, DTYPE_CATEGORICAL, DTYPE_ENUM_VALUES, Field, IntegerType, Metadata, +}; +use arrow::io::ipc::read::deserialize_schema; +use base64::Engine as _; +use base64::engine::general_purpose; +use polars_error::{PolarsResult, polars_bail}; +use polars_utils::pl_str::PlSmallStr; + +use super::super::super::ARROW_SCHEMA_META_KEY; +pub use crate::parquet::metadata::KeyValue; + +/// Reads an arrow schema from Parquet's file metadata. Returns `None` if no schema was found. +/// # Errors +/// Errors iff the schema cannot be correctly parsed. +pub fn read_schema_from_metadata(metadata: &mut Metadata) -> PolarsResult> { + metadata + .remove(ARROW_SCHEMA_META_KEY) + .map(|encoded| get_arrow_schema_from_metadata(&encoded)) + .transpose() +} + +fn convert_field(field: &mut Field) { + // @NOTE: We cast non-Polars dictionaries to normal values because Polars does not have a + // generic dictionary type. + field.dtype = match std::mem::take(&mut field.dtype) { + ArrowDataType::Dictionary(key_type, value_type, sorted) => { + let is_pl_enum_or_categorical = field.metadata.as_ref().is_some_and(|md| { + md.contains_key(DTYPE_ENUM_VALUES) || md.contains_key(DTYPE_CATEGORICAL) + }) && matches!(key_type, IntegerType::UInt32) + && matches!(value_type.as_ref(), ArrowDataType::Utf8View); + let is_int_to_str = matches!( + value_type.as_ref(), + ArrowDataType::Utf8View | ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 + ); + + if is_pl_enum_or_categorical || is_int_to_str { + convert_dtype(ArrowDataType::Dictionary(key_type, value_type, sorted)) + } else { + convert_dtype(*value_type) + } + }, + dt => convert_dtype(dt), + }; +} + +fn convert_dtype(mut dtype: ArrowDataType) -> ArrowDataType { + use ArrowDataType::*; + match dtype { + List(mut field) => { + convert_field(field.as_mut()); + dtype = LargeList(field); + }, + LargeList(ref mut field) | FixedSizeList(ref mut field, _) => convert_field(field.as_mut()), + Struct(ref mut fields) => { + for field in fields { + convert_field(field); + } + }, + Float16 => dtype = Float32, + Binary | LargeBinary => dtype = BinaryView, + Utf8 | LargeUtf8 => dtype = Utf8View, + Dictionary(_, ref mut dtype, _) => { + let dtype = dtype.as_mut(); + *dtype = convert_dtype(std::mem::take(dtype)); + }, + Extension(ref mut ext) => { + ext.inner = convert_dtype(std::mem::take(&mut ext.inner)); + }, + Map(mut field, _ordered) => { + // Polars doesn't support Map. + // A map is physically a `List>` + // So we read as list. + convert_field(field.as_mut()); + dtype = LargeList(field); + }, + _ => {}, + } + + dtype +} + +/// Try to convert Arrow schema metadata into a schema +fn get_arrow_schema_from_metadata(encoded_meta: &str) -> PolarsResult { + let decoded = general_purpose::STANDARD.decode(encoded_meta); + match decoded { + Ok(bytes) => { + let slice = if bytes[0..4] == [255u8; 4] { + &bytes[8..] + } else { + bytes.as_slice() + }; + let mut schema = deserialize_schema(slice).map(|x| x.0)?; + // Convert the data types to the data types we support. + for field in schema.iter_values_mut() { + convert_field(field); + } + Ok(schema) + }, + Err(err) => { + // The C++ implementation returns an error if the schema can't be parsed. + polars_bail!(InvalidOperation: + "unable to decode the encoded schema stored in {ARROW_SCHEMA_META_KEY}, {err:?}" + ) + }, + } +} + +pub(super) fn parse_key_value_metadata(key_value_metadata: &Option>) -> Metadata { + key_value_metadata + .as_ref() + .map(|key_values| { + key_values + .iter() + .filter_map(|kv| { + kv.value.as_ref().map(|value| { + ( + PlSmallStr::from_str(kv.key.as_str()), + PlSmallStr::from_str(value.as_str()), + ) + }) + }) + .collect() + }) + .unwrap_or_default() +} diff --git a/crates/polars-parquet/src/arrow/read/schema/mod.rs b/crates/polars-parquet/src/arrow/read/schema/mod.rs new file mode 100644 index 000000000000..ea27aa03d46d --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/schema/mod.rs @@ -0,0 +1,59 @@ +//! APIs to handle Parquet <-> Arrow schemas. +use arrow::datatypes::{ArrowSchema, TimeUnit}; + +mod convert; +mod metadata; + +pub(crate) use convert::*; +pub use convert::{parquet_to_arrow_schema, parquet_to_arrow_schema_with_options}; +pub use metadata::read_schema_from_metadata; +use polars_error::PolarsResult; + +use self::metadata::parse_key_value_metadata; +pub use crate::parquet::metadata::{FileMetadata, KeyValue, SchemaDescriptor}; +pub use crate::parquet::schema::types::ParquetType; + +/// Options when inferring schemas from Parquet +pub struct SchemaInferenceOptions { + /// When inferring schemas from the Parquet INT96 timestamp type, this is the corresponding TimeUnit + /// in the inferred Arrow Timestamp type. + /// + /// This defaults to `TimeUnit::Nanosecond`, but INT96 timestamps outside of the range of years 1678-2262, + /// will overflow when parsed as `Timestamp(TimeUnit::Nanosecond)`. Setting this to a lower resolution + /// (e.g. TimeUnit::Milliseconds) will result in loss of precision, but support a larger range of dates + /// without overflowing when parsing the data. + pub int96_coerce_to_timeunit: TimeUnit, +} + +impl Default for SchemaInferenceOptions { + fn default() -> Self { + SchemaInferenceOptions { + int96_coerce_to_timeunit: TimeUnit::Nanosecond, + } + } +} + +/// Infers a [`ArrowSchema`] from parquet's [`FileMetadata`]. +/// +/// This first looks for the metadata key `"ARROW:schema"`; if it does not exist, it converts the +/// Parquet types declared in the file's Parquet schema to Arrow's equivalent. +/// +/// # Error +/// This function errors iff the key `"ARROW:schema"` exists but is not correctly encoded, +/// indicating that the file's arrow metadata was incorrectly written. +pub fn infer_schema(file_metadata: &FileMetadata) -> PolarsResult { + infer_schema_with_options(file_metadata, &None) +} + +/// Like [`infer_schema`] but with configurable options which affects the behavior of inference +pub fn infer_schema_with_options( + file_metadata: &FileMetadata, + options: &Option, +) -> PolarsResult { + let mut metadata = parse_key_value_metadata(file_metadata.key_value_metadata()); + + let schema = read_schema_from_metadata(&mut metadata)?; + Ok(schema.unwrap_or_else(|| { + parquet_to_arrow_schema_with_options(file_metadata.schema().fields(), options) + })) +} diff --git a/crates/polars-parquet/src/arrow/read/statistics.rs b/crates/polars-parquet/src/arrow/read/statistics.rs new file mode 100644 index 000000000000..0c579f7e121b --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/statistics.rs @@ -0,0 +1,595 @@ +//! APIs exposing `crate::parquet`'s statistics as arrow's statistics. +use arrow::array::{ + Array, BinaryViewArray, BooleanArray, FixedSizeBinaryArray, MutableBinaryViewArray, + MutableBooleanArray, MutableFixedSizeBinaryArray, MutablePrimitiveArray, NullArray, + PrimitiveArray, Utf8ViewArray, +}; +use arrow::datatypes::{ArrowDataType, Field, IntegerType, IntervalUnit, TimeUnit}; +use arrow::types::{NativeType, days_ms, f16, i256}; +use ethnum::I256; +use polars_utils::IdxSize; +use polars_utils::pl_str::PlSmallStr; + +use super::{ParquetTimeUnit, RowGroupMetadata}; +use crate::parquet::error::{ParquetError, ParquetResult}; +use crate::parquet::schema::types::PhysicalType as ParquetPhysicalType; +use crate::parquet::statistics::Statistics as ParquetStatistics; +use crate::read::{ + ColumnChunkMetadata, PrimitiveLogicalType, convert_days_ms, convert_i128, convert_i256, + convert_year_month, int96_to_i64_ns, +}; + +/// Parquet statistics for a nesting level +#[derive(Debug, PartialEq)] +pub enum Statistics { + Column(Box), + + List(Option>), + FixedSizeList(Option>, usize), + + Struct(Box<[Option]>), + Dictionary(IntegerType, Option>, bool), +} + +/// Arrow-deserialized parquet statistics of a leaf-column +#[derive(Debug, PartialEq)] +pub struct ColumnStatistics { + field: Field, + + logical_type: Option, + physical_type: ParquetPhysicalType, + + /// Statistics of the leaf array of the column + statistics: ParquetStatistics, +} + +#[derive(Debug, PartialEq)] +pub enum ColumnPathSegment { + List { is_large: bool }, + FixedSizeList { width: usize }, + Dictionary { key: IntegerType, is_sorted: bool }, + Struct { column_idx: usize }, +} + +/// Arrow-deserialized parquet statistics of a leaf-column +#[derive(Debug, PartialEq)] +pub struct ArrowColumnStatistics { + pub null_count: Option, + pub distinct_count: Option, + + // While these two are Box, they will only ever contain one valid value. This might + // seems dumb, and don't get me wrong it is, but arrow::Scalar is basically useless. + pub min_value: Option>, + pub max_value: Option>, +} + +/// Arrow-deserialized parquet statistics of a leaf-column +pub struct ArrowColumnStatisticsArrays { + pub null_count: PrimitiveArray, + pub distinct_count: PrimitiveArray, + pub min_value: Box, + pub max_value: Box, +} + +fn timestamp(logical_type: Option<&PrimitiveLogicalType>, time_unit: TimeUnit, x: i64) -> i64 { + let unit = if let Some(PrimitiveLogicalType::Timestamp { unit, .. }) = logical_type { + unit + } else { + return x; + }; + + match (unit, time_unit) { + (ParquetTimeUnit::Milliseconds, TimeUnit::Second) => x / 1_000, + (ParquetTimeUnit::Microseconds, TimeUnit::Second) => x / 1_000_000, + (ParquetTimeUnit::Nanoseconds, TimeUnit::Second) => x * 1_000_000_000, + + (ParquetTimeUnit::Milliseconds, TimeUnit::Millisecond) => x, + (ParquetTimeUnit::Microseconds, TimeUnit::Millisecond) => x / 1_000, + (ParquetTimeUnit::Nanoseconds, TimeUnit::Millisecond) => x / 1_000_000, + + (ParquetTimeUnit::Milliseconds, TimeUnit::Microsecond) => x * 1_000, + (ParquetTimeUnit::Microseconds, TimeUnit::Microsecond) => x, + (ParquetTimeUnit::Nanoseconds, TimeUnit::Microsecond) => x / 1_000, + + (ParquetTimeUnit::Milliseconds, TimeUnit::Nanosecond) => x * 1_000_000, + (ParquetTimeUnit::Microseconds, TimeUnit::Nanosecond) => x * 1_000, + (ParquetTimeUnit::Nanoseconds, TimeUnit::Nanosecond) => x, + } +} + +impl ColumnStatistics { + pub fn into_arrow(self) -> ParquetResult { + use ParquetStatistics as S; + let (null_count, distinct_count) = match &self.statistics { + S::Binary(s) => (s.null_count, s.distinct_count), + S::Boolean(s) => (s.null_count, s.distinct_count), + S::FixedLen(s) => (s.null_count, s.distinct_count), + S::Int32(s) => (s.null_count, s.distinct_count), + S::Int64(s) => (s.null_count, s.distinct_count), + S::Int96(s) => (s.null_count, s.distinct_count), + S::Float(s) => (s.null_count, s.distinct_count), + S::Double(s) => (s.null_count, s.distinct_count), + }; + + let null_count = null_count.map(|v| v as u64); + let distinct_count = distinct_count.map(|v| v as u64); + + macro_rules! rmap { + ($expect:ident, $map:expr) => {{ + let s = self.statistics.$expect(); + + let min = s.min_value; + let max = s.max_value; + + let min = ($map)(min)?.map(|x| Box::new(x) as Box); + let max = ($map)(max)?.map(|x| Box::new(x) as Box); + + (min, max) + }}; + ($expect:ident, @prim $from:ty $(as $to:ty)? $(, $map:expr)?) => {{ + rmap!( + $expect, + |x: Option<$from>| { + $( + let x = x.map(|x| x as $to); + )? + $( + let x = x.map($map); + )? + ParquetResult::Ok(x.map(|x| PrimitiveArray::$(<$to>::)?new( + self.field.dtype().clone(), + vec![x].into(), + None, + ))) + } + ) + }}; + (@binary $(, $map:expr)?) => {{ + rmap!( + expect_binary, + |x: Option>| { + $( + let x = x.map($map); + )? + ParquetResult::Ok(x.map(|x| BinaryViewArray::from_slice([Some(x)]))) + } + ) + }}; + (@string) => {{ + rmap!( + expect_binary, + |x: Option>| { + let x = x.map(String::from_utf8).transpose().map_err(|_| { + ParquetError::oos("Invalid UTF8 in Statistics") + })?; + ParquetResult::Ok(x.map(|x| Utf8ViewArray::from_slice([Some(x)]))) + } + ) + }}; + } + + use {ArrowDataType as D, ParquetPhysicalType as PPT}; + let (min_value, max_value) = match (self.field.dtype(), &self.physical_type) { + (D::Null, _) => (None, None), + + (D::Boolean, _) => rmap!(expect_boolean, |x: Option| ParquetResult::Ok( + x.map(|x| BooleanArray::new(ArrowDataType::Boolean, vec![x].into(), None,)) + )), + + (D::Int8, _) => rmap!(expect_int32, @prim i32 as i8), + (D::Int16, _) => rmap!(expect_int32, @prim i32 as i16), + (D::Int32 | D::Date32 | D::Time32(_), _) => rmap!(expect_int32, @prim i32 as i32), + + // some implementations of parquet write arrow's date64 into i32. + (D::Date64, PPT::Int32) => rmap!(expect_int32, @prim i32 as i64, |x| x * 86400000), + + (D::Int64 | D::Time64(_) | D::Duration(_), _) | (D::Date64, PPT::Int64) => { + rmap!(expect_int64, @prim i64 as i64) + }, + + (D::Interval(IntervalUnit::YearMonth), _) => rmap!( + expect_binary, + @prim Vec, + |x| convert_year_month(&x) + ), + (D::Interval(IntervalUnit::DayTime), _) => rmap!( + expect_binary, + @prim Vec, + |x| convert_days_ms(&x) + ), + + (D::UInt8, _) => rmap!(expect_int32, @prim i32 as u8), + (D::UInt16, _) => rmap!(expect_int32, @prim i32 as u16), + (D::UInt32, PPT::Int32) => rmap!(expect_int32, @prim i32 as u32), + + // some implementations of parquet write arrow's u32 into i64. + (D::UInt32, PPT::Int64) => rmap!(expect_int64, @prim i64 as u32), + (D::UInt64, _) => rmap!(expect_int64, @prim i64 as u64), + + (D::Timestamp(time_unit, _), PPT::Int96) => { + rmap!(expect_int96, @prim [u32; 3], |x| { + timestamp(self.logical_type.as_ref(), *time_unit, int96_to_i64_ns(x)) + }) + }, + (D::Timestamp(time_unit, _), PPT::Int64) => { + rmap!(expect_int64, @prim i64, |x| { + timestamp(self.logical_type.as_ref(), *time_unit, x) + }) + }, + + // Read Float16, since we don't have a f16 type in Polars we read it to a Float32. + (_, PPT::FixedLenByteArray(2)) + if matches!( + self.logical_type.as_ref(), + Some(PrimitiveLogicalType::Float16) + ) => + { + rmap!(expect_fixedlen, @prim Vec, |v| f16::from_le_bytes([v[0], v[1]]).to_f32()) + }, + (D::Float32, _) => rmap!(expect_float, @prim f32), + (D::Float64, _) => rmap!(expect_double, @prim f64), + + (D::Decimal(_, _), PPT::Int32) => rmap!(expect_int32, @prim i32 as i128), + (D::Decimal(_, _), PPT::Int64) => rmap!(expect_int64, @prim i64 as i128), + (D::Decimal(_, _), PPT::FixedLenByteArray(n)) if *n > 16 => { + return Err(ParquetError::not_supported(format!( + "Can't decode Decimal128 type from Fixed Size Byte Array of len {n:?}", + ))); + }, + (D::Decimal(_, _), PPT::FixedLenByteArray(n)) => rmap!( + expect_fixedlen, + @prim Vec, + |x| convert_i128(&x, *n) + ), + (D::Decimal256(_, _), PPT::Int32) => { + rmap!(expect_int32, @prim i32, |x: i32| i256(I256::new(x.into()))) + }, + (D::Decimal256(_, _), PPT::Int64) => { + rmap!(expect_int64, @prim i64, |x: i64| i256(I256::new(x.into()))) + }, + (D::Decimal256(_, _), PPT::FixedLenByteArray(n)) if *n > 16 => { + return Err(ParquetError::not_supported(format!( + "Can't decode Decimal256 type from Fixed Size Byte Array of len {n:?}", + ))); + }, + (D::Decimal256(_, _), PPT::FixedLenByteArray(_)) => rmap!( + expect_fixedlen, + @prim Vec, + |x| convert_i256(&x) + ), + (D::Binary, _) => rmap!(@binary), + (D::LargeBinary, _) => rmap!(@binary), + (D::Utf8, _) => rmap!(@string), + (D::LargeUtf8, _) => rmap!(@string), + + (D::BinaryView, _) => rmap!(@binary), + (D::Utf8View, _) => rmap!(@string), + + (D::FixedSizeBinary(_), _) => { + rmap!(expect_fixedlen, |x: Option>| ParquetResult::Ok( + x.map(|x| FixedSizeBinaryArray::new( + self.field.dtype().clone(), + x.into(), + None + )) + )) + }, + + other => todo!("{:?}", other), + }; + + Ok(ArrowColumnStatistics { + null_count, + distinct_count, + + min_value, + max_value, + }) + } +} + +/// Deserializes the statistics in the column chunks from a single `row_group` +/// into [`Statistics`] associated from `field`'s name. +/// +/// # Errors +/// This function errors if the deserialization of the statistics fails (e.g. invalid utf8) +pub fn deserialize_all( + field: &Field, + row_groups: &[RowGroupMetadata], + field_idx: usize, +) -> ParquetResult> { + assert!(!row_groups.is_empty()); + use ArrowDataType as D; + match field.dtype() { + // @TODO: These are all a bit more complex, skip for now. + D::List(..) | D::LargeList(..) => Ok(None), + D::Dictionary(..) => Ok(None), + D::FixedSizeList(..) => Ok(None), + D::Struct(..) => Ok(None), + + _ => { + let mut null_count = MutablePrimitiveArray::::with_capacity(row_groups.len()); + let mut distinct_count = + MutablePrimitiveArray::::with_capacity(row_groups.len()); + + let primitive_type = &row_groups[0].parquet_columns()[field_idx] + .descriptor() + .descriptor + .primitive_type; + + let logical_type = &primitive_type.logical_type; + let physical_type = &primitive_type.physical_type; + + macro_rules! rmap { + ($expect:ident, $map:expr, $arr:ty$(, $arg:expr)?) => {{ + let mut min_arr = <$arr>::with_capacity(row_groups.len()$(, $arg)?); + let mut max_arr = <$arr>::with_capacity(row_groups.len()$(, $arg)?); + + for rg in row_groups { + let column = &rg.parquet_columns()[field_idx]; + let s = column.statistics().transpose()?; + + let (v_min, v_max, v_null_count, v_distinct_count) = match s { + None => (None, None, None, None), + Some(s) => { + let s = s.$expect(); + + let min = s.min_value; + let max = s.max_value; + + let min = ($map)(min)?; + let max = ($map)(max)?; + + ( + min, + max, + s.null_count.map(|v| v as IdxSize), + s.distinct_count.map(|v| v as IdxSize), + ) + } + }; + + min_arr.push(v_min); + max_arr.push(v_max); + null_count.push(v_null_count); + distinct_count.push(v_distinct_count); + } + + (min_arr.freeze().to_boxed(), max_arr.freeze().to_boxed()) + }}; + ($expect:ident, $arr:ty, @prim $from:ty $(as $to:ty)? $(, $map:expr)?) => {{ + rmap!( + $expect, + |x: Option<$from>| { + $( + let x = x.map(|x| x as $to); + )? + $( + let x = x.map($map); + )? + ParquetResult::Ok(x) + }, + $arr + ) + }}; + (@binary $(, $map:expr)?) => {{ + rmap!( + expect_binary, + |x: Option>| { + $( + let x = x.map($map); + )? + ParquetResult::Ok(x) + }, + MutableBinaryViewArray<[u8]> + ) + }}; + (@string) => {{ + rmap!( + expect_binary, + |x: Option>| { + let x = x.map(String::from_utf8).transpose().map_err(|_| { + ParquetError::oos("Invalid UTF8 in Statistics") + })?; + ParquetResult::Ok(x) + }, + MutableBinaryViewArray + ) + }}; + } + + use {ArrowDataType as D, ParquetPhysicalType as PPT}; + let (min_value, max_value) = match (field.dtype(), physical_type) { + (D::Null, _) => ( + NullArray::new(ArrowDataType::Null, row_groups.len()).to_boxed(), + NullArray::new(ArrowDataType::Null, row_groups.len()).to_boxed(), + ), + + (D::Boolean, _) => rmap!( + expect_boolean, + |x: Option| ParquetResult::Ok(x), + MutableBooleanArray + ), + + (D::Int8, _) => rmap!(expect_int32, MutablePrimitiveArray::, @prim i32 as i8), + (D::Int16, _) => { + rmap!(expect_int32, MutablePrimitiveArray::, @prim i32 as i16) + }, + (D::Int32 | D::Date32 | D::Time32(_), _) => { + rmap!(expect_int32, MutablePrimitiveArray::, @prim i32 as i32) + }, + + // some implementations of parquet write arrow's date64 into i32. + (D::Date64, PPT::Int32) => { + rmap!(expect_int32, MutablePrimitiveArray::, @prim i32 as i64, |x| x * 86400000) + }, + + (D::Int64 | D::Time64(_) | D::Duration(_), _) | (D::Date64, PPT::Int64) => { + rmap!(expect_int64, MutablePrimitiveArray::, @prim i64 as i64) + }, + + (D::Interval(IntervalUnit::YearMonth), _) => rmap!( + expect_binary, + MutablePrimitiveArray::, + @prim Vec, + |x| convert_year_month(&x) + ), + (D::Interval(IntervalUnit::DayTime), _) => rmap!( + expect_binary, + MutablePrimitiveArray::, + @prim Vec, + |x| convert_days_ms(&x) + ), + + (D::UInt8, _) => rmap!(expect_int32, MutablePrimitiveArray::, @prim i32 as u8), + (D::UInt16, _) => { + rmap!(expect_int32, MutablePrimitiveArray::, @prim i32 as u16) + }, + (D::UInt32, PPT::Int32) => { + rmap!(expect_int32, MutablePrimitiveArray::, @prim i32 as u32) + }, + + // some implementations of parquet write arrow's u32 into i64. + (D::UInt32, PPT::Int64) => { + rmap!(expect_int64, MutablePrimitiveArray::, @prim i64 as u32) + }, + (D::UInt64, _) => { + rmap!(expect_int64, MutablePrimitiveArray::, @prim i64 as u64) + }, + + (D::Timestamp(time_unit, _), PPT::Int96) => { + rmap!(expect_int96, MutablePrimitiveArray::, @prim [u32; 3], |x| { + timestamp(logical_type.as_ref(), *time_unit, int96_to_i64_ns(x)) + }) + }, + (D::Timestamp(time_unit, _), PPT::Int64) => { + rmap!(expect_int64, MutablePrimitiveArray::, @prim i64, |x| { + timestamp(logical_type.as_ref(), *time_unit, x) + }) + }, + + // Read Float16, since we don't have a f16 type in Polars we read it to a Float32. + (_, PPT::FixedLenByteArray(2)) + if matches!(logical_type.as_ref(), Some(PrimitiveLogicalType::Float16)) => + { + rmap!(expect_fixedlen, MutablePrimitiveArray::, @prim Vec, |v| f16::from_le_bytes([v[0], v[1]]).to_f32()) + }, + (D::Float32, _) => rmap!(expect_float, MutablePrimitiveArray::, @prim f32), + (D::Float64, _) => rmap!(expect_double, MutablePrimitiveArray::, @prim f64), + + (D::Decimal(_, _), PPT::Int32) => { + rmap!(expect_int32, MutablePrimitiveArray::, @prim i32 as i128) + }, + (D::Decimal(_, _), PPT::Int64) => { + rmap!(expect_int64, MutablePrimitiveArray::, @prim i64 as i128) + }, + (D::Decimal(_, _), PPT::FixedLenByteArray(n)) if *n > 16 => { + return Err(ParquetError::not_supported(format!( + "Can't decode Decimal128 type from Fixed Size Byte Array of len {n:?}", + ))); + }, + (D::Decimal(_, _), PPT::FixedLenByteArray(n)) => rmap!( + expect_fixedlen, + MutablePrimitiveArray::, + @prim Vec, + |x| convert_i128(&x, *n) + ), + (D::Decimal256(_, _), PPT::Int32) => { + rmap!(expect_int32, MutablePrimitiveArray::, @prim i32, |x: i32| i256(I256::new(x.into()))) + }, + (D::Decimal256(_, _), PPT::Int64) => { + rmap!(expect_int64, MutablePrimitiveArray::, @prim i64, |x: i64| i256(I256::new(x.into()))) + }, + (D::Decimal256(_, _), PPT::FixedLenByteArray(n)) if *n > 16 => { + return Err(ParquetError::not_supported(format!( + "Can't decode Decimal256 type from Fixed Size Byte Array of len {n:?}", + ))); + }, + (D::Decimal256(_, _), PPT::FixedLenByteArray(_)) => rmap!( + expect_fixedlen, + MutablePrimitiveArray::, + @prim Vec, + |x| convert_i256(&x) + ), + (D::Binary, _) => rmap!(@binary), + (D::LargeBinary, _) => rmap!(@binary), + (D::Utf8, _) => rmap!(@string), + (D::LargeUtf8, _) => rmap!(@string), + + (D::BinaryView, _) => rmap!(@binary), + (D::Utf8View, _) => rmap!(@string), + + (D::FixedSizeBinary(width), _) => { + rmap!( + expect_fixedlen, + |x: Option>| ParquetResult::Ok(x), + MutableFixedSizeBinaryArray, + *width + ) + }, + + other => todo!("{:?}", other), + }; + + Ok(Some(ArrowColumnStatisticsArrays { + null_count: null_count.freeze(), + distinct_count: distinct_count.freeze(), + min_value, + max_value, + })) + }, + } +} + +/// Deserializes the statistics in the column chunks from a single `row_group` +/// into [`Statistics`] associated from `field`'s name. +/// +/// # Errors +/// This function errors if the deserialization of the statistics fails (e.g. invalid utf8) +pub fn deserialize<'a>( + field: &Field, + columns: &mut impl ExactSizeIterator, +) -> ParquetResult> { + use ArrowDataType as D; + match field.dtype() { + D::List(field) | D::LargeList(field) => Ok(Some(Statistics::List( + deserialize(field.as_ref(), columns)?.map(Box::new), + ))), + D::Dictionary(key, dtype, is_sorted) => Ok(Some(Statistics::Dictionary( + *key, + deserialize( + &Field::new(PlSmallStr::EMPTY, dtype.as_ref().clone(), true), + columns, + )? + .map(Box::new), + *is_sorted, + ))), + D::FixedSizeList(field, width) => Ok(Some(Statistics::FixedSizeList( + deserialize(field.as_ref(), columns)?.map(Box::new), + *width, + ))), + D::Struct(fields) => { + let field_columns = fields + .iter() + .map(|f| deserialize(f, columns)) + .collect::>()?; + Ok(Some(Statistics::Struct(field_columns))) + }, + _ => { + let column = columns.next().unwrap(); + + Ok(column.statistics().transpose()?.map(|statistics| { + let primitive_type = &column.descriptor().descriptor.primitive_type; + + Statistics::Column(Box::new(ColumnStatistics { + field: field.clone(), + + logical_type: primitive_type.logical_type, + physical_type: primitive_type.physical_type, + + statistics, + })) + })) + }, + } +} diff --git a/crates/polars-parquet/src/arrow/write/binary/basic.rs b/crates/polars-parquet/src/arrow/write/binary/basic.rs new file mode 100644 index 000000000000..0afbd37bce22 --- /dev/null +++ b/crates/polars-parquet/src/arrow/write/binary/basic.rs @@ -0,0 +1,163 @@ +use arrow::array::{Array, BinaryArray, ValueSize}; +use arrow::bitmap::Bitmap; +use arrow::offset::Offset; +use polars_error::PolarsResult; + +use super::super::{WriteOptions, utils}; +use crate::arrow::read::schema::is_nullable; +use crate::parquet::encoding::{Encoding, delta_bitpacked}; +use crate::parquet::schema::types::PrimitiveType; +use crate::parquet::statistics::{BinaryStatistics, ParquetStatistics}; +use crate::write::utils::invalid_encoding; +use crate::write::{EncodeNullability, Page, StatisticsOptions}; + +pub(crate) fn encode_non_null_values<'a, I: Iterator>( + iter: I, + buffer: &mut Vec, +) { + iter.for_each(|x| { + // BYTE_ARRAY: first 4 bytes denote length in littleendian. + let len = (x.len() as u32).to_le_bytes(); + buffer.extend_from_slice(&len); + buffer.extend_from_slice(x); + }) +} + +pub(crate) fn encode_plain( + array: &BinaryArray, + options: EncodeNullability, + buffer: &mut Vec, +) { + if options.is_optional() && array.validity().is_some() { + let len_before = buffer.len(); + let capacity = + array.get_values_size() + (array.len() - array.null_count()) * size_of::(); + buffer.reserve(capacity); + encode_non_null_values(array.non_null_values_iter(), buffer); + // Ensure we allocated properly. + debug_assert_eq!(buffer.len() - len_before, capacity); + } else { + let len_before = buffer.len(); + let capacity = array.get_values_size() + array.len() * size_of::(); + buffer.reserve(capacity); + encode_non_null_values(array.values_iter(), buffer); + // Ensure we allocated properly. + debug_assert_eq!(buffer.len() - len_before, capacity); + } +} + +pub fn array_to_page( + array: &BinaryArray, + options: WriteOptions, + type_: PrimitiveType, + encoding: Encoding, +) -> PolarsResult { + let validity = array.validity(); + let is_optional = is_nullable(&type_.field_info); + let encode_options = EncodeNullability::new(is_optional); + + let mut buffer = vec![]; + utils::write_def_levels( + &mut buffer, + is_optional, + validity, + array.len(), + options.version, + )?; + + let definition_levels_byte_length = buffer.len(); + + match encoding { + Encoding::Plain => encode_plain(array, encode_options, &mut buffer), + Encoding::DeltaLengthByteArray => encode_delta( + array.values(), + array.offsets().buffer(), + array.validity(), + encode_options, + &mut buffer, + ), + _ => return Err(invalid_encoding(encoding, array.dtype())), + } + + let statistics = if options.has_statistics() { + Some(build_statistics(array, type_.clone(), &options.statistics)) + } else { + None + }; + + utils::build_plain_page( + buffer, + array.len(), + array.len(), + array.null_count(), + 0, + definition_levels_byte_length, + statistics, + type_, + options, + encoding, + ) + .map(Page::Data) +} + +pub(crate) fn build_statistics( + array: &BinaryArray, + primitive_type: PrimitiveType, + options: &StatisticsOptions, +) -> ParquetStatistics { + use polars_compute::min_max::MinMaxKernel; + + BinaryStatistics { + primitive_type, + null_count: options.null_count.then_some(array.null_count() as i64), + distinct_count: None, + max_value: options + .max_value + .then(|| array.max_propagate_nan_kernel().map(<[u8]>::to_vec)) + .flatten(), + min_value: options + .min_value + .then(|| array.min_propagate_nan_kernel().map(<[u8]>::to_vec)) + .flatten(), + } + .serialize() +} + +pub(crate) fn encode_delta( + values: &[u8], + offsets: &[O], + validity: Option<&Bitmap>, + options: EncodeNullability, + buffer: &mut Vec, +) { + if options.is_optional() && validity.is_some() { + if let Some(validity) = validity { + let lengths = offsets + .windows(2) + .map(|w| (w[1] - w[0]).to_usize() as i64) + .zip(validity.iter()) + .flat_map(|(x, is_valid)| if is_valid { Some(x) } else { None }); + let length = offsets.len() - 1 - validity.unset_bits(); + let lengths = utils::ExactSizedIter::new(lengths, length); + + delta_bitpacked::encode(lengths, buffer, 1); + } else { + let lengths = offsets.windows(2).map(|w| (w[1] - w[0]).to_usize() as i64); + delta_bitpacked::encode(lengths, buffer, 1); + } + } else { + let lengths = offsets.windows(2).map(|w| (w[1] - w[0]).to_usize() as i64); + delta_bitpacked::encode(lengths, buffer, 1); + } + + buffer.extend_from_slice( + &values[offsets.first().unwrap().to_usize()..offsets.last().unwrap().to_usize()], + ) +} + +/// Returns the ordering of two binary values. This corresponds to pyarrows' ordering +/// of statistics. +#[inline(always)] +pub(crate) fn ord_binary<'a>(a: &'a [u8], b: &'a [u8]) -> std::cmp::Ordering { + a.cmp(b) +} diff --git a/crates/polars-parquet/src/arrow/write/binary/mod.rs b/crates/polars-parquet/src/arrow/write/binary/mod.rs new file mode 100644 index 000000000000..db9a43582e0c --- /dev/null +++ b/crates/polars-parquet/src/arrow/write/binary/mod.rs @@ -0,0 +1,7 @@ +mod basic; +mod nested; + +pub use basic::array_to_page; +pub(crate) use basic::{build_statistics, encode_plain}; +pub(super) use basic::{encode_non_null_values, ord_binary}; +pub use nested::array_to_page as nested_array_to_page; diff --git a/crates/polars-parquet/src/arrow/write/binary/nested.rs b/crates/polars-parquet/src/arrow/write/binary/nested.rs new file mode 100644 index 000000000000..e122c2c525a2 --- /dev/null +++ b/crates/polars-parquet/src/arrow/write/binary/nested.rs @@ -0,0 +1,50 @@ +use arrow::array::{Array, BinaryArray}; +use arrow::offset::Offset; +use polars_error::PolarsResult; + +use super::super::{WriteOptions, nested, utils}; +use super::basic::{build_statistics, encode_plain}; +use crate::arrow::write::Nested; +use crate::parquet::encoding::Encoding; +use crate::parquet::page::DataPage; +use crate::parquet::schema::types::PrimitiveType; +use crate::read::schema::is_nullable; +use crate::write::EncodeNullability; + +pub fn array_to_page( + array: &BinaryArray, + options: WriteOptions, + type_: PrimitiveType, + nested: &[Nested], +) -> PolarsResult +where + O: Offset, +{ + let is_optional = is_nullable(&type_.field_info); + let encode_options = EncodeNullability::new(is_optional); + + let mut buffer = vec![]; + let (repetition_levels_byte_length, definition_levels_byte_length) = + nested::write_rep_and_def(options.version, nested, &mut buffer)?; + + encode_plain(array, encode_options, &mut buffer); + + let statistics = if options.has_statistics() { + Some(build_statistics(array, type_.clone(), &options.statistics)) + } else { + None + }; + + utils::build_plain_page( + buffer, + nested::num_values(nested), + nested[0].len(), + array.null_count(), + repetition_levels_byte_length, + definition_levels_byte_length, + statistics, + type_, + options, + Encoding::Plain, + ) +} diff --git a/crates/polars-parquet/src/arrow/write/binview/basic.rs b/crates/polars-parquet/src/arrow/write/binview/basic.rs new file mode 100644 index 000000000000..f184fd542cfe --- /dev/null +++ b/crates/polars-parquet/src/arrow/write/binview/basic.rs @@ -0,0 +1,128 @@ +use arrow::array::{Array, BinaryViewArray}; +use polars_compute::min_max::MinMaxKernel; +use polars_error::PolarsResult; + +use crate::parquet::encoding::delta_bitpacked; +use crate::parquet::schema::types::PrimitiveType; +use crate::parquet::statistics::{BinaryStatistics, ParquetStatistics}; +use crate::read::schema::is_nullable; +use crate::write::binary::encode_non_null_values; +use crate::write::utils::invalid_encoding; +use crate::write::{EncodeNullability, Encoding, Page, StatisticsOptions, WriteOptions, utils}; + +pub(crate) fn encode_plain( + array: &BinaryViewArray, + options: EncodeNullability, + buffer: &mut Vec, +) { + if options.is_optional() && array.validity().is_some() { + // @NOTE: This capacity might overestimate the amount of bytes since the buffers might + // still contain data that is not referenced by any value. + let capacity = + array.total_bytes_len() + (array.len() - array.null_count()) * size_of::(); + buffer.reserve(capacity); + + encode_non_null_values(array.non_null_values_iter(), buffer); + // Append the non-null values. + } else { + // @NOTE: This capacity might overestimate the amount of bytes since the buffers might + // still contain data that is not referenced by any value. + let capacity = array.total_bytes_len() + array.len() * size_of::(); + buffer.reserve(capacity); + + encode_non_null_values(array.values_iter(), buffer); + } +} + +pub(crate) fn encode_delta( + array: &BinaryViewArray, + options: EncodeNullability, + buffer: &mut Vec, +) { + if options.is_optional() && array.validity().is_some() { + let lengths = utils::ExactSizedIter::new( + array.non_null_views_iter().map(|v| v.length as i64), + array.len() - array.null_count(), + ); + delta_bitpacked::encode(lengths, buffer, 1); + + for slice in array.non_null_values_iter() { + buffer.extend_from_slice(slice) + } + } else { + let lengths = + utils::ExactSizedIter::new(array.views().iter().map(|v| v.length as i64), array.len()); + delta_bitpacked::encode(lengths, buffer, 1); + + buffer.extend(array.values_iter().flatten()); + } +} + +pub fn array_to_page( + array: &BinaryViewArray, + options: WriteOptions, + type_: PrimitiveType, + encoding: Encoding, +) -> PolarsResult { + let is_optional = is_nullable(&type_.field_info); + let encode_options = EncodeNullability::new(is_optional); + + let mut buffer = vec![]; + // TODO! reserve capacity + utils::write_def_levels( + &mut buffer, + is_optional, + array.validity(), + array.len(), + options.version, + )?; + + let definition_levels_byte_length = buffer.len(); + + match encoding { + Encoding::Plain => encode_plain(array, encode_options, &mut buffer), + Encoding::DeltaLengthByteArray => encode_delta(array, encode_options, &mut buffer), + _ => return Err(invalid_encoding(encoding, array.dtype())), + } + + let statistics = if options.has_statistics() { + Some(build_statistics(array, type_.clone(), &options.statistics)) + } else { + None + }; + + utils::build_plain_page( + buffer, + array.len(), + array.len(), + array.null_count(), + 0, + definition_levels_byte_length, + statistics, + type_, + options, + encoding, + ) + .map(Page::Data) +} + +pub(crate) fn build_statistics( + array: &BinaryViewArray, + primitive_type: PrimitiveType, + options: &StatisticsOptions, +) -> ParquetStatistics { + BinaryStatistics { + primitive_type, + null_count: options.null_count.then_some(array.null_count() as i64), + distinct_count: None, + max_value: options + .max_value + .then(|| array.max_propagate_nan_kernel().map(<[u8]>::to_vec)) + .flatten(), + min_value: options + .min_value + .then(|| array.min_propagate_nan_kernel().map(<[u8]>::to_vec)) + .flatten(), + } + .serialize() +} diff --git a/crates/polars-parquet/src/arrow/write/binview/mod.rs b/crates/polars-parquet/src/arrow/write/binview/mod.rs new file mode 100644 index 000000000000..5b0ab6102c22 --- /dev/null +++ b/crates/polars-parquet/src/arrow/write/binview/mod.rs @@ -0,0 +1,5 @@ +mod basic; +mod nested; + +pub(crate) use basic::{array_to_page, build_statistics, encode_plain}; +pub use nested::array_to_page as nested_array_to_page; diff --git a/crates/polars-parquet/src/arrow/write/binview/nested.rs b/crates/polars-parquet/src/arrow/write/binview/nested.rs new file mode 100644 index 000000000000..a271a383fb52 --- /dev/null +++ b/crates/polars-parquet/src/arrow/write/binview/nested.rs @@ -0,0 +1,46 @@ +use arrow::array::{Array, BinaryViewArray}; +use polars_error::PolarsResult; + +use super::super::{WriteOptions, nested, utils}; +use super::basic::{build_statistics, encode_plain}; +use crate::arrow::write::Nested; +use crate::parquet::encoding::Encoding; +use crate::parquet::page::DataPage; +use crate::parquet::schema::types::PrimitiveType; +use crate::read::schema::is_nullable; +use crate::write::EncodeNullability; + +pub fn array_to_page( + array: &BinaryViewArray, + options: WriteOptions, + type_: PrimitiveType, + nested: &[Nested], +) -> PolarsResult { + let is_optional = is_nullable(&type_.field_info); + let encode_options = EncodeNullability::new(is_optional); + + let mut buffer = vec![]; + let (repetition_levels_byte_length, definition_levels_byte_length) = + nested::write_rep_and_def(options.version, nested, &mut buffer)?; + + encode_plain(array, encode_options, &mut buffer); + + let statistics = if options.has_statistics() { + Some(build_statistics(array, type_.clone(), &options.statistics)) + } else { + None + }; + + utils::build_plain_page( + buffer, + nested::num_values(nested), + nested[0].len(), + array.null_count(), + repetition_levels_byte_length, + definition_levels_byte_length, + statistics, + type_, + options, + Encoding::Plain, + ) +} diff --git a/crates/polars-parquet/src/arrow/write/boolean/basic.rs b/crates/polars-parquet/src/arrow/write/boolean/basic.rs new file mode 100644 index 000000000000..be63db969633 --- /dev/null +++ b/crates/polars-parquet/src/arrow/write/boolean/basic.rs @@ -0,0 +1,127 @@ +use arrow::array::*; +use polars_error::{PolarsResult, polars_bail}; + +use super::super::{WriteOptions, utils}; +use crate::arrow::read::schema::is_nullable; +use crate::parquet::encoding::Encoding; +use crate::parquet::encoding::hybrid_rle::{self, bitpacked_encode}; +use crate::parquet::page::DataPage; +use crate::parquet::schema::types::PrimitiveType; +use crate::parquet::statistics::{BooleanStatistics, ParquetStatistics}; +use crate::write::{EncodeNullability, StatisticsOptions}; + +fn encode(iterator: impl Iterator, buffer: &mut Vec) -> PolarsResult<()> { + // encode values using bitpacking + let len = buffer.len(); + let mut buffer = std::io::Cursor::new(buffer); + buffer.set_position(len as u64); + Ok(bitpacked_encode(&mut buffer, iterator)?) +} + +pub(super) fn encode_plain( + array: &BooleanArray, + encode_options: EncodeNullability, + buffer: &mut Vec, +) -> PolarsResult<()> { + if encode_options.is_optional() && array.validity().is_some() { + encode(array.non_null_values_iter(), buffer) + } else { + encode(array.values().iter(), buffer) + } +} + +pub(super) fn encode_hybrid_rle( + array: &BooleanArray, + encode_options: EncodeNullability, + buffer: &mut Vec, +) -> PolarsResult<()> { + buffer.extend_from_slice(&[0; 4]); + let start = buffer.len(); + + if encode_options.is_optional() && array.validity().is_some() { + hybrid_rle::encode(buffer, array.non_null_values_iter(), 1)?; + } else { + hybrid_rle::encode(buffer, array.values().iter(), 1)?; + } + + let length = buffer.len() - start; + + // write the first 4 bytes as length + let length = (length as i32).to_le_bytes(); + (0..4).for_each(|i| buffer[start - 4 + i] = length[i]); + + Ok(()) +} + +pub fn array_to_page( + array: &BooleanArray, + options: WriteOptions, + type_: PrimitiveType, + encoding: Encoding, +) -> PolarsResult { + let is_optional = is_nullable(&type_.field_info); + let encode_nullability = EncodeNullability::new(is_optional); + + let validity = array.validity(); + + let mut buffer = vec![]; + utils::write_def_levels( + &mut buffer, + is_optional, + validity, + array.len(), + options.version, + )?; + + let definition_levels_byte_length = buffer.len(); + + match encoding { + Encoding::Plain => encode_plain(array, encode_nullability, &mut buffer)?, + Encoding::Rle => encode_hybrid_rle(array, encode_nullability, &mut buffer)?, + other => polars_bail!(nyi = "Encoding boolean as {other:?}"), + } + + let statistics = if options.has_statistics() { + Some(build_statistics(array, &options.statistics)) + } else { + None + }; + + utils::build_plain_page( + buffer, + array.len(), + array.len(), + array.null_count(), + 0, + definition_levels_byte_length, + statistics, + type_, + options, + encoding, + ) +} + +pub(super) fn build_statistics( + array: &BooleanArray, + options: &StatisticsOptions, +) -> ParquetStatistics { + use polars_compute::min_max::MinMaxKernel; + use polars_compute::unique::GenericUniqueKernel; + + BooleanStatistics { + null_count: options.null_count.then(|| array.null_count() as i64), + distinct_count: options + .distinct_count + .then(|| array.n_unique_non_null().try_into().ok()) + .flatten(), + max_value: options + .max_value + .then(|| array.max_propagate_nan_kernel()) + .flatten(), + min_value: options + .min_value + .then(|| array.min_propagate_nan_kernel()) + .flatten(), + } + .serialize() +} diff --git a/crates/polars-parquet/src/arrow/write/boolean/mod.rs b/crates/polars-parquet/src/arrow/write/boolean/mod.rs new file mode 100644 index 000000000000..280e2ff9efb5 --- /dev/null +++ b/crates/polars-parquet/src/arrow/write/boolean/mod.rs @@ -0,0 +1,5 @@ +mod basic; +mod nested; + +pub use basic::array_to_page; +pub use nested::array_to_page as nested_array_to_page; diff --git a/crates/polars-parquet/src/arrow/write/boolean/nested.rs b/crates/polars-parquet/src/arrow/write/boolean/nested.rs new file mode 100644 index 000000000000..326ebb9541e2 --- /dev/null +++ b/crates/polars-parquet/src/arrow/write/boolean/nested.rs @@ -0,0 +1,45 @@ +use arrow::array::{Array, BooleanArray}; +use polars_error::PolarsResult; + +use super::super::{EncodeNullability, WriteOptions, nested, utils}; +use super::basic::{build_statistics, encode_plain}; +use crate::arrow::read::schema::is_nullable; +use crate::arrow::write::Nested; +use crate::parquet::encoding::Encoding; +use crate::parquet::page::DataPage; +use crate::parquet::schema::types::PrimitiveType; + +pub fn array_to_page( + array: &BooleanArray, + options: WriteOptions, + type_: PrimitiveType, + nested: &[Nested], +) -> PolarsResult { + let is_optional = is_nullable(&type_.field_info); + let encode_options = EncodeNullability::new(is_optional); + + let mut buffer = vec![]; + let (repetition_levels_byte_length, definition_levels_byte_length) = + nested::write_rep_and_def(options.version, nested, &mut buffer)?; + + encode_plain(array, encode_options, &mut buffer)?; + + let statistics = if options.has_statistics() { + Some(build_statistics(array, &options.statistics)) + } else { + None + }; + + utils::build_plain_page( + buffer, + nested::num_values(nested), + nested[0].len(), + array.null_count(), + repetition_levels_byte_length, + definition_levels_byte_length, + statistics, + type_, + options, + Encoding::Plain, + ) +} diff --git a/crates/polars-parquet/src/arrow/write/dictionary.rs b/crates/polars-parquet/src/arrow/write/dictionary.rs new file mode 100644 index 000000000000..64580b81406c --- /dev/null +++ b/crates/polars-parquet/src/arrow/write/dictionary.rs @@ -0,0 +1,596 @@ +use arrow::array::{ + Array, BinaryViewArray, DictionaryArray, DictionaryKey, PrimitiveArray, Utf8ViewArray, +}; +use arrow::bitmap::{Bitmap, MutableBitmap}; +use arrow::buffer::Buffer; +use arrow::datatypes::{ArrowDataType, IntegerType, PhysicalType}; +use arrow::legacy::utils::CustomIterTools; +use arrow::trusted_len::TrustMyLength; +use arrow::types::NativeType; +use polars_compute::min_max::MinMaxKernel; +use polars_error::{PolarsResult, polars_bail}; + +use super::binary::{ + build_statistics as binary_build_statistics, encode_plain as binary_encode_plain, +}; +use super::fixed_size_binary::{ + build_statistics as fixed_binary_build_statistics, encode_plain as fixed_binary_encode_plain, +}; +use super::pages::PrimitiveNested; +use super::primitive::{ + build_statistics as primitive_build_statistics, encode_plain as primitive_encode_plain, +}; +use super::{EncodeNullability, Nested, WriteOptions, binview, nested}; +use crate::arrow::read::schema::is_nullable; +use crate::arrow::write::{slice_nested_leaf, utils}; +use crate::parquet::CowBuffer; +use crate::parquet::encoding::Encoding; +use crate::parquet::encoding::hybrid_rle::encode; +use crate::parquet::page::{DictPage, Page}; +use crate::parquet::schema::types::PrimitiveType; +use crate::parquet::statistics::ParquetStatistics; +use crate::write::DynIter; + +trait MinMaxThreshold { + const DELTA_THRESHOLD: usize; + const BITMASK_THRESHOLD: usize; + + fn from_start_and_offset(start: Self, offset: usize) -> Self; +} + +macro_rules! minmaxthreshold_impls { + ($($signed:ty, $unsigned:ty => $threshold:literal, $bm_threshold:expr,)+) => { + $( + impl MinMaxThreshold for $signed { + const DELTA_THRESHOLD: usize = $threshold; + const BITMASK_THRESHOLD: usize = $bm_threshold; + + fn from_start_and_offset(start: Self, offset: usize) -> Self { + start + ((offset as $unsigned) as $signed) + } + } + impl MinMaxThreshold for $unsigned { + const DELTA_THRESHOLD: usize = $threshold; + const BITMASK_THRESHOLD: usize = $bm_threshold; + + fn from_start_and_offset(start: Self, offset: usize) -> Self { + start + (offset as $unsigned) + } + } + )+ + }; +} + +minmaxthreshold_impls! { + i8, u8 => 16, u8::MAX as usize, + i16, u16 => 256, u16::MAX as usize, + i32, u32 => 512, u16::MAX as usize, + i64, u64 => 2048, u16::MAX as usize, +} + +enum DictionaryDecision { + NotWorth, + TryAgain, + Found(DictionaryArray), +} + +fn min_max_integer_encode_as_dictionary_optional<'a, E, T>( + array: &'a dyn Array, +) -> DictionaryDecision +where + E: std::fmt::Debug, + T: NativeType + + MinMaxThreshold + + std::cmp::Ord + + TryInto + + std::ops::Sub + + num_traits::CheckedSub + + num_traits::cast::AsPrimitive, + std::ops::RangeInclusive: Iterator, + PrimitiveArray: MinMaxKernel = T>, +{ + let min_max = as MinMaxKernel>::min_max_ignore_nan_kernel( + array.as_any().downcast_ref().unwrap(), + ); + + let Some((min, max)) = min_max else { + return DictionaryDecision::TryAgain; + }; + + debug_assert!(max >= min, "{max} >= {min}"); + let Some(diff) = max.checked_sub(&min) else { + return DictionaryDecision::TryAgain; + }; + + let diff = diff.as_(); + + if diff > T::BITMASK_THRESHOLD { + return DictionaryDecision::TryAgain; + } + + let mut seen_mask = MutableBitmap::from_len_zeroed(diff + 1); + + let array = array.as_any().downcast_ref::>().unwrap(); + + if array.has_nulls() { + for v in array.non_null_values_iter() { + let offset = (v - min).as_(); + debug_assert!(offset <= diff); + + unsafe { + seen_mask.set_unchecked(offset, true); + } + } + } else { + for v in array.values_iter() { + let offset = (*v - min).as_(); + debug_assert!(offset <= diff); + + unsafe { + seen_mask.set_unchecked(offset, true); + } + } + } + + let cardinality = seen_mask.set_bits(); + + let mut is_worth_it = false; + + is_worth_it |= cardinality <= T::DELTA_THRESHOLD; + is_worth_it |= (cardinality as f64) / (array.len() as f64) < 0.75; + + if !is_worth_it { + return DictionaryDecision::NotWorth; + } + + let seen_mask = seen_mask.freeze(); + + // SAFETY: We just did the calculation for this. + let indexes = seen_mask + .true_idx_iter() + .map(|idx| T::from_start_and_offset(min, idx)); + let indexes = unsafe { TrustMyLength::new(indexes, cardinality) }; + let indexes = indexes.collect_trusted::>(); + + let mut lookup = vec![0u16; diff + 1]; + + for (i, &idx) in indexes.iter().enumerate() { + lookup[(idx - min).as_()] = i as u16; + } + + use ArrowDataType as DT; + let values = PrimitiveArray::new(DT::from(T::PRIMITIVE), indexes.into(), None); + let values = Box::new(values); + + let keys: Buffer = array + .as_any() + .downcast_ref::>() + .unwrap() + .values() + .iter() + .map(|v| { + // @NOTE: + // Since the values might contain nulls which have a undefined value. We just + // clamp the values to between the min and max value. This way, they will still + // be valid dictionary keys. + let idx = *v.clamp(&min, &max) - min; + let value = unsafe { lookup.get_unchecked(idx.as_()) }; + (*value).into() + }) + .collect(); + + let keys = PrimitiveArray::new(DT::UInt32, keys, array.validity().cloned()); + DictionaryDecision::Found( + DictionaryArray::::try_new( + ArrowDataType::Dictionary( + IntegerType::UInt32, + Box::new(DT::from(T::PRIMITIVE)), + false, // @TODO: This might be able to be set to true? + ), + keys, + values, + ) + .unwrap(), + ) +} + +pub(crate) fn encode_as_dictionary_optional( + array: &dyn Array, + nested: &[Nested], + type_: PrimitiveType, + options: WriteOptions, +) -> Option>>> { + if array.is_empty() { + let array = DictionaryArray::::new_empty(ArrowDataType::Dictionary( + IntegerType::UInt32, + Box::new(array.dtype().clone()), + false, // @TODO: This might be able to be set to true? + )); + + return Some(array_to_pages( + &array, + type_, + nested, + options, + Encoding::RleDictionary, + )); + } + + use arrow::types::PrimitiveType as PT; + let fast_dictionary = match array.dtype().to_physical_type() { + PhysicalType::Primitive(pt) => match pt { + PT::Int8 => min_max_integer_encode_as_dictionary_optional::<_, i8>(array), + PT::Int16 => min_max_integer_encode_as_dictionary_optional::<_, i16>(array), + PT::Int32 => min_max_integer_encode_as_dictionary_optional::<_, i32>(array), + PT::Int64 => min_max_integer_encode_as_dictionary_optional::<_, i64>(array), + PT::UInt8 => min_max_integer_encode_as_dictionary_optional::<_, u8>(array), + PT::UInt16 => min_max_integer_encode_as_dictionary_optional::<_, u16>(array), + PT::UInt32 => min_max_integer_encode_as_dictionary_optional::<_, u32>(array), + PT::UInt64 => min_max_integer_encode_as_dictionary_optional::<_, u64>(array), + _ => DictionaryDecision::TryAgain, + }, + _ => DictionaryDecision::TryAgain, + }; + + match fast_dictionary { + DictionaryDecision::NotWorth => return None, + DictionaryDecision::Found(dictionary_array) => { + return Some(array_to_pages( + &dictionary_array, + type_, + nested, + options, + Encoding::RleDictionary, + )); + }, + DictionaryDecision::TryAgain => {}, + } + + let dtype = Box::new(array.dtype().clone()); + + let estimated_cardinality = polars_compute::cardinality::estimate_cardinality(array); + + if array.len() > 128 && (estimated_cardinality as f64) / (array.len() as f64) > 0.75 { + return None; + } + + // This does the group by. + let array = polars_compute::cast::cast( + array, + &ArrowDataType::Dictionary(IntegerType::UInt32, dtype, false), + Default::default(), + ) + .ok()?; + + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + Some(array_to_pages( + array, + type_, + nested, + options, + Encoding::RleDictionary, + )) +} + +fn serialize_def_levels_simple( + validity: Option<&Bitmap>, + length: usize, + is_optional: bool, + options: WriteOptions, + buffer: &mut Vec, +) -> PolarsResult<()> { + utils::write_def_levels(buffer, is_optional, validity, length, options.version) +} + +fn serialize_keys_values( + array: &DictionaryArray, + validity: Option<&Bitmap>, + buffer: &mut Vec, +) -> PolarsResult<()> { + let keys = array.keys_values_iter().map(|x| x as u32); + if let Some(validity) = validity { + // discard indices whose values are null. + let keys = keys + .zip(validity.iter()) + .filter(|&(_key, is_valid)| is_valid) + .map(|(key, _is_valid)| key); + let num_bits = utils::get_bit_width(keys.clone().max().unwrap_or(0) as u64); + + let keys = utils::ExactSizedIter::new(keys, array.len() - validity.unset_bits()); + + // num_bits as a single byte + buffer.push(num_bits as u8); + + // followed by the encoded indices. + Ok(encode::(buffer, keys, num_bits)?) + } else { + let num_bits = utils::get_bit_width(keys.clone().max().unwrap_or(0) as u64); + + // num_bits as a single byte + buffer.push(num_bits as u8); + + // followed by the encoded indices. + Ok(encode::(buffer, keys, num_bits)?) + } +} + +fn serialize_levels( + validity: Option<&Bitmap>, + length: usize, + type_: &PrimitiveType, + nested: &[Nested], + options: WriteOptions, + buffer: &mut Vec, +) -> PolarsResult<(usize, usize)> { + if nested.len() == 1 { + let is_optional = is_nullable(&type_.field_info); + serialize_def_levels_simple(validity, length, is_optional, options, buffer)?; + let definition_levels_byte_length = buffer.len(); + Ok((0, definition_levels_byte_length)) + } else { + nested::write_rep_and_def(options.version, nested, buffer) + } +} + +fn normalized_validity(array: &DictionaryArray) -> Option { + match (array.keys().validity(), array.values().validity()) { + (None, None) => None, + (keys, None) => keys.cloned(), + // The values can have a different length than the keys + (_, Some(_values)) => { + let iter = (0..array.len()).map(|i| unsafe { !array.is_null_unchecked(i) }); + MutableBitmap::from_trusted_len_iter(iter).into() + }, + } +} + +fn serialize_keys( + array: &DictionaryArray, + type_: PrimitiveType, + nested: &[Nested], + statistics: Option, + options: WriteOptions, +) -> PolarsResult { + let mut buffer = vec![]; + + let (start, len) = slice_nested_leaf(nested); + + let mut nested = nested.to_vec(); + let array = array.clone().sliced(start, len); + if let Some(Nested::Primitive(PrimitiveNested { length, .. })) = nested.last_mut() { + *length = len; + } else { + unreachable!("") + } + // Parquet only accepts a single validity - we "&" the validities into a single one + // and ignore keys whose _value_ is null. + // It's important that we slice before normalizing. + let validity = normalized_validity(&array); + + let (repetition_levels_byte_length, definition_levels_byte_length) = serialize_levels( + validity.as_ref(), + array.len(), + &type_, + &nested, + options, + &mut buffer, + )?; + + serialize_keys_values(&array, validity.as_ref(), &mut buffer)?; + + let (num_values, num_rows) = if nested.len() == 1 { + (array.len(), array.len()) + } else { + (nested::num_values(&nested), nested[0].len()) + }; + + utils::build_plain_page( + buffer, + num_values, + num_rows, + array.null_count(), + repetition_levels_byte_length, + definition_levels_byte_length, + statistics, + type_, + options, + Encoding::RleDictionary, + ) + .map(Page::Data) +} + +macro_rules! dyn_prim { + ($from:ty, $to:ty, $array:expr, $options:expr, $type_:expr) => {{ + let values = $array.values().as_any().downcast_ref().unwrap(); + + let buffer = + primitive_encode_plain::<$from, $to>(values, EncodeNullability::new(false), vec![]); + + let stats: Option = if !$options.statistics.is_empty() { + let mut stats = primitive_build_statistics::<$from, $to>( + values, + $type_.clone(), + &$options.statistics, + ); + stats.null_count = Some($array.null_count() as i64); + Some(stats.serialize()) + } else { + None + }; + ( + DictPage::new(CowBuffer::Owned(buffer), values.len(), false), + stats, + ) + }}; +} + +pub fn array_to_pages( + array: &DictionaryArray, + type_: PrimitiveType, + nested: &[Nested], + options: WriteOptions, + encoding: Encoding, +) -> PolarsResult>> { + match encoding { + Encoding::PlainDictionary | Encoding::RleDictionary => { + // write DictPage + let (dict_page, mut statistics): (_, Option) = match array + .values() + .dtype() + .to_logical_type() + { + ArrowDataType::Int8 => dyn_prim!(i8, i32, array, options, type_), + ArrowDataType::Int16 => dyn_prim!(i16, i32, array, options, type_), + ArrowDataType::Int32 | ArrowDataType::Date32 | ArrowDataType::Time32(_) => { + dyn_prim!(i32, i32, array, options, type_) + }, + ArrowDataType::Int64 + | ArrowDataType::Date64 + | ArrowDataType::Time64(_) + | ArrowDataType::Timestamp(_, _) + | ArrowDataType::Duration(_) => dyn_prim!(i64, i64, array, options, type_), + ArrowDataType::UInt8 => dyn_prim!(u8, i32, array, options, type_), + ArrowDataType::UInt16 => dyn_prim!(u16, i32, array, options, type_), + ArrowDataType::UInt32 => dyn_prim!(u32, i32, array, options, type_), + ArrowDataType::UInt64 => dyn_prim!(u64, i64, array, options, type_), + ArrowDataType::Float32 => dyn_prim!(f32, f32, array, options, type_), + ArrowDataType::Float64 => dyn_prim!(f64, f64, array, options, type_), + ArrowDataType::LargeUtf8 => { + let array = polars_compute::cast::cast( + array.values().as_ref(), + &ArrowDataType::LargeBinary, + Default::default(), + ) + .unwrap(); + let array = array.as_any().downcast_ref().unwrap(); + + let mut buffer = vec![]; + binary_encode_plain::(array, EncodeNullability::Required, &mut buffer); + let stats = if options.has_statistics() { + Some(binary_build_statistics( + array, + type_.clone(), + &options.statistics, + )) + } else { + None + }; + ( + DictPage::new(CowBuffer::Owned(buffer), array.len(), false), + stats, + ) + }, + ArrowDataType::BinaryView => { + let array = array + .values() + .as_any() + .downcast_ref::() + .unwrap(); + let mut buffer = vec![]; + binview::encode_plain(array, EncodeNullability::Required, &mut buffer); + + let stats = if options.has_statistics() { + Some(binview::build_statistics( + array, + type_.clone(), + &options.statistics, + )) + } else { + None + }; + ( + DictPage::new(CowBuffer::Owned(buffer), array.len(), false), + stats, + ) + }, + ArrowDataType::Utf8View => { + let array = array + .values() + .as_any() + .downcast_ref::() + .unwrap() + .to_binview(); + let mut buffer = vec![]; + binview::encode_plain(&array, EncodeNullability::Required, &mut buffer); + + let stats = if options.has_statistics() { + Some(binview::build_statistics( + &array, + type_.clone(), + &options.statistics, + )) + } else { + None + }; + ( + DictPage::new(CowBuffer::Owned(buffer), array.len(), false), + stats, + ) + }, + ArrowDataType::LargeBinary => { + let values = array.values().as_any().downcast_ref().unwrap(); + + let mut buffer = vec![]; + binary_encode_plain::(values, EncodeNullability::Required, &mut buffer); + let stats = if options.has_statistics() { + Some(binary_build_statistics( + values, + type_.clone(), + &options.statistics, + )) + } else { + None + }; + ( + DictPage::new(CowBuffer::Owned(buffer), values.len(), false), + stats, + ) + }, + ArrowDataType::FixedSizeBinary(_) => { + let mut buffer = vec![]; + let array = array.values().as_any().downcast_ref().unwrap(); + fixed_binary_encode_plain(array, EncodeNullability::Required, &mut buffer); + let stats = if options.has_statistics() { + let stats = fixed_binary_build_statistics( + array, + type_.clone(), + &options.statistics, + ); + Some(stats.serialize()) + } else { + None + }; + ( + DictPage::new(CowBuffer::Owned(buffer), array.len(), false), + stats, + ) + }, + other => { + polars_bail!( + nyi = + "Writing dictionary arrays to parquet only support data type {other:?}" + ) + }, + }; + + if let Some(stats) = &mut statistics { + stats.null_count = Some(array.null_count() as i64) + } + + // write DataPage pointing to DictPage + let data_page = + serialize_keys(array, type_, nested, statistics, options)?.unwrap_data(); + + Ok(DynIter::new( + [Page::Dict(dict_page), Page::Data(data_page)] + .into_iter() + .map(Ok), + )) + }, + _ => polars_bail!(nyi = "Dictionary arrays only support dictionary encoding"), + } +} diff --git a/crates/polars-parquet/src/arrow/write/file.rs b/crates/polars-parquet/src/arrow/write/file.rs new file mode 100644 index 000000000000..61116359472e --- /dev/null +++ b/crates/polars-parquet/src/arrow/write/file.rs @@ -0,0 +1,111 @@ +use std::io::Write; + +use arrow::datatypes::ArrowSchema; +use polars_error::{PolarsError, PolarsResult}; + +use super::schema::schema_to_metadata_key; +use super::{ThriftFileMetadata, WriteOptions, to_parquet_schema}; +use crate::parquet::metadata::{KeyValue, SchemaDescriptor}; +use crate::parquet::write::{RowGroupIterColumns, WriteOptions as FileWriteOptions}; + +/// Attaches [`ArrowSchema`] to `key_value_metadata` +pub fn add_arrow_schema( + schema: &ArrowSchema, + key_value_metadata: Option>, +) -> Option> { + key_value_metadata + .map(|mut x| { + x.push(schema_to_metadata_key(schema)); + x + }) + .or_else(|| Some(vec![schema_to_metadata_key(schema)])) +} + +/// An interface to write a parquet to a [`Write`] +pub struct FileWriter { + writer: crate::parquet::write::FileWriter, + schema: ArrowSchema, + options: WriteOptions, +} + +// Accessors +impl FileWriter { + /// The options assigned to the file + pub fn options(&self) -> WriteOptions { + self.options + } + + /// The [`SchemaDescriptor`] assigned to this file + pub fn parquet_schema(&self) -> &SchemaDescriptor { + self.writer.schema() + } + + /// The [`ArrowSchema`] assigned to this file + pub fn schema(&self) -> &ArrowSchema { + &self.schema + } +} + +impl FileWriter { + /// Returns a new [`FileWriter`]. + /// # Error + /// If it is unable to derive a parquet schema from [`ArrowSchema`]. + pub fn new_with_parquet_schema( + writer: W, + schema: ArrowSchema, + parquet_schema: SchemaDescriptor, + options: WriteOptions, + ) -> Self { + let created_by = Some("Polars".to_string()); + + Self { + writer: crate::parquet::write::FileWriter::new( + writer, + parquet_schema, + FileWriteOptions { + version: options.version, + write_statistics: options.has_statistics(), + }, + created_by, + ), + schema, + options, + } + } + + /// Returns a new [`FileWriter`]. + /// # Error + /// If it is unable to derive a parquet schema from [`ArrowSchema`]. + pub fn try_new(writer: W, schema: ArrowSchema, options: WriteOptions) -> PolarsResult { + let parquet_schema = to_parquet_schema(&schema)?; + Ok(Self::new_with_parquet_schema( + writer, + schema, + parquet_schema, + options, + )) + } + + /// Writes a row group to the file. + pub fn write(&mut self, row_group: RowGroupIterColumns<'_, PolarsError>) -> PolarsResult<()> { + Ok(self.writer.write(row_group)?) + } + + /// Writes the footer of the parquet file. Returns the total size of the file. + pub fn end(&mut self, key_value_metadata: Option>) -> PolarsResult { + let key_value_metadata = add_arrow_schema(&self.schema, key_value_metadata); + Ok(self.writer.end(key_value_metadata)?) + } + + /// Consumes this writer and returns the inner writer + pub fn into_inner(self) -> W { + self.writer.into_inner() + } + + /// Returns the underlying writer and [`ThriftFileMetadata`] + /// # Panics + /// This function panics if [`Self::end`] has not yet been called + pub fn into_inner_and_metadata(self) -> (W, ThriftFileMetadata) { + self.writer.into_inner_and_metadata() + } +} diff --git a/crates/polars-parquet/src/arrow/write/fixed_size_binary/basic.rs b/crates/polars-parquet/src/arrow/write/fixed_size_binary/basic.rs new file mode 100644 index 000000000000..fe71bd7ca501 --- /dev/null +++ b/crates/polars-parquet/src/arrow/write/fixed_size_binary/basic.rs @@ -0,0 +1,47 @@ +use arrow::array::{Array, FixedSizeBinaryArray}; +use polars_error::PolarsResult; + +use super::encode_plain; +use crate::parquet::page::DataPage; +use crate::parquet::schema::types::PrimitiveType; +use crate::parquet::statistics::FixedLenStatistics; +use crate::read::schema::is_nullable; +use crate::write::{EncodeNullability, Encoding, WriteOptions, utils}; + +pub fn array_to_page( + array: &FixedSizeBinaryArray, + options: WriteOptions, + type_: PrimitiveType, + statistics: Option, +) -> PolarsResult { + let is_optional = is_nullable(&type_.field_info); + let encode_options = EncodeNullability::new(is_optional); + + let validity = array.validity(); + + let mut buffer = vec![]; + utils::write_def_levels( + &mut buffer, + is_optional, + validity, + array.len(), + options.version, + )?; + + let definition_levels_byte_length = buffer.len(); + + encode_plain(array, encode_options, &mut buffer); + + utils::build_plain_page( + buffer, + array.len(), + array.len(), + array.null_count(), + 0, + definition_levels_byte_length, + statistics.map(|x| x.serialize()), + type_, + options, + Encoding::Plain, + ) +} diff --git a/crates/polars-parquet/src/arrow/write/fixed_size_binary/mod.rs b/crates/polars-parquet/src/arrow/write/fixed_size_binary/mod.rs new file mode 100644 index 000000000000..58f11adfa491 --- /dev/null +++ b/crates/polars-parquet/src/arrow/write/fixed_size_binary/mod.rs @@ -0,0 +1,160 @@ +mod basic; +mod nested; + +use arrow::array::{Array, FixedSizeBinaryArray, PrimitiveArray}; +use arrow::types::i256; +pub use basic::array_to_page; +pub use nested::array_to_page as nested_array_to_page; + +use super::binary::ord_binary; +use super::{EncodeNullability, StatisticsOptions}; +use crate::parquet::schema::types::PrimitiveType; +use crate::parquet::statistics::FixedLenStatistics; + +pub(crate) fn encode_plain( + array: &FixedSizeBinaryArray, + options: EncodeNullability, + buffer: &mut Vec, +) { + // append the non-null values + if options.is_optional() && array.validity().is_some() { + array.iter().for_each(|x| { + if let Some(x) = x { + buffer.extend_from_slice(x); + } + }) + } else { + buffer.extend_from_slice(array.values()); + } +} + +pub(super) fn build_statistics( + array: &FixedSizeBinaryArray, + primitive_type: PrimitiveType, + options: &StatisticsOptions, +) -> FixedLenStatistics { + FixedLenStatistics { + primitive_type, + null_count: options.null_count.then_some(array.null_count() as i64), + distinct_count: None, + max_value: options + .max_value + .then(|| { + array + .iter() + .flatten() + .max_by(|x, y| ord_binary(x, y)) + .map(|x| x.to_vec()) + }) + .flatten(), + min_value: options + .min_value + .then(|| { + array + .iter() + .flatten() + .min_by(|x, y| ord_binary(x, y)) + .map(|x| x.to_vec()) + }) + .flatten(), + } +} + +pub(super) fn build_statistics_decimal( + array: &PrimitiveArray, + primitive_type: PrimitiveType, + size: usize, + options: &StatisticsOptions, +) -> FixedLenStatistics { + FixedLenStatistics { + primitive_type, + null_count: options.null_count.then_some(array.null_count() as i64), + distinct_count: None, + max_value: options + .max_value + .then(|| { + array + .iter() + .flatten() + .max() + .map(|x| x.to_be_bytes()[16 - size..].to_vec()) + }) + .flatten(), + min_value: options + .min_value + .then(|| { + array + .iter() + .flatten() + .min() + .map(|x| x.to_be_bytes()[16 - size..].to_vec()) + }) + .flatten(), + } +} + +pub(super) fn build_statistics_decimal256_with_i128( + array: &PrimitiveArray, + primitive_type: PrimitiveType, + size: usize, + options: &StatisticsOptions, +) -> FixedLenStatistics { + FixedLenStatistics { + primitive_type, + null_count: options.null_count.then_some(array.null_count() as i64), + distinct_count: None, + max_value: options + .max_value + .then(|| { + array + .iter() + .flatten() + .max() + .map(|x| x.0.low().to_be_bytes()[16 - size..].to_vec()) + }) + .flatten(), + min_value: options + .min_value + .then(|| { + array + .iter() + .flatten() + .min() + .map(|x| x.0.low().to_be_bytes()[16 - size..].to_vec()) + }) + .flatten(), + } +} + +pub(super) fn build_statistics_decimal256( + array: &PrimitiveArray, + primitive_type: PrimitiveType, + size: usize, + options: &StatisticsOptions, +) -> FixedLenStatistics { + FixedLenStatistics { + primitive_type, + null_count: options.null_count.then_some(array.null_count() as i64), + distinct_count: None, + max_value: options + .max_value + .then(|| { + array + .iter() + .flatten() + .max() + .map(|x| x.0.to_be_bytes()[32 - size..].to_vec()) + }) + .flatten(), + min_value: options + .min_value + .then(|| { + array + .iter() + .flatten() + .min() + .map(|x| x.0.to_be_bytes()[32 - size..].to_vec()) + }) + .flatten(), + } +} diff --git a/crates/polars-parquet/src/arrow/write/fixed_size_binary/nested.rs b/crates/polars-parquet/src/arrow/write/fixed_size_binary/nested.rs new file mode 100644 index 000000000000..7a5c3ccb826a --- /dev/null +++ b/crates/polars-parquet/src/arrow/write/fixed_size_binary/nested.rs @@ -0,0 +1,39 @@ +use arrow::array::{Array, FixedSizeBinaryArray}; +use polars_error::PolarsResult; + +use super::encode_plain; +use crate::parquet::page::DataPage; +use crate::parquet::schema::types::PrimitiveType; +use crate::parquet::statistics::FixedLenStatistics; +use crate::read::schema::is_nullable; +use crate::write::{EncodeNullability, Encoding, Nested, WriteOptions, nested, utils}; + +pub fn array_to_page( + array: &FixedSizeBinaryArray, + options: WriteOptions, + type_: PrimitiveType, + nested: &[Nested], + statistics: Option, +) -> PolarsResult { + let is_optional = is_nullable(&type_.field_info); + let encode_options = EncodeNullability::new(is_optional); + + let mut buffer = vec![]; + let (repetition_levels_byte_length, definition_levels_byte_length) = + nested::write_rep_and_def(options.version, nested, &mut buffer)?; + + encode_plain(array, encode_options, &mut buffer); + + utils::build_plain_page( + buffer, + nested::num_values(nested), + nested[0].len(), + array.null_count(), + repetition_levels_byte_length, + definition_levels_byte_length, + statistics.map(|x| x.serialize()), + type_, + options, + Encoding::Plain, + ) +} diff --git a/crates/polars-parquet/src/arrow/write/mod.rs b/crates/polars-parquet/src/arrow/write/mod.rs new file mode 100644 index 000000000000..87a97bdc17b6 --- /dev/null +++ b/crates/polars-parquet/src/arrow/write/mod.rs @@ -0,0 +1,1031 @@ +//! APIs to write to Parquet format. +//! +//! # Arrow/Parquet Interoperability +//! As of [parquet-format v2.9](https://github.com/apache/parquet-format/blob/master/LogicalTypes.md) +//! there are Arrow [DataTypes](arrow::datatypes::ArrowDataType) which do not have a parquet +//! representation. These include but are not limited to: +//! * `ArrowDataType::Timestamp(TimeUnit::Second, _)` +//! * `ArrowDataType::Int64` +//! * `ArrowDataType::Duration` +//! * `ArrowDataType::Date64` +//! * `ArrowDataType::Time32(TimeUnit::Second)` +//! +//! The use of these arrow types will result in no logical type being stored within a parquet file. + +mod binary; +mod binview; +mod boolean; +mod dictionary; +mod file; +mod fixed_size_binary; +mod nested; +mod pages; +mod primitive; +mod row_group; +mod schema; +mod utils; + +use arrow::array::*; +use arrow::datatypes::*; +use arrow::types::{NativeType, days_ms, i256}; +pub use nested::{num_values, write_rep_and_def}; +pub use pages::{to_leaves, to_nested, to_parquet_leaves}; +use polars_utils::pl_str::PlSmallStr; +pub use utils::write_def_levels; + +pub use crate::parquet::compression::{BrotliLevel, CompressionOptions, GzipLevel, ZstdLevel}; +pub use crate::parquet::encoding::Encoding; +pub use crate::parquet::metadata::{ + Descriptor, FileMetadata, KeyValue, SchemaDescriptor, ThriftFileMetadata, +}; +pub use crate::parquet::page::{CompressedDataPage, CompressedPage, Page}; +use crate::parquet::schema::types::PrimitiveType as ParquetPrimitiveType; +pub use crate::parquet::schema::types::{ + FieldInfo, ParquetType, PhysicalType as ParquetPhysicalType, +}; +pub use crate::parquet::write::{ + Compressor, DynIter, DynStreamingIterator, RowGroupIterColumns, Version, compress, + write_metadata_sidecar, +}; +pub use crate::parquet::{FallibleStreamingIterator, fallible_streaming_iterator}; + +/// The statistics to write +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct StatisticsOptions { + pub min_value: bool, + pub max_value: bool, + pub distinct_count: bool, + pub null_count: bool, +} + +impl Default for StatisticsOptions { + fn default() -> Self { + Self { + min_value: true, + max_value: true, + distinct_count: false, + null_count: true, + } + } +} + +/// Options to encode an array +#[derive(Clone, Copy)] +pub enum EncodeNullability { + Required, + Optional, +} + +/// Currently supported options to write to parquet +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct WriteOptions { + /// Whether to write statistics + pub statistics: StatisticsOptions, + /// The page and file version to use + pub version: Version, + /// The compression to apply to every page + pub compression: CompressionOptions, + /// The size to flush a page, defaults to 1024 * 1024 if None + pub data_page_size: Option, +} + +use arrow::compute::aggregate::estimated_bytes_size; +use arrow::match_integer_type; +pub use file::FileWriter; +pub use pages::{Nested, array_to_columns, arrays_to_columns}; +use polars_error::{PolarsResult, polars_bail}; +pub use row_group::{RowGroupIterator, row_group_iter}; +pub use schema::to_parquet_type; + +use self::pages::{FixedSizeListNested, PrimitiveNested, StructNested}; +use crate::write::dictionary::encode_as_dictionary_optional; + +impl StatisticsOptions { + pub fn empty() -> Self { + Self { + min_value: false, + max_value: false, + distinct_count: false, + null_count: false, + } + } + + pub fn full() -> Self { + Self { + min_value: true, + max_value: true, + distinct_count: true, + null_count: true, + } + } + + pub fn is_empty(&self) -> bool { + !(self.min_value || self.max_value || self.distinct_count || self.null_count) + } + + pub fn is_full(&self) -> bool { + self.min_value && self.max_value && self.distinct_count && self.null_count + } +} + +impl WriteOptions { + pub fn has_statistics(&self) -> bool { + !self.statistics.is_empty() + } +} + +impl EncodeNullability { + const fn new(is_optional: bool) -> Self { + if is_optional { + Self::Optional + } else { + Self::Required + } + } + + fn is_optional(self) -> bool { + matches!(self, Self::Optional) + } +} + +/// returns offset and length to slice the leaf values +pub fn slice_nested_leaf(nested: &[Nested]) -> (usize, usize) { + // find the deepest recursive dremel structure as that one determines how many values we must + // take + let mut out = (0, 0); + for nested in nested.iter().rev() { + match nested { + Nested::LargeList(l_nested) => { + let start = *l_nested.offsets.first(); + let end = *l_nested.offsets.last(); + return (start as usize, (end - start) as usize); + }, + Nested::List(l_nested) => { + let start = *l_nested.offsets.first(); + let end = *l_nested.offsets.last(); + return (start as usize, (end - start) as usize); + }, + Nested::FixedSizeList(nested) => return (0, nested.length * nested.width), + Nested::Primitive(nested) => out = (0, nested.length), + Nested::Struct(_) => {}, + } + } + out +} + +fn decimal_length_from_precision(precision: usize) -> usize { + // digits = floor(log_10(2^(8*n - 1) - 1)) + // ceil(digits) = log10(2^(8*n - 1) - 1) + // 10^ceil(digits) = 2^(8*n - 1) - 1 + // 10^ceil(digits) + 1 = 2^(8*n - 1) + // log2(10^ceil(digits) + 1) = (8*n - 1) + // log2(10^ceil(digits) + 1) + 1 = 8*n + // (log2(10^ceil(a) + 1) + 1) / 8 = n + (((10.0_f64.powi(precision as i32) + 1.0).log2() + 1.0) / 8.0).ceil() as usize +} + +/// Creates a parquet [`SchemaDescriptor`] from a [`ArrowSchema`]. +pub fn to_parquet_schema(schema: &ArrowSchema) -> PolarsResult { + let parquet_types = schema + .iter_values() + .map(to_parquet_type) + .collect::>>()?; + Ok(SchemaDescriptor::new( + PlSmallStr::from_static("root"), + parquet_types, + )) +} + +/// Slices the [`Array`] to `Box` and `Vec`. +pub fn slice_parquet_array( + primitive_array: &mut dyn Array, + nested: &mut [Nested], + mut current_offset: usize, + mut current_length: usize, +) { + for nested in nested.iter_mut() { + match nested { + Nested::LargeList(l_nested) => { + l_nested.offsets.slice(current_offset, current_length + 1); + if let Some(validity) = l_nested.validity.as_mut() { + validity.slice(current_offset, current_length) + }; + + // Update the offset/ length so that the Primitive is sliced properly. + current_length = l_nested.offsets.range() as usize; + current_offset = *l_nested.offsets.first() as usize; + }, + Nested::List(l_nested) => { + l_nested.offsets.slice(current_offset, current_length + 1); + if let Some(validity) = l_nested.validity.as_mut() { + validity.slice(current_offset, current_length) + }; + + // Update the offset/ length so that the Primitive is sliced properly. + current_length = l_nested.offsets.range() as usize; + current_offset = *l_nested.offsets.first() as usize; + }, + Nested::Struct(StructNested { + validity, length, .. + }) => { + *length = current_length; + if let Some(validity) = validity.as_mut() { + validity.slice(current_offset, current_length) + }; + }, + Nested::Primitive(PrimitiveNested { + validity, length, .. + }) => { + *length = current_length; + if let Some(validity) = validity.as_mut() { + validity.slice(current_offset, current_length) + }; + primitive_array.slice(current_offset, current_length); + }, + Nested::FixedSizeList(FixedSizeListNested { + validity, + length, + width, + .. + }) => { + if let Some(validity) = validity.as_mut() { + validity.slice(current_offset, current_length) + }; + *length = current_length; + // Update the offset/ length so that the Primitive is sliced properly. + current_length *= *width; + current_offset *= *width; + }, + } + } +} + +/// Get the length of [`Array`] that should be sliced. +pub fn get_max_length(nested: &[Nested]) -> usize { + let mut length = 0; + for nested in nested.iter() { + match nested { + Nested::LargeList(l_nested) => length += l_nested.offsets.range() as usize, + Nested::List(l_nested) => length += l_nested.offsets.range() as usize, + Nested::FixedSizeList(nested) => length += nested.length * nested.width, + _ => {}, + } + } + length +} + +/// Returns an iterator of [`Page`]. +pub fn array_to_pages( + primitive_array: &dyn Array, + type_: ParquetPrimitiveType, + nested: &[Nested], + options: WriteOptions, + mut encoding: Encoding, +) -> PolarsResult>> { + if let ArrowDataType::Dictionary(key_type, _, _) = primitive_array.dtype().to_logical_type() { + return match_integer_type!(key_type, |$T| { + dictionary::array_to_pages::<$T>( + primitive_array.as_any().downcast_ref().unwrap(), + type_, + &nested, + options, + encoding, + ) + }); + }; + if let Encoding::RleDictionary = encoding { + // Only take this path for primitive columns + if matches!(nested.first(), Some(Nested::Primitive(_))) { + if let Some(result) = + encode_as_dictionary_optional(primitive_array, nested, type_.clone(), options) + { + return result; + } + } + + // We didn't succeed, fallback to plain + encoding = Encoding::Plain; + } + + let nested = nested.to_vec(); + + let number_of_rows = nested[0].len(); + + // note: this is not correct if the array is sliced - the estimation should happen on the + // primitive after sliced for parquet + let byte_size = estimated_bytes_size(primitive_array); + + const DEFAULT_PAGE_SIZE: usize = 1024 * 1024; + let max_page_size = options.data_page_size.unwrap_or(DEFAULT_PAGE_SIZE); + let max_page_size = max_page_size.min(2usize.pow(31) - 2usize.pow(25)); // allowed maximum page size + let bytes_per_row = if number_of_rows == 0 { + 0 + } else { + ((byte_size as f64) / (number_of_rows as f64)) as usize + }; + let rows_per_page = (max_page_size / (bytes_per_row + 1)).max(1); + + let row_iter = (0..number_of_rows) + .step_by(rows_per_page) + .map(move |offset| { + let length = if offset + rows_per_page > number_of_rows { + number_of_rows - offset + } else { + rows_per_page + }; + (offset, length) + }); + + let primitive_array = primitive_array.to_boxed(); + + let pages = row_iter.map(move |(offset, length)| { + let mut right_array = primitive_array.clone(); + let mut right_nested = nested.clone(); + slice_parquet_array(right_array.as_mut(), &mut right_nested, offset, length); + + array_to_page( + right_array.as_ref(), + type_.clone(), + &right_nested, + options, + encoding, + ) + }); + Ok(DynIter::new(pages)) +} + +/// Converts an [`Array`] to a [`CompressedPage`] based on options, descriptor and `encoding`. +pub fn array_to_page( + array: &dyn Array, + type_: ParquetPrimitiveType, + nested: &[Nested], + options: WriteOptions, + encoding: Encoding, +) -> PolarsResult { + if nested.len() == 1 { + // special case where validity == def levels + return array_to_page_simple(array, type_, options, encoding); + } + array_to_page_nested(array, type_, nested, options, encoding) +} + +/// Converts an [`Array`] to a [`CompressedPage`] based on options, descriptor and `encoding`. +pub fn array_to_page_simple( + array: &dyn Array, + type_: ParquetPrimitiveType, + options: WriteOptions, + encoding: Encoding, +) -> PolarsResult { + let dtype = array.dtype(); + + match dtype.to_logical_type() { + ArrowDataType::Boolean => boolean::array_to_page( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ), + // casts below MUST match the casts done at the metadata (field -> parquet type). + ArrowDataType::UInt8 => { + return primitive::array_to_page_integer::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ); + }, + ArrowDataType::UInt16 => { + return primitive::array_to_page_integer::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ); + }, + ArrowDataType::UInt32 => { + return primitive::array_to_page_integer::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ); + }, + ArrowDataType::UInt64 => { + return primitive::array_to_page_integer::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ); + }, + ArrowDataType::Int8 => { + return primitive::array_to_page_integer::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ); + }, + ArrowDataType::Int16 => { + return primitive::array_to_page_integer::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ); + }, + ArrowDataType::Int32 | ArrowDataType::Date32 | ArrowDataType::Time32(_) => { + return primitive::array_to_page_integer::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ); + }, + ArrowDataType::Int64 + | ArrowDataType::Date64 + | ArrowDataType::Time64(_) + | ArrowDataType::Timestamp(_, _) + | ArrowDataType::Duration(_) => { + return primitive::array_to_page_integer::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ); + }, + ArrowDataType::Float32 => primitive::array_to_page_plain::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + ), + ArrowDataType::Float64 => primitive::array_to_page_plain::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + ), + ArrowDataType::LargeUtf8 => { + let array = + polars_compute::cast::cast(array, &ArrowDataType::LargeBinary, Default::default()) + .unwrap(); + return binary::array_to_page::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ); + }, + ArrowDataType::LargeBinary => { + return binary::array_to_page::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ); + }, + ArrowDataType::BinaryView => { + return binview::array_to_page( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ); + }, + ArrowDataType::Utf8View => { + let array = + polars_compute::cast::cast(array, &ArrowDataType::BinaryView, Default::default()) + .unwrap(); + return binview::array_to_page( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ); + }, + ArrowDataType::Null => { + let array = Int32Array::new_null(ArrowDataType::Int32, array.len()); + primitive::array_to_page_plain::(&array, options, type_) + }, + ArrowDataType::Interval(IntervalUnit::YearMonth) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let mut values = Vec::::with_capacity(12 * array.len()); + array.values().iter().for_each(|x| { + let bytes = &x.to_le_bytes(); + values.extend_from_slice(bytes); + values.extend_from_slice(&[0; 8]); + }); + let array = FixedSizeBinaryArray::new( + ArrowDataType::FixedSizeBinary(12), + values.into(), + array.validity().cloned(), + ); + let statistics = if options.has_statistics() { + Some(fixed_size_binary::build_statistics( + &array, + type_.clone(), + &options.statistics, + )) + } else { + None + }; + fixed_size_binary::array_to_page(&array, options, type_, statistics) + }, + ArrowDataType::Interval(IntervalUnit::DayTime) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let mut values = Vec::::with_capacity(12 * array.len()); + array.values().iter().for_each(|x| { + let bytes = &x.to_le_bytes(); + values.extend_from_slice(&[0; 4]); // months + values.extend_from_slice(bytes); // days and seconds + }); + let array = FixedSizeBinaryArray::new( + ArrowDataType::FixedSizeBinary(12), + values.into(), + array.validity().cloned(), + ); + let statistics = if options.has_statistics() { + Some(fixed_size_binary::build_statistics( + &array, + type_.clone(), + &options.statistics, + )) + } else { + None + }; + fixed_size_binary::array_to_page(&array, options, type_, statistics) + }, + ArrowDataType::FixedSizeBinary(_) => { + let array = array.as_any().downcast_ref().unwrap(); + let statistics = if options.has_statistics() { + Some(fixed_size_binary::build_statistics( + array, + type_.clone(), + &options.statistics, + )) + } else { + None + }; + + fixed_size_binary::array_to_page(array, options, type_, statistics) + }, + ArrowDataType::Decimal256(precision, _) => { + let precision = *precision; + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + if precision <= 9 { + let values = array + .values() + .iter() + .map(|x| x.0.as_i32()) + .collect::>() + .into(); + + let array = PrimitiveArray::::new( + ArrowDataType::Int32, + values, + array.validity().cloned(), + ); + return primitive::array_to_page_integer::( + &array, options, type_, encoding, + ); + } else if precision <= 18 { + let values = array + .values() + .iter() + .map(|x| x.0.as_i64()) + .collect::>() + .into(); + + let array = PrimitiveArray::::new( + ArrowDataType::Int64, + values, + array.validity().cloned(), + ); + return primitive::array_to_page_integer::( + &array, options, type_, encoding, + ); + } else if precision <= 38 { + let size = decimal_length_from_precision(precision); + let statistics = if options.has_statistics() { + let stats = fixed_size_binary::build_statistics_decimal256_with_i128( + array, + type_.clone(), + size, + &options.statistics, + ); + Some(stats) + } else { + None + }; + + let mut values = Vec::::with_capacity(size * array.len()); + array.values().iter().for_each(|x| { + let bytes = &x.0.low().to_be_bytes()[16 - size..]; + values.extend_from_slice(bytes) + }); + let array = FixedSizeBinaryArray::new( + ArrowDataType::FixedSizeBinary(size), + values.into(), + array.validity().cloned(), + ); + fixed_size_binary::array_to_page(&array, options, type_, statistics) + } else { + let size = 32; + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let statistics = if options.has_statistics() { + let stats = fixed_size_binary::build_statistics_decimal256( + array, + type_.clone(), + size, + &options.statistics, + ); + Some(stats) + } else { + None + }; + let mut values = Vec::::with_capacity(size * array.len()); + array.values().iter().for_each(|x| { + let bytes = &x.to_be_bytes(); + values.extend_from_slice(bytes) + }); + let array = FixedSizeBinaryArray::new( + ArrowDataType::FixedSizeBinary(size), + values.into(), + array.validity().cloned(), + ); + + fixed_size_binary::array_to_page(&array, options, type_, statistics) + } + }, + ArrowDataType::Decimal(precision, _) => { + let precision = *precision; + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + if precision <= 9 { + let values = array + .values() + .iter() + .map(|x| *x as i32) + .collect::>() + .into(); + + let array = PrimitiveArray::::new( + ArrowDataType::Int32, + values, + array.validity().cloned(), + ); + return primitive::array_to_page_integer::( + &array, options, type_, encoding, + ); + } else if precision <= 18 { + let values = array + .values() + .iter() + .map(|x| *x as i64) + .collect::>() + .into(); + + let array = PrimitiveArray::::new( + ArrowDataType::Int64, + values, + array.validity().cloned(), + ); + return primitive::array_to_page_integer::( + &array, options, type_, encoding, + ); + } else { + let size = decimal_length_from_precision(precision); + + let statistics = if options.has_statistics() { + let stats = fixed_size_binary::build_statistics_decimal( + array, + type_.clone(), + size, + &options.statistics, + ); + Some(stats) + } else { + None + }; + + let mut values = Vec::::with_capacity(size * array.len()); + array.values().iter().for_each(|x| { + let bytes = &x.to_be_bytes()[16 - size..]; + values.extend_from_slice(bytes) + }); + let array = FixedSizeBinaryArray::new( + ArrowDataType::FixedSizeBinary(size), + values.into(), + array.validity().cloned(), + ); + fixed_size_binary::array_to_page(&array, options, type_, statistics) + } + }, + other => polars_bail!(nyi = "Writing parquet pages for data type {other:?}"), + } + .map(Page::Data) +} + +fn array_to_page_nested( + array: &dyn Array, + type_: ParquetPrimitiveType, + nested: &[Nested], + options: WriteOptions, + _encoding: Encoding, +) -> PolarsResult { + use ArrowDataType::*; + match array.dtype().to_logical_type() { + Null => { + let array = Int32Array::new_null(ArrowDataType::Int32, array.len()); + primitive::nested_array_to_page::(&array, options, type_, nested) + }, + Boolean => { + let array = array.as_any().downcast_ref().unwrap(); + boolean::nested_array_to_page(array, options, type_, nested) + }, + LargeUtf8 => { + let array = + polars_compute::cast::cast(array, &LargeBinary, Default::default()).unwrap(); + let array = array.as_any().downcast_ref().unwrap(); + binary::nested_array_to_page::(array, options, type_, nested) + }, + LargeBinary => { + let array = array.as_any().downcast_ref().unwrap(); + binary::nested_array_to_page::(array, options, type_, nested) + }, + BinaryView => { + let array = array.as_any().downcast_ref().unwrap(); + binview::nested_array_to_page(array, options, type_, nested) + }, + Utf8View => { + let array = polars_compute::cast::cast(array, &BinaryView, Default::default()).unwrap(); + let array = array.as_any().downcast_ref().unwrap(); + binview::nested_array_to_page(array, options, type_, nested) + }, + UInt8 => { + let array = array.as_any().downcast_ref().unwrap(); + primitive::nested_array_to_page::(array, options, type_, nested) + }, + UInt16 => { + let array = array.as_any().downcast_ref().unwrap(); + primitive::nested_array_to_page::(array, options, type_, nested) + }, + UInt32 => { + let array = array.as_any().downcast_ref().unwrap(); + primitive::nested_array_to_page::(array, options, type_, nested) + }, + UInt64 => { + let array = array.as_any().downcast_ref().unwrap(); + primitive::nested_array_to_page::(array, options, type_, nested) + }, + Int8 => { + let array = array.as_any().downcast_ref().unwrap(); + primitive::nested_array_to_page::(array, options, type_, nested) + }, + Int16 => { + let array = array.as_any().downcast_ref().unwrap(); + primitive::nested_array_to_page::(array, options, type_, nested) + }, + Int32 | Date32 | Time32(_) => { + let array = array.as_any().downcast_ref().unwrap(); + primitive::nested_array_to_page::(array, options, type_, nested) + }, + Int64 | Date64 | Time64(_) | Timestamp(_, _) | Duration(_) => { + let array = array.as_any().downcast_ref().unwrap(); + primitive::nested_array_to_page::(array, options, type_, nested) + }, + Float32 => { + let array = array.as_any().downcast_ref().unwrap(); + primitive::nested_array_to_page::(array, options, type_, nested) + }, + Float64 => { + let array = array.as_any().downcast_ref().unwrap(); + primitive::nested_array_to_page::(array, options, type_, nested) + }, + Decimal(precision, _) => { + let precision = *precision; + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + if precision <= 9 { + let values = array + .values() + .iter() + .map(|x| *x as i32) + .collect::>() + .into(); + + let array = PrimitiveArray::::new( + ArrowDataType::Int32, + values, + array.validity().cloned(), + ); + primitive::nested_array_to_page::(&array, options, type_, nested) + } else if precision <= 18 { + let values = array + .values() + .iter() + .map(|x| *x as i64) + .collect::>() + .into(); + + let array = PrimitiveArray::::new( + ArrowDataType::Int64, + values, + array.validity().cloned(), + ); + primitive::nested_array_to_page::(&array, options, type_, nested) + } else { + let size = decimal_length_from_precision(precision); + + let statistics = if options.has_statistics() { + let stats = fixed_size_binary::build_statistics_decimal( + array, + type_.clone(), + size, + &options.statistics, + ); + Some(stats) + } else { + None + }; + + let mut values = Vec::::with_capacity(size * array.len()); + array.values().iter().for_each(|x| { + let bytes = &x.to_be_bytes()[16 - size..]; + values.extend_from_slice(bytes) + }); + let array = FixedSizeBinaryArray::new( + ArrowDataType::FixedSizeBinary(size), + values.into(), + array.validity().cloned(), + ); + fixed_size_binary::nested_array_to_page(&array, options, type_, nested, statistics) + } + }, + Decimal256(precision, _) => { + let precision = *precision; + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + if precision <= 9 { + let values = array + .values() + .iter() + .map(|x| x.0.as_i32()) + .collect::>() + .into(); + + let array = PrimitiveArray::::new( + ArrowDataType::Int32, + values, + array.validity().cloned(), + ); + primitive::nested_array_to_page::(&array, options, type_, nested) + } else if precision <= 18 { + let values = array + .values() + .iter() + .map(|x| x.0.as_i64()) + .collect::>() + .into(); + + let array = PrimitiveArray::::new( + ArrowDataType::Int64, + values, + array.validity().cloned(), + ); + primitive::nested_array_to_page::(&array, options, type_, nested) + } else if precision <= 38 { + let size = decimal_length_from_precision(precision); + let statistics = if options.has_statistics() { + let stats = fixed_size_binary::build_statistics_decimal256_with_i128( + array, + type_.clone(), + size, + &options.statistics, + ); + Some(stats) + } else { + None + }; + + let mut values = Vec::::with_capacity(size * array.len()); + array.values().iter().for_each(|x| { + let bytes = &x.0.low().to_be_bytes()[16 - size..]; + values.extend_from_slice(bytes) + }); + let array = FixedSizeBinaryArray::new( + ArrowDataType::FixedSizeBinary(size), + values.into(), + array.validity().cloned(), + ); + fixed_size_binary::nested_array_to_page(&array, options, type_, nested, statistics) + } else { + let size = 32; + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let statistics = if options.has_statistics() { + let stats = fixed_size_binary::build_statistics_decimal256( + array, + type_.clone(), + size, + &options.statistics, + ); + Some(stats) + } else { + None + }; + let mut values = Vec::::with_capacity(size * array.len()); + array.values().iter().for_each(|x| { + let bytes = &x.to_be_bytes(); + values.extend_from_slice(bytes) + }); + let array = FixedSizeBinaryArray::new( + ArrowDataType::FixedSizeBinary(size), + values.into(), + array.validity().cloned(), + ); + + fixed_size_binary::nested_array_to_page(&array, options, type_, nested, statistics) + } + }, + other => polars_bail!(nyi = "Writing nested parquet pages for data type {other:?}"), + } + .map(Page::Data) +} + +fn transverse_recursive T + Clone>( + dtype: &ArrowDataType, + map: F, + encodings: &mut Vec, +) { + use arrow::datatypes::PhysicalType::*; + match dtype.to_physical_type() { + Null | Boolean | Primitive(_) | Binary | FixedSizeBinary | LargeBinary | Utf8 + | Dictionary(_) | LargeUtf8 | BinaryView | Utf8View => encodings.push(map(dtype)), + List | FixedSizeList | LargeList => { + let a = dtype.to_logical_type(); + if let ArrowDataType::List(inner) = a { + transverse_recursive(&inner.dtype, map, encodings) + } else if let ArrowDataType::LargeList(inner) = a { + transverse_recursive(&inner.dtype, map, encodings) + } else if let ArrowDataType::FixedSizeList(inner, _) = a { + transverse_recursive(&inner.dtype, map, encodings) + } else { + unreachable!() + } + }, + Struct => { + if let ArrowDataType::Struct(fields) = dtype.to_logical_type() { + for field in fields { + transverse_recursive(&field.dtype, map.clone(), encodings) + } + } else { + unreachable!() + } + }, + Map => { + if let ArrowDataType::Map(field, _) = dtype.to_logical_type() { + if let ArrowDataType::Struct(fields) = field.dtype.to_logical_type() { + for field in fields { + transverse_recursive(&field.dtype, map.clone(), encodings) + } + } else { + unreachable!() + } + } else { + unreachable!() + } + }, + Union => todo!(), + } +} + +/// Transverses the `dtype` up to its (parquet) columns and returns a vector of +/// items based on `map`. +/// +/// This is used to assign an [`Encoding`] to every parquet column based on the columns' type (see example) +pub fn transverse T + Clone>(dtype: &ArrowDataType, map: F) -> Vec { + let mut encodings = vec![]; + transverse_recursive(dtype, map, &mut encodings); + encodings +} diff --git a/crates/polars-parquet/src/arrow/write/nested/dremel/mod.rs b/crates/polars-parquet/src/arrow/write/nested/dremel/mod.rs new file mode 100644 index 000000000000..961393bf4ed2 --- /dev/null +++ b/crates/polars-parquet/src/arrow/write/nested/dremel/mod.rs @@ -0,0 +1,422 @@ +//! Implements the Dremel encoding part of Parquet with *repetition-levels* and *definition-levels* + +use arrow::bitmap::Bitmap; +use arrow::offset::OffsetsBuffer; +use polars_utils::fixedringbuffer::FixedRingBuffer; + +use super::super::pages::Nested; + +#[cfg(test)] +mod tests; + +/// A Dremel encoding value +#[derive(Clone, Copy)] +pub struct DremelValue { + /// A *repetition-level* value + pub rep: u16, + /// A *definition-level* value + pub def: u16, +} + +/// This tries to mirror the Parquet Schema structures, so that is simple to reason about the +/// Dremel structures. +enum LevelContent<'a> { + /// Always 1 instance + Required, + /// Zero or more instances + Repeated, + /// Zero or one instance + Optional(Option<&'a Bitmap>), +} + +struct Level<'a> { + content: LevelContent<'a>, + /// "Iterator" with number of elements for the next level + lengths: LevelLength<'a>, + /// Remaining number of elements to process. NOTE: This is **not** equal to `length - offset`. + remaining: usize, + /// Offset into level elements + offset: usize, + /// The definition-level associated with this level + definition_depth: u16, + /// The repetition-level associated with this level + repetition_depth: u16, +} + +/// This contains the number of elements on the next level for each +enum LevelLength<'a> { + /// Fixed number of elements based on the validity of this element + Optional(usize), + /// Fixed number of elements irregardless of the validity of this element + Constant(usize), + /// Variable number of elements and calculated from the difference between two `i32` offsets + OffsetsI32(&'a OffsetsBuffer), + /// Variable number of elements and calculated from the difference between two `i64` offsets + OffsetsI64(&'a OffsetsBuffer), +} + +/// A iterator for Dremel *repetition* and *definition-levels* in Parquet +/// +/// This buffers many consequentive repetition and definition-levels as to not have to branch in +/// and out of this code constantly. +pub struct BufferedDremelIter<'a> { + buffer: FixedRingBuffer, + + levels: Box<[Level<'a>]>, + /// Current offset into `levels` that is being explored + current_level: usize, + + last_repetition: u16, +} + +/// return number values of the nested +pub fn num_values(nested: &[Nested]) -> usize { + // @TODO: Make this smarter + // + // This is not that smart because it is really slow, but not doing this would be: + // 1. Error prone + // 2. Repeat much of the logic that you find below + BufferedDremelIter::new(nested).count() +} + +impl Level<'_> { + /// Fetch the number of elements given on the next level at `offset` on this level + fn next_level_length(&self, offset: usize, is_valid: bool) -> usize { + match self.lengths { + LevelLength::Optional(n) if is_valid => n, + LevelLength::Optional(_) => 0, + LevelLength::Constant(n) => n, + LevelLength::OffsetsI32(n) => n.length_at(offset), + LevelLength::OffsetsI64(n) => n.length_at(offset), + } + } +} + +impl<'a> BufferedDremelIter<'a> { + // @NOTE: This can maybe just directly be gotten from the Field and array, this double + // conversion seems rather wasteful. + /// Create a new [`BufferedDremelIter`] from a set of nested structures + /// + /// This creates a structure that resembles (but is not exactly the same) the Parquet schema, + /// we can then iterate this quite well. + pub fn new(nested: &'a [Nested]) -> Self { + let mut levels = Vec::with_capacity(nested.len() * 2 - 1); + + let mut definition_depth = 0u16; + let mut repetition_depth = 0u16; + for n in nested { + match n { + Nested::Primitive(n) => { + let (content, lengths) = if n.is_optional { + definition_depth += 1; + ( + LevelContent::Optional(n.validity.as_ref()), + LevelLength::Optional(1), + ) + } else { + (LevelContent::Required, LevelLength::Constant(1)) + }; + + levels.push(Level { + content, + lengths, + remaining: n.length, + offset: 0, + definition_depth, + repetition_depth, + }); + }, + Nested::List(n) => { + if n.is_optional { + definition_depth += 1; + levels.push(Level { + content: LevelContent::Optional(n.validity.as_ref()), + lengths: LevelLength::Constant(1), + remaining: n.offsets.len_proxy(), + offset: 0, + definition_depth, + repetition_depth, + }); + } + + definition_depth += 1; + levels.push(Level { + content: LevelContent::Repeated, + lengths: LevelLength::OffsetsI32(&n.offsets), + remaining: n.offsets.len_proxy(), + offset: 0, + definition_depth, + repetition_depth, + }); + repetition_depth += 1; + }, + Nested::LargeList(n) => { + if n.is_optional { + definition_depth += 1; + levels.push(Level { + content: LevelContent::Optional(n.validity.as_ref()), + lengths: LevelLength::Constant(1), + remaining: n.offsets.len_proxy(), + offset: 0, + definition_depth, + repetition_depth, + }); + } + + definition_depth += 1; + levels.push(Level { + content: LevelContent::Repeated, + lengths: LevelLength::OffsetsI64(&n.offsets), + remaining: n.offsets.len_proxy(), + offset: 0, + definition_depth, + repetition_depth, + }); + repetition_depth += 1; + }, + Nested::FixedSizeList(n) => { + if n.is_optional { + definition_depth += 1; + levels.push(Level { + content: LevelContent::Optional(n.validity.as_ref()), + lengths: LevelLength::Constant(1), + remaining: n.length, + offset: 0, + definition_depth, + repetition_depth, + }); + } + + definition_depth += 1; + levels.push(Level { + content: LevelContent::Repeated, + lengths: LevelLength::Constant(n.width), + remaining: n.length, + offset: 0, + definition_depth, + repetition_depth, + }); + repetition_depth += 1; + }, + Nested::Struct(n) => { + let content = if n.is_optional { + definition_depth += 1; + LevelContent::Optional(n.validity.as_ref()) + } else { + LevelContent::Required + }; + + levels.push(Level { + content, + lengths: LevelLength::Constant(1), + remaining: n.length, + offset: 0, + definition_depth, + repetition_depth, + }); + }, + }; + } + + let levels = levels.into_boxed_slice(); + + Self { + // This size is rather arbitrary, but it seems good to make it not too, too high as to + // reduce memory consumption. + buffer: FixedRingBuffer::new(256), + + levels, + current_level: 0, + last_repetition: 0, + } + } + + /// Attempt to fill the rest to the buffer with as many values as possible + fn fill(&mut self) { + // First exit condition: + // If the buffer is full stop trying to fetch more values and just pop the first + // element in the buffer. + // + // Second exit condition: + // We have exhausted all elements at the final level, there are no elements left. + while !(self.buffer.is_full() || (self.current_level == 0 && self.levels[0].remaining == 0)) + { + if self.levels[self.current_level].remaining == 0 { + self.last_repetition = u16::min( + self.last_repetition, + self.levels[self.current_level - 1].repetition_depth, + ); + self.current_level -= 1; + continue; + } + + let ns = &mut self.levels; + let lvl = self.current_level; + + let is_last_nesting = ns.len() == self.current_level + 1; + + macro_rules! push_value { + ($def:expr) => { + self.buffer + .push(DremelValue { + rep: self.last_repetition, + def: $def, + }) + .unwrap(); + self.last_repetition = ns[lvl].repetition_depth; + }; + } + + let num_done = match (&ns[lvl].content, is_last_nesting) { + (LevelContent::Required | LevelContent::Optional(None), true) => { + push_value!(ns[lvl].definition_depth); + + 1 + self.buffer.fill_repeat( + DremelValue { + rep: self.last_repetition, + def: ns[lvl].definition_depth, + }, + ns[lvl].remaining - 1, + ) + }, + (LevelContent::Required, false) => { + self.current_level += 1; + ns[lvl + 1].remaining = ns[lvl].next_level_length(ns[lvl].offset, true); + 1 + }, + + (LevelContent::Optional(Some(validity)), true) => { + let num_possible = + usize::min(self.buffer.remaining_capacity(), ns[lvl].remaining); + + let validity = (*validity).clone().sliced(ns[lvl].offset, num_possible); + + // @NOTE: maybe, we can do something here with leading zeros + for is_valid in validity.iter() { + push_value!(ns[lvl].definition_depth - u16::from(!is_valid)); + } + + num_possible + }, + (LevelContent::Optional(None), false) => { + let num_possible = + usize::min(self.buffer.remaining_capacity(), ns[lvl].remaining); + let mut num_done = num_possible; + let def = ns[lvl].definition_depth; + + // @NOTE: maybe, we can do something here with leading zeros + for i in 0..num_possible { + let next_level_length = ns[lvl].next_level_length(ns[lvl].offset + i, true); + + if next_level_length == 0 { + // Zero-sized (fixed) lists + push_value!(def); + } else { + self.current_level += 1; + ns[lvl + 1].remaining = next_level_length; + num_done = i + 1; + break; + } + } + + num_done + }, + (LevelContent::Optional(Some(validity)), false) => { + let mut num_done = 0; + let num_possible = + usize::min(self.buffer.remaining_capacity(), ns[lvl].remaining); + let def = ns[lvl].definition_depth; + + let validity = (*validity).clone().sliced(ns[lvl].offset, num_possible); + + // @NOTE: we can do something here with trailing ones and trailing zeros + for is_valid in validity.iter() { + num_done += 1; + let next_level_length = + ns[lvl].next_level_length(ns[lvl].offset + num_done - 1, is_valid); + + match (is_valid, next_level_length) { + (true, 0) => { + // Zero-sized (fixed) lists + push_value!(def); + }, + (true, _) => { + self.current_level += 1; + ns[lvl + 1].remaining = next_level_length; + break; + }, + (false, 0) => { + push_value!(def - 1); + }, + (false, _) => { + ns[lvl + 1].remaining = next_level_length; + + // @NOTE: + // This is needed for structs and fixed-size lists. These will have + // a non-zero length even if they are invalid. In that case, we + // need to skip over all the elements that would have been read if + // it was valid. + let mut embed_lvl = lvl + 1; + 'embed: while embed_lvl > lvl { + if embed_lvl == ns.len() - 1 { + ns[embed_lvl].offset += ns[embed_lvl].remaining; + } else { + while ns[embed_lvl].remaining > 0 { + let length = ns[embed_lvl] + .next_level_length(ns[embed_lvl].offset, false); + + ns[embed_lvl].offset += 1; + ns[embed_lvl].remaining -= 1; + + if length > 0 { + ns[embed_lvl + 1].remaining = length; + embed_lvl += 1; + continue 'embed; + } + } + } + + embed_lvl -= 1; + } + + push_value!(def - 1); + }, + } + } + + num_done + }, + (LevelContent::Repeated, _) => { + debug_assert!(!is_last_nesting); + let length = ns[lvl].next_level_length(ns[lvl].offset, true); + + if length == 0 { + push_value!(ns[lvl].definition_depth - 1); + } else { + self.current_level += 1; + ns[lvl + 1].remaining = length; + } + + 1 + }, + }; + + ns[lvl].offset += num_done; + ns[lvl].remaining -= num_done; + } + } +} + +impl Iterator for BufferedDremelIter<'_> { + type Item = DremelValue; + + fn next(&mut self) -> Option { + // Use an item from the buffer if it is available + if let Some(item) = self.buffer.pop_front() { + return Some(item); + } + + self.fill(); + self.buffer.pop_front() + } +} diff --git a/crates/polars-parquet/src/arrow/write/nested/dremel/tests.rs b/crates/polars-parquet/src/arrow/write/nested/dremel/tests.rs new file mode 100644 index 000000000000..783d353543b0 --- /dev/null +++ b/crates/polars-parquet/src/arrow/write/nested/dremel/tests.rs @@ -0,0 +1,993 @@ +use super::*; + +mod def { + use super::*; + use crate::write::pages::{ListNested, PrimitiveNested, StructNested}; + + fn test(nested: Vec, expected: Vec) { + let mut iter = BufferedDremelIter::new(&nested).map(|d| d.def); + // assert_eq!(iter.size_hint().0, expected.len()); + let result = iter.by_ref().collect::>(); + assert_eq!(result, expected); + // assert_eq!(iter.size_hint().0, 0); + } + + #[test] + fn struct_dbl_optional() { + let a = [true, true, true, false, true, true]; + let b = [true, false, true, false, false, true]; + let nested = vec![ + Nested::Struct(StructNested { + is_optional: true, + validity: Some(a.into()), + length: 6, + }), + Nested::Primitive(PrimitiveNested { + validity: Some(b.into()), + is_optional: true, + length: 6, + }), + ]; + let expected = vec![2, 1, 2, 0, 1, 2]; + + test(nested, expected) + } + + #[test] + fn struct_optional() { + let b = [ + true, false, true, true, false, true, false, false, true, true, + ]; + let nested = vec![ + Nested::Struct(StructNested { + is_optional: true, + validity: None, + length: 10, + }), + Nested::Primitive(PrimitiveNested { + validity: Some(b.into()), + is_optional: true, + length: 10, + }), + ]; + let expected = vec![2, 1, 2, 2, 1, 2, 1, 1, 2, 2]; + + test(nested, expected) + } + + #[test] + fn nested_edge_simple() { + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2].try_into().unwrap(), + validity: None, + }), + Nested::Primitive(PrimitiveNested { + validity: None, + is_optional: true, + length: 2, + }), + ]; + let expected = vec![3, 3]; + + test(nested, expected) + } + + #[test] + fn struct_optional_1() { + let b = [ + true, false, true, true, false, true, false, false, true, true, + ]; + let nested = vec![ + Nested::Struct(StructNested { + validity: None, + is_optional: true, + length: 10, + }), + Nested::Primitive(PrimitiveNested { + validity: Some(b.into()), + is_optional: true, + length: 10, + }), + ]; + let expected = vec![2, 1, 2, 2, 1, 2, 1, 1, 2, 2]; + + test(nested, expected) + } + + #[test] + fn struct_optional_optional() { + let nested = vec![ + Nested::Struct(StructNested { + is_optional: true, + validity: None, + length: 10, + }), + Nested::Primitive(PrimitiveNested { + validity: None, + is_optional: true, + length: 10, + }), + ]; + let expected = vec![2, 2, 2, 2, 2, 2, 2, 2, 2, 2]; + + test(nested, expected) + } + + #[test] + fn l1_required_required() { + let nested = vec![ + // [[0, 1], [], [2, 0, 3], [4, 5, 6], [], [7, 8, 9], [], [10]] + Nested::List(ListNested { + is_optional: false, + offsets: vec![0, 2, 2, 5, 8, 8, 11, 11, 12].try_into().unwrap(), + validity: None, + }), + Nested::Primitive(PrimitiveNested { + validity: None, + is_optional: false, + length: 12, + }), + ]; + let expected = vec![1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1]; + + test(nested, expected) + } + + #[test] + fn l1_optional_optional() { + // [[0, 1], None, [2, None, 3], [4, 5, 6], [], [7, 8, 9], None, [10]] + + let v0 = [true, false, true, true, true, true, false, true]; + let v1 = [ + true, true, //[0, 1] + true, false, true, //[2, None, 3] + true, true, true, //[4, 5, 6] + true, true, true, //[7, 8, 9] + true, //[10] + ]; + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2, 2, 5, 8, 8, 11, 11, 12].try_into().unwrap(), + validity: Some(v0.into()), + }), + Nested::Primitive(PrimitiveNested { + validity: Some(v1.into()), + is_optional: true, + length: 12, + }), + ]; + let expected = vec![3, 3, 0, 3, 2, 3, 3, 3, 3, 1, 3, 3, 3, 0, 3]; + + test(nested, expected) + } + + #[test] + fn l2_required_required_required() { + /* + [ + [ + [1,2,3], + [4,5,6,7], + ], + [ + [8], + [9, 10] + ] + ] + */ + let nested = vec![ + Nested::List(ListNested { + is_optional: false, + offsets: vec![0, 2, 4].try_into().unwrap(), + validity: None, + }), + Nested::List(ListNested { + is_optional: false, + offsets: vec![0, 3, 7, 8, 10].try_into().unwrap(), + validity: None, + }), + Nested::Primitive(PrimitiveNested { + validity: None, + is_optional: false, + length: 10, + }), + ]; + let expected = vec![2, 2, 2, 2, 2, 2, 2, 2, 2, 2]; + + test(nested, expected) + } + + #[test] + fn l2_optional_required_required() { + let a = [true, false, true, true]; + /* + [ + [ + [1,2,3], + [4,5,6,7], + ], + None, + [ + [8], + [], + [9, 10] + ] + ] + */ + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2, 2, 2, 5].try_into().unwrap(), + validity: Some(a.into()), + }), + Nested::List(ListNested { + is_optional: false, + offsets: vec![0, 3, 7, 8, 8, 10].try_into().unwrap(), + validity: None, + }), + Nested::Primitive(PrimitiveNested { + validity: None, + is_optional: false, + length: 10, + }), + ]; + let expected = vec![3, 3, 3, 3, 3, 3, 3, 0, 1, 3, 2, 3, 3]; + + test(nested, expected) + } + + mod fixedlist { + use super::*; + + #[test] + fn fsl() { + /* [ [ 1, 2 ], None, [ None, 3 ] ] */ + let a = [true, false, true]; + let b = [true, true, false, false, false, true]; + let nested = vec![ + Nested::fixed_size_list(Some(a.into()), true, 2, 3), + Nested::primitive(Some(b.into()), true, 6), + ]; + let expected = vec![3, 3, 0, 2, 3]; + + test(nested, expected) + } + + #[test] + fn fsl_fsl() { + // [ + // [ [ 1, 2, 3 ], [ 4, 5, 6 ] ], + // None, + // [ None, [ 7, None, 9 ] ], + // ] + let a = [true, false, true]; + let b = [true, true, true, true, false, true]; + let c = [ + true, true, true, true, true, true, false, false, false, false, false, false, + false, false, false, true, false, true, + ]; + let nested = vec![ + Nested::fixed_size_list(Some(a.into()), true, 2, 3), + Nested::fixed_size_list(Some(b.into()), true, 3, 6), + Nested::primitive(Some(c.into()), true, 18), + ]; + let expected = vec![5, 5, 5, 5, 5, 5, 0, 2, 5, 4, 5]; + + test(nested, expected) + } + + #[test] + fn fsl_fsl_1() { + // [ + // [ [1, 5, 2], [42, 13, 37] ], + // None, + // [ None, [3, 1, 3] ] + // ] + let a = [true, false, true]; + let b = [true, true, false, false, false, true]; + let c = [ + true, true, true, true, true, true, false, false, false, false, false, false, + false, false, false, true, true, true, + ]; + let nested = vec![ + Nested::fixed_size_list(Some(a.into()), true, 2, 3), + Nested::fixed_size_list(Some(b.into()), true, 3, 6), + Nested::primitive(Some(c.into()), true, 18), + ]; + let expected = vec![5, 5, 5, 5, 5, 5, 0, 2, 5, 5, 5]; + + test(nested, expected) + } + } + + mod simple { + use super::*; + + #[test] + fn none() { + /* [ None ] */ + let a = [false]; + let b = []; + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 0].try_into().unwrap(), + validity: Some(a.into()), + }), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0].try_into().unwrap(), + validity: Some(b.into()), + }), + Nested::Primitive(PrimitiveNested { + validity: None, + is_optional: false, + length: 0, + }), + ]; + let expected = vec![0]; + + test(nested, expected) + } + + #[test] + fn empty() { + /* [ [ ] ] */ + let a = [true]; + let b = []; + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 0].try_into().unwrap(), + validity: Some(a.into()), + }), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0].try_into().unwrap(), + validity: Some(b.into()), + }), + Nested::Primitive(PrimitiveNested { + validity: None, + is_optional: false, + length: 0, + }), + ]; + let expected = vec![1]; + + test(nested, expected) + } + + #[test] + fn list_none() { + /* [ [ None ] ] */ + let a = [true]; + let b = [false]; + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 1].try_into().unwrap(), + validity: Some(a.into()), + }), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 0].try_into().unwrap(), + validity: Some(b.into()), + }), + Nested::Primitive(PrimitiveNested { + validity: None, + is_optional: false, + length: 0, + }), + ]; + let expected = vec![2]; + + test(nested, expected) + } + + #[test] + fn list_empty() { + /* [ [ [] ] ] */ + let a = [true]; + let b = [true]; + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 1].try_into().unwrap(), + validity: Some(a.into()), + }), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 0].try_into().unwrap(), + validity: Some(b.into()), + }), + Nested::Primitive(PrimitiveNested { + validity: None, + is_optional: false, + length: 0, + }), + ]; + let expected = vec![3]; + + test(nested, expected) + } + + #[test] + fn list_list_one() { + /* [ [ [ 1 ] ] ] */ + let a = [true]; + let b = [true]; + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 1].try_into().unwrap(), + validity: Some(a.into()), + }), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 1].try_into().unwrap(), + validity: Some(b.into()), + }), + Nested::Primitive(PrimitiveNested { + validity: None, + is_optional: false, + length: 1, + }), + ]; + let expected = vec![4]; + + test(nested, expected) + } + } + + #[test] + fn l2_optional_optional_required() { + let a = [true, false, true]; + let b = [true, true, true, true, false]; + /* + [ + [ + [1,2,3], + [4,5,6,7], + ], + None, + [ + [8], + [], + None, + ], + ] + */ + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2, 2, 5].try_into().unwrap(), + validity: Some(a.into()), + }), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 3, 7, 8, 8, 8].try_into().unwrap(), + validity: Some(b.into()), + }), + Nested::Primitive(PrimitiveNested { + validity: None, + is_optional: false, + length: 8, + }), + ]; + let expected = vec![4, 4, 4, 4, 4, 4, 4, 0, 4, 3, 2]; + + test(nested, expected) + } + + #[test] + fn l2_optional_optional_optional() { + let a = [true, false, true]; + let b = [true, true, true, false]; + let c = [true, true, true, true, false, true, true, true]; + /* + [ + [ + [1,2,3], + [4,None,6,7], + ], + None, + [ + [8], + None, + ], + ] + */ + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2, 2, 4].try_into().unwrap(), + validity: Some(a.into()), + }), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 3, 7, 8, 8].try_into().unwrap(), + validity: Some(b.into()), + }), + Nested::Primitive(PrimitiveNested { + validity: Some(c.into()), + is_optional: true, + length: 8, + }), + ]; + let expected = vec![5, 5, 5, 5, 4, 5, 5, 0, 5, 2]; + + test(nested, expected) + } + + /* + [{"a": "a"}, {"a": "b"}], + None, + [{"a": "b"}, None, {"a": "b"}], + [{"a": None}, {"a": None}, {"a": None}], + [], + [{"a": "d"}, {"a": "d"}, {"a": "d"}], + None, + [{"a": "e"}], + */ + #[test] + fn nested_list_struct_nullable() { + let a = [ + true, true, true, false, true, false, false, false, true, true, true, true, + ]; + let b = [ + true, true, true, false, true, true, true, true, true, true, true, true, + ]; + let c = [true, false, true, true, true, true, false, true]; + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2, 2, 5, 8, 8, 11, 11, 12].try_into().unwrap(), + validity: Some(c.into()), + }), + Nested::Struct(StructNested { + validity: Some(b.into()), + is_optional: true, + length: 12, + }), + Nested::Primitive(PrimitiveNested { + validity: Some(a.into()), + is_optional: true, + length: 12, + }), + ]; + let expected = vec![4, 4, 0, 4, 2, 4, 3, 3, 3, 1, 4, 4, 4, 0, 4]; + + test(nested, expected) + } + + #[test] + fn nested_list_struct_nullable1() { + let c = [true, false]; + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 1, 1].try_into().unwrap(), + validity: Some(c.into()), + }), + Nested::Struct(StructNested { + validity: None, + is_optional: true, + length: 1, + }), + Nested::Primitive(PrimitiveNested { + validity: None, + is_optional: true, + length: 1, + }), + ]; + let expected = vec![4, 0]; + + test(nested, expected) + } + + #[test] + fn nested_struct_list_nullable() { + // [ + // { "a": [] }, + // { "a", [] }, + // ] + let a = [true, false, true, true, true, true, false, true]; + let b = [ + true, true, true, false, true, true, true, true, true, true, true, true, + ]; + let nested = vec![ + Nested::Struct(StructNested { + validity: None, + is_optional: true, + length: 8, + }), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2, 2, 5, 8, 8, 11, 11, 12].try_into().unwrap(), + validity: Some(a.into()), + }), + Nested::Primitive(PrimitiveNested { + validity: Some(b.into()), + is_optional: true, + length: 12, + }), + ]; + let expected = vec![4, 4, 1, 4, 3, 4, 4, 4, 4, 2, 4, 4, 4, 1, 4]; + + test(nested, expected) + } + + #[test] + fn nested_struct_list_nullable1() { + let a = [true, true, false]; + let nested = vec![ + Nested::Struct(StructNested { + validity: None, + is_optional: true, + length: 3, + }), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 1, 1, 1].try_into().unwrap(), + validity: Some(a.into()), + }), + Nested::Primitive(PrimitiveNested { + validity: None, + is_optional: true, + length: 1, + }), + ]; + let expected = vec![4, 2, 1]; + + test(nested, expected) + } + + #[test] + fn nested_list_struct_list_nullable1() { + /* + [ + [{"a": ["b"]}, None], + ] + */ + + let a = [true]; + let b = [true, false]; + let c = [true, false]; + let d = [true]; + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2].try_into().unwrap(), + validity: Some(a.into()), + }), + Nested::Struct(StructNested { + validity: Some(b.into()), + is_optional: true, + length: 2, + }), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 1, 1].try_into().unwrap(), + validity: Some(c.into()), + }), + Nested::Primitive(PrimitiveNested { + validity: Some(d.into()), + is_optional: true, + length: 1, + }), + ]; + /* + 0 6 + 1 6 + 0 0 + 0 6 + 1 2 + */ + let expected = vec![6, 2]; + + test(nested, expected) + } + + #[test] + fn nested_list_struct_list_nullable() { + /* + [ + [{"a": ["a"]}, {"a": ["b"]}], + None, + [{"a": ["b"]}, None, {"a": ["b"]}], + [{"a": None}, {"a": None}, {"a": None}], + [], + [{"a": ["d"]}, {"a": [None]}, {"a": ["c", "d"]}], + None, + [{"a": []}], + ] + */ + let a = [true, false, true, true, true, true, false, true]; + let b = [ + true, true, true, false, true, true, true, true, true, true, true, true, + ]; + let c = [ + true, true, true, false, true, false, false, false, true, true, true, true, + ]; + let d = [true, true, true, true, true, false, true, true]; + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2, 2, 5, 8, 8, 11, 11, 12].try_into().unwrap(), + validity: Some(a.into()), + }), + Nested::Struct(StructNested { + validity: Some(b.into()), + is_optional: true, + length: 12, + }), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 1, 2, 3, 3, 4, 4, 4, 4, 5, 6, 8, 8] + .try_into() + .unwrap(), + validity: Some(c.into()), + }), + Nested::Primitive(PrimitiveNested { + validity: Some(d.into()), + is_optional: true, + length: 8, + }), + ]; + let expected = vec![6, 6, 0, 6, 2, 6, 3, 3, 3, 1, 6, 5, 6, 6, 0, 4]; + + test(nested, expected) + } +} + +mod rep { + use super::super::super::super::pages::ListNested; + use super::*; + + fn test(nested: Vec, expected: Vec) { + let mut iter = BufferedDremelIter::new(&nested).map(|d| d.rep); + // assert_eq!(iter.size_hint().0, expected.len()); + assert_eq!(iter.by_ref().collect::>(), expected); + // assert_eq!(iter.size_hint().0, 0); + } + + #[test] + fn struct_required() { + let nested = vec![ + Nested::structure(None, false, 10), + Nested::primitive(None, true, 10), + ]; + let expected = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; + + test(nested, expected) + } + + #[test] + fn struct_optional() { + let nested = vec![ + Nested::structure(None, true, 10), + Nested::primitive(None, true, 10), + ]; + let expected = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; + + test(nested, expected) + } + + #[test] + fn l1() { + // [ + // [ 1, 2 ], + // [], + // [ 3, 4, 5 ], + // [ 6, 7, 8 ], + // [], + // [ 9, 10, 11 ], + // [], + // [ 12 ], + // ] + let nested = vec![ + Nested::list( + None, + false, + vec![0, 2, 2, 5, 8, 8, 11, 11, 12].try_into().unwrap(), + ), + Nested::primitive(None, false, 12), + ]; + let expected = vec![0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0]; + + test(nested, expected) + } + + #[test] + fn l2() { + let nested = vec![ + Nested::List(ListNested { + is_optional: false, + offsets: vec![0, 2, 2, 4].try_into().unwrap(), + validity: None, + }), + Nested::List(ListNested { + is_optional: false, + offsets: vec![0, 3, 7, 8, 10].try_into().unwrap(), + validity: None, + }), + Nested::primitive(None, false, 10), + ]; + let expected = vec![0, 2, 2, 1, 2, 2, 2, 0, 0, 1, 2]; + + test(nested, expected) + } + + #[test] + fn list_of_struct() { + /* + [ + [{"a": "b"}],[{"a": "c"}] + ] + */ + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 1, 2].try_into().unwrap(), + validity: None, + }), + Nested::structure(None, true, 2), + Nested::primitive(None, true, 2), + ]; + let expected = vec![0, 0]; + + test(nested, expected) + } + + #[test] + fn list_struct_list() { + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2, 3].try_into().unwrap(), + validity: None, + }), + Nested::structure(None, true, 3), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 3, 6, 7].try_into().unwrap(), + validity: None, + }), + Nested::primitive(None, true, 7), + ]; + let expected = vec![0, 2, 2, 1, 2, 2, 0]; + + test(nested, expected) + } + + #[test] + fn struct_list_optional() { + /* + {"f1": ["a", "b", None, "c"]} + */ + let nested = vec![ + Nested::structure(None, true, 1), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 4].try_into().unwrap(), + validity: None, + }), + Nested::primitive(None, true, 4), + ]; + let expected = vec![0, 1, 1, 1]; + + test(nested, expected) + } + + #[test] + fn l2_other() { + let nested = vec![ + Nested::List(ListNested { + is_optional: false, + offsets: vec![0, 1, 1, 3, 5, 5, 8, 8, 9].try_into().unwrap(), + validity: None, + }), + Nested::List(ListNested { + is_optional: false, + offsets: vec![0, 2, 4, 5, 7, 8, 9, 10, 11, 12].try_into().unwrap(), + validity: None, + }), + Nested::primitive(None, false, 12), + ]; + let expected = vec![0, 2, 0, 0, 2, 1, 0, 2, 1, 0, 0, 1, 1, 0, 0]; + + test(nested, expected) + } + + #[test] + fn list_struct_list_1() { + /* + [ + [{"a": ["a"]}, {"a": ["b"]}], + [], + [{"a": ["b"]}, None, {"a": ["b"]}], + [{"a": []}, {"a": []}, {"a": []}], + [], + [{"a": ["d"]}, {"a": ["a"]}, {"a": ["c", "d"]}], + [], + [{"a": []}], + ] + // reps: [0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 2, 0, 0] + */ + let a = [ + true, true, true, false, true, true, true, true, true, true, true, true, + ]; + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2, 2, 5, 8, 8, 11, 11, 12].try_into().unwrap(), + validity: None, + }), + Nested::structure(Some(a.into()), true, 12), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 1, 2, 2, 3, 3, 4, 4, 4, 4, 5, 7, 8] + .try_into() + .unwrap(), + validity: None, + }), + Nested::primitive(None, true, 8), + ]; + let expected = vec![0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 2, 0, 0]; + + test(nested, expected) + } + + #[test] + fn list_struct_list_2() { + /* + [ + [{"a": []}], + ] + // reps: [0] + */ + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 1].try_into().unwrap(), + validity: None, + }), + Nested::structure(None, true, 12), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 0].try_into().unwrap(), + validity: None, + }), + Nested::primitive(None, true, 0), + ]; + let expected = vec![0]; + + test(nested, expected) + } + + #[test] + fn list_struct_list_3() { + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 1, 1].try_into().unwrap(), + validity: None, + }), + Nested::structure(None, true, 12), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 0].try_into().unwrap(), + validity: None, + }), + Nested::primitive(None, true, 0), + ]; + let expected = vec![0, 0]; + // [1, 0], [0] + // pick last + + test(nested, expected) + } +} diff --git a/crates/polars-parquet/src/arrow/write/nested/mod.rs b/crates/polars-parquet/src/arrow/write/nested/mod.rs new file mode 100644 index 000000000000..bee67bbb4bae --- /dev/null +++ b/crates/polars-parquet/src/arrow/write/nested/mod.rs @@ -0,0 +1,109 @@ +mod dremel; + +pub use dremel::num_values; +use polars_error::PolarsResult; + +use super::Nested; +use crate::parquet::encoding::hybrid_rle::encode; +use crate::parquet::read::levels::get_bit_width; +use crate::parquet::write::Version; + +fn write_levels_v1) -> PolarsResult<()>>( + buffer: &mut Vec, + encode: F, +) -> PolarsResult<()> { + buffer.extend_from_slice(&[0; 4]); + let start = buffer.len(); + + encode(buffer)?; + + let end = buffer.len(); + let length = end - start; + + // write the first 4 bytes as length + let length = (length as i32).to_le_bytes(); + (0..4).for_each(|i| buffer[start - 4 + i] = length[i]); + Ok(()) +} + +/// writes the rep levels to a `Vec`. +fn write_rep_levels(buffer: &mut Vec, nested: &[Nested], version: Version) -> PolarsResult<()> { + let max_level = max_rep_level(nested) as i16; + if max_level == 0 { + return Ok(()); + } + let num_bits = get_bit_width(max_level); + + let levels = dremel::BufferedDremelIter::new(nested).map(|d| u32::from(d.rep)); + + match version { + Version::V1 => { + write_levels_v1(buffer, |buffer: &mut Vec| { + encode::(buffer, levels, num_bits)?; + Ok(()) + })?; + }, + Version::V2 => { + encode::(buffer, levels, num_bits)?; + }, + } + + Ok(()) +} + +/// writes the def levels to a `Vec`. +fn write_def_levels(buffer: &mut Vec, nested: &[Nested], version: Version) -> PolarsResult<()> { + let max_level = max_def_level(nested) as i16; + if max_level == 0 { + return Ok(()); + } + let num_bits = get_bit_width(max_level); + + let levels = dremel::BufferedDremelIter::new(nested).map(|d| u32::from(d.def)); + + match version { + Version::V1 => write_levels_v1(buffer, move |buffer: &mut Vec| { + encode::(buffer, levels, num_bits)?; + Ok(()) + }), + Version::V2 => Ok(encode::(buffer, levels, num_bits)?), + } +} + +fn max_def_level(nested: &[Nested]) -> usize { + nested + .iter() + .map(|nested| match nested { + Nested::Primitive(nested) => nested.is_optional as usize, + Nested::List(nested) => 1 + (nested.is_optional as usize), + Nested::LargeList(nested) => 1 + (nested.is_optional as usize), + Nested::Struct(nested) => nested.is_optional as usize, + Nested::FixedSizeList(nested) => 1 + nested.is_optional as usize, + }) + .sum() +} + +fn max_rep_level(nested: &[Nested]) -> usize { + nested + .iter() + .map(|nested| match nested { + Nested::FixedSizeList(_) | Nested::LargeList(_) | Nested::List(_) => 1, + Nested::Primitive(_) | Nested::Struct(_) => 0, + }) + .sum() +} + +/// Write `repetition_levels` and `definition_levels` to buffer. +pub fn write_rep_and_def( + page_version: Version, + nested: &[Nested], + buffer: &mut Vec, +) -> PolarsResult<(usize, usize)> { + write_rep_levels(buffer, nested, page_version)?; + let repetition_levels_byte_length = buffer.len(); + + write_def_levels(buffer, nested, page_version)?; + let definition_levels_byte_length = buffer.len() - repetition_levels_byte_length; + + Ok((repetition_levels_byte_length, definition_levels_byte_length)) +} diff --git a/crates/polars-parquet/src/arrow/write/pages.rs b/crates/polars-parquet/src/arrow/write/pages.rs new file mode 100644 index 000000000000..b3e6bfb0f08e --- /dev/null +++ b/crates/polars-parquet/src/arrow/write/pages.rs @@ -0,0 +1,964 @@ +use std::fmt::Debug; + +use arrow::array::{Array, FixedSizeListArray, ListArray, MapArray, StructArray}; +use arrow::bitmap::{Bitmap, MutableBitmap}; +use arrow::datatypes::PhysicalType; +use arrow::offset::{Offset, OffsetsBuffer}; +use polars_error::{PolarsResult, polars_bail}; + +use super::{Encoding, WriteOptions, array_to_pages}; +use crate::arrow::read::schema::is_nullable; +use crate::parquet::page::Page; +use crate::parquet::schema::types::{ParquetType, PrimitiveType as ParquetPrimitiveType}; +use crate::write::DynIter; + +#[derive(Debug, Clone, PartialEq)] +pub struct PrimitiveNested { + pub is_optional: bool, + pub validity: Option, + pub length: usize, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ListNested { + pub is_optional: bool, + pub offsets: OffsetsBuffer, + pub validity: Option, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct FixedSizeListNested { + pub validity: Option, + pub is_optional: bool, + pub width: usize, + pub length: usize, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct StructNested { + pub is_optional: bool, + pub validity: Option, + pub length: usize, +} + +impl ListNested { + pub fn new(offsets: OffsetsBuffer, validity: Option, is_optional: bool) -> Self { + Self { + is_optional, + offsets, + validity, + } + } +} + +/// Descriptor of nested information of a field +#[derive(Debug, Clone, PartialEq)] +pub enum Nested { + /// a primitive (leaf or parquet column) + Primitive(PrimitiveNested), + List(ListNested), + LargeList(ListNested), + FixedSizeList(FixedSizeListNested), + Struct(StructNested), +} + +impl Nested { + /// Returns the length (number of rows) of the element + pub fn len(&self) -> usize { + match self { + Nested::Primitive(nested) => nested.length, + Nested::List(nested) => nested.offsets.len_proxy(), + Nested::LargeList(nested) => nested.offsets.len_proxy(), + Nested::FixedSizeList(nested) => nested.length, + Nested::Struct(nested) => nested.length, + } + } + + pub fn primitive(validity: Option, is_optional: bool, length: usize) -> Self { + Self::Primitive(PrimitiveNested { + validity, + is_optional, + length, + }) + } + + pub fn list(validity: Option, is_optional: bool, offsets: OffsetsBuffer) -> Self { + Self::List(ListNested { + validity, + is_optional, + offsets, + }) + } + + pub fn large_list( + validity: Option, + is_optional: bool, + offsets: OffsetsBuffer, + ) -> Self { + Self::LargeList(ListNested { + validity, + is_optional, + offsets, + }) + } + + pub fn fixed_size_list( + validity: Option, + is_optional: bool, + width: usize, + length: usize, + ) -> Self { + Self::FixedSizeList(FixedSizeListNested { + validity, + is_optional, + width, + length, + }) + } + + pub fn structure(validity: Option, is_optional: bool, length: usize) -> Self { + Self::Struct(StructNested { + validity, + is_optional, + length, + }) + } +} + +/// Constructs the necessary `Vec>` to write the rep and def levels of `array` to parquet +pub fn to_nested(array: &dyn Array, type_: &ParquetType) -> PolarsResult>> { + let mut nested = vec![]; + + to_nested_recursive(array, type_, &mut nested, vec![])?; + Ok(nested) +} + +fn to_nested_recursive( + array: &dyn Array, + type_: &ParquetType, + nested: &mut Vec>, + mut parents: Vec, +) -> PolarsResult<()> { + let is_optional = is_nullable(type_.get_field_info()); + + use PhysicalType::*; + match array.dtype().to_physical_type() { + Struct => { + let array = array.as_any().downcast_ref::().unwrap(); + let fields = if let ParquetType::GroupType { fields, .. } = type_ { + fields + } else { + polars_bail!(InvalidOperation: + "Parquet type must be a group for a struct array", + ) + }; + + parents.push(Nested::Struct(StructNested { + is_optional, + validity: array.validity().cloned(), + length: array.len(), + })); + + for (type_, array) in fields.iter().zip(array.values()) { + to_nested_recursive(array.as_ref(), type_, nested, parents.clone())?; + } + }, + FixedSizeList => { + let array = array.as_any().downcast_ref::().unwrap(); + let type_ = if let ParquetType::GroupType { fields, .. } = type_ { + if let ParquetType::GroupType { fields, .. } = &fields[0] { + &fields[0] + } else { + polars_bail!(InvalidOperation: + "Parquet type must be a group for a list array", + ) + } + } else { + polars_bail!(InvalidOperation: + "Parquet type must be a group for a list array", + ) + }; + + parents.push(Nested::FixedSizeList(FixedSizeListNested { + validity: array.validity().cloned(), + length: array.len(), + width: array.size(), + is_optional, + })); + to_nested_recursive(array.values().as_ref(), type_, nested, parents)?; + }, + List => { + let array = array.as_any().downcast_ref::>().unwrap(); + let type_ = if let ParquetType::GroupType { fields, .. } = type_ { + if let ParquetType::GroupType { fields, .. } = &fields[0] { + &fields[0] + } else { + polars_bail!(InvalidOperation: + "Parquet type must be a group for a list array", + ) + } + } else { + polars_bail!(InvalidOperation: + "Parquet type must be a group for a list array", + ) + }; + + parents.push(Nested::List(ListNested::new( + array.offsets().clone(), + array.validity().cloned(), + is_optional, + ))); + to_nested_recursive(array.values().as_ref(), type_, nested, parents)?; + }, + LargeList => { + let array = array.as_any().downcast_ref::>().unwrap(); + let type_ = if let ParquetType::GroupType { fields, .. } = type_ { + if let ParquetType::GroupType { fields, .. } = &fields[0] { + &fields[0] + } else { + polars_bail!(InvalidOperation: + "Parquet type must be a group for a list array", + ) + } + } else { + polars_bail!(InvalidOperation: + "Parquet type must be a group for a list array", + ) + }; + + parents.push(Nested::LargeList(ListNested::new( + array.offsets().clone(), + array.validity().cloned(), + is_optional, + ))); + to_nested_recursive(array.values().as_ref(), type_, nested, parents)?; + }, + Map => { + let array = array.as_any().downcast_ref::().unwrap(); + let type_ = if let ParquetType::GroupType { fields, .. } = type_ { + if let ParquetType::GroupType { fields, .. } = &fields[0] { + &fields[0] + } else { + polars_bail!(InvalidOperation: + "Parquet type must be a group for a map array", + ) + } + } else { + polars_bail!(InvalidOperation: + "Parquet type must be a group for a map array", + ) + }; + + parents.push(Nested::List(ListNested::new( + array.offsets().clone(), + array.validity().cloned(), + is_optional, + ))); + to_nested_recursive(array.field().as_ref(), type_, nested, parents)?; + }, + _ => { + parents.push(Nested::Primitive(PrimitiveNested { + validity: array.validity().cloned(), + is_optional, + length: array.len(), + })); + nested.push(parents) + }, + } + Ok(()) +} + +fn expand_list_validity<'a, O: Offset>( + array: &'a ListArray, + validity: BitmapState, + array_stack: &mut Vec<(&'a dyn Array, BitmapState)>, +) { + let BitmapState::SomeSet(list_validity) = validity else { + array_stack.push(( + array.values().as_ref(), + match validity { + BitmapState::AllSet => BitmapState::AllSet, + BitmapState::SomeSet(_) => unreachable!(), + BitmapState::AllUnset(_) => BitmapState::AllUnset(array.values().len()), + }, + )); + return; + }; + + let offsets = array.offsets().buffer(); + let mut validity = MutableBitmap::with_capacity(array.values().len()); + let mut list_validity_iter = list_validity.iter(); + + // @NOTE: We need to take into account here that the list might only point to a slice of the + // values, therefore we need to extend the validity mask with dummy values to match the length + // of the values array. + + let mut idx = 0; + validity.extend_constant(offsets[0].to_usize(), false); + while list_validity_iter.num_remaining() > 0 { + let num_ones = list_validity_iter.take_leading_ones(); + let num_elements = offsets[idx + num_ones] - offsets[idx]; + validity.extend_constant(num_elements.to_usize(), true); + + idx += num_ones; + + let num_zeros = list_validity_iter.take_leading_zeros(); + let num_elements = offsets[idx + num_zeros] - offsets[idx]; + validity.extend_constant(num_elements.to_usize(), false); + + idx += num_zeros; + } + validity.extend_constant(array.values().len() - validity.len(), false); + + debug_assert_eq!(idx, array.len()); + let validity = validity.freeze(); + + debug_assert_eq!(validity.len(), array.values().len()); + array_stack.push((array.values().as_ref(), BitmapState::SomeSet(validity))); +} + +#[derive(Clone)] +enum BitmapState { + AllSet, + SomeSet(Bitmap), + AllUnset(usize), +} + +impl From> for BitmapState { + fn from(bm: Option<&Bitmap>) -> Self { + let Some(bm) = bm else { + return Self::AllSet; + }; + + let null_count = bm.unset_bits(); + + if null_count == 0 { + Self::AllSet + } else if null_count == bm.len() { + Self::AllUnset(bm.len()) + } else { + Self::SomeSet(bm.clone()) + } + } +} + +impl From for Option { + fn from(bms: BitmapState) -> Self { + match bms { + BitmapState::AllSet => None, + BitmapState::SomeSet(bm) => Some(bm), + BitmapState::AllUnset(len) => Some(Bitmap::new_zeroed(len)), + } + } +} + +impl std::ops::BitAnd for &BitmapState { + type Output = BitmapState; + + fn bitand(self, rhs: Self) -> Self::Output { + use BitmapState as B; + match (self, rhs) { + (B::AllSet, B::AllSet) => B::AllSet, + (B::AllSet, B::SomeSet(v)) | (B::SomeSet(v), B::AllSet) => B::SomeSet(v.clone()), + (B::SomeSet(lhs), B::SomeSet(rhs)) => { + let result = lhs & rhs; + let null_count = result.unset_bits(); + + if null_count == 0 { + B::AllSet + } else if null_count == result.len() { + B::AllUnset(result.len()) + } else { + B::SomeSet(result) + } + }, + (B::AllUnset(len), _) | (_, B::AllUnset(len)) => B::AllUnset(*len), + } + } +} + +/// Convert [`Array`] to a `Vec>` leaves in DFS order. +/// +/// Each leaf array has the validity propagated from the nesting levels above. +pub fn to_leaves(array: &dyn Array, leaves: &mut Vec>) { + use PhysicalType as P; + + leaves.clear(); + let mut array_stack: Vec<(&dyn Array, BitmapState)> = Vec::new(); + + array_stack.push((array, BitmapState::AllSet)); + + while let Some((array, inherited_validity)) = array_stack.pop() { + let child_validity = BitmapState::from(array.validity()); + let validity = (&child_validity) & (&inherited_validity); + + match array.dtype().to_physical_type() { + P::Struct => { + let array = array.as_any().downcast_ref::().unwrap(); + + leaves.reserve(array.len().saturating_sub(1)); + array + .values() + .iter() + .rev() + .for_each(|field| array_stack.push((field.as_ref(), validity.clone()))); + }, + P::List => { + let array = array.as_any().downcast_ref::>().unwrap(); + expand_list_validity(array, validity, &mut array_stack); + }, + P::LargeList => { + let array = array.as_any().downcast_ref::>().unwrap(); + expand_list_validity(array, validity, &mut array_stack); + }, + P::FixedSizeList => { + let array = array.as_any().downcast_ref::().unwrap(); + + let BitmapState::SomeSet(fsl_validity) = validity else { + array_stack.push(( + array.values().as_ref(), + match validity { + BitmapState::AllSet => BitmapState::AllSet, + BitmapState::SomeSet(_) => unreachable!(), + BitmapState::AllUnset(_) => BitmapState::AllUnset(array.values().len()), + }, + )); + continue; + }; + + let num_values = array.values().len(); + let size = array.size(); + + let mut validity = MutableBitmap::with_capacity(num_values); + let mut fsl_validity_iter = fsl_validity.iter(); + + let mut idx = 0; + while fsl_validity_iter.num_remaining() > 0 { + let num_ones = fsl_validity_iter.take_leading_ones(); + let num_elements = num_ones * size; + validity.extend_constant(num_elements, true); + + idx += num_ones; + + let num_zeros = fsl_validity_iter.take_leading_zeros(); + let num_elements = num_zeros * size; + validity.extend_constant(num_elements, false); + + idx += num_zeros; + } + + debug_assert_eq!(idx, array.len()); + + let validity = BitmapState::SomeSet(validity.freeze()); + + array_stack.push((array.values().as_ref(), validity)); + }, + P::Map => { + let array = array.as_any().downcast_ref::().unwrap(); + array_stack.push((array.field().as_ref(), validity)); + }, + P::Null + | P::Boolean + | P::Primitive(_) + | P::Binary + | P::FixedSizeBinary + | P::LargeBinary + | P::Utf8 + | P::LargeUtf8 + | P::Dictionary(_) + | P::BinaryView + | P::Utf8View => { + leaves.push(array.with_validity(validity.into())); + }, + + other => todo!("Writing {:?} to parquet not yet implemented", other), + } + } +} + +/// Convert `ParquetType` to `Vec` leaves in DFS order. +pub fn to_parquet_leaves(type_: ParquetType) -> Vec { + let mut leaves = vec![]; + to_parquet_leaves_recursive(type_, &mut leaves); + leaves +} + +fn to_parquet_leaves_recursive(type_: ParquetType, leaves: &mut Vec) { + match type_ { + ParquetType::PrimitiveType(primitive) => leaves.push(primitive), + ParquetType::GroupType { fields, .. } => { + fields + .into_iter() + .for_each(|type_| to_parquet_leaves_recursive(type_, leaves)); + }, + } +} + +/// Returns a vector of iterators of [`Page`], one per leaf column in the array +pub fn array_to_columns + Send + Sync>( + array: A, + type_: ParquetType, + options: WriteOptions, + encoding: &[Encoding], +) -> PolarsResult>>> { + let array = array.as_ref(); + + let nested = to_nested(array, &type_)?; + + let types = to_parquet_leaves(type_); + + let mut values = Vec::new(); + to_leaves(array, &mut values); + + assert_eq!(encoding.len(), types.len()); + + values + .iter() + .zip(nested) + .zip(types) + .zip(encoding.iter()) + .map(|(((values, nested), type_), encoding)| { + array_to_pages(values.as_ref(), type_, &nested, options, *encoding) + }) + .collect() +} + +pub fn arrays_to_columns + Send + Sync>( + arrays: &[A], + type_: ParquetType, + options: WriteOptions, + encoding: &[Encoding], +) -> PolarsResult>>> { + let array = arrays[0].as_ref(); + let nested = to_nested(array, &type_)?; + + let types = to_parquet_leaves(type_); + + // leaves; index level is nesting depth. + // index i: has a vec because we have multiple chunks. + let mut leaves = vec![]; + + // Ensure we transpose the leaves. So that all the leaves from the same columns are at the same level vec. + let mut scratch = vec![]; + for arr in arrays { + to_leaves(arr.as_ref(), &mut scratch); + for (i, leave) in std::mem::take(&mut scratch).into_iter().enumerate() { + while i < leaves.len() { + leaves.push(vec![]); + } + leaves[i].push(leave); + } + } + + leaves + .into_iter() + .zip(nested) + .zip(types) + .zip(encoding.iter()) + .map(move |(((values, nested), type_), encoding)| { + let iter = values.into_iter().map(|leave_values| { + array_to_pages( + leave_values.as_ref(), + type_.clone(), + &nested, + options, + *encoding, + ) + }); + + // Need a scratch to bubble up the error :/ + let mut scratch = Vec::with_capacity(iter.size_hint().0); + for v in iter { + scratch.push(v?) + } + Ok(DynIter::new(scratch.into_iter().flatten())) + }) + .collect::>>() +} + +#[cfg(test)] +mod tests { + use arrow::array::*; + use arrow::datatypes::*; + + use super::super::{FieldInfo, ParquetPhysicalType}; + use super::*; + use crate::parquet::schema::Repetition; + use crate::parquet::schema::types::{ + GroupLogicalType, PrimitiveConvertedType, PrimitiveLogicalType, + }; + + #[test] + fn test_struct() { + let boolean = BooleanArray::from_slice([false, false, true, true]).boxed(); + let int = Int32Array::from_slice([42, 28, 19, 31]).boxed(); + + let fields = vec![ + Field::new("b".into(), ArrowDataType::Boolean, false), + Field::new("c".into(), ArrowDataType::Int32, false), + ]; + + let array = StructArray::new( + ArrowDataType::Struct(fields), + 4, + vec![boolean.clone(), int.clone()], + Some(Bitmap::from([true, true, false, true])), + ); + + let type_ = ParquetType::GroupType { + field_info: FieldInfo { + name: "a".into(), + repetition: Repetition::Optional, + id: None, + }, + logical_type: None, + converted_type: None, + fields: vec![ + ParquetType::PrimitiveType(ParquetPrimitiveType { + field_info: FieldInfo { + name: "b".into(), + repetition: Repetition::Required, + id: None, + }, + logical_type: None, + converted_type: None, + physical_type: ParquetPhysicalType::Boolean, + }), + ParquetType::PrimitiveType(ParquetPrimitiveType { + field_info: FieldInfo { + name: "c".into(), + repetition: Repetition::Required, + id: None, + }, + logical_type: None, + converted_type: None, + physical_type: ParquetPhysicalType::Int32, + }), + ], + }; + let a = to_nested(&array, &type_).unwrap(); + + assert_eq!( + a, + vec![ + vec![ + Nested::structure(Some(Bitmap::from([true, true, false, true])), true, 4), + Nested::primitive(None, false, 4), + ], + vec![ + Nested::structure(Some(Bitmap::from([true, true, false, true])), true, 4), + Nested::primitive(None, false, 4), + ], + ] + ); + } + + #[test] + fn test_struct_struct() { + let boolean = BooleanArray::from_slice([false, false, true, true]).boxed(); + let int = Int32Array::from_slice([42, 28, 19, 31]).boxed(); + + let fields = vec![ + Field::new("b".into(), ArrowDataType::Boolean, false), + Field::new("c".into(), ArrowDataType::Int32, false), + ]; + + let array = StructArray::new( + ArrowDataType::Struct(fields), + 4, + vec![boolean.clone(), int.clone()], + Some(Bitmap::from([true, true, false, true])), + ); + + let fields = vec![ + Field::new("b".into(), array.dtype().clone(), true), + Field::new("c".into(), array.dtype().clone(), true), + ]; + + let array = StructArray::new( + ArrowDataType::Struct(fields), + 4, + vec![Box::new(array.clone()), Box::new(array)], + None, + ); + + let type_ = ParquetType::GroupType { + field_info: FieldInfo { + name: "a".into(), + repetition: Repetition::Optional, + id: None, + }, + logical_type: None, + converted_type: None, + fields: vec![ + ParquetType::PrimitiveType(ParquetPrimitiveType { + field_info: FieldInfo { + name: "b".into(), + repetition: Repetition::Required, + id: None, + }, + logical_type: None, + converted_type: None, + physical_type: ParquetPhysicalType::Boolean, + }), + ParquetType::PrimitiveType(ParquetPrimitiveType { + field_info: FieldInfo { + name: "c".into(), + repetition: Repetition::Required, + id: None, + }, + logical_type: None, + converted_type: None, + physical_type: ParquetPhysicalType::Int32, + }), + ], + }; + + let type_ = ParquetType::GroupType { + field_info: FieldInfo { + name: "a".into(), + repetition: Repetition::Required, + id: None, + }, + logical_type: None, + converted_type: None, + fields: vec![type_.clone(), type_], + }; + + let a = to_nested(&array, &type_).unwrap(); + + assert_eq!( + a, + vec![ + // a.b.b + vec![ + Nested::structure(None, false, 4), + Nested::structure(Some(Bitmap::from([true, true, false, true])), true, 4), + Nested::primitive(None, false, 4), + ], + // a.b.c + vec![ + Nested::structure(None, false, 4), + Nested::structure(Some(Bitmap::from([true, true, false, true])), true, 4), + Nested::primitive(None, false, 4), + ], + // a.c.b + vec![ + Nested::structure(None, false, 4), + Nested::structure(Some(Bitmap::from([true, true, false, true])), true, 4), + Nested::primitive(None, false, 4), + ], + // a.c.c + vec![ + Nested::structure(None, false, 4), + Nested::structure(Some(Bitmap::from([true, true, false, true])), true, 4), + Nested::primitive(None, false, 4), + ], + ] + ); + } + + #[test] + fn test_list_struct() { + let boolean = BooleanArray::from_slice([false, false, true, true]).boxed(); + let int = Int32Array::from_slice([42, 28, 19, 31]).boxed(); + + let fields = vec![ + Field::new("b".into(), ArrowDataType::Boolean, false), + Field::new("c".into(), ArrowDataType::Int32, false), + ]; + + let array = StructArray::new( + ArrowDataType::Struct(fields), + 4, + vec![boolean.clone(), int.clone()], + Some(Bitmap::from([true, true, false, true])), + ); + + let array = ListArray::new( + ArrowDataType::List(Box::new(Field::new( + "l".into(), + array.dtype().clone(), + true, + ))), + vec![0i32, 2, 4].try_into().unwrap(), + Box::new(array), + None, + ); + + let type_ = ParquetType::GroupType { + field_info: FieldInfo { + name: "a".into(), + repetition: Repetition::Optional, + id: None, + }, + logical_type: None, + converted_type: None, + fields: vec![ + ParquetType::PrimitiveType(ParquetPrimitiveType { + field_info: FieldInfo { + name: "b".into(), + repetition: Repetition::Required, + id: None, + }, + logical_type: None, + converted_type: None, + physical_type: ParquetPhysicalType::Boolean, + }), + ParquetType::PrimitiveType(ParquetPrimitiveType { + field_info: FieldInfo { + name: "c".into(), + repetition: Repetition::Required, + id: None, + }, + logical_type: None, + converted_type: None, + physical_type: ParquetPhysicalType::Int32, + }), + ], + }; + + let type_ = ParquetType::GroupType { + field_info: FieldInfo { + name: "l".into(), + repetition: Repetition::Required, + id: None, + }, + logical_type: None, + converted_type: None, + fields: vec![ParquetType::GroupType { + field_info: FieldInfo { + name: "list".into(), + repetition: Repetition::Repeated, + id: None, + }, + logical_type: None, + converted_type: None, + fields: vec![type_], + }], + }; + + let a = to_nested(&array, &type_).unwrap(); + + assert_eq!( + a, + vec![ + vec![ + Nested::List(ListNested:: { + is_optional: false, + offsets: vec![0, 2, 4].try_into().unwrap(), + validity: None, + }), + Nested::structure(Some(Bitmap::from([true, true, false, true])), true, 4), + Nested::primitive(None, false, 4), + ], + vec![ + Nested::List(ListNested:: { + is_optional: false, + offsets: vec![0, 2, 4].try_into().unwrap(), + validity: None, + }), + Nested::structure(Some(Bitmap::from([true, true, false, true])), true, 4), + Nested::primitive(None, false, 4), + ], + ] + ); + } + + #[test] + fn test_map() { + let kv_type = ArrowDataType::Struct(vec![ + Field::new("k".into(), ArrowDataType::Utf8, false), + Field::new("v".into(), ArrowDataType::Int32, false), + ]); + let kv_field = Field::new("kv".into(), kv_type.clone(), false); + let map_type = ArrowDataType::Map(Box::new(kv_field), false); + + let key_array = Utf8Array::::from_slice(["k1", "k2", "k3", "k4", "k5", "k6"]).boxed(); + let val_array = Int32Array::from_slice([42, 28, 19, 31, 21, 17]).boxed(); + let kv_array = StructArray::try_new(kv_type, 6, vec![key_array, val_array], None) + .unwrap() + .boxed(); + let offsets = OffsetsBuffer::try_from(vec![0, 2, 3, 4, 6]).unwrap(); + + let array = MapArray::try_new(map_type, offsets, kv_array, None).unwrap(); + + let type_ = ParquetType::GroupType { + field_info: FieldInfo { + name: "kv".into(), + repetition: Repetition::Optional, + id: None, + }, + logical_type: None, + converted_type: None, + fields: vec![ + ParquetType::PrimitiveType(ParquetPrimitiveType { + field_info: FieldInfo { + name: "k".into(), + repetition: Repetition::Required, + id: None, + }, + logical_type: Some(PrimitiveLogicalType::String), + converted_type: Some(PrimitiveConvertedType::Utf8), + physical_type: ParquetPhysicalType::ByteArray, + }), + ParquetType::PrimitiveType(ParquetPrimitiveType { + field_info: FieldInfo { + name: "v".into(), + repetition: Repetition::Required, + id: None, + }, + logical_type: None, + converted_type: None, + physical_type: ParquetPhysicalType::Int32, + }), + ], + }; + + let type_ = ParquetType::GroupType { + field_info: FieldInfo { + name: "m".into(), + repetition: Repetition::Required, + id: None, + }, + logical_type: Some(GroupLogicalType::Map), + converted_type: None, + fields: vec![ParquetType::GroupType { + field_info: FieldInfo { + name: "map".into(), + repetition: Repetition::Repeated, + id: None, + }, + logical_type: None, + converted_type: None, + fields: vec![type_], + }], + }; + + let a = to_nested(&array, &type_).unwrap(); + + assert_eq!( + a, + vec![ + vec![ + Nested::List(ListNested:: { + is_optional: false, + offsets: vec![0, 2, 3, 4, 6].try_into().unwrap(), + validity: None, + }), + Nested::structure(None, true, 6), + Nested::primitive(None, false, 6), + ], + vec![ + Nested::List(ListNested:: { + is_optional: false, + offsets: vec![0, 2, 3, 4, 6].try_into().unwrap(), + validity: None, + }), + Nested::structure(None, true, 6), + Nested::primitive(None, false, 6), + ], + ] + ); + } +} diff --git a/crates/polars-parquet/src/arrow/write/primitive/basic.rs b/crates/polars-parquet/src/arrow/write/primitive/basic.rs new file mode 100644 index 000000000000..9f113243279d --- /dev/null +++ b/crates/polars-parquet/src/arrow/write/primitive/basic.rs @@ -0,0 +1,243 @@ +use arrow::array::{Array, PrimitiveArray}; +use arrow::scalar::PrimitiveScalar; +use arrow::types::NativeType; +use polars_error::{PolarsResult, polars_bail}; + +use super::super::{WriteOptions, utils}; +use crate::arrow::read::schema::is_nullable; +use crate::arrow::write::utils::ExactSizedIter; +use crate::parquet::encoding::Encoding; +use crate::parquet::encoding::delta_bitpacked::encode; +use crate::parquet::page::DataPage; +use crate::parquet::schema::types::PrimitiveType; +use crate::parquet::statistics::PrimitiveStatistics; +use crate::parquet::types::NativeType as ParquetNativeType; +use crate::read::Page; +use crate::write::{EncodeNullability, StatisticsOptions}; + +pub(crate) fn encode_plain( + array: &PrimitiveArray, + options: EncodeNullability, + mut buffer: Vec, +) -> Vec +where + T: NativeType, + P: ParquetNativeType, + T: num_traits::AsPrimitive

, +{ + let is_optional = options.is_optional(); + + if is_optional { + // append the non-null values + let validity = array.validity(); + + if let Some(validity) = validity { + let null_count = validity.unset_bits(); + + if null_count > 0 { + let mut iter = validity.iter(); + let values = array.values().as_slice(); + + buffer.reserve(size_of::() * (array.len() - null_count)); + + let mut offset = 0; + let mut remaining_valid = array.len() - null_count; + while remaining_valid > 0 { + let num_valid = iter.take_leading_ones(); + buffer.extend( + values[offset..offset + num_valid] + .iter() + .flat_map(|value| value.as_().to_le_bytes()), + ); + remaining_valid -= num_valid; + offset += num_valid; + + let num_invalid = iter.take_leading_zeros(); + offset += num_invalid; + } + + return buffer; + } + } + } + + buffer.reserve(size_of::

() * array.len()); + buffer.extend( + array + .values() + .iter() + .flat_map(|value| value.as_().to_le_bytes()), + ); + + buffer +} + +pub(crate) fn encode_delta( + array: &PrimitiveArray, + options: EncodeNullability, + mut buffer: Vec, +) -> Vec +where + T: NativeType, + P: ParquetNativeType, + T: num_traits::AsPrimitive

, + P: num_traits::AsPrimitive, +{ + let is_optional = options.is_optional(); + + if is_optional { + // append the non-null values + let iterator = array.non_null_values_iter().map(|x| { + let parquet_native: P = x.as_(); + let integer: i64 = parquet_native.as_(); + integer + }); + let iterator = ExactSizedIter::new(iterator, array.len() - array.null_count()); + encode(iterator, &mut buffer, 1) + } else { + // append all values + let iterator = array.values().iter().map(|x| { + let parquet_native: P = x.as_(); + let integer: i64 = parquet_native.as_(); + integer + }); + encode(iterator, &mut buffer, 1) + } + buffer +} + +pub fn array_to_page_plain( + array: &PrimitiveArray, + options: WriteOptions, + type_: PrimitiveType, +) -> PolarsResult +where + T: NativeType, + P: ParquetNativeType, + T: num_traits::AsPrimitive

, +{ + array_to_page(array, options, type_, Encoding::Plain, encode_plain) +} + +pub fn array_to_page_integer( + array: &PrimitiveArray, + options: WriteOptions, + type_: PrimitiveType, + encoding: Encoding, +) -> PolarsResult +where + T: NativeType, + P: ParquetNativeType, + T: num_traits::AsPrimitive

, + P: num_traits::AsPrimitive, +{ + match encoding { + Encoding::Plain => array_to_page(array, options, type_, encoding, encode_plain), + Encoding::DeltaBinaryPacked => array_to_page(array, options, type_, encoding, encode_delta), + other => polars_bail!(nyi = "Encoding integer as {other:?}"), + } + .map(Page::Data) +} + +pub fn array_to_page, EncodeNullability, Vec) -> Vec>( + array: &PrimitiveArray, + options: WriteOptions, + type_: PrimitiveType, + encoding: Encoding, + encode: F, +) -> PolarsResult +where + T: NativeType, + P: ParquetNativeType, + // constraint required to build statistics + T: num_traits::AsPrimitive

, +{ + let is_optional = is_nullable(&type_.field_info); + let encode_options = EncodeNullability::new(is_optional); + + let validity = array.validity(); + + let mut buffer = vec![]; + utils::write_def_levels( + &mut buffer, + is_optional, + validity, + array.len(), + options.version, + )?; + + let definition_levels_byte_length = buffer.len(); + + let buffer = encode(array, encode_options, buffer); + + let statistics = if options.has_statistics() { + Some(build_statistics(array, type_.clone(), &options.statistics).serialize()) + } else { + None + }; + + utils::build_plain_page( + buffer, + array.len(), + array.len(), + array.null_count(), + 0, + definition_levels_byte_length, + statistics, + type_, + options, + encoding, + ) +} + +pub fn build_statistics( + array: &PrimitiveArray, + primitive_type: PrimitiveType, + options: &StatisticsOptions, +) -> PrimitiveStatistics

+where + T: NativeType, + P: ParquetNativeType, + T: num_traits::AsPrimitive

, +{ + let (min_value, max_value) = match (options.min_value, options.max_value) { + (true, true) => { + match polars_compute::min_max::dyn_array_min_max_propagate_nan(array as &dyn Array) { + None => (None, None), + Some((l, r)) => (Some(l), Some(r)), + } + }, + (true, false) => ( + polars_compute::min_max::dyn_array_min_propagate_nan(array as &dyn Array), + None, + ), + (false, true) => ( + None, + polars_compute::min_max::dyn_array_max_propagate_nan(array as &dyn Array), + ), + (false, false) => (None, None), + }; + + let min_value = min_value.and_then(|s| { + s.as_any() + .downcast_ref::>() + .unwrap() + .value() + .map(|x| x.as_()) + }); + let max_value = max_value.and_then(|s| { + s.as_any() + .downcast_ref::>() + .unwrap() + .value() + .map(|x| x.as_()) + }); + + PrimitiveStatistics::

(&self, mut pred: P) -> usize + where + P: FnMut(&T) -> bool, + { + self.exponential_search_by(|x| if pred(x) { Less } else { Greater }) + .unwrap_or_else(|i| i) + } +} + +impl ExponentialSearch for &[T] { + fn exponential_search_by(&self, mut f: F) -> Result + where + F: FnMut(&T) -> Ordering, + { + if self.is_empty() { + return Err(0); + } + + let mut bound = 1; + + while bound < self.len() { + // SAFETY + // Bound is always >=0 and < len. + let cmp = f(unsafe { self.get_unchecked(bound) }); + + if cmp == Greater { + break; + } + bound *= 2 + } + let end_bound = std::cmp::min(self.len(), bound); + // SAFETY: + // We checked the end bound and previous bound was within slice as per the `while` condition. + let prev_bound = bound / 2; + + let slice = unsafe { self.get_unchecked(prev_bound..end_bound) }; + + match slice.binary_search_by(f) { + Ok(i) => Ok(i + prev_bound), + Err(i) => Err(i + prev_bound), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_partition_point() { + let v = [1, 2, 3, 3, 5, 6, 7]; + let i = v.as_slice().partition_point_exponential(|&x| x < 5); + assert_eq!(i, 4); + } +} diff --git a/crates/polars-utils/src/cache.rs b/crates/polars-utils/src/cache.rs new file mode 100644 index 000000000000..268def2a2b51 --- /dev/null +++ b/crates/polars-utils/src/cache.rs @@ -0,0 +1,226 @@ +use std::borrow::Borrow; +use std::hash::{BuildHasher, Hash}; + +use foldhash::fast::RandomState; +use hashbrown::HashTable; +use hashbrown::hash_table::Entry; +use slotmap::{Key, SlotMap, new_key_type}; + +/// A cached function that use `LruCache`. +pub struct LruCachedFunc { + func: F, + cache: LruCache, +} + +impl LruCachedFunc +where + F: FnMut(T) -> R, + T: std::hash::Hash + Eq + Clone, + R: Copy, +{ + pub fn new(func: F, size: usize) -> Self { + Self { + func, + cache: LruCache::with_capacity(size.max(1)), + } + } + + pub fn eval(&mut self, x: T, use_cache: bool) -> R { + if use_cache { + *self + .cache + .get_or_insert_with(&x, |xr| (self.func)(xr.clone())) + } else { + (self.func)(x) + } + } +} + +new_key_type! { + struct LruKey; +} + +pub struct LruCache { + table: HashTable, + elements: SlotMap>, + max_capacity: usize, + most_recent: LruKey, + least_recent: LruKey, + build_hasher: S, +} + +struct LruEntry { + key: K, + value: V, + list: LruListNode, +} + +#[derive(Copy, Clone, Default)] +struct LruListNode { + more_recent: LruKey, + less_recent: LruKey, +} + +impl LruCache { + pub fn with_capacity(capacity: usize) -> Self { + Self::with_capacity_and_hasher(capacity, RandomState::default()) + } +} + +impl LruCache { + pub fn with_capacity_and_hasher(max_capacity: usize, build_hasher: S) -> Self { + assert!(max_capacity > 0); + Self { + // Allocate one more capacity to prevent double-lookup or realloc + // when doing get_or_insert when full. + table: HashTable::with_capacity(max_capacity + 1), + elements: SlotMap::with_capacity_and_key(max_capacity + 1), + max_capacity, + most_recent: LruKey::null(), + least_recent: LruKey::null(), + build_hasher, + } + } +} + +impl LruCache { + fn lru_list_unlink(&mut self, lru_key: LruKey) { + let list = self.elements[lru_key].list; + if let Some(more_recent) = self.elements.get_mut(list.more_recent) { + more_recent.list.less_recent = list.less_recent; + } else { + self.most_recent = list.less_recent; + } + if let Some(less_recent) = self.elements.get_mut(list.less_recent) { + less_recent.list.more_recent = list.more_recent; + } else { + self.least_recent = list.more_recent; + } + } + + fn lru_list_insert_mru(&mut self, lru_key: LruKey) { + let prev_most_recent_key = self.most_recent; + self.most_recent = lru_key; + if let Some(prev_most_recent) = self.elements.get_mut(prev_most_recent_key) { + prev_most_recent.list.more_recent = lru_key; + } else { + self.least_recent = lru_key; + } + let list = &mut self.elements[lru_key].list; + list.more_recent = LruKey::null(); + list.less_recent = prev_most_recent_key; + } + + pub fn pop_lru(&mut self) -> Option<(K, V)> { + if self.elements.is_empty() { + return None; + } + let lru_key = self.least_recent; + let hash = self.build_hasher.hash_one(&self.elements[lru_key].key); + self.lru_list_unlink(lru_key); + let lru_entry = self.elements.remove(lru_key).unwrap(); + self.table + .find_entry(hash, |k| *k == lru_key) + .unwrap() + .remove(); + Some((lru_entry.key, lru_entry.value)) + } + + pub fn get(&mut self, key: &Q) -> Option<&V> + where + K: Borrow, + Q: Hash + Eq + ?Sized, + { + let hash = self.build_hasher.hash_one(key); + let lru_key = *self + .table + .find(hash, |lru_key| self.elements[*lru_key].key.borrow() == key)?; + self.lru_list_unlink(lru_key); + self.lru_list_insert_mru(lru_key); + let lru_node = self.elements.get(lru_key).unwrap(); + Some(&lru_node.value) + } + + /// Returns the old value, if any. + pub fn insert(&mut self, key: K, value: V) -> Option { + let hash = self.build_hasher.hash_one(&key); + match self.table.entry( + hash, + |lru_key| self.elements[*lru_key].key == key, + |lru_key| self.build_hasher.hash_one(&self.elements[*lru_key].key), + ) { + Entry::Occupied(o) => { + let lru_key = *o.get(); + self.lru_list_unlink(lru_key); + self.lru_list_insert_mru(lru_key); + Some(core::mem::replace(&mut self.elements[lru_key].value, value)) + }, + + Entry::Vacant(v) => { + let lru_entry = LruEntry { + key, + value, + list: LruListNode::default(), + }; + let lru_key = self.elements.insert(lru_entry); + v.insert(lru_key); + self.lru_list_insert_mru(lru_key); + if self.elements.len() > self.max_capacity { + self.pop_lru(); + } + None + }, + } + } + + pub fn get_or_insert_with V>(&mut self, key: &Q, f: F) -> &mut V + where + K: Borrow, + Q: Hash + Eq + ToOwned + ?Sized, + { + enum Never {} + let Ok(ret) = self.try_get_or_insert_with::(key, |k| Ok(f(k))); + ret + } + + pub fn try_get_or_insert_with Result>( + &mut self, + key: &Q, + f: F, + ) -> Result<&mut V, E> + where + K: Borrow, + Q: Hash + Eq + ToOwned + ?Sized, + { + let hash = self.build_hasher.hash_one(key); + match self.table.entry( + hash, + |lru_key| self.elements[*lru_key].key.borrow() == key, + |lru_key| self.build_hasher.hash_one(&self.elements[*lru_key].key), + ) { + Entry::Occupied(o) => { + let lru_key = *o.get(); + if lru_key != self.most_recent { + self.lru_list_unlink(lru_key); + self.lru_list_insert_mru(lru_key); + } + Ok(&mut self.elements[lru_key].value) + }, + + Entry::Vacant(v) => { + let lru_entry = LruEntry { + value: f(key)?, + key: key.to_owned(), + list: LruListNode::default(), + }; + let lru_key = self.elements.insert(lru_entry); + v.insert(lru_key); + self.lru_list_insert_mru(lru_key); + if self.elements.len() > self.max_capacity { + self.pop_lru(); + } + Ok(&mut self.elements[lru_key].value) + }, + } + } +} diff --git a/crates/polars-utils/src/cardinality_sketch.rs b/crates/polars-utils/src/cardinality_sketch.rs new file mode 100644 index 000000000000..bfac7b297214 --- /dev/null +++ b/crates/polars-utils/src/cardinality_sketch.rs @@ -0,0 +1,74 @@ +use crate::algebraic_ops::alg_add_f64; + +// Computes 2^-n by directly subtracting from the IEEE754 double exponent. +fn inv_pow2(n: u8) -> f64 { + let base = f64::to_bits(1.0); + f64::from_bits(base - ((n as u64) << 52)) +} + +/// HyperLogLog in Practice: Algorithmic Engineering of +/// a State of The Art Cardinality Estimation Algorithm +/// Stefan Heule, Marc Nunkesser, Alexander Hall +/// +/// We use m = 256 which gives a relative error of ~6.5% of the cardinality +/// estimate. We don't bother with stuffing the counts in 6 bits, byte access is +/// fast. +/// +/// The bias correction described in the paper is not implemented, so this is +/// somewhere in between HyperLogLog and HyperLogLog++. +#[derive(Clone)] +pub struct CardinalitySketch { + buckets: Box<[u8; 256]>, +} + +impl Default for CardinalitySketch { + fn default() -> Self { + Self::new() + } +} + +impl CardinalitySketch { + pub fn new() -> Self { + Self { + // This compiles to alloc_zeroed directly. + buckets: vec![0u8; 256].into_boxed_slice().try_into().unwrap(), + } + } + + /// Add a new hash to the sketch. + pub fn insert(&mut self, mut h: u64) { + const ARBITRARY_ODD: u64 = 0x902813a5785dc787; + // We multiply by this arbitrarily chosen odd number and then take the + // top bits to ensure the sketch is influenced by all bits of the hash. + h = h.wrapping_mul(ARBITRARY_ODD); + let idx = (h >> 56) as usize; + let p = 1 + (h << 8).leading_zeros() as u8; + self.buckets[idx] = self.buckets[idx].max(p); + } + + pub fn combine(&mut self, other: &CardinalitySketch) { + *self.buckets = std::array::from_fn(|i| std::cmp::max(self.buckets[i], other.buckets[i])); + } + + pub fn estimate(&self) -> usize { + let m = 256.0; + let alpha_m = 0.7123 / (1.0 + 1.079 / m); + + let mut sum = 0.0; + let mut num_zero = 0; + for x in self.buckets.iter() { + sum = alg_add_f64(sum, inv_pow2(*x)); + num_zero += (*x == 0) as usize; + } + + let est = (alpha_m * m * m) / sum; + let corr_est = if est <= 5.0 / 2.0 * m && num_zero != 0 { + // Small cardinality estimate, full 64-bit logarithm is overkill. + m * (m as f32 / num_zero as f32).ln() as f64 + } else { + est + }; + + corr_est as usize + } +} diff --git a/crates/polars-utils/src/cell.rs b/crates/polars-utils/src/cell.rs new file mode 100644 index 000000000000..ae6b6ae461fc --- /dev/null +++ b/crates/polars-utils/src/cell.rs @@ -0,0 +1,85 @@ +//! Copy pasted from std::cell::SyncUnsafeCell +//! can be removed once the feature stabilizes. +use std::cell::UnsafeCell; + +/// [`UnsafeCell`], but [`Sync`]. +/// +/// This is just an [`UnsafeCell`], except it implements [`Sync`] +/// if `T` implements [`Sync`]. +/// +/// [`UnsafeCell`] doesn't implement [`Sync`], to prevent accidental misuse. +/// You can use [`SyncUnsafeCell`] instead of [`UnsafeCell`] to allow it to be +/// shared between threads, if that's intentional. +/// Providing proper synchronization is still the task of the user, +/// making this type just as unsafe to use. +/// +/// See [`UnsafeCell`] for details. +#[repr(transparent)] +pub struct SyncUnsafeCell { + value: UnsafeCell, +} + +unsafe impl Sync for SyncUnsafeCell {} + +impl SyncUnsafeCell { + /// Constructs a new instance of [`SyncUnsafeCell`] which will wrap the specified value. + #[inline] + pub fn new(value: T) -> Self { + Self { + value: UnsafeCell::new(value), + } + } + + /// Unwraps the value. + #[inline] + pub fn into_inner(self) -> T { + self.value.into_inner() + } +} + +impl SyncUnsafeCell { + /// Gets a mutable pointer to the wrapped value. + /// + /// This can be cast to a pointer of any kind. + /// Ensure that the access is unique (no active references, mutable or not) + /// when casting to `&mut T`, and ensure that there are no mutations + /// or mutable aliases going on when casting to `&T` + #[inline] + pub fn get(&self) -> *mut T { + self.value.get() + } + + /// Returns a mutable reference to the underlying data. + /// + /// This call borrows the [`SyncUnsafeCell`] mutably (at compile-time) which + /// guarantees that we possess the only reference. + #[inline] + pub fn get_mut(&mut self) -> &mut T { + self.value.get_mut() + } + + /// Gets a mutable pointer to the wrapped value. + /// + /// See [`UnsafeCell::get`] for details. + #[inline] + pub fn raw_get(this: *const Self) -> *mut T { + // We can just cast the pointer from `SyncUnsafeCell` to `T` because + // of #[repr(transparent)] on both SyncUnsafeCell and UnsafeCell. + // See UnsafeCell::raw_get. + this as *const T as *mut T + } +} + +impl Default for SyncUnsafeCell { + /// Creates an `SyncUnsafeCell`, with the `Default` value for T. + fn default() -> SyncUnsafeCell { + SyncUnsafeCell::new(Default::default()) + } +} + +impl From for SyncUnsafeCell { + /// Creates a new [`SyncUnsafeCell`] containing the given value. + fn from(t: T) -> SyncUnsafeCell { + SyncUnsafeCell::new(t) + } +} diff --git a/crates/polars-utils/src/chunks.rs b/crates/polars-utils/src/chunks.rs new file mode 100644 index 000000000000..80f720eb2fd7 --- /dev/null +++ b/crates/polars-utils/src/chunks.rs @@ -0,0 +1,63 @@ +/// A copy of the [`std::slice::Chunks`] that exposes the inner `slice` and `chunk_size`. +#[derive(Clone, Debug)] +pub struct Chunks<'a, T> { + slice: &'a [T], + chunk_size: usize, +} + +impl<'a, T> Iterator for Chunks<'a, T> { + type Item = &'a [T]; + + fn next(&mut self) -> Option { + if self.slice.is_empty() { + return None; + } + + let item; + (item, self.slice) = self.slice.split_at(self.chunk_size.min(self.slice.len())); + + Some(item) + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.slice.len().div_ceil(self.chunk_size); + (len, Some(len)) + } +} + +impl DoubleEndedIterator for Chunks<'_, T> { + fn next_back(&mut self) -> Option { + if self.slice.is_empty() { + return None; + } + + let rem = self.slice.len() % self.chunk_size; + let offset = if rem == 0 { self.chunk_size } else { rem }; + + let item; + (self.slice, item) = self.slice.split_at(self.slice.len() - offset); + + Some(item) + } +} + +impl ExactSizeIterator for Chunks<'_, T> {} + +impl<'a, T> Chunks<'a, T> { + pub const fn new(slice: &'a [T], chunk_size: usize) -> Self { + Self { slice, chunk_size } + } + + pub const fn as_slice(&self) -> &'a [T] { + self.slice + } + + pub const fn chunk_size(&self) -> usize { + self.chunk_size + } + + pub fn skip_in_place(&mut self, n: usize) { + let n = n * self.chunk_size; + self.slice = &self.slice[n.min(self.slice.len())..]; + } +} diff --git a/crates/polars-utils/src/clmul.rs b/crates/polars-utils/src/clmul.rs new file mode 100644 index 000000000000..e5a3258b7909 --- /dev/null +++ b/crates/polars-utils/src/clmul.rs @@ -0,0 +1,172 @@ +#[cfg(all(target_arch = "x86_64", target_feature = "pclmulqdq"))] +fn intel_clmul64(x: u64, y: u64) -> u64 { + use core::arch::x86_64::*; + unsafe { + // SAFETY: we have the target feature. + _mm_cvtsi128_si64(_mm_clmulepi64_si128( + _mm_cvtsi64_si128(x as i64), + _mm_cvtsi64_si128(y as i64), + 0, + )) as u64 + } +} + +#[cfg(all( + target_arch = "aarch64", + target_feature = "neon", + target_feature = "aes" +))] +fn arm_clmul64(x: u64, y: u64) -> u64 { + unsafe { + // SAFETY: we have the target feature. + use core::arch::aarch64::*; + vmull_p64(x, y) as u64 + } +} + +#[inline] +pub fn portable_clmul64(x: u64, mut y: u64) -> u64 { + let mut out = 0; + while y > 0 { + let lsb = y & y.wrapping_neg(); + out ^= x.wrapping_mul(lsb); + y ^= lsb; + } + out +} + +// Computes the carryless multiplication of x and y. +#[inline] +pub fn clmul64(x: u64, y: u64) -> u64 { + #[cfg(all(target_arch = "x86_64", target_feature = "pclmulqdq"))] + return intel_clmul64(x, y); + + #[cfg(all( + target_arch = "aarch64", + target_feature = "neon", + target_feature = "aes" + ))] + return arm_clmul64(x, y); + + #[allow(unreachable_code)] + portable_clmul64(x, y) +} + +#[inline] +pub fn portable_prefix_xorsum(x: u64) -> u64 { + portable_prefix_xorsum_inclusive(x << 1) +} + +// Computes for each bit i the XOR of bits[0..i]. +#[inline] +pub fn prefix_xorsum(x: u64) -> u64 { + #[cfg(all(target_arch = "x86_64", target_feature = "pclmulqdq"))] + return intel_clmul64(x, u64::MAX ^ 1); + + #[cfg(all( + target_arch = "aarch64", + target_feature = "neon", + target_feature = "aes" + ))] + return arm_clmul64(x, u64::MAX ^ 1); + + #[allow(unreachable_code)] + portable_prefix_xorsum(x) +} + +#[inline] +pub fn portable_prefix_xorsum_inclusive(mut x: u64) -> u64 { + for i in 0..6 { + x ^= x << (1 << i); + } + x +} + +// Computes for each bit i the XOR of bits[0..=i]. +#[inline] +pub fn prefix_xorsum_inclusive(x: u64) -> u64 { + #[cfg(all(target_arch = "x86_64", target_feature = "pclmulqdq"))] + return intel_clmul64(x, u64::MAX); + + #[cfg(all( + target_arch = "aarch64", + target_feature = "neon", + target_feature = "aes" + ))] + return arm_clmul64(x, u64::MAX); + + #[allow(unreachable_code)] + portable_prefix_xorsum_inclusive(x) +} + +#[cfg(test)] +mod test { + use rand::prelude::*; + + use super::*; + + #[test] + fn test_clmul() { + // Verify platform-specific clmul to portable. + let mut rng = StdRng::seed_from_u64(0xdeadbeef); + for _ in 0..100 { + let x = rng.r#gen(); + let y = rng.r#gen(); + assert_eq!(portable_clmul64(x, y), clmul64(x, y)); + } + + // Verify portable clmul for known test vectors. + assert_eq!( + portable_clmul64(0x8b44729195dde0ef, 0xb976c5ae2726fab0), + 0x4ae14eae84899290 + ); + assert_eq!( + portable_clmul64(0x399b6ed00c44b301, 0x693341db5acb2ff0), + 0x48dfa88344823ff0 + ); + assert_eq!( + portable_clmul64(0xdf4c9f6e60deb640, 0x6d4bcdb217ac4880), + 0x7300ffe474792000 + ); + assert_eq!( + portable_clmul64(0xa7adf3c53a200a51, 0x818cb40fe11b431e), + 0x6a280181d521797e + ); + assert_eq!( + portable_clmul64(0x5e78e12b744f228c, 0x4225ff19e9273266), + 0xa48b73cafb9665a8 + ); + } + + #[test] + fn test_prefix_xorsum() { + // Verify platform-specific prefix_xorsum to portable. + let mut rng = StdRng::seed_from_u64(0xdeadbeef); + for _ in 0..100 { + let x = rng.r#gen(); + assert_eq!(portable_prefix_xorsum(x), prefix_xorsum(x)); + } + + // Verify portable prefix_xorsum for known test vectors. + assert_eq!( + portable_prefix_xorsum(0x8b44729195dde0ef), + 0x0d87a31ee696bf4a + ); + assert_eq!( + portable_prefix_xorsum(0xb976c5ae2726fab0), + 0x2e5b79343a3b5320 + ); + assert_eq!( + portable_prefix_xorsum(0x399b6ed00c44b301), + 0xd1124b600878ddfe + ); + assert_eq!( + portable_prefix_xorsum(0x693341db5acb2ff0), + 0x4e227e926c8dcaa0 + ); + assert_eq!( + portable_prefix_xorsum(0xdf4c9f6e60deb640), + 0x6a7715b44094db80 + ); + } +} diff --git a/crates/polars-utils/src/config.rs b/crates/polars-utils/src/config.rs new file mode 100644 index 000000000000..837d1126e323 --- /dev/null +++ b/crates/polars-utils/src/config.rs @@ -0,0 +1,3 @@ +pub(crate) fn verbose() -> bool { + std::env::var("POLARS_VERBOSE").as_deref().unwrap_or("") == "1" +} diff --git a/crates/polars-utils/src/cpuid.rs b/crates/polars-utils/src/cpuid.rs new file mode 100644 index 000000000000..37f64d158a6f --- /dev/null +++ b/crates/polars-utils/src/cpuid.rs @@ -0,0 +1,63 @@ +// So much conditional stuff going on here... +#![allow(dead_code, unreachable_code, unused)] + +use std::sync::OnceLock; + +#[cfg(target_arch = "x86_64")] +use raw_cpuid::CpuId; + +#[cfg(target_feature = "bmi2")] +#[inline(never)] +#[cold] +fn detect_fast_bmi2() -> bool { + let cpu_id = CpuId::new(); + let vendor = cpu_id.get_vendor_info().expect("could not read cpu vendor"); + if vendor.as_str() == "AuthenticAMD" || vendor.as_str() == "HygonGenuine" { + let features = cpu_id + .get_feature_info() + .expect("could not read cpu feature info"); + let family_id = features.family_id(); + + // Hardcoded blacklist of known-bad AMD families. + // We'll assume any future releases that support BMI2 have a + // proper implementation. + !(0x15..=0x18).contains(&family_id) + } else { + true + } +} + +#[inline(always)] +pub fn has_fast_bmi2() -> bool { + #[cfg(target_feature = "bmi2")] + { + static CACHE: OnceLock = OnceLock::new(); + return *CACHE.get_or_init(detect_fast_bmi2); + } + + false +} + +#[inline] +pub fn is_avx512_enabled() -> bool { + #[cfg(target_arch = "x86_64")] + { + static CACHE: OnceLock = OnceLock::new(); + return *CACHE.get_or_init(|| { + if !std::arch::is_x86_feature_detected!("avx512f") { + return false; + } + + if std::env::var("POLARS_DISABLE_AVX512") + .map(|var| var == "1") + .unwrap_or(false) + { + return false; + } + + true + }); + } + + false +} diff --git a/crates/polars-utils/src/error.rs b/crates/polars-utils/src/error.rs new file mode 100644 index 000000000000..861731ece753 --- /dev/null +++ b/crates/polars-utils/src/error.rs @@ -0,0 +1,50 @@ +use std::borrow::Cow; +use std::fmt::{Display, Formatter}; + +use crate::config::verbose; +use crate::format_pl_smallstr; + +type ErrString = Cow<'static, str>; + +#[derive(Debug)] +pub enum PolarsUtilsError { + ComputeError(ErrString), +} + +impl Display for PolarsUtilsError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + PolarsUtilsError::ComputeError(s) => { + let s = s.as_ref(); + write!(f, "{s}") + }, + } + } +} + +pub type Result = std::result::Result; + +/// Utility whose Display impl truncates the string unless POLARS_VERBOSE is set. +pub struct TruncateErrorDetail<'a>(pub &'a str); + +impl std::fmt::Display for TruncateErrorDetail<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let maybe_truncated = if verbose() { + self.0 + } else { + // Clamp the output on non-verbose + &self.0[..self.0.len().min(4096)] + }; + + f.write_str(maybe_truncated)?; + + if maybe_truncated.len() != self.0.len() { + let n_more = self.0.len() - maybe_truncated.len(); + f.write_str(" ...(set POLARS_VERBOSE=1 to see full response (")?; + f.write_str(&format_pl_smallstr!("{}", n_more))?; + f.write_str(" more characters))")?; + }; + + Ok(()) + } +} diff --git a/crates/polars-utils/src/file.rs b/crates/polars-utils/src/file.rs new file mode 100644 index 000000000000..60271e0a492d --- /dev/null +++ b/crates/polars-utils/src/file.rs @@ -0,0 +1,83 @@ +use std::fs::File; +use std::io::{Read, Seek, Write}; + +impl From for ClosableFile { + fn from(value: File) -> Self { + ClosableFile { inner: value } + } +} + +impl From for File { + fn from(value: ClosableFile) -> Self { + value.inner + } +} + +pub struct ClosableFile { + inner: File, +} + +impl ClosableFile { + #[cfg(unix)] + pub fn close(self) -> std::io::Result<()> { + use std::os::fd::IntoRawFd; + let fd = self.inner.into_raw_fd(); + + match unsafe { libc::close(fd) } { + 0 => Ok(()), + _ => Err(std::io::Error::last_os_error()), + } + } + + #[cfg(not(unix))] + pub fn close(self) -> std::io::Result<()> { + Ok(()) + } +} + +impl AsMut for ClosableFile { + fn as_mut(&mut self) -> &mut File { + &mut self.inner + } +} + +impl AsRef for ClosableFile { + fn as_ref(&self) -> &File { + &self.inner + } +} + +impl Seek for ClosableFile { + fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result { + self.inner.seek(pos) + } +} + +impl Read for ClosableFile { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + self.inner.read(buf) + } +} + +impl Write for ClosableFile { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.inner.write(buf) + } + + fn flush(&mut self) -> std::io::Result<()> { + self.inner.flush() + } +} + +pub trait WriteClose: Write { + fn close(self: Box) -> std::io::Result<()> { + Ok(()) + } +} + +impl WriteClose for ClosableFile { + fn close(self: Box) -> std::io::Result<()> { + let f = *self; + f.close() + } +} diff --git a/crates/polars-utils/src/fixedringbuffer.rs b/crates/polars-utils/src/fixedringbuffer.rs new file mode 100644 index 000000000000..4f88824ddc9c --- /dev/null +++ b/crates/polars-utils/src/fixedringbuffer.rs @@ -0,0 +1,216 @@ +/// A ring-buffer with a size determined at creation-time +/// +/// This makes it perfectly suited for buffers that produce and consume at different speeds. +pub struct FixedRingBuffer { + start: usize, + length: usize, + buffer: *mut T, + /// The wanted fixed capacity in the buffer + capacity: usize, + + /// The actually allocated capacity, this should not be used for any calculations and it purely + /// used for the deallocation. + _buffer_capacity: usize, +} + +#[inline(always)] +const fn wrapping_add(x: usize, n: usize, capacity: usize) -> usize { + assert!(n <= capacity); + + let sub = if capacity - n <= x { capacity } else { 0 }; + + x.wrapping_add(n).wrapping_sub(sub) +} + +impl FixedRingBuffer { + pub fn new(capacity: usize) -> Self { + let buffer = Vec::with_capacity(capacity); + + Self { + start: 0, + length: 0, + + _buffer_capacity: buffer.capacity(), + buffer: buffer.leak() as *mut [T] as *mut T, + capacity, + } + } + + #[inline(always)] + pub const fn len(&self) -> usize { + self.length + } + + #[inline(always)] + pub const fn capacity(&self) -> usize { + self.capacity + } + + #[inline(always)] + pub const fn remaining_capacity(&self) -> usize { + self.capacity - self.len() + } + + #[inline(always)] + pub const fn is_empty(&self) -> bool { + self.length == 0 + } + + #[inline(always)] + pub const fn is_full(&self) -> bool { + self.len() == self.capacity + } + + /// Get a reference to all elements in the form of two slices. + /// + /// These are in the listed in the order of being pushed into the buffer. + #[inline] + pub fn as_slices(&self) -> (&[T], &[T]) { + // SAFETY: Only pick the part that is actually defined + if self.capacity - self.length > self.start { + ( + unsafe { + std::slice::from_raw_parts(self.buffer.wrapping_add(self.start), self.length) + }, + &[], + ) + } else { + ( + unsafe { + std::slice::from_raw_parts( + self.buffer.wrapping_add(self.start), + self.capacity - self.start, + ) + }, + unsafe { + std::slice::from_raw_parts( + self.buffer, + wrapping_add(self.start, self.length, self.capacity), + ) + }, + ) + } + } + + /// Pop an item at the front of the [`FixedRingBuffer`] + #[inline] + pub fn pop_front(&mut self) -> Option { + if self.is_empty() { + return None; + } + + // SAFETY: This value is never read again + let item = unsafe { self.buffer.wrapping_add(self.start).read() }; + self.start = wrapping_add(self.start, 1, self.capacity); + self.length -= 1; + Some(item) + } + + /// Push an item into the [`FixedRingBuffer`] + /// + /// Returns `None` if there is no more space + #[inline] + pub fn push(&mut self, value: T) -> Option<()> { + if self.is_full() { + return None; + } + + let offset = wrapping_add(self.start, self.len(), self.capacity); + + unsafe { self.buffer.wrapping_add(offset).write(value) }; + self.length += 1; + + Some(()) + } +} + +impl FixedRingBuffer { + /// Add at most `num` items of `value` into the [`FixedRingBuffer`] + /// + /// This returns the amount of items actually added. + pub fn fill_repeat(&mut self, value: T, num: usize) -> usize { + if num == 0 || self.is_full() { + return 0; + } + + let num = usize::min(num, self.remaining_capacity()); + + let start = wrapping_add(self.start, self.len(), self.capacity); + let end = wrapping_add(start, num, self.capacity); + + if start < end { + unsafe { std::slice::from_raw_parts_mut(self.buffer.wrapping_add(start), num) } + .fill(value); + } else { + unsafe { + std::slice::from_raw_parts_mut( + self.buffer.wrapping_add(start), + self.capacity - start, + ) + } + .fill(value); + + if end != 0 { + unsafe { std::slice::from_raw_parts_mut(self.buffer, end) }.fill(value); + } + } + + self.length += num; + + num + } +} + +impl Drop for FixedRingBuffer { + fn drop(&mut self) { + for i in 0..self.length { + let offset = wrapping_add(self.start, i, self.capacity); + unsafe { self.buffer.wrapping_add(offset).read() }; + } + + unsafe { Vec::from_raw_parts(self.buffer, 0, self._buffer_capacity) }; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn basic() { + let mut frb = FixedRingBuffer::new(256); + + assert!(frb.pop_front().is_none()); + + frb.push(1).unwrap(); + frb.push(3).unwrap(); + + assert_eq!(frb.pop_front(), Some(1)); + assert_eq!(frb.pop_front(), Some(3)); + assert_eq!(frb.pop_front(), None); + + assert!(!frb.is_full()); + assert_eq!(frb.fill_repeat(42, 300), 256); + assert!(frb.is_full()); + + for _ in 0..256 { + assert_eq!(frb.pop_front(), Some(42)); + assert!(!frb.is_full()); + } + assert_eq!(frb.pop_front(), None); + } + + #[test] + fn boxed() { + let mut frb = FixedRingBuffer::new(256); + + assert!(frb.pop_front().is_none()); + + frb.push(Box::new(1)).unwrap(); + frb.push(Box::new(3)).unwrap(); + + assert_eq!(frb.pop_front(), Some(Box::new(1))); + assert_eq!(frb.pop_front(), Some(Box::new(3))); + assert_eq!(frb.pop_front(), None); + } +} diff --git a/crates/polars-utils/src/float.rs b/crates/polars-utils/src/float.rs new file mode 100644 index 000000000000..30d47397c28e --- /dev/null +++ b/crates/polars-utils/src/float.rs @@ -0,0 +1,103 @@ +/// # Safety +/// unsafe code downstream relies on the correct is_float call +pub unsafe trait IsFloat: private::Sealed + Sized { + fn is_float() -> bool { + false + } + + fn is_f32() -> bool { + false + } + + fn is_f64() -> bool { + false + } + + fn nan_value() -> Self { + unimplemented!() + } + + #[allow(clippy::wrong_self_convention)] + fn is_nan(&self) -> bool + where + Self: Sized, + { + false + } + #[allow(clippy::wrong_self_convention)] + fn is_finite(&self) -> bool + where + Self: Sized, + { + true + } +} + +unsafe impl IsFloat for i8 {} +unsafe impl IsFloat for i16 {} +unsafe impl IsFloat for i32 {} +unsafe impl IsFloat for i64 {} +unsafe impl IsFloat for i128 {} +unsafe impl IsFloat for u8 {} +unsafe impl IsFloat for u16 {} +unsafe impl IsFloat for u32 {} +unsafe impl IsFloat for u64 {} +unsafe impl IsFloat for &str {} +unsafe impl IsFloat for &[u8] {} +unsafe impl IsFloat for bool {} +unsafe impl IsFloat for Option {} + +mod private { + pub trait Sealed {} + impl Sealed for i8 {} + impl Sealed for i16 {} + impl Sealed for i32 {} + impl Sealed for i64 {} + impl Sealed for i128 {} + impl Sealed for u8 {} + impl Sealed for u16 {} + impl Sealed for u32 {} + impl Sealed for u64 {} + impl Sealed for f32 {} + impl Sealed for f64 {} + impl Sealed for &str {} + impl Sealed for &[u8] {} + impl Sealed for bool {} + impl Sealed for Option {} +} + +macro_rules! impl_is_float { + ($tp:ty, $is_f32:literal, $is_f64:literal) => { + unsafe impl IsFloat for $tp { + #[inline] + fn is_float() -> bool { + true + } + + fn is_f32() -> bool { + $is_f32 + } + + fn is_f64() -> bool { + $is_f64 + } + + fn nan_value() -> Self { + Self::NAN + } + + #[inline] + fn is_nan(&self) -> bool { + <$tp>::is_nan(*self) + } + + #[inline] + fn is_finite(&self) -> bool { + <$tp>::is_finite(*self) + } + } + }; +} + +impl_is_float!(f32, true, false); +impl_is_float!(f64, false, true); diff --git a/crates/polars-utils/src/floor_divmod.rs b/crates/polars-utils/src/floor_divmod.rs new file mode 100644 index 000000000000..14c02fc8d257 --- /dev/null +++ b/crates/polars-utils/src/floor_divmod.rs @@ -0,0 +1,102 @@ +pub trait FloorDivMod: Sized { + // Returns the flooring division and associated modulo of lhs / rhs. + // This is the same division / modulo combination as Python. + // + // Returns (0, 0) if other == 0. + fn wrapping_floor_div_mod(self, other: Self) -> (Self, Self); +} + +macro_rules! impl_float_div_mod { + ($T:ty) => { + impl FloorDivMod for $T { + #[inline] + fn wrapping_floor_div_mod(self, other: Self) -> (Self, Self) { + let div = (self / other).floor(); + let mod_ = self - other * div; + (div, mod_) + } + } + }; +} + +macro_rules! impl_unsigned_div_mod { + ($T:ty) => { + impl FloorDivMod for $T { + #[inline] + fn wrapping_floor_div_mod(self, other: Self) -> (Self, Self) { + (self / other, self % other) + } + } + }; +} + +macro_rules! impl_signed_div_mod { + ($T:ty) => { + impl FloorDivMod for $T { + #[inline] + fn wrapping_floor_div_mod(self, other: Self) -> (Self, Self) { + if other == 0 { + return (0, 0); + } + + // Rust/C-style remainder is in the correct congruence + // class, but may not have the right sign. We want a + // remainder with the same sign as the RHS, which we + // can get by adding RHS to the remainder if the sign of + // the non-zero remainder differs from our RHS. + // + // Similarly, Rust/C-style division truncates instead of floors. + // If the remainder was non-zero and the signs were different + // (we'd have a negative result before truncating), we need to + // subtract 1 from the result. + let mut div = self.wrapping_div(other); + let mut mod_ = self.wrapping_rem(other); + if mod_ != 0 && (self < 0) != (other < 0) { + div -= 1; + mod_ += other; + } + (div, mod_) + } + } + }; +} + +impl_unsigned_div_mod!(u8); +impl_unsigned_div_mod!(u16); +impl_unsigned_div_mod!(u32); +impl_unsigned_div_mod!(u64); +impl_unsigned_div_mod!(u128); +impl_unsigned_div_mod!(usize); +impl_signed_div_mod!(i8); +impl_signed_div_mod!(i16); +impl_signed_div_mod!(i32); +impl_signed_div_mod!(i64); +impl_signed_div_mod!(i128); +impl_signed_div_mod!(isize); +impl_float_div_mod!(f32); +impl_float_div_mod!(f64); + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_signed_wrapping_div_mod() { + // Test for all i8, should transfer to other values. + for lhs in i8::MIN..=i8::MAX { + for rhs in i8::MIN..=i8::MAX { + let ans = if rhs != 0 { + let fdiv = (lhs as f64 / rhs as f64).floor(); + let fmod = lhs as f64 - rhs as f64 * fdiv; + + // float -> int conversion saturates, we want wrapping, double convert. + ((fdiv as i32) as i8, (fmod as i32) as i8) + } else { + (0, 0) + }; + + assert_eq!(lhs.wrapping_floor_div_mod(rhs), ans); + } + } + } +} diff --git a/crates/polars-utils/src/fmt.rs b/crates/polars-utils/src/fmt.rs new file mode 100644 index 000000000000..a292561f7cdb --- /dev/null +++ b/crates/polars-utils/src/fmt.rs @@ -0,0 +1,84 @@ +#[macro_export] +macro_rules! format_list_container { + ($e:expr, $start:tt, $end:tt) => {{ + use std::fmt::Write; + let mut out = String::new(); + out.push($start); + let mut iter = $e.into_iter(); + let mut next = iter.next(); + + loop { + if let Some(val) = next { + write!(out, "{val}").unwrap(); + }; + next = iter.next(); + if next.is_some() { + out.push_str(", ") + } else { + break; + } + } + out.push($end); + out + };}; +} + +#[macro_export] +macro_rules! format_list { + ($e:expr) => {{ + use polars_utils::format_list_container; + format_list_container!($e, '[', ']') + };}; +} + +#[macro_export] +macro_rules! format_tuple { + ($e:expr) => {{ + use polars_utils::format_list_container; + format_list_container!($e, '(', ')') + };}; +} + +#[macro_export] +macro_rules! format_list_container_truncated { + ($e:expr, $start:tt, $end:tt, $max:expr, $quote:expr) => {{ + use std::fmt::Write; + let mut out = String::new(); + out.push($start); + let mut iter = $e.into_iter(); + let mut next = iter.next(); + + let mut count = 0; + loop { + if $max == count { + write!(out, "...").unwrap(); + break; + } + count += 1; + + if let Some(val) = next { + write!(out, "{}{}{}", $quote, val, $quote).unwrap(); + }; + next = iter.next(); + if next.is_some() { + out.push_str(", ") + } else { + break; + } + } + out.push($end); + out + };}; +} + +#[macro_export] +macro_rules! format_list_truncated { + ($e:expr, $max:expr) => {{ + use polars_utils::format_list_container_truncated; + format_list_container_truncated!($e, '[', ']', $max, "") + };}; + ($e:expr, $max:expr, $quote:expr) => {{ + use polars_utils::format_list_container_truncated; + format_list_container_truncated!($e, '[', ']', $max, $quote) + };}; +} diff --git a/crates/polars-utils/src/functions.rs b/crates/polars-utils/src/functions.rs new file mode 100644 index 000000000000..f91e8611e3c3 --- /dev/null +++ b/crates/polars-utils/src/functions.rs @@ -0,0 +1,53 @@ +use std::mem::MaybeUninit; +use std::ops::Range; +use std::sync::Arc; + +// The ith portion of a range split in k (as equal as possible) parts. +#[inline(always)] +pub fn range_portion(i: usize, k: usize, r: Range) -> Range { + // Each portion having size n / k leaves n % k elements unaccounted for. + // Make the first n % k portions have 1 extra element. + let n = r.len(); + let base_size = n / k; + let num_one_larger = n % k; + let num_before = base_size * i + i.min(num_one_larger); + let our_size = base_size + (i < num_one_larger) as usize; + r.start + num_before..r.start + num_before + our_size +} + +pub fn arc_map T>(mut arc: Arc, mut f: F) -> Arc { + unsafe { + // Make the Arc unique (cloning if necessary). + Arc::make_mut(&mut arc); + + // If f panics we must be able to drop the Arc without assuming it is initialized. + let mut uninit_arc = Arc::from_raw(Arc::into_raw(arc).cast::>()); + + // Replace the value inside the arc. + let ptr = Arc::get_mut(&mut uninit_arc).unwrap_unchecked() as *mut MaybeUninit; + *ptr = MaybeUninit::new(f(ptr.read().assume_init())); + + // Now the Arc is properly initialized again. + Arc::from_raw(Arc::into_raw(uninit_arc).cast::()) + } +} + +pub fn try_arc_map Result>( + mut arc: Arc, + mut f: F, +) -> Result, E> { + unsafe { + // Make the Arc unique (cloning if necessary). + Arc::make_mut(&mut arc); + + // If f panics we must be able to drop the Arc without assuming it is initialized. + let mut uninit_arc = Arc::from_raw(Arc::into_raw(arc).cast::>()); + + // Replace the value inside the arc. + let ptr = Arc::get_mut(&mut uninit_arc).unwrap_unchecked() as *mut MaybeUninit; + *ptr = MaybeUninit::new(f(ptr.read().assume_init())?); + + // Now the Arc is properly initialized again. + Ok(Arc::from_raw(Arc::into_raw(uninit_arc).cast::())) + } +} diff --git a/crates/polars-utils/src/hashing.rs b/crates/polars-utils/src/hashing.rs new file mode 100644 index 000000000000..5dff05b6577e --- /dev/null +++ b/crates/polars-utils/src/hashing.rs @@ -0,0 +1,179 @@ +use std::hash::{Hash, Hasher}; + +use crate::nulls::IsNull; + +// Hash combine from c++' boost lib. +#[inline(always)] +pub fn _boost_hash_combine(l: u64, r: u64) -> u64 { + l ^ r.wrapping_add(0x9e3779b9u64.wrapping_add(l << 6).wrapping_add(r >> 2)) +} + +#[inline(always)] +pub const fn folded_multiply(a: u64, b: u64) -> u64 { + let full = (a as u128).wrapping_mul(b as u128); + (full as u64) ^ ((full >> 64) as u64) +} + +/// Contains a byte slice and a precomputed hash for that string. +/// During rehashes, we will rehash the hash instead of the string, that makes +/// rehashing cheap and allows cache coherent small hash tables. +#[derive(Eq, Copy, Clone, Debug)] +pub struct BytesHash<'a> { + payload: Option<&'a [u8]>, + pub(super) hash: u64, +} + +impl<'a> BytesHash<'a> { + #[inline] + pub fn new(s: Option<&'a [u8]>, hash: u64) -> Self { + Self { payload: s, hash } + } +} + +impl<'a> IsNull for BytesHash<'a> { + const HAS_NULLS: bool = true; + type Inner = BytesHash<'a>; + + #[inline(always)] + fn is_null(&self) -> bool { + self.payload.is_none() + } + + fn unwrap_inner(self) -> Self::Inner { + assert!(self.payload.is_some()); + self + } +} + +impl Hash for BytesHash<'_> { + fn hash(&self, state: &mut H) { + state.write_u64(self.hash) + } +} + +impl PartialEq for BytesHash<'_> { + #[inline] + fn eq(&self, other: &Self) -> bool { + (self.hash == other.hash) && (self.payload == other.payload) + } +} + +#[inline(always)] +pub fn hash_to_partition(h: u64, n_partitions: usize) -> usize { + // Assuming h is a 64-bit random number, we note that + // h / 2^64 is almost a uniform random number in [0, 1), and thus + // floor(h * n_partitions / 2^64) is almost a uniform random integer in + // [0, n_partitions). Despite being written with u128 multiplication this + // compiles to a single mul / mulhi instruction on x86-x64/aarch64. + ((h as u128 * n_partitions as u128) >> 64) as usize +} + +#[derive(Clone)] +pub struct HashPartitioner { + num_partitions: usize, + seed: u64, +} + +impl HashPartitioner { + /// Creates a new hash partitioner with the given number of partitions and + /// seed. + #[inline] + pub fn new(num_partitions: usize, mut seed: u64) -> Self { + assert!(num_partitions > 0); + // Make sure seeds bits are properly randomized, and is odd. + const ARBITRARY1: u64 = 0x85921e81c41226a0; + const ARBITRARY2: u64 = 0x3bc1d0faba166294; + const ARBITRARY3: u64 = 0xfbde893e21a73756; + seed = folded_multiply(seed ^ ARBITRARY1, ARBITRARY2); + seed = folded_multiply(seed, ARBITRARY3); + seed |= 1; + Self { + seed, + num_partitions, + } + } + + /// Converts a hash to a partition. It is guaranteed that the output is + /// in the range [0, n_partitions), and that independent HashPartitioners + /// that we initialized with the same num_partitions and seed return the same + /// partition. + #[inline(always)] + pub fn hash_to_partition(&self, hash: u64) -> usize { + // Assuming r is a 64-bit random number, we note that + // r / 2^64 is almost a uniform random number in [0, 1), and thus + // floor(r * n_partitions / 2^64) is almost a uniform random integer in + // [0, n_partitions). Despite being written with u128 multiplication this + // compiles to a single mul / mulhi instruction on x86-x64/aarch64. + let shuffled = hash.wrapping_mul(self.seed); + ((shuffled as u128 * self.num_partitions as u128) >> 64) as usize + } + + /// The partition nulls are put into. + #[inline(always)] + pub fn null_partition(&self) -> usize { + 0 + } + + #[inline(always)] + pub fn num_partitions(&self) -> usize { + self.num_partitions + } +} + +// FIXME: use Hasher interface and support a random state. +pub trait DirtyHash { + // A quick and dirty hash. Only the top bits of the hash are decent, such as + // used in hash_to_partition. + fn dirty_hash(&self) -> u64; +} + +// Multiplication by a 'random' odd number gives a universal hash function in +// the top bits. +const RANDOM_ODD: u64 = 0x55fbfd6bfc5458e9; + +macro_rules! impl_hash_partition_as_u64 { + ($T: ty) => { + impl DirtyHash for $T { + fn dirty_hash(&self) -> u64 { + (*self as u64).wrapping_mul(RANDOM_ODD) + } + } + }; +} + +impl_hash_partition_as_u64!(u8); +impl_hash_partition_as_u64!(u16); +impl_hash_partition_as_u64!(u32); +impl_hash_partition_as_u64!(u64); +impl_hash_partition_as_u64!(i8); +impl_hash_partition_as_u64!(i16); +impl_hash_partition_as_u64!(i32); +impl_hash_partition_as_u64!(i64); + +impl DirtyHash for i128 { + fn dirty_hash(&self) -> u64 { + (*self as u64) + .wrapping_mul(RANDOM_ODD) + .wrapping_add((*self >> 64) as u64) + } +} + +impl DirtyHash for BytesHash<'_> { + fn dirty_hash(&self) -> u64 { + self.hash + } +} + +impl DirtyHash for &T { + fn dirty_hash(&self) -> u64 { + (*self).dirty_hash() + } +} + +// FIXME: we should probably encourage explicit null handling, but for now we'll +// allow directly getting a partition from a nullable value. +impl DirtyHash for Option { + fn dirty_hash(&self) -> u64 { + self.as_ref().map(|s| s.dirty_hash()).unwrap_or(0) + } +} diff --git a/crates/polars-utils/src/idx_map/bytes_idx_map.rs b/crates/polars-utils/src/idx_map/bytes_idx_map.rs new file mode 100644 index 000000000000..21ebd61deb55 --- /dev/null +++ b/crates/polars-utils/src/idx_map/bytes_idx_map.rs @@ -0,0 +1,203 @@ +use hashbrown::hash_table::{ + Entry as TEntry, HashTable, OccupiedEntry as TOccupiedEntry, VacantEntry as TVacantEntry, +}; + +use crate::IdxSize; + +const BASE_KEY_DATA_CAPACITY: usize = 1024; + +struct Key { + key_hash: u64, + key_buffer: u32, + key_offset: usize, + key_length: u32, +} + +impl Key { + unsafe fn get<'k>(&self, key_data: &'k [Vec]) -> &'k [u8] { + let buf = unsafe { key_data.get_unchecked(self.key_buffer as usize) }; + unsafe { buf.get_unchecked(self.key_offset..self.key_offset + self.key_length as usize) } + } +} + +/// An IndexMap where the keys are always [u8] slices which are pre-hashed. +pub struct BytesIndexMap { + table: HashTable, + tuples: Vec<(Key, V)>, + key_data: Vec>, + + // Internal random seed used to keep hash iteration order decorrelated. + // We simply store a random odd number and multiply the canonical hash by it. + seed: u64, +} + +impl Default for BytesIndexMap { + fn default() -> Self { + Self { + table: HashTable::new(), + tuples: Vec::new(), + key_data: vec![Vec::with_capacity(BASE_KEY_DATA_CAPACITY)], + seed: rand::random::() | 1, + } + } +} + +impl BytesIndexMap { + pub fn new() -> Self { + Self::default() + } + + pub fn reserve(&mut self, additional: usize) { + self.table.reserve(additional, |i| unsafe { + let tuple = self.tuples.get_unchecked(*i as usize); + tuple.0.key_hash.wrapping_mul(self.seed) + }); + self.tuples.reserve(additional); + } + + pub fn len(&self) -> IdxSize { + self.tuples.len() as IdxSize + } + + pub fn is_empty(&self) -> bool { + self.tuples.is_empty() + } + + pub fn get(&self, hash: u64, key: &[u8]) -> Option<&V> { + let idx = self.table.find(hash.wrapping_mul(self.seed), |i| unsafe { + let t = self.tuples.get_unchecked(*i as usize); + hash == t.0.key_hash && key == t.0.get(&self.key_data) + })?; + unsafe { Some(&self.tuples.get_unchecked(*idx as usize).1) } + } + + pub fn contains_key(&self, hash: u64, key: &[u8]) -> bool { + self.table + .find(hash.wrapping_mul(self.seed), |i| unsafe { + let t = self.tuples.get_unchecked(*i as usize); + hash == t.0.key_hash && key == t.0.get(&self.key_data) + }) + .is_some() + } + + pub fn entry<'k>(&mut self, hash: u64, key: &'k [u8]) -> Entry<'_, 'k, V> { + let entry = self.table.entry( + hash.wrapping_mul(self.seed), + |i| unsafe { + let t = self.tuples.get_unchecked(*i as usize); + hash == t.0.key_hash && key == t.0.get(&self.key_data) + }, + |i| unsafe { + let t = self.tuples.get_unchecked(*i as usize); + t.0.key_hash.wrapping_mul(self.seed) + }, + ); + + match entry { + TEntry::Occupied(o) => Entry::Occupied(OccupiedEntry { + entry: o, + tuples: &mut self.tuples, + }), + TEntry::Vacant(v) => Entry::Vacant(VacantEntry { + key, + hash, + entry: v, + tuples: &mut self.tuples, + key_data: &mut self.key_data, + }), + } + } + + /// Gets the hash, key and value at the given index by insertion order. + #[inline(always)] + pub fn get_index(&self, idx: IdxSize) -> Option<(u64, &[u8], &V)> { + let t = self.tuples.get(idx as usize)?; + Some((t.0.key_hash, unsafe { t.0.get(&self.key_data) }, &t.1)) + } + + /// Gets the hash, key and value at the given index by insertion order. + /// + /// # Safety + /// The index must be less than len(). + #[inline(always)] + pub unsafe fn get_index_unchecked(&self, idx: IdxSize) -> (u64, &[u8], &V) { + let t = unsafe { self.tuples.get_unchecked(idx as usize) }; + unsafe { (t.0.key_hash, t.0.get(&self.key_data), &t.1) } + } + + /// Iterates over the (hash, key) pairs in insertion order. + pub fn iter_hash_keys(&self) -> impl Iterator { + self.tuples + .iter() + .map(|t| unsafe { (t.0.key_hash, t.0.get(&self.key_data)) }) + } + + /// Iterates over the values in insertion order. + pub fn iter_values(&self) -> impl Iterator { + self.tuples.iter().map(|t| &t.1) + } +} + +pub enum Entry<'a, 'k, V> { + Occupied(OccupiedEntry<'a, V>), + Vacant(VacantEntry<'a, 'k, V>), +} + +pub struct OccupiedEntry<'a, V> { + entry: TOccupiedEntry<'a, IdxSize>, + tuples: &'a mut Vec<(Key, V)>, +} + +impl<'a, V> OccupiedEntry<'a, V> { + pub fn index(&self) -> IdxSize { + *self.entry.get() + } + + pub fn into_mut(self) -> &'a mut V { + let idx = self.index(); + unsafe { &mut self.tuples.get_unchecked_mut(idx as usize).1 } + } +} + +pub struct VacantEntry<'a, 'k, V> { + hash: u64, + key: &'k [u8], + entry: TVacantEntry<'a, IdxSize>, + tuples: &'a mut Vec<(Key, V)>, + key_data: &'a mut Vec>, +} + +#[allow(clippy::needless_lifetimes)] +impl<'a, 'k, V> VacantEntry<'a, 'k, V> { + pub fn index(&self) -> IdxSize { + self.tuples.len() as IdxSize + } + + pub fn insert(self, value: V) -> &'a mut V { + unsafe { + let tuple_idx: IdxSize = self.tuples.len().try_into().unwrap(); + + let mut num_buffers = self.key_data.len() as u32; + let mut active_buf = self.key_data.last_mut().unwrap_unchecked(); + let key_len = self.key.len(); + if active_buf.len() + key_len > active_buf.capacity() { + let ideal_next_cap = BASE_KEY_DATA_CAPACITY.checked_shl(num_buffers).unwrap(); + let next_capacity = std::cmp::max(ideal_next_cap, key_len); + self.key_data.push(Vec::with_capacity(next_capacity)); + active_buf = self.key_data.last_mut().unwrap_unchecked(); + num_buffers += 1; + } + + let tuple_key = Key { + key_hash: self.hash, + key_buffer: num_buffers - 1, + key_offset: active_buf.len(), + key_length: self.key.len().try_into().unwrap(), + }; + self.tuples.push((tuple_key, value)); + active_buf.extend_from_slice(self.key); + self.entry.insert(tuple_idx); + &mut self.tuples.last_mut().unwrap_unchecked().1 + } + } +} diff --git a/crates/polars-utils/src/idx_map/mod.rs b/crates/polars-utils/src/idx_map/mod.rs new file mode 100644 index 000000000000..5fb2b5cd69cc --- /dev/null +++ b/crates/polars-utils/src/idx_map/mod.rs @@ -0,0 +1,5 @@ +pub mod bytes_idx_map; +pub use bytes_idx_map::BytesIndexMap; + +pub mod total_idx_map; +pub use total_idx_map::TotalIndexMap; diff --git a/crates/polars-utils/src/idx_map/total_idx_map.rs b/crates/polars-utils/src/idx_map/total_idx_map.rs new file mode 100644 index 000000000000..3ae15ea9e14e --- /dev/null +++ b/crates/polars-utils/src/idx_map/total_idx_map.rs @@ -0,0 +1,156 @@ +use hashbrown::hash_table::{ + Entry as TEntry, HashTable, OccupiedEntry as TOccupiedEntry, VacantEntry as TVacantEntry, +}; + +use crate::IdxSize; +use crate::aliases::PlRandomState; +use crate::total_ord::{BuildHasherTotalExt, TotalEq, TotalHash}; + +/// An IndexMap where the keys are hashed and compared with TotalOrd/TotalEq. +pub struct TotalIndexMap { + table: HashTable, + tuples: Vec<(K, V)>, + random_state: PlRandomState, +} + +impl Default for TotalIndexMap { + fn default() -> Self { + Self { + table: HashTable::new(), + tuples: Vec::new(), + random_state: PlRandomState::default(), + } + } +} + +impl TotalIndexMap { + pub fn reserve(&mut self, additional: usize) { + self.table.reserve(additional, |i| unsafe { + let tuple = self.tuples.get_unchecked(*i as usize); + self.random_state.tot_hash_one(&tuple.0) + }); + self.tuples.reserve(additional); + } + + pub fn len(&self) -> IdxSize { + self.tuples.len() as IdxSize + } + + pub fn is_empty(&self) -> bool { + self.tuples.is_empty() + } + + pub fn get(&self, key: &K) -> Option<&V> { + let hash = self.random_state.tot_hash_one(key); + let idx = self.table.find(hash, |i| unsafe { + let t = self.tuples.get_unchecked(*i as usize); + hash == self.random_state.tot_hash_one(&t.0) && key.tot_eq(&t.0) + })?; + unsafe { Some(&self.tuples.get_unchecked(*idx as usize).1) } + } + + pub fn entry(&mut self, key: K) -> Entry<'_, K, V> { + let hash = self.random_state.tot_hash_one(&key); + let entry = self.table.entry( + hash, + |i| unsafe { + let t = self.tuples.get_unchecked(*i as usize); + hash == self.random_state.tot_hash_one(&t.0) && key.tot_eq(&t.0) + }, + |i| unsafe { + let t = self.tuples.get_unchecked(*i as usize); + self.random_state.tot_hash_one(&t.0) + }, + ); + + match entry { + TEntry::Occupied(o) => Entry::Occupied(OccupiedEntry { + entry: o, + tuples: &mut self.tuples, + }), + TEntry::Vacant(v) => Entry::Vacant(VacantEntry { + key, + entry: v, + tuples: &mut self.tuples, + }), + } + } + + /// Insert a key which will never be mapped to. Returns the index of the entry. + /// + /// This is useful for entries which are handled externally. + pub fn push_unmapped_entry(&mut self, key: K, value: V) -> IdxSize { + let ret = self.tuples.len() as IdxSize; + self.tuples.push((key, value)); + ret + } + + /// Gets the key and value at the given index by insertion order. + #[inline(always)] + pub fn get_index(&self, idx: IdxSize) -> Option<(&K, &V)> { + let t = self.tuples.get(idx as usize)?; + Some((&t.0, &t.1)) + } + + /// Gets the key and value at the given index by insertion order. + /// + /// # Safety + /// The index must be less than len(). + #[inline(always)] + pub unsafe fn get_index_unchecked(&self, idx: IdxSize) -> (&K, &V) { + let t = unsafe { self.tuples.get_unchecked(idx as usize) }; + (&t.0, &t.1) + } + + /// Iterates over the keys in insertion order. + pub fn iter_keys(&self) -> impl Iterator { + self.tuples.iter().map(|t| &t.0) + } + + /// Iterates over the values in insertion order. + pub fn iter_values(&self) -> impl Iterator { + self.tuples.iter().map(|t| &t.1) + } +} + +pub enum Entry<'a, K, V> { + Occupied(OccupiedEntry<'a, K, V>), + Vacant(VacantEntry<'a, K, V>), +} + +pub struct OccupiedEntry<'a, K, V> { + entry: TOccupiedEntry<'a, IdxSize>, + tuples: &'a mut Vec<(K, V)>, +} + +impl<'a, K, V> OccupiedEntry<'a, K, V> { + pub fn index(&self) -> IdxSize { + *self.entry.get() + } + + pub fn into_mut(self) -> &'a mut V { + let idx = self.index(); + unsafe { &mut self.tuples.get_unchecked_mut(idx as usize).1 } + } +} + +pub struct VacantEntry<'a, K, V> { + key: K, + entry: TVacantEntry<'a, IdxSize>, + tuples: &'a mut Vec<(K, V)>, +} + +impl<'a, K, V> VacantEntry<'a, K, V> { + pub fn index(&self) -> IdxSize { + self.tuples.len() as IdxSize + } + + pub fn insert(self, value: V) -> &'a mut V { + unsafe { + let tuple_idx: IdxSize = self.tuples.len().try_into().unwrap(); + self.tuples.push((self.key, value)); + self.entry.insert(tuple_idx); + &mut self.tuples.last_mut().unwrap_unchecked().1 + } + } +} diff --git a/crates/polars-utils/src/idx_mapper.rs b/crates/polars-utils/src/idx_mapper.rs new file mode 100644 index 000000000000..d86c50aac4fb --- /dev/null +++ b/crates/polars-utils/src/idx_mapper.rs @@ -0,0 +1,55 @@ +use std::ops::Range; + +/// Reverses indexing direction +pub struct IdxMapper { + total_len: usize, + reverse: bool, +} + +impl IdxMapper { + pub fn new(total_len: usize, reverse: bool) -> Self { + Self { total_len, reverse } + } +} + +impl IdxMapper { + /// # Panics + /// `range.end <= self.total_len` + #[inline] + pub fn map_range(&self, range: Range) -> Range { + if self.reverse { + // len: 5 + // array: [0 1 2 3 4] + // slice: [ 2 3 ] + // in: 1..3 (right-to-left) + // out: 2..4 + map_range::(self.total_len, range) + } else { + range + } + } +} + +/// # Safety +/// `range.end <= total_len` +#[inline] +pub fn map_range(total_len: usize, range: Range) -> Range { + assert!(range.end <= total_len); + if REVERSE { + total_len - range.end..total_len - range.start + } else { + range + } +} + +#[cfg(test)] +mod tests { + use super::IdxMapper; + + #[test] + fn test_idx_map_roundtrip() { + let map = IdxMapper::new(100, true); + + assert_eq!(map.map_range(map.map_range(5..77)), 5..77); + } +} diff --git a/crates/polars-utils/src/idx_vec.rs b/crates/polars-utils/src/idx_vec.rs new file mode 100644 index 000000000000..4bb796b7f07a --- /dev/null +++ b/crates/polars-utils/src/idx_vec.rs @@ -0,0 +1,369 @@ +use std::fmt::{Debug, Formatter}; +use std::ops::Deref; + +use crate::index::{IdxSize, NonZeroIdxSize}; + +pub type IdxVec = UnitVec; + +/// A type logically equivalent to `Vec`, but which does not do a +/// memory allocation until at least two elements have been pushed, storing the +/// first element in the data pointer directly. +/// +/// Uses IdxSize internally to store lengths, will panic if trying to reserve +/// for more elements. +#[derive(Eq)] +pub struct UnitVec { + len: IdxSize, + capacity: NonZeroIdxSize, + data: *mut T, +} + +unsafe impl Send for UnitVec {} +unsafe impl Sync for UnitVec {} + +impl UnitVec { + #[inline(always)] + fn data_ptr_mut(&mut self) -> *mut T { + let external = self.data; + let inline = &mut self.data as *mut *mut T as *mut T; + if self.capacity.get() == 1 { + inline + } else { + external + } + } + + #[inline(always)] + fn data_ptr(&self) -> *const T { + let external = self.data; + let inline = &self.data as *const *mut T as *mut T; + if self.capacity.get() == 1 { + inline + } else { + external + } + } + + #[inline] + pub fn new() -> Self { + // This is optimized away, all const. + assert!(size_of::() <= size_of::<*mut T>() && align_of::() <= align_of::<*mut T>()); + Self { + len: 0, + capacity: NonZeroIdxSize::new(1).unwrap(), + data: std::ptr::null_mut(), + } + } + + #[inline(always)] + pub fn len(&self) -> usize { + self.len as usize + } + + #[inline(always)] + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + #[inline(always)] + pub fn capacity(&self) -> usize { + self.capacity.get() as usize + } + + #[inline(always)] + pub fn clear(&mut self) { + self.len = 0; + } + + #[inline(always)] + pub fn push(&mut self, idx: T) { + if self.len == self.capacity.get() { + self.reserve(1); + } + + unsafe { self.push_unchecked(idx) } + } + + #[inline(always)] + /// # Safety + /// Caller must ensure that `UnitVec` has enough capacity. + pub unsafe fn push_unchecked(&mut self, idx: T) { + unsafe { + self.data_ptr_mut().add(self.len as usize).write(idx); + self.len += 1; + } + } + + #[cold] + #[inline(never)] + pub fn reserve(&mut self, additional: usize) { + let new_len = self + .len + .checked_add(additional.try_into().unwrap()) + .unwrap(); + if new_len > self.capacity.get() { + let double = self.capacity.get() * 2; + self.realloc(double.max(new_len).max(8)); + } + } + + /// # Panics + /// Panics if `new_cap <= 1` or `new_cap < self.len` + fn realloc(&mut self, new_cap: IdxSize) { + assert!(new_cap > 1 && new_cap >= self.len); + unsafe { + let mut me = std::mem::ManuallyDrop::new(Vec::with_capacity(new_cap as usize)); + let buffer = me.as_mut_ptr(); + std::ptr::copy(self.data_ptr(), buffer, self.len as usize); + self.dealloc(); + self.data = buffer; + self.capacity = NonZeroIdxSize::new(new_cap).unwrap(); + } + } + + fn dealloc(&mut self) { + unsafe { + if self.capacity.get() > 1 { + let _ = Vec::from_raw_parts(self.data, self.len as usize, self.capacity()); + self.capacity = NonZeroIdxSize::new(1).unwrap(); + } + } + } + + pub fn with_capacity(capacity: usize) -> Self { + if capacity <= 1 { + Self::new() + } else { + let mut me = std::mem::ManuallyDrop::new(Vec::with_capacity(capacity)); + let data = me.as_mut_ptr(); + Self { + len: 0, + capacity: NonZeroIdxSize::new(capacity.try_into().unwrap()).unwrap(), + data, + } + } + } + + #[inline] + pub fn iter(&self) -> std::slice::Iter<'_, T> { + self.as_slice().iter() + } + + #[inline] + pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, T> { + self.as_mut_slice().iter_mut() + } + + #[inline] + pub fn as_slice(&self) -> &[T] { + self.as_ref() + } + + #[inline] + pub fn as_mut_slice(&mut self) -> &mut [T] { + self.as_mut() + } + + #[inline] + pub fn pop(&mut self) -> Option { + if self.len == 0 { + None + } else { + unsafe { + self.len -= 1; + Some(std::ptr::read(self.as_ptr().add(self.len()))) + } + } + } +} + +impl Extend for UnitVec { + fn extend>(&mut self, iter: I) { + let iter = iter.into_iter(); + self.reserve(iter.size_hint().0); + for v in iter { + self.push(v) + } + } +} + +impl Drop for UnitVec { + fn drop(&mut self) { + self.dealloc() + } +} + +impl Clone for UnitVec { + fn clone(&self) -> Self { + unsafe { + if self.capacity.get() == 1 { + Self { ..*self } + } else { + let mut copy = Self::with_capacity(self.len as usize); + std::ptr::copy(self.data_ptr(), copy.data_ptr_mut(), self.len as usize); + copy.len = self.len; + copy + } + } + } +} + +impl Debug for UnitVec { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "UnitVec: {:?}", self.as_slice()) + } +} + +impl Default for UnitVec { + fn default() -> Self { + Self { + len: 0, + capacity: NonZeroIdxSize::new(1).unwrap(), + data: std::ptr::null_mut(), + } + } +} + +impl Deref for UnitVec { + type Target = [T]; + + fn deref(&self) -> &Self::Target { + self.as_slice() + } +} + +impl AsRef<[T]> for UnitVec { + fn as_ref(&self) -> &[T] { + unsafe { std::slice::from_raw_parts(self.data_ptr(), self.len as usize) } + } +} + +impl AsMut<[T]> for UnitVec { + fn as_mut(&mut self) -> &mut [T] { + unsafe { std::slice::from_raw_parts_mut(self.data_ptr_mut(), self.len as usize) } + } +} + +impl PartialEq for UnitVec { + fn eq(&self, other: &Self) -> bool { + self.as_slice() == other.as_slice() + } +} + +impl FromIterator for UnitVec { + fn from_iter>(iter: I) -> Self { + let iter = iter.into_iter(); + if iter.size_hint().0 <= 1 { + let mut new = UnitVec::new(); + for v in iter { + new.push(v) + } + new + } else { + let v = iter.collect::>(); + v.into() + } + } +} + +impl From> for UnitVec { + fn from(mut value: Vec) -> Self { + if value.capacity() <= 1 { + let mut new = UnitVec::new(); + if let Some(v) = value.pop() { + new.push(v) + } + new + } else { + let mut me = std::mem::ManuallyDrop::new(value); + UnitVec { + data: me.as_mut_ptr(), + capacity: NonZeroIdxSize::new(me.capacity().try_into().unwrap()).unwrap(), + len: me.len().try_into().unwrap(), + } + } + } +} + +impl From<&[T]> for UnitVec { + fn from(value: &[T]) -> Self { + if value.len() <= 1 { + let mut new = UnitVec::new(); + if let Some(v) = value.first() { + new.push(v.clone()) + } + new + } else { + value.to_vec().into() + } + } +} + +#[macro_export] +macro_rules! unitvec { + () => ( + $crate::idx_vec::UnitVec::new() + ); + ($elem:expr; $n:expr) => ( + let mut new = $crate::idx_vec::UnitVec::new(); + for _ in 0..$n { + new.push($elem) + } + new + ); + ($elem:expr) => ( + {let mut new = $crate::idx_vec::UnitVec::new(); + let v = $elem; + // SAFETY: first element always fits. + unsafe { new.push_unchecked(v) }; + new} + ); + ($($x:expr),+ $(,)?) => ( + vec![$($x),+].into() + ); +} + +mod tests { + + #[test] + #[should_panic] + fn test_unitvec_realloc_zero() { + super::UnitVec::::new().realloc(0); + } + + #[test] + #[should_panic] + fn test_unitvec_realloc_one() { + super::UnitVec::::new().realloc(1); + } + + #[test] + #[should_panic] + fn test_untivec_realloc_lt_len() { + super::UnitVec::::from(&[1, 2][..]).realloc(1) + } + + #[test] + fn test_unitvec_clone() { + { + let v = unitvec![1usize]; + assert_eq!(v, v.clone()); + } + + for n in [ + 26903816120209729usize, + 42566276440897687, + 44435161834424652, + 49390731489933083, + 51201454727649242, + 83861672190814841, + 92169290527847622, + 92476373900398436, + 95488551309275459, + 97499984126814549, + ] { + let v = unitvec![n]; + assert_eq!(v, v.clone()); + } + } +} diff --git a/crates/polars-utils/src/index.rs b/crates/polars-utils/src/index.rs new file mode 100644 index 000000000000..b6575415b964 --- /dev/null +++ b/crates/polars-utils/src/index.rs @@ -0,0 +1,265 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use std::fmt::{Debug, Formatter}; + +use polars_error::{PolarsResult, polars_ensure}; + +use crate::nulls::IsNull; + +#[cfg(not(feature = "bigidx"))] +pub type IdxSize = u32; +#[cfg(feature = "bigidx")] +pub type IdxSize = u64; + +#[cfg(not(feature = "bigidx"))] +pub type NonZeroIdxSize = std::num::NonZeroU32; +#[cfg(feature = "bigidx")] +pub type NonZeroIdxSize = std::num::NonZeroU64; + +#[cfg(not(feature = "bigidx"))] +pub type AtomicIdxSize = std::sync::atomic::AtomicU32; +#[cfg(feature = "bigidx")] +pub type AtomicIdxSize = std::sync::atomic::AtomicU64; + +#[derive(Clone, Copy)] +#[repr(transparent)] +pub struct NullableIdxSize { + pub inner: IdxSize, +} + +impl PartialEq for NullableIdxSize { + fn eq(&self, other: &Self) -> bool { + self.inner == other.inner + } +} + +impl Eq for NullableIdxSize {} + +unsafe impl bytemuck::Zeroable for NullableIdxSize {} +unsafe impl bytemuck::AnyBitPattern for NullableIdxSize {} +unsafe impl bytemuck::NoUninit for NullableIdxSize {} + +impl Debug for NullableIdxSize { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.inner) + } +} + +impl NullableIdxSize { + #[inline(always)] + pub fn is_null_idx(&self) -> bool { + // The left/right join maintain_order algorithms depend on the special value for sorting + self.inner == IdxSize::MAX + } + + #[inline(always)] + pub const fn null() -> Self { + Self { + inner: IdxSize::MAX, + } + } + + #[inline(always)] + pub fn idx(&self) -> IdxSize { + self.inner + } + + #[inline(always)] + pub fn to_opt(&self) -> Option { + if self.is_null_idx() { + None + } else { + Some(self.idx()) + } + } +} + +impl From for NullableIdxSize { + #[inline(always)] + fn from(value: IdxSize) -> Self { + Self { inner: value } + } +} + +pub trait Bounded { + fn len(&self) -> usize; + + fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +pub trait NullCount { + fn null_count(&self) -> usize { + 0 + } +} + +impl NullCount for &T { + fn null_count(&self) -> usize { + (*self).null_count() + } +} + +impl Bounded for &[T] { + fn len(&self) -> usize { + <[T]>::len(self) + } +} + +impl NullCount for &[T] { + fn null_count(&self) -> usize { + 0 + } +} + +pub trait Indexable { + type Item: IsNull; + + fn get(&self, i: usize) -> Self::Item; + + /// # Safety + /// Doesn't do any bound checks. + unsafe fn get_unchecked(&self, i: usize) -> Self::Item; +} + +impl Indexable for &[T] { + type Item = T; + + fn get(&self, i: usize) -> Self::Item { + self[i] + } + + /// # Safety + /// Doesn't do any bound checks. + unsafe fn get_unchecked(&self, i: usize) -> Self::Item { + *<[T]>::get_unchecked(self, i) + } +} + +pub fn check_bounds(idx: &[IdxSize], len: IdxSize) -> PolarsResult<()> { + // We iterate in large uninterrupted chunks to help auto-vectorization. + let Some(max_idx) = idx.iter().copied().max() else { + return Ok(()); + }; + + polars_ensure!(max_idx < len, OutOfBounds: "indices are out of bounds"); + Ok(()) +} + +pub trait ToIdx { + fn to_idx(self, len: u64) -> IdxSize; +} + +macro_rules! impl_to_idx { + ($ty:ty) => { + impl ToIdx for $ty { + #[inline] + fn to_idx(self, _len: u64) -> IdxSize { + self as IdxSize + } + } + }; + ($ty:ty, $ity:ty) => { + impl ToIdx for $ty { + #[inline] + fn to_idx(self, len: u64) -> IdxSize { + let idx = self as $ity; + if idx < 0 { + (idx + len as $ity) as IdxSize + } else { + idx as IdxSize + } + } + } + }; +} + +impl_to_idx!(u8); +impl_to_idx!(u16); +impl_to_idx!(u32); +impl_to_idx!(u64); +impl_to_idx!(i8, i16); +impl_to_idx!(i16, i32); +impl_to_idx!(i32, i64); +impl_to_idx!(i64, i64); + +// Allows for 2^24 (~16M) chunks +// Leaves 2^40 (~1T) rows per chunk +const DEFAULT_CHUNK_BITS: u64 = 24; + +#[derive(Clone, Copy)] +#[repr(transparent)] +pub struct ChunkId { + swizzled: u64, +} + +impl Debug for ChunkId { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + if self.is_null() { + write!(f, "NULL") + } else { + let (chunk, row) = self.extract(); + write!(f, "({chunk}, {row})") + } + } +} + +impl ChunkId { + #[inline(always)] + pub const fn null() -> Self { + Self { swizzled: u64::MAX } + } + + #[inline(always)] + pub fn is_null(&self) -> bool { + self.swizzled == u64::MAX + } + + #[inline(always)] + #[allow(clippy::unnecessary_cast)] + pub fn store(chunk: IdxSize, row: IdxSize) -> Self { + debug_assert!(chunk < !(u64::MAX << CHUNK_BITS) as IdxSize); + let swizzled = ((row as u64) << CHUNK_BITS) | chunk as u64; + + Self { swizzled } + } + + #[inline(always)] + #[allow(clippy::unnecessary_cast)] + pub fn extract(self) -> (IdxSize, IdxSize) { + let row = (self.swizzled >> CHUNK_BITS) as IdxSize; + let mask = (1u64 << CHUNK_BITS) - 1; + let chunk = (self.swizzled & mask) as IdxSize; + (chunk, row) + } + + #[inline(always)] + pub fn inner_mut(&mut self) -> &mut u64 { + &mut self.swizzled + } + + pub fn from_inner(inner: u64) -> Self { + Self { swizzled: inner } + } + + pub fn into_inner(self) -> u64 { + self.swizzled + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_chunk_idx() { + let chunk = 213908; + let row = 813457; + + let ci: ChunkId = ChunkId::store(chunk, row); + let (c, r) = ci.extract(); + + assert_eq!(c, chunk); + assert_eq!(r, row); + } +} diff --git a/crates/polars-utils/src/io.rs b/crates/polars-utils/src/io.rs new file mode 100644 index 000000000000..563f8c399c27 --- /dev/null +++ b/crates/polars-utils/src/io.rs @@ -0,0 +1,35 @@ +use std::fs::File; +use std::io; +use std::path::Path; + +use polars_error::*; + +use crate::config::verbose; + +pub fn _limit_path_len_io_err(path: &Path, err: io::Error) -> PolarsError { + let path = path.to_string_lossy(); + let msg = if path.len() > 88 && !verbose() { + let truncated_path: String = path.chars().skip(path.len() - 88).collect(); + format!("{err}: ...{truncated_path}") + } else { + format!("{err}: {path}") + }; + io::Error::new(err.kind(), msg).into() +} + +pub fn open_file(path: &Path) -> PolarsResult { + File::open(path).map_err(|err| _limit_path_len_io_err(path, err)) +} + +pub fn open_file_write(path: &Path) -> PolarsResult { + std::fs::OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(path) + .map_err(|err| _limit_path_len_io_err(path, err)) +} + +pub fn create_file(path: &Path) -> PolarsResult { + File::create(path).map_err(|err| _limit_path_len_io_err(path, err)) +} diff --git a/crates/polars-utils/src/itertools/enumerate_idx.rs b/crates/polars-utils/src/itertools/enumerate_idx.rs new file mode 100644 index 000000000000..ef079fe33011 --- /dev/null +++ b/crates/polars-utils/src/itertools/enumerate_idx.rs @@ -0,0 +1,101 @@ +use num_traits::{FromPrimitive, One, Zero}; + +/// An iterator that yields the current count and the element during iteration. +/// +/// This `struct` is created by the [`enumerate`] method on [`Iterator`]. See its +/// documentation for more. +/// +/// [`enumerate`]: Iterator::enumerate +/// [`Iterator`]: trait.Iterator.html +#[derive(Clone, Debug)] +#[must_use = "iterators are lazy and do nothing unless consumed"] +pub struct EnumerateIdx { + iter: I, + count: IdxType, +} + +impl EnumerateIdx { + pub fn new(iter: I) -> Self { + Self { + iter, + count: IdxType::zero(), + } + } +} + +impl Iterator for EnumerateIdx +where + I: Iterator, + IdxType: std::ops::Add + FromPrimitive + std::ops::AddAssign + One + Copy, +{ + type Item = (IdxType, ::Item); + + /// # Overflow Behavior + /// + /// The method does no guarding against overflows, so enumerating more than + /// `idx::MAX` elements either produces the wrong result or panics. If + /// debug assertions are enabled, a panic is guaranteed. + /// + /// # Panics + /// + /// Might panic if the index of the element overflows a `idx`. + #[inline] + fn next(&mut self) -> Option { + let a = self.iter.next()?; + let i = self.count; + self.count += IdxType::one(); + Some((i, a)) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } + + #[inline] + fn nth(&mut self, n: usize) -> Option { + let a = self.iter.nth(n)?; + let i = self.count + IdxType::from_usize(n).unwrap(); + self.count = i + IdxType::one(); + Some((i, a)) + } + + #[inline] + fn count(self) -> usize { + self.iter.count() + } +} + +impl DoubleEndedIterator for EnumerateIdx +where + I: ExactSizeIterator + DoubleEndedIterator, + IdxType: std::ops::Add + FromPrimitive + std::ops::AddAssign + One + Copy, +{ + #[inline] + fn next_back(&mut self) -> Option<(IdxType, ::Item)> { + let a = self.iter.next_back()?; + let len = IdxType::from_usize(self.iter.len()).unwrap(); + // Can safely add, `ExactSizeIterator` promises that the number of + // elements fits into a `usize`. + Some((self.count + len, a)) + } + + #[inline] + fn nth_back(&mut self, n: usize) -> Option<(IdxType, ::Item)> { + let a = self.iter.nth_back(n)?; + let len = IdxType::from_usize(self.iter.len()).unwrap(); + // Can safely add, `ExactSizeIterator` promises that the number of + // elements fits into a `usize`. + Some((self.count + len, a)) + } +} + +impl ExactSizeIterator for EnumerateIdx +where + I: ExactSizeIterator, + IdxType: std::ops::Add + FromPrimitive + std::ops::AddAssign + One + Copy, +{ + fn len(&self) -> usize { + self.iter.len() + } +} diff --git a/crates/polars-utils/src/itertools/mod.rs b/crates/polars-utils/src/itertools/mod.rs new file mode 100644 index 000000000000..7c755444be1f --- /dev/null +++ b/crates/polars-utils/src/itertools/mod.rs @@ -0,0 +1,106 @@ +use std::cmp::Ordering; + +use crate::IdxSize; + +pub mod enumerate_idx; + +/// Utility extension trait of iterator methods. +pub trait Itertools: Iterator { + /// Equivalent to `.collect::>()`. + fn collect_vec(self) -> Vec + where + Self: Sized, + { + self.collect() + } + + /// Equivalent to `.collect::>()`. + fn try_collect(self) -> Result + where + Self: Sized + Iterator>, + Result: FromIterator>, + { + self.collect() + } + + /// Equivalent to `.collect::, _>>()`. + fn try_collect_vec(self) -> Result, E> + where + Self: Sized + Iterator>, + Result, E>: FromIterator>, + { + self.collect() + } + + fn enumerate_idx(self) -> enumerate_idx::EnumerateIdx + where + Self: Sized, + { + enumerate_idx::EnumerateIdx::new(self) + } + + fn enumerate_u32(self) -> enumerate_idx::EnumerateIdx + where + Self: Sized, + { + enumerate_idx::EnumerateIdx::new(self) + } + + fn all_equal(mut self) -> bool + where + Self: Sized, + Self::Item: PartialEq, + { + match self.next() { + None => true, + Some(a) => self.all(|x| a == x), + } + } + + // Stable copy of the unstable eq_by from the stdlib. + fn eq_by_(mut self, other: I, mut eq: F) -> bool + where + Self: Sized, + I: IntoIterator, + F: FnMut(Self::Item, I::Item) -> bool, + { + let mut other = other.into_iter(); + loop { + match (self.next(), other.next()) { + (None, None) => return true, + (None, Some(_)) => return false, + (Some(_), None) => return false, + (Some(l), Some(r)) => { + if eq(l, r) { + continue; + } else { + return false; + } + }, + } + } + } + + // Stable copy of the unstable partial_cmp_by from the stdlib. + fn partial_cmp_by_(mut self, other: I, mut partial_cmp: F) -> Option + where + Self: Sized, + I: IntoIterator, + F: FnMut(Self::Item, I::Item) -> Option, + { + let mut other = other.into_iter(); + loop { + match (self.next(), other.next()) { + (None, None) => return Some(Ordering::Equal), + (None, Some(_)) => return Some(Ordering::Less), + (Some(_), None) => return Some(Ordering::Greater), + (Some(l), Some(r)) => match partial_cmp(l, r) { + Some(Ordering::Equal) => continue, + ord => return ord, + }, + } + } + } +} + +impl Itertools for T {} diff --git a/crates/polars-utils/src/kahan_sum.rs b/crates/polars-utils/src/kahan_sum.rs new file mode 100644 index 000000000000..4a375ad77779 --- /dev/null +++ b/crates/polars-utils/src/kahan_sum.rs @@ -0,0 +1,54 @@ +use std::ops::{Add, AddAssign}; + +use num_traits::Float; + +#[derive(Debug, Clone)] +pub struct KahanSum { + sum: T, + err: T, +} + +impl KahanSum { + pub fn new(v: T) -> Self { + KahanSum { + sum: v, + err: T::zero(), + } + } + + pub fn sum(&self) -> T { + self.sum + } +} + +impl Default for KahanSum { + fn default() -> Self { + KahanSum { + sum: T::zero(), + err: T::zero(), + } + } +} + +impl AddAssign for KahanSum { + fn add_assign(&mut self, rhs: T) { + if rhs.is_finite() { + let y = rhs - self.err; + let new_sum = self.sum + y; + self.err = (new_sum - self.sum) - y; + self.sum = new_sum; + } else { + self.sum += rhs + } + } +} + +impl Add for KahanSum { + type Output = Self; + + fn add(self, rhs: T) -> Self::Output { + let mut rv = self; + rv += rhs; + rv + } +} diff --git a/crates/polars-utils/src/lib.rs b/crates/polars-utils/src/lib.rs new file mode 100644 index 000000000000..70fbebffbe74 --- /dev/null +++ b/crates/polars-utils/src/lib.rs @@ -0,0 +1,75 @@ +#![cfg_attr( + all(target_arch = "aarch64", feature = "nightly"), + feature(stdarch_aarch64_prefetch) +)] +#![cfg_attr(feature = "nightly", feature(core_intrinsics))] // For algebraic ops. +#![cfg_attr(feature = "nightly", feature(select_unpredictable))] // For branchless programming. +#![cfg_attr(feature = "nightly", allow(internal_features))] +#![cfg_attr(docsrs, feature(doc_auto_cfg))] +pub mod abs_diff; +pub mod algebraic_ops; +pub mod arena; +pub mod binary_search; +pub mod cache; +pub mod cardinality_sketch; +pub mod cell; +pub mod chunks; +pub mod clmul; +mod config; +pub mod cpuid; +pub mod error; +pub mod floor_divmod; +pub mod functions; +pub mod hashing; +pub mod idx_map; +pub mod idx_mapper; +pub mod idx_vec; +pub mod mem; +pub mod min_max; +pub mod pl_str; +pub mod priority; +pub mod regex_cache; +pub mod select; +pub mod slice; +pub mod slice_enum; +pub mod sort; +pub mod sparse_init_vec; +pub mod sync; +#[cfg(feature = "sysinfo")] +pub mod sys; +pub mod total_ord; + +pub use functions::*; +pub mod file; + +pub mod aliases; +pub mod fixedringbuffer; +pub mod fmt; +pub mod itertools; +pub mod macros; +pub mod vec; +#[cfg(target_family = "wasm")] +pub mod wasm; + +pub mod float; +pub mod index; +pub mod io; +#[cfg(feature = "mmap")] +pub mod mmap; +pub mod nulls; +pub mod partitioned; + +pub use index::{IdxSize, NullableIdxSize}; +pub use io::*; +pub use pl_str::unique_column_name; + +#[cfg(feature = "python")] +pub mod python_function; + +#[cfg(feature = "python")] +pub mod python_convert_registry; + +#[cfg(feature = "serde")] +pub mod pl_serialize; + +pub mod kahan_sum; diff --git a/crates/polars-utils/src/macros.rs b/crates/polars-utils/src/macros.rs new file mode 100644 index 000000000000..50511991773f --- /dev/null +++ b/crates/polars-utils/src/macros.rs @@ -0,0 +1,15 @@ +#[macro_export] +macro_rules! matches_any_order { + ($expression1:expr, $expression2:expr, $( $pattern1:pat_param )|+, $( $pattern2:pat_param )|+) => { + (matches!($expression1, $( $pattern1 )|+) && matches!($expression2, $( $pattern2)|+)) || + matches!($expression2, $( $pattern1 ) |+) && matches!($expression1, $( $pattern2)|+) + } +} + +#[macro_export] +macro_rules! no_call_const { + () => {{ + const { assert!(false, "should not be called") } + unreachable!() + }}; +} diff --git a/crates/polars-utils/src/mem.rs b/crates/polars-utils/src/mem.rs new file mode 100644 index 000000000000..2bc8cfc7afef --- /dev/null +++ b/crates/polars-utils/src/mem.rs @@ -0,0 +1,181 @@ +use std::sync::LazyLock; + +/// # Safety +/// This may break aliasing rules, make sure you are the only owner. +#[allow(clippy::mut_from_ref)] +pub unsafe fn to_mutable_slice(s: &[T]) -> &mut [T] { + let ptr = s.as_ptr() as *mut T; + let len = s.len(); + unsafe { std::slice::from_raw_parts_mut(ptr, len) } +} + +pub static PAGE_SIZE: LazyLock = LazyLock::new(|| { + #[cfg(target_family = "unix")] + unsafe { + libc::sysconf(libc::_SC_PAGESIZE) as usize + } + #[cfg(not(target_family = "unix"))] + { + 4096 + } +}); + +pub mod prefetch { + use super::PAGE_SIZE; + + /// # Safety + /// + /// This should only be called with pointers to valid memory. + unsafe fn prefetch_l2_impl(ptr: *const u8) { + #[cfg(target_arch = "x86_64")] + { + use std::arch::x86_64::*; + unsafe { _mm_prefetch(ptr as *const _, _MM_HINT_T1) }; + } + + #[cfg(all(target_arch = "aarch64", feature = "nightly"))] + { + use std::arch::aarch64::*; + unsafe { _prefetch(ptr as *const _, _PREFETCH_READ, _PREFETCH_LOCALITY2) }; + } + } + + /// Attempt to prefetch the memory in the slice to the L2 cache. + pub fn prefetch_l2(slice: &[u8]) { + if slice.is_empty() { + return; + } + + // @TODO: We can play a bit more with this prefetching. Maybe introduce a maximum number of + // prefetches as to not overwhelm the processor. The linear prefetcher should pick it up + // at a certain point. + + for i in (0..slice.len()).step_by(*PAGE_SIZE) { + unsafe { prefetch_l2_impl(slice[i..].as_ptr()) }; + } + + unsafe { prefetch_l2_impl(slice[slice.len() - 1..].as_ptr()) } + } + + /// `madvise()` with `MADV_SEQUENTIAL` on unix systems. This is a no-op on non-unix systems. + pub fn madvise_sequential(#[allow(unused)] slice: &[u8]) { + #[cfg(target_family = "unix")] + madvise(slice, libc::MADV_SEQUENTIAL); + } + + /// `madvise()` with `MADV_WILLNEED` on unix systems. This is a no-op on non-unix systems. + pub fn madvise_willneed(#[allow(unused)] slice: &[u8]) { + #[cfg(target_family = "unix")] + madvise(slice, libc::MADV_WILLNEED); + } + + /// `madvise()` with `MADV_POPULATE_READ` on linux systems. This a no-op on non-linux systems. + pub fn madvise_populate_read(#[allow(unused)] slice: &[u8]) { + #[cfg(target_os = "linux")] + madvise(slice, libc::MADV_POPULATE_READ); + } + + /// Forcibly reads at least one byte each page. + pub fn force_populate_read(slice: &[u8]) { + for i in (0..slice.len()).step_by(*PAGE_SIZE) { + std::hint::black_box(slice[i]); + } + + std::hint::black_box(slice.last().copied()); + } + + #[cfg(target_family = "unix")] + fn madvise(slice: &[u8], advice: libc::c_int) { + if slice.is_empty() { + return; + } + let ptr = slice.as_ptr(); + + let align = ptr as usize % *PAGE_SIZE; + let ptr = ptr.wrapping_sub(align); + let len = slice.len() + align; + + if unsafe { libc::madvise(ptr as *mut libc::c_void, len, advice) } != 0 { + let err = std::io::Error::last_os_error(); + if let std::io::ErrorKind::InvalidInput = err.kind() { + panic!("{}", err); + } + } + } + + pub fn no_prefetch(_: &[u8]) {} + + /// Get the configured memory prefetch function. + pub fn get_memory_prefetch_func(verbose: bool) -> fn(&[u8]) -> () { + let memory_prefetch_func = match std::env::var("POLARS_MEMORY_PREFETCH").ok().as_deref() { + None => { + // madvise_willneed performed the best on both MacOS on Apple Silicon and Ubuntu on x86-64, + // using PDS-H query 3 SF=10 after clearing file cache as a benchmark. + #[cfg(target_family = "unix")] + { + madvise_willneed + } + #[cfg(not(target_family = "unix"))] + { + no_prefetch + } + }, + Some("no_prefetch") => no_prefetch, + Some("prefetch_l2") => prefetch_l2, + Some("madvise_sequential") => { + #[cfg(target_family = "unix")] + { + madvise_sequential + } + #[cfg(not(target_family = "unix"))] + { + panic!( + "POLARS_MEMORY_PREFETCH=madvise_sequential is not supported by this system" + ); + } + }, + Some("madvise_willneed") => { + #[cfg(target_family = "unix")] + { + madvise_willneed + } + #[cfg(not(target_family = "unix"))] + { + panic!( + "POLARS_MEMORY_PREFETCH=madvise_willneed is not supported by this system" + ); + } + }, + Some("madvise_populate_read") => { + #[cfg(target_os = "linux")] + { + madvise_populate_read + } + #[cfg(not(target_os = "linux"))] + { + panic!( + "POLARS_MEMORY_PREFETCH=madvise_populate_read is not supported by this system" + ); + } + }, + Some("force_populate_read") => force_populate_read, + Some(v) => panic!("invalid value for POLARS_MEMORY_PREFETCH: {}", v), + }; + + if verbose { + let func_name = match memory_prefetch_func as usize { + v if v == no_prefetch as usize => "no_prefetch", + v if v == prefetch_l2 as usize => "prefetch_l2", + v if v == madvise_sequential as usize => "madvise_sequential", + v if v == madvise_willneed as usize => "madvise_willneed", + v if v == madvise_populate_read as usize => "madvise_populate_read", + v if v == force_populate_read as usize => "force_populate_read", + _ => unreachable!(), + }; + + eprintln!("memory prefetch function: {}", func_name); + } + + memory_prefetch_func + } +} diff --git a/crates/polars-utils/src/min_max.rs b/crates/polars-utils/src/min_max.rs new file mode 100644 index 000000000000..4632906b3de8 --- /dev/null +++ b/crates/polars-utils/src/min_max.rs @@ -0,0 +1,174 @@ +// These min/max operators don't follow our total order strictly. Instead +// if exactly one of the two arguments is NaN the skip_nan varieties returns +// the non-nan argument, whereas the propagate_nan varieties give the nan +// argument. If both/neither argument is NaN these extrema follow the normal +// total order. +// +// They also violate the regular total order for Option: on top of the +// above rules None's are always ignored, so only if both arguments are +// None is the output None. +pub trait MinMax: Sized { + // Comparison operators that either consider nan to be the smallest, or the + // largest possible value. Use tot_eq for equality. Prefer directly using + // min/max, they're slightly faster. + fn nan_min_lt(&self, other: &Self) -> bool; + fn nan_max_lt(&self, other: &Self) -> bool; + + // Binary operators that return either the minimum or maximum. + #[inline(always)] + fn min_propagate_nan(self, other: Self) -> Self { + if self.nan_min_lt(&other) { self } else { other } + } + + #[inline(always)] + fn max_propagate_nan(self, other: Self) -> Self { + if self.nan_max_lt(&other) { other } else { self } + } + + #[inline(always)] + fn min_ignore_nan(self, other: Self) -> Self { + if self.nan_max_lt(&other) { self } else { other } + } + + #[inline(always)] + fn max_ignore_nan(self, other: Self) -> Self { + if self.nan_min_lt(&other) { other } else { self } + } +} + +macro_rules! impl_trivial_min_max { + ($T: ty) => { + impl MinMax for $T { + #[inline(always)] + fn nan_min_lt(&self, other: &Self) -> bool { + self < other + } + + #[inline(always)] + fn nan_max_lt(&self, other: &Self) -> bool { + self < other + } + } + }; +} + +// We can't do a blanket impl because Rust complains f32 might implement +// Ord someday. +impl_trivial_min_max!(bool); +impl_trivial_min_max!(u8); +impl_trivial_min_max!(u16); +impl_trivial_min_max!(u32); +impl_trivial_min_max!(u64); +impl_trivial_min_max!(u128); +impl_trivial_min_max!(usize); +impl_trivial_min_max!(i8); +impl_trivial_min_max!(i16); +impl_trivial_min_max!(i32); +impl_trivial_min_max!(i64); +impl_trivial_min_max!(i128); +impl_trivial_min_max!(isize); +impl_trivial_min_max!(char); +impl_trivial_min_max!(&str); +impl_trivial_min_max!(&[u8]); +impl_trivial_min_max!(String); + +macro_rules! impl_float_min_max { + ($T: ty) => { + impl MinMax for $T { + #[inline(always)] + fn nan_min_lt(&self, other: &Self) -> bool { + !(other.is_nan() | (self >= other)) + } + + #[inline(always)] + fn nan_max_lt(&self, other: &Self) -> bool { + !(self.is_nan() | (self >= other)) + } + + #[inline(always)] + fn min_ignore_nan(self, other: Self) -> Self { + <$T>::min(self, other) + } + + #[inline(always)] + fn max_ignore_nan(self, other: Self) -> Self { + <$T>::max(self, other) + } + + #[inline(always)] + fn min_propagate_nan(self, other: Self) -> Self { + if (self < other) | self.is_nan() { + self + } else { + other + } + } + + #[inline(always)] + fn max_propagate_nan(self, other: Self) -> Self { + if (self > other) | self.is_nan() { + self + } else { + other + } + } + } + }; +} + +impl_float_min_max!(f32); +impl_float_min_max!(f64); + +pub trait MinMaxPolicy { + // Is the first argument strictly better than the second, per the policy? + fn is_better(a: &T, b: &T) -> bool; + fn best(a: T, b: T) -> T; +} + +#[derive(Copy, Clone, Debug)] +pub struct MinIgnoreNan; +impl MinMaxPolicy for MinIgnoreNan { + fn is_better(a: &T, b: &T) -> bool { + T::nan_max_lt(a, b) + } + + fn best(a: T, b: T) -> T { + T::min_ignore_nan(a, b) + } +} + +#[derive(Copy, Clone, Debug)] +pub struct MinPropagateNan; +impl MinMaxPolicy for MinPropagateNan { + fn is_better(a: &T, b: &T) -> bool { + T::nan_min_lt(a, b) + } + + fn best(a: T, b: T) -> T { + T::min_propagate_nan(a, b) + } +} + +#[derive(Copy, Clone, Debug)] +pub struct MaxIgnoreNan; +impl MinMaxPolicy for MaxIgnoreNan { + fn is_better(a: &T, b: &T) -> bool { + T::nan_min_lt(b, a) + } + + fn best(a: T, b: T) -> T { + T::max_ignore_nan(a, b) + } +} + +#[derive(Copy, Clone, Debug)] +pub struct MaxPropagateNan; +impl MinMaxPolicy for MaxPropagateNan { + fn is_better(a: &T, b: &T) -> bool { + T::nan_max_lt(b, a) + } + + fn best(a: T, b: T) -> T { + T::max_propagate_nan(a, b) + } +} diff --git a/crates/polars-utils/src/mmap.rs b/crates/polars-utils/src/mmap.rs new file mode 100644 index 000000000000..1f5d498f3266 --- /dev/null +++ b/crates/polars-utils/src/mmap.rs @@ -0,0 +1,489 @@ +use std::ffi::c_void; +use std::fs::File; +use std::io; +use std::mem::ManuallyDrop; +use std::sync::LazyLock; + +pub use memmap::Mmap; + +mod private { + use std::fs::File; + use std::ops::Deref; + use std::sync::Arc; + + use polars_error::PolarsResult; + + use super::MMapSemaphore; + use crate::mem::prefetch::prefetch_l2; + + /// A read-only reference to a slice of memory that can potentially be memory-mapped. + /// + /// A reference count is kept to the underlying buffer to ensure the memory is kept alive. + /// [`MemSlice::slice`] can be used to slice the memory in a zero-copy manner. + /// + /// This still owns the all the original memory and therefore should probably not be a long-lasting + /// structure. + #[derive(Clone, Debug)] + pub struct MemSlice { + // Store the `&[u8]` to make the `Deref` free. + // `slice` is not 'static - it is backed by `inner`. This is safe as long as `slice` is not + // directly accessed, and we are in a private module to guarantee that. Access should only + // be done through `Deref`, which automatically gives the correct lifetime. + slice: &'static [u8], + #[allow(unused)] + inner: MemSliceInner, + } + + /// Keeps the underlying buffer alive. This should be cheaply cloneable. + #[derive(Clone, Debug)] + #[allow(unused)] + enum MemSliceInner { + Bytes(bytes::Bytes), // Separate because it does atomic refcounting internally + Arc(Arc), + } + + impl Deref for MemSlice { + type Target = [u8]; + + #[inline(always)] + fn deref(&self) -> &Self::Target { + self.slice + } + } + + impl AsRef<[u8]> for MemSlice { + #[inline(always)] + fn as_ref(&self) -> &[u8] { + self.slice + } + } + + impl Default for MemSlice { + fn default() -> Self { + Self::from_bytes(bytes::Bytes::new()) + } + } + + impl From> for MemSlice { + fn from(value: Vec) -> Self { + Self::from_vec(value) + } + } + + impl MemSlice { + pub const EMPTY: Self = Self::from_static(&[]); + + /// Copy the contents into a new owned `Vec` + #[inline(always)] + pub fn to_vec(self) -> Vec { + <[u8]>::to_vec(self.deref()) + } + + /// Construct a `MemSlice` from an existing `Vec`. This is zero-copy. + #[inline] + pub fn from_vec(v: Vec) -> Self { + Self::from_bytes(bytes::Bytes::from(v)) + } + + /// Construct a `MemSlice` from [`bytes::Bytes`]. This is zero-copy. + #[inline] + pub fn from_bytes(bytes: bytes::Bytes) -> Self { + Self { + slice: unsafe { std::mem::transmute::<&[u8], &'static [u8]>(bytes.as_ref()) }, + inner: MemSliceInner::Bytes(bytes), + } + } + + #[inline] + pub fn from_mmap(mmap: Arc) -> Self { + Self { + slice: unsafe { + std::mem::transmute::<&[u8], &'static [u8]>(mmap.as_ref().as_ref()) + }, + inner: MemSliceInner::Arc(mmap), + } + } + + #[inline] + pub fn from_arc(slice: &[u8], arc: Arc) -> Self + where + T: std::fmt::Debug + Send + Sync + 'static, + { + Self { + slice: unsafe { std::mem::transmute::<&[u8], &'static [u8]>(slice) }, + inner: MemSliceInner::Arc(arc), + } + } + + #[inline] + pub fn from_file(file: &File) -> PolarsResult { + let mmap = MMapSemaphore::new_from_file(file)?; + Ok(Self::from_mmap(Arc::new(mmap))) + } + + /// Construct a `MemSlice` that simply wraps around a `&[u8]`. + #[inline] + pub const fn from_static(slice: &'static [u8]) -> Self { + let inner = MemSliceInner::Bytes(bytes::Bytes::from_static(slice)); + Self { slice, inner } + } + + /// Attempt to prefetch the memory belonging to to this [`MemSlice`] + #[inline] + pub fn prefetch(&self) { + prefetch_l2(self.as_ref()); + } + + /// # Panics + /// Panics if range is not in bounds. + #[inline] + #[track_caller] + pub fn slice(&self, range: std::ops::Range) -> Self { + let mut out = self.clone(); + out.slice = &out.slice[range]; + out + } + } + + impl From for MemSlice { + fn from(value: bytes::Bytes) -> Self { + Self::from_bytes(value) + } + } +} + +use memmap::MmapOptions; +use polars_error::PolarsResult; +#[cfg(target_family = "unix")] +use polars_error::polars_bail; +pub use private::MemSlice; +use rayon::{ThreadPool, ThreadPoolBuilder}; + +use crate::mem::PAGE_SIZE; + +/// A cursor over a [`MemSlice`]. +#[derive(Debug, Clone)] +pub struct MemReader { + data: MemSlice, + position: usize, +} + +impl MemReader { + pub fn new(data: MemSlice) -> Self { + Self { data, position: 0 } + } + + #[inline(always)] + pub fn remaining_len(&self) -> usize { + self.data.len() - self.position + } + + #[inline(always)] + pub fn total_len(&self) -> usize { + self.data.len() + } + + #[inline(always)] + pub fn position(&self) -> usize { + self.position + } + + /// Construct a `MemSlice` from an existing `Vec`. This is zero-copy. + #[inline(always)] + pub fn from_vec(v: Vec) -> Self { + Self::new(MemSlice::from_vec(v)) + } + + /// Construct a `MemSlice` from [`bytes::Bytes`]. This is zero-copy. + #[inline(always)] + pub fn from_bytes(bytes: bytes::Bytes) -> Self { + Self::new(MemSlice::from_bytes(bytes)) + } + + // Construct a `MemSlice` that simply wraps around a `&[u8]`. The caller must ensure the + /// slice outlives the returned `MemSlice`. + #[inline] + pub fn from_slice(slice: &'static [u8]) -> Self { + Self::new(MemSlice::from_static(slice)) + } + + #[inline(always)] + pub fn from_reader(mut reader: R) -> io::Result { + let mut vec = Vec::new(); + reader.read_to_end(&mut vec)?; + Ok(Self::from_vec(vec)) + } + + #[inline(always)] + pub fn read_slice(&mut self, n: usize) -> MemSlice { + let start = self.position; + let end = usize::min(self.position + n, self.data.len()); + self.position = end; + self.data.slice(start..end) + } +} + +impl From for MemReader { + fn from(data: MemSlice) -> Self { + Self { data, position: 0 } + } +} + +impl io::Read for MemReader { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let n = usize::min(buf.len(), self.remaining_len()); + buf[..n].copy_from_slice(&self.data[self.position..self.position + n]); + self.position += n; + Ok(n) + } +} + +impl io::Seek for MemReader { + fn seek(&mut self, pos: io::SeekFrom) -> io::Result { + let position = match pos { + io::SeekFrom::Start(position) => usize::min(position as usize, self.total_len()), + io::SeekFrom::End(offset) => { + let Some(position) = self.total_len().checked_add_signed(offset as isize) else { + return Err(io::Error::other("Seek before to before buffer")); + }; + + position + }, + io::SeekFrom::Current(offset) => { + let Some(position) = self.position.checked_add_signed(offset as isize) else { + return Err(io::Error::other("Seek before to before buffer")); + }; + + position + }, + }; + + self.position = position; + + Ok(position as u64) + } +} + +pub static UNMAP_POOL: LazyLock = LazyLock::new(|| { + let thread_name = std::env::var("POLARS_THREAD_NAME").unwrap_or_else(|_| "polars".to_string()); + ThreadPoolBuilder::new() + .num_threads(1) + .thread_name(move |i| format!("{}-unmap-{}", thread_name, i)) + .build() + .expect("could not spawn threads") +}); + +// Keep track of memory mapped files so we don't write to them while reading +// Use a btree as it uses less memory than a hashmap and this thing never shrinks. +// Write handle in Windows is exclusive, so this is only necessary in Unix. +#[cfg(target_family = "unix")] +static MEMORY_MAPPED_FILES: std::sync::LazyLock< + std::sync::Mutex>, +> = std::sync::LazyLock::new(|| std::sync::Mutex::new(Default::default())); + +#[derive(Debug)] +pub struct MMapSemaphore { + #[cfg(target_family = "unix")] + key: (u64, u64), + mmap: ManuallyDrop, +} + +impl Drop for MMapSemaphore { + fn drop(&mut self) { + #[cfg(target_family = "unix")] + { + let mut guard = MEMORY_MAPPED_FILES.lock().unwrap(); + if let std::collections::btree_map::Entry::Occupied(mut e) = guard.entry(self.key) { + let v = e.get_mut(); + *v -= 1; + + if *v == 0 { + e.remove_entry(); + } + } + } + + unsafe { + let mmap = ManuallyDrop::take(&mut self.mmap); + // If the unmap is 1 MiB or bigger, we do it in a background thread. + let len = self.mmap.len(); + if len >= 1024 * 1024 { + UNMAP_POOL.spawn(move || { + #[cfg(target_family = "unix")] + { + // If the unmap is bigger than our chunk size (32 MiB), we do it in chunks. + // This is because munmap holds a lock on the unmap file, which we don't + // want to hold for extended periods of time. + let chunk_size = (32_usize * 1024 * 1024).next_multiple_of(*PAGE_SIZE); + if len > chunk_size { + let mmap = ManuallyDrop::new(mmap); + let ptr: *const u8 = mmap.as_ptr(); + let mut offset = 0; + while offset < len { + let remaining = len - offset; + libc::munmap( + ptr.add(offset) as *mut c_void, + remaining.min(chunk_size), + ); + offset += chunk_size; + } + return; + } + } + drop(mmap) + }); + } else { + drop(mmap); + } + } + } +} + +impl MMapSemaphore { + pub fn new_from_file_with_options( + file: &File, + options: MmapOptions, + ) -> PolarsResult { + let mmap = unsafe { options.map(file) }?; + + #[cfg(target_family = "unix")] + { + // FIXME: We aren't handling the case where the file is already open in write-mode here. + + use std::os::unix::fs::MetadataExt; + let metadata = file.metadata()?; + + let mut guard = MEMORY_MAPPED_FILES.lock().unwrap(); + let key = (metadata.dev(), metadata.ino()); + match guard.entry(key) { + std::collections::btree_map::Entry::Occupied(mut e) => *e.get_mut() += 1, + std::collections::btree_map::Entry::Vacant(e) => _ = e.insert(1), + } + Ok(Self { + key, + mmap: ManuallyDrop::new(mmap), + }) + } + + #[cfg(not(target_family = "unix"))] + Ok(Self { + mmap: ManuallyDrop::new(mmap), + }) + } + + pub fn new_from_file(file: &File) -> PolarsResult { + Self::new_from_file_with_options(file, MmapOptions::default()) + } + + pub fn as_ptr(&self) -> *const u8 { + self.mmap.as_ptr() + } +} + +impl AsRef<[u8]> for MMapSemaphore { + #[inline] + fn as_ref(&self) -> &[u8] { + self.mmap.as_ref() + } +} + +pub fn ensure_not_mapped( + #[cfg_attr(not(target_family = "unix"), allow(unused))] file_md: &std::fs::Metadata, +) -> PolarsResult<()> { + // TODO: We need to actually register that this file has been write-opened and prevent + // read-opening this file based on that. + #[cfg(target_family = "unix")] + { + use std::os::unix::fs::MetadataExt; + let guard = MEMORY_MAPPED_FILES.lock().unwrap(); + if guard.contains_key(&(file_md.dev(), file_md.ino())) { + polars_bail!(ComputeError: "cannot write to file: already memory mapped"); + } + } + Ok(()) +} + +mod tests { + #[test] + fn test_mem_slice_zero_copy() { + use std::sync::Arc; + + use super::MemSlice; + + { + let vec = vec![1u8, 2, 3, 4, 5]; + let ptr = vec.as_ptr(); + + let mem_slice = MemSlice::from_vec(vec); + let ptr_out = mem_slice.as_ptr(); + + assert_eq!(ptr_out, ptr); + } + + { + let mut vec = vec![1u8, 2, 3, 4, 5]; + vec.truncate(2); + let ptr = vec.as_ptr(); + + let mem_slice = MemSlice::from_vec(vec); + let ptr_out = mem_slice.as_ptr(); + + assert_eq!(ptr_out, ptr); + } + + { + let bytes = bytes::Bytes::from(vec![1u8, 2, 3, 4, 5]); + let ptr = bytes.as_ptr(); + + let mem_slice = MemSlice::from_bytes(bytes); + let ptr_out = mem_slice.as_ptr(); + + assert_eq!(ptr_out, ptr); + } + + { + use crate::mmap::MMapSemaphore; + + let path = "../../examples/datasets/foods1.csv"; + let file = std::fs::File::open(path).unwrap(); + let mmap = MMapSemaphore::new_from_file(&file).unwrap(); + let ptr = mmap.as_ptr(); + + let mem_slice = MemSlice::from_mmap(Arc::new(mmap)); + let ptr_out = mem_slice.as_ptr(); + + assert_eq!(ptr_out, ptr); + } + + { + let vec = vec![1u8, 2, 3, 4, 5]; + let slice = vec.as_slice(); + let ptr = slice.as_ptr(); + + let mem_slice = MemSlice::from_static(unsafe { + std::mem::transmute::<&[u8], &'static [u8]>(slice) + }); + let ptr_out = mem_slice.as_ptr(); + + assert_eq!(ptr_out, ptr); + } + } + + #[test] + fn test_mem_slice_slicing() { + use super::MemSlice; + + { + let vec = vec![1u8, 2, 3, 4, 5]; + let slice = vec.as_slice(); + + let mem_slice = MemSlice::from_static(unsafe { + std::mem::transmute::<&[u8], &'static [u8]>(slice) + }); + + let out = &*mem_slice.slice(3..5); + assert_eq!(out, &slice[3..5]); + assert_eq!(out.as_ptr(), slice[3..5].as_ptr()); + } + } +} diff --git a/crates/polars-utils/src/nulls.rs b/crates/polars-utils/src/nulls.rs new file mode 100644 index 000000000000..3fe40e5eb86a --- /dev/null +++ b/crates/polars-utils/src/nulls.rs @@ -0,0 +1,85 @@ +pub trait IsNull { + const HAS_NULLS: bool; + type Inner; + + fn is_null(&self) -> bool; + + fn unwrap_inner(self) -> Self::Inner; +} + +impl IsNull for Option { + const HAS_NULLS: bool = true; + type Inner = T; + + #[inline(always)] + fn is_null(&self) -> bool { + self.is_none() + } + + #[inline(always)] + fn unwrap_inner(self) -> Self::Inner { + Option::unwrap(self) + } +} + +macro_rules! impl_is_null ( + ($($ty:tt)*) => { + impl IsNull for $($ty)* { + const HAS_NULLS: bool = false; + type Inner = $($ty)*; + + + #[inline(always)] + fn is_null(&self) -> bool { + false + } + + #[inline(always)] + fn unwrap_inner(self) -> $($ty)* { + self + } + } + }; +); + +impl_is_null!(bool); +impl_is_null!(f32); +impl_is_null!(f64); +impl_is_null!(i8); +impl_is_null!(i16); +impl_is_null!(i32); +impl_is_null!(i64); +impl_is_null!(i128); +impl_is_null!(u8); +impl_is_null!(u16); +impl_is_null!(u32); +impl_is_null!(u64); +impl_is_null!(u128); + +impl<'a> IsNull for &'a [u8] { + const HAS_NULLS: bool = false; + type Inner = &'a [u8]; + + #[inline(always)] + fn is_null(&self) -> bool { + false + } + + #[inline(always)] + fn unwrap_inner(self) -> Self::Inner { + self + } +} + +impl<'a, T: IsNull + ?Sized> IsNull for &'a T { + const HAS_NULLS: bool = false; + type Inner = &'a T; + + fn is_null(&self) -> bool { + (*self).is_null() + } + + fn unwrap_inner(self) -> Self::Inner { + self + } +} diff --git a/crates/polars-utils/src/partitioned.rs b/crates/polars-utils/src/partitioned.rs new file mode 100644 index 000000000000..e3f33cc1adbc --- /dev/null +++ b/crates/polars-utils/src/partitioned.rs @@ -0,0 +1,49 @@ +use hashbrown::hash_map::{HashMap, RawEntryBuilder, RawEntryBuilderMut}; + +use crate::aliases::PlRandomState; +use crate::hashing::hash_to_partition; + +pub struct PartitionedHashMap { + inner: Vec>, +} + +impl PartitionedHashMap { + pub fn new(inner: Vec>) -> Self { + Self { inner } + } + + #[inline(always)] + pub fn raw_entry_mut(&mut self, h: u64) -> RawEntryBuilderMut<'_, K, V, S> { + self.raw_entry_and_partition_mut(h).0 + } + + #[inline(always)] + pub fn raw_entry(&self, h: u64) -> RawEntryBuilder<'_, K, V, S> { + self.raw_entry_and_partition(h).0 + } + + #[inline] + pub fn raw_entry_and_partition(&self, h: u64) -> (RawEntryBuilder<'_, K, V, S>, usize) { + let partition = hash_to_partition(h, self.inner.len()); + let current_table = unsafe { self.inner.get_unchecked(partition) }; + (current_table.raw_entry(), partition) + } + + #[inline] + pub fn raw_entry_and_partition_mut( + &mut self, + h: u64, + ) -> (RawEntryBuilderMut<'_, K, V, S>, usize) { + let partition = hash_to_partition(h, self.inner.len()); + let current_table = unsafe { self.inner.get_unchecked_mut(partition) }; + (current_table.raw_entry_mut(), partition) + } + + pub fn inner(&self) -> &[HashMap] { + self.inner.as_ref() + } + + pub fn inner_mut(&mut self) -> &mut Vec> { + &mut self.inner + } +} diff --git a/crates/polars-utils/src/pl_serialize.rs b/crates/polars-utils/src/pl_serialize.rs new file mode 100644 index 000000000000..911c0dfe9461 --- /dev/null +++ b/crates/polars-utils/src/pl_serialize.rs @@ -0,0 +1,213 @@ +//! Centralized Polars serialization entry. +//! +//! Currently provides two serialization scheme's. +//! - Self-describing (and thus more forward compatible) activated with `FC: true` +//! - Compact activated with `FC: false` +use polars_error::{PolarsResult, to_compute_err}; + +fn serialize_impl(writer: W, value: &T) -> PolarsResult<()> +where + W: std::io::Write, + T: serde::ser::Serialize, +{ + if FC { + let mut s = rmp_serde::Serializer::new(writer).with_struct_map(); + value.serialize(&mut s).map_err(to_compute_err) + } else { + bincode::serialize_into(writer, value).map_err(to_compute_err) + } +} + +pub fn deserialize_impl(reader: R) -> PolarsResult +where + T: serde::de::DeserializeOwned, + R: std::io::Read, +{ + if FC { + rmp_serde::from_read(reader).map_err(to_compute_err) + } else { + bincode::deserialize_from(reader).map_err(to_compute_err) + } +} + +/// Mainly used to enable compression when serializing the final outer value. +/// For intermediate serialization steps, the function in the module should +/// be used instead. +pub struct SerializeOptions { + compression: bool, +} + +impl SerializeOptions { + pub fn with_compression(mut self, compression: bool) -> Self { + self.compression = compression; + self + } + + pub fn serialize_into_writer( + &self, + writer: W, + value: &T, + ) -> PolarsResult<()> + where + W: std::io::Write, + T: serde::ser::Serialize, + { + if self.compression { + let writer = flate2::write::ZlibEncoder::new(writer, flate2::Compression::fast()); + serialize_impl::<_, _, FC>(writer, value) + } else { + serialize_impl::<_, _, FC>(writer, value) + } + } + + pub fn deserialize_from_reader(&self, reader: R) -> PolarsResult + where + T: serde::de::DeserializeOwned, + R: std::io::Read, + { + if self.compression { + deserialize_impl::<_, _, FC>(flate2::read::ZlibDecoder::new(reader)) + } else { + deserialize_impl::<_, _, FC>(reader) + } + } + + pub fn serialize_to_bytes(&self, value: &T) -> PolarsResult> + where + T: serde::ser::Serialize, + { + let mut v = vec![]; + + self.serialize_into_writer::<_, _, FC>(&mut v, value)?; + + Ok(v) + } +} + +#[allow(clippy::derivable_impls)] +impl Default for SerializeOptions { + fn default() -> Self { + Self { compression: false } + } +} + +pub fn serialize_into_writer(writer: W, value: &T) -> PolarsResult<()> +where + W: std::io::Write, + T: serde::ser::Serialize, +{ + serialize_impl::<_, _, FC>(writer, value) +} + +pub fn deserialize_from_reader(reader: R) -> PolarsResult +where + T: serde::de::DeserializeOwned, + R: std::io::Read, +{ + deserialize_impl::<_, _, FC>(reader) +} + +pub fn serialize_to_bytes(value: &T) -> PolarsResult> +where + T: serde::ser::Serialize, +{ + let mut v = vec![]; + + serialize_into_writer::<_, _, FC>(&mut v, value)?; + + Ok(v) +} + +/// Potentially avoids copying memory compared to a naive `Vec::::deserialize`. +/// +/// This is essentially boilerplate for visiting bytes without copying where possible. +pub fn deserialize_map_bytes<'de, D, O>( + deserializer: D, + mut func: impl for<'b> FnMut(std::borrow::Cow<'b, [u8]>) -> O, +) -> Result +where + D: serde::de::Deserializer<'de>, +{ + // Lets us avoid monomorphizing the visitor + let mut out: Option = None; + struct V<'f>(&'f mut (dyn for<'b> FnMut(std::borrow::Cow<'b, [u8]>))); + + deserializer.deserialize_bytes(V(&mut |v| drop(out.replace(func(v)))))?; + + return Ok(out.unwrap()); + + impl<'de> serde::de::Visitor<'de> for V<'_> { + type Value = (); + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("deserialize_map_bytes") + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: serde::de::Error, + { + self.0(std::borrow::Cow::Borrowed(v)); + Ok(()) + } + + fn visit_byte_buf(self, v: Vec) -> Result + where + E: serde::de::Error, + { + self.0(std::borrow::Cow::Owned(v)); + Ok(()) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + // This is not ideal, but we hit here if the serialization format is JSON. + let bytes = std::iter::from_fn(|| seq.next_element::().transpose()) + .collect::, A::Error>>()?; + + self.0(std::borrow::Cow::Owned(bytes)); + Ok(()) + } + } +} + +#[cfg(test)] +mod tests { + #[test] + fn test_serde_skip_enum() { + #[derive(Default, Debug, PartialEq)] + struct MyType(Option); + + // Note: serde(skip) must be at the end of enums + #[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)] + enum Enum { + A, + #[serde(skip)] + B(MyType), + } + + impl Default for Enum { + fn default() -> Self { + Self::B(MyType(None)) + } + } + + let v = Enum::A; + let b = super::serialize_to_bytes::<_, false>(&v).unwrap(); + let r: Enum = super::deserialize_from_reader::<_, _, false>(b.as_slice()).unwrap(); + + assert_eq!(r, v); + + let v = Enum::A; + let b = super::SerializeOptions::default() + .serialize_to_bytes::<_, false>(&v) + .unwrap(); + let r: Enum = super::SerializeOptions::default() + .deserialize_from_reader::<_, _, false>(b.as_slice()) + .unwrap(); + + assert_eq!(r, v); + } +} diff --git a/crates/polars-utils/src/pl_str.rs b/crates/polars-utils/src/pl_str.rs new file mode 100644 index 000000000000..51497a702b8e --- /dev/null +++ b/crates/polars-utils/src/pl_str.rs @@ -0,0 +1,276 @@ +use std::sync::atomic::{AtomicU64, Ordering}; + +#[macro_export] +macro_rules! format_pl_smallstr { + ($($arg:tt)*) => {{ + use std::fmt::Write; + + let mut string = $crate::pl_str::PlSmallStr::EMPTY; + write!(string, $($arg)*).unwrap(); + string + }} +} + +type Inner = compact_str::CompactString; + +/// String type that inlines small strings. +#[derive(Clone, Eq, Hash, PartialOrd, Ord)] +#[cfg_attr( + feature = "serde", + derive(serde::Serialize, serde::Deserialize), + serde(transparent) +)] +pub struct PlSmallStr(Inner); + +impl PlSmallStr { + pub const EMPTY: Self = Self::from_static(""); + pub const EMPTY_REF: &'static Self = &Self::from_static(""); + + #[inline(always)] + pub const fn from_static(s: &'static str) -> Self { + Self(Inner::const_new(s)) + } + + #[inline(always)] + #[allow(clippy::should_implement_trait)] + pub fn from_str(s: &str) -> Self { + Self(Inner::from(s)) + } + + #[inline(always)] + pub fn from_string(s: String) -> Self { + Self(Inner::from(s)) + } + + #[inline(always)] + pub fn as_str(&self) -> &str { + self.0.as_str() + } + + #[inline(always)] + pub fn as_mut_str(&mut self) -> &mut str { + self.0.as_mut_str() + } + + #[inline(always)] + pub fn into_string(self) -> String { + self.0.into_string() + } +} + +impl Default for PlSmallStr { + #[inline(always)] + fn default() -> Self { + Self::EMPTY + } +} + +// AsRef, Deref and Borrow impls to &str + +impl AsRef for PlSmallStr { + #[inline(always)] + fn as_ref(&self) -> &str { + self.as_str() + } +} + +impl core::ops::Deref for PlSmallStr { + type Target = str; + + #[inline(always)] + fn deref(&self) -> &Self::Target { + self.as_str() + } +} + +impl core::ops::DerefMut for PlSmallStr { + #[inline(always)] + fn deref_mut(&mut self) -> &mut Self::Target { + self.as_mut_str() + } +} + +impl core::borrow::Borrow for PlSmallStr { + #[inline(always)] + fn borrow(&self) -> &str { + self.as_str() + } +} + +// AsRef impls for other types + +impl AsRef for PlSmallStr { + #[inline(always)] + fn as_ref(&self) -> &std::path::Path { + self.as_str().as_ref() + } +} + +impl AsRef<[u8]> for PlSmallStr { + #[inline(always)] + fn as_ref(&self) -> &[u8] { + self.as_str().as_bytes() + } +} + +impl AsRef for PlSmallStr { + #[inline(always)] + fn as_ref(&self) -> &std::ffi::OsStr { + self.as_str().as_ref() + } +} + +// From impls + +impl From<&str> for PlSmallStr { + #[inline(always)] + fn from(value: &str) -> Self { + Self::from_str(value) + } +} + +impl From for PlSmallStr { + #[inline(always)] + fn from(value: String) -> Self { + Self::from_string(value) + } +} + +impl From<&String> for PlSmallStr { + #[inline(always)] + fn from(value: &String) -> Self { + Self::from_str(value.as_str()) + } +} + +impl From for PlSmallStr { + #[inline(always)] + fn from(value: Inner) -> Self { + Self(value) + } +} + +// FromIterator impls + +impl FromIterator for PlSmallStr { + #[inline(always)] + fn from_iter>(iter: T) -> Self { + Self(Inner::from_iter(iter.into_iter().map(|x| x.0))) + } +} + +impl<'a> FromIterator<&'a PlSmallStr> for PlSmallStr { + #[inline(always)] + fn from_iter>(iter: T) -> Self { + Self(Inner::from_iter(iter.into_iter().map(|x| x.as_str()))) + } +} + +impl FromIterator for PlSmallStr { + #[inline(always)] + fn from_iter>(iter: I) -> PlSmallStr { + Self(Inner::from_iter(iter)) + } +} + +impl<'a> FromIterator<&'a char> for PlSmallStr { + #[inline(always)] + fn from_iter>(iter: I) -> PlSmallStr { + Self(Inner::from_iter(iter)) + } +} + +impl<'a> FromIterator<&'a str> for PlSmallStr { + #[inline(always)] + fn from_iter>(iter: I) -> PlSmallStr { + Self(Inner::from_iter(iter)) + } +} + +impl FromIterator for PlSmallStr { + #[inline(always)] + fn from_iter>(iter: I) -> PlSmallStr { + Self(Inner::from_iter(iter)) + } +} + +impl FromIterator> for PlSmallStr { + #[inline(always)] + fn from_iter>>(iter: I) -> PlSmallStr { + Self(Inner::from_iter(iter)) + } +} + +impl<'a> FromIterator> for PlSmallStr { + #[inline(always)] + fn from_iter>>(iter: I) -> PlSmallStr { + Self(Inner::from_iter(iter)) + } +} + +// PartialEq impls + +impl PartialEq for PlSmallStr +where + T: AsRef + ?Sized, +{ + #[inline(always)] + fn eq(&self, other: &T) -> bool { + self.as_str() == other.as_ref() + } +} + +impl PartialEq for &str { + #[inline(always)] + fn eq(&self, other: &PlSmallStr) -> bool { + *self == other.as_str() + } +} + +impl PartialEq for String { + #[inline(always)] + fn eq(&self, other: &PlSmallStr) -> bool { + self.as_str() == other.as_str() + } +} + +// Write + +impl core::fmt::Write for PlSmallStr { + #[inline(always)] + fn write_char(&mut self, c: char) -> std::fmt::Result { + self.0.write_char(c) + } + + #[inline(always)] + fn write_fmt(&mut self, args: std::fmt::Arguments<'_>) -> std::fmt::Result { + self.0.write_fmt(args) + } + + #[inline(always)] + fn write_str(&mut self, s: &str) -> std::fmt::Result { + self.0.write_str(s) + } +} + +// Debug, Display + +impl core::fmt::Debug for PlSmallStr { + #[inline(always)] + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.as_str().fmt(f) + } +} + +impl core::fmt::Display for PlSmallStr { + #[inline(always)] + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.as_str().fmt(f) + } +} + +pub fn unique_column_name() -> PlSmallStr { + static COUNTER: AtomicU64 = AtomicU64::new(0); + let idx = COUNTER.fetch_add(1, Ordering::Relaxed); + format_pl_smallstr!("_POLARS_TMP_{idx}") +} diff --git a/crates/polars-utils/src/priority.rs b/crates/polars-utils/src/priority.rs new file mode 100644 index 000000000000..a249a25dcb32 --- /dev/null +++ b/crates/polars-utils/src/priority.rs @@ -0,0 +1,25 @@ +use std::cmp::Ordering; + +/// A pair which is ordered exclusively by the first element. +#[derive(Copy, Clone, Debug)] +pub struct Priority(pub P, pub T); + +impl Ord for Priority { + fn cmp(&self, other: &Self) -> Ordering { + self.0.cmp(&other.0) + } +} + +impl PartialOrd for Priority { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl PartialEq for Priority { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for Priority {} diff --git a/crates/polars-utils/src/python_convert_registry.rs b/crates/polars-utils/src/python_convert_registry.rs new file mode 100644 index 000000000000..2f0252817b0b --- /dev/null +++ b/crates/polars-utils/src/python_convert_registry.rs @@ -0,0 +1,34 @@ +use std::any::Any; +use std::ops::Deref; +use std::sync::{Arc, LazyLock, RwLock}; + +use pyo3::{Py, PyAny, PyResult}; + +pub type PythonToSinkTarget = Arc) -> PyResult> + Send + Sync>; + +#[derive(Clone)] +pub struct FromPythonConvertRegistry { + pub sink_target: PythonToSinkTarget, +} + +#[derive(Clone)] +pub struct PythonConvertRegistry { + pub from_py: FromPythonConvertRegistry, +} + +static PYTHON_CONVERT_REGISTRY: LazyLock>> = + LazyLock::new(Default::default); + +pub fn get_python_convert_registry() -> PythonConvertRegistry { + PYTHON_CONVERT_REGISTRY + .deref() + .read() + .unwrap() + .as_ref() + .unwrap() + .clone() +} + +pub fn register_converters(registry: PythonConvertRegistry) { + *PYTHON_CONVERT_REGISTRY.deref().write().unwrap() = Some(registry); +} diff --git a/crates/polars-utils/src/python_function.rs b/crates/polars-utils/src/python_function.rs new file mode 100644 index 000000000000..bed42aaae9bd --- /dev/null +++ b/crates/polars-utils/src/python_function.rs @@ -0,0 +1,323 @@ +use polars_error::{PolarsError, polars_bail}; +use pyo3::prelude::*; +use pyo3::pybacked::PyBackedBytes; +use pyo3::types::PyBytes; +#[cfg(feature = "serde")] +pub use serde_wrap::{ + PYTHON3_VERSION, PySerializeWrap, SERDE_MAGIC_BYTE_MARK as PYTHON_SERDE_MAGIC_BYTE_MARK, + TrySerializeToBytes, +}; + +/// Wrapper around PyObject from pyo3 with additional trait impls. +#[derive(Debug)] +pub struct PythonObject(pub PyObject); +// Note: We have this because the struct itself used to be called `PythonFunction`, so it's +// referred to as such from a lot of places. +pub type PythonFunction = PythonObject; + +impl std::ops::Deref for PythonObject { + type Target = PyObject; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for PythonObject { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl Clone for PythonObject { + fn clone(&self) -> Self { + Python::with_gil(|py| Self(self.0.clone_ref(py))) + } +} + +impl From for PythonObject { + fn from(value: PyObject) -> Self { + Self(value) + } +} + +impl<'py> pyo3::conversion::IntoPyObject<'py> for PythonObject { + type Target = PyAny; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> Result { + Ok(self.0.into_bound(py)) + } +} + +impl<'py> pyo3::conversion::IntoPyObject<'py> for &PythonObject { + type Target = PyAny; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> Result { + Ok(self.0.bind(py).clone()) + } +} + +impl Eq for PythonObject {} + +impl PartialEq for PythonObject { + fn eq(&self, other: &Self) -> bool { + Python::with_gil(|py| { + let eq = self.0.getattr(py, "__eq__").unwrap(); + eq.call1(py, (other.0.clone_ref(py),)) + .unwrap() + .extract::(py) + // equality can be not implemented, so default to false + .unwrap_or(false) + }) + } +} + +#[cfg(feature = "serde")] +mod _serde_impls { + use super::{PySerializeWrap, PythonObject, TrySerializeToBytes, serde_wrap}; + use crate::pl_serialize::deserialize_map_bytes; + + impl PythonObject { + pub fn serialize_with_pyversion( + value: &T, + serializer: S, + ) -> std::result::Result + where + T: AsRef, + S: serde::ser::Serializer, + { + use serde::Serialize; + PySerializeWrap(value.as_ref()).serialize(serializer) + } + + pub fn deserialize_with_pyversion<'de, T, D>(d: D) -> Result + where + T: From, + D: serde::de::Deserializer<'de>, + { + use serde::Deserialize; + let v: PySerializeWrap = PySerializeWrap::deserialize(d)?; + + Ok(v.0.into()) + } + } + + impl TrySerializeToBytes for PythonObject { + fn try_serialize_to_bytes(&self) -> polars_error::PolarsResult> { + serde_wrap::serialize_pyobject_with_cloudpickle_fallback(&self.0) + } + + fn try_deserialize_bytes(bytes: &[u8]) -> polars_error::PolarsResult { + serde_wrap::deserialize_pyobject_bytes_maybe_cloudpickle(bytes) + } + } + + impl serde::Serialize for PythonObject { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::Error; + let bytes = self + .try_serialize_to_bytes() + .map_err(|e| S::Error::custom(e.to_string()))?; + + Vec::::serialize(&bytes, serializer) + } + } + + impl<'a> serde::Deserialize<'a> for PythonObject { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'a>, + { + use serde::de::Error; + deserialize_map_bytes(deserializer, |bytes| { + Self::try_deserialize_bytes(&bytes).map_err(|e| D::Error::custom(e.to_string())) + })? + } + } +} + +#[cfg(feature = "serde")] +mod serde_wrap { + use std::sync::LazyLock; + + use polars_error::PolarsResult; + + use super::*; + use crate::config; + use crate::pl_serialize::deserialize_map_bytes; + + pub const SERDE_MAGIC_BYTE_MARK: &[u8] = "PLPYFN".as_bytes(); + /// [minor, micro] + pub static PYTHON3_VERSION: LazyLock<[u8; 2]> = LazyLock::new(super::get_python3_version); + + /// Serializes a Python object without additional system metadata. This is intended to be used + /// together with `PySerializeWrap`, which attaches e.g. Python version metadata. + pub trait TrySerializeToBytes: Sized { + fn try_serialize_to_bytes(&self) -> PolarsResult>; + fn try_deserialize_bytes(bytes: &[u8]) -> PolarsResult; + } + + /// Serialization wrapper for T: TrySerializeToBytes that attaches Python + /// version metadata. + pub struct PySerializeWrap(pub T); + + impl serde::Serialize for PySerializeWrap<&T> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::Error; + let dumped = self + .0 + .try_serialize_to_bytes() + .map_err(|e| S::Error::custom(e.to_string()))?; + + serializer.serialize_bytes( + &[SERDE_MAGIC_BYTE_MARK, &*PYTHON3_VERSION, dumped.as_slice()].concat(), + ) + } + } + + impl<'a, T: TrySerializeToBytes> serde::Deserialize<'a> for PySerializeWrap { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'a>, + { + use serde::de::Error; + + deserialize_map_bytes(deserializer, |bytes| { + let Some((magic, rem)) = bytes.split_at_checked(SERDE_MAGIC_BYTE_MARK.len()) else { + return Err(D::Error::custom( + "unexpected EOF when reading serialized pyobject version", + )); + }; + + if magic != SERDE_MAGIC_BYTE_MARK { + return Err(D::Error::custom( + "serialized pyobject did not begin with magic byte mark", + )); + } + + let bytes = rem; + + let [a, b, rem @ ..] = bytes else { + return Err(D::Error::custom( + "unexpected EOF when reading serialized pyobject metadata", + )); + }; + + let py3_version = [*a, *b]; + // The validity of cloudpickle is check later when called `try_deserialize`. + let used_cloud_pickle = rem.first(); + + // Cloudpickle uses bytecode to serialize, which is unstable between versions + // So we only allow strict python versions if cloudpickle is used. + if py3_version != *PYTHON3_VERSION && used_cloud_pickle == Some(&1) { + return Err(D::Error::custom(format!( + "python version that pyobject was serialized with {:?} \ + differs from system python version {:?}", + (3, py3_version[0], py3_version[1]), + (3, PYTHON3_VERSION[0], PYTHON3_VERSION[1]), + ))); + } + + let bytes = rem; + + T::try_deserialize_bytes(bytes) + .map(Self) + .map_err(|e| D::Error::custom(e.to_string())) + })? + } + } + + pub fn serialize_pyobject_with_cloudpickle_fallback( + py_object: &PyObject, + ) -> PolarsResult> { + Python::with_gil(|py| { + let pickle = PyModule::import(py, "pickle") + .expect("unable to import 'pickle'") + .getattr("dumps") + .unwrap(); + + let dumped = pickle.call1((py_object.clone_ref(py),)); + + let (dumped, used_cloudpickle) = match dumped { + Ok(v) => (v, false), + Err(e) => { + if config::verbose() { + eprintln!( + "serialize_pyobject_with_cloudpickle_fallback(): \ + retrying with cloudpickle due to error: {:?}", + e + ); + } + + let cloudpickle = PyModule::import(py, "cloudpickle")? + .getattr("dumps") + .unwrap(); + let dumped = cloudpickle.call1((py_object.clone_ref(py),))?; + (dumped, true) + }, + }; + + let py_bytes = dumped.extract::()?; + + Ok([&[used_cloudpickle as u8, b'C'][..], py_bytes.as_ref()].concat()) + }) + .map_err(from_pyerr) + } + + pub fn deserialize_pyobject_bytes_maybe_cloudpickle From>( + bytes: &[u8], + ) -> PolarsResult { + // TODO: Actually deserialize with cloudpickle if it's set. + let [used_cloudpickle @ 0 | used_cloudpickle @ 1, b'C', rem @ ..] = bytes else { + polars_bail!(ComputeError: "deserialize_pyobject_bytes_maybe_cloudpickle: invalid start bytes") + }; + + let bytes = rem; + + Python::with_gil(|py| { + let p = if *used_cloudpickle == 1 { + "cloudpickle" + } else { + "pickle" + }; + + let pickle = PyModule::import(py, p) + .expect("unable to import 'pickle'") + .getattr("loads") + .unwrap(); + let arg = (PyBytes::new(py, bytes),); + let pyany_bound = pickle.call1(arg)?; + Ok(PyObject::from(pyany_bound).into()) + }) + .map_err(from_pyerr) + } +} + +/// Get the [minor, micro] Python3 version from the `sys` module. +fn get_python3_version() -> [u8; 2] { + Python::with_gil(|py| { + let version_info = PyModule::import(py, "sys") + .unwrap() + .getattr("version_info") + .unwrap(); + + [ + version_info.getattr("minor").unwrap().extract().unwrap(), + version_info.getattr("micro").unwrap().extract().unwrap(), + ] + }) +} + +fn from_pyerr(e: PyErr) -> PolarsError { + PolarsError::ComputeError(format!("error raised in python: {e}").into()) +} diff --git a/crates/polars-utils/src/regex_cache.rs b/crates/polars-utils/src/regex_cache.rs new file mode 100644 index 000000000000..187859c84736 --- /dev/null +++ b/crates/polars-utils/src/regex_cache.rs @@ -0,0 +1,55 @@ +use std::cell::RefCell; + +use regex::Regex; + +use crate::cache::LruCache; + +// Regex compilation is really heavy, and the resulting regexes can be large as +// well, so we should have a good caching scheme. +// +// TODO: add larger global cache which has time-based flush. + +/// A cache for compiled regular expressions. +pub struct RegexCache { + cache: LruCache, +} + +impl RegexCache { + fn new() -> Self { + Self { + cache: LruCache::with_capacity(32), + } + } + + pub fn compile(&mut self, re: &str) -> Result<&Regex, regex::Error> { + let r = self.cache.try_get_or_insert_with(re, |re| { + #[allow(clippy::disallowed_methods)] + Regex::new(re) + }); + Ok(&*r?) + } +} + +thread_local! { + static LOCAL_REGEX_CACHE: RefCell = RefCell::new(RegexCache::new()); +} + +pub fn compile_regex(re: &str) -> Result { + LOCAL_REGEX_CACHE.with_borrow_mut(|cache| cache.compile(re).cloned()) +} + +pub fn with_regex_cache R>(f: F) -> R { + LOCAL_REGEX_CACHE.with_borrow_mut(f) +} + +#[macro_export] +macro_rules! cached_regex { + () => {}; + + ($vis:vis static $name:ident = $regex:expr; $($rest:tt)*) => { + #[allow(clippy::disallowed_methods)] + $vis static $name: std::sync::LazyLock = std::sync::LazyLock::new(|| regex::Regex::new($regex).unwrap()); + $crate::regex_cache::cached_regex!($($rest)*); + }; +} +pub use cached_regex; diff --git a/crates/polars-utils/src/select.rs b/crates/polars-utils/src/select.rs new file mode 100644 index 000000000000..3dfceff59d56 --- /dev/null +++ b/crates/polars-utils/src/select.rs @@ -0,0 +1,9 @@ +#[cfg(feature = "nightly")] +pub fn select_unpredictable(cond: bool, true_val: T, false_val: T) -> T { + core::hint::select_unpredictable(cond, true_val, false_val) +} + +#[cfg(not(feature = "nightly"))] +pub fn select_unpredictable(cond: bool, true_val: T, false_val: T) -> T { + if cond { true_val } else { false_val } +} diff --git a/crates/polars-utils/src/slice.rs b/crates/polars-utils/src/slice.rs new file mode 100644 index 000000000000..839107bc0cf9 --- /dev/null +++ b/crates/polars-utils/src/slice.rs @@ -0,0 +1,96 @@ +use std::cmp::Ordering; +use std::mem::MaybeUninit; +use std::ops::Range; + +pub trait SliceAble { + /// # Safety + /// no bound checks. + unsafe fn slice_unchecked(&self, range: Range) -> Self; + + fn slice(&self, range: Range) -> Self; +} + +impl SliceAble for &[T] { + unsafe fn slice_unchecked(&self, range: Range) -> Self { + unsafe { self.get_unchecked(range) } + } + + fn slice(&self, range: Range) -> Self { + self.get(range).unwrap() + } +} + +pub trait Extrema { + fn min_value(&self) -> Option<&T>; + fn max_value(&self) -> Option<&T>; +} + +impl Extrema for [T] { + fn min_value(&self) -> Option<&T> { + self.iter() + .min_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal)) + } + + fn max_value(&self) -> Option<&T> { + self.iter() + .max_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal)) + } +} + +pub trait SortedSlice { + fn is_sorted_ascending(&self) -> bool; +} + +impl SortedSlice for [T] { + fn is_sorted_ascending(&self) -> bool { + if self.is_empty() { + true + } else { + let mut previous = self[0]; + let mut sorted = true; + + // don't early stop or branch + // so it autovectorizes + for &v in &self[1..] { + sorted &= previous <= v; + previous = v; + } + sorted + } + } +} + +pub trait Slice2Uninit { + fn as_uninit(&self) -> &[MaybeUninit]; +} + +impl Slice2Uninit for [T] { + #[inline] + fn as_uninit(&self) -> &[MaybeUninit] { + unsafe { std::slice::from_raw_parts(self.as_ptr() as *const MaybeUninit, self.len()) } + } +} + +// Loads a u64 from the given byteslice, as if it were padded with zeros. +#[inline] +pub fn load_padded_le_u64(bytes: &[u8]) -> u64 { + let len = bytes.len(); + if len >= 8 { + return u64::from_le_bytes(bytes[0..8].try_into().unwrap()); + } + + if len >= 4 { + let lo = u32::from_le_bytes(bytes[0..4].try_into().unwrap()); + let hi = u32::from_le_bytes(bytes[len - 4..len].try_into().unwrap()); + return (lo as u64) | ((hi as u64) << (8 * (len - 4))); + } + + if len == 0 { + return 0; + } + + let lo = bytes[0] as u64; + let mid = (bytes[len / 2] as u64) << (8 * (len / 2)); + let hi = (bytes[len - 1] as u64) << (8 * (len - 1)); + lo | mid | hi +} diff --git a/crates/polars-utils/src/slice_enum.rs b/crates/polars-utils/src/slice_enum.rs new file mode 100644 index 000000000000..b6a241cc1081 --- /dev/null +++ b/crates/polars-utils/src/slice_enum.rs @@ -0,0 +1,191 @@ +use std::num::TryFromIntError; +use std::ops::Range; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub enum Slice { + /// Or zero + Positive { + offset: usize, + len: usize, + }, + Negative { + offset_from_end: usize, + len: usize, + }, +} + +impl Slice { + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + match self { + Slice::Positive { len, .. } => *len, + Slice::Negative { len, .. } => *len, + } + } + + pub fn len_mut(&mut self) -> &mut usize { + match self { + Slice::Positive { len, .. } => len, + Slice::Negative { len, .. } => len, + } + } + + /// Returns the end position of the slice (offset + len). + /// + /// # Panics + /// Panics if self is negative. + pub fn end_position(&self) -> usize { + let Slice::Positive { offset, len } = self.clone() else { + panic!("cannot use end_position() on a negative slice"); + }; + + offset.saturating_add(len) + } + + /// Returns the equivalent slice to apply from an offsetted position. + /// + /// # Panics + /// Panics if self is negative. + pub fn offsetted(self, position: usize) -> Self { + let Slice::Positive { offset, len } = self else { + panic!("cannot use offsetted() on a negative slice"); + }; + + let (offset, len) = if position <= offset { + (offset - position, len) + } else { + let n_past_offset = position - offset; + (0, len.saturating_sub(n_past_offset)) + }; + + Slice::Positive { offset, len } + } + + /// Restricts the bounds of the slice to within a number of rows. Negative slices will also + /// be translated to the positive equivalent. + pub fn restrict_to_bounds(self, n_rows: usize) -> Self { + match self { + Slice::Positive { offset, len } => { + let offset = offset.min(n_rows); + let len = len.min(n_rows - offset); + Slice::Positive { offset, len } + }, + Slice::Negative { + offset_from_end, + len, + } => { + if n_rows >= offset_from_end { + // Trim extra starting rows + let offset = n_rows - offset_from_end; + let len = len.min(n_rows - offset); + Slice::Positive { offset, len } + } else { + // Slice offset goes past start of data. + let stop_at_n_from_end = offset_from_end.saturating_sub(len); + let len = n_rows.saturating_sub(stop_at_n_from_end); + + Slice::Positive { offset: 0, len } + } + }, + } + } +} + +impl From<(usize, usize)> for Slice { + fn from((offset, len): (usize, usize)) -> Self { + Slice::Positive { offset, len } + } +} + +impl From<(i64, usize)> for Slice { + fn from((offset, len): (i64, usize)) -> Self { + if offset >= 0 { + Slice::Positive { + offset: usize::try_from(offset).unwrap(), + len, + } + } else { + Slice::Negative { + offset_from_end: usize::try_from(-offset).unwrap(), + len, + } + } + } +} + +impl TryFrom for (i64, usize) { + type Error = TryFromIntError; + + fn try_from(value: Slice) -> Result { + match value { + Slice::Positive { offset, len } => Ok((i64::try_from(offset)?, len)), + Slice::Negative { + offset_from_end, + len, + } => Ok((-i64::try_from(offset_from_end)?, len)), + } + } +} + +impl From for Range { + fn from(value: Slice) -> Self { + match value { + Slice::Positive { offset, len } => offset..offset.checked_add(len).unwrap(), + Slice::Negative { .. } => panic!("cannot convert negative slice into range"), + } + } +} + +#[cfg(test)] +mod tests { + use super::Slice; + + #[test] + fn test_slice_offset() { + assert_eq!( + Slice::Positive { offset: 3, len: 10 }.offsetted(1), + Slice::Positive { offset: 2, len: 10 } + ); + assert_eq!( + Slice::Positive { offset: 3, len: 10 }.offsetted(5), + Slice::Positive { offset: 0, len: 8 } + ); + } + + #[test] + fn test_slice_restrict_to_bounds() { + assert_eq!( + Slice::Positive { offset: 3, len: 10 }.restrict_to_bounds(7), + Slice::Positive { offset: 3, len: 4 }, + ); + assert_eq!( + Slice::Positive { offset: 3, len: 10 }.restrict_to_bounds(0), + Slice::Positive { offset: 0, len: 0 }, + ); + assert_eq!( + Slice::Positive { offset: 3, len: 10 }.restrict_to_bounds(1), + Slice::Positive { offset: 1, len: 0 }, + ); + assert_eq!( + Slice::Positive { offset: 2, len: 0 }.restrict_to_bounds(10), + Slice::Positive { offset: 2, len: 0 }, + ); + assert_eq!( + Slice::Negative { + offset_from_end: 3, + len: 1 + } + .restrict_to_bounds(4), + Slice::Positive { offset: 1, len: 1 }, + ); + assert_eq!( + Slice::Negative { + offset_from_end: 3, + len: 1 + } + .restrict_to_bounds(1), + Slice::Positive { offset: 0, len: 0 }, + ); + } +} diff --git a/crates/polars-utils/src/sort.rs b/crates/polars-utils/src/sort.rs new file mode 100644 index 000000000000..54f6c65272ae --- /dev/null +++ b/crates/polars-utils/src/sort.rs @@ -0,0 +1,126 @@ +use std::mem::MaybeUninit; + +use num_traits::FromPrimitive; +use rayon::ThreadPool; +use rayon::prelude::*; + +use crate::IdxSize; +use crate::total_ord::TotalOrd; + +/// This is a perfect sort particularly useful for an arg_sort of an arg_sort +/// The second arg_sort sorts indices from `0` to `len` so can be just assigned to the +/// new index location. +/// +/// Besides that we know that all indices are unique and thus not alias so we can parallelize. +/// +/// This sort does not sort in place and will allocate. +/// +/// - The right indices are used for sorting +/// - The left indices are placed at the location right points to. +/// +/// # Safety +/// The caller must ensure that the right indexes for `&[(_, IdxSize)]` are integers ranging from `0..idx.len` +#[cfg(any(target_os = "emscripten", not(target_family = "wasm")))] +pub unsafe fn perfect_sort(pool: &ThreadPool, idx: &[(IdxSize, IdxSize)], out: &mut Vec) { + let chunk_size = std::cmp::max( + idx.len() / pool.current_num_threads(), + pool.current_num_threads(), + ); + + out.reserve(idx.len()); + let ptr = out.as_mut_ptr() as *const IdxSize as usize; + + pool.install(|| { + idx.par_chunks(chunk_size).for_each(|indices| { + let ptr = ptr as *mut IdxSize; + for (idx_val, idx_location) in indices { + // SAFETY: + // idx_location is in bounds by invariant of this function + // and we ensured we have at least `idx.len()` capacity + unsafe { *ptr.add(*idx_location as usize) = *idx_val }; + } + }); + }); + // SAFETY: + // all elements are written + unsafe { out.set_len(idx.len()) }; +} + +// wasm alternative with different signature +#[cfg(all(not(target_os = "emscripten"), target_family = "wasm"))] +pub unsafe fn perfect_sort( + pool: &crate::wasm::Pool, + idx: &[(IdxSize, IdxSize)], + out: &mut Vec, +) { + let chunk_size = std::cmp::max( + idx.len() / pool.current_num_threads(), + pool.current_num_threads(), + ); + + out.reserve(idx.len()); + let ptr = out.as_mut_ptr() as *const IdxSize as usize; + + pool.install(|| { + idx.par_chunks(chunk_size).for_each(|indices| { + let ptr = ptr as *mut IdxSize; + for (idx_val, idx_location) in indices { + // SAFETY: + // idx_location is in bounds by invariant of this function + // and we ensured we have at least `idx.len()` capacity + *ptr.add(*idx_location as usize) = *idx_val; + } + }); + }); + // SAFETY: + // all elements are written + out.set_len(idx.len()); +} + +unsafe fn assume_init_mut(slice: &mut [MaybeUninit]) -> &mut [T] { + unsafe { &mut *(slice as *mut [MaybeUninit] as *mut [T]) } +} + +pub fn arg_sort_ascending<'a, T: TotalOrd + Copy + 'a, Idx, I: IntoIterator>( + v: I, + scratch: &'a mut Vec, + n: usize, +) -> &'a mut [Idx] +where + Idx: FromPrimitive + Copy, +{ + // Needed to be able to write back to back in the same buffer. + debug_assert_eq!(align_of::(), align_of::<(T, Idx)>()); + let size = size_of::<(T, Idx)>(); + let upper_bound = size * n + size; + scratch.reserve(upper_bound); + let scratch_slice = unsafe { + let cap_slice = scratch.spare_capacity_mut(); + let (_, scratch_slice, _) = cap_slice.align_to_mut::>(); + &mut scratch_slice[..n] + }; + + for ((i, v), dst) in v.into_iter().enumerate().zip(scratch_slice.iter_mut()) { + *dst = MaybeUninit::new((v, Idx::from_usize(i).unwrap())); + } + debug_assert_eq!(n, scratch_slice.len()); + + let scratch_slice = unsafe { assume_init_mut(scratch_slice) }; + scratch_slice.sort_by(|key1, key2| key1.0.tot_cmp(&key2.0)); + + // now we write the indexes in the same array. + // So from to + unsafe { + let src = scratch_slice.as_ptr(); + + let (_, scratch_slice_aligned_to_idx, _) = scratch_slice.align_to_mut::(); + + let dst = scratch_slice_aligned_to_idx.as_mut_ptr(); + + for i in 0..n { + dst.add(i).write((*src.add(i)).1); + } + + &mut scratch_slice_aligned_to_idx[..n] + } +} diff --git a/crates/polars-utils/src/sparse_init_vec.rs b/crates/polars-utils/src/sparse_init_vec.rs new file mode 100644 index 000000000000..16de05bad2c7 --- /dev/null +++ b/crates/polars-utils/src/sparse_init_vec.rs @@ -0,0 +1,81 @@ +use std::sync::atomic::{AtomicU8, AtomicUsize, Ordering}; + +pub struct SparseInitVec { + ptr: *mut T, + len: usize, + cap: usize, + + num_init: AtomicUsize, + init_mask: Vec, +} + +unsafe impl Send for SparseInitVec {} +unsafe impl Sync for SparseInitVec {} + +impl SparseInitVec { + pub fn with_capacity(len: usize) -> Self { + let init_mask = (0..len.div_ceil(8)).map(|_| AtomicU8::new(0)).collect(); + let mut storage = Vec::with_capacity(len); + let cap = storage.capacity(); + let ptr = storage.as_mut_ptr(); + core::mem::forget(storage); + Self { + len, + cap, + ptr, + num_init: AtomicUsize::new(0), + init_mask, + } + } + + pub fn try_set(&self, idx: usize, value: T) -> Result<(), T> { + unsafe { + if idx >= self.len { + return Err(value); + } + + // SAFETY: we use Relaxed orderings as we only ever read data back in methods that take + // self mutably or owned, already implying synchronization. + let init_mask_byte = self.init_mask.get_unchecked(idx / 8); + let bit_mask = 1 << (idx % 8); + if init_mask_byte.fetch_or(bit_mask, Ordering::Relaxed) & bit_mask != 0 { + return Err(value); + } + + self.ptr.add(idx).write(value); + self.num_init.fetch_add(1, Ordering::Relaxed); + } + + Ok(()) + } + + pub fn try_assume_init(mut self) -> Result, Self> { + unsafe { + if *self.num_init.get_mut() == self.len { + let ret = Vec::from_raw_parts(self.ptr, self.len, self.cap); + drop(core::mem::take(&mut self.init_mask)); + core::mem::forget(self); + Ok(ret) + } else { + Err(self) + } + } + } +} + +impl Drop for SparseInitVec { + fn drop(&mut self) { + unsafe { + // Make sure storage gets dropped even if element drop panics. + let _storage = Vec::from_raw_parts(self.ptr, 0, self.cap); + + for idx in 0..self.len { + let init_mask_byte = self.init_mask.get_unchecked_mut(idx / 8); + let bit_mask = 1 << (idx % 8); + if *init_mask_byte.get_mut() & bit_mask != 0 { + self.ptr.add(idx).drop_in_place(); + } + } + } + } +} diff --git a/crates/polars-utils/src/sync.rs b/crates/polars-utils/src/sync.rs new file mode 100644 index 000000000000..ab32995d6aae --- /dev/null +++ b/crates/polars-utils/src/sync.rs @@ -0,0 +1,53 @@ +/// Utility that allows use to send pointers to another thread. +/// This is better than going through `usize` as MIRI can follow these. +#[derive(Debug, PartialEq, Eq)] +#[repr(transparent)] +pub struct SyncPtr(*mut T); + +impl SyncPtr { + /// # Safety + /// + /// This will make a pointer sync and send. + /// Ensure that you don't break aliasing rules. + pub unsafe fn new(ptr: *mut T) -> Self { + Self(ptr) + } + + pub fn from_const(ptr: *const T) -> Self { + Self(ptr as *mut T) + } + + pub fn new_null() -> Self { + Self(std::ptr::null_mut()) + } + + #[inline(always)] + pub fn get(&self) -> *mut T { + self.0 + } + + pub fn is_null(&self) -> bool { + self.0.is_null() + } + + /// # Safety + /// Derefs a raw pointer, no guarantees whatsoever. + pub unsafe fn deref_unchecked(&self) -> &'static T { + unsafe { &*(self.0 as *const T) } + } +} + +impl Copy for SyncPtr {} +impl Clone for SyncPtr { + fn clone(&self) -> SyncPtr { + *self + } +} +unsafe impl Sync for SyncPtr {} +unsafe impl Send for SyncPtr {} + +impl From<*const T> for SyncPtr { + fn from(value: *const T) -> Self { + Self::from_const(value) + } +} diff --git a/crates/polars-utils/src/sys.rs b/crates/polars-utils/src/sys.rs new file mode 100644 index 000000000000..b2a4bee748bc --- /dev/null +++ b/crates/polars-utils/src/sys.rs @@ -0,0 +1,24 @@ +use std::sync::{LazyLock, Mutex}; + +use sysinfo::System; + +/// Startup system is expensive, so we do it once +pub struct MemInfo { + sys: Mutex, +} + +impl MemInfo { + /// This call is quite expensive, cache the results. + pub fn free(&self) -> u64 { + let mut sys = self.sys.lock().unwrap(); + sys.refresh_memory(); + match sys.cgroup_limits() { + Some(limits) => limits.free_memory, + None => sys.available_memory(), + } + } +} + +pub static MEMINFO: LazyLock = LazyLock::new(|| MemInfo { + sys: Mutex::new(System::new()), +}); diff --git a/crates/polars-utils/src/total_ord.rs b/crates/polars-utils/src/total_ord.rs new file mode 100644 index 000000000000..8a63150ab387 --- /dev/null +++ b/crates/polars-utils/src/total_ord.rs @@ -0,0 +1,625 @@ +use std::cmp::Ordering; +use std::hash::{BuildHasher, Hash, Hasher}; + +use bytemuck::TransparentWrapper; + +use crate::hashing::{BytesHash, DirtyHash}; +use crate::nulls::IsNull; + +/// Converts an f32 into a canonical form, where -0 == 0 and all NaNs map to +/// the same value. +#[inline] +pub fn canonical_f32(x: f32) -> f32 { + // -0.0 + 0.0 becomes 0.0. + let convert_zero = x + 0.0; + if convert_zero.is_nan() { + f32::from_bits(0x7fc00000) // Canonical quiet NaN. + } else { + convert_zero + } +} + +/// Converts an f64 into a canonical form, where -0 == 0 and all NaNs map to +/// the same value. +#[inline] +pub fn canonical_f64(x: f64) -> f64 { + // -0.0 + 0.0 becomes 0.0. + let convert_zero = x + 0.0; + if convert_zero.is_nan() { + f64::from_bits(0x7ff8000000000000) // Canonical quiet NaN. + } else { + convert_zero + } +} + +/// Alternative trait for Eq. By consistently using this we can still be +/// generic w.r.t Eq while getting a total ordering for floats. +pub trait TotalEq { + fn tot_eq(&self, other: &Self) -> bool; + + #[inline] + fn tot_ne(&self, other: &Self) -> bool { + !(self.tot_eq(other)) + } +} + +/// Alternative trait for Ord. By consistently using this we can still be +/// generic w.r.t Ord while getting a total ordering for floats. +pub trait TotalOrd: TotalEq { + fn tot_cmp(&self, other: &Self) -> Ordering; + + #[inline] + fn tot_lt(&self, other: &Self) -> bool { + self.tot_cmp(other) == Ordering::Less + } + + #[inline] + fn tot_gt(&self, other: &Self) -> bool { + self.tot_cmp(other) == Ordering::Greater + } + + #[inline] + fn tot_le(&self, other: &Self) -> bool { + self.tot_cmp(other) != Ordering::Greater + } + + #[inline] + fn tot_ge(&self, other: &Self) -> bool { + self.tot_cmp(other) != Ordering::Less + } +} + +/// Alternative trait for Hash. By consistently using this we can still be +/// generic w.r.t Hash while being able to hash floats. +pub trait TotalHash { + fn tot_hash(&self, state: &mut H) + where + H: Hasher; + + fn tot_hash_slice(data: &[Self], state: &mut H) + where + H: Hasher, + Self: Sized, + { + for piece in data { + piece.tot_hash(state) + } + } +} + +pub trait BuildHasherTotalExt: BuildHasher { + fn tot_hash_one(&self, x: T) -> u64 + where + T: TotalHash, + Self: Sized, + ::Hasher: Hasher, + { + let mut hasher = self.build_hasher(); + x.tot_hash(&mut hasher); + hasher.finish() + } +} + +impl BuildHasherTotalExt for T {} + +#[repr(transparent)] +pub struct TotalOrdWrap(pub T); +unsafe impl TransparentWrapper for TotalOrdWrap {} + +impl PartialOrd for TotalOrdWrap { + #[inline(always)] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } + + #[inline(always)] + fn lt(&self, other: &Self) -> bool { + self.0.tot_lt(&other.0) + } + + #[inline(always)] + fn le(&self, other: &Self) -> bool { + self.0.tot_le(&other.0) + } + + #[inline(always)] + fn gt(&self, other: &Self) -> bool { + self.0.tot_gt(&other.0) + } + + #[inline(always)] + fn ge(&self, other: &Self) -> bool { + self.0.tot_ge(&other.0) + } +} + +impl Ord for TotalOrdWrap { + #[inline(always)] + fn cmp(&self, other: &Self) -> Ordering { + self.0.tot_cmp(&other.0) + } +} + +impl PartialEq for TotalOrdWrap { + #[inline(always)] + fn eq(&self, other: &Self) -> bool { + self.0.tot_eq(&other.0) + } + + #[inline(always)] + #[allow(clippy::partialeq_ne_impl)] + fn ne(&self, other: &Self) -> bool { + self.0.tot_ne(&other.0) + } +} + +impl Eq for TotalOrdWrap {} + +impl Hash for TotalOrdWrap { + #[inline(always)] + fn hash(&self, state: &mut H) { + self.0.tot_hash(state); + } +} + +impl Clone for TotalOrdWrap { + #[inline] + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl Copy for TotalOrdWrap {} + +impl IsNull for TotalOrdWrap { + const HAS_NULLS: bool = T::HAS_NULLS; + type Inner = T::Inner; + + #[inline(always)] + fn is_null(&self) -> bool { + self.0.is_null() + } + + #[inline(always)] + fn unwrap_inner(self) -> Self::Inner { + self.0.unwrap_inner() + } +} + +impl DirtyHash for f32 { + #[inline(always)] + fn dirty_hash(&self) -> u64 { + canonical_f32(*self).to_bits().dirty_hash() + } +} + +impl DirtyHash for f64 { + #[inline(always)] + fn dirty_hash(&self) -> u64 { + canonical_f64(*self).to_bits().dirty_hash() + } +} + +impl DirtyHash for TotalOrdWrap { + #[inline(always)] + fn dirty_hash(&self) -> u64 { + self.0.dirty_hash() + } +} + +macro_rules! impl_trivial_total { + ($T: ty) => { + impl TotalEq for $T { + #[inline(always)] + fn tot_eq(&self, other: &Self) -> bool { + self == other + } + + #[inline(always)] + fn tot_ne(&self, other: &Self) -> bool { + self != other + } + } + + impl TotalOrd for $T { + #[inline(always)] + fn tot_cmp(&self, other: &Self) -> Ordering { + self.cmp(other) + } + + #[inline(always)] + fn tot_lt(&self, other: &Self) -> bool { + self < other + } + + #[inline(always)] + fn tot_gt(&self, other: &Self) -> bool { + self > other + } + + #[inline(always)] + fn tot_le(&self, other: &Self) -> bool { + self <= other + } + + #[inline(always)] + fn tot_ge(&self, other: &Self) -> bool { + self >= other + } + } + + impl TotalHash for $T { + #[inline(always)] + fn tot_hash(&self, state: &mut H) + where + H: Hasher, + { + self.hash(state); + } + } + }; +} + +// We can't do a blanket impl because Rust complains f32 might implement +// Ord / Eq someday. +impl_trivial_total!(bool); +impl_trivial_total!(u8); +impl_trivial_total!(u16); +impl_trivial_total!(u32); +impl_trivial_total!(u64); +impl_trivial_total!(u128); +impl_trivial_total!(usize); +impl_trivial_total!(i8); +impl_trivial_total!(i16); +impl_trivial_total!(i32); +impl_trivial_total!(i64); +impl_trivial_total!(i128); +impl_trivial_total!(isize); +impl_trivial_total!(char); +impl_trivial_total!(&str); +impl_trivial_total!(&[u8]); +impl_trivial_total!(String); + +macro_rules! impl_float_eq_ord { + ($T:ty) => { + impl TotalEq for $T { + #[inline] + fn tot_eq(&self, other: &Self) -> bool { + if self.is_nan() { + other.is_nan() + } else { + self == other + } + } + } + + impl TotalOrd for $T { + #[inline(always)] + fn tot_cmp(&self, other: &Self) -> Ordering { + if self.tot_lt(other) { + Ordering::Less + } else if self.tot_gt(other) { + Ordering::Greater + } else { + Ordering::Equal + } + } + + #[inline(always)] + fn tot_lt(&self, other: &Self) -> bool { + !self.tot_ge(other) + } + + #[inline(always)] + fn tot_gt(&self, other: &Self) -> bool { + other.tot_lt(self) + } + + #[inline(always)] + fn tot_le(&self, other: &Self) -> bool { + other.tot_ge(self) + } + + #[inline(always)] + fn tot_ge(&self, other: &Self) -> bool { + // We consider all NaNs equal, and NaN is the largest possible + // value. Thus if self is NaN we always return true. Otherwise + // self >= other is correct. If other is not NaN it is trivially + // correct, and if it is we note that nothing can be greater or + // equal to NaN except NaN itself, which we already handled earlier. + self.is_nan() | (self >= other) + } + } + }; +} + +impl_float_eq_ord!(f32); +impl_float_eq_ord!(f64); + +impl TotalHash for f32 { + #[inline(always)] + fn tot_hash(&self, state: &mut H) + where + H: Hasher, + { + canonical_f32(*self).to_bits().hash(state) + } +} + +impl TotalHash for f64 { + #[inline(always)] + fn tot_hash(&self, state: &mut H) + where + H: Hasher, + { + canonical_f64(*self).to_bits().hash(state) + } +} + +// Blanket implementations. +impl TotalEq for Option { + #[inline(always)] + fn tot_eq(&self, other: &Self) -> bool { + match (self, other) { + (None, None) => true, + (Some(a), Some(b)) => a.tot_eq(b), + _ => false, + } + } + + #[inline(always)] + fn tot_ne(&self, other: &Self) -> bool { + match (self, other) { + (None, None) => false, + (Some(a), Some(b)) => a.tot_ne(b), + _ => true, + } + } +} + +impl TotalOrd for Option { + #[inline(always)] + fn tot_cmp(&self, other: &Self) -> Ordering { + match (self, other) { + (None, None) => Ordering::Equal, + (None, Some(_)) => Ordering::Less, + (Some(_), None) => Ordering::Greater, + (Some(a), Some(b)) => a.tot_cmp(b), + } + } + + #[inline(always)] + fn tot_lt(&self, other: &Self) -> bool { + match (self, other) { + (None, Some(_)) => true, + (Some(a), Some(b)) => a.tot_lt(b), + _ => false, + } + } + + #[inline(always)] + fn tot_gt(&self, other: &Self) -> bool { + other.tot_lt(self) + } + + #[inline(always)] + fn tot_le(&self, other: &Self) -> bool { + match (self, other) { + (Some(_), None) => false, + (Some(a), Some(b)) => a.tot_lt(b), + _ => true, + } + } + + #[inline(always)] + fn tot_ge(&self, other: &Self) -> bool { + other.tot_le(self) + } +} + +impl TotalHash for Option { + #[inline] + fn tot_hash(&self, state: &mut H) + where + H: Hasher, + { + self.is_some().tot_hash(state); + if let Some(slf) = self { + slf.tot_hash(state) + } + } +} + +impl TotalEq for &T { + #[inline(always)] + fn tot_eq(&self, other: &Self) -> bool { + (*self).tot_eq(*other) + } + + #[inline(always)] + fn tot_ne(&self, other: &Self) -> bool { + (*self).tot_ne(*other) + } +} + +impl TotalHash for &T { + #[inline(always)] + fn tot_hash(&self, state: &mut H) + where + H: Hasher, + { + (*self).tot_hash(state) + } +} + +impl TotalEq for (T, U) { + #[inline] + fn tot_eq(&self, other: &Self) -> bool { + self.0.tot_eq(&other.0) && self.1.tot_eq(&other.1) + } +} + +impl TotalOrd for (T, U) { + #[inline] + fn tot_cmp(&self, other: &Self) -> Ordering { + self.0 + .tot_cmp(&other.0) + .then_with(|| self.1.tot_cmp(&other.1)) + } +} + +impl TotalHash for BytesHash<'_> { + #[inline(always)] + fn tot_hash(&self, state: &mut H) + where + H: Hasher, + { + self.hash(state) + } +} + +impl TotalEq for BytesHash<'_> { + #[inline(always)] + fn tot_eq(&self, other: &Self) -> bool { + self == other + } +} + +/// This elides creating a [`TotalOrdWrap`] for types that don't need it. +pub trait ToTotalOrd { + type TotalOrdItem: Hash + Eq; + type SourceItem; + + fn to_total_ord(&self) -> Self::TotalOrdItem; + + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem; +} + +macro_rules! impl_to_total_ord_identity { + ($T: ty) => { + impl ToTotalOrd for $T { + type TotalOrdItem = $T; + type SourceItem = $T; + + #[inline] + fn to_total_ord(&self) -> Self::TotalOrdItem { + self.clone() + } + + #[inline] + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem { + ord_item + } + } + }; +} + +impl_to_total_ord_identity!(bool); +impl_to_total_ord_identity!(u8); +impl_to_total_ord_identity!(u16); +impl_to_total_ord_identity!(u32); +impl_to_total_ord_identity!(u64); +impl_to_total_ord_identity!(u128); +impl_to_total_ord_identity!(usize); +impl_to_total_ord_identity!(i8); +impl_to_total_ord_identity!(i16); +impl_to_total_ord_identity!(i32); +impl_to_total_ord_identity!(i64); +impl_to_total_ord_identity!(i128); +impl_to_total_ord_identity!(isize); +impl_to_total_ord_identity!(char); +impl_to_total_ord_identity!(String); + +macro_rules! impl_to_total_ord_lifetimed_ref_identity { + ($T: ty) => { + impl<'a> ToTotalOrd for &'a $T { + type TotalOrdItem = &'a $T; + type SourceItem = &'a $T; + + #[inline] + fn to_total_ord(&self) -> Self::TotalOrdItem { + *self + } + + #[inline] + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem { + ord_item + } + } + }; +} + +impl_to_total_ord_lifetimed_ref_identity!(str); +impl_to_total_ord_lifetimed_ref_identity!([u8]); + +macro_rules! impl_to_total_ord_wrapped { + ($T: ty) => { + impl ToTotalOrd for $T { + type TotalOrdItem = TotalOrdWrap<$T>; + type SourceItem = $T; + + #[inline] + fn to_total_ord(&self) -> Self::TotalOrdItem { + TotalOrdWrap(self.clone()) + } + + #[inline] + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem { + ord_item.0 + } + } + }; +} + +impl_to_total_ord_wrapped!(f32); +impl_to_total_ord_wrapped!(f64); + +/// This is safe without needing to map the option value to TotalOrdWrap, since +/// for example: +/// `TotalOrdWrap>` implements `Eq + Hash`, iff: +/// `Option` implements `TotalEq + TotalHash`, iff: +/// `T` implements `TotalEq + TotalHash` +impl ToTotalOrd for Option { + type TotalOrdItem = TotalOrdWrap>; + type SourceItem = Option; + + #[inline] + fn to_total_ord(&self) -> Self::TotalOrdItem { + TotalOrdWrap(*self) + } + + #[inline] + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem { + ord_item.0 + } +} + +impl ToTotalOrd for &T { + type TotalOrdItem = T::TotalOrdItem; + type SourceItem = T::SourceItem; + + #[inline] + fn to_total_ord(&self) -> Self::TotalOrdItem { + (*self).to_total_ord() + } + + #[inline] + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem { + T::peel_total_ord(ord_item) + } +} + +impl<'a> ToTotalOrd for BytesHash<'a> { + type TotalOrdItem = BytesHash<'a>; + type SourceItem = BytesHash<'a>; + + #[inline] + fn to_total_ord(&self) -> Self::TotalOrdItem { + *self + } + + #[inline] + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem { + ord_item + } +} diff --git a/crates/polars-utils/src/vec.rs b/crates/polars-utils/src/vec.rs new file mode 100644 index 000000000000..f1e7df4554c4 --- /dev/null +++ b/crates/polars-utils/src/vec.rs @@ -0,0 +1,181 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use std::mem::MaybeUninit; + +use num_traits::Zero; + +pub trait IntoRawParts { + fn into_raw_parts(self) -> (*mut T, usize, usize); + + // doesn't take ownership + fn raw_parts(&self) -> (*mut T, usize, usize); +} + +impl IntoRawParts for Vec { + fn into_raw_parts(self) -> (*mut T, usize, usize) { + let mut me = std::mem::ManuallyDrop::new(self); + (me.as_mut_ptr(), me.len(), me.capacity()) + } + + fn raw_parts(&self) -> (*mut T, usize, usize) { + (self.as_ptr() as *mut T, self.len(), self.capacity()) + } +} + +/// Fill current allocation if > 0 +/// otherwise realloc +pub trait ResizeFaster { + fn fill_or_alloc(&mut self, new_len: usize, value: T); +} + +impl ResizeFaster for Vec { + fn fill_or_alloc(&mut self, new_len: usize, value: T) { + if self.capacity() == 0 { + // it is faster to allocate zeroed + // so if the capacity is 0, we alloc (value might be 0) + *self = vec![value; new_len] + } else { + // first clear then reserve so that the reserve doesn't have + // to memcpy in case it needs to realloc. + self.clear(); + self.reserve(new_len); + + // // init the uninit values + let spare = &mut self.spare_capacity_mut()[..new_len]; + let init_value = MaybeUninit::new(value); + spare.fill(init_value); + unsafe { self.set_len(new_len) } + } + } +} +pub trait PushUnchecked { + /// Will push an item and not check if there is enough capacity + /// + /// # Safety + /// Caller must ensure the array has enough capacity to hold `T`. + unsafe fn push_unchecked(&mut self, value: T); +} + +impl PushUnchecked for Vec { + #[inline] + unsafe fn push_unchecked(&mut self, value: T) { + debug_assert!(self.capacity() > self.len()); + let end = self.as_mut_ptr().add(self.len()); + std::ptr::write(end, value); + self.set_len(self.len() + 1); + } +} + +pub trait CapacityByFactor { + fn with_capacity_by_factor(original_len: usize, factor: f64) -> Self; +} + +impl CapacityByFactor for Vec { + fn with_capacity_by_factor(original_len: usize, factor: f64) -> Self { + let cap = (original_len as f64 * factor) as usize; + Vec::with_capacity(cap) + } +} + +// Trait to convert a Vec. +// The reason for this is to reduce code-generation. Conversion functions that are named +// functions should only generate the conversion loop once. +pub trait ConvertVec { + type ItemIn; + + fn convert_owned Out>(self, f: F) -> Vec; + + fn convert Out>(&self, f: F) -> Vec; +} + +impl ConvertVec for Vec { + type ItemIn = T; + + fn convert_owned Out>(self, f: F) -> Vec { + self.into_iter().map(f).collect() + } + + fn convert Out>(&self, f: F) -> Vec { + self.iter().map(f).collect() + } +} + +/// Perform an in-place `Iterator::filter_map` over two vectors at the same time. +pub fn inplace_zip_filtermap( + x: &mut Vec, + y: &mut Vec, + mut f: impl FnMut(T, U) -> Option<(T, U)>, +) { + assert_eq!(x.len(), y.len()); + + let length = x.len(); + + struct OwnedBuffer { + end: *mut T, + length: usize, + } + + impl Drop for OwnedBuffer { + fn drop(&mut self) { + for i in 0..self.length { + unsafe { self.end.wrapping_sub(i + 1).read() }; + } + } + } + + let x_ptr = x.as_mut_ptr(); + let y_ptr = y.as_mut_ptr(); + + let mut x_buf = OwnedBuffer { + end: x_ptr.wrapping_add(length), + length, + }; + let mut y_buf = OwnedBuffer { + end: y_ptr.wrapping_add(length), + length, + }; + + // SAFETY: All items are now owned by `x_buf` and `y_buf`. Since we know that `x_buf` and + // `y_buf` will be dropped before the vecs representing `x` and `y`, this is safe. + unsafe { + x.set_len(0); + y.set_len(0); + } + + // SAFETY: + // + // We know we have a exclusive reference to x and y. + // + // We know that `i` is always smaller than `x.len()` and `y.len()`. Furthermore, we also know + // that `i - num_deleted > 0`. + // + // Items are dropped exactly once, even if `f` panics. + for i in 0..length { + let xi = unsafe { x_ptr.wrapping_add(i).read() }; + let yi = unsafe { y_ptr.wrapping_add(i).read() }; + + x_buf.length -= 1; + y_buf.length -= 1; + + // We hold the invariant here that all items that are not yet deleted are either in + // - `xi` or `yi` + // - `x_buf` or `y_buf` + // ` `x` or `y` + // + // This way if `f` ever panics, we are sure that all items are dropped exactly once. + // Deleted items will be dropped when they are deleted. + let result = f(xi, yi); + + if let Some((xi, yi)) = result { + x.push(xi); + y.push(yi); + } + } + + debug_assert_eq!(x_buf.length, 0); + debug_assert_eq!(y_buf.length, 0); + + // We are safe to forget `x_buf` and `y_buf` here since they will not deallocate anything + // anymore. + std::mem::forget(x_buf); + std::mem::forget(y_buf); +} diff --git a/crates/polars-utils/src/wasm.rs b/crates/polars-utils/src/wasm.rs new file mode 100644 index 000000000000..9325d11a8bc5 --- /dev/null +++ b/crates/polars-utils/src/wasm.rs @@ -0,0 +1,48 @@ +pub struct Pool; + +impl Pool { + pub fn current_num_threads(&self) -> usize { + rayon::current_num_threads() + } + + pub fn current_thread_index(&self) -> Option { + rayon::current_thread_index() + } + + pub fn current_thread_has_pending_tasks(&self) -> Option { + None + } + + pub fn install(&self, op: OP) -> R + where + OP: FnOnce() -> R + Send, + R: Send, + { + op() + } + + pub fn join(&self, oper_a: A, oper_b: B) -> (RA, RB) + where + A: FnOnce() -> RA + Send, + B: FnOnce() -> RB + Send, + RA: Send, + RB: Send, + { + rayon::join(oper_a, oper_b) + } + + pub fn spawn(&self, func: F) + where + F: 'static + FnOnce() + Send, + { + rayon::spawn(func); + } + + pub fn scope<'scope, OP, R>(&self, op: OP) -> R + where + OP: FnOnce(&rayon::Scope<'scope>) -> R + Send, + R: Send, + { + rayon::scope(op) + } +} diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml new file mode 100644 index 000000000000..1bb328bb5e78 --- /dev/null +++ b/crates/polars/Cargo.toml @@ -0,0 +1,452 @@ +[package] +name = "polars" +version = { workspace = true } +authors = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +keywords = ["dataframe", "query-engine", "arrow"] +license = { workspace = true } +readme = "../../README.md" +repository = { workspace = true } +description = "DataFrame library based on Apache Arrow" + +[dependencies] +arrow = { workspace = true } +polars-core = { workspace = true, features = ["algorithm_group_by"] } +polars-error = { workspace = true } +polars-io = { workspace = true, optional = true } +polars-lazy = { workspace = true, optional = true } +polars-ops = { workspace = true, optional = true } +polars-parquet = { workspace = true } +polars-plan = { workspace = true, optional = true } +polars-sql = { workspace = true, optional = true } +polars-time = { workspace = true, optional = true } +polars-utils = { workspace = true } + +[dev-dependencies] +apache-avro = { version = "0.17", features = ["snappy"] } +arrow = { workspace = true } +avro-schema = { workspace = true, features = ["async"] } +chrono = { workspace = true } +either = { workspace = true } +ethnum = "1" +futures = { workspace = true } +# used to run formal property testing +proptest = { version = "1", default-features = false, features = ["std"] } +rand = { workspace = true } +# used to test async readers +tokio = { workspace = true, features = ["macros", "rt", "fs", "io-util"] } +tokio-util = { workspace = true, features = ["compat"] } + +[build-dependencies] +version_check = { workspace = true } + +# enable js feature for getrandom to work in wasm +[target.'cfg(target_family = "wasm")'.dependencies] +getrandom = { version = "0.2", features = ["js"] } + +[features] +sql = ["polars-sql"] +rows = ["polars-core/rows"] +simd = ["polars-core/simd", "polars-io/simd", "polars-ops?/simd"] +avx512 = ["polars-core/avx512"] +nightly = ["polars-core/nightly", "polars-ops?/nightly", "simd", "polars-lazy?/nightly", "polars-sql?/nightly"] +docs = ["polars-core/docs"] +temporal = ["polars-core/temporal", "polars-lazy?/temporal", "polars-io/temporal", "polars-time"] +random = ["polars-core/random", "polars-lazy?/random", "polars-ops/random"] +default = [ + "docs", + "zip_with", + "csv", + "temporal", + "fmt", + "dtype-slim", +] +ndarray = ["polars-core/ndarray"] +# serde support for dataframes and series +serde = ["polars-core/serde", "polars-utils/serde", "ir_serde"] +serde-lazy = [ + "polars-core/serde-lazy", + "polars-lazy?/serde", + "polars-time?/serde", + "polars-io?/serde", + "polars-ops?/serde", + "polars-utils/serde", +] +parquet = ["polars-io", "polars-lazy?/parquet", "polars-io/parquet", "polars-sql?/parquet"] +async = ["polars-lazy?/async"] +cloud = ["polars-lazy?/cloud", "polars-io/cloud"] +aws = ["async", "cloud", "polars-io/aws"] +http = ["async", "cloud", "polars-io/http"] +azure = ["async", "cloud", "polars-io/azure"] +gcp = ["async", "cloud", "polars-io/gcp"] +lazy = ["polars-core/lazy", "polars-lazy"] +# commented out until UB is fixed +# parallel = ["polars-core/parallel"] + +# extra utilities for StringChunked +strings = ["polars-core/strings", "polars-lazy?/strings", "polars-ops/strings"] + +# support for ObjectChunked (downcastable Series of any type) +object = ["polars-core/object", "polars-lazy?/object", "polars-io/object"] + +# support for arrows json parsing +json = ["polars-io", "polars-io/json", "polars-lazy?/json", "polars-sql?/json", "dtype-struct"] + +# support for arrows ipc file parsing +ipc = ["polars-io", "polars-io/ipc", "polars-lazy?/ipc", "polars-sql?/ipc"] + +# support for arrows streaming ipc file parsing +ipc_streaming = ["polars-io", "polars-io/ipc_streaming", "polars-lazy?/ipc"] + +# support for apache avro file parsing +avro = ["polars-io", "polars-io/avro"] + +# support for arrows csv file parsing +csv = ["polars-io", "polars-io/csv", "polars-lazy?/csv", "polars-sql?/csv"] + +# slower builds +performant = [ + "polars-core/performant", + "chunked_ids", + "dtype-u8", + "dtype-u16", + "dtype-struct", + "cse", + "polars-ops/performant", + "streaming", + "fused", +] + +# Dataframe formatting. +fmt = ["polars-core/fmt"] +fmt_no_tty = ["polars-core/fmt_no_tty"] + +# extra operations +abs = ["polars-ops/abs", "polars-lazy?/abs"] +approx_unique = ["polars-lazy?/approx_unique", "polars-ops/approx_unique", "polars-core/approx_unique"] +arg_where = ["polars-lazy?/arg_where"] +array_any_all = ["polars-lazy?/array_any_all", "dtype-array"] +asof_join = ["polars-lazy?/asof_join", "polars-ops/asof_join"] +iejoin = ["polars-lazy?/iejoin"] +binary_encoding = ["polars-ops/binary_encoding", "polars-lazy?/binary_encoding", "polars-sql?/binary_encoding"] +bitwise = [ + "polars-core/bitwise", + "polars-plan?/bitwise", + "polars-ops/bitwise", + "polars-lazy?/bitwise", + "polars-sql?/bitwise", +] +business = ["polars-lazy?/business", "polars-ops/business"] +checked_arithmetic = ["polars-core/checked_arithmetic"] +chunked_ids = ["polars-ops?/chunked_ids"] +coalesce = ["polars-lazy?/coalesce"] +concat_str = ["polars-lazy?/concat_str"] +cov = ["polars-lazy/cov"] +cross_join = ["polars-lazy?/cross_join", "polars-ops/cross_join"] +cse = ["polars-lazy?/cse"] +cum_agg = ["polars-ops/cum_agg", "polars-lazy?/cum_agg"] +cumulative_eval = ["polars-lazy?/cumulative_eval"] +cutqcut = ["polars-lazy?/cutqcut"] +dataframe_arithmetic = ["polars-core/dataframe_arithmetic"] +month_start = ["polars-lazy?/month_start"] +month_end = ["polars-lazy?/month_end"] +offset_by = ["polars-lazy?/offset_by"] +decompress = ["polars-io/decompress"] +describe = ["polars-core/describe"] +diagonal_concat = ["polars-core/diagonal_concat", "polars-lazy?/diagonal_concat", "polars-sql?/diagonal_concat"] +diff = ["polars-ops/diff", "polars-lazy?/diff"] +dot_diagram = ["polars-lazy?/dot_diagram"] +dot_product = ["polars-core/dot_product"] +dynamic_group_by = ["polars-core/dynamic_group_by", "polars-lazy?/dynamic_group_by"] +ewma = ["polars-ops/ewma", "polars-lazy?/ewma"] +ewma_by = ["polars-ops/ewma_by", "polars-lazy?/ewma_by"] +extract_groups = ["polars-lazy?/extract_groups"] +extract_jsonpath = [ + "polars-core/strings", + "polars-ops/extract_jsonpath", + "polars-ops/strings", + "polars-lazy?/extract_jsonpath", +] +find_many = ["polars-plan/find_many"] +fused = ["polars-ops/fused", "polars-lazy?/fused"] +interpolate = ["polars-ops/interpolate", "polars-lazy?/interpolate"] +interpolate_by = ["polars-ops/interpolate_by", "polars-lazy?/interpolate_by"] +is_between = ["polars-lazy?/is_between", "polars-ops/is_between"] +is_first_distinct = ["polars-lazy?/is_first_distinct", "polars-ops/is_first_distinct"] +is_in = ["polars-lazy?/is_in"] +is_last_distinct = ["polars-lazy?/is_last_distinct", "polars-ops/is_last_distinct"] +is_unique = ["polars-lazy?/is_unique", "polars-ops/is_unique"] +regex = ["polars-lazy?/regex"] +list_any_all = ["polars-lazy?/list_any_all"] +list_count = ["polars-ops/list_count", "polars-lazy?/list_count"] +array_count = ["polars-ops/array_count", "polars-lazy?/array_count", "dtype-array"] +list_drop_nulls = ["polars-lazy?/list_drop_nulls"] +list_eval = ["polars-lazy?/list_eval", "polars-sql?/list_eval"] +list_gather = ["polars-ops/list_gather", "polars-lazy?/list_gather"] +list_sample = ["polars-lazy?/list_sample"] +list_sets = ["polars-lazy?/list_sets"] +list_to_struct = ["polars-ops/list_to_struct", "polars-lazy?/list_to_struct"] +list_arithmetic = ["polars-core/list_arithmetic"] +array_arithmetic = ["polars-core/array_arithmetic", "dtype-array"] +array_to_struct = ["polars-ops/array_to_struct", "polars-lazy?/array_to_struct"] +log = ["polars-ops/log", "polars-lazy?/log"] +merge_sorted = ["polars-lazy?/merge_sorted"] +meta = ["polars-lazy?/meta"] +mode = ["polars-ops/mode", "polars-lazy?/mode"] +moment = ["polars-ops/moment", "polars-lazy?/moment"] +partition_by = ["polars-core/partition_by"] +pct_change = ["polars-ops/pct_change", "polars-lazy?/pct_change"] +peaks = ["polars-lazy/peaks"] +pivot = ["polars-lazy?/pivot", "polars-ops/pivot", "dtype-struct", "rows"] +product = ["polars-core/product"] +propagate_nans = ["polars-lazy?/propagate_nans"] +range = ["polars-lazy?/range"] +rank = ["polars-lazy?/rank", "polars-ops/rank"] +reinterpret = ["polars-core/reinterpret", "polars-lazy?/reinterpret", "polars-ops/reinterpret"] +repeat_by = ["polars-ops/repeat_by", "polars-lazy?/repeat_by"] +replace = ["polars-ops/replace", "polars-lazy?/replace"] +rle = ["polars-lazy?/rle"] +rolling_window = ["polars-core/rolling_window", "polars-lazy?/rolling_window"] +rolling_window_by = ["polars-core/rolling_window_by", "polars-lazy?/rolling_window_by", "polars-time/rolling_window_by"] +round_series = ["polars-ops/round_series", "polars-lazy?/round_series"] +row_hash = ["polars-core/row_hash", "polars-lazy?/row_hash"] +index_of = ["polars-lazy?/index_of"] +search_sorted = ["polars-lazy?/search_sorted"] +semi_anti_join = ["polars-lazy?/semi_anti_join", "polars-ops/semi_anti_join", "polars-sql?/semi_anti_join"] +sign = ["polars-lazy?/sign"] +streaming = ["polars-lazy?/streaming"] +string_encoding = ["polars-ops/string_encoding", "polars-lazy?/string_encoding", "polars-core/strings"] +string_pad = ["polars-lazy?/string_pad", "polars-ops/string_pad"] +string_normalize = ["polars-lazy?/string_normalize", "polars-ops/string_normalize"] +string_reverse = ["polars-lazy?/string_reverse", "polars-ops/string_reverse"] +string_to_integer = ["polars-lazy?/string_to_integer", "polars-ops/string_to_integer"] +take_opt_iter = ["polars-core/take_opt_iter"] +timezones = [ + "polars-core/timezones", + "polars-lazy?/timezones", + "polars-io/timezones", + "polars-ops/timezones", + "polars-sql?/timezones", +] +to_dummies = ["polars-ops/to_dummies"] +top_k = ["polars-lazy?/top_k"] +trigonometry = ["polars-lazy?/trigonometry"] +true_div = ["polars-lazy?/true_div"] +unique_counts = ["polars-ops/unique_counts", "polars-lazy?/unique_counts"] +zip_with = ["polars-core/zip_with"] + +bigidx = ["polars-core/bigidx", "polars-lazy?/bigidx", "polars-ops/big_idx", "polars-utils/bigidx"] +polars_cloud = ["polars-lazy?/polars_cloud"] +ir_serde = ["polars-plan/ir_serde"] + +test = [ + "lazy", + "rolling_window", + "rank", + "round_series", + "csv", + "dtype-categorical", + "cum_agg", + "fmt", + "diff", + "abs", + "parquet", + "ipc", + "ipc_streaming", + "json", +] + +# all opt-in datatypes +dtype-full = [ + "dtype-date", + "dtype-datetime", + "dtype-duration", + "dtype-time", + "dtype-array", + "dtype-i8", + "dtype-i16", + "dtype-i128", + "dtype-decimal", + "dtype-u8", + "dtype-u16", + "dtype-categorical", + "dtype-struct", +] + +# sensible minimal set of opt-in datatypes +dtype-slim = [ + "dtype-date", + "dtype-datetime", + "dtype-duration", +] + +# opt-in datatypes for Series +dtype-date = [ + "polars-core/dtype-date", + "polars-io/dtype-date", + "polars-lazy?/dtype-date", + "polars-time?/dtype-date", + "polars-ops/dtype-date", +] +dtype-datetime = [ + "polars-core/dtype-datetime", + "polars-io/dtype-datetime", + "polars-lazy?/dtype-datetime", + "polars-time?/dtype-datetime", + "polars-ops/dtype-datetime", +] +dtype-duration = [ + "polars-core/dtype-duration", + "polars-io/dtype-duration", + "polars-lazy?/dtype-duration", + "polars-time?/dtype-duration", + "polars-ops/dtype-duration", +] +dtype-time = [ + "polars-core/dtype-time", + "polars-io/dtype-time", + "polars-lazy?/dtype-time", + "polars-time?/dtype-time", + "polars-ops/dtype-time", +] +dtype-array = [ + "polars-core/dtype-array", + "polars-lazy?/dtype-array", + "polars-ops/dtype-array", + "polars-plan?/dtype-array", +] +dtype-i8 = [ + "polars-core/dtype-i8", + "polars-io/dtype-i8", + "polars-lazy?/dtype-i8", + "polars-ops/dtype-i8", + "polars-time?/dtype-i8", +] +dtype-i16 = [ + "polars-core/dtype-i16", + "polars-io/dtype-i16", + "polars-lazy?/dtype-i16", + "polars-ops/dtype-i16", + "polars-time?/dtype-i16", +] +dtype-i128 = [ + "polars-core/dtype-i128", + "polars-io/dtype-i128", + "polars-lazy?/dtype-i128", + "polars-ops/dtype-i128", + "polars-time?/dtype-i128", +] +dtype-decimal = [ + "polars-core/dtype-decimal", + "polars-io/dtype-decimal", + "polars-lazy?/dtype-decimal", + "polars-sql?/dtype-decimal", + "polars-ops/dtype-decimal", +] +dtype-u8 = [ + "polars-core/dtype-u8", + "polars-io/dtype-u8", + "polars-lazy?/dtype-u8", + "polars-ops/dtype-u8", + "polars-time?/dtype-u8", +] +dtype-u16 = [ + "polars-core/dtype-u16", + "polars-io/dtype-u16", + "polars-lazy?/dtype-u16", + "polars-ops/dtype-u16", + "polars-time?/dtype-u16", +] +dtype-categorical = [ + "polars-core/dtype-categorical", + "polars-io/dtype-categorical", + "polars-lazy?/dtype-categorical", + "polars-ops/dtype-categorical", +] +dtype-struct = [ + "polars-core/dtype-struct", + "polars-io/dtype-struct", + "polars-lazy?/dtype-struct", + "polars-ops/dtype-struct", +] +hist = ["polars-ops/hist", "polars-lazy/hist"] + +docs-selection = [ + "csv", + "json", + "parquet", + "ipc", + "ipc_streaming", + "dtype-full", + "is_in", + "rows", + "docs", + "strings", + "object", + "lazy", + "temporal", + "random", + "zip_with", + "round_series", + "checked_arithmetic", + "ndarray", + "repeat_by", + "is_between", + "is_first_distinct", + "is_last_distinct", + "asof_join", + "cross_join", + "semi_anti_join", + "iejoin", + "concat_str", + "string_reverse", + "string_to_integer", + "decompress", + "mode", + "take_opt_iter", + "cum_agg", + "rolling_window", + "rolling_window_by", + "interpolate", + "interpolate_by", + "diff", + "rank", + "range", + "diagonal_concat", + "abs", + "dot_diagram", + "string_encoding", + "product", + "to_dummies", + "describe", + "list_eval", + "cumulative_eval", + "timezones", + "arg_where", + "propagate_nans", + "coalesce", + "dynamic_group_by", + "extract_groups", + "replace", + "approx_unique", + "unique_counts", + "polars_cloud", + "serde", + "ir_serde", + "cloud", + "async", +] + +bench = [ + "lazy", +] + +# All features expect python +full = ["docs-selection", "performant", "fmt"] + +[package.metadata.docs.rs] +# all-features = true +features = ["docs-selection"] +# defines the configuration attribute `docsrs` +rustdoc-args = ["--cfg", "docsrs"] diff --git a/crates/polars/LICENSE b/crates/polars/LICENSE new file mode 120000 index 000000000000..30cff7403da0 --- /dev/null +++ b/crates/polars/LICENSE @@ -0,0 +1 @@ +../../LICENSE \ No newline at end of file diff --git a/crates/polars/build.rs b/crates/polars/build.rs new file mode 100644 index 000000000000..3e4ab64620ac --- /dev/null +++ b/crates/polars/build.rs @@ -0,0 +1,7 @@ +fn main() { + println!("cargo:rerun-if-changed=build.rs"); + let channel = version_check::Channel::read().unwrap(); + if channel.is_nightly() { + println!("cargo:rustc-cfg=feature=\"nightly\""); + } +} diff --git a/crates/polars/src/docs/eager.rs b/crates/polars/src/docs/eager.rs new file mode 100644 index 000000000000..f15a2d9f1a15 --- /dev/null +++ b/crates/polars/src/docs/eager.rs @@ -0,0 +1,730 @@ +//! +//! # Polars Eager cookbook +//! +//! This page should serve as a cookbook to quickly get you started with most fundamental operations +//! executed on a [`ChunkedArray`], [`Series`] or [`DataFrame`]. +//! +//! [`ChunkedArray`]: crate::chunked_array::ChunkedArray +//! [`Series`]: crate::series::Series +//! [`DataFrame`]: crate::frame::DataFrame +//! +//! ## Tree Of Contents +//! +//! * [Creation of data structures](#creation-of-data-structures) +//! - [ChunkedArray](#chunkedarray) +//! - [Series](#series) +//! - [DataFrame](#dataframe) +//! * [Arithmetic](#arithmetic) +//! * [Comparisons](#comparisons) +//! * [Apply functions/ closures](#apply-functions-closures) +//! - [Series / ChunkedArrays](#dataframe-1) +//! - [DataFrame](#dataframe-1) +//! * [Filter](#filter) +//! * [Sort](#sort) +//! * [Joins](#joins) +//! * [GroupBy](#group_by) +//! * [pivot](#pivot) +//! * [Unpivot](#unpivot) +//! * [Explode](#explode) +//! * [IO](#io) +//! - [Read CSV](#read-csv) +//! - [Write CSV](#write-csv) +//! - [Read IPC](#read-ipc) +//! - [Write IPC](#write-ipc) +//! - [Read Parquet](#read-parquet) +//! - [Write Parquet](#write-parquet) +//! * [Various](#various) +//! - [Replace NaN with Missing](#replace-nan-with-missing) +//! - [Extracting data](#extracting-data) +//! +//! ## Creation of data structures +//! +//! ### ChunkedArray +//! +//! ``` +//! use polars::prelude::*; +//! +//! // use iterators +//! let ca: UInt32Chunked = (0..10).map(Some).collect(); +//! +//! // from slices +//! let ca = UInt32Chunked::new("foo".into(), &[1, 2, 3]); +//! +//! // use builders +//! let mut builder = PrimitiveChunkedBuilder::::new("foo".into(), 10); +//! for value in 0..10 { +//! builder.append_value(value); +//! } +//! let ca = builder.finish(); +//! ``` +//! +//! ### Series +//! +//! ``` +//! use polars::prelude::*; +//! +//! // use iterators +//! let s: Series = (0..10).map(Some).collect(); +//! +//! // from slices +//! let s = Series::new("foo".into(), &[1, 2, 3]); +//! +//! // from a chunked-array +//! let ca = UInt32Chunked::new("foo".into(), &[Some(1), None, Some(3)]); +//! let s = ca.into_series(); +//! +//! // into a Column +//! let s = s.into_column(); +//! ``` +//! +//! ### DataFrame +//! +//! ``` +//! use polars::prelude::*; +//! use polars::df; +//! # fn example() -> PolarsResult<()> { +//! +//! // use macro +//! let df = df! [ +//! "names" => ["a", "b", "c"], +//! "values" => [1, 2, 3], +//! "values_nulls" => [Some(1), None, Some(3)] +//! ]?; +//! +//! // from a Vec +//! let c1 = Column::new("names".into(), &["a", "b", "c"]); +//! let c2 = Column::new("values".into(), &[Some(1), None, Some(3)]); +//! let df = DataFrame::new(vec![c1, c2])?; +//! # Ok(()) +//! # } +//! ``` +//! +//! ## Arithmetic +//! Arithmetic can be done on both [`Series`] and [`ChunkedArray`]. The most notable difference is that +//! a [`Series`] coerces the data to match the underlying data types. +//! +//! ``` +//! use polars::prelude::*; +//! # fn example() -> PolarsResult<()> { +//! let s_int = Series::new("a".into(), &[1, 2, 3]); +//! let s_flt = Series::new("b".into(), &[1.0, 2.0, 3.0]); +//! +//! let added = &s_int + &s_flt; +//! let subtracted = &s_int - &s_flt; +//! let multiplied = &s_int * &s_flt; +//! let divided = &s_int / &s_flt; +//! let moduloed = &s_int % &s_flt; +//! +//! +//! // on chunked-arrays we first need to cast to same types +//! let ca_int = s_int.i32()?; +//! let ca_flt = s_flt.f32()?; +//! +//! ca_int.cast(&DataType::Float32)?.f32()? * ca_flt; +//! ca_flt.cast(&DataType::Int32)?.i32()? * ca_int; +//! +//! // we can also do arithmetic with numeric values +//! let multiplied = ca_int * 2.0; +//! let multiplied = s_flt * 2.0; +//! +//! // or broadcast Series to match the operands type +//! let added = &s_int * &Series::new("broadcast_me".into(), &[10]); +//! +//! # Ok(()) +//! # } +//! ``` +//! +//! Because Rust's Orphan Rule doesn't allow us to implement left side operations, we need to call +//! such operations directly. +//! +//! ```rust +//! # use polars::prelude::*; +//! let series = Series::new("foo".into(), [1, 2, 3]); +//! +//! // 1 / s +//! let divide_one_by_s = 1.div(&series); +//! +//! // 1 - s +//! let subtract_one_by_s = 1.sub(&series); +//! ``` +//! +//! For [`ChunkedArray`] left hand side operations can be done with the [`apply_values`] method. +//! +//! [`apply_values`]: crate::chunked_array::ops::ChunkApply::apply_values +//! +//! ```rust +//! # use polars::prelude::*; +//! let ca = UInt32Chunked::new("foo".into(), &[1, 2, 3]); +//! +//! // 1 / ca +//! let divide_one_by_ca = ca.apply_values(|rhs| 1 / rhs); +//! ``` +//! +//! ## Comparisons +//! +//! [`Series`] and [`ChunkedArray`] can be used in comparison operations to create _boolean_ masks/predicates. +//! +//! ``` +//! use polars::prelude::*; +//! # fn example() -> PolarsResult<()> { +//! +//! let s = Series::new("a".into(), &[1, 2, 3]); +//! let ca = UInt32Chunked::new("b".into(), &[Some(3), None, Some(1)]); +//! +//! // compare Series with numeric values +//! // == +//! s.equal(2); +//! // != +//! s.not_equal(2); +//! // > +//! s.gt(2); +//! // >= +//! s.gt_eq(2); +//! // < +//! s.lt(2); +//! // <= +//! s.lt_eq(2); +//! +//! +//! // compare Series with Series +//! // == +//! s.equal(&s); +//! // != +//! s.not_equal(&s); +//! // > +//! s.gt(&s); +//! // >= +//! s.gt_eq(&s); +//! // < +//! s.lt(&s); +//! // <= +//! s.lt_eq(&s); +//! +//! +//! // compare chunked-array with numeric values +//! // == +//! ca.equal(2); +//! // != +//! ca.not_equal(2); +//! // > +//! ca.gt(2); +//! // >= +//! ca.gt_eq(2); +//! // < +//! ca.lt(2); +//! // <= +//! ca.lt_eq(2); +//! +//! // compare chunked-array with chunked-array +//! // == +//! ca.equal(&ca); +//! // != +//! ca.not_equal(&ca); +//! // > +//! ca.gt(&ca); +//! // >= +//! ca.gt_eq(&ca); +//! // < +//! ca.lt(&ca); +//! // <= +//! ca.lt_eq(&ca); +//! +//! // use iterators +//! let a: BooleanChunked = ca.iter() +//! .map(|opt_value| { +//! match opt_value { +//! Some(value) => value < 10, +//! None => false +//! }}).collect(); +//! +//! # Ok(()) +//! # } +//! ``` +//! +//! +//! ## Apply functions/ closures +//! +//! See all possible [apply methods here](crate::chunked_array::ops::ChunkApply). +//! +//! ### Series / ChunkedArrays +//! +//! ``` +//! use polars::prelude::*; +//! use polars::prelude::arity::unary_elementwise_values; +//! # fn example() -> PolarsResult<()> { +//! +//! // apply a closure over all values +//! let s = Series::new("foo".into(), &[Some(1), Some(2), None]); +//! s.i32()?.apply_values(|value| value * 20); +//! +//! // count string lengths +//! let s = Series::new("foo".into(), &["foo", "bar", "foobar"]); +//! unary_elementwise_values::(s.str()?, |str_val| str_val.len() as u64); +//! +//! # Ok(()) +//! # } +//! ``` +//! +//! +//! ### Multiple columns +//! +//! ``` +//! use polars::prelude::*; +//! fn my_black_box_function(a: f32, b: f32) -> f32 { +//! // do something +//! a +//! } +//! +//! fn apply_multiples(col_a: &Series, col_b: &Series) -> Float32Chunked { +//! match (col_a.dtype(), col_b.dtype()) { +//! (DataType::Float32, DataType::Float32) => { +//! // downcast to `ChunkedArray` +//! let a = col_a.f32().unwrap(); +//! let b = col_b.f32().unwrap(); +//! +//! a.into_iter() +//! .zip(b.into_iter()) +//! .map(|(opt_a, opt_b)| match (opt_a, opt_b) { +//! (Some(a), Some(b)) => Some(my_black_box_function(a, b)), +//! // if either value is `None` we propagate that null +//! _ => None, +//! }) +//! .collect() +//! } +//! _ => panic!("unexpected dtypes"), +//! } +//! } +//! ``` +//! +//! ### DataFrame +//! +//! ``` +//! use polars::prelude::*; +//! use polars::df; +//! # fn example() -> PolarsResult<()> { +//! +//! let mut df = df![ +//! "letters" => ["a", "b", "c", "d"], +//! "numbers" => [1, 2, 3, 4] +//! ]?; +//! +//! +//! // coerce numbers to floats +//! df.try_apply("number", |s: &Series| s.cast(&DataType::Float64))?; +//! +//! // transform letters to uppercase letters +//! df.try_apply("letters", |s: &Series| { +//! Ok(s.str()?.to_uppercase()) +//! }); +//! +//! # Ok(()) +//! # } +//! ``` +//! +//! ## Filter +//! ``` +//! use polars::prelude::*; +//! +//! # fn example(df: &DataFrame) -> PolarsResult<()> { +//! // create a mask to filter out null values +//! let mask = df.column("sepal_width")?.is_not_null(); +//! +//! // select column +//! let s = df.column("sepal_length")?; +//! +//! // apply filter on a Series +//! let filtered_series = s.filter(&mask); +//! +//! // apply the filter on a DataFrame +//! let filtered_df = df.filter(&mask)?; +//! +//! # Ok(()) +//! # } +//! ``` +//! +//! ## Sort +//! ``` +//! use polars::prelude::*; +//! use polars::df; +//! +//! # fn example() -> PolarsResult<()> { +//! let df = df![ +//! "a" => [1, 2, 3], +//! "b" => ["a", "a", "b"] +//! ]?; +//! // sort this DataFrame by multiple columns +//! +//! // ordering of the columns +//! let descending = vec![true, false]; +//! // columns to sort by +//! let by = [PlSmallStr::from_static("b"), PlSmallStr::from_static("a")]; +//! // do the sort operation +//! let sorted = df.sort( +//! by, +//! SortMultipleOptions::default() +//! .with_order_descending_multi(descending) +//! .with_maintain_order(true) +//! )?; +//! +//! // sorted: +//! +//! // ╭─────┬─────╮ +//! // │ a ┆ b │ +//! // │ --- ┆ --- │ +//! // │ i64 ┆ str │ +//! // ╞═════╪═════╡ +//! // │ 1 ┆ "a" │ +//! // │ 2 ┆ "a" │ +//! // │ 3 ┆ "b" │ +//! // ╰─────┴─────╯ +//! +//! # Ok(()) +//! # } +//! ``` +//! +//! ## Joins +//! +//! ``` +//! use polars::prelude::*; +//! use polars::df; +//! +//! # fn example() -> PolarsResult<()> { +//! // Create first df. +//! let temp = df!("days" => &[0, 1, 2, 3, 4], +//! "temp" => &[22.1, 19.9, 7., 2., 3.], +//! "other" => &[1, 2, 3, 4, 5] +//! )?; +//! +//! // Create second df. +//! let rain = df!("days" => &[1, 2], +//! "rain" => &[0.1, 0.2], +//! "other" => &[1, 2, 3, 4, 5] +//! )?; +//! +//! // join on a single column +//! temp.left_join(&rain, ["days"], ["days"]); +//! temp.inner_join(&rain, ["days"], ["days"]); +//! temp.full_join(&rain, ["days"], ["days"]); +//! +//! // join on multiple columns +//! temp.join(&rain, vec!["days", "other"], vec!["days", "other"], JoinArgs::new(JoinType::Left), None); +//! +//! # Ok(()) +//! # } +//! ``` +//! +//! ## Groupby +//! +//! Note that Polars lazy is a lot more powerful in and more performant in group_by operations. +//! In lazy a myriad of aggregations can be combined from expressions. +//! +//! See more in: +//! +//! * [Groupby](crate::frame::group_by::GroupBy) +//! +//! ### GroupBy +//! ``` +//! use polars::prelude::*; +//! +//! # fn example(df: &DataFrame) -> PolarsResult<()> { +//! // group_by "groups" | sum "foo" +//! let out = df.group_by(["groups"])? +//! .select(["foo"]) +//! .sum(); +//! +//! # Ok(()) +//! # } +//! +//! ``` +//! +//! ### Pivot +//! +//! ``` +//! use polars::prelude::*; +//! use polars::df; +//! +//! # fn example(df: &DataFrame) -> PolarsResult<()> { +//! let df = df!("foo" => ["A", "A", "B", "B", "C"], +//! "N" => [1, 2, 2, 4, 2], +//! "bar" => ["k", "l", "m", "n", "0"] +//! )?; +//! +//! // group_by "foo" | pivot "bar" column | aggregate "N" +//! let pivoted = pivot::pivot( +//! &df, +//! [PlSmallStr::from_static("foo")], +//! Some([PlSmallStr::from_static("bar")]), +//! Some([PlSmallStr::from_static("N")]), +//! false, Some(first()), +//! None +//! ); +//! +//! // pivoted: +//! // +-----+------+------+------+------+------+ +//! // | foo | o | n | m | l | k | +//! // | --- | --- | --- | --- | --- | --- | +//! // | str | i32 | i32 | i32 | i32 | i32 | +//! // +=====+======+======+======+======+======+ +//! // | "A" | null | null | null | 2 | 1 | +//! // +-----+------+------+------+------+------+ +//! // | "B" | null | 4 | 2 | null | null | +//! // +-----+------+------+------+------+------+ +//! // | "C" | 2 | null | null | null | null | +//! // +-----+------+------+------+------+------+! +//! +//! # Ok(()) +//! # } +//! ``` +//! +//! ## Unpivot +//! +//! ``` +//! use polars::prelude::*; +//! use polars::df; +//! +//! # fn example(df: &DataFrame) -> PolarsResult<()> { +//! let df = df!["A" => &["a", "b", "a"], +//! "B" => &[1, 3, 5], +//! "C" => &[10, 11, 12], +//! "D" => &[2, 4, 6] +//! ]?; +//! +//! let unpivoted = df.unpivot( +//! [PlSmallStr::from_static("A"), PlSmallStr::from_static("B")], +//! [PlSmallStr::from_static("C"), PlSmallStr::from_static("D")], +//! ).unwrap(); +//! // unpivoted: +//! +//! // +-----+-----+----------+-------+ +//! // | A | B | variable | value | +//! // | --- | --- | --- | --- | +//! // | str | i32 | str | i32 | +//! // +=====+=====+==========+=======+ +//! // | "a" | 1 | "C" | 10 | +//! // +-----+-----+----------+-------+ +//! // | "b" | 3 | "C" | 11 | +//! // +-----+-----+----------+-------+ +//! // | "a" | 5 | "C" | 12 | +//! // +-----+-----+----------+-------+ +//! // | "a" | 1 | "D" | 2 | +//! // +-----+-----+----------+-------+ +//! // | "b" | 3 | "D" | 4 | +//! // +-----+-----+----------+-------+ +//! // | "a" | 5 | "D" | 6 | +//! // +-----+-----+----------+-------+ +//! +//! # Ok(()) +//! # } +//! ``` +//! +//! ## Explode +//! +//! ``` +//! use polars::prelude::*; +//! use polars::df; +//! +//! # fn example(df: &DataFrame) -> PolarsResult<()> { +//! let s0 = Series::new("a".into(), &[1i64, 2, 3]); +//! let s1 = Series::new("b".into(), &[1i64, 1, 1]); +//! let s2 = Series::new("c".into(), &[2i64, 2, 2]); +//! // construct a new ListChunked for a slice of Series. +//! let list = Column::new("foo".into(), &[s0, s1, s2]); +//! +//! // construct a few more Series. +//! let s0 = Column::new("B".into(), [1, 2, 3]); +//! let s1 = Column::new("C".into(), [1, 1, 1]); +//! let df = DataFrame::new(vec![list, s0, s1])?; +//! +//! let exploded = df.explode([PlSmallStr::from("foo")])?; +//! // exploded: +//! +//! // +-----+-----+-----+ +//! // | foo | B | C | +//! // | --- | --- | --- | +//! // | i64 | i32 | i32 | +//! // +=====+=====+=====+ +//! // | 1 | 1 | 1 | +//! // +-----+-----+-----+ +//! // | 2 | 1 | 1 | +//! // +-----+-----+-----+ +//! // | 3 | 1 | 1 | +//! // +-----+-----+-----+ +//! // | 1 | 2 | 1 | +//! // +-----+-----+-----+ +//! // | 1 | 2 | 1 | +//! // +-----+-----+-----+ +//! // | 1 | 2 | 1 | +//! // +-----+-----+-----+ +//! // | 2 | 3 | 1 | +//! // +-----+-----+-----+ +//! // | 2 | 3 | 1 | +//! // +-----+-----+-----+ +//! // | 2 | 3 | 1 | +//! // +-----+-----+-----+ +//! +//! # Ok(()) +//! # } +//! ``` +//! +//! ## IO +//! +//! ### Read CSV +//! +//! ``` +//! use polars::prelude::*; +//! +//! # fn example(df: &DataFrame) -> PolarsResult<()> { +//! // read from path +//! let mut file = std::fs::File::open("iris.csv")?; +//! let df = CsvReader::new(file).finish()?; +//! # Ok(()) +//! # } +//! ``` +//! +//! ### Write CSV +//! +//! ``` +//! use polars::prelude::*; +//! use std::fs::File; +//! +//! # fn example(df: &mut DataFrame) -> PolarsResult<()> { +//! // create a file +//! let mut file = File::create("example.csv").expect("could not create file"); +//! +//! // write DataFrame to file +//! CsvWriter::new(&mut file) +//! .include_header(true) +//! .with_separator(b',') +//! .finish(df); +//! # Ok(()) +//! # } +//! ``` +//! +//! ### Read IPC +//! ``` +//! use polars::prelude::*; +//! use std::fs::File; +//! +//! # fn example(df: &DataFrame) -> PolarsResult<()> { +//! // open file +//! let file = File::open("file.ipc").expect("file not found"); +//! +//! // read to DataFrame +//! let df = IpcReader::new(file) +//! .finish()?; +//! # Ok(()) +//! # } +//! ``` +//! +//! ### Write IPC +//! ``` +//! use polars::prelude::*; +//! use std::fs::File; +//! +//! # fn example(df: &mut DataFrame) -> PolarsResult<()> { +//! // create a file +//! let mut file = File::create("file.ipc").expect("could not create file"); +//! +//! // write DataFrame to file +//! IpcWriter::new(&mut file) +//! .finish(df) +//! # } +//! ``` +//! +//! ### Read Parquet +//! +//! ``` +//! use polars::prelude::*; +//! use std::fs::File; +//! +//! # fn example(df: &DataFrame) -> PolarsResult<()> { +//! // open file +//! let file = File::open("some_file.parquet").unwrap(); +//! +//! // read to DataFrame +//! let df = ParquetReader::new(file).finish()?; +//! # Ok(()) +//! # } +//! ``` +//! +//! ### Write Parquet +//! ``` +//! use polars::prelude::*; +//! use std::fs::File; +//! +//! # fn example(df: &mut DataFrame) -> PolarsResult { +//! // create a file +//! let file = File::create("example.parquet").expect("could not create file"); +//! +//! ParquetWriter::new(file) +//! .finish(df) +//! # } +//! ``` +//! +//! # Various +//! +//! ## Replace NaN with Missing. +//! The floating point [Not a Number: NaN](https://en.wikipedia.org/wiki/NaN) is conceptually different +//! than missing data in Polars. In the snippet below we show how we can replace [`NaN`] values with +//! missing values, by setting them to [`None`]. +//! +//! [`NaN`]: https://doc.rust-lang.org/std/primitive.f64.html#associatedconstant.NAN +//! +//! ``` +//! use polars::prelude::*; +//! use polars::df; +//! +//! /// Replaces NaN with missing values. +//! fn fill_nan_with_nulls() -> PolarsResult { +//! let nan = f64::NAN; +//! +//! let mut df = df! { +//! "a" => [nan, 1.0, 2.0], +//! "b" => [nan, 1.0, 2.0] +//! } +//! .unwrap(); +//! +//! for idx in 0..df.width() { +//! df.try_apply_at_idx(idx, |series| { +//! let mask = series.is_nan()?; +//! let ca = series.f64()?; +//! ca.set(&mask, None) +//! })?; +//! } +//! Ok(df) +//! } +//! ``` +//! +//! ## Extracting data +//! +//! To iterate over the values of a [`Series`], or to convert the [`Series`] into another structure +//! such as a [`Vec`], we must first downcast to a data type aware [`ChunkedArray`]. +//! +//! [`ChunkedArray`]: crate::chunked_array::ChunkedArray +//! +//! ``` +//! use polars::prelude::*; +//! use polars::df; +//! +//! fn extract_data() -> PolarsResult<()> { +//! let df = df! [ +//! "a" => [None, Some(1.0f32), Some(2.0)], +//! "str" => ["foo", "bar", "ham"] +//! ]?; +//! +//! // first extract ChunkedArray to get the inner type. +//! let ca = df.column("a")?.f32()?; +//! +//! // Then convert to vec +//! let to_vec: Vec> = Vec::from(ca); +//! +//! // We can also do this with iterators +//! let ca = df.column("str")?.str()?; +//! let to_vec: Vec> = ca.into_iter().collect(); +//! let to_vec_no_options: Vec<&str> = ca.into_no_null_iter().collect(); +//! +//! Ok(()) +//! } +//! ``` +//! +//! diff --git a/crates/polars/src/docs/lazy.rs b/crates/polars/src/docs/lazy.rs new file mode 100644 index 000000000000..cf0c6f9f35e6 --- /dev/null +++ b/crates/polars/src/docs/lazy.rs @@ -0,0 +1,287 @@ +//! +//! # Polars Lazy cookbook +//! +//! This page should serve as a cookbook to quickly get you started with Polars' query engine. +//! The lazy API allows you to create complex well performing queries on top of Polars eager. +//! +//! ## Tree Of Contents +//! +//! * [Start a lazy computation](#start-a-lazy-computation) +//! * [Filter](#filter) +//! * [Sort](#sort) +//! * [GroupBy](#group_by) +//! * [Joins](#joins) +//! * [Conditionally apply](#conditionally-apply) +//! * [Black box function](#black-box-function) +//! +//! ## Start a lazy computation +//! +//! ``` +//! use polars::prelude::*; +//! use polars::df; +//! +//! # fn example() -> PolarsResult<()> { +//! let df = df![ +//! "a" => [1, 2, 3], +//! "b" => [None, Some("a"), Some("b")] +//! ]?; +//! // from an eager DataFrame +//! let lf: LazyFrame = df.lazy(); +//! +//! // scan a csv file lazily +//! let lf: LazyFrame = LazyCsvReader::new("some_path") +//! .with_has_header(true) +//! .finish()?; +//! +//! // scan a parquet file lazily +//! let lf: LazyFrame = LazyFrame::scan_parquet("some_path", Default::default())?; +//! +//! # Ok(()) +//! # } +//! ``` +//! +//! ## Filter +//! ``` +//! use polars::prelude::*; +//! use polars::df; +//! +//! # fn example() -> PolarsResult<()> { +//! let df = df![ +//! "a" => [1, 2, 3], +//! "b" => [None, Some("a"), Some("b")] +//! ]?; +//! +//! let filtered = df.lazy() +//! .filter(col("a").gt(lit(2))) +//! .collect()?; +//! +//! // filtered: +//! +//! // ╭─────┬─────╮ +//! // │ a ┆ b │ +//! // │ --- ┆ --- │ +//! // │ i64 ┆ str │ +//! // ╞═════╪═════╡ +//! // │ 3 ┆ "c" │ +//! // ╰─────┴─────╯ +//! +//! # Ok(()) +//! # } +//! ``` +//! +//! ## Sort +//! ``` +//! use polars::prelude::*; +//! use polars::df; +//! +//! # fn example() -> PolarsResult<()> { +//! let df = df![ +//! "a" => [1, 2, 3], +//! "b" => ["a", "a", "b"] +//! ]?; +//! // sort this DataFrame by multiple columns +//! +//! let sorted = df.lazy() +//! .sort_by_exprs(vec![col("b"), col("a")], SortMultipleOptions::default()) +//! .collect()?; +//! +//! // sorted: +//! +//! // ╭─────┬─────╮ +//! // │ a ┆ b │ +//! // │ --- ┆ --- │ +//! // │ i64 ┆ str │ +//! // ╞═════╪═════╡ +//! // │ 1 ┆ "a" │ +//! // │ 2 ┆ "a" │ +//! // │ 3 ┆ "b" │ +//! // ╰─────┴─────╯ +//! +//! # Ok(()) +//! # } +//! ``` +//! +//! ## Groupby +//! +//! This example is from the polars [user guide](https://docs.pola.rs/user-guide/concepts/expressions-and-contexts/#group_by-and-aggregations). +//! +//! ``` +//! use polars::prelude::*; +//! # fn example() -> PolarsResult<()> { +//! +//! let df = LazyCsvReader::new("reddit.csv") +//! .with_has_header(true) +//! .with_separator(b',') +//! .finish()? +//! .group_by([col("comment_karma")]) +//! .agg([col("name").n_unique().alias("unique_names"), col("link_karma").max()]) +//! // take only 100 rows. +//! .fetch(100)?; +//! # Ok(()) +//! # } +//! ``` +//! +//! ## Joins +//! +//! ``` +//! use polars::prelude::*; +//! use polars::df; +//! # fn example() -> PolarsResult<()> { +//! let df_a = df![ +//! "a" => [1, 2, 1, 1], +//! "b" => ["a", "b", "c", "c"], +//! "c" => [0, 1, 2, 3] +//! ]?; +//! +//! let df_b = df![ +//! "foo" => [1, 1, 1], +//! "bar" => ["a", "c", "c"], +//! "ham" => ["let", "var", "const"] +//! ]?; +//! +//! let lf_a = df_a.clone().lazy(); +//! let lf_b = df_b.clone().lazy(); +//! +//! let joined = lf_a.join(lf_b, vec![col("a")], vec![col("foo")], JoinArgs::new(JoinType::Full)).collect()?; +//! // joined: +//! +//! // ╭─────┬─────┬─────┬──────┬─────────╮ +//! // │ b ┆ c ┆ a ┆ bar ┆ ham │ +//! // │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ +//! // │ str ┆ i64 ┆ i64 ┆ str ┆ str │ +//! // ╞═════╪═════╪═════╪══════╪═════════╡ +//! // │ "a" ┆ 0 ┆ 1 ┆ "a" ┆ "let" │ +//! // │ "a" ┆ 0 ┆ 1 ┆ "c" ┆ "var" │ +//! // │ "a" ┆ 0 ┆ 1 ┆ "c" ┆ "const" │ +//! // │ "b" ┆ 1 ┆ 2 ┆ null ┆ null │ +//! // │ "c" ┆ 2 ┆ 1 ┆ null ┆ null │ +//! // │ "c" ┆ 3 ┆ 1 ┆ null ┆ null │ +//! // ╰─────┴─────┴─────┴──────┴─────────╯ +//! +//! // other join syntax options +//! # let lf_a = df_a.clone().lazy(); +//! # let lf_b = df_b.clone().lazy(); +//! let inner = lf_a.inner_join(lf_b, col("a"), col("foo")).collect()?; +//! +//! # let lf_a = df_a.clone().lazy(); +//! # let lf_b = df_b.clone().lazy(); +//! let left = lf_a.left_join(lf_b, col("a"), col("foo")).collect()?; +//! +//! # let lf_a = df_a.clone().lazy(); +//! # let lf_b = df_b.clone().lazy(); +//! let outer = lf_a.full_join(lf_b, col("a"), col("foo")).collect()?; +//! +//! # let lf_a = df_a.clone().lazy(); +//! # let lf_b = df_b.clone().lazy(); +//! let joined_with_builder = lf_a.join_builder() +//! .with(lf_b) +//! .left_on(vec![col("a")]) +//! .right_on(vec![col("foo")]) +//! .how(JoinType::Inner) +//! .force_parallel(true) +//! .finish() +//! .collect()?; +//! +//! # Ok(()) +//! # } +//! ``` +//! +//! ## Conditionally apply +//! If we want to create a new column based on some condition, we can use the [`when`]/[`then`]/[`otherwise`] expressions. +//! +//! * [`when`] - accepts a predicate expression +//! * [`then`] - expression to use when `predicate == true` +//! * [`otherwise`] - expression to use when `predicate == false` +//! +//! [`when`]: polars_lazy::dsl::Then::when +//! [`then`]: polars_lazy::dsl::When::then +//! [`otherwise`]: polars_lazy::dsl::Then::otherwise +//! +//! ``` +//! use polars::prelude::*; +//! use polars::df; +//! # fn example() -> PolarsResult<()> { +//! let df = df![ +//! "range" => [1, 2, 3, 4, 5, 6, 8, 9, 10], +//! "left" => (0..10).map(|_| Some("foo")).collect::>(), +//! "right" => (0..10).map(|_| Some("bar")).collect::>() +//! ]?; +//! +//! let new = df.lazy() +//! .with_column(when(col("range").gt_eq(lit(5))) +//! .then(col("left")) +//! .otherwise(col("right")).alias("foo_or_bar") +//! ).collect()?; +//! +//! // new: +//! +//! // ╭───────┬───────┬───────┬────────────╮ +//! // │ range ┆ left ┆ right ┆ foo_or_bar │ +//! // │ --- ┆ --- ┆ --- ┆ --- │ +//! // │ i64 ┆ str ┆ str ┆ str │ +//! // ╞═══════╪═══════╪═══════╪════════════╡ +//! // │ 0 ┆ "foo" ┆ "bar" ┆ "bar" │ +//! // │ 1 ┆ "foo" ┆ "bar" ┆ "bar" │ +//! // │ 2 ┆ "foo" ┆ "bar" ┆ "bar" │ +//! // │ 3 ┆ "foo" ┆ "bar" ┆ "bar" │ +//! // │ … ┆ … ┆ … ┆ … │ +//! // │ 5 ┆ "foo" ┆ "bar" ┆ "foo" │ +//! // │ 6 ┆ "foo" ┆ "bar" ┆ "foo" │ +//! // │ 7 ┆ "foo" ┆ "bar" ┆ "foo" │ +//! // │ 8 ┆ "foo" ┆ "bar" ┆ "foo" │ +//! // │ 9 ┆ "foo" ┆ "bar" ┆ "foo" │ +//! // ╰───────┴───────┴───────┴────────────╯ +//! +//! # Ok(()) +//! # } +//! ``` +//! +//! # Black box function +//! +//! The expression API should be expressive enough for most of what you want to achieve, but it can happen +//! that you need to pass the values to an external function you do not control. The snippet below +//! shows how we use the [`Struct`] datatype to be able to apply a function over multiple inputs. +//! +//! [`Struct`]: crate::datatypes::DataType::Struct +//! +//! ```ignore +//! use polars::prelude::*; +//! fn my_black_box_function(a: f32, b: f32) -> f32 { +//! // do something +//! a +//! } +//! +//! fn apply_multiples() -> PolarsResult { +//! df![ +//! "a" => [1.0f32, 2.0, 3.0], +//! "b" => [3.0f32, 5.1, 0.3] +//! ]? +//! .lazy() +//! .select([as_struct(vec![col("a"), col("b")]).map( +//! |s| { +//! let ca = s.struct_()?; +//! +//! let series_a = ca.field_by_name("a")?; +//! let series_b = ca.field_by_name("b")?; +//! let chunked_a = series_a.f32()?; +//! let chunked_b = series_b.f32()?; +//! +//! let out: Float32Chunked = chunked_a +//! .into_iter() +//! .zip(chunked_b.into_iter()) +//! .map(|(opt_a, opt_b)| match (opt_a, opt_b) { +//! (Some(a), Some(b)) => Some(my_black_box_function(a, b)), +//! _ => None, +//! }) +//! .collect(); +//! +//! Ok(Some(out.into_series())) +//! }, +//! GetOutput::from_type(DataType::Float32), +//! )]) +//! .collect() +//! } +//! +//! ``` +//! +//! diff --git a/crates/polars/src/docs/mod.rs b/crates/polars/src/docs/mod.rs new file mode 100644 index 000000000000..be809c6ea356 --- /dev/null +++ b/crates/polars/src/docs/mod.rs @@ -0,0 +1,2 @@ +pub mod eager; +pub mod lazy; diff --git a/crates/polars/src/lib.rs b/crates/polars/src/lib.rs new file mode 100644 index 000000000000..c7d3720a8b20 --- /dev/null +++ b/crates/polars/src/lib.rs @@ -0,0 +1,435 @@ +//! # Polars: *DataFrames in Rust* +//! +//! Polars is a DataFrame library for Rust. It is based on [Apache Arrow](https://arrow.apache.org/)'s memory model. +//! Apache Arrow provides very cache efficient columnar data structures and is becoming the defacto +//! standard for columnar data. +//! +//! ## Quickstart +//! We recommend building queries directly with [polars-lazy]. This allows you to combine +//! expressions into powerful aggregations and column selections. All expressions are evaluated +//! in parallel and queries are optimized just in time. +//! +//! [polars-lazy]: polars_lazy +//! +//! ```no_run +//! use polars::prelude::*; +//! # fn example() -> PolarsResult<()> { +//! +//! let lf1 = LazyFrame::scan_parquet("myfile_1.parquet", Default::default())? +//! .group_by([col("ham")]) +//! .agg([ +//! // expressions can be combined into powerful aggregations +//! col("foo") +//! .sort_by([col("ham").rank(Default::default(), None)], SortMultipleOptions::default()) +//! .last() +//! .alias("last_foo_ranked_by_ham"), +//! // every expression runs in parallel +//! col("foo").cum_min(false).alias("cumulative_min_per_group"), +//! // every expression runs in parallel +//! col("foo").reverse().implode().alias("reverse_group"), +//! ]); +//! +//! let lf2 = LazyFrame::scan_parquet("myfile_2.parquet", Default::default())? +//! .select([col("ham"), col("spam")]); +//! +//! let df = lf1 +//! .join(lf2, [col("reverse")], [col("foo")], JoinArgs::new(JoinType::Left)) +//! // now we finally materialize the result. +//! .collect()?; +//! # Ok(()) +//! # } +//! ``` +//! +//! This means that Polars data structures can be shared zero copy with processes in many different +//! languages. +//! +//! ## Tree Of Contents +//! +//! * [Cookbooks](#cookbooks) +//! * [Data structures](#data-structures) +//! - [DataFrame](#dataframe) +//! - [Series](#series) +//! - [ChunkedArray](#chunkedarray) +//! * [SIMD](#simd) +//! * [API](#api) +//! * [Expressions](#expressions) +//! * [Compile times](#compile-times) +//! * [Performance](#performance-and-string-data) +//! - [Custom allocator](#custom-allocator) +//! * [Config](#config-with-env-vars) +//! * [User guide](#user-guide) +//! +//! ## Cookbooks +//! See examples in the cookbooks: +//! +//! * [Eager](crate::docs::eager) +//! * [Lazy](crate::docs::lazy) +//! +//! ## Data Structures +//! The base data structures provided by polars are [`DataFrame`], [`Series`], and [`ChunkedArray`]. +//! We will provide a short, top-down view of these data structures. +//! +//! [`DataFrame`]: crate::frame::DataFrame +//! [`Series`]: crate::series::Series +//! [`ChunkedArray`]: crate::chunked_array::ChunkedArray +//! +//! ### DataFrame +//! A [`DataFrame`] is a two-dimensional data structure backed by a [`Series`] and can be +//! seen as an abstraction on [`Vec`]. Operations that can be executed on a [`DataFrame`] are +//! similar to what is done in a `SQL` like query. You can `GROUP`, `JOIN`, `PIVOT` etc. +//! +//! [`Vec`]: std::vec::Vec +//! +//! ### Series +//! [`Series`] are the type-agnostic columnar data representation of Polars. The [`Series`] struct and +//! [`SeriesTrait`] trait provide many operations out of the box. Most type-agnostic operations are provided +//! by [`Series`]. Type-aware operations require downcasting to the typed data structure that is wrapped +//! by the [`Series`]. The underlying typed data structure is a [`ChunkedArray`]. +//! +//! [`SeriesTrait`]: crate::series::SeriesTrait +//! +//! ### ChunkedArray +//! [`ChunkedArray`] are wrappers around an arrow array, that can contain multiples chunks, e.g. +//! [`Vec`]. These are the root data structures of Polars, and implement many operations. +//! Most operations are implemented by traits defined in [chunked_array::ops], +//! or on the [`ChunkedArray`] struct. +//! +//! [`ChunkedArray`]: crate::chunked_array::ChunkedArray +//! +//! ## SIMD +//! Polars / Arrow uses packed_simd to speed up kernels with SIMD operations. SIMD is an optional +//! `feature = "nightly"`, and requires a nightly compiler. If you don't need SIMD, **Polars runs on stable!** +//! +//! ## API +//! Polars supports an eager and a lazy API. The eager API directly yields results, but is overall +//! more verbose and less capable of building elegant composite queries. We recommend to use the Lazy API +//! whenever you can. +//! +//! As neither API is async they should be wrapped in _spawn_blocking_ when used in an async context +//! to avoid blocking the async thread pool of the runtime. +//! +//! ## Expressions +//! Polars has a powerful concept called expressions. +//! Polars expressions can be used in various contexts and are a functional mapping of +//! `Fn(Series) -> Series`, meaning that they have [`Series`] as input and [`Series`] as output. +//! By looking at this functional definition, we can see that the output of an [`Expr`] also can serve +//! as the input of an [`Expr`]. +//! +//! [`Expr`]: polars_lazy::dsl::Expr +//! +//! That may sound a bit strange, so lets give an example. The following is an expression: +//! +//! `col("foo").sort().head(2)` +//! +//! The snippet above says select column `"foo"` then sort this column and then take the first 2 values +//! of the sorted output. +//! The power of expressions is that every expression produces a new expression and that they can +//! be piped together. +//! You can run an expression by passing them on one of polars execution contexts. +//! Here we run two expressions in the **select** context: +//! +//! ```no_run +//! # use polars::prelude::*; +//! # fn example() -> PolarsResult<()> { +//! # let df = DataFrame::default(); +//! df.lazy() +//! .select([ +//! col("foo").sort(Default::default()).head(None), +//! col("bar").filter(col("foo").eq(lit(1))).sum(), +//! ]) +//! .collect()?; +//! # Ok(()) +//! # } +//! ``` +//! All expressions are run in parallel, meaning that separate polars expressions are embarrassingly parallel. +//! (Note that within an expression there may be more parallelization going on). +//! +//! Understanding Polars expressions is most important when starting with the Polars library. Read more +//! about them in the [user guide](https://docs.pola.rs/user-guide/expressions). +//! +//! ### Eager +//! Read more in the pages of the following data structures /traits. +//! +//! * [DataFrame struct](crate::frame::DataFrame) +//! * [Series struct](crate::series::Series) +//! * [Series trait](crate::series::SeriesTrait) +//! * [ChunkedArray struct](crate::chunked_array::ChunkedArray) +//! * [ChunkedArray operations traits](crate::chunked_array::ops) +//! +//! ### Lazy +//! Unlock full potential with lazy computation. This allows query optimizations and provides Polars +//! the full query context so that the fastest algorithm can be chosen. +//! +//! **[Read more in the lazy module.](polars_lazy)** +//! +//! ## Compile times +//! A DataFrame library typically consists of +//! +//! * Tons of features +//! * A lot of datatypes +//! +//! Both of these really put strain on compile times. To keep Polars lean, we make both **opt-in**, +//! meaning that you only pay the compilation cost if you need it. +//! +//! ## Compile times and opt-in features +//! The opt-in features are (not including dtype features): +//! +//! * `lazy` - Lazy API +//! - `regex` - Use regexes in [column selection] +//! - `dot_diagram` - Create dot diagrams from lazy logical plans. +//! * `sql` - Pass SQL queries to Polars. +//! * `streaming` - Process datasets larger than RAM. +//! * `random` - Generate arrays with randomly sampled values +//! * `ndarray`- Convert from [`DataFrame`] to [ndarray](https://docs.rs/ndarray/) +//! * `temporal` - Conversions between [Chrono](https://docs.rs/chrono/) and Polars for temporal data types +//! * `timezones` - Activate timezone support. +//! * `strings` - Extra string utilities for [`StringChunked`] +//! - `string_pad` - `zfill`, `ljust`, `rjust` +//! - `string_to_integer` - `parse_int` +//! * `object` - Support for generic ChunkedArrays called [`ObjectChunked`] (generic over `T`). +//! These are downcastable from Series through the [Any](https://doc.rust-lang.org/std/any/index.html) trait. +//! * Performance related: +//! - `nightly` - Several nightly only features such as SIMD and specialization. +//! - `performant` - more fast paths, slower compile times. +//! - `bigidx` - Activate this feature if you expect >> 2^32 rows. This is rarely needed. +//! This allows Polars to scale up beyond 2^32 rows by using an index with a `u64` data type. +//! Polars will be a bit slower with this feature activated as many data structures +//! are less cache efficient. +//! - `cse` - Activate common subplan elimination optimization +//! * IO related: +//! - `serde` - Support for [serde](https://crates.io/crates/serde) serialization and deserialization. +//! Can be used for JSON and more serde supported serialization formats. +//! - `serde-lazy` - Support for [serde](https://crates.io/crates/serde) serialization and deserialization. +//! Can be used for JSON and more serde supported serialization formats. +//! - `parquet` - Read Apache Parquet format +//! - `json` - JSON serialization +//! - `ipc` - Arrow's IPC format serialization +//! - `decompress` - Automatically infer compression of csvs and decompress them. +//! Supported compressions: +//! - gzip +//! - zlib +//! - zstd +//! +//! [`StringChunked`]: crate::datatypes::StringChunked +//! [column selection]: polars_lazy::dsl::col +//! [`ObjectChunked`]: polars_core::datatypes::ObjectChunked +//! +//! +//! * [`DataFrame`] operations: +//! - `dynamic_group_by` - Groupby based on a time window instead of predefined keys. +//! Also activates rolling window group by operations. +//! - `sort_multiple` - Allow sorting a [`DataFrame`] on multiple columns +//! - `rows` - Create [`DataFrame`] from rows and extract rows from [`DataFrame`]s. +//! Also activates `pivot` and `transpose` operations +//! - `asof_join` - Join ASOF, to join on nearest keys instead of exact equality match. +//! - `cross_join` - Create the Cartesian product of two [`DataFrame`]s. +//! - `semi_anti_join` - SEMI and ANTI joins. +//! - `row_hash` - Utility to hash [`DataFrame`] rows to [`UInt64Chunked`] +//! - `diagonal_concat` - Concat diagonally thereby combining different schemas. +//! - `dataframe_arithmetic` - Arithmetic on ([`Dataframe`] and [`DataFrame`]s) and ([`DataFrame`] on [`Series`]) +//! - `partition_by` - Split into multiple [`DataFrame`]s partitioned by groups. +//! * [`Series`]/[`Expr`] operations: +//! - `is_in` - Check for membership in [`Series`]. +//! - `zip_with` - [Zip two Series/ ChunkedArrays](crate::chunked_array::ops::ChunkZip). +//! - `round_series` - Round underlying float types of [`Series`]. +//! - `repeat_by` - Repeat element in an Array N times, where N is given by another array. +//! - `is_first_distinct` - Check if element is first unique value. +//! - `is_last_distinct` - Check if element is last unique value. +//! - `is_between` - Check if this expression is between the given lower and upper bounds. +//! - `checked_arithmetic` - checked arithmetic/ returning [`None`] on invalid operations. +//! - `dot_product` - Dot/inner product on [`Series`] and [`Expr`]. +//! - `concat_str` - Concat string data in linear time. +//! - `reinterpret` - Utility to reinterpret bits to signed/unsigned +//! - `take_opt_iter` - Take from a [`Series`] with [`Iterator>`](std::iter::Iterator). +//! - `mode` - [Return the most occurring value(s)](polars_ops::chunked_array::mode) +//! - `cum_agg` - [`cum_sum`], [`cum_min`], [`cum_max`] aggregation. +//! - `rolling_window` - rolling window functions, like [`rolling_mean`] +//! - `interpolate` - [interpolate None values](polars_ops::series::interpolate()) +//! - `extract_jsonpath` - [Run jsonpath queries on StringChunked](https://goessner.net/articles/JsonPath/) +//! - `list` - List utils. +//! - `list_gather` take sublist by multiple indices +//! - `rank` - Ranking algorithms. +//! - `moment` - Kurtosis and skew statistics +//! - `ewma` - Exponential moving average windows +//! - `abs` - Get absolute values of [`Series`]. +//! - `arange` - Range operation on [`Series`]. +//! - `product` - Compute the product of a [`Series`]. +//! - `diff` - [`diff`] operation. +//! - `pct_change` - Compute change percentages. +//! - `unique_counts` - Count unique values in expressions. +//! - `log` - Logarithms for [`Series`]. +//! - `list_to_struct` - Convert [`List`] to [`Struct`] dtypes. +//! - `list_count` - Count elements in lists. +//! - `list_eval` - Apply expressions over list elements. +//! - `list_sets` - Compute UNION, INTERSECTION, and DIFFERENCE on list types. +//! - `cumulative_eval` - Apply expressions over cumulatively increasing windows. +//! - `arg_where` - Get indices where condition holds. +//! - `search_sorted` - Find indices where elements should be inserted to maintain order. +//! - `offset_by` - Add an offset to dates that take months and leap years into account. +//! - `trigonometry` - Trigonometric functions. +//! - `sign` - Compute the element-wise sign of a [`Series`]. +//! - `propagate_nans` - NaN propagating min/max aggregations. +//! - `extract_groups` - Extract multiple regex groups from strings. +//! - `cov` - Covariance and correlation functions. +//! - `find_many` - Find/replace multiple string patterns at once. +//! * [`DataFrame`] pretty printing +//! - `fmt` - Activate [`DataFrame`] formatting +//! +//! [`UInt64Chunked`]: crate::datatypes::UInt64Chunked +//! [`cum_sum`]: polars_ops::prelude::cum_sum +//! [`cum_min`]: polars_ops::prelude::cum_min +//! [`cum_max`]: polars_ops::prelude::cum_max +//! [`rolling_mean`]: crate::series::Series#method.rolling_mean +//! [`diff`]: polars_ops::prelude::diff +//! [`List`]: crate::datatypes::DataType::List +//! [`Struct`]: crate::datatypes::DataType::Struct +//! +//! ## Compile times and opt-in data types +//! As mentioned above, Polars [`Series`] are wrappers around +//! [`ChunkedArray`] without the generic parameter `T`. +//! To get rid of the generic parameter, all the possible values of `T` are compiled +//! for [`Series`]. This gets more expensive the more types you want for a [`Series`]. In order to reduce +//! the compile times, we have decided to default to a minimal set of types and make more [`Series`] types +//! opt-in. +//! +//! Note that if you get strange compile time errors, you probably need to opt-in for that [`Series`] dtype. +//! The opt-in dtypes are: +//! +//! | data type | feature flag | +//! |-------------------------|-------------------| +//! | Date | dtype-date | +//! | Datetime | dtype-datetime | +//! | Time | dtype-time | +//! | Duration | dtype-duration | +//! | Int8 | dtype-i8 | +//! | Int16 | dtype-i16 | +//! | UInt8 | dtype-u8 | +//! | UInt16 | dtype-u16 | +//! | Categorical | dtype-categorical | +//! | Struct | dtype-struct | +//! +//! +//! Or you can choose one of the preconfigured pre-sets. +//! +//! * `dtype-full` - all opt-in dtypes. +//! * `dtype-slim` - slim preset of opt-in dtypes. +//! +//! ## Performance +//! To get the best performance out of Polars we recommend compiling on a nightly compiler +//! with the features `simd` and `performant` activated. The activated cpu features also influence +//! the amount of simd acceleration we can use. +//! +//! See the features we activate for our python builds, or if you just run locally and want to +//! use all available features on your cpu, set `RUSTFLAGS='-C target-cpu=native'`. +//! +//! ### Custom allocator +//! An OLAP query engine does a lot of heap allocations. It is recommended to use a custom +//! allocator, (we have found this to have up to ~25% runtime influence). +//! [JeMalloc](https://crates.io/crates/tikv-jemallocator) and +//! [Mimalloc](https://crates.io/crates/mimalloc) for instance, show a significant +//! performance gain in runtime as well as memory usage. +//! +//! #### Jemalloc Usage +//! ```ignore +//! use tikv_jemallocator::Jemalloc; +//! +//! #[global_allocator] +//! static GLOBAL: Jemalloc = Jemalloc; +//! ``` +//! +//! #### Cargo.toml +//! ```toml +//! [dependencies] +//! tikv-jemallocator = { version = "*" } +//! ``` +//! +//! #### Mimalloc Usage +//! +//! ```ignore +//! use mimalloc::MiMalloc; +//! +//! #[global_allocator] +//! static GLOBAL: MiMalloc = MiMalloc; +//! ``` +//! +//! #### Cargo.toml +//! ```toml +//! [dependencies] +//! mimalloc = { version = "*", default-features = false } +//! ``` +//! +//! #### Notes +//! [Benchmarks](https://github.com/pola-rs/polars/pull/3108) have shown that on Linux and macOS JeMalloc +//! outperforms Mimalloc on all tasks and is therefore the default allocator used for the Python bindings on Unix platforms. +//! +//! ## Config with ENV vars +//! +//! * `POLARS_FMT_TABLE_FORMATTING` -> define styling of tables using any of the following options (default = UTF8_FULL_CONDENSED). These options are defined by comfy-table which provides examples for each at +//! * `ASCII_FULL` +//! * `ASCII_FULL_CONDENSED` +//! * `ASCII_NO_BORDERS` +//! * `ASCII_BORDERS_ONLY` +//! * `ASCII_BORDERS_ONLY_CONDENSED` +//! * `ASCII_HORIZONTAL_ONLY` +//! * `ASCII_MARKDOWN` +//! * `MARKDOWN` +//! * `UTF8_FULL` +//! * `UTF8_FULL_CONDENSED` +//! * `UTF8_NO_BORDERS` +//! * `UTF8_BORDERS_ONLY` +//! * `UTF8_HORIZONTAL_ONLY` +//! * `NOTHING` +//! * `POLARS_FMT_TABLE_CELL_ALIGNMENT` -> define cell alignment using any of the following options (default = LEFT): +//! * `LEFT` +//! * `CENTER` +//! * `RIGHT` +//! * `POLARS_FMT_TABLE_DATAFRAME_SHAPE_BELOW` -> print shape information below the table. +//! * `POLARS_FMT_TABLE_HIDE_COLUMN_NAMES` -> hide table column names. +//! * `POLARS_FMT_TABLE_HIDE_COLUMN_DATA_TYPES` -> hide data types for columns. +//! * `POLARS_FMT_TABLE_HIDE_COLUMN_SEPARATOR` -> hide separator that separates column names from rows. +//! * `POLARS_FMT_TABLE_HIDE_DATAFRAME_SHAPE_INFORMATION"` -> omit table shape information. +//! * `POLARS_FMT_TABLE_INLINE_COLUMN_DATA_TYPE` -> put column data type on the same line as the column name. +//! * `POLARS_FMT_TABLE_ROUNDED_CORNERS` -> apply rounded corners to UTF8-styled tables. +//! * `POLARS_FMT_MAX_COLS` -> maximum number of columns shown when formatting DataFrames. +//! * `POLARS_FMT_MAX_ROWS` -> maximum number of rows shown when formatting DataFrames, `-1` to show all. +//! * `POLARS_FMT_STR_LEN` -> maximum number of characters printed per string value. +//! * `POLARS_TABLE_WIDTH` -> width of the tables used during DataFrame formatting. +//! * `POLARS_MAX_THREADS` -> maximum number of threads used to initialize thread pool (on startup). +//! * `POLARS_VERBOSE` -> print logging info to stderr. +//! * `POLARS_NO_PARTITION` -> polars may choose to partition the group_by operation, based on data +//! cardinality. Setting this env var will turn partitioned group_by's off. +//! * `POLARS_PARTITION_UNIQUE_COUNT` -> at which (estimated) key count a partitioned group_by should run. +//! defaults to `1000`, any higher cardinality will run default group_by. +//! * `POLARS_FORCE_PARTITION` -> force partitioned group_by if the keys and aggregations allow it. +//! * `POLARS_ALLOW_EXTENSION` -> allows for [`ObjectChunked`] to be used in arrow, opening up possibilities like using +//! `T` in complex lazy expressions. However this does require `unsafe` code allow this. +//! * `POLARS_NO_PARQUET_STATISTICS` -> if set, statistics in parquet files are ignored. +//! * `POLARS_PANIC_ON_ERR` -> panic instead of returning an Error. +//! * `POLARS_BACKTRACE_IN_ERR` -> include a Rust backtrace in Error messages. +//! * `POLARS_NO_CHUNKED_JOIN` -> force rechunk before joins. +//! +//! ## User guide +//! +//! If you want to read more, check the [user guide](https://docs.pola.rs/). +#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![allow(ambiguous_glob_reexports)] +pub mod docs; +pub mod prelude; +#[cfg(feature = "sql")] +pub mod sql; + +pub use polars_core::{ + apply_method_all_arrow_series, chunked_array, datatypes, df, error, frame, functions, series, + testing, +}; +#[cfg(feature = "dtype-categorical")] +pub use polars_core::{enable_string_cache, using_string_cache}; +#[cfg(feature = "polars-io")] +pub use polars_io as io; +#[cfg(feature = "lazy")] +pub use polars_lazy as lazy; +#[cfg(feature = "temporal")] +pub use polars_time as time; + +/// Polars crate version +pub const VERSION: &str = env!("CARGO_PKG_VERSION"); diff --git a/crates/polars/src/prelude.rs b/crates/polars/src/prelude.rs new file mode 100644 index 000000000000..d472c4f19bdf --- /dev/null +++ b/crates/polars/src/prelude.rs @@ -0,0 +1,10 @@ +pub use polars_core::prelude::*; +pub use polars_core::utils::NoNull; +#[cfg(feature = "polars-io")] +pub use polars_io::prelude::*; +#[cfg(feature = "lazy")] +pub use polars_lazy::prelude::*; +#[cfg(feature = "polars-ops")] +pub use polars_ops::prelude::*; +#[cfg(feature = "temporal")] +pub use polars_time::prelude::*; diff --git a/crates/polars/src/sql.rs b/crates/polars/src/sql.rs new file mode 100644 index 000000000000..034b9722c784 --- /dev/null +++ b/crates/polars/src/sql.rs @@ -0,0 +1,2 @@ +pub use polars_sql::function_registry::*; +pub use polars_sql::{SQLContext, keywords, sql_expr}; diff --git a/crates/polars/tests/it/arrow/array/binary/mod.rs b/crates/polars/tests/it/arrow/array/binary/mod.rs new file mode 100644 index 000000000000..8f7e8d8f63fd --- /dev/null +++ b/crates/polars/tests/it/arrow/array/binary/mod.rs @@ -0,0 +1,234 @@ +use arrow::array::{Array, BinaryArray, Splitable}; +use arrow::bitmap::Bitmap; +use arrow::buffer::Buffer; +use arrow::datatypes::ArrowDataType; +use arrow::offset::OffsetsBuffer; +use polars_error::PolarsResult; + +mod mutable; +mod mutable_values; +mod to_mutable; + +fn array() -> BinaryArray { + vec![Some(b"hello".to_vec()), None, Some(b"hello2".to_vec())] + .into_iter() + .collect() +} + +#[test] +fn basics() { + let array = array(); + + assert_eq!(array.value(0), b"hello"); + assert_eq!(array.value(1), b""); + assert_eq!(array.value(2), b"hello2"); + assert_eq!(unsafe { array.value_unchecked(2) }, b"hello2"); + assert_eq!(array.values().as_slice(), b"hellohello2"); + assert_eq!(array.offsets().as_slice(), &[0, 5, 5, 11]); + assert_eq!( + array.validity(), + Some(&Bitmap::from_u8_slice([0b00000101], 3)) + ); + assert!(array.is_valid(0)); + assert!(!array.is_valid(1)); + assert!(array.is_valid(2)); + + let array2 = BinaryArray::::new( + ArrowDataType::Binary, + array.offsets().clone(), + array.values().clone(), + array.validity().cloned(), + ); + assert_eq!(array, array2); + + let array = array.sliced(1, 2); + assert_eq!(array.value(0), b""); + assert_eq!(array.value(1), b"hello2"); + // note how this keeps everything: the offsets were sliced + assert_eq!(array.values().as_slice(), b"hellohello2"); + assert_eq!(array.offsets().as_slice(), &[5, 5, 11]); +} + +#[test] +fn split_at() { + let (lhs, rhs) = array().split_at(1); + + assert_eq!(lhs.value(0), b"hello"); + assert_eq!(rhs.value(0), b""); + assert_eq!(rhs.value(1), b"hello2"); + + // note how this keeps everything: the offsets were sliced + assert_eq!(lhs.values().as_slice(), b"hellohello2"); + assert_eq!(rhs.values().as_slice(), b"hellohello2"); + assert_eq!(lhs.offsets().as_slice(), &[0, 5]); + assert_eq!(rhs.offsets().as_slice(), &[5, 5, 11]); + assert_eq!(lhs.validity().map_or(0, |v| v.set_bits()), 0); + assert_eq!(rhs.validity().map_or(0, |v| v.set_bits()), 1); +} + +#[test] +fn empty() { + let array = BinaryArray::::new_empty(ArrowDataType::Binary); + assert_eq!(array.values().as_slice(), b""); + assert_eq!(array.offsets().as_slice(), &[0]); + assert_eq!(array.validity(), None); +} + +#[test] +fn from() { + let array = BinaryArray::::from([Some(b"hello".as_ref()), Some(b" ".as_ref()), None]); + + let a = array.validity().unwrap(); + assert_eq!(a, &Bitmap::from([true, true, false])); +} + +#[test] +fn from_trusted_len_iter() { + let iter = std::iter::repeat_n(b"hello", 2).map(Some); + let a = BinaryArray::::from_trusted_len_iter(iter); + assert_eq!(a.len(), 2); +} + +#[test] +fn try_from_trusted_len_iter() { + let iter = std::iter::repeat_n(b"hello".as_ref(), 2) + .map(Some) + .map(PolarsResult::Ok); + let a = BinaryArray::::try_from_trusted_len_iter(iter).unwrap(); + assert_eq!(a.len(), 2); +} + +#[test] +fn from_iter() { + let iter = std::iter::repeat_n(b"hello", 2).map(Some); + let a: BinaryArray = iter.collect(); + assert_eq!(a.len(), 2); +} + +#[test] +fn with_validity() { + let array = BinaryArray::::from([Some(b"hello".as_ref()), Some(b" ".as_ref()), None]); + + let array = array.with_validity(None); + + let a = array.validity(); + assert_eq!(a, None); +} + +#[test] +#[should_panic] +fn wrong_offsets() { + let offsets = vec![0, 5, 4].try_into().unwrap(); // invalid offsets + let values = Buffer::from(b"abbbbb".to_vec()); + BinaryArray::::new(ArrowDataType::Binary, offsets, values, None); +} + +#[test] +#[should_panic] +fn wrong_dtype() { + let offsets = vec![0, 4].try_into().unwrap(); + let values = Buffer::from(b"abbb".to_vec()); + BinaryArray::::new(ArrowDataType::Int8, offsets, values, None); +} + +#[test] +#[should_panic] +fn value_with_wrong_offsets_panics() { + let offsets = vec![0, 10, 11, 4].try_into().unwrap(); + let values = Buffer::from(b"abbb".to_vec()); + // the 10-11 is not checked + let array = BinaryArray::::new(ArrowDataType::Binary, offsets, values, None); + + // but access is still checked (and panics) + // without checks, this would result in reading beyond bounds + array.value(0); +} + +#[test] +#[should_panic] +fn index_out_of_bounds_panics() { + let offsets = vec![0, 1, 2, 4].try_into().unwrap(); + let values = Buffer::from(b"abbb".to_vec()); + let array = BinaryArray::::new(ArrowDataType::Utf8, offsets, values, None); + + array.value(3); +} + +#[test] +#[should_panic] +fn value_unchecked_with_wrong_offsets_panics() { + let offsets = vec![0, 10, 11, 4].try_into().unwrap(); + let values = Buffer::from(b"abbb".to_vec()); + // the 10-11 is not checked + let array = BinaryArray::::new(ArrowDataType::Binary, offsets, values, None); + + // but access is still checked (and panics) + // without checks, this would result in reading beyond bounds, + // even if `0` is in bounds + unsafe { array.value_unchecked(0) }; +} + +#[test] +fn debug() { + let array = BinaryArray::::from([Some([1, 2].as_ref()), Some(&[]), None]); + + assert_eq!(format!("{array:?}"), "BinaryArray[[1, 2], [], None]"); +} + +#[test] +fn into_mut_1() { + let offsets = vec![0, 1].try_into().unwrap(); + let values = Buffer::from(b"a".to_vec()); + let a = values.clone(); // cloned values + assert_eq!(a, values); + let array = BinaryArray::::new(ArrowDataType::Binary, offsets, values, None); + assert!(array.into_mut().is_left()); +} + +#[test] +fn into_mut_2() { + let offsets: OffsetsBuffer = vec![0, 1].try_into().unwrap(); + let values = Buffer::from(b"a".to_vec()); + let a = offsets.clone(); // cloned offsets + assert_eq!(a, offsets); + let array = BinaryArray::::new(ArrowDataType::Binary, offsets, values, None); + assert!(array.into_mut().is_left()); +} + +#[test] +fn into_mut_3() { + let offsets = vec![0, 1].try_into().unwrap(); + let values = Buffer::from(b"a".to_vec()); + let validity = Some([true].into()); + let a = validity.clone(); // cloned validity + assert_eq!(a, validity); + let array = BinaryArray::::new(ArrowDataType::Binary, offsets, values, validity); + assert!(array.into_mut().is_left()); +} + +#[test] +fn into_mut_4() { + let offsets = vec![0, 1].try_into().unwrap(); + let values = Buffer::from(b"a".to_vec()); + let validity = Some([true].into()); + let array = BinaryArray::::new(ArrowDataType::Binary, offsets, values, validity); + assert!(array.into_mut().is_right()); +} + +#[test] +fn rev_iter() { + let array = BinaryArray::::from([Some("hello".as_bytes()), Some(" ".as_bytes()), None]); + + assert_eq!( + array.into_iter().rev().collect::>(), + vec![None, Some(" ".as_bytes()), Some("hello".as_bytes())] + ); +} + +#[test] +fn iter_nth() { + let array = BinaryArray::::from([Some("hello"), Some(" "), None]); + + assert_eq!(array.iter().nth(1), Some(Some(" ".as_bytes()))); + assert_eq!(array.iter().nth(10), None); +} diff --git a/crates/polars/tests/it/arrow/array/binary/mutable.rs b/crates/polars/tests/it/arrow/array/binary/mutable.rs new file mode 100644 index 000000000000..d57deb22faa5 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/binary/mutable.rs @@ -0,0 +1,215 @@ +use std::ops::Deref; + +use arrow::array::{BinaryArray, MutableArray, MutableBinaryArray, TryExtendFromSelf}; +use arrow::bitmap::Bitmap; +use polars_error::PolarsError; + +#[test] +fn new() { + assert_eq!(MutableBinaryArray::::new().len(), 0); + + let a = MutableBinaryArray::::with_capacity(2); + assert_eq!(a.len(), 0); + assert!(a.offsets().capacity() >= 2); + assert_eq!(a.values().capacity(), 0); + + let a = MutableBinaryArray::::with_capacities(2, 60); + assert_eq!(a.len(), 0); + assert!(a.offsets().capacity() >= 2); + assert!(a.values().capacity() >= 60); +} + +#[test] +fn from_iter() { + let iter = (0..3u8).map(|x| Some(vec![x; x as usize])); + let a: MutableBinaryArray = iter.clone().collect(); + assert_eq!(a.values().deref(), &[1u8, 2, 2]); + assert_eq!(a.offsets().as_slice(), &[0, 0, 1, 3]); + assert_eq!(a.validity(), None); + + let a = unsafe { MutableBinaryArray::::from_trusted_len_iter_unchecked(iter) }; + assert_eq!(a.values().deref(), &[1u8, 2, 2]); + assert_eq!(a.offsets().as_slice(), &[0, 0, 1, 3]); + assert_eq!(a.validity(), None); +} + +#[test] +fn from_trusted_len_iter() { + let data = [vec![0; 0], vec![1; 1], vec![2; 2]]; + let a: MutableBinaryArray = data.iter().cloned().map(Some).collect(); + assert_eq!(a.values().deref(), &[1u8, 2, 2]); + assert_eq!(a.offsets().as_slice(), &[0, 0, 1, 3]); + assert_eq!(a.validity(), None); + + let a = MutableBinaryArray::::from_trusted_len_iter(data.iter().cloned().map(Some)); + assert_eq!(a.values().deref(), &[1u8, 2, 2]); + assert_eq!(a.offsets().as_slice(), &[0, 0, 1, 3]); + assert_eq!(a.validity(), None); + + let a = MutableBinaryArray::::try_from_trusted_len_iter::( + data.iter().cloned().map(Some).map(Ok), + ) + .unwrap(); + assert_eq!(a.values().deref(), &[1u8, 2, 2]); + assert_eq!(a.offsets().as_slice(), &[0, 0, 1, 3]); + assert_eq!(a.validity(), None); + + let a = MutableBinaryArray::::from_trusted_len_values_iter(data.iter().cloned()); + assert_eq!(a.values().deref(), &[1u8, 2, 2]); + assert_eq!(a.offsets().as_slice(), &[0, 0, 1, 3]); + assert_eq!(a.validity(), None); +} + +#[test] +fn push_null() { + let mut array = MutableBinaryArray::::new(); + array.push::<&str>(None); + + let array: BinaryArray = array.into(); + assert_eq!(array.validity(), Some(&Bitmap::from([false]))); +} + +#[test] +fn pop() { + let mut a = MutableBinaryArray::::new(); + a.push(Some(b"first")); + a.push(Some(b"second")); + a.push::>(None); + a.push_null(); + + assert_eq!(a.pop(), None); + assert_eq!(a.len(), 3); + assert_eq!(a.pop(), None); + assert_eq!(a.len(), 2); + assert_eq!(a.pop(), Some(b"second".to_vec())); + assert_eq!(a.len(), 1); + assert_eq!(a.pop(), Some(b"first".to_vec())); + assert_eq!(a.len(), 0); + assert_eq!(a.pop(), None); + assert_eq!(a.len(), 0); +} + +#[test] +fn pop_all_some() { + let mut a = MutableBinaryArray::::new(); + a.push(Some(b"first")); + a.push(Some(b"second")); + a.push(Some(b"third")); + a.push(Some(b"fourth")); + + for _ in 0..4 { + a.push(Some(b"aaaa")); + } + + a.push(Some(b"bbbb")); + + assert_eq!(a.pop(), Some(b"bbbb".to_vec())); + assert_eq!(a.pop(), Some(b"aaaa".to_vec())); + assert_eq!(a.pop(), Some(b"aaaa".to_vec())); + assert_eq!(a.pop(), Some(b"aaaa".to_vec())); + assert_eq!(a.len(), 5); + assert_eq!(a.pop(), Some(b"aaaa".to_vec())); + assert_eq!(a.pop(), Some(b"fourth".to_vec())); + assert_eq!(a.pop(), Some(b"third".to_vec())); + assert_eq!(a.pop(), Some(b"second".to_vec())); + assert_eq!(a.pop(), Some(b"first".to_vec())); + assert!(a.is_empty()); + assert_eq!(a.pop(), None); +} + +#[test] +fn extend_trusted_len_values() { + let mut array = MutableBinaryArray::::new(); + + array.extend_trusted_len_values(vec![b"first".to_vec(), b"second".to_vec()].into_iter()); + array.extend_trusted_len_values(vec![b"third".to_vec()].into_iter()); + array.extend_trusted_len(vec![None, Some(b"fourth".to_vec())].into_iter()); + + let array: BinaryArray = array.into(); + + assert_eq!(array.values().as_slice(), b"firstsecondthirdfourth"); + assert_eq!(array.offsets().as_slice(), &[0, 5, 11, 16, 16, 22]); + assert_eq!( + array.validity(), + Some(&Bitmap::from_u8_slice([0b00010111], 5)) + ); +} + +#[test] +fn extend_trusted_len() { + let mut array = MutableBinaryArray::::new(); + + array.extend_trusted_len(vec![Some(b"first".to_vec()), Some(b"second".to_vec())].into_iter()); + array.extend_trusted_len(vec![None, Some(b"third".to_vec())].into_iter()); + + let array: BinaryArray = array.into(); + + assert_eq!(array.values().as_slice(), b"firstsecondthird"); + assert_eq!(array.offsets().as_slice(), &[0, 5, 11, 11, 16]); + assert_eq!( + array.validity(), + Some(&Bitmap::from_u8_slice([0b00001011], 4)) + ); +} + +#[test] +fn extend_from_self() { + let mut a = MutableBinaryArray::::from([Some(b"aa"), None]); + + a.try_extend_from_self(&a.clone()).unwrap(); + + assert_eq!( + a, + MutableBinaryArray::::from([Some(b"aa"), None, Some(b"aa"), None]) + ); +} + +#[test] +fn test_set_validity() { + let mut array = MutableBinaryArray::::new(); + array.push(Some(b"first")); + array.push(Some(b"second")); + array.push(Some(b"third")); + array.set_validity(Some([false, false, true].into())); + + assert!(!array.is_valid(0)); + assert!(!array.is_valid(1)); + assert!(array.is_valid(2)); +} + +#[test] +fn test_apply_validity() { + let mut array = MutableBinaryArray::::new(); + array.push(Some(b"first")); + array.push(Some(b"second")); + array.push(Some(b"third")); + array.set_validity(Some([true, true, true].into())); + + array.apply_validity(|mut mut_bitmap| { + mut_bitmap.set(1, false); + mut_bitmap.set(2, false); + mut_bitmap + }); + + assert!(array.is_valid(0)); + assert!(!array.is_valid(1)); + assert!(!array.is_valid(2)); +} + +#[test] +fn test_apply_validity_with_no_validity_inited() { + let mut array = MutableBinaryArray::::new(); + array.push(Some(b"first")); + array.push(Some(b"second")); + array.push(Some(b"third")); + + array.apply_validity(|mut mut_bitmap| { + mut_bitmap.set(1, false); + mut_bitmap.set(2, false); + mut_bitmap + }); + + assert!(array.is_valid(0)); + assert!(array.is_valid(1)); + assert!(array.is_valid(2)); +} diff --git a/crates/polars/tests/it/arrow/array/binary/mutable_values.rs b/crates/polars/tests/it/arrow/array/binary/mutable_values.rs new file mode 100644 index 000000000000..8d9500f2911d --- /dev/null +++ b/crates/polars/tests/it/arrow/array/binary/mutable_values.rs @@ -0,0 +1,101 @@ +use arrow::array::{MutableArray, MutableBinaryValuesArray}; +use arrow::datatypes::ArrowDataType; + +#[test] +fn capacity() { + let mut b = MutableBinaryValuesArray::::with_capacity(100); + + assert_eq!(b.values().capacity(), 0); + assert!(b.offsets().capacity() >= 100); + b.shrink_to_fit(); + assert!(b.offsets().capacity() < 100); +} + +#[test] +fn offsets_must_be_in_bounds() { + let offsets = vec![0, 10].try_into().unwrap(); + let values = b"abbbbb".to_vec(); + assert!( + MutableBinaryValuesArray::::try_new(ArrowDataType::Binary, offsets, values).is_err() + ); +} + +#[test] +fn dtype_must_be_consistent() { + let offsets = vec![0, 4].try_into().unwrap(); + let values = b"abbb".to_vec(); + assert!( + MutableBinaryValuesArray::::try_new(ArrowDataType::Int32, offsets, values).is_err() + ); +} + +#[test] +fn as_box() { + let offsets = vec![0, 2].try_into().unwrap(); + let values = b"ab".to_vec(); + let mut b = + MutableBinaryValuesArray::::try_new(ArrowDataType::Binary, offsets, values).unwrap(); + let _ = b.as_box(); +} + +#[test] +fn as_arc() { + let offsets = vec![0, 2].try_into().unwrap(); + let values = b"ab".to_vec(); + let mut b = + MutableBinaryValuesArray::::try_new(ArrowDataType::Binary, offsets, values).unwrap(); + let _ = b.as_arc(); +} + +#[test] +fn extend_trusted_len() { + let offsets = vec![0, 2].try_into().unwrap(); + let values = b"ab".to_vec(); + let mut b = + MutableBinaryValuesArray::::try_new(ArrowDataType::Binary, offsets, values).unwrap(); + b.extend_trusted_len(vec!["a", "b"].into_iter()); + + let offsets = vec![0, 2, 3, 4].try_into().unwrap(); + let values = b"abab".to_vec(); + assert_eq!( + b.as_box(), + MutableBinaryValuesArray::::try_new(ArrowDataType::Binary, offsets, values) + .unwrap() + .as_box() + ) +} + +#[test] +fn from_trusted_len() { + let mut b = MutableBinaryValuesArray::::from_trusted_len_iter(vec!["a", "b"].into_iter()); + + let offsets = vec![0, 1, 2].try_into().unwrap(); + let values = b"ab".to_vec(); + assert_eq!( + b.as_box(), + MutableBinaryValuesArray::::try_new(ArrowDataType::Binary, offsets, values) + .unwrap() + .as_box() + ) +} + +#[test] +fn extend_from_iter() { + let offsets = vec![0, 2].try_into().unwrap(); + let values = b"ab".to_vec(); + let mut b = + MutableBinaryValuesArray::::try_new(ArrowDataType::Binary, offsets, values).unwrap(); + b.extend_trusted_len(vec!["a", "b"].into_iter()); + + let a = b.clone(); + b.extend_trusted_len(a.iter()); + + let offsets = vec![0, 2, 3, 4, 6, 7, 8].try_into().unwrap(); + let values = b"abababab".to_vec(); + assert_eq!( + b.as_box(), + MutableBinaryValuesArray::::try_new(ArrowDataType::Binary, offsets, values) + .unwrap() + .as_box() + ) +} diff --git a/crates/polars/tests/it/arrow/array/binary/to_mutable.rs b/crates/polars/tests/it/arrow/array/binary/to_mutable.rs new file mode 100644 index 000000000000..8f07d3a166b3 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/binary/to_mutable.rs @@ -0,0 +1,70 @@ +use arrow::array::BinaryArray; +use arrow::bitmap::Bitmap; +use arrow::buffer::Buffer; +use arrow::datatypes::ArrowDataType; + +#[test] +fn not_shared() { + let array = BinaryArray::::from([Some("hello"), Some(" "), None]); + assert!(array.into_mut().is_right()); +} + +#[test] +#[allow(clippy::redundant_clone)] +fn shared_validity() { + let validity = Bitmap::from([true]); + let array = BinaryArray::::new( + ArrowDataType::Binary, + vec![0, 1].try_into().unwrap(), + b"a".to_vec().into(), + Some(validity.clone()), + ); + assert!(array.into_mut().is_left()) +} + +#[test] +#[allow(clippy::redundant_clone)] +fn shared_values() { + let values: Buffer = b"a".to_vec().into(); + let array = BinaryArray::::new( + ArrowDataType::Binary, + vec![0, 1].try_into().unwrap(), + values.clone(), + Some(Bitmap::from([true])), + ); + assert!(array.into_mut().is_left()) +} + +#[test] +#[allow(clippy::redundant_clone)] +fn shared_offsets_values() { + let offsets: Buffer = vec![0, 1].into(); + let values: Buffer = b"a".to_vec().into(); + let array = BinaryArray::::new( + ArrowDataType::Binary, + offsets.clone().try_into().unwrap(), + values.clone(), + Some(Bitmap::from([true])), + ); + assert!(array.into_mut().is_left()) +} + +#[test] +#[allow(clippy::redundant_clone)] +fn shared_offsets() { + let offsets: Buffer = vec![0, 1].into(); + let array = BinaryArray::::new( + ArrowDataType::Binary, + offsets.clone().try_into().unwrap(), + b"a".to_vec().into(), + Some(Bitmap::from([true])), + ); + assert!(array.into_mut().is_left()) +} + +#[test] +#[allow(clippy::redundant_clone)] +fn shared_all() { + let array = BinaryArray::::from([Some("hello"), Some(" "), None]); + assert!(array.clone().into_mut().is_left()) +} diff --git a/crates/polars/tests/it/arrow/array/binview/mod.rs b/crates/polars/tests/it/arrow/array/binview/mod.rs new file mode 100644 index 000000000000..05498d9b511c --- /dev/null +++ b/crates/polars/tests/it/arrow/array/binview/mod.rs @@ -0,0 +1,35 @@ +use std::sync::Arc; + +use arrow::array::*; +use arrow::buffer::Buffer; +use arrow::datatypes::ArrowDataType; + +fn array() -> BinaryViewArrayGeneric { + let datatype = ArrowDataType::Utf8View; + + let hello = View::new_from_bytes(b"hello", 0, 0); + let there = View::new_from_bytes(b"there", 0, 6); + let bye = View::new_from_bytes(b"bye", 1, 0); + let excl = View::new_from_bytes(b"!!!", 1, 3); + let hello_there = View::new_from_bytes(b"hello there", 1, 0); + + let views = Buffer::from(vec![hello, there, bye, excl, hello_there]); + let buffers = Arc::new([ + Buffer::from(b"hello there".to_vec()), + Buffer::from(b"bye!!!".to_vec()), + ]); + let validity = None; + + BinaryViewArrayGeneric::try_new(datatype, views, buffers, validity).unwrap() +} + +#[test] +fn split_at() { + let (lhs, rhs) = array().split_at(2); + + assert_eq!(lhs.value(0), "hello"); + assert_eq!(lhs.value(1), "there"); + assert_eq!(rhs.value(0), "bye"); + assert_eq!(rhs.value(1), "!!!"); + assert_eq!(rhs.value(2), "hello there"); +} diff --git a/crates/polars/tests/it/arrow/array/boolean/mod.rs b/crates/polars/tests/it/arrow/array/boolean/mod.rs new file mode 100644 index 000000000000..293c1407fb39 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/boolean/mod.rs @@ -0,0 +1,160 @@ +use arrow::array::{Array, BooleanArray, Splitable}; +use arrow::bitmap::Bitmap; +use arrow::datatypes::ArrowDataType; +use polars_error::PolarsResult; + +mod mutable; + +fn array() -> BooleanArray { + vec![Some(true), None, Some(false)].into_iter().collect() +} + +#[test] +fn basics() { + let array = array(); + + assert_eq!(array.dtype(), &ArrowDataType::Boolean); + + assert!(array.value(0)); + assert!(!array.value(1)); + assert!(!array.value(2)); + assert!(!unsafe { array.value_unchecked(2) }); + assert_eq!(array.values(), &Bitmap::from_u8_slice([0b00000001], 3)); + assert_eq!( + array.validity(), + Some(&Bitmap::from_u8_slice([0b00000101], 3)) + ); + assert!(array.is_valid(0)); + assert!(!array.is_valid(1)); + assert!(array.is_valid(2)); + + let array2 = BooleanArray::new( + ArrowDataType::Boolean, + array.values().clone(), + array.validity().cloned(), + ); + assert_eq!(array, array2); + + let array = array.sliced(1, 2); + assert!(!array.value(0)); + assert!(!array.value(1)); +} + +#[test] +fn split_at() { + let (lhs, rhs) = array().split_at(1); + + assert!(lhs.is_valid(0)); + assert!(!rhs.is_valid(0)); + assert!(rhs.is_valid(1)); + + assert!(lhs.value(0)); + assert!(!rhs.value(0)); + assert!(!rhs.value(1)); +} + +#[test] +fn try_new_invalid() { + assert!(BooleanArray::try_new(ArrowDataType::Int32, [true].into(), None).is_err()); + assert!( + BooleanArray::try_new( + ArrowDataType::Boolean, + [true].into(), + Some([false, true].into()) + ) + .is_err() + ); +} + +#[test] +fn with_validity() { + let bitmap = Bitmap::from([true, false, true]); + let a = BooleanArray::new(ArrowDataType::Boolean, bitmap, None); + let a = a.with_validity(Some(Bitmap::from([true, false, true]))); + assert!(a.validity().is_some()); +} + +#[test] +fn debug() { + let array = BooleanArray::from([Some(true), None, Some(false)]); + assert_eq!(format!("{array:?}"), "BooleanArray[true, None, false]"); +} + +#[test] +fn into_mut_valid() { + let bitmap = Bitmap::from([true, false, true]); + let a = BooleanArray::new(ArrowDataType::Boolean, bitmap, None); + let _ = a.into_mut().right().unwrap(); + + let bitmap = Bitmap::from([true, false, true]); + let validity = Bitmap::from([true, false, true]); + let a = BooleanArray::new(ArrowDataType::Boolean, bitmap, Some(validity)); + let _ = a.into_mut().right().unwrap(); +} + +#[test] +fn into_mut_invalid() { + let bitmap = Bitmap::from([true, false, true]); + let _other = bitmap.clone(); // values is shared + let a = BooleanArray::new(ArrowDataType::Boolean, bitmap, None); + let _ = a.into_mut().left().unwrap(); + + let bitmap = Bitmap::from([true, false, true]); + let validity = Bitmap::from([true, false, true]); + let _other = validity.clone(); // validity is shared + let a = BooleanArray::new(ArrowDataType::Boolean, bitmap, Some(validity)); + let _ = a.into_mut().left().unwrap(); +} + +#[test] +fn empty() { + let array = BooleanArray::new_empty(ArrowDataType::Boolean); + assert_eq!(array.values().len(), 0); + assert_eq!(array.validity(), None); +} + +#[test] +fn from_trusted_len_iter() { + let iter = std::iter::repeat_n(true, 2).map(Some); + let a = BooleanArray::from_trusted_len_iter(iter.clone()); + assert_eq!(a.len(), 2); + let a = unsafe { BooleanArray::from_trusted_len_iter_unchecked(iter) }; + assert_eq!(a.len(), 2); +} + +#[test] +fn try_from_trusted_len_iter() { + let iter = std::iter::repeat_n(true, 2).map(Some).map(PolarsResult::Ok); + let a = BooleanArray::try_from_trusted_len_iter(iter.clone()).unwrap(); + assert_eq!(a.len(), 2); + let a = unsafe { BooleanArray::try_from_trusted_len_iter_unchecked(iter).unwrap() }; + assert_eq!(a.len(), 2); +} + +#[test] +fn from_trusted_len_values_iter() { + let iter = std::iter::repeat_n(true, 2); + let a = BooleanArray::from_trusted_len_values_iter(iter.clone()); + assert_eq!(a.len(), 2); + let a = unsafe { BooleanArray::from_trusted_len_values_iter_unchecked(iter) }; + assert_eq!(a.len(), 2); +} + +#[test] +fn from_iter() { + let iter = std::iter::repeat_n(true, 2).map(Some); + let a: BooleanArray = iter.collect(); + assert_eq!(a.len(), 2); +} + +#[test] +fn into_iter() { + let data = vec![Some(true), None, Some(false)]; + let rev = data.clone().into_iter().rev(); + + let array: BooleanArray = data.clone().into_iter().collect(); + + assert_eq!(array.clone().into_iter().collect::>(), data); + + assert!(array.into_iter().rev().eq(rev)) +} diff --git a/crates/polars/tests/it/arrow/array/boolean/mutable.rs b/crates/polars/tests/it/arrow/array/boolean/mutable.rs new file mode 100644 index 000000000000..bbacf16d2d93 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/boolean/mutable.rs @@ -0,0 +1,177 @@ +use arrow::array::{MutableArray, MutableBooleanArray, TryExtendFromSelf}; +use arrow::bitmap::MutableBitmap; +use arrow::datatypes::ArrowDataType; +use polars_error::PolarsResult; + +#[test] +fn set() { + let mut a = MutableBooleanArray::from(&[Some(false), Some(true), Some(false)]); + + a.set(1, None); + a.set(0, Some(true)); + assert_eq!( + a, + MutableBooleanArray::from([Some(true), None, Some(false)]) + ); + assert_eq!(a.values(), &MutableBitmap::from([true, false, false])); +} + +#[test] +fn push() { + let mut a = MutableBooleanArray::new(); + a.push_value(true); + a.push_value(false); + a.push(None); + a.push_null(); + assert_eq!( + a, + MutableBooleanArray::from([Some(true), Some(false), None, None]) + ); +} + +#[test] +fn pop() { + let mut a = MutableBooleanArray::new(); + a.push(Some(true)); + a.push(Some(false)); + a.push(None); + a.push_null(); + + assert_eq!(a.pop(), None); + assert_eq!(a.len(), 3); + assert_eq!(a.pop(), None); + assert_eq!(a.len(), 2); + assert_eq!(a.pop(), Some(false)); + assert_eq!(a.len(), 1); + assert_eq!(a.pop(), Some(true)); + assert_eq!(a.len(), 0); + assert_eq!(a.pop(), None); + assert_eq!(a.len(), 0); +} + +#[test] +fn pop_all_some() { + let mut a = MutableBooleanArray::new(); + for _ in 0..4 { + a.push(Some(true)); + } + + for _ in 0..4 { + a.push(Some(false)); + } + + a.push(Some(true)); + + assert_eq!(a.pop(), Some(true)); + assert_eq!(a.pop(), Some(false)); + assert_eq!(a.pop(), Some(false)); + assert_eq!(a.pop(), Some(false)); + assert_eq!(a.len(), 5); + + assert_eq!( + a, + MutableBooleanArray::from([Some(true), Some(true), Some(true), Some(true), Some(false)]) + ); +} + +#[test] +fn from_trusted_len_iter() { + let iter = std::iter::repeat_n(true, 2).map(Some); + let a = MutableBooleanArray::from_trusted_len_iter(iter); + assert_eq!(a, MutableBooleanArray::from([Some(true), Some(true)])); +} + +#[test] +fn from_iter() { + let iter = std::iter::repeat_n(true, 2).map(Some); + let a: MutableBooleanArray = iter.collect(); + assert_eq!(a, MutableBooleanArray::from([Some(true), Some(true)])); +} + +#[test] +fn try_from_trusted_len_iter() { + let iter = vec![Some(true), Some(true), None] + .into_iter() + .map(PolarsResult::Ok); + let a = MutableBooleanArray::try_from_trusted_len_iter(iter).unwrap(); + assert_eq!(a, MutableBooleanArray::from([Some(true), Some(true), None])); +} + +#[test] +fn reserve() { + let mut a = MutableBooleanArray::try_new( + ArrowDataType::Boolean, + MutableBitmap::new(), + Some(MutableBitmap::new()), + ) + .unwrap(); + + a.reserve(10); + assert!(a.validity().unwrap().capacity() > 0); + assert!(a.values().capacity() > 0) +} + +#[test] +fn extend_trusted_len() { + let mut a = MutableBooleanArray::new(); + + a.extend_trusted_len(vec![Some(true), Some(false)].into_iter()); + assert_eq!(a.validity(), None); + + a.extend_trusted_len(vec![None, Some(true)].into_iter()); + assert_eq!( + a.validity(), + Some(&MutableBitmap::from([true, true, false, true])) + ); + assert_eq!(a.values(), &MutableBitmap::from([true, false, false, true])); +} + +#[test] +fn extend_trusted_len_values() { + let mut a = MutableBooleanArray::new(); + + a.extend_trusted_len_values(vec![true, true, false].into_iter()); + assert_eq!(a.validity(), None); + assert_eq!(a.values(), &MutableBitmap::from([true, true, false])); + + let mut a = MutableBooleanArray::new(); + a.push(None); + a.extend_trusted_len_values(vec![true, false].into_iter()); + assert_eq!( + a.validity(), + Some(&MutableBitmap::from([false, true, true])) + ); + assert_eq!(a.values(), &MutableBitmap::from([false, true, false])); +} + +#[test] +fn into_iter() { + let ve = MutableBitmap::from([true, false]) + .into_iter() + .collect::>(); + assert_eq!(ve, vec![true, false]); + let ve = MutableBitmap::from([true, false]) + .iter() + .collect::>(); + assert_eq!(ve, vec![true, false]); +} + +#[test] +fn shrink_to_fit() { + let mut a = MutableBitmap::with_capacity(100); + a.push(true); + a.shrink_to_fit(); + assert_eq!(a.capacity(), 8); +} + +#[test] +fn extend_from_self() { + let mut a = MutableBooleanArray::from([Some(true), None]); + + a.try_extend_from_self(&a.clone()).unwrap(); + + assert_eq!( + a, + MutableBooleanArray::from([Some(true), None, Some(true), None]) + ); +} diff --git a/crates/polars/tests/it/arrow/array/dictionary/mod.rs b/crates/polars/tests/it/arrow/array/dictionary/mod.rs new file mode 100644 index 000000000000..bb05634c7683 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/dictionary/mod.rs @@ -0,0 +1,209 @@ +mod mutable; + +use arrow::array::*; +use arrow::datatypes::ArrowDataType; + +#[test] +fn try_new_ok() { + let values = Utf8Array::::from_slice(["a", "aa"]); + let dtype = ArrowDataType::Dictionary(i32::KEY_TYPE, Box::new(values.dtype().clone()), false); + let array = + DictionaryArray::try_new(dtype, PrimitiveArray::from_vec(vec![1, 0]), values.boxed()) + .unwrap(); + + assert_eq!(array.keys(), &PrimitiveArray::from_vec(vec![1i32, 0])); + assert_eq!( + &Utf8Array::::from_slice(["a", "aa"]) as &dyn Array, + array.values().as_ref(), + ); + assert!(!array.is_ordered()); + + assert_eq!(format!("{array:?}"), "DictionaryArray[aa, a]"); +} + +#[test] +fn split_at() { + let values = Utf8Array::::from_slice(["a", "aa"]); + let dtype = ArrowDataType::Dictionary(i32::KEY_TYPE, Box::new(values.dtype().clone()), false); + let array = + DictionaryArray::try_new(dtype, PrimitiveArray::from_vec(vec![1, 0]), values.boxed()) + .unwrap(); + + let (lhs, rhs) = array.split_at(1); + + assert_eq!(format!("{lhs:?}"), "DictionaryArray[aa]"); + assert_eq!(format!("{rhs:?}"), "DictionaryArray[a]"); +} + +#[test] +fn try_new_incorrect_key() { + let values = Utf8Array::::from_slice(["a", "aa"]); + let dtype = ArrowDataType::Dictionary(i16::KEY_TYPE, Box::new(values.dtype().clone()), false); + + let r = DictionaryArray::try_new(dtype, PrimitiveArray::from_vec(vec![1, 0]), values.boxed()) + .is_err(); + + assert!(r); +} + +#[test] +fn try_new_nulls() { + let key: Option = None; + let keys = PrimitiveArray::from_iter([key]); + let value: &[&str] = &[]; + let values = Utf8Array::::from_slice(value); + + let dtype = ArrowDataType::Dictionary(u32::KEY_TYPE, Box::new(values.dtype().clone()), false); + let r = DictionaryArray::try_new(dtype, keys, values.boxed()).is_ok(); + + assert!(r); +} + +#[test] +fn try_new_incorrect_dt() { + let values = Utf8Array::::from_slice(["a", "aa"]); + let dtype = ArrowDataType::Int32; + + let r = DictionaryArray::try_new(dtype, PrimitiveArray::from_vec(vec![1, 0]), values.boxed()) + .is_err(); + + assert!(r); +} + +#[test] +fn try_new_incorrect_values_dt() { + let values = Utf8Array::::from_slice(["a", "aa"]); + let dtype = ArrowDataType::Dictionary(i32::KEY_TYPE, Box::new(ArrowDataType::LargeUtf8), false); + + let r = DictionaryArray::try_new(dtype, PrimitiveArray::from_vec(vec![1, 0]), values.boxed()) + .is_err(); + + assert!(r); +} + +#[test] +fn try_new_out_of_bounds() { + let values = Utf8Array::::from_slice(["a", "aa"]); + + let r = DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![2, 0]), values.boxed()) + .is_err(); + + assert!(r); +} + +#[test] +fn try_new_out_of_bounds_neg() { + let values = Utf8Array::::from_slice(["a", "aa"]); + + let r = DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![-1, 0]), values.boxed()) + .is_err(); + + assert!(r); +} + +#[test] +fn new_null() { + let dt = ArrowDataType::Dictionary(i16::KEY_TYPE, Box::new(ArrowDataType::Int32), false); + let array = DictionaryArray::::new_null(dt, 2); + + assert_eq!(format!("{array:?}"), "DictionaryArray[None, None]"); +} + +#[test] +fn new_empty() { + let dt = ArrowDataType::Dictionary(i16::KEY_TYPE, Box::new(ArrowDataType::Int32), false); + let array = DictionaryArray::::new_empty(dt); + + assert_eq!(format!("{array:?}"), "DictionaryArray[]"); +} + +#[test] +fn with_validity() { + let values = Utf8Array::::from_slice(["a", "aa"]); + let array = + DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0]), values.boxed()) + .unwrap(); + + let array = array.with_validity(Some([true, false].into())); + + assert_eq!(format!("{array:?}"), "DictionaryArray[aa, None]"); +} + +#[test] +fn rev_iter() { + let values = Utf8Array::::from_slice(["a", "aa"]); + let array = + DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0]), values.boxed()) + .unwrap(); + + let mut iter = array.into_iter(); + assert_eq!(iter.by_ref().rev().count(), 2); + assert_eq!(iter.size_hint(), (0, Some(0))); +} + +#[test] +fn iter_values() { + let values = Utf8Array::::from_slice(["a", "aa"]); + let array = + DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0]), values.boxed()) + .unwrap(); + + let mut iter = array.values_iter(); + assert_eq!(iter.by_ref().count(), 2); + assert_eq!(iter.size_hint(), (0, Some(0))); +} + +#[test] +fn keys_values_iter() { + let values = Utf8Array::::from_slice(["a", "aa"]); + let array = + DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0]), values.boxed()) + .unwrap(); + + assert_eq!(array.keys_values_iter().collect::>(), vec![1, 0]); +} + +#[test] +fn iter_values_typed() { + let values = Utf8Array::::from_slice(["a", "aa"]); + let array = + DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0, 0]), values.boxed()) + .unwrap(); + + let iter = array.values_iter_typed::>().unwrap(); + assert_eq!(iter.size_hint(), (3, Some(3))); + assert_eq!(iter.collect::>(), vec!["aa", "a", "a"]); + + let iter = array.iter_typed::>().unwrap(); + assert_eq!(iter.size_hint(), (3, Some(3))); + assert_eq!( + iter.collect::>(), + vec![Some("aa"), Some("a"), Some("a")] + ); +} + +#[test] +#[should_panic] +fn iter_values_typed_panic() { + let values = Utf8Array::::from_iter([Some("a"), Some("aa"), None]); + let array = + DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0, 0]), values.boxed()) + .unwrap(); + + // should not be iterating values + let iter = array.values_iter_typed::>().unwrap(); + let _ = iter.collect::>(); +} + +#[test] +#[should_panic] +fn iter_values_typed_panic_2() { + let values = Utf8Array::::from_iter([Some("a"), Some("aa"), None]); + let array = + DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0, 0]), values.boxed()) + .unwrap(); + + // should not be iterating values + let iter = array.iter_typed::>().unwrap(); + let _ = iter.collect::>(); +} diff --git a/crates/polars/tests/it/arrow/array/dictionary/mutable.rs b/crates/polars/tests/it/arrow/array/dictionary/mutable.rs new file mode 100644 index 000000000000..571d0d7d8f18 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/dictionary/mutable.rs @@ -0,0 +1,169 @@ +use std::borrow::Borrow; +use std::fmt::Debug; +use std::hash::Hash; + +use arrow::array::indexable::{AsIndexed, Indexable}; +use arrow::array::*; +use polars_error::PolarsResult; +use polars_utils::aliases::{InitHashMaps, PlHashSet}; + +#[test] +fn primitive() -> PolarsResult<()> { + let data = vec![Some(1), Some(2), Some(1)]; + + let mut a = MutableDictionaryArray::>::new(); + a.try_extend(data)?; + assert_eq!(a.len(), 3); + assert_eq!(a.values().len(), 2); + Ok(()) +} + +#[test] +fn utf8_natural() -> PolarsResult<()> { + let data = vec![Some("a"), Some("b"), Some("a")]; + + let mut a = MutableDictionaryArray::>::new(); + a.try_extend(data)?; + + assert_eq!(a.len(), 3); + assert_eq!(a.values().len(), 2); + Ok(()) +} + +#[test] +fn binary_natural() -> PolarsResult<()> { + let data = vec![ + Some("a".as_bytes()), + Some("b".as_bytes()), + Some("a".as_bytes()), + ]; + + let mut a = MutableDictionaryArray::>::new(); + a.try_extend(data)?; + assert_eq!(a.len(), 3); + assert_eq!(a.values().len(), 2); + Ok(()) +} + +#[test] +fn push_utf8() { + let mut new: MutableDictionaryArray> = MutableDictionaryArray::new(); + + for value in [Some("A"), Some("B"), None, Some("C"), Some("A"), Some("B")] { + new.try_push(value).unwrap(); + } + + assert_eq!( + new.values().values(), + MutableUtf8Array::::from_iter_values(["A", "B", "C"].into_iter()).values() + ); + + let mut expected_keys = MutablePrimitiveArray::::from_slice([0, 1]); + expected_keys.push(None); + expected_keys.push(Some(2)); + expected_keys.push(Some(0)); + expected_keys.push(Some(1)); + assert_eq!(*new.keys(), expected_keys); +} + +#[test] +fn into_empty() { + let mut new: MutableDictionaryArray> = MutableDictionaryArray::new(); + for value in [Some("A"), Some("B"), None, Some("C"), Some("A"), Some("B")] { + new.try_push(value).unwrap(); + } + let values = new.values().clone(); + let empty = new.into_empty(); + assert_eq!(empty.values(), &values); + assert!(empty.is_empty()); +} + +#[test] +fn from_values() { + let mut new: MutableDictionaryArray> = MutableDictionaryArray::new(); + for value in [Some("A"), Some("B"), None, Some("C"), Some("A"), Some("B")] { + new.try_push(value).unwrap(); + } + let mut values = new.values().clone(); + let empty = MutableDictionaryArray::::from_values(values.clone()).unwrap(); + assert_eq!(empty.values(), &values); + assert!(empty.is_empty()); + values.push(Some("A")); + assert!(MutableDictionaryArray::::from_values(values).is_err()); +} + +#[test] +fn try_empty() { + let mut values = MutableUtf8Array::::new(); + MutableDictionaryArray::::try_empty(values.clone()).unwrap(); + values.push(Some("A")); + assert!(MutableDictionaryArray::::try_empty(values.clone()).is_err()); +} + +fn test_push_ex(values: Vec, gen_: impl Fn(usize) -> T) +where + M: MutableArray + Indexable + TryPush> + TryExtend> + Default + 'static, + M::Type: Eq + Hash + Debug, + T: AsIndexed + Default + Clone + Eq + Hash, +{ + for is_extend in [false, true] { + let mut set = PlHashSet::new(); + let mut arr = MutableDictionaryArray::::new(); + macro_rules! push { + ($v:expr) => { + if is_extend { + arr.try_extend(std::iter::once($v)) + } else { + arr.try_push($v) + } + }; + } + arr.push_null(); + push!(None).unwrap(); + assert_eq!(arr.len(), 2); + assert_eq!(arr.values().len(), 0); + for (i, v) in values.iter().cloned().enumerate() { + push!(Some(v.clone())).unwrap(); + let is_dup = !set.insert(v.clone()); + if !is_dup { + assert_eq!(arr.values().value_at(i).borrow(), v.as_indexed()); + assert_eq!(arr.keys().value_at(arr.keys().len() - 1), i as u8); + } + assert_eq!(arr.values().len(), set.len()); + assert_eq!(arr.len(), 3 + i); + } + for i in 0..256 - set.len() { + push!(Some(gen_(i))).unwrap(); + } + assert!(push!(Some(gen_(256))).is_err()); + } +} + +#[test] +fn test_push_utf8_ex() { + test_push_ex::, _>(vec!["a".into(), "b".into(), "a".into()], |i| { + i.to_string() + }) +} + +#[test] +fn test_push_i64_ex() { + test_push_ex::, _>(vec![10, 20, 30, 20], |i| 1000 + i as i64); +} + +#[test] +fn test_big_dict() { + let n = 10; + let strings = (0..10).map(|i| i.to_string()).collect::>(); + let mut arr = MutableDictionaryArray::>::new(); + for s in &strings { + arr.try_push(Some(s)).unwrap(); + } + assert_eq!(arr.values().len(), n); + for _ in 0..10_000 { + for s in &strings { + arr.try_push(Some(s)).unwrap(); + } + } + assert_eq!(arr.values().len(), n); +} diff --git a/crates/polars/tests/it/arrow/array/equal/boolean.rs b/crates/polars/tests/it/arrow/array/equal/boolean.rs new file mode 100644 index 000000000000..e20be510879f --- /dev/null +++ b/crates/polars/tests/it/arrow/array/equal/boolean.rs @@ -0,0 +1,53 @@ +use arrow::array::*; + +use super::test_equal; + +#[test] +fn test_boolean_equal() { + let a = BooleanArray::from_slice([false, false, true]); + let b = BooleanArray::from_slice([false, false, true]); + test_equal(&a, &b, true); + + let b = BooleanArray::from_slice([false, false, false]); + test_equal(&a, &b, false); +} + +#[test] +fn test_boolean_equal_null() { + let a = BooleanArray::from(vec![Some(false), None, None, Some(true)]); + let b = BooleanArray::from(vec![Some(false), None, None, Some(true)]); + test_equal(&a, &b, true); + + let b = BooleanArray::from(vec![None, None, None, Some(true)]); + test_equal(&a, &b, false); + + let b = BooleanArray::from(vec![Some(true), None, None, Some(true)]); + test_equal(&a, &b, false); +} + +#[test] +fn test_boolean_equal_offset() { + let a = BooleanArray::from_slice(vec![false, true, false, true, false, false, true]); + let b = BooleanArray::from_slice(vec![true, false, false, false, true, false, true, true]); + test_equal(&a, &b, false); + + let a_slice = a.sliced(2, 3); + let b_slice = b.sliced(3, 3); + test_equal(&a_slice, &b_slice, true); + + let a_slice = a.sliced(3, 4); + let b_slice = b.sliced(4, 4); + test_equal(&a_slice, &b_slice, false); + + // Elements fill in `u8`'s exactly. + let mut vector = vec![false, false, true, true, true, true, true, true]; + let a = BooleanArray::from_slice(vector.clone()); + let b = BooleanArray::from_slice(vector.clone()); + test_equal(&a, &b, true); + + // Elements fill in `u8`s + suffix bits. + vector.push(true); + let a = BooleanArray::from_slice(vector.clone()); + let b = BooleanArray::from_slice(vector); + test_equal(&a, &b, true); +} diff --git a/crates/polars/tests/it/arrow/array/equal/dictionary.rs b/crates/polars/tests/it/arrow/array/equal/dictionary.rs new file mode 100644 index 000000000000..b429c71b4e69 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/equal/dictionary.rs @@ -0,0 +1,97 @@ +use arrow::array::*; + +use super::test_equal; + +fn create_dictionary_array(values: &[Option<&str>], keys: &[Option]) -> DictionaryArray { + let keys = Int16Array::from(keys); + let values = Utf8Array::::from(values); + + DictionaryArray::try_from_keys(keys, values.boxed()).unwrap() +} + +#[test] +fn dictionary_equal() { + // (a, b, c), (0, 1, 0, 2) => (a, b, a, c) + let a = create_dictionary_array( + &[Some("a"), Some("b"), Some("c")], + &[Some(0), Some(1), Some(0), Some(2)], + ); + // different representation (values and keys are swapped), same result + let b = create_dictionary_array( + &[Some("a"), Some("c"), Some("b")], + &[Some(0), Some(2), Some(0), Some(1)], + ); + test_equal(&a, &b, true); + + // different len + let b = create_dictionary_array( + &[Some("a"), Some("c"), Some("b")], + &[Some(0), Some(2), Some(1)], + ); + test_equal(&a, &b, false); + + // different key + let b = create_dictionary_array( + &[Some("a"), Some("c"), Some("b")], + &[Some(0), Some(2), Some(0), Some(0)], + ); + test_equal(&a, &b, false); + + // different values, same keys + let b = create_dictionary_array( + &[Some("a"), Some("b"), Some("d")], + &[Some(0), Some(1), Some(0), Some(2)], + ); + test_equal(&a, &b, false); +} + +#[test] +fn dictionary_equal_null() { + // (a, b, c), (1, 2, 1, 3) => (a, b, a, c) + let a = create_dictionary_array( + &[Some("a"), Some("b"), Some("c")], + &[Some(0), None, Some(0), Some(2)], + ); + + // equal to self + test_equal(&a, &a, true); + + // different representation (values and keys are swapped), same result + let b = create_dictionary_array( + &[Some("a"), Some("c"), Some("b")], + &[Some(0), None, Some(0), Some(1)], + ); + test_equal(&a, &b, true); + + // different null position + let b = create_dictionary_array( + &[Some("a"), Some("c"), Some("b")], + &[Some(0), Some(2), Some(0), None], + ); + test_equal(&a, &b, false); + + // different key + let b = create_dictionary_array( + &[Some("a"), Some("c"), Some("b")], + &[Some(0), None, Some(0), Some(0)], + ); + test_equal(&a, &b, false); + + // different values, same keys + let b = create_dictionary_array( + &[Some("a"), Some("b"), Some("d")], + &[Some(0), None, Some(0), Some(2)], + ); + test_equal(&a, &b, false); + + // different nulls in keys and values + let a = create_dictionary_array( + &[Some("a"), Some("b"), None], + &[Some(0), None, Some(0), Some(2)], + ); + let b = create_dictionary_array( + &[Some("a"), Some("b"), Some("c")], + &[Some(0), None, Some(0), None], + ); + test_equal(&a, &b, true); +} diff --git a/crates/polars/tests/it/arrow/array/equal/fixed_size_list.rs b/crates/polars/tests/it/arrow/array/equal/fixed_size_list.rs new file mode 100644 index 000000000000..0c9629d85422 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/equal/fixed_size_list.rs @@ -0,0 +1,84 @@ +use arrow::array::{ + FixedSizeListArray, MutableFixedSizeListArray, MutablePrimitiveArray, TryExtend, +}; + +use super::test_equal; + +/// Create a fixed size list of 2 value lengths +fn create_fixed_size_list_array, T: AsRef<[Option]>>( + data: T, +) -> FixedSizeListArray { + let data = data.as_ref().iter().map(|x| { + Some(match x { + Some(x) => x.as_ref().iter().map(|x| Some(*x)).collect::>(), + None => std::iter::repeat_n(None, 3).collect::>(), + }) + }); + + let mut list = MutableFixedSizeListArray::new(MutablePrimitiveArray::::new(), 3); + list.try_extend(data).unwrap(); + list.into() +} + +#[test] +fn test_fixed_size_list_equal() { + let a = create_fixed_size_list_array([Some(&[1, 2, 3]), Some(&[4, 5, 6])]); + let b = create_fixed_size_list_array([Some(&[1, 2, 3]), Some(&[4, 5, 6])]); + test_equal(&a, &b, true); + + let b = create_fixed_size_list_array([Some(&[1, 2, 3]), Some(&[4, 5, 7])]); + test_equal(&a, &b, false); +} + +// Test the case where null_count > 0 +#[test] +fn test_fixed_list_null() { + let a = + create_fixed_size_list_array([Some(&[1, 2, 3]), None, None, Some(&[4, 5, 6]), None, None]); + /* + let b = create_fixed_size_list_array(&[ + Some(&[1, 2, 3]), + None, + None, + Some(&[4, 5, 6]), + None, + None, + ]); + test_equal(&a, &b, true); + + let b = create_fixed_size_list_array(&[ + Some(&[1, 2, 3]), + None, + Some(&[7, 8, 9]), + Some(&[4, 5, 6]), + None, + None, + ]); + test_equal(&a, &b, false); + */ + + let b = + create_fixed_size_list_array([Some(&[1, 2, 3]), None, None, Some(&[3, 6, 9]), None, None]); + test_equal(&a, &b, false); +} + +#[test] +fn test_fixed_list_offsets() { + // Test the case where offset != 0 + let a = + create_fixed_size_list_array([Some(&[1, 2, 3]), None, None, Some(&[4, 5, 6]), None, None]); + let b = + create_fixed_size_list_array([Some(&[1, 2, 3]), None, None, Some(&[3, 6, 9]), None, None]); + + let a_slice = a.clone().sliced(0, 3); + let b_slice = b.clone().sliced(0, 3); + test_equal(&a_slice, &b_slice, true); + + let a_slice = a.clone().sliced(0, 5); + let b_slice = b.clone().sliced(0, 5); + test_equal(&a_slice, &b_slice, false); + + let a_slice = a.sliced(4, 1); + let b_slice = b.sliced(4, 1); + test_equal(&a_slice, &b_slice, true); +} diff --git a/crates/polars/tests/it/arrow/array/equal/list.rs b/crates/polars/tests/it/arrow/array/equal/list.rs new file mode 100644 index 000000000000..6deec984f7f6 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/equal/list.rs @@ -0,0 +1,90 @@ +use arrow::array::{Int32Array, ListArray, MutableListArray, MutablePrimitiveArray, TryExtend}; +use arrow::bitmap::Bitmap; +use arrow::datatypes::ArrowDataType; + +use super::test_equal; + +fn create_list_array, T: AsRef<[Option]>>(data: T) -> ListArray { + let iter = data.as_ref().iter().map(|x| { + x.as_ref() + .map(|x| x.as_ref().iter().map(|x| Some(*x)).collect::>()) + }); + let mut array = MutableListArray::>::new(); + array.try_extend(iter).unwrap(); + array.into() +} + +#[test] +fn test_list_equal() { + let a = create_list_array([Some(&[1, 2, 3]), Some(&[4, 5, 6])]); + let b = create_list_array([Some(&[1, 2, 3]), Some(&[4, 5, 6])]); + test_equal(&a, &b, true); + + let b = create_list_array([Some(&[1, 2, 3]), Some(&[4, 5, 7])]); + test_equal(&a, &b, false); +} + +// Test the case where null_count > 0 +#[test] +fn test_list_null() { + let a = create_list_array([Some(&[1, 2]), None, None, Some(&[3, 4]), None, None]); + let b = create_list_array([Some(&[1, 2]), None, None, Some(&[3, 4]), None, None]); + test_equal(&a, &b, true); + + let b = create_list_array([ + Some(&[1, 2]), + None, + Some(&[5, 6]), + Some(&[3, 4]), + None, + None, + ]); + test_equal(&a, &b, false); + + let b = create_list_array([Some(&[1, 2]), None, None, Some(&[3, 5]), None, None]); + test_equal(&a, &b, false); +} + +// Test the case where offset != 0 +#[test] +fn test_list_offsets() { + let a = create_list_array([Some(&[1, 2]), None, None, Some(&[3, 4]), None, None]); + let b = create_list_array([Some(&[1, 2]), None, None, Some(&[3, 5]), None, None]); + + let a_slice = a.clone().sliced(0, 3); + let b_slice = b.clone().sliced(0, 3); + test_equal(&a_slice, &b_slice, true); + + let a_slice = a.clone().sliced(0, 5); + let b_slice = b.clone().sliced(0, 5); + test_equal(&a_slice, &b_slice, false); + + let a_slice = a.sliced(4, 1); + let b_slice = b.sliced(4, 1); + test_equal(&a_slice, &b_slice, true); +} + +#[test] +fn test_bla() { + let offsets = vec![0, 3, 3, 6].try_into().unwrap(); + let dtype = ListArray::::default_datatype(ArrowDataType::Int32); + let values = Box::new(Int32Array::from([ + Some(1), + Some(2), + Some(3), + Some(4), + None, + Some(6), + ])); + let validity = Bitmap::from([true, false, true]); + let lhs = ListArray::::new(dtype, offsets, values, Some(validity)); + let lhs = lhs.sliced(1, 2); + + let offsets = vec![0, 0, 3].try_into().unwrap(); + let dtype = ListArray::::default_datatype(ArrowDataType::Int32); + let values = Box::new(Int32Array::from([Some(4), None, Some(6)])); + let validity = Bitmap::from([false, true]); + let rhs = ListArray::::new(dtype, offsets, values, Some(validity)); + + assert_eq!(lhs, rhs); +} diff --git a/crates/polars/tests/it/arrow/array/equal/mod.rs b/crates/polars/tests/it/arrow/array/equal/mod.rs new file mode 100644 index 000000000000..87f7ffeff251 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/equal/mod.rs @@ -0,0 +1,50 @@ +use arrow::array::*; + +mod dictionary; +mod fixed_size_list; +mod list; +mod primitive; +mod utf8; + +pub fn test_equal(lhs: &dyn Array, rhs: &dyn Array, expected: bool) { + // equality is symmetric + assert!(equal(lhs, lhs), "\n{lhs:?}\n{lhs:?}"); + assert!(equal(rhs, rhs), "\n{rhs:?}\n{rhs:?}"); + + assert_eq!(equal(lhs, rhs), expected, "\n{lhs:?}\n{rhs:?}"); + assert_eq!(equal(rhs, lhs), expected, "\n{rhs:?}\n{lhs:?}"); +} + +#[allow(clippy::type_complexity)] +fn binary_cases() -> Vec<(Vec>, Vec>, bool)> { + let base = vec![ + Some("hello".to_owned()), + None, + None, + Some("world".to_owned()), + None, + None, + ]; + let not_base = vec![ + Some("hello".to_owned()), + Some("foo".to_owned()), + None, + Some("world".to_owned()), + None, + None, + ]; + vec![ + ( + vec![Some("hello".to_owned()), Some("world".to_owned())], + vec![Some("hello".to_owned()), Some("world".to_owned())], + true, + ), + ( + vec![Some("hello".to_owned()), Some("world".to_owned())], + vec![Some("hello".to_owned()), Some("arrow".to_owned())], + false, + ), + (base.clone(), base.clone(), true), + (base, not_base, false), + ] +} diff --git a/crates/polars/tests/it/arrow/array/equal/primitive.rs b/crates/polars/tests/it/arrow/array/equal/primitive.rs new file mode 100644 index 000000000000..e50711eb9728 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/equal/primitive.rs @@ -0,0 +1,90 @@ +use arrow::array::*; + +use super::test_equal; + +#[test] +fn test_primitive() { + let cases = vec![ + ( + vec![Some(1), Some(2), Some(3)], + vec![Some(1), Some(2), Some(3)], + true, + ), + ( + vec![Some(1), Some(2), Some(3)], + vec![Some(1), Some(2), Some(4)], + false, + ), + ( + vec![Some(1), Some(2), None], + vec![Some(1), Some(2), None], + true, + ), + ( + vec![Some(1), None, Some(3)], + vec![Some(1), Some(2), None], + false, + ), + ( + vec![Some(1), None, None], + vec![Some(1), Some(2), None], + false, + ), + ]; + + for (lhs, rhs, expected) in cases { + let lhs = Int32Array::from(&lhs); + let rhs = Int32Array::from(&rhs); + test_equal(&lhs, &rhs, expected); + } +} + +#[test] +fn test_primitive_slice() { + let cases = vec![ + ( + vec![Some(1), Some(2), Some(3)], + (0, 1), + vec![Some(1), Some(2), Some(3)], + (0, 1), + true, + ), + ( + vec![Some(1), Some(2), Some(3)], + (1, 1), + vec![Some(1), Some(2), Some(3)], + (2, 1), + false, + ), + ( + vec![Some(1), Some(2), None], + (1, 1), + vec![Some(1), None, Some(2)], + (2, 1), + true, + ), + ( + vec![None, Some(2), None], + (1, 1), + vec![None, None, Some(2)], + (2, 1), + true, + ), + ( + vec![Some(1), None, Some(2), None, Some(3)], + (2, 2), + vec![None, Some(2), None, Some(3)], + (1, 2), + true, + ), + ]; + + for (lhs, slice_lhs, rhs, slice_rhs, expected) in cases { + let lhs = Int32Array::from(&lhs); + let lhs = lhs.sliced(slice_lhs.0, slice_lhs.1); + let rhs = Int32Array::from(&rhs); + let rhs = rhs.sliced(slice_rhs.0, slice_rhs.1); + + test_equal(&lhs, &rhs, expected); + } +} diff --git a/crates/polars/tests/it/arrow/array/equal/utf8.rs b/crates/polars/tests/it/arrow/array/equal/utf8.rs new file mode 100644 index 000000000000..a9f9e6cff069 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/equal/utf8.rs @@ -0,0 +1,26 @@ +use arrow::array::*; +use arrow::offset::Offset; + +use super::{binary_cases, test_equal}; + +fn test_generic_string_equal() { + let cases = binary_cases(); + + for (lhs, rhs, expected) in cases { + let lhs = lhs.iter().map(|x| x.as_deref()); + let rhs = rhs.iter().map(|x| x.as_deref()); + let lhs = Utf8Array::::from_trusted_len_iter(lhs); + let rhs = Utf8Array::::from_trusted_len_iter(rhs); + test_equal(&lhs, &rhs, expected); + } +} + +#[test] +fn utf8_equal() { + test_generic_string_equal::() +} + +#[test] +fn large_utf8_equal() { + test_generic_string_equal::() +} diff --git a/crates/polars/tests/it/arrow/array/fixed_size_binary/mod.rs b/crates/polars/tests/it/arrow/array/fixed_size_binary/mod.rs new file mode 100644 index 000000000000..f1ad6334912e --- /dev/null +++ b/crates/polars/tests/it/arrow/array/fixed_size_binary/mod.rs @@ -0,0 +1,103 @@ +use arrow::array::FixedSizeBinaryArray; +use arrow::bitmap::Bitmap; +use arrow::buffer::Buffer; +use arrow::datatypes::{ArrowDataType, ExtensionType}; + +mod mutable; + +#[test] +fn basics() { + let array = FixedSizeBinaryArray::new( + ArrowDataType::FixedSizeBinary(2), + Buffer::from(vec![1, 2, 3, 4, 5, 6]), + Some(Bitmap::from([true, false, true])), + ); + assert_eq!(array.size(), 2); + assert_eq!(array.len(), 3); + assert_eq!(array.validity(), Some(&Bitmap::from([true, false, true]))); + + assert_eq!(array.value(0), [1, 2]); + assert_eq!(array.value(2), [5, 6]); + + let array = array.sliced(1, 2); + + assert_eq!(array.value(1), [5, 6]); +} + +#[test] +fn with_validity() { + let a = FixedSizeBinaryArray::new( + ArrowDataType::FixedSizeBinary(2), + vec![1, 2, 3, 4, 5, 6].into(), + None, + ); + let a = a.with_validity(Some(Bitmap::from([true, false, true]))); + assert!(a.validity().is_some()); +} + +#[test] +fn debug() { + let a = FixedSizeBinaryArray::new( + ArrowDataType::FixedSizeBinary(2), + vec![1, 2, 3, 4, 5, 6].into(), + Some(Bitmap::from([true, false, true])), + ); + assert_eq!(format!("{a:?}"), "FixedSizeBinary(2)[[1, 2], None, [5, 6]]"); +} + +#[test] +fn empty() { + let array = FixedSizeBinaryArray::new_empty(ArrowDataType::FixedSizeBinary(2)); + assert_eq!(array.values().len(), 0); + assert_eq!(array.validity(), None); +} + +#[test] +fn null() { + let array = FixedSizeBinaryArray::new_null(ArrowDataType::FixedSizeBinary(2), 2); + assert_eq!(array.values().len(), 4); + assert_eq!(array.validity().cloned(), Some([false, false].into())); +} + +#[test] +fn from_iter() { + let iter = std::iter::repeat_n(vec![1u8, 2], 2).map(Some); + let a = FixedSizeBinaryArray::from_iter(iter, 2); + assert_eq!(a.len(), 2); +} + +#[test] +fn wrong_size() { + let values = Buffer::from(b"abb".to_vec()); + assert!( + FixedSizeBinaryArray::try_new(ArrowDataType::FixedSizeBinary(2), values, None).is_err() + ); +} + +#[test] +fn wrong_len() { + let values = Buffer::from(b"abba".to_vec()); + let validity = Some([true, false, false].into()); // it should be 2 + assert!( + FixedSizeBinaryArray::try_new(ArrowDataType::FixedSizeBinary(2), values, validity).is_err() + ); +} + +#[test] +fn wrong_dtype() { + let values = Buffer::from(b"abba".to_vec()); + assert!(FixedSizeBinaryArray::try_new(ArrowDataType::Binary, values, None).is_err()); +} + +#[test] +fn to() { + let values = Buffer::from(b"abba".to_vec()); + let a = FixedSizeBinaryArray::new(ArrowDataType::FixedSizeBinary(2), values, None); + + let extension = ArrowDataType::Extension(Box::new(ExtensionType { + name: "a".into(), + inner: ArrowDataType::FixedSizeBinary(2), + metadata: None, + })); + let _ = a.to(extension); +} diff --git a/crates/polars/tests/it/arrow/array/fixed_size_binary/mutable.rs b/crates/polars/tests/it/arrow/array/fixed_size_binary/mutable.rs new file mode 100644 index 000000000000..ad89efd256b1 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/fixed_size_binary/mutable.rs @@ -0,0 +1,173 @@ +use arrow::array::*; +use arrow::bitmap::{Bitmap, MutableBitmap}; +use arrow::datatypes::ArrowDataType; + +#[test] +fn basic() { + let a = MutableFixedSizeBinaryArray::try_new( + ArrowDataType::FixedSizeBinary(2), + Vec::from([1, 2, 3, 4]), + None, + ) + .unwrap(); + assert_eq!(a.len(), 2); + assert_eq!(a.dtype(), &ArrowDataType::FixedSizeBinary(2)); + assert_eq!(a.values(), &Vec::from([1, 2, 3, 4])); + assert_eq!(a.validity(), None); + assert_eq!(a.value(1), &[3, 4]); + assert_eq!(unsafe { a.value_unchecked(1) }, &[3, 4]); +} + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let a = MutableFixedSizeBinaryArray::try_new( + ArrowDataType::FixedSizeBinary(2), + Vec::from([1, 2, 3, 4]), + None, + ) + .unwrap(); + assert_eq!(a, a); + let b = MutableFixedSizeBinaryArray::try_new( + ArrowDataType::FixedSizeBinary(2), + Vec::from([1, 2]), + None, + ) + .unwrap(); + assert_eq!(b, b); + assert!(a != b); + let a = MutableFixedSizeBinaryArray::try_new( + ArrowDataType::FixedSizeBinary(2), + Vec::from([1, 2, 3, 4]), + Some(MutableBitmap::from([true, false])), + ) + .unwrap(); + let b = MutableFixedSizeBinaryArray::try_new( + ArrowDataType::FixedSizeBinary(2), + Vec::from([1, 2, 3, 4]), + Some(MutableBitmap::from([false, true])), + ) + .unwrap(); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); +} + +#[test] +fn try_from_iter() { + let array = MutableFixedSizeBinaryArray::try_from_iter( + vec![Some(b"ab"), Some(b"bc"), None, Some(b"fh")], + 2, + ) + .unwrap(); + assert_eq!(array.len(), 4); +} + +#[test] +fn push_null() { + let mut array = MutableFixedSizeBinaryArray::new(2); + array.push::<&[u8]>(None); + + let array: FixedSizeBinaryArray = array.into(); + assert_eq!(array.validity(), Some(&Bitmap::from([false]))); +} + +#[test] +fn pop() { + let mut a = MutableFixedSizeBinaryArray::new(2); + a.push(Some(b"aa")); + a.push::<&[u8]>(None); + a.push(Some(b"bb")); + a.push::<&[u8]>(None); + + assert_eq!(a.pop(), None); + assert_eq!(a.len(), 3); + assert_eq!(a.pop(), Some(b"bb".to_vec())); + assert_eq!(a.len(), 2); + assert_eq!(a.pop(), None); + assert_eq!(a.len(), 1); + assert_eq!(a.pop(), Some(b"aa".to_vec())); + assert!(a.is_empty()); + assert_eq!(a.pop(), None); + assert!(a.is_empty()); +} + +#[test] +fn pop_all_some() { + let mut a = MutableFixedSizeBinaryArray::new(2); + a.push(Some(b"aa")); + a.push(Some(b"bb")); + a.push(Some(b"cc")); + a.push(Some(b"dd")); + + for _ in 0..4 { + a.push(Some(b"11")); + } + + a.push(Some(b"22")); + + assert_eq!(a.pop(), Some(b"22".to_vec())); + assert_eq!(a.pop(), Some(b"11".to_vec())); + assert_eq!(a.pop(), Some(b"11".to_vec())); + assert_eq!(a.pop(), Some(b"11".to_vec())); + assert_eq!(a.len(), 5); + + assert_eq!( + a, + MutableFixedSizeBinaryArray::try_from_iter( + vec![ + Some(b"aa"), + Some(b"bb"), + Some(b"cc"), + Some(b"dd"), + Some(b"11"), + ], + 2, + ) + .unwrap() + ); +} + +#[test] +fn as_arc() { + let mut array = MutableFixedSizeBinaryArray::try_from_iter( + vec![Some(b"ab"), Some(b"bc"), None, Some(b"fh")], + 2, + ) + .unwrap(); + + let array = array.as_arc(); + assert_eq!(array.len(), 4); +} + +#[test] +fn as_box() { + let mut array = MutableFixedSizeBinaryArray::try_from_iter( + vec![Some(b"ab"), Some(b"bc"), None, Some(b"fh")], + 2, + ) + .unwrap(); + + let array = array.as_box(); + assert_eq!(array.len(), 4); +} + +#[test] +fn shrink_to_fit_and_capacity() { + let mut array = MutableFixedSizeBinaryArray::with_capacity(2, 100); + array.push(Some([1, 2])); + array.shrink_to_fit(); + assert_eq!(array.capacity(), 1); +} + +#[test] +fn extend_from_self() { + let mut a = MutableFixedSizeBinaryArray::from([Some([1u8, 2u8]), None]); + + a.try_extend_from_self(&a.clone()).unwrap(); + + assert_eq!( + a, + MutableFixedSizeBinaryArray::from([Some([1u8, 2u8]), None, Some([1u8, 2u8]), None]) + ); +} diff --git a/crates/polars/tests/it/arrow/array/fixed_size_list/mod.rs b/crates/polars/tests/it/arrow/array/fixed_size_list/mod.rs new file mode 100644 index 000000000000..6524f8769d54 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/fixed_size_list/mod.rs @@ -0,0 +1,129 @@ +mod mutable; + +use arrow::array::*; +use arrow::bitmap::Bitmap; +use arrow::datatypes::{ArrowDataType, Field}; + +fn data() -> FixedSizeListArray { + let values = Int32Array::from_slice([10, 20, 0, 0]); + + FixedSizeListArray::try_new( + ArrowDataType::FixedSizeList( + Box::new(Field::new("a".into(), values.dtype().clone(), true)), + 2, + ), + 2, + values.boxed(), + Some([true, false].into()), + ) + .unwrap() +} + +#[test] +fn basics() { + let array = data(); + assert_eq!(array.size(), 2); + assert_eq!(array.len(), 2); + assert_eq!(array.validity(), Some(&Bitmap::from([true, false]))); + + assert_eq!(array.value(0).as_ref(), Int32Array::from_slice([10, 20])); + assert_eq!(array.value(1).as_ref(), Int32Array::from_slice([0, 0])); + + let array = array.sliced(1, 1); + + assert_eq!(array.value(0).as_ref(), Int32Array::from_slice([0, 0])); +} + +#[test] +fn split_at() { + let (lhs, rhs) = data().split_at(1); + + assert_eq!(lhs.value(0).as_ref(), Int32Array::from_slice([10, 20])); + assert_eq!(rhs.value(0).as_ref(), Int32Array::from_slice([0, 0])); +} + +#[test] +fn with_validity() { + let array = data(); + + let a = array.with_validity(None); + assert!(a.validity().is_none()); +} + +#[test] +fn debug() { + let array = data(); + + assert_eq!(format!("{array:?}"), "FixedSizeListArray[[10, 20], None]"); +} + +#[test] +fn empty() { + let array = FixedSizeListArray::new_empty(ArrowDataType::FixedSizeList( + Box::new(Field::new("a".into(), ArrowDataType::Int32, true)), + 2, + )); + assert_eq!(array.values().len(), 0); + assert_eq!(array.validity(), None); +} + +#[test] +fn null() { + let array = FixedSizeListArray::new_null( + ArrowDataType::FixedSizeList( + Box::new(Field::new("a".into(), ArrowDataType::Int32, true)), + 2, + ), + 2, + ); + assert_eq!(array.values().len(), 4); + assert_eq!(array.validity().cloned(), Some([false, false].into())); +} + +#[test] +fn wrong_size() { + let values = Int32Array::from_slice([10, 20, 0]); + assert!( + FixedSizeListArray::try_new( + ArrowDataType::FixedSizeList( + Box::new(Field::new("a".into(), ArrowDataType::Int32, true)), + 2 + ), + 2, + values.boxed(), + None + ) + .is_err() + ); +} + +#[test] +fn wrong_len() { + let values = Int32Array::from_slice([10, 20, 0, 0]); + assert!( + FixedSizeListArray::try_new( + ArrowDataType::FixedSizeList( + Box::new(Field::new("a".into(), ArrowDataType::Int32, true)), + 2 + ), + 2, + values.boxed(), + Some([true, false, false].into()), // it should be 2 + ) + .is_err() + ); +} + +#[test] +fn wrong_dtype() { + let values = Int32Array::from_slice([10, 20, 0, 0]); + assert!( + FixedSizeListArray::try_new( + ArrowDataType::Binary, + 2, + values.boxed(), + Some([true, false, false, false].into()), + ) + .is_err() + ); +} diff --git a/crates/polars/tests/it/arrow/array/fixed_size_list/mutable.rs b/crates/polars/tests/it/arrow/array/fixed_size_list/mutable.rs new file mode 100644 index 000000000000..dc42b85c7e87 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/fixed_size_list/mutable.rs @@ -0,0 +1,92 @@ +use arrow::array::*; +use arrow::datatypes::{ArrowDataType, Field}; + +#[test] +fn primitive() { + let data = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + Some(vec![None, None, None]), + Some(vec![Some(4), None, Some(6)]), + ]; + + let mut list = MutableFixedSizeListArray::new(MutablePrimitiveArray::::new(), 3); + list.try_extend(data).unwrap(); + let list: FixedSizeListArray = list.into(); + + let a = list.value(0); + let a = a.as_any().downcast_ref::().unwrap(); + + let expected = Int32Array::from(vec![Some(1i32), Some(2), Some(3)]); + assert_eq!(a, &expected); + + let a = list.value(1); + let a = a.as_any().downcast_ref::().unwrap(); + + let expected = Int32Array::from(vec![None, None, None]); + assert_eq!(a, &expected) +} + +#[test] +fn new_with_field() { + let data = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + Some(vec![None, None, None]), + Some(vec![Some(4), None, Some(6)]), + ]; + + let mut list = MutableFixedSizeListArray::new_with_field( + MutablePrimitiveArray::::new(), + "custom_items".into(), + false, + 3, + ); + list.try_extend(data).unwrap(); + let list: FixedSizeListArray = list.into(); + + assert_eq!( + list.dtype(), + &ArrowDataType::FixedSizeList( + Box::new(Field::new( + "custom_items".into(), + ArrowDataType::Int32, + false + )), + 3 + ) + ); + + let a = list.value(0); + let a = a.as_any().downcast_ref::().unwrap(); + + let expected = Int32Array::from(vec![Some(1i32), Some(2), Some(3)]); + assert_eq!(a, &expected); + + let a = list.value(1); + let a = a.as_any().downcast_ref::().unwrap(); + + let expected = Int32Array::from(vec![None, None, None]); + assert_eq!(a, &expected) +} + +#[test] +fn extend_from_self() { + let data = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + None, + Some(vec![Some(4), None, Some(6)]), + ]; + let mut a = MutableFixedSizeListArray::new(MutablePrimitiveArray::::new(), 3); + a.try_extend(data.clone()).unwrap(); + + a.try_extend_from_self(&a.clone()).unwrap(); + let a: FixedSizeListArray = a.into(); + + let mut expected = data.clone(); + expected.extend(data); + + let mut b = MutableFixedSizeListArray::new(MutablePrimitiveArray::::new(), 3); + b.try_extend(expected).unwrap(); + let b: FixedSizeListArray = b.into(); + + assert_eq!(a, b); +} diff --git a/crates/polars/tests/it/arrow/array/list/mod.rs b/crates/polars/tests/it/arrow/array/list/mod.rs new file mode 100644 index 000000000000..37ab5d0e7e91 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/list/mod.rs @@ -0,0 +1,89 @@ +use arrow::array::*; +use arrow::buffer::Buffer; +use arrow::datatypes::ArrowDataType; + +mod mutable; + +#[test] +fn debug() { + let values = Buffer::from(vec![1, 2, 3, 4, 5]); + let values = PrimitiveArray::::new(ArrowDataType::Int32, values, None); + + let dtype = ListArray::::default_datatype(ArrowDataType::Int32); + let array = ListArray::::new( + dtype, + vec![0, 2, 2, 3, 5].try_into().unwrap(), + Box::new(values), + None, + ); + + assert_eq!(format!("{array:?}"), "ListArray[[1, 2], [], [3], [4, 5]]"); +} + +#[test] +fn split_at() { + let values = Buffer::from(vec![1, 2, 3, 4, 5]); + let values = PrimitiveArray::::new(ArrowDataType::Int32, values, None); + + let dtype = ListArray::::default_datatype(ArrowDataType::Int32); + let array = ListArray::::new( + dtype, + vec![0, 2, 2, 3, 5].try_into().unwrap(), + Box::new(values), + None, + ); + + let (lhs, rhs) = array.split_at(2); + + assert_eq!(format!("{lhs:?}"), "ListArray[[1, 2], []]"); + assert_eq!(format!("{rhs:?}"), "ListArray[[3], [4, 5]]"); +} + +#[test] +#[should_panic] +fn test_nested_panic() { + let values = Buffer::from(vec![1, 2, 3, 4, 5]); + let values = PrimitiveArray::::new(ArrowDataType::Int32, values, None); + + let dtype = ListArray::::default_datatype(ArrowDataType::Int32); + let array = ListArray::::new( + dtype.clone(), + vec![0, 2, 2, 3, 5].try_into().unwrap(), + Box::new(values), + None, + ); + + // The datatype for the nested array has to be created considering + // the nested structure of the child data + let _ = ListArray::::new( + dtype, + vec![0, 2, 4].try_into().unwrap(), + Box::new(array), + None, + ); +} + +#[test] +fn test_nested_display() { + let values = Buffer::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + let values = PrimitiveArray::::new(ArrowDataType::Int32, values, None); + + let dtype = ListArray::::default_datatype(ArrowDataType::Int32); + let array = ListArray::::new( + dtype, + vec![0, 2, 4, 7, 7, 8, 10].try_into().unwrap(), + Box::new(values), + None, + ); + + let dtype = ListArray::::default_datatype(array.dtype().clone()); + let nested = ListArray::::new( + dtype, + vec![0, 2, 5, 6].try_into().unwrap(), + Box::new(array), + None, + ); + + let expected = "ListArray[[[1, 2], [3, 4]], [[5, 6, 7], [], [8]], [[9, 10]]]"; + assert_eq!(format!("{nested:?}"), expected); +} diff --git a/crates/polars/tests/it/arrow/array/list/mutable.rs b/crates/polars/tests/it/arrow/array/list/mutable.rs new file mode 100644 index 000000000000..6b4c60e3b459 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/list/mutable.rs @@ -0,0 +1,76 @@ +use arrow::array::*; +use arrow::bitmap::Bitmap; +use arrow::buffer::Buffer; +use arrow::datatypes::ArrowDataType; + +#[test] +fn basics() { + let data = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + None, + Some(vec![Some(4), None, Some(6)]), + ]; + + let mut array = MutableListArray::>::new(); + array.try_extend(data).unwrap(); + let array: ListArray = array.into(); + + let values = PrimitiveArray::::new( + ArrowDataType::Int32, + Buffer::from(vec![1, 2, 3, 4, 0, 6]), + Some(Bitmap::from([true, true, true, true, false, true])), + ); + + let dtype = ListArray::::default_datatype(ArrowDataType::Int32); + let expected = ListArray::::new( + dtype, + vec![0, 3, 3, 6].try_into().unwrap(), + Box::new(values), + Some(Bitmap::from([true, false, true])), + ); + assert_eq!(expected, array); +} + +#[test] +fn with_capacity() { + let array = MutableListArray::>::with_capacity(10); + assert!(array.offsets().capacity() >= 10); + assert_eq!(array.offsets().len_proxy(), 0); + assert_eq!(array.values().values().capacity(), 0); + assert_eq!(array.validity(), None); +} + +#[test] +fn push() { + let mut array = MutableListArray::>::new(); + array + .try_push(Some(vec![Some(1i32), Some(2), Some(3)])) + .unwrap(); + assert_eq!(array.len(), 1); + assert_eq!(array.values().values().as_ref(), [1, 2, 3]); + assert_eq!(array.offsets().as_slice(), [0, 3]); + assert_eq!(array.validity(), None); +} + +#[test] +fn extend_from_self() { + let data = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + None, + Some(vec![Some(4), None, Some(6)]), + ]; + let mut a = MutableListArray::>::new(); + a.try_extend(data.clone()).unwrap(); + + a.try_extend_from_self(&a.clone()).unwrap(); + let a: ListArray = a.into(); + + let mut expected = data.clone(); + expected.extend(data); + + let mut b = MutableListArray::>::new(); + b.try_extend(expected).unwrap(); + let b: ListArray = b.into(); + + assert_eq!(a, b); +} diff --git a/crates/polars/tests/it/arrow/array/map/mod.rs b/crates/polars/tests/it/arrow/array/map/mod.rs new file mode 100644 index 000000000000..44702f118a89 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/map/mod.rs @@ -0,0 +1,104 @@ +use arrow::array::*; +use arrow::datatypes::{ArrowDataType, Field}; + +fn dt() -> ArrowDataType { + ArrowDataType::Struct(vec![ + Field::new("a".into(), ArrowDataType::Utf8, true), + Field::new("b".into(), ArrowDataType::Utf8, true), + ]) +} + +fn array() -> MapArray { + let dtype = ArrowDataType::Map(Box::new(Field::new("a".into(), dt(), true)), false); + + let field = StructArray::new( + dt(), + 3, + vec![ + Box::new(Utf8Array::::from_slice(["a", "aa", "aaa"])) as _, + Box::new(Utf8Array::::from_slice(["b", "bb", "bbb"])), + ], + None, + ); + + MapArray::new( + dtype, + vec![0, 1, 2, 3].try_into().unwrap(), + Box::new(field), + None, + ) +} + +#[test] +fn basics() { + let array = array(); + + assert_eq!( + array.value(0), + Box::new(StructArray::new( + dt(), + 1, + vec![ + Box::new(Utf8Array::::from_slice(["a"])) as _, + Box::new(Utf8Array::::from_slice(["b"])), + ], + None, + )) as Box + ); + + let sliced = array.sliced(1, 1); + assert_eq!( + sliced.value(0), + Box::new(StructArray::new( + dt(), + 1, + vec![ + Box::new(Utf8Array::::from_slice(["aa"])) as _, + Box::new(Utf8Array::::from_slice(["bb"])), + ], + None, + )) as Box + ); +} + +#[test] +fn split_at() { + let (lhs, rhs) = array().split_at(1); + + assert_eq!( + lhs.value(0), + Box::new(StructArray::new( + dt(), + 1, + vec![ + Box::new(Utf8Array::::from_slice(["a"])) as _, + Box::new(Utf8Array::::from_slice(["b"])), + ], + None, + )) as Box + ); + assert_eq!( + rhs.value(0), + Box::new(StructArray::new( + dt(), + 1, + vec![ + Box::new(Utf8Array::::from_slice(["aa"])) as _, + Box::new(Utf8Array::::from_slice(["bb"])), + ], + None, + )) as Box + ); + assert_eq!( + rhs.value(1), + Box::new(StructArray::new( + dt(), + 1, + vec![ + Box::new(Utf8Array::::from_slice(["aaa"])) as _, + Box::new(Utf8Array::::from_slice(["bbb"])), + ], + None, + )) as Box + ); +} diff --git a/crates/polars/tests/it/arrow/array/mod.rs b/crates/polars/tests/it/arrow/array/mod.rs new file mode 100644 index 000000000000..786253cb96ae --- /dev/null +++ b/crates/polars/tests/it/arrow/array/mod.rs @@ -0,0 +1,169 @@ +mod binary; +mod binview; +mod boolean; +mod dictionary; +mod equal; +mod fixed_size_binary; +mod fixed_size_list; +mod list; +mod map; +mod primitive; +mod struct_; +mod union; +mod utf8; + +use arrow::array::{Array, PrimitiveArray, clone, new_empty_array, new_null_array}; +use arrow::bitmap::Bitmap; +use arrow::datatypes::{ArrowDataType, ExtensionType, Field, UnionMode}; +use union::union_type; + +#[test] +fn nulls() { + let datatypes = vec![ + ArrowDataType::Int32, + ArrowDataType::Float64, + ArrowDataType::Utf8, + ArrowDataType::Binary, + ArrowDataType::List(Box::new(Field::new( + "a".into(), + ArrowDataType::Binary, + true, + ))), + ]; + let a = datatypes + .into_iter() + .all(|x| new_null_array(x, 10).null_count() == 10); + assert!(a); + + // unions' null count is always 0 + let datatypes = vec![ + union_type( + vec![Field::new("a".into(), ArrowDataType::Binary, true)], + None, + UnionMode::Dense, + ), + union_type( + vec![Field::new("a".into(), ArrowDataType::Binary, true)], + None, + UnionMode::Sparse, + ), + ]; + let a = datatypes + .into_iter() + .all(|x| new_null_array(x, 10).null_count() == 0); + assert!(a); +} + +#[test] +fn empty() { + let datatypes = vec![ + ArrowDataType::Int32, + ArrowDataType::Float64, + ArrowDataType::Utf8, + ArrowDataType::Binary, + ArrowDataType::List(Box::new(Field::new( + "a".into(), + ArrowDataType::Binary, + true, + ))), + ArrowDataType::List(Box::new(Field::new( + "a".into(), + ArrowDataType::Extension(Box::new(ExtensionType { + name: "ext".into(), + inner: ArrowDataType::Int32, + metadata: None, + })), + true, + ))), + union_type( + vec![Field::new("a".into(), ArrowDataType::Binary, true)], + None, + UnionMode::Sparse, + ), + union_type( + vec![Field::new("a".into(), ArrowDataType::Binary, true)], + None, + UnionMode::Dense, + ), + ArrowDataType::Struct(vec![Field::new("a".into(), ArrowDataType::Int32, true)]), + ]; + let a = datatypes.into_iter().all(|x| new_empty_array(x).is_empty()); + assert!(a); +} + +#[test] +fn empty_extension() { + let datatypes = vec![ + ArrowDataType::Int32, + ArrowDataType::Float64, + ArrowDataType::Utf8, + ArrowDataType::Binary, + ArrowDataType::List(Box::new(Field::new( + "a".into(), + ArrowDataType::Binary, + true, + ))), + union_type( + vec![Field::new("a".into(), ArrowDataType::Binary, true)], + None, + UnionMode::Sparse, + ), + union_type( + vec![Field::new("a".into(), ArrowDataType::Binary, true)], + None, + UnionMode::Dense, + ), + ArrowDataType::Struct(vec![Field::new("a".into(), ArrowDataType::Int32, true)]), + ]; + let a = datatypes + .into_iter() + .map(|dt| { + ArrowDataType::Extension(Box::new(ExtensionType { + name: "ext".into(), + inner: dt, + metadata: None, + })) + }) + .all(|x| { + let a = new_empty_array(x); + a.is_empty() && matches!(a.dtype(), ArrowDataType::Extension(_)) + }); + assert!(a); +} + +#[test] +fn test_clone() { + let datatypes = vec![ + ArrowDataType::Int32, + ArrowDataType::Float64, + ArrowDataType::Utf8, + ArrowDataType::Binary, + ArrowDataType::List(Box::new(Field::new( + "a".into(), + ArrowDataType::Binary, + true, + ))), + ]; + let a = datatypes + .into_iter() + .all(|x| clone(new_null_array(x.clone(), 10).as_ref()) == new_null_array(x, 10)); + assert!(a); +} + +#[test] +fn test_with_validity() { + let arr = PrimitiveArray::from_slice([1i32, 2, 3]); + let validity = Bitmap::from(&[true, false, true]); + let arr = arr.with_validity(Some(validity)); + let arr_ref = arr.as_any().downcast_ref::>().unwrap(); + + let expected = PrimitiveArray::from(&[Some(1i32), None, Some(3)]); + assert_eq!(arr_ref, &expected); +} + +// check that we ca derive stuff +#[allow(dead_code)] +#[derive(PartialEq, Clone, Debug)] +struct A { + array: Box, +} diff --git a/crates/polars/tests/it/arrow/array/primitive/fmt.rs b/crates/polars/tests/it/arrow/array/primitive/fmt.rs new file mode 100644 index 000000000000..eb6c067b9ec2 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/primitive/fmt.rs @@ -0,0 +1,224 @@ +use arrow::array::*; +use arrow::datatypes::*; +use arrow::types::{days_ms, months_days_ns}; + +#[test] +fn debug_int32() { + let array = Int32Array::from(&[Some(1), None, Some(2)]); + assert_eq!(format!("{array:?}"), "Int32[1, None, 2]"); +} + +#[test] +fn debug_date32() { + let array = Int32Array::from(&[Some(1), None, Some(2)]).to(ArrowDataType::Date32); + assert_eq!(format!("{array:?}"), "Date32[1970-01-02, None, 1970-01-03]"); +} + +#[test] +fn debug_time32s() { + let array = + Int32Array::from(&[Some(1), None, Some(2)]).to(ArrowDataType::Time32(TimeUnit::Second)); + assert_eq!( + format!("{array:?}"), + "Time32(Second)[00:00:01, None, 00:00:02]" + ); +} + +#[test] +fn debug_time32ms() { + let array = Int32Array::from(&[Some(1), None, Some(2)]) + .to(ArrowDataType::Time32(TimeUnit::Millisecond)); + assert_eq!( + format!("{array:?}"), + "Time32(Millisecond)[00:00:00.001, None, 00:00:00.002]" + ); +} + +#[test] +fn debug_interval_d() { + let array = Int32Array::from(&[Some(1), None, Some(2)]) + .to(ArrowDataType::Interval(IntervalUnit::YearMonth)); + assert_eq!(format!("{array:?}"), "Interval(YearMonth)[1m, None, 2m]"); +} + +#[test] +fn debug_int64() { + let array = Int64Array::from(&[Some(1), None, Some(2)]).to(ArrowDataType::Int64); + assert_eq!(format!("{array:?}"), "Int64[1, None, 2]"); +} + +#[test] +fn debug_date64() { + let array = Int64Array::from(&[Some(1), None, Some(86400000)]).to(ArrowDataType::Date64); + assert_eq!(format!("{array:?}"), "Date64[1970-01-01, None, 1970-01-02]"); +} + +#[test] +fn debug_time64us() { + let array = Int64Array::from(&[Some(1), None, Some(2)]) + .to(ArrowDataType::Time64(TimeUnit::Microsecond)); + assert_eq!( + format!("{array:?}"), + "Time64(Microsecond)[00:00:00.000001, None, 00:00:00.000002]" + ); +} + +#[test] +fn debug_time64ns() { + let array = + Int64Array::from(&[Some(1), None, Some(2)]).to(ArrowDataType::Time64(TimeUnit::Nanosecond)); + assert_eq!( + format!("{array:?}"), + "Time64(Nanosecond)[00:00:00.000000001, None, 00:00:00.000000002]" + ); +} + +#[test] +fn debug_timestamp_s() { + let array = Int64Array::from(&[Some(1), None, Some(2)]) + .to(ArrowDataType::Timestamp(TimeUnit::Second, None)); + assert_eq!( + format!("{array:?}"), + "Timestamp(Second, None)[1970-01-01 00:00:01, None, 1970-01-01 00:00:02]" + ); +} + +#[test] +fn debug_timestamp_ms() { + let array = Int64Array::from(&[Some(1), None, Some(2)]) + .to(ArrowDataType::Timestamp(TimeUnit::Millisecond, None)); + assert_eq!( + format!("{array:?}"), + "Timestamp(Millisecond, None)[1970-01-01 00:00:00.001, None, 1970-01-01 00:00:00.002]" + ); +} + +#[test] +fn debug_timestamp_us() { + let array = Int64Array::from(&[Some(1), None, Some(2)]) + .to(ArrowDataType::Timestamp(TimeUnit::Microsecond, None)); + assert_eq!( + format!("{array:?}"), + "Timestamp(Microsecond, None)[1970-01-01 00:00:00.000001, None, 1970-01-01 00:00:00.000002]" + ); +} + +#[test] +fn debug_timestamp_ns() { + let array = Int64Array::from(&[Some(1), None, Some(2)]) + .to(ArrowDataType::Timestamp(TimeUnit::Nanosecond, None)); + assert_eq!( + format!("{array:?}"), + "Timestamp(Nanosecond, None)[1970-01-01 00:00:00.000000001, None, 1970-01-01 00:00:00.000000002]" + ); +} + +#[test] +fn debug_timestamp_tz_ns() { + let array = Int64Array::from(&[Some(1), None, Some(2)]).to(ArrowDataType::Timestamp( + TimeUnit::Nanosecond, + Some("+02:00".into()), + )); + assert_eq!( + format!("{array:?}"), + "Timestamp(Nanosecond, Some(\"+02:00\"))[1970-01-01 02:00:00.000000001 +02:00, None, 1970-01-01 02:00:00.000000002 +02:00]" + ); +} + +#[test] +fn debug_timestamp_tz_not_parsable() { + let array = Int64Array::from(&[Some(1), None, Some(2)]).to(ArrowDataType::Timestamp( + TimeUnit::Nanosecond, + Some("aa".into()), + )); + assert_eq!( + format!("{array:?}"), + "Timestamp(Nanosecond, Some(\"aa\"))[1 (aa), None, 2 (aa)]" + ); +} + +#[cfg(feature = "timezones")] +#[test] +fn debug_timestamp_tz1_ns() { + let array = Int64Array::from(&[Some(1), None, Some(2)]).to(ArrowDataType::Timestamp( + TimeUnit::Nanosecond, + Some("Europe/Lisbon".into()), + )); + assert_eq!( + format!("{array:?}"), + "Timestamp(Nanosecond, Some(\"Europe/Lisbon\"))[1970-01-01 01:00:00.000000001 CET, None, 1970-01-01 01:00:00.000000002 CET]" + ); +} + +#[test] +fn debug_duration_ms() { + let array = Int64Array::from(&[Some(1), None, Some(2)]) + .to(ArrowDataType::Duration(TimeUnit::Millisecond)); + assert_eq!( + format!("{array:?}"), + "Duration(Millisecond)[1ms, None, 2ms]" + ); +} + +#[test] +fn debug_duration_s() { + let array = + Int64Array::from(&[Some(1), None, Some(2)]).to(ArrowDataType::Duration(TimeUnit::Second)); + assert_eq!(format!("{array:?}"), "Duration(Second)[1s, None, 2s]"); +} + +#[test] +fn debug_duration_us() { + let array = Int64Array::from(&[Some(1), None, Some(2)]) + .to(ArrowDataType::Duration(TimeUnit::Microsecond)); + assert_eq!( + format!("{array:?}"), + "Duration(Microsecond)[1us, None, 2us]" + ); +} + +#[test] +fn debug_duration_ns() { + let array = Int64Array::from(&[Some(1), None, Some(2)]) + .to(ArrowDataType::Duration(TimeUnit::Nanosecond)); + assert_eq!(format!("{array:?}"), "Duration(Nanosecond)[1ns, None, 2ns]"); +} + +#[test] +fn debug_decimal() { + let array = + Int128Array::from(&[Some(12345), None, Some(23456)]).to(ArrowDataType::Decimal(5, 2)); + assert_eq!(format!("{array:?}"), "Decimal(5, 2)[123.45, None, 234.56]"); +} + +#[test] +fn debug_decimal1() { + let array = + Int128Array::from(&[Some(12345), None, Some(23456)]).to(ArrowDataType::Decimal(5, 1)); + assert_eq!(format!("{array:?}"), "Decimal(5, 1)[1234.5, None, 2345.6]"); +} + +#[test] +fn debug_interval_days_ms() { + let array = DaysMsArray::from(&[Some(days_ms::new(1, 1)), None, Some(days_ms::new(2, 2))]); + assert_eq!( + format!("{array:?}"), + "Interval(DayTime)[1d1ms, None, 2d2ms]" + ); +} + +#[test] +fn debug_months_days_ns() { + let data = &[ + Some(months_days_ns::new(1, 1, 2)), + None, + Some(months_days_ns::new(2, 3, 3)), + ]; + + let array = MonthsDaysNsArray::from(&data); + + assert_eq!( + format!("{array:?}"), + "Interval(MonthDayNano)[1m1d2ns, None, 2m3d3ns]" + ); +} diff --git a/crates/polars/tests/it/arrow/array/primitive/mod.rs b/crates/polars/tests/it/arrow/array/primitive/mod.rs new file mode 100644 index 000000000000..d0630c854103 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/primitive/mod.rs @@ -0,0 +1,155 @@ +use arrow::array::*; +use arrow::bitmap::Bitmap; +use arrow::buffer::Buffer; +use arrow::datatypes::*; +use arrow::types::months_days_ns; + +mod fmt; +mod mutable; +mod to_mutable; + +fn array() -> Int32Array { + vec![Some(1), None, Some(10)].into_iter().collect() +} + +#[test] +fn basics() { + let array = array(); + + assert_eq!(array.value(0), 1); + assert_eq!(array.value(1), 0); + assert_eq!(array.value(2), 10); + assert_eq!(array.values().as_slice(), &[1, 0, 10]); + assert_eq!( + array.validity(), + Some(&Bitmap::from_u8_slice([0b00000101], 3)) + ); + assert!(array.is_valid(0)); + assert!(!array.is_valid(1)); + assert!(array.is_valid(2)); + + let array2 = Int32Array::new( + ArrowDataType::Int32, + array.values().clone(), + array.validity().cloned(), + ); + assert_eq!(array, array2); + + let array = array.sliced(1, 2); + assert_eq!(array.value(0), 0); + assert_eq!(array.value(1), 10); + assert_eq!(array.values().as_slice(), &[0, 10]); + + unsafe { + assert_eq!(array.value_unchecked(0), 0); + assert_eq!(array.value_unchecked(1), 10); + } +} + +#[test] +fn split_at() { + let (lhs, rhs) = array().split_at(1); + + assert!(lhs.is_valid(0)); + assert!(!rhs.is_valid(0)); + assert!(rhs.is_valid(1)); + + assert_eq!(lhs.value(0), 1); + assert_eq!(rhs.value(0), 0); + assert_eq!(rhs.value(1), 10); +} + +#[test] +fn empty() { + let array = Int32Array::new_empty(ArrowDataType::Int32); + assert_eq!(array.values().len(), 0); + assert_eq!(array.validity(), None); +} + +#[test] +fn from() { + let data = vec![Some(1), None, Some(10)]; + + let array = PrimitiveArray::from(data.clone()); + assert_eq!(array.len(), 3); + + let array = PrimitiveArray::from_iter(data.clone()); + assert_eq!(array.len(), 3); + + let array = PrimitiveArray::from_trusted_len_iter(data.into_iter()); + assert_eq!(array.len(), 3); + + let data = vec![1i32, 2, 3]; + + let array = PrimitiveArray::from_values(data.clone()); + assert_eq!(array.len(), 3); + + let array = PrimitiveArray::from_trusted_len_values_iter(data.into_iter()); + assert_eq!(array.len(), 3); +} + +#[test] +fn months_days_ns_from_slice() { + let data = &[ + months_days_ns::new(1, 1, 2), + months_days_ns::new(1, 1, 3), + months_days_ns::new(2, 3, 3), + ]; + + let array = MonthsDaysNsArray::from_slice(data); + + let a = array.values().as_slice(); + assert_eq!(a, data.as_ref()); +} + +#[test] +fn wrong_dtype() { + let values = Buffer::from(b"abbb".to_vec()); + assert!(PrimitiveArray::try_new(ArrowDataType::Utf8, values, None).is_err()); +} + +#[test] +fn wrong_len() { + let values = Buffer::from(b"abbb".to_vec()); + let validity = Some([true, false].into()); + assert!(PrimitiveArray::try_new(ArrowDataType::Utf8, values, validity).is_err()); +} + +#[test] +fn into_mut_1() { + let values = Buffer::::from(vec![0, 1]); + let a = values.clone(); // cloned values + assert_eq!(a, values); + let array = PrimitiveArray::new(ArrowDataType::Int32, values, None); + assert!(array.into_mut().is_left()); +} + +#[test] +fn into_mut_2() { + let values = Buffer::::from(vec![0, 1]); + let validity = Some([true, false].into()); + let a = validity.clone(); // cloned values + assert_eq!(a, validity); + let array = PrimitiveArray::new(ArrowDataType::Int32, values, validity); + assert!(array.into_mut().is_left()); +} + +#[test] +fn into_mut_3() { + let values = Buffer::::from(vec![0, 1]); + let validity = Some([true, false].into()); + let array = PrimitiveArray::new(ArrowDataType::Int32, values, validity); + assert!(array.into_mut().is_right()); +} + +#[test] +fn into_iter() { + let data = vec![Some(1), None, Some(10)]; + let rev = data.clone().into_iter().rev(); + + let array: Int32Array = data.clone().into_iter().collect(); + + assert_eq!(array.clone().into_iter().collect::>(), data); + + assert!(array.into_iter().rev().eq(rev)) +} diff --git a/crates/polars/tests/it/arrow/array/primitive/mutable.rs b/crates/polars/tests/it/arrow/array/primitive/mutable.rs new file mode 100644 index 000000000000..e5f3ef81d77b --- /dev/null +++ b/crates/polars/tests/it/arrow/array/primitive/mutable.rs @@ -0,0 +1,328 @@ +use arrow::array::*; +use arrow::bitmap::{Bitmap, MutableBitmap}; +use arrow::datatypes::ArrowDataType; +use polars_error::PolarsResult; + +#[test] +fn from_and_into_data() { + let a = MutablePrimitiveArray::try_new( + ArrowDataType::Int32, + vec![1i32, 0], + Some(MutableBitmap::from([true, false])), + ) + .unwrap(); + assert_eq!(a.len(), 2); + let (a, b, c) = a.into_inner(); + assert_eq!(a, ArrowDataType::Int32); + assert_eq!(b, Vec::from([1i32, 0])); + assert_eq!(c, Some(MutableBitmap::from([true, false]))); +} + +#[test] +fn from_vec() { + let a = MutablePrimitiveArray::from_vec(Vec::from([1i32, 0])); + assert_eq!(a.len(), 2); +} + +#[test] +fn to() { + let a = MutablePrimitiveArray::try_new( + ArrowDataType::Int32, + vec![1i32, 0], + Some(MutableBitmap::from([true, false])), + ) + .unwrap(); + let a = a.to(ArrowDataType::Date32); + assert_eq!(a.dtype(), &ArrowDataType::Date32); +} + +#[test] +fn values_mut_slice() { + let mut a = MutablePrimitiveArray::try_new( + ArrowDataType::Int32, + vec![1i32, 0], + Some(MutableBitmap::from([true, false])), + ) + .unwrap(); + let values = a.values_mut_slice(); + + values[0] = 10; + assert_eq!(a.values()[0], 10); +} + +#[test] +fn push() { + let mut a = MutablePrimitiveArray::::new(); + a.push(Some(1)); + a.push(None); + a.push_null(); + assert_eq!(a.len(), 3); + assert!(a.is_valid(0)); + assert!(!a.is_valid(1)); + assert!(!a.is_valid(2)); + + assert_eq!(a.values(), &Vec::from([1, 0, 0])); +} + +#[test] +fn pop() { + let mut a = MutablePrimitiveArray::::new(); + a.push(Some(1)); + a.push(None); + a.push(Some(2)); + a.push_null(); + assert_eq!(a.pop(), None); + assert_eq!(a.pop(), Some(2)); + assert_eq!(a.pop(), None); + assert!(a.is_valid(0)); + assert_eq!(a.values(), &Vec::from([1])); + assert_eq!(a.pop(), Some(1)); + assert_eq!(a.len(), 0); + assert_eq!(a.pop(), None); + assert_eq!(a.len(), 0); +} + +#[test] +fn pop_all_some() { + let mut a = MutablePrimitiveArray::::new(); + for v in 0..8 { + a.push(Some(v)); + } + + a.push(Some(8)); + assert_eq!(a.pop(), Some(8)); + assert_eq!(a.pop(), Some(7)); + assert_eq!(a.pop(), Some(6)); + assert_eq!(a.pop(), Some(5)); + assert_eq!(a.pop(), Some(4)); + assert_eq!(a.len(), 4); + assert!(a.is_valid(0)); + assert!(a.is_valid(1)); + assert!(a.is_valid(2)); + assert!(a.is_valid(3)); + assert_eq!(a.values(), &Vec::from([0, 1, 2, 3])); +} + +#[test] +fn set() { + let mut a = MutablePrimitiveArray::::from([Some(1), None]); + + a.set(0, Some(2)); + a.set(1, Some(1)); + + assert_eq!(a.len(), 2); + assert!(a.is_valid(0)); + assert!(a.is_valid(1)); + + assert_eq!(a.values(), &Vec::from([2, 1])); + + let mut a = MutablePrimitiveArray::::from_slice([1, 2]); + + a.set(0, Some(2)); + a.set(1, None); + + assert_eq!(a.len(), 2); + assert!(a.is_valid(0)); + assert!(!a.is_valid(1)); + + assert_eq!(a.values(), &Vec::from([2, 0])); +} + +#[test] +fn from_iter() { + let a = MutablePrimitiveArray::::from_iter((0..2).map(Some)); + assert_eq!(a.len(), 2); + let validity = a.validity().unwrap(); + assert_eq!(validity.unset_bits(), 0); +} + +#[test] +fn natural_arc() { + let a = MutablePrimitiveArray::::from_slice([0, 1]).into_arc(); + assert_eq!(a.len(), 2); +} + +#[test] +fn as_arc() { + let a = MutablePrimitiveArray::::from_slice([0, 1]).as_arc(); + assert_eq!(a.len(), 2); +} + +#[test] +fn as_box() { + let a = MutablePrimitiveArray::::from_slice([0, 1]).as_box(); + assert_eq!(a.len(), 2); +} + +#[test] +fn shrink_to_fit_and_capacity() { + let mut a = MutablePrimitiveArray::::with_capacity(100); + a.push(Some(1)); + a.try_push(None).unwrap(); + assert!(a.capacity() >= 100); + (&mut a as &mut dyn MutableArray).shrink_to_fit(); + assert_eq!(a.capacity(), 2); +} + +#[test] +fn only_nulls() { + let mut a = MutablePrimitiveArray::::new(); + a.push(None); + a.push(None); + let a: PrimitiveArray = a.into(); + assert_eq!(a.validity(), Some(&Bitmap::from([false, false]))); +} + +#[test] +fn from_trusted_len() { + let a = + MutablePrimitiveArray::::from_trusted_len_iter(vec![Some(1), None, None].into_iter()); + let a: PrimitiveArray = a.into(); + assert_eq!(a.validity(), Some(&Bitmap::from([true, false, false]))); + + let a = unsafe { + MutablePrimitiveArray::::from_trusted_len_iter_unchecked( + vec![Some(1), None].into_iter(), + ) + }; + let a: PrimitiveArray = a.into(); + assert_eq!(a.validity(), Some(&Bitmap::from([true, false]))); +} + +#[test] +fn extend_trusted_len() { + let mut a = MutablePrimitiveArray::::new(); + a.extend_trusted_len(vec![Some(1), Some(2)].into_iter()); + let validity = a.validity().unwrap(); + assert_eq!(validity.unset_bits(), 0); + a.extend_trusted_len(vec![None, Some(4)].into_iter()); + assert_eq!( + a.validity(), + Some(&MutableBitmap::from([true, true, false, true])) + ); + assert_eq!(a.values(), &Vec::::from([1, 2, 0, 4])); +} + +#[test] +fn extend_constant_no_validity() { + let mut a = MutablePrimitiveArray::::new(); + a.push(Some(1)); + a.extend_constant(2, Some(3)); + assert_eq!(a.validity(), None); + assert_eq!(a.values(), &Vec::::from([1, 3, 3])); +} + +#[test] +fn extend_constant_validity() { + let mut a = MutablePrimitiveArray::::new(); + a.push(Some(1)); + a.extend_constant(2, None); + assert_eq!( + a.validity(), + Some(&MutableBitmap::from([true, false, false])) + ); + assert_eq!(a.values(), &Vec::::from([1, 0, 0])); +} + +#[test] +fn extend_constant_validity_inverse() { + let mut a = MutablePrimitiveArray::::new(); + a.push(None); + a.extend_constant(2, Some(1)); + assert_eq!( + a.validity(), + Some(&MutableBitmap::from([false, true, true])) + ); + assert_eq!(a.values(), &Vec::::from([0, 1, 1])); +} + +#[test] +fn extend_constant_validity_none() { + let mut a = MutablePrimitiveArray::::new(); + a.push(None); + a.extend_constant(2, None); + assert_eq!( + a.validity(), + Some(&MutableBitmap::from([false, false, false])) + ); + assert_eq!(a.values(), &Vec::::from([0, 0, 0])); +} + +#[test] +fn extend_trusted_len_values() { + let mut a = MutablePrimitiveArray::::new(); + a.extend_trusted_len_values(vec![1, 2, 3].into_iter()); + assert_eq!(a.validity(), None); + assert_eq!(a.values(), &Vec::::from([1, 2, 3])); + + let mut a = MutablePrimitiveArray::::new(); + a.push(None); + a.extend_trusted_len_values(vec![1, 2].into_iter()); + assert_eq!( + a.validity(), + Some(&MutableBitmap::from([false, true, true])) + ); +} + +#[test] +fn extend_from_slice() { + let mut a = MutablePrimitiveArray::::new(); + a.extend_from_slice(&[1, 2, 3]); + assert_eq!(a.validity(), None); + assert_eq!(a.values(), &Vec::::from([1, 2, 3])); + + let mut a = MutablePrimitiveArray::::new(); + a.push(None); + a.extend_from_slice(&[1, 2]); + assert_eq!( + a.validity(), + Some(&MutableBitmap::from([false, true, true])) + ); +} + +#[test] +fn set_validity() { + let mut a = MutablePrimitiveArray::::new(); + a.extend_trusted_len(vec![Some(1), Some(2)].into_iter()); + let validity = a.validity().unwrap(); + assert_eq!(validity.unset_bits(), 0); + + // test that upon conversion to array the bitmap is set to None + let arr: PrimitiveArray<_> = a.clone().into(); + assert_eq!(arr.validity(), None); + + // test set_validity + a.set_validity(Some(MutableBitmap::from([false, true]))); + assert_eq!(a.validity(), Some(&MutableBitmap::from([false, true]))); +} + +#[test] +fn set_values() { + let mut a = MutablePrimitiveArray::::from_slice([1, 2]); + a.set_values(Vec::from([1, 3])); + assert_eq!(a.values().as_slice(), [1, 3]); +} + +#[test] +fn try_from_trusted_len_iter() { + let iter = std::iter::repeat_n(Some(1), 2).map(PolarsResult::Ok); + let a = MutablePrimitiveArray::try_from_trusted_len_iter(iter).unwrap(); + assert_eq!(a, MutablePrimitiveArray::from([Some(1), Some(1)])); +} + +#[test] +fn wrong_dtype() { + assert!(MutablePrimitiveArray::::try_new(ArrowDataType::Utf8, vec![], None).is_err()); +} + +#[test] +fn extend_from_self() { + let mut a = MutablePrimitiveArray::from([Some(1), None]); + + a.try_extend_from_self(&a.clone()).unwrap(); + + assert_eq!( + a, + MutablePrimitiveArray::from([Some(1), None, Some(1), None]) + ); +} diff --git a/crates/polars/tests/it/arrow/array/primitive/to_mutable.rs b/crates/polars/tests/it/arrow/array/primitive/to_mutable.rs new file mode 100644 index 000000000000..0cc32155a318 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/primitive/to_mutable.rs @@ -0,0 +1,53 @@ +use arrow::array::PrimitiveArray; +use arrow::bitmap::Bitmap; +use arrow::datatypes::ArrowDataType; +use either::Either; + +#[test] +fn array_to_mutable() { + let data = vec![1, 2, 3]; + let arr = PrimitiveArray::new(ArrowDataType::Int32, data.into(), None); + + // to mutable push and freeze again + let mut mut_arr = arr.into_mut().unwrap_right(); + mut_arr.push(Some(5)); + let immut: PrimitiveArray = mut_arr.into(); + assert_eq!(immut.values().as_slice(), [1, 2, 3, 5]); + + // let's cause a realloc and see if miri is ok + let mut mut_arr = immut.into_mut().unwrap_right(); + mut_arr.extend_constant(256, Some(9)); + let immut: PrimitiveArray = mut_arr.into(); + assert_eq!(immut.values().len(), 256 + 4); +} + +#[test] +fn array_to_mutable_not_owned() { + let data = vec![1, 2, 3]; + let arr = PrimitiveArray::new(ArrowDataType::Int32, data.into(), None); + let arr2 = arr.clone(); + + // to the `to_mutable` should fail and we should get back the original array + match arr2.into_mut() { + Either::Left(arr2) => { + assert_eq!(arr, arr2); + }, + _ => panic!(), + } +} + +#[test] +#[allow(clippy::redundant_clone)] +fn array_to_mutable_validity() { + let data = vec![1, 2, 3]; + + // both have a single reference should be ok + let bitmap = Bitmap::from_iter([true, false, true]); + let arr = PrimitiveArray::new(ArrowDataType::Int32, data.clone().into(), Some(bitmap)); + assert!(matches!(arr.into_mut(), Either::Right(_))); + + // now we clone the bitmap increasing the ref count + let bitmap = Bitmap::from_iter([true, false, true]); + let arr = PrimitiveArray::new(ArrowDataType::Int32, data.into(), Some(bitmap.clone())); + assert!(matches!(arr.into_mut(), Either::Left(_))); +} diff --git a/crates/polars/tests/it/arrow/array/struct_/iterator.rs b/crates/polars/tests/it/arrow/array/struct_/iterator.rs new file mode 100644 index 000000000000..7e3430e355d9 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/struct_/iterator.rs @@ -0,0 +1,29 @@ +use arrow::array::*; +use arrow::datatypes::*; +use arrow::scalar::new_scalar; + +#[test] +fn test_simple_iter() { + let boolean = BooleanArray::from_slice([false, false, true, true]).boxed(); + let int = Int32Array::from_slice([42, 28, 19, 31]).boxed(); + + let fields = vec![ + Field::new("b".into(), ArrowDataType::Boolean, false), + Field::new("c".into(), ArrowDataType::Int32, false), + ]; + + let array = StructArray::new( + ArrowDataType::Struct(fields), + boolean.len(), + vec![boolean.clone(), int.clone()], + None, + ); + + for (i, item) in array.iter().enumerate() { + let expected = Some(vec![ + new_scalar(boolean.as_ref(), i), + new_scalar(int.as_ref(), i), + ]); + assert_eq!(expected, item); + } +} diff --git a/crates/polars/tests/it/arrow/array/struct_/mod.rs b/crates/polars/tests/it/arrow/array/struct_/mod.rs new file mode 100644 index 000000000000..9492a1b6bdba --- /dev/null +++ b/crates/polars/tests/it/arrow/array/struct_/mod.rs @@ -0,0 +1,44 @@ +mod iterator; + +use arrow::array::*; +use arrow::bitmap::Bitmap; +use arrow::datatypes::*; + +fn array() -> StructArray { + let boolean = BooleanArray::from_slice([false, false, true, true]).boxed(); + let int = Int32Array::from_slice([42, 28, 19, 31]).boxed(); + + let fields = vec![ + Field::new("b".into(), ArrowDataType::Boolean, false), + Field::new("c".into(), ArrowDataType::Int32, false), + ]; + + StructArray::new( + ArrowDataType::Struct(fields), + boolean.len(), + vec![boolean.clone(), int.clone()], + Some(Bitmap::from([true, true, false, true])), + ) +} + +#[test] +fn debug() { + let array = array(); + + assert_eq!( + format!("{array:?}"), + "StructArray[{b: false, c: 42}, {b: false, c: 28}, None, {b: true, c: 31}]" + ); +} + +#[test] +fn split_at() { + let array = array(); + + let (lhs, rhs) = array.split_at(1); + assert_eq!(format!("{lhs:?}"), "StructArray[{b: false, c: 42}]"); + assert_eq!( + format!("{rhs:?}"), + "StructArray[{b: false, c: 28}, None, {b: true, c: 31}]" + ); +} diff --git a/crates/polars/tests/it/arrow/array/union.rs b/crates/polars/tests/it/arrow/array/union.rs new file mode 100644 index 000000000000..32d36b868017 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/union.rs @@ -0,0 +1,375 @@ +use arrow::array::*; +use arrow::buffer::Buffer; +use arrow::datatypes::*; +use arrow::scalar::{PrimitiveScalar, Scalar, UnionScalar, Utf8Scalar, new_scalar}; +use polars_error::PolarsResult; + +pub fn union_type(fields: Vec, ids: Option>, mode: UnionMode) -> ArrowDataType { + ArrowDataType::Union(Box::new(UnionType { fields, ids, mode })) +} + +fn next_unwrap(iter: &mut I) -> T +where + I: Iterator>, + T: Clone + 'static, +{ + iter.next() + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .clone() +} + +#[test] +fn sparse_debug() -> PolarsResult<()> { + let fields = vec![ + Field::new("a".into(), ArrowDataType::Int32, true), + Field::new("b".into(), ArrowDataType::Utf8, true), + ]; + let dtype = union_type(fields, None, UnionMode::Sparse); + let types = vec![0, 0, 1].into(); + let fields = vec![ + Int32Array::from(&[Some(1), None, Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + let array = UnionArray::new(dtype, types, fields, None); + + assert_eq!(format!("{array:?}"), "UnionArray[1, None, c]"); + + Ok(()) +} + +#[test] +fn dense_debug() -> PolarsResult<()> { + let fields = vec![ + Field::new("a".into(), ArrowDataType::Int32, true), + Field::new("b".into(), ArrowDataType::Utf8, true), + ]; + let dtype = union_type(fields, None, UnionMode::Dense); + let types = vec![0, 0, 1].into(); + let fields = vec![ + Int32Array::from(&[Some(1), None, Some(2)]).boxed(), + Utf8Array::::from([Some("c")]).boxed(), + ]; + let offsets = Some(vec![0, 1, 0].into()); + + let array = UnionArray::new(dtype, types, fields, offsets); + + assert_eq!(format!("{array:?}"), "UnionArray[1, None, c]"); + + Ok(()) +} + +#[test] +fn slice() -> PolarsResult<()> { + let fields = vec![ + Field::new("a".into(), ArrowDataType::Int32, true), + Field::new("b".into(), ArrowDataType::LargeUtf8, true), + ]; + let dtype = union_type(fields, None, UnionMode::Sparse); + let types = Buffer::from(vec![0, 0, 1]); + let fields = vec![ + Int32Array::from(&[Some(1), None, Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + let array = UnionArray::new(dtype.clone(), types, fields.clone(), None); + + let result = array.sliced(1, 2); + + let sliced_types = Buffer::from(vec![0, 1]); + let sliced_fields = vec![ + Int32Array::from(&[None, Some(2)]).boxed(), + Utf8Array::::from([Some("b"), Some("c")]).boxed(), + ]; + let expected = UnionArray::new(dtype, sliced_types, sliced_fields, None); + + assert_eq!(expected, result); + Ok(()) +} + +#[test] +fn iter_sparse() -> PolarsResult<()> { + let fields = vec![ + Field::new("a".into(), ArrowDataType::Int32, true), + Field::new("b".into(), ArrowDataType::Utf8, true), + ]; + let dtype = union_type(fields, None, UnionMode::Sparse); + let types = Buffer::from(vec![0, 0, 1]); + let fields = vec![ + Int32Array::from(&[Some(1), None, Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + let array = UnionArray::new(dtype, types, fields.clone(), None); + let mut iter = array.iter(); + + assert_eq!( + next_unwrap::, _>(&mut iter).value(), + &Some(1) + ); + assert_eq!( + next_unwrap::, _>(&mut iter).value(), + &None + ); + assert_eq!( + next_unwrap::, _>(&mut iter).value(), + Some("c") + ); + assert_eq!(iter.next(), None); + + Ok(()) +} + +#[test] +fn iter_dense() -> PolarsResult<()> { + let fields = vec![ + Field::new("a".into(), ArrowDataType::Int32, true), + Field::new("b".into(), ArrowDataType::Utf8, true), + ]; + let dtype = union_type(fields, None, UnionMode::Dense); + let types = Buffer::from(vec![0, 0, 1]); + let offsets = Buffer::::from(vec![0, 1, 0]); + let fields = vec![ + Int32Array::from(&[Some(1), None]).boxed(), + Utf8Array::::from([Some("c")]).boxed(), + ]; + + let array = UnionArray::new(dtype, types, fields.clone(), Some(offsets)); + let mut iter = array.iter(); + + assert_eq!( + next_unwrap::, _>(&mut iter).value(), + &Some(1) + ); + assert_eq!( + next_unwrap::, _>(&mut iter).value(), + &None + ); + assert_eq!( + next_unwrap::, _>(&mut iter).value(), + Some("c") + ); + assert_eq!(iter.next(), None); + + Ok(()) +} + +#[test] +fn iter_sparse_slice() -> PolarsResult<()> { + let fields = vec![ + Field::new("a".into(), ArrowDataType::Int32, true), + Field::new("b".into(), ArrowDataType::Utf8, true), + ]; + let dtype = union_type(fields, None, UnionMode::Sparse); + let types = Buffer::from(vec![0, 0, 1]); + let fields = vec![ + Int32Array::from(&[Some(1), Some(3), Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + let array = UnionArray::new(dtype, types, fields.clone(), None); + let array_slice = array.sliced(1, 1); + let mut iter = array_slice.iter(); + + assert_eq!( + next_unwrap::, _>(&mut iter).value(), + &Some(3) + ); + assert_eq!(iter.next(), None); + + Ok(()) +} + +#[test] +fn iter_dense_slice() -> PolarsResult<()> { + let fields = vec![ + Field::new("a".into(), ArrowDataType::Int32, true), + Field::new("b".into(), ArrowDataType::Utf8, true), + ]; + let dtype = union_type(fields, None, UnionMode::Dense); + let types = Buffer::from(vec![0, 0, 1]); + let offsets = Buffer::::from(vec![0, 1, 0]); + let fields = vec![ + Int32Array::from(&[Some(1), Some(3)]).boxed(), + Utf8Array::::from([Some("c")]).boxed(), + ]; + + let array = UnionArray::new(dtype, types, fields.clone(), Some(offsets)); + let array_slice = array.sliced(1, 1); + let mut iter = array_slice.iter(); + + assert_eq!( + next_unwrap::, _>(&mut iter).value(), + &Some(3) + ); + assert_eq!(iter.next(), None); + + Ok(()) +} + +#[test] +fn scalar() -> PolarsResult<()> { + let fields = vec![ + Field::new("a".into(), ArrowDataType::Int32, true), + Field::new("b".into(), ArrowDataType::Utf8, true), + ]; + let dtype = union_type(fields, None, UnionMode::Dense); + let types = Buffer::from(vec![0, 0, 1]); + let offsets = Buffer::::from(vec![0, 1, 0]); + let fields = vec![ + Int32Array::from(&[Some(1), None]).boxed(), + Utf8Array::::from([Some("c")]).boxed(), + ]; + + let array = UnionArray::new(dtype, types, fields.clone(), Some(offsets)); + + let scalar = new_scalar(&array, 0); + let union_scalar = scalar.as_any().downcast_ref::().unwrap(); + assert_eq!( + union_scalar + .value() + .as_any() + .downcast_ref::>() + .unwrap() + .value(), + &Some(1) + ); + assert_eq!(union_scalar.type_(), 0); + let scalar = new_scalar(&array, 1); + let union_scalar = scalar.as_any().downcast_ref::().unwrap(); + assert_eq!( + union_scalar + .value() + .as_any() + .downcast_ref::>() + .unwrap() + .value(), + &None + ); + assert_eq!(union_scalar.type_(), 0); + + let scalar = new_scalar(&array, 2); + let union_scalar = scalar.as_any().downcast_ref::().unwrap(); + assert_eq!( + union_scalar + .value() + .as_any() + .downcast_ref::>() + .unwrap() + .value(), + Some("c") + ); + assert_eq!(union_scalar.type_(), 1); + + Ok(()) +} + +#[test] +fn dense_without_offsets_is_error() { + let fields = vec![ + Field::new("a".into(), ArrowDataType::Int32, true), + Field::new("b".into(), ArrowDataType::Utf8, true), + ]; + let dtype = union_type(fields, None, UnionMode::Dense); + let types = vec![0, 0, 1].into(); + let fields = vec![ + Int32Array::from([Some(1), Some(3), Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + assert!(UnionArray::try_new(dtype, types, fields.clone(), None).is_err()); +} + +#[test] +fn fields_must_match() { + let fields = vec![ + Field::new("a".into(), ArrowDataType::Int64, true), + Field::new("b".into(), ArrowDataType::Utf8, true), + ]; + let dtype = union_type(fields, None, UnionMode::Sparse); + let types = vec![0, 0, 1].into(); + let fields = vec![ + Int32Array::from([Some(1), Some(3), Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + assert!(UnionArray::try_new(dtype, types, fields.clone(), None).is_err()); +} + +#[test] +fn sparse_with_offsets_is_error() { + let fields = vec![ + Field::new("a".into(), ArrowDataType::Int32, true), + Field::new("b".into(), ArrowDataType::Utf8, true), + ]; + let dtype = union_type(fields, None, UnionMode::Sparse); + let fields = vec![ + Int32Array::from([Some(1), Some(3), Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + let types = vec![0, 0, 1].into(); + let offsets = vec![0, 1, 0].into(); + + assert!(UnionArray::try_new(dtype, types, fields.clone(), Some(offsets)).is_err()); +} + +#[test] +fn offsets_must_be_in_bounds() { + let fields = vec![ + Field::new("a".into(), ArrowDataType::Int32, true), + Field::new("b".into(), ArrowDataType::Utf8, true), + ]; + let dtype = union_type(fields, None, UnionMode::Sparse); + let fields = vec![ + Int32Array::from([Some(1), Some(3), Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + let types = vec![0, 0, 1].into(); + // it must be equal to length og types + let offsets = vec![0, 1].into(); + + assert!(UnionArray::try_new(dtype, types, fields.clone(), Some(offsets)).is_err()); +} + +#[test] +fn sparse_with_wrong_offsets1_is_error() { + let fields = vec![ + Field::new("a".into(), ArrowDataType::Int32, true), + Field::new("b".into(), ArrowDataType::Utf8, true), + ]; + let dtype = union_type(fields, None, UnionMode::Sparse); + let fields = vec![ + Int32Array::from([Some(1), Some(3), Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + let types = vec![0, 0, 1].into(); + // it must be equal to length of types + let offsets = vec![0, 1, 10].into(); + + assert!(UnionArray::try_new(dtype, types, fields.clone(), Some(offsets)).is_err()); +} + +#[test] +fn types_must_be_in_bounds() -> PolarsResult<()> { + let fields = vec![ + Field::new("a".into(), ArrowDataType::Int32, true), + Field::new("b".into(), ArrowDataType::Utf8, true), + ]; + let dtype = union_type(fields, None, UnionMode::Sparse); + let fields = vec![ + Int32Array::from([Some(1), Some(3), Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + // 10 > num fields + let types = vec![0, 10].into(); + + assert!(UnionArray::try_new(dtype, types, fields.clone(), None).is_err()); + Ok(()) +} diff --git a/crates/polars/tests/it/arrow/array/utf8/mod.rs b/crates/polars/tests/it/arrow/array/utf8/mod.rs new file mode 100644 index 000000000000..89774dff4ae9 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/utf8/mod.rs @@ -0,0 +1,255 @@ +use arrow::array::*; +use arrow::bitmap::Bitmap; +use arrow::buffer::Buffer; +use arrow::datatypes::ArrowDataType; +use arrow::offset::OffsetsBuffer; +use polars_error::PolarsResult; + +mod mutable; +mod mutable_values; +mod to_mutable; + +fn array() -> Utf8Array { + vec![Some("hello"), None, Some("hello2")] + .into_iter() + .collect() +} + +#[test] +fn basics() { + let array = array(); + + assert_eq!(array.value(0), "hello"); + assert_eq!(array.value(1), ""); + assert_eq!(array.value(2), "hello2"); + assert_eq!(unsafe { array.value_unchecked(2) }, "hello2"); + assert_eq!(array.values().as_slice(), b"hellohello2"); + assert_eq!(array.offsets().as_slice(), &[0, 5, 5, 11]); + assert_eq!( + array.validity(), + Some(&Bitmap::from_u8_slice([0b00000101], 3)) + ); + assert!(array.is_valid(0)); + assert!(!array.is_valid(1)); + assert!(array.is_valid(2)); + + let array2 = Utf8Array::::new( + ArrowDataType::Utf8, + array.offsets().clone(), + array.values().clone(), + array.validity().cloned(), + ); + assert_eq!(array, array2); + + let array = array.sliced(1, 2); + assert_eq!(array.value(0), ""); + assert_eq!(array.value(1), "hello2"); + // note how this keeps everything: the offsets were sliced + assert_eq!(array.values().as_slice(), b"hellohello2"); + assert_eq!(array.offsets().as_slice(), &[5, 5, 11]); +} + +#[test] +fn split_at() { + let (lhs, rhs) = array().split_at(1); + + assert_eq!(lhs.value(0), "hello"); + assert_eq!(rhs.value(0), ""); + assert_eq!(rhs.value(1), "hello2"); + // note how this keeps everything: the offsets were sliced + assert_eq!(lhs.values().as_slice(), b"hellohello2"); + assert_eq!(rhs.values().as_slice(), b"hellohello2"); + assert_eq!(lhs.offsets().as_slice(), &[0, 5]); + assert_eq!(rhs.offsets().as_slice(), &[5, 5, 11]); +} + +#[test] +fn empty() { + let array = Utf8Array::::new_empty(ArrowDataType::Utf8); + assert_eq!(array.values().as_slice(), b""); + assert_eq!(array.offsets().as_slice(), &[0]); + assert_eq!(array.validity(), None); +} + +#[test] +fn from() { + let array = Utf8Array::::from([Some("hello"), Some(" "), None]); + + let a = array.validity().unwrap(); + assert_eq!(a, &Bitmap::from([true, true, false])); +} + +#[test] +fn from_slice() { + let b = Utf8Array::::from_slice(["a", "b", "cc"]); + + let offsets = vec![0, 1, 2, 4].try_into().unwrap(); + let values = b"abcc".to_vec().into(); + assert_eq!( + b, + Utf8Array::::new(ArrowDataType::Utf8, offsets, values, None) + ); +} + +#[test] +fn from_iter_values() { + let b = Utf8Array::::from_iter_values(["a", "b", "cc"].iter()); + + let offsets = vec![0, 1, 2, 4].try_into().unwrap(); + let values = b"abcc".to_vec().into(); + assert_eq!( + b, + Utf8Array::::new(ArrowDataType::Utf8, offsets, values, None) + ); +} + +#[test] +fn from_trusted_len_iter() { + let b = + Utf8Array::::from_trusted_len_iter(vec![Some("a"), Some("b"), Some("cc")].into_iter()); + + let offsets = vec![0, 1, 2, 4].try_into().unwrap(); + let values = b"abcc".to_vec().into(); + assert_eq!( + b, + Utf8Array::::new(ArrowDataType::Utf8, offsets, values, None) + ); +} + +#[test] +fn try_from_trusted_len_iter() { + let b = Utf8Array::::try_from_trusted_len_iter( + vec![Some("a"), Some("b"), Some("cc")] + .into_iter() + .map(PolarsResult::Ok), + ) + .unwrap(); + + let offsets = vec![0, 1, 2, 4].try_into().unwrap(); + let values = b"abcc".to_vec().into(); + assert_eq!( + b, + Utf8Array::::new(ArrowDataType::Utf8, offsets, values, None) + ); +} + +#[test] +fn not_utf8() { + let offsets = vec![0, 4].try_into().unwrap(); + let values = vec![0, 159, 146, 150].into(); // invalid utf8 + assert!(Utf8Array::::try_new(ArrowDataType::Utf8, offsets, values, None).is_err()); +} + +#[test] +fn not_utf8_individually() { + let offsets = vec![0, 1, 2].try_into().unwrap(); + let values = vec![207, 128].into(); // each is invalid utf8, but together is valid + assert!(Utf8Array::::try_new(ArrowDataType::Utf8, offsets, values, None).is_err()); +} + +#[test] +fn wrong_dtype() { + let offsets = vec![0, 4].try_into().unwrap(); + let values = b"abbb".to_vec().into(); + assert!(Utf8Array::::try_new(ArrowDataType::Int32, offsets, values, None).is_err()); +} + +#[test] +fn out_of_bounds_offsets_panics() { + // the 10 is out of bounds + let offsets = vec![0, 10, 11].try_into().unwrap(); + let values = b"abbb".to_vec().into(); + assert!(Utf8Array::::try_new(ArrowDataType::Utf8, offsets, values, None).is_err()); +} + +#[test] +#[should_panic] +fn index_out_of_bounds_panics() { + let offsets = vec![0, 1, 2, 4].try_into().unwrap(); + let values = b"abbb".to_vec().into(); + let array = Utf8Array::::new(ArrowDataType::Utf8, offsets, values, None); + + array.value(3); +} + +#[test] +fn debug() { + let array = Utf8Array::::from([Some("aa"), Some(""), None]); + + assert_eq!(format!("{array:?}"), "Utf8Array[aa, , None]"); +} + +#[test] +fn into_mut_1() { + let offsets = vec![0, 1].try_into().unwrap(); + let values = Buffer::from(b"a".to_vec()); + let a = values.clone(); // cloned values + assert_eq!(a, values); + let array = Utf8Array::::new(ArrowDataType::Utf8, offsets, values, None); + assert!(array.into_mut().is_left()); +} + +#[test] +fn into_mut_2() { + let offsets: OffsetsBuffer = vec![0, 1].try_into().unwrap(); + let values = b"a".to_vec().into(); + let a = offsets.clone(); // cloned offsets + assert_eq!(a, offsets); + let array = Utf8Array::::new(ArrowDataType::Utf8, offsets, values, None); + assert!(array.into_mut().is_left()); +} + +#[test] +fn into_mut_3() { + let offsets = vec![0, 1].try_into().unwrap(); + let values = b"a".to_vec().into(); + let validity = Some([true].into()); + let a = validity.clone(); // cloned validity + assert_eq!(a, validity); + let array = Utf8Array::::new(ArrowDataType::Utf8, offsets, values, validity); + assert!(array.into_mut().is_left()); +} + +#[test] +fn into_mut_4() { + let offsets = vec![0, 1].try_into().unwrap(); + let values = b"a".to_vec().into(); + let validity = Some([true].into()); + let array = Utf8Array::::new(ArrowDataType::Utf8, offsets, values, validity); + assert!(array.into_mut().is_right()); +} + +#[test] +fn rev_iter() { + let array = Utf8Array::::from([Some("hello"), Some(" "), None]); + + assert_eq!( + array.into_iter().rev().collect::>(), + vec![None, Some(" "), Some("hello")] + ); +} + +#[test] +fn iter_nth() { + let array = Utf8Array::::from([Some("hello"), Some(" "), None]); + + assert_eq!(array.iter().nth(1), Some(Some(" "))); + assert_eq!(array.iter().nth(10), None); +} + +#[test] +fn test_apply_validity() { + let mut array = Utf8Array::::from([Some("Red"), Some("Green"), Some("Blue")]); + array.set_validity(Some([true, true, true].into())); + + array.apply_validity(|bitmap| { + let mut mut_bitmap = bitmap.into_mut().right().unwrap(); + mut_bitmap.set(1, false); + mut_bitmap.set(2, false); + mut_bitmap.into() + }); + + assert!(array.is_valid(0)); + assert!(!array.is_valid(1)); + assert!(!array.is_valid(2)); +} diff --git a/crates/polars/tests/it/arrow/array/utf8/mutable.rs b/crates/polars/tests/it/arrow/array/utf8/mutable.rs new file mode 100644 index 000000000000..7f1725957085 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/utf8/mutable.rs @@ -0,0 +1,242 @@ +use arrow::array::{MutableArray, MutableUtf8Array, TryExtendFromSelf, Utf8Array}; +use arrow::bitmap::Bitmap; +use arrow::datatypes::ArrowDataType; + +#[test] +fn capacities() { + let b = MutableUtf8Array::::with_capacities(1, 10); + + assert!(b.values().capacity() >= 10); + assert!(b.offsets().capacity() >= 1); +} + +#[test] +fn push_null() { + let mut array = MutableUtf8Array::::new(); + array.push::<&str>(None); + + let array: Utf8Array = array.into(); + assert_eq!(array.validity(), Some(&Bitmap::from([false]))); +} + +#[test] +fn pop() { + let mut a = MutableUtf8Array::::new(); + a.push(Some("first")); + a.push(Some("second")); + a.push(Some("third")); + a.push::<&str>(None); + + assert_eq!(a.pop(), None); + assert_eq!(a.len(), 3); + assert_eq!(a.pop(), Some("third".to_owned())); + assert_eq!(a.len(), 2); + assert_eq!(a.pop(), Some("second".to_string())); + assert_eq!(a.len(), 1); + assert_eq!(a.pop(), Some("first".to_string())); + assert!(a.is_empty()); + assert_eq!(a.pop(), None); + assert!(a.is_empty()); +} + +#[test] +fn pop_all_some() { + let mut a = MutableUtf8Array::::new(); + a.push(Some("first")); + a.push(Some("second")); + a.push(Some("third")); + a.push(Some("fourth")); + for _ in 0..4 { + a.push(Some("aaaa")); + } + a.push(Some("こんにちは")); + + assert_eq!(a.pop(), Some("こんにちは".to_string())); + assert_eq!(a.pop(), Some("aaaa".to_string())); + assert_eq!(a.pop(), Some("aaaa".to_string())); + assert_eq!(a.pop(), Some("aaaa".to_string())); + assert_eq!(a.len(), 5); + assert_eq!(a.pop(), Some("aaaa".to_string())); + assert_eq!(a.pop(), Some("fourth".to_string())); + assert_eq!(a.pop(), Some("third".to_string())); + assert_eq!(a.pop(), Some("second".to_string())); + assert_eq!(a.pop(), Some("first".to_string())); + assert!(a.is_empty()); + assert_eq!(a.pop(), None); +} + +/// Safety guarantee +#[test] +fn not_utf8() { + let offsets = vec![0, 4].try_into().unwrap(); + let values = vec![0, 159, 146, 150]; // invalid utf8 + assert!(MutableUtf8Array::::try_new(ArrowDataType::Utf8, offsets, values, None).is_err()); +} + +#[test] +fn wrong_dtype() { + let offsets = vec![0, 4].try_into().unwrap(); + let values = vec![1, 2, 3, 4]; + assert!(MutableUtf8Array::::try_new(ArrowDataType::Int8, offsets, values, None).is_err()); +} + +#[test] +fn test_extend_trusted_len_values() { + let mut array = MutableUtf8Array::::new(); + + array.extend_trusted_len_values(["hi", "there"].iter()); + array.extend_trusted_len_values(["hello"].iter()); + array.extend_trusted_len(vec![Some("again"), None].into_iter()); + + let array: Utf8Array = array.into(); + + assert_eq!(array.values().as_slice(), b"hitherehelloagain"); + assert_eq!(array.offsets().as_slice(), &[0, 2, 7, 12, 17, 17]); + assert_eq!( + array.validity(), + Some(&Bitmap::from_u8_slice([0b00001111], 5)) + ); +} + +#[test] +fn test_extend_trusted_len() { + let mut array = MutableUtf8Array::::new(); + + array.extend_trusted_len(vec![Some("hi"), Some("there")].into_iter()); + array.extend_trusted_len(vec![None, Some("hello")].into_iter()); + array.extend_trusted_len_values(["again"].iter()); + + let array: Utf8Array = array.into(); + + assert_eq!(array.values().as_slice(), b"hitherehelloagain"); + assert_eq!(array.offsets().as_slice(), &[0, 2, 7, 7, 12, 17]); + assert_eq!( + array.validity(), + Some(&Bitmap::from_u8_slice([0b00011011], 5)) + ); +} + +#[test] +fn test_extend_values() { + let mut array = MutableUtf8Array::::new(); + + array.extend_values([Some("hi"), None, Some("there"), None].iter().flatten()); + array.extend_values([Some("hello"), None].iter().flatten()); + array.extend_values(vec![Some("again"), None].into_iter().flatten()); + + let array: Utf8Array = array.into(); + + assert_eq!(array.values().as_slice(), b"hitherehelloagain"); + assert_eq!(array.offsets().as_slice(), &[0, 2, 7, 12, 17]); + assert_eq!(array.validity(), None,); +} + +#[test] +fn test_extend() { + let mut array = MutableUtf8Array::::new(); + + array.extend([Some("hi"), None, Some("there"), None]); + + let array: Utf8Array = array.into(); + + assert_eq!( + array, + Utf8Array::::from([Some("hi"), None, Some("there"), None]) + ); +} + +#[test] +fn as_arc() { + let mut array = MutableUtf8Array::::new(); + + array.extend([Some("hi"), None, Some("there"), None]); + + assert_eq!( + Utf8Array::::from([Some("hi"), None, Some("there"), None]), + array.as_arc().as_ref() + ); +} + +#[test] +fn test_iter() { + let mut array = MutableUtf8Array::::new(); + + array.extend_trusted_len(vec![Some("hi"), Some("there")].into_iter()); + array.extend_trusted_len(vec![None, Some("hello")].into_iter()); + array.extend_trusted_len_values(["again"].iter()); + + let result = array.iter().collect::>(); + assert_eq!( + result, + vec![ + Some("hi"), + Some("there"), + None, + Some("hello"), + Some("again"), + ] + ); +} + +#[test] +fn as_box_twice() { + let mut a = MutableUtf8Array::::new(); + let _ = a.as_box(); + let _ = a.as_box(); + let mut a = MutableUtf8Array::::new(); + let _ = a.as_arc(); + let _ = a.as_arc(); +} + +#[test] +fn extend_from_self() { + let mut a = MutableUtf8Array::::from([Some("aa"), None]); + + a.try_extend_from_self(&a.clone()).unwrap(); + + assert_eq!( + a, + MutableUtf8Array::::from([Some("aa"), None, Some("aa"), None]) + ); +} + +#[test] +fn test_set_validity() { + let mut array = MutableUtf8Array::::from([Some("Red"), Some("Green"), Some("Blue")]); + array.set_validity(Some([false, false, true].into())); + + assert!(!array.is_valid(0)); + assert!(!array.is_valid(1)); + assert!(array.is_valid(2)); +} + +#[test] +fn test_apply_validity() { + let mut array = MutableUtf8Array::::from([Some("Red"), Some("Green"), Some("Blue")]); + array.set_validity(Some([true, true, true].into())); + + array.apply_validity(|mut mut_bitmap| { + mut_bitmap.set(1, false); + mut_bitmap.set(2, false); + mut_bitmap + }); + + assert!(array.is_valid(0)); + assert!(!array.is_valid(1)); + assert!(!array.is_valid(2)); +} + +#[test] +fn test_apply_validity_with_no_validity_inited() { + let mut array = MutableUtf8Array::::from([Some("Red"), Some("Green"), Some("Blue")]); + + array.apply_validity(|mut mut_bitmap| { + mut_bitmap.set(1, false); + mut_bitmap.set(2, false); + mut_bitmap + }); + + assert!(array.is_valid(0)); + assert!(array.is_valid(1)); + assert!(array.is_valid(2)); +} diff --git a/crates/polars/tests/it/arrow/array/utf8/mutable_values.rs b/crates/polars/tests/it/arrow/array/utf8/mutable_values.rs new file mode 100644 index 000000000000..d70316e18950 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/utf8/mutable_values.rs @@ -0,0 +1,105 @@ +use arrow::array::{MutableArray, MutableUtf8ValuesArray}; +use arrow::datatypes::ArrowDataType; + +#[test] +fn capacity() { + let mut b = MutableUtf8ValuesArray::::with_capacity(100); + + assert_eq!(b.values().capacity(), 0); + assert!(b.offsets().capacity() >= 100); + b.shrink_to_fit(); + assert!(b.offsets().capacity() < 100); +} + +#[test] +fn offsets_must_be_in_bounds() { + let offsets = vec![0, 10].try_into().unwrap(); + let values = b"abbbbb".to_vec(); + assert!(MutableUtf8ValuesArray::::try_new(ArrowDataType::Utf8, offsets, values).is_err()); +} + +#[test] +fn dtype_must_be_consistent() { + let offsets = vec![0, 4].try_into().unwrap(); + let values = b"abbb".to_vec(); + assert!(MutableUtf8ValuesArray::::try_new(ArrowDataType::Int32, offsets, values).is_err()); +} + +#[test] +fn must_be_utf8() { + let offsets = vec![0, 4].try_into().unwrap(); + let values = vec![0, 159, 146, 150]; + assert!(std::str::from_utf8(&values).is_err()); + assert!(MutableUtf8ValuesArray::::try_new(ArrowDataType::Utf8, offsets, values).is_err()); +} + +#[test] +fn as_box() { + let offsets = vec![0, 2].try_into().unwrap(); + let values = b"ab".to_vec(); + let mut b = + MutableUtf8ValuesArray::::try_new(ArrowDataType::Utf8, offsets, values).unwrap(); + let _ = b.as_box(); +} + +#[test] +fn as_arc() { + let offsets = vec![0, 2].try_into().unwrap(); + let values = b"ab".to_vec(); + let mut b = + MutableUtf8ValuesArray::::try_new(ArrowDataType::Utf8, offsets, values).unwrap(); + let _ = b.as_arc(); +} + +#[test] +fn extend_trusted_len() { + let offsets = vec![0, 2].try_into().unwrap(); + let values = b"ab".to_vec(); + let mut b = + MutableUtf8ValuesArray::::try_new(ArrowDataType::Utf8, offsets, values).unwrap(); + b.extend_trusted_len(vec!["a", "b"].into_iter()); + + let offsets = vec![0, 2, 3, 4].try_into().unwrap(); + let values = b"abab".to_vec(); + assert_eq!( + b.as_box(), + MutableUtf8ValuesArray::::try_new(ArrowDataType::Utf8, offsets, values) + .unwrap() + .as_box() + ) +} + +#[test] +fn from_trusted_len() { + let mut b = MutableUtf8ValuesArray::::from_trusted_len_iter(vec!["a", "b"].into_iter()); + + let offsets = vec![0, 1, 2].try_into().unwrap(); + let values = b"ab".to_vec(); + assert_eq!( + b.as_box(), + MutableUtf8ValuesArray::::try_new(ArrowDataType::Utf8, offsets, values) + .unwrap() + .as_box() + ) +} + +#[test] +fn extend_from_iter() { + let offsets = vec![0, 2].try_into().unwrap(); + let values = b"ab".to_vec(); + let mut b = + MutableUtf8ValuesArray::::try_new(ArrowDataType::Utf8, offsets, values).unwrap(); + b.extend_trusted_len(vec!["a", "b"].into_iter()); + + let a = b.clone(); + b.extend_trusted_len(a.iter()); + + let offsets = vec![0, 2, 3, 4, 6, 7, 8].try_into().unwrap(); + let values = b"abababab".to_vec(); + assert_eq!( + b.as_box(), + MutableUtf8ValuesArray::::try_new(ArrowDataType::Utf8, offsets, values) + .unwrap() + .as_box() + ) +} diff --git a/crates/polars/tests/it/arrow/array/utf8/to_mutable.rs b/crates/polars/tests/it/arrow/array/utf8/to_mutable.rs new file mode 100644 index 000000000000..5f2624368bf7 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/utf8/to_mutable.rs @@ -0,0 +1,71 @@ +use arrow::array::Utf8Array; +use arrow::bitmap::Bitmap; +use arrow::buffer::Buffer; +use arrow::datatypes::ArrowDataType; +use arrow::offset::OffsetsBuffer; + +#[test] +fn not_shared() { + let array = Utf8Array::::from([Some("hello"), Some(" "), None]); + assert!(array.into_mut().is_right()); +} + +#[test] +#[allow(clippy::redundant_clone)] +fn shared_validity() { + let validity = Bitmap::from([true]); + let array = Utf8Array::::new( + ArrowDataType::Utf8, + vec![0, 1].try_into().unwrap(), + b"a".to_vec().into(), + Some(validity.clone()), + ); + assert!(array.into_mut().is_left()) +} + +#[test] +#[allow(clippy::redundant_clone)] +fn shared_values() { + let values: Buffer = b"a".to_vec().into(); + let array = Utf8Array::::new( + ArrowDataType::Utf8, + vec![0, 1].try_into().unwrap(), + values.clone(), + Some(Bitmap::from([true])), + ); + assert!(array.into_mut().is_left()) +} + +#[test] +#[allow(clippy::redundant_clone)] +fn shared_offsets_values() { + let offsets: OffsetsBuffer = vec![0, 1].try_into().unwrap(); + let values: Buffer = b"a".to_vec().into(); + let array = Utf8Array::::new( + ArrowDataType::Utf8, + offsets.clone(), + values.clone(), + Some(Bitmap::from([true])), + ); + assert!(array.into_mut().is_left()) +} + +#[test] +#[allow(clippy::redundant_clone)] +fn shared_offsets() { + let offsets: OffsetsBuffer = vec![0, 1].try_into().unwrap(); + let array = Utf8Array::::new( + ArrowDataType::Utf8, + offsets.clone(), + b"a".to_vec().into(), + Some(Bitmap::from([true])), + ); + assert!(array.into_mut().is_left()) +} + +#[test] +#[allow(clippy::redundant_clone)] +fn shared_all() { + let array = Utf8Array::::from([Some("hello"), Some(" "), None]); + assert!(array.clone().into_mut().is_left()) +} diff --git a/crates/polars/tests/it/arrow/bitmap/assign_ops.rs b/crates/polars/tests/it/arrow/bitmap/assign_ops.rs new file mode 100644 index 000000000000..31ebfdad2907 --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/assign_ops.rs @@ -0,0 +1,72 @@ +use arrow::bitmap::{Bitmap, MutableBitmap, binary_assign, unary_assign}; +use proptest::prelude::*; + +use super::bitmap_strategy; + +#[test] +fn basics() { + let mut b = MutableBitmap::from_iter(std::iter::repeat_n(true, 10)); + unary_assign(&mut b, |x: u8| !x); + assert_eq!(b, MutableBitmap::from_iter(std::iter::repeat_n(false, 10))); + + let mut b = MutableBitmap::from_iter(std::iter::repeat_n(true, 10)); + let c = Bitmap::from_iter(std::iter::repeat_n(true, 10)); + binary_assign(&mut b, &c, |x: u8, y| x | y); + assert_eq!(b, MutableBitmap::from_iter(std::iter::repeat_n(true, 10))); +} + +#[test] +fn binary_assign_oob() { + // this check we don't have an oob access if the bitmaps are size T + 1 + // and we do some slicing. + let a = MutableBitmap::from_iter(std::iter::repeat_n(true, 65)); + let b = MutableBitmap::from_iter(std::iter::repeat_n(true, 65)); + + let a: Bitmap = a.into(); + let a = a.sliced(10, 20); + + let b: Bitmap = b.into(); + let b = b.sliced(10, 20); + + let mut a = a.make_mut(); + + binary_assign(&mut a, &b, |x: u64, y| x & y); +} + +#[test] +fn fast_paths() { + let b = MutableBitmap::from([true, false]); + let c = Bitmap::from_iter([true, true]); + let b = b & &c; + assert_eq!(b, MutableBitmap::from_iter([true, false])); + + let b = MutableBitmap::from([true, false]); + let c = Bitmap::from_iter([false, false]); + let b = b & &c; + assert_eq!(b, MutableBitmap::from_iter([false, false])); + + let b = MutableBitmap::from([true, false]); + let c = Bitmap::from_iter([true, true]); + let b = b | &c; + assert_eq!(b, MutableBitmap::from_iter([true, true])); + + let b = MutableBitmap::from([true, false]); + let c = Bitmap::from_iter([false, false]); + let b = b | &c; + assert_eq!(b, MutableBitmap::from_iter([true, false])); +} + +proptest! { + /// Asserts that !bitmap equals all bits flipped + #[test] + #[cfg_attr(miri, ignore)] // miri and proptest do not work well :( + fn not(b in bitmap_strategy()) { + let not_b: MutableBitmap = b.iter().map(|x| !x).collect(); + + let mut b = b.make_mut(); + + unary_assign(&mut b, |x: u8| !x); + + assert_eq!(b, not_b); + } +} diff --git a/crates/polars/tests/it/arrow/bitmap/bitmap_ops.rs b/crates/polars/tests/it/arrow/bitmap/bitmap_ops.rs new file mode 100644 index 000000000000..bea0b248d2ea --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/bitmap_ops.rs @@ -0,0 +1,40 @@ +use arrow::bitmap::{Bitmap, and, or, xor}; +use proptest::prelude::*; + +use super::bitmap_strategy; + +proptest! { + /// Asserts that !bitmap equals all bits flipped + #[test] + #[cfg_attr(miri, ignore)] // miri and proptest do not work well :( + fn not(bitmap in bitmap_strategy()) { + let not_bitmap: Bitmap = bitmap.iter().map(|x| !x).collect(); + + assert_eq!(!&bitmap, not_bitmap); + } +} + +#[test] +fn test_fast_paths() { + let all_true = Bitmap::from(&[true, true]); + let all_false = Bitmap::from(&[false, false]); + let toggled = Bitmap::from(&[true, false]); + + assert_eq!(and(&all_true, &all_true), all_true); + assert_eq!(and(&all_false, &all_true), all_false); + assert_eq!(and(&all_true, &all_false), all_false); + assert_eq!(and(&toggled, &all_false), all_false); + assert_eq!(and(&toggled, &all_true), toggled); + + assert_eq!(or(&all_true, &all_true), all_true); + assert_eq!(or(&all_true, &all_false), all_true); + assert_eq!(or(&all_false, &all_true), all_true); + assert_eq!(or(&all_false, &all_false), all_false); + assert_eq!(or(&toggled, &all_false), toggled); + + assert_eq!(xor(&all_true, &all_true), all_false); + assert_eq!(xor(&all_true, &all_false), all_true); + assert_eq!(xor(&all_false, &all_true), all_true); + assert_eq!(xor(&all_false, &all_false), all_false); + assert_eq!(xor(&toggled, &toggled), all_false); +} diff --git a/crates/polars/tests/it/arrow/bitmap/immutable.rs b/crates/polars/tests/it/arrow/bitmap/immutable.rs new file mode 100644 index 000000000000..336322534de2 --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/immutable.rs @@ -0,0 +1,78 @@ +use arrow::array::Splitable; +use arrow::bitmap::Bitmap; + +#[test] +fn as_slice() { + let b = Bitmap::from([true, true, true, true, true, true, true, true, true]); + + let (slice, offset, length) = b.as_slice(); + assert_eq!(slice, &[0b11111111, 0b1]); + assert_eq!(offset, 0); + assert_eq!(length, 9); +} + +#[test] +fn as_slice_offset() { + let b = Bitmap::from([true, true, true, true, true, true, true, true, true]); + let b = b.sliced(8, 1); + + let (slice, offset, length) = b.as_slice(); + assert_eq!(slice, &[0b1]); + assert_eq!(offset, 0); + assert_eq!(length, 1); +} + +#[test] +fn as_slice_offset_middle() { + let b = Bitmap::from_u8_slice([0, 0, 0, 0b00010101], 27); + let b = b.sliced(22, 5); + + let (slice, offset, length) = b.as_slice(); + assert_eq!(slice, &[0, 0b00010101]); + assert_eq!(offset, 6); + assert_eq!(length, 5); +} + +#[test] +fn split_at_unset_bits() { + let bm = Bitmap::from_u8_slice([0b01101010, 0, 0, 0b100], 27); + + assert_eq!(bm.unset_bits(), 22); + + let (lhs, rhs) = bm.split_at(5); + assert_eq!(lhs.lazy_unset_bits(), Some(3)); + assert_eq!(rhs.lazy_unset_bits(), Some(19)); + + let (lhs, rhs) = bm.split_at(22); + assert_eq!(lhs.lazy_unset_bits(), Some(18)); + assert_eq!(rhs.lazy_unset_bits(), Some(4)); + + let (lhs, rhs) = bm.split_at(0); + assert_eq!(lhs.lazy_unset_bits(), Some(0)); + assert_eq!(rhs.lazy_unset_bits(), Some(22)); + + let (lhs, rhs) = bm.split_at(27); + assert_eq!(lhs.lazy_unset_bits(), Some(22)); + assert_eq!(rhs.lazy_unset_bits(), Some(0)); + + let bm = Bitmap::new_zeroed(1024); + let (lhs, rhs) = bm.split_at(512); + assert_eq!(lhs.lazy_unset_bits(), Some(512)); + assert_eq!(rhs.lazy_unset_bits(), Some(512)); + + let bm = Bitmap::new_with_value(true, 1024); + let (lhs, rhs) = bm.split_at(512); + assert_eq!(lhs.lazy_unset_bits(), Some(0)); + assert_eq!(rhs.lazy_unset_bits(), Some(0)); +} + +#[test] +fn debug() { + let b = Bitmap::from([true, true, false, true, true, true, true, true, true]); + let b = b.sliced(2, 7); + + assert_eq!( + format!("{b:?}"), + "Bitmap { len: 7, offset: 2, bytes: [0b111110__, 0b_______1] }" + ); +} diff --git a/crates/polars/tests/it/arrow/bitmap/mod.rs b/crates/polars/tests/it/arrow/bitmap/mod.rs new file mode 100644 index 000000000000..ac5b984f85a7 --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/mod.rs @@ -0,0 +1,137 @@ +mod assign_ops; +mod bitmap_ops; +mod immutable; +mod mutable; +mod utils; + +use arrow::array::Splitable; +use arrow::bitmap::Bitmap; +use proptest::prelude::*; + +/// Returns a strategy of an arbitrary sliced [`Bitmap`] of size up to 1000 +pub(crate) fn bitmap_strategy() -> impl Strategy { + prop::collection::vec(any::(), 1..1000) + .prop_flat_map(|vec| { + let len = vec.len(); + (Just(vec), 0..len) + }) + .prop_flat_map(|(vec, index)| { + let len = vec.len(); + (Just(vec), Just(index), 0..len - index) + }) + .prop_flat_map(|(vec, index, len)| { + let bitmap = Bitmap::from(&vec); + let bitmap = bitmap.sliced(index, len); + Just(bitmap) + }) +} + +fn create_bitmap>(bytes: P, len: usize) -> Bitmap { + let buffer = Vec::::from(bytes.as_ref()); + Bitmap::from_u8_vec(buffer, len) +} + +#[test] +fn eq() { + let lhs = create_bitmap([0b01101010], 8); + let rhs = create_bitmap([0b01001110], 8); + assert!(lhs != rhs); +} + +#[test] +fn eq_len() { + let lhs = create_bitmap([0b01101010], 6); + let rhs = create_bitmap([0b00101010], 6); + assert!(lhs == rhs); + let rhs = create_bitmap([0b00001010], 6); + assert!(lhs != rhs); +} + +#[test] +fn eq_slice() { + let lhs = create_bitmap([0b10101010], 8).sliced(1, 7); + let rhs = create_bitmap([0b10101011], 8).sliced(1, 7); + assert!(lhs == rhs); + + let lhs = create_bitmap([0b10101010], 8).sliced(2, 6); + let rhs = create_bitmap([0b10101110], 8).sliced(2, 6); + assert!(lhs != rhs); +} + +#[test] +fn and() { + let lhs = create_bitmap([0b01101010], 8); + let rhs = create_bitmap([0b01001110], 8); + let expected = create_bitmap([0b01001010], 8); + assert_eq!(&lhs & &rhs, expected); +} + +#[test] +fn or_large() { + let input: &[u8] = &[ + 0b00000000, 0b00000001, 0b00000010, 0b00000100, 0b00001000, 0b00010000, 0b00100000, + 0b01000010, 0b11111111, + ]; + let input1: &[u8] = &[ + 0b00000000, 0b00000001, 0b10000000, 0b10000000, 0b10000000, 0b10000000, 0b10000000, + 0b10000000, 0b11111111, + ]; + let expected: &[u8] = &[ + 0b00000000, 0b00000001, 0b10000010, 0b10000100, 0b10001000, 0b10010000, 0b10100000, + 0b11000010, 0b11111111, + ]; + + let lhs = create_bitmap(input, 62); + let rhs = create_bitmap(input1, 62); + let expected = create_bitmap(expected, 62); + assert_eq!(&lhs | &rhs, expected); +} + +#[test] +fn and_offset() { + let lhs = create_bitmap([0b01101011], 8).sliced(1, 7); + let rhs = create_bitmap([0b01001111], 8).sliced(1, 7); + let expected = create_bitmap([0b01001010], 8).sliced(1, 7); + assert_eq!(&lhs & &rhs, expected); +} + +#[test] +fn or() { + let lhs = create_bitmap([0b01101010], 8); + let rhs = create_bitmap([0b01001110], 8); + let expected = create_bitmap([0b01101110], 8); + assert_eq!(&lhs | &rhs, expected); +} + +#[test] +fn not() { + let lhs = create_bitmap([0b01101010], 6); + let expected = create_bitmap([0b00010101], 6); + assert_eq!(!&lhs, expected); +} + +#[test] +fn subslicing_gives_correct_null_count() { + let base = Bitmap::from([false, true, true, false, false, true, true, true]); + assert_eq!(base.unset_bits(), 3); + + let view1 = base.clone().sliced(0, 1); + let view2 = base.sliced(1, 7); + assert_eq!(view1.unset_bits(), 1); + assert_eq!(view2.unset_bits(), 2); + + let view3 = view2.sliced(0, 1); + assert_eq!(view3.unset_bits(), 0); +} + +#[test] +fn split_at() { + let bm = create_bitmap([0b01101010], 8); + + let (lhs, rhs) = bm.split_at(5); + assert_eq!( + &lhs.iter().collect::>(), + &[false, true, false, true, false] + ); + assert_eq!(&rhs.iter().collect::>(), &[true, true, false]); +} diff --git a/crates/polars/tests/it/arrow/bitmap/mutable.rs b/crates/polars/tests/it/arrow/bitmap/mutable.rs new file mode 100644 index 000000000000..c7c6cf168a46 --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/mutable.rs @@ -0,0 +1,442 @@ +use arrow::bitmap::{Bitmap, MutableBitmap}; + +#[test] +fn from_slice() { + let slice = &[true, false, true]; + let a = MutableBitmap::from(slice); + assert_eq!(a.iter().collect::>(), slice); +} + +#[test] +fn from_len_zeroed() { + let a = MutableBitmap::from_len_zeroed(10); + assert_eq!(a.len(), 10); + assert_eq!(a.unset_bits(), 10); +} + +#[test] +fn from_len_set() { + let a = MutableBitmap::from_len_set(10); + assert_eq!(a.len(), 10); + assert_eq!(a.unset_bits(), 0); +} + +#[test] +fn try_new_invalid() { + assert!(MutableBitmap::try_new(vec![], 2).is_err()); +} + +#[test] +fn clear() { + let mut a = MutableBitmap::from_len_zeroed(10); + a.clear(); + assert_eq!(a.len(), 0); +} + +#[test] +fn trusted_len() { + let data = vec![true; 65]; + let bitmap = MutableBitmap::from_trusted_len_iter(data.into_iter()); + let bitmap: Bitmap = bitmap.into(); + assert_eq!(bitmap.len(), 65); + + assert_eq!(bitmap.as_slice().0[8], 0b00000001); +} + +#[test] +fn trusted_len_small() { + let data = vec![true; 7]; + let bitmap = MutableBitmap::from_trusted_len_iter(data.into_iter()); + let bitmap: Bitmap = bitmap.into(); + assert_eq!(bitmap.len(), 7); + + assert_eq!(bitmap.as_slice().0[0], 0b01111111); +} + +#[test] +fn push() { + let mut bitmap = MutableBitmap::new(); + bitmap.push(true); + bitmap.push(false); + bitmap.push(false); + for _ in 0..7 { + bitmap.push(true) + } + let bitmap: Bitmap = bitmap.into(); + assert_eq!(bitmap.len(), 10); + + assert_eq!(bitmap.as_slice().0, &[0b11111001, 0b00000011]); +} + +#[test] +fn push_small() { + let mut bitmap = MutableBitmap::new(); + bitmap.push(true); + bitmap.push(true); + bitmap.push(false); + let bitmap: Option = bitmap.into(); + let bitmap = bitmap.unwrap(); + assert_eq!(bitmap.len(), 3); + assert_eq!(bitmap.as_slice().0[0], 0b00000011); +} + +#[test] +fn push_exact_zeros() { + let mut bitmap = MutableBitmap::new(); + for _ in 0..8 { + bitmap.push(false) + } + let bitmap: Option = bitmap.into(); + let bitmap = bitmap.unwrap(); + assert_eq!(bitmap.len(), 8); + assert_eq!(bitmap.as_slice().0.len(), 1); +} + +#[test] +fn push_exact_ones() { + let mut bitmap = MutableBitmap::new(); + for _ in 0..8 { + bitmap.push(true) + } + let bitmap: Option = bitmap.into(); + assert!(bitmap.is_none()); +} + +#[test] +fn pop() { + let mut bitmap = MutableBitmap::new(); + bitmap.push(false); + bitmap.push(true); + bitmap.push(false); + bitmap.push(true); + + assert_eq!(bitmap.pop(), Some(true)); + assert_eq!(bitmap.len(), 3); + + assert_eq!(bitmap.pop(), Some(false)); + assert_eq!(bitmap.len(), 2); + + let bitmap: Bitmap = bitmap.into(); + assert_eq!(bitmap.len(), 2); + assert_eq!(bitmap.as_slice().0[0], 0b00001010); +} + +#[test] +fn pop_large() { + let mut bitmap = MutableBitmap::new(); + for _ in 0..8 { + bitmap.push(true); + } + + bitmap.push(false); + bitmap.push(true); + bitmap.push(false); + + assert_eq!(bitmap.pop(), Some(false)); + assert_eq!(bitmap.len(), 10); + + assert_eq!(bitmap.pop(), Some(true)); + assert_eq!(bitmap.len(), 9); + + assert_eq!(bitmap.pop(), Some(false)); + assert_eq!(bitmap.len(), 8); + + let bitmap: Bitmap = bitmap.into(); + assert_eq!(bitmap.len(), 8); + assert_eq!(bitmap.as_slice().0, &[0b11111111]); +} + +#[test] +fn pop_all() { + let mut bitmap = MutableBitmap::new(); + bitmap.push(false); + bitmap.push(true); + bitmap.push(true); + bitmap.push(true); + + assert_eq!(bitmap.pop(), Some(true)); + assert_eq!(bitmap.len(), 3); + assert_eq!(bitmap.pop(), Some(true)); + assert_eq!(bitmap.len(), 2); + assert_eq!(bitmap.pop(), Some(true)); + assert_eq!(bitmap.len(), 1); + assert_eq!(bitmap.pop(), Some(false)); + assert_eq!(bitmap.len(), 0); + assert_eq!(bitmap.pop(), None); + assert_eq!(bitmap.len(), 0); +} + +#[test] +fn capacity() { + let b = MutableBitmap::with_capacity(10); + assert!(b.capacity() >= 10); +} + +#[test] +fn capacity_push() { + let mut b = MutableBitmap::with_capacity(512); + (0..512).for_each(|_| b.push(true)); + assert_eq!(b.capacity(), 512); + b.reserve(8); + assert_eq!(b.capacity(), 1024); +} + +#[test] +fn extend() { + let mut b = MutableBitmap::new(); + + let iter = (0..512).map(|i| i % 6 == 0); + unsafe { b.extend_from_trusted_len_iter_unchecked(iter) }; + let b: Bitmap = b.into(); + for (i, v) in b.iter().enumerate() { + assert_eq!(i % 6 == 0, v); + } +} + +#[test] +fn extend_offset() { + let mut b = MutableBitmap::new(); + b.push(true); + + let iter = (0..512).map(|i| i % 6 == 0); + unsafe { b.extend_from_trusted_len_iter_unchecked(iter) }; + let b: Bitmap = b.into(); + let mut iter = b.iter().enumerate(); + assert!(iter.next().unwrap().1); + for (i, v) in iter { + assert_eq!((i - 1) % 6 == 0, v); + } +} + +#[test] +fn set() { + let mut bitmap = MutableBitmap::from_len_zeroed(12); + bitmap.set(0, true); + assert!(bitmap.get(0)); + bitmap.set(0, false); + assert!(!bitmap.get(0)); + + bitmap.set(11, true); + assert!(bitmap.get(11)); + bitmap.set(11, false); + assert!(!bitmap.get(11)); + bitmap.set(11, true); + + let bitmap: Option = bitmap.into(); + let bitmap = bitmap.unwrap(); + assert_eq!(bitmap.len(), 12); + assert_eq!(bitmap.as_slice().0[0], 0b00000000); +} + +#[test] +fn extend_from_bitmap() { + let other = Bitmap::from(&[true, false, true]); + let mut bitmap = MutableBitmap::new(); + + // call is optimized to perform a memcopy + bitmap.extend_from_bitmap(&other); + + assert_eq!(bitmap.len(), 3); + assert_eq!(bitmap.as_slice()[0], 0b00000101); + + // this call iterates over all bits + bitmap.extend_from_bitmap(&other); + + assert_eq!(bitmap.len(), 6); + assert_eq!(bitmap.as_slice()[0], 0b00101101); +} + +#[test] +fn extend_from_bitmap_offset() { + let other = Bitmap::from_u8_slice([0b00111111], 8); + let mut bitmap = MutableBitmap::from_vec(vec![1, 0, 0b00101010], 22); + + // call is optimized to perform a memcopy + bitmap.extend_from_bitmap(&other); + + assert_eq!(bitmap.len(), 22 + 8); + assert_eq!(bitmap.as_slice(), &[1, 0, 0b11101010, 0b00001111]); + + // more than one byte + let other = Bitmap::from_u8_slice([0b00111111, 0b00001111, 0b0001100], 20); + let mut bitmap = MutableBitmap::from_vec(vec![1, 0, 0b00101010], 22); + + // call is optimized to perform a memcopy + bitmap.extend_from_bitmap(&other); + + assert_eq!(bitmap.len(), 22 + 20); + assert_eq!( + bitmap.as_slice(), + &[1, 0, 0b11101010, 0b11001111, 0b0000011, 0b0000011] + ); +} + +#[test] +fn debug() { + let mut b = MutableBitmap::new(); + assert_eq!(format!("{b:?}"), "Bitmap { len: 0, offset: 0, bytes: [] }"); + b.push(true); + b.push(false); + assert_eq!( + format!("{b:?}"), + "Bitmap { len: 2, offset: 0, bytes: [0b______01] }" + ); + b.push(false); + b.push(false); + b.push(false); + b.push(false); + b.push(true); + b.push(true); + assert_eq!( + format!("{b:?}"), + "Bitmap { len: 8, offset: 0, bytes: [0b11000001] }" + ); + b.push(true); + assert_eq!( + format!("{b:?}"), + "Bitmap { len: 9, offset: 0, bytes: [0b11000001, 0b_______1] }" + ); +} + +#[test] +fn extend_set() { + let mut b = MutableBitmap::new(); + b.extend_constant(6, true); + assert_eq!(b.as_slice(), &[0b11111111]); + assert_eq!(b.len(), 6); + + let mut b = MutableBitmap::from(&[false]); + b.extend_constant(6, true); + assert_eq!(b.as_slice(), &[0b01111110]); + assert_eq!(b.len(), 1 + 6); + + let mut b = MutableBitmap::from(&[false]); + b.extend_constant(9, true); + assert_eq!(b.as_slice(), &[0b11111110, 0b11111111]); + assert_eq!(b.len(), 1 + 9); + + let mut b = MutableBitmap::from(&[false, false, false, false]); + b.extend_constant(2, true); + assert_eq!(b.as_slice(), &[0b00110000]); + assert_eq!(b.len(), 4 + 2); + + let mut b = MutableBitmap::from(&[false, false, false, false]); + b.extend_constant(8, true); + assert_eq!(b.as_slice(), &[0b11110000, 0b11111111]); + assert_eq!(b.len(), 4 + 8); + + let mut b = MutableBitmap::from(&[true, true]); + b.extend_constant(3, true); + assert_eq!(b.as_slice(), &[0b00011111]); + assert_eq!(b.len(), 2 + 3); +} + +#[test] +fn extend_unset() { + let mut b = MutableBitmap::new(); + b.extend_constant(6, false); + assert_eq!(b.as_slice(), &[0b0000000]); + assert_eq!(b.len(), 6); + + let mut b = MutableBitmap::from(&[true]); + b.extend_constant(6, false); + assert_eq!(b.as_slice(), &[0b00000001]); + assert_eq!(b.len(), 1 + 6); + + let mut b = MutableBitmap::from(&[true]); + b.extend_constant(9, false); + assert_eq!(b.as_slice(), &[0b0000001, 0b00000000]); + assert_eq!(b.len(), 1 + 9); + + let mut b = MutableBitmap::from(&[true, true, true, true]); + b.extend_constant(2, false); + assert_eq!(b.as_slice(), &[0b00001111]); + assert_eq!(b.len(), 4 + 2); +} + +#[test] +fn extend_bitmap() { + let mut b = MutableBitmap::from(&[true]); + b.extend_from_slice(&[0b00011001], 0, 6); + assert_eq!(b.as_slice(), &[0b00110011]); + assert_eq!(b.len(), 1 + 6); + + let mut b = MutableBitmap::from(&[true]); + b.extend_from_slice(&[0b00011001, 0b00011001], 0, 9); + assert_eq!(b.as_slice(), &[0b00110011, 0b00110010]); + assert_eq!(b.len(), 1 + 9); + + let mut b = MutableBitmap::from(&[true, true, true, true]); + b.extend_from_slice(&[0b00011001, 0b00011001], 0, 9); + assert_eq!(b.as_slice(), &[0b10011111, 0b10010001]); + assert_eq!(b.len(), 4 + 9); + + let mut b = MutableBitmap::from(&[true, true, true, true, true]); + b.extend_from_slice(&[0b00001011], 0, 4); + assert_eq!(b.as_slice(), &[0b01111111, 0b00000001]); + assert_eq!(b.len(), 5 + 4); +} + +// TODO! undo miri ignore once issue is fixed in miri +// this test was a memory hog and lead to OOM in CI +// given enough memory it was able to pass successfully on a local +#[test] +#[cfg_attr(miri, ignore)] +fn extend_constant1() { + use std::iter::FromIterator; + for i in 0..64 { + for j in 0..64 { + let mut b = MutableBitmap::new(); + b.extend_constant(i, false); + b.extend_constant(j, true); + assert_eq!( + b, + MutableBitmap::from_iter( + std::iter::repeat_n(false, i).chain(std::iter::repeat_n(true, j)) + ) + ); + + let mut b = MutableBitmap::new(); + b.extend_constant(i, true); + b.extend_constant(j, false); + assert_eq!( + b, + MutableBitmap::from_iter( + std::iter::repeat_n(true, i).chain(std::iter::repeat_n(false, j)) + ) + ); + } + } +} + +#[test] +fn extend_bitmap_one() { + for offset in 0..7 { + let mut b = MutableBitmap::new(); + for _ in 0..4 { + b.extend_from_slice(&[!0], offset, 1); + b.extend_from_slice(&[!0], offset, 1); + } + assert_eq!(b.as_slice(), &[0b11111111]); + } +} + +#[test] +fn extend_bitmap_other() { + let mut a = MutableBitmap::from([true, true, true, false, true, true, true, false, true, true]); + a.extend_from_slice(&[0b01111110u8, 0b10111111, 0b11011111, 0b00000111], 20, 2); + assert_eq!( + a, + MutableBitmap::from([ + true, true, true, false, true, true, true, false, true, true, true, false + ]) + ); +} + +#[test] +fn shrink_to_fit() { + let mut a = MutableBitmap::with_capacity(1025); + a.push(false); + a.shrink_to_fit(); + assert!(a.capacity() < 1025); +} diff --git a/crates/polars/tests/it/arrow/bitmap/utils/bit_chunks_exact.rs b/crates/polars/tests/it/arrow/bitmap/utils/bit_chunks_exact.rs new file mode 100644 index 000000000000..104db7fdc3bb --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/utils/bit_chunks_exact.rs @@ -0,0 +1,33 @@ +use arrow::bitmap::utils::BitChunksExact; + +#[test] +fn basics() { + let mut iter = BitChunksExact::::new(&[0b11111111u8, 0b00000001u8], 9); + assert_eq!(iter.next().unwrap(), 0b11111111u8); + assert_eq!(iter.remainder(), 0b00000001u8); +} + +#[test] +fn basics_u16_small() { + let mut iter = BitChunksExact::::new(&[0b11111111u8], 7); + assert_eq!(iter.next(), None); + assert_eq!(iter.remainder(), 0b0000_0000_1111_1111u16); +} + +#[test] +fn basics_u16() { + let mut iter = BitChunksExact::::new(&[0b11111111u8, 0b00000001u8], 9); + assert_eq!(iter.next(), None); + assert_eq!(iter.remainder(), 0b0000_0001_1111_1111u16); +} + +#[test] +fn remainder_u16() { + let mut iter = BitChunksExact::::new( + &[0b11111111u8, 0b00000001u8, 0b00000001u8, 0b11011011u8], + 23, + ); + assert_eq!(iter.next(), Some(511)); + assert_eq!(iter.next(), None); + assert_eq!(iter.remainder(), 1u16); +} diff --git a/crates/polars/tests/it/arrow/bitmap/utils/chunk_iter.rs b/crates/polars/tests/it/arrow/bitmap/utils/chunk_iter.rs new file mode 100644 index 000000000000..d19b6e51b5ed --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/utils/chunk_iter.rs @@ -0,0 +1,163 @@ +use arrow::bitmap::utils::BitChunks; +use arrow::types::BitChunkIter; + +#[test] +fn basics() { + let mut iter = BitChunks::::new(&[0b00000001u8, 0b00000010u8], 0, 16); + assert_eq!(iter.next().unwrap(), 0b0000_0010_0000_0001u16); + assert_eq!(iter.remainder(), 0); +} + +#[test] +fn remainder() { + let a = BitChunks::::new(&[0b00000001u8, 0b00000010u8, 0b00000100u8], 0, 18); + assert_eq!(a.remainder(), 0b00000100u16); +} + +#[test] +fn remainder_saturating() { + let a = BitChunks::::new(&[0b00000001u8, 0b00000010u8, 0b00000010u8], 0, 18); + assert_eq!(a.remainder(), 0b0000_0000_0000_0010u16); +} + +#[test] +fn basics_offset() { + let mut iter = BitChunks::::new(&[0b00000001u8, 0b00000011u8, 0b00000001u8], 1, 16); + assert_eq!(iter.remainder(), 0); + assert_eq!(iter.next().unwrap(), 0b1000_0001_1000_0000u16); + assert_eq!(iter.next(), None); +} + +#[test] +fn basics_offset_remainder() { + let mut a = BitChunks::::new(&[0b00000001u8, 0b00000011u8, 0b10000001u8], 1, 15); + assert_eq!(a.next(), None); + assert_eq!(a.remainder(), 0b1000_0001_1000_0000u16); + assert_eq!(a.remainder_len(), 15); +} + +#[test] +fn offset_remainder_saturating() { + let a = BitChunks::::new(&[0b00000001u8, 0b00000011u8, 0b00000011u8], 1, 17); + assert_eq!(a.remainder(), 0b0000_0000_0000_0001u16); +} + +#[test] +fn offset_remainder_saturating2() { + let a = BitChunks::::new(&[0b01001001u8, 0b00000001], 1, 8); + assert_eq!(a.remainder(), 0b1010_0100u64); +} + +#[test] +fn offset_remainder_saturating3() { + let input: &[u8] = &[0b01000000, 0b01000001]; + let a = BitChunks::::new(input, 8, 2); + assert_eq!(a.remainder(), 0b0100_0001u64); +} + +#[test] +fn basics_multiple() { + let mut iter = BitChunks::::new( + &[0b00000001u8, 0b00000010u8, 0b00000100u8, 0b00001000u8], + 0, + 4 * 8, + ); + assert_eq!(iter.next().unwrap(), 0b0000_0010_0000_0001u16); + assert_eq!(iter.next().unwrap(), 0b0000_1000_0000_0100u16); + assert_eq!(iter.remainder(), 0); +} + +#[test] +fn basics_multiple_offset() { + let mut iter = BitChunks::::new( + &[ + 0b00000001u8, + 0b00000010u8, + 0b00000100u8, + 0b00001000u8, + 0b00000001u8, + ], + 1, + 4 * 8, + ); + assert_eq!(iter.next().unwrap(), 0b0000_0001_0000_0000u16); + assert_eq!(iter.next().unwrap(), 0b1000_0100_0000_0010u16); + assert_eq!(iter.remainder(), 0); +} + +#[test] +fn remainder_large() { + let input: &[u8] = &[ + 0b00100100, 0b01001001, 0b10010010, 0b00100100, 0b01001001, 0b10010010, 0b00100100, + 0b01001001, 0b10010010, 0b00100100, 0b01001001, 0b10010010, 0b00000100, + ]; + let mut iter = BitChunks::::new(input, 0, 8 * 12 + 4); + assert_eq!(iter.remainder_len(), 100 - 96); + + for j in 0..12 { + let mut a = BitChunkIter::new(iter.next().unwrap(), 8); + for i in 0..8 { + assert_eq!(a.next().unwrap(), (j * 8 + i + 1) % 3 == 0); + } + } + assert_eq!(None, iter.next()); + + let expected_remainder = 0b00000100u8; + assert_eq!(iter.remainder(), expected_remainder); + + let mut a = BitChunkIter::new(expected_remainder, 8); + for i in 0..4 { + assert_eq!(a.next().unwrap(), (i + 1) % 3 == 0); + } +} + +#[test] +fn basics_1() { + let mut iter = BitChunks::::new( + &[0b00000001u8, 0b00000010u8, 0b00000100u8, 0b00001000u8], + 8, + 3 * 8, + ); + assert_eq!(iter.next().unwrap(), 0b0000_0100_0000_0010u16); + assert_eq!(iter.next(), None); + assert_eq!(iter.remainder(), 0b0000_0000_0000_1000u16); + assert_eq!(iter.remainder_len(), 8); +} + +#[test] +fn basics_2() { + let mut iter = BitChunks::::new( + &[0b00000001u8, 0b00000010u8, 0b00000100u8, 0b00001000u8], + 7, + 3 * 8, + ); + assert_eq!(iter.remainder(), 0b0000_0000_0001_0000u16); + assert_eq!(iter.next().unwrap(), 0b0000_1000_0000_0100u16); + assert_eq!(iter.next(), None); +} + +#[test] +fn remainder_1() { + let mut iter = BitChunks::::new(&[0b11111111u8, 0b00000001u8], 0, 9); + assert_eq!(iter.next(), None); + assert_eq!(iter.remainder(), 0b1_1111_1111u64); +} + +#[test] +fn remainder_2() { + // (i % 3 == 0) in bitmap + let input: &[u8] = &[ + 0b01001001, 0b10010010, 0b00100100, 0b01001001, 0b10010010, 0b00100100, 0b01001001, + 0b10010010, 0b00100100, 0b01001001, /* 73 */ + 0b10010010, /* 146 */ + 0b00100100, 0b00001001, + ]; + let offset = 10; // 8 + 2 + let length = 90; + + let mut iter = BitChunks::::new(input, offset, length); + let first: u64 = 0b0100100100100100100100100100100100100100100100100100100100100100; + assert_eq!(first, iter.next().unwrap()); + assert_eq!(iter.next(), None); + assert_eq!(iter.remainder(), 0b10010010010010010010010010u64); +} diff --git a/crates/polars/tests/it/arrow/bitmap/utils/fmt.rs b/crates/polars/tests/it/arrow/bitmap/utils/fmt.rs new file mode 100644 index 000000000000..08cfbf31c62e --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/utils/fmt.rs @@ -0,0 +1,38 @@ +use arrow::bitmap::utils::fmt; + +struct A<'a>(&'a [u8], usize, usize); + +impl std::fmt::Debug for A<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fmt(self.0, self.1, self.2, f) + } +} + +#[test] +fn test_debug() -> std::fmt::Result { + let test = |bytes, offset, len, bytes_str| { + assert_eq!( + format!("{:?}", A(bytes, offset, len)), + format!("Bitmap {{ len: {len}, offset: {offset}, bytes: {bytes_str} }}") + ); + }; + test(&[1], 0, 0, "[]"); + test(&[0b11000001], 0, 8, "[0b11000001]"); + test(&[0b11000001, 1], 0, 9, "[0b11000001, 0b_______1]"); + test(&[1], 0, 2, "[0b______01]"); + test(&[1], 1, 2, "[0b_____00_]"); + test(&[1], 2, 2, "[0b____00__]"); + test(&[1], 3, 2, "[0b___00___]"); + test(&[1], 4, 2, "[0b__00____]"); + test(&[1], 5, 2, "[0b_00_____]"); + test(&[1], 6, 2, "[0b00______]"); + test(&[0b11000001, 1], 1, 9, "[0b1100000_, 0b______01]"); + test(&[0b11000001, 1, 1, 1], 1, 9, "[0b1100000_, 0b______01]"); + test( + &[0b11000001, 1, 1], + 2, + 16, + "[0b110000__, 0b00000001, 0b______01]", + ); + Ok(()) +} diff --git a/crates/polars/tests/it/arrow/bitmap/utils/iterator.rs b/crates/polars/tests/it/arrow/bitmap/utils/iterator.rs new file mode 100644 index 000000000000..307937f4afa4 --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/utils/iterator.rs @@ -0,0 +1,45 @@ +use arrow::bitmap::utils::BitmapIter; + +#[test] +fn basic() { + let values = &[0b01011011u8]; + let iter = BitmapIter::new(values, 0, 6); + let result = iter.collect::>(); + assert_eq!(result, vec![true, true, false, true, true, false]) +} + +#[test] +fn large() { + let values = &[0b01011011u8]; + let values = std::iter::repeat_n(values, 63) + .flatten() + .copied() + .collect::>(); + let len = 63 * 8; + let iter = BitmapIter::new(&values, 0, len); + assert_eq!(iter.count(), len); +} + +#[test] +fn offset() { + let values = &[0b01011011u8]; + let iter = BitmapIter::new(values, 2, 4); + let result = iter.collect::>(); + assert_eq!(result, vec![false, true, true, false]) +} + +#[test] +fn rev() { + let values = &[0b01011011u8, 0b01011011u8]; + let iter = BitmapIter::new(values, 2, 13); + let result = iter.rev().collect::>(); + assert_eq!( + result, + vec![ + false, true, true, false, true, false, true, true, false, true, true, false, true + ] + .into_iter() + .rev() + .collect::>() + ) +} diff --git a/crates/polars/tests/it/arrow/bitmap/utils/mod.rs b/crates/polars/tests/it/arrow/bitmap/utils/mod.rs new file mode 100644 index 000000000000..ebd8d983dec0 --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/utils/mod.rs @@ -0,0 +1,85 @@ +use arrow::bitmap::utils::*; +use proptest::prelude::*; + +use super::bitmap_strategy; + +mod bit_chunks_exact; +mod chunk_iter; +mod fmt; +mod iterator; +mod slice_iterator; +mod zip_validity; + +#[test] +fn get_bit_basics() { + let input: &[u8] = &[ + 0b00000000, 0b00000001, 0b00000010, 0b00000100, 0b00001000, 0b00010000, 0b00100000, + 0b01000000, 0b11111111, + ]; + unsafe { + for i in 0..8 { + assert!(!get_bit_unchecked(input, i)); + } + assert!(get_bit_unchecked(input, 8)); + for i in 8 + 1..2 * 8 { + assert!(!get_bit_unchecked(input, i)); + } + assert!(get_bit_unchecked(input, 2 * 8 + 1)); + for i in 2 * 8 + 2..3 * 8 { + assert!(!get_bit_unchecked(input, i)); + } + assert!(get_bit_unchecked(input, 3 * 8 + 2)); + for i in 3 * 8 + 3..4 * 8 { + assert!(!get_bit_unchecked(input, i)); + } + assert!(get_bit_unchecked(input, 4 * 8 + 3)); + } +} + +#[test] +fn count_zeros_basics() { + let input: &[u8] = &[ + 0b01001001, 0b00000001, 0b00000010, 0b00000100, 0b00001000, 0b00010000, 0b00100000, + 0b01000000, 0b11111111, + ]; + assert_eq!(count_zeros(input, 0, 8), 8 - 3); + assert_eq!(count_zeros(input, 1, 7), 7 - 2); + assert_eq!(count_zeros(input, 1, 8), 8 - 3); + assert_eq!(count_zeros(input, 2, 7), 7 - 3); + assert_eq!(count_zeros(input, 0, 32), 32 - 6); + assert_eq!(count_zeros(input, 9, 2), 2); + + let input: &[u8] = &[0b01000000, 0b01000001]; + assert_eq!(count_zeros(input, 8, 2), 1); + assert_eq!(count_zeros(input, 8, 3), 2); + assert_eq!(count_zeros(input, 8, 4), 3); + assert_eq!(count_zeros(input, 8, 5), 4); + assert_eq!(count_zeros(input, 8, 6), 5); + assert_eq!(count_zeros(input, 8, 7), 5); + assert_eq!(count_zeros(input, 8, 8), 6); + + let input: &[u8] = &[0b01000000, 0b01010101]; + assert_eq!(count_zeros(input, 9, 2), 1); + assert_eq!(count_zeros(input, 10, 2), 1); + assert_eq!(count_zeros(input, 11, 2), 1); + assert_eq!(count_zeros(input, 12, 2), 1); + assert_eq!(count_zeros(input, 13, 2), 1); + assert_eq!(count_zeros(input, 14, 2), 1); +} + +#[test] +fn count_zeros_1() { + // offset = 10, len = 90 => remainder + let input: &[u8] = &[73, 146, 36, 73, 146, 36, 73, 146, 36, 73, 146, 36, 9]; + assert_eq!(count_zeros(input, 10, 90), 60); +} + +proptest! { + /// Asserts that `Bitmap::null_count` equals the number of unset bits + #[test] + #[cfg_attr(miri, ignore)] // miri and proptest do not work well :( + fn null_count(bitmap in bitmap_strategy()) { + let sum_of_sets: usize = (0..bitmap.len()).map(|x| (!bitmap.get_bit(x)) as usize).sum(); + assert_eq!(bitmap.unset_bits(), sum_of_sets); + } +} diff --git a/crates/polars/tests/it/arrow/bitmap/utils/slice_iterator.rs b/crates/polars/tests/it/arrow/bitmap/utils/slice_iterator.rs new file mode 100644 index 000000000000..2dabf0d24af7 --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/utils/slice_iterator.rs @@ -0,0 +1,150 @@ +use arrow::bitmap::Bitmap; +use arrow::bitmap::utils::SlicesIterator; +use proptest::prelude::*; + +use super::bitmap_strategy; + +proptest! { + /// Asserts that: + /// * `slots` is the number of set bits in the bitmap + /// * the sum of the lens of the slices equals `slots` + /// * each item on each slice is set + #[test] + #[cfg_attr(miri, ignore)] // miri and proptest do not work well :( + fn check_invariants(bitmap in bitmap_strategy()) { + let iter = SlicesIterator::new(&bitmap); + + let slots = iter.slots(); + + assert_eq!(bitmap.len() - bitmap.unset_bits(), slots); + + let slices = iter.collect::>(); + let mut sum = 0; + for (start, len) in slices { + sum += len; + for i in start..(start+len) { + assert!(bitmap.get_bit(i)); + } + } + assert_eq!(sum, slots); + } +} + +#[test] +fn single_set() { + let values = (0..16).map(|i| i == 1).collect::(); + + let iter = SlicesIterator::new(&values); + let count = iter.slots(); + let chunks = iter.collect::>(); + + assert_eq!(chunks, vec![(1, 1)]); + assert_eq!(count, 1); +} + +#[test] +fn single_unset() { + let values = (0..64).map(|i| i != 1).collect::(); + + let iter = SlicesIterator::new(&values); + let count = iter.slots(); + let chunks = iter.collect::>(); + + assert_eq!(chunks, vec![(0, 1), (2, 62)]); + assert_eq!(count, 64 - 1); +} + +#[test] +fn generic() { + let values = (0..130).map(|i| i % 62 != 0).collect::(); + + let iter = SlicesIterator::new(&values); + let count = iter.slots(); + let chunks = iter.collect::>(); + + assert_eq!(chunks, vec![(1, 61), (63, 61), (125, 5)]); + assert_eq!(count, 61 + 61 + 5); +} + +#[test] +fn incomplete_byte() { + let values = (0..6).map(|i| i == 1).collect::(); + + let iter = SlicesIterator::new(&values); + let count = iter.slots(); + let chunks = iter.collect::>(); + + assert_eq!(chunks, vec![(1, 1)]); + assert_eq!(count, 1); +} + +#[test] +fn incomplete_byte1() { + let values = (0..12).map(|i| i == 9).collect::(); + + let iter = SlicesIterator::new(&values); + let count = iter.slots(); + let chunks = iter.collect::>(); + + assert_eq!(chunks, vec![(9, 1)]); + assert_eq!(count, 1); +} + +#[test] +fn end_of_byte() { + let values = (0..16).map(|i| i != 7).collect::(); + + let iter = SlicesIterator::new(&values); + let count = iter.slots(); + let chunks = iter.collect::>(); + + assert_eq!(chunks, vec![(0, 7), (8, 8)]); + assert_eq!(count, 15); +} + +#[test] +fn bla() { + let values = vec![true, true, true, true, true, true, true, false] + .into_iter() + .collect::(); + let iter = SlicesIterator::new(&values); + let count = iter.slots(); + assert_eq!(values.unset_bits() + iter.slots(), values.len()); + + let total = iter.into_iter().fold(0, |acc, x| acc + x.1); + + assert_eq!(count, total); +} + +#[test] +fn past_end_should_not_be_returned() { + let values = Bitmap::from_u8_slice([0b11111010], 3); + let iter = SlicesIterator::new(&values); + let count = iter.slots(); + assert_eq!(values.unset_bits() + iter.slots(), values.len()); + + let total = iter.into_iter().fold(0, |acc, x| acc + x.1); + + assert_eq!(count, total); +} + +#[test] +fn sliced() { + let values = Bitmap::from_u8_slice([0b11111010, 0b11111011], 16); + let values = values.sliced(8, 2); + let iter = SlicesIterator::new(&values); + + let chunks = iter.collect::>(); + + // the first "11" in the second byte + assert_eq!(chunks, vec![(0, 2)]); +} + +#[test] +fn remainder_1() { + let values = Bitmap::from_u8_slice([0, 0, 0b00000000, 0b00010101], 27); + let values = values.sliced(22, 5); + let iter = SlicesIterator::new(&values); + let chunks = iter.collect::>(); + assert_eq!(chunks, vec![(2, 1), (4, 1)]); +} diff --git a/crates/polars/tests/it/arrow/bitmap/utils/zip_validity.rs b/crates/polars/tests/it/arrow/bitmap/utils/zip_validity.rs new file mode 100644 index 000000000000..6ec488521ac5 --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/utils/zip_validity.rs @@ -0,0 +1,106 @@ +use arrow::bitmap::Bitmap; +use arrow::bitmap::utils::{BitmapIter, ZipValidity}; + +#[test] +fn basic() { + let a = Bitmap::from([true, false]); + let a = Some(a.iter()); + let values = vec![0, 1]; + let zip = ZipValidity::new(values.into_iter(), a); + + let a = zip.collect::>(); + assert_eq!(a, vec![Some(0), None]); +} + +#[test] +fn complete() { + let a = Bitmap::from([true, false, true, false, true, false, true, false]); + let a = Some(a.iter()); + let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let zip = ZipValidity::new(values.into_iter(), a); + + let a = zip.collect::>(); + assert_eq!( + a, + vec![Some(0), None, Some(2), None, Some(4), None, Some(6), None] + ); +} + +#[test] +fn slices() { + let a = Bitmap::from([true, false]); + let a = Some(a.iter()); + let offsets = [0, 2, 3]; + let values = [1, 2, 3]; + let iter = offsets.windows(2).map(|x| { + let start = x[0]; + let end = x[1]; + &values[start..end] + }); + let zip = ZipValidity::new(iter, a); + + let a = zip.collect::>(); + assert_eq!(a, vec![Some([1, 2].as_ref()), None]); +} + +#[test] +fn byte() { + let a = Bitmap::from([true, false, true, false, false, true, true, false, true]); + let a = Some(a.iter()); + let values = vec![0, 1, 2, 3, 4, 5, 6, 7, 8]; + let zip = ZipValidity::new(values.into_iter(), a); + + let a = zip.collect::>(); + assert_eq!( + a, + vec![ + Some(0), + None, + Some(2), + None, + None, + Some(5), + Some(6), + None, + Some(8) + ] + ); +} + +#[test] +fn offset() { + let a = Bitmap::from([true, false, true, false, false, true, true, false, true]).sliced(1, 8); + let a = Some(a.iter()); + let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let zip = ZipValidity::new(values.into_iter(), a); + + let a = zip.collect::>(); + assert_eq!( + a, + vec![None, Some(1), None, None, Some(4), Some(5), None, Some(7)] + ); +} + +#[test] +fn none() { + let values = vec![0, 1, 2]; + let zip = ZipValidity::new(values.into_iter(), None::); + + let a = zip.collect::>(); + assert_eq!(a, vec![Some(0), Some(1), Some(2)]); +} + +#[test] +fn rev() { + let a = Bitmap::from([true, false, true, false, false, true, true, false, true]).sliced(1, 8); + let a = Some(a.iter()); + let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let zip = ZipValidity::new(values.into_iter(), a); + + let result = zip.rev().collect::>(); + let expected = vec![None, Some(1), None, None, Some(4), Some(5), None, Some(7)] + .into_iter() + .rev() + .collect::>(); + assert_eq!(result, expected); +} diff --git a/crates/polars/tests/it/arrow/buffer/immutable.rs b/crates/polars/tests/it/arrow/buffer/immutable.rs new file mode 100644 index 000000000000..cc8742ba73ae --- /dev/null +++ b/crates/polars/tests/it/arrow/buffer/immutable.rs @@ -0,0 +1,45 @@ +use arrow::buffer::Buffer; + +#[test] +fn new() { + let buffer = Buffer::::new(); + assert_eq!(buffer.len(), 0); + assert!(buffer.is_empty()); +} + +#[test] +fn from_slice() { + let buffer = Buffer::::from(vec![0, 1, 2]); + assert_eq!(buffer.len(), 3); + assert_eq!(buffer.as_slice(), &[0, 1, 2]); +} + +#[test] +fn slice() { + let buffer = Buffer::::from(vec![0, 1, 2, 3]); + let buffer = buffer.sliced(1, 2); + assert_eq!(buffer.len(), 2); + assert_eq!(buffer.as_slice(), &[1, 2]); +} + +#[test] +fn from_iter() { + let buffer = (0..3).collect::>(); + assert_eq!(buffer.len(), 3); + assert_eq!(buffer.as_slice(), &[0, 1, 2]); +} + +#[test] +fn debug() { + let buffer = Buffer::::from(vec![0, 1, 2, 3]); + let buffer = buffer.sliced(1, 2); + let a = format!("{buffer:?}"); + assert_eq!(a, "[1, 2]") +} + +#[test] +fn from_vec() { + let buffer = Buffer::::from(vec![0, 1, 2]); + assert_eq!(buffer.len(), 3); + assert_eq!(buffer.as_slice(), &[0, 1, 2]); +} diff --git a/crates/polars/tests/it/arrow/buffer/mod.rs b/crates/polars/tests/it/arrow/buffer/mod.rs new file mode 100644 index 000000000000..723312cd1a87 --- /dev/null +++ b/crates/polars/tests/it/arrow/buffer/mod.rs @@ -0,0 +1 @@ +mod immutable; diff --git a/crates/polars/tests/it/arrow/compute/aggregate/memory.rs b/crates/polars/tests/it/arrow/compute/aggregate/memory.rs new file mode 100644 index 000000000000..942fd38cc4a7 --- /dev/null +++ b/crates/polars/tests/it/arrow/compute/aggregate/memory.rs @@ -0,0 +1,32 @@ +use arrow::array::*; +use arrow::compute::aggregate::estimated_bytes_size; +use arrow::datatypes::{ArrowDataType, Field}; + +#[test] +fn primitive() { + let a = Int32Array::from_slice([1, 2, 3, 4, 5]); + assert_eq!(5 * size_of::(), estimated_bytes_size(&a)); +} + +#[test] +fn boolean() { + let a = BooleanArray::from_slice([true]); + assert_eq!(1, estimated_bytes_size(&a)); +} + +#[test] +fn utf8() { + let a = Utf8Array::::from_slice(["aaa"]); + assert_eq!(3 + 2 * size_of::(), estimated_bytes_size(&a)); +} + +#[test] +fn fixed_size_list() { + let dtype = ArrowDataType::FixedSizeList( + Box::new(Field::new("elem".into(), ArrowDataType::Float32, false)), + 3, + ); + let values = Box::new(Float32Array::from_slice([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])); + let a = FixedSizeListArray::new(dtype, 2, values, None); + assert_eq!(6 * size_of::(), estimated_bytes_size(&a)); +} diff --git a/crates/polars/tests/it/arrow/compute/aggregate/mod.rs b/crates/polars/tests/it/arrow/compute/aggregate/mod.rs new file mode 100644 index 000000000000..2363540356a2 --- /dev/null +++ b/crates/polars/tests/it/arrow/compute/aggregate/mod.rs @@ -0,0 +1 @@ +mod memory; diff --git a/crates/polars/tests/it/arrow/compute/arity_assign.rs b/crates/polars/tests/it/arrow/compute/arity_assign.rs new file mode 100644 index 000000000000..b8ba89dda238 --- /dev/null +++ b/crates/polars/tests/it/arrow/compute/arity_assign.rs @@ -0,0 +1,21 @@ +use arrow::array::Int32Array; +use arrow::compute::arity_assign::{binary, unary}; + +#[test] +fn test_unary_assign() { + let mut a = Int32Array::from([Some(5), Some(6), None, Some(10)]); + + unary(&mut a, |x| x + 10); + + assert_eq!(a, Int32Array::from([Some(15), Some(16), None, Some(20)])) +} + +#[test] +fn test_binary_assign() { + let mut a = Int32Array::from([Some(5), Some(6), None, Some(10)]); + let b = Int32Array::from([Some(1), Some(2), Some(1), None]); + + binary(&mut a, &b, |x, y| x + y); + + assert_eq!(a, Int32Array::from([Some(6), Some(8), None, None])) +} diff --git a/crates/polars/tests/it/arrow/compute/bitwise.rs b/crates/polars/tests/it/arrow/compute/bitwise.rs new file mode 100644 index 000000000000..e2a380fbd707 --- /dev/null +++ b/crates/polars/tests/it/arrow/compute/bitwise.rs @@ -0,0 +1,41 @@ +use arrow::array::*; +use arrow::compute::bitwise::*; + +#[test] +fn test_xor() { + let a = Int32Array::from(&[Some(2), Some(4), Some(6), Some(7)]); + let b = Int32Array::from(&[None, Some(6), Some(9), Some(7)]); + let result = xor(&a, &b); + let expected = Int32Array::from(&[None, Some(2), Some(15), Some(0)]); + + assert_eq!(result, expected); +} + +#[test] +fn test_and() { + let a = Int32Array::from(&[Some(1), Some(2), Some(15)]); + let b = Int32Array::from(&[None, Some(2), Some(6)]); + let result = and(&a, &b); + let expected = Int32Array::from(&[None, Some(2), Some(6)]); + + assert_eq!(result, expected); +} + +#[test] +fn test_or() { + let a = Int32Array::from(&[Some(1), Some(2), Some(0)]); + let b = Int32Array::from(&[None, Some(2), Some(0)]); + let result = or(&a, &b); + let expected = Int32Array::from(&[None, Some(2), Some(0)]); + + assert_eq!(result, expected); +} + +#[test] +fn test_not() { + let a = Int8Array::from(&[None, Some(1i8), Some(-100i8)]); + let result = not(&a); + let expected = Int8Array::from(&[None, Some(-2), Some(99)]); + + assert_eq!(result, expected); +} diff --git a/crates/polars/tests/it/arrow/compute/boolean.rs b/crates/polars/tests/it/arrow/compute/boolean.rs new file mode 100644 index 000000000000..c209421976c3 --- /dev/null +++ b/crates/polars/tests/it/arrow/compute/boolean.rs @@ -0,0 +1,453 @@ +use std::iter::FromIterator; + +use arrow::array::*; +use arrow::compute::boolean::*; +use arrow::scalar::BooleanScalar; + +#[test] +fn array_and() { + let a = BooleanArray::from_slice(vec![false, false, true, true]); + let b = BooleanArray::from_slice(vec![false, true, false, true]); + let c = and(&a, &b); + + let expected = BooleanArray::from_slice(vec![false, false, false, true]); + + assert_eq!(c, expected); +} + +#[test] +fn array_or() { + let a = BooleanArray::from_slice(vec![false, false, true, true]); + let b = BooleanArray::from_slice(vec![false, true, false, true]); + let c = or(&a, &b); + + let expected = BooleanArray::from_slice(vec![false, true, true, true]); + + assert_eq!(c, expected); +} + +#[test] +fn array_or_validity() { + let a = BooleanArray::from(vec![ + None, + None, + None, + Some(false), + Some(false), + Some(false), + Some(true), + Some(true), + Some(true), + ]); + let b = BooleanArray::from(vec![ + None, + Some(false), + Some(true), + None, + Some(false), + Some(true), + None, + Some(false), + Some(true), + ]); + let c = or(&a, &b); + + let expected = BooleanArray::from(vec![ + None, + None, + None, + None, + Some(false), + Some(true), + None, + Some(true), + Some(true), + ]); + + assert_eq!(c, expected); +} + +#[test] +fn array_not() { + let a = BooleanArray::from_slice(vec![false, true]); + let c = not(&a); + + let expected = BooleanArray::from_slice(vec![true, false]); + + assert_eq!(c, expected); +} + +#[test] +fn array_and_validity() { + let a = BooleanArray::from(vec![ + None, + None, + None, + Some(false), + Some(false), + Some(false), + Some(true), + Some(true), + Some(true), + ]); + let b = BooleanArray::from(vec![ + None, + Some(false), + Some(true), + None, + Some(false), + Some(true), + None, + Some(false), + Some(true), + ]); + let c = and(&a, &b); + + let expected = BooleanArray::from(vec![ + None, + None, + None, + None, + Some(false), + Some(false), + None, + Some(false), + Some(true), + ]); + + assert_eq!(c, expected); +} + +#[test] +fn array_and_sliced_same_offset() { + let a = BooleanArray::from_slice(vec![ + false, false, false, false, false, false, false, false, false, false, true, true, + ]); + let b = BooleanArray::from_slice(vec![ + false, false, false, false, false, false, false, false, false, true, false, true, + ]); + + let a = a.sliced(8, 4); + let b = b.sliced(8, 4); + let c = and(&a, &b); + + let expected = BooleanArray::from_slice(vec![false, false, false, true]); + + assert_eq!(expected, c); +} + +#[test] +fn array_and_sliced_same_offset_mod8() { + let a = BooleanArray::from_slice(vec![ + false, false, true, true, false, false, false, false, false, false, false, false, + ]); + let b = BooleanArray::from_slice(vec![ + false, false, false, false, false, false, false, false, false, true, false, true, + ]); + + let a = a.sliced(0, 4); + let b = b.sliced(8, 4); + + let c = and(&a, &b); + + let expected = BooleanArray::from_slice(vec![false, false, false, true]); + + assert_eq!(expected, c); +} + +#[test] +fn array_and_sliced_offset1() { + let a = BooleanArray::from_slice(vec![ + false, false, false, false, false, false, false, false, false, false, true, true, + ]); + let b = BooleanArray::from_slice(vec![false, true, false, true]); + + let a = a.sliced(8, 4); + + let c = and(&a, &b); + + let expected = BooleanArray::from_slice(vec![false, false, false, true]); + + assert_eq!(expected, c); +} + +#[test] +fn array_and_sliced_offset2() { + let a = BooleanArray::from_slice(vec![false, false, true, true]); + let b = BooleanArray::from_slice(vec![ + false, false, false, false, false, false, false, false, false, true, false, true, + ]); + + let b = b.sliced(8, 4); + + let c = and(&a, &b); + + let expected = BooleanArray::from_slice(vec![false, false, false, true]); + + assert_eq!(expected, c); +} + +#[test] +fn array_and_validity_offset() { + let a = BooleanArray::from(vec![None, Some(false), Some(true), None, Some(true)]); + let a = a.sliced(1, 4); + let a = a.as_any().downcast_ref::().unwrap(); + + let b = BooleanArray::from(vec![ + None, + None, + Some(true), + Some(false), + Some(true), + Some(true), + ]); + + let b = b.sliced(2, 4); + let b = b.as_any().downcast_ref::().unwrap(); + + let c = and(a, b); + + let expected = BooleanArray::from(vec![Some(false), Some(false), None, Some(true)]); + + assert_eq!(expected, c); +} + +#[test] +fn test_nonnull_array_is_null() { + let a = Int32Array::from_slice([1, 2, 3, 4]); + + let res = is_null(&a); + + let expected = BooleanArray::from_slice(vec![false, false, false, false]); + + assert_eq!(expected, res); +} + +#[test] +fn test_nonnull_array_with_offset_is_null() { + let a = Int32Array::from_slice(vec![1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1]); + let a = a.sliced(8, 4); + + let res = is_null(&a); + + let expected = BooleanArray::from_slice(vec![false, false, false, false]); + + assert_eq!(expected, res); +} + +#[test] +fn test_nonnull_array_is_not_null() { + let a = Int32Array::from_slice([1, 2, 3, 4]); + + let res = is_not_null(&a); + + let expected = BooleanArray::from_slice(vec![true, true, true, true]); + + assert_eq!(expected, res); +} + +#[test] +fn test_nonnull_array_with_offset_is_not_null() { + let a = Int32Array::from_slice([1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1]); + let a = a.sliced(8, 4); + + let res = is_not_null(&a); + + let expected = BooleanArray::from_slice([true, true, true, true]); + + assert_eq!(expected, res); +} + +#[test] +fn test_nullable_array_is_null() { + let a = Int32Array::from(vec![Some(1), None, Some(3), None]); + + let res = is_null(&a); + + let expected = BooleanArray::from_slice(vec![false, true, false, true]); + + assert_eq!(expected, res); +} + +#[test] +fn test_nullable_array_with_offset_is_null() { + let a = Int32Array::from(vec![ + None, + None, + None, + None, + None, + None, + None, + None, + // offset 8, previous None values are skipped by the slice + Some(1), + None, + Some(2), + None, + Some(3), + Some(4), + None, + None, + ]); + let a = a.sliced(8, 4); + + let res = is_null(&a); + + let expected = BooleanArray::from_slice(vec![false, true, false, true]); + + assert_eq!(expected, res); +} + +#[test] +fn test_nullable_array_is_not_null() { + let a = Int32Array::from(vec![Some(1), None, Some(3), None]); + + let res = is_not_null(&a); + + let expected = BooleanArray::from_slice(vec![true, false, true, false]); + + assert_eq!(expected, res); +} + +#[test] +fn test_nullable_array_with_offset_is_not_null() { + let a = Int32Array::from(vec![ + None, + None, + None, + None, + None, + None, + None, + None, + // offset 8, previous None values are skipped by the slice + Some(1), + None, + Some(2), + None, + Some(3), + Some(4), + None, + None, + ]); + let a = a.sliced(8, 4); + + let res = is_not_null(&a); + + let expected = BooleanArray::from_slice(vec![true, false, true, false]); + + assert_eq!(expected, res); +} + +#[test] +fn array_and_scalar() { + let array = BooleanArray::from_slice([false, false, true, true]); + + let scalar = BooleanScalar::new(Some(true)); + let real = and_scalar(&array, &scalar); + + let expected = BooleanArray::from_slice([false, false, true, true]); + assert_eq!(real, expected); + + let scalar = BooleanScalar::new(Some(false)); + let real = and_scalar(&array, &scalar); + + let expected = BooleanArray::from_slice([false, false, false, false]); + + assert_eq!(real, expected); +} + +#[test] +fn array_and_scalar_validity() { + let array = BooleanArray::from(&[None, Some(false), Some(true)]); + + let scalar = BooleanScalar::new(Some(true)); + let real = and_scalar(&array, &scalar); + + let expected = BooleanArray::from(&[None, Some(false), Some(true)]); + assert_eq!(real, expected); + + let scalar = BooleanScalar::new(None); + let real = and_scalar(&array, &scalar); + + let expected = BooleanArray::from(&[None; 3]); + assert_eq!(real, expected); + + let array = BooleanArray::from_slice([true, false, true]); + let real = and_scalar(&array, &scalar); + + let expected = BooleanArray::from(&[None; 3]); + assert_eq!(real, expected); +} + +#[test] +fn array_or_scalar() { + let array = BooleanArray::from_slice([false, false, true, true]); + + let scalar = BooleanScalar::new(Some(true)); + let real = or_scalar(&array, &scalar); + + let expected = BooleanArray::from_slice([true, true, true, true]); + assert_eq!(real, expected); + + let scalar = BooleanScalar::new(Some(false)); + let real = or_scalar(&array, &scalar); + + let expected = BooleanArray::from_slice([false, false, true, true]); + assert_eq!(real, expected); +} + +#[test] +fn array_or_scalar_validity() { + let array = BooleanArray::from(&[None, Some(false), Some(true)]); + + let scalar = BooleanScalar::new(Some(true)); + let real = or_scalar(&array, &scalar); + + let expected = BooleanArray::from(&[None, Some(true), Some(true)]); + assert_eq!(real, expected); + + let scalar = BooleanScalar::new(None); + let real = or_scalar(&array, &scalar); + + let expected = BooleanArray::from(&[None; 3]); + assert_eq!(real, expected); + + let array = BooleanArray::from_slice([true, false, true]); + let real = and_scalar(&array, &scalar); + + let expected = BooleanArray::from(&[None; 3]); + assert_eq!(real, expected); +} + +#[test] +fn test_any_all() { + let array = BooleanArray::from(&[None, Some(false), Some(true)]); + assert!(any(&array)); + assert!(!all(&array)); + let array = BooleanArray::from(&[None, Some(false), Some(false)]); + assert!(!any(&array)); + assert!(!all(&array)); + let array = BooleanArray::from(&[None, Some(true), Some(true)]); + assert!(any(&array)); + assert!(all(&array)); + let array = BooleanArray::from_iter(std::iter::repeat_n(false, 10).map(Some)); + assert!(!any(&array)); + assert!(!all(&array)); + let array = BooleanArray::from_iter(std::iter::repeat_n(true, 10).map(Some)); + assert!(any(&array)); + assert!(all(&array)); + let array = BooleanArray::from_iter([true, false, true, true].map(Some)); + assert!(any(&array)); + assert!(!all(&array)); + let array = BooleanArray::from(&[Some(true)]); + assert!(any(&array)); + assert!(all(&array)); + let array = BooleanArray::from(&[Some(false)]); + assert!(!any(&array)); + assert!(!all(&array)); + let array = BooleanArray::from(&[]); + assert!(!any(&array)); + assert!(all(&array)); +} diff --git a/crates/polars/tests/it/arrow/compute/boolean_kleene.rs b/crates/polars/tests/it/arrow/compute/boolean_kleene.rs new file mode 100644 index 000000000000..515490796d38 --- /dev/null +++ b/crates/polars/tests/it/arrow/compute/boolean_kleene.rs @@ -0,0 +1,223 @@ +use arrow::array::BooleanArray; +use arrow::compute::boolean_kleene::*; +use arrow::scalar::BooleanScalar; + +#[test] +fn and_generic() { + let lhs = BooleanArray::from(&[ + None, + None, + None, + Some(false), + Some(false), + Some(false), + Some(true), + Some(true), + Some(true), + ]); + let rhs = BooleanArray::from(&[ + None, + Some(false), + Some(true), + None, + Some(false), + Some(true), + None, + Some(false), + Some(true), + ]); + let c = and(&lhs, &rhs); + + let expected = BooleanArray::from(&[ + None, + Some(false), + None, + Some(false), + Some(false), + Some(false), + None, + Some(false), + Some(true), + ]); + + assert_eq!(c, expected); +} + +#[test] +fn or_generic() { + let a = BooleanArray::from(&[ + None, + None, + None, + Some(false), + Some(false), + Some(false), + Some(true), + Some(true), + Some(true), + ]); + let b = BooleanArray::from(&[ + None, + Some(false), + Some(true), + None, + Some(false), + Some(true), + None, + Some(false), + Some(true), + ]); + let c = or(&a, &b); + + let expected = BooleanArray::from(&[ + None, + None, + Some(true), + None, + Some(false), + Some(true), + Some(true), + Some(true), + Some(true), + ]); + + assert_eq!(c, expected); +} + +#[test] +fn or_right_nulls() { + let a = BooleanArray::from_slice([false, false, false, true, true, true]); + + let b = BooleanArray::from(&[Some(true), Some(false), None, Some(true), Some(false), None]); + + let c = or(&a, &b); + + let expected = BooleanArray::from(&[ + Some(true), + Some(false), + None, + Some(true), + Some(true), + Some(true), + ]); + + assert_eq!(c, expected); +} + +#[test] +fn or_left_nulls() { + let a = BooleanArray::from(vec![ + Some(true), + Some(false), + None, + Some(true), + Some(false), + None, + ]); + + let b = BooleanArray::from_slice([false, false, false, true, true, true]); + + let c = or(&a, &b); + + let expected = BooleanArray::from(vec![ + Some(true), + Some(false), + None, + Some(true), + Some(true), + Some(true), + ]); + + assert_eq!(c, expected); +} + +#[test] +fn array_and_true() { + let array = BooleanArray::from(&[Some(true), Some(false), None, Some(true), Some(false), None]); + + let scalar = BooleanScalar::new(Some(true)); + let result = and_scalar(&array, &scalar); + + // Should be same as argument array if scalar is true. + assert_eq!(result, array); +} + +#[test] +fn array_and_false() { + let array = BooleanArray::from(&[Some(true), Some(false), None, Some(true), Some(false), None]); + + let scalar = BooleanScalar::new(Some(false)); + let result = and_scalar(&array, &scalar); + + let expected = BooleanArray::from(&[ + Some(false), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false), + ]); + + assert_eq!(result, expected); +} + +#[test] +fn array_and_none() { + let array = BooleanArray::from(&[Some(true), Some(false), None, Some(true), Some(false), None]); + + let scalar = BooleanScalar::new(None); + let result = and_scalar(&array, &scalar); + + let expected = BooleanArray::from(&[None, Some(false), None, None, Some(false), None]); + + assert_eq!(result, expected); +} + +#[test] +fn array_or_true() { + let array = BooleanArray::from(&[Some(true), Some(false), None, Some(true), Some(false), None]); + + let scalar = BooleanScalar::new(Some(true)); + let result = or_scalar(&array, &scalar); + + let expected = BooleanArray::from(&[ + Some(true), + Some(true), + Some(true), + Some(true), + Some(true), + Some(true), + ]); + + assert_eq!(result, expected); +} + +#[test] +fn array_or_false() { + let array = BooleanArray::from(&[Some(true), Some(false), None, Some(true), Some(false), None]); + + let scalar = BooleanScalar::new(Some(false)); + let result = or_scalar(&array, &scalar); + + // Should be same as argument array if scalar is false. + assert_eq!(result, array); +} + +#[test] +fn array_or_none() { + let array = BooleanArray::from(&[Some(true), Some(false), None, Some(true), Some(false), None]); + + let scalar = BooleanScalar::new(None); + let result = or_scalar(&array, &scalar); + + let expected = BooleanArray::from(&[Some(true), None, None, Some(true), None, None]); + + assert_eq!(result, expected); +} + +#[test] +fn array_empty() { + let array = BooleanArray::from(&[]); + assert_eq!(any(&array), Some(false)); + assert_eq!(all(&array), Some(true)); +} diff --git a/crates/polars/tests/it/arrow/compute/mod.rs b/crates/polars/tests/it/arrow/compute/mod.rs new file mode 100644 index 000000000000..86e0ec542dfd --- /dev/null +++ b/crates/polars/tests/it/arrow/compute/mod.rs @@ -0,0 +1,6 @@ +mod aggregate; +mod bitwise; +mod boolean; +mod boolean_kleene; + +mod arity_assign; diff --git a/crates/polars/tests/it/arrow/ffi/data.rs b/crates/polars/tests/it/arrow/ffi/data.rs new file mode 100644 index 000000000000..24caed1c482e --- /dev/null +++ b/crates/polars/tests/it/arrow/ffi/data.rs @@ -0,0 +1,53 @@ +use arrow::array::*; +use arrow::datatypes::Field; +use arrow::ffi; +use polars_error::PolarsResult; + +fn _test_round_trip(array: Box, expected: Box) -> PolarsResult<()> { + let field = Field::new("a".into(), array.dtype().clone(), true); + + // export array and corresponding dtype + let array_ffi = ffi::export_array_to_c(array); + let schema_ffi = ffi::export_field_to_c(&field); + + // import references + let result_field = unsafe { ffi::import_field_from_c(&schema_ffi)? }; + let result_array = unsafe { ffi::import_array_from_c(array_ffi, result_field.dtype.clone())? }; + + assert_eq!(&result_array, &expected); + assert_eq!(result_field, field); + Ok(()) +} + +fn test_round_trip(expected: impl Array + Clone + 'static) -> PolarsResult<()> { + let array: Box = Box::new(expected.clone()); + let expected = Box::new(expected) as Box; + _test_round_trip(array.clone(), clone(expected.as_ref()))?; + + // sliced + _test_round_trip(array.sliced(1, 2), expected.sliced(1, 2)) +} + +#[test] +fn bool_nullable() -> PolarsResult<()> { + let data = BooleanArray::from(&[Some(true), None, Some(false), None]); + test_round_trip(data) +} + +#[test] +fn binview_nullable_inlined() -> PolarsResult<()> { + let data = Utf8ViewArray::from_slice([Some("foo"), None, Some("barbar"), None]); + test_round_trip(data) +} + +#[test] +fn binview_nullable_buffered() -> PolarsResult<()> { + let data = Utf8ViewArray::from_slice([ + Some("foobaroiwalksdfjoiei"), + None, + Some("barbar"), + None, + Some("aoisejiofjfoiewjjwfoiwejfo"), + ]); + test_round_trip(data) +} diff --git a/crates/polars/tests/it/arrow/ffi/mod.rs b/crates/polars/tests/it/arrow/ffi/mod.rs new file mode 100644 index 000000000000..1ca8fa75c400 --- /dev/null +++ b/crates/polars/tests/it/arrow/ffi/mod.rs @@ -0,0 +1,3 @@ +mod data; + +mod stream; diff --git a/crates/polars/tests/it/arrow/ffi/stream.rs b/crates/polars/tests/it/arrow/ffi/stream.rs new file mode 100644 index 000000000000..542ce1c1e0fb --- /dev/null +++ b/crates/polars/tests/it/arrow/ffi/stream.rs @@ -0,0 +1,44 @@ +use arrow::array::*; +use arrow::datatypes::Field; +use arrow::ffi; +use polars_error::{PolarsError, PolarsResult}; + +fn _test_round_trip(arrays: Vec>) -> PolarsResult<()> { + let field = Field::new("a".into(), arrays[0].dtype().clone(), true); + let iter = Box::new(arrays.clone().into_iter().map(Ok)) as _; + + let mut stream = Box::new(ffi::ArrowArrayStream::empty()); + + *stream = ffi::export_iterator(iter, field.clone()); + + // import + let mut stream = unsafe { ffi::ArrowArrayStreamReader::try_new(stream)? }; + + let mut produced_arrays: Vec> = vec![]; + while let Some(array) = unsafe { stream.next() } { + produced_arrays.push(array?); + } + + assert_eq!(produced_arrays, arrays); + assert_eq!(stream.field(), &field); + Ok(()) +} + +#[test] +fn round_trip() -> PolarsResult<()> { + let array = Int32Array::from(&[Some(2), None, Some(1), None]); + let array: Box = Box::new(array); + + _test_round_trip(vec![array.clone(), array.clone(), array]) +} + +#[test] +fn stream_reader_try_new_invalid_argument_error_on_released_stream() { + let released_stream = Box::new(ffi::ArrowArrayStream::empty()); + let reader = unsafe { ffi::ArrowArrayStreamReader::try_new(released_stream) }; + // poor man's assert_matches: + match reader { + Err(PolarsError::InvalidOperation(_)) => {}, + _ => panic!("ArrowArrayStreamReader::try_new did not return an InvalidArgumentError"), + } +} diff --git a/crates/polars/tests/it/arrow/io/ipc/mod.rs b/crates/polars/tests/it/arrow/io/ipc/mod.rs new file mode 100644 index 000000000000..245be65f98a1 --- /dev/null +++ b/crates/polars/tests/it/arrow/io/ipc/mod.rs @@ -0,0 +1,85 @@ +use std::io::Cursor; +use std::sync::Arc; + +use arrow::array::*; +use arrow::datatypes::{ArrowSchema, ArrowSchemaRef, Field}; +use arrow::io::ipc::IpcField; +use arrow::io::ipc::read::{FileReader, read_file_metadata}; +use arrow::io::ipc::write::*; +use arrow::record_batch::RecordBatchT; +use polars::prelude::PlSmallStr; +use polars_error::*; + +pub(crate) fn write( + batches: &[RecordBatchT>], + schema: &ArrowSchemaRef, + ipc_fields: Option>, + compression: Option, +) -> PolarsResult> { + let result = vec![]; + let options = WriteOptions { compression }; + let mut writer = FileWriter::try_new(result, schema.clone(), ipc_fields.clone(), options)?; + for batch in batches { + writer.write(batch, ipc_fields.as_ref().map(|x| x.as_ref()))?; + } + writer.finish()?; + Ok(writer.into_inner()) +} + +fn round_trip( + columns: RecordBatchT>, + schema: ArrowSchemaRef, + ipc_fields: Option>, + compression: Option, +) -> PolarsResult<()> { + let (expected_schema, expected_batches) = (schema.clone(), vec![columns]); + + let result = write(&expected_batches, &schema, ipc_fields, compression)?; + let mut reader = Cursor::new(result); + let metadata = read_file_metadata(&mut reader)?; + let schema = metadata.schema.clone(); + + let reader = FileReader::new(reader, metadata, None, None); + + assert_eq!(schema, expected_schema); + + let batches = reader.collect::>>()?; + + assert_eq!(batches, expected_batches); + Ok(()) +} + +fn prep_schema(array: &dyn Array) -> ArrowSchemaRef { + let name = PlSmallStr::from_static("a"); + Arc::new(ArrowSchema::from_iter([Field::new( + name, + array.dtype().clone(), + true, + )])) +} + +#[test] +fn write_boolean() -> PolarsResult<()> { + let array = BooleanArray::from([Some(true), Some(false), None, Some(true)]).boxed(); + let schema = prep_schema(array.as_ref()); + let columns = RecordBatchT::try_new(4, schema.clone(), vec![array])?; + round_trip(columns, schema, None, Some(Compression::ZSTD)) +} + +#[test] +fn write_sliced_utf8() -> PolarsResult<()> { + let array = Utf8Array::::from_slice(["aa", "bb"]) + .sliced(1, 1) + .boxed(); + let schema = prep_schema(array.as_ref()); + let columns = RecordBatchT::try_new(array.len(), schema.clone(), vec![array])?; + round_trip(columns, schema, None, Some(Compression::ZSTD)) +} + +#[test] +fn write_binview() -> PolarsResult<()> { + let array = Utf8ViewArray::from_slice([Some("foo"), Some("bar"), None, Some("hamlet")]).boxed(); + let schema = prep_schema(array.as_ref()); + let columns = RecordBatchT::try_new(array.len(), schema.clone(), vec![array])?; + round_trip(columns, schema, None, Some(Compression::ZSTD)) +} diff --git a/crates/polars/tests/it/arrow/io/mod.rs b/crates/polars/tests/it/arrow/io/mod.rs new file mode 100644 index 000000000000..c00b27ad365d --- /dev/null +++ b/crates/polars/tests/it/arrow/io/mod.rs @@ -0,0 +1 @@ +mod ipc; diff --git a/crates/polars/tests/it/arrow/mod.rs b/crates/polars/tests/it/arrow/mod.rs new file mode 100644 index 000000000000..492ab6567542 --- /dev/null +++ b/crates/polars/tests/it/arrow/mod.rs @@ -0,0 +1,12 @@ +mod ffi; +#[cfg(feature = "ipc")] +mod io; + +mod scalar; + +mod array; +mod bitmap; + +mod buffer; + +mod compute; diff --git a/crates/polars/tests/it/arrow/scalar/binary.rs b/crates/polars/tests/it/arrow/scalar/binary.rs new file mode 100644 index 000000000000..09cb8fba254b --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/binary.rs @@ -0,0 +1,31 @@ +use arrow::datatypes::ArrowDataType; +use arrow::scalar::{BinaryScalar, Scalar}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let a = BinaryScalar::::from(Some("a")); + let b = BinaryScalar::::from(None::<&str>); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + let b = BinaryScalar::::from(Some("b")); + assert!(a != b); + assert_eq!(b, b); +} + +#[test] +fn basics() { + let a = BinaryScalar::::from(Some("a")); + + assert_eq!(a.value(), Some(b"a".as_ref())); + assert_eq!(a.dtype(), &ArrowDataType::Binary); + assert!(a.is_valid()); + + let a = BinaryScalar::::from(None::<&str>); + + assert_eq!(a.dtype(), &ArrowDataType::LargeBinary); + assert!(!a.is_valid()); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/crates/polars/tests/it/arrow/scalar/boolean.rs b/crates/polars/tests/it/arrow/scalar/boolean.rs new file mode 100644 index 000000000000..76128d7355bd --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/boolean.rs @@ -0,0 +1,26 @@ +use arrow::datatypes::ArrowDataType; +use arrow::scalar::{BooleanScalar, Scalar}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let a = BooleanScalar::from(Some(true)); + let b = BooleanScalar::from(None); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + let b = BooleanScalar::from(Some(false)); + assert!(a != b); + assert_eq!(b, b); +} + +#[test] +fn basics() { + let a = BooleanScalar::new(Some(true)); + + assert_eq!(a.value(), Some(true)); + assert_eq!(a.dtype(), &ArrowDataType::Boolean); + assert!(a.is_valid()); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/crates/polars/tests/it/arrow/scalar/fixed_size_binary.rs b/crates/polars/tests/it/arrow/scalar/fixed_size_binary.rs new file mode 100644 index 000000000000..779e80e86c37 --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/fixed_size_binary.rs @@ -0,0 +1,26 @@ +use arrow::datatypes::ArrowDataType; +use arrow::scalar::{FixedSizeBinaryScalar, Scalar}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let a = FixedSizeBinaryScalar::new(ArrowDataType::FixedSizeBinary(1), Some("a")); + let b = FixedSizeBinaryScalar::new(ArrowDataType::FixedSizeBinary(1), None::<&str>); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + let b = FixedSizeBinaryScalar::new(ArrowDataType::FixedSizeBinary(1), Some("b")); + assert!(a != b); + assert_eq!(b, b); +} + +#[test] +fn basics() { + let a = FixedSizeBinaryScalar::new(ArrowDataType::FixedSizeBinary(1), Some("a")); + + assert_eq!(a.value(), Some(b"a".as_ref())); + assert_eq!(a.dtype(), &ArrowDataType::FixedSizeBinary(1)); + assert!(a.is_valid()); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/crates/polars/tests/it/arrow/scalar/fixed_size_list.rs b/crates/polars/tests/it/arrow/scalar/fixed_size_list.rs new file mode 100644 index 000000000000..eb4084792c33 --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/fixed_size_list.rs @@ -0,0 +1,47 @@ +use arrow::array::BooleanArray; +use arrow::datatypes::{ArrowDataType, Field}; +use arrow::scalar::{FixedSizeListScalar, Scalar}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let dt = ArrowDataType::FixedSizeList( + Box::new(Field::new("a".into(), ArrowDataType::Boolean, true)), + 2, + ); + let a = FixedSizeListScalar::new( + dt.clone(), + Some(BooleanArray::from_slice([true, false]).boxed()), + ); + + let b = FixedSizeListScalar::new(dt.clone(), None); + + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + + let b = FixedSizeListScalar::new(dt, Some(BooleanArray::from_slice([true, true]).boxed())); + assert!(a != b); + assert_eq!(b, b); +} + +#[test] +fn basics() { + let dt = ArrowDataType::FixedSizeList( + Box::new(Field::new("a".into(), ArrowDataType::Boolean, true)), + 2, + ); + let a = FixedSizeListScalar::new( + dt.clone(), + Some(BooleanArray::from_slice([true, false]).boxed()), + ); + + assert_eq!( + BooleanArray::from_slice([true, false]), + a.values().unwrap().as_ref() + ); + assert_eq!(a.dtype(), &dt); + assert!(a.is_valid()); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/crates/polars/tests/it/arrow/scalar/list.rs b/crates/polars/tests/it/arrow/scalar/list.rs new file mode 100644 index 000000000000..d8acce251fd1 --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/list.rs @@ -0,0 +1,43 @@ +use arrow::array::BooleanArray; +use arrow::datatypes::{ArrowDataType, Field}; +use arrow::scalar::{ListScalar, Scalar}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let dt = ArrowDataType::List(Box::new(Field::new( + "a".into(), + ArrowDataType::Boolean, + true, + ))); + let a = ListScalar::::new( + dt.clone(), + Some(BooleanArray::from_slice([true, false]).boxed()), + ); + let b = ListScalar::::new(dt.clone(), None); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + let b = ListScalar::::new(dt, Some(BooleanArray::from_slice([true, true]).boxed())); + assert!(a != b); + assert_eq!(b, b); +} + +#[test] +fn basics() { + let dt = ArrowDataType::List(Box::new(Field::new( + "a".into(), + ArrowDataType::Boolean, + true, + ))); + let a = ListScalar::::new( + dt.clone(), + Some(BooleanArray::from_slice([true, false]).boxed()), + ); + + assert_eq!(BooleanArray::from_slice([true, false]), a.values().as_ref()); + assert_eq!(a.dtype(), &dt); + assert!(a.is_valid()); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/crates/polars/tests/it/arrow/scalar/map.rs b/crates/polars/tests/it/arrow/scalar/map.rs new file mode 100644 index 000000000000..3a12e8bffcd6 --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/map.rs @@ -0,0 +1,69 @@ +use arrow::array::{BooleanArray, StructArray, Utf8Array}; +use arrow::datatypes::{ArrowDataType, Field}; +use arrow::scalar::{MapScalar, Scalar}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let kv_dt = ArrowDataType::Struct(vec![ + Field::new("key".into(), ArrowDataType::Utf8, false), + Field::new("value".into(), ArrowDataType::Boolean, true), + ]); + let kv_array1 = StructArray::try_new( + kv_dt.clone(), + 2, + vec![ + Utf8Array::::from([Some("k1"), Some("k2")]).boxed(), + BooleanArray::from_slice([true, false]).boxed(), + ], + None, + ) + .unwrap(); + let kv_array2 = StructArray::try_new( + kv_dt.clone(), + 2, + vec![ + Utf8Array::::from([Some("k1"), Some("k3")]).boxed(), + BooleanArray::from_slice([true, true]).boxed(), + ], + None, + ) + .unwrap(); + + let dt = ArrowDataType::Map(Box::new(Field::new("entries".into(), kv_dt, true)), false); + let a = MapScalar::new(dt.clone(), Some(Box::new(kv_array1))); + let b = MapScalar::new(dt.clone(), None); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + let b = MapScalar::new(dt, Some(Box::new(kv_array2))); + assert!(a != b); + assert_eq!(b, b); +} + +#[test] +fn basics() { + let kv_dt = ArrowDataType::Struct(vec![ + Field::new("key".into(), ArrowDataType::Utf8, false), + Field::new("value".into(), ArrowDataType::Boolean, true), + ]); + let kv_array = StructArray::try_new( + kv_dt.clone(), + 2, + vec![ + Utf8Array::::from([Some("k1"), Some("k2")]).boxed(), + BooleanArray::from_slice([true, false]).boxed(), + ], + None, + ) + .unwrap(); + + let dt = ArrowDataType::Map(Box::new(Field::new("entries".into(), kv_dt, true)), false); + let a = MapScalar::new(dt.clone(), Some(Box::new(kv_array.clone()))); + + assert_eq!(kv_array, a.values().as_ref()); + assert_eq!(a.dtype(), &dt); + assert!(a.is_valid()); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/crates/polars/tests/it/arrow/scalar/mod.rs b/crates/polars/tests/it/arrow/scalar/mod.rs new file mode 100644 index 000000000000..2f490f51e4fb --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/mod.rs @@ -0,0 +1,17 @@ +mod binary; +mod boolean; +mod fixed_size_binary; +mod fixed_size_list; +mod list; +mod map; +mod null; +mod primitive; +mod struct_; +mod utf8; + +// check that `PartialEq` can be derived +#[allow(dead_code)] +#[derive(PartialEq)] +struct A { + array: Box, +} diff --git a/crates/polars/tests/it/arrow/scalar/null.rs b/crates/polars/tests/it/arrow/scalar/null.rs new file mode 100644 index 000000000000..25f68534aa87 --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/null.rs @@ -0,0 +1,19 @@ +use arrow::datatypes::ArrowDataType; +use arrow::scalar::{NullScalar, Scalar}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let a = NullScalar::new(); + assert_eq!(a, a); +} + +#[test] +fn basics() { + let a = NullScalar::default(); + + assert_eq!(a.dtype(), &ArrowDataType::Null); + assert!(!a.is_valid()); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/crates/polars/tests/it/arrow/scalar/primitive.rs b/crates/polars/tests/it/arrow/scalar/primitive.rs new file mode 100644 index 000000000000..d9c2f04b4d37 --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/primitive.rs @@ -0,0 +1,36 @@ +use arrow::datatypes::ArrowDataType; +use arrow::scalar::{PrimitiveScalar, Scalar}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let a = PrimitiveScalar::from(Some(2i32)); + let b = PrimitiveScalar::::from(None); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + let b = PrimitiveScalar::::from(Some(1i32)); + assert!(a != b); + assert_eq!(b, b); +} + +#[test] +fn basics() { + let a = PrimitiveScalar::from(Some(2i32)); + + assert_eq!(a.value(), &Some(2i32)); + assert_eq!(a.dtype(), &ArrowDataType::Int32); + + let a = a.to(ArrowDataType::Date32); + assert_eq!(a.dtype(), &ArrowDataType::Date32); + + let a = PrimitiveScalar::::from(None); + + assert_eq!(a.dtype(), &ArrowDataType::Int32); + assert!(!a.is_valid()); + + let a = a.to(ArrowDataType::Date32); + assert_eq!(a.dtype(), &ArrowDataType::Date32); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/crates/polars/tests/it/arrow/scalar/struct_.rs b/crates/polars/tests/it/arrow/scalar/struct_.rs new file mode 100644 index 000000000000..1b4de73ef25c --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/struct_.rs @@ -0,0 +1,41 @@ +use arrow::datatypes::{ArrowDataType, Field}; +use arrow::scalar::{BooleanScalar, Scalar, StructScalar}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let dt = ArrowDataType::Struct(vec![Field::new("a".into(), ArrowDataType::Boolean, true)]); + let a = StructScalar::new( + dt.clone(), + Some(vec![ + Box::new(BooleanScalar::from(Some(true))) as Box + ]), + ); + let b = StructScalar::new(dt.clone(), None); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + let b = StructScalar::new( + dt, + Some(vec![ + Box::new(BooleanScalar::from(Some(false))) as Box + ]), + ); + assert!(a != b); + assert_eq!(b, b); +} + +#[test] +fn basics() { + let dt = ArrowDataType::Struct(vec![Field::new("a".into(), ArrowDataType::Boolean, true)]); + + let values = vec![Box::new(BooleanScalar::from(Some(true))) as Box]; + + let a = StructScalar::new(dt.clone(), Some(values.clone())); + + assert_eq!(a.values(), &values); + assert_eq!(a.dtype(), &dt); + assert!(a.is_valid()); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/crates/polars/tests/it/arrow/scalar/utf8.rs b/crates/polars/tests/it/arrow/scalar/utf8.rs new file mode 100644 index 000000000000..249922bcd60c --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/utf8.rs @@ -0,0 +1,31 @@ +use arrow::datatypes::ArrowDataType; +use arrow::scalar::{Scalar, Utf8Scalar}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let a = Utf8Scalar::::from(Some("a")); + let b = Utf8Scalar::::from(None::<&str>); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + let b = Utf8Scalar::::from(Some("b")); + assert!(a != b); + assert_eq!(b, b); +} + +#[test] +fn basics() { + let a = Utf8Scalar::::from(Some("a")); + + assert_eq!(a.value(), Some("a")); + assert_eq!(a.dtype(), &ArrowDataType::Utf8); + assert!(a.is_valid()); + + let a = Utf8Scalar::::from(None::<&str>); + + assert_eq!(a.dtype(), &ArrowDataType::LargeUtf8); + assert!(!a.is_valid()); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/crates/polars/tests/it/chunks/mod.rs b/crates/polars/tests/it/chunks/mod.rs new file mode 100644 index 000000000000..ab7fe5c8ec35 --- /dev/null +++ b/crates/polars/tests/it/chunks/mod.rs @@ -0,0 +1,2 @@ +#[cfg(feature = "parquet")] +mod parquet; diff --git a/crates/polars/tests/it/chunks/parquet.rs b/crates/polars/tests/it/chunks/parquet.rs new file mode 100644 index 000000000000..73770d0d0faa --- /dev/null +++ b/crates/polars/tests/it/chunks/parquet.rs @@ -0,0 +1,44 @@ +use std::io::{Seek, SeekFrom}; + +use polars::prelude::*; + +#[test] +fn test_cast_join_14872() { + let df1 = df![ + "ints" => [1] + ] + .unwrap(); + + let mut df2 = df![ + "ints" => [0, 1], + "strings" => vec![Series::new("".into(), ["a"]); 2], + ] + .unwrap(); + + let mut buf = std::io::Cursor::new(vec![]); + ParquetWriter::new(&mut buf) + .with_row_group_size(Some(1)) + .finish(&mut df2) + .unwrap(); + + let _ = buf.seek(SeekFrom::Start(0)); + let df2 = ParquetReader::new(buf).finish().unwrap(); + + let out = df1 + .join( + &df2, + ["ints"], + ["ints"], + JoinArgs::new(JoinType::Left), + None, + ) + .unwrap(); + + let expected = df![ + "ints" => [1], + "strings" => vec![Series::new("".into(), ["a"]); 1], + ] + .unwrap(); + + assert!(expected.equals(&out)); +} diff --git a/crates/polars/tests/it/core/date_like.rs b/crates/polars/tests/it/core/date_like.rs new file mode 100644 index 000000000000..0d08c6079539 --- /dev/null +++ b/crates/polars/tests/it/core/date_like.rs @@ -0,0 +1,176 @@ +use super::*; + +#[test] +#[cfg(feature = "dtype-datetime")] +#[cfg_attr(miri, ignore)] +fn test_datelike_join() -> PolarsResult<()> { + let s = Column::new("foo".into(), &[1, 2, 3]); + let mut s1 = s.cast(&DataType::Datetime(TimeUnit::Nanoseconds, None))?; + s1.rename("bar".into()); + + let df = DataFrame::new(vec![s, s1])?; + + let out = df.left_join(&df.clone(), ["bar"], ["bar"])?; + assert!(matches!( + out.column("bar")?.dtype(), + DataType::Datetime(TimeUnit::Nanoseconds, None) + )); + + let out = df.inner_join(&df.clone(), ["bar"], ["bar"])?; + assert!(matches!( + out.column("bar")?.dtype(), + DataType::Datetime(TimeUnit::Nanoseconds, None) + )); + + let out = df.full_join(&df.clone(), ["bar"], ["bar"])?; + assert!(matches!( + out.column("bar")?.dtype(), + DataType::Datetime(TimeUnit::Nanoseconds, None) + )); + Ok(()) +} + +#[test] +#[cfg(all(feature = "dtype-datetime", feature = "dtype-duration"))] +fn test_datelike_methods() -> PolarsResult<()> { + let s = Series::new("foo".into(), &[1, 2, 3]); + let s = s.cast(&DataType::Datetime(TimeUnit::Nanoseconds, None))?; + + let out = s.subtract(&s)?; + assert!(matches!( + out.dtype(), + DataType::Duration(TimeUnit::Nanoseconds) + )); + + let mut a = s.clone(); + a.append(&s).unwrap(); + assert_eq!(a.len(), 6); + + Ok(()) +} + +#[test] +#[cfg(all(feature = "dtype-datetime", feature = "dtype-duration"))] +fn test_arithmetic_dispatch() { + let s = Int64Chunked::new("".into(), &[1, 2, 3]) + .into_datetime(TimeUnit::Nanoseconds, None) + .into_series(); + + // check if we don't panic. + let out = &s * 100; + assert_eq!( + out.dtype(), + &DataType::Datetime(TimeUnit::Nanoseconds, None) + ); + let out = &s / 100; + assert_eq!( + out.dtype(), + &DataType::Datetime(TimeUnit::Nanoseconds, None) + ); + let out = &s + 100; + assert_eq!( + out.dtype(), + &DataType::Datetime(TimeUnit::Nanoseconds, None) + ); + let out = &s - 100; + assert_eq!( + out.dtype(), + &DataType::Datetime(TimeUnit::Nanoseconds, None) + ); + let out = &s % 100; + assert_eq!( + out.dtype(), + &DataType::Datetime(TimeUnit::Nanoseconds, None) + ); + + let out = 100.mul(&s); + assert_eq!( + out.dtype(), + &DataType::Datetime(TimeUnit::Nanoseconds, None) + ); + let out = 100.div(&s); + assert_eq!( + out.dtype(), + &DataType::Datetime(TimeUnit::Nanoseconds, None) + ); + let out = 100.sub(&s); + assert_eq!( + out.dtype(), + &DataType::Datetime(TimeUnit::Nanoseconds, None) + ); + let out = 100.add(&s); + assert_eq!( + out.dtype(), + &DataType::Datetime(TimeUnit::Nanoseconds, None) + ); + let out = 100.rem(&s); + assert_eq!( + out.dtype(), + &DataType::Datetime(TimeUnit::Nanoseconds, None) + ); +} + +#[test] +#[cfg(feature = "dtype-duration")] +fn test_duration() -> PolarsResult<()> { + let a = Int64Chunked::new("".into(), &[1, 2, 3]) + .into_datetime(TimeUnit::Nanoseconds, None) + .into_series(); + let b = Int64Chunked::new("".into(), &[2, 3, 4]) + .into_datetime(TimeUnit::Nanoseconds, None) + .into_series(); + let c = Int64Chunked::new("".into(), &[1, 1, 1]) + .into_duration(TimeUnit::Nanoseconds) + .into_series(); + assert_eq!( + *b.subtract(&a)?.dtype(), + DataType::Duration(TimeUnit::Nanoseconds) + ); + assert_eq!( + *a.add_to(&c)?.dtype(), + DataType::Datetime(TimeUnit::Nanoseconds, None) + ); + assert_eq!( + b.subtract(&a)?, + Int64Chunked::full("".into(), 1, a.len()) + .into_duration(TimeUnit::Nanoseconds) + .into_series() + ); + Ok(()) +} + +#[test] +#[cfg(feature = "dtype-duration")] +fn test_duration_date_arithmetic() -> PolarsResult<()> { + let date1 = Int32Chunked::new("".into(), &[1, 1, 1]) + .into_date() + .into_series(); + let date2 = Int32Chunked::new("".into(), &[2, 3, 4]) + .into_date() + .into_series(); + + let diff_ms = &date2 - &date1; + let diff_ms = diff_ms?; + let diff_us = diff_ms + .cast(&DataType::Duration(TimeUnit::Microseconds)) + .unwrap(); + let diff_ns = diff_ms + .cast(&DataType::Duration(TimeUnit::Nanoseconds)) + .unwrap(); + + // `+` is commutative for date and duration + assert_series_eq(&(&diff_ms + &date1)?, &(&date1 + &diff_ms)?); + assert_series_eq(&(&diff_us + &date1)?, &(&date1 + &diff_us)?); + assert_series_eq(&(&diff_ns + &date1)?, &(&date1 + &diff_ns)?); + + // `+` is correct date and duration + assert_series_eq(&(&diff_ms + &date1)?, &date2); + assert_series_eq(&(&diff_us + &date1)?, &date2); + assert_series_eq(&(&diff_ns + &date1)?, &date2); + + Ok(()) +} + +fn assert_series_eq(s1: &Series, s2: &Series) { + assert!(s1.equals(s2)) +} diff --git a/crates/polars/tests/it/core/group_by.rs b/crates/polars/tests/it/core/group_by.rs new file mode 100644 index 000000000000..12241bc5b2eb --- /dev/null +++ b/crates/polars/tests/it/core/group_by.rs @@ -0,0 +1,102 @@ +use polars_core::series::IsSorted; + +use super::*; + +#[test] +fn test_sorted_group_by() -> PolarsResult<()> { + // nulls last + let mut s = Series::new( + "a".into(), + &[Some(1), Some(1), Some(1), Some(6), Some(6), None], + ); + s.set_sorted_flag(IsSorted::Ascending); + for mt in [true, false] { + let out = s.group_tuples(mt, false)?; + assert_eq!(out.unwrap_slice(), &[[0, 3], [3, 2], [5, 1]]); + } + + // nulls first + let mut s = Series::new( + "a".into(), + &[None, None, Some(1), Some(1), Some(1), Some(6), Some(6)], + ); + s.set_sorted_flag(IsSorted::Ascending); + for mt in [true, false] { + let out = s.group_tuples(mt, false)?; + assert_eq!(out.unwrap_slice(), &[[0, 2], [2, 3], [5, 2]]); + } + + // nulls last + let mut s = Series::new( + "a".into(), + &[Some(1), Some(1), Some(1), Some(6), Some(6), None], + ); + s.set_sorted_flag(IsSorted::Ascending); + for mt in [true, false] { + let out = s.group_tuples(mt, false)?; + assert_eq!(out.unwrap_slice(), &[[0, 3], [3, 2], [5, 1]]); + } + + // nulls first descending sorted + let mut s = Series::new( + "a".into(), + &[ + None, + None, + Some(3), + Some(3), + Some(1), + Some(1), + Some(1), + Some(-1), + ], + ); + s.set_sorted_flag(IsSorted::Descending); + for mt in [false, true] { + let out = s.group_tuples(mt, false)?; + assert_eq!(out.unwrap_slice(), &[[0, 2], [2, 2], [4, 3], [7, 1]]); + } + + // nulls last descending sorted + let mut s = Series::new( + "a".into(), + &[ + Some(15), + Some(15), + Some(15), + Some(15), + Some(14), + Some(13), + Some(11), + Some(11), + Some(3), + Some(3), + Some(1), + Some(1), + Some(1), + Some(-1), + None, + None, + None, + ], + ); + s.set_sorted_flag(IsSorted::Descending); + for mt in [false, true] { + let out = s.group_tuples(mt, false)?; + assert_eq!( + out.unwrap_slice(), + &[ + [0, 4], + [4, 1], + [5, 1], + [6, 2], + [8, 2], + [10, 3], + [13, 1], + [14, 3] + ] + ); + } + + Ok(()) +} diff --git a/crates/polars/tests/it/core/joins.rs b/crates/polars/tests/it/core/joins.rs new file mode 100644 index 000000000000..7ab9d3c3a789 --- /dev/null +++ b/crates/polars/tests/it/core/joins.rs @@ -0,0 +1,690 @@ +use polars_core::utils::{accumulate_dataframes_vertical, split_df}; +#[cfg(feature = "dtype-categorical")] +use polars_core::{SINGLE_LOCK, disable_string_cache}; + +use super::*; + +#[test] +fn test_chunked_left_join() -> PolarsResult<()> { + let mut band_members = df![ + "name" => ["john", "paul", "mick", "bob"], + "band" => ["beatles", "beatles", "stones", "wailers"], + ]?; + + let mut band_instruments = df![ + "name" => ["john", "paul", "keith"], + "plays" => ["guitar", "bass", "guitar"] + ]?; + + let band_instruments = + accumulate_dataframes_vertical(split_df(&mut band_instruments, 2, false))?; + let band_members = accumulate_dataframes_vertical(split_df(&mut band_members, 2, false))?; + assert_eq!(band_instruments.first_col_n_chunks(), 2); + assert_eq!(band_members.first_col_n_chunks(), 2); + + let out = band_instruments.join( + &band_members, + ["name"], + ["name"], + JoinArgs::new(JoinType::Left), + None, + )?; + let expected = df![ + "name" => ["john", "paul", "keith"], + "plays" => ["guitar", "bass", "guitar"], + "band" => [Some("beatles"), Some("beatles"), None], + ]?; + assert!(out.equals_missing(&expected)); + + Ok(()) +} + +fn create_frames() -> (DataFrame, DataFrame) { + let s0 = Column::new("days".into(), &[0, 1, 2]); + let s1 = Column::new("temp".into(), &[22.1, 19.9, 7.]); + let s2 = Column::new("rain".into(), &[0.2, 0.1, 0.3]); + let temp = DataFrame::new(vec![s0, s1, s2]).unwrap(); + + let s0 = Column::new("days".into(), &[1, 2, 3, 1]); + let s1 = Column::new("rain".into(), &[0.1, 0.2, 0.3, 0.4]); + let rain = DataFrame::new(vec![s0, s1]).unwrap(); + (temp, rain) +} + +#[test] +#[cfg_attr(miri, ignore)] +fn test_inner_join() { + let (temp, rain) = create_frames(); + + for i in 1..8 { + unsafe { std::env::set_var("POLARS_MAX_THREADS", format!("{}", i)) }; + let joined = temp.inner_join(&rain, ["days"], ["days"]).unwrap(); + + let join_col_days = Column::new("days".into(), &[1, 2, 1]); + let join_col_temp = Column::new("temp".into(), &[19.9, 7., 19.9]); + let join_col_rain = Column::new("rain".into(), &[0.1, 0.3, 0.1]); + let join_col_rain_right = Column::new("rain_right".into(), [0.1, 0.2, 0.4].as_ref()); + let true_df = DataFrame::new(vec![ + join_col_days, + join_col_temp, + join_col_rain, + join_col_rain_right, + ]) + .unwrap(); + + assert!(joined.equals(&true_df)); + } +} + +#[test] +#[allow(clippy::float_cmp)] +#[cfg_attr(miri, ignore)] +fn test_left_join() { + for i in 1..8 { + unsafe { std::env::set_var("POLARS_MAX_THREADS", format!("{}", i)) }; + let s0 = Column::new("days".into(), &[0, 1, 2, 3, 4]); + let s1 = Column::new("temp".into(), &[22.1, 19.9, 7., 2., 3.]); + let temp = DataFrame::new(vec![s0, s1]).unwrap(); + + let s0 = Column::new("days".into(), &[1, 2]); + let s1 = Column::new("rain".into(), &[0.1, 0.2]); + let rain = DataFrame::new(vec![s0, s1]).unwrap(); + let joined = temp.left_join(&rain, ["days"], ["days"]).unwrap(); + assert_eq!( + (joined + .column("rain") + .unwrap() + .as_materialized_series() + .sum::() + .unwrap() + * 10.) + .round(), + 3. + ); + assert_eq!(joined.column("rain").unwrap().null_count(), 3); + + // test join on string + let s0 = Column::new("days".into(), &["mo", "tue", "wed", "thu", "fri"]); + let s1 = Column::new("temp".into(), &[22.1, 19.9, 7., 2., 3.]); + let temp = DataFrame::new(vec![s0, s1]).unwrap(); + + let s0 = Column::new("days".into(), &["tue", "wed"]); + let s1 = Column::new("rain".into(), &[0.1, 0.2]); + let rain = DataFrame::new(vec![s0, s1]).unwrap(); + let joined = temp.left_join(&rain, ["days"], ["days"]).unwrap(); + assert_eq!( + (joined + .column("rain") + .unwrap() + .as_materialized_series() + .sum::() + .unwrap() + * 10.) + .round(), + 3. + ); + assert_eq!(joined.column("rain").unwrap().null_count(), 3); + } +} + +#[test] +#[cfg_attr(miri, ignore)] +fn test_full_outer_join() -> PolarsResult<()> { + let (temp, rain) = create_frames(); + let joined = temp.join( + &rain, + ["days"], + ["days"], + JoinArgs::new(JoinType::Full).with_coalesce(JoinCoalesce::CoalesceColumns), + None, + )?; + assert_eq!(joined.height(), 5); + assert_eq!( + joined + .column("days")? + .as_materialized_series() + .sum::() + .unwrap(), + 7 + ); + + let df_left = df!( + "a"=> ["a", "b", "a", "z"], + "b"=>[1, 2, 3, 4], + "c"=>[6, 5, 4, 3] + )?; + let df_right = df!( + "a"=> ["b", "c", "b", "a"], + "k"=> [0, 3, 9, 6], + "c"=> [1, 0, 2, 1] + )?; + + let out = df_left.join( + &df_right, + ["a"], + ["a"], + JoinArgs::new(JoinType::Full).with_coalesce(JoinCoalesce::CoalesceColumns), + None, + )?; + assert_eq!(out.column("c_right")?.null_count(), 1); + + Ok(()) +} + +#[test] +#[cfg_attr(miri, ignore)] +fn test_join_with_nulls() { + let dts = &[20, 21, 22, 23, 24, 25, 27, 28]; + let vals = &[1.2, 2.4, 4.67, 5.8, 4.4, 3.6, 7.6, 6.5]; + let df = DataFrame::new(vec![ + Column::new("date".into(), dts), + Column::new("val".into(), vals), + ]) + .unwrap(); + + let vals2 = &[Some(1.1), None, Some(3.3), None, None]; + let df2 = DataFrame::new(vec![ + Column::new("date".into(), &dts[3..]), + Column::new("val2".into(), vals2), + ]) + .unwrap(); + + let joined = df.left_join(&df2, ["date"], ["date"]).unwrap(); + assert_eq!( + joined + .column("val2") + .unwrap() + .f64() + .unwrap() + .get(joined.height() - 1), + None + ); +} + +fn get_dfs() -> (DataFrame, DataFrame) { + let df_a = df! { + "a" => &[1, 2, 1, 1], + "b" => &["a", "b", "c", "c"], + "c" => &[0, 1, 2, 3] + } + .unwrap(); + + let df_b = df! { + "foo" => &[1, 1, 1], + "bar" => &["a", "c", "c"], + "ham" => &["let", "var", "const"] + } + .unwrap(); + (df_a, df_b) +} + +#[test] +#[cfg_attr(miri, ignore)] +fn test_join_multiple_columns() { + let (mut df_a, mut df_b) = get_dfs(); + + // First do a hack with concatenated string dummy column + let mut s = df_a + .column("a") + .unwrap() + .cast(&DataType::String) + .unwrap() + .str() + .unwrap() + + df_a.column("b").unwrap().str().unwrap(); + s.rename("dummy".into()); + + df_a.with_column(s).unwrap(); + let mut s = df_b + .column("foo") + .unwrap() + .cast(&DataType::String) + .unwrap() + .str() + .unwrap() + + df_b.column("bar").unwrap().str().unwrap(); + s.rename("dummy".into()); + df_b.with_column(s).unwrap(); + + let joined = df_a.left_join(&df_b, ["dummy"], ["dummy"]).unwrap(); + let ham_col = joined.column("ham").unwrap(); + let ca = ham_col.str().unwrap(); + + let correct_ham = &[ + Some("let"), + None, + Some("var"), + Some("const"), + Some("var"), + Some("const"), + ]; + + assert_eq!(Vec::from(ca), correct_ham); + + // now check the join with multiple columns + let joined = df_a + .join( + &df_b, + ["a", "b"], + ["foo", "bar"], + JoinType::Left.into(), + None, + ) + .unwrap(); + let ca = joined.column("ham").unwrap().str().unwrap(); + assert_eq!(Vec::from(ca), correct_ham); + let joined_inner_hack = df_a.inner_join(&df_b, ["dummy"], ["dummy"]).unwrap(); + let joined_inner = df_a + .join( + &df_b, + ["a", "b"], + ["foo", "bar"], + JoinType::Inner.into(), + None, + ) + .unwrap(); + + assert!( + joined_inner_hack + .column("ham") + .unwrap() + .equals_missing(joined_inner.column("ham").unwrap()) + ); + + let joined_full_outer_hack = df_a.full_join(&df_b, ["dummy"], ["dummy"]).unwrap(); + let joined_full_outer = df_a + .join( + &df_b, + ["a", "b"], + ["foo", "bar"], + JoinArgs::new(JoinType::Full).with_coalesce(JoinCoalesce::CoalesceColumns), + None, + ) + .unwrap(); + assert!( + joined_full_outer_hack + .column("ham") + .unwrap() + .equals_missing(joined_full_outer.column("ham").unwrap()) + ); +} + +#[test] +#[cfg_attr(miri, ignore)] +#[cfg(feature = "dtype-categorical")] +fn test_join_categorical() { + let _guard = SINGLE_LOCK.lock(); + disable_string_cache(); + let _sc = StringCacheHolder::hold(); + + let (mut df_a, mut df_b) = get_dfs(); + + df_a.try_apply("b", |s| { + s.cast(&DataType::Categorical(None, Default::default())) + }) + .unwrap(); + df_b.try_apply("bar", |s| { + s.cast(&DataType::Categorical(None, Default::default())) + }) + .unwrap(); + + let out = df_a + .join(&df_b, ["b"], ["bar"], JoinType::Left.into(), None) + .unwrap(); + assert_eq!(out.shape(), (6, 5)); + let correct_ham = &[ + Some("let"), + None, + Some("var"), + Some("const"), + Some("var"), + Some("const"), + ]; + let ham_col = out.column("ham").unwrap(); + let ca = ham_col.str().unwrap(); + + assert_eq!(Vec::from(ca), correct_ham); + + // test dispatch + for jt in [JoinType::Left, JoinType::Inner, JoinType::Full] { + let out = df_a.join(&df_b, ["b"], ["bar"], jt.into(), None).unwrap(); + let out = out.column("b").unwrap(); + assert_eq!( + out.dtype(), + &DataType::Categorical(None, Default::default()) + ); + } + + // Test error when joining on different string cache + let (mut df_a, mut df_b) = get_dfs(); + df_a.try_apply("b", |s| { + s.cast(&DataType::Categorical(None, Default::default())) + }) + .unwrap(); + + // Create a new string cache + drop(_sc); + let _sc = StringCacheHolder::hold(); + + df_b.try_apply("bar", |s| { + s.cast(&DataType::Categorical(None, Default::default())) + }) + .unwrap(); + let out = df_a.join(&df_b, ["b"], ["bar"], JoinType::Left.into(), None); + assert!(out.is_err()); +} + +#[test] +#[cfg_attr(miri, ignore)] +fn test_empty_df_join() -> PolarsResult<()> { + let empty: Vec = vec![]; + let empty_df = DataFrame::new(vec![ + Column::new("key".into(), &empty), + Column::new("eval".into(), &empty), + ]) + .unwrap(); + + let df = DataFrame::new(vec![ + Column::new("key".into(), &["foo"]), + Column::new("aval".into(), &[4]), + ]) + .unwrap(); + + let out = empty_df.inner_join(&df, ["key"], ["key"]).unwrap(); + assert_eq!(out.height(), 0); + let out = empty_df.left_join(&df, ["key"], ["key"]).unwrap(); + assert_eq!(out.height(), 0); + let out = empty_df.full_join(&df, ["key"], ["key"]).unwrap(); + assert_eq!(out.height(), 1); + df.left_join(&empty_df, ["key"], ["key"])?; + df.inner_join(&empty_df, ["key"], ["key"])?; + df.full_join(&empty_df, ["key"], ["key"])?; + + let empty: Vec = vec![]; + let _empty_df = DataFrame::new(vec![ + Column::new("key".into(), &empty), + Column::new("eval".into(), &empty), + ]) + .unwrap(); + + let df = df![ + "key" => [1i32, 2], + "vals" => [1, 2], + ]?; + + // https://github.com/pola-rs/polars/issues/1824 + let empty: Vec = vec![]; + let empty_df = DataFrame::new(vec![ + Column::new("key".into(), &empty), + Column::new("1val".into(), &empty), + Column::new("2val".into(), &empty), + ])?; + + let out = df.left_join(&empty_df, ["key"], ["key"])?; + assert_eq!(out.shape(), (2, 4)); + + Ok(()) +} + +#[test] +#[cfg_attr(miri, ignore)] +fn test_unit_df_join() -> PolarsResult<()> { + let df1 = df![ + "a" => [1], + "b" => [2] + ]?; + + let df2 = df![ + "a" => [1, 2, 3, 4], + "b" => [Some(1), None, Some(3), Some(4)] + ]?; + + let out = df1.left_join(&df2, ["a"], ["a"])?; + let expected = df![ + "a" => [1], + "b" => [2], + "b_right" => [1] + ]?; + assert!(out.equals(&expected)); + Ok(()) +} + +#[test] +#[cfg_attr(miri, ignore)] +fn test_join_err() -> PolarsResult<()> { + let df1 = df![ + "a" => [1, 2], + "b" => ["foo", "bar"] + ]?; + + let df2 = df![ + "a" => [1, 2, 3, 4], + "b" => [true, true, true, false] + ]?; + + // dtypes don't match, error + assert!( + df1.join( + &df2, + vec!["a", "b"], + vec!["a", "b"], + JoinType::Left.into(), + None + ) + .is_err() + ); + Ok(()) +} + +#[test] +#[cfg_attr(miri, ignore)] +fn test_joins_with_duplicates() -> PolarsResult<()> { + // test joins with duplicates in both dataframes + + let df_left = df![ + "col1" => [1, 1, 2], + "int_col" => [1, 2, 3] + ] + .unwrap(); + + let df_right = df![ + "join_col1" => [1, 1, 1, 1, 1, 3], + "dbl_col" => [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] + ] + .unwrap(); + + let df_inner_join = df_left + .inner_join(&df_right, ["col1"], ["join_col1"]) + .unwrap(); + + assert_eq!(df_inner_join.height(), 10); + assert_eq!(df_inner_join.column("col1")?.null_count(), 0); + assert_eq!(df_inner_join.column("int_col")?.null_count(), 0); + assert_eq!(df_inner_join.column("dbl_col")?.null_count(), 0); + + let df_left_join = df_left + .left_join(&df_right, ["col1"], ["join_col1"]) + .unwrap(); + + assert_eq!(df_left_join.height(), 11); + assert_eq!(df_left_join.column("col1")?.null_count(), 0); + assert_eq!(df_left_join.column("int_col")?.null_count(), 0); + assert_eq!(df_left_join.column("dbl_col")?.null_count(), 1); + + let df_full_outer_join = df_left + .join( + &df_right, + ["col1"], + ["join_col1"], + JoinArgs::new(JoinType::Full).with_coalesce(JoinCoalesce::CoalesceColumns), + None, + ) + .unwrap(); + + // ensure the column names don't get swapped by the drop we do + assert_eq!( + df_full_outer_join.get_column_names(), + &["col1", "int_col", "dbl_col"] + ); + assert_eq!(df_full_outer_join.height(), 12); + assert_eq!(df_full_outer_join.column("col1")?.null_count(), 0); + assert_eq!(df_full_outer_join.column("int_col")?.null_count(), 1); + assert_eq!(df_full_outer_join.column("dbl_col")?.null_count(), 1); + + Ok(()) +} + +#[test] +#[cfg_attr(miri, ignore)] +fn test_multi_joins_with_duplicates() -> PolarsResult<()> { + // test joins with multiple join columns and duplicates in both + // dataframes + + let df_left = df![ + "col1" => [1, 1, 1], + "join_col2" => ["a", "a", "b"], + "int_col" => [1, 2, 3] + ] + .unwrap(); + + let df_right = df![ + "join_col1" => [1, 1, 1, 1, 1, 2], + "col2" => ["a", "a", "a", "a", "a", "c"], + "dbl_col" => [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] + ] + .unwrap(); + + let df_inner_join = df_left + .join( + &df_right, + ["col1", "join_col2"], + ["join_col1", "col2"], + JoinType::Inner.into(), + None, + ) + .unwrap(); + + assert_eq!(df_inner_join.height(), 10); + assert_eq!(df_inner_join.column("col1")?.null_count(), 0); + assert_eq!(df_inner_join.column("join_col2")?.null_count(), 0); + assert_eq!(df_inner_join.column("int_col")?.null_count(), 0); + assert_eq!(df_inner_join.column("dbl_col")?.null_count(), 0); + + let df_left_join = df_left + .join( + &df_right, + ["col1", "join_col2"], + ["join_col1", "col2"], + JoinType::Left.into(), + None, + ) + .unwrap(); + + assert_eq!(df_left_join.height(), 11); + assert_eq!(df_left_join.column("col1")?.null_count(), 0); + assert_eq!(df_left_join.column("join_col2")?.null_count(), 0); + assert_eq!(df_left_join.column("int_col")?.null_count(), 0); + assert_eq!(df_left_join.column("dbl_col")?.null_count(), 1); + + let df_full_outer_join = df_left + .join( + &df_right, + ["col1", "join_col2"], + ["join_col1", "col2"], + JoinArgs::new(JoinType::Full).with_coalesce(JoinCoalesce::CoalesceColumns), + None, + ) + .unwrap(); + + assert_eq!(df_full_outer_join.height(), 12); + assert_eq!(df_full_outer_join.column("col1")?.null_count(), 0); + assert_eq!(df_full_outer_join.column("join_col2")?.null_count(), 0); + assert_eq!(df_full_outer_join.column("int_col")?.null_count(), 1); + assert_eq!(df_full_outer_join.column("dbl_col")?.null_count(), 1); + + Ok(()) +} + +#[test] +#[cfg_attr(miri, ignore)] +fn test_join_floats() -> PolarsResult<()> { + let df_a = df! { + "a" => &[1.0, 2.0, 1.0, 1.0], + "b" => &["a", "b", "c", "c"], + "c" => &[0.0, 1.0, 2.0, 3.0] + }?; + + let df_b = df! { + "foo" => &[1.0, 2.0, 1.0], + "bar" => &[1.0, 1.0, 1.0], + "ham" => &["let", "var", "const"] + }?; + + let out = df_a.join( + &df_b, + vec!["a", "c"], + vec!["foo", "bar"], + JoinType::Left.into(), + None, + )?; + assert_eq!( + Vec::from(out.column("ham")?.str()?), + &[None, Some("var"), None, None] + ); + + let out = df_a.join( + &df_b, + vec!["a", "c"], + vec!["foo", "bar"], + JoinArgs::new(JoinType::Full).with_coalesce(JoinCoalesce::CoalesceColumns), + None, + )?; + assert_eq!( + out.dtypes(), + &[ + DataType::Float64, + DataType::String, + DataType::Float64, + DataType::String + ] + ); + Ok(()) +} + +#[test] +#[cfg_attr(miri, ignore)] +#[cfg(feature = "lazy")] +fn test_4_threads_bit_offset() -> PolarsResult<()> { + // run this locally with a thread pool size of 4 + // this was an obscure bug caused by not taking the offset of a bit into account. + let n = 8i64; + let mut left_a = (0..n).map(Some).collect::(); + let mut left_b = (0..n) + .map(|i| if i % 2 == 0 { None } else { Some(0) }) + .collect::(); + left_a.rename("a".into()); + left_b.rename("b".into()); + let left_df = DataFrame::new(vec![left_a.into_column(), left_b.into_column()])?; + + let i = 1; + let len = 8; + let range = i..i + len; + let mut right_a = range.clone().map(Some).collect::(); + let mut right_b = range + .map(|i| if i % 3 == 0 { None } else { Some(1) }) + .collect::(); + right_a.rename("a".into()); + right_b.rename("b".into()); + + let right_df = DataFrame::new(vec![right_a.into_column(), right_b.into_column()])?; + let out = JoinBuilder::new(left_df.lazy()) + .with(right_df.lazy()) + .on([col("a"), col("b")]) + .how(JoinType::Inner) + .join_nulls(true) + .finish() + .collect() + .unwrap(); + assert_eq!(out.shape(), (1, 2)); + Ok(()) +} diff --git a/crates/polars/tests/it/core/list.rs b/crates/polars/tests/it/core/list.rs new file mode 100644 index 000000000000..f485ccadd482 --- /dev/null +++ b/crates/polars/tests/it/core/list.rs @@ -0,0 +1,16 @@ +use polars::prelude::*; + +#[test] +fn test_to_list_logical() -> PolarsResult<()> { + let ca = StringChunked::new("a".into(), &["2021-01-01", "2021-01-02", "2021-01-03"]); + let out = ca.as_date(None, false)?.into_series(); + let out = out.implode().unwrap(); + assert_eq!(out.len(), 1); + let s = format!("{:?}", out); + // check if dtype is maintained all the way to formatting + assert!(s.contains("[2021-01-01, 2021-01-02, 2021-01-03]")); + + let expl = out.explode().unwrap(); + assert_eq!(expl.dtype(), &DataType::Date); + Ok(()) +} diff --git a/crates/polars/tests/it/core/mod.rs b/crates/polars/tests/it/core/mod.rs new file mode 100644 index 000000000000..76adb01d5677 --- /dev/null +++ b/crates/polars/tests/it/core/mod.rs @@ -0,0 +1,13 @@ +mod date_like; +mod group_by; +mod joins; +mod list; +mod ops; +#[cfg(feature = "pivot")] +mod pivot; +#[cfg(feature = "rolling_window")] +mod rolling_window; +mod series; +mod utils; + +use polars::prelude::*; diff --git a/crates/polars/tests/it/core/ops/mod.rs b/crates/polars/tests/it/core/ops/mod.rs new file mode 100644 index 000000000000..0945ca71ff16 --- /dev/null +++ b/crates/polars/tests/it/core/ops/mod.rs @@ -0,0 +1,2 @@ +use super::*; +mod take; diff --git a/crates/polars/tests/it/core/ops/take.rs b/crates/polars/tests/it/core/ops/take.rs new file mode 100644 index 000000000000..373c644da066 --- /dev/null +++ b/crates/polars/tests/it/core/ops/take.rs @@ -0,0 +1,14 @@ +use super::*; + +#[test] +fn test_list_gather_nulls_and_empty() { + let a: &[i32] = &[]; + let a = Series::new("".into(), a); + let b = Series::new("".into(), &[None, Some(a.clone())]); + let indices = [Some(0 as IdxSize), Some(1), None] + .into_iter() + .collect_ca("".into()); + let out = b.take(&indices).unwrap(); + let expected = Series::new("".into(), &[None, Some(a), None]); + assert!(out.equals_missing(&expected)) +} diff --git a/crates/polars/tests/it/core/pivot.rs b/crates/polars/tests/it/core/pivot.rs new file mode 100644 index 000000000000..04ea549bdff3 --- /dev/null +++ b/crates/polars/tests/it/core/pivot.rs @@ -0,0 +1,272 @@ +use chrono::NaiveDate; +use polars::prelude::*; +use polars_ops::pivot::{PivotAgg, pivot, pivot_stable}; + +#[test] +#[cfg(feature = "dtype-date")] +fn test_pivot_date_() -> PolarsResult<()> { + let mut df = df![ + "index" => [8, 2, 3, 6, 3, 6, 2, 2], + "values1" => [1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000], + "values2" => [1, 1, 1, 1, 1, 1, 1, 1], + ]?; + df.try_apply("values1", |s| s.cast(&DataType::Date))?; + + // Test with date as the `columns` input + let out = pivot( + &df, + ["values1"], + Some(["index"]), + Some(["values2"]), + true, + Some(PivotAgg::Count), + None, + )?; + + let first = 1 as IdxSize; + let expected = df![ + "index" => [8i32, 2, 3, 6], + "1972-09-27" => [first, 3, 2, 2] + ]?; + assert!(out.equals_missing(&expected)); + + // Test with date as the `values` input. + let mut out = pivot_stable( + &df, + ["values2"], + Some(["index"]), + Some(["values1"]), + true, + Some(PivotAgg::First), + None, + )?; + out.try_apply("1", |s| { + let ca = s.date()?; + ca.to_string("%Y-%d-%m") + })?; + + let expected = df![ + "index" => [8i32, 2, 3, 6], + "1" => ["1972-27-09", "1972-27-09", "1972-27-09", "1972-27-09"] + ]?; + assert!(out.equals_missing(&expected)); + + Ok(()) +} + +#[test] +fn test_pivot_old() { + let s0 = Column::new("index".into(), ["A", "A", "B", "B", "C"].as_ref()); + let s2 = Column::new("columns".into(), ["k", "l", "m", "m", "l"].as_ref()); + let s1 = Column::new("values".into(), [1, 2, 2, 4, 2].as_ref()); + let df = DataFrame::new(vec![s0, s1, s2]).unwrap(); + + let pvt = pivot( + &df, + ["columns"], + Some(["index"]), + Some(["values"]), + false, + Some(PivotAgg::Sum), + None, + ) + .unwrap(); + assert_eq!(pvt.get_column_names(), &["index", "k", "l", "m"]); + assert_eq!( + Vec::from(&pvt.column("m").unwrap().i32().unwrap().sort(false)), + &[None, None, Some(6)] + ); + let pvt = pivot( + &df, + ["columns"], + Some(["index"]), + Some(["values"]), + false, + Some(PivotAgg::Min), + None, + ) + .unwrap(); + assert_eq!( + Vec::from(&pvt.column("m").unwrap().i32().unwrap().sort(false)), + &[None, None, Some(2)] + ); + let pvt = pivot( + &df, + ["columns"], + Some(["index"]), + Some(["values"]), + false, + Some(PivotAgg::Max), + None, + ) + .unwrap(); + assert_eq!( + Vec::from(&pvt.column("m").unwrap().i32().unwrap().sort(false)), + &[None, None, Some(4)] + ); + let pvt = pivot( + &df, + ["columns"], + Some(["index"]), + Some(["values"]), + false, + Some(PivotAgg::Mean), + None, + ) + .unwrap(); + assert_eq!( + Vec::from(&pvt.column("m").unwrap().f64().unwrap().sort(false)), + &[None, None, Some(3.0)] + ); + let pvt = pivot( + &df, + ["columns"], + Some(["index"]), + Some(["values"]), + false, + Some(PivotAgg::Count), + None, + ) + .unwrap(); + assert_eq!( + Vec::from(&pvt.column("m").unwrap().idx().unwrap().sort(false)), + &[None, None, Some(2)] + ); +} + +#[test] +#[cfg(feature = "dtype-categorical")] +fn test_pivot_categorical() -> PolarsResult<()> { + let mut df = df![ + "index" => [1, 1, 1, 1, 1, 1, 1, 1], + "columns" => ["a", "b", "c", "a", "b", "c", "a", "b"], + "values" => [8, 2, 3, 6, 3, 6, 2, 2], + ]?; + df.try_apply("columns", |s| { + s.cast(&DataType::Categorical(None, Default::default())) + })?; + + let out = pivot( + &df, + ["columns"], + Some(["index"]), + Some(["values"]), + true, + Some(PivotAgg::Count), + None, + )?; + assert_eq!(out.get_column_names(), &["index", "a", "b", "c"]); + + Ok(()) +} + +#[test] +fn test_pivot_new() -> PolarsResult<()> { + let df = df![ + "index1"=> ["foo", "foo", "foo", "foo", "foo", "bar", "bar", "bar", "bar"], + "index2"=> ["one", "one", "one", "two", "two", "one", "one", "two", "two"], + "cols1"=> ["small", "large", "large", "small", "small", "large", "small", "small", "large"], + "cols2"=> ["jam", "egg", "egg", "egg", "jam", "jam", "potato", "jam", "jam"], + "values1"=> [1, 2, 2, 3, 3, 4, 5, 6, 7], + "values2"=> [2, 4, 5, 5, 6, 6, 8, 9, 9] + ]?; + + let out = (pivot_stable( + &df, + ["cols1"], + Some(["index1", "index2"]), + Some(["values1"]), + true, + Some(PivotAgg::Sum), + None, + ))?; + let expected = df![ + "index1" => ["foo", "foo", "bar", "bar"], + "index2" => ["one", "two", "one", "two"], + "large" => [Some(4), None, Some(4), Some(7)], + "small" => [1, 6, 5, 6], + ]?; + assert!(out.equals_missing(&expected)); + + let out = pivot_stable( + &df, + ["cols1", "cols2"], + Some(["index1", "index2"]), + Some(["values1"]), + true, + Some(PivotAgg::Sum), + None, + )?; + let expected = df![ + "index1" => ["foo", "foo", "bar", "bar"], + "index2" => ["one", "two", "one", "two"], + "{\"large\",\"egg\"}" => [Some(4), None, None, None], + "{\"large\",\"jam\"}" => [None, None, Some(4), Some(7)], + "{\"small\",\"egg\"}" => [None, Some(3), None, None], + "{\"small\",\"jam\"}" => [Some(1), Some(3), None, Some(6)], + "{\"small\",\"potato\"}" => [None, None, Some(5), None], + ]?; + assert!(out.equals_missing(&expected)); + + Ok(()) +} + +#[test] +fn test_pivot_2() -> PolarsResult<()> { + let df = df![ + "index" => [Some("name1"), Some("name2"), None, Some("name1"), Some("name2")], + "columns"=> ["avg", "avg", "act", "test", "test"], + "values"=> [0.0, 0.1, 1.0, 0.4, 0.2] + ]?; + + let out = pivot_stable( + &df, + ["columns"], + Some(["index"]), + Some(["values"]), + false, + Some(PivotAgg::First), + None, + )?; + let expected = df![ + "index" => [Some("name1"), Some("name2"), None], + "avg" => [Some(0.0), Some(0.1), None], + "act" => [None, None, Some(1.)], + "test" => [Some(0.4), Some(0.2), None], + ]?; + assert!(out.equals_missing(&expected)); + + Ok(()) +} + +#[test] +#[cfg(feature = "dtype-datetime")] +fn test_pivot_datetime() -> PolarsResult<()> { + let dt = NaiveDate::from_ymd_opt(2021, 1, 1) + .unwrap() + .and_hms_opt(12, 15, 0) + .unwrap(); + let df = df![ + "index" => [dt, dt, dt, dt], + "columns" => ["x", "x", "y", "y"], + "values" => [100, 50, 500, -80] + ]?; + + let out = pivot( + &df, + ["columns"], + Some(["index"]), + Some(["values"]), + false, + Some(PivotAgg::Sum), + None, + )?; + let expected = df![ + "index" => [dt], + "x" => [150], + "y" => [420] + ]?; + assert!(out.equals(&expected)); + + Ok(()) +} diff --git a/crates/polars/tests/it/core/random.rs b/crates/polars/tests/it/core/random.rs new file mode 100644 index 000000000000..3310ce07fae3 --- /dev/null +++ b/crates/polars/tests/it/core/random.rs @@ -0,0 +1,31 @@ +use polars_core::series::IsSorted; +use super::*; + +#[test] +fn test_sample_sorted() { + let s = Series::new("a", [1, 2, 3]).sort(false); + matches!(s.is_sorted_flag(),IsSorted::Ascending); + let out = s.sample_frac(1.5, true, false, None).unwrap(); + matches!(s.is_sorted_flag(),IsSorted::Not); +} + +#[test] +fn test_sample() { + let df = df![ + "foo" => &[1, 2, 3, 4, 5] + ] + .unwrap(); + + // default samples are random and don't require seeds + assert!(df.sample_n(3, false, false, None).is_ok()); + assert!(df.sample_frac(0.4, false, false, None).is_ok()); + // with seeding + assert!(df.sample_n(3, false, false, Some(0)).is_ok()); + assert!(df.sample_frac(0.4, false, false, Some(0)).is_ok()); + // without replacement can not sample more than 100% + assert!(df.sample_frac(2.0, false, false, Some(0)).is_err()); + assert!(df.sample_n(3, true, false, Some(0)).is_ok()); + assert!(df.sample_frac(0.4, true, false, Some(0)).is_ok()); + // with replacement can sample more than 100% + assert!(df.sample_frac(2.0, true, false, Some(0)).is_ok()); +} diff --git a/crates/polars/tests/it/core/rolling_window.rs b/crates/polars/tests/it/core/rolling_window.rs new file mode 100644 index 000000000000..71a360d41772 --- /dev/null +++ b/crates/polars/tests/it/core/rolling_window.rs @@ -0,0 +1,283 @@ +use super::*; + +#[test] +fn test_rolling() { + let s = Int32Chunked::new("foo".into(), &[1, 2, 3, 2, 1]).into_series(); + let a = s + .rolling_sum(RollingOptionsFixedWindow { + window_size: 2, + min_periods: 1, + ..Default::default() + }) + .unwrap(); + let a = a.i32().unwrap(); + assert_eq!( + Vec::from(a), + [1, 3, 5, 5, 3] + .iter() + .copied() + .map(Some) + .collect::>() + ); + let a = s + .rolling_min(RollingOptionsFixedWindow { + window_size: 2, + min_periods: 1, + ..Default::default() + }) + .unwrap(); + let a = a.i32().unwrap(); + assert_eq!( + Vec::from(a), + [1, 1, 2, 2, 1] + .iter() + .copied() + .map(Some) + .collect::>() + ); + let a = s + .rolling_max(RollingOptionsFixedWindow { + window_size: 2, + weights: Some(vec![1., 1.]), + min_periods: 1, + ..Default::default() + }) + .unwrap(); + + let a = a.f64().unwrap(); + assert_eq!( + Vec::from(a), + [1., 2., 3., 3., 2.] + .iter() + .copied() + .map(Some) + .collect::>() + ); +} + +#[test] +fn test_rolling_min_periods() { + let s = Int32Chunked::new("foo".into(), &[1, 2, 3, 2, 1]).into_series(); + let a = s + .rolling_max(RollingOptionsFixedWindow { + window_size: 2, + min_periods: 2, + ..Default::default() + }) + .unwrap(); + let a = a.i32().unwrap(); + assert_eq!(Vec::from(a), &[None, Some(2), Some(3), Some(3), Some(2)]); +} + +#[test] +fn test_rolling_mean() { + let s = Float64Chunked::new( + "foo".into(), + &[ + Some(0.0), + Some(1.0), + Some(2.0), + None, + None, + Some(5.0), + Some(6.0), + ], + ) + .into_series(); + + // check err on wrong input + assert!( + s.rolling_mean(RollingOptionsFixedWindow { + window_size: 1, + min_periods: 2, + ..Default::default() + }) + .is_err() + ); + + // validate that we divide by the proper window length. (same as pandas) + let a = s + .rolling_mean(RollingOptionsFixedWindow { + window_size: 3, + min_periods: 1, + center: false, + ..Default::default() + }) + .unwrap(); + let a = a.f64().unwrap(); + assert_eq!( + Vec::from(a), + &[ + Some(0.0), + Some(0.5), + Some(1.0), + Some(1.5), + Some(2.0), + Some(5.0), + Some(5.5) + ] + ); + + // check centered rolling window + let a = s + .rolling_mean(RollingOptionsFixedWindow { + window_size: 3, + min_periods: 1, + center: true, + ..Default::default() + }) + .unwrap(); + let a = a.f64().unwrap(); + assert_eq!( + Vec::from(a), + &[ + Some(0.5), + Some(1.0), + Some(1.5), + Some(2.0), + Some(5.0), + Some(5.5), + Some(5.5) + ] + ); + + // integers + let ca = Int32Chunked::from_slice("".into(), &[1, 8, 6, 2, 16, 10]); + let out = ca + .into_series() + .rolling_mean(RollingOptionsFixedWindow { + window_size: 2, + weights: None, + min_periods: 2, + center: false, + ..Default::default() + }) + .unwrap(); + + let out = out.f64().unwrap(); + assert_eq!( + Vec::from(out), + &[None, Some(4.5), Some(7.0), Some(4.0), Some(9.0), Some(13.0)] + ); +} + +#[test] +fn test_rolling_map() { + let ca = Float64Chunked::new( + "foo".into(), + &[ + Some(0.0), + Some(1.0), + Some(2.0), + None, + None, + Some(5.0), + Some(6.0), + ], + ); + + let out = ca + .rolling_map( + &|s| s.sum_reduce().unwrap().into_series(s.name().clone()), + RollingOptionsFixedWindow { + window_size: 3, + min_periods: 3, + ..Default::default() + }, + ) + .unwrap(); + + let out = out.f64().unwrap(); + + assert_eq!( + Vec::from(out), + &[None, None, Some(3.0), None, None, None, None] + ); +} + +#[test] +fn test_rolling_var() { + let s = Float64Chunked::new( + "foo".into(), + &[ + Some(0.0), + Some(1.0), + Some(2.0), + None, + None, + Some(5.0), + Some(6.0), + ], + ) + .into_series(); + // window larger than array + assert_eq!( + s.rolling_var(RollingOptionsFixedWindow { + window_size: 10, + min_periods: 10, + ..Default::default() + }) + .unwrap() + .null_count(), + s.len() + ); + + let options = RollingOptionsFixedWindow { + window_size: 3, + min_periods: 3, + ..Default::default() + }; + let out = s + .rolling_var(options.clone()) + .unwrap() + .cast(&DataType::Int32) + .unwrap(); + let out = out.i32().unwrap(); + assert_eq!( + Vec::from(out), + &[None, None, Some(1), None, None, None, None] + ); + + let s = Float64Chunked::from_slice("".into(), &[0.0, 2.0, 8.0, 3.0, 12.0, 1.0]).into_series(); + let out = s + .rolling_var(options) + .unwrap() + .cast(&DataType::Int32) + .unwrap(); + let out = out.i32().unwrap(); + + assert_eq!( + Vec::from(out), + &[None, None, Some(17), Some(10), Some(20), Some(34)] + ); + + // check centered rolling window + let out = s + .rolling_var(RollingOptionsFixedWindow { + window_size: 4, + min_periods: 3, + center: true, + ..Default::default() + }) + .unwrap(); + let out = out.f64().unwrap().to_vec(); + + let exp_res = &[ + None, + Some(17.333333333333332), + Some(11.583333333333334), + Some(21.583333333333332), + Some(24.666666666666668), + Some(34.33333333333334), + ]; + let test_res = out.iter().zip(exp_res.iter()).all(|(&a, &b)| match (a, b) { + (None, None) => true, + (Some(a), Some(b)) => (a - b).abs() < 1e-12, + (_, _) => false, + }); + assert!( + test_res, + "{:?} is not approximately equal to {:?}", + out, exp_res + ); +} diff --git a/crates/polars/tests/it/core/series.rs b/crates/polars/tests/it/core/series.rs new file mode 100644 index 000000000000..3d740ad5d940 --- /dev/null +++ b/crates/polars/tests/it/core/series.rs @@ -0,0 +1,43 @@ +use polars::prelude::*; +use polars::series::*; + +#[test] +fn test_series_arithmetic() -> PolarsResult<()> { + let a = &Series::new("a".into(), &[1, 100, 6, 40]); + let b = &Series::new("b".into(), &[-1, 2, 3, 4]); + assert_eq!((a + b)?, Series::new("a".into(), &[0, 102, 9, 44])); + assert_eq!((a - b)?, Series::new("a".into(), &[2, 98, 3, 36])); + assert_eq!((a * b)?, Series::new("a".into(), &[-1, 200, 18, 160])); + assert_eq!((a / b)?, Series::new("a".into(), &[-1, 50, 2, 10])); + + Ok(()) +} + +#[test] +fn test_min_max_sorted_asc() { + let a = &mut Series::new("a".into(), &[1, 2, 3, 4]); + a.set_sorted_flag(IsSorted::Ascending); + assert_eq!(a.max().unwrap(), Some(4)); + assert_eq!(a.min().unwrap(), Some(1)); +} + +#[test] +fn test_min_max_sorted_desc() { + let a = &mut Series::new("a".into(), &[4, 3, 2, 1]); + a.set_sorted_flag(IsSorted::Descending); + assert_eq!(a.max().unwrap(), Some(4)); + assert_eq!(a.min().unwrap(), Some(1)); +} + +#[test] +fn test_construct_list_of_null_series() { + let s = Series::new( + "a".into(), + [ + Series::new_null("a1".into(), 1), + Series::new_null("a1".into(), 1), + ], + ); + assert_eq!(s.null_count(), 0); + assert_eq!(s.field().name(), "a"); +} diff --git a/crates/polars/tests/it/core/utils.rs b/crates/polars/tests/it/core/utils.rs new file mode 100644 index 000000000000..cc131643835a --- /dev/null +++ b/crates/polars/tests/it/core/utils.rs @@ -0,0 +1,19 @@ +use super::*; + +#[test] +fn test_df_macro_trailing_commas() -> PolarsResult<()> { + let a = df! { + "a" => &["a one", "a two"], + "b" => &["b one", "b two"], + "c" => &[1, 2] + }?; + + let b = df! { + "a" => &["a one", "a two"], + "b" => &["b one", "b two"], + "c" => &[1, 2], + }?; + + assert!(a.equals(&b)); + Ok(()) +} diff --git a/crates/polars/tests/it/io/avro/mod.rs b/crates/polars/tests/it/io/avro/mod.rs new file mode 100644 index 000000000000..e801a6d1d0b5 --- /dev/null +++ b/crates/polars/tests/it/io/avro/mod.rs @@ -0,0 +1,6 @@ +//! Read and write from and to Apache Avro + +mod read; +mod read_async; +mod write; +mod write_async; diff --git a/crates/polars/tests/it/io/avro/read.rs b/crates/polars/tests/it/io/avro/read.rs new file mode 100644 index 000000000000..07fb5ef8d7c0 --- /dev/null +++ b/crates/polars/tests/it/io/avro/read.rs @@ -0,0 +1,377 @@ +use std::sync::Arc; + +use apache_avro::types::{Record, Value}; +use apache_avro::{Codec, Days, Duration, Millis, Months, Schema as AvroSchema, Writer}; +use arrow::array::*; +use arrow::datatypes::*; +use arrow::io::avro::avro_schema::read::read_metadata; +use arrow::io::avro::read; +use arrow::record_batch::RecordBatchT; +use polars_error::PolarsResult; + +pub(super) fn schema() -> (AvroSchema, ArrowSchema) { + let raw_schema = r#" + { + "type": "record", + "name": "test", + "fields": [ + {"name": "a", "type": "long"}, + {"name": "b", "type": "string"}, + {"name": "c", "type": "int"}, + { + "name": "date", + "type": "int", + "logicalType": "date" + }, + {"name": "d", "type": "bytes"}, + {"name": "e", "type": "double"}, + {"name": "f", "type": "boolean"}, + {"name": "g", "type": ["null", "string"], "default": null}, + {"name": "h", "type": { + "type": "array", + "items": { + "name": "item", + "type": ["null", "int"], + "default": null + } + }}, + {"name": "i", "type": { + "type": "record", + "name": "bla", + "fields": [ + {"name": "e", "type": "double"} + ] + }}, + {"name": "nullable_struct", "type": [ + "null", { + "type": "record", + "name": "foo", + "fields": [ + {"name": "e", "type": "double"} + ] + }] + , "default": null + } + ] + } +"#; + + let schema = ArrowSchema::from_iter([ + Field::new("a".into(), ArrowDataType::Int64, false), + Field::new("b".into(), ArrowDataType::Utf8, false), + Field::new("c".into(), ArrowDataType::Int32, false), + Field::new("date".into(), ArrowDataType::Date32, false), + Field::new("d".into(), ArrowDataType::Binary, false), + Field::new("e".into(), ArrowDataType::Float64, false), + Field::new("f".into(), ArrowDataType::Boolean, false), + Field::new("g".into(), ArrowDataType::Utf8, true), + Field::new( + "h".into(), + ArrowDataType::List(Box::new(Field::new( + "item".into(), + ArrowDataType::Int32, + true, + ))), + false, + ), + Field::new( + "i".into(), + ArrowDataType::Struct(vec![Field::new("e".into(), ArrowDataType::Float64, false)]), + false, + ), + Field::new( + "nullable_struct".into(), + ArrowDataType::Struct(vec![Field::new("e".into(), ArrowDataType::Float64, false)]), + true, + ), + ]); + + (AvroSchema::parse_str(raw_schema).unwrap(), schema) +} + +pub(super) fn data() -> RecordBatchT> { + let data = vec![ + Some(vec![Some(1i32), None, Some(3)]), + Some(vec![Some(1i32), None, Some(3)]), + ]; + + let mut array = MutableListArray::>::new(); + array.try_extend(data).unwrap(); + + let columns = vec![ + Int64Array::from_slice([27, 47]).boxed(), + Utf8Array::::from_slice(["foo", "bar"]).boxed(), + Int32Array::from_slice([1, 1]).boxed(), + Int32Array::from_slice([1, 2]) + .to(ArrowDataType::Date32) + .boxed(), + BinaryArray::::from_slice([b"foo", b"bar"]).boxed(), + PrimitiveArray::::from_slice([1.0, 2.0]).boxed(), + BooleanArray::from_slice([true, false]).boxed(), + Utf8Array::::from([Some("foo"), None]).boxed(), + array.into_box(), + StructArray::new( + ArrowDataType::Struct(vec![Field::new("e".into(), ArrowDataType::Float64, false)]), + 2, + vec![PrimitiveArray::::from_slice([1.0, 2.0]).boxed()], + None, + ) + .boxed(), + StructArray::new( + ArrowDataType::Struct(vec![Field::new("e".into(), ArrowDataType::Float64, false)]), + 2, + vec![PrimitiveArray::::from_slice([1.0, 0.0]).boxed()], + Some([true, false].into()), + ) + .boxed(), + ]; + + let (_, schema) = schema(); + + RecordBatchT::try_new(2, Arc::new(schema), columns).unwrap() +} + +pub(super) fn write_avro(codec: Codec) -> Result, apache_avro::Error> { + let (avro, _) = schema(); + // a writer needs a schema and something to write to + let mut writer = Writer::with_codec(&avro, Vec::new(), codec); + + // the Record type models our Record schema + let mut record = Record::new(writer.schema()).unwrap(); + record.put("a", 27i64); + record.put("b", "foo"); + record.put("c", 1i32); + record.put("date", 1i32); + record.put("d", b"foo".as_ref()); + record.put("e", 1.0f64); + record.put("f", true); + record.put("g", Some("foo")); + record.put( + "h", + Value::Array(vec![ + Value::Union(1, Box::new(Value::Int(1))), + Value::Union(0, Box::new(Value::Null)), + Value::Union(1, Box::new(Value::Int(3))), + ]), + ); + record.put( + "i", + Value::Record(vec![("e".to_string(), Value::Double(1.0f64))]), + ); + record.put( + "duration", + Value::Duration(Duration::new(Months::new(1), Days::new(1), Millis::new(1))), + ); + record.put( + "nullable_struct", + Value::Union( + 1, + Box::new(Value::Record(vec![( + "e".to_string(), + Value::Double(1.0f64), + )])), + ), + ); + writer.append(record)?; + + let mut record = Record::new(writer.schema()).unwrap(); + record.put("b", "bar"); + record.put("a", 47i64); + record.put("c", 1i32); + record.put("date", 2i32); + record.put("d", b"bar".as_ref()); + record.put("e", 2.0f64); + record.put("f", false); + record.put("g", None::<&str>); + record.put( + "i", + Value::Record(vec![("e".to_string(), Value::Double(2.0f64))]), + ); + record.put( + "h", + Value::Array(vec![ + Value::Union(1, Box::new(Value::Int(1))), + Value::Union(0, Box::new(Value::Null)), + Value::Union(1, Box::new(Value::Int(3))), + ]), + ); + record.put("nullable_struct", Value::Union(0, Box::new(Value::Null))); + writer.append(record)?; + writer.into_inner() +} + +pub(super) fn read_avro( + mut avro: &[u8], + projection: Option>, +) -> PolarsResult<(RecordBatchT>, ArrowSchema)> { + let file = &mut avro; + + let metadata = read_metadata(file)?; + let schema = read::infer_schema(&metadata.record)?; + + let mut reader = read::Reader::new(file, metadata, schema.clone(), projection.clone()); + + let schema = if let Some(projection) = projection { + schema + .into_iter_values() + .zip(projection.iter()) + .filter_map(|x| if *x.1 { Some(x.0) } else { None }) + .collect() + } else { + schema + }; + + reader.next().unwrap().map(|x| (x, schema)) +} + +fn test(codec: Codec) -> PolarsResult<()> { + let avro = write_avro(codec).unwrap(); + let expected = data(); + let (_, expected_schema) = schema(); + + let (result, schema) = read_avro(&avro, None)?; + + assert_eq!(schema, expected_schema); + assert_eq!(result, expected); + Ok(()) +} + +#[test] +fn read_without_codec() -> PolarsResult<()> { + test(Codec::Null) +} + +#[test] +fn read_deflate() -> PolarsResult<()> { + test(Codec::Deflate) +} + +#[test] +fn read_snappy() -> PolarsResult<()> { + test(Codec::Snappy) +} + +#[test] +fn test_projected() -> PolarsResult<()> { + let expected = data(); + let expected_schema = expected.schema(); + + let avro = write_avro(Codec::Null).unwrap(); + + for i in 0..expected_schema.len() { + let mut projection = vec![false; expected_schema.len()]; + projection[i] = true; + + let length = expected.first().map_or(0, |arr| arr.len()); + let (expected_schema_2, expected_arrays) = expected.clone().into_schema_and_arrays(); + let expected_schema_2 = expected_schema_2 + .as_ref() + .clone() + .into_iter() + .zip(projection.iter()) + .filter_map(|x| if *x.1 { Some(x.0) } else { None }) + .collect(); + let expected_arrays = expected_arrays + .into_iter() + .zip(projection.iter()) + .filter_map(|x| if *x.1 { Some(x.0) } else { None }) + .collect(); + let expected = RecordBatchT::new(length, Arc::new(expected_schema_2), expected_arrays); + + let expected_schema = expected_schema + .clone() + .into_iter_values() + .zip(projection.iter()) + .filter_map(|x| if *x.1 { Some(x.0) } else { None }) + .collect(); + + let (result, schema) = read_avro(&avro, Some(projection))?; + + assert_eq!(schema, expected_schema); + assert_eq!(result, expected); + } + Ok(()) +} + +fn schema_list() -> (AvroSchema, ArrowSchema) { + let raw_schema = r#" + { + "type": "record", + "name": "test", + "fields": [ + {"name": "h", "type": { + "type": "array", + "items": { + "name": "item", + "type": "int" + } + }} + ] + } +"#; + + let schema = ArrowSchema::from_iter([Field::new( + "h".into(), + ArrowDataType::List(Box::new(Field::new( + "item".into(), + ArrowDataType::Int32, + false, + ))), + false, + )]); + + (AvroSchema::parse_str(raw_schema).unwrap(), schema) +} + +pub(super) fn data_list() -> RecordBatchT> { + let data = [Some(vec![Some(1i32), Some(2), Some(3)]), Some(vec![])]; + + let mut array = MutableListArray::>::new_from( + Default::default(), + ArrowDataType::List(Box::new(Field::new( + "item".into(), + ArrowDataType::Int32, + false, + ))), + 0, + ); + array.try_extend(data).unwrap(); + + let length = array.len(); + let (_, schema) = schema_list(); + let columns = vec![array.into_box()]; + + RecordBatchT::try_new(length, Arc::new(schema), columns).unwrap() +} + +pub(super) fn write_list(codec: Codec) -> Result, apache_avro::Error> { + let (avro, _) = schema_list(); + // a writer needs a schema and something to write to + let mut writer = Writer::with_codec(&avro, Vec::new(), codec); + + // the Record type models our Record schema + let mut record = Record::new(writer.schema()).unwrap(); + record.put( + "h", + Value::Array(vec![Value::Int(1), Value::Int(2), Value::Int(3)]), + ); + writer.append(record)?; + + let mut record = Record::new(writer.schema()).unwrap(); + record.put("h", Value::Array(vec![])); + writer.append(record)?; + Ok(writer.into_inner().unwrap()) +} + +#[test] +fn test_list() -> PolarsResult<()> { + let avro = write_list(Codec::Null).unwrap(); + let expected = data_list(); + let expected_schema = expected.schema(); + + let (result, schema) = read_avro(&avro, None)?; + + assert_eq!(&schema, expected_schema); + assert_eq!(result, expected); + Ok(()) +} diff --git a/crates/polars/tests/it/io/avro/read_async.rs b/crates/polars/tests/it/io/avro/read_async.rs new file mode 100644 index 000000000000..2c90354f56d2 --- /dev/null +++ b/crates/polars/tests/it/io/avro/read_async.rs @@ -0,0 +1,42 @@ +use apache_avro::Codec; +use arrow::io::avro::avro_schema::read_async::{block_stream, read_metadata}; +use arrow::io::avro::read; +use futures::{StreamExt, pin_mut}; +use polars_error::PolarsResult; + +use super::read::{schema, write_avro}; + +async fn test(codec: Codec) -> PolarsResult<()> { + let avro_data = write_avro(codec).unwrap(); + let (_, expected_schema) = schema(); + + let mut reader = &mut &avro_data[..]; + + let metadata = read_metadata(&mut reader).await?; + let schema = read::infer_schema(&metadata.record)?; + + assert_eq!(schema, expected_schema); + + let blocks = block_stream(&mut reader, metadata.marker).await; + + pin_mut!(blocks); + while let Some(block) = blocks.next().await.transpose()? { + assert!(block.number_of_rows > 0 || block.data.is_empty()) + } + Ok(()) +} + +#[tokio::test] +async fn read_without_codec() -> PolarsResult<()> { + test(Codec::Null).await +} + +#[tokio::test] +async fn read_deflate() -> PolarsResult<()> { + test(Codec::Deflate).await +} + +#[tokio::test] +async fn read_snappy() -> PolarsResult<()> { + test(Codec::Snappy).await +} diff --git a/crates/polars/tests/it/io/avro/write.rs b/crates/polars/tests/it/io/avro/write.rs new file mode 100644 index 000000000000..810abc832b3a --- /dev/null +++ b/crates/polars/tests/it/io/avro/write.rs @@ -0,0 +1,487 @@ +use std::io::Cursor; +use std::sync::Arc; + +use arrow::array::*; +use arrow::datatypes::*; +use arrow::io::avro::avro_schema::file::{Block, CompressedBlock, Compression}; +use arrow::io::avro::avro_schema::write::{compress, write_block, write_metadata}; +use arrow::io::avro::write; +use arrow::record_batch::RecordBatchT; +use avro_schema::schema::{Field as AvroField, Record, Schema as AvroSchema}; +use polars::io::avro::{AvroReader, AvroWriter}; +use polars::io::{SerReader, SerWriter}; +use polars::prelude::df; +use polars_error::PolarsResult; + +use super::read::read_avro; + +pub(super) fn schema() -> ArrowSchema { + ArrowSchema::from_iter([ + Field::new("int64".into(), ArrowDataType::Int64, false), + Field::new("int64 nullable".into(), ArrowDataType::Int64, true), + Field::new("utf8".into(), ArrowDataType::Utf8, false), + Field::new("utf8 nullable".into(), ArrowDataType::Utf8, true), + Field::new("int32".into(), ArrowDataType::Int32, false), + Field::new("int32 nullable".into(), ArrowDataType::Int32, true), + Field::new("date".into(), ArrowDataType::Date32, false), + Field::new("date nullable".into(), ArrowDataType::Date32, true), + Field::new("binary".into(), ArrowDataType::Binary, false), + Field::new("binary nullable".into(), ArrowDataType::Binary, true), + Field::new("float32".into(), ArrowDataType::Float32, false), + Field::new("float32 nullable".into(), ArrowDataType::Float32, true), + Field::new("float64".into(), ArrowDataType::Float64, false), + Field::new("float64 nullable".into(), ArrowDataType::Float64, true), + Field::new("boolean".into(), ArrowDataType::Boolean, false), + Field::new("boolean nullable".into(), ArrowDataType::Boolean, true), + Field::new( + "list".into(), + ArrowDataType::List(Box::new(Field::new( + "item".into(), + ArrowDataType::Int32, + true, + ))), + false, + ), + Field::new( + "list nullable".into(), + ArrowDataType::List(Box::new(Field::new( + "item".into(), + ArrowDataType::Int32, + true, + ))), + true, + ), + ]) +} + +pub(super) fn data() -> RecordBatchT> { + let list_dt = ArrowDataType::List(Box::new(Field::new( + "item".into(), + ArrowDataType::Int32, + true, + ))); + let list_dt1 = ArrowDataType::List(Box::new(Field::new( + "item".into(), + ArrowDataType::Int32, + true, + ))); + + let columns = vec![ + Box::new(Int64Array::from_slice([27, 47])) as Box, + Box::new(Int64Array::from([Some(27), None])), + Box::new(Utf8Array::::from_slice(["foo", "bar"])), + Box::new(Utf8Array::::from([Some("foo"), None])), + Box::new(Int32Array::from_slice([1, 1])), + Box::new(Int32Array::from([Some(1), None])), + Box::new(Int32Array::from_slice([1, 2]).to(ArrowDataType::Date32)), + Box::new(Int32Array::from([Some(1), None]).to(ArrowDataType::Date32)), + Box::new(BinaryArray::::from_slice([b"foo", b"bar"])), + Box::new(BinaryArray::::from([Some(b"foo"), None])), + Box::new(PrimitiveArray::::from_slice([1.0, 2.0])), + Box::new(PrimitiveArray::::from([Some(1.0), None])), + Box::new(PrimitiveArray::::from_slice([1.0, 2.0])), + Box::new(PrimitiveArray::::from([Some(1.0), None])), + Box::new(BooleanArray::from_slice([true, false])), + Box::new(BooleanArray::from([Some(true), None])), + Box::new(ListArray::::new( + list_dt, + vec![0, 2, 5].try_into().unwrap(), + Box::new(PrimitiveArray::::from([ + None, + Some(1), + None, + Some(3), + Some(4), + ])), + None, + )), + Box::new(ListArray::::new( + list_dt1, + vec![0, 2, 2].try_into().unwrap(), + Box::new(PrimitiveArray::::from([None, Some(1)])), + Some([true, false].into()), + )), + ]; + let schema = schema(); + + RecordBatchT::new(2, Arc::new(schema), columns) +} + +pub(super) fn serialize_to_block>( + columns: &RecordBatchT, + schema: &ArrowSchema, + compression: Option, +) -> PolarsResult { + let record = write::to_record(schema, "".to_string())?; + + let mut serializers = columns + .arrays() + .iter() + .map(|x| x.as_ref()) + .zip(record.fields.iter()) + .map(|(array, field)| write::new_serializer(array, &field.schema)) + .collect::>(); + let mut block = Block::new(columns.len(), vec![]); + + write::serialize(&mut serializers, &mut block); + + let mut compressed_block = CompressedBlock::default(); + + compress(&mut block, &mut compressed_block, compression)?; + + Ok(compressed_block) +} + +fn write_avro>( + columns: &RecordBatchT, + schema: &ArrowSchema, + compression: Option, +) -> PolarsResult> { + let compressed_block = serialize_to_block(columns, schema, compression)?; + + let avro_fields = write::to_record(schema, "".to_string())?; + let mut file = vec![]; + + write_metadata(&mut file, avro_fields, compression)?; + + write_block(&mut file, &compressed_block)?; + + Ok(file) +} + +fn roundtrip(compression: Option) -> PolarsResult<()> { + let expected = data(); + let expected_schema = schema(); + + let data = write_avro(&expected, &expected_schema, compression)?; + + let (result, read_schema) = read_avro(&data, None)?; + + assert_eq!(expected_schema, read_schema); + for (c1, c2) in result.columns().iter().zip(expected.columns().iter()) { + assert_eq!(c1.as_ref(), c2.as_ref()); + } + Ok(()) +} + +#[test] +fn no_compression() -> PolarsResult<()> { + roundtrip(None) +} + +#[test] +fn snappy() -> PolarsResult<()> { + roundtrip(Some(Compression::Snappy)) +} + +#[test] +fn deflate() -> PolarsResult<()> { + roundtrip(Some(Compression::Deflate)) +} + +fn large_format_schema() -> ArrowSchema { + ArrowSchema::from_iter([ + Field::new("large_utf8".into(), ArrowDataType::LargeUtf8, false), + Field::new("large_utf8_nullable".into(), ArrowDataType::LargeUtf8, true), + Field::new("large_binary".into(), ArrowDataType::LargeBinary, false), + Field::new( + "large_binary_nullable".into(), + ArrowDataType::LargeBinary, + true, + ), + ]) +} + +fn large_format_data() -> RecordBatchT> { + let columns = vec![ + Box::new(Utf8Array::::from_slice(["a", "b"])) as Box, + Box::new(Utf8Array::::from([Some("a"), None])), + Box::new(BinaryArray::::from_slice([b"foo", b"bar"])), + Box::new(BinaryArray::::from([Some(b"foo"), None])), + ]; + let schema = large_format_schema(); + + RecordBatchT::new(2, Arc::new(schema), columns) +} + +fn large_format_expected_schema() -> ArrowSchema { + ArrowSchema::from_iter([ + Field::new("large_utf8".into(), ArrowDataType::Utf8, false), + Field::new("large_utf8_nullable".into(), ArrowDataType::Utf8, true), + Field::new("large_binary".into(), ArrowDataType::Binary, false), + Field::new("large_binary_nullable".into(), ArrowDataType::Binary, true), + ]) +} + +fn large_format_expected_data() -> RecordBatchT> { + let columns = vec![ + Box::new(Utf8Array::::from_slice(["a", "b"])) as Box, + Box::new(Utf8Array::::from([Some("a"), None])), + Box::new(BinaryArray::::from_slice([b"foo", b"bar"])), + Box::new(BinaryArray::::from([Some(b"foo"), None])), + ]; + let schema = large_format_expected_schema(); + + RecordBatchT::new(2, Arc::new(schema), columns) +} + +#[test] +fn check_large_format() -> PolarsResult<()> { + let write_schema = large_format_schema(); + let write_data = large_format_data(); + + let data = write_avro(&write_data, &write_schema, None)?; + let (result, read_schame) = read_avro(&data, None)?; + + let expected_schema = large_format_expected_schema(); + assert_eq!(read_schame, expected_schema); + + let expected_data = large_format_expected_data(); + for (c1, c2) in result.columns().iter().zip(expected_data.columns().iter()) { + assert_eq!(c1.as_ref(), c2.as_ref()); + } + + Ok(()) +} + +fn struct_schema() -> ArrowSchema { + ArrowSchema::from_iter([ + Field::new( + "struct".into(), + ArrowDataType::Struct(vec![ + Field::new("item1".into(), ArrowDataType::Int32, false), + Field::new("item2".into(), ArrowDataType::Int32, true), + ]), + false, + ), + Field::new( + "struct nullable".into(), + ArrowDataType::Struct(vec![ + Field::new("item1".into(), ArrowDataType::Int32, false), + Field::new("item2".into(), ArrowDataType::Int32, true), + ]), + true, + ), + ]) +} + +fn struct_data() -> RecordBatchT> { + let struct_dt = ArrowDataType::Struct(vec![ + Field::new("item1".into(), ArrowDataType::Int32, false), + Field::new("item2".into(), ArrowDataType::Int32, true), + ]); + let schema = struct_schema(); + + RecordBatchT::new( + 2, + Arc::new(schema), + vec![ + Box::new(StructArray::new( + struct_dt.clone(), + 2, + vec![ + Box::new(PrimitiveArray::::from_slice([1, 2])), + Box::new(PrimitiveArray::::from([None, Some(1)])), + ], + None, + )), + Box::new(StructArray::new( + struct_dt, + 2, + vec![ + Box::new(PrimitiveArray::::from_slice([1, 2])), + Box::new(PrimitiveArray::::from([None, Some(1)])), + ], + Some([true, false].into()), + )), + ], + ) +} + +fn avro_record() -> Record { + Record { + name: "".to_string(), + namespace: None, + doc: None, + aliases: vec![], + fields: vec![ + AvroField { + name: "struct".to_string(), + doc: None, + schema: AvroSchema::Record(Record { + name: "r1".to_string(), + namespace: None, + doc: None, + aliases: vec![], + fields: vec![ + AvroField { + name: "item1".to_string(), + doc: None, + schema: AvroSchema::Int(None), + default: None, + order: None, + aliases: vec![], + }, + AvroField { + name: "item2".to_string(), + doc: None, + schema: AvroSchema::Union(vec![ + AvroSchema::Null, + AvroSchema::Int(None), + ]), + default: None, + order: None, + aliases: vec![], + }, + ], + }), + default: None, + order: None, + aliases: vec![], + }, + AvroField { + name: "struct nullable".to_string(), + doc: None, + schema: AvroSchema::Union(vec![ + AvroSchema::Null, + AvroSchema::Record(Record { + name: "r2".to_string(), + namespace: None, + doc: None, + aliases: vec![], + fields: vec![ + AvroField { + name: "item1".to_string(), + doc: None, + schema: AvroSchema::Int(None), + default: None, + order: None, + aliases: vec![], + }, + AvroField { + name: "item2".to_string(), + doc: None, + schema: AvroSchema::Union(vec![ + AvroSchema::Null, + AvroSchema::Int(None), + ]), + default: None, + order: None, + aliases: vec![], + }, + ], + }), + ]), + default: None, + order: None, + aliases: vec![], + }, + ], + } +} + +#[test] +fn avro_record_schema() -> PolarsResult<()> { + let arrow_schema = struct_schema(); + let record = write::to_record(&arrow_schema, "".to_string())?; + assert_eq!(record, avro_record()); + Ok(()) +} + +#[test] +fn struct_() -> PolarsResult<()> { + let write_schema = struct_schema(); + let write_data = struct_data(); + + let data = write_avro(&write_data, &write_schema, None)?; + let (result, read_schema) = read_avro(&data, None)?; + + let expected_schema = struct_schema(); + assert_eq!(read_schema, expected_schema); + + let expected_data = struct_data(); + for (c1, c2) in result.columns().iter().zip(expected_data.columns().iter()) { + assert_eq!(c1.as_ref(), c2.as_ref()); + } + + Ok(()) +} + +#[test] +fn test_write_and_read_with_compression() -> PolarsResult<()> { + let mut write_df = df!( + "i64" => &[1, 2], + "f64" => &[0.1, 0.2], + "string" => &["a", "b"] + )?; + + let compressions = vec![None, Some(Compression::Deflate), Some(Compression::Snappy)]; + + for compression in compressions.into_iter() { + let mut buf: Cursor> = Cursor::new(Vec::new()); + + AvroWriter::new(&mut buf) + .with_compression(compression) + .finish(&mut write_df)?; + buf.set_position(0); + + let read_df = AvroReader::new(buf).finish()?; + assert!(write_df.equals(&read_df)); + } + + Ok(()) +} + +#[test] +fn test_with_projection() -> PolarsResult<()> { + let mut df = df!( + "i64" => &[1, 2], + "f64" => &[0.1, 0.2], + "string" => &["a", "b"] + )?; + + let expected_df = df!( + "i64" => &[1, 2], + "f64" => &[0.1, 0.2] + )?; + + let mut buf: Cursor> = Cursor::new(Vec::new()); + + AvroWriter::new(&mut buf).finish(&mut df)?; + buf.set_position(0); + + let read_df = AvroReader::new(buf) + .with_projection(Some(vec![0, 1])) + .finish()?; + + assert!(expected_df.equals(&read_df)); + + Ok(()) +} + +#[test] +fn test_with_columns() -> PolarsResult<()> { + let mut df = df!( + "i64" => &[1, 2], + "f64" => &[0.1, 0.2], + "string" => &["a", "b"] + )?; + + let expected_df = df!( + "i64" => &[1, 2], + "string" => &["a", "b"] + )?; + + let mut buf: Cursor> = Cursor::new(Vec::new()); + + AvroWriter::new(&mut buf).finish(&mut df)?; + buf.set_position(0); + + let read_df = AvroReader::new(buf) + .with_columns(Some(vec!["i64".to_string(), "string".to_string()])) + .finish()?; + + assert!(expected_df.equals(&read_df)); + + Ok(()) +} diff --git a/crates/polars/tests/it/io/avro/write_async.rs b/crates/polars/tests/it/io/avro/write_async.rs new file mode 100644 index 000000000000..77cb212f89db --- /dev/null +++ b/crates/polars/tests/it/io/avro/write_async.rs @@ -0,0 +1,48 @@ +use arrow::array::*; +use arrow::datatypes::*; +use arrow::io::avro::write; +use arrow::record_batch::RecordBatchT; +use avro_schema::file::Compression; +use avro_schema::write_async::{write_block, write_metadata}; +use polars_error::PolarsResult; + +use super::read::read_avro; +use super::write::{data, schema, serialize_to_block}; + +async fn write_avro>( + columns: &RecordBatchT, + schema: &ArrowSchema, + compression: Option, +) -> PolarsResult> { + // usually done on a different thread pool + let compressed_block = serialize_to_block(columns, schema, compression)?; + + let record = write::to_record(schema, "".to_string())?; + let mut file = vec![]; + + write_metadata(&mut file, record, compression).await?; + + write_block(&mut file, &compressed_block).await?; + + Ok(file) +} + +async fn roundtrip(compression: Option) -> PolarsResult<()> { + let expected = data(); + let expected_schema = schema(); + + let data = write_avro(&expected, &expected_schema, compression).await?; + + let (result, read_schema) = read_avro(&data, None)?; + + assert_eq!(expected_schema, read_schema); + for (c1, c2) in result.columns().iter().zip(expected.columns().iter()) { + assert_eq!(c1.as_ref(), c2.as_ref()); + } + Ok(()) +} + +#[tokio::test] +async fn no_compression() -> PolarsResult<()> { + roundtrip(None).await +} diff --git a/crates/polars/tests/it/io/csv.rs b/crates/polars/tests/it/io/csv.rs new file mode 100644 index 000000000000..eb1c2f629a19 --- /dev/null +++ b/crates/polars/tests/it/io/csv.rs @@ -0,0 +1,1405 @@ +use std::io::Cursor; +use std::num::NonZeroUsize; + +use polars::io::RowIndex; +use polars_core::utils::concat_df; + +use super::*; + +const FOODS_CSV: &str = "../../examples/datasets/foods1.csv"; + +#[test] +fn write_csv() { + let mut buf: Vec = Vec::new(); + let mut df = create_df(); + + CsvWriter::new(&mut buf) + .include_header(true) + .with_batch_size(NonZeroUsize::new(1).unwrap()) + .finish(&mut df) + .expect("csv written"); + let csv = std::str::from_utf8(&buf).unwrap(); + assert_eq!("days,temp\n0,22.1\n1,19.9\n2,7.0\n3,2.0\n4,3.0\n", csv); + + let mut buf: Vec = Vec::new(); + CsvWriter::new(&mut buf) + .include_header(false) + .finish(&mut df) + .expect("csv written"); + let csv = std::str::from_utf8(&buf).unwrap(); + assert_eq!("0,22.1\n1,19.9\n2,7.0\n3,2.0\n4,3.0\n", csv); + + let mut buf: Vec = Vec::new(); + CsvWriter::new(&mut buf) + .include_header(false) + .with_line_terminator("\r\n".into()) + .finish(&mut df) + .expect("csv written"); + let csv = std::str::from_utf8(&buf).unwrap(); + assert_eq!("0,22.1\r\n1,19.9\r\n2,7.0\r\n3,2.0\r\n4,3.0\r\n", csv); +} + +#[test] +#[cfg(feature = "timezones")] +fn write_dates() { + use chrono; + + let s0 = Column::new( + "date".into(), + [chrono::NaiveDate::from_yo_opt(2024, 33), None], + ); + let s1 = Column::new( + "time".into(), + [None, chrono::NaiveTime::from_hms_opt(19, 50, 0)], + ); + let s2 = Column::new( + "datetime".into(), + [ + Some(chrono::NaiveDateTime::new( + chrono::NaiveDate::from_ymd_opt(2000, 12, 1).unwrap(), + chrono::NaiveTime::from_num_seconds_from_midnight_opt(99, 49575634).unwrap(), + )), + None, + ], + ); + let mut df = DataFrame::new(vec![s0, s1, s2.clone()]).unwrap(); + + let mut buf: Vec = Vec::new(); + CsvWriter::new(&mut buf) + .include_header(true) + .with_batch_size(NonZeroUsize::new(1).unwrap()) + .finish(&mut df) + .expect("csv written"); + let csv = std::str::from_utf8(&buf).unwrap(); + assert_eq!( + "date,time,datetime\n2024-02-02,,2000-12-01T00:01:39.049\n,19:50:00.000000000,\n", + csv, + ); + + buf.clear(); + CsvWriter::new(&mut buf) + .include_header(true) + .with_batch_size(NonZeroUsize::new(1).unwrap()) + .with_date_format(Some("%d/%m/%Y".into())) + .with_time_format(Some("%H%M%S".into())) + .with_datetime_format(Some("%Y-%m-%d %H:%M:%S".into())) + .finish(&mut df) + .expect("csv written"); + let csv = std::str::from_utf8(&buf).unwrap(); + assert_eq!( + "date,time,datetime\n02/02/2024,,2000-12-01 00:01:39\n,195000,\n", + csv, + ); + + buf.clear(); + CsvWriter::new(&mut buf) + .include_header(true) + .with_batch_size(NonZeroUsize::new(1).unwrap()) + .with_date_format(Some("%".into())) + .finish(&mut df) + .expect_err("invalid date/time format should err"); + + buf.clear(); + CsvWriter::new(&mut buf) + .include_header(true) + .with_batch_size(NonZeroUsize::new(1).unwrap()) + .with_date_format(Some("%H".into())) + .finish(&mut df) + .expect_err("invalid date/time format should err"); + + buf.clear(); + CsvWriter::new(&mut buf) + .include_header(true) + .with_batch_size(NonZeroUsize::new(1).unwrap()) + .with_datetime_format(Some("%Z".into())) + .finish(&mut df) + .expect_err("invalid date/time format should err"); + + let with_timezone = polars_ops::chunked_array::replace_time_zone( + s2.slice(0, 1).datetime().unwrap(), + Some("America/New_York"), + &StringChunked::new("".into(), ["raise"]), + NonExistent::Raise, + ) + .unwrap() + .into_column(); + let mut with_timezone_df = DataFrame::new(vec![with_timezone]).unwrap(); + buf.clear(); + CsvWriter::new(&mut buf) + .include_header(false) + .finish(&mut with_timezone_df) + .expect("csv written"); + let csv = std::str::from_utf8(&buf).unwrap(); + assert_eq!("2000-12-01T00:01:39.049-0500\n", csv); +} + +#[test] +fn test_read_csv_file() { + let file = std::fs::File::open(FOODS_CSV).unwrap(); + let df = CsvReadOptions::default() + .into_reader_with_file_handle(file) + .finish() + .unwrap(); + + assert_eq!(df.shape(), (27, 4)); +} + +#[test] +fn test_read_csv_filter() -> PolarsResult<()> { + let df = CsvReadOptions::default() + .try_into_reader_with_file_path(Some(FOODS_CSV.into()))? + .finish()?; + + let out = df.filter(&df.column("fats_g")?.as_materialized_series().gt(4)?)?; + + // This fails if all columns are not equal. + println!("{out}"); + + Ok(()) +} + +#[test] +fn test_parser() -> PolarsResult<()> { + let s = r#" + "sepal_length","sepal_width","petal_length","petal_width","variety" + 5.1,3.5,1.4,.2,"Setosa" + 4.9,3,1.4,.2,"Setosa" + 4.7,3.2,1.3,.2,"Setosa" + 4.6,3.1,1.5,.2,"Setosa" + 5,3.6,1.4,.2,"Setosa" + 5.4,3.9,1.7,.4,"Setosa" + 4.6,3.4,1.4,.3,"Setosa" +"#; + + let file = Cursor::new(s); + CsvReadOptions::default() + .with_infer_schema_length(Some(100)) + .with_has_header(true) + .with_ignore_errors(true) + .into_reader_with_file_handle(file) + .finish() + .unwrap(); + + let s = r#" + "sepal_length","sepal_width","petal_length","petal_width","variety" + 5.1,3.5,1.4,.2,"Setosa" + 5.1,3.5,1.4,.2,"Setosa" + "#; + + let file = Cursor::new(s); + + // just checks if unwrap doesn't panic + CsvReadOptions::default() + // we also check if infer schema ignores errors + .with_infer_schema_length(Some(10)) + .with_has_header(true) + .with_ignore_errors(true) + .into_reader_with_file_handle(file) + .finish() + .unwrap(); + + let s = r#""sepal_length","sepal_width","petal_length","petal_width","variety" + 5.1,3.5,1.4,.2,"Setosa" + 4.9,3,1.4,.2,"Setosa" + 4.7,3.2,1.3,.2,"Setosa" + 4.6,3.1,1.5,.2,"Setosa" + 5,3.6,1.4,.2,"Setosa" + 5.4,3.9,1.7,.4,"Setosa" + 4.6,3.4,1.4,.3,"Setosa" +"#; + + let file = Cursor::new(s); + let df = CsvReadOptions::default() + .with_infer_schema_length(Some(100)) + .with_has_header(true) + .into_reader_with_file_handle(file) + .finish() + .unwrap(); + + let col = df.column("variety").unwrap(); + assert_eq!(col.get(0)?, AnyValue::String("Setosa")); + assert_eq!(col.get(2)?, AnyValue::String("Setosa")); + + assert_eq!("sepal_length", df.get_columns()[0].name().as_str()); + assert_eq!(df.height(), 7); + + // test windows line endings + let s = "head_1,head_2\r\n1,2\r\n1,2\r\n1,2\r\n"; + + let file = Cursor::new(s); + let df = CsvReadOptions::default() + .with_infer_schema_length(Some(100)) + .with_has_header(true) + .into_reader_with_file_handle(file) + .finish() + .unwrap(); + + assert_eq!("head_1", df.get_columns()[0].name().as_str()); + assert_eq!(df.shape(), (3, 2)); + + // test windows line ending with 1 byte char column and no line endings for last line. + let s = "head_1\r\n1\r\n2\r\n3"; + + let file = Cursor::new(s); + let df = CsvReadOptions::default() + .with_infer_schema_length(Some(100)) + .with_has_header(true) + .into_reader_with_file_handle(file) + .finish() + .unwrap(); + + assert_eq!("head_1", df.get_columns()[0].name().as_str()); + assert_eq!(df.shape(), (3, 1)); + Ok(()) +} + +#[test] +fn test_tab_sep() { + let csv = br#"1003000126 ENKESHAFI ARDALAN M.D. M I 900 SETON DR CUMBERLAND 21502 MD US Internal Medicine Y F 99217 Hospital observation care on day of discharge N 68 67 68 73.821029412 381.30882353 57.880294118 58.2125 +1003000126 ENKESHAFI ARDALAN M.D. M I 900 SETON DR CUMBERLAND 21502 MD US Internal Medicine Y F 99218 Hospital observation care, typically 30 minutes N 19 19 19 100.88315789 476.94736842 76.795263158 77.469473684 +1003000126 ENKESHAFI ARDALAN M.D. M I 900 SETON DR CUMBERLAND 21502 MD US Internal Medicine Y F 99220 Hospital observation care, typically 70 minutes N 26 26 26 188.11076923 1086.9230769 147.47923077 147.79346154 +1003000126 ENKESHAFI ARDALAN M.D. M I 900 SETON DR CUMBERLAND 21502 MD US Internal Medicine Y F 99221 Initial hospital inpatient care, typically 30 minutes per day N 24 24 24 102.24 474.58333333 80.155 80.943333333 +1003000126 ENKESHAFI ARDALAN M.D. M I 900 SETON DR CUMBERLAND 21502 MD US Internal Medicine Y F 99222 Initial hospital inpatient care, typically 50 minutes per day N 17 17 17 138.04588235 625 108.22529412 109.22 +1003000126 ENKESHAFI ARDALAN M.D. M I 900 SETON DR CUMBERLAND 21502 MD US Internal Medicine Y F 99223 Initial hospital inpatient care, typically 70 minutes per day N 86 82 86 204.85395349 1093.5 159.25906977 161.78093023 +1003000126 ENKESHAFI ARDALAN M.D. M I 900 SETON DR CUMBERLAND 21502 MD US Internal Medicine Y F 99232 Subsequent hospital inpatient care, typically 25 minutes per day N 360 206 360 73.565666667 360.57222222 57.670305556 58.038833333 +1003000126 ENKESHAFI ARDALAN M.D. M I 900 SETON DR CUMBERLAND 21502 MD US Internal Medicine Y F 99233 Subsequent hospital inpatient care, typically 35 minutes per day N 284 148 284 105.34971831 576.98943662 82.512992958 82.805774648 +"#.as_ref(); + + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .with_infer_schema_length(Some(100)) + .with_has_header(false) + .with_ignore_errors(true) + .map_parse_options(|parse_options| parse_options.with_separator(b'\t')) + .into_reader_with_file_handle(file) + .finish() + .unwrap(); + assert_eq!(df.shape(), (8, 26)) +} + +#[test] +fn test_projection() -> PolarsResult<()> { + let df = CsvReadOptions::default() + .with_projection(Some(vec![0, 2].into())) + .try_into_reader_with_file_path(Some(FOODS_CSV.into()))? + .finish()?; + let col_1 = df.select_at_idx(0).unwrap(); + assert_eq!(col_1.get(0)?, AnyValue::String("vegetables")); + assert_eq!(col_1.get(1)?, AnyValue::String("seafood")); + assert_eq!(col_1.get(2)?, AnyValue::String("meat")); + + let col_2 = df.select_at_idx(1).unwrap(); + assert_eq!(col_2.get(0)?, AnyValue::Float64(0.5)); + assert_eq!(col_2.get(1)?, AnyValue::Float64(5.0)); + assert_eq!(col_2.get(2)?, AnyValue::Float64(5.0)); + Ok(()) +} + +#[test] +fn test_missing_data() { + // missing data should not lead to parser error. + let csv = r#"column_1,column_2,column_3 +1,2,3 +1,,3 +"#; + + let file = Cursor::new(csv); + let df = CsvReader::new(file).finish().unwrap(); + assert!( + df.column("column_1") + .unwrap() + .equals(&Column::new("column_1".into(), &[1_i64, 1])) + ); + assert!( + df.column("column_2") + .unwrap() + .equals_missing(&Column::new("column_2".into(), &[Some(2_i64), None])) + ); + assert!( + df.column("column_3") + .unwrap() + .equals(&Column::new("column_3".into(), &[3_i64, 3])) + ); +} + +#[test] +fn test_escape_comma() { + let csv = r#"column_1,column_2,column_3 +-86.64408227,"Autauga, Alabama, US",11 +-86.64408227,"Autauga, Alabama, US",12 +"#; + let file = Cursor::new(csv); + let df = CsvReader::new(file).finish().unwrap(); + assert_eq!(df.shape(), (2, 3)); + assert!( + df.column("column_3") + .unwrap() + .equals(&Column::new("column_3".into(), &[11_i64, 12])) + ); +} + +#[test] +fn test_escape_double_quotes() { + let csv = r#"column_1,column_2,column_3 +-86.64408227,"with ""double quotes"" US",11 +-86.64408227,"with ""double quotes followed"", by comma",12 +"#; + let file = Cursor::new(csv); + let df = CsvReader::new(file).finish().unwrap(); + assert_eq!(df.shape(), (2, 3)); + assert!(df.column("column_2").unwrap().equals(&Column::new( + "column_2".into(), + &[ + r#"with "double quotes" US"#, + r#"with "double quotes followed", by comma"# + ] + ))); +} + +#[test] +fn test_newline_in_custom_quote_char() { + // newline inside custom quote char (default is ") should parse correctly + let csv = r#"column_1,column_2 + 1,'foo + bar' + 2,'bar' +"#; + + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .map_parse_options(|parse_options| parse_options.with_quote_char(Some(b'\''))) + .into_reader_with_file_handle(file) + .finish() + .unwrap(); + assert_eq!(df.shape(), (2, 2)); +} + +#[test] +fn test_escape_2() { + // this is harder than it looks. + // Fields: + // * hello + // * "," + // * " " + // * world + // * "!" + let csv = r#"hello,","," ",world,"!" +hello,","," ",world,"!" +hello,","," ",world,"!" +hello,","," ",world,"!" +"#; + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .with_has_header(false) + .with_n_threads(Some(1)) + .into_reader_with_file_handle(file) + .finish() + .unwrap(); + + for (col, val) in [ + ("column_1", "hello"), + ("column_2", ","), + ("column_3", " "), + ("column_4", "world"), + ("column_5", "!"), + ] { + assert!( + df.column(col) + .unwrap() + .equals(&Column::new(col.into(), &[val; 4])) + ); + } +} + +#[test] +fn test_very_long_utf8() { + let csv = r#"column_1,column_2,column_3 +-86.64408227,"Lorem Ipsum is simply dummy text of the printing and typesetting +industry. Lorem Ipsum has been the industry's standard dummy text ever since th +e 1500s, when an unknown printer took a galley of type and scrambled it to make +a type specimen book. It has survived not only five centuries, but also the leap +into electronic typesetting, remaining essentially unchanged. It was popularised +in the 1960s with the release of Letraset sheets containing Lorem Ipsum passages, +and more recently with desktop publishing software like Aldus PageMaker including +versions of Lorem Ipsum.",11 +"#; + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .into_reader_with_file_handle(file) + .finish() + .unwrap(); + + assert!(df.column("column_2").unwrap().equals(&Column::new( + "column_2".into(), + &[ + r#"Lorem Ipsum is simply dummy text of the printing and typesetting +industry. Lorem Ipsum has been the industry's standard dummy text ever since th +e 1500s, when an unknown printer took a galley of type and scrambled it to make +a type specimen book. It has survived not only five centuries, but also the leap +into electronic typesetting, remaining essentially unchanged. It was popularised +in the 1960s with the release of Letraset sheets containing Lorem Ipsum passages, +and more recently with desktop publishing software like Aldus PageMaker including +versions of Lorem Ipsum."#, + ] + ))); +} + +#[test] +fn test_nulls_parser() { + // test it does not fail on the leading comma. + let csv = r#"id1,id2,id3,id4,id5,id6,v1,v2,v3 +id047,id023,id0000084849,90,96,35790,2,9,93.348148 +,id022,id0000031441,50,44,71525,3,11,81.013682 +id090,id048,id0000067778,24,2,51862,4,9, +"#; + + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .with_has_header(true) + .with_n_threads(Some(1)) + .into_reader_with_file_handle(file) + .finish() + .unwrap(); + assert_eq!(df.shape(), (3, 9)); +} + +#[test] +fn test_new_line_escape() { + let s = r#""sepal_length","sepal_width","petal_length","petal_width","variety" + 5.1,3.5,1.4,.2,"Setosa + texts after new line character" + 4.9,3,1.4,.2,"Setosa" + "#; + + let file = Cursor::new(s); + CsvReadOptions::default() + .with_has_header(true) + .into_reader_with_file_handle(file) + .finish() + .unwrap(); +} + +#[test] +fn test_new_line_escape_on_header() { + let s = r#""length","header with +new line character","width" +5.1,3.5,1.4 +"#; + let file: Cursor<&str> = Cursor::new(s); + let df = CsvReadOptions::default() + .with_has_header(true) + .into_reader_with_file_handle(file) + .finish() + .unwrap(); + assert_eq!(df.shape(), (1, 3)); + assert_eq!( + df.get_column_names(), + &["length", "header with\nnew line character", "width"] + ); +} + +#[test] +fn test_quoted_numeric() { + // CSV fields may be quoted + let s = r#""foo","bar" +"4.9","3" +"1.4","2" +"#; + + let file = Cursor::new(s); + let df = CsvReadOptions::default() + .with_has_header(true) + .into_reader_with_file_handle(file) + .finish() + .unwrap(); + assert_eq!(df.column("bar").unwrap().dtype(), &DataType::Int64); + assert_eq!(df.column("foo").unwrap().dtype(), &DataType::Float64); +} + +#[test] +fn test_empty_bytes_to_dataframe() { + let fields = vec![Field::new("test_field".into(), DataType::String)]; + let schema = Schema::from_iter(fields); + let file = Cursor::new(vec![]); + + let result = CsvReadOptions::default() + .with_has_header(false) + .with_columns(Some(schema.iter_names_cloned().collect())) + .with_schema(Some(Arc::new(schema))) + .into_reader_with_file_handle(file) + .finish(); + + assert!(result.is_ok()) +} + +#[test] +fn test_carriage_return() { + let csv = "\"foo\",\"bar\"\r\n\"158252579.00\",\"7.5800\"\r\n\"158252579.00\",\"7.5800\"\r\n"; + + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .with_has_header(true) + .with_n_threads(Some(1)) + .into_reader_with_file_handle(file) + .finish() + .unwrap(); + assert_eq!(df.shape(), (2, 2)); +} + +#[test] +fn test_missing_value() { + let csv = r#"foo,bar,ham +1,2,3 +1,2,3 +1,2 +"#; + + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .with_has_header(true) + .with_schema(Some(Arc::new(Schema::from_iter([ + Field::new("foo".into(), DataType::UInt32), + Field::new("bar".into(), DataType::UInt32), + Field::new("ham".into(), DataType::UInt32), + ])))) + .into_reader_with_file_handle(file) + .finish() + .unwrap(); + assert_eq!(df.column("ham").unwrap().len(), 3) +} + +#[test] +#[cfg(feature = "temporal")] +fn test_with_dtype() -> PolarsResult<()> { + // test if timestamps can be parsed as Datetime + let csv = r#"a,b,c,d,e +AUDCAD,1616455919,0.91212,0.95556,1 +AUDCAD,1616455920,0.92212,0.95556,1 +AUDCAD,1616455921,0.96212,0.95666,1 +"#; + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .with_has_header(true) + .with_schema_overwrite(Some(Arc::new(Schema::from_iter([Field::new( + "b".into(), + DataType::Datetime(TimeUnit::Nanoseconds, None), + )])))) + .with_ignore_errors(true) + .into_reader_with_file_handle(file) + .finish()?; + + assert_eq!( + df.dtypes(), + &[ + DataType::String, + DataType::Datetime(TimeUnit::Nanoseconds, None), + DataType::Float64, + DataType::Float64, + DataType::Int64 + ] + ); + Ok(()) +} + +#[test] +fn test_skip_rows() -> PolarsResult<()> { + let csv = r"#doc source pos typeindex type topic +#alpha : 25.0 25.0 +#beta : 0.1 +0 NA 0 0 57 0 +0 NA 0 0 57 0 +0 NA 5 5 513 0 +"; + + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .with_has_header(false) + .with_skip_rows(3) + .map_parse_options(|parse_options| parse_options.with_separator(b' ')) + .into_reader_with_file_handle(file) + .finish()?; + + assert_eq!(df.height(), 3); + Ok(()) +} + +#[test] +fn test_projection_idx() -> PolarsResult<()> { + let csv = r"#0 NA 0 0 57 0 +0 NA 0 0 57 0 +0 NA 5 5 513 0 +"; + + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .with_has_header(false) + .with_projection(Some(Arc::new(vec![4, 5]))) + .map_parse_options(|parse_options| parse_options.with_separator(b' ')) + .into_reader_with_file_handle(file) + .finish()?; + + assert_eq!(df.width(), 2); + + // this should give out of bounds error + let file = Cursor::new(csv); + let out = CsvReadOptions::default() + .with_has_header(false) + .with_projection(Some(Arc::new(vec![4, 6]))) + .map_parse_options(|parse_options| parse_options.with_separator(b' ')) + .into_reader_with_file_handle(file) + .finish(); + + assert!(out.is_err()); + Ok(()) +} + +#[test] +fn test_missing_fields() -> PolarsResult<()> { + let csv = r"1,2,3,4,5 +1,2,3 +1,2,3,4,5 +1,3,5 +"; + + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .with_has_header(false) + .into_reader_with_file_handle(file) + .finish()?; + + use polars_core::df; + let expect = df![ + "column_1" => [1, 1, 1, 1], + "column_2" => [2, 2, 2, 3], + "column_3" => [3, 3, 3, 5], + "column_4" => [Some(4), None, Some(4), None], + "column_5" => [Some(5), None, Some(5), None] + ]?; + assert!(df.equals_missing(&expect)); + Ok(()) +} + +#[test] +fn test_comment_lines() -> PolarsResult<()> { + let csv = r"1,2,3,4,5 +# this is a comment +1,2,3,4,5 +# this is also a comment +1,2,3,4,5 +"; + + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .with_has_header(false) + .map_parse_options(|parse_options| parse_options.with_comment_prefix(Some("#"))) + .into_reader_with_file_handle(file) + .finish()?; + assert_eq!(df.shape(), (3, 5)); + + let csv = r"!str,2,3,4,5 +!#& this is a comment +!str,2,3,4,5 +!#& this is also a comment +!str,2,3,4,5 +"; + + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .with_has_header(false) + .map_parse_options(|parse_options| parse_options.with_comment_prefix(Some("!#&"))) + .into_reader_with_file_handle(file) + .finish()?; + assert_eq!(df.shape(), (3, 5)); + + let csv = r"a,b,c,d,e +1,2,3,4,5 +% this is a comment +1,2,3,4,5 +% this is also a comment +1,2,3,4,5 +"; + + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .with_has_header(true) + .map_parse_options(|parse_options| parse_options.with_comment_prefix(Some("%"))) + .into_reader_with_file_handle(file) + .finish()?; + assert_eq!(df.shape(), (3, 5)); + + Ok(()) +} + +#[test] +fn test_null_values_argument() -> PolarsResult<()> { + let csv = r"1,a,foo +null-value,b,bar +3,null-value,ham +"; + + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .map_parse_options(|parse_options| { + parse_options.with_null_values(Some(NullValues::AllColumnsSingle("null-value".into()))) + }) + .into_reader_with_file_handle(file) + .finish()?; + assert!(df.get_columns()[0].null_count() > 0); + Ok(()) +} + +#[test] +fn test_no_newline_at_end() -> PolarsResult<()> { + let csv = r"a,b +foo,foo +bar,bar"; + let file = Cursor::new(csv); + let df = CsvReader::new(file).finish()?; + + use polars_core::df; + let expect = df![ + "a" => ["foo", "bar"], + "b" => ["foo", "bar"] + ]?; + assert!(df.equals(&expect)); + Ok(()) +} + +#[test] +#[cfg(feature = "temporal")] +fn test_automatic_datetime_parsing() -> PolarsResult<()> { + let csv = r"timestamp,open,high +2021-01-01 00:00:00,0.00305500,0.00306000 +2021-01-01 00:15:00,0.00298800,0.00300400 +2021-01-01 00:30:00,0.00298300,0.00300100 +2021-01-01 00:45:00,0.00299400,0.00304000 +"; + + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .map_parse_options(|parse_options| parse_options.with_try_parse_dates(true)) + .into_reader_with_file_handle(file) + .finish()?; + + let ts = df.column("timestamp")?; + assert_eq!( + ts.dtype(), + &DataType::Datetime(TimeUnit::Microseconds, None) + ); + assert_eq!(ts.null_count(), 0); + + Ok(()) +} + +#[test] +#[cfg(feature = "temporal")] +fn test_automatic_datetime_parsing_default_formats() -> PolarsResult<()> { + let csv = r"ts_dmy,ts_dmy_f,ts_dmy_p +01/01/2021 00:00:00,31-01-2021T00:00:00.123,31-01-2021 11:00 +01/01/2021 00:15:00,31-01-2021T00:15:00.123,31-01-2021 01:00 +01/01/2021 00:30:00,31-01-2021T00:30:00.123,31-01-2021 01:15 +01/01/2021 00:45:00,31-01-2021T00:45:00.123,31-01-2021 01:30 +"; + + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .map_parse_options(|parse_options| parse_options.with_try_parse_dates(true)) + .into_reader_with_file_handle(file) + .finish()?; + + for col in df.get_column_names() { + let ts = df.column(col)?; + assert_eq!( + ts.dtype(), + &DataType::Datetime(TimeUnit::Microseconds, None) + ); + assert_eq!(ts.null_count(), 0); + } + + Ok(()) +} + +#[test] +fn test_no_quotes() -> PolarsResult<()> { + let rolling_stones = r#"linenum,last_name,first_name +1,Jagger,Mick +2,O"Brian,Mary +3,Richards,Keith +4,L"Etoile,Bennet +5,Watts,Charlie +6,Smith,D"Shawn +7,Wyman,Bill +8,Woods,Ron +9,Jones,Brian +"#; + + let file = Cursor::new(rolling_stones); + let df = CsvReadOptions::default() + .map_parse_options(|parse_options| parse_options.with_quote_char(None)) + .into_reader_with_file_handle(file) + .finish()?; + assert_eq!(df.shape(), (9, 3)); + + Ok(()) +} + +#[test] +fn test_utf8() -> PolarsResult<()> { + // first part is valid ascii. later we have removed some bytes from the emoji. + let invalid_utf8 = [ + 111, 10, 98, 97, 114, 10, 104, 97, 109, 10, 115, 112, 97, 109, 10, 106, 97, 109, 10, 107, + 97, 109, 10, 108, 97, 109, 10, 207, 128, 10, 112, 97, 109, 10, 115, 116, 97, 109, 112, 10, + 240, 159, 137, 10, 97, 115, 99, 105, 105, 10, 240, 159, 144, 172, 10, 99, 105, 97, 111, + ]; + let file = Cursor::new(invalid_utf8); + assert!(CsvReader::new(file).finish().is_err()); + + Ok(()) +} + +#[test] +fn test_header_inference() -> PolarsResult<()> { + let csv = r#"not_a_header,really,even_if,it_looks_like_one +1,2,3,4 +4,3,2,1 +"#; + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .with_has_header(false) + .into_reader_with_file_handle(file) + .finish()?; + assert_eq!(df.dtypes(), vec![DataType::String; 4]); + Ok(()) +} + +#[test] +fn test_header_with_comments() -> PolarsResult<()> { + let csv = "# ignore me\na,b,c\nd,e,f"; + + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .map_parse_options(|parse_options| parse_options.with_comment_prefix(Some("#"))) + .into_reader_with_file_handle(file) + .finish()?; + // 1 row. + assert_eq!(df.shape(), (1, 3)); + + Ok(()) +} + +#[test] +#[cfg(feature = "temporal")] +fn test_ignore_parse_dates() -> PolarsResult<()> { + // if parse dates is set, a given schema should still prevail above date parsing. + let csv = r#"a,b,c +1,i,16200126 +2,j,16250130 +3,k,17220012 +4,l,17290009"#; + + use DataType::*; + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .with_dtype_overwrite(Some(vec![String, String, String].into())) + .map_parse_options(|parse_options| parse_options.with_try_parse_dates(true)) + .into_reader_with_file_handle(file) + .finish()?; + + assert_eq!(df.dtypes(), &[String, String, String]); + Ok(()) +} + +#[test] +fn test_projection_and_quoting() -> PolarsResult<()> { + let csv = "a,b,c,d +A1,'B1',C1,1 +A2,\"B2\",C2,2 +A3,\"B3\",C3,3 +A3,\"B4_\"\"with_embedded_double_quotes\"\"\",C4,4"; + + let file = Cursor::new(csv); + let df = CsvReader::new(file).finish()?; + assert_eq!(df.shape(), (4, 4)); + + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .with_n_threads(Some(1)) + .with_projection(Some(vec![0, 2].into())) + .into_reader_with_file_handle(file) + .finish()?; + assert_eq!(df.shape(), (4, 2)); + + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .with_n_threads(Some(1)) + .with_projection(Some(vec![1].into())) + .into_reader_with_file_handle(file) + .finish()?; + assert_eq!(df.shape(), (4, 1)); + + Ok(()) +} + +#[test] +fn test_infer_schema_0_rows() -> PolarsResult<()> { + let csv = r#"a,b,c,d +1,a,1.0,true +1,a,1.0,false +"#; + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .with_infer_schema_length(Some(0)) + .into_reader_with_file_handle(file) + .finish()?; + assert_eq!( + df.dtypes(), + &[ + DataType::String, + DataType::String, + DataType::String, + DataType::String + ] + ); + Ok(()) +} + +#[test] +fn test_infer_schema_eol() -> PolarsResult<()> { + // no eol after header + let no_eol = "colx,coly\nabcdef,1234"; + let file = Cursor::new(no_eol); + let df = CsvReader::new(file).finish()?; + assert_eq!(df.dtypes(), &[DataType::String, DataType::Int64]); + Ok(()) +} + +#[test] +fn test_whitespace_separators() -> PolarsResult<()> { + let tsv = "\ta\tb\tc\n1\ta1\tb1\tc1\n2\ta2\tb2\tc2\n".to_string(); + + let contents = vec![ + (tsv.replace('\t', " "), b' '), + (tsv.replace('\t', "-"), b'-'), + (tsv, b'\t'), + ]; + + for (content, sep) in contents { + let file = Cursor::new(&content); + let df = CsvReadOptions::default() + .map_parse_options(|parse_options| parse_options.with_separator(sep)) + .into_reader_with_file_handle(file) + .finish()?; + + assert_eq!(df.shape(), (2, 4)); + assert_eq!(df.get_column_names(), &["", "a", "b", "c"]); + } + + Ok(()) +} + +#[test] +fn test_scientific_floats() -> PolarsResult<()> { + let csv = r#"foo,bar +10000001,1e-5 +10000002,.04 +"#; + let file = Cursor::new(csv); + let df = CsvReader::new(file).finish()?; + assert_eq!(df.shape(), (2, 2)); + assert_eq!(df.dtypes(), &[DataType::Int64, DataType::Float64]); + + Ok(()) +} + +#[test] +fn test_tsv_header_offset() -> PolarsResult<()> { + let csv = "foo\tbar\n\t1000011\t1\n\t1000026\t2\n\t1000949\t2"; + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .map_parse_options(|parse_options| { + parse_options + .with_truncate_ragged_lines(true) + .with_separator(b'\t') + }) + .into_reader_with_file_handle(file) + .finish()?; + + assert_eq!(df.shape(), (3, 2)); + assert_eq!(df.dtypes(), &[DataType::String, DataType::Int64]); + let a = df.column("foo")?; + let a = a.str()?; + assert_eq!(a.get(0), None); + + Ok(()) +} + +#[test] +fn test_null_values_infer_schema() -> PolarsResult<()> { + let csv = r#"a,b +1,2 +3,NA +5,6"#; + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .map_parse_options(|parse_options| { + parse_options.with_null_values(Some(NullValues::AllColumnsSingle("NA".into()))) + }) + .into_reader_with_file_handle(file) + .finish()?; + let expected = &[DataType::Int64, DataType::Int64]; + assert_eq!(df.dtypes(), expected); + Ok(()) +} + +#[test] +fn test_comma_separated_field_in_tsv() -> PolarsResult<()> { + let csv = "first\tsecond\n1\t2.3,2.4\n3\t4.5,4.6\n"; + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .map_parse_options(|parse_options| parse_options.with_separator(b'\t')) + .into_reader_with_file_handle(file) + .finish()?; + assert_eq!(df.dtypes(), &[DataType::Int64, DataType::String]); + Ok(()) +} + +#[test] +fn test_quoted_projection() -> PolarsResult<()> { + let csv = r#"c1,c2,c3,c4,c5 +a,"b",c,d,1 +a,"b",c,d,1 +a,b,c,d,1"#; + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .with_projection(Some(Arc::new(vec![1, 4]))) + .into_reader_with_file_handle(file) + .finish()?; + assert_eq!(df.shape(), (3, 2)); + + Ok(()) +} + +#[test] +fn test_last_line_incomplete() -> PolarsResult<()> { + // test a last line that is incomplete and not finishes with a new line char + let csv = "b5bbf310dffe3372fd5d37a18339fea5,6a2752ffad059badb5f1f3c7b9e4905d,-2,0.033191,811.619 0.487341,16,GGTGTGAAATTTCACACC,TTTAATTATAATTAAG,+ +b5bbf310dffe3372fd5d37a18339fea5,e3fd7b95be3453a34361da84f815687d,-2,0.0335936,821.465 0.490834,1"; + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .with_has_header(false) + .into_reader_with_file_handle(file) + .finish()?; + assert_eq!(df.shape(), (2, 9)); + Ok(()) +} + +#[test] +fn test_quoted_bool_ints() -> PolarsResult<()> { + let csv = r#"foo,bar,baz +1,"4","false" +3,"5","false" +5,"6","true" +"#; + let file = Cursor::new(csv); + let df = CsvReader::new(file).finish()?; + let expected = df![ + "foo" => [1, 3, 5], + "bar" => [4, 5, 6], + "baz" => [false, false, true], + ]?; + assert!(df.equals_missing(&expected)); + + Ok(()) +} + +#[test] +fn test_skip_inference() -> PolarsResult<()> { + let csv = r#"metadata +line +foo,bar +1,2 +3,4 +5,6 +"#; + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .with_skip_rows(2) + .into_reader_with_file_handle(file.clone()) + .finish()?; + assert_eq!(df.get_column_names(), &["foo", "bar"]); + assert_eq!(df.shape(), (3, 2)); + let df = CsvReadOptions::default() + .with_skip_rows(2) + .with_skip_rows_after_header(2) + .into_reader_with_file_handle(file.clone()) + .finish()?; + assert_eq!(df.get_column_names(), &["foo", "bar"]); + assert_eq!(df.shape(), (1, 2)); + let df = CsvReadOptions::default() + .map_parse_options(|parse_options| parse_options.with_truncate_ragged_lines(true)) + .into_reader_with_file_handle(file) + .finish()?; + assert_eq!(df.shape(), (5, 1)); + + Ok(()) +} + +#[test] +fn test_with_row_index() -> PolarsResult<()> { + let df = CsvReadOptions::default() + .with_row_index(Some(RowIndex { + name: "rc".into(), + offset: 0, + })) + .try_into_reader_with_file_path(Some(FOODS_CSV.into()))? + .finish()?; + let rc = df.column("rc")?; + assert_eq!( + rc.idx()?.into_no_null_iter().collect::>(), + (0 as IdxSize..27).collect::>() + ); + let df = CsvReadOptions::default() + .with_row_index(Some(RowIndex { + name: "rc_2".into(), + offset: 10, + })) + .try_into_reader_with_file_path(Some(FOODS_CSV.into()))? + .finish()?; + let rc = df.column("rc_2")?; + assert_eq!( + rc.idx()?.into_no_null_iter().collect::>(), + (10 as IdxSize..37).collect::>() + ); + Ok(()) +} + +#[test] +fn test_empty_string_cols() -> PolarsResult<()> { + let csv = "\nabc\n\nxyz\n"; + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .with_has_header(false) + .into_reader_with_file_handle(file) + .finish()?; + let s = df.column("column_1")?; + let ca = s.str()?; + assert_eq!( + ca.iter().collect::>(), + &[None, Some("abc"), None, Some("xyz")] + ); + + let csv = ",\nabc,333\n,666\nxyz,999"; + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .with_has_header(false) + .into_reader_with_file_handle(file) + .finish()?; + let expected = df![ + "column_1" => [None, Some("abc"), None, Some("xyz")], + "column_2" => [None, Some(333i64), Some(666), Some(999)] + ]?; + assert!(df.equals_missing(&expected)); + Ok(()) +} + +#[test] +fn test_empty_col_names() -> PolarsResult<()> { + let csv = "a,b,c\n1,2,3"; + let file = Cursor::new(csv); + let df = CsvReader::new(file).finish()?; + let expected = df![ + "a" => [1i64], + "b" => [2i64], + "c" => [3i64] + ]?; + assert!(df.equals(&expected)); + + let csv = "a,,c\n1,2,3"; + let file = Cursor::new(csv); + let df = CsvReader::new(file).finish()?; + let expected = df![ + "a" => [1i64], + "" => [2i64], + "c" => [3i64] + ]?; + assert!(df.equals(&expected)); + + let csv = "a,b,\n1,2,3"; + let file = Cursor::new(csv); + let df = CsvReader::new(file).finish()?; + let expected = df![ + "a" => [1i64], + "b" => [2i64], + "" => [3i64] + ]?; + assert!(df.equals(&expected)); + + let csv = "a,b,,\n1,2,3"; + let file = Cursor::new(csv); + let df_result = CsvReader::new(file).finish()?; + assert_eq!(df_result.shape(), (1, 4)); + + let csv = "a,b\n1,2,3"; + let file = Cursor::new(csv); + let df_result = CsvReader::new(file).finish(); + assert!(df_result.is_err()); + Ok(()) +} + +#[test] +fn test_trailing_empty_string_cols() -> PolarsResult<()> { + let csv = "colx\nabc\nxyz\n\"\""; + let file = Cursor::new(csv); + let df = CsvReader::new(file).finish()?; + let col = df.column("colx")?; + let col = col.str()?; + assert_eq!( + col.into_no_null_iter().collect::>(), + &["abc", "xyz", ""] + ); + + let csv = "colx,coly\nabc,def\nxyz,mno\n,"; + let file = Cursor::new(csv); + let df = CsvReader::new(file).finish()?; + + assert_eq!( + df.get(1).unwrap(), + &[AnyValue::String("xyz"), AnyValue::String("mno")] + ); + assert_eq!(df.get(2).unwrap(), &[AnyValue::Null, AnyValue::Null]); + + Ok(()) +} + +#[test] +fn test_escaping_quotes() -> PolarsResult<()> { + let csv = "a\n\"\"\"\""; + let file = Cursor::new(csv); + let df = CsvReader::new(file).finish()?; + let col = df.column("a")?; + let col = col.str()?; + assert_eq!(col.into_no_null_iter().collect::>(), &["\""]); + Ok(()) +} + +#[test] +fn test_header_only() -> PolarsResult<()> { + let csv = "a,b,c"; + let file = Cursor::new(csv); + + // no header + let df = CsvReadOptions::default() + .with_has_header(false) + .into_reader_with_file_handle(file) + .finish()?; + assert_eq!(df.shape(), (1, 3)); + + // has header + for csv in &["x,y,z", "x,y,z\n"] { + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .with_has_header(true) + .into_reader_with_file_handle(file) + .finish()?; + + assert_eq!(df.shape(), (0, 3)); + assert_eq!( + df.dtypes(), + &[DataType::String, DataType::String, DataType::String] + ); + } + + Ok(()) +} + +#[test] +fn test_empty_csv() { + let csv = ""; + let file = Cursor::new(csv); + for h in [true, false] { + assert!(matches!( + CsvReadOptions::default() + .with_has_header(h) + .into_reader_with_file_handle(file.clone()) + .finish(), + Err(PolarsError::NoData(_)) + )) + } +} + +#[test] +fn test_try_parse_dates_empty() -> PolarsResult<()> { + let csv = "date +1745-04-02 +1742-03-21 +1743-06-16 +1730-07-22 + +1739-03-16 +"; + let file = Cursor::new(csv); + + let df = CsvReadOptions::default() + .map_parse_options(|parse_options| parse_options.with_try_parse_dates(true)) + .into_reader_with_file_handle(file) + .finish()?; + + assert_eq!(df.dtypes(), &[DataType::Date]); + assert_eq!(df.column("date")?.null_count(), 1); + Ok(()) +} + +#[test] +fn test_try_parse_dates_3380() -> PolarsResult<()> { + let csv = "lat;lon;validdate;t_2m:C;precip_1h:mm +46.685;7.953;2022-05-10T07:07:12Z;6.1;0.00 +46.685;7.953;2022-05-10T08:07:12Z;8.8;0.00"; + let file = Cursor::new(csv); + let df = CsvReadOptions::default() + .map_parse_options(|parse_options| { + parse_options + .with_separator(b';') + .with_try_parse_dates(true) + }) + .into_reader_with_file_handle(file) + .finish()?; + + assert_eq!(df.column("validdate")?.null_count(), 0); + Ok(()) +} + +#[test] +fn test_leading_whitespace_with_quote() -> PolarsResult<()> { + let csv = r#" +"ABC","DEF", +"24.5"," 4.1" +"#; + let file = Cursor::new(csv); + let df = CsvReader::new(file).finish()?; + let col_1 = df.column("ABC").unwrap(); + let col_2 = df.column("DEF").unwrap(); + assert_eq!(col_1.get(0)?, AnyValue::Float64(24.5)); + assert_eq!(col_2.get(0)?, AnyValue::String(" 4.1")); + Ok(()) +} + +#[test] +fn test_read_io_reader() { + let path = "../../examples/datasets/foods1.csv"; + let file = std::fs::File::open(path).unwrap(); + let mut reader = CsvReadOptions::default() + .with_chunk_size(5) + .try_into_reader_with_file_path(Some(path.into())) + .unwrap(); + + let mut reader = reader.batched_borrowed().unwrap(); + let batches = reader.next_batches(5).unwrap().unwrap(); + assert_eq!(batches.len(), 5); + let df = concat_df(&batches).unwrap(); + assert!(df.height() > 0); + let expected = CsvReader::new(file) + .finish() + .unwrap() + .head(Some(df.height())); + assert_eq!(&df, &expected); +} diff --git a/crates/polars/tests/it/io/ipc.rs b/crates/polars/tests/it/io/ipc.rs new file mode 100644 index 000000000000..b8930f12baf7 --- /dev/null +++ b/crates/polars/tests/it/io/ipc.rs @@ -0,0 +1,153 @@ +use std::io::{Cursor, Seek, SeekFrom}; + +use polars::prelude::*; + +#[test] +fn test_ipc_compression_variadic_buffers() { + let mut df = df![ + "foo" => std::iter::repeat_n("Home delivery vat 24 %",3).collect::>() + ] + .unwrap(); + + let mut file = std::io::Cursor::new(vec![]); + IpcWriter::new(&mut file) + .with_compression(Some(IpcCompression::LZ4)) + .with_compat_level(CompatLevel::newest()) + .finish(&mut df) + .unwrap(); + + file.seek(SeekFrom::Start(0)).unwrap(); + let out = IpcReader::new(file).finish().unwrap(); + + assert_eq!(out.shape(), (3, 1)); +} + +#[cfg(test)] +pub(crate) fn create_df() -> DataFrame { + let s0 = Column::new("days".into(), [0, 1, 2, 3, 4].as_ref()); + let s1 = Column::new("temp".into(), [22.1, 19.9, 7., 2., 3.].as_ref()); + DataFrame::new(vec![s0, s1]).unwrap() +} + +#[test] +fn write_and_read_ipc() { + // Vec : Write + Read + // Cursor>: Seek + let mut buf: Cursor> = Cursor::new(Vec::new()); + let mut df = create_df(); + + IpcWriter::new(&mut buf) + .finish(&mut df) + .expect("ipc writer"); + + buf.set_position(0); + + let df_read = IpcReader::new(buf).finish().unwrap(); + assert!(df.equals(&df_read)); +} + +#[test] +fn test_read_ipc_with_projection() { + let mut buf: Cursor> = Cursor::new(Vec::new()); + let mut df = df!("a" => [1, 2, 3], "b" => [2, 3, 4], "c" => [3, 4, 5]).unwrap(); + + IpcWriter::new(&mut buf) + .finish(&mut df) + .expect("ipc writer"); + buf.set_position(0); + + let expected = df!("b" => [2, 3, 4], "c" => [3, 4, 5]).unwrap(); + let df_read = IpcReader::new(buf) + .with_projection(Some(vec![1, 2])) + .finish() + .unwrap(); + assert_eq!(df_read.shape(), (3, 2)); + df_read.equals(&expected); +} + +#[test] +fn test_read_ipc_with_columns() { + let mut buf: Cursor> = Cursor::new(Vec::new()); + let mut df = df!("a" => [1, 2, 3], "b" => [2, 3, 4], "c" => [3, 4, 5]).unwrap(); + + IpcWriter::new(&mut buf) + .finish(&mut df) + .expect("ipc writer"); + buf.set_position(0); + + let expected = df!("b" => [2, 3, 4], "c" => [3, 4, 5]).unwrap(); + let df_read = IpcReader::new(buf) + .with_columns(Some(vec!["c".to_string(), "b".to_string()])) + .finish() + .unwrap(); + df_read.equals(&expected); + + for compat_level in [0, 1].map(|level| CompatLevel::with_level(level).unwrap()) { + let mut buf: Cursor> = Cursor::new(Vec::new()); + let mut df = df![ + "letters" => ["x", "y", "z"], + "ints" => [123, 456, 789], + "floats" => [4.5, 10.0, 10.0], + "other" => ["misc", "other", "value"], + ] + .unwrap(); + IpcWriter::new(&mut buf) + .with_compat_level(compat_level) + .finish(&mut df) + .expect("ipc writer"); + buf.set_position(0); + let expected = df![ + "letters" => ["x", "y", "z"], + "floats" => [4.5, 10.0, 10.0], + "other" => ["misc", "other", "value"], + "ints" => [123, 456, 789], + ] + .unwrap(); + let df_read = IpcReader::new(&mut buf) + .with_columns(Some(vec![ + "letters".to_string(), + "floats".to_string(), + "other".to_string(), + "ints".to_string(), + ])) + .finish() + .unwrap(); + assert!(df_read.equals(&expected)); + } +} + +#[test] +fn test_write_with_compression() { + let mut df = create_df(); + + let compressions = vec![None, Some(IpcCompression::LZ4), Some(IpcCompression::ZSTD)]; + + for compression in compressions.into_iter() { + let mut buf: Cursor> = Cursor::new(Vec::new()); + IpcWriter::new(&mut buf) + .with_compression(compression) + .finish(&mut df) + .expect("ipc writer"); + buf.set_position(0); + + let df_read = IpcReader::new(buf) + .finish() + .unwrap_or_else(|_| panic!("IPC reader: {:?}", compression)); + assert!(df.equals(&df_read)); + } +} + +#[test] +fn write_and_read_ipc_empty_series() { + let mut buf: Cursor> = Cursor::new(Vec::new()); + let chunked_array = Float64Chunked::new("empty".into(), &[0_f64; 0]); + let mut df = DataFrame::new(vec![chunked_array.into_column()]).unwrap(); + IpcWriter::new(&mut buf) + .finish(&mut df) + .expect("ipc writer"); + + buf.set_position(0); + + let df_read = IpcReader::new(buf).finish().unwrap(); + assert!(df.equals(&df_read)); +} diff --git a/crates/polars/tests/it/io/ipc_stream.rs b/crates/polars/tests/it/io/ipc_stream.rs new file mode 100644 index 000000000000..a4440c55ac2d --- /dev/null +++ b/crates/polars/tests/it/io/ipc_stream.rs @@ -0,0 +1,159 @@ +#[cfg(test)] +mod test { + use std::io::Cursor; + + use polars_core::prelude::*; + use polars_core::{assert_df_eq, df}; + use polars_io::ipc::*; + use polars_io::{SerReader, SerWriter}; + + use crate::io::create_df; + + fn create_ipc_stream(mut df: DataFrame) -> Cursor> { + let mut buf: Cursor> = Cursor::new(Vec::new()); + + IpcStreamWriter::new(&mut buf) + .finish(&mut df) + .expect("failed to write ICP stream"); + + buf.set_position(0); + + buf + } + + #[test] + fn write_and_read_ipc_stream() { + let df = create_df(); + + let reader = create_ipc_stream(df); + + let actual = IpcStreamReader::new(reader).finish().unwrap(); + + let expected = create_df(); + assert_df_eq!(actual, expected); + } + + #[test] + fn test_read_ipc_stream_with_projection() { + let df = df!( + "a" => [1], + "b" => [2], + "c" => [3], + ) + .unwrap(); + + let reader = create_ipc_stream(df); + + let actual = IpcStreamReader::new(reader) + .with_projection(Some(vec![1, 2])) + .finish() + .unwrap(); + + let expected = df!( + "b" => [2], + "c" => [3], + ) + .unwrap(); + assert_df_eq!(actual, expected); + } + + #[test] + fn test_read_ipc_stream_with_columns() { + let df = df!( + "a" => [1], + "b" => [2], + "c" => [3], + ) + .unwrap(); + + let reader = create_ipc_stream(df); + + let actual = IpcStreamReader::new(reader) + .with_columns(Some(vec!["c".to_string(), "b".to_string()])) + .finish() + .unwrap(); + + let expected = df!( + "c" => [3], + "b" => [2], + ) + .unwrap(); + assert_df_eq!(actual, expected); + } + + #[test] + fn test_read_ipc_stream_with_columns_reorder() { + let df = df![ + "a" => [1], + "b" => [2], + "c" => [3], + ] + .unwrap(); + + let reader = create_ipc_stream(df); + + let actual = IpcStreamReader::new(reader) + .with_columns(Some(vec![ + "b".to_string(), + "c".to_string(), + "a".to_string(), + ])) + .finish() + .unwrap(); + + let expected = df![ + "b" => [2], + "c" => [3], + "a" => [1], + ] + .unwrap(); + assert_df_eq!(actual, expected); + } + + #[test] + fn test_read_invalid_stream() { + let buf: Cursor> = Cursor::new(Vec::new()); + assert!(IpcStreamReader::new(buf.clone()).arrow_schema().is_err()); + assert!(IpcStreamReader::new(buf).finish().is_err()); + } + + #[test] + fn test_write_with_lz4_compression() { + test_write_with_compression(IpcCompression::LZ4); + } + + #[test] + fn test_write_with_zstd_compression() { + test_write_with_compression(IpcCompression::ZSTD); + } + + fn test_write_with_compression(compression: IpcCompression) { + let reader = { + let mut writer: Cursor> = Cursor::new(Vec::new()); + IpcStreamWriter::new(&mut writer) + .with_compression(Some(compression)) + .finish(&mut create_df()) + .unwrap(); + writer.set_position(0); + writer + }; + + let actual = IpcStreamReader::new(reader).finish().unwrap(); + assert_df_eq!(actual, create_df()); + } + + #[test] + fn write_and_read_ipc_stream_empty_series() { + fn df() -> DataFrame { + DataFrame::new(vec![ + Float64Chunked::new("empty".into(), &[0_f64; 0]).into_column(), + ]) + .unwrap() + } + + let reader = create_ipc_stream(df()); + + let actual = IpcStreamReader::new(reader).finish().unwrap(); + assert_df_eq!(df(), actual); + } +} diff --git a/crates/polars/tests/it/io/json.rs b/crates/polars/tests/it/io/json.rs new file mode 100644 index 000000000000..0105d53eecd3 --- /dev/null +++ b/crates/polars/tests/it/io/json.rs @@ -0,0 +1,187 @@ +use std::io::Cursor; +use std::num::NonZeroUsize; + +use super::*; + +#[test] +fn read_json() { + let basic_json = r#"{"a":1, "b":2.0, "c":false, "d":"4"} +{"a":-10, "b":-3.5, "c":true, "d":"4"} +{"a":2, "b":0.6, "c":false, "d":"text"} +{"a":1, "b":2.0, "c":false, "d":"4"} +{"a":7, "b":-3.5, "c":true, "d":"4"} +{"a":1, "b":0.6, "c":false, "d":"text"} +{"a":1, "b":2.0, "c":false, "d":"4"} +{"a":5, "b":-3.5, "c":true, "d":"4"} +{"a":1, "b":0.6, "c":false, "d":"text"} +{"a":1, "b":2.0, "c":false, "d":"4"} +{"a":1, "b":-3.5, "c":true, "d":"4"} +{"a":100000000000000, "b":0.6, "c":false, "d":"text"} +"#; + let file = Cursor::new(basic_json); + let df = JsonReader::new(file) + .infer_schema_len(NonZeroUsize::new(3)) + .with_json_format(JsonFormat::JsonLines) + .with_batch_size(NonZeroUsize::new(3).unwrap()) + .finish() + .unwrap(); + assert_eq!("a", df.get_columns()[0].name().as_str()); + assert_eq!("d", df.get_columns()[3].name().as_str()); + assert_eq!((12, 4), df.shape()); +} +#[test] +fn read_json_with_whitespace() { + let basic_json = r#"{ "a":1, "b":2.0, "c" :false , "d":"4"} +{"a":-10, "b":-3.5, "c":true, "d":"4"} +{"a":2, "b":0.6, "c":false, "d":"text" } +{"a":1, "b":2.0, "c":false, "d":"4"} + + +{"a": 7, "b":-3.5, "c":true, "d":"4"} +{"a":1, "b":0.6, "c":false, "d":"text"} +{"a":1, "b":2.0, "c":false, "d" :"4"} +{"a":5, "b":-3.5, "c":true , "d":"4"} + +{"a":1, "b":0.6, "c":false, "d":"text"} +{"a":1, "b":2.0, "c":false, "d":"4"} +{"a":1, "b":32.5, "c":false, "d":"99"} +{ "a":100000000000000, "b":0.6, "c":false, "d":"text"}"#; + let file = Cursor::new(basic_json); + let df = JsonReader::new(file) + .infer_schema_len(NonZeroUsize::new(3)) + .with_json_format(JsonFormat::JsonLines) + .with_batch_size(NonZeroUsize::new(3).unwrap()) + .finish() + .unwrap(); + assert_eq!("a", df.get_columns()[0].name().as_str()); + assert_eq!("d", df.get_columns()[3].name().as_str()); + assert_eq!((12, 4), df.shape()); +} +#[test] +fn read_json_with_escapes() { + let escaped_json = r#"{"id": 1, "text": "\""} + {"text": "\n{\n\t\t\"inner\": \"json\n}\n", "id": 10} + {"id": 0, "text":"\"","date":"2013-08-03 15:17:23"} + {"id": 1, "text":"\"123\"","date":"2009-05-19 21:07:53"} + {"id": 2, "text":"/....","date":"2009-05-19 21:07:53"} + {"id": 3, "text":"\n\n..","date":"2"} + {"id": 4, "text":"\"'/\n...","date":"2009-05-19 21:07:53"} + {"id": 5, "text":".h\"h1hh\\21hi1e2emm...","date":"2009-05-19 21:07:53"} + {"id": 6, "text":"xxxx....","date":"2009-05-19 21:07:53"} + {"id": 7, "text":".\"quoted text\".","date":"2009-05-19 21:07:53"} + +"#; + let file = Cursor::new(escaped_json); + let df = JsonLineReader::new(file) + .infer_schema_len(NonZeroUsize::new(6)) + .finish() + .unwrap(); + assert_eq!("id", df.get_columns()[0].name().as_str()); + assert_eq!( + AnyValue::String("\""), + df.column("text").unwrap().get(0).unwrap() + ); + assert_eq!("text", df.get_columns()[1].name().as_str()); + assert_eq!((10, 3), df.shape()); +} + +#[test] +fn read_unordered_json() { + let unordered_json = r#"{"a":1, "b":2.0, "c":false, "d":"4"} +{"a":-10, "b":-3.5, "c":true, "d":"4"} +{"a":2, "b":0.6, "c":false, "d":"text"} +{"a":1, "b":2.0, "c":false, "d":"4"} +{"a":7, "b":-3.5, "c":true, "d":"4"} +{"a":1, "b":0.6, "c":false, "d":"text"} +{"d":"1", "c":false, "d":"4", "b":2.0} +{"b":-3.5, "c":true, "d":"4", "a":5} +{"d":"text", "a":1, "c":false, "b":0.6} +{"a":1, "b":2.0, "c":false, "d":"4"} +{"a":1, "b":-3.5, "c":true, "d":"4"} +{"a":100000000000000, "b":0.6, "c":false, "d":"text"} +"#; + let file = Cursor::new(unordered_json); + let df = JsonReader::new(file) + .infer_schema_len(NonZeroUsize::new(3)) + .with_json_format(JsonFormat::JsonLines) + .with_batch_size(NonZeroUsize::new(3).unwrap()) + .finish() + .unwrap(); + assert_eq!("a", df.get_columns()[0].name().as_str()); + assert_eq!("d", df.get_columns()[3].name().as_str()); + assert_eq!((12, 4), df.shape()); +} + +#[test] +fn read_ndjson_with_trailing_newline() { + let data = r#"{"Column1":"Value1"} +"#; + + let file = Cursor::new(data); + let df = JsonReader::new(file) + .with_json_format(JsonFormat::JsonLines) + .finish() + .unwrap(); + + let expected = df! { + "Column1" => ["Value1"] + } + .unwrap(); + assert!(expected.equals(&df)); +} +#[test] +#[cfg(feature = "dtype-struct")] +fn test_read_ndjson_iss_5875() { + let jsonlines = r#" + {"struct": {"int_inner": [1, 2, 3], "float_inner": 5.0, "str_inner": ["a", "b", "c"]}} + {"struct": {"int_inner": [4, 5, 6]}, "float": 4.0} + "#; + let cursor = Cursor::new(jsonlines); + + let df = JsonLineReader::new(cursor).finish(); + assert!(df.is_ok()); + + let field_int_inner = Field::new( + "int_inner".into(), + DataType::List(Box::new(DataType::Int64)), + ); + let field_float_inner = Field::new("float_inner".into(), DataType::Float64); + let field_str_inner = Field::new( + "str_inner".into(), + DataType::List(Box::new(DataType::String)), + ); + + let mut schema = Schema::default(); + schema.with_column( + "struct".into(), + DataType::Struct(vec![field_int_inner, field_float_inner, field_str_inner]), + ); + schema.with_column("float".into(), DataType::Float64); + + assert_eq!(&schema, &(**df.unwrap().schema())); +} + +#[test] +#[cfg(feature = "dtype-struct")] +fn test_read_ndjson_iss_5875_part3() { + let jsonlines = r#" + {"key1":"value1", "key2": "value2", "key3": {"k1": 2, "k3": "value5", "k10": 5}} + {"key1":"value5", "key2": "value4", "key3": {"k1": 2, "k5": "value5", "k10": 4}} + {"key1":"value6", "key3": {"k1": 5, "k3": "value5"}}"#; + + let cursor = Cursor::new(jsonlines); + + let df = JsonLineReader::new(cursor).finish(); + assert!(df.is_ok()); +} + +#[test] +#[cfg(feature = "dtype-struct")] +fn test_read_ndjson_iss_6148() { + let json = b"{\"a\":1,\"b\":{}}\n{\"a\":2,\"b\":{}}\n"; + + let cursor = Cursor::new(json); + + let df = JsonLineReader::new(cursor).finish(); + assert!(df.is_ok()); +} diff --git a/crates/polars/tests/it/io/mod.rs b/crates/polars/tests/it/io/mod.rs new file mode 100644 index 000000000000..6ea615799996 --- /dev/null +++ b/crates/polars/tests/it/io/mod.rs @@ -0,0 +1,23 @@ +mod csv; + +#[cfg(feature = "json")] +mod json; + +#[cfg(feature = "parquet")] +mod parquet; + +#[cfg(feature = "avro")] +mod avro; + +#[cfg(feature = "ipc")] +mod ipc; +#[cfg(feature = "ipc_streaming")] +mod ipc_stream; + +use polars::prelude::*; + +pub(crate) fn create_df() -> DataFrame { + let s0 = Column::new("days".into(), [0, 1, 2, 3, 4].as_ref()); + let s1 = Column::new("temp".into(), [22.1, 19.9, 7., 2., 3.].as_ref()); + DataFrame::new(vec![s0, s1]).unwrap() +} diff --git a/crates/polars/tests/it/io/parquet/arrow/mod.rs b/crates/polars/tests/it/io/parquet/arrow/mod.rs new file mode 100644 index 000000000000..364b95f220a7 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/arrow/mod.rs @@ -0,0 +1,932 @@ +mod read; +mod write; + +use std::io::{Cursor, Read, Seek}; +use std::sync::Arc; + +use arrow::array::*; +use arrow::bitmap::Bitmap; +use arrow::datatypes::*; +use arrow::record_batch::RecordBatchT; +use arrow::types::{NativeType, i256}; +use ethnum::AsI256; +use polars_error::PolarsResult; +use polars_parquet::read::{self as p_read}; +use polars_parquet::write::*; + +use super::read::file::FileReader; + +fn new_struct( + arrays: Vec>, + length: usize, + names: Vec, + validity: Option, +) -> StructArray { + let fields = names + .into_iter() + .zip(arrays.iter()) + .map(|(n, a)| Field::new(n.into(), a.dtype().clone(), true)) + .collect(); + StructArray::new(ArrowDataType::Struct(fields), length, arrays, validity) +} + +pub fn read_column(mut reader: R, column: &str) -> PolarsResult> { + let metadata = p_read::read_metadata(&mut reader)?; + let schema = p_read::infer_schema(&metadata)?; + + let schema = schema.filter(|_, f| f.name == column); + + let mut reader = FileReader::new(reader, metadata.row_groups, schema, None); + + let array = reader.next().unwrap()?.into_arrays().pop().unwrap(); + + Ok(array) +} + +pub fn pyarrow_nested_edge(column: &str) -> Box { + match column { + "simple" => { + // [[0, 1]] + let data = [Some(vec![Some(0), Some(1)])]; + let mut a = MutableListArray::>::new(); + a.try_extend(data).unwrap(); + let array: ListArray = a.into(); + Box::new(array) + }, + "null" => { + // [None] + let data = [None::>>]; + let mut a = MutableListArray::>::new(); + a.try_extend(data).unwrap(); + let array: ListArray = a.into(); + Box::new(array) + }, + "empty" => { + // [None] + let data: [Option>>; 0] = []; + let mut a = MutableListArray::>::new(); + a.try_extend(data).unwrap(); + let array: ListArray = a.into(); + Box::new(array) + }, + "struct_list_nullable" => { + // [ + // {"f1": ["a", "b", None, "c"]} + // ] + let a = ListArray::::new( + ArrowDataType::LargeList(Box::new(Field::new( + "item".into(), + ArrowDataType::Utf8View, + true, + ))), + vec![0, 4].try_into().unwrap(), + Utf8ViewArray::from_slice([Some("a"), Some("b"), None, Some("c")]).boxed(), + None, + ); + StructArray::new( + ArrowDataType::Struct(vec![Field::new("f1".into(), a.dtype().clone(), true)]), + a.len(), + vec![a.boxed()], + None, + ) + .boxed() + }, + "list_struct_list_nullable" => { + let values = pyarrow_nested_edge("struct_list_nullable"); + ListArray::::new( + ArrowDataType::LargeList(Box::new(Field::new( + "item".into(), + values.dtype().clone(), + true, + ))), + vec![0, 1].try_into().unwrap(), + values, + None, + ) + .boxed() + }, + _ => todo!(), + } +} + +pub fn pyarrow_nested_nullable(column: &str) -> Box { + let i64_values = &[ + Some(0), + Some(1), + Some(2), + None, + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + Some(9), + Some(10), + ]; + let offsets = vec![0, 2, 2, 5, 8, 8, 11, 11, 12].try_into().unwrap(); + + let values = match column { + "list_int64" => { + // [[0, 1], None, [2, None, 3], [4, 5, 6], [], [7, 8, 9], None, [10]] + PrimitiveArray::::from(i64_values).boxed() + }, + "list_int64_required" | "list_int64_optional_required" | "list_int64_required_required" => { + // [[0, 1], None, [2, 0, 3], [4, 5, 6], [], [7, 8, 9], None, [10]] + PrimitiveArray::::from(&[ + Some(0), + Some(1), + Some(2), + Some(0), + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + Some(9), + Some(10), + ]) + .boxed() + }, + "list_int16" => PrimitiveArray::::from(&[ + Some(0), + Some(1), + Some(2), + None, + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + Some(9), + Some(10), + ]) + .boxed(), + "list_bool" => BooleanArray::from(&[ + Some(false), + Some(true), + Some(true), + None, + Some(false), + Some(true), + Some(false), + Some(true), + Some(false), + Some(false), + Some(false), + Some(true), + ]) + .boxed(), + /* + string = [ + ["Hello", "bbb"], + None, + ["aa", None, ""], + ["bbb", "aa", "ccc"], + [], + ["abc", "bbb", "bbb"], + None, + [""], + ] + */ + "list_utf8" => Utf8ViewArray::from_slice([ + Some("Hello".to_string()), + Some("bbb".to_string()), + Some("aa".to_string()), + None, + Some("".to_string()), + Some("bbb".to_string()), + Some("aa".to_string()), + Some("ccc".to_string()), + Some("abc".to_string()), + Some("bbb".to_string()), + Some("bbb".to_string()), + Some("".to_string()), + ]) + .boxed(), + "list_large_binary" => Box::new(BinaryArray::::from([ + Some(b"Hello".to_vec()), + Some(b"bbb".to_vec()), + Some(b"aa".to_vec()), + None, + Some(b"".to_vec()), + Some(b"bbb".to_vec()), + Some(b"aa".to_vec()), + Some(b"ccc".to_vec()), + Some(b"abc".to_vec()), + Some(b"bbb".to_vec()), + Some(b"bbb".to_vec()), + Some(b"".to_vec()), + ])), + "list_decimal" => { + let values = i64_values + .iter() + .map(|x| x.map(|x| x as i128)) + .collect::>(); + Box::new(PrimitiveArray::::from(values).to(ArrowDataType::Decimal(9, 0))) + }, + "list_decimal256" => { + let values = i64_values + .iter() + .map(|x| x.map(|x| i256(x.as_i256()))) + .collect::>(); + let array = PrimitiveArray::::from(values).to(ArrowDataType::Decimal256(9, 0)); + Box::new(array) + }, + "list_nested_i64" + | "list_nested_inner_required_i64" + | "list_nested_inner_required_required_i64" => { + Box::new(NullArray::new(ArrowDataType::Null, 1)) + }, + "struct_list_nullable" => pyarrow_nested_nullable("list_utf8"), + "list_struct_nullable" => { + let array = Utf8ViewArray::from_slice([ + Some("a"), + Some("b"), + // + Some("b"), + None, + Some("b"), + // + None, + None, + None, + // + Some("d"), + Some("d"), + Some("d"), + // + Some("e"), + ]) + .boxed(); + + let len = array.len(); + new_struct( + vec![array], + len, + vec!["a".to_string()], + Some( + [ + true, true, // + true, false, true, // + true, true, true, // + true, true, true, // + true, + ] + .into(), + ), + ) + .boxed() + }, + "list_struct_list_nullable" => { + /* + [ + [{"a": ["a"]}, {"a": ["b"]}], + None, + [{"a": ["b"]}, None, {"a": ["b"]}], + [{"a": None}, {"a": None}, {"a": None}], + [], + [{"a": ["d"]}, {"a": [None]}, {"a": ["c", "d"]}], + None, + [{"a": []}], + ] + */ + let array = Utf8ViewArray::from_slice([ + Some("a"), + Some("b"), + // + Some("b"), + Some("b"), + // + Some("d"), + None, + Some("c"), + Some("d"), + ]) + .boxed(); + + let array = ListArray::::new( + ArrowDataType::LargeList(Box::new(Field::new( + "item".into(), + array.dtype().clone(), + true, + ))), + vec![0, 1, 2, 3, 3, 4, 4, 4, 4, 5, 6, 8, 8] + .try_into() + .unwrap(), + array, + Some( + [ + true, true, true, false, true, false, false, false, true, true, true, true, + ] + .into(), + ), + ) + .boxed(); + + let len = array.len(); + new_struct( + vec![array], + len, + vec!["a".to_string()], + Some( + [ + true, true, // + true, false, true, // + true, true, true, // + true, true, true, // + true, + ] + .into(), + ), + ) + .boxed() + }, + other => unreachable!("{}", other), + }; + + match column { + "list_int64_required_required" => { + // [[0, 1], [], [2, 0, 3], [4, 5, 6], [], [7, 8, 9], [], [10]] + let dtype = ArrowDataType::LargeList(Box::new(Field::new( + "item".into(), + ArrowDataType::Int64, + false, + ))); + ListArray::::new(dtype, offsets, values, None).boxed() + }, + "list_int64_optional_required" => { + // [[0, 1], [], [2, 0, 3], [4, 5, 6], [], [7, 8, 9], [], [10]] + let dtype = ArrowDataType::LargeList(Box::new(Field::new( + "item".into(), + ArrowDataType::Int64, + true, + ))); + ListArray::::new(dtype, offsets, values, None).boxed() + }, + "list_nested_i64" => { + // [[0, 1]], None, [[2, None], [3]], [[4, 5], [6]], [], [[7], None, [9]], [[], [None], None], [[10]] + let data = [ + Some(vec![Some(vec![Some(0), Some(1)])]), + None, + Some(vec![Some(vec![Some(2), None]), Some(vec![Some(3)])]), + Some(vec![Some(vec![Some(4), Some(5)]), Some(vec![Some(6)])]), + Some(vec![]), + Some(vec![Some(vec![Some(7)]), None, Some(vec![Some(9)])]), + Some(vec![Some(vec![]), Some(vec![None]), None]), + Some(vec![Some(vec![Some(10)])]), + ]; + let mut a = + MutableListArray::>>::new(); + a.try_extend(data).unwrap(); + let array: ListArray = a.into(); + Box::new(array) + }, + "list_nested_inner_required_i64" => { + let data = [ + Some(vec![Some(vec![Some(0), Some(1)])]), + None, + Some(vec![Some(vec![Some(2), Some(3)]), Some(vec![Some(3)])]), + Some(vec![Some(vec![Some(4), Some(5)]), Some(vec![Some(6)])]), + Some(vec![]), + Some(vec![Some(vec![Some(7)]), None, Some(vec![Some(9)])]), + None, + Some(vec![Some(vec![Some(10)])]), + ]; + let mut a = + MutableListArray::>>::new(); + a.try_extend(data).unwrap(); + let array: ListArray = a.into(); + Box::new(array) + }, + "list_nested_inner_required_required_i64" => { + let data = [ + Some(vec![Some(vec![Some(0), Some(1)])]), + None, + Some(vec![Some(vec![Some(2), Some(3)]), Some(vec![Some(3)])]), + Some(vec![Some(vec![Some(4), Some(5)]), Some(vec![Some(6)])]), + Some(vec![]), + Some(vec![ + Some(vec![Some(7)]), + Some(vec![Some(8)]), + Some(vec![Some(9)]), + ]), + None, + Some(vec![Some(vec![Some(10)])]), + ]; + let mut a = + MutableListArray::>>::new(); + a.try_extend(data).unwrap(); + let array: ListArray = a.into(); + Box::new(array) + }, + "struct_list_nullable" => { + let len = values.len(); + new_struct(vec![values], len, vec!["a".to_string()], None).boxed() + }, + _ => { + let field = match column { + "list_int64" => Field::new("item".into(), ArrowDataType::Int64, true), + "list_int64_required" => Field::new("item".into(), ArrowDataType::Int64, false), + "list_int16" => Field::new("item".into(), ArrowDataType::Int16, true), + "list_bool" => Field::new("item".into(), ArrowDataType::Boolean, true), + "list_utf8" => Field::new("item".into(), ArrowDataType::Utf8View, true), + "list_large_binary" => Field::new("item".into(), ArrowDataType::LargeBinary, true), + "list_decimal" => Field::new("item".into(), ArrowDataType::Decimal(9, 0), true), + "list_decimal256" => { + Field::new("item".into(), ArrowDataType::Decimal256(9, 0), true) + }, + "list_struct_nullable" => Field::new("item".into(), values.dtype().clone(), true), + "list_struct_list_nullable" => { + Field::new("item".into(), values.dtype().clone(), true) + }, + other => unreachable!("{}", other), + }; + + let validity = Some(Bitmap::from([ + true, false, true, true, true, true, false, true, + ])); + // [0, 2, 2, 5, 8, 8, 11, 11, 12] + // [[a1, a2], None, [a3, a4, a5], [a6, a7, a8], [], [a9, a10, a11], None, [a12]] + let dtype = ArrowDataType::LargeList(Box::new(field)); + ListArray::::new(dtype, offsets, values, validity).boxed() + }, + } +} + +pub fn pyarrow_nullable(column: &str) -> Box { + let i64_values = &[ + Some(-256), + Some(-1), + None, + Some(3), + None, + Some(5), + Some(6), + Some(7), + None, + Some(9), + ]; + let u32_values = &[ + Some(0), + Some(1), + None, + Some(3), + None, + Some(5), + Some(6), + Some(7), + None, + Some(9), + ]; + + match column { + "int64" => Box::new(PrimitiveArray::::from(i64_values)), + "float64" => Box::new(PrimitiveArray::::from(&[ + Some(0.0), + Some(1.0), + None, + Some(3.0), + None, + Some(5.0), + Some(6.0), + Some(7.0), + None, + Some(9.0), + ])), + "string" => Box::new(Utf8ViewArray::from_slice([ + Some("Hello".to_string()), + None, + Some("aa".to_string()), + Some("".to_string()), + None, + Some("abc".to_string()), + None, + None, + Some("def".to_string()), + Some("aaa".to_string()), + ])), + "bool" => Box::new(BooleanArray::from([ + Some(true), + None, + Some(false), + Some(false), + None, + Some(true), + None, + None, + Some(true), + Some(true), + ])), + "timestamp_ms" => Box::new( + PrimitiveArray::::from_iter(u32_values.iter().map(|x| x.map(|x| x as i64))) + .to(ArrowDataType::Timestamp(TimeUnit::Millisecond, None)), + ), + "uint32" => Box::new(PrimitiveArray::::from(u32_values)), + "int32_dict" => { + let keys = PrimitiveArray::::from([Some(0), Some(1), None, Some(1)]); + let values = Box::new(PrimitiveArray::::from_slice([10, 200])); + Box::new(DictionaryArray::try_from_keys(keys, values).unwrap()) + }, + "timestamp_us" => Box::new( + PrimitiveArray::::from(i64_values) + .to(ArrowDataType::Timestamp(TimeUnit::Microsecond, None)), + ), + "timestamp_s" => Box::new( + PrimitiveArray::::from(i64_values) + .to(ArrowDataType::Timestamp(TimeUnit::Second, None)), + ), + "timestamp_s_utc" => Box::new(PrimitiveArray::::from(i64_values).to( + ArrowDataType::Timestamp(TimeUnit::Second, Some("UTC".into())), + )), + _ => unreachable!(), + } +} + +// these values match the values in `integration` +pub fn pyarrow_required(column: &str) -> Box { + let i64_values = &[ + Some(-256), + Some(-1), + Some(2), + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + Some(9), + ]; + + match column { + "int64" => Box::new(PrimitiveArray::::from(i64_values)), + "bool" => Box::new(BooleanArray::from_slice([ + true, true, false, false, false, true, true, true, true, true, + ])), + "string" => Box::new(Utf8ViewArray::from_slice([ + Some("Hello"), + Some("bbb"), + Some("aa"), + Some(""), + Some("bbb"), + Some("abc"), + Some("bbb"), + Some("bbb"), + Some("def"), + Some("aaa"), + ])), + _ => unreachable!(), + } +} + +pub fn pyarrow_struct(column: &str) -> Box { + let boolean = [ + Some(true), + None, + Some(false), + Some(false), + None, + Some(true), + None, + None, + Some(true), + Some(true), + ]; + let boolean = BooleanArray::from(boolean).boxed(); + + let string = [ + Some("Hello"), + None, + Some("aa"), + Some(""), + None, + Some("abc"), + None, + None, + Some("def"), + Some("aaa"), + ]; + let string = Utf8ViewArray::from_slice(string).boxed(); + + let mask = [true, true, false, true, true, true, true, true, true, true]; + + let fields = vec![ + Field::new("f1".into(), ArrowDataType::Utf8View, true), + Field::new("f2".into(), ArrowDataType::Boolean, true), + ]; + match column { + "struct" => StructArray::new( + ArrowDataType::Struct(fields), + string.len(), + vec![string, boolean], + None, + ) + .boxed(), + "struct_nullable" => { + let len = string.len(); + let values = vec![string, boolean]; + StructArray::new( + ArrowDataType::Struct(fields), + len, + values, + Some(mask.into()), + ) + .boxed() + }, + "struct_struct" => { + let struct_ = pyarrow_struct("struct"); + Box::new(StructArray::new( + ArrowDataType::Struct(vec![ + Field::new("f1".into(), ArrowDataType::Struct(fields), true), + Field::new("f2".into(), ArrowDataType::Boolean, true), + ]), + struct_.len(), + vec![struct_, boolean], + None, + )) + }, + "struct_struct_nullable" => { + let struct_ = pyarrow_struct("struct"); + Box::new(StructArray::new( + ArrowDataType::Struct(vec![ + Field::new("f1".into(), ArrowDataType::Struct(fields), true), + Field::new("f2".into(), ArrowDataType::Boolean, true), + ]), + struct_.len(), + vec![struct_, boolean], + Some(mask.into()), + )) + }, + _ => todo!(), + } +} + +fn integration_write( + schema: &ArrowSchema, + chunks: &[RecordBatchT>], +) -> PolarsResult> { + let options = WriteOptions { + statistics: StatisticsOptions::full(), + compression: CompressionOptions::Uncompressed, + version: Version::V1, + data_page_size: None, + }; + + let encodings = schema + .iter_values() + .map(|f| { + transverse(&f.dtype, |x| { + if let ArrowDataType::Dictionary(..) = x { + Encoding::RleDictionary + } else { + Encoding::Plain + } + }) + }) + .collect(); + + let row_groups = + RowGroupIterator::try_new(chunks.iter().cloned().map(Ok), schema, options, encodings)?; + + let writer = Cursor::new(vec![]); + + let mut writer = FileWriter::try_new(writer, schema.clone(), options)?; + + for group in row_groups { + writer.write(group?)?; + } + writer.end(None)?; + + Ok(writer.into_inner().into_inner()) +} + +type IntegrationRead = (ArrowSchema, Vec>>); + +fn integration_read(data: &[u8], limit: Option) -> PolarsResult { + let mut reader = Cursor::new(data); + let metadata = p_read::read_metadata(&mut reader)?; + let schema = p_read::infer_schema(&metadata)?; + + let reader = FileReader::new( + Cursor::new(data), + metadata.row_groups, + schema.clone(), + limit, + ); + + let batches = reader.collect::>>()?; + + Ok((schema, batches)) +} + +fn assert_roundtrip( + schema: ArrowSchema, + chunk: RecordBatchT>, + limit: Option, +) -> PolarsResult<()> { + let r = integration_write(&schema, &[chunk.clone()])?; + + let (new_schema, new_chunks) = integration_read(&r, limit)?; + + let expected = if let Some(limit) = limit { + let length = chunk.len().min(limit); + let expected = chunk + .into_arrays() + .into_iter() + .map(|x| x.sliced(0, limit)) + .collect::>(); + RecordBatchT::new(length, Arc::new(schema.clone()), expected) + } else { + chunk + }; + + assert_eq!(new_schema, schema); + assert_eq!(new_chunks, vec![expected]); + Ok(()) +} + +fn data>( + mut iter: I, + inner_is_nullable: bool, +) -> Box { + // [[0, 1], [], [2, 0, 3], [4, 5, 6], [], [7, 8, 9], [], [10]] + let data = vec![ + Some(vec![Some(iter.next().unwrap()), Some(iter.next().unwrap())]), + Some(vec![]), + Some(vec![ + Some(iter.next().unwrap()), + Some(iter.next().unwrap()), + Some(iter.next().unwrap()), + ]), + Some(vec![ + Some(iter.next().unwrap()), + Some(iter.next().unwrap()), + Some(iter.next().unwrap()), + ]), + Some(vec![]), + Some(vec![ + Some(iter.next().unwrap()), + Some(iter.next().unwrap()), + Some(iter.next().unwrap()), + ]), + Some(vec![]), + Some(vec![Some(iter.next().unwrap())]), + ]; + let mut array = MutableListArray::::new_with_field( + MutablePrimitiveArray::::new(), + "item".into(), + inner_is_nullable, + ); + array.try_extend(data).unwrap(); + array.into_box() +} + +fn assert_array_roundtrip( + is_nullable: bool, + array: Box, + limit: Option, +) -> PolarsResult<()> { + let schema = + ArrowSchema::from_iter([Field::new("a1".into(), array.dtype().clone(), is_nullable)]); + let chunk = RecordBatchT::try_new(array.len(), Arc::new(schema.clone()), vec![array])?; + + assert_roundtrip(schema, chunk, limit) +} + +fn test_list_array_required_required(limit: Option) -> PolarsResult<()> { + assert_array_roundtrip(false, data(0..12i8, false), limit)?; + assert_array_roundtrip(false, data(0..12i16, false), limit)?; + assert_array_roundtrip(false, data(0..12i64, false), limit)?; + assert_array_roundtrip(false, data(0..12i64, false), limit)?; + assert_array_roundtrip(false, data(0..12u8, false), limit)?; + assert_array_roundtrip(false, data(0..12u16, false), limit)?; + assert_array_roundtrip(false, data(0..12u32, false), limit)?; + assert_array_roundtrip(false, data(0..12u64, false), limit)?; + assert_array_roundtrip(false, data((0..12).map(|x| (x as f32) * 1.0), false), limit)?; + assert_array_roundtrip( + false, + data((0..12).map(|x| (x as f64) * 1.0f64), false), + limit, + ) +} + +#[test] +fn list_array_required_required() -> PolarsResult<()> { + test_list_array_required_required(None) +} + +#[test] +fn list_array_optional_optional() -> PolarsResult<()> { + assert_array_roundtrip(true, data(0..12, true), None) +} + +#[test] +fn list_array_required_optional() -> PolarsResult<()> { + assert_array_roundtrip(true, data(0..12, false), None) +} + +#[test] +fn list_array_optional_required() -> PolarsResult<()> { + assert_array_roundtrip(false, data(0..12, true), None) +} + +#[test] +fn list_slice() -> PolarsResult<()> { + let data = vec![ + Some(vec![None, Some(2)]), + Some(vec![Some(3), Some(4)]), + Some(vec![Some(5), Some(6)]), + ]; + let mut array = MutableListArray::::new_with_field( + MutablePrimitiveArray::::new(), + "item".into(), + true, + ); + array.try_extend(data).unwrap(); + let a: ListArray = array.into(); + let a = a.sliced(2, 1); + assert_array_roundtrip(false, a.boxed(), None) +} + +#[test] +fn struct_slice() -> PolarsResult<()> { + let a = pyarrow_nested_nullable("struct_list_nullable"); + + let a = a.sliced(2, 1); + assert_array_roundtrip(true, a, None) +} + +#[test] +fn list_struct_slice() -> PolarsResult<()> { + let a = pyarrow_nested_nullable("list_struct_nullable"); + + let a = a.sliced(2, 1); + assert_array_roundtrip(true, a, None) +} + +#[test] +fn list_int_nullable() -> PolarsResult<()> { + let data = vec![ + Some(vec![Some(1)]), + None, + Some(vec![None, Some(2)]), + Some(vec![]), + Some(vec![Some(3)]), + None, + ]; + let mut array = MutableListArray::::new_with_field( + MutablePrimitiveArray::::new(), + "item".into(), + true, + ); + array.try_extend(data).unwrap(); + assert_array_roundtrip(true, array.into_box(), None) +} + +#[test] +fn limit_list() -> PolarsResult<()> { + test_list_array_required_required(Some(2)) +} + +#[test] +fn filter_chunk() -> PolarsResult<()> { + let field = Field::new("c1".into(), ArrowDataType::Int16, true); + let schema = ArrowSchema::from_iter([field]); + let chunk1 = RecordBatchT::new( + 2, + Arc::new(schema.clone()), + vec![PrimitiveArray::from_slice([1i16, 3]).boxed()], + ); + let chunk2 = RecordBatchT::new( + 2, + Arc::new(schema.clone()), + vec![PrimitiveArray::from_slice([2i16, 4]).boxed()], + ); + + let r = integration_write(&schema, &[chunk1.clone(), chunk2.clone()])?; + + let mut reader = Cursor::new(r); + + let metadata = p_read::read_metadata(&mut reader)?; + + let new_schema = p_read::infer_schema(&metadata)?; + assert_eq!(new_schema, schema); + + // select chunk 1 + let row_groups = metadata + .row_groups + .into_iter() + .enumerate() + .filter(|(index, _)| *index == 0) + .map(|(_, row_group)| row_group) + .collect(); + + let reader = FileReader::new(reader, row_groups, schema, None); + + let new_chunks = reader.collect::>>()?; + + assert_eq!(new_chunks, vec![chunk1]); + Ok(()) +} diff --git a/crates/polars/tests/it/io/parquet/arrow/read.rs b/crates/polars/tests/it/io/parquet/arrow/read.rs new file mode 100644 index 000000000000..a1fce15e3814 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/arrow/read.rs @@ -0,0 +1,149 @@ +use std::path::PathBuf; + +use polars_parquet::arrow::read::*; + +use super::*; +use crate::io::parquet::read::file::FileReader; +#[cfg(feature = "parquet")] +#[test] +fn all_types() -> PolarsResult<()> { + use crate::io::parquet::read::file::FileReader; + + let dir = env!("CARGO_MANIFEST_DIR"); + let path = PathBuf::from(dir).join("../../docs/assets/data/alltypes_plain.parquet"); + + let mut reader = std::fs::File::open(path)?; + + let metadata = read_metadata(&mut reader)?; + let schema = infer_schema(&metadata)?; + let reader = FileReader::new(reader, metadata.row_groups, schema, None); + + let batches = reader.collect::>>()?; + assert_eq!(batches.len(), 1); + + let result = batches[0].columns()[0] + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(result, &Int32Array::from_slice([4, 5, 6, 7, 2, 3, 0, 1])); + + let result = batches[0].columns()[6] + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + result, + &Float32Array::from_slice([0.0, 1.1, 0.0, 1.1, 0.0, 1.1, 0.0, 1.1]) + ); + + let result = batches[0].columns()[9] + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + result, + &BinaryViewArray::from_slice_values([[48], [49], [48], [49], [48], [49], [48], [49]]) + ); + + Ok(()) +} + +#[cfg(feature = "parquet")] +#[test] +fn all_types_chunked() -> PolarsResult<()> { + // this has one batch with 8 elements + + use crate::io::parquet::read::file::FileReader; + let dir = env!("CARGO_MANIFEST_DIR"); + let path = PathBuf::from(dir).join("../../docs/assets/data/alltypes_plain.parquet"); + let mut reader = std::fs::File::open(path)?; + + let metadata = read_metadata(&mut reader)?; + let schema = infer_schema(&metadata)?; + // chunk it in 5 (so, (5,3)) + let reader = FileReader::new(reader, metadata.row_groups, schema, None); + + let batches = reader.collect::>>()?; + assert_eq!(batches.len(), 1); + + assert_eq!(batches[0].len(), 8); + + let result = batches[0].columns()[0] + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(result, &Int32Array::from_slice([4, 5, 6, 7, 2, 3, 0, 1])); + + let result = batches[0].columns()[6] + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + result, + &Float32Array::from_slice([0.0, 1.1, 0.0, 1.1, 0.0, 1.1, 0.0, 1.1]) + ); + + let result = batches[0].columns()[9] + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + result, + &BinaryViewArray::from_slice_values([[48], [49], [48], [49], [48], [49], [48], [49]]) + ); + + Ok(()) +} + +#[test] +fn read_int96_timestamps() -> PolarsResult<()> { + let timestamp_data = &[ + 0x50, 0x41, 0x52, 0x31, 0x15, 0x04, 0x15, 0x48, 0x15, 0x3c, 0x4c, 0x15, 0x06, 0x15, 0x00, + 0x12, 0x00, 0x00, 0x24, 0x00, 0x00, 0x0d, 0x01, 0x08, 0x9f, 0xd5, 0x1f, 0x0d, 0x0a, 0x44, + 0x00, 0x00, 0x59, 0x68, 0x25, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x14, + 0xfb, 0x2a, 0x00, 0x15, 0x00, 0x15, 0x14, 0x15, 0x18, 0x2c, 0x15, 0x06, 0x15, 0x10, 0x15, + 0x06, 0x15, 0x06, 0x1c, 0x00, 0x00, 0x00, 0x0a, 0x24, 0x02, 0x00, 0x00, 0x00, 0x06, 0x01, + 0x02, 0x03, 0x24, 0x00, 0x26, 0x9e, 0x01, 0x1c, 0x15, 0x06, 0x19, 0x35, 0x10, 0x00, 0x06, + 0x19, 0x18, 0x0a, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x73, 0x15, 0x02, + 0x16, 0x06, 0x16, 0x9e, 0x01, 0x16, 0x96, 0x01, 0x26, 0x60, 0x26, 0x08, 0x29, 0x2c, 0x15, + 0x04, 0x15, 0x00, 0x15, 0x02, 0x00, 0x15, 0x00, 0x15, 0x10, 0x15, 0x02, 0x00, 0x00, 0x00, + 0x15, 0x04, 0x19, 0x2c, 0x35, 0x00, 0x18, 0x06, 0x73, 0x63, 0x68, 0x65, 0x6d, 0x61, 0x15, + 0x02, 0x00, 0x15, 0x06, 0x25, 0x02, 0x18, 0x0a, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, + 0x6d, 0x70, 0x73, 0x00, 0x16, 0x06, 0x19, 0x1c, 0x19, 0x1c, 0x26, 0x9e, 0x01, 0x1c, 0x15, + 0x06, 0x19, 0x35, 0x10, 0x00, 0x06, 0x19, 0x18, 0x0a, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, + 0x61, 0x6d, 0x70, 0x73, 0x15, 0x02, 0x16, 0x06, 0x16, 0x9e, 0x01, 0x16, 0x96, 0x01, 0x26, + 0x60, 0x26, 0x08, 0x29, 0x2c, 0x15, 0x04, 0x15, 0x00, 0x15, 0x02, 0x00, 0x15, 0x00, 0x15, + 0x10, 0x15, 0x02, 0x00, 0x00, 0x00, 0x16, 0x9e, 0x01, 0x16, 0x06, 0x26, 0x08, 0x16, 0x96, + 0x01, 0x14, 0x00, 0x00, 0x28, 0x20, 0x70, 0x61, 0x72, 0x71, 0x75, 0x65, 0x74, 0x2d, 0x63, + 0x70, 0x70, 0x2d, 0x61, 0x72, 0x72, 0x6f, 0x77, 0x20, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, + 0x6e, 0x20, 0x31, 0x32, 0x2e, 0x30, 0x2e, 0x30, 0x19, 0x1c, 0x1c, 0x00, 0x00, 0x00, 0x95, + 0x00, 0x00, 0x00, 0x50, 0x41, 0x52, 0x31, + ]; + + let parse = |time_unit: TimeUnit| { + let mut reader = Cursor::new(timestamp_data); + let metadata = read_metadata(&mut reader)?; + let schema = arrow::datatypes::ArrowSchema::from_iter([arrow::datatypes::Field::new( + "timestamps".into(), + arrow::datatypes::ArrowDataType::Timestamp(time_unit, None), + false, + )]); + let reader = FileReader::new(reader, metadata.row_groups, schema, None); + reader.collect::>>() + }; + + // This data contains int96 timestamps in the year 1000 and 3000, which are out of range for + // Timestamp(TimeUnit::Nanoseconds) and will cause a panic in dev builds/overflow in release builds + // However, the code should work for the Microsecond/Millisecond time units + for time_unit in [ + arrow::datatypes::TimeUnit::Microsecond, + arrow::datatypes::TimeUnit::Millisecond, + arrow::datatypes::TimeUnit::Second, + ] { + parse(time_unit).expect("Should not error"); + } + std::panic::catch_unwind(|| parse(arrow::datatypes::TimeUnit::Nanosecond)) + .expect_err("Should be a panic error"); + + Ok(()) +} diff --git a/crates/polars/tests/it/io/parquet/arrow/write.rs b/crates/polars/tests/it/io/parquet/arrow/write.rs new file mode 100644 index 000000000000..ba57d822906e --- /dev/null +++ b/crates/polars/tests/it/io/parquet/arrow/write.rs @@ -0,0 +1,466 @@ +use polars_parquet::arrow::write::*; + +use super::*; + +fn round_trip( + column: &str, + file: &str, + version: Version, + compression: CompressionOptions, + encodings: Vec, +) -> PolarsResult<()> { + round_trip_opt_stats(column, file, version, compression, encodings) +} + +fn round_trip_opt_stats( + column: &str, + file: &str, + version: Version, + compression: CompressionOptions, + encodings: Vec, +) -> PolarsResult<()> { + let array = match file { + "nested" => pyarrow_nested_nullable(column), + "nullable" => pyarrow_nullable(column), + "required" => pyarrow_required(column), + "struct" => pyarrow_struct(column), + "nested_edge" => pyarrow_nested_edge(column), + _ => unreachable!(), + }; + + let field = Field::new("a1".into(), array.dtype().clone(), true); + let schema = ArrowSchema::from_iter([field]); + + let options = WriteOptions { + statistics: StatisticsOptions::full(), + compression, + version, + data_page_size: None, + }; + + let iter = vec![RecordBatchT::try_new( + array.len(), + Arc::new(schema.clone()), + vec![array.clone()], + )]; + + let row_groups = + RowGroupIterator::try_new(iter.into_iter(), &schema, options, vec![encodings])?; + + let writer = Cursor::new(vec![]); + let mut writer = FileWriter::try_new(writer, schema, options)?; + + for group in row_groups { + writer.write(group?)?; + } + writer.end(None)?; + + let data = writer.into_inner().into_inner(); + + std::fs::write("list_struct_list_nullable.parquet", &data).unwrap(); + + let result = read_column(&mut Cursor::new(data), "a1")?; + + assert_eq!(array.as_ref(), result.as_ref()); + Ok(()) +} + +#[test] +fn int64_optional_v1() -> PolarsResult<()> { + round_trip( + "int64", + "nullable", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn int64_required_v1() -> PolarsResult<()> { + round_trip( + "int64", + "required", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn int64_optional_v2() -> PolarsResult<()> { + round_trip( + "int64", + "nullable", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn int64_optional_delta() -> PolarsResult<()> { + round_trip( + "int64", + "nullable", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::DeltaBinaryPacked], + ) +} + +#[test] +fn int64_required_delta() -> PolarsResult<()> { + round_trip( + "int64", + "required", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::DeltaBinaryPacked], + ) +} + +#[cfg(feature = "parquet")] +#[test] +fn int64_optional_v2_compressed() -> PolarsResult<()> { + round_trip( + "int64", + "nullable", + Version::V2, + CompressionOptions::Snappy, + vec![Encoding::Plain], + ) +} + +#[test] +fn utf8_optional_v1() -> PolarsResult<()> { + round_trip( + "string", + "nullable", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn utf8_required_v1() -> PolarsResult<()> { + round_trip( + "string", + "required", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn utf8_optional_v2() -> PolarsResult<()> { + round_trip( + "string", + "nullable", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn utf8_required_v2() -> PolarsResult<()> { + round_trip( + "string", + "required", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[cfg(feature = "parquet")] +#[test] +fn utf8_optional_v2_compressed() -> PolarsResult<()> { + round_trip( + "string", + "nullable", + Version::V2, + CompressionOptions::Snappy, + vec![Encoding::Plain], + ) +} + +#[cfg(feature = "parquet")] +#[test] +fn utf8_required_v2_compressed() -> PolarsResult<()> { + round_trip( + "string", + "required", + Version::V2, + CompressionOptions::Snappy, + vec![Encoding::Plain], + ) +} + +#[test] +fn bool_optional_v1() -> PolarsResult<()> { + round_trip( + "bool", + "nullable", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn bool_required_v1() -> PolarsResult<()> { + round_trip( + "bool", + "required", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn bool_optional_v2_uncompressed() -> PolarsResult<()> { + round_trip( + "bool", + "nullable", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn bool_required_v2_uncompressed() -> PolarsResult<()> { + round_trip( + "bool", + "required", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[cfg(feature = "parquet")] +#[test] +fn bool_required_v2_compressed() -> PolarsResult<()> { + round_trip( + "bool", + "required", + Version::V2, + CompressionOptions::Snappy, + vec![Encoding::Plain], + ) +} + +#[test] +fn list_int64_optional_v2() -> PolarsResult<()> { + round_trip( + "list_int64", + "nested", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn list_int64_optional_v1() -> PolarsResult<()> { + round_trip( + "list_int64", + "nested", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn list_int64_required_required_v1() -> PolarsResult<()> { + round_trip( + "list_int64_required_required", + "nested", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn list_int64_required_required_v2() -> PolarsResult<()> { + round_trip( + "list_int64_required_required", + "nested", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn list_bool_optional_v2() -> PolarsResult<()> { + round_trip( + "list_bool", + "nested", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn list_bool_optional_v1() -> PolarsResult<()> { + round_trip( + "list_bool", + "nested", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn list_utf8_optional_v2() -> PolarsResult<()> { + round_trip( + "list_utf8", + "nested", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn list_utf8_optional_v1() -> PolarsResult<()> { + round_trip( + "list_utf8", + "nested", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn list_nested_inner_required_required_i64() -> PolarsResult<()> { + round_trip_opt_stats( + "list_nested_inner_required_required_i64", + "nested", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn v1_nested_struct_list_nullable() -> PolarsResult<()> { + round_trip_opt_stats( + "struct_list_nullable", + "nested", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn v1_nested_list_struct_list_nullable() -> PolarsResult<()> { + round_trip_opt_stats( + "list_struct_list_nullable", + "nested", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn utf8_optional_v2_delta() -> PolarsResult<()> { + round_trip( + "string", + "nullable", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::DeltaLengthByteArray], + ) +} + +#[test] +fn utf8_required_v2_delta() -> PolarsResult<()> { + round_trip( + "string", + "required", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::DeltaLengthByteArray], + ) +} + +#[test] +fn struct_v1() -> PolarsResult<()> { + round_trip( + "struct", + "struct", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain, Encoding::Plain], + ) +} + +#[test] +fn struct_v2() -> PolarsResult<()> { + round_trip( + "struct", + "struct", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain, Encoding::Plain], + ) +} + +#[test] +fn nested_edge_simple() -> PolarsResult<()> { + round_trip( + "simple", + "nested_edge", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn nested_edge_null() -> PolarsResult<()> { + round_trip( + "null", + "nested_edge", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn v1_nested_edge_struct_list_nullable() -> PolarsResult<()> { + round_trip( + "struct_list_nullable", + "nested_edge", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn nested_edge_list_struct_list_nullable() -> PolarsResult<()> { + round_trip( + "list_struct_list_nullable", + "nested_edge", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} diff --git a/crates/polars/tests/it/io/parquet/mod.rs b/crates/polars/tests/it/io/parquet/mod.rs new file mode 100644 index 000000000000..5d088aab3b15 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/mod.rs @@ -0,0 +1,208 @@ +#![forbid(unsafe_code)] +mod arrow; +pub(crate) mod read; +mod roundtrip; +mod write; + +use std::io::Cursor; +use std::path::PathBuf; + +use polars::prelude::*; + +// The dynamic representation of values in native Rust. This is not exhaustive. +// todo: maybe refactor this into serde/json? +#[derive(Debug, PartialEq)] +pub enum Array { + Int32(Vec>), + Int64(Vec>), + Int96(Vec>), + Float(Vec>), + Double(Vec>), + Boolean(Vec>), + Binary(Vec>>), + FixedLenBinary(Vec>>), + List(Vec>), + Struct(Vec, Vec), +} + +use polars_parquet::parquet::schema::types::{PhysicalType, PrimitiveType}; +use polars_parquet::parquet::statistics::*; + +pub fn alltypes_plain(column: &str) -> Array { + match column { + "id" => { + let expected = vec![4, 5, 6, 7, 2, 3, 0, 1]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Int32(expected) + }, + "id-short-array" => { + let expected = vec![4]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Int32(expected) + }, + "bool_col" => { + let expected = vec![true, false, true, false, true, false, true, false]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Boolean(expected) + }, + "tinyint_col" => { + let expected = vec![0, 1, 0, 1, 0, 1, 0, 1]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Int32(expected) + }, + "smallint_col" => { + let expected = vec![0, 1, 0, 1, 0, 1, 0, 1]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Int32(expected) + }, + "int_col" => { + let expected = vec![0, 1, 0, 1, 0, 1, 0, 1]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Int32(expected) + }, + "bigint_col" => { + let expected = vec![0, 10, 0, 10, 0, 10, 0, 10]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Int64(expected) + }, + "float_col" => { + let expected = vec![0.0, 1.1, 0.0, 1.1, 0.0, 1.1, 0.0, 1.1]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Float(expected) + }, + "double_col" => { + let expected = vec![0.0, 10.1, 0.0, 10.1, 0.0, 10.1, 0.0, 10.1]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Double(expected) + }, + "date_string_col" => { + let expected = vec![ + vec![48, 51, 47, 48, 49, 47, 48, 57], + vec![48, 51, 47, 48, 49, 47, 48, 57], + vec![48, 52, 47, 48, 49, 47, 48, 57], + vec![48, 52, 47, 48, 49, 47, 48, 57], + vec![48, 50, 47, 48, 49, 47, 48, 57], + vec![48, 50, 47, 48, 49, 47, 48, 57], + vec![48, 49, 47, 48, 49, 47, 48, 57], + vec![48, 49, 47, 48, 49, 47, 48, 57], + ]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Binary(expected) + }, + "string_col" => { + let expected = vec![ + vec![48], + vec![49], + vec![48], + vec![49], + vec![48], + vec![49], + vec![48], + vec![49], + ]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Binary(expected) + }, + "timestamp_col" => { + todo!() + }, + _ => unreachable!(), + } +} + +pub fn alltypes_statistics(column: &str) -> Statistics { + match column { + "id" => PrimitiveStatistics:: { + primitive_type: PrimitiveType::from_physical("col".into(), PhysicalType::Int32), + null_count: Some(0), + distinct_count: None, + min_value: Some(0), + max_value: Some(7), + } + .into(), + "id-short-array" => PrimitiveStatistics:: { + primitive_type: PrimitiveType::from_physical("col".into(), PhysicalType::Int32), + null_count: Some(0), + distinct_count: None, + min_value: Some(4), + max_value: Some(4), + } + .into(), + "bool_col" => BooleanStatistics { + null_count: Some(0), + distinct_count: None, + min_value: Some(false), + max_value: Some(true), + } + .into(), + "tinyint_col" | "smallint_col" | "int_col" => PrimitiveStatistics:: { + primitive_type: PrimitiveType::from_physical("col".into(), PhysicalType::Int32), + null_count: Some(0), + distinct_count: None, + min_value: Some(0), + max_value: Some(1), + } + .into(), + "bigint_col" => PrimitiveStatistics:: { + primitive_type: PrimitiveType::from_physical("col".into(), PhysicalType::Int64), + null_count: Some(0), + distinct_count: None, + min_value: Some(0), + max_value: Some(10), + } + .into(), + "float_col" => PrimitiveStatistics:: { + primitive_type: PrimitiveType::from_physical("col".into(), PhysicalType::Float), + null_count: Some(0), + distinct_count: None, + min_value: Some(0.0), + max_value: Some(1.1), + } + .into(), + "double_col" => PrimitiveStatistics:: { + primitive_type: PrimitiveType::from_physical("col".into(), PhysicalType::Double), + null_count: Some(0), + distinct_count: None, + min_value: Some(0.0), + max_value: Some(10.1), + } + .into(), + "date_string_col" => BinaryStatistics { + primitive_type: PrimitiveType::from_physical("col".into(), PhysicalType::ByteArray), + null_count: Some(0), + distinct_count: None, + min_value: Some(vec![48, 49, 47, 48, 49, 47, 48, 57]), + max_value: Some(vec![48, 52, 47, 48, 49, 47, 48, 57]), + } + .into(), + "string_col" => BinaryStatistics { + primitive_type: PrimitiveType::from_physical("col".into(), PhysicalType::ByteArray), + null_count: Some(0), + distinct_count: None, + min_value: Some(vec![48]), + max_value: Some(vec![49]), + } + .into(), + "timestamp_col" => { + todo!() + }, + _ => unreachable!(), + } +} + +#[test] +fn test_vstack_empty_3220() -> PolarsResult<()> { + let df1 = df! { + "a" => ["1", "2"], + "b" => [1, 2] + }?; + let empty_df = df1.head(Some(0)); + let mut stacked = df1.clone(); + stacked.vstack_mut(&empty_df)?; + stacked.vstack_mut(&df1)?; + let mut buf = Cursor::new(Vec::new()); + ParquetWriter::new(&mut buf).finish(&mut stacked)?; + let read_df = ParquetReader::new(buf).finish()?; + assert!(stacked.equals(&read_df)); + Ok(()) +} diff --git a/crates/polars/tests/it/io/parquet/read/binary.rs b/crates/polars/tests/it/io/parquet/read/binary.rs new file mode 100644 index 000000000000..086dd06bb9b9 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/binary.rs @@ -0,0 +1,37 @@ +use polars_parquet::parquet::error::ParquetResult; +use polars_parquet::parquet::page::DataPage; + +use super::dictionary::BinaryPageDict; +use super::utils::deserialize_optional; +use crate::io::parquet::read::utils::FixedLenBinaryPageState; +use crate::io::parquet::read::{hybrid_rle_fn_collect, hybrid_rle_iter}; + +pub fn page_to_vec( + page: &DataPage, + dict: Option<&BinaryPageDict>, +) -> ParquetResult>>> { + assert_eq!(page.descriptor.max_rep_level, 0); + + let state = FixedLenBinaryPageState::try_new(page, dict)?; + + match state { + FixedLenBinaryPageState::Optional(validity, values) => { + deserialize_optional(validity, values.map(|x| Ok(x.to_vec()))) + }, + FixedLenBinaryPageState::Required(values) => values + .map(|x| Ok(x.to_vec())) + .map(Some) + .map(|x| x.transpose()) + .collect(), + FixedLenBinaryPageState::RequiredDictionary(dict) => { + hybrid_rle_fn_collect(dict.indexes, |x| { + dict.dict.value(x as usize).map(<[u8]>::to_vec).map(Some) + }) + }, + FixedLenBinaryPageState::OptionalDictionary(validity, dict) => { + let values = hybrid_rle_iter(dict.indexes)? + .map(|x| dict.dict.value(x as usize).map(|x| x.to_vec())); + deserialize_optional(validity, values) + }, + } +} diff --git a/crates/polars/tests/it/io/parquet/read/boolean.rs b/crates/polars/tests/it/io/parquet/read/boolean.rs new file mode 100644 index 000000000000..03a3d980c3bf --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/boolean.rs @@ -0,0 +1,21 @@ +use polars_parquet::parquet::encoding::hybrid_rle::BitmapIter; +use polars_parquet::parquet::error::ParquetResult; +use polars_parquet::parquet::page::DataPage; + +use super::utils::deserialize_optional; +use crate::io::parquet::read::utils::BooleanPageState; + +pub fn page_to_vec(page: &DataPage) -> ParquetResult>> { + assert_eq!(page.descriptor.max_rep_level, 0); + + let state = BooleanPageState::try_new(page)?; + + match state { + BooleanPageState::Optional(validity, mut values) => { + deserialize_optional(validity, values.by_ref().map(Ok)) + }, + BooleanPageState::Required(bitmap, length) => { + Ok(BitmapIter::new(bitmap, 0, length).map(Some).collect()) + }, + } +} diff --git a/crates/polars/tests/it/io/parquet/read/deserialize.rs b/crates/polars/tests/it/io/parquet/read/deserialize.rs new file mode 100644 index 000000000000..1cd906b2c0e8 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/deserialize.rs @@ -0,0 +1,311 @@ +use polars_parquet::parquet::indexes::Interval; + +#[test] +fn bitmap_incomplete() { + let mut iter = FilteredHybridBitmapIter::new( + vec![HybridEncoded::Bitmap(&[0b01000011], 7)].into_iter(), + vec![Interval::new(1, 2)].into(), + ); + let a = iter.by_ref().collect::>(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![ + FilteredHybridEncoded::Skipped(1), + FilteredHybridEncoded::Bitmap { + values: &[0b01000011], + offset: 1, + length: 2, + } + ] + ); +} + +#[test] +fn bitmap_complete() { + let mut iter = FilteredHybridBitmapIter::new( + vec![HybridEncoded::Bitmap(&[0b01000011], 8)].into_iter(), + vec![Interval::new(0, 8)].into(), + ); + let a = iter.by_ref().collect::>(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![FilteredHybridEncoded::Bitmap { + values: &[0b01000011], + offset: 0, + length: 8, + }] + ); +} + +#[test] +fn bitmap_interval_incomplete() { + let mut iter = FilteredHybridBitmapIter::new( + vec![ + HybridEncoded::Bitmap(&[0b01000011], 8), + HybridEncoded::Bitmap(&[0b11111111], 8), + ] + .into_iter(), + vec![Interval::new(0, 10)].into(), + ); + let a = iter.by_ref().collect::>(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![ + FilteredHybridEncoded::Bitmap { + values: &[0b01000011], + offset: 0, + length: 8, + }, + FilteredHybridEncoded::Bitmap { + values: &[0b11111111], + offset: 0, + length: 2, + } + ] + ); +} + +#[test] +fn bitmap_interval_run_incomplete() { + let mut iter = FilteredHybridBitmapIter::new( + vec![ + HybridEncoded::Bitmap(&[0b01100011], 8), + HybridEncoded::Bitmap(&[0b11111111], 8), + ] + .into_iter(), + vec![Interval::new(0, 5), Interval::new(7, 4)].into(), + ); + let a = iter.by_ref().collect::>(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![ + FilteredHybridEncoded::Bitmap { + values: &[0b01100011], + offset: 0, + length: 5, + }, + FilteredHybridEncoded::Skipped(2), + FilteredHybridEncoded::Bitmap { + values: &[0b01100011], + offset: 7, + length: 1, + }, + FilteredHybridEncoded::Bitmap { + values: &[0b11111111], + offset: 0, + length: 3, + } + ] + ); +} + +#[test] +fn bitmap_interval_run_skipped() { + let mut iter = FilteredHybridBitmapIter::new( + vec![ + HybridEncoded::Bitmap(&[0b01100011], 8), + HybridEncoded::Bitmap(&[0b11111111], 8), + ] + .into_iter(), + vec![Interval::new(9, 2)].into(), + ); + let a = iter.by_ref().collect::>(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![ + FilteredHybridEncoded::Skipped(4), + FilteredHybridEncoded::Skipped(1), + FilteredHybridEncoded::Bitmap { + values: &[0b11111111], + offset: 1, + length: 2, + }, + ] + ); +} + +#[test] +fn bitmap_interval_run_offset_skipped() { + let mut iter = FilteredHybridBitmapIter::new( + vec![ + HybridEncoded::Bitmap(&[0b01100011], 8), + HybridEncoded::Bitmap(&[0b11111111], 8), + ] + .into_iter(), + vec![Interval::new(0, 1), Interval::new(9, 2)].into(), + ); + let a = iter.by_ref().collect::>(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![ + FilteredHybridEncoded::Bitmap { + values: &[0b01100011], + offset: 0, + length: 1, + }, + FilteredHybridEncoded::Skipped(3), + FilteredHybridEncoded::Skipped(1), + FilteredHybridEncoded::Bitmap { + values: &[0b11111111], + offset: 1, + length: 2, + }, + ] + ); +} + +#[test] +fn repeated_incomplete() { + let mut iter = FilteredHybridBitmapIter::new( + vec![HybridEncoded::Repeated(true, 7)].into_iter(), + vec![Interval::new(1, 2)].into(), + ); + let a = iter.by_ref().collect::>(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![ + FilteredHybridEncoded::Skipped(1), + FilteredHybridEncoded::Repeated { + is_set: true, + length: 2, + } + ] + ); +} + +#[test] +fn repeated_complete() { + let mut iter = FilteredHybridBitmapIter::new( + vec![HybridEncoded::Repeated(true, 8)].into_iter(), + vec![Interval::new(0, 8)].into(), + ); + let a = iter.by_ref().collect::>(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![FilteredHybridEncoded::Repeated { + is_set: true, + length: 8, + }] + ); +} + +#[test] +fn repeated_interval_incomplete() { + let mut iter = FilteredHybridBitmapIter::new( + vec![ + HybridEncoded::Repeated(true, 8), + HybridEncoded::Repeated(false, 8), + ] + .into_iter(), + vec![Interval::new(0, 10)].into(), + ); + let a = iter.by_ref().collect::>(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![ + FilteredHybridEncoded::Repeated { + is_set: true, + length: 8, + }, + FilteredHybridEncoded::Repeated { + is_set: false, + length: 2, + } + ] + ); +} + +#[test] +fn repeated_interval_run_incomplete() { + let mut iter = FilteredHybridBitmapIter::new( + vec![ + HybridEncoded::Repeated(true, 8), + HybridEncoded::Repeated(false, 8), + ] + .into_iter(), + vec![Interval::new(0, 5), Interval::new(7, 4)].into(), + ); + let a = iter.by_ref().collect::>(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![ + FilteredHybridEncoded::Repeated { + is_set: true, + length: 5, + }, + FilteredHybridEncoded::Skipped(2), + FilteredHybridEncoded::Repeated { + is_set: true, + length: 1, + }, + FilteredHybridEncoded::Repeated { + is_set: false, + length: 3, + } + ] + ); +} + +#[test] +fn repeated_interval_run_skipped() { + let mut iter = FilteredHybridBitmapIter::new( + vec![ + HybridEncoded::Repeated(true, 8), + HybridEncoded::Repeated(false, 8), + ] + .into_iter(), + vec![Interval::new(9, 2)].into(), + ); + let a = iter.by_ref().collect::>(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![ + FilteredHybridEncoded::Skipped(8), + FilteredHybridEncoded::Skipped(0), + FilteredHybridEncoded::Repeated { + is_set: false, + length: 2, + }, + ] + ); +} + +#[test] +fn repeated_interval_run_offset_skipped() { + let mut iter = FilteredHybridBitmapIter::new( + vec![ + HybridEncoded::Repeated(true, 8), + HybridEncoded::Repeated(false, 8), + ] + .into_iter(), + vec![Interval::new(0, 1), Interval::new(9, 2)].into(), + ); + let a = iter.by_ref().collect::>(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![ + FilteredHybridEncoded::Repeated { + is_set: true, + length: 1, + }, + FilteredHybridEncoded::Skipped(7), + FilteredHybridEncoded::Skipped(0), + FilteredHybridEncoded::Repeated { + is_set: false, + length: 2, + }, + ] + ); +} diff --git a/crates/polars/tests/it/io/parquet/read/dictionary/binary.rs b/crates/polars/tests/it/io/parquet/read/dictionary/binary.rs new file mode 100644 index 000000000000..2160adaf6bc1 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/dictionary/binary.rs @@ -0,0 +1,48 @@ +use polars_parquet::parquet::encoding::get_length; +use polars_parquet::parquet::error::ParquetError; + +#[derive(Debug)] +pub struct BinaryPageDict { + values: Vec>, +} + +impl BinaryPageDict { + pub fn new(values: Vec>) -> Self { + Self { values } + } + + #[inline] + pub fn value(&self, index: usize) -> Result<&[u8], ParquetError> { + self.values + .get(index) + .map(|x| x.as_ref()) + .ok_or_else(|| ParquetError::OutOfSpec("invalid index".to_string())) + } +} + +fn read_plain(bytes: &[u8], length: usize) -> Result>, ParquetError> { + let mut bytes = bytes; + let mut values = Vec::new(); + + for _ in 0..length { + let slot_length = get_length(bytes).unwrap(); + bytes = &bytes[4..]; + + if slot_length > bytes.len() { + return Err(ParquetError::OutOfSpec( + "The string on a dictionary page has a length that is out of bounds".to_string(), + )); + } + let (result, remaining) = bytes.split_at(slot_length); + + values.push(result.to_vec()); + bytes = remaining; + } + + Ok(values) +} + +pub fn read(buf: &[u8], num_values: usize) -> Result { + let values = read_plain(buf, num_values)?; + Ok(BinaryPageDict::new(values)) +} diff --git a/crates/polars/tests/it/io/parquet/read/dictionary/fixed_len_binary.rs b/crates/polars/tests/it/io/parquet/read/dictionary/fixed_len_binary.rs new file mode 100644 index 000000000000..2cc0ed67caa6 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/dictionary/fixed_len_binary.rs @@ -0,0 +1,35 @@ +use polars_parquet::parquet::error::{ParquetError, ParquetResult}; + +#[derive(Debug)] +pub struct FixedLenByteArrayPageDict { + values: Vec, + size: usize, +} + +impl FixedLenByteArrayPageDict { + pub fn new(values: Vec, size: usize) -> Self { + Self { values, size } + } + + #[inline] + pub fn value(&self, index: usize) -> ParquetResult<&[u8]> { + self.values + .get(index * self.size..(index + 1) * self.size) + .ok_or_else(|| { + ParquetError::OutOfSpec( + "The data page has an index larger than the dictionary page values".to_string(), + ) + }) + } +} + +pub fn read( + buf: &[u8], + size: usize, + num_values: usize, +) -> ParquetResult { + let length = size.saturating_mul(num_values); + let values = buf.get(..length).ok_or_else(|| ParquetError::OutOfSpec("Fixed sized binary declares a number of values times size larger than the page buffer".to_string()))?.to_vec(); + + Ok(FixedLenByteArrayPageDict::new(values, size)) +} diff --git a/crates/polars/tests/it/io/parquet/read/dictionary/mod.rs b/crates/polars/tests/it/io/parquet/read/dictionary/mod.rs new file mode 100644 index 000000000000..4fd07db84941 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/dictionary/mod.rs @@ -0,0 +1,56 @@ +mod binary; +mod fixed_len_binary; +mod primitive; + +pub use binary::BinaryPageDict; +pub use fixed_len_binary::FixedLenByteArrayPageDict; +use polars_parquet::parquet::error::{ParquetError, ParquetResult}; +use polars_parquet::parquet::page::DictPage; +use polars_parquet::parquet::schema::types::PhysicalType; +pub use primitive::PrimitivePageDict; + +pub enum DecodedDictPage { + Int32(PrimitivePageDict), + Int64(PrimitivePageDict), + Int96(PrimitivePageDict<[u32; 3]>), + Float(PrimitivePageDict), + Double(PrimitivePageDict), + ByteArray(BinaryPageDict), + FixedLenByteArray(FixedLenByteArrayPageDict), +} + +pub fn deserialize(page: &DictPage, physical_type: PhysicalType) -> ParquetResult { + _deserialize(&page.buffer, page.num_values, page.is_sorted, physical_type) +} + +fn _deserialize( + buf: &[u8], + num_values: usize, + is_sorted: bool, + physical_type: PhysicalType, +) -> ParquetResult { + match physical_type { + PhysicalType::Boolean => Err(ParquetError::OutOfSpec( + "Boolean physical type cannot be dictionary-encoded".to_string(), + )), + PhysicalType::Int32 => { + primitive::read::(buf, num_values, is_sorted).map(DecodedDictPage::Int32) + }, + PhysicalType::Int64 => { + primitive::read::(buf, num_values, is_sorted).map(DecodedDictPage::Int64) + }, + PhysicalType::Int96 => { + primitive::read::<[u32; 3]>(buf, num_values, is_sorted).map(DecodedDictPage::Int96) + }, + PhysicalType::Float => { + primitive::read::(buf, num_values, is_sorted).map(DecodedDictPage::Float) + }, + PhysicalType::Double => { + primitive::read::(buf, num_values, is_sorted).map(DecodedDictPage::Double) + }, + PhysicalType::ByteArray => binary::read(buf, num_values).map(DecodedDictPage::ByteArray), + PhysicalType::FixedLenByteArray(size) => { + fixed_len_binary::read(buf, size, num_values).map(DecodedDictPage::FixedLenByteArray) + }, + } +} diff --git a/crates/polars/tests/it/io/parquet/read/dictionary/primitive.rs b/crates/polars/tests/it/io/parquet/read/dictionary/primitive.rs new file mode 100644 index 000000000000..d78bdd9abf83 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/dictionary/primitive.rs @@ -0,0 +1,47 @@ +use polars_parquet::parquet::error::{ParquetError, ParquetResult}; +use polars_parquet::parquet::types::{NativeType, decode}; + +#[derive(Debug)] +pub struct PrimitivePageDict { + values: Vec, +} + +impl PrimitivePageDict { + pub fn new(values: Vec) -> Self { + Self { values } + } + + pub fn values(&self) -> &[T] { + &self.values + } + + #[inline] + pub fn value(&self, index: usize) -> ParquetResult<&T> { + self.values.get(index).ok_or_else(|| { + ParquetError::OutOfSpec( + "The data page has an index larger than the dictionary page values".to_string(), + ) + }) + } +} + +pub fn read( + buf: &[u8], + num_values: usize, + _is_sorted: bool, +) -> ParquetResult> { + let size_of = size_of::(); + + let typed_size = num_values.wrapping_mul(size_of); + + let values = buf.get(..typed_size).ok_or_else(|| { + ParquetError::OutOfSpec( + "The number of values declared in the dict page does not match the length of the page" + .to_string(), + ) + })?; + + let values = values.chunks_exact(size_of).map(decode::).collect(); + + Ok(PrimitivePageDict::new(values)) +} diff --git a/crates/polars/tests/it/io/parquet/read/file.rs b/crates/polars/tests/it/io/parquet/read/file.rs new file mode 100644 index 000000000000..51af12b2a51a --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/file.rs @@ -0,0 +1,175 @@ +use std::io::{Read, Seek}; +use std::sync::Arc; + +use arrow::array::Array; +use arrow::datatypes::ArrowSchema; +use arrow::record_batch::RecordBatchT; +use polars_error::PolarsResult; +use polars_parquet::read::{Filter, RowGroupMetadata}; + +use super::row_group::{RowGroupDeserializer, read_columns_many}; + +/// An iterator of [`RecordBatchT`]s coming from row groups of a parquet file. +/// +/// This can be thought of a flatten chain of [`Iterator`] - each row group is sequentially +/// mapped to an [`Iterator`] and each iterator is iterated upon until either the limit +/// or the last iterator ends. +/// # Implementation +/// This iterator is single threaded on both IO-bounded and CPU-bounded tasks, and mixes them. +pub struct FileReader { + row_groups: RowGroupReader, + remaining_rows: usize, + current_row_group: Option, +} + +impl FileReader { + /// Returns a new [`FileReader`]. + pub fn new( + reader: R, + row_groups: Vec, + schema: ArrowSchema, + limit: Option, + ) -> Self { + let row_groups = RowGroupReader::new(reader, schema, row_groups, limit); + + Self { + row_groups, + remaining_rows: limit.unwrap_or(usize::MAX), + current_row_group: None, + } + } + + fn next_row_group(&mut self) -> PolarsResult> { + let result = self.row_groups.next().transpose()?; + + // If current_row_group is None, then there will be no elements to remove. + if self.current_row_group.is_some() { + self.remaining_rows = self.remaining_rows.saturating_sub( + result + .as_ref() + .map(|x| x.num_rows()) + .unwrap_or(self.remaining_rows), + ); + } + Ok(result) + } +} + +impl Iterator for FileReader { + type Item = PolarsResult>>; + + fn next(&mut self) -> Option { + if self.remaining_rows == 0 { + // reached the limit + return None; + } + + if let Some(row_group) = &mut self.current_row_group { + match row_group.next() { + // no more chunks in the current row group => try a new one + None => match self.next_row_group() { + Ok(Some(row_group)) => { + self.current_row_group = Some(row_group); + // new found => pull again + self.next() + }, + Ok(None) => { + self.current_row_group = None; + None + }, + Err(e) => Some(Err(e)), + }, + other => other, + } + } else { + match self.next_row_group() { + Ok(Some(row_group)) => { + self.current_row_group = Some(row_group); + self.next() + }, + Ok(None) => { + self.current_row_group = None; + None + }, + Err(e) => Some(Err(e)), + } + } + } +} + +/// An [`Iterator`] from row groups of a parquet file. +/// +/// # Implementation +/// Advancing this iterator is IO-bounded - each iteration reads all the column chunks from the file +/// to memory and attaches [`RowGroupDeserializer`] to them so that they can be iterated in chunks. +pub struct RowGroupReader { + reader: R, + schema: ArrowSchema, + row_groups: std::vec::IntoIter, + remaining_rows: usize, +} + +impl RowGroupReader { + /// Returns a new [`RowGroupReader`] + pub fn new( + reader: R, + schema: ArrowSchema, + row_groups: Vec, + limit: Option, + ) -> Self { + Self { + reader, + schema, + row_groups: row_groups.into_iter(), + remaining_rows: limit.unwrap_or(usize::MAX), + } + } + + #[inline] + fn _next(&mut self) -> PolarsResult> { + if self.schema.is_empty() { + return Ok(None); + } + if self.remaining_rows == 0 { + // reached the limit + return Ok(None); + } + + let row_group = if let Some(row_group) = self.row_groups.next() { + row_group + } else { + return Ok(None); + }; + + let num_rows = row_group.num_rows(); + + let column_schema = self.schema.iter_values().cloned().collect(); + let column_chunks = read_columns_many( + &mut self.reader, + &row_group, + &self.schema, + Some(Filter::new_limited(self.remaining_rows)), + )?; + + let result = RowGroupDeserializer::new( + Arc::new(column_schema), + column_chunks, + num_rows, + Some(self.remaining_rows), + ); + self.remaining_rows = self.remaining_rows.saturating_sub(num_rows); + Ok(Some(result)) + } +} + +impl Iterator for RowGroupReader { + type Item = PolarsResult; + + fn next(&mut self) -> Option { + self._next().transpose() + } + + fn size_hint(&self) -> (usize, Option) { + self.row_groups.size_hint() + } +} diff --git a/crates/polars/tests/it/io/parquet/read/fixed_binary.rs b/crates/polars/tests/it/io/parquet/read/fixed_binary.rs new file mode 100644 index 000000000000..6afe24bde21e --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/fixed_binary.rs @@ -0,0 +1,32 @@ +use polars_parquet::parquet::error::ParquetResult; +use polars_parquet::parquet::page::DataPage; + +use super::dictionary::FixedLenByteArrayPageDict; +use super::utils::{FixedLenBinaryPageState, deserialize_optional}; +use crate::io::parquet::read::hybrid_rle_iter; + +pub fn page_to_vec( + page: &DataPage, + dict: Option<&FixedLenByteArrayPageDict>, +) -> ParquetResult>>> { + assert_eq!(page.descriptor.max_rep_level, 0); + + let state = FixedLenBinaryPageState::try_new(page, dict)?; + + match state { + FixedLenBinaryPageState::Optional(validity, values) => { + deserialize_optional(validity, values.map(|x| Ok(x.to_vec()))) + }, + FixedLenBinaryPageState::Required(values) => { + Ok(values.map(|x| x.to_vec()).map(Some).collect()) + }, + FixedLenBinaryPageState::RequiredDictionary(dict) => hybrid_rle_iter(dict.indexes)? + .map(|x| dict.dict.value(x as usize).map(|x| x.to_vec()).map(Some)) + .collect(), + FixedLenBinaryPageState::OptionalDictionary(validity, dict) => { + let values = hybrid_rle_iter(dict.indexes)? + .map(|x| dict.dict.value(x as usize).map(|x| x.to_vec())); + deserialize_optional(validity, values) + }, + } +} diff --git a/crates/polars/tests/it/io/parquet/read/mod.rs b/crates/polars/tests/it/io/parquet/read/mod.rs new file mode 100644 index 000000000000..cb7568676365 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/mod.rs @@ -0,0 +1,401 @@ +mod binary; +/// Serialization to Rust's Native types. +/// In comparison to Arrow, this in-memory format does not leverage logical types nor SIMD operations, +/// but OTOH it has no external dependencies and is very familiar to Rust developers. +mod boolean; +mod dictionary; +pub(crate) mod file; +mod fixed_binary; +mod primitive; +mod primitive_nested; +pub(crate) mod row_group; +mod struct_; +mod utils; + +use std::fs::File; + +use dictionary::DecodedDictPage; +use polars_parquet::parquet::encoding::hybrid_rle::{HybridRleChunk, HybridRleDecoder}; +use polars_parquet::parquet::error::{ParquetError, ParquetResult}; +use polars_parquet::parquet::metadata::ColumnChunkMetadata; +use polars_parquet::parquet::page::DataPage; +use polars_parquet::parquet::read::{BasicDecompressor, get_column_iterator, read_metadata}; +use polars_parquet::parquet::schema::Repetition; +use polars_parquet::parquet::schema::types::{GroupConvertedType, ParquetType}; +use polars_parquet::parquet::types::int96_to_i64_ns; +use polars_parquet::read::PageReader; +use polars_utils::mmap::MemReader; + +use super::*; + +pub fn hybrid_rle_iter(d: HybridRleDecoder) -> ParquetResult> { + Ok(d.collect()?.into_iter()) +} + +pub fn hybrid_rle_fn_collect( + d: HybridRleDecoder, + mut f: impl FnMut(u32) -> ParquetResult, +) -> ParquetResult> { + let mut target = Vec::with_capacity(d.len()); + + for chunk in d.into_chunk_iter() { + match chunk? { + HybridRleChunk::Rle(value, size) => { + target.resize(target.len() + size, f(value)?); + }, + HybridRleChunk::Bitpacked(mut decoder) => { + let mut chunked = decoder.chunked(); + for dchunk in chunked.by_ref() { + for v in dchunk { + target.push(f(v)?); + } + } + + if let Some((dchunk, l)) = chunked.remainder() { + for &v in &dchunk[..l] { + target.push(f(v)?); + } + } + }, + } + } + + Ok(target) +} + +pub fn get_path() -> PathBuf { + let dir = env!("CARGO_MANIFEST_DIR"); + PathBuf::from(dir).join("../../docs/assets/data") +} + +/// Reads a page into an [`Array`]. +/// This is CPU-intensive: decompress, decode and de-serialize. +pub fn page_to_array(page: &DataPage, dict: Option<&DecodedDictPage>) -> ParquetResult { + let physical_type = page.descriptor.primitive_type.physical_type; + match page.descriptor.max_rep_level { + 0 => match physical_type { + PhysicalType::Boolean => Ok(Array::Boolean(boolean::page_to_vec(page)?)), + PhysicalType::Int32 => { + let dict = dict.map(|dict| { + if let DecodedDictPage::Int32(dict) = dict { + dict + } else { + panic!() + } + }); + primitive::page_to_vec(page, dict).map(Array::Int32) + }, + PhysicalType::Int64 => { + let dict = dict.map(|dict| { + if let DecodedDictPage::Int64(dict) = dict { + dict + } else { + panic!() + } + }); + primitive::page_to_vec(page, dict).map(Array::Int64) + }, + PhysicalType::Int96 => { + let dict = dict.map(|dict| { + if let DecodedDictPage::Int96(dict) = dict { + dict + } else { + panic!() + } + }); + primitive::page_to_vec(page, dict).map(Array::Int96) + }, + PhysicalType::Float => { + let dict = dict.map(|dict| { + if let DecodedDictPage::Float(dict) = dict { + dict + } else { + panic!() + } + }); + primitive::page_to_vec(page, dict).map(Array::Float) + }, + PhysicalType::Double => { + let dict = dict.map(|dict| { + if let DecodedDictPage::Double(dict) = dict { + dict + } else { + panic!() + } + }); + primitive::page_to_vec(page, dict).map(Array::Double) + }, + PhysicalType::ByteArray => { + let dict = dict.map(|dict| { + if let DecodedDictPage::ByteArray(dict) = dict { + dict + } else { + panic!() + } + }); + + binary::page_to_vec(page, dict).map(Array::Binary) + }, + PhysicalType::FixedLenByteArray(_) => { + let dict = dict.map(|dict| { + if let DecodedDictPage::FixedLenByteArray(dict) = dict { + dict + } else { + panic!() + } + }); + + fixed_binary::page_to_vec(page, dict).map(Array::FixedLenBinary) + }, + }, + _ => match dict { + None => match physical_type { + PhysicalType::Int64 => Ok(primitive_nested::page_to_array::(page, None)?), + _ => todo!(), + }, + Some(_) => match physical_type { + PhysicalType::Int64 => { + let dict = dict.map(|dict| { + if let DecodedDictPage::Int64(dict) = dict { + dict + } else { + panic!() + } + }); + Ok(primitive_nested::page_dict_to_array(page, dict)?) + }, + _ => todo!(), + }, + }, + } +} + +/// Reads columns into an [`Array`]. +/// This is CPU-intensive: decompress, decode and de-serialize. +pub fn columns_to_array<'a, I>(mut columns: I, field: &ParquetType) -> ParquetResult +where + I: Iterator>, +{ + let mut validity = vec![]; + let mut has_filled = false; + let mut arrays = vec![]; + while let Some((pages, column)) = columns.next().transpose()? { + let mut iterator = BasicDecompressor::new(pages, vec![]); + + let dict = iterator + .read_dict_page()? + .map(|dict| dictionary::deserialize(&dict, column.physical_type())) + .transpose()?; + while let Some(page) = iterator.next().transpose()? { + let page = page.decompress(&mut iterator)?; + if !has_filled { + struct_::extend_validity(&mut validity, &page)?; + } + arrays.push(page_to_array(&page, dict.as_ref())?) + } + has_filled = true; + } + + match field { + ParquetType::PrimitiveType { .. } => arrays + .pop() + .ok_or_else(|| ParquetError::OutOfSpec("".to_string())), + ParquetType::GroupType { converted_type, .. } => { + if let Some(converted_type) = converted_type { + match converted_type { + GroupConvertedType::List => Ok(arrays.pop().unwrap()), + _ => todo!(), + } + } else { + Ok(Array::Struct(arrays, validity)) + } + }, + } +} + +pub fn read_column( + mut reader: MemReader, + row_group: usize, + field_name: &str, +) -> ParquetResult<(Array, Option)> { + let metadata = read_metadata(&mut reader)?; + + let field = metadata + .schema() + .fields() + .iter() + .find(|field| field.name() == field_name) + .ok_or_else(|| ParquetError::OutOfSpec("column does not exist".to_string()))?; + + let columns = get_column_iterator( + reader, + &metadata.row_groups[row_group], + field.name(), + usize::MAX, + ); + + let mut statistics = metadata.row_groups[row_group] + .columns_under_root_iter(field.name()) + .unwrap() + .map(|column_meta| column_meta.statistics().transpose()) + .collect::>>()?; + + let array = columns_to_array(columns, field)?; + + Ok((array, statistics.pop().unwrap())) +} + +fn get_column(path: &str, column: &str) -> ParquetResult<(Array, Option)> { + let file = File::open(path).unwrap(); + let memreader = MemReader::from_reader(file).unwrap(); + read_column(memreader, 0, column) +} + +fn test_column(column: &str) -> ParquetResult<()> { + let mut path = get_path(); + path.push("alltypes_plain.parquet"); + let path = path.to_str().unwrap(); + let (result, statistics) = get_column(path, column)?; + // the file does not have statistics + assert_eq!(statistics.as_ref(), None); + assert_eq!(result, alltypes_plain(column)); + Ok(()) +} + +#[test] +fn int32() -> ParquetResult<()> { + test_column("id") +} + +#[test] +fn bool() -> ParquetResult<()> { + test_column("bool_col") +} + +#[test] +fn tinyint_col() -> ParquetResult<()> { + test_column("tinyint_col") +} + +#[test] +fn smallint_col() -> ParquetResult<()> { + test_column("smallint_col") +} + +#[test] +fn int_col() -> ParquetResult<()> { + test_column("int_col") +} + +#[test] +fn bigint_col() -> ParquetResult<()> { + test_column("bigint_col") +} + +#[test] +fn float_col() -> ParquetResult<()> { + test_column("float_col") +} + +#[test] +fn double_col() -> ParquetResult<()> { + test_column("double_col") +} + +#[test] +fn timestamp_col() -> ParquetResult<()> { + let mut path = get_path(); + path.push("alltypes_plain.parquet"); + let path = path.to_str().unwrap(); + + let expected = vec![ + 1235865600000000000i64, + 1235865660000000000, + 1238544000000000000, + 1238544060000000000, + 1233446400000000000, + 1233446460000000000, + 1230768000000000000, + 1230768060000000000, + ]; + + let expected = expected.into_iter().map(Some).collect::>(); + let (array, _) = get_column(path, "timestamp_col")?; + if let Array::Int96(array) = array { + let a = array + .into_iter() + .map(|x| x.map(int96_to_i64_ns)) + .collect::>(); + assert_eq!(expected, a); + } else { + panic!("Timestamp expected"); + }; + Ok(()) +} + +#[test] +fn test_metadata() -> ParquetResult<()> { + let mut testdata = get_path(); + testdata.push("alltypes_plain.parquet"); + let mut file = File::open(testdata).unwrap(); + + let metadata = read_metadata(&mut file)?; + + let columns = metadata.schema_descr.columns(); + + /* + from pyarrow: + required group field_id=0 schema { + optional int32 field_id=1 id; + optional boolean field_id=2 bool_col; + optional int32 field_id=3 tinyint_col; + optional int32 field_id=4 smallint_col; + optional int32// pub enum Value { + // UInt32(Option), + // Int32(Option), + // Int64(Option), + // Int96(Option<[u32; 3]>), + // Float32(Option), + // Float64(Option), + // Boolean(Option), + // Binary(Option>), + // FixedLenBinary(Option>), + // List(Option), + // } + field_id=5 int_col; + optional int64 field_id=6 bigint_col; + optional float field_id=7 float_col; + optional double field_id=8 double_col; + optional binary field_id=9 date_string_col; + optional binary field_id=10 string_col; + optional int96 field_id=11 timestamp_col; + } + */ + let expected = vec![ + PhysicalType::Int32, + PhysicalType::Boolean, + PhysicalType::Int32, + PhysicalType::Int32, + PhysicalType::Int32, + PhysicalType::Int64, + PhysicalType::Float, + PhysicalType::Double, + PhysicalType::ByteArray, + PhysicalType::ByteArray, + PhysicalType::Int96, + ]; + + let result = columns + .iter() + .map(|column| { + assert_eq!( + column.descriptor.primitive_type.field_info.repetition, + Repetition::Optional + ); + column.descriptor.primitive_type.physical_type + }) + .collect::>(); + + assert_eq!(expected, result); + Ok(()) +} diff --git a/crates/polars/tests/it/io/parquet/read/primitive.rs b/crates/polars/tests/it/io/parquet/read/primitive.rs new file mode 100644 index 000000000000..ef450c1e4009 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/primitive.rs @@ -0,0 +1,56 @@ +use polars_parquet::parquet::error::ParquetResult; +use polars_parquet::parquet::page::DataPage; +use polars_parquet::parquet::types::NativeType; +use polars_parquet::read::ParquetError; + +use super::dictionary::PrimitivePageDict; +use super::hybrid_rle_iter; +use super::utils::{NativePageState, deserialize_optional}; +use crate::io::parquet::read::hybrid_rle_fn_collect; + +/// The deserialization state of a `DataPage` of `Primitive` parquet primitive type +#[derive(Debug)] +#[allow(clippy::large_enum_variant)] +pub enum PageState<'a, T> +where + T: NativeType, +{ + Nominal(NativePageState<'a, T, &'a PrimitivePageDict>), +} + +impl<'a, T: NativeType> PageState<'a, T> { + /// Tries to create [`NativePageState`] + /// # Error + /// Errors iff the page is not a `NativePageState` + pub fn try_new( + page: &'a DataPage, + dict: Option<&'a PrimitivePageDict>, + ) -> Result { + NativePageState::try_new(page, dict).map(Self::Nominal) + } +} + +pub fn page_to_vec( + page: &DataPage, + dict: Option<&PrimitivePageDict>, +) -> ParquetResult>> { + assert_eq!(page.descriptor.max_rep_level, 0); + let state = PageState::::try_new(page, dict)?; + + match state { + PageState::Nominal(state) => match state { + NativePageState::Optional(validity, mut values) => { + deserialize_optional(validity, values.by_ref().map(Ok)) + }, + NativePageState::Required(values) => Ok(values.map(Some).collect()), + NativePageState::RequiredDictionary(dict) => hybrid_rle_fn_collect(dict.indexes, |x| { + dict.dict.value(x as usize).copied().map(Some) + }), + NativePageState::OptionalDictionary(validity, dict) => { + let values = + hybrid_rle_iter(dict.indexes)?.map(|x| dict.dict.value(x as usize).copied()); + deserialize_optional(validity, values) + }, + }, + } +} diff --git a/crates/polars/tests/it/io/parquet/read/primitive_nested.rs b/crates/polars/tests/it/io/parquet/read/primitive_nested.rs new file mode 100644 index 000000000000..d7faaf6a9338 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/primitive_nested.rs @@ -0,0 +1,284 @@ +use polars_parquet::parquet::encoding::bitpacked::{Unpackable, Unpacked}; +use polars_parquet::parquet::encoding::hybrid_rle::HybridRleDecoder; +use polars_parquet::parquet::encoding::{Encoding, bitpacked, uleb128}; +use polars_parquet::parquet::error::{ParquetError, ParquetResult}; +use polars_parquet::parquet::page::{DataPage, EncodedSplitBuffer, split_buffer}; +use polars_parquet::parquet::read::levels::get_bit_width; +use polars_parquet::parquet::types::NativeType; + +use super::dictionary::PrimitivePageDict; +use super::{Array, hybrid_rle_iter}; + +fn read_buffer(values: &[u8]) -> impl Iterator + '_ { + let chunks = values.chunks_exact(size_of::()); + chunks.map(|chunk| { + // unwrap is infalible due to the chunk size. + let chunk: T::Bytes = match chunk.try_into() { + Ok(v) => v, + Err(_) => panic!(), + }; + T::from_le_bytes(chunk) + }) +} + +// todo: generalize i64 -> T +fn compose_array, F: Iterator, G: Iterator>( + rep_levels: I, + def_levels: F, + max_rep: u32, + max_def: u32, + mut values: G, +) -> Result { + let mut outer = vec![]; + let mut inner = vec![]; + + assert_eq!(max_rep, 1); + assert_eq!(max_def, 3); + let mut prev_def = 0; + rep_levels + .into_iter() + .zip(def_levels.into_iter()) + .try_for_each(|(rep, def)| { + match rep { + 1 => {}, + 0 => { + if prev_def > 1 { + let old = std::mem::take(&mut inner); + outer.push(Some(Array::Int64(old))); + } + }, + _ => unreachable!(), + } + match def { + 3 => inner.push(Some(values.next().unwrap())), + 2 => inner.push(None), + 1 => outer.push(Some(Array::Int64(vec![]))), + 0 => outer.push(None), + _ => unreachable!(), + } + prev_def = def; + Ok::<(), ParquetError>(()) + })?; + outer.push(Some(Array::Int64(inner))); + Ok(Array::List(outer)) +} + +fn read_array_impl>( + rep_levels: &[u8], + def_levels: &[u8], + values: I, + length: usize, + rep_level_encoding: (&Encoding, i16), + def_level_encoding: (&Encoding, i16), +) -> Result { + let max_rep_level = rep_level_encoding.1 as u32; + let max_def_level = def_level_encoding.1 as u32; + + match ( + (rep_level_encoding.0, max_rep_level == 0), + (def_level_encoding.0, max_def_level == 0), + ) { + ((Encoding::Rle, true), (Encoding::Rle, true)) => compose_array( + std::iter::repeat_n(0, length), + std::iter::repeat_n(0, length), + max_rep_level, + max_def_level, + values, + ), + ((Encoding::Rle, false), (Encoding::Rle, true)) => { + let num_bits = get_bit_width(rep_level_encoding.1); + let rep_levels = HybridRleDecoder::new(rep_levels, num_bits, length); + compose_array( + hybrid_rle_iter(rep_levels)?, + std::iter::repeat_n(0, length), + max_rep_level, + max_def_level, + values, + ) + }, + ((Encoding::Rle, true), (Encoding::Rle, false)) => { + let num_bits = get_bit_width(def_level_encoding.1); + let def_levels = HybridRleDecoder::new(def_levels, num_bits, length); + compose_array( + std::iter::repeat_n(0, length), + hybrid_rle_iter(def_levels)?, + max_rep_level, + max_def_level, + values, + ) + }, + ((Encoding::Rle, false), (Encoding::Rle, false)) => { + let rep_levels = + HybridRleDecoder::new(rep_levels, get_bit_width(rep_level_encoding.1), length); + let def_levels = + HybridRleDecoder::new(def_levels, get_bit_width(def_level_encoding.1), length); + compose_array( + hybrid_rle_iter(rep_levels)?, + hybrid_rle_iter(def_levels)?, + max_rep_level, + max_def_level, + values, + ) + }, + _ => todo!(), + } +} + +fn read_array( + rep_levels: &[u8], + def_levels: &[u8], + values: &[u8], + length: u32, + rep_level_encoding: (&Encoding, i16), + def_level_encoding: (&Encoding, i16), +) -> Result { + let values = read_buffer::(values); + read_array_impl::<_>( + rep_levels, + def_levels, + values, + length as usize, + rep_level_encoding, + def_level_encoding, + ) +} + +pub fn page_to_array( + page: &DataPage, + dict: Option<&PrimitivePageDict>, +) -> Result { + let EncodedSplitBuffer { + rep: rep_levels, + def: def_levels, + values, + } = split_buffer(page)?; + + match (&page.encoding(), dict) { + (Encoding::Plain, None) => read_array( + rep_levels, + def_levels, + values, + page.num_values() as u32, + ( + &page.repetition_level_encoding(), + page.descriptor.max_rep_level, + ), + ( + &page.definition_level_encoding(), + page.descriptor.max_def_level, + ), + ), + _ => todo!(), + } +} + +pub struct DecoderIter<'a, T: Unpackable> { + pub(crate) decoder: bitpacked::Decoder<'a, T>, + pub(crate) buffered: T::Unpacked, + pub(crate) unpacked_start: usize, + pub(crate) unpacked_end: usize, +} + +impl Iterator for DecoderIter<'_, T> { + type Item = T; + + fn next(&mut self) -> Option { + if self.unpacked_start >= self.unpacked_end { + let length; + (self.buffered, length) = self.decoder.chunked().next_inexact()?; + debug_assert!(length > 0); + self.unpacked_start = 1; + self.unpacked_end = length; + return Some(self.buffered[0]); + } + + let v = self.buffered[self.unpacked_start]; + self.unpacked_start += 1; + Some(v) + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.decoder.len() + self.unpacked_end - self.unpacked_start; + (len, Some(len)) + } +} + +impl ExactSizeIterator for DecoderIter<'_, T> {} + +impl<'a, T: Unpackable> DecoderIter<'a, T> { + pub fn new(packed: &'a [u8], num_bits: usize, length: usize) -> ParquetResult { + assert!(num_bits > 0); + Ok(Self { + decoder: bitpacked::Decoder::try_new(packed, num_bits, length)?, + buffered: T::Unpacked::zero(), + unpacked_start: 0, + unpacked_end: 0, + }) + } +} + +fn read_dict_array( + rep_levels: &[u8], + def_levels: &[u8], + values: &[u8], + length: u32, + dict: &PrimitivePageDict, + rep_level_encoding: (&Encoding, i16), + def_level_encoding: (&Encoding, i16), +) -> Result { + let dict_values = dict.values(); + + let bit_width = values[0]; + let values = &values[1..]; + + let (_, consumed) = uleb128::decode(values); + let values = &values[consumed..]; + + let indices = DecoderIter::::new(values, bit_width as usize, length as usize)?; + + let values = indices.map(|id| dict_values[id as usize]); + + read_array_impl::<_>( + rep_levels, + def_levels, + values, + length as usize, + rep_level_encoding, + def_level_encoding, + ) +} + +pub fn page_dict_to_array( + page: &DataPage, + dict: Option<&PrimitivePageDict>, +) -> Result { + assert_eq!(page.descriptor.max_rep_level, 1); + + let EncodedSplitBuffer { + rep: rep_levels, + def: def_levels, + values, + } = split_buffer(page)?; + + match (page.encoding(), dict) { + (Encoding::PlainDictionary, Some(dict)) => read_dict_array( + rep_levels, + def_levels, + values, + page.num_values() as u32, + dict, + ( + &page.repetition_level_encoding(), + page.descriptor.max_rep_level, + ), + ( + &page.definition_level_encoding(), + page.descriptor.max_def_level, + ), + ), + (_, None) => Err(ParquetError::OutOfSpec( + "A dictionary-encoded page MUST be preceded by a dictionary page".to_string(), + )), + _ => todo!(), + } +} diff --git a/crates/polars/tests/it/io/parquet/read/row_group.rs b/crates/polars/tests/it/io/parquet/read/row_group.rs new file mode 100644 index 000000000000..bef9f8833016 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/row_group.rs @@ -0,0 +1,167 @@ +use std::io::{Read, Seek}; + +use arrow::array::Array; +use arrow::datatypes::{ArrowSchemaRef, Field}; +use arrow::record_batch::RecordBatchT; +use polars::prelude::ArrowSchema; +use polars_error::PolarsResult; +use polars_parquet::arrow::read::{Filter, column_iter_to_arrays}; +use polars_parquet::parquet::metadata::ColumnChunkMetadata; +use polars_parquet::parquet::read::{BasicDecompressor, PageReader}; +use polars_parquet::read::RowGroupMetadata; +use polars_utils::mmap::MemReader; + +/// An [`Iterator`] of [`RecordBatchT`] that (dynamically) adapts a vector of iterators of [`Array`] into +/// an iterator of [`RecordBatchT`]. +/// +/// This struct tracks advances each of the iterators individually and combines the +/// result in a single [`RecordBatchT`]. +/// +/// # Implementation +/// This iterator is single-threaded and advancing it is CPU-bounded. +pub struct RowGroupDeserializer { + num_rows: usize, + remaining_rows: usize, + column_schema: ArrowSchemaRef, + column_chunks: Vec>, +} + +impl RowGroupDeserializer { + /// Creates a new [`RowGroupDeserializer`]. + /// + /// # Panic + /// This function panics iff any of the `column_chunks` + /// do not return an array with an equal length. + pub fn new( + column_schema: ArrowSchemaRef, + column_chunks: Vec>, + num_rows: usize, + limit: Option, + ) -> Self { + Self { + num_rows, + remaining_rows: limit.unwrap_or(usize::MAX).min(num_rows), + column_schema, + column_chunks, + } + } + + /// Returns the number of rows on this row group + pub fn num_rows(&self) -> usize { + self.num_rows + } +} + +impl Iterator for RowGroupDeserializer { + type Item = PolarsResult>>; + + fn next(&mut self) -> Option { + if self.remaining_rows == 0 { + return None; + } + let length = self.column_chunks.first().map_or(0, |chunk| chunk.len()); + let chunk = RecordBatchT::try_new( + length, + self.column_schema.clone(), + std::mem::take(&mut self.column_chunks), + ); + self.remaining_rows = self.remaining_rows.saturating_sub( + chunk + .as_ref() + .map(|x| x.len()) + .unwrap_or(self.remaining_rows), + ); + + Some(chunk) + } +} + +/// Reads all columns that are part of the parquet field `field_name` +/// # Implementation +/// This operation is IO-bounded `O(C)` where C is the number of columns associated to +/// the field (one for non-nested types) +pub fn read_columns<'a, R: Read + Seek>( + reader: &mut R, + row_group_metadata: &'a RowGroupMetadata, + field_name: &'a str, +) -> PolarsResult)>> { + row_group_metadata + .columns_under_root_iter(field_name) + .unwrap() + .map(|meta| _read_single_column(reader, meta)) + .collect() +} + +fn _read_single_column<'a, R>( + reader: &mut R, + meta: &'a ColumnChunkMetadata, +) -> PolarsResult<(&'a ColumnChunkMetadata, Vec)> +where + R: Read + Seek, +{ + let byte_range = meta.byte_range(); + let length = byte_range.end - byte_range.start; + reader.seek(std::io::SeekFrom::Start(byte_range.start))?; + + let mut chunk = vec![]; + chunk.try_reserve(length as usize)?; + reader.by_ref().take(length).read_to_end(&mut chunk)?; + Ok((meta, chunk)) +} + +/// Converts a vector of columns associated with the parquet field whose name is [`Field`] +/// to an iterator of [`Array`], [`ArrayIter`] of chunk size `chunk_size`. +pub fn to_deserializer( + columns: Vec<(&ColumnChunkMetadata, Vec)>, + field: Field, + filter: Option, +) -> PolarsResult> { + let (columns, types): (Vec<_>, Vec<_>) = columns + .into_iter() + .map(|(column_meta, chunk)| { + let len = chunk.len(); + let pages = PageReader::new( + MemReader::from_vec(chunk), + column_meta, + vec![], + len * 2 + 1024, + ); + ( + BasicDecompressor::new(pages, vec![]), + &column_meta.descriptor().descriptor.primitive_type, + ) + }) + .unzip(); + + column_iter_to_arrays(columns, types, field, filter).map(|v| v.0) +} + +/// Returns a vector of iterators of [`Array`] ([`ArrayIter`]) corresponding to the top +/// level parquet fields whose name matches `fields`'s names. +/// +/// # Implementation +/// This operation is IO-bounded `O(C)` where C is the number of columns in the row group - +/// it reads all the columns to memory from the row group associated to the requested fields. +/// +/// This operation is single-threaded. For readers with stronger invariants +/// (e.g. implement [`Clone`]) you can use [`read_columns`] to read multiple columns at once +/// and convert them to [`ArrayIter`] via [`to_deserializer`]. +pub fn read_columns_many( + reader: &mut R, + row_group: &RowGroupMetadata, + fields: &ArrowSchema, + filter: Option, +) -> PolarsResult>> { + // reads all the necessary columns for all fields from the row group + // This operation is IO-bounded `O(C)` where C is the number of columns in the row group + let field_columns = fields + .iter_values() + .map(|field| read_columns(reader, row_group, &field.name)) + .collect::>>()?; + + field_columns + .into_iter() + .zip(fields.iter_values().cloned()) + .map(|(columns, field)| to_deserializer(columns.clone(), field, filter.clone())) + .collect() +} diff --git a/crates/polars/tests/it/io/parquet/read/struct_.rs b/crates/polars/tests/it/io/parquet/read/struct_.rs new file mode 100644 index 000000000000..d7849920bb54 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/struct_.rs @@ -0,0 +1,32 @@ +use polars_parquet::parquet::encoding::hybrid_rle::HybridRleDecoder; +use polars_parquet::parquet::error::ParquetResult; +use polars_parquet::parquet::page::{DataPage, EncodedSplitBuffer, split_buffer}; +use polars_parquet::parquet::read::levels::get_bit_width; + +use super::hybrid_rle_iter; + +pub fn extend_validity(val: &mut Vec, page: &DataPage) -> ParquetResult<()> { + let EncodedSplitBuffer { + rep: _, + def: def_levels, + values: _, + } = split_buffer(page)?; + let length = page.num_values(); + + if page.descriptor.max_def_level == 0 { + return Ok(()); + } + + let def_level_encoding = ( + &page.definition_level_encoding(), + page.descriptor.max_def_level, + ); + + let def_levels = HybridRleDecoder::new(def_levels, get_bit_width(def_level_encoding.1), length); + + val.reserve(length); + hybrid_rle_iter(def_levels)?.for_each(|x| { + val.push(x != 0); + }); + Ok(()) +} diff --git a/crates/polars/tests/it/io/parquet/read/utils.rs b/crates/polars/tests/it/io/parquet/read/utils.rs new file mode 100644 index 000000000000..45294beab979 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/utils.rs @@ -0,0 +1,290 @@ +use polars_parquet::parquet::encoding::hybrid_rle::{self, BitmapIter, HybridRleDecoder}; +use polars_parquet::parquet::error::{ParquetError, ParquetResult}; +use polars_parquet::parquet::page::{DataPage, EncodedSplitBuffer, split_buffer}; +use polars_parquet::parquet::read::levels::get_bit_width; +use polars_parquet::parquet::schema::Repetition; +use polars_parquet::parquet::types::{NativeType, decode}; +use polars_parquet::read::PhysicalType; +use polars_parquet::write::Encoding; + +pub(super) fn dict_indices_decoder(page: &DataPage) -> ParquetResult { + let EncodedSplitBuffer { + rep: _, + def: _, + values: indices_buffer, + } = split_buffer(page)?; + + // SPEC: Data page format: the bit width used to encode the entry ids stored as 1 byte (max bit width = 32), + // SPEC: followed by the values encoded using RLE/Bit packed described above (with the given bit width). + let bit_width = indices_buffer[0]; + if bit_width > 32 { + panic!("Bit width of dictionary pages cannot be larger than 32",); + } + let indices_buffer = &indices_buffer[1..]; + + Ok(hybrid_rle::HybridRleDecoder::new( + indices_buffer, + bit_width as u32, + page.num_values(), + )) +} + +/// Decoder of definition levels. +#[derive(Debug)] +#[allow(clippy::large_enum_variant)] +pub enum DefLevelsDecoder<'a> { + /// When the maximum definition level is larger than 1 + Levels(HybridRleDecoder<'a>, u32), +} + +impl<'a> DefLevelsDecoder<'a> { + pub fn try_new(page: &'a DataPage) -> ParquetResult { + let EncodedSplitBuffer { + rep: _, + def: def_levels, + values: _, + } = split_buffer(page)?; + + let max_def_level = page.descriptor.max_def_level; + Ok({ + let iter = + HybridRleDecoder::new(def_levels, get_bit_width(max_def_level), page.num_values()); + Self::Levels(iter, max_def_level as u32) + }) + } +} + +pub fn deserialize_optional>>( + validity: DefLevelsDecoder, + values: I, +) -> ParquetResult>> { + match validity { + DefLevelsDecoder::Levels(levels, max_level) => { + deserialize_levels(levels, max_level, values) + }, + } +} + +fn deserialize_levels>>( + levels: HybridRleDecoder, + max: u32, + mut values: I, +) -> Result>, ParquetError> { + levels + .collect()? + .into_iter() + .map(|x| { + if x == max { + values.next().transpose() + } else { + Ok(None) + } + }) + .collect() +} + +#[derive(Debug)] +pub struct FixexBinaryIter<'a> { + values: std::slice::ChunksExact<'a, u8>, +} + +impl<'a> FixexBinaryIter<'a> { + pub fn new(values: &'a [u8], size: usize) -> Self { + let values = values.chunks_exact(size); + Self { values } + } +} + +impl<'a> Iterator for FixexBinaryIter<'a> { + type Item = &'a [u8]; + + #[inline] + fn next(&mut self) -> Option { + self.values.next() + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.values.size_hint() + } +} + +#[derive(Debug)] +pub struct Dictionary<'a, P> { + pub indexes: hybrid_rle::HybridRleDecoder<'a>, + pub dict: P, +} + +impl<'a, P> Dictionary<'a, P> { + pub fn try_new(page: &'a DataPage, dict: P) -> ParquetResult { + let indexes = dict_indices_decoder(page)?; + + Ok(Self { indexes, dict }) + } +} + +#[allow(clippy::large_enum_variant)] +pub enum FixedLenBinaryPageState<'a, P> { + Optional(DefLevelsDecoder<'a>, FixexBinaryIter<'a>), + Required(FixexBinaryIter<'a>), + RequiredDictionary(Dictionary<'a, P>), + OptionalDictionary(DefLevelsDecoder<'a>, Dictionary<'a, P>), +} + +impl<'a, P> FixedLenBinaryPageState<'a, P> { + pub fn try_new(page: &'a DataPage, dict: Option

) -> ParquetResult { + let is_optional = + page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + + let size: usize = if let PhysicalType::FixedLenByteArray(size) = + page.descriptor.primitive_type.physical_type + { + size + } else { + return Err(ParquetError::InvalidParameter( + "FixedLenBinaryPageState must be initialized by pages of FixedLenByteArray" + .to_string(), + )); + }; + + match (page.encoding(), dict, is_optional) { + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), false) => { + Dictionary::try_new(page, dict).map(Self::RequiredDictionary) + }, + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), true) => { + Ok(Self::OptionalDictionary( + DefLevelsDecoder::try_new(page)?, + Dictionary::try_new(page, dict)?, + )) + }, + (Encoding::Plain, _, true) => { + let EncodedSplitBuffer { + rep: _, + def: _, + values, + } = split_buffer(page)?; + + let validity = DefLevelsDecoder::try_new(page)?; + let values = FixexBinaryIter::new(values, size); + + Ok(Self::Optional(validity, values)) + }, + (Encoding::Plain, _, false) => { + let EncodedSplitBuffer { + rep: _, + def: _, + values, + } = split_buffer(page)?; + let values = FixexBinaryIter::new(values, size); + + Ok(Self::Required(values)) + }, + _ => Err(ParquetError::FeatureNotSupported(format!( + "Viewing page for encoding {:?} for binary type", + page.encoding(), + ))), + } + } +} + +/// Typedef of an iterator over PLAIN page values +pub type Casted<'a, T> = std::iter::Map, fn(&'a [u8]) -> T>; + +/// Views the values of the data page as [`Casted`] to [`NativeType`]. +pub fn native_cast(page: &DataPage) -> ParquetResult> { + let EncodedSplitBuffer { + rep: _, + def: _, + values, + } = split_buffer(page)?; + if values.len() % size_of::() != 0 { + panic!("A primitive page data's len must be a multiple of the type"); + } + + Ok(values.chunks_exact(size_of::()).map(decode::)) +} + +/// The deserialization state of a `DataPage` of `Primitive` parquet primitive type +#[derive(Debug)] +#[allow(clippy::large_enum_variant)] +pub enum NativePageState<'a, T, P> +where + T: NativeType, +{ + /// A page of optional values + Optional(DefLevelsDecoder<'a>, Casted<'a, T>), + /// A page of required values + Required(Casted<'a, T>), + /// A page of required, dictionary-encoded values + RequiredDictionary(Dictionary<'a, P>), + /// A page of optional, dictionary-encoded values + OptionalDictionary(DefLevelsDecoder<'a>, Dictionary<'a, P>), +} + +impl<'a, T: NativeType, P> NativePageState<'a, T, P> { + /// Tries to create [`NativePageState`] + /// # Error + /// Errors iff the page is not a `NativePageState` + pub fn try_new(page: &'a DataPage, dict: Option

) -> ParquetResult { + let is_optional = + page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + + match (page.encoding(), dict, is_optional) { + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), false) => { + Dictionary::try_new(page, dict).map(Self::RequiredDictionary) + }, + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), true) => { + Ok(Self::OptionalDictionary( + DefLevelsDecoder::try_new(page)?, + Dictionary::try_new(page, dict)?, + )) + }, + (Encoding::Plain, _, true) => { + let validity = DefLevelsDecoder::try_new(page)?; + let values = native_cast(page)?; + + Ok(Self::Optional(validity, values)) + }, + (Encoding::Plain, _, false) => native_cast(page).map(Self::Required), + _ => Err(ParquetError::FeatureNotSupported(format!( + "Viewing page for encoding {:?} for native type {}", + page.encoding(), + std::any::type_name::() + ))), + } + } +} + +// The state of a `DataPage` of `Boolean` parquet boolean type +#[derive(Debug)] +#[allow(clippy::large_enum_variant)] +pub enum BooleanPageState<'a> { + Optional(DefLevelsDecoder<'a>, BitmapIter<'a>), + Required(&'a [u8], usize), +} + +impl<'a> BooleanPageState<'a> { + pub fn try_new(page: &'a DataPage) -> ParquetResult { + let is_optional = + page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + + match (page.encoding(), is_optional) { + (Encoding::Plain, true) => { + let validity = DefLevelsDecoder::try_new(page)?; + + let values = split_buffer(page)?.values; + let values = BitmapIter::new(values, 0, values.len() * 8); + + Ok(Self::Optional(validity, values)) + }, + (Encoding::Plain, false) => { + let values = split_buffer(page)?.values; + Ok(Self::Required(values, page.num_values())) + }, + _ => Err(ParquetError::InvalidParameter(format!( + "Viewing page for encoding {:?} for boolean type not supported", + page.encoding(), + ))), + } + } +} diff --git a/crates/polars/tests/it/io/parquet/roundtrip.rs b/crates/polars/tests/it/io/parquet/roundtrip.rs new file mode 100644 index 000000000000..62a8ba412cab --- /dev/null +++ b/crates/polars/tests/it/io/parquet/roundtrip.rs @@ -0,0 +1,86 @@ +use std::io::Cursor; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Utf8ViewArray}; +use arrow::datatypes::{ArrowSchema, Field}; +use arrow::record_batch::RecordBatchT; +use polars_error::PolarsResult; +use polars_parquet::arrow::write::{FileWriter, WriteOptions}; +use polars_parquet::read::read_metadata; +use polars_parquet::write::{ + CompressionOptions, Encoding, RowGroupIterator, StatisticsOptions, Version, +}; + +use crate::io::parquet::read::file::FileReader; + +fn round_trip( + array: &ArrayRef, + version: Version, + compression: CompressionOptions, + encodings: Vec, +) -> PolarsResult<()> { + let field = Field::new("a1".into(), array.dtype().clone(), true); + let schema = ArrowSchema::from_iter([field]); + + let options = WriteOptions { + statistics: StatisticsOptions::full(), + compression, + version, + data_page_size: None, + }; + + let iter = vec![RecordBatchT::try_new( + array.len(), + Arc::new(schema.clone()), + vec![array.clone()], + )]; + + let row_groups = + RowGroupIterator::try_new(iter.into_iter(), &schema, options, vec![encodings])?; + + let writer = Cursor::new(vec![]); + let mut writer = FileWriter::try_new(writer, schema.clone(), options)?; + + for group in row_groups { + writer.write(group?)?; + } + writer.end(None)?; + + let data = writer.into_inner().into_inner(); + + let mut reader = Cursor::new(data); + let md = read_metadata(&mut reader).unwrap(); + // say we found that we only need to read the first two row groups, "0" and "1" + let row_groups = md + .row_groups + .into_iter() + .enumerate() + .filter(|(index, _)| *index == 0 || *index == 1) + .map(|(_, row_group)| row_group) + .collect(); + + // we can then read the row groups into chunks + let chunks = FileReader::new(reader, row_groups, schema, None); + + let mut arrays = vec![]; + for chunk in chunks { + let chunk = chunk?; + arrays.push(chunk.first().unwrap().clone()) + } + assert_eq!(arrays.len(), 1); + + assert_eq!(array.as_ref(), arrays[0].as_ref()); + Ok(()) +} + +#[test] +fn roundtrip_binview() -> PolarsResult<()> { + let array = Utf8ViewArray::from_slice([Some("foo"), Some("bar"), None, Some("hamlet")]); + + round_trip( + &array.boxed(), + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} diff --git a/crates/polars/tests/it/io/parquet/write/binary.rs b/crates/polars/tests/it/io/parquet/write/binary.rs new file mode 100644 index 000000000000..31afccc8b425 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/write/binary.rs @@ -0,0 +1,88 @@ +use polars_parquet::parquet::CowBuffer; +use polars_parquet::parquet::encoding::Encoding; +use polars_parquet::parquet::encoding::hybrid_rle::encode; +use polars_parquet::parquet::error::ParquetResult; +use polars_parquet::parquet::metadata::Descriptor; +use polars_parquet::parquet::page::{DataPage, DataPageHeader, DataPageHeaderV1, Page}; +use polars_parquet::parquet::statistics::BinaryStatistics; +use polars_parquet::parquet::types::ord_binary; +use polars_parquet::parquet::write::WriteOptions; + +fn unzip_option(array: &[Option>]) -> ParquetResult<(Vec, Vec)> { + // leave the first 4 bytes announcing the length of the def level + // this will be overwritten at the end, once the length is known. + // This is unknown at this point because of the uleb128 encoding, + // whose length is variable. + let mut validity = std::io::Cursor::new(vec![0; 4]); + validity.set_position(4); + + let mut values = vec![]; + let iter = array.iter().map(|value| { + if let Some(item) = value { + values.extend_from_slice(&(item.len() as i32).to_le_bytes()); + values.extend_from_slice(item.as_ref()); + true + } else { + false + } + }); + encode::(&mut validity, iter, 1)?; + + // write the length, now that it is known + let mut validity = validity.into_inner(); + let length = validity.len() - 4; + // todo: pay this small debt (loop?) + let length = length.to_le_bytes(); + validity[0] = length[0]; + validity[1] = length[1]; + validity[2] = length[2]; + validity[3] = length[3]; + + Ok((values, validity)) +} + +pub fn array_to_page_v1( + array: &[Option>], + options: &WriteOptions, + descriptor: &Descriptor, +) -> ParquetResult { + let (values, mut buffer) = unzip_option(array)?; + + buffer.extend_from_slice(&values); + + let statistics = if options.write_statistics { + let statistics = &BinaryStatistics { + primitive_type: descriptor.primitive_type.clone(), + null_count: Some((array.len() - array.iter().flatten().count()) as i64), + distinct_count: None, + max_value: array + .iter() + .flatten() + .max_by(|x, y| ord_binary(x, y)) + .cloned(), + min_value: array + .iter() + .flatten() + .min_by(|x, y| ord_binary(x, y)) + .cloned(), + }; + Some(statistics.serialize()) + } else { + None + }; + + let header = DataPageHeaderV1 { + num_values: array.len() as i32, + encoding: Encoding::Plain.into(), + definition_level_encoding: Encoding::Rle.into(), + repetition_level_encoding: Encoding::Rle.into(), + statistics, + }; + + Ok(Page::Data(DataPage::new( + DataPageHeader::V1(header), + CowBuffer::Owned(buffer), + descriptor.clone(), + array.len(), + ))) +} diff --git a/crates/polars/tests/it/io/parquet/write/mod.rs b/crates/polars/tests/it/io/parquet/write/mod.rs new file mode 100644 index 000000000000..63f10d7a5704 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/write/mod.rs @@ -0,0 +1,299 @@ +mod binary; +mod primitive; +mod sidecar; + +use std::io::{Cursor, Read, Seek}; + +use polars::io::SerReader; +use polars::io::parquet::read::ParquetReader; +use polars::io::parquet::write::ParquetWriter; +use polars_core::df; +use polars_core::prelude::*; +use polars_parquet::parquet::compression::{BrotliLevel, CompressionOptions}; +use polars_parquet::parquet::error::ParquetResult; +use polars_parquet::parquet::metadata::{Descriptor, SchemaDescriptor}; +use polars_parquet::parquet::page::Page; +use polars_parquet::parquet::schema::types::{ParquetType, PhysicalType}; +use polars_parquet::parquet::statistics::Statistics; +use polars_parquet::parquet::write::{ + Compressor, DynIter, DynStreamingIterator, FileWriter, Version, WriteOptions, +}; +use polars_parquet::read::read_metadata; +use polars_utils::mmap::MemReader; +use primitive::array_to_page_v1; + +use super::{Array, alltypes_plain, alltypes_statistics}; + +pub fn array_to_page( + array: &Array, + options: &WriteOptions, + descriptor: &Descriptor, +) -> ParquetResult { + // using plain encoding format + match array { + Array::Int32(array) => primitive::array_to_page_v1(array, options, descriptor), + Array::Int64(array) => primitive::array_to_page_v1(array, options, descriptor), + Array::Int96(array) => primitive::array_to_page_v1(array, options, descriptor), + Array::Float(array) => primitive::array_to_page_v1(array, options, descriptor), + Array::Double(array) => primitive::array_to_page_v1(array, options, descriptor), + Array::Binary(array) => binary::array_to_page_v1(array, options, descriptor), + _ => todo!(), + } +} + +fn read_column(reader: &mut R) -> ParquetResult<(Array, Option)> { + let memreader = MemReader::from_reader(reader)?; + let (a, statistics) = super::read::read_column(memreader, 0, "col")?; + Ok((a, statistics)) +} + +fn test_column(column: &str, compression: CompressionOptions) -> ParquetResult<()> { + let array = alltypes_plain(column); + + let options = WriteOptions { + write_statistics: true, + version: Version::V1, + }; + + // prepare schema + let type_ = match array { + Array::Int32(_) => PhysicalType::Int32, + Array::Int64(_) => PhysicalType::Int64, + Array::Int96(_) => PhysicalType::Int96, + Array::Float(_) => PhysicalType::Float, + Array::Double(_) => PhysicalType::Double, + Array::Binary(_) => PhysicalType::ByteArray, + _ => todo!(), + }; + + let schema = SchemaDescriptor::new( + "schema".into(), + vec![ParquetType::from_physical("col".into(), type_)], + ); + + let a = schema.columns(); + + let pages = DynStreamingIterator::new(Compressor::new_from_vec( + DynIter::new(std::iter::once(array_to_page( + &array, + &options, + &a[0].descriptor, + ))), + compression, + vec![], + )); + let columns = std::iter::once(Ok(pages)); + + let writer = Cursor::new(vec![]); + let mut writer = FileWriter::new(writer, schema, options, None); + + writer.write(DynIter::new(columns))?; + writer.end(None)?; + + let data = writer.into_inner().into_inner(); + + let (result, statistics) = read_column(&mut Cursor::new(data))?; + assert_eq!(array, result); + let stats = alltypes_statistics(column); + assert_eq!(statistics.as_ref(), Some(stats).as_ref(),); + Ok(()) +} + +#[test] +fn int32() -> ParquetResult<()> { + test_column("id", CompressionOptions::Uncompressed) +} + +#[test] +fn int32_snappy() -> ParquetResult<()> { + test_column("id", CompressionOptions::Snappy) +} + +#[test] +fn int32_lz4() -> ParquetResult<()> { + test_column("id", CompressionOptions::Lz4Raw) +} + +#[test] +fn int32_lz4_short_i32_array() -> ParquetResult<()> { + test_column("id-short-array", CompressionOptions::Lz4Raw) +} + +#[test] +fn int32_brotli() -> ParquetResult<()> { + test_column( + "id", + CompressionOptions::Brotli(Some(BrotliLevel::default())), + ) +} + +#[test] +#[ignore = "Native boolean writer not yet implemented"] +fn bool() -> ParquetResult<()> { + test_column("bool_col", CompressionOptions::Uncompressed) +} + +#[test] +fn tinyint() -> ParquetResult<()> { + test_column("tinyint_col", CompressionOptions::Uncompressed) +} + +#[test] +fn smallint_col() -> ParquetResult<()> { + test_column("smallint_col", CompressionOptions::Uncompressed) +} + +#[test] +fn int_col() -> ParquetResult<()> { + test_column("int_col", CompressionOptions::Uncompressed) +} + +#[test] +fn bigint_col() -> ParquetResult<()> { + test_column("bigint_col", CompressionOptions::Uncompressed) +} + +#[test] +fn float_col() -> ParquetResult<()> { + test_column("float_col", CompressionOptions::Uncompressed) +} + +#[test] +fn double_col() -> ParquetResult<()> { + test_column("double_col", CompressionOptions::Uncompressed) +} + +#[test] +fn basic() -> ParquetResult<()> { + let array = vec![ + Some(0), + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + Some(6), + ]; + + let options = WriteOptions { + write_statistics: false, + version: Version::V1, + }; + + let schema = SchemaDescriptor::new( + "schema".into(), + vec![ParquetType::from_physical( + "col".into(), + PhysicalType::Int32, + )], + ); + + let pages = DynStreamingIterator::new(Compressor::new_from_vec( + DynIter::new(std::iter::once(array_to_page_v1( + &array, + &options, + &schema.columns()[0].descriptor, + ))), + CompressionOptions::Uncompressed, + vec![], + )); + let columns = std::iter::once(Ok(pages)); + + let writer = Cursor::new(vec![]); + let mut writer = FileWriter::new(writer, schema, options, None); + + writer.write(DynIter::new(columns))?; + writer.end(None)?; + + let data = writer.into_inner().into_inner(); + let mut reader = Cursor::new(data); + + let metadata = read_metadata(&mut reader)?; + + // validated against an equivalent array produced by pyarrow. + let expected = 51; + assert_eq!( + metadata.row_groups[0] + .columns_under_root_iter("col") + .unwrap() + .next() + .unwrap() + .uncompressed_size(), + expected + ); + + Ok(()) +} + +#[test] +fn test_parquet() { + // In CI: This test will be skipped because the file does not exist. + if let Ok(r) = polars_utils::open_file("data/simple.parquet".as_ref()) { + let reader = ParquetReader::new(r); + let df = reader.finish().unwrap(); + assert_eq!(df.get_column_names(), ["a", "b"]); + assert_eq!(df.shape(), (3, 2)); + } +} + +#[test] +#[cfg(feature = "dtype-datetime")] +fn test_parquet_datetime_round_trip() -> PolarsResult<()> { + use std::io::{Cursor, Seek, SeekFrom}; + + let mut f = Cursor::new(vec![]); + + let mut df = df![ + "datetime" => [Some(191845729i64), Some(89107598), None, Some(3158971092)] + ]?; + + df.try_apply("datetime", |s| { + s.cast(&DataType::Datetime(TimeUnit::Nanoseconds, None)) + })?; + + ParquetWriter::new(&mut f).finish(&mut df)?; + + f.seek(SeekFrom::Start(0))?; + + let read = ParquetReader::new(f).finish()?; + assert!(read.equals_missing(&df)); + Ok(()) +} + +#[test] +fn test_read_parquet_with_projection() { + let mut buf: Cursor> = Cursor::new(Vec::new()); + let mut df = df!("a" => [1, 2, 3], "b" => [2, 3, 4], "c" => [3, 4, 5]).unwrap(); + + ParquetWriter::new(&mut buf) + .finish(&mut df) + .expect("parquet writer"); + buf.set_position(0); + + let expected = df!("b" => [2, 3, 4], "c" => [3, 4, 5]).unwrap(); + let df_read = ParquetReader::new(buf) + .with_projection(Some(vec![1, 2])) + .finish() + .unwrap(); + assert_eq!(df_read.shape(), (3, 2)); + df_read.equals(&expected); +} + +#[test] +fn test_read_parquet_with_columns() { + let mut buf: Cursor> = Cursor::new(Vec::new()); + let mut df = df!("a" => [1, 2, 3], "b" => [2, 3, 4], "c" => [3, 4, 5]).unwrap(); + + ParquetWriter::new(&mut buf) + .finish(&mut df) + .expect("parquet writer"); + buf.set_position(0); + + let expected = df!("b" => [2, 3, 4], "c" => [3, 4, 5]).unwrap(); + let df_read = ParquetReader::new(buf) + .with_columns(Some(vec!["c".to_string(), "b".to_string()])) + .finish() + .unwrap(); + assert_eq!(df_read.shape(), (3, 2)); + df_read.equals(&expected); +} diff --git a/crates/polars/tests/it/io/parquet/write/primitive.rs b/crates/polars/tests/it/io/parquet/write/primitive.rs new file mode 100644 index 000000000000..364728579f9d --- /dev/null +++ b/crates/polars/tests/it/io/parquet/write/primitive.rs @@ -0,0 +1,79 @@ +use polars_parquet::parquet::CowBuffer; +use polars_parquet::parquet::encoding::Encoding; +use polars_parquet::parquet::encoding::hybrid_rle::encode; +use polars_parquet::parquet::error::ParquetResult; +use polars_parquet::parquet::metadata::Descriptor; +use polars_parquet::parquet::page::{DataPage, DataPageHeader, DataPageHeaderV1, Page}; +use polars_parquet::parquet::statistics::PrimitiveStatistics; +use polars_parquet::parquet::types::NativeType; +use polars_parquet::parquet::write::WriteOptions; + +fn unzip_option(array: &[Option]) -> ParquetResult<(Vec, Vec)> { + // leave the first 4 bytes announcing the length of the def level + // this will be overwritten at the end, once the length is known. + // This is unknown at this point because of the uleb128 encoding, + // whose length is variable. + let mut validity = std::io::Cursor::new(vec![0; 4]); + validity.set_position(4); + + let mut values = vec![]; + let iter = array.iter().map(|value| { + if let Some(item) = value { + values.extend_from_slice(item.to_le_bytes().as_ref()); + true + } else { + false + } + }); + encode::(&mut validity, iter, 1)?; + + // write the length, now that it is known + let mut validity = validity.into_inner(); + let length = validity.len() - 4; + // todo: pay this small debt (loop?) + let length = length.to_le_bytes(); + validity[0] = length[0]; + validity[1] = length[1]; + validity[2] = length[2]; + validity[3] = length[3]; + + Ok((values, validity)) +} + +pub fn array_to_page_v1( + array: &[Option], + options: &WriteOptions, + descriptor: &Descriptor, +) -> ParquetResult { + let (values, mut buffer) = unzip_option(array)?; + + buffer.extend_from_slice(&values); + + let statistics = if options.write_statistics { + let statistics = &PrimitiveStatistics { + primitive_type: descriptor.primitive_type.clone(), + null_count: Some((array.len() - array.iter().flatten().count()) as i64), + distinct_count: None, + max_value: array.iter().flatten().max_by(|x, y| x.ord(y)).copied(), + min_value: array.iter().flatten().min_by(|x, y| x.ord(y)).copied(), + }; + Some(statistics.serialize()) + } else { + None + }; + + let header = DataPageHeaderV1 { + num_values: array.len() as i32, + encoding: Encoding::Plain.into(), + definition_level_encoding: Encoding::Rle.into(), + repetition_level_encoding: Encoding::Rle.into(), + statistics, + }; + + Ok(Page::Data(DataPage::new( + DataPageHeader::V1(header), + CowBuffer::Owned(buffer), + descriptor.clone(), + array.len(), + ))) +} diff --git a/crates/polars/tests/it/io/parquet/write/sidecar.rs b/crates/polars/tests/it/io/parquet/write/sidecar.rs new file mode 100644 index 000000000000..9e0745d5c0e2 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/write/sidecar.rs @@ -0,0 +1,52 @@ +use polars_parquet::parquet::error::ParquetError; +use polars_parquet::parquet::metadata::SchemaDescriptor; +use polars_parquet::parquet::schema::types::{ParquetType, PhysicalType}; +use polars_parquet::parquet::write::{FileWriter, Version, WriteOptions, write_metadata_sidecar}; + +#[test] +fn basic() -> Result<(), ParquetError> { + let schema = SchemaDescriptor::new( + "schema".into(), + vec![ParquetType::from_physical("c1".into(), PhysicalType::Int32)], + ); + + let mut metadatas = vec![]; + for i in 0..10 { + // say we will write 10 files + let relative_path = format!("part-{i}.parquet"); + let writer = std::io::Cursor::new(vec![]); + let mut writer = FileWriter::new( + writer, + schema.clone(), + WriteOptions { + write_statistics: true, + version: Version::V2, + }, + None, + ); + writer.end(None)?; + let (_, mut metadata) = writer.into_inner_and_metadata(); + + // once done, we write their relative paths: + metadata.row_groups.iter_mut().for_each(|row_group| { + row_group + .columns + .iter_mut() + .for_each(|column| column.file_path = Some(relative_path.clone())) + }); + metadatas.push(metadata); + } + + // merge their row groups + let first = metadatas.pop().unwrap(); + let sidecar = metadatas.into_iter().fold(first, |mut acc, metadata| { + acc.row_groups.extend(metadata.row_groups); + acc + }); + + // and write the metadata on a separate file + let mut writer = std::io::Cursor::new(vec![]); + write_metadata_sidecar(&mut writer, &sidecar)?; + + Ok(()) +} diff --git a/crates/polars/tests/it/io/partitioned.rs b/crates/polars/tests/it/io/partitioned.rs new file mode 100644 index 000000000000..c9a0b9758497 --- /dev/null +++ b/crates/polars/tests/it/io/partitioned.rs @@ -0,0 +1,51 @@ +use std::io::BufReader; +use std::path::PathBuf; + +use polars::io::ipc::{IpcReader, IpcWriterOption}; +use polars::io::prelude::SerReader; +use polars::io::PartitionedWriter; +use polars_error::PolarsResult; + +#[test] +#[cfg(feature = "ipc")] +fn test_ipc_partition() -> PolarsResult<()> { + let tmp_dir = tempfile::tempdir()?; + + let df = df!("a" => [1, 1, 2, 3], "b" => [2, 2, 3, 4], "c" => [2, 3, 4, 5]).unwrap(); + let by = ["a", "b"]; + let rootdir = tmp_dir.path().join("ipc-partition"); + + let option = IpcWriterOption::new(); + + PartitionedWriter::new(option, rootdir.clone(), by).finish(&df)?; + + let expected_dfs = [ + df!("a" => [1, 1], "b" => [2, 2], "c" => [2, 3])?, + df!("a" => [2], "b" => [3], "c" => [4])?, + df!("a" => [3], "b" => [4], "c" => [5])?, + ]; + + let expected: Vec<(PathBuf, DataFrame)> = ["a=1/b=2", "a=2/b=3", "a=3/b=4"] + .into_iter() + .zip(expected_dfs) + .map(|(p, df)| (rootdir.join(p), df)) + .collect(); + + for (expected_dir, expected_df) in expected.iter() { + assert!(expected_dir.exists()); + + let ipc_paths = std::fs::read_dir(expected_dir)? + .map(|e| { + let entry = e?; + Ok(entry.path()) + }) + .collect::>>()?; + + assert_eq!(ipc_paths.len(), 1); + let reader = BufReader::new(polars_utils::open_file(&ipc_paths[0])?); + let df = IpcReader::new(reader).finish()?; + assert!(expected_df.equals(&df)); + } + + Ok(()) +} diff --git a/crates/polars/tests/it/joins.rs b/crates/polars/tests/it/joins.rs new file mode 100644 index 000000000000..19e4911df3a9 --- /dev/null +++ b/crates/polars/tests/it/joins.rs @@ -0,0 +1,58 @@ +#[cfg(feature = "lazy")] +use polars::prelude::*; + +#[test] +#[cfg(feature = "lazy")] +fn join_nans_outer() -> PolarsResult<()> { + let df1 = df! { + "w" => [Some(2.5), None, Some(f64::NAN), None, Some(2.5), Some(f64::NAN), None, Some(3.0)], + "t" => [Some("xl"), Some("xl"), Some("xl"), Some("xl"), Some("xl"), Some("xl"), Some("xl"), Some("l")], + "c" => [Some(10), Some(5), Some(3), Some(2), Some(9), Some(4), Some(11), Some(3)], + }? + .lazy(); + let a1 = df1 + .clone() + .group_by(vec![col("w").alias("w"), col("t")]) + .agg(vec![col("c").sum().alias("c_sum")]); + let a2 = df1 + .group_by(vec![col("w").alias("w"), col("t")]) + .agg(vec![col("c").max().alias("c_max")]); + + let res = a1 + .join_builder() + .with(a2) + .left_on(vec![col("w"), col("t")]) + .right_on(vec![col("w"), col("t")]) + .how(JoinType::Full) + .coalesce(JoinCoalesce::CoalesceColumns) + .join_nulls(true) + .finish() + .collect()?; + + assert_eq!(res.shape(), (4, 4)); + Ok(()) +} + +#[test] +#[cfg(feature = "lazy")] +fn join_empty_datasets() -> PolarsResult<()> { + let a = DataFrame::new(Vec::from([Column::new_empty( + "foo".into(), + &DataType::Int64, + )])) + .unwrap(); + let b = DataFrame::new(Vec::from([ + Column::new_empty("foo".into(), &DataType::Int64), + Column::new_empty("bar".into(), &DataType::Int64), + ])) + .unwrap(); + + a.lazy() + .group_by([col("foo")]) + .agg([all().last()]) + .inner_join(b.lazy(), "foo", "foo") + .collect() + .unwrap(); + + Ok(()) +} diff --git a/crates/polars/tests/it/lazy/aggregation.rs b/crates/polars/tests/it/lazy/aggregation.rs new file mode 100644 index 000000000000..10c386037d17 --- /dev/null +++ b/crates/polars/tests/it/lazy/aggregation.rs @@ -0,0 +1,37 @@ +use super::*; + +#[test] +#[cfg(feature = "temporal")] +fn test_lazy_agg() { + let s0 = DateChunked::parse_from_str_slice( + "date".into(), + &[ + "2020-08-21", + "2020-08-21", + "2020-08-22", + "2020-08-23", + "2020-08-22", + ], + "%Y-%m-%d", + ) + .into_column(); + let s1 = Column::new("temp".into(), [20, 10, 7, 9, 1].as_ref()); + let s2 = Column::new("rain".into(), [0.2, 0.1, 0.3, 0.1, 0.01].as_ref()); + let df = DataFrame::new(vec![s0, s1, s2]).unwrap(); + + let lf = df + .lazy() + .group_by([col("date")]) + .agg([ + col("rain").min().alias("min"), + col("rain").sum().alias("sum"), + col("rain") + .quantile(lit(0.5), QuantileMethod::default()) + .alias("median_rain"), + ]) + .sort(["date"], Default::default()); + + let new = lf.collect().unwrap(); + let min = new.column("min").unwrap(); + assert_eq!(min, &Column::new("min".into(), [0.1f64, 0.01, 0.1])); +} diff --git a/crates/polars/tests/it/lazy/cse.rs b/crates/polars/tests/it/lazy/cse.rs new file mode 100644 index 000000000000..5178d55fbab0 --- /dev/null +++ b/crates/polars/tests/it/lazy/cse.rs @@ -0,0 +1,40 @@ +use super::*; + +#[test] +#[cfg(feature = "semi_anti_join")] +fn test_cse_union_schema_6504() -> PolarsResult<()> { + use polars_core::df; + let q1: LazyFrame = df![ + "a" => [1], + "b" => [2], + ]? + .lazy(); + let q2: LazyFrame = df![ + "b" => [1], + ]? + .lazy(); + + let q3 = q2 + .join(q1.clone(), [col("b")], [col("b")], JoinType::Anti.into()) + .with_column(lit(0).alias("a")) + .select([col("a"), col("b")]); + + let out = concat( + [q1, q3], + UnionArgs { + rechunk: false, + parallel: false, + ..Default::default() + }, + ) + .unwrap() + .with_comm_subplan_elim(true) + .collect()?; + let expected = df![ + "a" => [1, 0], + "b" => [2, 1], + ]?; + assert!(out.equals(&expected)); + + Ok(()) +} diff --git a/crates/polars/tests/it/lazy/cwc.rs b/crates/polars/tests/it/lazy/cwc.rs new file mode 100644 index 000000000000..df4949d60553 --- /dev/null +++ b/crates/polars/tests/it/lazy/cwc.rs @@ -0,0 +1,123 @@ +use polars::prelude::*; + +#[test] +#[ignore = "fuzz test: Takes to long"] +fn fuzz_cluster_with_columns() { + const PRIMES: &[i32] = &[2, 3, 5, 7, 11, 13, 17, 19, 23, 29]; + use rand::Rng; + + macro_rules! to_str { + ($col:expr) => { + std::str::from_utf8(std::slice::from_ref(&$col)).unwrap() + }; + } + + fn rnd_prime(rng: &'_ mut rand::rngs::ThreadRng) -> i32 { + PRIMES[rng.gen_range(0..PRIMES.len())] + } + + fn sample(rng: &'_ mut rand::rngs::ThreadRng, slice: &[u8]) -> u8 { + assert!(!slice.is_empty()); + slice[rng.gen_range(0..slice.len())] + } + + fn gen_expr(rng: &mut rand::rngs::ThreadRng, used_cols: &[u8]) -> Expr { + let mut depth = 0; + + use rand::Rng; + + fn leaf(rng: &mut rand::rngs::ThreadRng, used_cols: &[u8]) -> Expr { + if rng.r#gen() { + lit(rnd_prime(rng)) + } else { + col(to_str!(sample(rng, used_cols))) + } + } + + let mut e = leaf(rng, used_cols); + + loop { + if depth >= 10 || rng.r#gen() { + return e; + } else { + e = e * col(to_str!(sample(rng, used_cols))); + } + + depth += 1; + } + } + + use std::ops::RangeInclusive; + + const NUM_ORIGINAL_COLS: RangeInclusive = 1..=6; + const NUM_WITH_COLUMNS: RangeInclusive = 1..=64; + const NUM_EXPRS: RangeInclusive = 1..=8; + + let mut rng = rand::thread_rng(); + let rng = &mut rng; + + let mut unused_cols: Vec = Vec::with_capacity(26); + let mut used_cols: Vec = Vec::with_capacity(26); + + let mut columns: Vec = Vec::with_capacity(*NUM_ORIGINAL_COLS.end()); + + let mut used: Vec = Vec::with_capacity(26); + + let num_fuzzes = 100_000; + for _ in 0..num_fuzzes { + unused_cols.clear(); + used_cols.clear(); + unused_cols.extend(b'a'..=b'z'); + + let num_with_columns = rng.gen_range(NUM_WITH_COLUMNS.clone()); + let num_columns = rng.gen_range(NUM_ORIGINAL_COLS.clone()); + + for _ in 0..num_columns { + let column = rng.gen_range(0..unused_cols.len()); + let column = unused_cols.swap_remove(column); + + columns.push(Column::new(to_str!(column).into(), vec![rnd_prime(rng)])); + used_cols.push(column); + } + + let mut lf = DataFrame::new(std::mem::take(&mut columns)).unwrap().lazy(); + + for _ in 0..num_with_columns { + let num_exprs = rng.gen_range(0..8); + let mut exprs = Vec::with_capacity(*NUM_EXPRS.end()); + used.clear(); + + for _ in 0..num_exprs { + let col = loop { + let col = if unused_cols.is_empty() || rng.r#gen() { + sample(rng, &used_cols) + } else { + sample(rng, &unused_cols) + }; + + if !used.contains(&col) { + break col; + } + }; + + used.push(col); + + exprs.push(gen_expr(rng, &used_cols).alias(to_str!(col))); + } + + lf = lf.with_columns(exprs); + + for u in &used { + if let Some(idx) = unused_cols.iter().position(|x| x == u) { + unused_cols.remove(idx); + used_cols.push(*u); + } + } + } + + lf = lf.without_optimizations(); + let cwc = lf.clone().with_cluster_with_columns(true); + + assert_eq!(lf.collect().unwrap(), cwc.collect().unwrap()); + } +} diff --git a/crates/polars/tests/it/lazy/explodes.rs b/crates/polars/tests/it/lazy/explodes.rs new file mode 100644 index 000000000000..b42e23edb3a2 --- /dev/null +++ b/crates/polars/tests/it/lazy/explodes.rs @@ -0,0 +1,20 @@ +// used only if feature="strings" +#[allow(unused_imports)] +use super::*; + +#[cfg(feature = "strings")] +#[test] +fn test_explode_row_numbers() -> PolarsResult<()> { + let df = df![ + "text" => ["one two three four", "uno dos tres cuatro"] + ]? + .lazy() + .select([col("text").str().split(lit(" ")).alias("tokens")]) + .with_row_index("index", None) + .explode([col("tokens")]) + .select([col("index"), col("tokens")]) + .collect()?; + + assert_eq!(df.shape(), (8, 2)); + Ok(()) +} diff --git a/crates/polars/tests/it/lazy/expressions/apply.rs b/crates/polars/tests/it/lazy/expressions/apply.rs new file mode 100644 index 000000000000..8006d7da8291 --- /dev/null +++ b/crates/polars/tests/it/lazy/expressions/apply.rs @@ -0,0 +1,117 @@ +use super::*; + +#[test] +#[cfg(feature = "range")] +fn test_int_range_agg() -> PolarsResult<()> { + let df = df![ + "x" => [5, 5, 4, 4, 2, 2] + ]?; + + let out = df + .lazy() + .with_columns([int_range(lit(0i32), len(), 1, DataType::Int64).over([col("x")])]) + .collect()?; + assert_eq!( + Vec::from_iter(out.column("literal")?.i64()?.into_no_null_iter()), + &[0, 1, 0, 1, 0, 1] + ); + + Ok(()) +} + +#[test] +#[cfg(all(feature = "unique_counts", feature = "log"))] +fn test_groups_update() -> PolarsResult<()> { + let df = df!["group" => ["A" ,"A", "A", "B", "B", "B", "B"], + "id"=> [1, 1, 2, 3, 4, 3, 5] + ]?; + + let out = df + .lazy() + .group_by_stable([col("group")]) + .agg([col("id").unique_counts().log(2.0)]) + .explode([col("id")]) + .collect()?; + assert_eq!( + out.column("id")? + .f64()? + .into_no_null_iter() + .collect::>(), + &[1.0, 0.0, 1.0, 0.0, 0.0] + ); + Ok(()) +} + +#[test] +#[cfg(feature = "log")] +fn test_groups_update_binary_shift_log() -> PolarsResult<()> { + let out = df![ + "a" => [1, 2, 3, 5], + "b" => [1, 2, 1, 2], + ]? + .lazy() + .group_by([col("b")]) + .agg([col("a") - col("a").shift(lit(1)).log(2.0)]) + .sort(["b"], Default::default()) + .explode([col("a")]) + .collect()?; + assert_eq!( + Vec::from(out.column("a")?.f64()?), + &[None, Some(3.0), None, Some(4.0)] + ); + + Ok(()) +} + +#[test] +#[cfg(feature = "cum_agg")] +fn test_expand_list() -> PolarsResult<()> { + let out = df![ + "a" => [1, 2], + "b" => [2, 3], + ]? + .lazy() + .select([cols(["a", "b"]).cum_sum(false)]) + .collect()?; + + let expected = df![ + "a" => [1, 3], + "b" => [2, 5] + ]?; + + assert!(out.equals(&expected)); + + Ok(()) +} + +#[test] +fn test_apply_groups_empty() -> PolarsResult<()> { + let df = df![ + "id" => [1, 1], + "hi" => ["here", "here"] + ]?; + let out = df + .clone() + .lazy() + .filter(col("id").eq(lit(2))) + .group_by([col("id")]) + .agg([col("hi").drop_nulls().unique()]) + .explain(true) + .unwrap(); + println!("{}", out); + + let out = df + .lazy() + .filter(col("id").eq(lit(2))) + .group_by([col("id")]) + .agg([col("hi").drop_nulls().unique()]) + .collect()?; + + assert_eq!( + out.dtypes(), + &[DataType::Int32, DataType::List(Box::new(DataType::String))] + ); + assert_eq!(out.shape(), (0, 2)); + + Ok(()) +} diff --git a/crates/polars/tests/it/lazy/expressions/arity.rs b/crates/polars/tests/it/lazy/expressions/arity.rs new file mode 100644 index 000000000000..cb5f599a602b --- /dev/null +++ b/crates/polars/tests/it/lazy/expressions/arity.rs @@ -0,0 +1,369 @@ +use super::*; + +#[test] +#[cfg(feature = "unique_counts")] +fn test_list_broadcast() { + // simply test if this runs + df![ + "g" => [1, 1, 1], + "a" => [1, 2, 3], + ] + .unwrap() + .lazy() + .group_by([col("g")]) + .agg([col("a").unique_counts() * len()]) + .collect() + .unwrap(); +} + +#[test] +fn ternary_expand_sizes() -> PolarsResult<()> { + let df = df! { + "a" => [Some("a1"), None, None], + "b" => [Some("b1"), Some("b2"), None] + }?; + let out = df + .lazy() + .with_column( + when(not(lit(true))) + .then(lit("unexpected")) + .when(not(col("a").is_null())) + .then(col("a")) + .when(not(col("b").is_null())) + .then(col("b")) + .otherwise(lit("otherwise")) + .alias("c"), + ) + .collect()?; + let vals = out + .column("c")? + .str()? + .into_no_null_iter() + .collect::>(); + assert_eq!(vals, &["a1", "b2", "otherwise"]); + Ok(()) +} + +#[test] +#[cfg(feature = "strings")] +fn includes_null_predicate_3038() -> PolarsResult<()> { + let df = df! { + "a" => [Some("a1"), None, None], + }?; + let res = df + .lazy() + .with_column( + when(col("a").map( + move |s| { + s.str()? + .to_lowercase() + .contains("not_exist", true) + .map(|ca| Some(ca.into_column())) + }, + GetOutput::from_type(DataType::Boolean), + )) + .then(lit("unexpected")) + .when(col("a").eq(lit("a1".to_string()))) + .then(lit("good hit")) + .otherwise(Expr::Literal(LiteralValue::untyped_null())) + .alias("b"), + ) + .collect()?; + + let exp_df = df! { + "a" => [Some("a1"), None, None], + "b" => [Some("good hit"), None, None], + }?; + assert!(res.equals_missing(&exp_df)); + + let df = df! { + "a" => ["a1", "a2", "a3", "a4", "a2"], + "b" => [Some("tree"), None, None, None, None], + }?; + let res = df + .lazy() + .with_column( + when(col("b").map( + move |s| { + s.str()? + .to_lowercase() + .contains_literal("non-existent") + .map(|ca| Some(ca.into_column())) + }, + GetOutput::from_type(DataType::Boolean), + )) + .then(lit("weird-1")) + .when(col("a").eq(lit("a1".to_string()))) + .then(lit("ok1")) + .when(col("a").eq(lit("a2".to_string()))) + .then(lit("ok2")) + .when(lit(true)) + .then(lit("ft")) + .otherwise(Expr::Literal(LiteralValue::untyped_null())) + .alias("c"), + ) + .collect()?; + let exp_df = df! { + "a" => ["a1", "a2", "a3", "a4", "a2"], + "b" => [Some("tree"), None, None, None, None], + "c" => ["ok1", "ok2", "ft", "ft", "ok2"] + }?; + assert!(res.equals_missing(&exp_df)); + + Ok(()) +} + +#[test] +#[cfg(feature = "dtype-categorical")] +fn test_when_then_otherwise_cats() -> PolarsResult<()> { + polars::enable_string_cache(); + + let lf = df!["book" => [Some("bookA"), + None, + Some("bookB"), + None, + Some("bookA"), + Some("bookC"), + Some("bookC"), + Some("bookC")], + "user" => [Some("bob"), Some("bob"), Some("bob"), Some("tim"), Some("lucy"), Some("lucy"), None, None] + ]?.lazy(); + + let out = lf + .with_column(col("book").cast(DataType::Categorical(None, Default::default()))) + .with_column(col("user").cast(DataType::Categorical(None, Default::default()))) + .with_column( + when(col("book").is_null()) + .then(col("user")) + .otherwise(col("book")) + .alias("a"), + ) + .collect()?; + + assert_eq!( + out.column("a")? + .categorical()? + .iter_str() + .flatten() + .collect::>(), + &[ + "bookA", "bob", "bookB", "tim", "bookA", "bookC", "bookC", "bookC" + ] + ); + + Ok(()) +} + +#[test] +fn test_when_then_otherwise_single_bool() -> PolarsResult<()> { + let df = df![ + "key" => ["a", "b", "b"], + "val" => [Some(1), Some(2), None] + ]?; + + let out = df + .lazy() + .group_by_stable([col("key")]) + .agg([when(col("val").null_count().gt(lit(0))) + .then(Null {}.lit()) + .otherwise(col("val").sum()) + .alias("sum_null_prop")]) + .collect()?; + + let expected = df![ + "key" => ["a", "b"], + "sum_null_prop" => [Some(1), None] + ]?; + + assert!(out.equals_missing(&expected)); + + Ok(()) +} + +#[test] +#[cfg(feature = "unique_counts")] +fn test_update_groups_in_cast() -> PolarsResult<()> { + let df = df![ + "group" => ["A" ,"A", "A", "B", "B", "B", "B"], + "id"=> [1, 2, 1, 4, 5, 4, 6], + ]?; + + // optimized to + // col("id").unique_counts().cast(int64) * -1 + // in aggregation that cast coerces a list and the cast may forget to update groups + let out = df + .lazy() + .group_by_stable([col("group")]) + .agg([col("id").unique_counts() * lit(-1)]) + .collect()?; + + let expected = df![ + "group" => ["A" ,"B"], + "id"=> [AnyValue::List(Series::new("".into(), [-2i64, -1])), AnyValue::List(Series::new("".into(), [-2i64, -1, -1]))] + ]?; + + assert!(out.equals(&expected)); + Ok(()) +} + +#[test] +fn test_when_then_otherwise_sum_in_agg() -> PolarsResult<()> { + let df = df![ + "groups" => [1, 1, 2, 2], + "dist_a" => [0.1, 0.2, 0.5, 0.5], + "dist_b" => [0.8, 0.2, 0.5, 0.2], + ]?; + + let q = df + .lazy() + .group_by([col("groups")]) + .agg([when(all().exclude(["groups"]).sum().eq(lit(1))) + .then(all().exclude(["groups"]).sum()) + .otherwise(lit(LiteralValue::untyped_null()))]) + .sort(["groups"], Default::default()); + + let expected = df![ + "groups" => [1, 2], + "dist_a" => [None, Some(1.0f64)], + "dist_b" => [Some(1.0f64), None] + ]?; + assert!(q.collect()?.equals_missing(&expected)); + + Ok(()) +} + +#[test] +fn test_binary_over_3930() -> PolarsResult<()> { + let df = df![ + "class" => ["a", "a", "a", "b", "b", "b"], + "score" => [0.2, 0.5, 0.1, 0.3, 0.4, 0.2] + ]?; + + let ss = col("score").pow(2); + let mdiff = (ss.clone().shift(lit(-1)) - ss.shift(lit(1))) / lit(2); + let out = df.lazy().select([mdiff.over([col("class")])]).collect()?; + + let out = out.column("score")?; + let out = out.f64()?; + + assert_eq!( + Vec::from(out), + &[ + None, + Some(-0.015000000000000003), + None, + None, + Some(-0.024999999999999994), + None + ] + ); + + Ok(()) +} + +#[test] +#[cfg(feature = "rank")] +fn test_ternary_aggregation_set_literals() -> PolarsResult<()> { + let df = df![ + "name" => ["a", "b", "a", "b"], + "value" => [1, 3, 2, 4] + ]?; + + let out = df + .clone() + .lazy() + .group_by([col("name")]) + .agg([when(col("value").sum().eq(lit(3))) + .then(col("value").rank(Default::default(), None)) + .otherwise(lit(Series::new("".into(), &[10 as IdxSize])))]) + .sort(["name"], Default::default()) + .collect()?; + + let out = out.column("value")?; + assert_eq!( + out.get(0)?, + AnyValue::List(Series::new("".into(), &[1 as IdxSize, 2 as IdxSize])) + ); + assert_eq!( + out.get(1)?, + AnyValue::List(Series::new("".into(), &[10 as IdxSize, 10 as IdxSize])) + ); + + let out = df + .clone() + .lazy() + .group_by([col("name")]) + .agg([when(col("value").sum().eq(lit(3))) + .then(lit(Series::new("".into(), &[10 as IdxSize])).alias("value")) + .otherwise(col("value").rank(Default::default(), None))]) + .sort(["name"], Default::default()) + .collect()?; + + let out = out.column("value")?; + assert_eq!( + out.get(1)?, + AnyValue::List(Series::new("".into(), &[1 as IdxSize, 2])) + ); + assert_eq!( + out.get(0)?, + AnyValue::List(Series::new("".into(), &[10 as IdxSize, 10 as IdxSize])) + ); + + let out = df + .clone() + .lazy() + .group_by([col("name")]) + .agg([when(col("value").sum().eq(lit(3))) + .then(col("value").rank(Default::default(), None)) + .otherwise(Null {}.lit())]) + .sort(["name"], Default::default()) + .collect()?; + + let out = out.column("value")?; + assert!(matches!(out.get(0)?, AnyValue::List(_))); + assert!(matches!(out.get(1)?, AnyValue::List(_))); + + // swapped branch + let out = df + .lazy() + .group_by([col("name")]) + .agg([when(col("value").sum().eq(lit(3))) + .then(Null {}.lit().alias("value")) + .otherwise(col("value").rank(Default::default(), None))]) + .sort(["name"], Default::default()) + .collect()?; + + let out = out.column("value")?; + assert!(matches!(out.get(1)?, AnyValue::List(_))); + assert!(matches!(out.get(0)?, AnyValue::List(_))); + + Ok(()) +} + +#[test] +fn test_binary_group_consistency() -> PolarsResult<()> { + let lf = df![ + "name" => ["a", "b", "c", "d"], + "category" => [1, 2, 3, 4], + "score" => [3, 5, 1, 2], + ]? + .lazy(); + + let out = lf + .group_by([col("category")]) + .agg([col("name").filter(col("score").eq(col("score").max()))]) + .sort(["category"], Default::default()) + .collect()?; + let out = out.column("name")?; + + assert_eq!(out.dtype(), &DataType::List(Box::new(DataType::String))); + assert_eq!( + out.explode()? + .str()? + .into_no_null_iter() + .collect::>(), + &["a", "b", "c", "d"] + ); + + Ok(()) +} diff --git a/crates/polars/tests/it/lazy/expressions/expand.rs b/crates/polars/tests/it/lazy/expressions/expand.rs new file mode 100644 index 000000000000..32e7a7980b14 --- /dev/null +++ b/crates/polars/tests/it/lazy/expressions/expand.rs @@ -0,0 +1,46 @@ +use chrono::NaiveDate; + +use super::*; + +#[test] +fn test_expand_datetimes_3042() -> PolarsResult<()> { + let low = NaiveDate::from_ymd_opt(2020, 1, 1) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap(); + let high = NaiveDate::from_ymd_opt(2020, 2, 1) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap(); + let date_range = polars_time::date_range( + "dt1".into(), + low, + high, + Duration::parse("1w"), + ClosedWindow::Left, + TimeUnit::Milliseconds, + None, + )? + .into_series(); + + let out = df![ + "dt1" => date_range.clone(), + "dt2" => date_range, + ]? + .lazy() + .with_column( + dtype_col(&DataType::Datetime(TimeUnit::Milliseconds, None)) + .dt() + .to_string("%m/%d/%Y"), + ) + .limit(3) + .collect()?; + + let expected = df![ + "dt1" => ["01/01/2020", "01/08/2020", "01/15/2020"], + "dt2" => ["01/01/2020", "01/08/2020", "01/15/2020"], + ]?; + assert!(out.equals(&expected)); + + Ok(()) +} diff --git a/crates/polars/tests/it/lazy/expressions/filter.rs b/crates/polars/tests/it/lazy/expressions/filter.rs new file mode 100644 index 000000000000..2d60525c3d1a --- /dev/null +++ b/crates/polars/tests/it/lazy/expressions/filter.rs @@ -0,0 +1,36 @@ +use super::*; + +#[test] +fn test_filter_in_group_by_agg() -> PolarsResult<()> { + // This tests if the filter is correctly handled by the binary expression. + // This could lead to UB if it were not the case. The filter creates an empty column. + // but the group tuples could still be untouched leading to out of bounds aggregation. + let df = df![ + "a" => [1, 1, 2], + "b" => [1, 2, 3] + ]?; + + let out = df + .clone() + .lazy() + .group_by([col("a")]) + .agg([(col("b").filter(col("b").eq(lit(100))) * lit(2)) + .mean() + .alias("b_mean")]) + .collect()?; + + assert_eq!(out.column("b_mean")?.null_count(), 2); + + let out = df + .lazy() + .group_by([col("a")]) + .agg([(col("b") + .filter(col("b").eq(lit(100))) + .map(|v| Ok(Some(v)), GetOutput::same_type())) + .mean() + .alias("b_mean")]) + .collect()?; + assert_eq!(out.column("b_mean")?.null_count(), 2); + + Ok(()) +} diff --git a/crates/polars/tests/it/lazy/expressions/is_in.rs b/crates/polars/tests/it/lazy/expressions/is_in.rs new file mode 100644 index 000000000000..46a46e7eee3e --- /dev/null +++ b/crates/polars/tests/it/lazy/expressions/is_in.rs @@ -0,0 +1,20 @@ +use super::*; + +#[test] +fn test_is_in() -> PolarsResult<()> { + let df = df![ + "x" => [1, 2, 3], + "y" => ["a", "b", "c"] + ]?; + let s = Series::new("a".into(), ["a", "b"]); + + let out = df + .lazy() + .select([col("y").is_in(lit(s), false).alias("isin")]) + .collect()?; + assert_eq!( + Vec::from(out.column("isin")?.bool()?), + &[Some(true), Some(true), Some(false)] + ); + Ok(()) +} diff --git a/crates/polars/tests/it/lazy/expressions/literals.rs b/crates/polars/tests/it/lazy/expressions/literals.rs new file mode 100644 index 000000000000..a2e1cf7822b4 --- /dev/null +++ b/crates/polars/tests/it/lazy/expressions/literals.rs @@ -0,0 +1,19 @@ +use super::*; + +#[test] +fn test_datetime_as_lit() { + let Expr::Alias(e, name) = datetime(Default::default()) else { + panic!() + }; + assert_eq!(name, "datetime"); + assert!(matches!(e.as_ref(), Expr::Literal(_))) +} + +#[test] +fn test_duration_as_lit() { + let Expr::Alias(e, name) = duration(Default::default()) else { + panic!() + }; + assert_eq!(name, "duration"); + assert!(matches!(e.as_ref(), Expr::Literal(_))) +} diff --git a/crates/polars/tests/it/lazy/expressions/mod.rs b/crates/polars/tests/it/lazy/expressions/mod.rs new file mode 100644 index 000000000000..e52e550c5090 --- /dev/null +++ b/crates/polars/tests/it/lazy/expressions/mod.rs @@ -0,0 +1,11 @@ +mod apply; +mod arity; +mod expand; +mod filter; +#[cfg(feature = "is_in")] +mod is_in; +mod literals; +mod slice; +mod window; + +use super::*; diff --git a/crates/polars/tests/it/lazy/expressions/slice.rs b/crates/polars/tests/it/lazy/expressions/slice.rs new file mode 100644 index 000000000000..0b4cbf6ff93a --- /dev/null +++ b/crates/polars/tests/it/lazy/expressions/slice.rs @@ -0,0 +1,28 @@ +use polars_core::prelude::*; + +use super::*; + +#[test] +fn test_slice_args() -> PolarsResult<()> { + let groups: StringChunked = std::iter::repeat_n("a", 10) + .chain(std::iter::repeat_n("b", 20)) + .collect(); + + let df = df![ + "groups" => groups.into_series(), + "vals" => 0i32..30 + ]? + .lazy() + .group_by_stable([col("groups")]) + .agg([col("vals").slice(lit(0i64), (len() * lit(0.2)).cast(DataType::Int32))]) + .collect()?; + + let out = df.column("vals")?.explode()?; + let out = out.i32().unwrap(); + assert_eq!( + out.into_no_null_iter().collect::>(), + &[0, 1, 10, 11, 12, 13] + ); + + Ok(()) +} diff --git a/crates/polars/tests/it/lazy/expressions/window.rs b/crates/polars/tests/it/lazy/expressions/window.rs new file mode 100644 index 000000000000..fb52ac3810bb --- /dev/null +++ b/crates/polars/tests/it/lazy/expressions/window.rs @@ -0,0 +1,399 @@ +use crate::lazy::*; + +#[test] +fn test_lazy_window_functions() { + let df = df! { + "groups" => &[1, 1, 2, 2, 1, 2, 3, 3, 1], + "values" => &[1, 2, 3, 4, 5, 6, 7, 8, 8] + } + .unwrap(); + + // sums + // 1 => 16 + // 2 => 13 + // 3 => 15 + let correct = [16, 16, 13, 13, 16, 13, 15, 15, 16] + .iter() + .copied() + .map(Some) + .collect::>(); + + // test if groups is available after projection pushdown. + let _ = df + .clone() + .lazy() + .select(&[avg("values").over([col("groups")]).alias("part")]) + .collect() + .unwrap(); + // test if partition aggregation is correct + let out = df + .lazy() + .select([col("groups"), sum("values").over([col("groups")])]) + .collect() + .unwrap(); + assert_eq!( + Vec::from(out.select_at_idx(1).unwrap().i32().unwrap()), + correct + ); +} + +#[test] +fn test_shift_and_fill_window_function() -> PolarsResult<()> { + let df = fruits_cars(); + + // a ternary expression with a final list aggregation + let out1 = df + .clone() + .lazy() + .select([ + col("fruits"), + col("B").shift_and_fill(lit(-1), lit(-1)).over_with_options( + [col("fruits")], + None, + WindowMapping::Join, + ), + ]) + .collect()?; + + // same expression, no final list aggregation + let out2 = df + .lazy() + .select([ + col("fruits"), + col("B").shift_and_fill(lit(-1), lit(-1)).over_with_options( + [col("fruits")], + None, + WindowMapping::Join, + ), + ]) + .collect()?; + + assert!(out1.equals(&out2)); + + Ok(()) +} + +#[test] +fn test_exploded_window_function() -> PolarsResult<()> { + let df = fruits_cars(); + + let out = df + .clone() + .lazy() + .sort(["fruits"], Default::default()) + .select([ + col("fruits"), + col("B") + .shift(lit(1)) + .over_with_options([col("fruits")], None, WindowMapping::Explode) + .alias("shifted"), + ]) + .collect()?; + + assert_eq!( + Vec::from(out.column("shifted")?.i32()?), + &[None, Some(3), None, Some(5), Some(4)] + ); + + // this tests if cast succeeds in aggregation context + // we implicitly also test that a literal does not upcast a column + let out = df + .lazy() + .sort(["fruits"], Default::default()) + .select([ + col("fruits"), + col("B") + .shift_and_fill(lit(1), lit(-1.0f32)) + .over_with_options([col("fruits")], None, WindowMapping::Explode) + .alias("shifted"), + ]) + .collect()?; + + // even though we fill with f32, cast i32 -> f32 can overflow so the result is f64 + assert_eq!( + Vec::from(out.column("shifted")?.f64()?), + &[Some(-1.0), Some(3.0), Some(-1.0), Some(5.0), Some(4.0)] + ); + Ok(()) +} + +#[test] +fn test_reverse_in_groups() -> PolarsResult<()> { + let df = fruits_cars(); + + let out = df + .lazy() + .sort(["fruits"], Default::default()) + .select([ + col("B"), + col("fruits"), + col("B").reverse().over([col("fruits")]).alias("rev"), + ]) + .collect()?; + + assert_eq!( + Vec::from(out.column("rev")?.i32()?), + &[Some(2), Some(3), Some(1), Some(4), Some(5)] + ); + Ok(()) +} + +#[test] +fn test_sort_by_in_groups() -> PolarsResult<()> { + let df = fruits_cars(); + + let out = df + .lazy() + .sort(["cars"], Default::default()) + .select([ + col("fruits"), + col("cars"), + col("A") + .sort_by([col("B")], SortMultipleOptions::default()) + .over([col("cars")]) + .alias("sorted_A_by_B"), + ]) + .collect()?; + + assert_eq!( + Vec::from(out.column("sorted_A_by_B")?.i32()?), + &[Some(2), Some(5), Some(4), Some(3), Some(1)] + ); + Ok(()) +} +#[test] +#[cfg(feature = "cum_agg")] +fn test_literal_window_fn() -> PolarsResult<()> { + let df = df![ + "chars" => ["a", "a", "b"] + ]?; + + let out = df + .lazy() + .select([repeat(1, len()) + .cum_sum(false) + .over_with_options([col("chars")], None, WindowMapping::Join) + .alias("foo")]) + .collect()?; + + let out = out.column("foo")?; + assert!(matches!(out.dtype(), DataType::List(_))); + let flat = out.explode()?; + let flat = flat.i32()?; + assert_eq!( + Vec::from(flat), + &[Some(1), Some(2), Some(1), Some(2), Some(1)] + ); + + Ok(()) +} + +#[test] +fn test_window_mapping() -> PolarsResult<()> { + let df = fruits_cars(); + + // no aggregation + let out = df + .clone() + .lazy() + .select([col("A").over([col("fruits")])]) + .collect()?; + + assert!(out.column("A")?.equals(df.column("A")?)); + + let out = df + .clone() + .lazy() + .select([col("A"), lit(0).over([col("fruits")])]) + .collect()?; + + assert_eq!(out.shape(), (5, 2)); + + let out = df + .clone() + .lazy() + .select([(lit(10) + col("A")).alias("foo").over([col("fruits")])]) + .collect()?; + + let expected = Column::new("foo".into(), [11, 12, 13, 14, 15]); + assert!(out.column("foo")?.equals(&expected)); + + let out = df + .clone() + .lazy() + .select([ + col("fruits"), + col("B"), + col("A"), + (col("B").sum() + col("A")) + .alias("foo") + .over([col("fruits")]), + ]) + .collect()?; + let expected = Column::new("foo".into(), [11, 12, 8, 9, 15]); + assert!(out.column("foo")?.equals(&expected)); + + let out = df + .clone() + .lazy() + .select([ + col("fruits"), + col("A"), + col("B"), + (col("B").shift(lit(1)) - col("A")) + .alias("foo") + .over([col("fruits")]), + ]) + .collect()?; + let expected = Column::new("foo".into(), [None, Some(3), None, Some(-1), Some(-1)]); + assert!(out.column("foo")?.equals_missing(&expected)); + + // now sorted + // this will trigger a fast path + let df = df.sort(["fruits"], Default::default())?; + + let out = df + .clone() + .lazy() + .select([(lit(10) + col("A")).alias("foo").over([col("fruits")])]) + .collect()?; + let expected = Column::new("foo".into(), [13, 14, 11, 12, 15]); + assert!(out.column("foo")?.equals(&expected)); + + let out = df + .clone() + .lazy() + .select([ + col("fruits"), + col("B"), + col("A"), + (col("B").sum() + col("A")) + .alias("foo") + .over([col("fruits")]), + ]) + .collect()?; + + let expected = Column::new("foo".into(), [8, 9, 11, 12, 15]); + assert!(out.column("foo")?.equals(&expected)); + + let out = df + .lazy() + .select([ + col("fruits"), + col("A"), + col("B"), + (col("B").shift(lit(1)) - col("A")) + .alias("foo") + .over([col("fruits")]), + ]) + .collect()?; + + let expected = Column::new("foo".into(), [None, Some(-1), None, Some(3), Some(-1)]); + assert!(out.column("foo")?.equals_missing(&expected)); + + Ok(()) +} + +#[test] +fn test_window_exprs_in_binary_exprs() -> PolarsResult<()> { + let q = df![ + "value" => 0..8, + "cat" => [0, 0, 0, 0, 1, 1, 1, 1] + ]? + .lazy() + .with_columns([ + (col("value") - col("value").mean().over([col("cat")])) + .cast(DataType::Int32) + .alias("centered"), + (col("value") - col("value").std(1).over([col("cat")])) + .cast(DataType::Int32) + .alias("scaled"), + ((col("value") - col("value").mean().over([col("cat")])) + / col("value").std(1).over([col("cat")])) + .cast(DataType::Int32) + .alias("stdized"), + ((col("value") - col("value").mean()).over([col("cat")]) / col("value").std(1)) + .cast(DataType::Int32) + .alias("stdized2"), + ((col("value") - col("value").mean()) / col("value").std(1)) + .over([col("cat")]) + .cast(DataType::Int32) + .alias("stdized3"), + ]) + .sum(); + + let df = q.collect()?; + + let expected = df![ + "value" => [28], + "cat" => [4], + "centered" => [0], + "scaled" => [14], + "stdized" => [0], + "stdized2" => [0], + "stdized3" => [0] + ]?; + + assert!(df.equals(&expected)); + + Ok(()) +} + +#[test] +fn test_window_exprs_any_all() -> PolarsResult<()> { + let df = df![ + "var1"=> ["A", "B", "C", "C", "D", "D", "E", "E"], + "var2"=> [false, true, false, false, false, true, true, true], + ]? + .lazy() + .select([ + col("var2").any(true).over([col("var1")]).alias("any"), + col("var2").all(true).over([col("var1")]).alias("all"), + ]) + .collect()?; + + let expected = df![ + "any" => [false, true, false, false, true, true, true, true], + "all" => [false, true, false, false, false, false, true, true], + ]?; + assert!(df.equals(&expected)); + Ok(()) +} + +#[test] +fn test_window_naive_any() -> PolarsResult<()> { + let df = df![ + "row_id" => [0, 0, 1, 1, 1], + "boolvar" => [true, false, true, false, false] + ]?; + + let df = df + .lazy() + .with_column( + col("boolvar") + .sum() + .gt(lit(0)) + .over([col("row_id")]) + .alias("res"), + ) + .collect()?; + + let res = df.column("res")?; + assert_eq!(res.as_materialized_series().sum::().unwrap(), 5); + Ok(()) +} + +#[test] +fn test_window_map_empty_df_3542() -> PolarsResult<()> { + let df = df![ + "x" => ["a", "b", "c"], + "y" => [Some(1), None, Some(3)] + ]?; + let out = df + .lazy() + .filter(col("y").lt(0)) + .select([col("y").fill_null(0).last().over([col("y")])]) + .collect()?; + assert_eq!(out.height(), 0); + Ok(()) +} diff --git a/crates/polars/tests/it/lazy/exprs.rs b/crates/polars/tests/it/lazy/exprs.rs new file mode 100644 index 000000000000..535ef5867917 --- /dev/null +++ b/crates/polars/tests/it/lazy/exprs.rs @@ -0,0 +1,115 @@ +use polars::prelude::*; + +#[ignore] +#[test] +fn fuzz_exprs() { + const PRIMES: &[i32] = &[2, 3, 5, 7, 11, 13, 17, 19, 23, 29]; + use rand::Rng; + + let lf = DataFrame::new(vec![ + Column::new("A".into(), vec![1, 2, 3, 4, 5]), + Column::new("B".into(), vec![Some(5), Some(4), None, Some(2), Some(1)]), + Column::new( + "C".into(), + vec!["str", "", "a quite long string", "my", "string"], + ), + ]) + .unwrap() + .lazy(); + let empty = DataFrame::new(vec![ + Column::new("A".into(), Vec::::new()), + Column::new("B".into(), Vec::::new()), + Column::new("C".into(), Vec::<&str>::new()), + ]) + .unwrap() + .lazy(); + + fn rnd_prime(rng: &'_ mut rand::rngs::ThreadRng) -> i32 { + PRIMES[rng.gen_range(0..PRIMES.len())] + } + + fn gen_expr(rng: &mut rand::rngs::ThreadRng) -> Expr { + let mut depth = 0; + + use rand::Rng; + + fn leaf(rng: &mut rand::rngs::ThreadRng) -> Expr { + match rng.r#gen::() % 4 { + 0 => col("A"), + 1 => col("B"), + 2 => col("C"), + _ => lit(rnd_prime(rng)), + } + } + + let mut e = leaf(rng); + + loop { + if depth >= 10 || rng.r#gen::() % 4 == 0 { + return e; + } else { + let rhs = leaf(rng); + + e = match rng.r#gen::() % 19 { + 0 => e.eq(rhs), + 1 => e.eq_missing(rhs), + 2 => e.neq(rhs), + 3 => e.neq_missing(rhs), + 4 => e.lt(rhs), + 5 => e.lt_eq(rhs), + 6 => e.gt(rhs), + 7 => e.gt_eq(rhs), + 8 => e + rhs, + 9 => e - rhs, + 10 => e * rhs, + 11 => e / rhs, + 12 => Expr::BinaryExpr { + left: Arc::new(e), + right: Arc::new(rhs), + op: Operator::TrueDivide, + }, + 13 => e.floor_div(rhs), + 14 => e % rhs, + 15 => e.and(rhs), + 16 => e.or(rhs), + 17 => e.xor(rhs), + 18 => e.logical_and(rhs), + 19 => e.logical_or(rhs), + _ => unreachable!(), + }; + } + + depth += 1; + } + } + + let mut rng = rand::thread_rng(); + let rng = &mut rng; + + let num_fuzzes = 100_000; + for _ in 0..num_fuzzes { + let exprs = vec![ + gen_expr(rng).alias("X"), + gen_expr(rng).alias("Y"), + gen_expr(rng).alias("Z"), + gen_expr(rng).alias("W"), + gen_expr(rng).alias("I"), + gen_expr(rng).alias("J"), + ]; + + let wc = match rng.r#gen::() % 2 { + 0 => lf.clone(), + _ => empty.clone(), + }; + let wc = wc.with_columns(exprs); + + let unoptimized = wc.clone().without_optimizations(); + let optimized = wc; + + match (optimized.collect(), unoptimized.collect()) { + (Ok(o), Ok(u)) => assert_eq!(o, u), + (Err(_), Err(_)) => {}, + (_, _) => panic!("One failed!"), + } + } +} diff --git a/crates/polars/tests/it/lazy/folds.rs b/crates/polars/tests/it/lazy/folds.rs new file mode 100644 index 000000000000..78bd97e5f03c --- /dev/null +++ b/crates/polars/tests/it/lazy/folds.rs @@ -0,0 +1,27 @@ +use super::*; + +#[test] +fn test_fold_wildcard() -> PolarsResult<()> { + let df1 = df![ + "a" => [1, 2, 3], + "b" => [1, 2, 3] + ]?; + + let out = df1 + .clone() + .lazy() + .select([fold_exprs(lit(0), |a, b| (&a + &b).map(Some), [col("*")]).alias("foo")]) + .collect()?; + + assert_eq!( + Vec::from(out.column("foo")?.i32()?), + &[Some(2), Some(4), Some(6)] + ); + + // test if we don't panic due to wildcard + let _out = df1 + .lazy() + .select([polars_lazy::dsl::all_horizontal([col("*").is_not_null()])?]) + .collect()?; + Ok(()) +} diff --git a/crates/polars/tests/it/lazy/functions.rs b/crates/polars/tests/it/lazy/functions.rs new file mode 100644 index 000000000000..ff6dfec5a0b1 --- /dev/null +++ b/crates/polars/tests/it/lazy/functions.rs @@ -0,0 +1,26 @@ +use super::*; + +#[test] +#[cfg(all(feature = "concat_str", feature = "strings"))] +fn test_format_str() { + let a = df![ + "a" => [1, 2], + "b" => ["a", "b"] + ] + .unwrap(); + + let out = a + .lazy() + .select([format_str("({}, {}]", [col("a"), col("b")]) + .unwrap() + .alias("formatted")]) + .collect() + .unwrap(); + + let expected = df![ + "formatted" => ["(1, a]", "(2, b]"] + ] + .unwrap(); + + assert!(out.equals_missing(&expected)); +} diff --git a/crates/polars/tests/it/lazy/group_by.rs b/crates/polars/tests/it/lazy/group_by.rs new file mode 100644 index 000000000000..0de49977d4ed --- /dev/null +++ b/crates/polars/tests/it/lazy/group_by.rs @@ -0,0 +1,185 @@ +// used only if feature="dtype-duration", "dtype-struct" +#[allow(unused_imports)] +use polars_core::SINGLE_LOCK; +#[cfg(feature = "rank")] +use polars_core::series::ops::NullBehavior; + +use super::*; + +#[test] +#[cfg(feature = "rank")] +fn test_filter_sort_diff_2984() -> PolarsResult<()> { + // make sure that sort does not oob if filter returns no values + let df = df![ + "group"=> ["A" ,"A", "A", "B", "B", "B", "B"], + "id"=> [1, 2, 1, 4, 5, 4, 6], + ]?; + + let out = df + .lazy() + // don't use stable in this test, it hides wrong state + .group_by([col("group")]) + .agg([col("id") + .filter(col("id").lt(lit(3))) + .sort(Default::default()) + .diff(lit(1), Default::default()) + .sum()]) + .sort(["group"], Default::default()) + .collect()?; + + assert_eq!(Vec::from(out.column("id")?.i32()?), &[Some(1), Some(0)]); + Ok(()) +} + +#[test] +fn test_filter_after_tail() -> PolarsResult<()> { + let df = df![ + "a" => ["foo", "foo", "bar"], + "b" => [1, 2, 3] + ]?; + + let out = df + .lazy() + .group_by_stable([col("a")]) + .tail(Some(1)) + .filter(col("b").eq(lit(3))) + .with_predicate_pushdown(false) + .collect()?; + + let expected = df![ + "a" => ["bar"], + "b" => [3] + ]?; + assert!(out.equals(&expected)); + + Ok(()) +} + +#[test] +#[cfg(feature = "diff")] +fn test_filter_diff_arithmetic() -> PolarsResult<()> { + let df = df![ + "user" => [1, 1, 1, 1, 2], + "group" => [1, 2, 1, 1, 2], + "value" => [1, 5, 14, 17, 20] + ]?; + + let out = df + .lazy() + .group_by([col("user")]) + .agg([(col("value") + .filter(col("group").eq(lit(1))) + .diff(lit(1), Default::default()) + * lit(2)) + .alias("diff")]) + .sort(["user"], Default::default()) + .explode([col("diff")]) + .collect()?; + + let out = out.column("diff")?; + assert_eq!( + out, + &Column::new("diff".into(), &[None, Some(26), Some(6), None]) + ); + + Ok(()) +} + +#[test] +fn test_group_by_lit_agg() -> PolarsResult<()> { + let df = df![ + "group" => [1, 2, 1, 1, 2], + ]?; + + let out = df + .lazy() + .group_by([col("group")]) + .agg([lit("foo").alias("foo")]) + .collect()?; + + assert_eq!(out.column("foo")?.dtype(), &DataType::String); + + Ok(()) +} + +#[test] +#[cfg(feature = "diff")] +fn test_group_by_agg_list_with_not_aggregated() -> PolarsResult<()> { + let df = df![ + "group" => ["a", "a", "a", "a", "a", "a", "b", "b", "b", "b", "b", "b"], + "value" => [0, 2, 3, 6, 2, 4, 7, 9, 3, 4, 6, 7, ], + ]?; + + let out = df + .lazy() + .group_by([col("group")]) + .agg([ + when(col("value").diff(lit(1), NullBehavior::Ignore).gt_eq(0)) + .then(col("value").diff(lit(1), NullBehavior::Ignore)) + .otherwise(col("value")), + ]) + .sort(["group"], Default::default()) + .collect()?; + + let out = out.column("value")?; + let out = out.explode()?; + assert_eq!( + out, + Column::new("value".into(), &[0, 2, 1, 3, 2, 2, 7, 2, 3, 1, 2, 1]) + ); + Ok(()) +} + +#[test] +#[cfg(all(feature = "dtype-duration", feature = "dtype-decimal"))] +fn test_logical_mean_partitioned_group_by_block() -> PolarsResult<()> { + let _guard = SINGLE_LOCK.lock(); + let df = df![ + "decimal" => [1, 1, 2], + "duration" => [1000, 2000, 3000] + ]?; + + let out = df + .lazy() + .with_column(col("decimal").cast(DataType::Decimal(None, Some(2)))) + .with_column(col("duration").cast(DataType::Duration(TimeUnit::Microseconds))) + .group_by([col("decimal")]) + .agg([col("duration").mean()]) + .sort(["duration"], Default::default()) + .collect()?; + + let duration = out.column("duration")?; + + assert_eq!( + duration.get(0)?, + AnyValue::Duration(1500, TimeUnit::Microseconds) + ); + + Ok(()) +} + +#[test] +fn test_filter_aggregated_expression() -> PolarsResult<()> { + let df: DataFrame = df![ + "day" => [2, 2, 2, 2, 2, 2, 1, 1], + "y" => [Some(4), Some(5), Some(8), Some(7), Some(9), None, None, None], + "x" => [1, 2, 3, 4, 5, 6, 1, 2], + ]?; + + let f = col("y").is_not_null().and(col("x").is_not_null()); + + let df = df + .lazy() + .group_by([col("day")]) + .agg([(col("x") - col("x").first()).filter(f)]) + .sort(["day"], Default::default()) + .collect() + .unwrap(); + let x = df.column("x")?; + + assert_eq!( + x.get(1).unwrap(), + AnyValue::List(Series::new("".into(), [0, 1, 2, 3, 4])) + ); + Ok(()) +} diff --git a/crates/polars/tests/it/lazy/group_by_dynamic.rs b/crates/polars/tests/it/lazy/group_by_dynamic.rs new file mode 100644 index 000000000000..e4c57a4a6ea1 --- /dev/null +++ b/crates/polars/tests/it/lazy/group_by_dynamic.rs @@ -0,0 +1,63 @@ +// used only if feature="temporal", "dtype-date", "dynamic_group_by" +#[allow(unused_imports)] +use chrono::prelude::*; + +// used only if feature="temporal", "dtype-date", "dynamic_group_by" +#[allow(unused_imports)] +use super::*; + +#[test] +#[cfg(all( + feature = "temporal", + feature = "dtype-date", + feature = "dynamic_group_by" +))] +fn test_group_by_dynamic_week_bounds() -> PolarsResult<()> { + let start = NaiveDate::from_ymd_opt(2022, 2, 1) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap(); + let stop = NaiveDate::from_ymd_opt(2022, 2, 14) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap(); + let range = polars_time::date_range( + "dt".into(), + start, + stop, + Duration::parse("1d"), + ClosedWindow::Left, + TimeUnit::Milliseconds, + None, + )? + .into_series(); + + let a = Int32Chunked::full("a".into(), 1, range.len()); + let df = df![ + "dt" => range, + "a" => a + ]?; + + let out = df + .lazy() + .group_by_dynamic( + col("dt"), + [], + DynamicGroupOptions { + every: Duration::parse("1w"), + period: Duration::parse("1w"), + offset: Duration::parse("0w"), + closed_window: ClosedWindow::Left, + label: Label::DataPoint, + include_boundaries: true, + start_by: StartBy::DataPoint, + ..Default::default() + }, + ) + .agg([col("a").sum()]) + .collect()?; + let a = out.column("a")?; + assert_eq!(a.get(0)?, AnyValue::Int32(7)); + assert_eq!(a.get(1)?, AnyValue::Int32(6)); + Ok(()) +} diff --git a/crates/polars/tests/it/lazy/mod.rs b/crates/polars/tests/it/lazy/mod.rs new file mode 100644 index 000000000000..4dc3c9868a26 --- /dev/null +++ b/crates/polars/tests/it/lazy/mod.rs @@ -0,0 +1,27 @@ +mod aggregation; +#[cfg(feature = "cse")] +mod cse; +mod cwc; +mod explodes; +mod expressions; +mod exprs; +mod folds; +mod functions; +mod group_by; +mod group_by_dynamic; +mod predicate_queries; +mod projection_queries; +mod queries; +mod schema; + +use polars::prelude::*; + +pub(crate) fn fruits_cars() -> DataFrame { + df!( + "A"=> [1, 2, 3, 4, 5], + "fruits"=> ["banana", "banana", "apple", "apple", "banana"], + "B"=> [5, 4, 3, 2, 1], + "cars"=> ["beetle", "audi", "beetle", "beetle", "beetle"] + ) + .unwrap() +} diff --git a/crates/polars/tests/it/lazy/predicate_queries.rs b/crates/polars/tests/it/lazy/predicate_queries.rs new file mode 100644 index 000000000000..de7b1cee5d6d --- /dev/null +++ b/crates/polars/tests/it/lazy/predicate_queries.rs @@ -0,0 +1,273 @@ +// used only if feature="is_in", feature="dtype-categorical" +#[cfg(all(feature = "is_in", feature = "dtype-categorical"))] +use polars_core::{SINGLE_LOCK, StringCacheHolder, disable_string_cache}; + +use super::*; + +#[test] +fn test_predicate_after_renaming() -> PolarsResult<()> { + let df = df![ + "foo" => [1, 2, 3], + "bar" => [3, 2, 1] + ]? + .lazy() + .rename(["foo", "bar"], ["foo2", "bar2"], true) + .filter(col("foo2").eq(col("bar2"))) + .collect()?; + + let expected = df![ + "foo2" => [2], + "bar2" => [2], + ]?; + assert!(df.equals(&expected)); + + Ok(()) +} + +#[test] +fn filter_true_lit() -> PolarsResult<()> { + let df = df! { + "a" => [Some(true), Some(false), None], + "b" => ["1", "2", "3"] + }?; + let filter = col("a").eq(lit(true)); + let with_true = df.clone().lazy().filter(filter.clone()).collect()?; + let with_not_true = df + .clone() + .lazy() + .filter(not(filter)) + .with_predicate_pushdown(false) + .with_projection_pushdown(false) + .collect()?; + let with_null = df + .clone() + .lazy() + .filter(col("a").is_null()) + .with_predicate_pushdown(false) + .with_projection_pushdown(false) + .collect()?; + let res = with_true.vstack(&with_not_true)?; + let res = res.vstack(&with_null)?; + assert!(res.equals_missing(&df)); + Ok(()) +} + +fn create_n_filters(col_name: &str, num_filters: usize) -> Vec { + (0..num_filters) + .map(|i| col(col_name).eq(lit(format!("{}", i)))) + .collect() +} + +fn and_filters(expr: Vec) -> Expr { + expr.into_iter().reduce(polars::prelude::Expr::and).unwrap() +} + +#[test] +fn test_many_filters() -> PolarsResult<()> { + // just check if it runs. in #3210 + // we had terrible tree traversion perf. + let df = df! { + "id" => ["1", "2"] + }?; + let filters = create_n_filters("id", 30); + let _ = df + .lazy() + .filter(and_filters(filters)) + .with_predicate_pushdown(false) + .collect()?; + + Ok(()) +} + +#[test] +fn test_filter_no_combine() -> PolarsResult<()> { + let df = df![ + "vals" => [1, 2, 3, 4, 5] + ]?; + + let out = df + .lazy() + .filter(col("vals").gt(lit(1))) + // should be > 2 + // if optimizer would combine predicates this would be flawed + .filter(col("vals").gt(col("vals").min())) + .collect()?; + + assert_eq!( + Vec::from(out.column("vals")?.i32()?), + &[Some(3), Some(4), Some(5)] + ); + + Ok(()) +} + +#[test] +fn test_filter_block_join() -> PolarsResult<()> { + let df_a = df![ + "a" => ["a", "b", "c"], + "c" => [1, 4, 6] + ]?; + let df_b = df![ + "a" => ["a", "a", "c"], + "d" => [2, 4, 3] + ]?; + + let out = df_a + .lazy() + .left_join(df_b.lazy(), "a", "a") + // mean is influence by join + .filter(col("c").mean().eq(col("d"))) + .collect()?; + assert_eq!(out.shape(), (1, 3)); + + Ok(()) +} + +#[test] +#[cfg(all(feature = "is_in", feature = "dtype-categorical"))] +fn test_is_in_categorical_3420() -> PolarsResult<()> { + let df = df![ + "a" => ["a", "b", "c", "d", "e"], + "b" => [1, 2, 3, 4, 5] + ]?; + + let _guard = SINGLE_LOCK.lock(); + disable_string_cache(); + let _sc = StringCacheHolder::hold(); + + let s = Series::new("x".into(), ["a", "b", "c"]) + .strict_cast(&DataType::Categorical(None, Default::default()))?; + let out = df + .lazy() + .with_column(col("a").strict_cast(DataType::Categorical(None, Default::default()))) + .filter(col("a").is_in(lit(s).alias("x"), false)) + .collect()?; + + let mut expected = df![ + "a" => ["a", "b", "c"], + "b" => [1, 2, 3] + ]?; + expected.try_apply("a", |s| { + s.cast(&DataType::Categorical(None, Default::default())) + })?; + assert!(out.equals(&expected)); + Ok(()) +} + +#[test] +fn test_predicate_pushdown_blocked_by_outer_join() -> PolarsResult<()> { + let df1 = df! { + "a" => ["a1", "a2"], + "b" => ["b1", "b2"] + }?; + let df2 = df! { + "b" => ["b2", "b3"], + "c" => ["c2", "c3"] + }?; + let df = df1.lazy().full_join(df2.lazy(), col("b"), col("b")); + let out = df.filter(col("a").eq(lit("a1"))).collect()?; + let null: Option<&str> = None; + let expected = df![ + "a" => ["a1"], + "b" => ["b1"], + "b_right" => [null], + "c" => [null], + ]?; + assert!(out.equals_missing(&expected)); + Ok(()) +} + +#[test] +fn test_binaryexpr_pushdown_left_join_9506() -> PolarsResult<()> { + let df1 = df! { + "a" => ["a1"], + "b" => ["b1"] + }?; + let df2 = df! { + "b" => ["b1"], + "c" => ["c1"] + }?; + let df = df1.lazy().left_join(df2.lazy(), col("b"), col("b")); + let out = df.filter(col("c").eq(lit("c2"))).collect()?; + assert!(out.is_empty()); + Ok(()) +} + +#[test] +fn test_count_blocked_at_union_3963() -> PolarsResult<()> { + let lf1 = df![ + "k" => ["x", "x", "y"], + "v" => [3, 2, 6] + ]? + .lazy(); + + let lf2 = df![ + "k" => ["a", "a", "b"], + "v" => [1, 8, 5] + ]? + .lazy(); + + let expected = df![ + "k" => ["x", "x", "a", "a"], + "v" => [3, 2, 1, 8] + ]?; + + for rechunk in [true, false] { + let out = concat( + [lf1.clone(), lf2.clone()], + UnionArgs { + rechunk, + parallel: true, + ..Default::default() + }, + )? + .filter(len().over([col("k")]).gt(lit(1))) + .collect()?; + + assert!(out.equals(&expected)); + } + + Ok(()) +} + +#[test] +fn test_predicate_on_join_select_4884() -> PolarsResult<()> { + let lf = df![ + "x" => [0, 1], + "y" => [1, 2], + ]? + .lazy(); + let out = (lf.clone().join_builder().with(lf)) + .left_on([col("y")]) + .right_on([col("x")]) + .suffix("_right") + .finish() + .select([col("x"), col("y_right").alias("y")]) + .filter(col("x").neq(col("y")).and(col("y").eq(2))) + .collect()?; + + let expected = df![ + "x" => [0], + "y" => [2], + ]?; + assert_eq!(out, expected); + Ok(()) +} + +#[test] +fn test_predicate_pushdown_block_8847() -> PolarsResult<()> { + let ldf = df![ + "A" => [1, 2, 3] + ]? + .lazy(); + + let q = ldf + .with_column(lit(1).strict_cast(DataType::Int32).alias("B")) + .drop_nulls(None) + .filter(col("B").eq(lit(1))); + + let out = q.collect()?; + assert_eq!(out.get_column_names(), &["A", "B"]); + + Ok(()) +} diff --git a/crates/polars/tests/it/lazy/projection_queries.rs b/crates/polars/tests/it/lazy/projection_queries.rs new file mode 100644 index 000000000000..e5870e81ce4e --- /dev/null +++ b/crates/polars/tests/it/lazy/projection_queries.rs @@ -0,0 +1,188 @@ +use polars::prelude::*; + +#[test] +fn test_sum_after_filter() -> PolarsResult<()> { + let df = df![ + "ids" => 0..10, + "values" => 10..20, + ]? + .lazy() + .filter(not(col("ids").eq(lit(5)))) + .select([col("values").sum()]) + .collect()?; + + assert_eq!(df.column("values")?.get(0)?, AnyValue::Int32(130)); + Ok(()) +} + +#[test] +fn test_swap_rename() -> PolarsResult<()> { + let df = df![ + "a" => [1], + "b" => [2], + ]? + .lazy() + .rename(["a", "b"], ["b", "a"], true) + .collect()?; + + let expected = df![ + "b" => [1], + "a" => [2], + ]?; + assert!(df.equals(&expected)); + Ok(()) +} + +#[test] +fn test_full_outer_join_with_column_2988() -> PolarsResult<()> { + let ldf1 = df![ + "key1" => ["foo", "bar"], + "key2" => ["foo", "bar"], + "val1" => [3, 1] + ]? + .lazy(); + + let ldf2 = df![ + "key1" => ["bar", "baz"], + "key2" => ["bar", "baz"], + "val2" => [6, 8] + ]? + .lazy(); + + let out = ldf1 + .join( + ldf2, + [col("key1"), col("key2")], + [col("key1"), col("key2")], + JoinArgs::new(JoinType::Full).with_coalesce(JoinCoalesce::CoalesceColumns), + ) + .with_columns([col("key1")]) + .collect()?; + assert_eq!(out.get_column_names(), &["key1", "key2", "val1", "val2"]); + assert_eq!( + Vec::from(out.column("key1")?.str()?), + &[Some("bar"), Some("baz"), Some("foo")] + ); + assert_eq!( + Vec::from(out.column("key2")?.str()?), + &[Some("bar"), Some("baz"), Some("foo")] + ); + assert_eq!( + Vec::from(out.column("val1")?.i32()?), + &[Some(1), None, Some(3)] + ); + assert_eq!( + Vec::from(out.column("val2")?.i32()?), + &[Some(6), Some(8), None] + ); + + Ok(()) +} + +#[test] +fn test_many_aliasing_projections_5070() -> PolarsResult<()> { + let df = df! { + "date" => [1, 2, 3], + "val" => [1, 2, 3], + }?; + + let out = df + .lazy() + .filter(col("date").gt(lit(1))) + .select([col("*")]) + .with_columns([col("val").max().alias("max")]) + .with_column(col("max").alias("diff")) + .with_column((col("val") / col("diff")).alias("output")) + .select([all().exclude(["max", "diff"])]) + .collect()?; + let expected = df![ + "date" => [2, 3], + "val" => [2, 3], + "output" => [0, 1], + ]?; + assert!(out.equals(&expected)); + + Ok(()) +} + +#[test] +#[cfg(feature = "cum_agg")] +fn test_projection_5086() -> PolarsResult<()> { + let df = df![ + "a" => ["a", "a", "a", "b"], + "b" => [1, 0, 1, 0], + "c" => [0, 1, 2, 0], + ]?; + + let out = df + .lazy() + .select([ + col("a"), + col("b") + .gather("c") + .cum_sum(false) + .over([col("a")]) + .gt(lit(0)), + ]) + .select([ + col("a"), + col("b") + .xor(col("b").shift(lit(1)).over([col("a")])) + .fill_null(lit(true)) + .alias("keep"), + ]) + .collect()?; + + let expected = df![ + "a" => ["a", "a", "a", "b"], + "keep" => [true, false, false, true] + ]?; + + assert!(out.equals(&expected)); + + Ok(()) +} + +#[test] +#[cfg(feature = "dtype-struct")] +fn test_unnest_pushdown() -> PolarsResult<()> { + let df = df![ + "collection" => Series::full_null("".into(), 1, &DataType::Int32), + "users" => Series::full_null("".into(), 1, &DataType::List(Box::new(DataType::Struct(vec![Field::new("email".into(), DataType::String)])))), + ]?; + + let out = df + .lazy() + .explode(["users"]) + .unnest(["users"]) + .select([col("email")]) + .collect()?; + + assert_eq!(out.get_column_names(), &["email"]); + + Ok(()) +} + +#[test] +fn test_join_duplicate_7314() -> PolarsResult<()> { + let df_a: DataFrame = df![ + "a" => [1, 2, 2], + "b" => [4, 5, 6], + "c" => [1, 1, 1], + ]?; + + let df_b: DataFrame = df![ + "a" => [1, 2, 2], + "b" => [4, 5, 6], + "d" => [1, 1, 1], + ]?; + + let out = df_a + .lazy() + .inner_join(df_b.lazy(), col("a"), col("b")) + .select([col("a"), col("c") * col("d")]) + .collect()?; + + assert_eq!(out.get_column_names(), &["a", "c"]); + Ok(()) +} diff --git a/crates/polars/tests/it/lazy/queries.rs b/crates/polars/tests/it/lazy/queries.rs new file mode 100644 index 000000000000..2103892ee12a --- /dev/null +++ b/crates/polars/tests/it/lazy/queries.rs @@ -0,0 +1,272 @@ +use polars_core::series::IsSorted; + +use super::*; + +#[test] +fn test_with_duplicate_column_empty_df() { + let a = Int32Chunked::from_slice("a".into(), &[]); + + assert_eq!( + DataFrame::new(vec![a.into_column()]) + .unwrap() + .lazy() + .with_columns([lit(true).alias("a")]) + .collect() + .unwrap() + .get_column_names(), + &["a"] + ); +} + +#[test] +fn test_drop() -> PolarsResult<()> { + // dropping all columns is a special case. It may fail because a projection + // that projects nothing could be misinterpreted as select all. + let out = df![ + "a" => [1], + ]? + .lazy() + .drop(["a"]) + .collect()?; + assert_eq!(out.width(), 0); + Ok(()) +} + +#[test] +#[cfg(feature = "dynamic_group_by")] +fn test_special_group_by_schemas() -> PolarsResult<()> { + let df = df![ + "a" => [1, 2, 3, 4, 5], + "b" => [1, 2, 3, 4, 5], + ]?; + + let out = df + .clone() + .lazy() + .with_column(col("a").set_sorted_flag(IsSorted::Ascending)) + .rolling( + col("a"), + [], + RollingGroupOptions { + period: Duration::parse("2i"), + offset: Duration::parse("0i"), + closed_window: ClosedWindow::Left, + ..Default::default() + }, + ) + .agg([col("b").sum().alias("sum")]) + .select([col("a"), col("sum")]) + .collect()?; + + assert_eq!( + out.column("sum")? + .i32()? + .into_no_null_iter() + .collect::>(), + &[3, 5, 7, 9, 5] + ); + + let out = df + .lazy() + .with_column(col("a").set_sorted_flag(IsSorted::Ascending)) + .group_by_dynamic( + col("a"), + [], + DynamicGroupOptions { + every: Duration::parse("2i"), + period: Duration::parse("2i"), + offset: Duration::parse("0i"), + label: Label::DataPoint, + include_boundaries: false, + closed_window: ClosedWindow::Left, + ..Default::default() + }, + ) + .agg([col("b").sum().alias("sum")]) + .select([col("a"), col("sum")]) + .collect()?; + + assert_eq!( + out.column("sum")? + .i32()? + .into_no_null_iter() + .collect::>(), + &[1, 5, 9] + ); + + Ok(()) +} + +#[test] +fn max_on_empty_df_3027() -> PolarsResult<()> { + let df = df! { + "id" => ["1"], + "name" => ["one"], + "numb" => [1] + }? + .head(Some(0)); + + let out = df + .lazy() + .group_by(&[col("id"), col("name")]) + .agg(&[col("numb").max()]) + .collect()?; + assert_eq!(out.shape(), (0, 3)); + Ok(()) +} + +#[test] +fn test_alias_before_cast() -> PolarsResult<()> { + let out = df![ + "a" => [1, 2, 3], + ]? + .lazy() + .select([col("a").alias("d").cast(DataType::Int32)]) + .select([all()]) + .collect()?; + assert_eq!( + Vec::from(out.column("d")?.i32()?), + &[Some(1), Some(2), Some(3)] + ); + Ok(()) +} + +#[test] +fn test_sorted_path() -> PolarsResult<()> { + // start with a sorted column and see if the metadata remains preserved + + let payloads = &[1, 2, 3]; + let df = df![ + "a"=> [AnyValue::List(Series::new("".into(), payloads)), AnyValue::List(Series::new("".into(), payloads)), AnyValue::List(Series::new("".into(), payloads))] + ]?; + + let out = df + .lazy() + .with_row_index("index", None) + .explode(["a"]) + .group_by(["index"]) + .agg([col("a").count().alias("count")]) + .collect()?; + + let s = out.column("index")?; + assert_eq!(s.is_sorted_flag(), IsSorted::Ascending); + + Ok(()) +} + +#[test] +fn test_sorted_path_joins() -> PolarsResult<()> { + let dfa = df![ + "a"=> [1, 2, 3] + ]?; + + let dfb = df![ + "a"=> [1, 2, 3] + ]?; + + let out = dfa + .lazy() + .with_column(col("a").set_sorted_flag(IsSorted::Ascending)) + .join(dfb.lazy(), [col("a")], [col("a")], JoinType::Left.into()) + .collect()?; + + let s = out.column("a")?; + assert_eq!(s.is_sorted_flag(), IsSorted::Ascending); + + Ok(()) +} + +#[test] +fn test_unknown_supertype_ignore() -> PolarsResult<()> { + let df = df![ + "col1" => [0., 3., 2., 1.], + "col2" => [0., 0., 1., 1.], + ]?; + + let out = df + .lazy() + .with_columns([(col("col1").fill_null(0f64) + col("col2"))]) + .collect()?; + assert_eq!(out.shape(), (4, 2)); + Ok(()) +} + +#[test] +fn test_apply_multiple_columns() -> PolarsResult<()> { + let df = fruits_cars(); + + let multiply = |s: &mut [Column]| (&(&s[0] * &s[0])? * &s[1]).map(Some); + + let out = df + .clone() + .lazy() + .select([map_multiple( + multiply, + [col("A"), col("B")], + GetOutput::from_type(DataType::Int32), + )]) + .collect()?; + let out = out.column("A")?; + let out = out.i32()?; + assert_eq!( + Vec::from(out), + &[Some(5), Some(16), Some(27), Some(32), Some(25)] + ); + + let out = df + .lazy() + .group_by_stable([col("cars")]) + .agg([apply_multiple( + multiply, + [col("A"), col("B")], + GetOutput::from_type(DataType::Int32), + false, + )]) + .collect()?; + + let out = out.column("A")?; + let out = out.list()?.get_as_series(1).unwrap(); + let out = out.i32()?; + + assert_eq!(Vec::from(out), &[Some(16)]); + Ok(()) +} + +#[test] +fn test_group_by_on_lists() -> PolarsResult<()> { + let s0 = Column::new("".into(), [1i32, 2, 3]); + let s1 = Column::new("groups".into(), [4i32, 5]); + + let mut builder = + ListPrimitiveChunkedBuilder::::new("arrays".into(), 10, 10, DataType::Int32); + builder.append_series(s0.as_materialized_series()).unwrap(); + builder.append_series(s1.as_materialized_series()).unwrap(); + let s2 = builder.finish().into_column(); + + let df = DataFrame::new(vec![s1, s2])?; + let out = df + .clone() + .lazy() + .group_by([col("groups")]) + .agg([col("arrays").first()]) + .collect()?; + + assert_eq!( + out.column("arrays")?.dtype(), + &DataType::List(Box::new(DataType::Int32)) + ); + + let out = df + .lazy() + .group_by([col("groups")]) + .agg([col("arrays").implode()]) + .collect()?; + + // a list of lists + assert_eq!( + out.column("arrays")?.dtype(), + &DataType::List(Box::new(DataType::List(Box::new(DataType::Int32)))) + ); + + Ok(()) +} diff --git a/crates/polars/tests/it/lazy/schema.rs b/crates/polars/tests/it/lazy/schema.rs new file mode 100644 index 000000000000..8b137891791f --- /dev/null +++ b/crates/polars/tests/it/lazy/schema.rs @@ -0,0 +1 @@ + diff --git a/crates/polars/tests/it/main.rs b/crates/polars/tests/it/main.rs new file mode 100644 index 000000000000..5884df346bb0 --- /dev/null +++ b/crates/polars/tests/it/main.rs @@ -0,0 +1,15 @@ +#![cfg_attr(feature = "nightly", allow(clippy::result_large_err))] // remove once stable +#![cfg_attr(feature = "nightly", allow(clippy::manual_repeat_n))] // remove once stable +#![cfg_attr(feature = "nightly", allow(clippy::len_zero))] // remove once stable +mod core; +mod io; +mod joins; +#[cfg(feature = "lazy")] +mod lazy; +mod schema; +mod time; + +mod arrow; +mod chunks; + +pub static FOODS_CSV: &str = "../../examples/datasets/foods1.csv"; diff --git a/crates/polars/tests/it/schema.rs b/crates/polars/tests/it/schema.rs new file mode 100644 index 000000000000..855fb764281b --- /dev/null +++ b/crates/polars/tests/it/schema.rs @@ -0,0 +1,576 @@ +use polars::prelude::*; + +#[test] +fn test_schema_rename() { + use DataType::*; + + #[track_caller] + fn test_case(old: &str, new: &str, expected: Option<(&str, Vec)>) { + fn make_schema() -> Schema { + Schema::from_iter([ + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), + ]) + } + let mut schema = make_schema(); + let res = schema.rename(old, new.into()); + if let Some((old_name, expected_fields)) = expected { + assert_eq!(res.unwrap(), old_name); + assert_eq!(schema, Schema::from_iter(expected_fields)); + } else { + assert!(res.is_none()); + assert_eq!(schema, make_schema()); + } + } + + test_case( + "a", + "anton", + Some(( + "a", + vec![ + Field::new("anton".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), + ], + )), + ); + + test_case( + "b", + "bantam", + Some(( + "b", + vec![ + Field::new("a".into(), UInt64), + Field::new("bantam".into(), Int32), + Field::new("c".into(), Int8), + ], + )), + ); + + test_case("d", "dan", None); +} + +#[test] +fn test_schema_insert_at_index() { + use DataType::*; + + #[track_caller] + fn test_case( + schema: &Schema, + index: usize, + name: &str, + expected: (Option, Vec), + ) { + println!("{index:?} -- {name:?} -- {expected:?}"); + let new = schema + .new_inserting_at_index(index, name.into(), String) + .unwrap(); + + let mut new_mut = schema.clone(); + let old_dtype = new_mut.insert_at_index(index, name.into(), String).unwrap(); + + let (expected_dtype, expected_fields) = expected; + let expected = Schema::from_iter(expected_fields); + + assert_eq!(expected, new); + + assert_eq!(expected, new_mut); + assert_eq!(expected_dtype, old_dtype); + } + + let schema = Schema::from_iter([ + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), + ]); + + test_case( + &schema, + 0, + "new", + ( + None, + vec![ + Field::new("new".into(), String), + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), + ], + ), + ); + + test_case( + &schema, + 0, + "a", + ( + Some(UInt64), + vec![ + Field::new("a".into(), String), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), + ], + ), + ); + + test_case( + &schema, + 0, + "b", + ( + Some(Int32), + vec![ + Field::new("b".into(), String), + Field::new("a".into(), UInt64), + Field::new("c".into(), Int8), + ], + ), + ); + + test_case( + &schema, + 1, + "a", + ( + Some(UInt64), + vec![ + Field::new("b".into(), Int32), + Field::new("a".into(), String), + Field::new("c".into(), Int8), + ], + ), + ); + + test_case( + &schema, + 2, + "a", + ( + Some(UInt64), + vec![ + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), + Field::new("a".into(), String), + ], + ), + ); + + test_case( + &schema, + 3, + "a", + ( + Some(UInt64), + vec![ + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), + Field::new("a".into(), String), + ], + ), + ); + + test_case( + &schema, + 3, + "new", + ( + None, + vec![ + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), + Field::new("new".into(), String), + ], + ), + ); + + test_case( + &schema, + 2, + "c", + ( + Some(Int8), + vec![ + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), String), + ], + ), + ); + + test_case( + &schema, + 3, + "c", + ( + Some(Int8), + vec![ + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), String), + ], + ), + ); + + assert!( + schema + .new_inserting_at_index(4, "oob".into(), String) + .is_err() + ); +} + +#[test] +fn test_with_column() { + use DataType::*; + + #[track_caller] + fn test_case( + schema: &Schema, + new_name: &str, + new_dtype: DataType, + expected: (Option, Vec), + ) { + let mut schema = schema.clone(); + let old_dtype = schema.with_column(new_name.into(), new_dtype); + let (exp_dtype, exp_fields) = expected; + assert_eq!(exp_dtype, old_dtype); + assert_eq!(Schema::from_iter(exp_fields), schema); + } + + let schema = Schema::from_iter([ + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), + ]); + + test_case( + &schema, + "a", + String, + ( + Some(UInt64), + vec![ + Field::new("a".into(), String), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), + ], + ), + ); + + test_case( + &schema, + "b", + String, + ( + Some(Int32), + vec![ + Field::new("a".into(), UInt64), + Field::new("b".into(), String), + Field::new("c".into(), Int8), + ], + ), + ); + + test_case( + &schema, + "c", + String, + ( + Some(Int8), + vec![ + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), String), + ], + ), + ); + + test_case( + &schema, + "d", + String, + ( + None, + vec![ + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), + Field::new("d".into(), String), + ], + ), + ); +} + +#[test] +fn test_getters() { + use DataType::*; + + macro_rules! test_case { + ($schema:expr, $method:ident, name: $name:expr, $expected:expr) => {{ + assert_eq!($expected, $schema.$method($name).unwrap()); + assert!($schema.$method("NOT_FOUND").is_none()); + }}; + ($schema:expr, $method:ident, index: $index:expr, $expected:expr) => {{ + assert_eq!($expected, $schema.$method($index).unwrap()); + assert!($schema.$method(usize::MAX).is_none()); + }}; + } + + let mut schema = Schema::from_iter([ + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), + ]); + + test_case!(schema, get, name: "a", &UInt64); + test_case!(schema, get_full, name: "a", (0, &"a".into(), &UInt64)); + test_case!(schema, get_field, name: "a", Field::new("a".into(), UInt64)); + test_case!(schema, get_at_index, index: 1, (&"b".into(), &Int32)); + test_case!(schema, get_at_index_mut, index: 1, (&mut "b".into(), &mut Int32)); + + assert!(schema.contains("a")); + assert!(!schema.contains("NOT_FOUND")); +} + +#[test] +fn test_removal() { + use DataType::*; + + #[track_caller] + fn test_case( + schema: &Schema, + to_remove: &str, + dtype: Option, + swapped_expected: Vec, + shifted_expected: Vec, + ) { + #[track_caller] + fn test_it(expected: (Option, Vec), actual: (Option, Schema)) { + let (exp_dtype, exp_fields) = expected; + let (act_dtype, act_schema) = actual; + + assert_eq!(Schema::from_iter(exp_fields), act_schema); + assert_eq!(exp_dtype, act_dtype); + } + + let mut swapped = schema.clone(); + let swapped_res = swapped.remove(to_remove); + + test_it((dtype.clone(), swapped_expected), (swapped_res, swapped)); + + let mut shifted = schema.clone(); + let shifted_res = shifted.shift_remove(to_remove); + + test_it((dtype, shifted_expected), (shifted_res, shifted)); + } + + let schema = Schema::from_iter([ + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), + Field::new("d".into(), Float64), + ]); + + test_case( + &schema, + "a", + Some(UInt64), + vec![ + Field::new("d".into(), Float64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), + ], + vec![ + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), + Field::new("d".into(), Float64), + ], + ); + + test_case( + &schema, + "b", + Some(Int32), + vec![ + Field::new("a".into(), UInt64), + Field::new("d".into(), Float64), + Field::new("c".into(), Int8), + ], + vec![ + Field::new("a".into(), UInt64), + Field::new("c".into(), Int8), + Field::new("d".into(), Float64), + ], + ); + + test_case( + &schema, + "c", + Some(Int8), + vec![ + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("d".into(), Float64), + ], + vec![ + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("d".into(), Float64), + ], + ); + + test_case( + &schema, + "d", + Some(Float64), + vec![ + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), + ], + vec![ + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), + ], + ); + + test_case( + &schema, + "NOT_FOUND", + None, + vec![ + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), + Field::new("d".into(), Float64), + ], + vec![ + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), + Field::new("d".into(), Float64), + ], + ); +} + +#[test] +fn test_set_dtype() { + use DataType::*; + + #[track_caller] + fn test_case( + schema: &Schema, + name: &str, + index: usize, + expected: (Option, Vec), + ) { + // test set_dtype + { + let mut schema = schema.clone(); + let old_dtype = schema.set_dtype(name, String); + let (exp_dtype, exp_fields) = &expected; + assert_eq!(&old_dtype, exp_dtype); + assert_eq!(Schema::from_iter(exp_fields.clone()), schema); + } + + // test set_dtype_at_index + { + let mut schema = schema.clone(); + let old_dtype = schema.set_dtype_at_index(index, String); + let (exp_dtype, exp_fields) = &expected; + assert_eq!(&old_dtype, exp_dtype); + assert_eq!(Schema::from_iter(exp_fields.clone()), schema); + } + } + + let schema = Schema::from_iter([ + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), + ]); + + test_case( + &schema, + "a", + 0, + ( + Some(UInt64), + vec![ + Field::new("a".into(), String), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), + ], + ), + ); + test_case( + &schema, + "b", + 1, + ( + Some(Int32), + vec![ + Field::new("a".into(), UInt64), + Field::new("b".into(), String), + Field::new("c".into(), Int8), + ], + ), + ); + test_case( + &schema, + "c", + 2, + ( + Some(Int8), + vec![ + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), String), + ], + ), + ); + test_case( + &schema, + "d", + 3, + ( + None, + vec![ + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), + ], + ), + ); +} + +#[test] +fn test_infer_schema() { + use DataType::{Int32, Null, String}; + use polars_core::frame::row::infer_schema; + + // Sample data as a vector of tuples (column name, value) + let data: Vec> = vec![ + vec![(PlSmallStr::from("a"), DataType::String)], + vec![(PlSmallStr::from("b"), DataType::Int32)], + vec![(PlSmallStr::from("c"), DataType::Null)], + ]; + + // Create an iterator over the sample data + let iter = data.into_iter(); + + // Infer the schema + let schema = infer_schema(iter, 3); + + let exp_fields = vec![ + Field::new("a".into(), String), + Field::new("b".into(), Int32), + Field::new("c".into(), Null), + ]; + + // Check the inferred schema + assert_eq!(Schema::from_iter(exp_fields.clone()), schema); +} diff --git a/crates/polars/tests/it/time/date_range.rs b/crates/polars/tests/it/time/date_range.rs new file mode 100644 index 000000000000..f733c05e2acb --- /dev/null +++ b/crates/polars/tests/it/time/date_range.rs @@ -0,0 +1,80 @@ +use chrono::NaiveDate; +use polars::prelude::*; +#[allow(unused_imports)] +use polars::time::date_range; + +#[test] +fn test_time_units_9413() { + let start = NaiveDate::from_ymd_opt(2022, 1, 1) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap(); + let stop = NaiveDate::from_ymd_opt(2022, 1, 5) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap(); + let actual = date_range( + "date".into(), + start, + stop, + Duration::parse("1d"), + ClosedWindow::Both, + TimeUnit::Milliseconds, + None, + ) + .map(|date_range| date_range.into_series()); + let result = format!("{:?}", actual); + let expected = r#"Ok(shape: (5,) +Series: 'date' [datetime[ms]] +[ + 2022-01-01 00:00:00 + 2022-01-02 00:00:00 + 2022-01-03 00:00:00 + 2022-01-04 00:00:00 + 2022-01-05 00:00:00 +])"#; + assert_eq!(result, expected); + let actual = date_range( + "date".into(), + start, + stop, + Duration::parse("1d"), + ClosedWindow::Both, + TimeUnit::Microseconds, + None, + ) + .map(|date_range| date_range.into_series()); + let result = format!("{:?}", actual); + let expected = r#"Ok(shape: (5,) +Series: 'date' [datetime[μs]] +[ + 2022-01-01 00:00:00 + 2022-01-02 00:00:00 + 2022-01-03 00:00:00 + 2022-01-04 00:00:00 + 2022-01-05 00:00:00 +])"#; + assert_eq!(result, expected); + let actual = date_range( + "date".into(), + start, + stop, + Duration::parse("1d"), + ClosedWindow::Both, + TimeUnit::Nanoseconds, + None, + ) + .map(|date_range| date_range.into_series()); + let result = format!("{:?}", actual); + let expected = r#"Ok(shape: (5,) +Series: 'date' [datetime[ns]] +[ + 2022-01-01 00:00:00 + 2022-01-02 00:00:00 + 2022-01-03 00:00:00 + 2022-01-04 00:00:00 + 2022-01-05 00:00:00 +])"#; + assert_eq!(result, expected); + assert_eq!(result, expected); +} diff --git a/crates/polars/tests/it/time/mod.rs b/crates/polars/tests/it/time/mod.rs new file mode 100644 index 000000000000..4473be6627a9 --- /dev/null +++ b/crates/polars/tests/it/time/mod.rs @@ -0,0 +1 @@ +mod date_range; diff --git a/docker-compose.yaml b/docker-compose.yaml deleted file mode 100644 index b4fc8a36df75..000000000000 --- a/docker-compose.yaml +++ /dev/null @@ -1,13 +0,0 @@ -version: '3' -services: - notebook: - build: . - ports: - - "8891:8891" - command: ["jupyter", "notebook", "--ip=0.0.0.0", "--port=8891", "--NotebookApp.token=''"] - environment: - - EVCXR_TMPDIR=/target - volumes: - - ./polars:/polars - - ./target:/target - - ./examples:/examples diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 000000000000..f0a24a43924d --- /dev/null +++ b/docs/README.md @@ -0,0 +1,11 @@ +The documentation is split across two subfolders, `source` and `assets`. The folder `source` +contains the static source files that make up the user guide, which are mostly markdown files and +the snippets of code. The folder `assets` contains (dynamically generated) assets used by those +files, including data files for the snippets and images with plots or diagrams. + +Do _not_ merge the two folders together. In +[PR #18773](https://github.com/pola-rs/polars/pull/18773) we introduced this split to fix the MkDocs +server live reloading. If everything is in one folder `docs`, the MkDocs server will watch the +folder `docs`. When you make one change the MkDocs server live reloads and rebuilds the docs. This +triggers scripts that build asset files, which change the folder `docs`, leading to an infinite +reloading loop. diff --git a/docs/assets/data/alltypes_plain.parquet b/docs/assets/data/alltypes_plain.parquet new file mode 100644 index 000000000000..a63f5dca7c38 Binary files /dev/null and b/docs/assets/data/alltypes_plain.parquet differ diff --git a/docs/assets/data/apple_stock.csv b/docs/assets/data/apple_stock.csv new file mode 100644 index 000000000000..6c3f9752d587 --- /dev/null +++ b/docs/assets/data/apple_stock.csv @@ -0,0 +1,101 @@ +Date,Close +1981-02-23,24.62 +1981-05-06,27.38 +1981-05-18,28.0 +1981-09-25,14.25 +1982-07-08,11.0 +1983-01-03,28.5 +1983-04-06,40.0 +1983-10-03,23.13 +1984-07-27,27.13 +1984-08-17,27.5 +1984-08-24,28.12 +1985-05-07,20.0 +1985-09-03,14.75 +1985-12-06,19.75 +1986-03-12,24.75 +1986-04-09,27.13 +1986-04-17,29.0 +1986-09-17,34.25 +1986-11-26,40.5 +1987-02-25,69.13 +1987-04-15,71.0 +1988-02-23,42.75 +1988-03-07,46.88 +1988-03-23,42.5 +1988-12-12,38.5 +1988-12-19,40.75 +1989-04-17,39.25 +1989-11-13,46.5 +1990-11-23,36.38 +1991-03-22,63.25 +1991-05-17,47.0 +1991-06-03,49.25 +1991-06-18,42.12 +1992-06-25,45.62 +1992-10-12,44.0 +1993-07-06,37.75 +1993-09-15,24.5 +1993-09-30,23.37 +1993-11-09,30.12 +1994-01-24,35.0 +1994-03-15,37.62 +1994-06-27,26.25 +1994-07-08,27.06 +1994-12-21,38.38 +1995-07-06,47.0 +1995-10-16,36.13 +1995-11-17,40.13 +1995-12-12,38.0 +1996-01-31,27.63 +1996-02-05,29.25 +1996-07-15,17.19 +1996-09-20,22.87 +1996-12-23,23.25 +1997-03-17,16.5 +1997-05-09,17.06 +1997-08-06,26.31 +1997-09-30,21.69 +1998-02-09,19.19 +1998-03-12,27.0 +1998-05-07,30.19 +1998-05-12,30.12 +1999-07-09,55.63 +1999-12-08,110.06 +2000-01-14,100.44 +2000-06-27,51.75 +2000-07-05,51.62 +2000-07-19,52.69 +2000-08-07,47.94 +2000-08-28,58.06 +2000-09-26,51.44 +2001-03-02,19.25 +2001-12-10,22.54 +2002-01-25,23.25 +2002-03-07,24.38 +2002-08-16,15.81 +2002-10-03,14.3 +2003-11-18,20.41 +2004-02-26,23.04 +2004-03-08,26.0 +2004-09-22,36.92 +2005-06-24,37.76 +2005-12-07,73.95 +2005-12-22,74.02 +2006-06-22,59.58 +2006-11-28,91.81 +2007-08-13,127.79 +2007-12-04,179.81 +2007-12-31,198.08 +2008-05-09,183.45 +2008-06-27,170.09 +2009-08-03,166.43 +2010-04-01,235.97 +2010-12-10,320.56 +2011-04-28,346.75 +2011-12-02,389.7 +2012-05-16,546.08 +2012-12-04,575.85 +2013-07-05,417.42 +2013-11-07,512.49 +2014-02-25,522.06 \ No newline at end of file diff --git a/docs/assets/data/iris.csv b/docs/assets/data/iris.csv new file mode 100644 index 000000000000..d6b466b31892 --- /dev/null +++ b/docs/assets/data/iris.csv @@ -0,0 +1,151 @@ +sepal_length,sepal_width,petal_length,petal_width,species +5.1,3.5,1.4,.2,Setosa +4.9,3,1.4,.2,Setosa +4.7,3.2,1.3,.2,Setosa +4.6,3.1,1.5,.2,Setosa +5,3.6,1.4,.2,Setosa +5.4,3.9,1.7,.4,Setosa +4.6,3.4,1.4,.3,Setosa +5,3.4,1.5,.2,Setosa +4.4,2.9,1.4,.2,Setosa +4.9,3.1,1.5,.1,Setosa +5.4,3.7,1.5,.2,Setosa +4.8,3.4,1.6,.2,Setosa +4.8,3,1.4,.1,Setosa +4.3,3,1.1,.1,Setosa +5.8,4,1.2,.2,Setosa +5.7,4.4,1.5,.4,Setosa +5.4,3.9,1.3,.4,Setosa +5.1,3.5,1.4,.3,Setosa +5.7,3.8,1.7,.3,Setosa +5.1,3.8,1.5,.3,Setosa +5.4,3.4,1.7,.2,Setosa +5.1,3.7,1.5,.4,Setosa +4.6,3.6,1,.2,Setosa +5.1,3.3,1.7,.5,Setosa +4.8,3.4,1.9,.2,Setosa +5,3,1.6,.2,Setosa +5,3.4,1.6,.4,Setosa +5.2,3.5,1.5,.2,Setosa +5.2,3.4,1.4,.2,Setosa +4.7,3.2,1.6,.2,Setosa +4.8,3.1,1.6,.2,Setosa +5.4,3.4,1.5,.4,Setosa +5.2,4.1,1.5,.1,Setosa +5.5,4.2,1.4,.2,Setosa +4.9,3.1,1.5,.2,Setosa +5,3.2,1.2,.2,Setosa +5.5,3.5,1.3,.2,Setosa +4.9,3.6,1.4,.1,Setosa +4.4,3,1.3,.2,Setosa +5.1,3.4,1.5,.2,Setosa +5,3.5,1.3,.3,Setosa +4.5,2.3,1.3,.3,Setosa +4.4,3.2,1.3,.2,Setosa +5,3.5,1.6,.6,Setosa +5.1,3.8,1.9,.4,Setosa +4.8,3,1.4,.3,Setosa +5.1,3.8,1.6,.2,Setosa +4.6,3.2,1.4,.2,Setosa +5.3,3.7,1.5,.2,Setosa +5,3.3,1.4,.2,Setosa +7,3.2,4.7,1.4,Versicolor +6.4,3.2,4.5,1.5,Versicolor +6.9,3.1,4.9,1.5,Versicolor +5.5,2.3,4,1.3,Versicolor +6.5,2.8,4.6,1.5,Versicolor +5.7,2.8,4.5,1.3,Versicolor +6.3,3.3,4.7,1.6,Versicolor +4.9,2.4,3.3,1,Versicolor +6.6,2.9,4.6,1.3,Versicolor +5.2,2.7,3.9,1.4,Versicolor +5,2,3.5,1,Versicolor +5.9,3,4.2,1.5,Versicolor +6,2.2,4,1,Versicolor +6.1,2.9,4.7,1.4,Versicolor +5.6,2.9,3.6,1.3,Versicolor +6.7,3.1,4.4,1.4,Versicolor +5.6,3,4.5,1.5,Versicolor +5.8,2.7,4.1,1,Versicolor +6.2,2.2,4.5,1.5,Versicolor +5.6,2.5,3.9,1.1,Versicolor +5.9,3.2,4.8,1.8,Versicolor +6.1,2.8,4,1.3,Versicolor +6.3,2.5,4.9,1.5,Versicolor +6.1,2.8,4.7,1.2,Versicolor +6.4,2.9,4.3,1.3,Versicolor +6.6,3,4.4,1.4,Versicolor +6.8,2.8,4.8,1.4,Versicolor +6.7,3,5,1.7,Versicolor +6,2.9,4.5,1.5,Versicolor +5.7,2.6,3.5,1,Versicolor +5.5,2.4,3.8,1.1,Versicolor +5.5,2.4,3.7,1,Versicolor +5.8,2.7,3.9,1.2,Versicolor +6,2.7,5.1,1.6,Versicolor +5.4,3,4.5,1.5,Versicolor +6,3.4,4.5,1.6,Versicolor +6.7,3.1,4.7,1.5,Versicolor +6.3,2.3,4.4,1.3,Versicolor +5.6,3,4.1,1.3,Versicolor +5.5,2.5,4,1.3,Versicolor +5.5,2.6,4.4,1.2,Versicolor +6.1,3,4.6,1.4,Versicolor +5.8,2.6,4,1.2,Versicolor +5,2.3,3.3,1,Versicolor +5.6,2.7,4.2,1.3,Versicolor +5.7,3,4.2,1.2,Versicolor +5.7,2.9,4.2,1.3,Versicolor +6.2,2.9,4.3,1.3,Versicolor +5.1,2.5,3,1.1,Versicolor +5.7,2.8,4.1,1.3,Versicolor +6.3,3.3,6,2.5,Virginica +5.8,2.7,5.1,1.9,Virginica +7.1,3,5.9,2.1,Virginica +6.3,2.9,5.6,1.8,Virginica +6.5,3,5.8,2.2,Virginica +7.6,3,6.6,2.1,Virginica +4.9,2.5,4.5,1.7,Virginica +7.3,2.9,6.3,1.8,Virginica +6.7,2.5,5.8,1.8,Virginica +7.2,3.6,6.1,2.5,Virginica +6.5,3.2,5.1,2,Virginica +6.4,2.7,5.3,1.9,Virginica +6.8,3,5.5,2.1,Virginica +5.7,2.5,5,2,Virginica +5.8,2.8,5.1,2.4,Virginica +6.4,3.2,5.3,2.3,Virginica +6.5,3,5.5,1.8,Virginica +7.7,3.8,6.7,2.2,Virginica +7.7,2.6,6.9,2.3,Virginica +6,2.2,5,1.5,Virginica +6.9,3.2,5.7,2.3,Virginica +5.6,2.8,4.9,2,Virginica +7.7,2.8,6.7,2,Virginica +6.3,2.7,4.9,1.8,Virginica +6.7,3.3,5.7,2.1,Virginica +7.2,3.2,6,1.8,Virginica +6.2,2.8,4.8,1.8,Virginica +6.1,3,4.9,1.8,Virginica +6.4,2.8,5.6,2.1,Virginica +7.2,3,5.8,1.6,Virginica +7.4,2.8,6.1,1.9,Virginica +7.9,3.8,6.4,2,Virginica +6.4,2.8,5.6,2.2,Virginica +6.3,2.8,5.1,1.5,Virginica +6.1,2.6,5.6,1.4,Virginica +7.7,3,6.1,2.3,Virginica +6.3,3.4,5.6,2.4,Virginica +6.4,3.1,5.5,1.8,Virginica +6,3,4.8,1.8,Virginica +6.9,3.1,5.4,2.1,Virginica +6.7,3.1,5.6,2.4,Virginica +6.9,3.1,5.1,2.3,Virginica +5.8,2.7,5.1,1.9,Virginica +6.8,3.2,5.9,2.3,Virginica +6.7,3.3,5.7,2.5,Virginica +6.7,3,5.2,2.3,Virginica +6.3,2.5,5,1.9,Virginica +6.5,3,5.2,2,Virginica +6.2,3.4,5.4,2.3,Virginica +5.9,3,5.1,1.8,Virginica \ No newline at end of file diff --git a/docs/assets/data/reddit.csv b/docs/assets/data/reddit.csv new file mode 100644 index 000000000000..88f91e3df7db --- /dev/null +++ b/docs/assets/data/reddit.csv @@ -0,0 +1,100 @@ +id,name,created_utc,updated_on,comment_karma,link_karma +1,truman48lamb_jasonbroken,1397113470,1536527864,0,0 +2,johnethen06_jasonbroken,1397113483,1536527864,0,0 +3,yaseinrez_jasonbroken,1397113483,1536527864,0,1 +4,Valve92_jasonbroken,1397113503,1536527864,0,0 +5,srbhuyan_jasonbroken,1397113506,1536527864,0,0 +6,taojianlong_jasonbroken,1397113510,1536527864,4,0 +7,YourPalGrant92_jasonbroken,1397113513,1536527864,0,0 +8,Lucki87_jasonbroken,1397113515,1536527864,0,0 +9,punkstock_jasonbroken,1397113517,1536527864,0,0 +10,duder_con_chile_jasonbroken,1397113519,1536527864,0,2 +11,IHaveBigBalls_jasonbroken,1397113520,1536527864,0,0 +12,Foggybanana_jasonbroken,1397113523,1536527864,0,0 +13,Thedrinkdriver_jasonbroken,1397113527,1536527864,-9,0 +14,littlemissd_jasonbroken,1397113530,1536527864,0,-3 +15,phonethaway_jasonbroken,1397113537,1536527864,0,0 +16,DreamingOfWinterfell_jasonbroken,1397113538,1536527864,0,0 +17,ssaig_jasonbroken,1397113544,1536527864,1,0 +18,divinetribe_jasonbroken,1397113549,1536527864,0,0 +19,fdbvfdssdgfds_jasonbroken,1397113552,1536527864,3,0 +20,hjtrsh54yh43_jasonbroken,1397113559,1536527864,-1,-1 +21,Dalin86_jasonbroken,1397113561,1536527864,0,0 +22,sgalex_jasonbroken,1397113561,1536527864,0,0 +23,beszhthw_jasonbroken,1397113566,1536527864,0,0 +24,WojkeN_jasonbroken,1397113572,1536527864,-8,0 +25,LixksHD_jasonbroken,1397113572,1536527864,0,0 +26,bradhrvf78_jasonbroken,1397113574,1536527864,0,0 +27,ravenfeathers_jasonbroken,1397113576,1536527864,0,0 +28,jayne101_jasonbroken,1397113583,1536527864,0,0 +29,jdennis6701_jasonbroken,1397113585,1536527864,0,0 +30,Puppy243_jasonbroken,1397113592,1536527864,0,0 +31,sissyt_jasonbroken,1397113609,1536527864,0,0 +32,fengye78_jasonbroken,1397113613,1536527864,0,0 +33,bigspender1988_jasonbroken,1397113614,1536527864,0,21 +34,bitdownworld_jasonbroken,1397113618,1536527864,0,0 +35,adhyufsdtha12_jasonbroken,1397113619,1536527864,0,0 +36,Haydenac_jasonbroken,1397113635,1536527864,0,0 +37,ihatewhoweare_jasonbroken,1397113636,1536527864,61,0 +38,HungDaddy69__jasonbroken,1397113641,1536527864,0,0 +39,FSUJohnny24_jasonbroken,1397113646,1536527864,0,0 +40,Toejimon_jasonbroken,1397113650,1536527864,0,0 +41,mine69flesh_jasonbroken,1397113651,1536527864,0,0 +42,brycentkt_jasonbroken,1397113653,1536527864,0,0 +43,hmmmitsbig,1397113655,1536527864,0,0 +77714,hockeyschtick,1137474000,1536497404,11104,451 +77715,kbmunkholm,1137474000,1536528267,0,0 +77716,dickb,1137588452,1536528267,0,0 +77717,stephenjcole,1137474000,1536528267,0,2 +77718,rosetree,1137474000,1536528267,0,0 +77719,benhawK,1138180921,1536528267,0,0 +77720,joenowak,1137474000,1536528268,0,0 +77721,constant,1137474000,1536528268,1,0 +77722,jpscott,1137474000,1536528268,0,1 +77723,meryn,1137474000,1536528268,0,2 +77724,momerath,1128916800,1536528268,2490,101 +77725,inuse,1137474000,1536528269,0,0 +77726,dubert11,1137474000,1536528269,38,59 +77727,CaliMark,1137474000,1536528269,0,0 +77728,Maniac,1137474000,1536528269,0,0 +77729,earlpearl,1137474000,1536528269,0,0 +77730,ghost,1137474000,1536497404,767,0 +77731,paulzg,1137474000,1536528270,0,0 +77732,rshawgo,1137474000,1536497404,707,6883 +77733,spage,1137474000,1536528270,0,0 +77734,HrothgarReborn,1137474000,1536528270,0,0 +77735,darknessvisible,1137474000,1536528270,26133,139 +77736,finleyt,1137714898,1536528270,0,0 +77737,Dalton,1137474000,1536528271,118,2 +77738,graemes,1137474000,1536528271,0,0 +77739,lettuce,1137780958,1536497404,4546,724 +77740,mudkicker,1137474000,1536528271,0,0 +77741,mydignet,1139649149,1536528271,0,0 +77742,markbo,1137474000,1536528271,0,0 +77743,mrfrostee,1137474000,1536528272,227,43 +77744,parappayo,1136350800,1536528272,53,164 +77745,danastasi,1137474000,1536528272,2335,146 +77747,AltherrWeb,1137474000,1536528272,1387,1605 +77748,dtpetty,1137474000,1536528273,0,0 +77749,jamesluke4,1137474000,1536528273,0,0 +77750,sankeld,1137474000,1536528273,9,45 +77751,iampivot,1139479524,1536497404,2640,31 +77752,mcaamano,1137474000,1536528273,0,0 +77753,wonsungi,1137596632,1536528273,0,0 +77754,naotakem,1137474000,1536528274,0,0 +77755,bis,1137474000,1536497404,2191,285 +77756,imeinzen,1137474000,1536528274,0,0 +77757,zrenneh,1137474000,1536528274,79,0 +77758,onclephilippe,1137474000,1536528274,0,0 +77759,Mokzaio415,1139422169,1536528274,0,0 +77761,-brisse,1137474000,1536528275,14,1 +77762,coolin86,1138303196,1536528275,40,7 +77763,Lunchy,1137599510,1536528275,65,0 +77764,jannemans,1137474000,1536528275,0,0 +77765,compostellas,1137474000,1536528276,6,0 +77766,genericbob,1137474000,1536528276,291,14 +77767,domlexch,1139482978,1536528276,0,0 +77768,TinheadNed,1139665457,1536497404,4434,103 +77769,patopurifik,1137474000,1536528276,0,0 +77770,PoPPo,1139057558,1536528276,0,0 +77771,tandrews,1137474000,1536528277,0,0 diff --git a/docs/assets/images/.gitignore b/docs/assets/images/.gitignore new file mode 100644 index 000000000000..d6b7ef32c847 --- /dev/null +++ b/docs/assets/images/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/docs/mlc-config.json b/docs/mlc-config.json new file mode 100644 index 000000000000..888be1eaf3da --- /dev/null +++ b/docs/mlc-config.json @@ -0,0 +1,13 @@ +{ + "ignorePatterns": [ + { + "pattern": "^https://crates.io/" + }, + { + "pattern": "^https://stackoverflow.com/" + }, + { + "pattern": "^https://marketplace.visualstudio.com/" + } + ] +} diff --git a/docs/source/_build/API_REFERENCE_LINKS.yml b/docs/source/_build/API_REFERENCE_LINKS.yml new file mode 100644 index 000000000000..b5feaf67aff1 --- /dev/null +++ b/docs/source/_build/API_REFERENCE_LINKS.yml @@ -0,0 +1,486 @@ +python: + agg: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.dataframe.group_by.GroupBy.agg.html + alias: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.alias.html + all: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.all.html + approx_n_unique: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.approx_n_unique.html + Array: https://docs.pola.rs/api/python/stable/reference/api/polars.datatypes.Array.html + cast: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.cast.html + Categorical: https://docs.pola.rs/api/python/stable/reference/api/polars.datatypes.Categorical.html + col: https://docs.pola.rs/api/python/stable/reference/expressions/col.html + collect: https://docs.pola.rs/api/python/stable/reference/lazyframe/api/polars.LazyFrame.collect.html + concat: https://docs.pola.rs/api/python/stable/reference/api/polars.concat.html + concat_list: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.concat_list.html + concat_str: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.concat_str.html + Config: https://docs.pola.rs/api/python/stable/reference/config.html + CredentialProviderAWS: https://docs.pola.rs/api/python/stable/reference/api/polars.CredentialProviderAWS.html + CredentialProviderAzure: https://docs.pola.rs/api/python/stable/reference/api/polars.CredentialProviderAzure.html + CredentialProviderGCP: https://docs.pola.rs/api/python/stable/reference/api/polars.CredentialProviderGCP.html + cs.by_name: https://docs.pola.rs/api/python/stable/reference/selectors.html#polars.selectors.by_name + cs.contains: https://docs.pola.rs/api/python/stable/reference/selectors.html#polars.selectors.contains + cs.first: https://docs.pola.rs/api/python/stable/reference/selectors.html#polars.selectors.first + cs.matches: https://docs.pola.rs/api/python/stable/reference/selectors.html#polars.selectors.matches + cs.numeric: https://docs.pola.rs/api/python/stable/reference/selectors.html#polars.selectors.numeric + cs.temporal: https://docs.pola.rs/api/python/stable/reference/selectors.html#polars.selectors.temporal + DataFrame: https://docs.pola.rs/api/python/stable/reference/dataframe/index.html + DataFrame.explode: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.explode.html + date_range: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.date_range.html + datetime_range: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.datetime_range.html + describe: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.describe.html + dt.convert_time_zone: + name: dt.convert_time_zone + link: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.dt.convert_time_zone.html + feature_flags: [timezone] + dt.replace_time_zone: + name: dt.replace_time_zone + link: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.dt.replace_time_zone.html + feature_flags: [timezone] + dt.to_string: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.dt.to_string.html + dt.year: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.dt.year.html + element: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.element.html + Enum: https://docs.pola.rs/api/python/stable/reference/api/polars.datatypes.Enum.html + estimated_size: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.estimated_size.html + exclude: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.exclude.html + expand_selector: https://docs.pola.rs/api/python/stable/reference/selectors.html#polars.selectors.expand_selector + explain: https://docs.pola.rs/api/python/stable/reference/lazyframe/api/polars.LazyFrame.explain.html + explode: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.explode.html + Expr.arr: + name: arr namespace + link: https://docs.pola.rs/api/python/stable/reference/expressions/array.html + Expr.dt: + name: dt namespace + link: https://docs.pola.rs/api/python/stable/reference/expressions/temporal.html + Expr.list: + name: list namespace + link: https://docs.pola.rs/api/python/stable/reference/expressions/list.html + Expr.name: + name: name namespace + link: https://docs.pola.rs/api/python/stable/reference/expressions/name.html + Expr.str: + name: str namespace + link: https://docs.pola.rs/api/python/stable/reference/expressions/string.html + fetch: https://docs.pola.rs/api/python/stable/reference/lazyframe/api/polars.LazyFrame.fetch.html + fill_nan: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.fill_nan.html + fill_null: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.fill_null.html + filter: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.filter.html + fold: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.fold.html + from_arrow: + name: from_arrow + link: https://docs.pola.rs/api/python/stable/reference/api/polars.from_arrow.html + feature_flags: [fsspec, pyarrow] + glimpse: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.glimpse.html + group_by: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.group_by.html + group_by_dynamic: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.group_by_dynamic.html + head: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.head.html + implode: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.implode.html + interpolate: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.interpolate.html + is_between: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.is_between.html + is_duplicated: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.is_duplicated.html + is_null: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.is_null.html + is_selector: https://docs.pola.rs/api/python/stable/reference/selectors.html#polars.selectors.is_selector + join: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.join.html + join_asof: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.join_asof.html + join_where: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.join_where.html + lazy: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.lazy.html + LazyFrame: https://docs.pola.rs/api/python/stable/reference/lazyframe/index.html + List: https://docs.pola.rs/api/python/stable/reference/api/polars.datatypes.List.html + list.eval: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.list.eval.html + map: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.name.map.html + map_elements: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.map_elements.html + max: https://docs.pola.rs/api/python/stable/reference/series/api/polars.Series.max.html + min: https://docs.pola.rs/api/python/stable/reference/series/api/polars.Series.min.html + n_unique: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.n_unique.html + np.log: + name: log + link: https://numpy.org/doc/stable/reference/generated/numpy.log.html + feature_flags: [numpy] + null_count: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.null_count.html + operators: https://docs.pola.rs/api/python/stable/reference/expressions/operators.html + over: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.over.html + pivot: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.pivot.html + prefix: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.name.prefix.html + rank: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.rank.html + read_csv: https://docs.pola.rs/api/python/stable/reference/api/polars.read_csv.html + read_database: + name: read_database + link: https://docs.pola.rs/api/python/stable/reference/api/polars.read_database.html + read_database_connectorx: + name: read_database + link: https://docs.pola.rs/api/python/stable/reference/api/polars.read_database.html + feature_flags: [connectorx] + read_database_uri: https://docs.pola.rs/api/python/stable/reference/api/polars.read_database_uri.html + read_excel: https://docs.pola.rs/api/python/stable/reference/api/polars.read_excel.html + read_ipc: https://docs.pola.rs/api/python/stable/reference/api/polars.read_ipc.html + read_json: https://docs.pola.rs/api/python/stable/reference/api/polars.read_json.html + read_ndjson: https://docs.pola.rs/api/python/stable/reference/api/polars.read_ndjson.html + read_parquet: https://docs.pola.rs/api/python/stable/reference/api/polars.read_parquet.html + round: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.round.html#polars.Expr.round + sample: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.sample.html + scan_csv: https://docs.pola.rs/api/python/stable/reference/api/polars.scan_csv.html + scan_ipc: https://docs.pola.rs/api/python/stable/reference/api/polars.scan_ipc.html + scan_ndjson: https://docs.pola.rs/api/python/stable/reference/api/polars.scan_ndjson.html + scan_parquet: https://docs.pola.rs/api/python/stable/reference/api/polars.scan_parquet.html + scan_pyarrow_dataset: https://docs.pola.rs/api/python/stable/reference/api/polars.scan_pyarrow_dataset.html + select: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.select.html + selectors: https://docs.pola.rs/api/python/stable/reference/selectors.html + Series: https://docs.pola.rs/api/python/stable/reference/series/index.html + Series.arr: https://docs.pola.rs/api/python/stable/reference/series/array.html + Series.dt.day: https://docs.pola.rs/api/python/stable/reference/series/api/polars.Series.dt.day.html + show_graph: https://docs.pola.rs/api/python/stable/reference/lazyframe/api/polars.LazyFrame.show_graph.html + sort: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.sort.html + SQLContext: https://docs.pola.rs/api/python/stable/reference/sql/python_api.html#polars.SQLContext + SQLexecute: + name: execute + link: https://docs.pola.rs/api/python/stable/reference/sql/api/polars.SQLContext.execute.html + SQLregister: + name: register + link: https://docs.pola.rs/api/python/stable/reference/sql/api/polars.SQLContext.register.html#polars.SQLContext.register + SQLregister_many: + name: register_many + link: https://docs.pola.rs/api/python/stable/reference/sql/api/polars.SQLContext.register_many.html + str.contains: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.contains.html + str.ends_with: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.ends_with.html + str.extract: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.extract.html + str.extract_all: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.extract_all.html + str.head: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.head.html + str.len_bytes: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.len_bytes.html + str.len_chars: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.len_chars.html + str.replace: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.replace.html + str.replace_all: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.replace_all.html + str.slice: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.slice.html + str.split: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.split.html + str.starts_with: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.starts_with.html + str.strip_chars: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.strip_chars.html + str.strip_chars_end: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.strip_chars_end.html + str.strip_chars_start: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.strip_chars_start.html + str.strip_prefix: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.strip_prefix.html + str.strip_suffix: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.strip_suffix.html + str.tail: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.tail.html + str.to_date: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.to_date.html + str.to_datetime: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.to_datetime.html + str.to_lowercase: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.to_lowercase.html + str.to_titlecase: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.to_titlecase.html + str.to_uppercase: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.to_uppercase.html + StringCache: https://docs.pola.rs/api/python/stable/reference/api/polars.StringCache.html + struct: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.struct.html + struct.field: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.struct.field.html + struct.rename_fields: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.struct.rename_fields.html + suffix: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.name.suffix.html + tail: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.tail.html + unique: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.unique.html + unique_counts: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.unique_counts.html + unnest: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.unnest.html + unpivot: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.unpivot.html + upsample: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.upsample.html + value_counts: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.value_counts.html + vstack: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.vstack.html + when: https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.when.html + with_columns: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.with_columns.html + write_csv: https://docs.pola.rs/api/python/stable/reference/api/polars.DataFrame.write_csv.html + write_database: https://docs.pola.rs/api/python/stable/reference/api/polars.DataFrame.write_database.html + write_excel: https://docs.pola.rs/api/python/stable/reference/api/polars.DataFrame.write_excel.html + write_json: https://docs.pola.rs/api/python/stable/reference/api/polars.DataFrame.write_json.html + write_ndjson: https://docs.pola.rs/api/python/stable/reference/api/polars.DataFrame.write_ndjson.html + write_parquet: https://docs.pola.rs/api/python/stable/reference/api/polars.DataFrame.write_parquet.html + Workspace: https://docs.cloud.pola.rs/reference/workspace/workspace.html + ComputeContext: https://docs.cloud.pola.rs/reference/compute/compute.html + LazyFrameExt : https://docs.cloud.pola.rs/reference/query/lazyframeext.html + QueryResult : https://docs.cloud.pola.rs/reference/query/query_result.html + InteractiveQuery : https://docs.cloud.pola.rs/reference/query/interactive_query.html + BatchQuery: https://docs.cloud.pola.rs/reference/query/batch_query.html + login: https://docs.cloud.pola.rs/reference/auth/api/polars_cloud.login.html + +rust: + agg: https://docs.rs/polars/latest/polars/prelude/struct.LazyGroupBy.html#method.agg + alias: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/enum.Expr.html#method.alias + all: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/functions/fn.all.html + approx_n_unique: + name: approx_n_unique + link: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/enum.Expr.html#method.approx_n_unique + feature_flags: [approx_unique] + arr.eval: + name: arr + link: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/enum.Expr.html#method.arr + feature_flags: [list_eval, rank] + Array: + name: Array + link: https://docs.pola.rs/api/rust/dev/polars/datatypes/enum.DataType.html#variant.Array + feature_flags: [dtype-array] + cast: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/enum.Expr.html#method.cast + Categorical: + name: Categorical + link: https://docs.pola.rs/api/rust/dev/polars/prelude/enum.DataType.html#variant.Categorical + feature_flags: [dtype-categorical] + col: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/fn.col.html + collect: + name: collect + link: https://docs.pola.rs/api/rust/dev/polars/prelude/struct.LazyFrame.html#method.collect + feature_flags: [streaming] + concat: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/functions/fn.concat.html + concat_list: + name: concat_lst + link: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/fn.concat_list.html + concat_str: + name: concat_str + link: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/fn.concat_str.html + feature_flags: [concat_str] + cross_join: + name: cross_join + link: https://docs.pola.rs/api/rust/dev/polars/prelude/struct.LazyFrame.html#method.cross_join + feature_flags: [cross_join] + cs.by_name: https://github.com/pola-rs/polars/issues/10594 + cs.contains: https://github.com/pola-rs/polars/issues/10594 + cs.first: https://github.com/pola-rs/polars/issues/10594 + cs.matches: https://github.com/pola-rs/polars/issues/10594 + cs.numeric: https://github.com/pola-rs/polars/issues/10594 + cs.temporal: https://github.com/pola-rs/polars/issues/10594 + DataFrame: https://docs.pola.rs/api/rust/dev/polars/frame/struct.DataFrame.html + DataFrame.explode: https://docs.pola.rs/api/rust/dev/polars/frame/struct.DataFrame.html#method.explode + date_range: + name: date_range + link: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/functions/fn.date_range.html + feature_flags: [range, dtype-date] + datetime_range: + name: datetime_range + link: https://docs.rs/polars/latest/polars/prelude/fn.datetime_range.html + feature_flags: [lazy, dtype-datetime] + describe: + name: describe + link: https://docs.pola.rs/api/rust/dev/polars/frame/struct.DataFrame.html#method.describe + feature_flags: [describe] + dt.convert_time_zone: + name: dt.convert_time_zone + link: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/dt/struct.DateLikeNameSpace.html#method.convert_time_zone + feature_flags: [timezones] + dt.replace_time_zone: + name: dt.replace_time_zone + link: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/dt/struct.DateLikeNameSpace.html#method.replace_time_zone + feature_flags: [timezones] + dt.to_string: + name: dt.to_string + link: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/dt/struct.DateLikeNameSpace.html#method.to_string + feature_flags: [temporal] + dt.year: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/dt/struct.DateLikeNameSpace.html#method.year + dtype_col: https://docs.rs/polars/latest/polars/prelude/fn.dtype_col.html + dtype_cols: https://docs.rs/polars/latest/polars/prelude/fn.dtype_cols.html + element: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/fn.col.html + estimated_size: https://docs.rs/polars/latest/polars/frame/struct.DataFrame.html#method.estimated_size + exclude: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/enum.Expr.html#method.exclude + expand_selector: https://github.com/pola-rs/polars/issues/10594 + explain: https://docs.rs/polars/latest/polars/prelude/struct.LazyFrame.html#method.explain + explode: https://docs.rs/polars/latest/polars/frame/struct.DataFrame.html#method.explode + Expr.arr: + name: '`arr` namespace' + link: https://docs.pola.rs/api/rust/dev/polars/prelude/enum.Expr.html#method.arr + feature_flags: [dtype-array] + Expr.dt: + name: dt namespace + link: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/dt/struct.DateLikeNameSpace.html + feature_flags: [temporal] + Expr.list: + name: list namespace + link: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/struct.ListNameSpace.html + Expr.name: + name: name namespace + link: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/struct.ExprNameNameSpace.html + feature_flags: [lazy] + Expr.str: + name: str namespace + link: https://docs.pola.rs/api/rust/dev/polars/prelude/trait.StringNameSpaceImpl.html + feature_flags: [strings] + fill_nan: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/enum.Expr.html#method.fill_nan + fill_null: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/enum.Expr.html#method.fill_null + filter: https://docs.pola.rs/api/rust/dev/polars_lazy/frame/struct.LazyFrame.html#method.filter + fold: + name: fold_exprs + link: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/fn.fold_exprs.html + group_by: https://docs.pola.rs/api/rust/dev/polars_lazy/frame/struct.LazyFrame.html#method.group_by + group_by_dynamic: + name: group_by_dynamic + link: https://docs.pola.rs/api/rust/dev/polars_lazy/frame/struct.LazyFrame.html#method.group_by_dynamic + feature_flags: [dynamic_group_by] + head: https://docs.pola.rs/api/rust/dev/polars/frame/struct.DataFrame.html#method.head + implode: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/enum.Expr.html#method.implode + interpolate: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/enum.Expr.html#method.interpolate + is_between: + name: is_between + link: https://docs.pola.rs/api/rust/dev/polars/prelude/enum.Expr.html#method.is_between + feature_flags: [is_between] + is_duplicated: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/enum.Expr.html#method.is_duplicated + is_null: https://docs.pola.rs/api/rust/dev/polars/prelude/enum.Expr.html#method.is_null + is_selector: https://github.com/pola-rs/polars/issues/10594 + join: https://docs.pola.rs/api/rust/dev/polars/prelude/trait.DataFrameJoinOps.html#method.join + join-semi_anti_join_flag: + name: join + link: https://docs.pola.rs/api/rust/dev/polars/prelude/trait.DataFrameJoinOps.html#method.join + feature_flags: [semi_anti_join] + join_asof_by: + name: join_asof_by + link: https://docs.pola.rs/api/rust/dev/polars/prelude/trait.AsofJoinBy.html#method.join_asof_by + feature_flags: [asof_join] + join_where: + name: join_where + link: https://docs.pola.rs/api/rust/dev/polars/prelude/struct.JoinBuilder.html#method.join_where + feature_flags: [iejoin] + LazyFrame: https://docs.pola.rs/api/rust/dev/polars/prelude/struct.LazyFrame.html + List: https://docs.pola.rs/api/rust/dev/polars/datatypes/enum.DataType.html#variant.List + list.eval: + name: list.eval + link: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/trait.ListNameSpaceExtension.html#method.eval + feature_flags: [list_eval] + map: https://docs.rs/polars/latest/polars/prelude/struct.ExprNameNameSpace.html#method.map + max: https://docs.pola.rs/api/rust/dev/polars/series/struct.Series.html#method.max + min: https://docs.pola.rs/api/rust/dev/polars/series/struct.Series.html#method.min + n_unique: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/enum.Expr.html#method.n_unique + null_count: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/enum.Expr.html#method.null_count + operators: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/enum.Operator.html + over: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/enum.Expr.html#method.over + pivot: https://docs.pola.rs/api/rust/dev/polars_lazy/frame/pivot/fn.pivot.html + prefix: https://docs.rs/polars/latest/polars/prelude/struct.ExprNameNameSpace.html#method.prefix + rank: https://docs.rs/polars/latest/polars/prelude/enum.Expr.html#method.rank + read_csv: + name: CsvReader + link: https://docs.pola.rs/api/rust/dev/polars/prelude/struct.CsvReader.html + feature_flags: [csv] + read_ipc: + name: IpcReader + link: https://docs.pola.rs/api/rust/dev/polars/prelude/struct.IpcReader.html + feature_flags: [ipc] + read_json: + name: JsonReader + link: https://docs.pola.rs/api/rust/dev/polars_io/json/struct.JsonReader.html + feature_flags: [json] + read_ndjson: + name: JsonLineReader + link: https://docs.pola.rs/api/rust/dev/polars_io/ndjson/core/struct.JsonLineReader.html + feature_flags: [json] + read_parquet: + name: ParquetReader + link: https://docs.pola.rs/api/rust/dev/polars/prelude/struct.ParquetReader.html + feature_flags: [parquet] + round: + name: round + link: https://docs.pola.rs/api/rust/dev/polars/prelude/enum.Expr.html#method.round + feature_flags: [round_series] + sample: + name: sample_n + link: https://docs.pola.rs/api/rust/dev/polars/frame/struct.DataFrame.html#method.sample_n + scan_csv: + name: LazyCsvReader + link: https://docs.pola.rs/api/rust/dev/polars/prelude/struct.LazyCsvReader.html + feature_flags: [csv] + scan_ndjson: + name: LazyJsonLineReader + link: https://docs.pola.rs/api/rust/dev/polars_lazy/frame/struct.LazyJsonLineReader.html + feature_flags: [json] + scan_parquet: + name: scan_parquet + link: https://docs.pola.rs/api/rust/dev/polars/prelude/struct.LazyFrame.html#method.scan_parquet + feature_flags: [parquet] + scan_pyarrow_dataset: https://docs.pola.rs/api/python/stable/reference/api/polars.scan_pyarrow_dataset.html + select: https://docs.pola.rs/api/rust/dev/polars_lazy/frame/struct.LazyFrame.html#method.select + selectors: https://github.com/pola-rs/polars/issues/10594 + Series: https://docs.pola.rs/api/rust/dev/polars/series/struct.Series.html + Series.arr: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/struct.ArrayNameSpace.html + Series.dt.day: + name: dt.day + link: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/dt/struct.DateLikeNameSpace.html#method.day + feature_flags: [temporal] + sort: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/enum.Expr.html#method.sort + str.contains: + name: str.contains + link: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/string/struct.StringNameSpace.html#method.contains + feature_flags: [regex] + str.ends_with: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/string/struct.StringNameSpace.html#method.ends_with + str.extract: + name: str.extract + link: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/string/struct.StringNameSpace.html#method.extract + str.extract_all: + name: str.extract_all + link: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/string/struct.StringNameSpace.html#method.extract_all + str.head: + name: str.head + link: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/string/struct.StringNameSpace.html#method.head + str.len_bytes: + name: str.len_bytes + link: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/string/struct.StringNameSpace.html#method.len_bytes + str.len_chars: + name: str.len_chars + link: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/string/struct.StringNameSpace.html#method.len_chars + str.replace: + name: str.replace + link: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/string/struct.StringNameSpace.html#method.replace + feature_flags: [regex] + str.replace_all: + name: str.replace_all + link: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/string/struct.StringNameSpace.html#method.replace_all + feature_flags: [regex] + str.split: + name: str.split + link: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/string/struct.StringNameSpace.html#method.split + str.starts_with: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/string/struct.StringNameSpace.html#method.starts_with + str.str_head: https://docs.rs/polars/latest/polars/prelude/trait.StringNameSpaceImpl.html#method.str_head + str.str_slice: https://docs.rs/polars/latest/polars/prelude/trait.StringNameSpaceImpl.html#method.str_slice + str.str_tail: https://docs.rs/polars/latest/polars/prelude/trait.StringNameSpaceImpl.html#method.str_tail + str.strip_chars: https://docs.rs/polars/latest/polars/prelude/trait.StringNameSpaceImpl.html#method.strip_chars + str.strip_chars_end: https://docs.rs/polars/latest/polars/prelude/trait.StringNameSpaceImpl.html#method.strip_chars_end + str.strip_chars_start: https://docs.rs/polars/latest/polars/prelude/trait.StringNameSpaceImpl.html#method.strip_chars_start + str.strip_prefix: https://docs.rs/polars/latest/polars/prelude/trait.StringNameSpaceImpl.html#method.strip_prefix + str.strip_suffix: https://docs.rs/polars/latest/polars/prelude/trait.StringNameSpaceImpl.html#method.strip_suffix + str.tail: + name: str.tail + link: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/string/struct.StringNameSpace.html#method.tail + str.to_date: + name: str.replace_all + link: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/string/struct.StringNameSpace.html#method.to_date + feature_flags: [dtype-date] + str.to_datetime: + name: str.replace_all + link: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/string/struct.StringNameSpace.html#method.to_datetime + feature_flags: [dtype-datetime] + str.to_lowercase: https://docs.rs/polars/latest/polars/prelude/trait.StringNameSpaceImpl.html#method.to_lowercase + str.to_titlecase: + name: str.to_titlecase + link: https://docs.rs/polars/latest/polars/prelude/trait.StringNameSpaceImpl.html#method.to_titlecase + feature_flags: [nightly] + str.to_uppercase: https://docs.rs/polars/latest/polars/prelude/trait.StringNameSpaceImpl.html#method.to_uppercase + struct: + name: Struct + link: https://docs.pola.rs/api/rust/dev/polars/datatypes/enum.DataType.html#variant.Struct + feature_flags: [dtype-struct] + struct.field: + name: struct.field_by_name + link: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/struct.StructNameSpace.html#method.field_by_name + struct.rename_fields: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/struct.StructNameSpace.html#method.rename_fields + suffix: https://docs.rs/polars/latest/polars/prelude/struct.ExprNameNameSpace.html#method.suffix + tail: https://docs.pola.rs/api/rust/dev/polars/frame/struct.DataFrame.html#method.tail + unique: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/enum.Expr.html#method.unique + unique_counts: + name: unique_counts + link: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/enum.Expr.html#method.unique_counts + feature_flags: [unique_counts] + unnest: https://docs.pola.rs/api/rust/dev/polars/frame/struct.DataFrame.html#method.unnest + unpivot: https://docs.pola.rs/api/rust/dev/polars/frame/struct.DataFrame.html#method.unpivot + upsample: https://docs.pola.rs/api/rust/dev/polars/frame/struct.DataFrame.html#method.upsample + value_counts: + name: value_counts + link: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/enum.Expr.html#method.value_counts + feature_flags: [dtype-struct] + vstack: https://docs.pola.rs/api/rust/dev/polars_core/frame/struct.DataFrame.html#method.vstack + when: https://docs.pola.rs/api/rust/dev/polars_lazy/dsl/fn.when.html + with_columns: https://docs.pola.rs/api/rust/dev/polars_lazy/frame/struct.LazyFrame.html#method.with_columns + write_csv: + name: CsvWriter + link: https://docs.pola.rs/api/rust/dev/polars/prelude/struct.CsvWriter.html + feature_flags: [csv] + write_json: + name: JsonWriter + link: https://docs.pola.rs/api/rust/dev/polars_io/json/struct.JsonWriter.html + feature_flags: [json] + write_ndjson: + name: JsonWriter + link: https://docs.pola.rs/api/rust/dev/polars_io/json/struct.JsonWriter.html + feature_flags: [json] + write_parquet: + name: ParquetWriter + link: https://docs.pola.rs/api/rust/dev/polars/prelude/struct.ParquetWriter.html + feature_flags: [parquet] diff --git a/docs/source/_build/assets/logo.png b/docs/source/_build/assets/logo.png new file mode 100644 index 000000000000..9b5486edce3b Binary files /dev/null and b/docs/source/_build/assets/logo.png differ diff --git a/docs/source/_build/css/extra.css b/docs/source/_build/css/extra.css new file mode 100644 index 000000000000..0256afdc1926 --- /dev/null +++ b/docs/source/_build/css/extra.css @@ -0,0 +1,75 @@ +:root { + --md-primary-fg-color: #0075ff; + --md-text-font: 'Proxima Nova', sans-serif; +} + + +span .md-typeset .emojione, +.md-typeset .gemoji, +.md-typeset .twemoji { + vertical-align: text-bottom; +} + +@font-face { + font-family: 'Proxima Nova', sans-serif; + src: 'https://fonts.cdnfonts.com/css/proxima-nova-2' +} + +:root { + --md-code-font: "Source Code Pro" !important; +} + +.contributor_icon { + height: 40px; + width: 40px; + border-radius: 20px; + margin: 0 5px; +} + +.feature-flag { + background-color: rgba(255, 245, 214, .5); + border: none; + padding: 0px 5px; + text-align: center; + text-decoration: none; + display: inline-block; + margin: 4px 2px; + cursor: pointer; + font-size: .85em; +} + +[data-md-color-scheme=slate] .feature-flag { + background-color: var(--md-code-bg-color); +} + +.md-typeset ol li, +.md-typeset ul li { + margin-bottom: 0em !important; +} + +:root { + --md-admonition-icon--rust: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 512 512'%3E%3C!--! Font Awesome Free 6.4.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2023 Fonticons, Inc.--%3E%3Cpath d='m508.52 249.75-21.82-13.51c-.17-2-.34-3.93-.55-5.88l18.72-17.5a7.35 7.35 0 0 0-2.44-12.25l-24-9c-.54-1.88-1.08-3.78-1.67-5.64l15-20.83a7.35 7.35 0 0 0-4.79-11.54l-25.42-4.15c-.9-1.73-1.79-3.45-2.73-5.15l10.68-23.42a7.35 7.35 0 0 0-6.95-10.39l-25.82.91q-1.79-2.22-3.61-4.4L439 81.84a7.36 7.36 0 0 0-8.84-8.84L405 78.93q-2.17-1.83-4.4-3.61l.91-25.82a7.35 7.35 0 0 0-10.39-7L367.7 53.23c-1.7-.94-3.43-1.84-5.15-2.73l-4.15-25.42a7.35 7.35 0 0 0-11.54-4.79L326 35.26c-1.86-.59-3.75-1.13-5.64-1.67l-9-24a7.35 7.35 0 0 0-12.25-2.44l-17.5 18.72c-1.95-.21-3.91-.38-5.88-.55L262.25 3.48a7.35 7.35 0 0 0-12.5 0L236.24 25.3c-2 .17-3.93.34-5.88.55l-17.5-18.72a7.35 7.35 0 0 0-12.25 2.44l-9 24c-1.89.55-3.79 1.08-5.66 1.68l-20.82-15a7.35 7.35 0 0 0-11.54 4.79l-4.15 25.41c-1.73.9-3.45 1.79-5.16 2.73l-23.4-10.63a7.35 7.35 0 0 0-10.39 7l.92 25.81c-1.49 1.19-3 2.39-4.42 3.61L81.84 73A7.36 7.36 0 0 0 73 81.84L78.93 107c-1.23 1.45-2.43 2.93-3.62 4.41l-25.81-.91a7.42 7.42 0 0 0-6.37 3.26 7.35 7.35 0 0 0-.57 7.13l10.66 23.41c-.94 1.7-1.83 3.43-2.73 5.16l-25.41 4.14a7.35 7.35 0 0 0-4.79 11.54l15 20.82c-.59 1.87-1.13 3.77-1.68 5.66l-24 9a7.35 7.35 0 0 0-2.44 12.25l18.72 17.5c-.21 1.95-.38 3.91-.55 5.88l-21.86 13.5a7.35 7.35 0 0 0 0 12.5l21.82 13.51c.17 2 .34 3.92.55 5.87l-18.72 17.5a7.35 7.35 0 0 0 2.44 12.25l24 9c.55 1.89 1.08 3.78 1.68 5.65l-15 20.83a7.35 7.35 0 0 0 4.79 11.54l25.42 4.15c.9 1.72 1.79 3.45 2.73 5.14l-10.63 23.43a7.35 7.35 0 0 0 .57 7.13 7.13 7.13 0 0 0 6.37 3.26l25.83-.91q1.77 2.22 3.6 4.4L73 430.16a7.36 7.36 0 0 0 8.84 8.84l25.16-5.93q2.18 1.83 4.41 3.61l-.92 25.82a7.35 7.35 0 0 0 10.39 6.95l23.43-10.68c1.69.94 3.42 1.83 5.14 2.73l4.15 25.42a7.34 7.34 0 0 0 11.54 4.78l20.83-15c1.86.6 3.76 1.13 5.65 1.68l9 24a7.36 7.36 0 0 0 12.25 2.44l17.5-18.72c1.95.21 3.92.38 5.88.55l13.51 21.82a7.35 7.35 0 0 0 12.5 0l13.51-21.82c2-.17 3.93-.34 5.88-.56l17.5 18.73a7.36 7.36 0 0 0 12.25-2.44l9-24c1.89-.55 3.78-1.08 5.65-1.68l20.82 15a7.34 7.34 0 0 0 11.54-4.78l4.15-25.42c1.72-.9 3.45-1.79 5.15-2.73l23.42 10.68a7.35 7.35 0 0 0 10.39-6.95l-.91-25.82q2.22-1.79 4.4-3.61l25.15 5.93a7.36 7.36 0 0 0 8.84-8.84L433.07 405q1.83-2.17 3.61-4.4l25.82.91a7.23 7.23 0 0 0 6.37-3.26 7.35 7.35 0 0 0 .58-7.13l-10.68-23.42c.94-1.7 1.83-3.43 2.73-5.15l25.42-4.15a7.35 7.35 0 0 0 4.79-11.54l-15-20.83c.59-1.87 1.13-3.76 1.67-5.65l24-9a7.35 7.35 0 0 0 2.44-12.25l-18.72-17.5c.21-1.95.38-3.91.55-5.87l21.82-13.51a7.35 7.35 0 0 0 0-12.5Zm-151 129.08A13.91 13.91 0 0 0 341 389.51l-7.64 35.67a187.51 187.51 0 0 1-156.36-.74l-7.64-35.66a13.87 13.87 0 0 0-16.46-10.68l-31.51 6.76a187.38 187.38 0 0 1-16.26-19.21H258.3c1.72 0 2.89-.29 2.89-1.91v-54.19c0-1.57-1.17-1.91-2.89-1.91h-44.83l.05-34.35H262c4.41 0 23.66 1.28 29.79 25.87 1.91 7.55 6.17 32.14 9.06 40 2.89 8.82 14.6 26.46 27.1 26.46H407a187.3 187.3 0 0 1-17.34 20.09Zm25.77 34.49A15.24 15.24 0 1 1 368 398.08h.44a15.23 15.23 0 0 1 14.8 15.24Zm-225.62-.68a15.24 15.24 0 1 1-15.25-15.25h.45a15.25 15.25 0 0 1 14.75 15.25Zm-88.1-178.49 32.83-14.6a13.88 13.88 0 0 0 7.06-18.33L102.69 186h26.56v119.73h-53.6a187.65 187.65 0 0 1-6.08-71.58Zm-11.26-36.06a15.24 15.24 0 0 1 15.23-15.25H74a15.24 15.24 0 1 1-15.67 15.24Zm155.16 24.49.05-35.32h63.26c3.28 0 23.07 3.77 23.07 18.62 0 12.29-15.19 16.7-27.68 16.7ZM399 306.71c-9.8 1.13-20.63-4.12-22-10.09-5.78-32.49-15.39-39.4-30.57-51.4 18.86-11.95 38.46-29.64 38.46-53.26 0-25.52-17.49-41.59-29.4-49.48-16.76-11-35.28-13.23-40.27-13.23h-198.9a187.49 187.49 0 0 1 104.89-59.19l23.47 24.6a13.82 13.82 0 0 0 19.6.44l26.26-25a187.51 187.51 0 0 1 128.37 91.43l-18 40.57a14 14 0 0 0 7.09 18.33l34.59 15.33a187.12 187.12 0 0 1 .4 32.54h-19.28c-1.91 0-2.69 1.27-2.69 3.13v8.82C421 301 409.31 305.58 399 306.71ZM240 60.21A15.24 15.24 0 0 1 255.21 45h.45A15.24 15.24 0 1 1 240 60.21ZM436.84 214a15.24 15.24 0 1 1 0-30.48h.44a15.24 15.24 0 0 1-.44 30.48Z'/%3E%3C/svg%3E"); +} + +.md-typeset .admonition.rust, +.md-typeset details.rust { + border-color: rgb(205, 121, 44); +} + +.md-typeset .rust>.admonition-title, +.md-typeset .rust>summary { + background-color: rgb(205, 121, 44, .1); +} + +.md-typeset .rust>.admonition-title::before, +.md-typeset .rust>summary::before { + background-color: rgb(205, 121, 44); + -webkit-mask-image: var(--md-admonition-icon--rust); + mask-image: var(--md-admonition-icon--rust); +} + +/* Adapt Excalidraw diagrams to dark mode. */ +body[data-md-color-scheme="slate"] .excalidraw svg { + will-change: filter; + filter: invert(100%) hue-rotate(180deg); +} diff --git a/docs/source/_build/js/mathjax.js b/docs/source/_build/js/mathjax.js new file mode 100644 index 000000000000..5b34852b5eee --- /dev/null +++ b/docs/source/_build/js/mathjax.js @@ -0,0 +1,19 @@ +window.MathJax = { + tex: { + inlineMath: [["\\(", "\\)"]], + displayMath: [["\\[", "\\]"]], + processEscapes: true, + processEnvironments: true + }, + options: { + ignoreHtmlClass: ".*|", + processHtmlClass: "arithmatex" + } +}; + +document$.subscribe(() => { + MathJax.startup.output.clearCache() + MathJax.typesetClear() + MathJax.texReset() + MathJax.typesetPromise() +}) diff --git a/docs/source/_build/overrides/404.html b/docs/source/_build/overrides/404.html new file mode 100644 index 000000000000..986222099a22 --- /dev/null +++ b/docs/source/_build/overrides/404.html @@ -0,0 +1,222 @@ +{% extends "base.html" %} +{% block content %} +

+ +{% endblock %} diff --git a/docs/source/_build/scripts/macro.py b/docs/source/_build/scripts/macro.py new file mode 100644 index 000000000000..651786b0044b --- /dev/null +++ b/docs/source/_build/scripts/macro.py @@ -0,0 +1,189 @@ +from collections import OrderedDict +import os +from typing import Any, List, Optional, Set +import yaml +import logging + + +from mkdocs_macros.plugin import MacrosPlugin + +# Supported Languages and their metadata +LANGUAGES = OrderedDict( + python={ + "extension": ".py", + "display_name": "Python", + "icon_name": "python", + "code_name": "python", + }, + rust={ + "extension": ".rs", + "display_name": "Rust", + "icon_name": "rust", + "code_name": "rust", + }, +) + +# Load all links to reference docs +with open("docs/source/_build/API_REFERENCE_LINKS.yml", "r") as f: + API_REFERENCE_LINKS = yaml.load(f, Loader=yaml.CLoader) + + +def create_feature_flag_link(feature_name: str) -> str: + """Create a feature flag warning telling the user to activate a certain feature before running the code + + Args: + feature_name (str): name of the feature + + Returns: + str: Markdown formatted string with a link and the feature flag message + """ + return f'[:material-flag-plus: Available on feature {feature_name}](/user-guide/installation/#feature-flags "To use this functionality enable the feature flag {feature_name}"){{.feature-flag}}' + + +def create_feature_flag_links(language: str, api_functions: List[str]) -> List[str]: + """Generate markdown feature flags for the code tabs based on the api_functions. + It checks for the key feature_flag in the configuration yaml for the function and if it exists print out markdown + + Args: + language (str): programming languages + api_functions (List[str]): Api functions that are called + + Returns: + List[str]: Per unique feature flag a markdown formatted string for the feature flag + """ + api_functions_info = [ + info + for f in api_functions + if (info := API_REFERENCE_LINKS.get(language).get(f)) + ] + feature_flags: Set[str] = { + flag + for info in api_functions_info + if type(info) == dict and info.get("feature_flags") + for flag in info.get("feature_flags") + } + + return [create_feature_flag_link(flag) for flag in feature_flags] + + +def create_api_function_link(language: str, function_key: str) -> Optional[str]: + """Create an API link in markdown with an icon of the YAML file + + Args: + language (str): programming language + function_key (str): Key to the specific function + + Returns: + str: If the function is found than the link else None + """ + info = API_REFERENCE_LINKS.get(language, {}).get(function_key) + + if info is None: + logging.warning(f"Could not find {function_key} for language {language}") + return None + else: + # Either be a direct link + if type(info) == str: + return f"[:material-api: `{function_key}`]({info})" + else: + function_name = info["name"] + link = info["link"] + return f"[:material-api: `{function_name}`]({link})" + + +def code_tab( + base_path: str, + section: Optional[str], + language_info: dict, + api_functions: List[str], +) -> str: + """Generate a single tab for the code block corresponding to a specific language. + It gets the code at base_path and possible section and pretty prints markdown for it + + Args: + base_path (str): path where the code is located + section (str, optional): section in the code that should be displayed + language_info (dict): Language specific information (icon name, display name, ...) + api_functions (List[str]): List of api functions which should be linked + + Returns: + str: A markdown formatted string represented a single tab + """ + language = language_info["code_name"] + + # Create feature flags + feature_flags_links = create_feature_flag_links(language, api_functions) + + # Create API Links if they are defined in the YAML + api_functions = [ + link for f in api_functions if (link := create_api_function_link(language, f)) + ] + language_headers = " ·".join(api_functions + feature_flags_links) + + # Create path for Snippets extension + snippets_file_name = f"{base_path}:{section}" if section else f"{base_path}" + + # See Content Tabs for details https://squidfunk.github.io/mkdocs-material/reference/content-tabs/ + return f"""=== \":fontawesome-brands-{language_info['icon_name']}: {language_info['display_name']}\" + {language_headers} + ```{language} + --8<-- \"{snippets_file_name}\" + ``` + """ + + +def define_env(env: MacrosPlugin) -> None: + @env.macro + def code_header( + language: str, section: str = [], api_functions: List[str] = [] + ) -> str: + language_info = LANGUAGES[language] + + language = language_info["code_name"] + + # Create feature flags + feature_flags_links = create_feature_flag_links(language, api_functions) + + # Create API Links if they are defined in the YAML + api_functions = [ + link + for f in api_functions + if (link := create_api_function_link(language, f)) + ] + language_headers = " ·".join(api_functions + feature_flags_links) + return f"""=== \":fontawesome-brands-{language_info['icon_name']}: {language_info['display_name']}\" + {language_headers}""" + + @env.macro + def code_block( + path: str, + section: str = None, + api_functions: List[str] = None, + python_api_functions: List[str] = None, + rust_api_functions: List[str] = None, + ) -> str: + """Dynamically generate a code block for the code located under {language}/path + + Args: + path (str): base_path for each language + section (str, optional): Optional segment within the code file. Defaults to None. + api_functions (List[str], optional): API functions that should be linked. Defaults to None. + Returns: + str: Markdown tabbed code block with possible links to api functions and feature flags + """ + result = [] + + for language, info in LANGUAGES.items(): + base_path = f"{language}/{path}{info['extension']}" + full_path = "docs/source/src/" + base_path + if language == "python": + extras = python_api_functions or [] + else: + extras = rust_api_functions or [] + # Check if file exists for the language + if os.path.exists(full_path): + result.append( + code_tab(base_path, section, info, api_functions + extras) + ) + + return "\n".join(result) diff --git a/docs/source/_build/scripts/people.py b/docs/source/_build/scripts/people.py new file mode 100644 index 000000000000..fd94874fa178 --- /dev/null +++ b/docs/source/_build/scripts/people.py @@ -0,0 +1,41 @@ +import itertools +from github import Github, Auth +import os + +token = os.getenv("GITHUB_TOKEN") +auth = Auth.Token(token) if token else None +g = Github(auth=auth) + +ICON_TEMPLATE = '{login}' + + +def get_people_md(): + repo = g.get_repo("pola-rs/polars") + contributors = repo.get_contributors() + with open("./docs/assets/people.md", "w") as f: + for c in itertools.islice(contributors, 50): + # We love dependabot, but he doesn't need a spot on our website + if c.login == "dependabot[bot]": + continue + + f.write( + ICON_TEMPLATE.format( + login=c.login, + avatar_url=c.avatar_url, + html_url=c.html_url, + ) + + "\n" + ) + + +def on_startup(command, dirty): + """Mkdocs hook to autogenerate docs/assets/people.md on startup""" + try: + get_people_md() + except Exception as e: + msg = f"WARNING:{__file__}: Could not generate docs/assets/people.md. Got error: {str(e)}" + print(msg) + + +if __name__ == "__main__": + get_people_md() diff --git a/docs/source/_build/snippets/under_construction.md b/docs/source/_build/snippets/under_construction.md new file mode 100644 index 000000000000..00b0cc4af922 --- /dev/null +++ b/docs/source/_build/snippets/under_construction.md @@ -0,0 +1,4 @@ +!!! warning ":construction: Under Construction :construction:" + + This section is still under development. Want to help out? Consider contributing and making a [pull request](https://github.com/pola-rs/polars) to our repository. + Please read our [contributing guide](https://docs.pola.rs/development/contributing/) on how to proceed. diff --git a/docs/source/api/reference.md b/docs/source/api/reference.md new file mode 100644 index 000000000000..199401411f8c --- /dev/null +++ b/docs/source/api/reference.md @@ -0,0 +1,14 @@ +# Reference guide + +The API reference contains detailed descriptions of all public functions and objects. It's the best +place to look if you need information on a specific function. + +## Python + +The Python API reference is built using Sphinx. It's available in +[our docs](https://docs.pola.rs/api/python/stable/reference/index.html). + +## Rust + +The Rust API reference is built using Cargo. It's available on +[docs.rs](https://docs.rs/polars/latest/polars/). diff --git a/docs/source/development/contributing/ci.md b/docs/source/development/contributing/ci.md new file mode 100644 index 000000000000..8488417826a1 --- /dev/null +++ b/docs/source/development/contributing/ci.md @@ -0,0 +1,63 @@ +# Continuous integration + +Polars uses GitHub Actions as its continuous integration (CI) tool. The setup is reasonably complex, +as far as CI setups go. This page explains some of the design choices. + +## Goal + +Overall, the CI suite aims to achieve the following: + +- Enforce code correctness by running automated tests. +- Enforce code quality by running automated linting checks. +- Enforce code performance by running benchmark tests. +- Enforce that code is properly documented. +- Allow maintainers to easily publish new releases. + +We rely on a wide range of tools to achieve this for both the Rust and the Python code base, and +thus a lot of checks are triggered on each pull request. + +It's entirely possible that you submit a relatively trivial fix that subsequently fails a bunch of +checks. Do not despair - check the logs to see what went wrong and try to fix it. You can run the +failing command locally to verify that everything works correctly. If you can't figure it out, ask a +maintainer for help! + +## Design + +The CI setup is designed with the following requirements in mind: + +- Get feedback on each step individually. We want to avoid our test job being cancelled because a + linting check failed, only to find out later that we also have a failing test. +- Get feedback on each check as quickly as possible. We want to be able to iterate quickly if it + turns out our code does not pass some of the checks. +- Only run checks when they need to be run. A change to the Rust code does not warrant a linting + check of the Python code, for example. + +This results in a modular setup with many separate workflows and jobs that rely heavily on caching. + +### Modular setup + +The repository consists of two main parts: the Rust code base and the Python code base. Both code +bases are interdependent: Rust code is tested through Python tests, and the Python code relies on +the Rust implementation for most functionality. + +To make sure CI jobs are only run when they need to be run, each workflow is triggered only when +relevant files are modified. + +### Caching + +The main challenge is that the Rust code base for Polars is quite large, and consequently, compiling +the project from scratch is slow. This is addressed by caching the Rust build artifacts. + +However, since GitHub Actions does not allow sharing caches between feature branches, we need to run +the workflows on the main branch as well - at least the part that builds the Rust cache. This leads +to many workflows that trigger both on pull request AND on push to the main branch, with individual +steps of jobs enabled or disabled based on the branch it runs on. + +Care must also be taken not to exceed the maximum cache space of 10Gb allotted to open source GitHub +repositories. Hence we do not do any caching on feature branches - we always use the cache available +from the main branch. This also avoids any extra time that would be required to store the cache. + +## Releases + +The release jobs for Rust and Python are triggered manually. Refer to the +[contributing guide](./index.md#release-flow) for the full release process. diff --git a/docs/source/development/contributing/code-style.md b/docs/source/development/contributing/code-style.md new file mode 100644 index 000000000000..c22b9f3ed7ac --- /dev/null +++ b/docs/source/development/contributing/code-style.md @@ -0,0 +1,94 @@ +# Code style + +This page contains some guidance on code style. + +!!! info + + Additional information will be added to this page later. + +## Rust + +### Naming conventions + +Naming conventions for variables: + +```rust +let s: Series = ... +let ca: ChunkedArray = ... +let arr: ArrayRef = ... +let arr: PrimitiveArray = ... +let dtype: DataType = ... +let dtype: ArrowDataType = ... +``` + +### Code example + +```rust +use std::ops::Add; + +use polars::export::arrow::array::*; +use polars::export::arrow::compute::arity::binary; +use polars::export::arrow::types::NativeType; +use polars::prelude::*; +use polars_core::utils::{align_chunks_binary, combine_validities_and}; +use polars_core::with_match_physical_numeric_polars_type; + +// Prefer to do the compute closest to the arrow arrays. +// this will tend to be faster as iterators can work directly on slices and don't have +// to go through boxed traits +fn compute_kernel(arr_1: &PrimitiveArray, arr_2: &PrimitiveArray) -> PrimitiveArray +where + T: Add + NativeType, +{ + // process the null data separately + // this saves an expensive branch and bitoperation when iterating + let validity_1 = arr_1.validity(); + let validity_2 = arr_2.validity(); + + let validity = combine_validities_and(validity_1, validity_2); + + // process the numerical data as if there were no validities + let values_1: &[T] = arr_1.values().as_slice(); + let values_2: &[T] = arr_2.values().as_slice(); + + let values = values_1 + .iter() + .zip(values_2) + .map(|(a, b)| *a + *b) + .collect::>(); + + PrimitiveArray::from_data_default(values.into(), validity) +} + +// Same kernel as above, but uses the `binary` abstraction. Prefer this, +#[allow(dead_code)] +fn compute_kernel2(arr_1: &PrimitiveArray, arr_2: &PrimitiveArray) -> PrimitiveArray +where + T: Add + NativeType, +{ + binary(arr_1, arr_2, arr_1.dtype().clone(), |a, b| a + b) +} + +fn compute_chunked_array_2_args( + ca_1: &ChunkedArray, + ca_2: &ChunkedArray, +) -> ChunkedArray { + // This ensures both ChunkedArrays have the same number of chunks with the + // same offset and the same length. + let (ca_1, ca_2) = align_chunks_binary(ca_1, ca_2); + let chunks = ca_1 + .downcast_iter() + .zip(ca_2.downcast_iter()) + .map(|(arr_1, arr_2)| compute_kernel(arr_1, arr_2)); + ChunkedArray::from_chunk_iter(ca_1.name(), chunks) +} + +pub fn compute_expr_2_args(arg_1: &Series, arg_2: &Series) -> Series { + // Dispatch the numerical series to `compute_chunked_array_2_args`. + with_match_physical_numeric_polars_type!(arg_1.dtype(), |$T| { + let ca_1: &ChunkedArray<$T> = arg_1.as_ref().as_ref().as_ref(); + let ca_2: &ChunkedArray<$T> = arg_2.as_ref().as_ref().as_ref(); + compute_chunked_array_2_args(ca_1, ca_2).into_series() + }) +} +``` diff --git a/docs/source/development/contributing/ide.md b/docs/source/development/contributing/ide.md new file mode 100644 index 000000000000..4e500001e54d --- /dev/null +++ b/docs/source/development/contributing/ide.md @@ -0,0 +1,139 @@ +# IDE configuration + +Using an integrated development environments (IDE) and configuring it properly will help you work on +Polars more effectively. This page contains some recommendations for configuring popular IDEs. + +## Visual Studio Code + +Make sure to configure VSCode to use the virtual environment created by the Makefile. + +### Extensions + +The extensions below are recommended. + +#### rust-analyzer + +If you work on the Rust code at all, you will need the +[rust-analyzer](https://marketplace.visualstudio.com/items?itemName=rust-lang.rust-analyzer) +extension. This extension provides code completion for the Rust code. + +For it to work well for the Polars code base, add the following settings to your +`.vscode/settings.json`: + +```json +{ + "rust-analyzer.cargo.features": "all", + "rust-analyzer.cargo.targetDir": true +} +``` + +#### Ruff + +The [Ruff](https://marketplace.visualstudio.com/items?itemName=charliermarsh.ruff) extension will +help you conform to the formatting requirements of the Python code. We use both the Ruff linter and +formatter. It is recommended to configure the extension to use the Ruff installed in your +environment. This will make it use the correct Ruff version and configuration. + +```json +{ + "ruff.importStrategy": "fromEnvironment" +} +``` + +#### CodeLLDB + +The [CodeLLDB](https://marketplace.visualstudio.com/items?itemName=vadimcn.vscode-lldb) extension is +useful for debugging Rust code. You can also debug Rust code called from Python (see section below). + +### Debugging + +Due to the way that Python and Rust interoperate, debugging the Rust side of development from Python +calls can be difficult. This guide shows how to set up a debugging environment that makes debugging +Rust code called from a Python script painless. + +#### Preparation + +Start by installing the CodeLLDB extension (see above). Then add the following two configurations to +your `launch.json` file. This file is usually found in the `.vscode` folder of your project root. +See the +[official VSCode documentation](https://code.visualstudio.com/docs/editor/debugging#_launch-configurations) +for more information about the `launch.json` file. + +
launch.json + +```json +{ + "configurations": [ + { + "name": "Debug Rust/Python", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/py-polars/debug/launch.py", + "args": [ + "${file}" + ], + "console": "internalConsole", + "justMyCode": true, + "serverReadyAction": { + "pattern": "pID = ([0-9]+)", + "action": "startDebugging", + "name": "Rust LLDB" + } + }, + { + "name": "Rust LLDB", + "pid": "0", + "type": "lldb", + "request": "attach", + "program": "${workspaceFolder}/py-polars/.venv/bin/python", + "stopOnEntry": false, + "sourceLanguages": [ + "rust" + ], + "presentation": { + "hidden": true + } + } + ] +} +``` + +
+ +!!! info + + On some systems, the LLDB debugger will not attach unless [ptrace protection](https://linux-audit.com/protect-ptrace-processes-kernel-yama-ptrace_scope) is disabled. + To disable, run the following command: + + ```shell + echo 0 | sudo tee /proc/sys/kernel/yama/ptrace_scope + ``` + +#### Running the debugger + +1. Create a Python script containing Polars code. Ensure that your virtual environment is activated. + +2. Set breakpoints in any `.rs` or `.py` file. + +3. In the `Run and Debug` panel on the left, select `Debug Rust/Python` from the drop-down menu on + top and click the `Start Debugging` button. + +At this point, your debugger should stop on breakpoints in any `.rs` file located within the +codebase. + +#### Details + +The debugging feature runs via the specially-designed VSCode launch configuration shown above. The +initial Python debugger is launched using a special launch script located at +`py-polars/debug/launch.py` and passes the name of the script to be debugged (the target script) as +an input argument. The launch script determines the process ID, writes this value into the +`launch.json` configuration file, compiles the target script and runs it in the current environment. +At this point, a second (Rust) debugger is attached to the Python debugger. The result is two +simultaneous debuggers operating on the same running instance. Breakpoints in the Python code will +stop on the Python debugger and breakpoints in the Rust code will stop on the Rust debugger. + +## JetBrains (PyCharm, RustRover, CLion) + +!!! info + + More information needed. diff --git a/docs/source/development/contributing/index.md b/docs/source/development/contributing/index.md new file mode 100644 index 000000000000..97765dd8ddc0 --- /dev/null +++ b/docs/source/development/contributing/index.md @@ -0,0 +1,430 @@ +--- +render_macros: false +--- + +# Overview + +Thanks for taking the time to contribute! We appreciate all contributions, from reporting bugs to +implementing new features. If you're unclear on how to proceed after reading this guide, please +contact us on [Discord](https://discord.gg/4UfP5cfBE7). + +## Reporting bugs + +We use [GitHub issues](https://github.com/pola-rs/polars/issues) to track bugs and suggested +enhancements. You can report a bug by opening a +[new issue](https://github.com/pola-rs/polars/issues/new/choose). Use the appropriate issue type for +the language you are using +([Rust](https://github.com/pola-rs/polars/issues/new?labels=bug&template=bug_report_rust.yml) / +[Python](https://github.com/pola-rs/polars/issues/new?labels=bug&template=bug_report_python.yml)). + +Before creating a bug report, please check that your bug has not already been reported, and that +your bug exists on the latest version of Polars. If you find a closed issue that seems to report the +same bug you're experiencing, open a new issue and include a link to the original issue in your +issue description. + +Please include as many details as possible in your bug report. The information helps the maintainers +resolve the issue faster. + +## Suggesting enhancements + +We use [GitHub issues](https://github.com/pola-rs/polars/issues) to track bugs and suggested +enhancements. You can suggest an enhancement by opening a +[new feature request](https://github.com/pola-rs/polars/issues/new?labels=enhancement&template=feature_request.yml). +Before creating an enhancement suggestion, please check that a similar issue does not already exist. + +Please describe the behavior you want and why, and provide examples of how Polars would be used if +your feature were added. + +## Contributing to the codebase + +### Picking an issue + +Pick an issue by going through the [issue tracker](https://github.com/pola-rs/polars/issues) and +finding an issue you would like to work on. Feel free to pick any issue with an +[accepted](https://github.com/pola-rs/polars/issues?q=is%3Aopen+is%3Aissue+label%3Aaccepted) label +that is not already assigned. We use the +[help wanted](https://github.com/pola-rs/polars/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22) +label to indicate issues that are high on our wishlist. + +If you are a first time contributor, you might want to look for issues labeled +[good first issue](https://github.com/pola-rs/polars/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22). +The Polars code base is quite complex, so starting with a small issue will help you find your way +around! + +If you would like to take on an issue, please comment on the issue to let others know. You may use +the issue to discuss possible solutions. + +### Setting up your local environment + +The Polars development flow relies on both Rust and Python, which means setting up your local +development environment is not trivial. If you run into problems, please contact us on +[Discord](https://discord.gg/4UfP5cfBE7). + +!!! note + + If you are a Windows user, the steps below might not work as expected. + Try developing using [WSL](https://learn.microsoft.com/en-us/windows/wsl/install). + Under native Windows, you may have to manually copy the contents of `toolchain.toml` to `py-polars/toolchain.toml`, as Git for Windows may not correctly handle symbolic links. + +#### Configuring Git + +For contributing to Polars you need a free [GitHub account](https://github.com) and have +[git](https://git-scm.com) installed on your machine. Start by +[forking](https://docs.github.com/en/get-started/quickstart/fork-a-repo) the Polars repository, then +clone your forked repository using `git`: + +```bash +git clone https://github.com//polars.git +cd polars +``` + +Optionally set the `upstream` remote to be able to sync your fork with the Polars repository in the +future: + +```bash +git remote add upstream https://github.com/pola-rs/polars.git +git fetch upstream +``` + +#### Installing dependencies + +In order to work on Polars effectively, you will need [Rust](https://www.rust-lang.org/), +[Python](https://www.python.org/), and [dprint](https://dprint.dev/). + +First, install Rust using [rustup](https://www.rust-lang.org/tools/install). After the initial +installation, you will also need to install the nightly toolchain: + +```bash +rustup toolchain install nightly --component miri +``` + +Next, install Python, for example using [pyenv](https://github.com/pyenv/pyenv#installation). We +recommend using the latest Python version (`3.13`). Make sure you deactivate any active virtual +environments (command: `deactivate`) or conda environments (command: `conda deactivate`), as the +steps below will create a new [virtual environment](https://docs.python.org/3/tutorial/venv.html) +for Polars. You will need Python even if you intend to work on the Rust code only, as we rely on the +Python tests to verify all functionality. + +Finally, install [dprint](https://dprint.dev/install/). This is not strictly required, but it is +recommended as we use it to autoformat certain file types. + +You can now check that everything works correctly by going into the `py-polars` directory and +running the test suite (warning: this may be slow the first time you run it): + +```bash +cd py-polars +make test +``` + +!!! note + + You need to have [CMake](https://cmake.org/) installed for `make test` to work. + +This will do a number of things: + +- Use Python to create a virtual environment in the `.venv` folder. +- Use [pip](https://pip.pypa.io/) and [uv](https://github.com/astral-sh/uv) to install all Python + dependencies for development, linting, and building documentation. +- Use Rust to compile and install Polars in your virtual environment. _At least 8GB of RAM is + recommended for this step to run smoothly._ +- Use [pytest](https://docs.pytest.org/) to run the Python unittests in your virtual environment + +!!! note + + There are a small number of specialized dependencies that are not installed by default. + If you are running specific tests and encounter an error message about a missing dependency, + try running `make requirements-all` to install _all_ known dependencies). + +Check if linting also works correctly by running: + +```bash +make pre-commit +``` + +Note that we do not actually use the [pre-commit](https://pre-commit.com/) tool. We use the Makefile +to conveniently run the following formatting and linting tools: + +- [ruff](https://github.com/charliermarsh/ruff) +- [mypy](http://mypy-lang.org/) +- [rustfmt](https://github.com/rust-lang/rustfmt) +- [clippy](https://doc.rust-lang.org/nightly/clippy/index.html) +- [dprint](https://dprint.dev/) + +If this all runs correctly, you're ready to start contributing to the Polars codebase! + +#### Updating the development environment + +Dependencies are updated regularly - at least once per month. If you do not keep your environment +up-to-date, you may notice tests or CI checks failing, or you may not be able to build Polars at +all. + +To update your environment, first make sure your fork is in sync with the Polars repository: + +```bash +git checkout main +git fetch upstream +git rebase upstream/main +git push origin main +``` + +Update all Python dependencies to their latest versions by running: + +```bash +make requirements +``` + +If the Rust toolchain version has been updated, you should update your Rust toolchain. Follow it up +by running `cargo clean` to make sure your Cargo folder does not grow too large: + +```bash +rustup update +cargo clean +``` + +### Working on your issue + +Create a new git branch from the `main` branch in your local repository, and start coding! + +The Rust code is located in the `crates` directory, while the Python codebase is located in the +`py-polars` directory. Both directories contain a `Makefile` with helpful commands. Most notably: + +- `make test` to run the test suite (see the [test suite docs](./test.md) for more info) +- `make pre-commit` to run autoformatting and linting + +Note that your work cannot be merged if these checks fail! Run `make help` to get a list of other +helpful commands. + +Two other things to keep in mind: + +- If you add code that should be tested, add tests. +- If you change the public API, [update the documentation](#api-reference). + +### Pull requests + +When you have resolved your issue, +[open a pull request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork) +in the Polars repository. Please adhere to the following guidelines: + + +- Title: + - Start your pull request title with a [conventional commit](https://www.conventionalcommits.org/) tag. + This helps us add your contribution to the right section of the changelog. + We use the [Angular convention](https://github.com/angular/angular/blob/22b96b9/CONTRIBUTING.md#type). + Scope can be `rust` and/or `python`, depending on your contribution: this tag determines which changelog(s) will include your change. + Omit the scope if your change affects both Rust and Python. + - Use a descriptive title starting with an uppercase letter. + This text will end up in the [changelog](https://github.com/pola-rs/polars/releases), so make sure the text is meaningful to the user. + Use single backticks to annotate code snippets. + Use active language and do not end your title with punctuation. + - Example: ``fix(python): Fix `DataFrame.top_k` not handling nulls correctly`` +- Description: + - In the pull request description, [link](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue) to the issue you were working on. + - Add any relevant information to the description that you think may help the maintainers review your code. +- Make sure your branch is [rebased](https://docs.github.com/en/get-started/using-git/about-git-rebase) against the latest version of the `main` branch. +- Make sure all [GitHub Actions checks](./ci.md) pass. + + +After you have opened your pull request, a maintainer will review it and possibly leave some +comments. Once all issues are resolved, the maintainer will merge your pull request, and your work +will be part of the next Polars release! + +Keep in mind that your work does not have to be perfect right away! If you are stuck or unsure about +your solution, feel free to open a draft pull request and ask for help. + +## Contributing to documentation + +The most important components of Polars documentation are the +[user guide](https://docs.pola.rs/user-guide/), the +[API references](https://docs.pola.rs/api/python/stable/reference/index.html), and the database of +questions on [StackOverflow](https://stackoverflow.com/). + +### User guide + +The user guide is maintained in the `docs/source/user-guide` folder. Before creating a PR first +raise an issue to discuss what you feel is missing or could be improved. + +#### Building and serving the user guide + +The user guide is built using [MkDocs](https://www.mkdocs.org/). You install the dependencies for +building the user guide by running `make build` in the root of the repo. Additionally, you need to +make sure the [graphviz](https://graphviz.org/) `dot` binary is on your path. + +Activate the virtual environment and run `mkdocs serve` to build and serve the user guide, so you +can view it locally and see updates as you make changes. + +#### Creating a new user guide page + +Each user guide page is based on a `.md` markdown file. This file must be listed in `mkdocs.yml`. + +#### Adding a shell code block + +To add a code block with code to be run in a shell with tabs for Python and Rust, use the following +format: + +```` +=== ":fontawesome-brands-python: Python" + + ```shell + $ pip install fsspec + ``` + +=== ":fontawesome-brands-rust: Rust" + + ```shell + $ cargo add aws_sdk_s3 + ``` +```` + +#### Adding a code block + +The snippets for Python and Rust code blocks are in the `docs/source/src/python/` and +`docs/source/src/rust/` directories, respectively. To add a code snippet with Python or Rust code to +a `.md` page, use the following format: + +``` +{{code_block('user-guide/io/cloud-storage','read_parquet',['read_parquet','read_csv'])}} +``` + +- The first argument is a path to either or both files called + `docs/source/src/python/user-guide/io/cloud-storage.py` and + `docs/source/src/rust/user-guide/io/cloud-storage.rs`. +- The second argument is the name given at the start and end of each snippet in the `.py` or `.rs` + file +- The third argument is a list of links to functions in the API docs. For each element of the list + there must be a corresponding entry in `docs/source/_build/API_REFERENCE_LINKS.yml` + +If the corresponding `.py` and `.rs` snippet files both exist then each snippet named in the second +argument to `code_block` above must exist or the build will fail. An empty snippet should be added +to the `.py` or `.rs` file if the snippet is not needed. + +Each snippet is formatted as follows: + +```python +# --8<-- [start:read_parquet] +import polars as pl + +df = pl.read_parquet("file.parquet") +# --8<-- [end:read_parquet] +``` + +The snippet is delimited by `--8<-- [start:]` and `--8<-- [end:]`. The +snippet name must match the name given in the second argument to `code_block` above. + +In some cases, you may need to add links to different functions for the Python and Rust APIs. When +that is the case, you can use the two extra optional arguments that `code_block` accepts, that can +be used to pass Python-only and Rust-only links: + +``` +{{code_block('path', 'snippet_name', ['common_api_links'], ['python_only_links'], ['rust_only_links'])}} +``` + +#### Linting + +Before committing, install `dprint` (see above) and run `dprint fmt` from the `docs` directory to +lint the markdown files. + +### API reference + +Polars has separate API references for [Rust](https://docs.pola.rs/api/rust/dev/polars/) and +[Python](https://docs.pola.rs/api/python/dev/reference/index.html). These are generated directly +from the codebase, so in order to contribute, you will have to follow the steps outlined in +[this section](#contributing-to-the-codebase) above. + +#### Rust + +Rust Polars uses `cargo doc` to build its documentation. Contributions to improve or clarify the API +reference are welcome. + +#### Python + +For the Python API reference, we always welcome good docstring examples. There are still parts of +the API that do not have any code examples. This is a great way to start contributing to Polars! + +Note that we follow the [numpydoc](https://numpydoc.readthedocs.io/en/latest/format.html) +convention. Docstring examples should also follow the [Black](https://black.readthedocs.io/) +codestyle. From the `py-polars` directory, run `make fmt` to make sure your additions pass the +linter, and run `make doctest` to make sure your docstring examples are valid. + +Polars uses Sphinx to build the API reference. This means docstrings in general should follow the +[reST](https://www.sphinx-doc.org/en/master/usage/restructuredtext/basics.html) format. If you want +to build the API reference locally, go to the `py-polars/docs` directory and run `make html`. The +resulting HTML files will be in `py-polars/docs/build/html`. + +New additions to the API should be added manually to the API reference by adding an entry to the +correct `.rst` file in the `py-polars/docs/source/reference` directory. + +### StackOverflow + +We use StackOverflow to create a database of high quality questions and answers that is searchable +and remains up-to-date. There is a separate tag for each language: + +- [Python Polars](https://stackoverflow.com/questions/tagged/python-polars) +- [Rust Polars](https://stackoverflow.com/questions/tagged/rust-polars) + +Contributions in the form of well-formulated questions or answers are always welcome! If you add a +new question, please notify us by adding a +[matching issue](https://github.com/pola-rs/polars/issues/new?&labels=question&template=question.yml) +to our GitHub issue tracker. + +## Release flow + +_This section is intended for Polars maintainers._ + +Polars releases Rust crates to [crates.io](https://crates.io/crates/polars) and Python packages to +[PyPI](https://pypi.org/project/polars/). + +New releases are marked by an official [GitHub release](https://github.com/pola-rs/polars/releases) +and an associated git tag. We utilize +[Release Drafter](https://github.com/release-drafter/release-drafter) to automatically draft GitHub +releases with release notes. + +### Steps + +The steps for releasing a new Rust or Python version are similar. The release process is mostly +automated through GitHub Actions, but some manual steps are required. Follow the steps below to +release a new version. + +Start by bumping the version number in the source code: + +1. Check the [releases page](https://github.com/pola-rs/polars/releases) on GitHub and find the + appropriate draft release. Note the version number associated with this release. +2. Make sure your fork is up-to-date with the latest version of the main Polars repository, and + create a new branch. +3. Bump the version number. + +- _Rust:_ Update the version number in all `Cargo.toml` files in the `polars` directory and + subdirectories. You'll probably want to use some search/replace strategy, as there are quite a few + crates that need to be updated. +- _Python:_ Update the version number in + [`py-polars/Cargo.toml`](https://github.com/pola-rs/polars/blob/main/py-polars/Cargo.toml#L3) to + match the version of the draft release. + +4. From the `py-polars` directory, run `make build` to generate a new `Cargo.lock` file. +5. Create a new commit with all files added. The name of the commit should follow the format + `release(): Polars `. For example: + `release(python): Python Polars 0.16.1` +6. Push your branch and open a new pull request to the `main` branch of the main Polars repository. +7. Wait for the GitHub Actions checks to pass, then squash and merge your pull request. + +Directly after merging your pull request, release the new version: + +8. Go to the release workflow + ([Python](https://github.com/pola-rs/polars/actions/workflows/release-python.yml)/[Rust](https://github.com/pola-rs/polars/actions/workflows/release-rust.yml)), + click _Run workflow_ in the top right, and click the green button. This will trigger the + workflow, which will build all release artifacts and publish them. +9. Wait for the workflow to finish, then check + [crates.io](https://crates.io/crates/polars)/[PyPI](https://pypi.org/project/polars/)/[GitHub](https://github.com/pola-rs/polars/releases) + to verify that the new Polars release is now available. + +### Troubleshooting + +It may happen that one or multiple release jobs fail. If so, you should first try to simply re-run +the failed jobs from the GitHub Actions UI. + +If that doesn't help, you will have to figure out what's wrong and commit a fix. Once your fix has +made it to the `main` branch, simply re-trigger the release workflow. + +## License + +Any contributions you make to this project will fall under the +[MIT License](https://github.com/pola-rs/polars/blob/main/LICENSE) that covers the Polars project. diff --git a/docs/source/development/contributing/test.md b/docs/source/development/contributing/test.md new file mode 100644 index 000000000000..b219bcb085fe --- /dev/null +++ b/docs/source/development/contributing/test.md @@ -0,0 +1,141 @@ +# Test suite + +!!! info + + Additional information on the Rust test suite will be added to this page later. + +The `py-polars/tests` folder contains the main Polars test suite. This page contains some +information on the various components of the test suite, as well as guidelines for writing new +tests. + +The test suite contains four main components, each confined to their own folder: unit tests, +parametric tests, benchmark tests, and doctests. + +Note that this test suite is indirectly responsible for testing Rust Polars as well. The Rust test +suite is kept small to reduce compilation times. A lot of the Rust functionality is tested here +instead. + +## Unit tests + +The `unit` folder contains all regular unit tests. These tests are intended to make sure all Polars +functionality works as intended. + +### Running unit tests + +Run unit tests by running `make test` from the `py-polars` folder. This will compile the Rust +bindings and then run the unit tests. + +If you're working in the Python code only, you can avoid recompiling every time by simply running +`pytest` instead from your virtual environment. + +By default, "slow" tests and "ci-only" tests are skipped for local test runs. Such tests are marked +using a [custom pytest marker](https://docs.pytest.org/en/latest/example/markers.html). To run these +tests specifically, you can run `pytest -m slow`, `pytest -m ci_only`, `pytest -m slow ci_only` or +run `pytest -m ""` to run _all_ tests, regardless of marker. + +Note that the "ci-only" tests may require you to run `make requirements-all` to get additional +dependencies (such as `torch`) that are otherwise not installed as part of the default Polars +development environment. + +Tests can be run in parallel by running `pytest -n auto`. The parallelization is handled by +[`pytest-xdist`](https://pytest-xdist.readthedocs.io/en/latest/). + +### Writing unit tests + +Whenever you add new functionality, you should also add matching unit tests. Add your tests to +appropriate test module in the `unit` folder. Some guidelines to keep in mind: + +- Try to fully cover all possible inputs and edge cases you can think of. +- Utilize pytest tools like [`fixture`](https://docs.pytest.org/en/latest/explanation/fixtures.html) + and [`parametrize`](https://docs.pytest.org/en/latest/how-to/parametrize.html) where appropriate. +- Since many tests will require some data to be defined first, it can be efficient to run multiple + checks in a single test. This can also be addressed using pytest fixtures. +- Unit tests should not depend on external factors, otherwise test parallelization will break. + +## Parametric tests + +The `parametric` folder contains parametric tests written using the +[Hypothesis](https://hypothesis.readthedocs.io/) framework. These tests are intended to find and +test edge cases by generating many random datapoints. + +### Running parametric tests + +Run parametric tests by running `pytest -m hypothesis`. + +Note that parametric tests are excluded by default when running `pytest`. You must explicitly +specify `-m hypothesis` to run them. + +These tests _will_ be included when calculating test coverage, and will also be run as part of the +`make test-all` make command. + +## Doctests + +The `docs` folder contains a script for running +[`doctest`](https://docs.python.org/3/library/doctest.html). This folder does not contain any actual +tests - rather, the script checks all docstrings in the Polars package for `Examples` sections, runs +the code examples, and verifies the output. + +The aim of running `doctest` is to make sure the `Examples` sections in our docstrings are valid and +remain up-to-date with code changes. + +### Running `doctest` + +To run the `doctest` module, run `make doctest` from the `py-polars` folder. You can also run the +script directly from your virtual environment. + +Note that doctests are _not_ run using pytest. While pytest does have the capability to run doc +examples, configuration options are too limited for our purposes. + +Doctests will _not_ count towards test coverage. They are not a substitute for unit tests, but +rather intended to convey the intended use of the Polars API to the user. + +### Writing doc examples + +Almost all classes/methods/functions that are part of Polars' public API should include code +examples in their docstring. These examples help users understand basic usage and allow us to +illustrate more advanced concepts as well. Some guidelines for writing a good docstring `Examples` +section: + +- Start with a minimal example that showcases the default functionality. +- Showcase the effect of its parameters. +- Showcase any special interactions when combined with other code. +- Keep it succinct and avoid multiple examples showcasing the same thing. + +There are many great docstring examples already, just check other code if you need inspiration! + +In addition to the [regular options](https://docs.python.org/3/library/doctest.html#option-flags) +available when writing doctests, the script configuration allows for a new `IGNORE_RESULT` +directive. Use this directive if you want to ensure the code runs, but the output may be random by +design or not interesting to check. + +```python +>>> df.sample(n=2) # doctest: +IGNORE_RESULT +``` + +## Benchmark tests + +The `benchmark` folder contains code for running various benchmark tests. The aim of this part of +the test suite is to spot performance regressions in the code, and to verify that Polars +functionality works as expected when run on a release build or at a larger scale. + +Polars uses [CodSpeed](https://codspeed.io/pola-rs/polars) for tracking the performance of the +benchmark tests. + +### Generating data + +For most tests, a relatively large dataset must be generated first. This is done as part of the +`pytest` setup process. + +The data generation logic was taken from the +[H2O.ai database benchmark](https://github.com/h2oai/db-benchmark), which is the foundation for many +of the benchmark tests. + +### Running the benchmark tests + +The benchmark tests can be run using pytest. Run `pytest -m benchmark --durations 0 -v` to run these +tests and report run duration. + +Note that benchmark tests are excluded by default when running `pytest`. You must explicitly specify +`-m benchmark` to run them. They will also be excluded when calculating test coverage. + +These tests _will_ be run as part of the `make test-all` make command. diff --git a/docs/source/development/versioning.md b/docs/source/development/versioning.md new file mode 100644 index 000000000000..402c935b28aa --- /dev/null +++ b/docs/source/development/versioning.md @@ -0,0 +1,125 @@ +# Versioning + +## Version changes + +Polars adheres to the [semantic versioning](https://semver.org/) specification: + +- Breaking changes lead to a **major** version increase (`1.0.0`, `2.0.0`, ...) +- New features and performance improvements lead to a **minor** version increase (`1.1.0`, `1.2.0`, + ...) +- Other changes lead to a **patch** version increase (`1.0.1`, `1.0.2`, ...) + +## Policy for breaking changes + +Polars takes backwards compatibility seriously, but we are not afraid to change things if it leads +to a better product. + +### Philosophy + +We don't always get it right on the first try. We learn as we go along and get feedback from our +users. Sometimes, we're a little too eager to get out a new feature and didn't ponder all the +possible implications. + +If this happens, we correct our mistakes and introduce a breaking change. Most of the time, this is +no big deal. Users get a deprecation warning, they do a quick search-and-replace in their code base, +and that's that. + +At times, we run into an issue requires more effort on our user's part to fix. A change in the query +engine can seriously impact the assumptions in a data pipeline. We do not make such changes lightly, +but we will make them if we believe it makes Polars better. + +Freeing ourselves of past indiscretions is important to keep Polars moving forward. We know it takes +time and energy for our users to keep up with new releases but, in the end, it benefits everyone for +Polars to be the best product possible. + +### What qualifies as a breaking change + +**A breaking change occurs when an existing component of the public API is changed or removed.** + +A feature is part of the public API if it is documented in the +[API reference](https://docs.pola.rs/api/python/stable/reference/index.html). + +Examples of breaking changes: + +- A deprecated function or method is removed. +- The default value of a parameter is changed. +- The outcome of a query has changed due to changes to the query engine. + +Examples of changes that are _not_ considered breaking: + +- An undocumented function is removed. +- The module path of a public class is changed. +- An optional parameter is added to an existing method. + +Bug fixes are not considered a breaking change, even though it may impact some users' +[workflows](https://xkcd.com/1172/). + +### Unstable functionality + +Some parts of the public API are marked as **unstable**. You can recognize this functionality from +the warning in the API reference, or from the warning issued when the configuration option +`warn_unstable` is active. There are a number of reasons functionality may be marked as unstable: + +- We are unsure about the exact API. The name, function signature, or implementation are likely to + change in the future. +- The functionality is not tested extensively yet. Bugs may pop up when used in real-world + scenarios. +- The functionality does not yet integrate well with the full Polars API. You may find it works in + one context but not in another. + +Releasing functionality as unstable allows us to gather important feedback from users that use +Polars in real-world scenarios. This helps us fine-tune things before giving it our final stamp of +approval. Users that are only interested in solid, well-tested functionality can avoid this part of +the API. + +Functionality marked as unstable may change at any point without it being considered a breaking +change. + +### Deprecation warnings + +If we decide to introduce a breaking change, the existing behavior is deprecated _if possible_. For +example, if we choose to rename a function, the new function is added alongside the old function, +and using the old function will result in a deprecation warning. + +Not all changes can be deprecated nicely. A change to the query engine may have effects across a +large part of the API. Such changes will not be warned for, but _will_ be included in the changelog +and the migration guide. + +!!! warning Rust users only + + Breaking changes to the Rust API are not deprecated first, but _will_ be listed in the changelog. + Supporting deprecated functionality would slow down development too much at this point in time. + +### Deprecation period + +As a rule, deprecated functionality is removed two breaking releases after the deprecation happens. +For example, a function deprecated in version `1.2.3` will be retained in version `2.0.0` and +removed in version `3.0.0`. + +An exception to this rule are deprecations introduced with a breaking release. These will be +enforced on the next breaking release. For example, a function deprecated in version `2.0.0` will be +removed in version `3.0.0`. + +This means that if your program does not raise any deprecation warnings, it should be mostly safe to +upgrade to the next major version. As breaking releases happen about once every six months, this +allows six to twelve months to adjust to any pending breaking changes. + +**In some cases, we may decide to adjust the deprecation period.** If retaining the deprecated +functionality blocks other improvements to Polars, we may shorten the deprecation period to a single +breaking release. This will be mentioned in the warning message. If the deprecation affects many +users, we may extend the deprecation period. + +## Release frequency + +Polars does not have a set release schedule. We issue a new release whenever we feel like we have +something new and valuable to offer to our users. In practice, a new minor version is released about +once every one or two weeks. + +### Breaking releases + +Over time, issues pop up that require a breaking change to address. When enough issues have +accumulated, we issue a breaking release. + +So far, breaking releases have happened about once every three to six months. The rate and severity +of breaking changes will continue to diminish as Polars grows more solid. From this point on, we +expect new major versions to be released about once every six months. diff --git a/docs/source/index.md b/docs/source/index.md new file mode 100644 index 000000000000..1a98cd3c43db --- /dev/null +++ b/docs/source/index.md @@ -0,0 +1,81 @@ +![logo](https://raw.githubusercontent.com/pola-rs/polars-static/master/banner/polars_github_banner.svg) + +

Blazingly Fast DataFrame Library

+ + +Polars is a blazingly fast DataFrame library for manipulating structured data. The core is written +in Rust, and available for Python, R and NodeJS. + +## Key features + +- **Fast**: Written from scratch in Rust, designed close to the machine and without external + dependencies. +- **I/O**: First class support for all common data storage layers: local, cloud storage & databases. +- **Intuitive API**: Write your queries the way they were intended. Polars, internally, will + determine the most efficient way to execute using its query optimizer. +- **Out of Core**: The streaming API allows you to process your results without requiring all your + data to be in memory at the same time. +- **Parallel**: Utilises the power of your machine by dividing the workload among the available CPU + cores without any additional configuration. +- **Vectorized Query Engine** +- **GPU Support**: Optionally run queries on NVIDIA GPUs for maximum performance for in-memory + workloads. +- **[Apache Arrow support](https://arrow.apache.org/)**: Polars can consume and produce Arrow data + often with zero-copy operations. Note that Polars is not built on a Pyarrow/Arrow implementation. + Instead, Polars has its own compute and buffer implementations. + + + +!!! info "Users new to DataFrames" + A DataFrame is a 2-dimensional data structure that is useful for data manipulation and analysis. With labeled axes for rows and columns, each column can contain different data types, making complex data operations such as merging and aggregation much easier. Due to their flexibility and intuitive way of storing and working with data, DataFrames have become increasingly popular in modern data analytics and engineering. + + + +## Philosophy + +The goal of Polars is to provide a lightning fast DataFrame library that: + +- Utilizes all available cores on your machine. +- Optimizes queries to reduce unneeded work/memory allocations. +- Handles datasets much larger than your available RAM. +- A consistent and predictable API. +- Adheres to a strict schema (data-types should be known before running the query). + +Polars is written in Rust which gives it C/C++ performance and allows it to fully control +performance-critical parts in a query engine. + +## Example + +{{code_block('home/example','example',['scan_csv','filter','group_by','collect'])}} + +A more extensive introduction can be found in the [next chapter](user-guide/getting-started.md). + +## Community + +Polars has a very active community with frequent releases (approximately weekly). Below are some of +the top contributors to the project: + +--8<-- "docs/assets/people.md" + +## Contributing + +We appreciate all contributions, from reporting bugs to implementing new features. Read our +[contributing guide](development/contributing/index.md) to learn more. + +## License + +This project is licensed under the terms of the +[MIT license](https://github.com/pola-rs/polars/blob/main/LICENSE). diff --git a/docs/source/polars-cloud/cli.md b/docs/source/polars-cloud/cli.md new file mode 100644 index 000000000000..6da6e29bb689 --- /dev/null +++ b/docs/source/polars-cloud/cli.md @@ -0,0 +1,55 @@ +# CLI + +Polars cloud comes with a command line interface (CLI) out of the box. This allows you to interact +with polars cloud resources from the terminal. + +```bash +pc --help +``` + +``` +usage: pc [-h] [-v] [-V] {login,workspace,compute} ... + +positional arguments: +{login,workspace,compute} +login Authenticate with Polars Cloud by logging in through the browser +workspace Manage Polars Cloud workspaces. +compute Manage Polars Cloud compute clusters. + +options: +-h, --help show this help message and exit +-v, --verbose Output debug logging messages. +-V, --version Display the version of the Polars Cloud client. +``` + +### Authentication + +You can authenticate with Polars Cloud from the CLI using + +```bash +pc login +``` + +This refreshes your access token and saves it to disk. + +### Workspaces + +Create and setup a new workspace + +```bash +pc workspace setup +``` + +List all workspaces + +```bash +pc workspace list +``` + +``` +NAME ID STATUS +test-workspace 0194ac0e-5122-7a90-af5e-b1f60b1989f4 Active +polars-ci-2025… 0194287a-e0a5-7642-8058-0f79a39f5b98 Uninitialized +``` + +### Compute diff --git a/docs/source/polars-cloud/connect-cloud.md b/docs/source/polars-cloud/connect-cloud.md new file mode 100644 index 000000000000..119a59bd5722 --- /dev/null +++ b/docs/source/polars-cloud/connect-cloud.md @@ -0,0 +1,67 @@ +# Connect cloud environment + +To use Polars Cloud, you must connect your workspaces to a cloud environment. + +If you log in to the Polars Cloud dashboard for the first time with an account that isn’t connected +to a cloud environment, you will see a blue bar at the top of the screen. You can explore Polars +Cloud in this state, but you won’t be able to execute any queries yet. + +![An overview of the Polars Cloud dashboard showing a button to connect your cloud environment](https://raw.githubusercontent.com/pola-rs/polars-static/refs/heads/master/polars_cloud/connect-cloud/dashboard.png) + +When you click the blue bar you will be redirected to the start of the set up flow. + +## 1. Set workspace name + +In the first step of the setup flow, you’ll name your workspace. You can keep the default name +"Personal Workspace" or use the name of your team or department. This workspace name will be used by +the compute context to run your queries remotely. + + + +!!! tip "Naming your workspace" + If you’re unsure, you can use a temporary name. You can change the workspace name later under the workspace settings. + + + +![Connect your cloud screen where you can input a workspace name](https://raw.githubusercontent.com/pola-rs/polars-static/refs/heads/master/polars_cloud/connect-cloud/workspace-naming.png) + +## 2. Deploy to AWS + +After naming your workspace, click Deploy to Amazon. This opens a screen in AWS with a +CloudFormation template. This template installs the necessary roles in your AWS environment. + +![CloudFormation stack image as step of the setupflow](https://raw.githubusercontent.com/pola-rs/polars-static/refs/heads/master/polars_cloud/connect-cloud/cloudformation.png) + +If you want to learn more about what Polars Cloud installs in your environment, you can read more on +[the AWS Infrastructure page](providers/aws/infra.md). + + +!!! info "No permissions to deploy the stack in AWS" + If you don't have the required permissions to deploy CloudFormation stacks in your AWS environment, you can copy the URL and share it with your operations team or someone with the permissions. With the URL they can deploy the stack for you. + + +## 3. Deploying the environment + +After you click Create stack, the CloudFormation stack will be deployed in your environment. This +process usually takes around 5 minutes. You can monitor the progress in your AWS environment or in +the Polars setup flow. + +![Progress screen in the set up flow](https://raw.githubusercontent.com/pola-rs/polars-static/refs/heads/master/polars_cloud/connect-cloud/progress-page.png) + +When the CloudFormation stack deployment completes, you’ll see a confirmation message. + +![Final screen of the set up flow indication successful deployment](https://raw.githubusercontent.com/pola-rs/polars-static/refs/heads/master/polars_cloud/connect-cloud/successful-setup.png) + +If you click "Start exploring", you will be redirected to the Polars Cloud dashboard. + +You can now run your Polars query remotely in the cloud. Go to the +[getting started section](quickstart.md) to your first query in minutes, +[learn more how to run queries remote](run/compute-context.md) or manage your workspace to invite +your team. + + + +!!! info "Only connect a workspace once" + You only have to connect your workspace once. If you invite your team to a workspace that is connected to a cloud environment they can immediately run queries remotely. + + diff --git a/docs/source/polars-cloud/explain/authentication.md b/docs/source/polars-cloud/explain/authentication.md new file mode 100644 index 000000000000..9973942ff296 --- /dev/null +++ b/docs/source/polars-cloud/explain/authentication.md @@ -0,0 +1,38 @@ +# Logging in + +Polars cloud allows authentication through short-lived authentication tokens. There are two ways you +can obtain an access token: + +- command line interface +- python client + +After a successful `login` Polars Cloud stores the token in `{$HOME}/.polars`. You can alter this +path by setting the environment variable `POLARS_CLOUD_ACCESS_TOKEN_PATH`. + +### Command Line Interface (CLI) + +Authenticate with CLI using the following command + +```bash +pc login +``` + +### Python client + +Authenticate with the Polars Cloud using + +{{code_block('polars-cloud/authentication','login',['login'])}} + +Both methods redirect you to the browser where you can provide your login credentials and continue +the sign in process. + +## Service accounts + +Both flows described above are for interactive logins where a person is present in the process. For +non-interactive workflows such as orchestration tools there are service accounts. These allow you to +login programmatically. + +To create a service account go to the Polars Cloud dashboard under Settings and service accounts. +Here you can create a new service account for your workspace. To authenticate set the +`POLARS_CLOUD_CLIENT_ID` and `POLARS_CLOUD_CLIENT_SECRET` environment variables. Polars Cloud will +automatically pick these up if there are no access tokens present in the path. diff --git a/docs/source/polars-cloud/faq.md b/docs/source/polars-cloud/faq.md new file mode 100644 index 000000000000..a56cf4f6810e --- /dev/null +++ b/docs/source/polars-cloud/faq.md @@ -0,0 +1,66 @@ +# FAQ + +On this page you can find answers to some frequently asked questions around Polars Cloud. + +## Who is behind Polars Cloud? + +Polars Cloud is built by the organization behind the open source Polars project. We are committed to +improve Polars open source for all single machine workloads. Polars Cloud will extend Polars +functionalities for remote and distributed compute. + +## Where does the compute run? + +All compute runs in your own cloud environment. The main reason is that this ensures that your data +never leaves your environment and that the compute is always close to your data. + +You can learn more about how this setup in +[the infrastructure section of the documentation](providers/aws/infra.md). + +## Can you run Polars Cloud on-premise? + +Currently, Polars Cloud is only available to organizations that are on AWS. Support for on-premise +infrastructure is on our roadmap and will become available soon. + +## What does Polars Cloud offer me beyond Polars? + +Polars Cloud offers a managed service that enables scalable data processing with the flexibility and +expressiveness of the Polars API. It extends the open source Polars project with the following +capabilities: + +- Distributed engine to scale workloads horizontally. +- Cost-optimized serverless architecture that automatically scales compute resources +- Built-in fault tolerance mechanisms ensuring query completion even during hardware failures or + system interruptions +- Comprehensive monitoring and analytics tools providing detailed insights into query performance + and resource utilization. + +## What are the main use cases for Polars Cloud? + +Polars Cloud offers both a batch as an interactive mode to users. Batch mode can be used for ETL +workloads or one-off large scale analytic jobs. Interactive mode is for users that are looking to do +data exploration on a larger scale data processing that requires more compute than their own machine +can offer. + +## How can Polars Cloud integrate with my workflow? + +One of our key priorities is ensuring that running remote queries feels as native and seamless as +running them locally. Every user should be able to scale their queries effortlessly. + +Polars Cloud is completely environment agnostic. This allows you to run your queries from anywhere +such as your own machine, Jupyter/Marimo notebooks, Airflow DAGs, AWS Lambda functions, or your +servers. By not tying you to a specific platform, Polars Cloud gives you the flexibility to execute +your queries wherever it best fits your workflow. + +## What is the pricing model of Polars Cloud? + +Polars Cloud is available at no additional cost in this early stage. You only pay for the resources +you use in your own cloud environment. We are exploring different usage based pricing models that +are geared towards running queries as fast and efficient as possible. + +## Will the distributed engine be available in open source? + +The distributed engine is only available in Polars Cloud. There are no plans to make it available in +the open source project. Polars is focused on single node compute, as it makes efficient use of the +available resources. Users already report utilizing Polars to process hundreds of gigabytes of data +on single (large) compute instance. The distributed engine is geared towards teams and organizations +that are I/O bound or want to scale their Polars queries beyond single machines. diff --git a/docs/source/polars-cloud/index.md b/docs/source/polars-cloud/index.md new file mode 100644 index 000000000000..4c071f1b0696 --- /dev/null +++ b/docs/source/polars-cloud/index.md @@ -0,0 +1,55 @@ +![Image showing the Polars Cloud logo](https://raw.githubusercontent.com/pola-rs/polars-static/refs/heads/master/polars_cloud/polars-cloud.svg) + +# Introducing Polars Cloud + + + +!!! tip "Polars Cloud is in alpha stage" + Polars Cloud is currently available to a select group of individuals and companies for early-stage testing. You can learn more about Polars Cloud and its goals in [our recent announcement post](https://pola.rs/posts/polars-cloud-what-we-are-building/). + + + +DataFrame implementations always differed from SQL and databases. SQL could run anywhere from +embedded databases to massive data warehouses. Yet, DataFrame users have been forced to choose +between a solution for local work or solutions geared towards distributed computing, each with their +own APIs and limitations. + +Polars is bridging this gap with **Polars Cloud**. Build on top of the popular open source project, +Polars Cloud enables you to write DataFrame code once and run it anywhere. The distributed engine +available with Polars Cloud allows to scale your Polars queries beyond a single machine. + +## Key Features of Polars Cloud + +- **Unified DataFrame Experience**: Run a Polars query seamlessly on your local machine and at scale + with our new distributed engine. All from the same API. +- **Serverless Compute**: Effortlessly start compute resources without managing infrastructure with + options to execute queries on both CPU and GPU. +- **Any Environment**: Start a remote query from a notebook on your machine, Airflow DAG, AWS + Lambda, or your server. Get the flexibility to embed Polars Cloud in any environment. + +## Install Polars Cloud + +Simply extend the capabilities of Polars with: + +```bash +pip install polars polars_cloud +``` + +## Example query + +To run your query in the cloud, simply write Polars queries like you are used to, but call +`LazyFrame.remote()` to indicate that the query should be run remotely. + +{{code_block('polars-cloud/index','index',['ComputeContext','LazyFrameExt'])}} + +## Sign up today and start for free + +Polars Cloud is still in an early development stage and available at no additional cost. You only +pay for the resources you use in your own cloud environment. + +## Cloud availability + +Currently, Polars Cloud is available on AWS. Other cloud providers and on-premise solutions are on +the roadmap and will become available in the upcoming months. + +![AWS logo](https://raw.githubusercontent.com/pola-rs/polars-static/refs/heads/master/polars_cloud/aws-logo.svg) diff --git a/docs/source/polars-cloud/integrations/airflow.md b/docs/source/polars-cloud/integrations/airflow.md new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/docs/source/polars-cloud/providers/aws/infra.md b/docs/source/polars-cloud/providers/aws/infra.md new file mode 100644 index 000000000000..2c8849558f60 --- /dev/null +++ b/docs/source/polars-cloud/providers/aws/infra.md @@ -0,0 +1,73 @@ +# Infrastructure + +Polars Cloud manages the hardware for you by spinning up and down raw EC2 instances. In order to do +this it needs permissions in your own cloud environment. None of the resources below have costs +associated with them. While no compute clusters are running Polars Cloud will not create any AWS +costs. The recommended way of doing this is running `pc workspace setup`. + +## Recommended setup + +When you deploy Polars Cloud the following infrastructure is setup. + +
+![AWS infrastructure](https://raw.githubusercontent.com/pola-rs/polars-static/refs/heads/master/polars_cloud/aws-infra.png) +
+ +1. A `VPC` and `subnet` in which Polars EC2 workers can run. +1. Two `security groups`. One for batch mode, which does not have any public ports, and one for + interactive mode, which allows direct communication between your local environment and the + cluster. +1. `PolarsWorker` IAM role. Polars EC2 workers run under this IAM role. +1. `UserInitiated` & `Unattended` IAM role. The `UserInitiated` role has the permissions to start + Polars EC2 workers in your environment. The `Unattended` role can terminate unused compute + clusters that you might have forgot about. + +## Security + +By design Polars Cloud never has access to the data inside your cloud environment. The data never +leaves your environment. + +### IAM permissions + +The list below show an overview of the required permissions for each of the roles. + +??? User Initiated + + - ec2:CreateTags + - ec2:RunInstances + - ec2:DescribeInstances + - ec2:DescribeInstanceTypeOfferings + - ec2:DescribeInstanceTypes + - ec2:TerminateInstances + - ec2:CreateFleet + - ec2:CreateLaunchTemplate + - ec2:CreateLaunchTemplateVersion + - ec2:DescribeLaunchTemplates + +??? Unattended + + - ec2:DescribeInstances + - ec2:TerminateInstances + - ec2:DescribeFleets + - ec2:DeleteLaunchTemplate + - ec2:DeleteLaunchTemplateVersions + - ec2:DeleteFleets + - sts:GetCallerIdentity + - sts:TagSession + - cloudwatch:GetMetricData + - logs:GetLogEvents + - logs:FilterLogEvents + - logs:DescribeLogStreams + +??? Worker + + - logs:CreateLogGroup + - logs:PutRetentionPolicy + - cloudwatch:PutMetricData + +## Custom setup + +Depending on your enterprise needs or existing infrastructure, you may not require certain +components (e.g. VPC, subnet) of the default setup of Polars Cloud. Or you have additional security +requirements in place. Together with our team of engineers we can integrate Polars Cloud with your +existing infrastructure. Please contact us directly. diff --git a/docs/source/polars-cloud/providers/aws/permissions.md b/docs/source/polars-cloud/providers/aws/permissions.md new file mode 100644 index 000000000000..c480468a4631 --- /dev/null +++ b/docs/source/polars-cloud/providers/aws/permissions.md @@ -0,0 +1,16 @@ +# Permissions + +The workspace is an isolation for all resources living within your cloud environment. Every +workspace has a single instance profile which defines the permissions for the compute. This profile +is attached to the compute within your environment. By default, the profile can read and write from +S3, but you can easily adjust depending on your own infrastructure stack. + +## Adding or removing permissions + +If you want Polars Cloud to be able to read from other data sources than `S3` within your cloud +environment you must provide the access control from directly within AWS. To do this go to `IAM` +within the aws console and locate the role called `polars--IAMWorkerRole-`. +Here you can adjust the permissions of the workspace for instance: + +- [Narrow down the S3 access to certain buckets](https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_examples_s3_deny-except-bucket.html) +- [Provide IAM access to rds database](https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.IAMPolicy.html) diff --git a/docs/source/polars-cloud/quickstart.md b/docs/source/polars-cloud/quickstart.md new file mode 100644 index 000000000000..964f79f1533d --- /dev/null +++ b/docs/source/polars-cloud/quickstart.md @@ -0,0 +1,69 @@ +# Getting started + +Polars Cloud is a managed compute platform for your Polars queries. It allows you to effortlessly +run your local queries in your cloud environment, both in an interactive setting as well as for ETL +or batch jobs. By working in a 'Bring your own Cloud' model the data never leaves your environment. + +## Installation + +Install the Polars Cloud python library in your environment + +```bash +$ pip install polars polars-cloud +``` + +Create an account and login by running the command below. + +```bash +$ pc login +``` + +## Connect your cloud + +Polars Cloud currently exclusively supports AWS as a cloud provider. + +Polars Cloud needs permission to manage hardware in your environment. This is done by deploying our +cloudformation template. See our [infrastructure](providers/aws/infra.md) section for more details. + +To connect your cloud run: + +```bash +$ pc setup workspace -n +``` + +This redirects you to the browser where you can connect Polars Cloud to your AWS environment. +Alternatively, you can follow the steps in the browser and create the workspace there. + +## Run your queries + +Now that we are done with the setup, we can start running queries. The general principle here is +writing Polars like you're always used to and calling `.remote()` on your `LazyFrame`. The following +example shows how to create a compute cluster and run a simple Polars query. + +{{code_block('polars-cloud/quickstart','general',['ComputeContext','LazyFrameExt'])}} + +Let us go through the code line by line. First we need to define the hardware the cluster will run +on. This can be provided in terms of cpu & memory or by specifying the the exact instance type in +AWS. + +```python +ctx = pc.ComputeContext(memory=8, cpus=2 , cluster_size=1) +``` + +Then we write a regular lazy Polars query. In this simple example we compute the maximum of column +`a` over column `b`. + +```python +df = pl.LazyFrame({ + "a": [1, 2, 3], + "b": [4, 4, 5] +}) +lf = df.with_columns( + c = pl.col("a").max().over("b") +) +``` + +Finally we are going to run our query on the compute cluster. We use `.remote()` to signify that we +want to run the query remotely. This gives back a special version of the `LazyFrame` with extension +methods. Up until this point nothing has executed yet, calling `.write_parquet()` sends the query to +g diff --git a/docs/source/polars-cloud/release-notes.md b/docs/source/polars-cloud/release-notes.md new file mode 100644 index 000000000000..f5847b75fd7b --- /dev/null +++ b/docs/source/polars-cloud/release-notes.md @@ -0,0 +1,63 @@ +# Release Notes + +## Polars Cloud 0.0.7 (2025-04-11) + +### Enhancements + +- We now default to the streaming engine for distributed workloads A big benefit is that query + results are now streamed directly to final S3/external storage instead of being collected into + memory first. This enables handling much larger results without causing out-of-memory errors. +- Polars Cloud now supports organizations. An organization can contain multiple workspaces and will + eventually serve as the central entity for managing user roles and billing. During migration, an + organization was automatically created for each of your workspaces. If you have multiple + workspaces with the same name, you can now select one by specifying the organization: + ```python + workspace = Workspace("my_workspace", organization="my_organization") + context = ComputeContext(workspace=workspace) + ``` +- Polars Cloud login tokens are now automatically refreshed for up to 8 hours after the initial + login. No more need to login every 15 minutes. +- A clear error message is now shown when connecting to Polars Cloud with an out-of-date client + version. +- The portal now has a feedback button in the bottom right corner so you can easily report feedback + to us. We greatly appreciate it, so please use it as much as you would like. +- Distributed queries now display the physical plan in addition to the logical plan. The logical + plan is the optimized query and how it would be run on a single node. The physical plan represents + how we chop this up into stages and tasks so that it can be run on multiple nodes in a distributed + manner. +- The workspace compute listing page has been reworked for improved clarity and shows more + information. + +### Bug fixes + +- Improved out-of-memory reporting +- Ctrl-C now works to exit during login +- Workspace members can now be re-invited if they were previously removed + +## Polars Cloud 0.0.6 (2025-03-31) + +This version of Polars Cloud is only compatible with the 1.26.x releases of Polars Open Source. + +### Enhancements + +- Added support for [Polars IO plugins](https://docs.pola.rs/user-guide/plugins/io_plugins/). + - These allow you to register different file formats as sources to the Polars engines which allows + you to benefit from optimizations like projection pushdown, predicate pushdown, early stopping + and support of our streaming engine. +- Added `.show()` for remote queries + - You can now call `.remote().show()` instead of `remote().limit().collect().collect()` +- All Polars features that rely on external Python dependencies are now available in Polars Cloud + such as: + - `scan_iceberg()` + - `scan_delta()` + - `read_excel()` + - `read_database()` + +### Bug fixes + +- Fixed an issue where specifying CPU/memory minimums could select an incompatible legacy EC2 + instance. +- The API now returns clear errors when trying to start a query on an uninitialized workspace. +- Fixed a 404 error when opening a query overview page in a workspace different from the currently + selected workspace. +- Viewing another workspace members compute detail page no longer shows a blank screen. diff --git a/docs/source/polars-cloud/run/compute-context.md b/docs/source/polars-cloud/run/compute-context.md new file mode 100644 index 000000000000..b76007b5efcc --- /dev/null +++ b/docs/source/polars-cloud/run/compute-context.md @@ -0,0 +1,69 @@ +# Defining a compute context + +The compute context defines the hardware configuration used to execute your queries. This can be +either a single node or, for distributed execution, multiple nodes. This section explains how to set +up and manage your compute context. + +{{code_block('polars-cloud/compute-context','compute',['ComputeContext'])}} + +## Setting the context + +You can define your compute context in three ways: + +1. Use your workspace default +2. Specify CPUs and RAM requirements +3. Select a specific instance type + +### Workspace default + +In the Polars Cloud dashboard, you can set default requirements from your cloud service provider to +be used for all queries. You can also manually define storage and the default cluster size. + +Polars Cloud will use these defaults if no other parameters are passed to the `ComputeContext`. + +{{code_block('polars-cloud/compute-context','default-compute',['ComputeContext'])}} + +Find out more about how to [set workspace defaults](../workspace/settings.md) in the workspace +settings section. + +### Define hardware specifications + +You can directly specify the `cpus` and `memory` requirements in your `ComputeContext`. When set, +Polars Cloud will select the most suitable instance type from your cloud service provider that meets +the specifications. The requirements are lower bounds, meaning the machine will have at least that +number of CPUs and memory. + +{{code_block('polars-cloud/compute-context','defined-compute',['ComputeContext'])}} + +### Set instance type + +For more control, you can specify the exact instance type for Polars to use. This is useful when you +have specific hardware requirements in a production environment. + +{{code_block('polars-cloud/compute-context','set-compute',['ComputeContext'])}} + +## Applying the compute context + +Once defined, you can apply your compute context to queries in three ways: + +1. By directly passing the context to the remote query: + + ```python + query.remote(context=ctx).sink_parquet(...) + ``` + +2. By globally setting the compute context. This way you set it once and don't need to provide it to + every `remote` call: + + ```python + pc.set_compute_context(ctx) + + query.remote().sink_parquet(...) + ``` + +3. When a default compute context is set via the Polars Cloud dashboard. It is no longer required to + define a compoute context. + + ```python + query.remote().sink_parquet(...) + ``` diff --git a/docs/source/polars-cloud/run/distributed-engine.md b/docs/source/polars-cloud/run/distributed-engine.md new file mode 100644 index 000000000000..c2db4b9fe69f --- /dev/null +++ b/docs/source/polars-cloud/run/distributed-engine.md @@ -0,0 +1,63 @@ +# Distributed query execution + +With the introduction of Polars Cloud, we also introduced the distributed engine. This engine +enables users to horizontally scale workloads across multiple machines. + +Polars has always been optimized for fast and efficient performance on a single machine. However, +when querying large datasets from cloud storage, performance is often constrained by the I/O +limitations of a single node. By scaling horizontally, these download limitations can be +significantly reduced, allowing users to process data at scale. + + + +!!! info "Distributed engine is early stage" + The distributed engine is still in the very early stages of development. Major performance improvements are planned for the near future. When an operation is not yet available in a distributed manner, Polars Cloud will execute it on a single node. + + Find out which operations are [currently supported in the distributed engine](https://github.com/pola-rs/polars/issues/21487). + + + +## Using distributed engine + +To execute queries using the distributed engine, you can call the `distributed()` method. + +```python +lf: LazyFrame + +result = ( + lf.remote() + .distributed() + .collect() +) +``` + +### Example + +```python +import polars as pl +import polars_cloud as pc +from datetime import date + +query = ( + pl.scan_parquet("s3://dataset/") + .filter(pl.col("l_shipdate") <= date(1998, 9, 2)) + .group_by("l_returnflag", "l_linestatus") + .agg( + avg_price=pl.mean("l_extendedprice"), + avg_disc=pl.mean("l_discount"), + count_order=pl.len(), + ) +) + +result = ( + query.remote(pc.ComputeContext(cpus=16, memory=64, cluster_size=32)) + .distributed() + .sink_parquet("s3://output/result.parquet") +) +``` + +## Working with large datasets in the distributed engine + +The distributed engine can only read sources partitioned with direct scan_ methods such as +`scan_parquet` and `scan_csv`. Open table formats like `scan_iceberg` are not yet supported in a +distributed fashion and will run on a single node when utilized. diff --git a/docs/source/polars-cloud/run/interactive-batch.md b/docs/source/polars-cloud/run/interactive-batch.md new file mode 100644 index 000000000000..f578133fc791 --- /dev/null +++ b/docs/source/polars-cloud/run/interactive-batch.md @@ -0,0 +1,102 @@ +# Interactive or batch mode + +In Polars Cloud, a user can define two types of compute modes: batch & interactive. Batch mode is +designed for batch job style queries. These kinds of queries are typically scheduled and run once in +a certain period. Interactive mode allows for exploratory workflows where a user interacts with the +dataset and requires more compute resources than are locally available. + +The rest of this page will give examples on how to set up one or the other. More information on the +architectural differences and implications can be found on +[the infrastructure page](../providers/aws/infra.md). + +Below we create a simple dataframe to use as an example to demonstrate the difference between both +modes. + +{{code_block('polars-cloud/interactive-batch','example',[])}} + +## Batch + +Batch workflows are systematic data processing pipelines, written for scheduled execution. They +typically process large volumes of data in scheduled intervals (e.g. hourly, daily, etc.). A key +characteristic is that the executed job has a defined lifetime. A predefined compute instance should +spin up at a certain time and shut down when the job is executed. + +Polars Cloud makes it easy to run your query at scale whenever your use case requires it. You can +develop your query on your local machine and define a compute context and destination to execute it +in your cloud environment. + +{{code_block('polars-cloud/interactive-batch','batch',['ComputeContext'])}} + +The query you execute in batch mode runs in your cloud environment. The data and results of the +query are not sent to Polars Cloud, ensuring that your data and output remain secure. + +```python +lf.remote(context=ctx).sink_parquet("s3://bucket/output.parquet") +``` + +## Interactive + +Polars Cloud also supports interactive workflows. Different from batch mode, results are being +interactively updated. Polars Cloud will not automatically close the cluster when a result has been +produced, but the cluster stays active and intermediate state can still be accessed. In interactive +mode you directly communicate with the compute nodes. + +Because this mode is used for exploratory use cases and short feedback cycles, the queries are not +logged to Polars Cloud and will not be available for later inspection. + +{{code_block('polars-cloud/interactive-batch','interactive',['ComputeContext'])}} + +The initial query remains the same. In the compute context the parameter `interactive` should be set +to `True`. + +When calling `.collect()` on your remote query execution, the output is written to a temporary +location. These intermediate result files are automatically deleted after several hours. The output +of the remote query is a LazyFrame. + +```python +print(type(res1)) +``` + +``` + +``` + +If you want to inspect the results you can call collect again. + +```python +print(res1.collect()) +``` + +```text +shape: (4, 3) +┌────────────────┬────────────┬───────────┐ +│ name ┆ birth_year ┆ bmi │ +│ --- ┆ --- ┆ --- │ +│ str ┆ i32 ┆ f64 │ +╞════════════════╪════════════╪═══════════╡ +│ Chloe Cooper ┆ 1983 ┆ 19.687787 │ +│ Ben Brown ┆ 1985 ┆ 23.141498 │ +│ Alice Archer ┆ 1997 ┆ 23.791913 │ +│ Daniel Donovan ┆ 1981 ┆ 27.134694 │ +└────────────────┴────────────┴───────────┘ +``` + +To continue your exploration you can use the returned LazyFrame to build another query. + +{{code_block('polars-cloud/interactive-batch','interactive-next',[])}} + +```python +print(res2.collect()) +``` + +```text +shape: (2, 3) +┌──────────────┬────────────┬───────────┐ +│ name ┆ birth_year ┆ bmi │ +│ --- ┆ --- ┆ --- │ +│ str ┆ i32 ┆ f64 │ +╞══════════════╪════════════╪═══════════╡ +│ Chloe Cooper ┆ 1983 ┆ 19.687787 │ +│ Ben Brown ┆ 1985 ┆ 23.141498 │ +└──────────────┴────────────┴───────────┘ +``` diff --git a/docs/source/polars-cloud/run/service-accounts.md b/docs/source/polars-cloud/run/service-accounts.md new file mode 100644 index 000000000000..88127c66376e --- /dev/null +++ b/docs/source/polars-cloud/run/service-accounts.md @@ -0,0 +1,50 @@ +# Using service accounts + +Service accounts function as programmatic identities that enable secure machine-to-machine +communication for remote data processing workflows. These accounts facilitate automated processes +(typically in production environments) without requiring interactive login sessions. + +## Create a service account + +In the workspace settings page, navigate to the Service Accounts section. Here, you can create and +manage service accounts associated with your workspace. + +To create a new Service Account: + +1. Click the Create New Service Account button +2. Provide a name and description for the account +3. Copy and securely store the Client ID and Client Secret + + + +!!! info "Client ID and Secret visible once" + If you lose the Client ID or Client Secret, you will need to generate a new service account. + + + +## Set up your environment to use a Service Account + +To authenticate using a Service Account, you must set environment variables. The following variables +should be defined using the credentials provided when the service account was created: + +```bash +export POLARS_CLOUD_CLIENT_ID="CLIENT_ID_HERE" +export POLARS_CLOUD_CLIENT_SECRET="CLIENT_SECRET_HERE" +``` + +## Execute query with Service Account + +Once the environment variables are set, you do not need to log in to Polars Cloud manually. Your +process will automatically connect to the workspace. + +## Revoking access or deleting a Service Account + +When a service account is no longer needed, it is recommended to revoke its access. Keep in mind +that deleting a service account is irreversible, and any processes or applications relying on this +account will lose access immediately. + +1. Navigate to the Workspace Settings page. +2. Go to the Service Accounts section. +3. Locate the service account to delete. +4. Click the three-dot menu next to the service account and select Delete. +5. Confirm the deletion of the Service Account. diff --git a/docs/source/polars-cloud/run/workflow.md b/docs/source/polars-cloud/run/workflow.md new file mode 100644 index 000000000000..f933b062292e --- /dev/null +++ b/docs/source/polars-cloud/run/workflow.md @@ -0,0 +1,155 @@ +# From local to cloud query execution + +Data processing and analytics often begins small but can quickly grow beyond the capabilities of +your local machine. A typical workflow starts with exploring a sample dataset locally, developing +the analytical approach, and then scaling up to process the full dataset in the cloud. + +This pattern allows you to iterate quickly during development while still handling larger datasets +in production. With Polars Cloud, you can maintain this natural workflow without rewriting your code +when moving from local to cloud execution, without requiring any migrations between local and +production tooling. + +## Local exploration + +For this workflow, we define the following simple mocked dataset that will act as a sample to +demonstrate the workflow. Here we will create the LazyFrame ourselves, but it could also be read as +(remote) file. + +```python +import polars as pl + +lf = pl.LazyFrame( + { + "region": [ + "Australia", + "California", + "Benelux", + "Siberia", + "Mediterranean", + "Congo", + "Borneo", + ], + "temperature": [32.1, 28.5, 30.2, 22.7, 29.3, 31.8, 33.2], + "humidity": [40, 35, 75, 30, 45, 80, 70], + "burn_area": [120, 85, 30, 65, 95, 25, 40], + "vegetation_density": [0.6, 0.7, 0.9, 0.4, 0.5, 0.9, 0.8], + } +) +``` + +A simple transformation will done to create a new column. + +```python +( + lf.with_columns( + ( + (pl.col("temperature") / 10) + * (1 - pl.col("humidity") / 100) + * pl.col("vegetation_density") + ).alias("fire_risk"), + ).filter(pl.col("humidity") < 70) + .sort(by="fire_risk", descending=True) + .collect() +) +``` + +```text +shape: (4, 6) +┌───────────────┬─────────────┬──────────┬───────────┬────────────────────┬───────────┐ +│ region ┆ temperature ┆ humidity ┆ burn_area ┆ vegetation_density ┆ fire_risk │ +│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ +│ str ┆ f64 ┆ i64 ┆ i64 ┆ f64 ┆ f64 │ +╞═══════════════╪═════════════╪══════════╪═══════════╪════════════════════╪═══════════╡ +│ California ┆ 28.5 ┆ 35 ┆ 85 ┆ 0.7 ┆ 1.29675 │ +│ Australia ┆ 32.1 ┆ 40 ┆ 120 ┆ 0.6 ┆ 1.1556 │ +│ Mediterranean ┆ 29.3 ┆ 45 ┆ 95 ┆ 0.5 ┆ 0.80575 │ +│ Siberia ┆ 22.7 ┆ 30 ┆ 65 ┆ 0.4 ┆ 0.6356 │ +└───────────────┴─────────────┴──────────┴───────────┴────────────────────┴───────────┘ +``` + +## Run at scale in the cloud + +Imagine that there is a larger dataset stored in a cloud provider’s storage solution. The dataset is +so large that it doesn’t fit on our local machine. However, through local analysis, we have verified +that the defined query correctly calculates the column we are looking for. + +With Polars Cloud, we can easily run the same query at scale. First, we make small changes to our +query to point to our resources in the cloud. + +```python +lf = pl.scan_parquet("s3://climate-data/global/*.parquet") + +query = ( + lf.with_columns( + [ + ( + (pl.col("temperature") / 10) + * (1 - pl.col("humidity") / 100) + * pl.col("vegetation_density") + ).alias("fire_risk"), + ] + ) + .filter(pl.col("humidity") < 70) + .sort(by="fire_risk", descending=True) +) +``` + +Next, we set our compute context and call `.remote(context=ctx)` on our query. + +```python +import polars_cloud as pc + +ctx = pc.ComputeContext( + workspace="environmental-analysis", + memory=32, + cpus=8 +) + +query.remote(context=ctx).sink_parquet("s3://bucket/result.parquet") +``` + +### Continue analysis in interactive mode + +Running `.sink_parquet()` will write the results to the defined bucket on S3. Alternatively, we can +take a more interactive approach by adding the parameter `interactive=True` to our compute context. + +```python +ctx = pc.ComputeContext( + workspace="environmental-analysis", + memory=32, + cpus=8, + interactive=True, # set interactive to True +) + +result = query.remote(context=ctx).collect() + +print(result.collect()) +``` + +```text +shape: (4, 6) +┌───────────────┬─────────────┬──────────┬───────────┬────────────────────┬───────────┐ +│ region ┆ temperature ┆ humidity ┆ burn_area ┆ vegetation_density ┆ fire_risk │ +│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ +│ str ┆ f64 ┆ i64 ┆ i64 ┆ f64 ┆ f64 │ +╞═══════════════╪═════════════╪══════════╪═══════════╪════════════════════╪═══════════╡ +│ California ┆ 28.5 ┆ 35 ┆ 85 ┆ 0.7 ┆ 1.29675 │ +│ Australia ┆ 32.1 ┆ 40 ┆ 120 ┆ 0.6 ┆ 1.1556 │ +│ Mediterranean ┆ 29.3 ┆ 45 ┆ 95 ┆ 0.5 ┆ 0.80575 │ +│ Siberia ┆ 22.7 ┆ 30 ┆ 65 ┆ 0.4 ┆ 0.6356 │ +└───────────────┴─────────────┴──────────┴───────────┴────────────────────┴───────────┘ +``` + +We can call `.collect()` instead of `.sink_parquet()`. This will store your results to a temporary +location which can be used to further iterate upon. A LazyFrame is returned that can be used in the +next steps of the workflow. + +```python +res2 = ( + result + .filter(pl.col("fire_risk") > 1) + .sink_parquet("s3://bucket/output-interactive.parquet") +) +``` + +The result of your interactive workflow can be written to S3. diff --git a/docs/source/polars-cloud/workspace/settings.md b/docs/source/polars-cloud/workspace/settings.md new file mode 100644 index 000000000000..d577ea02df80 --- /dev/null +++ b/docs/source/polars-cloud/workspace/settings.md @@ -0,0 +1,61 @@ +# Configuring workspace settings + +The Workspace Settings page provides a centralized interface for managing your workspace +configuration. + +## General + +The General section contains basic information about your workspace: + +- **Workspace Name**: Displays the current name of your workspace. You can modify this by clicking + the "Edit" button. +- **Description**: Provides space for a detailed description of your workspace's purpose. The + description can be modified by clicking the "Edit" button. + +## Default Compute Configuration + +This section allows you to define the default computational resources allocated to jobs and +processes within your workspace. + +- **Set Cluster Defaults**: Clicking this button lets you configure either a Resource based default + or Instance based default. + - With Resource based custom amount of vCPUs, RAM, Storage and cluster size can be defined. + - Instance based allows to select a default instance type from the cloud service provider. + +Setting default compute configurations eliminates the need to explicitly define a compute context. +More information on configuration can be found in the section on +[setting up compute context](../run/compute-context.md). + +## Query and compute labels + +Labels help organize and categorize queries and compute within your workspace. The labels can only +be set from the dashboard. + +```python +ctx= plc.ComputeContext( + workspace="PolarsCloudDemo", + labels=["marketing", "cltv"], +) +``` + +- **Create New Label**: Create custom labels that can be applied to various resources for better + organization and insight in usage. + +## Service Accounts + +The Service Accounts section manages programmatic access to your workspace. + +- **Token Description**: Displays the purpose of each service account token. +- **Created At**: Shows when a service account was created. +- **Action**: Provides options to delete service accounts. +- **Create New Service Account**: Allows you to generate new service accounts for API access and + automation. + +You can find more information about using Service Accounts in your workflows in the "Use service +accounts" section. + +## Disable workspace + +Users have the ability to disable the entire workspace. This should be used with caution. Disabling +the workspace will stop processes, remove access and permanently delete all data and content within +the workspace. diff --git a/docs/source/polars-cloud/workspace/team.md b/docs/source/polars-cloud/workspace/team.md new file mode 100644 index 000000000000..e4ac78a5f9ca --- /dev/null +++ b/docs/source/polars-cloud/workspace/team.md @@ -0,0 +1,32 @@ +# Manage workspace members + +The Team page serves as a central hub for managing who has access to the workspace. Administrators +are able to invite collaborators, monitor membership status, and manage user permissions. + +## Adding your team to the workspace + +### Email invitations + +The primary method for adding team members is through email invitations: + +1. Enter the email address of the person you wish to add in the provided input field +2. Click "Send email" to deliver an invitation directly to their inbox + +Invitees will receive in invitation email that they can use to register and get added to the +workspace. + +### Invitation links + +With an invite link you can share the personal invitation link with an invitee. + +1. Click "Generate invite link" to create a shareable URL +2. Copy and share the link to the invitee. + +The invitee can use this link to join the workspace without requiring an email invitation. + + + +!!! info "Inviting existing users" + Users that are already part of another workspace will get a prompt to join the workspace they are invited for. + + diff --git a/docs/source/pyproject.toml b/docs/source/pyproject.toml new file mode 100644 index 000000000000..d525f50502a4 --- /dev/null +++ b/docs/source/pyproject.toml @@ -0,0 +1,8 @@ +[tool.ruff] +fix = true + +[tool.ruff.lint] +ignore = [ + "E402", # Module level import not at top of file + "F811", # Redefinition of unused variable +] diff --git a/docs/source/releases/changelog.md b/docs/source/releases/changelog.md new file mode 100644 index 000000000000..d3afd8af6807 --- /dev/null +++ b/docs/source/releases/changelog.md @@ -0,0 +1,6 @@ +# Changelog + +Polars uses GitHub to manage both Python and Rust releases. + +Refer to our [GitHub releases page](https://github.com/pola-rs/polars/releases) for the changelog +associated with each new release. diff --git a/docs/source/releases/upgrade/0.19.md b/docs/source/releases/upgrade/0.19.md new file mode 100644 index 000000000000..68cf4ed064e1 --- /dev/null +++ b/docs/source/releases/upgrade/0.19.md @@ -0,0 +1,200 @@ +# Version 0.19 + +## Breaking changes + +### Aggregation functions no longer support horizontal computation + +This impacts aggregation functions like `sum`, `min`, and `max`. These functions were overloaded to +support both vertical and horizontal computation. Recently, new dedicated functionality for +horizontal computation was released, and horizontal computation was deprecated. + +Restore the old behavior by using the horizontal variant, e.g. `sum_horizontal`. + +**Example** + +Before: + +```shell +>>> df = pl.DataFrame({'a': [1, 2], 'b': [11, 12]}) +>>> df.select(pl.sum('a', 'b')) # horizontal computation +shape: (2, 1) +┌─────┐ +│ sum │ +│ --- │ +│ i64 │ +╞═════╡ +│ 12 │ +│ 14 │ +└─────┘ +``` + +After: + +```shell +>>> df = pl.DataFrame({'a': [1, 2], 'b': [11, 12]}) +>>> df.select(pl.sum('a', 'b')) # vertical computation +shape: (1, 2) +┌─────┬─────┐ +│ a ┆ b │ +│ --- ┆ --- │ +│ i64 ┆ i64 │ +╞═════╪═════╡ +│ 3 ┆ 23 │ +└─────┴─────┘ +``` + +### Update to `all` / `any` + +`all` will now ignore null values by default, rather than treat them as `False`. + +For both `any` and `all`, the `drop_nulls` parameter has been renamed to `ignore_nulls` and is now +keyword-only. Also fixed an issue when setting this parameter to `False` would erroneously result in +`None` output in some cases. + +To restore the old behavior, set `ignore_nulls` to `False` and check for `None` output. + +**Example** + +Before: + +```shell +>>> pl.Series([True, None]).all() +False +``` + +After: + +```shell +>>> pl.Series([True, None]).all() +True +``` + +### Improved error types for many methods + +Improving our error messages is an ongoing effort. We did a sweep of our Python code base and made +many improvements to error messages and error types. Most notably, many `ValueError`s were changed +to `TypeError`s. + +If your code relies on handling Polars exceptions, you may have to make some adjustments. + +**Example** + +Before: + +```shell +>>> pl.Series(values=15) +... +ValueError: Series constructor called with unsupported type; got 'int' +``` + +After: + +```shell +>>> pl.Series(values=15) +... +TypeError: Series constructor called with unsupported type 'int' for the `values` parameter +``` + +### Updates to expression input parsing + +Methods like `select` and `with_columns` accept one or more expressions. But they also accept +strings, integers, lists, and other inputs that we try to interpret as expressions. We updated our +internal logic to parse inputs more consistently. + +**Example** + +Before: + +```shell +>>> pl.DataFrame({'a': [1, 2]}).with_columns(None) +shape: (2, 1) +┌─────┐ +│ a │ +│ --- │ +│ i64 │ +╞═════╡ +│ 1 │ +│ 2 │ +└─────┘ +``` + +After: + +```shell +>>> pl.DataFrame({'a': [1, 2]}).with_columns(None) +shape: (2, 2) +┌─────┬─────────┐ +│ a ┆ literal │ +│ --- ┆ --- │ +│ i64 ┆ null │ +╞═════╪═════════╡ +│ 1 ┆ null │ +│ 2 ┆ null │ +└─────┴─────────┘ +``` + +### `shuffle` / `sample` now use an internal Polars seed + +If you used the built-in Python `random.seed` function to control the randomness of Polars +expressions, this will no longer work. Instead, use the new `set_random_seed` function. + +**Example** + +Before: + +```python +import random + +random.seed(1) +``` + +After: + +```python +import polars as pl + +pl.set_random_seed(1) +``` + +## Deprecations + +Creating a consistent and intuitive API is hard; finding the right name for each function, method, +and parameter might be the hardest part. The new version comes with several naming changes, and you +will most likely run into deprecation warnings when upgrading to `0.19`. + +If you want to upgrade without worrying about deprecation warnings right now, you can add the +following snippet to your code: + +```python +import warnings + +warnings.filterwarnings("ignore", category=DeprecationWarning) +``` + +### `groupby` renamed to `group_by` + +This is not a change we make lightly, as it will impact almost all our users. But "group by" is +really two different words, and our naming strategy dictates that these should be separated by an +underscore. + +Most likely, a simple search and replace will be enough to take care of this update: + +- Search: `.groupby(` +- Replace: `.group_by(` + +### `apply` renamed to `map_*` + +`apply` is probably the most misused part of our API. Many Polars users come from pandas, where +`apply` has a completely different meaning. + +We now consolidate all our functionality for user-defined functions under the name `map`. This +results in the following renaming: + +| Before | After | +| --------------------------- | -------------- | +| `Series/Expr.apply` | `map_elements` | +| `Series/Expr.rolling_apply` | `rolling_map` | +| `DataFrame.apply` | `map_rows` | +| `GroupBy.apply` | `map_groups` | +| `apply` | `map_groups` | +| `map` | `map_batches` | diff --git a/docs/source/releases/upgrade/0.20.md b/docs/source/releases/upgrade/0.20.md new file mode 100644 index 000000000000..58901798c8b4 --- /dev/null +++ b/docs/source/releases/upgrade/0.20.md @@ -0,0 +1,543 @@ +# Version 0.20 + +## Breaking changes + +### Change default `join` behavior with regard to null values + +Previously, null values in the join key were considered a value like any other value. This meant +that null values in the left frame would be joined with null values in the right frame. This is +expensive and does not match default behavior in SQL. + +Default behavior has now been changed to ignore null values in the join key. The previous behavior +can be retained by setting `join_nulls=True`. + +**Example** + +Before: + +```pycon +>>> df1 = pl.DataFrame({"a": [1, 2, None], "b": [4, 4, 4]}) +>>> df2 = pl.DataFrame({"a": [None, 2, 3], "c": [5, 5, 5]}) +>>> df1.join(df2, on="a", how="inner") +shape: (2, 3) +┌──────┬─────┬─────┐ +│ a ┆ b ┆ c │ +│ --- ┆ --- ┆ --- │ +│ i64 ┆ i64 ┆ i64 │ +╞══════╪═════╪═════╡ +│ null ┆ 4 ┆ 5 │ +│ 2 ┆ 4 ┆ 5 │ +└──────┴─────┴─────┘ +``` + +After: + +```pycon +>>> df1.join(df2, on="a", how="inner") +shape: (1, 3) +┌─────┬─────┬─────┐ +│ a ┆ b ┆ c │ +│ --- ┆ --- ┆ --- │ +│ i64 ┆ i64 ┆ i64 │ +╞═════╪═════╪═════╡ +│ 2 ┆ 4 ┆ 5 │ +└─────┴─────┴─────┘ +>>> df1.join(df2, on="a", how="inner", nulls_equal=True) # Keeps previous behavior +shape: (2, 3) +┌──────┬─────┬─────┐ +│ a ┆ b ┆ c │ +│ --- ┆ --- ┆ --- │ +│ i64 ┆ i64 ┆ i64 │ +╞══════╪═════╪═════╡ +│ null ┆ 4 ┆ 5 │ +│ 2 ┆ 4 ┆ 5 │ +└──────┴─────┴─────┘ +``` + +### Preserve left and right join keys in outer joins + +Previously, the result of an outer join did not contain the join keys of the left and right frames. +Rather, it contained a coalesced version of the left key and right key. This loses information and +does not conform to default SQL behavior. + +The behavior has been changed to include the original join keys. Name clashes are solved by +appending a suffix (`_right` by default) to the right join key name. The previous behavior can be +retained by setting `how="outer_coalesce"`. + +**Example** + +Before: + +```pycon +>>> df1 = pl.DataFrame({"L1": ["a", "b", "c"], "L2": [1, 2, 3]}) +>>> df2 = pl.DataFrame({"L1": ["a", "c", "d"], "R2": [7, 8, 9]}) +>>> df1.join(df2, on="L1", how="outer") +shape: (4, 3) +┌─────┬──────┬──────┐ +│ L1 ┆ L2 ┆ R2 │ +│ --- ┆ --- ┆ --- │ +│ str ┆ i64 ┆ i64 │ +╞═════╪══════╪══════╡ +│ a ┆ 1 ┆ 7 │ +│ c ┆ 3 ┆ 8 │ +│ d ┆ null ┆ 9 │ +│ b ┆ 2 ┆ null │ +└─────┴──────┴──────┘ +``` + +After: + +```pycon +>>> df1.join(df2, on="L1", how="outer") +shape: (4, 4) +┌──────┬──────┬──────────┬──────┐ +│ L1 ┆ L2 ┆ L1_right ┆ R2 │ +│ --- ┆ --- ┆ --- ┆ --- │ +│ str ┆ i64 ┆ str ┆ i64 │ +╞══════╪══════╪══════════╪══════╡ +│ a ┆ 1 ┆ a ┆ 7 │ +│ b ┆ 2 ┆ null ┆ null │ +│ c ┆ 3 ┆ c ┆ 8 │ +│ null ┆ null ┆ d ┆ 9 │ +└──────┴──────┴──────────┴──────┘ +>>> df1.join(df2, on="a", how="outer_coalesce") # Keeps previous behavior +shape: (4, 3) +┌─────┬──────┬──────┐ +│ L1 ┆ L2 ┆ R2 │ +│ --- ┆ --- ┆ --- │ +│ str ┆ i64 ┆ i64 │ +╞═════╪══════╪══════╡ +│ a ┆ 1 ┆ 7 │ +│ c ┆ 3 ┆ 8 │ +│ d ┆ null ┆ 9 │ +│ b ┆ 2 ┆ null │ +└─────┴──────┴──────┘ +``` + +### `count` now ignores null values + +The `count` method for `Expr` and `Series` now ignores null values. Use `len` to get the count with +null values included. + +Note that `pl.count()` and `group_by(...).count()` are unchanged. These count the number of rows in +the context, so nulls are not applicable in the same way. + +This brings behavior more in line with the SQL standard, where `COUNT(col)` ignores null values but +`COUNT(*)` counts rows regardless of null values. + +**Example** + +Before: + +```pycon +>>> df = pl.DataFrame({"a": [1, 2, None]}) +>>> df.select(pl.col("a").count()) +shape: (1, 1) +┌─────┐ +│ a │ +│ --- │ +│ u32 │ +╞═════╡ +│ 3 │ +└─────┘ +``` + +After: + +```pycon +>>> df.select(pl.col("a").count()) +shape: (1, 1) +┌─────┐ +│ a │ +│ --- │ +│ u32 │ +╞═════╡ +│ 2 │ +└─────┘ +>>> df.select(pl.col("a").len()) # Mirrors previous behavior +shape: (1, 1) +┌─────┐ +│ a │ +│ --- │ +│ u32 │ +╞═════╡ +│ 3 │ +└─────┘ +``` + +### `NaN` values are now considered equal + +Floating point `NaN` values were treated as unequal across Polars operations. This has been +corrected to better match user expectation and existing standards. + +While this is considered a bug fix, it is included in this guide in order to draw attention to +possible impact on user workflows that may contain `NaN` values. + +**Example** + +Before: + +```pycon +>>> s = pl.Series([1.0, float("nan"), float("inf")]) +>>> s == s +shape: (3,) +Series: '' [bool] +[ + true + false + true +] +``` + +After: + +```pycon +>>> s == s +shape: (3,) +Series: '' [bool] +[ + true + true + true +] +``` + +### Assertion utils updates to exact checking and `NaN` equality + +The assertion utility functions `assert_frame_equal` and `assert_series_equal` would use the +tolerance parameters `atol` and `rtol` to do approximate checking, unless `check_exact` was set to +`True`. This could lead to some surprising behavior, as integers are generally thought of as exact +values. Integer values are now always checked exactly. To do inexact checking, convert to float +first. + +Additionally, the `nans_compare_equal` parameter has been removed and `NaN` values are now always +considered equal, which was the previous default behavior. This parameter had previously been +deprecated but has been removed before the end of the standard deprecation period to facilitate the +change to `NaN` equality. + +**Example** + +Before: + +```pycon +>>> from polars.testing import assert_frame_equal +>>> df1 = pl.DataFrame({"id": [123456]}) +>>> df2 = pl.DataFrame({"id": [123457]}) +>>> assert_frame_equal(df1, df2) # Passes +``` + +After: + +```pycon +>>> assert_frame_equal(df1, df2) +... +AssertionError: DataFrames are different (value mismatch for column 'id') +[left]: [123456] +[right]: [123457] +``` + +### Allow all `DataType` objects to be instantiated + +Polars data types are subclasses of the `DataType` class. We had a 'hack' in place that +automatically converted data types instantiated without any arguments to the `class`, rather than +actually instantiating it. The idea was to allow specifying data types as `Int64` rather than +`Int64()`, which is more succinct. However, this caused some unexpected behavior when working +directly with data type objects, especially as there was a discrepancy with data types like +`Datetime` which _will_ be instantiated in many cases. + +Going forward, instantiating a data type will always return an instance of that class. Both classes +an instances are handled by Polars, so the previous short syntax is still available. Methods that +return data types like `Series.dtype` and `DataFrame.schema` now always return instantiated data +types objects. + +You may have to update some of your data type checks if you were not already using the equality +operator (`==`), as well as update some type hints. + +**Example** + +Before: + +```pycon +>>> s = pl.Series([1, 2, 3], dtype=pl.Int8) +>>> s.dtype == pl.Int8 +True +>>> s.dtype is pl.Int8 +True +>>> isinstance(s.dtype, pl.Int8) +False +``` + +After: + +```pycon +>>> s.dtype == pl.Int8 +True +>>> s.dtype is pl.Int8 +False +>>> isinstance(s.dtype, pl.Int8) +True +``` + +### Update constructors for `Decimal` and `Array` data types + +The data types `Decimal` and `Array` have had their parameters switched around. The new constructors +should more closely match user expectations. + +**Example** + +Before: + +```pycon +>>> pl.Array(2, pl.Int16) +Array(Int16, 2) +>>> pl.Decimal(5, 10) +Decimal(precision=10, scale=5) +``` + +After: + +```pycon +>>> pl.Array(pl.Int16, 2) +Array(Int16, width=2) +>>> pl.Decimal(10, 5) +Decimal(precision=10, scale=5) +``` + +### `DataType.is_nested` changed from a property to a class method + +This is a minor change, but a very important one to properly update. Failure to update accordingly +may result in faulty logic, as Python will evaluate the _method_ to `True`. For example, +`if dtype.is_nested` will now evaluate to `True` regardless of the data type, because it returns the +method, which Python considers truthy. + +**Example** + +Before: + +``` +>>> pl.List(pl.Int8).is_nested +True +``` + +After: + +``` +>>> pl.List(pl.Int8).is_nested() +True +``` + +### Smaller integer data types for datetime components `dt.month`, `dt.week` + +Most datetime components such as `month` and `week` would previously return a `UInt32` type. This +has been updated to the smallest appropriate signed integer type. This should reduce memory +consumption. + +| Method | Dtype old | Dtype new | +| ----------- | --------- | --------- | +| year | i32 | i32 | +| iso_year | i32 | i32 | +| quarter | u32 | i8 | +| month | u32 | i8 | +| week | u32 | i8 | +| day | u32 | i8 | +| weekday | u32 | i8 | +| ordinal_day | u32 | i16 | +| hour | u32 | i8 | +| minute | u32 | i8 | +| second | u32 | i8 | +| millisecond | u32 | i32* | +| microsecond | u32 | i32 | +| nanosecond | u32 | i32 | + +_*Technically, `millisecond` can be an `i16`. This may be updated in the future._ + +**Example** + +Before: + +```pycon +>>> from datetime import date +>>> s = pl.Series([date(2023, 12, 31), date(2024, 1, 1)]) +>>> s.dt.month() +shape: (2,) +Series: '' [u32] +[ + 12 + 1 +] +``` + +After: + +```pycon +>>> s.dt.month() +shape: (2,) +Series: '' [u8] +[ + 12 + 1 +] +``` + +### Series now defaults to `Null` data type when no data is present + +This replaces the previous behavior of initializing as a `Float32` type. + +**Example** + +Before: + +```pycon +>>> pl.Series("a", [None]) +shape: (1,) +Series: 'a' [f32] +[ + null +] +``` + +After: + +```pycon +>>> pl.Series("a", [None]) +shape: (1,) +Series: 'a' [null] +[ + null +] +``` + +### `replace` reimplemented with slightly different behavior + +The new implementation is mostly backwards compatible. Please do note the following: + +1. The logic for determining the return data type has changed. You may want to specify + `return_dtype` to override the inferred data type, or take advantage of the new function + signature (separate `old` and `new` parameters) to influence the return type. +2. The previous workaround for referencing other columns as default by using a struct column no + longer works. It now simply works as expected, no workaround needed. + +**Example** + +Before: + +```pycon +>>> df = pl.DataFrame({"a": [1, 2, 2, 3], "b": [1.5, 2.5, 5.0, 1.0]}, schema={"a": pl.Int8, "b": pl.Float64}) +>>> df.select(pl.col("a").replace({2: 100})) +shape: (4, 1) +┌─────┐ +│ a │ +│ --- │ +│ i8 │ +╞═════╡ +│ 1 │ +│ 100 │ +│ 100 │ +│ 3 │ +└─────┘ +>>> df.select(pl.struct("a", "b").replace({2: 100}, default=pl.col("b"))) +shape: (4, 1) +┌───────┐ +│ a │ +│ --- │ +│ f64 │ +╞═══════╡ +│ 1.5 │ +│ 100.0 │ +│ 100.0 │ +│ 1.0 │ +└───────┘ +``` + +After: + +```pycon +>>> df.select(pl.col("a").replace({2: 100})) +shape: (4, 1) +┌─────┐ +│ a │ +│ --- │ +│ i64 │ +╞═════╡ +│ 1 │ +│ 100 │ +│ 100 │ +│ 3 │ +└─────┘ +>>> df.select(pl.col("a").replace({2: 100}, default=pl.col("b"))) # No struct needed +shape: (4, 1) +┌───────┐ +│ a │ +│ --- │ +│ f64 │ +╞═══════╡ +│ 1.5 │ +│ 100.0 │ +│ 100.0 │ +│ 1.0 │ +└───────┘ +``` + +### `value_counts` resulting column renamed from `counts` to `count` + +The resulting struct field for the `value_counts` method has been renamed from `counts` to `count`. + +**Example** + +Before: + +```pycon +>>> s = pl.Series("a", ["x", "x", "y"]) +>>> s.value_counts() +shape: (2, 2) +┌─────┬────────┐ +│ a ┆ counts │ +│ --- ┆ --- │ +│ str ┆ u32 │ +╞═════╪════════╡ +│ x ┆ 2 │ +│ y ┆ 1 │ +└─────┴────────┘ +``` + +After: + +```pycon +>>> s.value_counts() +shape: (2, 2) +┌─────┬───────┐ +│ a ┆ count │ +│ --- ┆ --- │ +│ str ┆ u32 │ +╞═════╪═══════╡ +│ x ┆ 2 │ +│ y ┆ 1 │ +└─────┴───────┘ +``` + +### Update `read_parquet` to use Object Store rather than fsspec + +If you were using `read_parquet`, installing `fsspec` as an optional dependency is no longer +required. The new Object Store implementation was already in use for `scan_parquet`. It may have +slightly different behavior in certain cases, such as how credentials are detected and how downloads +are performed. + +The resulting `DataFrame` should be identical between versions. + +## Deprecations + +### Cumulative functions renamed from `cum*` to `cum_*` + +Technically, this deprecation was introduced in version `0.19.14`, but many users will first +encounter it when upgrading to `0.20`. It's a relatively impactful change, which is why we mention +it here. + +| Old name | New name | +| ----------- | ------------ | +| `cumfold` | `cum_fold` | +| `cumreduce` | `cum_reduce` | +| `cumsum` | `cum_sum` | +| `cumprod` | `cum_prod` | +| `cummin` | `cum_min` | +| `cummax` | `cum_max` | +| `cumcount` | `cum_count` | diff --git a/docs/source/releases/upgrade/1.md b/docs/source/releases/upgrade/1.md new file mode 100644 index 000000000000..8caaad765f8b --- /dev/null +++ b/docs/source/releases/upgrade/1.md @@ -0,0 +1,1071 @@ +# Version 1 + +## Breaking changes + +### Properly apply `strict` parameter in Series constructor + +The behavior of the Series constructor has been updated. Generally, it will be more strict, unless +the user passes `strict=False`. + +Strict construction is more efficient than non-strict construction, so make sure to pass values of +the same data type to the constructor for the best performance. + +**Example** + +Before: + +```pycon +>>> s = pl.Series([1, 2, 3.5]) +shape: (3,) +Series: '' [f64] +[ + 1.0 + 2.0 + 3.5 +] +>>> s = pl.Series([1, 2, 3.5], strict=False) +shape: (3,) +Series: '' [i64] +[ + 1 + 2 + null +] +>>> s = pl.Series([1, 2, 3.5], strict=False, dtype=pl.Int8) +Series: '' [i8] +[ + 1 + 2 + null +] +``` + +After: + +```pycon +>>> s = pl.Series([1, 2, 3.5]) +Traceback (most recent call last): +... +TypeError: unexpected value while building Series of type Int64; found value of type Float64: 3.5 + +Hint: Try setting `strict=False` to allow passing data with mixed types. +>>> s = pl.Series([1, 2, 3.5], strict=False) +shape: (3,) +Series: '' [f64] +[ + 1.0 + 2.0 + 3.5 +] +>>> s = pl.Series([1, 2, 3.5], strict=False, dtype=pl.Int8) +Series: '' [i8] +[ + 1 + 2 + 3 +] +``` + +### Change data orientation inference logic for DataFrame construction + +Polars no longer inspects data types to infer the orientation of the data passed to the DataFrame +constructor. Data orientation is inferred based on the data and schema dimensions. + +Additionally, a warning is raised whenever row orientation is inferred. Because of some confusing +edge cases, users should pass `orient="row"` to make explicit that their input is row-based. + +**Example** + +Before: + +```pycon +>>> data = [[1, "a"], [2, "b"]] +>>> pl.DataFrame(data) +shape: (2, 2) +┌──────────┬──────────┐ +│ column_0 ┆ column_1 │ +│ --- ┆ --- │ +│ i64 ┆ str │ +╞══════════╪══════════╡ +│ 1 ┆ a │ +│ 2 ┆ b │ +└──────────┴──────────┘ +``` + +After: + +```pycon +>>> pl.DataFrame(data) +Traceback (most recent call last): +... +TypeError: unexpected value while building Series of type Int64; found value of type String: "a" + +Hint: Try setting `strict=False` to allow passing data with mixed types. +``` + +Use instead: + +```pycon +>>> pl.DataFrame(data, orient="row") +shape: (2, 2) +┌──────────┬──────────┐ +│ column_0 ┆ column_1 │ +│ --- ┆ --- │ +│ i64 ┆ str │ +╞══════════╪══════════╡ +│ 1 ┆ a │ +│ 2 ┆ b │ +└──────────┴──────────┘ +``` + +### Consistently convert to given time zone in Series constructor + +!!! danger + + This change may silently impact the results of your pipelines. + If you work with time zones, please make sure to account for this change. + +Handling of time zone information in the Series and DataFrame constructors was inconsistent. +Row-wise construction would convert to the given time zone, while column-wise construction would +_replace_ the time zone. The inconsistency has been fixed by always converting to the time zone +specified in the data type. + +**Example** + +Before: + +```pycon +>>> from datetime import datetime +>>> pl.Series([datetime(2020, 1, 1)], dtype=pl.Datetime('us', 'Europe/Amsterdam')) +shape: (1,) +Series: '' [datetime[μs, Europe/Amsterdam]] +[ + 2020-01-01 00:00:00 CET +] +``` + +After: + +```pycon +>>> from datetime import datetime +>>> pl.Series([datetime(2020, 1, 1)], dtype=pl.Datetime('us', 'Europe/Amsterdam')) +shape: (1,) +Series: '' [datetime[μs, Europe/Amsterdam]] +[ + 2020-01-01 01:00:00 CET +] +``` + +### Update some error types to more appropriate variants + +We have updated a lot of error types to more accurately represent the problem. Most commonly, +`ComputeError` types were changed to `InvalidOperationError` or `SchemaError`. + +**Example** + +Before: + +```pycon +>>> s = pl.Series("a", [100, 200, 300]) +>>> s.cast(pl.UInt8) +Traceback (most recent call last): +... +polars.exceptions.ComputeError: conversion from `i64` to `u8` failed in column 'a' for 1 out of 3 values: [300] +``` + +After: + +```pycon +>>> s.cast(pl.UInt8) +Traceback (most recent call last): +... +polars.exceptions.InvalidOperationError: conversion from `i64` to `u8` failed in column 'a' for 1 out of 3 values: [300] +``` + +### Update `read/scan_parquet` to disable Hive partitioning by default for file inputs + +Parquet reading functions now also support directory inputs. Hive partitioning is enabled by default +for directories, but is now _disabled_ by default for file inputs. File inputs include single files, +globs, and lists of files. Explicitly pass `hive_partitioning=True` to restore previous behavior. + +**Example** + +Before: + +```pycon +>>> pl.read_parquet("dataset/a=1/foo.parquet") +shape: (2, 2) +┌─────┬─────┐ +│ a ┆ x │ +│ --- ┆ --- │ +│ i64 ┆ f64 │ +╞═════╪═════╡ +│ 1 ┆ 1.0 │ +│ 1 ┆ 2.0 │ +└─────┴─────┘ +``` + +After: + +```pycon +>>> pl.read_parquet("dataset/a=1/foo.parquet") +shape: (2, 1) +┌─────┐ +│ x │ +│ --- │ +│ f64 │ +╞═════╡ +│ 1.0 │ +│ 2.0 │ +└─────┘ +>>> pl.read_parquet("dataset/a=1/foo.parquet", hive_partitioning=True) +shape: (2, 2) +┌─────┬─────┐ +│ a ┆ x │ +│ --- ┆ --- │ +│ i64 ┆ f64 │ +╞═════╪═════╡ +│ 1 ┆ 1.0 │ +│ 1 ┆ 2.0 │ +└─────┴─────┘ +``` + +### Update `reshape` to return Array types instead of List types + +`reshape` now returns an Array type instead of a List type. + +Users can restore the old functionality by calling `.arr.to_list()` on the output. Note that this is +not more expensive than it would be to create a List type directly, because reshaping into an array +is basically free. + +**Example** + +Before: + +```pycon +>>> s = pl.Series([1, 2, 3, 4, 5, 6]) +>>> s.reshape((2, 3)) +shape: (2,) +Series: '' [list[i64]] +[ + [1, 2, 3] + [4, 5, 6] +] +``` + +After: + +```pycon +>>> s.reshape((2, 3)) +shape: (2,) +Series: '' [array[i64, 3]] +[ + [1, 2, 3] + [4, 5, 6] +] +``` + +### Read 2D NumPy arrays as `Array` type instead of `List` + +The Series constructor now parses 2D NumPy arrays as an `Array` type rather than a `List` type. + +**Example** + +Before: + +```pycon +>>> import numpy as np +>>> arr = np.array([[1, 2], [3, 4]]) +>>> pl.Series(arr) +shape: (2,) +Series: '' [list[i64]] +[ + [1, 2] + [3, 4] +] +``` + +After: + +```pycon +>>> import numpy as np +>>> arr = np.array([[1, 2], [3, 4]]) +>>> pl.Series(arr) +shape: (2,) +Series: '' [array[i64, 2]] +[ + [1, 2] + [3, 4] +] +``` + +### Split `replace` functionality into two separate methods + +The API for `replace` has proven to be confusing to many users, particularly with regards to the +`default` argument and the resulting data type. + +It has been split up into two methods: `replace` and `replace_strict`. `replace` now always keeps +the existing data type _(breaking, see example below)_ and is meant for replacing some values in +your existing column. Its parameters `default` and `return_dtype` have been deprecated. + +The new method `replace_strict` is meant for creating a new column, mapping some or all of the +values of the original column, and optionally specifying a default value. If no default is provided, +it raises an error if any non-null values are not mapped. + +**Example** + +Before: + +```pycon +>>> s = pl.Series([1, 2, 3]) +>>> s.replace(1, "a") +shape: (3,) +Series: '' [str] +[ + "a" + "2" + "3" +] +``` + +After: + +```pycon +>>> s.replace(1, "a") +Traceback (most recent call last): +... +polars.exceptions.InvalidOperationError: conversion from `str` to `i64` failed in column 'literal' for 1 out of 1 values: ["a"] +>>> s.replace_strict(1, "a", default=s) +shape: (3,) +Series: '' [str] +[ + "a" + "2" + "3" +] +``` + +### Preserve nulls in `ewm_mean`, `ewm_std`, and `ewm_var` + +Polars will no longer forward-fill null values in `ewm` methods. The user can call `.forward_fill()` +on the output to achieve the same result. + +**Example** + +Before: + +```pycon +>>> s = pl.Series([1, 4, None, 3]) +>>> s.ewm_mean(alpha=.9, ignore_nulls=False) +shape: (4,) +Series: '' [f64] +[ + 1.0 + 3.727273 + 3.727273 + 3.007913 +] +``` + +After: + +```pycon +>>> s.ewm_mean(alpha=.9, ignore_nulls=False) +shape: (4,) +Series: '' [f64] +[ + 1.0 + 3.727273 + null + 3.007913 +] +``` + +### Update `clip` to no longer propagate nulls in the given bounds + +Null values in the bounds no longer set the value to null - instead, the original value is retained. + +**Before** + +```pycon +>>> df = pl.DataFrame({"a": [0, 1, 2], "min": [1, None, 1]}) +>>> df.select(pl.col("a").clip("min")) +shape: (3, 1) +┌──────┐ +│ a │ +│ --- │ +│ i64 │ +╞══════╡ +│ 1 │ +│ null │ +│ 2 │ +└──────┘ +``` + +**After** + +```pycon +>>> df.select(pl.col("a").clip("min")) +shape: (3, 1) +┌──────┐ +│ a │ +│ --- │ +│ i64 │ +╞══════╡ +│ 1 │ +│ 1 │ +│ 2 │ +└──────┘ +``` + +### Change `str.to_datetime` to default to microsecond precision for format specifiers `"%f"` and `"%.f"` + +In `.str.to_datetime`, when specifying `%.f` as the format, the default was to set the resulting +datatype to nanosecond precision. This has been changed to microsecond precision. + +#### Example + +**Before** + +```pycon +>>> s = pl.Series(["2022-08-31 00:00:00.123456789"]) +>>> s.str.to_datetime(format="%Y-%m-%d %H:%M:%S%.f") +shape: (1,) +Series: '' [datetime[ns]] +[ + 2022-08-31 00:00:00.123456789 +] +``` + +**After** + +```pycon +>>> s.str.to_datetime(format="%Y-%m-%d %H:%M:%S%.f") +shape: (1,) +Series: '' [datetime[us]] +[ + 2022-08-31 00:00:00.123456 +] +``` + +### Update resulting column names in `pivot` when pivoting by multiple values + +In `DataFrame.pivot`, when specifying multiple `values` columns, the result would redundantly +include the `column` column in the column names. This has been addressed. + +**Example** + +Before: + +```python +>>> df = pl.DataFrame( +... { +... "name": ["Cady", "Cady", "Karen", "Karen"], +... "subject": ["maths", "physics", "maths", "physics"], +... "test_1": [98, 99, 61, 58], +... "test_2": [100, 100, 60, 60], +... } +... ) +>>> df.pivot(index='name', columns='subject', values=['test_1', 'test_2']) +shape: (2, 5) +┌───────┬──────────────────────┬────────────────────────┬──────────────────────┬────────────────────────┐ +│ name ┆ test_1_subject_maths ┆ test_1_subject_physics ┆ test_2_subject_maths ┆ test_2_subject_physics │ +│ --- ┆ --- ┆ --- ┆ --- ┆ --- │ +│ str ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ +╞═══════╪══════════════════════╪════════════════════════╪══════════════════════╪════════════════════════╡ +│ Cady ┆ 98 ┆ 99 ┆ 100 ┆ 100 │ +│ Karen ┆ 61 ┆ 58 ┆ 60 ┆ 60 │ +└───────┴──────────────────────┴────────────────────────┴──────────────────────┴────────────────────────┘ +``` + +After: + +```python +>>> df = pl.DataFrame( +... { +... "name": ["Cady", "Cady", "Karen", "Karen"], +... "subject": ["maths", "physics", "maths", "physics"], +... "test_1": [98, 99, 61, 58], +... "test_2": [100, 100, 60, 60], +... } +... ) +>>> df.pivot('subject', index='name') +┌───────┬──────────────┬────────────────┬──────────────┬────────────────┐ +│ name ┆ test_1_maths ┆ test_1_physics ┆ test_2_maths ┆ test_2_physics │ +│ --- ┆ --- ┆ --- ┆ --- ┆ --- │ +│ str ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ +╞═══════╪══════════════╪════════════════╪══════════════╪════════════════╡ +│ Cady ┆ 98 ┆ 99 ┆ 100 ┆ 100 │ +│ Karen ┆ 61 ┆ 58 ┆ 60 ┆ 60 │ +└───────┴──────────────┴────────────────┴──────────────┴────────────────┘ +``` + +Note that the function signature has also changed: + +- `columns` has been renamed to `on`, and is now the first positional argument. +- `index` and `values` are both optional. If `index` is not specified, then it will use all columns + not specified in `on` and `values`. If `values` is not specified, it will use all columns not + specified in `on` and `index`. + +### Support Decimal types by default when converting from Arrow + +Update conversion from Arrow to always convert Decimals into Polars Decimals, rather than cast to +Float64. `Config.activate_decimals` has been removed. + +**Example** + +Before: + +```pycon +>>> from decimal import Decimal as D +>>> import pyarrow as pa +>>> arr = pa.array([D("1.01"), D("2.25")]) +>>> pl.from_arrow(arr) +shape: (2,) +Series: '' [f64] +[ + 1.01 + 2.25 +] +``` + +After: + +```pycon +>>> pl.from_arrow(arr) +shape: (2,) +Series: '' [decimal[3,2]] +[ + 1.01 + 2.25 +] +``` + +### Remove serde functionality from `pl.read_json` and `DataFrame.write_json` + +`pl.read_json` no longer supports reading JSON files produced by `DataFrame.serialize`. Users should +use `pl.DataFrame.deserialize` instead. + +`DataFrame.write_json` now only writes row-oriented JSON. The parameters `row_oriented` and `pretty` +have been removed. Users should use `DataFrame.serialize` to serialize a DataFrame. + +**Example - `write_json`** + +Before: + +```pycon +>>> df = pl.DataFrame({"a": [1, 2], "b": [3.0, 4.0]}) +>>> df.write_json() +'{"columns":[{"name":"a","datatype":"Int64","bit_settings":"","values":[1,2]},{"name":"b","datatype":"Float64","bit_settings":"","values":[3.0,4.0]}]}' +``` + +After: + +```pycon +>>> df.write_json() # Same behavior as previously `df.write_json(row_oriented=True)` +'[{"a":1,"b":3.0},{"a":2,"b":4.0}]' +``` + +**Example - `read_json`** + +Before: + +```pycon +>>> import io +>>> df_ser = '{"columns":[{"name":"a","datatype":"Int64","bit_settings":"","values":[1,2]},{"name":"b","datatype":"Float64","bit_settings":"","values":[3.0,4.0]}]}' +>>> pl.read_json(io.StringIO(df_ser)) +shape: (2, 2) +┌─────┬─────┐ +│ a ┆ b │ +│ --- ┆ --- │ +│ i64 ┆ f64 │ +╞═════╪═════╡ +│ 1 ┆ 3.0 │ +│ 2 ┆ 4.0 │ +└─────┴─────┘ +``` + +After: + +```pycon +>>> pl.read_json(io.StringIO(df_ser)) # Format no longer supported: data is treated as a single row +shape: (1, 1) +┌─────────────────────────────────┐ +│ columns │ +│ --- │ +│ list[struct[4]] │ +╞═════════════════════════════════╡ +│ [{"a","Int64","",[1.0, 2.0]}, … │ +└─────────────────────────────────┘ +``` + +Use instead: + +```pycon +>>> pl.DataFrame.deserialize(io.StringIO(df_ser)) +shape: (2, 2) +┌─────┬─────┐ +│ a ┆ b │ +│ --- ┆ --- │ +│ i64 ┆ f64 │ +╞═════╪═════╡ +│ 1 ┆ 3.0 │ +│ 2 ┆ 4.0 │ +└─────┴─────┘ +``` + +### `Series.equals` no longer checks names by default + +Previously, `Series.equals` would return `False` if the Series names didn't match. The method now no +longer checks the names by default. The previous behavior can be retained by setting +`check_names=True`. + +**Example** + +Before: + +```pycon +>>> s1 = pl.Series("foo", [1, 2, 3]) +>>> s2 = pl.Series("bar", [1, 2, 3]) +>>> s1.equals(s2) +False +``` + +After: + +```pycon +>>> s1.equals(s2) +True +>>> s1.equals(s2, check_names=True) +False +``` + +### Remove `columns` parameter from `nth` expression function + +The `columns` parameter was removed in favor of treating positional inputs as additional indices. +Use `Expr.get` instead to get the same functionality. + +**Example** + +Before: + +```pycon +>>> df = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}) +>>> df.select(pl.nth(1, "a")) +shape: (1, 1) +┌─────┐ +│ a │ +│ --- │ +│ i64 │ +╞═════╡ +│ 2 │ +└─────┘ +``` + +After: + +```pycon +>>> df.select(pl.nth(1, "a")) +... +TypeError: argument 'indices': 'str' object cannot be interpreted as an integer +``` + +Use instead: + +```pycon +>>> df.select(pl.col("a").get(1)) +shape: (1, 1) +┌─────┐ +│ a │ +│ --- │ +│ i64 │ +╞═════╡ +│ 2 │ +└─────┘ +``` + +### Rename struct fields of `rle` output + +The struct fields of the `rle` method have been renamed from `lengths/values` to `len/value`. The +data type of the `len` field has also been updated to match the index type (was previously `Int32`, +now `UInt32`). + +**Before** + +```pycon +>>> s = pl.Series(["a", "a", "b", "c", "c", "c"]) +>>> s.rle().struct.unnest() +shape: (3, 2) +┌─────────┬────────┐ +│ lengths ┆ values │ +│ --- ┆ --- │ +│ i32 ┆ str │ +╞═════════╪════════╡ +│ 2 ┆ a │ +│ 1 ┆ b │ +│ 3 ┆ c │ +└─────────┴────────┘ +``` + +**After** + +```pycon +>>> s.rle().struct.unnest() +shape: (3, 2) +┌─────┬───────┐ +│ len ┆ value │ +│ --- ┆ --- │ +│ u32 ┆ str │ +╞═════╪═══════╡ +│ 2 ┆ a │ +│ 1 ┆ b │ +│ 3 ┆ c │ +└─────┴───────┘ +``` + +### Update `set_sorted` to only accept a single column + +Calling `set_sorted` indicates that a column is sorted _individually_. Passing multiple columns +indicates that each of those columns are also sorted individually. However, many users assumed this +meant that the columns were sorted as a group, which led to incorrect results. + +To help users avoid this pitfall, we removed the possibility to specify multiple columns in +`set_sorted`. To set multiple columns as sorted, simply call `set_sorted` multiple times. + +**Example** + +Before: + +```pycon +>>> df = pl.DataFrame({"a": [1, 2, 3], "b": [4.0, 5.0, 6.0], "c": [9, 7, 8]}) +>>> df.set_sorted("a", "b") +``` + +After: + +```pycon +>>> df.set_sorted("a", "b") +Traceback (most recent call last): +... +TypeError: DataFrame.set_sorted() takes 2 positional arguments but 3 were given +``` + +Use instead: + +```pycon +>>> df.set_sorted("a").set_sorted("b") +``` + +### Default to raising on out-of-bounds indices in all `get`/`gather` operations + +The default behavior was inconsistent between `get` and `gather` operations in various places. Now +all such operations will raise by default. Pass `null_on_oob=True` to restore previous behavior. + +**Example** + +Before: + +```pycon +>>> s = pl.Series([[0, 1, 2], [0]]) +>>> s.list.get(1) +shape: (2,) +Series: '' [i64] +[ + 1 + null +] +``` + +After: + +```pycon +>>> s.list.get(1) +Traceback (most recent call last): +... +polars.exceptions.ComputeError: get index is out of bounds +``` + +Use instead: + +```pycon +>>> s.list.get(1, null_on_oob=True) +shape: (2,) +Series: '' [i64] +[ + 1 + null +] +``` + +### Change default engine for `read_excel` to `"calamine"` + +The `calamine` engine (available through the `fastexcel` package) has been added to Polars +relatively recently. It's much faster than the other engines, and was already the default for `xlsb` +and `xls` files. We now made it the default for all Excel files. + +There may be subtle differences between this engine and the previous default (`xlsx2csv`). One clear +difference is that the `calamine` engine does not support the `engine_options` parameter. If you +cannot get your desired behavior with the `calamine` engine, specify `engine="xlsx2csv"` to restore +previous behavior. + +### Example + +Before: + +```pycon +>>> pl.read_excel("data.xlsx", engine_options={"skip_empty_lines": True}) +``` + +After: + +```pycon +>>> pl.read_excel("data.xlsx", engine_options={"skip_empty_lines": True}) +Traceback (most recent call last): +... +TypeError: read_excel() got an unexpected keyword argument 'skip_empty_lines' +``` + +Instead, explicitly specify the `xlsx2csv` engine or omit the `engine_options`: + +```pycon +>>> pl.read_excel("data.xlsx", engine="xlsx2csv", engine_options={"skip_empty_lines": True}) +``` + +### Remove class variables from some DataTypes + +Some DataType classes had class variables. The `Datetime` class, for example, had `time_unit` and +`time_zone` as class variables. This was unintended: these should have been instance variables. This +has now been corrected. + +**Example** + +Before: + +```pycon +>>> dtype = pl.Datetime +>>> dtype.time_unit is None +True +``` + +After: + +```pycon +>>> dtype.time_unit is None +Traceback (most recent call last): +... +AttributeError: type object 'Datetime' has no attribute 'time_unit' +``` + +Use instead: + +```pycon +>>> getattr(dtype, "time_unit", None) is None +True +``` + +### Change default `offset` in `group_by_dynamic` from 'negative `every`' to 'zero' + +This affects the start of the first window in `group_by_dynamic`. The new behavior should align more +with user expectations. + +**Example** + +Before: + +```pycon +>>> from datetime import date +>>> df = pl.DataFrame({ +... "ts": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)], +... "value": [1, 2, 3], +... }) +>>> df.group_by_dynamic("ts", every="1d", period="2d").agg("value") +shape: (4, 2) +┌────────────┬───────────┐ +│ ts ┆ value │ +│ --- ┆ --- │ +│ date ┆ list[i64] │ +╞════════════╪═══════════╡ +│ 2019-12-31 ┆ [1] │ +│ 2020-01-01 ┆ [1, 2] │ +│ 2020-01-02 ┆ [2, 3] │ +│ 2020-01-03 ┆ [3] │ +└────────────┴───────────┘ +``` + +After: + +```pycon +>>> df.group_by_dynamic("ts", every="1d", period="2d").agg("value") +shape: (3, 2) +┌────────────┬───────────┐ +│ ts ┆ value │ +│ --- ┆ --- │ +│ date ┆ list[i64] │ +╞════════════╪═══════════╡ +│ 2020-01-01 ┆ [1, 2] │ +│ 2020-01-02 ┆ [2, 3] │ +│ 2020-01-03 ┆ [3] │ +└────────────┴───────────┘ +``` + +### Change default serialization format of `LazyFrame/DataFrame/Expr` + +The only serialization format available for the `serialize/deserialize` methods on Polars objects +was JSON. We added a more optimized binary format and made this the default. JSON serialization is +still available by passing `format="json"`. + +**Example** + +Before: + +```pycon +>>> lf = pl.LazyFrame({"a": [1, 2, 3]}).sum() +>>> serialized = lf.serialize() +>>> serialized +'{"MapFunction":{"input":{"DataFrameScan":{"df":{"columns":[{"name":...' +>>> from io import StringIO +>>> pl.LazyFrame.deserialize(StringIO(serialized)).collect() +shape: (1, 1) +┌─────┐ +│ a │ +│ --- │ +│ i64 │ +╞═════╡ +│ 6 │ +└─────┘ +``` + +After: + +```pycon +>>> lf = pl.LazyFrame({"a": [1, 2, 3]}).sum() +>>> serialized = lf.serialize() +>>> serialized +b'\xa1kMapFunction\xa2einput\xa1mDataFrameScan\xa4bdf...' +>>> from io import BytesIO # Note: using BytesIO instead of StringIO +>>> pl.LazyFrame.deserialize(BytesIO(serialized)).collect() +shape: (1, 1) +┌─────┐ +│ a │ +│ --- │ +│ i64 │ +╞═════╡ +│ 6 │ +└─────┘ +``` + +### Constrain access to globals from `DataFrame.sql` in favor of `pl.sql` + +The `sql` methods on `DataFrame` and `LazyFrame` can no longer access global variables. These +methods should be used for operating on the frame itself. For global access, there is now the +top-level `sql` function. + +**Example** + +Before: + +```pycon +>>> df1 = pl.DataFrame({"id1": [1, 2]}) +>>> df2 = pl.DataFrame({"id2": [3, 4]}) +>>> df1.sql("SELECT * FROM df1 CROSS JOIN df2") +shape: (4, 2) +┌─────┬─────┐ +│ id1 ┆ id2 │ +│ --- ┆ --- │ +│ i64 ┆ i64 │ +╞═════╪═════╡ +│ 1 ┆ 3 │ +│ 1 ┆ 4 │ +│ 2 ┆ 3 │ +│ 2 ┆ 4 │ +└─────┴─────┘ +``` + +After: + +```pycon +>>> df1.sql("SELECT * FROM df1 CROSS JOIN df2") +Traceback (most recent call last): +... +polars.exceptions.SQLInterfaceError: relation 'df1' was not found +``` + +Use instead: + +```pycon +>>> pl.sql("SELECT * FROM df1 CROSS JOIN df2", eager=True) +shape: (4, 2) +┌─────┬─────┐ +│ id1 ┆ id2 │ +│ --- ┆ --- │ +│ i64 ┆ i64 │ +╞═════╪═════╡ +│ 1 ┆ 3 │ +│ 1 ┆ 4 │ +│ 2 ┆ 3 │ +│ 2 ┆ 4 │ +└─────┴─────┘ +``` + +### Remove re-export of type aliases + +We have a lot of type aliases defined in the `polars.type_aliases` module. Some of these were +re-exported at the top-level and in the `polars.datatypes` module. These re-exports have been +removed. + +We plan on adding a public `polars.typing` module in the future with a number of curated type +aliases. Until then, please define your own type aliases, or import from our `polars.type_aliases` +module. Note that the `type_aliases` module is not technically public, so use at your own risk. + +**Example** + +Before: + +```python +def foo(dtype: pl.PolarsDataType) -> None: ... +``` + +After: + +```python +PolarsDataType = pl.DataType | type[pl.DataType] + +def foo(dtype: PolarsDataType) -> None: ... +``` + +### Streamline optional dependency definitions in `pyproject.toml` + +We revisited to optional dependency definitions and made some minor changes. If you were using the +extras `fastexcel`, `gevent`, `matplotlib`, or `async`, this is a breaking change. Please update +your Polars installation to use the new extras. + +**Example** + +Before: + +```bash +pip install 'polars[fastexcel,gevent,matplotlib]' +``` + +After: + +```bash +pip install 'polars[calamine,async,graph]' +``` + +## Deprecations + +### Issue `PerformanceWarning` when LazyFrame properties `schema/dtypes/columns/width` are used + +Recent improvements to the correctness of the schema resolving in the lazy engine have had +significant performance impact on the cost of resolving the schema. It is no longer 'free' - in +fact, in complex pipelines with lazy file reading, resolving the schema can be relatively expensive. + +Because of this, the schema-related properties on LazyFrame were no longer good API design. +Properties represent information that is already available, and just needs to be retrieved. However, +for the LazyFrame properties, accessing these may have significant performance cost. + +To solve this, we added the `LazyFrame.collect_schema` method, which retrieves the schema and +returns a `Schema` object. The properties raise a `PerformanceWarning` and tell the user to use +`collect_schema` instead. We chose not to deprecate the properties for now to facilitate writing +code that is generic for both DataFrames and LazyFrames. diff --git a/docs/source/releases/upgrade/index.md b/docs/source/releases/upgrade/index.md new file mode 100644 index 000000000000..256764fb006e --- /dev/null +++ b/docs/source/releases/upgrade/index.md @@ -0,0 +1,26 @@ +# About + +Polars releases an upgrade guide alongside each breaking release. This guide is intended to help you +upgrade from an older Polars version to the new version. + +Each guide contains all breaking changes that were not previously deprecated, as well as any +significant new deprecations. + +A full list of all changes is available in the [changelog](../changelog.md). + +!!! tip + + It can be useful to upgrade to the latest non-breaking version before upgrading to a new breaking version. + This way, you can run your code and address any deprecation warnings. + The upgrade to the new breaking version should then go much more smoothly! + +!!! tip + + One of our maintainers has created a tool for automatically upgrading your Polars code to a later version. + It's based on the well-known pyupgrade tool. + Try out [polars-upgrade](https://github.com/MarcoGorelli/polars-upgrade) and let us know what you think! + +!!! rust "Note" + + There are no upgrade guides yet for Rust releases. + These will be added once the rate of breaking changes to the Rust API slows down and a [deprecation policy](../../development/versioning.md#deprecation-period) is added. diff --git a/docs/source/requirements.txt b/docs/source/requirements.txt new file mode 100644 index 000000000000..60be972f1e3e --- /dev/null +++ b/docs/source/requirements.txt @@ -0,0 +1,18 @@ +altair +pandas +pyarrow +graphviz +hvplot +matplotlib +plotnine +seaborn +plotly +numba >= 0.54; python_version < '3.13' # numba does not support Python 3.13 +numpy + +mkdocs-material==9.6.1 +mkdocs-macros-plugin==1.3.7 +mkdocs-redirects==1.2.1 +material-plausible-plugin==0.3.0 +markdown-exec[ansi]==1.10.0 +pygithub==2.6.1 diff --git a/docs/source/src/python/home/example.py b/docs/source/src/python/home/example.py new file mode 100644 index 000000000000..0011842828b0 --- /dev/null +++ b/docs/source/src/python/home/example.py @@ -0,0 +1,12 @@ +# --8<-- [start:example] +import polars as pl + +q = ( + pl.scan_csv("docs/assets/data/iris.csv") + .filter(pl.col("sepal_length") > 5) + .group_by("species") + .agg(pl.all().sum()) +) + +df = q.collect() +# --8<-- [end:example] diff --git a/docs/source/src/python/polars-cloud/authentication.py b/docs/source/src/python/polars-cloud/authentication.py new file mode 100644 index 000000000000..a87876bdd227 --- /dev/null +++ b/docs/source/src/python/polars-cloud/authentication.py @@ -0,0 +1,7 @@ +""" +# --8<-- [start:login] +import polars_cloud as pc + +workspace = pc.login() +# --8<-- [end:login] +""" diff --git a/docs/source/src/python/polars-cloud/compute-context.py b/docs/source/src/python/polars-cloud/compute-context.py new file mode 100644 index 000000000000..575a979fb71b --- /dev/null +++ b/docs/source/src/python/polars-cloud/compute-context.py @@ -0,0 +1,32 @@ +""" +import polars_cloud as pc + +# --8<-- [start:compute] +ctx = pc.ComputeContext( + workspace="your-workspace", + instance_type="t2.micro", + cluster_size=2, + labels=["docs"], +) +# --8<-- [end:compute] + +# --8<-- [start:default-compute] +ctx = pc.ComputeContext(workspace="your-workspace") +# --8<-- [end:default-compute] + +# --8<-- [start:defined-compute] +ctx = pc.ComputeContext( + workspace="your-workspace", + memory=8, + cpus=2, +) +# --8<-- [end:defined-compute] + +# --8<-- [start:set-compute] +ctx = pc.ComputeContext( + workspace="your-workspace", + instance_type="t2.micro", + cluster_size=2 +) +# --8<-- [end:set-compute] +""" diff --git a/docs/source/src/python/polars-cloud/index.py b/docs/source/src/python/polars-cloud/index.py new file mode 100644 index 000000000000..97a1cf7be425 --- /dev/null +++ b/docs/source/src/python/polars-cloud/index.py @@ -0,0 +1,23 @@ +""" +# --8<-- [start:index] +import polars as pl +import polars_cloud as pc + +ctx = pc.ComputeContext(cpus=16, memory=64) + +query = ( + pl.scan_parquet("s3://my-dataset/") + .group_by("l_returnflag", "l_linestatus") + .agg( + avg_price=pl.mean("l_extendedprice"), + avg_disc=pl.mean("l_discount"), + count_order=pl.len(), + ) +) + +( + query.remote(context=ctx) + .sink_parquet("s3://my-dst/") +) +# --8<-- [end:index] +""" diff --git a/docs/source/src/python/polars-cloud/interactive-batch.py b/docs/source/src/python/polars-cloud/interactive-batch.py new file mode 100644 index 000000000000..bb5ae83181f0 --- /dev/null +++ b/docs/source/src/python/polars-cloud/interactive-batch.py @@ -0,0 +1,59 @@ +""" +# --8<-- [start:example] +import polars as pl +import polars_cloud as pc +import datetime as dt + +lf = pl.LazyFrame( + { + "name": ["Alice Archer", "Ben Brown", "Chloe Cooper", "Daniel Donovan"], + "birthdate": [ + dt.date(1997, 1, 10), + dt.date(1985, 2, 15), + dt.date(1983, 3, 22), + dt.date(1981, 4, 30), + ], + "weight": [57.9, 72.5, 53.6, 83.1], # (kg) + "height": [1.56, 1.77, 1.65, 1.75], # (m) + } +) +# --8<-- [end:example] + +# --8<-- [start:batch] +ctx = pc.ComputeContext(workspace="your-workspace", cpus=24, memory=64) + +lf = lf.select( + pl.col("name"), + pl.col("birthdate").dt.year().alias("birth_year"), + (pl.col("weight") / (pl.col("height") ** 2)).alias("bmi"), +).sort(by="bmi") + +lf.remote(context=ctx).sink_parquet("s3://bucket/output.parquet") +# --8<-- [end:batch] + +# --8<-- [start:interactive] +ctx = pc.ComputeContext( + workspace="your-workspace", cpus=24, memory=64, interactive=True +) + +lf = lf.select( + pl.col("name"), + pl.col("birthdate").dt.year().alias("birth_year"), + (pl.col("weight") / (pl.col("height") ** 2)).alias("bmi"), +).sort(by="bmi") + +res1 = lf.remote(context=ctx).collect() + +# --8<-- [end:interactive] + +# --8<-- [start:interactive-next] +res2 = ( + res1 + .filter( + pl.col("birth_year").is_in([1983, 1985]), + ) + .remote(context=ctx) + .collect() +) +# --8<-- [end:interactive-next] +""" diff --git a/docs/source/src/python/polars-cloud/quickstart.py b/docs/source/src/python/polars-cloud/quickstart.py new file mode 100644 index 000000000000..80d2d06ddcac --- /dev/null +++ b/docs/source/src/python/polars-cloud/quickstart.py @@ -0,0 +1,20 @@ +""" +# --8<-- [start:general] +import polars_cloud as pc +import polars as pl + +ctx = pc.ComputeContext(memory=8, cpus=2, cluster_size=1) +lf = pl.LazyFrame( + { + "a": [1, 2, 3], + "b": [4, 4, 5], + } +).with_columns( + pl.col("a").max().over("b").alias("c"), +) +( + lf.remote(context=ctx) + .sink_parquet(uri="s3://my-bucket/result.parquet") +) +# --8<-- [end:general] +""" diff --git a/docs/source/src/python/user-guide/concepts/data-types-and-structures.py b/docs/source/src/python/user-guide/concepts/data-types-and-structures.py new file mode 100644 index 000000000000..20067c0a4d19 --- /dev/null +++ b/docs/source/src/python/user-guide/concepts/data-types-and-structures.py @@ -0,0 +1,84 @@ +# --8<-- [start:series] +import polars as pl + +s = pl.Series("ints", [1, 2, 3, 4, 5]) +print(s) +# --8<-- [end:series] + +# --8<-- [start:series-dtype] +s1 = pl.Series("ints", [1, 2, 3, 4, 5]) +s2 = pl.Series("uints", [1, 2, 3, 4, 5], dtype=pl.UInt64) +print(s1.dtype, s2.dtype) +# --8<-- [end:series-dtype] + +# --8<-- [start:df] +from datetime import date + +df = pl.DataFrame( + { + "name": ["Alice Archer", "Ben Brown", "Chloe Cooper", "Daniel Donovan"], + "birthdate": [ + date(1997, 1, 10), + date(1985, 2, 15), + date(1983, 3, 22), + date(1981, 4, 30), + ], + "weight": [57.9, 72.5, 53.6, 83.1], # (kg) + "height": [1.56, 1.77, 1.65, 1.75], # (m) + } +) + +print(df) +# --8<-- [end:df] + +# --8<-- [start:schema] +print(df.schema) +# --8<-- [end:schema] + +# --8<-- [start:head] +print(df.head(3)) +# --8<-- [end:head] + +# --8<-- [start:glimpse] +print(df.glimpse(return_as_string=True)) +# --8<-- [end:glimpse] + +# --8<-- [start:tail] +print(df.tail(3)) +# --8<-- [end:tail] + +# --8<-- [start:sample] +import random + +random.seed(42) # For reproducibility. + +print(df.sample(2)) +# --8<-- [end:sample] + +# --8<-- [start:describe] +print(df.describe()) +# --8<-- [end:describe] + +# --8<-- [start:schema-def] +df = pl.DataFrame( + { + "name": ["Alice", "Ben", "Chloe", "Daniel"], + "age": [27, 39, 41, 43], + }, + schema={"name": None, "age": pl.UInt8}, +) + +print(df) +# --8<-- [end:schema-def] + +# --8<-- [start:schema_overrides] +df = pl.DataFrame( + { + "name": ["Alice", "Ben", "Chloe", "Daniel"], + "age": [27, 39, 41, 43], + }, + schema_overrides={"age": pl.UInt8}, +) + +print(df) +# --8<-- [end:schema_overrides] diff --git a/docs/source/src/python/user-guide/concepts/expressions.py b/docs/source/src/python/user-guide/concepts/expressions.py new file mode 100644 index 000000000000..7a4cb0637ba3 --- /dev/null +++ b/docs/source/src/python/user-guide/concepts/expressions.py @@ -0,0 +1,105 @@ +# --8<-- [start:expression] +import polars as pl + +pl.col("weight") / (pl.col("height") ** 2) +# --8<-- [end:expression] + +# --8<-- [start:print-expr] +bmi_expr = pl.col("weight") / (pl.col("height") ** 2) +print(bmi_expr) +# --8<-- [end:print-expr] + +# --8<-- [start:df] +from datetime import date + +df = pl.DataFrame( + { + "name": ["Alice Archer", "Ben Brown", "Chloe Cooper", "Daniel Donovan"], + "birthdate": [ + date(1997, 1, 10), + date(1985, 2, 15), + date(1983, 3, 22), + date(1981, 4, 30), + ], + "weight": [57.9, 72.5, 53.6, 83.1], # (kg) + "height": [1.56, 1.77, 1.65, 1.75], # (m) + } +) + +print(df) +# --8<-- [end:df] + +# --8<-- [start:select-1] +result = df.select( + bmi=bmi_expr, + avg_bmi=bmi_expr.mean(), + ideal_max_bmi=25, +) +print(result) +# --8<-- [end:select-1] + +# --8<-- [start:select-2] +result = df.select(deviation=(bmi_expr - bmi_expr.mean()) / bmi_expr.std()) +print(result) +# --8<-- [end:select-2] + +# --8<-- [start:with_columns-1] +result = df.with_columns( + bmi=bmi_expr, + avg_bmi=bmi_expr.mean(), + ideal_max_bmi=25, +) +print(result) +# --8<-- [end:with_columns-1] + +# --8<-- [start:filter-1] +result = df.filter( + pl.col("birthdate").is_between(date(1982, 12, 31), date(1996, 1, 1)), + pl.col("height") > 1.7, +) +print(result) +# --8<-- [end:filter-1] + +# --8<-- [start:group_by-1] +result = df.group_by( + (pl.col("birthdate").dt.year() // 10 * 10).alias("decade"), +).agg(pl.col("name")) +print(result) +# --8<-- [end:group_by-1] + +# --8<-- [start:group_by-2] +result = df.group_by( + (pl.col("birthdate").dt.year() // 10 * 10).alias("decade"), + (pl.col("height") < 1.7).alias("short?"), +).agg(pl.col("name")) +print(result) +# --8<-- [end:group_by-2] + +# --8<-- [start:group_by-3] +result = df.group_by( + (pl.col("birthdate").dt.year() // 10 * 10).alias("decade"), + (pl.col("height") < 1.7).alias("short?"), +).agg( + pl.len(), + pl.col("height").max().alias("tallest"), + pl.col("weight", "height").mean().name.prefix("avg_"), +) +print(result) +# --8<-- [end:group_by-3] + +# --8<-- [start:expression-expansion-1] +expr = (pl.col(pl.Float64) * 1.1).name.suffix("*1.1") +result = df.select(expr) +print(result) +# --8<-- [end:expression-expansion-1] + +# --8<-- [start:expression-expansion-2] +df2 = pl.DataFrame( + { + "ints": [1, 2, 3, 4], + "letters": ["A", "B", "C", "D"], + } +) +result = df2.select(expr) +print(result) +# --8<-- [end:expression-expansion-2] diff --git a/docs/source/src/python/user-guide/concepts/lazy-vs-eager.py b/docs/source/src/python/user-guide/concepts/lazy-vs-eager.py new file mode 100644 index 000000000000..dd48c65b2378 --- /dev/null +++ b/docs/source/src/python/user-guide/concepts/lazy-vs-eager.py @@ -0,0 +1,45 @@ +# --8<-- [start:import] +import polars as pl + +# --8<-- [end:import] + +# --8<-- [start:eager] + +df = pl.read_csv("docs/assets/data/iris.csv") +df_small = df.filter(pl.col("sepal_length") > 5) +df_agg = df_small.group_by("species").agg(pl.col("sepal_width").mean()) +print(df_agg) +# --8<-- [end:eager] + +# --8<-- [start:lazy] +q = ( + pl.scan_csv("docs/assets/data/iris.csv") + .filter(pl.col("sepal_length") > 5) + .group_by("species") + .agg(pl.col("sepal_width").mean()) +) + +df = q.collect() +# --8<-- [end:lazy] + +# --8<-- [start:explain] +print(q.explain()) +# --8<-- [end:explain] + +# --8<-- [start:explain-expression-expansion] +schema = pl.Schema( + { + "int_1": pl.Int16, + "int_2": pl.Int32, + "float_1": pl.Float64, + "float_2": pl.Float64, + "float_3": pl.Float64, + } +) + +print( + pl.LazyFrame(schema=schema) + .select((pl.col(pl.Float64) * 1.1).name.suffix("*1.1")) + .explain() +) +# --8<-- [end:explain-expression-expansion] diff --git a/docs/source/src/python/user-guide/concepts/streaming.py b/docs/source/src/python/user-guide/concepts/streaming.py new file mode 100644 index 000000000000..7c2dec7e44b8 --- /dev/null +++ b/docs/source/src/python/user-guide/concepts/streaming.py @@ -0,0 +1,27 @@ +# --8<-- [start:import] +import polars as pl + +# --8<-- [end:import] + +# --8<-- [start:streaming] +q1 = ( + pl.scan_csv("docs/assets/data/iris.csv") + .filter(pl.col("sepal_length") > 5) + .group_by("species") + .agg(pl.col("sepal_width").mean()) +) +df = q1.collect(engine="streaming") +# --8<-- [end:streaming] + +# --8<-- [start:example] +print(q1.explain(engine="streaming")) + +# --8<-- [end:example] + +# --8<-- [start:example2] +q2 = pl.scan_csv("docs/assets/data/iris.csv").with_columns( + pl.col("sepal_length").mean().over("species") +) + +print(q2.explain(engine="streaming")) +# --8<-- [end:example2] diff --git a/docs/source/src/python/user-guide/expressions/aggregation.py b/docs/source/src/python/user-guide/expressions/aggregation.py new file mode 100644 index 000000000000..d414037cc9f3 --- /dev/null +++ b/docs/source/src/python/user-guide/expressions/aggregation.py @@ -0,0 +1,180 @@ +# --8<-- [start:dataframe] +import polars as pl + +url = "hf://datasets/nameexhaustion/polars-docs/legislators-historical.csv" + +schema_overrides = { + "first_name": pl.Categorical, + "gender": pl.Categorical, + "type": pl.Categorical, + "state": pl.Categorical, + "party": pl.Categorical, +} + +dataset = ( + pl.read_csv(url, schema_overrides=schema_overrides) + .with_columns(pl.col("first", "middle", "last").name.suffix("_name")) + .with_columns(pl.col("birthday").str.to_date(strict=False)) +) +# --8<-- [end:dataframe] + +# --8<-- [start:basic] +q = ( + dataset.lazy() + .group_by("first_name") + .agg( + pl.len(), + pl.col("gender"), + pl.first("last_name"), # Short for `pl.col("last_name").first()` + ) + .sort("len", descending=True) + .limit(5) +) + +df = q.collect() +print(df) +# --8<-- [end:basic] + +# --8<-- [start:conditional] +q = ( + dataset.lazy() + .group_by("state") + .agg( + (pl.col("party") == "Anti-Administration").sum().alias("anti"), + (pl.col("party") == "Pro-Administration").sum().alias("pro"), + ) + .sort("pro", descending=True) + .limit(5) +) + +df = q.collect() +print(df) +# --8<-- [end:conditional] + +# --8<-- [start:nested] +q = ( + dataset.lazy() + .group_by("state", "party") + .agg(pl.len().alias("count")) + .filter( + (pl.col("party") == "Anti-Administration") + | (pl.col("party") == "Pro-Administration") + ) + .sort("count", descending=True) + .limit(5) +) + +df = q.collect() +print(df) +# --8<-- [end:nested] + + +# --8<-- [start:filter] +from datetime import date + + +def compute_age(): + return date.today().year - pl.col("birthday").dt.year() + + +def avg_birthday(gender: str) -> pl.Expr: + return ( + compute_age() + .filter(pl.col("gender") == gender) + .mean() + .alias(f"avg {gender} birthday") + ) + + +q = ( + dataset.lazy() + .group_by("state") + .agg( + avg_birthday("M"), + avg_birthday("F"), + (pl.col("gender") == "M").sum().alias("# male"), + (pl.col("gender") == "F").sum().alias("# female"), + ) + .limit(5) +) + +df = q.collect() +print(df) +# --8<-- [end:filter] + + +# --8<-- [start:filter-nested] +q = ( + dataset.lazy() + .group_by("state", "gender") + .agg( + # The function `avg_birthday` is not needed: + compute_age().mean().alias("avg birthday"), + pl.len().alias("#"), + ) + .sort("#", descending=True) + .limit(5) +) + +df = q.collect() +print(df) +# --8<-- [end:filter-nested] + + +# --8<-- [start:sort] +def get_name() -> pl.Expr: + return pl.col("first_name") + pl.lit(" ") + pl.col("last_name") + + +q = ( + dataset.lazy() + .sort("birthday", descending=True) + .group_by("state") + .agg( + get_name().first().alias("youngest"), + get_name().last().alias("oldest"), + ) + .limit(5) +) + +df = q.collect() +print(df) +# --8<-- [end:sort] + + +# --8<-- [start:sort2] +q = ( + dataset.lazy() + .sort("birthday", descending=True) + .group_by("state") + .agg( + get_name().first().alias("youngest"), + get_name().last().alias("oldest"), + get_name().sort().first().alias("alphabetical_first"), + ) + .limit(5) +) + +df = q.collect() +print(df) +# --8<-- [end:sort2] + + +# --8<-- [start:sort3] +q = ( + dataset.lazy() + .sort("birthday", descending=True) + .group_by("state") + .agg( + get_name().first().alias("youngest"), + get_name().last().alias("oldest"), + get_name().sort().first().alias("alphabetical_first"), + pl.col("gender").sort_by(get_name()).first(), + ) + .sort("state") + .limit(5) +) + +df = q.collect() +print(df) +# --8<-- [end:sort3] diff --git a/docs/source/src/python/user-guide/expressions/casting.py b/docs/source/src/python/user-guide/expressions/casting.py new file mode 100644 index 000000000000..e320ffa4a0c6 --- /dev/null +++ b/docs/source/src/python/user-guide/expressions/casting.py @@ -0,0 +1,137 @@ +# --8<-- [start:dfnum] +import polars as pl + +df = pl.DataFrame( + { + "integers": [1, 2, 3], + "big_integers": [10000002, 2, 30000003], + "floats": [4.0, 5.8, -6.3], + } +) + +print(df) +# --8<-- [end:dfnum] + +# --8<-- [start:castnum] +result = df.select( + pl.col("integers").cast(pl.Float32).alias("integers_as_floats"), + pl.col("floats").cast(pl.Int32).alias("floats_as_integers"), +) +print(result) +# --8<-- [end:castnum] + + +# --8<-- [start:downcast] +print(f"Before downcasting: {df.estimated_size()} bytes") +result = df.with_columns( + pl.col("integers").cast(pl.Int16), + pl.col("floats").cast(pl.Float32), +) +print(f"After downcasting: {result.estimated_size()} bytes") +# --8<-- [end:downcast] + +# --8<-- [start:overflow] +from polars.exceptions import InvalidOperationError + +try: + result = df.select(pl.col("big_integers").cast(pl.Int8)) + print(result) +except InvalidOperationError as err: + print(err) +# --8<-- [end:overflow] + +# --8<-- [start:overflow2] +result = df.select(pl.col("big_integers").cast(pl.Int8, strict=False)) +print(result) +# --8<-- [end:overflow2] + + +# --8<-- [start:strings] +df = pl.DataFrame( + { + "integers_as_strings": ["1", "2", "3"], + "floats_as_strings": ["4.0", "5.8", "-6.3"], + "floats": [4.0, 5.8, -6.3], + } +) + +result = df.select( + pl.col("integers_as_strings").cast(pl.Int32), + pl.col("floats_as_strings").cast(pl.Float64), + pl.col("floats").cast(pl.String), +) +print(result) +# --8<-- [end:strings] + + +# --8<-- [start:strings2] +df = pl.DataFrame( + { + "floats": ["4.0", "5.8", "- 6 . 3"], + } +) +try: + result = df.select(pl.col("floats").cast(pl.Float64)) +except InvalidOperationError as err: + print(err) +# --8<-- [end:strings2] + +# --8<-- [start:bool] +df = pl.DataFrame( + { + "integers": [-1, 0, 2, 3, 4], + "floats": [0.0, 1.0, 2.0, 3.0, 4.0], + "bools": [True, False, True, False, True], + } +) + +result = df.select( + pl.col("integers").cast(pl.Boolean), + pl.col("floats").cast(pl.Boolean), + pl.col("bools").cast(pl.Int8), +) +print(result) +# --8<-- [end:bool] + +# --8<-- [start:dates] +from datetime import date, datetime, time + +df = pl.DataFrame( + { + "date": [ + date(1970, 1, 1), # epoch + date(1970, 1, 10), # 9 days later + ], + "datetime": [ + datetime(1970, 1, 1, 0, 0, 0), # epoch + datetime(1970, 1, 1, 0, 1, 0), # 1 minute later + ], + "time": [ + time(0, 0, 0), # reference time + time(0, 0, 1), # 1 second later + ], + } +) + +result = df.select( + pl.col("date").cast(pl.Int64).alias("days_since_epoch"), + pl.col("datetime").cast(pl.Int64).alias("us_since_epoch"), + pl.col("time").cast(pl.Int64).alias("ns_since_midnight"), +) +print(result) +# --8<-- [end:dates] + +# --8<-- [start:dates2] +df = pl.DataFrame( + { + "date": [date(2022, 1, 1), date(2022, 1, 2)], + "string": ["2022-01-01", "2022-01-02"], + } +) + +result = df.select( + pl.col("date").dt.to_string("%Y-%m-%d"), + pl.col("string").str.to_datetime("%Y-%m-%d"), +) +print(result) +# --8<-- [end:dates2] diff --git a/docs/source/src/python/user-guide/expressions/categoricals.py b/docs/source/src/python/user-guide/expressions/categoricals.py new file mode 100644 index 000000000000..c45c7f9dcef8 --- /dev/null +++ b/docs/source/src/python/user-guide/expressions/categoricals.py @@ -0,0 +1,215 @@ +# --8<-- [start:enum-example] +import polars as pl + +bears_enum = pl.Enum(["Polar", "Panda", "Brown"]) +bears = pl.Series(["Polar", "Panda", "Brown", "Brown", "Polar"], dtype=bears_enum) +print(bears) +# --8<-- [end:enum-example] + +# --8<-- [start:enum-wrong-value] +from polars.exceptions import InvalidOperationError + +try: + bears_kind_of = pl.Series( + ["Polar", "Panda", "Brown", "Polar", "Shark"], + dtype=bears_enum, + ) +except InvalidOperationError as exc: + print("InvalidOperationError:", exc) +# --8<-- [end:enum-wrong-value] + +# --8<-- [start:log-levels] +log_levels = pl.Enum(["debug", "info", "warning", "error"]) + +logs = pl.DataFrame( + { + "level": ["debug", "info", "debug", "error"], + "message": [ + "process id: 525", + "Service started correctly", + "startup time: 67ms", + "Cannot connect to DB!", + ], + }, + schema_overrides={ + "level": log_levels, + }, +) + +non_debug_logs = logs.filter( + pl.col("level") > "debug", +) +print(non_debug_logs) +# --8<-- [end:log-levels] + +# --8<-- [start:categorical-example] +bears_cat = pl.Series( + ["Polar", "Panda", "Brown", "Brown", "Polar"], dtype=pl.Categorical +) +print(bears_cat) +# --8<-- [end:categorical-example] + +# --8<-- [start:categorical-comparison-string] +print(bears_cat < "Cat") +# --8<-- [end:categorical-comparison-string] + +# --8<-- [start:categorical-comparison-string-column] +bears_str = pl.Series( + ["Panda", "Brown", "Brown", "Polar", "Polar"], +) +print(bears_cat == bears_str) +# --8<-- [end:categorical-comparison-string-column] + +# --8<-- [start:categorical-comparison-categorical-column] +from polars.exceptions import StringCacheMismatchError + +bears_cat2 = pl.Series( + ["Panda", "Brown", "Brown", "Polar", "Polar"], + dtype=pl.Categorical, +) + +try: + print(bears_cat == bears_cat2) +except StringCacheMismatchError as exc: + exc_str = str(exc).splitlines()[0] + print("StringCacheMismatchError:", exc_str) +# --8<-- [end:categorical-comparison-categorical-column] + +# --8<-- [start:stringcache-categorical-equality] +with pl.StringCache(): + bears_cat = pl.Series( + ["Polar", "Panda", "Brown", "Brown", "Polar"], dtype=pl.Categorical + ) + bears_cat2 = pl.Series( + ["Panda", "Brown", "Brown", "Polar", "Polar"], dtype=pl.Categorical + ) + +print(bears_cat == bears_cat2) +# --8<-- [end:stringcache-categorical-equality] + +# --8<-- [start:stringcache-categorical-comparison-lexical] +with pl.StringCache(): + bears_cat = pl.Series( + ["Polar", "Panda", "Brown", "Brown", "Polar"], + dtype=pl.Categorical(ordering="lexical"), + ) + bears_cat2 = pl.Series( + ["Panda", "Brown", "Brown", "Polar", "Polar"], dtype=pl.Categorical + ) + +print(bears_cat > bears_cat2) +# --8<-- [end:stringcache-categorical-comparison-lexical] + +# --8<-- [start:stringcache-categorical-comparison-physical] +with pl.StringCache(): + bears_cat = pl.Series( + # Polar < Panda < Brown + ["Polar", "Panda", "Brown", "Brown", "Polar"], + dtype=pl.Categorical, + ) + bears_cat2 = pl.Series( + ["Panda", "Brown", "Brown", "Polar", "Polar"], dtype=pl.Categorical + ) + +print(bears_cat > bears_cat2) +# --8<-- [end:stringcache-categorical-comparison-physical] + +# --8<-- [start:concatenating-categoricals] +import warnings + +from polars.exceptions import CategoricalRemappingWarning + +male_bears = pl.DataFrame( + { + "species": ["Polar", "Brown", "Panda"], + "weight": [450, 500, 110], # kg + }, + schema_overrides={"species": pl.Categorical}, +) +female_bears = pl.DataFrame( + { + "species": ["Brown", "Polar", "Panda"], + "weight": [340, 200, 90], # kg + }, + schema_overrides={"species": pl.Categorical}, +) + +with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=CategoricalRemappingWarning) + bears = pl.concat([male_bears, female_bears], how="vertical") + +print(bears) +# --8<-- [end:concatenating-categoricals] + + +# --8<-- [start:example] +import polars as pl + +bears_enum = pl.Enum(["Polar", "Panda", "Brown"]) +bears = pl.Series(["Polar", "Panda", "Brown", "Brown", "Polar"], dtype=bears_enum) +print(bears) + +cat_bears = pl.Series( + ["Polar", "Panda", "Brown", "Brown", "Polar"], dtype=pl.Categorical +) +# --8<-- [end:example] + + +# --8<-- [start:append] +cat_bears = pl.Series( + ["Polar", "Panda", "Brown", "Brown", "Polar"], dtype=pl.Categorical +) +cat2_series = pl.Series( + ["Panda", "Brown", "Brown", "Polar", "Polar"], dtype=pl.Categorical +) + +# Triggers a CategoricalRemappingWarning. +print(cat_bears.append(cat2_series)) +# --8<-- [end:append] + +# --8<-- [start:enum_append] +dtype = pl.Enum(["Polar", "Panda", "Brown"]) +cat_bears = pl.Series(["Polar", "Panda", "Brown", "Brown", "Polar"], dtype=dtype) +cat2_series = pl.Series(["Panda", "Brown", "Brown", "Polar", "Polar"], dtype=dtype) +print(cat_bears.append(cat2_series)) +# --8<-- [end:enum_append] + +# --8<-- [start:enum_error] +dtype = pl.Enum(["Polar", "Panda", "Brown"]) +try: + cat_bears = pl.Series(["Polar", "Panda", "Brown", "Black"], dtype=dtype) +except Exception as e: + print(e) +# --8<-- [end:enum_error] + +# --8<-- [start:equality] +dtype = pl.Enum(["Polar", "Panda", "Brown"]) +cat_bears = pl.Series(["Brown", "Panda", "Polar"], dtype=dtype) +cat_series2 = pl.Series(["Polar", "Panda", "Brown"], dtype=dtype) +print(cat_bears == cat_series2) +# --8<-- [end:equality] + +# --8<-- [start:global_equality] +with pl.StringCache(): + cat_bears = pl.Series(["Brown", "Panda", "Polar"], dtype=pl.Categorical) + cat_series2 = pl.Series(["Polar", "Panda", "Black"], dtype=pl.Categorical) + print(cat_bears == cat_series2) +# --8<-- [end:global_equality] + +# --8<-- [start:equality] +dtype = pl.Enum(["Polar", "Panda", "Brown"]) +cat_bears = pl.Series(["Brown", "Panda", "Polar"], dtype=dtype) +cat_series2 = pl.Series(["Polar", "Panda", "Brown"], dtype=dtype) +print(cat_bears == cat_series2) +# --8<-- [end:equality] + +# --8<-- [start:str_compare_single] +cat_bears = pl.Series(["Brown", "Panda", "Polar"], dtype=pl.Categorical) +print(cat_bears <= "Cat") +# --8<-- [end:str_compare_single] + +# --8<-- [start:str_compare] +cat_bears = pl.Series(["Brown", "Panda", "Polar"], dtype=pl.Categorical) +cat_series_utf = pl.Series(["Panda", "Panda", "Polar"]) +print(cat_bears <= cat_series_utf) +# --8<-- [end:str_compare] diff --git a/docs/source/src/python/user-guide/expressions/column-selections.py b/docs/source/src/python/user-guide/expressions/column-selections.py new file mode 100644 index 000000000000..61f3fdb44a09 --- /dev/null +++ b/docs/source/src/python/user-guide/expressions/column-selections.py @@ -0,0 +1,97 @@ +# --8<-- [start:selectors_df] +from datetime import date, datetime + +import polars as pl + +df = pl.DataFrame( + { + "id": [9, 4, 2], + "place": ["Mars", "Earth", "Saturn"], + "date": pl.date_range(date(2022, 1, 1), date(2022, 1, 3), "1d", eager=True), + "sales": [33.4, 2142134.1, 44.7], + "has_people": [False, True, False], + "logged_at": pl.datetime_range( + datetime(2022, 12, 1), datetime(2022, 12, 1, 0, 0, 2), "1s", eager=True + ), + } +).with_row_index("index") +print(df) +# --8<-- [end:selectors_df] + +# --8<-- [start:all] +out = df.select(pl.col("*")) + +# Is equivalent to +out = df.select(pl.all()) +print(out) +# --8<-- [end:all] + +# --8<-- [start:exclude] +out = df.select(pl.col("*").exclude("logged_at", "index")) +print(out) +# --8<-- [end:exclude] + +# --8<-- [start:expansion_by_names] +out = df.select(pl.col("date", "logged_at").dt.to_string("%Y-%h-%d")) +print(out) +# --8<-- [end:expansion_by_names] + +# --8<-- [start:expansion_by_regex] +out = df.select(pl.col("^.*(as|sa).*$")) +print(out) +# --8<-- [end:expansion_by_regex] + +# --8<-- [start:expansion_by_dtype] +out = df.select(pl.col(pl.Int64, pl.UInt32, pl.Boolean).n_unique()) +print(out) +# --8<-- [end:expansion_by_dtype] + +# --8<-- [start:selectors_intro] +import polars.selectors as cs + +out = df.select(cs.integer(), cs.string()) +print(out) +# --8<-- [end:selectors_intro] + +# --8<-- [start:selectors_diff] +out = df.select(cs.numeric() - cs.first()) +print(out) +# --8<-- [end:selectors_diff] + +# --8<-- [start:selectors_union] +out = df.select(cs.by_name("index") | ~cs.numeric()) +print(out) +# --8<-- [end:selectors_union] + +# --8<-- [start:selectors_by_name] +out = df.select(cs.contains("index"), cs.matches(".*_.*")) +print(out) +# --8<-- [end:selectors_by_name] + +# --8<-- [start:selectors_to_expr] +out = df.select(cs.temporal().as_expr().dt.to_string("%Y-%h-%d")) +print(out) +# --8<-- [end:selectors_to_expr] + +# --8<-- [start:selectors_is_selector_utility] +from polars.selectors import is_selector + +out = cs.numeric() +print(is_selector(out)) + +out = cs.boolean() | cs.numeric() +print(is_selector(out)) + +out = cs.numeric() + pl.lit(123) +print(is_selector(out)) +# --8<-- [end:selectors_is_selector_utility] + +# --8<-- [start:selectors_colnames_utility] +from polars.selectors import expand_selector + +out = cs.temporal() +print(expand_selector(df, out)) + +out = ~(cs.temporal() | cs.numeric()) +print(expand_selector(df, out)) +# --8<-- [end:selectors_colnames_utility] diff --git a/docs/source/src/python/user-guide/expressions/expression-expansion.py b/docs/source/src/python/user-guide/expressions/expression-expansion.py new file mode 100644 index 000000000000..b34169e49736 --- /dev/null +++ b/docs/source/src/python/user-guide/expressions/expression-expansion.py @@ -0,0 +1,198 @@ +# --8<-- [start:df] +import polars as pl + +df = pl.DataFrame( + { # As of 14th October 2024, ~3pm UTC + "ticker": ["AAPL", "NVDA", "MSFT", "GOOG", "AMZN"], + "company_name": ["Apple", "NVIDIA", "Microsoft", "Alphabet (Google)", "Amazon"], + "price": [229.9, 138.93, 420.56, 166.41, 188.4], + "day_high": [231.31, 139.6, 424.04, 167.62, 189.83], + "day_low": [228.6, 136.3, 417.52, 164.78, 188.44], + "year_high": [237.23, 140.76, 468.35, 193.31, 201.2], + "year_low": [164.08, 39.23, 324.39, 121.46, 118.35], + } +) + +print(df) +# --8<-- [end:df] + +# --8<-- [start:col-with-names] +eur_usd_rate = 1.09 # As of 14th October 2024 + +result = df.with_columns( + ( + pl.col( + "price", + "day_high", + "day_low", + "year_high", + "year_low", + ) + / eur_usd_rate + ).round(2) +) +print(result) +# --8<-- [end:col-with-names] + +# --8<-- [start:expression-list] +exprs = [ + (pl.col("price") / eur_usd_rate).round(2), + (pl.col("day_high") / eur_usd_rate).round(2), + (pl.col("day_low") / eur_usd_rate).round(2), + (pl.col("year_high") / eur_usd_rate).round(2), + (pl.col("year_low") / eur_usd_rate).round(2), +] + +result2 = df.with_columns(exprs) +print(result.equals(result2)) +# --8<-- [end:expression-list] + +# --8<-- [start:col-with-dtype] +result = df.with_columns((pl.col(pl.Float64) / eur_usd_rate).round(2)) +print(result) +# --8<-- [end:col-with-dtype] + +# --8<-- [start:col-with-dtypes] +result2 = df.with_columns( + ( + pl.col( + pl.Float32, + pl.Float64, + ) + / eur_usd_rate + ).round(2) +) +print(result.equals(result2)) +# --8<-- [end:col-with-dtypes] + +# --8<-- [start:col-with-regex] +result = df.select(pl.col("ticker", "^.*_high$", "^.*_low$")) +print(result) +# --8<-- [end:col-with-regex] + +# --8<-- [start:col-error] +try: + df.select(pl.col("ticker", pl.Float64)) +except TypeError as err: + print("TypeError:", err) +# --8<-- [end:col-error] + +# --8<-- [start:all] +result = df.select(pl.all()) +print(result.equals(df)) +# --8<-- [end:all] + +# --8<-- [start:all-exclude] +result = df.select(pl.all().exclude("^day_.*$")) +print(result) +# --8<-- [end:all-exclude] + +# --8<-- [start:col-exclude] +result = df.select(pl.col(pl.Float64).exclude("^day_.*$")) +print(result) +# --8<-- [end:col-exclude] + +# --8<-- [start:duplicate-error] +from polars.exceptions import DuplicateError + +gbp_usd_rate = 1.31 # As of 14th October 2024 + +try: + df.select( + pl.col("price") / gbp_usd_rate, # This would be named "price"... + pl.col("price") / eur_usd_rate, # And so would this. + ) +except DuplicateError as err: + print("DuplicateError:", err) +# --8<-- [end:duplicate-error] + +# --8<-- [start:alias] +result = df.select( + (pl.col("price") / gbp_usd_rate).alias("price (GBP)"), + (pl.col("price") / eur_usd_rate).alias("price (EUR)"), +) +# --8<-- [end:alias] + +# --8<-- [start:prefix-suffix] +result = df.select( + (pl.col("^year_.*$") / eur_usd_rate).name.prefix("in_eur_"), + (pl.col("day_high", "day_low") / gbp_usd_rate).name.suffix("_gbp"), +) +print(result) +# --8<-- [end:prefix-suffix] + +# --8<-- [start:name-map] +# There is also `.name.to_uppercase`, so this usage of `.map` is moot. +result = df.select(pl.all().name.map(str.upper)) +print(result) +# --8<-- [end:name-map] + +# --8<-- [start:for-with_columns] +result = df +for tp in ["day", "year"]: + result = result.with_columns( + (pl.col(f"{tp}_high") - pl.col(f"{tp}_low")).alias(f"{tp}_amplitude") + ) +print(result) +# --8<-- [end:for-with_columns] + + +# --8<-- [start:yield-expressions] +def amplitude_expressions(time_periods): + for tp in time_periods: + yield (pl.col(f"{tp}_high") - pl.col(f"{tp}_low")).alias(f"{tp}_amplitude") + + +result = df.with_columns(amplitude_expressions(["day", "year"])) +print(result) +# --8<-- [end:yield-expressions] + +# --8<-- [start:selectors] +import polars.selectors as cs + +result = df.select(cs.string() | cs.ends_with("_high")) +print(result) +# --8<-- [end:selectors] + +# --8<-- [start:selectors-set-operations] +result = df.select(cs.contains("_") - cs.string()) +print(result) +# --8<-- [end:selectors-set-operations] + +# --8<-- [start:selectors-expressions] +result = df.select((cs.contains("_") - cs.string()) / eur_usd_rate) +print(result) +# --8<-- [end:selectors-expressions] + +# --8<-- [start:selector-ambiguity] +people = pl.DataFrame( + { + "name": ["Anna", "Bob"], + "has_partner": [True, False], + "has_kids": [False, False], + "has_tattoos": [True, False], + "is_alive": [True, True], + } +) + +wrong_result = people.select((~cs.starts_with("has_")).name.prefix("not_")) +print(wrong_result) +# --8<-- [end:selector-ambiguity] + +# --8<-- [start:as_expr] +result = people.select((~cs.starts_with("has_").as_expr()).name.prefix("not_")) +print(result) +# --8<-- [end:as_expr] + +# --8<-- [start:is_selector] +print(cs.is_selector(~cs.starts_with("has_").as_expr())) +# --8<-- [end:is_selector] + +# --8<-- [start:expand_selector] +print( + cs.expand_selector( + people, + cs.starts_with("has_"), + ) +) +# --8<-- [end:expand_selector] diff --git a/docs/source/src/python/user-guide/expressions/folds.py b/docs/source/src/python/user-guide/expressions/folds.py new file mode 100644 index 000000000000..f0be44b29cb5 --- /dev/null +++ b/docs/source/src/python/user-guide/expressions/folds.py @@ -0,0 +1,89 @@ +# --8<-- [start:mansum] +import operator +import polars as pl + +df = pl.DataFrame( + { + "label": ["foo", "bar", "spam"], + "a": [1, 2, 3], + "b": [10, 20, 30], + } +) + +result = df.select( + pl.fold( + acc=pl.lit(0), + function=operator.add, + exprs=pl.col("a", "b"), + ).alias("sum_fold"), + pl.sum_horizontal(pl.col("a", "b")).alias("sum_horz"), +) + +print(result) +# --8<-- [end:mansum] + +# --8<-- [start:mansum-explicit] +acc = pl.lit(0) +f = operator.add + +result = df.select( + f(f(acc, pl.col("a")), pl.col("b")), + pl.fold(acc=acc, function=f, exprs=pl.col("a", "b")).alias("sum_fold"), +) + +print(result) +# --8<-- [end:mansum-explicit] + +# --8<-- [start:manprod] +result = df.select( + pl.fold( + acc=pl.lit(0), + function=operator.mul, + exprs=pl.col("a", "b"), + ).alias("prod"), +) + +print(result) +# --8<-- [end:manprod] + +# --8<-- [start:manprod-fixed] +result = df.select( + pl.fold( + acc=pl.lit(1), + function=operator.mul, + exprs=pl.col("a", "b"), + ).alias("prod"), +) + +print(result) +# --8<-- [end:manprod-fixed] + +# --8<-- [start:conditional] +df = pl.DataFrame( + { + "a": [1, 2, 3], + "b": [0, 1, 2], + } +) + +result = df.filter( + pl.fold( + acc=pl.lit(True), + function=lambda acc, x: acc & x, + exprs=pl.all() > 1, + ) +) +print(result) +# --8<-- [end:conditional] + +# --8<-- [start:string] +df = pl.DataFrame( + { + "a": ["a", "b", "c"], + "b": [1, 2, 3], + } +) + +result = df.select(pl.concat_str(["a", "b"])) +print(result) +# --8<-- [end:string] diff --git a/docs/source/src/python/user-guide/expressions/lists.py b/docs/source/src/python/user-guide/expressions/lists.py new file mode 100644 index 000000000000..6cffd6520317 --- /dev/null +++ b/docs/source/src/python/user-guide/expressions/lists.py @@ -0,0 +1,184 @@ +# --8<-- [start:list-example] +from datetime import datetime +import polars as pl + +df = pl.DataFrame( + { + "names": [ + ["Anne", "Averill", "Adams"], + ["Brandon", "Brooke", "Borden", "Branson"], + ["Camila", "Campbell"], + ["Dennis", "Doyle"], + ], + "children_ages": [ + [5, 7], + [], + [], + [8, 11, 18], + ], + "medical_appointments": [ + [], + [], + [], + [datetime(2022, 5, 22, 16, 30)], + ], + } +) + +print(df) +# --8<-- [end:list-example] + +# --8<-- [start:array-example] +df = pl.DataFrame( + { + "bit_flags": [ + [True, True, True, True, False], + [False, True, True, True, True], + ], + "tic_tac_toe": [ + [ + [" ", "x", "o"], + [" ", "x", " "], + ["o", "x", " "], + ], + [ + ["o", "x", "x"], + [" ", "o", "x"], + [" ", " ", "o"], + ], + ], + }, + schema={ + "bit_flags": pl.Array(pl.Boolean, 5), + "tic_tac_toe": pl.Array(pl.String, (3, 3)), + }, +) + +print(df) +# --8<-- [end:array-example] + +# --8<-- [start:numpy-array-inference] +import numpy as np + +array = np.arange(0, 120).reshape((5, 2, 3, 4)) # 4D array + +print(pl.Series(array).dtype) # Column with the 3D subarrays +# --8<-- [end:numpy-array-inference] + +# --8<-- [start:weather] +weather = pl.DataFrame( + { + "station": [f"Station {idx}" for idx in range(1, 6)], + "temperatures": [ + "20 5 5 E1 7 13 19 9 6 20", + "18 8 16 11 23 E2 8 E2 E2 E2 90 70 40", + "19 24 E9 16 6 12 10 22", + "E2 E0 15 7 8 10 E1 24 17 13 6", + "14 8 E0 16 22 24 E1", + ], + } +) + +print(weather) +# --8<-- [end:weather] + +# --8<-- [start:split] +weather = weather.with_columns( + pl.col("temperatures").str.split(" "), +) +print(weather) +# --8<-- [end:split] + +# --8<-- [start:explode] +result = weather.explode("temperatures") +print(result) +# --8<-- [end:explode] + +# --8<-- [start:list-slicing] +result = weather.with_columns( + pl.col("temperatures").list.head(3).alias("head"), + pl.col("temperatures").list.tail(3).alias("tail"), + pl.col("temperatures").list.slice(-3, 2).alias("two_next_to_last"), +) +print(result) +# --8<-- [end:list-slicing] + +# --8<-- [start:element-wise-casting] +result = weather.with_columns( + pl.col("temperatures") + .list.eval(pl.element().cast(pl.Int64, strict=False).is_null()) + .list.sum() + .alias("errors"), +) +print(result) +# --8<-- [end:element-wise-casting] + +# --8<-- [start:element-wise-regex] +result2 = weather.with_columns( + pl.col("temperatures") + .list.eval(pl.element().str.contains("(?i)[a-z]")) + .list.sum() + .alias("errors"), +) +print(result.equals(result2)) +# --8<-- [end:element-wise-regex] + +# --8<-- [start:weather_by_day] +weather_by_day = pl.DataFrame( + { + "station": [f"Station {idx}" for idx in range(1, 11)], + "day_1": [17, 11, 8, 22, 9, 21, 20, 8, 8, 17], + "day_2": [15, 11, 10, 8, 7, 14, 18, 21, 15, 13], + "day_3": [16, 15, 24, 24, 8, 23, 19, 23, 16, 10], + } +) +print(weather_by_day) +# --8<-- [end:weather_by_day] + +# --8<-- [start:rank_pct] +rank_pct = (pl.element().rank(descending=True) / pl.all().count()).round(2) + +result = weather_by_day.with_columns( + # create the list of homogeneous data + pl.concat_list(pl.all().exclude("station")).alias("all_temps") +).select( + # select all columns except the intermediate list + pl.all().exclude("all_temps"), + # compute the rank by calling `list.eval` + pl.col("all_temps").list.eval(rank_pct, parallel=True).alias("temps_rank"), +) + +print(result) +# --8<-- [end:rank_pct] + +# --8<-- [start:array-overview] +df = pl.DataFrame( + { + "first_last": [ + ["Anne", "Adams"], + ["Brandon", "Branson"], + ["Camila", "Campbell"], + ["Dennis", "Doyle"], + ], + "fav_numbers": [ + [42, 0, 1], + [2, 3, 5], + [13, 21, 34], + [73, 3, 7], + ], + }, + schema={ + "first_last": pl.Array(pl.String, 2), + "fav_numbers": pl.Array(pl.Int32, 3), + }, +) + +result = df.select( + pl.col("first_last").arr.join(" ").alias("name"), + pl.col("fav_numbers").arr.sort(), + pl.col("fav_numbers").arr.max().alias("largest_fav"), + pl.col("fav_numbers").arr.sum().alias("summed"), + pl.col("fav_numbers").arr.contains(3).alias("likes_3"), +) +print(result) +# --8<-- [end:array-overview] diff --git a/docs/source/src/python/user-guide/expressions/missing-data.py b/docs/source/src/python/user-guide/expressions/missing-data.py new file mode 100644 index 000000000000..e61af94a79e4 --- /dev/null +++ b/docs/source/src/python/user-guide/expressions/missing-data.py @@ -0,0 +1,97 @@ +# --8<-- [start:dataframe] +import polars as pl + +df = pl.DataFrame( + { + "value": [1, None], + }, +) +print(df) +# --8<-- [end:dataframe] + + +# --8<-- [start:count] +null_count_df = df.null_count() +print(null_count_df) +# --8<-- [end:count] + + +# --8<-- [start:isnull] +is_null_series = df.select( + pl.col("value").is_null(), +) +print(is_null_series) +# --8<-- [end:isnull] + + +# --8<-- [start:dataframe2] +df = pl.DataFrame( + { + "col1": [0.5, 1, 1.5, 2, 2.5], + "col2": [1, None, 3, None, 5], + }, +) +print(df) +# --8<-- [end:dataframe2] + + +# --8<-- [start:fill] +fill_literal_df = df.with_columns( + pl.col("col2").fill_null(3), +) +print(fill_literal_df) +# --8<-- [end:fill] + +# --8<-- [start:fillexpr] +fill_expression_df = df.with_columns( + pl.col("col2").fill_null((2 * pl.col("col1")).cast(pl.Int64)), +) +print(fill_expression_df) +# --8<-- [end:fillexpr] + +# --8<-- [start:fillstrategy] +fill_forward_df = df.with_columns( + pl.col("col2").fill_null(strategy="forward").alias("forward"), + pl.col("col2").fill_null(strategy="backward").alias("backward"), +) +print(fill_forward_df) +# --8<-- [end:fillstrategy] + +# --8<-- [start:fillinterpolate] +fill_interpolation_df = df.with_columns( + pl.col("col2").interpolate(), +) +print(fill_interpolation_df) +# --8<-- [end:fillinterpolate] + +# --8<-- [start:nan] +import numpy as np + +nan_df = pl.DataFrame( + { + "value": [1.0, np.nan, float("nan"), 3.0], + }, +) +print(nan_df) +# --8<-- [end:nan] + +# --8<-- [start:nan-computed] +df = pl.DataFrame( + { + "dividend": [1, 0, -1], + "divisor": [1, 0, -1], + } +) +result = df.select(pl.col("dividend") / pl.col("divisor")) +print(result) +# --8<-- [end:nan-computed] + +# --8<-- [start:nanfill] +mean_nan_df = nan_df.with_columns( + pl.col("value").fill_nan(None).alias("replaced"), +).select( + pl.all().mean().name.suffix("_mean"), + pl.all().sum().name.suffix("_sum"), +) +print(mean_nan_df) +# --8<-- [end:nanfill] diff --git a/docs/source/src/python/user-guide/expressions/numpy-example.py b/docs/source/src/python/user-guide/expressions/numpy-example.py new file mode 100644 index 000000000000..ff5bd240b689 --- /dev/null +++ b/docs/source/src/python/user-guide/expressions/numpy-example.py @@ -0,0 +1,7 @@ +import polars as pl +import numpy as np + +df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + +out = df.select(np.log(pl.all()).name.suffix("_log")) +print(out) diff --git a/docs/source/src/python/user-guide/expressions/operations.py b/docs/source/src/python/user-guide/expressions/operations.py new file mode 100644 index 000000000000..556ed512c757 --- /dev/null +++ b/docs/source/src/python/user-guide/expressions/operations.py @@ -0,0 +1,132 @@ +# --8<-- [start:dataframe] +import polars as pl +import numpy as np + +np.random.seed(42) # For reproducibility. + +df = pl.DataFrame( + { + "nrs": [1, 2, 3, None, 5], + "names": ["foo", "ham", "spam", "egg", "spam"], + "random": np.random.rand(5), + "groups": ["A", "A", "B", "A", "B"], + } +) +print(df) +# --8<-- [end:dataframe] + +# --8<-- [start:arithmetic] +result = df.select( + (pl.col("nrs") + 5).alias("nrs + 5"), + (pl.col("nrs") - 5).alias("nrs - 5"), + (pl.col("nrs") * pl.col("random")).alias("nrs * random"), + (pl.col("nrs") / pl.col("random")).alias("nrs / random"), + (pl.col("nrs") ** 2).alias("nrs ** 2"), + (pl.col("nrs") % 3).alias("nrs % 3"), +) + +print(result) +# --8<-- [end:arithmetic] + +# --8<-- [start:operator-overloading] +# Python only: +result_named_operators = df.select( + (pl.col("nrs").add(5)).alias("nrs + 5"), + (pl.col("nrs").sub(5)).alias("nrs - 5"), + (pl.col("nrs").mul(pl.col("random"))).alias("nrs * random"), + (pl.col("nrs").truediv(pl.col("random"))).alias("nrs / random"), + (pl.col("nrs").pow(2)).alias("nrs ** 2"), + (pl.col("nrs").mod(3)).alias("nrs % 3"), +) + +print(result.equals(result_named_operators)) +# --8<-- [end:operator-overloading] + +# --8<-- [start:comparison] +result = df.select( + (pl.col("nrs") > 1).alias("nrs > 1"), # .gt + (pl.col("nrs") >= 3).alias("nrs >= 3"), # ge + (pl.col("random") < 0.2).alias("random < .2"), # .lt + (pl.col("random") <= 0.5).alias("random <= .5"), # .le + (pl.col("nrs") != 1).alias("nrs != 1"), # .ne + (pl.col("nrs") == 1).alias("nrs == 1"), # .eq +) +print(result) +# --8<-- [end:comparison] + +# --8<-- [start:boolean] +# Boolean operators & | ~ +result = df.select( + ((~pl.col("nrs").is_null()) & (pl.col("groups") == "A")).alias( + "number not null and group A" + ), + ((pl.col("random") < 0.5) | (pl.col("groups") == "B")).alias( + "random < 0.5 or group B" + ), +) + +print(result) + +# Corresponding named functions `and_`, `or_`, and `not_`. +result2 = df.select( + (pl.col("nrs").is_null().not_().and_(pl.col("groups") == "A")).alias( + "number not null and group A" + ), + ((pl.col("random") < 0.5).or_(pl.col("groups") == "B")).alias( + "random < 0.5 or group B" + ), +) +print(result.equals(result2)) +# --8<-- [end:boolean] + +# --8<-- [start:bitwise] +result = df.select( + pl.col("nrs"), + (pl.col("nrs") & 6).alias("nrs & 6"), + (pl.col("nrs") | 6).alias("nrs | 6"), + (~pl.col("nrs")).alias("not nrs"), + (pl.col("nrs") ^ 6).alias("nrs ^ 6"), +) + +print(result) +# --8<-- [end:bitwise] + +# --8<-- [start:count] +long_df = pl.DataFrame({"numbers": np.random.randint(0, 100_000, 100_000)}) + +result = long_df.select( + pl.col("numbers").n_unique().alias("n_unique"), + pl.col("numbers").approx_n_unique().alias("approx_n_unique"), +) + +print(result) +# --8<-- [end:count] + +# --8<-- [start:value_counts] +result = df.select( + pl.col("names").value_counts().alias("value_counts"), +) + +print(result) +# --8<-- [end:value_counts] + +# --8<-- [start:unique_counts] +result = df.select( + pl.col("names").unique(maintain_order=True).alias("unique"), + pl.col("names").unique_counts().alias("unique_counts"), +) + +print(result) +# --8<-- [end:unique_counts] + +# --8<-- [start:collatz] +result = df.select( + pl.col("nrs"), + pl.when(pl.col("nrs") % 2 == 1) # Is the number odd? + .then(3 * pl.col("nrs") + 1) # If so, multiply by 3 and add 1. + .otherwise(pl.col("nrs") // 2) # If not, divide by 2. + .alias("Collatz"), +) + +print(result) +# --8<-- [end:collatz] diff --git a/docs/source/src/python/user-guide/expressions/strings.py b/docs/source/src/python/user-guide/expressions/strings.py new file mode 100644 index 000000000000..a68d032d702f --- /dev/null +++ b/docs/source/src/python/user-guide/expressions/strings.py @@ -0,0 +1,112 @@ +# --8<-- [start:df] +import polars as pl + +df = pl.DataFrame( + { + "language": ["English", "Dutch", "Portuguese", "Finish"], + "fruit": ["pear", "peer", "pêra", "päärynä"], + } +) + +result = df.with_columns( + pl.col("fruit").str.len_bytes().alias("byte_count"), + pl.col("fruit").str.len_chars().alias("letter_count"), +) +print(result) +# --8<-- [end:df] + +# --8<-- [start:existence] +result = df.select( + pl.col("fruit"), + pl.col("fruit").str.starts_with("p").alias("starts_with_p"), + pl.col("fruit").str.contains("p..r").alias("p..r"), + pl.col("fruit").str.contains("e+").alias("e+"), + pl.col("fruit").str.ends_with("r").alias("ends_with_r"), +) +print(result) +# --8<-- [end:existence] + +# --8<-- [start:extract] +df = pl.DataFrame( + { + "urls": [ + "http://vote.com/ballon_dor?candidate=messi&ref=polars", + "http://vote.com/ballon_dor?candidat=jorginho&ref=polars", + "http://vote.com/ballon_dor?candidate=ronaldo&ref=polars", + ] + } +) +result = df.select( + pl.col("urls").str.extract(r"candidate=(\w+)", group_index=1), +) +print(result) +# --8<-- [end:extract] + + +# --8<-- [start:extract_all] +df = pl.DataFrame({"text": ["123 bla 45 asd", "xyz 678 910t"]}) +result = df.select( + pl.col("text").str.extract_all(r"(\d+)").alias("extracted_nrs"), +) +print(result) +# --8<-- [end:extract_all] + + +# --8<-- [start:replace] +df = pl.DataFrame({"text": ["123abc", "abc456"]}) +result = df.with_columns( + pl.col("text").str.replace(r"\d", "-"), + pl.col("text").str.replace_all(r"\d", "-").alias("text_replace_all"), +) +print(result) +# --8<-- [end:replace] + +# --8<-- [start:casing] +addresses = pl.DataFrame( + { + "addresses": [ + "128 PERF st", + "Rust blVD, 158", + "PoLaRs Av, 12", + "1042 Query sq", + ] + } +) + +addresses = addresses.select( + pl.col("addresses").alias("originals"), + pl.col("addresses").str.to_titlecase(), + pl.col("addresses").str.to_lowercase().alias("lower"), + pl.col("addresses").str.to_uppercase().alias("upper"), +) +print(addresses) +# --8<-- [end:casing] + +# --8<-- [start:strip] +addr = pl.col("addresses") +chars = ", 0123456789" +result = addresses.select( + addr.str.strip_chars(chars).alias("strip"), + addr.str.strip_chars_end(chars).alias("end"), + addr.str.strip_chars_start(chars).alias("start"), + addr.str.strip_prefix("128 ").alias("prefix"), + addr.str.strip_suffix(", 158").alias("suffix"), +) +print(result) +# --8<-- [end:strip] + +# --8<-- [start:slice] +df = pl.DataFrame( + { + "fruits": ["pear", "mango", "dragonfruit", "passionfruit"], + "n": [1, -1, 4, -4], + } +) + +result = df.with_columns( + pl.col("fruits").str.slice(pl.col("n")).alias("slice"), + pl.col("fruits").str.head(pl.col("n")).alias("head"), + pl.col("fruits").str.tail(pl.col("n")).alias("tail"), +) +print(result) +# --8<-- [end:slice] diff --git a/docs/source/src/python/user-guide/expressions/structs.py b/docs/source/src/python/user-guide/expressions/structs.py new file mode 100644 index 000000000000..f500343b428d --- /dev/null +++ b/docs/source/src/python/user-guide/expressions/structs.py @@ -0,0 +1,119 @@ +# --8<-- [start:ratings_df] +import polars as pl + +ratings = pl.DataFrame( + { + "Movie": ["Cars", "IT", "ET", "Cars", "Up", "IT", "Cars", "ET", "Up", "Cars"], + "Theatre": ["NE", "ME", "IL", "ND", "NE", "SD", "NE", "IL", "IL", "NE"], + "Avg_Rating": [4.5, 4.4, 4.6, 4.3, 4.8, 4.7, 4.5, 4.9, 4.7, 4.6], + "Count": [30, 27, 26, 29, 31, 28, 28, 26, 33, 28], + } +) +print(ratings) +# --8<-- [end:ratings_df] + +# --8<-- [start:state_value_counts] +result = ratings.select(pl.col("Theatre").value_counts(sort=True)) +print(result) +# --8<-- [end:state_value_counts] + +# --8<-- [start:struct_unnest] +result = ratings.select(pl.col("Theatre").value_counts(sort=True)).unnest("Theatre") +print(result) +# --8<-- [end:struct_unnest] + +# --8<-- [start:series_struct] +rating_series = pl.Series( + "ratings", + [ + {"Movie": "Cars", "Theatre": "NE", "Avg_Rating": 4.5}, + {"Movie": "Toy Story", "Theatre": "ME", "Avg_Rating": 4.9}, + ], +) +print(rating_series) +# --8<-- [end:series_struct] + +# --8<-- [start:series_struct_error] +null_rating_series = pl.Series( + "ratings", + [ + {"Movie": "Cars", "Theatre": "NE", "Avg_Rating": 4.5}, + {"Mov": "Toy Story", "Theatre": "ME", "Avg_Rating": 4.9}, + {"Movie": "Snow White", "Theatre": "IL", "Avg_Rating": "4.7"}, + ], + strict=False, # To show the final structs with `null` values. +) +print(null_rating_series) +# --8<-- [end:series_struct_error] + +# --8<-- [start:series_struct_extract] +result = rating_series.struct.field("Movie") +print(result) +# --8<-- [end:series_struct_extract] + +# --8<-- [start:series_struct_rename] +result = rating_series.struct.rename_fields(["Film", "State", "Value"]) +print(result) +# --8<-- [end:series_struct_rename] + +# --8<-- [start:struct-rename-check] +print( + result.to_frame().unnest("ratings"), +) +# --8<-- [end:struct-rename-check] + +# --8<-- [start:struct_duplicates] +result = ratings.filter(pl.struct("Movie", "Theatre").is_duplicated()) +print(result) +# --8<-- [end:struct_duplicates] + +# --8<-- [start:struct_ranking] +result = ratings.with_columns( + pl.struct("Count", "Avg_Rating") + .rank("dense", descending=True) + .over("Movie", "Theatre") + .alias("Rank") +).filter(pl.struct("Movie", "Theatre").is_duplicated()) + +print(result) +# --8<-- [end:struct_ranking] + +# --8<-- [start:multi_column_apply] +df = pl.DataFrame({"keys": ["a", "a", "b"], "values": [10, 7, 1]}) + +result = df.select( + pl.struct(["keys", "values"]) + .map_elements(lambda x: len(x["keys"]) + x["values"], return_dtype=pl.Int64) + .alias("solution_map_elements"), + (pl.col("keys").str.len_bytes() + pl.col("values")).alias("solution_expr"), +) +print(result) +# --8<-- [end:multi_column_apply] + + +# --8<-- [start:ack] +def ack(m, n): + if not m: + return n + 1 + if not n: + return ack(m - 1, 1) + return ack(m - 1, ack(m, n - 1)) + + +# --8<-- [end:ack] + +# --8<-- [start:struct-ack] +values = pl.DataFrame( + { + "m": [0, 0, 0, 1, 1, 1, 2], + "n": [2, 3, 4, 1, 2, 3, 1], + } +) +result = values.with_columns( + pl.struct(["m", "n"]) + .map_elements(lambda s: ack(s["m"], s["n"]), return_dtype=pl.Int64) + .alias("ack") +) + +print(result) +# --8<-- [end:struct-ack] diff --git a/docs/source/src/python/user-guide/expressions/user-defined-functions.py b/docs/source/src/python/user-guide/expressions/user-defined-functions.py new file mode 100644 index 000000000000..3962d8ac0b41 --- /dev/null +++ b/docs/source/src/python/user-guide/expressions/user-defined-functions.py @@ -0,0 +1,114 @@ +# --8<-- [start:setup] + +import warnings + +import polars as pl +from polars.exceptions import PolarsInefficientMapWarning + +warnings.simplefilter("ignore", PolarsInefficientMapWarning) +# --8<-- [end:setup] + +# --8<-- [start:dataframe] +df = pl.DataFrame( + { + "keys": ["a", "a", "b", "b"], + "values": [10, 7, 1, 23], + } +) +print(df) +# --8<-- [end:dataframe] + +# --8<-- [start:individual_log] +import math + + +def my_log(value): + return math.log(value) + + +out = df.select(pl.col("values").map_elements(my_log, return_dtype=pl.Float64)) +print(out) +# --8<-- [end:individual_log] + + +# --8<-- [start:diff_from_mean] +def diff_from_mean(series): + # This will be very slow for non-trivial Series, since it's all Python + # code: + total = 0 + for value in series: + total += value + mean = total / len(series) + return pl.Series([value - mean for value in series]) + + +# Apply our custom function to a full Series with map_batches(): +out = df.select(pl.col("values").map_batches(diff_from_mean)) +print("== select() with UDF ==") +print(out) + +# Apply our custom function per group: +print("== group_by() with UDF ==") +out = df.group_by("keys").agg(pl.col("values").map_batches(diff_from_mean)) +print(out) +# --8<-- [end:diff_from_mean] + +# --8<-- [start:np_log] +import numpy as np + +out = df.select(pl.col("values").map_batches(np.log)) +print(out) +# --8<-- [end:np_log] + +# --8<-- [start:diff_from_mean_numba] +from numba import float64, guvectorize, int64 + + +# This will be compiled to machine code, so it will be fast. The Series is +# converted to a NumPy array before being passed to the function. See the +# Numba documentation for more details: +# https://numba.readthedocs.io/en/stable/user/vectorize.html +@guvectorize([(int64[:], float64[:])], "(n)->(n)") +def diff_from_mean_numba(arr, result): + total = 0 + for value in arr: + total += value + mean = total / len(arr) + for i, value in enumerate(arr): + result[i] = value - mean + + +out = df.select(pl.col("values").map_batches(diff_from_mean_numba)) +print("== select() with UDF ==") +print(out) + +out = df.group_by("keys").agg(pl.col("values").map_batches(diff_from_mean_numba)) +print("== group_by() with UDF ==") +print(out) +# --8<-- [end:diff_from_mean_numba] + + +# --8<-- [start:combine] +# Add two arrays together: +@guvectorize([(int64[:], int64[:], float64[:])], "(n),(n)->(n)") +def add(arr, arr2, result): + for i in range(len(arr)): + result[i] = arr[i] + arr2[i] + + +df3 = pl.DataFrame({"values1": [1, 2, 3], "values2": [10, 20, 30]}) + +out = df3.select( + # Create a struct that has two columns in it: + pl.struct(["values1", "values2"]) + # Pass the struct to a lambda that then passes the individual columns to + # the add() function: + .map_batches( + lambda combined: add( + combined.struct.field("values1"), combined.struct.field("values2") + ) + ) + .alias("add_columns") +) +print(out) +# --8<-- [end:combine] diff --git a/docs/source/src/python/user-guide/expressions/window.py b/docs/source/src/python/user-guide/expressions/window.py new file mode 100644 index 000000000000..f82da48d75f1 --- /dev/null +++ b/docs/source/src/python/user-guide/expressions/window.py @@ -0,0 +1,149 @@ +# --8<-- [start:pokemon] +import polars as pl + +types = ( + "Grass Water Fire Normal Ground Electric Psychic Fighting Bug Steel " + "Flying Dragon Dark Ghost Poison Rock Ice Fairy".split() +) +type_enum = pl.Enum(types) +# then let's load some csv data with information about pokemon +pokemon = pl.read_csv( + "https://gist.githubusercontent.com/ritchie46/cac6b337ea52281aa23c049250a4ff03/raw/89a957ff3919d90e6ef2d34235e6bf22304f3366/pokemon.csv", +).cast({"Type 1": type_enum, "Type 2": type_enum}) +print(pokemon.head()) +# --8<-- [end:pokemon] + +# --8<-- [start:rank] +result = pokemon.select( + pl.col("Name", "Type 1"), + pl.col("Speed").rank("dense", descending=True).over("Type 1").alias("Speed rank"), +) + +print(result) +# --8<-- [end:rank] + +# --8<-- [start:rank-multiple] +result = pokemon.select( + pl.col("Name", "Type 1", "Type 2"), + pl.col("Speed") + .rank("dense", descending=True) + .over("Type 1", "Type 2") + .alias("Speed rank"), +) + +print(result) +# --8<-- [end:rank-multiple] + +# --8<-- [start:rank-explode] +result = ( + pokemon.group_by("Type 1") + .agg( + pl.col("Name"), + pl.col("Speed").rank("dense", descending=True).alias("Speed rank"), + ) + .select(pl.col("Name"), pl.col("Type 1"), pl.col("Speed rank")) + .explode("Name", "Speed rank") +) + +print(result) +# --8<-- [end:rank-explode] + +# --8<-- [start:athletes] +athletes = pl.DataFrame( + { + "athlete": list("ABCDEF"), + "country": ["PT", "NL", "NL", "PT", "PT", "NL"], + "rank": [6, 1, 5, 4, 2, 3], + } +) +print(athletes) +# --8<-- [end:athletes] + +# --8<-- [start:athletes-sort-over-country] +result = athletes.select( + pl.col("athlete", "rank").sort_by(pl.col("rank")).over(pl.col("country")), + pl.col("country"), +) + +print(result) +# --8<-- [end:athletes-sort-over-country] + +# --8<-- [start:athletes-explode] +result = athletes.select( + pl.all() + .sort_by(pl.col("rank")) + .over(pl.col("country"), mapping_strategy="explode"), +) + +print(result) +# --8<-- [end:athletes-explode] + +# --8<-- [start:athletes-join] +result = athletes.with_columns( + pl.col("rank").sort().over(pl.col("country"), mapping_strategy="join"), +) + +print(result) +# --8<-- [end:athletes-join] + +# --8<-- [start:pokemon-mean] +result = pokemon.select( + pl.col("Name", "Type 1", "Speed"), + pl.col("Speed").mean().over(pl.col("Type 1")).alias("Mean speed in group"), +) + +print(result) +# --8<-- [end:pokemon-mean] + + +# --8<-- [start:group_by] +result = pokemon.select( + "Type 1", + "Type 2", + pl.col("Attack").mean().over("Type 1").alias("avg_attack_by_type"), + pl.col("Defense") + .mean() + .over(["Type 1", "Type 2"]) + .alias("avg_defense_by_type_combination"), + pl.col("Attack").mean().alias("avg_attack"), +) +print(result) +# --8<-- [end:group_by] + +# --8<-- [start:operations] +filtered = pokemon.filter(pl.col("Type 2") == "Psychic").select( + "Name", + "Type 1", + "Speed", +) +print(filtered) +# --8<-- [end:operations] + +# --8<-- [start:sort] +result = filtered.with_columns( + pl.col("Name", "Speed").sort_by("Speed", descending=True).over("Type 1"), +) +print(result) +# --8<-- [end:sort] + +# --8<-- [start:examples] +result = pokemon.sort("Type 1").select( + pl.col("Type 1").head(3).over("Type 1", mapping_strategy="explode"), + pl.col("Name") + .sort_by(pl.col("Speed"), descending=True) + .head(3) + .over("Type 1", mapping_strategy="explode") + .alias("fastest/group"), + pl.col("Name") + .sort_by(pl.col("Attack"), descending=True) + .head(3) + .over("Type 1", mapping_strategy="explode") + .alias("strongest/group"), + pl.col("Name") + .sort() + .head(3) + .over("Type 1", mapping_strategy="explode") + .alias("sorted_by_alphabet"), +) +print(result) +# --8<-- [end:examples] diff --git a/docs/source/src/python/user-guide/getting-started.py b/docs/source/src/python/user-guide/getting-started.py new file mode 100644 index 000000000000..d68207ebd60d --- /dev/null +++ b/docs/source/src/python/user-guide/getting-started.py @@ -0,0 +1,135 @@ +# --8<-- [start:df] +import polars as pl +import datetime as dt + +df = pl.DataFrame( + { + "name": ["Alice Archer", "Ben Brown", "Chloe Cooper", "Daniel Donovan"], + "birthdate": [ + dt.date(1997, 1, 10), + dt.date(1985, 2, 15), + dt.date(1983, 3, 22), + dt.date(1981, 4, 30), + ], + "weight": [57.9, 72.5, 53.6, 83.1], # (kg) + "height": [1.56, 1.77, 1.65, 1.75], # (m) + } +) + +print(df) +# --8<-- [end:df] + +# --8<-- [start:csv] +df.write_csv("docs/assets/data/output.csv") +df_csv = pl.read_csv("docs/assets/data/output.csv", try_parse_dates=True) +print(df_csv) +# --8<-- [end:csv] + +# --8<-- [start:select] +result = df.select( + pl.col("name"), + pl.col("birthdate").dt.year().alias("birth_year"), + (pl.col("weight") / (pl.col("height") ** 2)).alias("bmi"), +) +print(result) +# --8<-- [end:select] + +# --8<-- [start:expression-expansion] +result = df.select( + pl.col("name"), + (pl.col("weight", "height") * 0.95).round(2).name.suffix("-5%"), +) +print(result) +# --8<-- [end:expression-expansion] + +# --8<-- [start:with_columns] +result = df.with_columns( + birth_year=pl.col("birthdate").dt.year(), + bmi=pl.col("weight") / (pl.col("height") ** 2), +) +print(result) +# --8<-- [end:with_columns] + +# --8<-- [start:filter] +result = df.filter(pl.col("birthdate").dt.year() < 1990) +print(result) +# --8<-- [end:filter] + +# --8<-- [start:filter-multiple] +result = df.filter( + pl.col("birthdate").is_between(dt.date(1982, 12, 31), dt.date(1996, 1, 1)), + pl.col("height") > 1.7, +) +print(result) +# --8<-- [end:filter-multiple] + +# --8<-- [start:group_by] +result = df.group_by( + (pl.col("birthdate").dt.year() // 10 * 10).alias("decade"), + maintain_order=True, +).len() +print(result) +# --8<-- [end:group_by] + +# --8<-- [start:group_by-agg] +result = df.group_by( + (pl.col("birthdate").dt.year() // 10 * 10).alias("decade"), + maintain_order=True, +).agg( + pl.len().alias("sample_size"), + pl.col("weight").mean().round(2).alias("avg_weight"), + pl.col("height").max().alias("tallest"), +) +print(result) +# --8<-- [end:group_by-agg] + +# --8<-- [start:complex] +result = ( + df.with_columns( + (pl.col("birthdate").dt.year() // 10 * 10).alias("decade"), + pl.col("name").str.split(by=" ").list.first(), + ) + .select( + pl.all().exclude("birthdate"), + ) + .group_by( + pl.col("decade"), + maintain_order=True, + ) + .agg( + pl.col("name"), + pl.col("weight", "height").mean().round(2).name.prefix("avg_"), + ) +) +print(result) +# --8<-- [end:complex] + +# --8<-- [start:join] +df2 = pl.DataFrame( + { + "name": ["Ben Brown", "Daniel Donovan", "Alice Archer", "Chloe Cooper"], + "parent": [True, False, False, False], + "siblings": [1, 2, 3, 4], + } +) + +print(df.join(df2, on="name", how="left")) +# --8<-- [end:join] + +# --8<-- [start:concat] +df3 = pl.DataFrame( + { + "name": ["Ethan Edwards", "Fiona Foster", "Grace Gibson", "Henry Harris"], + "birthdate": [ + dt.date(1977, 5, 10), + dt.date(1975, 6, 23), + dt.date(1973, 7, 22), + dt.date(1971, 8, 3), + ], + "weight": [67.9, 72.5, 57.6, 93.1], # (kg) + "height": [1.76, 1.6, 1.66, 1.8], # (m) + } +) + +print(pl.concat([df, df3], how="vertical")) +# --8<-- [end:concat] diff --git a/docs/source/src/python/user-guide/io/bigquery.py b/docs/source/src/python/user-guide/io/bigquery.py new file mode 100644 index 000000000000..3b74bb163008 --- /dev/null +++ b/docs/source/src/python/user-guide/io/bigquery.py @@ -0,0 +1,41 @@ +""" +# --8<-- [start:read] +import polars as pl +from google.cloud import bigquery + +client = bigquery.Client() + +# Perform a query. +QUERY = ( + 'SELECT name FROM `bigquery-public-data.usa_names.usa_1910_2013` ' + 'WHERE state = "TX" ' + 'LIMIT 100') +query_job = client.query(QUERY) # API request +rows = query_job.result() # Waits for query to finish + +df = pl.from_arrow(rows.to_arrow()) +# --8<-- [end:read] + +# --8<-- [start:write] +from google.cloud import bigquery + +client = bigquery.Client() + +# Write DataFrame to stream as parquet file; does not hit disk +with io.BytesIO() as stream: + df.write_parquet(stream) + stream.seek(0) + parquet_options = bigquery.ParquetOptions() + parquet_options.enable_list_inference = True + job = client.load_table_from_file( + stream, + destination='tablename', + project='projectname', + job_config=bigquery.LoadJobConfig( + source_format=bigquery.SourceFormat.PARQUET, + parquet_options=parquet_options, + ), + ) +job.result() # Waits for the job to complete +# --8<-- [end:write] +""" diff --git a/docs/source/src/python/user-guide/io/cloud-storage.py b/docs/source/src/python/user-guide/io/cloud-storage.py new file mode 100644 index 000000000000..07be82588a3b --- /dev/null +++ b/docs/source/src/python/user-guide/io/cloud-storage.py @@ -0,0 +1,120 @@ +""" +# --8<-- [start:read_parquet] +import polars as pl + +source = "s3://bucket/*.parquet" + +df = pl.read_parquet(source) +# --8<-- [end:read_parquet] + +# --8<-- [start:scan_parquet_query] +import polars as pl + +source = "s3://bucket/*.parquet" + +df = pl.scan_parquet(source).filter(pl.col("id") < 100).select("id","value").collect() +# --8<-- [end:scan_parquet_query] + + +# --8<-- [start:scan_parquet_storage_options_aws] +import polars as pl + +source = "s3://bucket/*.parquet" + +storage_options = { + "aws_access_key_id": "", + "aws_secret_access_key": "", + "aws_region": "us-east-1", +} +df = pl.scan_parquet(source, storage_options=storage_options).collect() +# --8<-- [end:scan_parquet_storage_options_aws] + +# --8<-- [start:credential_provider_class] +lf = pl.scan_parquet( + "s3://.../...", + credential_provider=pl.CredentialProviderAWS( + profile_name="..." + assume_role={ + "RoleArn": f"...", + "RoleSessionName": "...", + } + ), +) + +df = lf.collect() +# --8<-- [end:credential_provider_class] + +# --8<-- [start:credential_provider_custom_func] +def get_credentials() -> pl.CredentialProviderFunctionReturn: + expiry = None + + return { + "aws_access_key_id": "...", + "aws_secret_access_key": "...", + "aws_session_token": "...", + }, expiry + + +lf = pl.scan_parquet( + "s3://.../...", + credential_provider=get_credentials, +) + +df = lf.collect() +# --8<-- [end:credential_provider_custom_func] + +# --8<-- [start:credential_provider_custom_func_azure] +def credential_provider(): + credential = DefaultAzureCredential(exclude_managed_identity_credential=True) + token = credential.get_token("https://storage.azure.com/.default") + + return {"bearer_token": token.token}, token.expires_on + + +pl.scan_parquet( + "abfss://...@.../...", + credential_provider=credential_provider, +) + +# Note that for the above case, this shortcut is also available: + +pl.scan_parquet( + "abfss://...@.../...", + credential_provider=pl.CredentialProviderAzure( + credentials=DefaultAzureCredential(exclude_managed_identity_credential=True) + ), +) + +# --8<-- [end:credential_provider_custom_func_azure] + +# --8<-- [start:scan_pyarrow_dataset] +import polars as pl +import pyarrow.dataset as ds + +dset = ds.dataset("s3://my-partitioned-folder/", format="parquet") +( + pl.scan_pyarrow_dataset(dset) + .filter(pl.col("foo") == "a") + .select(["foo", "bar"]) + .collect() +) +# --8<-- [end:scan_pyarrow_dataset] + +# --8<-- [start:write_parquet] + +import polars as pl +import s3fs + +df = pl.DataFrame({ + "foo": ["a", "b", "c", "d", "d"], + "bar": [1, 2, 3, 4, 5], +}) + +fs = s3fs.S3FileSystem() +destination = "s3://bucket/my_file.parquet" + +# write parquet +with fs.open(destination, mode='wb') as f: + df.write_parquet(f) +# --8<-- [end:write_parquet] +""" diff --git a/docs/source/src/python/user-guide/io/csv.py b/docs/source/src/python/user-guide/io/csv.py new file mode 100644 index 000000000000..8816bc216321 --- /dev/null +++ b/docs/source/src/python/user-guide/io/csv.py @@ -0,0 +1,19 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +""" +# --8<-- [start:read] +df = pl.read_csv("docs/assets/data/path.csv") +# --8<-- [end:read] +""" + +# --8<-- [start:write] +df = pl.DataFrame({"foo": [1, 2, 3], "bar": [None, "bak", "baz"]}) +df.write_csv("docs/assets/data/path.csv") +# --8<-- [end:write] + +# --8<-- [start:scan] +df = pl.scan_csv("docs/assets/data/path.csv") +# --8<-- [end:scan] diff --git a/docs/source/src/python/user-guide/io/database.py b/docs/source/src/python/user-guide/io/database.py new file mode 100644 index 000000000000..1a6536f55222 --- /dev/null +++ b/docs/source/src/python/user-guide/io/database.py @@ -0,0 +1,43 @@ +""" +# --8<-- [start:read_uri] +import polars as pl + +uri = "postgresql://username:password@server:port/database" +query = "SELECT * FROM foo" + +pl.read_database_uri(query=query, uri=uri) +# --8<-- [end:read_uri] + +# --8<-- [start:read_cursor] +import polars as pl +from sqlalchemy import create_engine + +conn = create_engine(f"sqlite:///test.db") + +query = "SELECT * FROM foo" + +pl.read_database(query=query, connection=conn.connect()) +# --8<-- [end:read_cursor] + + +# --8<-- [start:adbc] +uri = "postgresql://username:password@server:port/database" +query = "SELECT * FROM foo" + +pl.read_database_uri(query=query, uri=uri, engine="adbc") +# --8<-- [end:adbc] + +# --8<-- [start:write] +uri = "postgresql://username:password@server:port/database" +df = pl.DataFrame({"foo": [1, 2, 3]}) + +df.write_database(table_name="records", connection=uri) +# --8<-- [end:write] + +# --8<-- [start:write_adbc] +uri = "postgresql://username:password@server:port/database" +df = pl.DataFrame({"foo": [1, 2, 3]}) + +df.write_database(table_name="records", connection=uri, engine="adbc") +# --8<-- [end:write_adbc] +""" diff --git a/docs/source/src/python/user-guide/io/excel.py b/docs/source/src/python/user-guide/io/excel.py new file mode 100644 index 000000000000..3931d0d41cac --- /dev/null +++ b/docs/source/src/python/user-guide/io/excel.py @@ -0,0 +1,26 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +""" +# --8<-- [start:read] +df = pl.read_excel("docs/assets/data/path.xlsx") +# --8<-- [end:read] + +# --8<-- [start:read_sheet_name] +df = pl.read_excel("docs/assets/data/path.xlsx", sheet_name="Sales") +# --8<-- [end:read_sheet_name] +""" + +# --8<-- [start:write] +df = pl.DataFrame({"foo": [1, 2, 3], "bar": [None, "bak", "baz"]}) +df.write_excel("docs/assets/data/path.xlsx") +# --8<-- [end:write] + +""" +# --8<-- [start:write_sheet_name] +df = pl.DataFrame({"foo": [1, 2, 3], "bar": [None, "bak", "baz"]}) +df.write_excel("docs/assets/data/path.xlsx", worksheet="Sales") +# --8<-- [end:write_sheet_name] +""" diff --git a/docs/source/src/python/user-guide/io/hive.py b/docs/source/src/python/user-guide/io/hive.py new file mode 100644 index 000000000000..6d9c4d5f88d1 --- /dev/null +++ b/docs/source/src/python/user-guide/io/hive.py @@ -0,0 +1,132 @@ +# --8<-- [start:init_paths] +import polars as pl +from pathlib import Path + +dfs = [ + pl.DataFrame({"x": [1, 2]}), + pl.DataFrame({"x": [3, 4, 5]}), + pl.DataFrame({"x": [6, 7]}), + pl.DataFrame({"x": [8, 9, 10, 11]}), +] + +parts = [ + "year=2023/month=11", + "year=2023/month=12", + "year=2024/month=01", + "year=2024/month=02", +] + +for df, part in zip(dfs, parts): + path = Path("docs/assets/data/hive/") / part / "data.parquet" + Path(path).parent.mkdir(exist_ok=True, parents=True) + df.write_parquet(path) + + path = Path("docs/assets/data/hive_mixed/") / part / "data.parquet" + Path(path).parent.mkdir(exist_ok=True, parents=True) + df.write_parquet(path) + +# Make sure the file is not empty because path expansion ignores empty files. +Path("docs/assets/data/hive_mixed/description.txt").write_text("A") + + +def print_paths(path: str) -> None: + def dir_recurse(path: Path): + if path.is_dir(): + for p in path.iterdir(): + yield from dir_recurse(p) + else: + yield path + + df = ( + pl.Series( + "File path", + (str(x) for x in dir_recurse(Path(path))), + dtype=pl.String, + ) + .sort() + .to_frame() + ) + + with pl.Config( + tbl_hide_column_data_types=True, + tbl_hide_dataframe_shape=True, + fmt_str_lengths=999, + ): + print(df) + + +print_paths("docs/assets/data/hive/") +# --8<-- [end:init_paths] + +# --8<-- [start:show_mixed_paths] +print_paths("docs/assets/data/hive_mixed/") +# --8<-- [end:show_mixed_paths] + +# --8<-- [start:scan_dir] +import polars as pl + +df = pl.scan_parquet("docs/assets/data/hive/").collect() + +with pl.Config(tbl_rows=99): + print(df) +# --8<-- [end:scan_dir] + +# --8<-- [start:scan_dir_err] +from pathlib import Path + +try: + pl.scan_parquet("docs/assets/data/hive_mixed/").collect() +except Exception as e: + print(e) + +# --8<-- [end:scan_dir_err] + +# --8<-- [start:scan_glob] +df = pl.scan_parquet( + # Glob to match all files ending in `.parquet` + "docs/assets/data/hive_mixed/**/*.parquet", + hive_partitioning=True, +).collect() + +with pl.Config(tbl_rows=99): + print(df) + +# --8<-- [end:scan_glob] + +# --8<-- [start:scan_file_no_hive] +df = pl.scan_parquet( + [ + "docs/assets/data/hive/year=2024/month=01/data.parquet", + "docs/assets/data/hive/year=2024/month=02/data.parquet", + ], +).collect() + +print(df) + +# --8<-- [end:scan_file_no_hive] + +# --8<-- [start:scan_file_hive] +df = pl.scan_parquet( + [ + "docs/assets/data/hive/year=2024/month=01/data.parquet", + "docs/assets/data/hive/year=2024/month=02/data.parquet", + ], + hive_partitioning=True, +).collect() + +print(df) + +# --8<-- [end:scan_file_hive] + +# --8<-- [start:write_parquet_partitioned_show_data] +df = pl.DataFrame({"a": [1, 1, 2, 2, 3], "b": [1, 1, 1, 2, 2], "c": 1}) +print(df) +# --8<-- [end:write_parquet_partitioned_show_data] + +# --8<-- [start:write_parquet_partitioned] +df.write_parquet("docs/assets/data/hive_write/", partition_by=["a", "b"]) +# --8<-- [end:write_parquet_partitioned] + +# --8<-- [start:write_parquet_partitioned_show_paths] +print_paths("docs/assets/data/hive_write/") +# --8<-- [end:write_parquet_partitioned_show_paths] diff --git a/docs/source/src/python/user-guide/io/hugging-face.py b/docs/source/src/python/user-guide/io/hugging-face.py new file mode 100644 index 000000000000..cb7d21583682 --- /dev/null +++ b/docs/source/src/python/user-guide/io/hugging-face.py @@ -0,0 +1,96 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:scan_iris_csv] +print(pl.scan_csv("hf://datasets/nameexhaustion/polars-docs/iris.csv").collect()) +# --8<-- [end:scan_iris_csv] + +# --8<-- [start:scan_iris_ndjson] +print(pl.scan_ndjson("hf://datasets/nameexhaustion/polars-docs/iris.jsonl").collect()) +# --8<-- [end:scan_iris_ndjson] + +# --8<-- [start:scan_iris_repr] +print( + """\ +shape: (150, 5) +┌──────────────┬─────────────┬──────────────┬─────────────┬───────────┐ +│ sepal_length ┆ sepal_width ┆ petal_length ┆ petal_width ┆ species │ +│ --- ┆ --- ┆ --- ┆ --- ┆ --- │ +│ f64 ┆ f64 ┆ f64 ┆ f64 ┆ str │ +╞══════════════╪═════════════╪══════════════╪═════════════╪═══════════╡ +│ 5.1 ┆ 3.5 ┆ 1.4 ┆ 0.2 ┆ setosa │ +│ 4.9 ┆ 3.0 ┆ 1.4 ┆ 0.2 ┆ setosa │ +│ 4.7 ┆ 3.2 ┆ 1.3 ┆ 0.2 ┆ setosa │ +│ 4.6 ┆ 3.1 ┆ 1.5 ┆ 0.2 ┆ setosa │ +│ 5.0 ┆ 3.6 ┆ 1.4 ┆ 0.2 ┆ setosa │ +│ … ┆ … ┆ … ┆ … ┆ … │ +│ 6.7 ┆ 3.0 ┆ 5.2 ┆ 2.3 ┆ virginica │ +│ 6.3 ┆ 2.5 ┆ 5.0 ┆ 1.9 ┆ virginica │ +│ 6.5 ┆ 3.0 ┆ 5.2 ┆ 2.0 ┆ virginica │ +│ 6.2 ┆ 3.4 ┆ 5.4 ┆ 2.3 ┆ virginica │ +│ 5.9 ┆ 3.0 ┆ 5.1 ┆ 1.8 ┆ virginica │ +└──────────────┴─────────────┴──────────────┴─────────────┴───────────┘ +""" +) +# --8<-- [end:scan_iris_repr] + +# --8<-- [start:scan_parquet_hive] +print(pl.scan_parquet("hf://datasets/nameexhaustion/polars-docs/hive_dates/").collect()) +# --8<-- [end:scan_parquet_hive] + +# --8<-- [start:scan_parquet_hive_repr] +print( + """\ +shape: (4, 3) +┌────────────┬────────────────────────────┬─────┐ +│ date1 ┆ date2 ┆ x │ +│ --- ┆ --- ┆ --- │ +│ date ┆ datetime[μs] ┆ i32 │ +╞════════════╪════════════════════════════╪═════╡ +│ 2024-01-01 ┆ 2023-01-01 00:00:00 ┆ 1 │ +│ 2024-02-01 ┆ 2023-02-01 00:00:00 ┆ 2 │ +│ 2024-03-01 ┆ null ┆ 3 │ +│ null ┆ 2023-03-01 01:01:01.000001 ┆ 4 │ +└────────────┴────────────────────────────┴─────┘ +""" +) +# --8<-- [end:scan_parquet_hive_repr] + +# --8<-- [start:scan_ipc] +print(pl.scan_ipc("hf://spaces/nameexhaustion/polars-docs/orders.feather").collect()) +# --8<-- [end:scan_ipc] + +# --8<-- [start:scan_ipc_repr] +print( + """\ +shape: (10, 9) +┌────────────┬───────────┬───────────────┬──────────────┬───┬─────────────────┬─────────────────┬────────────────┬─────────────────────────┐ +│ o_orderkey ┆ o_custkey ┆ o_orderstatus ┆ o_totalprice ┆ … ┆ o_orderpriority ┆ o_clerk ┆ o_shippriority ┆ o_comment │ +│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │ +│ i64 ┆ i64 ┆ str ┆ f64 ┆ ┆ str ┆ str ┆ i64 ┆ str │ +╞════════════╪═══════════╪═══════════════╪══════════════╪═══╪═════════════════╪═════════════════╪════════════════╪═════════════════════════╡ +│ 1 ┆ 36901 ┆ O ┆ 173665.47 ┆ … ┆ 5-LOW ┆ Clerk#000000951 ┆ 0 ┆ nstructions sleep │ +│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ furiously am… │ +│ 2 ┆ 78002 ┆ O ┆ 46929.18 ┆ … ┆ 1-URGENT ┆ Clerk#000000880 ┆ 0 ┆ foxes. pending accounts │ +│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ at th… │ +│ 3 ┆ 123314 ┆ F ┆ 193846.25 ┆ … ┆ 5-LOW ┆ Clerk#000000955 ┆ 0 ┆ sly final accounts │ +│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ boost. care… │ +│ 4 ┆ 136777 ┆ O ┆ 32151.78 ┆ … ┆ 5-LOW ┆ Clerk#000000124 ┆ 0 ┆ sits. slyly regular │ +│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ warthogs c… │ +│ 5 ┆ 44485 ┆ F ┆ 144659.2 ┆ … ┆ 5-LOW ┆ Clerk#000000925 ┆ 0 ┆ quickly. bold deposits │ +│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ sleep s… │ +│ 6 ┆ 55624 ┆ F ┆ 58749.59 ┆ … ┆ 4-NOT SPECIFIED ┆ Clerk#000000058 ┆ 0 ┆ ggle. special, final │ +│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ requests … │ +│ 7 ┆ 39136 ┆ O ┆ 252004.18 ┆ … ┆ 2-HIGH ┆ Clerk#000000470 ┆ 0 ┆ ly special requests │ +│ 32 ┆ 130057 ┆ O ┆ 208660.75 ┆ … ┆ 2-HIGH ┆ Clerk#000000616 ┆ 0 ┆ ise blithely bold, │ +│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ regular req… │ +│ 33 ┆ 66958 ┆ F ┆ 163243.98 ┆ … ┆ 3-MEDIUM ┆ Clerk#000000409 ┆ 0 ┆ uriously. furiously │ +│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ final requ… │ +│ 34 ┆ 61001 ┆ O ┆ 58949.67 ┆ … ┆ 3-MEDIUM ┆ Clerk#000000223 ┆ 0 ┆ ly final packages. │ +│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ fluffily fi… │ +└────────────┴───────────┴───────────────┴──────────────┴───┴─────────────────┴─────────────────┴────────────────┴─────────────────────────┘ +""" +) +# --8<-- [end:scan_ipc_repr] diff --git a/docs/source/src/python/user-guide/io/json.py b/docs/source/src/python/user-guide/io/json.py new file mode 100644 index 000000000000..0efa8908d262 --- /dev/null +++ b/docs/source/src/python/user-guide/io/json.py @@ -0,0 +1,27 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +""" +# --8<-- [start:read] +df = pl.read_json("docs/assets/data/path.json") +# --8<-- [end:read] + +# --8<-- [start:readnd] +df = pl.read_ndjson("docs/assets/data/path.json") +# --8<-- [end:readnd] + +""" + +# --8<-- [start:write] +df = pl.DataFrame({"foo": [1, 2, 3], "bar": [None, "bak", "baz"]}) +df.write_json("docs/assets/data/path.json") +# --8<-- [end:write] + +""" +# --8<-- [start:scan] +df = pl.scan_ndjson("docs/assets/data/path.json") +# --8<-- [end:scan] + +""" diff --git a/docs/source/src/python/user-guide/io/multiple.py b/docs/source/src/python/user-guide/io/multiple.py new file mode 100644 index 000000000000..08aca544b7dc --- /dev/null +++ b/docs/source/src/python/user-guide/io/multiple.py @@ -0,0 +1,42 @@ +# --8<-- [start:create] +import polars as pl + +df = pl.DataFrame({"foo": [1, 2, 3], "bar": [None, "ham", "spam"]}) + +for i in range(5): + df.write_csv(f"docs/assets/data/my_many_files_{i}.csv") +# --8<-- [end:create] + +# --8<-- [start:read] +df = pl.read_csv("docs/assets/data/my_many_files_*.csv") +print(df) +# --8<-- [end:read] + +# --8<-- [start:creategraph] +import base64 + +pl.scan_csv("docs/assets/data/my_many_files_*.csv").show_graph( + output_path="docs/assets/images/multiple.png", show=False +) +with open("docs/assets/images/multiple.png", "rb") as f: + png = base64.b64encode(f.read()).decode() + print(f'') +# --8<-- [end:creategraph] + +# --8<-- [start:graph] +pl.scan_csv("docs/assets/data/my_many_files_*.csv").show_graph() +# --8<-- [end:graph] + +# --8<-- [start:glob] +import glob + +import polars as pl + +queries = [] +for file in glob.glob("docs/assets/data/my_many_files_*.csv"): + q = pl.scan_csv(file).group_by("bar").agg(pl.len(), pl.sum("foo")) + queries.append(q) + +dataframes = pl.collect_all(queries) +print(dataframes) +# --8<-- [end:glob] diff --git a/docs/source/src/python/user-guide/io/parquet.py b/docs/source/src/python/user-guide/io/parquet.py new file mode 100644 index 000000000000..f92c73e7bc16 --- /dev/null +++ b/docs/source/src/python/user-guide/io/parquet.py @@ -0,0 +1,19 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +""" +# --8<-- [start:read] +df = pl.read_parquet("docs/assets/data/path.parquet") +# --8<-- [end:read] +""" + +# --8<-- [start:write] +df = pl.DataFrame({"foo": [1, 2, 3], "bar": [None, "bak", "baz"]}) +df.write_parquet("docs/assets/data/path.parquet") +# --8<-- [end:write] + +# --8<-- [start:scan] +df = pl.scan_parquet("docs/assets/data/path.parquet") +# --8<-- [end:scan] diff --git a/docs/source/src/python/user-guide/io/sheets_colab.py b/docs/source/src/python/user-guide/io/sheets_colab.py new file mode 100644 index 000000000000..332aea0d1209 --- /dev/null +++ b/docs/source/src/python/user-guide/io/sheets_colab.py @@ -0,0 +1,29 @@ +""" +# --8<-- [start:open] +import polars as pl +from google.colab import sheets +url = "https://docs.google.com/spreadsheets/d/1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgvE2upms" +sheet = sheets.InteractiveSheet(url=url, backend="polars", display=False) +sheet.as_df() +# --8<-- [end:open] +# --8<-- [start:create_title] +sheet = sheets.InteractiveSheet(title="Colab <3 Polars", backend="polars") +# --8<-- [end:create_title] +# --8<-- [start:create_df] +df = pl.DataFrame({"a": [1,2,3], "b": ["a", "b", "c"]}) +sheet = sheets.InteractiveSheet(df=df, title="Colab <3 Polars", backend="polars") +# --8<-- [end:create_df] +# --8<-- [start:update] +sheet.update(df) +# --8<-- [end:update] +# --8<-- [start:update_loc] +sheet.update(df, clear=False) +sheet.update(df, location="D3") +sheet.update(df, location=(3, 4)) +# --8<-- [end:update_loc] +# --8<-- [start:update_loop] +for i, df in dfs: + df = pl.select(x=pl.arange(5)).with_columns(pow=pl.col("x") ** i) + sheet.update(df, loc=(1, i * 3), clear=i == 0) +# --8<-- [end:update_loop] +""" diff --git a/docs/source/src/python/user-guide/lazy/execution.py b/docs/source/src/python/user-guide/lazy/execution.py new file mode 100644 index 000000000000..aa425102d5aa --- /dev/null +++ b/docs/source/src/python/user-guide/lazy/execution.py @@ -0,0 +1,40 @@ +""" +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:df] +q1 = ( + pl.scan_csv("docs/assets/data/reddit.csv") + .with_columns(pl.col("name").str.to_uppercase()) + .filter(pl.col("comment_karma") > 0) +) +# --8<-- [end:df] + +# --8<-- [start:collect] +q4 = ( + pl.scan_csv(f"docs/assets/data/reddit.csv") + .with_columns(pl.col("name").str.to_uppercase()) + .filter(pl.col("comment_karma") > 0) + .collect() +) +# --8<-- [end:collect] +# --8<-- [start:stream] +q5 = ( + pl.scan_csv(f"docs/assets/data/reddit.csv") + .with_columns(pl.col("name").str.to_uppercase()) + .filter(pl.col("comment_karma") > 0) + .collect(engine='streaming') +) +# --8<-- [end:stream] +# --8<-- [start:partial] +q9 = ( + pl.scan_csv(f"docs/assets/data/reddit.csv") + .head(10) + .with_columns(pl.col("name").str.to_uppercase()) + .filter(pl.col("comment_karma") > 0) + .collect() +) +# --8<-- [end:partial] +""" diff --git a/docs/source/src/python/user-guide/lazy/gpu.py b/docs/source/src/python/user-guide/lazy/gpu.py new file mode 100644 index 000000000000..21d246d09d76 --- /dev/null +++ b/docs/source/src/python/user-guide/lazy/gpu.py @@ -0,0 +1,46 @@ +# --8<-- [start:setup] +import polars as pl + +df = pl.LazyFrame({"a": [1.242, 1.535]}) + +q = df.select(pl.col("a").round(1)) +# --8<-- [end:setup] + +# Avoiding requiring the GPU engine for doc build output +# --8<-- [start:simple-result] +result = q.collect() +print(result) +# --8<-- [end:simple-result] + + +# --8<-- [start:engine-setup] +q = df.select((pl.col("a") ** 4)) + +# --8<-- [end:engine-setup] + +# --8<-- [start:engine-result] +result = q.collect() +print(result) +# --8<-- [end:engine-result] + +# --8<-- [start:fallback-setup] +df = pl.LazyFrame( + { + "key": [1, 1, 1, 2, 3, 3, 2, 2], + "value": [1, 2, 3, 4, 5, 6, 7, 8], + } +) + +q = df.select(pl.col("value").sum().over("key")) + +# --8<-- [end:fallback-setup] + +# --8<-- [start:fallback-result] +print( + "PerformanceWarning: Query execution with GPU not supported, reason: \n" + ": Grouped rolling window not implemented" +) +print("# some details elided") +print() +print(q.collect()) +# --8<- [end:fallback-result] diff --git a/docs/source/src/python/user-guide/lazy/multiplexing.py b/docs/source/src/python/user-guide/lazy/multiplexing.py new file mode 100644 index 000000000000..d7055a19d902 --- /dev/null +++ b/docs/source/src/python/user-guide/lazy/multiplexing.py @@ -0,0 +1,81 @@ +# --8<-- [start:setup] +import polars as pl +import numpy as np +import tempfile +import base64 +import polars.testing + + +def show_plan(q: pl.LazyFrame, optimized: bool = True): + with tempfile.NamedTemporaryFile() as fp: + q.show_graph(show=False, output_path=fp.name, optimized=optimized) + with open(fp.name, "rb") as f: + png = base64.b64encode(f.read()).decode() + print(f'') + + +# --8<-- [end:setup] + + +# --8<-- [start:dataframe] +np.random.seed(0) +a = np.arange(0, 10) +np.random.shuffle(a) +df = pl.DataFrame({"n": a}) +print(df) +# --8<-- [end:dataframe] + +# --8<-- [start:eager] +# A group-by doesn't guarantee order +df1 = df.group_by("n").len() + +# Take the lower half and the upper half in a list +out = [df1.slice(offset=i * 5, length=5) for i in range(2)] + +# Assert df1 is equal to the sum of both halves +pl.testing.assert_frame_equal(df1, pl.concat(out)) +# --8<-- [end:eager] + +""" +# --8<-- [start:lazy] +lf1 = df.lazy().group_by("n").len() + +out = [lf1.slice(offset=i * 5, length=5).collect() for i in range(2)] + +pl.testing.assert_frame_equal(lf1.collect(), pl.concat(out)) +# --8<-- [end:lazy] +""" + +# --8<-- [start:plan_0] +q1 = df.lazy().group_by("n").len() +show_plan(q1, optimized=False) +# --8<-- [end:plan_0] + +# --8<-- [start:plan_1] +q1 = df.lazy().group_by("n").len() +q2 = q1.slice(offset=0, length=5) +show_plan(q2, optimized=False) +# --8<-- [end:plan_1] + +# --8<-- [start:plan_2] +q1 = df.lazy().group_by("n").len() +q2 = q1.slice(offset=5, length=5) +show_plan(q2, optimized=False) +# --8<-- [end:plan_2] + + +# --8<-- [start:collect_all] +lf1 = df.lazy().group_by("n").len() + +out = [lf1.slice(offset=i * 5, length=5) for i in range(2)] +results = pl.collect_all([lf1] + out) + +pl.testing.assert_frame_equal(results[0], pl.concat(results[1:])) +# --8<-- [end:collect_all] + +# --8<-- [start:explain_all] +lf1 = df.lazy().group_by("n").len() +out = [lf1.slice(offset=i * 5, length=5) for i in range(2)] + +print(pl.explain_all([lf1] + out)) +# --8<-- [end:explain_all] diff --git a/docs/source/src/python/user-guide/lazy/query-plan.py b/docs/source/src/python/user-guide/lazy/query-plan.py new file mode 100644 index 000000000000..f0aae1383455 --- /dev/null +++ b/docs/source/src/python/user-guide/lazy/query-plan.py @@ -0,0 +1,50 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:plan] +q1 = ( + pl.scan_csv("docs/assets/data/reddit.csv") + .with_columns(pl.col("name").str.to_uppercase()) + .filter(pl.col("comment_karma") > 0) +) +# --8<-- [end:plan] + +# --8<-- [start:createplan] +import base64 + +q1.show_graph( + optimized=False, show=False, output_path="docs/assets/images/query_plan.png" +) +with open("docs/assets/images/query_plan.png", "rb") as f: + png = base64.b64encode(f.read()).decode() + print(f'') +# --8<-- [end:createplan] + +""" +# --8<-- [start:showplan] +q1.show_graph(optimized=False) +# --8<-- [end:showplan] +""" + +# --8<-- [start:describe] +q1.explain(optimized=False) +# --8<-- [end:describe] + +# --8<-- [start:createplan2] +q1.show_graph(show=False, output_path="docs/assets/images/query_plan_optimized.png") +with open("docs/assets/images/query_plan_optimized.png", "rb") as f: + png = base64.b64encode(f.read()).decode() + print(f'') +# --8<-- [end:createplan2] + +""" +# --8<-- [start:show] +q1.show_graph() +# --8<-- [end:show] +""" + +# --8<-- [start:optimized] +q1.explain() +# --8<-- [end:optimized] diff --git a/docs/source/src/python/user-guide/lazy/schema.py b/docs/source/src/python/user-guide/lazy/schema.py new file mode 100644 index 000000000000..b77469434f51 --- /dev/null +++ b/docs/source/src/python/user-guide/lazy/schema.py @@ -0,0 +1,40 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:schema] +lf = pl.LazyFrame({"foo": ["a", "b", "c"], "bar": [0, 1, 2]}) + +print(lf.collect_schema()) +# --8<-- [end:schema] + +# --8<-- [start:lazyround] +lf = pl.LazyFrame({"foo": ["a", "b", "c"]}).with_columns(pl.col("foo").round(2)) +# --8<-- [end:lazyround] + +# --8<-- [start:typecheck] +try: + print(lf.collect()) +except Exception as e: + print(f"{type(e).__name__}: {e}") +# --8<-- [end:typecheck] + +# --8<-- [start:lazyeager] +lazy_eager_query = ( + pl.LazyFrame( + { + "id": ["a", "b", "c"], + "month": ["jan", "feb", "mar"], + "values": [0, 1, 2], + } + ) + .with_columns((2 * pl.col("values")).alias("double_values")) + .collect() + .pivot(index="id", on="month", values="double_values", aggregate_function="first") + .lazy() + .filter(pl.col("mar").is_null()) + .collect() +) +print(lazy_eager_query) +# --8<-- [end:lazyeager] diff --git a/docs/source/src/python/user-guide/lazy/using.py b/docs/source/src/python/user-guide/lazy/using.py new file mode 100644 index 000000000000..0c4a37bd21d1 --- /dev/null +++ b/docs/source/src/python/user-guide/lazy/using.py @@ -0,0 +1,18 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +""" +# --8<-- [start:dataframe] +q1 = ( + pl.scan_csv(f"docs/assets/data/reddit.csv") + .with_columns(pl.col("name").str.to_uppercase()) + .filter(pl.col("comment_karma") > 0) +) +# --8<-- [end:dataframe] +""" + +# --8<-- [start:fromdf] +q3 = pl.DataFrame({"foo": ["a", "b", "c"], "bar": [0, 1, 2]}).lazy() +# --8<-- [end:fromdf] diff --git a/docs/source/src/python/user-guide/misc/arrow.py b/docs/source/src/python/user-guide/misc/arrow.py new file mode 100644 index 000000000000..56e9cafa8137 --- /dev/null +++ b/docs/source/src/python/user-guide/misc/arrow.py @@ -0,0 +1,13 @@ +# --8<-- [start:to_arrow] +import polars as pl + +df = pl.DataFrame({"foo": [1, 2, 3], "bar": ["ham", "spam", "jam"]}) + +arrow_table = df.to_arrow() +print(arrow_table) +# --8<-- [end:to_arrow] + +# --8<-- [start:to_arrow_zero] +arrow_table_zero_copy = df.to_arrow(compat_level=pl.CompatLevel.newest()) +print(arrow_table_zero_copy) +# --8<-- [end:to_arrow_zero] diff --git a/docs/source/src/python/user-guide/misc/arrow_pycapsule.py b/docs/source/src/python/user-guide/misc/arrow_pycapsule.py new file mode 100644 index 000000000000..9ba1c9901307 --- /dev/null +++ b/docs/source/src/python/user-guide/misc/arrow_pycapsule.py @@ -0,0 +1,28 @@ +# --8<-- [start:to_arrow] +import polars as pl +import pyarrow as pa + +df = pl.DataFrame({"foo": [1, 2, 3], "bar": ["ham", "spam", "jam"]}) +arrow_table = pa.table(df) +print(arrow_table) +# --8<-- [end:to_arrow] + +# --8<-- [start:to_polars] +polars_df = pl.DataFrame(arrow_table) +print(polars_df) +# --8<-- [end:to_polars] + +# --8<-- [start:to_arrow_series] +arrow_chunked_array = pa.chunked_array(df["foo"]) +print(arrow_chunked_array) +# --8<-- [end:to_arrow_series] + +# --8<-- [start:to_polars_series] +polars_series = pl.Series(arrow_chunked_array) +print(polars_series) +# --8<-- [end:to_polars_series] + +# --8<-- [start:to_arrow_array_rechunk] +arrow_array = pa.array(df["foo"]) +print(arrow_array) +# --8<-- [end:to_arrow_array_rechunk] diff --git a/docs/source/src/python/user-guide/misc/multiprocess.py b/docs/source/src/python/user-guide/misc/multiprocess.py new file mode 100644 index 000000000000..6876b0553752 --- /dev/null +++ b/docs/source/src/python/user-guide/misc/multiprocess.py @@ -0,0 +1,85 @@ +""" +# --8<-- [start:recommendation] +from multiprocessing import get_context + + +def my_fun(s): + print(s) + + +with get_context("spawn").Pool() as pool: + pool.map(my_fun, ["input1", "input2", ...]) + +# --8<-- [end:recommendation] + +# --8<-- [start:example1] +import multiprocessing +import polars as pl + + +def test_sub_process(df: pl.DataFrame, job_id): + df_filtered = df.filter(pl.col("a") > 0) + print(f"Filtered (job_id: {job_id})", df_filtered, sep="\n") + + +def create_dataset(): + return pl.DataFrame({"a": [0, 2, 3, 4, 5], "b": [0, 4, 5, 56, 4]}) + + +def setup(): + # some setup work + df = create_dataset() + df.write_parquet("/tmp/test.parquet") + + +def main(): + test_df = pl.read_parquet("/tmp/test.parquet") + + for i in range(0, 5): + proc = multiprocessing.get_context("spawn").Process( + target=test_sub_process, args=(test_df, i) + ) + proc.start() + proc.join() + + print(f"Executed sub process {i}") + + +if __name__ == "__main__": + setup() + main() + +# --8<-- [end:example1] +""" + +# --8<-- [start:example2] +import multiprocessing +import polars as pl + + +def test_sub_process(df: pl.DataFrame, job_id): + df_filtered = df.filter(pl.col("a") > 0) + print(f"Filtered (job_id: {job_id})", df_filtered, sep="\n") + + +def create_dataset(): + return pl.DataFrame({"a": [0, 2, 3, 4, 5], "b": [0, 4, 5, 56, 4]}) + + +def main(): + test_df = create_dataset() + + for i in range(0, 5): + proc = multiprocessing.get_context("fork").Process( + target=test_sub_process, args=(test_df, i) + ) + proc.start() + proc.join() + + print(f"Executed sub process {i}") + + +if __name__ == "__main__": + main() + +# --8<-- [end:example2] diff --git a/docs/source/src/python/user-guide/misc/styling.py b/docs/source/src/python/user-guide/misc/styling.py new file mode 100644 index 000000000000..e001de49dd29 --- /dev/null +++ b/docs/source/src/python/user-guide/misc/styling.py @@ -0,0 +1,161 @@ +# --8<-- [start:setup] +import warnings + +# great-tables throws a Deprecation warning in `as_raw_html` under Python 3.13 +# https://github.com/posit-dev/great-tables/pull/563 +warnings.filterwarnings( + "ignore", "'count' is passed as positional argument", category=DeprecationWarning +) +# --8<-- [end:setup] + +# --8<-- [start:dataframe] +import polars as pl +import polars.selectors as cs + +path = "docs/assets/data/iris.csv" + +df = ( + pl.scan_csv(path) + .group_by("species") + .agg(cs.starts_with("petal").mean().round(3)) + .collect() +) +print(df) +# --8<-- [end:dataframe] + +# --8<-- [start:structure-header] +df.style.tab_header(title="Iris Data", subtitle="Mean measurement values per species") +# --8<-- [end:structure-header] + +# --8<-- [start:structure-header-out] +print( + df.style.tab_header( + title="Iris Data", subtitle="Mean measurement values per species" + ).as_raw_html() +) +# --8<-- [end:structure-header-out] + + +# --8<-- [start:structure-stub] +df.style.tab_stub(rowname_col="species") +# --8<-- [end:structure-stub] + +# --8<-- [start:structure-stub-out] +print(df.style.tab_stub(rowname_col="species").as_raw_html()) +# --8<-- [end:structure-stub-out] + +# --8<-- [start:structure-spanner] +( + df.style.tab_spanner("Petal", cs.starts_with("petal")).cols_label( + petal_length="Length", petal_width="Width" + ) +) +# --8<-- [end:structure-spanner] + +# --8<-- [start:structure-spanner-out] +print( + df.style.tab_spanner("Petal", cs.starts_with("petal")) + .cols_label(petal_length="Length", petal_width="Width") + .as_raw_html() +) +# --8<-- [end:structure-spanner-out] + +# --8<-- [start:format-number] +df.style.fmt_number("petal_width", decimals=1) +# --8<-- [end:format-number] + + +# --8<-- [start:format-number-out] +print(df.style.fmt_number("petal_width", decimals=1).as_raw_html()) +# --8<-- [end:format-number-out] + + +# --8<-- [start:style-simple] +from great_tables import loc, style + +df.style.tab_style( + style.fill("yellow"), + loc.body( + rows=pl.col("petal_length") == pl.col("petal_length").max(), + ), +) +# --8<-- [end:style-simple] + +# --8<-- [start:style-simple-out] +from great_tables import loc, style + +print( + df.style.tab_style( + style.fill("yellow"), + loc.body( + rows=pl.col("petal_length") == pl.col("petal_length").max(), + ), + ).as_raw_html() +) +# --8<-- [end:style-simple-out] + + +# --8<-- [start:style-bold-column] +from great_tables import loc, style + +df.style.tab_style( + style.text(weight="bold"), + loc.body(columns="species"), +) +# --8<-- [end:style-bold-column] + +# --8<-- [start:style-bold-column-out] +from great_tables import loc, style + +print( + df.style.tab_style( + style.text(weight="bold"), + loc.body(columns="species"), + ).as_raw_html() +) +# --8<-- [end:style-bold-column-out] + +# --8<-- [start:full-example] +from great_tables import loc, style + +( + df.style.tab_header( + title="Iris Data", subtitle="Mean measurement values per species" + ) + .tab_stub(rowname_col="species") + .cols_label(petal_length="Length", petal_width="Width") + .tab_spanner("Petal", cs.starts_with("petal")) + .fmt_number("petal_width", decimals=2) + .tab_style( + style.fill("yellow"), + loc.body( + rows=pl.col("petal_length") == pl.col("petal_length").max(), + ), + ) +) +# --8<-- [end:full-example] + +# --8<-- [start:full-example-out] +from great_tables import loc, style + +print( + df.style.tab_header( + title="Iris Data", subtitle="Mean measurement values per species" + ) + .tab_stub(rowname_col="species") + .cols_label(petal_length="Length", petal_width="Width") + .tab_spanner("Petal", cs.starts_with("petal")) + .fmt_number("petal_width", decimals=2) + .tab_style( + style.fill("yellow"), + loc.body( + rows=pl.col("petal_length") == pl.col("petal_length").max(), + ), + ) + .tab_style( + style.text(weight="bold"), + loc.body(columns="species"), + ) + .as_raw_html() +) +# --8<-- [end:full-example-out] diff --git a/docs/source/src/python/user-guide/misc/visualization.py b/docs/source/src/python/user-guide/misc/visualization.py new file mode 100644 index 000000000000..939e8edc8a2e --- /dev/null +++ b/docs/source/src/python/user-guide/misc/visualization.py @@ -0,0 +1,220 @@ +# --8<-- [start:dataframe] +import polars as pl + +path = "docs/assets/data/iris.csv" + +df = pl.read_csv(path) +print(df) +# --8<-- [end:dataframe] + +""" +# --8<-- [start:hvplot_show_plot] +import hvplot.polars +df.hvplot.scatter( + x="sepal_width", + y="sepal_length", + by="species", + width=650, + title="Irises", + xlabel='Sepal Width', + ylabel='Sepal Length', +) +# --8<-- [end:hvplot_show_plot] +""" + +# --8<-- [start:hvplot_make_plot] +import hvplot.polars + +plot = df.hvplot.scatter( + x="sepal_width", + y="sepal_length", + by="species", + width=650, + title="Irises", + xlabel="Sepal Width", + ylabel="Sepal Length", +) +hvplot.save(plot, "docs/assets/images/hvplot_scatter.html", resources="cdn") +with open("docs/assets/images/hvplot_scatter.html", "r") as f: + chart_html = f.read() + print(f"{chart_html}") +# --8<-- [end:hvplot_make_plot] + +""" +# --8<-- [start:matplotlib_show_plot] +import matplotlib.pyplot as plt + +fig, ax = plt.subplots() +ax.scatter( + x=df["sepal_width"], + y=df["sepal_length"], + c=df["species"].cast(pl.Categorical).to_physical(), +) +ax.set_title('Irises') +ax.set_xlabel('Sepal Width') +ax.set_ylabel('Sepal Length') +# --8<-- [end:matplotlib_show_plot] +""" + +# --8<-- [start:matplotlib_make_plot] +import base64 + +import matplotlib.pyplot as plt + +fig, ax = plt.subplots() +ax.scatter( + x=df["sepal_width"], + y=df["sepal_length"], + c=df["species"].cast(pl.Categorical).to_physical(), +) +ax.set_title("Irises") +ax.set_xlabel("Sepal Width") +ax.set_ylabel("Sepal Length") +fig.savefig("docs/assets/images/matplotlib_scatter.png") +with open("docs/assets/images/matplotlib_scatter.png", "rb") as f: + png = base64.b64encode(f.read()).decode() + print(f'') +# --8<-- [end:matplotlib_make_plot] + +""" +# --8<-- [start:plotnine_show_plot] +from plotnine import ggplot, aes, geom_point, labs + +( + ggplot(df, mapping=aes(x="sepal_width", y="sepal_length", color="species")) + + geom_point() + + labs(title="Irises", x="Sepal Width", y="Sepal Length") +) +# --8<-- [end:plotnine_show_plot] +""" + +# --8<-- [start:plotnine_make_plot] +import base64 + +from plotnine import ggplot, aes, geom_point, labs + +fig_path = "docs/assets/images/plotnine.png" + +( + ggplot(df, mapping=aes(x="sepal_width", y="sepal_length", color="species")) + + geom_point() + + labs(title="Irises", x="Sepal Width", y="Sepal Length") +).save(fig_path, dpi=300, verbose=False) + +with open(fig_path, "rb") as f: + png = base64.b64encode(f.read()).decode() + print(f'') +# --8<-- [end:plotnine_make_plot] + +""" +# --8<-- [start:seaborn_show_plot] +import seaborn as sns +import matplotlib.pyplot as plt + +fig, ax = plt.subplots() +sns.scatterplot( + df, + x="sepal_width", + y="sepal_length", + hue="species", + ax=ax, +) +ax.set_title('Irises') +ax.set_xlabel('Sepal Width') +ax.set_ylabel('Sepal Length') +# --8<-- [end:seaborn_show_plot] +""" + +# --8<-- [start:seaborn_make_plot] +import seaborn as sns +import matplotlib.pyplot as plt + +fig, ax = plt.subplots() +sns.scatterplot( + df, + x="sepal_width", + y="sepal_length", + hue="species", + ax=ax, +) +ax.set_title("Irises") +ax.set_xlabel("Sepal Width") +ax.set_ylabel("Sepal Length") +fig.savefig("docs/assets/images/seaborn_scatter.png") +with open("docs/assets/images/seaborn_scatter.png", "rb") as f: + png = base64.b64encode(f.read()).decode() + print(f'') +# --8<-- [end:seaborn_make_plot] + +""" +# --8<-- [start:plotly_show_plot] +import plotly.express as px + +px.scatter( + df, + x="sepal_width", + y="sepal_length", + color="species", + width=650, + title="Irises", + labels={'sepal_width': 'Sepal Width', 'sepal_length': 'Sepal Length'} +) +# --8<-- [end:plotly_show_plot] +""" + +# --8<-- [start:plotly_make_plot] +import plotly.express as px + +fig = px.scatter( + df, + x="sepal_width", + y="sepal_length", + color="species", + width=650, + title="Irises", + labels={"sepal_width": "Sepal Width", "sepal_length": "Sepal Length"}, +) +fig.write_html( + "docs/assets/images/plotly_scatter.html", full_html=False, include_plotlyjs="cdn" +) +with open("docs/assets/images/plotly_scatter.html", "r") as f: + chart_html = f.read() + print(f"{chart_html}") +# --8<-- [end:plotly_make_plot] + +""" +# --8<-- [start:altair_show_plot] +chart = ( + df.plot.point( + x="sepal_width", + y="sepal_length", + color="species", + ) + .properties(width=500, title="Irises") + .configure_scale(zero=False) + .configure_axisX(tickMinStep=1) +) +chart.encoding.x.title = "Sepal Width" +chart.encoding.y.title = "Sepal Length" +chart +# --8<-- [end:altair_show_plot] +""" + +# --8<-- [start:altair_make_plot] +chart = ( + df.plot.point( + x="sepal_width", + y="sepal_length", + color="species", + ) + .properties(width=500, title="Irises") + .configure_scale(zero=False) + .configure_axisX(tickMinStep=1) +) +chart.encoding.x.title = "Sepal Width" +chart.encoding.y.title = "Sepal Length" +chart.save("docs/assets/images/altair_scatter.html") +with open("docs/assets/images/altair_scatter.html", "r") as f: + chart_html = f.read() + print(f"{chart_html}") +# --8<-- [end:altair_make_plot] diff --git a/docs/source/src/python/user-guide/sql/create.py b/docs/source/src/python/user-guide/sql/create.py new file mode 100644 index 000000000000..dd932ebebc7b --- /dev/null +++ b/docs/source/src/python/user-guide/sql/create.py @@ -0,0 +1,21 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:create] +data = {"name": ["Alice", "Bob", "Charlie", "David"], "age": [25, 30, 35, 40]} +df = pl.LazyFrame(data) + +ctx = pl.SQLContext(my_table=df, eager=True) + +result = ctx.execute( + """ + CREATE TABLE older_people + AS + SELECT * FROM my_table WHERE age > 30 +""" +) + +print(ctx.execute("SELECT * FROM older_people")) +# --8<-- [end:create] diff --git a/docs/source/src/python/user-guide/sql/cte.py b/docs/source/src/python/user-guide/sql/cte.py new file mode 100644 index 000000000000..c44b906cf3ad --- /dev/null +++ b/docs/source/src/python/user-guide/sql/cte.py @@ -0,0 +1,24 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:cte] +ctx = pl.SQLContext() +df = pl.LazyFrame( + {"name": ["Alice", "Bob", "Charlie", "David"], "age": [25, 30, 35, 40]} +) +ctx.register("my_table", df) + +result = ctx.execute( + """ + WITH older_people AS ( + SELECT * FROM my_table WHERE age > 30 + ) + SELECT * FROM older_people WHERE STARTS_WITH(name,'C') +""", + eager=True, +) + +print(result) +# --8<-- [end:cte] diff --git a/docs/source/src/python/user-guide/sql/intro.py b/docs/source/src/python/user-guide/sql/intro.py new file mode 100644 index 000000000000..2a6630c9a8a6 --- /dev/null +++ b/docs/source/src/python/user-guide/sql/intro.py @@ -0,0 +1,100 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:context] +ctx = pl.SQLContext() +# --8<-- [end:context] + +# --8<-- [start:register_context] +df = pl.DataFrame({"a": [1, 2, 3]}) +lf = pl.LazyFrame({"b": [4, 5, 6]}) + +# Register all dataframes in the global namespace: registers both "df" and "lf" +ctx = pl.SQLContext(register_globals=True) + +# Register an explicit mapping of identifier name to frame +ctx = pl.SQLContext(frames={"table_one": df, "table_two": lf}) + +# Register frames using kwargs; dataframe df as "df" and lazyframe lf as "lf" +ctx = pl.SQLContext(df=df, lf=lf) +# --8<-- [end:register_context] + +# --8<-- [start:register_pandas] +import pandas as pd + +df_pandas = pd.DataFrame({"c": [7, 8, 9]}) +ctx = pl.SQLContext(df_pandas=pl.from_pandas(df_pandas)) +# --8<-- [end:register_pandas] + +# --8<-- [start:execute] +# For local files use scan_csv instead +pokemon = pl.read_csv( + "https://gist.githubusercontent.com/ritchie46/cac6b337ea52281aa23c049250a4ff03/raw/89a957ff3919d90e6ef2d34235e6bf22304f3366/pokemon.csv" +) +with pl.SQLContext(register_globals=True, eager=True) as ctx: + df_small = ctx.execute("SELECT * from pokemon LIMIT 5") + print(df_small) +# --8<-- [end:execute] + +# --8<-- [start:prepare_multiple_sources] +with open("docs/assets/data/products_categories.json", "w") as temp_file: + json_data = """{"product_id": 1, "category": "Category 1"} +{"product_id": 2, "category": "Category 1"} +{"product_id": 3, "category": "Category 2"} +{"product_id": 4, "category": "Category 2"} +{"product_id": 5, "category": "Category 3"}""" + + temp_file.write(json_data) + +with open("docs/assets/data/products_masterdata.csv", "w") as temp_file: + csv_data = """product_id,product_name +1,Product A +2,Product B +3,Product C +4,Product D +5,Product E""" + + temp_file.write(csv_data) + +sales_data = pd.DataFrame( + { + "product_id": [1, 2, 3, 4, 5], + "sales": [100, 200, 150, 250, 300], + } +) +# --8<-- [end:prepare_multiple_sources] + +# --8<-- [start:execute_multiple_sources] +# Input data: +# products_masterdata.csv with schema {'product_id': Int64, 'product_name': String} +# products_categories.json with schema {'product_id': Int64, 'category': String} +# sales_data is a Pandas DataFrame with schema {'product_id': Int64, 'sales': Int64} + +with pl.SQLContext( + products_masterdata=pl.scan_csv("docs/assets/data/products_masterdata.csv"), + products_categories=pl.scan_ndjson("docs/assets/data/products_categories.json"), + sales_data=pl.from_pandas(sales_data), + eager=True, +) as ctx: + query = """ + SELECT + product_id, + product_name, + category, + sales + FROM + products_masterdata + LEFT JOIN products_categories USING (product_id) + LEFT JOIN sales_data USING (product_id) + """ + print(ctx.execute(query)) +# --8<-- [end:execute_multiple_sources] + +# --8<-- [start:clean_multiple_sources] +import os + +os.remove("docs/assets/data/products_categories.json") +os.remove("docs/assets/data/products_masterdata.csv") +# --8<-- [end:clean_multiple_sources] diff --git a/docs/source/src/python/user-guide/sql/select.py b/docs/source/src/python/user-guide/sql/select.py new file mode 100644 index 000000000000..15741a33d241 --- /dev/null +++ b/docs/source/src/python/user-guide/sql/select.py @@ -0,0 +1,106 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + + +# --8<-- [start:df] +df = pl.DataFrame( + { + "city": [ + "New York", + "Los Angeles", + "Chicago", + "Houston", + "Phoenix", + "Amsterdam", + ], + "country": ["USA", "USA", "USA", "USA", "USA", "Netherlands"], + "population": [8399000, 3997000, 2705000, 2320000, 1680000, 900000], + } +) + +ctx = pl.SQLContext(population=df, eager=True) + +print(ctx.execute("SELECT * FROM population")) +# --8<-- [end:df] + +# --8<-- [start:group_by] +result = ctx.execute( + """ + SELECT country, AVG(population) as avg_population + FROM population + GROUP BY country + """ +) +print(result) +# --8<-- [end:group_by] + + +# --8<-- [start:orderby] +result = ctx.execute( + """ + SELECT city, population + FROM population + ORDER BY population + """ +) +print(result) +# --8<-- [end:orderby] + +# --8<-- [start:join] +income = pl.DataFrame( + { + "city": [ + "New York", + "Los Angeles", + "Chicago", + "Houston", + "Amsterdam", + "Rotterdam", + "Utrecht", + ], + "country": [ + "USA", + "USA", + "USA", + "USA", + "Netherlands", + "Netherlands", + "Netherlands", + ], + "income": [55000, 62000, 48000, 52000, 42000, 38000, 41000], + } +) +ctx.register_many(income=income) +result = ctx.execute( + """ + SELECT country, city, income, population + FROM population + LEFT JOIN income on population.city = income.city + """ +) +print(result) +# --8<-- [end:join] + + +# --8<-- [start:functions] +result = ctx.execute( + """ + SELECT city, population + FROM population + WHERE STARTS_WITH(country,'U') + """ +) +print(result) +# --8<-- [end:functions] + +# --8<-- [start:tablefunctions] +result = ctx.execute( + """ + SELECT * + FROM read_csv('docs/assets/data/iris.csv') + """ +) +print(result) +# --8<-- [end:tablefunctions] diff --git a/docs/source/src/python/user-guide/sql/show.py b/docs/source/src/python/user-guide/sql/show.py new file mode 100644 index 000000000000..cedf425dc54b --- /dev/null +++ b/docs/source/src/python/user-guide/sql/show.py @@ -0,0 +1,26 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + + +# --8<-- [start:show] +# Create some DataFrames and register them with the SQLContext +df1 = pl.LazyFrame( + { + "name": ["Alice", "Bob", "Charlie", "David"], + "age": [25, 30, 35, 40], + } +) +df2 = pl.LazyFrame( + { + "name": ["Ellen", "Frank", "Gina", "Henry"], + "age": [45, 50, 55, 60], + } +) +ctx = pl.SQLContext(mytable1=df1, mytable2=df2) + +tables = ctx.execute("SHOW TABLES", eager=True) + +print(tables) +# --8<-- [end:show] diff --git a/docs/source/src/python/user-guide/transformations/concatenation.py b/docs/source/src/python/user-guide/transformations/concatenation.py new file mode 100644 index 000000000000..60a24aed1930 --- /dev/null +++ b/docs/source/src/python/user-guide/transformations/concatenation.py @@ -0,0 +1,98 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:vertical] +df_v1 = pl.DataFrame( + { + "a": [1], + "b": [3], + } +) +df_v2 = pl.DataFrame( + { + "a": [2], + "b": [4], + } +) +df_vertical_concat = pl.concat( + [ + df_v1, + df_v2, + ], + how="vertical", +) +print(df_vertical_concat) +# --8<-- [end:vertical] + +# --8<-- [start:horizontal] +df_h1 = pl.DataFrame( + { + "l1": [1, 2], + "l2": [3, 4], + } +) +df_h2 = pl.DataFrame( + { + "r1": [5, 6], + "r2": [7, 8], + "r3": [9, 10], + } +) +df_horizontal_concat = pl.concat( + [ + df_h1, + df_h2, + ], + how="horizontal", +) +print(df_horizontal_concat) +# --8<-- [end:horizontal] + +# --8<-- [start:horizontal_different_lengths] +df_h1 = pl.DataFrame( + { + "l1": [1, 2], + "l2": [3, 4], + } +) +df_h2 = pl.DataFrame( + { + "r1": [5, 6, 7], + "r2": [8, 9, 10], + } +) +df_horizontal_concat = pl.concat( + [ + df_h1, + df_h2, + ], + how="horizontal", +) +print(df_horizontal_concat) +# --8<-- [end:horizontal_different_lengths] + +# --8<-- [start:cross] +df_d1 = pl.DataFrame( + { + "a": [1], + "b": [3], + } +) +df_d2 = pl.DataFrame( + { + "a": [2], + "d": [4], + } +) + +df_diagonal_concat = pl.concat( + [ + df_d1, + df_d2, + ], + how="diagonal", +) +print(df_diagonal_concat) +# --8<-- [end:cross] diff --git a/docs/source/src/python/user-guide/transformations/joins.py b/docs/source/src/python/user-guide/transformations/joins.py new file mode 100644 index 000000000000..e44cbdc560c1 --- /dev/null +++ b/docs/source/src/python/user-guide/transformations/joins.py @@ -0,0 +1,185 @@ +# --8<-- [start:prep-data] +import pathlib +import requests + + +DATA = [ + ( + "https://raw.githubusercontent.com/pola-rs/polars-static/refs/heads/master/data/monopoly_props_groups.csv", + "docs/assets/data/monopoly_props_groups.csv", + ), + ( + "https://raw.githubusercontent.com/pola-rs/polars-static/refs/heads/master/data/monopoly_props_prices.csv", + "docs/assets/data/monopoly_props_prices.csv", + ), +] + + +for url, dest in DATA: + if pathlib.Path(dest).exists(): + continue + with open(dest, "wb") as f: + f.write(requests.get(url, timeout=10).content) +# --8<-- [end:prep-data] + +# --8<-- [start:props_groups] +import polars as pl + +props_groups = pl.read_csv("docs/assets/data/monopoly_props_groups.csv").head(5) +print(props_groups) +# --8<-- [end:props_groups] + +# --8<-- [start:props_prices] +props_prices = pl.read_csv("docs/assets/data/monopoly_props_prices.csv").head(5) +print(props_prices) +# --8<-- [end:props_prices] + +# --8<-- [start:equi-join] +result = props_groups.join(props_prices, on="property_name") +print(result) +# --8<-- [end:equi-join] + +# --8<-- [start:props_groups2] +props_groups2 = props_groups.with_columns( + pl.col("property_name").str.to_lowercase(), +) +print(props_groups2) +# --8<-- [end:props_groups2] + +# --8<-- [start:props_prices2] +props_prices2 = props_prices.select( + pl.col("property_name").alias("name"), pl.col("cost") +) +print(props_prices2) +# --8<-- [end:props_prices2] + +# --8<-- [start:join-key-expression] +result = props_groups2.join( + props_prices2, + left_on="property_name", + right_on=pl.col("name").str.to_lowercase(), +) +print(result) +# --8<-- [end:join-key-expression] + +# --8<-- [start:inner-join] +result = props_groups.join(props_prices, on="property_name", how="inner") +print(result) +# --8<-- [end:inner-join] + +# --8<-- [start:left-join] +result = props_groups.join(props_prices, on="property_name", how="left") +print(result) +# --8<-- [end:left-join] + +# --8<-- [start:right-join] +result = props_groups.join(props_prices, on="property_name", how="right") +print(result) +# --8<-- [end:right-join] + +# --8<-- [start:left-right-join-equals] +print( + result.equals( + props_prices.join( + props_groups, + on="property_name", + how="left", + # Reorder the columns to match the order from above. + ).select(pl.col("group"), pl.col("property_name"), pl.col("cost")) + ) +) +# --8<-- [end:left-right-join-equals] + +# --8<-- [start:full-join] +result = props_groups.join(props_prices, on="property_name", how="full") +print(result) +# --8<-- [end:full-join] + +# --8<-- [start:full-join-coalesce] +result = props_groups.join( + props_prices, + on="property_name", + how="full", + coalesce=True, +) +print(result) +# --8<-- [end:full-join-coalesce] + +# --8<-- [start:semi-join] +result = props_groups.join(props_prices, on="property_name", how="semi") +print(result) +# --8<-- [end:semi-join] + +# --8<-- [start:anti-join] +result = props_groups.join(props_prices, on="property_name", how="anti") +print(result) +# --8<-- [end:anti-join] + +# --8<-- [start:players] +players = pl.DataFrame( + { + "name": ["Alice", "Bob"], + "cash": [78, 135], + } +) +print(players) +# --8<-- [end:players] + +# --8<-- [start:non-equi] +result = players.join_where(props_prices, pl.col("cash") > pl.col("cost")) +print(result) +# --8<-- [end:non-equi] + +# --8<-- [start:df_trades] +from datetime import datetime + +df_trades = pl.DataFrame( + { + "time": [ + datetime(2020, 1, 1, 9, 1, 0), + datetime(2020, 1, 1, 9, 1, 0), + datetime(2020, 1, 1, 9, 3, 0), + datetime(2020, 1, 1, 9, 6, 0), + ], + "stock": ["A", "B", "B", "C"], + "trade": [101, 299, 301, 500], + } +) +print(df_trades) +# --8<-- [end:df_trades] + +# --8<-- [start:df_quotes] +df_quotes = pl.DataFrame( + { + "time": [ + datetime(2020, 1, 1, 9, 0, 0), + datetime(2020, 1, 1, 9, 2, 0), + datetime(2020, 1, 1, 9, 4, 0), + datetime(2020, 1, 1, 9, 6, 0), + ], + "stock": ["A", "B", "C", "A"], + "quote": [100, 300, 501, 102], + } +) + +print(df_quotes) +# --8<-- [end:df_quotes] + +# --8<-- [start:asof] +df_asof_join = df_trades.join_asof(df_quotes, on="time", by="stock") +print(df_asof_join) +# --8<-- [end:asof] + +# --8<-- [start:asof-tolerance] +df_asof_tolerance_join = df_trades.join_asof( + df_quotes, on="time", by="stock", tolerance="1m" +) +print(df_asof_tolerance_join) +# --8<-- [end:asof-tolerance] + +# --8<-- [start:cartesian-product] +tokens = pl.DataFrame({"monopoly_token": ["hat", "shoe", "boat"]}) + +result = players.select(pl.col("name")).join(tokens, how="cross") +print(result) +# --8<-- [end:cartesian-product] diff --git a/docs/source/src/python/user-guide/transformations/pivot.py b/docs/source/src/python/user-guide/transformations/pivot.py new file mode 100644 index 000000000000..95354072d24e --- /dev/null +++ b/docs/source/src/python/user-guide/transformations/pivot.py @@ -0,0 +1,31 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:df] +df = pl.DataFrame( + { + "foo": ["A", "A", "B", "B", "C"], + "N": [1, 2, 2, 4, 2], + "bar": ["k", "l", "m", "n", "o"], + } +) +print(df) +# --8<-- [end:df] + +# --8<-- [start:eager] +out = df.pivot("bar", index="foo", values="N", aggregate_function="first") +print(out) +# --8<-- [end:eager] + +# --8<-- [start:lazy] +q = ( + df.lazy() + .collect() + .pivot(index="foo", on="bar", values="N", aggregate_function="first") + .lazy() +) +out = q.collect() +print(out) +# --8<-- [end:lazy] diff --git a/docs/source/src/python/user-guide/transformations/time-series/filter.py b/docs/source/src/python/user-guide/transformations/time-series/filter.py new file mode 100644 index 000000000000..f3344784373d --- /dev/null +++ b/docs/source/src/python/user-guide/transformations/time-series/filter.py @@ -0,0 +1,30 @@ +# --8<-- [start:df] +import polars as pl +from datetime import datetime + +df = pl.read_csv("docs/assets/data/apple_stock.csv", try_parse_dates=True) +print(df) +# --8<-- [end:df] + +# --8<-- [start:filter] +filtered_df = df.filter( + pl.col("Date") == datetime(1995, 10, 16), +) +print(filtered_df) +# --8<-- [end:filter] + +# --8<-- [start:range] +filtered_range_df = df.filter( + pl.col("Date").is_between(datetime(1995, 7, 1), datetime(1995, 11, 1)), +) +print(filtered_range_df) +# --8<-- [end:range] + +# --8<-- [start:negative] +ts = pl.Series(["-1300-05-23", "-1400-03-02"]).str.to_date() + +negative_dates_df = pl.DataFrame({"ts": ts, "values": [3, 4]}) + +negative_dates_filtered_df = negative_dates_df.filter(pl.col("ts").dt.year() < -1300) +print(negative_dates_filtered_df) +# --8<-- [end:negative] diff --git a/docs/source/src/python/user-guide/transformations/time-series/parsing.py b/docs/source/src/python/user-guide/transformations/time-series/parsing.py new file mode 100644 index 000000000000..2a9e6335f232 --- /dev/null +++ b/docs/source/src/python/user-guide/transformations/time-series/parsing.py @@ -0,0 +1,43 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:df] +df = pl.read_csv("docs/assets/data/apple_stock.csv", try_parse_dates=True) +print(df) +# --8<-- [end:df] + + +# --8<-- [start:cast] +df = pl.read_csv("docs/assets/data/apple_stock.csv", try_parse_dates=False) + +df = df.with_columns(pl.col("Date").str.to_date("%Y-%m-%d")) +print(df) +# --8<-- [end:cast] + + +# --8<-- [start:df3] +df_with_year = df.with_columns(pl.col("Date").dt.year().alias("year")) +print(df_with_year) +# --8<-- [end:df3] + +# --8<-- [start:extract] +df_with_year = df.with_columns(pl.col("Date").dt.year().alias("year")) +print(df_with_year) +# --8<-- [end:extract] + +# --8<-- [start:mixed] +data = [ + "2021-03-27T00:00:00+0100", + "2021-03-28T00:00:00+0100", + "2021-03-29T00:00:00+0200", + "2021-03-30T00:00:00+0200", +] +mixed_parsed = ( + pl.Series(data) + .str.to_datetime("%Y-%m-%dT%H:%M:%S%z") + .dt.convert_time_zone("Europe/Brussels") +) +print(mixed_parsed) +# --8<-- [end:mixed] diff --git a/docs/source/src/python/user-guide/transformations/time-series/resampling.py b/docs/source/src/python/user-guide/transformations/time-series/resampling.py new file mode 100644 index 000000000000..80a7b2597a67 --- /dev/null +++ b/docs/source/src/python/user-guide/transformations/time-series/resampling.py @@ -0,0 +1,36 @@ +# --8<-- [start:setup] +from datetime import datetime + +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:df] +df = pl.DataFrame( + { + "time": pl.datetime_range( + start=datetime(2021, 12, 16), + end=datetime(2021, 12, 16, 3), + interval="30m", + eager=True, + ), + "groups": ["a", "a", "a", "b", "b", "a", "a"], + "values": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], + } +) +print(df) +# --8<-- [end:df] + +# --8<-- [start:upsample] +out1 = df.upsample(time_column="time", every="15m").fill_null(strategy="forward") +print(out1) +# --8<-- [end:upsample] + +# --8<-- [start:upsample2] +out2 = ( + df.upsample(time_column="time", every="15m") + .interpolate() + .fill_null(strategy="forward") +) +print(out2) +# --8<-- [end:upsample2] diff --git a/docs/source/src/python/user-guide/transformations/time-series/rolling.py b/docs/source/src/python/user-guide/transformations/time-series/rolling.py new file mode 100644 index 000000000000..731a3b4a61db --- /dev/null +++ b/docs/source/src/python/user-guide/transformations/time-series/rolling.py @@ -0,0 +1,66 @@ +# --8<-- [start:setup] +from datetime import date, datetime + +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:df] +df = pl.read_csv("docs/assets/data/apple_stock.csv", try_parse_dates=True) +df = df.sort("Date") +print(df) +# --8<-- [end:df] + +# --8<-- [start:group_by] +annual_average_df = df.group_by_dynamic("Date", every="1y").agg(pl.col("Close").mean()) + +df_with_year = annual_average_df.with_columns(pl.col("Date").dt.year().alias("year")) +print(df_with_year) +# --8<-- [end:group_by] + +# --8<-- [start:group_by_dyn] +df = ( + pl.date_range( + start=date(2021, 1, 1), + end=date(2021, 12, 31), + interval="1d", + eager=True, + ) + .alias("time") + .to_frame() +) + +out = df.group_by_dynamic("time", every="1mo", period="1mo", closed="left").agg( + pl.col("time").cum_count().reverse().head(3).alias("day/eom"), + ((pl.col("time") - pl.col("time").first()).last().dt.total_days() + 1).alias( + "days_in_month" + ), +) +print(out) +# --8<-- [end:group_by_dyn] + +# --8<-- [start:group_by_roll] +df = pl.DataFrame( + { + "time": pl.datetime_range( + start=datetime(2021, 12, 16), + end=datetime(2021, 12, 16, 3), + interval="30m", + eager=True, + ), + "groups": ["a", "a", "a", "b", "b", "a", "a"], + } +) +print(df) +# --8<-- [end:group_by_roll] + +# --8<-- [start:group_by_dyn2] +out = df.group_by_dynamic( + "time", + every="1h", + closed="both", + group_by="groups", + include_boundaries=True, +).agg(pl.len()) +print(out) +# --8<-- [end:group_by_dyn2] diff --git a/docs/source/src/python/user-guide/transformations/time-series/timezones.py b/docs/source/src/python/user-guide/transformations/time-series/timezones.py new file mode 100644 index 000000000000..0f5470b08e30 --- /dev/null +++ b/docs/source/src/python/user-guide/transformations/time-series/timezones.py @@ -0,0 +1,27 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:example] +ts = ["2021-03-27 03:00", "2021-03-28 03:00"] +tz_naive = pl.Series("tz_naive", ts).str.to_datetime() +tz_aware = tz_naive.dt.replace_time_zone("UTC").rename("tz_aware") +time_zones_df = pl.DataFrame([tz_naive, tz_aware]) +print(time_zones_df) +# --8<-- [end:example] + +# --8<-- [start:example2] +time_zones_operations = time_zones_df.select( + [ + pl.col("tz_aware") + .dt.replace_time_zone("Europe/Brussels") + .alias("replace time zone"), + pl.col("tz_aware") + .dt.convert_time_zone("Asia/Kathmandu") + .alias("convert time zone"), + pl.col("tz_aware").dt.replace_time_zone(None).alias("unset time zone"), + ] +) +print(time_zones_operations) +# --8<-- [end:example2] diff --git a/docs/source/src/python/user-guide/transformations/unpivot.py b/docs/source/src/python/user-guide/transformations/unpivot.py new file mode 100644 index 000000000000..3e79279eec17 --- /dev/null +++ b/docs/source/src/python/user-guide/transformations/unpivot.py @@ -0,0 +1,18 @@ +# --8<-- [start:df] +import polars as pl + +df = pl.DataFrame( + { + "A": ["a", "b", "a"], + "B": [1, 3, 5], + "C": [10, 11, 12], + "D": [2, 4, 6], + } +) +print(df) +# --8<-- [end:df] + +# --8<-- [start:unpivot] +out = df.unpivot(["C", "D"], index=["A", "B"]) +print(out) +# --8<-- [end:unpivot] diff --git a/docs/source/src/rust/Cargo.toml b/docs/source/src/rust/Cargo.toml new file mode 100644 index 000000000000..1a08751ad4ad --- /dev/null +++ b/docs/source/src/rust/Cargo.toml @@ -0,0 +1,157 @@ +[package] +name = "polars-doc-examples" +version = { workspace = true } +authors = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +license = { workspace = true } +repository = { workspace = true } +description = "Code examples included in the Polars documentation website" + +[dependencies] +aws-config = { version = "1" } +aws-sdk-s3 = { version = "1" } +aws-smithy-checksums = { version = "0.60.10" } +chrono = { workspace = true } +rand = { workspace = true } +reqwest = { workspace = true, features = ["blocking", "default-tls"] } +tokio = { workspace = true } + +[dependencies.polars] +workspace = true + +[[bin]] +name = "home" +path = "home/example.rs" +required-features = ["polars/lazy", "polars/csv"] + +[[bin]] +name = "getting-started" +path = "user-guide/getting-started.rs" +required-features = ["polars/lazy", "polars/temporal", "polars/round_series", "polars/strings", "polars/is_between"] + +[[bin]] +name = "concepts-data-types-and-structures" +path = "user-guide/concepts/data-types-and-structures.rs" +required-features = ["polars/lazy", "polars/temporal"] + +[[bin]] +name = "concepts-expressions" +path = "user-guide/concepts/expressions.rs" +required-features = ["polars/lazy", "polars/temporal", "polars/is_between"] +[[bin]] +name = "concepts-lazy-vs-eager" +path = "user-guide/concepts/lazy-vs-eager.rs" +required-features = ["polars/lazy", "polars/csv"] +[[bin]] +name = "concepts-streaming" +path = "user-guide/concepts/streaming.rs" +required-features = ["polars/lazy", "polars/csv", "polars/streaming"] + +[[bin]] +name = "expressions-aggregation" +path = "user-guide/expressions/aggregation.rs" +required-features = ["polars/lazy", "polars/csv", "polars/temporal", "polars/dtype-categorical"] +[[bin]] +name = "expressions-casting" +path = "user-guide/expressions/casting.rs" +required-features = ["polars/lazy", "polars/temporal", "polars/strings", "polars/dtype-u8"] +[[bin]] +name = "expressions-column-selections" +path = "user-guide/expressions/column-selections.rs" +required-features = ["polars/lazy", "polars/temporal", "polars/regex"] +[[bin]] +name = "expressions-folds" +path = "user-guide/expressions/folds.rs" +required-features = ["polars/lazy", "polars/strings", "polars/concat_str", "polars/temporal"] +[[bin]] +name = "expressions-expression-expansion" +path = "user-guide/expressions/expression-expansion.rs" +required-features = ["polars/lazy", "polars/round_series", "polars/regex"] +[[bin]] +name = "expressions-lists" +path = "user-guide/expressions/lists.rs" +required-features = ["polars/lazy"] +[[bin]] +name = "expressions-missing-data" +path = "user-guide/expressions/missing-data.rs" +required-features = ["polars/lazy", "polars/interpolate"] +[[bin]] +name = "expressions-operations" +path = "user-guide/expressions/operations.rs" +required-features = ["polars/lazy", "polars/approx_unique", "polars/dtype-struct", "polars/unique_counts"] +[[bin]] +name = "expressions-strings" +path = "user-guide/expressions/strings.rs" +required-features = ["polars/lazy", "polars/strings", "polars/regex"] +[[bin]] +name = "expressions-structs" +path = "user-guide/expressions/structs.rs" +required-features = ["polars/lazy", "polars/dtype-struct", "polars/rank", "polars/strings", "polars/temporal"] +[[bin]] +name = "expressions-window" +path = "user-guide/expressions/window.rs" +required-features = ["polars/lazy", "polars/csv", "polars/rank"] + +[[bin]] +name = "io-cloud-storage" +path = "user-guide/io/cloud-storage.rs" +required-features = ["polars/csv"] +[[bin]] +name = "io-csv" +path = "user-guide/io/csv.rs" +required-features = ["polars/lazy", "polars/csv"] +[[bin]] +name = "io-json" +path = "user-guide/io/json.rs" +required-features = ["polars/lazy", "polars/json"] +[[bin]] +name = "io-parquet" +path = "user-guide/io/parquet.rs" +required-features = ["polars/lazy", "polars/parquet"] + +[[bin]] +name = "transformations-concatenation" +path = "user-guide/transformations/concatenation.rs" +required-features = ["polars/lazy", "polars/diagonal_concat"] +[[bin]] +name = "transformations-joins" +path = "user-guide/transformations/joins.rs" +required-features = [ + "polars/lazy", + "polars/strings", + "polars/semi_anti_join", + "polars/iejoin", + "polars/cross_join", + "polars/temporal", + "polars/asof_join", +] +[[bin]] +name = "transformations-unpivot" +path = "user-guide/transformations/unpivot.rs" +required-features = ["polars/pivot"] +[[bin]] +name = "transformations-pivot" +path = "user-guide/transformations/pivot.rs" +required-features = ["polars/lazy", "polars/pivot"] + +[[bin]] +name = "transformations-time-series-filter" +path = "user-guide/transformations/time-series/filter.rs" +required-features = ["polars/lazy", "polars/strings", "polars/temporal"] +[[bin]] +name = "transformations-time-series-parsing" +path = "user-guide/transformations/time-series/parsing.rs" +required-features = ["polars/lazy", "polars/strings", "polars/temporal", "polars/timezones"] +[[bin]] +name = "transformations-time-series-resampling" +path = "user-guide/transformations/time-series/resampling.rs" +required-features = ["polars/lazy", "polars/temporal", "polars/interpolate"] +[[bin]] +name = "transformations-time-series-rolling" +path = "user-guide/transformations/time-series/rolling.rs" +required-features = ["polars/lazy", "polars/temporal", "polars/dynamic_group_by", "polars/cum_agg"] +[[bin]] +name = "transformations-time-series-timezones" +path = "user-guide/transformations/time-series/timezones.rs" +required-features = ["polars/lazy", "polars/temporal", "polars/timezones", "polars/strings"] diff --git a/docs/source/src/rust/home/example.rs b/docs/source/src/rust/home/example.rs new file mode 100644 index 000000000000..17f1c0430376 --- /dev/null +++ b/docs/source/src/rust/home/example.rs @@ -0,0 +1,18 @@ +fn main() -> Result<(), Box> { + // --8<-- [start:example] + use polars::prelude::*; + + let q = LazyCsvReader::new("docs/assets/data/iris.csv") + .with_has_header(true) + .finish()? + .filter(col("sepal_length").gt(lit(5))) + .group_by(vec![col("species")]) + .agg([col("*").sum()]); + + let df = q.collect()?; + // --8<-- [end:example] + + println!("{}", df); + + Ok(()) +} diff --git a/docs/source/src/rust/user-guide/concepts/data-types-and-structures.rs b/docs/source/src/rust/user-guide/concepts/data-types-and-structures.rs new file mode 100644 index 000000000000..0b0b92a4f972 --- /dev/null +++ b/docs/source/src/rust/user-guide/concepts/data-types-and-structures.rs @@ -0,0 +1,62 @@ +fn main() { + // --8<-- [start:series] + use polars::prelude::*; + + let s = Series::new("ints".into(), &[1, 2, 3, 4, 5]); + + println!("{}", s); + // --8<-- [end:series] + + // --8<-- [start:series-dtype] + let s1 = Series::new("ints".into(), &[1, 2, 3, 4, 5]); + let s2 = Series::new("uints".into(), &[1, 2, 3, 4, 5]) + .cast(&DataType::UInt64) // Here, we actually cast after inference. + .unwrap(); + println!("{} {}", s1.dtype(), s2.dtype()); // i32 u64 + // --8<-- [end:series-dtype] + + // --8<-- [start:df] + use chrono::prelude::*; + + let df: DataFrame = df!( + "name" => ["Alice Archer", "Ben Brown", "Chloe Cooper", "Daniel Donovan"], + "birthdate" => [ + NaiveDate::from_ymd_opt(1997, 1, 10).unwrap(), + NaiveDate::from_ymd_opt(1985, 2, 15).unwrap(), + NaiveDate::from_ymd_opt(1983, 3, 22).unwrap(), + NaiveDate::from_ymd_opt(1981, 4, 30).unwrap(), + ], + "weight" => [57.9, 72.5, 53.6, 83.1], // (kg) + "height" => [1.56, 1.77, 1.65, 1.75], // (m) + ) + .unwrap(); + println!("{}", df); + // --8<-- [end:df] + + // --8<-- [start:schema] + println!("{:?}", df.schema()); + // --8<-- [end:schema] + + // --8<-- [start:head] + let df_head = df.head(Some(3)); + + println!("{}", df_head); + // --8<-- [end:head] + + // --8<-- [start:tail] + let df_tail = df.tail(Some(3)); + + println!("{}", df_tail); + // --8<-- [end:tail] + + // --8<-- [start:sample] + let n = Series::new("".into(), &[2]); + let sampled_df = df.sample_n(&n, false, false, None).unwrap(); + + println!("{}", sampled_df); + // --8<-- [end:sample] + + // --8<-- [start:describe] + // Not available in Rust + // --8<-- [end:describe] +} diff --git a/docs/source/src/rust/user-guide/concepts/expressions.rs b/docs/source/src/rust/user-guide/concepts/expressions.rs new file mode 100644 index 000000000000..c74e9847e3fb --- /dev/null +++ b/docs/source/src/rust/user-guide/concepts/expressions.rs @@ -0,0 +1,135 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box> { + // --8<-- [start:df] + use chrono::prelude::*; + use polars::prelude::*; + + let df: DataFrame = df!( + "name" => ["Alice Archer", "Ben Brown", "Chloe Cooper", "Daniel Donovan"], + "birthdate" => [ + NaiveDate::from_ymd_opt(1997, 1, 10).unwrap(), + NaiveDate::from_ymd_opt(1985, 2, 15).unwrap(), + NaiveDate::from_ymd_opt(1983, 3, 22).unwrap(), + NaiveDate::from_ymd_opt(1981, 4, 30).unwrap(), + ], + "weight" => [57.9, 72.5, 53.6, 83.1], // (kg) + "height" => [1.56, 1.77, 1.65, 1.75], // (m) + ) + .unwrap(); + println!("{}", df); + // --8<-- [end:df] + + // --8<-- [start:select-1] + let bmi = col("weight") / col("height").pow(2); + let result = df + .clone() + .lazy() + .select([ + bmi.clone().alias("bmi"), + bmi.clone().mean().alias("avg_bmi"), + lit(25).alias("ideal_max_bmi"), + ]) + .collect()?; + println!("{}", result); + // --8<-- [end:select-1] + + // --8<-- [start:select-2] + let result = df + .clone() + .lazy() + .select([((bmi.clone() - bmi.clone().mean()) / bmi.clone().std(1)).alias("deviation")]) + .collect()?; + println!("{}", result); + // --8<-- [end:select-2] + + // --8<-- [start:with_columns-1] + let result = df + .clone() + .lazy() + .with_columns([ + bmi.clone().alias("bmi"), + bmi.clone().mean().alias("avg_bmi"), + lit(25).alias("ideal_max_bmi"), + ]) + .collect()?; + println!("{}", result); + // --8<-- [end:with_columns-1] + + // --8<-- [start:filter-1] + let result = df + .clone() + .lazy() + .filter( + col("birthdate") + .is_between( + lit(NaiveDate::from_ymd_opt(1982, 12, 31).unwrap()), + lit(NaiveDate::from_ymd_opt(1996, 1, 1).unwrap()), + ClosedInterval::Both, + ) + .and(col("height").gt(lit(1.7))), + ) + .collect()?; + println!("{}", result); + // --8<-- [end:filter-1] + + // --8<-- [start:group_by-1] + let result = df + .clone() + .lazy() + .group_by([(col("birthdate").dt().year() / lit(10) * lit(10)).alias("decade")]) + .agg([col("name")]) + .collect()?; + println!("{}", result); + // --8<-- [end:group_by-1] + + // --8<-- [start:group_by-2] + let result = df + .clone() + .lazy() + .group_by([ + (col("birthdate").dt().year() / lit(10) * lit(10)).alias("decade"), + (col("height").lt(lit(1.7)).alias("short?")), + ]) + .agg([col("name")]) + .collect()?; + println!("{}", result); + // --8<-- [end:group_by-2] + + // --8<-- [start:group_by-3] + let result = df + .clone() + .lazy() + .group_by([ + (col("birthdate").dt().year() / lit(10) * lit(10)).alias("decade"), + (col("height").lt(lit(1.7)).alias("short?")), + ]) + .agg([ + len(), + col("height").max().alias("tallest"), + cols(["weight", "height"]).mean().name().prefix("avg_"), + ]) + .collect()?; + println!("{}", result); + // --8<-- [end:group_by-3] + + // --8<-- [start:expression-expansion-1] + let expr = (dtype_col(&DataType::Float64) * lit(1.1)) + .name() + .suffix("*1.1"); + let result = df.clone().lazy().select([expr.clone()]).collect()?; + println!("{}", result); + // --8<-- [end:expression-expansion-1] + + // --8<-- [start:expression-expansion-2] + let df2: DataFrame = df!( + "ints" => [1, 2, 3, 4], + "letters" => ["A", "B", "C", "D"], + ) + .unwrap(); + let result = df2.clone().lazy().select([expr.clone()]).collect()?; + println!("{}", result); + // --8<-- [end:expression-expansion-2] + + Ok(()) +} diff --git a/docs/source/src/rust/user-guide/concepts/lazy-vs-eager.rs b/docs/source/src/rust/user-guide/concepts/lazy-vs-eager.rs new file mode 100644 index 000000000000..89f3a6610748 --- /dev/null +++ b/docs/source/src/rust/user-guide/concepts/lazy-vs-eager.rs @@ -0,0 +1,42 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box> { + // --8<-- [start:eager] + let df = CsvReadOptions::default() + .try_into_reader_with_file_path(Some("docs/assets/data/iris.csv".into())) + .unwrap() + .finish() + .unwrap(); + let mask = df.column("sepal_length")?.f64()?.gt(5.0); + let df_small = df.filter(&mask)?; + #[allow(deprecated)] + let df_agg = df_small + .group_by(["species"])? + .select(["sepal_width"]) + .mean()?; + println!("{}", df_agg); + // --8<-- [end:eager] + + // --8<-- [start:lazy] + let q = LazyCsvReader::new("docs/assets/data/iris.csv") + .with_has_header(true) + .finish()? + .filter(col("sepal_length").gt(lit(5))) + .group_by(vec![col("species")]) + .agg([col("sepal_width").mean()]); + let df = q.collect()?; + println!("{}", df); + // --8<-- [end:lazy] + + // --8<-- [start:explain] + let q = LazyCsvReader::new("docs/assets/data/iris.csv") + .with_has_header(true) + .finish()? + .filter(col("sepal_length").gt(lit(5))) + .group_by(vec![col("species")]) + .agg([col("sepal_width").mean()]); + println!("{}", q.explain(true)?); + // --8<-- [end:explain] + + Ok(()) +} diff --git a/docs/source/src/rust/user-guide/concepts/streaming.rs b/docs/source/src/rust/user-guide/concepts/streaming.rs new file mode 100644 index 000000000000..6b4182dd67ff --- /dev/null +++ b/docs/source/src/rust/user-guide/concepts/streaming.rs @@ -0,0 +1,36 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box> { + // --8<-- [start:streaming] + let q1 = LazyCsvReader::new("docs/assets/data/iris.csv") + .with_has_header(true) + .finish()? + .filter(col("sepal_length").gt(lit(5))) + .group_by(vec![col("species")]) + .agg([col("sepal_width").mean()]); + + let df = q1.clone().with_streaming(true).collect()?; + println!("{}", df); + // --8<-- [end:streaming] + + // --8<-- [start:example] + let query_plan = q1.with_streaming(true).explain(true)?; + println!("{}", query_plan); + // --8<-- [end:example] + + // --8<-- [start:example2] + let q2 = LazyCsvReader::new("docs/assets/data/iris.csv") + .finish()? + .with_columns(vec![ + col("sepal_length") + .mean() + .over(vec![col("species")]) + .alias("sepal_length_mean"), + ]); + + let query_plan = q2.with_streaming(true).explain(true)?; + println!("{}", query_plan); + // --8<-- [end:example2] + + Ok(()) +} diff --git a/docs/source/src/rust/user-guide/expressions/aggregation.rs b/docs/source/src/rust/user-guide/expressions/aggregation.rs new file mode 100644 index 000000000000..eb63627277b4 --- /dev/null +++ b/docs/source/src/rust/user-guide/expressions/aggregation.rs @@ -0,0 +1,242 @@ +fn main() -> Result<(), Box> { + // --8<-- [start:dataframe] + use std::io::Cursor; + + use polars::prelude::*; + use reqwest::blocking::Client; + + let url = "https://huggingface.co/datasets/nameexhaustion/polars-docs/resolve/main/legislators-historical.csv"; + + let mut schema = Schema::default(); + schema.with_column( + "first_name".into(), + DataType::Categorical(None, Default::default()), + ); + schema.with_column( + "gender".into(), + DataType::Categorical(None, Default::default()), + ); + schema.with_column( + "type".into(), + DataType::Categorical(None, Default::default()), + ); + schema.with_column( + "state".into(), + DataType::Categorical(None, Default::default()), + ); + schema.with_column( + "party".into(), + DataType::Categorical(None, Default::default()), + ); + schema.with_column("birthday".into(), DataType::Date); + + let data = Client::new().get(url).send()?.bytes()?; + + let dataset = CsvReadOptions::default() + .with_has_header(true) + .with_schema_overwrite(Some(Arc::new(schema))) + .map_parse_options(|parse_options| parse_options.with_try_parse_dates(true)) + .into_reader_with_file_handle(Cursor::new(data)) + .finish()? + .lazy() + .with_columns([ + col("first").name().suffix("_name"), + col("middle").name().suffix("_name"), + col("last").name().suffix("_name"), + ]) + .collect()?; + + println!("{}", &dataset); + // --8<-- [end:dataframe] + + // --8<-- [start:basic] + let df = dataset + .clone() + .lazy() + .group_by(["first_name"]) + .agg([len(), col("gender"), col("last_name").first()]) + .sort( + ["len"], + SortMultipleOptions::default() + .with_order_descending(true) + .with_nulls_last(true), + ) + .limit(5) + .collect()?; + + println!("{}", df); + // --8<-- [end:basic] + + // --8<-- [start:conditional] + let df = dataset + .clone() + .lazy() + .group_by(["state"]) + .agg([ + (col("party").eq(lit("Anti-Administration"))) + .sum() + .alias("anti"), + (col("party").eq(lit("Pro-Administration"))) + .sum() + .alias("pro"), + ]) + .sort( + ["pro"], + SortMultipleOptions::default().with_order_descending(true), + ) + .limit(5) + .collect()?; + + println!("{}", df); + // --8<-- [end:conditional] + + // --8<-- [start:nested] + let df = dataset + .clone() + .lazy() + .group_by(["state", "party"]) + .agg([len().alias("count")]) + .filter( + col("party") + .eq(lit("Anti-Administration")) + .or(col("party").eq(lit("Pro-Administration"))), + ) + .sort( + ["count"], + SortMultipleOptions::default() + .with_order_descending(true) + .with_nulls_last(true), + ) + .limit(5) + .collect()?; + + println!("{}", df); + // --8<-- [end:nested] + + // --8<-- [start:filter] + fn compute_age() -> Expr { + lit(2024) - col("birthday").dt().year() + } + + fn avg_birthday(gender: &str) -> Expr { + compute_age() + .filter(col("gender").eq(lit(gender))) + .mean() + .alias(format!("avg {} birthday", gender)) + } + + let df = dataset + .clone() + .lazy() + .group_by(["state"]) + .agg([ + avg_birthday("M"), + avg_birthday("F"), + (col("gender").eq(lit("M"))).sum().alias("# male"), + (col("gender").eq(lit("F"))).sum().alias("# female"), + ]) + .limit(5) + .collect()?; + + println!("{}", df); + // --8<-- [end:filter] + + // --8<-- [start:filter-nested] + let df = dataset + .clone() + .lazy() + .group_by(["state", "gender"]) + .agg([compute_age().mean().alias("avg birthday"), len().alias("#")]) + .sort( + ["#"], + SortMultipleOptions::default() + .with_order_descending(true) + .with_nulls_last(true), + ) + .limit(5) + .collect()?; + + println!("{}", df); + // --8<-- [end:filter-nested] + + // --8<-- [start:sort] + fn get_name() -> Expr { + col("first_name") + lit(" ") + col("last_name") + } + + let df = dataset + .clone() + .lazy() + .sort( + ["birthday"], + SortMultipleOptions::default() + .with_order_descending(true) + .with_nulls_last(true), + ) + .group_by(["state"]) + .agg([ + get_name().first().alias("youngest"), + get_name().last().alias("oldest"), + ]) + .limit(5) + .collect()?; + + println!("{}", df); + // --8<-- [end:sort] + + // --8<-- [start:sort2] + let df = dataset + .clone() + .lazy() + .sort( + ["birthday"], + SortMultipleOptions::default() + .with_order_descending(true) + .with_nulls_last(true), + ) + .group_by(["state"]) + .agg([ + get_name().first().alias("youngest"), + get_name().last().alias("oldest"), + get_name() + .sort(Default::default()) + .first() + .alias("alphabetical_first"), + ]) + .limit(5) + .collect()?; + + println!("{}", df); + // --8<-- [end:sort2] + + // --8<-- [start:sort3] + let df = dataset + .clone() + .lazy() + .sort( + ["birthday"], + SortMultipleOptions::default() + .with_order_descending(true) + .with_nulls_last(true), + ) + .group_by(["state"]) + .agg([ + get_name().first().alias("youngest"), + get_name().last().alias("oldest"), + get_name() + .sort(Default::default()) + .first() + .alias("alphabetical_first"), + col("gender") + .sort_by(["first_name"], SortMultipleOptions::default()) + .first(), + ]) + .sort(["state"], SortMultipleOptions::default()) + .limit(5) + .collect()?; + + println!("{}", df); + // --8<-- [end:sort3] + + Ok(()) +} diff --git a/docs/source/src/rust/user-guide/expressions/casting.rs b/docs/source/src/rust/user-guide/expressions/casting.rs new file mode 100644 index 000000000000..6e0cb2b9576a --- /dev/null +++ b/docs/source/src/rust/user-guide/expressions/casting.rs @@ -0,0 +1,180 @@ +fn main() -> Result<(), Box> { + // --8<-- [start:dfnum] + use polars::prelude::*; + + let df = df! ( + "integers"=> [1, 2, 3], + "big_integers"=> [10000002, 2, 30000003], + "floats"=> [4.0, 5.8, -6.3], + )?; + + println!("{}", df); + // --8<-- [end:dfnum] + + // --8<-- [start:castnum] + let result = df + .clone() + .lazy() + .select([ + col("integers") + .cast(DataType::Float32) + .alias("integers_as_floats"), + col("floats") + .cast(DataType::Int32) + .alias("floats_as_integers"), + ]) + .collect()?; + println!("{}", result); + // --8<-- [end:castnum] + + // --8<-- [start:downcast] + println!("Before downcasting: {} bytes", df.estimated_size()); + let result = df + .clone() + .lazy() + .with_columns([ + col("integers").cast(DataType::Int16), + col("floats").cast(DataType::Float32), + ]) + .collect()?; + println!("After downcasting: {} bytes", result.estimated_size()); + // --8<-- [end:downcast] + + // --8<-- [start:overflow] + let result = df + .clone() + .lazy() + .select([col("big_integers").strict_cast(DataType::Int8)]) + .collect(); + if let Err(e) = result { + println!("{}", e) + }; + // --8<-- [end:overflow] + + // --8<-- [start:overflow2] + let result = df + .clone() + .lazy() + .select([col("big_integers").cast(DataType::Int8)]) + .collect()?; + println!("{}", result); + // --8<-- [end:overflow2] + + // --8<-- [start:strings] + let df = df! ( + "integers_as_strings" => ["1", "2", "3"], + "floats_as_strings" => ["4.0", "5.8", "-6.3"], + "floats" => [4.0, 5.8, -6.3], + )?; + + let result = df + .clone() + .lazy() + .select([ + col("integers_as_strings").cast(DataType::Int32), + col("floats_as_strings").cast(DataType::Float64), + col("floats").cast(DataType::String), + ]) + .collect()?; + println!("{}", result); + // --8<-- [end:strings] + + // --8<-- [start:strings2] + let df = df! ("floats" => ["4.0", "5.8", "- 6 . 3"])?; + + let result = df + .clone() + .lazy() + .select([col("floats").strict_cast(DataType::Float64)]) + .collect(); + if let Err(e) = result { + println!("{}", e) + }; + // --8<-- [end:strings2] + + // --8<-- [start:bool] + let df = df! ( + "integers"=> [-1, 0, 2, 3, 4], + "floats"=> [0.0, 1.0, 2.0, 3.0, 4.0], + "bools"=> [true, false, true, false, true], + )?; + + let result = df + .clone() + .lazy() + .select([ + col("integers").cast(DataType::Boolean), + col("floats").cast(DataType::Boolean), + col("bools").cast(DataType::UInt8), + ]) + .collect()?; + println!("{}", result); + // --8<-- [end:bool] + + // --8<-- [start:dates] + use chrono::prelude::*; + + let df = df!( + "date" => [ + NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(), // epoch + NaiveDate::from_ymd_opt(1970, 1, 10).unwrap(), // 9 days later + ], + "datetime" => [ + NaiveDate::from_ymd_opt(1970, 1, 1).unwrap().and_hms_opt(0, 0, 0).unwrap(), // epoch + NaiveDate::from_ymd_opt(1970, 1, 1).unwrap().and_hms_opt(0, 1, 0).unwrap(), // 1 minute later + ], + "time" => [ + NaiveTime::from_hms_opt(0, 0, 0).unwrap(), // reference time + NaiveTime::from_hms_opt(0, 0, 1).unwrap(), // 1 second later + ] + ) + .unwrap() + .lazy() + // Make the time unit match that of Python's for the same results. + .with_column(col("datetime").cast(DataType::Datetime(TimeUnit::Microseconds, None))) + .collect()?; + + let result = df + .clone() + .lazy() + .select([ + col("date").cast(DataType::Int64).alias("days_since_epoch"), + col("datetime") + .cast(DataType::Int64) + .alias("us_since_epoch"), + col("time").cast(DataType::Int64).alias("ns_since_midnight"), + ]) + .collect()?; + println!("{}", result); + // --8<-- [end:dates] + + // --8<-- [start:dates2] + let df = df! ( + "date" => [ + NaiveDate::from_ymd_opt(2022, 1, 1).unwrap(), + NaiveDate::from_ymd_opt(2022, 1, 2).unwrap(), + ], + "string" => [ + "2022-01-01", + "2022-01-02", + ], + )?; + + let result = df + .clone() + .lazy() + .select([ + col("date").dt().to_string("%Y-%m-%d"), + col("string").str().to_datetime( + Some(TimeUnit::Microseconds), + None, + StrptimeOptions::default(), + lit("raise"), + ), + ]) + .collect()?; + println!("{}", result); + // --8<-- [end:dates2] + + Ok(()) +} diff --git a/docs/source/src/rust/user-guide/expressions/column-selections.rs b/docs/source/src/rust/user-guide/expressions/column-selections.rs new file mode 100644 index 000000000000..0dff5ab38c62 --- /dev/null +++ b/docs/source/src/rust/user-guide/expressions/column-selections.rs @@ -0,0 +1,100 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box> { + // --8<-- [start:selectors_df] + + use chrono::prelude::*; + use polars::time::*; + + let df = df!( + "id" => &[9, 4, 2], + "place" => &["Mars", "Earth", "Saturn"], + "date" => date_range("date".into(), + NaiveDate::from_ymd_opt(2022, 1, 1).unwrap().and_hms_opt(0, 0, 0).unwrap(), NaiveDate::from_ymd_opt(2022, 1, 3).unwrap().and_hms_opt(0, 0, 0).unwrap(), Duration::parse("1d"),ClosedWindow::Both, TimeUnit::Milliseconds, None)?, + "sales" => &[33.4, 2142134.1, 44.7], + "has_people" => &[false, true, false], + "logged_at" => date_range("logged_at".into(), + NaiveDate::from_ymd_opt(2022, 1, 1).unwrap().and_hms_opt(0, 0, 0).unwrap(), NaiveDate::from_ymd_opt(2022, 1, 1).unwrap().and_hms_opt(0, 0, 2).unwrap(), Duration::parse("1s"),ClosedWindow::Both, TimeUnit::Milliseconds, None)?, + )? + .with_row_index("index".into(), None)?; + println!("{}", &df); + // --8<-- [end:selectors_df] + + // --8<-- [start:all] + let out = df.clone().lazy().select([col("*")]).collect()?; + println!("{}", &out); + + // Is equivalent to + let out = df.clone().lazy().select([all()]).collect()?; + println!("{}", &out); + // --8<-- [end:all] + + // --8<-- [start:exclude] + let out = df + .clone() + .lazy() + .select([col("*").exclude(["logged_at", "index"])]) + .collect()?; + println!("{}", &out); + // --8<-- [end:exclude] + + // --8<-- [start:expansion_by_names] + let out = df + .clone() + .lazy() + .select([cols(["date", "logged_at"]).dt().to_string("%Y-%h-%d")]) + .collect()?; + println!("{}", &out); + // --8<-- [end:expansion_by_names] + + // --8<-- [start:expansion_by_regex] + let out = df.clone().lazy().select([col("^.*(as|sa).*$")]).collect()?; + println!("{}", &out); + // --8<-- [end:expansion_by_regex] + + // --8<-- [start:expansion_by_dtype] + let out = df + .clone() + .lazy() + .select([dtype_cols([DataType::Int64, DataType::UInt32, DataType::Boolean]).n_unique()]) + .collect()?; + // gives different result than python as the id col is i32 in rust + println!("{}", &out); + // --8<-- [end:expansion_by_dtype] + + // --8<-- [start:selectors_intro] + // Not available in Rust, refer the following link + // https://github.com/pola-rs/polars/issues/10594 + // --8<-- [end:selectors_intro] + + // --8<-- [start:selectors_diff] + // Not available in Rust, refer the following link + // https://github.com/pola-rs/polars/issues/10594 + // --8<-- [end:selectors_diff] + + // --8<-- [start:selectors_union] + // Not available in Rust, refer the following link + // https://github.com/pola-rs/polars/issues/10594 + // --8<-- [end:selectors_union] + + // --8<-- [start:selectors_by_name] + // Not available in Rust, refer the following link + // https://github.com/pola-rs/polars/issues/10594 + // --8<-- [end:selectors_by_name] + + // --8<-- [start:selectors_to_expr] + // Not available in Rust, refer the following link + // https://github.com/pola-rs/polars/issues/10594 + // --8<-- [end:selectors_to_expr] + + // --8<-- [start:selectors_is_selector_utility] + // Not available in Rust, refer the following link + // https://github.com/pola-rs/polars/issues/10594 + // --8<-- [end:selectors_is_selector_utility] + + // --8<-- [start:selectors_colnames_utility] + // Not available in Rust, refer the following link + // https://github.com/pola-rs/polars/issues/10594 + // --8<-- [end:selectors_colnames_utility] + Ok(()) +} diff --git a/docs/source/src/rust/user-guide/expressions/expression-expansion.rs b/docs/source/src/rust/user-guide/expressions/expression-expansion.rs new file mode 100644 index 000000000000..a1752a7d0334 --- /dev/null +++ b/docs/source/src/rust/user-guide/expressions/expression-expansion.rs @@ -0,0 +1,215 @@ +fn main() -> Result<(), Box> { + // --8<-- [start:df] + use polars::prelude::*; + + // Data as of 14th October 2024, ~3pm UTC + let df = df!( + "ticker" => ["AAPL", "NVDA", "MSFT", "GOOG", "AMZN"], + "company_name" => ["Apple", "NVIDIA", "Microsoft", "Alphabet (Google)", "Amazon"], + "price" => [229.9, 138.93, 420.56, 166.41, 188.4], + "day_high" => [231.31, 139.6, 424.04, 167.62, 189.83], + "day_low" => [228.6, 136.3, 417.52, 164.78, 188.44], + "year_high" => [237.23, 140.76, 468.35, 193.31, 201.2], + "year_low" => [164.08, 39.23, 324.39, 121.46, 118.35], + )?; + + println!("{}", df); + // --8<-- [end:df] + + // --8<-- [start:col-with-names] + let eur_usd_rate = 1.09; // As of 14th October 2024 + + let result = df + .clone() + .lazy() + .with_column( + (cols(["price", "day_high", "day_low", "year_high", "year_low"]) / lit(eur_usd_rate)) + .round(2), + ) + .collect()?; + println!("{}", result); + // --8<-- [end:col-with-names] + + // --8<-- [start:expression-list] + let exprs = [ + (col("price") / lit(eur_usd_rate)).round(2), + (col("day_high") / lit(eur_usd_rate)).round(2), + (col("day_low") / lit(eur_usd_rate)).round(2), + (col("year_high") / lit(eur_usd_rate)).round(2), + (col("year_low") / lit(eur_usd_rate)).round(2), + ]; + + let result2 = df.clone().lazy().with_columns(exprs).collect()?; + println!("{}", result.equals(&result2)); + // --8<-- [end:expression-list] + + // --8<-- [start:col-with-dtype] + let result = df + .clone() + .lazy() + .with_column((dtype_col(&DataType::Float64) / lit(eur_usd_rate)).round(2)) + .collect()?; + println!("{}", result); + // --8<-- [end:col-with-dtype] + + // --8<-- [start:col-with-dtypes] + let result2 = df + .clone() + .lazy() + .with_column( + (dtype_cols([DataType::Float32, DataType::Float64]) / lit(eur_usd_rate)).round(2), + ) + .collect()?; + println!("{}", result.equals(&result2)); + // --8<-- [end:col-with-dtypes] + + // --8<-- [start:col-with-regex] + // NOTE: Using regex inside `col`/`cols` requires the feature flag `regex`. + let result = df + .clone() + .lazy() + .select([cols(["ticker", "^.*_high$", "^.*_low$"])]) + .collect()?; + println!("{}", result); + // --8<-- [end:col-with-regex] + + // --8<-- [start:all] + let result = df.clone().lazy().select([all()]).collect()?; + println!("{}", result.equals(&df)); + // --8<-- [end:all] + + // --8<-- [start:all-exclude] + let result = df + .clone() + .lazy() + .select([all().exclude(["^day_.*$"])]) + .collect()?; + println!("{}", result); + // --8<-- [end:all-exclude] + + // --8<-- [start:col-exclude] + let result = df + .clone() + .lazy() + .select([dtype_col(&DataType::Float64).exclude(["^day_.*$"])]) + .collect()?; + println!("{}", result); + // --8<-- [end:col-exclude] + + // --8<-- [start:duplicate-error] + let gbp_usd_rate = 1.31; // As of 14th October 2024 + + let result = df + .clone() + .lazy() + .select([ + col("price") / lit(gbp_usd_rate), + col("price") / lit(eur_usd_rate), + ]) + .collect(); + match result { + Ok(df) => println!("{}", df), + Err(e) => println!("{}", e), + }; + // --8<-- [end:duplicate-error] + + // --8<-- [start:alias] + let _result = df + .clone() + .lazy() + .select([ + (col("price") / lit(gbp_usd_rate)).alias("price (GBP)"), + (col("price") / lit(eur_usd_rate)).alias("price (EUR)"), + ]) + .collect()?; + // --8<-- [end:alias] + + // --8<-- [start:prefix-suffix] + let result = df + .clone() + .lazy() + .select([ + (col("^year_.*$") / lit(eur_usd_rate)) + .name() + .prefix("in_eur_"), + (cols(["day_high", "day_low"]) / lit(gbp_usd_rate)) + .name() + .suffix("_gbp"), + ]) + .collect()?; + println!("{}", result); + // --8<-- [end:prefix-suffix] + + // --8<-- [start:name-map] + // There is also `name().to_uppercase()`, so this usage of `map` is moot. + let result = df + .clone() + .lazy() + .select([all() + .name() + .map(|name| Ok(PlSmallStr::from_string(name.to_ascii_uppercase())))]) + .collect()?; + println!("{}", result); + // --8<-- [end:name-map] + + // --8<-- [start:for-with_columns] + let mut result = df.clone().lazy(); + for tp in ["day", "year"] { + let high = format!("{}_high", tp); + let low = format!("{}_low", tp); + let aliased = format!("{}_amplitude", tp); + result = result.with_column((col(high) - col(low)).alias(aliased)) + } + let result = result.collect()?; + println!("{}", result); + // --8<-- [end:for-with_columns] + + // --8<-- [start:yield-expressions] + let mut exprs: Vec = vec![]; + for tp in ["day", "year"] { + let high = format!("{}_high", tp); + let low = format!("{}_low", tp); + let aliased = format!("{}_amplitude", tp); + exprs.push((col(high) - col(low)).alias(aliased)) + } + let result = df.clone().lazy().with_columns(exprs).collect()?; + println!("{}", result); + // --8<-- [end:yield-expressions] + + // --8<-- [start:selectors] + // Selectors are not available in Rust yet. + // Refer to https://github.com/pola-rs/polars/issues/10594 + // --8<-- [end:selectors] + + // --8<-- [start:selectors-set-operations] + // Selectors are not available in Rust yet. + // Refer to https://github.com/pola-rs/polars/issues/10594 + // --8<-- [end:selectors-set-operations] + + // --8<-- [start:selectors-expressions] + // Selectors are not available in Rust yet. + // Refer to https://github.com/pola-rs/polars/issues/10594 + // --8<-- [end:selectors-expressions] + + // --8<-- [start:selector-ambiguity] + // Selectors are not available in Rust yet. + // Refer to https://github.com/pola-rs/polars/issues/10594 + // --8<-- [end:selector-ambiguity] + + // --8<-- [start:as_expr] + // Selectors are not available in Rust yet. + // Refer to https://github.com/pola-rs/polars/issues/10594 + // --8<-- [end:as_expr] + + // --8<-- [start:is_selector] + // Selectors are not available in Rust yet. + // Refer to https://github.com/pola-rs/polars/issues/10594 + // --8<-- [end:is_selector] + + // --8<-- [start:expand_selector] + // Selectors are not available in Rust yet. + // Refer to https://github.com/pola-rs/polars/issues/10594 + // --8<-- [end:expand_selector] + + Ok(()) +} diff --git a/docs/source/src/rust/user-guide/expressions/folds.rs b/docs/source/src/rust/user-guide/expressions/folds.rs new file mode 100644 index 000000000000..3ab11e8772c8 --- /dev/null +++ b/docs/source/src/rust/user-guide/expressions/folds.rs @@ -0,0 +1,113 @@ +fn main() -> Result<(), Box> { + // --8<-- [start:mansum] + use polars::lazy::dsl::sum_horizontal; + use polars::prelude::*; + + let df = df!( + "label" => ["foo", "bar", "spam"], + "a" => [1, 2, 3], + "b" => [10, 20, 30], + )?; + + let result = df + .clone() + .lazy() + .select([ + fold_exprs( + lit(0), + |acc, val| (&acc + &val).map(Some), + [col("a"), col("b")], + ) + .alias("sum_fold"), + sum_horizontal([col("a"), col("b")], true)?.alias("sum_horz"), + ]) + .collect()?; + + println!("{:?}", result); + // --8<-- [end:mansum] + + // --8<-- [start:mansum-explicit] + let acc = lit(0); + let f = |acc: Expr, val: Expr| acc + val; + + let result = df + .clone() + .lazy() + .select([ + f(f(acc, col("a")), col("b")), + fold_exprs( + lit(0), + |acc, val| (&acc + &val).map(Some), + [col("a"), col("b")], + ) + .alias("sum_fold"), + ]) + .collect()?; + + println!("{:?}", result); + // --8<-- [end:mansum-explicit] + + // --8<-- [start:manprod] + let result = df + .clone() + .lazy() + .select([fold_exprs( + lit(0), + |acc, val| (&acc * &val).map(Some), + [col("a"), col("b")], + ) + .alias("prod")]) + .collect()?; + + println!("{:?}", result); + // --8<-- [end:manprod] + + // --8<-- [start:manprod-fixed] + let result = df + .clone() + .lazy() + .select([fold_exprs( + lit(1), + |acc, val| (&acc * &val).map(Some), + [col("a"), col("b")], + ) + .alias("prod")]) + .collect()?; + + println!("{:?}", result); + // --8<-- [end:manprod-fixed] + + // --8<-- [start:conditional] + let df = df!( + "a" => [1, 2, 3], + "b" => [0, 1, 2], + )?; + + let result = df + .clone() + .lazy() + .filter(fold_exprs( + lit(true), + |acc, val| (&acc & &val).map(Some), + [col("*").gt(1)], + )) + .collect()?; + + println!("{:?}", result); + // --8<-- [end:conditional] + + // --8<-- [start:string] + let df = df!( + "a" => ["a", "b", "c"], + "b" => [1, 2, 3], + )?; + + let result = df + .lazy() + .select([concat_str([col("a"), col("b")], "", false)]) + .collect()?; + println!("{:?}", result); + // --8<-- [end:string] + + Ok(()) +} diff --git a/docs/source/src/rust/user-guide/expressions/lists.rs b/docs/source/src/rust/user-guide/expressions/lists.rs new file mode 100644 index 000000000000..ee8bf9597ce7 --- /dev/null +++ b/docs/source/src/rust/user-guide/expressions/lists.rs @@ -0,0 +1,51 @@ +fn main() -> Result<(), Box> { + // --8<-- [start:list-example] + // Contribute the Rust translation of the Python example by opening a PR. + // --8<-- [end:list-example] + + // --8<-- [start:array-example] + // Contribute the Rust translation of the Python example by opening a PR. + // --8<-- [end:array-example] + + // --8<-- [start:numpy-array-inference] + // Contribute the Rust translation of the Python example by opening a PR. + // --8<-- [end:numpy-array-inference] + + // --8<-- [start:weather] + // Contribute the Rust translation of the Python example by opening a PR. + // --8<-- [end:weather] + + // --8<-- [start:split] + // Contribute the Rust translation of the Python example by opening a PR. + // --8<-- [end:split] + + // --8<-- [start:explode] + // Contribute the Rust translation of the Python example by opening a PR. + // --8<-- [end:explode] + + // --8<-- [start:list-slicing] + // Contribute the Rust translation of the Python example by opening a PR. + // --8<-- [end:list-slicing] + + // --8<-- [start:element-wise-casting] + // Contribute the Rust translation of the Python example by opening a PR. + // --8<-- [end:element-wise-casting] + + // --8<-- [start:element-wise-regex] + // Contribute the Rust translation of the Python example by opening a PR. + // --8<-- [end:element-wise-regex] + + // --8<-- [start:weather_by_day] + // Contribute the Rust translation of the Python example by opening a PR. + // --8<-- [end:weather_by_day] + + // --8<-- [start:rank_pct] + // Contribute the Rust translation of the Python example by opening a PR. + // --8<-- [end:rank_pct] + + // --8<-- [start:array-overview] + // Contribute the Rust translation of the Python example by opening a PR. + // --8<-- [end:array-overview] + + Ok(()) +} diff --git a/docs/source/src/rust/user-guide/expressions/missing-data.rs b/docs/source/src/rust/user-guide/expressions/missing-data.rs new file mode 100644 index 000000000000..35534c54523e --- /dev/null +++ b/docs/source/src/rust/user-guide/expressions/missing-data.rs @@ -0,0 +1,118 @@ +fn main() -> Result<(), Box> { + // --8<-- [start:dataframe] + use polars::prelude::*; + let df = df! ( + "value" => &[Some(1), None], + )?; + + println!("{}", df); + // --8<-- [end:dataframe] + + // --8<-- [start:count] + let null_count_df = df.null_count(); + println!("{}", null_count_df); + // --8<-- [end:count] + + // --8<-- [start:isnull] + let is_null_series = df + .clone() + .lazy() + .select([col("value").is_null()]) + .collect()?; + println!("{}", is_null_series); + // --8<-- [end:isnull] + + // --8<-- [start:dataframe2] + let df = df! ( + "col1" => [0.5, 1.0, 1.5, 2.0, 2.5], + "col2" => [Some(1), None, Some(3), None, Some(5)], + )?; + + println!("{}", df); + // --8<-- [end:dataframe2] + + // --8<-- [start:fill] + let fill_literal_df = df + .clone() + .lazy() + .with_column(col("col2").fill_null(3)) + .collect()?; + + println!("{}", fill_literal_df); + // --8<-- [end:fill] + + // --8<-- [start:fillstrategy] + + let fill_literal_df = df + .clone() + .lazy() + .with_columns([ + col("col2") + .fill_null_with_strategy(FillNullStrategy::Forward(None)) + .alias("forward"), + col("col2") + .fill_null_with_strategy(FillNullStrategy::Backward(None)) + .alias("backward"), + ]) + .collect()?; + + println!("{}", fill_literal_df); + // --8<-- [end:fillstrategy] + + // --8<-- [start:fillexpr] + let fill_expression_df = df + .clone() + .lazy() + .with_column(col("col2").fill_null((lit(2) * col("col1")).cast(DataType::Int64))) + .collect()?; + + println!("{}", fill_expression_df); + // --8<-- [end:fillexpr] + + // --8<-- [start:fillinterpolate] + let fill_interpolation_df = df + .clone() + .lazy() + .with_column(col("col2").interpolate(InterpolationMethod::Linear)) + .collect()?; + + println!("{}", fill_interpolation_df); + // --8<-- [end:fillinterpolate] + + // --8<-- [start:nan] + let nan_df = df!( + "value" => [1.0, f64::NAN, f64::NAN, 3.0], + )?; + println!("{}", nan_df); + // --8<-- [end:nan] + + // --8<-- [start:nan-computed] + let df = df!( + "dividend" => [1.0, 0.0, -1.0], + "divisor" => [1.0, 0.0, -1.0], + )?; + + let result = df + .clone() + .lazy() + .select([col("dividend") / col("divisor")]) + .collect()?; + + println!("{}", result); + // --8<-- [end:nan-computed] + + // --8<-- [start:nanfill] + let mean_nan_df = nan_df + .clone() + .lazy() + .with_column(col("value").fill_nan(Null {}.lit()).alias("replaced")) + .select([ + col("*").mean().name().suffix("_mean"), + col("*").sum().name().suffix("_sum"), + ]) + .collect()?; + + println!("{}", mean_nan_df); + // --8<-- [end:nanfill] + Ok(()) +} diff --git a/docs/source/src/rust/user-guide/expressions/operations.rs b/docs/source/src/rust/user-guide/expressions/operations.rs new file mode 100644 index 000000000000..55fbf9412f0e --- /dev/null +++ b/docs/source/src/rust/user-guide/expressions/operations.rs @@ -0,0 +1,138 @@ +fn main() -> Result<(), Box> { + // --8<-- [start:dataframe] + use polars::prelude::*; + + let df = df! ( + "nrs" => &[Some(1), Some(2), Some(3), None, Some(5)], + "names" => &["foo", "ham", "spam", "egg", "spam"], + "random" => &[0.37454, 0.950714, 0.731994, 0.598658, 0.156019], + "groups" => &["A", "A", "B", "A", "B"], + )?; + + println!("{}", &df); + // --8<-- [end:dataframe] + + // --8<-- [start:arithmetic] + let result = df + .clone() + .lazy() + .select([ + (col("nrs") + lit(5)).alias("nrs + 5"), + (col("nrs") - lit(5)).alias("nrs - 5"), + (col("nrs") * col("random")).alias("nrs * random"), + (col("nrs") / col("random")).alias("nrs / random"), + (col("nrs").pow(lit(2))).alias("nrs ** 2"), + (col("nrs") % lit(3)).alias("nrs % 3"), + ]) + .collect()?; + println!("{}", result); + // --8<-- [end:arithmetic] + + // --8<-- [start:comparison] + let result = df + .clone() + .lazy() + .select([ + col("nrs").gt(1).alias("nrs > 1"), + col("nrs").gt_eq(3).alias("nrs >= 3"), + col("random").lt_eq(0.2).alias("random < .2"), + col("random").lt_eq(0.5).alias("random <= .5"), + col("nrs").neq(1).alias("nrs != 1"), + col("nrs").eq(1).alias("nrs == 1"), + ]) + .collect()?; + println!("{}", result); + // --8<-- [end:comparison] + + // --8<-- [start:boolean] + let result = df + .clone() + .lazy() + .select([ + ((col("nrs").is_null()).not().and(col("groups").eq(lit("A")))) + .alias("number not null and group A"), + (col("random").lt(lit(0.5)).or(col("groups").eq(lit("B")))) + .alias("random < 0.5 or group B"), + ]) + .collect()?; + println!("{}", result); + // --8<-- [end:boolean] + + // --8<-- [start:bitwise] + let result = df + .clone() + .lazy() + .select([ + col("nrs"), + col("nrs").and(lit(6)).alias("nrs & 6"), + col("nrs").or(lit(6)).alias("nrs | 6"), + col("nrs").not().alias("not nrs"), + col("nrs").xor(lit(6)).alias("nrs ^ 6"), + ]) + .collect()?; + println!("{}", result); + // --8<-- [end:bitwise] + + // --8<-- [start:count] + use rand::distributions::{Distribution, Uniform}; + use rand::thread_rng; + + let mut rng = thread_rng(); + let between = Uniform::new_inclusive(0, 100_000); + let arr: Vec = between.sample_iter(&mut rng).take(100_100).collect(); + + let long_df = df!( + "numbers" => &arr + )?; + + let result = long_df + .clone() + .lazy() + .select([ + col("numbers").n_unique().alias("n_unique"), + col("numbers").approx_n_unique().alias("approx_n_unique"), + ]) + .collect()?; + println!("{}", result); + // --8<-- [end:count] + + // --8<-- [start:value_counts] + let result = df + .clone() + .lazy() + .select([col("names") + .value_counts(false, false, "count", false) + .alias("value_counts")]) + .collect()?; + println!("{}", result); + // --8<-- [end:value_counts] + + // --8<-- [start:unique_counts] + let result = df + .clone() + .lazy() + .select([ + col("names").unique_stable().alias("unique"), + col("names").unique_counts().alias("unique_counts"), + ]) + .collect()?; + println!("{}", result); + // --8<-- [end:unique_counts] + + // --8<-- [start:collatz] + let result = df + .clone() + .lazy() + .select([ + col("nrs"), + when((col("nrs") % lit(2)).eq(lit(1))) + .then(lit(3) * col("nrs") + lit(1)) + .otherwise(col("nrs") / lit(2)) + .alias("Collatz"), + ]) + .collect()?; + println!("{}", result); + // --8<-- [end:collatz] + + Ok(()) +} diff --git a/docs/source/src/rust/user-guide/expressions/strings.rs b/docs/source/src/rust/user-guide/expressions/strings.rs new file mode 100644 index 000000000000..c77eb35302ef --- /dev/null +++ b/docs/source/src/rust/user-guide/expressions/strings.rs @@ -0,0 +1,170 @@ +fn main() -> Result<(), Box> { + // --8<-- [start:df] + use polars::prelude::*; + + let df = df! ( + "language" => ["English", "Dutch", "Portuguese", "Finish"], + "fruit" => ["pear", "peer", "pêra", "päärynä"], + )?; + + let result = df + .clone() + .lazy() + .with_columns([ + col("fruit").str().len_bytes().alias("byte_count"), + col("fruit").str().len_chars().alias("letter_count"), + ]) + .collect()?; + + println!("{}", result); + // --8<-- [end:df] + + // --8<-- [start:existence] + let result = df + .clone() + .lazy() + .select([ + col("fruit"), + col("fruit") + .str() + .starts_with(lit("p")) + .alias("starts_with_p"), + col("fruit").str().contains(lit("p..r"), true).alias("p..r"), + col("fruit").str().contains(lit("e+"), true).alias("e+"), + col("fruit").str().ends_with(lit("r")).alias("ends_with_r"), + ]) + .collect()?; + + println!("{}", result); + // --8<-- [end:existence] + + // --8<-- [start:extract] + let df = df! ( + "urls" => [ + "http://vote.com/ballon_dor?candidate=messi&ref=polars", + "http://vote.com/ballon_dor?candidat=jorginho&ref=polars", + "http://vote.com/ballon_dor?candidate=ronaldo&ref=polars", + ] + )?; + + let result = df + .clone() + .lazy() + .select([col("urls").str().extract(lit(r"candidate=(\w+)"), 1)]) + .collect()?; + + println!("{}", result); + // --8<-- [end:extract] + + // --8<-- [start:extract_all] + let df = df! ( + "text" => ["123 bla 45 asd", "xyz 678 910t"] + )?; + + let result = df + .clone() + .lazy() + .select([col("text") + .str() + .extract_all(lit(r"(\d+)")) + .alias("extracted_nrs")]) + .collect()?; + + println!("{}", result); + // --8<-- [end:extract_all] + + // --8<-- [start:replace] + let df = df! ( + "text" => ["123abc", "abc456"] + )?; + + let result = df + .clone() + .lazy() + .with_columns([ + col("text").str().replace(lit(r"\d"), lit("-"), false), + col("text") + .str() + .replace_all(lit(r"\d"), lit("-"), false) + .alias("text_replace_all"), + ]) + .collect()?; + + println!("{}", result); + // --8<-- [end:replace] + + // --8<-- [start:casing] + let addresses = df! ( + "addresses" => [ + "128 PERF st", + "Rust blVD, 158", + "PoLaRs Av, 12", + "1042 Query sq", + ] + )?; + + let addresses = addresses + .clone() + .lazy() + .select([ + col("addresses").alias("originals"), + col("addresses").str().to_titlecase(), + col("addresses").str().to_lowercase().alias("lower"), + col("addresses").str().to_uppercase().alias("upper"), + ]) + .collect()?; + + println!("{}", addresses); + // --8<-- [end:casing] + + // --8<-- [start:strip] + let addr = col("addresses"); + let chars = lit(", 0123456789"); + let result = addresses + .clone() + .lazy() + .select([ + addr.clone().str().strip_chars(chars.clone()).alias("strip"), + addr.clone() + .str() + .strip_chars_end(chars.clone()) + .alias("end"), + addr.clone() + .str() + .strip_chars_start(chars.clone()) + .alias("start"), + addr.clone().str().strip_prefix(lit("128 ")).alias("prefix"), + addr.clone() + .str() + .strip_suffix(lit(", 158")) + .alias("suffix"), + ]) + .collect()?; + + println!("{}", result); + // --8<-- [end:strip] + + // --8<-- [start:slice] + let df = df! ( + "fruits" => ["pear", "mango", "dragonfruit", "passionfruit"], + "n" => [1, -1, 4, -4], + )?; + + let result = df + .clone() + .lazy() + .with_columns([ + col("fruits") + .str() + .slice(col("n"), lit(NULL)) + .alias("slice"), + col("fruits").str().head(col("n")).alias("head"), + col("fruits").str().tail(col("n")).alias("tail"), + ]) + .collect()?; + + println!("{}", result); + // --8<-- [end:slice] + + Ok(()) +} diff --git a/docs/source/src/rust/user-guide/expressions/structs.rs b/docs/source/src/rust/user-guide/expressions/structs.rs new file mode 100644 index 000000000000..d62717f30628 --- /dev/null +++ b/docs/source/src/rust/user-guide/expressions/structs.rs @@ -0,0 +1,146 @@ +fn main() -> Result<(), Box> { + // --8<-- [start:ratings_df] + use polars::prelude::*; + let ratings = df!( + "Movie"=> ["Cars", "IT", "ET", "Cars", "Up", "IT", "Cars", "ET", "Up", "Cars"], + "Theatre"=> ["NE", "ME", "IL", "ND", "NE", "SD", "NE", "IL", "IL", "NE"], + "Avg_Rating"=> [4.5, 4.4, 4.6, 4.3, 4.8, 4.7, 4.5, 4.9, 4.7, 4.6], + "Count"=> [30, 27, 26, 29, 31, 28, 28, 26, 33, 28], + + )?; + println!("{}", &ratings); + // --8<-- [end:ratings_df] + + // --8<-- [start:state_value_counts] + let result = ratings + .clone() + .lazy() + .select([col("Theatre").value_counts(true, true, "count", false)]) + .collect()?; + println!("{}", result); + // --8<-- [end:state_value_counts] + + // --8<-- [start:struct_unnest] + let result = ratings + .clone() + .lazy() + .select([col("Theatre").value_counts(true, true, "count", false)]) + .unnest(["Theatre"]) + .collect()?; + println!("{}", result); + // --8<-- [end:struct_unnest] + + // --8<-- [start:series_struct] + // Don't think we can make it the same way in rust, but this works + let rating_series = df!( + "Movie" => &["Cars", "Toy Story"], + "Theatre" => &["NE", "ME"], + "Avg_Rating" => &[4.5, 4.9], + )? + .into_struct("ratings".into()) + .into_series(); + println!("{}", &rating_series); + // // --8<-- [end:series_struct] + + // --8<-- [start:series_struct_error] + // Contribute the Rust translation of the Python example by opening a PR. + // --8<-- [end:series_struct_error] + + // --8<-- [start:series_struct_extract] + let result = rating_series.struct_()?.field_by_name("Movie")?; + println!("{}", result); + // --8<-- [end:series_struct_extract] + + // --8<-- [start:series_struct_rename] + // Contribute the Rust translation of the Python example by opening a PR. + // --8<-- [end:series_struct_rename] + + // --8<-- [start:struct-rename-check] + // Contribute the Rust translation of the Python example by opening a PR. + // --8<-- [end:struct-rename-check] + + // --8<-- [start:struct_duplicates] + // Contribute the Rust translation of the Python example by opening a PR. + // --8<-- [end:struct_duplicates] + + // --8<-- [start:struct_ranking] + let result = ratings + .clone() + .lazy() + .with_columns([as_struct(vec![col("Count"), col("Avg_Rating")]) + .rank( + RankOptions { + method: RankMethod::Dense, + descending: true, + }, + None, + ) + .over([col("Movie"), col("Theatre")]) + .alias("Rank")]) + // .filter(as_struct(&[col("Movie"), col("Theatre")]).is_duplicated()) + // Error: .is_duplicated() not available if you try that + // https://github.com/pola-rs/polars/issues/3803 + .filter(len().over([col("Movie"), col("Theatre")]).gt(lit(1))) + .collect()?; + println!("{}", result); + // --8<-- [end:struct_ranking] + + // --8<-- [start:multi_column_apply] + let df = df!( + "keys" => ["a", "a", "b"], + "values" => [10, 7, 1], + )?; + + let result = df + .lazy() + .select([ + // pack to struct to get access to multiple fields in a custom `apply/map` + as_struct(vec![col("keys"), col("values")]) + // we will compute the len(a) + b + .apply( + |s| { + // downcast to struct + let ca = s.struct_()?; + + // get the fields as Series + let s_a = &ca.fields_as_series()[0]; + let s_b = &ca.fields_as_series()[1]; + + // downcast the `Series` to their known type + let ca_a = s_a.str()?; + let ca_b = s_b.i32()?; + + // iterate both `ChunkedArrays` + let result: Int32Chunked = ca_a + .into_iter() + .zip(ca_b) + .map(|(opt_a, opt_b)| match (opt_a, opt_b) { + (Some(a), Some(b)) => Some(a.len() as i32 + b), + _ => None, + }) + .collect(); + + Ok(Some(result.into_column())) + }, + GetOutput::from_type(DataType::Int32), + ) + // note: the `'solution_map_elements'` alias is just there to show how you + // get the same output as in the Python API example. + .alias("solution_map_elements"), + (col("keys").str().count_matches(lit("."), true) + col("values")) + .alias("solution_expr"), + ]) + .collect()?; + println!("{}", result); + // --8<-- [end:multi_column_apply] + + // --8<-- [start:ack] + // Contribute the Rust translation of the Python example by opening a PR. + // --8<-- [end:ack] + + // --8<-- [start:struct-ack] + // Contribute the Rust translation of the Python example by opening a PR. + // --8<-- [end:struct-ack] + + Ok(()) +} diff --git a/docs/source/src/rust/user-guide/expressions/user-defined-functions.rs b/docs/source/src/rust/user-guide/expressions/user-defined-functions.rs new file mode 100644 index 000000000000..b83898ef6c7c --- /dev/null +++ b/docs/source/src/rust/user-guide/expressions/user-defined-functions.rs @@ -0,0 +1,27 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box> { + // --8<-- [start:dataframe] + let df = df!( + "keys" => &["a", "a", "b", "b"], + "values" => &[10, 7, 1, 23], + )?; + println!("{}", df); + // --8<-- [end:dataframe] + + // --8<-- [start:individual_log] + // --8<-- [end:individual_log] + + // --8<-- [start:diff_from_mean] + // --8<-- [end:diff_from_mean] + + // --8<-- [start:np_log] + // --8<-- [end:np_log] + + // --8<-- [start:diff_from_mean_numba] + // --8<-- [end:diff_from_mean_numba] + + // --8<-- [start:combine] + // --8<-- [end:combine] + Ok(()) +} diff --git a/docs/source/src/rust/user-guide/expressions/window.rs b/docs/source/src/rust/user-guide/expressions/window.rs new file mode 100644 index 000000000000..9579b3c0adb7 --- /dev/null +++ b/docs/source/src/rust/user-guide/expressions/window.rs @@ -0,0 +1,172 @@ +fn main() -> Result<(), Box> { + // --8<-- [start:pokemon] + use polars::prelude::*; + use reqwest::blocking::Client; + + let data: Vec = Client::new() + .get("https://gist.githubusercontent.com/ritchie46/cac6b337ea52281aa23c049250a4ff03/raw/89a957ff3919d90e6ef2d34235e6bf22304f3366/pokemon.csv") + .send()? + .text()? + .bytes() + .collect(); + + let file = std::io::Cursor::new(data); + let df = CsvReadOptions::default() + .with_has_header(true) + .into_reader_with_file_handle(file) + .finish()?; + + println!("{}", df.head(Some(5))); + // --8<-- [end:pokemon] + + // --8<-- [start:rank] + let result = df + .clone() + .lazy() + .select([ + col("Name"), + col("Type 1"), + col("Speed") + .rank( + RankOptions { + method: RankMethod::Dense, + descending: true, + }, + None, + ) + .over(["Type 1"]) + .alias("Speed rank"), + ]) + .collect()?; + + println!("{}", result); + // --8<-- [end:rank] + + // --8<-- [start:rank-multiple] + // Contribute the Rust translation of the Python example by opening a PR. + // --8<-- [end:rank-multiple] + + // --8<-- [start:rank-explode] + // Contribute the Rust translation of the Python example by opening a PR. + // --8<-- [end:rank-explode] + + // --8<-- [start:athletes] + // Contribute the Rust translation of the Python example by opening a PR. + // --8<-- [end:athletes] + + // --8<-- [start:athletes-sort-over-country] + // Contribute the Rust translation of the Python example by opening a PR. + // --8<-- [end:athletes-sort-over-country] + + // --8<-- [start:athletes-explode] + // Contribute the Rust translation of the Python example by opening a PR. + // --8<-- [end:athletes-explode] + + // --8<-- [start:athletes-join] + // Contribute the Rust translation of the Python example by opening a PR. + // --8<-- [end:athletes-join] + + // --8<-- [start:pokemon-mean] + let result = df + .clone() + .lazy() + .select([ + col("Name"), + col("Type 1"), + col("Speed"), + col("Speed") + .mean() + .over(["Type 1"]) + .alias("Mean speed in group"), + ]) + .collect()?; + + println!("{}", result); + // --8<-- [end:pokemon-mean] + + // --8<-- [start:group_by] + let result = df + .clone() + .lazy() + .select([ + col("Type 1"), + col("Type 2"), + col("Attack") + .mean() + .over(["Type 1"]) + .alias("avg_attack_by_type"), + col("Defense") + .mean() + .over(["Type 1", "Type 2"]) + .alias("avg_defense_by_type_combination"), + col("Attack").mean().alias("avg_attack"), + ]) + .collect()?; + + println!("{}", result); + // --8<-- [end:group_by] + + // --8<-- [start:operations] + let filtered = df + .clone() + .lazy() + .filter(col("Type 2").eq(lit("Psychic"))) + .select([col("Name"), col("Type 1"), col("Speed")]) + .collect()?; + + println!("{}", filtered); + // --8<-- [end:operations] + + // --8<-- [start:sort] + let result = filtered + .lazy() + .with_columns([cols(["Name", "Speed"]) + .sort_by( + ["Speed"], + SortMultipleOptions::default().with_order_descending(true), + ) + .over(["Type 1"])]) + .collect()?; + println!("{}", result); + // --8<-- [end:sort] + + // --8<-- [start:examples] + let result = df + .clone() + .lazy() + .select([ + col("Type 1") + .head(Some(3)) + .over_with_options(["Type 1"], None, WindowMapping::Explode) + .flatten(), + col("Name") + .sort_by( + ["Speed"], + SortMultipleOptions::default().with_order_descending(true), + ) + .head(Some(3)) + .over_with_options(["Type 1"], None, WindowMapping::Explode) + .flatten() + .alias("fastest/group"), + col("Name") + .sort_by( + ["Attack"], + SortMultipleOptions::default().with_order_descending(true), + ) + .head(Some(3)) + .over_with_options(["Type 1"], None, WindowMapping::Explode) + .flatten() + .alias("strongest/group"), + col("Name") + .sort(Default::default()) + .head(Some(3)) + .over_with_options(["Type 1"], None, WindowMapping::Explode) + .flatten() + .alias("sorted_by_alphabet"), + ]) + .collect()?; + println!("{:?}", result); + // --8<-- [end:examples] + + Ok(()) +} diff --git a/docs/source/src/rust/user-guide/getting-started.rs b/docs/source/src/rust/user-guide/getting-started.rs new file mode 100644 index 000000000000..b5b249304209 --- /dev/null +++ b/docs/source/src/rust/user-guide/getting-started.rs @@ -0,0 +1,195 @@ +fn main() -> Result<(), Box> { + // --8<-- [start:df] + use chrono::prelude::*; + use polars::prelude::*; + + let mut df: DataFrame = df!( + "name" => ["Alice Archer", "Ben Brown", "Chloe Cooper", "Daniel Donovan"], + "birthdate" => [ + NaiveDate::from_ymd_opt(1997, 1, 10).unwrap(), + NaiveDate::from_ymd_opt(1985, 2, 15).unwrap(), + NaiveDate::from_ymd_opt(1983, 3, 22).unwrap(), + NaiveDate::from_ymd_opt(1981, 4, 30).unwrap(), + ], + "weight" => [57.9, 72.5, 53.6, 83.1], // (kg) + "height" => [1.56, 1.77, 1.65, 1.75], // (m) + ) + .unwrap(); + println!("{}", df); + // --8<-- [end:df] + + // --8<-- [start:csv] + use std::fs::File; + + let mut file = File::create("docs/assets/data/output.csv").expect("could not create file"); + CsvWriter::new(&mut file) + .include_header(true) + .with_separator(b',') + .finish(&mut df)?; + let df_csv = CsvReadOptions::default() + .with_has_header(true) + .with_parse_options(CsvParseOptions::default().with_try_parse_dates(true)) + .try_into_reader_with_file_path(Some("docs/assets/data/output.csv".into()))? + .finish()?; + println!("{}", df_csv); + // --8<-- [end:csv] + + // --8<-- [start:select] + let result = df + .clone() + .lazy() + .select([ + col("name"), + col("birthdate").dt().year().alias("birth_year"), + (col("weight") / col("height").pow(2)).alias("bmi"), + ]) + .collect()?; + println!("{}", result); + // --8<-- [end:select] + + // --8<-- [start:expression-expansion] + let result = df + .clone() + .lazy() + .select([ + col("name"), + (cols(["weight", "height"]) * lit(0.95)) + .round(2) + .name() + .suffix("-5%"), + ]) + .collect()?; + println!("{}", result); + // --8<-- [end:expression-expansion] + + // --8<-- [start:with_columns] + let result = df + .clone() + .lazy() + .with_columns([ + col("birthdate").dt().year().alias("birth_year"), + (col("weight") / col("height").pow(2)).alias("bmi"), + ]) + .collect()?; + println!("{}", result); + // --8<-- [end:with_columns] + + // --8<-- [start:filter] + let result = df + .clone() + .lazy() + .filter(col("birthdate").dt().year().lt(lit(1990))) + .collect()?; + println!("{}", result); + // --8<-- [end:filter] + + // --8<-- [start:filter-multiple] + let result = df + .clone() + .lazy() + .filter( + col("birthdate") + .is_between( + lit(NaiveDate::from_ymd_opt(1982, 12, 31).unwrap()), + lit(NaiveDate::from_ymd_opt(1996, 1, 1).unwrap()), + ClosedInterval::Both, + ) + .and(col("height").gt(lit(1.7))), + ) + .collect()?; + println!("{}", result); + // --8<-- [end:filter-multiple] + + // --8<-- [start:group_by] + // Use `group_by_stable` if you want the Python behaviour of `maintain_order=True`. + let result = df + .clone() + .lazy() + .group_by([(col("birthdate").dt().year() / lit(10) * lit(10)).alias("decade")]) + .agg([len()]) + .collect()?; + println!("{}", result); + // --8<-- [end:group_by] + + // --8<-- [start:group_by-agg] + let result = df + .clone() + .lazy() + .group_by([(col("birthdate").dt().year() / lit(10) * lit(10)).alias("decade")]) + .agg([ + len().alias("sample_size"), + col("weight").mean().round(2).alias("avg_weight"), + col("height").max().alias("tallest"), + ]) + .collect()?; + println!("{}", result); + // --8<-- [end:group_by-agg] + + // --8<-- [start:complex] + let result = df + .clone() + .lazy() + .with_columns([ + (col("birthdate").dt().year() / lit(10) * lit(10)).alias("decade"), + col("name").str().split(lit(" ")).list().first(), + ]) + .select([all().exclude(["birthdate"])]) + .group_by([col("decade")]) + .agg([ + col("name"), + cols(["weight", "height"]) + .mean() + .round(2) + .name() + .prefix("avg_"), + ]) + .collect()?; + println!("{}", result); + // --8<-- [end:complex] + + // --8<-- [start:join] + let df2: DataFrame = df!( + "name" => ["Ben Brown", "Daniel Donovan", "Alice Archer", "Chloe Cooper"], + "parent" => [true, false, false, false], + "siblings" => [1, 2, 3, 4], + ) + .unwrap(); + + let result = df + .clone() + .lazy() + .join( + df2.clone().lazy(), + [col("name")], + [col("name")], + JoinArgs::new(JoinType::Left), + ) + .collect()?; + + println!("{}", result); + // --8<-- [end:join] + + // --8<-- [start:concat] + let df3: DataFrame = df!( + "name" => ["Ethan Edwards", "Fiona Foster", "Grace Gibson", "Henry Harris"], + "birthdate" => [ + NaiveDate::from_ymd_opt(1977, 5, 10).unwrap(), + NaiveDate::from_ymd_opt(1975, 6, 23).unwrap(), + NaiveDate::from_ymd_opt(1973, 7, 22).unwrap(), + NaiveDate::from_ymd_opt(1971, 8, 3).unwrap(), + ], + "weight" => [67.9, 72.5, 57.6, 93.1], // (kg) + "height" => [1.76, 1.6, 1.66, 1.8], // (m) + ) + .unwrap(); + + let result = concat( + [df.clone().lazy(), df3.clone().lazy()], + UnionArgs::default(), + )? + .collect()?; + println!("{}", result); + // --8<-- [end:concat] + + Ok(()) +} diff --git a/docs/source/src/rust/user-guide/io/cloud-storage.rs b/docs/source/src/rust/user-guide/io/cloud-storage.rs new file mode 100644 index 000000000000..5de0b2bb7d63 --- /dev/null +++ b/docs/source/src/rust/user-guide/io/cloud-storage.rs @@ -0,0 +1,49 @@ +// --8<-- [start:read_parquet] +use aws_config::BehaviorVersion; +use polars::prelude::*; + +#[tokio::main] +async fn main() { + let bucket = ""; + let path = ""; + + let config = aws_config::load_defaults(BehaviorVersion::latest()).await; + let client = aws_sdk_s3::Client::new(&config); + + let object = client + .get_object() + .bucket(bucket) + .key(path) + .send() + .await + .unwrap(); + + let bytes = object.body.collect().await.unwrap().into_bytes(); + + let cursor = std::io::Cursor::new(bytes); + let df = CsvReader::new(cursor).finish().unwrap(); + + println!("{:?}", df); +} +// --8<-- [end:read_parquet] + +// --8<-- [start:scan_parquet_query] +// --8<-- [end:scan_parquet_query] + +// --8<-- [start:scan_parquet_storage_options_aws] +// --8<-- [end:scan_parquet_storage_options_aws] + +// --8<-- [start:credential_provider_class] +// --8<-- [end:credential_provider_class] + +// --8<-- [start:credential_provider_custom_func] +// --8<-- [end:credential_provider_custom_func] + +// --8<-- [start:credential_provider_custom_func_azure] +// --8<-- [end:credential_provider_custom_func_azure] + +// --8<-- [start:scan_pyarrow_dataset] +// --8<-- [end:scan_pyarrow_dataset] + +// --8<-- [start:write_parquet] +// --8<-- [end:write_parquet] diff --git a/docs/source/src/rust/user-guide/io/csv.rs b/docs/source/src/rust/user-guide/io/csv.rs new file mode 100644 index 000000000000..1406a9e098cc --- /dev/null +++ b/docs/source/src/rust/user-guide/io/csv.rs @@ -0,0 +1,34 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box> { + // --8<-- [start:read] + use polars::prelude::*; + + // --8<-- [start:write] + let mut df = df!( + "foo" => &[1, 2, 3], + "bar" => &[None, Some("bak"), Some("baz")], + ) + .unwrap(); + + let mut file = std::fs::File::create("docs/assets/data/path.csv").unwrap(); + CsvWriter::new(&mut file).finish(&mut df).unwrap(); + // --8<-- [end:write] + + let df = CsvReadOptions::default() + .try_into_reader_with_file_path(Some("docs/assets/data/path.csv".into())) + .unwrap() + .finish() + .unwrap(); + // --8<-- [end:read] + println!("{}", df); + + // --8<-- [start:scan] + let lf = LazyCsvReader::new("docs/assets/data/path.csv") + .finish() + .unwrap(); + // --8<-- [end:scan] + println!("{}", lf.collect()?); + + Ok(()) +} diff --git a/docs/source/src/rust/user-guide/io/json.rs b/docs/source/src/rust/user-guide/io/json.rs new file mode 100644 index 000000000000..468e20babe8b --- /dev/null +++ b/docs/source/src/rust/user-guide/io/json.rs @@ -0,0 +1,48 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box> { + // --8<-- [start:write] + let mut df = df!( + "foo" => &[1, 2, 3], + "bar" => &[None, Some("bak"), Some("baz")], + ) + .unwrap(); + + let mut file = std::fs::File::create("docs/assets/data/path.json").unwrap(); + + // json + JsonWriter::new(&mut file) + .with_json_format(JsonFormat::Json) + .finish(&mut df) + .unwrap(); + + // ndjson + JsonWriter::new(&mut file) + .with_json_format(JsonFormat::JsonLines) + .finish(&mut df) + .unwrap(); + // --8<-- [end:write] + + // --8<-- [start:read] + use polars::prelude::*; + + let mut file = std::fs::File::open("docs/assets/data/path.json").unwrap(); + let df = JsonReader::new(&mut file).finish()?; + // --8<-- [end:read] + println!("{}", df); + + // --8<-- [start:readnd] + let mut file = std::fs::File::open("docs/assets/data/path.json").unwrap(); + let df = JsonLineReader::new(&mut file).finish().unwrap(); + // --8<-- [end:readnd] + println!("{}", df); + + // --8<-- [start:scan] + let lf = LazyJsonLineReader::new("docs/assets/data/path.json") + .finish() + .unwrap(); + // --8<-- [end:scan] + println!("{}", lf.collect()?); + + Ok(()) +} diff --git a/docs/source/src/rust/user-guide/io/parquet.rs b/docs/source/src/rust/user-guide/io/parquet.rs new file mode 100644 index 000000000000..a554c7051040 --- /dev/null +++ b/docs/source/src/rust/user-guide/io/parquet.rs @@ -0,0 +1,29 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box> { + // --8<-- [start:write] + let mut df = df!( + "foo" => &[1, 2, 3], + "bar" => &[None, Some("bak"), Some("baz")], + ) + .unwrap(); + + let mut file = std::fs::File::create("docs/assets/data/path.parquet").unwrap(); + ParquetWriter::new(&mut file).finish(&mut df).unwrap(); + // --8<-- [end:write] + + // --8<-- [start:read] + let mut file = std::fs::File::open("docs/assets/data/path.parquet").unwrap(); + + let df = ParquetReader::new(&mut file).finish().unwrap(); + // --8<-- [end:read] + println!("{}", df); + + // --8<-- [start:scan] + let args = ScanArgsParquet::default(); + let lf = LazyFrame::scan_parquet("docs/assets/data/path.parquet", args).unwrap(); + // --8<-- [end:scan] + println!("{}", lf.collect()?); + + Ok(()) +} diff --git a/docs/source/src/rust/user-guide/transformations/concatenation.rs b/docs/source/src/rust/user-guide/transformations/concatenation.rs new file mode 100644 index 000000000000..4b7d183316c9 --- /dev/null +++ b/docs/source/src/rust/user-guide/transformations/concatenation.rs @@ -0,0 +1,62 @@ +// --8<-- [start:setup] +use polars::prelude::*; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // --8<-- [start:vertical] + let df_v1 = df!( + "a"=> &[1], + "b"=> &[3], + )?; + let df_v2 = df!( + "a"=> &[2], + "b"=> &[4], + )?; + let df_vertical_concat = concat( + [df_v1.clone().lazy(), df_v2.clone().lazy()], + UnionArgs::default(), + )? + .collect()?; + println!("{}", &df_vertical_concat); + // --8<-- [end:vertical] + + // --8<-- [start:horizontal] + let df_h1 = df!( + "l1"=> &[1, 2], + "l2"=> &[3, 4], + )?; + let df_h2 = df!( + "r1"=> &[5, 6], + "r2"=> &[7, 8], + "r3"=> &[9, 10], + )?; + let df_horizontal_concat = polars::functions::concat_df_horizontal(&[df_h1, df_h2], true)?; + println!("{}", &df_horizontal_concat); + // --8<-- [end:horizontal] + // + // --8<-- [start:horizontal_different_lengths] + let df_h1 = df!( + "l1"=> &[1, 2], + "l2"=> &[3, 4], + )?; + let df_h2 = df!( + "r1"=> &[5, 6, 7], + "r2"=> &[8, 9, 10], + )?; + let df_horizontal_concat = polars::functions::concat_df_horizontal(&[df_h1, df_h2], true)?; + println!("{}", &df_horizontal_concat); + // --8<-- [end:horizontal_different_lengths] + + // --8<-- [start:cross] + let df_d1 = df!( + "a"=> &[1], + "b"=> &[3], + )?; + let df_d2 = df!( + "a"=> &[2], + "d"=> &[4],)?; + let df_diagonal_concat = polars::functions::concat_df_diagonal(&[df_d1, df_d2])?; + println!("{}", &df_diagonal_concat); + // --8<-- [end:cross] + Ok(()) +} diff --git a/docs/source/src/rust/user-guide/transformations/joins.rs b/docs/source/src/rust/user-guide/transformations/joins.rs new file mode 100644 index 000000000000..45037abf3eeb --- /dev/null +++ b/docs/source/src/rust/user-guide/transformations/joins.rs @@ -0,0 +1,291 @@ +// --8<-- [start:setup] +use polars::prelude::*; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // NOTE: This assumes the data has been downloaded and is available. + // See the corresponding Python script for the remote location of the data. + + // --8<-- [start:props_groups] + let props_groups = CsvReadOptions::default() + .with_has_header(true) + .try_into_reader_with_file_path(Some("docs/assets/data/monopoly_props_groups.csv".into()))? + .finish()? + .head(Some(5)); + println!("{}", props_groups); + // --8<-- [end:props_groups] + + // --8<-- [start:props_prices] + let props_prices = CsvReadOptions::default() + .with_has_header(true) + .try_into_reader_with_file_path(Some("docs/assets/data/monopoly_props_prices.csv".into()))? + .finish()? + .head(Some(5)); + println!("{}", props_prices); + // --8<-- [end:props_prices] + + // --8<-- [start:equi-join] + // In Rust, we cannot use the shorthand of specifying a common + // column name just once. + let result = props_groups + .clone() + .lazy() + .join( + props_prices.clone().lazy(), + [col("property_name")], + [col("property_name")], + JoinArgs::default(), + ) + .collect()?; + println!("{}", result); + // --8<-- [end:equi-join] + + // --8<-- [start:props_groups2] + let props_groups2 = props_groups + .clone() + .lazy() + .with_column(col("property_name").str().to_lowercase()) + .collect()?; + println!("{}", props_groups2); + // --8<-- [end:props_groups2] + + // --8<-- [start:props_prices2] + let props_prices2 = props_prices + .clone() + .lazy() + .select([col("property_name").alias("name"), col("cost")]) + .collect()?; + println!("{}", props_prices2); + // --8<-- [end:props_prices2] + + // --8<-- [start:join-key-expression] + let result = props_groups2 + .clone() + .lazy() + .join( + props_prices2.clone().lazy(), + [col("property_name")], + [col("name").str().to_lowercase()], + JoinArgs::default(), + ) + .collect()?; + println!("{}", result); + // --8<-- [end:join-key-expression] + + // --8<-- [start:inner-join] + let result = props_groups + .clone() + .lazy() + .join( + props_prices.clone().lazy(), + [col("property_name")], + [col("property_name")], + JoinArgs::new(JoinType::Inner), + ) + .collect()?; + println!("{}", result); + // --8<-- [end:inner-join] + + // --8<-- [start:left-join] + let result = props_groups + .clone() + .lazy() + .join( + props_prices.clone().lazy(), + [col("property_name")], + [col("property_name")], + JoinArgs::new(JoinType::Left), + ) + .collect()?; + println!("{}", result); + // --8<-- [end:left-join] + + // --8<-- [start:right-join] + let result = props_groups + .clone() + .lazy() + .join( + props_prices.clone().lazy(), + [col("property_name")], + [col("property_name")], + JoinArgs::new(JoinType::Right), + ) + .collect()?; + println!("{}", result); + // --8<-- [end:right-join] + + // --8<-- [start:left-right-join-equals] + // `equals_missing` is needed instead of `equals` + // so that missing values compare as equal. + let dfs_match = result.equals_missing( + &props_prices + .clone() + .lazy() + .join( + props_groups.clone().lazy(), + [col("property_name")], + [col("property_name")], + JoinArgs::new(JoinType::Left), + ) + .select([ + // Reorder the columns to match the order of `result`. + col("group"), + col("property_name"), + col("cost"), + ]) + .collect()?, + ); + println!("{}", dfs_match); + // --8<-- [end:left-right-join-equals] + + // --8<-- [start:full-join] + let result = props_groups + .clone() + .lazy() + .join( + props_prices.clone().lazy(), + [col("property_name")], + [col("property_name")], + JoinArgs::new(JoinType::Full), + ) + .collect()?; + println!("{}", result); + // --8<-- [end:full-join] + + // --8<-- [start:full-join-coalesce] + let result = props_groups + .clone() + .lazy() + .join( + props_prices.clone().lazy(), + [col("property_name")], + [col("property_name")], + JoinArgs::new(JoinType::Full).with_coalesce(JoinCoalesce::CoalesceColumns), + ) + .collect()?; + println!("{}", result); + // --8<-- [end:full-join-coalesce] + + // --8<-- [start:semi-join] + let result = props_groups + .clone() + .lazy() + .join( + props_prices.clone().lazy(), + [col("property_name")], + [col("property_name")], + JoinArgs::new(JoinType::Semi), + ) + .collect()?; + println!("{}", result); + // --8<-- [end:semi-join] + + // --8<-- [start:anti-join] + let result = props_groups + .clone() + .lazy() + .join( + props_prices.clone().lazy(), + [col("property_name")], + [col("property_name")], + JoinArgs::new(JoinType::Anti), + ) + .collect()?; + println!("{}", result); + // --8<-- [end:anti-join] + + // --8<-- [start:players] + let players = df!( + "name" => ["Alice", "Bob"], + "cash" => [78, 135], + )?; + println!("{}", players); + // --8<-- [end:players] + + // --8<-- [start:non-equi] + let result = players + .clone() + .lazy() + .join_builder() + .with(props_prices.clone().lazy()) + .join_where(vec![col("cash").cast(DataType::Int64).gt(col("cost"))]) + .collect()?; + println!("{}", result); + // --8<-- [end:non-equi] + + // --8<-- [start:df_trades] + use chrono::prelude::*; + + let df_trades = df!( + "time" => [ + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 1, 0).unwrap(), + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 1, 0).unwrap(), + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 3, 0).unwrap(), + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 6, 0).unwrap(), + ], + "stock" => ["A", "B", "B", "C"], + "trade" => [101, 299, 301, 500], + )?; + println!("{}", df_trades); + // --8<-- [end:df_trades] + + // --8<-- [start:df_quotes] + let df_quotes = df!( + "time" => [ + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 1, 0).unwrap(), + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 2, 0).unwrap(), + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 4, 0).unwrap(), + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 6, 0).unwrap(), + ], + "stock" => ["A", "B", "C", "A"], + "quote" => [100, 300, 501, 102], + )?; + println!("{}", df_quotes); + // --8<-- [end:df_quotes] + + // --8<-- [start:asof] + let result = df_trades.join_asof_by( + &df_quotes, + "time", + "time", + ["stock"], + ["stock"], + AsofStrategy::Backward, + None, + true, + true, + )?; + println!("{}", result); + // --8<-- [end:asof] + + // --8<-- [start:asof-tolerance] + let result = df_trades.join_asof_by( + &df_quotes, + "time", + "time", + ["stock"], + ["stock"], + AsofStrategy::Backward, + Some(AnyValue::Duration(60000, TimeUnit::Milliseconds)), + true, + true, + )?; + println!("{}", result); + // --8<-- [end:asof-tolerance] + + // --8<-- [start:cartesian-product] + let tokens = df!( + "monopoly_token" => ["hat", "shoe", "boat"], + )?; + + let result = players + .clone() + .lazy() + .select([col("name")]) + .cross_join(tokens.clone().lazy(), None) + .collect()?; + println!("{}", result); + // --8<-- [end:cartesian-product] + + Ok(()) +} diff --git a/docs/source/src/rust/user-guide/transformations/pivot.rs b/docs/source/src/rust/user-guide/transformations/pivot.rs new file mode 100644 index 000000000000..5072ed82d52c --- /dev/null +++ b/docs/source/src/rust/user-guide/transformations/pivot.rs @@ -0,0 +1,38 @@ +// --8<-- [start:setup] +use polars::prelude::pivot::pivot; +use polars::prelude::*; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // --8<-- [start:df] + let df = df!( + "foo"=> ["A", "A", "B", "B", "C"], + "bar"=> ["k", "l", "m", "n", "o"], + "N"=> [1, 2, 2, 4, 2], + )?; + println!("{}", &df); + // --8<-- [end:df] + + // --8<-- [start:eager] + let out = pivot(&df, ["foo"], Some(["bar"]), Some(["N"]), false, None, None)?; + println!("{}", &out); + // --8<-- [end:eager] + + // --8<-- [start:lazy] + let q = df.lazy(); + let q2 = pivot( + &q.collect()?, + ["foo"], + Some(["bar"]), + Some(["N"]), + false, + None, + None, + )? + .lazy(); + let out = q2.collect()?; + println!("{}", &out); + // --8<-- [end:lazy] + + Ok(()) +} diff --git a/docs/source/src/rust/user-guide/transformations/time-series/filter.rs b/docs/source/src/rust/user-guide/transformations/time-series/filter.rs new file mode 100644 index 000000000000..451af20ada80 --- /dev/null +++ b/docs/source/src/rust/user-guide/transformations/time-series/filter.rs @@ -0,0 +1,57 @@ +// --8<-- [start:setup] +use chrono::prelude::*; +use polars::io::prelude::*; +use polars::prelude::*; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // --8<-- [start:df] + let df = CsvReadOptions::default() + .map_parse_options(|parse_options| parse_options.with_try_parse_dates(true)) + .try_into_reader_with_file_path(Some("docs/assets/data/apple_stock.csv".into())) + .unwrap() + .finish() + .unwrap(); + println!("{}", &df); + // --8<-- [end:df] + + // --8<-- [start:filter] + let filtered_df = df + .clone() + .lazy() + .filter(col("Date").eq(lit(NaiveDate::from_ymd_opt(1995, 10, 16).unwrap()))) + .collect()?; + println!("{}", &filtered_df); + // --8<-- [end:filter] + + // --8<-- [start:range] + let filtered_range_df = df + .clone() + .lazy() + .filter( + col("Date") + .gt(lit(NaiveDate::from_ymd_opt(1995, 7, 1).unwrap())) + .and(col("Date").lt(lit(NaiveDate::from_ymd_opt(1995, 11, 1).unwrap()))), + ) + .collect()?; + println!("{}", &filtered_range_df); + // --8<-- [end:range] + + // --8<-- [start:negative] + let negative_dates_df = df!( + "ts"=> &["-1300-05-23", "-1400-03-02"], + "values"=> &[3, 4])? + .lazy() + .with_column(col("ts").str().to_date(StrptimeOptions::default())) + .collect()?; + + let negative_dates_filtered_df = negative_dates_df + .clone() + .lazy() + .filter(col("ts").dt().year().lt(-1300)) + .collect()?; + println!("{}", &negative_dates_filtered_df); + // --8<-- [end:negative] + + Ok(()) +} diff --git a/docs/source/src/rust/user-guide/transformations/time-series/parsing.rs b/docs/source/src/rust/user-guide/transformations/time-series/parsing.rs new file mode 100644 index 000000000000..b6f75e6240a7 --- /dev/null +++ b/docs/source/src/rust/user-guide/transformations/time-series/parsing.rs @@ -0,0 +1,76 @@ +// --8<-- [start:setup] +use polars::io::prelude::*; +use polars::prelude::*; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // --8<-- [start:df] + let df = CsvReadOptions::default() + .map_parse_options(|parse_options| parse_options.with_try_parse_dates(true)) + .try_into_reader_with_file_path(Some("docs/assets/data/apple_stock.csv".into())) + .unwrap() + .finish() + .unwrap(); + println!("{}", &df); + // --8<-- [end:df] + + // --8<-- [start:cast] + let df = CsvReadOptions::default() + .map_parse_options(|parse_options| parse_options.with_try_parse_dates(false)) + .try_into_reader_with_file_path(Some("docs/assets/data/apple_stock.csv".into())) + .unwrap() + .finish() + .unwrap(); + let df = df + .clone() + .lazy() + .with_columns([col("Date").str().to_date(StrptimeOptions::default())]) + .collect()?; + println!("{}", &df); + // --8<-- [end:cast] + + // --8<-- [start:df3] + let df_with_year = df + .clone() + .lazy() + .with_columns([col("Date").dt().year().alias("year")]) + .collect()?; + println!("{}", &df_with_year); + // --8<-- [end:df3] + + // --8<-- [start:extract] + let df_with_year = df + .clone() + .lazy() + .with_columns([col("Date").dt().year().alias("year")]) + .collect()?; + println!("{}", &df_with_year); + // --8<-- [end:extract] + + // --8<-- [start:mixed] + let data = [ + "2021-03-27T00:00:00+0100", + "2021-03-28T00:00:00+0100", + "2021-03-29T00:00:00+0200", + "2021-03-30T00:00:00+0200", + ]; + let q = col("date") + .str() + .to_datetime( + Some(TimeUnit::Microseconds), + None, + StrptimeOptions { + format: Some("%Y-%m-%dT%H:%M:%S%z".into()), + ..Default::default() + }, + lit("raise"), + ) + .dt() + .convert_time_zone("Europe/Brussels".into()); + let mixed_parsed = df!("date" => &data)?.lazy().select([q]).collect()?; + + println!("{}", &mixed_parsed); + // --8<-- [end:mixed] + + Ok(()) +} diff --git a/docs/source/src/rust/user-guide/transformations/time-series/resampling.rs b/docs/source/src/rust/user-guide/transformations/time-series/resampling.rs new file mode 100644 index 000000000000..dec19f65fc26 --- /dev/null +++ b/docs/source/src/rust/user-guide/transformations/time-series/resampling.rs @@ -0,0 +1,50 @@ +// --8<-- [start:setup] +use chrono::prelude::*; +use polars::prelude::*; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // --8<-- [start:df] + let time = polars::time::date_range( + "time".into(), + NaiveDate::from_ymd_opt(2021, 12, 16) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap(), + NaiveDate::from_ymd_opt(2021, 12, 16) + .unwrap() + .and_hms_opt(3, 0, 0) + .unwrap(), + Duration::parse("30m"), + ClosedWindow::Both, + TimeUnit::Milliseconds, + None, + )?; + let df = df!( + "time" => time, + "groups" => &["a", "a", "a", "b", "b", "a", "a"], + "values" => &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], + )?; + println!("{}", &df); + // --8<-- [end:df] + + // --8<-- [start:upsample] + let out1 = df + .clone() + .upsample::<[String; 0]>([], "time", Duration::parse("15m"))? + .fill_null(FillNullStrategy::Forward(None))?; + println!("{}", &out1); + // --8<-- [end:upsample] + + // --8<-- [start:upsample2] + let out2 = df + .clone() + .upsample::<[String; 0]>([], "time", Duration::parse("15m"))? + .lazy() + .with_columns([col("values").interpolate(InterpolationMethod::Linear)]) + .collect()? + .fill_null(FillNullStrategy::Forward(None))?; + println!("{}", &out2); + // --8<-- [end:upsample2] + Ok(()) +} diff --git a/docs/source/src/rust/user-guide/transformations/time-series/rolling.rs b/docs/source/src/rust/user-guide/transformations/time-series/rolling.rs new file mode 100644 index 000000000000..7378816b786a --- /dev/null +++ b/docs/source/src/rust/user-guide/transformations/time-series/rolling.rs @@ -0,0 +1,152 @@ +// --8<-- [start:setup] +use chrono::prelude::*; +use polars::io::prelude::*; +use polars::prelude::*; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // --8<-- [start:df] + let df = CsvReadOptions::default() + .map_parse_options(|parse_options| parse_options.with_try_parse_dates(true)) + .try_into_reader_with_file_path(Some("docs/assets/data/apple_stock.csv".into())) + .unwrap() + .finish() + .unwrap() + .sort( + ["Date"], + SortMultipleOptions::default().with_maintain_order(true), + )?; + println!("{}", &df); + // --8<-- [end:df] + + // --8<-- [start:group_by] + let annual_average_df = df + .clone() + .lazy() + .group_by_dynamic( + col("Date"), + [], + DynamicGroupOptions { + every: Duration::parse("1y"), + period: Duration::parse("1y"), + offset: Duration::parse("0"), + ..Default::default() + }, + ) + .agg([col("Close").mean()]) + .collect()?; + + let df_with_year = annual_average_df + .lazy() + .with_columns([col("Date").dt().year().alias("year")]) + .collect()?; + println!("{}", &df_with_year); + // --8<-- [end:group_by] + + // --8<-- [start:group_by_dyn] + let time = polars::time::date_range( + "time".into(), + NaiveDate::from_ymd_opt(2021, 1, 1) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap(), + NaiveDate::from_ymd_opt(2021, 12, 31) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap(), + Duration::parse("1d"), + ClosedWindow::Both, + TimeUnit::Milliseconds, + None, + )? + .cast(&DataType::Date)?; + + let df = df!( + "time" => time, + )?; + + let out = df + .clone() + .lazy() + .group_by_dynamic( + col("time"), + [], + DynamicGroupOptions { + every: Duration::parse("1mo"), + period: Duration::parse("1mo"), + offset: Duration::parse("0"), + closed_window: ClosedWindow::Left, + ..Default::default() + }, + ) + .agg([ + col("time") + .cum_count(true) // python example has false + .reverse() + .head(Some(3)) + .alias("day/eom"), + ((col("time").last() - col("time").first()).map( + // had to use map as .duration().days() is not available + |s| { + Ok(Some( + s.duration()? + .into_iter() + .map(|d| d.map(|v| v / 1000 / 24 / 60 / 60)) + .collect::() + .into_column(), + )) + }, + GetOutput::from_type(DataType::Int64), + ) + lit(1)) + .alias("days_in_month"), + ]) + .collect()?; + println!("{}", &out); + // --8<-- [end:group_by_dyn] + + // --8<-- [start:group_by_roll] + let time = polars::time::date_range( + "time".into(), + NaiveDate::from_ymd_opt(2021, 12, 16) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap(), + NaiveDate::from_ymd_opt(2021, 12, 16) + .unwrap() + .and_hms_opt(3, 0, 0) + .unwrap(), + Duration::parse("30m"), + ClosedWindow::Both, + TimeUnit::Milliseconds, + None, + )?; + let df = df!( + "time" => time, + "groups"=> ["a", "a", "a", "b", "b", "a", "a"], + )?; + println!("{}", &df); + // --8<-- [end:group_by_roll] + + // --8<-- [start:group_by_dyn2] + let out = df + .clone() + .lazy() + .group_by_dynamic( + col("time"), + [col("groups")], + DynamicGroupOptions { + every: Duration::parse("1h"), + period: Duration::parse("1h"), + offset: Duration::parse("0"), + include_boundaries: true, + closed_window: ClosedWindow::Both, + ..Default::default() + }, + ) + .agg([len()]) + .collect()?; + println!("{}", &out); + // --8<-- [end:group_by_dyn2] + + Ok(()) +} diff --git a/docs/source/src/rust/user-guide/transformations/time-series/timezones.rs b/docs/source/src/rust/user-guide/transformations/time-series/timezones.rs new file mode 100644 index 000000000000..476a7a332b5c --- /dev/null +++ b/docs/source/src/rust/user-guide/transformations/time-series/timezones.rs @@ -0,0 +1,52 @@ +// --8<-- [start:setup] +use polars::prelude::*; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // --8<-- [start:example] + let ts = ["2021-03-27 03:00", "2021-03-28 03:00"]; + let tz_naive = Column::new("tz_naive".into(), &ts); + let time_zones_df = DataFrame::new(vec![tz_naive])? + .lazy() + .select([col("tz_naive").str().to_datetime( + Some(TimeUnit::Milliseconds), + None, + StrptimeOptions::default(), + lit("raise"), + )]) + .with_columns([col("tz_naive") + .dt() + .replace_time_zone(Some("UTC".into()), lit("raise"), NonExistent::Raise) + .alias("tz_aware")]) + .collect()?; + + println!("{}", &time_zones_df); + // --8<-- [end:example] + + // --8<-- [start:example2] + let time_zones_operations = time_zones_df + .lazy() + .select([ + col("tz_aware") + .dt() + .replace_time_zone( + Some("Europe/Brussels".into()), + lit("raise"), + NonExistent::Raise, + ) + .alias("replace time zone"), + col("tz_aware") + .dt() + .convert_time_zone("Asia/Kathmandu".into()) + .alias("convert time zone"), + col("tz_aware") + .dt() + .replace_time_zone(None, lit("raise"), NonExistent::Raise) + .alias("unset time zone"), + ]) + .collect()?; + println!("{}", &time_zones_operations); + // --8<-- [end:example2] + + Ok(()) +} diff --git a/docs/source/src/rust/user-guide/transformations/unpivot.rs b/docs/source/src/rust/user-guide/transformations/unpivot.rs new file mode 100644 index 000000000000..a094d7364e7d --- /dev/null +++ b/docs/source/src/rust/user-guide/transformations/unpivot.rs @@ -0,0 +1,21 @@ +// --8<-- [start:setup] +use polars::prelude::*; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // --8<-- [start:df] + let df = df!( + "A"=> &["a", "b", "a"], + "B"=> &[1, 3, 5], + "C"=> &[10, 11, 12], + "D"=> &[2, 4, 6], + )?; + println!("{}", &df); + // --8<-- [end:df] + + // --8<-- [start:unpivot] + let out = df.unpivot(["A", "B"], ["C", "D"])?; + println!("{}", &out); + // --8<-- [end:unpivot] + Ok(()) +} diff --git a/docs/source/user-guide/concepts/_streaming.md b/docs/source/user-guide/concepts/_streaming.md new file mode 100644 index 000000000000..c810f4d98bfb --- /dev/null +++ b/docs/source/user-guide/concepts/_streaming.md @@ -0,0 +1,55 @@ +# Streaming + + + +One additional benefit of the lazy API is that it allows queries to be executed in a streaming +manner. Instead of processing all the data at once, Polars can execute the query in batches allowing +you to process datasets that do not fit in memory. + +To tell Polars we want to execute a query in streaming mode we pass the `engine="streaming"` +argument to `collect` + +{{code_block('user-guide/concepts/streaming','streaming',['collect'])}} + +## When is streaming available? + +Streaming is still in development. We can ask Polars to execute any lazy query in streaming mode. +However, not all lazy operations support streaming. If there is an operation for which streaming is +not supported, Polars will run the query in non-streaming mode. + +Streaming is supported for many operations including: + +- `filter`, `slice`, `head`, `tail` +- `with_columns`, `select` +- `group_by` +- `join` +- `unique` +- `sort` +- `explode`, `unpivot` +- `scan_csv`, `scan_parquet`, `scan_ipc` + +This list is not exhaustive. Polars is in active development, and more operations can be added +without explicit notice. + +### Example with supported operations + +To determine which parts of your query are streaming, use the `explain` method. Below is an example +that demonstrates how to inspect the query plan. More information about the query plan can be found +in the chapter on the [Lazy API](https://docs.pola.rs/user-guide/lazy/query-plan/). + +{{code_block('user-guide/concepts/streaming', 'example',['explain'])}} + +```python exec="on" result="text" session="user-guide/streaming" +--8<-- "python/user-guide/concepts/streaming.py:import" +--8<-- "python/user-guide/concepts/streaming.py:streaming" +--8<-- "python/user-guide/concepts/streaming.py:example" +``` + +### Example with non-streaming operations + +{{code_block('user-guide/concepts/streaming', 'example2',['explain'])}} + +```python exec="on" result="text" session="user-guide/streaming" +--8<-- "python/user-guide/concepts/streaming.py:import" +--8<-- "python/user-guide/concepts/streaming.py:example2" +``` diff --git a/docs/source/user-guide/concepts/data-types-and-structures.md b/docs/source/user-guide/concepts/data-types-and-structures.md new file mode 100644 index 000000000000..93d29f60f484 --- /dev/null +++ b/docs/source/user-guide/concepts/data-types-and-structures.md @@ -0,0 +1,213 @@ +# Data types and structures + +## Data types + +Polars supports a variety of data types that fall broadly under the following categories: + +- Numeric data types: signed integers, unsigned integers, floating point numbers, and decimals. +- Nested data types: lists, structs, and arrays. +- Temporal: dates, datetimes, times, and time deltas. +- Miscellaneous: strings, binary data, Booleans, categoricals, enums, and objects. + +All types support missing values represented by the special value `null`. This is not to be +conflated with the special value `NaN` in floating number data types; see the +[section about floating point numbers](#floating-point-numbers) for more information. + +You can also find a +[full table with all data types supported in the appendix](#appendix-full-data-types-table) with +notes on when to use each data type and with links to relevant parts of the documentation. + +## Series + +The core base data structures provided by Polars are series and dataframes. A series is a +1-dimensional homogeneous data structure. By “homogeneous” we mean that all elements inside a series +have the same data type. The snippet below shows how to create a named series: + +{{code_block('user-guide/concepts/data-types-and-structures','series',['Series'])}} + +```python exec="on" result="text" session="user-guide/data-types-and-structures" +--8<-- "python/user-guide/concepts/data-types-and-structures.py:series" +``` + +When creating a series, Polars will infer the data type from the values you provide. You can specify +a concrete data type to override the inference mechanism: + +{{code_block('user-guide/concepts/data-types-and-structures','series-dtype',['Series'])}} + +```python exec="on" result="text" session="user-guide/data-types-and-structures" +--8<-- "python/user-guide/concepts/data-types-and-structures.py:series-dtype" +``` + +## Dataframe + +A dataframe is a 2-dimensional heterogeneous data structure that contains uniquely named series. By +holding your data in a dataframe you will be able to use the Polars API to write queries that +manipulate your data. You will be able to do this by using the +[contexts and expressions provided by Polars](expressions-and-contexts.md) that we will talk about +next. + +The snippet below shows how to create a dataframe from a dictionary of lists: + +{{code_block('user-guide/concepts/data-types-and-structures','df',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/data-types-and-structures" +--8<-- "python/user-guide/concepts/data-types-and-structures.py:df" +``` + +### Inspecting a dataframe + +In this subsection we will show some useful methods to quickly inspect a dataframe. We will use the +dataframe we created earlier as a starting point. + +#### Head + +The function `head` shows the first rows of a dataframe. By default, you get the first 5 rows but +you can also specify the number of rows you want: + +{{code_block('user-guide/concepts/data-types-and-structures','head',['head'])}} + +```python exec="on" result="text" session="user-guide/data-types-and-structures" +--8<-- "python/user-guide/concepts/data-types-and-structures.py:head" +``` + +#### Glimpse + +The function `glimpse` is another function that shows the values of the first few rows of a +dataframe, but formats the output differently from `head`. Here, each line of the output corresponds +to a single column, making it easier to take inspect wider dataframes: + +=== ":fontawesome-brands-python: Python" +[:material-api: `glimpse`](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.glimpse.html) + +```python +--8<-- "python/user-guide/concepts/data-types-and-structures.py:glimpse" +``` + +```python exec="on" result="text" session="user-guide/data-types-and-structures" +--8<-- "python/user-guide/concepts/data-types-and-structures.py:glimpse" +``` + +!!! info + + `glimpse` is only available for Python users. + +#### Tail + +The function `tail` shows the last rows of a dataframe. By default, you get the last 5 rows but you +can also specify the number of rows you want, similar to how `head` works: + +{{code_block('user-guide/concepts/data-types-and-structures','tail',['tail'])}} + +```python exec="on" result="text" session="user-guide/data-types-and-structures" +--8<-- "python/user-guide/concepts/data-types-and-structures.py:tail" +``` + +#### Sample + +If you think the first or last rows of your dataframe are not representative of your data, you can +use `sample` to get an arbitrary number of randomly selected rows from the DataFrame. Note that the +rows are not necessarily returned in the same order as they appear in the dataframe: + +{{code_block('user-guide/concepts/data-types-and-structures','sample',['sample'])}} + +```python exec="on" result="text" session="user-guide/data-types-and-structures" +--8<-- "python/user-guide/concepts/data-types-and-structures.py:sample" +``` + +#### Describe + +You can also use `describe` to compute summary statistics for all columns of your dataframe: + +{{code_block('user-guide/concepts/data-types-and-structures','describe',['describe'])}} + +```python exec="on" result="text" session="user-guide/data-types-and-structures" +--8<-- "python/user-guide/concepts/data-types-and-structures.py:describe" +``` + +## Schema + +When talking about data (in a dataframe or otherwise) we can refer to its schema. The schema is a +mapping of column or series names to the data types of those same columns or series. + +You can check the schema of a dataframe with `schema`: + +{{code_block('user-guide/concepts/data-types-and-structures','schema',[])}} + +```python exec="on" result="text" session="user-guide/data-types-and-structures" +--8<-- "python/user-guide/concepts/data-types-and-structures.py:schema" +``` + +Much like with series, Polars will infer the schema of a dataframe when you create it but you can +override the inference system if needed. + +In Python, you can specify an explicit schema by using a dictionary to map column names to data +types. You can use the value `None` if you do not wish to override inference for a given column: + +```python +--8<-- "python/user-guide/concepts/data-types-and-structures.py:schema-def" +``` + +```python exec="on" result="text" session="user-guide/data-types-and-structures" +--8<-- "python/user-guide/concepts/data-types-and-structures.py:schema-def" +``` + +If you only need to override the inference of some columns, the parameter `schema_overrides` tends +to be more convenient because it lets you omit columns for which you do not want to override the +inference: + +```python +--8<-- "python/user-guide/concepts/data-types-and-structures.py:schema_overrides" +``` + +```python exec="on" result="text" session="user-guide/data-types-and-structures" +--8<-- "python/user-guide/concepts/data-types-and-structures.py:schema_overrides" +``` + +## Data types internals + +Polars utilizes the [Arrow Columnar Format](https://arrow.apache.org/docs/format/Columnar.html) for +its data orientation. Following this specification allows Polars to transfer data to/from other +tools that also use the Arrow specification with little to no overhead. + +Polars gets most of its performance from its query engine, the optimizations it performs on your +query plans, and from the parallelization that it employs when running +[your expressions](expressions-and-contexts.md#expressions). + +## Floating point numbers + +Polars generally follows the IEEE 754 floating point standard for `Float32` and `Float64`, with some +exceptions: + +- Any `NaN` compares equal to any other `NaN`, and greater than any non-`NaN` value. +- Operations do not guarantee any particular behavior on the sign of zero or `NaN`, nor on the + payload of `NaN` values. This is not just limited to arithmetic operations, e.g. a sort or group + by operation may canonicalize all zeroes to +0 and all `NaN`s to a positive `NaN` without payload + for efficient equality checks. + +Polars always attempts to provide reasonably accurate results for floating point computations but +does not provide guarantees on the error unless mentioned otherwise. Generally speaking 100% +accurate results are infeasibly expensive to achieve (requiring much larger internal representations +than 64-bit floats), and thus some error is always to be expected. + +## Appendix: full data types table + +| Type(s) | Details | +| ------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `Boolean` | Boolean type that is bit packed efficiently. | +| `Int8`, `Int16`, `Int32`, `Int64` | Varying-precision signed integer types. | +| `UInt8`, `UInt16`, `UInt32`, `UInt64` | Varying-precision unsigned integer types. | +| `Float32`, `Float64` | Varying-precision signed floating point numbers. | +| `Decimal` | Decimal 128-bit type with optional precision and non-negative scale. Use this if you need fine-grained control over the precision of your floats and the operations you make on them. See [Python's `decimal.Decimal`](https://docs.python.org/3/library/decimal.html) for documentation on what a decimal data type is. | +| `String` | Variable length UTF-8 encoded string data, typically Human-readable. | +| `Binary` | Stores arbitrary, varying length raw binary data. | +| `Date` | Represents a calendar date. | +| `Time` | Represents a time of day. | +| `Datetime` | Represents a calendar date and time of day. | +| `Duration` | Represents a time duration. | +| `Array` | Arrays with a known, fixed shape per series; akin to numpy arrays. [Learn more about how arrays and lists differ and how to work with both](../expressions/lists-and-arrays.md). | +| `List` | Homogeneous 1D container with variable length. [Learn more about how arrays and lists differ and how to work with both](../expressions/lists-and-arrays.md). | +| `Object` | Wraps arbitrary Python objects. | +| `Categorical` | Efficient encoding of string data where the categories are inferred at runtime. [Learn more about how categoricals and enums differ and how to work with both](../expressions/categorical-data-and-enums.md). | +| `Enum` | Efficient ordered encoding of a set of predetermined string categories. [Learn more about how categoricals and enums differ and how to work with both](../expressions/categorical-data-and-enums.md). | +| `Struct` | Composite product type that can store multiple fields. [Learn more about the data type `Struct` in its dedicated documentation section.](../expressions/structs.md). | +| `Null` | Represents null values. | diff --git a/docs/source/user-guide/concepts/expressions-and-contexts.md b/docs/source/user-guide/concepts/expressions-and-contexts.md new file mode 100644 index 000000000000..649b446c6071 --- /dev/null +++ b/docs/source/user-guide/concepts/expressions-and-contexts.md @@ -0,0 +1,229 @@ +# Expressions and contexts + +Polars has developed its own Domain Specific Language (DSL) for transforming data. The language is +very easy to use and allows for complex queries that remain human readable. Expressions and +contexts, which will be introduced here, are very important in achieving this readability while also +allowing the Polars query engine to optimize your queries to make them run as fast as possible. + +## Expressions + +In Polars, an _expression_ is a lazy representation of a data transformation. Expressions are +modular and flexible, which means you can use them as building blocks to build more complex +expressions. Here is an example of a Polars expression: + +```python +--8<-- "python/user-guide/concepts/expressions.py:expression" +``` + +As you might be able to guess, this expression takes a column named “weight” and divides its values +by the square of the values in a column “height”, computing a person's BMI. + +The code above expresses an abstract computation that we can save in a variable, manipulate further, +or just print: + +```python +--8<-- "python/user-guide/concepts/expressions.py:print-expr" +``` + +```python exec="on" result="text" session="user-guide/concepts/expressions-and-contexts" +--8<-- "python/user-guide/concepts/expressions.py:expression" +--8<-- "python/user-guide/concepts/expressions.py:print-expr" +``` + +Because expressions are lazy, no computations have taken place yet. That's what we need contexts +for. + +## Contexts + +Polars expressions need a _context_ in which they are executed to produce a result. Depending on the +context it is used in, the same Polars expression can produce different results. In this section, we +will learn about the four most common contexts that Polars provides[^1]: + +1. `select` +2. `with_columns` +3. `filter` +4. `group_by` + +We use the dataframe below to show how each of the contexts works. + +{{code_block('user-guide/concepts/expressions','df',[])}} + +```python exec="on" result="text" session="user-guide/concepts/expressions-and-contexts" +--8<-- "python/user-guide/concepts/expressions.py:df" +``` + +### `select` + +The selection context `select` applies expressions over columns. The context `select` may produce +new columns that are aggregations, combinations of other columns, or literals: + +{{code_block('user-guide/concepts/expressions','select-1',['select'])}} + +```python exec="on" result="text" session="user-guide/concepts/expressions-and-contexts" +--8<-- "python/user-guide/concepts/expressions.py:select-1" +``` + +The expressions in a context `select` must produce series that are all the same length or they must +produce a scalar. Scalars will be broadcast to match the length of the remaining series. Literals, +like the number used above, are also broadcast. + +Note that broadcasting can also occur within expressions. For instance, consider the expression +below: + +{{code_block('user-guide/concepts/expressions','select-2',['select'])}} + +```python exec="on" result="text" session="user-guide/concepts/expressions-and-contexts" +--8<-- "python/user-guide/concepts/expressions.py:select-2" +``` + +Both the subtraction and the division use broadcasting within the expression because the +subexpressions that compute the mean and the standard deviation evaluate to single values. + +The context `select` is very flexible and powerful and allows you to evaluate arbitrary expressions +independent of, and in parallel to, each other. This is also true of the other contexts that we will +see next. + +### `with_columns` + +The context `with_columns` is very similar to the context `select`. The main difference between the +two is that the context `with_columns` creates a new dataframe that contains the columns from the +original dataframe and the new columns according to its input expressions, whereas the context +`select` only includes the columns selected by its input expressions: + +{{code_block('user-guide/concepts/expressions','with_columns-1',['with_columns'])}} + +```python exec="on" result="text" session="user-guide/concepts/expressions-and-contexts" +--8<-- "python/user-guide/concepts/expressions.py:with_columns-1" +``` + +Because of this difference between `select` and `with_columns`, the expressions used in a context +`with_columns` must produce series that have the same length as the original columns in the +dataframe, whereas it is enough for the expressions in the context `select` to produce series that +have the same length among them. + +### `filter` + +The context `filter` filters the rows of a dataframe based on one or more expressions that evaluate +to the Boolean data type. + +{{code_block('user-guide/concepts/expressions','filter-1',['filter'])}} + +```python exec="on" result="text" session="user-guide/concepts/expressions-and-contexts" +--8<-- "python/user-guide/concepts/expressions.py:filter-1" +``` + +### `group_by` and aggregations + +In the context `group_by`, rows are grouped according to the unique values of the grouping +expressions. You can then apply expressions to the resulting groups, which may be of variable +lengths. + +When using the context `group_by`, you can use an expression to compute the groupings dynamically: + +{{code_block('user-guide/concepts/expressions','group_by-1',['group_by'])}} + +```python exec="on" result="text" session="user-guide/concepts/expressions-and-contexts" +--8<-- "python/user-guide/concepts/expressions.py:group_by-1" +``` + +After using `group_by` we use `agg` to apply aggregating expressions to the groups. Since in the +example above we only specified the name of a column, we get the groups of that column as lists. + +We can specify as many grouping expressions as we'd like and the context `group_by` will group the +rows according to the distinct values across the expressions specified. Here, we group by a +combination of decade of birth and whether the person is shorter than 1.7 metres: + +{{code_block('user-guide/concepts/expressions','group_by-2',['group_by'])}} + +```python exec="on" result="text" session="user-guide/concepts/expressions-and-contexts" +--8<-- "python/user-guide/concepts/expressions.py:group_by-2" +``` + +The resulting dataframe, after applying aggregating expressions, contains one column per each +grouping expression on the left and then as many columns as needed to represent the results of the +aggregating expressions. In turn, we can specify as many aggregating expressions as we want: + +{{code_block('user-guide/concepts/expressions','group_by-3',['group_by'])}} + +```python exec="on" result="text" session="user-guide/concepts/expressions-and-contexts" +--8<-- "python/user-guide/concepts/expressions.py:group_by-3" +``` + +See also `group_by_dynamic` and `rolling` for other grouping contexts. + +## Expression expansion + +The last example contained two grouping expressions and three aggregating expressions, and yet the +resulting dataframe contained six columns instead of five. If we look closely, the last aggregating +expression mentioned two different columns: “weight” and “height”. + +Polars expressions support a feature called _expression expansion_. Expression expansion is like a +shorthand notation for when you want to apply the same transformation to multiple columns. As we +have seen, the expression + +```python +pl.col("weight", "height").mean().name.prefix("avg_") +``` + +will compute the mean value of the columns “weight” and “height” and will rename them as +“avg_weight” and “avg_height”, respectively. In fact, the expression above is equivalent to using +the two following expressions: + +```python +[ + pl.col("weight").mean().alias("avg_weight"), + pl.col("height").mean().alias("avg_height"), +] +``` + +In this case, this expression expands into two independent expressions that Polars can execute in +parallel. In other cases, we may not be able to know in advance how many independent expressions an +expression will unfold into. + +Consider this simple but elucidative example: + +```python +(pl.col(pl.Float64) * 1.1).name.suffix("*1.1") +``` + +This expression will multiply all columns with data type `Float64` by `1.1`. The number of columns +this applies to depends on the schema of each dataframe. In the case of the dataframe we have been +using, it applies to two columns: + +{{code_block('user-guide/concepts/expressions','expression-expansion-1',['group_by'])}} + +```python exec="on" result="text" session="user-guide/concepts/expressions-and-contexts" +--8<-- "python/user-guide/concepts/expressions.py:expression-expansion-1" +``` + +In the case of the dataframe `df2` below, the same expression expands to 0 columns because no column +has the data type `Float64`: + +{{code_block('user-guide/concepts/expressions','expression-expansion-2',['group_by'])}} + +```python exec="on" result="text" session="user-guide/concepts/expressions-and-contexts" +--8<-- "python/user-guide/concepts/expressions.py:expression-expansion-2" +``` + +It is equally easy to imagine a scenario where the same expression would expand to dozens of +columns. + +Next, you will learn about +[the lazy API and the function `explain`](lazy-api.md#previewing-the-query-plan), which you can use +to preview what an expression will expand to given a schema. + +## Conclusion + +Because expressions are lazy, when you use an expression inside a context Polars can try to simplify +your expression before running the data transformation it expresses. Separate expressions within a +context are embarrassingly parallel and Polars will take advantage of that, while also parallelizing +expression execution when using expression expansion. Further performance gains can be obtained when +using [the lazy API of Polars](lazy-api.md), which is introduced next. + +We have only scratched the surface of the capabilities of expressions. There are a ton more +expressions and they can be combined in a variety of ways. See the +[section on expressions](../expressions/index.md) for a deeper dive on the different types of +expressions available. + +[^1]: There are additional List and SQL contexts which are covered later in this guide. But for +simplicity, we leave them out of scope for now. diff --git a/docs/source/user-guide/concepts/index.md b/docs/source/user-guide/concepts/index.md new file mode 100644 index 000000000000..b926828716f4 --- /dev/null +++ b/docs/source/user-guide/concepts/index.md @@ -0,0 +1,8 @@ +# Concepts + +This chapter describes the core concepts of the Polars API. Understanding these will help you +optimise your queries on a daily basis. We will cover the following topics: + +- [Data types and structures](data-types-and-structures.md) +- [Expressions and contexts](expressions-and-contexts.md) +- [Lazy API](lazy-api.md) diff --git a/docs/source/user-guide/concepts/lazy-api.md b/docs/source/user-guide/concepts/lazy-api.md new file mode 100644 index 000000000000..349bd058c2bb --- /dev/null +++ b/docs/source/user-guide/concepts/lazy-api.md @@ -0,0 +1,80 @@ +# Lazy API + +Polars supports two modes of operation: lazy and eager. The examples so far have used the eager API, +in which the query is executed immediately. In the lazy API, the query is only evaluated once it is +_collected_. Deferring the execution to the last minute can have significant performance advantages +and is why the lazy API is preferred in most cases. Let us demonstrate this with an example: + +{{code_block('user-guide/concepts/lazy-vs-eager','eager',['read_csv'])}} + +In this example we use the eager API to: + +1. Read the iris [dataset](https://archive.ics.uci.edu/dataset/53/iris). +1. Filter the dataset based on sepal length. +1. Calculate the mean of the sepal width per species. + +Every step is executed immediately returning the intermediate results. This can be very wasteful as +we might do work or load extra data that is not being used. If we instead used the lazy API and +waited on execution until all the steps are defined then the query planner could perform various +optimizations. In this case: + +- Predicate pushdown: Apply filters as early as possible while reading the dataset, thus only + reading rows with sepal length greater than 5. +- Projection pushdown: Select only the columns that are needed while reading the dataset, thus + removing the need to load additional columns (e.g., petal length and petal width). + +{{code_block('user-guide/concepts/lazy-vs-eager','lazy',['scan_csv'])}} + +These will significantly lower the load on memory & CPU thus allowing you to fit bigger datasets in +memory and process them faster. Once the query is defined you call `collect` to inform Polars that +you want to execute it. You can +[learn more about the lazy API in its dedicated chapter](../lazy/index.md). + +!!! info "Eager API" + + In many cases the eager API is actually calling the lazy API under the hood and immediately collecting the result. This has the benefit that within the query itself optimization(s) made by the query planner can still take place. + +## When to use which + +In general, the lazy API should be preferred unless you are either interested in the intermediate +results or are doing exploratory work and don't know yet what your query is going to look like. + +## Previewing the query plan + +When using the lazy API you can use the function `explain` to ask Polars to create a description of +the query plan that will be executed once you collect the results. This can be useful if you want to +see what types of optimizations Polars performs on your queries. We can ask Polars to explain the +query `q` we defined above: + +{{code_block('user-guide/concepts/lazy-vs-eager','explain',['explain'])}} + +```python exec="on" result="text" session="user-guide/concepts/lazy-api" +--8<-- "python/user-guide/concepts/lazy-vs-eager.py:import" +--8<-- "python/user-guide/concepts/lazy-vs-eager.py:lazy" +--8<-- "python/user-guide/concepts/lazy-vs-eager.py:explain" +``` + +Immediately, we can see in the explanation that Polars did apply predicate pushdown, as it is only +reading rows where the sepal length is greater than 5, and it did apply projection pushdown, as it +is only reading the columns that are needed by the query. + +The function `explain` can also be used to see how expression expansion will unfold in the context +of a given schema. Consider the example expression from the +[section on expression expansion](expressions-and-contexts.md#expression-expansion): + +```python +(pl.col(pl.Float64) * 1.1).name.suffix("*1.1") +``` + +We can use `explain` to see how this expression would evaluate against an arbitrary schema: + +=== ":fontawesome-brands-python: Python" +[:material-api: `explain`](https://docs.pola.rs/api/python/stable/reference/lazyframe/api/polars.LazyFrame.explain.html) + +```python +--8<-- "python/user-guide/concepts/lazy-vs-eager.py:explain-expression-expansion" +``` + +```python exec="on" result="text" session="user-guide/concepts/lazy-api" +--8<-- "python/user-guide/concepts/lazy-vs-eager.py:explain-expression-expansion" +``` diff --git a/docs/source/user-guide/ecosystem.md b/docs/source/user-guide/ecosystem.md new file mode 100644 index 000000000000..ba353cf566ac --- /dev/null +++ b/docs/source/user-guide/ecosystem.md @@ -0,0 +1,114 @@ +# Ecosystem + +## Introduction + +On this page you can find a non-exhaustive list of libraries and tools that support Polars. As the +data ecosystem is evolving fast, more libraries will likely support Polars in the future. One of the +main drivers is that Polars makes adheres its memory layout to the `Apache Arrow` spec. + +### Table of contents: + +- [Apache Arrow](#apache-arrow) +- [Data visualisation](#data-visualisation) +- [IO](#io) +- [Machine learning](#machine-learning) +- [Other](#other) + +--- + +### Apache Arrow + +[Apache Arrow](https://arrow.apache.org/) enables zero-copy reads of data within the same process, +meaning that data can be directly accessed in its in-memory format without the need for copying or +serialisation. This enhances performance when integrating with different tools using Apache Arrow. +Polars is compatible with a wide range of libraries that also make use of Apache Arrow, like Pandas +and DuckDB. + +### Data visualisation + +See the [dedicated visualization section](misc/visualization.md). + +### IO + +#### Delta Lake + +The [Delta Lake](https://github.com/delta-io/delta-rs) project aims to unlock the power of the +Deltalake for as many users and projects as possible by providing native low-level APIs aimed at +developers and integrators, as well as a high-level operations API that lets you query, inspect, and +operate your Delta Lake with ease. Delta Lake builds on the native Polars Parquet reader allowing +you to write standard Polars queries against a DeltaTable. + +Read how to use Delta Lake with Polars +[at Delta Lake](https://delta-io.github.io/delta-rs/integrations/delta-lake-polars/#reading-a-delta-lake-table-with-polars). + +### Machine Learning + +#### Scikit Learn + +The [Scikit Learn](https://scikit-learn.org/stable/) machine learning package accepts a Polars +`DataFrame` as input/output to all transformers and as input to models. +[skrub](https://skrub-data.org) helps encoding DataFrames for scikit-learn estimators (eg converting +dates or strings). + +#### XGBoost & LightGBM + +XGBoost and LightGBM are gradient boosting packages for doing regression or classification on +tabular data. +[XGBoost accepts Polars `DataFrame` and `LazyFrame` as input](https://xgboost.readthedocs.io/en/latest/python/python_intro.html) +while LightGBM accepts Polars `DataFrame` as input. + +#### Time series forecasting + +The +[Nixtla time series forecasting packages](https://nixtlaverse.nixtla.io/statsforecast/docs/getting-started/getting_started_complete_polars.html) +accept a Polars `DataFrame` as input. + +#### Hugging Face + +Hugging Face is a platform for working with machine learning datasets and models. +[Polars can be used to work with datasets downloaded from Hugging Face](io/hugging-face.md). + +#### Deep learning frameworks + +A `DataFrame` can be transformed +[into a PyTorch format using `to_torch`](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.to_torch.html) +or +[into a JAX format using `to_jax`](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.to_jax.html). + +### Other + +#### DuckDB + +[DuckDB](https://duckdb.org) is a high-performance analytical database system. It is designed to be +fast, reliable, portable, and easy to use. DuckDB provides a rich SQL dialect, with support far +beyond basic SQL. DuckDB supports arbitrary and nested correlated subqueries, window functions, +collations, complex types (arrays, structs), and more. Read about integration with Polars +[on the DuckDB website](https://duckdb.org/docs/guides/python/polars). + +#### Great Tables + +With [Great Tables](https://posit-dev.github.io/great-tables/articles/intro.html) anyone can make +wonderful-looking tables in Python. Here is a +[blog post](https://posit-dev.github.io/great-tables/blog/polars-styling/) on how to use Great +Tables with Polars. + +#### LanceDB + +[LanceDB](https://lancedb.com/) is a developer-friendly, serverless vector database for AI +applications. They have added a direct integration with Polars. LanceDB can ingest Polars +dataframes, return results as polars dataframes, and export the entire table as a polars lazyframe. +You can find a quick tutorial in their blog +[LanceDB + Polars](https://blog.lancedb.com/lancedb-polars-2d5eb32a8aa3) + +#### Mage + +[Mage](https://www.mage.ai) is an open-source data pipeline tool for transforming and integrating +data. Learn about integration between Polars and Mage at +[docs.mage.ai](https://docs.mage.ai/integrations/polars). + +#### marimo + +[marimo](https://marimo.io) is a reactive notebook for Python and SQL that models notebooks as +dataflow graphs. It offers built-in support for Polars, allowing seamless integration of Polars +dataframes in an interactive, reactive environment - such as displaying rich Polars tables, no-code +transformations of Polars dataframes, or selecting points on a Polars-backed reactive chart. diff --git a/docs/source/user-guide/expressions/aggregation.md b/docs/source/user-guide/expressions/aggregation.md new file mode 100644 index 000000000000..c162b4094970 --- /dev/null +++ b/docs/source/user-guide/expressions/aggregation.md @@ -0,0 +1,149 @@ +# Aggregation + +The Polars [context](../concepts/expressions-and-contexts.md#contexts) `group_by` lets you apply +expressions on subsets of columns, as defined by the unique values of the column over which the data +is grouped. This is a very powerful capability that we explore in this section of the user guide. + +We start by reading in a +[US congress `dataset`](https://github.com/unitedstates/congress-legislators): + +{{code_block('user-guide/expressions/aggregation','dataframe',['DataFrame','Categorical'])}} + +```python exec="on" result="text" session="user-guide/expressions" +--8<-- "python/user-guide/expressions/aggregation.py:dataframe" +``` + +## Basic aggregations + +You can easily apply multiple expressions to your aggregated values. Simply list all of the +expressions you want inside the function `agg`. There is no upper bound on the number of +aggregations you can do and you can make any combination you want. In the snippet below we will +group the data based on the column “first_name” and then we will apply the following aggregations: + +- count the number of rows in the group (which means we count how many people in the data set have + each unique first name); +- combine the values of the column “gender” into a list by referring the column but omitting an + aggregate function; and +- get the first value of the column “last_name” within the group. + +After computing the aggregations, we immediately sort the result and limit it to the top five rows +so that we have a nice summary overview: + +{{code_block('user-guide/expressions/aggregation','basic',['group_by'])}} + +```python exec="on" result="text" session="user-guide/expressions" +--8<-- "python/user-guide/expressions/aggregation.py:basic" +``` + +It's that easy! Let's turn it up a notch. + +## Conditionals + +Let's say we want to know how many delegates of a state are “Pro” or “Anti” administration. We can +query that directly in the aggregation without the need for a `lambda` or grooming the dataframe: + +{{code_block('user-guide/expressions/aggregation','conditional',['group_by'])}} + +```python exec="on" result="text" session="user-guide/expressions" +--8<-- "python/user-guide/expressions/aggregation.py:conditional" +``` + +## Filtering + +We can also filter the groups. Let's say we want to compute a mean per group, but we don't want to +include all values from that group, and we also don't want to actually filter the rows from the +dataframe because we need those rows for another aggregation. + +In the example below we show how this can be done. + +!!! note + + Note that we can define Python functions for clarity. + These functions don't cost us anything because they return Polars expressions, we don't apply a custom function over a series during runtime of the query. + Of course, you can write functions that return expressions in Rust, too. + +{{code_block('user-guide/expressions/aggregation','filter',['group_by'])}} + +```python exec="on" result="text" session="user-guide/expressions" +--8<-- "python/user-guide/expressions/aggregation.py:filter" +``` + +Do the average age values look nonsensical? That's because we are working with historical data that +dates back to the 1800s and we are doing our computations assuming everyone represented in the +dataset is still alive and kicking. + +## Nested grouping + +The two previous queries could have been done with a nested `group_by`, but that wouldn't have let +us show off some of these features. 😉 To do a nested `group_by`, simply list the columns that will +be used for grouping. + +First, we use a nested `group_by` to figure out how many delegates of a state are “Pro” or “Anti” +administration: + +{{code_block('user-guide/expressions/aggregation','nested',['group_by'])}} + +```python exec="on" result="text" session="user-guide/expressions" +--8<-- "python/user-guide/expressions/aggregation.py:nested" +``` + +Next, we use a nested `group_by` to compute the average age of delegates per state and per gender: + +{{code_block('user-guide/expressions/aggregation','filter-nested',['group_by'])}} + +```python exec="on" result="text" session="user-guide/expressions" +--8<-- "python/user-guide/expressions/aggregation.py:filter-nested" +``` + +Note that we get the same results but the format of the data is different. Depending on the +situation, one format may be more suitable than the other. + +## Sorting + +It is common to see a dataframe being sorted for the sole purpose of managing the ordering during a +grouping operation. Let's say that we want to get the names of the oldest and youngest politicians +per state. We could start by sorting and then grouping: + +{{code_block('user-guide/expressions/aggregation','sort',['group_by'])}} + +```python exec="on" result="text" session="user-guide/expressions" +--8<-- "python/user-guide/expressions/aggregation.py:sort" +``` + +However, if we also want to sort the names alphabetically, we need to perform an extra sort +operation. Luckily, we can sort in a `group_by` context without changing the sorting of the +underlying dataframe: + +{{code_block('user-guide/expressions/aggregation','sort2',['group_by'])}} + +```python exec="on" result="text" session="user-guide/expressions" +--8<-- "python/user-guide/expressions/aggregation.py:sort2" +``` + +We can even sort a column with the order induced by another column, and this also works inside the +context `group_by`. This modification to the previous query lets us check if the delegate with the +first name is male or female: + +{{code_block('user-guide/expressions/aggregation','sort3',['group_by'])}} + +```python exec="on" result="text" session="user-guide/expressions" +--8<-- "python/user-guide/expressions/aggregation.py:sort3" +``` + +## Do not kill parallelization + +!!! warning "Python users only" + + The following section is specific to Python, and doesn't apply to Rust. + Within Rust, blocks and closures (lambdas) can, and will, be executed concurrently. + +Python is generally slower than Rust. Besides the overhead of running “slow” bytecode, Python has to +remain within the constraints of the Global Interpreter Lock (GIL). This means that if you were to +use a `lambda` or a custom Python function to apply during a parallelized phase, Polars' speed is +capped running Python code, preventing any multiple threads from executing the function. + +Polars will try to parallelize the computation of the aggregating functions over the groups, so it +is recommended that you avoid using `lambda`s and custom Python functions as much as possible. +Instead, try to stay within the realm of the Polars expression API. This is not always possible, +though, so if you want to learn more about using `lambda`s you can go +[the user guide section on using user-defined functions](user-defined-python-functions.md). diff --git a/docs/source/user-guide/expressions/athletes_over_country.svg b/docs/source/user-guide/expressions/athletes_over_country.svg new file mode 100644 index 000000000000..3ab07cc73ef7 --- /dev/null +++ b/docs/source/user-guide/expressions/athletes_over_country.svg @@ -0,0 +1,84 @@ + + + + + + A + + B + + C + + D + + E + + F + + PT + + NL + + NL + + PT + + PT + + NL + + 6 + + 1 + + 5 + + 4 + + 2 + + 3 + + E + + B + + F + + D + + A + + C + + PT + + NL + + NL + + PT + + PT + + NL + + 2 + + 1 + + 3 + + 4 + + 6 + + 5 + + + + NL + NL + diff --git a/docs/source/user-guide/expressions/athletes_over_country_explode.svg b/docs/source/user-guide/expressions/athletes_over_country_explode.svg new file mode 100644 index 000000000000..d49db911465c --- /dev/null +++ b/docs/source/user-guide/expressions/athletes_over_country_explode.svg @@ -0,0 +1,85 @@ + + + + + + A + + B + + C + + D + + E + + F + + PT + + NL + + NL + + PT + + PT + + NL + + 6 + + 1 + + 5 + + 4 + + 2 + + 3 + + E + + B + + F + + D + + A + + C + + PT + + NL + + NL + + PT + + PT + + NL + + 2 + + 1 + + 3 + + 4 + + 6 + + 5 + + NL + NL + + + NL + diff --git a/docs/source/user-guide/expressions/basic-operations.md b/docs/source/user-guide/expressions/basic-operations.md new file mode 100644 index 000000000000..b3eaa6364850 --- /dev/null +++ b/docs/source/user-guide/expressions/basic-operations.md @@ -0,0 +1,141 @@ +# Basic operations + +This section shows how to do basic operations on dataframe columns, like do basic arithmetic +calculations, perform comparisons, and other general-purpose operations. We will use the following +dataframe for the examples that follow: + +{{code_block('user-guide/expressions/operations', 'dataframe', ['DataFrame'])}} + +```python exec="on" result="text" session="expressions/operations" +--8<-- "python/user-guide/expressions/operations.py:dataframe" +``` + +## Basic arithmetic + +Polars supports basic arithmetic between series of the same length, or between series and literals. +When literals are mixed with series, the literals are broadcast to match the length of the series +they are being used with. + +{{code_block('user-guide/expressions/operations', 'arithmetic', ['operators'])}} + +```python exec="on" result="text" session="expressions/operations" +--8<-- "python/user-guide/expressions/operations.py:arithmetic" +``` + +The example above shows that when an arithmetic operation takes `null` as one of its operands, the +result is `null`. + +Polars uses operator overloading to allow you to use your language's native arithmetic operators +within your expressions. If you prefer, in Python you can use the corresponding named functions, as +the snippet below demonstrates: + +```python +--8<-- "python/user-guide/expressions/operations.py:operator-overloading" +``` + +```python exec="on" result="text" session="expressions/operations" +--8<-- "python/user-guide/expressions/operations.py:operator-overloading" +``` + +## Comparisons + +Like with arithmetic operations, Polars supports comparisons via the overloaded operators or named +functions: + +{{code_block('user-guide/expressions/operations','comparison',['operators'])}} + +```python exec="on" result="text" session="expressions/operations" +--8<-- "python/user-guide/expressions/operations.py:comparison" +``` + +## Boolean and bitwise operations + +Depending on the language, you may use the operators `&`, `|`, and `~`, for the Boolean operations +“and”, “or”, and “not”, respectively, or the functions of the same name: + +{{code_block('user-guide/expressions/operations', 'boolean', ['operators'])}} + +```python exec="on" result="text" session="expressions/operations" +--8<-- "python/user-guide/expressions/operations.py:boolean" +``` + +??? info "Python trivia" + + The Python functions are called `and_`, `or_`, and `not_`, because the words `and`, `or`, and `not` are reserved keywords in Python. + Similarly, we cannot use the keywords `and`, `or`, and `not`, as the Boolean operators because these Python keywords will interpret their operands in the context of Truthy and Falsy through the dunder method `__bool__`. + Thus, we overload the bitwise operators `&`, `|`, and `~`, as the Boolean operators because they are the second best choice. + +These operators/functions can also be used for the respective bitwise operations, alongside the +bitwise operator `^` / function `xor`: + +{{code_block('user-guide/expressions/operations', 'bitwise', [])}} + +```python exec="on" result="text" session="expressions/operations" +--8<-- "python/user-guide/expressions/operations.py:bitwise" +``` + +## Counting (unique) values + +Polars has two functions to count the number of unique values in a series. The function `n_unique` +can be used to count the exact number of unique values in a series. However, for very large data +sets, this operation can be quite slow. In those cases, if an approximation is good enough, you can +use the function `approx_n_unique` that uses the algorithm +[HyperLogLog++](https://en.wikipedia.org/wiki/HyperLogLog) to estimate the result. + +The example below shows an example series where the `approx_n_unique` estimation is wrong by 0.9%: + +{{code_block('user-guide/expressions/operations', 'count', ['n_unique', 'approx_n_unique'])}} + +```python exec="on" result="text" session="expressions/operations" +--8<-- "python/user-guide/expressions/operations.py:count" +``` + +You can get more information about the unique values and their counts with the function +`value_counts`, that Polars also provides: + +{{code_block('user-guide/expressions/operations', 'value_counts', ['value_counts'])}} + +```python exec="on" result="text" session="expressions/operations" +--8<-- "python/user-guide/expressions/operations.py:value_counts" +``` + +The function `value_counts` returns the results in +[structs, a data type that we will explore in a later section](structs.md). + +Alternatively, if you only need a series with the unique values or a series with the unique counts, +they are one function away: + +{{code_block('user-guide/expressions/operations', 'unique_counts', ['unique', 'unique_counts'])}} + +```python exec="on" result="text" session="expressions/operations" +--8<-- "python/user-guide/expressions/operations.py:unique_counts" +``` + +Note that we need to specify `maintain_order=True` in the function `unique` so that the order of the +results is consistent with the order of the results in `unique_counts`. See the API reference for +more information. + +## Conditionals + +Polars supports something akin to a ternary operator through the function `when`, which is followed +by one function `then` and an optional function `otherwise`. + +The function `when` accepts a predicate expression. The values that evaluate to `True` are replaced +by the corresponding values of the expression inside the function `then`. The values that evaluate +to `False` are replaced by the corresponding values of the expression inside the function +`otherwise` or `null`, if `otherwise` is not provided. + +The example below applies one step of the +[Collatz conjecture](https://en.wikipedia.org/wiki/Collatz_conjecture) to the numbers in the column +“nrs”: + +{{code_block('user-guide/expressions/operations', 'collatz', ['when'])}} + +```python exec="on" result="text" session="expressions/operations" +--8<-- "python/user-guide/expressions/operations.py:collatz" +``` + +You can also emulate a chain of an arbitrary number of conditionals, akin to Python's `elif` +statement, by chaining an arbitrary number of consecutive blocks of `.when(...).then(...)`. In those +cases, and for each given value, Polars will only consider a replacement expression that is deeper +within the chain if the previous predicates all failed for that value. diff --git a/docs/source/user-guide/expressions/casting.md b/docs/source/user-guide/expressions/casting.md new file mode 100644 index 000000000000..4642ee1148f3 --- /dev/null +++ b/docs/source/user-guide/expressions/casting.md @@ -0,0 +1,130 @@ +# Casting + +Casting converts the [underlying data type of a column](../concepts/data-types-and-structures.md) to +a new one. Casting is available through the function `cast`. + +The function `cast` includes a parameter `strict` that determines how Polars behaves when it +encounters a value that cannot be converted from the source data type to the target data type. The +default behaviour is `strict=True`, which means that Polars will thrown an error to notify the user +of the failed conversion while also providing details on the values that couldn't be cast. On the +other hand, if `strict=False`, any values that cannot be converted to the target data type will be +quietly converted to `null`. + +## Basic example + +Let's take a look at the following dataframe which contains both integers and floating point +numbers: + +{{code_block('user-guide/expressions/casting', 'dfnum', [])}} + +```python exec="on" result="text" session="user-guide/casting" +--8<-- "python/user-guide/expressions/casting.py:dfnum" +``` + +To perform casting operations between floats and integers, or vice versa, we use the function +`cast`: + +{{code_block('user-guide/expressions/casting','castnum',['cast'])}} + +```python exec="on" result="text" session="user-guide/casting" +--8<-- "python/user-guide/expressions/casting.py:castnum" +``` + +Note that floating point numbers are truncated when casting to an integer data type. + +## Downcasting numerical data types + +You can reduce the memory footprint of a column by changing the precision associated with its +numeric data type. As an illustration, the code below demonstrates how casting from `Int64` to +`Int16` and from `Float64` to `Float32` can be used to lower memory usage: + +{{code_block('user-guide/expressions/casting','downcast',['cast', 'estimated_size'])}} + +```python exec="on" result="text" session="user-guide/casting" +--8<-- "python/user-guide/expressions/casting.py:downcast" +``` + +When performing downcasting it is crucial to ensure that the chosen number of bits (such as 64, 32, +or 16) is sufficient to accommodate the largest and smallest numbers in the column. For example, a +32-bit signed integer (`Int32`) represents integers between -2147483648 and 2147483647, inclusive, +while an 8-bit signed integer only represents integers between -128 and 127, inclusive. Attempting +to downcast to a data type with insufficient precision results in an error thrown by Polars: + +{{code_block('user-guide/expressions/casting','overflow',['cast'])}} + +```python exec="on" result="text" session="user-guide/casting" +--8<-- "python/user-guide/expressions/casting.py:overflow" +``` + +If you set the parameter `strict` to `False` the overflowing/underflowing values are converted to +`null`: + +{{code_block('user-guide/expressions/casting','overflow2',['cast'])}} + +```python exec="on" result="text" session="user-guide/casting" +--8<-- "python/user-guide/expressions/casting.py:overflow2" +``` + +## Converting strings to numeric data types + +Strings that represent numbers can be converted to the appropriate data types via casting. The +opposite conversion is also possible: + +{{code_block('user-guide/expressions/casting','strings',['cast'])}} + +```python exec="on" result="text" session="user-guide/casting" +--8<-- "python/user-guide/expressions/casting.py:strings" +``` + +In case the column contains a non-numerical value, or a poorly formatted one, Polars will throw an +error with details on the conversion error. You can set `strict=False` to circumvent the error and +get a `null` value instead. + +{{code_block('user-guide/expressions/casting','strings2',['cast'])}} + +```python exec="on" result="text" session="user-guide/casting" +--8<-- "python/user-guide/expressions/casting.py:strings2" +``` + +## Booleans + +Booleans can be expressed as either 1 (`True`) or 0 (`False`). It's possible to perform casting +operations between a numerical data type and a Boolean, and vice versa. + +When converting numbers to Booleans, the number 0 is converted to `False` and all other numbers are +converted to `True`, in alignment with Python's Truthy and Falsy values for numbers: + +{{code_block('user-guide/expressions/casting','bool',['cast'])}} + +```python exec="on" result="text" session="user-guide/casting" +--8<-- "python/user-guide/expressions/casting.py:bool" +``` + +## Parsing / formatting temporal data types + +All temporal data types are represented internally as the number of time units elapsed since a +reference moment, usually referred to as the epoch. For example, values of the data type `Date` are +stored as the number of days since the epoch. For the data type `Datetime` the time unit is the +microsecond (us) and for `Time` the time unit is the nanosecond (ns). + +Casting between numerical types and temporal data types is allowed and exposes this relationship: + +{{code_block('user-guide/expressions/casting','dates',['cast'])}} + +```python exec="on" result="text" session="user-guide/casting" +--8<-- "python/user-guide/expressions/casting.py:dates" +``` + +To format temporal data types as strings we can use the function `dt.to_string` and to parse +temporal data types from strings we can use the function `str.to_datetime`. Both functions adopt the +[chrono format syntax](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) for +formatting. + +{{code_block('user-guide/expressions/casting','dates2',['dt.to_string','str.to_date'])}} + +```python exec="on" result="text" session="user-guide/casting" +--8<-- "python/user-guide/expressions/casting.py:dates2" +``` + +It's worth noting that `str.to_datetime` features additional options that support timezone +functionality. Refer to the API documentation for further information. diff --git a/docs/source/user-guide/expressions/categorical-data-and-enums.md b/docs/source/user-guide/expressions/categorical-data-and-enums.md new file mode 100644 index 000000000000..64ce602b40a2 --- /dev/null +++ b/docs/source/user-guide/expressions/categorical-data-and-enums.md @@ -0,0 +1,541 @@ +# Categorical data and enums + +A column that holds string values that can only take on one of a limited number of possible values +is a column that holds [categorical data](https://en.wikipedia.org/wiki/Categorical_variable). +Usually, the number of possible values is much smaller than the length of the column. Some typical +examples include your nationality, the operating system of your computer, or the license that your +favorite open source project uses. + +When working with categorical data you can use Polars' dedicated types, `Categorical` and `Enum`, to +make your queries more performant. Now, we will see what are the differences between the two data +types `Categorical` and `Enum` and when you should use one data type or the other. We also include +some notes on +[why the data types `Categorical` and `Enum` are more efficient than using the plain string values](#performance-considerations-on-categorical-data-types) +in the end of this user guide section. + +## `Enum` vs `Categorical` + +In short, you should prefer `Enum` over `Categorical` whenever possible. When the categories are +fixed and known up front, use `Enum`. When you don't know the categories or they are not fixed then +you must use `Categorical`. In case your requirements change along the way you can always cast from +one to the other. + +## Data type `Enum` + +### Creating an `Enum` + +The data type `Enum` is an ordered categorical data type. To use the data type `Enum` you have to +specify the categories in advance to create a new data type that is a variant of an `Enum`. Then, +when creating a new series, a new dataframe, or when casting a string column, you can use that +`Enum` variant. + +{{code_block('user-guide/expressions/categoricals', 'enum-example', ['Enum'])}} + +```python exec="on" result="text" session="expressions/categoricals" +--8<-- "python/user-guide/expressions/categoricals.py:enum-example" +``` + +### Invalid values + +Polars will raise an error if you try to specify a data type `Enum` whose categories do not include +all the values present: + +{{code_block('user-guide/expressions/categoricals', 'enum-wrong-value', ['Enum'])}} + +```python exec="on" result="text" session="expressions/categoricals" +--8<-- "python/user-guide/expressions/categoricals.py:enum-wrong-value" +``` + +If you are in a position where you cannot know all of the possible values in advance and erroring on +unknown values is semantically wrong, you may need to +[use the data type `Categorical`](#data-type-categorical). + +### Category ordering and comparison + +The data type `Enum` is ordered and the order is induced by the order in which you specify the +categories. The example below uses log levels as an example of where an ordered `Enum` is useful: + +{{code_block('user-guide/expressions/categoricals', 'log-levels', ['Enum'])}} + +```python exec="on" result="text" session="expressions/categoricals" +--8<-- "python/user-guide/expressions/categoricals.py:log-levels" +``` + +This example shows that we can compare `Enum` values with a string, but this only works if the +string matches one of the `Enum` values. If we compared the column “level” with any string other +than `"debug"`, `"info"`, `"warning"`, or `"error"`, Polars would raise an exception. + +Columns with the data type `Enum` can also be compared with other columns that have the same data +type `Enum` or columns that hold strings, but only if all the strings are valid `Enum` values. + +## Data type `Categorical` + +The data type `Categorical` can be seen as a more flexible version of `Enum`. + +### Creating a `Categorical` series + +To use the data type `Categorical`, you can cast a column of strings or specify `Categorical` as the +data type of a series or dataframe column: + +{{code_block('user-guide/expressions/categoricals', 'categorical-example', ['Categorical'])}} + +```python exec="on" result="text" session="expressions/categoricals" +--8<-- "python/user-guide/expressions/categoricals.py:categorical-example" +``` + +Having Polars infer the categories for you may sound strictly better than listing the categories +beforehand, but this inference comes with a performance cost. That is why, whenever possible, you +should use `Enum`. You can learn more by +[reading the subsection about the data type `Categorical` and its encodings](#data-type-categorical-and-encodings). + +### Lexical comparison with strings + +When comparing a `Categorical` column with a string, Polars will perform a lexical comparison: + +{{code_block('user-guide/expressions/categoricals', 'categorical-comparison-string', +['Categorical'])}} + +```python exec="on" result="text" session="expressions/categoricals" +--8<-- "python/user-guide/expressions/categoricals.py:categorical-comparison-string" +``` + +You can also compare a column of strings with your `Categorical` column, and the comparison will +also be lexical: + +{{code_block('user-guide/expressions/categoricals', 'categorical-comparison-string-column', +['Categorical'])}} + +```python exec="on" result="text" session="expressions/categoricals" +--8<-- "python/user-guide/expressions/categoricals.py:categorical-comparison-string-column" +``` + +Although it is possible to compare a string column with a categorical column, it is typically more +efficient to compare two categorical columns. We will see how to do that next. + +### Comparing `Categorical` columns and the string cache + +You are told that comparing columns with the data type `Categorical` is more efficient than if one +of them is a string column. So, you change your code so that the second column is also a categorical +column and then you perform your comparison... But Polars raises an exception: + +{{code_block('user-guide/expressions/categoricals', 'categorical-comparison-categorical-column', +['Categorical'])}} + +```python exec="on" result="text" session="expressions/categoricals" +--8<-- "python/user-guide/expressions/categoricals.py:categorical-comparison-categorical-column" +``` + +By default, the values in columns with the data type `Categorical` are +[encoded in the order they are seen in the column](#encodings), and independently from other +columns, which means that Polars cannot compare efficiently two categorical columns that were +created independently. + +Enabling the Polars string cache and creating the columns with the cache enabled fixes this issue: + +{{code_block('user-guide/expressions/categoricals', 'stringcache-categorical-equality', +['StringCache', 'Categorical'])}} + +```python exec="on" result="text" session="expressions/categoricals" +--8<-- "python/user-guide/expressions/categoricals.py:stringcache-categorical-equality" +``` + +Note that using [the string cache comes at a performance cost](#using-the-global-string-cache). + +### Combining `Categorical` columns + +The string cache is also useful in any operation that combines or mixes two columns with the data +type `Categorical` in any way. An example of this is when +[concatenating two dataframes vertically](../getting-started.md#concatenating-dataframes): + +{{code_block('user-guide/expressions/categoricals', 'concatenating-categoricals', ['StringCache', +'Categorical'])}} + +```python exec="on" result="text" session="expressions/categoricals" +--8<-- "python/user-guide/expressions/categoricals.py:concatenating-categoricals" +``` + +In this case, Polars issues a warning complaining about an expensive reenconding that implies taking +a performance hit. Polars then suggests using the data type `Enum` if possible, or using the string +cache. To understand the issue with this operation and why Polars raises an error, please read the +final section about +[the performance considerations of using categorical data types](#performance-considerations-on-categorical-data-types). + +### Comparison between `Categorical` columns is not lexical + +When comparing two columns with data type `Categorical`, Polars does not perform lexical comparison +between the values by default. If you want lexical ordering, you need to specify so when creating +the column: + +{{code_block('user-guide/expressions/categoricals', 'stringcache-categorical-comparison-lexical', +['StringCache', 'Categorical'])}} + +```python exec="on" result="text" session="expressions/categoricals" +--8<-- "python/user-guide/expressions/categoricals.py:stringcache-categorical-comparison-lexical" +``` + +Otherwise, the order is inferred together with the values: + +{{code_block('user-guide/expressions/categoricals', 'stringcache-categorical-comparison-physical', +['StringCache', 'Categorical'])}} + +```python exec="on" result="text" session="expressions/categoricals" +--8<-- "python/user-guide/expressions/categoricals.py:stringcache-categorical-comparison-physical" +``` + +## Performance considerations on categorical data types + +This part of the user guide explains + +- why categorical data types are more performant than the string literals; and +- why Polars needs a string cache when doing some operations with the data type `Categorical`. + +### Encodings + +Categorical data represents string data where the values in the column have a finite set of values +(usually way smaller than the length of the column). Storing these values as plain strings is a +waste of memory and performance as we will be repeating the same string over and over again. +Additionally, in operations like joins we have to perform expensive string comparisons. + +Categorical data types like `Enum` and `Categorical` let you encode the string values in a cheaper +way, establishing a relationship between a cheap encoding value and the original string literal. + +As an example of a sensible encoding, Polars could choose to represent the finite set of categories +as positive integers. With that in mind, the diagram below shows a regular string column and a +possible representation of a Polars column with the categorical data type: + + + + + + +
String Column Categorical Column
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Series
Polar
Panda
Brown
Panda
Brown
Brown
Polar
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Physical
0
1
2
1
2
2
0
+ +
+ + + + + + + + + + + + + + + + + +
Categories
Polar
Panda
Brown
+
+
+ +The physical `0` in this case encodes (or maps) to the value 'Polar', the value `1` encodes to +'Panda', and the value `2` to 'Brown'. This encoding has the benefit of only storing the string +values once. Additionally, when we perform operations (e.g. sorting, counting) we can work directly +on the physical representation which is much faster than the working with string data. + +### Encodings for the data type `Enum` are global + +When working with the data type `Enum` we specify the categories in advance. This way, Polars can +ensure different columns and even different datasets have the same encoding and there is no need for +expensive re-encoding or cache lookups. + +### Data type `Categorical` and encodings + +The fact that the categories for the data type `Categorical` are inferred come at a cost. The main +cost here is that we have no control over our encodings. + +Consider the following scenario where we append the following two categorical series: + +{{code_block('user-guide/concepts/data-types/categoricals','append',[])}} + +Polars encodes the string values in the order they appear. So, the series would look like this: + + + + + + +
cat_series cat2_series
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + + + +
Physical
0
1
2
2
0
+ +
+ + + + + + + + + + + + + + + + + +
Categories
Polar
Panda
Brown
+ +
+
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + + + +
Physical
0
1
1
2
2
+ +
+ + + + + + + + + + + + + + + + + + +
Categories
Panda
Brown
Polar
+ +
+
+ +Combining the series becomes a non-trivial task which is expensive as the physical value of `0` +represents something different in both series. Polars does support these types of operations for +convenience, however these should be avoided due to its slower performance as it requires making +both encodings compatible first before doing any merge operations. + +### Using the global string cache + +One way to handle this reencoding problem is to enable the string cache. Under the string cache, the +diagram would instead look like this: + + + + + + +
SeriesString cache
+ + + + + + +
cat_seriescat2_series
+ + + + + + + + + + + + + + + + + + + + + + + +
Physical
0
1
2
2
0
+
+ + + + + + + + + + + + + + + + + + + + + + + +
Physical
1
2
2
0
0
+
+
+ + + + + + + + + + + + + + + + + +
Categories
Polar
Panda
Brown
+
+ +When you enable the string cache, strings are no longer encoded in the order they appear on a +per-column basis. Instead, the encoding is shared across columns. The value 'Polar' will always be +encoded by the same value for all categorical columns created under the string cache. Merge +operations (e.g. appends, joins) become cheap again as there is no need to make the encodings +compatible first, solving the problem we had above. + +However, the string cache does come at a small performance hit during construction of the series as +we need to look up or insert the string values in the cache. Therefore, it is preferred to use the +data type `Enum` if you know your categories in advance. diff --git a/docs/source/user-guide/expressions/expression-expansion.md b/docs/source/user-guide/expressions/expression-expansion.md new file mode 100644 index 000000000000..23c1bdb3fb25 --- /dev/null +++ b/docs/source/user-guide/expressions/expression-expansion.md @@ -0,0 +1,413 @@ +# Expression expansion + +As you've seen in +[the section about expressions and contexts](../concepts/expressions-and-contexts.md), expression +expansion is a feature that enables you to write a single expression that can expand to multiple +different expressions, possibly depending on the schema of the context in which the expression is +used. + +This feature isn't just decorative or syntactic sugar. It allows for a very powerful application of +[DRY](https://en.wikipedia.org/wiki/Don%27t_repeat_yourself) principles in your code: a single +expression that specifies multiple columns expands into a list of expressions, which means you can +write one single expression and reuse the computation that it represents. + +In this section we will show several forms of expression expansion and we will be using the +dataframe that you can see below for that effect: + +{{code_block('user-guide/expressions/expression-expansion', 'df', [])}} + +```python exec="on" result="text" session="expressions/expression-expansion" +--8<-- "python/user-guide/expressions/expression-expansion.py:df" +``` + +## Function `col` + +The function `col` is the most common way of making use of expression expansion features in Polars. +Typically used to refer to one column of a dataframe, in this section we explore other ways in which +you can use `col` (or its variants, when in Rust). + +### Explicit expansion by column name + +The simplest form of expression expansion happens when you provide multiple column names to the +function `col`. + +The example below uses a single function `col` with multiple column names to convert the values in +USD to EUR: + +{{code_block('user-guide/expressions/expression-expansion', 'col-with-names', ['col'])}} + +```python exec="on" result="text" session="expressions/expression-expansion" +--8<-- "python/user-guide/expressions/expression-expansion.py:col-with-names" +``` + +When you list the column names you want the expression to expand to, you can predict what the +expression will expand to. In this case, the expression that does the currency conversion is +expanded to a list of five expressions: + +{{code_block('user-guide/expressions/expression-expansion', 'expression-list', ['col'])}} + +```python exec="on" result="text" session="expressions/expression-expansion" +--8<-- "python/user-guide/expressions/expression-expansion.py:expression-list" +``` + +### Expansion by data type + +We had to type five column names in the previous example but the function `col` can also +conveniently accept one or more data types. If you provide data types instead of column names, the +expression is expanded to all columns that match one of the data types provided. + +The example below performs the exact same computation as before: + +{{code_block('user-guide/expressions/expression-expansion', 'col-with-dtype', [], ['col'], +['dtype_col'])}} + +```python exec="on" result="text" session="expressions/expression-expansion" +--8<-- "python/user-guide/expressions/expression-expansion.py:col-with-dtype" +``` + +When we use a data type with expression expansion we cannot know, beforehand, how many columns a +single expression will expand to. We need the schema of the input dataframe if we want to determine +what is the final list of expressions that is to be applied. + +If we weren't sure about whether the price columns where of the type `Float64` or `Float32`, we +could specify both data types: + +{{code_block('user-guide/expressions/expression-expansion', 'col-with-dtypes', [], ['col'], +['dtype_cols'])}} + +```python exec="on" result="text" session="expressions/expression-expansion" +--8<-- "python/user-guide/expressions/expression-expansion.py:col-with-dtypes" +``` + +### Expansion by pattern matching + +You can also use regular expressions to specify patterns that are used to match the column names. To +distinguish between a regular column name and expansion by pattern matching, regular expressions +start and end with `^` and `$`, respectively. This also means that the pattern must match against +the whole column name string. + +Regular expressions can be mixed with regular column names: + +{{code_block('user-guide/expressions/expression-expansion', 'col-with-regex', ['col'])}} + +```python exec="on" result="text" session="expressions/expression-expansion" +--8<-- "python/user-guide/expressions/expression-expansion.py:col-with-regex" +``` + +### Arguments cannot be of mixed types + +In Python, the function `col` accepts an arbitrary number of strings (as +[column names](#explicit-expansion-by-column-name) or as +[regular expressions](#expansion-by-pattern-matching)) or an arbitrary number of data types, but you +cannot mix both in the same function call: + +```python +--8<-- "python/user-guide/expressions/expression-expansion.py:col-error" +``` + +```python exec="on" result="text" session="expressions/expression-expansion" +--8<-- "python/user-guide/expressions/expression-expansion.py:col-error" +``` + +## Selecting all columns + +Polars provides the function `all` as shorthand notation to refer to all columns of a dataframe: + +{{code_block('user-guide/expressions/expression-expansion', 'all', ['all'])}} + +```python exec="on" result="text" session="expressions/expression-expansion" +--8<-- "python/user-guide/expressions/expression-expansion.py:all" +``` + +!!! note + + The function `all` is syntactic sugar for `col("*")`, but since the argument `"*"` is a special case and `all` reads more like English, the usage of `all` is preferred. + +## Excluding columns + +Polars also provides a mechanism to exclude certain columns from expression expansion. For that, you +use the function `exclude`, which accepts exactly the same types of arguments as `col`: + +{{code_block('user-guide/expressions/expression-expansion', 'all-exclude', ['exclude'])}} + +```python exec="on" result="text" session="expressions/expression-expansion" +--8<-- "python/user-guide/expressions/expression-expansion.py:all-exclude" +``` + +Naturally, the function `exclude` can also be used after the function `col`: + +{{code_block('user-guide/expressions/expression-expansion', 'col-exclude', ['exclude'])}} + +```python exec="on" result="text" session="expressions/expression-expansion" +--8<-- "python/user-guide/expressions/expression-expansion.py:col-exclude" +``` + +## Column renaming + +By default, when you apply an expression to a column, the result keeps the same name as the original +column. + +Preserving the column name can be semantically wrong and in certain cases Polars may even raise an +error if duplicate names occur: + +{{code_block('user-guide/expressions/expression-expansion', 'duplicate-error', [])}} + +```python exec="on" result="text" session="expressions/expression-expansion" +--8<-- "python/user-guide/expressions/expression-expansion.py:duplicate-error" +``` + +To prevent errors like this, and to allow users to rename their columns when appropriate, Polars +provides a series of functions that let you change the name of a column or a group of columns. + +### Renaming a single column with `alias` + +The function `alias` has been used thoroughly in the documentation already and it lets you rename a +single column: + +{{code_block('user-guide/expressions/expression-expansion', 'alias', ['alias'])}} + +```python exec="on" result="text" session="expressions/expression-expansion" +--8<-- "python/user-guide/expressions/expression-expansion.py:alias" +``` + +### Prefixing and suffixing column names + +When using expression expansion you cannot use the function `alias` because the function `alias` is +designed specifically to rename a single column. + +When it suffices to add a static prefix or a static suffix to the existing names, we can use the +functions `prefix` and `suffix` from the namespace `name`: + +{{code_block('user-guide/expressions/expression-expansion', 'prefix-suffix', ['Expr.name', 'prefix', +'suffix'])}} + +```python exec="on" result="text" session="expressions/expression-expansion" +--8<-- "python/user-guide/expressions/expression-expansion.py:prefix-suffix" +``` + +### Dynamic name replacement + +If a static prefix/suffix is not enough, the namespace `name` also provides the function `map` that +accepts a callable that accepts the old column names and produces the new ones: + +{{code_block('user-guide/expressions/expression-expansion', 'name-map', ['Expr.name', 'map'])}} + +```python exec="on" result="text" session="expressions/expression-expansion" +--8<-- "python/user-guide/expressions/expression-expansion.py:name-map" +``` + +See the API reference for the full contents of the namespace `name`. + +## Programmatically generating expressions + +Expression expansion is a very useful feature but it does not solve all of your problems. For +example, if we want to compute the day and year amplitude of the prices of the stocks in our +dataframe, expression expansion won't help us. + +At first, you may think about using a `for` loop: + +{{code_block('user-guide/expressions/expression-expansion', 'for-with_columns', [])}} + +```python exec="on" result="text" session="expressions/expression-expansion" +--8<-- "python/user-guide/expressions/expression-expansion.py:for-with_columns" +``` + +Do not do this. Instead, generate all of the expressions you want to compute programmatically and +use them only once in a context. Loosely speaking, you want to swap the `for` loop with the context +`with_columns`. In practice, you could do something like the following: + +{{code_block('user-guide/expressions/expression-expansion', 'yield-expressions', [])}} + +```python exec="on" result="text" session="expressions/expression-expansion" +--8<-- "python/user-guide/expressions/expression-expansion.py:yield-expressions" +``` + +This produces the same final result and by specifying all of the expressions in one go we give +Polars the opportunity to: + +1. do a better job at optimising the query; and +2. parallelise the execution of the actual computations. + +## More flexible column selections + +Polars comes with the submodule `selectors` that provides a number of functions that allow you to +write more flexible column selections for expression expansion. + +!!! warning + + This functionality is not available in Rust yet. Refer to [Polars issue #10594](https://github.com/pola-rs/polars/issues/10594). + +As a first example, here is how we can use the functions `string` and `ends_with`, and the set +operations that the functions from `selectors` support, to select all string columns and the columns +whose names end with `"_high"`: + +{{code_block('user-guide/expressions/expression-expansion', 'selectors', [], ['selectors'], [])}} + +```python exec="on" result="text" session="expressions/expression-expansion" +--8<-- "python/user-guide/expressions/expression-expansion.py:selectors" +``` + +The submodule `selectors` provides +[a number of selectors that match based on the data type of the columns](#selectors-for-data-types), +of which the most useful are the functions that match a whole category of types, like `cs.numeric` +for all numeric data types or `cs.temporal` for all temporal data types. + +The submodule `selectors` also provides +[a number of selectors that match based on patterns in the column names](#selectors-for-column-name-patterns) +which make it more convenient to specify common patterns you may want to check for, like the +function `cs.ends_with` that was shown above. + +### Combining selectors with set operations + +We can combine multiple selectors using set operations and the usual Python operators: + + + +| Operator | Operation | +| ----------------------- | -------------------- | +| `A | B` | Union | +| `A & B` | Intersection | +| `A - B` | Difference | +| `A ^ B` | Symmetric difference | +| `~A` | Complement | + + +The next example matches all non-string columns that contain an underscore in the name: + +{{code_block('user-guide/expressions/expression-expansion', 'selectors-set-operations', [], +['selectors'], [])}} + +```python exec="on" result="text" session="expressions/expression-expansion" +--8<-- "python/user-guide/expressions/expression-expansion.py:selectors-set-operations" +``` + +### Resolving operator ambiguity + +Expression functions can be chained on top of selectors: + +{{code_block('user-guide/expressions/expression-expansion', 'selectors-expressions', [], +['selectors'], [])}} + +```python exec="on" result="text" session="expressions/expression-expansion" +--8<-- "python/user-guide/expressions/expression-expansion.py:selectors-expressions" +``` + +However, some operators have been overloaded to operate both on Polars selectors and on expressions. +For example, the operator `~` on a selector represents +[the set operation “complement”](#combining-selectors-with-set-operations) and on an expression +represents the Boolean operation of negation. + +When you use a selector and then want to use, in the context of an expression, one of the +[operators that act as set operators for selectors](#combining-selectors-with-set-operations), you +can use the function `as_expr`. + +Below, we want to negate the Boolean values in the columns “has_partner”, “has_kids”, and +“has_tattoos”. If we are not careful, the combination of the operator `~` and the selector +`cs.starts_with("has_")` will actually select the columns that we do not care about: + +{{code_block('user-guide/expressions/expression-expansion', 'selector-ambiguity', [], [], [])}} + +```python exec="on" result="text" session="expressions/expression-expansion" +--8<-- "python/user-guide/expressions/expression-expansion.py:selector-ambiguity" +``` + +The correct solution uses `as_expr`: + +{{code_block('user-guide/expressions/expression-expansion', 'as_expr', [])}} + +```python exec="on" result="text" session="expressions/expression-expansion" +--8<-- "python/user-guide/expressions/expression-expansion.py:as_expr" +``` + +### Debugging selectors + +When you are not sure whether you have a Polars selector at hand or not, you can use the function +`cs.is_selector` to check: + +{{code_block('user-guide/expressions/expression-expansion', 'is_selector', [], ['is_selector'], +[])}} + +```python exec="on" result="text" session="expressions/expression-expansion" +--8<-- "python/user-guide/expressions/expression-expansion.py:is_selector" +``` + +This should help you avoid any ambiguous situations where you think you are operating with +expressions but are in fact operating with selectors. + +Another helpful debugging utility is the function `expand_selector`. Given a target frame or schema, +you can check what columns a given selector will expand to: + +{{code_block('user-guide/expressions/expression-expansion', 'expand_selector', [], +['expand_selector'], [])}} + +```python exec="on" result="text" session="expressions/expression-expansion" +--8<-- "python/user-guide/expressions/expression-expansion.py:expand_selector" +``` + +### Complete reference + +The tables below group the functions available in the submodule `selectors` by their type of +behaviour. + +#### Selectors for data types + +Selectors that match based on the data type of the column: + +| Selector function | Data type(s) matched | +| ------------------ | ------------------------------------------------------------------ | +| `binary` | `Binary` | +| `boolean` | `Boolean` | +| `by_dtype` | Data types specified as arguments | +| `categorical` | `Categorical` | +| `date` | `Date` | +| `datetime` | `Datetime`, optionally filtering by time unit/zone | +| `decimal` | `Decimal` | +| `duration` | `Duration`, optionally filtering by time unit | +| `float` | All float types, regardless of precision | +| `integer` | All integer types, signed and unsigned, regardless of precision | +| `numeric` | All numeric types, namely integers, floats, and `Decimal` | +| `signed_integer` | All signed integer types, regardless of precision | +| `string` | `String` | +| `temporal` | All temporal data types, namely `Date`, `Datetime`, and `Duration` | +| `time` | `Time` | +| `unsigned_integer` | All unsigned integer types, regardless of precision | + +#### Selectors for column name patterns + +Selectors that match based on column name patterns: + +| Selector function | Columns selected | +| ----------------- | ------------------------------------------------------------ | +| `alpha` | Columns with alphabetical names | +| `alphanumeric` | Columns with alphanumeric names (letters and the digits 0-9) | +| `by_name` | Columns with the names specified as arguments | +| `contains` | Columns whose names contain the substring specified | +| `digit` | Columns with numeric names (only the digits 0-9) | +| `ends_with` | Columns whose names end with the given substring | +| `matches` | Columns whose names match the given regex pattern | +| `starts_with` | Columns whose names start with the given substring | + +#### Positional selectors + +Selectors that match based on the position of the columns: + +| Selector function | Columns selected | +| ----------------- | ------------------------------------ | +| `all` | All columns | +| `by_index` | The columns at the specified indices | +| `first` | The first column in the context | +| `last` | The last column in the context | + +#### Miscellaneous functions + +The submodule `selectors` also provides the following functions: + +| Function | Behaviour | +| ----------------- | ------------------------------------------------------------------------------------- | +| `as_expr`* | Convert a selector to an expression | +| `exclude` | Selects all columns except those matching the given names, data types, or selectors | +| `expand_selector` | Expand selector to matching columns with respect to a specific frame or target schema | +| `is_selector` | Check whether the given object/expression is a selector | + +*`as_expr` isn't a function defined on the submodule `selectors`, but rather a method defined on +selectors. diff --git a/docs/source/user-guide/expressions/folds.md b/docs/source/user-guide/expressions/folds.md new file mode 100644 index 000000000000..d6d7647401f1 --- /dev/null +++ b/docs/source/user-guide/expressions/folds.md @@ -0,0 +1,86 @@ +# Folds + +Polars provides many expressions to perform computations across columns, like `sum_horizontal`, +`mean_horizontal`, and `min_horizontal`. However, these are just special cases of a general +algorithm called a fold, and Polars provides a general mechanism for you to compute custom folds for +when the specialised versions of Polars are not enough. + +Folds computed with the function `fold` operate on the full columns for maximum speed. They utilize +the data layout very efficiently and often have vectorized execution. + +## Basic example + +As a first example, we will reimplement `sum_horizontal` with the function `fold`: + +{{code_block('user-guide/expressions/folds','mansum',['fold'])}} + +```python exec="on" result="text" session="user-guide/folds" +--8<-- "python/user-guide/expressions/folds.py:mansum" +``` + +The function `fold` expects a function `f` as the parameter `function` and `f` should accept two +arguments. The first argument is the accumulated result, which we initialise as zero, and the second +argument takes the successive values of the expressions listed in the parameter `exprs`. In our +case, they're the two columns “a” and “b”. + +The snippet below includes a third explicit expression that represents what the function `fold` is +doing above: + +{{code_block('user-guide/expressions/folds','mansum-explicit',['fold'])}} + +```python exec="on" result="text" session="user-guide/folds" +--8<-- "python/user-guide/expressions/folds.py:mansum-explicit" +``` + +??? tip "`fold` in Python" + + Most programming languages include a higher-order function that implements the algorithm that the function `fold` in Polars implements. + The Polars `fold` is very similar to Python's `functools.reduce`. + You can [learn more about the power of `functools.reduce` in this article](http://mathspp.com/blog/pydonts/the-power-of-reduce). + +## The initial value `acc` + +The initial value chosen for the accumulator `acc` is typically, but not always, the +[identity element](https://en.wikipedia.org/wiki/Identity_element) of the operation you want to +apply. For example, if we wanted to multiply across the columns, we would not get the correct result +if our accumulator was set to zero: + +{{code_block('user-guide/expressions/folds','manprod',['fold'])}} + +```python exec="on" result="text" session="user-guide/folds" +--8<-- "python/user-guide/expressions/folds.py:manprod" +``` + +To fix this, the accumulator `acc` should be set to `1`: + +{{code_block('user-guide/expressions/folds','manprod-fixed',['fold'])}} + +```python exec="on" result="text" session="user-guide/folds" +--8<-- "python/user-guide/expressions/folds.py:manprod-fixed" +``` + +## Conditional + +In the case where you'd want to apply a condition/predicate across all columns in a dataframe, a +fold can be a very concise way to express this. + +{{code_block('user-guide/expressions/folds','conditional',['fold'])}} + +```python exec="on" result="text" session="user-guide/folds" +--8<-- "python/user-guide/expressions/folds.py:conditional" +``` + +The snippet above filters all rows where all columns are greater than 1. + +## Folds and string data + +Folds could be used to concatenate string data. However, due to the materialization of intermediate +columns, this operation will have squared complexity. + +Therefore, we recommend using the function `concat_str` for this: + +{{code_block('user-guide/expressions/folds','string',['concat_str'])}} + +```python exec="on" result="text" session="user-guide/folds" +--8<-- "python/user-guide/expressions/folds.py:string" +``` diff --git a/docs/source/user-guide/expressions/index.md b/docs/source/user-guide/expressions/index.md new file mode 100644 index 000000000000..2d07da0a98be --- /dev/null +++ b/docs/source/user-guide/expressions/index.md @@ -0,0 +1,25 @@ +# Expressions + +We +[introduced the concept of “expressions” in a previous section](../concepts/expressions-and-contexts.md#expressions). +In this section we will focus on exploring the types of expressions that Polars offers. Each section +gives an overview of what they do and provides additional examples. + + +- Essentials: + - [Basic operations](basic-operations.md) – how to do basic operations on dataframe columns, like arithmetic calculations, comparisons, and other common, general-purpose operations + - [Expression expansion](expression-expansion.md) – what is expression expansion and how to use it + - [Casting](casting.md) – how to convert / cast values to different data types +- How to work with specific types of data or data type namespaces: + - [Strings](strings.md) – how to work with strings and the namespace `str` + - [Lists and arrays](lists-and-arrays.md) – the differences between the data types `List` and `Array`, when to use them, and how to use them + - [Categorical data and enums](categorical-data-and-enums.md) – the differences between the data types `Categorical` and `Enum`, when to use them, and how to use them + - [Structs](structs.md) – when to use the data type `Struct` and how to use it + - [Missing data](missing-data.md) – how to work with missing data and how to fill missing data +- Types of operations: + - [Aggregation](aggregation.md) – how to work with aggregating contexts like `group_by` + - [Window functions](window-functions.md) – how to apply window functions over columns in a dataframe + - [Folds](folds.md) – how to perform arbitrary computations horizontally across columns +- [User-defined Python functions](user-defined-python-functions.md) – how to apply user-defined Python functions to dataframe columns or to column values +- [Numpy functions](numpy-functions.md) – how to use NumPy native functions on Polars dataframes and series + diff --git a/docs/source/user-guide/expressions/lists-and-arrays.md b/docs/source/user-guide/expressions/lists-and-arrays.md new file mode 100644 index 000000000000..bb95e3446421 --- /dev/null +++ b/docs/source/user-guide/expressions/lists-and-arrays.md @@ -0,0 +1,204 @@ +# Lists and arrays + +Polars has first-class support for two homogeneous container data types: `List` and `Array`. Polars +supports many operations with the two data types and their APIs overlap, so this section of the user +guide has the objective of clarifying when one data type should be chosen in favour of the other. + +## Lists vs arrays + +### The data type `List` + +The data type list is suitable for columns whose values are homogeneous 1D containers of varying +lengths. + +The dataframe below contains three examples of columns with the data type `List`: + +{{code_block('user-guide/expressions/lists', 'list-example', ['List'])}} + +```python exec="on" result="text" session="expressions/lists" +--8<-- "python/user-guide/expressions/lists.py:list-example" +``` + +Note that the data type `List` is different from Python's type `list`, where elements can be of any +type. If you want to store true Python lists in a column, you can do so with the data type `Object` +and your column will not have the list manipulation features that we're about to discuss. + +### The data type `Array` + +The data type `Array` is suitable for columns whose values are homogeneous containers of an +arbitrary dimension with a known and fixed shape. + +The dataframe below contains two examples of columns with the data type `Array`. + +{{code_block('user-guide/expressions/lists', 'array-example', ['Array'])}} + +```python exec="on" result="text" session="expressions/lists" +--8<-- "python/user-guide/expressions/lists.py:array-example" +``` + +The example above shows how to specify that the columns “bit_flags” and “tic_tac_toe” have the data +type `Array`, parametrised by the data type of the elements contained within and by the shape of +each array. + +In general, Polars does not infer that a column has the data type `Array` for performance reasons, +and defaults to the appropriate variant of the data type `List`. In Python, an exception to this +rule is when you provide a NumPy array to build a column. In that case, Polars has the guarantee +from NumPy that all subarrays have the same shape, so an array of $n + 1$ dimensions will generate a +column of $n$ dimensional arrays: + +{{code_block('user-guide/expressions/lists', 'numpy-array-inference', ['Array'])}} + +```python exec="on" result="text" session="expressions/lists" +--8<-- "python/user-guide/expressions/lists.py:numpy-array-inference" +``` + +### When to use each + +In short, prefer the data type `Array` over `List` because it is more memory efficient and more +performant. If you cannot use `Array`, then use `List`: + +- when the values within a column do not have a fixed shape; or +- when you need functions that are only available in the list API. + +## Working with lists + +### The namespace `list` + +Polars provides many functions to work with values of the data type `List` and these are grouped +inside the namespace `list`. We will explore this namespace a bit now. + +!!! warning "`arr` then, `list` now" + + In previous versions of Polars, the namespace for list operations used to be `arr`. + `arr` is now the namespace for the data type `Array`. + If you find references to the namespace `arr` on StackOverflow or other sources, note that those sources _may_ be outdated. + +The dataframe `weather` defined below contains data from different weather stations across a region. +When the weather station is unable to get a result, an error code is recorded instead of the actual +temperature at that time. + +{{code_block('user-guide/expressions/lists', 'weather', [])}} + +```python exec="on" result="text" session="expressions/lists" +--8<-- "python/user-guide/expressions/lists.py:weather" +``` + +### Programmatically creating lists + +Given the dataframe `weather` defined previously, it is very likely we need to run some analysis on +the temperatures that are captured by each station. To make this happen, we need to first be able to +get individual temperature measurements. We +[can use the namespace `str`](strings.md#the-string-namespace) for this: + +{{code_block('user-guide/expressions/lists', 'split', ['str.split'])}} + +```python exec="on" result="text" session="expressions/lists" +--8<-- "python/user-guide/expressions/lists.py:split" +``` + +A natural follow-up would be to explode the list of temperatures so that each measurement is in its +own row: + +{{code_block('user-guide/expressions/lists', 'explode', ['explode'])}} + +```python exec="on" result="text" session="expressions/lists" +--8<-- "python/user-guide/expressions/lists.py:explode" +``` + +However, in Polars we often do not need to do this to operate on the list elements. + +### Operating on lists + +Polars provides several standard operations on columns with the `List` data type. +[Similar to what you can do with strings](strings.md#slicing), lists can be sliced with the +functions `head`, `tail`, and `slice`: + +{{code_block('user-guide/expressions/lists', 'list-slicing', ['Expr.list'])}} + +```python exec="on" result="text" session="expressions/lists" +--8<-- "python/user-guide/expressions/lists.py:list-slicing" +``` + +### Element-wise computation within lists + +If we need to identify the stations that are giving the most number of errors we need to + +1. try to convert the measurements into numbers; +2. count the number of non-numeric values (i.e., `null` values) in the list, by row; and +3. rename this output column as “errors” so that we can easily identify the stations. + +To perform these steps, we need to perform a casting operation on each measurement within the list +values. The function `eval` is used as the entry point to perform operations on the elements of the +list. Within it, you can use the context `element` to refer to each single element of the list +individually, and then you can use any Polars expression on the element: + +{{code_block('user-guide/expressions/lists', 'element-wise-casting', ['element'])}} + +```python exec="on" result="text" session="expressions/lists" +--8<-- "python/user-guide/expressions/lists.py:element-wise-casting" +``` + +Another alternative would be to use a regular expression to check if a measurement starts with a +letter: + +{{code_block('user-guide/expressions/lists', 'element-wise-regex', ['element'])}} + +```python exec="on" result="text" session="expressions/lists" +--8<-- "python/user-guide/expressions/lists.py:element-wise-regex" +``` + +If you are unfamiliar with the namespace `str` or the notation `(?i)` in the regex, now is a good +time to +[look at how to work with strings and regular expressions in Polars](strings.md#check-for-the-existence-of-a-pattern). + +### Row-wise computations + +The function `eval` gives us access to the list elements and `pl.element` refers to each individual +element, but we can also use `pl.all()` to refer to all of the elements of the list. + +To show this in action, we will start by creating another dataframe with some more weather data: + +{{code_block('user-guide/expressions/lists', 'weather_by_day', [])}} + +```python exec="on" result="text" session="expressions/lists" +--8<-- "python/user-guide/expressions/lists.py:weather_by_day" +``` + +Now, we will calculate the percentage rank of the temperatures by day, measured across stations. +Polars does not provide a function to do this directly, but because expressions are so versatile we +can create our own percentage rank expression for highest temperature. Let's try that: + +{{code_block('user-guide/expressions/lists', 'rank_pct', ['element', 'rank'])}} + +```python exec="on" result="text" session="expressions/lists" +--8<-- "python/user-guide/expressions/lists.py:rank_pct" +``` + +## Working with arrays + +### Creating an array column + +As [we have seen above](#the-data-type-array), Polars usually does not infer the data type `Array` +automatically. You have to specify the data type `Array` when creating a series/dataframe or +[cast a column](casting.md) explicitly unless you create the column out of a NumPy array. + +### The namespace `arr` + +The data type `Array` was recently introduced and is still pretty nascent in features that it +offers. Even so, the namespace `arr` aggregates several functions that you can use to work with +arrays. + +!!! warning "`arr` then, `list` now" + + In previous versions of Polars, the namespace for list operations used to be `arr`. + `arr` is now the namespace for the data type `Array`. + If you find references to the namespace `arr` on StackOverflow or other sources, note that those sources _may_ be outdated. + +The API documentation should give you a good overview of the functions in the namespace `arr`, of +which we present a couple: + +{{code_block('user-guide/expressions/lists', 'array-overview', ['Expr.arr'])}} + +```python exec="on" result="text" session="expressions/lists" +--8<-- "python/user-guide/expressions/lists.py:array-overview" +``` diff --git a/docs/source/user-guide/expressions/missing-data.md b/docs/source/user-guide/expressions/missing-data.md new file mode 100644 index 000000000000..065aa38da5c7 --- /dev/null +++ b/docs/source/user-guide/expressions/missing-data.md @@ -0,0 +1,189 @@ +# Missing data + +This section of the user guide teaches how to work with missing data in Polars. + +## `null` and `NaN` values + +In Polars, missing data is represented by the value `null`. This missing value `null` is used for +all data types, including numerical types. + +Polars also supports the value `NaN` (“Not a Number”) for columns with floating point numbers. The +value `NaN` is considered to be a valid floating point value, which is different from missing data. +[We discuss the value `NaN` separately below](#not-a-number-or-nan-values). + +When creating a series or a dataframe, you can set a value to `null` by using the appropriate +construct for your language: + +{{code_block('user-guide/expressions/missing-data','dataframe',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/missing-data" +--8<-- "python/user-guide/expressions/missing-data.py:dataframe" +``` + +!!! info "Difference from pandas" + + In pandas, the value used to represent missing data depends on the data type of the column. + In Polars, missing data is always represented by the value `null`. + +## Missing data metadata + +Polars keeps track of some metadata regarding the missing data of each series. This metadata allows +Polars to answer some basic queries about missing values in a very efficient way, namely how many +values are missing and which ones are missing. + +To determine how many values are missing from a column you can use the function `null_count`: + +{{code_block('user-guide/expressions/missing-data','count',['null_count'])}} + +```python exec="on" result="text" session="user-guide/missing-data" +--8<-- "python/user-guide/expressions/missing-data.py:count" +``` + +The function `null_count` can be called on a dataframe, a column from a dataframe, or on a series +directly. The function `null_count` is a cheap operation because the result is already known. + +Polars uses something called a “validity bitmap” to know which values are missing in a series. The +validity bitmap is memory efficient as it is bit encoded. If a series has length $n$, then its +validity bitmap will cost $n / 8$ bytes. The function `is_null` uses the validity bitmap to +efficiently report which values are `null` and which are not: + +{{code_block('user-guide/expressions/missing-data','isnull',['is_null'])}} + +```python exec="on" result="text" session="user-guide/missing-data" +--8<-- "python/user-guide/expressions/missing-data.py:isnull" +``` + +The function `is_null` can be used on a column of a dataframe or on a series directly. Again, this +is a cheap operation because the result is already known by Polars. + +??? info "Why does Polars waste memory on a validity bitmap?" + + It all comes down to a tradeoff. + By using a bit more memory per column, Polars can be much more efficient when performing most operations on your columns. + If the validity bitmap wasn't known, every time you wanted to compute something you would have to check each position of the series to see if a legal value was present or not. + With the validity bitmap, Polars knows automatically the positions where your operations can be applied. + +## Filling missing data + +Missing data in a series can be filled with the function `fill_null`. You can specify how missing +data is effectively filled in a couple of different ways: + +- a literal of the correct data type; +- a Polars expression, such as replacing with values computed from another column; +- a strategy based on neighbouring values, such as filling forwards or backwards; and +- interpolation. + +To illustrate how each of these methods work we start by defining a simple dataframe with two +missing values in the second column: + +{{code_block('user-guide/expressions/missing-data','dataframe2',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/missing-data" +--8<-- "python/user-guide/expressions/missing-data.py:dataframe2" +``` + +### Fill with a specified literal value + +You can fill the missing data with a specified literal value. This literal value will replace all of +the occurrences of the value `null`: + +{{code_block('user-guide/expressions/missing-data','fill',['fill_null'])}} + +```python exec="on" result="text" session="user-guide/missing-data" +--8<-- "python/user-guide/expressions/missing-data.py:fill" +``` + +However, this is actually just a special case of the general case where +[the function `fill_null` replaces missing values with the corresponding values from the result of a Polars expression](#fill-with-a-strategy-based-on-neighbouring-values), +as seen next. + +### Fill with an expression + +In the general case, the missing data can be filled by extracting the corresponding values from the +result of a general Polars expression. For example, we can fill the second column with values taken +from the double of the first column: + +{{code_block('user-guide/expressions/missing-data','fillexpr',['fill_null'])}} + +```python exec="on" result="text" session="user-guide/missing-data" +--8<-- "python/user-guide/expressions/missing-data.py:fillexpr" +``` + +### Fill with a strategy based on neighbouring values + +You can also fill the missing data by following a fill strategy based on the neighbouring values. +The two simpler strategies look for the first non-`null` value that comes immediately before or +immediately after the value `null` that is being filled: + +{{code_block('user-guide/expressions/missing-data','fillstrategy',['fill_null'])}} + +```python exec="on" result="text" session="user-guide/missing-data" +--8<-- "python/user-guide/expressions/missing-data.py:fillstrategy" +``` + +You can find other fill strategies in the API docs. + +### Fill with interpolation + +Additionally, you can fill intermediate missing data with interpolation by using the function +`interpolate` instead of the function `fill_null`: + +{{code_block('user-guide/expressions/missing-data','fillinterpolate',['interpolate'])}} + +```python exec="on" result="text" session="user-guide/missing-data" +--8<-- "python/user-guide/expressions/missing-data.py:fillinterpolate" +``` + +Note: With interpolate, nulls at the beginning and end of the series remain null. + +## Not a Number, or `NaN` values + +Missing data in a series is only ever represented by the value `null`, regardless of the data type +of the series. Columns with a floating point data type can sometimes have the value `NaN`, which +might be confused with `null`. + +The special value `NaN` can be created directly: + +{{code_block('user-guide/expressions/missing-data','nan',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/missing-data" +--8<-- "python/user-guide/expressions/missing-data.py:nan" +``` + +And it might also arise as the result of a computation: + +{{code_block('user-guide/expressions/missing-data','nan-computed',[])}} + +```python exec="on" result="text" session="user-guide/missing-data" +--8<-- "python/user-guide/expressions/missing-data.py:nan-computed" +``` + +!!! info + + By default, a `NaN` value in an integer column causes the column to be cast to a float data type in pandas. + This does not happen in Polars; instead, an exception is raised. + +`NaN` values are considered to be a type of floating point data and are **not considered to be +missing data** in Polars. This means: + +- `NaN` values are **not** counted with the function `null_count`; and +- `NaN` values are filled when you use the specialised function `fill_nan` method but are **not** + filled with the function `fill_null`. + +Polars has the functions `is_nan` and `fill_nan`, which work in a similar way to the functions +`is_null` and `fill_null`. Unlike with missing data, Polars does not hold any metadata regarding the +`NaN` values, so the function `is_nan` entails actual computation. + +One further difference between the values `null` and `NaN` is that numerical aggregating functions, +like `mean` and `sum`, skip the missing values when computing the result, whereas the value `NaN` is +considered for the computation and typically propagates into the result. If desirable, this behavior +can be avoided by replacing the occurrences of the value `NaN` with the value `null`: + +{{code_block('user-guide/expressions/missing-data','nanfill',['fill_nan'])}} + +```python exec="on" result="text" session="user-guide/missing-data" +--8<-- "python/user-guide/expressions/missing-data.py:nanfill" +``` + +You can learn more about the value `NaN` in +[the section about floating point number data types](../concepts/data-types-and-structures.md#floating-point-numbers). diff --git a/docs/source/user-guide/expressions/numpy-functions.md b/docs/source/user-guide/expressions/numpy-functions.md new file mode 100644 index 000000000000..971500f2d943 --- /dev/null +++ b/docs/source/user-guide/expressions/numpy-functions.md @@ -0,0 +1,28 @@ +# Numpy functions + +Polars expressions support NumPy [ufuncs](https://numpy.org/doc/stable/reference/ufuncs.html). See +[the NumPy documentation for a list of all supported NumPy functions](https://numpy.org/doc/stable/reference/ufuncs.html#available-ufuncs). + +This means that if a function is not provided by Polars, we can use NumPy and we still have fast +columnar operations through the NumPy API. + +## Example + +{{code_block('user-guide/expressions/numpy-example',api_functions=['DataFrame','np.log'])}} + +```python exec="on" result="text" session="user-guide/numpy" +--8<-- "python/user-guide/expressions/numpy-example.py" +``` + +## Interoperability + +Polars' series have support for NumPy universal functions (ufuncs) and generalized ufuncs. +Element-wise functions such as `np.exp`, `np.cos`, `np.div`, etc, all work with almost zero +overhead. + +However, bear in mind that +[Polars keeps track of missing values with a separate bitmask](missing-data.md) and NumPy does not +receive this information. This can lead to a window function or a `np.convolve` giving flawed or +incomplete results, so an error will be raised if you pass a series with missing data to a +generalized ufunc. Convert a Polars series to a NumPy array with the function `to_numpy`. Missing +values will be replaced by `np.nan` during the conversion. diff --git a/docs/source/user-guide/expressions/speed_rank_by_type.svg b/docs/source/user-guide/expressions/speed_rank_by_type.svg new file mode 100644 index 000000000000..4324508c228b --- /dev/null +++ b/docs/source/user-guide/expressions/speed_rank_by_type.svg @@ -0,0 +1,102 @@ + + + + + + Bulbasaur + + Ivysaur + + Venusaur + + + VenusaurMega + Venusaur + + + Charmander + + ... + + Oddish + + Gloom + + ... + + Grass + + Grass + + Grass + + Grass + + Fire + + ... + + Grass + + Grass + + ... + + 45 + + 60 + + 80 + + 80 + + 65 + + ... + + 30 + + 40 + + ... + + + + + + 6 + + 3 + + 1 + + 1 + + 7 + + ... + + 8 + + 7 + + ... + + Name + + Type 1 + + Speed + + Speed rank + + Golbat + + Poison + + 90 + + 1 + diff --git a/docs/source/user-guide/expressions/strings.md b/docs/source/user-guide/expressions/strings.md new file mode 100644 index 000000000000..826e49cc7952 --- /dev/null +++ b/docs/source/user-guide/expressions/strings.md @@ -0,0 +1,162 @@ +# Strings + +The following section discusses operations performed on string data, which is a frequently used data +type when working with dataframes. String processing functions are available in the namespace `str`. + +Working with strings in other dataframe libraries can be highly inefficient due to the fact that +strings have unpredictable lengths. Polars mitigates these inefficiencies by +[following the Arrow Columnar Format specification](../concepts/data-types-and-structures.md#data-types-internals), +so you can write performant data queries on string data too. + +## The string namespace + +When working with string data you will likely need to access the namespace `str`, which aggregates +40+ functions that let you work with strings. As an example of how to access functions from within +that namespace, the snippet below shows how to compute the length of the strings in a column in +terms of the number of bytes and the number of characters: + +{{code_block('user-guide/expressions/strings','df',['str.len_bytes','str.len_chars'])}} + +```python exec="on" result="text" session="expressions/strings" +--8<-- "python/user-guide/expressions/strings.py:df" +``` + +!!! note + + If you are working exclusively with ASCII text, then the results of the two computations will be the same and using `len_bytes` is recommended since it is faster. + +## Parsing strings + +Polars offers multiple methods for checking and parsing elements of a string column, namely checking +for the existence of given substrings or patterns, and counting, extracting, or replacing, them. We +will demonstrate some of these operations in the upcoming examples. + +### Check for the existence of a pattern + +We can use the function `contains` to check for the presence of a pattern within a string. By +default, the argument to the function `contains` is interpreted as a regular expression. If you want +to specify a literal substring, set the parameter `literal` to `True`. + +For the special cases where you want to check if the strings start or end with a fixed substring, +you can use the functions `starts_with` or `ends_with`, respectively. + +{{code_block('user-guide/expressions/strings','existence',['str.contains', +'str.starts_with','str.ends_with'])}} + +```python exec="on" result="text" session="expressions/strings" +--8<-- "python/user-guide/expressions/strings.py:existence" +``` + +### Regex specification + +Polars relies on the Rust crate `regex` to work with regular expressions, so you may need to +[refer to the syntax documentation](https://docs.rs/regex/latest/regex/#syntax) to see what features +and flags are supported. In particular, note that the flavor of regex supported by Polars is +different from Python's module `re`. + +### Extract a pattern + +The function `extract` allows us to extract patterns from the string values in a column. The +function `extract` accepts a regex pattern with one or more capture groups and extracts the capture +group specified as the second argument. + +{{code_block('user-guide/expressions/strings','extract',['str.extract'])}} + +```python exec="on" result="text" session="expressions/strings" +--8<-- "python/user-guide/expressions/strings.py:extract" +``` + +To extract all occurrences of a pattern within a string, we can use the function `extract_all`. In +the example below, we extract all numbers from a string using the regex pattern `(\d+)`, which +matches one or more digits. The resulting output of the function `extract_all` is a list containing +all instances of the matched pattern within the string. + +{{code_block('user-guide/expressions/strings','extract_all',['str.extract_all'])}} + +```python exec="on" result="text" session="expressions/strings" +--8<-- "python/user-guide/expressions/strings.py:extract_all" +``` + +### Replace a pattern + +Akin to the functions `extract` and `extract_all`, Polars provides the functions `replace` and +`replace_all`. These accept a regex pattern or a literal substring (if the parameter `literal` is +set to `True`) and perform the replacements specified. The function `replace` will make at most one +replacement whereas the function `replace_all` will make all the non-overlapping replacements it +finds. + +{{code_block('user-guide/expressions/strings','replace',['str.replace', 'str.replace_all'])}} + +```python exec="on" result="text" session="expressions/strings" +--8<-- "python/user-guide/expressions/strings.py:replace" +``` + +## Modifying strings + +### Case conversion + +Converting the casing of a string is a common operation and Polars supports it out of the box with +the functions `to_lowercase`, `to_titlecase`, and `to_uppercase`: + +{{code_block('user-guide/expressions/strings','casing', ['str.to_lowercase', 'str.to_titlecase', +'str.to_uppercase'])}} + +```python exec="on" result="text" session="expressions/strings" +--8<-- "python/user-guide/expressions/strings.py:casing" +``` + +### Stripping characters from the ends + +Polars provides five functions in the namespace `str` that let you strip characters from the ends of +the string: + +| Function | Behaviour | +| ------------------- | --------------------------------------------------------------------- | +| `strip_chars` | Removes leading and trailing occurrences of the characters specified. | +| `strip_chars_end` | Removes trailing occurrences of the characters specified. | +| `strip_chars_start` | Removes leading occurrences of the characters specified. | +| `strip_prefix` | Removes an exact substring prefix if present. | +| `strip_suffix` | Removes an exact substring suffix if present. | + +??? info "Similarity to Python string methods" + + `strip_chars` is similar to Python's string method `strip` and `strip_prefix`/`strip_suffix` + are similar to Python's string methods `removeprefix` and `removesuffix`, respectively. + +It is important to understand that the first three functions interpret their string argument as a +set of characters whereas the functions `strip_prefix` and `strip_suffix` do interpret their string +argument as a literal string. + +{{code_block('user-guide/expressions/strings', 'strip', ['str.strip_chars', 'str.strip_chars_end', +'str.strip_chars_start', 'str.strip_prefix', 'str.strip_suffix'])}} + +```python exec="on" result="text" session="expressions/strings" +--8<-- "python/user-guide/expressions/strings.py:strip" +``` + +If no argument is provided, the three functions `strip_chars`, `strip_chars_end`, and +`strip_chars_start`, remove whitespace by default. + +### Slicing + +Besides [extracting substrings as specified by patterns](#extract-a-pattern), you can also slice +strings at specified offsets to produce substrings. The general-purpose function for slicing is +`slice` and it takes the starting offset and the optional _length_ of the slice. If the length of +the slice is not specified or if it's past the end of the string, Polars slices the string all the +way to the end. + +The functions `head` and `tail` are specialised versions used for slicing the beginning and end of a +string, respectively. + +{{code_block('user-guide/expressions/strings', 'slice', [], ['str.slice', 'str.head', 'str.tail'], +['str.str_slice', 'str.str_head', 'str.str_tail'])}} + +```python exec="on" result="text" session="expressions/strings" +--8<-- "python/user-guide/expressions/strings.py:slice" +``` + +## API documentation + +In addition to the examples covered above, Polars offers various other string manipulation +functions. To explore these additional methods, you can go to the API documentation of your chosen +programming language for Polars. diff --git a/docs/source/user-guide/expressions/structs.md b/docs/source/user-guide/expressions/structs.md new file mode 100644 index 000000000000..faf4b3763182 --- /dev/null +++ b/docs/source/user-guide/expressions/structs.md @@ -0,0 +1,166 @@ +# Structs + +The data type `Struct` is a composite data type that can store multiple fields in a single column. + +!!! tip "Python analogy" + + For Python users, the data type `Struct` is kind of like a Python + dictionary. Even better, if you are familiar with Python typing, you can think of the data type + `Struct` as `typing.TypedDict`. + +In this page of the user guide we will see situations in which the data type `Struct` arises, we +will understand why it does arise, and we will see how to work with `Struct` values. + +Let's start with a dataframe that captures the average rating of a few movies across some states in +the US: + +{{code_block('user-guide/expressions/structs','ratings_df',['DataFrame'])}} + +```python exec="on" result="text" session="expressions/structs" +--8<-- "python/user-guide/expressions/structs.py:ratings_df" +``` + +## Encountering the data type `Struct` + +A common operation that will lead to a `Struct` column is the ever so popular `value_counts` +function that is commonly used in exploratory data analysis. Checking the number of times a state +appears in the data is done as so: + +{{code_block('user-guide/expressions/structs','state_value_counts',['value_counts'])}} + +```python exec="on" result="text" session="expressions/structs" +--8<-- "python/user-guide/expressions/structs.py:state_value_counts" +``` + +Quite unexpected an output, especially if coming from tools that do not have such a data type. We're +not in peril, though. To get back to a more familiar output, all we need to do is use the function +`unnest` on the `Struct` column: + +{{code_block('user-guide/expressions/structs','struct_unnest',['unnest'])}} + +```python exec="on" result="text" session="expressions/structs" +--8<-- "python/user-guide/expressions/structs.py:struct_unnest" +``` + +The function `unnest` will turn each field of the `Struct` into its own column. + +!!! note "Why `value_counts` returns a `Struct`" + + Polars expressions always operate on a single series and return another series. + `Struct` is the data type that allows us to provide multiple columns as input to an expression, or to output multiple columns from an expression. + Thus, we can use the data type `Struct` to specify each value and its count when we use `value_counts`. + +## Inferring the data type `Struct` from dictionaries + +When building series or dataframes, Polars will convert dictionaries to the data type `Struct`: + +{{code_block('user-guide/expressions/structs','series_struct',['Series'])}} + +```python exec="on" result="text" session="expressions/structs" +--8<-- "python/user-guide/expressions/structs.py:series_struct" +``` + +The number of fields, their names, and their types, are inferred from the first dictionary seen. +Subsequent incongruences can result in `null` values or in errors: + +{{code_block('user-guide/expressions/structs','series_struct_error',['Series'])}} + +```python exec="on" result="text" session="expressions/structs" +--8<-- "python/user-guide/expressions/structs.py:series_struct_error" +``` + +## Extracting individual values of a `Struct` + +Let's say that we needed to obtain just the field `"Movie"` from the `Struct` in the series that we +created above. We can use the function `field` to do so: + +{{code_block('user-guide/expressions/structs','series_struct_extract',['struct.field'])}} + +```python exec="on" result="text" session="expressions/structs" +--8<-- "python/user-guide/expressions/structs.py:series_struct_extract" +``` + +## Renaming individual fields of a `Struct` + +What if we need to rename individual fields of a `Struct` column? We use the function +`rename_fields`: + +{{code_block('user-guide/expressions/structs','series_struct_rename',['struct.rename_fields'])}} + +```python exec="on" result="text" session="expressions/structs" +--8<-- "python/user-guide/expressions/structs.py:series_struct_rename" +``` + +To be able to actually see that the field names were change, we will create a dataframe where the +only column is the result and then we use the function `unnest` so that each field becomes its own +column. The column names will reflect the renaming operation we just did: + +{{code_block('user-guide/expressions/structs','struct-rename-check',['struct.rename_fields'])}} + +```python exec="on" result="text" session="expressions/structs" +--8<-- "python/user-guide/expressions/structs.py:struct-rename-check" +``` + +## Practical use-cases of `Struct` columns + +### Identifying duplicate rows + +Let's get back to the `ratings` data. We want to identify cases where there are duplicates at a +“Movie” and “Theatre” level. + +This is where the data type `Struct` shines: + +{{code_block('user-guide/expressions/structs','struct_duplicates',['is_duplicated', 'struct'])}} + +```python exec="on" result="text" session="expressions/structs" +--8<-- "python/user-guide/expressions/structs.py:struct_duplicates" +``` + +We can identify the unique cases at this level also with `is_unique`! + +### Multi-column ranking + +Suppose, given that we know there are duplicates, we want to choose which rating gets a higher +priority. We can say that the column “Count” is the most important, and if there is a tie in the +column “Count” then we consider the column “Avg_Rating”. + +We can then do: + +{{code_block('user-guide/expressions/structs','struct_ranking',['is_duplicated', 'struct'])}} + +```python exec="on" result="text" session="expressions/structs" +--8<-- "python/user-guide/expressions/structs.py:struct_ranking" +``` + +That's a pretty complex set of requirements done very elegantly in Polars! To learn more about the +function `over`, used above, [see the user guide section on window functions](window-functions.md). + +### Using multiple columns in a single expression + +As mentioned earlier, the data type `Struct` is also useful if you need to pass multiple columns as +input to an expression. As an example, suppose we want to compute +[the Ackermann function](https://en.wikipedia.org/wiki/Ackermann_function) on two columns of a +dataframe. There is no way of composing Polars expressions to compute the Ackermann function[^1], so +we define a custom function: + +{{code_block('user-guide/expressions/structs', 'ack', [])}} + +```python exec="on" result="text" session="expressions/structs" +--8<-- "python/user-guide/expressions/structs.py:ack" +``` + +Now, to compute the values of the Ackermann function on those arguments, we start by creating a +`Struct` with fields `m` and `n` and then use the function `map_elements` to apply the function +`ack` to each value: + +{{code_block('user-guide/expressions/structs','struct-ack',[], ['map_elements'], [])}} + +```python exec="on" result="text" session="expressions/structs" +--8<-- "python/user-guide/expressions/structs.py:struct-ack" +``` + +Refer to +[this section of the user guide to learn more about applying user-defined Python functions to your data](user-defined-python-functions.md). + +[^1]: To say that something cannot be done is quite a bold claim. If you prove us wrong, please let +us know! diff --git a/docs/source/user-guide/expressions/user-defined-python-functions.md b/docs/source/user-guide/expressions/user-defined-python-functions.md new file mode 100644 index 000000000000..1cf250e2d73f --- /dev/null +++ b/docs/source/user-guide/expressions/user-defined-python-functions.md @@ -0,0 +1,178 @@ +# User-defined Python functions + + + +Polars expressions are quite powerful and flexible, so there is much less need for custom Python +functions compared to other libraries. Still, you may need to pass an expression's state to a third +party library or apply your black box function to data in Polars. + +In this part of the documentation we'll be using two APIs that allows you to do this: + +- [:material-api: `map_elements`](https://docs.pola.rs/py-polars/html/reference/expressions/api/polars.Expr.map_elements.html): + Call a function separately on each value in the `Series`. +- [:material-api: `map_batches`](https://docs.pola.rs/py-polars/html/reference/expressions/api/polars.Expr.map_batches.html): + Always passes the full `Series` to the function. + +## Processing individual values with `map_elements()` + +Let's start with the simplest case: we want to process each value in a `Series` individually. Here +is our data: + +{{code_block('user-guide/expressions/user-defined-functions','dataframe',[])}} + +```python exec="on" result="text" session="user-guide/udf" +--8<-- "python/user-guide/expressions/user-defined-functions.py:setup" +--8<-- "python/user-guide/expressions/user-defined-functions.py:dataframe" +``` + +We'll call `math.log()` on each individual value: + +{{code_block('user-guide/expressions/user-defined-functions','individual_log',[])}} + +```python exec="on" result="text" session="user-guide/udf" +--8<-- "python/user-guide/expressions/user-defined-functions.py:individual_log" +``` + +While this works, `map_elements()` has two problems: + +1. **Limited to individual items:** Often you'll want to have a calculation that needs to operate on + the whole `Series`, rather than individual items one by one. +2. **Performance overhead:** Even if you do want to process each item individually, calling a + function for each individual item is slow; all those extra function calls add a lot of overhead. + +Let's start by solving the first problem, and then we'll see how to solve the second problem. + +## Processing a whole `Series` with `map_batches()` + +We want to run a custom function on the contents of a whole `Series`. For demonstration purposes, +let's say we want to calculate the difference between the mean of a `Series` and each value. + +We can use the `map_batches()` API to run this function on either the full `Series` or individual +groups in a `group_by()`: + +{{code_block('user-guide/expressions/user-defined-functions','diff_from_mean',[])}} + +```python exec="on" result="text" session="user-guide/udf" +--8<-- "python/user-guide/expressions/user-defined-functions.py:diff_from_mean" +``` + +## Fast operations with user-defined functions + +The problem with a pure-Python implementation is that it's slow. In general, you want to minimize +how much Python code you call if you want fast results. + +To maximize speed, you'll want to make sure that you're using a function written in a compiled +language. For numeric calculations Polars supports a pair of interfaces defined by NumPy called +["ufuncs"](https://numpy.org/doc/stable/reference/ufuncs.html) and +["generalized ufuncs"](https://numpy.org/neps/nep-0005-generalized-ufuncs.html). The former runs on +each item individually, and the latter accepts a whole NumPy array, which allows for more flexible +operations. + +[NumPy](https://numpy.org/doc/stable/reference/ufuncs.html) and other libraries like +[SciPy](https://docs.scipy.org/doc/scipy/reference/special.html#module-scipy.special) come with +pre-written ufuncs you can use with Polars. For example: + +{{code_block('user-guide/expressions/user-defined-functions','np_log',[])}} + +```python exec="on" result="text" session="user-guide/udf" +--8<-- "python/user-guide/expressions/user-defined-functions.py:np_log" +``` + +Notice that we can use `map_batches()`, because `numpy.log()` is able to run on both individual +items and on whole NumPy arrays. This means it will run much faster than our original example, since +we only have a single Python call and then all processing happens in a fast low-level language. + +## Example: A fast custom function using Numba + +The pre-written functions NumPy provides are helpful, but our goal is to write our own functions. +For example, let's say we want a fast version of our `diff_from_mean()` example above. The easiest +way to write this in Python is to use [Numba](https://numba.readthedocs.io/en/stable/), which allows +you to write custom functions in (a subset) of Python while still getting the benefit of compiled +code. + +In particular, Numba provides a decorator called +[`@guvectorize`](https://numba.readthedocs.io/en/stable/user/vectorize.html#the-guvectorize-decorator). +This creates a generalized ufunc by compiling a Python function to fast machine code, in a way that +allows it to be used by Polars. + +In the following example the `diff_from_mean_numba()` will be compiled to fast machine code at +import time, which will take a little time. After that all calls to the function will run quickly. +The `Series` will be converted to a NumPy array before being passed to the function: + +{{code_block('user-guide/expressions/user-defined-functions','diff_from_mean_numba',[])}} + +```python exec="on" result="text" session="user-guide/udf" +--8<-- "python/user-guide/expressions/user-defined-functions.py:diff_from_mean_numba" +``` + +## Missing data is not allowed when calling generalized ufuncs + +Before being passed to a user-defined function like `diff_from_mean_numba()`, a `Series` will be +converted to a NumPy array. Unfortunately, NumPy arrays don't have a concept of missing data. If +there is missing data in the original `Series`, this means the resulting array won't actually match +the `Series`. + +If you're calculating results item by item, this doesn't matter. For example, `numpy.log()` gets +called on each individual value separately, so those missing values don't change the calculation. +But if the result of a user-defined function depend on multiple values in the `Series`, it's not +clear what exactly should happen with the missing values. + +Therefore, when calling generalized ufuncs such as Numba functions decorated with `@guvectorize`, +Polars will raise an error if you try to pass in a `Series` with missing data. How do you get rid of +missing data? Either [fill it in](missing-data.md) or +[drop it](https://docs.pola.rs/py-polars/html/reference/dataframe/api/polars.DataFrame.drop_nulls.html) +before calling your custom function. + +## Combining multiple column values + +If you want to pass multiple columns to a user-defined function, you can use `Struct`s, which are +[covered in detail in a different section](structs.md). The basic idea is to combine multiple +columns into a `Struct`, and then the function can extract the columns back out: + +{{code_block('user-guide/expressions/user-defined-functions','combine',[])}} + +```python exec="on" result="text" session="user-guide/udf" +--8<-- "python/user-guide/expressions/user-defined-functions.py:combine" +``` + +## Streaming calculations + +Passing the full `Series` to the user-defined function has a cost: it may use a lot of memory, as +its contents are copied into a NumPy array. You can use the `is_elementwise=True` argument to +[:material-api: `map_batches`](https://docs.pola.rs/py-polars/html/reference/expressions/api/polars.Expr.map_batches.html) +to stream results into the function, which means it might not get all values at once. + +!!! note + + The `is_elementwise` argument can lead to incorrect results if set incorrectly. + If you set `is_elementwise=True`, make sure that your function actually operates + element-by-element (e.g. "calculate the logarithm of each value") - our example function `diff_from_mean()`, for instance, does not. + +## Return types + +Custom Python functions are often black boxes; Polars doesn't know what your function is doing or +what it will return. The return data type is therefore automatically inferred. We do that by waiting +for the first non-null value. That value will then be used to determine the type of the resulting +`Series`. + +The mapping of Python types to Polars data types is as follows: + +- `int` -> `Int64` +- `float` -> `Float64` +- `bool` -> `Boolean` +- `str` -> `String` +- `list[tp]` -> `List[tp]` (where the inner type is inferred with the same rules) +- `dict[str, [tp]]` -> `struct` +- `Any` -> `object` (Prevent this at all times) + +Rust types map as follows: + +- `i32` or `i64` -> `Int64` +- `f32` or `f64` -> `Float64` +- `bool` -> `Boolean` +- `String` or `str` -> `String` +- `Vec` -> `List[tp]` (where the inner type is inferred with the same rules) + +You can pass a `return_dtype` argument to +[:material-api: `map_batches`](https://docs.pola.rs/py-polars/html/reference/expressions/api/polars.Expr.map_batches.html) +if you want to override the inferred type. diff --git a/docs/source/user-guide/expressions/window-functions.md b/docs/source/user-guide/expressions/window-functions.md new file mode 100644 index 000000000000..7acc249785aa --- /dev/null +++ b/docs/source/user-guide/expressions/window-functions.md @@ -0,0 +1,165 @@ +# Window functions + +Window functions are expressions with superpowers. They allow you to perform aggregations on groups +within the context `select`. Let's get a feel for what that means. + +First, we load a Pokémon dataset: + +{{code_block('user-guide/expressions/window','pokemon',['read_csv'])}} + +```python exec="on" result="text" session="user-guide/window" +--8<-- "python/user-guide/expressions/window.py:pokemon" +``` + +## Operations per group + +Window functions are ideal when we want to perform an operation within a group. For instance, +suppose we want to rank our Pokémon by the column “Speed”. However, instead of a global ranking, we +want to rank the speed within each group defined by the column “Type 1”. We write the expression to +rank the data by the column “Speed” and then we add the function `over` to specify that this should +happen over the unique values of the column “Type 1”: + +{{code_block('user-guide/expressions/window','rank',['over'])}} + +```python exec="on" result="text" session="user-guide/window" +--8<-- "python/user-guide/expressions/window.py:rank" +``` + +To help visualise this operation, you may imagine that Polars selects the subsets of the data that +share the same value for the column “Type 1” and then computes the ranking expression only for those +values. Then, the results for that specific group are projected back to the original rows and Polars +does this for all of the existing groups. The diagram below highlights the ranking computation for +the Pokémon with “Type 1” equal to “Grass”. + +
+--8<-- "docs/source/user-guide/expressions/speed_rank_by_type.svg" +
+ +Note how the row for the Pokémon “Golbat” has a “Speed” value of `90`, which is greater than the +value `80` of the Pokémon “Venusaur”, and yet the latter was ranked 1 because “Golbat” and “Venusar” +do not share the same value for the column “Type 1”. + +The function `over` accepts an arbitrary number of expressions to specify the groups over which to +perform the computations. We can repeat the ranking above, but over the combination of the columns +“Type 1” and “Type 2” for a more fine-grained ranking: + +{{code_block('user-guide/expressions/window','rank-multiple',['over'])}} + +```python exec="on" result="text" session="user-guide/window" +--8<-- "python/user-guide/expressions/window.py:rank-multiple" +``` + +In general, the results you get with the function `over` can also be achieved with +[an aggregation](aggregation.md) followed by a call to the function `explode`, although the rows +would be in a different order: + +{{code_block('user-guide/expressions/window','rank-explode',['explode'])}} + +```python exec="on" result="text" session="user-guide/window" +--8<-- "python/user-guide/expressions/window.py:rank-explode" +``` + +This shows that, usually, `group_by` and `over` produce results of different shapes: + +- `group_by` usually produces a resulting dataframe with as many rows as groups used for + aggregating; and +- `over` usually produces a dataframe with the same number of rows as the original. + +The function `over` does not always produce results with the same number of rows as the original +dataframe, and that is what we explore next. + +## Mapping results to dataframe rows + +The function `over` accepts a parameter `mapping_strategy` that determines how the results of the +expression over the group are mapped back to the rows of the dataframe. + +### `group_to_rows` + +The default behaviour is `"group_to_rows"`: the result of the expression over the group should be +the same length as the group and the results are mapped back to the rows of that group. + +If the order of the rows is not relevant, the option `"explode"` is more performant. Instead of +mapping the resulting values to the original rows, Polars creates a new dataframe where values from +the same group are next to each other. To help understand the distinction, consider the following +dataframe: + +```python exec="on" result="text" session="user-guide/window" +--8<-- "python/user-guide/expressions/window.py:athletes" +``` + +We can sort the athletes by rank within their own countries. If we do so, the Dutch athletes were in +the second, third, and sixth, rows, and they will remain there. What will change is the order of the +names of the athletes, which goes from “B”, “C”, and “F”, to “B”, “F”, and “C”: + +{{code_block('user-guide/expressions/window','athletes-sort-over-country',['over'])}} + +```python exec="on" result="text" session="user-guide/window" +--8<-- "python/user-guide/expressions/window.py:athletes-sort-over-country" +``` + +The diagram below represents this transformation: + +
+--8<-- "docs/source/user-guide/expressions/athletes_over_country.svg" +
+ +### `explode` + +If we set the parameter `mapping_strategy` to `"explode"`, then athletes of the same country are +grouped together, but the final order of the rows – with respect to the countries – will not be the +same, as the diagram shows: + +
+--8<-- "docs/source/user-guide/expressions/athletes_over_country_explode.svg" +
+ +Because Polars does not need to keep track of the positions of the rows of each group, using +`"explode"` is typically faster than `"group_to_rows"`. However, using `"explode"` also requires +more care because it implies reordering the other columns that we wish to keep. The code that +produces this result follows + +{{code_block('user-guide/expressions/window','athletes-explode',['over'])}} + +```python exec="on" result="text" session="user-guide/window" +--8<-- "python/user-guide/expressions/window.py:athletes-explode" +``` + +### `join` + +Another possible value for the parameter `mapping_strategy` is `"join"`, which aggregates the +resulting values in a list and repeats the list over all rows of the same group: + +{{code_block('user-guide/expressions/window','athletes-join',['over'])}} + +```python exec="on" result="text" session="user-guide/window" +--8<-- "python/user-guide/expressions/window.py:athletes-join" +``` + +## Windowed aggregation expressions + +In case the expression applied to the values of a group produces a scalar value, the scalar is +broadcast across the rows of the group: + +{{code_block('user-guide/expressions/window','pokemon-mean',['over'])}} + +```python exec="on" result="text" session="user-guide/window" +--8<-- "python/user-guide/expressions/window.py:pokemon-mean" +``` + +## More examples + +For more exercises, below are some window functions for us to compute: + +- sort all Pokémon by type; +- select the first `3` Pokémon per type as `"Type 1"`; +- sort the Pokémon within a type by speed in descending order and select the first `3` as + `"fastest/group"`; +- sort the Pokémon within a type by attack in descending order and select the first `3` as + `"strongest/group"`; and +- sort the Pokémon within a type by name and select the first `3` as `"sorted_by_alphabet"`. + +{{code_block('user-guide/expressions/window','examples',['over'])}} + +```python exec="on" result="text" session="user-guide/window" +--8<-- "python/user-guide/expressions/window.py:examples" +``` diff --git a/docs/source/user-guide/getting-started.md b/docs/source/user-guide/getting-started.md new file mode 100644 index 000000000000..a59d32c6b301 --- /dev/null +++ b/docs/source/user-guide/getting-started.md @@ -0,0 +1,211 @@ +# Getting started + +This chapter is here to help you get started with Polars. It covers all the fundamental features and +functionalities of the library, making it easy for new users to familiarise themselves with the +basics from initial installation and setup to core functionalities. If you're already an advanced +user or familiar with dataframes, feel free to skip ahead to the +[next chapter about installation options](installation.md). + +## Installing Polars + +=== ":fontawesome-brands-python: Python" + + ``` bash + pip install polars + ``` + +=== ":fontawesome-brands-rust: Rust" + + ``` shell + cargo add polars -F lazy + + # Or Cargo.toml + [dependencies] + polars = { version = "x", features = ["lazy", ...]} + ``` + +## Reading & writing + +Polars supports reading and writing for common file formats (e.g., csv, json, parquet), cloud +storage (S3, Azure Blob, BigQuery) and databases (e.g., postgres, mysql). Below, we create a small +dataframe and show how to write it to disk and read it back. + +{{code_block('user-guide/getting-started','df',['DataFrame'])}} + +```python exec="on" result="text" session="getting-started" +--8<-- "python/user-guide/getting-started.py:df" +``` + +In the example below we write the dataframe to a csv file called `output.csv`. After that, we read +it back using `read_csv` and then print the result for inspection. + +{{code_block('user-guide/getting-started','csv',['read_csv','write_csv'])}} + +```python exec="on" result="text" session="getting-started" +--8<-- "python/user-guide/getting-started.py:csv" +``` + +For more examples on the CSV file format and other data formats, see the [IO section](io/index.md) +of the user guide. + +## Expressions and contexts + +_Expressions_ are one of the main strengths of Polars because they provide a modular and flexible +way of expressing data transformations. + +Here is an example of a Polars expression: + +```py +pl.col("weight") / (pl.col("height") ** 2) +``` + +As you might be able to guess, this expression takes the column named “weight” and divides its +values by the square of the values in the column “height”, computing a person's BMI. Note that the +code above expresses an abstract computation: it's only inside a Polars _context_ that the +expression materalizes into a series with the results. + +Below, we will show examples of Polars expressions inside different contexts: + +- `select` +- `with_columns` +- `filter` +- `group_by` + +For a more +[detailed exploration of expressions and contexts see the respective user guide section](concepts/expressions-and-contexts.md). + +### `select` + +The context `select` allows you to select and manipulate columns from a dataframe. In the simplest +case, each expression you provide will map to a column in the result dataframe: + +{{code_block('user-guide/getting-started','select',['select','alias','Expr.dt'])}} + +```python exec="on" result="text" session="getting-started" +--8<-- "python/user-guide/getting-started.py:select" +``` + +Polars also supports a feature called “expression expansion”, in which one expression acts as +shorthand for multiple expressions. In the example below, we use expression expansion to manipulate +the columns “weight” and “height” with a single expression. When using expression expansion you can +use `.name.suffix` to add a suffix to the names of the original columns: + +{{code_block('user-guide/getting-started','expression-expansion',['select','alias','Expr.name'])}} + +```python exec="on" result="text" session="getting-started" +--8<-- "python/user-guide/getting-started.py:expression-expansion" +``` + +You can check other sections of the user guide to learn more about +[basic operations](expressions/basic-operations.md) or +[column selections in expression expansion](expressions/expression-expansion.md). + +### `with_columns` + +The context `with_columns` is very similar to the context `select` but `with_columns` adds columns +to the dataframe instead of selecting them. Notice how the resulting dataframe contains the four +columns of the original dataframe plus the two new columns introduced by the expressions inside +`with_columns`: + +{{code_block('user-guide/getting-started','with_columns',['with_columns'])}} + +```python exec="on" result="text" session="getting-started" +--8<-- "python/user-guide/getting-started.py:with_columns" +``` + +In the example above we also decided to use named expressions instead of the method `alias` to +specify the names of the new columns. Other contexts like `select` and `group_by` also accept named +expressions. + +### `filter` + +The context `filter` allows us to create a second dataframe with a subset of the rows of the +original one: + +{{code_block('user-guide/getting-started','filter',['filter','Expr.dt'])}} + +```python exec="on" result="text" session="getting-started" +--8<-- "python/user-guide/getting-started.py:filter" +``` + +You can also provide multiple predicate expressions as separate parameters, which is more convenient +than putting them all together with `&`: + +{{code_block('user-guide/getting-started','filter-multiple',['filter','is_between'])}} + +```python exec="on" result="text" session="getting-started" +--8<-- "python/user-guide/getting-started.py:filter-multiple" +``` + +### `group_by` + +The context `group_by` can be used to group together the rows of the dataframe that share the same +value across one or more expressions. The example below counts how many people were born in each +decade: + +{{code_block('user-guide/getting-started','group_by',['group_by','alias','Expr.dt'])}} + +```python exec="on" result="text" session="getting-started" +--8<-- "python/user-guide/getting-started.py:group_by" +``` + +The keyword argument `maintain_order` forces Polars to present the resulting groups in the same +order as they appear in the original dataframe. This slows down the grouping operation but is used +here to ensure reproducibility of the examples. + +After using the context `group_by` we can use `agg` to compute aggregations over the resulting +groups: + +{{code_block('user-guide/getting-started','group_by-agg',['group_by','agg'])}} + +```python exec="on" result="text" session="getting-started" +--8<-- "python/user-guide/getting-started.py:group_by-agg" +``` + +### More complex queries + +Contexts and the expressions within can be chained to create more complex queries according to your +needs. In the example below we combine some of the contexts we have seen so far to create a more +complex query: + +{{code_block('user-guide/getting-started','complex',['group_by','agg','select','with_columns','Expr.str','Expr.list'])}} + +```python exec="on" result="text" session="getting-started" +--8<-- "python/user-guide/getting-started.py:complex" +``` + +## Combining dataframes + +Polars provides a number of tools to combine two dataframes. In this section, we show an example of +a join and an example of a concatenation. + +### Joining dataframes + +Polars provides many different join algorithms. The example below shows how to use a left outer join +to combine two dataframes when a column can be used as a unique identifier to establish a +correspondence between rows across the dataframes: + +{{code_block('user-guide/getting-started','join',['join'])}} + +```python exec="on" result="text" session="getting-started" +--8<-- "python/user-guide/getting-started.py:join" +``` + +Polars provides many different join algorithms that you can learn about in the +[joins section of the user guide](transformations/joins.md). + +### Concatenating dataframes + +Concatenating dataframes creates a taller or wider dataframe, depending on the method used. Assuming +we have a second dataframe with data from other people, we could use vertical concatenation to +create a taller dataframe: + +{{code_block('user-guide/getting-started','concat',['concat'])}} + +```python exec="on" result="text" session="getting-started" +--8<-- "python/user-guide/getting-started.py:concat" +``` + +Polars provides vertical and horizontal concatenation, as well as diagonal concatenation. You can +learn more about these in the +[concatenations section of the user guide](transformations/concatenation.md). diff --git a/docs/source/user-guide/gpu-support.md b/docs/source/user-guide/gpu-support.md new file mode 100644 index 000000000000..ba3ad5c241d7 --- /dev/null +++ b/docs/source/user-guide/gpu-support.md @@ -0,0 +1,205 @@ +# GPU Support [Open Beta] + +Polars provides an in-memory, GPU-accelerated execution engine for Python users of the Lazy API on +NVIDIA GPUs using [RAPIDS cuDF](https://docs.rapids.ai/api/cudf/stable/). This functionality is +available in Open Beta and is undergoing rapid development. + +### System Requirements + +- NVIDIA Volta™ or higher GPU with [compute capability](https://developer.nvidia.com/cuda-gpus) 7.0+ +- CUDA 11 or CUDA 12 +- Linux or Windows Subsystem for Linux 2 (WSL2) + +See the [RAPIDS installation guide](https://docs.rapids.ai/install#system-req) for full details. + +### Installation + +You can install the GPU backend for Polars with a feature flag as part of a normal +[installation](installation.md). + +=== ":fontawesome-brands-python: Python" + +```bash +pip install polars[gpu] +``` + +!!! note Installation on a CUDA 11 system + + If you have CUDA 11, the installation line also needs the NVIDIA package index to get the CUDA 11 package. + + === ":fontawesome-brands-python: Python" + ```bash + pip install --extra-index-url=https://pypi.nvidia.com polars cudf-polars-cu11 + ``` + +### Usage + +Having built a query using the lazy API [as normal](lazy/index.md), GPU-enabled execution is +requested by running `.collect(engine="gpu")` instead of `.collect()`. + +{{ code_header("python", [], []) }} + +```python +--8<-- "python/user-guide/lazy/gpu.py:setup" + +result = q.collect(engine="gpu") +print(result) +``` + +```python exec="on" result="text" session="user-guide/lazy" +--8<-- "python/user-guide/lazy/gpu.py:setup" +--8<-- "python/user-guide/lazy/gpu.py:simple-result" +``` + +For more detailed control over the execution, for example to specify which GPU to use on a multi-GPU +node, we can provide a `GPUEngine` object. By default, the GPU engine will use a configuration +applicable to most use cases. + +{{ code_header("python", [], []) }} + +```python +--8<-- "python/user-guide/lazy/gpu.py:engine-setup" +result = q.collect(engine=pl.GPUEngine(device=1)) +print(result) +``` + +```python exec="on" result="text" session="user-guide/lazy" +--8<-- "python/user-guide/lazy/gpu.py:engine-setup" +--8<-- "python/user-guide/lazy/gpu.py:engine-result" +``` + +### How It Works + +When you use the GPU-accelerated engine, Polars creates and optimizes a query plan and dispatches to +a [RAPIDS](https://rapids.ai/) cuDF-based physical execution engine to compute the results on NVIDIA +GPUs. The final result is returned as a normal CPU-backed Polars dataframe. + +### What's Supported on the GPU? + +GPU support is currently in Open Beta and the engine is undergoing rapid development. The engine +currently supports many, but not all, of the core expressions and data types. + +Since expressions are composable, it's not feasible to list a full matrix of expressions supported +on the GPU. Instead, we provide a list of the high-level categories of expressions and interfaces +that are currently supported and not supported. + +#### Supported + +- LazyFrame API +- SQL API +- I/O from CSV, Parquet, ndjson, and in-memory CPU DataFrames. +- Operations on numeric, logical, string, and datetime types +- String processing +- Aggregations and grouped aggregations +- Joins +- Filters +- Missing data +- Concatenation + +#### Not Supported + +- Eager DataFrame API +- Streaming API +- Operations on categorical, struct, and list data types +- Rolling aggregations +- Time series resampling +- Timezones +- Folds +- User-defined functions +- JSON, Excel, and Database file formats + +#### Did my query use the GPU? + +The release of the GPU engine in Open Beta implies that we expect things to work well, but there are +still some rough edges we're working on. In particular the full breadth of the Polars expression API +is not yet supported. With fallback to the CPU, your query _should_ complete, but you might not +observe any change in the time it takes to execute. There are two ways to get more information on +whether the query ran on the GPU. + +When running in verbose mode, any queries that cannot execute on the GPU will issue a +`PerformanceWarning`: + +{{ code_header("python", [], []) }} + +```python +--8<-- "python/user-guide/lazy/gpu.py:fallback-setup" + +with pl.Config() as cfg: + cfg.set_verbose(True) + result = q.collect(engine="gpu") + +print(result) +``` + +```python exec="on" result="text" session="user-guide/lazy" +--8<-- "python/user-guide/lazy/gpu.py:fallback-setup" +print( + "PerformanceWarning: Query execution with GPU not supported, reason: \n" + ": Grouped rolling window not implemented" +) +print("# some details elided") +print() +print(q.collect()) +``` + +To disable fallback, and have the GPU engine raise an exception if a query is unsupported, we can +pass an appropriately configured `GPUEngine` object: + +{{ code_header("python", [], []) }} + +```python +q.collect(engine=pl.GPUEngine(raise_on_fail=True)) +``` + +```pytb +Traceback (most recent call last): + File "", line 1, in + File "/home/coder/third-party/polars/py-polars/polars/lazyframe/frame.py", line 2035, in collect + return wrap_df(ldf.collect(callback)) +polars.exceptions.ComputeError: 'cuda' conversion failed: NotImplementedError: Grouped rolling window not implemented +``` + +Currently, only the proximal cause of failure to execute on the GPU is reported, we plan to extend +this functionality to report all unsupported operations for a query. + +### Testing + +The Polars and NVIDIA RAPIDS teams run comprehensive unit and integration tests to ensure that the +GPU-accelerated Polars backend works smoothly. + +The **full** Polars test suite is run on every commit made to the GPU engine, ensuring consistency +of results. + +The GPU engine currently passes 99.2% of the Polars unit tests with CPU fallback enabled. Without +CPU fallback, the GPU engine passes 88.8% of the Polars unit tests. With fallback, there are +approximately 100 failing tests: around 40 of these fail due to mismatching debug output; there are +some cases where the GPU engine produces the a correct result but uses a different data type; the +remainder are cases where we do not correctly determine that a query is unsupported and therefore +fail at runtime, instead of falling back. + +### When Should I Use a GPU? + +Based on our benchmarking, you're most likely to observe speedups using the GPU engine when your +workflow's profile is dominated by grouped aggregations and joins. In contrast I/O bound queries +typically show similar performance on GPU and CPU. GPUs typically have less RAM than CPU systems, +therefore very large datasets will fail due to out of memory errors. Based on our testing, raw +datasets of 50-100 GiB fit (depending on the workflow) well with a GPU with 80GiB of memory. + +### CPU-GPU Interoperability + +Both the CPU and GPU engine use the Apache Arrow columnar memory specification, making it possible +to quickly move data between the CPU and GPU. Additionally, files written by one engine can be read +by the other engine. + +When using GPU mode, your workflow won't fail if something isn't supported. When you run +`collect(engine="gpu")`, the optimized query plan is inspected to see whether it can be executed on +the GPU. If it can't, it will transparently fall back to the standard Polars engine and run on the +CPU. + +GPU execution is only available in the Lazy API, so materialized DataFrames will reside in CPU +memory when the query execution finishes. + +### Providing feedback + +Please report issues, and missing features, on the Polars +[issue tracker](https://github.com/pola-rs/polars/issues). diff --git a/docs/source/user-guide/installation.md b/docs/source/user-guide/installation.md new file mode 100644 index 000000000000..0a5334625746 --- /dev/null +++ b/docs/source/user-guide/installation.md @@ -0,0 +1,271 @@ +# Installation + +Polars is a library and installation is as simple as invoking the package manager of the +corresponding programming language. + +=== ":fontawesome-brands-python: Python" + + ``` bash + pip install polars + + # Or for legacy CPUs without AVX2 support + pip install polars-lts-cpu + ``` + +=== ":fontawesome-brands-rust: Rust" + + ``` shell + cargo add polars -F lazy + + # Or Cargo.toml + [dependencies] + polars = { version = "x", features = ["lazy", ...]} + ``` + +## Big Index + +By default, Polars dataframes are limited to $2^{32}$ rows (~4.3 billion). Increase this limit to +$2^{64}$ (~18 quintillion) by enabling the big index extension: + +=== ":fontawesome-brands-python: Python" + + ``` bash + pip install polars-u64-idx + ``` + +=== ":fontawesome-brands-rust: Rust" + + ``` shell + cargo add polars -F bigidx + + # Or Cargo.toml + [dependencies] + polars = { version = "x", features = ["bigidx", ...] } + ``` + +## Legacy CPU + +To install Polars for Python on an old CPU without +[AVX](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions) support, run: + +=== ":fontawesome-brands-python: Python" + + ``` bash + pip install polars-lts-cpu + ``` + +## Importing + +To use the library, simply import it into your project: + +=== ":fontawesome-brands-python: Python" + + ``` python + import polars as pl + ``` + +=== ":fontawesome-brands-rust: Rust" + + ``` rust + use polars::prelude::*; + ``` + +## Feature flags + +By using the above command you install the core of Polars onto your system. However, depending on +your use case, you might want to install the optional dependencies as well. These are made optional +to minimize the footprint. The flags are different depending on the programming language. Throughout +the user guide we will mention when a functionality used requires an additional dependency. + +### Python + +```text +# For example +pip install 'polars[numpy,fsspec]' +``` + +#### All + +| Tag | Description | +| --- | ---------------------------------- | +| all | Install all optional dependencies. | + +#### GPU + +| Tag | Description | +| --- | --------------------------- | +| gpu | Run queries on NVIDIA GPUs. | + +!!! note + + See [GPU support](gpu-support.md) for more detailed instructions and + prerequisites. + +#### Interoperability + +| Tag | Description | +| -------- | -------------------------------------------------- | +| pandas | Convert data to and from pandas dataframes/series. | +| numpy | Convert data to and from NumPy arrays. | +| pyarrow | Convert data to and from PyArrow tables/arrays. | +| pydantic | Convert data from Pydantic models to Polars. | + +#### Excel + +| Tag | Description | +| ---------- | ------------------------------------------------ | +| calamine | Read from Excel files with the calamine engine. | +| openpyxl | Read from Excel files with the openpyxl engine. | +| xlsx2csv | Read from Excel files with the xlsx2csv engine. | +| xlsxwriter | Write to Excel files with the XlsxWriter engine. | +| excel | Install all supported Excel engines. | + +#### Database + +| Tag | Description | +| ---------- | ------------------------------------------------------------------------------------ | +| adbc | Read from and write to databases with the Arrow Database Connectivity (ADBC) engine. | +| connectorx | Read from databases with the ConnectorX engine. | +| sqlalchemy | Write to databases with the SQLAlchemy engine. | +| database | Install all supported database engines. | + +#### Cloud + +| Tag | Description | +| ------ | ------------------------------------------- | +| fsspec | Read from and write to remote file systems. | + +#### Other I/O + +| Tag | Description | +| --------- | ------------------------------------ | +| deltalake | Read from and write to Delta tables. | +| iceberg | Read from Apache Iceberg tables. | + +#### Other + +| Tag | Description | +| ----------- | ----------------------------------------------- | +| async | Collect LazyFrames asynchronously. | +| cloudpickle | Serialize user-defined functions. | +| graph | Visualize LazyFrames as a graph. | +| plot | Plot dataframes through the `plot` namespace. | +| style | Style dataframes through the `style` namespace. | +| timezone | Timezone support[^note]. | + +[^note]: Only needed if you are on Windows. + +### Rust + +```toml +# Cargo.toml +[dependencies] +polars = { version = "0.26.1", features = ["lazy", "temporal", "describe", "json", "parquet", "dtype-datetime"] } +``` + +The opt-in features are: + + + +- Additional data types: + - `dtype-date` + - `dtype-datetime` + - `dtype-time` + - `dtype-duration` + - `dtype-i8` + - `dtype-i16` + - `dtype-u8` + - `dtype-u16` + - `dtype-categorical` + - `dtype-struct` +- `lazy` - Lazy API: + - `regex` - Use regexes in column selection. + - `dot_diagram` - Create dot diagrams from lazy logical plans. +- `sql` - Pass SQL queries to Polars. +- `streaming` - Be able to process datasets that are larger than RAM. +- `random` - Generate arrays with randomly sampled values +- `ndarray`- Convert from `DataFrame` to `ndarray` +- `temporal` - Conversions between [Chrono](https://docs.rs/chrono/) and Polars for temporal data types +- `timezones` - Activate timezone support. +- `strings` - Extra string utilities for `StringChunked`: + - `string_pad` - for `pad_start`, `pad_end`, `zfill`. + - `string_to_integer` - for `parse_int`. +- `object` - Support for generic ChunkedArrays called `ObjectChunked` (generic over `T`). + These are downcastable from Series through the [Any](https://doc.rust-lang.org/std/any/index.html) trait. +- Performance related: + - `nightly` - Several nightly only features such as SIMD and specialization. + - `performant` - more fast paths, slower compile times. + - `bigidx` - Activate this feature if you expect >> $2^{32}$ rows. + This allows polars to scale up way beyond that by using `u64` as an index. + Polars will be a bit slower with this feature activated as many data structures + are less cache efficient. + - `cse` - Activate common subplan elimination optimization. +- IO related: + - `serde` - Support for [serde](https://crates.io/crates/serde) serialization and deserialization. + Can be used for JSON and more serde supported serialization formats. + - `serde-lazy` - Support for [serde](https://crates.io/crates/serde) serialization and deserialization. + Can be used for JSON and more serde supported serialization formats. + - `parquet` - Read Apache Parquet format. + - `json` - JSON serialization. + - `ipc` - Arrow's IPC format serialization. + - `decompress` - Automatically infer compression of csvs and decompress them. + Supported compressions: + - gzip + - zlib + - zstd +- Dataframe operations: + - `dynamic_group_by` - Group by based on a time window instead of predefined keys. + Also activates rolling window group by operations. + - `sort_multiple` - Allow sorting a dataframe on multiple columns. + - `rows` - Create dataframe from rows and extract rows from `dataframes`. + Also activates `pivot` and `transpose` operations. + - `join_asof` - Join ASOF, to join on nearest keys instead of exact equality match. + - `cross_join` - Create the Cartesian product of two dataframes. + - `semi_anti_join` - SEMI and ANTI joins. + - `row_hash` - Utility to hash dataframe rows to `UInt64Chunked`. + - `diagonal_concat` - Diagonal concatenation thereby combining different schemas. + - `dataframe_arithmetic` - Arithmetic between dataframes and other dataframes or series. + - `partition_by` - Split into multiple dataframes partitioned by groups. +- Series/expression operations: + - `is_in` - Check for membership in Series. + - `zip_with` - Zip two `Series` / `ChunkedArray`s. + - `round_series` - round underlying float types of series. + - `repeat_by` - Repeat element in an array a number of times specified by another array. + - `is_first_distinct` - Check if element is first unique value. + - `is_last_distinct` - Check if element is last unique value. + - `checked_arithmetic` - checked arithmetic returning `None` on invalid operations. + - `dot_product` - Dot/inner product on series and expressions. + - `concat_str` - Concatenate string data in linear time. + - `reinterpret` - Utility to reinterpret bits to signed/unsigned. + - `take_opt_iter` - Take from a series with `Iterator>`. + - `mode` - Return the most frequently occurring value(s). + - `cum_agg` - `cum_sum`, `cum_min`, and `cum_max`, aggregations. + - `rolling_window` - rolling window functions, like `rolling_mean`. + - `interpolate` - Interpolate intermediate `None` values. + - `extract_jsonpath` - [Run `jsonpath` queries on `StringChunked`](https://goessner.net/articles/JsonPath/). + - `list` - List utils: + - `list_gather` - take sublist by multiple indices. + - `rank` - Ranking algorithms. + - `moment` - Kurtosis and skew statistics. + - `ewma` - Exponential moving average windows. + - `abs` - Get absolute values of series. + - `arange` - Range operation on series. + - `product` - Compute the product of a series. + - `diff` - `diff` operation. + - `pct_change` - Compute change percentages. + - `unique_counts` - Count unique values in expressions. + - `log` - Logarithms for series. + - `list_to_struct` - Convert `List` to `Struct` data types. + - `list_count` - Count elements in lists. + - `list_eval` - Apply expressions over list elements. + - `cumulative_eval` - Apply expressions over cumulatively increasing windows. + - `arg_where` - Get indices where condition holds. + - `search_sorted` - Find indices where elements should be inserted to maintain order. + - `offset_by` - Add an offset to dates that take months and leap years into account. + - `trigonometry` - Trigonometric functions. + - `sign` - Compute the element-wise sign of a series. + - `propagate_nans` - `NaN`-propagating min/max aggregations. +- Dataframe pretty printing: + - `fmt` - Activate dataframe formatting. + + diff --git a/docs/source/user-guide/io/bigquery.md b/docs/source/user-guide/io/bigquery.md new file mode 100644 index 000000000000..21287cd448d2 --- /dev/null +++ b/docs/source/user-guide/io/bigquery.md @@ -0,0 +1,19 @@ +# Google BigQuery + +To read or write from GBQ, additional dependencies are needed: + +=== ":fontawesome-brands-python: Python" + +```shell +$ pip install google-cloud-bigquery +``` + +## Read + +We can load a query into a `DataFrame` like this: + +{{code_block('user-guide/io/bigquery','read',['from_arrow'])}} + +## Write + +{{code_block('user-guide/io/bigquery','write',[])}} diff --git a/docs/source/user-guide/io/cloud-storage.md b/docs/source/user-guide/io/cloud-storage.md new file mode 100644 index 000000000000..476111577ff1 --- /dev/null +++ b/docs/source/user-guide/io/cloud-storage.md @@ -0,0 +1,83 @@ +# Cloud storage + +Polars can read and write to AWS S3, Azure Blob Storage and Google Cloud Storage. The API is the +same for all three storage providers. + +To read from cloud storage, additional dependencies may be needed depending on the use case and +cloud storage provider: + +=== ":fontawesome-brands-python: Python" + + ```shell + $ pip install fsspec s3fs adlfs gcsfs + ``` + +=== ":fontawesome-brands-rust: Rust" + + ```shell + $ cargo add aws_sdk_s3 aws_config tokio --features tokio/full + ``` + +## Reading from cloud storage + +Polars supports reading Parquet, CSV, IPC and NDJSON files from cloud storage: + +{{code_block('user-guide/io/cloud-storage','read_parquet',['read_parquet','read_csv','read_ipc'])}} + +## Scanning from cloud storage with query optimisation + +Using `pl.scan_*` functions to read from cloud storage can benefit from +[predicate and projection pushdowns](../lazy/optimizations.md), where the query optimizer will apply +them before the file is downloaded. This can significantly reduce the amount of data that needs to +be downloaded. The query evaluation is triggered by calling `collect`. + +{{code_block('user-guide/io/cloud-storage','scan_parquet_query',[])}} + +## Cloud authentication + +Polars is able to automatically load default credential configurations for some cloud providers. For +cases when this does not happen, it is possible to manually configure the credentials for Polars to +use for authentication. This can be done in a few ways: + +### Using `storage_options`: + +- Credentials can be passed as configuration keys in a dict with the `storage_options` parameter: + +{{code_block('user-guide/io/cloud-storage','scan_parquet_storage_options_aws',['scan_parquet'])}} + +### Using one of the available `CredentialProvider*` utility classes + +- There may be a utility class `pl.CredentialProvider*` that provides the required authentication + functionality. For example, `pl.CredentialProviderAWS` supports selecting AWS profiles, as well as + assuming an IAM role: + +{{code_block('user-guide/io/cloud-storage','credential_provider_class',['scan_parquet', +'CredentialProviderAWS'])}} + +### Using a custom `credential_provider` function + +- Some environments may require custom authentication logic (e.g. AWS IAM role-chaining). For these + cases a Python function can be provided for Polars to use to retrieve credentials: + +{{code_block('user-guide/io/cloud-storage','credential_provider_custom_func',['scan_parquet'])}} + +- Example for Azure: + +{{code_block('user-guide/io/cloud-storage','credential_provider_custom_func_azure',['scan_parquet', +'CredentialProviderAzure'])}} + +## Scanning with PyArrow + +We can also scan from cloud storage using PyArrow. This is particularly useful for partitioned +datasets such as Hive partitioning. + +We first create a PyArrow dataset and then create a `LazyFrame` from the dataset. + +{{code_block('user-guide/io/cloud-storage','scan_pyarrow_dataset',['scan_pyarrow_dataset'])}} + +## Writing to cloud storage + +We can write a `DataFrame` to cloud storage in Python using s3fs for S3, adlfs for Azure Blob +Storage and gcsfs for Google Cloud Storage. In this example, we write a Parquet file to S3. + +{{code_block('user-guide/io/cloud-storage','write_parquet',['write_parquet'])}} diff --git a/docs/source/user-guide/io/csv.md b/docs/source/user-guide/io/csv.md new file mode 100644 index 000000000000..0681daddbd5e --- /dev/null +++ b/docs/source/user-guide/io/csv.md @@ -0,0 +1,21 @@ +# CSV + +## Read & write + +Reading a CSV file should look familiar: + +{{code_block('user-guide/io/csv','read',['read_csv'])}} + +Writing a CSV file is similar with the `write_csv` function: + +{{code_block('user-guide/io/csv','write',['write_csv'])}} + +## Scan + +Polars allows you to _scan_ a CSV input. Scanning delays the actual parsing of the file and instead +returns a lazy computation holder called a `LazyFrame`. + +{{code_block('user-guide/io/csv','scan',['scan_csv'])}} + +If you want to know why this is desirable, you can read more about these Polars optimizations +[here](../concepts/lazy-api.md). diff --git a/docs/source/user-guide/io/database.md b/docs/source/user-guide/io/database.md new file mode 100644 index 000000000000..e9b3a6075799 --- /dev/null +++ b/docs/source/user-guide/io/database.md @@ -0,0 +1,107 @@ +# Databases + +## Read from a database + +Polars can read from a database using the `pl.read_database_uri` and `pl.read_database` functions. + +### Difference between `read_database_uri` and `read_database` + +Use `pl.read_database_uri` if you want to specify the database connection with a connection string +called a `uri`. For example, the following snippet shows a query to read all columns from the `foo` +table in a Postgres database where we use the `uri` to connect: + +{{code_block('user-guide/io/database','read_uri',['read_database_uri'])}} + +On the other hand, use `pl.read_database` if you want to connect via a connection engine created +with a library like SQLAlchemy. + +{{code_block('user-guide/io/database','read_cursor',['read_database'])}} + +Note that `pl.read_database_uri` is likely to be faster than `pl.read_database` if you are using a +SQLAlchemy or DBAPI2 connection as these connections may load the data row-wise into Python before +copying the data again to the column-wise Apache Arrow format. + +### Engines + +Polars doesn't manage connections and data transfer from databases by itself. Instead, external +libraries (known as _engines_) handle this. + +When using `pl.read_database`, you specify the engine when you create the connection object. When +using `pl.read_database_uri`, you can specify one of two engines to read from the database: + +- [ConnectorX](https://github.com/sfu-db/connector-x) and +- [ADBC](https://arrow.apache.org/docs/format/ADBC.html) + +Both engines have native support for Apache Arrow and so can read data directly into a Polars +`DataFrame` without copying the data. + +#### ConnectorX + +ConnectorX is the default engine and +[supports numerous databases](https://github.com/sfu-db/connector-x#sources) including Postgres, +Mysql, SQL Server and Redshift. ConnectorX is written in Rust and stores data in Arrow format to +allow for zero-copy to Polars. + +To read from one of the supported databases with `ConnectorX` you need to activate the additional +dependency `ConnectorX` when installing Polars or install it manually with + +```shell +$ pip install connectorx +``` + +#### ADBC + +ADBC (Arrow Database Connectivity) is an engine supported by the Apache Arrow project. ADBC aims to +be both an API standard for connecting to databases and libraries implementing this standard in a +range of languages. + +It is still early days for ADBC so support for different databases is limited. At present, drivers +for ADBC are only available for [Postgres](https://pypi.org/project/adbc-driver-postgresql/), +[SQLite](https://pypi.org/project/adbc-driver-sqlite/) and +[Snowflake](https://pypi.org/project/adbc-driver-snowflake/). To install ADBC, you need to install +the driver for your database. For example, to install the driver for SQLite, you run: + +```shell +$ pip install adbc-driver-sqlite +``` + +As ADBC is not the default engine, you must specify the engine as an argument to +`pl.read_database_uri`. + +{{code_block('user-guide/io/database','adbc',['read_database_uri'])}} + +## Write to a database + +We can write to a database with Polars using the `pl.write_database` function. + +### Engines + +As with reading from a database above, Polars uses an _engine_ to write to a database. The currently +supported engines are: + +- [SQLAlchemy](https://www.sqlalchemy.org/) and +- Arrow Database Connectivity (ADBC) + +#### SQLAlchemy + +With the default engine SQLAlchemy you can write to any database supported by SQLAlchemy. To use +this engine you need to install SQLAlchemy and Pandas + +```shell +$ pip install SQLAlchemy pandas +``` + +In this example, we write the `DataFrame` to a table called `records` in the database + +{{code_block('user-guide/io/database','write',['write_database'])}} + +In the SQLAlchemy approach, Polars converts the `DataFrame` to a Pandas `DataFrame` backed by +PyArrow and then uses SQLAlchemy methods on a Pandas `DataFrame` to write to the database. + +#### ADBC + +ADBC can also be used to write to a database. Writing is supported for the same databases that +support reading with ADBC. As shown above, you need to install the appropriate ADBC driver for your +database. + +{{code_block('user-guide/io/database','write_adbc',['write_database'])}} diff --git a/docs/source/user-guide/io/excel.md b/docs/source/user-guide/io/excel.md new file mode 100644 index 000000000000..ce0a9a5ee208 --- /dev/null +++ b/docs/source/user-guide/io/excel.md @@ -0,0 +1,64 @@ +# Excel + +Polars can read and write to Excel files from Python. From a performance perspective, we recommend +using other formats if possible, such as Parquet or CSV files. + +## Read + +Polars does not have a native Excel reader. Instead, it uses an external library called an "engine" +to parse Excel files into a form that Polars can parse. The available engines are: + +- fastexcel: This engine is based on the Rust [calamine](https://github.com/tafia/calamine) crate + and is (by far) the fastest reader. +- xlsx2csv: This reader parses the .xlsx file to an in-memory CSV that Polars then reads with its + own CSV reader. +- openpyxl: Typically slower than xls2csv, but can provide more flexibility for files that are + difficult to parse. + +We recommend working with the default fastexcel engine. The xlsx2csv and openpyxl engines are slower +but may have more features for parsing tricky data. These engines may be helpful if the fastexcel +reader does not work for a specific Excel file. + +To use one of these engines, the appropriate Python package must be installed as an additional +dependency. + +=== ":fontawesome-brands-python: Python" + + ```shell + $ pip install fastexcel xlsx2csv openpyxl + ``` + +The default engine for reading .xslx files is fastexcel. This engine uses the Rust calamine crate to +read .xslx files into an Apache Arrow in-memory representation that Polars can read without needing +to copy the data. + +{{code_block('user-guide/io/excel','read',['read_excel'])}} + +We can specify the sheet name to read with the `sheet_name` argument. If we do not specify a sheet +name, the first sheet will be read. + +{{code_block('user-guide/io/excel','read_sheet_name',['read_excel'])}} + +## Write + +We need the xlswriter library installed as an additional dependency to write to Excel files. + +=== ":fontawesome-brands-python: Python" + + ```shell + $ pip install xlsxwriter + ``` + +Writing to Excel files is not currently available in Rust Polars, though it is possible to +[use this crate](https://docs.rs/crate/xlsxwriter/latest) to write to Excel files from Rust. + +Writing a `DataFrame` to an Excel file is done with the `write_excel` method: + +{{code_block('user-guide/io/excel','write',['write_excel'])}} + +The name of the worksheet can be specified with the `worksheet` argument. + +{{code_block('user-guide/io/excel','write_sheet_name',['write_excel'])}} + +Polars can create rich Excel files with multiple sheets and formatting. For more details, see the +API docs for `write_excel`. diff --git a/docs/source/user-guide/io/hive.md b/docs/source/user-guide/io/hive.md new file mode 100644 index 000000000000..d0de9ada2c59 --- /dev/null +++ b/docs/source/user-guide/io/hive.md @@ -0,0 +1,101 @@ +## Scanning hive partitioned data + +Polars supports scanning hive partitioned parquet and IPC datasets, with planned support for other +formats in the future. + +Hive partition parsing is enabled by default if `scan_parquet` receives a single directory path, +otherwise it is disabled by default. This can be explicitly configured using the `hive_partitioning` +parameter. + +### Scanning a hive directory + +For this example the following directory structure is used: + +```python exec="on" result="text" session="user-guide/io/hive" +--8<-- "python/user-guide/io/hive.py:init_paths" +``` + +Simply pass the directory to `scan_parquet`, and all files will be loaded with the hive parts in the +path included in the output: + +{{code_block('user-guide/io/hive','scan_dir',['scan_parquet'])}} + +```python exec="on" result="text" session="user-guide/io/hive" +--8<-- "python/user-guide/io/hive.py:scan_dir" +``` + +### Handling mixed files + +Passing a directory to `scan_parquet` may not work if there are files with different extensions in +the directory. + +For this example the following directory structure is used: + +```python exec="on" result="text" session="user-guide/io/hive" +--8<-- "python/user-guide/io/hive.py:show_mixed_paths" +``` + +{{code_block('user-guide/io/hive','scan_dir_err',['scan_parquet'])}} + +The above fails as `description.txt` is not a valid parquet file: + +```python exec="on" result="text" session="user-guide/io/hive" +--8<-- "python/user-guide/io/hive.py:scan_dir_err" +``` + +In this situation, a glob pattern can be used to be more specific about which files to load. Note +that `hive_partitioning` must explicitly set to `True`: + +{{code_block('user-guide/io/hive','scan_glob',['scan_parquet'])}} + +```python exec="on" result="text" session="user-guide/io/hive" +--8<-- "python/user-guide/io/hive.py:scan_glob" +``` + +### Scanning file paths with hive parts + +`hive_partitioning` is not enabled by default for file paths: + +{{code_block('user-guide/io/hive','scan_file_no_hive',['scan_parquet'])}} + +```python exec="on" result="text" session="user-guide/io/hive" +--8<-- "python/user-guide/io/hive.py:scan_file_no_hive" +``` + +Pass `hive_partitioning=True` to enable hive partition parsing: + +{{code_block('user-guide/io/hive','scan_file_hive',['scan_parquet'])}} + +```python exec="on" result="text" session="user-guide/io/hive" +--8<-- "python/user-guide/io/hive.py:scan_file_hive" +``` + +## Writing hive partitioned data + +> Note: The following functionality is considered _unstable_, and is subject to change. + +Polars supports writing hive partitioned parquet datasets, with planned support for other formats. + +### Example + +For this example the following DataFrame is used: + +{{code_block('user-guide/io/hive','write_parquet_partitioned_show_data',[])}} + +```python exec="on" result="text" session="user-guide/io/hive" +--8<-- "python/user-guide/io/hive.py:write_parquet_partitioned_show_data" +``` + +We will write it to a hive-partitioned parquet dataset, partitioned by the columns `a` and `b`: + +{{code_block('user-guide/io/hive','write_parquet_partitioned',['write_parquet'])}} + +```python exec="on" result="text" session="user-guide/io/hive" +--8<-- "python/user-guide/io/hive.py:write_parquet_partitioned" +``` + +The output is a hive partitioned parquet dataset with the following paths: + +```python exec="on" result="text" session="user-guide/io/hive" +--8<-- "python/user-guide/io/hive.py:write_parquet_partitioned_show_paths" +``` diff --git a/docs/source/user-guide/io/hugging-face.md b/docs/source/user-guide/io/hugging-face.md new file mode 100644 index 000000000000..692586137831 --- /dev/null +++ b/docs/source/user-guide/io/hugging-face.md @@ -0,0 +1,93 @@ +# Hugging Face + +## Scanning datasets from Hugging Face + +All cloud-enabled scan functions, and their `read_` counterparts transparently support scanning from +Hugging Face: + +| Scan | Read | +| --------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------- | +| [scan_parquet](https://docs.pola.rs/api/python/stable/reference/api/polars.scan_parquet.html) | [read_parquet](https://docs.pola.rs/api/python/stable/reference/api/polars.read_parquet.html) | +| [scan_csv](https://docs.pola.rs/api/python/stable/reference/api/polars.scan_csv.html) | [read_csv](https://docs.pola.rs/api/python/stable/reference/api/polars.read_csv.html) | +| [scan_ndjson](https://docs.pola.rs/api/python/stable/reference/api/polars.scan_ndjson.html) | [read_ndjson](https://docs.pola.rs/api/python/stable/reference/api/polars.read_ndjson.html) | +| [scan_ipc](https://docs.pola.rs/api/python/stable/reference/api/polars.scan_ipc.html) | [read_ipc](https://docs.pola.rs/api/python/stable/reference/api/polars.read_ipc.html) | + +### Path format + +To scan from Hugging Face, a `hf://` path can be passed to the scan functions. The `hf://` path +format is defined as `hf://BUCKET/REPOSITORY@REVISION/PATH`, where: + +- `BUCKET` is one of `datasets` or `spaces` +- `REPOSITORY` is the location of the repository, this is usually in the format of + `username/repo_name`. A branch can also be optionally specified by appending `@branch` +- `REVISION` is the name of the branch (or commit) to use. This is optional and defaults to `main` + if not given. +- `PATH` is a file or directory path, or a glob pattern from the repository root. + +Example `hf://` paths: + +| Path | Path components | +| ----------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| hf://datasets/nameexhaustion/polars-docs/iris.csv | Bucket: datasets
Repository: nameexhaustion/polars-docs
Branch: main
Path: iris.csv
[Web URL](https://huggingface.co/datasets/nameexhaustion/polars-docs/tree/main/) | +| hf://datasets/nameexhaustion/polars-docs@foods/\*.csv | Bucket: datasets
Repository: nameexhaustion/polars-docs
Branch: foods
Path: \*.csv
[Web URL](https://huggingface.co/datasets/nameexhaustion/polars-docs/tree/foods/) | +| hf://datasets/nameexhaustion/polars-docs/hive_dates/ | Bucket: datasets
Repository: nameexhaustion/polars-docs
Branch: main
Path: hive_dates/
[Web URL](https://huggingface.co/datasets/nameexhaustion/polars-docs/tree/main/hive_dates/) | +| hf://spaces/nameexhaustion/polars-docs/orders.feather | Bucket: spaces
Repository: nameexhaustion/polars-docs
Branch: main
Path: orders.feather
[Web URL](https://huggingface.co/spaces/nameexhaustion/polars-docs/tree/main/) | + +### Authentication + +A Hugging Face API key can be passed to Polars to access private locations using either of the +following methods: + +- Passing a `token` in `storage_options` to the scan function, e.g. + `scan_parquet(..., storage_options={'token': ''})` +- Setting the `HF_TOKEN` environment variable, e.g. `export HF_TOKEN=` + +### Examples + +#### CSV + +```python exec="on" result="text" session="user-guide/io/hugging-face" +--8<-- "python/user-guide/io/hugging-face.py:setup" +``` + +{{code_block('user-guide/io/hugging-face','scan_iris_csv',['scan_csv'])}} + +```python exec="on" result="text" session="user-guide/io/hugging-face" +--8<-- "python/user-guide/io/hugging-face.py:scan_iris_repr" +``` + +See this file at +[https://huggingface.co/datasets/nameexhaustion/polars-docs/blob/main/iris.csv](https://huggingface.co/datasets/nameexhaustion/polars-docs/blob/main/iris.csv) + +#### NDJSON + +{{code_block('user-guide/io/hugging-face','scan_iris_ndjson',['scan_ndjson'])}} + +```python exec="on" result="text" session="user-guide/io/hugging-face" +--8<-- "python/user-guide/io/hugging-face.py:scan_iris_repr" +``` + +See this file at +[https://huggingface.co/datasets/nameexhaustion/polars-docs/blob/main/iris.jsonl](https://huggingface.co/datasets/nameexhaustion/polars-docs/blob/main/iris.jsonl) + +#### Parquet + +{{code_block('user-guide/io/hugging-face','scan_parquet_hive_repr',['scan_parquet'])}} + +```python exec="on" result="text" session="user-guide/io/hugging-face" +--8<-- "python/user-guide/io/hugging-face.py:scan_parquet_hive_repr" +``` + +See this folder at +[https://huggingface.co/datasets/nameexhaustion/polars-docs/tree/main/hive_dates/](https://huggingface.co/datasets/nameexhaustion/polars-docs/tree/main/hive_dates/) + +#### IPC + +{{code_block('user-guide/io/hugging-face','scan_ipc',['scan_ipc'])}} + +```python exec="on" result="text" session="user-guide/io/hugging-face" +--8<-- "python/user-guide/io/hugging-face.py:scan_ipc_repr" +``` + +See this file at +[https://huggingface.co/spaces/nameexhaustion/polars-docs/blob/main/orders.feather](https://huggingface.co/spaces/nameexhaustion/polars-docs/blob/main/orders.feather) diff --git a/docs/source/user-guide/io/index.md b/docs/source/user-guide/io/index.md new file mode 100644 index 000000000000..a4f56d188f81 --- /dev/null +++ b/docs/source/user-guide/io/index.md @@ -0,0 +1,15 @@ +# IO + +Reading and writing your data is crucial for a DataFrame library. In this chapter you will learn +more on how to read and write to different file formats that are supported by Polars. + +- [CSV](csv.md) +- [Excel](excel.md) +- [Parquet](parquet.md) +- [Json](json.md) +- [Multiple](multiple.md) +- [Hive](hive.md) +- [Database](database.md) +- [Cloud storage](cloud-storage.md) +- [Google Big Query](bigquery.md) +- [Hugging Face](hugging-face.md) diff --git a/docs/source/user-guide/io/json.md b/docs/source/user-guide/io/json.md new file mode 100644 index 000000000000..043ffe48fd19 --- /dev/null +++ b/docs/source/user-guide/io/json.md @@ -0,0 +1,31 @@ +# JSON files + +Polars can read and write both standard JSON and newline-delimited JSON (NDJSON). + +## Read + +### JSON + +Reading a JSON file should look familiar: + +{{code_block('user-guide/io/json','read',['read_json'])}} + +### Newline Delimited JSON + +JSON objects that are delimited by newlines can be read into Polars in a much more performant way +than standard json. + +Polars can read an NDJSON file into a `DataFrame` using the `read_ndjson` function: + +{{code_block('user-guide/io/json','readnd',['read_ndjson'])}} + +## Write + +{{code_block('user-guide/io/json','write',['write_json','write_ndjson'])}} + +## Scan + +Polars allows you to _scan_ a JSON input **only for newline delimited json**. Scanning delays the +actual parsing of the file and instead returns a lazy computation holder called a `LazyFrame`. + +{{code_block('user-guide/io/json','scan',['scan_ndjson'])}} diff --git a/docs/source/user-guide/io/multiple.md b/docs/source/user-guide/io/multiple.md new file mode 100644 index 000000000000..ab21409b254c --- /dev/null +++ b/docs/source/user-guide/io/multiple.md @@ -0,0 +1,40 @@ +## Dealing with multiple files. + +Polars can deal with multiple files differently depending on your needs and memory strain. + +Let's create some files to give us some context: + +{{code_block('user-guide/io/multiple','create',['write_csv'])}} + +## Reading into a single `DataFrame` + +To read multiple files into a single `DataFrame`, we can use globbing patterns: + +{{code_block('user-guide/io/multiple','read',['read_csv'])}} + +```python exec="on" result="text" session="user-guide/io/multiple" +--8<-- "python/user-guide/io/multiple.py:create" +--8<-- "python/user-guide/io/multiple.py:read" +``` + +To see how this works we can take a look at the query plan. Below we see that all files are read +separately and concatenated into a single `DataFrame`. Polars will try to parallelize the reading. + +{{code_block('user-guide/io/multiple','graph',['show_graph'])}} + +```python exec="on" session="user-guide/io/multiple" +--8<-- "python/user-guide/io/multiple.py:creategraph" +``` + +## Reading and processing in parallel + +If your files don't have to be in a single table you can also build a query plan for each file and +execute them in parallel on the Polars thread pool. + +All query plan execution is embarrassingly parallel and doesn't require any communication. + +{{code_block('user-guide/io/multiple','glob',['scan_csv'])}} + +```python exec="on" result="text" session="user-guide/io/multiple" +--8<-- "python/user-guide/io/multiple.py:glob" +``` diff --git a/docs/source/user-guide/io/parquet.md b/docs/source/user-guide/io/parquet.md new file mode 100644 index 000000000000..2906a54e3f78 --- /dev/null +++ b/docs/source/user-guide/io/parquet.md @@ -0,0 +1,34 @@ +# Parquet + +Loading or writing [`Parquet` files](https://parquet.apache.org/) is lightning fast as the layout of +data in a Polars `DataFrame` in memory mirrors the layout of a Parquet file on disk in many +respects. + +Unlike CSV, Parquet is a columnar format. This means that the data is stored in columns rather than +rows. This is a more efficient way of storing data as it allows for better compression and faster +access to data. + +## Read + +We can read a `Parquet` file into a `DataFrame` using the `read_parquet` function: + +{{code_block('user-guide/io/parquet','read',['read_parquet'])}} + +## Write + +{{code_block('user-guide/io/parquet','write',['write_parquet'])}} + +## Scan + +Polars allows you to _scan_ a `Parquet` input. Scanning delays the actual parsing of the file and +instead returns a lazy computation holder called a `LazyFrame`. + +{{code_block('user-guide/io/parquet','scan',['scan_parquet'])}} + +If you want to know why this is desirable, you can read more about those Polars optimizations +[here](../concepts/lazy-api.md). + +When we scan a `Parquet` file stored in the cloud, we can also apply predicate and projection +pushdowns. This can significantly reduce the amount of data that needs to be downloaded. For +scanning a Parquet file in the cloud, see +[Cloud storage](cloud-storage.md/#scanning-from-cloud-storage-with-query-optimisation). diff --git a/docs/source/user-guide/io/sheets_colab.md b/docs/source/user-guide/io/sheets_colab.md new file mode 100644 index 000000000000..9657b21dfbf4 --- /dev/null +++ b/docs/source/user-guide/io/sheets_colab.md @@ -0,0 +1,47 @@ +# Google Sheets (via Colab) + +Google Colab provides a utility class to read from and write to Google Sheets. + +## Opening and reading from a sheet + +We can open existing sheets by initializing `sheets.InteractiveSheet` with either: + +- the `url` parameter, for example + https://docs.google.com/spreadsheets/d/1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgvE2upms/ +- the `sheet_id` parameter for example 1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgvE2upms + +By default the left-most worksheets will be used, we can change this by providing either +`worksheet_id` or `worksheet_name`. + +The first time in each session that we use `InteractiveSheet` we will need to give Colab permission +to edit our drive assets on our behalf. + +{{code_block('user-guide/io/sheets_colab','open',[])}} + +## Creating a new sheet + +When you don't provide the source of the spreadsheet one will be created for you. + +{{code_block('user-guide/io/sheets_colab','create_title',[])}} + +When you pass the `df` parameter the data will be written to the sheet immediately. + +{{code_block('user-guide/io/sheets_colab','create_df',[])}} + +## Writing to a sheet + +By default the `update` method will clear the worksheet and write the dataframe in the top left +corner. + +{{code_block('user-guide/io/sheets_colab','update',[])}} + +We can modify where the data is written with the `location` parameter and whether the worksheet is +cleared before with `clear`. + +{{code_block('user-guide/io/sheets_colab','update_loc',[])}} + +A good way to write multiple dataframes onto a worksheet in a loop is: + +{{code_block('user-guide/io/sheets_colab','update_loop',[])}} + +This clears the worksheet then writes the dataframes next to each other, one every five columns. diff --git a/docs/source/user-guide/lazy/execution.md b/docs/source/user-guide/lazy/execution.md new file mode 100644 index 000000000000..f35e52be8a83 --- /dev/null +++ b/docs/source/user-guide/lazy/execution.md @@ -0,0 +1,90 @@ +# Query execution + +Our example query on the Reddit dataset is: + +{{code_block('user-guide/lazy/execution','df',['scan_csv'])}} + +If we were to run the code above on the Reddit CSV the query would not be evaluated. Instead Polars +takes each line of code, adds it to the internal query graph and optimizes the query graph. + +When we execute the code Polars executes the optimized query graph by default. + +### Execution on the full dataset + +We can execute our query on the full dataset by calling the `.collect` method on the query. + +{{code_block('user-guide/lazy/execution','collect',['scan_csv','collect'])}} + +```text +shape: (14_029, 6) +┌─────────┬───────────────────────────┬─────────────┬────────────┬───────────────┬────────────┐ +│ id ┆ name ┆ created_utc ┆ updated_on ┆ comment_karma ┆ link_karma │ +│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ +│ i64 ┆ str ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ +╞═════════╪═══════════════════════════╪═════════════╪════════════╪═══════════════╪════════════╡ +│ 6 ┆ TAOJIANLONG_JASONBROKEN ┆ 1397113510 ┆ 1536527864 ┆ 4 ┆ 0 │ +│ 17 ┆ SSAIG_JASONBROKEN ┆ 1397113544 ┆ 1536527864 ┆ 1 ┆ 0 │ +│ 19 ┆ FDBVFDSSDGFDS_JASONBROKEN ┆ 1397113552 ┆ 1536527864 ┆ 3 ┆ 0 │ +│ 37 ┆ IHATEWHOWEARE_JASONBROKEN ┆ 1397113636 ┆ 1536527864 ┆ 61 ┆ 0 │ +│ … ┆ … ┆ … ┆ … ┆ … ┆ … │ +│ 1229384 ┆ DSFOX ┆ 1163177415 ┆ 1536497412 ┆ 44411 ┆ 7917 │ +│ 1229459 ┆ NEOCARTY ┆ 1163177859 ┆ 1536533090 ┆ 40 ┆ 0 │ +│ 1229587 ┆ TEHSMA ┆ 1163178847 ┆ 1536497412 ┆ 14794 ┆ 5707 │ +│ 1229621 ┆ JEREMYLOW ┆ 1163179075 ┆ 1536497412 ┆ 411 ┆ 1063 │ +└─────────┴───────────────────────────┴─────────────┴────────────┴───────────────┴────────────┘ +``` + +Above we see that from the 10 million rows there are 14,029 rows that match our predicate. + +With the default `collect` method Polars processes all of your data as one batch. This means that +all the data has to fit into your available memory at the point of peak memory usage in your query. + +!!! warning "Reusing `LazyFrame` objects" + + Remember that `LazyFrame`s are query plans i.e. a promise on computation and is not guaranteed to cache common subplans. This means that every time you reuse it in separate downstream queries after it is defined, it is computed all over again. If you define an operation on a `LazyFrame` that doesn't maintain row order (such as a `group_by`), then the order will also change every time it is run. To avoid this, use `maintain_order=True` arguments for such operations. + +### Execution on larger-than-memory data + +If your data requires more memory than you have available Polars may be able to process the data in +batches using _streaming_ mode. To use streaming mode you simply pass the `engine="streaming"` +argument to `collect` + +{{code_block('user-guide/lazy/execution','stream',['scan_csv','collect'])}} + +### Execution on a partial dataset + +While you're writing, optimizing or checking your query on a large dataset, querying all available +data may lead to a slow development process. + +Instead, you can scan a subset of your partitions or use `.head`/`.collect` at the beginning and end +of your query, respectively. Keep in mind that the results of aggregations and filters on subsets of +your data may not be representative of the result you would get on the full data. + +{{code_block('user-guide/lazy/execution','partial',['scan_csv','collect','head'])}} + +```text +shape: (1, 6) +┌─────┬─────────────────────────┬─────────────┬────────────┬───────────────┬────────────┐ +│ id ┆ name ┆ created_utc ┆ updated_on ┆ comment_karma ┆ link_karma │ +│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ +│ i64 ┆ str ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ +╞═════╪═════════════════════════╪═════════════╪════════════╪═══════════════╪════════════╡ +│ 6 ┆ TAOJIANLONG_JASONBROKEN ┆ 1397113510 ┆ 1536527864 ┆ 4 ┆ 0 │ +└─────┴─────────────────────────┴─────────────┴────────────┴───────────────┴────────────┘ +``` + +## Diverging queries + +It is very common that a query diverges at one point. In these cases it is recommended to use +`collect_all` as they will ensure that diverging queries execute only once. + +```python +# Some expensive LazyFrame +lf: LazyFrame + +lf_1 = LazyFrame.select(pl.all().sum()) + +lf_2 = lf.some_other_computation() + +pl.collect_all([lf_1, lf_2]) # this will execute lf only once! +``` diff --git a/docs/source/user-guide/lazy/gpu.md b/docs/source/user-guide/lazy/gpu.md new file mode 100644 index 000000000000..e58575fb2b36 --- /dev/null +++ b/docs/source/user-guide/lazy/gpu.md @@ -0,0 +1,25 @@ +# GPU Support + +Polars provides an in-memory, GPU-accelerated execution engine for the Lazy API in Python using +[RAPIDS cuDF](https://docs.rapids.ai/api/cudf/stable/) on NVIDIA GPUs. This functionality is +available in Open Beta, is undergoing rapid development, and is currently a single GPU +implementation. + +If you install Polars with the [GPU feature flag](../installation.md), you can trigger GPU-based +execution by running `.collect(engine="gpu")` instead of `.collect()`. + +{{ code_header("python", [], []) }} + +```python +--8<-- "python/user-guide/lazy/gpu.py:setup" + +result = q.collect(engine="gpu") +print(result) +``` + +```python exec="on" result="text" session="user-guide/lazy" +--8<-- "python/user-guide/lazy/gpu.py:setup" +--8<-- "python/user-guide/lazy/gpu.py:simple-result" +``` + +Learn more in the [GPU Support guide](../gpu-support.md). diff --git a/docs/source/user-guide/lazy/index.md b/docs/source/user-guide/lazy/index.md new file mode 100644 index 000000000000..5443e1349c10 --- /dev/null +++ b/docs/source/user-guide/lazy/index.md @@ -0,0 +1,13 @@ +# Lazy + +The Lazy chapter is a guide for working with `LazyFrames`. It covers the functionalities like how to +use it and how to optimise it. You can also find more information about the query plan or gain more +insight in the streaming capabilities. + +- [Using lazy API](using.md) +- [Optimisations](optimizations.md) +- [Schemas](schemas.md) +- [Query plan](query-plan.md) +- [Execution](execution.md) +- [Sources & Sinks](sources_sinks.md) +- [GPU Support](gpu.md) diff --git a/docs/source/user-guide/lazy/multiplexing.md b/docs/source/user-guide/lazy/multiplexing.md new file mode 100644 index 000000000000..28bb42d74ec6 --- /dev/null +++ b/docs/source/user-guide/lazy/multiplexing.md @@ -0,0 +1,86 @@ +# Multiplexing queries + +
+```python exec="on" result="text" session="user-guide/lazy/multiplexing" +--8<-- "python/user-guide/lazy/multiplexing.py:setup" +``` +
+ +In the [Sources and Sinks](./sources_sinks.md) page, we already discussed multiplexing as a way to +split a query into multiple sinks. This page will go a bit deeper in this concept, as it is +important to understand when combining `LazyFrame`s with procedural programming constructs. + +When dealing with eager dataframes, it is very common to keep state in a temporary variable. Let's +look at the following example. Below we create a `DataFrame` with 10 unique elements in a random +order (so that Polars doesn't hit any fast paths for sorted keys). + +{{code_block('user-guide/lazy/multiplexing','dataframe',[])}} + +```python exec="on" result="text" session="user-guide/lazy/multiplexing" +--8<-- "python/user-guide/lazy/multiplexing.py:dataframe" +``` + +## Eager + +If you deal with the Polars eager API, making a variable and iterating over that temporary +`DataFrame` gives the result you expect, as the result of the group-by is stored in `df1`. Even +though the output order is unstable, it doesn't matter as it is eagerly evaluated. The follow +snippet therefore doesn't raise and the assert passes. +{{code_block('user-guide/lazy/multiplexing','eager',[])}} + +## Lazy + +Now if we tried this naively with `LazyFrame`s, this would fail. + +{{code_block('user-guide/lazy/multiplexing','lazy',[])}} + +```python +AssertionError: DataFrames are different (value mismatch for column 'n') +[left]: [9, 2, 0, 5, 3, 1, 7, 8, 6, 4] +[right]: [0, 9, 6, 8, 2, 5, 4, 3, 1, 7] +``` + +The reason this fails is that `lf1` doesn't contain the materialized result of +`df.lazy().group_by("n").len()`, it instead holds the query plan in that variable. + +```python exec="on" session="user-guide/lazy/multiplexing" +--8<-- "python/user-guide/lazy/multiplexing.py:plan_0" +``` + +This means that every time we branch of this `LazyFrame` and call `collect` we re-evaluate the +group-by. Besides being expensive, this also leads to unexpected results if you assume that the +output is stable (which isn't the case here). + +In the example above you are actually evaluating 2 query plans: + +**Plan 1** + +```python exec="on" session="user-guide/lazy/multiplexing" +--8<-- "python/user-guide/lazy/multiplexing.py:plan_1" +``` + +**Plan 2** + +```python exec="on" session="user-guide/lazy/multiplexing" +--8<-- "python/user-guide/lazy/multiplexing.py:plan_2" +``` + +## Combine the query plans + +To circumvent this, we must give Polars the opportunity to look at all the query plans in a single +optimization and execution pass. This can be done by passing the diverging `LazyFrame`'s to the +`collect_all` function. + +{{code_block('user-guide/lazy/multiplexing','collect_all',[])}} + +If we explain the combined queries with `pl.explain_all`, we can also observe that they are shared +under a single "SINK_MULTIPLE" evaluation and that the optimizer has recognized that parts of the +query come from the same subplan, indicated by the inserted "CACHE" nodes. + +```python exec="on" result="text" session="user-guide/lazy/multiplexing" +--8<-- "python/user-guide/lazy/multiplexing.py:explain_all" +``` + +Combining related subplans in a single execution unit with `pl.collect_all` can thus lead to large +performance increases and allows diverging query plans, storing temporary tables, and a more +procedural programming style. diff --git a/docs/source/user-guide/lazy/optimizations.md b/docs/source/user-guide/lazy/optimizations.md new file mode 100644 index 000000000000..70678f4dbf13 --- /dev/null +++ b/docs/source/user-guide/lazy/optimizations.md @@ -0,0 +1,18 @@ +# Optimizations + +If you use Polars' lazy API, Polars will run several optimizations on your query. Some of them are +executed up front, others are determined just in time as the materialized data comes in. + +Here is a non-complete overview of optimizations done by polars, what they do and how often they +run. + +| Optimization | Explanation | runs | +| -------------------------- | ------------------------------------------------------------------------------------------------------------ | ----------------------------- | +| Predicate pushdown | Applies filters as early as possible/ at scan level. | 1 time | +| Projection pushdown | Select only the columns that are needed at the scan level. | 1 time | +| Slice pushdown | Only load the required slice from the scan level. Don't materialize sliced outputs (e.g. join.head(10)). | 1 time | +| Common subplan elimination | Cache subtrees/file scans that are used by multiple subtrees in the query plan. | 1 time | +| Simplify expressions | Various optimizations, such as constant folding and replacing expensive operations with faster alternatives. | until fixed point | +| Join ordering | Estimates the branches of joins that should be executed first in order to reduce memory pressure. | 1 time | +| Type coercion | Coerce types such that operations succeed and run on minimal required memory. | until fixed point | +| Cardinality estimation | Estimates cardinality in order to determine optimal group by strategy. | 0/n times; dependent on query | diff --git a/docs/source/user-guide/lazy/query-plan.md b/docs/source/user-guide/lazy/query-plan.md new file mode 100644 index 000000000000..ae7799b751f5 --- /dev/null +++ b/docs/source/user-guide/lazy/query-plan.md @@ -0,0 +1,103 @@ +# Query plan + +For any lazy query Polars has both: + +- a non-optimized plan with the set of steps code as we provided it and +- an optimized plan with changes made by the query optimizer + +We can understand both the non-optimized and optimized query plans with visualization and by +printing them as text. + +
+```python exec="on" result="text" session="user-guide/lazy/query-plan" +--8<-- "python/user-guide/lazy/query-plan.py:setup" +``` +
+ +Below we consider the following query: + +{{code_block('user-guide/lazy/query-plan','plan',[])}} + +```python exec="on" session="user-guide/lazy/query-plan" +--8<-- "python/user-guide/lazy/query-plan.py:plan" +``` + +## Non-optimized query plan + +### Graphviz visualization + +To create visualizations of the query plan, +[Graphviz should be installed](https://graphviz.org/download/) and added to your PATH. + +First we visualize the non-optimized plan by setting `optimized=False`. + +{{code_block('user-guide/lazy/query-plan','showplan',['show_graph'])}} + +```python exec="on" session="user-guide/lazy/query-plan" +--8<-- "python/user-guide/lazy/query-plan.py:createplan" +``` + +The query plan visualization should be read from bottom to top. In the visualization: + +- each box corresponds to a stage in the query plan +- the `sigma` stands for `SELECTION` and indicates any filter conditions +- the `pi` stands for `PROJECTION` and indicates choosing a subset of columns + +### Printed query plan + +We can also print the non-optimized plan with `explain(optimized=False)` + +{{code_block('user-guide/lazy/query-plan','describe',['explain'])}} + +```python exec="on" session="user-guide/lazy/query-plan" +--8<-- "python/user-guide/lazy/query-plan.py:describe" +``` + +```text +FILTER [(col("comment_karma")) > (0)] FROM WITH_COLUMNS: + [col("name").str.uppercase()] + + CSV SCAN data/reddit.csv + PROJECT */6 COLUMNS +``` + +The printed plan should also be read from bottom to top. This non-optimized plan is roughly equal +to: + +- read from the `data/reddit.csv` file +- read all 6 columns (where the * wildcard in PROJECT \*/6 COLUMNS means take all columns) +- transform the `name` column to uppercase +- apply a filter on the `comment_karma` column + +## Optimized query plan + +Now we visualize the optimized plan with `show_graph`. + +{{code_block('user-guide/lazy/query-plan','show',['show_graph'])}} + +```python exec="on" session="user-guide/lazy/query-plan" +--8<-- "python/user-guide/lazy/query-plan.py:createplan2" +``` + +We can also print the optimized plan with `explain` + +{{code_block('user-guide/lazy/query-plan','optimized',['explain'])}} + +```text + WITH_COLUMNS: + [col("name").str.uppercase()] + + CSV SCAN data/reddit.csv + PROJECT */6 COLUMNS + SELECTION: [(col("comment_karma")) > (0)] +``` + +The optimized plan is to: + +- read the data from the Reddit CSV +- apply the filter on the `comment_karma` column while the CSV is being read line-by-line +- transform the `name` column to uppercase + +In this case the query optimizer has identified that the `filter` can be applied while the CSV is +read from disk rather than reading the whole file into memory and then applying the filter. This +optimization is called _Predicate Pushdown_. diff --git a/docs/source/user-guide/lazy/schemas.md b/docs/source/user-guide/lazy/schemas.md new file mode 100644 index 000000000000..df7f8d4f6991 --- /dev/null +++ b/docs/source/user-guide/lazy/schemas.md @@ -0,0 +1,76 @@ +# Schema + +The schema of a Polars `DataFrame` or `LazyFrame` sets out the names of the columns and their +datatypes. You can see the schema with the `.collect_schema` method on a `DataFrame` or `LazyFrame` + +{{code_block('user-guide/lazy/schema','schema',['LazyFrame'])}} + +```python exec="on" result="text" session="user-guide/lazy/schemas" +--8<-- "python/user-guide/lazy/schema.py:setup" +--8<-- "python/user-guide/lazy/schema.py:schema" +``` + +The schema plays an important role in the lazy API. + +## Type checking in the lazy API + +One advantage of the lazy API is that Polars will check the schema before any data is processed. +This check happens when you execute your lazy query. + +We see how this works in the following simple example where we call the `.round` expression on the +string column `foo`. + +{{code_block('user-guide/lazy/schema','lazyround',['with_columns'])}} + +The `.round` expression is only valid for columns with a numeric data type. Calling `.round` on a +string column means the operation will raise an `InvalidOperationError` when we evaluate the query +with `collect`. This schema check happens before the data is processed when we call `collect`. + +{{code_block('user-guide/lazy/schema','typecheck',[])}} + +```python exec="on" result="text" session="user-guide/lazy/schemas" +--8<-- "python/user-guide/lazy/schema.py:lazyround" +--8<-- "python/user-guide/lazy/schema.py:typecheck" +``` + +If we executed this query in eager mode the error would only be found once the data had been +processed in all earlier steps. + +When we execute a lazy query Polars checks for any potential `InvalidOperationError` before the +time-consuming step of actually processing the data in the pipeline. + +## The lazy API must know the schema + +In the lazy API the Polars query optimizer must be able to infer the schema at every step of a query +plan. This means that operations where the schema is not knowable in advance cannot be used with the +lazy API. + +The classic example of an operation where the schema is not knowable in advance is a `.pivot` +operation. In a `.pivot` the new column names come from data in one of the columns. As these column +names cannot be known in advance a `.pivot` is not available in the lazy API. + +## Dealing with operations not available in the lazy API + +If your pipeline includes an operation that is not available in the lazy API it is normally best to: + +- run the pipeline in lazy mode up until that point +- execute the pipeline with `.collect` to materialize a `DataFrame` +- do the non-lazy operation on the `DataFrame` +- convert the output back to a `LazyFrame` with `.lazy` and continue in lazy mode + +We show how to deal with a non-lazy operation in this example where we: + +- create a simple `DataFrame` +- convert it to a `LazyFrame` with `.lazy` +- do a transformation using `.with_columns` +- execute the query before the pivot with `.collect` to get a `DataFrame` +- do the `.pivot` on the `DataFrame` +- convert back in lazy mode +- do a `.filter` +- finish by executing the query with `.collect` to get a `DataFrame` + +{{code_block('user-guide/lazy/schema','lazyeager',['collect','lazy','pivot','filter'])}} + +```python exec="on" result="text" session="user-guide/lazy/schemas" +--8<-- "python/user-guide/lazy/schema.py:lazyeager" +``` diff --git a/docs/source/user-guide/lazy/sources_sinks.md b/docs/source/user-guide/lazy/sources_sinks.md new file mode 100644 index 000000000000..4d1b345aa9e6 --- /dev/null +++ b/docs/source/user-guide/lazy/sources_sinks.md @@ -0,0 +1,53 @@ +# Sources and sinks + +## Scan + +When using the `LazyFrame` API, it is important to favor `scan_*` (`scan_parquet`, `scan_csv`, etc.) +over `read_*`. A Polars `scan` is lazy and will delay execution until the query is collected. The +benefit of this, is that the Polars optimizer can push optimization into the readers. They can skip +reading columns and rows that aren't required. Another benefit is that, during streaming execution, +the engine already can process batches before the file is completely read. + +## Sink + +Sinks can execute a query and stream the results to storage (being disk or cloud). The benefit of +sinking data to storage is that you don't necessarily have to store all data in RAM, but can process +data in batches. + +If we would want to convert many csv files to parquet, whilst dropping the missing data, we could do +something like the query below. We use a partitioning strategy that defines how many rows may be in +a single parquet file, before we generate a new file + +```python +lf = scan_csv("my_dataset/*.csv").filter(pl.all().is_not_null()) +lf.sink_parquet( + pl.PartitionMaxSize( + "my_table_{part}.parquet" + max_size=512_000 + ) +) +``` + +This will create the following files on disk: + +```text +my_table_0.parquet +my_table_1.parquet +... +my_table_n.parquet +``` + +## Multiplexing sinks + +Sinks can also multiplex. Meaning that we write to different sinks in a single query. In the code +snippet below, we take a `LazyFrame` and sink it into 2 sinks at the same time. + +```python +# Some expensive computation +lf: LazyFrame + +q1 = lf.sink_parquet(.., lazy=True) +q2 = lf.sink_ipc(.., lazy=True) + +lf.collect_all([q1, q2]) +``` diff --git a/docs/source/user-guide/lazy/using.md b/docs/source/user-guide/lazy/using.md new file mode 100644 index 000000000000..62a58ae07ee1 --- /dev/null +++ b/docs/source/user-guide/lazy/using.md @@ -0,0 +1,41 @@ +# Usage + +With the lazy API, Polars doesn't run each query line-by-line but instead processes the full query +end-to-end. To get the most out of Polars it is important that you use the lazy API because: + +- the lazy API allows Polars to apply automatic query optimization with the query optimizer +- the lazy API allows you to work with larger than memory datasets using streaming +- the lazy API can catch schema errors before processing the data + +Here we see how to use the lazy API starting from either a file or an existing `DataFrame`. + +## Using the lazy API from a file + +In the ideal case we would use the lazy API right from a file as the query optimizer may help us to +reduce the amount of data we read from the file. + +We create a lazy query from the Reddit CSV data and apply some transformations. + +By starting the query with `pl.scan_csv` we are using the lazy API. + +{{code_block('user-guide/lazy/using','dataframe',['scan_csv','with_columns','filter','col'])}} + +A `pl.scan_` function is available for a number of file types including CSV, IPC, Parquet and JSON. + +In this query we tell Polars that we want to: + +- load data from the Reddit CSV file +- convert the `name` column to uppercase +- apply a filter to the `comment_karma` column + +The lazy query will not be executed at this point. See this page on +[executing lazy queries](execution.md) for more on running lazy queries. + +## Using the lazy API from a `DataFrame` + +An alternative way to access the lazy API is to call `.lazy` on a `DataFrame` that has already been +created in memory. + +{{code_block('user-guide/lazy/using','fromdf',['lazy'])}} + +By calling `.lazy` we convert the `DataFrame` to a `LazyFrame`. diff --git a/docs/source/user-guide/migration/pandas.md b/docs/source/user-guide/migration/pandas.md new file mode 100644 index 000000000000..555f96a88331 --- /dev/null +++ b/docs/source/user-guide/migration/pandas.md @@ -0,0 +1,428 @@ +# Coming from Pandas + +Here we set out the key points that anyone who has experience with pandas and wants to try Polars +should know. We include both differences in the concepts the libraries are built on and differences +in how you should write Polars code compared to pandas code. + +## Differences in concepts between Polars and pandas + +### Polars does not have a multi-index/index + +pandas gives a label to each row with an index. Polars does not use an index and each row is indexed +by its integer position in the table. + +Polars aims to have predictable results and readable queries, as such we think an index does not +help us reach that objective. We believe the semantics of a query should not change by the state of +an index or a `reset_index` call. + +In Polars a DataFrame will always be a 2D table with heterogeneous data-types. The data-types may +have nesting, but the table itself will not. Operations like resampling will be done by specialized +functions or methods that act like 'verbs' on a table explicitly stating the columns that that +'verb' operates on. As such, it is our conviction that not having indices make things simpler, more +explicit, more readable and less error-prone. + +Note that an 'index' data structure as known in databases will be used by Polars as an optimization +technique. + +### Polars adheres to the Apache Arrow memory format to represent data in memory while pandas uses NumPy arrays + +Polars represents data in memory according to the Arrow memory spec while pandas by default +represents data in memory with NumPy arrays. Apache Arrow is an emerging standard for in-memory +columnar analytics that can accelerate data load times, reduce memory usage and accelerate +calculations. + +Polars can convert data to NumPy format with the `to_numpy` method. + +### Polars has more support for parallel operations than pandas + +Polars exploits the strong support for concurrency in Rust to run many operations in parallel. While +some operations in pandas are multi-threaded the core of the library is single-threaded and an +additional library such as `Dask` must be used to parallelize operations. Polars is faster than all +open source solutions that parallelize pandas code. + +### Polars has support for different engines + +Polars has native support for an engine optimized for in-memory processing and a streaming engine +optimized for large scale data processing. Furthermore Polars has native integration with a CuDF +supported engine. All these engines benefit from Polars' query optimizer and Polars ensures semantic +correctness between all those engines. In pandas the implementation can dispatch between numpy and +Pyarrow, but because of pandas' loose strictness guarantees, the data-type outputs and semantics +between those backends can differ. This can lead to subtle bugs. + +### Polars can lazily evaluate queries and apply query optimization + +Eager evaluation is when code is evaluated as soon as you run the code. Lazy evaluation is when +running a line of code means that the underlying logic is added to a query plan rather than being +evaluated. + +Polars supports eager evaluation and lazy evaluation whereas pandas only supports eager evaluation. +The lazy evaluation mode is powerful because Polars carries out automatic query optimization when it +examines the query plan and looks for ways to accelerate the query or reduce memory usage. + +`Dask` also supports lazy evaluation when it generates a query plan. + +### Polars is strict + +Polars is strict about data types. Data type resolution in Polars is dependent on the operation +graph, whereas pandas converts types loosely (e.g. new missing data can lead to integer columns +being converted to floats). This strictness leads to fewer bugs and more predictable behavior. + +### Polars has a more verstatile API + +Polars is built on expressions and allows expression inputs in almost all operations. This means +that when you understand how expressions work, your knowledge in Polars extrapolates. Pandas doesn't +have an expression system and often requires Python `lambda`s to express the complexity you want. +Polars sees the requirement of a Python `lambda` as a lack of expressiveness of its API, and tries +to give you native support whenever possible. + +## Key syntax differences + +Users coming from pandas generally need to know one thing... + +``` +polars != pandas +``` + +If your Polars code looks like it could be pandas code, it might run, but it likely runs slower than +it should. + +Let's go through some typical pandas code and see how we might rewrite it in Polars. + +### Selecting data + +As there is no index in Polars there is no `.loc` or `iloc` method in Polars - and there is also no +`SettingWithCopyWarning` in Polars. + +However, the best way to select data in Polars is to use the expression API. For example, if you +want to select a column in pandas, you can do one of the following: + +```python +df["a"] +df.loc[:,"a"] +``` + +but in Polars you would use the `.select` method: + +```python +df.select("a") +``` + +If you want to select rows based on the values then in Polars you use the `.filter` method: + +```python +df.filter(pl.col("a") < 10) +``` + +As noted in the section on expressions below, Polars can run operations in `.select` and `filter` in +parallel and Polars can carry out query optimization on the full set of data selection criteria. + +### Be lazy + +Working in lazy evaluation mode is straightforward and should be your default in Polars as the lazy +mode allows Polars to do query optimization. + +We can run in lazy mode by either using an implicitly lazy function (such as `scan_csv`) or +explicitly using the `lazy` method. + +Take the following simple example where we read a CSV file from disk and do a group by. The CSV file +has numerous columns but we just want to do a group by on one of the id columns (`id1`) and then sum +by a value column (`v1`). In pandas this would be: + +```python +df = pd.read_csv(csv_file, usecols=["id1","v1"]) +grouped_df = df.loc[:,["id1","v1"]].groupby("id1").sum("v1") +``` + +In Polars you can build this query in lazy mode with query optimization and evaluate it by replacing +the eager pandas function `read_csv` with the implicitly lazy Polars function `scan_csv`: + +```python +df = pl.scan_csv(csv_file) +grouped_df = df.group_by("id1").agg(pl.col("v1").sum()).collect() +``` + +Polars optimizes this query by identifying that only the `id1` and `v1` columns are relevant and so +will only read these columns from the CSV. By calling the `.collect` method at the end of the second +line we instruct Polars to eagerly evaluate the query. + +If you do want to run this query in eager mode you can just replace `scan_csv` with `read_csv` in +the Polars code. + +Read more about working with lazy evaluation in the [lazy API](../lazy/using.md) section. + +### Express yourself + +A typical pandas script consists of multiple data transformations that are executed sequentially. +However, in Polars these transformations can be executed in parallel using expressions. + +#### Column assignment + +We have a dataframe `df` with a column called `value`. We want to add two new columns, a column +called `tenXValue` where the `value` column is multiplied by 10 and a column called `hundredXValue` +where the `value` column is multiplied by 100. + +In pandas this would be: + +```python +df.assign( + tenXValue=lambda df_: df_.value * 10, + hundredXValue=lambda df_: df_.value * 100 +) +``` + +These column assignments are executed sequentially. + +In Polars we add columns to `df` using the `.with_columns` method: + +```python +df.with_columns( + tenXValue=pl.col("value") * 10, + hundredXValue=pl.col("value") * 100, +) +``` + +These column assignments are executed in parallel. + +#### Column assignment based on predicate + +In this case we have a dataframe `df` with columns `a`,`b` and `c`. We want to re-assign the values +in column `a` based on a condition. When the value in column `c` is equal to 2 then we replace the +value in `a` with the value in `b`. + +In pandas this would be: + +```python +df.assign(a=lambda df_: df_["a"].mask(df_["c"] == 2, df_["b"])) +``` + +while in Polars this would be: + +```python +df.with_columns( + pl.when(pl.col("c") == 2) + .then(pl.col("b")) + .otherwise(pl.col("a")).alias("a") +) +``` + +Polars can compute every branch of an `if -> then -> otherwise` in parallel. This is valuable, when +the branches get more expensive to compute. + +#### Filtering + +We want to filter the dataframe `df` with housing data based on some criteria. + +In pandas you filter the dataframe by passing Boolean expressions to the `query` method: + +```python +df.query("m2_living > 2500 and price < 300000") +``` + +or by directly evaluating a mask: + +```python +df[(df["m2_living"] > 2500) & (df["price"] < 300000)] +``` + +while in Polars you call the `filter` method: + +```python +df.filter( + (pl.col("m2_living") > 2500) & (pl.col("price") < 300000) +) +``` + +The query optimizer in Polars can also detect if you write multiple filters separately and combine +them into a single filter in the optimized plan. + +## pandas transform + +The pandas documentation demonstrates an operation on a group by called `transform`. In this case we +have a dataframe `df` and we want a new column showing the number of rows in each group. + +In pandas we have: + +```python +df = pd.DataFrame({ + "c": [1, 1, 1, 2, 2, 2, 2], + "type": ["m", "n", "o", "m", "m", "n", "n"], +}) + +df["size"] = df.groupby("c")["type"].transform(len) +``` + +Here pandas does a group by on `"c"`, takes column `"type"`, computes the group length and then +joins the result back to the original `DataFrame` producing: + +``` + c type size +0 1 m 3 +1 1 n 3 +2 1 o 3 +3 2 m 4 +4 2 m 4 +5 2 n 4 +6 2 n 4 +``` + +In Polars the same can be achieved with `window` functions: + +```python +df.with_columns( + pl.col("type").count().over("c").alias("size") +) +``` + +``` +shape: (7, 3) +┌─────┬──────┬──────┐ +│ c ┆ type ┆ size │ +│ --- ┆ --- ┆ --- │ +│ i64 ┆ str ┆ u32 │ +╞═════╪══════╪══════╡ +│ 1 ┆ m ┆ 3 │ +│ 1 ┆ n ┆ 3 │ +│ 1 ┆ o ┆ 3 │ +│ 2 ┆ m ┆ 4 │ +│ 2 ┆ m ┆ 4 │ +│ 2 ┆ n ┆ 4 │ +│ 2 ┆ n ┆ 4 │ +└─────┴──────┴──────┘ +``` + +Because we can store the whole operation in a single expression, we can combine several `window` +functions and even combine different groups! + +Polars will cache window expressions that are applied over the same group, so storing them in a +single `with_columns` is both convenient **and** optimal. In the following example we look at a case +where we are calculating group statistics over `"c"` twice: + +```python +df.with_columns( + pl.col("c").count().over("c").alias("size"), + pl.col("c").sum().over("type").alias("sum"), + pl.col("type").reverse().over("c").alias("reverse_type") +) +``` + +``` +shape: (7, 5) +┌─────┬──────┬──────┬─────┬──────────────┐ +│ c ┆ type ┆ size ┆ sum ┆ reverse_type │ +│ --- ┆ --- ┆ --- ┆ --- ┆ --- │ +│ i64 ┆ str ┆ u32 ┆ i64 ┆ str │ +╞═════╪══════╪══════╪═════╪══════════════╡ +│ 1 ┆ m ┆ 3 ┆ 5 ┆ o │ +│ 1 ┆ n ┆ 3 ┆ 5 ┆ n │ +│ 1 ┆ o ┆ 3 ┆ 1 ┆ m │ +│ 2 ┆ m ┆ 4 ┆ 5 ┆ n │ +│ 2 ┆ m ┆ 4 ┆ 5 ┆ n │ +│ 2 ┆ n ┆ 4 ┆ 5 ┆ m │ +│ 2 ┆ n ┆ 4 ┆ 5 ┆ m │ +└─────┴──────┴──────┴─────┴──────────────┘ +``` + +## Missing data + +pandas uses `NaN` and/or `None` values to indicate missing values depending on the dtype of the +column. In addition the behaviour in pandas varies depending on whether the default dtypes or +optional nullable arrays are used. In Polars missing data corresponds to a `null` value for all data +types. + +For float columns Polars permits the use of `NaN` values. These `NaN` values are not considered to +be missing data but instead a special floating point value. + +In pandas an integer column with missing values is cast to be a float column with `NaN` values for +the missing values (unless using optional nullable integer dtypes). In Polars any missing values in +an integer column are simply `null` values and the column remains an integer column. + +See the [missing data](../expressions/missing-data.md) section for more details. + +## Pipe littering + +A common usage in pandas is utilizing `pipe` to apply some function to a `DataFrame`. Copying this +coding style to Polars is unidiomatic and leads to suboptimal query plans. + +The snippet below shows a common pattern in pandas. + +```python +def add_foo(df: pd.DataFrame) -> pd.DataFrame: + df["foo"] = ... + return df + +def add_bar(df: pd.DataFrame) -> pd.DataFrame: + df["bar"] = ... + return df + + +def add_ham(df: pd.DataFrame) -> pd.DataFrame: + df["ham"] = ... + return df + +(df + .pipe(add_foo) + .pipe(add_bar) + .pipe(add_ham) +) +``` + +If we do this in polars, we would create 3 `with_columns` contexts, that forces Polars to run the 3 +pipes sequentially, utilizing zero parallelism. + +The way to get similar abstractions in polars is creating functions that create expressions. The +snippet below creates 3 expressions that run on a single context and thus are allowed to run in +parallel. + +```python +def get_foo(input_column: str) -> pl.Expr: + return pl.col(input_column).some_computation().alias("foo") + +def get_bar(input_column: str) -> pl.Expr: + return pl.col(input_column).some_computation().alias("bar") + +def get_ham(input_column: str) -> pl.Expr: + return pl.col(input_column).some_computation().alias("ham") + +# This single context will run all 3 expressions in parallel +df.with_columns( + get_ham("col_a"), + get_bar("col_b"), + get_foo("col_c"), +) +``` + +If you need the schema in the functions that generate the expressions, you can utilize a single +`pipe`: + +```python +from collections import OrderedDict + +def get_foo(input_column: str, schema: OrderedDict) -> pl.Expr: + if "some_col" in schema: + # branch_a + ... + else: + # branch b + ... + +def get_bar(input_column: str, schema: OrderedDict) -> pl.Expr: + if "some_col" in schema: + # branch_a + ... + else: + # branch b + ... + +def get_ham(input_column: str) -> pl.Expr: + return pl.col(input_column).some_computation().alias("ham") + +# Use pipe (just once) to get hold of the schema of the LazyFrame. +lf.pipe(lambda lf: lf.with_columns( + get_ham("col_a"), + get_bar("col_b", lf.schema), + get_foo("col_c", lf.schema), +) +``` + +Another benefit of writing functions that return expressions, is that these functions are composable +as expressions can be chained and partially applied, leading to much more flexibility in the design. diff --git a/docs/source/user-guide/migration/spark.md b/docs/source/user-guide/migration/spark.md new file mode 100644 index 000000000000..ff6740b3b11b --- /dev/null +++ b/docs/source/user-guide/migration/spark.md @@ -0,0 +1,200 @@ +# Coming from Apache Spark + +## Column-based API vs. Row-based API + +Whereas the `Spark` `DataFrame` is analogous to a collection of rows, a Polars `DataFrame` is closer +to a collection of columns. This means that you can combine columns in Polars in ways that are not +possible in `Spark`, because `Spark` preserves the relationship of the data in each row. + +Consider this sample dataset: + +```python +import polars as pl + +df = pl.DataFrame({ + "foo": ["a", "b", "c", "d", "d"], + "bar": [1, 2, 3, 4, 5], +}) + +dfs = spark.createDataFrame( + [ + ("a", 1), + ("b", 2), + ("c", 3), + ("d", 4), + ("d", 5), + ], + schema=["foo", "bar"], +) +``` + +### Example 1: Combining `head` and `sum` + +In Polars you can write something like this: + +```python +df.select( + pl.col("foo").sort().head(2), + pl.col("bar").filter(pl.col("foo") == "d").sum() +) +``` + +Output: + +``` +shape: (2, 2) +┌─────┬─────┐ +│ foo ┆ bar │ +│ --- ┆ --- │ +│ str ┆ i64 │ +╞═════╪═════╡ +│ a ┆ 9 │ +├╌╌╌╌╌┼╌╌╌╌╌┤ +│ b ┆ 9 │ +└─────┴─────┘ +``` + +The expressions on columns `foo` and `bar` are completely independent. Since the expression on `bar` +returns a single value, that value is repeated for each value output by the expression on `foo`. But +`a` and `b` have no relation to the data that produced the sum of `9`. + +To do something similar in `Spark`, you'd need to compute the sum separately and provide it as a +literal: + +```python +from pyspark.sql.functions import col, sum, lit + +bar_sum = ( + dfs + .where(col("foo") == "d") + .groupBy() + .agg(sum(col("bar"))) + .take(1)[0][0] +) + +( + dfs + .orderBy("foo") + .limit(2) + .withColumn("bar", lit(bar_sum)) + .show() +) +``` + +Output: + +``` ++---+---+ +|foo|bar| ++---+---+ +| a| 9| +| b| 9| ++---+---+ +``` + +### Example 2: Combining Two `head`s + +In Polars you can combine two different `head` expressions on the same DataFrame, provided that they +return the same number of values. + +```python +df.select( + pl.col("foo").sort().head(2), + pl.col("bar").sort(descending=True).head(2), +) +``` + +Output: + +``` +shape: (3, 2) +┌─────┬─────┐ +│ foo ┆ bar │ +│ --- ┆ --- │ +│ str ┆ i64 │ +╞═════╪═════╡ +│ a ┆ 5 │ +├╌╌╌╌╌┼╌╌╌╌╌┤ +│ b ┆ 4 │ +└─────┴─────┘ +``` + +Again, the two `head` expressions here are completely independent, and the pairing of `a` to `5` and +`b` to `4` results purely from the juxtaposition of the two columns output by the expressions. + +To accomplish something similar in `Spark`, you would need to generate an artificial key that +enables you to join the values in this way. + +```python +from pyspark.sql import Window +from pyspark.sql.functions import row_number + +foo_dfs = ( + dfs + .withColumn( + "rownum", + row_number().over(Window.orderBy("foo")) + ) +) + +bar_dfs = ( + dfs + .withColumn( + "rownum", + row_number().over(Window.orderBy(col("bar").desc())) + ) +) + +( + foo_dfs.alias("foo") + .join(bar_dfs.alias("bar"), on="rownum") + .select("foo.foo", "bar.bar") + .limit(2) + .show() +) +``` + +Output: + +``` ++---+---+ +|foo|bar| ++---+---+ +| a| 5| +| b| 4| ++---+---+ +``` + +### Example 3: Composing expressions + +Polars allows you compose expressions quite liberally. For example, if you want to find the rolling +mean of a lagged variable, you can compose `shift` and `rolling_mean` and evaluate them in a single +`over` expression: + +```python +df.with_columns( + feature=pl.col('price').shift(7).rolling_mean(7).over('store', order_by='date') +) +``` + +In PySpark however this is not allowed. They allow composing expressions such as +`F.mean(F.abs("price")).over(window)` because `F.abs` is an elementwise function, but not +`F.mean(F.lag("price", 1)).over(window)` because `F.lag` is a window function. To produce the same +result, both `F.lag` and `F.mean` need their own window. + +```python +from pyspark.sql import Window +from pyspark.sql import functions as F + +window = Window().partitionBy("store").orderBy("date") +rolling_window = window.rowsBetween(-6, 0) +( + df.withColumn("lagged_price", F.lag("price", 7).over(window)).withColumn( + "feature", + F.when( + F.count("lagged_price").over(rolling_window) >= 7, + F.mean("lagged_price").over(rolling_window), + ), + ) +) +``` diff --git a/docs/source/user-guide/misc/arrow.md b/docs/source/user-guide/misc/arrow.md new file mode 100644 index 000000000000..23985e9d94c6 --- /dev/null +++ b/docs/source/user-guide/misc/arrow.md @@ -0,0 +1,148 @@ +# Arrow producer/consumer + +## Using pyarrow + +Polars can move data in and out of arrow zero copy. This can be done either via pyarrow or natively. +Let's first start by showing the pyarrow solution: + +{{code_block('user-guide/misc/arrow','to_arrow',[])}} + +``` +pyarrow.Table +foo: int64 +bar: large_string +---- +foo: [[1,2,3]] +bar: [["ham","spam","jam"]] +``` + +Or if you want to ensure the output is zero-copy: + +{{code_block('user-guide/misc/arrow','to_arrow_zero',[])}} + +``` +pyarrow.Table +foo: int64 +bar: string_view +---- +foo: [[1,2,3]] +bar: [["ham","spam","jam"]] +``` + +Importing from pyarrow can be achieved with `pl.from_arrow`. + +## Using the Arrow PyCapsule Interface + +As of Polars v1.3 and higher, Polars implements the +[Arrow PyCapsule Interface](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html), +a protocol for sharing Arrow data across Python libraries. + +### Exporting data from Polars to pyarrow + +To convert a Polars `DataFrame` to a `pyarrow.Table`, use the `pyarrow.table` constructor: + +!!! note + + This requires pyarrow v15 or higher. + +{{code_block('user-guide/misc/arrow_pycapsule','to_arrow',[])}} + +``` +pyarrow.Table +foo: int64 +bar: string_view +---- +foo: [[1,2,3]] +bar: [["ham","spam","jam"]] +``` + +To convert a Polars `Series` to a `pyarrow.ChunkedArray`, use the `pyarrow.chunked_array` +constructor. + +{{code_block('user-guide/misc/arrow_pycapsule','to_arrow_series',[])}} + +``` +[ + [ + 1, + 2, + 3 + ] +] +``` + +You can also pass a `Series` to the `pyarrow.array` constructor to create a contiguous array. Note +that this will not be zero-copy if the underlying `Series` had multiple chunks. + +{{code_block('user-guide/misc/arrow_pycapsule','to_arrow_array_rechunk',[])}} + +``` +[ + 1, + 2, + 3 +] +``` + +### Importing data from pyarrow to Polars + +We can pass the pyarrow `Table` back to Polars by using the `polars.DataFrame` constructor: + +{{code_block('user-guide/misc/arrow_pycapsule','to_polars',[])}} + +``` +shape: (3, 2) +┌─────┬──────┐ +│ foo ┆ bar │ +│ --- ┆ --- │ +│ i64 ┆ str │ +╞═════╪══════╡ +│ 1 ┆ ham │ +│ 2 ┆ spam │ +│ 3 ┆ jam │ +└─────┴──────┘ +``` + +Similarly, we can pass the pyarrow `ChunkedArray` or `Array` back to Polars by using the +`polars.Series` constructor: + +{{code_block('user-guide/misc/arrow_pycapsule','to_polars_series',[])}} + +``` +shape: (3,) +Series: '' [i64] +[ + 1 + 2 + 3 +] +``` + +### Usage with other arrow libraries + +There's a [growing list](https://github.com/apache/arrow/issues/39195#issuecomment-2245718008) of +libraries that support the PyCapsule Interface directly. Polars `Series` and `DataFrame` objects +work automatically with every such library. + +### For library maintainers + +If you're developing a library that you wish to integrate with Polars, it's suggested to implement +the +[Arrow PyCapsule Interface](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html) +yourself. This comes with a number of benefits: + +- Zero-copy exchange for both Polars Series and DataFrame +- No required dependency on pyarrow. +- No direct dependency on Polars. +- Harder to cause memory leaks than handling pointers as raw integers. +- Automatic zero-copy integration other PyCapsule Interface-supported libraries. + +## Using Polars directly + +Polars can also consume and export to and import from the +[Arrow C Data Interface](https://arrow.apache.org/docs/format/CDataInterface.html) directly. This is +recommended for libraries that don't support the Arrow PyCapsule Interface and want to interop with +Polars without requiring a pyarrow installation. + +- To export `ArrowArray` C structs, Polars exposes: `Series._export_arrow_to_c`. +- To import an `ArrowArray` C struct, Polars exposes `Series._import_arrow_from_c`. diff --git a/docs/source/user-guide/misc/comparison.md b/docs/source/user-guide/misc/comparison.md new file mode 100644 index 000000000000..3520a26819f7 --- /dev/null +++ b/docs/source/user-guide/misc/comparison.md @@ -0,0 +1,63 @@ +# Comparison with other tools + +These are several libraries and tools that share similar functionalities with Polars. This often +leads to questions from data experts about what the differences are. Below is a short comparison +between some of the more popular data processing tools and Polars, to help data experts make a +deliberate decision on which tool to use. + +You can find performance benchmarks (h2oai benchmark) of these tools here: +[Polars blog post](https://pola.rs/posts/benchmarks/) or a more recent benchmark +[done by DuckDB](https://duckdblabs.github.io/db-benchmark/) + +### Pandas + +Pandas stands as a widely-adopted and comprehensive tool in Python data analysis, renowned for its +rich feature set and strong community support. However, due to its single threaded nature, it can +struggle with performance and memory usage on medium and large datasets. + +In contrast, Polars is optimised for high-performance multithreaded computing on single nodes, +providing significant improvements in speed and memory efficiency, particularly for medium to large +data operations. Its more composable and stricter API results in greater expressiveness and fewer +schema-related bugs. + +### Dask + +Dask extends Pandas' capabilities to large, distributed datasets. Dask mimics Pandas' API, offering +a familiar environment for Pandas users, but with the added benefit of parallel and distributed +computing. + +While Dask excels at scaling Pandas workflows across clusters, it only supports a subset of the +Pandas API and therefore cannot be used for all use cases. Polars offers a more versatile API that +delivers strong performance within the constraints of a single node. + +The choice between Dask and Polars often comes down to familiarity with the Pandas API and the need +for distributed processing for extremely large datasets versus the need for efficiency and speed in +a vertically scaled environment for a wide range of use cases. + +### Modin + +Similar to Dask. In 2023, Snowflake acquired Ponder, the organisation that maintains Modin. + +### Spark + +Spark (specifically PySpark) represents a different approach to large-scale data processing. While +Polars has an optimised performance for single-node environments, Spark is designed for distributed +data processing across clusters, making it suitable for extremely large datasets. + +However, Spark's distributed nature can introduce complexity and overhead, especially for small +datasets and tasks that can run on a single machine. Another consideration is collaboration between +data scientists and engineers. As they typically work with different tools (Pandas and Pyspark), +refactoring is often required by engineers to deploy data scientists' data processing pipelines. +Polars offers a single syntax that, due to vertical scaling, works in local environments and on a +single machine in the cloud. + +The choice between Polars and Spark often depends on the scale of data and the specific requirements +of the processing task. If you need to process TBs of data, Spark is a better choice. + +### DuckDB + +Polars and DuckDB have many similarities. However, DuckDB is focused on providing an in-process SQL +OLAP database management system, while Polars is focused on providing a scalable `DataFrame` +interface to many languages. The different front-ends lead to different optimisation strategies and +different algorithm prioritisation. The interoperability between both is zero-copy. DuckDB offers a +guide on [how to integrate with Polars](https://duckdb.org/docs/guides/python/polars.html). diff --git a/docs/source/user-guide/misc/multiprocessing.md b/docs/source/user-guide/misc/multiprocessing.md new file mode 100644 index 000000000000..b8f6fe0db263 --- /dev/null +++ b/docs/source/user-guide/misc/multiprocessing.md @@ -0,0 +1,130 @@ +# Multiprocessing + +TLDR: if you find that using Python's built-in `multiprocessing` module together with Polars results +in a Polars error about multiprocessing methods, you should make sure you are using `spawn`, not +`fork`, as the starting method: + +{{code_block('user-guide/misc/multiprocess','recommendation',[])}} + +## When not to use multiprocessing + +Before we dive into the details, it is important to emphasize that Polars has been built from the +start to use all your CPU cores. It does this by executing computations which can be done in +parallel in separate threads. For example, requesting two expressions in a `select` statement can be +done in parallel, with the results only being combined at the end. Another example is aggregating a +value within groups using `group_by().agg()`, each group can be evaluated separately. It is +very unlikely that the `multiprocessing` module can improve your code performance in these cases. If +you're using the GPU Engine with Polars you should also avoid manual multiprocessing. When used +simultaneously, they can compete for system memory and processing power, leading to reduced +performance. + +See [the optimizations section](../lazy/optimizations.md) for more optimizations. + +## When to use multiprocessing + +Although Polars is multithreaded, other libraries may be single-threaded. When the other library is +the bottleneck, and the problem at hand is parallelizable, it makes sense to use multiprocessing to +gain a speed up. + +## The problem with the default multiprocessing config + +### Summary + +The [Python multiprocessing documentation](https://docs.python.org/3/library/multiprocessing.html) +lists the three methods to create a process pool: + +1. spawn +1. fork +1. forkserver + +The description of fork is (as of 2022-10-15): + +> The parent process uses os.fork() to fork the Python interpreter. The child process, when it +> begins, is effectively identical to the parent process. All resources of the parent are inherited +> by the child process. Note that safely forking a multithreaded process is problematic. + +> Available on Unix only. The default on Unix. + +The short summary is: Polars is multithreaded as to provide strong performance out-of-the-box. Thus, +it cannot be combined with `fork`. If you are on Unix (Linux, BSD, etc), you are using `fork`, +unless you explicitly override it. + +The reason you may not have encountered this before is that pure Python code, and most Python +libraries, are (mostly) single threaded. Alternatively, you are on Windows or MacOS, on which `fork` +is not even available as a method (for MacOS it was up to Python 3.7). + +Thus one should use `spawn`, or `forkserver`, instead. `spawn` is available on all platforms and the +safest choice, and hence the recommended method. + +### Example + +The problem with `fork` is in the copying of the parent's process. Consider the example below, which +is a slightly modified example posted on the +[Polars issue tracker](https://github.com/pola-rs/polars/issues/3144): + +{{code_block('user-guide/misc/multiprocess','example1',[])}} + +Using `fork` as the method, instead of `spawn`, will cause a dead lock. + +The fork method is equivalent to calling `os.fork()`, which is a system call as defined in +[the POSIX standard](https://pubs.opengroup.org/onlinepubs/9699919799/functions/fork.html): + +> A process shall be created with a single thread. If a multi-threaded process calls fork(), the new +> process shall contain a replica of the calling thread and its entire address space, possibly +> including the states of mutexes and other resources. Consequently, to avoid errors, the child +> process may only execute async-signal-safe operations until such time as one of the exec functions +> is called. + +In contrast, `spawn` will create a completely new fresh Python interpreter, and not inherit the +state of mutexes. + +So what happens in the code example? For reading the file with `pl.read_parquet` the file has to be +locked. Then `os.fork()` is called, copying the state of the parent process, including mutexes. Thus +all child processes will copy the file lock in an acquired state, leaving them hanging indefinitely +waiting for the file lock to be released, which never happens. + +What makes debugging these issues tricky is that `fork` can work. Change the example to not having +the call to `pl.read_parquet`: + +{{code_block('user-guide/misc/multiprocess','example2',[])}} + +This works fine. Therefore debugging these issues in larger code bases, i.e. not the small toy +examples here, can be a real pain, as a seemingly unrelated change can break your multiprocessing +code. In general, one should therefore never use the `fork` start method with multithreaded +libraries unless there are very specific requirements that cannot be met otherwise. + +### Pro's and cons of fork + +Based on the example, you may think, why is `fork` available in Python to start with? + +First, probably because of historical reasons: `spawn` was added to Python in version 3.4, whilst +`fork` has been part of Python from the 2.x series. + +Second, there are several limitations for `spawn` and `forkserver` that do not apply to `fork`, in +particular all arguments should be pickleable. See the +[Python multiprocessing docs](https://docs.python.org/3/library/multiprocessing.html#the-spawn-and-forkserver-start-methods) +for more information. + +Third, because it is faster to create new processes compared to `spawn`, as `spawn` is effectively +`fork` + creating a brand new Python process without the locks by calling +[execv](https://pubs.opengroup.org/onlinepubs/9699919799/functions/exec.html). Hence the warning in +the Python docs that it is slower: there is more overhead to `spawn`. However, in almost all cases, +one would like to use multiple processes to speed up computations that take multiple minutes or even +hours, meaning the overhead is negligible in the grand scheme of things. And more importantly, it +actually works in combination with multithreaded libraries. + +Fourth, `spawn` starts a new process, and therefore it requires code to be importable, in contrast +to `fork`. In particular, this means that when using `spawn` the relevant code should not be in the +global scope, such as in Jupyter notebooks or in plain scripts. Hence in the examples above, we +define functions where we spawn within, and run those functions from a `__main__` clause. This is +not an issue for typical projects, but during quick experimentation in notebooks it could fail. + +## References + +1. https://docs.python.org/3/library/multiprocessing.html + +1. https://pythonspeed.com/articles/python-multiprocessing/ + +1. https://pubs.opengroup.org/onlinepubs/9699919799/functions/fork.html + +1. https://bnikolic.co.uk/blog/python/parallelism/2019/11/13/python-forkserver-preload.html diff --git a/docs/source/user-guide/misc/polars_llms.md b/docs/source/user-guide/misc/polars_llms.md new file mode 100644 index 000000000000..6e1459701c21 --- /dev/null +++ b/docs/source/user-guide/misc/polars_llms.md @@ -0,0 +1,77 @@ +# Generating Polars code with LLMs + +Large Language Models (LLMs) can sometimes return Pandas code or invalid Polars code in their +output. This guide presents approaches that help LLMs generate valid Polars code more consistently. + +These approaches have been developed by the Polars community through testing model responses to +various inputs. If you find additional effective approaches for generating Polars code from LLMs, +please raise a [pull request](https://github.com/pola-rs/polars/pulls). + +## System prompt + +Many LLMs allow you to provide a system prompt that is included with every individual prompt you +send to the model. In the system prompt, you can specify your preferred defaults, such as "Use +Polars as the default dataframe library". Including such a system prompt typically leads to models +consistently generating Polars code rather than Pandas code. + +You can set this system prompt in the settings menu of both web-based LLMs like ChatGPT and +IDE-based LLMs like Cursor. Refer to each application's documentation for specific instructions. + +## Enable web search + +Some LLMs can search the web to access information beyond their pre-training data. Enabling web +search allows an LLM to reference up-to-date Polars documentation for the current API. + +Some IDE-based LLMs can index the Polars API documentation and reference this when generating code. +For example, in Cursor you can add Polars as a custom docs source and instruct the agent to +reference the Polars documentation in a prompt. + +However, web search does not yet guarantee that valid code will be produced. If a model is confident +in a result based on its pre-training data, it may not incorporate web search results in its output. + +The Polars API pages also have AI-enabled search to help you find the information you need more +easily. + +## Provide examples + +You can guide LLMs to use correct syntax by including relevant examples in your prompt. + +For instance, this basic query: + +```python +df = pl.DataFrame({ + "id": ["a", "b", "a", "b", "c"], + "score": [1, 2, 1, 3, 3], + "year": [2020, 2020, 2021, 2021, 2021], +}) +# Compute average of score by id +``` + +Often results in outdated `groupby` syntax instead of the correct `group_by`. + +However, including a simple example from the Polars `group_by` documentation (preferably with web +search enabled) like this: + +```python +df = pl.DataFrame({ + "id": ["a", "b", "a", "b", "c"], + "score": [1, 2, 1, 3, 3], + "year": [2020, 2020, 2021, 2021, 2021], +}) +# Compute average of score by id +# Examples of Polars code: + +# df.group_by("a").agg(pl.col("b").mean()) +``` + +Produces valid outputs more often. This approach has been validated across several leading models. + +The combination of web search and examples is more effective than either independently. Model +outputs indicate that when an example contradicts the model's pre-trained expectations, it seems +more likely to trigger a web search for verification. + +Additionally, explicit instructions like "use `group_by` instead of `groupby`" can be effective in +guiding the model to use correct syntax. + +Common examples such as `df.group_by("a").agg(pl.col("b").mean())` can also be added the system +prompt for more consistency. diff --git a/docs/source/user-guide/misc/styling.md b/docs/source/user-guide/misc/styling.md new file mode 100644 index 000000000000..d64d3741e15c --- /dev/null +++ b/docs/source/user-guide/misc/styling.md @@ -0,0 +1,68 @@ +# Styling + +Data in a Polars `DataFrame` can be styled for presentation use the `DataFrame.style` property. This +returns a `GT` object from +[Great Tables](https://posit-dev.github.io/great-tables/articles/intro.html), which enables +structuring, formatting, and styling for table display. + +{{code_block('user-guide/misc/styling','dataframe',[])}} + +```python exec="on" result="text" session="user-guide/misc/styling" +--8<-- "python/user-guide/misc/styling.py:dataframe" +``` + +## Structure: add header title + +{{code_block('user-guide/misc/styling','structure-header',[])}} + +```python exec="on" session="user-guide/misc/styling" +--8<-- "python/user-guide/misc/styling.py:structure-header-out" +``` + +## Structure: add row stub + +{{code_block('user-guide/misc/styling','structure-stub',[])}} + +```python exec="on" session="user-guide/misc/styling" +--8<-- "python/user-guide/misc/styling.py:structure-stub-out" +``` + +## Structure: add column spanner + +{{code_block('user-guide/misc/styling','structure-spanner',[])}} + +```python exec="on" session="user-guide/misc/styling" +--8<-- "python/user-guide/misc/styling.py:structure-spanner-out" +``` + +## Format: limit decimal places + +{{code_block('user-guide/misc/styling','format-number',[])}} + +```python exec="on" session="user-guide/misc/styling" +--8<-- "python/user-guide/misc/styling.py:format-number-out" +``` + +## Style: highlight max row + +{{code_block('user-guide/misc/styling','style-simple',[])}} + +```python exec="on" session="user-guide/misc/styling" +--8<-- "python/user-guide/misc/styling.py:style-simple-out" +``` + +## Style: bold species column + +{{code_block('user-guide/misc/styling','style-bold-column',[])}} + +```python exec="on" session="user-guide/misc/styling" +--8<-- "python/user-guide/misc/styling.py:style-bold-column-out" +``` + +## Full example + +{{code_block('user-guide/misc/styling','full-example',[])}} + +```python exec="on" session="user-guide/misc/styling" +--8<-- "python/user-guide/misc/styling.py:full-example-out" +``` diff --git a/docs/source/user-guide/misc/visualization.md b/docs/source/user-guide/misc/visualization.md new file mode 100644 index 000000000000..99494ac6db0a --- /dev/null +++ b/docs/source/user-guide/misc/visualization.md @@ -0,0 +1,111 @@ +# Visualization + +Data in a Polars `DataFrame` can be visualized using common visualization libraries. + +We illustrate plotting capabilities using the Iris dataset. We read a CSV and then plot one column +against another, colored by a yet another column. + +{{code_block('user-guide/misc/visualization','dataframe',[])}} + +```python exec="on" result="text" session="user-guide/misc/visualization" +--8<-- "python/user-guide/misc/visualization.py:dataframe" +``` + +## Built-in plotting with Altair + +Polars has a `plot` method to create plots using [Altair](https://altair-viz.github.io/): + +{{code_block('user-guide/misc/visualization','altair_show_plot',[])}} + +```python exec="on" session="user-guide/misc/visualization" +--8<-- "python/user-guide/misc/visualization.py:altair_make_plot" +``` + +This is shorthand for: + +```python +import altair as alt + +( + alt.Chart(df).mark_point(tooltip=True).encode( + x="sepal_length", + y="sepal_width", + color="species", + ) + .properties(width=500) + .configure_scale(zero=False) +) +``` + +and is only provided for convenience, and to signal that Altair is known to work well with Polars. + +For configuration, we suggest reading +[Chart Configuration](https://altair-viz.github.io/altair-tutorial/notebooks/08-Configuration.html). +For example, you can: + +- Change the width/height/title with `.properties(width=500, height=350, title="My amazing plot")`. +- Change the x-axis label rotation with `.configure_axisX(labelAngle=30)`. +- Change the opacity of the points in your scatter plot with `.configure_point(opacity=.5)`. + +## hvPlot + +If you import `hvplot.polars`, then it registers a `hvplot` method which you can use to create +interactive plots using [hvPlot](https://hvplot.holoviz.org/). + +{{code_block('user-guide/misc/visualization','hvplot_show_plot',[])}} + +```python exec="on" session="user-guide/misc/visualization" +--8<-- "python/user-guide/misc/visualization.py:hvplot_make_plot" +``` + +## Matplotlib + +To create a scatter plot we can pass columns of a `DataFrame` directly to Matplotlib as a `Series` +for each column. Matplotlib does not have explicit support for Polars objects but can accept a +Polars `Series` by converting it to a NumPy array (which is zero-copy for numeric data without null +values). + +Note that because the column `'species'` isn't numeric, we need to first convert it to numeric +values so that it can be passed as an argument to `c`. + +{{code_block('user-guide/misc/visualization','matplotlib_show_plot',[])}} + +```python exec="on" session="user-guide/misc/visualization" +--8<-- "python/user-guide/misc/visualization.py:matplotlib_make_plot" +``` + +## Plotnine + +[Plotnine](https://plotnine.org/) is a reimplementation of ggplot2 in Python, bringing the Grammar +of Graphics to Python users with an interface similar to its R counterpart. It supports Polars +`DataFrame` by internally converting it to a pandas `DataFrame`. + +{{code_block('user-guide/misc/visualization','plotnine_show_plot',[])}} + +```python exec="on" session="user-guide/misc/visualization" +--8<-- "python/user-guide/misc/visualization.py:plotnine_make_plot" +``` + +## Seaborn and Plotly + +[Seaborn](https://seaborn.pydata.org/) and [Plotly](https://plotly.com/) can accept a Polars +`DataFrame` by leveraging the +[dataframe interchange protocol](https://data-apis.org/dataframe-api/), which offers zero-copy +conversion where possible. Note that the protocol does not support all Polars data types (e.g. +`List`) so your mileage may vary here. + +### Seaborn + +{{code_block('user-guide/misc/visualization','seaborn_show_plot',[])}} + +```python exec="on" session="user-guide/misc/visualization" +--8<-- "python/user-guide/misc/visualization.py:seaborn_make_plot" +``` + +### Plotly + +{{code_block('user-guide/misc/visualization','plotly_show_plot',[])}} + +```python exec="on" session="user-guide/misc/visualization" +--8<-- "python/user-guide/misc/visualization.py:plotly_make_plot" +``` diff --git a/docs/source/user-guide/plugins/expr_plugins.md b/docs/source/user-guide/plugins/expr_plugins.md new file mode 100644 index 000000000000..3c8765f1759b --- /dev/null +++ b/docs/source/user-guide/plugins/expr_plugins.md @@ -0,0 +1,265 @@ +# Expression Plugins + + + +Expression plugins are the preferred way to create user defined functions. They allow you to compile +a Rust function and register that as an expression into the Polars library. The Polars engine will +dynamically link your function at runtime and your expression will run almost as fast as native +expressions. Note that this works without any interference of Python and thus no GIL contention. + +They will benefit from the same benefits default expressions have: + +- Optimization +- Parallelism +- Rust native performance + +To get started we will see what is needed to create a custom expression. + +## Our first custom expression: Pig Latin + +For our first expression we are going to create a pig latin converter. Pig latin is a silly language +where in every word the first letter is removed, added to the back and finally "ay" is added. So the +word "pig" would convert to "igpay". + +We could of course already do that with expressions, e.g. +`col("name").str.slice(1) + col("name").str.slice(0, 1) + "ay"`, but a specialized function for this +would perform better and allows us to learn about the plugins. + +### Setting up + +We start with a new library as the following `Cargo.toml` file + +```toml +[package] +name = "expression_lib" +version = "0.1.0" +edition = "2021" + +[lib] +name = "expression_lib" +crate-type = ["cdylib"] + +[dependencies] +polars = { version = "*" } +pyo3 = { version = "*", features = ["extension-module", "abi3-py38"] } +pyo3-polars = { version = "*", features = ["derive"] } +serde = { version = "*", features = ["derive"] } +``` + +### Writing the expression + +In this library we create a helper function that converts a `&str` to pig-latin, and we create the +function that we will expose as an expression. To expose a function we must add the +`#[polars_expr(output_type=DataType)]` attribute and the function must always accept +`inputs: &[Series]` as its first argument. + +```rust +// src/expressions.rs +use polars::prelude::*; +use pyo3_polars::derive::polars_expr; +use std::fmt::Write; + +fn pig_latin_str(value: &str, output: &mut String) { + if let Some(first_char) = value.chars().next() { + write!(output, "{}{}ay", &value[1..], first_char).unwrap() + } +} + +#[polars_expr(output_type=String)] +fn pig_latinnify(inputs: &[Series]) -> PolarsResult { + let ca = inputs[0].str()?; + let out: StringChunked = ca.apply_into_string_amortized(pig_latin_str); + Ok(out.into_series()) +} +``` + +Note that we use `apply_into_string_amortized`, as opposed to `apply_values`, to avoid allocating a +new string for each row. If your plugin takes in multiple inputs, operates elementwise, and produces +a `String` output, then you may want to look at the `binary_elementwise_into_string_amortized` +utility function in `polars::prelude::arity`. + +This is all that is needed on the Rust side. On the Python side we must setup a folder with the same +name as defined in the `Cargo.toml`, in this case "expression_lib". We will create a folder in the +same directory as our Rust `src` folder named `expression_lib` and we create an +`expression_lib/__init__.py`. The resulting file structure should look something like this: + +``` +├── 📁 expression_lib/ # name must match "lib.name" in Cargo.toml +| └── __init__.py +| +├── 📁src/ +| ├── lib.rs +| └── expressions.rs +| +├── Cargo.toml +└── pyproject.toml +``` + +Then we create a new class `Language` that will hold the expressions for our new `expr.language` +namespace. The function name of our expression can be registered. Note that it is important that +this name is correct, otherwise the main Polars package cannot resolve the function name. +Furthermore we can set additional keyword arguments that explain to Polars how this expression +behaves. In this case we tell Polars that this function is elementwise. This allows Polars to run +this expression in batches. Whereas for other operations this would not be allowed, think for +instance of a sort, or a slice. + +```python +# expression_lib/__init__.py +from pathlib import Path +from typing import TYPE_CHECKING + +import polars as pl +from polars.plugins import register_plugin_function +from polars._typing import IntoExpr + +PLUGIN_PATH = Path(__file__).parent + +def pig_latinnify(expr: IntoExpr) -> pl.Expr: + """Pig-latinnify expression.""" + return register_plugin_function( + plugin_path=PLUGIN_PATH, + function_name="pig_latinnify", + args=expr, + is_elementwise=True, + ) +``` + +We can then compile this library in our environment by installing `maturin` and running +`maturin develop --release`. + +And that's it. Our expression is ready to use! + +```python +import polars as pl +from expression_lib import pig_latinnify + +df = pl.DataFrame( + { + "convert": ["pig", "latin", "is", "silly"], + } +) +out = df.with_columns(pig_latin=pig_latinnify("convert")) +``` + +Alternatively, you can +[register a custom namespace](https://docs.pola.rs/api/python/stable/reference/api/polars.api.register_expr_namespace.html#polars.api.register_expr_namespace), +which enables you to write: + +```python +out = df.with_columns( + pig_latin=pl.col("convert").language.pig_latinnify(), +) +``` + +## Accepting kwargs + +If you want to accept `kwargs` (keyword arguments) in a polars expression, all you have to do is +define a Rust `struct` and make sure that it derives `serde::Deserialize`. + +```rust +/// Provide your own kwargs struct with the proper schema and accept that type +/// in your plugin expression. +#[derive(Deserialize)] +pub struct MyKwargs { + float_arg: f64, + integer_arg: i64, + string_arg: String, + boolean_arg: bool, +} + +/// If you want to accept `kwargs`. You define a `kwargs` argument +/// on the second position in you plugin. You can provide any custom struct that is deserializable +/// with the pickle protocol (on the Rust side). +#[polars_expr(output_type=String)] +fn append_kwargs(input: &[Series], kwargs: MyKwargs) -> PolarsResult { + let input = &input[0]; + let input = input.cast(&DataType::String)?; + let ca = input.str().unwrap(); + + Ok(ca + .apply_into_string_amortized(|val, buf| { + write!( + buf, + "{}-{}-{}-{}-{}", + val, kwargs.float_arg, kwargs.integer_arg, kwargs.string_arg, kwargs.boolean_arg + ) + .unwrap() + }) + .into_series()) +} +``` + +On the Python side the kwargs can be passed when we register the plugin. + +```python +def append_args( + expr: IntoExpr, + float_arg: float, + integer_arg: int, + string_arg: str, + boolean_arg: bool, +) -> pl.Expr: + """ + This example shows how arguments other than `Series` can be used. + """ + return register_plugin_function( + plugin_path=PLUGIN_PATH, + function_name="append_kwargs", + args=expr, + kwargs={ + "float_arg": float_arg, + "integer_arg": integer_arg, + "string_arg": string_arg, + "boolean_arg": boolean_arg, + }, + is_elementwise=True, + ) +``` + +## Output data types + +Output data types of course don't have to be fixed. They often depend on the input types of an +expression. To accommodate this you can provide the `#[polars_expr()]` macro with an +`output_type_func` argument that points to a function. This function can map input fields `&[Field]` +to an output `Field` (name and data type). + +In the snippet below is an example where we use the utility `FieldsMapper` to help with this +mapping. + +```rust +use polars_plan::dsl::FieldsMapper; + +fn haversine_output(input_fields: &[Field]) -> PolarsResult { + FieldsMapper::new(input_fields).map_to_float_dtype() +} + +#[polars_expr(output_type_func=haversine_output)] +fn haversine(inputs: &[Series]) -> PolarsResult { + let out = match inputs[0].dtype() { + DataType::Float32 => { + let start_lat = inputs[0].f32().unwrap(); + let start_long = inputs[1].f32().unwrap(); + let end_lat = inputs[2].f32().unwrap(); + let end_long = inputs[3].f32().unwrap(); + crate::distances::naive_haversine(start_lat, start_long, end_lat, end_long)? + .into_series() + } + DataType::Float64 => { + let start_lat = inputs[0].f64().unwrap(); + let start_long = inputs[1].f64().unwrap(); + let end_lat = inputs[2].f64().unwrap(); + let end_long = inputs[3].f64().unwrap(); + crate::distances::naive_haversine(start_lat, start_long, end_lat, end_long)? + .into_series() + } + _ => polars_bail!(InvalidOperation: "only supported for float types"), + }; + Ok(out) +} +``` + +That's all you need to know to get started. Take a look at +[this repo](https://github.com/pola-rs/pyo3-polars/tree/main/example/derive_expression) to see how +this all fits together, and at +[this tutorial](https://marcogorelli.github.io/polars-plugins-tutorial/) to gain a more thorough +understanding. diff --git a/docs/source/user-guide/plugins/index.md b/docs/source/user-guide/plugins/index.md new file mode 100644 index 000000000000..3c1912ca7685 --- /dev/null +++ b/docs/source/user-guide/plugins/index.md @@ -0,0 +1,42 @@ +# Plugins + +Polars allows you to extend its functionality with either Expression plugins or IO plugins. + +- [Expression plugins](./expr_plugins.md) +- [IO plugins](./io_plugins.md) + +## Community plugins + +Here is a curated (non-exhaustive) list of community-implemented plugins. + +### Various + +- [polars-xdt](https://github.com/pola-rs/polars-xdt) Polars plugin with extra datetime-related + functionality which isn't quite in-scope for the main library +- [polars-hash](https://github.com/ion-elgreco/polars-hash) Stable non-cryptographic and + cryptographic hashing functions for Polars + +### Data science + +- [polars-distance](https://github.com/ion-elgreco/polars-distance) Polars plugin for pairwise + distance functions +- [polars-ds](https://github.com/abstractqqq/polars_ds_extension) Polars extension aiming to + simplify common numerical/string data analysis procedures + +### Geo + +- [polars-st](https://github.com/Oreilles/polars-st) Polars ST provides spatial operations on Polars + DataFrames, Series and Expressions. Just like Shapely and Geopandas. +- [polars-reverse-geocode](https://github.com/MarcoGorelli/polars-reverse-geocode) Offline reverse + geocoder for finding the closest city to a given (latitude, longitude) pair. +- [polars-h3](https://github.com/Filimoa/polars-h3) This is a Polars extension that adds support for + the H3 discrete global grid system, so you can index points and geometries to hexagons directly in + Polars. + +## Other material + +- [Ritchie Vink - Keynote on Polars Plugins](https://youtu.be/jKW-CBV7NUM) +- [Polars plugins tutorial](https://marcogorelli.github.io/polars-plugins-tutorial/) Learn how to + write a plugin by going through some very simple and minimal examples +- [cookiecutter-polars-plugin](https://github.com/MarcoGorelli/cookiecutter-polars-plugins) Project + template for Polars Plugins diff --git a/docs/source/user-guide/plugins/io_plugins.md b/docs/source/user-guide/plugins/io_plugins.md new file mode 100644 index 000000000000..11fc6bff415f --- /dev/null +++ b/docs/source/user-guide/plugins/io_plugins.md @@ -0,0 +1,195 @@ +# IO Plugins + +Besides [expression plugins](./expr_plugins.md), we also support IO plugins. These allow you to +register different file formats as sources to the Polars engines. Because sources can move data zero +copy via Arrow FFI and sources can produce large chunks of data before returning, we've decided to +interface to IO plugins via Python for now, as we don't think the short time the GIL is needed +should lead to any contention. + +E.g. an IO source can read their dataframe's in rust and only at the rendez-vous move the data +zero-copy having only a short time the GIL is needed. + +## Use case + +You want IO plugins if you have a source file not supported by Polars and you want to benefit from +optimizations like projection pushdown, predicate pushdown, early stopping and support of our +streaming engine. + +## Example + +So let's write a simple, very bad, custom CSV source and register that as an IO plugin. I want to +stress that this is a very bad example and is only given for learning purposes. + +First we define some imports we need: + +```python +# Use python for csv parsing. +import csv +import polars as pl +# Used to register a new generator on every instantiation. +from polars.io.plugins import register_io_source +from typing import Iterator +import io +``` + +### Parsing the schema + +Every `scan` function in Polars has to be able to provide the schema of the data it reads. For this +simple csv parser we will always read the data as `pl.String`. The only thing that differs are the +field names and the number of fields. + +```python +def parse_schema(csv_str: str) -> pl.Schema: + first_line = csv_str.split("\n")[0] + + return pl.Schema({k: pl.String for k in first_line.split(",")}) +``` + +If we run this with small csv file `"a,b,c\n1,2,3"` we get the schema: +`Schema([('a', String), ('b', String), ('c', String)])`. + +```python +>>> print(parse_schema("a,b,c\n1,2,3")) +Schema([('a', String), ('b', String), ('c', String)]) +``` + +### Writing the source + +Next up is the actual source. For this we create an outer and an inner function. The outer function +`my_scan_csv` is the user facing function. This function will accept the file name and other +potential arguments you would need for reading the source. For csv files, these arguments could be +"delimiter", "quote_char" and such. + +This outer function calls `register_io_source` which accepts a `callable` and a `schema`. The schema +is the Polars schema of the complete source file (independent of projection pushdown). + +The callable is a function that will return a generator that produces `pl.DataFrame` objects. + +The arguments of this function are predefined and this function must accept: + +- `with_columns` + + Columns that are projected. The reader must project these columns if applied + +- `predicate` + + Polars expression. The reader must filter their rows accordingly. + +- `n_rows` + + Materialize only n rows from the source. The reader can stop when `n_rows` are read. + +- `batch_size` + + A hint of the ideal batch size the reader's generator must produce. + +The inner function is the actual implementation of the IO source and can also call into Rust/C++ or +wherever the IO plugin is written. If you want to see an IO source implemented in Rust, take a look +at our [plugins repository](https://github.com/pola-rs/pyo3-polars/tree/main/example/io_plugin). + +```python +def my_scan_csv(csv_str: str) -> pl.LazyFrame: + schema = parse_schema(csv_str) + + def source_generator( + with_columns: list[str] | None, + predicate: pl.Expr | None, + n_rows: int | None, + batch_size: int | None, + ) -> Iterator[pl.DataFrame]: + """ + Generator function that creates the source. + This function will be registered as IO source. + """ + if batch_size is None: + batch_size = 100 + + # Initialize the reader. + reader = csv.reader(io.StringIO(csv_str), delimiter=',') + # Skip the header. + _ = next(reader) + + # Ensure we don't read more rows than requested from the engine + while n_rows is None or n_rows > 0: + if n_rows is not None: + batch_size = min(batch_size, n_rows) + + rows = [] + + for _ in range(batch_size): + try: + row = next(reader) + except StopIteration: + n_rows = 0 + break + rows.append(row) + + df = pl.from_records(rows, schema=schema, orient="row") + n_rows -= df.height + + # If we would make a performant reader, we would not read these + # columns at all. + if with_columns is not None: + df = df.select(with_columns) + + # If the source supports predicate pushdown, the expression can be parsed + # to skip rows/groups. + if predicate is not None: + df = df.filter(predicate) + + yield df + + return register_io_source(io_source=source_generator, schema=schema) +``` + +### Taking it for a (very slow) spin + +Finally we can test our source: + +```python +csv_str1 = """a,b,c,d +1,2,3,4 +9,10,11,2 +1,2,3,4 +1,122,3,4""" + +print(my_scan_csv(csv_str1).collect()) + + +csv_str2 = """a,b +1,2 +9,10 +1,2 +1,122""" + +print(my_scan_csv(csv_str2).head(2).collect()) +``` + +Running the script above would print the following output to the console: + +``` +shape: (4, 4) +┌─────┬─────┬─────┬─────┐ +│ a ┆ b ┆ c ┆ d │ +│ --- ┆ --- ┆ --- ┆ --- │ +│ str ┆ str ┆ str ┆ str │ +╞═════╪═════╪═════╪═════╡ +│ 1 ┆ 2 ┆ 3 ┆ 4 │ +│ 9 ┆ 10 ┆ 11 ┆ 2 │ +│ 1 ┆ 2 ┆ 3 ┆ 4 │ +│ 1 ┆ 122 ┆ 3 ┆ 4 │ +└─────┴─────┴─────┴─────┘ +shape: (2, 2) +┌─────┬─────┐ +│ a ┆ b │ +│ --- ┆ --- │ +│ str ┆ str │ +╞═════╪═════╡ +│ 1 ┆ 2 │ +│ 9 ┆ 10 │ +└─────┴─────┘ +``` + +## Further reading + +- [Rust example (distribution source)](https://github.com/pola-rs/pyo3-polars/tree/main/example/io_plugin) diff --git a/docs/source/user-guide/sql/create.md b/docs/source/user-guide/sql/create.md new file mode 100644 index 000000000000..3fb272161ec3 --- /dev/null +++ b/docs/source/user-guide/sql/create.md @@ -0,0 +1,33 @@ +# CREATE + +In Polars, the `SQLContext` provides a way to execute SQL statements against `LazyFrames` and +`DataFrames` using SQL syntax. One of the SQL statements that can be executed using `SQLContext` is +the `CREATE TABLE` statement, which is used to create a new table. + +The syntax for the `CREATE TABLE` statement in Polars is as follows: + +``` +CREATE TABLE table_name +AS +SELECT ... +``` + +In this syntax, `table_name` is the name of the new table that will be created, and `SELECT ...` is +a SELECT statement that defines the data that will be inserted into the table. + +Here's an example of how to use the `CREATE TABLE` statement in Polars: + +{{code_block('user-guide/sql/create','create',['SQLregister','SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql" +--8<-- "python/user-guide/sql/create.py:setup" +--8<-- "python/user-guide/sql/create.py:create" +``` + +In this example, we use the `execute()` method of the `SQLContext` to execute a `CREATE TABLE` +statement that creates a new table called `older_people` based on a SELECT statement that selects +all rows from the `my_table` DataFrame where the `age` column is greater than 30. + +!!! note Result + + Note that the result of a `CREATE TABLE` statement is not the table itself. The table is registered in the `SQLContext`. In case you want to turn the table back to a `DataFrame` you can use a `SELECT * FROM ...` statement diff --git a/docs/source/user-guide/sql/cte.md b/docs/source/user-guide/sql/cte.md new file mode 100644 index 000000000000..90ec01aa3e33 --- /dev/null +++ b/docs/source/user-guide/sql/cte.md @@ -0,0 +1,40 @@ +# Common Table Expressions + +Common Table Expressions (CTEs) are a feature of SQL that allow you to define a temporary named +result set that can be referenced within a SQL statement. CTEs provide a way to break down complex +SQL queries into smaller, more manageable pieces, making them easier to read, write, and maintain. + +A CTE is defined using the `WITH` keyword followed by a comma-separated list of subqueries, each of +which defines a named result set that can be used in subsequent queries. The syntax for a CTE is as +follows: + +``` +WITH cte_name AS ( + subquery +) +SELECT ... +``` + +In this syntax, `cte_name` is the name of the CTE, and `subquery` is the subquery that defines the +result set. The CTE can then be referenced in subsequent queries as if it were a table or view. + +CTEs are particularly useful when working with complex queries that involve multiple levels of +subqueries, as they allow you to break down the query into smaller, more manageable pieces that are +easier to understand and debug. Additionally, CTEs can help improve query performance by allowing +the database to optimize and cache the results of subqueries, reducing the number of times they need +to be executed. + +Polars supports Common Table Expressions (CTEs) using the WITH clause in SQL syntax. Below is an +example + +{{code_block('user-guide/sql/cte','cte',['SQLregister','SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql/cte" +--8<-- "python/user-guide/sql/cte.py:setup" +--8<-- "python/user-guide/sql/cte.py:cte" +``` + +In this example, we use the `execute()` method of the `SQLContext` to execute a SQL query that +includes a CTE. The CTE selects all rows from the `my_table` LazyFrame where the `age` column is +greater than 30 and gives it the alias `older_people`. We then execute a second SQL query that +selects all rows from the `older_people` CTE where the `name` column starts with the letter 'C'. diff --git a/docs/source/user-guide/sql/intro.md b/docs/source/user-guide/sql/intro.md new file mode 100644 index 000000000000..0c475c9d1cf2 --- /dev/null +++ b/docs/source/user-guide/sql/intro.md @@ -0,0 +1,126 @@ +# Introduction + +While Polars supports interaction with SQL, it's recommended that users familiarize themselves with +the [expression syntax](../concepts/expressions-and-contexts.md#expressions) to produce more +readable and expressive code. As the DataFrame interface is primary, new features are typically +added to the expression API first. However, if you already have an existing SQL codebase or prefer +the use of SQL, Polars does offers support for this. + +!!! note Execution + + There is no separate SQL engine because Polars translates SQL queries into [expressions](../concepts/expressions-and-contexts.md#expressions), which are then executed using its own engine. This approach ensures that Polars maintains its performance and scalability advantages as a native DataFrame library, while still providing users with the ability to work with SQL. + +## Context + +Polars uses the `SQLContext` object to manage SQL queries. The context contains a mapping of +`DataFrame` and `LazyFrame` identifier names to their corresponding datasets[^1]. The example below +starts a `SQLContext`: + +{{code_block('user-guide/sql/intro','context',['SQLContext'])}} + +```python exec="on" session="user-guide/sql" +--8<-- "python/user-guide/sql/intro.py:setup" +--8<-- "python/user-guide/sql/intro.py:context" +``` + +## Register Dataframes + +There are several ways to register DataFrames during `SQLContext` initialization. + +- register all `LazyFrame` and `DataFrame` objects in the global namespace. +- register explicitly via a dictionary mapping, or kwargs. + +{{code_block('user-guide/sql/intro','register_context',['SQLContext'])}} + +```python exec="on" session="user-guide/sql" +--8<-- "python/user-guide/sql/intro.py:register_context" +``` + +We can also register Pandas DataFrames by converting them to Polars first. + +{{code_block('user-guide/sql/intro','register_pandas',['SQLContext'])}} + +```python exec="on" session="user-guide/sql" +--8<-- "python/user-guide/sql/intro.py:register_pandas" +``` + +!!! note Pandas + + Converting a Pandas DataFrame backed by Numpy will trigger a potentially expensive conversion; however, if the Pandas DataFrame is already backed by Arrow then the conversion will be significantly cheaper (and in some cases close to free). + +Once the `SQLContext` is initialized, we can register additional Dataframes or unregister existing +Dataframes with: + +- `register` +- `register_globals` +- `register_many` +- `unregister` + +## Execute queries and collect results + +SQL queries are always executed in lazy mode to take advantage of the full set of query planning +optimizations, so we have two options to collect the result: + +- Set the parameter `eager_execution` to True in `SQLContext`; this ensures that Polars + automatically collects the LazyFrame results from `execute` calls. +- Set the parameter `eager` to True when executing a query with `execute`, or explicitly collect the + result using `collect`. + +We execute SQL queries by calling `execute` on a `SQLContext`. + +{{code_block('user-guide/sql/intro','execute',['SQLregister','SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql" +--8<-- "python/user-guide/sql/intro.py:execute" +``` + +## Execute queries from multiple sources + +SQL queries can be executed just as easily from multiple sources. In the example below, we register: + +- a CSV file (loaded lazily) +- a NDJSON file (loaded lazily) +- a Pandas DataFrame + +And join them together using SQL. Lazy reading allows to only load the necessary rows and columns +from the files. + +In the same way, it's possible to register cloud datalakes (S3, Azure Data Lake). A PyArrow dataset +can point to the datalake, then Polars can read it with `scan_pyarrow_dataset`. + +{{code_block('user-guide/sql/intro','execute_multiple_sources',['SQLregister','SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql" +--8<-- "python/user-guide/sql/intro.py:prepare_multiple_sources" +--8<-- "python/user-guide/sql/intro.py:execute_multiple_sources" +--8<-- "python/user-guide/sql/intro.py:clean_multiple_sources" +``` + +[^1]: Additionally it also tracks the [common table expressions](./cte.md) as well. + +## Compatibility + +Polars does not support the complete SQL specification, but it does support a subset of the most +common statement types. + +!!! note Dialect + + Where possible, Polars aims to follow PostgreSQL syntax definitions and function behaviour. + +For example, here is a non-exhaustive list of some of the supported functionality: + +- Write a `CREATE` statements: `CREATE TABLE xxx AS ...` +- Write a `SELECT` statements containing:`WHERE`,`ORDER`,`LIMIT`,`GROUP BY`,`UNION` and `JOIN` + clauses ... +- Write Common Table Expressions (CTE's) such as: `WITH tablename AS` +- Explain a query: `EXPLAIN SELECT ...` +- List registered tables: `SHOW TABLES` +- Drop a table: `DROP TABLE tablename` +- Truncate a table: `TRUNCATE TABLE tablename` + +The following are some features that are not yet supported: + +- `INSERT`, `UPDATE` or `DELETE` statements +- Meta queries such as `ANALYZE` + +In the upcoming sections we will cover each of the statements in more detail. diff --git a/docs/source/user-guide/sql/select.md b/docs/source/user-guide/sql/select.md new file mode 100644 index 000000000000..df223705a770 --- /dev/null +++ b/docs/source/user-guide/sql/select.md @@ -0,0 +1,82 @@ +# SELECT + +In Polars SQL, the `SELECT` statement is used to retrieve data from a table into a `DataFrame`. The +basic syntax of a `SELECT` statement in Polars SQL is as follows: + +```sql +SELECT column1, column2, ... +FROM table_name; +``` + +Here, `column1`, `column2`, etc. are the columns that you want to select from the table. You can +also use the wildcard `*` to select all columns. `table_name` is the name of the table or that you +want to retrieve data from. In the sections below we will cover some of the more common SELECT +variants + +{{code_block('user-guide/sql/select','df',['SQLregister','SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql/select" +--8<-- "python/user-guide/sql/select.py:setup" +--8<-- "python/user-guide/sql/select.py:df" +``` + +### GROUP BY + +The `GROUP BY` statement is used to group rows in a table by one or more columns and compute +aggregate functions on each group. + +{{code_block('user-guide/sql/select','group_by',['SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql/select" +--8<-- "python/user-guide/sql/select.py:group_by" +``` + +### ORDER BY + +The `ORDER BY` statement is used to sort the result set of a query by one or more columns in +ascending or descending order. + +{{code_block('user-guide/sql/select','orderby',['SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql/select" +--8<-- "python/user-guide/sql/select.py:orderby" +``` + +### JOIN + +{{code_block('user-guide/sql/select','join',['SQLregister_many','SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql/select" +--8<-- "python/user-guide/sql/select.py:join" +``` + +### Functions + +Polars provides a wide range of SQL functions, including: + +- Mathematical functions: `ABS`, `EXP`, `LOG`, `ASIN`, `ACOS`, `ATAN`, etc. +- String functions: `LOWER`, `UPPER`, `LTRIM`, `RTRIM`, `STARTS_WITH`,`ENDS_WITH`. +- Aggregation functions: `SUM`, `AVG`, `MIN`, `MAX`, `COUNT`, `STDDEV`, `FIRST` etc. +- Array functions: `EXPLODE`, `UNNEST`,`ARRAY_SUM`,`ARRAY_REVERSE`, etc. + +For a full list of supported functions go the +[API documentation](https://docs.rs/polars-sql/latest/src/polars_sql/keywords.rs.html). The example +below demonstrates how to use a function in a query + +{{code_block('user-guide/sql/select','functions',['SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql/select" +--8<-- "python/user-guide/sql/select.py:functions" +``` + +### Table Functions + +In the examples earlier we first generated a DataFrame which we registered in the `SQLContext`. +Polars also support directly reading from CSV, Parquet, JSON and IPC in your SQL query using table +functions `read_xxx`. + +{{code_block('user-guide/sql/select','tablefunctions',['SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql/select" +--8<-- "python/user-guide/sql/select.py:tablefunctions" +``` diff --git a/docs/source/user-guide/sql/show.md b/docs/source/user-guide/sql/show.md new file mode 100644 index 000000000000..55a0496eadd3 --- /dev/null +++ b/docs/source/user-guide/sql/show.md @@ -0,0 +1,30 @@ +# SHOW TABLES + +In Polars, the `SHOW TABLES` statement is used to list all the tables that have been registered in +the current `SQLContext`. When you register a DataFrame with the `SQLContext`, you give it a name +that can be used to refer to the DataFrame in subsequent SQL statements. The `SHOW TABLES` statement +allows you to see a list of all the registered tables, along with their names. + +The syntax for the `SHOW TABLES` statement in Polars is as follows: + +``` +SHOW TABLES +``` + +Here's an example of how to use the `SHOW TABLES` statement in Polars: + +{{code_block('user-guide/sql/show','show',['SQLregister','SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql/show" +--8<-- "python/user-guide/sql/show.py:setup" +--8<-- "python/user-guide/sql/show.py:show" +``` + +In this example, we create two DataFrames and register them with the `SQLContext` using different +names. We then execute a `SHOW TABLES` statement using the `execute()` method of the `SQLContext` +object, which returns a DataFrame containing a list of all the registered tables and their names. +The resulting DataFrame is then printed using the `print()` function. + +Note that the `SHOW TABLES` statement only lists tables that have been registered with the current +`SQLContext`. If you register a DataFrame with a different `SQLContext` or in a different Python +session, it will not appear in the list of tables returned by `SHOW TABLES`. diff --git a/docs/source/user-guide/transformations/concatenation.md b/docs/source/user-guide/transformations/concatenation.md new file mode 100644 index 000000000000..f43f8e781991 --- /dev/null +++ b/docs/source/user-guide/transformations/concatenation.md @@ -0,0 +1,75 @@ +# Concatenation + +There are a number of ways to concatenate data from separate DataFrames: + +- two dataframes with **the same columns** can be **vertically** concatenated to make a **longer** + dataframe +- two dataframes with **non-overlapping columns** can be **horizontally** concatenated to make a + **wider** dataframe +- two dataframes with **different numbers of rows and columns** can be **diagonally** concatenated + to make a dataframe which might be longer and/ or wider. Where column names overlap values will be + vertically concatenated. Where column names do not overlap new rows and columns will be added. + Missing values will be set as `null` + +## Vertical concatenation - getting longer + +In a vertical concatenation you combine all of the rows from a list of `DataFrames` into a single +longer `DataFrame`. + +{{code_block('user-guide/transformations/concatenation','vertical',['concat'])}} + +```python exec="on" result="text" session="user-guide/transformations/concatenation" +--8<-- "python/user-guide/transformations/concatenation.py:setup" +--8<-- "python/user-guide/transformations/concatenation.py:vertical" +``` + +Vertical concatenation fails when the dataframes do not have the same column names. + +## Horizontal concatenation - getting wider + +In a horizontal concatenation you combine all of the columns from a list of `DataFrames` into a +single wider `DataFrame`. + +{{code_block('user-guide/transformations/concatenation','horizontal',['concat'])}} + +```python exec="on" result="text" session="user-guide/transformations/concatenation" +--8<-- "python/user-guide/transformations/concatenation.py:horizontal" +``` + +Horizontal concatenation fails when dataframes have overlapping columns. + +When dataframes have different numbers of rows, columns will be padded with `null` values at the end +up to the maximum length. + +{{code_block('user-guide/transformations/concatenation','horizontal_different_lengths',['concat'])}} + +```python exec="on" result="text" session="user-guide/transformations/concatenation" +--8<-- "python/user-guide/transformations/concatenation.py:horizontal_different_lengths" +``` + +## Diagonal concatenation - getting longer, wider and `null`ier + +In a diagonal concatenation you combine all of the row and columns from a list of `DataFrames` into +a single longer and/or wider `DataFrame`. + +{{code_block('user-guide/transformations/concatenation','cross',['concat'])}} + +```python exec="on" result="text" session="user-guide/transformations/concatenation" +--8<-- "python/user-guide/transformations/concatenation.py:cross" +``` + +Diagonal concatenation generates nulls when the column names do not overlap. + +When the dataframe shapes do not match and we have an overlapping semantic key then +[we can join the dataframes](joins.md) instead of concatenating them. + +## Rechunking + +Before a concatenation we have two dataframes `df1` and `df2`. Each column in `df1` and `df2` is in +one or more chunks in memory. By default, during concatenation the chunks in each column are not +made contiguous. This makes the concat operation faster and consume less memory but it may slow down +future operations that would benefit from having the data be in contiguous memory. The process of +copying the fragmented chunks into a single new chunk is known as **rechunking**. Rechunking is an +expensive operation. Prior to version 0.20.26, the default was to perform a rechunk but in new +versions, the default is not to. If you do want Polars to rechunk the concatenated `DataFrame` you +specify `rechunk = True` when doing the concatenation. diff --git a/docs/source/user-guide/transformations/index.md b/docs/source/user-guide/transformations/index.md new file mode 100644 index 000000000000..452aaeb3e892 --- /dev/null +++ b/docs/source/user-guide/transformations/index.md @@ -0,0 +1,11 @@ +# Transformations + +The focus of this section is to describe different types of data transformations and provide some +examples on how to use them. + + + +- [Joins](joins.md) +- [Concatenation](concatenation.md) +- [Pivot](pivot.md) +- [Unpivot](unpivot.md) diff --git a/docs/source/user-guide/transformations/joins.md b/docs/source/user-guide/transformations/joins.md new file mode 100644 index 000000000000..f50418532347 --- /dev/null +++ b/docs/source/user-guide/transformations/joins.md @@ -0,0 +1,306 @@ +# Joins + +A join operation combines columns from one or more dataframes into a new dataframe. The different +“joining strategies” and matching criteria used by the different types of joins influence how +columns are combined and also what rows are included in the result of the join operation. + +The most common type of join is an “equi join”, in which rows are matched by a key expression. +Polars supports several joining strategies for equi joins, which determine exactly how we handle the +matching of rows. Polars also supports “non-equi joins”, a type of join where the matching criterion +is not an equality, and a type of join where rows are matched by key proximity, called “asof join”. + +## Quick reference table + +The table below acts as a quick reference for people who know what they are looking for. If you want +to learn about joins in general and how to work with them in Polars, feel free to skip the table and +keep reading below. + +=== ":fontawesome-brands-python: Python" + + [:material-api: `join`](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.join.html) + [:material-api: `join_where`](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.join_where.html) + [:material-api: `join_asof`](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.join_asof.html) + +=== ":fontawesome-brands-rust: Rust" + + [:material-api: `join`](https://docs.pola.rs/api/rust/dev/polars/prelude/trait.DataFrameJoinOps.html#method.join) + ([:material-flag-plus: semi_anti_join](/user-guide/installation/#feature-flags "Enable the feature flag semi_anti_join for semi and for anti joins"){.feature-flag} needed for some options.) + [:material-api: `join_asof_by`](https://docs.pola.rs/api/rust/dev/polars/prelude/trait.AsofJoinBy.html#method.join_asof_by) + [:material-flag-plus: Available on feature asof_join](/user-guide/installation/#feature-flags "To use this functionality enable the feature flag asof_join"){.feature-flag} + [:material-api: `join_where`](https://docs.rs/polars/latest/polars/prelude/struct.JoinBuilder.html#method.join_where) + [:material-flag-plus: Available on feature iejoin](/user-guide/installation/#feature-flags "To use this functionality enable the feature flag iejoin"){.feature-flag} + +| Type | Function | Brief description | +| --------------------- | -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| Equi inner join | `join(..., how="inner")` | Keeps rows that matched both on the left and right. | +| Equi left outer join | `join(..., how="left")` | Keeps all rows from the left plus matching rows from the right. Non-matching rows from the left have their right columns filled with `null`. | +| Equi right outer join | `join(..., how="right")` | Keeps all rows from the right plus matching rows from the left. Non-matching rows from the right have their left columns filled with `null`. | +| Equi full join | `join(..., how="full")` | Keeps all rows from either dataframe, regardless of whether they match or not. Non-matching rows from one side have the columns from the other side filled with `null`. | +| Equi semi join | `join(..., how="semi")` | Keeps rows from the left that have a match on the right. | +| Equi anti join | `join(..., how="anti")` | Keeps rows from the left that do not have a match on the right. | +| Non-equi inner join | `join_where` | Finds all possible pairings of rows from the left and right that satisfy the given predicate(s). | +| Asof join | `join_asof`/`join_asof_by` | Like a left outer join, but matches on the nearest key instead of on exact key matches. | +| Cartesian product | `join(..., how="cross")` | Computes the [Cartesian product](https://en.wikipedia.org/wiki/Cartesian_product) of the two dataframes. | + +## Equi joins + +In an equi join, rows are matched by checking equality of a key expression. You can do an equi join +with the function `join` by specifying the name of the column to be used as key. For the examples, +we will be loading some (modified) Monopoly property data. + +First, we load a dataframe that contains property names and their colour group in the game: + +{{code_block('user-guide/transformations/joins','props_groups',[])}} + +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:prep-data" +--8<-- "python/user-guide/transformations/joins.py:props_groups" +``` + +Next, we load a dataframe that contains property names and their price in the game: + +{{code_block('user-guide/transformations/joins','props_prices',[])}} + +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:props_prices" +``` + +Now, we join both dataframes to create a dataframe that contains property names, colour groups, and +prices: + +{{code_block('user-guide/transformations/joins','equi-join',['join'])}} + +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:equi-join" +``` + +The result has four rows but both dataframes used in the operation had five rows. Polars uses a +joining strategy to determine what happens with rows that have multiple matches or with rows that +have no match at all. By default, Polars computes an “inner join” but there are +[other join strategies that we show next](#join-strategies). + +In the example above, the two dataframes conveniently had the column we wish to use as key with the +same name and with the values in the exact same format. Suppose, for the sake of argument, that one +of the dataframes had a differently named column and the other had the property names in lower case: + +{{code_block('user-guide/transformations/joins','props_groups2',['Expr.str'])}} + +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:props_groups2" +``` + +{{code_block('user-guide/transformations/joins','props_prices2',[])}} + +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:props_prices2" +``` + +In a situation like this, where we may want to perform the same join as before, we can leverage +`join`'s flexibility and specify arbitrary expressions to compute the joining key on the left and on +the right, allowing one to compute row keys dynamically: + +{{code_block('user-guide/transformations/joins', 'join-key-expression', ['join', 'Expr.str'])}} + +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:join-key-expression" +``` + +Because we are joining on the right with an expression, Polars preserves the column “property_name” +from the left and the column “name” from the right so we can have access to the original values that +the key expressions were applied to. + +## Join strategies + +When computing a join with `df1.join(df2, ...)`, we can specify one of many different join +strategies. A join strategy specifies what rows to keep from each dataframe based on whether they +match rows from the other dataframe. + +### Inner join + +In an inner join the resulting dataframe only contains the rows from the left and right dataframes +that matched. That is the default strategy used by `join` and above we can see an example of that. +We repeat the example here and explicitly specify the join strategy: + +{{code_block('user-guide/transformations/joins','inner-join',['join'])}} + +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:inner-join" +``` + +The result does not include the row from `props_groups` that contains “The Shire” and the result +also does not include the row from `props_prices` that contains “Sesame Street”. + +### Left join + +A left outer join is a join where the result contains all the rows from the left dataframe and the +rows of the right dataframe that matched any rows from the left dataframe. + +{{code_block('user-guide/transformations/joins','left-join',['join'])}} + +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:left-join" +``` + +If there are any rows from the left dataframe that have no matching rows on the right dataframe, +they get the value `null` on the new columns. + +### Right join + +Computationally speaking, a right outer join is exactly the same as a left outer join, but with the +arguments swapped. Here is an example: + +{{code_block('user-guide/transformations/joins','right-join',['join'])}} + +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:right-join" +``` + +We show that `df1.join(df2, how="right", ...)` is the same as `df2.join(df1, how="left", ...)`, up +to the order of the columns of the result, with the computation below: + +{{code_block('user-guide/transformations/joins','left-right-join-equals',['join'])}} + +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:left-right-join-equals" +``` + +### Full join + +A full outer join will keep all of the rows from the left and right dataframes, even if they don't +have matching rows in the other dataframe: + +{{code_block('user-guide/transformations/joins','full-join',['join'])}} + +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:full-join" +``` + +In this case, we see that we get two columns `property_name` and `property_name_right` to make up +for the fact that we are matching on the column `property_name` of both dataframes and there are +some names for which there are no matches. The two columns help differentiate the source of each row +data. If we wanted to force `join` to coalesce the two columns `property_name` into a single column, +we could set `coalesce=True` explicitly: + +{{code_block('user-guide/transformations/joins','full-join-coalesce',['join'])}} + +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:full-join-coalesce" +``` + +When not set, the parameter `coalesce` is determined automatically from the join strategy and the +key(s) specified, which is why the inner, left, and right, joins acted as if `coalesce=True`, even +though we didn't set it. + +### Semi join + +A semi join will return the rows of the left dataframe that have a match in the right dataframe, but +we do not actually join the matching rows: + +{{code_block('user-guide/transformations/joins', 'semi-join', [], ['join'], +['join-semi_anti_join_flag'])}} + +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:semi-join" +``` + +A semi join acts as a sort of row filter based on a second dataframe. + +### Anti join + +Conversely, an anti join will return the rows of the left dataframe that do not have a match in the +right dataframe: + +{{code_block('user-guide/transformations/joins', 'anti-join', [], ['join'], +['join-semi_anti_join_flag'])}} + +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:anti-join" +``` + +## Non-equi joins + +In a non-equi join matches between the left and right dataframes are computed differently. Instead +of looking for matches on key expressions, we provide a single predicate that determines what rows +of the left dataframe can be paired up with what rows of the right dataframe. + +For example, consider the following Monopoly players and their current cash: + +{{code_block('user-guide/transformations/joins','players',[])}} + +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:players" +``` + +Using a non-equi join we can easily build a dataframe with all the possible properties that each +player could be interested in buying. We use the function `join_where` to compute a non-equi join: + +{{code_block('user-guide/transformations/joins','non-equi',['join_where'])}} + +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:non-equi" +``` + +You can provide multiple expressions as predicates but they all must use comparison operators that +evaluate to a Boolean result and must refer to columns from both dataframes. + +!!! note + + `join_where` is still experimental and doesn't yet support arbitrary Boolean expressions as predicates. + +## Asof join + +An `asof` join is like a left join except that we match on nearest key rather than equal keys. In +Polars we can do an asof join with the `join_asof` method. + +For the asof join we will consider a scenario inspired by the stock market. Suppose a stock market +broker has a dataframe called `df_trades` showing transactions it has made for different stocks. + +{{code_block('user-guide/transformations/joins','df_trades',[])}} + +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:df_trades" +``` + +The broker has another dataframe called `df_quotes` showing prices it has quoted for these stocks: + +{{code_block('user-guide/transformations/joins','df_quotes',[])}} + +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:df_quotes" +``` + +You want to produce a dataframe showing for each trade the most recent quote provided _on or before_ +the time of the trade. You do this with `join_asof` (using the default `strategy = "backward"`). To +avoid joining between trades on one stock with a quote on another you must specify an exact +preliminary join on the stock column with `by="stock"`. + +{{code_block('user-guide/transformations/joins','asof', [], ['join_asof'], ['join_asof_by'])}} + +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:asof" +``` + +If you want to make sure that only quotes within a certain time range are joined to the trades you +can specify the `tolerance` argument. In this case we want to make sure that the last preceding +quote is within 1 minute of the trade so we set `tolerance = "1m"`. + +{{code_block('user-guide/transformations/joins','asof-tolerance', [], ['join_asof'], +['join_asof_by'])}} + +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:asof-tolerance" +``` + +## Cartesian product + +Polars allows you to compute the +[Cartesian product](https://en.wikipedia.org/wiki/Cartesian_product) of two dataframes, producing a +dataframe where all rows of the left dataframe are paired up with all the rows of the right +dataframe. To compute the Cartesian product of two dataframes, you can pass the strategy +`how="cross"` to the function `join` without specifying any of `on`, `left_on`, and `right_on`: + +{{code_block('user-guide/transformations/joins','cartesian-product',[],['join'],['cross_join'])}} + +```python exec="on" result="text" session="transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:cartesian-product" +``` diff --git a/docs/source/user-guide/transformations/pivot.md b/docs/source/user-guide/transformations/pivot.md new file mode 100644 index 000000000000..7ddc7feb24ca --- /dev/null +++ b/docs/source/user-guide/transformations/pivot.md @@ -0,0 +1,47 @@ +# Pivots + +Pivot a column in a `DataFrame` and perform one of the following aggregations: + +- first +- last +- sum +- min +- max +- mean +- median +- len + +The pivot operation consists of a group by one, or multiple columns (these will be the new y-axis), +the column that will be pivoted (this will be the new x-axis) and an aggregation. + +## Dataset + +{{code_block('user-guide/transformations/pivot','df',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/transformations/pivot" +--8<-- "python/user-guide/transformations/pivot.py:setup" +--8<-- "python/user-guide/transformations/pivot.py:df" +``` + +## Eager + +{{code_block('user-guide/transformations/pivot','eager',['pivot'])}} + +```python exec="on" result="text" session="user-guide/transformations/pivot" +--8<-- "python/user-guide/transformations/pivot.py:eager" +``` + +## Lazy + +A Polars `LazyFrame` always need to know the schema of a computation statically (before collecting +the query). As a pivot's output schema depends on the data, and it is therefore impossible to +determine the schema without running the query. + +Polars could have abstracted this fact for you just like Spark does, but we don't want you to shoot +yourself in the foot with a shotgun. The cost should be clear upfront. + +{{code_block('user-guide/transformations/pivot','lazy',['pivot'])}} + +```python exec="on" result="text" session="user-guide/transformations/pivot" +--8<-- "python/user-guide/transformations/pivot.py:lazy" +``` diff --git a/docs/source/user-guide/transformations/time-series/filter.md b/docs/source/user-guide/transformations/time-series/filter.md new file mode 100644 index 000000000000..05b19aed57b5 --- /dev/null +++ b/docs/source/user-guide/transformations/time-series/filter.md @@ -0,0 +1,50 @@ +# Filtering + +Filtering date columns works in the same way as with other types of columns using the `.filter` +method. + +Polars uses Python's native `datetime`, `date` and `timedelta` for equality comparisons between the +datatypes `pl.Datetime`, `pl.Date` and `pl.Duration`. + +In the following example we use a time series of Apple stock prices. + +{{code_block('user-guide/transformations/time-series/filter','df',['read_csv'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/filter" +--8<-- "python/user-guide/transformations/time-series/filter.py:df" +``` + +## Filtering by single dates + +We can filter by a single date using an equality comparison in a filter expression: + +{{code_block('user-guide/transformations/time-series/filter','filter',['filter'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/filter" +--8<-- "python/user-guide/transformations/time-series/filter.py:filter" +``` + +Note we are using the lowercase `datetime` method rather than the uppercase `Datetime` data type. + +## Filtering by a date range + +We can filter by a range of dates using the `is_between` method in a filter expression with the +start and end dates: + +{{code_block('user-guide/transformations/time-series/filter','range',['filter','is_between'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/filter" +--8<-- "python/user-guide/transformations/time-series/filter.py:range" +``` + +## Filtering with negative dates + +Say you are working with an archeologist and are dealing in negative dates. Polars can parse and +store them just fine, but the Python `datetime` library does not. So for filtering, you should use +attributes in the `.dt` namespace: + +{{code_block('user-guide/transformations/time-series/filter','negative',['str.to_date'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/filter" +--8<-- "python/user-guide/transformations/time-series/filter.py:negative" +``` diff --git a/docs/source/user-guide/transformations/time-series/parsing.md b/docs/source/user-guide/transformations/time-series/parsing.md new file mode 100644 index 000000000000..112f0695e9fc --- /dev/null +++ b/docs/source/user-guide/transformations/time-series/parsing.md @@ -0,0 +1,68 @@ +# Parsing + +Polars has native support for parsing time series data and doing more sophisticated operations such +as temporal grouping and resampling. + +## Datatypes + +Polars has the following datetime datatypes: + +- `Date`: Date representation e.g. 2014-07-08. It is internally represented as days since UNIX epoch + encoded by a 32-bit signed integer. +- `Datetime`: Datetime representation e.g. 2014-07-08 07:00:00. It is internally represented as a 64 + bit integer since the Unix epoch and can have different units such as ns, us, ms. +- `Duration`: A time delta type that is created when subtracting `Date/Datetime`. Similar to + `timedelta` in Python. +- `Time`: Time representation, internally represented as nanoseconds since midnight. + +## Parsing dates from a file + +When loading from a CSV file Polars attempts to parse dates and times if the `try_parse_dates` flag +is set to `True`: + +{{code_block('user-guide/transformations/time-series/parsing','df',['read_csv'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/parsing" +--8<-- "python/user-guide/transformations/time-series/parsing.py:setup" +--8<-- "python/user-guide/transformations/time-series/parsing.py:df" +``` + +This flag will trigger schema inference on a number of rows, as configured by the +`infer_schema_length` setting (100 rows by default). Schema inference is computationally expensive +and can slow down file loading if a high number of rows is used. + +On the other hand binary formats such as parquet have a schema that is respected by Polars. + +## Casting strings to dates + +You can also cast a column of datetimes encoded as strings to a datetime type. You do this by +calling the string `str.to_date` method and passing the format of the date string: + +{{code_block('user-guide/transformations/time-series/parsing','cast',['read_csv','str.to_date'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/parsing" +--8<-- "python/user-guide/transformations/time-series/parsing.py:cast" +``` + +[The format string specification can be found here.](https://docs.rs/chrono/latest/chrono/format/strftime/index.html). + +## Extracting date features from a date column + +You can extract data features such as the year or day from a date column using the `.dt` namespace: + +{{code_block('user-guide/transformations/time-series/parsing','extract',['dt.year'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/parsing" +--8<-- "python/user-guide/transformations/time-series/parsing.py:extract" +``` + +## Mixed offsets + +If you have mixed offsets (say, due to crossing daylight saving time), then you can use `utc=True` +and then convert to your time zone: + +{{code_block('user-guide/transformations/time-series/parsing','mixed',['str.to_datetime','dt.convert_time_zone'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/parsing" +--8<-- "python/user-guide/transformations/time-series/parsing.py:mixed" +``` diff --git a/docs/source/user-guide/transformations/time-series/resampling.md b/docs/source/user-guide/transformations/time-series/resampling.md new file mode 100644 index 000000000000..e82abe3f1d2f --- /dev/null +++ b/docs/source/user-guide/transformations/time-series/resampling.md @@ -0,0 +1,47 @@ +# Resampling + +We can resample by either: + +- upsampling (moving data to a higher frequency) +- downsampling (moving data to a lower frequency) +- combinations of these e.g. first upsample and then downsample + +## Downsampling to a lower frequency + +Polars views downsampling as a special case of the **group_by** operation and you can do this with +`group_by_dynamic` and `group_by_rolling` - +[see the temporal group by page for examples](rolling.md). + +## Upsampling to a higher frequency + +Let's go through an example where we generate data at 30 minute intervals: + +{{code_block('user-guide/transformations/time-series/resampling','df',['DataFrame','date_range'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/resampling" +--8<-- "python/user-guide/transformations/time-series/resampling.py:setup" +--8<-- "python/user-guide/transformations/time-series/resampling.py:df" +``` + +Upsampling can be done by defining the new sampling interval. By upsampling we are adding in extra +rows where we do not have data. As such upsampling by itself gives a DataFrame with nulls. These +nulls can then be filled with a fill strategy or interpolation. + +### Upsampling strategies + +In this example we upsample from the original 30 minutes to 15 minutes and then use a `forward` +strategy to replace the nulls with the previous non-null value: + +{{code_block('user-guide/transformations/time-series/resampling','upsample',['upsample'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/resampling" +--8<-- "python/user-guide/transformations/time-series/resampling.py:upsample" +``` + +In this example we instead fill the nulls by linear interpolation: + +{{code_block('user-guide/transformations/time-series/resampling','upsample2',['upsample','interpolate','fill_null'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/resampling" +--8<-- "python/user-guide/transformations/time-series/resampling.py:upsample2" +``` diff --git a/docs/source/user-guide/transformations/time-series/rolling.md b/docs/source/user-guide/transformations/time-series/rolling.md new file mode 100644 index 000000000000..126a68290900 --- /dev/null +++ b/docs/source/user-guide/transformations/time-series/rolling.md @@ -0,0 +1,163 @@ +# Grouping + +## Grouping by fixed windows + +We can calculate temporal statistics using `group_by_dynamic` to group rows into days/months/years +etc. + +### Annual average example + +In following simple example we calculate the annual average closing price of Apple stock prices. We +first load the data from CSV: + +{{code_block('user-guide/transformations/time-series/rolling','df',['upsample'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/rolling" +--8<-- "python/user-guide/transformations/time-series/rolling.py:setup" +--8<-- "python/user-guide/transformations/time-series/rolling.py:df" +``` + +!!! info + + The dates are sorted in ascending order - if they are not sorted in this way the `group_by_dynamic` output will not be correct! + +To get the annual average closing price we tell `group_by_dynamic` that we want to: + +- group by the `Date` column on an annual (`1y`) basis +- take the mean values of the `Close` column for each year: + +{{code_block('user-guide/transformations/time-series/rolling','group_by',['group_by_dynamic'])}} + +The annual average closing price is then: + +```python exec="on" result="text" session="user-guide/transformations/ts/rolling" +--8<-- "python/user-guide/transformations/time-series/rolling.py:group_by" +``` + +### Parameters for `group_by_dynamic` + +A dynamic window is defined by a: + +- **every**: indicates the interval of the window +- **period**: indicates the duration of the window +- **offset**: can be used to offset the start of the windows + +The value for `every` sets how often the groups start. The time period values are flexible - for +example we could take: + +- the average over 2 year intervals by replacing `1y` with `2y` +- the average over 18 month periods by replacing `1y` with `1y6mo` + +We can also use the `period` parameter to set how long the time period for each group is. For +example, if we set the `every` parameter to be `1y` and the `period` parameter to be `2y` then we +would get groups at one year intervals where each groups spanned two years. + +If the `period` parameter is not specified then it is set equal to the `every` parameter so that if +the `every` parameter is set to be `1y` then each group spans `1y` as well. + +Because _**every**_ does not have to be equal to _**period**_, we can create many groups in a very +flexible way. They may overlap or leave boundaries between them. + +Let's see how the windows for some parameter combinations would look. Let's start out boring. 🥱 + +- every: 1 day -> `"1d"` +- period: 1 day -> `"1d"` + +```text +this creates adjacent windows of the same size +|--| + |--| + |--| +``` + +- every: 1 day -> `"1d"` +- period: 2 days -> `"2d"` + +```text +these windows have an overlap of 1 day +|----| + |----| + |----| +``` + +- every: 2 days -> `"2d"` +- period: 1 day -> `"1d"` + +```text +this would leave gaps between the windows +data points that in these gaps will not be a member of any group +|--| + |--| + |--| +``` + +#### `truncate` + +The `truncate` parameter is a Boolean variable that determines what datetime value is associated +with each group in the output. In the example above the first data point is on 23rd February 1981. +If `truncate = True` (the default) then the date for the first year in the annual average is 1st +January 1981. However, if `truncate = False` then the date for the first year in the annual average +is the date of the first data point on 23rd February 1981. Note that `truncate` only affects what's +shown in the `Date` column and does not affect the window boundaries. + +### Using expressions in `group_by_dynamic` + +We aren't restricted to using simple aggregations like `mean` in a group by operation - we can use +the full range of expressions available in Polars. + +In the snippet below we create a `date range` with every **day** (`"1d"`) in 2021 and turn this into +a `DataFrame`. + +Then in the `group_by_dynamic` we create dynamic windows that start every **month** (`"1mo"`) and +have a window length of `1` month. The values that match these dynamic windows are then assigned to +that group and can be aggregated with the powerful expression API. + +Below we show an example where we use **group_by_dynamic** to compute: + +- the number of days until the end of the month +- the number of days in a month + +{{code_block('user-guide/transformations/time-series/rolling','group_by_dyn',['group_by_dynamic','DataFrame.explode','date_range'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/rolling" +--8<-- "python/user-guide/transformations/time-series/rolling.py:group_by_dyn" +``` + +## Grouping by rolling windows + +The rolling operation, `rolling`, is another entrance to the `group_by`/`agg` context. But different +from the `group_by_dynamic` where the windows are fixed by a parameter `every` and `period`. In a +`rolling`, the windows are not fixed at all! They are determined by the values in the +`index_column`. + +So imagine having a time column with the values `{2021-01-06, 2021-01-10}` and a `period="5d"` this +would create the following windows: + +```text +2021-01-01 2021-01-06 + |----------| + + 2021-01-05 2021-01-10 + |----------| +``` + +Because the windows of a rolling group by are always determined by the values in the `DataFrame` +column, the number of groups is always equal to the original `DataFrame`. + +## Combining group by operations + +Rolling and dynamic group by operations can be combined with normal group by operations. + +Below is an example with a dynamic group by. + +{{code_block('user-guide/transformations/time-series/rolling','group_by_roll',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/rolling" +--8<-- "python/user-guide/transformations/time-series/rolling.py:group_by_roll" +``` + +{{code_block('user-guide/transformations/time-series/rolling','group_by_dyn2',['group_by_dynamic'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/rolling" +--8<-- "python/user-guide/transformations/time-series/rolling.py:group_by_dyn2" +``` diff --git a/docs/source/user-guide/transformations/time-series/timezones.md b/docs/source/user-guide/transformations/time-series/timezones.md new file mode 100644 index 000000000000..25e6c873c50b --- /dev/null +++ b/docs/source/user-guide/transformations/time-series/timezones.md @@ -0,0 +1,46 @@ +--- +hide: + - toc +--- + +# Time zones + +!!! quote "Tom Scott" + + You really should never, ever deal with time zones if you can help it. + +The `Datetime` datatype can have a time zone associated with it. Examples of valid time zones are: + +- `None`: no time zone, also known as "time zone naive". +- `UTC`: Coordinated Universal Time. +- `Asia/Kathmandu`: time zone in "area/location" format. See the + [list of tz database time zones](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones) to + see what's available. + +Caution: Fixed offsets such as +02:00, should not be used for handling time zones. It's advised to +use the "Area/Location" format mentioned above, as it can manage timezones more effectively. + +Note that, because a `Datetime` can only have a single time zone, it is impossible to have a column +with multiple time zones. If you are parsing data with multiple offsets, you may want to pass +`utc=True` to convert them all to a common time zone (`UTC`), see +[parsing dates and times](parsing.md). + +The main methods for setting and converting between time zones are: + +- `dt.convert_time_zone`: convert from one time zone to another. +- `dt.replace_time_zone`: set/unset/change time zone. + +Let's look at some examples of common operations: + +{{code_block('user-guide/transformations/time-series/timezones','example',['str.to_datetime','dt.replace_time_zone'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/timezones" +--8<-- "python/user-guide/transformations/time-series/timezones.py:setup" +--8<-- "python/user-guide/transformations/time-series/timezones.py:example" +``` + +{{code_block('user-guide/transformations/time-series/timezones','example2',['dt.convert_time_zone','dt.replace_time_zone'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/timezones" +--8<-- "python/user-guide/transformations/time-series/timezones.py:example2" +``` diff --git a/docs/source/user-guide/transformations/unpivot.md b/docs/source/user-guide/transformations/unpivot.md new file mode 100644 index 000000000000..83715a001dc8 --- /dev/null +++ b/docs/source/user-guide/transformations/unpivot.md @@ -0,0 +1,21 @@ +# Unpivots + +Unpivot unpivots a DataFrame from wide format to long format + +## Dataset + +{{code_block('user-guide/transformations/unpivot','df',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/transformations/unpivot" +--8<-- "python/user-guide/transformations/unpivot.py:df" +``` + +## Eager + lazy + +`Eager` and `lazy` have the same API. + +{{code_block('user-guide/transformations/unpivot','unpivot',['unpivot'])}} + +```python exec="on" result="text" session="user-guide/transformations/unpivot" +--8<-- "python/user-guide/transformations/unpivot.py:unpivot" +``` diff --git a/dprint.json b/dprint.json new file mode 100644 index 000000000000..38bd6131b936 --- /dev/null +++ b/dprint.json @@ -0,0 +1,23 @@ +{ + "includes": [ + "**/*.{md,toml,json}" + ], + "markdown": { + "lineWidth": 100, + "textWrap": "always" + }, + "excludes": [ + ".venv/", + "**/target/", + "py-polars/.hypothesis/", + "py-polars/.mypy_cache/", + "py-polars/.pytest_cache/", + "py-polars/.ruff_cache/", + "py-polars/tests/unit/io/files/" + ], + "plugins": [ + "https://plugins.dprint.dev/json-0.19.3.wasm", + "https://plugins.dprint.dev/markdown-0.17.1.wasm", + "https://plugins.dprint.dev/toml-0.6.2.wasm" + ] +} diff --git a/examples/10_minutes_to_polars.ipynb b/examples/10_minutes_to_polars.ipynb deleted file mode 100644 index 9d5349c21028..000000000000 --- a/examples/10_minutes_to_polars.ipynb +++ /dev/null @@ -1,1744 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 10 minutes to polars\n", - "This as short introduction to Polars to get you started with the basic concepts of data wrangling. It is very much influenced by [10 minutes to pandas](https://pandas.pydata.org/pandas-docs/stable/user_guide/10min.html).\n", - "\n", - "We start by importing Polars. If you run this for the first time, get a coffee. This will take a while" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "use polars::prelude::*;\n", - "\n", - "#[macro_use]\n", - "extern crate polars;" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Object creation\n", - "Creating a `Series` by passing a list of nullable values. Note that we use `Option` to describe missing values." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Series: i32\n", - "[\n", - "\t1\n", - "\t3\n", - "\t5\n", - "\tnull\n", - "\t6\n", - "\t8\n", - "]" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "Series::new(\n", - " \"some_values with ones\", \n", - " &[Some(1), Some(3), Some(5), None, Some(6), Some(8)]\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "If we dont have any missing values, we can just pass a slice of `T`." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Series: i32\n", - "[\n", - "\t1\n", - "\t3\n", - "\t5\n", - "\t7\n", - "\t6\n", - "\t8\n", - "]" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "Series::new(\n", - " \"some_non_null_values\", \n", - " &[1, 3, 5, 7, 6, 8]\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The `Series` are actually an `Enum` around different typed values of a `ChunkedArray`. \n", - "You can think of a `ChunedkArray` as an array with a known type. Every `ChunkedArray` has a type alias that makes them more convenient to use. \n", - "\n", - "Some examples are:\n", - "\n", - "| Type | Alias |\n", - "|-----------------------------|------------------|\n", - "| `ChunkedArray` | `Float64Chunked` |\n", - "| `ChunkedArray` | `UInt32Chunked` |\n", - "| `ChunkedArray` | `BooleanChunked` |\n", - "| `ChunkedArray` | `Utf8Chunked` |\n", - "\n", - "See all available data types [here](https://ritchie46.github.io/polars/polars/datatypes/index.html).\n", - "\n", - "Create a `ChunkedArray` with null values:" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[PrimitiveArray\n", - "[\n", - " null,\n", - " 1,\n", - " 2,\n", - "]]" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "Int64Chunked::new_from_opt_slice(\"nullable\", &[None, Some(1), Some(2)])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Or create a `ChunkedArray` without null values." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[PrimitiveArray\n", - "[\n", - " 1,\n", - " 2,\n", - " 3,\n", - "]]" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "Int64Chunked::new_from_slice(\"non-nullable\", &[1, 2, 3])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Converting from `Series` to a `ChunkedArray` can be done by defining there type." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Ok([PrimitiveArray\n", - "[\n", - " 1,\n", - " 2,\n", - " 3,\n", - "]])" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "let s = Series::new(\"values\", &[1, 2, 3]);\n", - "s.i32()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This will return an `Err` if you specify the wrong type." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Err(DataTypeMisMatch)" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "s.i64()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "But we can cast a `Series` to the proper type and then unpack." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Ok([PrimitiveArray\n", - "[\n", - " 1,\n", - " 2,\n", - " 3,\n", - "]])" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "s.cast::().unwrap().i64()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Below we use pattern matching to check if the cast was successful. Note that the clones on a `ChunkedArray` and a `Series` are very cheap, as the underlying data is wrapped by an `Arc`. " - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[PrimitiveArray\n", - "[\n", - " 1,\n", - " 2,\n", - " 3,\n", - "]]" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "let ca = match s.i64() {\n", - " Err(_) => {\n", - " s.cast::()\n", - " .unwrap()\n", - " .i64()\n", - " .map(|ca| ca.clone())\n", - " .unwrap()\n", - " },\n", - " Ok(ca) => ca.clone()\n", - "};\n", - "ca" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Converting from a `ChunkedArray` to a `Series`." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Series: i64\n", - "[\n", - "\t1\n", - "\t2\n", - "\t3\n", - "]" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "ca.into_series()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "A `DataFrame` is created from a `Vec` of `Series`." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+------------+-----+-------+\n", - "| dates | n | foos |\n", - "| --- | --- | --- |\n", - "| date32 | i32 | str |\n", - "+============+=====+=======+\n", - "| 2020-08-21 | 1 | \"foo\" |\n", - "+------------+-----+-------+\n", - "| 2020-08-21 | 2 | \"foo\" |\n", - "+------------+-----+-------+\n", - "| 2020-08-22 | 3 | \"foo\" |\n", - "+------------+-----+-------+\n", - "| 2020-08-23 | 4 | \"foo\" |\n", - "+------------+-----+-------+\n", - "| 2020-08-22 | 5 | \"foo\" |\n", - "+------------+-----+-------+\n" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "let dates = &[\n", - " \"2020-08-21\",\n", - " \"2020-08-21\",\n", - " \"2020-08-22\",\n", - " \"2020-08-23\",\n", - " \"2020-08-22\",\n", - "];\n", - "let fmt = \"%Y-%m-%d\";\n", - "let s0 = Date32Chunked::parse_from_str_slice(\"dates\", dates, fmt).into();\n", - "let s1 = Series::new(\"n\", &[1, 2, 3, 4, 5]);\n", - "let s2 = Utf8Chunked::full(\"foos\", \"foo\", 5).into();\n", - "\n", - "let df = DataFrame::new(vec![s0, s1, s2]).expect(\"something went wrong\");\n", - "df" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The columns of the resulting `DataFrame` have different data types." - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Column: 'dates',\t dtype: Date32(Day)\n", - "Column: 'n',\t dtype: Int32\n", - "Column: 'foos',\t dtype: Utf8\n" - ] - }, - { - "data": { - "text/plain": [ - "()" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df.dtypes()\n", - " .iter()\n", - " .zip(df.columns().iter())\n", - " .for_each(|(dtype, name)| \n", - " println!(\"Column: '{}',\\t dtype: {:?}\", name, dtype))" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[\"dates\", \"n\", \"foos\"]" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df.columns()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Viewing data\n", - "\n", - "Here is how to view the top and bottom rows of a DataFrame." - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+------------+-----+-------+\n", - "| dates | n | foos |\n", - "| --- | --- | --- |\n", - "| date32 | i32 | str |\n", - "+============+=====+=======+\n", - "| 2020-08-21 | 1 | \"foo\" |\n", - "+------------+-----+-------+\n", - "| 2020-08-21 | 2 | \"foo\" |\n", - "+------------+-----+-------+\n", - "| 2020-08-22 | 3 | \"foo\" |\n", - "+------------+-----+-------+\n" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df.head(Some(3))" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+------------+-----+-------+\n", - "| dates | n | foos |\n", - "| --- | --- | --- |\n", - "| date32 | i32 | str |\n", - "+============+=====+=======+\n", - "| 2020-08-22 | 3 | \"foo\" |\n", - "+------------+-----+-------+\n", - "| 2020-08-23 | 4 | \"foo\" |\n", - "+------------+-----+-------+\n", - "| 2020-08-22 | 5 | \"foo\" |\n", - "+------------+-----+-------+\n" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df.tail(Some(3))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Sorting by a column:" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+------------+-----+-------+\n", - "| dates | n | foos |\n", - "| --- | --- | --- |\n", - "| date32 | i32 | str |\n", - "+============+=====+=======+\n", - "| 2020-08-23 | 4 | \"foo\" |\n", - "+------------+-----+-------+\n", - "| 2020-08-22 | 3 | \"foo\" |\n", - "+------------+-----+-------+\n", - "| 2020-08-22 | 5 | \"foo\" |\n", - "+------------+-----+-------+\n", - "| 2020-08-21 | 1 | \"foo\" |\n", - "+------------+-----+-------+\n", - "| 2020-08-21 | 2 | \"foo\" |\n", - "+------------+-----+-------+\n" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "let reverse = true;\n", - "df.sort(\"dates\", reverse).expect(\"column not sortable\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Selection\n", - "Selecting a single column, which yields a `Result`:" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Series: date32(day)\n", - "[\n", - "\t2020-08-21\n", - "\t2020-08-21\n", - "\t2020-08-22\n", - "\t2020-08-23\n", - "\t2020-08-22\n", - "]" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df.column(\"dates\")\n", - " .expect(\"columns don't exist\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Selecting 1 or multiple columns, which yield another `Result`:" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+------------+\n", - "| dates |\n", - "| --- |\n", - "| date32 |\n", - "+============+\n", - "| 2020-08-21 |\n", - "+------------+\n", - "| 2020-08-21 |\n", - "+------------+\n", - "| 2020-08-22 |\n", - "+------------+\n", - "| 2020-08-23 |\n", - "+------------+\n", - "| 2020-08-22 |\n", - "+------------+\n" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df.select(\"dates\")\n", - " .expect(\"column does not exist\")" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+------------+-----+\n", - "| dates | n |\n", - "| --- | --- |\n", - "| date32 | i32 |\n", - "+============+=====+\n", - "| 2020-08-21 | 1 |\n", - "+------------+-----+\n", - "| 2020-08-21 | 2 |\n", - "+------------+-----+\n", - "| 2020-08-22 | 3 |\n", - "+------------+-----+\n", - "| 2020-08-23 | 4 |\n", - "+------------+-----+\n", - "| 2020-08-22 | 5 |\n", - "+------------+-----+\n" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df.select(&[\"dates\", \"n\"])\n", - " .expect(\"column does not exist\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "A `DataFrame` can also be sliced in to a subset of the DataFrame." - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "+------------+-----+-------+\n", - "| dates | n | foos |\n", - "| --- | --- | --- |\n", - "| date32 | i32 | str |\n", - "+============+=====+=======+\n", - "| 2020-08-22 | 3 | \"foo\" |\n", - "+------------+-----+-------+\n", - "| 2020-08-23 | 4 | \"foo\" |\n", - "+------------+-----+-------+\n" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "let offset = 2;\n", - "let length = 2;\n", - "df.slice(offset, length)\n", - " .expect(\"slice was not within bounds\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Select a column by index:" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Series: i32\n", - "[\n", - "\t1\n", - "\t2\n", - "\t3\n", - "\t4\n", - "\t5\n", - "]" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df.select_at_idx(1)\n", - " .expect(\"column was not within bounds\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Boolean indexing\n", - "Boolean indexes can be used to filter data. Note that this also works on `Series` and `ChunkedArray`. We also use the `as_result!` macro. This utility expects a block that returns a `Result`. This makes it to convenient to use the `?` operator." - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+------------+-----+-------+\n", - "| dates | n | foos |\n", - "| --- | --- | --- |\n", - "| date32 | i32 | str |\n", - "+============+=====+=======+\n", - "| 2020-08-22 | 3 | \"foo\" |\n", - "+------------+-----+-------+\n", - "| 2020-08-23 | 4 | \"foo\" |\n", - "+------------+-----+-------+\n", - "| 2020-08-22 | 5 | \"foo\" |\n", - "+------------+-----+-------+\n" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "as_result!({\n", - " // select the n column\n", - " let n_s = df.column(\"n\")?;\n", - " let mask = n_s.gt(2);\n", - "\n", - " // filter values > 2\n", - " df.filter(&mask)\n", - "}).unwrap()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Filter all values in the \"n\" column greater than 2 and smaller than 5:" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+------------+-----+-------+\n", - "| dates | n | foos |\n", - "| --- | --- | --- |\n", - "| date32 | i32 | str |\n", - "+============+=====+=======+\n", - "| 2020-08-22 | 3 | \"foo\" |\n", - "+------------+-----+-------+\n", - "| 2020-08-23 | 4 | \"foo\" |\n", - "+------------+-----+-------+\n" - ] - }, - "execution_count": 24, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "as_result!({\n", - " // select the n column\n", - " let n_s = df.column(\"n\")?;\n", - " \n", - " // create the boolean mask\n", - " let mask = (n_s.gt(2) & n_s.lt(5))?;\n", - "\n", - " // filter values > 2\n", - " df.filter(&mask)\n", - "}).unwrap()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "For all the comparison methods available on `Series` and `ChunkArrays` check the [ChunkCompare trait](https://ritchie46.github.io/polars/polars/chunked_array/ops/trait.ChunkCompare.html).\n", - "\n", - "# Setting\n", - "Setting a new column can be done with the `hstack` operation. This is operation adds new columns to the existing `DataFrame`." - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+------------+-----+-------+-------+\n", - "| dates | n | foos | days |\n", - "| --- | --- | --- | --- |\n", - "| date32 | i32 | str | str |\n", - "+============+=====+=======+=======+\n", - "| 2020-08-21 | 1 | \"foo\" | \"mo\" |\n", - "+------------+-----+-------+-------+\n", - "| 2020-08-21 | 2 | \"foo\" | \"tue\" |\n", - "+------------+-----+-------+-------+\n", - "| 2020-08-22 | 3 | \"foo\" | \"wed\" |\n", - "+------------+-----+-------+-------+\n", - "| 2020-08-23 | 4 | \"foo\" | \"thu\" |\n", - "+------------+-----+-------+-------+\n", - "| 2020-08-22 | 5 | \"foo\" | \"fri\" |\n", - "+------------+-----+-------+-------+\n" - ] - }, - "execution_count": 25, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "let mut df = df;\n", - "let s = Series::new(\"days\", &[\"mo\", \"tue\", \"wed\", \"thu\", \"fri\"]);\n", - "df.hstack(&[s]).unwrap()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "It isn't possible to get mutable access to the columns of a `DataFrame`, because this would give you the possibility to invalidate the `DataFrame` (for instance by replacing the column with a `Series` with a different length).\n", - "\n", - "Luckely there are other ways to mutate a DataFrame. We could for instance replace a column in the `DataFrame`:" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+------------+-----+-------+-------+\n", - "| dates | n | foos | days |\n", - "| --- | --- | --- | --- |\n", - "| date32 | i32 | str | str |\n", - "+============+=====+=======+=======+\n", - "| 2020-08-21 | 1 | \"bar\" | \"mo\" |\n", - "+------------+-----+-------+-------+\n", - "| 2020-08-21 | 2 | \"bar\" | \"tue\" |\n", - "+------------+-----+-------+-------+\n", - "| 2020-08-22 | 3 | \"bar\" | \"wed\" |\n", - "+------------+-----+-------+-------+\n", - "| 2020-08-23 | 4 | \"bar\" | \"thu\" |\n", - "+------------+-----+-------+-------+\n", - "| 2020-08-22 | 5 | \"bar\" | \"fri\" |\n", - "+------------+-----+-------+-------+\n" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "let s = Utf8Chunked::full(\"bars\", \"bar\", 5);\n", - "df.replace(\"foos\", s).unwrap()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Or if we want to use the column we're replacing to determine the new column's values we can use the `apply` method and use a closure to create the new column.\n", - "\n", - "Below we use this determine `n + 1`:" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+------------+-----+-------+-------+\n", - "| dates | n | foos | days |\n", - "| --- | --- | --- | --- |\n", - "| date32 | i32 | str | str |\n", - "+============+=====+=======+=======+\n", - "| 2020-08-21 | 2 | \"bar\" | \"mo\" |\n", - "+------------+-----+-------+-------+\n", - "| 2020-08-21 | 3 | \"bar\" | \"tue\" |\n", - "+------------+-----+-------+-------+\n", - "| 2020-08-22 | 4 | \"bar\" | \"wed\" |\n", - "+------------+-----+-------+-------+\n", - "| 2020-08-23 | 5 | \"bar\" | \"thu\" |\n", - "+------------+-----+-------+-------+\n", - "| 2020-08-22 | 6 | \"bar\" | \"fri\" |\n", - "+------------+-----+-------+-------+\n" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df.apply(\"n\", |s| s + 1).unwrap()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Both the `replace` and the `apply` methods exist for selection by index; \n", - "* `replace_at_idx`\n", - "* `apply_at_idx`" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "+------------+-----+-------+-------+\n", - "| dates | n | foos | days |\n", - "| --- | --- | --- | --- |\n", - "| date32 | i32 | str | str |\n", - "+============+=====+=======+=======+\n", - "| 2020-08-21 | 4 | \"bar\" | \"mo\" |\n", - "+------------+-----+-------+-------+\n", - "| 2020-08-21 | 6 | \"bar\" | \"tue\" |\n", - "+------------+-----+-------+-------+\n", - "| 2020-08-22 | 8 | \"bar\" | \"wed\" |\n", - "+------------+-----+-------+-------+\n", - "| 2020-08-23 | 10 | \"bar\" | \"thu\" |\n", - "+------------+-----+-------+-------+\n", - "| 2020-08-22 | 12 | \"bar\" | \"fri\" |\n", - "+------------+-----+-------+-------+\n" - ] - }, - "execution_count": 28, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df.apply_at_idx(1, |s| s * 2)\n", - " .unwrap()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Or we can apply a closure to the values that are valid under a condition constraint:" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+------------+-----+---------------------+-------+\n", - "| dates | n | foos | days |\n", - "| --- | --- | --- | --- |\n", - "| date32 | i32 | str | str |\n", - "+============+=====+=====================+=======+\n", - "| 2020-08-21 | 4 | \"not_within_bounds\" | \"mo\" |\n", - "+------------+-----+---------------------+-------+\n", - "| 2020-08-21 | 6 | \"bar\" | \"tue\" |\n", - "+------------+-----+---------------------+-------+\n", - "| 2020-08-22 | 8 | \"bar\" | \"wed\" |\n", - "+------------+-----+---------------------+-------+\n", - "| 2020-08-23 | 10 | \"not_within_bounds\" | \"thu\" |\n", - "+------------+-----+---------------------+-------+\n", - "| 2020-08-22 | 12 | \"not_within_bounds\" | \"fri\" |\n", - "+------------+-----+---------------------+-------+\n" - ] - }, - "execution_count": 29, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "as_result!({\n", - " let mask = (df.column(\"n\")?.gt(4) & df.column(\"n\")?.lt(10))?;\n", - " \n", - " df.may_apply(\"foos\", |s| {\n", - " s.utf8()?\n", - " .set(&!mask, Some(\"not_within_bounds\"))\n", - " }\n", - " )\n", - " }\n", - ").unwrap()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Iterators\n", - "Every `ChunkedArray` implements the [IntoIterator trait](https://doc.rust-lang.org/std/iter/trait.IntoIterator.html) which gives us all the powerful trait methods available for iterators." - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Some(15)\n" - ] - }, - { - "data": { - "text/plain": [ - "Ok(())" - ] - }, - "execution_count": 30, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "as_result!({\n", - " let s = Series::new(\"a\", [1, 2, 3, 4, 5]);\n", - " \n", - " let v = s.i32()?\n", - " .into_iter()\n", - " .sum::>();\n", - " \n", - " println!(\"{:?}\", v);\n", - " \n", - " Ok(())\n", - "})" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+------------+-----+---------------------+----------+\n", - "| dates | n | foos | days |\n", - "| --- | --- | --- | --- |\n", - "| date32 | i32 | str | str |\n", - "+============+=====+=====================+==========+\n", - "| 2020-08-21 | 4 | \"not_within_bounds\" | \"mo_ay\" |\n", - "+------------+-----+---------------------+----------+\n", - "| 2020-08-21 | 6 | \"bar\" | \"tue_ay\" |\n", - "+------------+-----+---------------------+----------+\n", - "| 2020-08-22 | 8 | \"bar\" | \"wed_ay\" |\n", - "+------------+-----+---------------------+----------+\n", - "| 2020-08-23 | 10 | \"not_within_bounds\" | \"thu_ay\" |\n", - "+------------+-----+---------------------+----------+\n", - "| 2020-08-22 | 12 | \"not_within_bounds\" | \"fri_ay\" |\n", - "+------------+-----+---------------------+----------+\n" - ] - }, - "execution_count": 31, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "as_result!({\n", - " // adds \"ay\" to every word.\n", - " fn to_pig_latin(opt_val: Option<&str>) -> Option {\n", - " opt_val.map(|val| format!(\"{}_ay\", val))\n", - " }\n", - " \n", - " // may apply takes a closure that may fail.\n", - " df.may_apply(\"days\", |s| {\n", - " let ca: Utf8Chunked = s.utf8()?\n", - " .into_iter()\n", - " .map(to_pig_latin)\n", - " .collect();\n", - " Ok(ca)\n", - " });\n", - " \n", - " Ok(df.clone())\n", - "}).unwrap()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Concat\n", - "\n", - "Polars provides various facilities for easily combining `DataFrames` and `Series`.\n", - "\n", - "We can concatenate a `DataFrame` with `hstack`:" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "+------------+-----+---------------------+----------+------------+-----+---------------------+----------+\n", - "| dates | n | foos | days | dates | n | foos | days |\n", - "| --- | --- | --- | --- | --- | --- | --- | --- |\n", - "| date32 | i32 | str | str | date32 | i32 | str | str |\n", - "+============+=====+=====================+==========+============+=====+=====================+==========+\n", - "| 2020-08-21 | 4 | \"not_within_bounds\" | \"mo_ay\" | 2020-08-21 | 4 | \"not_within_bounds\" | \"mo_ay\" |\n", - "+------------+-----+---------------------+----------+------------+-----+---------------------+----------+\n", - "| 2020-08-21 | 6 | \"bar\" | \"tue_ay\" | 2020-08-21 | 6 | \"bar\" | \"tue_ay\" |\n", - "+------------+-----+---------------------+----------+------------+-----+---------------------+----------+\n", - "| 2020-08-22 | 8 | \"bar\" | \"wed_ay\" | 2020-08-22 | 8 | \"bar\" | \"wed_ay\" |\n", - "+------------+-----+---------------------+----------+------------+-----+---------------------+----------+\n", - "| 2020-08-23 | 10 | \"not_within_bounds\" | \"thu_ay\" | 2020-08-23 | 10 | \"not_within_bounds\" | \"thu_ay\" |\n", - "+------------+-----+---------------------+----------+------------+-----+---------------------+----------+\n", - "| 2020-08-22 | 12 | \"not_within_bounds\" | \"fri_ay\" | 2020-08-22 | 12 | \"not_within_bounds\" | \"fri_ay\" |\n", - "+------------+-----+---------------------+----------+------------+-----+---------------------+----------+\n", - "\n" - ] - } - ], - "source": [ - "{\n", - " let mut df1 = df.clone(); \n", - " \n", - " df1.hstack(df.get_columns());\n", - " \n", - " println!(\"{:?}\", df1);\n", - "};" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Or append the rows of a second DataFrame:" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "+------------+-----+---------------------+----------+\n", - "| dates | n | foos | days |\n", - "| --- | --- | --- | --- |\n", - "| date32 | i32 | str | str |\n", - "+============+=====+=====================+==========+\n", - "| 2020-08-21 | 4 | \"not_within_bounds\" | \"mo_ay\" |\n", - "+------------+-----+---------------------+----------+\n", - "| 2020-08-21 | 6 | \"bar\" | \"tue_ay\" |\n", - "+------------+-----+---------------------+----------+\n", - "| 2020-08-22 | 8 | \"bar\" | \"wed_ay\" |\n", - "+------------+-----+---------------------+----------+\n", - "| 2020-08-23 | 10 | \"not_within_bounds\" | \"thu_ay\" |\n", - "+------------+-----+---------------------+----------+\n" - ] - } - ], - "source": [ - "{\n", - " let mut df1 = df.clone(); \n", - " \n", - " df1.vstack(&df);\n", - " \n", - " println!(\"{:?}\", df1);\n", - "};" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Join\n", - "SQL-style joins. " - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "| 2020-08-22 | 12 | \"not_within_bounds\" | \"fri_ay\" |\n", - "+------------+-----+---------------------+----------+\n", - "| 2020-08-21 | 4 | \"not_within_bounds\" | \"mo_ay\" |\n", - "+------------+-----+---------------------+----------+\n", - "| 2020-08-21 | 6 | \"bar\" | \"tue_ay\" |\n", - "+------------+-----+---------------------+----------+\n", - "| 2020-08-22 | 8 | \"bar\" | \"wed_ay\" |\n", - "+------------+-----+---------------------+----------+\n", - "| 2020-08-23 | 10 | \"not_within_bounds\" | \"thu_ay\" |\n", - "+------------+-----+---------------------+----------+\n", - "| 2020-08-22 | 12 | \"not_within_bounds\" | \"fri_ay\" |\n", - "+------------+-----+---------------------+----------+\n", - "\n", - "+-------+------+\n", - "| key | lval |\n", - "| --- | --- |\n", - "| str | i32 |\n", - "+=======+======+\n", - "| \"foo\" | 1 |\n", - "+-------+------+\n", - "| \"foo\" | 2 |\n", - "+-------+------+\n", - "\n", - "+-------+------+\n", - "| key | rval |\n", - "| --- | --- |\n", - "| str | i32 |\n", - "+=======+======+\n", - "| \"foo\" | 4 |\n", - "+-------+------+\n", - "| \"foo\" | 5 |\n", - "+-------+------+\n", - "\n" - ] - }, - { - "data": { - "text/plain": [ - "+-------+------+------+\n", - "| key | lval | rval |\n", - "| --- | --- | --- |\n", - "| str | i32 | i32 |\n", - "+=======+======+======+\n", - "| \"foo\" | 1 | 4 |\n", - "+-------+------+------+\n", - "| \"foo\" | 2 | 4 |\n", - "+-------+------+------+\n", - "| \"foo\" | 1 | 5 |\n", - "+-------+------+------+\n", - "| \"foo\" | 2 | 5 |\n", - "+-------+------+------+\n" - ] - }, - "execution_count": 34, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "as_result!({\n", - " let left = DataFrame::new(vec![\n", - " Series::new(\"key\", &[\"foo\", \"foo\"]),\n", - " Series::new(\"lval\", &[1, 2]),\n", - " ])?;\n", - " \n", - " let right = DataFrame::new(vec![\n", - " Series::new(\"key\", &[\"foo\", \"foo\"]),\n", - " Series::new(\"rval\", &[4, 5]),\n", - " ])?;\n", - " \n", - " println!(\"{:?}\", left);\n", - " println!(\"{:?}\", right);\n", - " \n", - " left.inner_join(&right, \"key\", \"key\")\n", - "}).unwrap()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Another example that can be given is:" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "+-------+------+\n", - "| key | lval |\n", - "| --- | --- |\n", - "| str | i32 |\n", - "+=======+======+\n", - "| \"foo\" | 1 |\n", - "+-------+------+\n", - "| \"bar\" | 2 |\n", - "+-------+------+\n", - "\n", - "+-------+------+\n", - "| key | rval |\n", - "| --- | --- |\n", - "| str | i32 |\n", - "+=======+======+\n", - "| \"foo\" | 4 |\n", - "+-------+------+\n", - "| \"bar\" | 5 |\n", - "+-------+------+\n", - "\n" - ] - }, - { - "data": { - "text/plain": [ - "+-------+------+------+\n", - "| key | lval | rval |\n", - "| --- | --- | --- |\n", - "| str | i32 | i32 |\n", - "+=======+======+======+\n", - "| \"foo\" | 1 | 4 |\n", - "+-------+------+------+\n", - "| \"bar\" | 2 | 5 |\n", - "+-------+------+------+\n" - ] - }, - "execution_count": 35, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "as_result!({\n", - " \n", - " let left = DataFrame::new(vec![\n", - " Series::new(\"key\", &[\"foo\", \"bar\"]),\n", - " Series::new(\"lval\", &[1, 2]),\n", - " ])?;\n", - " \n", - " let right = DataFrame::new(vec![\n", - " Series::new(\"key\", &[\"foo\", \"bar\"]),\n", - " Series::new(\"rval\", &[4, 5]),\n", - " ])?;\n", - " \n", - " println!(\"{:?}\", left);\n", - " println!(\"{:?}\", right);\n", - " \n", - " left.inner_join(&right, \"key\", \"key\")\n", - "}).unwrap()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Grouping\n", - "\n", - "By \"group by\" we are referring to a process involving one or more of the following steps:\n", - "* **Splitting** the data into groups based on some criteria\n", - "* **Applying** a function to each group independently\n", - "* **Combining** the results into a data structure" - ] - }, - { - "cell_type": "code", - "execution_count": 44, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+-------+---------+-----+-----+\n", - "| A | B | C | D |\n", - "| --- | --- | --- | --- |\n", - "| str | str | i32 | i32 |\n", - "+=======+=========+=====+=====+\n", - "| \"foo\" | \"one\" | 1 | 1 |\n", - "+-------+---------+-----+-----+\n", - "| \"bar\" | \"one\" | 1 | 2 |\n", - "+-------+---------+-----+-----+\n", - "| \"foo\" | \"two\" | 1 | 3 |\n", - "+-------+---------+-----+-----+\n", - "| \"bar\" | \"three\" | 1 | 4 |\n", - "+-------+---------+-----+-----+\n", - "| \"foo\" | \"two\" | 1 | 5 |\n", - "+-------+---------+-----+-----+\n", - "| \"bar\" | \"two\" | 1 | 6 |\n", - "+-------+---------+-----+-----+\n", - "| \"foo\" | \"one\" | 1 | 7 |\n", - "+-------+---------+-----+-----+\n", - "| \"foo\" | \"three\" | 1 | 8 |\n", - "+-------+---------+-----+-----+\n" - ] - }, - "execution_count": 44, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "let df = DataFrame::new(vec![\n", - " Series::new(\"A\", &[\"foo\", \"bar\", \"foo\", \"bar\",\n", - " \"foo\", \"bar\", \"foo\", \"foo\"]),\n", - " Series::new(\"B\", &[\"one\", \"one\", \"two\", \"three\",\n", - " \"two\", \"two\", \"one\", \"three\"]),\n", - " Int32Chunked::full(\"C\", 1, 8).into(),\n", - " Series::new(\"D\", &[1, 2, 3, 4,\n", - " 5, 6, 7, 8])\n", - "]).unwrap();\n", - "df" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Grouping and then applying the `sum()` method to the resulting groups:" - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+-------+-------+\n", - "| A | C_sum |\n", - "| --- | --- |\n", - "| str | i32 |\n", - "+=======+=======+\n", - "| \"foo\" | 5 |\n", - "+-------+-------+\n", - "| \"bar\" | 3 |\n", - "+-------+-------+\n" - ] - }, - "execution_count": 45, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "as_result!({\n", - " (&df).groupby(\"A\")?.select(\"C\").sum()\n", - "}).unwrap()" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+-------+---------+-------+\n", - "| A | B | C_sum |\n", - "| --- | --- | --- |\n", - "| str | str | i32 |\n", - "+=======+=========+=======+\n", - "| \"bar\" | \"three\" | 1 |\n", - "+-------+---------+-------+\n", - "| \"foo\" | \"one\" | 2 |\n", - "+-------+---------+-------+\n", - "| \"bar\" | \"two\" | 1 |\n", - "+-------+---------+-------+\n", - "| \"foo\" | \"three\" | 1 |\n", - "+-------+---------+-------+\n", - "| \"bar\" | \"one\" | 1 |\n", - "+-------+---------+-------+\n", - "| \"foo\" | \"two\" | 2 |\n", - "+-------+---------+-------+\n" - ] - }, - "execution_count": 46, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "as_result!({\n", - " (&df).groupby(&[\"A\", \"B\"])?.select(\"C\").sum()\n", - "}).unwrap()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Pivot tables\n", - "Pivots create a summary table by a applying a groupby and defining a pivot column and values to aggregate." - ] - }, - { - "cell_type": "code", - "execution_count": 58, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+---------+-----+-------+-----+\n", - "| A | B | C | E |\n", - "| --- | --- | --- | --- |\n", - "| str | str | str | i32 |\n", - "+=========+=====+=======+=====+\n", - "| \"one\" | \"A\" | \"foo\" | 0 |\n", - "+---------+-----+-------+-----+\n", - "| \"one\" | \"B\" | \"foo\" | 1 |\n", - "+---------+-----+-------+-----+\n", - "| \"two\" | \"C\" | \"foo\" | 2 |\n", - "+---------+-----+-------+-----+\n", - "| \"three\" | \"A\" | \"bar\" | 3 |\n", - "+---------+-----+-------+-----+\n", - "| \"one\" | \"B\" | \"bar\" | 4 |\n", - "+---------+-----+-------+-----+\n", - "| \"one\" | \"C\" | \"bar\" | 5 |\n", - "+---------+-----+-------+-----+\n", - "| \"two\" | \"A\" | \"foo\" | 6 |\n", - "+---------+-----+-------+-----+\n", - "| \"three\" | \"B\" | \"foo\" | 7 |\n", - "+---------+-----+-------+-----+\n", - "| \"one\" | \"C\" | \"foo\" | 8 |\n", - "+---------+-----+-------+-----+\n", - "| \"one\" | \"A\" | \"bar\" | 9 |\n", - "+---------+-----+-------+-----+\n" - ] - }, - "execution_count": 58, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "let s0 = Series::new(\"A\", &[\"one\", \"one\", \"two\", \"three\",\n", - " \"one\", \"one\", \"two\", \"three\",\n", - " \"one\", \"one\", \"two\", \"three\"\n", - "]);\n", - "let s1 = Series::new(\"B\", &[\"A\", \"B\", \"C\",\n", - " \"A\", \"B\", \"C\",\n", - " \"A\", \"B\", \"C\",\n", - " \"A\", \"B\", \"C\",\n", - "]);\n", - "let s2 = Series::new(\"C\", &[\"foo\", \"foo\", \"foo\", \"bar\", \"bar\", \"bar\",\n", - " \"foo\", \"foo\", \"foo\", \"bar\", \"bar\", \"bar\"\n", - "]);\n", - "let s3 = Series::new(\"E\", &((0..12).collect::>()));\n", - "\n", - "let df = DataFrame::new(vec![s0, s1, s2, s3]).unwrap();\n", - "df" - ] - }, - { - "cell_type": "code", - "execution_count": 61, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+---------+-----+-----+\n", - "| A | foo | bar |\n", - "| --- | --- | --- |\n", - "| str | i32 | i32 |\n", - "+=========+=====+=====+\n", - "| \"one\" | 9 | 18 |\n", - "+---------+-----+-----+\n", - "| \"three\" | 7 | 14 |\n", - "+---------+-----+-----+\n", - "| \"two\" | 8 | 10 |\n", - "+---------+-----+-----+\n" - ] - }, - "execution_count": 61, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "as_result!({\n", - " (&df).groupby(&[\"A\"])?.pivot(\"C\", \"E\").sum()\n", - "}).unwrap()" - ] - }, - { - "cell_type": "code", - "execution_count": 59, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+---------+-----+------+------+\n", - "| A | B | foo | bar |\n", - "| --- | --- | --- | --- |\n", - "| str | str | i32 | i32 |\n", - "+=========+=====+======+======+\n", - "| \"three\" | \"A\" | null | 3 |\n", - "+---------+-----+------+------+\n", - "| \"one\" | \"C\" | 8 | 5 |\n", - "+---------+-----+------+------+\n", - "| \"two\" | \"A\" | 6 | null |\n", - "+---------+-----+------+------+\n", - "| \"one\" | \"A\" | 0 | 9 |\n", - "+---------+-----+------+------+\n", - "| \"two\" | \"C\" | 2 | null |\n", - "+---------+-----+------+------+\n", - "| \"three\" | \"B\" | 7 | null |\n", - "+---------+-----+------+------+\n", - "| \"one\" | \"B\" | 1 | 4 |\n", - "+---------+-----+------+------+\n", - "| \"two\" | \"B\" | null | 10 |\n", - "+---------+-----+------+------+\n", - "| \"three\" | \"C\" | null | 11 |\n", - "+---------+-----+------+------+\n" - ] - }, - "execution_count": 59, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "as_result!({\n", - " (&df).groupby(&[\"A\", \"B\"])?.pivot(\"C\", \"E\").sum()\n", - "}).unwrap()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.3" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/examples/10_minutes_to_pypolars.ipynb b/examples/10_minutes_to_pypolars.ipynb deleted file mode 100644 index 527ad324cff7..000000000000 --- a/examples/10_minutes_to_pypolars.ipynb +++ /dev/null @@ -1,1652 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 10 minutes to pypolars\n", - "This as short introduction to Polars to get you started with the basic concepts of data wrangling. It is very much influenced by [10 minutes to pandas](https://pandas.pydata.org/pandas-docs/stable/user_guide/10min.html).\n", - "\n", - "Pypolars are the python bindings to Polars. It currently supports only a subset of the datatypes and operations supported by Polars. \n", - "However it should be enough to give your slow pipelines a boost." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import pypolars as pl\n", - "import numpy as np\n", - "np.random.seed(1)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Object creation\n", - "\n", - "Creating a `Series` by passing a list or array of values." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Series: 'a' [i64]\n", - "[\n", - "\t1\n", - "\t2\n", - "\t3\n", - "]" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "pl.Series(\"a\", [1, 2, 3])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "A `Series` can also have nullable values." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Series: 'with nullable values' [i64]\n", - "[\n", - "\t1\n", - "\tnull\n", - "\t3\n", - "]" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "s = pl.Series(\"with nullable values\", [1, None, 3], nullable=True)\n", - "s" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Series have a data type and can be casted" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "i64\n" - ] - }, - { - "data": { - "text/plain": [ - "Series: 'with nullable values' [f32]\n", - "[\n", - "\t1\n", - "\tnull\n", - "\t3\n", - "]" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "print(s.dtype)\n", - "s.cast_f32()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "A `DataFrame` can be created by passing a dictionary with keys as column names and \n", - "list values." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+-------+-----+-----+\n", - "| foo | bar | ham |\n", - "| --- | --- | --- |\n", - "| f64 | i64 | str |\n", - "+=======+=====+=====+\n", - "| 0.417 | 0 | \"h\" |\n", - "+-------+-----+-----+\n", - "| 0.72 | 1 | \"h\" |\n", - "+-------+-----+-----+\n", - "| 0e0 | 2 | \"h\" |\n", - "+-------+-----+-----+" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df = pl.DataFrame({\n", - " \"foo\": np.random.rand(10),\n", - " \"bar\": np.arange(10),\n", - " \"ham\": [\"h\"] * 3 + [\"a\"] * 3 + [\"m\"] * 4\n", - "})\n", - "df.head(3)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The columns of the result `DataFrame` have different types and names" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['f64', 'i64', 'str']\n", - "['foo', 'bar', 'ham']\n" - ] - } - ], - "source": [ - "print(df.dtypes)\n", - "print(df.columns)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Viewing data\n", - "\n", - "We can view the top and bottom rows of a `DataFrame`" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+-------+-----+-----+\n", - "| foo | bar | ham |\n", - "| --- | --- | --- |\n", - "| f64 | i64 | str |\n", - "+=======+=====+=====+\n", - "| 0.417 | 0 | \"h\" |\n", - "+-------+-----+-----+\n", - "| 0.72 | 1 | \"h\" |\n", - "+-------+-----+-----+\n", - "| 0e0 | 2 | \"h\" |\n", - "+-------+-----+-----+" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df.head(3)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+-------+-----+-----+\n", - "| foo | bar | ham |\n", - "| --- | --- | --- |\n", - "| f64 | i64 | str |\n", - "+=======+=====+=====+\n", - "| 0.346 | 7 | \"m\" |\n", - "+-------+-----+-----+\n", - "| 0.397 | 8 | \"m\" |\n", - "+-------+-----+-----+\n", - "| 0.539 | 9 | \"m\" |\n", - "+-------+-----+-----+" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df.tail(3)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can sort by column." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+-------+-----+-----+\n", - "| foo | bar | ham |\n", - "| --- | --- | --- |\n", - "| f64 | i64 | str |\n", - "+=======+=====+=====+\n", - "| 0.72 | 1 | \"h\" |\n", - "+-------+-----+-----+\n", - "| 0.539 | 9 | \"m\" |\n", - "+-------+-----+-----+\n", - "| 0.417 | 0 | \"h\" |\n", - "+-------+-----+-----+\n", - "| 0.397 | 8 | \"m\" |\n", - "+-------+-----+-----+\n", - "| 0.346 | 7 | \"m\" |\n", - "+-------+-----+-----+" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df.sort(\"foo\", reverse=True).head(5)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Selection\n", - "We can select a single column, which returns a Series." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Series: 'foo' [f64]\n", - "[\n", - "\t0.417\n", - "\t0.72\n", - "\t0e0\n", - "]" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df[\"foo\"].head(3)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Or select a column by index" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Series: 'foo' [f64]\n", - "[\n", - "\t0.417\n", - "\t0.72\n", - "\t0e0\n", - "\t0.302\n", - "\t0.147\n", - "\t0.092\n", - "\t0.186\n", - "\t0.346\n", - "\t0.397\n", - "\t0.539\n", - "]" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df[0]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "When we select in two dimensions, we select by row, column order.\n", - "Here we slice until the third row of the first column." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Series: 'foo' [f64]\n", - "[\n", - "\t0.417\n", - "\t0.72\n", - "\t0e0\n", - "]" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df[:3, 0]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Or we can slice the whole `DataFrame` into a smaller sub `DataFrame`" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+-------+-----+-----+\n", - "| foo | bar | ham |\n", - "| --- | --- | --- |\n", - "| f64 | i64 | str |\n", - "+=======+=====+=====+\n", - "| 0.417 | 0 | \"h\" |\n", - "+-------+-----+-----+\n", - "| 0.72 | 1 | \"h\" |\n", - "+-------+-----+-----+\n", - "| 0e0 | 2 | \"h\" |\n", - "+-------+-----+-----+\n", - "| 0.302 | 3 | \"a\" |\n", - "+-------+-----+-----+" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df[:4]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Or we slice both rows and columns" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+-----+-----+\n", - "| bar | ham |\n", - "| --- | --- |\n", - "| i64 | str |\n", - "+=====+=====+\n", - "| 3 | \"a\" |\n", - "+-----+-----+\n", - "| 4 | \"a\" |\n", - "+-----+-----+" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df[3:5,1:]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Boolean indexing\n", - "Boolean indexes can be used to filter data." - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+-------+-----+-----+\n", - "| foo | bar | ham |\n", - "| --- | --- | --- |\n", - "| f64 | i64 | str |\n", - "+=======+=====+=====+\n", - "| 0.72 | 1 | \"h\" |\n", - "+-------+-----+-----+\n", - "| 0.539 | 9 | \"m\" |\n", - "+-------+-----+-----+" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df[df[\"foo\"] > 0.5]" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+-------+-----+-----+\n", - "| foo | bar | ham |\n", - "| --- | --- | --- |\n", - "| f64 | i64 | str |\n", - "+=======+=====+=====+\n", - "| 0.302 | 3 | \"a\" |\n", - "+-------+-----+-----+\n", - "| 0.147 | 4 | \"a\" |\n", - "+-------+-----+-----+\n", - "| 0.092 | 5 | \"a\" |\n", - "+-------+-----+-----+" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df[df[\"ham\"] == \"a\"]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Setting\n", - "Adding a new column to the `DataFrame` can be done with `hstack`." - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+-------+-----+-----+-----+\n", - "| foo | bar | ham | new |\n", - "| --- | --- | --- | --- |\n", - "| f64 | i64 | str | i64 |\n", - "+=======+=====+=====+=====+\n", - "| 0.417 | 0 | \"h\" | 0 |\n", - "+-------+-----+-----+-----+\n", - "| 0.72 | 1 | \"h\" | 1 |\n", - "+-------+-----+-----+-----+\n", - "| 0e0 | 2 | \"h\" | 2 |\n", - "+-------+-----+-----+-----+\n", - "| 0.302 | 3 | \"a\" | 3 |\n", - "+-------+-----+-----+-----+\n", - "| 0.147 | 4 | \"a\" | 4 |\n", - "+-------+-----+-----+-----+\n", - "| 0.092 | 5 | \"a\" | 5 |\n", - "+-------+-----+-----+-----+\n", - "| 0.186 | 6 | \"m\" | 6 |\n", - "+-------+-----+-----+-----+\n", - "| 0.346 | 7 | \"m\" | 7 |\n", - "+-------+-----+-----+-----+\n", - "| 0.397 | 8 | \"m\" | 8 |\n", - "+-------+-----+-----+-----+\n", - "| 0.539 | 9 | \"m\" | 9 |\n", - "+-------+-----+-----+-----+" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df[\"new\"] = np.arange(10)\n", - "df" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can also define the column location by index" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+---------+-----+-----+-----+\n", - "| new_foo | bar | ham | new |\n", - "| --- | --- | --- | --- |\n", - "| f64 | i64 | str | i64 |\n", - "+=========+=====+=====+=====+\n", - "| 0.419 | 0 | \"h\" | 0 |\n", - "+---------+-----+-----+-----+\n", - "| 0.685 | 1 | \"h\" | 1 |\n", - "+---------+-----+-----+-----+\n", - "| 0.204 | 2 | \"h\" | 2 |\n", - "+---------+-----+-----+-----+" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df[0] = pl.Series(\"new_foo\", np.random.rand(10))\n", - "df.head(3)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Or use a boolean mask to assign new values.\n", - "\n", - "_Note that every mutable assignment alocates new memory. So isn't actually mutable with regard to the actual memory. This is a performance trade off. Due to the immutable memory, slices, clones, subsets of `Series`/`DataFrames` are zero copy. If you need to mutate a lot of values, it's faster to do this in numpy allocate a new `Series`._" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [], - "source": [ - "# selection order is row column\n", - "df[df[\"new_foo\"] > 0.5, \"new_foo\"] = 1" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "we can also define the mutation location by passing an array of indexes" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Series: 'ham' [str]\n", - "[\n", - "\t\"h\"\n", - "\t\"c\"\n", - "\t\"c\"\n", - "\t\"c\"\n", - "\t\"a\"\n", - "\t\"a\"\n", - "\t\"m\"\n", - "\t\"m\"\n", - "\t\"m\"\n", - "\t\"m\"\n", - "]" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "s = df[\"ham\"]\n", - "s[[1, 2, 3]] = \"c\"\n", - "s" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Concat\n", - "Polars provide methods to cobine multiple `DataFrames` and `Series`.\n", - "We can concatenate a `DataFrame` with `hstack`." - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+---------+-----+-----+-----+---------+-----+-----+-----+\n", - "| new_foo | bar | ham | new | new_foo | bar | ham | new |\n", - "| --- | --- | --- | --- | --- | --- | --- | --- |\n", - "| f64 | i64 | str | i64 | f64 | i64 | str | i64 |\n", - "+=========+=====+=====+=====+=========+=====+=====+=====+\n", - "| 0.419 | 0 | \"h\" | 0 | 0.419 | 0 | \"h\" | 0 |\n", - "+---------+-----+-----+-----+---------+-----+-----+-----+\n", - "| 1 | 1 | \"h\" | 1 | 1 | 1 | \"h\" | 1 |\n", - "+---------+-----+-----+-----+---------+-----+-----+-----+\n", - "| 0.204 | 2 | \"h\" | 2 | 0.204 | 2 | \"h\" | 2 |\n", - "+---------+-----+-----+-----+---------+-----+-----+-----+\n", - "| 1 | 3 | \"a\" | 3 | 1 | 3 | \"a\" | 3 |\n", - "+---------+-----+-----+-----+---------+-----+-----+-----+\n", - "| 0.027 | 4 | \"a\" | 4 | 0.027 | 4 | \"a\" | 4 |\n", - "+---------+-----+-----+-----+---------+-----+-----+-----+" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# clones are super cheap!\n", - "df1 = df.clone()\n", - "df1.hstack(df.get_columns())\n", - "df1.head()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Or append rows from another `DataFrame`." - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "10 20\n" - ] - } - ], - "source": [ - "df1 = df.clone()\n", - "df1.vstack(df)\n", - "print(df.height, df1.height)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Joins\n", - "SQL-styel joins." - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+-------+------+------+\n", - "| key | lval | rval |\n", - "| --- | --- | --- |\n", - "| str | i64 | i64 |\n", - "+=======+======+======+\n", - "| \"foo\" | 1 | 4 |\n", - "+-------+------+------+\n", - "| \"foo\" | 2 | 4 |\n", - "+-------+------+------+\n", - "| \"foo\" | 1 | 5 |\n", - "+-------+------+------+\n", - "| \"foo\" | 2 | 5 |\n", - "+-------+------+------+" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "left = pl.DataFrame({'key': ['foo', 'foo'], 'lval': [1, 2]})\n", - "right = pl.DataFrame({'key': ['foo', 'foo'], 'rval': [4, 5]})\n", - "\n", - "left.join(right, left_on=\"key\", right_on=\"key\", how=\"inner\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Another example that can be given is:" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+-------+------+------+\n", - "| key | lval | rval |\n", - "| --- | --- | --- |\n", - "| str | i64 | i64 |\n", - "+=======+======+======+\n", - "| \"foo\" | 1 | 4 |\n", - "+-------+------+------+\n", - "| \"bar\" | 2 | 5 |\n", - "+-------+------+------+" - ] - }, - "execution_count": 24, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "left = pl.DataFrame({'key': ['foo', 'bar'], 'lval': [1, 2]})\n", - "right = pl.DataFrame({'key': ['foo', 'bar'], 'rval': [4, 5]})\n", - "\n", - "left.join(right, left_on=\"key\", right_on=\"key\", how=\"inner\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Grouping\n", - "\n", - "By \"group by\" we are referring to a process involving one or more of the following steps:\n", - "* **Splitting** the data into groups based on some criteria\n", - "* **Applying** a function to each group independently\n", - "* **Combining** the results into a data structure" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+-------+---------+----------+----------+\n", - "| A | B | C | D |\n", - "| --- | --- | --- | --- |\n", - "| str | str | f64 | f64 |\n", - "+=======+=========+==========+==========+\n", - "| \"foo\" | \"one\" | 1.134 | 0.902 |\n", - "+-------+---------+----------+----------+\n", - "| \"bar\" | \"one\" | -1.1e0 | 0.502 |\n", - "+-------+---------+----------+----------+\n", - "| \"foo\" | \"two\" | -1.72e-1 | 0.901 |\n", - "+-------+---------+----------+----------+\n", - "| \"bar\" | \"three\" | -8.78e-1 | -6.84e-1 |\n", - "+-------+---------+----------+----------+\n", - "| \"foo\" | \"two\" | 0.042 | -1.23e-1 |\n", - "+-------+---------+----------+----------+\n", - "| \"bar\" | \"two\" | 0.583 | -9.36e-1 |\n", - "+-------+---------+----------+----------+\n", - "| \"foo\" | \"one\" | -1.101e0 | -2.68e-1 |\n", - "+-------+---------+----------+----------+\n", - "| \"foo\" | \"three\" | 1.145 | 0.53 |\n", - "+-------+---------+----------+----------+" - ] - }, - "execution_count": 25, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df = pl.DataFrame({'A': ['foo', 'bar', 'foo', 'bar',\n", - " 'foo', 'bar', 'foo', 'foo'],\n", - " 'B': ['one', 'one', 'two', 'three',\n", - " 'two', 'two', 'one', 'three'],\n", - " 'C': np.random.randn(8),\n", - " 'D': np.random.randn(8)})\n", - "df" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+-------+----------+----------+\n", - "| A | C_sum | D_sum |\n", - "| --- | --- | --- |\n", - "| str | f64 | f64 |\n", - "+=======+==========+==========+\n", - "| \"foo\" | 1.048 | 1.942 |\n", - "+-------+----------+----------+\n", - "| \"bar\" | -1.395e0 | -1.117e0 |\n", - "+-------+----------+----------+" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df.groupby(\"A\").select([\"C\", \"D\"]).sum()" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+-------+---------+---------+---------+----------+----------+\n", - "| A | B | A_first | B_first | C_first | D_first |\n", - "| --- | --- | --- | --- | --- | --- |\n", - "| str | str | str | str | f64 | f64 |\n", - "+=======+=========+=========+=========+==========+==========+\n", - "| \"bar\" | \"three\" | \"bar\" | \"three\" | -8.78e-1 | -6.84e-1 |\n", - "+-------+---------+---------+---------+----------+----------+\n", - "| \"foo\" | \"one\" | \"foo\" | \"one\" | 1.134 | 0.902 |\n", - "+-------+---------+---------+---------+----------+----------+\n", - "| \"bar\" | \"two\" | \"bar\" | \"two\" | 0.583 | -9.36e-1 |\n", - "+-------+---------+---------+---------+----------+----------+\n", - "| \"foo\" | \"three\" | \"foo\" | \"three\" | 1.145 | 0.53 |\n", - "+-------+---------+---------+---------+----------+----------+\n", - "| \"bar\" | \"one\" | \"bar\" | \"one\" | -1.1e0 | 0.502 |\n", - "+-------+---------+---------+---------+----------+----------+\n", - "| \"foo\" | \"two\" | \"foo\" | \"two\" | -1.72e-1 | 0.901 |\n", - "+-------+---------+---------+---------+----------+----------+" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df.groupby([\"A\", \"B\"]).select_all().first()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Pivot tables\n", - "\n", - "Pivots create a summary table by a applying a groupby and defining a pivot column and values to aggregate." - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": {}, - "outputs": [], - "source": [ - "df = pl.DataFrame({'A': ['one', 'one', 'two', 'three'] * 3,\n", - " 'B': ['A', 'B', 'C'] * 4,\n", - " 'C': ['foo', 'foo', 'foo', 'bar', 'bar', 'bar'] * 2,\n", - " 'D': np.random.randn(12),\n", - " 'E': np.random.randn(12)})" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+---------+-------+----------+\n", - "| A | foo | bar |\n", - "| --- | --- | --- |\n", - "| str | f64 | f64 |\n", - "+=========+=======+==========+\n", - "| \"one\" | 1.245 | 1.939 |\n", - "+---------+-------+----------+\n", - "| \"three\" | 0.617 | -9.86e-1 |\n", - "+---------+-------+----------+\n", - "| \"two\" | 0.171 | -1.143e0 |\n", - "+---------+-------+----------+" - ] - }, - "execution_count": 29, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df.groupby(\"A\").pivot(pivot_column=\"C\", values_column=\"E\").sum()" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "+---------+-----+----------+----------+\n", - "| A | B | foo | bar |\n", - "| --- | --- | --- | --- |\n", - "| str | str | f64 | f64 |\n", - "+=========+=====+==========+==========+\n", - "| \"three\" | \"A\" | null | -6.37e-1 |\n", - "+---------+-----+----------+----------+\n", - "| \"one\" | \"C\" | 0.3 | 2.1 |\n", - "+---------+-----+----------+----------+\n", - "| \"two\" | \"A\" | 0.12 | null |\n", - "+---------+-----+----------+----------+\n", - "| \"one\" | \"A\" | -7.47e-1 | -3.52e-1 |\n", - "+---------+-----+----------+----------+\n", - "| \"two\" | \"C\" | 0.051 | null |\n", - "+---------+-----+----------+----------+\n", - "| \"three\" | \"B\" | 0.617 | null |\n", - "+---------+-----+----------+----------+\n", - "| \"one\" | \"B\" | 1.692 | 0.191 |\n", - "+---------+-----+----------+----------+\n", - "| \"two\" | \"B\" | null | -1.143e0 |\n", - "+---------+-----+----------+----------+\n", - "| \"three\" | \"C\" | null | -3.49e-1 |\n", - "+---------+-----+----------+----------+" - ] - }, - "execution_count": 30, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "pivotted = df.groupby([\"A\", \"B\"]).pivot(pivot_column=\"C\", values_column=\"E\").sum()\n", - "pivotted" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Numpy interops\n", - "\n", - "Polars has zero cost interaction with numpy's [ufunc](https://numpy.org/doc/stable/reference/ufuncs.html) functionality. This means that if a function/ method isn't supported by Polars, we can use numpy's without any overhead. Numpy will write the output to Polars/ arrow memory, and the null bitmask will keep null information." - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Series: 'bar' [f64]\n", - "[\n", - "\t0.804\n", - "\t-5.05e-1\n", - "\tnull\n", - "\t0.939\n", - "\tnull\n", - "\tnull\n", - "\t0.982\n", - "\t0.415\n", - "\t0.94\n", - "]" - ] - }, - "execution_count": 31, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "s = pivotted[\"bar\"]\n", - "np.cos(s)" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Series: 'bar' [f64]\n", - "[\n", - "\t0.529\n", - "\t8.168\n", - "\tnull\n", - "\t0.703\n", - "\tnull\n", - "\tnull\n", - "\t1.21\n", - "\t0.319\n", - "\t0.705\n", - "]" - ] - }, - "execution_count": 32, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.exp(s)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Clones\n", - "It was already mentioned previously, but clones are super cheap. This is due to the fact that the underlying memory backed by Polars is immutable. Slices and clones can be made with the guarantee that they will never be modified.\n", - "\n", - "Below we observe that cloning an array of 1e6 elements is almost 850x faster. The cost of clone a Polars Series is also constant and doesn't increase with memory size. Ideal for writing pure funcitons. The cost of cloning a DataFrame 10x the size is also very small." - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": {}, - "outputs": [], - "source": [ - "a = np.arange(int(1e6))" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "504 µs ± 20.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" - ] - } - ], - "source": [ - "%%timeit\n", - "np.copy(a)" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": {}, - "outputs": [], - "source": [ - "s = pl.Series(\"a\", a)" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "637 ns ± 22.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n" - ] - } - ], - "source": [ - "%%timeit\n", - "s.clone()" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(1000000, 10)\n" - ] - }, - { - "data": { - "text/plain": [ - "+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+\n", - "| a_0 | a_1 | a_2 | a_3 | a_4 | a_5 | a_6 | a_7 | a_8 | a_9 |\n", - "| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |\n", - "| i64 | i64 | i64 | i64 | i64 | i64 | i64 | i64 | i64 | i64 |\n", - "+=====+=====+=====+=====+=====+=====+=====+=====+=====+=====+\n", - "| 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |\n", - "+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+\n", - "| 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |\n", - "+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+\n", - "| 2 | 2 | 2 | 2 | 2 | 2 | 2 | 2 | 2 | 2 |\n", - "+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+" - ] - }, - "execution_count": 37, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df = pl.DataFrame({f\"a_{i}\": s.clone() for i in range(10)})\n", - "print(df.shape)\n", - "df.head(3)" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1.52 µs ± 34.8 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n" - ] - } - ], - "source": [ - "%%timeit\n", - "df.clone()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Performance\n", - "Let's check some performances with Pandas. There is probably quite some FFI overhead that can be reduced. I expect bett" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/opt/miniconda3/lib/python3.7/site-packages/ipykernel_launcher.py:2: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.\n", - " \n" - ] - } - ], - "source": [ - "import pandas as pd\n", - "from pandas.util.testing import rands\n", - "import time\n", - "import matplotlib.pyplot as plt" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "metadata": {}, - "outputs": [], - "source": [ - "def create_join_dfs(N = 10_000):\n", - " left_pivot = int(0.8 * N)\n", - " right_pivot = N - left_pivot\n", - "\n", - " indices = np.array([rands(10) for _ in range(N)], dtype=\"O\")\n", - " indices2 = np.array([rands(10) for _ in range(N)], dtype=\"O\")\n", - " key = np.tile(indices[:left_pivot], 10)\n", - " key2 = np.tile(indices2[:left_pivot], 10)\n", - "\n", - " left = pd.DataFrame({\"key\": key, \"key2\": key2, \"value\": np.random.randn(len(key))})\n", - " right = pd.DataFrame(\n", - " {\"key\": indices[right_pivot:], \"key2\": indices2[right_pivot:], \"value2\": np.random.randn(left_pivot)}\n", - " )\n", - " return left, right\n", - " \n", - "left, right = create_join_dfs()" - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "0.0001862" - ] - }, - "execution_count": 41, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def time_lambda(f: \"fn() -> ()\") -> float:\n", - " \"\"\"\n", - " eval time in ms\n", - " \"\"\"\n", - " t0 = time.time_ns() \n", - " for _ in range(10):\n", - " f()\n", - " return (time.time_ns() - t0) / 10 / 1e6\n", - "\n", - "time_lambda(lambda : 1)\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": 103, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "10\n", - "\tpandas\n", - "\tpolars\n", - "20\n", - "\tpandas\n", - "\tpolars\n", - "30\n", - "\tpandas\n", - "\tpolars\n", - "40\n", - "\tpandas\n", - "\tpolars\n", - "50\n", - "\tpandas\n", - "\tpolars\n", - "70\n", - "\tpandas\n", - "\tpolars\n", - "100\n", - "\tpandas\n", - "\tpolars\n", - "200\n", - "\tpandas\n", - "\tpolars\n", - "500\n", - "\tpandas\n", - "\tpolars\n", - "1000\n", - "\tpandas\n", - "\tpolars\n", - "2000\n", - "\tpandas\n", - "\tpolars\n" - ] - } - ], - "source": [ - "left_pl = []\n", - "left_pd = []\n", - "inner_pd = []\n", - "inner_pl = []\n", - "outer_pd = []\n", - "outer_pl = []\n", - "par_inner = []\n", - "par_left = []\n", - "n_proxy = []\n", - "\n", - "for N in [10, 20, 30, 40, 50, 70, 100, 200, 500, 1000, 2000]:\n", - " print(N)\n", - " N *= 1000\n", - " left, right = create_join_dfs(N)\n", - " f_left_pd = lambda: left.merge(right, on=\"key\", how=\"left\")\n", - " f_inner_pd = lambda: left.merge(right, on=\"key\", how=\"inner\")\n", - " f_outer_pd = lambda: left.merge(right, on=\"key\", how=\"outer\")\n", - " pd_left_t = time_lambda(f_left_pd)\n", - " pd_inner_t = time_lambda(f_inner_pd)\n", - " pd_outer_t = time_lambda(f_outer_pd)\n", - " \n", - " # create polars dfs\n", - " left = pl.DataFrame(left.to_dict(orient=\"list\"))\n", - " right = pl.DataFrame(right.to_dict(orient=\"list\"))\n", - " f_left_pl = lambda: left.join(right, left_on=\"key\", right_on=\"key\", how=\"left\")\n", - " f_inner_pl = lambda: left.join(right, left_on=\"key\", right_on=\"key\", how=\"inner\")\n", - " f_outer_pl = lambda: left.join(right, left_on=\"key\", right_on=\"key\", how=\"outer\")\n", - " pl_left_t = time_lambda(f_left_pl)\n", - " pl_inner_t = time_lambda(f_inner_pl)\n", - " pl_outer_t = time_lambda(f_outer_pl)\n", - " \n", - " f_left_par = lambda: left.join(right, left_on=\"key\", right_on=\"key\", how=\"left\", parallel=True)\n", - " f_inner_par = lambda: left.join(right, left_on=\"key\", right_on=\"key\", how=\"inner\", parallel=True)\n", - " par_left_t = time_lambda(f_left_par)\n", - " par_inner_t = time_lambda(f_inner_par)\n", - " \n", - " # pandas\n", - " print(\"\\tpandas\")\n", - " left_pd.append(pd_left_t)\n", - " inner_pd.append(pd_inner_t)\n", - " outer_pd.append(pd_outer_t)\n", - " \n", - " # polars\n", - " print(\"\\tpolars\")\n", - " left_pl.append(pl_left_t)\n", - " inner_pl.append(pl_inner_t)\n", - " outer_pl.append(pl_outer_t)\n", - " \n", - " # parallel polars\n", - " par_left.append(par_left_t)\n", - " par_inner.append(par_inner_t)\n", - " n_proxy.append(N)\n", - " del left\n", - " del right\n", - " \n" - ] - }, - { - "cell_type": "code", - "execution_count": 109, - "metadata": {}, - "outputs": [], - "source": [ - "def make_fig(how):\n", - " plt.figure(figsize=(24, 6))\n", - " if how == \"inner\":\n", - " plt.plot(n_proxy, left_pd, label=\"pandas left\")\n", - " plt.plot(n_proxy, left_pl, label=\"polars left\")\n", - " plt.plot(n_proxy, par_left, label=\"polars left parallel\")\n", - " elif how == \"outer\":\n", - " plt.plot(n_proxy, outer_pl, label=\"polars outer\")\n", - " plt.plot(n_proxy, outer_pd, label=\"pandas outer\")\n", - " else:\n", - " plt.plot(n_proxy, inner_pl, label=\"polars inner\")\n", - " plt.plot(n_proxy, inner_pd, label=\"pandas inner\")\n", - " plt.plot(n_proxy, par_inner, label=\"polars inner parallel\")\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": 110, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAABWoAAAFlCAYAAAByV9G0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nOzde3TV9Z3v/+c3F3LhmpBky9VwC9lBETSAFlAuidrWVs+0ndYlrda2Kqe22jnOaHumZ9rp8bd6zo817dS29ueattr+enO1xxl+ts6YDVK0ooKjg7o3CEoEFPYOl1y4JCQ7n98fybC0AqKiO4HnYy0XO5/9/X7z2kmI8OKT9zcKISBJkiRJkiRJyp28XAeQJEmSJEmSpDOdRa0kSZIkSZIk5ZhFrSRJkiRJkiTlmEWtJEmSJEmSJOWYRa0kSZIkSZIk5ZhFrSRJkiRJkiTlWEGuA7yVioqKUF1dnesYkiRJkiRJkvSuPP3003tCCJXHem7AF7XV1dVs2LAh1zEkSZIkSZIk6V2JouiV4z3n6ANJkiRJkiRJyjGLWkmSJEmSJEnKMYtaSZIkSZIkScqxAT+j9li6u7vZuXMnnZ2duY6i90lxcTHjx4+nsLAw11EkSZIkSZKkU25QFrU7d+5k+PDhVFdXE0VRruPoPRZCYO/evezcuZNJkyblOo4kSZIkSZJ0yg3K0QednZ2MHj3akvYMEUURo0ePdge1JEmSJEmSTluDsqgFLGnPMH6+JUmSJEmSdDobtEXt6WrRokVs2LDhXV3j3nvv5eabbz7hMV1dXTQ0NDBr1ix+85vf8N3vfpdDhw69q/crSZIkSZIk6Z0ZlDNq9e4988wzdHd38+yzzwJQXV3NsmXLKC0tzXEySZIkSZIk6czjjtp3oLm5mdraWq699lpmzpzJxz/+8aO7Uf/+7/+eOXPmcM4553DDDTcQQgD6dsrefvvtzJ07l5qaGh599FEADh8+zKc+9SlmzpzJJz/5SQ4fPnz0/Sxfvpz6+npmzJjB3/3d3x1dv+OOO6irq2PmzJncdtttJ8za0tLCxz72MebMmcOcOXP405/+RCaTYdmyZTz77LPMmjWLf/zHf+S1115j8eLFLF68+FR/uCRJkiRJkiS9hUG/o/ab/98LJF9rP6XXrBs7gr/7yIwTHrN582Z+/OMfM3/+fK6//np++MMfctttt3HzzTfzP/7H/wDg05/+NA8++CAf+chHAOjp6eGpp57iD3/4A9/85jdJJBLcfffdlJaWsnHjRjZu3Mj5559/9H3ceeedlJeXk81mWbp0KRs3bmT8+PE88MADbNq0iSiKaG1tPWHOW265ha985SssWLCA7du3c9lll5FKpfinf/onVqxYwYMPPgjAd77zHR555BEqKirezYdOkiRJkiRJ0jvgjtp3aMKECcyfPx+AZcuW8dhjjwHwyCOPMG/ePM4991xWr17NCy+8cPScv/iLvwDgggsuoLm5GYC1a9eybNkyAGbOnMnMmTOPHn///fdz/vnnM3v2bF544QWSySQjRoyguLiYz3/+8/yf//N/3nJUQSKR4Oabb2bWrFl89KMfpb29nY6OjlP2cZAkSZIkSZLejhACW9IdPPny3lxHGVAG/Y7at9r5+l6JouhNb3d2dvJf/+t/ZcOGDUyYMIFvfOMbdHZ2Hj2mqKgIgPz8fHp6eo57LYBt27axYsUK1q9fT1lZGddddx2dnZ0UFBTw1FNPsWrVKn7961/z/e9/n9WrVx83Z29vL+vWraOkpOTdvmRJkiRJkiTpHenJ9rLhlf0kkmkSqTTNew9Re9Zw/vXWi3MdbcBwR+07tH37dtatWwfAr371KxYsWHC0lK2oqODAgQP89re/fcvrXHzxxfziF78A4Pnnn2fjxo0AtLe3M3ToUEaOHEk6neahhx4C4MCBA7S1tfGhD32I7373u0dvBnY8l156Kd///vePvn2844cPH+5OW0mSJEmSJJ0yB7p6+MNzu/ir3zxL/Z0JPnXPE/xs3SucPXoo37rqHH762Tm5jjigDPodtbkSj8e57777uPHGG5k2bRrLly+ntLSUL3zhC5x77rlUV1czZ85bf7EtX76cz372s8ycOZNZs2Yxd+5cAM477zxmz57NjBkzmDx58tExCx0dHVx55ZV0dnYSQuA73/nOCa//ve99jy9+8YvMnDmTnp4eLr74Yn70ox+96bgbbriBD37wg4wZM4ZHHnnkHXxEJEmSJEmSdKbb1XaYRCpDUzLNEy/t5Ui2l1GlhSyZXkVDXYyF0yoYXlyY65gDUhRCyHWGE6qvrw8bNmx4w1oqlSIej+coETQ3N3PFFVfw/PPP5yzDmSjXn3dJkiRJkiS9UQiB5K52EskMTandPP9qOwBnjy6lMR6joS5G/dllFOT7g/0AURQ9HUKoP9Zz7qiVJEmSJEmSdNKO9PTy5La9NCXTJJJpXmvrJIpg9oRR/M3l02mMx5haNeyY92XS8VnUvgPV1dXuppUkSZIkSdIZo+1QN49sztCUSvPHzS0c6OqhuDCPBVMruaVhGktqY1QOL8p1zEHNolaSJEmSJEnSm2zfe4imVN+u2aea95HtDVQMG8KHzx1DY12M+VMrKBmSn+uYpw2LWkmSJEmSJEn09gb+Y2criVSaRDLD5nQHANOqhnHjxZNpqIsxa/wo8vIcafBesKiVJEmSJEmSzlCd3Vn+tHVPXzmbytDS0UV+XsSc6jL+9sNxGuIxqiuG5jrmGcGiVpIkSZIkSTqD7DnQxepNGZqSaR7d0kJndy9Dh+SzaHoVDXVVLKqpomzokFzHPONY1L5PFi1axIoVK6ivrz9l11yzZg0rVqzgwQcfPOFxV199NS+88AKf/exnKSsr49JLL2Xs2LGnLIckSZIkSZIGrhACL7UcJJFK05RM8+/b9xMCjBlZzCcumEBDXYwLJ5dTVOC82VyyqB2gstks+fnv/jfH7t27efzxx3nllVeAvsL4nHPOsaiVJEmSJEk6jfVke/n37a00JXeTSGXYtucgADPGjuDLS6bRWBdjxtgRRJHzZgcKi9p3oLm5mcsvv5x58+bxzDPPUFNTw89+9jNKS0tZtWoVt912Gz09PcyZM4e7776boqKiN5y/fPly1q9fz+HDh/n4xz/ON7/5TQCqq6u5/vrrefjhh7n55pvJZDL86Ec/oqCggLq6On79618fN9PBgwf50pe+xHPPPUdPTw/f+MY3uPLKK7n00kvJZDLMmjWLj33sY2zYsIFrrrmGkpIS1q1bR0lJyXv6sZIkSZIkSdL742BXD2tfbKEpleaRTRn2H+qmMD/iwsmj+ez8apbGY4wbZRc0UA3+ovahO2D3c6f2mmedCx/89gkP2bx5Mz/+8Y+ZP38+119/PT/84Q+5+eabue6661i1ahU1NTV85jOf4e677+bWW299w7l33nkn5eXlZLNZli5dysaNG5k5cyYAxcXFPPbYYwCMHTuWbdu2UVRURGtr6wnz3HnnnSxZsoSf/OQntLa2MnfuXBoaGli5ciVXXHEFzz77LACrVq065SMYJEmSJEmSlBu72zr7bwSW5vGtezmS7WVEcQFLaqtorDuLi2sqGF5cmOuYOgmDv6jNkQkTJjB//nwAli1bxve+9z0aGxuZNGkSNTU1AFx77bX84Ac/eFNRe//993PPPffQ09PDrl27SCaTR4vaT37yk0ePmzlzJtdccw1XXXUVV1111QnzPPzww6xcuZIVK1YA0NnZyfbt290xK0mSJEmSdBoJIZDa1XG0nN24sw2AieWlfPqis2mIx6ivLqMwPy/HSfV2Df6i9i12vr5X/nx+RxRFhBDe8rxt27axYsUK1q9fT1lZGddddx2dnZ1Hnx86dOjRx7///e9Zu3YtK1eu5Fvf+hYvvPACBQXH/pSFEPjd737H9OnT37De3Nz8Nl6VJEmSJEmSBpojPb08tW3f0ZuBvdp6GIDZE0fx15dNp7EuxrSqYc6bHeQGf1GbI9u3b2fdunVcdNFF/OpXv2LBggXU1tbS3NzM1q1bmTp1Kj//+c+55JJL3nBee3s7Q4cOZeTIkaTTaR566CEWLVr0puv39vayY8cOFi9ezIIFC/jlL3/JgQMHGDVq1DHzXHbZZdx1113cddddRFHEM888w+zZs9903PDhw+no6DglHwNJkiRJkiS9N9oOd7Nmc4ZEKsOazRk6OnsoKshj4bQKvrRkKkviVVQNL851TJ1CFrXvUDwe57777uPGG29k2rRpLF++nOLiYn7605/yiU984ujNxG666aY3nHfeeecxe/ZsZsyYweTJk4+OT/hz2WyWZcuW0dbWRgiBr3zlK8ctaQG+/vWvc+uttzJz5kxCCFRXV/Pggw++6bjrrruOm266yZuJSZIkSZIkDTA79h06OtLgyZf30dMbGD10CB885ywa4jEWTKugdIh13ukqOpkf18+l+vr6sGHDhjespVIp4vF4jhL1jRO44ooreP7553OW4UyU68+7JEmSJEnSqdTbG3ju1Taakn3l7KbdfT8FPbVqGA3xGI11VcyaUEZ+niMNThdRFD0dQqg/1nMnVcFHUTQK+CfgHCAA1wObgd8A1UAz8JchhP39x38V+ByQBb4cQvi3/vULgHuBEuAPwC1hoDfFkiRJkiRJ0inS2Z3l8Zf20JTMsCqVJtPRRV4E9dXl/PcPxWmoizGpYuhbX0innZPdK/2PwL+GED4eRdEQoBT4GrAqhPDtKIruAO4Abo+iqA74FDADGAskoiiqCSFkgbuBG4An6CtqLwceOqWv6H1QXV3tblpJkiRJkiSdlL0Huli9KUMilWbti3s43J2ldEg+l9RU0lgXY/H0KsqGDsl1TOXYWxa1URSNAC4GrgMIIRwBjkRRdCWwqP+w+4A1wO3AlcCvQwhdwLYoirYCc6MoagZGhBDW9V/3Z8BVDMKiVpIkSZIkSTqRl1oOkOgfafD0K/vpDXDWiGI+dsE4GuIxLpw8muLC/FzH1AByMjtqJwMtwE+jKDoPeBq4BYiFEHYBhBB2RVFU1X/8OPp2zP6nnf1r3f2P/3z9TaIouoG+nbdMnDjxpF+MJEmSJEmSlAvZ3sC/b99PIpmmKZXm5ZaDAMTHjODmJdNojMc4Z9wIosh5szq2kylqC4DzgS+FEJ6Mougf6RtzcDzH+moLJ1h/82II9wD3QN/NxE4ioyRJkiRJkvS+OtjVw6Nb9pBIpVm9KcO+g0coyIu4aMporr2omqXxKsaXleY6pgaJkylqdwI7QwhP9r/9W/qK2nQURWP6d9OOATKvO37C684fD7zWvz7+GOuSJEmSJEnSoJBu72RVqm/e7GNb93Ckp5cRxQUsrq2iIR7jkumVjCguzHVMDUJ5b3VACGE3sCOKoun9S0uBJLASuLZ/7VrgX/ofrwQ+FUVRURRFk4BpwFP9YxI6oii6MOrb4/2Z151z2lu0aBEbNmw4pddcs2YNV1xxxVsed/XVVzNz5ky+853vcO+99/Laa+9/Pz5s2DAAmpubOeecc0547MkcI0mSJEmS9H4IIbBpdzvfX72FK7//GPP+r1V87YHneDHdwTXzJvLLz8/j6a838o+fms1HzhtrSat37GR21AJ8CfhFFEVDgJeBz9JX8t4fRdHngO3AJwBCCC9EUXQ/fWVuD/DFEEK2/zrLgXuBEvpuIuaNxI4jm82Sn//uB0rv3r2bxx9/nFdeeQXoK4zPOeccxo4d+66v/Xo9PT0UFJzsl5MkSZIkSdLA1Z3tZf22fTzcfzOwnfsPA3DehFHcdmkNDXUxpseGO29Wp9RJNWshhGeB+mM8tfQ4x98J3HmM9Q3AoN8q2dzczOWXX868efN45plnqKmp4Wc/+xmlpaWsWrWK2267jZ6eHubMmcPdd99NUVHRG85fvnw569ev5/Dhw3z84x/nm9/8JgDV1dVcf/31PPzww9x8881kMhl+9KMfUVBQQF1dHb/+9a+Pm+ngwYN86Utf4rnnnqOnp4dvfOMbXHnllVx66aVkMhlmzZrFxz72MTZs2MA111xDSUkJ69ato6Sk5Og1Fi1axKxZs3jqqadob2/nJz/5CXPnzuWpp57i1ltv5fDhw5SUlPDTn/6U6dOnc++99/L73/+ezs5ODh48yMqVK7nyyivZv38/3d3d/M//+T+58sorj5s5m81yxx13sGbNGrq6uvjiF7/IjTfe+C4/O5IkSZIkSW9fe2c3aza3kEimeWRzho7OHoYU5LFgagVfXDyVpbVVVI0oznVMncYG/RbI//XU/2LTvk2n9Jq15bXcPvf2Ex6zefNmfvzjHzN//nyuv/56fvjDH3LzzTdz3XXXsWrVKmpqavjMZz7D3Xffza233vqGc++8807Ky8vJZrMsXbqUjRs3MnPmTACKi4t57LHHABg7dizbtm2jqKiI1tbWE+a58847WbJkCT/5yU9obW1l7ty5NDQ0sHLlSq644gqeffZZAFatWsWKFSuorz9W795X+D7++OOsXbuW66+/nueff57a2lrWrl1LQUEBiUSCr33ta/zud78DYN26dWzcuJHy8nJ6enp44IEHGDFiBHv27OHCCy/kox/96HH/denHP/4xI0eOZP369XR1dTF//nwuvfRS/zVKkiRJkiS9L3buP0QimSaRyvDEy3vp6Q2UDx3C5TPOoqEuxsJpFZQOGfT1mQYJv9LeoQkTJjB//nwAli1bxve+9z0aGxuZNGkSNTU1AFx77bX84Ac/eFNRe//993PPPffQ09PDrl27SCaTR4vaT37yk0ePmzlzJtdccw1XXXUVV1111QnzPPzww6xcuZIVK1YA0NnZyfbt29+wY/ZkXH311QBcfPHFtLe309raSkdHB9deey1btmwhiiK6u7uPHt/Y2Eh5eTnQN7Pla1/7GmvXriUvL49XX32VdDrNWWedddzMGzdu5Le//S0AbW1tbNmy5ejHT5IkSZIk6VQKIfDcq20kkmmaUhlSu9oBmFw5lM8tnERjPMbsiWXk57mJTO+/QV/UvtXO1/fKn+/6jKKIEMJbnrdt2zZWrFjB+vXrKSsr47rrrqOzs/Po80OHDj36+Pe//z1r165l5cqVfOtb3+KFF1447hzYEAK/+93vmD59+hvWm5ub38arOvbr+vrXv87ixYt54IEHaG5uZtGiRcfM+4tf/IKWlhaefvppCgsLqa6ufsNrO1bmu+66i8suu+xdZZYkSZIkSTqezu4s617e279zNk26vYu8COrPLudrH6plaTzGlMphuY4pkZfrAIPV9u3bWbduHQC/+tWvWLBgAbW1tTQ3N7N161YAfv7zn3PJJZe84bz29naGDh3KyJEjSafTPPTQse+n1tvby44dO1i8eDH/+3//b1pbWzlw4MBx81x22WXcddddR8viZ5555pjHDR8+nI6OjuNe5ze/+Q0Ajz32GCNHjmTkyJG0tbUxbtw4AO69997jntvW1kZVVRWFhYU88sgjR29gdqLMd99999Edui+++CIHDx484TmSJEmSJElvZd/BI/zu6Z3c9POnOf9bTXz2p+t54JlXmT2hjBWfOI/1/72B+2+6iBsunmJJqwFj0O+ozZV4PM59993HjTfeyLRp01i+fDnFxcX89Kc/5ROf+MTRm4nddNNNbzjvvPPOY/bs2cyYMYPJkycfHZ/w57LZLMuWLaOtrY0QAl/5ylcYNWrUcfN8/etf59Zbb2XmzJmEEKiurubBBx9803HXXXcdN9100zFvJgZQVlbGBz7wgaM3EwP4m7/5G6699lr+4R/+gSVLlhw3wzXXXMNHPvIR6uvrmTVrFrW1tcc9FuDzn/88zc3NnH/++YQQqKys5J//+Z9PeI4kSZIkSdKxbNtzsG+kQTLNhlf20RugangRV80eR2M8xkVTRlNcmJ/rmNJxRSfz4/q5VF9fHzZs2PCGtVQqRTwez1Givh/Nv+KKK3j++edzluG9sGjRohPeaCzXcv15lyRJkiRJA0e2N/Dsjv08nEyTSKZ5qaXvp3RrzxpOY12MhniMc8eNJM95sxpAoih6OoRwzPLNHbWSJEmSJEkaFA4d6eHRLXtIJNOs3pRh78EjFORFzJtczrILz6YhHmNCeWmuY0rviEXtO1BdXX3a7aYFWLNmTa4jSJIkSZIkvUGmo5NVqQyJZJrHtu6hq6eX4UUFLKqtorEuxiU1lYwsKcx1TOlds6iVJEmSJEnSgBFC4MX0ARKpvnmzz+5oBWDcqBKunjuRxroYc6rLGVKQl+Ok0qk1aIvaEAJR5IyRM8VAn6UsSZIkSZLeue5sL+ub95FIZkik0mzfdwiA88aP5L811tBQF6P2rOF2QTqtDcqitri4mL179zJ69Gh/g54BQgjs3buX4uLiXEeRJEmSJEmnSEdnN398sYVEMs0jm1toO9zNkII85k8ZzY2XTGZpbYyzRtoF6MwxKIva8ePHs3PnTlpaWnIdRe+T4uJixo8fn+sYkiRJkiTpXXi19TCr+kcaPPHyXrqzgbLSQhriMRrrqlg4rZKhRYOyrpLetUH5lV9YWMikSZNyHUOSJEmSJEknEELghdfaaUr2lbPJXe0ATKoYymfnT6IhHuOCs8vIz/MnpqVBWdRKkiRJkiRpYOrqyfLEy/toSu5mVSrDrrZOoggumFjGHR+spbEuxpTKYbmOKQ04FrWSJEmSJEl6V1oPHWH1pr4bgf1xcwsHj2QpKcxn4bQK/qqxhsW1VVQMK8p1TGlAs6iVJEmSJEnS29a85yCJ/nmzG17ZT7Y3UDm8iI/OGkdjXRUfmFJBcWF+rmNKg4ZFrSRJkiRJkt5Sb2/gmR2tJFJpEsk0WzIHAKg9azjLL5lCQ12MmeNGkue8WekdsaiVJEmSJEnSMR0+kuWxrXtIJNOs2pRmz4Ej5OdFzJtUztVzJ9IQjzFxdGmuY0qnBYtaSZIkSZIkHdXS0cXqTWmakhke29pCZ3cvw4sKuGR6JY11MRbVVDGytDDXMaXTjkWtJEmSJEnSGSyEwNbMAZr6580+u6OVEGDcqBI+WT+BhroY8yaNZkhBXq6jSqc1i1pJkiRJkqQzTE+2lw2v7KcpmSaRSvPK3kMAnDtuJLcuraGxLkZ8zHCiyHmz0vvFolaSJEmSJOkMcKCrh7UvttCUTLN6U4a2w90Myc/joimj+cLCySyNVzFmZEmuY0pnLItaSZIkSZKk09SutsMkkmmaUhmeeGkvR7K9jCotZGm8isZ4jIU1lQwrsh6SBgJ/J0qSJEmSJJ0mQgi88Fo7iVTfSIPnX20HoHp0Kdd+4Gwa4jEuOLuMgnznzUoDjUWtJEmSJEnSIHakp5cnXt7bV84m07zW1kkUwfkTy7j98loa66qYUjnMebPSAGdRK0mSJEmSNMi0HjrCms0tNKXS/HFzCwe6eiguzGPhtEpubahhcW0VlcOLch1T0ttgUStJkiRJkjQIbN97iKZUmqbkbtY37yfbG6gYVsQVM8fQEI+xYFoFxYX5uY4p6R2yqJUkSZIkSRqAensD/7GzlUQqTVMyzYvpAwDUxIZx48WTaaiLMWv8KPLyHGkgnQ4saiVJkiRJkgaIzu4sj23Z038zsAx7DnSRnxcxp7qMv/1wnMa6GGePHprrmJLeAxa1kiRJkiRJObTnQBerUxmaUmke3dJCZ3cvw4oKuGR6JY3xGIumVzKqdEiuY0p6j1nUSpIkSZIkvY9CCLzUcoCmZIZEKs2/b99PCDB2ZDF/WT+BhniMeZPLKSpw3qx0JrGolSRJkiRJeo/1ZHt5+pX9R0cabNtzEIBzxo3glqXTaIjHmDF2BFHkvFnpTGVRK0mSJEmS9B440NXDoy+20JRK88imDPsPdVOYH3HRlAqun1/N0niMsaNKch1T0gBhUStJkiRJknSK7G7r7N81m+bxrXs5ku1lZEkhS2qraIjHuLimguHFhbmOKWkAsqiVJEmSJEl6h0IIpHZ1kEilaUqmee7VNgAmlpfy6YvOpiEeY051GQX5eTlOKmmgs6iVJEmSJEl6G4709PLUtn00JXeTSGV4tfUwUQSzJoziry+bzqV1MaZWDXPerKS35aSK2iiKmoEOIAv0hBDqoygqB34DVAPNwF+GEPb3H/9V4HP9x385hPBv/esXAPcCJcAfgFtCCOHUvRxJkiRJkqRTr+1wN2s2Z2hKpvnj5hY6unooKshj4bQKvrx0Kotrq6gaXpzrmJIGsbezo3ZxCGHP696+A1gVQvh2FEV39L99exRFdcCngBnAWCARRVFNCCEL3A3cADxBX1F7OfDQKXgdkiRJkiRJp9SOfYdoSvbNm31q2z56egMVw4bwoXPH0FAXY8HUCkqG5Oc6pqTTxLsZfXAlsKj/8X3AGuD2/vVfhxC6gG1RFG0F5vbvyh0RQlgHEEXRz4CrsKiVJEmSJEkDQG9vYOOrbST6y9lNuzsAmFY1jC9cPJmGeIxZE0aRn+dIA0mn3skWtQF4OIqiAPw/IYR7gFgIYRdACGFXFEVV/ceOo2/H7H/a2b/W3f/4z9clSZIkSZJyorM7y+Mv7aEpmWFVKk2mo4u8COZUl/O3H46zNB5jUsXQXMeUdAY42aJ2fgjhtf4ytimKok0nOPZY/6wUTrD+5gtE0Q30jUhg4sSJJxlRkiRJkiTpre090MXqTRkSqTRrX9zD4e4sQ4fkc8n0ShriMRZPr6Js6JBcx5R0hjmpojaE8Fr/r5koih4A5gLpKIrG9O+mHQNk+g/fCUx43enjgdf618cfY/1Y7+8e4B6A+vp6bzYmSZIkSZLelZdaDpBIpmlKpnl6+35CgLNGFPOxC8bREI9x0ZTRFBU4b1ZS7rxlURtF0VAgL4TQ0f/4UuDvgZXAtcC3+3/9l/5TVgK/jKLoH+i7mdg04KkQQjaKoo4oii4EngQ+A9x1ql+QJEmSJElStjfw79v3990MLJnm5T0HAagbM4IvLZnGpXUxZowdQRQ5b1bSwHAyO2pjwAP937gKgF+GEP41iqL1wP1RFH0O2A58AiCE8EIURfcDSaAH+GIIIdt/reXAvUAJfTcR80ZikiRJkiTplDjY1cOjW/bQlEzzyOYM+w4eoTA/4sLJo7lufjVL4zHGjSrJdUxJOqYohIE9WaC+vvYpx98AACAASURBVD5s2LAh1zEkSZIkSdIAlG7vJJHq2zX7p5f2cqSnlxHFBSyuraKxLsbFNZWMKC7MdUxJAiCKoqdDCPXHeu5kbyYmSZIkSZKUcyEENu3uIJFMk0il+Y+dbQBMKC9h2byzaairYk51OYX5eTlOKklvj0WtJEmSJEka0LqzvTy1bV/fvNlUmp37DwMwa8Io/vqy6TTEY9TEhjlvVtKgZlErSZIkSZIGnLbD3fzxxRYS/fNmOzp7KCrIY8HUCr64eCpLa6uoGlGc65iSdMpY1EqSJEmSpAFhx75DrEqlaUqlefLlffT0BkYPHcLlM86ioS7GwmkVlA6xypB0evK7myRJkiRJyone3sDzr7WRSKZ5OJlm0+4OAKZUDuVzCydxaV2MWRPKyM9zpIGk059FrSRJkiRJet90dmdZ9/JempJpVqXSpNu7yIug/uxyvvahWhriMSZXDst1TEl631nUSpIkSZKk99S+g0dYvSlDIplm7ZYWDh3JUjokn4unVdJYF2NxbRXlQ4fkOqYk5ZRFrSRJkiRJOuVebjlAIpUmkcyw4ZV99AaIjSjiv8weR0NdjIsmj6a4MD/XMSVpwLColSRJkiRJ71q2N/DM9v00pdIkkmleajkIQHzMCG5ePJWGuhjnjB1JnvNmJemYLGolSZIkSdI7cuhID49u2UMimWb1pgx7Dx6hIC/iwsmj+fSFZ7M0HmNCeWmuY0rSoGBRK0mSJEmSTlqmvZNV/fNmH9u6h66eXoYXF7B4ehUNdTEWTa9kRHFhrmNK0qBjUStJkiRJko4rhMCL6b55sw8n0/zHjlYAxpeVcPXciTTWxZg7qZzC/LwcJ5Wkwc2iVpIkSZIkvUF3tpf1zftoSqZJpNLs2HcYgPPGj+S/NdbQOCPG9Nhwosh5s5J0qljUSpIkSZIk2ju7+ePmFhKpNI9sytDe2cOQgjwWTK1g+SVTWRqvIjaiONcxJem0ZVErSZIkSdIZ6tXWwyT6d80+8fJeurOB8qFDuHTGWTTEYyycVsHQIqsDSXo/+N1WkiRJkqQzRAiB519tpymVJpFMk9zVDsDkiqFcP38SDXUxzp9YRn6eIw0k6f1mUStJkiRJ0mmsqyfLupf2kkilSSQz7G7vJC+CC84u46sfrKWhLsaUymG5jilJZzyLWkmSJEmSTjP7Dx7hkc0ZEqk0f9zcwsEjWUoK87m4poL/Fq9hSW0Vo4cV5TqmJOl1LGolSZIkSToNNO85SCKVpimZZsMr+8n2BqqGF/HRWeNorKviA1MqKC7Mz3VMSdJxWNRKkiRJkjQIZXsDz+5opan/ZmBbMwcAqD1rOMsvmUJjXYxzx40kz3mzkjQoWNRKkiRJkjRIHD6S5bGte2hK7mb1pgx7DhwhPy9i3qRyrpk3kYZ4jAnlpbmOKUl6ByxqJUmSJEkawFo6uliV6ts1++iWPXT19DK8qIBFtVU0xKtYVFPFyNLCXMeUJL1LFrWSJEmSJA0gIQS2ZA4cHWnw7I5WQoBxo0q4em7frtm5k8oZUpCX66iSpFPIolaSJEmSpBzryfayvnk/if6ds6/sPQTAzPEj+UpDDQ3xGPExw4ki581K0unKolaSJEmSpBzo6Oxm7Yt7SKTSrN6Uoe1wN0Py8/jA1NF8YeFklsarGDOyJNcxJUnvE4taSZIkSZLeJ6+1HmZVKk1TKsO6l/bQnQ2MKi1kabyKxniMhTWVDCvyr+qSdCbyu78kSZIkSe+REAIvvNZOIpWmKZnmhdfaAageXcp1H6imse4szp84ioJ8581K0pnOolaSJEmSpFOoqyfLky/vO3ozsF1tnUQRnD+xjDs+WEtDPMaUyqHOm5UkvYFFrSRJkiRJ71LroSOs2dxCUzLNH19s4UBXD8WFeSycVslXGmtYUltFxbCiXMeUJA1gFrWSJEmSJL0Dr+w9eHTX7Prm/WR7A5XDi/jIeWNoiMeYP7WC4sL8XMeUJA0SFrWSJEmSJJ2E3t7AsztbSfSXsy+mDwAwPTacmy6ZTEM8xnnjR5GX50gDSdLbZ1ErSZIkSdJxdHZneWzLHhKpNIlUhj0HusjPi5hbXc7Xr5hIQ7yKs0cPzXVMSdJpwKJWkiRJkqTXaeno4pFNGZpSaR7d0kJndy/Digq4ZHoljfEYi6dXMbK0MNcxJUmnGYtaSZIkSdIZLYTASy0HaEpmaEru5pkdrYQAY0cW85f1E2isizFv0miGFOTlOqok6TRmUStJkiRJOuP0ZHt5+pX9R28G1rz3EADnjBvBLUun0VgXo27MCKLIebOSpPeHRa0kSZIk6YxwoKuHtS+2kEimWb05Q+uhbgrzIy6aUsHnFk5maW0VY0eV5DqmJOkMZVErSZIkSTpt7W7rpCmVJpFMs+6lvRzJ9jKqtJAl06toqIuxcFoFw4udNytJyr2TLmqjKMoHNgCvhhCuiKKoHPgNUA00A38ZQtjff+xXgc8BWeDLIYR/61+/ALgXKAH+ANwSQgin6sVIkiRJks5sIQSSu9pJJDMkUmmee7UNgLNHl/KZi86moS5G/dllFOQ7b1aSNLC8nR21twApYET/23cAq0II346i6I7+t2+PoqgO+BQwAxgLJKIoqgkhZIG7gRuAJ+grai8HHjolr0SSJEmSdEY60tPLk9v2kkimSaQyvNp6mCiC2RNG8TeXT6cxHmNq1TDnzUqSBrSTKmqjKBoPfBi4E/ir/uUrgUX9j+8D1gC396//OoTQBWyLomgrMDeKomZgRAhhXf81fwZchUWtJEmSJOltajvUzZoXMzQl0/xxcwsdXT0UF+axYGolX146lSW1MSqHF+U6piRJJ+1kd9R+F/gbYPjr1mIhhF0AIYRdURRV9a+Po2/H7H/a2b/W3f/4z9clSZIkSXpLO/YdoimZJpFK89S2ffT0BiqGDeFD546hoS7GgqkVlAzJz3VMSZLekbcsaqMougLIhBCejqJo0Ulc81g/SxJOsH6s93kDfSMSmDhx4km8S0mSJEnS6aa3N7Dx1TYSyTRNyTSb0x0ATKsaxhcunkxjXYxZ40eRl+dIA0nS4HcyO2rnAx+NouhDQDEwIoqi/xdIR1E0pn837Rgg03/8TmDC684fD7zWvz7+GOtvEkK4B7gHoL6+3puNSZIkSdIZorM7y+Mv7enfOZuhpaOL/LyIOdVl/O2H4zTEY1RXDM11TEmSTrm3LGpDCF8FvgrQv6P2thDCsiiK/m/gWuDb/b/+S/8pK4FfRlH0D/TdTGwa8FQIIRtFUUcURRcCTwKfAe46xa9HkiRJkjTI7D3QxapNGRLJNI9u2cPh7ixDh+SzaHoVDXVVLKqpomzokFzHlCTpPXWyM2qP5dvA/VEUfQ7YDnwCIITwQhRF9wNJoAf4Yggh23/OcuBeoIS+m4h5IzFJkiRJOgO91HKgb9dsMs3T2/cTAowZWczHLxhPQ12MCyeXU1TgvFlJ0pkjCmFgTxaor68PGzZsyHUMSZIkSdK7kO0NPP3KfhKpvnL25T0HAZgxdgQN8RiNdTFmjB1BFDlvVpJ0+oqi6OkQQv2xnns3O2olSZIkSTqug109PLqlhaZkhtWb0uw/1E1hfsSFk0dz3fxqlsZjjBtVkuuYkiQNCBa1kiRJkqRTJt3eeXTX7J9e2suRnl5GFBewpLaKhroYl9RUMry4MNcxJUkacCxqJUmSJEnvWAiBTbs7SCTTNKXSbNzZBsCE8hKWzTubxroY9dVlFObn5TipJEkDm0WtJEmSJOlt6c728tS2fTQl0zQl07zaehiAWRNG8deXTaexLsa0qmHOm5Uk6W2wqJUkSZIkvaW2w9388cUWmpJp1mzO0NHZQ1FBHgunVfClJVNZEq+ianhxrmNKkjRoWdRKkiRJko5px75DffNmU2mefHkfPb2B0UOH8MFzzqIhHmPBtApKh/jXSkmSTgX/jypJkiRJAqC3N/Dcq20kUn0jDTbt7gBgatUwPr9wMo11VcyaUEZ+niMNJEk61SxqJUmSJOkM1tmdZd1Le2lKpVmVSpNu7yIvgvrqcv77h+I01MWYVDE01zElSTrtWdRKkiRJ0hlm38EjrN6UIZFMs3ZLC4eOZCkdks8lNZU0xGMsqa2ibOiQXMeUJOmMYlErSZIkSWeAl1sOHB1p8PQr++kNEBtRxH+ZPY7GuhgXTh5NcWF+rmNKknTGsqiVJEmSpNNQtjfwzPb9NPWXsy+3HAQgPmYENy+ZRmM8xjnjRhBFzpuVJGkgsKiVJEmSpNPEoSM9PLplD03JNKs3Zdh38AgFeREXTh7NtRdVszRexfiy0lzHlCRJx2BRK0mSJEmDWKa9k0QqQyKV5rGtezjS08uI4gIW11bREI9xyfRKRhQX5jqmJEl6Cxa1kiRJkjSIhBDYnO4gkUzTlMrwHztaARhfVsI18ybSGI8xZ1I5hfl5OU4qSZLeDotaSZIkSRrgurO9rN+2j6ZUmkQqzY59hwE4b8Iobru0hoa6GNNjw503K0nSIGZRK0mSJEkDUHtnN3/c3EIileaRTRnaO3sYUpDHgqkVLL9kKg3xKqpGFOc6piRJOkUsaiVJkiRpgNi5/xCr+ufNPvHyXrqzgfKhQ7h0xlk01sVYOK2C0iH+NU6SpNOR/4eXJEmSpBwJIfD8q+00pdI0JdOkdrUDMLlyKNfPn0RjXYzZE8vIz3OkgSRJpzuLWkmSJEl6H3X1ZFn30l6akmlWpTLsbu8kL4ILzi7jax+qZWk8xpTKYbmOKUmS3mcWtZIkSZL0Htt/8AiPbM7QlEyz9sUWDh7JUjokn4unVdJQF2Px9EpGDyvKdUxJkpRDFrWSJEmS9B5o3nOQpmSaplSaDc376A1QNbyIK2ePozEe46IpoykuzM91TEmSNEBY1EqSJEnSKZDtDTy7Yz9Nyb6bgW3NHACg9qzhfHHxVBriMc4dN5I8581KkqRjsKiVJEmSpHfo8JEsj25pIZFKs3pThj0HjlCQFzFvcjnXzJtIQzzGhPLSXMeUJEmDgEWtJEmSJL0NmY5OVqf6ds0+umUPXT29DC8qYFFtFQ3xKhZNr2JkSWGuY0qSpEHGolaSJEmSTiCEwJbMAZqSaRKpNM/uaCUEGDeqhKvnTqSxLsac6nKGFOTlOqokSRrELGolSZIk6c/0ZHtZ37z/aDm7fd8hAGaOH8lfNdTQUBej9qzhRJHzZiVJ0qlhUStJkiRJQEdnN2tf3ENTcjePbG6h7XA3QwrymD9lNDdeMpmltTHOGlmc65iSJOk0ZVErSZIk6Yz1WuthEqk0Tck0T7y8l+5soKy0kIZ4jMa6KhZOq2RokX9tkiRJ7z3/xCFJkiTpjBFC4IXX2o+ONHjhtXYAJlUM5bPzJ9EQj3HB2WXk5znSQJIkvb8saiVJkiSd1rp6sjzx8j4S/eXsrrZOoggumFjGHR+spbEuxpTKYbmOKUmSznAWtZIkSZJOO62HjvDI5gyJZIY/vtjCga4eSgrzWTitgq801rCktoqKYUW5jilJknSURa0kSZKk08Irew/SlOybN7vhlf1kewOVw4v4yHljaKyL8YEpFRQX5uc6piRJ0jFZ1EqSJEkalHp7A8/ubCXRX85uyRwAYHpsOMsvmUJDXYyZ40aS57xZSZI0CFjUSpIkSRo0Dh/J8qete2hKplm1KcOeA13k50XMm1TO1XMn0hCPMXF0aa5jSpIkvW0WtZIkSZIGtJaOLlZvStOUzPDY1hY6u3sZXlTAJdMraayLsaimipGlhbmOKUmS9K5Y1EqSJEkaUEIIbM0coCmVJpFM88yOVkKAcaNK+GT9BBrqYsybNJohBXm5jipJknTKWNRKkiRJyrmebC8bXtlPIpkmkUrTvPcQAOeOG8mtS2toqKuibswIosh5s5Ik6fT0lkVtFEXFwFqgqP/434YQ/i6KonLgN0A10Az8ZQhhf/85XwU+B2SBL4cQ/q1//QLgXqAE+ANwSwghnNqXJEmSJGkwONDVw9oXW0gk06zenKH1UDdD8vO4aMpoPrdwMg3xKsaMLMl1TEmSpPfFyeyo7QKWhBAORFFUCDwWRdFDwF8Aq0II346i6A7gDuD2KIrqgE8BM4CxQCKKopoQQha4G7gBeIK+ovZy4KFT/qokSZIkDUi72g6TSGVIJNOse2kvR7K9jCotZMn0KhrrYiysqWRYkT/4J0mSzjxv+Seg/h2vB/rfLOz/LwBXAov61+8D1gC396//OoTQBWyLomgrMDeKomZgRAhhHUAURT8DrsKiVpIkSTpthRBI7monkczQlNrN86+2A1A9upRrP3A2DfEYF5xdRkG+82YlSdKZ7aT+qTqKonzgaWAq8IMQwpNRFMVCCLsAQgi7oiiq6j98HH07Zv/Tzv617v7Hf74uSZIk6TRypKeXJ7ft7Z83m+HV1sNEEcyeMIrbL6+lsa6KKZXDnDcrSZL0OidV1PaPLZgVRdEo4IEois45weHH+tNWOMH6my8QRTfQNyKBiRMnnkxESZIkSTnUdqibNS9meDiZZu3mFjq6eiguzGPhtEpuWTqNxbVVVA4vynVMSZKkAettDX8KIbRGUbSGvtmy6SiKxvTvph0DZPoP2wlMeN1p44HX+tfHH2P9WO/nHuAegPr6em82JkmSJA1A2/ceoimVJpFM81TzPrK9gYphRXx45hga4jEWTKuguDA/1zElSZIGhbcsaqMoqgS6+0vaEqAB+F/ASuBa4Nv9v/5L/ykrgV9GUfQP9N1MbBrwVAghG0VRRxRFFwJPAp8B7jrVL0iSJEnSe6O3N7Dx1TaakrtJJDNsTncAUBMbxo0XT6ahLsas8aPIy3OkgSRJ0tt1MjtqxwD39c+pzQPuDyE8GEXROuD+KIo+B2wHPgEQQnghiqL7gSTQA3yxf3QCwHLgXqCEvpuIeSMxSZIkaQDr7M7yp617SKT65s22dHSRnxcxp7qMv/1wnMa6GGePHprrmJIkSYNeFMLAnixQX18fNmzYkOsYkiRJ0hljz4EuVm/KkEimeXTLHg53ZxlWVMAlNZU01sVYNL2SUaVDch1TkiRp0Imi6OkQQv2xnntbM2olSZIknX5CCLzUcrBv12wyzdPb9xMCjBlZzMcvGE9jXYx5k8spKnDerCRJ0nvFolaSJEk6A/Vke/n37a1982ZTGbbtOQjAjLEjuGXpNBriMWaMHUEUOW9WkiTp/WBRK0mSJJ0hDnb18OiWFh5OpnlkU4b9h7opzI+4aEoF18+vZmk8xthRJbmOKUmSdEayqJUkSZJOY7vbOvtvBJbm8a17OZLtZWRJIUtqq2iIx7i4poLhxYW5jilJknTGs6iVJEmSTiMhBFK7Oo6Wsxt3tgEwsbyUT190Ng3xGHOqyyjIz8txUkmSJL2eRa0kSZI0yHVne3ny5X0kUmmakmlebT1MFMGsCaP468umc2ldjKlVw5w3K0mSNIBZ1EqSJEmDUNvhbtZszpBIZVizOUNHZw9FBXksnFbBl5dOZXFtFVXDi3MdU5IkSSfJolaSJEkaJHbsO3R0pMGTL++jpzcweugQPnjO/8/enQZHct5ngn/euu/KOnCjCld3o092gWySoilRoklKlOyV5PFYI69nJY8VVqyPGDv2CMmzG7vz0bMb44jZ2Q07tLsOWRH2SNqdmRh/sMNDyTOj3VjJNqVukhJbvLqbDXTjaAB131X57ofMysqsymo0uwFUAXh+ERWVlajMriKz0cCDB/93Ei+dn8SHTyXh9ziH/TKJiIiI6CEwqCUiIiIiGlGqKvHGnbwx0uCnG0UAwKnxEH79uUW8eG4CmZQCp4MjDYiIiIiOOga1REREREQjpNZs4/vv7eCV65v47vVNbBbqcAjgyfk4/vufO4cXzk1gIRkc9sskIiIion3GoJaIiIiIaMh2yw389U+38MqbG/h/3tlGpdFG0OPER5fH8OK5CTy/PI5Y0DPsl0lERET06FoNoLiu3SCA9NPDfkUjg0EtEREREdEQ3LhXwitvavNmf/h+FqoEJiM+/L3HZ/DiuQk8s5SA18V5s0RERHRESAlUs1oAW1gHinft7yvb3WNmngB+/a+H95pHDINaIiIiIqJD0FYlfnQ7i++8uYlXrm/ixr0yAOD8VAS//bOn8fHzE7gwHYEQnDdLREREI6bVAEobAwLYdaBwFyhuAK1q/7GBJBCZAsLTWjAbnu4+VtKH/15GGINaIiIiIqIDUmm08L23t/Gd65v4659uYbfcgNsp8KHFBL74zDxePD+BGcU/7JdJREREJ5WUQC03IHg13Zfv9R/r9JoC2MeB8BQQmbbehycBl/fw39cRxaCWiIiIiGgfbRZq+O71LXzn+ib+33e30WipiPhceP7sOF46P4Hnzowh4nMP+2USERHRcddpwRY3+oNXczBr24JNdJuv0yv9AWxkGvDHAP4m0L5iUEtERERE9AiklHhrs6iNNHhzE6+t5QEAqbgf//DpObx4fhxPzsfhdjqG/EqJiIjoWHjgFuw2AGk91unVWq6RaWAqAyx/Sg9ep0wjCabYgh0SBrVERERERB9Qs63i727u4pXr2mJgq7taEyWTUvDffmIZL56bwJmJEOfNEhER0QfTbmoNWLvgtRPMFjeAZqX/WH+823idzlhnwXbuA3G2YEcYg1oiIiIiogdQqDXxn966h1fe3MR/eGsLxVoLXpcDHz6VxG9+7BReODuO8Yhv2C+TiIiIRpGUQC0/ePxA5758D/0tWE933MBUBlie1lqxllmwU4CbX4ccdQxqiYiIiIgGWMtW8J03N/Gd61v4wY0dtFSJRNCDly9M4sXzE/jI6SQCHn5JTUREdKK1m0Bp0z54NQeze7Vgpy6zBXvC8atKIiIiIiKdlBI/vlPAK29u4JXrW7i+XgAALI0F8aWPLOClcxNYScfgdPCbJSIiomNPSqBe2DuALW3BvgU7qQWtU48BZ17uzn9lC5YGYFBLRERERCdavdXG//feDr7z5ia+e30LG4UaHAK4MhfHP/nUWbx4bgKLY6Fhv0wiIiLaT+2W1oK1nQVrGk3QLPcf6491G6+Tl7rBq/k+kGALlj4wBrVEREREdKJIKbGer+H77+3gO9c38b2376HcaCPgceK502N46fwEnj87jnjQM+yXSkRERB/UoBZscaNnFuwWIFXrsQ63HrROARMXgdMf7wlgOy1Y/3DeGx17DGqJiIiI6Fgr11t4fS2Pa6s5XFvN4urtHLaKdQDARMSLz67M4MXzE3hmMQGf2znkV0tEREQDPUoL1qd0A9eJCwNmwSYAh+Pw3xeRjkEtERERER0bbVXina0irt3O6cFsDm9vFqHqY+PmEwH8zFICmZSCJ+biuDAdgYPzZomIiIavVhgcvD5KCzY8qW2zBUtHAINaIiIiIjqytgo1XNUD2Wu3c3h9LYdyow0AiPrduJxS8PELk1hJKbicUjjOgIiI6LC1W1rAarsgl2lhrkap/1i2YOmEYVBLREREREdCtdHGG3fyuLaaNYLZu/kaAMDlEDg/HcEvPjGLTEpBJqVgIRmE4CIeREREB6deHBy8dhqxpU2bFqxLb7tOARPngVMv9gSw+s0TGM77IhoSBrVERERENHJUVeLGdglX9REGV2/n8NZmEW19hsFszI/H52L4tZSClXQMF6YjnC9LRES0X9Q2UNqyD17NIwkaxf5jfdFu4Dp+vhu8GgtyTQOBJFuwRDYY1BIRERHR0G2X6pa5sq+t5VCstQAAYa8Ll1MKfuOjS8joIwzGwt4hv2IiIqIjyq4FW9ywBrClTUC2rcc5XEBoUgtex84CSz/bH8CyBUv0SBjUEhEREdGhqjXb+MndghHKXr2dxVq2CgBwOgSWJ8L4zy5PI5NS8HhawWIyxAW/iIiI9rJfLdixs93FucyzYINjbMESHTAGtURERER0YKSUuLldNkLZa6s5XF8voNnWRhhMR33IpBV84Zk5ZFIxXJqJwu/hCAMiIiKLemlA8LofLdhJwBMczvsiIgsGtURERES0b7LlBq6taQt9XV3N4bXVHPLVJgAg4HHisdkovvThRaykFaykFIxHfEN+xUREREOktoHyPfvg1dyMrRf6j/VGu/Nfl9iCJToOGNQSERER0UNptFS8uV7AtdtZoy17a6cCAHAI4MxEGJ+8OIlMSkEmreD0eBhOjjAgIqKTolHunwXbG8AWN/pbsMKptVzDU8DYGWDxY/0BbGSKLViiY4hBLRERERHtSUqJ1d0qrq52Q9mf3Cmg0VYBAONhLzIpBZ97MoVMSsFjswpCXn6pSUREx5Cqai1Yu+DV3Iit5/uP9Ua6zdfkR7uNWPMoguAY4OAYIKKTiF89ExEREVGffLWJ10xzZV9bzWGn3AAA+NwOPDaj4FefndfasikFU1EfhGBbloiIjji7Fmxxo2cW7AagtqzHCScQmtCC18QpYOE5m1mwU4A3NJz3RTRipJQoNUuotWoYC4wN++WMDAa1RERERCdcs63irY0irq52ZstmceNe2fj4qfEQnj87jkxKwUpawfJEGC4n590REdERsl8t2IWP9ASw+iiC0DhbsHTiNdUmcrUcdmu72KnuYKe2o23XdrBb3e1u13axW91FQ21gZXwF3/jkN4b90kcGg1oiIiKiE0RKiTu5qtaUva21Zd+4k0e9pY0wSIY8yKQU/L2VGWRSMTyWiiLicw/5VRMREd1Ho9ITuNoszGXbgnUAock9WrCTgDc8nPdFNGRSSlRaFexUu4FrZ7tzMz/O1XO253E73Ej4E4j74oj74jitnEbcH0fCl0A6nD7kdzXaGNQSERERHWOleguvr+a0tuxqDldv57BdqgMAPC4HLk5H8CtPzyGTVrCSUjAb83OEARERjQZVBSrb9sGruRlbs2nBesLd+a9swRIZWmoLuXrOEr7atV072/V23fY8EU8EcV8cCX8CS8oSnvQ9iYQ/gYQvYezvBLMhd4hfXz6gPYNaIUQKs7LwOAAAIABJREFUwDcATAJQAXxNSvkvhBBxAN8CMA/gFoDPSSmz+jG/B+BLANoA/rGU8q/0/U8A+DoAP4C/APA7Ukq5v2+JiIiI6GRqtVW8vVnS58pqi369s1VC56utxWQQz51OIpPW5sqenYzA4+IIAyIiGoJmdXDw2hlJUNwA1Kb1OOHQZsGGp4DEEjD/4W7war5nC5ZOkEqz0h0zYGq42m3n6jlI9EdxLodLC1h9CcT9cSwqi0bo2mm/doLXuC8Ot5O/cXUQHqRR2wLwX0spfySECAP4oRDiFQC/CuC7UsrfF0J8FcBXAXxFCHEewOcBXAAwDeA7QogzUso2gD8E8GUAP4AW1L4M4C/3+00RERERnQQb+RqurWZxVW/K/vhOHpVGGwCgBNzIpBR86tIUVtIxXJ6NQgl4hvyKiYjo2FNVoLJjE7z2zISt2fyKtCfUbbzOPWsfwAbHASd/OZiOt7ba1lqvpobroMbrbm0X1VbV9jxhd9hoti5EF3Bl8ooRtJobrwl/AmF3mK3XEbDnZzcp5TqAdX27KIS4DmAGwGcAfEx/2p8A+I8AvqLv/6aUsg7gphDiXQBPCSFuAYhIKb8PAEKIbwD4LBjUEhEREe2p0mjh9bW8ZbbsRqEGAHA7Bc5PR/G5KylkUlpbdi4R4BfbRES0v5pVUwN2wCzY4rp9CzY4rgWtsQVg7mdsZsFOAb7IcN4X0SGotqp7tl074Wu2lrVvvQqXpeE6H5nvb7yatj1O/pD+qPlAP4YSQswDWAHwNwAm9BAXUsp1IcS4/rQZaI3ZjjV9X1Pf7t1PRERERCaqKvHuvRKu3e7Oln1rowBV/3o9HQ/gqYU4MikFK2kF56cj8Lo4Y4+IiB7Sg7Rgi+tANdt/rKUF+zNswdKJoUoVuXqur+3aG7p2Hg9qvYbcIaPVOheZw8r4irXtqo8iSPgSiHgi/EH8MffAnymFECEA/xrA70opC/e5MOw+IO+z3+7P+jK0EQlIp7n6GxERER1vW8Wa0ZK9tprD62t5lOraytQRnwuXUwpeev4UMmkFl2cVJELeIb9iIiI6Mpq1/rEDD9KChdBmwbIFSydIrVXrb7vWdvoe71Z3ka1noUq17xxO4UTMFzNaranxlGW+q3nBrbg/Dq+TX9dR1wMFtUIIN7SQ9k+llP9G370phJjS27RTALb0/WsAUqbDZwHc1ffP2uzvI6X8GoCvAcCVK1e42BgREREdG7VmGz++o40wuKqHs3dyWsPC5RA4OxXGZ1emsZKKIZNWsJAIwuFgc4KIiHpIqbVg+4LXngW67Fqw7qDeeJ0C5p6xD2BDE2zB0pGnShWFeqHbbu0NXXu2K62K7XmC7qARtM6GZnF57LJt8JrwJRDxRuAQXKyVHs6en3WFVp39PwFcl1L+gelDfw7giwB+X7//d6b9fyaE+ANoi4mdBvC3Usq2EKIohPgQtNEJXwDwL/ftnRARERGNGFWVuLFd1puyWVxbzeGn60W09BkGM4ofmbSCf/TsPDIpBRdnovC5OcKAiOjEa9a64wZ6g9fOfXEDaDd6DhRAaFwLWmNzQPpD9qMIvBGAvz5NR1S9XbeMG7Bru3a2s7Us2rLddw6HcCDmjRkjBS4lL/WHrvr4gZgvBr/LP4R3SifRg/x47FkA/wWAN4QQ1/R9/wRaQPttIcSXANwG8EsAIKX8iRDi2wDeBNAC8FtSGn8rfgPA1wH4oS0ixoXEiIiI6NjYLTe0QFafLfvaag6FmjbCIOR14bHZKL783KK24FdawXjYN+RXTEREh0pKoLJ7/1mwhbtAdbf/WHeg23hNDQhgQxOA033474voEUgpUWgU+kLXQQtulZol2/P4XX4jYJ0KTeFi8qIlcDWPG4h6onA6+MNxGj1CytGeLHDlyhX56quvDvtlEBEREVnUW2385G7BMlv29q7263IOASxPRrTFvvRQdmksBCdHGBARHV+tuk3ztXckwQbQrvccKIDgmH3wapkFG2ULlo6MRrth23DtXXir87glW33nEBCI+WKWkNWyyJZpO+6LI+AODOGdEn1wQogfSimv2H2MA2eIiIiI9iClxPs7FSOQvbqaw5t382i2tR94T0Z8yKQU/MrTaWRSCi7NRhHw8MssIqJj4UFasMV1bV5sL0sL9un+MDY8qd3YgqURJ6VEsVm873xX83axWbQ9j8/pMwLWycAkzifOd4NXXwJxf3db8SpsvdKJw+8giIiIiHrkKg0jlL2mjzDIVrTVsAMeJy7NRPFrH17Q2rKpGCajHGFARHQk7UcLNpoCUk+xBUtHTrPdtM521Ruu5oW3zI9bqn3rVfEqRsP1bPzswOZrwpdg65VoDwxqiYiI6ERrtFT8dKOghbL6bNmb22UA2vfWp8dDeOn8BDKpGFbSCk6Ph+ByciVfIqKRJiVQzfYErjYLc9m1YF3+buM19ZQ1eO3cswVLI0hKiVKzZNtwNcJX0/5Co2B7Ho/DYyyqNRYYw3J82TLf1TyGQPEqcDkYLRHtF/5tIiIiohNDSom1bBVX9VD22moWP75bQKOlAgDGwl5kUgr+/hOzWNFHGIR9/EaciGiktBrdcQODAtjiBtCq9R8bHNOC1ugMMHvFFMCaRhL4FLZgaWQ01SZytVxfu7X3cacN21AbtueJeqNGwHomdqav8WoOXwOuAAT/DhANBYNaIiIiOrYKtSZeX83j2mrWGGOwXdK+gfG6HLg0E8UXPjSHlXQMmbSC6aiP35gQEQ1LpwXbN4qg576y3X+sy9dtvM5csV+YKzQJuDyH/76ITKSUKDfLtm3XviZsbQf5et72PG6H2zJS4LRy2tp21ee9JnwJKD4Fbgd/8Ex0FDCoJSIiomOh1Vbx042iZbbse/dKkNp6X1gaC+KjZ8aRSStYSSlYngzDzREGRESHo9UAShuDF+PqNGLv14KNzGghbG8DNjwF+GNswdLQtNQWcvWcNXTtabuaA9h638xjTcQTMVqtS8oSnvQ92dd27bRgQ+4Qf7hMdAwxqCUiIqIjR0qJ9XzNCGSv3s7ijTt51JraCIN40INMSsGnL09jJa3gsVkFUT+bJERE++5BWrDFDaB8r/9YSwv2CbZgaWRIKVFtVbFT2+mb9WpZfKuqbefqOUjIvvO4HC5Lu3VJWbKd9dq5uTn3mOjEY1BLREREI69Ub+H1tZyx4Ne11Ry2ilobxeN04MJMBL/8VBqZlIKVVAypuJ8tEyKiR/VALdgNoFXtPzaQ7AauM0/0B7BswdIha6ttZOtZ29DVbvGtWtum3Q0g7AkbAeuisogrviuWWa/m7bA7zK9HiOgDYVBLREREI6WtSryzVTQC2WurOby9WYSqF1XmEwH8zFJCmyubUnBuKgKPiyMMiIgemJRALTd4/EDn3q4F6/SaAtjHu41Y8314EnB5D/990YlTaVb2bLt2xg9ka1n71qtwWRqu85F529C1s+1xsuFNRAeHQS0REREN1WahhqtGKJvFG2t5lBttAEDU70YmpeATFyaRSSvIzCqIBfkNEhHRQO2m1nK1C17NwaxtCzbRbbxOr/QHsJFptmDpQLXVNvKN/MD5rp0AtvOxqt11DCDkDhkB61xkDivjK9bg1bTQVsQTYeuViEYGg1oiIiI6NNVGG2/cyePaalafLZvDel771UK3U+DcVAS/+MSsNsIgHcN8IsBvnoiIAL0Fmx8cvHbuy/eA3tag06u1XCPTwFQGWP6U/YJcbMHSAai1apaGaydwtYSv+uJb2XoWqlT7zuEUTsR8MWPkQCqSssx3NS+4FffH4XXyWiaio4lBLRERER0IVZV4714JV1e7s2Xf2iyirc8wSMX9uDIfRyalIJNScGE6Ap/bOeRXTUQ0BO0mUNrcYxbsOtCs9B9rbsFOZWzGEEwBgThbsLRvVKkiX89b2q6dsNVu4a1Ky+a6BRB0B42gNRVK4fLY5b62a+dxxBuBQ3DMEREdfwxqiYiIaF9sl+qWubKvreZQrLcAAGGvC5dTCn7zY0vIpBRcTilIhth2IaJj7n4t2OJGd19pC/0tWE83cJ26DCx/sr8FG5oE3L6hvDU6XurtumWmq13b1TzrtS3bfedwCAdi3pgRsj429lhf27UzfiDmi8Hv8g/hnRIRjTYGtURERPSB1Zpt/ORu3jRbNoe1rDYnzukQODsZxqcz0/oIAwWLyRAcDra5iOgYabeA0sYeLdgNoFnuP9Yf7zZepx4zjR8w3bMFS49ASolCo2Bpuw5qvO7WdlFqlmzP43f5jYB1OjSNS8lLlsDVPG5A8SpsvRIRPSIGtURERHRfUkrc3C4bgezV2zlcXy+gpY8wmI76kEkr+MIzc1hJx3BxOgq/hyMMiOiIkhKoF/YeQzCwBTupBa1TjwFnXu7OfzWPImALlh5Co92wXVDLbuGtbC2Llmz1nUNAIOaLGSHrhcQF6yJbpu24L46AOzCEd0pEdHIxqCUiIiKLbLmhBbKmEQb5ahMAEPQ48disgl9/blFry6YUjEcYOBDREdFuabNgLaMIeu/XB7RgY93G6+Sl/lmwkWltXixbsPSAOq3XvdqunWC22Czansfn9BkB62RgEucT5/varp1txavA6eAPU4mIRhWDWiIiohOs3mrj+noR125njcbsrR1t0Q+HAM5MhPHJi5NYSSvIpGI4NR6CkyMMiGgU1QqDg9dOI7a8BfSuKO9wd2e/TlwETn+8J4DttGA5T5P21mw3rQFrT+O1tw3bUu1br4pXMRqu5+Ln+tqu5tEDbL0SER0fDGqJiIhOCCklbu9WjPEF11ZzePNuAY22FlpMRLzIpBT8gyfTyKQUPDYbRdDLLxWIaMjaLS1gtR1FYBpJ0LCZsWluwU5ctG/B+uOAg3M1yZ6UEqVmybbtaoSvpsfFhn3r1ePwGItqjQXGsBxftm28JvwJKF4FLgf//SUiOon42Z+IiOiYylebeM2YK5vFa2t57JYbAAC/24lLM1H86rPzWEkpyKQVTEXZFiOiQ9ZpwfY2Xy2zYDf3aMFeAE6/xBYsPbCm2kS2lrVtu9oFsk21aXueTus17otjOb5sme3aCWU72wFXAIJjMYiIaA8MaomIiI6BZlvFT9eLuLaaNWbL3rinzVgUAjg1FsILZ8eRSSvIpBQsT4ThcrJBRkQHRG1rAevDtGB9SjdwnTjfbcSa7wMJtmDJIKVEuVm2BKy9bVdzAzZfz9uex+1wW0YKnFZOW9uuvoTxccWnwO1wH/I7JSKi445BLRER0REjpcSdXFWbKauPMHjjTh71ltY4S4Y8yKQU/OLjs8ikFFyajSLi4zeTRLRP6sXBwet9W7AuLXztBLCnXuwJYPWbh/M2CWipLeTqucGhq2l7t7aLertue56IJ2K0Wk8pp2zbrp0WbMgdYuuViIiGikEtERHRiCvWmnh9LW+ZLbtd0r4h9bocuDgTxT/80BwyKa0tOxvz8xtNIvrg1DZQ2rIPXs0Lc9nN4PRFu4Hr+Plu8GqeBRtIsgV7gkkpUW1VrcGrvrCW3cJbuXrO9jwuh8sy23VJWbKd9dq5uZ38QSURER0dDGqJiIhGSKut4u3NktaWXc3i2moO72yVIKX28cVkEM+dTiKTVrCSiuHsVBhujjAgor3USwOCV1MAW9oEZNt6nLkFO34OWHoBCE/2L8rFFuyJ1FbbyNazti1Xu+1au2Z7nrAnbASsi8oirviuGKMG4v64ZTvsDvOHkUREdGwxqCUiIhqi9XzVGF9wdTWHN9byqDa1oCQWcCOTUvBzl6a12bKzCqIBNoOIyERtA+V79sGruRlbL/Qfa27Bjp1jC5YAAJVmZc+2ayeAzdaykJB953AJl6XhOh+Z7xszYN72OD1DeKdERESjh0EtERHRISnXW3jjTt6YLXt1NYvNgjbCwON04Nx0BP/gyRRW9AW/0nGuEE10onVasH3BqymALW7Yt2BDk3oAuwwsPd8fwIYnAU9wOO+LDlVbbSPfyO/Zdu0EsNVW1fY8IXfICFjnInNYGV+xBq964zXhSyDiifDfLyIioofAoJaIiOgAtFWJd7dKxviCq7dzeHuzCFUvHs0lAvjQYsKYK3t+OgKvyzncF01Eh+NRWrDeaLf5OrasB6/mBbmmgeAYW7DHXLVVtW24mue/dh7n6jmovQu7AXAKp2WWazqStjw2L7gV98fhdXqH8E6JiIhOFga1RERE+2CrWDNGGFxbzeH1tTxK9RYAIOJz4XJKwcfPT2AlHcPllIJ4kL/mSXQsNcoD2q97tGCFU2u5hqeAsTPA4sf6A9jIFFuwx5QqVeTr+b6Gq6X5ahpFUGlVbM8TdAeNoDUVSuHy2OW+tmsnfI14I3AIBvpERESjhEEtERHRB1RttPHju3lLMHsnp/2qqMshcG4qgl9YmdHasmkFC4kgHA7+CijRkaaqWgvWLng1N2Lr+f5jvZFu8zX5UftZsMExwMFW/XFSb9ctjVe7tqt51mu7N7wH4BAOxLwxI2SdHZvta7t2xg/EfDH4Xf4hvFMiIiLaLwxqiYiI7kNVJW5sl/VAVhtjcH29iLY+w2BG8SOTVvCPnp3HSlrBhekofG6GLURHSqPSE7jajCQobQBqy3qcuQWbPA0sfFR7bJkFOwV4Q8N5X7SvVKmiUC9Y5rkOnPta20G5WbY9j9/lNwLW6dA0LiUvWQJX87gBxauw9UpERHSCMKglIiIy2SnVjZZs51asaeFMyOvC5VQU/+VHF5FJxZBJKRgLc2Yf0ciSEqjs6u1X0814rDdja3u0YBeeYwv2mGq0G/ZjBmy2s7UsWrLVdw4BgZgvZoSsF5IX+tqu5lvAHRjCOyUiIqKjgEEtERGdWLVmG2+uFywjDG7vanP/HAJYnozg5x+bxkpawUpKwdJYiCMMiEZFu6W1XAvrQOGO3ny9oz++2x1N0K5bjxMOIDShBa2JJWDhI/0BLFuwR5aUEoVGwbbh2pnvaowfqO6i2Czansfn9Bkh61RwCheSF/rarp1txavAycCeiIiI9gGDWiIiOhGklLi1U9HGF+jB7JvrBTTb2giDqagPmZSCX3k6jUxKwaXZKAIe/jNJNBSWBbnMTVhTGFveAnpXsnf59LB1Bph9Ug9d9UW4IjPax0ITgJN/t4+SZrtpme06qPHaeU6rd0QFtNar4lWMluu5+Lm+xqt59ABbr0RERDQM/CqViIiOpVyl0TfCIFdpAgACHicuzUTxpQ8vIpNSsJJWMBHxDfkVE50AUgLVrKn5eqc7C9YcxNqNIvApWvAamQYmLna3w9PdbX8MEGy9jzopJYrNorXh2tN2NYevxYZ969Xj8BiLao0FxrAcX7ZtvCb8CSheBS4Hv/UhIiKi0cavVoiI6MhrtFRcXy9YQtmb29oiLkIAZ8bD+MT5SWTSCjIpBWcmwnByhAHR/mq3gNKmdeyAXRDbqvUcKLqjCOKLwPyHu61YcxPWw4bjKGuqTWRrWdu2q93CW021aXueTus17otjOb5s23ZN+BKI++MIuAIQDOaJiIjoGGFQS0RER4qUEmvZKq6u5vQRBln8+G4BjZb2K9BjYS8yKQV//4lZrKQVPDarIOTlP3dEj6RR6QlcbUYSlDb7RxE4vd2wdeaJbvPVPJIgNAE43cN5XzSQlBLlZtnSeN2p7djPeq3tIl+3aUEDcDvcRsia9CdxJnbG2nb1JYyPKz4FbgevBSIiIjq5+J0rERGNtEKtideMUFa77ZQbAACf24FLM1F88Zk5ZFIxZNIKpqM+NqyIHpQxiqBn/mtvE7aW6z/WF+2OHZg4322+moPYQJyjCEZIS211W6+1Hds5r+ZgtqE2bM8T8USMlusp5ZSxbR410GnChtwhfk4mIiIiekAMaomIaGS02ip+ulE0Atmrt7N4717Z+PjSWBDPnx1HJqWNMFieDMPtdAzxFRONsHZLW3DL0nztbJvC2IGjCKaA2AIw96x1BEFnJIEnOJS3RV1SSlRaFWvbtbZrPO4NYXN1m8AdgMvhssx2XVKWbGe9dm5uNqCJiIiIDsSeQa0Q4o8B/DyALSnlRX1fHMC3AMwDuAXgc1LKrP6x3wPwJQBtAP9YSvlX+v4nAHwdgB/AXwD4HSml3N+3Q0RER4WUEnfzNWN8wbXVHN64k0etqf3qdCLoQSal4LOZGWT0EQZRP8MBIgBAs9q/AJelCbsOlDZsRhF4umHrzOP9i3FFpjmKYMhaagu5es521qvddq3dG7Rrwp6wEbAuKUt40vekMWog7o9btsPuMFuvRERERCPgQRq1XwfwvwL4hmnfVwF8V0r5+0KIr+qPvyKEOA/g8wAuAJgG8B0hxBkpZRvAHwL4MoAfQAtqXwbwl/v1RoiIaLSV6i28vqaPL9DHGGwV6wAAj8uBC9MR/PJTaaykY1hJKZiN+Rkc0MnTGUVgHjtgWZxL365m+4/1RvXm6zQwdk4PXnuasBxFMBSVZmXPtmtnO1fPQaK/y+ASLkvDdSG60DdmwLztcXqG8E6JiIiI6FHsGdRKKb8nhJjv2f0ZAB/Tt/8EwH8E8BV9/zellHUAN4UQ7wJ4SghxC0BESvl9ABBCfAPAZ8GglojoWGqrEm9vFi2h7DtbRah69rCQDOLZU0ljhMG5qQg8Lo4woGNObWsLbtnNgDW3YlvVngMFEBrXwtbYPDD3jHUxrk4Q6w0N412dSG21bbReB813NS+2Ve37f6oJuUNGwDoXmcPK+Io1eNUbrwlfAhFPhD+8IiIiIjrmHnZG7YSUch0ApJTrQohxff8MtMZsx5q+r6lv9+63JYT4MrT2LdLp9EO+RCIiOiybhRquGot9ZfH6Wh6VRhsAoATcuDyr4JOXJo1gVgmw6UXHTGcUQWfsgBHE3uk2YUubgGxbjzNGEUwDUxlg+VPWxbgi00B4kqMIDkG1VR3YeO08Ns96VXvHSgBwCqdllms6ku5ruyZ8CST8CcR8MXid3iG8UyIiIiIaVfu9mJjdj/nlffbbklJ+DcDXAODKlSucY0tENCKabRVr2SpubZe7jdnVHNbz2oxEt1Pg/FQEv/TELDJpBZlUDPOJAFtgdHRJCdRy1rEDlpEEehhrO4og0g1hl57XQ1fTYlyRGSCQ4CiCAyClRLFZRK6WM4LVbC2r3dez2nYth936rtF+rbQqtucKuoPd4DWcRmY809d27cyCjXgjcAj+dgARERERPZyHDWo3hRBTept2CsCWvn8NQMr0vFkAd/X9szb7iYhoxKiqxN18Fbe2K7i5XcJN/f7WTgWruxW01O7Pz1JxP67Mx7GSUpBJKzg/FYHP7Rziqyf6ANQ2UNoyha82IwkKd21GEQAIjmthq5IGUk9bF+PqjCTwhg//PR1TtVbNGDWQq2lhq93jbE0LYfP1PFqyZXsut8ONmC+GmDcGxafgsbHHjMZrJ3DttF9jvhj8Lv8hv1siIiIiOqkeNqj9cwBfBPD7+v2/M+3/MyHEH0BbTOw0gL+VUraFEEUhxIcA/A2ALwD4l4/0yomI6KFJKXGvWMfN7bJ22ynj5r0ybu2U8f5OBfVW91d6/W4n5pNBnJsK41OXJjGfCGJxLIiFZAjxIEcY0Ihq1voX4OptwhY3+kcRONxayBqeBqYuA8ufNDVh9SA2NAm4eO0/rJbaQr6e10LVnoC103jthK+dfYNmvAoIKF4Fik9BzBvDXGQOl8cuI+aLQfEqiPviULyK5bHfxYUKiYiIiGg07RnUCiH+FbSFw5JCiDUA/yO0gPbbQogvAbgN4JcAQEr5EyHEtwG8CaAF4LekNL4D+g0AXwfgh7aIGBcSIyI6YNlyAze2y7hlCmRv6Y/LjW5A5XE6kE4EsJAM4mPL45hPBLGQ1G4TES9DDRodUgK1fE/4atOEre72H+sJ62HrFJD8aHe7sxhXZxSBg7+6/qCklCg1S93Q1RSw2gWu2VoWhUZh4PmC7qARqCZ8CZxSThlBa6cB27mPe+MIe8JwOtjiJyIiIqLjQUg52iNgr1y5Il999dVhvwwiopFVrDW1MQWmVmwnnM1Xm8bznA6B2ZgfC8mg0YrtBLLTih9OB8NYGjK1DZTvWRfgKvbOg70LNG1miQbHeua/mhbj6jRifZHDf09HTGfEwP2CVvPjXC13/xED3pjWZu0ErJ3QtSd47TRePU42lYmIiIjoeBNC/FBKecXuY/u9mBgRER2AWrONW3ob1tKQ3a5gu1S3PHdG8WM+GcDPPzZltGIXkkHMxgLwuNgUpCFp1rpBq7n5am7FFtdtRhG4ujNfJy8Bpz/R34QNTwIu73De1wjrjBiwC1h7g9dOIHu/EQNRb9QIWNPhtGXEgHGvh64xXwwBFxcSJCIiIiL6IBjUEhGNiEZLxWq2YrRiO/Njb22XcTdfszx3LOzFQiKIF86OYz4ZxEIygIVkCHOJABfzosPVGUVgCV/X+1uxlZ3+Yz2hbtt14Tl9Nqy5FTsDBJIcRYDuiIH7jRboPO7MeS3UC5Cw/82pzoiBmDeGuC+OpeiSEbCaA9dO4zXiiXDEABERERHRAWNQS0R0iNqqxN1ctacVqwWza9kq2mo3VFECbswngvjQYkIPY7XbXCKAsM89xHdBJ4aqdkcRmIPY3iZss9x/bCCphbDRGWD2iv1IghM8iqDerhut1t3arhG43m9hrZZqP2LA5XAh7o0boepyfNkyTsAyakBfeMvrZAOZiIiIiGjUMKglItpnUkpsFGp6G7aCm9sl3NTvV3eraLRV47lBjxMLY0Fcmoni05entfmxySAWEkHEgpzVSAeoVbcJX3uasKUNoDccdLj0cQNTwMRF4PTHu63YThgbnjpRowjaahv5Rh65mh66mgLW3sed5uteIwY6AWsqlMJjyces4wV65rsG3UGOGCAiIiIiOgYY1BIRPQQpJXbKDcvM2Fs7Zdy4V8b7OxVUm905mx6XAwuJIE6Nh/DS+UksJAPaIl5jQYyFvAxYaH9JCdQLetg6qAm7DlS2+4/kfWONAAAbtElEQVR1B7vzXxc+ooev5gW5prVFu47xKAIpJcrNcneUgGmeqxG69sx3vd+IgYArYAlYF6OLlgW0zIGr4lMQ9UQ5YoCIiIiI6IRiUEtEdB/5atMyoqAzpuDmdhnFWrdp6HIIpOMBzCeDePZUEvPJIBb1duxUxAeHg2Es7YPOKILiXZsmrGm7Ueo/NpDohq0zV0wtWFMQ640Ax+wHB+YRA70Ba+dxrpbDbn33gUYMdALVuDeO5fiypeHat7CWL8YRA0RERERE9MAY1BLRiVdptIwxBZ1WbCeM3S03jOcJAcwofiwkg/iFlRmjFbuQCGI25ofLeXxbhnQIjFEE5iZsTyu2uN4/ikA49cB1Cpg4D5x6sacFq48icPuG8772kXnEgDFawBSwmhfW6gSylVZl4Pmi3qgRsM6GZnEpeQmKV0HcF7cdNcARA0REREREdJAY1BLRiVBvtXF7p9LXir25XcZmoW557kTEi4VkEJ+4MKHNjE1oi3il4gH43PyVZHoItYJ17IBlMS79ZjuKINANW+ee1RfjmrE2YYNjwBH8VXljxMCAgNWu+Zqv5weOGPC7/EbAqvgULEQX+kcLdEJYn4KIJwKXg18GERERERHR6OB3KER0bLTaKtayVdzcKePmPWsYezdXhWrKdxJBD+aTQXz41BgWx7ph7HwygICHnxrpAamqFrBa5r/e7W/C2o0i8Me7i2/NPK6NJOjMh+2Esb7okRlF0Gg3ugGrHq72Bq69zdem2rQ9l3nEQMwbw5nYGS107TRcO+MHOsGsV4HPdfQbw0REREREdLIxjSCiI0VVJdYLNdy8V8bNnbIxP/bWdhm3dytomdLYsNeFhbEgnpiL4Rcfn8VCshPGBhH1u4f4LuhIaDW0oLWoh659TVj9Y71ho3AC4UktdB07Cyy9oIWv5iA2PAW4/cN5Xw+grbZRaBS6oWrPwlp9wWs9h3KzPPB8nREDilfBdGgaF5MXbQPXzuOQO8QRA0REREREdOIwqCWikSOlxL1S3WjF3tCD2M4M2XpLNZ7rczswnwji7FQYL1+ctCzilQh6GPaQvXqxZzEumyZs+V7/cS5/d+TA3DP6CAK9FdtZqCs0PlKjCKSUqLQq1oC1d7xATxN2rxED5rbrfHTeMs+1d75r1BvliAEiIiIiIqIHwO+ciGhocpWGEcLeNN1ubZdRbrSN57mdAnMJbTzBc2eSWEiGMJ8MYCEZxETYB4eDYSzpVBWo7JjGDtzpNl/NrdhGsf9Yf6w7cmB6xbQYl6kJ61OGPoqg0W7YjhPoDVzNTdiBIwaEC4pPMQLW08rpvgW0OoFsZz9HDBARERERER0MBrVEdKBK9ZYliL21rTdkd8rIVbrhkUMAqXgA84kgnpyPGyMKFpNBTCt+OBnGUqsBlDZMLdh1m1as3SgCBxDqjCI4Ayw9b9OEHc4oAlWqKNQL2K3v9s1v7V1Ya7e2u+eIgYgnYgSsU6EpnE+c10YLeOPdBbVMwStHDBAREREREY0OBrVE9MhqzTbe36ng5nYJN7cr3WB2p4x7xbrludNRH+aTQfzcpSktjE0EsTAWRCoWgMflGNI7oKGrF/Xm692e8NXUhC1v9R/n8ncX30p9yNSCNQWxwXHAefD/3EkpUW1VjUC1N2C1W1gr38hDlart+fwuv6XZmo6kLc1W88JaMV+MIwaIiIiIiIiOOH5HR0QPpNlWsbpbsY4o2Cnj5r0y1gs1SNM4y2TIi4VkAM8vj1lmxs7Fg/B7Rmd2Jx0CKbujCMzzX3vD2Hqh/1h/TB87MAVMXTaNIDCFsf7YgY0iaLab1mZrZ9vUfO1twjbUhu25zCMGYr4YTimnjIC1N3DthK5+1+guNkZERERERET7j0EtERnaqsTdXNUIYW/oi3nd3C5jLVtFW+2msVG/GwvJIJ5eTBit2IVEEPPJAMI+9xDfBR2adhMobvQsxtXThC1uAO2e8FI4gNCEFrYmTwOLH+u2YsOmUQSewL691M6IAfNCWnYzXs3N11KzNPB8nREDilfBVLA7YiDmNc13Nc14DbvDHDFARERERERE98WgluiEkVJis1C3tmL17ds7FTTa3V/DDnicWEgGcXEmik9fnrYEsrGgZ4jvgvadqgL1PFDZ1Rqwtreej1VzAKT1PC5fd+xA6umexbj02yOOIuiMGDBarb2jBXoW1urc9hox0AlYOyMGLOMFTKFr1BuF28EfRhAREREREdH+YlBLdAxJKbFbbvS1Ym9uV/D+ThmVRtt4rsflwHwigMVkEC+cG8dCIoiFpHYbC3vZAjyKpAQa5cEB66D9sm1/PqcHCCT0WxyYvNR93BvEPsQogma7aQ1Y69bxAsY+UxN20IgBp3BaAtYlZalvAS3zY8WncMQAERERERERjQQGtURHWL7axK2eVuyt7TJubJdRrLWM57kcAql4AAvJIJ5ZTGAhGcBCMoT5ZADTUT8cDoaxI61Z2ztk7d3XrtufSzhMoWsCSJ4CAk9b93UC2c62J/TA4asqVRTrBSNctSygZQpczeMG7jdiIOwJG4HqZGASZ+Nn+4JXczAb9oThEFyUjoiIiIiIiI4eBrVEI67SaOHWdqUvjL25XcZOudsqFAKYjvqxOBbEZzMzRit2PhnEbMwPt5Ph1UhoN4Fq9sEarp19jcFBJnxKN1CNzmqLbplD1t7w1acAjr2vhc54gUKjgHz5rnZfzxv3+Xoe+Ua+b77r/UYMeJ1eo9Ea88UwG541Ata4L943aoAjBoiIiIiIiOgkYVBLNALqrTZWdyu4uV3Bze2ScX9ru4KNQs3y3ImIF/OJIF46P2EEsYvJIFLxAHxu55DewQmlqkAtpwes2w8WvNbyg8/nCZtC1iSQXO5vt5pv/ties15bagvFRlEPXNeQ3/1JN2xt5FGoF2xD2EKjgKbaHHhel3Ah4o0YAeuSstQ3y7VzH/fGOWKAiIiIiIiIaA8MaokOSaut4k6uiht6I7YzouDWThl3slWopjWZ4kEP5hMBPHsqiYVkAPOddmwiiKCXf20PhJRAvbjHTNeej1V3gQHtUTi9QDDZDVmV9ICWq6nt6vIOeGmmdqseohZ27iDfyNsGrMZz6gUUm8X7vu2gO4ioJ4qIN4KoJ4olZQkRTwRRbxRRb7S7bXpO1BuF3+Xn/GIiIiIiIiKifcTEh2ifVRotvLNZwlubRbyzWcSNe2Xc3CljdbeCZrubxoa9Lswng1hJxfALK7NY1NuxC4kgogH+uvcja1a1MLW8vUf4ato/qEHqcFlD1fGz95/pGkgA7kDfXNe22kaxUbQGrJX3kc++3hew9oawD9Ju7YSqY4ExLClLloC18zFzCBv2hDlagIiIiIiIiGhEMKgleki1Zhs37pXx9mbRuL21WcTqbtV4jtflwEIyiOWJMD5xYdKYG7uQDCIR9LCR+KBaDa29+kEW1GpWBpxMaCMDOoFqfAGYfWJwyzWQALwRI3SVUqLWrlkarJ1gtbC7hvyGfcv1oNqtEW8EAVeA1xIRERERERHREceglmgPzbaK93fKeGuj25J9a7OI93cqaOvzClwOgcWxIC7PKvjcEymcnghjeTKMdDwAp4MBmoXaBqo5U7C6vXfwWi8MPp832g1UQ5PA+IX7LKaVAPwK4HBa2q19DdbqTeRz16whrOk5bLcSERERERER0X5jUEukU1WJ1WwFb210GrIlvL1ZxHv3SsbIAocA5hJBnJkI4ecvTRmB7HwiCI/LMeR3MARSaotjPehogcoOUM0CkPbncwetTdb40uDRAoEEpE9BTaj27dZ6AflGDvnc+yhsPVy71dxgXVQWbQPW3pYr261ERERERERE9DAY1NKJI6XEer6GtzaLeHujG8i+s1VErdldGGo25seZiTA+tjyO5ckQTo+HcWo8BJ/bOcRXf4Ck1MYF9Ias5e3BwWt1F1Bb9udzeqwh6+TFgaMF2n4FRacHedmwnc+ar+dRqN9AoXDtkdqtSX/SaLf2hq7m+4g3wnYrERERERERER0qBrV0bEkpca9U1xb22tCC2Lc2inhns4RivRsuTkS8ODMRxq88PYfliTBOT4RweiKMkPeI//Vo1R9wpqtpf6tmfy7hsAasyVNA4Glru9UfR80XQt7lQcHpQl5tau3VTthqBKy7yOdvonBPXzzrAdqtAVfA0mDttFs7c1oHtVzZbiUiIiIiIiKio+KIJ1FEmlylgbc3S6aWrHbLVrpty1jAjTMTYfzC4zM4MxHWbyEoAc8QX/kDare0kQG2IeuA4LVRGnw+n9INWSOzwORlU7s1hqI3iILbi7zThbxDoCDbyDcLlpZroV5AofEe8sUfGSFsQ20M/COdwmlprCZ8CSxGF20DVnMIy3YrEREREREREZ0EDGrpSCnVW3hHD2Hf2igZLdmtYt14TtjrwumJEF6+OGkKZMNIhjyj0a5UVaCWe/CZrpUd7fmDeMKm+a1JILmstVv9URS8AeTdPuRdbhScThQEkJctFJqlnpbrNvLF91DYKaDYeLB2aydU7W23Dmq5st1KRERERERERDQYg1oaSbVmG+9uabNj39osGuML7uSqxnN8bgfOTITxkdNjWJ4MGYHsVNR3uIFgJ3gtbQHle/ptGyhvdbd757pK1f5cTi8QTOrBaxLtaAolfwR5bwh5jxcFlxd5p1NruUJFXjaRb5a6i2jV8yg03kF+O38w7VZPBG4n261ERERERERERPuNQS0NVaOl4uZ22RhVoM2SLeH9nTJUqT3H43RgcSyIJ+Zi+M+fThsjC1KxAByOAwpkm7XBoWtpq2f/NiDb/efozHUNjqEWiCE/toi89zIKHj/ybi8KLjcKDoG8kMjLNgpqA/lm2dRy3UKx8h5QGfwy2W4lIiIiIiIiIjoeGNTSoag127iTq+KdTXNLtogb98po6Yms0yEwnwjg7GQYn748jeVJLZCdSwThdjoe7QVIqc14HRi63rMGs/VC3ynqAih4QigE4ygEYyhEoigkp1DojBZwCBQEUJAtFNQGCq0KCo0i8vU8GuoqUIN262HXbl2ILgxcJIvtViIiIiIiIiKi44dBLe2LUr2FO9kq1rIV3MlVte1cFWtZbXu7VLc8Px0P4MxECC+em8DyZBinx8NYHAvC53Y++B/aqlvDVdvQtdN8vQeoLTQB5B0OFJwOFBwOFBxO5P0RFHxhFLwBFBIJFCYmtNAVUg9d6yi0Kqi1zaMEskArC7S6e0LukCVIXQxNGY+NsNWm5Rp0B9luJSIiIiIiIiI64RjU0p6klMhVmriT04LYtWy1G8bq2/lq03KMx+nATMyPGcWPF86OYybmx2zMj6WxEE6NhxD02lx65tbrfea9NkubKFR3UWiVtNDVfHM6UHB5UPAEkHd7UAi4UAgqKCCComyiKlv9fy4AoAGggaAIIuKMGAHrvB6sRjymmyl47ewLeUJwOfjXiYiIiIiIiIiIHg6TJQKghbFbxTpubpfx/k4ZN7cr+n0Zq7sVlBvWGawBjxMziha+Pj6nYEYJYDbm1wJZxY9kyKvNj21UgMq2vqDWLWBnG7i9hWZpE8XSJgrlTRSqO/oiWAUUhDS1XbVb3qGHr04nCg6gGpVANAIgYvteAq6AJVxN9yyKNShwDXvCDFuJiIiIiIiIiGgomEqdIK22is1iHau7FSOMvbVdxq2dMt7fqaDa7IaxbqdAKh7AfCKIDy0mMBvzY1bxIRVsIenMwtHYRLl0C6XyBkqVeyjnd1HazOF6PY9Xm0WUmlWU2lWUoKLscKDkECiJ7siBisM0c9YHwOcEEDN2+R0eRDwhRDxRRHwxzPbOZx0QuIY9YbgdnNtKRERERERERERHy6EHtUKIlwH8CwBOAP+HlPL3D/s1HEdtVWK7VMfdXBUb+Rru5mtYz1Wxnq/hbr6K9WwV+VIOAVFA0JmH31FAwF1EMljHRX8NTy3U4XLXAEcNbdRRk3WU1DpKahNv7DTxg+02ipAoOwTae81TdQFelwNBRxQhhxchVwAhTxDTngjO+mKIBMYR8Se0Jqtd4OqJcpEsIiIiIiIiIiI6UQ41qBVCOAH8bwBeArAG4O+EEH8upXzzMF/HqGu1VWQrTeyWG7hXLGMrv43d/DpK5R1UqznU6jk0GgU0myW0WiXUWgXUZRGqo4a2s46ms4mas4WGQ0XDIVH3STSnBVwOoU9iBbL6n/Wu+Q9WAVdbIiyBkBQICSfCwoUZZxBhlx9BTwghTwQhXwwhfwLBQBKh4ASCoUmEAkmE3CGE3CEE3UEGrURERERERERERB/AYTdqnwLwrpTyBgAIIb4J4DMAGNQC+N6P/hL//NWvoCZU1BwSVYdA1TGgverWbyZ+VSIqgSicGBNuRJxe+J1eBFw+BFwBBNwBBDxhBDxhhHwxhPXANRQYRzg4gXBoAl5vBGKvxiwRERERERERERHtq8MOamcArJoerwF4uvdJQogvA/gyAKTT6cN5ZSMgGVUQFW5MCjcCwougw4+QO4CwN4SoN4JIIIKgLwq/JwyvNwqfNwqfL4pwaApKeAYet3/Yb4GIiIiIiIiIiIgewmEHtXZVTdm3Q8qvAfgaAFy5cqXv48fV+aVn8I2lHw77ZRAREREREREREdEhcxzyn7cGIGV6PAvg7iG/BiIiIiIiIiIiIqKRcthB7d8BOC2EWBBCeAB8HsCfH/JrICIiIiIiIiIiIhophzr6QErZEkL8NoC/AuAE8MdSyp8c5msgIiIiIiIiIiIiGjWHPaMWUsq/APAXh/3nEhEREREREREREY2qwx59QEREREREREREREQ9GNQSERERERERERERDRmDWiIiIiIiIiIiIqIhY1BLRERERERERERENGQMaomIiIiIiIiIiIiGjEEtERERERERERER0ZAxqCUiIiIiIiIiIiIaMga1REREREREREREREPGoJaIiIiIiIiIiIhoyISUctiv4b6EEPcAvD/s13EAkgC2h/0i6FjjNUYHjdcYHTReY3TQeI3RQeM1RgeN1xgdNF5jdNBO4jU2J6Ucs/vAyAe1x5UQ4lUp5ZVhvw46vniN0UHjNUYHjdcYHTReY3TQeI3RQeM1RgeN1xgdNF5jVhx9QERERERERERERDRkDGqJiIiIiIiIiIiIhoxB7fB8bdgvgI49XmN00HiN0UHjNUYHjdcYHTReY3TQeI3RQeM1RgeN15gJZ9QSERERERERERERDRkbtURERERERERERERDxqB2CIQQLwsh3hJCvCuE+OqwXw+NFiFESgjxH4QQ14UQPxFC/I6+/58KIe4IIa7pt0+Zjvk9/Xp6SwjxCdP+J4QQb+gf+1+EEELf7xVCfEvf/zdCiHnTMV8UQryj3754eO+cDpMQ4pZ+bVwTQryq74sLIV7R/9+/IoSImZ7Pa4wemBBi2fS56poQoiCE+F1+HqNHIYT4YyHElhDix6Z9Q/28JYRY0J/7jn6s56D/O9DBGXCN/c9CiJ8KIV4XQvxbIYSi758XQlRNn8/+yHQMrzGyNeAaG+q/jbzGjpcB19i3TNfXLSHENX0/P4/RByIGZxX8emw/SSl5O8QbACeA9wAsAvAAeA3A+WG/Lt5G5wZgCsDj+nYYwNsAzgP4pwD+G5vnn9evIy+ABf36cuof+1sAzwAQAP4SwCf1/b8J4I/07c8D+Ja+HQdwQ7+P6duxYf834e1ArrNbAJI9+/4nAF/Vt78K4J/xGuNtH641J4ANAHP8PMbbI15LzwF4HMCPTfuG+nkLwLcBfF7f/iMAvzHs/0687fs19nEALn37n5musXnz83rOw2uMtw9yjQ3130ZeY8frZneN9Xz8nwP4H/Rtfh7j7YNeX4OyCn49to83NmoP31MA3pVS3pBSNgB8E8BnhvyaaIRIKdellD/St4sArgOYuc8hnwHwTSllXUp5E8C7AJ4SQkwBiEgpvy+1z1jfAPBZ0zF/om//3wBe0H+C9QkAr0gpd6WUWQCvAHh5n98ijS7zdfEnsF4vvMboYb0A4D0p5fv3eQ6vMdqTlPJ7AHZ7dg/t85b+sZ/Vn9v759MRZHeNSSn/vZSypT/8AYDZ+52D1xjdz4DPY4Pw8xh9YPe7xvT/358D8K/udw5eYzTIfbIKfj22jxjUHr4ZAKumx2u4fwhHJ5he818B8Df6rt8W2q/e/bHp1wkGXVMz+nbvfssx+jcfeQCJ+5yLjh8J4N8LIX4ohPiyvm9CSrkOaP8IAxjX9/Mao0fxeVi/IeDnMdpPw/y8lQCQM4V4vNaOv1+D1vrpWBBCXBVC/CchxEf0fbzG6GEM699GXmMny0cAbEop3zHt4+cxeig9WQW/HttHDGoPn7DZJw/9VdDIE0KEAPxrAL8rpSwA+EMASwAyANah/doKMPiaut+19jDH0PHyrJTycQCfBPBbQojn7vNcXmP0UPT5UJ8G8H/pu/h5jA7LYVxTvNZOECHEfwegBeBP9V3rANJSyhUA/xWAPxNCRMBrjD64Yf7byGvsZPllWH94zs9j9FBssoqBT7XZx89je2BQe/jWAKRMj2cB3B3Sa6ERJYRwQ/vE96dSyn8DAFLKTSllW0qpAvjfoY3RAAZfU2uw/nqe+VozjhFCuABEof2KDK/PE0JKeVe/3wLwb6FdT5v/f3t3zxpFFAVg+L0IKor4ARYpXYi/wMLCUoIGFfwoBEFRG/+BRf6Dla0gWIhYuZ2F2kcMxg9U3NgrphJsLK7FPSOzS1LsZte77L4PHDLczGyW3cOZyWXunFiG0ix5+hG7m2Ma1VlgLef8HaxjmoiadesncCj2HXwtzZBoWHIOuBZLNIllnJux/Yby3L3jmGMaUuVzozk2J+I7vgQ8acasYxrFVnMVeD02Vk7U/n+vgcXoSrebsiS0W/k9aYrEM1YeAJ9yzvda4wut3S4CTSfPLnA1uiMeAxaB1Vhy8CuldDJe8zrwrHVM0yXxCvAy/vF4DiyllA7HsqulGNMMSSntTykdaLYp3/MH+vPiBv35Yo5pFH13bljHNAHV6lb87lXsO/j3NSNSSmeAu8CFnPPv1vjRlNKu2O5QcuybOaZh1Tw3mmNz5TTwOef8b7m5dUzD2m6uAq/HxitPQUezeQtgmdIdbwNYqf1+jOkK4BTlVv13wNuIZeAR8D7Gu8BC65iVyKcvRLfEGD9BudjbAO4DKcb3UpYi9yjdFjutY27FeA+4WfvzMCaSYx1K98114GNThyjP93kBfI2fR8wxYwd5tg/YBA62xqxjxk5y6jFlmeYfyl0Vt2vXrainqzH+FNhT+3Myxp5jPcoz8ZprsqYT9eU4h64Da8B5c8wYMceqnhvNsdmKrXIsxh8Cdwb2tY4Zw+bXdnMVXo+NMZoPQpIkSZIkSZJUiY8+kCRJkiRJkqTKnKiVJEmSJEmSpMqcqJUkSZIkSZKkypyolSRJkiRJkqTKnKiVJEmSJEmSpMqcqJUkSZIkSZKkypyolSRJkiRJkqTKnKiVJEmSJEmSpMr+ArnCPso86nMVAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAABWoAAAFlCAYAAAByV9G0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nOzde3hW9Zn/+/dKCDlwCmcJB8MxCYEgEjyBAiKogNZppTojrQwdtTjW2o79aTuz1Vr9/Vq3uzPjtLXbXm11Oh1btbtTG6AqKFotVrG1VCGcI0c5h2MScvjuP9ZjEEUOGlgJeb+uK1eefJ+11nOv50pr+OTO/Y1CCEiSJEmSJEmSkpOWdAGSJEmSJEmS1NoZ1EqSJEmSJElSwgxqJUmSJEmSJClhBrWSJEmSJEmSlDCDWkmSJEmSJElKmEGtJEmSJEmSJCWsTdIFHEu3bt1Cfn5+0mVIkiRJkiRJ0ifyxhtvbA8hdD/Sc80+qM3Pz2fx4sVJlyFJkiRJkiRJn0gURe981HOOPpAkSZIkSZKkhBnUSpIkSZIkSVLCDGolSZIkSZIkKWHNfkbtkdTW1rJhwwaqq6uTLkXNTFZWFn369CEjIyPpUiRJkiRJkqTj1iKD2g0bNtChQwfy8/OJoijpctRMhBDYsWMHGzZsoH///kmXI0mSJEmSJB23Fjn6oLq6mq5duxrS6jBRFNG1a1c7rSVJkiRJktTitMigFjCk1RH5fSFJkiRJkqSWqMUGtS3N+PHjWbx4cZNec/Hixdx6661Nek1JkiRJkiRJp16LnFHbGtTX15Oenn7UY0pLSyktLT2pddTV1dGmjd8mkiRJkiRJ0slkR+3HUFFRQWFhIddffz0lJSVcffXVHDhwAIAFCxYwcuRIhg8fzqxZs6ipqfnQ+bNnz6a0tJTi4mLuvvvuxvX8/Hzuvfdexo4dy5NPPslDDz3E0KFDKSkp4dprr/3QdRYuXMi0adMAuOeee5g1axbjx49nwIABPPTQQ421FhUVccMNN1BcXMzkyZOpqqoCYPXq1Vx22WWMGjWKCy+8kPLycgBmzpzJV7/6VSZMmMAdd9zRtG+eJEmSJEmSpA9p8a2S3/zt2yzdtKdJrzk0ryN3X1F81GOWL1/Oj3/8Y8aMGcOsWbP4wQ9+wC233MLMmTNZsGABQ4YM4fOf/zwPP/wwt91222Hn3n///XTp0oX6+nomTpzIkiVLKCkpASArK4uXX34ZgLy8PNauXUtmZiaVlZXHrLu8vJwXXniBvXv3UlBQwOzZswFYuXIljz/+OD/60Y/47Gc/y69+9StmzJjBjTfeyA9/+EMGDx7MH//4R26++Waef/55AFasWMH8+fOP2dUrSZIkSZIk6ZOzo/Zj6tu3L2PGjAFgxowZvPzyyyxfvpz+/fszZMgQAK6//npeeumlD537xBNPcPbZZzNy5Ejefvttli5d2vjcNddc0/i4pKSE6667jv/6r/86rvEDU6dOJTMzk27dutGjRw+2bNkCQP/+/TnrrLMAGDVqFBUVFezbt48//OEPTJ8+nbPOOoubbrqJzZs3N15r+vTphrSSJEmSJEk6ObYth4pXkq6iWWnxHbXH6nw9WaIo+tDXIYRjnrd27VoefPBBXn/9dTp37szMmTOprq5ufL5du3aNj+fMmcNLL73E008/zbe+9S3efvvtowa2mZmZjY/T09Opq6s74npVVRUNDQ3k5uby5ptvHvFa769DkiRJkiRJ+kQaGmDjG1BeBuVzYMdK6DEUbl6UdGXNhh21H9O6detYtCj+Rnr88ccZO3YshYWFVFRUsGrVKgB+9rOfMW7cuMPO27NnD+3ataNTp05s2bKFefPmHfH6DQ0NrF+/ngkTJvDAAw9QWVnJvn37mqz+jh070r9/f5588kkAQgj85S9/abLrS5IkSZIkqZWrOwirFkDZV+C7RfDjS2DR96BTb5jyIFz3VNIVNistvqM2KUVFRTz22GPcdNNNDB48mNmzZ5OVlcVPf/pTpk+fTl1dHaNHj+aLX/ziYeeNGDGCkSNHUlxczIABAxrHJ3xQfX09M2bMYPfu3YQQ+MpXvkJubm6T3sPPf/5zZs+ezX333UdtbS3XXnstI0aMaNLXkCRJkiRJUitSsxdWzYdlZbDyWajZAxk5MOgSKJwGQyZDduekq2yWouP5c/0klZaWhsWLFx+2tmzZMoqKihKqCCoqKpg2bRpvvfVWYjXooyX9/SFJkiRJktSq7NsKy+fFIw3WLIT6GsjpCgWXx+HsgPGQkZ1wkc1DFEVvhBBKj/ScHbWSJEmSJEmSTszONXEwWz4H1r0KBMjtB6O/EIezfc+FdKPHE+G79THk5+fbTStJkiRJkqTWIwR4d0k80qB8Dmx9O17vORzG3QFF06DnMIiiZOtswQxqJUmSJEmSJH1YfR2sWwTlqXB293qI0qDf+XDp/4aCKdClf9JVnjYMaiVJkiRJkiTFDh6ANS/EwezyeVC1E9IzYeDFcedsweXQrlvSVZ6WDGolSZIkSZKk1uzATljxTNw5u/p5qD0AmZ1gyKXxSIOBEyGzfdJVnvYMaiVJkiRJkqTWZveG1GZgZVDxCoR66NALzvq7eDOw/LGQnpF0la2KQW0zM378eB588EFKS0s/9jWefvppli5dyp133tmElUmSJEmSJKnFCgG2lac2AyuDzW/G692GwJgvx+Fs3khIS0u2zlbMoPY0dOWVV3LllVee1Neoq6ujTRu/fSRJkiRJkpqthgbY8HpqM7Ay2LkmXu9dCpfcAwVTofuQJCvU+xiRfwwVFRUUFhZy/fXXU1JSwtVXX82BAwcAuPfeexk9ejTDhg3jxhtvJIQAxJ2yd9xxB+eccw5Dhgzh97//PQBVVVVce+21lJSUcM0111BVVdX4OrNnz6a0tJTi4mLuvvvuxvU777yToUOHUlJSwu233/6h+h599FFuueUWAGbOnMmtt97KBRdcwIABA3jqqacAWLhwIePHj+fqq6+msLCQ6667rrHWN954g3HjxjFq1CguvfRSNm/e3HgP3/jGNxg3bhz//u//3tRvqyRJkiRJkj6puhpY+Rz89svw/xTATybDqz+Azvkw9bvw1XK4YQGM/YohbTPT8lsi590J7/61aa95xnC4/NtHPWT58uX8+Mc/ZsyYMcyaNYsf/OAH3H777dxyyy3cddddAHzuc5+jrKyMK664Aoi7UF977TXmzp3LN7/5TebPn8/DDz9MTk4OS5YsYcmSJZx99tmNr3H//ffTpUsX6uvrmThxIkuWLKFPnz78+te/pry8nCiKqKysPObtbN68mZdffpny8nKuvPJKrr76agD+/Oc/8/bbb5OXl8eYMWN45ZVXOPfcc/nSl77Eb37zG7p3784vf/lL/vmf/5mf/OQnAFRWVvLiiy9+rLdVkiRJkiRJJ0H1Hlj1XDzWYOVzcHAvtG0Pgy6JRxoMngTZuUlXqWNo+UFtQvr27cuYMWMAmDFjBg899BC33347L7zwAg888AAHDhxg586dFBcXNwa1n/70pwEYNWoUFRUVALz00kvceuutAJSUlFBSUtL4Gk888QSPPPIIdXV1bN68maVLlzJ06FCysrL4h3/4B6ZOncq0adOOWetVV11FWloaQ4cOZcuWLY3r55xzDn369AHgrLPOoqKigtzcXN566y0mTZoEQH19Pb169Wo855prrvm4b5kkSZIkSZKayt4tsHxuPNJgzYvQUAs53WDY38ThbP9xkJGVdJU6AS0/qD1G5+vJEkXRh76urq7m5ptvZvHixfTt25d77rmH6urqxmMyMzMBSE9Pp66u7iOvBbB27VoefPBBXn/9dTp37szMmTOprq6mTZs2vPbaayxYsIBf/OIXfO973+P5558/aq3vvS7QON7gg+vv1RRCoLi4mEWLFh3xWu3atTvqa0mSJEmSJOkk2bE6NW92Dqx/DQjxSINzb4rD2b7nQFp60lXqY3JG7ce0bt26xjDz8ccfZ+zYsY2hbLdu3di3b1/jPNijueiii/j5z38OwFtvvcWSJUsA2LNnD+3ataNTp05s2bKFefPmAbBv3z52797NlClT+Ld/+zfefPPNJr2vgoICtm3b1nhvtbW1vP322036GpIkSZIkSToOIcDGP8GCb8H3z4P/OBueuwtqq2D812H2H+DWN+HS++HM8w1pW7iW31GbkKKiIh577DFuuukmBg8ezOzZs8nJyeGGG25g+PDh5OfnM3r06GNeZ/bs2fz93/89JSUlnHXWWZxzzjkAjBgxgpEjR1JcXMyAAQMaxyzs3buXT33qU1RXVxNC4F//9V+b9L7atm3LU089xa233sru3bupq6vjtttuo7i4uElfR5IkSZIkSUdQXwvv/OFQ5+yejRClwZljYNS3oXAq5PZLukqdBNH7/xS+OSotLQ2LFy8+bG3ZsmUUFRUlVBFUVFQwbdo03nrrrcRq0EdL+vtDkiRJkiTphBzcD6ufjzcDW/E7qK6ENlkwcGIczA65DNp1TbpKNYEoit4IIZQe6bnj6qiNoqgC2AvUA3UhhNIoiroAvwTygQrgsyGEXanjvw58IXX8rSGEZ1Lro4BHgWxgLvDl0NyTYkmSJEmSJKmpHdgJy+fFXbOrn4e6KsjKhYLL43B24MXQ1r2CWpMTGX0wIYSw/X1f3wksCCF8O4qiO1Nf3xFF0VDgWqAYyAPmR1E0JIRQDzwM3Ai8ShzUXgbMa4L7OKXy8/PtppUkSZIkSdKJqVwXB7Plc+CdVyA0QMfecPbn4s3AzrwA0jOSrlIJ+SQzaj8FjE89fgxYCNyRWv9FCKEGWBtF0SrgnFRXbscQwiKAKIr+E7iKFhjUSpIkSZIkSccUAmxdGo80KC+Dd+NN5OleBGO/GnfO5o2EKEq2TjULxxvUBuDZKIoC8P+GEB4BeoYQNgOEEDZHUdQjdWxv4o7Z92xIrdWmHn9wXZIkSZIkSTo9NNTD+tdSm4GVwa4KIII+o2HSvXHnbNeBSVepZuh4g9oxIYRNqTD2uSiKyo9y7JF+BRCOsv7hC0TRjcQjEujXz13sJEmSJEmS1IzVVsPaF+Ngdvk82L8N0jJgwDgYc1s8d7bDGUlXqWbuuILaEMKm1OetURT9GjgH2BJFUa9UN20vYGvq8A1A3/ed3gfYlFrvc4T1I73eI8AjAKWlpW42JkmSJEmSpOalejeseDYOZ1fNh4P7oG0HGDwJiqbBoEmQ1THpKtWCpB3rgCiK2kVR1OG9x8Bk4C3gaeD61GHXA79JPX4auDaKoswoivoDg4HXUmMS9kZRdF4URRHw+fedc9obP348ixcvbtJrLl68mFtvvbVJr5mkmTNn8tRTTwHH936djPdUkiRJkiTpI+3ZDK//GH72N/DAQPj//gHe+QMMvxquewr+12qY/lMY9hlDWp2w4+mo7Qn8Os5WaQP8dwjhd1EUvQ48EUXRF4B1wHSAEMLbURQ9ASwF6oB/DCHUp641G3gUyCbeRMyNxD5CfX096enpRz2mtLSU0tLSk1pHXV0dbdp8kj3nDnc89yVJkiRJktRsbF8Zd80uK4ONqYaxLgPgvNnxvNk+oyHtmL2Q0jEd87sohLAmhDAi9VEcQrg/tb4jhDAxhDA49Xnn+865P4QwMIRQEEKY9771xSGEYannbgkhtMixBhUVFRQWFnL99ddTUlLC1VdfzYEDBwBYsGABI0eOZPjw4cyaNYuampoPnT979mxKS0spLi7m7rvvblzPz8/n3nvvZezYsTz55JM89NBDDB06lJKSEq699toPXWfhwoVMmzYNgHvuuYdZs2Yxfvx4BgwYwEMPPdRYa1FRETfccAPFxcVMnjyZqqoqAFavXs1ll13GqFGjuPDCCykvj0cPz5w5k69+9atMmDCBO+6447DXfPTRR/nUpz7FZZddRkFBAd/85jcbn7vqqqsYNWoUxcXFPPLII43r7du356677uLcc89l0aJF3HvvvYwePZphw4Zx4403cqxvg2effZbzzz+fs88+m+nTp7Nv376jHi9JkiRJkvSxNTTAhjdg/jfhe+fA90ph/j3QUAcX/wvc/Cp86U8w+VvQ71xDWjWZpmuVTMh3XvsO5TuPtrfZiSvsUsgd59xx1GOWL1/Oj3/8Y8aMGcOsWbP4wQ9+wC233MLMmTNZsGABQ4YM4fOf/zwPP/wwt91222Hn3n///XTp0oX6+nomTpzIkiVLKCkpASArK4uXX34ZgLy8PNauXUtmZiaVlZXHrLu8vJwXXniBvXv3UlBQwOzZswFYuXIljz/+OD/60Y/47Gc/y69+9StmzJjBjTfeyA9/+EMGDx7MH//4R26++Waef/55AFasWMH8+fOP2P362muv8dZbb5GTk8Po0aOZOnUqpaWl/OQnP6FLly5UVVUxevRoPvOZz9C1a1f279/PsGHDuPfeewEYOnQod911FwCf+9znKCsr44orrjjiPW3fvp377ruP+fPn065dO77zne/w3e9+t/F8SZIkSZKkT6y+Fip+D+VzoHwu7N0EUTrkj4HRX4CCKZDb99jXkT6BFh/UJqVv376MGTMGgBkzZvDQQw8xadIk+vfvz5AhQwC4/vrr+f73v/+hoPaJJ57gkUceoa6ujs2bN7N06dLGoPaaa65pPK6kpITrrruOq666iquuuuqYNU2dOpXMzEwyMzPp0aMHW7ZsAaB///6cddZZAIwaNYqKigr27dvHH/7wB6ZPn954/vu7f6dPn/6RIwomTZpE165dAfj0pz/Nyy+/TGlpKQ899BC//vWvAVi/fj0rV66ka9eupKen85nPfKbx/BdeeIEHHniAAwcOsHPnToqLiz8yqH311VdZunRp43t98OBBzj///GO+F5IkSZIkSUdVsw9WL4hHGqx8Jt4crE02DJoIhXfBkEshp0vSVaoVafFB7bE6X0+W1Mzew74+nkkOa9eu5cEHH+T111+nc+fOzJw5k+rq6sbn27Vr1/h4zpw5vPTSSzz99NN861vf4u233z7qvNjMzMzGx+np6dTV1R1xvaqqioaGBnJzc3nzzTePeK331/FBR7r3hQsXMn/+fBYtWkROTg7jx49vvK+srKzG0Le6upqbb76ZxYsX07dvX+65557D7v+DQghMmjSJxx9//COPkSRJkiRJOi77t8PyeXHn7Ornob4GsjvHs2YLp8KACdA2J+kq1Uo5RONjWrduHYsWLQLg8ccfZ+zYsRQWFlJRUcGqVasA+NnPfsa4ceMOO2/Pnj20a9eOTp06sWXLFubNO/J+ag0NDaxfv54JEybwwAMPUFlZ2aSzWTt27Ej//v158skngTgQ/ctf/nJc5z733HPs3LmTqqoq/ud//ocxY8awe/duOnfuTE5ODuXl5bz66qtHPPe9ULZbt27s27ePp5566qivdd555/HKK680vqcHDhxgxYoVx3ubkiRJkiSptdtVAYu+Dz+dAg8OhqdvgS1vQenfw/VlcPsquOoHcVBrSKsEtfiO2qQUFRXx2GOPcdNNNzF48GBmz55NVlYWP/3pT5k+fTp1dXWMHj2aL37xi4edN2LECEaOHElxcTEDBgxo/JP+D6qvr2fGjBns3r2bEAJf+cpXyM3NbdJ7+PnPf87s2bO57777qK2t5dprr2XEiBHHPG/s2LF87nOfY9WqVfzd3/0dpaWlDB8+nB/+8IeUlJRQUFDAeeedd8Rzc3NzueGGGxg+fDj5+fmMHj36qK/VvXt3Hn30Uf72b/+2cTTDfffd1zheQpIkSZIk6TAhxEHssrK4c3bLX+P1HsVw4e1QNA3OKIEP/MWwlLToeP5cP0mlpaVh8eLFh60tW7aMoqKihCqCiooKpk2bxltvvZVYDUl59NFHWbx4Md/73veSLuUjJf39IUmSJEmSTrGGelj3amozsDKofAeIoN95cadswRToOjDpKiWiKHojhFB6pOfsqJUkSZIkSVLLU1sFaxbGwezyeXBgB6S3jefMXvhPUHA5tO+RdJXScTOo/Rjy8/NbZTctwMyZM5k5c2bSZUiSJEmSpNaoaheseDYOZ1ctgNr9kNkRBk+ORxoMugQyOyRdpfSxGNRKkiRJkiSp+dqz6dBIg4qXoaEO2p8BI66BwmmQfyG0aZt0ldIn1mKD2hACkUOf9QHNfeayJEmSJEk6DtuWx8HssjLY9Kd4resgOP+WOJztPQrS0pKtUWpiLTKozcrKYseOHXTt2tWwVo1CCOzYsYOsrKykS5EkSZIkSSeioQE2vhGHs+VzYMfKeD3vbJh4VxzOdi9ItkbpJGuRQW2fPn3YsGED27ZtS7oUNTNZWVn06dMn6TIkSZIkSdKx1B2EipdSYw3mwr53Ia0N5I+Fc2+CginQqXfSVUqnTIsMajMyMujfv3/SZUiSJEmSJOlE1OyFlc/F4ezKZ6FmD2TkxJuAFV0BgydBduekq5QS0SKDWkmSJEmSJLUQ+7bC8nnxWIM1C6H+IOR0haFXxiMNBoyHjOyEi5SSZ1ArSZIkSZKkprVzTWqkwRxY9yoQILcfjL4BCqdCv/MgLT3pKqVmxaBWkiRJkiRJn0wI8O4SWJbaDGzr2/F6z+Ew7g4omgY9h4GbwksfyaBWkiRJkiRJJ66+DtYtikcalM+B3eshSoN+58Ol/zvunO2cn3SVUothUCtJkiRJkqTjc/AArHkhDmaXz4OqnZCeCQMvjjtnCy6Hdt2SrlJqkQxqJUmSJEmS9NEO7IQVz8Sds6sWQF0VZHWCIZfFXbMDJ0Jm+6SrlFo8g1pJkiRJkiQdbveG1GZgZVDxCoR66NALRl4HhdMgfyykZyRdpXRaMaiVJEmSJElq7UKAbeWpzcDKYPOb8Xq3Ahjz5XgzsF4jIS0t2Tql05hBrSRJkiRJUmvU0AAbXk9tBlYGO9fE631GwyX3xJ2z3QYnWaHUqhjUSpIkSZIktRZ1NbD2pVQ4Oxf2b4W0NtD/Ijj/FiiYAh17JV2l1CoZ1EqSJEmSJJ3OqvfAymfjmbMrn4ODe6Ftexh0CRRdEX/Ozk26SqnVM6iVJEmSJEk63ezdAsvnxp2za16EhlrI6QbD/gYKr4g7aDOykq5S0vsY1EqSJEmSJJ0OdqyOg9llZfHsWQJ0zodzb4rnzfY9B9LSk65S0kcwqJUkSZIkSWqJQoBNf45HGpTPgW3L4vUzSmDCN6BwKvQYClGUbJ2SjotBrSRJkiRJUktRXwvv/CG1Gdgc2LMRojQ4cwyM+nYczub2S7pKSR+DQa0kSZIkSVJzdnA/rH4+Hmmw4ndQXQltsmDgRJjwzzDkMmjXNekqJX1CBrWSJEmSJEnNzf4dcShbPicOaeuqICsXCi6Pu2YHXgxt2yVdpaQmZFArSZIkSZLUHFSuOzRv9p1XIDRAxz5w9ufjcPbMCyA9I+kqJZ0kBrWSJEmSJElJCAG2Lo1HGpSXwbtL4vXuRTD2q1A0DXqd5WZgUithUCtJkiRJknSqNNTD+tdSm4GVwa4KIIK+58Cke6FwGnQdmHSVkhJgUCtJkiRJknQy1VbD2hfjYHb5PNi/DdIyYMA4GHMbFEyBDj2TrlJSwgxqJUmSJEmSmlr1bljxbBzOrpoPB/dB2w4weFI80mDQJMjqmHSVkpoRg1pJkiRJkqSmsGczLJ8bh7Nrfw8NtdCuBwy/GgqvgP4XQpvMpKuU1EwZ1EqSJEmSJH1c21fGweyyMti4OF7rMgDOmx3Pm+0zGtLSkq1RUotgUCtJkiRJknS8Ghpg058PbQa2fUW83ussuPhf4nC2eyFEUbJ1SmpxDGolSZIkSZKOpr4WKn4P5XOgfC7s3QRROuSPgdE3QOEU6NQn6SoltXAGtZIkSZIkSR9Usw9WL4hHGqx8Jt4crE02DJoIRXfD4MmQ0yXpKiWdRgxqJUmSJEmSAPZvh+Xz4s7Z1c9DfQ1kd4nHGRROhQEToG1O0lVKOk0Z1EqSJEmSpNZrV0VqpMEcWLcIQgN06guls+Jwtt/5kG58Iunk8/9pJEmSJElS6xECvPvXQ+Hslr/G6z2K4aKvxeHsGSVuBibplDOolSRJkiRJp7eGelj3KpSXxR+V64AI+p0Hk++Lw9kuA5KuUlIrZ1ArSZIkSZJOP7VVsGZhHMwunwcHdkB623jO7IW3Q8Hl0L5H0lVKUiODWkmSJEmSdHqo2gUrno3D2VULoHY/ZHaEIZfGXbODLoHMDklXKUlHdNxBbRRF6cBiYGMIYVoURV2AXwL5QAXw2RDCrtSxXwe+ANQDt4YQnkmtjwIeBbKBucCXQwihqW5GkiRJkiS1Mns2pebNlkHFy9BQB+3PgBHXxuFs/oXQpm3SVUrSMZ1IR+2XgWVAx9TXdwILQgjfjqLoztTXd0RRNBS4FigG8oD5URQNCSHUAw8DNwKvEge1lwHzmuROJEmSJEnS6S8E2L4iDmaXlcGmP8XrXQfB+bdA0RWQdzakpSVbpySdoOMKaqMo6gNMBe4Hvppa/hQwPvX4MWAhcEdq/RchhBpgbRRFq4BzoiiqADqGEBalrvmfwFUY1EqSJEmSpKNpaICNbxzaDGzHqni99yiYeBcUToPuBcnWKEmf0PF21P4b8L+A9w9y6RlC2AwQQtgcRdF7E7h7E3fMvmdDaq029fiD65IkSZIkSYerOwgVL6XGGsyFfe9CWhvIHwvnfhEKpkAnYwVJp49jBrVRFE0DtoYQ3oiiaPxxXDM6wlo4yvqRXvNG4hEJ9OvX7zheUpIkSZIktXg1e2Hlc3E4u/JZqNkDGe1g0MR4pMHgSZDdOekqJemkOJ6O2jHAlVEUTQGygI5RFP0XsCWKol6pbtpewNbU8RuAvu87vw+wKbXe5wjrHxJCeAR4BKC0tNTNxiRJkiRJOl3t2wrL58UjDdYshPqDkNMVhl4JhVfAgHGQkZ10lZJ00h0zqA0hfB34OkCqo/b2EMKMKIr+b+B64Nupz79JnfI08N9RFH2XeDOxwcBrIYT6KIr2RlF0HvBH4PPAfzTx/UiSJEmSpOZu55rUSIM5sO5VIEBuPxh9AxROhX7nQVp60lVK0il1vDNqj+TbwBNRFH0BWAdMBwghvB1F0RPAUqAO+McQQn3qnNnAo0A28SZibiQmSZIkSdLpLgTY/NaPMdsAACAASURBVJdD4ezWt+P1nsNh/J1xONtzGERHmpooSa1DFELznixQWloaFi9enHQZkiRJkiTpRNTXwbpF8UiD8jmwez1EadDvfCicBoVToHN+0lVK0ikVRdEbIYTSIz33STpqJUmSJEmSDjl4ANa8AMvKYMU8qNoF6Zkw8GIYdwcUXA7tuiVdpSQ1Swa1kiRJkiTp4zuwE1Y8E3fOrloAdVWQ1QmGXBaPNBg4ETLbJ12lJDV7BrWSJEmSJOnE7N6QmjdbBhWvQKiHDnkwckYczuaPhfSMpKuUpBbFoFaSJEmSJB1dCLCtPB5pUF4Gm9+M17sVwJgvQ9E06DUS0tKSrVOSWjCDWkmSJEmS9GENDbDh9dRmYGWwc0283mc0XHJPvCFYt8FJVihJpxWDWkmSJEmSFKurgbUvpcLZubB/K6RlQP+L4PxboGAKdOyVdJWSdFoyqJUkSZIkqTWr3gMrn41nzq58Dg7uhbbtYdAlUHRF/Dk7N+kqJem0Z1ArSZIkSVJrs3cLLJ8bd86ueREaaqFddxj2N1B4RdxBm5GVdJWS1KoY1EqSJEmS1BrsWB0Hs8vK4tmzBOicD+feFHfO9hkNaelJVylJrZZBrSRJkiRJp6MQYNOf45EG5XNg27J4vdcImPANKJwKPYZCFCVbpyQJMKiVJEmSJOn0UV8L77xyKJzdsxGiNDhzDIz6DhROgdx+SVcpSToCg1pJkiRJklqyg/th9fPxSIMVv4PqSmiTBQMnwoR/hiGXQbuuSVcpSToGg1pJkiRJklqa/TviULa8LA5p66ohKxcKLo9HGgy8GNq2S7pKSdIJMKiVJEmSJKklqFx3aKTBO69AaICOfeDs6+Nw9swLID0j6SolSR+TQa0kSZIkSc1RCLB1aTzSoLwM3l0Sr3cvggv/KQ5ne53lZmCSdJowqJUkSZIkqbloqIf1r8XBbHkZ7KoAIuh7Dky6FwqnQdeBSVcpSToJDGolSZIkSUpSbTWsfTEOZpfPg/3bIL0t9B8HY26DginQoWfSVUqSTjKDWkmSJEmSTrWqSlj5XBzOrpoPB/dB2w4wZHI80mDQJMjqmHSVkqRTyKBWkiRJkqRTYc9mWJ7aDGzt76GhFtr3hOHT45EG/S+ENplJVylJSohBrSRJkiRJJ8v2lXHX7LIy2Lg4XusyAM6bDUVXQO9SSEtLtkZJUrNgUCtJkiRJUlNpaIBNfz60Gdj2FfF63ki4+F+g8AroXgBRlGydkqRmx6BWkiRJkqRPor4WKn4fjzQonwt7N0GUDvljYfQNUDgFOvVJukpJUjNnUCtJkiRJ0omq2QerF8QjDVY+A9W7oU02DJoIRXfD4MmQ0yXpKiVJLYhBrSRJkiRJx2P/dlg+Lx5psPoFqK+B7C7xRmCF02DAeGibk3SVkqQWyqBWkiRJkqSPsqsiHmmwrAzWvwqhATr1g9JZUDQN+p4H6f7TWpL0yflfE0mSJEmS3hMCvPvX1LzZObDlr/F6j2K46GtQOBXOKHEzMElSkzOolSRJkiS1bg31sO7VeKRBeRlUrgMi6HceTL4/3gysy4Ckq5QkneYMaiVJkiRJrU9tFaxZGAezy+fBgR2Q3hYGTIALb4eCKdC+e9JVSpJaEYNaSZIkSVLrULULVjwbh7OrFkDtfsjsCEMujUcaDLoEMjskXaUkqZUyqJUkSZIknb52b4Tlc+NwtuJlaKiD9mfAiGvjcDb/QmjTNukqJUkyqJUkSZIknUZCgO0rYNlv483ANv0pXu86GC74EhROg7yzIS0t2TolSfoAg1pJkiRJUsvW0AAb3zi0GdiOVfF671Ew8S4ovAK6D0m2RkmSjsGgVpIkSZLU8tQdhIqX4q7Z8rmw711IaxOPMjj3i/FYg455SVcpSdJxM6iVJEmSJLUMNXth5XNxOLvyWajZAxntYNBEKLoCBk+C7M5JVylJ0sdiUCtJkiRJar72bYXl8+KRBmsWQv1ByOkKQ6+MRxoMGAcZ2UlXKUnSJ2ZQK0mSJElqXnauibtml5XB+j8CAXL7wegboGga9D0X0tKTrlKSpCZlUCtJkiRJSlYIsPkvqXmzZbB1abx+xnAYf2c8b7bnMIiiZOuUJOkkMqiVJEmSJJ169XWwblEczJbPgd3rIUqDfufDpf8HCqdA5/ykq5Qk6ZQxqJUkSZIknRoHD8CaF+KRBivmQdUuSM+EgRfHnbNDLoN23ZKuUpKkRBjUSpIkSZJOngM7YcUzcefsqgVQVwVZneJQtnAqDJwIme2TrlKSpMQZ1EqSJEmSmlblelg+Nw5nK16BUA8d8mDkjDiczR8L6RlJVylJUrNiUCtJkiRJ+mRCgG3l8UiD8jLY/Ga83q0Axt4Wh7N5Z7sZmCRJR2FQK0mSJEk6cQ0NsOF1KP9tvBnYzjXxep/RcMk343C22+Bka5QkqQUxqJUkSZIkHZ+6Glj7Utw1Wz4X9m+FtAzofxGcf0scznY4I+kqJUlqkQxqJUmSJEkfrXoPrHw27ppd+Rwc3Att28PgSVA4Lf6c1SnpKiVJavEMaiVJkiRJh9u75dBmYGtehIZaaNcdhv0NFF4Rd9BmZCVdpSRJp5VjBrVRFGUBLwGZqeOfCiHcHUVRF+CXQD5QAXw2hLArdc7XgS8A9cCtIYRnUuujgEeBbGAu8OUQQmjaW5IkSZIknbAdq+NgdllZPHuWAJ3z4dyboOiKePZsWnrSVUqSdNo6no7aGuDiEMK+KIoygJejKJoHfBpYEEL4dhRFdwJ3AndEUTQUuBYoBvKA+VEUDQkh1AMPAzcCrxIHtZcB85r8riRJkiRJRxcCbPpzPNKgfA5sWxav9xoBE74RjzXoUQRRlGydkiS1EscMalMdr/tSX2akPgLwKWB8av0xYCFwR2r9FyGEGmBtFEWrgHOiKKoAOoYQFgFEUfSfwFUY1EqSJEnSqVFfC++8ciic3bMRonQ48wIY9R0onAK5/ZKuUpKkVum4ZtRGUZQOvAEMAr4fQvhjFEU9QwibAUIIm6Mo6pE6vDdxx+x7NqTWalOPP7guSZIkSTpZDu6H1c/HIw1W/A6qK6FNFgycCBf/Cwy5DHK6JF2lJEmt3nEFtamxBWdFUZQL/DqKomFHOfxIfxcTjrL+4QtE0Y3EIxLo18/f5kqSJEnSCdm/Iw5ly8vikLauGrJyoeDyeKTBwAnQtl3SVUqSpPc5rqD2PSGEyiiKFhLPlt0SRVGvVDdtL2Br6rANQN/3ndYH2JRa73OE9SO9ziPAIwClpaVuNiZJkiRJx1K57tBIg3degdAAHfvA2ddD4dR4vEF6RtJVSpKkj3DMoDaKou5AbSqkzQYuAb4DPA1cD3w79fk3qVOeBv47iqLvEm8mNhh4LYRQH0XR3iiKzgP+CHwe+I+mviFJkiRJahVCgK1L45EG5WXw7pJ4vcdQuPCf4nC211luBiZJUgtxPB21vYDHUnNq04AnQghlURQtAp6IougLwDpgOkAI4e0oip4AlgJ1wD+mRicAzAYeBbKJNxFzIzFJkiRJOl4N9bD+tTiYLS+DXRVABH3PgUnfisPZrgOTrlKSJH0MUQjNe7JAaWlpWLx4cdJlSJIkSVIyaqth7Yuw7LewfB4c2A7pbaH/OCiaBkMuhw49k65SkiQdhyiK3gghlB7puROaUStJkiRJOgWqKmHlc3HX7Kr5cHAftO0AQybHXbODJkFWx6SrlCRJTcigVpIkSZKagz2bYXlqM7C1v4eGWmjfE4ZPh8Jp0P9CaJOZdJWSJOkkMaiVJEmSpKRsXxl3zS4rg42pkW9dBsL5N8fhbO9SSEtLtkZJknRKGNRKkiRJ0qnS0ACb/nxoM7DtK+L1vJFw8b9A4RXQvQCiKNk6JUnSKWdQK0mSJEknU30tVPw+HmlQPhf2boIoHfLHwugboHAKdOqTdJWSJClhBrWSJEmS1NRq9sWbgJXPgRXPQM1uyMiBQROh8G4YPBlyuiRdpSRJakYMaiVJkiSpKezfDsvnxSMNVr8A9TWQ3QWKpsXzZgeMh7Y5SVcpSZKaKYNaSZIkSfq4dlXEXbPLymD9qxAaoFM/KJ0VB7R9z4N0/9klSZKOzZ8YJEmSJOl4hQDv/jU1b3YObPlrvN5zGFz0tbhz9ozhbgYmSZJOmEGtJEmSJB1NQz2sezUeaVBeBpXrgAj6nQ+T7483A+syIOkqJUlSC2dQK0mSJEkfVFsFaxbGIw1WzIMDOyA9M54ze9HXYMjl0L57wkVKkqTTiUGtJEmSJAFU7YIVz0L5b2HVAqg9AJmdYMhkKJwKgy6BzA5JVylJkk5TBrWSJEmSWq/dG2H53HikQcXL0FAH7c+AEX8bh7P5F0KbtklXKUmSWgGDWkmSJEmtRwiwfQUs+228GdimP8XrXQfDBV+KNwPLOxvS0pKtU5IktToGtZIkSZJObw0NsPGNQ5uB7VgVr/ceBRPvjsPZ7kOSrVGSJLV6BrWSJEmSTj91B6Hipbhrtnwu7HsX0trEowzO/WI81qBjXtJVSpIkNTKolSRJknR6qNkLK5+Lw9mVz0LNHshoB4MvibtmB0+G7Nykq5QkSToig1pJkiRJLde+ranNwObAmoVQfxByusHQT8Xh7IBxkJGddJWSJEnHZFArSZIkqWXZuSYOZpeVwfo/AgFy+8HoG6BoGvQ9F9LSk65SkiTphBjUSpIkSWreQoDNf0nNmy2DrUvj9TOGw/g7487ZnsUQRcnWKUmS9AkY1EqSJElqfurrYN2iOJgtnwO710OUBv0ugEv/T7wZWOczk65SkiSpyRjUSpIkSWoeDh6ANS/EIw1WzIOqXdAmCwZeHHfODrkM2nVLukpJkqSTwqBWkiRJUnIO7IQVz8Sds6sWQF0VZHWKQ9nCaXFIm9k+6SolSZJOOoNaSZIkSadW5XpYPjcOZytegVAPHfJg5Ix4pEH+WEjPSLpKSZKkU8qgVpIkSdLJFQJsK49HGpSXweY34/XuhTD2tjiczTvbzcAkSVKrZlArSZIkqek1NMCG16H8t/FmYDvXxOt9RsMl34zHGnQblGyNkiRJzYhBrSRJkqSmUVcDa1+Ku2bL58L+rZCWAf0vggu+BAVToMMZSVcpSZISsq+mjk2VVWysrGLjripy2qbz6bP7JF1Ws2FQK0mSJOnjq94DK5+Nu2ZXPgcH90Lb9jB4Utw1O3hSvDmYJEk6rTU0BLbtq2kMYTdVVh0KZSur2VRZxe6q2sPOGd67k0Ht+xjUSpIkSToxe7fA8jlxOLvmRWiohXbdYdin43B2wDhok5l0lZIkqQlVHaxn0+7DQ9gNle89rmbz7ipq68Nh53TMakNebja9c7MZnd+ZvNzsxq9752bTvYM/L7yfQa0kSZKkY9uxGpal5s1ueB0I0Lk/nPfFOJztMxrS0pOuUpIkfQwhBHbsP9gYwm5MfbwXwm6srGLn/oOHnZMWwRkds8jLzWZkv1ym5vYiLzebPqkwNi83iw5ZGQndUctkUCtJkiTpw0KATX+Og9nyMthWHq/3GgETvhGHsz2KIIqSrVOSJB1TTV09m1PjB44Uwm6qrKKmruGwc3Lapsedr52zGd6nU2MX7Hsh7Bkds2iTnpbQHZ2eDGolSZIkxepr4Z1XUuHsHNizEaJ0OPMCGPX3UDgVcvsmXaUkSXqfEAK7q2rZ8IG5sO+FsBsrq9i2t+ZD5/XokEnvztkMzevIpKE9Dwth++Tm0DG7DZG/kD2lDGolSZKk1uzgflj9PCwrgxW/g+pKaJMNgybCxf8CQy6DnC5JVylJUqtVW9/Au7vjbtj3ZsRufF937KbKKg4crD/snMw2aY3dsBcX9IjnwnaOQ9jeudmc0SmLzDaOLGpuDGolSZKk1mb/jjiULS+LQ9q6asjKhYLL45EGAy+GtjlJVylJUquwp7r2UCfsEULYLXuqaTh8jy66tmtL787ZDOrenosGd6d352x652Y1btTVpV1bu2FbIINaSZIkqTXY9Q4snxuPNHjnFQgN0LEPnH09FE2DfhdAuv88kCSpKdU3BLbujYPXeDTB4SHsxsoq9lbXHXZORnoUjyDolM0FA7t9KITNy80mK8Nu2NORP4lJkiRJp6MQYMvbhzYDe3dJvN5jKFz4T3HnbK8RbgYmSdIncOBg3RFD2PeC2Hd3V1P3gXbYTtkZ9M7Npk/nHM4b0JW894WwvXOz6dY+k7Q0//vcGhnUSpIkSaeLhnpY/1oczJaXwa4KIIK+58Kkb8WbgXUdmHSVkiS1CA0Nge37a+JNuXZVfSiE3VhZReWB2sPOSU+LOKNjPAe29MzOqbmw8Uef3Gx65WbTPtM4Tkfmd4YkSZLUktVWw9oXYdlvYfk8OLAd0ttC/3Ew9isw5HLo0DPpKiVJanaqa+vZvPvIIWz8Uc3B+obDzmmf2SY1fiCLkf1yD+uEzcvNpkeHTNqkpyV0R2rpDGolSZKklqaqElY+F3fNrpoPB/dB2w4wZHI80mDQJZDVMekqJUlKTAiBXQdqU5tzVR02F/a9x9v3HTzsnCiCnh2yyMvNYnifXC4dlnVYCJuXm02n7IyE7kitgUGtJEmS1BLs2QzL58QzZ9f+HhpqoX1PGD49Dmf7XwhtMpOuUpKkU+JgXQNb9lSnZsMeHsa+97i69vBu2OyMdPJys+jdOYeheR3J65TdOJqgd242PTtm0baN3bBKjkGtJEmS1FxtXxmPNCifAxsXx2tdBsL5N8fhbO9SSPMflJKk00sIgT3VdY0jCTbtrvpQZ+zWvTWEw/foolv7THrnZlF4RgcuLuhxWAjbOzeb3JwMIjfRVDNmUCtJkiQ1Fw0NsOnPUJ4KZ7eviNfzRsLF/1ccznYviP82U5KkFqquvoEte2saxxC81xV7qDO2mn01dYed0zY9LdUNm81Fg7vHAWznQ2MJenXKIisjPaE7kpqGQa0kSZKUpPpaqPh9HMyWz4W9myBKh/yxcM6NUHA5dOqTdJWSJB23fTV1h8YQHCGEfXdPNfUNh7fDds7JoHfnbPK7tuOCgd3o0/nQXNjeudl0bdeWtDR/UanTm0GtJEmSdKrV7Is3ASufAyuegZrdkJEDgyZC4d0weDLkdEm6SkmSPqShIbBtX80RQ9iNldVsqqxid1XtYee0SYvolZtFXqdszu3fpXEkQV7jRl1Z5LQ1opL8X4EkSZJ0KuzfDsvnQXkZrH4B6msguwsUXQGFU2HgBMjITrpKSVIrV3WwvnEmbONogsr3HlezeXcVtfWHd8N2yGrTOAe29MzO75sNm0Xv3By6d8gk3W5Y6ZgMaiVJkqSTZVdF3DW7rAzWvwqhATr1g9JZUDQN+p4H6f5ILkk6NUII7Nh/MO6A3XVoFMHGygOpz1Xs3H/wsHPSIjijYxZ5udmc1TeXqSW9Dgthe+Vm0TErI6E7kk4vx/ypMIqivsB/AmcADcAjIYR/j6KoC/BLIB+oAD4bQtiVOufrwBeAeuDWEMIzqfVRwKNANjAX+HIIH9yjT5IkSWqhQoB3/5qaNzsHtvw1Xu85DC76WrwZ2BnD3QxMknRS1NTV8+7u6iOGsO+NJ6ipazjsnJy26Y0bcg3v06lxFEHv3BzycrPo2TGLjPS0hO5Ial2O59f3dcA/hRD+FEVRB+CNKIqeA2YCC0II346i6E7gTuCOKIqGAtcCxUAeMD+KoiEhhHrgYeBG4FXioPYyYF5T35QkSZJ0yjTUw7pFqXC2DCrXARH0Ox8m3x+PNejSP+kqJUktXAiB3VW1bDjC5lwbU4+37a350Hk9OmSSl5tNUV5HLhnak7xOWfTunJMKY7PplJ1B5C8QpWbhmEFtCGEzsDn1eG8URcuA3sCngPGpwx4DFgJ3pNZ/EUKoAdZGUbQKOCeKogqgYwhhEUAURf8JXIVBrSRJklqa2ipYszAeabBiHhzYAemZ8ZzZi772/7d3r8GR5fd5359/X083+o47MDM7mNmZ2V3ucm+jXZKiSIqURIm2RMWxHMmKxMRysaLItpRUKpKjqtgvZafMqthJRcXEKkkpXUjHdolO5JIUx4krFVnUckVdeJP2xuUMZoCZAdDdAPre/7w4p0+f0xcMMAOgcfl+qrrQfbob05g924N55sHvJ13/PikzO+lXCQA4RVqdru6W3eZrb0bs7UATdnWrpt1mJ/ScZCzizoYtpvTRG3PuSIJiyg9hF/KOkrHohL4iAAd1oIFYxpjLkl6U9AeS5r0QV9baO8aYOe9hy3Ibsz23vGMt7/rgcQAAAODkq21Kf/670tf/lfTGv5Fau1IyL13/HnekwZMfk5LZSb9KAMAJVam3+k3YESHsWqWu7sBwyOmphJYKKT05m9GHrs1qqeDogr+oK6XSVII2LHCG7DuoNcZkJP1zST9jra3s8UYw6g67x/FRv9an5Y5I0KVLl/b7EgEAAIDDVb4tfeO33ZEG7/y/UrctZRak53/EXQb2xAelWGLSrxIAMGGdrtV6tRe8ujNigyHs7a2aqvV26DnxqNFi3g1cP3B1xl3O5YWwS4WUlvIppRK0YYHzZF9BrTEmLjek/TVr7b/wDq8ZYxa9Nu2ipHXv+C1JFwNPvyBp1Tt+YcTxIdbaz0r6rCTdvHmTZWMAAAA4HtZK9/9c+tq/cmfOrr7uHp++Jn3gb7vN2aWXpAhLVQDgPNlttodC2NWtmm55H++W62oP1GHzqbiWCyldKKb16kopFMJeKKQ0k0kqEqENC6DvoUGtcauz/1TS16y1nwnc9QVJn5L0C97H3woc/3VjzGfkLhO7JumL1tqOMaZqjHmf3NEJPy7pnxzaVwIAAAA8im5Xuv0ltzX79f9devCGe3z5pvSxv+eGs7PXJ/saAQBHptu1ur/TcJdyBZqwvTbs6lZNm7ut0HOiEaOFnDsH9uYTxaEQdrGQUiZ5oGmTALCvRu23S/oxSX9qjPmyd+y/kRvQft4Y8xOS3pX0Q5Jkrf2KMebzkr4qqS3pp6y1vWnXPynplyWl5C4RY5EYAAAAjl9zV3r3/3Nbs1//bWn7rhSJSZe/Q3rfT0o3PiHllib9KgEAh6De6uhOeXwIu1quq9nuhp6TSca0XHCXcr1wsaDlYsq77X6cyyYVi/LTFQAOl7H2ZE8WuHnzpn3ttdcm/TIAAABwWnU70r2vu63Z3mXtq5LtSPEp6dp3SU99v3Ttu6VUYdKvFgBwANZabe66S7puBUYS9GfD1nV/uxF6jjHSfNbRUsFxg9deCJtP+c3YnBNjSReAI2GM+ZK19uao++jhAwAA4OywVtp6NxDKvi7d+WOptePe7+TdGbMf/C+ki69KKx+S4s5kXzMAYKxmu6u1Sn0ohO03YuuqtTqh5zjxiN9+fWYpp6V8KhTIzuccJWK0YQGcPAS1AAAAOL12N9wwNtiW3b3v3hdNSovvlV76MWn5ZfdSXGERGACcENZaVert/nKuck23N4Nt2JrWqw0N/iDwTCap5YKjGwtZfeeNuXArtpBSMR2nDQvgVCKoBQAAwOnQ3JXu/km/KXv7S9Lm296dRpq9IV3/uLT8khvKzr1HiiUm+pIB4Dxrd7paqzb8JuzwaIK6thvt0HMS0Yg/kuBD12b9mbC9kQSLeUdOPDqhrwgAjhZBLQAAAE6evebKSlLughvIvvyfuKHs4vOSk5voSwaA82a70e6PIhgRwt6t1NXphuuwxXRcS4WULk9P6QNXZ0Ih7FLB0cxUUpEIbVgA5xNBLQAAACYrOFd29XW3Lbv65dFzZZdfdgPa7MJkXzMAnHHdrtW97UZ/DMFmfzlX71i51go9JxYxWsg7Wi6k9OpKyR9J4LZi3ZZsOkEMAQDj8A4JAACA47XfubJL3giD0hXmygLAIas1O1otjwphd7W6Vdedck2tTrgNm3VibgO2kNLNJ4qhEHa5kNZsNqkobVgAeGQEtQAAADg6rZp050/CoSxzZQHgSFlr9WCn6YewvVEEvRB2daumBzvN0HMiRprPuW3YFy4W9JfeuxgKYRcLjnJOfEJfEQCcDwS1AAAAOBz7miv7ovTyp7y5si8wVxYAHkGj3dHdcn1sCHt7q6ZGuxt6TjoR1XLBbcA+u5zXhaI7E3Yp744nmM85ikf56QUAmCSCWgAAAByctVL5W4FQdmCubDLvtmSZKwsAB2KtVbnW0q0Ry7lue9fvVRtDz5vLJrVUSOnpxZy+65l5LeUdf0bsciGlfCouYxhLAOD41do1bdQ39KD2QBv1Df/yoPZA06lp/c3n/uakX+KJQVALAACAh3vYXNmF56QX/2MvlGWuLACM0+p0tVZx27Cr5d5ognogkK1pt9kJPScZi/ht2I/emNNSwW3D9kLYhbyjZCw6oa8IwHnT6XZUbpa1UdvQg/qDUPC6Ue8f692utWsjP08mntHN+ZvH/OpPNoJaAAAAhDFXFgAeWbXe8gPX21t1f1FXL4hdq9TVDe/o0vRUQkuFlK7OTulD12bdENZrwy4VUpqeStCGBXCkeq3XUeHr4O2txpa6tjv0OaImqpJT8i8X5y5q2pn2b0+npv3bRacoJ+ZM4Cs92QhqAQAAzrOhubKvS2tfCcyVXXYDWebKAoA6Xav1an1sCHt7q6ZqvR16TjxqtJh3G7AfuDqj5UJ/JMFSIaWlfEqpBG1YAIcr2Hod1XL1w1cvmB3Xep2KT/nh6sXMRT0/+3w/fE2VNO30w9dcMqeI4SeqHgdBLQAAwHlxoLmyL0lLL0m5xcm+ZgA4RrvN9sgQ9pb38W65rvZAHTafimupkNKFYlqvrpRCIexyIaXZTFKRCG1YAI8v2HodDF+DrdeN2oY2G5tjW69Fp9gPX+cuum1X7zat18kiqAUAADirenNlVwOzZXfuufdFE9LCe5krC+Dc6Hat7u80tNqbB7tZC4wocD9u7rZCz4lGjBZy7hiCm08Uh0LYpUJKmSR/rQbwaLq2q63G1r5arxv1De22d0d+nqn4lB+20no93fgTBQAAaHB2DwAAIABJREFU4CzYz1zZa9/Tb8rOP8tcWQBnSr3V0Z3y+BB2tVxXsx1ul2WSMS9wdfTCxYK/nKsXxM5lk4pFCTQA7N9erdfB8PVxWq+927RezxaCWgAAgNPGnysbaMoyVxbAGWat1eZuyx1DMDAXtjeq4P52I/QcY6S5bFLLhZSeXc7r4+9ZcNuw+ZTfjM05MZZ0AdhT13ZVbpT9oLUXtg7e3m/rNTjr1Q9fU4EQ1pmm9XqOEdQCAACcZPuaK/ui9MGfcUNZ5soCOIWa7a7WKvWhENZvw27VVWt1Qs9x4hG//fr0Ys6/7s6LTWk+5ygRI+gAMKzerofC1b3C14e1Xv2RA17rtXe713rtXWi9Yj8IagEAAE6S3Q1vpixzZQGcDdZaVert/nKu8vBogvVqQza8o0szmaSWC46uz2f1nTfmQiHsUiGlYjpOGxaApP21XoPjBw7aeh1cslVySson87RecegIagEAACZlcK7s6uvSxlvenUaauS49+d3uGIPll5krC+BEane6Wq82QsFrL5S97bVhtxvt0HMS0YiWCo6WCil96NqsPxO2t6hrMe/IiUcn9BUBOAnq7XooXN0rfN1qbKljO0OfI9h6LTklPT/3PK1XnGgEtQAAAMeh25HufSO87Gv9q1LXCy96c2Vf+nHmygI4UXYa7dAYAr8Zu1XX7a2a7lbq6nTDddhiOq6lQkpPTE/pA1dn+gu6iu7irpmppCIR2rDAedJrvY5arBUKYw/Qel3OLOu9s+8dG77SesVpQ1ALAABw2PY7V/bbf5q5sgAmqtu1urfdGAphb3sh7OpWTeVaK/ScWMRoIe+2YV9dKfkjCZaLKS0XHC3mU5pK8ldN4DzotV73E77ut/X63tn3+mHrYPBadIpKxVIT+EqB48GfngAAAI/rQHNlX5JKV5krC+BY1JodrZZHhbC7Wt2q6065plYn3IbNOjF3DEEhpZtPFEMh7FIhpbmsoyhtWOBMCrZeewFrL3gdFb6Oa72mY2k/YF3OLOu5mefGhq+0XoE+gloAAICD6M2VXQ2EssyVBTAB1lo92Gn6IextfxyBG8KubtX0YKcZek7ESPM5N3B94WJBn3huMRTCLhVSyjnxCX1FAI5CsPU6GL4OznvdrG+ObL1GTETFZNEPWJ+beW5owVbvPlqvwKMjqAUAABhnv3NlX/wxb4TBC5KTn+xrBnBmNNod3S3XA8u5wiHs7a2aGu1u6DmpeNQLXlN6djmv5YLjzoXNuyHsQt5RPEpzDTjNhlqvAwu2BsPXnd7opQHpWNoPWHut13HhK61X4HgQ1AIAAEgDc2W9MQarf8RcWQBHwlqrcq0VCGFrWi3XA83YmtarjaHnzWaTWi6k9PRiTh97es5f0rVUSOlCMaV8Ki5jGEsAnDZ7tV4Hw9eHtV5LKXe8wKjgtXeb1itwMhHUAgCA82lfc2V/1Jsr+zJzZQEcSKvT1VrFDV7dGbF13eoFsl4bdrcZDlqSsYgfvH7kxqyWC2ktFRx3XmzRbcMmY9EJfUUADqJru6o0Kv0RAwOt143aRiiIfZzWa+82rVfg9COoBQAAZ1+rJt390/AIA+bKAngM1XrLb77e3qr3W7FeCLtWqasb3tGl6amElgopXZmd0ndcmw2FsEuFlKanErRhgROs0WmEAtbB1mvwvsdtvZZS7kdar8D5QlALAADOlofNlc0uMVcWwJ46Xat71YZub+2ODGFvb9VUrbdDz4lHjRbzKS0VHL3/6rQueM3YXgi7lE8plaANC5wko1qvoaZrrT8Ddj+t11KqtGfrtZQqqZAs0HoFMBZBLQAAOL2slcq3AqEsc2UBPNxus+03YVcDM2J7Iezdcl3tgTpsPhX358C+ulLy58L2FnfNZpKKRGjDApM2qvUaDFuD4et+Wq8lpxQKXktOqR++0noFcMgIagEAwOnx0LmyzzFXFjjnul2r+zsNrQZC2NuBNuzqVk2bu63Qc6IRo4Wco6WCo5tPFIdC2KVCSpkkf3UCJqHXevXnvHqt12AD1l+4VX+wr9brYmZRz848OzZ8zSfyikZowAM4fny3AQAATqY958pKmrkxMFf2PVIsObnXC+BY1Fsd3SmPD2FXy3U1293Qc6YSUT90feFiwW/G9gLZ+WxSsSj/qAMcl17rddSSrcEgdr+t1/fMvEfTzrQfvPrha6qkYrKodDw9ga8UAA6GoBYAAExecK7sqteWXfsKc2WBc8Zaq83dlla3aro1MBe2N6rg/nYj9BxjpLlsUkuFlJ5dzuvj71lw27C9VmwhpVwqxpIu4Ajt1XoNjh/oXbZb2yM/TyqW8sNVWq8AziOCWgAAcLwOMld26SU3oM0tTfY1AzgUzXZXa5W6Owt2YC6sG8rWVWuFm3NOPOIHrk8v5oZC2IW8o0SMNixw2PZqvQ6Gr5v1TbVte+hzRExEhWTBD1ppvQLA3ghqAQDA0QrNlfXasjvr7n3MlQXODGutKvV2fzlXuRYKZFe36lqr1mXDO7o0k0louZDS9fmsPnJjLhTCLhdTKqbjtGGBQzDYeh0MWw/Seu21XEe1XoPhK61XADgYgloAAHB4HjpX9rr05Me8UPYlaf5Z5soCp0S709V6tRGaCRsMYW9v1bTdCDfqEtGIlgqOlgopffDajBu+BhZ1LeYdOXFCHOBRDbZe9wpfH9Z67YWto1qvvftovQLA0SKoBQAAj6bbke7/eTiUZa4scGrtNNqhMQSDIezdSl2dbrgOW0zHtVRI6dJ0Wu+/Oh0KYZcKjmamkopEaMMC+9W1XVWbVT2oPRgZvA7efmjr1Rnfeu3dpvUKACcHQS0AAHi4UXNl73xZanp/QUzmpKUXpQ/8nX5blrmywInR7Vrd224MhbC3vRB2daumcq0Vek4sYrSQd9uwr6yU/BB2qeDoQjGlxXxKU0n+OgE8TKPT0GZ9c1/h60Fbr6PC15JTovUKAKcU31kBAIBhuxvugq/eTNlRc2Vf+OvMlQVOiHqrMyaE3dXqVl13yjW1OuE2bNaJ+eHrzSeKoRB2qZDSXNZRlDYsMMRvvQ4s2BoXvh609ToqfC0kC7ReAeAcIKgFAOC8Y64scKJZa/Vgp+mNIajp1qY7jmA1MKbgwU4z9JyIkeZzbhv2hYsFfeK5RS0XHG8kgXvJOfEJfUXAyTOq9ToufN1X69UJt15Hha+0XgEAgwhqAQA4Tw40V/Yld5wBc2WBI9Vod3S3XA8s5wqHsLe3amq0u6HnpOJRLRfdxVzPLuf7IWzeDWEX8o7iUVruOL+stao0K0Ot13Hh635arwvpBb1n+j1DM15pvQIADgtBLQAAZ9XgXNnVP3IvzJUFjo21VuVaKxDC1rRaruv2Zj+IvbfdkA1PJdBsNqmlQkpPL+b0safn3AVdvUVdhZQK6biMYSwBzpdmp+kGq4Pha20j1ILdqLkf99N6fab0TChspfUKAJgkgloAAM6K/c6VXXrJDWann2SuLPCYWp2u1ipu8Lpadtuwt3qBrHfZaXZCz0nEIlr2AteP3Jj1w9deELtYcJSM0crD2RdsvY4KWwdHEOy39frM9DNjRw7QegUAnGQEtQAAnEahubJeMLvxZv9+5soCh6Jab2nVW8p1uzeSYLM/kmCtUld3oA1bmkpouZDSldkpffDaTCiEXS6mND2VoA2LMyvYeg2FrV7LdfD2qNarkVHRKYZar6VUaWz4SusVAHBWENQCAHDS7Xuu7I+6wSxzZYF96XSt7lUbY0PY21s1VevhECkWMVosOFoupPT+q9O6UOgv5+rNiE0laOvh7BhsvY4KW4O3q63qyM/jRB1Np6ZDrddeEFtySqEgltYrAOC8IqgFAOAk6c2VXX2935ZlrixwINZaVeptrVXq3qWhtUpd673r1brWvWPtgTpszolpuZjWhWJKr6yU/CbsUiGlC8WUZjJJRSO0YXG6DbZeR14/hNZr8DatVwAAHo6gFgCASapteqMLRsyVjcTdubLP/4gXyjJXFthttrVWaehuua716mAQ64awa5W66q3u0HNzTkzzOUfzOUevXpnSfM7xxxIsF1NazDvKOvEJfFXA4xnVet0rfH2c1mvvNq1XAAAO30ODWmPML0n6y5LWrbXPesdKkj4n6bKkdyT9NWvtpnff35X0E5I6kv6OtfZ3vOMvS/plSSlJvy3pp60d3G8LAMAZ1qoH5sp+ibmyQECj3dF6paH1al13y27wGmy+9oLYamO42ZeKR7WQdzSXTer5CwXN55Kazzmayzmazyb9cJaRBDhNRrVee0u1RgWxB2m9jgtfab0CADBZ5mFZqTHmQ5K2Jf1qIKj9h5I2rLW/YIz5OUlFa+3PGmOekfQbkl6RtCTp/5R03VrbMcZ8UdJPS/r3coPaf2yt/dcPe4E3b960r7322qN/hQAATMJD58ou9gNZ5sriDGt3urq/3QyMIeg3YNeqDW8cQV2bu62h5yaiEc3lekFrUnNZx7/uH8s5yiZjLOfCiddrve4VtgZvP6z16oetXtA66nYxWaT1CgDACWOM+ZK19uao+x7aqLXW/jtjzOWBw5+U9BHv+q9I+r8l/ax3/DettQ1Jbxtj3pD0ijHmHUk5a+3vey/oVyX9oKSHBrUAAJx41kqV24FQlrmyOPu6XauN3abfdL0bCGHXvTbsWqWh+9sNDfYCIkaazSa1kHN0sZTWzctFzXshbD+YdVRMxwlgcaL1Wq9D4Wvgun+7saF2d+/Wa8kphVqvg+ErrVcAAM62R51RO2+tvSNJ1to7xpg57/iy3MZszy3vWMu7Pnh8JGPMpyV9WpIuXbr0iC8RAIAjMjhXdvV1aXvNvY+5sjjlrLWq1Nr+rNe1wOiBYBC7Xm0MLeKSpOmphOZyjhZyST27lHfHD+SSfhA7n0tqmoVcOKHGtV7H3d5P63UuPaenp5/uL9hySv59JaekQrKgWITVIQAA4PCXiY36jtvucXwka+1nJX1WckcfHM5LAwDgEexnruzVj0pL3giDBebK4uTaabT7YasXxN4tN7xZsP1QttHeexHX1aszQ+MH5nOOZjNJJWL8owROlnGt15EjCOr7a70+VXoqFLYGw1darwAA4FE9alC7ZoxZ9Nq0i5K89dS6Jeli4HEXJK16xy+MOA4AwMkRmivrtWXX/qw/VzazIF24Kb34o24ou/iClCpM9jUDkuqtju5VGwPzX+taK3u3vaVc2w9ZxPXipYI7fiCwgKs3G5ZFXDgpgq3X/YSv1eb+Wq+D4SutVwAAcNwe9buNL0j6lKRf8D7+VuD4rxtjPiN3mdg1SV/0lolVjTHvk/QHkn5c0j95rFcOAMDj2Ndc2RekD/zt/ggD5srimLU6Xd3fbvjh6/rAIq61shvIbj1kEddTC1l9+PpsfxFX1vFHEmRYxIUTINh63St8fVjrtZAs+AHrU6WnhgLXYOs1FUtx7gMAgBPloUGtMeY35C4OmzHG3JL09+QGtJ83xvyEpHcl/ZAkWWu/Yoz5vKSvSmpL+ilrbcf7VD8p6ZclpeQuEWORGADg+PTmyq4GZssyVxYT0u1aPdjxFnFVg3Ng+4u47pYberAzvIgrGjGazSQ1n0vq0nRa37ZS1ELO8ccP9ILYAou4MEGDrdeHha/jWq/JaFLTzrSmU9MjW6/B8JXWKwAAOO2MHfzu/4S5efOmfe211yb9MgAAp8nD5spOX+sHssyVxSGy1qpca4UWcK1XG/texDWTSWgu6/jzX92lXMHbSU1PsYgLk9HqtPxG69gZr7X9tV7HNV2Dt2m9AgCAs8gY8yVr7c1R9/FPzgCA063bke7/RTiUZa4sjsBOo627vfA10IDdzyKufCruh61XZ4OLuPoh7AyLuHDMrLWqtqrhpmsgaD1o6zU463Vc+ErrFQAAYDy+SwIAnB4PmyubyErLLzJXFgfSW8R1d6D1Gg5iRy/iSiei3tiB4UVcC3nHmwWblBNnEReOR6vTCs1yHQxfH6X12gteh1qw3kgCWq8AAACHg6AWAHBy1TbdILYXyo6dK/uSN1f2GnNl4est4rpb9sLXaiB87bVixy3iikX8Wa9PL+T04evJ0PzX+bzbhs0k+VYKR6vXeg01XQ+x9TpqBAGtVwAAgMngOzAAwMmwn7myV76TubIILeIKBa8DS7nGLeKayyY1l3P0xHRar6yUNJ9LhhZxLeQc5VMs4sLR6bVeQ83Xx2i9llKlUPBaSpX8UHbamVYpVVI6luacBgAAOOEIagEAx+8gc2WXXpKWXmSu7DkwuIjrbiU8+3Wt6o4kuDdiEZcx0vRU0p/3+t4LeW8pF4u4cPSCrddg0BoMYg/aep1Nz44PX2m9AgAAnEl8dwcAOFoj58p+WeoFFcyVPRe2G22/AbteaQwv5fLasM0xi7h6c2CvzfUXcblBrHt9NptUPMrYCxyeca3X3u1gGLtR31CrOzxCY7D1eqN0ww9bS6nAnFdarwAAABBBLQDgsD10ruyz0vP/UT+UZa7sqVZvdfxZr8ExBGsDQexOszP03KlE1J31mnX08qWi13rth68s4sJhanaa2mpsabO+qXKjrM3GprbqWyNbrxv1DVWalZGfZ7D1OjZ8pfUKAACAA+I7RwDAowvOlV31QtkHb/TvZ67sqdXqdHWv2p/3OriIq3e9XBu/iGsh5+jppZw+cmMuNH5gPsciLjyeWrumrfqWG7w2vODVC2BHHdtsbKrWro39fIVkwW+13ijdCM12pfUKAACA48LfkAAAe7NW2rknbbwtbb7tftx4S7r/DWntK4G5svPS8k3p+R9xQ1nmyp5Ina7Vg51Gf+TAwCKuu2X3+oOd5p6LuFZmpvS+K9PeCIJ++DqfS7KIC/tmrdVOa0dbjS3/EgxXe8HrYCjb6DTGfs5cIqdCsqBCsqDZ9KyuFa/5t/PJvIpO0b9dSBZUdIq0XgEAAHAi8F0pAMBd7lX+VjiM9T++IzW3Aw82Uv6CVLrSnyu79JI7V5ZwbmKstdrabfmzXtdGLOJaK9d1b7uhzh6LuBbyjp6/WOiPH8gl/aVc01MJRVjEhTG6tqtqs9oPXetbIwPY4LGtxpbavX/sGRAxEeUTeT9cXcws6pnpZ9yA1QkEr8mifyyXyBG6AgAA4NTiO1kAOC9aNWnzm14A+1Y4jN16VwouwokmpMITbhh7+YNScUUqrbgfi08wvuCYVestd/xAYOnWfhdxFdJxf9brtbkZLfTC10ADdibDIi6EtbttVZqVobB1VLu1d6zcLKtrh89BSYqZWChcvZy7HGq39gLX4LFsIquI4bwEAADA+UFQCwBnSW1zdCt2422puhp+bDInFS+7c2Of/n43lO2FsbklKcICp6NWb3WG5r6uVxv7WsSVScbcea+9RVzeUq75wDKu2SyLuCC1Oq3hsHVE4zV4bNwiLUlKRBJ+6FpMFnWtcE1FpxgKW3sjBXrHpuJTjMMAAAAAHoKgFgBOE2ul6t3RrdjNt92gNmhqzg1gr3w43IotXZHSJUYVHJFWp+sHrqHxAwNLuUYt4krGIn7Y2lvEtZD3FnFl+21YFnGdT/V2fd9ha++y09oZ+/lSsVQoXF2eWfYD2F671Q9evWOpWIrQFQAAADgC/C0PAE6aTssdRRCcEeuHsu9Iwc3lJiLlL7oB7DM/GG7FFi9LycyEvoizqbeIa60cHjmwXqmHgtj7282h58bGLOIKNmDns45yqRgh2DlgrVWtXdNmY3N04DoigC03yqoF//8fkI1n/XC16BS1kl8JLczqBa7+xSkoGWWMCQAAAHBSENQCwCQ0dwYC2EArdutbkg38qHvM6bdhr3400IpdkQqXpGh8Yl/GWRFcxHW3PDz7tdeKHbeIaybjLuJaHFjEtZBzZ8PO5xyV0iziOqustaq2qirXy/3gdVzbNXC81R1uVEuSkVEumfMbrPPped0o3ggt0Qq2YAuOO+M1HuG9AAAAADjNCGoB4ChYK+1uDMyKDYSy22vhxzsFN3hdfll69q+Gw9jMghRhoc6jsNZqu9H2w9a7gTEE6wNLuZqd4SVIxXTcHTmQc3R9Puu3X+e8EHY+52gmk1CMRVxnRqfbUbVZ9RdlbdY3h1qto4517PAcYUmKmmioyXoxc1HPzTwXbrYGAthCsqBcIqcoM6IBAACAc4egFgAeVbfrLuga1YrdeFtqDCzjyS65weuT3y2VLrtjCnphbKo4kS/hNKs1O6GwdS0wfiC4lGt3xCKubG8RV87Rt10u+Uu5WMR1trS6LZUbZT9cLTfKQ6MGesd6HyuNiqzsyM8Xj8RDoerVwtWxYWvvdiaeUcQQ5AMAAAB4OIJaANhLu+HOi93wGrHBMHbzm1Kn0X9sJOaOIiiuSBdeCbdii5eleGpiX8Zp0mx3dW87vIjrrhfCrgdC2Uq9PfTcZCyihbw76/U9Szl99Kk5P3idyzpayDuayyY1xSKuU6fRaYwOV8c0XsuNsqqt6tjP50SdULC6OLXoz3cdF76mY2nmBwMAAAA4MvxNFQDqleE27MZb7gzZ8i0p2K6LT7nB68x16frHvSDWW+CVuyBFeVsdp9O1erDd6DdgA/NfeyMJ1it1PdgZvYhr3pv3enU2ow9cndbc4CKunKOcwyKuk663RCvYbh3VeB08ttcSran4VChcvZS7pELSndtaTBZHtl2dmHOMXzUAAAAAPByJAoCzz1pp597oVuzG29Lu/fDj09NuAHvp/QOt2BUpM+duj4LPWqvN3ZbfdB1cxNU7fq/a0MAeLhkjzWbcoHW54OjFSwVvBEHSD2ZZxHVyWWu109oZarcONl7LjbK2Glv+sUawiT4gm8j64epselbXitf8cHVk4zVZUJyFegAAAADOAIJaAGdDpy1Vbo1uxW6+IzW3Aw82Uv6CO47gqU+EW7HFFcnJTeZrOGG6XavN3abWqw334s19XQ/Mf12rNHSvuvcirvmco6cWsv5SrvlsvwHLIq6To2u7qjarbqAaCFd7l2C7NXhfuzs8gkKSIiaifCLvh6uLmUU9M/1MKHAdbLzmEjnFInxrAgAAAOB84m9DAE6PVs2dCzuqFbv1TSkYGEUTbhBbXJEuf7Dfii1dcefIxpIT+zImrdO1erDT0LoXsvYWcq1XvTZstaF7lbrubTfU6gwvVco5Mc3l3Fmvr66UvBEESX8MwVzWbcImYyzimpR2t61Ks+LPdA2FqwPHegFsuVlW1w4H7pIUM7FQuHo5dznUbu0FrsFj2USWJVoAAAAAcAAEtQBOltrmQCs2cL26Gn5sMueGsQvPSs/8QHhEQW5JipyvoLDd6er+djMQuLof16sN3QuEsfe3m+oMziCQ24DthaxXZ6e9BVxu8BoMYJ34+fp9nbRWpxVqto5rvAYD2Gpz/BKtRCThz2stJou6Xrwearfmk3n3vsCxqfgUs38BAAAA4IgR1AI4XtZK1bujW7Ebb0n1rfDjM/Nu8Hrlw+FWbHFFSpfOxbzYZrure9vDowfWK8E2bEMPdhqyw/mrZjIJzWbd0LU/giDphrBeGDubpQF7HOrt+thwdbDx2rvstHbGfr5ULBWa1bqcXfYD2F7gGgxl88m8UrEUoSsAAAAAnEAEtQAOX6clbb07uhW7+Y4U3N5uIlL+ohvAPvtXwq3Y4mUpmZnUV3Hk6q2OP3qgt4DLnwcbCGQ3dppDz40YaSaT1FwuqYW8o+cv5v0wdi7bX8Q1k0kqzgzYQ2et1W57dyhwfVjjtd6pj/2c2Xg21GZdya/4AWyv3RpaouUUlIye3xEeAAAAAHDWENQCeDTNnYEA1mvEbrwtlW9JttN/bCzlhq6lFenqR71WrBfGFi5JZ2xj+26z7Y8cCM5/vecd6wWy5Vpr6LmxiNGsF7ZeKKb18hNFf+SAH8Bmk5rOJBWN0Io8DNZaVVvVobB16PbA8VZ3+L+fJBkZ5ZI5v8E6n57XjeKNULt1MHDNJ/OKR87W/wcAAAAAgIMhqAUwmrXS7kY4gA2Gsttr4cc7BTd8vXBTeu6H+kFsaUXKLEiR09/q3G603ZA1sHhrvdofQ7DmhbHVRnvouYloxA1gc0ldmZ3S+69O+/Nf5wLzX0vphCIEsI+s0+24S7T2ClsHjpUbZXWC/7AQEDXRUJP1Yuainpt5bmTY2rueS+QUPWfzkQEAAAAAj4+gFjjPul2pcnt0K3bzHalRCT8+u+QGr09+t1S63J8VW1qRUsVJfAWPzVqrSr3dn/vaa8COCGN3m8NhXjIW8VuuTy1k9aFrs/3gNdCALaTjzAU9oFa35Y4PqPcD1c3G5tiwdbOxqUqjIqsRg3olxSPxUKh6tXB1bNjau52JZxQxp/8fGQAAAAAAJx9BLXDWtRvuvNhRrdjNd6ROYP5pJOaOIihdkS6+Gm7FFi9L8dSkvooDs9Zqa7elNT9sdUcO3BsYR7BeaajR7g49P52Iaj7naDab1LPLeT9wncslNe+1X2ezjnJOjAB2Hxqdxuhwtb45su1abpRVbVXHfj4n6oSC1cWpReWTeRWd4ti2azqW5r8VAAAAAODEIqgFzoJ6ZXwrtnxLCjYM41Nu8DpzXbr+8XArNndBip7st4Vu12pjt+nPeb0XWsLlBrD3qu6l2RkOYLPJmN94felS0Q9gZwPt17mco0zyZP8+TErXdrXd2lalUemPGNhH47UWXCA3YCo+FQpWL+Uu+cuzisniUNs1n8wrFTs9/2gAAAAAAMB+kEQAp4G10s690a3Yjbek3Qfhx6en3QD20vvDrdjSFWlqVjqBrcJO1+rBdqPfdA0s3QqOIbi/3VC7O/yj7YV03J/5emVmSnPBBmzvetZRKnG+Z4daa1Vr11RpukFrtVlVtVn1r/sfG959rWro9nZre+xoAUnKJrJ+uDqbntW14jU/cB3VeM0n80pEE8f4OwAAAAAAwMlEUAucFJ22VLk1uhW78bbU2gk82Ej5C+44gqf+UrgVW1yRnNyEvohhrU7XGzfQ6M+B9Rt5z20fAAASKklEQVSw/TD2wXZDI/JXTU8lvCVcjq7PZzUfmP86F2jDOvHzE8A2Og0/PB0Vtvauj7tv3OKsnnQsrVwyp2wiq2w8q4X0gq4Xr7u3vWO9+4ON11wip1iEP1YAAAAAAHgU/I0aOE6tWj94HWzFbr0rddv9x0YTbhBbXJEufzDcii1ckmLJSX0VkqRGu+PPfr0XbMAOzIPd2G3KDgSwxkgzmaTXck3q2aW85nNJzebCC7hmMkklYmdvkVOr29J2c3u4xRoIU/0gtlUJ3a42q2p2m3t+/mQ0qWwiq1zCC1Odop7IPRE61vvYu967nUlkCFsBAAAAAJgA/jYOHLbGtteGfVN68GY4lK2uhh+bzLlh7MJz0jOfDLdic0tS5PhborVmxx890Bs50BtHcC/QgN3abQ09Nxoxms244wYuFFN68VIx1ICdz7lLuKanEopFT28A25vTGhoRMKLFOnKkQLOy57xWSYqZmBugJnPKxt0wdSG94LdYc4lcqNUaDFyziayS0cmG+AAAAAAA4OAIaoFH0dx1w9cHbwYC2bfcj9t3w4/NzLvB65UPh1uxxRUpXTq2ebE7jXZg6dbwGILefdV6e+i58ajRXNbRbDapy9NTemWlpPms4y/l6n0sTSUUjZy8+beDxs1pHTkyoFHx57T2bj9sTquRUSaRCTVVBxutg03W4LFULCVzAucIAwAAAACAo0NQC4zTqntjCt4MB7IP3hxuxk7NSqWr0pMfc0PY6avu7dIVKZk5spdorVW10XbDVm/kQL8BGw5jd5rDc0kTsYjfeL0+n9V3XJt158F681979xVScUVOWADrz2ltVkKt1nEt1uD91WZVbTscSAelY+lQeLqQXtC1wrWhpmsukRtqtmbiGUXM6W0MAwAAAACA40dQi/Ot3ZS2vjnQjH1TevCWVP6WFGxNpkpuALvyIS+IvdL/6OQP9WVZa1WutfyRA+uVhta8j/eq4XEE9VZ36PmpeNQPWZ9Zyuk7b8x5rVf3WO++XCo2seZmb07rwxZfjWy9Nir7ntPauxSdoi7lLo2c0ZpNZJVP5PtBayKjeCR+TL8TAAAAAAAABLU4Dzotd1FXbzxBMJDdeleygaDTybtN2EuvSqW/3m/GTl+RUsXHfindrtXmbrMfwIZar/0A9t52Q832cACbScb8wPWFi4X+/NdcUrOBJVyZ5NEHsME5rcFZrWMXYw0Errvt3T0/f9REh8LUhfTC+IVYyRxzWgEAAAAAwKlFUIuzodtxQ9deG3YwjO0Gfsw9kXWD16WXpOd+yAtivUD2EWfGdrtWG7vNQPO133gNjiG4V22o3R2ebZpzYv6irVdWSv3Zr94Ygt596cTh/S8bnNO6n5EBg9e3m3vPaZU0tPDqUvZS6PbgjNZgAMucVgAAAAAAcJ4Q1OL06Halyq3weIJeILv5jtRt9R8bn3JHEiw8Jz3zg4Fm7FV3nuw+A0BrrTZ3W1qr1PuLuLwQdm1gGdeoALaYjms+5y7henJ2xmvAuvNfewHsbDYpJx59pN+S4JzWYKu12qyq2nJvjwtb9zOnNRVLhQLV+fS8P6d1bNjqBbFTsSlFI4/2dQEAAAAAAJw3BLU4WayVKqvD82I33pQ23pY6jf5jY44bxs7ekJ76RLgZm13YM4wNzoAdGcIG5sE2O8MjCArpuB+0XvUC2N7Ygd4SrtlsUsnY3kFlu9tWtVnRem2PeayBY4P3NYK/HyPEI/HQsquCUxhqtY4aJZBL5JjTCgAAAAAAcIwIanH8rJW210Yv8Np4S2rX+o+NJqTiihvAPvld4WZsdkmKRAY+tVWl1tba+rY7hqBS90PX3hiCXig7agZsbwTBfM7RqytTfug6P6YB2+l2QnNaq811bTWrevd+RdXVfqv1cea0Doap8+n5oRbryLA1mWNOKwAAAAAAwClBUIvD16pL1TvupbLqXqp3pMptb5nX21Jzu//4SFwqXnbD1ysfcefH9sLY3LIUicpaq2qj3Z/9+lZda5W33dA1EL6uVepqjAhgs94Srvmco2+7XPKbr7PZmLLpjtLJphKJppp2xwtRN/xAda21rTfKVVXuhdus1WZV263toV9r6NeOZ0NjAS5lL4Vu+4FrPLwQizmtAAAAAAAA5wdBLfbPWqm2GQheAwFsJRDM1jaGnxufknKLbjv2iW/3glg3kN1OLWptux0OXdcaWqvc13rltt+IrbU6Q582k4xoJtdVKdvRkxc6eiHdVtpxQ9d4rC4TrctGdlVr7/iN1m+2qvqz3aqqW1XttHYe+mUHg9ZsIqvlzHKouTqq1dq7MKcVAAAAAAAA+0FQi7B6xWu9vul+fPCWtPVNN4yt3pXa9YEnGHc5V25Ryl+ULr4im11UPTWvSnxGG5FZrZuS7jWT2thtuiMIthpae7fXgP2adtuvyURrMtG6TMT9mIg3lJtqK+U0lVpo6uqFuiKxurqmprbdVaO7o932tnbbO3og6YEktSSVvYv/6owyiUwoQL2YuRgKVjOJDEErAAAAAAAAJoqg9jyql70Q1lvQ5Yeyb0q798OPzS6pW3hCjbkXtXNxTuX4jDYi01pXSavdkr7VTOtObUcPdivaelBWebWi7VZVXfOWFPlqKHw1ETdsjccbMoma7GxN0dm6smNe5o6kXRllYuGgNRufGxusDh6bik8pYiJjfgUAAAAAAADgZDj2oNYY872S/ntJUUn/i7X2F477NZxpzV1p5560c99twZa/Jbv1rtob78puvStbuaVWs6zdSEQ1Y1SLGN1PlHQ3PqO76ad0N5PTHZvWXZvSejemnU5bnWpdZqcuE31LJvIVyQtdTbQmE2m5v27Su0hKBF6OUUTp2JQyiaxyiawKzowfpGbiGYJWAAAAAAAAQMcc1BpjopL+R0nfLemWpD80xnzBWvvV43wdp0KnLdW3ZHc3tFO+p3J5VVvVu6ps31e19kC1+qZ2mxXV21XV2ttqdHfVUEO7pqvtSETViNGuiWg3YrRjotqJRFRLG7UyOUm5Eb9g1buEJRRXPJJSKprRVDyjXGJW+WRWpVRe0+n8yKDVHykQzygdTxO0AgAAAAAAAA9x3I3aVyS9Ya19S5KMMb8p6ZOSCGol/f4f/R/6zB/+12qarupGqkWMdo1RIzIm6Iwp8F8wqng3o4SNKWGTipu0YpGMovGc0vGcZuJpZeJp5ZJTyiWnlHemVExlVEpnVXSmlI6nlY6llYqlQhfmswIAAAAAAABH77iD2mVJ3wrcviXp1cEHGWM+LenTknTp0qXjeWUngDM1rYgyyiiholKKK6VEdEqJeFbJRF7pVElTqVll0jMqpgsqpXOaTmW1kC1oOp1TIpp4+C8CAAAAAAAA4MQ57qDWjDhmhw5Y+1lJn5WkmzdvDt1/Vr14/X363PU/mPTLAAAAAAAAAHDMjnt46C1JFwO3L0haPebXAAAAAAAAAAAnynEHtX8o6ZoxZsUYk5D0w5K+cMyvAQAAAAAAAABOlGMdfWCtbRtj/pak35EUlfRL1tqvHOdrAAAAAAAAAICT5rhn1Mpa+9uSfvu4f10AAAAAAAAAOKmOe/QBAAAAAAAAAGAAQS0AAAAAAAAATBhBLQAAAAAAAABMGEEtAAAAAAAAAEwYQS0AAAAAAAAATBhBLQAAAAAAAABMGEEtAAAAAAAAAEwYQS0AAAAAAAAATBhBLQAAAAAAAABMmLHWTvo17MkYc0/SNyf9Oo7AjKT7k34RONM4x3DUOMdw1DjHcNQ4x3DUOMdw1DjHcNQ4x3DUzuM59oS1dnbUHSc+qD2rjDGvWWtvTvp14OziHMNR4xzDUeMcw1HjHMNR4xzDUeMcw1HjHMNR4xwLY/QBAAAAAAAAAEwYQS0AAAAAAAAATBhB7eR8dtIvAGce5xiOGucYjhrnGI4a5xiOGucYjhrnGI4a5xiOGudYADNqAQAAAAAAAGDCaNQCAAAAAAAAwIQR1E6AMeZ7jTHfMMa8YYz5uUm/HpwsxpiLxph/a4z5mjHmK8aYn/aO/31jzG1jzJe9yycCz/m73vn0DWPMxwPHXzbG/Kl33z82xhjveNIY8znv+B8YYy4HnvMpY8xfeJdPHd9XjuNkjHnHOze+bIx5zTtWMsb8nvff/veMMcXA4znHsG/GmBuB96ovG2Mqxpif4X0Mj8MY80vGmHVjzJ8Fjk30fcsYs+I99i+85yaO+vcBR2fMOfbfGWO+boz5E2PMvzTGFLzjl40xtcD72S8GnsM5hpHGnGMT/bORc+xsGXOOfS5wfr1jjPmyd5z3MRyIGZ9V8P3YYbLWcjnGi6SopDclXZGUkPTHkp6Z9OvicnIukhYlveRdz0r6c0nPSPr7kv6rEY9/xjuPkpJWvPMr6t33RUnvl2Qk/WtJ3+cd/88l/aJ3/Yclfc67XpL0lvex6F0vTvr3hMuRnGfvSJoZOPYPJf2cd/3nJP0DzjEuh3CuRSXdlfQE72NcHvNc+pCklyT9WeDYRN+3JH1e0g97139R0k9O+veJy6GfY98jKeZd/weBc+xy8HEDn4dzjMtBzrGJ/tnIOXa2LqPOsYH7/5Gk/9a7zvsYl4OeX+OyCr4fO8QLjdrj94qkN6y1b1lrm5J+U9InJ/yacIJYa+9Ya1/3rlclfU3S8h5P+aSk37TWNqy1b0t6Q9IrxphFSTlr7e9b9x3rVyX9YOA5v+Jd/98kfcz7F6yPS/o9a+2GtXZT0u9J+t5D/hJxcgXPi19R+HzhHMOj+pikN62139zjMZxjeChr7b+TtDFweGLvW959H/UeO/jr4xQadY5Za3/XWtv2bv57SRf2+hycY9jLmPexcXgfw4HtdY55/73/mqTf2OtzcI5hnD2yCr4fO0QEtcdvWdK3Ardvae8QDueYV/N/UdIfeIf+lnF/9O6XAj9OMO6cWvauDx4PPcf7y0dZ0vQenwtnj5X0u8aYLxljPu0dm7fW3pHcP4QlzXnHOcfwOH5Y4b8Q8D6GwzTJ961pSVuBEI9z7ez7G3JbPz0rxpg/Msb8P8aY7/COcY7hUUzqz0bOsfPlOyStWWv/InCM9zE8koGsgu/HDhFB7fEzI47ZY38VOPGMMRlJ/1zSz1hrK5L+J0lXJb0g6Y7cH1uRxp9Te51rj/IcnC3fbq19SdL3SfopY8yH9ngs5xgeiTcf6gck/TPvEO9jOC7HcU5xrp0jxpifl9SW9GveoTuSLllrX5T0X0r6dWNMTpxjOLhJ/tnIOXa+/IjC/3jO+xgeyYisYuxDRxzjfewhCGqP3y1JFwO3L0handBrwQlljInLfeP7NWvtv5Aka+2atbZjre1K+p/ljtGQxp9TtxT+8bzgueY/xxgTk5SX+yMynJ/nhLV21fu4Lulfyj2f1rwfQ+n9yNO693DOMTyq75P0urV2TeJ9DEdiku9b9yUVvMcOfi6cId7Ckr8s6Ue9H9GU92OcD7zrX5I7d++6OMdwQBP+s5Fz7Jzw/hv/FUmf6x3jfQyPYlRWIb4fO1QEtcfvDyVd87bSJeT+SOgXJvyacIJ4M1b+qaSvWWs/Ezi+GHjYfyCpt8nzC5J+2NuOuCLpmqQvej9yUDXGvM/7nD8u6bcCz+ltSfyrkv4v7y8evyPpe4wxRe/Hrr7HO4YzxBgzZYzJ9q7L/e/8ZwqfF59S+HzhHMOjCDU3eB/DEZjY+5Z337/1Hjv46+OMMMZ8r6SflfQD1trdwPFZY0zUu35F7jn2FucYDmqSfzZyjp0r3yXp69Za/8fNeR/DQY3LKsT3Y4fLnoCNZuftIukTcrfjvSnp5yf9ericrIukD8qt6v+JpC97l09I+l8l/al3/AuSFgPP+XnvfPqGvG2J3vGbcr/Ze1PS/yDJeMcduT+K/IbcbYtXAs/5G97xNyT9p5P+/eByJOfYFbnbN/9Y0ld670Ny5/v8G0l/4X0scY5xeYzzLC3pgaR84BjvY1we55z6Dbk/ptmS26r4iUm/b3nvp1/0jv8zSclJ/z5xOfRz7A25M/F635P1NlH/h96foX8s6XVJ3885xuURz7GJ/tnIOXa2LqPOMe/4L0v6zwYey/sYl4OeX+OyCr4fO8RL7zcCAAAAAAAAADAhjD4AAAAAAAAAgAkjqAUAAAAAAACACSOoBQAAAAAAAIAJI6gFAAAAAAAAgAkjqAUAAAAAAACACSOoBQAAAAAAAIAJI6gFAAAAAAAAgAkjqAUAAAAAAACACfv/AS0eAZoXnGF+AAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAABWoAAAFlCAYAAAByV9G0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nOzdd3Sd5Zn3+++tYslFcq+4yE0SLsI2MiU2wdiUYDtl0s8kAQITCIRJOSEDzJo3JKQMb8KZkknIOqyTCWRCCpM5mcmxbJqBEBICGEIIRXIVtsG4YlsuklXu88e9kW3sxAXbj8r3s5aXtu79PHtfW4YF/Lh8XSHGiCRJkiRJkiQpO3lZFyBJkiRJkiRJ3Z1BrSRJkiRJkiRlzKBWkiRJkiRJkjJmUCtJkiRJkiRJGTOolSRJkiRJkqSMGdRKkiRJkiRJUsYKsi7gSAYNGhTLysqyLkOSJEmSJEmS3pZnnnlmS4xx8OGe6/BBbVlZGcuWLcu6DEmSJEmSJEl6W0IIr/y55xx9IEmSJEmSJEkZM6iVJEmSJEmSpIwZ1EqSJEmSJElSxjr8jNrDaW5uZv369TQ2NmZdinKKi4sZOXIkhYWFWZciSZIkSZIkdTqdMqhdv349JSUllJWVEULIupxuL8bI1q1bWb9+PWPHjs26HEmSJEmSJKnT6ZSjDxobGxk4cKAhbQcRQmDgwIF2OEuSJEmSJEnHqVMGtYAhbQfj74ckSZIkSZJ0/DptUNvZzJkzh2XLlmVdBt/85jezLkGSJEmSJEnSWxjUdlCtra0n5XWPJ6g9WbVIkiRJkiRJSgxqj0N9fT2VlZVcfvnlVFVV8cEPfpA9e/YAsHTpUqZPn87UqVO58soraWpqOuT+a6+9lurqaiZPnswtt9zSfl5WVsatt97K7Nmz+c///E++853vMGnSJKqqqvjoRz96yOs0NjbyyU9+kqlTpzJ9+nQeeeQRAO666y6uv/769usWLlzIo48+yk033cTevXuZNm0aH/vYxwD48Y9/zFlnncW0adO45ppr2kPZPn368OUvf5mzzz6bJ5544sT98CRJkiRJkiQdoiDrAt6ur/5/L/LSaztP6GtOGlHKLe+e/Bevqaur4wc/+AGzZs3iyiuv5I477uD666/niiuuYOnSpZSXl3PZZZfx/e9/n89//vMH3fuNb3yDAQMG0Nrayrx583j++eepqqoCoLi4mMcffxyAESNGsGbNGoqKiti+ffshNXzve98D4E9/+hO1tbVcfPHFLF++/M/WfNttt/Hd736X5557DoCXX36Zn//85/z2t7+lsLCQ6667jnvuuYfLLruM3bt3M2XKFG699daj/8FJkiRJkiRJOi521B6nUaNGMWvWLAA+/vGP8/jjj1NXV8fYsWMpLy8H4PLLL+exxx475N57772XGTNmMH36dF588UVeeuml9uc+8pGPtD+uqqriYx/7GD/+8Y8pKDg0U3/88cf5xCc+AUBlZSVjxoz5i0HtWy1dupRnnnmGmTNnMm3aNJYuXcrq1asByM/P5wMf+MBRv5YkSZIkSZJ0VGKEzXXwyu+yrqRD6fQdtUfqfD1ZQgiHfB9jPOJ9a9as4fbbb+fpp5+mf//+XHHFFTQ2NrY/37t37/bHNTU1PPbYY/zqV7/ia1/7Gi+++OJBge2fe7+CggLa2travz/w9Q8UY+Tyyy/nH//xHw95rri4mPz8/CN+HkmSJEmSJOmI2lph/TKoXQR1i2HrShgyGa4zrH2THbXHae3ate2zW3/6058ye/ZsKisrqa+vZ+XKlQD8x3/8B+eff/5B9+3cuZPevXvTt29fNm7cyJIlSw77+m1tbaxbt44LLriAb33rW2zfvp1du3YddM073/lO7rnnHgCWL1/O2rVrqaiooKysjOeee679NZ566qn2ewoLC2lubgZg3rx5/OIXv2DTpk0AbNu2jVdeeeUE/HQkSZIkSZLU7TU3Qt198Ku/hf+rAv79Yvj9HdB3FMy/HT72n1lX2KF0+o7arJx++uncfffdXHPNNUycOJFrr72W4uJifvjDH/KhD32IlpYWZs6cyac//emD7jvjjDOYPn06kydPZty4ce3jE96qtbWVj3/84+zYsYMYI1/4whfo16/fQddcd911fPrTn2bq1KkUFBRw1113UVRUxKxZsxg7dixTp05lypQpzJgxo/2eq6++mqqqKmbMmME999zD17/+dS6++GLa2tooLCzke9/7HmPGjDnxPzBJkiRJkiR1fXu2wYoHUufsyoeheTf0KIGJF0HlAphwIfTsd+TX6YbC0fxx/SxVV1fHZcuWHXT28ssvc/rpp2dUEdTX17Nw4UJeeOGFzGroiLL+fZEkSZIkSVIGtq+F2sUpnH3ldxBboWQ4VFyawtmy86CgKOsqO4QQwjMxxurDPWdHrSRJkiRJkqSjFyO8/nwKZ+tq4PU/pfPBlTD781CxAEZMhzynrh4Lg9rjUFZWZjetJEmSJEmSuo/W5tQtW1uTloHtWAcEGH0OXPS11Dk7cHzWVXZqBrWSJEmSJEmSDtW0C1YtTeHs8vuhcTsUFMO4C+D8v4PyS6HP4Kyr7DIMaiVJkiRJkiQlDRth+ZI01mD1o9DaBD377583O34u9OiddZVdkkGtJEmSJEmS1J1tWZG6ZmtrYP3TQIR+o2HmVVAxH0afC/nGiCebP2FJkiRJkiSpO2lrg1efgdpFad7sluXpfPgZMOfm1Dk7dDKEkG2d3YxBbQczZ84cbr/9dqqrq0/Ze27fvp2f/OQnXHfddafsPSVJkiRJknQKtTTBmsdy4ewS2LUR8gpgzCyY+ak02qDfqKyr7NYMasX27du54447jimojTESYyQvL+8kViZJkiRJkqTjtvcNWPFgGmmw8iHYtwt69IEJF6au2YkXpfmz6hBM2Y5DfX09lZWVXH755VRVVfHBD36QPXv2AHDrrbcyc+ZMpkyZwtVXX02MEUidsjfeeCNnnXUW5eXl/OY3vwFg7969fPSjH6WqqoqPfOQj7N27t/19rr32Wqqrq5k8eTK33HJL+/lNN93EpEmTqKqq4oYbbjikvm3btvG+972PqqoqzjnnHJ5//nkAvvKVr3D77be3XzdlyhTq6+u56aabWLVqFdOmTeNLX/oSAN/+9reZOXMmVVVV7e9dX1/P6aefznXXXceMGTNYt27difyxSpIkSZIk6e3avg6evBPufg98ewL8v5+CtU/A1A/Cx34BX1oFH74bqj5sSNvBdP6O2iU3wet/OrGvOWwqXHrbX7ykrq6OH/zgB8yaNYsrr7ySO+64gxtuuIHrr7+eL3/5ywB84hOfYNGiRbz73e8GoKWlhaeeeorFixfz1a9+lYceeojvf//79OrVi+eff57nn3+eGTNmtL/HN77xDQYMGEBrayvz5s3j+eefZ+TIkfzyl7+ktraWEALbt28/pLZbbrmF6dOn89///d88/PDDXHbZZTz33HN/9rPcdtttvPDCC+3XPPDAA6xYsYKnnnqKGCPvec97eOyxxxg9ejR1dXX88Ic/5I477jjmH6skSZIkSZJOsBhh44upa7auBjb8MZ0PKodzr4fKhXDameCfiu7wOn9Qm5FRo0Yxa9YsAD7+8Y/zne98hxtuuIFHHnmEb33rW+zZs4dt27YxefLk9qD2/e9/PwBnnnkm9fX1ADz22GN89rOfBaCqqoqqqqr297j33nu58847aWlpYcOGDbz00ktMmjSJ4uJi/uZv/oYFCxawcOHCQ2p7/PHH+a//+i8A5s6dy9atW9mxY8dRf7YHHniABx54gOnTpwOwa9cuVqxYwejRoxkzZgznnHPOMf60JEmSJEmSdMK0tqQu2brFaebs9rVAgJEz4cKvprEGgyZmXaWOUecPao/Q+XqyhLdsvQsh0NjYyHXXXceyZcsYNWoUX/nKV2hsbGy/pqioCID8/HxaWlr+7GsBrFmzhttvv52nn36a/v37c8UVV9DY2EhBQQFPPfUUS5cu5Wc/+xnf/e53efjhhw+6981xC2+tr6CggLa2tvazA2t76/0333wz11xzzUHn9fX19O7d+8/9SCRJkiRJknSy7NsNqx5OnbPL70vzZ/OLYNwcOO+LUH4plAzNukq9DfY8H6e1a9fyxBNPAPDTn/6U2bNntwefgwYNYteuXfziF7844uu8853v5J577gHghRdeaJ8nu3PnTnr37k3fvn3ZuHEjS5YsAVJ3644dO5g/fz7/8i//ctiRBge+5qOPPsqgQYMoLS2lrKyMZ599FoBnn32WNWvWAFBSUkJDQ0P7/Zdccgn//u//zq5duwB49dVX2bRp07H/kCRJkiRJknT8dm2GZ38EP/kofGsc/PzjULcEJl4MH/4R/N1q+Ni9cOYVhrRdQOfvqM3I6aefzt13380111zDxIkTufbaa+nVqxef+tSnmDp1KmVlZcycOfOIr3PttdfyyU9+kqqqKqZNm8ZZZ50FwBlnnMH06dOZPHky48aNax+z0NDQwHvf+14aGxuJMfLP//zPh7zmV77ylfbX7NWrF3fffTcAH/jAB/jRj37EtGnTmDlzJuXl5QAMHDiQWbNmMWXKFC699FK+/e1v8/LLL3PuuecC0KdPH3784x+Tn59/Qn52kiRJkiRJ+jO2rkpds7U1sO5JIELfUSmMrZgPY94B+YVZV6mTIBzuj8kfclEI/YD/B5gCROBKoA74OVAG1AMfjjG+kbv+ZuAqoBX4bIzx/tz5mcBdQE9gMfC5eIQCqqur47Jlyw46e/nllzn99NOP8iOeePX19SxcuJAXXnghsxo6oqx/XyRJkiRJkjqdtjZ47Q9pEVhtDWyuTefDpkLFgjRvdthUOMzoTHU+IYRnYozVh3vuaDtq/xW4L8b4wRBCD6AX8PfA0hjjbSGEm4CbgBtDCJOAjwKTgRHAQyGE8hhjK/B94Grg96Sg9l3Akrfx2SRJkiRJkqTOpWUf1D+Wgtm6JdCwAUJ+6pY985NQcSn0H5N1lTrFjhjUhhBKgXcCVwDEGPcB+0II7wXm5C67G3gUuBF4L/CzGGMTsCaEsBI4K4RQD5TGGJ/Ive6PgPfRCYPasrIyu2klSZIkSZJ09Bp3wIoHUzi74kHY1wCFvWHCvNQ1O/Fi6DUg6yqVoaPpqB0HbAZ+GEI4A3gG+BwwNMa4ASDGuCGEMCR3/Wmkjtk3rc+dNecev/VckiRJkiRJ6np2vAp1i1M4W/84tDVD78Ew5a/SWINx50Nhz6yrVAdxNEFtATAD+NsY45MhhH8ljTn4cw43MCP+hfNDXyCEq0kjEhg9evRh3yTGSHA2R4dxNLOOJUmSJEmSurQYYdPLuZEGNWn2LMCA8XDOtVC5EEZWQ54L23Woowlq1wPrY4xP5r7/BSmo3RhCGJ7rph0ObDrg+lEH3D8SeC13PvIw54eIMd4J3Alpmdhbny8uLmbr1q0MHDjQsLYDiDGydetWiouLsy5FkiRJkiTp1GprhbW/z3XOLoI36tP5adUw75Y01mBQucvAdERHDGpjjK+HENaFECpijHXAPOCl3K/LgdtyX/8nd8uvgJ+EEP6JtExsIvBUjLE1hNAQQjgHeBK4DPi34yl65MiRrF+/ns2bNx/P7ToJiouLGTly5JEvlCRJkiRJ6uz27YHVj6TO2eX3wZ6tkN8Dxp4Psz4H5ZdC6fCsq1QnczQdtQB/C9wTQugBrAY+CeQB94YQrgLWAh8CiDG+GEK4lxTktgCfiTG25l7nWuAuoCdpidhxLRIrLCxk7Nixx3OrJEmSJEmSdOx2b02hbG0NrHoYWvZCUV8ovzh1zY6fB8WlWVepTix09Nmi1dXVcdmyZVmXIUmSJEmSpO5m22qoXZzGGqx9AmIblJ6WgtmK+TBmFhT0yLpKdSIhhGdijNWHe+5oO2olSZIkSZKkri3GtACsbnHqnN30UjofOgXOuyEFtMPPcN6sTgqDWkmSJEmSJHVfLfvglcdTMFu3BHa+CiEPRr8DLvlm6pwd4AhOnXwGtZIkSZIkSepeGnfCyodSOLviQWjaAQU9YcI8mPsPMPES6D0w6yrVzRjUSpIkSZIkqevbuWH/SIM1j0FbM/QaCJPeDRULYNwc6NEr6yrVjRnUSpIkSZIkqeuJETbXQe2iFNC++kw67z8Wzr4GKhfCqLMgLz/bOqUcg1pJkiRJkiR1DW2tsP7pFM7W1sC21el8xAyY+7/SMrDBlS4DU4dkUCtJkiRJkqTOq3kvrH50/zKwPVsgrxDGngfnfiYtAysdkXWV0hEZ1EqSJEmSJKlz2bMNlt+fOmdXPQzNe6CoFCZelILZiRdBcd+sq5SOiUGtJEmSJEmSOr436qF2cZo3+8rvILZCyQg44/9IIw3KzoOCHllXKR03g1pJkiRJkiR1PDHChj+mYLa2Bja+kM4Hnw6zvwCV82H4dMjLy7ZO6QQxqJUkSZIkSVLH0NoMr/x2f+fsjnUQ8mDUOXDx19NYg4Hjs65SOikMaiVJkiRJkpSdpgZYuTR1za64Hxp3QEExjJ8Lc26C8ndB70FZVymddAa1kiRJkiRJOrUaNqaO2brFsPpRaN0HPQdA5cLUNTv+AujRO+sqpVPKoFaSJEmSJEkn3+blUFeTOmfXLwMi9BsDMz+VloGNOhvyjarUfflXvyRJkiRJkk68tjZ4dRnULkozZ7euSOfDp8EFf5/C2SGTIIRs65Q6CINaSZIkSZIknRjNjbDm16lrtm4J7N4EeQVQNhvOvgYqLoW+I7OuUuqQDGolSZIkSZJ0/Pa+AcsfSJ2zK5dC827o0QcmXgQVC9LXnv2yrlLq8AxqJUmSJEmSdGy2r0uLwGoXQf1vIbZCn2FQ9eG0EGzseVBQlHWVUqdiUCtJkiRJkqS/LEbY+EIaaVBbA68/n84HVcCsz6V5syNmQF5etnVKnZhBrSRJkiRJkg7V2gJrf5cWgdXWwI61QIBRZ8NFt6axBoMmZF2l1GUY1EqSJEmSJClp2gWrHk7B7Ir70/zZ/CIYfwGc/yUofxf0GZJ1lVKXZFArSZIkSZLUne3aBHVL0szZVY9AaxMU90uhbOUCGD8XivpkXaXU5RnUSpIkSZIkdTdbVkJdbt7suqeACH1HQ/WVUDkfRr8D8o2NpFPJv+MkSZIkSZK6urY2eO1ZqF2UZs5uqUvnw6pgzk2pc3boFAgh2zqlbsygVpIkSZIkqStqaYI1v0nhbN0S2PU6hHwomwUzr4KKS6Hf6KyrlJRjUCtJkiRJktRV7N0OKx5MYw1WPAT7GqCwN0yYB5ULYeJF0GtA1lVKOgyDWkmSJEmSpM5sx/rUMVu7COofh7YW6D0Eprw/jTQYez4UFmddpaQjMKiVJEmSJEnqTGKETS+lRWC1NbDhuXQ+cCKce30KZ0+rhry8bOuUdEwMaiVJkiRJkjq61hZY9/u0CKx2EWx/JZ2PnAkXfgUqFsDg8iwrlPQ2GdRKkiRJkiR1RPv2wKqHU9fs8vtg7zbI7wHj5sDsL6RlYCXDsq5S0gliUCtJkiRJktRR7N6S5s3WLU4hbUsjFPeFiZekkQYT5kFRSdZVSjoJDGolSZIkSZKytHVVCmZra2DdkxDboHQkzLgcKufDmFmQX5h1lZJOMoNaSZIkSZKkU6mtDTb8ITdvtgY2v5zOh06Bd34pdc4Oq4IQsq1T0illUCtJkiRJknSyteyD+t+kYLZuCTS8BiEvdcvO+MfUOdu/LOsqJWXIoFaSJEmSJOlkaNwBKx5MYw1WPAhNO6GwF4yfC5VfhvJLoNeArKuU1EEY1EqSJEmSJJ0oO1/bP292zW+grRl6DYJJ700jDcbNgcKeWVcpqQMyqJUkSZIkSTpeMcLmWqhdlGbOvvZsOh8wDs75NFQuhJEzIS8/2zoldXhHFdSGEOqBBqAVaIkxVocQBgA/B8qAeuDDMcY3ctffDFyVu/6zMcb7c+dnAncBPYHFwOdijPHEfRxJkiRJkqSTrK0V1j2Vwtm6xbBtdTo/7UyY92WoWACDK1wGJumYHEtH7QUxxi0HfH8TsDTGeFsI4abc9zeGECYBHwUmAyOAh0II5THGVuD7wNXA70lB7buAJSfgc0iSJEmSJJ08zXth1SNQVwN198GeLZBXCOPOh3Ovh4r5UDo86yoldWJvZ/TBe4E5ucd3A48CN+bOfxZjbALWhBBWAmflunJLY4xPAIQQfgS8D4NaSZIkSZLUEe3eCivuT/NmVy6Flr1Q1BcmXpTmzU64EIpLs65SUhdxtEFtBB4IIUTg/44x3gkMjTFuAIgxbgghDMldexqpY/ZN63NnzbnHbz2XJEmSJEnqGLatyS0DWwxrfwexDUpGwPSPpXB2zGwo6JF1lZK6oKMNamfFGF/LhbEPhhBq/8K1hxvAEv/C+aEvEMLVpBEJjB49+ihLlCRJkiRJOkYxwobnUjBbWwObXkznQybBeV9MIw1GTHferKST7qiC2hjja7mvm0IIvwTOAjaGEIbnummHA5tyl68HRh1w+0jgtdz5yMOcH+797gTuBKiurnbZmCRJkiRJOnFam6H+8RTM1i2Bnesh5MHoc+Hib0DlfBgwLusqJXUzRwxqQwi9gbwYY0Pu8cXArcCvgMuB23Jf/yd3y6+An4QQ/om0TGwi8FSMsTWE0BBCOAd4ErgM+LcT/YEkSZIkSZIO0bgTVj6UxhosfwCadkBBTxg/Fy74eyi/BHoPyrpKSd3Y0XTUDgV+GVKLfwHwkxjjfSGEp4F7QwhXAWuBDwHEGF8MIdwLvAS0AJ+JMbbmXuta4C6gJ2mJmIvEJEmSJEnSydHw+v55s2t+Da37oNdAOP3dqWt23AXQo1fWVUoSACHGjj1ZoLq6Oi5btizrMiRJkiRJUkcXI2xZnkYa1NbAq7k8oX8ZVC5My8BGnQ15+ZmWKan7CiE8E2OsPtxzR7tMTJIkSZIkqeNpa4X1y6B2Ueqe3boynY+YDnP/ASoWwJDTXQYmqcMzqJUkSZIkSZ1L815Y/Wuoyy0D270Z8gqg7Dw4+9NQMR/6npZ1lZJ0TAxqJUmSJElSx7dnG6x4IHXOrnwYmndDjxKYeFEaaTDhQujZL+sqJem4GdRKkiRJkqSO6Y1XcsvAauCV30FshZLhcMZHUjhbdh4UFGVdpSSdEAa1kiRJkiSpY4gRXn8eanPh7MY/pfPBlTD782ne7IjpkJeXbZ2SdBIY1EqSJEmSpOy0Nqdu2dqa1D27Yx0QYPQ5cNHXUufswPFZVylJJ51BrSRJkiRJOrWadsGqpSmcXX4/NG6HgmIYdwGc/3dQfin0GZx1lZJ0ShnUSpIkSZKkk69hIyxfksYarH4UWpugZ3+ouDR1zY6fCz16Z12lJGXGoFaSJEmSJJ0cW1akrtnaGlj/NBCh32iYeVUKZ0edA/lGE5IEBrWSJEmSJOlEaWuDV5+B2kVp3uyW5el8+Bkw5+YUzg6dDCFkW6ckdUAGtZIkSZIk6fg1N8Kax6CuBuqWwK6NkFcAY2bBzE+l0Qb9RmVdpSR1eAa1kiRJkiTp2Ox9A1Y8mDpnVy6FfbugRx+YcGHqmp14UZo/K0k6aga1kiRJkiTpyLavS+MMamvgld9CWwv0GQpTPwiVC6HsPCgszrpKSeq0DGolSZIkSdKhYoSNL6Zgtq4GNvwxnQ8qh3OvT+HsaWdCXl62dUpSF2FQK0mSJEmSktYWWPtErnN2EWxfCwQYORMu/GoaazBoYtZVSlKXZFArSZIkSVJ3tm83rHo4dc4uvy/Nn80vgnFz4LwvQvmlUDI06yolqcszqJUkSZIkqbvZtRmWL4HaxbD6EWhphOJ+UH5J6podPw+K+mRdpSR1Kwa1kiRJkiR1B1tXpa7Z2hpY9yQQoe8oOPMKqJgPY94B+YVZVylJ3ZZBrSRJkiRJXVFbG7z2hzRrtm4xbK5N58Omwvk3ps7ZYVMhhGzrlCQBBrWSJEmSJHUdLU2w5jdQV5PGGux6HUJ+6pY985NQcSn0H5N1lZKkwzColSRJkiSpM2vcASseTCMNVjwI+xqgsDdMmJe6ZideDL0GZF2lJOkIDGolSZIkSepsdryaxhnU1kD949DWDL0Hw5S/gooFMO58KOyZdZWSpGNgUCtJkiRJUkcXI2x6OQWzdTVp9izAgPFwzrVQuRBGVkNefrZ1SpKOm0GtJEmSJEkdUVsrrP19rnN2EbxRn85Pq4Z5t6SxBoPKXQYmSV2EQa0kSZIkSR3Fvj2w+pHUObv8PtizFfJ7wNjzYdbnoPxSKB2edZWSpJPAoFaSJEmSpCzt3ppC2doaWPUwtOyFor5QfnHqmh0/D4pLs65SknSSGdRKkiRJknSqbVsNtYvTWIO1T0Bsg9LTYMYnoGI+lM2G/MKsq5QknUIGtZIkSZIknWwxpgVgdYtT5+yml9L50Clw3g2pc3b4Gc6blaRuzKBWkiRJkqSToWUfvPJ4CmbrlsDOVyHkweh3wCXfTJ2zA8ZmXaUkqYMwqJUkSZIk6URp3AkrH0xjDVY8AE07oaAnTJgHc/8BJl4CvQdmXaUkqQMyqJUkSZIk6e3YuWH/SIM1j0FbM/QaCJPeAxULYNwc6NEr6yolSR2cQa0kSZIkScciRthcB7WLUkD76jPpvP9YOPsaqFwIo86CvPxs65QkdSoGtZIkSZIkHUlbK6x/OoWztTWwbXU6HzED5v6vtAxscKXLwCRJx82gVpIkSZKkw2neC6sf3b8MbM8WyCuEsefBuZ9Jy8BKR2RdpSSpizColSRJkiTpTXu2wfL7U+fsqoeheQ8UlcLEi1IwO/EiKO6bdZWSpC7IoFaSJEmS1L29UQ+1i9O82Vd+B7EVSkbAtL9O4WzZeVDQI+sqJUldnEGtJEmSJKl7iRE2/DEFs7U1sPGFdD74dJj9BaicD8OnQ15etnVKkrqVow5qQwj5wDLg1RjjwhDCAODnQBlQD3w4xvhG7tqbgauAVuCzMcb7c+dnAncBPYHFwOdijPFEfRhJkiRJkg6rtRle+W0KZmsXw871EPJg1Dlw8ddT5+zA8VlXKUnqxo6lo/ZzwMtAae77m4ClMcbbQgg35b6/MYQwCfgoMBkYATwUQiiPMbYC3weuBn5PCmrfBSw5IZ9EkiRJkk/t8NoAACAASURBVKQDNTXAyodSMLvifmjcAQXFMH4uXHAzlL8Leg/KukpJkoCjDGpDCCOBBcA3gP8zd/xeYE7u8d3Ao8CNufOfxRibgDUhhJXAWSGEeqA0xvhE7jV/BLwPg1pJkiRJ0onSsHH/SIM1v4bWfdBzAFQuTF2z4y+AHr2zrlKSpEMcbUftvwB/B5QccDY0xrgBIMa4IYQwJHd+Gqlj9k3rc2fNucdvPZckSZIk6fhtXg51NSmcXb8MiNBvDMz8FFQugFFnQ74rWiRJHdsR/0kVQlgIbIoxPhNCmHMUrxkOcxb/wvnh3vNq0ogERo8efRRvKUmSJEnqNtra4NVlULsojTXYuiKdD58GF/x9CmeHTIJwuP8MlSSpYzqa/6U4C3hPCGE+UAyUhhB+DGwMIQzPddMOBzblrl8PjDrg/pHAa7nzkYc5P0SM8U7gToDq6mqXjUmSJElSd9fcmEYZ1NZA3RLYvQnyCqBsNpx9DVRcCn1HHvl1JEnqoI4Y1MYYbwZuBsh11N4QY/x4COHbwOXAbbmv/5O75VfAT0II/0RaJjYReCrG2BpCaAghnAM8CVwG/NsJ/jySJEmSpK5i7xuw/IHUObtyKTTvhh59YOJFULEgfe3ZL+sqJUk6Id7OkJ7bgHtDCFcBa4EPAcQYXwwh3Au8BLQAn4kxtubuuRa4C+hJWiLmIjFJkiRJ0n7b16ZxBnU1UP9biK3QZxhUfTgtBBt7HhQUZV2lJEknXIixY08WqK6ujsuWLcu6DEmSJEnSyRAjvP4nqFucOmdf/1M6H1SRZs1WLoARMyAvL9s6JUk6AUIIz8QYqw/3nGsvJUmSJEmnVmsLrP1dmjdbuxh2rAUCjDobLro1jTUYNCHrKiVJOqUMaiVJkiRJJ1/TLlj1cApnl98HjdshvwjGXwDnfwnK3wV9hmRdpSRJmTGolSRJkiSdHLs2Qd2SNNZg1SPQ2gTF/VIoW7kAxs+Foj5ZVylJUodgUCtJkiRJOnG2rEyLwGprYN1TQIS+o6H6SqicD6PfAfn+p6gkSW/lPx0lSZIkScevrQ1eezYtAqtdDFvq0vmwKphzU+qcHToFQsi2TkmSOjiDWkmSJEnSsWlpgjW/SeFs3RLY9TqEfCibBTOvgopLod/orKuUJKlTMaiVJEmSJB3Z3u2w4sE01mDFQ7CvAQp7w8QLoWIBTLwIeg3IukpJkjotg1pJkiRJ0uHtWJ/GGdTVQP3j0NYCvYfAlPdD5UIY+04oLM66SkmSugSDWkmSJElSEiNseiktAqtdBBv+mM4HToRzr0/zZk+rhry8bOuUJKkLMqiVJEmSpO6stQXW/T51ztYugu2vpPORM+HCr6SxBoPLs6xQkqRuwaBWkiRJkrqbfXtg1cOpc3b5fbB3G+T3gHFzYPYX0jKwkmFZVylJUrdiUCtJkiRJ3cHuLVC3BOoWp5C2pRGK+8LES9JIgwnzoKgk6yolSeq2DGolSZIkqavauioFs7U1sO5JiG1QOhJmXA6V82HMLMgvzLpKSZKEQa0kSZIkdR1tbbDhD7l5szWw+eV0PnQKvPNLqXN2WBWEkG2dkiTpEAa1kiRJktSZteyD+t+kYLZuCTS8BiEvdcvO+MfUOdu/LOsqJUnSERjUSpIkSVJn07gDVjyYxhqseBCadkJhLxg/Fyq/DOWXQK8BWVcpSZKOgUGtJEmSJHUGO1/bP292zW+grRl6DYJJ74XKhTDufCjsmXWVkiTpOBnUSpIkSVJHFCNsroXaRWnm7GvPpvMB4+CcT6dwduRMyMvPtk5JknRCGNRKkiRJUkfR1grrnkxds7U18MaadH7amTDvy1CxAAZXuAxMkqQuyKBWkiRJkrLUvBdWPZKC2eX3wZ4tkFeYRhm842+hYj6UDs+6SkmSdJIZ1EqSJEnSqbZ7K6y4P4WzK5dCy14o6gsTL4LKBTDhQiguzbpKSZJ0ChnUSpIkSdKpsG1NbhnYYlj7O4htUDICpn8shbNjZkNBj6yrlCRJGTGolSRJkqSTIUbY8FwKZmtrYNOL6XzIJDjvi2mkwYjpzpuVJEmAQa0kSZIknTitzVD/eApm65bAzvUQ8mD0uXDxN6ByPgwYl3WVkiSpAzKolSRJkqS3o3EnrHwojTVY/gA07YCCnjB+Llzw91B+CfQelHWVkiSpgzOolSRJkqRj1fB6bt5sDax5DFr3Qa+BcPq7U9fsuAugR6+sq5QkSZ2IQa0kSZIkHUmMsGU51C5KM2dfXZbO+5fBWVenZWCjzoa8/EzLlCRJnZdBrSRJkiQdTlsrrH86dc3W1sC2Vel8xHSY+w9QsQCGnO4yMEmSdEIY1EqSJEnSm5r3wupfQ11uGdjuzZBXAGXnwTnXQsV86Hta1lVKkqQuyKBWkiRJUve2ZxuseCCNNVj5MDTvhh4lMPGiNNJgwoXQs1/WVUqSpC7OoFaSJElS9/PGK/uXgb3yO4itUDIczvhICmfLzoOCoqyrlCRJ3YhBrSRJkqSuL0Z4/fm0CKy2Bjb+KZ0ProTZn0/zZkdMh7y8bOuUJEndlkGtJEmSpK6ptTl1y9bWpO7ZHeuAAKPPgYu+ljpnB47PukpJkiTAoFaSJElSV9K0C1Y+lILZ5fdD43YoKIZxF8D5N0L5u6DP4KyrlCRJOoRBrSRJkqTOrWEjLF+SOmdX/xpam6Bnf6iYD5XzYfxc6NE76yolSZL+IoNaSZIkSZ3PlhUpmK2tgfVPAxH6jYaZV6WRBqPOgXz/c0eSJHUe/puLJEmSpI6vrQ1efQZqF6WxBluWp/PhZ8Ccm1M4O3QyhJBtnZIkScfpiEFtCKEYeAwoyl3/ixjjLSGEAcDPgTKgHvhwjPGN3D03A1cBrcBnY4z3587PBO4CegKLgc/FGOOJ/UiSJEmSuoTmRljzGNTVQN0S2LUR8gpgzCyY+SmouBT6jcq6SkmSpBPiaDpqm4C5McZdIYRC4PEQwhLg/cDSGONtIYSbgJuAG0MIk4CPApOBEcBDIYTyGGMr8H3gauD3pKD2XcCSE/6pJEmSJHVOe9+AFQ+mztmVS2HfLujRByZcmLpmJ16U5s9KkiR1MUcManMdr7ty3xbmfkXgvcCc3PndwKPAjbnzn8UYm4A1IYSVwFkhhHqgNMb4BEAI4UfA+zColSRJkrq37evSOIPaGnjlt9DWAn2GwtQPQuVCKDsPCouzrlKSJOmkOqoZtSGEfOAZYALwvRjjkyGEoTHGDQAxxg0hhCG5y08jdcy+aX3urDn3+K3nkiRJkrqTGGHjiymYrauBDX9M54PK4dzrUzh72pmQl5dtnZIkSafQUQW1ubEF00II/YBfhhCm/IXLDze9P/6F80NfIISrSSMSGD169NGUKEmSJKkja22BtU/kOmcXwfa1QICRM+HCr6axBoMmZl2lJElSZo4qqH1TjHF7COFR0mzZjSGE4blu2uHAptxl64EDJ/qPBF7LnY88zPnh3udO4E6A6upql41JkiRJndG+3bDq4dQ5u/y+NH82vwjGzYHzvgjll0LJ0KyrlCRJ6hCOGNSGEAYDzbmQtidwIfC/gV8BlwO35b7+T+6WXwE/CSH8E2mZ2ETgqRhjawihIYRwDvAkcBnwbyf6A0mSJEnK0K7NsHwJ1C6G1Y9ASyMU94PyS1LX7Ph5UNQn6yolSZI6nKPpqB0O3J2bU5sH3BtjXBRCeAK4N4RwFbAW+BBAjPHFEMK9wEtAC/CZ3OgEgGuBu4CepCViLhKTJEmSOrutq1LXbG0NrHsSiNB3FJx5BVTMhzHvgPzCrKuUJEnq0EKMHXuyQHV1dVy2bFnWZUiSJEl6U1sbvPaHNGu2bjFsrk3nw6ZCxYLUOTtsKoTDramQJEnqvkIIz8QYqw/33DHNqJUkSZLUTbU0wZrfQF1NGmuw63UI+alb9sxPQsWl0H9M1lVKkiR1Wga1kiRJkg6vcQeseDCNNFjxIOxrgMLeMGFe6pqdeDH0GpB1lZIkSV2CQa0kSZKk/Xa8msYZ1NZA/ePQ1gy9B8OUv0pjDcadD4U9s65SkiSpyzGolSRJkrqzGGHTyymYratJs2cBBoyHc66FyoUwshry8rOtU5IkqYszqJUkSZK6m7ZWWPv7XOfsInijPp2fVg3zbkljDQaVuwxMkiTpFDKolSRJkrqDfXtg9SOpc3b5fbBnK+T3gLHnw6zPQfmlUDo86yolSZK6LYNaSZIkqavavSWFsrWLYdXD0LIXivpC+cWpa3bChVBUknWVkiRJwqBWkiRJ6lq2rU7BbG0NrPs9xDYoPQ1mfAIq5kPZbMgvzLpKSZIkvYVBrSRJktSZxZgWgNXWpJmzm15K50OnwHk3pM7Z4Wc4b1aSJKmDM6iVJEmSOpuWffDK47lwdgnsfBVCHox+B1zyzdQ5O2Bs1lVKkiTpGBjUSpIkSZ1B405Y+WAaa7DiAWjaCQU9YcI8mPsPMPES6D0w6yolSZJ0nAxqJUmSpI5q54Y0zqC2BtY8Bm3N0GsgTHoPVCyAcXOgR6+sq5QkSdIJYFArSZIkdRQxwuY6qF2UAtpXn0nn/cfC2ddA5UIYdRbk5WdbpyRJkk44g1pJkiQpS22tsO4pqKtJnbPbVqfzETNg7v9Ky8AGV7oMTJIkqYszqJUkSZJOtea9sPrRXOfsfbBnC+QVwth3wrmfScvASkdkXaUkSdIJ19jcyuaGJjbubASgumxAxhV1HAa1kiRJ0qmwZxssvy91za56GJr3QFEpTLwodc1OuBCK+2ZdpSRJ0nFpbm1rD2A37mxiU0Nj++ONOxvZtLOJjQ2NbN/T3H7PtFH9+O/PzMqw6o7FoFaSJEk6Wd6oh9rFad7sK7+D2AolI2DaX6eu2bLzoKBH1lVKkiT9Wa1tka27m1LQurOR198MYnc2HhTKbt29jxgPvjc/LzCkpIghpcWMGdiLs8YOYGhp+n5oaTGn9euZzYfqoAxqJUmSpBMlRtjwxxTM1tbAxhfS+eDTYfYXoHI+DJ8OeXnZ1ilJkrq9GCNv7GnOha2N7UHsxoYDg9gmNu9qorXt4AQ2BBjUp4ihpUUM71vMGaP6MbS0iKGlxSmILUlB7MDePcjLc87+0TKolSRJkt6O1mZ45bcpmK1dDDvXQ8iDUefAxV9PnbMDx2ddpSRJ6iZijDQ0tbBxx/6xAxsbDghi3wxgG5rY19p2yP39exUytLSYIaXFlA8tSeFr32KGlrwZxBYzqE8PCvL9H88nmkGtJEmSdKyaGmDlQymYXXE/NO6AgmIYPxcuuBnK3wW9B2VdpSRJ6mL27GvZH74e1AX75hzYFMLubW495N6S4oL2jtezxw7IjR84uAt2cEkRxYX5GXwygUGtJEmSdHQaNu4fabDm19C6D3oOgMqFqWt2/AXQo3fWVUqSpE6osbn1oEVch+uC3bSziYamlkPu7VmY3z73tWrk/hEEQ0r3d8EOKS2iVw9jwI7O3yFJkiTpz9m8HOpqUji7fhkQod8YmPkpqFwAo86GfP+VWpIkHV5zaxtbdjW1h6+bDgpim9J4goZGtu9pPuTeHvl5DMmFrhXDSjhv4uD27tf2LtjSYkqKCgjBObBdgf9WKUmSJL2prQ3WP50LZxfD1hXpfPg0uODvUzg7ZFLaoCFJkrqt1rbI1t1NB3S85oLYhgOC2J1NbN3dRDx4Dxf5eYEhJSlkHTOwF2eNHdAeurYHsSXF9OtVaADbzRjUSpIkqXtrbkyjDGoXQd19sHsT5BVA2Ww4+xqouBT6jsy6SkmSdArEGNm+p5mNBwauua7XjTub2jtiN+9qorXt4AQ2BBjYu6i947VqZF+GlBQf1AU7pLSIgb2LyM8zgNWhDGolSZLUvcQIO1+D+sdTOLtyKTTvhh59YOJFULEgfe3ZL+tKJUnSCRJjpKGp5eDRA4fpgt20s4l9rW2H3N+/V2H73NfyoSUHjR548/GgPkUU5udl8OnUVRjUSpIkqWvbuQE2PAev/SH367nUNQvQZxhUfTgtBBt7HhQUZVurJEk6Znv2tRwQvh6wgKvh4Lmwe5tbD7m3pKigfQ7szLIB6XFJMcP65oLYkmIGlxRRXJifwSdTd2NQK0mSpK5j16aDA9nX/gC7Xk/PhTwYXJm6ZYdPg5FnwvDpkGfniyRJHVFTSyubdjYdMvd1087Gg0YTNDS2HHJvcWEew3IdsFNH9uPCkv2jB4bmumCHlBTRu8hoTB2HfzVKkiSpc9q95eBA9rU/QMNruScDDK6AcXNgxPT0a9gU6NE7w4IlSRJAc2sbW3Y1HTBu4IAgtuHNObCNvLGn+ZB7C/NDbu5rEROH9GH2hEEHzYB9cxxBSVGBi7jU6RjUSpIkqePbs+2ATtk/wIY/wo51uScDDJyQln+NmA4jpsGwKijqk2nJkiR1N21tka279x0y9/WtXbBbdjURD97DRX5eYHCftIhr1IBeVJf1Z2huEdeBXbD9exUawKrLMqiVJElSx7L3jf1dsm/Olt2+dv/zA8bDqLPh7GtynbJVUFyaXb2SJHVxMUa272k+aNzA4bpgNzc00dIWD7l/UJ8e7UHr1NP65jpi93fBDiktYmDvIvLzDGDVvRnUSpIkKTuNO1J37IHdsm/U73++fxmcdibM/Jv9oWzPfllVK0lSlxJjZFdTy2Hnvm48IIjd1NDEvpa2Q+7v16uQoSUpaJ04ZND+4LVkfwg7uKSIwnznwUtHw6BWkiRJp0bjTnj9+YPnym5btf/5fqNTGDvj8vR1+BnQa0B29UqS1Int3de6P3A9YO7rgeHrxp2N7NnXesi9JUUF7eMGZpYNSI/f0gU7uKSI4sL8DD6Z1HUZ1EqSJOnEa9qVC2UPWPS1dSWQ++OQfUelIHbaX+dC2WnQe2CmJUuS1Bk0tbSyuaHpoND1cB2xDY0th9xbXJiXwtaSYiaPKGVu5ZCDumCH9S1mSEkRvYuMi6Qs+HeeJEmS3p59e+D1Px2w6Os52FxHeyhbMiKFsVUfSYu+hk+DPoMzLVmSpI6mpbWNLbv2HbELdtvufYfcW5gf2scNTBzSh9kTBh3SBTuktJjS4gIXcUkdmEGtJEmSjl7zXnj9hYMXfW2uhZibW9dnKIyYAZP/an+nbMnQbGuWJClDbW2Rrbv35YLWP98Fu2VXE/Ete7jyAgwuSR2vI/v34swx/RlaWsyw3AKuNxd09etZSJ6LuKROz6BWkiRJh9fcCBtfhA1/2D9XdtPLEHOz7HoPTmHs6e/eH8qWDs+2ZkmSTpEYIzv2Nh+0fGtTw6GLuDY3NNHSFg+5f1CfHu1dsFNG9GXIm/NfD+iCHdiniHwDWKnbMKiVJEkStOyDTS8evOhr00vQlptv12tgCmMrLk2B7IjpUDoC/OOTkqQuJsbIrqaWw8593fRmKJs729fSdsj9fXsWts99nTBk0EEzYN98PKhPET0K8jL4dJI6siMGtSGEUcCPgGFAG3BnjPFfQwgDgJ8DZUA98OEY4xu5e24GrgJagc/GGO/PnZ8J3AX0BBYDn4vxrY39kiRJOqlam1MIe+Cir00vQWtu5l3P/imMfcdnUyA7Ylpa/mUoK0nq5Pbua33L+IH9XbCv79j/eM++1kPu7VNU0D739czRaQRBexdsbkHXkNIiigvzM/hkkrqCo+mobQG+GGN8NoRQAjwTQngQuAJYGmO8LYRwE3ATcGMIYdL/3969Bsl1n3Ue//37dnpm+jKXnp4Z3W8jWfJFki1fkmBvNoHECRDDhoBZIIEkuAihFnZrqxKWql1eht0iVSxsQWWXVJItIAkLFHkBBYGF5Q1JcGLZsaPYku+yRpqbNN1z6ft/X5zTfU5fZizJM3Pm8v1UnZqecxmdlv/umfnp6eeR9LikOyXtkfR3xpjj1tq6pN+X9ISkb8gNah+V9Nfr/aQAAADgqdfcHrLBQV9Xn5XqZfe4k3WD2Id+2f2456w0eJBQFgCwrVRqjVYAOx0YxtVWBVsoqVCqdV3rxCIaz7pB6517MnrXHfmuKth8JqmUw5uSAWysN32VsdZOSZryHheNMRck7ZX0mKR3eqd9UdI/SvqUt//L1tqypJeNMZckPWCMeUVSxlr7z5JkjPmSpB8TQS0AAMD6qNek2RfaB31d/a5UK7nHnYw0cVp68AmvUvasNHSYUBYAsGXV6o3WIK5rhbKuFkp+ENtsR1Asa36p0nVtPGqU96pcj46m9PajI14FbHsVbKYvJsP3QgBbwC39c5Ax5pCks5K+KWnMC3FlrZ0yxuS90/bKrZhtuuztq3qPO/cDAADgVjXq0twlv1L2ynnp6jNSddk9nki5oez9H/cHfQ0fkSL0wwMAhK/RsJpfrnRUvJZ1rdgMYt19s4tldc7hihhpNO0GrfuG+nXfwaFW+Jr3wtexjKOh/oQiDOICsI3cdFBrjElJ+jNJv2atLazxr029Dtg19vf6s56Q2yJBBw4cuNlbBAAA2JkaDWn+xfZBX1NPS9Ul93i83w1l7/t5f9DXyDFCWQDAprPWamGl2rMHbDOMnfb21ToTWEkjA4lW39dTE5lW+DoeqIQdSTmKEsAC2IFuKqg1xsTlhrR/ZK39c2/3NWPMhFdNOyFp2tt/WdL+wOX7JF3x9u/rsb+LtfZzkj4nSefOnWPYGAAA2D0aDen6y+2VslNPS5WiezzWJ43fLZ39Wb99QW5SijC4BACwsRbLNT98XaMKtlxrdF2b7Yu32g0cHc35rQeaVbCZpEZTjhIx/pERwO71pkGtcUtn/1DSBWvtZwOHvibpI5I+4338y8D+PzbGfFbuMLFJSd+y1taNMUVjzENyWyd8WNLvrtszAQAA2G6sla6/0j7o68rTUnnBPR513FD29OP+oK/cCSnKMBMAwPopVettPV87q2CboexSpd51bX8iqvGM2wf23gOD7gCujh6w+YyjZJx/UASAN3MzP+W/Q9LPSfquMea8t+8/yQ1ov2qM+Zik1yR9SJKstc8ZY74q6XuSapI+aa1tvpp/QtIXJPXJHSLGIDEAALA7WCvdeK190NeV81Lphns8mpDG7pLu/qBfKTt6hxSNh3vfAIBtq1JraGbRG7jVFsSWNV10Q9irCyUVSrWua51YpFXxenJPRu88kW+Fr/lWNWxSKYd/PASA9WKs3dqdBc6dO2effPLJsG8DAADg5lkrLVwOBLJeKLsy7x6PxKWxU34gO3FGyp+SYolw7xsAsC3U6g3NLVXaqmCnA20Imn1g55YqXdfGIkb5tKOxrD90Kx/o/9qsgs30xbTGbBoAwG0yxnzbWnuu1zH+6QsAAOCtsFYqTrUP+rrylLQ86x6PxKT8Senkj/iDvsbulGJOuPcNANhyGg2r68uVrr6vVzsqYmcXy+qcwxUxUi7lBq17B5M6e2CwFcQGq2CH+xOKMIgLALYkgloAAIBbUbzaHspOnZcWr7nHTNQNZY8/6vWUvdcNZePJcO8ZABAqa60KKzWv2rV3Fey0146gWu9+1+vIQKLV9/XURKZnFezIQEKxKIO4AGA7I6gFAABYzeJ0e5Xs1Hm3elaSTMQd7HX03V4LgzNuj9lEf7j3DADYVEvlmtvrNTB0q7Mi9lqhpHKt0XVtJhlr9Xp98MiA13bAaQ3kGs8mNZpylIgRwALAbkBQCwAAIElLs24oOxWoli284R00Uu64dPgRv6/s+N1SYiDUWwYAbJxSte4Gr4EqWDd4ba+CXSx3D+LqT0Q17rUbOHtg0A1e0/4ArrGMo3w6qb5ENIRnBgDYqghqAQDA7rM83zHo62lp4TX/+Mgx6eDb/UFfE/dITjq8+wUArJtKraGZxY7WA80gNhDKLqxUu65NxCJuq4F0UifHM3rn8e4esGOZpFIOv2oDAG4d3z0AAMDOtnJdmnq6va/sjVf948NHpP33Sw/8ohfM3iMls+HdLwDgttQbVnOLZT94Lfaogi2UNLdU6bo2FjHKp92+r4dzA3royEjPKthsX1zGMIgLALAxCGoBAMDOUSoEQllvu/6yf3zokBvGnvuoF8qelvoGQ7tdAMCbazSsri9XevZ9DVbBzhTLanTM4TJGyqUcjWUcTWSTOnNgUGPp7irY4f6EIhECWABAuAhqAQDA9lQuSlPP+EO+rjwlzV3yj2cPuAO+7v2w+3HijNQ/HN79AgDaWGtVKNXag1ev7+vVBf/xdLGkat12XT88kGhVvN4xnm4N4BoLVMHmUgnFogziAgBsDwS1AABg66ssuaFssK/s7EVJ3i/umX1uGHv6ca9S9qw0MBLqLQPAbrZUrvXs++r2hfUHdJWqja5rM8lYK2h98MiA+zjt+EFsxtFo2pETYxAXAGBnIagFAABbS2VZuvZsoH3BeWn2ecl6v8ynJ9ww9u4PuVWye85IqXy49wwAu0SpWtdM0W87cLU1kKsUaE1Q1mK51nVtXzyq8azb9/X0vsFA+4H2Kti+BAEsAGB3IqgFAADhqZYCoaxXLTvzfcnW3eMDeWnvvdKpx9xwds8ZKT0e7j0DwA5UrTfaAtjVqmBvLFe7rk3EIm7omk7q5HhG/+q40xq+NZb2q2BTToxBXAAArIGgFgAAbI5a2Qtlz/vB7MwFqeFVXfXn3DD2jh8OhLIT7iQYAMBtqTes5hbLbT1grxXKbVWw08WS5pYqsh1tYKMRo3zaUT6T1MGRfj1weLhVBetvjrJ9cQJYAADWAUEtAABYf7WKNP299kFf174nNbxKrL5hN4w9/l43kN1zVsrsJZQFgJtkrdX15aoXtnoVr11BbFkzi2XVG+0JrDFSLuVoLONoIpvU6f2DgQDWUT7thrAjAwlFIrwuAwCwWQhqAQDAW1OvStMX2gd9XXtOqlfc48lBN4h9+694lbJnpex+QlkA6MFaq0Kp1gpam+FrK4htBrDFsir17kFcQ/3xVt/XE+Ppnj1gc6mEYtFICM8OAACshaAWAADcvHrNHewVHPR19btSvewed7LSntPSQ59wA9mJsXEV4wAAIABJREFUM9LQIUJZAJC0XKn54WtbFWxZ1xaa1bAllardAWw6GWtVvD54eLjV9zVYBZvPOHJiDOICAGC7IqgFAAC9NerS7Avtg76ufleqrbjHE2m3bcEDv+hXyg4dliJUaQHYXUrVetsgrl5VsNOFsorlWte1ffGoxrNJ5dOOTu/zWxAEq2DzGUf9CX51AwBgp+O7PQAAcEPZuUuBQV9PSVefkarL7vH4gDRxWjr3UX/Q1/BRQlkAO1q13tBsYBBXsx3B1WBFbLGkG8vVrmsT0YjyXuh6YjythydHW9WvrSrYTFJpJ8YgLgAAIImgFgCA3afRkOZf8gPZqfPS1NNSZdE9Hu+Xxu+R7v2IP+hr5JgU4e20AHaGesNqbqkcqHj1gthioCK2UNbcUlm2fQ6XohGjfNoNWQ+O9OuBw8Ot0LUVxKaTGuyPE8ACAIBbQlALAMBOZq0byrYGfXmhbLngHo8l3VD2zL/12xfkjhPKAtiWrLW6sVzVtWJJVxeCPWDd4LVZETuzWFa90Z7AGiONDDititd79mWVTyc7qmCTGh5IKBohgAUAAOuPoBYAgJ3CWunGq+2DvqbOS6UF93jUkcbvku75SX/Q1+gdUpQfBwBsbdZaFcu1VtC6WhXsdKGsSr17ENdQf7zV9/X4WLqt9UDzcS7lKB6lnQsAAAgPv5kBALAdWSstvN4+6GvqvLRy3T0eibuh7F0fdAPZPWel/EkpGg/3vgGgw3KlFghfg1Ww7X1hV6r1rmvTTqzVB/b+Q8PdPWDTSY2mHSXjvEsAAABsfQS1AABsddZKhTfaB31NnZeW59zjkZiUPyWd/IA/6Ct/Soo54d43gF2tXKtrulDu6vs67bUiaLYmKJZrXdcm4xGNexWwd+8b1A+mHa8i1m9BkE87GnD4dQYAAOwc/GQDAMBWU5hqD2SvPCUtzbjHTNQNYU+83x/0lb9TiifDvWcAu0a13tDsYjnQbiAQxBabfWBLur5c7bo2EY20wtYT42k9PDnaXQWbSSrtxBjEBQAAdh2CWgAAwlS81h7IXnlKWrzmHjMRafSkNPkef9DX2J1SvC/cewawIzUaVnNLla6+r8G2BNcKZc0tlWXb53ApGjEaTbmDuPYP9+vcoSGNeYO4glWwQ/1xAlgAAIBVENQCALBZFmfaA9kr56XiFe+gkUZPSEff5Q/6Gr9bSvSHessAtj9rrW4sV3WtbehWdxXsTLGsWqM9gTVGGhlwWhWv9+zLKp9OtlXB5jOORgYcRSMEsAAAAG8FQS0AABthaU6aCgz6unJeKlz2DhopNykdftgf9DV+t+SkQr1lANuLtVaL5Vqr7+vVQPg63RbKllWpN7quH+yPayztBq2T+ZwfvKb9EHY07SgejYTw7AAAAHYfgloAAN6q5Xlp6un2vrI3XvOPDx+VDjzkty+YuEdy0uHdL4Atb6VS99sNBPq++kGs+3G5Uu+6Nu3EWu0G7j807D7uqIIdTTtKxqMhPDMAAACshqAWAIBbsXLDD2WbbQyuv+IfHzos7T0n3f9xL5Q9LSWzod0ugK2lXKtrulDu6gE7XSi1tSYolmpd1ybjETdsTSd1556M3nVHPjCEy+sHm3Y04PAjPgAAwHbET3EAAKymVHBD2WBf2fmX/OODB6U9Z6T7ft4PZfuGQrtdAOGp1RuaXay8aRXs/FKl69p41LTaDUzmU/qBY7muKth8JqlMMsYgLgAAgB2MoBYAAEkqL0pXn2kf9DV30T+e3e+Gsmd/1u8r2z8c3v0C2BSNhtXcUqWr72tnFezsYlm2fQ6XIkYaTbsVr/uH+3Xu0FArfM0HKmEH++KKMIgLAABg1yOoBQDsPpUl6ep3A4O+npJmX5DkpSyZvW4Qe89PeX1lz0gDuVBvGcD6stZqYaUaCF79vq/XCiVd9YLYmWJZtYbtuj6XSrSqYO/ak1W+2f81UAU7knIUJYAFAADATSKoBQDsbNUV6eqz7YO+Zr4vWW8CenrCrZC964N+KJvKh3vPAG5brd7Q3FKl1Qd2uljWTNF7XCi3Pp8pllWpN7quz/bFNe5VvE7mc60esM1QdiyTVC7lKBGLhPDsAAAAsJMR1AIAdo5qSbr2nHTlO15f2fPS9AXJelPRB/JuGHvyA24gO3FGykyEe88AbspKpe4Hrl7/Vz+EbQawJc0tVbpaEEjSUH9c+XRSo2lHR3IDGs04beHrWNoNZ5Px6OY/OQAAAEAEtQCA7apWdkPZ4KCv6QtSw5uU3j/ihrIn3ucN+jojZfZIDOIBtoxm+4FeVa/N4HW6WNZMoaxiudZ1fSxilEs5ymcc7R1M6sz+rEbTSeXTjvJpR6NpdwhXLpWQEyOABQAAwNZGUAsA2PpqFWnmQvugr2vPSY2qe7xvyA1j3/Eef9BXdh+hLBCSzvYDftVrR/uBxbIqte72A33xqPIZN2w9OZ7RI5Ne6OoFr80Qdrg/wRAuAAAA7BgEtQCAraVedXvINgPZK09J156V6hX3eDLrBrFv+6TXU/asNHiAUBbYBKVqva3363ShpJnFckcVbFlzS+VV2w+4gWuy1X5gNOWHr80gdiARleH/aQAAAOwyBLUAgPDUa9LsC4FKWS+UrZXc405GmjgtPfhL/qCvocOEssA6arYfWK3qdfpN2g9EI0ajKbfCdU+2vf1AsAqW9gMAAADA2t40qDXGfF7Sj0iattbe5e0blvQVSYckvSLpJ621171jvy7pY5Lqkv6dtfZvvP33SfqCpD5JfyXpV63tVWsBANiRGnVp9qIbxjb7yk49I9VW3OOJlNu24P6P+5WyQ4elCJPVgdsRbD8wsxjs/RpoRVC4ufYDd4yn9cjkaFvwOur1hqX9AAAAALA+bqai9guSfk/SlwL7Pi3p7621nzHGfNr7/FPGmFOSHpd0p6Q9kv7OGHPcWluX9PuSnpD0DblB7aOS/nq9nggAYAtpNKS5S+2DvqaekapL7vH4gFspe+4X/FB2+CihLHATgu0HVq+CXb39wGB/3GszkNSDhwc0Ghi6FayCTTkx2g8AAAAAm+hNg1pr7T8ZYw517H5M0ju9x1+U9I+SPuXt/7K1tizpZWPMJUkPGGNekZSx1v6zJBljviTpx0RQCwDbX6MhXX+5fdDX1NNSpegej/VJE/dI9/6cP+grNylFeAs00GStVWGl5vd+bYawgSrYZghbLPVuP5BLJZRPJzWRTeq0136gVQFL+wEAAABgy7vdHrVj1topSbLWThlj8t7+vXIrZpsue/uq3uPO/T0ZY56QW32rAwcO3OYtAgDW3dKcWyk7d0maueCHsuWCezyWlMbvlk4/7lfK5o5LUVqiY3eq1RuaX6qs2fv1ZtoPjKa62w80h3LRfgAAAADYGdb7N+devyHYNfb3ZK39nKTPSdK5c+foYwsAm6myLM2/5AeywW3lun9eNCGN3SXd/SF/0NfoHVI0Ht69A5uk2X6gs/drM3Rt7ptfKqvxJu0HHjg84Aevgd6vtB8AAAAAdpfbDWqvGWMmvGraCUnT3v7LkvYHztsn6Yq3f1+P/QCAMDTq0o3XpLkXvRD2ovfxRWnh9fZz03uk3DHpzh+XRo5JI5PSyFFp8CCVsthROtsPzKxRBXvT7QdSjka93q/NMHY07dB+AAAAAECX2/0N+2uSPiLpM97Hvwzs/2NjzGflDhOblPQta23dGFM0xjwk6ZuSPizpd9/SnQMA1mattDQbqIi96Aez8y9J9Yp/rpN1w9iDb/fCWG8bPiI5qfCeA7AO6g2rucWOqtdid+/X6WLv9gPJeMRtMZB2dGI8rYe99gN+/1e3/cBQf0JR2g8AAAAAuE1vGtQaY/5E7uCwnDHmsqT/Ijeg/aox5mOSXpP0IUmy1j5njPmqpO9Jqkn6pLW27n2pT0j6gqQ+uUPEGCQGAOuhshSojH0xUB17SSot+OdFE27wOnJMOv7eQHXsMWkgJ/H2amwzpWp9zarXm2k/0GwzcP+h4baK12b4SvsBAAAAAJvFWLu1W8CeO3fOPvnkk2HfBgCEq16Tbrzau1VB4Y32c7P73dYErcrYZquCA1KEt1tjawu2H2irem3r/Xpz7Qf8ile//UBwEBftBwAAAABsNmPMt62153odo7kgAGwV1kqL072HeM2/LDWq/rnJQSk3KR1+xAtlJ/1WBYn+8J4DsIpg+4HOKthgKDtTLKt8E+0HfuBYzh28FWg/MJp2NDxA+wEAAAAA2xNBLQBstvJiIIR9sb06tlzwz4s6bgg7ekK644f9MHbkmNQ/TKsCbAmd7Qc6q16b++YWb7/9wGjaUZr2AwAAAAB2OIJaANgI9ap0/dXeg7yKU4ETjTS43w1fTz/ePsgru49WBQiFtVaFUk0zq1S9NoPYmWJZhTXaD4ymHY1nk7pnX7bVfqAZytJ+AAAAAADaEdQCwO2yVipe7d2q4PorUiMQYPUNu60Kjr6ro1XBYSneF9pTwO6yVvuB4BCum2k/cHysu/1AswqW9gMAAAAAcOsIagHgzZQKgTYFHYO8Kov+ebGkG76O3Smdeqx9kFf/cHj3jx0v2H6gV9XrtLet1n4g2xd3+7x67Qc6g1faDwAAAADAxiOoBQBJqlXcKtherQoWrwVONNLgAbc69sDb2lsVZPZKkUhYzwA7TK/2A8GqV78fbKln+4GIkXJem4GxTFJ3780GAthkq/1ALuUoGaf9AAAAAACEjaAWwO5hrVS40nuQ1/VXJVv3z+3PuWHs5A8FKmOPSUOHpHgytKeA7a/esJpb8oLWVapgb6b9wGja0WQ+pXccHXHbD6QcjXrhK+0HAAAAAGD7IagFsPOs3Ai0KehoVVBd9s+L97ttCSZOS3d9MBDIHpH6hsK7f2xLfvuBslsFW2wPY2+2/cBo2tG5g0PKZ5Ktz4MtCGg/AAAAAAA7E0EtgO2pVpbmX+7dqmBpxj/PRKTBg2517KGH2wd5pSdoVYA1tbUfaLYe6Kh6dQPZm2s/cNeebKvlQLP9wGjKDWJpPwAAAAAAuxtBLYCtq9GQCm/0blVw4zXJBt4WPpB3w9gT7wv0jZ10WxXEEqE9BWxNne0HOnu/ThdLXv/X3u0HnFjEC1yTrfYDzapX2g8AAAAAAG4HQS2A8K1cl2Yv9aiOfVGqrfjnxQfciti990n3/FQgkD0qJbPh3T+2jF7tB3pVwc6u0X6gOXDrvgO0HwAAAAAAbB6CWgCbo1qS5l8KhLGBbXnOP89E3SrY3KR05J0drQrGJcKxXcdvPxAYvFXoGMJ1k+0H8mmnq/1AM5il/QAAAAAAIEwEtQDWT6MhLbze3aZg7pJ043VJgRLG1Lgbxp780cAQr2PS0EEpGg/tKWDztLUfWCxrplfv12LpptsPvP3oSKvlwGjGD19HBhzaDwAAAAAAtjyCWgC3bmmuoyo20KqgXvbPS6Sl3DFp/4PSmZ9pb1XgpMO7f2yozvYD/sCt9t6vc0sV1Xv0H8gkY62WA832A6NeRWyrD2zaUSZJ+wEAAAAAwM5BUAugt+qKF772aFWwct0/LxKThg671bHH3t1eHZvK06pgh7DWqrBS08xiya+ADbQcmG4Fs2UtrFS7rm+2H2i2GbhzIth+wG1BQPsBAAAAAMBuRlAL7GaNunTjtY5A1quOXXi9/dz0Hrc69s4fD4SxR6XBg1KUl5LtqlSta9YLXWcCYetMYF9zq9TXbj9wbLSj/UBzAFeG9gMAAAAAALwZ0hVgp7PWHdY1e7G7Mnb+Jale8c91sm4Ye/DtfouCkUlp+IjkpMJ7DrgljYbVjZVqa9hWcOsMYntVvxojDfcnWkHrkdEB93HKabUhaB6j/QAAAAAAAOuDoBbYKSpLgcrYjkFepQX/vEjcDV5zk9Lx97a3KhjI0apgC1up1L2QtdQeunaEsLOLZdV69H7ti0fdPq8pR8fHUnrH0ZFW4Brs/To8kFA8GgnhGQIAAAAAsHsR1ALbSb0m3Xi1d6uCwhvt52b2udWxd38oMMTrmJTdT6uCLaTesJpfqrRXv3rDtppVr7Pe/mK51nV9sPfraNrRHePpVh/Y0WD7gbSjAYf/7gAAAAAAbFX81g5sNdZKSzOrtCp4WWoE3qqezLrVsIcf8doUHPNbFST6w3sOu5y1VkvN6tdiuasFQbD9wNxiWT2KX5V2Yq2Q9dSejF/52tF+YHggQe9XAAAAAAB2AIJaICzlxUAI+2J7dWy54J8XddzgdfSEdMcPt7cq6B+mVcEmqtUbml2stLcfKHQM3vKqYVeq9a7rYxHTClwnskndsy/rVb62tx/IpRz1JaIhPEMAAAAAABAWglpgI9Wr0o3XelfHFqcCJxq3JUHumHT68fZBXtl9UoTQbqNYa1Uo1bqrXxfLXYO45pcrsj2qX7N98Vbgemb/YKva1e0H67cfGOyLK0L1KwAAAAAA6IGgFlgPjbp0/RVp+oI0c8H9OH3BDWiDrQr6ht0Q9ui7OloVHJbifaHd/k5UqTW6wtbV+sBWao2u6xOxSCtw3T/cr/sODnW3H0g7yqUScmIE6QAAAAAA4K0hqAVuhbXSwmUviP2eNPN97+PzUq3knzd4QMqfkibf47YsaA7y6h8O7953AGutbixXAyFrqbvvqxfC3liu9vwawwOJVvXr4cMDbW0HgkO4MsmYDG0lAAAAAADAJiGoBXqxVlq85lfGtqpkvy9Viv556T1S/g7p/o9Lo3e44ezoCclJhXfv21CpWu8IW7vbD0wXy5pdLKta7+49kIxHlE+7Fa7H8im97ehIz/YDI6mE4tFICM8QAAAAAABgbQS1wPK8XyE7fcGvkl257p/TP+KGsGd+WsqflEZPugFt31B4973FNRpW88uV7mFbPfrAFku1rusjRhpJOa3AdXIs3T54K9B+YCARpfoVAAAAAABsawS12D1KBS+E7aiSXbzmn+Nk3SD21GNuMNsMZVOj4d33FrNUrrWFrNOFUlfl60yxrLmliuqN7urXlBNrha0nJzJ6ZDLYcsDfRgYcRRm8BQAAAAAAdgmCWuw8lWVp9nm3TUGwSnbhdf+ceL/bquDYD7phbDOQzeyRdmFlZq3e0PxSpavPa68QdrlS77o+FjHKeZWvY5mk7t6b7ah8ddsP5NIJ9Sd42QEAAAAAAOhEYoLtq1aR5i76FbLN9gXXX5HkVXJGE1LuhHTgbVL+F7wesndIgwelyM7uVWqtVbFZ/do5bCsQxM4uutWvtrv4VZlkzG0vkHJ0z77BnpWvoylHQ/0JRah+BQAAAAAAuG0Etdj66jXp+stedWygSnb+Ranh9TY1UWnkmDRxWjr9uFcle0oaOixFd9Yyr9Qaml3sDFzLmlkste2bKZZVqja6rk9EIxpNO8qlHe0f7te9B4f8wVuBADaXcpSMR0N4hgAAAAAAALvPzkqwsL01GtLCa4HBXl4/2dkXpHrZO8lIQ4fcEPbkj/ptC0aOSTEnzLt/S6y1Wlipdle+LnYM3yqWdX252vNrDA8kWoHruYMDXYO33CA2qUxfjMFbAAAAAAAAWwxBLTaftVLhij/MqzXc63mpuuSfl9nnhrBH/7UfyOZOSIn+8O79FlhrVap61a89hm21gthCSbOLFVXq3dWvTiyifMYNWI/kUnrw8Miqg7cSsZ3dygEAAAAAAGAnI6jFxmjUpeKUdON1d4jXjdfcbeZ5N5QtL/jnpsbcvrH3fjgw2OuElMyGd/+eRsNqsVLTwnJVCytVFUpVFVa8xys1Laz4+5uPm8cKK9We4asx0siAH7JO5lNtVa/BIDblUP0KAAAAAACwGxDU4vZUS1LhDTd8XXg9EMi+7rYvKFzx+8c29eek3HHp7p8IBLInpYGRjb3VesMPV0uBcDXwMRi0BgPYYqmqRo8hW03RiFEmGVO2L65MX1zZvrj2DPa5nyfdz0dSiVYQm087Gh5IKBal+hUAAAAAAAA+glr01mi4QezsC9L8S92B7OK19vNNRErvkQb3S/sfcj9m93sfD0jZfbfdsqDZQqCtanV59XC1veq1qqVKfc2vn4hFlPVC1mxfXLlUQkdHB9rC19ZHL3zN9seVScaoeAUAAAAAAMC6IKjd7SpL0twlafaiu81ddMPZ2UtSbcU/L+q4YevgfmnyPdLggUAQu1/K7JGi8VX/mFq9oaVgFWsgTF2rjcBaLQSCUk5MmWSsFageGO5vPXYD1piy/cHP/QA2GY+u198mAAAAAAAAcFsIanc6a6WV64Fq2MtuhezsC25Au/B64GTjBrC549KhR1QdOqobA4c05+zXQmRIS9WGlsp1LVdqWizXtTxX09JUXcuVohbLz2m5XNdSpaalck3LlebjupbKNZVrawetEaOuqtU92T5l+uLK9MW6AtZsoNI1k4zRSgAAAAAAAADb2qYHtcaYRyX9jqSopP9lrf3MZt/DjtJouG0IggO72toUXJaqS+2XxAe0nD6sG5kzms59QG9E9+pFu1cXq6O6umI0O1XW3MWKFss1SSVJF1f94xPRiAacqPoTsdbHlBNTLuVowImpPxFVyom1jmeS8a7wNdsX10AipkiEFgIAAAAAAADYnTY1qDXGRCX9D0k/JOmypH8xxnzNWvu9zbyPLctatxXByry0PCctz8suz6tanFV1cVa14owaS7Myy7OKrswrVpqXU72hiG3vwVqIZHXNjGpKOb1hJ/VqY1iv1If1RiOnN2xO86W0VPRD0YiRhgcc5VJ1jaQSOr1vUCOphHIpR7lUQoP9CaWdmPqdmAYSUQ04MQ0kYupLRJWIUckKAAAAAAAAvFWbXVH7gKRL1tqXJMkY82VJj0kiqJV0+cI3tO+rj7btM5ISkmLW6IYGdN2mNaeM5u2g5u1+zSutq3ZY8/ExLST2aDE5rnhfWulkTKlkXCmvyvW4E9NZL2TNeAOzcilHIwNuEBulmhUAAAAAAAAIzWYHtXslBZuiXpb0YOdJxpgnJD0hSQcOHNicO9sCzNABfTn7cVUSQ6olB2X7RqT+IUX6c4qnhpTqSyrtDcw6lIzp7mTcDWRpGwAAAAAAAABsa5sd1PZKE23XDms/J+lzknTu3Lmu4zvV3om9evzf/3bYtwEAAAAAAABgk212g9HLkvYHPt8n6com3wMAAAAAAAAAbCmbHdT+i6RJY8xhY0xC0uOSvrbJ9wAAAAAAAAAAW8qmtj6w1taMMb8i6W8kRSV93lr73GbeAwAAAAAAAABsNZvdo1bW2r+S9Feb/ecCAAAAAAAAwFa12a0PAAAAAAAAAAAdCGoBAAAAAAAAIGQEtQAAAAAAAAAQMoJaAAAAAAAAAAgZQS0AAAAAAAAAhIygFgAAAAAAAABCRlALAAAAAAAAACEjqAUAAAAAAACAkBHUAgAAAAAAAEDIjLU27HtYkzFmRtKrYd/HBshJmg37JrCjscaw0Vhj2GisMWw01hg2GmsMG401ho3GGsNG241r7KC1drTXgS0f1O5UxpgnrbXnwr4P7FysMWw01hg2GmsMG401ho3GGsNGY41ho7HGsNFYY+1ofQAAAAAAAAAAISOoBQAAAAAAAICQEdSG53Nh3wB2PNYYNhprDBuNNYaNxhrDRmONYaOxxrDRWGPYaKyxAHrUAgAAAAAAAEDIqKgFAAAAAAAAgJAR1IbAGPOoMeZ5Y8wlY8ynw74fbC3GmP3GmH8wxlwwxjxnjPlVb/9vGmPeMMac97b3B675dW89PW+MeW9g/33GmO96x/67McZ4+x1jzFe8/d80xhwKXPMRY8xFb/vI5j1zbCZjzCve2jhvjHnS2zdsjPm699/+68aYocD5rDHcNGPMicBr1XljTMEY82u8juGtMMZ83hgzbYx5NrAv1NctY8xh79yL3rWJjf57wMZZZY39N2PM940xzxhj/sIYM+jtP2SMWQm8nv1B4BrWGHpaZY2F+r2RNbazrLLGvhJYX68YY857+3kdwy0xq2cV/Dy2nqy1bJu4SYpKelHSEUkJSU9LOhX2fbFtnU3ShKR7vcdpSS9IOiXpNyX9xx7nn/LWkSPpsLe+ot6xb0l6myQj6a8lvc/b/8uS/sB7/Likr3iPhyW95H0c8h4Phf13wrYh6+wVSbmOff9V0qe9x5+W9FusMbZ1WGtRSVclHeR1jO0trqVHJN0r6dnAvlBftyR9VdLj3uM/kPSJsP+e2NZ9jb1HUsx7/FuBNXYoeF7H12GNsd3KGgv1eyNrbGdtvdZYx/HflvSfvce8jrHd6vpaLavg57F13Kio3XwPSLpkrX3JWluR9GVJj4V8T9hCrLVT1trveI+Lki5I2rvGJY9J+rK1tmytfVnSJUkPGGMmJGWstf9s3VesL0n6scA1X/Qe/x9J7/b+Beu9kr5urZ231l6X9HVJj67zU8TWFVwXX1T7emGN4Xa9W9KL1tpX1ziHNYY3Za39J0nzHbtDe93yjr3LO7fzz8c21GuNWWv/1lpb8z79hqR9a30N1hjWssrr2Gp4HcMtW2uNef+9f1LSn6z1NVhjWM0aWQU/j60jgtrNt1fS64HPL2vtEA67mFfmf1bSN71dv2Lct959PvB2gtXW1F7vcef+tmu8Xz4WJI2s8bWw81hJf2uM+bYx5glv35i1dkpyvwlLynv7WWN4Kx5X+y8EvI5hPYX5ujUi6UYgxGOt7XwflVv103TYGPOUMeb/GWMe9vaxxnA7wvreyBrbXR6WdM1aezGwj9cx3JaOrIKfx9YRQe3mMz322U2/C2x5xpiUpD+T9GvW2oKk35d0VNIZSVNy37Yirb6m1lprt3MNdpZ3WGvvlfQ+SZ80xjyyxrmsMdwWrz/UByT9qbeL1zFsls1YU6y1XcQY8xuSapL+yNs1JemAtfaspP8g6Y+NMRmxxnDrwvzeyBrbXX5a7f94zusYbkuPrGLVU3vs43XsTRDUbr7LkvYHPt8n6UpI94ItyhgTl/vC90fW2j+XJGvtNWtt3VrbkPQ/5bbRkFZfU5clU8iXAAACjElEQVTV/va84FprXWOMiUnKyn2LDOtzl7DWXvE+Tkv6C7nr6Zr3NpTmW56mvdNZY7hd75P0HWvtNYnXMWyIMF+3ZiUNeud2fi3sIN7Akh+R9DPeWzTlvY1zznv8bbl9946LNYZbFPL3RtbYLuH9N/43kr7S3MfrGG5Hr6xC/Dy2rghqN9+/SJr0ptIl5L4l9Gsh3xO2EK/Hyh9KumCt/Wxg/0TgtB+X1Jzk+TVJj3vTEQ9LmpT0Le8tB0VjzEPe1/ywpL8MXNOckvgTkv6v94vH30h6jzFmyHvb1Xu8fdhBjDEDxph087Hc/87Pqn1dfETt64U1htvRVrnB6xg2QGivW96xf/DO7fzzsUMYYx6V9ClJH7DWLgf2jxpjot7jI3LX2EusMdyqML83ssZ2lR+U9H1rbevt5ryO4VatllWIn8fWl90CE8122ybp/XKn470o6TfCvh+2rbVJ+gG5pfrPSDrvbe+X9L8lfdfb/zVJE4FrfsNbT8/Lm5bo7T8n94e9FyX9niTj7U/KfSvyJbnTFo8Ervmot/+SpF8I+++DbUPW2BG50zeflvRc83VIbn+fv5d00fs4zBpjewvrrF/SnKRsYB+vY2xvZU39idy3aVblVlV8LOzXLe/19Fve/j+V5IT998S27mvsktyeeM2fyZqTqD/ofQ99WtJ3JP0oa4ztNtdYqN8bWWM7a+u1xrz9X5D0Sx3n8jrGdqvra7Wsgp/H1nFr/kUAAAAAAAAAAEJC6wMAAAAAAAAACBlBLQAAAAAAAACEjKAWAAAAAAAAAEJGUAsAAAAAAAAAISOoBQAAAAAAAICQEdQCAAAAAAAAQMgIagEAAAAAAAAgZAS1AAAAAAAAABCy/w+tJEVi2ts/QgAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "for how in [\"inner\", \"left\", \"outer\"]:\n", - " make_fig(how)\n", - " plt.legend()\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 106, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 106, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAABWYAAAH0CAYAAAC+UfTcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nOzdeVTWZd7H8feloCi4L42miWaugICAoqikqTUulWlqZqltllaW9WQ1lk02zUyUpWlWM0ZNZqZmpmlRk7grApK7uZFr5hIqKMpyPX/ctwwoIBJwg35e53gGftf2/f3u+znnOZ8urp+x1iIiIiIiIiIiIiIiJaecqwsQERERERERERERudYomBUREREREREREREpYQpmRUREREREREREREqYglkRERERERERERGREqZgVkRERERERERERKSEKZgVERERERERERERKWEKZkVERERcxBiTbIxp4uo6AIwxdxpj9jtrCnB1PdcKY0ykMWaiq+soSsaYcGPMAVfXISIiIlLaKZgVERERKWLGmGhjzIOX62et9bLW7imJmgogAhjtrGmDq4sREREREbnaKZgVERERuYYZY9ycPzYCthRyjvJFV5Fcy7J9H0VERESuegpmRURERIqRMeYhY8wuY8wJY8zXxpj62dqsMaap8+dIY8xUY8w3xpjTxph1xpgb85jT2zn2YWPMIWPMYWPM2Gzt5Ywx44wxu40xx40xXxhjal409gFjzD5ghTEmGSgP/GSM2e3s19K58zfJGLPFGNM32/yRxpj3jDGLjTEpwM3Oa9OMMUucxyGsMsb8yRjztjHmd2PM9uxHJGSr77QxZqsx5s5sbcOMMSuNMRHOsXuNMbdla69pjPnIee+/G2O+ytbW2xiT4Kx7tTHGL1vbc8aYg841dxhjuuXxfCONMdONMd87+y4zxjTK1v6O89iHU8aYOGNMp2xtE5zP+xPn2C3GmKBs7QHGmHhn22zAI1tbDWPMImPMUed9LTLGNLjouexxjt1rjBmSS+31jTFnL3ze2dY8ZoxxN8Y0dd7PSee12Xk8g8t9xyo6P9tDzn9vG2Mq5jHX5T7rVcaYScaYE8CEgtYoIiIiUtYpmBUREREpJsaYrsDrwN1APeAX4PN8hgwGXgFqALuA1y6zxM3ATUAPYJwx5hbn9SeAO4AuQH3gd2DqRWO7AC2BrtZaL+e1NtbaG40x7sBCIAqoCzwOzDTGNM82/h5nfVWAlc5rdwN/AWoD54A1QLzz97nAW9nG7wY6AdWc9/ypMaZetvZ2wA7n2H8C/zbGGGfbf4DKQGtnfZMAjDGBwAzgEaAW8D7wtTNEbA6MBoKttVWAnkBibg/VaQjwqnP9BGBmtrb1gD9QE/gMmGOM8cjW3hfH51wd+Bp411lfBeArZ/01gTnAXdnGlQM+wrF7+QbgbLaxnsBk4DZn/R2cdeVgrT2E47lnn/ceYK61Ns15T1E4vmMNgCn5PAPI+zv2ItDe+RzaACE4PvvcFOSz3oPjs3ytEDWKiIiIlEkKZkVERESKzxBghrU23lp7DngeCDXGeOfR/0trbYy1Nh1HEOh/mflfsdamWGs34Qj0BjuvPwK8aK094Fx3AtDf5Pwz8QnOsWdzmbc94AX83Vp73lr7I7Ao2/wAC6y1q6y1mdbaVOe1+dbaOOfv84FUa+0n1toMYDaQtWPWWjvHWnvIOX42sBNHuHfBL9baD51jP8YRbF/nDPRuA0Zaa3+31qZZa5c5xzwEvG+tXWetzbDWfowjIG4PZAAVgVbGGHdrbaK1dnc+z/Yba+1y5/N7Ecfn1tBZ+6fW2uPW2nRr7ZvOebOH1iuttYudtf8HR3B54bm6A287656LI+S98EyOW2vnWWvPWGtP4wgpu2SbNxPwMcZUstYettbmdfTEZzg/K2eYPch5DSANR/Bb31qbaq1dmfsUWfL6jg0B/mqt/c1aexRH4Do0twkK8FkfstZOcT7Ps4WoUURERKRMUjArIiIiUnzq49glC4C1Nhk4DlyfR/9fs/18Bkc4mp/92X7+xbkeOEKt+c4/508CtuEIJq/LY2xude+31mZeNH/2unMbfyTbz2dz+T3rfowx92U7ciAJ8MGxO/WCrGdhrT3j/NELaAicsNb+nsv6jYCxF+Z0ztsQR8C3CxiDI6T+zRjzucl2rEQusu7P+bmdwPl8jTFjjTHbnH9qn4RjJ2iuteP4HD2coXh94KC11mZrz/p+GGMqG2PeN8b8Yow5BSwHqhtjyltrU4CBwEjgsHEcedEij9rn4giS6wOdAQuscLb9H2CAGOcxCyPyeQY5ngM5v2M5vtsXteVQgM/64u/SldYoIiIiUiYpmBUREREpPodwhIVA1p+j1wIOFtH8DbP9fINzPXAEXbdZa6tn++dhrc2+bvZwMLe6Gxpjsv//ijeQs+78xufLeV7rhziOFqhlra0ObMYRxl3OfqCmMaZ6Hm2vXXTfla21swCstZ9Za8NwfCYW+Ec+62Q9W2OMF46jBw45z5N9DsexDTWctZ8sYO2HgeuzHckAjud6wVgcO2/bWWur4ghVuTC3tfY7a213HLuHt+N4hpew1ibhOArgbhzHGMy6EAZba3+11j5kra2PY2f1NOM85zgPeX3Hcny3L2rLUsDPOsd3qRA1ioiIiJRJCmZFREREis9nwHBjjL/zxUh/A9ZZaxOLaP7xzl2WrYHhOI4LAJgOvOYMxTDG1DHG3H4F864DUoD/c74wKhzoQ/7n414JTxxh3FFnfcNx7KK8LGvtYWAJjrCuhrO+CwHmh8BIY0w74+BpjOlljKlijGlujOnq/BxScezgzchnqT8bY8Kc58K+iuNz24/jTN10Z+1uxpiXgKoFvO81zrFPGGPcjDH9yPkn/VWcdSUZx8u7Xr7QYIy5zhjT1xnunwOSL1P/Z8B9OM6avXCMAcaYAeZ/LxT7HcfnkN88eX3HZgF/cX63agMvAZ/mMv6KP+tC1CgiIiJSJimYFRERESke1lr7X2A8MA/HbskbcZz3WVSW4XhJ2H+BCGttlPP6OzheOhVljDkNrMXxgqWCFn4exwusbgOOAdOA+6y124uiaGvtVuBNHEHlEcAXWHUFUwzFcQ7pduA3HEcUYK2NxXHO7Ls4Ar1dwDDnmIrA33Hcz684XjT1Qj5rfIYjGD0BtMVxpirAdziC4Z9x/Pl+KvkfC5HF+Vz7OWv6HcfRBF9m6/I2UMlZ41rg22xt5XDsqD3krKkL8Fg+y32N46VdR6y1P2W7HgysM8YkO/s8aa3dm888eX3HJgKxwEZgE46XvE3M5Z4L81lfaY0iIiIiZZLJecSViIiIiPxRxph4HC9G+qqY5vcG9gLuzheFSREyxkQCB6y1f3F1La6i75iIiIhI8dOOWREREZEi5PyT75bABlfXIiIiIiIipZeCWREREZEiYoz5B46XLj1nrf3lcv1FREREROTapaMMREREREREREREREqYdsyKiIiIiIiIiIiIlDAFsyIiIiIiIiIiIiIlzM3VBeSmdu3a1tvb29VliIiIiIiIiIiIiPwhcXFxx6y1dS6+XiqDWW9vb2JjY11dhoiIiIiIiIiIiMgfYozJ9cXAOspAREREREREREREpIQpmBUREREREREREREpYQpmRUREREREREREREpYqTxjNjdpaWkcOHCA1NRUV5ciZZyHhwcNGjTA3d3d1aWIiIiIiIiIiMg1qswEswcOHKBKlSp4e3tjjHF1OVJGWWs5fvw4Bw4coHHjxq4uR0RERERERERErlFl5iiD1NRUatWqpVBW/hBjDLVq1dLOaxERERERERERcakyE8wCCmWlSOh7JCIiIiIiIiIirlamgllX8/LyumyfyZMn07JlS4YMGUJ0dDSrV68ugcryl5iYiI+Pj6vLEBEREREREREREScFs0Vs2rRpLF68mJkzZ5aaYFZERERERERERERKFwWzhfTGG28QHByMn58fL7/8MgAjR45kz5499O3bl0mTJjF9+nQmTZqEv78/K1asyDF+2bJl+Pv74+/vT0BAAKdPnyY6OprOnTtz55130qpVK0aOHElmZiYAUVFRhIaGEhgYyIABA0hOTgYgLi6OLl260LZtW3r27Mnhw4ezrrdp04bQ0FCmTp2atW5kZCSjR4/O+r13795ER0cDjh3BY8eOJTAwkG7dunH06NFie34iIiIiIiIiIiLXMgWzhRAVFcXOnTuJiYkhISGBuLg4li9fzvTp06lfvz5Lly7lqaeeYuTIkTz11FMkJCTQqVOnHHNEREQwdepUEhISWLFiBZUqVQIgJiaGN998k02bNrF7926+/PJLjh07xsSJE/nhhx+Ij48nKCiIt956i7S0NB5//HHmzp1LXFwcI0aM4MUXXwRg+PDhTJ48mTVr1hT4vlJSUggMDCQ+Pp4uXbrwyiuvFN1DExERERERERERkSxuri6gMF5ZuIWth04V6Zyt6lfl5T6tC9Q3KiqKqKgoAgICAEhOTmbnzp107ty5wOt17NiRp59+miFDhtCvXz8aNGgAQEhICE2aNAFg8ODBrFy5Eg8PD7Zu3UrHjh0BOH/+PKGhoezYsYPNmzfTvXt3ADIyMqhXrx4nT54kKSmJLl26ADB06FCWLFly2ZrKlSvHwIEDAbj33nvp169fge9HRERERERERERECq5MBrOuZq3l+eef55FHHin0HOPGjaNXr14sXryY9u3b88MPPwBgjMnRzxiDtZbu3bsza9asHG2bNm2idevWl+yKTUpKumSeC9zc3LKORwBITU3Ns8a85hAREREREREREZE/pkwGswXd2Vpcevbsyfjx4xkyZAheXl4cPHgQd3d36tatm6NflSpVOHUq9529u3fvxtfXF19fX9asWcP27dupXr06MTEx7N27l0aNGjF79mwefvhh2rdvz6hRo9i1axdNmzblzJkzHDhwgObNm3P06FHWrFlDaGgoaWlp/Pzzz7Ru3Zpq1aqxcuVKwsLCmDlzZta63t7eTJs2jczMTA4ePEhMTExWW2ZmJnPnzmXQoEF89tlnhIWFFc8DFBERERERERERucbpjNlC6NGjB/fccw+hoaH4+vrSv39/Tp8+fUm/Pn36MH/+/Fxf/vX222/j4+NDmzZtqFSpErfddhsAoaGhjBs3Dh8fHxo3bsydd95JnTp1iIyMZPDgwfj5+dG+fXu2b99OhQoVmDt3Ls899xxt2rTB39+f1atXA/DRRx8xatQoQkNDs86vBccRCo0bN8bX15dnnnmGwMDArDZPT0+2bNlC27Zt+fHHH3nppZeK4/GJiIiIiIiIiIhc84y11tU1XCIoKMjGxsbmuLZt2zZatmzpoopKRnR0NBERESxatMgl63t5eZGcnOyStUvatfB9EhERERERERER1zPGxFlrgy6+rh2zIiIiIiIiIiIiIiWsTJ4xe7UKDw8nPDzcZetfK7tlRUREREREREREXE07ZkVERERERERERKT4HNsFaamurqLUUTArIiIiIiIiIiIixePnKPggHL7XS+YvpmBWREREREREREREipa1sGoyfHY31GwMHZ9wdUWljs6YFRERERERERERkaKTlgqLxsBPs6DVHXDHNKjg6eqqSh3tmL0CXl5el+0zefJkWrZsyZAhQ4iOjmb16tUlUFlO4eHhxMbGAuDt7c2xY8fy7V+QPsVt2LBhzJ0716U1iIiIiIiIiIhIETh9CHYsgfAXYECkQtk8aMdsEZs2bRpLliyhcePGTJgwAS8vLzp06FCka1hrsdZSrpxydRERERERERERKSV+T4TqjaBmE3g8HjxrubqiUk3JXiG98cYbBAcH4+fnx8svvwzAyJEj2bNnD3379mXSpElMnz6dSZMm4e/vz4oVK3KMnzBhAkOHDqVr167cdNNNfPjhhwAkJyfTrVs3AgMD8fX1ZcGCBQAkJibSsmVLHnvsMQIDA9m/fz+PPvooQUFBtG7dOquG/Hz66aeEhITg7+/PI488QkZGRp59MzIyGDZsGD4+Pvj6+jJp0iTAsRt3zJgxdOjQAR8fH2JiYgBISUlhxIgRBAcHExAQkFV3RkYGzz77bNazev/99wFHuDx69GhatWpFr169+O2337LWzr6DNzY2lvDw8HyfmYiIiIiIiIiIuNiW+TC1Paz/l+N3hbKXpR2zhRAVFcXOnTuJiYnBWkvfvn1Zvnw506dP59tvv2Xp0qXUrl2bkydP4uXlxTPPPJPrPBs3bmTt2rWkpKQQEBBAr169qFu3LvPnz6dq1aocO3aM9u3b07dvXwB27NjBRx99xLRp0wB47bXXqFmzJhkZGXTr1o2NGzfi5+eX61rbtm1j9uzZrFq1Cnd3dx577DFmzpzJfffdl2v/hIQEDh48yObNmwFISkrKaktJSWH16tUsX76cESNGsHnzZl577TW6du3KjBkzSEpKIiQkhFtuuYWZM2dSrVo11q9fz7lz5+jYsSM9evRgw4YN7Nixg02bNnHkyBFatWrFiBEjLvvsc3tm9evXv+w4EREREREREREpBpmZsOwfsOzv0CAEWvZ1dUVlRpkNZge+v+aSa7396jE01Juz5zMY9lHMJe392zZgQFBDTqSc59FP43K0zX4ktMBrR0VFERUVRUBAAODY5bpz5046d+58Rfdw++23U6lSJSpVqsTNN99MTEwMvXr14oUXXmD58uWUK1eOgwcPcuTIEQAaNWpE+/bts8Z/8cUXfPDBB6Snp3P48GG2bt2aZzD73//+l7i4OIKDgwE4e/YsdevWzbO2Jk2asGfPHh5//HF69epFjx49stoGDx4MQOfOnTl16hRJSUlERUXx9ddfExERAUBqair79u0jKiqKjRs3Zp0fe/LkSXbu3Mny5csZPHgw5cuXp379+nTt2rXQz+yOO+4o0FgRERERERERESlC51Ng/kjY9jW0uQf6vA1uFV1dVZlRZoNZV7LW8vzzz/PII4/8oXmMMZf8PnPmTI4ePUpcXBzu7u54e3uTmpoKgKfn/w5K3rt3LxEREaxfv54aNWowbNiwrH551Xz//ffz+uuvF6i2GjVq8NNPP/Hdd98xdepUvvjiC2bMmJFn3dZa5s2bR/PmzS9Zd8qUKfTs2TPH9cWLF18yzwVubm5kZmYCXHJPua0tIiIiIiIiIiIucGiD4yVfPSZC6GhQTnNFyuwZs7MfCb3k39BQbwAqVSifa/uAoIYA1PSscEnblejZsyczZswgOTkZgIMHD+Y4I/WCKlWqcPr06TznWbBgAampqRw/fpzo6GiCg4M5efIkdevWxd3dnaVLl/LLL7/kOvbUqVN4enpSrVo1jhw5wpIlS/KtuVu3bsydOzerzhMnTuQ5N8CxY8fIzMzkrrvu4tVXXyU+Pj6rbfbs2QCsXLmSatWqUa1aNXr27MmUKVOw1gKwYcMGwPGs3nvvPdLS0gD4+eefSUlJoXPnznz++edkZGRw+PBhli5dmjW/t7c3cXGOHc3z5s277DMTEREREREREZESlOzMwbzD4IkN0OFxhbKFoB2zhdCjRw+2bdtGaKgj0PXy8uLTTz+95GiAPn360L9/fxYsWMCUKVPo1KlTjvaQkBB69erFvn37GD9+PPXr12fIkCH06dOHoKAg/P39adGiRa41tGnThoCAAFq3bk2TJk3o2LFjvjW3atWKiRMn0qNHDzIzM3F3d2fq1Kk0atQo1/4HDx5k+PDhWTtXs++0rVGjBh06dODUqVNZu2jHjx/PmDFj8PPzw1qLt7c3ixYt4sEHHyQxMZHAwECstdSpU4evvvqKO++8kx9//BFfX1+aNWtGly5dsuZ/+eWXeeCBB/jb3/5Gu3btLvvMRERERERERESkhCR8BouehkEzoWk3qN7Q1RWVWebCDsfSJCgoyMbGxua4tm3bNlq2bOmiiorehAkT8n0xWGkVHh5OREQEQUFBJb52UT6zq+37JCIiIiIiIiJSrDIz4PuXYM270LgLDIiEyjVdXVWZYIyJs9ZeEqZpx6yIiIiIiIiIiIjkLfUkzHsQdkZB8ENw6+tQ3t3VVZV5CmZdZMKECa4uoVCio6NdtnZZfWYiIiIiIiIiImXa1gWw+0foPQmCRri6mquGglkRERERERERERG5VOop8KgKAUOhQQjUzf1dSFI45VxdgIiIiIiIiIiIiJQyMR/CO23g2E4wRqFsMVAwKyIiIiIiIiIiIg4ZabDoKVj8DDQMAa/r/tB0mZmW6ct289eFW4uowKuHjjIQEREREREREREROHMCvrgPEldAxzHQ7SUoV77Q0x1PPsfYOT8RveMovXzrkZ6RiVt57RO9QE/iCnTo0MHVJRSr8PBwYmNjXV2GiIiIiIiIiIi4wqq3YX8M3PkBdH/lD4Wy6/Yc58+TV7B693FevcOHd+8JUCh7Ee2YvQKrV68u1vnT09Nxc9NHIiIiIiIiIiIiJSgtFdw9IPwF8LkL6rX5Q9OdPJvGgx/HUrtKRWYMC6Z1/WpFVOjV5bIxtTHGwxgTY4z5yRizxRjzSi59jDFmsjFmlzFmozEmMFvbrcaYHc62cUV9AyXJy8sLgOjoaMLDw+nfvz8tWrRgyJAhWGsB8Pb25uWXXyYwMBBfX1+2b98OQEpKCiNGjCA4OJiAgAAWLFgAQGRkJAMGDKBPnz706NEjx3opKSn06tWLNm3a4OPjw+zZs7PWeO655wgJCSEkJIRdu3YBcPToUe666y6Cg4MJDg5m1apV+a599uxZBg0ahJ+fHwMHDuTs2bOX3CvA3LlzGTZsGADDhg1j5MiRdOrUiWbNmrFo0aIifcYiIiIiIiIiIlJCrIWVb8P7neDs745w9g+EsifPpmGtpVold/51fxALHw9TKJuPgmzPPAd0tdYmG2PcgZXGmCXW2rXZ+twG3OT81w54D2hnjCkPTAW6AweA9caYr621Zf603w0bNrBlyxbq169Px44dWbVqFWFhYQDUrl2b+Ph4pk2bRkREBP/617947bXX6Nq1KzNmzCApKYmQkBBuueUWANasWcPGjRupWbNmjjW+/fZb6tevzzfffAPAyZMns9qqVq1KTEwMn3zyCWPGjGHRokU8+eSTPPXUU4SFhbFv3z569uzJtm3b8lz7/fffp3LlymzcuJGNGzcSGBhIQSQmJrJs2TJ2797NzTffzK5du/Dw8CiKxyoiIiIiIiIiIiUhLRUWPgkbP4fWd0L5in9oulW7jvHk5wk806MZg0JuoF2TWkVU6NXrssGsdWwFTXb+6u78Zy/qdjvwibPvWmNMdWNMPcAb2GWt3QNgjPnc2fePBbNLxsGvm/7QFJf4ky/c9vcCdw8JCaFBgwYA+Pv7k5iYmBXM9uvXD4C2bdvy5ZdfAhAVFcXXX39NREQEAKmpqezbtw+A7t27XxLKAvj6+vLMM8/w3HPP0bt3bzp16pTVNnjw4Kz/feqppwD44Ycf2Lr1f4/21KlTnD59Os+1ly9fzhNPPAGAn58ffn5+Bbr3u+++m3LlynHTTTfRpEkTtm/fjr+/f4HGioiIiIiIiIiIi53+FT4fAgdj4ea/QOdnwJhCTZWRaXnnvzuZ8uNObqzjRcANNYq42KtXgQ40de58jQOaAlOttesu6nI9sD/b7wec13K73q7Q1ZYiFSv+778ilC9fnvT09Evasl+31jJv3jyaN2+eY55169bh6emZ6xrNmjUjLi6OxYsX8/zzz9OjRw9eeuklAEy2/2O58HNmZiZr1qyhUqVKOebJa+2L58nrempqar5j8ppDRERERERERERKocXPwm9b4e7/QKu+hZ7myKlUnpi1gXV7T3BXYANevaM1lSvo/UkFVaAnZa3NAPyNMdWB+cYYH2vt5mxdckvmbD7XL2GMeRh4GOCGG27Iv6Ar2NlaWvTs2ZMpU6YwZcoUjDFs2LCBgICAfMccOnSImjVrcu+99+Ll5UVkZGRW2+zZsxk3bhyzZ88mNDQUgB49evDuu+/y7LPPApCQkIC/v3+ea3fu3JmZM2dy8803s3nzZjZu3Jg1/3XXXce2bdto3rw58+fPp0qVKlltc+bM4f7772fv3r3s2bMn18BXRERERERERERKmcwMKFce/hwByUegXsH+ejovmw+eZNPBk0QMaEP/tg2KqMhrxxVF2NbaJGNMNHArkD2YPQA0zPZ7A+AQUCGP67nN/QHwAUBQUFCu4W1ZNn78eMaMGYOfnx/WWry9vS/74qxNmzbx7LPPUq5cOdzd3Xnvvfey2s6dO0e7du3IzMxk1qxZAEyePJlRo0bh5+dHeno6nTt3Zvr06Xmu/eijjzJ8+HD8/Pzw9/cnJCQka/6///3v9O7dm4YNG+Lj40NycnJWW/PmzenSpQtHjhxh+vTpOl9WRERERERERKQ0y8yE6NfhUDwMng1VrnP8K4T0jEzi9yUR0rgm3Vpex4r/u5laXn/sfNprlXEcC5tPB2PqAGnOULYSEAX8w1q7KFufXsBo4M84jiqYbK0NMca4AT8D3YCDwHrgHmvtlvzWDAoKsrGxsTmubdu2jZYtW17p/V2VvL29iY2NpXbt2iW+9rBhw+jduzf9+/cv8bWLkr5PIiIiIiIiInJNOJ8C8x+BbQvB/17oPQncKhRqqkNJZ3li1gYS9ifx49hwbqhVuYiLvToZY+KstUEXXy/Ijtl6wMfOc2bLAV9YaxcZY0YCWGunA4txhLK7gDPAcGdbujFmNPAdUB6YcblQVkRERERERERERIpA0n6YNRh+2wI9/wbtHyv0S77+u+0IY+f8RFp6Jm/e3UahbBG4bDBrrd0IXHIYqjOQvfCzBUblMX4xjuBWikhiYqLL1s5+zq2IiIiIiIiIiJRS1sLsIZD0C9wzB266pdBTvb5kG+8v20OrelWZOiSQxrVzf5G9XBm9Jk1ERERERERERORqYq1jZ2yfyeBeGeo0+0PTeVZw477QRrzw55Z4uJcvoiJFwayIiIiIiIiIiMjVIDMDvn/J8XPP16C+f6Gn+nbzYTwrutHppjo83rUpppBHIEjeyrm6ABEREREREREREfmDUk/CZ3fDmnch47xj12whnEvP4OUFmxn5aTwzVu4FUChbTLRjVkREREREREREpCw7vhtmDYITe6D3JAgaUahpEjyzTDYAACAASURBVI+lMHpWPJsPnuKBsMY8d2uLIi5UstOO2SvQoUMHV5dQZCIjIxk9ejQAEyZMICIiIt/+BelT3KKjo+ndu7dLaxARERERERERKVXSzkJkL0g5BkO/KnQou/dYCr2nrGT/ibN8eF8Q43u3ooKbosPipB2zV2D16tXFOn96ejpubkX3kRT1fCIiIiIiIiIiUsq4V3Lskq3TAmo2LvQ03rUq80BYY+4Obsj11SsVYYGSF8XeV8DLywtw7NwMDw+nf//+tGjRgiFDhmCd53Z4e3vz8ssvExgYiK+vL9u3bwcgJSWFESNGEBwcTEBAAAsWLAAcO1cHDBhAnz596NGjR471EhMTadGiBffffz9+fn7079+fM2fOAPDXv/6V4OBgfHx8ePjhh7PWDw8P54UXXqBLly688847LFy4kHbt2hEQEMAtt9zCkSNH8r3H3bt3c+utt9K2bVs6deqUVX9e5syZg4+PD23atKFz585Z93T77bdz66230rx5c1555ZWs/p9++ikhISH4+/vzyCOPkJGRAUBUVBShoaEEBgYyYMAAkpOTAfj2229p0aIFYWFhfPnll1nzXLyD18fHh8TExHyfmYiIiIiIiIjIVSEjDRaOgYTPHL83v61Qoezuo8ncPX0NvxxPwRjDU92bKZQtQQpmC2nDhg28/fbbbN26lT179rBq1aqsttq1axMfH8+jjz6aFR6+9tprdO3alfXr17N06VKeffZZUlJSAFizZg0ff/wxP/744yXr7Nixg4cffpiNGzdStWpVpk2bBsDo0aNZv349mzdv5uzZsyxatChrTFJSEsuWLWPs2LGEhYWxdu1aNmzYwKBBg/jnP/+Z7309/PDDTJkyhbi4OCIiInjsscfy7f/Xv/6V7777jp9++omvv/4663pMTAwzZ84kISGBOXPmEBsby7Zt25g9ezarVq0iISGB8uXLM3PmTI4dO8bEiRP54YcfiI+PJygoiLfeeovU1FQeeughFi5cyIoVK/j1118v86nk/8xERERERERERMq8lOPwyR0Q9xH8nljoaeZvOECfKSvZdTSZX0+mFl19UmBl9+/cP+p16bXWd0DIQ3D+DMwccGm7/z0QMMTxBf7ivpxtw7+5ouVDQkJo0KCBY1p/fxITEwkLCwOgX79+ALRt2zZrl2dUVBRff/11VlCbmprKvn37AOjevTs1a9bMdZ2GDRvSsWNHAO69914mT57MM888w9KlS/nnP//JmTNnOHHiBK1bt6ZPnz4ADBw4MGv8gQMHGDhwIIcPH+b8+fM0bpz3fz1JTk5m9erVDBjwv2d37ty5fJ9Dx44dGTZsGHfffXfWfV+4p1q1amU9j5UrV+Lm5kZcXBzBwcEAnD17lrp167J27Vq2bt2adZ/nz58nNDSU7du307hxY2666aas+//ggw/yrSe/ZyYiIiIiIiIiUqYd2ep4ydfpX6Hfh+B39xVPceZ8Oi8v2MKcuAOEeNdk8uAA/lTNoxiKlcspu8Gsi1WsWDHr5/Lly5Oenn5JW/br1lrmzZtH8+bNc8yzbt06PD0981zHGHPJ76mpqTz22GPExsbSsGFDJkyYQGrq//7LRvb5Hn/8cZ5++mn69u1LdHQ0EyZMyHOtzMxMqlevTkJCQj53ntP06dNZt24d33zzDf7+/lljc6vbWsv999/P66+/nqNt4cKFdO/enVmzZuW4npCQcMk8F7i5uZGZmZn1e/b7z21tEREREREREZEy7fSv8O8eUKEyDF8CDdoWaprp0buZG3+Ax7s25cluN+FWXn9Q7yplN5jNb4drhcr5t3vWuuIdsn9Uz549mTJlClOmTMEYw4YNGwgICLjsuH379rFmzRpCQ0OZNWsWYWFhWSFk7dq1SU5OZu7cufTv3z/X8SdPnuT6668H4OOPP853rapVq9K4cWPmzJnDgAEDsNayceNG2rRpk+eY3bt3065dO9q1a8fChQvZv38/AN9//z0nTpygUqVKfPXVV8yYMYPKlStz++2389RTT1G3bl1OnDjB6dOnad++PaNGjWLXrl00bdqUM2fOcODAAVq0aMHevXvZvXs3N954Y47g1tvbO+v4hvj4ePbu3ZvvMxMRERERERERKdOq/Am6T4Bmt0G1669oqLWWk2fTqF65Ao+GN6VD09q0b1KreOqUAlMkXkLGjx9PWloafn5++Pj4MH78+AKNa9myJR9//DF+fn6cOHGCRx99lOrVq/PQQw/h6+vLHXfckXU0QG4mTJjAgAED6NSpE7Vr177sejNnzuTf//43bdq0oXXr1lkvKcvLs88+i6+vLz4+PnTu3DkrxA0LC2Po0KH4+/tz1113ERQURKtWrZg4cSI9evTAz8+P7t27c/jwYerUqUNkZCSDBw/Gz8+P9u3bs337djw8PPjggw/o1asXYWFhNGrUKGvdu+66ixMnTuDv7897771Hs2bN8n1mIiIiIiIiIiJlTloqLBgF+9c7fg9+8IpD2eRz6Tw1O4F+01aTci6dShXKK5QtJYy11tU1XCIoKMjGxsbmuLZt2zZatmzpoopcIzExkd69e7N582ZXl3JFIiMjiY2N5d133y3xtQv6zK7F75OIiIiIiIiIlCGnf4XPh8DBWLj1H9B+5BVPsfXQKUZ/Fk/i8RTG3NKMUTc3pXw5HflY0owxcdbaoIuvl92jDERERERERERERK5GhzbArHsg9SQM/BRa9rmi4dZaZq7bx18XbaV6JXc+e6i9dsmWQtoxK9ckfZ9EREREREREpFQ6vNHxki/P2jB4FvzJ94qnSM/IZMD7a6ji4c5bd7ehtlfFyw+SYqMdsyIiIiIiIiIiIqXdda0hdBS0Gwleda5o6OaDJ7m+eiVqeFbgo2HBVPVwp5yOLii19PIvERERERERERERVzqXDAufhJMHoVx56Db+ikJZay2Rq/bSb9pq/vHtdgCqV66gULaU045ZERERERERERERV0naB7MGw29bwbsT+Pa/ouEnz6Txf/N+4rstR+jWoi7P3dqimAqVoqZgVkRERERERERExBX2rYXPh0BGGgyZA01vuaLh2w6f4sGPYzlyKpW/9GrJA2GNMUa7ZMsKHWVQjCIjIzl06JCryyiwxMREfHx8XF2GiIiIiIiIiMjVb9d/IbI3eFSDh/57xaEsQG2vilxXtSJzRobyYKcmCmXLGAWzxagwwWx6enoxVSMiIiIiIiIiIqVGgyAIHOoIZWvfVOBhv6ecJ+K7HaRnZFKnSkXmPdqBgBtqFGOhUlwUzF6Bt956Cx8fH3x8fHj77beBS3eZRkREMGHCBObOnUtsbCxDhgzB39+fs2fPEhcXR5cuXWjbti09e/bk8OHDAISHh/PCCy/QpUsX3nnnnRxrLlu2DH9/f/z9/QkICOD06dNER0fTuXNn7rzzTlq1asXIkSPJzMwEICoqitDQUAIDAxkwYADJyckAea4dFxdHmzZtCA0NZerUqVnrRkZGMnr06Kzfe/fuTXR0NABeXl6MHTuWwMBAunXrxtGjR4v4SYuIiIiIiIiIXIXOJsF3L0LaWcdO2d6ToFLBQ9X1iSf48+QVfLB8DxsPngTQLtkyTMFsAcXFxfHRRx+xbt061q5dy4cffsiGDRvy7N+/f3+CgoKYOXMmCQkJuLm58fjjjzN37lzi4uIYMWIEL774Ylb/pKQkli1bxtixY3PMExERwdSpU0lISGDFihVUqlQJgJiYGN588002bdrE7t27+fLLLzl27BgTJ07khx9+ID4+nqCgIN566y3S0tLyXHv48OFMnjyZNWvWFPhZpKSkEBgYSHx8PF26dOGVV165kkcpIiIiIiIiInLtOb4b/nULrJsO+9dd0dDMTMvUpbsY9MFaKriVY96jHQjULtkyr0y+/OsfMf9g+4ntRTpni5oteC7kuTzbV65cyZ133omnpycA/fr1Y8WKFfTt27dA8+/YsYPNmzfTvXt3ADIyMqhXr15W+8CBA3Md17FjR55++mmGDBlCv379aNCgAQAhISE0adIEgMGDB7Ny5Uo8PDzYunUrHTt2BOD8+fOEhobmufbJkydJSkqiS5cuAAwdOpQlS5Zc9l7KlSuXVe+9995Lv379CvQMRERERERERESuSbt/hDnDoJwb3Pc1eHe8ouEvfrWZWTH76O1Xj9f7+VLFw7146pQSVSaDWVew1uZ63c3NLesYAYDU1NQ8x7du3TrPnakXAt+LjRs3jl69erF48WLat2/PDz/8AFy6Td0Yg7WW7t27M2vWrBxtmzZtynXtpKSkPLe7F/S+cqtFREREREREREScNn4B80dCneYweBbU8C7wUGstxhgGhzTE5/qq3BNyg3KYq0iZDGbz29laXDp37sywYcMYN24c1lrmz5/Pf/7zH6677jp+++03jh8/jpeXF4sWLeLWW28FoEqVKpw+fRqA5s2bc/ToUdasWUNoaChpaWn8/PPPtG7dOt91d+/eja+vL76+vqxZs4bt27dTvXp1YmJi2Lt3L40aNWL27Nk8/PDDtG/fnlGjRrFr1y6aNm3KmTNnOHDgQL5rV6tWjZUrVxIWFsbMmTOz1vX29mbatGlkZmZy8OBBYmJistoyMzOZO3cugwYN4rPPPiMsLKwYnriIiIiIiIiIyFXg+rbgOwB6RUDFKgUakpFpmfLjTpLOpDGhb2v8GlTHr0H1Yi5USlqZDGZdITAwkGHDhhESEgLAgw8+SEBAAAAvvfQS7dq1o3HjxrRo0SJrzLBhwxg5ciSVKlVizZo1zJ07lyeeeIKTJ0+Snp7OmDFjLhvMvv322yxdupTy5cvTqlUrbrvttqyAddy4cWzatCnrRWDlypUjMjKSwYMHc+7cOQAmTpxIs2bN8lz7o48+YsSIEVSuXJmePXtmrduxY0caN26Mr68vPj4+BAYGZrV5enqyZcsW2rZtS7Vq1Zg9e3aRPWcRERERERERkTIv5Ths+AQ6joFaN0K/9ws89LdTqYyZncDq3cfpF3A9GZmW8uW0S/ZqZPL6E31XCgoKsrGxsTmubdu2jZYtW7qootIlOjqaiIgIFi1a5JL1vby8SE5OdsnaRUXfJxEREREREREpFke2wKxBcPoIPLIM6hY8f1ix8yhPzU4g+Vw6f73dhwFtG+jogquAMSbOWht08XXtmBURERERERERESkK2xfDlw9BBS8YvuSKQtmkM+d59NN46lXzYNZD7bnpuoIdeyBll4LZMig8PJzw8HCXrV/Wd8uKiIiIiIiIiBS5te/Bt89DfX8Y9BlUrV+gYb+nnKd6ZXeqV67AR8ODaV2/KpUrKLK7FpRzdQEiIiIiIiIiIiJlXp3m4He3Y6dsAUPZpdt/o+ub0XwRux+AYO+aCmWvIWXqk7bW6lwN+cNK47nKIiIiIiIiIlIGnToMe5dDm4FwY1fHvwJIy8gk4rsdvL98Dy3+VIUg75rFXKiURmUmmPXw8OD48ePUqlVL4awUmrWW48eP4+Hh4epSRERERERERKQsOxgPn98D51Og6S3gWatAww78fobHZ21gw74khrS7gfG9W+HhXr6Yi5XSqMwEsw0aNODAgQMcPXrU1aVIGefh4UGDBg1cXYaIiIiIiIiIlFWb5sKCUeBZ13F0QQFDWYAdv55m15Fk3r0ngN5+BTvyQK5OpjT+WXdQUJCNjY11dRkiIiIiIiIiIiI5LX0dlv0dbgiFgZ+CZ+3LDjmfnkls4gk6NHX0TTpznuqVKxR3pVJKGGPirLVBF1/Xy79EREREREREREQKqlJ1CBgK931doFB23/Ez9J++mvtmxHDg9zMACmUFKENHGYiIiIiIiIiIiLjE779A0i/QuDO0G+m4VoB3IC3edJjn5m7EGHj3nkAa1KhczIVKWaJgVkREREREREREJC+/rIbZ94J7ZXg8Htwuv9vVWssrC7cSuTqRNg2r8+7gABrWVCgrOSmYFRERERERERERyU38J7DoaajRCAZ/XqBQFsAYQy3PCjzUqTHP9mxBBTedJiqXUjArIiIiIiIiIiKSXWYmRL0Ia6fBjV2h/wyoVOOywxYkHKRG5Qp0blaH0V2bYgpw3IFcuxTXi4iIiIiIiIiIZGcMpJ6C9o/BPXMuG8qePZ/BuHkbefLzBGau+8U5hUJZyZ92zIqIiIiIiIiIiAAc2wU2E+o0g75ToNzl9zTu+u00o2ZuYMeR0zwWfiNPd29WAoXK1UDBrIiIiIiIiIiIyO4fYc4wqN0MHvi+gKFsMn2mrKJyhfJ8PCKELs3qFH+dctVQMCsiIiIiIiIiItcua2Hd+/DdC1CnBdz1b8dRBvkOsRhjuLGOJ4+G38jA4IZcV9WjhAqWq8Vlo39jTENjzFJjzDZjzBZjzJO59HnWGJPg/LfZGJNhjKnpbEs0xmxytsUWx02IiIiIiIiIiIhcsfTzsPBJ+PY5aHYrPPAd1GiU75Dtv57izmmr+eV4CsYYnuh2k0JZKZSC7JhNB8Zaa+ONMVWAOGPM99barRc6WGvfAN4AMMb0AZ6y1p7INsfN1tpjRVm4iIiIiIiIiIjIH2Ph6HboNBZu/ku+xxdYa/l8/X4mfL2FqpXcOZZ8nka1PEuwVrnaXDaYtdYeBg47fz5tjNkGXA9szWPIYGBWkVUoIiIiIiIiIiJSlI5shar1oFINuH8RuFXIt/vp1DRemL+ZhT8dIqxpbSYN9KdOlYolVKxcrS5/inE2xhhvIABYl0d7ZeBWYF62yxaIMsbEGWMeLlyZIiIiIiIiIiIiRWD7N/Dv7rDkOcfvlwllAaYv2803Gw/xTI9mfDIiRKGsFIkCv/zLGOOFI3AdY609lUe3PsCqi44x6GitPWSMqQt8b4zZbq1dnsv8DwMPA9xwww0FvgEREREREREREZHLshZWvAk/ToT6AXDLK5fpbvn9TBo1PSsw+uab6NriOto2qlFCxcq1oEA7Zo0x7jhC2ZnW2i/z6TqIi44xsNYecv7vb8B8ICS3gdbaD6y1QdbaoDp16hSkLBERERERERERkctLOwvzHoQfXwXf/jB8seMogzycPJvGYzPj6f/eas6cT6dShfIKZaXIXTaYNcYY4N/ANmvtW/n0qwZ0ARZku+bpfGEYxhhPoAew+Y8WLSIiIiIiIiIiUmCpp2D/Ouj2EvT7ENwr5dn1p/1J9J6ygu+3HmFgcEM83MqXYKFyLSnIUQYdgaHAJmNMgvPaC8ANANba6c5rdwJR1tqUbGOvA+Y7sl3cgM+std8WReEiIiIiIiIiIiL5+m0b1LoJqlwHj62Fil55drXWMmNVIn9fso26VTyY/UiodslKsbpsMGutXQmYAvSLBCIvurYHaFPI2kRERERERERERApn01xYMArCnobw5/INZQHSMy2LNh4ivHld3ujvR/XKl38pmMgfUeCXf4mIiIiIiIiIiJR6mZmwdKLjRV+NOkLwA/l2j9/3O41reVLDswKRw0Oo6uGG86+/RYpVgV7+JSIiIiIiIiIiUuqdOw2z73WEsoH3wdCvwLN2rl0zMy3Tl+1mwPQ1vBG1A4BqldwVykqJ0Y5ZERERERERERG5OpzYC3uXwa3/gHaPQB4h6/Hkc4yd8xPRO47yZ98/Me62FiVcqIiCWRERERERERERKetO7IWajaGeHzy5ETxr5dl188GTPPDxen4/k8ard/hwb7sbtEtWXEJHGYiIiIiIiIiISNkV9zG8G+x42RfkG8oC/KmaB41qeTL/sQ4Mbd9Ioay4jIJZEREREREREREpezLSYck4WPgENO4MTW/Js+vR0+d4fck20jMyqe1VkS8eCaV1/WolWKzIpXSUgYiIiIiIiIiIlC1nf4c5w2HPUmj/GHR/FcrnHnOt2nWMJz9P4HRqGn/2qUebhtVLuFiR3CmYFRERERERERGRsiVxFfyyCvq+C4FDc+2SkWl55787mfLjTprU9uTTB0No8aeqJVyoSN4UzIqIiIiIiIiISNmQ/Bt41YWWveHxeKjeMM+uz83byNy4A9wV2IBX72hN5QqKwaR00TdSRERERERERERKN2th3XT44RUY9g00aJtnKGutxRjDfaGNaN+kFv3bNijhYkUKRsGsiIiIiIiIiIiUXunn4ZunYcN/oEVvqNM8924Zmbz5/c+cPZ/BhL6t8WtQHb8GOk9WSq9yri5AREREREREREQkVynH4JPbHaFs52fh7v9ARa9Luh1KOsugD9byXvRuzqVnkJlpXVCsyJXRjlkRERERERERESmdNvwHDsXDXf8G3/65dvlh6xGemfsTaemZvDPIn9v9ry/hIkUKR8GsiIiIiIiIiIiULqmnwKMqdHgSmveCOs1y7XYi5TxPfr6BRrU8mTokkMa1PUu4UJHCUzArIiIiIiIiIiKlg7WwIgJi/gUP/QjVrs81lD2WfI5anhWo6VmB/zzYjlb1quLhXt4FBYsUns6YFRERERERERER10s7C/MegB8nQuPOULlmrt2+3XyYmyOi+SJ2PwCBN9RQKCtlknbMioiIiIiIiIiIa506BJ/fA4cSoNvLEPYUGJOjy7n0DP72zTY+XvMLbRpUI7RJbRcVK1I0FMyKiIiIiIiIiIhrLX0Nju2EwbOg+W2XNCceS2H0rHg2HzzFiI6NGXdbCyq46Q/BpWxTMCsiIiIiIiIiIq6RlgruHtDzdQgdDXVb5tpt12/JHPj9LB/eF0T3VteVcJEixUP/aUFEREREREREREpWZib8MAFm9ITzZ8Cj6iWhbGpaBit2HgXgllbXsfz/blYoK1cVBbMiIiIiIiIiIlJyzp2G2UNg5SSo7w/lLv2D7t1Hk7lj6ipGRK7n8MmzAFT1cC/pSkWKlY4yEBERERERERGRkvF7IswaDEd3wG1vQMhDl7zka/6GA7w4fzMe7uX54L4g6lWr5JpaRYqZglkRERERERERESkZX42CUwfh3nlw4805mqy1PP/lJj5fv58Q75pMHhzAn6p5uKhQkeKnYFZERERERERERIpXZgaUKw+3T3GcL1u76SVdjDE0qFGJ0Tc3ZcwtN+FWXidwytVNwayIiIiIiIiIiPw/e3cdHtW5tXH4tyfuAkQgCe6QIMEpUqRA3Z1TOy1QB+o9PXWl9FCHfpRihWIVKlSgeLFSgiVIkJAQosRlMjP7+2MoUlwnIc99XbmSzN7vzhqCzZO113t+2G3w87NQkg3Xj4fQBkccNk2TmWtSiQjypkeTWjx0aWMXFSpy4elHDyIiIiIiIiIicu6V7oep18OqsRAQCabjiMNF5TYe/2odT85ez4w1e1xUpIjrqGNWRERERERERETOraytMO0WyEuBqz+CtncccXjz3gIe+nItu3KKebxvEx669OjRBiIXOwWzIiIiIiIiIiJy7tgrYOoNYC2Gf82Ful2OOLw1o5BrPl5GsI8HU+/rTJeGNVxUqIhrKZgVEREREREREZGzZ5rO924ecN04CKwNwTGHHTYxDIPGYf481rcxN8VHU9Pfy0XFirieZsyKiIiIiIiIiMjZsZXDtw/B0tHOz2M6HxHKbkjN54oPlrI7pxjDMBjWq5FCWan2FMyKiIiIiIiIiMiZK8qCiVfBuinOgPYwpmnyxbKdXP/JcnKLreSXVrioSJHKR6MMRERERERERETkzOzbANNuheIsuOFzaHX9wUP5JRU8OTuBnzdl0KdZGKNujCPEz9OFxYpULgpmRURERERERETk9JXuhwmXg6cf3DMParc94vDHi7YzPzGT5y9vzr3d62MYhosKFamcFMyKiIiIiIiIiMjp8wmBq8ZATBcIiACcowuyi6zUCvDisT5NuLx1JLFRwS4uVKRy0oxZERERERERERE5NdYSmHUvbPnJ+XnLaw+GsvuLrdw3cQ03j/uDUqsdH083hbIiJ6COWRERERERERERObn8NJh+G6QnQJ32RxxasyuXh6f9RU6RlWcHNcPbQ72AIiejYFZERERERERERE4sdY0zlLUWw63ToOlAABwOk08WJTP6161Ehfgwe2hXWkcFubhYkapBwayIiIiIiIiIiBxf9jaYMMg5suDObyC8xcFDdtPkt8QMBraK4I3rWhPg7eHCQkWqFgWzIiIiIiIiIiJyfDUaQd8XIfZm8KsBwModOTQJDyDEz5NJ93TE38sdwzBcWqZUbg7Twbj146gbWJeB9Qe6upxKQQM/RERERERERETkSOWFMPs+yEwEw4Auw8CvBnaHyf9+28qtn63gvd+2AhDg7aFQVk4ovzyfh+Y/xEfrPmL1vtWuLqfSUMesiIiIiIiIiIgckrsTpt0K2VuhUV8Iaw5AZkEZj321juXJOVzXtg5PDWjm4kKlKtics5nhC4eTUZLB852e56amN7m6pEpDwayIiIiIiIiIiDjtXAIzBoPpgDvnQINeAKzbk8d9E1dTVG7j7RtiubF9lLpk5aTmbJvDayteI9QnlEkDJtG6VmtXl1SpKJgVERERERERERFnKDv5GghtALdOhxoNDx6KCvGhWUQgL1zZgibhAS4sUqqCMlsZr698na+3f02XyC681eMtQrxDXF1WpaNgVkREREREREREILojdH0Euj8G3kGk55fyf0t28szAZtT092LKfZ1cXaFUAXsK9zBi4QgScxO5P/Z+hsUNw83i5uqyKqWTbv5lGEa0YRi/G4aRaBjGJsMwHj3GOb0Mw8g3DGPdgbcXDjs2wDCMLYZhbDcM4+lz/QREREREREREROQMleTCtw8537t7Qd//gncQvydlMmjMEqatSiFpX6Grq5QqYnHqYm7+/mZSi1L5qM9HPNz2YYWyJ3AqHbM2YIRpmmsNwwgA/jQM41fTNDf/47wlpmlecfgDhmG4AR8B/YBUYLVhGN8dY62IiIiIiIiIiFxIWVth2s2QnwotrobG/aiwOxj18xbGLt5Bs4gAPrq9HQ1r+bu6Uqnk7A47nyR8wtj1Y2ke2px3e71LdEC0q8uq9E4azJqmmQ6kH/i4mwv83AAAIABJREFU0DCMRKAOcCrhakdgu2maOwAMw5gOXH2Ka0VERERERERE5HzY9ivMusfZJfuv7yHGOaZg5MwEvl23l9s7xfCfK1rg7aFuRzmx/WX7eWrxU/yR/gfXNrqWZzs9i7e7t6vLqhJOa8asYRj1gLbAymMc7mIYRgKwFxhpmuYmnAHunsPOSQU0kERERERERERExFU2zobZ90FYS7h1GgRHY5omhmFwb/f69G0ezpVxtV1dpVQBG7I2MHzRcHJLc3mp60tc1/g6V5dUpZxyMGsYhj8wG3jMNM2CfxxeC9Q1TbPIMIxBwDdAY8A4xqXM41z/fuB+gJiYmFMtS0RERERERERETsZaDEUZENoA6vWA+Huh30tYLT68OXczdoeDl65uRWxUMLFRwa6uVio50zSZsWUGb65+k3DfcCYNmkTLGi1dXVaVc9LNvwAMw/DAGcpONU1zzj+Pm6ZZYJpm0YGPfwQ8DMOoibND9vCBElE4O2qPYprmONM0403TjK9Vq9ZpPg0RERERERERETlKURYseBXeawlzHnA+5l8LLh9FSqHBDZ8u5/NlOzEMA4fjmL10IkcotZXy3NLneHXlq3SJ7MJXV3ylUPYMnbRj1jAMAxgPJJqmOfo450QAGaZpmoZhdMQZ+OYAeUBjwzDqA2nALcBt56p4ERERERERERE5hpxkWP4BrPsS7FZodjl0feTg4R83pPPUrPUYBnx6R3sGtIpwYbFSVewu2M3jCx9n+/7tPNjmQe6PvR+LcUp9n3IMpzLKoBtwJ7DBMIx1Bx57FogBME3zU+AGYKhhGDagFLjFNE0TsBmG8RDwM+AGfH5g9qyIiIiIiIiIiJxrpgmGAckLnKFsm1uhy0NQs/HBU7KLyhk5M4HG4QF8eGtbokN9XViwVBXzU+bz/NLncbe482nfT+lap6urS6ryDGd+WrnEx8eba9ascXUZIiIiIiIiIiKVn8MB236BZWMg9iaIvxsqSqG8EPzDsNocLEjKYEFSJm9eF4vFYpCwJ4/mkYF4uqvbUU7M5rDx/l/vM2HjBFrVaMW7vd6ltr82hzsdhmH8aZpm/D8fP+XNv0REREREREREpBKxlcP6Gc6RBdlbICgaPP2cxzx8SM6zM2NxIrPXppJdZCUi0Ju9+aVEhfgSF60NvuTkskuzeWrxU6zat4qbmtzEUx2fwtPN09VlXTQUzIqIiIiIiIiIVEUzBsPWeRDeGq77DFpeC24eAPyRnMOtn63AzWLQp1kYt3SMpkfjWri7qUNWTs26zHWMWDiCfGs+r3V/jasaXuXqki46CmZFRERERERERKqC/DRY+Sl0ewz8ajg38+r0ADTozeb0Qr76fgvRob7cd0kD4uuF8Nyg5lzdtjZhAd6urlyqENM0+TLpS0atHkWkfyRT+06laWjTc3LtCrsDm93Ex9PtnFyvqlMwKyIiIiIiIiJSmWVsco4r2DDTublXVDy0uJrCiI58l7CXr35axvrUfDzdLdzVtR4AHm4W/t2jgWvrliqnpKKE/y7/L/N2zaN3dG9e7f4qgZ6B5+TaiekFjJiRQJuYYF6/tvU5uWZVp2BWRERERERERKQyslnhq9udG3t5+EKHf2N2HoIRUg+Ap+ds4If16TSLCODFK1twbdsognw9XFuzVFk78nfw+O+Ps6tgF4+1e4y7W92NxTj70Rc2u4NPFyUzZv42gnw86Nmk1jmo9uKgYFZEREREREREpLKw2yDtT4jpBO6e4BcGvZ9nf8s7mZ1YwlcTdvPZ4FrUq+nHg70acf8lDYiNCsIwDFdXLlXYz7t+5oVlL+Dt7s24fuPoFNnpnFx3W0YhI2YmsD41nytiI3n56laE+mnzsL8pmBURERERERERcTVrCaybCn98CHkp8PBaHMH1WN7yJaatTuGXX9ZSYTdpGxNMfmkFAC1qn5tbzKX6qnBU8N6f7zF582TiasXxbs93CfcLP+vr2h0m/7dkB+/+uhU/Tzc+uq0dl8dGnoOKLy4KZkVEREREREREXKU0z7mh18qxUJoLdeKx930Ft+AY8ksruOeL1fh6uXFH57rc0iGGphEBrq5YLhJZJVmMXDSStZlrua3ZbYyMH4mH29mPwtiRVcTImQmsTcnjspbhvHpNa2oFeJ2Dii8+CmZFRERERERERC40uw3c3KGiFJaMxtGgN2vq3MnYnWEULrEzo6UbIX5ufPnvTrSqE4S3h3axl3Nn9b7VPLHoCUpsJbx1yVsMajDorK/pcJhMWL6Lt+cl4e3hxphb2nBVXG2N2TgBBbMiIiIiIiIiIhdK2p+w7H1nd+y/5rLHFsS3cV8zcUMpWRvKCQso4Mb4KOwOEzeLQXy9UFdXLBcR0zSZuGki/1v7P6IDovm//v9Ho5BGZ33d3TnFPDFzPat25dKnWRhvXNeasEDvc1DxxU3BrIiIiIiIiIjI+WSasO1XWP4+7FqC6RWIvd3duNttLNyaxeg/8ri0WRg3d4ihd9NauLtZXF2xXISKrEX8Z9l/+C3lN/rV7cfLXV/G39P/rK7pcJhMXbmb139Mwt3NYNSNcVzfro66ZE+RglkRERERERERkfPprynw3UNU+EWyoM7DvJQWz7Cgttzh5s51bevQr3k4EUHqLpTzZ9v+bTy+8HFSC1MZGT+SwS0Gn3V4uie3hKdmr2d5cg49mtTiretbExnkc44qrh4UzIqIiIiIiIiInEtlBbB2IgTHYDa/itllHUgNGMFHWXEYeZ70bxlOi9qBAPh5uePnpXhGzp8fdvzAS3+8hJ+HH+MvG0/78PZndT3TNJm+eg+vfr8ZgDeua80tHaLVJXsG9CdfRERERERERORcKEiHlZ9grpmAUV4A7e/CaHE1X67LodDoydNXxHBt2zqE+nm6ulKpBirsFbyz5h2mJU2jXVg7RvUcRS3fWmd1zfT8Up6avYHFW7Po2rAGb98QS1SI7zmquPpRMCsiIiIiIiIicrYWj8Jc+Camw84Sj66MsQ1k/KX3EwJ8flcHgnw81FEoF8y+4n2MWDiC9dnruavlXTzS7hE8LB5nfD3TNJn1Zyovf78Zm93klatbcnunulgs+j19NhTMioiIiIiIiIicLtOE3cshvCU7iz34PckNT1tvxlYMJLRmE27uG4O3hxsAwb7qkJULZ0X6Cp5c9CRWh5XRvUbTr26/s7peZkEZz8zZwPykTDrWC+WdG2OpW8PvHFVbvSmYFRERERERERE5VQ47JP1AxeL38Ni3Fvq9gtH0Xsakt+Sa9v0Y2yHm4PxYkQvJYToYv2E8H677kPqB9Xmv93vUD6p/xtczTZPvEvbywrebKKuw858rWnB313rqkj2HFMyKiIiIiIiIiJyMaWJf/Tnli8fgW7SbvWYYS2s9yu0d7qOepy+rn+uLp7vF1VVKNVVgLeC5Jc+xMHUhA+sP5MUuL+LrceazX7OLynnu6w38vCmDdjHBjLoxjga1/M9hxQIKZkVEREREREREjq+iFDx8+GL5Lpr9NglvuxvT3IcTEn89N3WoB57O8EuhrLhKUm4Sj//+OPuK9/FMx2e4tdmtZzXP+If16fzn240Uldt4ZmAz7rukAW7qkj0vFMyKiIiIiIiIiPxDefZO0n8aRd3U7zCGLqOwzMbEqBe5umMzXm0RjoebglhxvW+2f8OrK14lyCuICQMm0CaszRlfK7fYygvfbuT79enERQUx6sY4GocHnMNq5Z8UzIqIiIiIiIiIHJCycTmFC0bTNHc+tU2DvfWuog4GD13aCMNo7OryRAAot5fzxso3mL1tNp0iOvFWj7eo4VPjjK/386Z9PPf1BvJLK3jisqY80KMB7vrhw3mnYFZEREREREREqr2conJGTpzPuMw7CMWT30NuwL/Xw3SMbQ0WA93ILZVFWlEawxcOZ3POZu5rfR8PtXkIN4vbGV0rr8TKS3M38/VfabSIDGTyvZ1oHqnN6y4UBbMiIiIiIiIiUi1tSMnGvn42bSzbCR34NvjWZH7cu3ToeQX9atRydXkiR1matpSnlzyNw+FgTO8xXBpz6Rlfa0FSBk/P3kBusZVH+zTmoUsbaUTHBaZgVkRERERERESqjfzSCn5cs5WiPz5nUPHX1DFyMGs1w7AWMeHujkBHV5cochSH6WBswlg+SfiExiGNea/Xe8QExpzRtQrKKnhl7mZm/plK0/AAPr+rA63qBJ3jiuVUKJgVERERERERkWrhq9UpfPftLD52e4cgo4SM0PYU9/kAvxYDwaJOQamc8sryeHrp0yxLW8ZVDa/i+c7P4+Puc0bXWrw1i6dmryejoIwHezfkkT6N8XI/szEIcvYUzIqIiIiIiIjIRSmrsJw5a1PpUyufRgE2mkY0oXmbzpil/aH3w4RHxbu6RJET2pS9ieELh5NVmsULXV7ghsY3YBinP/G4qNzGaz8kMm1VCg1r+TFnWDfaRAefh4rldCiYFREREREREZGLht1hsmRbFl+t3kNO4mLus8ylkdufUCeeNv+eT5vorkBXV5cpckKmaTJr2yzeWPkGNX1qMmngJFrVbHVG11q+PZsnZq1nb34p9/dowPB+TfD2UJdsZaBgVkREREREREQuCqZpcvn7SwjJXMFTnrNo47EFu1cwdHoSOt7v6vJETkmZrYxXV7zKt8nf0q12N9685E2CvU+/u7XEauPNn5KY9Mdu6tf0Y9aQLrSvG3oeKpYzpWBWRERERERERKokq83BgqQMFm3N5vUrG2EAt3eKITZ7A7HbS6Hr27i1vQM8/Vxdqsgp2VOwh8cXPs6W/VsYEjeEIbFDcLOcfnfrqp25jJyZQEpuCXd3q8eTlzXDx1NdspWNglkRERERERERqVKSs4qYsXoPs9emYi3KZajfQhzbf8Wt+2Pc2fUhsD8APABuij2k6li4ZyHPLnkWwzD4qM9H9IjqcdrXKKuw887PW/h82U6iQnyYfn9nOjeocR6qlXNBf0OJiIiIiIiISJWxZFsWd45fRZQlhzdrLqQXP+FuK4HIPlCnvfMkBbJShdgddj5a9xGfbfiMFjVaMLrXaOr41znt66xN2c/IGQnsyC7mzs51eXpgM/y89GehMtN3R0REREREREQqrU178/lq9R4a1PTjrm716Vg/lGcHNeOu7Y/hmbocWl0PXR+GiNauLlXktOWW5fLk4idZmb6S6xtfzzOdnsHLzeu0rlFWYee937by2eIdRAb5MOXeTnRvXPM8VSznkoJZEREREREREalUCssq+C5hL9NX7WFDWj6e7gYvtcqGqU/ideX73N+jITR9C7wCIDja1eWKnJGErARGLBxBXnkeL3d9mWsbX3va11ifmseIGQlsyyzilg7RPHd5cwK8Pc5DtXI+KJgVERERERERkUpl5MwEft6UQYtwXybG76Zb5pe4J20AvzDI2QaBkRDewtVlipwR0zSZvmU6b69+m3DfcCYPnEzzGs1P6xpWm4MPFmzj44XJ1PL34ou7O9Cradh5qljOFwWzIiIiIiIiIuIyucVW5qxNZcaaPYz/VweiQ315+NLGPNg9itbfDcDYuBNqNIYr34fYm8HD29Uli5yxkooSXvrjJX7c+SM9onrwevfXCfIKOq1rbNqbz4gZCSTtK+T6dlG8cGULgnzUJVsVKZgVEREREZFTtqdgD98kf0N0QDSXN7gcD4teCIrI6XM4TJYlZzN99R5+2bSPCrtJu5hgCnPSIWUVrdrcCgRBm9sgvBU0GQAWi6vLFjkru/J38fjCx9mRv4NH2j7Cva3vxWKc+u/rCruDj39P5oMF2wjx8+SzwfH0axF+HiuW803BrIiIiIiInJBpmvyV+ReTNk9iQcoCTEwAxq0fx/2x93NFgytwt+ilhYicnM3uwN3NQk6xlbsnrMbf2507O9fjjsYVNNg2Ab6aBnYr1OvunB3b80lXlyxyTvy2+zeeX/Y8nhZPPu37KV1qdzmt9Vv2FTJi5jo2phVwdZvavHhlS0L8PM9TtXKhGKZpurqGo8THx5tr1qxxdRkiIiIiItWazWHjt92/MXHTRDbmbCTQM5Cbm97MzU1vJjE3kY/XfUxibiIxATE8EPcAg+oPUkArIkepsDv4PSmT6av3UG6zM/W+zgCs2plLbEAB3vOfh8Tvwc0T4m6Brg9DzcYurlrk3LA5bIxZO4YvNn1B65qtGd1rNBF+Eae+3u5g7OIdjPltGwHe7rx6TSsGto48jxXL+WAYxp+macb/83H9r0lERERERI5QaC1kzrY5TE2cSnpxOnUD6/Jsx+cIMbswa3UWbfwt9GrSiwBHLMO+nkyG42eeW/ocb/7xId1q3MLDnW8iJsSfCrsDi2HgZjFc/ZRExAVSckqYtjqFWX+mklVYTliAFzfFR+Ow27GUZNOxfjiUAGlr4ZLh0PEBCNBt2XLxyC7NZuSikfyZ8Sc3N72ZJzs8iafbqXe5bs8sZMTM9STsyWNQ6wheuboVNfy9zmPFcqEpmBUREREREQDSitKYmjiVOdvmUFxRTHx4PMNajyQ1rR4ffZdK6v6N1PT35IpYZ6eOr6c73SJ7sje/Iyllq8j3/oF59vf469dZPBY/DFtBHCNmbCAswIuIIG8ig7yJCPTh3z3qExnkQ26xlRKrjfBAbzzcNDtS5GJQVmHHMMDL3Y35SRmMW7yD3k3DuKVDNL0aBuK+aRZ8cid4+sK/fwffUHhsA1jcXF26yDm1NmMtIxeNpNBayOvdX+fKhlee8lq7w+TzpTt555ct+Hq68cGtbbkiNhLD0A86LzYaZSAiIiIiUs0lZCUwadMkfkv5DQsW+tfrz+CWg2kR2oIe7/zOntxSujaswW2dYujfIgJP92OHqHaHnbnbf2Vi4ji2522jtl9dGnlci3d5OzIKrKTnl5KeX8a8R3sQU8OXcYuTef3HJAwDavp7HQhuvXn7hliCfT3ZnllITpGVyCAfwoO88HJXcCNSWSXtK2D6qj18/Vcazw1qzk0doiksq6DEaifcsxzWTIAVn0DRPghvDd0ehVbXa0MvueiYpsnkzZMZ/edoogKiGN1rNE1Cmpzy+p3ZxTwxM4E1u/fTt3k4r1/XirAA7/NYsVwIGmUgIiIiIiIH2R125qfMZ9LmSSRkJRDgEcAtTe7Eu/QSViQ6aNKtOYZh8No1rakT4kPDWv4nvaabxY1rmgzgqsb9mZ8yn4/XfczivP/RIKgBQ/oMoX/d/kfsPt27aRiB3h6k55exL7+M9IIyduUU4+PpDGCnrEjhi+W7Dp5fw8+TyGBvvn2wO24WgyXbssgsKCcy2JvIIB8iAr0PrhWR88/hMJmxZg/TVu8hYU8enm4W+rcMp0lEAAAB3h4EeHvAn9Pht/9Cg15w7SfQoDeo808uQsUVxbyw7AV+2f0LfWL68Eq3VwjwDDiltQ6HycQ/dvHWvCQ83SyMvimOa9vWUZfsRU4dsyIiIiIi1UhxRTFfb/uaKYlTSCtKI8o/ip4R17M3tRXzNuzHanPQLiaYD29rR+1gn7P6Wg7Twa+7f+XThE/ZnredhkENGdLm6ID2ePbmlbIjq/hgp216fhlF5TY+uLUtAEMm/8m8TfuOWNOwlh/zR/QCYMqK3eQWWw+OUXC++eDnpf4UkTNlmiap+0uJDvUF4KoPl1JqtXNzh2iuaxdFqJ8nZGyC5R9AdEeIvwcqyiArCWq3cXH1IudPcl4yjy98nN0Fu3ms3WPc1fKuUw5VU3JKeGJWAit35tKraS3evC6WiCB1yV5Mjtcxe9Jg1jCMaGASEAE4gHGmaY75xzm3A08d+LQIGGqaZsKBY7uAQsAO2I5VxD8pmBURERERObf2Fe9jauJUZm2dRVFFEW3D2jK4xWACHHHcPHYV/l7uXNu2Drd1iqF5ZOA5/doO08Evu37hk4RP2JG/g0bBjRgSN4R+dfudUkB7PKVWO/sKykjPL3V23OaXAfBg70YA3Dl+JUu2ZR+xJjYqiO8e6g7Ai99totRqPzT/NsibejX8qFfT74xrErlY5ZVY+fqvNL5avYddOcWsfLYvQT4e7C+2EuzrgQGwayksGwPbfwUPX+jxhHNTL5GL3Lyd83hh+Qv4uPswqucoOkR0OKV1pmkydWUKr/+YiMUweOGKFtwYH6Uu2YvQ2QSzkUCkaZprDcMIAP4ErjFNc/Nh53QFEk3T3G8YxkDgRdM0Ox04tguIN00z+xiXPyYFsyIiIiIi58am7E1M3DyRX3b9AkDHsF5YCnvSJKgFw/s3xeEw+TYhjf4tIs57J6ndYeeX3c6Admf+ThqHNGZo3FD6xPQ5q4D2RMoq7GQWlDvD24IyvNwtDGjl3Lzs7gmr2LS3gKyicv5+WdS/RTjjBjtfN9346XK83N2OCG5b1wkiNioYcL6g1otnudglZxXx/vxt/LRxH1abg7ioIG7uEMO1bescOTpk7mPw5wTwqwUdH4AO9zo39hK5iFXYKxj952imJE6hbVhbRvUcRZhv2CmtTcsr5alZ61m6PZvujWry1g2x1DnLO1Wk8jrjGbOmaaYD6Qc+LjQMIxGoA2w+7Jzlhy1ZAUSddcUiIiIiInJG7A47C1MXMmnTJNZmrsXP3Y8OoVeRlhLPL5vd8PawUKeTHQCLxeDathfmv+9uFjcG1h9I/7r9mbdrHp8mfMrwhcNpEtKEYXHD6B3T+5wHtN4ebsTU8CWmhu9Rxybc3RGACruDzMJy9uWXHtxgzDRNwgO92ZtXyrLt2WQUlOEw4a6u9YiNCnYGVC/9Qnig14Hg1oeIIG96Nw2jY/1Q7A6T3GIrNfw8sVgU3krVkllQRmmFnbo1/DBNWLgli9s6xnBTfDQtah/oqLeWwKqp0OIa8K8FLa+ByFiIuxU8FC7JxS+jOIORi0ayLmsddzS/g+Hxw/GweJx0nWk6ZzO/8n0iDtPktWtbcVvHGP2gr5o6rRmzhmHUAxYDrUzTLDjOOSOBZqZp3nfg853AfsAExpqmOe5kX0cdsyIiIiIip6+kooRvk79lyuYppBSmUNuvNrc3v531SU2ZtTqbxmH+3N4phmvbRRHkc/IXj+eb3WHnp10/MTZhLLsKdtEstBlD4oZwafSlle4Fqs3uILvIisWAsEBvisttjJm/7cDGZc4ZuBkFZYzs35QHejZkT24Jl7z9Ox5uBuGBf3fc+nB7pxg6N6hBcbmNbZlFRAZ5U9PfCzeFt+JiNruDRVuzmL56DwuSMhnQKoKPbmsHgNXmwNP9wA9NirNh1WewahyU5sLl70KH+1xYuciFtyp9FU8sfoJSWykvd3uZAfUGnNK6ffllPD1nPQu3ZNG5QSjv3BB3cF5ztVFRBvZy8A5ydSUX1BmPMjjsAv7AIuA10zTnHOec3sDHQHfTNHMOPFbbNM29hmGEAb8CD5umufgYa+8H7geIiYlpv3v37lN7ZiIiIiIi1VxGcQbTkqYxc+tMCqwFRPk2xczryVsDbiMuugbbMwvJLa6gQ72QShd4AtgcNn7a+ROfJnxKSmEKzUObMzRuKL2ie1XKeo/H4TCxOUw83S3sL7byXcLeI4LbfQVlPDWgGYNaR7JyRw43j1sBgJvFIDzA2Xn77KDmxNcLZW9eKX+l5B0coxAW4IW72/kZ9yAyfulOPlu8g30FZdT09+T69lHcHB9Ng1r+h05yOGDeU7B2MthKoclA6PYoxHSGKvTnVORsmKbJ5xs/5/2/3qduYF3+1+t/NAhucErr5qxN48W5m6iwO3h6QDMGd6lX/e6o2P0HfPcQ1ImH68a6upoL6oxHGRxY7AHMBqaeIJSNBf4PGPh3KAtgmubeA+8zDcP4GuiIs+v2CAc6aceBs2P2VOoSEREREanOEnMSmbx5Mj/t/AmH6aCOZweK93UkcX9t6tbwI7/UAUCjsAAXV3pi7hZ3rmx4JQPrD+SHHT8wdv1YHvn9EZqHNmdYm2H0jOpZJQJai8XA88CL7BA/T/7Vtd5xz20aEcD4f8Wz9/DgNr/s4CiF1btyeXT6ukPXNqBWgBcT7upIi9qBrE/N44/knINjFCKDvAkP9D7U1ShyAuU2O78nZdK3eTjubhbyS6w0jwzgxata0qd5GB6H/xAgdweENgCLBUpyoPX10PURqNXUdU9AxAUKrYU8v/R5FuxZwGX1LuOlri/h53HyzSIzC8t4ds5GfkvMIL5uCKNujKt+m0yWF8JvL8HqzyA4BuJudnVFlcapbP5lABOBXNM0HzvOOTHAAmDw4fNmDcPwAywHZtP64eyYfdk0zXkn+poaZSAiIiIicmwO08GS1CVM2jyJVftW4ePuw9UNr2H6b/UoKQmmX/Nwbu8cQ7eGNatsJ47NYeP7Hd8zNmEsqUWptKzRkmFthnFJnUuqREB7LpRYbaTklhwMbNPznOHtE5c1JSzQm3GLk3n9x6Sj1i1/+lJqB/swb2M6yw8Gt95EBDrD25hQ3yr7+0LO3raMQr5avYc5f6WRW2xlwt0d6N007OiN7EwTtv8Gy8bArqXw0Bqo2cj5eDX5MyhyuC25Wxi+cDh7i/YyPH44dzS/46T/Hpmmydz16bzw7UZKrHaevKwpd3erX/1G12z71bk5YEEadB4KvZ8DL/+Tr7vInPEoA8MwugNLgA2A48DDzwIxAKZpfmoYxv8B1wN/zx+wmaYZbxhGA+DrA4+5A1+apvnayYpVMCsiIiIicqRSWylzk+cyefNkdhXswtdSg+CK3sy4/TGCvIL4PSmTFrUDCQ/0dnWp50yFo4Lvk79n7PqxpBWl0apGK4a2GVqtAtoTKSyrcIa2f4e3+WUM7dUQT3cLHy7YxrjFOygosx083zBg66sD8XCz8MnCZNbsyj0U3Ab5UDvYm64Na7rwGcn5kllYxtApa/lz93483Az6tQjn5g4xdG9U88iQyGaFjbNg+QeQuRkCakOXYdD+LvCq3J33IufL3OS5vPzHywR4BjCq5yjahbc76ZqconKe/2YjP23cR1x0MO/eGEehdbyFAAAgAElEQVSjsGoWRhbnwM/PwPqvoFYzuOoDiO7o6qpc5qxnzF5ICmZFRERERJyyS7OZljSNGVtmkFeehx/1yEnrjK2wNb2aRPC/W9pWio28zqcKRwVzk+cybv040orSiK0Zy9A2Q+lWu5sC2pMosdrYdyC4zSm2cmVcbQA+XLCNnzbuIz2/jNxiKwBhAV6seq4vAI9N/4vN6QUHxyREBHnTsJb/wfVlFXa83C369a+kTNNkQ1o+aftLGdg6EofD5F8TVnFJ45pc1y6Kmv5e/1zgTO6LsuC9llCjoXNcQavrwd3TNU9CxMWsditvrXqLGVtn0CGiA2/3eJuaPif/4dVPG9J5/puNFJbZeKxfY+6/pEH1mhFumrBpDvz4JJTlwSUjnG/uXidfexFTMCsiIiIiUoVs3b+VyZsn88OOH7A5bDQP6syahNYEuzXjlg7R3NIhptrt5Fxhr+Db5G8Zt34c6cXpxNaK5cG4B+lSu4sCwrNQVmEno6CMglIbraOcu2SPXZTMn7v3s6/A2YmbVVhOXHQw3z7YDYDL31/CruzigzNuI4K8aV83hFs7xgCwI6uIUD9Pgnw89L25gPJLK/h2XRrTV+1hc3oBdYJ9WPJk7+OPryhIh5WfQMZmuGOW87HMJOf8WH3fpBpLL0pnxKIRbMjewN2t7uaRto/gbjnxNk37i63897tNfJewl1Z1Ann3xjY0jahmneYFe+GHEbDlR6jdFq76ECJaubqqSkHBrIiIiIhIJWeaJsv2LmPSpkn8kf4HFjxpFdiX1/sMI8I3ivmJzs16qvsGTxX2Cr5J/obP1n9GenE6bWq1YWiboXSJVEB7vlhtDorKbYT6Obsnp6zYTXJW0RGjFNrXC+Gj25y3+LZ/5Vdyiq14e1icwW2gNwNaRRzcEO33LZnU8vciMsibUD9Pfd/OgS9XpvDS3E2U2xy0rB3ILR1juCqu9rE76jOTnOMK1n8Fph1aXA1Xfwye1euHPSLHsjxtOU8teYoKRwWvdXuNPnX7nHTNr5szeGbOBvJKrDzSpzFDezU8chO9i51pwtqJ8Mt/wF4Blz4HnYaC24nD7OpEwayIiIiISCVVbi/n++Tv+WLTJHYV7MBiD6Q0pwveZd15sEdrHujZ0NUlVkpWu5Vvtn/DuPXjyCjJoG1YW4a1GUaniE4K+lzg8A2kftyQzt68UmdwW+AMbns2qcUjfRpTVmGn2X8O7Qft6W4hItCbe7rV465u9Sm32Zm2MoWIA2MUIoO9qennpU3L/iGrsJw5a1O5tFkYjcMDWLMrl2/WpXFLhxha1Qk6/sKkH2H6reDuA23vgC4PQmj9C1e4SCXlMB18tv4zPlr3EQ2DG/Jer/eoF1TvhGvySyp46ftNzFmbRrOIAN69KY6WtU/w5+9ilJMMcx+FXUug3iVw5RjnOBQ5goJZEREREZFKJqc0hxlbZjB9y3Ryy3LxJZqctC7EhvTkjk4NGNQ6Em8PN1eXWelZ7VbmbJvDZxs+I7Mkk3Zh7RjWZhgdIzoqoK2EbHYHm/YWHOi0LSX9QNdtvxbhXBlXm905xfR8Z+ERa9wtBi9f3YrbOsWQWVDGZ0t2HAxu/97ALCzA+6Lf7dzuMFmyLYvpq/bwW2IGNofJ85c3575LGhx/kcMOST+AxR2aDQJrMaz4GNrfA341LlzxIpVYfnk+zyx5hiVpS7i8weW80PkFfD1O3EH++5ZMnp69nuwiKw/2ashDlzauXne02G3OUSgLXgM3D+j/CrT7l8agHIeCWRERERGRSiI5L5nxGyby487vsZsVdAzvxv1xd+PnaIqbxULzyEBXl1glldvLmb11NuM3jCezNJP24e15sM2DdIjo4OrS5DQ4HCa5JdbDxiSUHgxu28aEsDZlP7d9toKyCscR6z66rR2Xx0ayITWfTxZtJyLQh4ggr4MBbovIQPy8qu5ttQ6HyWX/W8y2TOf83uvb1eHmDjHH3+m9ohQSpjlHFuTugIZ94M45F7ZokSpgc85mhi8cTkZJBk93eJqbmt50wh/qFZZV8Or3iXy1Zg+Nw/x596Y4YqOCL2DFlUDGJvj2Idi7FpoOgsvfhcDarq6qUlMwKyIiIiLiQqZpsiJ9BR+v/Zx1OSswHe5U5LenrvtlvHN1P+Kiq9mLuvOo3F7OrK2zGL9hPFmlWXSI6MCwuGHERxz1ekiqKNM0yS+tODjfNj2/jB5NahIV4svSbdm8OHcT6XmlFFvtB9fMGdaVdjEhzNuYzscLk4kI9D7QcesMbvs0DyPA2+OIkQyuZLU5mJ+YwbLkbF65uhWGYfD50p2EB3rTr8VJZk2vmwa//geKs6B2O+j2CDS/CizqwBc53Jxtc3htxWuEeIcwutdoYmvFnvD8pduyeXJWAvsKyri/R0Me69u4et3ZYiuHxaNg6WjwDoZB70DLa9UlewoUzIqIiIiIuIDVbuXHnT8yafMktu3fhmnzx5Hfjf7R13B351bERgVVihDoYlRmK3MGtBvHk12aTaeITgxtM5T24e1dXZpcAKZpUlhuOxjctq8bgr+XO78nZfLF8l3syy9jb34phWU2AP545lIig3z46PftjF+682BwGxnsTWSQD/d2r4+3hxv5JRV4uBv4ep6f7tvkrCJmrN7D7LWpZBdZiQj05ruHuxEW4H3ihft3g3cg+ITApm9g3VTo+gjU667QROQfyu3lvL7ydeZsm0PnyM681eMtQr1Dj3t+cbmNN35KZMqKFBrU8mPUjXG0iwm5gBVXAntWObtks7dA7C0w4A3wPf6vmRxJwayIiIiIyAWUV5bHJ2snM2f7DMrMPBoFN2Jwi8H4VsTTuX7EsXdKl/OizFbGzK0zGb9hPDllOXSK7MSDbR6kbVhbV5cmlUDRgfC2fk0/3CwGv2/J5NfNGUeMUigos7H11YG4WQye+3oDU1emEOTjcXDGbVSID69e0xqA7ZmFAEQE+eB/mqMTft+Syd0TVuNuMejTPIxbOsTQo0mtE8/O3bsOlr/vDGN7PQ09nzzjXwuR6iC1MJXhC4eTmJvIv1v/mwfbPIjbCbrJV+zI4YlZCaTuL+XebvUZeVnT6tUlW14EC16FlZ9CYB248n/QuJ+rq6pyFMyKiIiIiFwAW3KTeXv5Z6zO/gXTqMBR3IR2wdcw/qbb8HSvRi/kKqFSWykztszg842fk1uWS5fILgxrM4w2YW1cXZpUcmUV9oNBzPLkbP5KyTsU3BaUYprwwyOXAHDPF6tZkJQJQICXOxFB3rSuE8Tom52/zxZuycQEIoO8sdoczPozlSbhAdzRuS6lVjuT/tjFte3qnLxDdvt8WDYGdi4CzwCIvws6DYWgOufpV0Gk6lucupinlzwNwBvd36BndM/jnltqtfPWvCS+WL6LujV8GXVjHB3qVbMO0eQFMPdRyEuBDv+Gvv8FrwBXV1UlKZgVERERETlPTNNkTcYaJm6ayKLURZgON7zLO3Bdw1t5oEs3avh7ubpEOcw/A9qutbsyrM0w4mrFubo0uQhsTMsnOavoYHCbnl9KkI8Hb9/g/P01cMwSEtMLDp7v6W7h35fU54nLmp384g4HWA7Mlp16E+xbD52GQPzd4B10Pp6OyEXB7rDzScInjF0/lmahzRjdazTRAdHHPX/NrlxGzkxgV04Jd3Wtx5MDmp638SWVUul++OV5+GsK1GgEV30Adbu6uqoqTcGsiIiIiMg5VmIt570/ZvDtzumUGimEeIVwTcMbaerbn4HNm2A50e3H4nIlFSV8teUrJmycwP7y/XSr041hccNOuvmLyNlIzy9lb14p6flllFc46NM8jGBfzxMvKi+CtZOctxIP/hZC60PhPvAJBfeTrBWp5vaX7efpJU+zfO9yrml0Dc91eg5v92N3pJdV2Hn3ly3839Kd1An24Z0b4ujSsMYFrtjFNn8HP46E4mzo9ij0fAo8TtLBLyelYFZERERE5BxJyszgzaUT+HP/9+Cej1ERxuV1b+G/ve887os9qbxKKkqYljSNLzZ9QV55Ht3rdOfBNg/SqmYrV5cm1V1hhjOMXTMeyvIhpisMehsiWru6MpEqYUPWBoYvGk5uaS7PdnqW65tcf9xz/0rZz8iZCSRnFXN7pxieGdT8tOdEV2mFGc5ANvE7598xV38EkbqT5FxRMCsiIiIichZM02RP4R6mJE5heuJsTMNKgNmCW5vcwQMdB+LpXo1evF2kSipK+DLpS77Y9AX55fn0iOrBsLhhtKzZ0tWlSXVUXgTvNgNrETS/0tm5FnXUa3oROQbTNJm5dSZvrnqTMN8w3u31Li1rHPvv8nKbnTG/bePTRclEBHrz1g2xXNK41gWu2IVME9Z9CT8/CxWlzk0Euz4Mbtqk9FxSMCsiIiIicgYyCkr5YNmvzEv9CqvnBtwsbnQJ68vNTW6nZ31tGnUxKq4o5svEL5m4eSL55fn0iurFkDZDjvuiXuScSVkB23+DS593fp4wHaI6QI2Grq1LpAoptZXyyh+vMHfHXLrX6c6bl7xJkNexZzBvSM1n5MwEtmQUclN8FM9f0YJA72oUSO7f7dzca8fvENPFOUu2ZmNXV3VRUjArIiIiInKKTNNk6bYM3l85h83Fc7H4pOJm+nF94xsZ0nYwtXyrUSdNNVZkLeLLpC+ZuGkiBdYCekX3YljcMJrXaO7q0uRi4nDA1p9g2RjYsxJ8QmDYSggId3VlIlXO7oLdPL7wcbbv387QNkN5IPYBLIblqPOsNgcf/r6dj37fTk1/T968LpbezcJcULGLOOyw6jOY/zIYBvR9EeLvPbS5oJxzCmZFRERERE7CNE2KKooY99eXfL5hMhaPfPwtkdze/A7ubXMjPu4+ri5RXKDQWsjUxKlM2jyJQmshl0ZfytA2Q2kW2szVpUllZJrON4vF+b44C+zWA28Vzvc+IRAUBRmbYMa/IGcbBMdAl4eg7R3g6efqZyFS5SxIWcBzS5/DzeLGW5e8Rbc63Y553ua9BYycmcDm9AKua1uH/17ZkiDfatQlm5kE3z0MqaugUT+44j0IjnZ1VRc9BbMiIiIiIsdgmiZrU/bzf8vXkGL/lWxjMSW2EhoFxDG07T30rdfrmN02Uv0UWAuYunkqkzdPprCikD4xfRgaN5SmoU1dXVr1YJrgsDnDTYcNvAOdjxfshfLCQ+GnzQrunlCnvfP4joVQ9I9w1K8mtLrOefyPj6Aw/VBoardCrWbOGYsAcx6Aon2HjtvKoUEv6P+K8/j77aB0/5Hr293pvCXYNOGl4KOfS+dhMOANZ93TboX4u6H51eCmWdUip8vmsPHBXx/w+cbPaVmjJaN7jaa2f+2jzquwO/h0YTLvL9hGkI8Hr1/bmv4tI1xQsYvYrM7O/MVvg6c/DHgTYm9ydszKeXe8YFZ/64uIiIhItVRYVsE3f6UxYc0i9vIz7gEbsRgWBtUfwJ0t79Q8UTlKoGcgQ9sM5fYWtzNl8xQmb57M/JT59KvbjyFxQ2gS0sTVJZ6Zg4HngdDx74AxINIZFBZmQH7q0V2fjfo6A9C0PyE94chg0m6DHiPB4gYb58CuJYfW2soBE26a5Pz6i9+BLfOOvLaXPzyw2Hl81j2QONf5+N+C68Jj650ffz0Edi468jmFtYRhy50fz38F0v7R+BPd6VAwu24a5CY7N7px83S+OeyHzi3Lc26I4+YJXgHgW8MZ7P6tyWXOut08D12j9oH504YBl48Gizu4ex06HnpgZqxXANz1/Rl920QEskuzeWrxU6zat4obm9zI0x2fxtPN86jztmYUMmJGAhvS8rkyrjYvXdWSUL+jz7topa11dslmbIRW18OAt8BfY5kqA3XMioiIiEi1YpomdtPOY3OnsGDvTNx8U/C2+HFj0xsZ3PJ2IvyqUfeMnJX88nwmb57MlMQpFFcU079uf4bEDaFxyD82TnHYoSz/H8FlBfiHgW8olBXA3r/+cdzqDA9D6kJeCmz+9lA36MGuzMFQq6nzBfeKT44MNu1WZzdUeAtI+hHmv3T08bt/grDmsHIs/PTk0U/w0QQIqQdLRjvX/9MTO8CvhjP4XDLq6OPP7QMPH/jtRfhrypHBpbsXDFnqPG/ZGGdX68HjXuAdBFeMdh5fNw2ytxy53ifE+fzBubYk51Co6uYBXkEQdaBjNnen83twePDq7uUMf0WkylqXuY4RC0eQb83nP53/w9WNrj7qHLvDZNziHbz361b8vd155epWXB4b6YJqXcRaAgtfd94Z4B/u/EFRs0Gurqpa0igDEREREam2Sqw2vk9IZ/LKLbRtuZVVud+RVpRGmHdt7o39F9c0ugZfD19XlymnyjSdAaPpAA9v52P5aWArOzJ49A46tLv0lnlgKz3yeI3GUP8S5+ZLi946uiO0QU9oeS1Yi2H2v/9xvNwZDLa/i/ycbUyaeS1TvaDEgP6lVobmFdCw1wvQZRhkbYGPOh79PK4cA+3vcnacfnbp0cevHw+tb4Adi2DSVYcet7g7w8WbJkPjvpC8AL5//EDoeFhX5qB3IDIWdi11hq+HB5duntD9Meec07S1kDz/6OMtrnGOC8hJhpzth0LTv88Jb+l8X5rn/LU/fK3FQ5vIiMh5YZomXyZ9yajVo4j0j+S9Xu8dc6RMclYRI2cm8FdKHgNaRvDqta2o6e/lgopdZOcSZ5fs/p3Of2v6vez8d1FcQsGsiIiIiFQ7WzMK+XJlCrMTNmD1W4JXyCpMSxntwtoxuMVgekX3ws3i5uoyqxbThMxEKM4EDGd4CbBzMezffWSw6ekHHe51Hl/1GWRvOzLYDKztfKEIMPcxyEo6MvyMiIXrP3MeH9cLsrY6jzkqnI81vgxun+H8eFRT5xzQw7W8Dm6c4Pz49SiwFh55vN3gI+eAWjyODBfj74HezzhvYx/f71Bw+XcwGneLcz5fWT7Me4Y8AyZZ05hauptS086AsA4M6fI8DbxCIGH60cFm7bYQWt85Z3TfhiOv7eYJAeHOW93tFYeCTwWeIlKNlVSU8OLyF/lp10/0iu7Fa91fI9Az8Ihz7A6TCct28s7PW/D2cOPlq1tyVVxtjOoyS7UsH359Af78AkLqw1XvQ/0erq6q2tOMWRERERGpFkzTxDAMHA6TwVNmUeA5H/eYDXhj0L9ePwa3GEzrWq1dXWbVs3u5M1zc9isU7nU+5lsTnkx2frxyLCT9Y1ZmcN1Dwez2+ZDyx5FdmWHNjzzfzcMZ5v59PLTBoWPNroC6+49cX6PhoeMDXnfONT38dvXAw25XvWeec97p4bfLex7okjYMeGH/8QNPD59Dt90fi3cQXPMxwcAjwJ1l+5m4aSJfJn3JvG+vYUD9AQyJG0KDoAbHXu8VAHW7Hv/6bh7ONxGRamxH/g6G/z6cnQU7ebTdo9zT6p6jNufclV3ME7MSWL1rP32ahfHGda0JC/R2UcUusOUn5x0URRnODQx7PXvo3zqplNQxKyIiIiIXhV3ZxUxblcKCLfsYeY2dL5MmszZzLX7ufv/P3n3HR13Yfxx/3V3GZV72npCdEJYEUVFcoCCiCCIb3Frrr4622mlrbavVqq2CSFVGGRVx4GQpiAhhSQIJJCEJCQnZe978/v74hkAUFSHJJeTzfDzyMHffu8sneURy977P9/NhWtw0ZiXOOuuWZvE9avIhbxMMn6vO4tz+HHz9bxh8NcSOV0NTBz2EjlBv31ShdrOe2XGqc1KXQw1Qde11LMtaxpqjazBajdwYfSP3pd5HtCHa3qUJIUS/svH4Rv6w8w/oHfQ8e+WzXBp8aZfjNpvCyt1F/P3TozjoNPxxcjK3jQgdOF2yzVXw2a/h8Hp1+eGUf0PoSHtXJc4gowyEEEIIIcRFx2y1sfVIBavSi9lxrBRn7wN4Bu6inUpC3EKYkzSHW2Nuxd1Jlvz8KIsJinaqYWzeJnWmKMCcdyHmWvV0e53zgA5az1dtey3LDi9jbc5ajFYjE6Mncv/Q+4n0jLR3aUII0aeZbWZe3P8iK7NXkuqfygtXvfCdJZ0nalv51TuZ7Cqo4co4f569bQjBBhc7VdzLFAUOrYNPf63+nb7qV3D5L+RvdR8kwawQQgghhLhonBpXkF5Qwx1vbsQ7aA9az3SMSjOpfqnMS57HtRHX4qCVyV0/qKlc7XL1ioCyTFgyVg1fo66AuAkdnbHS3dldatpqWJa1jLVH12Kymbhp0E3cl3ofEZ4R9i5NCCH6nKrWKh7f/jgHKg8wK2EWj1/yOI5njHVRFIU1e07wzMfZaDQafjcpkRmjwgdOl2xDiTq2IG8ThI2Cm1+BgAR7VyW+hwSzQgghhBCiX7PaFL7MrWJVehERPm5Mv0zDiuwVfFL4KaBwbcS1zEuax7CAYfYute+y2eDkAcjdCHkboSxDHVUw5RW16yZvkxrKOrnZu9KLWnVbNW8dfou3c97GbDN3BrThnuH2Lk0IIfqEfeX7eHz747RaWvnjmD8yadCkLsdP1rfx6/WZ7Mir5rLBvjw3LZUw7wEyS9Vmg31vwJanQLHBtX+AtHvVOeqiz5JgVgghhBBC9EuVTe2s21fC6vRiSutb8PYrICBsNyeNh3F1cGVq7FRmJ84mzCPM3qX2TRbT6VMal1wFZQdBo4WwNIgbD/ETv7uES/SK6rZq3jz8Jm/nvI3FZmHy4Mncm3ov4R4S0AohBiZFUVietZyXDrxEuEc4L457kRjvmC7H1+0v4ekPs7HYFH4zMYHZoyPRagdIl2x1Hmx4GIq/hkHjYPLL4B1l56LEuZBgVgghhBBC9BunRhUAPLE+k7X78kmIzcHoup1qYwmBroHMSZzD1LipeDp52rnaPkZRoCpH7YjN3QR1hfCLw6DVwoGV4OAMMdeBq4+9KxUdqlqrOgNam2Lj5pibuWfIPfJmgxBiQGk2NfOHr//A5qLNXB95PX++7M9dZsRXNLbz5LuH+PxoJWnRPjw/bSgRvgOkS9ZqVhdwbvs7OOphwl9h2GwYKGMbLgISzAohhBBCiD6vrsXE+gNqd+yLM4YR4mtmyTcr+aRoPU3mRpJ9k5mfPJ/rIq/DUev44w840Bx+F7b8EeqL1cuBKeqc2Ct/CU4D5MVrP1bZWskbh97gndx3sCk2psRM4Z7Uewh1D7V3aUII0aPy6vJ4dNujnGg6wSMjH2Fe0rzON2gVReH9g6X88YMsTFYbv5qQwILLogZOl2xZBnzwEJRnQuJkmPg8eAT9+P1EnyLBrBBCCCGE6JMUReFAcR2rdhfz0aEyTBYbKVEthEfvJb1yKxabhavDr2Ze8jxGBIwYOEs9fkz9idNdsVc/CSHDIf8LSF+ijiiIHQ8G6bjsjypaKnjjsBrQKorCLbG3cM+QewhxD7F3aUII0e0+LviYP+36E26Obvzjyn9wSdDp7Kqqychv3zvEpuwKRkR48fz0oQzyd/+BR7uImNth+7Ow82Vw9YVJz0PSFHtXJc6TBLNCCCGEEKJPsdkUtFoN7WYro57ZgqIoXJZSTbN+K4dq9+Hi4MKUwVOYmzRXttaf0lYHX72kLumqzFav84pUu2fixtu3NtHtylvK+c+h//Bu3rsoKNwacyv3DLmHYPdge5cmhBAXzGw18499/2DN0TWMCBjB81c9j7+rf+fxjzJP8vv3D9NisvL4+DjuumIQuoHSJVu0CzY8BDXHYNgcGP+0jCDq5ySYFUIIIYQQfcLh0gZWpRdxuLSRDQ9djslm4tU9/2Nb+XoKGwsIcAlgZuJMpsdNx+BssHe59tVSA8e2qMu7km9Vu2eej4PgVIibALETwC9WZsxd5E4FtOvz1gNwW+xt3D3kboLc5FRWIUT/VN5SzmPbHyOzKpP5SfP5v5H/1zmiqLbFxO/fP8zHh8oYGmbg+elDiQ30sHPFvcTYBFv+BHuXglcE3PQSxFxr76pEN5BgVgghhBBC2E2rycJHGWWsSi8io6QBvaOWCaluREdn8l7+Omrba0nwSWBe0jxuiLoBR90Anh9bfhhyPlW7Ykv2AgpEXwXzN6jHze3q4g8x4JQ1l7H00FLeO/YeGjSdAW2gW6C9SxNCiHO2u2w3v9r+K4xWI09f/jTjo06f8fHZ4XJ+9/4hGtrM/OK6OO67chAOOq0dq+1FeZvhw19AYymMvh+u+R04D5CxDQOABLNCCCGEEKLXnRpX8FHmSR5a/Q2xAe7cMBxqdFvZWPQxJpuJq8KuYl7SPEYFjRqY82ONzVCyBwZfo15eMwtyPoaQER1dseMheBhoB8gLU/GjTjafZOmhpbyf9z4ajYZpcdO4e8jdBLgG2Ls0IYT4XjbFxpuH3+Tf3/ybaM9oXrz6RaIN0QDUt5p4akMW7x88SXKIJy/cPpSEIE87V9xLWmvhsychcy34xcOUVyA8zd5ViW4mwawQQgghhOgVRouVzw6Xs2p3MVfG+fHQNbEYzVZWH9pCes177Dy5E2edM1MGT2FO0pzOF2UDSk0+5G5Uu2KLdoLVBI9kqcu6qo+B3hPcJWQTP6y0uZSlmUv54NgHaDVapsdP566Uu7rMaBRCiL6g0dTIb3f8lm0l27gx6kaeuuwpXB1dAdh6pIIn3j1EXYuJh66J4WdXx+A4ELpkFQWy3oNPfgnt9XDFo3Dl4+DgbO/KRA+QYFYIIYQQQvSo49UtrNlTzLr9JdS2mIj0deXeKyNw8z3EiuwV5NXl4av3ZWbCTG6Pvx1vvbe9S+49FhPYLODkCofegfV3qdf7xakdsXETIGIMDOQRDuK8nWg6wdLMpWzI34CD1oHpcdO5M+VOCWiFEH3C0dqjPPLFI5S3lPP4qMeZlTALjUZDQ5uZpz/K5p39JSQEefD89KGkhA6Q2fKNZfDxYx1nyAyHm1+BoBR7VyV6kASzQgghhBCi250aVQBw17K9bMut4vrEQKaMNFBk2sranDXUtNcQ6x3LvKR5TIyeiJPOyc5V95KmcrUjNncjFGyD6/8Eo+6G5kq1QyZ2PK1HlLgAACAASURBVPgMwG5h0WNONJ7g9UOv82H+hzhoHbg9/nbuTLkTPxc/e5cmhBigPjj2AU/vfhqDk4EXxr3AsIBhAGzPreKJ9ZlUNhl54KrB/PzaGJwddHauthcoChxYAZt+D1YjXP1buPRB0DnYuzLRwySYFUIIIYQQ3aa0vo3/7Snm7X0lrLt/DOE+ruRXNVNnKuHT4nVsyN9Au7Wdy0MvZ17SPMYEjxk482MtRnhzApz8Rr3sGaqGsMPnQNh3no8L0e2KG4tZkrmEjwo+wknrxIz4GSxMWYivi6+9SxNCDBBGq5G/7/k77+S+Q1pQGs9d+Ry+Lr40Gy0883E2a/acICbAnRemD2VouJe9y+0dtQWw4WE4vgOixsLkl8F3sL2rEr1EglkhhBBCCHFBrDaFL3OrWJVexOdHK1GAcXH+PHFjAg3KUVZkr2B7yXactE5MHjyZOYlziPGOsXfZPautHvI/VztjAW59Tf3vhofBOxJiJ0BgMgyUUFr0KUWNRbye+TofFXyEs86ZGfEzWJC8QAJaIUSPKm0u5bFtj5FVk8VdKXfx0PCHcNA68PWxan75TiYnG9q4d+wgHrk+Dr3jAOiStVlh9yL4/Bl1ZNH4p2H4PFnqOcBIMCuEEEIIIc6L1aag02qobjYy5m9bMbg4MWNUGLeNDCar4UtWZq/kSO0RfPQ+zIifwYz4GRd/8HPoHdj3FhTvAsUKLt6QcBPc/G8JYUWfc7zhOEsyl/BJ4Sc465y5I+EOFiQvwEfvY+/ShBAXma9Kv+KJHU9gtVl55opnuCbiGlqMFp797CgrdhUR7efG89NTGRk5QP79qciCDx6Ckwcg7ka46Z/gGWLvqoQdSDArhBBCCCHOmaIo7MqvYVV6MXWtJlbfcykA+4vqiPCDDwrWs+bIGirbKhlkGMS8pHlMGjQJvYPezpX3AHMbFO5Qu2Kvewqc3WH7c5C9AeLGq12xYZeAdgB0/Yh+rbChkNcyXuPTwk/RO+iZmTCTBckLBtYiPiFEj7ApNpZkLGFxxmJivWN5cdyLRHhGsKewlsfXZXCirpWFl0XzywnxuDgNgL+XFiPseEH90HvBxOcgeaq8eTuASTArhBBCCCF+VF2LifUHSlidXkxBdQtero5MGxHGkxMTKW0+wcrslXyQ/wFtljbGBI9hXvI8Lgu5DK3mIjsdr6UGst5Vw9jCL8HSDo6uMG8DhI8Cm01OQRT9VkF9Aa9lvsZnhZ/h4uDCrMRZzE+aj5d+gMx5FEJ0q/r2ep746gl2lu5k8qDJ/H7M78HmxD825vDW14WEe7vyj2mpjB50kZ9Nc8qJvbDhIag6CqkzYMLfwG2AfO/ie0kwK4QQQgghzkpRFGwK6LQaln99nD9uyGJkpDezR0dwY0oQ2XUZrMhawRcnvkCn1TEpehJzk+YS7xNv79K7j9UCJXtAb1BnwpZlwpKx4B2ldsTGjYfIK8DxIuwIFgNWfn0+r2W8xsbjG3FxcGF24mzmJ8/H4Gywd2lCiH4iqyaLR794lKq2Kp5Ie4LpcdM5UFzPL9dlUFDdwtxLI3nixgTcnB3sXWrPM7XA53+B3YvVcQU3vaQ+fxACCWaFEEIIIcS3NLWbef+bUlalFzP/sihmpkXQ1G6mpK6NmEAXthRtYXnWcrJqsjA4G5gRP4M74u/A39Xf3qV3j5YaOLYZcjdC/lZob4AR89Q5sYoCNcfAN0ZOOxQXvWN1x3gtUw1o3RzdmJ04m3lJ8ySgFUJ8L0VRWJ+3nr+m/xU/Fz/+Oe6fxBgSeXFLLku/LCDY4MJz01K5PMbP3qX2jvwv4MOHob4YRt0N1/4R9J72rkr0IecdzGo0mnBgBRAE2IDXFUV5+Vu30QAvAxOBVmCBoigHOo7d0HFMB/xHUZS//1ixEswKIYQQQvScw6UNrEov4oODJ2k1WUkJ9eTha2IZnxxEo6mRd3PfZdXRVZS3lBPlGcXcpLlMHjwZFwcXe5d+YRQFGkvBEKZ+/nKq+gLKLQBix6tdLYOulhdSYsDKq8tjccZiNhdtxt3RnTlJc5ibNBdPJ/l/QghxWrulnb/s/gsf5H/AZSGX8fexf6e4SsNj6zI4VtnMzLRwfjMxEQ+9o71L7XltdbDpd/DNf8FnsPrmbtTl9q5K9EEXEswGA8GKohzQaDQewH7gFkVRss+4zUTg56jB7GjgZUVRRms0Gh2QC1wPlAB7gZln3vdsJJgVQgghhOheFqsNB506E3XKqzvJKW9kytBQZl8aQWqYFyVNJaw6sop3896l1dJKWlAa85LmMTZsbP+eH2tsgoJtalds3mZQbPBYjjofNncjuPlD8DCZFyvEGXJqc1iSuYTNRZvxcPRgTtIc5iTNkYBWCMGJxhM8su0RcupyuH/o/SxMuodFXxSyeHs+/u7OPDstlaviLpIza35M9gb45HFoqYbLH4arfg2O/fxNbNFjum2UgUaj+QB4RVGUzWdctwTYpijKmo7LOcA4IAp4SlGUCR3XPwmgKMrffuhrSDArhBBCDAAnD6rzt9wD7F3JRS23oonV6cV8fKiMzY9ciZerE8cqm/D30GNwceRg5UFWZK9ga/FWtGi5IfoG5ibNJck3yd6lnz9FUccP7H5N7WKxmcHZEwZfrc6LTb0ddAOgi0eIC5RTm8PijMVsLd6Kh5MHc5PmMidxDh5OHvYuTQhhB9tObOM3O36DRqPhb2P/ho9mKI+vy+BoeRPTRobx+5uSMLgMgL+vTRVqIHtkAwQNgZtfgZBh9q5K9HHfF8z+pOnLGo0mChgOpH/rUChw4ozLJR3Xne360T/lawohhBDiIvXuvVCdA/4JEDUWoq+EqCvA1cfelfV7RouVzw6Xs2p3MXuO1+Kk0zJxSBCtJiterhDl58LW4q2syF5BZlUmHk4eLExeyMyEmQS6Bdq7/J/OYoKinZC3Se2CnboUwkZCcCpcer8axkZcKmGsED9RvE88L139Ekdrj7Lo4CIWHVzEyuyVzEuax5zEObg7udu7RCFEL7DarLx68FWWHlpKok8iz135PO/vbeeVz3fi7ebEf+ZdwnVJ/fD5w0+lKHBwNWz8DZjb4No/wGUPy/MLcUHOOZjVaDTuwHrgF4qiNH778FnuovzA9Wd7/HuBewEiIiLOtSwhhBBC9BeKAp89Cb6DIe0euHUxFH4JhTvUJ7l7l8KQ2+G2peptj22F8FGgl+Uz5+rUuIKSujb+b+1Bonxd+c3EBKaNDMfHzYlmUzMrs1ey6sgqSptLCfcI58m0J7kl5hZcHV3tXf5P11imdqwUbANTM+icIXosnU83Iy9TP4QQFyTBJ4F/XfMvsmuyWZyxmFcPvsrK7JXMT57PrIRZEtAKcRGrba/lV1/+ivSydG6LvY2pkT/jZ8uPknWykSnDQvjTzcl4uTrZu8yeV1cEH/4fFHwB4Zeqs2T94+xdlbgInNMoA41G4wh8BGxUFOWfZzkuowyEEEII8cO+egm2/BHGPAQTnul6zGKCkwfA0VXtcqzJh3+PAI0Wgod2dNNeCZFjwMnNPvX3UWarjS3ZFaxKL8bg4sirs0cAkHGiniGhBrRaDWXNZaw6sor1eetpNjczImAE85LnMS5sHDqtzs7fwTmy2dTfkdyN6giMSxaq3SqvjVU7reMmqL8n8vshRI/Lqsli8cHFbC/ZjsHZwILkBcxMmImbo/z/J8TFJKMqg8e2PUZdex1Ppv2GipNDeWlLLp56R565NYUbUoLtXWLPs1lhz1LY+md1TNJ1T8Eld8lsevGTXcjyLw2wHKhVFOUX33ObScBDnF7+9S9FUdI0Go0D6vKva4FS1OVfsxRFyfqhrynBrBBCCHGRyVwH794NKbfB1P/8+JNZixFO7IHjO9SO2pK96pzQ6csg+VZoKIWaPAgfPWCXLBwpa+TjzDLe3neCyiYjIQY9c8ZE8uC4mM7bHK4+zIqsFWwq2gTA+MjxzEueR4pfir3K/umOfqwu1zi2BVqr1bB+2GyY8oq9KxNiwDtcfZjFGYv5suRLvJy9Ojto+2UHvhCik6IorM1Zy3N7nyPQNZBHU//Cq5uMZJyoZ+KQIJ6ekoKvu7O9y+x5lUdhw8+hZA/EXAc3vQRe4fauSvRTFxLMXgHsAA4Bto6rfwNEACiK8lpHePsKcAPQCixUFGVfx/0nAi8BOuBNRVG+1SLzXRLMCiGEEBeRgm3w32nqjM8568HhPJ7Im1rgRDqEjAAXL/j6Fdj0W9A5QVia2ikZPVb9XPeTRuj3C20mKweK60gvrOXusdF46h3519Y8XtySy9XxAcweHcG4+AB0Wg1Wm5VtJ7axInsFByoP4O7ozrS4acxKmEWwex/vbFEUqDyiBvEj56vXrZkFxV9DzPVqV+zga2QOsRB9zKGqQyzKWMRXpV/h7ezNgpQF3BF/hwS0QvRDreZW/rz7z3xc8DFjQ68kQXcvr2w9iauTjqenpDB5aIi9S+x5FhPsfBm+fE49E+eGZ9XFoZqzTesU4tycdzBrDxLMCiGEEBeR3YvhwApY+KkaqnYHYxMU7YLC7WpXbVkmaHXw6yJwdleP6RwheFi/DWpP1LayKr2YPYU1ZJY0YLEpaDWw9t4xpEX7UNtiQgN4u6lz3VrNrbx37D3+m/1fSppLCHUPZXbibKbGTu3bpxeb29RZw7kbIW8zNBSr1z+WCx6B0FKj/t70l5ELQgxgGVUZLM5YzM7SnfjofViQvIAZ8TMkoBWinzjecJxHtj1Cfn0+s+LuYc/B4RwoauD6pECeuTWFAA+9vUvseaUH1C7ZisOQPBVufA7c/e1dlbgISDArhBBCiN6lKKc7C8zt4NiDT+Zba6Eiq2PxE/DWJCj6Cpw81Lm0UWNh8NUQNKTnargAtS0m9h6vJb2glnHx/lwZ50/WyQZueXUnqWFepEX7kBbtw8hIbzz1XTf/lreUs+boGtblrqPJ1MRQ/6HMS5rHNRHX4KDto6F0fTE4e4CLNxxYCRseUucLDxoHsePVD0OovasUQpyng5UHWZyxmK9Pfo2P3oc7U+7k9vjbcXEYmKNnhOgPthRt4Xc7f4ej1pFrfB7hfztccNJp+dOUZG4ZFormYu8WNbXCtr/CrlfBPRAmvQAJk+xdlbiISDArhBBCiN7TVger74Brfnc6LO1NzVVqJ+2pGbU1eeop8HPfU49nvg2ByeCfaLflDUaLlb98dIT0whpyK5oBcHbQ8ssJ8dw9dhA2m4LRYsPF6eydokdrj7IsaxkbCzdiw8Z1EdcxN2kuwwKG9ea3cW6sFnUURd5GyN0EVUfUFzyj7lY7Ysu+gcgreja8F0L0um8qv2HRwUXsLtstAa0QfZTFZuHlAy+zLGsZcV7J2Mrn8E2hhqvj/fn7bakEeg6Av82FO9Qu2bpCGDEfrv9z953lJUQHCWaFEEII0TssRlg5VV2UMGe9Ov/V3hrLoL0BAhKgrR6ejQIUcPVVu2mjx6pdml4R3f6lFUWhpK6NPYW17CmsxeDqyG8mJgIw/sXtBBlcGB3tw+hoH4aEGXB2+P5T9hVFYdfJXbyV9Ra7y3bj6uDK1NipzE6cTZhHWLfXfkGsFnWMhLEZXkpRw3qtA0ReBrETIHEyeEfau0ohRC84UHGARRmLSC9Lx8/FjztT7mR63HT0DgMg8BGiD6tuq+bx7Y+zv2I/w7wmsn//WHQaR34/OYnpI8Mu/i7Z9gbY/AfYvwy8o2Dyv2DQVfauSlykJJgVQgghRM+z2WD9nZD1Htz2BgyZZu+Kzq6+WO2OOL5DnW/aWAoTn4e0e6C5EnI+VcNa7+gLWvTwwqYc1u8v4WRDOwAGF0cmDgnib1NTATVoPZcXPWarmc+Of8ayrGXk1uXi7+LP7MTZTI+fjqeT53nX160UBcoyIG+TOi/W1Qdmr1OPbXsWAhLVUQX6PlKvEKLX7Svfx+KMxewp34O/iz93DbmLaXHTcNYNgO3uQvQx+8r38asvf0WjqZEA42yy8+IYG+vH329LJdRrAHS153wKHz0CzRVw6YNw9W/BSeZhi54jwawQQgghet7G38KuV+D6p+Hyh+1dzblRFKgtUOeduvrAoXdg/V3qMUP46Y7ahJvOGirabAo5FU2dHbFHyhrZ/OhV6LQa/rk5l/zKZkYPUmfExgV4oNWee9DbbGrmndx3WHlkJZWtlcR4xTA/eT6ToifhqHP88QfoLV+9CLtfg+Zy9XLICEi6Ga54xL51CSH6pL3le1l0cBH7KvYR4BLAnUPulIBWiB5U01ZDVk0WWTVZZFdnk1WTRVVbFd6OIdQW3oHNGMxvJyUyKy3i4u+SbamGT38Nh9+BgGSY8m8IHWnvqsQAIMGsEEIIIXqWzQrv3quOB7jx2QvqNLUrRYHqXLWTtvBLOP4VtNXCI1lgCMOa/yU0l6MbdBXvHTPz1IZsGtrMAIQY9Iwe5MtTk5MxuJ5/cFreUs6qI6t4J/cdms3NpAWlsSB5AVeEXmH/F0w1+WpHbP5WmL4cnN1h1yJ1hmzcBIi5DtwD7FujEKJf2Fu+l1cPvsr+iv0EuAZw95C7uS32Npx0TvYuTYh+q769nuya7M4gNqs6i/JW9Y1TDRp8nMLw1A6irSmYvIIELo0K4R/ThhLuc5F3iyoKHFqnhrLGJrjyl+obyA7y743oHRLMCiGEEKLn2GzqEi2bDVBA+/1zUvsbo9lMbtYBvqzzJb2wlmlFf+ZmzQ4A2gwxfKMbgmbQOMIvv50w7wt7UZNTm8PyrOV8WvgpNmxMiJzA/JT5JPsmX/g3ciFqC2DPUjWQrc1Xr/NPUIPZgAT71iaE6NcURWFP+R4WHVzEgcoDBLoGcs+Qe7g19lYJaIX4ETWtDewuzWB/WSZH67Ipas6l0VLReVxr8cfUGoqlLRRbWxhWYwjYnHHUaQg2uHDn5VHMGxP1k87m6ZcaStSxBXmbIPQSmPKKOmJJiF4kwawQQgghesbxnWr3wcw14BVu72ouWJvJSpvZio+bE8cqm5j0r68wWmwAxAd6MDrKwMKYJqIb96sdtUW7wHcw3K+Gtex9AzyCIPLyc9roqygK6eXpLDu8jJ0nd+Li4MLU2KnMTZpLqHtoT36r36+xDI5tBv9ECB8FZZnwn+s6lqRNgLjx6pIMIYToJoqisLtsN4sOLuJg1UGC3ILUgDbm1r41ukWIXtJmslLe2E5ZQxvlDe2cqK8jpy6H4uYcqs35tHIcxbGq8/Y2kw/W9lC0pnB8HAYR6hpLqKcPQQY9wQY9QQYXgjz1BBn0+Lo5XfxhLKgNA/vfhM1PgWKFa34Po++7qBoIRP8hwawQQgghul/lEXhzArgHwp0dC5/6mcZ2M/uL6jpnxGaW1DPn0kj+ODkZs9XGc58d5ZIoH9KifPB2O0v3ltUMTeVqKG2zwfMx0FoDGi0EpUL0lZA0BcK6Pg8z28xsOr6J5VnLOVJ7BF+9L7MTZ3N7/O0YnA299N13UBQo2Qd5G9Wu2PJM9foxD8GEZ9Tj5lZwcuvduoQQA46iKOwq28Wig4vIqMog2C2Ye1Lv4ZbBt0hAKy4KiqLQ2G6horGdsoZ2yhvaKGtoP+NyOycbm2hWitDpS9DpS9G6lKB1qkKjUfMbR8UbH4fBhLnGEeeVyNCAZGL8ggjy1OPp4mD/sUd9QfUx2PBzKP5aXT46+WV5U1nYlQSzQgghhOhejSfhP9eDzQx3bQbvSHtXdE7qWkyU1reREmpAURTG/O1zyhvbcdBqSA0zkBbty7WJAYyKOs+Q2WKEkr1QuAOO74ATe2DsY3D1k2BspuXLZ1nvrOG/Fbsoay0n2hDNguQFTBo0qXcX37TVQU0BhI1Ug9cXU6DpJISPhtjx6rzYgKT+OytYCNGvKYrC1ye/ZtHBRWRWZxLiFsI9qfcwJWYKjloJaEXfZLMp1LaaKO8IWMsazx68tpqsp++ksaB1LsdgKEfvfhKb0wnaKEVBPVvH09GbeO8khgUMYVjgEJJ8k/Bz8bPTd9gPWC2w69/wxd/AUQ8T/grDZsvzGWF3EswKIYQQovu0N8BbE6HuOCz8FIJT7V3R96psamdPYS3pBWpHbE5FE2HeLnz162sA+DDjJD5uTgyP8MLVyaH7CzC1gtVEpWJiVfrzrDv+MU1aLSPbTSx0CmFs5PVoR8zt+TEQiqJ2OOdthNxN6rIuV194LEedD1x6QO0k6Yddz0KIi5eiKOw8uZNFBxdxqPoQoe6h3Jt6L5MHT5aAVvQqi9VGVbOxM1wtb2jvGDWghq/lje1UNBgxWW1d7qfTagj0cCbIoCfA0wFXtxqsTsU0K4VUGPMpbcnHolgA8HL2ItkvmWTf0x8BrgHSAXuuyjJhw0NQlgEJN8GkF9TxUkL0ARLMCiGEEKL7NFfB2pkw7kmIudbe1XRRUtfKvuN13Dw0BK1Ww5PvZrJmzwncnHSMjPJhdLQPadE+XBLp3SsvdI7VHWN59nI+KvgIm2LjurBxLPAawpDqInVGbVkG3LsNQoap83qLv4boqyBkOFzoabvmNtA5q8HrlqfgqxfV6wOHqHNiYydA2Cj1uBBC9GGKorCjdAeLDi4iqyaLMPcw7k29l5sG3yQBrbhg7WYrlY1GdZ7rGZ2tZ3a9VjUZsX0rPnFy0KrzWz3VOa6BBj3Bnuo810BPR0zackrb8jhSk0V2TTZHa49ispkA8HD0IMkvqTOATfFLIdgtWELY82Fuhy+fg69eUt90nvS8OkZKiD5EglkhhBBCXDibDRQb6BzUz/tAoFfZ1M7nRyrVrtjCWkrr2wD47BdjSQjyJK+iiVaTleQQTxx0vVOvoijsLd/Lsqxl7CjdgV6n59bYW5mbOJdwz291xrbVg7On+rPc/hx88Yx6vaMbRI6BqLEw5mfnHtLWF6tzYvM2qcHvgo/V+bal+6H8EMRcDwY7LRUTQogLdCqgffXgq2TXZBPuEa4GtINuwkHbA2c9iH6v2WhRO1objJ2LtMoa26lo6AhgG9upbTF9537uzg6nF2edGbwa9AR5uhBs0OPl6ohGo8Gm2ChqLCKrJousajWEPVJ7hDaL+pzE1cGVJN+OELajIzbcI1xC2O5QtEudJVuTp44sGP8XOftH9EkSzAohhBDiwm15CiqyYMZ/waEX56F2sNkUciub2FNYy5hBvsQGerD1SAV3Ld+Hn7sTadE+jI72JS3ah/hAj17fOGyxWdhStIW3st4iuyYbH70PMxNmckf8HXjpvc7tQVpqoOgrNVQt3AGmZngkS52NtmuRepvosRCQ3DUYr8mHtbOh6oh62TtanRM76h7wi+neb1QIIexMURS2l2xn0cFFHKk9QoRHBPcNvY+J0RMloB0gFEWhvtX8rfmtbZ1h66mO1yaj5Tv39XFzIrAjbA3q6HI9FboGG/QEeurx0J/9DVFFUShpKlFD2I6P7JpsWswtAOh1ehJ8Ek6PJPBLJsozCq3G/m9mX1SMTbDlT7B3KRgiYPJLfe4sLiHOJMGsEEIIIS7MnqXwyeNwyZ0w6Z+9tkShxWhhzZ5i0gtr2Xu8lvpWMwC/m5TI3WMH0WqyUNbQziA/N7t1nrSaW3nv2HuszF5JaXMpUZ5RzEuex+RBk9E76C/swY3N4Oyufv7WJDW0BXDxgaAhMPgauOIX6ml8/5sDg69WRxRIGCuEGAAUReGLE1+wOGMxR2uPEukZyX2p93Fj9I0S0PZjVptCTcc817MFr6cuGy1d57lqNBDg4UyQwaVjpIC+S9drUEfoqnfUnVMdiqJQ3lLeGcAerj5Mdk02jaZGABy1jiT4JHTphh1kGCS/ez0tbzN8+AtoLIXR98E1vz/9XEmIPkqCWSGEEEKcvyMfqaFf3A1qt6zu/F9wWG0K6YU1xAd64OvetevWaLFyqKSB9MJa/N2duX1UOEaLldSnNhFs0JMW7UNatC+jo30I83ax+ymA1W3VrD6ymv/l/I9GUyPDA4azIHkB48LH9VxnTEOJ2kl7fIc6nzbpFrjqlz3ztYQQop9QFIXPT3zO4oOLyanLIcozivuG3seNUTei055bCCd6h8lio7KpY37rGUu01MvqqIHKJiOWbw10ddRpzuhydensbA0+I3z1d3e+oLFFla2VZFV37YStba8FwEHjQKx3bJflXDFeMThe6Dx4ce5aa+GzJyFzLfjFwc2vQMRoe1clxDmRYFYIIYQQ56c4HVbcDIEpMP9DcHI974eqbjbyi7UH+epYNU46LRNSgpiVFkHWyQa2HqnkQHFdZ/fLzUND+NfM4QDUtZjwdnPqlm+nOxQ0FLAiawUb8jdgsVm4NuJa5ifPZ1jAMHuXJoQQA5pNsfF58ecsylhEXl0e0YZo7k+9nwlREySg7QWtJsu3gtZvB6/tVDcbv3M/F0ddZ8DatcPVpfN6H1enbh1RVNNWczqArc4mqyaLqrYqALQaLYO9BncGsMm+ycT5xOGs6/0xTgJQFMh6Dz75JbTXwxWPwpWP22WslhDnS4JZIYQQQpyfk9+o3Qkz/gtufuf9MOkFNfx8zTfUt5pIi/altL6NmmYjje0WXJ10eOoduTYxgLGx/qRF++DTh4JYULuxDlQeYNnhZWwr2Yazzpkpg6cwL3kekZ6R9i5PCCHEGWyKja3FW1l0cBHH6o8xyDCI+4fez/jI8RLQngdFUWhst3R2tZ4eLdB11EBDm/k79zW4OJ61u1W97EKQQY+n3qFHz4JpMDZ0dsCe6ogtaykDQIOGaEN0l8Vc8T7xuDi49Fg94idoLIOPH4OcjyF4GEx5FYJS7F2VED+ZBLNCCCGE+GnMbeDY8aJEUc57pqzNprDkywL+sfEofu7OmK02GtrMDAv34q0Fo9hypJLV6UXsL67HSaflxiFqF21atI/dRxUAWG1WthZvZVnWMg5VH8LL2Utd6JVwBz562forhBB9mU2xsbloM69lvMax+mMMNgzm/mFqQCvLmFQ2m0Jtq+mMDtc2ys8IXk91vLaarN+5Mg41JwAAIABJREFUr5+78+lOV88zul3PuOzq1LvzVptMTRypOXJ6OVd1FiXNJZ3HIzwiOkPYJN8kEn0ScXeS+aR9jqLAgRWw6fdgNcLVv4VLH7ygcVpC2JMEs0IIIYQ4d8YmWDYJ4m6Eq58874epazHx6NsH+SKnikBPZyoajaSEevLsbakkhxi63DanvIk1e4pZf6CEpnYLMQHuzEyL4LYRoXi59n73bJuljfePvc+KrBWUNJcQ7hHO/KT53Bxzs3TRCCFEP2NTbGw6vonFGYspaCggxiuG+4fez/WR11/UAa3FaqOyyfit0QJtlDcauyzSMlu75gI6rYZAD+eOoNXlrMFrgIceJwf7/uxaza0crT3aJYQ93ni883ioe2iXxVyJPokYnA3f/4Cib6gtgA0Pq/P0I6+Am/8FvoPtXZUQF0SCWSGEEEKcG6sZVs+Agm0wcw3ETTivhzlQXMdDqw5Q3Wzi1zfE8/b+Em4bEcqdl0f/4GKONpOVjzJPsnpPMd8U1+PsoGXSkGBmjY5gZKR3j3fR1rTVsDZnLWuPrqXeWE+qfyoLkxdydfjVcvqrEEL0c1ablU1FakBb2FBIrHcsDwx9gGsjru13AW272do5QqDraIHTwWtVk5Fv7dDC2UH7rdECXRdpBRv0+Lo7o+vGea7dod3STk5dTucoguyabAoaCrAp6mz6ANcAkn2TSfFLIdlX7Yb11nvbuWrxk9issHsRfP4M6Bzh+j/DiPmg7V//bwpxNhLMCiGEEOLHKQp88DM4uAom/wtGzj+Ph1B4c+dx/vpxNs6OOpYvTGNUtA9Wm/KTX+QdKWtkdXox739TSpPRQlygO7PSIrh1eBgG1+7dgny84TgrstWFXiariXHh41iYspBh/sP6xEgFIYQQ3cdqs/LZ8c94LeM1jjceJ847jgeGPsA1Edf0iYC22Wjp7Ggta2inoqGdssauM11rW0zfuZ+Hs0PnAq2g7wlevVwd+/zfNbPVTG59bmcIm1WdxbH6Y1gVdZyCj96nM4A9FcL6u/rbuWpxQSqy4IOH4OQBiLsBJv0TDKH2rkqIbiPBrBBCCCF+3Bd/he3PwlW/hqt/85Pv3tBm5vG3D7L5SCUajfoCccVdoxkW7nVBZbWaLHyYcZLV6cVklDSgd9QyaUgIs0ZHMCLC64JeYB6sPMhbh9/iixNf4Kh15OaYm5mXNI9oQ/QF1SzExaix3cyxymbyK5tpM3933uSF6pGoqAcCqJ6KtHoiK9P0QLU9U2cPPOY5PKhNsZLV+CU7qtZQayolUD+IK/1mEedx6Vn/tnT3z9NktVF5qtv1jK7XZqPlO7f1cXM6I2w9c7SAS2cY6+7c/+Zvmm1mCuoLOgPYrJoscutyMdvURWIGZ0NnAHtqOVega2CfD5fFObIYYccL6ofeADc+Bym39cw/NELYkQSzQgghhPhxh96Bop1ql8JPfEJ8uLSBu5bvpaLRCMBNqcH8cXIy/h7O3Vri4dIG1uxRu2hbTFYSgjyYNTqCW4aH4qk/ty5aq83KthPbeCvrLTKqMjA4G5gRP4OZCTPxc/Hr1nqF6I/qWkzkVTaTV9lEXkUzxyrVj/LGdnuXJkQPseLgmYGz/1a0TjVY20MwVl2HtTmRnoviVVoNBHjoCTToCT7LAq1ggwsBns7oHfv/OB2rzUphQ+HpmbA1WeTU5mC0qs8dPBw9SPJNIskvqTOMDXUPlRD2YnViL2x4CKqOwpDb4Ya/g5uvvasSokdIMCuEEEKI79feoHYpnAdFUViVXsyfNmSpXbJ6R569LZXrkgK7uciuWowWNnR00R4qbcDFUcfkocHMGh3J0DDDWV/EtVva2ZC/gRXZKyhqLCLUPZR5SfO4JeYWXB1de7ReIfoaRVGoajKS1xG6nhnC1pxxirSrk46YAPfOj9gAD2IC3PHUd29nXk+8KumplzpKT1TbPx6yR36mPfHzPN86rTYLn5dsZFXuG5S1lBBjSGBuwj2MDryiRzr4HLQafN2cfnD2en9lU2wUNxZ3Wcx1pPYIbZY2AFwcXE4v5urohg33CO8ToyREDzO1wOd/gd2LwTMEbnrxvHcaCNFfSDArhBBCiLMr2Q//nQpTl0Lc+J9012ajhbuX7WV3YS3j4v355YR4Inxc8TjHztXucqikgdV7ivjg4ElaTVaSgj2ZOTqCW4aF4KF3pK69jrVH17Lm6BrqjHWk+KawIGUB10VcJwu9xEVPURRONrSTV9HU2fmaV9lMXkUTje2nT5f20DsQ2xG8xga6MzjAndgAd0IMLmj72BIgIXqaxWbho4KPWJKxhJLmEpJ9k3lw2IOMDR0r3ZtnoSgKJc0l6lKu6uzO5VzN5mYAnHXOJPgkdAawKb4pRHpGyt/ggSj/C/jwYagvhkvuguueAr2nvasSosdJMCuEEEKI76rJhzfGg5Mb3L0F3APO+a67C6q5e/k+mo1WLon05u37xtg9vGlqN/PBQbWLNrusEVfXOgbF7KfMugOTzchVYVcxP3k+lwReIi+sxUXHalMoqWslr6K5cwxBfkcQ22I6PQ/W183pjO5Xd2IDPYgNcMffw1n+vxDiW8w2Mx/lf8SSzCWUNpeS4pvCA8MeGNABraIoVLRWcLj6cJe5sI2mRgActY7Ee8d3zoNN8k1isNdgHLT9b/6t6EZtdbDpd/DNf8FnMNz8b4i63N5VCdFrJJgVQgghRFct1fDG9dBWD3dtBr+Yc7qboig8+e4h1u49AcBtI0L569QhODv0na6XjMoMXt63lL1VX6IoWswNw4l0uJEFl1zKzcNC+uVyFCFOMVttFNW0dI4dyOv4KKhqxmixdd4u0NO5c+zAqRA2JsAdX/funfssxEBgtpn5MP9DXs98ndLmUlL9Unlg2ANcHnL5RR/QVrVWdRlHkFWTRW17LQAOGgdivGO6LOaK9YrFUde7Z86IPi57A3zyuPrc87Kfw7gnwNHF3lUJ0askmBVCCCHEaeY2WDYJKrJg/ocQnnZOd2szWZm+5GsOlzbi4ezA0vmXcOmgvrGkwabY2H5iO8uylnGg8gAeTh7MiJ/B5Ojp7DxqYlV6MUfLm3Bz0jFleCiz0iJICT2/ubpC9IZ2s5XC6hZ1BmxFE8eqmsmraKawugWL7fRz+FAvF2ID3TvHEAzuCGANLhKMCNHdzFYzH+R/wOuZr1PWUkaqfyoPDn2Qy0IuuygC2tr22s7w9dRYgsq2SgC0Gi2DDIO6hLBx3nHoHfR2rlr0WU0VaiB7ZAMEDYGbX4GQYfauSgi7kGBWCCGEEKfZbLD1T2ogmzDpR29utSkcKK7jd+8dJqeiiXHx/rw+9xKcHOy/oMNoNfJh/ocsz1rO8cbjhLiFMDdpLlNjp3ZZ6KUoCt+cqGd1ejEfZZ6k3WwjNczArLQIJg8NwU26aIWdtJos5Fe2qMu3KtXwNb+qmaKaFk7lr1oNRPq6MdjfvUsIO8jfTX53hbADs9XM+/nvszRzKWUtZQzzH8YDwx5gTPCYfhPQNhgbyK7J7tIJW9ZSBoAGDVGGqC6LueK942VRpjg3igIZa+CzJ9VmgHG/hsseBumkFgOYBLNCCCGEUJ8ot9aAm9853+VIWSP3rdxPSV0rXi6OvDxzOGNj/XuwyHPTYGzgfzn/Y/WR1dS015Dok8jClIVcH3n9j86xa2gz8/43paxOLyanogl3ZwduGR7CrLRIkkJkAYXoGQ1tZo5VNpPfMf/1VAhbWt/WeRsHrYZoPzdiA92J6RhDEBvgTrSfG3rHvjMuRAihMllNvH/sfV7PfJ2K1gqGBwznwWEPMjpodJ8KaJtNzRypPdKlG/ZE04nO4+Ee4V1C2ESfRNyd3O1Ysei36orgo19A/ucQfqk6S9Y/zt5VCWF3EswKIYQQArb/A/a8Dvd+AYawH7xpu9nKS1tyWfJlAYoCg/zcWHX3aIK97DsTrKSphJXZK3nv2Hu0Wdq4IvQKFiYvZFTQqJ/8IlhR1E7gVenFfJxZhtFiY1i4F7NGR3BTajCuTtKJKH662hYTeRVq8Hqs8tQc2CYqGo2dt3Fy0Krdr50LuNTxA5G+bjjq7N+JLoT4aUxWE+/lvcfSQ0upaK1gRMAIHhz2IGlBab0e0LaaW8mpy+mynOt44/HO4yFuIST7qUu5Ti3nMjjLaB9xgWxW2Psf2PIn0GjguqfgkrtAK3/ThAAJZoUQQgjxzSr44EFInQG3LlGfNH+PwuoW5r2Rzok6tZNvwWVR/G5SIg52DIyyqrN4K+stNhdtRqvRMil6EvOT5xPrHdstj1/fauLdA6Ws3lPMscpmPPQO3Do8lFmjI0gIki5a0ZWiKFQ1GTu6Xps6F3DlVzZT02LqvJ2rk47YAHcGd4weOBXChnm7otP2nW46IUT3MFlNrM9bz38y/0NlWyUjA0fys2E/Y1TQqB75ekarkZzanC7jCAoaCrAp6iLAAJcAkvySunTD+uh9eqQWMYBV5cAHD0HJHoi5Dm56Ebwi7F2VEH2KBLNCCCHEQHZsK6y+HaKugFnrwMHpB2++4WApj/wvA2dHLf+eOZxrEwN7qdCubIqNr0q/YlnWMvaW78XD0YPp8dOZlTCLQLeeqUlRFPYer2N1ehGfHC7HZLExIsKLWaMjuSk1WE4nH2BsNoWTDW2doWtexekxBE3tls7beeodiA1Ug9eYjo/YQA+CPfVoJYAVYsAxWo28k/sObxx6g6q2KkYFjeKBoQ9cUEBrtprJrc8lqzqrczbssbpjWBT13yIfvU+XxVxJvkkEuAZ017ckxHdZzfDVS/Dlc+DkBjf8XW0A6ENjPIToKySYFUIIIQaqiix4Yzx4R8PCT0B/9u7PTVnlvP5lAYnBHqzcXczQcC9enTWcMO/eX/Rhspr4uOBjlmctJ78hnyC3IOYkzuG22Nt6deZdXYuJ9QdKWL2nmIKqFjz1DkwdEcbs0RHEBnr0Wh2i51ltCidqWzs6X5s6RxAcq2ym1WTtvJ2vm1NH6Kp2wJ6aAevv4dyn5kkKIfqGdku7GtAefoPqtmrSgtJ4cNiDjAwc+YP3s9gs5Nfnd1nOlVOXg9lmBsDTyZNk32RS/FI6w9hA10D5d0j0ntIDsOHnUHEYkm+FG58Dd3kjQIjvI8GsEEIIMVAZm2HjkzDuN+AZ/J3DlY3tPPVhFp8cKkfvqKXdbOPOy6N54sYEnBx6d3RBg7GBdbnrWH1kNVVtVcR7x7MgZQEToibgqLXfJl9FUUgvrGV1ejGfHS7HZLUxKsqbmWkRTBwiXbT9idlqo6impaPztblzDmx+VTMmi63zdkGeemID3dU5sGeEsD5uP9xtLoQQZ9NuaWdd7jreOPQGNe01jA4ezc+G/YzhAcOx2qwcbzzeZRzB0dqjGK3qXGp3R/fT82A7xhKEuYdJCCvsw9QK2/4Gu14BtwCY9AIk3mTvqoTo8ySYFUIIIQaa1lrQOYHz2TtMFUXh7X0neObjI7SarOi0Ghy1Gv4xfSg3DvlugNuTTjafZGX2St7Ne5dWSyuXhVzG/OT5jAke0+deeNa2mHhn/wnW7DlBYXULBhdHbhsRxqzREcQEyAbrvqLdbKWgqoW8yiZ1BEHHx/HqFiy2089/w7xdOua+ehDj705MxxIuT7393ggQQly82ixtvJ3zNm8efpPa9loGGwZzsuUkbRZ1pruLgwuJPomd4wiSfZOJ8IxAq5EFSqIPKNwBHz4MtQUwYj5c/2dw8bJ3VUL0CxLMCiGEED2guLGYnSd3okGDTqvDQeOATqtDp9Gdvnzm5x3HHLTq9Vqt9ntv87330Wh/PKw0t8GKKYAGFn561o24NpvC7Ut2UVrfRllDO8khniyaPYJIX7ee+WGdxZGaI7yV9Rabjm9Cg4Ybo29kfvJ84n3ie62G86UoCrsKalidXszGrHLMVoW0aB9mj47ghpQgnB2ki7Y3tBgt5Fc1d3bAquMHmiiubeVU/qrVQKSvW+fYgdhAd2L8PRgc4Iark4N9vwEhxIB0KqDdUbKDQV6DOkPYaEM0Oq38/RB9THsDbP4j7H8LvKNg8r9g0FX2rkqIfkWCWSGEEKIb1bTV8FrGa7yT+07n0o3e1CXA7QhxT4e5OnQtNehMLTh4hqJz8e4MebUaLVVNZgI9XNFqdGSfbKGhzUKEtzvJId446Rw6A+BvB8LfFxR/9+v/eDjdZGpibc5a0svScXN0Y1rsNOYkzSHILajXf5bdobrZyDv7S1izp5iimla8XR2ZNjKMO9IiGOwvXbTdoaHN3Bm6nhnClta3dd7GUach2s+tY/mWR2cIG+XrJuMmhBBCiPOR8yl89Cg0l8OlD8LVvwWn3t8/IER/J8GsEEII0Q1aza0sz17OssPLMFqNTI2dysKUhbg4uGC1WbEqHR8dn1tslu+9bFEsp68/4/NTx2w2W9frf+j+Z36dE3uw1ORiDUrF5juo81h9azs5FQ20mEz4ezpS39qOTbER7OWEu177vY99Zv2njnWHANcA5iTOYVrcNDycLo5FWjabwtf5NazeU8SmrAosNoVLB/kwa3QkE5IDpYv2HNQ0G8/ofFUXceVVNFPZZOy8jbOD9ozZr+6dQWykryuOOjndVwghhLhgLdXw6a/h8DsQkAQ3vwJhP7y0Tgjx/b4vmJVzt4QQQohzYLFZeDfvXRZnLKa6rZprI67l/0b8H9GGaHuX1lX6Eji6C8Y8BBOeAaDNZOXFLbl8+nUBfu7OXB3lw6eHy4gNcGfR7JE/eS6qoijYFNsPBsU/FO5abBY0Gg2pfqk46i6uOZ5arYYrYv24ItaPyqZ21u0rYe3eYh5e8w0+bk5M7+iijfbrvXERfZGiKFQ2GTs6X5u6BLG1LabO27k56YgJcGdsrH+XEDbM2xWdtm/NHhZCCCEuCooCh9apoayxCcY9CVc8Cg6y/FKIniAds0IIIcQPUBSFz098zssHXqawoZDhAcN5dOSjDAsYZu/Szq7uOOx7E659qnOu7J8+zOKtnce5dXgoJ+tbSS+sY9rIMJ6ekoKLk3Rw9jSbTeH/2bvv8Cbve/3jb00P2Zb3xAa8AZsRdnZSyIDsHTKaSTNoMzpOT3q6e5r+Tlea3ZLRLLJXk7ASshMICWEYA55gvJds2ZKt/fz+eGTZBjNjWx6f13X5siI9lr4ixki3P8/9/by8hVVf7ef93Y14fQonZsWxbH4GZ01NxqgfuxOePp9CnbVbDV4PCGE7Hb2T11GhenKTIv2Tr+pGXDmJEaSYQ0fc5m9CCCHEmGWtgXfvgbL1kDYHLnwYEqcEe1VCjAlSZSCEEEIco21N2/jrN39lW/M2JkVN4u7Zd3Nm+pkjMyhqrYCYyYEwtr3LRafDQ3psOC02J//ZVsdjn1TQ6XDzuwsLuGJOepAXPD41dTh45ZtqXtxcTW17N/ERRi6bnc7V89KHddO1web1Key3dFHW2El5c08Iqwaw3W5v4Lj4CKN/A67IwEZc2UkRJESEjMy/V0IIIcR44PPBlqfg/d+A4oUzfwnzfwCyEZ0Qg0aCWSGEEOIo7bXu5R/f/oMN+zcQHxbP7TNu55KcS9BrR2gDUGMxPHUOzFuOcub/8O6Oen77TjGZ8RG8eOt8Hv+0kr+uL2FSvIlHrzmB/OSoYK943PP6FD4ta2bVV/v5cE8TXp/CKTnxLJuXwaKpSSO2J9Xl8VHVaqesyUZZo43yZhtljZ1UtthxeXyB45KjQslJigiEsDlJEWQnRBBjktMghRBCiBGlpRz+80PY/yVMPg3O/wfEjrCqLiHGgOMOZjUazVPAeUCToigFA9z+U+Aa/3/qgSlAgqIoFo1Gsw/oBLyAZ6AFDESCWSGEEMHQ0t3Co9se5Y2yNwjRhXBjwY1cP/V6wg0jeOdZaw08sRiAxivf4Rcb2vhgdxPTJ5i579wpPP5pBR+XNHP+jFTuv6SQiJARGi6PYw1WBy9/Xc3LX++nzuogPiKEK+ZM4Op5GaTHBud7z+H2Utlsp6ypU92Ay19DUNXahcfX+9oxPTYsMP3aMwGblRhBVOjY6g4WQgghxhyvBzY+BB/dD/pQdW+CWdeCnMEixJD4LsHsqYANeHagYPaAY88H7lEU5Uz/f+8D5iiK0nIsi5VgVgghxHCyu+38u/jfPFP8DG6vm8tyL+O2GbcRFxYX7KUdXne7OinbUUvxOa9w5VsdeH0KPz4rl+kTorn7pa202Fz88vypXDs/Q04VH+G8PoVPSpsCU7QKcEpOAsvmZfC9KYlDMkVrd3oCm26p1QNqB+x+Sxc9LxG1GpgUZ/J3v/ZOwWYmmAg3StAvhBBCjCo+L+x5Fz79MzQUQf55sOQvEJUS7JUJMaYdKpg94qtpRVE+1Wg0k47yca4GXjy2pQkhhBDB4fa5eb30dR7b/hgWh4WzJp7Fj074EROjJgZ7aUemKPDajSit5WiufZ2s9HksrSjmzjOyWL+rkWUrN5ESHcrrt59I4QRzsFcrjoJOq+HM/CTOzE+irr3bP0VbzW3PbyExMoQr56Zz5dx0JsQc+xSttctNeXOnf/LVFghja9u7A8cYdBomx5soSDVz0cw0cpLUAHZSfDgheumYE0IIIUY1ZydsfR42PQbtVRA9ES7/N0y9SKZkhQiio+qY9Qez7x5uYlaj0YQDNUC2oigW/3V7gTZAAf6pKMq/jmZRMjErhBBiKCmKwvtV7/Pg1gep6qhidtJs7p19L9MTpgd7aUfN6fHyzlsvs7u8gp/cex9hRh3Wbjc/fXU763c1ctbUJP58+QzMYXJK+Wjm8fr4qKSZVV9V8XFpMwCn5yZw9bwMzsxPRN9nilZRFFrtrt7p10Z1+rWsyUZzpzNwXIhe2696INvfAZsRGz5iu22FEEIIcZzaq2HzP2HLs+C0QvoCWHgn5C+Vzb2EGEbHPTF7DM4HvugJZf1OUhSlTqPRJALvazSaPYqifHqIBS4HlgNkZGQM4rKEEEKIXlsat/C3LX9jR/MOssxZPHzmw5w64dRRdZr/rm8/466PfZQ1mblw5nm4PD7Km2zcsWoL9e0OfnneVG46adKoek5iYHqdlsVTk1g8NYmatq7AFO3y57aQHBXKBTNTsTk9lPs7YNu63IGvNRl1ZCdFclpugj+AVSdg02LC0Gnle0MIIYQY02q/hY2PQPGb6n9PvVANZCcc1dY/QohhMpgTs28CryqKsuoQt/8GsCmK8pcjPZ5MzAohhBhsFe0VPLDlAT6u+ZjEsETunHUnF2RdgF47ejoyHW4vHz3zW86t+Qc/Mvyaiy+9ltPzEnh+UxW/f3c38RFGHr7mBE7IiAn2UsUQ8nh9bNijdtF+WtZMVKiBnED/a2QghE0xh0o4L4QQQownPi+UrFED2f1fgjESZn8f5v8AomUATohgGtKJWY1GYwZOA67tc50J0CqK0um/fBbwu8F4PCGEEOJoNXU18ei2R3mz/E3C9eHcdcJdXDPlGsL0YcFe2jEzlrzD2TUPsiv6dP74gztBq+OHL27l3R31nJGXwN+umEmMyRjsZYohptdpOXtaMmdPS8bh9hKi10oAK4QQQoxnThtsWwWbHoW2vWDOgLP/CLOug9CoYK9OCHEYRwxmNRrNi8DpQLxGo6kBfg0YABRFedx/2MXAekVR7H2+NAl40/9GQQ+sUhRl7eAtXQghhDg0m8vGUzuf4rldz+FRPCzLX8by6cuJCR1d06QtNid/XlvCz6daiHlzOb70eUy9/iV2tzi584Vv2ddq52fn5HHbqVlo5fT0cSfUIN1wQgghxLhlrYXN/4ItT4PDCmlzYNGvIf980I2es8KEGM+OqspguI2nKgOLw8K2pm3kROeQFpmGViObbgghxHfh9rp5pfQV/rn9n7Q52zh38rn8cNYPSY9MD/bSjomiKLzxbS2/f28XIU4Ln4f/FENUEspN63h1Vxe/fHsn5jADD149iwWZccFerhBCCCGEGC5129Tp2J2vg+KDKefDwhWQPi/YKxNCHMJwbP4ljsOWxi3c+/G9AITpw5hsnkx2dHbgIycmh6TwJDlFUQghjkBRFNbtW8c/vv0HNbYa5ifP557Z9zAtflqwl3bMqi1d3PdmEZ+VtTB7Ygx/umQhhhon3ekn8z/v1fD6tzWclB3HA1fOIiEyJNjLFUIIIYQQQ83ng7J1an/svs/AGAHzlqv9sTGTgr06IcRxkonZIOv2dFPWVkZFewVl7WWUt5VT0V5BU3dT4JgIQwRZ0Vm9gW2M+jkuNE4CWyGEADbXb+ZvW/5GcWsxOTE53Dv7Xk5KPWnU/oz8yavbWVNUzy8XT+CKLC/a1BmUN3VyxwvfUtZk40dn5vCj7+Wgk+oCIYQQQoixzdUF21fBxkfBUgFRE9Qwdvb3IdQc7NUJIY7SoSZmJZgdoaxOK+Xtakhb1lZGhVX93O5sDxwTHRJNdnQ2WdFZ5ETnqJ9jcjCHyA9nIcT4UNpWygNbHuCz2s9INiWzYuYKzss8D5129PVu7qrrwKjXkJ0YicXuwuHoJvW966FuK++duY6fvltFmEHHA1fN5JSchGAvVwghhBBCDKWOevh6JXzzFHS3Qeosta5g6oWgMwR7dUKIYyTB7BigKAqtjlYq2isoby8PTNqWt5djc9sCxyWEJQQmbHNicgKXTQZTEFcvhBCDp8HewMNbH+Y/Ff8hwhDBrdNv5er8qwnVhwZ7acfM4fby4IYy/vlpJWfkJfDE9+eCosCbt8GOl3h1wi/4afk05k2K5aFls0iKGn3PUQghhBBCHKWGIrWuoOg18Hkgf6kayGYsgFF6NpgQQjpmxwSNRkN8WDzxYfHMT5kfuF5RFBq7GilvL6e8rVz93F7Oa6Wv4fA6AselmlLVkDamt8M205w5KoMMIcT41OHq4MmiJ3lh9wv4FB/XT72eW6ffOmrPFNhU2cp/v1HE3hY7l8+ewC+WTlFv2PA72PESz4Rex6/Lp3H76Vn8eHEuep1sECmEEEIIMeb4fFAICQUeAAAgAElEQVT+AWx8GPZ+AgYTzLkJFtwGsZnBXp0QYghJMDsGaDQakk3JJJuSOTnt5MD1PsVHra22X1hb3l7OpvpNuH1u9WvRkB6Z3luJEJNDdnQ2k6ImYZDTI4QQI4TL6+KlPS/xr6J/0eHsYGnmUlbMWkFaRFqwl3bc1u5s4Lbnt5ARG87zN8/n5Jx49YaSNfD533hFWcTfHOfz1A0zOTM/KbiLFUIIIYQQg8/dDdtfgk2PQkspRKbCot+q/bFhMcFenRBiGEgwO4ZpNVrSI9NJj0znjIwzAtd7fB72d+4PbDRW1l5GeXs5n9R8glfxAqDX6JkYNZHsmP4dtumR6ei18m0jhBgePsXH6r2reXjrw9Taajkx9UTumX0P+bH5wV7acbPYXcSajJyel8BPz87jppMmE2ZUO3FdHh9/2pOK230DxSmX8t41c5gQEx7kFQshhBBCiEHV2QhfPwHfPAldrZAyAy5ZCdMulv5YIcYZ6ZgVAS6vi73Wvb2bjrWrHbY1nTUoqN8nRq2RzOjM3g5bf2CbGpGKViOn2AohBs/Guo38fcvf2W3ZTX5sPvfMvocTU08M9rKOW2OHg1+9vZPiug7W33Mq4cb+v+Rq2LOR/3q/hU9qtdx00mR+fm4+Rr38XBVCCCGEGDMai2Hjo1D0CnjdkHcuLLwTJp4k/bFCjHHSMSuOyKgzkhebR15sXr/ru9xdgcC2vL2csvYytjRu4b3K9wLHhOnDyDL376/Njs4mMTwRjfwDI4Q4BiWWEv6+5e98UfcFqaZU7j/lfpZMXjJqf/nj8ym89HU196/Zjcvj4+5FuRgO6Ir98qtNTFlzOT9QJnP1ta9wTkFKkFYrhBBCCCEGlaJAxQb48mGo/Aj0YXDC9TD/dojPDvbqhBBBJhOz4rh1ujqpaK/o119b3lZOq6M1cEykMTIQ0vZUImTHZBMbGhvElQshRqI6Wx0Pb32YdyvfJdIYyfLpy7k6/2qMOmOwl3bcrF1ulj/3DV/ttbAwM477LylkUrwpcLvb6+Oxd77kom9vIErnwn7dGtIypwVxxUIIIYQQYlC4Hepk7MZHoHkPRCTD/OUw+0YIl/fDQow3MjErBl2kMZKZiTOZmTiz3/VtjrZAUFvRXkFZWxnr9q2jw9UROCY2NDYQ1mZHZ5MTo1YiRBmjhvtpCCGCzOq08kTRE6zavQqNRsONBTdyc+HNY+LnQWSonqgwA/936XQunzOh3xkE9dZufvr8l/ys8cck6ztRvv8u0RMllBVCCCGEGNVszWp37OaV0NUCSYVw8T9h2iWgH70DB0KIoSETs2JYKIpCS3eLutFYWzkV1grK29TwtsvTFTguMTwx0Fvbd9I23CCb3wgx1ji9TlbtXsXKopXYXDYuyLqAFbNWkGxKDvbSvpPt1e38cfVuHrp6FolRoQMe82lpM3e/vI0fe1ayTPM+mmUvQe7Zw7xSIYQQQggxaJr2wKZHYPvL4HVCztlqf+zkU6U/VgghE7MiuDQaDQnhCSSEJ/TbvEdRFOrt9f2qEMrby/mm5BucXmfguLSItN7uWn+P7WTzZEJ0IcF4OkKI78Dtc/Nuxbs8tv0x6u31nJx2MnefcPdB/dajTZfLw9/Wl/LUF3tJiAyhpr37oGDW61P4xwelPPRRObmJkSy89K9ounaMv1C2vRo2/wvisiHnLIiSTl0hhBBCjEKKovbGbnwEyj8AfSjMXAYL7oCE3GCvTggxCsjErBiRvD4vNbaafmFteXs5+6z78CgeALQaLRmRGYGwtqfDNiMqA4PWEORnIIQ4kMPj4I2yN/h38b+pt9czLW4a986+l3kp84K9tO/s09Jm7nuziJq2bq6Zn8F/nZtPVGj/n0NNnQ7uenEbGytb+V1OOZdffSth4ePsbABFga3Pwdr7wGUD/K9BUmZC7jlqQJ0yE7Sjc6M3IYQQQowTHicUvaYGsk3FYEqEecthzk1gigv26oQQI9ChJmYlmBWjitvrpqqjinJrb2Bb0V7B/s79+BQfAHqtnklRk3orEWKyyYnOIS0iDZ1WF+RnIMT40+nq5OWSl3lu13NYHBZOSDyBWwpv4eS0k/t1ro5md7+0lR21Vv50yXTmTT54M4eNFa386KWtdDrcPDdrD3N3/AYW/w5Oumv4Fxss1lp450fqNMmkU+CCh8DdBaVroXQ91GwGxae+sck9Sw1qM0+HkMhgr1wIIYQQQmVvhW+eUs/8sTdB4jS1rqDwMtDL2ZxCiEOTYFaMaQ6Pg30d+yhrKwuEteXt5dTaagPHhOhCyDRn9qtDyI7OJsWUMmbCISFGEovDwvO7nufFPS9ic9s4Ke0kbi28ldlJs4O9tEHxeVkLSVEh5CRF0uFwY9RpCTX0/+WPz6fw2CcV/HV9CZPiTTxzcjvpa29SA8dlL4NuHEz3KwpsfxHW/Bx8blj0W5h7y8FTsfZWNbQtXQvlG8BpBZ0RJp2sdrTlng2xk4PzHIQQQggxvjWXwqZH1dc0HgdkL1YD2czTpT9WCHFUJJgV41KXuysQ0vbtsW3qbgocYzKYDtpsLCc6h/iweAlshTgODfYGnil+htdKX8PpdbJo4iJuKbyFqXFTg720QWF3erh/zW6e37Sf82ek8tDVswY8zmJ3cc/L2/iktJkLZqTypwVuwlddCPE5cMNqCIkY5pUHQWcDvHOXGrZmLIQLH4G4rCN/ndcN+zdB2TooXQctper18XlqQJt7DqTPB51U5QshhBBiiCgK7P1UrSsoWwe6EJhxpdofmzgl2KsTQowyEswK0YfVaR0wsG1ztgWOMYeYyTJnkROT0y+wjQ6NDuLKhRi59ln38dTOp3in8h1QYGnmUm4qvIlMc2awlzZoNu+18JNXt1Pd1sXNJ03mJ2fnHTQlC7ClysKKVVtptbn49QVTWTYnDc2j88Hrgps/gMikIKx+GCkKFL0Kq3+qTpV879cw/7bj745trYCy9WrAu+8LdfI21AzZi9SQNnsRhB9cISGEEEIIccw8Ltj5uhrINhZBeDzMuxXm3AwRCcFenRBilJJgVoij0NrdelBYW9FeQae7M3BMXGhcoLe276RthHEcTL8JMYA9lj08UfQE6/etx6gzcmnOpdww7QZSIlKCvbRBtWF3I7c8+w3pMeH85fIZA3bJKorCE5/t5f+t3UNqdBiPXnMCBWlm9cam3aDVqxOzY5mtCd69B/a8CxPmwUWPQXz24N2/sxMqPlInacvWgb0ZNFp1gjb3bLX2IHGKnFYohBBCiGPTZYEtT8NX/wJbAyTk+/tjrwBDaLBXJ4QY5SSYFeI4KYpCY1djYMK2rK2MivYKKqwVdHu6A8clm5IDIW1Pj22mOZMwfVgQVy/E0NnatJWVO1byWe1nRBgiuCr/Kq6Zcg3xYfHBXtqgcri9hBp0ONxeHv+kgltPycQUcvAp9NYuNz95bTvv72rknGnJ/N/l04nSuqH4TZi5bOwHhYoCxW/Aez8Blx3O/B/1zcxQbrro80HdVnWStmwd1G9Xrzdn9FYeTDpZ3kwJIYQQ4tBayuGrx2DbKnVj0swz4MQVkPW9sf/6TQgxbCSYFWKQ+RQftbba/pUIbeVUWitx+9wAaNAwIXJCoAahpxJhsnkyRp0xyM9AiGOnKApf1n3JyqKVbGncQkxIDNdNvY4r868kyhgV7OUNKqfHy4MbynhvRz3v/ugUIgYIY3vsqGnnjhe+pcHq4L4lU7jxpElofF54+Vo1NPzBp5AyfRhXP8zsLfDevbDrbUibrU7JJuQN/zo66vyVB+ug8mP1zZUhXH2DlXuWOk0bNbYmuYUQQghxHBQFqr5Q6wpK1qgbshZeAQvvgKRpwV6dEGIMkmBWiGHi8Xmo7qwOBLU9oW1VRxVexQuATqNjYtTEQGCbFZ1Fdkw2GZEZ6LWymY0YeXyKjw37N7Byx0p2W3aTFJ7EjQU3cknOJWNyKnxXXQf3vrKNPQ2dXDZ7Ar8+fyqRoYaDjtvbYuf5TVU8t7GKhMgQHlo2ixMyYtQX++/eo54Ot+Qvai/ZWLXrbXj3XnB2wBn3wcIfjoxNudwO2Pe5GoyXrgVrtXp9ygx1kjb3bEiZdfy9t0IIIYQYfbz+s5k2PqyeaRMWC3NvUT/G+h4AQoigkmBWiCBzeV3s69jXL6wtby+nprMGBfXvoUFrYLJ5MtnR2eTE5JBlVgPbtIg0tBoJD8Twc/vcrK5czZM7n2SvdS8ToyZyc8HNnJd5HgbdwUHlaOfx+njs4woe/LAMc5iRP11SyKKp/V+ke30KH+5p4tmN+/isrAW9VsP5M1L51XlTiTH5J+E//TN8+Ac46W5Y/NvhfyLDocsCq3+ibo6RMhMufnzk7lCsKGrHb+ladZq2ZjMoPjAlQs5ZakibdQaERAZ7pUIIIYQYCt1tsOUZ+Oqf0FkH8bmw4A6YcRUYxt6QgRBi5JFgVogRqtvTTaW1MrDRWFl7GeXt5TTYGwLHhOnDyDRn9uuvzY7OJik8CY30Hokh4PA4eLP8TZ7e+TT19nryYvK4ZfotLM5YjG4oO0ODzOdTuHrlJhKjQvndBdN6g1agudPJK99Us+qr/dS2d5McFcqy+RlcNTedxKg+HaatFfDwXCi4FC7+59icyNzzHrxzt/om5/T/UgPo0RTUd1mg/AM1qC3/ABxW0BrUPtrcc9Tag9jMYK9SCCGEEN+VpRI2PQ5bnwe3HSafBgtXQPaisfkaTQgxYkkwK8Qo0+nqVDcZ69l0rF3ddKyluyVwTIQhItBbmxPT22EbFxonga04LjaXjZdLXubZXc9icViYmTCTW6ffyilpp4zZ7ymvT+GZL/exdHoKSVGhdLu8hBnV8FlRFLZUtfHcpipWF9Xj9iqclB3HdQsmsWhKInrdIV7Q7/sCJswF/Rjrku6ywNqfw46XIbkQLnockguCvarvxuuG6q/807TroaVEvT4+t3cDsfT5oyt4FkIIIcYzRYH9m9S6gj3vgVYPhZepm5ImFwZ7dUKIcUqCWSHGiDZHG+Xt5f02HStrK6PD1RE4JiYkRu2t7alE8F82h5iDuHIxklkcFl7Y/QIv7n6RTncnJ6WexC2FtzA7afaYDWQB9rXY+cmr2/mmqo2fnp3HnWdkA2B3enhrWy3PbaxiT0MnkaF6Lps9gWvmTyQ7MWLgO6vfDp2N6rTlWFSyFt65C7pa4NSfwik/HpthpaVSDWhL16odtT43hJrVyZqcsyFnMYTHBnuVQgghhDiQ16123298BOq+hbAYmHMTzL1VNv8UQgSdBLNCjGGKotDS3dKvu7Zn87EuT1fguMSwxMBGY4FahOhswg3hQVy9CKYGewPPFD/Da6Wv4fQ6WTRxETcX3sy0uLG9G63Pp/D8V1Xcv3oPep2G314wjYtnpVHRbOP5Tft5fUsNnU4PU1KiuH7hRC6cmUq48TAbWrVVwZOLQR8KK74GfcjwPZmh1t0O6+6DbS9A4jS4+DF1A63xwNkJlR/3TtPam0CjhQnzeqdpE6fAGP7lhRBCCDHiOay9/bEdNRCbBQvvgBlXg9EU7NUJIQQgwawQ45KiKDTYGwK9tRXtFZS1lVFprcTpdQKgQUNGVAa5MbnkxeSRF5tHXkweyabkMT0pOd5VdVTx9M6nebvibRRFYWnmUm4uuJnM6PHRq/nIR+X8eV0Jp+Ym8IeLplFc28Fzm6r4sqIVo07LksJkrls4iRMyoo/896DLAk+dDbZGuGndyN0A63iUfQD/+aH63E65F0792dirZzhaPh/Ub1U3DytdB/Xb1OvNGeqUdO45MOkUMIQe/n6EEEIIMTja9qlh7LfPgsum/ju88E71DBfpjxVCjDASzAohArw+L7W2WsrayihtK6WkrYQSSwk1tprAMVHGKDWs9Qe1ubG5ZEdnE6IbQ5OA41CJpYQnip5gfdV6DFoDF2dfzA0FN5AWkRbspQ05RVHo6PZgDjfQ3uVSp2IdHl78ej+NHU7SosO4ZkEGV8xJJz7iKL/PO+rh+UuhtQyuewsmnTS0T2K4OKyw7hew9TlIyIeLHoO0E4K9qpGlox7K1qshbeVH4O4CQzhknq5O0+acBVGpwV6lEEIIMfZUb1b7Y3e/o57JUnApLLgDUmcGe2VCCHFIEswKIY7I5rJR1l5GiaWEkrYSSi2llLWX0e3pBkCn0TEpapIa1voD27zYPOLD4oO8cnEk25q2sbJoJZ/WfIrJYOKqvKu4duq14+b/XWOHg5+/voMWm5P/OjefF7+qZl1xAx6fwmm5CVy3YCJn5Cei0x7jlPjnD8Cnf4Yrn4esM4Zm8cOt4kN4+4fQWQcn3QWn/VymQI/E7VD7aMvWqV281v3q9cnT1Una3HMgdZZM7wghhBDHy+uBPe+o/bE1X6v977NvhPk/kF+ECiFGBQlmhRDHxevzUt1ZHZiq7fnc2NUYOCY2NDYQ0vZM2U42T8agHYMbA40iiqKwsW4jK4tW8k3jN0SHRHPd1Ou4Kv8qooxRwV7esFAUhf9sr+OXb+2k2+0lOsxAs82FOczAFXPUzbwmxR9H95jXAzq9uutv2z6InTzoax92zk5Y/0vY8jTE56pTshMOet0gjkRRoHmPv5d2HVR/BYoPTAnqqZW5Z0HmGRA6Pv4OCiGEEN+Jo0M9g2fT4+ovPmMmq9OxM5dByCE2ZBVCiBFIglkhxKCyOq39gtrStlLK28tx+9wAGLQGsqKzDuqujQ6NDvLKxz6f4uPD/R+ysmglu1p3kRieyI3TbuSSnEvG1UZv1m43d7zwLV+Ut6DVgE+BwjQz1y2cyAUzUgk16I7vjvd+Bu/cBde8CnFZg7voYKn8BN5eAdZqOHEFnPELMIQFe1VjQ5cFyjeoQW35+2pNhNag1l7knqPWHsSOj25nIYQQ4qi171f7Y7c8A65OyDhR7Y/NOxe0x/kaTgghgkiCWSHEkHP73Oyz7gvUIPSEtq2O1sAxieGJ5MfmB3pr82LyyIjMQCcvsL4zt8/Nmr1reLLoSSqtlWREZnBz4c2cl3keRt342bDJ5fGxrriBZ77Yxzf729BpNVw0M5XrF05iRvp3/MXArv/A6zeru/1e98boP3XOaYMPfgNfr1Sf00WPQcb8YK9q7PJ61Ana0rVqP23zHvX6+Fy1kzb3HMhYADo520AIIcQ4VbNF7Y/d9bb639MuhoV3QNrs4K5LCCG+IwlmhRBB09Ld0hvU+sPavda9eBUvAGH6MLKjs/tvNhaTS4RRTk86Gg6Pg7fK3+LpnU9TZ68jNyaXWwtvZfHExeMq8K5r7+bpL/by3Kb9ONxeMmLDWTYvnSvnZhBjGoRgesu/4d17IG0OLHsZwmO/+30G074v4O07oK0KFtwOZ/4SjONnonpEsOz1byC2Vu2o9bogxAzZ31ND2uxFYIoL9iqFEEKIoeXzwp731P7Y6k3qv4Wzv6/2x5onBHt1QggxKCSYFUKMKE6vk4r2ikANQk9g2+HqCByTFpHWrwYhNzaXCRET0GiOcYOmMcrmsvFK6Ss8W/wsrY5WZiTMYPn05ZySdsq4+TNSFIUvylt5btM+1hc3ogAa4K5FOfzozBy0x7qZ16HsfANeuxGyF8MVz4DxOHppRwpXF2z4HXz1OMRMhAsfVU+rF8HltEHlx73TtLZGdafpCXPVuoPccyBxKoyTv9tCCCHGAWcnbH0BNj0K7VUQnaH2x866FkIig706IYQYVBLMCiFGPEVRaOxqPKi7tqqjCgX1Z1WEIYLcmNx+07XZMdmE6cdPH2abo40Xdr/Aqj2r6HR1cmLqidxSeAtzkuaMm0DW2u3m9S01PL+pisoWOyF6LU6Pj8nxJh66ehYFaebBfUCXHTY+CiffPbpPM9+/Cd66AywVMG85LPrN6A6ZxyqfD+q3qZuHla5VLwOY03tD2kknSw+wEEKI0cla09sf67RC+ny1Pzb/POmPFUKMWRLMCiFGrS53F+Xt5f3C2tK2UuxuOwBajZaMyIxAUJsXq1YhJIUnjamgstHeyDO7nuG10tfo9nSzKGMRtxTewrT4acFe2rAprrPy3MYq3tpWi8Pt44QMtTN2W3U7Pzgti7sX5RCiH6QX9F43fPoX9Y1CaNTg3GewuLvhwz+opwhGp8OFj8DkU4O9KnG0Ohv8lQfroOIjcNtBHwaZp0POInWH6ogkiEiE8Dh5UyuEEGJkqv1WfS1S/CagwNQLYcGdkD432CsTQoghJ8GsEGJM8Sk+ajtr+/XWlraVUmurDRxjDjEH+mrzY/PJi80j05w56jbC2t+xn6d2PsXbFW+jKApLM5dyU8FNZEVnBXtpw8Lh9rJmZz3Pbazi2/3thBq0nFeYwsUnTOCk7HiqLV0025yckBEzeA/qssMr34fy9+GSJ2D65YN338Ot+mt463ZoLYM5N8Hi38npgaOZ2wFVn0Ppeihdo+5a3ZdGC+HxakhrSvAHtglgSlSvi0jsvSwhrhBCiKHm86pnf2x8BKq+AGOk2h87b7laqSSEEOOEBLNCiHGh09Wpdtb2dNdaSihrL8PpdQKg1+iZHD1Znaz199bmxeQRFzbyNtgpsZTw5M4nWbdvHXqNnotzLubGghtJi0gL9tKGRbWli1Wb9/Py19VY7C4y401cu2AimQkmfvOfYrITI3ji+0MwYdFlgVVXQO0WOO8B9c3DaOR2wMf3w5cPQlQaXPAQZJ0R7FWJwaQoaidfRz3Ym8Dm/7A3ga1Z7antud7jOPjr+4a4gcDWH+YeeDk8VkJcIYQQR89lh22r1P5YSyWYM2DBbTDrutF/JpIQQhwHCWaFEOOW1+elqrOKUkvvJmMlbSU0dTUFjokPi+8X1ObF5DHJPAm9Vj/s693WtI0nip7gk5pPMBlMXJF3BddPvZ74sPhhX8tw8/kUPi1r5rmNVXxY0oQGWDw1iesWTGJ2RjR/31DGys8qSYsO48+XzWBh1iAH6tZaeP4S9Q3EpU/C1AsG9/6HS+0WtUu2eQ+c8H046w/yJmg8UxR1gxV7sz+8bTzEZX+oe9gQ9zATuD2Xw+NAqx3+5ymEECL4Oupg87/gm6fB0Q5ps2HhCphyAeiG/3W1EEKMFIcKZuUnoxBizNNpdWSaM8k0Z3LO5HMC17c52vrVIJRYSvhq11d4fB4AjFojWdFZB3XXmkMGeWMp1I3PNtZv5ImiJ/i64WuiQ6JZMXMFV+VfNSSPN9K02V28uqWa5zftZ7+li/iIEFackc3V8zJIjQ6jotnG+Y98QXmTjWXzM7hvyRQiQobgnzCfRw2xrn19dHawepzwyf/B539XA7RrX4fsRcFelQg2jUYN5kOjIO4IFSj9QtxGf1g7wOWW8sOEuDowxR8wdZvgD28PuBwWKyGuEEKMBfXb1bqCna+D4lM38lq4AtLnqf8OCSGEGJBMzAohRB9ur5tKa2UgqC1pU0Nbi8MSOCbFlNJ/ujY2j/TIdLSaYw8XfIqPj/Z/xMqilRS3FpMYlsj3p32fy3IvI9wQPphPbUTaXt3Oc5uqeGd7HU6Pj3mTYrl24UTOmZaMUd/759ne5eL7T23mnsW5nJ6XOPgLaa1QN1DSatUutNF4ynbdNnVKtqkYZl4LZ/8vhEUHe1ViLFMUcHaotQl2//TtgJf9H/5KmX56QtyBJm/79eQmSogrhBAjjc+nbk658WHY9xkYI9Sqgvk/gNjJwV6dEEKMKFJlIIQQx0lRFFq6W/rVIJRaStnXsQ+v4gUgTB9GTkxOoAYhLzaPnJgcTAbTgPfp9rlZu3ctTxY9SYW1gvTIdG4uuJnzs84fdZuTHatWm5N1xY289PV+dtRYCTfquHhWGtctnEh+cu/p9nsaOnjis73cf0khBp0WRVHQDMXERcWH8NK1cPLdcNrPBv/+h5rHBZ/9FT77i3qq+QUPQu7ZwV6VEP31DXED3bcDXW4+Qoib0KdKIenQl8NiJMQVQoih4uqC7S+q/bGt5WqX/fzb4ITr5ZfCQghxCBLMCiHEIHN4HFRYKyi1lLLHsicQ2Ha6OwPHpEemB6Zr82PyyY7O5su6L3m6+GlqbbXkxORwa+GtLJ64OCh9tsOlxeZkXXEDq4vq2VRpwetTyEmM4LqFE7l4VhqRoYbAsR6vj39+WskDH5RiDjPy0vL5ZCdGDs3Cdr4BbyyHhDz1tP/I5KF5nKHSUARv3a5+nn4VnPsnNZASYjRTFHBYj74T1+s6+D76hrgDbWbW97KEuEIIcXQ6G2DzSvjmSehug9RZal3B1AtBZzjy1wshxDh23MGsRqN5CjgPaFIUpWCA208H3gb2+q96Q1GU3/lvOwf4B6ADnlAU5U9Hs1gJZoUQo5WiKNTb6/vVIJRYStjfub/fcdMTprO8cDmnTjh1aKZAR4DmTidrixtYU1TPpspWfApMjjexpDCZJYUpTE2JOui5lzfZ+PGr29le3c7S6Sn8/sICYk1DNEG8eSWs/ilkLISrXxxdEx5eN3z+AHzy/9RQ6fwHIH9psFclxPDrF+IephO3p1ZhoBBXq1dD3IF6cA+sV5AQVwgxHjUUwcZHoehVtY8/fyksvFN9DTVGX8cKIcRg+y7B7KmADXj2MMHsTxRFOe+A63VAKbAYqAG+Bq5WFGXXkRYrwawQYqzpcndR2lZKWXsZk6ImMSdpzpgMZJs6Hazb2cB7RfVs3mvBp0BmgomlhSksKUwhPznykM9bURQuePgLqtu6+P2FBZw/I3XoFtpeDQ/Nhqwz4fKnwRA2dI812Bp3qVOy9dug4DJY8mcIjw32qoQY+RRF3SH8aDtxfe6D76NfiJt0cA9uv07cGAkshBCjl88H5R+o/bF7PwFDOMy6Vq0sONJGkkIIIQ5yqGD2iOfNKoryqUajmXQcjzkPKFcUpdK/gJeAC4EjBrNCiOPj8fpo73bTZnfRanf1+2zpcmGx93602V20d7tJMYdSmGamwP8xLTWq32nlYnCEG8KZmTiTmYkzg72UQdfU4WDNTrWmYGyS67MAACAASURBVPM+C4oCWQkmVpyRzZLpKeQlHTqMBdjf2kVshJGIED1/u2IG5nADiZGhQ7NYRVGDkuh0uGkNJM8A3SipkPB64MsH4eP7ISQKrnhWPXVQCHF0NBo1LA2LgYTcwx/bN8Q9XCdu064jh7iH2tis72UJcYUQI4W7G3a8rE7ItpRAZAos+g3MvkHqkoQQYggM1rvRhRqNZjtQhzo9WwykAdV9jqkB5g/S4wkx5imKgt3lVUPVAz+6Dghd/ddZu90cagg+MkRPjMlIjMlIUlQo+clRmMMM7Ld0sanSwlvb6gLHTo43MS01qjewTTVjDpewVvRq7HCwpqie1UUNfF2lhrE5iRH86Mwclk5PITfpyJ2wPp/CC19V8cfVe7hybjq/uWAaOUfxdcfN44Q3b4OcxTBzGaTNHrrHGmzNJeqUbO0WmHoRLP2rupO9EGJoHGuI2912+B5cWxM0FqvXDxjiGgbuxJUQVwgxXGxN8PUT6kdXKyRPh0tWqq879GN7Y1ohhAimwQhmvwUmKopi02g0S4C3gBxgoFeMh+xN0Gg0y4HlABkZGYOwLCFGFrfXR1uXiza7u1/AarG5aOsaOHx1eXwD3pdeqyHWZCTWZCQm3MiU1Chiw42B6/reFhdhJDrcQIhed9j1NXc62VlnpbjWSlGtla3723l3R33g9vTYMArTzExLNQcC2yHr/hQjUoPVweqietbsrOebqjYUBXKTIrjrezksLUw5KFRVFAW3V8Hp8eL0+IgOM6DXaWnudLK3xc6DG8r4vLyFU3Li+cFpmUO7eKcNXr4GKj9WN6oYLXxe9RTCD/8XjCa47GkouCTYqxJC9KXRqHUi4bHqRoKH0y/EPUQPrq0RGnaql32eg+8jEOIONIF7QE+uhLhCiCNp3AWbHoEdr6g93LnnwokrYOJJ8vNDCCGGwRE7ZgH8VQbvDtQxO8Cx+4A5qOHsbxRFOdt//X8DKIpy/5HuQzpmxUinKAo2p6e3FqDLRas/YO2dYnVjsTtp63LTanPS4RjgzZVfZKieOP8064EBa4zJ2P+2CCORIfph6Se12F0U16lBbXFtB0W1VvZbugK3p5pDAxUIhWlmpqVFDd0p6OK4eH3+YNTtI8yoI9Sgo8vloaLJHghMe26flRFDsjmUaksXa3bW43T7aLU72dPQSUWTnWabE4CM2DB0Wi3RYQa0Wg0Ot3o/Dy+bRX5yFK9+U80v396J0+PrN8G94cenkZUQwROfVfKH93YTbtTxi6VTWDYvY2i/n+2t8MJlUL8dLnhQ7UcbDVrK4K07oGYz5J8H5/1dDVyEEOODz+evU2jq3307UK3C4ULcfoHtYWoVQqMlhBFivFAUqNgAGx+Big9BH6aeTbTgDojPDvbqhBBiTDrujtmjuONkoFFRFEWj0cwDtEAr0A7kaDSayUAtcBWw7Ls+nhBDweXx0d516D7WA8PXNrsbl3fgaVaDrmeaNYRYk4HU6LBAsBoIWHs+wo1Ehxsx6kfmDs+xJiOn5CRwSk5C4Dprl5viOis766zsrO1gZ62V9bsaA7cnRYVQkGruF9gmRYWMyY2ujpWiKGg0Gnw+hT0NnX2CUR9Ot5eMuHDyk6NwuL288NX+QGDq9PhwuL2ckZ/IabkJNHc6+e83ig76+ttPz+LCmWmUNHRy6WNf4vR4cXt7k9G/XD6Dy2ZPYHd9B5c+tvGg9T12zQmcW5jC11UW/rh6T7/bdBoNl8+ewG2nZ1Hf7uD/1u1Br9MQotcRHWYgxKDFqFO/j3OSIrl+4SRC9VpCDDpC9FpC9Fri/BPWZ01NJisxgqkpUSRFDXGQ77TBU2eDtRqufB7ylwzt4w0Gnxe+ehw2/A70oXDJE1B4mQQmQow3Wm3vJC75hz+2b4jbr0qhz+XOBnVndXvzwCGuznhAJ25PlULSwZclxBVidHI7oOhVNZBt3q3+nT7zlzDnJtlIVAghguSIwaxGo3kROB2I12g0NcCvAQOAoiiPA5cBt2s0Gg/QDVylqGO4Ho1GswJYB+iAp/zds0IMKUVR6HB4Btz46lCbYXUeZprVHGYIBKkTYsKZMSHaH64aAuFrrCmE2HAjMSYDEcM0zRos5nADJ2bHc2J2b79lp8PNrroOdtapQe3OWisflTTh82eC8RHGQFdtQVoUBWlm0qLDRv2fU7fLS6vdicX/fWUOM3BChropwn1vFtHU4cTS5/ZLZqXx2wsL8CkKSx787KD7+8Gpmfz3kihcXh+/f7d3n0SjP9hMjQ7ltFw1JK9t7w4EntFhBkIiQ4gMVX+kx5qMXDk33X+7jhCDetyMCWYAshIiWHn9HEL0WkL9wWl7t4uiaisXPfIF26rbAZiSHMl5M1I5tyCZzISIwHqyEiI4OefkQ/65zEyPZmZ69CFvz4gLJyMu/Gj/mL+bkAiYdQ2kz4eJJw7PY34XrRXw9p2wf6N6KuH5D0BkcrBXJYQY6fqGuIlHEeJ2tw0whdunVqGzHhp2qNcp3oPvoyfEDYtRNyMMjVI/h0QecNk88PUhUaNn40UhxgJ7C3z9JHy9Uv17nlQAFz2u1iPpQ4K9OiGEGNeOqspguEmVgejL6fEe1Mt60MZX9v41Ah7fwN/XRv/UXk/3akz4AJUBfW6LDjdg0I3MadaRrsvlYXd9B0U11kBgW9Zkw+v/fxMTbqCgX2dtFBmx4UELaxVFwenxEWpQu3i/LG+hpq2bVrsLi91Jq91Fekw49yxWN4E58y8fU9li73cfZ01N4l/Xq2cmnPPAp2g0GuL6TEgvyIzlnIIUANburO8TmqrhaGJkCIlRoYFfLoTo1SlUrXZo/kx6KgveK2pguz+MnZYaxZLCFJYUpjA53jQkjzssqjeDRgcTRskGXz6f+mbp/V+rgce5/w9mXCUTaUKI4OoX4h7Qg2trVqd0HR3g9H/0XB5oIvdABpM/rI3sE+72XDYf4vqecNd/WR8qPyeFOJymPbDpUdj+EnidkHMWLFwBk0+VvztCCDHMDlVlIMGsGFY+n0Knw0Or3dmvGqCnjzXwucsdCF1tzkO/uI8ONwQqAQba+Co2ov9t4UbdqJ/SHM0cbi+76/2TtTVqHUJpY2fgdPvIUD0FqWYKJ5iZlhpFYZqZSXGm4womfT4Fa7fbH6yq4arbq3D+jFQAHtxQxua9ln4Bf1ZiBGvuOgWg3+RomEFHrMnI/MxY/nbFTACe+KwSl9fnD15DiDUZSTGHkhodNhh/VEOm2tLFe0X1rC6qZ0eNFYDCNDNLClM4tyCZSaM5jO1R9j68cj0kTYOb3x/5bzwse+HtFVD1OWQvVntwo1KDvSohhDg+igIehz+k7QSntc/ljgMuW/2hbmefgNd/2W0/8mNpDQdM4x7F9O6BQa8xUp04FmKsUBR1s9ONj0D5++ovMGZcpfbHHmmDQiGEEENmyDpmxfjmcHsPCFgPqAw4IHxt63IFJiYP1NNBGeufVp0cF37wxld9glezf4d3MXqEGnTMyohhlv90f1AnoksbbOwMbDJm5d9f7sPlUTt8I0L0TE2NYmpKFJPiwjkxK46sxEiKaq3sqGmn1db7fdfl8vD0jfMAuOvlbbyzva7f48eajIFgtsXmxO7ykBYdSmFaFLGmECb2Ob3+gStnotdpiDOFEGbUHfRcbjklc9D/fIbK/tbeMLaoVg1jp08w8/Nz81lSkDJ8tQLDYccr8NbtkDgFrlo1skNZnw+2PAXrfwVaHVzwsLox2UhesxBCHIlGA4Yw9SMy6fjvx+sB1wCBbSDQHSjo7YD26v5h8EBVDAcyRh5mevdIoW/P9K6cDi6CzOOEotfUQLapWK0bOeMXan+sKf7IXy+EECIoZGJWBPRMGA648VWfGoG+19ldA7/Y1WggJtxITLiBOFMIMX36WPtWBfTcdqjwS4xdLo+PVruzX7Daandx1dx0jHot//ykkte2VGPxf5/1DfTDDDoiQ/U0dTrRoE7aJkSGEBcRwou3LkCn1bBhdyNVrV3ERfSG+fERIUO/2dQIsa/FHghji+s6AJgxwRyoKUiPHUNhbI9Nj8Han8OkU+CqF9TpqJGqfb86Jbv3E8g8Ay54CKLTg70qIYQYWxQF3F1HP6XrtA58vaf7yI+lCzn6Kd2QSAgxH3y9wSTTu+LY2VvVX/RuXqnWjCROhYV3QsFlYBgfr3uFEGI0kInZccjh9gZO0z7ws6XLhcXWf1Osti4XhxhmJdyo69fHmpUQ0b8ywHTwNKtuiHoxxcijKApdLi8Wu4u4CCPhRj2VzTY+2N2oVgnYeoPXv14xg6yECF7cvJ9f/+fg/QBPz0sgKyGCtJhQUsxhTEszE2cyEh1mwKdAWnQYJY2dbK9uo9Phptvto8PhwenxERFq4Nf/2enfZMzMKTkJGPXj5w3O3hY7q4vqeW9HPbvq1TB2Zno0v1gyhXMLk5kQMwbD2B4+H+z9DPLPg0ufHLlvRBQFvn0G1v0PoMB5D8DsG2RKVgghhoJGA0aT+kHK8d+PxwUuW2+oe1A1Q8fA11v29j+GIw3EaA4zsXs0lQ3+63WG43+uYvRoKVP7Y7e9qP7yIHsRLHxc/YWvvK4QQohRQyZmRwlvzzRr3x7WPp/7bnzVM33Y7R54mlWrIRCo9q0KiDtEyBoTbpRp1nHI5fFR296tbnzVJ1hdNCWJvORItlW38z9vFWGxqdc7/dUDz9w0j9NyE1i7s4Hbnt9y0IZrvzpvKjlJkZQ1dvJNVRuxB3wPmsMMR90D7PUp7G2xU1xn9W8yZqW4toNOfy+xQachLzmSwj6bjOUlRwY2+BoLKpptrN5Rz3tF9exp6ARgVkY0SwtTOLcwhbQR3nn7nfm86hvl8FhwO0CrH7k7fVtr4D8/gooN6qYbFz4C0RnBXpUQQojh4POpvblHO6V74IZqPaGv13nkx9KHHf2Ubr8N1/oEvYZwCfdGIkWBfZ+pdQWla9VJ7elXqBOyiVOCvTohhBCHIZt/jVCtNidFtdb+tQEDVAm0d7s51P8qk1EX2OQqpidQ7bPx1YHha1SoYch2eRcjk8+n0O4P9iNDDSRFhWLtdvPvL/apwWuf77Xlp2ZyyQkTKK6zsvTBzw+6rz9fNp3L56RT1tjJH1fvJtYU0lsXEG7klNx4UsxhOD1e3F4F0zBvuObzKey3dFFUqwa1O2ut7KztwNrtBkCv1ZCTFElBapR/kzEzU1OiRtUvH8qbbKz21xT0hLGzJ8YENvAa6RuQDRq3A964RZ1IumXDyJ6S3fYCrP1vNUg+63cw+yY5XVUIIcSx8zgPP6Xbb8O1Q1Q2uDqP/DgaXZ9pXPMANQ1HubHaSP1l6WjjcUHxG7DxYWgogvB4mHsLzL0ZIhKDvTohhBBHQYLZEer9XY3c+mzvc9VpNf7uVaO/l7VP0GrqE7z2mWYdS9N/4tiUNnbSYnP2drTaXExNjeLsacl0u7xc+Mjngdt6aipWnJHNT87Oo9XmZPYfPiAqVE9cREjg++mqueksmppEp8PNB7sb1W5gf9AfZxqd32+KolDT1s3OWqs/sO1gp/8XIqBOkWcnRlCQZqYg1UzhBDWsNYWMnDcTZY2dgc7Y0kYbAHN6wtjCZFLM4ySM7eHogJeWqVMjZ98PC+8I9ooG1lEH79wFZeth4slw4cMQOznYqxJCCDGe+bz+aoajnNIdsLKhA3yeIz+WwXSEKd1DVTb0CYP1oeN3erfLAluehq/+BbYGiM9Tp2OnX6FusCeEEGLUkGB2hGrvclHRbAuEX1Fh+mGdLhQjg8frC/T9alBPvwd4/JMKqi1dgRqBVpuTE7Pi+f1FBQBM/dVaug7YgO26BRP5/UUFKIrCilVbMYcbiOsT5k9LjSI7MRJFUXB7lXHVwdqXoijUWx3+iVo1rC2qtdLcqZ4iqNHA5HgThf6wtiDNzLS0KKJCh6+3rbSxk/d2qGFsWZMNjQbmToxlSWEy5xSkkGweoROiQ83WBM9fCk274MJHYcaVwV7RwRQFdrwMa36mTrks/i3MvVWmZIUQQowNigIexwBTugcGuZ2H33DNbT/yY2kNh+7YHbCP9xDTu6Pp3+DWCnVT020vqBvYZZ4BC1dA1pmj63kIIYQIkGBWiCCobLZR1+6gtU9PqylEz+2nZwFwyzNf801VG+1d7sDXLMiM5aXlCwH43l8/pq3LHQhV40xG5k+O5YaT1Im7D3Y1Eu6vsogzhRATbkCvkxdr30VTh0Odqq1Vg9riOiv1Vkfg9klx4UxLM/cJbKOIDjcOymMrikJJYyerd9SzemcD5T1h7KRYlhamcE5BMklR4zSM7evFq6HiI7jiWcg9K9irOVhnI7x7N5SshvQFcNGjEJcV7FUJIYQQI4/Xo1YrDDilax2gpmGgyoZOUAbeW6Mf4+HqGI5U2eAPePUhQ/dnoShQ9aXaH1uyWt3ErfAK9aygpGlD97hCCCGGhQSzQnwHTo83UBXQ6fCwMCsOgDe31vBVpSXQ0dpqcxJq0LH27lMBuP6pzXxa2hy4H60Gpk+I5q07TwLgHx+U0Wp3BkLXWFMIE2LCmJEeDah9qdIHHHwtNmfvZK0/sK1t7w7cPiEmTA1qez5So4iLOLoX7oqisKehk9VF6gZelc12tBqYN1kNY8+elkyihLH9WWvVioD0ucFeSX+KAjtfh9U/AXc3fO9XMP820I6++g8hhBBi1FAUdar0aKd0D9xwrSfc9XQf+bF0IUc/pXuoDdeMEf2rGbxuKH5L7Y+t3wZhsWp37NxbITJp6P7chBBCDCsJZoUYQG17N2WNnb0drXYXFpuL/724AL1Oy1/Xl/D0F/uwOXs7tAw6DaV/OBeNRsN9bxbx/q7GflUBqdFh3LdE3RV1e3U7DreXOP9EqzlMNl4bK9rsLv/mYh2BTcaqWrsCt6eYQ/t01kZRkGoOBKyKorCrvoPVRfWsKWqgskUNY+dPjmPJ9BTOmZZMQuQQTmSMRlUb1dP5zv/HyAw6bc3w3j2w+x2YMBcuegzic4K9KiGEEEIcLY+rdyr3iB27A1U2+C9zpPfXmv6hbler2h8bl6NOx06/Cozhw/GMhRBCDCMJZsWY1uXyBKoCLHYXcybFEBlqYFNlK69tqQlMs/ZMtm748WmkmMN4aEMZf32/NHA/Bp2GWJORdXefSnS4kdVF9Wzea1GDV3+4GhdhZHZGjASs4iDWbjfFdVaK/VO1O+us7G2x0/NjNjEyhKmpUexrsbOvtQutBhZmxXFugToZK2HsIZSsgVdvAHM63LgGIhKCvaJeHhdsex4+/AM4bXDmL9QOuJEYHgshhBBiaPl8am/ukaZ0+17WaGHWtZC9WPpjhRBiDDtUMDtythwXwk9RFBQFtFoNbXYX3+5vCwSqFruLFpuTH5yaRV5yJGt3NnD3y1txuH397uPNO05kVkYM9dZuvihvCUyzTo43EWsKQe9/0XPRrDROzI4jzhRCbISRyJD+m68tKUxhSWHKsD5/MXqZwwycmBXPiVnxgetsTg+7/BuLFddaKa7rYEJMOMtPzeLsaUlHXXkwbm1bBW+vgJTpcM1rYIo/8tcMB48Ttj4Pn/8drNWQPh/OfxAS84O9MiGEEEIEi1brry2IBNKCvRohhBCjgASzYsgpivL/27vz+LrqOv/jr0+Spm2Sli4pbWlpS9m3lqUssoO4wKD8WHRkXHBQYX6u48wojDq/cfk5Px39geI24IagIu7bKIoosi9lKyiLyA4tpXvaps32nT++J81NmjRtSe5N0tfz8biPe+5Z7vnck5s+Tt/5ns+haWMbK9a2bApY99i5gd0a63l25Xou/u2jLFvXwop1G1mxtoVl61r47Ovm89r5u/Dwkibe9q2u0dOja6pobBjN2YfMBMYxp7GONx0xuxjNWrspYN1r6jgAzjh4JmccPLPP2nadVMeuk7xUSIOnYXQNh+82icN3m1TpUoafOy6HX38Adjse3vCd4j85Fda2Ee69Cm66BNY8m9sWvOZzsPvLu/eLkyRJkiSpHwaz2mYpJVrbE7U1VbS2d/D7h5d2G826Yl0LJ+87ldfM34XFq5s5/j9voKW9+4jWD5+6L+84bi4dHXDnk7lVwJSG0ew9dTyTG2rZbXI9AAfMGM9P33X0ph6udbXV3Ua07jNtPB85bb+yfn5JZTJ9Hsw/J/eVHcy7IG+Nto1wz5V5hOya52Dm4fDaS2H3kwxkJUmSJEnbxWBWdHQkVje3lrQL2MjEulqOmDsZgPdfcx9LmzZs6uG6cn0L5xw+i4+ffgApwQVX3b3pveprq5nUUMu8mRMAmFhXy98fM6fbaNbJ9bXMKkapzppcx80XntRnbePGjOKgXScM4qeXNKS0t8Ffr4e9XgWzjsyPSmrdkEfIdgayux4Bp38R5p5oICtJkiRJekkMZkegjo7EupY2xo0ZBcAfH32Rp1esL1oJ5BtgzZ5UxwdfnXshHvuff+C5Vc3d3uOV+03dFMw+UdwxfubEOubPnMDkhloWzJkIQG1NFf/93mOYWJdHtI4Z1f2GN2NGVfOvp+w72B9Z0kjQ2gw/PA8e+RWcfwPscnAFaykC2ZsuhqbnYdcj4fQvwdwTDGQlSZIkSQPCYHYYaO9IrFyfR6suL8LV9o7E6QflhvIXX/codzy+fFM7gZXrW9h72nh+/b5jAbjkuke575lVAIwfU8PkhtGMG931oz//uLm0dyQmN9RuuknW1PFjNi3/6buO3mJ9+++y00B/ZA2mDavh+XthyQOw9KF8ifi4XeDQc6Fh57y8ox3GTjSAUvk0r4Krz4Gnb4NTP1u5ULZ1Q9Gy4GJoWgyzXgZnfCX3ufX3QZIkSZI0gAxmh4B7nl7JomdWsXxdcXOstS2sb23nyvMOB+Afr7mPX9z/fLdtJtXXbgpm1zS30pESc6fUc9huk5hcX9vthlZfOOdgRtdUMbG+llHVVZvt/9yj5gzeh1PlpASrns4B7JJFcODroXEPeORa+Mn5eZ2GqdDeCs0r4IAzczB797fgun+D6tEwbhqM3yU/n3ZJDmtffBTWvdi1bNTYyn5ODX9NS+DbZ8GLj8BZX4MDzy5/Da0b4J5v5ZYFTYth1lFwxmWw23EGspIkSZKkQWEwOwT8atFivnbzE0TknqyT62uZ3FBLe0eiuio48+AZHDZn4qbRrJPrRzOpvnbT9h997f5bfP/SkFYjVHtrvjnR6AZY/lf4xftyGLthdbFCwJS9czC7+4nw5p/CtAOhvjEvbt0A1cV3au4J8Kr/ly/fblqSH4sXQU0RwC78Btzxla59j9kJxs+AC26C6poc/K5+BsZNh/HT83P9znmZ1JsnboQVT8DfXQN7vLy8+25tzn+MuPkSWLsEZh8NZ14Oc441kJUkSZIkDapIKVW6hs0sWLAgLVy4sNJllM3KdS10pMSEulqqqwwC1I+Odnj2rhyWLikeSx+C4z4Ax38Q1q+A75wN0+bl8HX6fNh5X6itH5j9r3oGlj+WRxU2LYY1i2FjE5x5WV7+/bfAn3/WfZuGafAvj+Tpmy6G1c+WBLfTYKdZMGWvgalPw0drc9eI66YXYNzU8u777iuKQPYFmH0MnHAR7HZs+WqQJEmSJO0QIuLulNKCnvMdwjYETCwZ/Sp10/RCDl4X35/bDBzyltyi4MrToW0DjJ0E0+fBEf8Ac47J29RNgnf8fvBqmrBrfvTl7G/CumV5xO2aIrxNHV3Llz4Ej/0ut0/oNP0guOCPefqaN8G65V2jbcdNh6n7we4n5eVtLVDj78yw09GevwvrlsH6ZbnNxg2fgtdfBbOOKF8o29oMC78Jt3wuB7JzjoWzvm4gK0mSJEkqO4NZaSjo6IB1S/PoUYCfvQse/W2e12nf1+RgtroG3vwTmDgnh5ZD7XLrquocso2b2vsNnM76an5u3ZAvHV+zGCgZuT9ueh71+/y9sOZX0NYMe/9NVzD7uQNzKN3Z+3bcLrk9Q2df0sX359YJDTvnWjTwUoKNa2D98tyLeKcZOTC//UtF8Lq8K4Cdfw4ccQGsXQqX9Gi7MmWfvG05tKyHu78JN38u/17NORbO/kbXHzQkSZIkSSozg1mpEpY+DM/cUbQieACWPAj1k+EfH8jLx0yAPV+RWxFMOxCmHgBjJ3RtP/uoytQ9kEaNyeHyxDnd55/6ma7plHKf3LaNXfOOuADWPJd73655Ph/L0Q05mG1vhcuOy+tFVb652bhpcOhb86OtBR74QffRuGN2GnrhdqUs+0sOUNcv6wpYJ86Bea/Py796Uj7m65dDe0ued+hb4TWfh6oauP7jOaitb4S6yfl59Pi8Xn0jvObSYlljfp4wC6pHDe5nalmf+yLf8vkcyO52HBx/Bcw5enD3K0mSJElSPwxmpcG0fkURvD4AS/8Mr/1CHsV552U5LKodl4PXQ96ce8KmlEPCV32y0pUPDRHdA2mAY/9py9u84equG5d1tlKoKv6pa1oMP3tn9/VrxsKr/wMWnJdbR9x6aUn/22JU7vgZw6t9QltLV7ja0QYzDsnz77gst5JYvyx/N9cty20iXndFXv7ts2DVU93fa7//1RXMNu6d+xV3Bqt1jfk1QFUV/OtzUNvHzQarR8Gh5w74R+1Ty7qSQPZF2O14OOFbI+OPGpIkSZKkEcFgVhoIKcHqZ6B+Sr6Z0YM/guv+Pc/rNG56Dgt3mgFHvw+Oeg9MmJMDLQ2M6lGwz6l9Lx8/A957b9do26YlOaydUoSLa56Fu76WWyWUOuvreUTu8/fC9Z8oGXFbtFKYdWTu7TvYlv0FVjzevVUAAa/4WF7+4wvgkV/lNgOdpuwD77ojT//5Z/DiIyWh6j75DwOdTrsYorpred3k7oH0GV/Zcn19hbLl1LIO7vp6DmTXL4O5J8DxF8Hsl1W6MkmSJEmSujGYlbbHuuXw2HV5JOzi+/PzhlXwlp/D3ONzj9NdD4fD3l60I5gHDVO6tu95+b7Ko7oGJs3Nj97MOBQ+vCT/i5tS6gAAEl1JREFULDtH2zYtzj9LyDeOal6RRz+vfaHrpmZ//+s8EvPBH8FvPtK9VUJnK4X6RtiwJm+TOrqHq/uclkcHP/gjeOTaklYCK6B1HVz4ZN7PjZ+BRdd01Vs1Kn+WzmB2xiF5hHFdYw6K6xtzcNzprf+95bYNe5y8PUd1aGhZl0P1Wy4tAtkT4YSLcmguSZIkSdIQFCml/tcqswULFqSFCxdWugwJNq6FF/6Ue8Euvj9f1r3nyfDcPfDVE6FmDEzdvyt83fuUfFMqjXztbfkS+abF0LhX7nP71K1w77e7j8bdsAretwgmzoabL4HffXTz97rwSRg7EW74NNz3na7+rHWNuffwyR/LLTCWPgwbm/K8ukYYPc7+uBvX5kD21ktz2L37SXmE7KwjKl2ZJEmSJEkARMTdKaUFm803mJUKTS/kS9gnzobmVfC1l8PyvwLF78jYiXDSv8Fhb8s9PFc8DpP3yKMwpb60rM8BflVVboXw5M25XUDd5K6Ader+g38TrJFm41q466tw6xeKQPbleYRs5+hmSZIkSZKGiL6CWRMl7bj+/PMclC1ZlFsRrH0BDnwdnPU1GLNTvqz9wNfD9Hl5ROz4GV2jE2tqc39OqT+lfVd3OTg/tP02NsGdRSDbvCK3Xzj+Itj1sEpXJkmSJEnSNjGY1cjWtjH3A13yACxelC8HP+XTedkfPw0vPpxv/LT7y3MA2znaLgLOvLxydUvqbmMT3Hl5EciuhD1ekUfIztzsD46SJEmSJA0LBrMaOZpXwrLHukbO/fL9cM+V0NGWX9c2wJxju9Y/53vQsDPUjC5/rZK2zoY1OZC97Yv5d3zPV+YRsjMPrXRlkiRJkiS9JAazGr4W3w+P/LprNOzqp4GADz0HtfUw87DcF3Za0Ypg4m65z2enCbtWrHRJ/diwBu68DG79Yr6B2p6vguMvNJCVJEmSJI0YBrMa2tpbYdlfuvrALlkEp38JJsyCp26FGz4FjXvmUbKHnZdD2KriJkoH/V1la5e07TashjuKEbIbVsFer4bjP5h7PkuSJEmSNIIYzGro2Lg294OdMAvGTYPHroerz4H2jXl5zRjYeb98OfOEWXDwm+CQt+TRsdJwkBK0twAB1aO6bianIpC9rAhkV8NepxSB7CGVrkySJEmSpEFhMKvKaV4FC7/RNRJ2+V+BBKd+Fg5/BzTulZ+nzcs35pq8J1SXfGVHj6tY6RqBUoK2DdDaDK3roWV9fu583dpcMr2+l3ml2zX3vW1q79pn9ej8B4ea2mK6eFTXlkyP7jFdm7fpbZ3q2h7v19d7j+k+XcmQuHlVDmRv/1IOZPc+NQeyuxxcmXokSZIkSSoTg1kNro4OWPlEDl4XF+0I5h4PR70Hogqu/3ju9TptHhz4+twLdmZx864Ju8KrPlnZ+jU0dLT3HoJ2C0r7CENb1vURrvay7TaLPGJ71NjiUdf1XDd583mjxsKoMXm7to15NHhbSw6E21tK5hWP9hbY2JTX6Ta/ZJo0MMd4WwLg0nU2m7c1IXFtXv7QL+C2L8PG1bD33xSB7EED83kkSZIkSRriDGY1cNo2wtKHcsg1+2V5BOIl+0PT83l5VMOUfXIgCzBmPFz0dH7W8NXetpUhaX+jSbcQoLZt2Pa6orpHaFoy3TC1e1haW9dLgFr63HNeMV0zurLtCFKCjrbeA9tuoW8vAXDp+u3Fsi0FwJuFxL289/aExPuclgPZ6fMH/PBIkiRJkjSUGczqpbn/e/D4DXkk7IsP55Bo+ny44MYcWB317txyYNq8HMqOGtN9e0PZ7jo68qXuHe09njvysd1sWUfJ67Ze5pUu6+29O7ZxJGovAWtH67Z/zuravgPPsZOKoHRrQ9KSbTcFsXX58vyRLopetdWjYHRDZWvZmpC4fWNXANy2AabsDVP3r2zdkiRJkiRViMGstiwlWP1sVx/YJQ9A0xJ4x/V5+aPXwlO35RYEe7wihyw775dH1nW0wbw3dIWA65f1EzSWBIfdlvUWNPYWXpasu1Xv3VcI2kcw2i387GP7Xmva2u3bKvuz7lTT87L8Ynr0uJKRplsbltZtHrLWjO3eK1gjw1AKiSVJkiRJGgZMRypt5VPw1+v7CRrbepnXvpVBY3/blyxrb4e2YhTkqLq83vrl0LK2q96oyo//O71r+9QBjy2Bx66r3HHcWlEFVTX5Mveq6uK5qsfrml7mVXdtWzqvqqa4nL3kdVT12G4L23dbt2YLNfV8r97q39L2Nb3sr3geNaZ7aFpVVemfkiRJkiRJ0ohnMFtpL/wJfvn+/teLrQwKewZvWwwKR+XL0dctgY1roaUph6wAu58M9Y3QvCLfNb2+Eeoac4i3xaCwt6Cxt/33FX4WoeJWhad9vfcWaqtkP1BJkiRJkiSpYDBbabufCP/8SN/B6qbnlxAorn0Rltyf2xAsLtoRnPIp2ONkeOx6+OF5sMv83Ad22rzclqBxLy83lyRJkiRJkgaJyVuldfbrHAgdHbDyiRy8TpwNuxwMSx+CLx/Ztc5Os3LwOqo+v557Ilz4pCNJJUmSJEmSpDLqN5iNiG8ApwFLU0oH9LL8jcCFxcu1wP9OKd1fLHsSaALagbaU0oIBqlsp5TC1vQ1+86HixlwP5nYEAIefn4PZyXvAKz8J0+fB1AOgblL397GfqCRJkiRJklR2WzNi9grgi8CVfSx/Ajg+pbQyIk4BLgeOKFl+Ykpp2UuqckfXvCqPgt30WART9oGzv57bDTz+Bxg7Cea/IQew0w6EKfvmbatHwVHvrmz9kiRJkiRJkrrpN5hNKd0YEXO2sPzWkpe3AzNfelk7qJRgzfM5eF23DA55c55/1Rnw/D15un7nHL7OOKRru3fdaSsCSZIkSZIkaRgZ6B6zbwN+XfI6Ab+NiARcllK6fID3NzLc911YdE2+MVfzijxv9Hg46I251cCJHwIij4QdN3Xz7Q1lJUmSJEmSpGFlwILZiDiRHMweUzL76JTS8xGxM3BdRDycUrqxj+3PB84HmDVr1kCVNTyseR42rIZ9T4Np8/Jj6n5d/V/3fEVl65MkSZIkSZI0oCKl1P9KuZXBL3u7+VexfB7wE+CUlNKjfazzUWBtSumz/e1vwYIFaeHChf3WJUmSJEmSJElDWUTcnVJa0HN+1QC88Szgx8CbS0PZiKiPiHGd08ArgQdf6v4kSZIkSZIkabjrt5VBRFwNnAA0RsSzwL8DowBSSv8F/B9gMvDlyL1O24oEeCrwk2JeDfDdlNK1g/AZJEmSJEmSJGlY6TeYTSmd08/ytwNv72X+48D87S9NkiRJkiRJkkaml9zKQJIkSZIkSZK0bQxmJUmSJEmSJKnMDGYlSZIkSZIkqcwMZiVJkiRJkiSpzAxmJUmSJEmSJKnMDGYlSZIkSZIkqcwMZiVJkiRJkiSpzAxmJUmSJEmSJKnMDGYlSZIkSZIkqcwMZiVJkiRJkiSpzAxmJUmSJEmSJKnMDGYlSZIkSZIkqcwMZiVJkiRJkiSpzAxmJUmSJEmSJKnMDGYlSZIkSZIkqcwipVTpGjYTES8CT1Vg1zsBq3eA/Q72/hqBZYP4/lJfKvU7vCPzmGcj7TgMh88z1GqsZD3l3LfnEBqphtq/KTsCj3k20o7DcPk8Q6lOzyEGhucQqpSt/W7PTilN6TlzSAazlRIRl6eUzh/p+x3s/UXEwpTSgsF6f6kvlfod3pF5zLORdhyGw+cZajVWsp5y7ttzCI1UQ+3flB2BxzwbacdhuHyeoVSn5xAD9v6eQ6giXup321YG3f1iB9lvpT6nNNj8bpefxzwbacdhOHyeoVZjJesp576H2nGXBorf7fLzmGcj7TgMl88zlOr0HEIa3l7Sd9sRsxpw/qVKkiRtD88hJEnS9vAcQsOVI2Y1GC6vdAGSJGlY8hxCkiRtD88hNCw5YlaSJEmSJEmSyswRs5IkSZIkSZJUZgazkiRJkiRJklRmBrOSJEmSJEmSVGYGsxp0EVEfEXdHxGmVrkWSJA0PEXFCRNwUEf8VESdUuh5JkjQ8RERVRHwyIr4QEedWuh5pSwxmtc0i4hsRsTQiHuwx/9UR8UhEPBYRF5UsuhD4fnmrlCRJQ802nkMkYC0wBni23LVKkqShYxvPIU4HZgCteA6hIS5SSpWuQcNMRBxH/o/SlSmlA4p51cCjwCvI//DdBZwD7AI0kv9TtSyl9MuKFC1JkipuG88hHk4pdUTEVODilNIbK1S2JEmqsG08h3gtsDKldFlE/DCldHaFypb6VVPpAjT8pJRujIg5PWYfDjyWUnocICK+R/4rVQNQD+wHNEfEr1JKHWUsV5IkDRHbcg6RUvpzsXwlMLpsRUqSpCFnG3OIZ4CWYp32ctUobQ+DWQ2UGeR//Do9CxyRUno3QES8lTxi1lBWkiSV6vUcIiLOBF4FTAC+WInCJEnSkNbrOQTweeALEXEscGMlCpO2lsGsBkr0Mm9Tn4yU0hXlK0WSJA0jvZ5DpJR+DPy43MVIkqRho69ziPXA28pdjLQ9vPmXBsqzwK4lr2cCz1eoFkmSNHx4DiFJkraH5xAa9gxmNVDuAvaMiN0iohZ4A/DzCtckSZKGPs8hJEnS9vAcQsOeway2WURcDdwG7B0Rz0bE21JKbcC7gd8ADwHfTyn9qZJ1SpKkocVzCEmStD08h9BIFSml/teSJEmSJEmSJA0YR8xKkiRJkiRJUpkZzEqSJEmSJElSmRnMSpIkSZIkSVKZGcxKkiRJkiRJUpkZzEqSJEmSJElSmRnMSpIkSZIkSVKZGcxKkiRpRIiIJyOisQL7nRAR79zWZT3WmxMRD/ax7IaIWPBS65QkSdLQYjArSZKksoiImkrXUCoiqgforSYAfYWvW1omSZKkHZjBrCRJkrZZMcLzoYj4akT8KSJ+GxFje1nvioi4OCL+AHw6IiZFxE8jYlFE3B4R84r1HihGl0ZELI+ItxTzr4qIkyNi/4i4MyLuK7bds5/6fhoRdxe1nV8yf21EfDwi7gBeFhGnRsTDEXFzRFwaEb8s1quPiG9ExF0RcW9EnF7M762OTwG7F/M+06OUbssioiEiro+Ie4rPfHrJujUR8a3ifX8YEXW9fK5XRsRtxfY/iIiGrfhxSZIkaQgymJUkSdL22hP4Ukppf2AVcFYf6+0FnJxS+mfgY8C9KaV5wIeAK4t1bgGOBvYHHgeOLeYfCdwO/APw+ZTSQcAC4Nl+ajsvpXRose57I2JyMb8eeDCldASwELgMOCWldAwwpWT7DwO/TykdBpwIfCYi6vuo4yLgrymlg1JKH+hRR89lG4AzUkqHFO/7/yMiinX3Bi4vjs0aeoy0Ldo0fKQ4locU9f9TP8dBkiRJQ5TBrCRJkrbXEyml+4rpu4E5faz3g5RSezF9DHAVQErp98DkiNgJuAk4rnh8BTgwImYAK1JKa4HbgA9FxIXA7JRScz+1vTci7ieHuruSQ2SAduBHxfQ+wOMppSeK11eXbP9K4KKIuA+4ARgDzNqOOnoK4D8iYhHwO2AGMLVY9kxK6ZZi+tvkY1XqSGA/4JairnOB2du4f0mSJA0RBrOSJEnaXhtLptuBvnrIriuZjl6WJ+BG8ijZY8lB6IvA2eTAlpTSd4HXAs3AbyLipL6KiogTgJOBl6WU5gP3koNVgA0lIXFvtZTWeVYx0vWglNKslNJD21JHH95IHpl7aDHq9oWS2lKPdXu+DuC6kpr2Sym9bRv3L0mSpCHCYFaSJEnldCM5nOwMUJellNaklJ4BGoE9U0qPAzcD/0IRzEbEXPLo1kuBnwPztrCPnYCVKaX1EbEPeaRpbx4G5kbEnOL135Ys+w3wns42AxFx8BbqaALG9bGPnst2ApamlFoj4kS6j3idFREvK6bPIR+DUrcDR0fEHkUtdRGxVx/7lSRJ0hBnMCtJkqRy+iiwoLiU/1Pky/E73QE8WkzfRL7MvzOc/FvgweIS/n3o6k3bm2vJN9JaBHyCHGhupmhD8E7g2oi4mTx6dXWx+BPAKGBRRDxYvO61jpTScnJ7gQd73vyrl2XfKT7/QnJA/XDJ6g8B5xZ1TyK3dCh9rxeBtwJXF+vcXtQgSZKkYShS6nmFlCRJkrRjiIiGlNLaYmTsl4C/pJQuqXRdkiRJGvkcMStJkqQd2TuK0a9/IrcZuKzC9UiSJGkH4YhZSZIkSZIkSSozR8xKkiRJkiRJUpkZzEqSJEmSJElSmRnMSpIkSZIkSVKZGcxKkiRJkiRJUpkZzEqSJEmSJElSmRnMSpIkSZIkSVKZ/Q9xcNEVZtQ5XQAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "plt.figure(figsize=(24, 8));\n", - "plt.title(\"Join performances pandas vs polars\")\n", - "x = np.array(n_proxy) * 0.8\n", - "plt.plot(x, np.array(left_pd) / np.array(left_pl), label=\"left speedup\", c=\"C0\")\n", - "plt.plot(x, np.array(left_pd) / np.array(par_left), label=\"left parallel speedup\", c=\"C0\", ls=\"--\")\n", - "plt.plot(x, np.array(inner_pd) / np.array(inner_pl), label=\"inner speedup\", c=\"C1\")\n", - "plt.plot(x, np.array(inner_pd) / np.array(par_inner), label=\"inner parallel speedup\", c=\"C1\", ls=\"--\")\n", - "plt.plot(x, np.array(outer_pd) / np.array(outer_pl), label=\"outer speedup\", c=\"C2\")\n", - "plt.xscale(\"log\")\n", - "plt.xlabel(\"n rows largest table\")\n", - "plt.legend()" - ] - }, - { - "cell_type": "code", - "execution_count": 48, - "metadata": {}, - "outputs": [], - "source": [ - "str_groups = np.array(list(\"0123456789\"))\n", - "groups = np.arange(10)\n", - "\n", - "size = int(1e7)\n", - "g = np.random.choice(groups, size)\n", - "sg = np.random.choice(str_groups, size)\n", - "v = np.random.randn(size)\n", - "df = pd.DataFrame({\"groups\": g, \"values\": v, \"str\": sg})" - ] - }, - { - "cell_type": "code", - "execution_count": 49, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "360 ms ± 6.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" - ] - } - ], - "source": [ - "%%timeit\n", - "df.groupby(\"groups\").agg(\"str\").count()" - ] - }, - { - "cell_type": "code", - "execution_count": 50, - "metadata": {}, - "outputs": [], - "source": [ - "df = pl.DataFrame(df.to_dict(orient=\"list\"))" - ] - }, - { - "cell_type": "code", - "execution_count": 53, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "118 ms ± 1.33 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" - ] - } - ], - "source": [ - "%%timeit\n", - "df.groupby(\"groups\").select(\"str\").count()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.3" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/examples/Dockerfile b/examples/Dockerfile deleted file mode 100644 index 298b6335f291..000000000000 --- a/examples/Dockerfile +++ /dev/null @@ -1,14 +0,0 @@ -FROM rustlang/rust:nightly-slim - -RUN apt-get update \ -&& apt-get install \ - libssl-dev \ - lld \ - cmake \ - jupyter-notebook \ - pkg-config \ - -y \ -&& cargo install evcxr_jupyter \ -&& cargo install sccache \ -&& evcxr_jupyter --install \ -&& rm -rf /var/lib/apt/lists/* \ No newline at end of file diff --git a/examples/datasets/.gitignore b/examples/datasets/.gitignore new file mode 100644 index 000000000000..ad6cd5df3039 --- /dev/null +++ b/examples/datasets/.gitignore @@ -0,0 +1,3 @@ +*.parquet +*.ipc +*.ndjson diff --git a/examples/datasets/foods1.csv b/examples/datasets/foods1.csv new file mode 100644 index 000000000000..66f1ef9a1d5f --- /dev/null +++ b/examples/datasets/foods1.csv @@ -0,0 +1,28 @@ +category,calories,fats_g,sugars_g +vegetables,45,0.5,2 +seafood,150,5,0 +meat,100,5,0 +fruit,60,0,11 +seafood,140,5,1 +meat,120,10,1 +vegetables,20,0,2 +fruit,30,0,5 +seafood,130,5,0 +fruit,50,4.5,0 +meat,110,7,0 +vegetables,25,0,2 +fruit,30,0,3 +vegetables,22,0,3 +vegetables,25,0,4 +seafood,100,5,0 +seafood,200,10,0 +seafood,200,7,2 +fruit,60,0,11 +meat,110,7,0 +vegetables,25,0,3 +seafood,200,7,2 +seafood,130,1.5,0 +fruit,130,0,25 +meat,100,7,0 +vegetables,30,0,5 +fruit,50,0,11 diff --git a/examples/datasets/foods2.csv b/examples/datasets/foods2.csv new file mode 100644 index 000000000000..2faf5d6d04fa --- /dev/null +++ b/examples/datasets/foods2.csv @@ -0,0 +1,28 @@ +category,calories,fats_g,sugars_g +meat,101,8,0 +vegetables,44,0.3,3 +seafood,129,4,1 +fruit,49,4.4,0 +meat,95,4,0 +vegetables,24,0,1 +fruit,123,0,24 +fruit,59,0,9 +fruit,27,0,2 +vegetables,21,0,3 +fruit,61,0,12 +vegetables,26,0,4 +seafood,146,6,2 +vegetables,23,0,2 +fruit,51,1,13 +meat,111,8,2 +seafood,104,4,1 +meat,123,12,0 +seafood,201,8,4 +vegetables,35,1,7 +vegetables,21,1,3 +seafood,133,1.6,1 +meat,110,5,0 +seafood,192,6,0 +seafood,142,4,0 +fruit,31,0,4 +seafood,194,12,1 diff --git a/examples/datasets/foods3.csv b/examples/datasets/foods3.csv new file mode 100644 index 000000000000..80150ad15c26 --- /dev/null +++ b/examples/datasets/foods3.csv @@ -0,0 +1,28 @@ +category,calories,fats_g,sugars_g +seafood,117,9,0 +seafood,201,6,1 +fruit,59,1,14 +meat,97,6,0 +meat,124,12,1 +meat,113,11,1 +vegetables,30,1,1 +seafood,191,6,1 +vegetables,35,0.4,0 +vegetables,21,0,2 +seafood,121,1.5,0 +seafood,125,5,1 +vegetables,21,0,3 +seafood,142,5,0 +meat,118,7,1 +fruit,61,0,12 +fruit,33,1,4 +vegetables,31,0,6 +meat,109,7,2 +vegetables,22,0,1 +fruit,31,0,2 +vegetables,22,0,2 +seafood,155,5,0 +fruit,133,0,27 +seafood,205,9,0 +fruit,72,4.5,7 +fruit,60,1,7 diff --git a/examples/datasets/foods4.csv b/examples/datasets/foods4.csv new file mode 100644 index 000000000000..c8e33560dd88 --- /dev/null +++ b/examples/datasets/foods4.csv @@ -0,0 +1,28 @@ +category,calories,fats_g,sugars_g +vegetables,21,0,3 +seafood,158,5,4 +fruit,40,3,3 +vegetables,30,1,6 +meat,124,14,3 +fruit,63,3,11 +meat,114,5,0 +vegetables,32,0,5 +fruit,66,0,13 +seafood,200,14,0 +vegetables,26,0,4 +vegetables,25,0,5 +seafood,135,1.6,1 +seafood,144,5,2 +seafood,130,7,2 +vegetables,48,0.7,5 +fruit,136,2,28 +vegetables,32,0,8 +fruit,54,0,14 +seafood,105,5,3 +meat,103,5,3 +fruit,40,0,9 +seafood,205,7,4 +meat,100,5,0 +meat,110,5,0 +fruit,50,2.5,4 +seafood,213,7,5 diff --git a/examples/datasets/foods5.csv b/examples/datasets/foods5.csv new file mode 100644 index 000000000000..bc6efdfd9106 --- /dev/null +++ b/examples/datasets/foods5.csv @@ -0,0 +1,28 @@ +category,calories,fats_g,sugars_g +seafood,142,6,3 +meat,99,4,0 +fruit,127,0,23 +vegetables,23,0,1 +vegetables,30,0,2 +meat,88,5,1 +fruit,55,0,8 +fruit,27,0,4 +seafood,127,1.3,1 +meat,123,11,2 +seafood,124,4,2 +seafood,204,4,4 +fruit,52,0,10 +fruit,58,0,14 +vegetables,18,0,4 +seafood,102,6,0 +seafood,210,11,0 +vegetables,23,0,5 +vegetables,26,0,0 +seafood,145,4,0 +meat,95,5,0 +fruit,34,0,2 +meat,48,2,0 +vegetables,37,0.4,1 +fruit,56,4.2,0 +vegetables,34,0,3 +seafood,180,6,1 diff --git a/examples/datasets/null_nutriscore.csv b/examples/datasets/null_nutriscore.csv new file mode 100644 index 000000000000..0f7922502bc2 --- /dev/null +++ b/examples/datasets/null_nutriscore.csv @@ -0,0 +1,28 @@ +category,calories,fats_g,sugars_g,nutri_score,proteins_g +seafood,117,9,0,,10 +seafood,201,6,1,,10 +fruit,59,1,14,,10 +meat,97,6,0,,10 +meat,124,12,1,,10 +meat,113,11,1,,10 +vegetables,30,1,1,,10 +seafood,191,6,1,,10 +vegetables,35,0.4,0,,10 +vegetables,21,0,2,,10 +seafood,121,1.5,0,,10 +seafood,125,5,1,,10 +vegetables,21,0,3,,10 +seafood,142,5,0,,10 +meat,118,7,1,,10 +fruit,61,0,12,,10 +fruit,33,1,4,,10 +vegetables,31,0,6,,10 +meat,109,7,2,,10 +vegetables,22,0,1,,10 +fruit,31,0,2,,10 +vegetables,22,0,2,,10 +seafood,155,5,0,,10 +fruit,133,0,27,,10 +seafood,205,9,0,,10 +fruit,72,4.5,7,,10 +fruit,60,1,7,,10 diff --git a/examples/datasets/pds_heads/customer.feather b/examples/datasets/pds_heads/customer.feather new file mode 100644 index 000000000000..1803394202dc Binary files /dev/null and b/examples/datasets/pds_heads/customer.feather differ diff --git a/examples/datasets/pds_heads/lineitem.feather b/examples/datasets/pds_heads/lineitem.feather new file mode 100644 index 000000000000..0b435926be21 Binary files /dev/null and b/examples/datasets/pds_heads/lineitem.feather differ diff --git a/examples/datasets/pds_heads/nation.feather b/examples/datasets/pds_heads/nation.feather new file mode 100644 index 000000000000..2dea971bcc60 Binary files /dev/null and b/examples/datasets/pds_heads/nation.feather differ diff --git a/examples/datasets/pds_heads/orders.feather b/examples/datasets/pds_heads/orders.feather new file mode 100644 index 000000000000..bd8544fdb579 Binary files /dev/null and b/examples/datasets/pds_heads/orders.feather differ diff --git a/examples/datasets/pds_heads/part.feather b/examples/datasets/pds_heads/part.feather new file mode 100644 index 000000000000..30f5aae4acf1 Binary files /dev/null and b/examples/datasets/pds_heads/part.feather differ diff --git a/examples/datasets/pds_heads/partsupp.feather b/examples/datasets/pds_heads/partsupp.feather new file mode 100644 index 000000000000..7c8026b36beb Binary files /dev/null and b/examples/datasets/pds_heads/partsupp.feather differ diff --git a/examples/datasets/pds_heads/region.feather b/examples/datasets/pds_heads/region.feather new file mode 100644 index 000000000000..0e3b85be007f Binary files /dev/null and b/examples/datasets/pds_heads/region.feather differ diff --git a/examples/datasets/pds_heads/supplier.feather b/examples/datasets/pds_heads/supplier.feather new file mode 100644 index 000000000000..48bac61d3972 Binary files /dev/null and b/examples/datasets/pds_heads/supplier.feather differ diff --git a/examples/iris_classifier/.gitignore b/examples/iris_classifier/.gitignore deleted file mode 100644 index f86308f07d59..000000000000 --- a/examples/iris_classifier/.gitignore +++ /dev/null @@ -1 +0,0 @@ -iris.csv \ No newline at end of file diff --git a/examples/iris_classifier/Cargo.toml b/examples/iris_classifier/Cargo.toml deleted file mode 100644 index 7a5abad16c30..000000000000 --- a/examples/iris_classifier/Cargo.toml +++ /dev/null @@ -1,13 +0,0 @@ -[package] -name = "iris_classifier" -version = "0.1.0" -authors = ["ritchie46 "] -edition = "2018" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -polars = {path = "../../polars", features = ["random", "ndarray"]} -reqwest = {version = "0.10.8", features = ["blocking"]} -ndarray = "0.13" -itertools = "0.9" \ No newline at end of file diff --git a/examples/iris_classifier/src/main.rs b/examples/iris_classifier/src/main.rs deleted file mode 100644 index 12e64a87a58f..000000000000 --- a/examples/iris_classifier/src/main.rs +++ /dev/null @@ -1,222 +0,0 @@ -//! Running this program outputs: -//! -//! +----------+----------+----------+----------+---------------+ -//! | column_1 | column_2 | column_3 | column_4 | column_5 | -//! | --- | --- | --- | --- | --- | -//! | f64 | f64 | f64 | f64 | str | -//! +==========+==========+==========+==========+===============+ -//! | 5.1 | 3.5 | 1.4 | 0.2 | "Iris-setosa" | -//! +----------+----------+----------+----------+---------------+ -//! | 4.9 | 3 | 1.4 | 0.2 | "Iris-setosa" | -//! +----------+----------+----------+----------+---------------+ -//! | 4.7 | 3.2 | 1.3 | 0.2 | "Iris-setosa" | -//! +----------+----------+----------+----------+---------------+ -//! -//! +--------------+-------------+-------------+--------------+---------------+ -//! | sepal.length | sepal.width | petal.width | petal.length | class | -//! | --- | --- | --- | --- | --- | -//! | f64 | f64 | f64 | f64 | str | -//! +==============+=============+=============+==============+===============+ -//! | 5.1 | 3.5 | 1.4 | 0.2 | "Iris-setosa" | -//! +--------------+-------------+-------------+--------------+---------------+ -//! | 4.9 | 3 | 1.4 | 0.2 | "Iris-setosa" | -//! +--------------+-------------+-------------+--------------+---------------+ -//! | 4.7 | 3.2 | 1.3 | 0.2 | "Iris-setosa" | -//! +--------------+-------------+-------------+--------------+---------------+ -//! -//! +--------------+-------------+-------------+--------------+---------------+ -//! | sepal.length | sepal.width | petal.width | petal.length | class | -//! | --- | --- | --- | --- | --- | -//! | f64 | f64 | f64 | f64 | str | -//! +==============+=============+=============+==============+===============+ -//! | 5.1 | 3.5 | 1.4 | 0.2 | "Iris-setosa" | -//! +--------------+-------------+-------------+--------------+---------------+ -//! | 4.9 | 3 | 1.4 | 0.2 | "Iris-setosa" | -//! +--------------+-------------+-------------+--------------+---------------+ -//! | 4.7 | 3.2 | 1.3 | 0.2 | "Iris-setosa" | -//! +--------------+-------------+-------------+--------------+---------------+ -//! -//! +--------------+-------------+-------------+--------------+---------------+ -//! | sepal.length | sepal.width | petal.width | petal.length | class | -//! | --- | --- | --- | --- | --- | -//! | f64 | f64 | f64 | f64 | str | -//! +==============+=============+=============+==============+===============+ -//! | 0.006 | 0.008 | 0.002 | 0.001 | "Iris-setosa" | -//! +--------------+-------------+-------------+--------------+---------------+ -//! | 0.006 | 0.007 | 0.002 | 0.001 | "Iris-setosa" | -//! +--------------+-------------+-------------+--------------+---------------+ -//! | 0.005 | 0.007 | 0.002 | 0.001 | "Iris-setosa" | -//! +--------------+-------------+-------------+--------------+---------------+ -//! -//! +--------------+-------------+-------------+--------------+---------------+-------------+ -//! | sepal.length | sepal.width | petal.width | petal.length | class | ohe | -//! | --- | --- | --- | --- | --- | --- | -//! | f64 | f64 | f64 | f64 | str | list [u32] | -//! +==============+=============+=============+==============+===============+=============+ -//! | 0.006 | 0.008 | 0.002 | 0.001 | "Iris-setosa" | "[0, 1, 0]" | -//! +--------------+-------------+-------------+--------------+---------------+-------------+ -//! | 0.006 | 0.007 | 0.002 | 0.001 | "Iris-setosa" | "[0, 1, 0]" | -//! +--------------+-------------+-------------+--------------+---------------+-------------+ -//! | 0.005 | 0.007 | 0.002 | 0.001 | "Iris-setosa" | "[0, 1, 0]" | -//! +--------------+-------------+-------------+--------------+---------------+-------------+ -//! -use itertools::Itertools; -use ndarray::prelude::*; -use polars::prelude::*; -use reqwest; -use std::fs::File; -use std::io::Write; -use std::path::Path; - -const FEATURES: [&str; 4] = ["sepal.length", "sepal.width", "petal.width", "petal.length"]; -const LEARNING_RATE: f64 = 0.01; - -fn download_iris() -> std::io::Result<()> { - let r = reqwest::blocking::get( - "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data", - ) - .expect("could not download iris"); - let mut f = File::create("iris.csv")?; - f.write_all(r.text().unwrap().as_bytes()) -} - -fn read_csv() -> Result { - let file = File::open("iris.csv").expect("could not read iris file"); - CsvReader::new(file) - .infer_schema(Some(100)) - .has_header(false) - .with_batch_size(100) - .finish() -} - -fn rename_cols(mut df: DataFrame) -> Result { - (0..5) - .zip(&[ - "sepal.length", - "sepal.width", - "petal.width", - "petal.length", - "class", - ]) - .for_each(|(idx, name)| { - df[idx].rename(name); - }); - - Ok(df) -} - -fn enforce_schema(mut df: DataFrame) -> Result { - let dtypes = &[ - ArrowDataType::Float64, - ArrowDataType::Float64, - ArrowDataType::Float64, - ArrowDataType::Float64, - ArrowDataType::Utf8, - ]; - - df.schema() - .clone() - .fields() - .iter() - .zip(dtypes) - .map(|(field, dtype)| { - if field.data_type() != dtype { - df.may_apply(field.name(), |col| match dtype { - ArrowDataType::Float64 => col.cast::(), - ArrowDataType::Utf8 => col.cast::(), - _ => return Err(PolarsError::Other("unexpected type".to_string())), - })?; - } - Ok(()) - }) - .collect::>()?; - Ok(df) -} - -fn normalize(mut df: DataFrame) -> Result { - let cols = &FEATURES; - - for &col in cols { - df.may_apply(col, |s| { - let ca = s.f64().unwrap(); - - match ca.sum() { - Some(sum) => Ok(ca / sum), - None => Err(PolarsError::Other("Nulls in column".to_string())), - } - })?; - } - Ok(df) -} - -fn one_hot_encode(mut df: DataFrame) -> Result { - let y = df["class"].utf8().unwrap(); - - let unique = y.unique(); - let n_unique = unique.len(); - - let mut ohe = y - .into_iter() - .map(|opt_s| { - let mut ohe = vec![0; n_unique]; - let mut idx = 0; - for i in 0..n_unique { - if unique.get(i) == opt_s { - idx = i; - break; - } - } - ohe[idx] = 1; - match opt_s { - Some(s) => UInt32Chunked::new_from_slice(s, &ohe).into_series(), - None => UInt32Chunked::new_from_slice("null", &ohe).into_series(), - } - }) - .collect::(); - ohe.rename("ohe"); - df.add_column(ohe)?; - - Ok(df) -} - -fn print_state(df: DataFrame) -> Result { - println!("{:?}", df.head(Some(3))); - Ok(df) -} - -fn pipe() -> Result { - read_csv()? - .pipe(print_state) - .unwrap() - .pipe(rename_cols) - .expect("could not rename columns") - .pipe(print_state) - .unwrap() - .pipe(enforce_schema) - .expect("could not enforce schema") - .pipe(print_state) - .unwrap() - .pipe(normalize)? - .pipe(print_state) - .unwrap() - .pipe(one_hot_encode) - .expect("could not ohe") - .pipe(print_state) -} -fn train(mut df: DataFrame) -> Result<()> { - let feat = df.select(&FEATURES)?.to_ndarray::()?; - - let target = df - .column("ohe")? - .large_list()? - .to_ndarray::()?; - todo!() -} - -fn main() { - if !Path::new("iris.csv").exists() { - download_iris().expect("could not create file") - } - - let df = pipe().expect("could not prepare DataFrame"); -} diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 000000000000..9fef5abdd30e --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,228 @@ +# https://www.mkdocs.org/user-guide/configuration/ + +# Project information +site_name: Polars user guide +site_url: https://docs.pola.rs/ +repo_url: https://github.com/pola-rs/polars +repo_name: pola-rs/polars + +docs_dir: docs/source + +# Documentation layout +nav: + - Polars: + - User guide: + - index.md + - user-guide/getting-started.md + - user-guide/installation.md + - Concepts: + - user-guide/concepts/index.md + - user-guide/concepts/data-types-and-structures.md + - user-guide/concepts/expressions-and-contexts.md + - user-guide/concepts/lazy-api.md + - Expressions: + - user-guide/expressions/index.md + - user-guide/expressions/basic-operations.md + - user-guide/expressions/expression-expansion.md + - user-guide/expressions/casting.md + - user-guide/expressions/strings.md + - user-guide/expressions/lists-and-arrays.md + - user-guide/expressions/categorical-data-and-enums.md + - user-guide/expressions/structs.md + - user-guide/expressions/missing-data.md + - user-guide/expressions/aggregation.md + - user-guide/expressions/window-functions.md + - user-guide/expressions/folds.md + - user-guide/expressions/user-defined-python-functions.md + - user-guide/expressions/numpy-functions.md + - Transformations: + - user-guide/transformations/index.md + - user-guide/transformations/joins.md + - user-guide/transformations/concatenation.md + - user-guide/transformations/pivot.md + - user-guide/transformations/unpivot.md + - Time series: + - user-guide/transformations/time-series/parsing.md + - user-guide/transformations/time-series/filter.md + - user-guide/transformations/time-series/rolling.md + - user-guide/transformations/time-series/resampling.md + - user-guide/transformations/time-series/timezones.md + - Lazy API: + - user-guide/lazy/index.md + - user-guide/lazy/using.md + - user-guide/lazy/optimizations.md + - user-guide/lazy/schemas.md + - user-guide/lazy/query-plan.md + - user-guide/lazy/execution.md + - user-guide/lazy/sources_sinks.md + - user-guide/lazy/multiplexing.md + - user-guide/lazy/gpu.md + - IO: + - user-guide/io/index.md + - user-guide/io/csv.md + - user-guide/io/excel.md + - user-guide/io/parquet.md + - user-guide/io/json.md + - user-guide/io/multiple.md + - user-guide/io/hive.md + - user-guide/io/database.md + - user-guide/io/cloud-storage.md + - user-guide/io/bigquery.md + - user-guide/io/hugging-face.md + - user-guide/io/sheets_colab.md + - Plugins: + - user-guide/plugins/index.md + - user-guide/plugins/expr_plugins.md + - user-guide/plugins/io_plugins.md + - SQL: + - user-guide/sql/intro.md + - user-guide/sql/show.md + - user-guide/sql/select.md + - user-guide/sql/create.md + - user-guide/sql/cte.md + - Migrating: + - user-guide/migration/pandas.md + - user-guide/migration/spark.md + - user-guide/ecosystem.md + - Misc: + - user-guide/misc/multiprocessing.md + - user-guide/misc/visualization.md + - user-guide/misc/styling.md + - user-guide/misc/comparison.md + - user-guide/misc/arrow.md + - user-guide/misc/polars_llms.md + - user-guide/gpu-support.md + - API: + - api/reference.md + - Development: + - Contributing: + - development/contributing/index.md + - development/contributing/ide.md + - development/contributing/test.md + - development/contributing/ci.md + - development/contributing/code-style.md + - development/versioning.md + - Releases: + - releases/changelog.md + - Upgrade guides: + - releases/upgrade/index.md + - releases/upgrade/1.md + - releases/upgrade/0.20.md + - releases/upgrade/0.19.md + - Polars Cloud: + - polars-cloud/index.md + - polars-cloud/quickstart.md + - polars-cloud/connect-cloud.md + - User guide: + - polars-cloud/run/compute-context.md + - polars-cloud/run/workflow.md + - polars-cloud/run/interactive-batch.md + - polars-cloud/run/distributed-engine.md + - polars-cloud/run/service-accounts.md + - Manage workspace: + - polars-cloud/workspace/settings.md + - polars-cloud/workspace/team.md + - Authentication: + - polars-cloud/explain/authentication.md + - Providers: + - AWS: + - polars-cloud/providers/aws/infra.md + - polars-cloud/providers/aws/permissions.md + # - Integrations: + # - polars-cloud/integrations/airflow.md + - Misc: + - polars-cloud/cli.md + - polars-cloud/faq.md + - API Reference: https://docs.cloud.pola.rs + - polars-cloud/release-notes.md +not_in_nav: | + /_build/ +validation: + links: + # Allow an absolute link to the features page for our code snippets + absolute_links: ignore + +# Build directories +theme: + name: material + locale: en + custom_dir: docs/source/_build/overrides + palette: + # Palette toggle for light mode + - media: "(prefers-color-scheme: light)" + scheme: default + primary: custom + toggle: + icon: material/brightness-7 + name: Switch to dark mode + # Palette toggle for dark mode + - media: "(prefers-color-scheme: dark)" + scheme: slate + primary: custom + toggle: + icon: material/brightness-4 + name: Switch to light mode + logo: _build/assets/logo.png + features: + - navigation.tracking + - navigation.sections + - navigation.instant + - navigation.tabs + - navigation.tabs.sticky + - navigation.footer + - navigation.indexes + - content.tabs.link + - content.code.copy + icon: + repo: fontawesome/brands/github + +extra_javascript: + - _build/js/mathjax.js + - https://unpkg.com/mathjax@3/es5/tex-mml-chtml.js +extra_css: + - _build/css/extra.css +extra: + analytics: + provider: plausible + domain: guide.pola.rs,combined.pola.rs + +# Preview controls +strict: true + +# Formatting options +markdown_extensions: + - admonition + - md_in_html + - pymdownx.details + - attr_list + - pymdownx.emoji: + emoji_index: !!python/name:material.extensions.emoji.twemoji + emoji_generator: !!python/name:material.extensions.emoji.to_svg + - pymdownx.superfences + - pymdownx.tabbed: + alternate_style: true + - pymdownx.snippets: + base_path: [".", "docs/source/src/"] + check_paths: true + dedent_subsections: true + - footnotes + - pymdownx.arithmatex: + generic: true + +hooks: + - docs/source/_build/scripts/people.py + +plugins: + - search: + lang: en + - markdown-exec + - material-plausible + - macros: + module_name: docs/source/_build/scripts/macro + - redirects: + redirect_maps: + "user-guide/index.md": "index.md" + "user-guide/basics/index.md": "user-guide/getting-started.md" + "user-guide/basics/reading-writing.md": "user-guide/getting-started.md" + "user-guide/basics/expressions.md": "user-guide/getting-started.md" + "user-guide/basics/joins.md": "user-guide/getting-started.md" diff --git a/pandas_cmp/Cargo.toml b/pandas_cmp/Cargo.toml deleted file mode 100644 index daa33db926f0..000000000000 --- a/pandas_cmp/Cargo.toml +++ /dev/null @@ -1,12 +0,0 @@ -[package] -name = "pandas_cmp" -version = "0.1.0" -authors = ["ritchie46 "] -edition = "2018" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -polars = {path = "../polars"} -glob = "0.3.0" - diff --git a/pandas_cmp/bench_groupby.py b/pandas_cmp/bench_groupby.py deleted file mode 100644 index 22fd320de3b0..000000000000 --- a/pandas_cmp/bench_groupby.py +++ /dev/null @@ -1,22 +0,0 @@ -import datetime -import pandas as pd -import glob - -files = glob.glob("../data/1*.csv") -files.sort() -with open("../data/python_bench.txt", "w") as f: - with open("../data/python_bench_str.txt", "w") as f_str: - - for fn in files: - df = pd.read_csv(fn) - df = df.astype({"str": "str"}) - t0 = datetime.datetime.now() - res = df.groupby("groups").sum() - duration = datetime.datetime.now() - t0 - print(fn, duration.microseconds, res) - f.write(f"{duration.microseconds}\n") - - t0 = datetime.datetime.now() - res = df.groupby("str").sum() - duration = datetime.datetime.now() - t0 - f_str.write(f"{duration.microseconds}\n") diff --git a/pandas_cmp/bench_join.py b/pandas_cmp/bench_join.py deleted file mode 100644 index cd46d87cecb3..000000000000 --- a/pandas_cmp/bench_join.py +++ /dev/null @@ -1,42 +0,0 @@ -import pandas as pd -import datetime -import numpy as np - -left = pd.read_csv("../data/join_left_80000.csv") -right = pd.read_csv("../data/join_right_80000.csv") - -fw = open("../data/python_bench_join.txt", "w") - -durations = [] -for _ in range(10): - t0 = datetime.datetime.now() - joined = left.merge(right, on="key", how="inner") - duration = datetime.datetime.now() - t0 - durations.append(duration.microseconds) -mean = np.mean(durations) -fw.write(f"{mean}\n") -print("inner join {} μs".format(mean)) -print("shape:", joined.shape) - - -durations = [] -for _ in range(10): - t0 = datetime.datetime.now() - joined = left.merge(right, on="key", how="left") - duration = datetime.datetime.now() - t0 - durations.append(duration.microseconds) -mean = np.mean(durations) -fw.write(f"{mean}\n") -print("left join {} μs".format(mean)) -print("shape:", joined.shape) - -durations = [] -for _ in range(10): - t0 = datetime.datetime.now() - joined = left.merge(right, on="key", how="outer") - duration = datetime.datetime.now() - t0 - durations.append(duration.microseconds) -mean = np.mean(durations) -fw.write(f"{mean}\n") -print("outer join {} μs".format(mean)) -print("shape:", joined.shape) diff --git a/pandas_cmp/create_data.py b/pandas_cmp/create_data.py deleted file mode 100644 index 791598e8e8c9..000000000000 --- a/pandas_cmp/create_data.py +++ /dev/null @@ -1,35 +0,0 @@ -import pandas as pd -import numpy as np -from pandas.util.testing import rands - -groups = np.arange(10) -str_groups = np.array(list("0123456789")) -np.random.seed(1) - -for size in [1e2, 1e3, 1e4, 1e5, 1e6]: - size = int(size) - g = np.random.choice(groups, size) - sg = np.random.choice(str_groups, size) - v = np.random.randn(size) - df = pd.DataFrame({"groups": g, "values": v, "str": sg}) - df.to_csv(f"../data/{size}.csv", index=False) - -print("data created") - -# Join benchmark data -# https://wesmckinney.com/blog/high-performance-database-joins-with-pandas-dataframe-more-benchmarks/ -# https://github.com/wesm/pandas/blob/23669822819808bbaeb6ea36a6b2ef98026884db/bench/bench_merge_sqlite.py -N = 10000 -indices = np.array([rands(10) for _ in range(N)], dtype="O") -indices2 = np.array([rands(10) for _ in range(N)], dtype="O") -key = np.tile(indices[:8000], 10) -key2 = np.tile(indices2[:8000], 10) - -left = pd.DataFrame({"key": key, "key2": key2, "value": np.random.randn(80000)}) -right = pd.DataFrame( - {"key": indices[2000:], "key2": indices2[2000:], "value2": np.random.randn(8000)} -) - -left.to_csv("../data/join_left_80000.csv", index=False) -right.to_csv("../data/join_right_80000.csv", index=False) - diff --git a/pandas_cmp/img/groupby10_.png b/pandas_cmp/img/groupby10_.png deleted file mode 100644 index 81810ae60746..000000000000 Binary files a/pandas_cmp/img/groupby10_.png and /dev/null differ diff --git a/pandas_cmp/img/join_80_000.png b/pandas_cmp/img/join_80_000.png deleted file mode 100644 index e15af9ce16e0..000000000000 Binary files a/pandas_cmp/img/join_80_000.png and /dev/null differ diff --git a/pandas_cmp/plot_join_results.py b/pandas_cmp/plot_join_results.py deleted file mode 100644 index 87f8444c6038..000000000000 --- a/pandas_cmp/plot_join_results.py +++ /dev/null @@ -1,19 +0,0 @@ -import matplotlib.pyplot as plt -import pandas as pd - -with open("../data/python_bench_join.txt") as f: - pandas = [float(a) for a in f.read().split("\n")[:-1]] - -with open("../data/rust_bench_join.txt") as f: - polars = [float(a) for a in f.read().split("\n")[:-1]] - -x = ["inner", "left", "outer"] - -df = pd.DataFrame({"join_type": x, - "pandas": pandas, - "polars": polars}) - -df.plot.bar(x="join_type", figsize=(14, 6)) -plt.title("join on 80,000 rows") -plt.ylabel("time [μs]") -plt.savefig("img/join_80_000.png") diff --git a/pandas_cmp/plot_results.py b/pandas_cmp/plot_results.py deleted file mode 100644 index 0043938abeea..000000000000 --- a/pandas_cmp/plot_results.py +++ /dev/null @@ -1,25 +0,0 @@ -import matplotlib.pyplot as plt - -with open("../data/python_bench.txt") as f: - pandas = [int(a) for a in f.read().split("\n")[:-1]] -with open("../data/python_bench_str.txt") as f: - pandas_str = [int(a) for a in f.read().split("\n")[:-1]] - -with open("../data/rust_bench.txt") as f: - polars = [int(a) for a in f.read().split("\n")[:-1]] -with open("../data/rust_bench_str.txt") as f: - polars_str = [int(a) for a in f.read().split("\n")[:-1]] - -sizes = [1e2, 1e3, 1e4, 1e5, 1e6] - -plt.figure(figsize=(14, 6)) -plt.plot(sizes, pandas, label="pandas", color="C0") -plt.plot(sizes, polars, label="polars", color="C1") -plt.plot(sizes, pandas_str, label="pandas_str", color="C0", ls="-.") -plt.plot(sizes, polars_str, label="polars_str", color="C1", ls="-.") -plt.legend() -plt.title("Group by on 10 groups") -plt.ylabel("time [microseconds]") -plt.xscale("log") -plt.xlabel("dataset size") -plt.savefig("img/groupby10_.png") diff --git a/pandas_cmp/requirements.txt b/pandas_cmp/requirements.txt deleted file mode 100644 index 9bf5de4b52ae..000000000000 --- a/pandas_cmp/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -numpy==1.16.4 -pandas==1.0.3 -matplotlib==3.1.2 \ No newline at end of file diff --git a/pandas_cmp/src/main.rs b/pandas_cmp/src/main.rs deleted file mode 100644 index 8fcb71760f57..000000000000 --- a/pandas_cmp/src/main.rs +++ /dev/null @@ -1,136 +0,0 @@ -use glob::glob; -use polars::prelude::*; -use std::env; -use std::fs::{canonicalize, File}; -use std::io::Write; -use std::process::exit; -use std::time::Instant; - -fn read_df(f: &File) -> DataFrame { - let mut df = CsvReader::new(f) - .infer_schema(Some(100)) - .has_header(true) - .finish() - .expect("dataframe"); - - // for groupby we need to cast a column to a string - if let Ok(s) = df.column("str") { - let s = s - .i64() - .expect("i64") - .into_iter() - .map(|v| v.map(|v| format!("{}", v))); - let s: Series = Series::new("str", &s.collect::>()); - df.replace("str", s).expect("replaced"); - } - df -} - -fn bench_groupby() { - let paths = glob("../data/1*.csv") - .expect("valid glob") - .map(|v| v.expect("path")); - let mut paths = paths.collect::>(); - paths.sort(); - - let mut wrt_file = File::create("../data/rust_bench.txt").expect("file"); - let mut wrt_file_str = File::create("../data/rust_bench_str.txt").expect("file"); - - for p in &paths { - let f = File::open(p).expect("a csv file"); - let df = read_df(&f); - - let now = Instant::now(); - let sum = df.groupby("groups").expect("gb").select("values").sum(); - let duration = now.elapsed().as_micros(); - - let now = Instant::now(); - let sum_str = df.groupby("str").expect("gb").select("values").sum(); - let duration_str = now.elapsed().as_micros(); - println!("{:?}", (sum, sum_str)); - println!("{:?}", (p, duration)); - wrt_file - .write(&format!("{}\n", duration).as_bytes()) - .expect("write to file"); - wrt_file_str - .write(&format!("{}\n", duration_str).as_bytes()) - .expect("write to file"); - } -} - -fn bench_join() { - let f = File::open("../data/join_left_80000.csv").expect("file"); - let left = read_df(&f); - let f = File::open("../data/join_right_80000.csv").expect("file"); - let right = read_df(&f); - let mut wrt_file = File::create("../data/rust_bench_join.txt").expect("file"); - - let mut mean = 0.0; - for _ in 0..10 { - let now = Instant::now(); - let joined = left - .inner_join(&right, "key", "key") - .expect("could not join"); - let duration = now.elapsed().as_micros(); - mean += duration as f32 - } - mean /= 10.; - println!("inner join: {} μs", mean); - writeln!(wrt_file, "{}", mean).expect("could not write"); - - let mut mean = 0.0; - for _ in 0..10 { - let now = Instant::now(); - let joined = left - .left_join(&right, "key", "key") - .expect("could not join"); - let duration = now.elapsed().as_micros(); - mean += duration as f32 - } - mean /= 10.; - println!("left join: {} μs", mean); - writeln!(wrt_file, "{}", mean).expect("could not write"); - - let mut mean = 0.0; - for _ in 0..10 { - let now = Instant::now(); - let joined = left - .outer_join(&right, "key", "key") - .expect("could not join"); - let duration = now.elapsed().as_micros(); - mean += duration as f32 - } - mean /= 10.; - println!("outer join: {} μs", mean); - writeln!(wrt_file, "{}", mean).expect("could not write"); -} - -fn print_cli() { - println!( - " -cargo run [args] - -args: - groupby - join - " - ); -} - -fn main() { - let args = env::args().collect::>(); - if args.len() == 1 { - print_cli(); - exit(0) - } - println!("Current directory: {:?}", canonicalize(".")); - match &args[1][..] { - "groupby" => bench_groupby(), - "join" => bench_join(), - other => { - println!("got {}. expected:", other); - print_cli(); - exit(1) - } - } -} diff --git a/polars/Cargo.toml b/polars/Cargo.toml deleted file mode 100644 index c04e04dae700..000000000000 --- a/polars/Cargo.toml +++ /dev/null @@ -1,36 +0,0 @@ -[package] -name = "polars" -version = "0.6.0" -authors = ["ritchie46 "] -edition = "2018" -license = "MIT" -description = "DataFrame library" -repository = "https://github.com/ritchie46/polars" -readme = "../README.md" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[features] -pretty = ["prettytable-rs"] -simd = ["arrow/simd"] -docs = [] -temporal = ["chrono"] -random = ["rand", "rand_distr"] -parallel = [] -default = ["pretty", "docs", "temporal"] - -[dependencies] -arrow = {version = "1.0.1", default_features = false} -thiserror = "^1.0.16" -num = "^0.2.1" -fnv = "^1.0.7" -itertools = "^0.9.0" -unsafe_unwrap = "^0.1.0" -rayon = "^1.3.1" -prettytable-rs = { version="^0.8.0", features=["win_crlf"], optional = true, default_features = false} -chrono = {version = "^0.4.13", optional = true} -enum_dispatch = "^0.3.2" -parquet = {version = "1", optional = true} -rand = {version = "0.7", optional = true} -rand_distr = {version = "0.3", optional = true} -ndarray = {version = "0.13", optional = true} \ No newline at end of file diff --git a/polars/benches/bench.rs b/polars/benches/bench.rs deleted file mode 100644 index 8da93cb08416..000000000000 --- a/polars/benches/bench.rs +++ /dev/null @@ -1,52 +0,0 @@ -#![feature(test)] -extern crate test; -use polars::prelude::*; -use test::Bencher; - -#[bench] -fn bench_std_iter(b: &mut Bencher) { - let v: Vec = (0..1000).collect(); - let mut sum = 0; - b.iter(|| sum = v.iter().sum::()); - println!("{}", sum) -} - -#[bench] -fn bench_warmup(b: &mut Bencher) { - let s: Series = (0u32..1000).collect(); - b.iter(|| { - s.u32().unwrap().into_iter(); - }); -} - -#[bench] -fn bench_num_iter(b: &mut Bencher) { - let s: Series = (0u32..1000).collect(); - let mut sum = 0; - b.iter(|| { - sum = s - .u32() - .unwrap() - .into_iter() - .map(|opt| opt.unwrap()) - .sum::() - }); - println!("{}", sum) -} - -#[bench] -fn bench_num_2_chunks(b: &mut Bencher) { - let mut s: Series = (0u32..500).collect(); - let s2: Series = (500u32..1000).collect(); - s.append(&s2).unwrap(); - let mut sum = 0; - b.iter(|| { - sum = s - .u32() - .unwrap() - .into_iter() - .map(|opt| opt.unwrap()) - .sum::() - }); - println!("{}", sum) -} diff --git a/polars/src/chunked_array/aggregate.rs b/polars/src/chunked_array/aggregate.rs deleted file mode 100644 index 13d572980f65..000000000000 --- a/polars/src/chunked_array/aggregate.rs +++ /dev/null @@ -1,211 +0,0 @@ -//! Implementations of the ChunkAgg trait. -use crate::chunked_array::ChunkedArray; -use crate::datatypes::BooleanChunked; -use crate::{datatypes::PolarsNumericType, prelude::*}; -use arrow::compute; -use num::{Num, NumCast, ToPrimitive}; -use std::cmp::{Ordering, PartialOrd}; -use std::ops::{Add, Div}; - -macro_rules! cmp_float_with_nans { - ($a:expr, $b:expr, $precision:ty) => {{ - let a: $precision = NumCast::from($a).unwrap(); - let b: $precision = NumCast::from($b).unwrap(); - match (a.is_nan(), b.is_nan()) { - (true, true) => Ordering::Equal, - (true, false) => Ordering::Greater, - (false, true) => Ordering::Less, - (false, false) => a.partial_cmp(&b).unwrap(), - } - }}; -} - -macro_rules! agg_float_with_nans { - ($self:ident, $agg_method:ident, $precision:ty) => {{ - if $self.null_count() == 0 { - $self - .into_no_null_iter() - .$agg_method(|&a, &b| cmp_float_with_nans!(a, b, $precision)) - } else { - $self - .into_iter() - .filter(|opt| opt.is_some()) - .map(|opt| opt.unwrap()) - .$agg_method(|&a, &b| cmp_float_with_nans!(a, b, $precision)) - } - }}; -} - -impl ChunkAgg for ChunkedArray -where - T: PolarsNumericType, - T::Native: Add - + PartialOrd - + Div - + Num - + NumCast - + ToPrimitive, -{ - fn sum(&self) -> Option { - self.downcast_chunks() - .iter() - .map(|&a| compute::sum(a)) - .fold(None, |acc, v| match v { - Some(v) => match acc { - None => Some(v), - Some(acc) => Some(acc + v), - }, - None => acc, - }) - } - - fn min(&self) -> Option { - match T::get_data_type() { - ArrowDataType::Float32 => agg_float_with_nans!(self, min_by, f32), - ArrowDataType::Float64 => agg_float_with_nans!(self, min_by, f64), - _ => self - .downcast_chunks() - .iter() - .filter_map(|&a| compute::min(a)) - .fold_first(|acc, v| if acc > v { acc } else { v }), - } - } - - fn max(&self) -> Option { - match T::get_data_type() { - ArrowDataType::Float32 => agg_float_with_nans!(self, max_by, f32), - ArrowDataType::Float64 => agg_float_with_nans!(self, max_by, f64), - _ => self - .downcast_chunks() - .iter() - .filter_map(|&a| compute::max(a)) - .fold_first(|acc, v| if acc > v { acc } else { v }), - } - } - - fn mean(&self) -> Option { - let len = (self.len() - self.null_count()) as f64; - self.sum() - .map(|v| NumCast::from(v.to_f64().unwrap() / len).unwrap()) - } - - fn median(&self) -> Option { - let null_count = self.null_count(); - self.sort(false) - .slice((self.len() - null_count) / 2 + null_count, 1) - .unwrap() - .into_iter() - .next() - .unwrap() - } -} - -fn min_max_helper(ca: &BooleanChunked, min: bool) -> Option { - let min_max = ca.into_iter().fold(0, |acc: u8, x| match x { - Some(v) => { - let v = v as u8; - if min { - if acc < v { - acc - } else { - v - } - } else { - if acc > v { - acc - } else { - v - } - } - } - None => acc, - }); - Some(min_max) -} - -/// Booleans are casted to 1 or 0. -impl ChunkAgg for BooleanChunked { - /// Returns `None` if the array is empty or only contains null values. - fn sum(&self) -> Option { - if self.len() == 0 { - return None; - } - let sum = self.into_iter().fold(0, |acc: u8, x| match x { - Some(v) => acc + v as u8, - None => acc, - }); - Some(sum) - } - - fn min(&self) -> Option { - if self.len() == 0 { - return None; - } - min_max_helper(self, true) - } - - fn max(&self) -> Option { - if self.len() == 0 { - return None; - } - min_max_helper(self, false) - } - - fn mean(&self) -> Option { - let len = self.len() - self.null_count(); - self.sum().map(|v| (v as usize / len) as u8) - } - - fn median(&self) -> Option { - let null_count = self.null_count(); - let opt_v = self - .sort(false) - .slice((self.len() - null_count) / 2 + null_count, 1) - .unwrap() - .into_iter() - .next() - .unwrap(); - opt_v.map(|v| v as u8) - } -} - -#[cfg(test)] -mod test { - use crate::prelude::*; - - #[test] - fn test_agg_float() { - let ca1 = Float32Chunked::new_from_slice("a", &[1.0, f32::NAN]); - let ca2 = Float32Chunked::new_from_slice("b", &[f32::NAN, 1.0]); - assert_eq!(ca1.min(), ca2.min()); - let ca1 = Float64Chunked::new_from_slice("a", &[1.0, f64::NAN]); - let ca2 = Float64Chunked::new_from_slice("b", &[f64::NAN, 1.0]); - assert_eq!(ca1.min(), ca2.min()); - println!("{:?}", (ca1.min(), ca2.min())) - } - - #[test] - fn test_median() { - let ca = UInt32Chunked::new_from_opt_slice( - "a", - &[Some(2), Some(1), None, Some(3), Some(5), None, Some(4)], - ); - assert_eq!(ca.median(), Some(3)); - let ca = UInt32Chunked::new_from_opt_slice( - "a", - &[ - None, - Some(7), - Some(6), - Some(2), - Some(1), - None, - Some(3), - Some(5), - None, - Some(4), - ], - ); - assert_eq!(ca.median(), Some(4)); - } -} diff --git a/polars/src/chunked_array/apply.rs b/polars/src/chunked_array/apply.rs deleted file mode 100644 index c4e1187db0e0..000000000000 --- a/polars/src/chunked_array/apply.rs +++ /dev/null @@ -1,58 +0,0 @@ -//! Implementations of the ChunkApply Trait. -use crate::prelude::*; - -impl<'a, T> ChunkApply<'a, T::Native, T::Native> for ChunkedArray -where - T: PolarsNumericType, -{ - /// Chooses the fastest path for closure application. - /// Null values remain null. - /// - /// # Example - /// - /// ``` - /// use polars::prelude::*; - /// fn double(ca: &UInt32Chunked) -> UInt32Chunked { - /// ca.apply(|v| v * 2) - /// } - /// ``` - fn apply(&'a self, f: F) -> Self - where - F: Fn(T::Native) -> T::Native + Copy, - { - if let Ok(slice) = self.cont_slice() { - slice.iter().copied().map(f).map(Some).collect() - } else { - let mut ca: ChunkedArray = self - .data_views() - .iter() - .copied() - .zip(self.null_bits()) - .map(|(slice, (null_count, opt_buffer))| { - let vec: AlignedVec<_> = slice.iter().copied().map(f).collect(); - (vec, (null_count, opt_buffer)) - }) - .collect(); - ca.rename(self.name()); - ca - } - } -} - -impl<'a> ChunkApply<'a, bool, bool> for BooleanChunked { - fn apply(&self, f: F) -> Self - where - F: Fn(bool) -> bool + Copy, - { - self.into_iter().map(|opt_v| opt_v.map(|v| f(v))).collect() - } -} - -impl<'a> ChunkApply<'a, &'a str, String> for Utf8Chunked { - fn apply(&'a self, f: F) -> Self - where - F: Fn(&'a str) -> String, - { - self.into_iter().map(|opt_v| opt_v.map(|v| f(v))).collect() - } -} diff --git a/polars/src/chunked_array/arithmetic.rs b/polars/src/chunked_array/arithmetic.rs deleted file mode 100644 index 778257761c85..000000000000 --- a/polars/src/chunked_array/arithmetic.rs +++ /dev/null @@ -1,406 +0,0 @@ -//! Implementations of arithmetic operations on ChunkedArray's. -use crate::prelude::*; -use crate::utils::Xob; -use arrow::{array::ArrayRef, compute}; -use num::{Num, NumCast, ToPrimitive}; -use std::ops::{Add, Div, Mul, Sub}; -use std::sync::Arc; - -// TODO: Add Boolean arithmetic - -macro_rules! operand_on_primitive_arr { - ($_self:expr, $rhs:tt, $operator:expr, $expect:expr) => {{ - let mut new_chunks = Vec::with_capacity($_self.chunks.len()); - $_self - .downcast_chunks() - .iter() - .zip($rhs.downcast_chunks()) - .for_each(|(left, right)| { - let res = Arc::new($operator(left, right).expect($expect)) as ArrayRef; - new_chunks.push(res); - }); - $_self.copy_with_chunks(new_chunks) - }}; -} - -#[macro_export] -macro_rules! apply_operand_on_chunkedarray_by_iter { - - ($self:ident, $rhs:ident, $operand:tt) => { - { - match ($self.null_count(), $rhs.null_count()) { - (0, 0) => { - let a: Xob> = $self - .into_no_null_iter() - .zip($rhs.into_no_null_iter()) - .map(|(left, right)| left $operand right) - .collect(); - a.into_inner() - }, - (0, _) => { - $self - .into_no_null_iter() - .zip($rhs.into_iter()) - .map(|(left, opt_right)| opt_right.map(|right| left $operand right)) - .collect() - }, - (_, 0) => { - $self - .into_iter() - .zip($rhs.into_no_null_iter()) - .map(|(opt_left, right)| opt_left.map(|left| left $operand right)) - .collect() - }, - (_, _) => { - $self.into_iter() - .zip($rhs.into_iter()) - .map(|(opt_left, opt_right)| match (opt_left, opt_right) { - (None, None) => None, - (None, Some(_)) => None, - (Some(_), None) => None, - (Some(left), Some(right)) => Some(left $operand right), - }) - .collect() - - } - } - } - } -} - -// Operands on ChunkedArray & ChunkedArray - -impl Add for &ChunkedArray -where - T: PolarsNumericType, - T::Native: Add - + Sub - + Mul - + Div - + num::Zero, -{ - type Output = ChunkedArray; - - fn add(self, rhs: Self) -> Self::Output { - if self.chunk_id == rhs.chunk_id { - let expect_str = "Could not add, check data types and length"; - operand_on_primitive_arr![self, rhs, compute::add, expect_str] - } else { - apply_operand_on_chunkedarray_by_iter!(self, rhs, +) - } - } -} - -impl Div for &ChunkedArray -where - T: PolarsNumericType, - T::Native: Add - + Sub - + Mul - + Div - + num::Zero - + num::One, -{ - type Output = ChunkedArray; - - fn div(self, rhs: Self) -> Self::Output { - if self.chunk_id == rhs.chunk_id { - let expect_str = "Could not divide, check data types and length"; - operand_on_primitive_arr!(self, rhs, compute::divide, expect_str) - } else { - apply_operand_on_chunkedarray_by_iter!(self, rhs, /) - } - } -} - -impl Mul for &ChunkedArray -where - T: PolarsNumericType, - T::Native: Add - + Sub - + Mul - + Div - + num::Zero, -{ - type Output = ChunkedArray; - - fn mul(self, rhs: Self) -> Self::Output { - if self.chunk_id == rhs.chunk_id { - let expect_str = "Could not multiply, check data types and length"; - operand_on_primitive_arr!(self, rhs, compute::multiply, expect_str) - } else { - apply_operand_on_chunkedarray_by_iter!(self, rhs, *) - } - } -} - -impl Sub for &ChunkedArray -where - T: PolarsNumericType, - T::Native: Add - + Sub - + Mul - + Div - + num::Zero, -{ - type Output = ChunkedArray; - - fn sub(self, rhs: Self) -> Self::Output { - if self.chunk_id == rhs.chunk_id { - let expect_str = "Could not subtract, check data types and length"; - operand_on_primitive_arr![self, rhs, compute::subtract, expect_str] - } else { - apply_operand_on_chunkedarray_by_iter!(self, rhs, -) - } - } -} - -impl Add for ChunkedArray -where - T: PolarsNumericType, - T::Native: Add - + Sub - + Mul - + Div - + num::Zero, -{ - type Output = Self; - - fn add(self, rhs: Self) -> Self::Output { - if self.chunk_id == rhs.chunk_id { - (&self).add(&rhs) - } else { - apply_operand_on_chunkedarray_by_iter!(self, rhs, +) - } - } -} - -impl Div for ChunkedArray -where - T: PolarsNumericType, - T::Native: Add - + Sub - + Mul - + Div - + num::Zero - + num::One, -{ - type Output = Self; - - fn div(self, rhs: Self) -> Self::Output { - (&self).div(&rhs) - } -} - -impl Mul for ChunkedArray -where - T: PolarsNumericType, - T::Native: Add - + Sub - + Mul - + Div - + num::Zero, -{ - type Output = Self; - - fn mul(self, rhs: Self) -> Self::Output { - (&self).mul(&rhs) - } -} - -impl Sub for ChunkedArray -where - T: PolarsNumericType, - T::Native: Add - + Sub - + Mul - + Div - + num::Zero, -{ - type Output = Self; - - fn sub(self, rhs: Self) -> Self::Output { - (&self).sub(&rhs) - } -} - -// Operands on ChunkedArray & Num - -impl Add for &ChunkedArray -where - T: PolarsNumericType, - T::Native: NumCast, - N: Num + ToPrimitive, - T::Native: Add - + Sub - + Mul - + Div - + num::Zero, -{ - type Output = ChunkedArray; - - fn add(self, rhs: N) -> Self::Output { - let adder: T::Native = NumCast::from(rhs).unwrap(); - if self.is_optimal_aligned() { - let intermed: Xob<_> = self.into_no_null_iter().map(|val| val + adder).collect(); - intermed.into_inner() - } else { - self.into_iter() - .map(|opt_val| opt_val.map(|val| val + adder)) - .collect() - } - } -} - -impl Sub for &ChunkedArray -where - T: PolarsNumericType, - T::Native: NumCast, - N: Num + ToPrimitive, - T::Native: Add - + Sub - + Mul - + Div - + num::Zero, -{ - type Output = ChunkedArray; - - fn sub(self, rhs: N) -> Self::Output { - let subber: T::Native = NumCast::from(rhs).unwrap(); - if self.is_optimal_aligned() { - let intermed: Xob<_> = self.into_no_null_iter().map(|val| val - subber).collect(); - intermed.into_inner() - } else { - self.into_iter() - .map(|opt_val| opt_val.map(|val| val - subber)) - .collect() - } - } -} - -impl Div for &ChunkedArray -where - T: PolarsNumericType, - T::Native: NumCast, - N: Num + ToPrimitive, - T::Native: Add - + Sub - + Mul - + Div - + num::Zero, -{ - type Output = ChunkedArray; - - fn div(self, rhs: N) -> Self::Output { - let divider: T::Native = NumCast::from(rhs).unwrap(); - if self.is_optimal_aligned() { - let intermed: Xob<_> = self.into_no_null_iter().map(|val| val / divider).collect(); - intermed.into_inner() - } else { - self.into_iter() - .map(|opt_val| opt_val.map(|val| val / divider)) - .collect() - } - } -} - -impl Mul for &ChunkedArray -where - T: PolarsNumericType, - T::Native: NumCast, - N: Num + ToPrimitive, - T::Native: Add - + Sub - + Mul - + Div - + num::Zero, -{ - type Output = ChunkedArray; - - fn mul(self, rhs: N) -> Self::Output { - let multiplier: T::Native = NumCast::from(rhs).unwrap(); - if self.is_optimal_aligned() { - let intermed: Xob<_> = self - .into_no_null_iter() - .map(|val| val * multiplier) - .collect(); - intermed.into_inner() - } else { - self.into_iter() - .map(|opt_val| opt_val.map(|val| val * multiplier)) - .collect() - } - } -} - -pub trait Pow { - fn pow_f32(&self, exp: f32) -> Float32Chunked; - fn pow_f64(&self, exp: f64) -> Float64Chunked; -} - -macro_rules! power { - ($self:ident, $exp:ident, $to_primitive:ident, $return:ident) => {{ - if let Ok(slice) = $self.cont_slice() { - slice - .iter() - .map(|&val| val.$to_primitive().unwrap().powf($exp)) - .collect::>() - .into_inner() - } else { - $self - .into_iter() - .map(|val| val.map(|val| val.$to_primitive().unwrap().powf($exp))) - .collect() - } - }}; -} - -impl Pow for ChunkedArray -where - T: PolarsNumericType, - T::Native: ToPrimitive, -{ - fn pow_f32(&self, exp: f32) -> Float32Chunked { - power!(self, exp, to_f32, Float32Chunked) - } - - fn pow_f64(&self, exp: f64) -> Float64Chunked { - power!(self, exp, to_f64, Float64Chunked) - } -} - -#[cfg(test)] -pub(crate) mod test { - use crate::prelude::*; - - pub(crate) fn create_two_chunked() -> (Int32Chunked, Int32Chunked) { - let mut a1 = Int32Chunked::new_from_slice("a", &[1, 2, 3]); - let a2 = Int32Chunked::new_from_slice("a", &[4, 5, 6]); - let a3 = Int32Chunked::new_from_slice("a", &[1, 2, 3, 4, 5, 6]); - a1.append(&a2); - (a1, a3) - } - - #[test] - fn test_chunk_mismatch() { - let (a1, a2) = create_two_chunked(); - // with different chunks - let _ = &a1 + &a2; - let _ = &a1 - &a2; - let _ = &a1 / &a2; - let _ = &a1 * &a2; - - // with same chunks - let _ = &a1 + &a1; - let _ = &a1 - &a1; - let _ = &a1 / &a1; - let _ = &a1 * &a1; - } - - #[test] - fn test_power() { - let a = UInt32Chunked::new_from_slice("", &[1, 2, 3]); - let b = a.pow_f64(2.); - println!("{:?}", b); - } -} diff --git a/polars/src/chunked_array/builder.rs b/polars/src/chunked_array/builder.rs deleted file mode 100644 index 7900f24cf2f7..000000000000 --- a/polars/src/chunked_array/builder.rs +++ /dev/null @@ -1,644 +0,0 @@ -use crate::prelude::*; -use crate::utils::get_iter_capacity; -use arrow::array::{ArrayBuilder, ArrayDataBuilder, ArrayRef}; -use arrow::datatypes::{ArrowPrimitiveType, Field, ToByteSlice}; -use arrow::{ - array::{Array, ArrayData, LargeListBuilder, PrimitiveArray, PrimitiveBuilder, StringBuilder}, - buffer::Buffer, - memory, - util::bit_util, -}; -use std::iter::FromIterator; -use std::marker::PhantomData; -use std::mem::ManuallyDrop; -use std::ops::{Deref, DerefMut}; -use std::sync::Arc; - -pub struct PrimitiveChunkedBuilder -where - T: ArrowPrimitiveType, -{ - pub builder: PrimitiveBuilder, - capacity: usize, - field: Field, -} - -impl PrimitiveChunkedBuilder -where - T: ArrowPrimitiveType, -{ - pub fn new(name: &str, capacity: usize) -> Self { - PrimitiveChunkedBuilder { - builder: PrimitiveBuilder::::new(capacity), - capacity, - field: Field::new(name, T::get_data_type(), true), - } - } - - /// Appends a value of type `T` into the builder - pub fn append_value(&mut self, v: T::Native) { - self.builder.append_value(v).expect("could not append"); - } - - /// Appends a null slot into the builder - pub fn append_null(&mut self) { - self.builder.append_null().expect("could not append"); - } - - /// Appends an `Option` into the builder - pub fn append_option(&mut self, v: Option) { - self.builder.append_option(v).expect("could not append"); - } - - pub fn finish(mut self) -> ChunkedArray { - let arr = Arc::new(self.builder.finish()); - let len = arr.len(); - ChunkedArray { - field: Arc::new(self.field), - chunks: vec![arr], - chunk_id: vec![len], - phantom: PhantomData, - } - } -} - -impl Deref for PrimitiveChunkedBuilder { - type Target = PrimitiveBuilder; - - fn deref(&self) -> &Self::Target { - &self.builder - } -} - -impl DerefMut for PrimitiveChunkedBuilder { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.builder - } -} - -pub type BooleanChunkedBuilder = PrimitiveChunkedBuilder; - -pub struct Utf8ChunkedBuilder { - pub builder: StringBuilder, - capacity: usize, - field: Field, -} - -impl Utf8ChunkedBuilder { - pub fn new(name: &str, capacity: usize) -> Self { - Utf8ChunkedBuilder { - builder: StringBuilder::new(capacity), - capacity, - field: Field::new(name, ArrowDataType::Utf8, true), - } - } - - /// Appends a value of type `T` into the builder - pub fn append_value>(&mut self, v: S) { - self.builder - .append_value(v.as_ref()) - .expect("could not append"); - } - - /// Appends a null slot into the builder - pub fn append_null(&mut self) { - self.builder.append_null().expect("could not append"); - } - - pub fn append_option>(&mut self, opt: Option) { - match opt { - Some(s) => self.append_value(s.as_ref()), - None => self.append_null(), - } - } - - pub fn finish(mut self) -> Utf8Chunked { - let arr = Arc::new(self.builder.finish()); - let len = arr.len(); - ChunkedArray { - field: Arc::new(self.field), - chunks: vec![arr], - chunk_id: vec![len], - phantom: PhantomData, - } - } -} - -impl Deref for Utf8ChunkedBuilder { - type Target = StringBuilder; - - fn deref(&self) -> &Self::Target { - &self.builder - } -} - -impl DerefMut for Utf8ChunkedBuilder { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.builder - } -} - -pub fn build_primitive_ca_with_opt(s: &[Option], name: &str) -> ChunkedArray -where - T: ArrowPrimitiveType, - T::Native: Copy, -{ - let mut builder = PrimitiveChunkedBuilder::new(name, s.len()); - for opt in s { - builder.append_option(*opt); - } - let ca = builder.finish(); - ca -} - -fn set_null_bits( - mut builder: ArrayDataBuilder, - null_bit_buffer: Option, - null_count: usize, - len: usize, -) -> ArrayDataBuilder { - if null_count > 0 { - let null_bit_buffer = - null_bit_buffer.expect("implementation error. Should not be None if null_count > 0"); - debug_assert!(null_count == len - bit_util::count_set_bits(null_bit_buffer.data())); - builder = builder - .null_count(null_count) - .null_bit_buffer(null_bit_buffer); - } - builder -} - -/// Take an existing slice and a null bitmap and construct an arrow array. -pub fn build_with_existing_null_bitmap_and_slice( - null_bit_buffer: Option, - null_count: usize, - values: &[T::Native], -) -> PrimitiveArray -where - T: ArrowPrimitiveType, -{ - let len = values.len(); - // See: - // https://docs.rs/arrow/0.16.0/src/arrow/array/builder.rs.html#314 - // TODO: make implementation for aligned owned vector for zero copy creation. - let builder = ArrayData::builder(T::get_data_type()) - .len(len) - .add_buffer(Buffer::from(values.to_byte_slice())); - - let builder = set_null_bits(builder, null_bit_buffer, null_count, len); - let data = builder.build(); - PrimitiveArray::::from(data) -} - -/// Get the null count and the null bitmap of the arrow array -pub fn get_bitmap(arr: &T) -> (usize, Option) { - let data = arr.data(); - ( - data.null_count(), - data.null_bitmap().as_ref().map(|bitmap| { - let buff = bitmap.buffer_ref(); - buff.clone() - }), - ) -} - -// Used in polars/src/chunked_array/apply.rs:24 to collect from aligned vecs and null bitmaps -impl FromIterator<(AlignedVec, (usize, Option))> for ChunkedArray -where - T: PolarsNumericType, -{ - fn from_iter, (usize, Option))>>( - iter: I, - ) -> Self { - let mut chunks = vec![]; - - for (values, (null_count, opt_buffer)) in iter { - let arr = aligned_vec_to_primitive_array::(values, opt_buffer, null_count); - chunks.push(Arc::new(arr) as ArrayRef) - } - ChunkedArray::new_from_chunks("from_iter", chunks) - } -} - -/// Returns the nearest number that is `>=` than `num` and is a multiple of 64 -#[inline] -pub fn round_upto_multiple_of_64(num: usize) -> usize { - round_upto_power_of_2(num, 64) -} - -/// Returns the nearest multiple of `factor` that is `>=` than `num`. Here `factor` must -/// be a power of 2. -fn round_upto_power_of_2(num: usize, factor: usize) -> usize { - debug_assert!(factor > 0 && (factor & (factor - 1)) == 0); - (num + (factor - 1)) & !(factor - 1) -} - -/// Take an owned Vec that is 64 byte aligned and create a zero copy PrimitiveArray -/// Can also take a null bit buffer into account. -pub fn aligned_vec_to_primitive_array( - values: AlignedVec, - null_bit_buffer: Option, - null_count: usize, -) -> PrimitiveArray { - let values = values.into_inner(); - let vec_len = values.len(); - - let me = ManuallyDrop::new(values); - let ptr = me.as_ptr() as *const u8; - let len = me.len() * std::mem::size_of::(); - let capacity = me.capacity() * std::mem::size_of::(); - debug_assert_eq!((ptr as usize) % 64, 0); - - let buffer = unsafe { Buffer::from_raw_parts(ptr, len, capacity) }; - - let builder = ArrayData::builder(T::get_data_type()) - .len(vec_len) - .add_buffer(buffer); - - let builder = set_null_bits(builder, null_bit_buffer, null_count, vec_len); - let data = builder.build(); - - PrimitiveArray::::from(data) -} - -pub trait AlignedAlloc { - fn with_capacity_aligned(size: usize) -> Vec; -} - -impl AlignedAlloc for Vec { - /// Create a new Vec where first bytes memory address has an alignment of 64 bytes, as described - /// by arrow spec. - /// Read more: - /// https://github.com/rust-ndarray/ndarray/issues/771 - fn with_capacity_aligned(size: usize) -> Vec { - // Can only have a zero copy to arrow memory if address of first byte % 64 == 0 - let t_size = std::mem::size_of::(); - let capacity = size * t_size; - let ptr = memory::allocate_aligned(capacity) as *mut T; - unsafe { Vec::from_raw_parts(ptr, 0, capacity) } - } -} - -pub struct AlignedVec(pub Vec); - -impl FromIterator for AlignedVec { - fn from_iter>(iter: I) -> Self { - let mut iter = iter.into_iter(); - let sh = iter.size_hint(); - let size = sh.1.unwrap_or(sh.0); - - let mut inner = Vec::with_capacity_aligned(size); - - while let Some(v) = iter.next() { - inner.push(v) - } - - // Iterator size hint wasn't correct and reallocation has occurred - assert!(inner.len() <= size); - AlignedVec(inner) - } -} - -impl AlignedVec { - pub fn new(v: Vec) -> Result { - if v.as_ptr() as usize % 64 != 0 { - Err(PolarsError::MemoryNotAligned) - } else { - Ok(AlignedVec(v)) - } - } - pub fn into_inner(self) -> Vec { - self.0 - } -} - -pub trait NewChunkedArray { - fn new_from_slice(name: &str, v: &[N]) -> Self; - fn new_from_opt_slice(name: &str, opt_v: &[Option]) -> Self; - - /// Create a new ChunkedArray from an iterator. - fn new_from_opt_iter(name: &str, it: impl Iterator>) -> Self; - - /// Create a new ChunkedArray from an iterator. - fn new_from_iter(name: &str, it: impl Iterator) -> Self; -} - -impl NewChunkedArray for ChunkedArray -where - T: ArrowPrimitiveType, -{ - fn new_from_slice(name: &str, v: &[T::Native]) -> Self { - let mut builder = PrimitiveChunkedBuilder::::new(name, v.len()); - v.iter().for_each(|&v| builder.append_value(v)); - builder.finish() - } - - fn new_from_opt_slice(name: &str, opt_v: &[Option]) -> Self { - let mut builder = PrimitiveChunkedBuilder::::new(name, opt_v.len()); - opt_v.iter().for_each(|&opt| builder.append_option(opt)); - builder.finish() - } - - fn new_from_opt_iter( - name: &str, - it: impl Iterator>, - ) -> ChunkedArray { - let mut builder = PrimitiveChunkedBuilder::new(name, get_iter_capacity(&it)); - it.for_each(|opt| builder.append_option(opt)); - builder.finish() - } - - /// Create a new ChunkedArray from an iterator. - fn new_from_iter(name: &str, it: impl Iterator) -> ChunkedArray { - let mut builder = PrimitiveChunkedBuilder::new(name, get_iter_capacity(&it)); - it.for_each(|opt| builder.append_value(opt)); - builder.finish() - } -} - -impl NewChunkedArray for Utf8Chunked -where - S: AsRef, -{ - fn new_from_slice(name: &str, v: &[S]) -> Self { - let mut builder = StringBuilder::new(v.len()); - v.into_iter().for_each(|val| { - builder - .append_value(val.as_ref()) - .expect("Could not append value"); - }); - - let field = Arc::new(Field::new(name, ArrowDataType::Utf8, true)); - - ChunkedArray { - field, - chunks: vec![Arc::new(builder.finish())], - chunk_id: vec![v.len()], - phantom: PhantomData, - } - } - - fn new_from_opt_slice(name: &str, opt_v: &[Option]) -> Self { - let mut builder = Utf8ChunkedBuilder::new(name, opt_v.len()); - - opt_v.iter().for_each(|opt| match opt { - Some(v) => builder.append_value(v.as_ref()), - None => builder.append_null(), - }); - builder.finish() - } - - fn new_from_opt_iter(name: &str, it: impl Iterator>) -> Self { - let mut builder = Utf8ChunkedBuilder::new(name, get_iter_capacity(&it)); - it.for_each(|opt| builder.append_option(opt)); - builder.finish() - } - - /// Create a new ChunkedArray from an iterator. - fn new_from_iter(name: &str, it: impl Iterator) -> Self { - let mut builder = Utf8ChunkedBuilder::new(name, get_iter_capacity(&it)); - it.for_each(|v| builder.append_value(v)); - builder.finish() - } -} - -pub trait LargListBuilderTrait { - fn append_opt_series(&mut self, opt_s: &Option); - fn append_series(&mut self, s: &Series); - fn finish(&mut self) -> LargeListChunked; -} - -pub struct LargeListPrimitiveChunkedBuilder -where - T: ArrowPrimitiveType, -{ - pub builder: LargeListBuilder>, - field: Field, -} - -macro_rules! append_opt_series { - ($self:ident, $opt_s: ident) => {{ - match $opt_s { - Some(s) => { - let data = s.array_data(); - $self - .builder - .values() - .append_data(&data) - .expect("should not fail"); - $self.builder.append(true).expect("should not fail"); - } - None => { - $self.builder.append(false).expect("should not fail"); - } - } - }}; -} - -macro_rules! append_series { - ($self:ident, $s: ident) => {{ - let data = $s.array_data(); - $self - .builder - .values() - .append_data(&data) - .expect("should not fail"); - $self.builder.append(true).expect("should not fail"); - }}; -} - -macro_rules! finish_largelist_builder { - ($self:ident) => {{ - let arr = Arc::new($self.builder.finish()); - let len = arr.len(); - LargeListChunked { - field: Arc::new($self.field.clone()), - chunks: vec![arr], - chunk_id: vec![len], - phantom: PhantomData, - } - }}; -} - -impl LargeListPrimitiveChunkedBuilder -where - T: ArrowPrimitiveType, -{ - pub fn new(name: &str, values_builder: PrimitiveBuilder, capacity: usize) -> Self { - let builder = LargeListBuilder::with_capacity(values_builder, capacity); - let field = Field::new( - name, - ArrowDataType::LargeList(Box::new(T::get_data_type())), - true, - ); - - LargeListPrimitiveChunkedBuilder { builder, field } - } - - pub fn append_slice(&mut self, opt_v: Option<&[T::Native]>) { - match opt_v { - Some(v) => { - self.builder - .values() - .append_slice(v) - .expect("could not append"); - self.builder.append(true).expect("should not fail"); - } - None => { - self.builder.append(false).expect("should not fail"); - } - } - } - pub fn append_opt_slice(&mut self, opt_v: Option<&[Option]>) { - match opt_v { - Some(v) => { - v.iter().for_each(|opt| { - self.builder - .values() - .append_option(*opt) - .expect("could not append") - }); - self.builder.append(true).expect("should not fail"); - } - None => { - self.builder.append(false).expect("should not fail"); - } - } - } - - pub fn append_null(&mut self) { - self.builder.append(false).expect("should not fail"); - } -} - -impl LargListBuilderTrait for LargeListPrimitiveChunkedBuilder -where - T: ArrowPrimitiveType, -{ - fn append_opt_series(&mut self, opt_s: &Option) { - append_opt_series!(self, opt_s) - } - - fn append_series(&mut self, s: &Series) { - append_series!(self, s); - } - - fn finish(&mut self) -> LargeListChunked { - finish_largelist_builder!(self) - } -} - -pub struct LargeListUtf8ChunkedBuilder { - builder: LargeListBuilder, - field: Field, -} - -impl LargeListUtf8ChunkedBuilder { - pub fn new(name: &str, values_builder: StringBuilder, capacity: usize) -> Self { - let builder = LargeListBuilder::with_capacity(values_builder, capacity); - let field = Field::new( - name, - ArrowDataType::LargeList(Box::new(ArrowDataType::Utf8)), - true, - ); - - LargeListUtf8ChunkedBuilder { builder, field } - } -} - -impl LargListBuilderTrait for LargeListUtf8ChunkedBuilder { - fn append_opt_series(&mut self, opt_s: &Option) { - append_opt_series!(self, opt_s) - } - - fn append_series(&mut self, s: &Series) { - append_series!(self, s); - } - - fn finish(&mut self) -> LargeListChunked { - finish_largelist_builder!(self) - } -} - -pub fn get_large_list_builder( - dt: &ArrowDataType, - capacity: usize, - name: &str, -) -> Box { - macro_rules! get_primitive_builder { - ($type:ty) => {{ - let values_builder = PrimitiveBuilder::<$type>::new(capacity); - let builder = LargeListPrimitiveChunkedBuilder::new(&name, values_builder, capacity); - Box::new(builder) - }}; - } - macro_rules! get_utf8_builder { - () => {{ - let values_builder = StringBuilder::new(capacity); - let builder = LargeListUtf8ChunkedBuilder::new(&name, values_builder, capacity); - Box::new(builder) - }}; - } - match_arrow_data_type_apply_macro!(dt, get_primitive_builder, get_utf8_builder) -} - -#[cfg(test)] -mod test { - use super::*; - use arrow::array::Int32Array; - - #[test] - fn test_existing_null_bitmap() { - let mut builder = PrimitiveBuilder::::new(3); - for val in &[Some(1), None, Some(2)] { - builder.append_option(*val).unwrap(); - } - let arr = builder.finish(); - let (null_count, buf) = get_bitmap(&arr); - - let new_arr = - build_with_existing_null_bitmap_and_slice::(buf, null_count, &[7, 8, 9]); - assert!(new_arr.is_valid(0)); - assert!(new_arr.is_null(1)); - assert!(new_arr.is_valid(2)); - } - - #[test] - fn from_vec() { - // Can only have a zero copy to arrow memory if address of first byte % 64 == 0 - let mut v = Vec::with_capacity_aligned(2); - v.push(1); - v.push(2); - - let ptr = v.as_ptr(); - assert_eq!((ptr as usize) % 64, 0); - let a = aligned_vec_to_primitive_array::(AlignedVec::new(v).unwrap(), None, 0); - assert_eq!(a.value_slice(0, 2), &[1, 2]) - } - - #[test] - fn test_list_builder() { - let values_builder = Int32Array::builder(10); - let mut builder = LargeListPrimitiveChunkedBuilder::new("a", values_builder, 10); - - // create a series containing two chunks - let mut s1 = Int32Chunked::new_from_slice("a", &[1, 2, 3]).into_series(); - let s2 = Int32Chunked::new_from_slice("b", &[4, 5, 6]).into_series(); - s1.append(&s2).unwrap(); - - builder.append_series(&s1); - builder.append_series(&s2); - let ls = builder.finish(); - if let AnyType::LargeList(s) = ls.get_any(0) { - // many chunks are aggregated to one in the ListArray - assert_eq!(s.len(), 6) - } else { - assert!(false) - } - if let AnyType::LargeList(s) = ls.get_any(1) { - assert_eq!(s.len(), 3) - } else { - assert!(false) - } - } -} diff --git a/polars/src/chunked_array/cast.rs b/polars/src/chunked_array/cast.rs deleted file mode 100644 index edaa88b3ecf5..000000000000 --- a/polars/src/chunked_array/cast.rs +++ /dev/null @@ -1,56 +0,0 @@ -//! Implementations of the ChunkCast Trait. -use crate::prelude::*; -use arrow::compute; - -fn cast_ca(ca: &ChunkedArray) -> Result> -where - N: PolarsDataType, - T: PolarsDataType, -{ - match N::get_data_type() { - // only i32 can be cast to Date32 - ArrowDataType::Date32(DateUnit::Day) => { - if T::get_data_type() != ArrowDataType::Int32 { - let casted_i32 = cast_ca::(ca)?; - return cast_ca(&casted_i32); - } - } - _ => (), - }; - let chunks = ca - .chunks - .iter() - .map(|arr| compute::cast(arr, &N::get_data_type())) - .collect::>>()?; - - Ok(ChunkedArray::::new_from_chunks(ca.field.name(), chunks)) -} - -impl ChunkCast for ChunkedArray -where - T: PolarsNumericType, -{ - fn cast(&self) -> Result> - where - N: PolarsDataType, - { - cast_ca(self) - } -} - -macro_rules! impl_chunkcast { - ($ca_type:ident) => { - impl ChunkCast for $ca_type { - fn cast(&self) -> Result> - where - N: PolarsDataType, - { - cast_ca(self) - } - } - }; -} - -impl_chunkcast!(Utf8Chunked); -impl_chunkcast!(BooleanChunked); -impl_chunkcast!(LargeListChunked); diff --git a/polars/src/chunked_array/chunkops.rs b/polars/src/chunked_array/chunkops.rs deleted file mode 100644 index 0d20af07b2cd..000000000000 --- a/polars/src/chunked_array/chunkops.rs +++ /dev/null @@ -1,179 +0,0 @@ -use crate::chunked_array::builder::get_large_list_builder; -use crate::prelude::*; -use arrow::array::{Array, ArrayRef, PrimitiveBuilder, StringBuilder}; -use std::sync::Arc; - -// TODO: Test rechunking properly - -pub trait ChunkOps { - /// Aggregate to chunk id. - /// A chunk id is a vector of the chunk lengths. - fn rechunk(&self, chunk_lengths: Option<&[usize]>) -> Result - where - Self: std::marker::Sized; - /// Only rechunk if lhs and rhs don't match - fn optional_rechunk(&self, rhs: &ChunkedArray) -> Result> - where - Self: std::marker::Sized; -} - -macro_rules! optional_rechunk { - ($self:tt, $rhs:tt) => { - if $self.chunk_id != $rhs.chunk_id { - // we can rechunk ourselves to match - $self.rechunk(Some(&$rhs.chunk_id)).map(|v| Some(v)) - } else { - Ok(None) - } - }; -} - -#[inline] -fn mimic_chunks(arr: &ArrayRef, chunk_lengths: &[usize], name: &str) -> Result> -where - T: PolarsDataType, - ChunkedArray: ChunkOps, -{ - let mut chunks = Vec::with_capacity(chunk_lengths.len()); - let mut offset = 0; - for chunk_length in chunk_lengths { - chunks.push(arr.slice(offset, *chunk_length)); - offset += *chunk_length - } - Ok(ChunkedArray::new_from_chunks(name, chunks)) -} - -impl ChunkOps for ChunkedArray -where - T: PolarsNumericType, -{ - fn rechunk(&self, chunk_lengths: Option<&[usize]>) -> Result { - // we aggregate to 1 or chunk_id - match (self.chunks.len(), chunk_lengths.map(|v| v.len())) { - // No rechunking needed. - (1, Some(1)) | (1, None) => Ok(self.clone()), - // Left contains a single chunk. We can cheaply mimic right as arrow slices are zero copy - (1, Some(_)) => mimic_chunks(&self.chunks[0], chunk_lengths.unwrap(), self.name()), - // Left will be aggregated to match right - (_, Some(_)) | (_, None) => { - let default = &[self.len()]; - let chunk_id = chunk_lengths.unwrap_or(default); - let mut iter = self.into_iter(); - let mut chunks: Vec> = Vec::with_capacity(chunk_id.len()); - - for &chunk_length in chunk_id { - let mut builder = PrimitiveBuilder::::new(chunk_length); - - for _ in 0..chunk_length { - builder.append_option( - iter.next() - .expect("the first option is the iterator bounds"), - )?; - } - chunks.push(Arc::new(builder.finish())) - } - Ok(ChunkedArray::new_from_chunks(self.name(), chunks)) - } - } - } - - fn optional_rechunk(&self, rhs: &ChunkedArray) -> Result> { - optional_rechunk!(self, rhs) - } -} - -impl ChunkOps for BooleanChunked { - fn rechunk(&self, chunk_lengths: Option<&[usize]>) -> Result { - match (self.chunks.len(), chunk_lengths.map(|v| v.len())) { - (1, Some(1)) | (1, None) => Ok(self.clone()), - (1, Some(_)) => mimic_chunks(&self.chunks[0], chunk_lengths.unwrap(), self.name()), - (_, Some(_)) | (_, None) => { - let default = &[self.len()]; - let chunk_id = chunk_lengths.unwrap_or(default); - - let mut iter = self.into_iter(); - let mut chunks: Vec> = Vec::with_capacity(chunk_id.len()); - - for &chunk_length in chunk_id { - let mut builder = PrimitiveBuilder::::new(chunk_length); - - for _ in 0..chunk_length { - builder.append_option( - iter.next() - .expect("the first option is the iterator bounds"), - )?; - } - chunks.push(Arc::new(builder.finish())) - } - Ok(ChunkedArray::new_from_chunks(self.name(), chunks)) - } - } - } - - fn optional_rechunk(&self, rhs: &ChunkedArray) -> Result> { - optional_rechunk!(self, rhs) - } -} - -impl ChunkOps for Utf8Chunked { - fn rechunk(&self, chunk_lengths: Option<&[usize]>) -> Result { - match (self.chunks.len(), chunk_lengths.map(|v| v.len())) { - (1, Some(1)) | (1, None) => Ok(self.clone()), - (1, Some(_)) => mimic_chunks(&self.chunks[0], chunk_lengths.unwrap(), self.name()), - (_, Some(_)) | (_, None) => { - let default = &[self.len()]; - let chunk_id = chunk_lengths.unwrap_or(default); - let mut iter = self.into_iter(); - let mut chunks: Vec> = Vec::with_capacity(chunk_id.len()); - - for &chunk_length in chunk_id { - let mut builder = StringBuilder::new(chunk_length); - - for _ in 0..chunk_length { - let opt_val = iter.next().expect("first option is iterator bounds"); - match opt_val { - None => builder.append_null().expect("should not fail"), - Some(val) => builder.append_value(val).expect("should not fail"), - } - } - chunks.push(Arc::new(builder.finish())) - } - Ok(ChunkedArray::new_from_chunks(self.name(), chunks)) - } - } - } - - fn optional_rechunk(&self, rhs: &ChunkedArray) -> Result> { - optional_rechunk!(self, rhs) - } -} -impl ChunkOps for LargeListChunked { - fn rechunk(&self, chunk_lengths: Option<&[usize]>) -> Result { - match (self.chunks.len(), chunk_lengths.map(|v| v.len())) { - (1, Some(1)) | (1, None) => Ok(self.clone()), - (1, Some(_)) => mimic_chunks(&self.chunks[0], chunk_lengths.unwrap(), self.name()), - (_, Some(_)) | (_, None) => { - let default = &[self.len()]; - let chunk_id = chunk_lengths.unwrap_or(default); - let mut iter = self.into_iter(); - let mut chunks: Vec> = Vec::with_capacity(chunk_id.len()); - - for &chunk_length in chunk_id { - let mut builder = - get_large_list_builder(self.dtype(), chunk_length, self.name()); - while let Some(v) = iter.next() { - builder.append_opt_series(&v) - } - let list = builder.finish(); - // cheap clone of Arc - chunks.push(list.chunks[0].clone()) - } - Ok(ChunkedArray::new_from_chunks(self.name(), chunks)) - } - } - } - - fn optional_rechunk(&self, rhs: &ChunkedArray) -> Result> { - optional_rechunk!(self, rhs) - } -} diff --git a/polars/src/chunked_array/comparison.rs b/polars/src/chunked_array/comparison.rs deleted file mode 100644 index b55e42001d0f..000000000000 --- a/polars/src/chunked_array/comparison.rs +++ /dev/null @@ -1,730 +0,0 @@ -use crate::{prelude::*, utils::Xob}; -use arrow::{ - array::{ArrayRef, BooleanArray, PrimitiveArray, StringArray}, - compute, -}; -use num::{Num, NumCast, ToPrimitive}; -use std::ops::{BitAnd, BitOr, Not}; -use std::sync::Arc; - -impl ChunkedArray -where - T: PolarsNumericType, -{ - /// First ensure that the chunks of lhs and rhs match and then iterates over the chunks and applies - /// the comparison operator. - fn comparison( - &self, - rhs: &ChunkedArray, - operator: impl Fn(&PrimitiveArray, &PrimitiveArray) -> arrow::error::Result, - ) -> Result { - let chunks = self - .downcast_chunks() - .iter() - .zip(rhs.downcast_chunks()) - .map(|(left, right)| { - let arr_res = operator(left, right); - let arr = match arr_res { - Ok(arr) => arr, - Err(e) => return Err(PolarsError::ArrowError(e)), - }; - Ok(Arc::new(arr) as ArrayRef) - }) - .collect::>>()?; - - Ok(ChunkedArray::new_from_chunks("", chunks)) - } -} - -macro_rules! impl_eq_missing { - ($self:ident, $rhs:ident) => {{ - match ($self.null_count(), $rhs.null_count()) { - (0, 0) => $self - .into_no_null_iter() - .zip($rhs.into_no_null_iter()) - .map(|(opt_a, opt_b)| opt_a == opt_b) - .collect(), - (_, _) => $self - .into_iter() - .zip($rhs) - .map(|(opt_a, opt_b)| opt_a == opt_b) - .collect(), - } - }}; -} - -impl ChunkCompare<&ChunkedArray> for ChunkedArray -where - T: PolarsNumericType, -{ - fn eq_missing(&self, rhs: &ChunkedArray) -> BooleanChunked { - impl_eq_missing!(self, rhs) - } - - fn eq(&self, rhs: &ChunkedArray) -> BooleanChunked { - if self.chunk_id == rhs.chunk_id { - // should not fail if arrays are equal - self.comparison(rhs, compute::eq).expect("should not fail.") - } else { - apply_operand_on_chunkedarray_by_iter!(self, rhs, ==) - } - } - - fn neq(&self, rhs: &ChunkedArray) -> BooleanChunked { - if self.chunk_id == rhs.chunk_id { - self.comparison(rhs, compute::neq) - .expect("should not fail.") - } else { - apply_operand_on_chunkedarray_by_iter!(self, rhs, !=) - } - } - - fn gt(&self, rhs: &ChunkedArray) -> BooleanChunked { - if self.chunk_id == rhs.chunk_id { - self.comparison(rhs, compute::gt).expect("should not fail.") - } else { - apply_operand_on_chunkedarray_by_iter!(self, rhs, >) - } - } - - fn gt_eq(&self, rhs: &ChunkedArray) -> BooleanChunked { - if self.chunk_id == rhs.chunk_id { - self.comparison(rhs, compute::gt_eq) - .expect("should not fail.") - } else { - apply_operand_on_chunkedarray_by_iter!(self, rhs, >=) - } - } - - fn lt(&self, rhs: &ChunkedArray) -> BooleanChunked { - if self.chunk_id == rhs.chunk_id { - self.comparison(rhs, compute::lt).expect("should not fail.") - } else { - apply_operand_on_chunkedarray_by_iter!(self, rhs, <) - } - } - - fn lt_eq(&self, rhs: &ChunkedArray) -> BooleanChunked { - if self.chunk_id == rhs.chunk_id { - self.comparison(rhs, compute::lt_eq) - .expect("should not fail.") - } else { - apply_operand_on_chunkedarray_by_iter!(self, rhs, <=) - } - } -} - -impl ChunkCompare<&BooleanChunked> for BooleanChunked { - fn eq_missing(&self, rhs: &BooleanChunked) -> BooleanChunked { - impl_eq_missing!(self, rhs) - } - - fn eq(&self, rhs: &BooleanChunked) -> BooleanChunked { - apply_operand_on_chunkedarray_by_iter!(self, rhs, ==) - } - - fn neq(&self, rhs: &BooleanChunked) -> BooleanChunked { - apply_operand_on_chunkedarray_by_iter!(self, rhs, !=) - } - - fn gt(&self, rhs: &BooleanChunked) -> BooleanChunked { - apply_operand_on_chunkedarray_by_iter!(self, rhs, >) - } - - fn gt_eq(&self, rhs: &BooleanChunked) -> BooleanChunked { - apply_operand_on_chunkedarray_by_iter!(self, rhs, >=) - } - - fn lt(&self, rhs: &BooleanChunked) -> BooleanChunked { - apply_operand_on_chunkedarray_by_iter!(self, rhs, <) - } - - fn lt_eq(&self, rhs: &BooleanChunked) -> BooleanChunked { - apply_operand_on_chunkedarray_by_iter!(self, rhs, <=) - } -} - -impl Utf8Chunked { - fn comparison( - &self, - rhs: &Utf8Chunked, - operator: impl Fn(&StringArray, &StringArray) -> arrow::error::Result, - ) -> Result { - let chunks = self - .chunks - .iter() - .zip(&rhs.chunks) - .map(|(left, right)| { - let left = left - .as_any() - .downcast_ref::() - .expect("could not downcast one of the chunks"); - let right = right - .as_any() - .downcast_ref::() - .expect("could not downcast one of the chunks"); - let arr_res = operator(left, right); - let arr = match arr_res { - Ok(arr) => arr, - Err(e) => return Err(PolarsError::ArrowError(e)), - }; - Ok(Arc::new(arr) as ArrayRef) - }) - .collect::>>()?; - - Ok(ChunkedArray::new_from_chunks("", chunks)) - } -} - -impl ChunkCompare<&Utf8Chunked> for Utf8Chunked { - fn eq_missing(&self, rhs: &Utf8Chunked) -> BooleanChunked { - impl_eq_missing!(self, rhs) - } - - fn eq(&self, rhs: &Utf8Chunked) -> BooleanChunked { - if self.chunk_id == rhs.chunk_id { - self.comparison(rhs, compute::eq_utf8) - .expect("should not fail") - } else { - apply_operand_on_chunkedarray_by_iter!(self, rhs, ==) - } - } - - fn neq(&self, rhs: &Utf8Chunked) -> BooleanChunked { - if self.chunk_id == rhs.chunk_id { - self.comparison(rhs, compute::neq_utf8) - .expect("should not fail") - } else { - apply_operand_on_chunkedarray_by_iter!(self, rhs, !=) - } - } - - fn gt(&self, rhs: &Utf8Chunked) -> BooleanChunked { - if self.chunk_id == rhs.chunk_id { - self.comparison(rhs, compute::gt_utf8) - .expect("should not fail") - } else { - apply_operand_on_chunkedarray_by_iter!(self, rhs, >) - } - } - - fn gt_eq(&self, rhs: &Utf8Chunked) -> BooleanChunked { - if self.chunk_id == rhs.chunk_id { - self.comparison(rhs, compute::gt_eq_utf8) - .expect("should not fail") - } else { - apply_operand_on_chunkedarray_by_iter!(self, rhs, >=) - } - } - - fn lt(&self, rhs: &Utf8Chunked) -> BooleanChunked { - if self.chunk_id == rhs.chunk_id { - self.comparison(rhs, compute::lt_utf8) - .expect("should not fail") - } else { - apply_operand_on_chunkedarray_by_iter!(self, rhs, <) - } - } - - fn lt_eq(&self, rhs: &Utf8Chunked) -> BooleanChunked { - if self.chunk_id == rhs.chunk_id { - self.comparison(rhs, compute::lt_eq_utf8) - .expect("should not fail") - } else { - apply_operand_on_chunkedarray_by_iter!(self, rhs, <=) - } - } -} - -fn cmp_chunked_array_to_num( - ca: &ChunkedArray, - cmp_fn: &dyn Fn(Option) -> bool, -) -> BooleanChunked -where - T: PolarsNumericType, -{ - ca.into_iter().map(cmp_fn).collect() -} - -pub trait NumComp: Num + NumCast + PartialOrd {} - -impl NumComp for f32 {} -impl NumComp for f64 {} -impl NumComp for i8 {} -impl NumComp for i16 {} -impl NumComp for i32 {} -impl NumComp for i64 {} -impl NumComp for u8 {} -impl NumComp for u16 {} -impl NumComp for u32 {} -impl NumComp for u64 {} - -impl ChunkCompare for ChunkedArray -where - T: PolarsNumericType, - T::Native: NumCast, - Rhs: NumComp + ToPrimitive, -{ - fn eq_missing(&self, rhs: Rhs) -> BooleanChunked { - self.eq(rhs) - } - - fn eq(&self, rhs: Rhs) -> BooleanChunked { - let rhs = NumCast::from(rhs).expect("could not cast to underlying chunkedarray type"); - cmp_chunked_array_to_num(self, &|lhs: Option| lhs == Some(rhs)) - } - - fn neq(&self, rhs: Rhs) -> BooleanChunked { - let rhs = NumCast::from(rhs).expect("could not cast to underlying chunkedarray type"); - cmp_chunked_array_to_num(self, &|lhs: Option| lhs != Some(rhs)) - } - - fn gt(&self, rhs: Rhs) -> BooleanChunked { - let rhs = NumCast::from(rhs).expect("could not cast to underlying chunkedarray type"); - cmp_chunked_array_to_num(self, &|lhs: Option| lhs > Some(rhs)) - } - - fn gt_eq(&self, rhs: Rhs) -> BooleanChunked { - let rhs = NumCast::from(rhs).expect("could not cast to underlying chunkedarray type"); - cmp_chunked_array_to_num(self, &|lhs: Option| lhs >= Some(rhs)) - } - - fn lt(&self, rhs: Rhs) -> BooleanChunked { - let rhs = NumCast::from(rhs).expect("could not cast to underlying chunkedarray type"); - cmp_chunked_array_to_num(self, &|lhs: Option| lhs < Some(rhs)) - } - - fn lt_eq(&self, rhs: Rhs) -> BooleanChunked { - let rhs = NumCast::from(rhs).expect("could not cast to underlying chunkedarray type"); - cmp_chunked_array_to_num(self, &|lhs: Option| lhs <= Some(rhs)) - } -} - -fn cmp_utf8chunked_to_str(ca: &Utf8Chunked, cmp_fn: &dyn Fn(&str) -> bool) -> BooleanChunked { - ca.into_iter() - .map(|opt_s| match opt_s { - None => false, - Some(s) => cmp_fn(s), - }) - .collect() -} - -impl ChunkCompare<&str> for Utf8Chunked { - fn eq_missing(&self, rhs: &str) -> BooleanChunked { - self.eq(rhs) - } - - fn eq(&self, rhs: &str) -> BooleanChunked { - cmp_utf8chunked_to_str(self, &|lhs| lhs == rhs) - } - fn neq(&self, rhs: &str) -> BooleanChunked { - cmp_utf8chunked_to_str(self, &|lhs| lhs != rhs) - } - - fn gt(&self, rhs: &str) -> BooleanChunked { - cmp_utf8chunked_to_str(self, &|lhs| lhs > rhs) - } - - fn gt_eq(&self, rhs: &str) -> BooleanChunked { - cmp_utf8chunked_to_str(self, &|lhs| lhs >= rhs) - } - - fn lt(&self, rhs: &str) -> BooleanChunked { - cmp_utf8chunked_to_str(self, &|lhs| lhs < rhs) - } - - fn lt_eq(&self, rhs: &str) -> BooleanChunked { - cmp_utf8chunked_to_str(self, &|lhs| lhs <= rhs) - } -} - -macro_rules! impl_cmp_largelist { - ($self:ident, $rhs:ident, $cmp_method:ident) => {{ - match ($self.null_count(), $rhs.null_count()) { - (0, 0) => { - $self - .into_no_null_iter() - .zip($rhs.into_no_null_iter()) - // TODO: use Xob to get rid of redundant Some - .map(|(left, right)| Some(left.$cmp_method(&right))) - .collect() - } - (0, _) => $self - .into_no_null_iter() - .zip($rhs.into_iter()) - .map(|(left, opt_right)| opt_right.map(|right| left.$cmp_method(&right))) - .collect(), - (_, 0) => $self - .into_iter() - .zip($rhs.into_no_null_iter()) - .map(|(opt_left, right)| opt_left.map(|left| left.$cmp_method(&right))) - .collect(), - (_, _) => $self - .into_iter() - .zip($rhs.into_iter()) - .map(|(opt_left, opt_right)| match (opt_left, opt_right) { - (None, None) => None, - (None, Some(_)) => None, - (Some(_), None) => None, - (Some(left), Some(right)) => Some(left.$cmp_method(&right)), - }) - .collect(), - } - }}; -} - -impl ChunkCompare<&LargeListChunked> for LargeListChunked { - fn eq_missing(&self, rhs: &LargeListChunked) -> BooleanChunked { - impl_cmp_largelist!(self, rhs, series_equal_missing) - } - - fn eq(&self, rhs: &LargeListChunked) -> BooleanChunked { - impl_cmp_largelist!(self, rhs, series_equal) - } - - fn neq(&self, rhs: &LargeListChunked) -> BooleanChunked { - self.eq(rhs).not() - } - - // following are not implemented because gt, lt comparison of series don't make sense - fn gt(&self, _rhs: &LargeListChunked) -> BooleanChunked { - unimplemented!() - } - - fn gt_eq(&self, _rhs: &LargeListChunked) -> BooleanChunked { - unimplemented!() - } - - fn lt(&self, _rhs: &LargeListChunked) -> BooleanChunked { - unimplemented!() - } - - fn lt_eq(&self, _rhs: &LargeListChunked) -> BooleanChunked { - unimplemented!() - } -} - -impl BooleanChunked { - /// First ensure that the chunks of lhs and rhs match and then iterates over the chunks and applies - /// the comparison operator. - fn bit_operation( - &self, - rhs: &BooleanChunked, - operator: impl Fn(&BooleanArray, &BooleanArray) -> arrow::error::Result, - ) -> Result { - let chunks = self - .downcast_chunks() - .iter() - .zip(rhs.downcast_chunks()) - .map(|(left, right)| { - let arr_res = operator(left, right); - let arr = match arr_res { - Ok(arr) => arr, - Err(e) => return Err(PolarsError::ArrowError(e)), - }; - Ok(Arc::new(arr) as ArrayRef) - }) - .collect::>>()?; - - Ok(ChunkedArray::new_from_chunks("", chunks)) - } -} - -macro_rules! impl_bitwise_op { - ($self:ident, $rhs:ident, $arrow_method:ident, $op:tt) => {{ - if $self.chunk_id == $rhs.chunk_id { - let result = $self.bit_operation($rhs, compute::$arrow_method); - match result { - Ok(v) => return Ok(v), - Err(_) => (), - }; - }; - let ca = $self - .into_iter() - .zip($rhs.into_iter()) - .map(|(opt_left, opt_right)| match (opt_left, opt_right) { - (Some(left), Some(right)) => Some(left $op right), - _ => None, - }) - .collect(); - Ok(ca) - }} - -} - -impl BitOr for &BooleanChunked { - type Output = Result; - - fn bitor(self, rhs: Self) -> Self::Output { - impl_bitwise_op!(self, rhs, or, |) - } -} - -impl BitOr for BooleanChunked { - type Output = Result; - - fn bitor(self, rhs: Self) -> Self::Output { - (&self).bitor(&rhs) - } -} - -impl BitAnd for &BooleanChunked { - type Output = Result; - - fn bitand(self, rhs: Self) -> Self::Output { - impl_bitwise_op!(self, rhs, and, &) - } -} - -impl BitAnd for BooleanChunked { - type Output = Result; - - fn bitand(self, rhs: Self) -> Self::Output { - (&self).bitand(&rhs) - } -} - -impl Not for &BooleanChunked { - type Output = BooleanChunked; - - fn not(self) -> Self::Output { - let chunks = self - .downcast_chunks() - .iter() - .map(|a| { - let arr = compute::not(a).expect("should not fail"); - Arc::new(arr) as ArrayRef - }) - .collect::>(); - ChunkedArray::new_from_chunks(self.name(), chunks) - } -} - -impl Not for BooleanChunked { - type Output = BooleanChunked; - - fn not(self) -> Self::Output { - (&self).not() - } -} - -#[cfg(test)] -mod test { - use super::super::{arithmetic::test::create_two_chunked, test::get_chunked_array}; - use crate::prelude::*; - use itertools::Itertools; - use std::iter::repeat; - - #[test] - fn test_bitwise_ops() { - let a = BooleanChunked::new_from_slice("a", &[true, false, false]); - let b = BooleanChunked::new_from_opt_slice("b", &[Some(true), Some(true), None]); - assert_eq!( - Vec::from((&a | &b).unwrap()), - &[Some(true), Some(true), None] - ); - assert_eq!( - Vec::from((&a & &b).unwrap()), - &[Some(true), Some(false), None] - ); - assert_eq!(Vec::from(!b), &[Some(false), Some(false), None]); - } - - #[test] - fn test_compare_chunk_diff() { - let (a1, a2) = create_two_chunked(); - - assert_eq!( - a1.eq(&a2).into_iter().collect_vec(), - repeat(Some(true)).take(6).collect_vec() - ); - assert_eq!( - a2.eq(&a1).into_iter().collect_vec(), - repeat(Some(true)).take(6).collect_vec() - ); - assert_eq!( - a1.neq(&a2).into_iter().collect_vec(), - repeat(Some(false)).take(6).collect_vec() - ); - assert_eq!( - a2.neq(&a1).into_iter().collect_vec(), - repeat(Some(false)).take(6).collect_vec() - ); - assert_eq!( - a1.gt(&a2).into_iter().collect_vec(), - repeat(Some(false)).take(6).collect_vec() - ); - assert_eq!( - a2.gt(&a1).into_iter().collect_vec(), - repeat(Some(false)).take(6).collect_vec() - ); - assert_eq!( - a1.gt_eq(&a2).into_iter().collect_vec(), - repeat(Some(true)).take(6).collect_vec() - ); - assert_eq!( - a2.gt_eq(&a1).into_iter().collect_vec(), - repeat(Some(true)).take(6).collect_vec() - ); - assert_eq!( - a1.lt_eq(&a2).into_iter().collect_vec(), - repeat(Some(true)).take(6).collect_vec() - ); - assert_eq!( - a2.lt_eq(&a1).into_iter().collect_vec(), - repeat(Some(true)).take(6).collect_vec() - ); - assert_eq!( - a1.lt(&a2).into_iter().collect_vec(), - repeat(Some(false)).take(6).collect_vec() - ); - assert_eq!( - a2.lt(&a1).into_iter().collect_vec(), - repeat(Some(false)).take(6).collect_vec() - ); - } - - #[test] - fn test_equal_chunks() { - let a1 = get_chunked_array(); - let a2 = get_chunked_array(); - - assert_eq!( - a1.eq(&a2).into_iter().collect_vec(), - repeat(Some(true)).take(3).collect_vec() - ); - assert_eq!( - a2.eq(&a1).into_iter().collect_vec(), - repeat(Some(true)).take(3).collect_vec() - ); - assert_eq!( - a1.neq(&a2).into_iter().collect_vec(), - repeat(Some(false)).take(3).collect_vec() - ); - assert_eq!( - a2.neq(&a1).into_iter().collect_vec(), - repeat(Some(false)).take(3).collect_vec() - ); - assert_eq!( - a1.gt(&a2).into_iter().collect_vec(), - repeat(Some(false)).take(3).collect_vec() - ); - assert_eq!( - a2.gt(&a1).into_iter().collect_vec(), - repeat(Some(false)).take(3).collect_vec() - ); - assert_eq!( - a1.gt_eq(&a2).into_iter().collect_vec(), - repeat(Some(true)).take(3).collect_vec() - ); - assert_eq!( - a2.gt_eq(&a1).into_iter().collect_vec(), - repeat(Some(true)).take(3).collect_vec() - ); - assert_eq!( - a1.lt_eq(&a2).into_iter().collect_vec(), - repeat(Some(true)).take(3).collect_vec() - ); - assert_eq!( - a2.lt_eq(&a1).into_iter().collect_vec(), - repeat(Some(true)).take(3).collect_vec() - ); - assert_eq!( - a1.lt(&a2).into_iter().collect_vec(), - repeat(Some(false)).take(3).collect_vec() - ); - assert_eq!( - a2.lt(&a1).into_iter().collect_vec(), - repeat(Some(false)).take(3).collect_vec() - ); - } - - #[test] - fn test_null_handling() { - // assert we comply with arrows way of handling null data - // we check comparison on two arrays with one chunk and verify it is equal to a differently - // chunked array comparison. - - // two same chunked arrays - let a1: Int32Chunked = (&[Some(1), None, Some(3)]).iter().copied().collect(); - let a2: Int32Chunked = (&[Some(1), Some(2), Some(3)]).iter().copied().collect(); - - let mut a2_2chunks: Int32Chunked = (&[Some(1), Some(2)]).iter().copied().collect(); - a2_2chunks.append(&(&[Some(3)]).iter().copied().collect()); - - assert_eq!( - a1.eq(&a2).into_iter().collect_vec(), - a1.eq(&a2_2chunks).into_iter().collect_vec() - ); - - assert_eq!( - a1.neq(&a2).into_iter().collect_vec(), - a1.neq(&a2_2chunks).into_iter().collect_vec() - ); - assert_eq!( - a1.neq(&a2).into_iter().collect_vec(), - a2_2chunks.neq(&a1).into_iter().collect_vec() - ); - - assert_eq!( - a1.gt(&a2).into_iter().collect_vec(), - a1.gt(&a2_2chunks).into_iter().collect_vec() - ); - assert_eq!( - a1.gt(&a2).into_iter().collect_vec(), - a2_2chunks.gt(&a1).into_iter().collect_vec() - ); - - assert_eq!( - a1.gt_eq(&a2).into_iter().collect_vec(), - a1.gt_eq(&a2_2chunks).into_iter().collect_vec() - ); - assert_eq!( - a1.gt_eq(&a2).into_iter().collect_vec(), - a2_2chunks.gt_eq(&a1).into_iter().collect_vec() - ); - - assert_eq!( - a1.lt_eq(&a2).into_iter().collect_vec(), - a1.lt_eq(&a2_2chunks).into_iter().collect_vec() - ); - assert_eq!( - a1.lt_eq(&a2).into_iter().collect_vec(), - a2_2chunks.lt_eq(&a1).into_iter().collect_vec() - ); - - assert_eq!( - a1.lt(&a2).into_iter().collect_vec(), - a1.lt(&a2_2chunks).into_iter().collect_vec() - ); - assert_eq!( - a1.lt(&a2).into_iter().collect_vec(), - a2_2chunks.lt(&a1).into_iter().collect_vec() - ); - } - - #[test] - fn test_left_right() { - // This failed with arrow comparisons. TODO: check minimal arrow example with one array being - // sliced - let a1: Int32Chunked = (&[Some(1), Some(2)]).iter().copied().collect(); - let a1 = a1.slice(1, 1).unwrap(); - let a2: Int32Chunked = (&[Some(2)]).iter().copied().collect(); - assert_eq!(a1.eq(&a2).sum(), a2.eq(&a1).sum()); - assert_eq!(a1.neq(&a2).sum(), a2.neq(&a1).sum()); - assert_eq!(a1.gt(&a2).sum(), a2.gt(&a1).sum()); - assert_eq!(a1.lt(&a2).sum(), a2.lt(&a1).sum()); - assert_eq!(a1.lt_eq(&a2).sum(), a2.lt_eq(&a1).sum()); - assert_eq!(a1.gt_eq(&a2).sum(), a2.gt_eq(&a1).sum()); - - let a1: Utf8Chunked = (&["a", "b"]).iter().copied().collect(); - let a1 = a1.slice(1, 1).unwrap(); - let a2: Utf8Chunked = (&["b"]).iter().copied().collect(); - assert_eq!(a1.eq(&a2).sum(), a2.eq(&a1).sum()); - assert_eq!(a1.neq(&a2).sum(), a2.neq(&a1).sum()); - assert_eq!(a1.gt(&a2).sum(), a2.gt(&a1).sum()); - assert_eq!(a1.lt(&a2).sum(), a2.lt(&a1).sum()); - assert_eq!(a1.lt_eq(&a2).sum(), a2.lt_eq(&a1).sum()); - assert_eq!(a1.gt_eq(&a2).sum(), a2.gt_eq(&a1).sum()); - } -} diff --git a/polars/src/chunked_array/iterator.rs b/polars/src/chunked_array/iterator.rs deleted file mode 100644 index 5f247f026e70..000000000000 --- a/polars/src/chunked_array/iterator.rs +++ /dev/null @@ -1,1995 +0,0 @@ -use crate::prelude::*; -use arrow::array::{ - Array, ArrayDataRef, ArrayRef, BooleanArray, LargeListArray, PrimitiveArray, StringArray, -}; -use std::iter::Copied; -use std::slice::Iter; - -// ExactSizeIterator trait implementations for all Iterator structs in this file -impl<'a, T> ExactSizeIterator for NumIterSingleChunkNullCheck<'a, T> where T: PolarsNumericType {} -impl<'a, T> ExactSizeIterator for NumIterSingleChunk<'a, T> -where - T: PolarsNumericType, - T::Native: Copy, -{ -} -impl<'a, T> ExactSizeIterator for NumIterManyChunkNullCheck<'a, T> where T: PolarsNumericType {} -impl<'a, T> ExactSizeIterator for NumIterManyChunk<'a, T> where T: PolarsNumericType {} -impl<'a> ExactSizeIterator for Utf8IterSingleChunk<'a> {} -impl<'a> ExactSizeIterator for Utf8IterSingleChunkNullCheck<'a> {} -impl<'a> ExactSizeIterator for Utf8IterManyChunk<'a> {} -impl<'a> ExactSizeIterator for Utf8IterManyChunkNullCheck<'a> {} - -/// Trait for ChunkedArrays that don't have null values. -/// TODO: implement for faster paths -pub trait IntoNoNullIterator { - type Item; - type IntoIter: Iterator; - - fn into_no_null_iter(self) -> Self::IntoIter; -} - -impl<'a, T> IntoNoNullIterator for &'a ChunkedArray -where - T: PolarsNumericType, -{ - type Item = T::Native; - type IntoIter = Box + 'a>; - - fn into_no_null_iter(self) -> Self::IntoIter { - match self.chunks.len() { - 1 => Box::new( - self.downcast_chunks()[0] - .value_slice(0, self.len()) - .into_iter() - .copied(), - ), - _ => Box::new(NumIterManyChunk::new(self)), - } - } -} - -/// Single chunk with null values -pub struct NumIterSingleChunkNullCheck<'a, T> -where - T: PolarsNumericType, -{ - arr: &'a PrimitiveArray, - idx: usize, - back_idx: usize, -} - -impl<'a, T> NumIterSingleChunkNullCheck<'a, T> -where - T: PolarsNumericType, -{ - fn return_opt_val(&self, index: usize) -> Option> { - if self.arr.is_null(index) { - Some(None) - } else { - Some(Some(self.arr.value(index))) - } - } -} - -impl<'a, T> Iterator for NumIterSingleChunkNullCheck<'a, T> -where - T: PolarsNumericType, -{ - type Item = Option; - - fn next(&mut self) -> Option { - if self.idx == self.back_idx { - None - } else { - self.idx += 1; - self.return_opt_val(self.idx - 1) - } - } - - fn size_hint(&self) -> (usize, Option) { - let len = self.arr.len(); - (len, Some(len)) - } -} - -impl<'a, T> DoubleEndedIterator for NumIterSingleChunkNullCheck<'a, T> -where - T: PolarsNumericType, -{ - fn next_back(&mut self) -> Option { - if self.idx == self.back_idx { - None - } else { - self.back_idx -= 1; - self.return_opt_val(self.back_idx) - } - } -} - -/// Single chunk no null values -pub struct NumIterSingleChunk<'a, T> -where - T: PolarsNumericType, - T::Native: Copy, -{ - iter: Copied>, -} - -impl<'a, T> Iterator for NumIterSingleChunk<'a, T> -where - T: PolarsNumericType, - T::Native: Copy, -{ - type Item = Option; - - fn next(&mut self) -> Option { - self.iter.next().map(Some) - } - - fn size_hint(&self) -> (usize, Option) { - self.iter.size_hint() - } -} - -impl<'a, T> DoubleEndedIterator for NumIterSingleChunk<'a, T> -where - T: PolarsNumericType, - T::Native: Copy, -{ - fn next_back(&mut self) -> Option { - self.iter.next_back().map(Some) - } -} - -/// Many chunks no null checks -/// Both used as iterator with null checks and without. We later map Some on it for the iter -/// with null checks -pub struct NumIterManyChunk<'a, T> -where - T: PolarsNumericType, -{ - ca: &'a ChunkedArray, - chunks: Vec<&'a PrimitiveArray>, - // current iterator if we iterate from the left - current_iter_left: Copied>, - // If iter_left and iter_right are the same, this is None and we need to use iter left - // This is done because we can only have one owner - current_iter_right: Option>>, - chunk_idx_left: usize, - idx_left: usize, - idx_right: usize, - chunk_idx_right: usize, -} - -impl<'a, T> NumIterManyChunk<'a, T> -where - T: PolarsNumericType, -{ - fn set_current_iter_left(&mut self) { - let current_chunk = unsafe { self.chunks.get_unchecked(self.chunk_idx_left) }; - self.current_iter_left = current_chunk - .value_slice(0, current_chunk.len()) - .iter() - .copied(); - } - - fn set_current_iter_right(&mut self) { - if self.chunk_idx_left == self.chunk_idx_left { - // from left and right we use the same iterator - self.current_iter_right = None - } else { - let current_chunk = unsafe { self.chunks.get_unchecked(self.chunk_idx_right) }; - self.current_iter_right = Some( - current_chunk - .value_slice(0, current_chunk.len()) - .iter() - .copied(), - ); - } - } - fn new(ca: &'a ChunkedArray) -> Self { - let chunks = ca.downcast_chunks(); - let current_len_left = chunks[0].len(); - let current_iter_left = chunks[0].value_slice(0, current_len_left).iter().copied(); - - let idx_left = 0; - let chunk_idx_left = 0; - let idx_right = ca.len(); - - let chunk_idx_right = chunks.len() - 1; - let current_iter_right; - if chunk_idx_left == chunk_idx_right { - current_iter_right = None - } else { - let arr = chunks[chunk_idx_right]; - current_iter_right = Some(arr.value_slice(0, arr.len()).iter().copied()) - } - - NumIterManyChunk { - ca, - current_iter_left, - chunks, - current_iter_right, - idx_left, - chunk_idx_left, - idx_right, - chunk_idx_right, - } - } -} - -impl<'a, T> Iterator for NumIterManyChunk<'a, T> -where - T: PolarsNumericType, -{ - type Item = T::Native; - - fn next(&mut self) -> Option { - let opt_val = self.current_iter_left.next(); - - let opt_val = if opt_val.is_none() { - // iterators have met in the middle or at the end - if self.idx_left == self.idx_right { - return None; - // one chunk is finished but there are still more chunks - } else { - self.chunk_idx_left += 1; - self.set_current_iter_left(); - } - // so we return the first value of the next chunk - self.current_iter_left.next() - } else { - // we got a value - opt_val - }; - self.idx_left += 1; - // opt_val.map(Some) - opt_val - } - - fn size_hint(&self) -> (usize, Option) { - let len = self.ca.len(); - (len, Some(len)) - } -} - -impl<'a, T> DoubleEndedIterator for NumIterManyChunk<'a, T> -where - T: PolarsNumericType, -{ - fn next_back(&mut self) -> Option { - let opt_val = match &mut self.current_iter_right { - Some(it) => it.next_back(), - None => self.current_iter_left.next_back(), - }; - - let opt_val = if opt_val.is_none() { - // iterators have met in the middle or at the beginning - if self.idx_left == self.idx_right { - return None; - // one chunk is finished but there are still more chunks - } else { - self.chunk_idx_right -= 1; - self.set_current_iter_right(); - } - // so we return the first value of the next chunk from the back - self.current_iter_left.next_back() - } else { - // we got a value - opt_val - }; - self.idx_right -= 1; - opt_val - } -} - -/// Many chunks with null checks -pub struct NumIterManyChunkNullCheck<'a, T> -where - T: PolarsNumericType, -{ - ca: &'a ChunkedArray, - chunks: Vec<&'a PrimitiveArray>, - // current iterator if we iterate from the left - current_iter_left: Copied>, - current_data_left: ArrayDataRef, - // index in the current iterator from left - current_array_i_left: usize, - // If iter_left and iter_right are the same, this is None and we need to use iter left - // This is done because we can only have one owner - current_iter_right: Option>>, - current_data_right: ArrayDataRef, - current_array_i_right: usize, - chunk_idx_left: usize, - idx_left: usize, - idx_right: usize, - chunk_idx_right: usize, -} - -impl<'a, T> NumIterManyChunkNullCheck<'a, T> -where - T: PolarsNumericType, -{ - fn set_current_iter_left(&mut self) { - let current_chunk = unsafe { self.chunks.get_unchecked(self.chunk_idx_left) }; - self.current_data_left = current_chunk.data(); - self.current_iter_left = current_chunk - .value_slice(0, current_chunk.len()) - .iter() - .copied(); - } - - fn set_current_iter_right(&mut self) { - let current_chunk = unsafe { self.chunks.get_unchecked(self.chunk_idx_right) }; - self.current_data_right = current_chunk.data(); - if self.chunk_idx_left == self.chunk_idx_left { - // from left and right we use the same iterator - self.current_iter_right = None - } else { - self.current_iter_right = Some( - current_chunk - .value_slice(0, current_chunk.len()) - .iter() - .copied(), - ); - } - } - fn new(ca: &'a ChunkedArray) -> Self { - let chunks = ca.downcast_chunks(); - let arr_left = chunks[0]; - let current_len_left = arr_left.len(); - let current_iter_left = arr_left.value_slice(0, current_len_left).iter().copied(); - let current_data_left = arr_left.data(); - - let idx_left = 0; - let chunk_idx_left = 0; - let idx_right = ca.len(); - - let chunk_idx_right = chunks.len() - 1; - let current_iter_right; - let arr = chunks[chunk_idx_right]; - let current_data_right = arr.data(); - if chunk_idx_left == chunk_idx_right { - current_iter_right = None - } else { - current_iter_right = Some(arr.value_slice(0, arr.len()).iter().copied()) - } - let current_array_i_right = arr.len(); - - NumIterManyChunkNullCheck { - ca, - current_iter_left, - current_data_left, - current_array_i_left: 0, - chunks, - current_iter_right, - current_data_right, - current_array_i_right, - idx_left, - chunk_idx_left, - idx_right, - chunk_idx_right, - } - } -} - -impl<'a, T> Iterator for NumIterManyChunkNullCheck<'a, T> -where - T: PolarsNumericType, -{ - type Item = Option; - - fn next(&mut self) -> Option { - let opt_val = self.current_iter_left.next(); - - let opt_val = if opt_val.is_none() { - // iterators have met in the middle or at the end - if self.idx_left == self.idx_right { - return None; - // one chunk is finished but there are still more chunks - } else { - self.chunk_idx_left += 1; - // reset the index - self.current_array_i_left = 0; - self.set_current_iter_left(); - } - // so we return the first value of the next chunk - self.current_iter_left.next() - } else { - // we got a value - opt_val - }; - self.idx_left += 1; - self.current_array_i_left += 1; - if self - .current_data_left - .is_null(self.current_array_i_left - 1) - { - Some(None) - } else { - opt_val.map(Some) - } - } - - fn size_hint(&self) -> (usize, Option) { - let len = self.ca.len(); - (len, Some(len)) - } -} - -impl<'a, T> DoubleEndedIterator for NumIterManyChunkNullCheck<'a, T> -where - T: PolarsNumericType, -{ - fn next_back(&mut self) -> Option { - let opt_val = match &mut self.current_iter_right { - Some(it) => it.next_back(), - None => self.current_iter_left.next_back(), - }; - - let opt_val = if opt_val.is_none() { - // iterators have met in the middle or at the beginning - if self.idx_left == self.idx_right { - return None; - // one chunk is finished but there are still more chunks - } else { - self.chunk_idx_right -= 1; - self.set_current_iter_right(); - // reset the index accumulator - self.current_array_i_right = self.current_data_right.len() - } - // so we return the first value of the next chunk from the back - self.current_iter_left.next_back() - } else { - // we got a value - opt_val - }; - self.idx_right -= 1; - self.current_array_i_right -= 1; - - if self.current_data_right.is_null(self.current_array_i_right) { - Some(None) - } else { - opt_val.map(Some) - } - } -} - -pub enum NumericChunkIterDispatch<'a, T> -where - T: PolarsNumericType, -{ - SingleChunk(NumIterSingleChunk<'a, T>), - SingleChunkNullCheck(NumIterSingleChunkNullCheck<'a, T>), - ManyChunk(NumIterManyChunk<'a, T>), - ManyChunkNullCheck(NumIterManyChunkNullCheck<'a, T>), -} - -impl<'a, T> Iterator for NumericChunkIterDispatch<'a, T> -where - T: PolarsNumericType, -{ - type Item = Option; - - fn next(&mut self) -> Option { - match self { - NumericChunkIterDispatch::SingleChunk(a) => a.next(), - NumericChunkIterDispatch::SingleChunkNullCheck(a) => a.next(), - NumericChunkIterDispatch::ManyChunk(a) => a.next().map(Some), - NumericChunkIterDispatch::ManyChunkNullCheck(a) => a.next(), - } - } - - fn size_hint(&self) -> (usize, Option) { - match self { - NumericChunkIterDispatch::SingleChunk(a) => a.size_hint(), - NumericChunkIterDispatch::SingleChunkNullCheck(a) => a.size_hint(), - NumericChunkIterDispatch::ManyChunk(a) => a.size_hint(), - NumericChunkIterDispatch::ManyChunkNullCheck(a) => a.size_hint(), - } - } -} - -impl<'a, T> DoubleEndedIterator for NumericChunkIterDispatch<'a, T> -where - T: PolarsNumericType, -{ - fn next_back(&mut self) -> Option { - match self { - NumericChunkIterDispatch::SingleChunk(a) => a.next_back(), - NumericChunkIterDispatch::SingleChunkNullCheck(a) => a.next_back(), - NumericChunkIterDispatch::ManyChunk(a) => a.next_back().map(Some), - NumericChunkIterDispatch::ManyChunkNullCheck(a) => a.next_back(), - } - } -} - -impl<'a, T> ExactSizeIterator for NumericChunkIterDispatch<'a, T> where T: PolarsNumericType {} - -impl<'a, T> IntoIterator for &'a ChunkedArray -where - T: PolarsNumericType, -{ - type Item = Option; - type IntoIter = NumericChunkIterDispatch<'a, T>; - - fn into_iter(self) -> Self::IntoIter { - match self.cont_slice() { - Ok(slice) => { - // Compile could not infer T. - let a: NumIterSingleChunk<'_, T> = NumIterSingleChunk { - iter: slice.iter().copied(), - }; - NumericChunkIterDispatch::SingleChunk(a) - } - Err(_) => { - let chunks = self.downcast_chunks(); - match chunks.len() { - 1 => { - let arr = chunks[0]; - let len = arr.len(); - - NumericChunkIterDispatch::SingleChunkNullCheck( - NumIterSingleChunkNullCheck { - arr, - idx: 0, - back_idx: len, - }, - ) - } - _ => { - if self.null_count() == 0 { - NumericChunkIterDispatch::ManyChunk(NumIterManyChunk::new(self)) - } else { - NumericChunkIterDispatch::ManyChunkNullCheck( - NumIterManyChunkNullCheck::new(self), - ) - } - } - } - } - } - } -} - -/// No null checks and dont return Option but T directly. So this struct is not return by the -/// IntoIterator trait -pub struct Utf8IterCont<'a> { - current_array: &'a StringArray, - idx_left: usize, - idx_right: usize, -} - -impl<'a> Utf8IterCont<'a> { - fn new(ca: &'a Utf8Chunked) -> Self { - let chunks = ca.downcast_chunks(); - let current_array = chunks[0]; - let idx_left = 0; - let idx_right = current_array.len(); - - Utf8IterCont { - current_array, - idx_left, - idx_right, - } - } -} - -impl<'a> Iterator for Utf8IterCont<'a> { - type Item = &'a str; - - fn next(&mut self) -> Option { - // end of iterator or meet reversed iterator in the middle - if self.idx_left == self.idx_right { - return None; - } - - let v = self.current_array.value(self.idx_left); - self.idx_left += 1; - Some(v) - } - - fn size_hint(&self) -> (usize, Option) { - let len = self.current_array.len(); - (len, Some(len)) - } -} - -impl<'a> IntoNoNullIterator for &'a Utf8Chunked { - type Item = &'a str; - type IntoIter = Utf8IterCont<'a>; - - fn into_no_null_iter(self) -> Self::IntoIter { - Utf8IterCont::new(self) - } -} - -/// No null checks -pub struct Utf8IterSingleChunk<'a> { - current_array: &'a StringArray, - idx_left: usize, - idx_right: usize, -} - -impl<'a> Utf8IterSingleChunk<'a> { - fn new(ca: &'a Utf8Chunked) -> Self { - let chunks = ca.downcast_chunks(); - let current_array = chunks[0]; - let idx_left = 0; - let idx_right = current_array.len(); - - Utf8IterSingleChunk { - current_array, - idx_left, - idx_right, - } - } -} - -impl<'a> Iterator for Utf8IterSingleChunk<'a> { - type Item = Option<&'a str>; - - fn next(&mut self) -> Option { - // end of iterator or meet reversed iterator in the middle - if self.idx_left == self.idx_right { - return None; - } - - let v = self.current_array.value(self.idx_left); - self.idx_left += 1; - Some(Some(v)) - } - - fn size_hint(&self) -> (usize, Option) { - let len = self.current_array.len(); - (len, Some(len)) - } -} - -impl<'a> DoubleEndedIterator for Utf8IterSingleChunk<'a> { - fn next_back(&mut self) -> Option { - // end of iterator or meet reversed iterator in the middle - if self.idx_left == self.idx_right { - return None; - } - self.idx_right -= 1; - Some(Some(self.current_array.value(self.idx_right))) - } -} - -pub struct Utf8IterSingleChunkNullCheck<'a> { - current_data: ArrayDataRef, - current_array: &'a StringArray, - idx_left: usize, - idx_right: usize, -} - -impl<'a> Utf8IterSingleChunkNullCheck<'a> { - fn new(ca: &'a Utf8Chunked) -> Self { - let chunks = ca.downcast_chunks(); - let current_array = chunks[0]; - let current_data = current_array.data(); - let idx_left = 0; - let idx_right = current_array.len(); - - Utf8IterSingleChunkNullCheck { - current_data, - current_array, - idx_left, - idx_right, - } - } -} - -impl<'a> Iterator for Utf8IterSingleChunkNullCheck<'a> { - type Item = Option<&'a str>; - - fn next(&mut self) -> Option { - // end of iterator or meet reversed iterator in the middle - if self.idx_left == self.idx_right { - return None; - } - let ret; - if self.current_data.is_null(self.idx_left) { - ret = Some(None) - } else { - let v = self.current_array.value(self.idx_left); - ret = Some(Some(v)) - } - self.idx_left += 1; - ret - } - - fn size_hint(&self) -> (usize, Option) { - let len = self.current_array.len(); - (len, Some(len)) - } -} - -impl<'a> DoubleEndedIterator for Utf8IterSingleChunkNullCheck<'a> { - fn next_back(&mut self) -> Option { - // end of iterator or meet reversed iterator in the middle - if self.idx_left == self.idx_right { - return None; - } - self.idx_right -= 1; - if self.current_data.is_null(self.idx_right) { - Some(None) - } else { - Some(Some(self.current_array.value(self.idx_right))) - } - } -} - -/// Many chunks no nulls -pub struct Utf8IterManyChunk<'a> { - ca: &'a Utf8Chunked, - chunks: Vec<&'a StringArray>, - current_array_left: &'a StringArray, - current_array_right: &'a StringArray, - current_array_idx_left: usize, - current_array_idx_right: usize, - current_array_left_len: usize, - idx_left: usize, - idx_right: usize, - chunk_idx_left: usize, - chunk_idx_right: usize, -} - -impl<'a> Utf8IterManyChunk<'a> { - fn new(ca: &'a Utf8Chunked) -> Self { - let chunks = ca.downcast_chunks(); - let current_array_left = chunks[0]; - let idx_left = 0; - let chunk_idx_left = 0; - let chunk_idx_right = chunks.len() - 1; - let current_array_right = chunks[chunk_idx_right]; - let idx_right = ca.len(); - let current_array_idx_left = 0; - let current_array_idx_right = current_array_right.len(); - let current_array_left_len = current_array_left.len(); - - Utf8IterManyChunk { - ca, - chunks, - current_array_left, - current_array_right, - current_array_idx_left, - current_array_idx_right, - current_array_left_len, - idx_left, - idx_right, - chunk_idx_left, - chunk_idx_right, - } - } -} - -impl<'a> Iterator for Utf8IterManyChunk<'a> { - type Item = Option<&'a str>; - - fn next(&mut self) -> Option { - // end of iterator or meet reversed iterator in the middle - if self.idx_left == self.idx_right { - return None; - } - - // return value - let ret = self.current_array_left.value(self.current_array_idx_left); - - // increment index pointers - self.idx_left += 1; - self.current_array_idx_left += 1; - - // we've reached the end of the chunk - if self.current_array_idx_left == self.current_array_left_len { - // Set a new chunk as current data - self.chunk_idx_left += 1; - - // if this evaluates to False, next call will be end of iterator - if self.chunk_idx_left < self.chunks.len() { - // reset to new array - self.current_array_idx_left = 0; - self.current_array_left = self.chunks[self.chunk_idx_left]; - self.current_array_left_len = self.current_array_left.len(); - } - } - Some(Some(ret)) - } - - fn size_hint(&self) -> (usize, Option) { - let len = self.ca.len(); - (len, Some(len)) - } -} - -impl<'a> DoubleEndedIterator for Utf8IterManyChunk<'a> { - fn next_back(&mut self) -> Option { - // end of iterator or meet reversed iterator in the middle - if self.idx_left == self.idx_right { - return None; - } - self.idx_right -= 1; - self.current_array_idx_right -= 1; - - let ret = self.current_array_right.value(self.current_array_idx_right); - - // we've reached the end of the chunk from the right - if self.current_array_idx_right == 0 && self.idx_right > 0 { - // set a new chunk as current data - self.chunk_idx_right -= 1; - // reset to new array - self.current_array_right = self.chunks[self.chunk_idx_right]; - self.current_array_idx_right = self.current_array_right.len(); - } - Some(Some(ret)) - } -} - -/// Many chunks with nulls -pub struct Utf8IterManyChunkNullCheck<'a> { - ca: &'a Utf8Chunked, - chunks: Vec<&'a StringArray>, - current_data_left: ArrayDataRef, - current_array_left: &'a StringArray, - current_data_right: ArrayDataRef, - current_array_right: &'a StringArray, - current_array_idx_left: usize, - current_array_idx_right: usize, - current_array_left_len: usize, - idx_left: usize, - idx_right: usize, - chunk_idx_left: usize, - chunk_idx_right: usize, -} - -impl<'a> Utf8IterManyChunkNullCheck<'a> { - fn new(ca: &'a Utf8Chunked) -> Self { - let chunks = ca.downcast_chunks(); - let current_array_left = chunks[0]; - let current_data_left = current_array_left.data(); - let idx_left = 0; - let chunk_idx_left = 0; - let chunk_idx_right = chunks.len() - 1; - let current_array_right = chunks[chunk_idx_right]; - let current_data_right = current_array_right.data(); - let idx_right = ca.len(); - let current_array_idx_left = 0; - let current_array_idx_right = current_data_right.len(); - let current_array_left_len = current_array_left.len(); - - Utf8IterManyChunkNullCheck { - ca, - chunks, - current_data_left, - current_array_left, - current_data_right, - current_array_right, - current_array_idx_left, - current_array_idx_right, - current_array_left_len, - idx_left, - idx_right, - chunk_idx_left, - chunk_idx_right, - } - } -} - -impl<'a> Iterator for Utf8IterManyChunkNullCheck<'a> { - type Item = Option<&'a str>; - - fn next(&mut self) -> Option { - // end of iterator or meet reversed iterator in the middle - if self.idx_left == self.idx_right { - return None; - } - - // return value - let ret; - if self.current_array_left.is_null(self.current_array_idx_left) { - ret = None - } else { - ret = Some(self.current_array_left.value(self.current_array_idx_left)); - } - - // increment index pointers - self.idx_left += 1; - self.current_array_idx_left += 1; - - // we've reached the end of the chunk - if self.current_array_idx_left == self.current_array_left_len { - // Set a new chunk as current data - self.chunk_idx_left += 1; - - // if this evaluates to False, next call will be end of iterator - if self.chunk_idx_left < self.chunks.len() { - // reset to new array - self.current_array_idx_left = 0; - self.current_array_left = self.chunks[self.chunk_idx_left]; - self.current_data_left = self.current_array_left.data(); - self.current_array_left_len = self.current_array_left.len(); - } - } - Some(ret) - } - - fn size_hint(&self) -> (usize, Option) { - let len = self.ca.len(); - (len, Some(len)) - } -} - -impl<'a> DoubleEndedIterator for Utf8IterManyChunkNullCheck<'a> { - fn next_back(&mut self) -> Option { - // end of iterator or meet reversed iterator in the middle - if self.idx_left == self.idx_right { - return None; - } - self.idx_right -= 1; - self.current_array_idx_right -= 1; - - let ret = if self - .current_data_right - .is_null(self.current_array_idx_right) - { - Some(None) - } else { - Some(Some( - self.current_array_right.value(self.current_array_idx_right), - )) - }; - - // we've reached the end of the chunk from the right - if self.current_array_idx_right == 0 && self.idx_right > 0 { - // set a new chunk as current data - self.chunk_idx_right -= 1; - // reset to new array - self.current_array_right = self.chunks[self.chunk_idx_right]; - self.current_data_right = self.current_array_right.data(); - self.current_array_idx_right = self.current_array_right.len(); - } - ret - } -} - -pub enum Utf8ChunkIterDispatch<'a> { - SingleChunk(Utf8IterSingleChunk<'a>), - SingleChunkNullCheck(Utf8IterSingleChunkNullCheck<'a>), - ManyChunk(Utf8IterManyChunk<'a>), - ManyChunkNullCheck(Utf8IterManyChunkNullCheck<'a>), -} - -impl<'a> Iterator for Utf8ChunkIterDispatch<'a> { - type Item = Option<&'a str>; - - fn next(&mut self) -> Option { - match self { - Utf8ChunkIterDispatch::SingleChunk(a) => a.next(), - Utf8ChunkIterDispatch::SingleChunkNullCheck(a) => a.next(), - Utf8ChunkIterDispatch::ManyChunk(a) => a.next(), - Utf8ChunkIterDispatch::ManyChunkNullCheck(a) => a.next(), - } - } - - fn size_hint(&self) -> (usize, Option) { - match self { - Utf8ChunkIterDispatch::SingleChunk(a) => a.size_hint(), - Utf8ChunkIterDispatch::SingleChunkNullCheck(a) => a.size_hint(), - Utf8ChunkIterDispatch::ManyChunk(a) => a.size_hint(), - Utf8ChunkIterDispatch::ManyChunkNullCheck(a) => a.size_hint(), - } - } -} - -impl<'a> DoubleEndedIterator for Utf8ChunkIterDispatch<'a> { - fn next_back(&mut self) -> Option { - match self { - Utf8ChunkIterDispatch::SingleChunk(a) => a.next_back(), - Utf8ChunkIterDispatch::SingleChunkNullCheck(a) => a.next_back(), - Utf8ChunkIterDispatch::ManyChunk(a) => a.next_back(), - Utf8ChunkIterDispatch::ManyChunkNullCheck(a) => a.next_back(), - } - } -} - -impl<'a> ExactSizeIterator for Utf8ChunkIterDispatch<'a> {} - -impl<'a> IntoIterator for &'a Utf8Chunked { - type Item = Option<&'a str>; - type IntoIter = Utf8ChunkIterDispatch<'a>; - - fn into_iter(self) -> Self::IntoIter { - let chunks = self.downcast_chunks(); - match chunks.len() { - 1 => { - if self.null_count() == 0 { - Utf8ChunkIterDispatch::SingleChunk(Utf8IterSingleChunk::new(self)) - } else { - Utf8ChunkIterDispatch::SingleChunkNullCheck(Utf8IterSingleChunkNullCheck::new( - self, - )) - } - } - _ => { - if self.null_count() == 0 { - Utf8ChunkIterDispatch::ManyChunk(Utf8IterManyChunk::new(self)) - } else { - Utf8ChunkIterDispatch::ManyChunkNullCheck(Utf8IterManyChunkNullCheck::new(self)) - } - } - } - } -} - -macro_rules! impl_iterator_traits { - ($ca_type:ident, // ChunkedArray type - $arrow_array:ident, // Arrow array Type - $no_null_iter_struct:ident, // Name of Iterator in case of no null checks and return type is T instead of Option - $single_chunk_ident:ident, // Name of Iterator in case of single chunk - $single_chunk_null_ident:ident, // Name of Iterator in case of single chunk and nulls - $many_chunk_ident:ident, // Name of Iterator in case of many chunks - $many_chunk_null_ident:ident, // Name of Iterator in case of many chunks and null - $chunkdispatch:ident, // Name of Dispatch struct - $iter_item:ty, // Item returned by Iterator e.g. Option - $iter_item_no_null:ty, // Iterm return by Iterator that doesn't do null checks. e.g. bool - $return_function: ident, // function that is called upon returning from the iterator with the inner type - // So in case of both Option and T function is called with T - // Fn(method_name: &str, type: T) -> ? - ) => { - impl<'a> ExactSizeIterator for $single_chunk_ident<'a> {} - impl<'a> ExactSizeIterator for $single_chunk_null_ident<'a> {} - impl<'a> ExactSizeIterator for $many_chunk_ident<'a> {} - impl<'a> ExactSizeIterator for $many_chunk_null_ident<'a> {} - - /// No null checks and dont return Option but T directly. So this struct is not return by the - /// IntoIterator trait - pub struct $no_null_iter_struct<'a> { - current_array: &'a $arrow_array, - idx_left: usize, - idx_right: usize, - } - - impl<'a> $no_null_iter_struct<'a> { - fn new(ca: &'a $ca_type) -> Self { - let chunks = ca.downcast_chunks(); - let current_array = chunks[0]; - let idx_left = 0; - let idx_right = current_array.len(); - - $no_null_iter_struct { - current_array, - idx_left, - idx_right, - } - } - } - - impl<'a> Iterator for $no_null_iter_struct<'a> { - type Item = $iter_item_no_null; - - fn next(&mut self) -> Option { - // end of iterator or meet reversed iterator in the middle - if self.idx_left == self.idx_right { - return None; - } - - let v = self.current_array.value(self.idx_left); - self.idx_left += 1; - Some($return_function("next", v)) - } - - fn size_hint(&self) -> (usize, Option) { - let len = self.current_array.len(); - (len, Some(len)) - } - } - - impl<'a> IntoNoNullIterator for &'a $ca_type { - type Item = $iter_item_no_null; - type IntoIter = $no_null_iter_struct<'a>; - - fn into_no_null_iter(self) -> Self::IntoIter { - $no_null_iter_struct::new(self) - } - } - - /// No null checks - pub struct $single_chunk_ident<'a> { - current_array: &'a $arrow_array, - idx_left: usize, - idx_right: usize, - } - - impl<'a> $single_chunk_ident<'a> { - fn new(ca: &'a $ca_type) -> Self { - let chunks = ca.downcast_chunks(); - let current_array = chunks[0]; - let idx_left = 0; - let idx_right = current_array.len(); - - $single_chunk_ident { - current_array, - idx_left, - idx_right, - } - } - } - - impl<'a> Iterator for $single_chunk_ident<'a> { - type Item = $iter_item; - - fn next(&mut self) -> Option { - // end of iterator or meet reversed iterator in the middle - if self.idx_left == self.idx_right { - return None; - } - - let v = self.current_array.value(self.idx_left); - self.idx_left += 1; - Some(Some($return_function("next", v))) - } - - fn size_hint(&self) -> (usize, Option) { - let len = self.current_array.len(); - (len, Some(len)) - } - } - - impl<'a> DoubleEndedIterator for $single_chunk_ident<'a> { - fn next_back(&mut self) -> Option { - // end of iterator or meet reversed iterator in the middle - if self.idx_left == self.idx_right { - return None; - } - self.idx_right -= 1; - let v = self.current_array.value(self.idx_right); - Some(Some($return_function("next_back", v))) - } - } - - pub struct $single_chunk_null_ident<'a> { - current_data: ArrayDataRef, - current_array: &'a $arrow_array, - idx_left: usize, - idx_right: usize, - } - - impl<'a> $single_chunk_null_ident<'a> { - fn new(ca: &'a $ca_type) -> Self { - let chunks = ca.downcast_chunks(); - let current_array = chunks[0]; - let current_data = current_array.data(); - let idx_left = 0; - let idx_right = current_array.len(); - - $single_chunk_null_ident { - current_data, - current_array, - idx_left, - idx_right, - } - } - } - - impl<'a> Iterator for $single_chunk_null_ident<'a> { - type Item = $iter_item; - - fn next(&mut self) -> Option { - // end of iterator or meet reversed iterator in the middle - if self.idx_left == self.idx_right { - return None; - } - let ret; - if self.current_data.is_null(self.idx_left) { - ret = None - } else { - let v = self.current_array.value(self.idx_left); - ret = Some($return_function("next", v)) - } - self.idx_left += 1; - Some(ret) - } - - fn size_hint(&self) -> (usize, Option) { - let len = self.current_array.len(); - (len, Some(len)) - } - } - - impl<'a> DoubleEndedIterator for $single_chunk_null_ident<'a> { - fn next_back(&mut self) -> Option { - // end of iterator or meet reversed iterator in the middle - if self.idx_left == self.idx_right { - return None; - } - self.idx_right -= 1; - if self.current_data.is_null(self.idx_right) { - Some(None) - } else { - let v = self.current_array.value(self.idx_right); - Some(Some($return_function("next_back", v))) - } - } - } - - /// Many chunks no nulls - pub struct $many_chunk_ident<'a> { - ca: &'a $ca_type, - chunks: Vec<&'a $arrow_array>, - current_array_left: &'a $arrow_array, - current_array_right: &'a $arrow_array, - current_array_idx_left: usize, - current_array_idx_right: usize, - current_array_left_len: usize, - idx_left: usize, - idx_right: usize, - chunk_idx_left: usize, - chunk_idx_right: usize, - } - - impl<'a> $many_chunk_ident<'a> { - fn new(ca: &'a $ca_type) -> Self { - let chunks = ca.downcast_chunks(); - let current_array_left = chunks[0]; - let idx_left = 0; - let chunk_idx_left = 0; - let chunk_idx_right = chunks.len() - 1; - let current_array_right = chunks[chunk_idx_right]; - let idx_right = ca.len(); - let current_array_idx_left = 0; - let current_array_idx_right = current_array_right.len(); - let current_array_left_len = current_array_left.len(); - - $many_chunk_ident { - ca, - chunks, - current_array_left, - current_array_right, - current_array_idx_left, - current_array_idx_right, - current_array_left_len, - idx_left, - idx_right, - chunk_idx_left, - chunk_idx_right, - } - } - } - - impl<'a> Iterator for $many_chunk_ident<'a> { - type Item = $iter_item; - - fn next(&mut self) -> Option { - // end of iterator or meet reversed iterator in the middle - if self.idx_left == self.idx_right { - return None; - } - - // return value - let v = self.current_array_left.value(self.current_array_idx_left); - let ret = $return_function("next", v); - - // increment index pointers - self.idx_left += 1; - self.current_array_idx_left += 1; - - // we've reached the end of the chunk - if self.current_array_idx_left == self.current_array_left_len { - // Set a new chunk as current data - self.chunk_idx_left += 1; - - // if this evaluates to False, next call will be end of iterator - if self.chunk_idx_left < self.chunks.len() { - // reset to new array - self.current_array_idx_left = 0; - self.current_array_left = self.chunks[self.chunk_idx_left]; - self.current_array_left_len = self.current_array_left.len(); - } - } - Some(Some(ret)) - } - - fn size_hint(&self) -> (usize, Option) { - let len = self.ca.len(); - (len, Some(len)) - } - } - - impl<'a> DoubleEndedIterator for $many_chunk_ident<'a> { - fn next_back(&mut self) -> Option { - // end of iterator or meet reversed iterator in the middle - if self.idx_left == self.idx_right { - return None; - } - self.idx_right -= 1; - self.current_array_idx_right -= 1; - - let ret = self.current_array_right.value(self.current_array_idx_right); - - // we've reached the end of the chunk from the right - if self.current_array_idx_right == 0 && self.idx_right > 0 { - // set a new chunk as current data - self.chunk_idx_right -= 1; - // reset to new array - self.current_array_right = self.chunks[self.chunk_idx_right]; - self.current_array_idx_right = self.current_array_right.len(); - } - Some(Some($return_function("next_back", ret))) - } - } - - /// Many chunks no nulls - pub struct $many_chunk_null_ident<'a> { - ca: &'a $ca_type, - chunks: Vec<&'a $arrow_array>, - current_data_left: ArrayDataRef, - current_array_left: &'a $arrow_array, - current_data_right: ArrayDataRef, - current_array_right: &'a $arrow_array, - current_array_idx_left: usize, - current_array_idx_right: usize, - current_array_left_len: usize, - idx_left: usize, - idx_right: usize, - chunk_idx_left: usize, - chunk_idx_right: usize, - } - - impl<'a> $many_chunk_null_ident<'a> { - fn new(ca: &'a $ca_type) -> Self { - let chunks = ca.downcast_chunks(); - let current_array_left = chunks[0]; - let current_data_left = current_array_left.data(); - let idx_left = 0; - let chunk_idx_left = 0; - let chunk_idx_right = chunks.len() - 1; - let current_array_right = chunks[chunk_idx_right]; - let current_data_right = current_array_right.data(); - let idx_right = ca.len(); - let current_array_idx_left = 0; - let current_array_idx_right = current_data_right.len(); - let current_array_left_len = current_array_left.len(); - - $many_chunk_null_ident { - ca, - chunks, - current_data_left, - current_array_left, - current_data_right, - current_array_right, - current_array_idx_left, - current_array_idx_right, - current_array_left_len, - idx_left, - idx_right, - chunk_idx_left, - chunk_idx_right, - } - } - } - - impl<'a> Iterator for $many_chunk_null_ident<'a> { - type Item = $iter_item; - - fn next(&mut self) -> Option { - // end of iterator or meet reversed iterator in the middle - if self.idx_left == self.idx_right { - return None; - } - - // return value - let ret; - if self.current_array_left.is_null(self.current_array_idx_left) { - ret = None - } else { - let v = self.current_array_left.value(self.current_array_idx_left); - ret = Some($return_function("next", v)); - } - - // increment index pointers - self.idx_left += 1; - self.current_array_idx_left += 1; - - // we've reached the end of the chunk - if self.current_array_idx_left == self.current_array_left_len { - // Set a new chunk as current data - self.chunk_idx_left += 1; - - // if this evaluates to False, next call will be end of iterator - if self.chunk_idx_left < self.chunks.len() { - // reset to new array - self.current_array_idx_left = 0; - self.current_array_left = self.chunks[self.chunk_idx_left]; - self.current_data_left = self.current_array_left.data(); - self.current_array_left_len = self.current_array_left.len(); - } - } - Some(ret) - } - - fn size_hint(&self) -> (usize, Option) { - let len = self.ca.len(); - (len, Some(len)) - } - } - - impl<'a> DoubleEndedIterator for $many_chunk_null_ident<'a> { - fn next_back(&mut self) -> Option { - // end of iterator or meet reversed iterator in the middle - if self.idx_left == self.idx_right { - return None; - } - self.idx_right -= 1; - self.current_array_idx_right -= 1; - - let ret = if self - .current_data_right - .is_null(self.current_array_idx_right) - { - Some(None) - } else { - let v = self.current_array_right.value(self.current_array_idx_right); - Some(Some($return_function("next_back", v))) - }; - - // we've reached the end of the chunk from the right - if self.current_array_idx_right == 0 && self.idx_right > 0 { - // set a new chunk as current data - self.chunk_idx_right -= 1; - // reset to new array - self.current_array_right = self.chunks[self.chunk_idx_right]; - self.current_data_right = self.current_array_right.data(); - self.current_array_idx_right = self.current_array_right.len(); - } - ret - } - } - - pub enum $chunkdispatch<'a> { - SingleChunk($single_chunk_ident<'a>), - SingleChunkNullCheck($single_chunk_null_ident<'a>), - ManyChunk($many_chunk_ident<'a>), - ManyChunkNullCheck($many_chunk_null_ident<'a>), - } - - impl<'a> Iterator for $chunkdispatch<'a> { - type Item = $iter_item; - - fn next(&mut self) -> Option { - match self { - $chunkdispatch::SingleChunk(a) => a.next(), - $chunkdispatch::SingleChunkNullCheck(a) => a.next(), - $chunkdispatch::ManyChunk(a) => a.next(), - $chunkdispatch::ManyChunkNullCheck(a) => a.next(), - } - } - - fn size_hint(&self) -> (usize, Option) { - match self { - $chunkdispatch::SingleChunk(a) => a.size_hint(), - $chunkdispatch::SingleChunkNullCheck(a) => a.size_hint(), - $chunkdispatch::ManyChunk(a) => a.size_hint(), - $chunkdispatch::ManyChunkNullCheck(a) => a.size_hint(), - } - } - } - - impl<'a> DoubleEndedIterator for $chunkdispatch<'a> { - fn next_back(&mut self) -> Option { - match self { - $chunkdispatch::SingleChunk(a) => a.next_back(), - $chunkdispatch::SingleChunkNullCheck(a) => a.next_back(), - $chunkdispatch::ManyChunk(a) => a.next_back(), - $chunkdispatch::ManyChunkNullCheck(a) => a.next_back(), - } - } - } - - impl<'a> ExactSizeIterator for $chunkdispatch<'a> {} - - impl<'a> IntoIterator for &'a $ca_type { - type Item = $iter_item; - type IntoIter = $chunkdispatch<'a>; - - fn into_iter(self) -> Self::IntoIter { - let chunks = self.downcast_chunks(); - match chunks.len() { - 1 => { - if self.null_count() == 0 { - $chunkdispatch::SingleChunk($single_chunk_ident::new(self)) - } else { - $chunkdispatch::SingleChunkNullCheck($single_chunk_null_ident::new( - self, - )) - } - } - _ => { - if self.null_count() == 0 { - $chunkdispatch::ManyChunk($many_chunk_ident::new(self)) - } else { - $chunkdispatch::ManyChunkNullCheck($many_chunk_null_ident::new(self)) - } - } - } - } - } - }; -} - -// Used for macro. method_name is ignored -fn return_from_bool_iter(_method_name: &str, v: bool) -> bool { - v -} - -impl_iterator_traits!( - BooleanChunked, - BooleanArray, - BooleanIterCont, - BooleanIterSingleChunk, - BooleanIterSingleChunkNullCheck, - BooleanIterManyChunk, - BooleanIterManyChunkNullCheck, - BooleanIterDispatch, - Option, - bool, - return_from_bool_iter, -); - -// used for macro -fn return_from_list_iter(method_name: &str, v: ArrayRef) -> Series { - (method_name, v).into() -} - -impl_iterator_traits!( - LargeListChunked, - LargeListArray, - ListIterCont, - ListIterSingleChunk, - ListIterSingleChunkNullCheck, - ListIterManyChunk, - ListIterManyChunkNullCheck, - ListIterDispatch, - Option, - Series, - return_from_list_iter, -); - -#[cfg(test)] -mod test { - use crate::prelude::*; - - #[test] - fn out_of_bounds() { - let mut a = UInt32Chunked::new_from_slice("a", &[1, 2, 3]); - let b = UInt32Chunked::new_from_slice("a", &[1, 2, 3]); - a.append(&b); - - let v = a.into_iter().collect::>(); - assert_eq!( - vec![Some(1u32), Some(2), Some(3), Some(1), Some(2), Some(3)], - v - ) - } - - #[test] - fn test_iter_numitersinglechunknullcheck() { - let a = UInt32Chunked::new_from_opt_slice("a", &[Some(1), None, Some(3)]); - let mut it = a.into_iter(); - - // normal iterator - assert_eq!(it.next(), Some(Some(1))); - assert_eq!(it.next(), Some(None)); - assert_eq!(it.next(), Some(Some(3))); - assert_eq!(it.next(), None); - - // reverse iterator - let mut it = a.into_iter(); - assert_eq!(it.next_back(), Some(Some(3))); - assert_eq!(it.next_back(), Some(None)); - assert_eq!(it.next_back(), Some(Some(1))); - assert_eq!(it.next_back(), None); - - // iterators should not cross - let mut it = a.into_iter(); - assert_eq!(it.next_back(), Some(Some(3))); - assert_eq!(it.next(), Some(Some(1))); - assert_eq!(it.next(), Some(None)); - // should stop here as we took this one from the back - assert_eq!(it.next(), None); - - // do the same from the right side - let mut it = a.into_iter(); - assert_eq!(it.next(), Some(Some(1))); - assert_eq!(it.next_back(), Some(Some(3))); - assert_eq!(it.next_back(), Some(None)); - assert_eq!(it.next_back(), None); - } - - #[test] - fn test_iter_numitersinglechunk() { - let a = UInt32Chunked::new_from_slice("a", &[1, 2, 3]); - - // normal iterator - let mut it = a.into_iter(); - assert_eq!(it.next(), Some(Some(1))); - assert_eq!(it.next(), Some(Some(2))); - assert_eq!(it.next(), Some(Some(3))); - assert_eq!(it.next(), None); - - // reverse iterator - let mut it = a.into_iter(); - assert_eq!(it.next_back(), Some(Some(3))); - assert_eq!(it.next_back(), Some(Some(2))); - assert_eq!(it.next_back(), Some(Some(1))); - assert_eq!(it.next_back(), None); - - // iterators should not cross - let mut it = a.into_iter(); - assert_eq!(it.next_back(), Some(Some(3))); - assert_eq!(it.next(), Some(Some(1))); - assert_eq!(it.next(), Some(Some(2))); - // should stop here as we took this one from the back - assert_eq!(it.next(), None); - - // do the same from the right side - let mut it = a.into_iter(); - assert_eq!(it.next(), Some(Some(1))); - assert_eq!(it.next_back(), Some(Some(3))); - assert_eq!(it.next_back(), Some(Some(2))); - assert_eq!(it.next_back(), None); - } - - #[test] - fn test_iter_numitermanychunk() { - let mut a = UInt32Chunked::new_from_slice("a", &[1, 2]); - let a_b = UInt32Chunked::new_from_slice("", &[3]); - a.append(&a_b); - - // normal iterator - let mut it = a.into_iter(); - assert_eq!(it.next(), Some(Some(1))); - assert_eq!(it.next(), Some(Some(2))); - assert_eq!(it.next(), Some(Some(3))); - assert_eq!(it.next(), None); - - // reverse iterator - let mut it = a.into_iter(); - assert_eq!(it.next_back(), Some(Some(3))); - assert_eq!(it.next_back(), Some(Some(2))); - assert_eq!(it.next_back(), Some(Some(1))); - assert_eq!(it.next_back(), None); - - // iterators should not cross - let mut it = a.into_iter(); - assert_eq!(it.next_back(), Some(Some(3))); - assert_eq!(it.next(), Some(Some(1))); - assert_eq!(it.next(), Some(Some(2))); - // should stop here as we took this one from the back - assert_eq!(it.next(), None); - - // do the same from the right side - let mut it = a.into_iter(); - assert_eq!(it.next(), Some(Some(1))); - assert_eq!(it.next_back(), Some(Some(3))); - assert_eq!(it.next_back(), Some(Some(2))); - assert_eq!(it.next_back(), None); - } - - #[test] - fn test_iter_numitermanychunknullcheck() { - let mut a = UInt32Chunked::new_from_opt_slice("a", &[Some(1), None]); - let a_b = UInt32Chunked::new_from_opt_slice("", &[Some(3)]); - a.append(&a_b); - - // normal iterator - let mut it = a.into_iter(); - assert_eq!(it.next(), Some(Some(1))); - assert_eq!(it.next(), Some(None)); - assert_eq!(it.next(), Some(Some(3))); - assert_eq!(it.next(), None); - - // reverse iterator - let mut it = a.into_iter(); - assert_eq!(it.next_back(), Some(Some(3))); - assert_eq!(it.next_back(), Some(None)); - assert_eq!(it.next_back(), Some(Some(1))); - assert_eq!(it.next_back(), None); - - // iterators should not cross - let mut it = a.into_iter(); - assert_eq!(it.next_back(), Some(Some(3))); - assert_eq!(it.next(), Some(Some(1))); - assert_eq!(it.next(), Some(None)); - // should stop here as we took this one from the back - assert_eq!(it.next(), None); - - // do the same from the right side - let mut it = a.into_iter(); - assert_eq!(it.next(), Some(Some(1))); - assert_eq!(it.next_back(), Some(Some(3))); - assert_eq!(it.next_back(), Some(None)); - assert_eq!(it.next_back(), None); - } - - #[test] - fn test_iter_utf8itersinglechunknullcheck() { - let a = Utf8Chunked::new_from_opt_slice("a", &[Some("a"), None, Some("c")]); - - // normal iterator - let mut it = a.into_iter(); - assert_eq!(it.next(), Some(Some("a"))); - assert_eq!(it.next(), Some(None)); - assert_eq!(it.next(), Some(Some("c"))); - assert_eq!(it.next(), None); - - // reverse iterator - let mut it = a.into_iter(); - assert_eq!(it.next_back(), Some(Some("c"))); - assert_eq!(it.next_back(), Some(None)); - assert_eq!(it.next_back(), Some(Some("a"))); - assert_eq!(it.next_back(), None); - - // iterators should not cross - let mut it = a.into_iter(); - assert_eq!(it.next_back(), Some(Some("c"))); - assert_eq!(it.next(), Some(Some("a"))); - assert_eq!(it.next(), Some(None)); - // should stop here as we took this one from the back - assert_eq!(it.next(), None); - - // do the same from the right side - let mut it = a.into_iter(); - assert_eq!(it.next(), Some(Some("a"))); - assert_eq!(it.next_back(), Some(Some("c"))); - assert_eq!(it.next_back(), Some(None)); - assert_eq!(it.next_back(), None); - } - - #[test] - fn test_iter_utf8itersinglechunk() { - let a = Utf8Chunked::new_from_slice("a", &["a", "b", "c"]); - - // normal iterator - let mut it = a.into_iter(); - assert_eq!(it.next(), Some(Some("a"))); - assert_eq!(it.next(), Some(Some("b"))); - assert_eq!(it.next(), Some(Some("c"))); - assert_eq!(it.next(), None); - - // reverse iterator - let mut it = a.into_iter(); - assert_eq!(it.next_back(), Some(Some("c"))); - assert_eq!(it.next_back(), Some(Some("b"))); - assert_eq!(it.next_back(), Some(Some("a"))); - assert_eq!(it.next_back(), None); - - // iterators should not cross - let mut it = a.into_iter(); - assert_eq!(it.next_back(), Some(Some("c"))); - assert_eq!(it.next(), Some(Some("a"))); - assert_eq!(it.next(), Some(Some("b"))); - // should stop here as we took this one from the back - assert_eq!(it.next(), None); - - // do the same from the right side - let mut it = a.into_iter(); - assert_eq!(it.next(), Some(Some("a"))); - assert_eq!(it.next_back(), Some(Some("c"))); - assert_eq!(it.next_back(), Some(Some("b"))); - assert_eq!(it.next_back(), None); - } - - #[test] - fn test_iter_utf8itermanychunk() { - let mut a = Utf8Chunked::new_from_slice("a", &["a", "b"]); - let a_b = Utf8Chunked::new_from_slice("", &["c"]); - a.append(&a_b); - - // normal iterator - let mut it = a.into_iter(); - assert_eq!(it.next(), Some(Some("a"))); - assert_eq!(it.next(), Some(Some("b"))); - assert_eq!(it.next(), Some(Some("c"))); - assert_eq!(it.next(), None); - - // reverse iterator - let mut it = a.into_iter(); - assert_eq!(it.next_back(), Some(Some("c"))); - assert_eq!(it.next_back(), Some(Some("b"))); - assert_eq!(it.next_back(), Some(Some("a"))); - assert_eq!(it.next_back(), None); - - // iterators should not cross - let mut it = a.into_iter(); - assert_eq!(it.next_back(), Some(Some("c"))); - assert_eq!(it.next(), Some(Some("a"))); - assert_eq!(it.next(), Some(Some("b"))); - // should stop here as we took this one from the back - assert_eq!(it.next(), None); - - // do the same from the right side - let mut it = a.into_iter(); - assert_eq!(it.next(), Some(Some("a"))); - assert_eq!(it.next_back(), Some(Some("c"))); - assert_eq!(it.next_back(), Some(Some("b"))); - assert_eq!(it.next_back(), None); - } - - #[test] - fn test_iter_utf8itermanychunknullcheck() { - let mut a = Utf8Chunked::new_from_opt_slice("a", &[Some("a"), None]); - let a_b = Utf8Chunked::new_from_opt_slice("", &[Some("c")]); - a.append(&a_b); - - // normal iterator - let mut it = a.into_iter(); - assert_eq!(it.next(), Some(Some("a"))); - assert_eq!(it.next(), Some(None)); - assert_eq!(it.next(), Some(Some("c"))); - assert_eq!(it.next(), None); - - // reverse iterator - let mut it = a.into_iter(); - assert_eq!(it.next_back(), Some(Some("c"))); - assert_eq!(it.next_back(), Some(None)); - assert_eq!(it.next_back(), Some(Some("a"))); - assert_eq!(it.next_back(), None); - - // iterators should not cross - let mut it = a.into_iter(); - assert_eq!(it.next_back(), Some(Some("c"))); - assert_eq!(it.next(), Some(Some("a"))); - assert_eq!(it.next(), Some(None)); - // should stop here as we took this one from the back - assert_eq!(it.next(), None); - - // do the same from the right side - let mut it = a.into_iter(); - assert_eq!(it.next(), Some(Some("a"))); - assert_eq!(it.next_back(), Some(Some("c"))); - assert_eq!(it.next_back(), Some(None)); - assert_eq!(it.next_back(), None); - } - - #[test] - fn test_iter_boolitersinglechunknullcheck() { - let a = BooleanChunked::new_from_opt_slice("", &[Some(true), None, Some(false)]); - - // normal iterator - let mut it = a.into_iter(); - assert_eq!(it.next(), Some(Some(true))); - assert_eq!(it.next(), Some(None)); - assert_eq!(it.next(), Some(Some(false))); - assert_eq!(it.next(), None); - - // reverse iterator - let mut it = a.into_iter(); - assert_eq!(it.next_back(), Some(Some(false))); - assert_eq!(it.next_back(), Some(None)); - assert_eq!(it.next_back(), Some(Some(true))); - assert_eq!(it.next_back(), None); - - // iterators should not cross - let mut it = a.into_iter(); - assert_eq!(it.next_back(), Some(Some(false))); - assert_eq!(it.next(), Some(Some(true))); - assert_eq!(it.next(), Some(None)); - // should stop here as we took this one from the back - assert_eq!(it.next(), None); - - // do the same from the right side - let mut it = a.into_iter(); - assert_eq!(it.next(), Some(Some(true))); - assert_eq!(it.next_back(), Some(Some(false))); - assert_eq!(it.next_back(), Some(None)); - assert_eq!(it.next_back(), None); - } - - #[test] - fn test_iter_boolitersinglechunk() { - let a = BooleanChunked::new_from_slice("", &[true, true, false]); - - // normal iterator - let mut it = a.into_iter(); - assert_eq!(it.next(), Some(Some(true))); - assert_eq!(it.next(), Some(Some(true))); - assert_eq!(it.next(), Some(Some(false))); - assert_eq!(it.next(), None); - - // reverse iterator - let mut it = a.into_iter(); - assert_eq!(it.next_back(), Some(Some(false))); - assert_eq!(it.next_back(), Some(Some(true))); - assert_eq!(it.next_back(), Some(Some(true))); - assert_eq!(it.next_back(), None); - - // iterators should not cross - let mut it = a.into_iter(); - assert_eq!(it.next_back(), Some(Some(false))); - assert_eq!(it.next(), Some(Some(true))); - assert_eq!(it.next(), Some(Some(true))); - // should stop here as we took this one from the back - assert_eq!(it.next(), None); - - // do the same from the right side - let mut it = a.into_iter(); - assert_eq!(it.next(), Some(Some(true))); - assert_eq!(it.next_back(), Some(Some(false))); - assert_eq!(it.next_back(), Some(Some(true))); - assert_eq!(it.next_back(), None); - } - - #[test] - fn test_iter_boolitermanychunk() { - let mut a = BooleanChunked::new_from_slice("", &[true, true]); - let a_b = BooleanChunked::new_from_slice("", &[false]); - a.append(&a_b); - - // normal iterator - let mut it = a.into_iter(); - assert_eq!(it.next(), Some(Some(true))); - assert_eq!(it.next(), Some(Some(true))); - assert_eq!(it.next(), Some(Some(false))); - assert_eq!(it.next(), None); - - // reverse iterator - let mut it = a.into_iter(); - assert_eq!(it.next_back(), Some(Some(false))); - assert_eq!(it.next_back(), Some(Some(true))); - assert_eq!(it.next_back(), Some(Some(true))); - assert_eq!(it.next_back(), None); - - // iterators should not cross - let mut it = a.into_iter(); - assert_eq!(it.next_back(), Some(Some(false))); - assert_eq!(it.next(), Some(Some(true))); - assert_eq!(it.next(), Some(Some(true))); - // should stop here as we took this one from the back - assert_eq!(it.next(), None); - - // do the same from the right side - let mut it = a.into_iter(); - assert_eq!(it.next(), Some(Some(true))); - assert_eq!(it.next_back(), Some(Some(false))); - assert_eq!(it.next_back(), Some(Some(true))); - assert_eq!(it.next_back(), None); - } - - #[test] - fn test_iter_boolitermanychunknullcheck() { - let mut a = BooleanChunked::new_from_opt_slice("a", &[Some(true), None]); - let a_b = BooleanChunked::new_from_opt_slice("", &[Some(false)]); - a.append(&a_b); - - // normal iterator - let mut it = a.into_iter(); - assert_eq!(it.next(), Some(Some(true))); - assert_eq!(it.next(), Some(None)); - assert_eq!(it.next(), Some(Some(false))); - assert_eq!(it.next(), None); - - // reverse iterator - let mut it = a.into_iter(); - assert_eq!(it.next_back(), Some(Some(false))); - assert_eq!(it.next_back(), Some(None)); - assert_eq!(it.next_back(), Some(Some(true))); - assert_eq!(it.next_back(), None); - - // iterators should not cross - let mut it = a.into_iter(); - assert_eq!(it.next_back(), Some(Some(false))); - assert_eq!(it.next(), Some(Some(true))); - assert_eq!(it.next(), Some(None)); - // should stop here as we took this one from the back - assert_eq!(it.next(), None); - - // do the same from the right side - let mut it = a.into_iter(); - assert_eq!(it.next(), Some(Some(true))); - assert_eq!(it.next_back(), Some(Some(false))); - assert_eq!(it.next_back(), Some(None)); - assert_eq!(it.next_back(), None); - } -} diff --git a/polars/src/chunked_array/mod.rs b/polars/src/chunked_array/mod.rs deleted file mode 100644 index 6bdb5082651b..000000000000 --- a/polars/src/chunked_array/mod.rs +++ /dev/null @@ -1,965 +0,0 @@ -//! The typed heart of every Series column. -use crate::chunked_array::builder::{ - aligned_vec_to_primitive_array, build_with_existing_null_bitmap_and_slice, get_bitmap, -}; -use crate::prelude::*; -use arrow::{ - array::{ - ArrayRef, BooleanArray, Date64Array, Float32Array, Float64Array, Int16Array, Int32Array, - Int64Array, Int8Array, PrimitiveArray, PrimitiveBuilder, StringArray, - Time64NanosecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, - }, - buffer::Buffer, - datatypes::{ArrowPrimitiveType, DateUnit, Field, TimeUnit}, -}; -use itertools::Itertools; -use std::iter::{Copied, Map}; -use std::marker::PhantomData; -use std::sync::Arc; - -pub mod aggregate; -pub mod apply; -pub mod ops; -#[macro_use] -pub mod arithmetic; -pub mod builder; -pub mod cast; -pub mod chunkops; -pub mod comparison; -pub mod iterator; -#[cfg(feature = "ndarray")] -#[doc(cfg(feature = "ndarray"))] -mod ndarray; -#[cfg(feature = "parallel")] -#[doc(cfg(feature = "parallel"))] -pub mod par; -#[cfg(feature = "random")] -#[doc(cfg(feature = "random"))] -mod random; -pub mod set; -pub mod take; -#[cfg(feature = "temporal")] -#[doc(cfg(feature = "temporal"))] -pub mod temporal; -pub mod unique; -pub mod upstream_traits; -use arrow::array::{ - Array, ArrayDataRef, Date32Array, DurationMicrosecondArray, DurationMillisecondArray, - DurationNanosecondArray, DurationSecondArray, IntervalDayTimeArray, IntervalYearMonthArray, - LargeListArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, -}; -use std::mem; - -/// Get a 'hash' of the chunks in order to compare chunk sizes quickly. -fn create_chunk_id(chunks: &Vec) -> Vec { - let mut chunk_id = Vec::with_capacity(chunks.len()); - for a in chunks { - chunk_id.push(a.len()) - } - chunk_id -} - -/// # ChunkedArray -/// -/// Every Series contains a `ChunkedArray`. Unlike Series, ChunkedArray's are typed. This allows -/// us to apply closures to the data and collect the results to a `ChunkedArray` of te same type `T`. -/// Below we use an apply to use the cosine function to the values of a `ChunkedArray`. -/// -/// ```rust -/// # use polars::prelude::*; -/// fn apply_cosine(ca: &Float32Chunked) -> Float32Chunked { -/// ca.apply(|v| v.cos()) -/// } -/// ``` -/// -/// If we would like to cast the result we could use a Rust Iterator instead of an `apply` method. -/// Note that Iterators are slightly slower as the null values aren't ignored implicitly. -/// -/// ```rust -/// # use polars::prelude::*; -/// fn apply_cosine_and_cast(ca: &Float32Chunked) -> Float64Chunked { -/// ca.into_iter() -/// .map(|opt_v| { -/// opt_v.map(|v| v.cos() as f64) -/// }).collect() -/// } -/// ``` -/// -/// Another option is to first cast and then use an apply. -/// -/// ```rust -/// # use polars::prelude::*; -/// fn apply_cosine_and_cast(ca: &Float32Chunked) -> Float64Chunked { -/// ca.cast::() -/// .unwrap() -/// .apply(|v| v.cos()) -/// } -/// ``` -/// -/// ## Conversion between Series and ChunkedArray's -/// Conversion from a `Series` to a `ChunkedArray` is effortless. -/// -/// ```rust -/// # use polars::prelude::*; -/// fn to_chunked_array(series: &Series) -> Result<&Int32Chunked>{ -/// series.i32() -/// } -/// -/// fn to_series(ca: Int32Chunked) -> Series { -/// ca.into_series() -/// } -/// ``` -/// -/// # Iterators -/// -/// `ChunkedArrays` fully support Rust native [Iterator](https://doc.rust-lang.org/std/iter/trait.Iterator.html) -/// and [DoubleEndedIterator](https://doc.rust-lang.org/std/iter/trait.DoubleEndedIterator.html) traits, thereby -/// giving access to all the excelent methods available for [Iterators](https://doc.rust-lang.org/std/iter/trait.Iterator.html). -/// -/// ```rust -/// # use polars::prelude::*; -/// -/// fn iter_forward(ca: &Float32Chunked) { -/// ca.into_iter() -/// .for_each(|opt_v| println!("{:?}", opt_v)) -/// } -/// -/// fn iter_backward(ca: &Float32Chunked) { -/// ca.into_iter() -/// .rev() -/// .for_each(|opt_v| println!("{:?}", opt_v)) -/// } -/// ``` -/// -/// # Memory layout -/// -/// `ChunkedArray`'s use [Apache Arrow](https://github.com/apache/arrow) as backend for the memory layout. -/// Arrows memory is immutable which makes it possible to make mutliple zero copy (sub)-views from a single array. -/// -/// To be able to append data, Polars uses chunks to append new memory locations, hence the `ChunkedArray` data structure. -/// Appends are cheap, because it will not lead to a full reallocation of the whole array (as could be the case with a Rust Vec). -/// -/// However, multiple chunks in a `ChunkArray` will slow down the Iterators, arithmetic and other operations. -/// When multiplying two `ChunkArray'`s with different chunk sizes they cannot utilize [SIMD](https://en.wikipedia.org/wiki/SIMD) for instance. -/// However, when chunk size don't match, Iterators will be used to do the operation (instead of arrows upstream implementation, which may utilize SIMD) and -/// the result will be a single chunked array. -/// -/// **The key takeaway is that by applying operations on a `ChunkArray` of multiple chunks, the results will converge to -/// a `ChunkArray` of a single chunk!** It is recommended to leave them as is. If you want to have predictable performance -/// (no unexpected re-allocation of memory), it is adviced to call the [rechunk](chunked_array/chunkops/trait.ChunkOps.html) after -/// multiple append operations. -pub struct ChunkedArray { - pub(crate) field: Arc, - // For now settle with dynamic generics until we are more confident about the api - pub(crate) chunks: Vec, - // chunk lengths - chunk_id: Vec, - phantom: PhantomData, -} - -impl ChunkedArray { - /// Get Arrow ArrayData - pub fn array_data(&self) -> Vec { - self.chunks.iter().map(|arr| arr.data()).collect() - } - - /// Get the null count and the buffer of bits representing null values - pub fn null_bits(&self) -> Vec<(usize, Option)> { - self.chunks - .iter() - .map(|arr| get_bitmap(arr.as_ref())) - .collect() - } - - /// Series to ChunkedArray - pub fn unpack_series_matching_type(&self, series: &Series) -> Result<&ChunkedArray> { - macro_rules! unpack { - ($variant:ident) => {{ - if let Series::$variant(ca) = series { - let ca = unsafe { mem::transmute::<_, &ChunkedArray>(ca) }; - Ok(ca) - } else { - Err(PolarsError::DataTypeMisMatch) - } - }}; - } - match self.field.data_type() { - ArrowDataType::Utf8 => unpack!(Utf8), - ArrowDataType::Boolean => unpack!(Bool), - ArrowDataType::UInt8 => unpack!(UInt8), - ArrowDataType::UInt16 => unpack!(UInt16), - ArrowDataType::UInt32 => unpack!(UInt32), - ArrowDataType::UInt64 => unpack!(UInt64), - ArrowDataType::Int8 => unpack!(Int8), - ArrowDataType::Int16 => unpack!(Int16), - ArrowDataType::Int32 => unpack!(Int32), - ArrowDataType::Int64 => unpack!(Int64), - ArrowDataType::Float32 => unpack!(Float32), - ArrowDataType::Float64 => unpack!(Float64), - ArrowDataType::Date32(DateUnit::Day) => unpack!(Date32), - ArrowDataType::Date64(DateUnit::Millisecond) => unpack!(Date64), - ArrowDataType::Time32(TimeUnit::Millisecond) => unpack!(Time32Millisecond), - ArrowDataType::Time32(TimeUnit::Second) => unpack!(Time32Second), - ArrowDataType::Time64(TimeUnit::Nanosecond) => unpack!(Time64Nanosecond), - ArrowDataType::Time64(TimeUnit::Microsecond) => unpack!(Time64Microsecond), - ArrowDataType::Interval(IntervalUnit::DayTime) => unpack!(IntervalDayTime), - ArrowDataType::Interval(IntervalUnit::YearMonth) => unpack!(IntervalYearMonth), - ArrowDataType::Duration(TimeUnit::Nanosecond) => unpack!(DurationNanosecond), - ArrowDataType::Duration(TimeUnit::Microsecond) => unpack!(DurationMicrosecond), - ArrowDataType::Duration(TimeUnit::Millisecond) => unpack!(DurationMillisecond), - ArrowDataType::Duration(TimeUnit::Second) => unpack!(DurationSecond), - ArrowDataType::Timestamp(TimeUnit::Nanosecond, _) => unpack!(TimestampNanosecond), - ArrowDataType::Timestamp(TimeUnit::Microsecond, _) => unpack!(TimestampMicrosecond), - ArrowDataType::Timestamp(TimeUnit::Millisecond, _) => unpack!(Time32Millisecond), - ArrowDataType::Timestamp(TimeUnit::Second, _) => unpack!(TimestampSecond), - _ => unimplemented!(), - } - } - - /// Combined length of all the chunks. - pub fn len(&self) -> usize { - self.chunks.iter().fold(0, |acc, arr| acc + arr.len()) - } - - /// Unique id representing the number of chunks - pub fn chunk_id(&self) -> &Vec { - &self.chunk_id - } - - /// A reference to the chunks - pub fn chunks(&self) -> &Vec { - &self.chunks - } - - /// Returns true if contains a single chunk and has no null values - pub fn is_optimal_aligned(&self) -> bool { - self.chunks.len() == 1 && self.null_count() == 0 - } - - /// Count the null values. - pub fn null_count(&self) -> usize { - self.chunks.iter().map(|arr| arr.null_count()).sum() - } - - /// Take a view of top n elements - pub fn limit(&self, num_elements: usize) -> Result { - self.slice(0, num_elements) - } - - /// Append arrow array in place. - /// - /// ```rust - /// # use polars::prelude::*; - /// let mut array = Int32Chunked::new_from_slice("array", &[1, 2]); - /// let array_2 = Int32Chunked::new_from_slice("2nd", &[3]); - /// - /// array.append(&array_2); - /// assert_eq!(Vec::from(&array), [Some(1), Some(2), Some(3)]) - /// ``` - pub fn append_array(&mut self, other: ArrayRef) -> Result<()> { - if other.data_type() == self.field.data_type() { - self.chunks.push(other); - self.chunk_id = create_chunk_id(&self.chunks); - Ok(()) - } else { - Err(PolarsError::DataTypeMisMatch) - } - } - - /// Create a new ChunkedArray from self, where the chunks are replaced. - fn copy_with_chunks(&self, chunks: Vec) -> Self { - let chunk_id = create_chunk_id(&chunks); - ChunkedArray { - field: self.field.clone(), - chunks, - chunk_id, - phantom: PhantomData, - } - } - - /// Recompute the chunk_id / chunk_lengths. - fn set_chunk_id(&mut self) { - self.chunk_id = create_chunk_id(&self.chunks) - } - - /// Slice the array. The chunks are reallocated the underlying data slices are zero copy. - pub fn slice(&self, offset: usize, length: usize) -> Result { - if offset + length > self.len() { - return Err(PolarsError::OutOfBounds); - } - let mut remaining_length = length; - let mut remaining_offset = offset; - let mut new_chunks = vec![]; - - for chunk in &self.chunks { - let chunk_len = chunk.len(); - if remaining_offset >= chunk_len { - remaining_offset -= chunk_len; - continue; - } - let take_len; - if remaining_length + remaining_offset > chunk_len { - take_len = chunk_len - remaining_offset; - } else { - take_len = remaining_length; - } - - new_chunks.push(chunk.slice(remaining_offset, take_len)); - remaining_length -= take_len; - remaining_offset = 0; - if remaining_length == 0 { - break; - } - } - Ok(self.copy_with_chunks(new_chunks)) - } - - /// Get a mask of the null values. - pub fn is_null(&self) -> BooleanChunked { - if self.null_count() == 0 { - return BooleanChunked::full("is_null", false, self.len()); - } - let chunks = self - .chunks - .iter() - .map(|arr| { - let mut builder = PrimitiveBuilder::::new(arr.len()); - for i in 0..arr.len() { - builder - .append_value(arr.is_null(i)) - .expect("could not append"); - } - let chunk: ArrayRef = Arc::new(builder.finish()); - chunk - }) - .collect_vec(); - BooleanChunked::new_from_chunks("is_null", chunks) - } - - /// Get data type of ChunkedArray. - pub fn dtype(&self) -> &ArrowDataType { - self.field.data_type() - } - - /// Get the index of the chunk and the index of the value in that chunk - #[inline] - pub(crate) fn index_to_chunked_index(&self, index: usize) -> (usize, usize) { - if self.chunk_id().len() == 1 { - return (0, index); - } - let mut index_remainder = index; - let mut current_chunk_idx = 0; - - for chunk in &self.chunks { - let chunk_len = chunk.len(); - if chunk_len - 1 >= index_remainder { - break; - } else { - index_remainder -= chunk_len; - current_chunk_idx += 1; - } - } - (current_chunk_idx, index_remainder) - } - - /// Get the head of the ChunkedArray - pub fn head(&self, length: Option) -> Self { - let res_ca = match length { - Some(len) => self.slice(0, std::cmp::min(len, self.len())), - None => self.slice(0, std::cmp::min(10, self.len())), - }; - res_ca.unwrap() - } - - /// Get the tail of the ChunkedArray - pub fn tail(&self, length: Option) -> Self { - let len = match length { - Some(len) => std::cmp::min(len, self.len()), - None => std::cmp::min(10, self.len()), - }; - self.slice(self.len() - len, len).unwrap() - } - - /// Append in place. - pub fn append(&mut self, other: &Self) - where - Self: std::marker::Sized, - { - self.chunks.extend(other.chunks.clone()) - } - - /// Name of the ChunkedArray. - pub fn name(&self) -> &str { - self.field.name() - } - - /// Get a reference to the field. - pub fn ref_field(&self) -> &Field { - &self.field - } - - /// Rename this ChunkedArray. - pub fn rename(&mut self, name: &str) { - self.field = Arc::new(Field::new( - name, - self.field.data_type().clone(), - self.field.is_nullable(), - )) - } -} - -impl ChunkedArray -where - T: PolarsDataType, -{ - /// Create a new ChunkedArray from existing chunks. - pub fn new_from_chunks(name: &str, chunks: Vec) -> Self { - let field = Arc::new(Field::new(name, T::get_data_type(), true)); - let chunk_id = create_chunk_id(&chunks); - ChunkedArray { - field, - chunks, - chunk_id, - phantom: PhantomData, - } - } - - /// Get a single value. Beware this is slow. (only used for formatting) - pub(crate) fn get_any(&self, index: usize) -> AnyType { - let (chunk_idx, idx) = self.index_to_chunked_index(index); - let arr = &self.chunks[chunk_idx]; - - if arr.is_null(idx) { - return AnyType::Null; - } - - macro_rules! downcast_and_pack { - ($casttype:ident, $variant:ident) => {{ - let arr = arr - .as_any() - .downcast_ref::<$casttype>() - .expect("could not downcast one of the chunks"); - let v = arr.value(idx); - AnyType::$variant(v) - }}; - } - macro_rules! downcast { - ($casttype:ident) => {{ - let arr = arr - .as_any() - .downcast_ref::<$casttype>() - .expect("could not downcast one of the chunks"); - arr.value(idx) - }}; - } - // TODO: insert types - match T::get_data_type() { - ArrowDataType::Utf8 => downcast_and_pack!(StringArray, Utf8), - ArrowDataType::Boolean => downcast_and_pack!(BooleanArray, Boolean), - ArrowDataType::UInt8 => downcast_and_pack!(UInt8Array, UInt8), - ArrowDataType::UInt16 => downcast_and_pack!(UInt16Array, UInt16), - ArrowDataType::UInt32 => downcast_and_pack!(UInt32Array, UInt32), - ArrowDataType::UInt64 => downcast_and_pack!(UInt64Array, UInt64), - ArrowDataType::Int8 => downcast_and_pack!(Int8Array, Int8), - ArrowDataType::Int16 => downcast_and_pack!(Int16Array, Int16), - ArrowDataType::Int32 => downcast_and_pack!(Int32Array, Int32), - ArrowDataType::Int64 => downcast_and_pack!(Int64Array, Int64), - ArrowDataType::Float32 => downcast_and_pack!(Float32Array, Float32), - ArrowDataType::Float64 => downcast_and_pack!(Float64Array, Float64), - ArrowDataType::Date32(DateUnit::Day) => downcast_and_pack!(Date32Array, Date32), - ArrowDataType::Date64(DateUnit::Millisecond) => downcast_and_pack!(Date64Array, Date64), - ArrowDataType::Time32(TimeUnit::Millisecond) => { - let v = downcast!(Time32MillisecondArray); - AnyType::Time32(v, TimeUnit::Millisecond) - } - ArrowDataType::Time32(TimeUnit::Second) => { - let v = downcast!(Time32SecondArray); - AnyType::Time32(v, TimeUnit::Second) - } - ArrowDataType::Time64(TimeUnit::Nanosecond) => { - let v = downcast!(Time64NanosecondArray); - AnyType::Time64(v, TimeUnit::Nanosecond) - } - ArrowDataType::Time64(TimeUnit::Microsecond) => { - let v = downcast!(Time64MicrosecondArray); - AnyType::Time64(v, TimeUnit::Microsecond) - } - ArrowDataType::Interval(IntervalUnit::DayTime) => { - downcast_and_pack!(IntervalDayTimeArray, IntervalDayTime) - } - ArrowDataType::Interval(IntervalUnit::YearMonth) => { - downcast_and_pack!(IntervalYearMonthArray, IntervalYearMonth) - } - ArrowDataType::Duration(TimeUnit::Nanosecond) => { - let v = downcast!(DurationNanosecondArray); - AnyType::Duration(v, TimeUnit::Nanosecond) - } - ArrowDataType::Duration(TimeUnit::Microsecond) => { - let v = downcast!(DurationMicrosecondArray); - AnyType::Duration(v, TimeUnit::Microsecond) - } - ArrowDataType::Duration(TimeUnit::Millisecond) => { - let v = downcast!(DurationMillisecondArray); - AnyType::Duration(v, TimeUnit::Millisecond) - } - ArrowDataType::Duration(TimeUnit::Second) => { - let v = downcast!(DurationSecondArray); - AnyType::Duration(v, TimeUnit::Second) - } - ArrowDataType::Timestamp(TimeUnit::Nanosecond, _) => { - let v = downcast!(TimestampNanosecondArray); - AnyType::TimeStamp(v, TimeUnit::Nanosecond) - } - ArrowDataType::Timestamp(TimeUnit::Microsecond, _) => { - let v = downcast!(TimestampMicrosecondArray); - AnyType::TimeStamp(v, TimeUnit::Microsecond) - } - ArrowDataType::Timestamp(TimeUnit::Millisecond, _) => { - let v = downcast!(TimestampMillisecondArray); - AnyType::TimeStamp(v, TimeUnit::Millisecond) - } - ArrowDataType::Timestamp(TimeUnit::Second, _) => { - let v = downcast!(TimestampSecondArray); - AnyType::TimeStamp(v, TimeUnit::Second) - } - ArrowDataType::LargeList(_) => { - let v = downcast!(LargeListArray); - AnyType::LargeList(("", v).into()) - } - _ => unimplemented!(), - } - } -} - -impl Utf8Chunked { - #[deprecated(since = "3.1", note = "Use `new_from_slice`")] - pub fn new_utf8_from_slice>(name: &str, v: &[S]) -> Self { - Utf8Chunked::new_from_slice(name, v) - } - - #[deprecated(since = "3.1", note = "Use `new_from_opt_slice`")] - pub fn new_utf8_from_opt_slice>(name: &str, opt_v: &[Option]) -> Self { - Utf8Chunked::new_from_opt_slice(name, opt_v) - } -} - -impl ChunkedArray -where - T: ArrowPrimitiveType, -{ - /// Create a new ChunkedArray by taking ownershipt of the AlignedVec. This operation is zero copy. - pub fn new_from_aligned_vec(name: &str, v: AlignedVec) -> Self { - let arr = aligned_vec_to_primitive_array::(v, None, 0); - Self::new_from_chunks(name, vec![Arc::new(arr)]) - } - - /// Nullify values in slice with an existing null bitmap - pub fn new_with_null_bitmap( - name: &str, - values: &[T::Native], - buffer: Option, - null_count: usize, - ) -> Self { - let len = values.len(); - let arr = Arc::new(build_with_existing_null_bitmap_and_slice::( - buffer, null_count, values, - )); - ChunkedArray { - field: Arc::new(Field::new(name, T::get_data_type(), true)), - chunks: vec![arr], - chunk_id: vec![len], - phantom: PhantomData, - } - } - - /// Nullify values in slice with an existing null bitmap - pub fn new_from_owned_with_null_bitmap( - name: &str, - values: AlignedVec, - buffer: Option, - null_count: usize, - ) -> Self { - let len = values.0.len(); - let arr = Arc::new(aligned_vec_to_primitive_array::( - values, buffer, null_count, - )); - ChunkedArray { - field: Arc::new(Field::new(name, T::get_data_type(), true)), - chunks: vec![arr], - chunk_id: vec![len], - phantom: PhantomData, - } - } -} - -impl ChunkedArray -where - T: PolarsNumericType, -{ - /// Contiguous slice - pub fn cont_slice(&self) -> Result<&[T::Native]> { - if self.chunks.len() == 1 && self.chunks[0].null_count() == 0 { - Ok(self.downcast_chunks()[0].value_slice(0, self.len())) - } else { - Err(PolarsError::NoSlice) - } - } - - /// Get slices of the underlying arrow data. - /// NOTE: null values should be taken into account by the user of these slices as they are handled - /// separately - pub fn data_views(&self) -> Vec<&[T::Native]> { - self.downcast_chunks() - .iter() - .map(|arr| arr.value_slice(0, arr.len())) - .collect() - } - - /// Rechunk and return a ptr to the start of the array - pub fn as_single_ptr(&mut self) -> usize { - let mut ca = self.rechunk(None).expect("should not fail"); - mem::swap(&mut ca, self); - let a = self.data_views()[0]; - let ptr = a.as_ptr(); - ptr as usize - } - - /// If [cont_slice](#method.cont_slice) is successful a closure is mapped over the elements. - /// - /// # Example - /// - /// ``` - /// use polars::prelude::*; - /// fn multiply(ca: &UInt32Chunked) -> Result { - /// let mapped = ca.map(|v| v * 2)?; - /// Ok(mapped.collect()) - /// } - /// ``` - pub fn map(&self, f: F) -> Result>, F>> - where - F: Fn(T::Native) -> B, - { - let slice = self.cont_slice()?; - Ok(slice.iter().copied().map(f)) - } - - /// If [cont_slice](#method.cont_slice) fails we can fallback on an iterator with null checks - /// and map a closure over the elements. - /// - /// # Example - /// - /// ``` - /// use polars::prelude::*; - /// use itertools::Itertools; - /// fn multiply(ca: &UInt32Chunked) -> Series { - /// let mapped_result = ca.map(|v| v * 2); - /// - /// if let Ok(mapped) = mapped_result { - /// mapped.collect() - /// } else { - /// ca - /// .map_null_checks(|opt_v| opt_v.map(|v |v * 2)).collect() - /// } - /// } - /// ``` - pub fn map_null_checks(&self, f: F) -> Map, F> - where - F: Fn(Option) -> B, - { - self.into_iter().map(f) - } - - /// If [cont_slice](#method.cont_slice) is successful a closure can be applied as aggregation - /// - /// # Example - /// - /// ``` - /// use polars::prelude::*; - /// fn compute_sum(ca: &UInt32Chunked) -> Result { - /// ca.fold(0, |acc, value| acc + value) - /// } - /// ``` - pub fn fold(&self, init: B, f: F) -> Result - where - F: Fn(B, T::Native) -> B, - { - let slice = self.cont_slice()?; - Ok(slice.iter().copied().fold(init, f)) - } - - /// If [cont_slice](#method.cont_slice) fails we can fallback on an iterator with null checks - /// and a closure for aggregation - /// - /// # Example - /// - /// ``` - /// use polars::prelude::*; - /// fn compute_sum(ca: &UInt32Chunked) -> u32 { - /// match ca.fold(0, |acc, value| acc + value) { - /// // faster sum without null checks was successful - /// Ok(sum) => sum, - /// // Null values or multiple chunks in ChunkedArray, we need to do more bounds checking - /// Err(_) => ca.fold_null_checks(0, |acc, opt_value| { - /// match opt_value { - /// Some(v) => acc + v, - /// None => acc - /// } - /// }) - /// } - /// } - /// ``` - pub fn fold_null_checks(&self, init: B, f: F) -> B - where - F: Fn(B, Option) -> B, - { - self.into_iter().fold(init, f) - } -} - -impl LargeListChunked { - pub(crate) fn get_inner_dtype(&self) -> &Box { - match self.dtype() { - ArrowDataType::LargeList(dt) => dt, - _ => panic!("should not happen"), - } - } -} - -impl Clone for ChunkedArray { - fn clone(&self) -> Self { - ChunkedArray { - field: self.field.clone(), - chunks: self.chunks.clone(), - chunk_id: self.chunk_id.clone(), - phantom: PhantomData, - } - } -} - -pub trait Downcast { - fn downcast_chunks(&self) -> Vec<&T>; -} - -impl Downcast> for ChunkedArray -where - T: PolarsNumericType, -{ - fn downcast_chunks(&self) -> Vec<&PrimitiveArray> { - self.chunks - .iter() - .map(|arr| { - let arr = &**arr; - unsafe { &*(arr as *const dyn Array as *const PrimitiveArray) } - }) - .collect::>() - } -} - -impl Downcast for Utf8Chunked { - fn downcast_chunks(&self) -> Vec<&StringArray> { - self.chunks - .iter() - .map(|arr| { - let arr = &**arr; - unsafe { &*(arr as *const dyn Array as *const StringArray) } - }) - .collect::>() - } -} - -impl Downcast for BooleanChunked { - fn downcast_chunks(&self) -> Vec<&BooleanArray> { - self.chunks - .iter() - .map(|arr| { - let arr = &**arr; - unsafe { &*(arr as *const dyn Array as *const BooleanArray) } - }) - .collect::>() - } -} - -impl Downcast for LargeListChunked { - fn downcast_chunks(&self) -> Vec<&LargeListArray> { - self.chunks - .iter() - .map(|arr| { - let arr = &**arr; - unsafe { &*(arr as *const dyn Array as *const LargeListArray) } - }) - .collect::>() - } -} - -impl AsRef> for ChunkedArray { - fn as_ref(&self) -> &ChunkedArray { - self - } -} - -pub struct NoNull(pub T); - -#[cfg(test)] -pub(crate) mod test { - use crate::prelude::*; - - pub(crate) fn get_chunked_array() -> Int32Chunked { - ChunkedArray::new_from_slice("a", &[1, 2, 3]) - } - - #[test] - fn test_sort() { - let a = Int32Chunked::new_from_slice("a", &[1, 9, 3, 2]); - let b = a - .sort(false) - .into_iter() - .map(|opt| opt.unwrap()) - .collect::>(); - assert_eq!(b, [1, 2, 3, 9]); - let a = Utf8Chunked::new_from_slice("a", &["b", "a", "c"]); - let a = a.sort(false); - let b = a.into_iter().collect::>(); - assert_eq!(b, [Some("a"), Some("b"), Some("c")]); - } - - #[test] - fn arithmetic() { - let s1 = get_chunked_array(); - println!("{:?}", s1.chunks); - let s2 = &s1.clone(); - let s1 = &s1; - println!("{:?}", s1 + s2); - println!("{:?}", s1 - s2); - println!("{:?}", s1 * s2); - } - - #[test] - fn iter() { - let s1 = get_chunked_array(); - // sum - assert_eq!(s1.into_iter().fold(0, |acc, val| { acc + val.unwrap() }), 6) - } - - #[test] - fn limit() { - let a = get_chunked_array(); - let b = a.limit(2).unwrap(); - println!("{:?}", b); - assert_eq!(b.len(), 2) - } - - #[test] - fn filter() { - let a = get_chunked_array(); - let b = a - .filter(&BooleanChunked::new_from_slice( - "filter", - &[true, false, false], - )) - .unwrap(); - assert_eq!(b.len(), 1); - assert_eq!(b.into_iter().next(), Some(Some(1))); - } - - #[test] - fn aggregates_numeric() { - let a = get_chunked_array(); - assert_eq!(a.max(), Some(3)); - assert_eq!(a.min(), Some(1)); - assert_eq!(a.sum(), Some(6)) - } - - #[test] - fn take() { - let a = get_chunked_array(); - let new = a.take([0u32, 1].as_ref().as_take_iter(), None).unwrap(); - assert_eq!(new.len(), 2) - } - - #[test] - fn get() { - let mut a = get_chunked_array(); - assert_eq!(AnyType::Int32(2), a.get_any(1)); - // check if chunks indexes are properly determined - a.append_array(a.chunks[0].clone()).unwrap(); - assert_eq!(AnyType::Int32(1), a.get_any(3)); - } - - #[test] - fn cast() { - let a = get_chunked_array(); - let b = a.cast::().unwrap(); - assert_eq!(b.field.data_type(), &ArrowDataType::Int64) - } - - fn assert_slice_equal(ca: &ChunkedArray, eq: &[T::Native]) - where - ChunkedArray: ChunkOps, - T: PolarsNumericType, - { - assert_eq!( - ca.into_iter().map(|opt| opt.unwrap()).collect::>(), - eq - ) - } - - #[test] - fn slice() { - let mut first = UInt32Chunked::new_from_slice("first", &[0, 1, 2]); - let second = UInt32Chunked::new_from_slice("second", &[3, 4, 5]); - first.append(&second); - assert_slice_equal(&first.slice(0, 3).unwrap(), &[0, 1, 2]); - assert_slice_equal(&first.slice(0, 4).unwrap(), &[0, 1, 2, 3]); - assert_slice_equal(&first.slice(1, 4).unwrap(), &[1, 2, 3, 4]); - assert_slice_equal(&first.slice(3, 2).unwrap(), &[3, 4]); - assert_slice_equal(&first.slice(3, 3).unwrap(), &[3, 4, 5]); - assert!(first.slice(3, 4).is_err()); - } - - #[test] - fn sorting() { - let s = UInt32Chunked::new_from_slice("", &[9, 2, 4]); - let sorted = s.sort(false); - assert_slice_equal(&sorted, &[2, 4, 9]); - let sorted = s.sort(true); - assert_slice_equal(&sorted, &[9, 4, 2]); - - let s: Utf8Chunked = ["b", "a", "z"].iter().collect(); - let sorted = s.sort(false); - assert_eq!( - sorted.into_iter().collect::>(), - &[Some("a"), Some("b"), Some("z")] - ); - let sorted = s.sort(true); - assert_eq!( - sorted.into_iter().collect::>(), - &[Some("z"), Some("b"), Some("a")] - ); - let s: Utf8Chunked = [Some("b"), None, Some("z")].iter().collect(); - let sorted = s.sort(false); - assert_eq!( - sorted.into_iter().collect::>(), - &[None, Some("b"), Some("z")] - ); - } - - #[test] - fn reverse() { - let s = UInt32Chunked::new_from_slice("", &[1, 2, 3]); - // path with continuous slice - assert_slice_equal(&s.reverse(), &[3, 2, 1]); - // path with options - let s = UInt32Chunked::new_from_opt_slice("", &[Some(1), None, Some(3)]); - assert_eq!(Vec::from(&s.reverse()), &[Some(3), None, Some(1)]); - let s = BooleanChunked::new_from_slice("", &[true, false]); - assert_eq!(Vec::from(&s.reverse()), &[Some(false), Some(true)]); - - let s = Utf8Chunked::new_from_slice("", &["a", "b", "c"]); - assert_eq!(Vec::from(&s.reverse()), &[Some("c"), Some("b"), Some("a")]); - - let s = Utf8Chunked::new_from_opt_slice("", &[Some("a"), None, Some("c")]); - assert_eq!(Vec::from(&s.reverse()), &[Some("c"), None, Some("a")]); - } -} diff --git a/polars/src/chunked_array/ndarray.rs b/polars/src/chunked_array/ndarray.rs deleted file mode 100644 index 868c6a4ebe4c..000000000000 --- a/polars/src/chunked_array/ndarray.rs +++ /dev/null @@ -1,102 +0,0 @@ -use crate::prelude::*; -use ndarray::prelude::*; - -impl ChunkedArray -where - T: PolarsNumericType, -{ - /// If data is aligned in a single chunk and has no Null values a zero copy view is returned - /// as an `ndarray` - pub fn to_ndarray(&self) -> Result> { - let slice = self.cont_slice()?; - Ok(aview1(slice)) - } -} - -impl LargeListChunked { - /// If all nested `Series` have the same length, a 2 dimensional `ndarray::Array` is returned. - pub fn to_ndarray(&self) -> Result> - where - N: PolarsNumericType, - { - if self.null_count() != 0 { - return Err(PolarsError::HasNullValues); - } else { - let mut iter = self.into_no_null_iter(); - - let mut ndarray; - let width; - - // first iteration determine the size - if let Some(series) = iter.next() { - width = series.len(); - - ndarray = unsafe { Array::uninitialized((self.len(), series.len())) }; - - let series = series.cast::()?; - let ca = series.unpack::()?; - let a = ca.to_ndarray()?; - let mut row = ndarray.slice_mut(s![0, ..]); - row.assign(&a); - - while let Some(series) = iter.next() { - if series.len() != width { - return Err(PolarsError::ShapeMisMatch); - } - let series = series.cast::()?; - let ca = series.unpack::()?; - let a = ca.to_ndarray()?; - let mut row = ndarray.slice_mut(s![0, ..]); - row.assign(&a) - } - Ok(ndarray) - } else { - Err(PolarsError::NoData) - } - } - } -} - -impl DataFrame { - /// Create a 2D `ndarray::Array` from this `DataFrame`. This requires all columns in the - /// `DataFrame` to be non-null and numeric. They will be casted to the same data type - /// (if they aren't already). - /// - /// ```rust - /// use polars::prelude::*; - /// let a = UInt32Chunked::new_from_slice("a", &[1, 2, 3]).into_series(); - /// let b = Float64Chunked::new_from_slice("b", &[10., 8., 6.]).into_series(); - /// - /// let df = DataFrame::new(vec![a, b]).unwrap(); - /// let ndarray = df.to_ndarray::().unwrap(); - /// println!("{:?}", ndarray); - /// ``` - /// Outputs: - /// ```text - /// [[1.0, 10.0], - /// [2.0, 8.0], - /// [3.0, 6.0]], shape=[3, 2], strides=[2, 1], layout=C (0x1), const ndim=2/ - /// ``` - pub fn to_ndarray(&self) -> Result> - where - N: PolarsNumericType, - N::Native: num::Zero + Copy, - { - let mut ndarr = Array2::zeros(self.shape()); - for (col_idx, series) in self.get_columns().iter().enumerate() { - if series.null_count() != 0 { - return Err(PolarsError::HasNullValues); - } - // this is an Arc clone if already of type N - let series = series.cast::()?; - let ca = series.unpack::()?; - - ca.into_no_null_iter() - .enumerate() - .for_each(|(row_idx, val)| { - *&mut ndarr[[row_idx, col_idx]] = val; - }) - } - Ok(ndarr) - } -} diff --git a/polars/src/chunked_array/ops.rs b/polars/src/chunked_array/ops.rs deleted file mode 100644 index 4ad6efaa4287..000000000000 --- a/polars/src/chunked_array/ops.rs +++ /dev/null @@ -1,879 +0,0 @@ -//! Traits for miscellaneous operations on ChunkedArray -use crate::chunked_array::builder::get_large_list_builder; -use crate::prelude::*; -use crate::utils::Xob; -use arrow::compute; -use itertools::Itertools; -use num::{Num, NumCast}; -use std::cmp::Ordering; -use std::marker::Sized; -use std::ops::{Add, Div}; - -/// Random access -pub trait TakeRandom { - type Item; - - /// Get a nullable value by index. - fn get(&self, index: usize) -> Option; - - /// Get a value by index and ignore the null bit. - unsafe fn get_unchecked(&self, index: usize) -> Self::Item; -} -// Utility trait because associated type needs a lifetime -pub trait TakeRandomUtf8 { - type Item; - - /// Get a nullable value by index. - fn get(self, index: usize) -> Option; - - /// Get a value by index and ignore the null bit. - unsafe fn get_unchecked(self, index: usize) -> Self::Item; -} - -/// Fast access by index. -pub trait ChunkTake { - /// Take values from ChunkedArray by index. - fn take(&self, indices: impl Iterator, capacity: Option) -> Result - where - Self: std::marker::Sized; - - /// Take values from ChunkedArray by index without checking bounds. - unsafe fn take_unchecked( - &self, - indices: impl Iterator, - capacity: Option, - ) -> Self - where - Self: std::marker::Sized; - - /// Take values from ChunkedArray by Option. - fn take_opt( - &self, - indices: impl Iterator>, - capacity: Option, - ) -> Result - where - Self: std::marker::Sized; - - /// Take values from ChunkedArray by Option. - unsafe fn take_opt_unchecked( - &self, - indices: impl Iterator>, - capacity: Option, - ) -> Self - where - Self: std::marker::Sized; -} - -/// Create a `ChunkedArray` with new values by index or by boolean mask. -/// Note that these operations clone data. This is however the only way we can modify at mask or -/// index level as the underlying Arrow arrays are immutable. -pub trait ChunkSet<'a, A, B> { - /// Set the values at indexes `idx` to some optional value `Option`. - /// - /// # Example - /// - /// ```rust - /// # use polars::prelude::*; - /// let ca = Int32Chunked::new_from_slice("a", &[1, 2, 3]); - /// let new = ca.set_at_idx(&[0, 1], Some(10)).unwrap(); - /// - /// assert_eq!(Vec::from(&new), &[Some(10), Some(10), Some(3)]); - /// ``` - fn set_at_idx(&'a self, idx: &T, opt_value: Option) -> Result - where - Self: Sized; - - /// Set the values at indexes `idx` by applying a closure to these values. - /// - /// # Example - /// - /// ```rust - /// # use polars::prelude::*; - /// let ca = Int32Chunked::new_from_slice("a", &[1, 2, 3]); - /// let new = ca.set_at_idx_with(&[0, 1], |opt_v| opt_v.map(|v| v - 5)).unwrap(); - /// - /// assert_eq!(Vec::from(&new), &[Some(-4), Some(-3), Some(3)]); - /// ``` - fn set_at_idx_with(&'a self, idx: &T, f: F) -> Result - where - Self: Sized, - F: Fn(Option) -> Option; - /// Set the values where the mask evaluates to `true` to some optional value `Option`. - /// - /// # Example - /// - /// ```rust - /// # use polars::prelude::*; - /// let ca = Int32Chunked::new_from_slice("a", &[1, 2, 3]); - /// let mask = BooleanChunked::new_from_slice("mask", &[false, true, false]); - /// let new = ca.set(&mask, Some(5)).unwrap(); - /// assert_eq!(Vec::from(&new), &[Some(1), Some(5), Some(3)]); - /// ``` - fn set(&'a self, mask: &BooleanChunked, opt_value: Option) -> Result - where - Self: Sized; - - /// Set the values where the mask evaluates to `true` by applying a closure to these values. - /// - /// # Example - /// - /// ```rust - /// # use polars::prelude::*; - /// let ca = Int32Chunked::new_from_slice("a", &[1, 2, 3]); - /// let mask = BooleanChunked::new_from_slice("mask", &[false, true, false]); - /// let new = ca.set_with(&mask, |opt_v| opt_v.map( - /// |v| v * 2 - /// )).unwrap(); - /// assert_eq!(Vec::from(&new), &[Some(1), Some(4), Some(3)]); - /// ``` - fn set_with(&'a self, mask: &BooleanChunked, f: F) -> Result - where - Self: Sized, - F: Fn(Option) -> Option; -} - -/// Cast `ChunkedArray` to `ChunkedArray` -pub trait ChunkCast { - /// Cast `ChunkedArray` to `ChunkedArray` - fn cast(&self) -> Result> - where - N: PolarsDataType; -} - -/// Fastest way to do elementwise operations on a ChunkedArray -pub trait ChunkApply<'a, A, B> { - /// Apply a closure `F` elementwise. - fn apply(&'a self, f: F) -> Self - where - F: Fn(A) -> B + Copy; -} - -/// Aggregation operations -pub trait ChunkAgg { - /// Returns `None` if the array is empty or only contains null values. - fn sum(&self) -> Option; - fn min(&self) -> Option; - /// Returns the maximum value in the array, according to the natural order. - /// Returns an option because the array is nullable. - fn max(&self) -> Option; - - /// Returns the mean value in the array. - /// Returns an option because the array is nullable. - fn mean(&self) -> Option; - - /// Returns the mean value in the array. - /// Returns an option because the array is nullable. - fn median(&self) -> Option; -} - -/// Compare [Series](series/series/enum.Series.html) -/// and [ChunkedArray](series/chunked_array/struct.ChunkedArray.html)'s and get a `boolean` mask that -/// can be used to filter rows. -/// -/// # Example -/// -/// ``` -/// use polars::prelude::*; -/// fn filter_all_ones(df: &DataFrame) -> Result { -/// let mask = df -/// .column("column_a")? -/// .eq(1); -/// -/// df.filter(&mask) -/// } -/// ``` -pub trait ChunkCompare { - /// Check for equality and regard missing values as equal. - fn eq_missing(&self, rhs: Rhs) -> BooleanChunked; - - /// Check for equality. - fn eq(&self, rhs: Rhs) -> BooleanChunked; - - /// Check for inequality. - fn neq(&self, rhs: Rhs) -> BooleanChunked; - - /// Greater than comparison. - fn gt(&self, rhs: Rhs) -> BooleanChunked; - - /// Greater than or equal comparison. - fn gt_eq(&self, rhs: Rhs) -> BooleanChunked; - - /// Less than comparison. - fn lt(&self, rhs: Rhs) -> BooleanChunked; - - /// Less than or equal comparison - fn lt_eq(&self, rhs: Rhs) -> BooleanChunked; -} - -/// Get unique values in a `ChunkedArray` -pub trait ChunkUnique { - // We don't return Self to be able to use AutoRef specialization - /// Get unique values of a ChunkedArray - fn unique(&self) -> ChunkedArray; - - /// Get first index of the unique values in a `ChunkedArray`. - fn arg_unique(&self) -> Vec; - - /// Number of unique values in the `ChunkedArray` - fn n_unique(&self) -> usize { - self.arg_unique().len() - } -} - -/// Sort operations on `ChunkedArray`. -pub trait ChunkSort { - /// Returned a sorted `ChunkedArray`. - fn sort(&self, reverse: bool) -> ChunkedArray; - - /// Sort this array in place. - fn sort_in_place(&mut self, reverse: bool); - - /// Retrieve the indexes needed to sort this array. - fn argsort(&self, reverse: bool) -> Vec; -} - -fn sort_partial(a: &Option, b: &Option) -> Ordering { - match (a, b) { - (Some(a), Some(b)) => a.partial_cmp(b).expect("could not compare"), - (None, Some(_)) => Ordering::Less, - (Some(_), None) => Ordering::Greater, - (None, None) => Ordering::Equal, - } -} - -impl ChunkSort for ChunkedArray -where - T: PolarsNumericType, - T::Native: std::cmp::PartialOrd, -{ - fn sort(&self, reverse: bool) -> ChunkedArray { - if reverse { - self.into_iter() - .sorted_by(|a, b| sort_partial(b, a)) - .collect() - } else { - self.into_iter() - .sorted_by(|a, b| sort_partial(a, b)) - .collect() - } - } - - fn sort_in_place(&mut self, reverse: bool) { - let sorted = self.sort(reverse); - self.chunks = sorted.chunks; - } - - fn argsort(&self, reverse: bool) -> Vec { - if reverse { - self.into_iter() - .enumerate() - .sorted_by(|(_idx_a, a), (_idx_b, b)| sort_partial(b, a)) - .map(|(idx, _v)| idx) - .collect::>() - .0 - } else { - self.into_iter() - .enumerate() - .sorted_by(|(_idx_a, a), (_idx_b, b)| sort_partial(a, b)) - .map(|(idx, _v)| idx) - .collect::>() - .0 - } - } -} - -macro_rules! argsort { - ($self:ident, $closure:expr) => {{ - $self - .into_iter() - .enumerate() - .sorted_by($closure) - .map(|(idx, _v)| idx) - .collect::>() - .0 - }}; -} - -macro_rules! sort { - ($self:ident, $reverse:ident) => {{ - if $reverse { - $self.into_iter().sorted_by(|a, b| b.cmp(a)).collect() - } else { - $self.into_iter().sorted_by(|a, b| a.cmp(b)).collect() - } - }}; -} - -impl ChunkSort for Utf8Chunked { - fn sort(&self, reverse: bool) -> Utf8Chunked { - sort!(self, reverse) - } - - fn sort_in_place(&mut self, reverse: bool) { - let sorted = self.sort(reverse); - self.chunks = sorted.chunks; - } - - fn argsort(&self, reverse: bool) -> Vec { - if reverse { - argsort!(self, |(_idx_a, a), (_idx_b, b)| b.cmp(a)) - } else { - argsort!(self, |(_idx_a, a), (_idx_b, b)| a.cmp(b)) - } - } -} - -impl ChunkSort for LargeListChunked { - fn sort(&self, _reverse: bool) -> Self { - println!("A ListChunked cannot be sorted. Doing nothing"); - self.clone() - } - - fn sort_in_place(&mut self, _reverse: bool) { - println!("A ListChunked cannot be sorted. Doing nothing"); - } - - fn argsort(&self, _reverse: bool) -> Vec { - println!("A ListChunked cannot be sorted. Doing nothing"); - (0..self.len()).collect() - } -} - -impl ChunkSort for BooleanChunked { - fn sort(&self, reverse: bool) -> BooleanChunked { - sort!(self, reverse) - } - - fn sort_in_place(&mut self, reverse: bool) { - let sorted = self.sort(reverse); - self.chunks = sorted.chunks; - } - - fn argsort(&self, reverse: bool) -> Vec { - if reverse { - argsort!(self, |(_idx_a, a), (_idx_b, b)| b.cmp(a)) - } else { - argsort!(self, |(_idx_a, a), (_idx_b, b)| a.cmp(b)) - } - } -} - -#[derive(Copy, Clone)] -pub enum FillNoneStrategy { - Backward, - Forward, - Mean, - Min, - Max, -} - -/// Replace None values with various strategies -pub trait ChunkFillNone { - /// Replace None values with one of the following strategies: - /// * Forward fill (replace None with the previous value) - /// * Backward fill (replace None with the next value) - /// * Mean fill (replace None with the mean of the whole array) - /// * Min fill (replace None with the minimum of the whole array) - /// * Max fill (replace None with the maximum of the whole array) - fn fill_none(&self, strategy: FillNoneStrategy) -> Result - where - Self: Sized; - - /// Replace None values with a give value `T`. - fn fill_none_with_value(&self, value: T) -> Result - where - Self: Sized; -} - -fn fill_forward(ca: &ChunkedArray) -> ChunkedArray -where - T: PolarsNumericType, -{ - ca.into_iter() - .scan(None, |previous, opt_v| { - let val = match opt_v { - Some(_) => Some(opt_v), - None => Some(*previous), - }; - *previous = opt_v; - val - }) - .collect() -} - -macro_rules! impl_fill_forward { - ($ca:ident) => {{ - let ca = $ca - .into_iter() - .scan(None, |previous, opt_v| { - let val = match opt_v { - Some(_) => Some(opt_v), - None => Some(*previous), - }; - *previous = opt_v; - val - }) - .collect(); - Ok(ca) - }}; -} - -fn fill_backward(ca: &ChunkedArray) -> ChunkedArray -where - T: PolarsNumericType, -{ - let mut iter = ca.into_iter().peekable(); - - let mut builder = PrimitiveChunkedBuilder::::new(ca.name(), ca.len()); - while let Some(opt_v) = iter.next() { - match opt_v { - Some(v) => builder.append_value(v), - None => { - match iter.peek() { - // end of iterator - None => builder.append_null(), - Some(opt_v) => builder.append_option(*opt_v), - } - } - } - } - builder.finish() -} - -macro_rules! impl_fill_backward { - ($ca:ident, $builder:ident) => {{ - let mut iter = $ca.into_iter().peekable(); - - while let Some(opt_v) = iter.next() { - match opt_v { - Some(v) => $builder.append_value(v), - None => { - match iter.peek() { - // end of iterator - None => $builder.append_null(), - Some(opt_v) => $builder.append_option(*opt_v), - } - } - } - } - Ok($builder.finish()) - }}; -} - -fn fill_value(ca: &ChunkedArray, value: Option) -> ChunkedArray -where - T: PolarsNumericType, -{ - ca.into_iter() - .map(|opt_v| match opt_v { - Some(_) => opt_v, - None => value, - }) - .collect() -} - -macro_rules! impl_fill_value { - ($ca:ident, $value:expr) => {{ - $ca.into_iter() - .map(|opt_v| match opt_v { - Some(_) => opt_v, - None => $value, - }) - .collect() - }}; -} - -impl ChunkFillNone for ChunkedArray -where - T: PolarsNumericType, - T::Native: Add + PartialOrd + Div + Num + NumCast, -{ - fn fill_none(&self, strategy: FillNoneStrategy) -> Result { - // nothing to fill - if self.null_count() == 0 { - return Ok(self.clone()); - } - let ca = match strategy { - FillNoneStrategy::Forward => fill_forward(self), - FillNoneStrategy::Backward => fill_backward(self), - FillNoneStrategy::Min => impl_fill_value!(self, self.min()), - FillNoneStrategy::Max => impl_fill_value!(self, self.max()), - FillNoneStrategy::Mean => impl_fill_value!(self, self.mean()), - }; - Ok(ca) - } - fn fill_none_with_value(&self, value: T::Native) -> Result { - Ok(impl_fill_value!(self, Some(value))) - } -} - -impl ChunkFillNone for BooleanChunked { - fn fill_none(&self, strategy: FillNoneStrategy) -> Result { - // nothing to fill - if self.null_count() == 0 { - return Ok(self.clone()); - } - let mut builder = PrimitiveChunkedBuilder::::new(self.name(), self.len()); - match strategy { - FillNoneStrategy::Forward => impl_fill_forward!(self), - FillNoneStrategy::Backward => impl_fill_backward!(self, builder), - FillNoneStrategy::Min => Ok(impl_fill_value!(self, self.min().map(|v| v != 0))), - FillNoneStrategy::Max => Ok(impl_fill_value!(self, self.max().map(|v| v != 0))), - FillNoneStrategy::Mean => Ok(impl_fill_value!(self, self.mean().map(|v| v != 0))), - } - } - - fn fill_none_with_value(&self, value: bool) -> Result { - Ok(impl_fill_value!(self, Some(value))) - } -} - -impl ChunkFillNone<&str> for Utf8Chunked { - fn fill_none(&self, strategy: FillNoneStrategy) -> Result { - // nothing to fill - if self.null_count() == 0 { - return Ok(self.clone()); - } - let mut builder = Utf8ChunkedBuilder::new(self.name(), self.len()); - match strategy { - FillNoneStrategy::Forward => impl_fill_forward!(self), - FillNoneStrategy::Backward => impl_fill_backward!(self, builder), - _ => Err(PolarsError::InvalidOperation), - } - } - - fn fill_none_with_value(&self, value: &str) -> Result { - Ok(impl_fill_value!(self, Some(value))) - } -} - -impl ChunkFillNone<&Series> for LargeListChunked { - fn fill_none(&self, _strategy: FillNoneStrategy) -> Result { - Err(PolarsError::InvalidOperation) - } - fn fill_none_with_value(&self, _value: &Series) -> Result { - Err(PolarsError::InvalidOperation) - } -} - -/// Fill a ChunkedArray with one value. -pub trait ChunkFull { - /// Create a ChunkedArray with a single value. - fn full(name: &str, value: T, length: usize) -> Self - where - Self: std::marker::Sized; -} - -impl ChunkFull for ChunkedArray -where - T: ArrowPrimitiveType, -{ - fn full(name: &str, value: T::Native, length: usize) -> Self - where - T::Native: Copy, - { - let mut builder = PrimitiveChunkedBuilder::new(name, length); - - for _ in 0..length { - builder.append_value(value) - } - builder.finish() - } -} - -impl<'a> ChunkFull<&'a str> for Utf8Chunked { - fn full(name: &str, value: &'a str, length: usize) -> Self { - let mut builder = Utf8ChunkedBuilder::new(name, length); - - for _ in 0..length { - builder.append_value(value); - } - builder.finish() - } -} - -/// Reverse a ChunkedArray -pub trait ChunkReverse { - /// Return a reversed version of this array. - fn reverse(&self) -> ChunkedArray; -} - -impl ChunkReverse for ChunkedArray -where - T: PolarsNumericType, - ChunkedArray: ChunkOps, -{ - fn reverse(&self) -> ChunkedArray { - if let Ok(slice) = self.cont_slice() { - let ca: Xob> = slice.iter().rev().copied().collect(); - let mut ca = ca.into_inner(); - ca.rename(self.name()); - ca - } else { - self.take((0..self.len()).rev(), None) - .expect("implementation error, should not fail") - } - } -} - -macro_rules! impl_reverse { - ($arrow_type:ident, $ca_type:ident) => { - impl ChunkReverse<$arrow_type> for $ca_type { - fn reverse(&self) -> Self { - self.take((0..self.len()).rev(), None) - .expect("implementation error, should not fail") - } - } - }; -} - -impl_reverse!(BooleanType, BooleanChunked); -impl_reverse!(Utf8Type, Utf8Chunked); -impl_reverse!(LargeListType, LargeListChunked); - -/// Filter values by a boolean mask. -pub trait ChunkFilter { - /// Filter values in the ChunkedArray with a boolean mask. - /// - /// ```rust - /// # use polars::prelude::*; - /// let array = Int32Chunked::new_from_slice("array", &[1, 2, 3]); - /// let mask = BooleanChunked::new_from_slice("mask", &[true, false, true]); - /// - /// let filtered = array.filter(&mask).unwrap(); - /// assert_eq!(Vec::from(&filtered), [Some(1), Some(3)]) - /// ``` - fn filter(&self, filter: &BooleanChunked) -> Result> - where - Self: Sized; -} - -impl ChunkFilter for ChunkedArray -where - T: PolarsSingleType, - ChunkedArray: ChunkOps, -{ - fn filter(&self, filter: &BooleanChunked) -> Result> { - let opt = self.optional_rechunk(filter)?; - let left = opt.as_ref().or(Some(self)).unwrap(); - let chunks = left - .chunks - .iter() - .zip(&filter.downcast_chunks()) - .map(|(arr, &fil)| compute::filter(&*(arr.clone()), fil)) - .collect::, arrow::error::ArrowError>>()?; - - Ok(self.copy_with_chunks(chunks)) - } -} - -impl ChunkFilter for LargeListChunked { - fn filter(&self, filter: &BooleanChunked) -> Result { - let dt = self.get_inner_dtype(); - let mut builder = get_large_list_builder(dt, self.len(), self.name()); - filter - .into_iter() - .zip(self.into_iter()) - .for_each(|(opt_bool_val, opt_series)| { - let bool_val = opt_bool_val.unwrap_or(false); - let opt_val = match bool_val { - true => opt_series, - false => None, - }; - builder.append_opt_series(&opt_val) - }); - Ok(builder.finish()) - } -} - -/// Shift the values of a ChunkedArray by a number of periods. -pub trait ChunkShift { - /// Shift the values by a given period and fill the parts that will be empty due to this operation - /// with `fill_value`. - fn shift(&self, periods: i32, fill_value: &Option) -> Result>; -} - -fn chunk_shift_helper( - ca: &ChunkedArray, - builder: &mut PrimitiveChunkedBuilder, - amount: usize, -) where - T: PolarsNumericType, - T::Native: Copy, -{ - match ca.cont_slice() { - // fast path - Ok(slice) => slice - .iter() - .take(amount) - .for_each(|v| builder.append_value(*v)), - // slower path - _ => { - ca.into_iter() - .take(amount) - .for_each(|opt| builder.append_option(opt)); - } - } -} - -impl ChunkShift for ChunkedArray -where - T: PolarsNumericType, - T::Native: Copy, -{ - fn shift(&self, periods: i32, fill_value: &Option) -> Result> { - if periods.abs() >= self.len() as i32 { - return Err(PolarsError::OutOfBounds); - } - let mut builder = PrimitiveChunkedBuilder::::new(self.name(), self.len()); - let amount = self.len() - periods.abs() as usize; - - // Fill the front of the array - if periods > 0 { - for _ in 0..periods { - builder.append_option(*fill_value) - } - chunk_shift_helper(self, &mut builder, amount); - // Fill the back of the array - } else { - chunk_shift_helper(self, &mut builder, amount); - for _ in 0..periods.abs() { - builder.append_option(*fill_value) - } - } - Ok(builder.finish()) - } -} - -macro_rules! impl_shift { - // append_method and append_fn do almost the same. Only for largelist type, the closure - // accepts an owned value, while fill_value is a reference. That's why we have two options. - ($self:ident, $builder:ident, $periods:ident, $fill_value:ident, - $append_method:ident, $append_fn:expr) => {{ - let amount = $self.len() - $periods.abs() as usize; - - // Fill the front of the array - if $periods > 0 { - for _ in 0..$periods { - $builder.$append_method($fill_value) - } - $self - .into_iter() - .take(amount) - .for_each(|opt| $append_fn(&mut $builder, opt)); - // Fill the back of the array - } else { - $self - .into_iter() - .take(amount) - .for_each(|opt| $append_fn(&mut $builder, opt)); - for _ in 0..$periods.abs() { - $builder.$append_method($fill_value) - } - } - Ok($builder.finish()) - }}; -} - -impl ChunkShift for BooleanChunked { - fn shift(&self, periods: i32, fill_value: &Option) -> Result { - if periods.abs() >= self.len() as i32 { - return Err(PolarsError::OutOfBounds); - } - let mut builder = PrimitiveChunkedBuilder::::new(self.name(), self.len()); - - fn append_fn(builder: &mut PrimitiveChunkedBuilder, v: Option) { - builder.append_option(v); - } - let fill_value = *fill_value; - - impl_shift!(self, builder, periods, fill_value, append_option, append_fn) - } -} - -impl ChunkShift for Utf8Chunked { - fn shift(&self, periods: i32, fill_value: &Option<&str>) -> Result { - if periods.abs() >= self.len() as i32 { - return Err(PolarsError::OutOfBounds); - } - let mut builder = Utf8ChunkedBuilder::new(self.name(), self.len()); - fn append_fn(builder: &mut Utf8ChunkedBuilder, v: Option<&str>) { - builder.append_option(v); - } - let fill_value = *fill_value; - - impl_shift!(self, builder, periods, fill_value, append_option, append_fn) - } -} - -impl ChunkShift for LargeListChunked { - fn shift(&self, periods: i32, fill_value: &Option) -> Result { - if periods.abs() >= self.len() as i32 { - return Err(PolarsError::OutOfBounds); - } - let dt = self.get_inner_dtype(); - let mut builder = get_large_list_builder(dt, self.len(), self.name()); - fn append_fn(builder: &mut Box, v: Option) { - builder.append_opt_series(&v); - } - - impl_shift!( - self, - builder, - periods, - fill_value, - append_opt_series, - append_fn - ) - } -} - -#[cfg(test)] -mod test { - use crate::prelude::*; - - #[test] - fn test_shift() { - let ca = Int32Chunked::new_from_slice("", &[1, 2, 3]); - let shifted = ca.shift(1, &Some(0)).unwrap(); - assert_eq!(shifted.cont_slice().unwrap(), &[0, 1, 2]); - let shifted = ca.shift(1, &None).unwrap(); - assert_eq!(Vec::from(&shifted), &[None, Some(1), Some(2)]); - let shifted = ca.shift(-1, &None).unwrap(); - assert_eq!(Vec::from(&shifted), &[Some(1), Some(2), None]); - assert!(ca.shift(3, &None).is_err()); - } - - #[test] - fn test_fill_none() { - let ca = - Int32Chunked::new_from_opt_slice("", &[None, Some(2), Some(3), None, Some(4), None]); - let filled = ca.fill_none(FillNoneStrategy::Forward).unwrap(); - assert_eq!( - Vec::from(&filled), - &[None, Some(2), Some(3), Some(3), Some(4), Some(4)] - ); - let filled = ca.fill_none(FillNoneStrategy::Backward).unwrap(); - assert_eq!( - Vec::from(&filled), - &[Some(2), Some(2), Some(3), Some(4), Some(4), None] - ); - let filled = ca.fill_none(FillNoneStrategy::Min).unwrap(); - assert_eq!( - Vec::from(&filled), - &[Some(2), Some(2), Some(3), Some(2), Some(4), Some(2)] - ); - let filled = ca.fill_none_with_value(10).unwrap(); - assert_eq!( - Vec::from(&filled), - &[Some(10), Some(2), Some(3), Some(10), Some(4), Some(10)] - ); - let filled = ca.fill_none(FillNoneStrategy::Mean).unwrap(); - assert_eq!( - Vec::from(&filled), - &[Some(3), Some(2), Some(3), Some(3), Some(4), Some(3)] - ); - println!("{:?}", filled); - } -} diff --git a/polars/src/chunked_array/par/mod.rs b/polars/src/chunked_array/par/mod.rs deleted file mode 100644 index ec02418e128a..000000000000 --- a/polars/src/chunked_array/par/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod utf8; diff --git a/polars/src/chunked_array/par/utf8.rs b/polars/src/chunked_array/par/utf8.rs deleted file mode 100644 index 6d6d273f2ba9..000000000000 --- a/polars/src/chunked_array/par/utf8.rs +++ /dev/null @@ -1,236 +0,0 @@ -use crate::prelude::*; -use arrow::array::StringArray; -use rayon::iter::plumbing::*; -use rayon::iter::plumbing::{Consumer, ProducerCallback}; -use rayon::prelude::*; -use std::{mem, ops::Range}; - -#[derive(Debug, Clone)] -pub struct Utf8IntoIter<'a> { - ca: &'a Utf8Chunked, -} - -impl<'a> IntoParallelIterator for &'a Utf8Chunked { - type Iter = Utf8IntoIter<'a>; - type Item = Option<&'a str>; - - fn into_par_iter(self) -> Self::Iter { - Utf8IntoIter { ca: self } - } -} -impl<'a> ParallelIterator for Utf8IntoIter<'a> { - type Item = Option<&'a str>; - - fn drive_unindexed(self, consumer: C) -> C::Result - where - C: UnindexedConsumer, - { - bridge(self, consumer) - } - - fn opt_len(&self) -> Option { - Some(self.ca.len()) - } -} -impl<'a> IndexedParallelIterator for Utf8IntoIter<'a> { - fn len(&self) -> usize { - self.ca.len() - } - - fn drive(self, consumer: C) -> C::Result - where - C: Consumer, - { - bridge(self, consumer) - } - - fn with_producer(self, callback: CB) -> CB::Output - where - CB: ProducerCallback, - { - callback.callback(Utf8Producer { - ca: &self.ca, - offset: 0, - len: self.ca.len(), - }) - } -} - -struct Utf8Producer<'a> { - ca: &'a Utf8Chunked, - offset: usize, - len: usize, -} - -impl<'a> Producer for Utf8Producer<'a> { - type Item = Option<&'a str>; - type IntoIter = Utf8Iter<'a>; - - fn into_iter(self) -> Self::IntoIter { - let iter = (0..self.len).into_iter(); - Utf8Iter { ca: self.ca, iter } - } - - fn split_at(self, index: usize) -> (Self, Self) { - ( - Utf8Producer { - ca: self.ca, - offset: self.offset, - len: index + 1, - }, - Utf8Producer { - ca: self.ca, - offset: self.offset + index, - len: self.len - index, - }, - ) - } -} - -struct Utf8Iter<'a> { - ca: &'a Utf8Chunked, - iter: Range, -} - -impl<'a> Iterator for Utf8Iter<'a> { - type Item = Option<&'a str>; - - fn next(&mut self) -> Option { - self.iter.next().map(|idx| unsafe { - mem::transmute::, Option<&'a str>>(self.ca.get(idx)) - }) - } - - fn size_hint(&self) -> (usize, Option) { - let len = self.ca.len(); - (len, Some(len)) - } -} - -impl<'a> DoubleEndedIterator for Utf8Iter<'a> { - fn next_back(&mut self) -> Option { - self.iter.next_back().map(|idx| unsafe { - mem::transmute::, Option<&'a str>>(self.ca.get(idx)) - }) - } -} - -impl<'a> ExactSizeIterator for Utf8Iter<'a> {} - -/// No null Iterators - -#[derive(Debug, Clone)] -pub struct Utf8IntoIterCont<'a> { - ca: &'a Utf8Chunked, -} - -impl<'a> IntoParallelIterator for NoNull<&'a Utf8Chunked> { - type Iter = Utf8IntoIterCont<'a>; - type Item = &'a str; - - fn into_par_iter(self) -> Self::Iter { - Utf8IntoIterCont { ca: self.0 } - } -} -impl<'a> ParallelIterator for Utf8IntoIterCont<'a> { - type Item = &'a str; - - fn drive_unindexed(self, consumer: C) -> C::Result - where - C: UnindexedConsumer, - { - bridge(self, consumer) - } - - fn opt_len(&self) -> Option { - Some(self.ca.len()) - } -} -impl<'a> IndexedParallelIterator for Utf8IntoIterCont<'a> { - fn len(&self) -> usize { - self.ca.len() - } - - fn drive(self, consumer: C) -> C::Result - where - C: Consumer, - { - bridge(self, consumer) - } - - fn with_producer(self, callback: CB) -> CB::Output - where - CB: ProducerCallback, - { - callback.callback(Utf8ProducerCont { - arr: self.ca.downcast_chunks()[0], - offset: 0, - len: self.ca.len(), - }) - } -} - -struct Utf8ProducerCont<'a> { - arr: &'a StringArray, - offset: usize, - len: usize, -} - -impl<'a> Producer for Utf8ProducerCont<'a> { - type Item = &'a str; - type IntoIter = Utf8IterCont<'a>; - - fn into_iter(self) -> Self::IntoIter { - let iter = (0..self.len).into_iter(); - Utf8IterCont { - arr: self.arr, - iter, - offset: self.offset, - } - } - - fn split_at(self, index: usize) -> (Self, Self) { - ( - Utf8ProducerCont { - arr: self.arr, - offset: self.offset, - len: index + 1, - }, - Utf8ProducerCont { - arr: self.arr, - offset: self.offset + index, - len: self.len - index, - }, - ) - } -} - -struct Utf8IterCont<'a> { - arr: &'a StringArray, - iter: Range, - offset: usize, -} - -impl<'a> Iterator for Utf8IterCont<'a> { - type Item = &'a str; - - fn next(&mut self) -> Option { - self.iter.next().map(|idx| unsafe { - mem::transmute::<&'_ str, &'a str>(self.arr.value(idx + self.offset)) - }) - } - - fn size_hint(&self) -> (usize, Option) { - self.iter.size_hint() - } -} - -impl<'a> DoubleEndedIterator for Utf8IterCont<'a> { - fn next_back(&mut self) -> Option { - self.iter.next_back().map(|idx| unsafe { - mem::transmute::<&'_ str, &'a str>(self.arr.value(idx + self.offset)) - }) - } -} - -impl<'a> ExactSizeIterator for Utf8IterCont<'a> {} diff --git a/polars/src/chunked_array/random.rs b/polars/src/chunked_array/random.rs deleted file mode 100644 index 29664b6ab215..000000000000 --- a/polars/src/chunked_array/random.rs +++ /dev/null @@ -1,65 +0,0 @@ -use crate::prelude::*; -use num::{Float, NumCast}; -use rand::distributions::Bernoulli; -use rand::prelude::*; -use rand_distr::{Distribution, Normal, StandardNormal, Uniform}; - -impl ChunkedArray -where - T: PolarsNumericType, - T::Native: Float + NumCast, -{ - /// Create `ChunkedArray` with samples from a Normal distribution. - pub fn rand_normal(name: &str, length: usize, mean: f64, std_dev: f64) -> Result { - let normal = match Normal::new(mean, std_dev) { - Ok(dist) => dist, - Err(e) => return Err(PolarsError::RandError(format!("{:?}", e))), - }; - let mut builder = PrimitiveChunkedBuilder::::new(name, length); - for _ in 0..length { - let smpl = normal.sample(&mut rand::thread_rng()); - let smpl = NumCast::from(smpl).unwrap(); - builder.append_value(smpl) - } - Ok(builder.finish()) - } - - /// Create `ChunkedArray` with samples from a Standard Normal distribution. - pub fn rand_standard_normal(name: &str, length: usize) -> Self { - let mut builder = PrimitiveChunkedBuilder::::new(name, length); - for _ in 0..length { - let smpl: f64 = thread_rng().sample(StandardNormal); - let smpl = NumCast::from(smpl).unwrap(); - builder.append_value(smpl) - } - builder.finish() - } - - /// Create `ChunkedArray` with samples from a Uniform distribution. - pub fn rand_uniform(name: &str, length: usize, low: f64, high: f64) -> Self { - let uniform = Uniform::new(low, high); - let mut builder = PrimitiveChunkedBuilder::::new(name, length); - for _ in 0..length { - let smpl = uniform.sample(&mut rand::thread_rng()); - let smpl = NumCast::from(smpl).unwrap(); - builder.append_value(smpl) - } - builder.finish() - } -} - -impl BooleanChunked { - /// Create `ChunkedArray` with samples from a Bernoulli distribution. - pub fn rand_bernoulli(name: &str, length: usize, p: f64) -> Result { - let dist = match Bernoulli::new(p) { - Ok(dist) => dist, - Err(e) => return Err(PolarsError::RandError(format!("{:?}", e))), - }; - let mut builder = BooleanChunkedBuilder::new(name, length); - for _ in 0..length { - let smpl = dist.sample(&mut rand::thread_rng()); - builder.append_value(smpl) - } - Ok(builder.finish()) - } -} diff --git a/polars/src/chunked_array/set.rs b/polars/src/chunked_array/set.rs deleted file mode 100644 index 31526b4dcb6a..000000000000 --- a/polars/src/chunked_array/set.rs +++ /dev/null @@ -1,284 +0,0 @@ -use crate::prelude::*; -use crate::utils::Xob; - -macro_rules! impl_set_at_idx_with { - ($self:ident, $builder:ident, $idx:ident, $f:ident) => {{ - let mut idx_iter = $idx.as_take_iter(); - let mut ca_iter = $self.into_iter().enumerate(); - - while let Some(current_idx) = idx_iter.next() { - if current_idx > $self.len() { - return Err(PolarsError::OutOfBounds); - } - while let Some((cnt_idx, opt_val)) = ca_iter.next() { - if cnt_idx == current_idx { - $builder.append_option($f(opt_val)); - break; - } else { - $builder.append_option(opt_val); - } - } - } - // the last idx is probably not the last value so we finish the iterator - while let Some((_, opt_val)) = ca_iter.next() { - $builder.append_option(opt_val); - } - - let ca = $builder.finish(); - Ok(ca) - }}; -} - -macro_rules! check_bounds { - ($self:ident, $mask:ident) => {{ - if $self.len() != $mask.len() { - return Err(PolarsError::ShapeMisMatch); - } - }}; -} - -impl<'a, T> ChunkSet<'a, T::Native, T::Native> for ChunkedArray -where - T: ArrowPrimitiveType, - &'a ChunkedArray: - IntoNoNullIterator + IntoIterator>, - T::Native: Copy, -{ - fn set_at_idx(&'a self, idx: &I, value: Option) -> Result { - self.set_at_idx_with(idx, |_| value) - } - - fn set_at_idx_with(&'a self, idx: &I, f: F) -> Result - where - F: Fn(Option) -> Option, - { - // TODO: implement fast path - let mut builder = PrimitiveChunkedBuilder::::new(self.name(), self.len()); - impl_set_at_idx_with!(self, builder, idx, f) - } - - fn set(&'a self, mask: &BooleanChunked, value: Option) -> Result { - // all the code does practically the same but has different fast paths. We do this - // again because we don't need the return option in all cases. Otherwise we would have called - // set_with - match value { - Some(new_value) => { - // fast path because self has no nulls - let mut ca = if self.is_optimal_aligned() { - // fast path because mask has no nulls - if mask.is_optimal_aligned() { - let ca: Xob<_> = self - .into_no_null_iter() - .zip(mask.into_no_null_iter()) - .map(|(val, mask)| match mask { - true => new_value, - false => val, - }) - .collect(); - ca.into_inner() - // slower path, mask has null values - } else { - let ca: Xob<_> = self - .into_no_null_iter() - .zip(mask) - .map(|(val, opt_mask)| match opt_mask { - None => val, - Some(true) => new_value, - Some(false) => val, - }) - .collect(); - ca.into_inner() - } - } else { - // mask has no nulls, self has - if mask.is_optimal_aligned() { - self.into_iter() - .zip(mask.into_no_null_iter()) - .map(|(opt_val, mask)| match mask { - true => Some(new_value), - false => opt_val, - }) - .collect() - } else { - // slowest path, mask and self have null values - self.into_iter() - .zip(mask) - .map(|(opt_val, opt_mask)| match opt_mask { - None => opt_val, - Some(true) => Some(new_value), - Some(false) => opt_val, - }) - .collect() - } - }; - ca.rename(self.name()); - Ok(ca) - } - None => self.set_with(mask, |_| value), - } - } - - fn set_with(&'a self, mask: &BooleanChunked, f: F) -> Result - where - F: Fn(Option) -> Option, - { - check_bounds!(self, mask); - - // fast path because self has no nulls - let mut ca: ChunkedArray = if self.is_optimal_aligned() { - // fast path because mask has no nulls - if mask.is_optimal_aligned() { - self.into_no_null_iter() - .zip(mask.into_no_null_iter()) - .map(|(val, mask)| match mask { - true => f(Some(val)), - false => Some(val), - }) - .collect() - - // slower path, mask has null values - } else { - self.into_no_null_iter() - .zip(mask) - .map(|(val, opt_mask)| match opt_mask { - None => Some(val), - Some(true) => f(Some(val)), - Some(false) => Some(val), - }) - .collect() - } - } else { - // mask has no nulls, self has - if mask.is_optimal_aligned() { - self.into_iter() - .zip(mask.into_no_null_iter()) - .map(|(opt_val, mask)| match mask { - true => f(opt_val), - false => opt_val, - }) - .collect() - } else { - // slowest path, mask and self have null values - self.into_iter() - .zip(mask) - .map(|(opt_val, opt_mask)| match opt_mask { - None => opt_val, - Some(true) => f(opt_val), - Some(false) => opt_val, - }) - .collect() - } - }; - - ca.rename(self.name()); - Ok(ca) - } -} - -impl<'a> ChunkSet<'a, &'a str, String> for Utf8Chunked { - fn set_at_idx(&'a self, idx: &T, opt_value: Option<&'a str>) -> Result - where - Self: Sized, - { - let mut idx_iter = idx.as_take_iter(); - let mut ca_iter = self.into_iter().enumerate(); - let mut builder = Utf8ChunkedBuilder::new(self.name(), self.len()); - - while let Some(current_idx) = idx_iter.next() { - if current_idx > self.len() { - return Err(PolarsError::OutOfBounds); - } - while let Some((cnt_idx, opt_val_self)) = ca_iter.next() { - if cnt_idx == current_idx { - builder.append_option(opt_value); - break; - } else { - builder.append_option(opt_val_self); - } - } - } - // the last idx is probably not the last value so we finish the iterator - while let Some((_, opt_val_self)) = ca_iter.next() { - builder.append_option(opt_val_self); - } - - let ca = builder.finish(); - Ok(ca) - } - - fn set_at_idx_with(&'a self, idx: &T, f: F) -> Result - where - Self: Sized, - F: Fn(Option<&'a str>) -> Option, - { - let mut builder = Utf8ChunkedBuilder::new(self.name(), self.len()); - impl_set_at_idx_with!(self, builder, idx, f) - } - - fn set(&'a self, mask: &BooleanChunked, opt_value: Option<&'a str>) -> Result - where - Self: Sized, - { - check_bounds!(self, mask); - let mut builder = Utf8ChunkedBuilder::new(self.name(), self.len()); - self.into_iter() - .zip(mask) - .for_each(|(opt_val_self, opt_mask)| match opt_mask { - None => builder.append_option(opt_val_self), - Some(true) => builder.append_option(opt_value), - Some(false) => builder.append_option(opt_val_self), - }); - Ok(builder.finish()) - } - - fn set_with(&'a self, mask: &BooleanChunked, f: F) -> Result - where - Self: Sized, - F: Fn(Option<&'a str>) -> Option, - { - check_bounds!(self, mask); - let mut builder = Utf8ChunkedBuilder::new(self.name(), self.len()); - self.into_iter() - .zip(mask) - .for_each(|(opt_val, opt_mask)| match opt_mask { - None => builder.append_option(opt_val), - Some(true) => builder.append_option(f(opt_val)), - Some(false) => builder.append_option(opt_val), - }); - Ok(builder.finish()) - } -} - -#[cfg(test)] -mod test { - use crate::prelude::*; - - #[test] - fn test_set() { - let ca = Int32Chunked::new_from_slice("a", &[1, 2, 3]); - let mask = BooleanChunked::new_from_slice("mask", &[false, true, false]); - let ca = ca.set(&mask, Some(5)).unwrap(); - assert_eq!(Vec::from(&ca), &[Some(1), Some(5), Some(3)]); - let ca = Int32Chunked::new_from_slice("a", &[1, 2, 3]); - let mask = BooleanChunked::new_from_opt_slice("mask", &[None, Some(true), None]); - let ca = ca.set(&mask, Some(5)).unwrap(); - assert_eq!(Vec::from(&ca), &[Some(1), Some(5), Some(3)]); - - let ca = ca.set_at_idx(&[0, 1], Some(10)).unwrap(); - assert_eq!(Vec::from(&ca), &[Some(10), Some(10), Some(3)]); - - assert!(ca.set_at_idx(&[0, 10], Some(0)).is_err()); - - // test booleans - let ca = BooleanChunked::new_from_slice("a", &[true, true, true]); - let mask = BooleanChunked::new_from_slice("mask", &[false, true, false]); - let ca = ca.set(&mask, None).unwrap(); - assert_eq!(Vec::from(&ca), &[Some(true), None, Some(true)]); - - // test utf8 - let ca = Utf8Chunked::new_from_slice("a", &["foo", "foo", "foo"]); - let mask = BooleanChunked::new_from_slice("mask", &[false, true, false]); - let ca = ca.set(&mask, Some("bar")).unwrap(); - assert_eq!(Vec::from(&ca), &[Some("foo"), Some("bar"), Some("foo")]); - } -} diff --git a/polars/src/chunked_array/take.rs b/polars/src/chunked_array/take.rs deleted file mode 100644 index 46ff7c06f73c..000000000000 --- a/polars/src/chunked_array/take.rs +++ /dev/null @@ -1,736 +0,0 @@ -//! Traits to provide fast Random access to ChunkedArrays data. -//! This prevents downcasting every iteration. -//! IntoTakeRandom provides structs that implement the TakeRandom trait. -//! There are several structs that implement the fastest path for random access. -//! -use crate::chunked_array::builder::{ - get_large_list_builder, PrimitiveChunkedBuilder, Utf8ChunkedBuilder, -}; -use crate::prelude::*; -use arrow::array::{ - Array, ArrayRef, BooleanArray, LargeListArray, PrimitiveArray, PrimitiveArrayOps, StringArray, -}; -use std::sync::Arc; - -macro_rules! impl_take_random_get { - ($self:ident, $index:ident, $array_type:ty) => {{ - let (chunk_idx, idx) = $self.index_to_chunked_index($index); - let arr = unsafe { - let arr = $self.chunks.get_unchecked(chunk_idx); - &*(arr as *const ArrayRef as *const Arc<$array_type>) - }; - if arr.is_valid(idx) { - Some(arr.value(idx)) - } else { - None - } - }}; -} - -macro_rules! impl_take_random_get_unchecked { - ($self:ident, $index:ident, $array_type:ty) => {{ - let (chunk_idx, idx) = $self.index_to_chunked_index($index); - let arr = { - let arr = $self.chunks.get_unchecked(chunk_idx); - &*(arr as *const ArrayRef as *const Arc<$array_type>) - }; - arr.value(idx) - }}; -} - -impl TakeRandom for ChunkedArray -where - T: ArrowPrimitiveType, -{ - type Item = T::Native; - - fn get(&self, index: usize) -> Option { - impl_take_random_get!(self, index, PrimitiveArray) - } - - unsafe fn get_unchecked(&self, index: usize) -> Self::Item { - impl_take_random_get_unchecked!(self, index, PrimitiveArray) - } -} -// extra trait such that it also works without extra reference. -impl<'a> TakeRandomUtf8 for &'a Utf8Chunked { - type Item = &'a str; - - fn get(self, index: usize) -> Option { - impl_take_random_get!(self, index, StringArray) - } - - unsafe fn get_unchecked(self, index: usize) -> Self::Item { - impl_take_random_get_unchecked!(self, index, StringArray) - } -} - -impl TakeRandom for LargeListChunked { - type Item = Series; - - fn get(&self, index: usize) -> Option { - let opt_arr = impl_take_random_get!(self, index, LargeListArray); - opt_arr.map(|arr| (self.name(), arr).into()) - } - - unsafe fn get_unchecked(&self, index: usize) -> Self::Item { - let arr = impl_take_random_get_unchecked!(self, index, LargeListArray); - (self.name(), arr).into() - } -} - -macro_rules! impl_take { - ($self:ident, $indices:ident, $capacity:ident, $builder:ident) => {{ - let capacity = $capacity.unwrap_or($indices.size_hint().0); - let mut builder = $builder::new($self.name(), capacity); - - let taker = $self.take_rand(); - for idx in $indices { - match taker.get(idx) { - Some(v) => builder.append_value(v), - None => builder.append_null(), - } - } - Ok(builder.finish()) - }}; -} - -macro_rules! impl_take_opt { - ($self:ident, $indices:ident, $capacity:ident, $builder:ident) => {{ - let capacity = $capacity.unwrap_or($indices.size_hint().0); - let mut builder = $builder::new($self.name(), capacity); - let taker = $self.take_rand(); - - for opt_idx in $indices { - match opt_idx { - Some(idx) => match taker.get(idx) { - Some(v) => builder.append_value(v), - None => builder.append_null(), - }, - None => builder.append_null(), - }; - } - Ok(builder.finish()) - }}; -} - -macro_rules! impl_take_opt_unchecked { - ($self:ident, $indices:ident, $capacity:ident, $builder:ident) => {{ - let capacity = $capacity.unwrap_or($indices.size_hint().0); - let mut builder = $builder::new($self.name(), capacity); - let taker = $self.take_rand(); - - for opt_idx in $indices { - match opt_idx { - Some(idx) => { - let v = taker.get_unchecked(idx); - builder.append_value(v); - } - None => builder.append_null(), - }; - } - builder.finish() - }}; -} - -macro_rules! impl_take_unchecked { - ($self:ident, $indices:ident, $capacity:ident, $builder:ident) => {{ - let capacity = $capacity.unwrap_or($indices.size_hint().0); - let mut builder = $builder::new($self.name(), capacity); - - let taker = $self.take_rand(); - for idx in $indices { - let v = taker.get_unchecked(idx); - builder.append_value(v); - } - builder.finish() - }}; -} - -impl ChunkTake for ChunkedArray -where - T: PolarsNumericType, -{ - fn take(&self, indices: impl Iterator, capacity: Option) -> Result { - impl_take!(self, indices, capacity, PrimitiveChunkedBuilder) - } - - unsafe fn take_unchecked( - &self, - indices: impl Iterator, - capacity: Option, - ) -> Self { - impl_take_unchecked!(self, indices, capacity, PrimitiveChunkedBuilder) - } - - fn take_opt( - &self, - indices: impl Iterator>, - capacity: Option, - ) -> Result { - impl_take_opt!(self, indices, capacity, PrimitiveChunkedBuilder) - } - - unsafe fn take_opt_unchecked( - &self, - indices: impl Iterator>, - capacity: Option, - ) -> Self { - impl_take_opt_unchecked!(self, indices, capacity, PrimitiveChunkedBuilder) - } -} - -impl ChunkTake for BooleanChunked { - fn take(&self, indices: impl Iterator, capacity: Option) -> Result - where - Self: std::marker::Sized, - { - impl_take!(self, indices, capacity, PrimitiveChunkedBuilder) - } - - unsafe fn take_unchecked( - &self, - indices: impl Iterator, - capacity: Option, - ) -> Self { - impl_take_unchecked!(self, indices, capacity, PrimitiveChunkedBuilder) - } - - fn take_opt( - &self, - indices: impl Iterator>, - capacity: Option, - ) -> Result { - impl_take_opt!(self, indices, capacity, PrimitiveChunkedBuilder) - } - - unsafe fn take_opt_unchecked( - &self, - indices: impl Iterator>, - capacity: Option, - ) -> Self { - impl_take_opt_unchecked!(self, indices, capacity, PrimitiveChunkedBuilder) - } -} - -impl ChunkTake for Utf8Chunked { - fn take(&self, indices: impl Iterator, capacity: Option) -> Result - where - Self: std::marker::Sized, - { - impl_take!(self, indices, capacity, Utf8ChunkedBuilder) - } - - unsafe fn take_unchecked( - &self, - indices: impl Iterator, - capacity: Option, - ) -> Self { - impl_take_unchecked!(self, indices, capacity, Utf8ChunkedBuilder) - } - - fn take_opt( - &self, - indices: impl Iterator>, - capacity: Option, - ) -> Result - where - Self: std::marker::Sized, - { - impl_take_opt!(self, indices, capacity, Utf8ChunkedBuilder) - } - - unsafe fn take_opt_unchecked( - &self, - indices: impl Iterator>, - capacity: Option, - ) -> Self { - impl_take_opt_unchecked!(self, indices, capacity, Utf8ChunkedBuilder) - } -} - -impl ChunkTake for LargeListChunked { - fn take(&self, indices: impl Iterator, capacity: Option) -> Result { - let capacity = capacity.unwrap_or(indices.size_hint().0); - - match self.dtype() { - ArrowDataType::LargeList(dt) => { - let mut builder = get_large_list_builder(&**dt, capacity, self.name()); - let taker = self.take_rand(); - - for idx in indices { - builder.append_opt_series(&taker.get(idx)); - } - Ok(builder.finish()) - } - _ => unimplemented!(), - } - } - - unsafe fn take_unchecked( - &self, - indices: impl Iterator, - capacity: Option, - ) -> Self { - let capacity = capacity.unwrap_or(indices.size_hint().0); - match self.dtype() { - ArrowDataType::LargeList(dt) => { - let mut builder = get_large_list_builder(&**dt, capacity, self.name()); - let taker = self.take_rand(); - for idx in indices { - let v = taker.get_unchecked(idx); - builder.append_opt_series(&Some(v)); - } - builder.finish() - } - _ => unimplemented!(), - } - } - - fn take_opt( - &self, - indices: impl Iterator>, - capacity: Option, - ) -> Result { - let capacity = capacity.unwrap_or(indices.size_hint().0); - - match self.dtype() { - ArrowDataType::LargeList(dt) => { - let mut builder = get_large_list_builder(&**dt, capacity, self.name()); - - let taker = self.take_rand(); - - for opt_idx in indices { - match opt_idx { - Some(idx) => { - let opt_s = taker.get(idx); - builder.append_opt_series(&opt_s) - } - None => builder.append_opt_series(&None), - }; - } - Ok(builder.finish()) - } - _ => unimplemented!(), - } - } - - unsafe fn take_opt_unchecked( - &self, - indices: impl Iterator>, - capacity: Option, - ) -> Self { - let capacity = capacity.unwrap_or(indices.size_hint().0); - - match self.dtype() { - ArrowDataType::LargeList(dt) => { - let mut builder = get_large_list_builder(&**dt, capacity, self.name()); - let taker = self.take_rand(); - - for opt_idx in indices { - match opt_idx { - Some(idx) => { - let s = taker.get_unchecked(idx); - builder.append_opt_series(&Some(s)) - } - None => builder.append_opt_series(&None), - }; - } - builder.finish() - } - _ => unimplemented!(), - } - } -} - -pub trait AsTakeIndex { - fn as_take_iter<'a>(&'a self) -> Box + 'a>; - - fn as_opt_take_iter<'a>(&'a self) -> Box> + 'a> { - unimplemented!() - } - - fn take_index_len(&self) -> usize; -} - -impl AsTakeIndex for &UInt32Chunked { - fn as_take_iter<'a>(&'a self) -> Box + 'a> { - match self.cont_slice() { - Ok(slice) => Box::new(slice.into_iter().map(|&val| val as usize)), - Err(_) => Box::new( - self.into_iter() - .filter_map(|opt_val| opt_val.map(|val| val as usize)), - ), - } - } - fn as_opt_take_iter<'a>(&'a self) -> Box> + 'a> { - Box::new( - self.into_iter() - .map(|opt_val| opt_val.map(|val| val as usize)), - ) - } - fn take_index_len(&self) -> usize { - self.len() - } -} - -impl AsTakeIndex for T -where - T: AsRef<[usize]>, -{ - fn as_take_iter<'a>(&'a self) -> Box + 'a> { - Box::new(self.as_ref().iter().copied()) - } - fn take_index_len(&self) -> usize { - self.as_ref().len() - } -} - -impl AsTakeIndex for [u32] { - fn as_take_iter<'a>(&'a self) -> Box + 'a> { - Box::new(self.iter().map(|&v| v as usize)) - } - fn take_index_len(&self) -> usize { - self.len() - } -} - -/// Create a type that implements a faster `TakeRandom`. -pub trait IntoTakeRandom<'a> { - type Item; - type TakeRandom; - /// Create a type that implements `TakeRandom`. - fn take_rand(&self) -> Self::TakeRandom; -} - -/// Choose the Struct for multiple chunks or the struct for a single chunk. -macro_rules! many_or_single { - ($self:ident, $StructSingle:ident, $StructMany:ident) => {{ - let chunks = $self.downcast_chunks(); - if chunks.len() == 1 { - Box::new($StructSingle { arr: chunks[0] }) - } else { - Box::new($StructMany { - ca: $self, - chunks: chunks, - }) - } - }}; -} - -pub enum NumTakeRandomDispatch<'a, T> -where - T: PolarsNumericType, - T::Native: Copy, -{ - Cont(NumTakeRandomCont<'a, T::Native>), - Single(NumTakeRandomSingleChunk<'a, T>), - Many(NumTakeRandomChunked<'a, T>), -} - -impl<'a, T> TakeRandom for NumTakeRandomDispatch<'a, T> -where - T: PolarsNumericType, - T::Native: Copy, -{ - type Item = T::Native; - - fn get(&self, index: usize) -> Option { - use NumTakeRandomDispatch::*; - match self { - Cont(a) => a.get(index), - Single(a) => a.get(index), - Many(a) => a.get(index), - } - } - - unsafe fn get_unchecked(&self, index: usize) -> Self::Item { - use NumTakeRandomDispatch::*; - match self { - Cont(a) => a.get_unchecked(index), - Single(a) => a.get_unchecked(index), - Many(a) => a.get_unchecked(index), - } - } -} - -impl<'a, T> IntoTakeRandom<'a> for &'a ChunkedArray -where - T: PolarsNumericType, -{ - type Item = T::Native; - type TakeRandom = NumTakeRandomDispatch<'a, T>; - - fn take_rand(&self) -> Self::TakeRandom { - match self.cont_slice() { - Ok(slice) => NumTakeRandomDispatch::Cont(NumTakeRandomCont { slice }), - _ => { - let chunks = self.downcast_chunks(); - if chunks.len() == 1 { - NumTakeRandomDispatch::Single(NumTakeRandomSingleChunk { arr: chunks[0] }) - } else { - NumTakeRandomDispatch::Many(NumTakeRandomChunked { - ca: self, - chunks: chunks, - }) - } - } - } - } -} - -impl<'a> IntoTakeRandom<'a> for &'a Utf8Chunked { - type Item = &'a str; - type TakeRandom = Box + 'a>; - - fn take_rand(&self) -> Self::TakeRandom { - many_or_single!(self, Utf8TakeRandomSingleChunk, Utf8TakeRandom) - } -} - -impl<'a> IntoTakeRandom<'a> for &'a BooleanChunked { - type Item = bool; - type TakeRandom = Box + 'a>; - - fn take_rand(&self) -> Self::TakeRandom { - many_or_single!(self, BoolTakeRandomSingleChunk, BoolTakeRandom) - } -} - -impl<'a> IntoTakeRandom<'a> for &'a LargeListChunked { - type Item = Series; - type TakeRandom = Box + 'a>; - - fn take_rand(&self) -> Self::TakeRandom { - let chunks = self.downcast_chunks(); - if chunks.len() == 1 { - Box::new(ListTakeRandomSingleChunk { - arr: chunks[0], - name: self.name(), - }) - } else { - Box::new(ListTakeRandom { - ca: self, - chunks: chunks, - }) - } - } -} - -pub struct NumTakeRandomChunked<'a, T> -where - T: PolarsNumericType, -{ - ca: &'a ChunkedArray, - chunks: Vec<&'a PrimitiveArray>, -} - -macro_rules! take_random_get { - ($self:ident, $index:ident) => {{ - let (chunk_idx, arr_idx) = $self.ca.index_to_chunked_index($index); - let arr = $self.chunks.get(chunk_idx); - match arr { - Some(arr) => { - if arr.is_null(arr_idx) { - None - } else { - Some(arr.value(arr_idx)) - } - } - None => None, - } - }}; -} - -macro_rules! take_random_get_unchecked { - ($self:ident, $index:ident) => {{ - let (chunk_idx, arr_idx) = $self.ca.index_to_chunked_index($index); - $self.chunks.get_unchecked(chunk_idx).value(arr_idx) - }}; -} - -macro_rules! take_random_get_single { - ($self:ident, $index:ident) => {{ - if $self.arr.is_null($index) { - None - } else { - Some($self.arr.value($index)) - } - }}; -} - -impl<'a, T> TakeRandom for NumTakeRandomChunked<'a, T> -where - T: PolarsNumericType, -{ - type Item = T::Native; - - fn get(&self, index: usize) -> Option { - take_random_get!(self, index) - } - - unsafe fn get_unchecked(&self, index: usize) -> Self::Item { - take_random_get_unchecked!(self, index) - } -} - -pub struct NumTakeRandomCont<'a, T> { - slice: &'a [T], -} - -impl<'a, T> TakeRandom for NumTakeRandomCont<'a, T> -where - T: Copy, -{ - type Item = T; - - fn get(&self, index: usize) -> Option { - self.slice.get(index).map(|v| *v) - } - - unsafe fn get_unchecked(&self, index: usize) -> Self::Item { - *self.slice.get_unchecked(index) - } -} - -pub struct NumTakeRandomSingleChunk<'a, T> -where - T: PolarsNumericType, -{ - arr: &'a PrimitiveArray, -} - -impl<'a, T> TakeRandom for NumTakeRandomSingleChunk<'a, T> -where - T: PolarsNumericType, -{ - type Item = T::Native; - - fn get(&self, index: usize) -> Option { - take_random_get_single!(self, index) - } - - unsafe fn get_unchecked(&self, index: usize) -> Self::Item { - self.arr.value(index) - } -} - -pub struct Utf8TakeRandom<'a> { - ca: &'a Utf8Chunked, - chunks: Vec<&'a StringArray>, -} - -impl<'a> TakeRandom for Utf8TakeRandom<'a> { - type Item = &'a str; - - fn get(&self, index: usize) -> Option { - take_random_get!(self, index) - } - - unsafe fn get_unchecked(&self, index: usize) -> Self::Item { - take_random_get_unchecked!(self, index) - } -} - -pub struct Utf8TakeRandomSingleChunk<'a> { - arr: &'a StringArray, -} - -impl<'a> TakeRandom for Utf8TakeRandomSingleChunk<'a> { - type Item = &'a str; - - fn get(&self, index: usize) -> Option { - take_random_get_single!(self, index) - } - - unsafe fn get_unchecked(&self, index: usize) -> Self::Item { - self.arr.value(index) - } -} - -pub struct BoolTakeRandom<'a> { - ca: &'a BooleanChunked, - chunks: Vec<&'a BooleanArray>, -} - -impl<'a> TakeRandom for BoolTakeRandom<'a> { - type Item = bool; - - fn get(&self, index: usize) -> Option { - take_random_get!(self, index) - } - - unsafe fn get_unchecked(&self, index: usize) -> Self::Item { - take_random_get_unchecked!(self, index) - } -} - -pub struct BoolTakeRandomSingleChunk<'a> { - arr: &'a BooleanArray, -} - -impl<'a> TakeRandom for BoolTakeRandomSingleChunk<'a> { - type Item = bool; - - fn get(&self, index: usize) -> Option { - take_random_get_single!(self, index) - } - - unsafe fn get_unchecked(&self, index: usize) -> Self::Item { - self.arr.value(index) - } -} -pub struct ListTakeRandom<'a> { - ca: &'a LargeListChunked, - chunks: Vec<&'a LargeListArray>, -} - -impl<'a> TakeRandom for ListTakeRandom<'a> { - type Item = Series; - - fn get(&self, index: usize) -> Option { - let v = take_random_get!(self, index); - v.map(|v| (self.ca.name(), v).into()) - } - - unsafe fn get_unchecked(&self, index: usize) -> Self::Item { - let v = take_random_get_unchecked!(self, index); - (self.ca.name(), v).into() - } -} - -pub struct ListTakeRandomSingleChunk<'a> { - arr: &'a LargeListArray, - name: &'a str, -} - -impl<'a> TakeRandom for ListTakeRandomSingleChunk<'a> { - type Item = Series; - - fn get(&self, index: usize) -> Option { - let v = take_random_get_single!(self, index); - v.map(|v| (self.name, v).into()) - } - - unsafe fn get_unchecked(&self, index: usize) -> Self::Item { - (self.name, self.arr.value(index)).into() - } -} - -#[cfg(test)] -mod test { - use crate::prelude::*; - - #[test] - fn test_take_random() { - let ca = Int32Chunked::new_from_slice("a", &[1, 2, 3]); - assert_eq!(ca.get(0), Some(1)); - assert_eq!(ca.get(1), Some(2)); - assert_eq!(ca.get(2), Some(3)); - - let ca = Utf8Chunked::new_from_slice("a", &["a", "b", "c"]); - assert_eq!(ca.get(0), Some("a")); - assert_eq!(ca.get(1), Some("b")); - assert_eq!(ca.get(2), Some("c")); - } -} diff --git a/polars/src/chunked_array/temporal.rs b/polars/src/chunked_array/temporal.rs deleted file mode 100644 index 5a79396cebb1..000000000000 --- a/polars/src/chunked_array/temporal.rs +++ /dev/null @@ -1,431 +0,0 @@ -//! Traits and utilities for temporal data. -use crate::prelude::*; -use chrono::{Datelike, NaiveDate, NaiveDateTime, NaiveTime, Timelike}; - -// Conversion extracted from: -// https://docs.rs/arrow/1.0.0/src/arrow/array/array.rs.html#589 - -/// Number of seconds in a day -const SECONDS_IN_DAY: i64 = 86_400; -/// Number of milliseconds in a second -const MILLISECONDS_IN_SECOND: i64 = 1_000; -/// Number of microseconds in a second -const MICROSECONDS_IN_SECOND: i64 = 1_000_000; -/// Number of nanoseconds in a second -const NANOSECONDS_IN_SECOND: i64 = 1_000_000_000; - -pub(crate) fn date32_as_datetime(v: i32) -> NaiveDateTime { - NaiveDateTime::from_timestamp(v as i64 * SECONDS_IN_DAY, 0) -} - -pub(crate) fn date64_as_datetime(v: i64) -> NaiveDateTime { - NaiveDateTime::from_timestamp( - // extract seconds from milliseconds - v / MILLISECONDS_IN_SECOND, - // discard extracted seconds and convert milliseconds to nanoseconds - (v % MILLISECONDS_IN_SECOND * MICROSECONDS_IN_SECOND) as u32, - ) -} - -pub(crate) fn timestamp_nanoseconds_as_datetime(v: i64) -> NaiveDateTime { - // some nanoseconds will be truncated down as integer division rounds downwards - let seconds = v / 1_000_000_000; - // we can use that to compute the remaining nanoseconds - let nanoseconds = (v - (seconds * 1_000_000_000)) as u32; - - NaiveDateTime::from_timestamp(seconds, nanoseconds) -} - -pub(crate) fn timestamp_microseconds_as_datetime(v: i64) -> NaiveDateTime { - // see nanoseconds for the logic - let seconds = v / 1_000_000; - let microseconds = (v - (seconds * 1_000_000)) as u32; - - NaiveDateTime::from_timestamp(seconds, microseconds) -} - -pub(crate) fn timestamp_milliseconds_as_datetime(v: i64) -> NaiveDateTime { - // see nanoseconds for the logic - let seconds = v / 1000; - let milliseconds = (v - (seconds * 1000)) as u32; - - NaiveDateTime::from_timestamp(seconds, milliseconds) -} - -pub(crate) fn timestamp_seconds_as_datetime(seconds: i64) -> NaiveDateTime { - NaiveDateTime::from_timestamp(seconds, 0) -} - -// date64 is number of milliseconds since the Unix Epoch -pub(crate) fn naive_datetime_to_date64(v: &NaiveDateTime) -> i64 { - v.timestamp_millis() -} - -pub(crate) fn naive_datetime_to_timestamp_nanoseconds(v: &NaiveDateTime) -> i64 { - v.timestamp_nanos() -} - -pub(crate) fn naive_datetime_to_timestamp_microseconds(v: &NaiveDateTime) -> i64 { - v.timestamp() * 1_000_000 + v.timestamp_subsec_micros() as i64 -} - -pub(crate) fn naive_datetime_to_timestamp_milliseconds(v: &NaiveDateTime) -> i64 { - v.timestamp_millis() -} - -pub(crate) fn naive_datetime_to_timestamp_seconds(v: &NaiveDateTime) -> i64 { - v.timestamp() -} - -pub(crate) fn naive_time_to_time64_nanoseconds(v: &NaiveTime) -> i64 { - // 3600 seconds in an hour - v.hour() as i64 * 3600 * NANOSECONDS_IN_SECOND - // 60 seconds in a minute - + v.minute() as i64 * 60 * NANOSECONDS_IN_SECOND - + v.second() as i64 * NANOSECONDS_IN_SECOND - + v.nanosecond() as i64 -} - -pub(crate) fn naive_time_to_time64_microseconds(v: &NaiveTime) -> i64 { - v.hour() as i64 * 3600 * MICROSECONDS_IN_SECOND - + v.minute() as i64 * 60 * MICROSECONDS_IN_SECOND - + v.second() as i64 * MICROSECONDS_IN_SECOND - + v.nanosecond() as i64 / 1000 -} - -pub(crate) fn naive_time_to_time32_milliseconds(v: &NaiveTime) -> i32 { - v.hour() as i32 * 3600 * MILLISECONDS_IN_SECOND as i32 - + v.minute() as i32 * 60 * MILLISECONDS_IN_SECOND as i32 - + v.second() as i32 * MILLISECONDS_IN_SECOND as i32 - + v.nanosecond() as i32 / 1000_000 -} - -pub(crate) fn naive_time_to_time32_seconds(v: &NaiveTime) -> i32 { - v.hour() as i32 * 3600 + v.minute() as i32 * 60 + v.second() as i32 + v.nanosecond() as i32 -} -pub(crate) fn time64_nanosecond_as_time(v: i64) -> NaiveTime { - NaiveTime::from_num_seconds_from_midnight( - // extract seconds from nanoseconds - (v / NANOSECONDS_IN_SECOND) as u32, - // discard extracted seconds - (v % NANOSECONDS_IN_SECOND) as u32, - ) -} - -pub(crate) fn time64_microsecond_as_time(v: i64) -> NaiveTime { - NaiveTime::from_num_seconds_from_midnight( - // extract seconds from microseconds - (v / MICROSECONDS_IN_SECOND) as u32, - // discard extracted seconds and convert microseconds to - // nanoseconds - (v % MICROSECONDS_IN_SECOND * MILLISECONDS_IN_SECOND) as u32, - ) -} - -pub(crate) fn time32_second_as_time(v: i32) -> NaiveTime { - NaiveTime::from_num_seconds_from_midnight(v as u32, 0) -} - -pub(crate) fn time32_millisecond_as_time(v: i32) -> NaiveTime { - let v = v as u32; - NaiveTime::from_num_seconds_from_midnight( - // extract seconds from milliseconds - v / MILLISECONDS_IN_SECOND as u32, - // discard extracted seconds and convert milliseconds to - // nanoseconds - v % MILLISECONDS_IN_SECOND as u32 * MICROSECONDS_IN_SECOND as u32, - ) -} - -pub fn unix_time() -> NaiveDateTime { - NaiveDateTime::from_timestamp(0, 0) -} - -pub trait FromNaiveTime { - fn new_from_naive_time(name: &str, v: &[N]) -> Self; - - fn parse_from_str_slice(name: &str, v: &[&str], fmt: &str) -> Self; -} - -fn parse_naive_time_from_str(s: &str, fmt: &str) -> Option { - NaiveTime::parse_from_str(s, fmt).ok() -} - -macro_rules! impl_from_naive_time { - ($arrowtype:ident, $chunkedtype:ident, $func:ident) => { - impl FromNaiveTime<$arrowtype, NaiveTime> for $chunkedtype { - fn new_from_naive_time(name: &str, v: &[NaiveTime]) -> Self { - let unit = v.iter().map($func).collect::>(); - ChunkedArray::new_from_aligned_vec(name, unit) - } - - fn parse_from_str_slice(name: &str, v: &[&str], fmt: &str) -> Self { - ChunkedArray::new_from_opt_iter( - name, - v.iter() - .map(|s| parse_naive_time_from_str(s, fmt).as_ref().map($func)), - ) - } - } - }; -} - -impl_from_naive_time!( - Time64NanosecondType, - Time64NanosecondChunked, - naive_time_to_time64_nanoseconds -); -impl_from_naive_time!( - Time64MicrosecondType, - Time64MicrosecondChunked, - naive_time_to_time64_microseconds -); -impl_from_naive_time!( - Time32MillisecondType, - Time32MillisecondChunked, - naive_time_to_time32_milliseconds -); -impl_from_naive_time!( - Time32SecondType, - Time32SecondChunked, - naive_time_to_time32_seconds -); - -pub trait AsNaiveTime { - fn as_naive_time(&self) -> Vec>; -} - -macro_rules! impl_as_naivetime { - ($ca:ty, $fun:ident) => { - impl AsNaiveTime for $ca { - fn as_naive_time(&self) -> Vec> { - self.into_iter().map(|opt_t| opt_t.map($fun)).collect() - } - } - }; -} - -impl_as_naivetime!(Time32SecondChunked, time32_second_as_time); -impl_as_naivetime!(Time32MillisecondChunked, time32_millisecond_as_time); -impl_as_naivetime!(Time64NanosecondChunked, time64_nanosecond_as_time); -impl_as_naivetime!(Time64MicrosecondChunked, time64_microsecond_as_time); - -fn parse_naive_datetime_from_str(s: &str, fmt: &str) -> Option { - NaiveDateTime::parse_from_str(s, fmt).ok() -} - -pub trait FromNaiveDateTime { - fn new_from_naive_datetime(name: &str, v: &[N]) -> Self; - - fn parse_from_str_slice(name: &str, v: &[&str], fmt: &str) -> Self; -} - -macro_rules! impl_from_naive_datetime { - ($arrowtype:ident, $chunkedtype:ident, $func:ident) => { - impl FromNaiveDateTime<$arrowtype, NaiveDateTime> for $chunkedtype { - fn new_from_naive_datetime(name: &str, v: &[NaiveDateTime]) -> Self { - let unit = v.iter().map($func).collect::>(); - ChunkedArray::new_from_aligned_vec(name, unit) - } - - fn parse_from_str_slice(name: &str, v: &[&str], fmt: &str) -> Self { - ChunkedArray::new_from_opt_iter( - name, - v.iter() - .map(|s| parse_naive_datetime_from_str(s, fmt).as_ref().map($func)), - ) - } - } - }; -} - -impl_from_naive_datetime!(Date64Type, Date64Chunked, naive_datetime_to_date64); -impl_from_naive_datetime!( - TimestampNanosecondType, - TimestampNanosecondChunked, - naive_datetime_to_timestamp_nanoseconds -); -impl_from_naive_datetime!( - TimestampMicrosecondType, - TimestampMicrosecondChunked, - naive_datetime_to_timestamp_microseconds -); -impl_from_naive_datetime!( - TimestampMillisecondType, - TimestampMillisecondChunked, - naive_datetime_to_timestamp_milliseconds -); -impl_from_naive_datetime!( - TimestampSecondType, - TimestampSecondChunked, - naive_datetime_to_timestamp_seconds -); - -pub trait FromNaiveDate { - fn new_from_naive_date(name: &str, v: &[N]) -> Self; - - fn parse_from_str_slice(name: &str, v: &[&str], fmt: &str) -> Self; -} - -fn naive_date_to_date32(nd: NaiveDate, unix_time: NaiveDate) -> i32 { - nd.signed_duration_since(unix_time).num_days() as i32 -} - -fn parse_naive_date_from_str(s: &str, fmt: &str) -> Option { - NaiveDate::parse_from_str(s, fmt).ok() -} - -fn unix_time_naive_date() -> NaiveDate { - NaiveDate::from_ymd(1970, 1, 1) -} - -impl FromNaiveDate for Date32Chunked { - fn new_from_naive_date(name: &str, v: &[NaiveDate]) -> Self { - let unix_date = unix_time_naive_date(); - - let unit = v - .iter() - .map(|v| naive_date_to_date32(*v, unix_date)) - .collect::>(); - ChunkedArray::new_from_aligned_vec(name, unit) - } - - fn parse_from_str_slice(name: &str, v: &[&str], fmt: &str) -> Self { - let unix_date = unix_time_naive_date(); - - ChunkedArray::new_from_opt_iter( - name, - v.iter().map(|s| { - parse_naive_date_from_str(s, fmt) - .as_ref() - .map(|v| naive_date_to_date32(*v, unix_date)) - }), - ) - } -} - -pub trait AsNaiveDateTime { - fn as_naive_datetime(&self) -> Vec>; -} - -macro_rules! impl_as_naive_datetime { - ($ca:ty, $fun:ident) => { - impl AsNaiveDateTime for $ca { - fn as_naive_datetime(&self) -> Vec> { - self.into_iter().map(|opt_t| opt_t.map($fun)).collect() - } - } - }; -} - -impl_as_naive_datetime!(Date32Chunked, date32_as_datetime); -impl_as_naive_datetime!(Date64Chunked, date64_as_datetime); -impl_as_naive_datetime!( - TimestampNanosecondChunked, - timestamp_nanoseconds_as_datetime -); -impl_as_naive_datetime!( - TimestampMicrosecondChunked, - timestamp_microseconds_as_datetime -); -impl_as_naive_datetime!( - TimestampMillisecondChunked, - timestamp_milliseconds_as_datetime -); -impl_as_naive_datetime!(TimestampSecondChunked, timestamp_seconds_as_datetime); - -pub trait AsNaiveDate { - fn as_naive_date(&self) -> Vec>; -} - -impl AsNaiveDate for Date32Chunked { - fn as_naive_date(&self) -> Vec> { - self.into_iter() - .map(|opt_t| { - opt_t.map(|v| { - let dt = date32_as_datetime(v); - NaiveDate::from_ymd(dt.year(), dt.month(), dt.day()) - }) - }) - .collect() - } -} - -#[cfg(all(test, feature = "temporal"))] -mod test { - use crate::prelude::*; - use chrono::{NaiveDateTime, NaiveTime}; - - #[test] - fn from_time() { - let times: Vec<_> = ["23:56:04", "00:00:00"] - .iter() - .map(|s| NaiveTime::parse_from_str(s, "%H:%M:%S").unwrap()) - .collect(); - let t = Time64NanosecondChunked::new_from_naive_time("times", ×); - // NOTE: the values are checked and correct. - assert_eq!([86164000000000, 0], t.cont_slice().unwrap()); - let t = Time64MicrosecondChunked::new_from_naive_time("times", ×); - assert_eq!([86164000000, 0], t.cont_slice().unwrap()); - let t = Time32MillisecondChunked::new_from_naive_time("times", ×); - assert_eq!([86164000, 0], t.cont_slice().unwrap()); - let t = Time32SecondChunked::new_from_naive_time("times", ×); - assert_eq!([86164, 0], t.cont_slice().unwrap()); - } - - #[test] - fn from_datetime() { - let datetimes: Vec<_> = [ - "1988-08-25 00:00:16", - "2015-09-05 23:56:04", - "2012-12-21 00:00:00", - ] - .iter() - .map(|s| NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S").unwrap()) - .collect(); - - // NOTE: the values are checked and correct. - let dt = Date64Chunked::new_from_naive_datetime("name", &datetimes); - assert_eq!( - [588470416000, 1441497364000, 1356048000000], - dt.cont_slice().unwrap() - ); - let dt = TimestampNanosecondChunked::new_from_naive_datetime("name", &datetimes); - assert_eq!( - [588470416000000000, 1441497364000000000, 1356048000000000000], - dt.cont_slice().unwrap() - ); - let dt = TimestampMicrosecondChunked::new_from_naive_datetime("name", &datetimes); - assert_eq!( - [588470416000000, 1441497364000000, 1356048000000000], - dt.cont_slice().unwrap() - ); - let dt = TimestampMillisecondChunked::new_from_naive_datetime("name", &datetimes); - assert_eq!( - [588470416000, 1441497364000, 1356048000000], - dt.cont_slice().unwrap() - ); - let dt = TimestampSecondChunked::new_from_naive_datetime("name", &datetimes); - assert_eq!( - [588470416, 1441497364, 1356048000], - dt.cont_slice().unwrap() - ); - } - - #[test] - fn from_date() { - let dates = &[ - "2020-08-21", - "2020-08-21", - "2020-08-22", - "2020-08-23", - "2020-08-22", - ]; - let fmt = "%Y-%m-%d"; - let ca = Date32Chunked::parse_from_str_slice("dates", dates, fmt); - assert_eq!( - [18495, 18495, 18496, 18497, 18496], - ca.cont_slice().unwrap() - ); - } -} diff --git a/polars/src/chunked_array/unique.rs b/polars/src/chunked_array/unique.rs deleted file mode 100644 index 1fb870c4c7aa..000000000000 --- a/polars/src/chunked_array/unique.rs +++ /dev/null @@ -1,233 +0,0 @@ -use crate::prelude::*; -use crate::utils::{floating_encode_f64, integer_decode}; -use fnv::{FnvBuildHasher, FnvHasher}; -use num::{NumCast, ToPrimitive}; -use std::collections::{HashMap, HashSet}; -use std::hash::{BuildHasherDefault, Hash}; -use unsafe_unwrap::UnsafeUnwrap; - -impl ChunkUnique for LargeListChunked { - fn unique(&self) -> ChunkedArray { - unimplemented!() - } - - fn arg_unique(&self) -> Vec { - unimplemented!() - } -} - -fn fill_set( - a: impl Iterator, - capacity: usize, -) -> HashSet> -where - A: Hash + Eq, -{ - let mut set = HashSet::with_capacity_and_hasher(capacity, FnvBuildHasher::default()); - - for val in a { - set.insert(val); - } - - set -} - -fn arg_unique(a: impl Iterator, capacity: usize) -> Vec -where - T: Hash + Eq, -{ - let mut set = HashSet::with_capacity_and_hasher(capacity, FnvBuildHasher::default()); - let mut unique = Vec::with_capacity_aligned(capacity); - a.enumerate().for_each(|(idx, val)| { - if set.insert(val) { - unique.push(idx) - } - }); - - unique -} - -impl ChunkUnique for ChunkedArray -where - T: PolarsIntegerType, - T::Native: Hash + Eq, - ChunkedArray: ChunkOps, -{ - fn unique(&self) -> Self { - let set = match self.cont_slice() { - Ok(slice) => fill_set(slice.iter().map(|v| Some(*v)), self.len()), - Err(_) => fill_set(self.into_iter(), self.len()), - }; - - Self::new_from_opt_iter(self.name(), set.iter().copied()) - } - - fn arg_unique(&self) -> Vec { - match self.cont_slice() { - Ok(slice) => arg_unique(slice.iter(), self.len()), - Err(_) => arg_unique(self.into_iter(), self.len()), - } - } -} - -impl ChunkUnique for Utf8Chunked { - fn unique(&self) -> Self { - let set = fill_set(self.into_iter(), self.len()); - Utf8Chunked::new_from_opt_iter(self.name(), set.iter().copied()) - } - - fn arg_unique(&self) -> Vec { - arg_unique(self.into_iter(), self.len()) - } -} - -impl ChunkUnique for BooleanChunked { - fn unique(&self) -> Self { - // can be None, Some(true), Some(false) - let mut unique = Vec::with_capacity(3); - for v in self { - if unique.len() == 3 { - break; - } - if !unique.contains(&v) { - unique.push(v) - } - } - ChunkedArray::new_from_opt_slice(self.name(), &unique) - } - - fn arg_unique(&self) -> Vec { - arg_unique(self.into_iter(), self.len()) - } -} - -// Use stable form of specialization using autoref -// https://github.com/dtolnay/case-studies/blob/master/autoref-specialization/README.md -impl ChunkUnique for &ChunkedArray -where - T: PolarsNumericType, - T::Native: NumCast + ToPrimitive, - ChunkedArray: ChunkOps, -{ - fn unique(&self) -> ChunkedArray { - let set = match self.cont_slice() { - Ok(slice) => fill_set( - slice - .iter() - .map(|v| Some(integer_decode(v.to_f64().unwrap()))), - self.len(), - ), - Err(_) => fill_set( - self.into_iter() - .map(|opt_v| opt_v.map(|v| integer_decode(v.to_f64().unwrap()))), - self.len(), - ), - }; - - // let builder = PrimitiveChunkedBuilder::new(self.name(), set.len()); - ChunkedArray::new_from_opt_iter( - self.name(), - set.iter().copied().map(|opt| match opt { - Some((mantissa, exponent, sign)) => { - let flt = floating_encode_f64(mantissa, exponent, sign); - let val: T::Native = NumCast::from(flt).unwrap(); - Some(val) - } - None => None, - }), - ) - } - - fn arg_unique(&self) -> Vec { - match self.cont_slice() { - Ok(slice) => arg_unique( - slice.iter().map(|v| { - let v = v.to_f64(); - debug_assert!(v.is_some()); - let v = unsafe { v.unsafe_unwrap() }; - integer_decode(v) - }), - self.len(), - ), - Err(_) => arg_unique( - self.into_iter().map(|opt_v| { - opt_v.map(|v| { - let v = v.to_f64(); - debug_assert!(v.is_some()); - let v = unsafe { v.unsafe_unwrap() }; - integer_decode(v) - }) - }), - self.len(), - ), - } - } -} - -pub trait ValueCounts -where - T: ArrowPrimitiveType, -{ - fn value_counts(&self) -> HashMap, u32, BuildHasherDefault>; -} - -fn fill_set_value_count( - a: impl Iterator, - capacity: usize, -) -> HashMap> -where - K: Hash + Eq, -{ - let mut kv_store = HashMap::with_capacity_and_hasher(capacity, FnvBuildHasher::default()); - - for key in a { - let count = kv_store.entry(key).or_insert(0); - *count += 1; - } - - kv_store -} - -impl ValueCounts for ChunkedArray -where - T: PolarsIntegerType, - T::Native: Hash + Eq, - ChunkedArray: ChunkOps, -{ - fn value_counts(&self) -> HashMap, u32, BuildHasherDefault> { - match self.cont_slice() { - Ok(slice) => fill_set_value_count(slice.iter().map(|v| Some(*v)), self.len()), - Err(_) => fill_set_value_count(self.into_iter(), self.len()), - } - } -} - -#[cfg(test)] -mod test { - use crate::prelude::*; - use itertools::Itertools; - - #[test] - fn unique() { - let ca = ChunkedArray::::new_from_slice("a", &[1, 2, 3, 2, 1]); - assert_eq!( - ca.unique().into_iter().collect_vec(), - vec![Some(1), Some(2), Some(3)] - ); - let ca = BooleanChunked::new_from_slice("a", &[true, false, true]); - assert_eq!( - ca.unique().into_iter().collect_vec(), - vec![Some(true), Some(false)] - ); - - let ca = - Utf8Chunked::new_from_opt_slice("", &[Some("a"), None, Some("a"), Some("b"), None]); - assert_eq!(Vec::from(&ca.unique()), &[Some("a"), None, Some("b")]); - } - - #[test] - fn arg_unique() { - let ca = ChunkedArray::::new_from_slice("a", &[1, 2, 1, 1, 3]); - assert_eq!(ca.arg_unique().into_iter().collect_vec(), vec![0, 1, 4]); - } -} diff --git a/polars/src/chunked_array/upstream_traits.rs b/polars/src/chunked_array/upstream_traits.rs deleted file mode 100644 index fae47dd4f962..000000000000 --- a/polars/src/chunked_array/upstream_traits.rs +++ /dev/null @@ -1,437 +0,0 @@ -//! Implementations of upstream traits for ChunkedArray -use crate::chunked_array::builder::get_large_list_builder; -use crate::prelude::*; -use crate::utils::get_iter_capacity; -use crate::utils::Xob; -use rayon::iter::{FromParallelIterator, IntoParallelIterator}; -use rayon::prelude::*; -use std::collections::LinkedList; -use std::iter::FromIterator; - -/// FromIterator trait - -impl FromIterator> for ChunkedArray -where - T: ArrowPrimitiveType, -{ - fn from_iter>>(iter: I) -> Self { - let iter = iter.into_iter(); - let mut builder = PrimitiveChunkedBuilder::new("", get_iter_capacity(&iter)); - - for opt_val in iter { - builder.append_option(opt_val); - } - builder.finish() - } -} - -// Xob is only a wrapper needed for specialization -impl FromIterator for Xob> -where - T: ArrowPrimitiveType, -{ - fn from_iter>(iter: I) -> Self { - let iter = iter.into_iter(); - let mut builder = PrimitiveChunkedBuilder::new("", get_iter_capacity(&iter)); - - for val in iter { - builder.append_value(val); - } - Xob::new(builder.finish()) - } -} - -impl FromIterator for BooleanChunked { - fn from_iter>(iter: I) -> Self { - let iter = iter.into_iter(); - let mut builder = PrimitiveChunkedBuilder::new("", get_iter_capacity(&iter)); - - for val in iter { - builder.append_value(val); - } - builder.finish() - } -} - -// FromIterator for Utf8Chunked variants. - -impl<'a> FromIterator<&'a str> for Utf8Chunked { - fn from_iter>(iter: I) -> Self { - let iter = iter.into_iter(); - let mut builder = Utf8ChunkedBuilder::new("", get_iter_capacity(&iter)); - - for val in iter { - builder.append_value(val); - } - builder.finish() - } -} - -impl<'a> FromIterator<&'a &'a str> for Utf8Chunked { - fn from_iter>(iter: I) -> Self { - let iter = iter.into_iter(); - let mut builder = Utf8ChunkedBuilder::new("", get_iter_capacity(&iter)); - - for val in iter { - builder.append_value(val); - } - builder.finish() - } -} - -macro_rules! impl_from_iter_utf8 { - ($iter:ident) => {{ - let iter = $iter.into_iter(); - let mut builder = Utf8ChunkedBuilder::new("", get_iter_capacity(&iter)); - - for opt_val in iter { - builder.append_option(opt_val.as_ref()) - } - builder.finish() - }}; -} - -impl<'a> FromIterator> for Utf8Chunked { - fn from_iter>>(iter: I) -> Self { - impl_from_iter_utf8!(iter) - } -} - -impl<'a> FromIterator<&'a Option<&'a str>> for Utf8Chunked { - fn from_iter>>(iter: I) -> Self { - impl_from_iter_utf8!(iter) - } -} - -impl FromIterator for Utf8Chunked { - fn from_iter>(iter: I) -> Self { - let iter = iter.into_iter(); - let mut builder = Utf8ChunkedBuilder::new("", get_iter_capacity(&iter)); - - for val in iter { - builder.append_value(val.as_str()); - } - builder.finish() - } -} - -impl FromIterator> for Utf8Chunked { - fn from_iter>>(iter: I) -> Self { - let iter = iter.into_iter(); - let mut builder = Utf8ChunkedBuilder::new("", get_iter_capacity(&iter)); - - for opt_val in iter { - match opt_val { - None => builder.append_null(), - Some(val) => builder.append_value(val.as_str()), - } - } - builder.finish() - } -} - -impl FromIterator for LargeListChunked { - fn from_iter>(iter: I) -> Self { - let mut it = iter.into_iter(); - let capacity = get_iter_capacity(&it); - - // first take one to get the dtype. We panic if we have an empty iterator - let v = it.next().unwrap(); - let mut builder = get_large_list_builder(v.dtype(), capacity, "collected"); - - builder.append_opt_series(&Some(v)); - while let Some(s) = it.next() { - builder.append_opt_series(&Some(s)); - } - builder.finish() - } -} - -impl<'a> FromIterator<&'a Series> for LargeListChunked { - fn from_iter>(iter: I) -> Self { - let mut it = iter.into_iter(); - let capacity = get_iter_capacity(&it); - - // first take one to get the dtype. We panic if we have an empty iterator - let v = it.next().unwrap(); - let mut builder = get_large_list_builder(v.dtype(), capacity, "collected"); - - builder.append_series(v); - while let Some(s) = it.next() { - builder.append_series(s); - } - - builder.finish() - } -} - -macro_rules! impl_from_iter_opt_series { - ($iter:ident) => {{ - // we don't know the type of the series until we get Some(Series) from the iterator. - // until that happens we count the number of None's so that we can first fill the None's until - // we know the type - - let mut it = $iter.into_iter(); - - let v; - let mut cnt = 0; - - loop { - let opt_v = it.next(); - - match opt_v { - Some(opt_v) => match opt_v { - Some(val) => { - v = val; - break; - } - None => cnt += 1, - }, - // end of iterator - None => { - // type is not known - panic!("Type of Series cannot be determined as they are all null") - } - } - } - let capacity = get_iter_capacity(&it); - let mut builder = get_large_list_builder(v.dtype(), capacity, "collected"); - - // first fill all None's we encountered - while cnt > 0 { - builder.append_opt_series(&None); - cnt -= 1; - } - - // now the first non None - builder.append_series(&v); - - // now we have added all Nones, we can consume the rest of the iterator. - while let Some(opt_s) = it.next() { - builder.append_opt_series(&opt_s); - } - - builder.finish() - - }} -} - -impl FromIterator> for LargeListChunked { - fn from_iter>>(iter: I) -> Self { - impl_from_iter_opt_series!(iter) - } -} - -impl<'a> FromIterator<&'a Option> for LargeListChunked { - fn from_iter>>(iter: I) -> Self { - impl_from_iter_opt_series!(iter) - } -} - -/// FromParallelIterator trait - -// Code taken from https://docs.rs/rayon/1.3.1/src/rayon/iter/extend.rs.html#356-366 -fn vec_push(mut vec: Vec, elem: T) -> Vec { - vec.push(elem); - vec -} - -fn as_list(item: T) -> LinkedList { - let mut list = LinkedList::new(); - list.push_back(item); - list -} - -fn list_append(mut list1: LinkedList, mut list2: LinkedList) -> LinkedList { - list1.append(&mut list2); - list1 -} - -fn collect_into_linked_list(par_iter: I) -> LinkedList> -where - I: IntoParallelIterator, -{ - par_iter - .into_par_iter() - .fold(Vec::new, vec_push) - .map(as_list) - .reduce(LinkedList::new, list_append) -} - -fn get_capacity_from_par_results(ll: &LinkedList>) -> usize { - ll.iter().map(|list| list.len()).sum() -} - -impl FromParallelIterator> for ChunkedArray -where - T: ArrowPrimitiveType, -{ - fn from_par_iter>>(iter: I) -> Self { - // Get linkedlist filled with different vec result from different threads - let vectors = collect_into_linked_list(iter); - let capacity: usize = get_capacity_from_par_results(&vectors); - - let mut builder = PrimitiveChunkedBuilder::new("", capacity); - // Unpack all these results and append them single threaded - vectors.iter().for_each(|vec| { - for opt_val in vec { - builder.append_option(*opt_val); - } - }); - - builder.finish() - } -} - -impl FromParallelIterator for BooleanChunked { - fn from_par_iter>(iter: I) -> Self { - let vectors = collect_into_linked_list(iter); - let capacity: usize = get_capacity_from_par_results(&vectors); - - let mut builder = PrimitiveChunkedBuilder::new("", capacity); - // Unpack all these results and append them single threaded - vectors.iter().for_each(|vec| { - for val in vec { - builder.append_value(*val); - } - }); - - builder.finish() - } -} - -impl FromParallelIterator for Utf8Chunked { - fn from_par_iter>(iter: I) -> Self { - let vectors = collect_into_linked_list(iter); - let capacity: usize = get_capacity_from_par_results(&vectors); - - let mut builder = Utf8ChunkedBuilder::new("", capacity); - // Unpack all these results and append them single threaded - vectors.iter().for_each(|vec| { - for val in vec { - builder.append_value(val.as_str()); - } - }); - - builder.finish() - } -} - -impl FromParallelIterator> for Utf8Chunked { - fn from_par_iter>>(iter: I) -> Self { - let vectors = collect_into_linked_list(iter); - let capacity: usize = get_capacity_from_par_results(&vectors); - - let mut builder = Utf8ChunkedBuilder::new("", capacity); - // Unpack all these results and append them single threaded - vectors.iter().for_each(|vec| { - for val in vec { - builder.append_option(val.as_ref()); - } - }); - builder.finish() - } -} - -impl<'a> FromParallelIterator> for Utf8Chunked { - fn from_par_iter>>(iter: I) -> Self { - let vectors = collect_into_linked_list(iter); - let capacity: usize = get_capacity_from_par_results(&vectors); - - let mut builder = Utf8ChunkedBuilder::new("", capacity); - // Unpack all these results and append them single threaded - vectors.iter().for_each(|vec| { - for val in vec { - builder.append_option(val.as_ref()); - } - }); - builder.finish() - } -} - -impl<'a> FromParallelIterator<&'a str> for Utf8Chunked { - fn from_par_iter>(iter: I) -> Self { - let vectors = collect_into_linked_list(iter); - let capacity: usize = get_capacity_from_par_results(&vectors); - - let mut builder = Utf8ChunkedBuilder::new("", capacity); - // Unpack all these results and append them single threaded - vectors.iter().for_each(|vec| { - for val in vec { - builder.append_value(val); - } - }); - builder.finish() - } -} - -/// From trait - -// TODO: use macro -// Only the one which takes Utf8Chunked by reference is implemented. -// We cannot return a & str owned by this function. -impl<'a> From<&'a Utf8Chunked> for Vec> { - fn from(ca: &'a Utf8Chunked) -> Self { - let mut vec = Vec::with_capacity_aligned(ca.len()); - ca.into_iter().for_each(|opt| vec.push(opt)); - vec - } -} - -impl From for Vec> { - fn from(ca: Utf8Chunked) -> Self { - let mut vec = Vec::with_capacity_aligned(ca.len()); - ca.into_iter() - .for_each(|opt| vec.push(opt.map(|s| s.to_string()))); - vec - } -} - -impl<'a> From<&'a BooleanChunked> for Vec> { - fn from(ca: &'a BooleanChunked) -> Self { - let mut vec = Vec::with_capacity_aligned(ca.len()); - ca.into_iter().for_each(|opt| vec.push(opt)); - vec - } -} - -impl From for Vec> { - fn from(ca: BooleanChunked) -> Self { - let mut vec = Vec::with_capacity_aligned(ca.len()); - ca.into_iter().for_each(|opt| vec.push(opt)); - vec - } -} - -impl<'a, T> From<&'a ChunkedArray> for Vec> -where - T: PolarsNumericType, - &'a ChunkedArray: IntoIterator>, - ChunkedArray: ChunkOps, -{ - fn from(ca: &'a ChunkedArray) -> Self { - let mut vec = Vec::with_capacity_aligned(ca.len()); - ca.into_iter().for_each(|opt| vec.push(opt)); - vec - } -} -// TODO: macro implementation of Vec From for all types. ChunkedArray (no reference) doesn't implement -// &'a ChunkedArray: IntoIterator>, - -#[cfg(test)] -mod test { - use crate::prelude::*; - - #[test] - fn test_collect_into_large_listt() { - let s1 = Series::new("", &[true, false, true]); - let s2 = Series::new("", &[true, false, true]); - - let ll: LargeListChunked = [&s1, &s2].iter().map(|&s| s).collect(); - assert_eq!(ll.len(), 2); - assert_eq!(ll.null_count(), 0); - let ll: LargeListChunked = [None, Some(s2)].iter().collect(); - assert_eq!(ll.len(), 2); - assert_eq!(ll.null_count(), 1); - } -} diff --git a/polars/src/datatypes.rs b/polars/src/datatypes.rs deleted file mode 100644 index 99507678b370..000000000000 --- a/polars/src/datatypes.rs +++ /dev/null @@ -1,245 +0,0 @@ -//! # Data types supported by Polars. -//! -//! At the moment Polars doesn't include all data types available by Arrow. The goal is to -//! incrementally support more data types and prioritize these by usability. -//! -//! [See the AnyType variants](enum.AnyType.html#variants) for the data types that -//! are currently supported. -//! -use crate::chunked_array::ChunkedArray; -use crate::series::Series; -pub use arrow::datatypes::DataType as ArrowDataType; -pub use arrow::datatypes::{ - ArrowNumericType, ArrowPrimitiveType, BooleanType, Date32Type, Date64Type, DateUnit, - DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, DurationSecondType, - Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType, - IntervalUnit, IntervalYearMonthType, Time32MillisecondType, Time32SecondType, - Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, - TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, - UInt64Type, UInt8Type, -}; - -pub struct Utf8Type {} - -pub struct LargeListType {} - -pub trait PolarsDataType { - fn get_data_type() -> ArrowDataType; -} - -impl PolarsDataType for T -where - T: ArrowPrimitiveType, -{ - fn get_data_type() -> ArrowDataType { - T::get_data_type() - } -} - -impl PolarsDataType for Utf8Type { - fn get_data_type() -> ArrowDataType { - ArrowDataType::Utf8 - } -} - -impl PolarsDataType for LargeListType { - fn get_data_type() -> ArrowDataType { - // null as we cannot no anything without self. - ArrowDataType::LargeList(Box::new(ArrowDataType::Null)) - } -} - -/// Any type that is not nested -pub trait PolarsSingleType: PolarsDataType {} - -impl PolarsSingleType for T where T: ArrowPrimitiveType + PolarsDataType {} - -impl PolarsSingleType for Utf8Type {} - -pub type LargeListChunked = ChunkedArray; -pub type BooleanChunked = ChunkedArray; -pub type UInt8Chunked = ChunkedArray; -pub type UInt16Chunked = ChunkedArray; -pub type UInt32Chunked = ChunkedArray; -pub type UInt64Chunked = ChunkedArray; -pub type Int8Chunked = ChunkedArray; -pub type Int16Chunked = ChunkedArray; -pub type Int32Chunked = ChunkedArray; -pub type Int64Chunked = ChunkedArray; -pub type Float32Chunked = ChunkedArray; -pub type Float64Chunked = ChunkedArray; -pub type Utf8Chunked = ChunkedArray; -pub type Date32Chunked = ChunkedArray; -pub type Date64Chunked = ChunkedArray; -pub type DurationNanosecondChunked = ChunkedArray; -pub type DurationMicrosecondChunked = ChunkedArray; -pub type DurationMillisecondChunked = ChunkedArray; -pub type DurationSecondChunked = ChunkedArray; - -pub type Time64NanosecondChunked = ChunkedArray; -pub type Time64MicrosecondChunked = ChunkedArray; -pub type Time32MillisecondChunked = ChunkedArray; -pub type Time32SecondChunked = ChunkedArray; -pub type IntervalDayTimeChunked = ChunkedArray; -pub type IntervalYearMonthChunked = ChunkedArray; - -pub type TimestampNanosecondChunked = ChunkedArray; -pub type TimestampMicrosecondChunked = ChunkedArray; -pub type TimestampMillisecondChunked = ChunkedArray; -pub type TimestampSecondChunked = ChunkedArray; - -pub trait PolarsNumericType: ArrowNumericType {} - -impl PolarsNumericType for UInt8Type {} -impl PolarsNumericType for UInt16Type {} -impl PolarsNumericType for UInt32Type {} -impl PolarsNumericType for UInt64Type {} -impl PolarsNumericType for Int8Type {} -impl PolarsNumericType for Int16Type {} -impl PolarsNumericType for Int32Type {} -impl PolarsNumericType for Int64Type {} -impl PolarsNumericType for Float32Type {} -impl PolarsNumericType for Float64Type {} -impl PolarsNumericType for Date32Type {} -impl PolarsNumericType for Date64Type {} -impl PolarsNumericType for Time64NanosecondType {} -impl PolarsNumericType for Time64MicrosecondType {} -impl PolarsNumericType for Time32MillisecondType {} -impl PolarsNumericType for Time32SecondType {} -impl PolarsNumericType for DurationNanosecondType {} -impl PolarsNumericType for DurationMicrosecondType {} -impl PolarsNumericType for DurationMillisecondType {} -impl PolarsNumericType for DurationSecondType {} -impl PolarsNumericType for IntervalYearMonthType {} -impl PolarsNumericType for IntervalDayTimeType {} -impl PolarsNumericType for TimestampNanosecondType {} -impl PolarsNumericType for TimestampMicrosecondType {} -impl PolarsNumericType for TimestampMillisecondType {} -impl PolarsNumericType for TimestampSecondType {} - -pub trait PolarsIntegerType: PolarsNumericType {} -impl PolarsIntegerType for UInt8Type {} -impl PolarsIntegerType for UInt16Type {} -impl PolarsIntegerType for UInt32Type {} -impl PolarsIntegerType for UInt64Type {} -impl PolarsIntegerType for Int8Type {} -impl PolarsIntegerType for Int16Type {} -impl PolarsIntegerType for Int32Type {} -impl PolarsIntegerType for Int64Type {} -impl PolarsIntegerType for Date32Type {} -impl PolarsIntegerType for Date64Type {} -impl PolarsIntegerType for Time64NanosecondType {} -impl PolarsIntegerType for Time64MicrosecondType {} -impl PolarsIntegerType for Time32MillisecondType {} -impl PolarsIntegerType for Time32SecondType {} -impl PolarsIntegerType for DurationNanosecondType {} -impl PolarsIntegerType for DurationMicrosecondType {} -impl PolarsIntegerType for DurationMillisecondType {} -impl PolarsIntegerType for DurationSecondType {} -impl PolarsIntegerType for IntervalYearMonthType {} -impl PolarsIntegerType for IntervalDayTimeType {} -impl PolarsIntegerType for TimestampNanosecondType {} -impl PolarsIntegerType for TimestampMicrosecondType {} -impl PolarsIntegerType for TimestampMillisecondType {} -impl PolarsIntegerType for TimestampSecondType {} - -#[derive(Debug)] -pub enum AnyType<'a> { - Null, - /// A binary true or false. - Boolean(bool), - /// A UTF8 encoded string type. - Utf8(&'a str), - /// An unsigned 8-bit integer number. - UInt8(u8), - /// An unsigned 16-bit integer number. - UInt16(u16), - /// An unsigned 32-bit integer number. - UInt32(u32), - /// An unsigned 64-bit integer number. - UInt64(u64), - /// An 8-bit integer number. - Int8(i8), - /// A 16-bit integer number. - Int16(i16), - /// A 32-bit integer number. - Int32(i32), - /// A 64-bit integer number. - Int64(i64), - /// A 32-bit floating point number. - Float32(f32), - /// A 64-bit floating point number. - Float64(f64), - /// A 32-bit date representing the elapsed time since UNIX epoch (1970-01-01) - /// in days (32 bits). - Date32(i32), - /// A 64-bit date representing the elapsed time since UNIX epoch (1970-01-01) - /// in milliseconds (64 bits). - Date64(i64), - /// A 64-bit time representing the elapsed time since midnight in the unit of `TimeUnit`. - Time64(i64, TimeUnit), - /// A 32-bit time representing the elapsed time since midnight in the unit of `TimeUnit`. - Time32(i32, TimeUnit), - /// Measure of elapsed time in either seconds, milliseconds, microseconds or nanoseconds. - Duration(i64, TimeUnit), - /// Naive Time elapsed from the Unix epoch, 00:00:00.000 on 1 January 1970, excluding leap seconds, as a 64-bit integer. - /// Note that UNIX time does not include leap seconds. - TimeStamp(i64, TimeUnit), - /// A "calendar" interval which models types that don't necessarily have a precise duration without the context of a base timestamp - /// (e.g. days can differ in length during day light savings time transitions). - IntervalDayTime(i64), - IntervalYearMonth(i32), - LargeList(Series), -} - -pub trait ToStr { - fn to_str(&self) -> String; -} - -impl ToStr for ArrowDataType { - fn to_str(&self) -> String { - // TODO: add types here - let s = match self { - ArrowDataType::Null => "null", - ArrowDataType::Boolean => "bool", - ArrowDataType::UInt8 => "u8", - ArrowDataType::UInt16 => "u16", - ArrowDataType::UInt32 => "u32", - ArrowDataType::UInt64 => "u64", - ArrowDataType::Int8 => "i8", - ArrowDataType::Int16 => "i16", - ArrowDataType::Int32 => "i32", - ArrowDataType::Int64 => "i64", - ArrowDataType::Float32 => "f32", - ArrowDataType::Float64 => "f64", - ArrowDataType::Utf8 => "str", - ArrowDataType::Date32(DateUnit::Day) => "date32", - ArrowDataType::Date64(DateUnit::Millisecond) => "date64", - ArrowDataType::Time32(TimeUnit::Second) => "time64(s)", - ArrowDataType::Time32(TimeUnit::Millisecond) => "time64(ms)", - ArrowDataType::Time64(TimeUnit::Nanosecond) => "time64(ns)", - ArrowDataType::Time64(TimeUnit::Microsecond) => "time64(μs)", - // Note: Polars doesn't support the optional TimeZone in the timestamps. - ArrowDataType::Timestamp(TimeUnit::Nanosecond, _) => "timestamp(ns)", - ArrowDataType::Timestamp(TimeUnit::Microsecond, _) => "timestamp(μs)", - ArrowDataType::Timestamp(TimeUnit::Millisecond, _) => "timestamp(ms)", - ArrowDataType::Timestamp(TimeUnit::Second, _) => "timestamp(s)", - ArrowDataType::Duration(TimeUnit::Nanosecond) => "duration(ns)", - ArrowDataType::Duration(TimeUnit::Microsecond) => "duration(μs)", - ArrowDataType::Duration(TimeUnit::Millisecond) => "duration(ms)", - ArrowDataType::Duration(TimeUnit::Second) => "duration(s)", - ArrowDataType::Interval(IntervalUnit::DayTime) => "interval(daytime)", - ArrowDataType::Interval(IntervalUnit::YearMonth) => "interval(year-month)", - ArrowDataType::LargeList(tp) => return format!("list [{}]", tp.to_str()), - _ => panic!(format!("{:?} not implemented", self)), - }; - s.into() - } -} - -impl<'a> PartialEq for AnyType<'a> { - // Everything of Any is slow. Don't use. - fn eq(&self, other: &Self) -> bool { - format!("{}", self) == format!("{}", other) - } -} diff --git a/polars/src/doc/changelog/mod.rs b/polars/src/doc/changelog/mod.rs deleted file mode 100644 index 497ec214fd81..000000000000 --- a/polars/src/doc/changelog/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -pub mod v0_3; -pub mod v0_4; -pub mod v0_5; -pub mod v0_6; diff --git a/polars/src/doc/changelog/v0_3.rs b/polars/src/doc/changelog/v0_3.rs deleted file mode 100644 index 0c3bd98ca220..000000000000 --- a/polars/src/doc/changelog/v0_3.rs +++ /dev/null @@ -1,8 +0,0 @@ -//! # Changelog v0.3 -//! -//! * Utf8 type is nullable [#37](https://github.com/ritchie46/polars/issues/37) -//! * Support all ARROW numeric types [#40](https://github.com/ritchie46/polars/issues/40) -//! * Support all ARROW temporal types [#46](https://github.com/ritchie46/polars/issues/46) -//! * ARROW IPC Reader/ Writer [#50](https://github.com/ritchie46/polars/issues/50) -//! * Implement DoubleEndedIterator trait for ChunkedArray's [#34](https://github.com/ritchie46/polars/issues/34) -//! diff --git a/polars/src/doc/changelog/v0_4.rs b/polars/src/doc/changelog/v0_4.rs deleted file mode 100644 index c4f00cf1b50b..000000000000 --- a/polars/src/doc/changelog/v0_4.rs +++ /dev/null @@ -1,9 +0,0 @@ -//! # Changelog v0.4 -//! -//! * median aggregation added to `ChunkedArray` -//! * Arrow LargeList datatype support (and groupby aggregation into LargeList). -//! * Shift operation. -//! * Fill None operation. -//! * Buffered serialization (less memory requirements) -//! * Temporal utilities -//! diff --git a/polars/src/doc/changelog/v0_5.rs b/polars/src/doc/changelog/v0_5.rs deleted file mode 100644 index b0e4981d55cb..000000000000 --- a/polars/src/doc/changelog/v0_5.rs +++ /dev/null @@ -1,11 +0,0 @@ -//! # Changelog v0.5 -//! -//! * `DataFrame.column` returns `Result<_>` **breaking change**. -//! * Define idiomatic way to do inplace operations on a `DataFrame` with `apply`, `may_apply` and `ChunkSet` -//! * `ChunkSet` Trait. -//! * `Groupby` aggregations can be done on a selection of multiple columns. -//! * `Groupby` operation can be done on multiple keys. -//! * `Groupby` `first` operation. -//! * `Pivot` operation. -//! * Random access to `ChunkedArray` types via `.get` and `.get_unchecked`. -//! diff --git a/polars/src/doc/changelog/v0_6.rs b/polars/src/doc/changelog/v0_6.rs deleted file mode 100644 index 23e38f3d2369..000000000000 --- a/polars/src/doc/changelog/v0_6.rs +++ /dev/null @@ -1,8 +0,0 @@ -//! # Changelog v0.6 -//! -//! * Add more distributions for random sampling. -//! * Fix float aggregations with NaNs. -//! * Comparisons are more performant. -//! * Outer join is more performant. -//! * Start with parallel iterator support for ChunkedArrays. -//! * Remove crossbeam dependency. diff --git a/polars/src/doc/mod.rs b/polars/src/doc/mod.rs deleted file mode 100644 index 3ef925cf2fe7..000000000000 --- a/polars/src/doc/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -//! Other documentation -pub mod changelog; -#[cfg(feature = "temporal")] -pub mod time; diff --git a/polars/src/doc/time.rs b/polars/src/doc/time.rs deleted file mode 100644 index 1644a27e093a..000000000000 --- a/polars/src/doc/time.rs +++ /dev/null @@ -1,126 +0,0 @@ -//! # DateTime related functionality -//! -//! Polars supports all data types in Arrow related to time and dates in any kind. -//! In Arrow times and dates are stored as i32 or i64 integers. This can represent for instance -//! a duration in seconds since the *Unix Epoch: 00:00:00 1 january 1970*. -//! -//! ## Chrono -//! I can imagine that integer values aren't the most intuitive when dealing with dates and times. -//! For this reason, Polars supports conversion from **chrono's** [NaiveTime](https://docs.rs/chrono/0.4.13/chrono/naive/struct.NaiveTime.html) -//! and [NaiveDate](https://docs.rs/chrono/0.4.13/chrono/naive/struct.NaiveDate.html) structs to -//! Polars and vice versa. -//! -//! `ChunkedArray`'s initialization is represented by the the following traits: -//! * [FromNaiveTime](../../chunked_array/temporal/trait.FromNaiveTime.html) -//! * [FromNaiveDateTime](../../chunked_array/temporal/trait.FromNaiveDateTime.html) -//! -//! To cast a `ChunkedArray` to a `Vec` or a `Vec` see: -//! * [AsNaiveTime](../../chunked_array/temporal/trait.AsNaiveTime.html) -//! -//! ### Example -//! -//! ```rust -//! use chrono::NaiveTime; -//! use polars::prelude::*;; -//! -//! // We can create a ChunkedArray from NaiveTime objects -//! fn from_naive_time_to_time32(time_values: &[NaiveTime]) -> Time32SecondChunked { -//! Time32SecondChunked::new_from_naive_time("name", time_values) -//! } -//! -//! // Or from a ChunkedArray to NaiveTime objects -//! fn from_time32_to_naive_time(ca: &Time32SecondChunked) -> Vec> { -//! ca.as_naive_time() -//! } -//! ``` -//! -//! ## String formatting -//! -//! We can also directly parse strings given a predifined `fmt: &str`. This uses **chrono's** -//! [NaiveTime::parse_from_str](https://docs.rs/chrono/0.4.15/chrono/naive/struct.NaiveTime.html#method.parse_from_str) -//! under the hood. So look there for the different formatting options. If the string parsing is not -//! succesful, the value will be `None`. -//! -//! ### Examples -//! -//! #### NaiveTime -//! -//! ```rust -//! use polars::prelude::*; -//! use chrono::NaiveTime; -//! -//! // String values to parse, Note that the 2nd value is not a correct time value. -//! let time_values = &["23:56:04", "26:00:99", "12:39:08"]; -//! // Parsing fmt -//! let fmt = "%H:%M:%S"; -//! // Create the ChunkedArray -//! let ca = Time64NanosecondChunked::parse_from_str_slice("Time as ns since midnight", time_values, fmt); -//! -//! // Assert that we've got a ChunkedArray with a single None value. -//! assert_eq!(ca.as_naive_time(), -//! &[NaiveTime::parse_from_str(time_values[0], fmt).ok(), -//! None, -//! NaiveTime::parse_from_str(time_values[2], fmt).ok()]); -//! ``` -//! #### NaiveDateTime -//! -//! ```rust -//! use polars::prelude::*; -//! use chrono::NaiveDateTime; -//! -//! // String values to parse, Note that the 2nd value is not a correct time value. -//! let datetime_values = &[ -//! "1988-08-25 00:00:16", -//! "2015-09-05 23:56:04", -//! "2012-12-21 00:00:00", -//! ]; -//! // Parsing fmt -//! let fmt = "%Y-%m-%d %H:%M:%S"; -//! -//! // Create the ChunkedArray -//! let ca = TimestampSecondChunked::parse_from_str_slice("datetime as s since Epoch", datetime_values, fmt); -//! -//! // Or collect into a Vec -//! let vec = ca.as_naive_datetime(); -//! -//! // We could also parse these datetime strings as the following data types: -//! -//! // date and time timestamps in different precisions -//! let ca = TimestampNanosecondChunked::parse_from_str_slice("datetime as ns since Epoch", datetime_values, fmt); -//! let ca = TimestampMicrosecondChunked::parse_from_str_slice("datetime as μs since Epoch", datetime_values, fmt); -//! let ca = TimestampMillisecondChunked::parse_from_str_slice("datetime as ms since Epoch", datetime_values, fmt); -//! -//! // or dates in different precisions (days and milliseconds) -//! let ca = Date32Chunked::parse_from_str_slice("date as days since Epoch", datetime_values, fmt); -//! let ca = Date64Chunked::parse_from_str_slice("date as ms since Epoch", datetime_values, fmt); -//! ``` -//! -//! ## Temporal Data Types -//! Polars supports all the time datatypes supported by Apache Arrow. These store the time values in -//! different precisions and bytes: -//! -//! ### Time -//! * **Time64NanosecondChunked** -//! - A ChunkedArray which stores time with nanosecond precision as 64 bit signed integer. -//! * **Time64MicrosecondChunked** -//! - A ChunkedArray which stores time with microsecond precision as 64 bit signed integer. -//! * **Time32MillisecondChunked** -//! - A ChunkedArray which stores time with millisecond precision as 32 bit signed integer. -//! * **Time32SecondChunked** -//! - A ChunkedArray which stores time with second precision as 32 bit signed integer. -//! -//! ### Date -//! * **Date32Chunked** -//! - A ChunkedArray storing the date as elapsed days since the Unix Epoch as 32 bit signed integer. -//! * **Date64Chunked** -//! - A ChunkedArray storing the date as elapsed milliseconds since the Unix Epoch as 64 bit signed integer. -//! -//! ### DateTime -//! * **TimestampNanosecondChunked** -//! - A ChunkedArray storing the date and time as elapsed nanoseconds since the Unix Epoch as 64 bit signed integer. -//! * **TimestampMicrosecondChunked** -//! - A ChunkedArray storing the date and time as elapsed microseconds since the Unix Epoch as 64 bit signed integer. -//! * **TimestampMillisecondChunked** -//! - A ChunkedArray storing the date and time as elapsed milliseconds since the Unix Epoch as 64 bit signed integer. -//! * **TimestampSecondChunked** -//! - A ChunkedArray storing the date and time as elapsed seconds since the Unix Epoch as 64 bit signed integer. diff --git a/polars/src/error.rs b/polars/src/error.rs deleted file mode 100644 index 4a494a828292..000000000000 --- a/polars/src/error.rs +++ /dev/null @@ -1,39 +0,0 @@ -use thiserror::Error as ThisError; - -#[derive(Debug, ThisError)] -pub enum PolarsError { - #[error(transparent)] - ArrowError(#[from] arrow::error::ArrowError), - #[error("Invalid operation")] - InvalidOperation, - #[error("Chunks don't match")] - ChunkMisMatch, - #[error("Data types don't match")] - DataTypeMisMatch, - #[error("Not found")] - NotFound, - #[error("Lengths don't match")] - ShapeMisMatch, - #[error("{0}")] - Other(String), - #[error("No selection was made")] - NoSelection, - #[error("Out of bounds")] - OutOfBounds, - #[error("Not contiguous or null values")] - NoSlice, - #[error("Such empty...")] - NoData, - #[error("Memory should be 64 byte aligned")] - MemoryNotAligned, - #[cfg(feature = "parquet")] - #[error(transparent)] - ParquetError(#[from] parquet::errors::ParquetError), - #[cfg(feature = "random")] - #[error("{0}")] - RandError(String), - #[error("This operation requires data without None values")] - HasNullValues, -} - -pub type Result = std::result::Result; diff --git a/polars/src/fmt.rs b/polars/src/fmt.rs deleted file mode 100644 index 9e8ce34c2c76..000000000000 --- a/polars/src/fmt.rs +++ /dev/null @@ -1,482 +0,0 @@ -use crate::datatypes::{AnyType, ToStr}; -use crate::prelude::*; - -#[cfg(feature = "temporal")] -use crate::chunked_array::temporal::{ - date32_as_datetime, date64_as_datetime, time32_millisecond_as_time, time32_second_as_time, - time64_microsecond_as_time, time64_nanosecond_as_time, timestamp_microseconds_as_datetime, - timestamp_milliseconds_as_datetime, timestamp_nanoseconds_as_datetime, - timestamp_seconds_as_datetime, -}; -use num::{Num, NumCast}; -#[cfg(feature = "pretty")] -use prettytable::Table; -use std::{ - fmt, - fmt::{Debug, Display, Formatter}, -}; -const LIMIT: usize = 10; - -/// Some unit functions that just pass the integer values if we don't want all chrono functionality -#[cfg(not(feature = "temporal"))] -mod temporal { - pub struct DateTime(T) - where - T: Copy; - - impl DateTime - where - T: Copy, - { - pub fn date(&self) -> T { - self.0 - } - } - - pub fn date32_as_datetime(v: i32) -> DateTime { - DateTime(v) - } - pub fn date64_as_datetime(v: i64) -> DateTime { - DateTime(v) - } - pub fn time32_millisecond_as_time(v: i32) -> i32 { - v - } - pub fn time32_second_as_time(v: i32) -> i32 { - v - } - pub fn time64_nanosecond_as_time(v: i64) -> i64 { - v - } - pub fn time64_microsecond_as_time(v: i64) -> i64 { - v - } - pub fn timestamp_nanoseconds_as_datetime(v: i64) -> i64 { - v - } - pub fn timestamp_microseconds_as_datetime(v: i64) -> i64 { - v - } - pub fn timestamp_milliseconds_as_datetime(v: i64) -> i64 { - v - } - pub fn timestamp_seconds_as_datetime(v: i64) -> i64 { - v - } -} -#[cfg(not(feature = "temporal"))] -use temporal::*; - -macro_rules! format_array { - ($limit:ident, $f:ident, $a:ident, $dtype:expr, $name:expr, $array_type:expr) => {{ - write![$f, "{}: '{}' [{}]\n[\n", $array_type, $name, $dtype]?; - - for i in 0..$limit { - let v = $a.get_any(i); - write!($f, "\t{}\n", v)?; - } - - write![$f, "]"] - }}; -} - -macro_rules! format_utf8_array { - ($limit:ident, $f:ident, $a:ident, $name:expr, $array_type:expr) => {{ - write![$f, "{}: '{}' [str]\n[\n", $array_type, $name]?; - $a.into_iter().take($limit).for_each(|opt_s| match opt_s { - None => { - write!($f, "\tnull\n").ok(); - } - Some(s) => { - write!($f, "\t\"{}\"\n", &s[..std::cmp::min($limit, s.len())]).ok(); - } - }); - write![$f, "]"] - }}; -} -macro_rules! format_list_array { - ($limit:ident, $f:ident, $a:ident, $name:expr, $array_type:expr) => {{ - write![$f, "{}: '{}' [list]\n[\n", $array_type, $name]?; - - for i in 0..$limit { - let opt_v = $a.get(i); - match opt_v { - Some(v) => write!($f, "\t{}\n", v.fmt_largelist())?, - None => write!($f, "\tnull\n")?, - } - } - - write![$f, "]"] - }}; -} - -macro_rules! set_limit { - ($self:ident) => { - std::cmp::min($self.len(), LIMIT) - }; -} - -impl Debug for ChunkedArray -where - T: ArrowPrimitiveType, -{ - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - let limit = set_limit!(self); - let dtype = format!("{:?}", T::get_data_type()); - format_array!(limit, f, self, dtype, self.name(), "ChunkedArray") - } -} - -impl Debug for Utf8Chunked { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - let limit = set_limit!(self); - format_utf8_array!(limit, f, self, self.name(), "ChunkedArray") - } -} - -impl Debug for LargeListChunked { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - let limit = set_limit!(self); - format_list_array!(limit, f, self, self.name(), "ChunkedArray") - } -} - -impl Debug for Series { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - let limit = set_limit!(self); - - match self { - Series::UInt8(a) => format_array!(limit, f, a, "u8", a.name(), "Series"), - Series::UInt16(a) => format_array!(limit, f, a, "u16", a.name(), "Series"), - Series::UInt32(a) => format_array!(limit, f, a, "u32", a.name(), "Series"), - Series::UInt64(a) => format_array!(limit, f, a, "u64", a.name(), "Series"), - Series::Int8(a) => format_array!(limit, f, a, "i8", a.name(), "Series"), - Series::Int16(a) => format_array!(limit, f, a, "i16", a.name(), "Series"), - Series::Int32(a) => format_array!(limit, f, a, "i32", a.name(), "Series"), - Series::Int64(a) => format_array!(limit, f, a, "i64", a.name(), "Series"), - Series::Bool(a) => format_array!(limit, f, a, "bool", a.name(), "Series"), - Series::Float32(a) => format_array!(limit, f, a, "f32", a.name(), "Series"), - Series::Float64(a) => format_array!(limit, f, a, "f64", a.name(), "Series"), - Series::Date32(a) => format_array!(limit, f, a, "date32(day)", a.name(), "Series"), - Series::Date64(a) => format_array!(limit, f, a, "date64(ms)", a.name(), "Series"), - Series::Time32Millisecond(a) => { - format_array!(limit, f, a, "time32(ms)", a.name(), "Series") - } - Series::Time32Second(a) => format_array!(limit, f, a, "time32(s)", a.name(), "Series"), - Series::Time64Nanosecond(a) => { - format_array!(limit, f, a, "time64(ns)", a.name(), "Series") - } - Series::Time64Microsecond(a) => { - format_array!(limit, f, a, "time64(μs)", a.name(), "Series") - } - Series::DurationNanosecond(a) => { - format_array!(limit, f, a, "duration(ns)", a.name(), "Series") - } - Series::DurationMicrosecond(a) => { - format_array!(limit, f, a, "duration(μs)", a.name(), "Series") - } - Series::DurationMillisecond(a) => { - format_array!(limit, f, a, "duration(ms)", a.name(), "Series") - } - Series::DurationSecond(a) => { - format_array!(limit, f, a, "duration(s)", a.name(), "Series") - } - Series::IntervalDayTime(a) => { - format_array!(limit, f, a, "interval(daytime)", a.name(), "Series") - } - Series::IntervalYearMonth(a) => { - format_array!(limit, f, a, "interval(year-month)", a.name(), "Series") - } - Series::TimestampNanosecond(a) => { - format_array!(limit, f, a, "timestamp(ns)", a.name(), "Series") - } - Series::TimestampMicrosecond(a) => { - format_array!(limit, f, a, "timestamp(μs)", a.name(), "Series") - } - Series::TimestampMillisecond(a) => { - format_array!(limit, f, a, "timestamp(ms)", a.name(), "Series") - } - Series::TimestampSecond(a) => { - format_array!(limit, f, a, "timestamp(s)", a.name(), "Series") - } - Series::Utf8(a) => format_utf8_array!(limit, f, a, a.name(), "Series"), - Series::LargeList(a) => format_list_array!(limit, f, a, a.name(), "Series"), - } - } -} - -impl Display for Series { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - Debug::fmt(self, f) - } -} - -impl Debug for DataFrame { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - Display::fmt(self, f) - } -} - -#[cfg(feature = "pretty")] -impl Display for DataFrame { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - let mut table = Table::new(); - let names = self - .schema() - .fields() - .iter() - .map(|f| format!("{}\n---\n{}", f.name(), f.data_type().to_str())) - .collect(); - table.set_titles(names); - for i in 0..10 { - let opt = self.get(i); - if let Some(row) = opt { - let mut row_str = Vec::with_capacity(row.len()); - for v in &row { - row_str.push(format!("{}", v)); - } - table.add_row(row.iter().map(|v| format!("{}", v)).collect()); - } else { - break; - } - } - write!(f, "{}", table)?; - Ok(()) - } -} - -#[cfg(not(feature = "pretty"))] -impl Display for DataFrame { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!( - f, - "DataFrame. NOTE: compile with the feature 'pretty' for pretty printing." - ) - } -} - -fn fmt_integer(f: &mut Formatter<'_>, width: usize, v: T) -> fmt::Result { - let v: i64 = NumCast::from(v).unwrap(); - if v > 9999 { - write!(f, "{:>width$e}", v, width = width) - } else { - write!(f, "{:>width$}", v, width = width) - } -} - -fn fmt_float(f: &mut Formatter<'_>, width: usize, v: T) -> fmt::Result { - let v: f64 = NumCast::from(v).unwrap(); - let v = (v * 1000.).round() / 1000.; - if v > 9999. || v < 0.001 { - write!(f, "{:>width$e}", v, width = width) - } else { - write!(f, "{:>width$}", v, width = width) - } -} - -#[cfg(not(feature = "pretty"))] -impl Display for AnyType<'_> { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "{:?}", self) - } -} - -#[cfg(feature = "pretty")] -impl Display for AnyType<'_> { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - let width = 0; - match self { - AnyType::Null => write!(f, "{}", "null"), - AnyType::UInt8(v) => write!(f, "{}", v), - AnyType::UInt16(v) => write!(f, "{}", v), - AnyType::UInt32(v) => write!(f, "{}", v), - AnyType::UInt64(v) => write!(f, "{}", v), - AnyType::Int8(v) => fmt_integer(f, width, *v), - AnyType::Int16(v) => fmt_integer(f, width, *v), - AnyType::Int32(v) => fmt_integer(f, width, *v), - AnyType::Int64(v) => fmt_integer(f, width, *v), - AnyType::Float32(v) => fmt_float(f, width, *v), - AnyType::Float64(v) => fmt_float(f, width, *v), - AnyType::Boolean(v) => write!(f, "{}", *v), - AnyType::Utf8(v) => write!(f, "{}", format!("\"{}\"", v)), - AnyType::Date32(v) => write!(f, "{}", date32_as_datetime(*v).date()), - AnyType::Date64(v) => write!(f, "{}", date64_as_datetime(*v).date()), - AnyType::Time32(v, TimeUnit::Millisecond) => { - write!(f, "{}", time32_millisecond_as_time(*v)) - } - AnyType::Time32(v, TimeUnit::Second) => write!(f, "{}", time32_second_as_time(*v)), - AnyType::Time64(v, TimeUnit::Nanosecond) => { - write!(f, "{}", time64_nanosecond_as_time(*v)) - } - AnyType::Time64(v, TimeUnit::Microsecond) => { - write!(f, "{}", time64_microsecond_as_time(*v)) - } - AnyType::Duration(v, TimeUnit::Nanosecond) => write!(f, "{}", v), - AnyType::Duration(v, TimeUnit::Microsecond) => write!(f, "{}", v), - AnyType::Duration(v, TimeUnit::Millisecond) => write!(f, "{}", v), - AnyType::Duration(v, TimeUnit::Second) => write!(f, "{}", v), - AnyType::TimeStamp(v, TimeUnit::Nanosecond) => { - write!(f, "{}", timestamp_nanoseconds_as_datetime(*v)) - } - AnyType::TimeStamp(v, TimeUnit::Microsecond) => { - write!(f, "{}", timestamp_microseconds_as_datetime(*v)) - } - AnyType::TimeStamp(v, TimeUnit::Millisecond) => { - write!(f, "{}", timestamp_milliseconds_as_datetime(*v)) - } - AnyType::TimeStamp(v, TimeUnit::Second) => { - write!(f, "{}", timestamp_seconds_as_datetime(*v)) - } - AnyType::IntervalDayTime(v) => write!(f, "{}", v), - AnyType::IntervalYearMonth(v) => write!(f, "{}", v), - AnyType::LargeList(s) => write!(f, "{:?}", s.fmt_largelist()), - _ => unimplemented!(), - } - } -} - -macro_rules! fmt_option { - ($opt:expr) => {{ - match $opt { - Some(v) => format!("{:?}", v), - None => "null".to_string(), - } - }}; -} - -macro_rules! impl_fmt_largelist { - ($self:ident) => {{ - match $self.len() { - 1 => format!("[{}]", fmt_option!($self.get(0))), - 2 => format!( - "[{}, {}]", - fmt_option!($self.get(0)), - fmt_option!($self.get(1)) - ), - 3 => format!( - "[{}, {}, {}]", - fmt_option!($self.get(0)), - fmt_option!($self.get(1)), - fmt_option!($self.get(2)) - ), - _ => format!( - "[{}, {}, ... {}]", - fmt_option!($self.get(0)), - fmt_option!($self.get(1)), - fmt_option!($self.get($self.len() - 1)) - ), - } - }}; -} - -pub(crate) trait FmtLargeList { - fn fmt_largelist(&self) -> String; -} - -impl FmtLargeList for ChunkedArray -where - T: ArrowPrimitiveType, -{ - fn fmt_largelist(&self) -> String { - impl_fmt_largelist!(self) - } -} - -impl FmtLargeList for Utf8Chunked { - fn fmt_largelist(&self) -> String { - impl_fmt_largelist!(self) - } -} - -impl FmtLargeList for LargeListChunked { - fn fmt_largelist(&self) -> String { - impl_fmt_largelist!(self) - } -} - -#[cfg(all(test, feature = "temporal"))] -mod test { - use crate::prelude::*; - - #[test] - fn list() { - use arrow::array::Int32Array; - let values_builder = Int32Array::builder(10); - let mut builder = LargeListPrimitiveChunkedBuilder::new("a", values_builder, 10); - builder.append_slice(Some(&[1, 2, 3])); - builder.append_slice(None); - let list = builder.finish().into_series(); - - println!("{:?}", list); - assert_eq!( - r#"Series: 'a' [list] -[ - [1, 2, 3] - null -]"#, - format!("{:?}", list) - ); - } - - #[test] - fn temporal() { - let s = Date32Chunked::new_from_opt_slice("date32", &[Some(1), None, Some(3)]); - assert_eq!( - r#"Series: 'date32' [date32(day)] -[ - 1970-01-02 - null - 1970-01-04 -]"#, - format!("{:?}", s.into_series()) - ); - - let s = Date64Chunked::new_from_opt_slice("", &[Some(1), None, Some(1000_000_000_000)]); - assert_eq!( - r#"Series: '' [date64(ms)] -[ - 1970-01-01 - null - 2001-09-09 -]"#, - format!("{:?}", s.into_series()) - ); - let s = Time64NanosecondChunked::new_from_slice( - "", - &[1_000_000, 37_800_005_000_000, 86_399_210_000_000], - ); - assert_eq!( - r#"Series: '' [time64(ns)] -[ - 00:00:00.001 - 10:30:00.005 - 23:59:59.210 -]"#, - format!("{:?}", s.into_series()) - ) - } - #[test] - fn test_fmt_chunkedarray() { - let ca = Int32Chunked::new_from_opt_slice("date32", &[Some(1), None, Some(3)]); - println!("{:?}", ca); - assert_eq!( - r#"ChunkedArray: 'date32' [Int32] -[ - 1 - null - 3 -]"#, - format!("{:?}", ca) - ); - let ca = Utf8Chunked::new_from_slice("name", &["a", "b"]); - println!("{:?}", ca); - assert_eq!( - r#"ChunkedArray: 'name' [str] -[ - "a" - "b" -]"#, - format!("{:?}", ca) - ); - } -} diff --git a/polars/src/frame/group_by.rs b/polars/src/frame/group_by.rs deleted file mode 100644 index 3cd8f8934445..000000000000 --- a/polars/src/frame/group_by.rs +++ /dev/null @@ -1,1233 +0,0 @@ -use super::hash_join::prepare_hashed_relation; -use crate::chunked_array::builder::PrimitiveChunkedBuilder; -use crate::frame::select::Selection; -use crate::prelude::*; -use arrow::array::{PrimitiveBuilder, StringBuilder}; -use enum_dispatch::enum_dispatch; -use fnv::FnvBuildHasher; -use itertools::Itertools; -use num::{Num, NumCast, ToPrimitive, Zero}; -use rayon::prelude::*; -use std::collections::HashMap; -use std::hash::Hash; -use std::{ - fmt::{Debug, Formatter}, - ops::Add, -}; - -fn groupby(a: impl Iterator) -> Vec<(usize, Vec)> -where - T: Hash + Eq + Copy, -{ - let hash_tbl = prepare_hashed_relation(a); - - hash_tbl - .into_iter() - .map(|(_, indexes)| { - let first = unsafe { *indexes.get_unchecked(0) }; - (first, indexes) - }) - .collect() -} - -#[enum_dispatch(Series)] -trait IntoGroupTuples { - fn group_tuples(&self) -> Vec<(usize, Vec)> { - unimplemented!() - } -} - -impl IntoGroupTuples for ChunkedArray -where - T: PolarsIntegerType, - T::Native: Eq + Hash, -{ - fn group_tuples(&self) -> Vec<(usize, Vec)> { - if let Ok(slice) = self.cont_slice() { - groupby(slice.iter()) - } else { - groupby(self.into_iter()) - } - } -} -impl IntoGroupTuples for BooleanChunked { - fn group_tuples(&self) -> Vec<(usize, Vec)> { - if self.is_optimal_aligned() { - groupby(self.into_no_null_iter()) - } else { - groupby(self.into_iter()) - } - } -} - -impl IntoGroupTuples for Utf8Chunked { - fn group_tuples(&self) -> Vec<(usize, Vec)> { - if self.is_optimal_aligned() { - groupby(self.into_no_null_iter()) - } else { - groupby(self.into_iter()) - } - } -} - -impl IntoGroupTuples for Float64Chunked {} -impl IntoGroupTuples for Float32Chunked {} -impl IntoGroupTuples for LargeListChunked {} - -/// Utility enum used for grouping on multiple columns -#[derive(Copy, Clone, Hash, Eq, PartialEq)] -enum Groupable<'a> { - Boolean(bool), - Utf8(&'a str), - UInt8(u8), - UInt16(u16), - UInt32(u32), - UInt64(u64), - Int8(i8), - Int16(i16), - Int32(i32), - Int64(i64), -} - -impl<'a> Debug for Groupable<'a> { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - use Groupable::*; - match self { - Boolean(v) => write!(f, "{}", v), - Utf8(v) => write!(f, "{}", v), - UInt8(v) => write!(f, "{}", v), - UInt16(v) => write!(f, "{}", v), - UInt32(v) => write!(f, "{}", v), - UInt64(v) => write!(f, "{}", v), - Int8(v) => write!(f, "{}", v), - Int16(v) => write!(f, "{}", v), - Int32(v) => write!(f, "{}", v), - Int64(v) => write!(f, "{}", v), - } - } -} - -impl Series { - fn as_groupable_iter<'a>(&'a self) -> Box> + 'a> { - macro_rules! as_groupable_iter { - ($ca:expr, $variant:ident ) => {{ - Box::new( - $ca.into_iter() - .map(|opt_b| opt_b.map(|b| Groupable::$variant(b))), - ) - }}; - } - match self { - Series::Bool(ca) => as_groupable_iter!(ca, Boolean), - Series::UInt8(ca) => as_groupable_iter!(ca, UInt8), - Series::UInt16(ca) => as_groupable_iter!(ca, UInt16), - Series::UInt32(ca) => as_groupable_iter!(ca, UInt32), - Series::UInt64(ca) => as_groupable_iter!(ca, UInt64), - Series::Int8(ca) => as_groupable_iter!(ca, Int8), - Series::Int16(ca) => as_groupable_iter!(ca, Int16), - Series::Int32(ca) => as_groupable_iter!(ca, Int32), - Series::Int64(ca) => as_groupable_iter!(ca, Int64), - Series::Date32(ca) => as_groupable_iter!(ca, Int32), - Series::Date64(ca) => as_groupable_iter!(ca, Int64), - Series::TimestampSecond(ca) => as_groupable_iter!(ca, Int64), - Series::TimestampMillisecond(ca) => as_groupable_iter!(ca, Int64), - Series::TimestampNanosecond(ca) => as_groupable_iter!(ca, Int64), - Series::TimestampMicrosecond(ca) => as_groupable_iter!(ca, Int64), - Series::Time32Second(ca) => as_groupable_iter!(ca, Int32), - Series::Time32Millisecond(ca) => as_groupable_iter!(ca, Int32), - Series::Time64Nanosecond(ca) => as_groupable_iter!(ca, Int64), - Series::Time64Microsecond(ca) => as_groupable_iter!(ca, Int64), - Series::DurationNanosecond(ca) => as_groupable_iter!(ca, Int64), - Series::DurationMicrosecond(ca) => as_groupable_iter!(ca, Int64), - Series::DurationMillisecond(ca) => as_groupable_iter!(ca, Int64), - Series::DurationSecond(ca) => as_groupable_iter!(ca, Int64), - Series::IntervalDayTime(ca) => as_groupable_iter!(ca, Int64), - Series::IntervalYearMonth(ca) => as_groupable_iter!(ca, Int32), - Series::Utf8(ca) => as_groupable_iter!(ca, Utf8), - _ => unimplemented!(), - } - } -} - -impl DataFrame { - /// Group DataFrame using a Series column. - /// - /// # Example - /// - /// ``` - /// use polars::prelude::*; - /// fn groupby_sum(df: &DataFrame) -> Result { - /// df.groupby("column_name")? - /// .select("agg_column_name") - /// .sum() - /// } - /// ``` - pub fn groupby<'g, J, S: Selection<'g, J>>(&self, by: S) -> Result { - let selected_keys = self.select_series(by)?; - - let groups = match selected_keys.len() { - 1 => selected_keys[0].group_tuples(), - 2 => { - let iter = selected_keys[0] - .as_groupable_iter() - .zip(selected_keys[1].as_groupable_iter()); - groupby(iter) - } - 3 => { - let iter = selected_keys[0] - .as_groupable_iter() - .zip(selected_keys[1].as_groupable_iter()) - .zip(selected_keys[2].as_groupable_iter()); - groupby(iter) - } - 4 => { - let iter = selected_keys[0] - .as_groupable_iter() - .zip(selected_keys[1].as_groupable_iter()) - .zip(selected_keys[2].as_groupable_iter()) - .zip(selected_keys[3].as_groupable_iter()); - groupby(iter) - } - 5 => { - let iter = selected_keys[0] - .as_groupable_iter() - .zip(selected_keys[1].as_groupable_iter()) - .zip(selected_keys[2].as_groupable_iter()) - .zip(selected_keys[3].as_groupable_iter()) - .zip(selected_keys[4].as_groupable_iter()); - groupby(iter) - } - _ => { - return Err(PolarsError::Other( - "more than 5 combined keys are currently not supported".to_string(), - )); - } - }; - - Ok(GroupBy { - df: self, - selected_keys, - groups, - selected_agg: None, - }) - } -} - -/// Returned by a groupby operation on a DataFrame. This struct supports -/// several aggregations. -/// -/// Until described otherwise, the examples in this struct are performed on the following DataFrame: -/// -/// ```rust -/// use polars::prelude::*; -/// -/// let dates = &[ -/// "2020-08-21", -/// "2020-08-21", -/// "2020-08-22", -/// "2020-08-23", -/// "2020-08-22", -/// ]; -/// // date format -/// let fmt = "%Y-%m-%d"; -/// // create date series -/// let s0 = Date32Chunked::parse_from_str_slice("date", dates, fmt) -/// .into_series(); -/// // create temperature series -/// let s1 = Series::new("temp", [20, 10, 7, 9, 1].as_ref()); -/// // create rain series -/// let s2 = Series::new("rain", [0.2, 0.1, 0.3, 0.1, 0.01].as_ref()); -/// // create a new DataFrame -/// let df = DataFrame::new(vec![s0, s1, s2]).unwrap(); -/// println!("{:?}", df); -/// ``` -/// -/// Outputs: -/// -/// ```text -/// +------------+------+------+ -/// | date | temp | rain | -/// | --- | --- | --- | -/// | date32 | i32 | f64 | -/// +============+======+======+ -/// | 2020-08-21 | 20 | 0.2 | -/// +------------+------+------+ -/// | 2020-08-21 | 10 | 0.1 | -/// +------------+------+------+ -/// | 2020-08-22 | 7 | 0.3 | -/// +------------+------+------+ -/// | 2020-08-23 | 9 | 0.1 | -/// +------------+------+------+ -/// | 2020-08-22 | 1 | 0.01 | -/// +------------+------+------+ -/// ``` -/// -#[derive(Debug, Clone)] -pub struct GroupBy<'df, 'selection_str> { - df: &'df DataFrame, - selected_keys: Vec, - // [first idx, [other idx]] - groups: Vec<(usize, Vec)>, - // columns selected for aggregation - selected_agg: Option>, -} - -#[enum_dispatch(Series)] -trait NumericAggSync { - fn agg_mean(&self, _groups: &Vec<(usize, Vec)>) -> Series { - unimplemented!() - } - fn agg_min(&self, _groups: &Vec<(usize, Vec)>) -> Series { - unimplemented!() - } - fn agg_max(&self, _groups: &Vec<(usize, Vec)>) -> Series { - unimplemented!() - } - fn agg_sum(&self, _groups: &Vec<(usize, Vec)>) -> Series { - unimplemented!() - } -} - -impl NumericAggSync for BooleanChunked {} -impl NumericAggSync for Utf8Chunked {} -impl NumericAggSync for LargeListChunked {} - -impl NumericAggSync for ChunkedArray -where - T: PolarsNumericType + Sync, - T::Native: std::ops::Add + Num + NumCast, -{ - fn agg_mean(&self, groups: &Vec<(usize, Vec)>) -> Series { - Series::Float64( - groups - .par_iter() - .map(|(_first, idx)| { - // Fast path - if let Ok(slice) = self.cont_slice() { - let mut sum = 0.; - for i in idx { - sum = sum + slice[*i].to_f64().unwrap() - } - Some(sum / idx.len() as f64) - } else { - let take = unsafe { - self.take_unchecked(idx.into_iter().copied(), Some(self.len())) - }; - let opt_sum: Option = take.sum(); - opt_sum.map(|sum| sum.to_f64().unwrap() / idx.len() as f64) - } - }) - .collect(), - ) - } - - fn agg_min(&self, groups: &Vec<(usize, Vec)>) -> Series { - groups - .par_iter() - .map(|(_first, idx)| { - if let Ok(slice) = self.cont_slice() { - let mut min = None; - for i in idx { - let v = slice[*i]; - - min = match min { - Some(min) => { - if min < v { - Some(min) - } else { - Some(v) - } - } - None => Some(v), - }; - } - min - } else { - let take = - unsafe { self.take_unchecked(idx.into_iter().copied(), Some(self.len())) }; - take.min() - } - }) - .collect::>() - .into_series() - } - - fn agg_max(&self, groups: &Vec<(usize, Vec)>) -> Series { - groups - .par_iter() - .map(|(_first, idx)| { - if let Ok(slice) = self.cont_slice() { - let mut max = None; - for i in idx { - let v = slice[*i]; - - max = match max { - Some(max) => { - if max > v { - Some(max) - } else { - Some(v) - } - } - None => Some(v), - }; - } - max - } else { - let take = - unsafe { self.take_unchecked(idx.into_iter().copied(), Some(self.len())) }; - take.max() - } - }) - .collect::>() - .into_series() - } - - fn agg_sum(&self, groups: &Vec<(usize, Vec)>) -> Series { - groups - .par_iter() - .map(|(_first, idx)| { - if let Ok(slice) = self.cont_slice() { - let mut sum = Zero::zero(); - for i in idx { - sum = sum + slice[*i] - } - Some(sum) - } else { - let take = - unsafe { self.take_unchecked(idx.into_iter().copied(), Some(self.len())) }; - take.sum() - } - }) - .collect::>() - .into_series() - } -} - -#[enum_dispatch(Series)] -trait AggFirst { - fn agg_first(&self, _groups: &Vec<(usize, Vec)>) -> Series { - unimplemented!() - } -} - -macro_rules! impl_agg_first { - ($self:ident, $groups:ident, $ca_type:ty) => {{ - $groups - .iter() - .map(|(first, _idx)| $self.get(*first)) - .collect::<$ca_type>() - .into_series() - }}; -} - -impl AggFirst for ChunkedArray -where - T: PolarsNumericType + std::marker::Sync, -{ - fn agg_first(&self, groups: &Vec<(usize, Vec)>) -> Series { - impl_agg_first!(self, groups, ChunkedArray) - } -} - -impl AggFirst for BooleanChunked { - fn agg_first(&self, groups: &Vec<(usize, Vec)>) -> Series { - impl_agg_first!(self, groups, BooleanChunked) - } -} - -impl AggFirst for Utf8Chunked { - fn agg_first(&self, groups: &Vec<(usize, Vec)>) -> Series { - groups - .iter() - .map(|(first, _idx)| self.get(*first)) - .collect::() - .into_series() - } -} - -impl AggFirst for LargeListChunked {} - -impl<'df, 'selection_str> GroupBy<'df, 'selection_str> { - /// Select the column by which the determine the groups. - /// You can select a single column or a slice of columns. - pub fn select(mut self, selection: S) -> Self - where - S: Selection<'selection_str, J>, - { - self.selected_agg = Some(selection.to_selection_vec()); - self - } - - fn keys(&self) -> Vec { - // Keys will later be appended with the aggregation columns, so we already allocate extra space - let size; - if let Some(sel) = &self.selected_agg { - size = sel.len() + self.selected_keys.len(); - } else { - size = self.selected_keys.len(); - } - let mut keys = Vec::with_capacity(size); - unsafe { - self.selected_keys.iter().for_each(|s| { - let key = s.take_iter_unchecked( - self.groups.iter().map(|(idx, _)| *idx), - Some(self.groups.len()), - ); - keys.push(key) - }); - } - keys - } - - fn prepare_agg(&self) -> Result<(Vec, Vec)> { - let selection = match &self.selected_agg { - Some(selection) => selection, - None => return Err(PolarsError::NoSelection), - }; - - let keys = self.keys(); - let agg_col = self.df.select_series(selection)?; - Ok((keys, agg_col)) - } - - /// Aggregate grouped series and compute the mean per group. - /// - /// # Example - /// - /// ```rust - /// # use polars::prelude::*; - /// fn example(df: DataFrame) -> Result { - /// df.groupby("date")?.select(&["temp", "rain"]).mean() - /// } - /// ``` - /// Returns: - /// - /// ```text - /// +------------+-----------+-----------+ - /// | date | temp_mean | rain_mean | - /// | --- | --- | --- | - /// | date32 | f64 | f64 | - /// +============+===========+===========+ - /// | 2020-08-23 | 9 | 0.1 | - /// +------------+-----------+-----------+ - /// | 2020-08-22 | 4 | 0.155 | - /// +------------+-----------+-----------+ - /// | 2020-08-21 | 15 | 0.15 | - /// +------------+-----------+-----------+ - /// ``` - pub fn mean(&self) -> Result { - let (mut cols, agg_cols) = self.prepare_agg()?; - - for agg_col in agg_cols { - let new_name = format!["{}_mean", agg_col.name()]; - let mut agg = agg_col.agg_mean(&self.groups); - agg.rename(&new_name); - cols.push(agg); - } - DataFrame::new(cols) - } - - /// Aggregate grouped series and compute the sum per group. - /// - /// # Example - /// - /// ```rust - /// # use polars::prelude::*; - /// fn example(df: DataFrame) -> Result { - /// df.groupby("date")?.select("temp").sum() - /// } - /// ``` - /// Returns: - /// - /// ```text - /// +------------+----------+ - /// | date | temp_sum | - /// | --- | --- | - /// | date32 | i32 | - /// +============+==========+ - /// | 2020-08-23 | 9 | - /// +------------+----------+ - /// | 2020-08-22 | 8 | - /// +------------+----------+ - /// | 2020-08-21 | 30 | - /// +------------+----------+ - /// ``` - pub fn sum(&self) -> Result { - let (mut cols, agg_cols) = self.prepare_agg()?; - - for agg_col in agg_cols { - let new_name = format!["{}_sum", agg_col.name()]; - let mut agg = agg_col.agg_sum(&self.groups); - agg.rename(&new_name); - cols.push(agg); - } - DataFrame::new(cols) - } - - /// Aggregate grouped series and compute the minimal value per group. - /// - /// # Example - /// - /// ```rust - /// # use polars::prelude::*; - /// fn example(df: DataFrame) -> Result { - /// df.groupby("date")?.select("temp").min() - /// } - /// ``` - /// Returns: - /// - /// ```text - /// +------------+----------+ - /// | date | temp_min | - /// | --- | --- | - /// | date32 | i32 | - /// +============+==========+ - /// | 2020-08-23 | 9 | - /// +------------+----------+ - /// | 2020-08-22 | 1 | - /// +------------+----------+ - /// | 2020-08-21 | 10 | - /// +------------+----------+ - /// ``` - pub fn min(&self) -> Result { - let (mut cols, agg_cols) = self.prepare_agg()?; - for agg_col in agg_cols { - let new_name = format!["{}_min", agg_col.name()]; - let mut agg = agg_col.agg_min(&self.groups); - agg.rename(&new_name); - cols.push(agg); - } - DataFrame::new(cols) - } - - /// Aggregate grouped series and compute the maximum value per group. - /// - /// # Example - /// - /// ```rust - /// # use polars::prelude::*; - /// fn example(df: DataFrame) -> Result { - /// df.groupby("date")?.select("temp").max() - /// } - /// ``` - /// Returns: - /// - /// ```text - /// +------------+----------+ - /// | date | temp_max | - /// | --- | --- | - /// | date32 | i32 | - /// +============+==========+ - /// | 2020-08-23 | 9 | - /// +------------+----------+ - /// | 2020-08-22 | 7 | - /// +------------+----------+ - /// | 2020-08-21 | 20 | - /// +------------+----------+ - /// ``` - pub fn max(&self) -> Result { - let (mut cols, agg_cols) = self.prepare_agg()?; - for agg_col in agg_cols { - let new_name = format!["{}_max", agg_col.name()]; - let mut agg = agg_col.agg_max(&self.groups); - agg.rename(&new_name); - cols.push(agg); - } - DataFrame::new(cols) - } - - /// Aggregate grouped series and find the first value per group. - /// - /// # Example - /// - /// ```rust - /// # use polars::prelude::*; - /// fn example(df: DataFrame) -> Result { - /// df.groupby("date")?.select("temp").first() - /// } - /// ``` - /// Returns: - /// - /// ```text - /// +------------+------------+ - /// | date | temp_first | - /// | --- | --- | - /// | date32 | i32 | - /// +============+============+ - /// | 2020-08-23 | 9 | - /// +------------+------------+ - /// | 2020-08-22 | 7 | - /// +------------+------------+ - /// | 2020-08-21 | 20 | - /// +------------+------------+ - /// ``` - pub fn first(&self) -> Result { - let (mut cols, agg_cols) = self.prepare_agg()?; - for agg_col in agg_cols { - let new_name = format!["{}_first", agg_col.name()]; - let mut agg = agg_col.agg_first(&self.groups); - agg.rename(&new_name); - cols.push(agg); - } - DataFrame::new(cols) - } - - /// Aggregate grouped series and compute the number of values per group. - /// - /// # Example - /// - /// ```rust - /// # use polars::prelude::*; - /// fn example(df: DataFrame) -> Result { - /// df.groupby("date")?.select("temp").count() - /// } - /// ``` - /// Returns: - /// - /// ```text - /// +------------+------------+ - /// | date | temp_count | - /// | --- | --- | - /// | date32 | u32 | - /// +============+============+ - /// | 2020-08-23 | 1 | - /// +------------+------------+ - /// | 2020-08-22 | 2 | - /// +------------+------------+ - /// | 2020-08-21 | 2 | - /// +------------+------------+ - /// ``` - pub fn count(&self) -> Result { - let (mut cols, agg_cols) = self.prepare_agg()?; - for agg_col in agg_cols { - let new_name = format!["{}_count", agg_col.name()]; - let mut builder = PrimitiveChunkedBuilder::new(&new_name, self.groups.len()); - for (_first, idx) in &self.groups { - builder.append_value(idx.len() as u32); - } - let ca = builder.finish(); - let agg = Series::UInt32(ca); - cols.push(agg); - } - DataFrame::new(cols) - } - - /// Aggregate the groups of the groupby operation into lists. - /// - /// # Example - /// - /// ```rust - /// # use polars::prelude::*; - /// fn example(df: DataFrame) -> Result { - /// // GroupBy and aggregate to Lists - /// df.groupby("date")?.select("temp").agg_list() - /// } - /// ``` - /// Returns: - /// - /// ```text - /// +------------+------------------------+ - /// | date | temp_agg_list | - /// | --- | --- | - /// | date32 | list [i32] | - /// +============+========================+ - /// | 2020-08-23 | "[Some(9)]" | - /// +------------+------------------------+ - /// | 2020-08-22 | "[Some(7), Some(1)]" | - /// +------------+------------------------+ - /// | 2020-08-21 | "[Some(20), Some(10)]" | - /// +------------+------------------------+ - /// ``` - pub fn agg_list(&self) -> Result { - macro_rules! impl_gb { - ($type:ty, $agg_col:expr) => {{ - let values_builder = PrimitiveBuilder::<$type>::new(self.groups.len()); - let mut builder = - LargeListPrimitiveChunkedBuilder::new("", values_builder, self.groups.len()); - for (_first, idx) in &self.groups { - let s = unsafe { - $agg_col.take_iter_unchecked(idx.into_iter().copied(), Some(idx.len())) - }; - builder.append_opt_series(&Some(s)) - } - builder.finish().into_series() - }}; - } - - macro_rules! impl_gb_utf8 { - ($agg_col:expr) => {{ - let values_builder = StringBuilder::new(self.groups.len()); - let mut builder = - LargeListUtf8ChunkedBuilder::new("", values_builder, self.groups.len()); - for (_first, idx) in &self.groups { - let s = unsafe { - $agg_col.take_iter_unchecked(idx.into_iter().copied(), Some(idx.len())) - }; - builder.append_series(&s) - } - builder.finish().into_series() - }}; - } - - let (mut cols, agg_cols) = self.prepare_agg()?; - for agg_col in agg_cols { - let new_name = format!["{}_agg_list", agg_col.name()]; - let mut agg = - match_arrow_data_type_apply_macro!(agg_col.dtype(), impl_gb, impl_gb_utf8, agg_col); - agg.rename(&new_name); - cols.push(agg); - } - DataFrame::new(cols) - } - - /// Pivot a column of the current `DataFrame` and perform one of the following aggregations: - /// * first - /// * sum - /// * min - /// * max - /// * mean - /// * median - /// - /// The pivot operation consists of a group by one, or multiple collumns (these will be the new - /// y-axis), column that will be pivoted (this will be the new x-axis) and an aggregation. - /// - /// # Panics - /// If the values column is not a numerical type, the code will panic. - /// - /// # Example - /// - /// ```rust - /// use polars::prelude::*; - /// let s0 = Series::new("foo", ["A", "A", "B", "B", "C"].as_ref()); - /// let s1 = Series::new("N", [1, 2, 2, 4, 2].as_ref()); - /// let s2 = Series::new("bar", ["k", "l", "m", "n", "o"].as_ref()); - /// // create a new DataFrame - /// let df = DataFrame::new(vec![s0, s1, s2]).unwrap(); - /// - /// fn example(df: DataFrame) -> Result { - /// df.groupby("foo")? - /// .pivot("bar", "N") - /// .first() - /// } - /// ``` - /// Transforms: - /// - /// ```text - /// +-----+-----+-----+ - /// | foo | N | bar | - /// | --- | --- | --- | - /// | str | i32 | str | - /// +=====+=====+=====+ - /// | "A" | 1 | "k" | - /// +-----+-----+-----+ - /// | "A" | 2 | "l" | - /// +-----+-----+-----+ - /// | "B" | 2 | "m" | - /// +-----+-----+-----+ - /// | "B" | 4 | "n" | - /// +-----+-----+-----+ - /// | "C" | 2 | "o" | - /// +-----+-----+-----+ - /// ``` - /// - /// Into: - /// - /// ```text - /// +-----+------+------+------+------+------+ - /// | foo | o | n | m | l | k | - /// | --- | --- | --- | --- | --- | --- | - /// | str | i32 | i32 | i32 | i32 | i32 | - /// +=====+======+======+======+======+======+ - /// | "A" | null | null | null | 2 | 1 | - /// +-----+------+------+------+------+------+ - /// | "B" | null | 4 | 2 | null | null | - /// +-----+------+------+------+------+------+ - /// | "C" | 2 | null | null | null | null | - /// +-----+------+------+------+------+------+ - /// ``` - pub fn pivot( - &mut self, - pivot_column: &'selection_str str, - values_column: &'selection_str str, - ) -> Pivot { - // same as select method - self.selected_agg = Some(vec![pivot_column, values_column]); - - let pivot = Pivot { - gb: self, - pivot_column, - values_column, - }; - pivot - } -} - -/// Intermediate structure when a `pivot` operation is applied. -/// See [the pivot method for more information.](../group_by/struct.GroupBy.html#method.pivot) -pub struct Pivot<'df, 'selection_str> { - gb: &'df GroupBy<'df, 'selection_str>, - pivot_column: &'selection_str str, - values_column: &'selection_str str, -} - -#[enum_dispatch(Series)] -trait ChunkPivot { - fn pivot( - &self, - _pivot_series: &Series, - _keys: Vec, - _groups: &Vec<(usize, Vec)>, - _agg_type: PivotAgg, - ) -> DataFrame { - unimplemented!() - } -} - -impl ChunkPivot for ChunkedArray -where - T: PolarsNumericType, - T::Native: Copy + Num + NumCast, -{ - fn pivot( - &self, - pivot_series: &Series, - keys: Vec, - groups: &Vec<(usize, Vec)>, - agg_type: PivotAgg, - ) -> DataFrame { - // TODO: save an allocation by creating a random access struct for the Groupable utility type. - let pivot_vec: Vec<_> = pivot_series.as_groupable_iter().collect(); - - let values_taker = self.take_rand(); - - let new_column_map = |size| { - // create a new hashmap that will be filled with new Vecs that later will be aggegrated - let mut columns_agg_map = - HashMap::with_capacity_and_hasher(size, FnvBuildHasher::default()); - for opt_column_name in &pivot_vec { - if let Some(column_name) = opt_column_name { - columns_agg_map - .entry(column_name) - .or_insert_with(|| Vec::new()); - } - } - - columns_agg_map - }; - - // create a hash map that will be filled with the results of the aggregation. - // let mut columns_agg_map_main = new_column_map(groups.len()); - let mut columns_agg_map_main = - HashMap::with_capacity_and_hasher(pivot_vec.len(), FnvBuildHasher::default()); - for opt_column_name in &pivot_vec { - if let Some(column_name) = opt_column_name { - columns_agg_map_main.entry(column_name).or_insert_with(|| { - PrimitiveChunkedBuilder::::new(&format!("{:?}", column_name), groups.len()) - }); - } - } - - // iterate over the groups that need to be aggregated - // idxes are the indexes of the groups in the keys, pivot, and values columns - for (_first, idx) in groups { - // for every group do the aggregation by adding them to the vector belonging by that column - // the columns are hashed with the pivot values - let mut columns_agg_map_group = new_column_map(idx.len()); - for &i in idx { - let opt_pivot_val = unsafe { pivot_vec.get_unchecked(i) }; - - if let Some(pivot_val) = opt_pivot_val { - let values_val = values_taker.get(i); - columns_agg_map_group - .get_mut(&pivot_val) - .map(|v| v.push(values_val)); - } - } - - // After the vectors are filled we really do the aggregation and add the result to the main - // hash map, mapping pivot values as column to aggregate result. - for (k, v) in &mut columns_agg_map_group { - let main_builder = columns_agg_map_main.get_mut(k).unwrap(); - - match v.len() { - 0 => main_builder.append_null(), - // NOTE: now we take first, but this is the place where all aggregations happen - _ => match agg_type { - PivotAgg::First => pivot_agg_first(main_builder, v), - PivotAgg::Sum => pivot_agg_sum(main_builder, v), - PivotAgg::Min => pivot_agg_min(main_builder, v), - PivotAgg::Max => pivot_agg_max(main_builder, v), - PivotAgg::Mean => pivot_agg_mean(main_builder, v), - PivotAgg::Median => pivot_agg_median(main_builder, v), - }, - } - } - } - // Finalize the pivot by creating a vec of all the columns and creating a DataFrame - let mut cols = keys; - cols.reserve_exact(columns_agg_map_main.len()); - - for (_, builder) in columns_agg_map_main { - let ca = builder.finish(); - cols.push(ca.into_series()); - } - - let df = DataFrame::new(cols).unwrap(); - df - } -} - -impl ChunkPivot for BooleanChunked {} -impl ChunkPivot for Utf8Chunked {} -impl ChunkPivot for LargeListChunked {} - -enum PivotAgg { - First, - Sum, - Min, - Max, - Mean, - Median, -} - -fn pivot_agg_first(builder: &mut PrimitiveChunkedBuilder, v: &Vec>) -where - T: PolarsNumericType, -{ - builder.append_option(v[0]); -} - -fn pivot_agg_median(builder: &mut PrimitiveChunkedBuilder, v: &mut Vec>) -where - T: PolarsNumericType, - T::Native: PartialOrd, -{ - v.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); - builder.append_option(v[v.len() / 2]); -} - -fn pivot_agg_sum(builder: &mut PrimitiveChunkedBuilder, v: &Vec>) -where - T: PolarsNumericType, - T::Native: Num + Zero, -{ - builder.append_option(v.iter().copied().fold_options(Zero::zero(), Add::add)); -} - -fn pivot_agg_mean(builder: &mut PrimitiveChunkedBuilder, v: &Vec>) -where - T: PolarsNumericType, - T::Native: Num + Zero + NumCast, -{ - builder.append_option( - v.iter() - .copied() - .fold_options::(Zero::zero(), Add::add) - .map(|sum_val| sum_val / NumCast::from(v.len()).unwrap()), - ); -} - -fn pivot_agg_min(builder: &mut PrimitiveChunkedBuilder, v: &Vec>) -where - T: PolarsNumericType, -{ - let mut min = None; - - for opt_val in v { - if let Some(val) = opt_val { - match min { - None => min = Some(*val), - Some(minimum) => { - if val < &minimum { - min = Some(*val) - } - } - } - } - } - - builder.append_option(min); -} - -fn pivot_agg_max(builder: &mut PrimitiveChunkedBuilder, v: &Vec>) -where - T: PolarsNumericType, -{ - let mut max = None; - - for opt_val in v { - if let Some(val) = opt_val { - match max { - None => max = Some(*val), - Some(maximum) => { - if val > &maximum { - max = Some(*val) - } - } - } - } - } - - builder.append_option(max); -} - -impl<'df, 'sel_str> Pivot<'df, 'sel_str> { - /// Aggregate the pivot results by taking the first occurring value. - pub fn first(&self) -> Result { - let pivot_series = self.gb.df.column(self.pivot_column)?; - let values_series = self.gb.df.column(self.values_column)?; - Ok(values_series.pivot( - pivot_series, - self.gb.keys(), - &self.gb.groups, - PivotAgg::First, - )) - } - - /// Aggregate the pivot results by taking the sum of all duplicates. - pub fn sum(&self) -> Result { - let pivot_series = self.gb.df.column(self.pivot_column)?; - let values_series = self.gb.df.column(self.values_column)?; - Ok(values_series.pivot(pivot_series, self.gb.keys(), &self.gb.groups, PivotAgg::Sum)) - } - - /// Aggregate the pivot results by taking the minimal value of all duplicates. - pub fn min(&self) -> Result { - let pivot_series = self.gb.df.column(self.pivot_column)?; - let values_series = self.gb.df.column(self.values_column)?; - Ok(values_series.pivot(pivot_series, self.gb.keys(), &self.gb.groups, PivotAgg::Min)) - } - - /// Aggregate the pivot results by taking the maximum value of all duplicates. - pub fn max(&self) -> Result { - let pivot_series = self.gb.df.column(self.pivot_column)?; - let values_series = self.gb.df.column(self.values_column)?; - Ok(values_series.pivot(pivot_series, self.gb.keys(), &self.gb.groups, PivotAgg::Max)) - } - - /// Aggregate the pivot results by taking the mean value of all duplicates. - pub fn mean(&self) -> Result { - let pivot_series = self.gb.df.column(self.pivot_column)?; - let values_series = self.gb.df.column(self.values_column)?; - Ok(values_series.pivot( - pivot_series, - self.gb.keys(), - &self.gb.groups, - PivotAgg::Mean, - )) - } - /// Aggregate the pivot results by taking the median value of all duplicates. - pub fn median(&self) -> Result { - let pivot_series = self.gb.df.column(self.pivot_column)?; - let values_series = self.gb.df.column(self.values_column)?; - Ok(values_series.pivot( - pivot_series, - self.gb.keys(), - &self.gb.groups, - PivotAgg::Median, - )) - } -} - -#[cfg(test)] -mod test { - use crate::prelude::*; - - #[test] - fn test_group_by() { - let s0 = Date32Chunked::parse_from_str_slice( - "date", - &[ - "2020-08-21", - "2020-08-21", - "2020-08-22", - "2020-08-23", - "2020-08-22", - ], - "%Y-%m-%d", - ) - .into_series(); - let s1 = Series::new("temp", [20, 10, 7, 9, 1].as_ref()); - let s2 = Series::new("rain", [0.2, 0.1, 0.3, 0.1, 0.01].as_ref()); - let df = DataFrame::new(vec![s0, s1, s2]).unwrap(); - println!("{:?}", df); - - println!( - "{:?}", - df.groupby("date").unwrap().select("temp").count().unwrap() - ); - // Select multiple - println!( - "{:?}", - df.groupby("date") - .unwrap() - .select(&["temp", "rain"]) - .mean() - .unwrap() - ); - // Group by multiple - println!( - "multiple keys {:?}", - df.groupby(&["date", "temp"]) - .unwrap() - .select("rain") - .mean() - .unwrap() - ); - println!( - "{:?}", - df.groupby("date").unwrap().select("temp").sum().unwrap() - ); - println!( - "{:?}", - df.groupby("date").unwrap().select("temp").min().unwrap() - ); - println!( - "{:?}", - df.groupby("date").unwrap().select("temp").max().unwrap() - ); - println!( - "{:?}", - df.groupby("date") - .unwrap() - .select("temp") - .agg_list() - .unwrap() - ); - println!( - "{:?}", - df.groupby("date").unwrap().select("temp").first().unwrap() - ); - } - - #[test] - fn test_pivot() { - let s0 = Series::new("foo", ["A", "A", "B", "B", "C"].as_ref()); - let s1 = Series::new("N", [1, 2, 2, 4, 2].as_ref()); - let s2 = Series::new("bar", ["k", "l", "m", "m", "l"].as_ref()); - let df = DataFrame::new(vec![s0, s1, s2]).unwrap(); - println!("{:?}", df); - - let pvt = df.groupby("foo").unwrap().pivot("bar", "N").sum().unwrap(); - assert_eq!( - Vec::from(pvt.column("m").unwrap().i32().unwrap()), - &[None, Some(6), None] - ); - let pvt = df.groupby("foo").unwrap().pivot("bar", "N").min().unwrap(); - assert_eq!( - Vec::from(pvt.column("m").unwrap().i32().unwrap()), - &[None, Some(2), None] - ); - let pvt = df.groupby("foo").unwrap().pivot("bar", "N").max().unwrap(); - assert_eq!( - Vec::from(pvt.column("m").unwrap().i32().unwrap()), - &[None, Some(4), None] - ); - let pvt = df.groupby("foo").unwrap().pivot("bar", "N").mean().unwrap(); - assert_eq!( - Vec::from(pvt.column("m").unwrap().i32().unwrap()), - &[None, Some(3), None] - ); - } -} diff --git a/polars/src/frame/hash_join.rs b/polars/src/frame/hash_join.rs deleted file mode 100644 index d5ab4b9b233c..000000000000 --- a/polars/src/frame/hash_join.rs +++ /dev/null @@ -1,860 +0,0 @@ -use crate::prelude::*; -use crate::utils::Xob; -use enum_dispatch::enum_dispatch; -use fnv::{FnvBuildHasher, FnvHashMap}; -#[cfg(feature = "parallel")] -use rayon::prelude::*; -use std::collections::{HashMap, HashSet}; -use std::hash::Hash; -use unsafe_unwrap::UnsafeUnwrap; - -macro_rules! hash_join_inner { - ($s_right:ident, $ca_left:ident, $type_:ident) => {{ - // call the type method series.i32() - let ca_right = $s_right.$type_()?; - $ca_left.hash_join_inner(ca_right) - }}; -} - -macro_rules! par_hash_join_inner { - ($s_right:ident, $ca_left:ident, $type_:ident) => {{ - // call the type method series.i32() - let ca_right = $s_right.$type_()?; - $ca_left.par_hash_join_inner(ca_right) - }}; -} - -macro_rules! hash_join_left { - ($s_right:ident, $ca_left:ident, $type_:ident) => {{ - // call the type method series.i32() - let ca_right = $s_right.$type_()?; - $ca_left.hash_join_left(ca_right) - }}; -} - -macro_rules! par_hash_join_left { - ($s_right:ident, $ca_left:ident, $type_:ident) => {{ - // call the type method series.i32() - let ca_right = $s_right.$type_()?; - $ca_left.par_hash_join_left(ca_right) - }}; -} - -macro_rules! hash_join_outer { - ($s_right:ident, $ca_left:ident, $type_:ident) => {{ - // call the type method series.i32() - let ca_right = $s_right.$type_()?; - $ca_left.hash_join_outer(ca_right) - }}; -} - -macro_rules! apply_hash_join_on_series { - ($s_left:ident, $s_right:ident, $join_macro:ident) => {{ - match $s_left { - Series::UInt8(ca_left) => $join_macro!($s_right, ca_left, u8), - Series::UInt16(ca_left) => $join_macro!($s_right, ca_left, u16), - Series::UInt32(ca_left) => $join_macro!($s_right, ca_left, u32), - Series::UInt64(ca_left) => $join_macro!($s_right, ca_left, u64), - Series::Int8(ca_left) => $join_macro!($s_right, ca_left, i8), - Series::Int16(ca_left) => $join_macro!($s_right, ca_left, i16), - Series::Int32(ca_left) => $join_macro!($s_right, ca_left, i32), - Series::Int64(ca_left) => $join_macro!($s_right, ca_left, i64), - Series::Bool(ca_left) => $join_macro!($s_right, ca_left, bool), - Series::Utf8(ca_left) => $join_macro!($s_right, ca_left, utf8), - Series::Date32(ca_left) => $join_macro!($s_right, ca_left, date32), - Series::Date64(ca_left) => $join_macro!($s_right, ca_left, date64), - Series::Time32Millisecond(ca_left) => { - $join_macro!($s_right, ca_left, time32_millisecond) - } - Series::Time32Second(ca_left) => $join_macro!($s_right, ca_left, time32_second), - Series::Time64Nanosecond(ca_left) => $join_macro!($s_right, ca_left, time64_nanosecond), - Series::Time64Microsecond(ca_left) => { - $join_macro!($s_right, ca_left, time64_microsecond) - } - Series::DurationMillisecond(ca_left) => { - $join_macro!($s_right, ca_left, duration_millisecond) - } - Series::DurationSecond(ca_left) => $join_macro!($s_right, ca_left, duration_second), - Series::DurationNanosecond(ca_left) => { - $join_macro!($s_right, ca_left, duration_nanosecond) - } - Series::DurationMicrosecond(ca_left) => { - $join_macro!($s_right, ca_left, duration_microsecond) - } - Series::TimestampMillisecond(ca_left) => { - $join_macro!($s_right, ca_left, timestamp_millisecond) - } - Series::TimestampSecond(ca_left) => $join_macro!($s_right, ca_left, timestamp_second), - Series::TimestampNanosecond(ca_left) => { - $join_macro!($s_right, ca_left, timestamp_nanosecond) - } - Series::TimestampMicrosecond(ca_left) => { - $join_macro!($s_right, ca_left, timestamp_microsecond) - } - Series::IntervalDayTime(ca_left) => $join_macro!($s_right, ca_left, interval_daytime), - Series::IntervalYearMonth(ca_left) => { - $join_macro!($s_right, ca_left, interval_year_month) - } - _ => unimplemented!(), - } - }}; -} - -pub(crate) fn prepare_hashed_relation( - b: impl Iterator, -) -> HashMap, FnvBuildHasher> -where - T: Hash + Eq + Copy, -{ - let mut hash_tbl = FnvHashMap::default(); - - b.enumerate() - .for_each(|(idx, key)| hash_tbl.entry(key).or_insert_with(Vec::new).push(idx)); - hash_tbl -} - -/// Hash join a and b. -/// b should be the shorter relation. -/// NOTE that T also can be an Option. Nulls are seen as equal. -fn hash_join_tuples_inner( - a: impl Iterator, - b: impl Iterator, - // Because b should be the shorter relation we could need to swap to keep left left and right right. - swap: bool, -) -> Vec<(usize, usize)> -where - T: Hash + Eq + Copy, -{ - let mut results = Vec::new(); - // First we hash one relation - let hash_tbl = prepare_hashed_relation(b); - - // Next we probe the other relation in the hash table - // code duplication is because we want to only do the swap check once - if swap { - a.enumerate().for_each(|(idx_a, key)| { - if let Some(indexes_b) = hash_tbl.get(&key) { - let tuples = indexes_b.iter().map(|&idx_b| (idx_b, idx_a)); - results.extend(tuples) - } - }); - } else { - a.enumerate().for_each(|(idx_a, key)| { - if let Some(indexes_b) = hash_tbl.get(&key) { - let tuples = indexes_b.iter().map(|&idx_b| (idx_a, idx_b)); - results.extend(tuples) - } - }); - } - results -} - -#[cfg(feature = "parallel")] -macro_rules! par_prepare_hashed_relation { - ($b:expr) => {{ - $b - // acc is the hashmap - .fold(FnvHashMap::default, |mut acc, (idx, key)| { - acc.entry(key).or_insert_with(Vec::new).push(idx); - acc - }) - .reduce(FnvHashMap::default, |mut map1, map2| { - map1.extend(map2.into_iter()); - map1 - }) - }}; -} - -#[cfg(feature = "parallel")] -macro_rules! par_hash_join_tuples_inner { - ($a:expr, $b:expr, $swap:expr) => {{ - // First we hash one relation - let hash_tbl = par_prepare_hashed_relation!($b); - - // Next we probe the other relation in the hash table - // code duplication is because we want to only do the swap check once - if $swap { - $a.map(|(idx_a, key)| { - if let Some(indexes_b) = hash_tbl.get(&key) { - let tuples: Vec<_> = indexes_b.iter().map(|&idx_b| (idx_b, idx_a)).collect(); - tuples - } else { - Vec::with_capacity(0) - } - }) - .reduce(Vec::new, |mut v1, v2| { - v1.extend(v2); - v1 - }) - } else { - $a.map(|(idx_a, key)| { - if let Some(indexes_b) = hash_tbl.get(&key) { - let tuples: Vec<_> = indexes_b.iter().map(|&idx_b| (idx_a, idx_b)).collect(); - tuples - } else { - Vec::with_capacity(0) - } - }) - .reduce(Vec::new, |mut v1, v2| { - v1.extend(v2); - v1 - }) - } - }}; -} - -#[cfg(feature = "parallel")] -macro_rules! par_hash_join_tuples_left { - ($a:expr, $b:expr) => {{ - // First we hash one relation - let hash_tbl = par_prepare_hashed_relation!($b); - - // Next we probe the other relation in the hash table - $a.map(|(idx_a, key)| { - match hash_tbl.get(&key) { - // left and right matches - Some(indexes_b) => { - let tuples: Vec<_> = indexes_b - .iter() - .map(|&idx_b| (idx_a, Some(idx_b))) - .collect(); - tuples - } - // only left values, right = null - None => vec![(idx_a, None)], - } - }) - .reduce(Vec::new, |mut v1, v2| { - v1.extend(v2); - v1 - }) - }}; -} - -/// Hash join left. None/ Nulls are regarded as Equal -/// All left values are joined so no Option there. -fn hash_join_tuples_left( - a: impl Iterator, - b: impl Iterator, -) -> Vec<(usize, Option)> -where - T: Hash + Eq + Copy, -{ - let mut results = Vec::new(); - // First we hash one relation - let hash_tbl = prepare_hashed_relation(b); - - // Next we probe the other relation in the hash table - a.enumerate().for_each(|(idx_a, key)| { - match hash_tbl.get(&key) { - // left and right matches - Some(indexes_b) => results.extend(indexes_b.iter().map(|&idx_b| (idx_a, Some(idx_b)))), - // only left values, right = null - None => results.push((idx_a, None)), - } - }); - results -} - -/// Hash join outer. Both left and right can have no match so Options -/// We accept a closure as we need to do two passes over the same iterators. -fn hash_join_tuples_outer<'a, T, I, J>( - a: I, - b: J, - swap: bool, -) -> Vec<(Option, Option)> -where - I: Iterator, - J: Iterator, - T: Hash + Eq + Copy + Sync, -{ - let mut results = Vec::with_capacity(a.size_hint().0 + b.size_hint().0); - - // prepare hash table - let mut hash_tbl = prepare_hashed_relation(b); - - // probe the hash table. - // Note: indexes from b that are not matched will be None, Some(idx_b) - // Therefore we remove the matches and the remaining will be joined from the right - - // code duplication is because we want to only do the swap check once - if swap { - a.enumerate().for_each(|(idx_a, key)| { - match hash_tbl.remove(&key) { - // left and right matches - Some(indexes_b) => { - results.extend(indexes_b.iter().map(|&idx_b| (Some(idx_b), Some(idx_a)))) - } - // only left values, right = null - None => { - results.push((None, Some(idx_a))); - } - } - }); - hash_tbl.iter().for_each(|(_k, indexes_b)| { - // remaining joined values from the right table - results.extend(indexes_b.iter().map(|&idx_b| (Some(idx_b), None))) - }); - } else { - a.enumerate().for_each(|(idx_a, key)| { - match hash_tbl.remove(&key) { - // left and right matches - Some(indexes_b) => { - results.extend(indexes_b.iter().map(|&idx_b| (Some(idx_a), Some(idx_b)))) - } - // only left values, right = null - None => { - results.push((Some(idx_a), None)); - } - } - }); - hash_tbl.iter().for_each(|(_k, indexes_b)| { - // remaining joined values from the right table - results.extend(indexes_b.iter().map(|&idx_b| (None, Some(idx_b)))) - }); - }; - - results -} - -pub trait HashJoin { - fn hash_join_inner(&self, other: &ChunkedArray) -> Vec<(usize, usize)>; - fn par_hash_join_inner(&self, _other: &ChunkedArray) -> Vec<(usize, usize)> { - panic!("set parallel feature to use parallel joins") - } - fn hash_join_left(&self, other: &ChunkedArray) -> Vec<(usize, Option)>; - fn par_hash_join_left(&self, _other: &ChunkedArray) -> Vec<(usize, Option)> { - panic!("set parallel feature to use parallel joins") - } - fn hash_join_outer(&self, other: &ChunkedArray) -> Vec<(Option, Option)>; -} - -macro_rules! det_hash_prone_order { - ($self:expr, $other:expr) => {{ - // The shortest relation will be used to create a hash table. - let left_first = $self.len() > $other.len(); - let a; - let b; - if left_first { - a = $self; - b = $other; - } else { - b = $self; - a = $other; - } - - (a, b, !left_first) - }}; -} - -impl HashJoin for ChunkedArray -where - T: PolarsNumericType + Sync, - T::Native: Eq + Hash, -{ - fn hash_join_inner(&self, other: &ChunkedArray) -> Vec<(usize, usize)> { - let (a, b, swap) = det_hash_prone_order!(self, other); - - match (a.cont_slice(), b.cont_slice()) { - (Ok(a_slice), Ok(b_slice)) => { - hash_join_tuples_inner(a_slice.iter(), b_slice.iter(), swap) - } - (Ok(a_slice), Err(_)) => { - hash_join_tuples_inner( - a_slice.iter().map(|v| Some(*v)), // take ownership - b.into_iter(), - swap, - ) - } - (Err(_), Ok(b_slice)) => { - hash_join_tuples_inner(a.into_iter(), b_slice.iter().map(|v| Some(*v)), swap) - } - (Err(_), Err(_)) => hash_join_tuples_inner(a.into_iter(), b.into_iter(), swap), - } - } - - #[cfg(feature = "parallel")] - fn par_hash_join_inner(&self, other: &ChunkedArray) -> Vec<(usize, usize)> { - let (a, b, swap) = det_hash_prone_order!(self, other); - - match (a.cont_slice(), b.cont_slice()) { - (Ok(a_slice), Ok(b_slice)) => { - par_hash_join_tuples_inner!(a_slice.into_par_iter().enumerate(), b_slice.into_par_iter().enumerate(), swap) - } - (Ok(a_slice), Err(_)) => { - hash_join_tuples_inner( - a_slice.iter().map(|v| Some(*v)), // take ownership - b.into_iter(), - swap, - ) - } - (Err(_), Ok(b_slice)) => { - hash_join_tuples_inner(a.into_iter(), b_slice.iter().map(|v| Some(*v)), swap) - } - (Err(_), Err(_)) => hash_join_tuples_inner(a.into_iter(), b.into_iter(), swap), - } - } - - #[cfg(feature = "parallel")] - fn par_hash_join_left(&self, other: &ChunkedArray) -> Vec<(usize, Option)> { - match (self.cont_slice(), other.cont_slice()) { - (Ok(a_slice), Ok(b_slice)) => { - par_hash_join_tuples_left!(a_slice.par_iter().enumerate(), b_slice.into_par_iter().enumerate()) - } - (Ok(a_slice), Err(_)) => { - hash_join_tuples_left( - a_slice.iter().map(|v| Some(*v)), // take ownership - other.into_iter(), - ) - } - (Err(_), Ok(b_slice)) => { - hash_join_tuples_left(self.into_iter(), b_slice.iter().map(|v| Some(*v))) - } - (Err(_), Err(_)) => hash_join_tuples_left(self.into_iter(), other.into_iter()), - } - } - - fn hash_join_left(&self, other: &ChunkedArray) -> Vec<(usize, Option)> { - match (self.cont_slice(), other.cont_slice()) { - (Ok(a_slice), Ok(b_slice)) => hash_join_tuples_left(a_slice.iter(), b_slice.iter()), - (Ok(a_slice), Err(_)) => { - hash_join_tuples_left( - a_slice.iter().map(|v| Some(*v)), // take ownership - other.into_iter(), - ) - } - (Err(_), Ok(b_slice)) => { - hash_join_tuples_left(self.into_iter(), b_slice.iter().map(|v| Some(*v))) - } - (Err(_), Err(_)) => hash_join_tuples_left(self.into_iter(), other.into_iter()), - } - } - - fn hash_join_outer(&self, other: &ChunkedArray) -> Vec<(Option, Option)> { - let (a, b, swap) = det_hash_prone_order!(self, other); - match (a.cont_slice(), b.cont_slice()) { - (Ok(a_slice), Ok(b_slice)) => { - hash_join_tuples_outer(a_slice.iter(), b_slice.iter(), swap) - } - (Ok(a_slice), Err(_)) => { - hash_join_tuples_outer( - a_slice.iter().map(|v| Some(*v)), // take ownership - b.into_iter(), - swap, - ) - } - (Err(_), Ok(b_slice)) => hash_join_tuples_outer( - a.into_iter(), - b_slice.iter().map(|v: &T::Native| Some(*v)), - swap, - ), - (Err(_), Err(_)) => hash_join_tuples_outer(a.into_iter(), b.into_iter(), swap), - } - } -} - -impl HashJoin for BooleanChunked { - fn hash_join_inner(&self, other: &BooleanChunked) -> Vec<(usize, usize)> { - let (a, b, swap) = det_hash_prone_order!(self, other); - - // Create the join tuples - match (a.is_optimal_aligned(), b.is_optimal_aligned()) { - (true, true) => { - hash_join_tuples_inner(a.into_no_null_iter(), b.into_no_null_iter(), swap) - } - _ => hash_join_tuples_inner(a.into_iter(), b.into_iter(), swap), - } - } - #[cfg(feature = "parallel")] - fn par_hash_join_inner(&self, other: &BooleanChunked) -> Vec<(usize, usize)> { - self.hash_join_inner(other) - } - - fn hash_join_left(&self, other: &BooleanChunked) -> Vec<(usize, Option)> { - match (self.is_optimal_aligned(), other.is_optimal_aligned()) { - (true, true) => { - hash_join_tuples_left(self.into_no_null_iter(), other.into_no_null_iter()) - } - _ => hash_join_tuples_left(self.into_iter(), other.into_iter()), - } - } - - #[cfg(feature = "parallel")] - fn par_hash_join_left(&self, other: &BooleanChunked) -> Vec<(usize, Option)> { - self.hash_join_left(other) - } - - fn hash_join_outer(&self, other: &BooleanChunked) -> Vec<(Option, Option)> { - let (a, b, swap) = det_hash_prone_order!(self, other); - match (a.is_optimal_aligned(), b.is_optimal_aligned()) { - (true, true) => { - hash_join_tuples_outer(a.into_no_null_iter(), b.into_no_null_iter(), swap) - } - _ => hash_join_tuples_outer(a.into_iter(), b.into_iter(), swap), - } - } -} - -impl HashJoin for Utf8Chunked { - fn hash_join_inner(&self, other: &Utf8Chunked) -> Vec<(usize, usize)> { - let (a, b, swap) = det_hash_prone_order!(self, other); - - // Create the join tuples - match (a.is_optimal_aligned(), b.is_optimal_aligned()) { - (true, true) => { - hash_join_tuples_inner(a.into_no_null_iter(), b.into_no_null_iter(), swap) - } - _ => hash_join_tuples_inner(a.into_iter(), b.into_iter(), swap), - } - } - - #[cfg(feature = "parallel")] - fn par_hash_join_inner(&self, other: &Utf8Chunked) -> Vec<(usize, usize)> { - let (a, b, swap) = det_hash_prone_order!(self, other); - - // Create the join tuples - match (a.is_optimal_aligned(), b.is_optimal_aligned()) { - (true, true) => par_hash_join_tuples_inner!( - NoNull(a).into_par_iter().enumerate(), - NoNull(b).into_par_iter().enumerate(), - swap - ), - _ => par_hash_join_tuples_inner!(a.into_par_iter().enumerate(), b.into_par_iter().enumerate(), swap), - } - } - - fn hash_join_left(&self, other: &Utf8Chunked) -> Vec<(usize, Option)> { - match (self.is_optimal_aligned(), other.is_optimal_aligned()) { - (true, true) => { - hash_join_tuples_left(self.into_no_null_iter(), other.into_no_null_iter()) - } - _ => hash_join_tuples_left(self.into_iter(), other.into_iter()), - } - } - - #[cfg(feature = "parallel")] - fn par_hash_join_left(&self, other: &Utf8Chunked) -> Vec<(usize, Option)> { - match (self.is_optimal_aligned(), other.is_optimal_aligned()) { - (true, true) => par_hash_join_tuples_left!( - NoNull(self).into_par_iter().enumerate(), - NoNull(other).into_par_iter().enumerate() - ), - _ => par_hash_join_tuples_left!(self.into_par_iter().enumerate(), other.into_par_iter().enumerate()), - } - } - - fn hash_join_outer(&self, other: &Utf8Chunked) -> Vec<(Option, Option)> { - let (a, b, swap) = det_hash_prone_order!(self, other); - match (a.is_optimal_aligned(), b.is_optimal_aligned()) { - (true, true) => { - hash_join_tuples_outer(a.into_no_null_iter(), b.into_no_null_iter(), swap) - } - _ => hash_join_tuples_outer(a.into_iter(), b.into_iter(), swap), - } - } -} - -#[enum_dispatch(Series)] -trait ZipOuterJoinColumn { - fn zip_outer_join_column( - &self, - _right_column: &Series, - _opt_join_tuples: &Vec<(Option, Option)>, - ) -> Series { - unimplemented!() - } -} - -impl ZipOuterJoinColumn for ChunkedArray -where - T: PolarsIntegerType, -{ - fn zip_outer_join_column( - &self, - right_column: &Series, - opt_join_tuples: &Vec<(Option, Option)>, - ) -> Series { - let right_ca = self.unpack_series_matching_type(right_column).unwrap(); - - let left_rand_access = self.take_rand(); - let right_rand_access = right_ca.take_rand(); - - opt_join_tuples - .iter() - .map(|(opt_left_idx, opt_right_idx)| { - if let Some(left_idx) = opt_left_idx { - unsafe { left_rand_access.get_unchecked(*left_idx) } - } else { - unsafe { - let right_idx = opt_right_idx.unsafe_unwrap(); - right_rand_access.get_unchecked(right_idx) - } - } - }) - .collect::>>() - .into_inner() - .into_series() - } -} - -impl ZipOuterJoinColumn for Float32Chunked {} -impl ZipOuterJoinColumn for Float64Chunked {} -impl ZipOuterJoinColumn for LargeListChunked {} - -macro_rules! impl_zip_outer_join { - ($chunkedtype:ident) => { - impl ZipOuterJoinColumn for $chunkedtype { - fn zip_outer_join_column( - &self, - right_column: &Series, - opt_join_tuples: &Vec<(Option, Option)>, - ) -> Series { - let right_ca = self.unpack_series_matching_type(right_column).unwrap(); - - let left_rand_access = self.take_rand(); - let right_rand_access = right_ca.take_rand(); - - opt_join_tuples - .iter() - .map(|(opt_left_idx, opt_right_idx)| { - if let Some(left_idx) = opt_left_idx { - unsafe { left_rand_access.get_unchecked(*left_idx) } - } else { - unsafe { - let right_idx = opt_right_idx.unsafe_unwrap(); - right_rand_access.get_unchecked(right_idx) - } - } - }) - .collect::<$chunkedtype>() - .into_series() - } - } - }; -} -impl_zip_outer_join!(BooleanChunked); -impl_zip_outer_join!(Utf8Chunked); - -impl DataFrame { - /// Utility method to finish a join. - fn finish_join(&self, mut df_left: DataFrame, mut df_right: DataFrame) -> Result { - let mut left_names = - HashSet::with_capacity_and_hasher(df_left.width(), FnvBuildHasher::default()); - - df_left.columns.iter().for_each(|series| { - left_names.insert(series.name()); - }); - - let mut rename_strs = Vec::with_capacity(df_right.width()); - - df_right.columns.iter().for_each(|series| { - if left_names.contains(series.name()) { - rename_strs.push(series.name().to_owned()) - } - }); - - for name in rename_strs { - df_right.rename(&name, &format!("{}_right", name))?; - } - - df_left.hstack(&df_right.columns)?; - Ok(df_left) - } - - fn create_left_df(&self, join_tuples: &[(usize, B)]) -> DataFrame { - unsafe { - self.take_iter_unchecked( - join_tuples.iter().map(|(left, _right)| *left), - Some(join_tuples.len()), - ) - } - } - - /// Perform an inner join on two DataFrames. - /// - /// # Example - /// - /// ``` - /// use polars::prelude::*; - /// fn join_dfs(left: &DataFrame, right: &DataFrame) -> Result { - /// left.inner_join(right, "join_column_left", "join_column_right") - /// } - /// ``` - pub fn inner_join( - &self, - other: &DataFrame, - left_on: &str, - right_on: &str, - ) -> Result { - let s_left = self.column(left_on)?; - let s_right = other.column(right_on)?; - let join_tuples = match self.parallel { - true => apply_hash_join_on_series!(s_left, s_right, par_hash_join_inner), - false => apply_hash_join_on_series!(s_left, s_right, hash_join_inner), - }; - - let (df_left, df_right) = rayon::join( - || self.create_left_df(&join_tuples), - || unsafe { - other.drop(right_on).unwrap().take_iter_unchecked( - join_tuples.iter().map(|(_left, right)| *right), - Some(join_tuples.len()), - ) - }, - ); - self.finish_join(df_left, df_right) - } - - /// Perform a left join on two DataFrames - /// # Example - /// - /// ``` - /// use polars::prelude::*; - /// fn join_dfs(left: &DataFrame, right: &DataFrame) -> Result { - /// left.left_join(right, "join_column_left", "join_column_right") - /// } - /// ``` - pub fn left_join(&self, other: &DataFrame, left_on: &str, right_on: &str) -> Result { - let s_left = self.column(left_on)?; - let s_right = other.column(right_on)?; - let opt_join_tuples = match self.parallel { - true => apply_hash_join_on_series!(s_left, s_right, par_hash_join_left), - false => apply_hash_join_on_series!(s_left, s_right, hash_join_left), - }; - - let (df_left, df_right) = rayon::join( - || self.create_left_df(&opt_join_tuples), - || unsafe { - other.drop(right_on).unwrap().take_opt_iter_unchecked( - opt_join_tuples.iter().map(|(_left, right)| *right), - Some(opt_join_tuples.len()), - ) - }, - ); - self.finish_join(df_left, df_right) - } - - /// Perform an outer join on two DataFrames - /// # Example - /// - /// ``` - /// use polars::prelude::*; - /// fn join_dfs(left: &DataFrame, right: &DataFrame) -> Result { - /// left.outer_join(right, "join_column_left", "join_column_right") - /// } - /// ``` - pub fn outer_join( - &self, - other: &DataFrame, - left_on: &str, - right_on: &str, - ) -> Result { - let s_left = self.column(left_on)?; - let s_right = other.column(right_on)?; - - // Get the indexes of the joined relations - let opt_join_tuples: Vec<(Option, Option)> = - apply_hash_join_on_series!(s_left, s_right, hash_join_outer); - - // Take the left and right dataframes by join tuples - let (mut df_left, df_right) = rayon::join( - || unsafe { - self.drop(left_on).unwrap().take_opt_iter_unchecked( - opt_join_tuples.iter().map(|(left, _right)| *left), - Some(opt_join_tuples.len()), - ) - }, - || unsafe { - other.drop(right_on).unwrap().take_opt_iter_unchecked( - opt_join_tuples.iter().map(|(_left, right)| *right), - Some(opt_join_tuples.len()), - ) - }, - ); - let mut s = s_left.zip_outer_join_column(s_right, &opt_join_tuples); - s.rename(left_on); - df_left.hstack(&[s])?; - self.finish_join(df_left, df_right) - } -} - -#[cfg(test)] -mod test { - use crate::prelude::*; - - fn create_frames() -> (DataFrame, DataFrame) { - let s0 = Series::new("days", &[0, 1, 2]); - let s1 = Series::new("temp", &[22.1, 19.9, 7.]); - let s2 = Series::new("rain", &[0.2, 0.1, 0.3]); - let temp = DataFrame::new(vec![s0, s1, s2]).unwrap(); - - let s0 = Series::new("days", &[1, 2, 3, 1]); - let s1 = Series::new("rain", &[0.1, 0.2, 0.3, 0.4]); - let rain = DataFrame::new(vec![s0, s1]).unwrap(); - (temp, rain) - } - - #[test] - fn test_inner_join() { - let (temp, rain) = create_frames(); - let joined = temp.inner_join(&rain, "days", "days").unwrap(); - - let join_col_days = Series::new("days", &[1, 2, 1]); - let join_col_temp = Series::new("temp", &[19.9, 7., 19.9]); - let join_col_rain = Series::new("rain", &[0.1, 0.3, 0.1]); - let join_col_rain_right = Series::new("rain_right", [0.1, 0.2, 0.4].as_ref()); - let true_df = DataFrame::new(vec![ - join_col_days, - join_col_temp, - join_col_rain, - join_col_rain_right, - ]) - .unwrap(); - - println!("{}", joined); - assert!(joined.frame_equal(&true_df)); - } - - #[test] - fn test_left_join() { - let s0 = Series::new("days", &[0, 1, 2, 3, 4]); - let s1 = Series::new("temp", &[22.1, 19.9, 7., 2., 3.]); - let temp = DataFrame::new(vec![s0, s1]).unwrap(); - - let s0 = Series::new("days", &[1, 2]); - let s1 = Series::new("rain", &[0.1, 0.2]); - let rain = DataFrame::new(vec![s0, s1]).unwrap(); - let joined = temp.left_join(&rain, "days", "days").unwrap(); - println!("{}", &joined); - assert_eq!( - (joined.column("rain").unwrap().sum::().unwrap() * 10.).round(), - 3. - ); - assert_eq!(joined.column("rain").unwrap().null_count(), 3); - - // test join on utf8 - let s0 = Series::new("days", &["mo", "tue", "wed", "thu", "fri"]); - let s1 = Series::new("temp", &[22.1, 19.9, 7., 2., 3.]); - let temp = DataFrame::new(vec![s0, s1]).unwrap(); - - let s0 = Series::new("days", &["tue", "wed"]); - let s1 = Series::new("rain", &[0.1, 0.2]); - let rain = DataFrame::new(vec![s0, s1]).unwrap(); - let joined = temp.left_join(&rain, "days", "days").unwrap(); - println!("{}", &joined); - assert_eq!( - (joined.column("rain").unwrap().sum::().unwrap() * 10.).round(), - 3. - ); - assert_eq!(joined.column("rain").unwrap().null_count(), 3); - } - - #[test] - fn test_outer_join() { - let (temp, rain) = create_frames(); - let joined = temp.outer_join(&rain, "days", "days").unwrap(); - println!("{:?}", &joined); - assert_eq!(joined.height(), 5); - assert_eq!(joined.column("days").unwrap().sum::(), Some(7)); - } -} diff --git a/polars/src/frame/mod.rs b/polars/src/frame/mod.rs deleted file mode 100644 index f51cf51888bf..000000000000 --- a/polars/src/frame/mod.rs +++ /dev/null @@ -1,1067 +0,0 @@ -//! DataFrame module. -use crate::frame::select::Selection; -use crate::prelude::*; -use arrow::datatypes::{Field, Schema}; -use arrow::record_batch::RecordBatch; -use itertools::Itertools; -use rayon::prelude::*; -use std::marker::Sized; -use std::mem; -use std::sync::Arc; - -pub mod group_by; -pub mod hash_join; -pub mod select; -pub mod ser; -mod upstream_traits; - -pub trait IntoSeries { - fn into_series(self) -> Series - where - Self: Sized; -} - -impl IntoSeries for Series { - fn into_series(self) -> Series { - self - } -} - -impl IntoSeries for ChunkedArray { - fn into_series(self) -> Series { - Series::from_chunked_array(self) - } -} - -type DfSchema = Arc; -type DfSeries = Series; -type DfColumns = Vec; - -#[derive(Clone)] -pub struct DataFrame { - columns: DfColumns, - parallel: bool, -} - -impl DataFrame { - /// Get the index of the column. - fn name_to_idx(&self, name: &str) -> Result { - let mut idx = 0; - for column in &self.columns { - if column.name() == name { - break; - } - idx += 1; - } - if idx == self.columns.len() { - Err(PolarsError::NotFound) - } else { - Ok(idx) - } - } - - /// Create a DataFrame from a Vector of Series. - /// - /// # Example - /// - /// ``` - /// use polars::prelude::*; - /// let s0 = Series::new("days", [0, 1, 2].as_ref()); - /// let s1 = Series::new("temp", [22.1, 19.9, 7.].as_ref()); - /// let df = DataFrame::new(vec![s0, s1]).unwrap(); - /// ``` - pub fn new(columns: Vec) -> Result { - let mut first_len = None; - let mut series_cols = Vec::with_capacity(columns.len()); - - // check for series length equality and convert into series in one pass - for s in columns { - let series = s.into_series(); - match first_len { - Some(len) => { - if series.len() != len { - return Err(PolarsError::ShapeMisMatch); - } - } - None => first_len = Some(series.len()), - } - series_cols.push(series) - } - let mut df = DataFrame { - columns: series_cols, - parallel: false, - }; - df.rechunk()?; - Ok(df) - } - - /// Opt in parallel operations. - #[cfg(feature = "parallel")] - pub fn with_parallel(&mut self, parallel: bool) -> &mut Self { - self.parallel = parallel; - self - } - - // doesn't check Series sizes. - fn new_no_checks(columns: Vec) -> DataFrame { - DataFrame { - columns, - parallel: false, - } - } - - /// Ensure all the chunks in the DataFrame are aligned. - fn rechunk(&mut self) -> Result<&mut Self> { - let mut chunk_lens = Vec::with_capacity(self.columns.len()); - let mut all_equal = true; - for series in &self.columns { - let current_len = series.len(); - if chunk_lens.len() > 1 { - if current_len != chunk_lens[0] { - all_equal = false - } - } - chunk_lens.push(series.len()) - } - - // fast path - if all_equal { - Ok(self) - } else { - let argmin = chunk_lens - .iter() - .position_min() - .ok_or(PolarsError::NoData)?; - let min_chunks = chunk_lens[argmin]; - - let to_rechunk = chunk_lens - .into_iter() - .enumerate() - .filter_map(|(idx, len)| if len > min_chunks { Some(idx) } else { None }) - .collect::>(); - - // clone shouldn't be too expensive as we expect the nr. of chunks to be close to 1. - let chunk_id = self.columns[argmin].chunk_lengths().clone(); - - for idx in to_rechunk { - let col = &self.columns[idx]; - let new_col = col.rechunk(Some(&chunk_id))?; - self.columns[idx] = new_col; - } - Ok(self) - } - } - - /// Get a reference to the DataFrame schema. - pub fn schema(&self) -> Schema { - let fields = Self::create_fields(&self.columns); - Schema::new(fields) - } - - /// Get a reference to the DataFrame columns. - #[inline] - pub fn get_columns(&self) -> &DfColumns { - &self.columns - } - - /// Get the column labels of the DataFrame. - pub fn columns(&self) -> Vec<&str> { - self.columns.iter().map(|s| s.name()).collect() - } - - /// Get the data types of the columns in the DataFrame. - pub fn dtypes(&self) -> Vec { - self.columns.iter().map(|s| s.dtype().clone()).collect() - } - - /// The number of chunks per column - pub fn n_chunks(&self) -> Result { - Ok(self - .columns - .get(0) - .ok_or(PolarsError::NoData)? - .chunks() - .len()) - } - - /// Get fields from the columns. - fn create_fields(columns: &DfColumns) -> Vec { - columns.iter().map(|s| s.field().clone()).collect() - } - - /// This method should be called after every mutable addition/ deletion of columns - fn register_mutation(&mut self) -> Result<()> { - self.rechunk()?; - Ok(()) - } - - /// Get a reference to the schema fields of the DataFrame. - pub fn fields(&self) -> Vec { - self.columns.iter().map(|s| s.field().clone()).collect() - } - - /// Get (width x height) - /// - /// # Example - /// - /// ``` - /// use polars::prelude::*; - /// fn assert_shape(df: &DataFrame, shape: (usize, usize)) { - /// assert_eq!(df.shape(), shape) - /// } - /// ``` - pub fn shape(&self) -> (usize, usize) { - let rows = self.columns.len(); - if rows > 0 { - (self.columns[0].len(), rows) - } else { - (0, 0) - } - } - - /// Get width of DataFrame - /// - /// # Example - /// - /// ``` - /// use polars::prelude::*; - /// fn assert_width(df: &DataFrame, width: usize) { - /// assert_eq!(df.width(), width) - /// } - /// ``` - pub fn width(&self) -> usize { - self.shape().1 - } - - /// Get height of DataFrame - /// - /// # Example - /// - /// ``` - /// use polars::prelude::*; - /// fn assert_height(df: &DataFrame, height: usize) { - /// assert_eq!(df.height(), height) - /// } - /// ``` - pub fn height(&self) -> usize { - self.shape().0 - } - - /// Add multiple Series to a DataFrame - /// This expects the Series to have the same length. - /// - /// # Example - /// - /// ``` - /// use polars::prelude::*; - /// fn stack(df: &mut DataFrame, columns: &[Series]) { - /// df.hstack(columns); - /// } - /// ``` - pub fn hstack(&mut self, columns: &[DfSeries]) -> Result<&mut Self> { - let height = self.height(); - for col in columns { - if col.len() != height { - return Err(PolarsError::ShapeMisMatch); - } else { - self.columns.push(col.clone()); - } - } - self.register_mutation()?; - Ok(self) - } - - /// Concatenate a DataFrame to this DataFrame - pub fn vstack(&mut self, df: &DataFrame) -> Result<&mut Self> { - if self.width() != df.width() { - return Err(PolarsError::ShapeMisMatch); - } - - if self.dtypes() != df.dtypes() { - return Err(PolarsError::DataTypeMisMatch); - } - self.columns - .iter_mut() - .zip(df.columns.iter()) - .for_each(|(left, right)| { - left.append(right).expect("should not fail"); - }); - self.register_mutation()?; - Ok(self) - } - - /// Remove column by name - /// - /// # Example - /// - /// ``` - /// use polars::prelude::*; - /// fn drop_column(df: &mut DataFrame, name: &str) -> Result { - /// df.drop_in_place(name) - /// } - /// ``` - pub fn drop_in_place(&mut self, name: &str) -> Result { - let idx = self.name_to_idx(name)?; - let result = Ok(self.columns.remove(idx)); - self.register_mutation()?; - result - } - - /// Drop a column by name. - /// This is a pure method and will return a new DataFrame instead of modifying - /// the current one in place. - pub fn drop(&self, name: &str) -> Result { - let idx = self.name_to_idx(name)?; - let mut new_cols = Vec::with_capacity(self.columns.len() - 1); - - self.columns.iter().enumerate().for_each(|(i, s)| { - if i != idx { - new_cols.push(s.clone()) - } - }); - - Ok(DataFrame::new_no_checks(new_cols)) - } - - /// Add a new column to this `DataFrame`. - pub fn add_column(&mut self, column: S) -> Result<&mut Self> { - let series = column.into_series(); - if series.len() == self.height() { - self.columns.push(series); - Ok(self) - } else { - Err(PolarsError::ShapeMisMatch) - } - } - - /// Create a new `DataFrame` with the column added. - pub fn with_column(&self, column: S) -> Result { - let mut df = self.clone(); - df.add_column(column)?; - Ok(df) - } - - /// Get a row in the `DataFrame` Beware this is slow. - /// - /// # Example - /// - /// ``` - /// use polars::prelude::*; - /// fn example(df: &mut DataFrame, idx: usize) -> Option> { - /// df.get(idx) - /// } - /// ``` - pub fn get(&self, idx: usize) -> Option> { - match self.columns.get(0) { - Some(s) => { - if s.len() <= idx { - return None; - } - } - None => return None, - } - Some(self.columns.iter().map(|s| s.get(idx)).collect()) - } - - /// Select a series by index. - pub fn select_at_idx(&self, idx: usize) -> Option<&Series> { - self.columns.get(idx) - } - - /// Select a mutable series by index. - /// - /// *Note: the length of the Series should remain the same otherwise the DataFrame is invalid.* - /// For this reason the method is not public - fn select_idx_mut(&mut self, idx: usize) -> Option<&mut Series> { - self.columns.get_mut(idx) - } - - /// Get column index of a series by name. - pub fn find_idx_by_name(&self, name: &str) -> Option { - self.columns - .iter() - .enumerate() - .filter(|(_idx, series)| series.name() == name) - .map(|(idx, _)| idx) - .next() - } - - /// Select a single column by name. - pub fn column(&self, name: &str) -> Result<&Series> { - let idx = self.find_idx_by_name(name).ok_or(PolarsError::NotFound)?; - Ok(self.select_at_idx(idx).unwrap()) - } - - /// Select column(s) from this DataFrame and return a new DataFrame. - /// - /// # Examples - /// - /// ``` - /// use polars::prelude::*; - /// - /// fn example(df: &DataFrame, possible: &str) -> Result { - /// match possible { - /// "by_str" => df.select("my-column"), - /// "by_tuple" => df.select(("col_1", "col_2")), - /// "by_vec" => df.select(vec!["col_a", "col_b"]), - /// _ => unimplemented!() - /// } - /// } - /// ``` - pub fn select<'a, S, J>(&self, selection: S) -> Result - where - S: Selection<'a, J>, - { - let selected = self.select_series(selection)?; - let df = DataFrame::new_no_checks(selected); - Ok(df) - } - - /// Select column(s) from this DataFrame and return them into a Vector. - pub fn select_series<'a, S, J>(&self, selection: S) -> Result> - where - S: Selection<'a, J>, - { - let cols = selection.to_selection_vec(); - let selected = cols - .iter() - .map(|c| self.column(c).map(|s| s.clone())) - .collect::>>()?; - Ok(selected) - } - - /// Select a mutable series by name. - /// *Note: the length of the Series should remain the same otherwise the DataFrame is invalid.* - /// For this reason the method is not public - fn select_mut(&mut self, name: &str) -> Option<&mut Series> { - let opt_idx = self.find_idx_by_name(name); - - match opt_idx { - Some(idx) => self.select_idx_mut(idx), - None => None, - } - } - - /// Take DataFrame rows by a boolean mask. - pub fn filter(&self, mask: &BooleanChunked) -> Result { - let new_col = self - .columns - .par_iter() - .map(|col| col.filter(mask)) - .collect::>>()?; - Ok(DataFrame::new_no_checks(new_col)) - } - - /// Take DataFrame value by indexes from an iterator. - /// - /// # Example - /// - /// ``` - /// use polars::prelude::*; - /// fn example(df: &DataFrame) -> Result { - /// let iterator = (0..9).into_iter(); - /// df.take_iter(iterator, None) - /// } - /// ``` - pub fn take_iter(&self, iter: I, capacity: Option) -> Result - where - I: Iterator + Clone + Sync, - { - let new_col = self - .columns - .par_iter() - .map(|s| { - let mut i = iter.clone(); - s.take_iter(&mut i, capacity) - }) - .collect::>>()?; - Ok(DataFrame::new_no_checks(new_col)) - } - - /// Take DataFrame values by indexes from an iterator. This doesn't do any bound checking. - /// - /// # Example - /// - /// ``` - /// use polars::prelude::*; - /// unsafe fn example(df: &DataFrame) -> DataFrame { - /// let iterator = (0..9).into_iter(); - /// df.take_iter_unchecked(iterator, None) - /// } - /// ``` - pub unsafe fn take_iter_unchecked(&self, iter: I, capacity: Option) -> Self - where - I: Iterator + Clone + Sync, - { - let new_col = self - .columns - .par_iter() - .map(|s| { - let mut i = iter.clone(); - s.take_iter_unchecked(&mut i, capacity) - }) - .collect::>(); - DataFrame::new_no_checks(new_col) - } - - /// Take DataFrame values by indexes from an iterator that may contain None values. - /// - /// # Example - /// - /// ``` - /// use polars::prelude::*; - /// fn example(df: &DataFrame) -> Result { - /// let iterator = (0..9).into_iter().map(Some); - /// df.take_opt_iter(iterator, None) - /// } - /// ``` - pub fn take_opt_iter(&self, iter: I, capacity: Option) -> Result - where - I: Iterator> + Clone + Sync, - { - let new_col = self - .columns - .par_iter() - .map(|s| { - let mut i = iter.clone(); - s.take_opt_iter(&mut i, capacity) - }) - .collect::>>()?; - Ok(DataFrame::new_no_checks(new_col)) - } - - /// Take DataFrame values by indexes from an iterator that may contain None values. - /// This doesn't do any bound checking. - /// - /// # Example - /// - /// ``` - /// use polars::prelude::*; - /// unsafe fn example(df: &DataFrame) -> DataFrame { - /// let iterator = (0..9).into_iter().map(Some); - /// df.take_opt_iter_unchecked(iterator, None) - /// } - /// ``` - pub unsafe fn take_opt_iter_unchecked(&self, iter: I, capacity: Option) -> Self - where - I: Iterator> + Clone + Sync, - { - let new_col = self - .columns - .par_iter() - .map(|s| { - let mut i = iter.clone(); - s.take_opt_iter_unchecked(&mut i, capacity) - }) - .collect::>(); - DataFrame::new_no_checks(new_col) - } - - /// Take DataFrame rows by index values. - /// - /// # Example - /// - /// ``` - /// use polars::prelude::*; - /// fn example(df: &DataFrame) -> Result { - /// let idx = vec![0, 1, 9]; - /// df.take(&idx) - /// } - /// ``` - pub fn take(&self, indices: &T) -> Result { - let new_col = self - .columns - .par_iter() - .map(|s| s.take(indices)) - .collect::>>()?; - - Ok(DataFrame::new_no_checks(new_col)) - } - - /// Rename a column in the DataFrame - /// - /// # Example - /// - /// ``` - /// use polars::prelude::*; - /// fn example(df: &mut DataFrame) -> Result<&mut DataFrame> { - /// let original_name = "foo"; - /// let new_name = "bar"; - /// df.rename(original_name, new_name) - /// } - /// ``` - pub fn rename(&mut self, column: &str, name: &str) -> Result<&mut Self> { - self.select_mut(column) - .ok_or(PolarsError::NotFound) - .map(|s| s.rename(name))?; - Ok(self) - } - - /// Sort DataFrame in place by a column. - pub fn sort_in_place(&mut self, by_column: &str, reverse: bool) -> Result<&mut Self> { - let s = self.column(by_column)?; - - let take = s.argsort(reverse); - - self.columns = self - .columns - .par_iter() - .map(|s| s.take(&take)) - .collect::>>()?; - Ok(self) - } - - /// Return a sorted clone of this DataFrame. - pub fn sort(&self, by_column: &str, reverse: bool) -> Result { - let s = self.column(by_column)?; - - let take = s.argsort(reverse); - self.take(&take) - } - - /// Replace a column with a series. - pub fn replace(&mut self, column: &str, new_col: S) -> Result<&mut Self> { - self.apply(column, |_| new_col.into_series()) - } - - /// Replace column at index `idx` with a series. - /// - /// # Example - /// - /// ```rust - /// use polars::prelude::*; - /// let s0 = Series::new("foo", &["ham", "spam", "egg"]); - /// let s1 = Series::new("ascii", &[70, 79, 79]); - /// let mut df = DataFrame::new(vec![s0, s1]).unwrap(); - /// - /// // Add 32 to get lowercase ascii values - /// df.replace_at_idx(1, df.select_at_idx(1).unwrap() + 32); - /// ``` - pub fn replace_at_idx(&mut self, idx: usize, new_col: S) -> Result<&mut Self> { - let mut new_column = new_col.into_series(); - if new_column.len() != self.height() { - return Err(PolarsError::ShapeMisMatch); - }; - if idx >= self.width() { - return Err(PolarsError::OutOfBounds); - } - let old_col = &mut self.columns[idx]; - mem::swap(old_col, &mut new_column); - Ok(self) - } - - /// Apply a closure to a column. This is the recommended way to do in place modification. - /// - /// # Example - /// - /// ```rust - /// use polars::prelude::*; - /// let s0 = Series::new("foo", &["ham", "spam", "egg"]); - /// let s1 = Series::new("names", &["Jean", "Claude", "van"]); - /// let mut df = DataFrame::new(vec![s0, s1]).unwrap(); - /// - /// fn str_to_len(str_val: &Series) -> Series { - /// str_val.utf8() - /// .unwrap() - /// .into_iter() - /// .map(|opt_name: Option<&str>| { - /// opt_name.map(|name: &str| name.len() as u32) - /// }) - /// .collect::() - /// .into_series() - /// } - /// - /// // Replace the names column by the length of the names. - /// df.apply("names", str_to_len); - /// ``` - /// Results in: - /// - /// ```text - /// +--------+-------+ - /// | foo | | - /// | --- | names | - /// | str | u32 | - /// +========+=======+ - /// | "ham" | 4 | - /// +--------+-------+ - /// | "spam" | 6 | - /// +--------+-------+ - /// | "egg" | 3 | - /// +--------+-------+ - /// ``` - pub fn apply(&mut self, column: &str, f: F) -> Result<&mut Self> - where - F: FnOnce(&Series) -> S, - S: IntoSeries, - { - let idx = self.find_idx_by_name(column).ok_or(PolarsError::NotFound)?; - self.apply_at_idx(idx, f) - } - - /// Apply a closure to a column at index `idx`. This is the recommended way to do in place - /// modification. - /// - /// # Example - /// - /// ```rust - /// use polars::prelude::*; - /// let s0 = Series::new("foo", &["ham", "spam", "egg"]); - /// let s1 = Series::new("ascii", &[70, 79, 79]); - /// let mut df = DataFrame::new(vec![s0, s1]).unwrap(); - /// - /// // Add 32 to get lowercase ascii values - /// df.apply_at_idx(1, |s| s + 32); - /// ``` - /// Results in: - /// - /// ```text - /// +--------+-------+ - /// | foo | ascii | - /// | --- | --- | - /// | str | i32 | - /// +========+=======+ - /// | "ham" | 102 | - /// +--------+-------+ - /// | "spam" | 111 | - /// +--------+-------+ - /// | "egg" | 111 | - /// +--------+-------+ - /// ``` - pub fn apply_at_idx(&mut self, idx: usize, f: F) -> Result<&mut Self> - where - F: FnOnce(&Series) -> S, - S: IntoSeries, - { - let col = self.columns.get_mut(idx).ok_or(PolarsError::OutOfBounds)?; - let name = col.name().to_string(); - let _ = mem::replace(col, f(col).into_series()); - - // make sure the name remains the same after applying the closure - unsafe { - let col = self.columns.get_unchecked_mut(idx); - col.rename(&name); - } - self.register_mutation()?; - Ok(self) - } - - /// Apply a closure that may fail to a column at index `idx`. This is the recommended way to do in place - /// modification. - /// - /// # Example - /// - /// This is the idomatic way to replace some values a column of a `DataFrame` given range of indexes. - /// - /// ``` - /// # use polars::prelude::*; - /// let s0 = Series::new("foo", &["ham", "spam", "egg", "bacon", "quack"]); - /// let s1 = Series::new("values", &[1, 2, 3, 4, 5]); - /// let mut df = DataFrame::new(vec![s0, s1]).unwrap(); - /// - /// let idx = &[0, 1, 4]; - /// - /// df.may_apply("foo", |s| { - /// s.utf8()? - /// .set_at_idx_with(idx, |opt_val| opt_val.map(|string| format!("{}-is-modified", string))) - /// }); - /// ``` - /// Results in: - /// - /// ```text - /// +---------------------+--------+ - /// | foo | values | - /// | --- | --- | - /// | str | i32 | - /// +=====================+========+ - /// | "ham-is-modified" | 1 | - /// +---------------------+--------+ - /// | "spam-is-modified" | 2 | - /// +---------------------+--------+ - /// | "egg" | 3 | - /// +---------------------+--------+ - /// | "bacon" | 4 | - /// +---------------------+--------+ - /// | "quack-is-modified" | 5 | - /// +---------------------+--------+ - /// ``` - pub fn may_apply_at_idx(&mut self, idx: usize, f: F) -> Result<&mut Self> - where - F: FnOnce(&Series) -> Result, - S: IntoSeries, - { - let col = self.columns.get_mut(idx).ok_or(PolarsError::OutOfBounds)?; - let name = col.name().to_string(); - - let _ = mem::replace(col, f(col).map(|s| s.into_series())?); - - // make sure the name remains the same after applying the closure - unsafe { - let col = self.columns.get_unchecked_mut(idx); - col.rename(&name); - } - self.register_mutation()?; - Ok(self) - } - - /// Apply a closure that may fail to a column. This is the recommended way to do in place - /// modification. - /// - /// # Example - /// - /// This is the idomatic way to replace some values a column of a `DataFrame` given a boolean mask. - /// - /// ``` - /// # use polars::prelude::*; - /// let s0 = Series::new("foo", &["ham", "spam", "egg", "bacon", "quack"]); - /// let s1 = Series::new("values", &[1, 2, 3, 4, 5]); - /// let mut df = DataFrame::new(vec![s0, s1]).unwrap(); - /// - /// // create a mask - /// let mask = || { - /// (df.column("values")?.lt_eq(1) | df.column("values")?.gt_eq(5)) - /// }; - /// let mask = mask().unwrap(); - /// - /// df.may_apply("foo", |s| { - /// s.utf8()? - /// .set(&mask, Some("not_within_bounds")) - /// }); - /// ``` - /// Results in: - /// - /// ```text - /// +---------------------+--------+ - /// | foo | values | - /// | --- | --- | - /// | str | i32 | - /// +=====================+========+ - /// | "not_within_bounds" | 1 | - /// +---------------------+--------+ - /// | "spam" | 2 | - /// +---------------------+--------+ - /// | "egg" | 3 | - /// +---------------------+--------+ - /// | "bacon" | 4 | - /// +---------------------+--------+ - /// | "not_within_bounds" | 5 | - /// +---------------------+--------+ - /// ``` - pub fn may_apply(&mut self, column: &str, f: F) -> Result<&mut Self> - where - F: FnOnce(&Series) -> Result, - S: IntoSeries, - { - let idx = self.find_idx_by_name(column).ok_or(PolarsError::NotFound)?; - self.may_apply_at_idx(idx, f) - } - - /// Slice the DataFrame along the rows. - pub fn slice(&self, offset: usize, length: usize) -> Result { - let col = self - .columns - .par_iter() - .map(|s| s.slice(offset, length)) - .collect::>>()?; - Ok(DataFrame::new_no_checks(col)) - } - - /// Get the head of the DataFrame - pub fn head(&self, length: Option) -> Self { - let col = self - .columns - .iter() - .map(|s| s.head(length)) - .collect::>(); - DataFrame::new_no_checks(col) - } - - /// Get the tail of the DataFrame - pub fn tail(&self, length: Option) -> Self { - let col = self - .columns - .iter() - .map(|s| s.tail(length)) - .collect::>(); - DataFrame::new_no_checks(col) - } - - /// Transform the underlying chunks in the DataFrame to Arrow RecordBatches - pub fn as_record_batches(&self) -> Result> { - let n_chunks = self.n_chunks()?; - let width = self.width(); - - let schema = Arc::new(self.schema()); - - let mut record_batches = Vec::with_capacity(n_chunks); - for i in 0..n_chunks { - // the columns of a single recorbatch - let mut rb_cols = Vec::with_capacity(width); - - for col in &self.columns { - rb_cols.push(Arc::clone(&col.chunks()[i])) - } - let rb = RecordBatch::try_new(Arc::clone(&schema), rb_cols)?; - record_batches.push(rb) - } - Ok(record_batches) - } - - pub fn iter_record_batches( - &mut self, - buffer_size: usize, - ) -> impl Iterator + '_ { - match self.n_chunks() { - Ok(1) => {} - Ok(_) => { - self.columns = self - .columns - .iter() - .map(|s| s.rechunk(None).unwrap()) - .collect(); - } - Err(_) => {} // no data. So iterator will be empty - } - RecordBatchIter { - columns: &self.columns, - schema: Arc::new(self.schema()), - buffer_size, - idx: 0, - len: self.height(), - } - } - - /// Get a DataFrame with all the columns in reversed order - pub fn reverse(&self) -> Self { - let col = self.columns.iter().map(|s| s.reverse()).collect::>(); - DataFrame::new_no_checks(col) - } - - /// Shift the values by a given period and fill the parts that will be empty due to this operation - /// with `Nones`. - /// - /// See the method on [Series](../series/enum.Series.html#method.shift) for more info on the `shift` operation. - pub fn shift(&self, periods: i32) -> Result { - let col = self - .columns - .iter() - .map(|s| s.shift(periods)) - .collect::>>()?; - Ok(DataFrame::new_no_checks(col)) - } - - /// Replace None values with one of the following strategies: - /// * Forward fill (replace None with the previous value) - /// * Backward fill (replace None with the next value) - /// * Mean fill (replace None with the mean of the whole array) - /// * Min fill (replace None with the minimum of the whole array) - /// * Max fill (replace None with the maximum of the whole array) - /// - /// See the method on [Series](../series/enum.Series.html#method.fill_none) for more info on the `fill_none` operation. - pub fn fill_none(&self, strategy: FillNoneStrategy) -> Result { - let col = self - .columns - .iter() - .map(|s| s.fill_none(strategy)) - .collect::>>()?; - Ok(DataFrame::new_no_checks(col)) - } - - /// Pipe different functions/ closure operations that work on a DataFrame together. - pub fn pipe(self, f: F) -> Result - where - F: Fn(DataFrame) -> Result, - { - f(self) - } - - /// Pipe different functions/ closure operations that work on a DataFrame together. - pub fn pipe_mut(&mut self, f: F) -> Result - where - F: Fn(&mut DataFrame) -> Result, - { - f(self) - } - - /// Pipe different functions/ closure operations that work on a DataFrame together. - pub fn pipe_with_args(self, f: F, args: Args) -> Result - where - F: Fn(DataFrame, Args) -> Result, - { - f(self, args) - } -} - -pub struct RecordBatchIter<'a> { - columns: &'a Vec, - schema: Arc, - buffer_size: usize, - idx: usize, - len: usize, -} - -impl<'a> Iterator for RecordBatchIter<'a> { - type Item = RecordBatch; - - fn next(&mut self) -> Option { - if self.idx >= self.len { - return None; - } - // most iterations the slice length will be buffer_size, except for the last. That one - // may be shorter - let length = if self.idx + self.buffer_size < self.len { - self.buffer_size - } else { - self.len - self.idx - }; - - let mut rb_cols = Vec::with_capacity(self.columns.len()); - // take a slice from all columns and add the the current RecordBatch - self.columns.iter().for_each(|s| { - let slice = s.slice(self.idx, length).unwrap(); - rb_cols.push(Arc::clone(&slice.chunks()[0])) - }); - let rb = RecordBatch::try_new(Arc::clone(&self.schema), rb_cols).unwrap(); - self.idx += length; - Some(rb) - } -} - -#[cfg(test)] -mod test { - use crate::prelude::*; - - fn create_frame() -> DataFrame { - let s0 = Series::new("days", [0, 1, 2].as_ref()); - let s1 = Series::new("temp", [22.1, 19.9, 7.].as_ref()); - DataFrame::new(vec![s0, s1]).unwrap() - } - - #[test] - fn test_select() { - let df = create_frame(); - assert_eq!(df.column("days").unwrap().eq(1).sum(), Some(1)); - } - - #[test] - fn test_filter() { - let df = create_frame(); - println!("{}", df.column("days").unwrap()); - println!("{:?}", df); - println!("{:?}", df.filter(&df.column("days").unwrap().eq(0))) - } - - #[test] - fn test_sort() { - let mut df = create_frame(); - df.sort_in_place("temp", false).unwrap(); - println!("{:?}", df); - } - - #[test] - fn slice() { - let df = create_frame(); - let sliced_df = df.slice(0, 2).expect("slice"); - assert_eq!(sliced_df.shape(), (2, 2)); - println!("{:?}", df) - } -} diff --git a/polars/src/frame/select.rs b/polars/src/frame/select.rs deleted file mode 100644 index 17c4a6052c7b..000000000000 --- a/polars/src/frame/select.rs +++ /dev/null @@ -1,54 +0,0 @@ -pub trait Selection<'a, S> { - fn to_selection_vec(self) -> Vec<&'a str>; -} - -impl<'a> Selection<'a, &str> for &'a str { - fn to_selection_vec(self) -> Vec<&'a str> { - vec![self] - } -} - -impl<'a> Selection<'a, &str> for Vec<&'a str> { - fn to_selection_vec(self) -> Vec<&'a str> { - self - } -} - -impl<'a, T, S: 'a> Selection<'a, S> for &'a T -where - T: AsRef<[S]>, - S: AsRef, -{ - fn to_selection_vec(self) -> Vec<&'a str> { - self.as_ref().iter().map(|s| s.as_ref()).collect() - } -} - -impl<'a> Selection<'a, &str> for (&'a str, &'a str) { - fn to_selection_vec(self) -> Vec<&'a str> { - vec![self.0, self.1] - } -} -impl<'a> Selection<'a, &str> for (&'a str, &'a str, &'a str) { - fn to_selection_vec(self) -> Vec<&'a str> { - vec![self.0, self.1, self.2] - } -} - -impl<'a> Selection<'a, &str> for (&'a str, &'a str, &'a str, &'a str) { - fn to_selection_vec(self) -> Vec<&'a str> { - vec![self.0, self.1, self.2, self.3] - } -} - -impl<'a> Selection<'a, &str> for (&'a str, &'a str, &'a str, &'a str, &'a str) { - fn to_selection_vec(self) -> Vec<&'a str> { - vec![self.0, self.1, self.2, self.3, self.4] - } -} - -impl<'a> Selection<'a, &str> for (&'a str, &'a str, &'a str, &'a str, &'a str, &'a str) { - fn to_selection_vec(self) -> Vec<&'a str> { - vec![self.0, self.1, self.2, self.3, self.4, self.5] - } -} diff --git a/polars/src/frame/ser/csv.rs b/polars/src/frame/ser/csv.rs deleted file mode 100644 index edb53b6ef20f..000000000000 --- a/polars/src/frame/ser/csv.rs +++ /dev/null @@ -1,308 +0,0 @@ -//! # (De)serializing CSV files -//! -//! ## Write a DataFrame to a csv file. -//! -//! ## Example -//! -//! ``` -//! use polars::prelude::*; -//! use std::fs::File; -//! -//! fn example(df: &mut DataFrame) -> Result<()> { -//! let mut file = File::create("example.csv").expect("could not create file"); -//! -//! CsvWriter::new(&mut file) -//! .has_headers(true) -//! .with_delimiter(b',') -//! .finish(df) -//! } -//! ``` -//! -//! ## Read a csv file to a DataFrame -//! -//! ## Example -//! -//! ``` -//! use polars::prelude::*; -//! use std::io::Cursor; -//! -//! let s = r#" -//! "sepal.length","sepal.width","petal.length","petal.width","variety" -//! 5.1,3.5,1.4,.2,"Setosa" -//! 4.9,3,1.4,.2,"Setosa" -//! 4.7,3.2,1.3,.2,"Setosa" -//! 4.6,3.1,1.5,.2,"Setosa" -//! 5,3.6,1.4,.2,"Setosa" -//! 5.4,3.9,1.7,.4,"Setosa" -//! 4.6,3.4,1.4,.3,"Setosa" -//! "#; -//! -//! let file = Cursor::new(s); -//! let df = CsvReader::new(file) -//! .infer_schema(Some(100)) -//! .has_header(true) -//! .with_batch_size(100) -//! .finish() -//! .unwrap(); -//! -//! assert_eq!("sepal.length", df.get_columns()[0].name()); -//! # assert_eq!(1, df.column("sepal.length").unwrap().chunks().len()); -//! ``` -use crate::prelude::*; -use crate::{frame::ser::finish_reader, utils::clone}; -pub use arrow::csv::{ReaderBuilder, WriterBuilder}; -use arrow::error::ArrowError; -use std::io::{Read, Seek, Write}; -use std::sync::Arc; - -/// Write a DataFrame to csv. -pub struct CsvWriter<'a, W: Write> { - /// File or Stream handler - buffer: &'a mut W, - /// Builds an Arrow CSV Writer - writer_builder: WriterBuilder, - buffer_size: usize, -} - -impl<'a, W> SerWriter<'a, W> for CsvWriter<'a, W> -where - W: Write, -{ - fn new(buffer: &'a mut W) -> Self { - CsvWriter { - buffer, - writer_builder: WriterBuilder::new(), - buffer_size: 1000, - } - } - - fn finish(self, df: &mut DataFrame) -> Result<()> { - let mut csv_writer = self.writer_builder.build(self.buffer); - - let iter = df.iter_record_batches(self.buffer_size); - for batch in iter { - csv_writer.write(&batch)? - } - Ok(()) - } -} - -impl<'a, W> CsvWriter<'a, W> -where - W: Write, -{ - /// Set whether to write headers - pub fn has_headers(mut self, has_headers: bool) -> Self { - self.writer_builder = self.writer_builder.has_headers(has_headers); - self - } - - /// Set the CSV file's column delimiter as a byte character - pub fn with_delimiter(mut self, delimiter: u8) -> Self { - self.writer_builder = self.writer_builder.with_delimiter(delimiter); - self - } - - /// Set the CSV file's date format - pub fn with_date_format(mut self, format: String) -> Self { - self.writer_builder = self.writer_builder.with_date_format(format); - self - } - - /// Set the CSV file's time format - pub fn with_time_format(mut self, format: String) -> Self { - self.writer_builder = self.writer_builder.with_time_format(format); - self - } - - /// Set the CSV file's timestamp formatch array in - pub fn with_timestamp_format(mut self, format: String) -> Self { - self.writer_builder = self.writer_builder.with_timestamp_format(format); - self - } - - /// Set the size of the write buffers. Batch size is the amount of rows written at once. - pub fn with_batch_size(mut self, batch_size: usize) -> Self { - self.buffer_size = batch_size; - self - } -} - -/// Creates a DataFrame after reading a csv. -pub struct CsvReader -where - R: Read + Seek, -{ - /// File or Stream object - reader: R, - /// Builds an Arrow csv reader - reader_builder: ReaderBuilder, - /// Aggregates chunk afterwards to a single chunk. - rechunk: bool, - /// Continue with next batch when a ParserError is encountered. - ignore_parser_error: bool, - // use by error ignore logic - max_records: Option, -} - -impl SerReader for CsvReader -where - R: Read + Seek, -{ - /// Create a new CsvReader from a file/ stream - fn new(reader: R) -> Self { - CsvReader { - reader, - reader_builder: ReaderBuilder::new(), - rechunk: true, - ignore_parser_error: false, - max_records: None, - } - } - - /// Rechunk to one contiguous chunk of memory after all data is read - fn set_rechunk(mut self, rechunk: bool) -> Self { - self.rechunk = rechunk; - self - } - - /// Continue with next batch when a ParserError is encountered. - fn with_ignore_parser_error(mut self) -> Self { - self.ignore_parser_error = true; - self - } - - /// Read the file and create the DataFrame. - fn finish(mut self) -> Result { - // It could be that we could not infer schema due to invalid lines. - // If we have a CsvError or ParserError we half the number of lines we use for - // schema inference - let reader = if self.ignore_parser_error && self.max_records.is_some() { - if self.max_records < Some(1) { - return Err(PolarsError::Other("Could not infer schema".to_string())); - } - let reader_val; - loop { - let rb = clone(&self.reader_builder); - match rb.build(clone(&self.reader)) { - Err(ArrowError::CsvError(_)) | Err(ArrowError::ParseError(_)) => { - self.max_records = self.max_records.map(|v| v / 2); - self.reader_builder = self.reader_builder.infer_schema(self.max_records); - } - Err(e) => return Err(PolarsError::ArrowError(e)), - Ok(reader) => { - reader_val = reader; - break; - } - } - } - reader_val - } else { - self.reader_builder.build(self.reader)? - }; - finish_reader(reader, self.rechunk, self.ignore_parser_error) - } -} - -impl CsvReader -where - R: Read + Seek, -{ - /// Create a new DataFrame by reading a csv file. - /// - /// # Example - /// - /// ``` - /// use polars::prelude::*; - /// use std::fs::File; - /// - /// fn example() -> Result { - /// let file = File::open("iris.csv").expect("could not open file"); - /// - /// CsvReader::new(file) - /// .infer_schema(None) - /// .has_header(true) - /// .finish() - /// } - /// ``` - - /// Set the CSV file's schema - pub fn with_schema(mut self, schema: Arc) -> Self { - self.reader_builder = self.reader_builder.with_schema(schema); - self - } - - /// Set whether the CSV file has headers - pub fn has_header(mut self, has_header: bool) -> Self { - self.reader_builder = self.reader_builder.has_header(has_header); - self - } - - /// Set the CSV file's column delimiter as a byte character - pub fn with_delimiter(mut self, delimiter: u8) -> Self { - self.reader_builder = self.reader_builder.with_delimiter(delimiter); - self - } - - /// Set the CSV reader to infer the schema of the file - pub fn infer_schema(mut self, max_records: Option) -> Self { - // used by error ignore logic - self.max_records = max_records; - self.reader_builder = self.reader_builder.infer_schema(max_records); - self - } - - /// Set the batch size (number of records to load at one time) - pub fn with_batch_size(mut self, batch_size: usize) -> Self { - self.reader_builder = self.reader_builder.with_batch_size(batch_size); - self - } - - /// Set the reader's column projection - pub fn with_projection(mut self, projection: Vec) -> Self { - self.reader_builder = self.reader_builder.with_projection(projection); - self - } -} - -#[cfg(test)] -mod test { - use crate::prelude::*; - - #[test] - fn write_csv() { - let mut buf: Vec = Vec::new(); - let mut df = create_df(); - - CsvWriter::new(&mut buf) - .has_headers(true) - .finish(&mut df) - .expect("csv written"); - let csv = std::str::from_utf8(&buf).unwrap(); - assert_eq!("days,temp\n0,22.1\n1,19.9\n2,7\n3,2\n4,3\n", csv); - } - - #[test] - fn test_parser_error_ignore() { - use std::io::Cursor; - - let s = r#" - "sepal.length","sepal.width","petal.length","petal.width","variety" - 5.1,3.5,1.4,.2,"Setosa" - 5.1,3.5,1.4,.2,"Setosa" - 4.9,3,1.4,.2,"Setosa", "extra-column" - "#; - - let file = Cursor::new(s); - - // just checks if unwrap doesn't panic - CsvReader::new(file) - // we also check if infer schema ignores errors - .infer_schema(Some(10)) - .has_header(true) - .with_batch_size(2) - .with_ignore_parser_error() - .finish() - .unwrap(); - } -} diff --git a/polars/src/frame/ser/ipc.rs b/polars/src/frame/ser/ipc.rs deleted file mode 100644 index a609c2f3a001..000000000000 --- a/polars/src/frame/ser/ipc.rs +++ /dev/null @@ -1,151 +0,0 @@ -//! # (De)serializing Arrows IPC format. -//! -//! Arrow IPC is a [binary format format](https://arrow.apache.org/docs/python/ipc.html). -//! It is the recommended way to serialize and deserialize Polars DataFrames as this is most true -//! to the data schema. -//! -//! ## Example -//! -//! ```rust -//! use polars::prelude::*; -//! use std::io::Cursor; -//! -//! -//! let s0 = Series::new("days", &[0, 1, 2, 3, 4]); -//! let s1 = Series::new("temp", &[22.1, 19.9, 7., 2., 3.]); -//! let mut df = DataFrame::new(vec![s0, s1]).unwrap(); -//! -//! // Create an in memory file handler. -//! // Vec: Read + Write -//! // Cursor: Seek -//! -//! let mut buf: Cursor> = Cursor::new(Vec::new()); -//! -//! // write to the in memory buffer -//! IPCWriter::new(&mut buf).finish(&mut df).expect("ipc writer"); -//! -//! // reset the buffers index after writing to the beginning of the buffer -//! buf.set_position(0); -//! -//! // read the buffer into a DataFrame -//! let df_read = IPCReader::new(buf).finish().unwrap(); -//! assert!(df.frame_equal(&df_read)); -//! ``` -use super::{finish_reader, ArrowReader, ArrowResult, RecordBatch}; -use crate::prelude::*; -use arrow::ipc::{ - reader::FileReader as ArrowIPCFileReader, writer::FileWriter as ArrowIPCFileWriter, -}; -use arrow::record_batch::RecordBatchReader; -use std::io::{Read, Seek, Write}; -use std::sync::Arc; - -/// Read Arrows IPC format into a DataFrame -pub struct IPCReader { - /// File or Stream object - reader: R, - /// Aggregates chunks afterwards to a single chunk. - rechunk: bool, - ignore_parser_error: bool, -} - -impl ArrowReader for ArrowIPCFileReader -where - R: Read + Seek, -{ - fn next(&mut self) -> ArrowResult> { - self.next_batch() - } - - fn schema(&self) -> Arc { - self.schema() - } -} - -impl SerReader for IPCReader -where - R: Read + Seek, -{ - fn new(reader: R) -> Self { - IPCReader { - reader, - rechunk: true, - ignore_parser_error: false, - } - } - fn set_rechunk(mut self, rechunk: bool) -> Self { - self.rechunk = rechunk; - self - } - - fn with_ignore_parser_error(mut self) -> Self { - self.ignore_parser_error = true; - self - } - - fn finish(self) -> Result { - let rechunk = self.rechunk; - let ipc_reader = ArrowIPCFileReader::try_new(self.reader)?; - finish_reader(ipc_reader, rechunk, self.ignore_parser_error) - } -} - -/// Write a DataFrame to Arrow's IPC format -pub struct IPCWriter<'a, W> { - writer: &'a mut W, - batch_size: usize, -} - -impl<'a, W> IPCWriter<'a, W> { - /// Set the size of the write buffer. Batch size is the amount of rows written at once. - pub fn with_batch_size(mut self, batch_size: usize) -> Self { - self.batch_size = batch_size; - self - } -} - -impl<'a, W> SerWriter<'a, W> for IPCWriter<'a, W> -where - W: Write, -{ - fn new(writer: &'a mut W) -> Self { - IPCWriter { - writer, - batch_size: 200_000, - } - } - - fn finish(self, df: &mut DataFrame) -> Result<()> { - let mut ipc_writer = ArrowIPCFileWriter::try_new(self.writer, &df.schema())?; - - let iter = df.iter_record_batches(self.batch_size); - - for batch in iter { - ipc_writer.write(&batch)? - } - Ok(()) - } -} - -#[cfg(test)] -mod test { - use crate::prelude::*; - use std::io::Cursor; - - #[test] - fn write_and_read_ipc() { - // Vec : Write + Read - // Cursor>: Seek - let mut buf: Cursor> = Cursor::new(Vec::new()); - let mut df = create_df(); - - IPCWriter::new(&mut buf) - .finish(&mut df) - .expect("ipc writer"); - - buf.set_position(0); - - let df_read = IPCReader::new(buf).finish().unwrap(); - assert!(df.frame_equal(&df_read)); - } -} diff --git a/polars/src/frame/ser/json.rs b/polars/src/frame/ser/json.rs deleted file mode 100644 index e3e6f129e1ef..000000000000 --- a/polars/src/frame/ser/json.rs +++ /dev/null @@ -1,172 +0,0 @@ -//! # (De)serialize JSON files. -//! -//! ## Read JSON to a DataFrame -//! -//! ## Example -//! -//! ``` -//! use polars::prelude::*; -//! use std::io::Cursor; -//! -//! let basic_json = r#"{"a":1, "b":2.0, "c":false, "d":"4"} -//! {"a":-10, "b":-3.5, "c":true, "d":"4"} -//! {"a":2, "b":0.6, "c":false, "d":"text"} -//! {"a":1, "b":2.0, "c":false, "d":"4"} -//! {"a":7, "b":-3.5, "c":true, "d":"4"} -//! {"a":1, "b":0.6, "c":false, "d":"text"} -//! {"a":1, "b":2.0, "c":false, "d":"4"} -//! {"a":5, "b":-3.5, "c":true, "d":"4"} -//! {"a":1, "b":0.6, "c":false, "d":"text"} -//! {"a":1, "b":2.0, "c":false, "d":"4"} -//! {"a":1, "b":-3.5, "c":true, "d":"4"} -//! {"a":100000000000000, "b":0.6, "c":false, "d":"text"}"#; -//! let file = Cursor::new(basic_json); -//! let df = JsonReader::new(file) -//! .infer_schema(Some(3)) -//! .with_batch_size(3) -//! .finish() -//! .unwrap(); -//! -//! println!("{:?}", df); -//! ``` -//! >>> Outputs: -//! -//! ```text -//! +-----+--------+-------+--------+ -//! | a | b | c | d | -//! | --- | --- | --- | --- | -//! | i64 | f64 | bool | str | -//! +=====+========+=======+========+ -//! | 1 | 2 | false | "4" | -//! +-----+--------+-------+--------+ -//! | -10 | -3.5e0 | true | "4" | -//! +-----+--------+-------+--------+ -//! | 2 | 0.6 | false | "text" | -//! +-----+--------+-------+--------+ -//! | 1 | 2 | false | "4" | -//! +-----+--------+-------+--------+ -//! | 7 | -3.5e0 | true | "4" | -//! +-----+--------+-------+--------+ -//! | 1 | 0.6 | false | "text" | -//! +-----+--------+-------+--------+ -//! | 1 | 2 | false | "4" | -//! +-----+--------+-------+--------+ -//! | 5 | -3.5e0 | true | "4" | -//! +-----+--------+-------+--------+ -//! | 1 | 0.6 | false | "text" | -//! +-----+--------+-------+--------+ -//! | 1 | 2 | false | "4" | -//! +-----+--------+-------+--------+ -//! ``` -//! -use crate::frame::ser::finish_reader; -use crate::prelude::*; -pub use arrow::json::ReaderBuilder; -use std::io::{Read, Seek}; -use std::sync::Arc; - -pub struct JsonReader -where - R: Read + Seek, -{ - reader: R, - reader_builder: ReaderBuilder, - rechunk: bool, - ignore_parser_error: bool, -} - -impl SerReader for JsonReader -where - R: Read + Seek, -{ - fn new(reader: R) -> Self { - JsonReader { - reader, - reader_builder: ReaderBuilder::new(), - rechunk: true, - ignore_parser_error: false, - } - } - - fn with_ignore_parser_error(mut self) -> Self { - self.ignore_parser_error = true; - self - } - - fn set_rechunk(mut self, rechunk: bool) -> Self { - self.rechunk = rechunk; - self - } - - fn finish(self) -> Result { - let rechunk = self.rechunk; - finish_reader( - self.reader_builder.build(self.reader)?, - rechunk, - self.ignore_parser_error, - ) - } -} - -impl JsonReader -where - R: Read + Seek, -{ - /// Set the JSON file's schema - pub fn with_schema(mut self, schema: Arc) -> Self { - self.reader_builder = self.reader_builder.with_schema(schema); - self - } - - /// Set the JSON reader to infer the schema of the file - pub fn infer_schema(mut self, max_records: Option) -> Self { - self.reader_builder = self.reader_builder.infer_schema(max_records); - self - } - - /// Set the batch size (number of records to load at one time) - /// This heavily influences loading time. - pub fn with_batch_size(mut self, batch_size: usize) -> Self { - self.reader_builder = self.reader_builder.with_batch_size(batch_size); - self - } - - /// Set the reader's column projection - pub fn with_projection(mut self, projection: Vec) -> Self { - self.reader_builder = self.reader_builder.with_projection(projection); - self - } -} - -#[cfg(test)] -mod test { - use crate::prelude::*; - use std::io::Cursor; - - #[test] - fn read_json() { - let basic_json = r#"{"a":1, "b":2.0, "c":false, "d":"4"} -{"a":-10, "b":-3.5, "c":true, "d":"4"} -{"a":2, "b":0.6, "c":false, "d":"text"} -{"a":1, "b":2.0, "c":false, "d":"4"} -{"a":7, "b":-3.5, "c":true, "d":"4"} -{"a":1, "b":0.6, "c":false, "d":"text"} -{"a":1, "b":2.0, "c":false, "d":"4"} -{"a":5, "b":-3.5, "c":true, "d":"4"} -{"a":1, "b":0.6, "c":false, "d":"text"} -{"a":1, "b":2.0, "c":false, "d":"4"} -{"a":1, "b":-3.5, "c":true, "d":"4"} -{"a":100000000000000, "b":0.6, "c":false, "d":"text"}"#; - let file = Cursor::new(basic_json); - let df = JsonReader::new(file) - .infer_schema(Some(3)) - .with_batch_size(3) - .finish() - .unwrap(); - - println!("{:?}", df); - assert_eq!("a", df.columns[0].name()); - assert_eq!("d", df.columns[3].name()); - assert_eq!((12, 4), df.shape()); - } -} diff --git a/polars/src/frame/ser/mod.rs b/polars/src/frame/ser/mod.rs deleted file mode 100644 index 58d4d332e31e..000000000000 --- a/polars/src/frame/ser/mod.rs +++ /dev/null @@ -1,165 +0,0 @@ -pub mod csv; -pub mod ipc; -pub mod json; -#[cfg(feature = "parquet")] -#[doc(cfg(feature = "parquet"))] -pub mod parquet; -use crate::prelude::*; -use arrow::{ - csv::Reader as ArrowCsvReader, error::ArrowError, error::Result as ArrowResult, - json::Reader as ArrowJsonReader, record_batch::RecordBatch, -}; -use std::io::{Read, Seek, Write}; -use std::sync::Arc; - -pub trait SerReader -where - R: Read + Seek, -{ - fn new(reader: R) -> Self; - - /// Rechunk to a single chunk after Reading file. - fn set_rechunk(self, rechunk: bool) -> Self; - - /// Continue with next batch when a ParserError is encountered. - fn with_ignore_parser_error(self) -> Self; - - /// Take the SerReader and return a parsed DataFrame. - fn finish(self) -> Result; -} - -pub trait SerWriter<'a, W> -where - W: Write, -{ - fn new(writer: &'a mut W) -> Self; - fn finish(self, df: &mut DataFrame) -> Result<()>; -} - -pub trait ArrowReader { - fn next(&mut self) -> ArrowResult>; - - fn schema(&self) -> Arc; -} - -impl ArrowReader for ArrowCsvReader { - fn next(&mut self) -> ArrowResult> { - self.next() - } - - fn schema(&self) -> Arc { - self.schema() - } -} - -impl ArrowReader for ArrowJsonReader { - fn next(&mut self) -> ArrowResult> { - self.next() - } - - fn schema(&self) -> Arc { - self.schema() - } -} - -pub fn finish_reader( - mut reader: R, - rechunk: bool, - ignore_parser_error: bool, -) -> Result { - fn init_ca(field: &Field) -> ChunkedArray - where - T: PolarsDataType, - { - ChunkedArray::new_from_chunks(field.name(), vec![]) - } - - let mut columns = reader - .schema() - .fields() - .iter() - .map(|field| match field.data_type() { - ArrowDataType::UInt8 => Series::UInt8(init_ca(field)), - ArrowDataType::UInt16 => Series::UInt16(init_ca(field)), - ArrowDataType::UInt32 => Series::UInt32(init_ca(field)), - ArrowDataType::UInt64 => Series::UInt64(init_ca(field)), - ArrowDataType::Int8 => Series::Int8(init_ca(field)), - ArrowDataType::Int16 => Series::Int16(init_ca(field)), - ArrowDataType::Int32 => Series::Int32(init_ca(field)), - ArrowDataType::Int64 => Series::Int64(init_ca(field)), - ArrowDataType::Float32 => Series::Float32(init_ca(field)), - ArrowDataType::Float64 => Series::Float64(init_ca(field)), - ArrowDataType::Utf8 => Series::Utf8(init_ca(field)), - ArrowDataType::Boolean => Series::Bool(init_ca(field)), - ArrowDataType::Date32(DateUnit::Millisecond) => Series::Date32(init_ca(field)), - ArrowDataType::Date64(DateUnit::Millisecond) => Series::Date64(init_ca(field)), - ArrowDataType::Duration(TimeUnit::Nanosecond) => { - Series::DurationNanosecond(init_ca(field)) - } - ArrowDataType::Duration(TimeUnit::Microsecond) => { - Series::DurationMicrosecond(init_ca(field)) - } - ArrowDataType::Duration(TimeUnit::Millisecond) => { - Series::DurationMillisecond(init_ca(field)) - } - ArrowDataType::Duration(TimeUnit::Second) => Series::DurationSecond(init_ca(field)), - ArrowDataType::Time64(TimeUnit::Nanosecond) => Series::Time64Nanosecond(init_ca(field)), - ArrowDataType::Time64(TimeUnit::Microsecond) => { - Series::Time64Microsecond(init_ca(field)) - } - ArrowDataType::Time32(TimeUnit::Millisecond) => { - Series::Time32Millisecond(init_ca(field)) - } - ArrowDataType::Time32(TimeUnit::Second) => Series::Time32Second(init_ca(field)), - ArrowDataType::Timestamp(TimeUnit::Nanosecond, _) => { - Series::TimestampNanosecond(init_ca(field)) - } - ArrowDataType::Timestamp(TimeUnit::Microsecond, _) => { - Series::TimestampMicrosecond(init_ca(field)) - } - ArrowDataType::Timestamp(TimeUnit::Millisecond, _) => { - Series::TimestampMillisecond(init_ca(field)) - } - ArrowDataType::Timestamp(TimeUnit::Second, _) => { - Series::TimestampSecond(init_ca(field)) - } - ArrowDataType::LargeList(_) => Series::LargeList(init_ca(field)), - t => panic!(format!("Arrow datatype {:?} is not supported", t)), - }) - .collect::>(); - - loop { - let batch = match reader.next() { - Err(ArrowError::ParseError(s)) => { - if ignore_parser_error { - continue; - } else { - return Err(PolarsError::ArrowError(ArrowError::ParseError(s))); - } - } - Err(e) => return Err(PolarsError::ArrowError(e)), - Ok(None) => break, - Ok(Some(batch)) => batch, - }; - batch - .columns() - .into_iter() - .zip(&mut columns) - .map(|(arr, ser)| ser.append_array(arr.clone())) - .collect::>>()?; - } - if rechunk { - columns = columns - .into_iter() - .map(|s| { - let s = s.rechunk(None)?; - Ok(s) - }) - .collect::>>()?; - } - - Ok(DataFrame { - columns, - parallel: false, - }) -} diff --git a/polars/src/frame/ser/parquet.rs b/polars/src/frame/ser/parquet.rs deleted file mode 100644 index 6e25ae9c57d3..000000000000 --- a/polars/src/frame/ser/parquet.rs +++ /dev/null @@ -1,102 +0,0 @@ -//! # Reading Apache parquet files. -//! -//! ## Example -//! -//! ```rust -//! use polars::prelude::*; -//! use std::fs::File; -//! -//! fn example() -> Result { -//! let r = File::open("some_file.parquet").unwrap(); -//! let reader = ParquetReader::new(r); -//! reader.finish() -//! } -//! ``` -//! -use super::{finish_reader, ArrowReader, ArrowResult, RecordBatch}; -use crate::prelude::*; -use arrow::record_batch::RecordBatchReader; -use parquet::arrow::{ - arrow_reader::ParquetRecordBatchReader, ArrowReader as ParquetArrowReader, - ParquetFileArrowReader, -}; -use parquet::file::reader::SerializedFileReader; -use std::io::{Read, Seek}; -use std::rc::Rc; -use std::sync::Arc; - -/// Read Apache parquet format into a DataFrame. -pub struct ParquetReader { - reader: R, - rechunk: bool, - batch_size: usize, - ignore_parser_error: bool, -} - -impl ArrowReader for ParquetRecordBatchReader { - fn next(&mut self) -> ArrowResult> { - self.next_batch() - } - - fn schema(&self) -> Arc { - ::schema(self) - } -} - -impl ParquetReader { - /// Set the size of the read buffer. Batch size is the amount of rows read at once. - /// This heavily influences loading time. - pub fn with_batch_size(mut self, batch_size: usize) -> Self { - self.batch_size = batch_size; - self - } -} - -impl SerReader for ParquetReader -where - R: 'static + Read + Seek + parquet::file::reader::Length + parquet::file::reader::TryClone, -{ - fn new(reader: R) -> Self { - ParquetReader { - reader, - rechunk: true, - // parquets are often large, so use a large batch size - batch_size: 524288, - ignore_parser_error: false, - } - } - - fn set_rechunk(mut self, rechunk: bool) -> Self { - self.rechunk = rechunk; - self - } - - fn with_ignore_parser_error(mut self) -> Self { - self.ignore_parser_error = true; - self - } - - fn finish(self) -> Result { - let rechunk = self.rechunk; - - let file_reader = Rc::new(SerializedFileReader::new(self.reader)?); - let mut arrow_reader = ParquetFileArrowReader::new(file_reader); - let record_reader = arrow_reader.get_record_reader(self.batch_size)?; - finish_reader(record_reader, rechunk, self.ignore_parser_error) - } -} - -#[cfg(test)] -mod test { - use crate::prelude::*; - use std::fs::File; - - #[test] - fn test_parquet() { - let r = File::open("data/simple.parquet").unwrap(); - let reader = ParquetReader::new(r); - let df = reader.finish().unwrap(); - assert_eq!(df.columns(), ["a", "b"]); - assert_eq!(df.shape(), (3, 2)); - } -} diff --git a/polars/src/frame/upstream_traits.rs b/polars/src/frame/upstream_traits.rs deleted file mode 100644 index 4d947e2f1634..000000000000 --- a/polars/src/frame/upstream_traits.rs +++ /dev/null @@ -1,61 +0,0 @@ -use crate::prelude::*; -use std::iter::FromIterator; -use std::ops::{ - Index, IndexMut, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive, -}; - -impl FromIterator for DataFrame { - /// # Panics - /// - /// Panics if Series have different lengths. - fn from_iter>(iter: T) -> Self { - let v = iter.into_iter().collect(); - DataFrame::new(v).expect("could not create DataFrame from iterator") - } -} - -impl Index for DataFrame { - type Output = Series; - - fn index(&self, index: usize) -> &Self::Output { - &self.columns[index] - } -} - -/// Gives mutable access to a DataFrame. -/// Warning: Use with care, if you modify the Series by replacing it with a different lengthed -/// Series you've invalidated the DataFrame. -impl IndexMut for DataFrame { - fn index_mut(&mut self, index: usize) -> &mut Self::Output { - &mut self.columns[index] - } -} - -macro_rules! impl_ranges { - ($range_type:ty) => { - impl Index<$range_type> for DataFrame { - type Output = [Series]; - - fn index(&self, index: $range_type) -> &Self::Output { - &self.columns[index] - } - } - }; -} - -impl_ranges!(Range); -impl_ranges!(RangeInclusive); -impl_ranges!(RangeFrom); -impl_ranges!(RangeTo); -impl_ranges!(RangeToInclusive); -impl_ranges!(RangeFull); - -// we don't implement Borrow or AsRef as upstream crates may add impl of trait for usize. -impl Index<&str> for DataFrame { - type Output = Series; - - fn index(&self, index: &str) -> &Self::Output { - let idx = self.name_to_idx(index).unwrap(); - &self.columns[idx] - } -} diff --git a/polars/src/lib.rs b/polars/src/lib.rs deleted file mode 100644 index 8566115d0fea..000000000000 --- a/polars/src/lib.rs +++ /dev/null @@ -1,202 +0,0 @@ -//! # Polars: *DataFrames in Rust* -//! -//! Polars is a DataFrame library for Rust. It is based on [Apache Arrows](https://arrow.apache.org/) memory model. -//! This means that operations on Polars array's *(called `Series` or `ChunkedArray` {if the type `T` is known})* are -//! optimally aligned cache friendly operations and SIMD. Sadly, Apache Arrow needs **nightly Rust**, -//! which means that Polars cannot run on stable. -//! -//! Read more in the pages of the [DataFrame](frame/struct.DataFrame.html), [Series](series/enum.Series.html), and -//! [ChunkedArray](chunked_array/struct.ChunkedArray.html) data structures. -//! -//! ## Read and write CSV/ JSON -//! -//! ``` -//! use polars::prelude::*; -//! use std::fs::File; -//! -//! fn example() -> Result { -//! let file = File::open("iris.csv") -//! .expect("could not read file"); -//! -//! CsvReader::new(file) -//! .infer_schema(None) -//! .has_header(true) -//! .finish() -//! } -//! ``` -//! -//! For more IO examples see: -//! -//! * [the csv module](frame/ser/csv/index.html) -//! * [the json module](frame/ser/json/index.html) -//! * [the IPC module](frame/ser/ipc/index.html) -//! * [the parquet module](frame/ser/parquet/index.html) -//! -//! ## Joins -//! -//! ``` -//! use polars::prelude::*; -//! -//! fn join() -> Result { -//! // Create first df. -//! let s0 = Series::new("days", &[0, 1, 2, 3, 4]); -//! let s1 = Series::new("temp", &[22.1, 19.9, 7., 2., 3.]); -//! let temp = DataFrame::new(vec![s0, s1])?; -//! -//! // Create second df. -//! let s0 = Series::new("days", &[1, 2]); -//! let s1 = Series::new("rain", &[0.1, 0.2]); -//! let rain = DataFrame::new(vec![s0, s1])?; -//! -//! // Left join on days column. -//! temp.left_join(&rain, "days", "days") -//! } -//! -//! println!("{}", join().unwrap()) -//! ``` -//! -//! ```text -//! +------+------+------+ -//! | days | temp | rain | -//! | --- | --- | --- | -//! | i32 | f64 | f64 | -//! +======+======+======+ -//! | 0 | 22.1 | null | -//! +------+------+------+ -//! | 1 | 19.9 | 0.1 | -//! +------+------+------+ -//! | 2 | 7 | 0.2 | -//! +------+------+------+ -//! | 3 | 2 | null | -//! +------+------+------+ -//! | 4 | 3 | null | -//! +------+------+------+ -//! ``` -//! -//! ## Groupby's | aggregations | pivots -//! -//! ``` -//! use polars::prelude::*; -//! fn groupby_sum(df: &DataFrame) -> Result { -//! df.groupby("column_name")? -//! .select("agg_column_name") -//! .sum() -//! } -//! ``` -//! -//! ## Arithmetic -//! ``` -//! use polars::prelude::*; -//! let s: Series = [1, 2, 3].iter().collect(); -//! let s_squared = &s * &s; -//! ``` -//! -//! ## Rust iterators -//! -//! ``` -//! use polars::prelude::*; -//! -//! let s: Series = [1, 2, 3].iter().collect(); -//! let s_squared: Series = s.i32() -//! .expect("datatype mismatch") -//! .into_iter() -//! .map(|optional_v| { -//! match optional_v { -//! Some(v) => Some(v * v), -//! None => None, // null value -//! } -//! }).collect(); -//! ``` -//! -//! ## Apply custom closures -//! -//! Besides running custom iterators, custom closures can be applied on the values of [ChunkedArray](chunked_array/struct.ChunkedArray.html) -//! by using the [apply](chunked_array/apply/trait.Apply.html) method. This method accepts -//! a closure that will be applied on all values of `Option` that are non null. Note that this is the -//! **fastest** way to apply a custom closure on `ChunkedArray`'s. -//! ``` -//! # use polars::prelude::*; -//! let s: Series = Series::new("values", [Some(1.0), None, Some(3.0)]); -//! // null values are ignored automatically -//! let squared = s.f64() -//! .unwrap() -//! .apply(|value| value.powf(2.0)) -//! .into_series(); -//! -//! assert_eq!(Vec::from(squared.f64().unwrap()), &[Some(1.0), None, Some(9.0)]) -//! ``` -//! -//! ## Comparisons -//! -//! ``` -//! use polars::prelude::*; -//! use itertools::Itertools; -//! let s = Series::new("dollars", &[1, 2, 3]); -//! let mask = s.eq(1); -//! let valid = [true, false, false].iter(); -//! -//! assert_eq!(Vec::from(mask), &[Some(true), Some(false), Some(false)]); -//! ``` -//! -//! ## Temporal data types -//! -//! ```rust -//! # use polars::prelude::*; -//! let dates = &[ -//! "2020-08-21", -//! "2020-08-21", -//! "2020-08-22", -//! "2020-08-23", -//! "2020-08-22", -//! ]; -//! // date format -//! let fmt = "%Y-%m-%d"; -//! // create date series -//! let s0 = Date32Chunked::parse_from_str_slice("date", dates, fmt) -//! .into_series(); -//! ``` -//! -//! -//! ## And more... -//! -//! * [DataFrame](frame/struct.DataFrame.html) -//! * [Series](series/enum.Series.html) -//! * [ChunkedArray](chunked_array/struct.ChunkedArray.html) -//! - [Operations implemented by Traits](chunked_array/ops/index.html) -//! * [Time/ DateTime utilities](doc/time/index.html) -//! * [Groupby, aggregations and pivots](frame/group_by/struct.GroupBy.html) -//! -//! ## Features -//! -//! Additional cargo features: -//! -//! * `pretty` (default) -//! - pretty printing of DataFrames -//! * `temporal (default)` -//! - Conversions between Chrono and Polars for temporal data -//! * `simd` -//! - SIMD operations -//! * `paquet` -//! - Read Apache Parquet format -//! * `random` -//! - Generate array's with randomly sampled values -//! * `ndarray` -//! - Convert from `DataFrame` to `ndarray` -//! * `parallel` -//! - Parallel variants of operation -#![allow(dead_code)] -#![feature(iterator_fold_self)] -#![feature(doc_cfg)] -#[macro_use] -pub mod series; -#[macro_use] -pub(crate) mod utils; -pub mod chunked_array; -pub mod datatypes; -#[cfg(feature = "docs")] -pub mod doc; -pub mod error; -mod fmt; -pub mod frame; -pub mod prelude; -pub mod testing; diff --git a/polars/src/prelude.rs b/polars/src/prelude.rs deleted file mode 100644 index 8f4dc064764c..000000000000 --- a/polars/src/prelude.rs +++ /dev/null @@ -1,58 +0,0 @@ -//! Everything you need to get started with Polars. -pub use crate::{ - chunked_array::{ - arithmetic::Pow, - builder::{ - AlignedAlloc, AlignedVec, BooleanChunkedBuilder, LargListBuilderTrait, - LargeListPrimitiveChunkedBuilder, LargeListUtf8ChunkedBuilder, NewChunkedArray, - PrimitiveChunkedBuilder, Utf8ChunkedBuilder, - }, - chunkops::ChunkOps, - comparison::NumComp, - iterator::{IntoNoNullIterator, NumericChunkIterDispatch}, - ops::{ - ChunkAgg, ChunkApply, ChunkCast, ChunkCompare, ChunkFillNone, ChunkFilter, ChunkFull, - ChunkReverse, ChunkSet, ChunkShift, ChunkSort, ChunkTake, ChunkUnique, - FillNoneStrategy, TakeRandom, TakeRandomUtf8, - }, - take::{AsTakeIndex, IntoTakeRandom, NumTakeRandomChunked, NumTakeRandomCont}, - ChunkedArray, Downcast, NoNull, - }, - datatypes, - datatypes::*, - error::{PolarsError, Result}, - frame::{ - ser::{ - csv::{CsvReader, CsvWriter}, - ipc::{IPCReader, IPCWriter}, - json::JsonReader, - SerReader, SerWriter, - }, - DataFrame, IntoSeries, - }, - series::{arithmetic::LhsNumOps, NamedFrom, Series}, - testing::*, -}; -pub use arrow::datatypes::{ArrowPrimitiveType, Field, Schema}; - -#[cfg(feature = "temporal")] -pub use crate::chunked_array::temporal::{ - AsNaiveDateTime, AsNaiveTime, FromNaiveDate, FromNaiveDateTime, FromNaiveTime, -}; - -#[cfg(test)] -pub(crate) fn create_df() -> DataFrame { - let s0 = Series::new("days", [0, 1, 2, 3, 4].as_ref()); - let s1 = Series::new("temp", [22.1, 19.9, 7., 2., 3.].as_ref()); - DataFrame::new(vec![s0, s1]).unwrap() -} -#[cfg(feature = "parquet")] -pub use crate::frame::ser::parquet::ParquetReader; - -#[macro_export] -macro_rules! as_result { - ($block:block) => {{ - let res: Result<_> = $block; - res - }}; -} diff --git a/polars/src/series/aggregate.rs b/polars/src/series/aggregate.rs deleted file mode 100644 index f9d7f761450f..000000000000 --- a/polars/src/series/aggregate.rs +++ /dev/null @@ -1,156 +0,0 @@ -use crate::prelude::*; -use num::{NumCast, ToPrimitive, Zero}; -use std::ops::Div; - -// TODO: implement types -macro_rules! apply_agg_fn { - ($self:ident, $agg:ident) => { - match $self { - Series::Bool(a) => a - .$agg() - .map(|v| T::from(v).expect("could not cast bool to T")), - Series::UInt8(a) => a - .$agg() - .map(|v| T::from(v).expect("could not cast u8 to T")), - Series::UInt16(a) => a - .$agg() - .map(|v| T::from(v).expect("could not cast u16 to T")), - Series::UInt32(a) => a - .$agg() - .map(|v| T::from(v).expect("could not cast u32 to T")), - Series::UInt64(a) => a - .$agg() - .map(|v| T::from(v).expect("could not cast u64 to T")), - Series::Int8(a) => a - .$agg() - .map(|v| T::from(v).expect("could not cast i8 to T")), - Series::Int16(a) => a - .$agg() - .map(|v| T::from(v).expect("could not cast i16 to T")), - Series::Int32(a) => a - .$agg() - .map(|v| T::from(v).expect("could not cast i32 to T")), - Series::Int64(a) => a - .$agg() - .map(|v| T::from(v).expect("could not cast i64 to T")), - Series::Float32(a) => a - .$agg() - .map(|v| T::from(v).expect("could not cast f32 to T")), - Series::Float64(a) => a - .$agg() - .map(|v| T::from(v).expect("could not cast f64 to T")), - Series::Date32(a) => a - .$agg() - .map(|v| T::from(v).expect("could not cast Date64 to T")), - Series::Date64(a) => a - .$agg() - .map(|v| T::from(v).expect("could not cast Date64 to T")), - Series::Time32Millisecond(a) => a - .$agg() - .map(|v| T::from(v).expect("could not cast Time32Millisecond to T")), - Series::Time32Second(a) => a - .$agg() - .map(|v| T::from(v).expect("could not cast Time32Second to T")), - Series::Time64Nanosecond(a) => a - .$agg() - .map(|v| T::from(v).expect("could not cast Time64Nanosecond to T")), - Series::Time64Microsecond(a) => a - .$agg() - .map(|v| T::from(v).expect("could not cast Time64Microsecond to T")), - Series::DurationNanosecond(a) => a - .$agg() - .map(|v| T::from(v).expect("could not cast DurationNanosecond to T")), - Series::DurationMicrosecond(a) => a - .$agg() - .map(|v| T::from(v).expect("could not cast DurationMicrosecond to T")), - Series::DurationMillisecond(a) => a - .$agg() - .map(|v| T::from(v).expect("could not cast DurationMillisecond to T")), - Series::DurationSecond(a) => a - .$agg() - .map(|v| T::from(v).expect("could not cast DurationSecond to T")), - Series::TimestampNanosecond(a) => a - .$agg() - .map(|v| T::from(v).expect("could not cast TimestampNanosecond to T")), - Series::TimestampMicrosecond(a) => a - .$agg() - .map(|v| T::from(v).expect("could not cast TimestampMicrosecond to T")), - Series::TimestampMillisecond(a) => a - .$agg() - .map(|v| T::from(v).expect("could not cast TimestampMillisecond to T")), - Series::TimestampSecond(a) => a - .$agg() - .map(|v| T::from(v).expect("could not cast TimestampSecond to T")), - Series::IntervalDayTime(a) => a - .$agg() - .map(|v| T::from(v).expect("could not cast IntervalDayTime to T")), - Series::IntervalYearMonth(a) => a - .$agg() - .map(|v| T::from(v).expect("could not cast IntervalYearMonth to T")), - Series::Utf8(_a) => unimplemented!(), - Series::LargeList(_a) => unimplemented!(), - } - }; -} - -impl Series { - /// Returns `None` if the array is empty or only contains null values. - /// ``` - /// # use polars::prelude::*; - /// let s = Series::new("days", [1, 2, 3].as_ref()); - /// assert_eq!(s.sum(), Some(6)); - /// ``` - pub fn sum(&self) -> Option - where - T: NumCast + Zero + ToPrimitive, - { - apply_agg_fn!(self, sum) - } - - /// Returns the minimum value in the array, according to the natural order. - /// Returns an option because the array is nullable. - /// ``` - /// # use polars::prelude::*; - /// let s = Series::new("days", [1, 2, 3].as_ref()); - /// assert_eq!(s.min(), Some(1)); - /// ``` - pub fn min(&self) -> Option - where - T: NumCast + Zero + ToPrimitive, - { - apply_agg_fn!(self, min) - } - - /// Returns the maximum value in the array, according to the natural order. - /// Returns an option because the array is nullable. - /// ``` - /// # use polars::prelude::*; - /// let s = Series::new("days", [1, 2, 3].as_ref()); - /// assert_eq!(s.max(), Some(3)); - /// ``` - pub fn max(&self) -> Option - where - T: NumCast + Zero + ToPrimitive, - { - apply_agg_fn!(self, max) - } - - pub fn mean(&self) -> Option - where - T: NumCast + Zero + ToPrimitive + Div, - { - apply_agg_fn!(self, sum).map(|v| v / T::from(self.len()).unwrap()) - } -} - -#[cfg(test)] -mod test { - use crate::prelude::*; - - #[test] - fn test_agg_bool() { - let s = Series::new("", vec![true, false, true].as_slice()); - assert_eq!(s.max::(), Some(1)); - assert_eq!(s.min::(), Some(0)); - } -} diff --git a/polars/src/series/arithmetic.rs b/polars/src/series/arithmetic.rs deleted file mode 100644 index b466142f4df3..000000000000 --- a/polars/src/series/arithmetic.rs +++ /dev/null @@ -1,480 +0,0 @@ -use crate::prelude::*; -use enum_dispatch::enum_dispatch; -use num::{Num, NumCast, ToPrimitive}; -use std::ops; - -#[enum_dispatch(Series)] -pub(super) trait NumOpsDispatch { - fn subtract(&self, _rhs: &Series) -> Result { - Err(PolarsError::InvalidOperation) - } - fn add_to(&self, _rhs: &Series) -> Result { - Err(PolarsError::InvalidOperation) - } - fn multiply(&self, _rhs: &Series) -> Result { - Err(PolarsError::InvalidOperation) - } - fn divide(&self, _rhs: &Series) -> Result { - Err(PolarsError::InvalidOperation) - } -} - -impl NumOpsDispatch for ChunkedArray -where - T: PolarsNumericType, - T::Native: ops::Add - + ops::Sub - + ops::Mul - + ops::Div - + num::Zero - + num::One, -{ - fn subtract(&self, rhs: &Series) -> Result { - let rhs = self.unpack_series_matching_type(rhs)?; - let out = self - rhs; - Ok(out.into_series()) - } - fn add_to(&self, rhs: &Series) -> Result { - let rhs = self.unpack_series_matching_type(rhs)?; - let out = self + rhs; - Ok(out.into_series()) - } - fn multiply(&self, rhs: &Series) -> Result { - let rhs = self.unpack_series_matching_type(rhs)?; - let out = self * rhs; - Ok(out.into_series()) - } - fn divide(&self, rhs: &Series) -> Result { - let rhs = self.unpack_series_matching_type(rhs)?; - let out = self / rhs; - Ok(out.into_series()) - } -} - -impl NumOpsDispatch for Utf8Chunked {} -impl NumOpsDispatch for BooleanChunked {} -impl NumOpsDispatch for LargeListChunked {} - -impl ops::Sub for Series { - type Output = Self; - - fn sub(self, rhs: Self) -> Self::Output { - (&self).subtract(&rhs).expect("data types don't match") - } -} - -impl ops::Add for Series { - type Output = Self; - - fn add(self, rhs: Self) -> Self::Output { - (&self).add_to(&rhs).expect("data types don't match") - } -} - -impl std::ops::Mul for Series { - type Output = Self; - - fn mul(self, rhs: Self) -> Self::Output { - (&self).multiply(&rhs).expect("data types don't match") - } -} - -impl std::ops::Div for Series { - type Output = Self; - - fn div(self, rhs: Self) -> Self::Output { - (&self).divide(&rhs).expect("data types don't match") - } -} - -// Same only now for referenced data types - -impl ops::Sub for &Series { - type Output = Series; - - fn sub(self, rhs: Self) -> Self::Output { - (&self).subtract(rhs).expect("data types don't match") - } -} - -impl ops::Add for &Series { - type Output = Series; - - fn add(self, rhs: Self) -> Self::Output { - (&self).add_to(rhs).expect("data types don't match") - } -} - -impl std::ops::Mul for &Series { - type Output = Series; - - /// ``` - /// # use polars::prelude::*; - /// let s: Series = [1, 2, 3].iter().collect(); - /// let out = &s * &s; - /// ``` - fn mul(self, rhs: Self) -> Self::Output { - (&self).multiply(rhs).expect("data types don't match") - } -} - -impl std::ops::Div for &Series { - type Output = Series; - - /// ``` - /// # use polars::prelude::*; - /// let s: Series = [1, 2, 3].iter().collect(); - /// let out = &s / &s; - /// ``` - fn div(self, rhs: Self) -> Self::Output { - (&self).divide(rhs).expect("data types don't match") - } -} - -// Series +-/* numbers instead of Series - -#[enum_dispatch(Series)] -pub(super) trait NumOpsDispatchSeriesSingleNumber { - fn subtract_number(&self, _rhs: N) -> Series { - unimplemented!() - } - fn add_number(&self, _rhs: N) -> Series { - unimplemented!() - } - fn multiply_number(&self, _rhs: N) -> Series { - unimplemented!() - } - fn divide_number(&self, _rhs: N) -> Series { - unimplemented!() - } -} - -impl NumOpsDispatchSeriesSingleNumber for BooleanChunked {} -impl NumOpsDispatchSeriesSingleNumber for Utf8Chunked {} -impl NumOpsDispatchSeriesSingleNumber for LargeListChunked {} - -impl NumOpsDispatchSeriesSingleNumber for ChunkedArray -where - T: PolarsNumericType, - T::Native: Num - + NumCast - + ops::Add - + ops::Sub - + ops::Mul - + ops::Div, -{ - fn subtract_number(&self, rhs: N) -> Series { - let rhs: T::Native = NumCast::from(rhs).expect(&format!("could not cast")); - let mut ca: ChunkedArray = self - .into_iter() - .map(|opt_v| opt_v.map(|v| v - rhs)) - .collect(); - ca.rename(self.name()); - ca.into_series() - } - - fn add_number(&self, rhs: N) -> Series { - let rhs: T::Native = NumCast::from(rhs).expect(&format!("could not cast")); - let mut ca: ChunkedArray = self - .into_iter() - .map(|opt_v| opt_v.map(|v| v + rhs)) - .collect(); - ca.rename(self.name()); - ca.into_series() - } - fn multiply_number(&self, rhs: N) -> Series { - let rhs: T::Native = NumCast::from(rhs).expect(&format!("could not cast")); - let mut ca: ChunkedArray = self - .into_iter() - .map(|opt_v| opt_v.map(|v| v * rhs)) - .collect(); - ca.rename(self.name()); - ca.into_series() - } - fn divide_number(&self, rhs: N) -> Series { - let rhs: T::Native = NumCast::from(rhs).expect(&format!("could not cast")); - let mut ca: ChunkedArray = self - .into_iter() - .map(|opt_v| opt_v.map(|v| v / rhs)) - .collect(); - ca.rename(self.name()); - ca.into_series() - } -} - -impl ops::Sub for &Series -where - T: Num + NumCast, -{ - type Output = Series; - - fn sub(self, rhs: T) -> Self::Output { - self.subtract_number(rhs) - } -} - -impl ops::Sub for Series -where - T: Num + NumCast, -{ - type Output = Self; - - fn sub(self, rhs: T) -> Self::Output { - (&self).sub(rhs) - } -} - -impl ops::Add for &Series -where - T: Num + NumCast, -{ - type Output = Series; - - fn add(self, rhs: T) -> Self::Output { - self.add_number(rhs) - } -} - -impl ops::Add for Series -where - T: Num + NumCast, -{ - type Output = Self; - - fn add(self, rhs: T) -> Self::Output { - (&self).add(rhs) - } -} - -impl ops::Div for &Series -where - T: Num + NumCast, -{ - type Output = Series; - - fn div(self, rhs: T) -> Self::Output { - self.divide_number(rhs) - } -} - -impl ops::Div for Series -where - T: Num + NumCast, -{ - type Output = Self; - - fn div(self, rhs: T) -> Self::Output { - (&self).div(rhs) - } -} - -impl ops::Mul for &Series -where - T: Num + NumCast, -{ - type Output = Series; - - fn mul(self, rhs: T) -> Self::Output { - self.multiply_number(rhs) - } -} - -impl ops::Mul for Series -where - T: Num + NumCast, -{ - type Output = Self; - - fn mul(self, rhs: T) -> Self::Output { - (&self).mul(rhs) - } -} - -/// We cannot override the left hand side behaviour. So we create a trait Lhs num ops. -/// This allows for 1.add(&Series) - -#[enum_dispatch(Series)] -pub(super) trait LhsNumOpsDispatch { - fn lhs_subtract_number(&self, _lhs: N) -> Series { - unimplemented!() - } - fn lhs_add_number(&self, _lhs: N) -> Series { - unimplemented!() - } - fn lhs_multiply_number(&self, _lhs: N) -> Series { - unimplemented!() - } - fn lhs_divide_number(&self, _lhs: N) -> Series { - unimplemented!() - } -} - -impl LhsNumOpsDispatch for BooleanChunked {} -impl LhsNumOpsDispatch for Utf8Chunked {} -impl LhsNumOpsDispatch for LargeListChunked {} - -impl LhsNumOpsDispatch for ChunkedArray -where - T: PolarsNumericType, - T::Native: Num - + NumCast - + ops::Add - + ops::Sub - + ops::Mul - + ops::Div, -{ - fn lhs_subtract_number(&self, lhs: N) -> Series { - let lhs: T::Native = NumCast::from(lhs).expect(&format!("could not cast")); - let mut ca: ChunkedArray = self - .into_iter() - .map(|opt_v| opt_v.map(|v| lhs - v)) - .collect(); - ca.rename(self.name()); - ca.into_series() - } - - fn lhs_add_number(&self, lhs: N) -> Series { - let lhs: T::Native = NumCast::from(lhs).expect(&format!("could not cast")); - let mut ca: ChunkedArray = self - .into_iter() - .map(|opt_v| opt_v.map(|v| lhs + v)) - .collect(); - ca.rename(self.name()); - ca.into_series() - } - fn lhs_multiply_number(&self, lhs: N) -> Series { - let lhs: T::Native = NumCast::from(lhs).expect(&format!("could not cast")); - let mut ca: ChunkedArray = self - .into_iter() - .map(|opt_v| opt_v.map(|v| lhs * v)) - .collect(); - ca.rename(self.name()); - ca.into_series() - } - fn lhs_divide_number(&self, lhs: N) -> Series { - let lhs: T::Native = NumCast::from(lhs).expect(&format!("could not cast")); - let mut ca: ChunkedArray = self - .into_iter() - .map(|opt_v| opt_v.map(|v| lhs / v)) - .collect(); - ca.rename(self.name()); - ca.into_series() - } -} - -pub trait LhsNumOps { - type Output; - - fn add(self, rhs: &Series) -> Self::Output; - fn sub(self, rhs: &Series) -> Self::Output; - fn div(self, rhs: &Series) -> Self::Output; - fn mul(self, rhs: &Series) -> Self::Output; -} - -impl LhsNumOps for T -where - T: Num + NumCast, -{ - type Output = Series; - - fn add(self, rhs: &Series) -> Self::Output { - rhs.lhs_add_number(self) - } - fn sub(self, rhs: &Series) -> Self::Output { - rhs.lhs_subtract_number(self) - } - fn div(self, rhs: &Series) -> Self::Output { - rhs.lhs_divide_number(self) - } - fn mul(self, rhs: &Series) -> Self::Output { - rhs.lhs_multiply_number(self) - } -} - -// TODO: use enum dispatch -impl Series { - fn pow(&self, exp: E) -> Series - where - E: ToPrimitive, - { - match self { - Series::UInt8(ca) => Series::Float32(ca.pow_f32(exp.to_f32().unwrap())), - Series::UInt16(ca) => Series::Float32(ca.pow_f32(exp.to_f32().unwrap())), - Series::UInt32(ca) => Series::Float32(ca.pow_f32(exp.to_f32().unwrap())), - Series::UInt64(ca) => Series::Float64(ca.pow_f64(exp.to_f64().unwrap())), - Series::Int8(ca) => Series::Float32(ca.pow_f32(exp.to_f32().unwrap())), - Series::Int16(ca) => Series::Float32(ca.pow_f32(exp.to_f32().unwrap())), - Series::Int32(ca) => Series::Float32(ca.pow_f32(exp.to_f32().unwrap())), - Series::Int64(ca) => Series::Float64(ca.pow_f64(exp.to_f64().unwrap())), - Series::Float32(ca) => Series::Float32(ca.pow_f32(exp.to_f32().unwrap())), - Series::Float64(ca) => Series::Float64(ca.pow_f64(exp.to_f64().unwrap())), - Series::Date32(ca) => Series::Float32(ca.pow_f32(exp.to_f32().unwrap())), - Series::Date64(ca) => Series::Float64(ca.pow_f64(exp.to_f64().unwrap())), - Series::Time64Nanosecond(ca) => Series::Float64(ca.pow_f64(exp.to_f64().unwrap())), - Series::DurationNanosecond(ca) => Series::Float64(ca.pow_f64(exp.to_f64().unwrap())), - _ => unimplemented!(), - } - } -} - -#[cfg(test)] -mod test { - use crate::prelude::*; - - #[test] - fn test_arithmetic_series() { - // Series +-/* Series - let s: Series = [1, 2, 3].iter().collect(); - assert_eq!( - Vec::from((&s * &s).i32().unwrap()), - [Some(1), Some(4), Some(9)] - ); - assert_eq!( - Vec::from((&s / &s).i32().unwrap()), - [Some(1), Some(1), Some(1)] - ); - assert_eq!( - Vec::from((&s - &s).i32().unwrap()), - [Some(0), Some(0), Some(0)] - ); - assert_eq!( - Vec::from((&s + &s).i32().unwrap()), - [Some(2), Some(4), Some(6)] - ); - // Series +-/* Number - assert_eq!( - Vec::from((&s + 1).i32().unwrap()), - [Some(2), Some(3), Some(4)] - ); - assert_eq!( - Vec::from((&s - 1).i32().unwrap()), - [Some(0), Some(1), Some(2)] - ); - assert_eq!( - Vec::from((&s * 2).i32().unwrap()), - [Some(2), Some(4), Some(6)] - ); - assert_eq!( - Vec::from((&s / 2).i32().unwrap()), - [Some(0), Some(1), Some(1)] - ); - - // Lhs operations - assert_eq!( - Vec::from((1.add(&s)).i32().unwrap()), - [Some(2), Some(3), Some(4)] - ); - assert_eq!( - Vec::from((1.sub(&s)).i32().unwrap()), - [Some(0), Some(-1), Some(-2)] - ); - assert_eq!( - Vec::from((1.div(&s)).i32().unwrap()), - [Some(1), Some(0), Some(0)] - ); - assert_eq!( - Vec::from((1.mul(&s)).i32().unwrap()), - [Some(1), Some(2), Some(3)] - ); - } -} diff --git a/polars/src/series/comparison.rs b/polars/src/series/comparison.rs deleted file mode 100644 index dbbdc019f204..000000000000 --- a/polars/src/series/comparison.rs +++ /dev/null @@ -1,182 +0,0 @@ -//! Comparison operations on Series. - -use super::Series; -use crate::apply_method_numeric_series; -use crate::prelude::*; - -fn fill_bool(val: bool, len: usize) -> BooleanChunked { - std::iter::repeat(val).take(len).collect() -} - -macro_rules! compare { - ($variant:path, $lhs:ident, $rhs:ident, $cmp_method:ident) => {{ - if let $variant(rhs_) = $rhs { - $lhs.$cmp_method(rhs_) - } else { - fill_bool(false, $lhs.len()) - } - }}; -} - -macro_rules! impl_compare { - ($self:ident, $rhs:ident, $method:ident) => {{ - match $self { - Series::Bool(a) => compare!(Series::Bool, a, $rhs, $method), - Series::UInt8(a) => compare!(Series::UInt8, a, $rhs, $method), - Series::UInt16(a) => compare!(Series::UInt16, a, $rhs, $method), - Series::UInt32(a) => compare!(Series::UInt32, a, $rhs, $method), - Series::UInt64(a) => compare!(Series::UInt64, a, $rhs, $method), - Series::Int8(a) => compare!(Series::Int8, a, $rhs, $method), - Series::Int16(a) => compare!(Series::Int16, a, $rhs, $method), - Series::Int32(a) => compare!(Series::Int32, a, $rhs, $method), - Series::Int64(a) => compare!(Series::Int64, a, $rhs, $method), - Series::Float32(a) => compare!(Series::Float32, a, $rhs, $method), - Series::Float64(a) => compare!(Series::Float64, a, $rhs, $method), - Series::Utf8(a) => compare!(Series::Utf8, a, $rhs, $method), - Series::Date32(a) => compare!(Series::Date32, a, $rhs, $method), - Series::Date64(a) => compare!(Series::Date64, a, $rhs, $method), - Series::Time32Millisecond(a) => compare!(Series::Time32Millisecond, a, $rhs, $method), - Series::Time32Second(a) => compare!(Series::Time32Second, a, $rhs, $method), - Series::Time64Nanosecond(a) => compare!(Series::Time64Nanosecond, a, $rhs, $method), - Series::Time64Microsecond(a) => compare!(Series::Time64Microsecond, a, $rhs, $method), - Series::DurationNanosecond(a) => compare!(Series::DurationNanosecond, a, $rhs, $method), - Series::DurationMicrosecond(a) => { - compare!(Series::DurationMicrosecond, a, $rhs, $method) - } - Series::DurationMillisecond(a) => { - compare!(Series::DurationMillisecond, a, $rhs, $method) - } - Series::DurationSecond(a) => compare!(Series::DurationSecond, a, $rhs, $method), - Series::TimestampNanosecond(a) => { - compare!(Series::TimestampNanosecond, a, $rhs, $method) - } - Series::TimestampMicrosecond(a) => { - compare!(Series::TimestampMicrosecond, a, $rhs, $method) - } - Series::TimestampMillisecond(a) => { - compare!(Series::TimestampMillisecond, a, $rhs, $method) - } - Series::TimestampSecond(a) => compare!(Series::TimestampSecond, a, $rhs, $method), - Series::IntervalDayTime(a) => compare!(Series::IntervalDayTime, a, $rhs, $method), - Series::IntervalYearMonth(a) => compare!(Series::IntervalYearMonth, a, $rhs, $method), - Series::LargeList(a) => compare!(Series::LargeList, a, $rhs, $method), - } - }}; -} - -impl ChunkCompare<&Series> for Series { - fn eq_missing(&self, rhs: &Series) -> BooleanChunked { - impl_compare!(self, rhs, eq_missing) - } - - /// Create a boolean mask by checking for equality. - fn eq(&self, rhs: &Series) -> BooleanChunked { - impl_compare!(self, rhs, eq) - } - - /// Create a boolean mask by checking for inequality. - fn neq(&self, rhs: &Series) -> BooleanChunked { - impl_compare!(self, rhs, neq) - } - - /// Create a boolean mask by checking if lhs > rhs. - fn gt(&self, rhs: &Series) -> BooleanChunked { - impl_compare!(self, rhs, gt) - } - - /// Create a boolean mask by checking if lhs >= rhs. - fn gt_eq(&self, rhs: &Series) -> BooleanChunked { - impl_compare!(self, rhs, gt_eq) - } - - /// Create a boolean mask by checking if lhs < rhs. - fn lt(&self, rhs: &Series) -> BooleanChunked { - impl_compare!(self, rhs, lt) - } - - /// Create a boolean mask by checking if lhs <= rhs. - fn lt_eq(&self, rhs: &Series) -> BooleanChunked { - impl_compare!(self, rhs, lt_eq) - } -} - -impl ChunkCompare for Series -where - Rhs: NumComp, -{ - fn eq_missing(&self, rhs: Rhs) -> BooleanChunked { - self.eq(rhs) - } - - fn eq(&self, rhs: Rhs) -> BooleanChunked { - apply_method_numeric_series!(self, eq, rhs) - } - - fn neq(&self, rhs: Rhs) -> BooleanChunked { - apply_method_numeric_series!(self, neq, rhs) - } - - fn gt(&self, rhs: Rhs) -> BooleanChunked { - apply_method_numeric_series!(self, gt, rhs) - } - - fn gt_eq(&self, rhs: Rhs) -> BooleanChunked { - apply_method_numeric_series!(self, gt_eq, rhs) - } - - fn lt(&self, rhs: Rhs) -> BooleanChunked { - apply_method_numeric_series!(self, lt, rhs) - } - - fn lt_eq(&self, rhs: Rhs) -> BooleanChunked { - apply_method_numeric_series!(self, lt_eq, rhs) - } -} - -impl ChunkCompare<&str> for Series { - fn eq_missing(&self, rhs: &str) -> BooleanChunked { - self.eq(rhs) - } - - fn eq(&self, rhs: &str) -> BooleanChunked { - match self { - Series::Utf8(a) => a.eq(rhs), - _ => std::iter::repeat(false).take(self.len()).collect(), - } - } - - fn neq(&self, rhs: &str) -> BooleanChunked { - match self { - Series::Utf8(a) => a.neq(rhs), - _ => std::iter::repeat(false).take(self.len()).collect(), - } - } - - fn gt(&self, rhs: &str) -> BooleanChunked { - match self { - Series::Utf8(a) => a.gt(rhs), - _ => std::iter::repeat(false).take(self.len()).collect(), - } - } - - fn gt_eq(&self, rhs: &str) -> BooleanChunked { - match self { - Series::Utf8(a) => a.gt_eq(rhs), - _ => std::iter::repeat(false).take(self.len()).collect(), - } - } - - fn lt(&self, rhs: &str) -> BooleanChunked { - match self { - Series::Utf8(a) => a.lt(rhs), - _ => std::iter::repeat(false).take(self.len()).collect(), - } - } - - fn lt_eq(&self, rhs: &str) -> BooleanChunked { - match self { - Series::Utf8(a) => a.lt_eq(rhs), - _ => std::iter::repeat(false).take(self.len()).collect(), - } - } -} diff --git a/polars/src/series/iterator.rs b/polars/src/series/iterator.rs deleted file mode 100644 index 4f2b0e7f8013..000000000000 --- a/polars/src/series/iterator.rs +++ /dev/null @@ -1,82 +0,0 @@ -use crate::prelude::*; -use std::iter::FromIterator; - -macro_rules! from_iterator { - ($native:ty, $variant:ident) => { - impl FromIterator> for Series { - fn from_iter>>(iter: I) -> Self { - let ca = iter.into_iter().collect(); - Series::$variant(ca) - } - } - - impl FromIterator<$native> for Series { - fn from_iter>(iter: I) -> Self { - let ca = iter.into_iter().map(|v| Some(v)).collect(); - Series::$variant(ca) - } - } - - impl<'a> FromIterator<&'a $native> for Series { - fn from_iter>(iter: I) -> Self { - let ca = iter.into_iter().map(|v| Some(*v)).collect(); - Series::$variant(ca) - } - } - }; -} - -from_iterator!(u8, UInt8); -from_iterator!(u16, UInt16); -from_iterator!(u32, UInt32); -from_iterator!(u64, UInt64); -from_iterator!(i8, Int8); -from_iterator!(i16, Int16); -from_iterator!(i32, Int32); -from_iterator!(i64, Int64); -from_iterator!(f32, Float32); -from_iterator!(f64, Float64); -from_iterator!(bool, Bool); - -impl<'a> FromIterator<&'a str> for Series { - fn from_iter>(iter: I) -> Self { - let ca = iter.into_iter().collect(); - Series::Utf8(ca) - } -} - -impl<'a> FromIterator<&'a Series> for Series { - fn from_iter>(iter: I) -> Self { - let ca = iter.into_iter().collect(); - Series::LargeList(ca) - } -} - -impl FromIterator for Series { - fn from_iter>(iter: I) -> Self { - let ca = iter.into_iter().collect(); - Series::LargeList(ca) - } -} - -impl FromIterator> for Series { - fn from_iter>>(iter: I) -> Self { - let ca = iter.into_iter().collect(); - Series::LargeList(ca) - } -} - -#[cfg(test)] -mod test { - use crate::prelude::*; - - #[test] - fn test_iter() { - let a = Series::new("age", [23, 71, 9].as_ref()); - let _b = a - .i32() - .unwrap() - .into_iter() - .map(|opt_v| opt_v.map(|v| v * 2)); - } -} diff --git a/polars/src/series/mod.rs b/polars/src/series/mod.rs deleted file mode 100644 index 50cb6c5acaf4..000000000000 --- a/polars/src/series/mod.rs +++ /dev/null @@ -1,1168 +0,0 @@ -//! Type agnostic columnar data structure. -pub use crate::prelude::ChunkCompare; -use crate::prelude::*; -use arrow::{array::ArrayRef, buffer::Buffer}; -use std::mem; - -pub(crate) mod aggregate; -pub(crate) mod arithmetic; -mod comparison; -pub(crate) mod iterator; -use arithmetic::{LhsNumOpsDispatch, NumOpsDispatch, NumOpsDispatchSeriesSingleNumber}; -// needed for enum dispatch -use crate::fmt::FmtLargeList; -use arrow::array::ArrayDataRef; -use enum_dispatch::enum_dispatch; -use num::{Num, NumCast}; - -/// # Series -/// The columnar data type for a DataFrame. The [Series enum](enum.Series.html) consists -/// of typed [ChunkedArray](../chunked_array/struct.ChunkedArray.html)'s. To quickly cast -/// a `Series` to a `ChunkedArray` you can call the method with the name of the type: -/// -/// ``` -/// # use polars::prelude::*; -/// let s: Series = [1, 2, 3].iter().collect(); -/// // Quickly obtain the ChunkedArray wrapped by the Series. -/// let chunked_array = s.i32().unwrap(); -/// ``` -/// -/// ## Arithmetic -/// -/// You can do standard arithmetic on series. -/// ``` -/// # use polars::prelude::*; -/// let s: Series = [1, 2, 3].iter().collect(); -/// let out_add = &s + &s; -/// let out_sub = &s - &s; -/// let out_div = &s / &s; -/// let out_mul = &s * &s; -/// ``` -/// -/// Or with series and numbers. -/// -/// ``` -/// # use polars::prelude::*; -/// let s: Series = (1..3).collect(); -/// let out_add_one = &s + 1; -/// let out_multiply = &s * 10; -/// -/// // Could not overload left hand side operator. -/// let out_divide = 1.div(&s); -/// let out_add = 1.add(&s); -/// let out_subtract = 1.sub(&s); -/// let out_multiply = 1.mul(&s); -/// ``` -/// -/// ## Comparison -/// You can obtain boolean mask by comparing series. -/// -/// ``` -/// # use polars::prelude::*; -/// use itertools::Itertools; -/// let s = Series::new("dollars", &[1, 2, 3]); -/// let mask = s.eq(1); -/// let valid = [true, false, false].iter(); -/// assert!(mask -/// .into_iter() -/// .map(|opt_bool| opt_bool.unwrap()) // option, because series can be null -/// .zip(valid) -/// .all(|(a, b)| a == *b)) -/// ``` -/// -/// See all the comparison operators in the [CmpOps trait](../chunked_array/comparison/trait.CmpOps.html) -/// -/// ## Iterators -/// The Series variants contain differently typed [ChunkedArray's](../chunked_array/struct.ChunkedArray.html). -/// These structs can be turned into iterators, making it possible to use any function/ closure you want -/// on a Series. -/// -/// These iterators return an `Option` because the values of a series may be null. -/// -/// ``` -/// use polars::prelude::*; -/// let pi = 3.14; -/// let s = Series::new("angle", [2f32 * pi, pi, 1.5 * pi].as_ref()); -/// let s_cos: Series = s.f32() -/// .expect("series was not an f32 dtype") -/// .into_iter() -/// .map(|opt_angle| opt_angle.map(|angle| angle.cos())) -/// .collect(); -/// ``` -/// -/// ## Creation -/// Series can be create from different data structures. Below we'll show a few ways we can create -/// a Series object. -/// -/// ``` -/// # use polars::prelude::*; -/// // Series van be created from Vec's, slices and arrays -/// Series::new("boolean series", &vec![true, false, true]); -/// Series::new("int series", &[1, 2, 3]); -/// // And can be nullable -/// Series::new("got nulls", &[Some(1), None, Some(2)]); -/// -/// // Series can also be collected from iterators -/// let from_iter: Series = (0..10) -/// .into_iter() -/// .collect(); -/// -/// ``` -#[enum_dispatch] -#[derive(Clone)] -pub enum Series { - UInt8(ChunkedArray), - UInt16(ChunkedArray), - UInt32(ChunkedArray), - UInt64(ChunkedArray), - Int8(ChunkedArray), - Int16(ChunkedArray), - Int32(ChunkedArray), - Int64(ChunkedArray), - Float32(ChunkedArray), - Float64(ChunkedArray), - Utf8(ChunkedArray), - Bool(ChunkedArray), - Date32(ChunkedArray), - Date64(ChunkedArray), - Time32Millisecond(Time32MillisecondChunked), - Time32Second(Time32SecondChunked), - Time64Nanosecond(ChunkedArray), - Time64Microsecond(ChunkedArray), - DurationNanosecond(ChunkedArray), - DurationMicrosecond(DurationMicrosecondChunked), - DurationMillisecond(DurationMillisecondChunked), - DurationSecond(DurationSecondChunked), - IntervalDayTime(IntervalDayTimeChunked), - IntervalYearMonth(IntervalYearMonthChunked), - TimestampNanosecond(TimestampNanosecondChunked), - TimestampMicrosecond(TimestampMicrosecondChunked), - TimestampMillisecond(TimestampMillisecondChunked), - TimestampSecond(TimestampSecondChunked), - LargeList(LargeListChunked), -} - -#[macro_export] -macro_rules! apply_method_all_series { - ($self:ident, $method:ident, $($args:expr),*) => { - match $self { - Series::Utf8(a) => a.$method($($args),*), - Series::Bool(a) => a.$method($($args),*), - Series::UInt8(a) => a.$method($($args),*), - Series::UInt16(a) => a.$method($($args),*), - Series::UInt32(a) => a.$method($($args),*), - Series::UInt64(a) => a.$method($($args),*), - Series::Int8(a) => a.$method($($args),*), - Series::Int16(a) => a.$method($($args),*), - Series::Int32(a) => a.$method($($args),*), - Series::Int64(a) => a.$method($($args),*), - Series::Float32(a) => a.$method($($args),*), - Series::Float64(a) => a.$method($($args),*), - Series::Date32(a) => a.$method($($args),*), - Series::Date64(a) => a.$method($($args),*), - Series::Time32Millisecond(a) => a.$method($($args),*), - Series::Time32Second(a) => a.$method($($args),*), - Series::Time64Nanosecond(a) => a.$method($($args),*), - Series::Time64Microsecond(a) => a.$method($($args),*), - Series::DurationNanosecond(a) => a.$method($($args),*), - Series::DurationMicrosecond(a) => a.$method($($args),*), - Series::DurationMillisecond(a) => a.$method($($args),*), - Series::DurationSecond(a) => a.$method($($args),*), - Series::TimestampNanosecond(a) => a.$method($($args),*), - Series::TimestampMicrosecond(a) => a.$method($($args),*), - Series::TimestampMillisecond(a) => a.$method($($args),*), - Series::TimestampSecond(a) => a.$method($($args),*), - Series::IntervalDayTime(a) => a.$method($($args),*), - Series::IntervalYearMonth(a) => a.$method($($args),*), - Series::LargeList(a) => a.$method($($args),*), - } - } -} - -// doesn't include Bool and Utf8 -#[macro_export] -macro_rules! apply_method_numeric_series { - ($self:ident, $method:ident, $($args:expr),*) => { - match $self { - Series::UInt8(a) => a.$method($($args),*), - Series::UInt16(a) => a.$method($($args),*), - Series::UInt32(a) => a.$method($($args),*), - Series::UInt64(a) => a.$method($($args),*), - Series::Int8(a) => a.$method($($args),*), - Series::Int16(a) => a.$method($($args),*), - Series::Int32(a) => a.$method($($args),*), - Series::Int64(a) => a.$method($($args),*), - Series::Float32(a) => a.$method($($args),*), - Series::Float64(a) => a.$method($($args),*), - Series::Date32(a) => a.$method($($args),*), - Series::Date64(a) => a.$method($($args),*), - Series::Time32Millisecond(a) => a.$method($($args),*), - Series::Time32Second(a) => a.$method($($args),*), - Series::Time64Nanosecond(a) => a.$method($($args),*), - Series::Time64Microsecond(a) => a.$method($($args),*), - Series::DurationNanosecond(a) => a.$method($($args),*), - Series::DurationMicrosecond(a) => a.$method($($args),*), - Series::DurationMillisecond(a) => a.$method($($args),*), - Series::DurationSecond(a) => a.$method($($args),*), - Series::TimestampNanosecond(a) => a.$method($($args),*), - Series::TimestampMicrosecond(a) => a.$method($($args),*), - Series::TimestampMillisecond(a) => a.$method($($args),*), - Series::TimestampSecond(a) => a.$method($($args),*), - Series::IntervalDayTime(a) => a.$method($($args),*), - Series::IntervalYearMonth(a) => a.$method($($args),*), - _ => unimplemented!(), - } - } -} - -#[macro_export] -macro_rules! apply_method_numeric_series_and_return { - ($self:ident, $method:ident, [$($args:expr),*], $($opt_question_mark:tt)*) => { - match $self { - Series::UInt8(a) => Series::UInt8(a.$method($($args),*)$($opt_question_mark)*), - Series::UInt16(a) => Series::UInt16(a.$method($($args),*)$($opt_question_mark)*), - Series::UInt32(a) => Series::UInt32(a.$method($($args),*)$($opt_question_mark)*), - Series::UInt64(a) => Series::UInt64(a.$method($($args),*)$($opt_question_mark)*), - Series::Int8(a) => Series::Int8(a.$method($($args),*)$($opt_question_mark)*), - Series::Int16(a) => Series::Int16(a.$method($($args),*)$($opt_question_mark)*), - Series::Int32(a) => Series::Int32(a.$method($($args),*)$($opt_question_mark)*), - Series::Int64(a) => Series::Int64(a.$method($($args),*)$($opt_question_mark)*), - Series::Float32(a) => Series::Float32(a.$method($($args),*)$($opt_question_mark)*), - Series::Float64(a) => Series::Float64(a.$method($($args),*)$($opt_question_mark)*), - Series::Date32(a) => Series::Date32(a.$method($($args),*)$($opt_question_mark)*), - Series::Date64(a) => Series::Date64(a.$method($($args),*)$($opt_question_mark)*), - Series::Time32Millisecond(a) => Series::Time32Millisecond(a.$method($($args),*)$($opt_question_mark)*), - Series::Time32Second(a) => Series::Time32Second(a.$method($($args),*)$($opt_question_mark)*), - Series::Time64Nanosecond(a) => Series::Time64Nanosecond(a.$method($($args),*)$($opt_question_mark)*), - Series::Time64Microsecond(a) => Series::Time64Microsecond(a.$method($($args),*)$($opt_question_mark)*), - Series::DurationNanosecond(a) => Series::DurationNanosecond(a.$method($($args),*)$($opt_question_mark)*), - Series::DurationMicrosecond(a) => Series::DurationMicrosecond(a.$method($($args),*)$($opt_question_mark)*), - Series::DurationMillisecond(a) => Series::DurationMillisecond(a.$method($($args),*)$($opt_question_mark)*), - Series::DurationSecond(a) => Series::DurationSecond(a.$method($($args),*)$($opt_question_mark)*), - Series::TimestampNanosecond(a) => Series::TimestampNanosecond(a.$method($($args),*)$($opt_question_mark)*), - Series::TimestampMicrosecond(a) => Series::TimestampMicrosecond(a.$method($($args),*)$($opt_question_mark)*), - Series::TimestampMillisecond(a) => Series::TimestampMillisecond(a.$method($($args),*)$($opt_question_mark)*), - Series::TimestampSecond(a) => Series::TimestampSecond(a.$method($($args),*)$($opt_question_mark)*), - Series::IntervalDayTime(a) => Series::IntervalDayTime(a.$method($($args),*)$($opt_question_mark)*), - Series::IntervalYearMonth(a) => Series::IntervalYearMonth(a.$method($($args),*)$($opt_question_mark)*), - _ => unimplemented!() - } - } -} - -macro_rules! apply_method_all_series_and_return { - ($self:ident, $method:ident, [$($args:expr),*], $($opt_question_mark:tt)*) => { - match $self { - Series::UInt8(a) => Series::UInt8(a.$method($($args),*)$($opt_question_mark)*), - Series::UInt16(a) => Series::UInt16(a.$method($($args),*)$($opt_question_mark)*), - Series::UInt32(a) => Series::UInt32(a.$method($($args),*)$($opt_question_mark)*), - Series::UInt64(a) => Series::UInt64(a.$method($($args),*)$($opt_question_mark)*), - Series::Int8(a) => Series::Int8(a.$method($($args),*)$($opt_question_mark)*), - Series::Int16(a) => Series::Int16(a.$method($($args),*)$($opt_question_mark)*), - Series::Int32(a) => Series::Int32(a.$method($($args),*)$($opt_question_mark)*), - Series::Int64(a) => Series::Int64(a.$method($($args),*)$($opt_question_mark)*), - Series::Float32(a) => Series::Float32(a.$method($($args),*)$($opt_question_mark)*), - Series::Float64(a) => Series::Float64(a.$method($($args),*)$($opt_question_mark)*), - Series::Utf8(a) => Series::Utf8(a.$method($($args),*)$($opt_question_mark)*), - Series::Bool(a) => Series::Bool(a.$method($($args),*)$($opt_question_mark)*), - Series::Date32(a) => Series::Date32(a.$method($($args),*)$($opt_question_mark)*), - Series::Date64(a) => Series::Date64(a.$method($($args),*)$($opt_question_mark)*), - Series::Time32Millisecond(a) => Series::Time32Millisecond(a.$method($($args),*)$($opt_question_mark)*), - Series::Time32Second(a) => Series::Time32Second(a.$method($($args),*)$($opt_question_mark)*), - Series::Time64Nanosecond(a) => Series::Time64Nanosecond(a.$method($($args),*)$($opt_question_mark)*), - Series::Time64Microsecond(a) => Series::Time64Microsecond(a.$method($($args),*)$($opt_question_mark)*), - Series::DurationNanosecond(a) => Series::DurationNanosecond(a.$method($($args),*)$($opt_question_mark)*), - Series::DurationMicrosecond(a) => Series::DurationMicrosecond(a.$method($($args),*)$($opt_question_mark)*), - Series::DurationMillisecond(a) => Series::DurationMillisecond(a.$method($($args),*)$($opt_question_mark)*), - Series::DurationSecond(a) => Series::DurationSecond(a.$method($($args),*)$($opt_question_mark)*), - Series::TimestampNanosecond(a) => Series::TimestampNanosecond(a.$method($($args),*)$($opt_question_mark)*), - Series::TimestampMicrosecond(a) => Series::TimestampMicrosecond(a.$method($($args),*)$($opt_question_mark)*), - Series::TimestampMillisecond(a) => Series::TimestampMillisecond(a.$method($($args),*)$($opt_question_mark)*), - Series::TimestampSecond(a) => Series::TimestampSecond(a.$method($($args),*)$($opt_question_mark)*), - Series::IntervalDayTime(a) => Series::IntervalDayTime(a.$method($($args),*)$($opt_question_mark)*), - Series::IntervalYearMonth(a) => Series::IntervalYearMonth(a.$method($($args),*)$($opt_question_mark)*), - Series::LargeList(a) => Series::LargeList(a.$method($($args),*)$($opt_question_mark)*), - } - } -} - -macro_rules! unpack_series { - ($self:ident, $variant:ident) => { - if let Series::$variant(ca) = $self { - Ok(ca) - } else { - Err(PolarsError::DataTypeMisMatch) - } - }; -} - -impl Series { - /// Get Arrow ArrayData - pub fn array_data(&self) -> Vec { - apply_method_all_series!(self, array_data,) - } - - pub fn from_chunked_array(ca: ChunkedArray) -> Self { - pack_ca_to_series(ca) - } - - /// Get the lengths of the underlying chunks - pub fn chunk_lengths(&self) -> &Vec { - apply_method_all_series!(self, chunk_id,) - } - /// Name of series. - pub fn name(&self) -> &str { - apply_method_all_series!(self, name,) - } - - /// Rename series. - pub fn rename(&mut self, name: &str) -> &mut Self { - apply_method_all_series!(self, rename, name); - self - } - - /// Get field (used in schema) - pub fn field(&self) -> &Field { - apply_method_all_series!(self, ref_field,) - } - - /// Get datatype of series. - pub fn dtype(&self) -> &ArrowDataType { - self.field().data_type() - } - - /// Underlying chunks. - pub fn chunks(&self) -> &Vec { - apply_method_all_series!(self, chunks,) - } - - /// No. of chunks - pub fn n_chunks(&self) -> usize { - self.chunks().len() - } - - pub fn i8(&self) -> Result<&Int8Chunked> { - unpack_series!(self, Int8) - } - - pub fn i16(&self) -> Result<&Int16Chunked> { - unpack_series!(self, Int16) - } - - /// Unpack to ChunkedArray - /// ``` - /// # use polars::prelude::*; - /// let s: Series = [1, 2, 3].iter().collect(); - /// let s_squared: Series = s.i32() - /// .unwrap() - /// .into_iter() - /// .map(|opt_v| { - /// match opt_v { - /// Some(v) => Some(v * v), - /// None => None, // null value - /// } - /// }).collect(); - /// ``` - pub fn i32(&self) -> Result<&Int32Chunked> { - unpack_series!(self, Int32) - } - - /// Unpack to ChunkedArray - pub fn i64(&self) -> Result<&Int64Chunked> { - unpack_series!(self, Int64) - } - - /// Unpack to ChunkedArray - pub fn f32(&self) -> Result<&Float32Chunked> { - unpack_series!(self, Float32) - } - - /// Unpack to ChunkedArray - pub fn f64(&self) -> Result<&Float64Chunked> { - unpack_series!(self, Float64) - } - - /// Unpack to ChunkedArray - pub fn u8(&self) -> Result<&UInt8Chunked> { - unpack_series!(self, UInt8) - } - - /// Unpack to ChunkedArray - pub fn u16(&self) -> Result<&UInt16Chunked> { - unpack_series!(self, UInt16) - } - - /// Unpack to ChunkedArray - pub fn u32(&self) -> Result<&UInt32Chunked> { - unpack_series!(self, UInt32) - } - - /// Unpack to ChunkedArray - pub fn u64(&self) -> Result<&UInt64Chunked> { - unpack_series!(self, UInt64) - } - - /// Unpack to ChunkedArray - pub fn bool(&self) -> Result<&BooleanChunked> { - unpack_series!(self, Bool) - } - - /// Unpack to ChunkedArray - pub fn utf8(&self) -> Result<&Utf8Chunked> { - unpack_series!(self, Utf8) - } - - /// Unpack to ChunkedArray - pub fn date32(&self) -> Result<&Date32Chunked> { - unpack_series!(self, Date32) - } - - /// Unpack to ChunkedArray - pub fn date64(&self) -> Result<&Date64Chunked> { - unpack_series!(self, Date64) - } - - /// Unpack to ChunkedArray - pub fn time32_millisecond(&self) -> Result<&Time32MillisecondChunked> { - unpack_series!(self, Time32Millisecond) - } - - /// Unpack to ChunkedArray - pub fn time32_second(&self) -> Result<&Time32SecondChunked> { - unpack_series!(self, Time32Second) - } - - /// Unpack to ChunkedArray - pub fn time64_nanosecond(&self) -> Result<&Time64NanosecondChunked> { - unpack_series!(self, Time64Nanosecond) - } - - /// Unpack to ChunkedArray - pub fn time64_microsecond(&self) -> Result<&Time64MicrosecondChunked> { - unpack_series!(self, Time64Microsecond) - } - - /// Unpack to ChunkedArray - pub fn duration_nanosecond(&self) -> Result<&DurationNanosecondChunked> { - unpack_series!(self, DurationNanosecond) - } - - /// Unpack to ChunkedArray - pub fn duration_microsecond(&self) -> Result<&DurationMicrosecondChunked> { - unpack_series!(self, DurationMicrosecond) - } - - /// Unpack to ChunkedArray - pub fn duration_millisecond(&self) -> Result<&DurationMillisecondChunked> { - unpack_series!(self, DurationMillisecond) - } - - /// Unpack to ChunkedArray - pub fn duration_second(&self) -> Result<&DurationSecondChunked> { - unpack_series!(self, DurationSecond) - } - - /// Unpack to ChunkedArray - pub fn timestamp_nanosecond(&self) -> Result<&TimestampNanosecondChunked> { - unpack_series!(self, TimestampNanosecond) - } - - /// Unpack to ChunkedArray - pub fn timestamp_microsecond(&self) -> Result<&TimestampMicrosecondChunked> { - unpack_series!(self, TimestampMicrosecond) - } - - /// Unpack to ChunkedArray - pub fn timestamp_millisecond(&self) -> Result<&TimestampMillisecondChunked> { - unpack_series!(self, TimestampMillisecond) - } - - /// Unpack to ChunkedArray - pub fn timestamp_second(&self) -> Result<&TimestampSecondChunked> { - unpack_series!(self, TimestampSecond) - } - - /// Unpack to ChunkedArray - pub fn interval_daytime(&self) -> Result<&IntervalDayTimeChunked> { - unpack_series!(self, IntervalDayTime) - } - - /// Unpack to ChunkedArray - pub fn interval_year_month(&self) -> Result<&IntervalYearMonthChunked> { - unpack_series!(self, IntervalYearMonth) - } - - /// Unpack to ChunkedArray - pub fn large_list(&self) -> Result<&LargeListChunked> { - unpack_series!(self, LargeList) - } - - pub fn append_array(&mut self, other: ArrayRef) -> Result<&mut Self> { - apply_method_all_series!(self, append_array, other)?; - Ok(self) - } - - /// Take `num_elements` from the top as a zero copy view. - pub fn limit(&self, num_elements: usize) -> Result { - Ok(apply_method_all_series_and_return!(self, limit, [num_elements], ?)) - } - - /// Get a zero copy view of the data. - pub fn slice(&self, offset: usize, length: usize) -> Result { - Ok(apply_method_all_series_and_return!(self, slice, [offset, length], ?)) - } - - /// Append a Series of the same type in place. - pub fn append(&mut self, other: &Self) -> Result<&mut Self> { - if self.dtype() == other.dtype() { - apply_method_all_series!(self, append, other.as_ref()); - Ok(self) - } else { - Err(PolarsError::DataTypeMisMatch) - } - } - - /// Filter by boolean mask. This operation clones data. - pub fn filter>(&self, filter: T) -> Result { - Ok(apply_method_all_series_and_return!(self, filter, [filter.as_ref()], ?)) - } - - /// Take by index from an iterator. This operation clones the data. - pub fn take_iter( - &self, - iter: impl Iterator, - capacity: Option, - ) -> Result { - Ok(apply_method_all_series_and_return!(self, take, [iter, capacity], ?)) - } - - /// Take by index from an iterator. This operation clones the data. - pub unsafe fn take_iter_unchecked( - &self, - iter: impl Iterator, - capacity: Option, - ) -> Self { - apply_method_all_series_and_return!(self, take_unchecked, [iter, capacity],) - } - - /// Take by index from an iterator. This operation clones the data. - pub unsafe fn take_opt_iter_unchecked( - &self, - iter: impl Iterator>, - capacity: Option, - ) -> Self { - apply_method_all_series_and_return!(self, take_opt_unchecked, [iter, capacity],) - } - - /// Take by index from an iterator. This operation clones the data. - pub fn take_opt_iter( - &self, - iter: impl Iterator>, - capacity: Option, - ) -> Result { - Ok(apply_method_all_series_and_return!(self, take_opt, [iter, capacity], ?)) - } - - /// Take by index. This operation is clone. - pub fn take(&self, indices: &T) -> Result { - let mut iter = indices.as_take_iter(); - let capacity = indices.take_index_len(); - self.take_iter(&mut iter, Some(capacity)) - } - - /// Get length of series. - pub fn len(&self) -> usize { - apply_method_all_series!(self, len,) - } - - /// Aggregate all chunks to a contiguous array of memory. - pub fn rechunk(&self, chunk_lengths: Option<&[usize]>) -> Result { - Ok(apply_method_all_series_and_return!(self, rechunk, [chunk_lengths], ?)) - } - - /// Get the head of the Series. - pub fn head(&self, length: Option) -> Self { - apply_method_all_series_and_return!(self, head, [length],) - } - - /// Get the tail of the Series. - pub fn tail(&self, length: Option) -> Self { - apply_method_all_series_and_return!(self, tail, [length],) - } - - /// Cast to some primitive type. - pub fn cast(&self) -> Result - where - N: PolarsDataType, - { - let s = match self { - Series::Bool(arr) => pack_ca_to_series(arr.cast::()?), - Series::Utf8(arr) => pack_ca_to_series(arr.cast::()?), - Series::UInt8(arr) => pack_ca_to_series(arr.cast::()?), - Series::UInt16(arr) => pack_ca_to_series(arr.cast::()?), - Series::UInt32(arr) => pack_ca_to_series(arr.cast::()?), - Series::UInt64(arr) => pack_ca_to_series(arr.cast::()?), - Series::Int8(arr) => pack_ca_to_series(arr.cast::()?), - Series::Int16(arr) => pack_ca_to_series(arr.cast::()?), - Series::Int32(arr) => pack_ca_to_series(arr.cast::()?), - Series::Int64(arr) => pack_ca_to_series(arr.cast::()?), - Series::Float32(arr) => pack_ca_to_series(arr.cast::()?), - Series::Float64(arr) => pack_ca_to_series(arr.cast::()?), - Series::Date32(arr) => pack_ca_to_series(arr.cast::()?), - Series::Date64(arr) => pack_ca_to_series(arr.cast::()?), - Series::Time32Millisecond(arr) => pack_ca_to_series(arr.cast::()?), - Series::Time32Second(arr) => pack_ca_to_series(arr.cast::()?), - Series::Time64Nanosecond(arr) => pack_ca_to_series(arr.cast::()?), - Series::Time64Microsecond(arr) => pack_ca_to_series(arr.cast::()?), - Series::DurationNanosecond(arr) => pack_ca_to_series(arr.cast::()?), - Series::DurationMicrosecond(arr) => pack_ca_to_series(arr.cast::()?), - Series::DurationMillisecond(arr) => pack_ca_to_series(arr.cast::()?), - Series::DurationSecond(arr) => pack_ca_to_series(arr.cast::()?), - Series::TimestampNanosecond(arr) => pack_ca_to_series(arr.cast::()?), - Series::TimestampMicrosecond(arr) => pack_ca_to_series(arr.cast::()?), - Series::TimestampMillisecond(arr) => pack_ca_to_series(arr.cast::()?), - Series::TimestampSecond(arr) => pack_ca_to_series(arr.cast::()?), - Series::IntervalDayTime(arr) => pack_ca_to_series(arr.cast::()?), - Series::IntervalYearMonth(arr) => pack_ca_to_series(arr.cast::()?), - Series::LargeList(arr) => pack_ca_to_series(arr.cast::()?), - }; - Ok(s) - } - - /// Get the `ChunkedArray` for some `PolarsDataType` - pub fn unpack(&self) -> Result<&ChunkedArray> - where - N: PolarsDataType, - { - macro_rules! unpack_if_match { - ($ca:ident) => {{ - if *$ca.dtype() == N::get_data_type() { - unsafe { Ok(mem::transmute::<_, &ChunkedArray>($ca)) } - } else { - Err(PolarsError::DataTypeMisMatch) - } - }}; - } - match self { - Series::Bool(arr) => unpack_if_match!(arr), - Series::Utf8(arr) => unpack_if_match!(arr), - Series::UInt8(arr) => unpack_if_match!(arr), - Series::UInt16(arr) => unpack_if_match!(arr), - Series::UInt32(arr) => unpack_if_match!(arr), - Series::UInt64(arr) => unpack_if_match!(arr), - Series::Int8(arr) => unpack_if_match!(arr), - Series::Int16(arr) => unpack_if_match!(arr), - Series::Int32(arr) => unpack_if_match!(arr), - Series::Int64(arr) => unpack_if_match!(arr), - Series::Float32(arr) => unpack_if_match!(arr), - Series::Float64(arr) => unpack_if_match!(arr), - Series::Date32(arr) => unpack_if_match!(arr), - Series::Date64(arr) => unpack_if_match!(arr), - Series::Time32Millisecond(arr) => unpack_if_match!(arr), - Series::Time32Second(arr) => unpack_if_match!(arr), - Series::Time64Nanosecond(arr) => unpack_if_match!(arr), - Series::Time64Microsecond(arr) => unpack_if_match!(arr), - Series::DurationNanosecond(arr) => unpack_if_match!(arr), - Series::DurationMicrosecond(arr) => unpack_if_match!(arr), - Series::DurationMillisecond(arr) => unpack_if_match!(arr), - Series::DurationSecond(arr) => unpack_if_match!(arr), - Series::TimestampNanosecond(arr) => unpack_if_match!(arr), - Series::TimestampMicrosecond(arr) => unpack_if_match!(arr), - Series::TimestampMillisecond(arr) => unpack_if_match!(arr), - Series::TimestampSecond(arr) => unpack_if_match!(arr), - Series::IntervalDayTime(arr) => unpack_if_match!(arr), - Series::IntervalYearMonth(arr) => unpack_if_match!(arr), - Series::LargeList(arr) => unpack_if_match!(arr), - } - } - - /// Get a single value by index. Don't use this operation for loops as a runtime cast is - /// needed for every iteration. - pub fn get(&self, index: usize) -> AnyType { - apply_method_all_series!(self, get_any, index) - } - - /// Sort in place. - pub fn sort_in_place(&mut self, reverse: bool) -> &mut Self { - apply_method_all_series!(self, sort_in_place, reverse); - self - } - - pub fn sort(&self, reverse: bool) -> Self { - apply_method_all_series_and_return!(self, sort, [reverse],) - } - - /// Retrieve the indexes needed for a sort. - pub fn argsort(&self, reverse: bool) -> Vec { - apply_method_all_series!(self, argsort, reverse) - } - - /// Count the null values. - pub fn null_count(&self) -> usize { - apply_method_all_series!(self, null_count,) - } - - /// Get unique values in the Series. - pub fn unique(&self) -> Self { - apply_method_all_series_and_return!(self, unique, [],) - } - - /// Get first indexes of unique values. - pub fn arg_unique(&self) -> Vec { - apply_method_all_series!(self, arg_unique,) - } - - /// Get a mask of the null values. - pub fn is_null(&self) -> BooleanChunked { - apply_method_all_series!(self, is_null,) - } - - /// Get the bits that represent the null values of the underlying ChunkedArray - pub fn null_bits(&self) -> Vec<(usize, Option)> { - apply_method_all_series!(self, null_bits,) - } - - /// return a Series in reversed order - pub fn reverse(&self) -> Self { - apply_method_all_series_and_return!(self, reverse, [],) - } - - /// Rechunk and return a pointer to the start of the Series. - /// Only implemented for numeric types - pub fn as_single_ptr(&mut self) -> usize { - apply_method_numeric_series!(self, as_single_ptr,) - } - - /// Shift the values by a given period and fill the parts that will be empty due to this operation - /// with `Nones`. - /// - /// *NOTE: If you want to fill the Nones with a value use the - /// [`shift` operation on `ChunkedArray`](../chunked_array/ops/trait.ChunkShift.html).* - /// - /// # Example - /// - /// ```rust - /// # use polars::prelude::*; - /// fn example() -> Result<()> { - /// let s = Series::new("series", &[1, 2, 3]); - /// - /// let shifted = s.shift(1)?; - /// assert_eq!(Vec::from(shifted.i32()?), &[None, Some(1), Some(2)]); - /// - /// let shifted = s.shift(-1)?; - /// assert_eq!(Vec::from(shifted.i32()?), &[Some(1), Some(2), None]); - /// - /// let shifted = s.shift(2)?; - /// assert_eq!(Vec::from(shifted.i32()?), &[None, None, Some(1)]); - /// - /// Ok(()) - /// } - /// example(); - /// ``` - pub fn shift(&self, periods: i32) -> Result { - Ok(apply_method_all_series_and_return!(self, shift, [periods, &None],?)) - } - - /// Replace None values with one of the following strategies: - /// * Forward fill (replace None with the previous value) - /// * Backward fill (replace None with the next value) - /// * Mean fill (replace None with the mean of the whole array) - /// * Min fill (replace None with the minimum of the whole array) - /// * Max fill (replace None with the maximum of the whole array) - /// - /// *NOTE: If you want to fill the Nones with a value use the - /// [`fill_none` operation on `ChunkedArray`](../chunked_array/ops/trait.ChunkFillNone.html)*. - /// - /// # Example - /// - /// ```rust - /// # use polars::prelude::*; - /// fn example() -> Result<()> { - /// let s = Series::new("some_missing", &[Some(1), None, Some(2)]); - /// - /// let filled = s.fill_none(FillNoneStrategy::Forward)?; - /// assert_eq!(Vec::from(filled.i32()?), &[Some(1), Some(1), Some(2)]); - /// - /// let filled = s.fill_none(FillNoneStrategy::Backward)?; - /// assert_eq!(Vec::from(filled.i32()?), &[Some(1), Some(2), Some(2)]); - /// - /// let filled = s.fill_none(FillNoneStrategy::Min)?; - /// assert_eq!(Vec::from(filled.i32()?), &[Some(1), Some(1), Some(2)]); - /// - /// let filled = s.fill_none(FillNoneStrategy::Max)?; - /// assert_eq!(Vec::from(filled.i32()?), &[Some(1), Some(2), Some(2)]); - /// - /// let filled = s.fill_none(FillNoneStrategy::Mean)?; - /// assert_eq!(Vec::from(filled.i32()?), &[Some(1), Some(1), Some(2)]); - /// - /// Ok(()) - /// } - /// example(); - /// ``` - pub fn fill_none(&self, strategy: FillNoneStrategy) -> Result { - Ok(apply_method_all_series_and_return!(self, fill_none, [strategy],?)) - } - - pub(crate) fn fmt_largelist(&self) -> String { - apply_method_all_series!(self, fmt_largelist,) - } -} - -fn pack_ca_to_series(ca: ChunkedArray) -> Series { - unsafe { - match N::get_data_type() { - ArrowDataType::Boolean => Series::Bool(mem::transmute(ca)), - ArrowDataType::Utf8 => Series::Utf8(mem::transmute(ca)), - ArrowDataType::UInt8 => Series::UInt8(mem::transmute(ca)), - ArrowDataType::UInt16 => Series::UInt16(mem::transmute(ca)), - ArrowDataType::UInt32 => Series::UInt32(mem::transmute(ca)), - ArrowDataType::UInt64 => Series::UInt64(mem::transmute(ca)), - ArrowDataType::Int8 => Series::Int8(mem::transmute(ca)), - ArrowDataType::Int16 => Series::Int16(mem::transmute(ca)), - ArrowDataType::Int32 => Series::Int32(mem::transmute(ca)), - ArrowDataType::Int64 => Series::Int64(mem::transmute(ca)), - ArrowDataType::Float32 => Series::Float32(mem::transmute(ca)), - ArrowDataType::Float64 => Series::Float64(mem::transmute(ca)), - ArrowDataType::Date32(DateUnit::Day) => Series::Date32(mem::transmute(ca)), - ArrowDataType::Date64(DateUnit::Millisecond) => Series::Date64(mem::transmute(ca)), - ArrowDataType::Time64(datatypes::TimeUnit::Microsecond) => { - Series::Time64Microsecond(mem::transmute(ca)) - } - ArrowDataType::Time64(datatypes::TimeUnit::Nanosecond) => { - Series::Time64Nanosecond(mem::transmute(ca)) - } - ArrowDataType::Time32(datatypes::TimeUnit::Millisecond) => { - Series::Time32Millisecond(mem::transmute(ca)) - } - ArrowDataType::Time32(datatypes::TimeUnit::Second) => { - Series::Time32Second(mem::transmute(ca)) - } - ArrowDataType::Duration(datatypes::TimeUnit::Nanosecond) => { - Series::DurationNanosecond(mem::transmute(ca)) - } - ArrowDataType::Duration(datatypes::TimeUnit::Microsecond) => { - Series::DurationMicrosecond(mem::transmute(ca)) - } - ArrowDataType::Duration(datatypes::TimeUnit::Millisecond) => { - Series::DurationMillisecond(mem::transmute(ca)) - } - ArrowDataType::Duration(datatypes::TimeUnit::Second) => { - Series::DurationSecond(mem::transmute(ca)) - } - ArrowDataType::Timestamp(TimeUnit::Nanosecond, _) => { - Series::TimestampNanosecond(mem::transmute(ca)) - } - ArrowDataType::Timestamp(TimeUnit::Microsecond, _) => { - Series::TimestampMicrosecond(mem::transmute(ca)) - } - ArrowDataType::Timestamp(TimeUnit::Millisecond, _) => { - Series::TimestampMillisecond(mem::transmute(ca)) - } - ArrowDataType::Timestamp(TimeUnit::Second, _) => { - Series::TimestampSecond(mem::transmute(ca)) - } - ArrowDataType::Interval(IntervalUnit::YearMonth) => { - Series::IntervalYearMonth(mem::transmute(ca)) - } - ArrowDataType::Interval(IntervalUnit::DayTime) => { - Series::IntervalDayTime(mem::transmute(ca)) - } - ArrowDataType::LargeList(_) => Series::LargeList(mem::transmute(ca)), - _ => panic!("Not implemented: {:?}", N::get_data_type()), - } - } -} - -pub trait NamedFrom { - /// Initialize by name and values. - fn new(name: &str, _: T) -> Self; -} - -macro_rules! impl_named_from { - ($type:ty, $series_var:ident, $method:ident) => { - impl> NamedFrom for Series { - fn new(name: &str, v: T) -> Self { - Series::$series_var(ChunkedArray::$method(name, v.as_ref())) - } - } - }; -} - -impl<'a, T: AsRef<[&'a str]>> NamedFrom for Series { - fn new(name: &str, v: T) -> Self { - Series::Utf8(ChunkedArray::new_from_slice(name, v.as_ref())) - } -} -impl<'a, T: AsRef<[Option<&'a str>]>> NamedFrom]> for Series { - fn new(name: &str, v: T) -> Self { - Series::Utf8(ChunkedArray::new_from_opt_slice(name, v.as_ref())) - } -} - -impl_named_from!([String], Utf8, new_from_slice); -impl_named_from!([bool], Bool, new_from_slice); -impl_named_from!([u8], UInt8, new_from_slice); -impl_named_from!([u16], UInt16, new_from_slice); -impl_named_from!([u32], UInt32, new_from_slice); -impl_named_from!([u64], UInt64, new_from_slice); -impl_named_from!([i8], Int8, new_from_slice); -impl_named_from!([i16], Int16, new_from_slice); -impl_named_from!([i32], Int32, new_from_slice); -impl_named_from!([i64], Int64, new_from_slice); -impl_named_from!([f32], Float32, new_from_slice); -impl_named_from!([f64], Float64, new_from_slice); -impl_named_from!([Option], Utf8, new_from_opt_slice); -impl_named_from!([Option], Bool, new_from_opt_slice); -impl_named_from!([Option], UInt8, new_from_opt_slice); -impl_named_from!([Option], UInt16, new_from_opt_slice); -impl_named_from!([Option], UInt32, new_from_opt_slice); -impl_named_from!([Option], UInt64, new_from_opt_slice); -impl_named_from!([Option], Int8, new_from_opt_slice); -impl_named_from!([Option], Int16, new_from_opt_slice); -impl_named_from!([Option], Int32, new_from_opt_slice); -impl_named_from!([Option], Int64, new_from_opt_slice); -impl_named_from!([Option], Float32, new_from_opt_slice); -impl_named_from!([Option], Float64, new_from_opt_slice); - -macro_rules! impl_as_ref_ca { - ($type:ident, $series_var:ident) => { - impl AsRef> for Series { - fn as_ref(&self) -> &ChunkedArray { - match self { - Series::$series_var(a) => a, - _ => unimplemented!(), - } - } - } - }; -} - -impl_as_ref_ca!(UInt8Type, UInt8); -impl_as_ref_ca!(UInt16Type, UInt16); -impl_as_ref_ca!(UInt32Type, UInt32); -impl_as_ref_ca!(UInt64Type, UInt64); -impl_as_ref_ca!(Int8Type, Int8); -impl_as_ref_ca!(Int16Type, Int16); -impl_as_ref_ca!(Int32Type, Int32); -impl_as_ref_ca!(Int64Type, Int64); -impl_as_ref_ca!(Float32Type, Float32); -impl_as_ref_ca!(Float64Type, Float64); -impl_as_ref_ca!(BooleanType, Bool); -impl_as_ref_ca!(Utf8Type, Utf8); -impl_as_ref_ca!(Date32Type, Date32); -impl_as_ref_ca!(Date64Type, Date64); -impl_as_ref_ca!(Time64NanosecondType, Time64Nanosecond); -impl_as_ref_ca!(Time64MicrosecondType, Time64Microsecond); -impl_as_ref_ca!(Time32MillisecondType, Time32Millisecond); -impl_as_ref_ca!(Time32SecondType, Time32Second); -impl_as_ref_ca!(DurationNanosecondType, DurationNanosecond); -impl_as_ref_ca!(DurationMicrosecondType, DurationMicrosecond); -impl_as_ref_ca!(DurationMillisecondType, DurationMillisecond); -impl_as_ref_ca!(DurationSecondType, DurationSecond); -impl_as_ref_ca!(TimestampNanosecondType, TimestampNanosecond); -impl_as_ref_ca!(TimestampMicrosecondType, TimestampMicrosecond); -impl_as_ref_ca!(TimestampMillisecondType, TimestampMillisecond); -impl_as_ref_ca!(TimestampSecondType, TimestampSecond); -impl_as_ref_ca!(IntervalDayTimeType, IntervalDayTime); -impl_as_ref_ca!(IntervalYearMonthType, IntervalYearMonth); -impl_as_ref_ca!(LargeListType, LargeList); - -macro_rules! impl_as_mut_ca { - ($type:ident, $series_var:ident) => { - impl AsMut> for Series { - fn as_mut(&mut self) -> &mut ChunkedArray { - match self { - Series::$series_var(a) => a, - _ => unimplemented!(), - } - } - } - }; -} - -impl_as_mut_ca!(UInt8Type, UInt8); -impl_as_mut_ca!(UInt16Type, UInt16); -impl_as_mut_ca!(UInt32Type, UInt32); -impl_as_mut_ca!(UInt64Type, UInt64); -impl_as_mut_ca!(Int8Type, Int8); -impl_as_mut_ca!(Int16Type, Int16); -impl_as_mut_ca!(Int32Type, Int32); -impl_as_mut_ca!(Int64Type, Int64); -impl_as_mut_ca!(Float32Type, Float32); -impl_as_mut_ca!(Float64Type, Float64); -impl_as_mut_ca!(BooleanType, Bool); -impl_as_mut_ca!(Utf8Type, Utf8); -impl_as_mut_ca!(Date32Type, Date32); -impl_as_mut_ca!(Date64Type, Date64); -impl_as_mut_ca!(Time64NanosecondType, Time64Nanosecond); -impl_as_mut_ca!(Time64MicrosecondType, Time64Microsecond); -impl_as_mut_ca!(Time32MillisecondType, Time32Millisecond); -impl_as_mut_ca!(Time32SecondType, Time32Second); -impl_as_mut_ca!(DurationNanosecondType, DurationNanosecond); -impl_as_mut_ca!(DurationMicrosecondType, DurationMicrosecond); -impl_as_mut_ca!(DurationMillisecondType, DurationMillisecond); -impl_as_mut_ca!(DurationSecondType, DurationSecond); -impl_as_mut_ca!(TimestampNanosecondType, TimestampNanosecond); -impl_as_mut_ca!(TimestampMicrosecondType, TimestampMicrosecond); -impl_as_mut_ca!(TimestampMillisecondType, TimestampMillisecond); -impl_as_mut_ca!(TimestampSecondType, TimestampSecond); -impl_as_mut_ca!(IntervalDayTimeType, IntervalDayTime); -impl_as_mut_ca!(IntervalYearMonthType, IntervalYearMonth); -impl_as_mut_ca!(LargeListType, LargeList); - -macro_rules! from_series_to_ca { - ($variant:ident, $ca:ident) => { - impl<'a> From<&'a Series> for &'a $ca { - fn from(s: &'a Series) -> Self { - match s { - Series::$variant(ca) => ca, - _ => unimplemented!(), - } - } - } - }; -} -from_series_to_ca!(UInt8, UInt8Chunked); -from_series_to_ca!(UInt16, UInt16Chunked); -from_series_to_ca!(UInt32, UInt32Chunked); -from_series_to_ca!(UInt64, UInt64Chunked); -from_series_to_ca!(Int8, Int8Chunked); -from_series_to_ca!(Int16, Int16Chunked); -from_series_to_ca!(Int32, Int32Chunked); -from_series_to_ca!(Int64, Int64Chunked); -from_series_to_ca!(Float32, Float32Chunked); -from_series_to_ca!(Float64, Float64Chunked); -from_series_to_ca!(Bool, BooleanChunked); -from_series_to_ca!(Utf8, Utf8Chunked); -from_series_to_ca!(Date32, Date32Chunked); -from_series_to_ca!(Date64, Date64Chunked); -from_series_to_ca!(Time32Millisecond, Time32MillisecondChunked); -from_series_to_ca!(Time32Second, Time32SecondChunked); -from_series_to_ca!(Time64Microsecond, Time64MicrosecondChunked); -from_series_to_ca!(Time64Nanosecond, Time64NanosecondChunked); -from_series_to_ca!(DurationMillisecond, DurationMillisecondChunked); -from_series_to_ca!(DurationSecond, DurationSecondChunked); -from_series_to_ca!(DurationMicrosecond, DurationMicrosecondChunked); -from_series_to_ca!(DurationNanosecond, DurationNanosecondChunked); -from_series_to_ca!(TimestampMillisecond, TimestampMillisecondChunked); -from_series_to_ca!(TimestampSecond, TimestampSecondChunked); -from_series_to_ca!(TimestampMicrosecond, TimestampMicrosecondChunked); -from_series_to_ca!(TimestampNanosecond, TimestampNanosecondChunked); -from_series_to_ca!(IntervalDayTime, IntervalDayTimeChunked); -from_series_to_ca!(IntervalYearMonth, IntervalYearMonthChunked); -from_series_to_ca!(LargeList, LargeListChunked); - -// TODO: add types -impl From<(&str, ArrayRef)> for Series { - fn from(name_arr: (&str, ArrayRef)) -> Self { - let (name, arr) = name_arr; - let chunk = vec![arr]; - match chunk[0].data_type() { - ArrowDataType::Utf8 => Utf8Chunked::new_from_chunks(name, chunk).into_series(), - ArrowDataType::Boolean => BooleanChunked::new_from_chunks(name, chunk).into_series(), - ArrowDataType::UInt8 => UInt8Chunked::new_from_chunks(name, chunk).into_series(), - ArrowDataType::UInt16 => UInt16Chunked::new_from_chunks(name, chunk).into_series(), - ArrowDataType::UInt32 => UInt32Chunked::new_from_chunks(name, chunk).into_series(), - ArrowDataType::UInt64 => UInt64Chunked::new_from_chunks(name, chunk).into_series(), - ArrowDataType::Int8 => Int8Chunked::new_from_chunks(name, chunk).into_series(), - ArrowDataType::Int16 => Int16Chunked::new_from_chunks(name, chunk).into_series(), - ArrowDataType::Int32 => Int32Chunked::new_from_chunks(name, chunk).into_series(), - ArrowDataType::Int64 => Int64Chunked::new_from_chunks(name, chunk).into_series(), - ArrowDataType::Float32 => Float32Chunked::new_from_chunks(name, chunk).into_series(), - ArrowDataType::Float64 => Float64Chunked::new_from_chunks(name, chunk).into_series(), - ArrowDataType::Date32(DateUnit::Day) => { - Date32Chunked::new_from_chunks(name, chunk).into_series() - } - ArrowDataType::Date64(DateUnit::Millisecond) => { - Date64Chunked::new_from_chunks(name, chunk).into_series() - } - ArrowDataType::Time32(TimeUnit::Millisecond) => { - Time32MillisecondChunked::new_from_chunks(name, chunk).into_series() - } - ArrowDataType::Time32(TimeUnit::Second) => { - Time32SecondChunked::new_from_chunks(name, chunk).into_series() - } - ArrowDataType::Time64(TimeUnit::Nanosecond) => { - Time64NanosecondChunked::new_from_chunks(name, chunk).into_series() - } - ArrowDataType::Time64(TimeUnit::Microsecond) => { - Time64MicrosecondChunked::new_from_chunks(name, chunk).into_series() - } - ArrowDataType::Interval(IntervalUnit::DayTime) => { - IntervalDayTimeChunked::new_from_chunks(name, chunk).into_series() - } - ArrowDataType::Interval(IntervalUnit::YearMonth) => { - IntervalYearMonthChunked::new_from_chunks(name, chunk).into_series() - } - ArrowDataType::Duration(TimeUnit::Nanosecond) => { - DurationNanosecondChunked::new_from_chunks(name, chunk).into_series() - } - ArrowDataType::Duration(TimeUnit::Microsecond) => { - DurationMicrosecondChunked::new_from_chunks(name, chunk).into_series() - } - ArrowDataType::Duration(TimeUnit::Millisecond) => { - DurationMillisecondChunked::new_from_chunks(name, chunk).into_series() - } - ArrowDataType::Duration(TimeUnit::Second) => { - DurationSecondChunked::new_from_chunks(name, chunk).into_series() - } - ArrowDataType::Timestamp(TimeUnit::Nanosecond, _) => { - TimestampNanosecondChunked::new_from_chunks(name, chunk).into_series() - } - ArrowDataType::Timestamp(TimeUnit::Microsecond, _) => { - TimestampMicrosecondChunked::new_from_chunks(name, chunk).into_series() - } - ArrowDataType::Timestamp(TimeUnit::Millisecond, _) => { - TimestampMillisecondChunked::new_from_chunks(name, chunk).into_series() - } - ArrowDataType::Timestamp(TimeUnit::Second, _) => { - TimestampSecondChunked::new_from_chunks(name, chunk).into_series() - } - ArrowDataType::LargeList(_) => { - LargeListChunked::new_from_chunks(name, chunk).into_series() - } - _ => unimplemented!(), - } - } -} - -#[cfg(test)] -mod test { - use crate::prelude::*; - - #[test] - fn cast() { - let ar = ChunkedArray::::new_from_slice("a", &[1, 2]); - let s = Series::Int32(ar); - let s2 = s.cast::().unwrap(); - match s2 { - Series::Int64(_) => assert!(true), - _ => assert!(false), - } - let s2 = s.cast::().unwrap(); - match s2 { - Series::Float32(_) => assert!(true), - _ => assert!(false), - } - } - - #[test] - fn new_series() { - Series::new("boolean series", &vec![true, false, true]); - Series::new("int series", &[1, 2, 3]); - let ca = Int32Chunked::new_from_slice("a", &[1, 2, 3]); - Series::from(ca); - } - - #[test] - fn series_append() { - let mut s1 = Series::new("a", &[1, 2]); - let s2 = Series::new("b", &[3]); - s1.append(&s2).unwrap(); - assert_eq!(s1.len(), 3); - - // add wrong type - let s2 = Series::new("b", &[3.0]); - assert!(s1.append(&s2).is_err()) - } -} diff --git a/polars/src/testing.rs b/polars/src/testing.rs deleted file mode 100644 index a0e240cf354c..000000000000 --- a/polars/src/testing.rs +++ /dev/null @@ -1,89 +0,0 @@ -//! Testing utilities. -use crate::prelude::*; - -impl Series { - /// Check if series are equal. Note that `None == None` evaluates to `false` - pub fn series_equal(&self, other: &Series) -> bool { - if self.len() != other.len() { - return false; - } - if self.null_count() != other.null_count() { - return false; - } - match self.eq(other).sum() { - None => false, - Some(sum) => sum as usize == self.len(), - } - } - - /// Check if all values in series are equal where `None == None` evaluates to `true`. - pub fn series_equal_missing(&self, other: &Series) -> bool { - if self.len() != other.len() { - return false; - } - if self.null_count() != other.null_count() { - return false; - } - // if all null and previous check did not return (so other is also all null) - if self.null_count() == self.len() { - return true; - } - match self.eq_missing(other).sum() { - None => false, - Some(sum) => sum as usize == self.len(), - } - } -} - -impl DataFrame { - /// Check if `DataFrames` are equal. Note that `None == None` evaluates to `false` - pub fn frame_equal(&self, other: &DataFrame) -> bool { - if self.shape() != other.shape() { - return false; - } - for (left, right) in self.get_columns().iter().zip(other.get_columns()) { - if !left.series_equal(right) { - return false; - } - } - return true; - } - - /// Check if all values in `DataFrames` are equal where `None == None` evaluates to `true`. - pub fn frame_equal_missing(&self, other: &DataFrame) -> bool { - if self.shape() != other.shape() { - return false; - } - for (left, right) in self.get_columns().iter().zip(other.get_columns()) { - if !left.series_equal_missing(right) { - return false; - } - } - return true; - } -} - -#[cfg(test)] -mod test { - use crate::prelude::*; - - #[test] - fn test_series_equal() { - let a = Series::new("a", &[1, 2, 3]); - let b = Series::new("b", &[1, 2, 3]); - assert!(a.series_equal(&b)); - - let s = Series::new("foo", &[None, Some(1i64)]); - assert!(s.series_equal_missing(&s)); - } - - #[test] - fn test_df_equal() { - let a = Series::new("a", [1, 2, 3].as_ref()); - let b = Series::new("b", [1, 2, 3].as_ref()); - - let df1 = DataFrame::new(vec![a, b]).unwrap(); - let df2 = df1.clone(); - assert!(df1.frame_equal(&df2)) - } -} diff --git a/polars/src/utils.rs b/polars/src/utils.rs deleted file mode 100644 index 24f604abd296..000000000000 --- a/polars/src/utils.rs +++ /dev/null @@ -1,105 +0,0 @@ -use std::mem; -use std::ops::{Deref, DerefMut}; - -/// Used to split the mantissa and exponent of floating point numbers -/// https://stackoverflow.com/questions/39638363/how-can-i-use-a-hashmap-with-f64-as-key-in-rust -pub(crate) fn integer_decode(val: f64) -> (u64, i16, i8) { - let bits: u64 = unsafe { mem::transmute(val) }; - let sign: i8 = if bits >> 63 == 0 { 1 } else { -1 }; - let mut exponent: i16 = ((bits >> 52) & 0x7ff) as i16; - let mantissa = if exponent == 0 { - (bits & 0xfffffffffffff) << 1 - } else { - (bits & 0xfffffffffffff) | 0x10000000000000 - }; - - exponent -= 1023 + 52; - (mantissa, exponent, sign) -} - -pub(crate) fn floating_encode_f64(mantissa: u64, exponent: i16, sign: i8) -> f64 { - sign as f64 * mantissa as f64 * (2.0f64).powf(exponent as f64) -} - -/// Just a wrapper structure. Useful for certain impl specializations -/// This is for instance use to implement -/// `impl FromIterator for Xob>` -/// as `Option` was alrady implemented: -/// `impl FromIterator> for ChunkedArray` -pub struct Xob { - inner: T, -} - -impl Xob { - pub fn new(inner: T) -> Self { - Xob { inner } - } - - pub fn into_inner(self) -> T { - self.inner - } -} - -impl Deref for Xob { - type Target = T; - - fn deref(&self) -> &Self::Target { - &self.inner - } -} - -impl DerefMut for Xob { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.inner - } -} - -pub fn get_iter_capacity>(iter: &I) -> usize { - match iter.size_hint() { - (_lower, Some(upper)) => upper, - (0, None) => 1024, - (lower, None) => lower, - } -} - -#[macro_export] -macro_rules! match_arrow_data_type_apply_macro { - ($obj:expr, $macro:ident, $macro_utf8:ident $(, $opt_args:expr)*) => {{ - match $obj { - ArrowDataType::Utf8 => $macro_utf8!($($opt_args)*), - ArrowDataType::Boolean => $macro!(BooleanType $(, $opt_args)*), - ArrowDataType::UInt8 => $macro!(UInt8Type $(, $opt_args)*), - ArrowDataType::UInt16 => $macro!(UInt16Type $(, $opt_args)*), - ArrowDataType::UInt32 => $macro!(UInt32Type $(, $opt_args)*), - ArrowDataType::UInt64 => $macro!(UInt64Type $(, $opt_args)*), - ArrowDataType::Int8 => $macro!(Int8Type $(, $opt_args)*), - ArrowDataType::Int16 => $macro!(Int16Type $(, $opt_args)*), - ArrowDataType::Int32 => $macro!(Int32Type $(, $opt_args)*), - ArrowDataType::Int64 => $macro!(Int64Type $(, $opt_args)*), - ArrowDataType::Float32 => $macro!(Float32Type $(, $opt_args)*), - ArrowDataType::Float64 => $macro!(Float64Type $(, $opt_args)*), - ArrowDataType::Date32(DateUnit::Day) => $macro!(Date32Type $(, $opt_args)*), - ArrowDataType::Date64(DateUnit::Millisecond) => $macro!(Date64Type $(, $opt_args)*), - ArrowDataType::Time32(TimeUnit::Millisecond) => $macro!(Time32MillisecondType $(, $opt_args)*), - ArrowDataType::Time32(TimeUnit::Second) => $macro!(Time32SecondType $(, $opt_args)*), - ArrowDataType::Time64(TimeUnit::Nanosecond) => $macro!(Time64NanosecondType $(, $opt_args)*), - ArrowDataType::Time64(TimeUnit::Microsecond) => $macro!(Time64MicrosecondType $(, $opt_args)*), - ArrowDataType::Interval(IntervalUnit::DayTime) => $macro!(IntervalDayTimeType $(, $opt_args)*), - ArrowDataType::Interval(IntervalUnit::YearMonth) => $macro!(IntervalYearMonthType $(, $opt_args)*), - ArrowDataType::Duration(TimeUnit::Nanosecond) => $macro!(DurationNanosecondType $(, $opt_args)*), - ArrowDataType::Duration(TimeUnit::Microsecond) => $macro!(DurationMicrosecondType $(, $opt_args)*), - ArrowDataType::Duration(TimeUnit::Millisecond) => $macro!(DurationMillisecondType $(, $opt_args)*), - ArrowDataType::Duration(TimeUnit::Second) => $macro!(DurationSecondType $(, $opt_args)*), - ArrowDataType::Timestamp(TimeUnit::Nanosecond, _) => $macro!(TimestampNanosecondType $(, $opt_args)*), - ArrowDataType::Timestamp(TimeUnit::Microsecond, _) => $macro!(TimestampMicrosecondType $(, $opt_args)*), - ArrowDataType::Timestamp(TimeUnit::Millisecond, _) => $macro!(Time32MillisecondType $(, $opt_args)*), - ArrowDataType::Timestamp(TimeUnit::Second, _) => $macro!(TimestampSecondType $(, $opt_args)*), - _ => unimplemented!(), - } - }}; -} - -/// Clone if upstream hasn't implemented clone -pub(crate) fn clone(t: &T) -> T { - unsafe { mem::transmute_copy(t) } -} diff --git a/py-polars/.gitignore b/py-polars/.gitignore deleted file mode 100644 index 56cb858b8012..000000000000 --- a/py-polars/.gitignore +++ /dev/null @@ -1 +0,0 @@ -wheels/ \ No newline at end of file diff --git a/py-polars/Cargo.toml b/py-polars/Cargo.toml index 4e2729d059d8..7d0c459291e5 100644 --- a/py-polars/Cargo.toml +++ b/py-polars/Cargo.toml @@ -1,23 +1,113 @@ [package] -name = "py_polars" -version = "0.1.0" -authors = ["ritchie46 "] -edition = "2018" +name = "py-polars" +version = "1.28.1" +edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[lib] +name = "polars" +crate-type = ["cdylib"] [dependencies] -polars = {path = "../polars", features = ["parquet", "simd", "parallel"]} -pyo3 = {version = "0.11", features = ["extension-module"] } -thiserror = "1.0.20" -numpy = "0.11" -ndarray = "0.13.1" +libc = { workspace = true } +# Explicit dependency is needed to add bigidx in CI during release +polars = { workspace = true } +polars-python = { workspace = true, features = ["pymethods", "iejoin"] } +pyo3 = { workspace = true, features = ["abi3-py39", "chrono", "extension-module", "multiple-pymethods"] } +[target.'cfg(all(any(not(target_family = "unix"), target_os = "emscripten", allocator = "mimalloc"), not(allocator = "default")))'.dependencies] +mimalloc = { version = "0.1", default-features = false } -[lib] -name = "pypolars" -crate-type = ["cdylib"] +# Feature background_threads is unsupported on MacOS (https://github.com/jemalloc/jemalloc/issues/843). +[target.'cfg(all(target_family = "unix", not(target_os = "macos"), not(target_os = "emscripten"), not(allocator = "mimalloc"), not(allocator = "default")))'.dependencies] +tikv-jemallocator = { version = "0.6.0", features = ["disable_initial_exec_tls", "background_threads"] } + +[target.'cfg(all(target_family = "unix", target_os = "macos", not(allocator = "mimalloc"), not(allocator = "default")))'.dependencies] +tikv-jemallocator = { version = "0.6.0", features = ["disable_initial_exec_tls"] } +either = { workspace = true } + +[features] +# Features used in this crate +ffi_plugin = ["polars-python/ffi_plugin"] +csv = ["polars-python/csv"] +polars_cloud = ["polars-python/polars_cloud"] +object = ["polars-python/object"] +clipboard = ["polars-python/clipboard"] +sql = ["polars-python/sql"] +trigonometry = ["polars-python/trigonometry"] +parquet = ["polars-python/parquet"] +ipc = ["polars-python/ipc"] +catalog = ["polars-python/catalog"] + +# Features passed through to the polars-python crate +avro = ["polars-python/avro"] +ipc_streaming = ["polars-python/ipc_streaming"] +is_in = ["polars-python/is_in"] +json = ["polars-python/json"] +sign = ["polars-python/sign"] +asof_join = ["polars-python/asof_join"] +cross_join = ["polars-python/cross_join"] +pct_change = ["polars-python/pct_change"] +repeat_by = ["polars-python/repeat_by"] +# also includes simd +nightly = ["polars-python/nightly"] +streaming = ["polars-python/streaming"] +meta = ["polars-python/meta"] +search_sorted = ["polars-python/search_sorted"] +decompress = ["polars-python/decompress"] +regex = ["polars-python/regex"] +extract_jsonpath = ["polars-python/extract_jsonpath"] +pivot = ["polars-python/pivot"] +top_k = ["polars-python/top_k"] +propagate_nans = ["polars-python/propagate_nans"] +performant = ["polars-python/performant"] +timezones = ["polars-python/timezones"] +cse = ["polars-python/cse"] +merge_sorted = ["polars-python/merge_sorted"] +list_gather = ["polars-python/list_gather"] +list_count = ["polars-python/list_count"] +array_count = ["polars-python/array_count"] +binary_encoding = ["polars-python/binary_encoding"] +list_sets = ["polars-python/list_sets"] +list_any_all = ["polars-python/list_any_all"] +array_any_all = ["polars-python/array_any_all"] +list_drop_nulls = ["polars-python/list_drop_nulls"] +list_sample = ["polars-python/list_sample"] +cutqcut = ["polars-python/cutqcut"] +rle = ["polars-python/rle"] +extract_groups = ["polars-python/extract_groups"] +cloud = ["polars-python/cloud"] +peaks = ["polars-python/peaks"] +hist = ["polars-python/hist"] +find_many = ["polars-python/find_many"] +new_streaming = ["polars-python/new_streaming"] + +dtype-i8 = ["polars-python/dtype-i8"] +dtype-i16 = ["polars-python/dtype-i16"] +dtype-u8 = ["polars-python/dtype-u8"] +dtype-u16 = ["polars-python/dtype-u16"] +dtype-array = ["polars-python/dtype-array"] + +dtypes = ["polars-python/dtypes"] + +operations = ["polars-python/operations"] + +io = ["polars-python/io"] + +optimizations = ["polars-python/optimizations"] +all = [ + "ffi_plugin", + "csv", + "polars_cloud", + "object", + "clipboard", + "sql", + "trigonometry", + "parquet", + "ipc", + "catalog", + "polars-python/all", + "performant", +] -[package.metadata.maturin] -requires-dist = ["numpy"] \ No newline at end of file +default = ["all", "nightly"] diff --git a/py-polars/Dockerfile b/py-polars/Dockerfile deleted file mode 100644 index 09256c0cd675..000000000000 --- a/py-polars/Dockerfile +++ /dev/null @@ -1,9 +0,0 @@ -FROM ritchie46/py-polars-base - -RUN pip install pandas jupyterlab matplotlib -RUN mkdir /notebooks - -WORKDIR /notebooks -COPY examples/* /notebooks/ - -CMD ["jupyter", "lab", "--ip=0.0.0.0", "--port=8890", "--no-browser", "--allow-root", "--NotebookApp.token=''"] \ No newline at end of file diff --git a/py-polars/Dockerfile_base b/py-polars/Dockerfile_base deleted file mode 100644 index 70c9af7f00af..000000000000 --- a/py-polars/Dockerfile_base +++ /dev/null @@ -1,15 +0,0 @@ -FROM konstin2/maturin as build - -COPY . . - -RUN cd py-polars \ -&& rustup default nightly \ -&& echo we are in $(pwd) \ -&& mkdir wheels \ -&& maturin build --release -o wheels -i /opt/python/cp37-cp37m/bin/python - -FROM python:3.7-slim-buster - -COPY --from=build /io/py-polars/wheels /wheels - -RUN pip install /wheels/* diff --git a/py-polars/LICENSE b/py-polars/LICENSE new file mode 120000 index 000000000000..ea5b60640b01 --- /dev/null +++ b/py-polars/LICENSE @@ -0,0 +1 @@ +../LICENSE \ No newline at end of file diff --git a/py-polars/Makefile b/py-polars/Makefile new file mode 100644 index 000000000000..a56d547cb387 --- /dev/null +++ b/py-polars/Makefile @@ -0,0 +1,113 @@ +.DEFAULT_GOAL := help + +PYTHONPATH= +SHELL=bash +VENV=../.venv + +ifeq ($(OS),Windows_NT) + VENV_BIN=$(VENV)/Scripts +else + VENV_BIN=$(VENV)/bin +endif + +.PHONY: .venv +.venv: ## Set up virtual environment and install requirements + @$(MAKE) -s -C .. $@ + +.PHONY: requirements +requirements: .venv ## Install/refresh Python project requirements + @$(MAKE) -s -C .. $@ + +.PHONY: requirements-all +requirements-all: .venv ## Install/refresh all Python requirements (including those needed for CI tests) + @$(MAKE) -s -C .. $@ + +.PHONY: build +build: .venv ## Compile and install Python Polars for development + @$(MAKE) -s -C .. $@ + +.PHONY: build-release +build-release: .venv ## Compile and install Python Polars binary with optimizations, with minimal debug symbols + @$(MAKE) -s -C .. $@ + +.PHONY: build-nodebug-release +build-nodebug-release: .venv ## Same as build-release, but without any debug symbols at all (a bit faster to build) + @$(MAKE) -s -C .. $@ + +.PHONY: build-debug-release +build-debug-release: .venv ## Same as build-release, but with full debug symbols turned on (a bit slower to build) + @$(MAKE) -s -C .. $@ + +.PHONY: build-dist-release +build-dist-release: .venv ## Compile and install Python Polars binary with super slow extra optimization turned on, for distribution + @$(MAKE) -s -C .. $@ + +.PHONY: fix +fix: + @$(MAKE) -s -C .. $@ + +.PHONY: lint +lint: .venv ## Run lint checks (only) + $(VENV_BIN)/ruff check + -$(VENV_BIN)/mypy + +.PHONY: fmt +fmt: .venv ## Run autoformatting (and lint) + $(VENV_BIN)/ruff check + $(VENV_BIN)/ruff format + $(VENV_BIN)/typos .. + cargo fmt --all + -dprint fmt + -$(VENV_BIN)/mypy + +.PHONY: clippy +clippy: ## Run clippy + cargo clippy --locked -- -D warnings -D clippy::dbg_macro + +.PHONY: pre-commit +pre-commit: fmt clippy ## Run all code formatting and lint/quality checks + +.PHONY: test +test: .venv build ## Run fast unittests + POLARS_TIMEOUT_MS=60000 $(VENV_BIN)/pytest -n auto $(PYTEST_ARGS) + +.PHONY: test-all +test-all: .venv build ## Run all tests + POLARS_TIMEOUT_MS=60000 $(VENV_BIN)/pytest -n auto -m "slow or not slow" + $(VENV_BIN)/python tests/docs/run_doctest.py + +.PHONY: doctest +doctest: .venv build ## Run doctests + $(VENV_BIN)/python tests/docs/run_doctest.py + $(VENV_BIN)/pytest tests/docs/test_user_guide.py -m docs + +.PHONY: docs +docs: .venv ## Build Python docs (incremental) + @$(MAKE) -s -C docs html + +.PHONY: docs-clean +docs-clean: .venv ## Build Python docs (full rebuild) + @$(MAKE) -s -C docs clean + @$(MAKE) docs + +.PHONY: coverage +coverage: .venv build ## Run tests and report coverage + POLARS_TIMEOUT_MS=60000 $(VENV_BIN)/pytest --cov -n auto -m "not release and not benchmark" + +.PHONY: clean +clean: ## Clean up caches and build artifacts + @$(MAKE) -s -C docs clean + @rm -rf .hypothesis/ + @rm -rf .mypy_cache/ + @rm -rf .pytest_cache/ + @$(VENV_BIN)/ruff clean + @rm -rf tests/data/pdsh/sf* + @rm -f .coverage + @rm -f coverage.xml + @rm -f polars/polars.abi3.so + @find . -type f -name '*.py[co]' -delete -or -type d -name __pycache__ -exec rm -r {} + + +.PHONY: help +help: ## Display this help screen + @echo -e "\033[1mAvailable commands:\033[0m" + @grep -E '^[a-z.A-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-22s\033[0m %s\n", $$1, $$2}' | sort diff --git a/py-polars/README.md b/py-polars/README.md new file mode 120000 index 000000000000..32d46ee883b5 --- /dev/null +++ b/py-polars/README.md @@ -0,0 +1 @@ +../README.md \ No newline at end of file diff --git a/py-polars/build.requirements.txt b/py-polars/build.requirements.txt deleted file mode 100644 index fc2bbc8c9df0..000000000000 --- a/py-polars/build.requirements.txt +++ /dev/null @@ -1 +0,0 @@ -maturin==0.8.1 \ No newline at end of file diff --git a/py-polars/build.rs b/py-polars/build.rs new file mode 100644 index 000000000000..5024e53b6889 --- /dev/null +++ b/py-polars/build.rs @@ -0,0 +1,7 @@ +fn main() { + println!("cargo::rustc-check-cfg=cfg(allocator, values(\"default\", \"mimalloc\"))"); + println!( + "cargo:rustc-env=TARGET={}", + std::env::var("TARGET").unwrap() + ); +} diff --git a/py-polars/debug/launch.py b/py-polars/debug/launch.py new file mode 100644 index 000000000000..d824702b2da5 --- /dev/null +++ b/py-polars/debug/launch.py @@ -0,0 +1,82 @@ +import os +import re +import sys +import time +from pathlib import Path + +""" +The following parameter determines the sleep time of the Python process after a signal +is sent that attaches the Rust LLDB debugger. If the Rust LLDB debugger attaches to the +current session too late, it might miss any set breakpoints. If this happens +consistently, it is recommended to increase this value. +""" +LLDB_DEBUG_WAIT_TIME_SECONDS = 1 + + +def launch_debugging() -> None: + """ + Debug Rust files via Python. + + Determine the pID for the current debugging session, attach the Rust LLDB launcher, + and execute the originally-requested script. + """ + if len(sys.argv) == 1: + msg = ( + "launch.py is not meant to be executed directly; please use the `Python: " + "Debug Rust` debugging configuration to run a python script that uses the " + "polars library." + ) + raise RuntimeError(msg) + + # Get the current process ID. + pID = os.getpid() + + # Print to the debug console to allow VSCode to pick up on the signal and start the + # Rust LLDB configuration automatically. + launch_file = Path(__file__).parents[2] / ".vscode/launch.json" + if not launch_file.exists(): + msg = f"Cannot locate {launch_file}" + raise RuntimeError(msg) + with launch_file.open("r") as f: + launch_info = f.read() + + # Overwrite the pid found in launch.json with the pid for the current process. + # Match the initial "Rust LLDB" definition with the pid defined immediately after. + pattern = re.compile('("Rust LLDB",\\s*"pid":\\s*")\\d+(")') + found = pattern.search(launch_info) + if not found: + msg = ( + "Cannot locate pid definition in launch.json for Rust LLDB configuration. " + "Please follow the instructions in the debugging section of the " + "contributing guide (https://docs.pola.rs/development/contributing/ide/#debugging) " + "for creating the launch configuration." + ) + raise RuntimeError(msg) + + launch_info_with_new_pid = pattern.sub(rf"\g<1>{pID}\g<2>", launch_info) + with launch_file.open("w") as f: + f.write(launch_info_with_new_pid) + + # Print pID to the debug console. This auto-triggers the Rust LLDB configurations. + print(f"pID = {pID}") + + # Give the LLDB time to connect. Depending on how long it takes for your LLDB + # debugging session to initialize, you may have to adjust this setting. + time.sleep(LLDB_DEBUG_WAIT_TIME_SECONDS) + + # Update sys.argv so that when exec() is called, the first argument is the script + # name itself, and the remaining are the input arguments. + sys.argv.pop(0) + with Path(sys.argv[0]).open() as fh: + script_contents = fh.read() + + # Run the originally requested file by reading in the script, compiling, and + # executing the code. + file_to_execute = Path(sys.argv[0]) + exec( + compile(script_contents, file_to_execute, mode="exec"), {"__name__": "__main__"} + ) + + +if __name__ == "__main__": + launch_debugging() diff --git a/py-polars/docs/.gitignore b/py-polars/docs/.gitignore new file mode 100644 index 000000000000..6d5180b5702b --- /dev/null +++ b/py-polars/docs/.gitignore @@ -0,0 +1,2 @@ +build/ +source/reference/**/api/ diff --git a/py-polars/docs/Makefile b/py-polars/docs/Makefile new file mode 100644 index 000000000000..30eb5bceb596 --- /dev/null +++ b/py-polars/docs/Makefile @@ -0,0 +1,27 @@ +# Minimal makefile for Sphinx documentation +# + +export BUILDING_SPHINX_DOCS = 1 + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= -j auto -W +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +clean: + @rm -rf source/reference/*/api/ + @rm -rf source/reference/api/ + @rm -rf "$(BUILDDIR)" + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/py-polars/docs/_templates/api_redirect.html b/py-polars/docs/_templates/api_redirect.html new file mode 100644 index 000000000000..c04a8b58ce54 --- /dev/null +++ b/py-polars/docs/_templates/api_redirect.html @@ -0,0 +1,10 @@ +{% set redirect = redirects[pagename.split("/")[-1]] %} + + + + This API page has moved + + + + + diff --git a/py-polars/docs/_templates/autosummary/accessor.rst b/py-polars/docs/_templates/autosummary/accessor.rst new file mode 100644 index 000000000000..4ba745cd6fdb --- /dev/null +++ b/py-polars/docs/_templates/autosummary/accessor.rst @@ -0,0 +1,6 @@ +{{ fullname }} +{{ underline }} + +.. currentmodule:: {{ module.split('.')[0] }} + +.. autoaccessor:: {{ (module.split('.')[1:] + [objname]) | join('.') }} diff --git a/py-polars/docs/_templates/autosummary/accessor_attribute.rst b/py-polars/docs/_templates/autosummary/accessor_attribute.rst new file mode 100644 index 000000000000..b5ad65d6a736 --- /dev/null +++ b/py-polars/docs/_templates/autosummary/accessor_attribute.rst @@ -0,0 +1,6 @@ +{{ fullname }} +{{ underline }} + +.. currentmodule:: {{ module.split('.')[0] }} + +.. autoaccessorattribute:: {{ (module.split('.')[1:] + [objname]) | join('.') }} diff --git a/py-polars/docs/_templates/autosummary/accessor_callable.rst b/py-polars/docs/_templates/autosummary/accessor_callable.rst new file mode 100644 index 000000000000..7a3301814f5f --- /dev/null +++ b/py-polars/docs/_templates/autosummary/accessor_callable.rst @@ -0,0 +1,6 @@ +{{ fullname }} +{{ underline }} + +.. currentmodule:: {{ module.split('.')[0] }} + +.. autoaccessorcallable:: {{ (module.split('.')[1:] + [objname]) | join('.') }}.__call__ diff --git a/py-polars/docs/_templates/autosummary/accessor_method.rst b/py-polars/docs/_templates/autosummary/accessor_method.rst new file mode 100644 index 000000000000..aefbba6ef1bb --- /dev/null +++ b/py-polars/docs/_templates/autosummary/accessor_method.rst @@ -0,0 +1,6 @@ +{{ fullname }} +{{ underline }} + +.. currentmodule:: {{ module.split('.')[0] }} + +.. autoaccessormethod:: {{ (module.split('.')[1:] + [objname]) | join('.') }} diff --git a/py-polars/docs/_templates/autosummary/class.rst b/py-polars/docs/_templates/autosummary/class.rst new file mode 100644 index 000000000000..a9c9bd2b6507 --- /dev/null +++ b/py-polars/docs/_templates/autosummary/class.rst @@ -0,0 +1,33 @@ +{% extends "!autosummary/class.rst" %} + +{% block methods %} +{% if methods %} + +.. + HACK -- the point here is that we don't want this to appear in the output, but the autosummary should still generate the pages. + .. autosummary:: + :toctree: + {% for item in all_methods %} + {%- if not item.startswith('_') or item in ['__call__'] %} + {{ name }}.{{ item }} + {%- endif -%} + {%- endfor %} + +{% endif %} +{% endblock %} + +{% block attributes %} +{% if attributes %} + +.. + HACK -- the point here is that we don't want this to appear in the output, but the autosummary should still generate the pages. + .. autosummary:: + :toctree: + {% for item in all_attributes %} + {%- if not item.startswith('_') %} + {{ name }}.{{ item }} + {%- endif -%} + {%- endfor %} + +{% endif %} +{% endblock %} diff --git a/py-polars/docs/_templates/autosummary/class_without_autosummary.rst b/py-polars/docs/_templates/autosummary/class_without_autosummary.rst new file mode 100644 index 000000000000..6676c672b206 --- /dev/null +++ b/py-polars/docs/_templates/autosummary/class_without_autosummary.rst @@ -0,0 +1,6 @@ +{{ fullname }} +{{ underline }} + +.. currentmodule:: {{ module }} + +.. autoclass:: {{ objname }} diff --git a/py-polars/docs/_templates/sidebar-nav-bs.html b/py-polars/docs/_templates/sidebar-nav-bs.html new file mode 100644 index 000000000000..7e0043e771e7 --- /dev/null +++ b/py-polars/docs/_templates/sidebar-nav-bs.html @@ -0,0 +1,9 @@ + diff --git a/py-polars/docs/requirements-docs.txt b/py-polars/docs/requirements-docs.txt new file mode 100644 index 000000000000..f618c3724a71 --- /dev/null +++ b/py-polars/docs/requirements-docs.txt @@ -0,0 +1,19 @@ +hypothesis +numpy +pandas +pyarrow + +sphinx==8.1.3 + +# Third-party Sphinx extensions +autodocsumm==0.2.14 +numpydoc==1.8.0 +pydata-sphinx-theme==0.16.0 +sphinx-autosummary-accessors==2023.4.0 +sphinx-copybutton==0.5.2 +sphinx-design==0.6.1 +sphinx-favicon==1.0.1 +sphinx-reredirects==0.1.5 +sphinx-toolbox==3.8.1 + +livereload==2.7.0 diff --git a/py-polars/docs/run_live_docs_server.py b/py-polars/docs/run_live_docs_server.py new file mode 100644 index 000000000000..140e9e261d5c --- /dev/null +++ b/py-polars/docs/run_live_docs_server.py @@ -0,0 +1,27 @@ +from livereload import Server, shell +from source.conf import html_static_path, templates_path + +# ------------------------------------------------------------------------- +# To use, just execute `python run_live_docs_server.py` in a terminal +# and a local server will run the docs in your browser, automatically +# refreshing/reloading the pages you're working on as they are modified. +# Extremely helpful to see the real output before it gets uploaded, and +# a much smoother experience than constantly running `make html` yourself. +# ------------------------------------------------------------------------- + +if __name__ == "__main__": + # establish a local docs server + svr = Server() + + # command to rebuild the docs + refresh_docs = shell("make html") + + # watch for source file changes and trigger rebuild/refresh + svr.watch("*.rst", refresh_docs, delay=1) + svr.watch("*.md", refresh_docs, delay=1) + svr.watch("source/reference/*", refresh_docs, delay=1) + for path in html_static_path + templates_path: + svr.watch(f"source/{path}/*", refresh_docs, delay=1) + + # path from which to serve the docs + svr.serve(root="build/html") diff --git a/py-polars/docs/source/_static/css/custom.css b/py-polars/docs/source/_static/css/custom.css new file mode 100644 index 000000000000..966f1a86d21e --- /dev/null +++ b/py-polars/docs/source/_static/css/custom.css @@ -0,0 +1,66 @@ +/* To have blue background of width of the block (instead of width of content) */ +dl.class>dt:first-of-type { + display: block !important; +} + +/* Display long method names over multiple lines in navbar. */ +.bd-toc-item { + overflow-wrap: break-word; +} + +/* Dark & light theme tweaks */ +html[data-theme="light"] { + --pst-color-on-background: #ecf4ff; + --pst-gradient-sidebar-left: #ffffff; + --pst-gradient-sidebar-right: #fbfbfb; + --pst-color-border: #cccccc; +} + +html[data-theme="dark"] { + --pst-color-on-background: #333333; + --pst-gradient-sidebar-left: #121212; + --pst-gradient-sidebar-right: #181818; + --pst-color-sidebar-nav: #181818; + --pst-color-border: #444444; +} + +/* add subtle gradients to sidebar and card elements */ +div.bd-sidebar-primary { + background-image: linear-gradient(90deg, var(--pst-gradient-sidebar-left) 0%, var(--pst-gradient-sidebar-right) 100%); +} + +div.sd-card { + background-image: linear-gradient(0deg, var(--pst-gradient-sidebar-left) 0%, var(--pst-gradient-sidebar-right) 100%); +} + +/* match docs footer colour to the header */ +footer.bd-footer { + background-color: var(--pst-color-on-background); +} + +/* + we're not currently doing anything meaningful with the + right toc, so hide until there's something to put there + */ +div.bd-sidebar-secondary { + display: none; +} + +label.sidebar-toggle.secondary-toggle { + display: none !important; +} + +/* fix visited link colour */ +a:visited { + color: var(--pst-color-link); +} + +/* fix ugly navbar scrollbar display */ +.sidebar-primary-items__end { + margin: 0 !important; +} + +/* give code examples the same faint drop-shadow as admonitions */ +pre { + box-shadow: 0 .2rem .5rem var(--pst-color-shadow), 0 0 .0625rem var(--pst-color-shadow) !important; +} diff --git a/py-polars/docs/source/_static/version_switcher.json b/py-polars/docs/source/_static/version_switcher.json new file mode 100644 index 000000000000..53c85e7d5005 --- /dev/null +++ b/py-polars/docs/source/_static/version_switcher.json @@ -0,0 +1,28 @@ +[ + { + "name": "dev", + "version": "dev", + "url": "https://docs.pola.rs/api/python/dev/" + }, + { + "name": "1 (stable)", + "version": "1", + "url": "https://docs.pola.rs/api/python/stable/", + "preferred": true + }, + { + "name": "0.20", + "version": "0.20", + "url": "https://docs.pola.rs/api/python/version/0.20/" + }, + { + "name": "0.19", + "version": "0.19", + "url": "https://docs.pola.rs/api/python/version/0.19/" + }, + { + "name": "0.18", + "version": "0.18", + "url": "https://docs.pola.rs/api/python/version/0.18/" + } +] diff --git a/py-polars/docs/source/_templates/layout.html b/py-polars/docs/source/_templates/layout.html new file mode 100644 index 000000000000..ab46cbf65bd2 --- /dev/null +++ b/py-polars/docs/source/_templates/layout.html @@ -0,0 +1,13 @@ +{% extends "!layout.html" %} + +{% block extrahead %} + +{% endblock %} diff --git a/py-polars/docs/source/conf.py b/py-polars/docs/source/conf.py new file mode 100644 index 000000000000..472438ffeb69 --- /dev/null +++ b/py-polars/docs/source/conf.py @@ -0,0 +1,295 @@ +# Configuration file for the Sphinx documentation builder. +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +from __future__ import annotations + +import inspect +import os +import re +import sys +import warnings +from pathlib import Path +from typing import Any + +import sphinx_autosummary_accessors + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. + +# Add py-polars directory +sys.path.insert(0, str(Path("../..").resolve())) + + +# -- Project information ----------------------------------------------------- + +project = "Polars" +author = "Ritchie Vink" +copyright = f"2025, {author}" + + +# -- General configuration --------------------------------------------------- + +extensions = [ + # Sphinx extensions + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.githubpages", + "sphinx.ext.intersphinx", + "sphinx.ext.linkcode", + "sphinx.ext.mathjax", + # Third-party extensions + "autodocsumm", + "numpydoc", + "sphinx_autosummary_accessors", + "sphinx_copybutton", + "sphinx_design", + "sphinx_favicon", + "sphinx_reredirects", + "sphinx_toolbox.more_autodoc.overloads", +] + +# Render docstring text in `single backticks` as code. +default_role = "code" + +maximum_signature_line_length = 88 + + +# Below setting is used by +# sphinx-autosummary-accessors - build docs for namespace accessors like `Series.str` +# https://sphinx-autosummary-accessors.readthedocs.io/en/stable/ +templates_path = ["_templates", sphinx_autosummary_accessors.templates_path] + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = ["Thumbs.db", ".DS_Store"] + +# Hide overload type signatures +# sphinx_toolbox - Box of handy tools for Sphinx +# https://sphinx-toolbox.readthedocs.io/en/latest/ +overloads_location = ["bottom"] + + +# -- Extension settings ------------------------------------------------------ + +# sphinx.ext.intersphinx - link to other projects' documentation +# https://www.sphinx-doc.org/en/master/usage/extensions/intersphinx.html +intersphinx_mapping = { + "numpy": ("https://numpy.org/doc/stable/", None), + "pandas": ("https://pandas.pydata.org/docs/", None), + "pyarrow": ("https://arrow.apache.org/docs/", None), + "python": ("https://docs.python.org/3", None), +} + +# numpydoc - parse numpy docstrings +# https://numpydoc.readthedocs.io/en/latest/ +# Used in favor of sphinx.ext.napoleon for nicer render of docstring sections +numpydoc_show_class_members = False + +# Sphinx-copybutton - add copy button to code blocks +# https://sphinx-copybutton.readthedocs.io/en/latest/index.html +# strip the '>>>' and '...' prompt/continuation prefixes. +copybutton_prompt_text = r">>> |\.\.\. " +copybutton_prompt_is_regexp = True + +# redirect empty root to the actual landing page +redirects = {"index": "reference/index.html"} + + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. +html_theme = "pydata_sphinx_theme" + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ["_static"] +html_css_files = ["css/custom.css"] # relative to html_static_path +html_show_sourcelink = False + +# key site root paths +static_assets_root = "https://raw.githubusercontent.com/pola-rs/polars-static/master" +github_root = "https://github.com/pola-rs/polars" +web_root = "https://docs.pola.rs" + +# Specify version for version switcher dropdown menu +git_ref = os.environ.get("POLARS_VERSION", "main") +version_match = re.fullmatch(r"py-(\d+)\.\d+\.\d+.*", git_ref) +switcher_version = version_match.group(1) if version_match is not None else "dev" + +html_js_files = [ + ( + "https://plausible.io/js/script.js", + {"data-domain": "docs.pola.rs,combined.pola.rs", "defer": "defer"}, + ), +] + +html_theme_options = { + "external_links": [ + { + "name": "User guide", + "url": f"{web_root}/", + }, + { + "name": "Polars Cloud API reference", + "url": "https://docs.cloud.pola.rs/reference/index.html", + }, + ], + "icon_links": [ + { + "name": "GitHub", + "url": github_root, + "icon": "fa-brands fa-github", + }, + { + "name": "Discord", + "url": "https://discord.gg/4UfP5cfBE7", + "icon": "fa-brands fa-discord", + }, + { + "name": "X/Twitter", + "url": "https://x.com/datapolars", + "icon": "fa-brands fa-x-twitter", + }, + { + "name": "Bluesky", + "url": "https://bsky.app/profile/pola.rs", + "icon": "fa-brands fa-bluesky", + }, + ], + "logo": { + "image_light": f"{static_assets_root}/logos/polars-logo-dark-medium.png", + "image_dark": f"{static_assets_root}/logos/polars-logo-dimmed-medium.png", + }, + "switcher": { + "json_url": f"{web_root}/api/python/dev/_static/version_switcher.json", + "version_match": switcher_version, + }, + "show_version_warning_banner": False, + "navbar_end": ["theme-switcher", "version-switcher", "navbar-icon-links"], + "check_switcher": False, +} + +# sphinx-favicon - Add support for custom favicons +# https://github.com/tcmetzger/sphinx-favicon +favicons = [ + { + "rel": "icon", + "sizes": "32x32", + "href": f"{static_assets_root}/icons/favicon-32x32.png", + }, + { + "rel": "apple-touch-icon", + "sizes": "180x180", + "href": f"{static_assets_root}/icons/touchicon-180x180.png", + }, +] + + +# sphinx-ext-linkcode - Add external links to source code +# https://www.sphinx-doc.org/en/master/usage/extensions/linkcode.html +def linkcode_resolve(domain: str, info: dict[str, Any]) -> str | None: + """ + Determine the URL corresponding to Python object. + + Based on pandas equivalent: + https://github.com/pandas-dev/pandas/blob/main/doc/source/conf.py#L629-L686 + """ + if domain != "py": + return None + + modname = info["module"] + fullname = info["fullname"] + + submod = sys.modules.get(modname) + if submod is None: + return None + + obj = submod + for part in fullname.split("."): + try: + with warnings.catch_warnings(): + # Accessing deprecated objects will generate noisy warnings + warnings.simplefilter("ignore", FutureWarning) + obj = getattr(obj, part) + except AttributeError: # noqa: PERF203 + return None + + try: + fn = inspect.getsourcefile(inspect.unwrap(obj)) + except TypeError: + try: # property + fn = inspect.getsourcefile(inspect.unwrap(obj.fget)) + except (AttributeError, TypeError): + fn = None + if not fn: + return None + + try: + source, lineno = inspect.getsourcelines(obj) + except TypeError: + try: # property + source, lineno = inspect.getsourcelines(obj.fget) + except (AttributeError, TypeError): + lineno = None + except OSError: + lineno = None + + linespec = f"#L{lineno}-L{lineno + len(source) - 1}" if lineno else "" + + conf_dir_path = Path(__file__).absolute().parent + polars_root = (conf_dir_path.parent.parent / "polars").absolute() + + fn = os.path.relpath(fn, start=polars_root) + return f"{github_root}/blob/{git_ref}/py-polars/polars/{fn}{linespec}" + + +def _minify_classpaths(s: str) -> str: + # strip private polars classpaths, leaving the classname: + # * "pl.Expr" -> "Expr" + # * "polars.expr.expr.Expr" -> "Expr" + # * "polars.lazyframe.frame.LazyFrame" -> "LazyFrame" + # also: + # * "datetime.date" => "date" + s = s.replace("datetime.", "") + return re.sub( + pattern=r""" + ~? + ( + (?:pl| + (?:polars\. + (?:_reexport|datatypes) + ) + ) + (?:\.[a-z.]+)?\. + ([A-Z][\w.]+) + ) + """, + repl=r"\2", + string=s, + flags=re.VERBOSE, + ) + + +def process_signature( # noqa: D103 + app: object, + what: object, + name: object, + obj: object, + opts: object, + sig: str, + ret: str, +) -> tuple[str, str]: + return ( + _minify_classpaths(sig) if sig else sig, + _minify_classpaths(ret) if ret else ret, + ) + + +def setup(app: Any) -> None: # noqa: D103 + # TODO: a handful of methods do not seem to trigger the event for + # some reason (possibly @overloads?) - investigate further... + app.connect("autodoc-process-signature", process_signature) diff --git a/py-polars/docs/source/index.rst b/py-polars/docs/source/index.rst new file mode 100644 index 000000000000..1c482f7fd0ba --- /dev/null +++ b/py-polars/docs/source/index.rst @@ -0,0 +1,15 @@ +.. raw:: html + +
+ +===== +Index +===== +.. raw:: html + +
+ +.. toctree:: + :maxdepth: 2 + + reference/index diff --git a/py-polars/docs/source/reference/api.rst b/py-polars/docs/source/reference/api.rst new file mode 100644 index 000000000000..20eb4e0fa5cb --- /dev/null +++ b/py-polars/docs/source/reference/api.rst @@ -0,0 +1,171 @@ +================= +Extending the API +================= +.. currentmodule:: polars + +Providing new functionality +--------------------------- + +These functions allow you to register custom functionality in a dedicated +namespace on the underlying Polars classes without requiring subclassing +or mixins. Expr, DataFrame, LazyFrame, and Series are all supported targets. + +This feature is primarily intended for use by library authors providing +domain-specific capabilities which may not exist (or belong) in the +core library. + + +Available registrations +----------------------- + +.. currentmodule:: polars.api +.. autosummary:: + :toctree: api/ + + register_expr_namespace + register_dataframe_namespace + register_lazyframe_namespace + register_series_namespace + +.. note:: + + You cannot override existing Polars namespaces (such as ``.str`` or ``.dt``), and attempting to do so + will raise an `AttributeError `_. + However, you *can* override other custom namespaces (which will only generate a + `UserWarning `_). + + +Examples +-------- + +.. tab-set:: + + .. tab-item:: Expr + + .. code-block:: python + + @pl.api.register_expr_namespace("greetings") + class Greetings: + def __init__(self, expr: pl.Expr) -> None: + self._expr = expr + + def hello(self) -> pl.Expr: + return (pl.lit("Hello ") + self._expr).alias("hi there") + + def goodbye(self) -> pl.Expr: + return (pl.lit("Sayōnara ") + self._expr).alias("bye") + + + pl.DataFrame(data=["world", "world!", "world!!"]).select( + [ + pl.all().greetings.hello(), + pl.all().greetings.goodbye(), + ] + ) + + # shape: (3, 1) shape: (3, 2) + # ┌──────────┐ ┌───────────────┬──────────────────┐ + # │ column_0 │ │ hi there ┆ bye │ + # │ --- │ │ --- ┆ --- │ + # │ str │ │ str ┆ str │ + # ╞══════════╡ >> ╞═══════════════╪══════════════════╡ + # │ world │ │ Hello world ┆ Sayōnara world │ + # │ world! │ │ Hello world! ┆ Sayōnara world! │ + # │ world!! │ │ Hello world!! ┆ Sayōnara world!! │ + # └──────────┘ └───────────────┴──────────────────┘ + + .. tab-item:: DataFrame + + .. code-block:: python + + @pl.api.register_dataframe_namespace("split") + class SplitFrame: + def __init__(self, df: pl.DataFrame) -> None: + self._df = df + + def by_alternate_rows(self) -> list[pl.DataFrame]: + df = self._df.with_row_index(name="n") + return [ + df.filter((pl.col("n") % 2) == 0).drop("n"), + df.filter((pl.col("n") % 2) != 0).drop("n"), + ] + + + pl.DataFrame( + data=["aaa", "bbb", "ccc", "ddd", "eee", "fff"], + schema=[("txt", pl.String)], + ).split.by_alternate_rows() + + # [┌─────┐ ┌─────┐ + # │ txt │ │ txt │ + # │ --- │ │ --- │ + # │ str │ │ str │ + # ╞═════╡ ╞═════╡ + # │ aaa │ │ bbb │ + # │ ccc │ │ ddd │ + # │ eee │ │ fff │ + # └─────┘, └─────┘] + + .. tab-item:: LazyFrame + + .. code-block:: python + + @pl.api.register_lazyframe_namespace("types") + class DTypeOperations: + def __init__(self, ldf: pl.LazyFrame) -> None: + self._ldf = ldf + + def upcast_integer_types(self) -> pl.LazyFrame: + return self._ldf.with_columns( + pl.col(tp).cast(pl.Int64) + for tp in (pl.Int8, pl.Int16, pl.Int32) + ) + + + ldf = pl.DataFrame( + data={"a": [1, 2], "b": [3, 4], "c": [5.6, 6.7]}, + schema=[("a", pl.Int16), ("b", pl.Int32), ("c", pl.Float32)], + ).lazy() + + ldf.types.upcast_integer_types() + + # shape: (2, 3) shape: (2, 3) + # ┌─────┬─────┬─────┐ ┌─────┬─────┬─────┐ + # │ a ┆ b ┆ c │ │ a ┆ b ┆ c │ + # │ --- ┆ --- ┆ --- │ │ --- ┆ --- ┆ --- │ + # │ i16 ┆ i32 ┆ f32 │ >> │ i64 ┆ i64 ┆ f32 │ + # ╞═════╪═════╪═════╡ ╞═════╪═════╪═════╡ + # │ 1 ┆ 3 ┆ 5.6 │ │ 1 ┆ 3 ┆ 5.6 │ + # │ 2 ┆ 4 ┆ 6.7 │ │ 2 ┆ 4 ┆ 6.7 │ + # └─────┴─────┴─────┘ └─────┴─────┴─────┘ + + .. tab-item:: Series + + .. code-block:: python + + @pl.api.register_series_namespace("math") + class MathShortcuts: + def __init__(self, s: pl.Series) -> None: + self._s = s + + def square(self) -> pl.Series: + return self._s * self._s + + def cube(self) -> pl.Series: + return self._s * self._s * self._s + + + s = pl.Series("n", [1, 2, 3, 4, 5]) + + s2 = s.math.square().rename("n2") + s3 = s.math.cube().rename("n3") + + # shape: (5,) shape: (5,) shape: (5,) + # Series: 'n' [i64] Series: 'n2' [i64] Series: 'n3' [i64] + # [ [ [ + # 1 1 1 + # 2 4 8 + # 3 9 27 + # 4 16 64 + # 5 25 125 + # ] ] ] diff --git a/py-polars/docs/source/reference/catalog/index.rst b/py-polars/docs/source/reference/catalog/index.rst new file mode 100644 index 000000000000..52fdd97030f6 --- /dev/null +++ b/py-polars/docs/source/reference/catalog/index.rst @@ -0,0 +1,23 @@ +======= +Catalog +======= + +Interface with data catalogs. + +.. raw:: html + + + +.. grid:: + + .. grid-item-card:: + + **Unity Catalog** + ^^^^^^^^^^^^^^ + + .. toctree:: + :maxdepth: 3 + + unity diff --git a/py-polars/docs/source/reference/catalog/unity.rst b/py-polars/docs/source/reference/catalog/unity.rst new file mode 100644 index 000000000000..ea5720426aa7 --- /dev/null +++ b/py-polars/docs/source/reference/catalog/unity.rst @@ -0,0 +1,22 @@ +Unity Catalog +~~~~~~~~~~~~~ +Interface with Unity catalogs. + +.. currentmodule:: polars + +.. autosummary:: + :toctree: api/ + + Catalog + Catalog.list_catalogs + Catalog.list_namespaces + Catalog.list_tables + Catalog.get_table_info + Catalog.scan_table + catalog.unity.CatalogInfo + catalog.unity.ColumnInfo + catalog.unity.DataSourceFormat + catalog.unity.NamespaceInfo + catalog.unity.TableInfo + catalog.unity.TableInfo.get_polars_schema + catalog.unity.TableType diff --git a/py-polars/docs/source/reference/config.rst b/py-polars/docs/source/reference/config.rst new file mode 100644 index 000000000000..869a10f7d9bc --- /dev/null +++ b/py-polars/docs/source/reference/config.rst @@ -0,0 +1,120 @@ +====== +Config +====== +.. currentmodule:: polars + +Config options +-------------- + +.. autosummary:: + :toctree: api/ + + Config.set_ascii_tables + Config.set_auto_structify + Config.set_decimal_separator + Config.set_engine_affinity + Config.set_float_precision + Config.set_fmt_float + Config.set_fmt_str_lengths + Config.set_fmt_table_cell_list_len + Config.set_streaming_chunk_size + Config.set_tbl_cell_alignment + Config.set_tbl_cell_numeric_alignment + Config.set_tbl_cols + Config.set_tbl_column_data_type_inline + Config.set_tbl_dataframe_shape_below + Config.set_tbl_formatting + Config.set_tbl_hide_column_data_types + Config.set_tbl_hide_column_names + Config.set_tbl_hide_dataframe_shape + Config.set_tbl_hide_dtype_separator + Config.set_tbl_rows + Config.set_tbl_width_chars + Config.set_thousands_separator + Config.set_trim_decimal_zeros + Config.set_verbose + +Config load, save, state +------------------------ +.. autosummary:: + :toctree: api/ + + Config.load + Config.load_from_file + Config.save + Config.save_to_file + Config.state + Config.restore_defaults + +While it is easy to restore *all* configuration options to their default +value using ``restore_defaults``, it can also be useful to reset *individual* +options. This can be done by setting the related value to ``None``, eg: + +.. code-block:: python + + pl.Config.set_tbl_rows(None) + + +Use as a context manager +------------------------ + +Note that ``Config`` supports setting context-scoped options. These options +are valid *only* during scope lifetime, and are reset to their initial values +(whatever they were before entering the new context) on scope exit. + +You can take advantage of this by initialising a ``Config`` instance and then +explicitly calling one or more of the available "set\_" methods on it... + +.. code-block:: python + + with pl.Config() as cfg: + cfg.set_verbose(True) + do_various_things() + + # on scope exit any modified settings are restored to their previous state + +...or, often cleaner, by setting the options in the ``Config`` init directly +(optionally omitting the "set\_" prefix for brevity): + +.. code-block:: python + + with pl.Config(verbose=True): + do_various_things() + +Use as a decorator +------------------ + +In the same vein, you can also use a ``Config`` instance as a function decorator +to temporarily set options for the duration of the function call: + +.. code-block:: python + + cfg_ascii_frames = pl.Config(ascii_tables=True, apply_on_context_enter=True) + + @cfg_ascii_frames + def write_markdown_frame_to_stdout(df: pl.DataFrame) -> None: + sys.stdout.write(str(df)) + +Multiple Config instances +------------------------- +You may want to establish related bundles of `Config` options for use in different +parts of your code. Usually options are set immediately on `Config` init, meaning +the `Config` instance cannot be reused; however, you can defer this so that options +are only invoked when entering context scope (which includes function entry if used +as a decorator)._ + +This allows you to create multiple *reusable* `Config` instances in one place, update +and modify them centrally, and apply them as needed throughout your codebase. + +.. code-block:: python + + cfg_verbose = pl.Config(verbose=True, apply_on_context_enter=True) + cfg_markdown = pl.Config(tbl_formatting="MARKDOWN", apply_on_context_enter=True) + + @cfg_markdown + def write_markdown_frame_to_stdout(df: pl.DataFrame) -> None: + sys.stdout.write(str(df)) + + @cfg_verbose + def do_various_things(): + ... diff --git a/py-polars/docs/source/reference/dataframe/aggregation.rst b/py-polars/docs/source/reference/dataframe/aggregation.rst new file mode 100644 index 000000000000..71a30b800a65 --- /dev/null +++ b/py-polars/docs/source/reference/dataframe/aggregation.rst @@ -0,0 +1,22 @@ +=========== +Aggregation +=========== + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + DataFrame.count + DataFrame.max + DataFrame.max_horizontal + DataFrame.mean + DataFrame.mean_horizontal + DataFrame.median + DataFrame.min + DataFrame.min_horizontal + DataFrame.product + DataFrame.quantile + DataFrame.std + DataFrame.sum + DataFrame.sum_horizontal + DataFrame.var diff --git a/py-polars/docs/source/reference/dataframe/attributes.rst b/py-polars/docs/source/reference/dataframe/attributes.rst new file mode 100644 index 000000000000..086cc41597eb --- /dev/null +++ b/py-polars/docs/source/reference/dataframe/attributes.rst @@ -0,0 +1,15 @@ +========== +Attributes +========== + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + DataFrame.columns + DataFrame.dtypes + DataFrame.flags + DataFrame.height + DataFrame.schema + DataFrame.shape + DataFrame.width diff --git a/py-polars/docs/source/reference/dataframe/computation.rst b/py-polars/docs/source/reference/dataframe/computation.rst new file mode 100644 index 000000000000..ff45f7f1bfac --- /dev/null +++ b/py-polars/docs/source/reference/dataframe/computation.rst @@ -0,0 +1,10 @@ +=========== +Computation +=========== + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + DataFrame.fold + DataFrame.hash_rows diff --git a/py-polars/docs/source/reference/dataframe/descriptive.rst b/py-polars/docs/source/reference/dataframe/descriptive.rst new file mode 100644 index 000000000000..00fba0d0cad1 --- /dev/null +++ b/py-polars/docs/source/reference/dataframe/descriptive.rst @@ -0,0 +1,18 @@ +=========== +Descriptive +=========== + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + DataFrame.approx_n_unique + DataFrame.describe + DataFrame.estimated_size + DataFrame.glimpse + DataFrame.is_duplicated + DataFrame.is_empty + DataFrame.is_unique + DataFrame.n_chunks + DataFrame.n_unique + DataFrame.null_count diff --git a/py-polars/docs/source/reference/dataframe/export.rst b/py-polars/docs/source/reference/dataframe/export.rst new file mode 100644 index 000000000000..8ebb005221eb --- /dev/null +++ b/py-polars/docs/source/reference/dataframe/export.rst @@ -0,0 +1,22 @@ +====== +Export +====== + +Export DataFrame data to other formats: + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + DataFrame.__array__ + DataFrame.__arrow_c_stream__ + DataFrame.__dataframe__ + DataFrame.to_arrow + DataFrame.to_dict + DataFrame.to_dicts + DataFrame.to_init_repr + DataFrame.to_jax + DataFrame.to_numpy + DataFrame.to_pandas + DataFrame.to_struct + DataFrame.to_torch diff --git a/py-polars/docs/source/reference/dataframe/group_by.rst b/py-polars/docs/source/reference/dataframe/group_by.rst new file mode 100644 index 000000000000..3b201ee3bf9b --- /dev/null +++ b/py-polars/docs/source/reference/dataframe/group_by.rst @@ -0,0 +1,27 @@ +======= +GroupBy +======= + +This namespace becomes available by calling `DataFrame.group_by(...)`. + +.. currentmodule:: polars.dataframe.group_by +.. autosummary:: + :toctree: api/ + + GroupBy.__iter__ + GroupBy.agg + GroupBy.all + GroupBy.count + GroupBy.first + GroupBy.head + GroupBy.last + GroupBy.len + GroupBy.map_groups + GroupBy.max + GroupBy.mean + GroupBy.median + GroupBy.min + GroupBy.n_unique + GroupBy.quantile + GroupBy.sum + GroupBy.tail diff --git a/py-polars/docs/source/reference/dataframe/index.rst b/py-polars/docs/source/reference/dataframe/index.rst new file mode 100644 index 000000000000..97570b06ca52 --- /dev/null +++ b/py-polars/docs/source/reference/dataframe/index.rst @@ -0,0 +1,30 @@ +========= +DataFrame +========= + +This page gives an overview of all public DataFrame methods. + +.. toctree:: + :maxdepth: 2 + :hidden: + + aggregation + attributes + computation + descriptive + export + group_by + modify_select + miscellaneous + plot + style + +.. _dataframe: + +.. currentmodule:: polars + +.. autoclass:: DataFrame + :members: + :noindex: + :autosummary: + :autosummary-nosignatures: diff --git a/py-polars/docs/source/reference/dataframe/miscellaneous.rst b/py-polars/docs/source/reference/dataframe/miscellaneous.rst new file mode 100644 index 000000000000..127cbfaab5a6 --- /dev/null +++ b/py-polars/docs/source/reference/dataframe/miscellaneous.rst @@ -0,0 +1,22 @@ +============= +Miscellaneous +============= + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + DataFrame.collect_schema + DataFrame.corr + DataFrame.equals + DataFrame.lazy + DataFrame.map_rows + +Serialization +------------- + +.. autosummary:: + :toctree: api/ + + DataFrame.deserialize + DataFrame.serialize diff --git a/py-polars/docs/source/reference/dataframe/modify_select.rst b/py-polars/docs/source/reference/dataframe/modify_select.rst new file mode 100644 index 000000000000..50855df23ba1 --- /dev/null +++ b/py-polars/docs/source/reference/dataframe/modify_select.rst @@ -0,0 +1,79 @@ +====================== +Manipulation/selection +====================== + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + DataFrame.__getitem__ + DataFrame.bottom_k + DataFrame.cast + DataFrame.clear + DataFrame.clone + DataFrame.drop + DataFrame.drop_in_place + DataFrame.drop_nans + DataFrame.drop_nulls + DataFrame.explode + DataFrame.extend + DataFrame.fill_nan + DataFrame.fill_null + DataFrame.filter + DataFrame.gather_every + DataFrame.get_column + DataFrame.get_column_index + DataFrame.get_columns + DataFrame.group_by + DataFrame.group_by_dynamic + DataFrame.head + DataFrame.hstack + DataFrame.insert_column + DataFrame.interpolate + DataFrame.item + DataFrame.iter_columns + DataFrame.iter_rows + DataFrame.iter_slices + DataFrame.join + DataFrame.join_asof + DataFrame.join_where + DataFrame.limit + DataFrame.melt + DataFrame.merge_sorted + DataFrame.partition_by + DataFrame.pipe + DataFrame.pivot + DataFrame.rechunk + DataFrame.remove + DataFrame.rename + DataFrame.replace_column + DataFrame.reverse + DataFrame.rolling + DataFrame.row + DataFrame.rows + DataFrame.rows_by_key + DataFrame.sample + DataFrame.select + DataFrame.select_seq + DataFrame.set_sorted + DataFrame.shift + DataFrame.shrink_to_fit + DataFrame.slice + DataFrame.sort + DataFrame.sql + DataFrame.tail + DataFrame.to_dummies + DataFrame.to_series + DataFrame.top_k + DataFrame.transpose + DataFrame.unique + DataFrame.unnest + DataFrame.unpivot + DataFrame.unstack + DataFrame.update + DataFrame.upsample + DataFrame.vstack + DataFrame.with_columns + DataFrame.with_columns_seq + DataFrame.with_row_count + DataFrame.with_row_index diff --git a/py-polars/docs/source/reference/dataframe/plot.rst b/py-polars/docs/source/reference/dataframe/plot.rst new file mode 100644 index 000000000000..85f1cc9395a8 --- /dev/null +++ b/py-polars/docs/source/reference/dataframe/plot.rst @@ -0,0 +1,7 @@ +==== +Plot +==== + +.. currentmodule:: polars + +.. autoproperty:: DataFrame.plot \ No newline at end of file diff --git a/py-polars/docs/source/reference/dataframe/style.rst b/py-polars/docs/source/reference/dataframe/style.rst new file mode 100644 index 000000000000..ebe9a3fabf2e --- /dev/null +++ b/py-polars/docs/source/reference/dataframe/style.rst @@ -0,0 +1,7 @@ +===== +Style +===== + +.. currentmodule:: polars + +.. autoproperty:: DataFrame.style \ No newline at end of file diff --git a/py-polars/docs/source/reference/datatypes.rst b/py-polars/docs/source/reference/datatypes.rst new file mode 100644 index 000000000000..1299130b7613 --- /dev/null +++ b/py-polars/docs/source/reference/datatypes.rst @@ -0,0 +1,75 @@ +========== +Data types +========== +.. currentmodule:: polars.datatypes + +DataType +~~~~~~~~ +.. autosummary:: + :toctree: api/ + :nosignatures: + + DataType + +Numeric +~~~~~~~ +.. autosummary:: + :toctree: api/ + :nosignatures: + + Decimal + Float32 + Float64 + Int8 + Int16 + Int32 + Int64 + Int128 + UInt8 + UInt16 + UInt32 + UInt64 + +Temporal +~~~~~~~~~~~ +.. autosummary:: + :toctree: api/ + :nosignatures: + + Date + Datetime + Duration + Time + +Nested +~~~~~~ +.. autosummary:: + :toctree: api/ + + Array + List + Field + Struct + +String +~~~~~~ +.. autosummary:: + :toctree: api/ + :nosignatures: + + String + Categorical + Enum + Utf8 + +Other +~~~~~ +.. autosummary:: + :toctree: api/ + :nosignatures: + + Binary + Boolean + Null + Object + Unknown diff --git a/py-polars/docs/source/reference/exceptions.rst b/py-polars/docs/source/reference/exceptions.rst new file mode 100644 index 000000000000..1921ffae8d0c --- /dev/null +++ b/py-polars/docs/source/reference/exceptions.rst @@ -0,0 +1,58 @@ +========== +Exceptions +========== +.. currentmodule:: polars.exceptions + +Errors +~~~~~~ + +.. autosummary:: + :toctree: api/ + :nosignatures: + + PolarsError + ColumnNotFoundError + ComputeError + DuplicateError + InvalidOperationError + ModuleUpgradeRequiredError + NoDataError + NoRowsReturnedError + OutOfBoundsError + ParameterCollisionError + RowsError + SQLInterfaceError + SQLSyntaxError + SchemaError + SchemaFieldNotFoundError + ShapeError + StringCacheMismatchError + StructFieldNotFoundError + TooManyRowsReturnedError + UnsuitableSQLError + +Warnings +~~~~~~~~ + +.. autosummary:: + :toctree: api/ + :nosignatures: + + PolarsWarning + CategoricalRemappingWarning + ChronoFormatWarning + CustomUFuncWarning + DataOrientationWarning + MapWithoutReturnDtypeWarning + PerformanceWarning + PolarsInefficientMapWarning + UnstableWarning + +Panic +~~~~~ + +.. autosummary:: + :toctree: api/ + :nosignatures: + + PanicException diff --git a/py-polars/docs/source/reference/expressions/aggregation.rst b/py-polars/docs/source/reference/expressions/aggregation.rst new file mode 100644 index 000000000000..123111120a73 --- /dev/null +++ b/py-polars/docs/source/reference/expressions/aggregation.rst @@ -0,0 +1,35 @@ +=========== +Aggregation +=========== + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + Expr.agg_groups + Expr.all + Expr.any + Expr.approx_n_unique + Expr.arg_max + Expr.arg_min + Expr.bitwise_and + Expr.bitwise_or + Expr.bitwise_xor + Expr.count + Expr.first + Expr.implode + Expr.last + Expr.len + Expr.max + Expr.mean + Expr.median + Expr.min + Expr.n_unique + Expr.nan_max + Expr.nan_min + Expr.null_count + Expr.product + Expr.quantile + Expr.std + Expr.sum + Expr.var diff --git a/py-polars/docs/source/reference/expressions/array.rst b/py-polars/docs/source/reference/expressions/array.rst new file mode 100644 index 000000000000..1573c478920c --- /dev/null +++ b/py-polars/docs/source/reference/expressions/array.rst @@ -0,0 +1,35 @@ +===== +Array +===== + +The following methods are available under the `expr.arr` attribute. + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + :template: autosummary/accessor_method.rst + + Expr.arr.all + Expr.arr.any + Expr.arr.arg_max + Expr.arr.arg_min + Expr.arr.contains + Expr.arr.count_matches + Expr.arr.explode + Expr.arr.first + Expr.arr.get + Expr.arr.join + Expr.arr.last + Expr.arr.max + Expr.arr.median + Expr.arr.min + Expr.arr.n_unique + Expr.arr.reverse + Expr.arr.shift + Expr.arr.sort + Expr.arr.std + Expr.arr.sum + Expr.arr.to_list + Expr.arr.to_struct + Expr.arr.unique + Expr.arr.var diff --git a/py-polars/docs/source/reference/expressions/binary.rst b/py-polars/docs/source/reference/expressions/binary.rst new file mode 100644 index 000000000000..ed13c8dd3596 --- /dev/null +++ b/py-polars/docs/source/reference/expressions/binary.rst @@ -0,0 +1,18 @@ +====== +Binary +====== + +The following methods are available under the `expr.bin` attribute. + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + :template: autosummary/accessor_method.rst + + Expr.bin.contains + Expr.bin.decode + Expr.bin.encode + Expr.bin.ends_with + Expr.bin.reinterpret + Expr.bin.size + Expr.bin.starts_with diff --git a/py-polars/docs/source/reference/expressions/boolean.rst b/py-polars/docs/source/reference/expressions/boolean.rst new file mode 100644 index 000000000000..f70e93d588f9 --- /dev/null +++ b/py-polars/docs/source/reference/expressions/boolean.rst @@ -0,0 +1,24 @@ +======= +Boolean +======= + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + Expr.all + Expr.any + Expr.has_nulls + Expr.is_between + Expr.is_duplicated + Expr.is_finite + Expr.is_first_distinct + Expr.is_in + Expr.is_infinite + Expr.is_last_distinct + Expr.is_nan + Expr.is_not_nan + Expr.is_not_null + Expr.is_null + Expr.is_unique + Expr.not_ diff --git a/py-polars/docs/source/reference/expressions/categories.rst b/py-polars/docs/source/reference/expressions/categories.rst new file mode 100644 index 000000000000..a437704f6094 --- /dev/null +++ b/py-polars/docs/source/reference/expressions/categories.rst @@ -0,0 +1,16 @@ +========== +Categories +========== + +The following methods are available under the `expr.cat` attribute. + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + :template: autosummary/accessor_method.rst + + Expr.cat.ends_with + Expr.cat.get_categories + Expr.cat.len_bytes + Expr.cat.len_chars + Expr.cat.starts_with diff --git a/py-polars/docs/source/reference/expressions/col.rst b/py-polars/docs/source/reference/expressions/col.rst new file mode 100644 index 000000000000..612e56e4cd63 --- /dev/null +++ b/py-polars/docs/source/reference/expressions/col.rst @@ -0,0 +1,18 @@ +========== +polars.col +========== + +Create an expression representing column(s) in a DataFrame. + +``col`` is technically not a function, but it can be used like one. + +See the class documentation below for examples and further documentation. + +----- + +.. currentmodule:: polars.functions.col +.. autoclass:: Col + :members: __call__, __getattr__ + :noindex: + :autosummary: + :autosummary-nosignatures: diff --git a/py-polars/docs/source/reference/expressions/columns.rst b/py-polars/docs/source/reference/expressions/columns.rst new file mode 100644 index 000000000000..1fdf2ccbc86b --- /dev/null +++ b/py-polars/docs/source/reference/expressions/columns.rst @@ -0,0 +1,16 @@ +=============== +Columns / names +=============== + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + Expr.alias + Expr.exclude + +.. toctree:: + :maxdepth: 2 + :hidden: + + col diff --git a/py-polars/docs/source/reference/expressions/computation.rst b/py-polars/docs/source/reference/expressions/computation.rst new file mode 100644 index 000000000000..a11ab57815e9 --- /dev/null +++ b/py-polars/docs/source/reference/expressions/computation.rst @@ -0,0 +1,86 @@ +=========== +Computation +=========== + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + Expr.abs + Expr.approx_n_unique + Expr.arccos + Expr.arccosh + Expr.arcsin + Expr.arcsinh + Expr.arctan + Expr.arctanh + Expr.arg_unique + Expr.bitwise_count_ones + Expr.bitwise_count_zeros + Expr.bitwise_leading_ones + Expr.bitwise_leading_zeros + Expr.bitwise_trailing_ones + Expr.bitwise_trailing_zeros + Expr.cbrt + Expr.cos + Expr.cosh + Expr.cot + Expr.cum_count + Expr.cum_max + Expr.cum_min + Expr.cum_prod + Expr.cum_sum + Expr.cumulative_eval + Expr.degrees + Expr.diff + Expr.dot + Expr.entropy + Expr.ewm_mean + Expr.ewm_mean_by + Expr.ewm_std + Expr.ewm_var + Expr.exp + Expr.hash + Expr.hist + Expr.index_of + Expr.kurtosis + Expr.log + Expr.log10 + Expr.log1p + Expr.mode + Expr.n_unique + Expr.pct_change + Expr.peak_max + Expr.peak_min + Expr.radians + Expr.rank + Expr.rolling_kurtosis + Expr.rolling_map + Expr.rolling_max + Expr.rolling_max_by + Expr.rolling_mean + Expr.rolling_mean_by + Expr.rolling_median + Expr.rolling_median_by + Expr.rolling_min + Expr.rolling_min_by + Expr.rolling_quantile + Expr.rolling_quantile_by + Expr.rolling_skew + Expr.rolling_std + Expr.rolling_std_by + Expr.rolling_sum + Expr.rolling_sum_by + Expr.rolling_var + Expr.rolling_var_by + Expr.search_sorted + Expr.sign + Expr.sin + Expr.sinh + Expr.skew + Expr.sqrt + Expr.tan + Expr.tanh + Expr.unique + Expr.unique_counts + Expr.value_counts diff --git a/py-polars/docs/source/reference/expressions/functions.rst b/py-polars/docs/source/reference/expressions/functions.rst new file mode 100644 index 000000000000..dd5243500192 --- /dev/null +++ b/py-polars/docs/source/reference/expressions/functions.rst @@ -0,0 +1,117 @@ +========= +Functions +========= + +These functions are available from the Polars module root and can be used as expressions, and sometimes also in eager contexts. + +---- + +**Available in module namespace:** + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + all + all_horizontal + any + any_horizontal + approx_n_unique + arange + arctan2 + arctan2d + arg_sort_by + arg_where + business_day_count + coalesce + concat_arr + concat_list + concat_str + corr + count + cov + cum_count + cum_fold + cum_reduce + cum_sum + cum_sum_horizontal + date + date_range + date_ranges + datetime + datetime_range + datetime_ranges + duration + element + exclude + first + fold + format + from_epoch + groups + head + implode + int_range + int_ranges + last + len + linear_space + linear_spaces + lit + map_batches + map_groups + max + max_horizontal + mean + mean_horizontal + median + min + min_horizontal + n_unique + nth + ones + quantile + reduce + repeat + rolling_corr + rolling_cov + select + sql + sql_expr + std + struct + sum + sum_horizontal + tail + time + time_range + time_ranges + var + when + zeros + + +**Available in expression namespace:** + +.. autosummary:: + :toctree: api/ + + Expr.all + Expr.any + Expr.approx_n_unique + Expr.count + Expr.first + Expr.head + Expr.implode + Expr.map_batches + Expr.map_elements + Expr.max + Expr.mean + Expr.median + Expr.min + Expr.n_unique + Expr.quantile + Expr.std + Expr.sum + Expr.tail + Expr.var diff --git a/py-polars/docs/source/reference/expressions/index.rst b/py-polars/docs/source/reference/expressions/index.rst new file mode 100644 index 000000000000..5f9b8b541dad --- /dev/null +++ b/py-polars/docs/source/reference/expressions/index.rst @@ -0,0 +1,36 @@ +=========== +Expressions +=========== + +This page gives an overview of all public Polars expressions. + +.. toctree:: + :maxdepth: 2 + :hidden: + + aggregation + array + binary + boolean + categories + columns + computation + functions + list + modify_select + meta + miscellaneous + name + operators + string + struct + temporal + window + +.. currentmodule:: polars + +.. autoclass:: Expr + :members: + :noindex: + :autosummary: + :autosummary-nosignatures: diff --git a/py-polars/docs/source/reference/expressions/list.rst b/py-polars/docs/source/reference/expressions/list.rst new file mode 100644 index 000000000000..18b9dd9c4867 --- /dev/null +++ b/py-polars/docs/source/reference/expressions/list.rst @@ -0,0 +1,51 @@ +==== +List +==== + +The following methods are available under the `expr.list` attribute. + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + :template: autosummary/accessor_method.rst + + Expr.list.all + Expr.list.any + Expr.list.arg_max + Expr.list.arg_min + Expr.list.concat + Expr.list.contains + Expr.list.count_matches + Expr.list.diff + Expr.list.drop_nulls + Expr.list.eval + Expr.list.explode + Expr.list.first + Expr.list.gather + Expr.list.gather_every + Expr.list.get + Expr.list.head + Expr.list.join + Expr.list.last + Expr.list.len + Expr.list.max + Expr.list.mean + Expr.list.median + Expr.list.min + Expr.list.n_unique + Expr.list.reverse + Expr.list.sample + Expr.list.set_difference + Expr.list.set_intersection + Expr.list.set_symmetric_difference + Expr.list.set_union + Expr.list.shift + Expr.list.slice + Expr.list.sort + Expr.list.std + Expr.list.sum + Expr.list.tail + Expr.list.to_array + Expr.list.to_struct + Expr.list.unique + Expr.list.var diff --git a/py-polars/docs/source/reference/expressions/meta.rst b/py-polars/docs/source/reference/expressions/meta.rst new file mode 100644 index 000000000000..6e4428381a34 --- /dev/null +++ b/py-polars/docs/source/reference/expressions/meta.rst @@ -0,0 +1,26 @@ +==== +Meta +==== + +The following methods are available under the `expr.meta` attribute. + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + :template: autosummary/accessor_method.rst + + Expr.meta.eq + Expr.meta.has_multiple_outputs + Expr.meta.is_column + Expr.meta.is_column_selection + Expr.meta.is_literal + Expr.meta.is_regex_projection + Expr.meta.ne + Expr.meta.output_name + Expr.meta.pop + Expr.meta.root_names + Expr.meta.serialize + Expr.meta.show_graph + Expr.meta.tree_format + Expr.meta.undo_aliases + Expr.meta.write_json diff --git a/py-polars/docs/source/reference/expressions/miscellaneous.rst b/py-polars/docs/source/reference/expressions/miscellaneous.rst new file mode 100644 index 000000000000..c0ea4d2caf1b --- /dev/null +++ b/py-polars/docs/source/reference/expressions/miscellaneous.rst @@ -0,0 +1,11 @@ +============= +Miscellaneous +============= + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + Expr.deserialize + Expr.from_json + Expr.set_sorted diff --git a/py-polars/docs/source/reference/expressions/modify_select.rst b/py-polars/docs/source/reference/expressions/modify_select.rst new file mode 100644 index 000000000000..73b9aaee9b5f --- /dev/null +++ b/py-polars/docs/source/reference/expressions/modify_select.rst @@ -0,0 +1,63 @@ +====================== +Manipulation/selection +====================== + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + Expr.append + Expr.arg_sort + Expr.arg_true + Expr.backward_fill + Expr.bottom_k + Expr.bottom_k_by + Expr.cast + Expr.ceil + Expr.clip + Expr.cut + Expr.drop_nans + Expr.drop_nulls + Expr.explode + Expr.extend_constant + Expr.fill_nan + Expr.fill_null + Expr.filter + Expr.flatten + Expr.floor + Expr.forward_fill + Expr.gather + Expr.gather_every + Expr.get + Expr.head + Expr.inspect + Expr.interpolate + Expr.interpolate_by + Expr.limit + Expr.lower_bound + Expr.pipe + Expr.qcut + Expr.rechunk + Expr.reinterpret + Expr.repeat_by + Expr.replace + Expr.replace_strict + Expr.reshape + Expr.reverse + Expr.rle + Expr.rle_id + Expr.round + Expr.round_sig_figs + Expr.sample + Expr.shift + Expr.shrink_dtype + Expr.shuffle + Expr.slice + Expr.sort + Expr.sort_by + Expr.tail + Expr.to_physical + Expr.top_k + Expr.top_k_by + Expr.upper_bound + Expr.where diff --git a/py-polars/docs/source/reference/expressions/name.rst b/py-polars/docs/source/reference/expressions/name.rst new file mode 100644 index 000000000000..80693496c350 --- /dev/null +++ b/py-polars/docs/source/reference/expressions/name.rst @@ -0,0 +1,20 @@ +==== +Name +==== + +The following methods are available under the `expr.name` attribute. + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + :template: autosummary/accessor_method.rst + + Expr.name.keep + Expr.name.map + Expr.name.map_fields + Expr.name.prefix + Expr.name.prefix_fields + Expr.name.suffix + Expr.name.suffix_fields + Expr.name.to_lowercase + Expr.name.to_uppercase diff --git a/py-polars/docs/source/reference/expressions/operators.rst b/py-polars/docs/source/reference/expressions/operators.rst new file mode 100644 index 000000000000..c4ce55a2b144 --- /dev/null +++ b/py-polars/docs/source/reference/expressions/operators.rst @@ -0,0 +1,56 @@ +========= +Operators +========= + +Polars supports native Python operators for all common operations; +these operators are also available as methods on the :class:`Expr` +class. + +Conjunction +~~~~~~~~~~~ + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + Expr.and_ + Expr.or_ + +Comparison +~~~~~~~~~~ + +.. autosummary:: + :toctree: api/ + + Expr.eq + Expr.eq_missing + Expr.ge + Expr.gt + Expr.le + Expr.lt + Expr.ne + Expr.ne_missing + +Numeric +~~~~~~~ + +.. autosummary:: + :toctree: api/ + + Expr.add + Expr.floordiv + Expr.mod + Expr.mul + Expr.neg + Expr.pow + Expr.sub + Expr.truediv + + +Binary +~~~~~~ + +.. autosummary:: + :toctree: api/ + + Expr.xor diff --git a/py-polars/docs/source/reference/expressions/string.rst b/py-polars/docs/source/reference/expressions/string.rst new file mode 100644 index 000000000000..c0265eb04bf3 --- /dev/null +++ b/py-polars/docs/source/reference/expressions/string.rst @@ -0,0 +1,60 @@ +====== +String +====== + +The following methods are available under the `expr.str` attribute. + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + :template: autosummary/accessor_method.rst + + Expr.str.concat + Expr.str.contains + Expr.str.contains_any + Expr.str.count_matches + Expr.str.decode + Expr.str.encode + Expr.str.ends_with + Expr.str.escape_regex + Expr.str.explode + Expr.str.extract + Expr.str.extract_all + Expr.str.extract_groups + Expr.str.extract_many + Expr.str.find + Expr.str.find_many + Expr.str.head + Expr.str.join + Expr.str.json_decode + Expr.str.json_path_match + Expr.str.len_bytes + Expr.str.len_chars + Expr.str.normalize + Expr.str.pad_end + Expr.str.pad_start + Expr.str.replace + Expr.str.replace_all + Expr.str.replace_many + Expr.str.reverse + Expr.str.slice + Expr.str.split + Expr.str.split_exact + Expr.str.splitn + Expr.str.starts_with + Expr.str.strip_chars + Expr.str.strip_chars_start + Expr.str.strip_chars_end + Expr.str.strip_prefix + Expr.str.strip_suffix + Expr.str.strptime + Expr.str.tail + Expr.str.to_date + Expr.str.to_datetime + Expr.str.to_decimal + Expr.str.to_integer + Expr.str.to_lowercase + Expr.str.to_time + Expr.str.to_titlecase + Expr.str.to_uppercase + Expr.str.zfill diff --git a/py-polars/docs/source/reference/expressions/struct.rst b/py-polars/docs/source/reference/expressions/struct.rst new file mode 100644 index 000000000000..cd081477b23b --- /dev/null +++ b/py-polars/docs/source/reference/expressions/struct.rst @@ -0,0 +1,16 @@ +====== +Struct +====== + +The following methods are available under the `expr.struct` attribute. + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + :template: autosummary/accessor_method.rst + + Expr.struct.field + Expr.struct.unnest + Expr.struct.json_encode + Expr.struct.rename_fields + Expr.struct.with_fields diff --git a/py-polars/docs/source/reference/expressions/temporal.rst b/py-polars/docs/source/reference/expressions/temporal.rst new file mode 100644 index 000000000000..2e0534b05f3c --- /dev/null +++ b/py-polars/docs/source/reference/expressions/temporal.rst @@ -0,0 +1,57 @@ +======== +Temporal +======== + +The following methods are available under the `expr.dt` attribute. + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + :template: autosummary/accessor_method.rst + + Expr.dt.add_business_days + Expr.dt.base_utc_offset + Expr.dt.cast_time_unit + Expr.dt.century + Expr.dt.combine + Expr.dt.convert_time_zone + Expr.dt.date + Expr.dt.datetime + Expr.dt.day + Expr.dt.dst_offset + Expr.dt.epoch + Expr.dt.hour + Expr.dt.is_business_day + Expr.dt.is_leap_year + Expr.dt.iso_year + Expr.dt.microsecond + Expr.dt.millennium + Expr.dt.millisecond + Expr.dt.minute + Expr.dt.month + Expr.dt.month_end + Expr.dt.month_start + Expr.dt.nanosecond + Expr.dt.offset_by + Expr.dt.ordinal_day + Expr.dt.quarter + Expr.dt.replace + Expr.dt.replace_time_zone + Expr.dt.round + Expr.dt.second + Expr.dt.strftime + Expr.dt.time + Expr.dt.timestamp + Expr.dt.to_string + Expr.dt.total_days + Expr.dt.total_hours + Expr.dt.total_microseconds + Expr.dt.total_milliseconds + Expr.dt.total_minutes + Expr.dt.total_nanoseconds + Expr.dt.total_seconds + Expr.dt.truncate + Expr.dt.week + Expr.dt.weekday + Expr.dt.with_time_unit + Expr.dt.year diff --git a/py-polars/docs/source/reference/expressions/window.rst b/py-polars/docs/source/reference/expressions/window.rst new file mode 100644 index 000000000000..7c6c045b9e27 --- /dev/null +++ b/py-polars/docs/source/reference/expressions/window.rst @@ -0,0 +1,10 @@ +====== +Window +====== + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + Expr.over + Expr.rolling diff --git a/py-polars/docs/source/reference/functions.rst b/py-polars/docs/source/reference/functions.rst new file mode 100644 index 000000000000..cfed93e38e85 --- /dev/null +++ b/py-polars/docs/source/reference/functions.rst @@ -0,0 +1,59 @@ +========= +Functions +========= +.. currentmodule:: polars + +Conversion +~~~~~~~~~~ +.. autosummary:: + :toctree: api/ + + from_arrow + from_dataframe + from_dict + from_dicts + from_numpy + from_pandas + from_records + from_repr + json_normalize + +Miscellaneous +~~~~~~~~~~~~~~~~~~~~ +.. autosummary:: + :toctree: api/ + + align_frames + concat + defer + escape_regex + +Multiple queries +~~~~~~~~~~~~~~~~ +.. autosummary:: + :toctree: api/ + + collect_all + collect_all_async + explain_all + +Random +~~~~~~ +.. autosummary:: + :toctree: api/ + + set_random_seed + +StringCache +~~~~~~~~~~~ + +Note that the `StringCache` can be used as both a context manager +and a decorator, in order to explicitly scope cache lifetime. + +.. autosummary:: + :toctree: api/ + + StringCache + enable_string_cache + disable_string_cache + using_string_cache diff --git a/py-polars/docs/source/reference/index.rst b/py-polars/docs/source/reference/index.rst new file mode 100644 index 000000000000..e281cf1423f5 --- /dev/null +++ b/py-polars/docs/source/reference/index.rst @@ -0,0 +1,121 @@ +==================== +Python API reference +==================== + +This page gives a high-level overview of all public Polars objects, functions and +methods. All classes and functions exposed in the ``polars.*`` namespace are public. + + +.. grid:: + + .. grid-item-card:: + + .. toctree:: + :maxdepth: 2 + + dataframe/index + + .. grid-item-card:: + + .. toctree:: + :maxdepth: 2 + + lazyframe/index + + .. grid-item-card:: + + .. toctree:: + :maxdepth: 2 + + series/index + + +.. grid:: + + .. grid-item-card:: + + .. toctree:: + :maxdepth: 2 + + expressions/index + selectors + + .. grid-item-card:: + + .. toctree:: + :maxdepth: 2 + + functions + + .. grid-item-card:: + + .. toctree:: + :maxdepth: 2 + + datatypes + schema/index + + +.. grid:: + + .. grid-item-card:: + + .. toctree:: + :maxdepth: 2 + + io + + .. grid-item-card:: + + .. toctree:: + :maxdepth: 2 + + catalog/index + + .. grid-item-card:: + + .. toctree:: + :maxdepth: 2 + + config + + .. grid-item-card:: + + .. toctree:: + :maxdepth: 2 + + api + + .. toctree:: + :maxdepth: 1 + + plugins + + +.. grid:: + + .. grid-item-card:: + + .. toctree:: + :maxdepth: 2 + + sql/index + + .. grid-item-card:: + + .. toctree:: + :maxdepth: 1 + + exceptions + + .. toctree:: + :maxdepth: 2 + + testing + + .. grid-item-card:: + + .. toctree:: + :maxdepth: 1 + + metadata diff --git a/py-polars/docs/source/reference/io.rst b/py-polars/docs/source/reference/io.rst new file mode 100644 index 000000000000..d902aa864634 --- /dev/null +++ b/py-polars/docs/source/reference/io.rst @@ -0,0 +1,155 @@ +============ +Input/output +============ +.. currentmodule:: polars + +Avro +~~~~ +.. autosummary:: + :toctree: api/ + + read_avro + DataFrame.write_avro + +Clipboard +~~~~~~~~~ +.. autosummary:: + :toctree: api/ + + read_clipboard + DataFrame.write_clipboard + +CSV +~~~ +.. autosummary:: + :toctree: api/ + + read_csv + read_csv_batched + scan_csv + DataFrame.write_csv + LazyFrame.sink_csv + +.. currentmodule:: polars.io.csv.batched_reader + +.. autosummary:: + :toctree: api/ + + BatchedCsvReader.next_batches + +.. currentmodule:: polars + +Database +~~~~~~~~ +.. autosummary:: + :toctree: api/ + + read_database + read_database_uri + DataFrame.write_database + +Delta Lake +~~~~~~~~~~ +.. autosummary:: + :toctree: api/ + + read_delta + scan_delta + DataFrame.write_delta + +Excel / ODS +~~~~~~~~~~~ +.. autosummary:: + :toctree: api/ + + read_excel + read_ods + DataFrame.write_excel + +Feather / IPC +~~~~~~~~~~~~~ +.. autosummary:: + :toctree: api/ + + read_ipc + read_ipc_schema + read_ipc_stream + scan_ipc + DataFrame.write_ipc + DataFrame.write_ipc_stream + LazyFrame.sink_ipc + +Iceberg +~~~~~~~ +.. autosummary:: + :toctree: api/ + + scan_iceberg + DataFrame.write_iceberg + +JSON +~~~~ +.. autosummary:: + :toctree: api/ + + read_json + read_ndjson + scan_ndjson + DataFrame.write_json + DataFrame.write_ndjson + LazyFrame.sink_ndjson + + +Partition +~~~~~~~~~ +Sink to disk with differing partitioning strategies. + +.. autosummary:: + :toctree: api/ + + PartitionByKey + PartitionMaxSize + PartitionParted + +.. currentmodule:: polars.io.partition + +.. autosummary:: + :toctree: api/ + + KeyedPartition + BasePartitionContext + KeyedPartitionContext + +.. currentmodule:: polars + +Parquet +~~~~~~~ +.. autosummary:: + :toctree: api/ + + read_parquet + read_parquet_schema + scan_parquet + DataFrame.write_parquet + LazyFrame.sink_parquet + +PyArrow Datasets +~~~~~~~~~~~~~~~~ +Connect to pyarrow datasets. + +.. autosummary:: + :toctree: api/ + + scan_pyarrow_dataset + +Cloud Credentials +~~~~~~~~~~~~~~~~~ +Configuration for cloud credential provisioning. + +.. autosummary:: + :toctree: api/ + + CredentialProvider + CredentialProviderAWS + CredentialProviderAzure + CredentialProviderGCP diff --git a/py-polars/docs/source/reference/lazyframe/aggregation.rst b/py-polars/docs/source/reference/lazyframe/aggregation.rst new file mode 100644 index 000000000000..a7ada9045f8f --- /dev/null +++ b/py-polars/docs/source/reference/lazyframe/aggregation.rst @@ -0,0 +1,18 @@ +=========== +Aggregation +=========== + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + LazyFrame.count + LazyFrame.max + LazyFrame.mean + LazyFrame.median + LazyFrame.min + LazyFrame.null_count + LazyFrame.quantile + LazyFrame.std + LazyFrame.sum + LazyFrame.var diff --git a/py-polars/docs/source/reference/lazyframe/attributes.rst b/py-polars/docs/source/reference/lazyframe/attributes.rst new file mode 100644 index 000000000000..23617421b8bf --- /dev/null +++ b/py-polars/docs/source/reference/lazyframe/attributes.rst @@ -0,0 +1,12 @@ +========== +Attributes +========== + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + LazyFrame.columns + LazyFrame.dtypes + LazyFrame.schema + LazyFrame.width diff --git a/py-polars/docs/source/reference/lazyframe/descriptive.rst b/py-polars/docs/source/reference/lazyframe/descriptive.rst new file mode 100644 index 000000000000..0f05afae8960 --- /dev/null +++ b/py-polars/docs/source/reference/lazyframe/descriptive.rst @@ -0,0 +1,11 @@ +=========== +Descriptive +=========== + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + LazyFrame.describe + LazyFrame.explain + LazyFrame.show_graph diff --git a/py-polars/docs/source/reference/lazyframe/gpu_engine.rst b/py-polars/docs/source/reference/lazyframe/gpu_engine.rst new file mode 100644 index 000000000000..3372f72566bf --- /dev/null +++ b/py-polars/docs/source/reference/lazyframe/gpu_engine.rst @@ -0,0 +1,14 @@ +========= +GPUEngine +========= + +This object provides fine-grained control over the behavior of the +GPU engine when calling `LazyFrame.collect()` with an `engine` +argument. + +.. currentmodule:: polars.lazyframe.engine_config + +.. autosummary:: + :toctree: api/ + + GPUEngine diff --git a/py-polars/docs/source/reference/lazyframe/group_by.rst b/py-polars/docs/source/reference/lazyframe/group_by.rst new file mode 100644 index 000000000000..3cb0abd2fd89 --- /dev/null +++ b/py-polars/docs/source/reference/lazyframe/group_by.rst @@ -0,0 +1,27 @@ +======= +GroupBy +======= + +This namespace becomes available by calling `LazyFrame.group_by(...)`. + +.. currentmodule:: polars.lazyframe.group_by + +.. autosummary:: + :toctree: api/ + + LazyGroupBy.agg + LazyGroupBy.all + LazyGroupBy.count + LazyGroupBy.first + LazyGroupBy.head + LazyGroupBy.last + LazyGroupBy.len + LazyGroupBy.map_groups + LazyGroupBy.max + LazyGroupBy.mean + LazyGroupBy.median + LazyGroupBy.min + LazyGroupBy.n_unique + LazyGroupBy.quantile + LazyGroupBy.sum + LazyGroupBy.tail diff --git a/py-polars/docs/source/reference/lazyframe/in_process.rst b/py-polars/docs/source/reference/lazyframe/in_process.rst new file mode 100644 index 000000000000..e3a2f7110994 --- /dev/null +++ b/py-polars/docs/source/reference/lazyframe/in_process.rst @@ -0,0 +1,14 @@ +============== +InProcessQuery +============== + +This namespace becomes available by calling `LazyFrame.collect(background=True)`. + +.. currentmodule:: polars.lazyframe.in_process + +.. autosummary:: + :toctree: api/ + + InProcessQuery.cancel + InProcessQuery.fetch + InProcessQuery.fetch_blocking diff --git a/py-polars/docs/source/reference/lazyframe/index.rst b/py-polars/docs/source/reference/lazyframe/index.rst new file mode 100644 index 000000000000..ab37eb776fb0 --- /dev/null +++ b/py-polars/docs/source/reference/lazyframe/index.rst @@ -0,0 +1,28 @@ +========= +LazyFrame +========= + +This page gives an overview of all public LazyFrame methods. + +.. toctree:: + :maxdepth: 2 + :hidden: + + aggregation + attributes + descriptive + group_by + modify_select + miscellaneous + in_process + gpu_engine + +.. _lazyframe: + +.. currentmodule:: polars + +.. autoclass:: LazyFrame + :members: + :noindex: + :autosummary: + :autosummary-nosignatures: diff --git a/py-polars/docs/source/reference/lazyframe/miscellaneous.rst b/py-polars/docs/source/reference/lazyframe/miscellaneous.rst new file mode 100644 index 000000000000..40cc8aaec335 --- /dev/null +++ b/py-polars/docs/source/reference/lazyframe/miscellaneous.rst @@ -0,0 +1,26 @@ +============= +Miscellaneous +============= + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + LazyFrame.cache + LazyFrame.collect + LazyFrame.collect_async + LazyFrame.collect_schema + LazyFrame.lazy + LazyFrame.map_batches + LazyFrame.pipe + LazyFrame.profile + LazyFrame.remote + +Serialization +------------- + +.. autosummary:: + :toctree: api/ + + LazyFrame.deserialize + LazyFrame.serialize diff --git a/py-polars/docs/source/reference/lazyframe/modify_select.rst b/py-polars/docs/source/reference/lazyframe/modify_select.rst new file mode 100644 index 000000000000..1e7f66c46eea --- /dev/null +++ b/py-polars/docs/source/reference/lazyframe/modify_select.rst @@ -0,0 +1,56 @@ +====================== +Manipulation/selection +====================== + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + LazyFrame.approx_n_unique + LazyFrame.bottom_k + LazyFrame.cast + LazyFrame.clear + LazyFrame.clone + LazyFrame.drop + LazyFrame.drop_nans + LazyFrame.drop_nulls + LazyFrame.explode + LazyFrame.fill_nan + LazyFrame.fill_null + LazyFrame.filter + LazyFrame.first + LazyFrame.gather_every + LazyFrame.group_by + LazyFrame.group_by_dynamic + LazyFrame.head + LazyFrame.inspect + LazyFrame.interpolate + LazyFrame.join + LazyFrame.join_asof + LazyFrame.join_where + LazyFrame.last + LazyFrame.limit + LazyFrame.melt + LazyFrame.merge_sorted + LazyFrame.remove + LazyFrame.rename + LazyFrame.reverse + LazyFrame.rolling + LazyFrame.select + LazyFrame.select_seq + LazyFrame.set_sorted + LazyFrame.shift + LazyFrame.slice + LazyFrame.sort + LazyFrame.sql + LazyFrame.tail + LazyFrame.top_k + LazyFrame.unique + LazyFrame.unnest + LazyFrame.unpivot + LazyFrame.update + LazyFrame.with_columns + LazyFrame.with_columns_seq + LazyFrame.with_context + LazyFrame.with_row_count + LazyFrame.with_row_index diff --git a/py-polars/docs/source/reference/metadata.rst b/py-polars/docs/source/reference/metadata.rst new file mode 100644 index 000000000000..4d9c0dbf9c60 --- /dev/null +++ b/py-polars/docs/source/reference/metadata.rst @@ -0,0 +1,13 @@ +======== +Metadata +======== +.. currentmodule:: polars + +.. autosummary:: + :toctree: api/ + + build_info + get_index_type + show_versions + thread_pool_size + threadpool_size diff --git a/py-polars/docs/source/reference/plugins.rst b/py-polars/docs/source/reference/plugins.rst new file mode 100644 index 000000000000..bab812ec5b56 --- /dev/null +++ b/py-polars/docs/source/reference/plugins.rst @@ -0,0 +1,47 @@ +======= +Plugins +======= +.. currentmodule:: polars + +Polars allows you to extend its functionality with either Expression plugins or IO plugins. +See the `user guide `_ for more information and resources. + +Expression plugins +------------------ + +Expression plugins are the preferred way to create user defined functions. They allow you to compile +a Rust function and register that as an expression into the Polars library. The Polars engine will +dynamically link your function at runtime and your expression will run almost as fast as native +expressions. Note that this works without any interference of Python and thus no GIL contention. + +See the `expression plugins section of the user guide `_ +for more information. + +.. autosummary:: + :toctree: api/ + + plugins.register_plugin_function + + +IO plugins +------------------ + +IO plugins allow you to register different file formats as sources to the Polars engines. + +See the `IO plugins section of the user guide `_ +for more information. + +.. note:: + + The ``io.plugins`` module is not imported by default in order to optimise import speed of + the primary ``polars`` module. Either import ``polars.io.plugins`` and *then* use that + namespace, or import ``register_io_source`` from the full module path, e.g.: + + .. code-block:: python + + from polars.io.plugins import register_io_source + +.. autosummary:: + :toctree: api/ + + io.plugins.register_io_source diff --git a/py-polars/docs/source/reference/schema/index.rst b/py-polars/docs/source/reference/schema/index.rst new file mode 100644 index 000000000000..e383fe899862 --- /dev/null +++ b/py-polars/docs/source/reference/schema/index.rst @@ -0,0 +1,11 @@ +====== +Schema +====== + +.. currentmodule:: polars + +.. autoclass:: Schema + :members: + :noindex: + :autosummary: + :autosummary-nosignatures: diff --git a/py-polars/docs/source/reference/selectors.rst b/py-polars/docs/source/reference/selectors.rst new file mode 100644 index 000000000000..13c22cb0843e --- /dev/null +++ b/py-polars/docs/source/reference/selectors.rst @@ -0,0 +1,140 @@ +========= +Selectors +========= +.. currentmodule:: polars + +Selectors allow for more intuitive selection of columns from :class:`DataFrame` +or :class:`LazyFrame` objects based on their name, dtype or other properties. +They unify and build on the related functionality that is available through +the :meth:`col` expression and can also broadcast expressions over the selected +columns. + +Importing +--------- + +* Selectors are available as functions imported from ``polars.selectors`` +* Typical/recommended usage is to import the module as ``cs`` and employ selectors from there. + + .. code-block:: python + + import polars.selectors as cs + import polars as pl + + df = pl.DataFrame( + { + "w": ["xx", "yy", "xx", "yy", "xx"], + "x": [1, 2, 1, 4, -2], + "y": [3.0, 4.5, 1.0, 2.5, -2.0], + "z": ["a", "b", "a", "b", "b"], + }, + ) + df.group_by(cs.string()).agg(cs.numeric().sum()) + +Set operations +-------------- + +Selectors support the following ``set`` operations: + +.. table:: + :widths: 20 60 + + +------------------------+------------+ + | Operation | Expression | + +========================+============+ + | `UNION` | ``A | B`` | + +------------------------+------------+ + | `INTERSECTION` | ``A & B`` | + +------------------------+------------+ + | `DIFFERENCE` | ``A - B`` | + +------------------------+------------+ + | `SYMMETRIC DIFFERENCE` | ``A ^ B`` | + +------------------------+------------+ + | `COMPLEMENT` | ``~A`` | + +------------------------+------------+ + +Note that both individual selector results and selector set operations will always return +matching columns in the same order as the underlying frame schema. + +Examples +======== + +.. code-block:: python + + import polars.selectors as cs + import polars as pl + + # set up an empty dataframe with plenty of columns of various dtypes + df = pl.DataFrame( + schema={ + "abc": pl.UInt16, + "bbb": pl.UInt32, + "cde": pl.Float64, + "def": pl.Float32, + "eee": pl.Boolean, + "fgg": pl.Boolean, + "ghi": pl.Time, + "JJK": pl.Date, + "Lmn": pl.Duration, + "opp": pl.Datetime("ms"), + "qqR": pl.String, + }, + ) + + # Select the UNION of temporal, strings and columns that start with "e" + assert df.select(cs.temporal() | cs.string() | cs.starts_with("e")).schema == { + "eee": pl.Boolean, + "ghi": pl.Time, + "JJK": pl.Date, + "Lmn": pl.Duration, + "opp": pl.Datetime("ms"), + "qqR": pl.String, + } + + # Select the INTERSECTION of temporal and column names that match "opp" OR "JJK" + assert df.select(cs.temporal() & cs.matches("opp|JJK")).schema == { + "JJK": pl.Date, + "opp": pl.Datetime("ms"), + } + + # Select the DIFFERENCE of temporal columns and columns that contain the name "opp" OR "JJK" + assert df.select(cs.temporal() - cs.matches("opp|JJK")).schema == { + "ghi": pl.Time, + "Lmn": pl.Duration, + } + + # Select the SYMMETRIC DIFFERENCE of numeric columns and columns that contain an "e" + assert df.select(cs.contains("e") ^ cs.numeric()).schema == { + "abc": UInt16, + "bbb": UInt32, + "eee": Boolean, + } + + # Select the COMPLEMENT of all columns of dtypes Duration and Time + assert df.select(~cs.by_dtype([pl.Duration, pl.Time])).schema == { + "abc": pl.UInt16, + "bbb": pl.UInt32, + "cde": pl.Float64, + "def": pl.Float32, + "eee": pl.Boolean, + "fgg": pl.Boolean, + "JJK": pl.Date, + "opp": pl.Datetime("ms"), + "qqR": pl.String, + } + + +.. note:: + + If you don't want to use the set operations on the selectors, you can materialize them as ``expressions`` + by calling ``as_expr``. This ensures the operations ``OR, AND, etc`` are dispatched to the underlying + expressions instead. + +Functions +--------- + +Available selector functions: + +.. automodule:: polars.selectors + :members: + :autosummary: + :autosummary-no-titles: diff --git a/py-polars/docs/source/reference/series/aggregation.rst b/py-polars/docs/source/reference/series/aggregation.rst new file mode 100644 index 000000000000..fe74d9eb4fd0 --- /dev/null +++ b/py-polars/docs/source/reference/series/aggregation.rst @@ -0,0 +1,24 @@ +=========== +Aggregation +=========== + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + Series.arg_max + Series.arg_min + Series.count + Series.implode + Series.max + Series.mean + Series.median + Series.min + Series.mode + Series.nan_max + Series.nan_min + Series.product + Series.quantile + Series.std + Series.sum + Series.var diff --git a/py-polars/docs/source/reference/series/array.rst b/py-polars/docs/source/reference/series/array.rst new file mode 100644 index 000000000000..92effb371544 --- /dev/null +++ b/py-polars/docs/source/reference/series/array.rst @@ -0,0 +1,35 @@ +===== +Array +===== + +The following methods are available under the `Series.arr` attribute. + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + :template: autosummary/accessor_method.rst + + Series.arr.all + Series.arr.any + Series.arr.arg_max + Series.arr.arg_min + Series.arr.contains + Series.arr.count_matches + Series.arr.explode + Series.arr.first + Series.arr.get + Series.arr.join + Series.arr.last + Series.arr.max + Series.arr.median + Series.arr.min + Series.arr.n_unique + Series.arr.reverse + Series.arr.shift + Series.arr.sort + Series.arr.std + Series.arr.sum + Series.arr.to_list + Series.arr.to_struct + Series.arr.unique + Series.arr.var diff --git a/py-polars/docs/source/reference/series/attributes.rst b/py-polars/docs/source/reference/series/attributes.rst new file mode 100644 index 000000000000..2e2deb52f890 --- /dev/null +++ b/py-polars/docs/source/reference/series/attributes.rst @@ -0,0 +1,12 @@ +========== +Attributes +========== + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + Series.dtype + Series.flags + Series.name + Series.shape diff --git a/py-polars/docs/source/reference/series/binary.rst b/py-polars/docs/source/reference/series/binary.rst new file mode 100644 index 000000000000..25e622d420b2 --- /dev/null +++ b/py-polars/docs/source/reference/series/binary.rst @@ -0,0 +1,18 @@ +====== +Binary +====== + +The following methods are available under the `Series.bin` attribute. + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + :template: autosummary/accessor_method.rst + + Series.bin.contains + Series.bin.decode + Series.bin.encode + Series.bin.ends_with + Series.bin.reinterpret + Series.bin.size + Series.bin.starts_with diff --git a/py-polars/docs/source/reference/series/boolean.rst b/py-polars/docs/source/reference/series/boolean.rst new file mode 100644 index 000000000000..33da829f1e77 --- /dev/null +++ b/py-polars/docs/source/reference/series/boolean.rst @@ -0,0 +1,11 @@ +======= +Boolean +======= + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + Series.all + Series.any + Series.not_ diff --git a/py-polars/docs/source/reference/series/categories.rst b/py-polars/docs/source/reference/series/categories.rst new file mode 100644 index 000000000000..46db6491f100 --- /dev/null +++ b/py-polars/docs/source/reference/series/categories.rst @@ -0,0 +1,19 @@ +========== +Categories +========== + +The following methods are available under the `Series.cat` attribute. + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + :template: autosummary/accessor_method.rst + + Series.cat.ends_with + Series.cat.get_categories + Series.cat.is_local + Series.cat.len_bytes + Series.cat.len_chars + Series.cat.starts_with + Series.cat.to_local + Series.cat.uses_lexical_ordering diff --git a/py-polars/docs/source/reference/series/computation.rst b/py-polars/docs/source/reference/series/computation.rst new file mode 100644 index 000000000000..5a3318dfbda6 --- /dev/null +++ b/py-polars/docs/source/reference/series/computation.rst @@ -0,0 +1,80 @@ +=========== +Computation +=========== + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + Series.abs + Series.arccos + Series.arccosh + Series.arcsin + Series.arcsinh + Series.arctan + Series.arctanh + Series.arg_true + Series.arg_unique + Series.approx_n_unique + Series.bitwise_count_ones + Series.bitwise_count_zeros + Series.bitwise_leading_ones + Series.bitwise_leading_zeros + Series.bitwise_trailing_ones + Series.bitwise_trailing_zeros + Series.bitwise_and + Series.bitwise_or + Series.bitwise_xor + Series.cbrt + Series.cos + Series.cosh + Series.cot + Series.cum_count + Series.cum_max + Series.cum_min + Series.cum_prod + Series.cum_sum + Series.cumulative_eval + Series.diff + Series.dot + Series.entropy + Series.ewm_mean + Series.ewm_mean_by + Series.ewm_std + Series.ewm_var + Series.exp + Series.first + Series.hash + Series.hist + Series.index_of + Series.is_between + Series.kurtosis + Series.last + Series.log + Series.log10 + Series.log1p + Series.pct_change + Series.peak_max + Series.peak_min + Series.rank + Series.replace + Series.replace_strict + Series.rolling_kurtosis + Series.rolling_map + Series.rolling_max + Series.rolling_mean + Series.rolling_median + Series.rolling_min + Series.rolling_quantile + Series.rolling_skew + Series.rolling_std + Series.rolling_sum + Series.rolling_var + Series.search_sorted + Series.sign + Series.sin + Series.sinh + Series.skew + Series.sqrt + Series.tan + Series.tanh diff --git a/py-polars/docs/source/reference/series/descriptive.rst b/py-polars/docs/source/reference/series/descriptive.rst new file mode 100644 index 000000000000..4f7eb6ec011f --- /dev/null +++ b/py-polars/docs/source/reference/series/descriptive.rst @@ -0,0 +1,34 @@ +=========== +Descriptive +=========== + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + Series.chunk_lengths + Series.describe + Series.estimated_size + Series.has_nulls + Series.has_validity + Series.is_duplicated + Series.is_empty + Series.is_finite + Series.is_first_distinct + Series.is_in + Series.is_infinite + Series.is_last_distinct + Series.is_nan + Series.is_not_nan + Series.is_not_null + Series.is_null + Series.is_sorted + Series.is_unique + Series.len + Series.lower_bound + Series.n_chunks + Series.n_unique + Series.null_count + Series.unique_counts + Series.upper_bound + Series.value_counts diff --git a/py-polars/docs/source/reference/series/export.rst b/py-polars/docs/source/reference/series/export.rst new file mode 100644 index 000000000000..268ef25113fd --- /dev/null +++ b/py-polars/docs/source/reference/series/export.rst @@ -0,0 +1,20 @@ +====== +Export +====== + +Export Series data to other formats: + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + Series.__array__ + Series.__arrow_c_stream__ + Series.to_arrow + Series.to_frame + Series.to_init_repr + Series.to_jax + Series.to_list + Series.to_numpy + Series.to_pandas + Series.to_torch diff --git a/py-polars/docs/source/reference/series/index.rst b/py-polars/docs/source/reference/series/index.rst new file mode 100644 index 000000000000..d507312498c9 --- /dev/null +++ b/py-polars/docs/source/reference/series/index.rst @@ -0,0 +1,37 @@ +====== +Series +====== + +This page gives an overview of all public Series methods. + +.. toctree:: + :maxdepth: 2 + :hidden: + + aggregation + array + attributes + binary + boolean + categories + computation + descriptive + export + list + modify_select + miscellaneous + operators + plot + string + struct + temporal + +.. _series: + +.. currentmodule:: polars + +.. autoclass:: Series + :members: + :noindex: + :autosummary: + :autosummary-nosignatures: diff --git a/py-polars/docs/source/reference/series/list.rst b/py-polars/docs/source/reference/series/list.rst new file mode 100644 index 000000000000..ab857679c059 --- /dev/null +++ b/py-polars/docs/source/reference/series/list.rst @@ -0,0 +1,51 @@ +==== +List +==== + +The following methods are available under the `Series.list` attribute. + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + :template: autosummary/accessor_method.rst + + Series.list.all + Series.list.any + Series.list.arg_max + Series.list.arg_min + Series.list.concat + Series.list.contains + Series.list.count_matches + Series.list.diff + Series.list.drop_nulls + Series.list.eval + Series.list.explode + Series.list.first + Series.list.gather + Series.list.gather_every + Series.list.get + Series.list.head + Series.list.join + Series.list.last + Series.list.len + Series.list.max + Series.list.mean + Series.list.median + Series.list.min + Series.list.n_unique + Series.list.reverse + Series.list.sample + Series.list.set_difference + Series.list.set_intersection + Series.list.set_symmetric_difference + Series.list.set_union + Series.list.shift + Series.list.slice + Series.list.sort + Series.list.std + Series.list.sum + Series.list.tail + Series.list.to_array + Series.list.to_struct + Series.list.unique + Series.list.var diff --git a/py-polars/docs/source/reference/series/miscellaneous.rst b/py-polars/docs/source/reference/series/miscellaneous.rst new file mode 100644 index 000000000000..729071b69b95 --- /dev/null +++ b/py-polars/docs/source/reference/series/miscellaneous.rst @@ -0,0 +1,14 @@ +============= +Miscellaneous +============= + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + Series.equals + Series.get_chunks + Series.map_elements + Series.reinterpret + Series.set_sorted + Series.to_physical diff --git a/py-polars/docs/source/reference/series/modify_select.rst b/py-polars/docs/source/reference/series/modify_select.rst new file mode 100644 index 000000000000..ca471c789a9a --- /dev/null +++ b/py-polars/docs/source/reference/series/modify_select.rst @@ -0,0 +1,61 @@ +====================== +Manipulation/selection +====================== + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + Series.__getitem__ + Series.alias + Series.append + Series.arg_sort + Series.backward_fill + Series.bottom_k + Series.cast + Series.ceil + Series.clear + Series.clip + Series.clone + Series.cut + Series.drop_nans + Series.drop_nulls + Series.explode + Series.extend + Series.extend_constant + Series.fill_nan + Series.fill_null + Series.filter + Series.floor + Series.forward_fill + Series.gather + Series.gather_every + Series.head + Series.interpolate + Series.interpolate_by + Series.item + Series.limit + Series.new_from_index + Series.qcut + Series.rechunk + Series.rename + Series.reshape + Series.reverse + Series.rle + Series.rle_id + Series.round + Series.round_sig_figs + Series.sample + Series.scatter + Series.set + Series.shift + Series.shrink_dtype + Series.shrink_to_fit + Series.shuffle + Series.slice + Series.sort + Series.tail + Series.to_dummies + Series.top_k + Series.unique + Series.zip_with diff --git a/py-polars/docs/source/reference/series/operators.rst b/py-polars/docs/source/reference/series/operators.rst new file mode 100644 index 000000000000..e01c1b39e9de --- /dev/null +++ b/py-polars/docs/source/reference/series/operators.rst @@ -0,0 +1,31 @@ +========= +Operators +========= + +Polars supports native Python operators for all common operations; +many of these operators are also available as methods on the :class:`Series` +class. + +Comparison +~~~~~~~~~~ + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + Series.eq + Series.eq_missing + Series.ge + Series.gt + Series.le + Series.lt + Series.ne + Series.ne_missing + +Numeric +~~~~~~~ + +.. autosummary:: + :toctree: api/ + + Series.pow diff --git a/py-polars/docs/source/reference/series/plot.rst b/py-polars/docs/source/reference/series/plot.rst new file mode 100644 index 000000000000..f7f719b8472e --- /dev/null +++ b/py-polars/docs/source/reference/series/plot.rst @@ -0,0 +1,7 @@ +==== +Plot +==== + +.. currentmodule:: polars + +.. autoproperty:: Series.plot \ No newline at end of file diff --git a/py-polars/docs/source/reference/series/string.rst b/py-polars/docs/source/reference/series/string.rst new file mode 100644 index 000000000000..a3f51279975f --- /dev/null +++ b/py-polars/docs/source/reference/series/string.rst @@ -0,0 +1,60 @@ +====== +String +====== + +The following methods are available under the `Series.str` attribute. + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + :template: autosummary/accessor_method.rst + + Series.str.concat + Series.str.contains + Series.str.contains_any + Series.str.count_matches + Series.str.decode + Series.str.encode + Series.str.ends_with + Series.str.escape_regex + Series.str.explode + Series.str.extract + Series.str.extract_all + Series.str.extract_groups + Series.str.extract_many + Series.str.find + Series.str.find_many + Series.str.head + Series.str.join + Series.str.json_decode + Series.str.json_path_match + Series.str.len_bytes + Series.str.len_chars + Series.str.normalize + Series.str.pad_end + Series.str.pad_start + Series.str.replace + Series.str.replace_all + Series.str.replace_many + Series.str.reverse + Series.str.slice + Series.str.split + Series.str.split_exact + Series.str.splitn + Series.str.starts_with + Series.str.strip_chars + Series.str.strip_chars_start + Series.str.strip_chars_end + Series.str.strip_prefix + Series.str.strip_suffix + Series.str.strptime + Series.str.tail + Series.str.to_date + Series.str.to_datetime + Series.str.to_decimal + Series.str.to_integer + Series.str.to_lowercase + Series.str.to_time + Series.str.to_titlecase + Series.str.to_uppercase + Series.str.zfill diff --git a/py-polars/docs/source/reference/series/struct.rst b/py-polars/docs/source/reference/series/struct.rst new file mode 100644 index 000000000000..af753cb1389b --- /dev/null +++ b/py-polars/docs/source/reference/series/struct.rst @@ -0,0 +1,22 @@ +====== +Struct +====== + +The following methods are available under the `Series.struct` attribute. + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + :template: autosummary/accessor_method.rst + + Series.struct.field + Series.struct.json_encode + Series.struct.rename_fields + Series.struct.unnest + +.. autosummary:: + :toctree: api/ + :template: autosummary/accessor_attribute.rst + + Series.struct.fields + Series.struct.schema diff --git a/py-polars/docs/source/reference/series/temporal.rst b/py-polars/docs/source/reference/series/temporal.rst new file mode 100644 index 000000000000..ce869600d7f6 --- /dev/null +++ b/py-polars/docs/source/reference/series/temporal.rst @@ -0,0 +1,61 @@ +======== +Temporal +======== + +The following methods are available under the `Series.dt` attribute. + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + :template: autosummary/accessor_method.rst + + Series.dt.add_business_days + Series.dt.base_utc_offset + Series.dt.cast_time_unit + Series.dt.century + Series.dt.combine + Series.dt.convert_time_zone + Series.dt.date + Series.dt.datetime + Series.dt.day + Series.dt.dst_offset + Series.dt.epoch + Series.dt.hour + Series.dt.is_business_day + Series.dt.is_leap_year + Series.dt.iso_year + Series.dt.max + Series.dt.mean + Series.dt.median + Series.dt.microsecond + Series.dt.millennium + Series.dt.millisecond + Series.dt.min + Series.dt.minute + Series.dt.month + Series.dt.month_end + Series.dt.month_start + Series.dt.nanosecond + Series.dt.offset_by + Series.dt.ordinal_day + Series.dt.quarter + Series.dt.replace + Series.dt.replace_time_zone + Series.dt.round + Series.dt.second + Series.dt.strftime + Series.dt.time + Series.dt.timestamp + Series.dt.to_string + Series.dt.total_days + Series.dt.total_hours + Series.dt.total_microseconds + Series.dt.total_milliseconds + Series.dt.total_minutes + Series.dt.total_nanoseconds + Series.dt.total_seconds + Series.dt.truncate + Series.dt.week + Series.dt.weekday + Series.dt.with_time_unit + Series.dt.year diff --git a/py-polars/docs/source/reference/sql/clauses.rst b/py-polars/docs/source/reference/sql/clauses.rst new file mode 100644 index 000000000000..1bf2ea6ff9bd --- /dev/null +++ b/py-polars/docs/source/reference/sql/clauses.rst @@ -0,0 +1,353 @@ +SQL Clauses +=========== + +.. list-table:: + :header-rows: 1 + :widths: 20 60 + + * - Function + - Description + * - :ref:`SELECT ().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + Ok(()) } -create_aligned_buffer!(aligned_array_f32, f32); -create_aligned_buffer!(aligned_array_f64, f64); -create_aligned_buffer!(aligned_array_i8, i8); -create_aligned_buffer!(aligned_array_i16, i16); -create_aligned_buffer!(aligned_array_i32, i32); -create_aligned_buffer!(aligned_array_i64, i64); -create_aligned_buffer!(aligned_array_u8, u8); -create_aligned_buffer!(aligned_array_u16, u16); -create_aligned_buffer!(aligned_array_u32, u32); -create_aligned_buffer!(aligned_array_u64, u64); - -#[pyfunction] -pub fn series_from_ptr_f64(name: &str, ptr: usize, len: usize) -> PySeries { - let v: Vec = unsafe { npy::vec_from_ptr(ptr, len) }; - let av = AlignedVec::new(v).unwrap(); - let ca = ChunkedArray::new_from_aligned_vec(name, av); - PySeries::new(Series::Float64(ca)) +#[pymodule] +fn _expr_nodes(_py: Python, m: &Bound) -> PyResult<()> { + use polars_python::lazyframe::visit::PyExprIR; + use polars_python::lazyframe::visitor::expr_nodes::*; + // Expressions + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + // Options + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + Ok(()) } #[pymodule] -fn pypolars(_py: Python, m: &PyModule) -> PyResult<()> { +fn polars(py: Python, m: &Bound) -> PyResult<()> { + // Classes m.add_class::().unwrap(); m.add_class::().unwrap(); - m.add_wrapped(wrap_pyfunction!(aligned_array_f32)).unwrap(); - m.add_wrapped(wrap_pyfunction!(aligned_array_f64)).unwrap(); - m.add_wrapped(wrap_pyfunction!(aligned_array_i32)).unwrap(); - m.add_wrapped(wrap_pyfunction!(aligned_array_i64)).unwrap(); - m.add_wrapped(wrap_pyfunction!(series_from_ptr_f64)) + m.add_class::().unwrap(); + m.add_class::().unwrap(); + #[cfg(not(target_arch = "wasm32"))] + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + m.add_class::().unwrap(); + #[cfg(feature = "csv")] + m.add_class::().unwrap(); + #[cfg(feature = "sql")] + m.add_class::().unwrap(); + + // Submodules + // LogicalPlan objects + m.add_wrapped(wrap_pymodule!(_ir_nodes))?; + // Expr objects + m.add_wrapped(wrap_pymodule!(_expr_nodes))?; + + // Functions - eager + m.add_wrapped(wrap_pyfunction!(functions::concat_df)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::concat_series)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::concat_df_diagonal)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::concat_df_horizontal)) + .unwrap(); + + // Functions - range + m.add_wrapped(wrap_pyfunction!(functions::int_range)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::eager_int_range)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::int_ranges)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::linear_space)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::linear_spaces)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::date_range)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::date_ranges)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::datetime_range)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::datetime_ranges)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::time_range)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::time_ranges)) + .unwrap(); + + // Functions - business + m.add_wrapped(wrap_pyfunction!(functions::business_day_count)) + .unwrap(); + + // Functions - aggregation + m.add_wrapped(wrap_pyfunction!(functions::all_horizontal)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::any_horizontal)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::max_horizontal)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::min_horizontal)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::sum_horizontal)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::mean_horizontal)) + .unwrap(); + + // Functions - lazy + m.add_wrapped(wrap_pyfunction!(functions::arg_sort_by)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::arg_where)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::as_struct)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::coalesce)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::field)).unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::col)).unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::collect_all)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::explain_all)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::collect_all_with_callback)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::cols)).unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::concat_lf)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::concat_arr)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::concat_list)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::concat_str)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::len)).unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::cov)).unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::cum_fold)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::cum_reduce)) + .unwrap(); + #[cfg(feature = "trigonometry")] + m.add_wrapped(wrap_pyfunction!(functions::arctan2)).unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::datetime)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::concat_expr)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::concat_lf_diagonal)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::concat_lf_horizontal)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::dtype_cols)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::index_cols)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::duration)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::first)).unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::fold)).unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::last)).unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::lit)).unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::map_mul)).unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::nth)).unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::pearson_corr)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::rolling_corr)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::rolling_cov)) .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::reduce)).unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::repeat)).unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::spearman_rank_corr)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::when)).unwrap(); + + // Functions: other + m.add_wrapped(wrap_pyfunction!(functions::check_length)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::py_get_engine_affinity)) + .unwrap(); + + #[cfg(feature = "sql")] + m.add_wrapped(wrap_pyfunction!(functions::sql_expr)) + .unwrap(); + + // Functions - I/O + #[cfg(feature = "ipc")] + m.add_wrapped(wrap_pyfunction!(functions::read_ipc_schema)) + .unwrap(); + #[cfg(feature = "parquet")] + m.add_wrapped(wrap_pyfunction!(functions::read_parquet_schema)) + .unwrap(); + #[cfg(feature = "clipboard")] + m.add_wrapped(wrap_pyfunction!(functions::read_clipboard_string)) + .unwrap(); + #[cfg(feature = "clipboard")] + m.add_wrapped(wrap_pyfunction!(functions::write_clipboard_string)) + .unwrap(); + #[cfg(feature = "catalog")] + m.add_class::().unwrap(); + + // Functions - meta + m.add_wrapped(wrap_pyfunction!(functions::get_index_type)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::thread_pool_size)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::enable_string_cache)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::disable_string_cache)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::using_string_cache)) + .unwrap(); + + // Numeric formatting + m.add_wrapped(wrap_pyfunction!(functions::get_thousands_separator)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::set_thousands_separator)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::get_float_fmt)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::get_float_precision)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::get_decimal_separator)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::get_trim_decimal_zeros)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::set_float_fmt)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::set_float_precision)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::set_decimal_separator)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::set_trim_decimal_zeros)) + .unwrap(); + + // Functions - misc + m.add_wrapped(wrap_pyfunction!(functions::dtype_str_repr)) + .unwrap(); + #[cfg(feature = "object")] + m.add_wrapped(wrap_pyfunction!(functions::__register_startup_deps)) + .unwrap(); + + // Functions - random + m.add_wrapped(wrap_pyfunction!(functions::set_random_seed)) + .unwrap(); + + // Functions - escape_regex + m.add_wrapped(wrap_pyfunction!(functions::escape_regex)) + .unwrap(); + + // Dtype helpers + m.add_wrapped(wrap_pyfunction!(datatypes::_get_dtype_max)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(datatypes::_get_dtype_min)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(datatypes::_known_timezones)) + .unwrap(); + + // Exceptions - Errors + m.add("PolarsError", py.get_type::()) + .unwrap(); + m.add( + "ColumnNotFoundError", + py.get_type::(), + ) + .unwrap(); + m.add("ComputeError", py.get_type::()) + .unwrap(); + m.add( + "DuplicateError", + py.get_type::(), + ) + .unwrap(); + m.add( + "InvalidOperationError", + py.get_type::(), + ) + .unwrap(); + m.add("NoDataError", py.get_type::()) + .unwrap(); + m.add( + "OutOfBoundsError", + py.get_type::(), + ) + .unwrap(); + m.add( + "SQLInterfaceError", + py.get_type::(), + ) + .unwrap(); + m.add( + "SQLSyntaxError", + py.get_type::(), + ) + .unwrap(); + m.add("SchemaError", py.get_type::()) + .unwrap(); + m.add( + "SchemaFieldNotFoundError", + py.get_type::(), + ) + .unwrap(); + m.add("ShapeError", py.get_type::()) + .unwrap(); + m.add( + "StringCacheMismatchError", + py.get_type::(), + ) + .unwrap(); + m.add( + "StructFieldNotFoundError", + py.get_type::(), + ) + .unwrap(); + + // Exceptions - Warnings + m.add("PolarsWarning", py.get_type::()) + .unwrap(); + m.add( + "PerformanceWarning", + py.get_type::(), + ) + .unwrap(); + m.add( + "CategoricalRemappingWarning", + py.get_type::(), + ) + .unwrap(); + m.add( + "MapWithoutReturnDtypeWarning", + py.get_type::(), + ) + .unwrap(); + + // Exceptions - Panic + m.add( + "PanicException", + py.get_type::(), + ) + .unwrap(); + + // Cloud + #[cfg(feature = "polars_cloud")] + m.add_wrapped(wrap_pyfunction!(cloud::prepare_cloud_plan)) + .unwrap(); + #[cfg(feature = "polars_cloud")] + m.add_wrapped(wrap_pyfunction!(cloud::_execute_ir_plan_with_gpu)) + .unwrap(); + + // Build info + m.add("__version__", env!("CARGO_PKG_VERSION"))?; + + // Plugins + #[cfg(feature = "ffi_plugin")] + m.add_wrapped(wrap_pyfunction!(functions::register_plugin_function)) + .unwrap(); + + // Capsules + m.add("_allocator", create_allocator_capsule(py)?)?; + + m.add("_debug", cfg!(debug_assertions))?; + Ok(()) } diff --git a/py-polars/src/npy.rs b/py-polars/src/npy.rs deleted file mode 100644 index d84e2171c366..000000000000 --- a/py-polars/src/npy.rs +++ /dev/null @@ -1,45 +0,0 @@ -use ndarray::IntoDimension; -use numpy::{ - npyffi::{self, flags, types::npy_intp}, - ToNpyDims, PY_ARRAY_API, -}; -use numpy::{Element, PyArray1}; -use polars::prelude::*; -use pyo3::prelude::*; -use std::{mem, ptr}; - -/// Create an empty numpy array arrows 64 byte alignment -pub fn aligned_array(size: usize) -> (Py>, *mut T) { - let mut buf: Vec = Vec::with_capacity_aligned(size); - unsafe { buf.set_len(size) } - let gil = Python::acquire_gil(); - let py = gil.python(); - // modified from - // numpy-0.10.0/src/array.rs:375 - - let len = buf.len(); - let buffer_ptr = buf.as_mut_ptr(); - - let dims = [len].into_dimension(); - let strides = [mem::size_of::() as npy_intp]; - unsafe { - let ptr = PY_ARRAY_API.PyArray_New( - PY_ARRAY_API.get_type_object(npyffi::ArrayType::PyArray_Type), - dims.ndim_cint(), - dims.as_dims_ptr(), - T::ffi_dtype() as i32, - strides.as_ptr() as *mut _, // strides - buffer_ptr as _, // data - mem::size_of::() as i32, // itemsize - flags::NPY_ARRAY_OUT_ARRAY, // flag - ptr::null_mut(), //obj - ); - mem::forget(buf); - (PyArray1::from_owned_ptr(py, ptr).to_owned(), buffer_ptr) - } -} - -pub unsafe fn vec_from_ptr(ptr: usize, len: usize) -> Vec { - let ptr = ptr as *mut T; - Vec::from_raw_parts(ptr, len, len) -} diff --git a/py-polars/src/series.rs b/py-polars/src/series.rs deleted file mode 100644 index 44cc310dc558..000000000000 --- a/py-polars/src/series.rs +++ /dev/null @@ -1,882 +0,0 @@ -use crate::error::PyPolarsEr; -use crate::npy; -use numpy::PyArray1; -use polars::chunked_array::builder::get_bitmap; -use polars::prelude::*; -use pyo3::types::PyList; -use pyo3::{exceptions::RuntimeError, prelude::*, Python}; - -#[pyclass] -#[repr(transparent)] -#[derive(Clone)] -pub struct PySeries { - pub series: Series, -} - -impl PySeries { - pub(crate) fn new(series: Series) -> Self { - PySeries { series } - } -} - -// Init with numpy arrays -macro_rules! init_method { - ($name:ident, $type:ty) => { - #[pymethods] - impl PySeries { - #[staticmethod] - pub fn $name(name: &str, val: &PyArray1<$type>) -> PySeries { - unsafe { - PySeries { - series: Series::new(name, val.as_slice().unwrap()), - } - } - } - } - }; -} - -init_method!(new_i8, i8); -init_method!(new_i16, i16); -init_method!(new_i32, i32); -init_method!(new_i64, i64); -init_method!(new_f32, f32); -init_method!(new_f64, f64); -init_method!(new_bool, bool); -init_method!(new_u8, u8); -init_method!(new_u16, u16); -init_method!(new_u32, u32); -init_method!(new_u64, u64); -init_method!(new_date32, i32); -init_method!(new_date64, i64); -init_method!(new_duration_ns, i64); -init_method!(new_time_ns, i64); - -// Init with lists that can contain Nones -macro_rules! init_method_opt { - ($name:ident, $type:ty) => { - #[pymethods] - impl PySeries { - #[staticmethod] - pub fn $name(name: &str, val: Vec>) -> PySeries { - PySeries { - series: Series::new(name, &val), - } - } - } - }; -} - -init_method_opt!(new_opt_u8, u8); -init_method_opt!(new_opt_u16, u16); -init_method_opt!(new_opt_u32, u32); -init_method_opt!(new_opt_u64, u64); -init_method_opt!(new_opt_i8, i8); -init_method_opt!(new_opt_i16, i16); -init_method_opt!(new_opt_i32, i32); -init_method_opt!(new_opt_i64, i64); -init_method_opt!(new_opt_f32, f32); -init_method_opt!(new_opt_f64, f64); -init_method_opt!(new_opt_bool, bool); -init_method_opt!(new_opt_date32, i32); -init_method_opt!(new_opt_date64, i64); -init_method_opt!(new_opt_duration_ns, i64); -init_method_opt!(new_opt_time_ns, i64); - -#[pymethods] -impl PySeries { - #[staticmethod] - pub fn new_str(name: &str, val: Vec<&str>) -> Self { - PySeries::new(Series::new(name, &val)) - } - - #[staticmethod] - pub fn new_opt_str(name: &str, val: Vec>) -> Self { - PySeries::new(Series::new(name, &val)) - } - - pub fn rechunk(&mut self, in_place: bool) -> Option { - let series = self.series.rechunk(None).expect("should not fail"); - if in_place { - self.series = series; - None - } else { - Some(PySeries::new(series)) - } - } - - pub fn name(&self) -> &str { - self.series.name() - } - - pub fn rename(&mut self, name: &str) { - self.series.rename(name); - } - - pub fn dtype(&self) -> String { - self.series.dtype().to_str() - } - - pub fn n_chunks(&self) -> usize { - self.series.n_chunks() - } - - pub fn limit(&self, num_elements: usize) -> PyResult { - let series = self.series.limit(num_elements).map_err(PyPolarsEr::from)?; - Ok(PySeries { series }) - } - - pub fn slice(&self, offset: usize, length: usize) -> PyResult { - let series = self - .series - .slice(offset, length) - .map_err(PyPolarsEr::from)?; - Ok(PySeries { series }) - } - - pub fn append(&mut self, other: &PySeries) -> PyResult<()> { - self.series - .append(&other.series) - .map_err(PyPolarsEr::from)?; - Ok(()) - } - - pub fn filter(&self, filter: &PySeries) -> PyResult { - let filter_series = &filter.series; - if let Series::Bool(ca) = filter_series { - let series = self.series.filter(ca).map_err(PyPolarsEr::from)?; - Ok(PySeries { series }) - } else { - Err(RuntimeError::py_err("Expected a boolean mask")) - } - } - - pub fn add(&self, other: &PySeries) -> PyResult { - Ok(PySeries::new(&self.series + &other.series)) - } - - pub fn sub(&self, other: &PySeries) -> PyResult { - Ok(PySeries::new(&self.series - &other.series)) - } - - pub fn mul(&self, other: &PySeries) -> PyResult { - Ok(PySeries::new(&self.series * &other.series)) - } - - pub fn div(&self, other: &PySeries) -> PyResult { - Ok(PySeries::new(&self.series / &other.series)) - } - - pub fn head(&self, length: Option) -> PyResult { - Ok(PySeries::new(self.series.head(length))) - } - - pub fn tail(&self, length: Option) -> PyResult { - Ok(PySeries::new(self.series.tail(length))) - } - - pub fn sort_in_place(&mut self, reverse: bool) { - self.series.sort_in_place(reverse); - } - - pub fn sort(&mut self, reverse: bool) -> Self { - PySeries::new(self.series.sort(reverse)) - } - - pub fn argsort(&self, reverse: bool) -> Py> { - let gil = pyo3::Python::acquire_gil(); - let pyarray = PyArray1::from_vec(gil.python(), self.series.argsort(reverse)); - pyarray.to_owned() - } - - pub fn arg_unique(&self) -> Py> { - let gil = pyo3::Python::acquire_gil(); - let pyarray = PyArray1::from_vec(gil.python(), self.series.arg_unique()); - pyarray.to_owned() - } - - pub fn take(&self, indices: Vec) -> PyResult { - let take = self.series.take(&indices).map_err(PyPolarsEr::from)?; - Ok(PySeries::new(take)) - } - - pub fn take_with_series(&self, indices: &PySeries) -> PyResult { - let idx = indices.series.u32().map_err(PyPolarsEr::from)?; - let take = self.series.take(&idx).map_err(PyPolarsEr::from)?; - Ok(PySeries::new(take)) - } - - pub fn null_count(&self) -> PyResult { - Ok(self.series.null_count()) - } - - pub fn is_null(&self) -> PySeries { - Self::new(Series::Bool(self.series.is_null())) - } - - pub fn series_equal(&self, other: &PySeries) -> PyResult { - Ok(self.series.series_equal(&other.series)) - } - pub fn eq(&self, rhs: &PySeries) -> PyResult { - Ok(Self::new(Series::Bool(self.series.eq(&rhs.series)))) - } - - pub fn neq(&self, rhs: &PySeries) -> PyResult { - Ok(Self::new(Series::Bool(self.series.neq(&rhs.series)))) - } - - pub fn gt(&self, rhs: &PySeries) -> PyResult { - Ok(Self::new(Series::Bool(self.series.gt(&rhs.series)))) - } - - pub fn gt_eq(&self, rhs: &PySeries) -> PyResult { - Ok(Self::new(Series::Bool(self.series.gt_eq(&rhs.series)))) - } - - pub fn lt(&self, rhs: &PySeries) -> PyResult { - Ok(Self::new(Series::Bool(self.series.lt(&rhs.series)))) - } - - pub fn lt_eq(&self, rhs: &PySeries) -> PyResult { - Ok(Self::new(Series::Bool(self.series.lt_eq(&rhs.series)))) - } - - pub fn as_str(&self) -> PyResult { - Ok(format!("{:?}", self.series)) - } - - pub fn len(&self) -> usize { - self.series.len() - } - - pub fn to_list(&self) -> PyObject { - let gil = Python::acquire_gil(); - let python = gil.python(); - - let pylist = match &self.series { - Series::UInt8(ca) => PyList::new(python, ca), - Series::UInt16(ca) => PyList::new(python, ca), - Series::UInt32(ca) => PyList::new(python, ca), - Series::UInt64(ca) => PyList::new(python, ca), - Series::Int8(ca) => PyList::new(python, ca), - Series::Int16(ca) => PyList::new(python, ca), - Series::Int32(ca) => PyList::new(python, ca), - Series::Int64(ca) => PyList::new(python, ca), - Series::Float32(ca) => PyList::new(python, ca), - Series::Float64(ca) => PyList::new(python, ca), - Series::Date32(ca) => PyList::new(python, ca), - Series::Date64(ca) => PyList::new(python, ca), - Series::Time64Nanosecond(ca) => PyList::new(python, ca), - Series::DurationNanosecond(ca) => PyList::new(python, ca), - Series::Bool(ca) => PyList::new(python, ca), - Series::Utf8(ca) => PyList::new(python, ca), - _ => todo!(), - }; - pylist.to_object(python) - } - - /// Rechunk and return a pointer to the start of the Series. - /// Only implemented for numeric types - pub fn as_single_ptr(&mut self) -> usize { - self.series.as_single_ptr() - } - - pub fn fill_none(&self, strategy: &str) -> PyResult { - let strat = match strategy { - "backward" => FillNoneStrategy::Backward, - "forward" => FillNoneStrategy::Forward, - "min" => FillNoneStrategy::Min, - "max" => FillNoneStrategy::Max, - "mean" => FillNoneStrategy::Mean, - s => return Err(PyPolarsEr::Other(format!("Strategy {} not supported", s)).into()), - }; - let series = self.series.fill_none(strat).map_err(PyPolarsEr::from)?; - Ok(PySeries::new(series)) - } - - /// Attempts to copy data to numpy arrays. If integer types have missing values - /// they will be casted to floating point values, where NaNs are used to represent missing. - /// Strings will be converted to python lists and booleans will be a numpy array if there are no - /// missing values, otherwise a python list is made. - pub fn to_numpy(&self) -> PyObject { - let gil = Python::acquire_gil(); - let py = gil.python(); - let series = &self.series; - - // if has null values we use floats and np.nan to represent missing values - macro_rules! impl_to_np_array { - ($ca:ident, $float_type:ty) => {{ - match $ca.cont_slice() { - Ok(slice) => PyArray1::from_slice(py, slice).to_object(py), - Err(_) => { - if $ca.null_count() == 0 { - let v = $ca.into_no_null_iter().collect::>(); - PyArray1::from_vec(py, v).to_object(py) - } else { - let v = $ca - .into_iter() - .map(|opt_v| match opt_v { - Some(v) => v as $float_type, - None => <$float_type>::NAN, - }) - .collect::>(); - PyArray1::from_vec(py, v).to_object(py) - } - } - } - }}; - } - - match series { - Series::UInt8(ca) => impl_to_np_array!(ca, f32), - Series::UInt16(ca) => impl_to_np_array!(ca, f32), - Series::UInt32(ca) => impl_to_np_array!(ca, f32), - Series::UInt64(ca) => impl_to_np_array!(ca, f64), - Series::Int8(ca) => impl_to_np_array!(ca, f32), - Series::Int16(ca) => impl_to_np_array!(ca, f32), - Series::Int32(ca) => impl_to_np_array!(ca, f32), - Series::Int64(ca) => impl_to_np_array!(ca, f64), - Series::Float32(ca) => impl_to_np_array!(ca, f32), - Series::Float64(ca) => impl_to_np_array!(ca, f64), - Series::Date32(ca) => impl_to_np_array!(ca, f32), - Series::Date64(ca) => impl_to_np_array!(ca, f64), - Series::Bool(ca) => { - if ca.null_count() == 0 { - let v = ca.into_no_null_iter().collect::>(); - PyArray1::from_vec(py, v).to_object(py) - } else { - self.to_list() - } - } - Series::Utf8(_) => self.to_list(), - _ => todo!(), - } - } - - pub fn clone(&self) -> Self { - PySeries::new(self.series.clone()) - } -} - -macro_rules! impl_set_with_mask { - ($name:ident, $native:ty, $cast:ident, $variant:ident) => { - fn $name(series: &Series, filter: &PySeries, value: Option<$native>) -> Result { - let mask = filter.series.bool()?; - let ca = series.$cast()?; - let new = ca.set(mask, value)?; - Ok(Series::$variant(new)) - } - - #[pymethods] - impl PySeries { - pub fn $name(&self, filter: &PySeries, value: Option<$native>) -> PyResult { - let series = $name(&self.series, filter, value).map_err(PyPolarsEr::from)?; - Ok(Self::new(series)) - } - } - }; -} - -impl_set_with_mask!(set_with_mask_str, &str, utf8, Utf8); -impl_set_with_mask!(set_with_mask_f64, f64, f64, Float64); -impl_set_with_mask!(set_with_mask_f32, f32, f32, Float32); -impl_set_with_mask!(set_with_mask_u8, u8, u8, UInt8); -impl_set_with_mask!(set_with_mask_u16, u16, u16, UInt16); -impl_set_with_mask!(set_with_mask_u32, u32, u32, UInt32); -impl_set_with_mask!(set_with_mask_u64, u64, u64, UInt64); -impl_set_with_mask!(set_with_mask_i8, i8, i8, Int8); -impl_set_with_mask!(set_with_mask_i16, i16, i16, Int16); -impl_set_with_mask!(set_with_mask_i32, i32, i32, Int32); -impl_set_with_mask!(set_with_mask_i64, i64, i64, Int64); - -macro_rules! impl_set_at_idx { - ($name:ident, $native:ty, $cast:ident, $variant:ident) => { - fn $name(series: &Series, idx: &[usize], value: Option<$native>) -> Result { - let ca = series.$cast()?; - let new = ca.set_at_idx(&idx, value)?; - Ok(Series::$variant(new)) - } - - #[pymethods] - impl PySeries { - pub fn $name(&self, idx: &PyArray1, value: Option<$native>) -> PyResult { - let idx = unsafe { idx.as_slice().unwrap() }; - let series = $name(&self.series, &idx, value).map_err(PyPolarsEr::from)?; - Ok(Self::new(series)) - } - } - }; -} - -impl_set_at_idx!(set_at_idx_str, &str, utf8, Utf8); -impl_set_at_idx!(set_at_idx_f64, f64, f64, Float64); -impl_set_at_idx!(set_at_idx_f32, f32, f32, Float32); -impl_set_at_idx!(set_at_idx_u8, u8, u8, UInt8); -impl_set_at_idx!(set_at_idx_u16, u16, u16, UInt16); -impl_set_at_idx!(set_at_idx_u32, u32, u32, UInt32); -impl_set_at_idx!(set_at_idx_u64, u64, u64, UInt64); -impl_set_at_idx!(set_at_idx_i8, i8, i8, Int8); -impl_set_at_idx!(set_at_idx_i16, i16, i16, Int16); -impl_set_at_idx!(set_at_idx_i32, i32, i32, Int32); -impl_set_at_idx!(set_at_idx_i64, i64, i64, Int64); - -macro_rules! impl_get { - ($name:ident, $series_variant:ident, $type:ty) => { - #[pymethods] - impl PySeries { - pub fn $name(&self, index: usize) -> Option<$type> { - if let Series::$series_variant(ca) = &self.series { - ca.get(index) - } else { - None - } - } - } - }; -} - -impl_get!(get_f32, Float32, f32); -impl_get!(get_f64, Float64, f64); -impl_get!(get_u8, UInt8, u8); -impl_get!(get_u16, UInt16, u16); -impl_get!(get_u32, UInt32, u32); -impl_get!(get_u64, UInt64, u64); -impl_get!(get_i8, Int8, i8); -impl_get!(get_i16, Int16, i16); -impl_get!(get_i32, Int32, i32); -impl_get!(get_i64, Int64, i64); - -macro_rules! impl_unsafe_from_ptr { - ($name:ident, $series_variant:ident) => { - #[pymethods] - impl PySeries { - pub fn $name(&self, ptr: usize, len: usize) -> Self { - let v = unsafe { npy::vec_from_ptr(ptr, len) }; - let av = AlignedVec::new(v).unwrap(); - let (null_count, null_bitmap) = get_bitmap(self.series.chunks()[0].as_ref()); - let ca = ChunkedArray::new_from_owned_with_null_bitmap( - self.name(), - av, - null_bitmap, - null_count, - ); - Self::new(Series::$series_variant(ca)) - } - } - }; -} - -impl_unsafe_from_ptr!(unsafe_from_ptr_f32, Float32); -impl_unsafe_from_ptr!(unsafe_from_ptr_f64, Float64); -impl_unsafe_from_ptr!(unsafe_from_ptr_u8, UInt8); -impl_unsafe_from_ptr!(unsafe_from_ptr_u16, UInt16); -impl_unsafe_from_ptr!(unsafe_from_ptr_u32, UInt32); -impl_unsafe_from_ptr!(unsafe_from_ptr_u64, UInt64); -impl_unsafe_from_ptr!(unsafe_from_ptr_i8, Int8); -impl_unsafe_from_ptr!(unsafe_from_ptr_i16, Int16); -impl_unsafe_from_ptr!(unsafe_from_ptr_i32, Int32); -impl_unsafe_from_ptr!(unsafe_from_ptr_i64, Int64); - -macro_rules! impl_cast { - ($name:ident, $type:ty) => { - #[pymethods] - impl PySeries { - pub fn $name(&self) -> PyResult { - let s = self.series.cast::<$type>().map_err(PyPolarsEr::from)?; - Ok(PySeries::new(s)) - } - } - }; -} - -impl_cast!(cast_u8, UInt8Type); -impl_cast!(cast_u16, UInt16Type); -impl_cast!(cast_u32, UInt32Type); -impl_cast!(cast_u64, UInt64Type); -impl_cast!(cast_i8, Int8Type); -impl_cast!(cast_i16, Int16Type); -impl_cast!(cast_i32, Int32Type); -impl_cast!(cast_i64, Int64Type); -impl_cast!(cast_f32, Float32Type); -impl_cast!(cast_f64, Float64Type); -impl_cast!(cast_date32, Date32Type); -impl_cast!(cast_date64, Date64Type); -impl_cast!(cast_time64ns, Time64NanosecondType); -impl_cast!(cast_duration_ns, DurationNanosecondType); - -macro_rules! impl_arithmetic { - ($name:ident, $type:ty, $operand:tt) => { - #[pymethods] - impl PySeries { - pub fn $name(&self, other: $type) -> PyResult { - Ok(PySeries::new(&self.series $operand other)) - } - } - }; -} - -impl_arithmetic!(add_u8, u8, +); -impl_arithmetic!(add_u16, u16, +); -impl_arithmetic!(add_u32, u32, +); -impl_arithmetic!(add_u64, u64, +); -impl_arithmetic!(add_i8, i8, +); -impl_arithmetic!(add_i16, i16, +); -impl_arithmetic!(add_i32, i32, +); -impl_arithmetic!(add_i64, i64, +); -impl_arithmetic!(add_f32, f32, +); -impl_arithmetic!(add_f64, f64, +); -impl_arithmetic!(sub_u8, u8, -); -impl_arithmetic!(sub_u16, u16, -); -impl_arithmetic!(sub_u32, u32, -); -impl_arithmetic!(sub_u64, u64, -); -impl_arithmetic!(sub_i8, i8, -); -impl_arithmetic!(sub_i16, i16, -); -impl_arithmetic!(sub_i32, i32, -); -impl_arithmetic!(sub_i64, i64, -); -impl_arithmetic!(sub_f32, f32, -); -impl_arithmetic!(sub_f64, f64, -); -impl_arithmetic!(div_u8, u8, /); -impl_arithmetic!(div_u16, u16, /); -impl_arithmetic!(div_u32, u32, /); -impl_arithmetic!(div_u64, u64, /); -impl_arithmetic!(div_i8, i8, /); -impl_arithmetic!(div_i16, i16, /); -impl_arithmetic!(div_i32, i32, /); -impl_arithmetic!(div_i64, i64, /); -impl_arithmetic!(div_f32, f32, /); -impl_arithmetic!(div_f64, f64, /); -impl_arithmetic!(mul_u8, u8, *); -impl_arithmetic!(mul_u16, u16, *); -impl_arithmetic!(mul_u32, u32, *); -impl_arithmetic!(mul_u64, u64, *); -impl_arithmetic!(mul_i8, i8, *); -impl_arithmetic!(mul_i16, i16, *); -impl_arithmetic!(mul_i32, i32, *); -impl_arithmetic!(mul_i64, i64, *); -impl_arithmetic!(mul_f32, f32, *); -impl_arithmetic!(mul_f64, f64, *); - -macro_rules! impl_rhs_arithmetic { - ($name:ident, $type:ty, $operand:ident) => { - #[pymethods] - impl PySeries { - pub fn $name(&self, other: $type) -> PyResult { - Ok(PySeries::new(other.$operand(&self.series))) - } - } - }; -} - -impl_rhs_arithmetic!(add_u8_rhs, u8, add); -impl_rhs_arithmetic!(add_u16_rhs, u16, add); -impl_rhs_arithmetic!(add_u32_rhs, u32, add); -impl_rhs_arithmetic!(add_u64_rhs, u64, add); -impl_rhs_arithmetic!(add_i8_rhs, i8, add); -impl_rhs_arithmetic!(add_i16_rhs, i16, add); -impl_rhs_arithmetic!(add_i32_rhs, i32, add); -impl_rhs_arithmetic!(add_i64_rhs, i64, add); -impl_rhs_arithmetic!(add_f32_rhs, f32, add); -impl_rhs_arithmetic!(add_f64_rhs, f64, add); -impl_rhs_arithmetic!(sub_u8_rhs, u8, sub); -impl_rhs_arithmetic!(sub_u16_rhs, u16, sub); -impl_rhs_arithmetic!(sub_u32_rhs, u32, sub); -impl_rhs_arithmetic!(sub_u64_rhs, u64, sub); -impl_rhs_arithmetic!(sub_i8_rhs, i8, sub); -impl_rhs_arithmetic!(sub_i16_rhs, i16, sub); -impl_rhs_arithmetic!(sub_i32_rhs, i32, sub); -impl_rhs_arithmetic!(sub_i64_rhs, i64, sub); -impl_rhs_arithmetic!(sub_f32_rhs, f32, sub); -impl_rhs_arithmetic!(sub_f64_rhs, f64, sub); -impl_rhs_arithmetic!(div_u8_rhs, u8, div); -impl_rhs_arithmetic!(div_u16_rhs, u16, div); -impl_rhs_arithmetic!(div_u32_rhs, u32, div); -impl_rhs_arithmetic!(div_u64_rhs, u64, div); -impl_rhs_arithmetic!(div_i8_rhs, i8, div); -impl_rhs_arithmetic!(div_i16_rhs, i16, div); -impl_rhs_arithmetic!(div_i32_rhs, i32, div); -impl_rhs_arithmetic!(div_i64_rhs, i64, div); -impl_rhs_arithmetic!(div_f32_rhs, f32, div); -impl_rhs_arithmetic!(div_f64_rhs, f64, div); -impl_rhs_arithmetic!(mul_u8_rhs, u8, mul); -impl_rhs_arithmetic!(mul_u16_rhs, u16, mul); -impl_rhs_arithmetic!(mul_u32_rhs, u32, mul); -impl_rhs_arithmetic!(mul_u64_rhs, u64, mul); -impl_rhs_arithmetic!(mul_i8_rhs, i8, mul); -impl_rhs_arithmetic!(mul_i16_rhs, i16, mul); -impl_rhs_arithmetic!(mul_i32_rhs, i32, mul); -impl_rhs_arithmetic!(mul_i64_rhs, i64, mul); -impl_rhs_arithmetic!(mul_f32_rhs, f32, mul); -impl_rhs_arithmetic!(mul_f64_rhs, f64, mul); - -macro_rules! impl_sum { - ($name:ident, $type:ty) => { - #[pymethods] - impl PySeries { - pub fn $name(&self) -> PyResult> { - Ok(self.series.sum()) - } - } - }; -} - -impl_sum!(sum_u8, u8); -impl_sum!(sum_u16, u16); -impl_sum!(sum_u32, u32); -impl_sum!(sum_u64, u64); -impl_sum!(sum_i8, i8); -impl_sum!(sum_i16, i16); -impl_sum!(sum_i32, i32); -impl_sum!(sum_i64, i64); -impl_sum!(sum_f32, f32); -impl_sum!(sum_f64, f64); - -macro_rules! impl_min { - ($name:ident, $type:ty) => { - #[pymethods] - impl PySeries { - pub fn $name(&self) -> PyResult> { - Ok(self.series.min()) - } - } - }; -} - -impl_min!(min_u8, u8); -impl_min!(min_u16, u16); -impl_min!(min_u32, u32); -impl_min!(min_u64, u64); -impl_min!(min_i8, i8); -impl_min!(min_i16, i16); -impl_min!(min_i32, i32); -impl_min!(min_i64, i64); -impl_min!(min_f32, f32); -impl_min!(min_f64, f64); - -macro_rules! impl_max { - ($name:ident, $type:ty) => { - #[pymethods] - impl PySeries { - pub fn $name(&self) -> PyResult> { - Ok(self.series.max()) - } - } - }; -} - -impl_max!(max_u8, u8); -impl_max!(max_u16, u16); -impl_max!(max_u32, u32); -impl_max!(max_u64, u64); -impl_max!(max_i8, i8); -impl_max!(max_i16, i16); -impl_max!(max_i32, i32); -impl_max!(max_i64, i64); -impl_max!(max_f32, f32); -impl_max!(max_f64, f64); - -macro_rules! impl_mean { - ($name:ident, $type:ty) => { - #[pymethods] - impl PySeries { - pub fn $name(&self) -> PyResult> { - Ok(self.series.mean()) - } - } - }; -} - -impl_mean!(mean_u8, u8); -impl_mean!(mean_u16, u16); -impl_mean!(mean_u32, u32); -impl_mean!(mean_u64, u64); -impl_mean!(mean_i8, i8); -impl_mean!(mean_i16, i16); -impl_mean!(mean_i32, i32); -impl_mean!(mean_i64, i64); -impl_mean!(mean_f32, f32); -impl_mean!(mean_f64, f64); - -macro_rules! impl_eq_num { - ($name:ident, $type:ty) => { - #[pymethods] - impl PySeries { - pub fn $name(&self, rhs: $type) -> PyResult { - Ok(PySeries::new(Series::Bool(self.series.eq(rhs)))) - } - } - }; -} - -impl_eq_num!(eq_u8, u8); -impl_eq_num!(eq_u16, u16); -impl_eq_num!(eq_u32, u32); -impl_eq_num!(eq_u64, u64); -impl_eq_num!(eq_i8, i8); -impl_eq_num!(eq_i16, i16); -impl_eq_num!(eq_i32, i32); -impl_eq_num!(eq_i64, i64); -impl_eq_num!(eq_f32, f32); -impl_eq_num!(eq_f64, f64); -impl_eq_num!(eq_str, &str); - -macro_rules! impl_neq_num { - ($name:ident, $type:ty) => { - #[pymethods] - impl PySeries { - pub fn $name(&self, rhs: $type) -> PyResult { - Ok(PySeries::new(Series::Bool(self.series.neq(rhs)))) - } - } - }; -} - -impl_neq_num!(neq_u8, u8); -impl_neq_num!(neq_u16, u16); -impl_neq_num!(neq_u32, u32); -impl_neq_num!(neq_u64, u64); -impl_neq_num!(neq_i8, i8); -impl_neq_num!(neq_i16, i16); -impl_neq_num!(neq_i32, i32); -impl_neq_num!(neq_i64, i64); -impl_neq_num!(neq_f32, f32); -impl_neq_num!(neq_f64, f64); -impl_neq_num!(neq_str, &str); - -macro_rules! impl_gt_num { - ($name:ident, $type:ty) => { - #[pymethods] - impl PySeries { - pub fn $name(&self, rhs: $type) -> PyResult { - Ok(PySeries::new(Series::Bool(self.series.gt(rhs)))) - } - } - }; -} - -impl_gt_num!(gt_u8, u8); -impl_gt_num!(gt_u16, u16); -impl_gt_num!(gt_u32, u32); -impl_gt_num!(gt_u64, u64); -impl_gt_num!(gt_i8, i8); -impl_gt_num!(gt_i16, i16); -impl_gt_num!(gt_i32, i32); -impl_gt_num!(gt_i64, i64); -impl_gt_num!(gt_f32, f32); -impl_gt_num!(gt_f64, f64); -impl_gt_num!(gt_str, &str); - -macro_rules! impl_gt_eq_num { - ($name:ident, $type:ty) => { - #[pymethods] - impl PySeries { - pub fn $name(&self, rhs: $type) -> PyResult { - Ok(PySeries::new(Series::Bool(self.series.gt_eq(rhs)))) - } - } - }; -} - -impl_gt_eq_num!(gt_eq_u8, u8); -impl_gt_eq_num!(gt_eq_u16, u16); -impl_gt_eq_num!(gt_eq_u32, u32); -impl_gt_eq_num!(gt_eq_u64, u64); -impl_gt_eq_num!(gt_eq_i8, i8); -impl_gt_eq_num!(gt_eq_i16, i16); -impl_gt_eq_num!(gt_eq_i32, i32); -impl_gt_eq_num!(gt_eq_i64, i64); -impl_gt_eq_num!(gt_eq_f32, f32); -impl_gt_eq_num!(gt_eq_f64, f64); -impl_gt_eq_num!(gt_eq_str, &str); - -macro_rules! impl_lt_num { - ($name:ident, $type:ty) => { - #[pymethods] - impl PySeries { - pub fn $name(&self, rhs: $type) -> PyResult { - Ok(PySeries::new(Series::Bool(self.series.lt(rhs)))) - } - } - }; -} - -impl_lt_num!(lt_u8, u8); -impl_lt_num!(lt_u16, u16); -impl_lt_num!(lt_u32, u32); -impl_lt_num!(lt_u64, u64); -impl_lt_num!(lt_i8, i8); -impl_lt_num!(lt_i16, i16); -impl_lt_num!(lt_i32, i32); -impl_lt_num!(lt_i64, i64); -impl_lt_num!(lt_f32, f32); -impl_lt_num!(lt_f64, f64); -impl_lt_num!(lt_str, &str); - -macro_rules! impl_lt_eq_num { - ($name:ident, $type:ty) => { - #[pymethods] - impl PySeries { - pub fn $name(&self, rhs: $type) -> PyResult { - Ok(PySeries::new(Series::Bool(self.series.lt_eq(rhs)))) - } - } - }; -} - -impl_lt_eq_num!(lt_eq_u8, u8); -impl_lt_eq_num!(lt_eq_u16, u16); -impl_lt_eq_num!(lt_eq_u32, u32); -impl_lt_eq_num!(lt_eq_u64, u64); -impl_lt_eq_num!(lt_eq_i8, i8); -impl_lt_eq_num!(lt_eq_i16, i16); -impl_lt_eq_num!(lt_eq_i32, i32); -impl_lt_eq_num!(lt_eq_i64, i64); -impl_lt_eq_num!(lt_eq_f32, f32); -impl_lt_eq_num!(lt_eq_f64, f64); -impl_lt_eq_num!(lt_eq_str, &str); - -pub(crate) fn to_series_collection(ps: Vec) -> Vec { - // prevent destruction of ps - let mut ps = std::mem::ManuallyDrop::new(ps); - - // get mutable pointer and reinterpret as Series - let p = ps.as_mut_ptr() as *mut Series; - let len = ps.len(); - let cap = ps.capacity(); - - // The pointer ownership will be transferred to Vec and this will be responsible for dealoc - unsafe { Vec::from_raw_parts(p, len, cap) } -} - -pub(crate) fn to_pyseries_collection(s: Vec) -> Vec { - let mut s = std::mem::ManuallyDrop::new(s); - - let p = s.as_mut_ptr() as *mut PySeries; - let len = s.len(); - let cap = s.capacity(); - - unsafe { Vec::from_raw_parts(p, len, cap) } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn transmute_to_series() { - // NOTE: This is only possible because PySeries is #[repr(transparent)] - // https://doc.rust-lang.org/reference/type-layout.html - let ps = PySeries { - series: [1i32, 2, 3].iter().collect(), - }; - - let s = unsafe { std::mem::transmute::(ps.clone()) }; - - assert_eq!(s.sum::(), Some(6)); - let collection = vec![ps]; - let s = to_series_collection(collection); - assert_eq!( - s.iter().map(|s| s.sum::()).collect::>(), - vec![Some(6)] - ); - } -} diff --git a/py-polars/tasks.sh b/py-polars/tasks.sh deleted file mode 100755 index 3cd87e3fbb90..000000000000 --- a/py-polars/tasks.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/usr/bin/env bash - -function test-build { - maturin build -o wheels -i $(which python) -} - -function test-install { - test-build - cd wheels && pip install -U py* && cd .. -} - -function release-build { - maturin build --release -o wheels -i $(which python) -} - -function release-install { - release-build - cd wheels && pip install -U py* && cd .. -} - -function build-run-tests { - test-install - pytest tests -} - -function build-docker-base { - cd .. - docker build -f py-polars/Dockerfile_base -t ritchie46/py-polars-base . -} - -function build-and-push-base { - build-docker-base - docker push ritchie46/py-polars-base -} - -function build-docker { - docker build -t ritchie46/py-polars . -} - -function build-and-push { - build-docker - docker push ritchie46/py-polars -} - -"$@" diff --git a/py-polars/tests/__init__.py b/py-polars/tests/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/benchmark/__init__.py b/py-polars/tests/benchmark/__init__.py new file mode 100644 index 000000000000..643565f66680 --- /dev/null +++ b/py-polars/tests/benchmark/__init__.py @@ -0,0 +1,9 @@ +""" +Benchmark tests. + +These tests are skipped by default as a relatively large dataset must be generated +first. + +See the documentation on how to run these tests: +https://docs.pola.rs/development/contributing/test/#benchmark-tests +""" diff --git a/py-polars/tests/benchmark/conftest.py b/py-polars/tests/benchmark/conftest.py new file mode 100644 index 000000000000..da210cb64125 --- /dev/null +++ b/py-polars/tests/benchmark/conftest.py @@ -0,0 +1,9 @@ +import pytest + +import polars as pl +from tests.benchmark.data import generate_group_by_data + + +@pytest.fixture(scope="session") +def groupby_data() -> pl.DataFrame: + return generate_group_by_data(10_000, 100, null_ratio=0.05) diff --git a/py-polars/tests/benchmark/data/__init__.py b/py-polars/tests/benchmark/data/__init__.py new file mode 100644 index 000000000000..84b7ef3683ce --- /dev/null +++ b/py-polars/tests/benchmark/data/__init__.py @@ -0,0 +1,9 @@ +"""Data generation functionality for use in the benchmarking suite.""" + +from tests.benchmark.data.h2oai import generate_group_by_data +from tests.benchmark.data.pdsh import load_pdsh_table + +__all__ = [ + "generate_group_by_data", + "load_pdsh_table", +] diff --git a/py-polars/tests/benchmark/data/h2oai/__init__.py b/py-polars/tests/benchmark/data/h2oai/__init__.py new file mode 100644 index 000000000000..3fd84ef8314e --- /dev/null +++ b/py-polars/tests/benchmark/data/h2oai/__init__.py @@ -0,0 +1,5 @@ +"""Data generation functionality for use in the test suite.""" + +from tests.benchmark.data.h2oai.datagen_groupby import generate_group_by_data + +__all__ = ["generate_group_by_data"] diff --git a/py-polars/tests/benchmark/data/h2oai/datagen_groupby.py b/py-polars/tests/benchmark/data/h2oai/datagen_groupby.py new file mode 100644 index 000000000000..782885cf5ec2 --- /dev/null +++ b/py-polars/tests/benchmark/data/h2oai/datagen_groupby.py @@ -0,0 +1,193 @@ +""" +Script to generate data for benchmarking group-by operations. + +Data generation logic was adapted from the H2O.ai benchmark. +The original R script is located here: +https://github.com/h2oai/db-benchmark/blob/master/_data/groupby-datagen.R + +Examples +-------- +10 million rows, 100 groups, no nulls, random order: +$ python datagen_groupby.py 1e7 1e2 --null-percentage 0 + +100 million rows, 10 groups, 5% nulls, sorted: +$ python datagen_groupby.py 1e8 1e1 --null-percentage 5 --sorted +""" + +import argparse +import logging + +import numpy as np +from numpy.random import default_rng + +import polars as pl + +logging.basicConfig(level=logging.INFO) + +SEED = 0 +rng = default_rng(seed=SEED) + +__all__ = ["generate_group_by_data"] + + +def generate_group_by_data( + n_rows: int, n_groups: int, null_ratio: float = 0.0, *, sort: bool = False +) -> pl.DataFrame: + """Generate data for benchmarking group-by operations.""" + logging.info("Generating data...") + df = _generate_data(n_rows, n_groups) + + if null_ratio > 0.0: + logging.info("Setting nulls...") + df = _set_nulls(df, null_ratio) + + if sort: + logging.info("Sorting data...") + df = df.sort(c for c in df.columns if c.startswith("id")) + + logging.info("Done generating data.") + return df + + +def _generate_data(n_rows: int, n_groups: int) -> pl.DataFrame: + N = n_rows + K = n_groups + + group_str_small = [f"id{str(i).zfill(3)}" for i in range(1, K + 1)] + group_str_large = [f"id{str(i).zfill(10)}" for i in range(1, int(N / K) + 1)] + group_int_small = range(1, K + 1) + group_int_large = range(1, int(N / K) + 1) + + var_int_small = range(1, 6) + var_int_large = range(1, 16) + var_float = rng.uniform(0, 100, N) + + return pl.DataFrame( + { + "id1": rng.choice(group_str_small, N), + "id2": rng.choice(group_str_small, N), + "id3": rng.choice(group_str_large, N), + "id4": rng.choice(group_int_small, N), + "id5": rng.choice(group_int_small, N), + "id6": rng.choice(group_int_large, N), + "v1": rng.choice(var_int_small, N), + "v2": rng.choice(var_int_large, N), + "v3": np.round(var_float, 6), + }, + schema={ + "id1": pl.String, + "id2": pl.String, + "id3": pl.String, + "id4": pl.Int32, + "id5": pl.Int32, + "id6": pl.Int32, + "v1": pl.Int32, + "v2": pl.Int32, + "v3": pl.Float64, + }, + ) + + +def _set_nulls(df: pl.DataFrame, null_ratio: float) -> pl.DataFrame: + """Set null values according to the given ratio.""" + + def set_nulls_var(s: pl.Series, ratio: float) -> pl.Series: + """Set Series values to null according to the given ratio.""" + len = s.len() + n_null = int(ratio * len) + if n_null == 0: + return s + + indices = rng.choice(len, size=n_null, replace=False) + return s.scatter(indices, None) + + def set_nulls_group(s: pl.Series, ratio: float) -> pl.Series: + """Set Series unique values to null according to the given ratio.""" + uniques = s.unique() + n_null = int(ratio * uniques.len()) + if n_null == 0: + return s + + to_replace = rng.choice(uniques, size=n_null, replace=False) + return ( + s.to_frame() + .select( + pl.when(pl.col(s.name).is_in(to_replace)) + .then(None) + .otherwise(pl.col(s.name)) + .alias(s.name) + ) + .to_series() + ) + + return df.with_columns( + set_nulls_group(s, null_ratio) + if s.name.startswith("id") + else set_nulls_var(s, null_ratio) + for s in df.get_columns() + ) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate data for benchmarking group-by operations" + ) + + parser.add_argument("rows", type=float, help="Number of rows") + parser.add_argument("groups", type=float, help="Number of groups") + parser.add_argument( + "-n", + "--null-percentage", + type=int, + default=0, + choices=range(1, 101), + metavar="[0-100]", + help="Percentage of null values", + ) + parser.add_argument( + "-s", + "--sort", + action="store_true", + help="Sort the data by group", + ) + parser.add_argument( + "-o", + "--output", + type=str, + default=None, + help="Output filename", + ) + + args = parser.parse_args() + + # Convert arguments to appropriate types + n_rows = int(args.rows) + n_groups = int(args.groups) + null_ratio = args.null_percentage / 100 + sort = args.sort + + logging.info( + f"Generating data: {n_rows} rows, {n_groups} groups, {null_ratio} null ratio, sorted: {args.sort}" + ) + + df = generate_group_by_data(n_rows, n_groups, null_ratio=null_ratio, sort=sort) + write_data(df, args) + + +def write_data(df: pl.DataFrame, args: argparse.Namespace) -> None: + def format_int(i: int) -> str: + base, exp = f"{i:e}".split("e") + return f"{float(base):g}e{int(exp)}" + + if args.output is not None: + filename = args.output + else: + filename = f"G1_{format_int(args.rows)}_{format_int(args.groups)}_{args.null_percentage}_{int(args.sort)}.csv" + + logging.info(f"Writing data to {filename}") + df.write_csv(filename) + logging.info("Done writing data.") + + +if __name__ == "__main__": + main() diff --git a/py-polars/tests/benchmark/data/pdsh/__init__.py b/py-polars/tests/benchmark/data/pdsh/__init__.py new file mode 100644 index 000000000000..ef007f5ed8d9 --- /dev/null +++ b/py-polars/tests/benchmark/data/pdsh/__init__.py @@ -0,0 +1,5 @@ +"""Generate data for the PDS-H benchmark tests.""" + +from tests.benchmark.data.pdsh.generate_data import load_pdsh_table + +__all__ = ["load_pdsh_table"] diff --git a/py-polars/tests/benchmark/data/pdsh/dbgen/dbgen b/py-polars/tests/benchmark/data/pdsh/dbgen/dbgen new file mode 100755 index 000000000000..63ecbb8616e5 Binary files /dev/null and b/py-polars/tests/benchmark/data/pdsh/dbgen/dbgen differ diff --git a/py-polars/tests/benchmark/data/pdsh/dbgen/dists.dss b/py-polars/tests/benchmark/data/pdsh/dbgen/dists.dss new file mode 100644 index 000000000000..895a2ae9a12f --- /dev/null +++ b/py-polars/tests/benchmark/data/pdsh/dbgen/dists.dss @@ -0,0 +1,836 @@ +# +# $Id: dists.dss,v 1.2 2005/01/03 20:08:58 jms Exp $ +# +# Revision History +# =================== +# $Log: dists.dss,v $ +# Revision 1.2 2005/01/03 20:08:58 jms +# change line terminations +# +# Revision 1.1.1.1 2004/11/24 23:31:46 jms +# re-establish external server +# +# Revision 1.1.1.1 2003/04/03 18:54:21 jms +# recreation after CVS crash +# +# Revision 1.1.1.1 2003/04/03 18:54:21 jms +# initial checkin +# +# +# +# +# distributions have the following format: +# +# | # comment +# +# Distributions are used to bias the selection of a token +# based on its associated weight. The list of tokens and values +# between the keywords BEGIN and END define the distribution named after +# the BEGIN. A uniformly random value from [0, sum(weights)] +# will be chosen and the first token whose cumulative weight is greater than +# or equal to the result will be returned. In essence, the weights for each +# token represent its relative weight within a distribution. +# +# one special token is defined: count (number of data points in the +# distribution). It MUST be defined for each named distribution. +#----------------------------------------------------------------------- +# currently defined distributions and their use: +# NAME FIELD/NOTES +# ======== ============== +# category parts.category +# container parts.container +# instruct shipping instructions +# msegmnt market segment +# names parts.name +# nations must be ordered along with regions +# nations2 stand alone nations set for use with qgen +# o_prio order priority +# regions must be ordered along with nations +# rflag lineitems.returnflag +# types parts.type +# colors embedded string creation; CANNOT BE USED FOR pick_str(), agg_str() perturbs order +# articles comment generation +# nouns +# verbs +# adverbs +# auxillaries +# prepositions +# terminators +# grammar sentence formation +# np +# vp +### +# category +### +BEGIN category +COUNT|5 +FURNITURE|1 +STORAGE EQUIP|1 +TOOLS|1 +MACHINE TOOLS|1 +OTHER|1 +END category +### +# container +### +begin p_cntr +count|40 +SM CASE|1 +SM BOX|1 +SM BAG|1 +SM JAR|1 +SM PACK|1 +SM PKG|1 +SM CAN|1 +SM DRUM|1 +LG CASE|1 +LG BOX|1 +LG BAG|1 +LG JAR|1 +LG PACK|1 +LG PKG|1 +LG CAN|1 +LG DRUM|1 +MED CASE|1 +MED BOX|1 +MED BAG|1 +MED JAR|1 +MED PACK|1 +MED PKG|1 +MED CAN|1 +MED DRUM|1 +JUMBO CASE|1 +JUMBO BOX|1 +JUMBO BAG|1 +JUMBO JAR|1 +JUMBO PACK|1 +JUMBO PKG|1 +JUMBO CAN|1 +JUMBO DRUM|1 +WRAP CASE|1 +WRAP BOX|1 +WRAP BAG|1 +WRAP JAR|1 +WRAP PACK|1 +WRAP PKG|1 +WRAP CAN|1 +WRAP DRUM|1 +end p_cntr +### +# instruct +### +begin instruct +count|4 +DELIVER IN PERSON|1 +COLLECT COD|1 +TAKE BACK RETURN|1 +NONE|1 +end instruct +### +# msegmnt +### +begin msegmnt +count|5 +AUTOMOBILE|1 +BUILDING|1 +FURNITURE|1 +HOUSEHOLD|1 +MACHINERY|1 +end msegmnt +### +# names +### +begin p_names +COUNT|4 +CLEANER|1 +SOAP|1 +DETERGENT|1 +EXTRA|1 +end p_names +### +# nations +# NOTE: this is a special case; the weights here are adjustments to +# map correctly into the regions table, and are *NOT* cummulative +# values to mimic a distribution +### +begin nations +count|25 +ALGERIA|0 +ARGENTINA|1 +BRAZIL|0 +CANADA|0 +EGYPT|3 +ETHIOPIA|-4 +FRANCE|3 +GERMANY|0 +INDIA|-1 +INDONESIA|0 +IRAN|2 +IRAQ|0 +JAPAN|-2 +JORDAN|2 +KENYA|-4 +MOROCCO|0 +MOZAMBIQUE|0 +PERU|1 +CHINA|1 +ROMANIA|1 +SAUDI ARABIA|1 +VIETNAM|-2 +RUSSIA|1 +UNITED KINGDOM|0 +UNITED STATES|-2 +end nations +### +# nations2 +### +begin nations2 +count|25 +ALGERIA|1 +ARGENTINA|1 +BRAZIL|1 +CANADA|1 +EGYPT|1 +ETHIOPIA|1 +FRANCE|1 +GERMANY|1 +INDIA|1 +INDONESIA|1 +IRAN|1 +IRAQ|1 +JAPAN|1 +JORDAN|1 +KENYA|1 +MOROCCO|1 +MOZAMBIQUE|1 +PERU|1 +CHINA|1 +ROMANIA|1 +SAUDI ARABIA|1 +VIETNAM|1 +RUSSIA|1 +UNITED KINGDOM|1 +UNITED STATES|1 +end nations2 +### +# regions +### +begin regions +count|5 +AFRICA|1 +AMERICA|1 +ASIA|1 +EUROPE|1 +MIDDLE EAST|1 +end regions +### +# o_prio +### +begin o_oprio +count|5 +1-URGENT|1 +2-HIGH|1 +3-MEDIUM|1 +4-NOT SPECIFIED|1 +5-LOW|1 +end o_oprio +### +# rflag +### +begin rflag +count|2 +R|1 +A|1 +end rflag +### +# smode +### +begin smode +count|7 +REG AIR|1 +AIR|1 +RAIL|1 +TRUCK|1 +MAIL|1 +FOB|1 +SHIP|1 +end smode +### +# types +### +begin p_types +COUNT|150 +STANDARD ANODIZED TIN|1 +STANDARD ANODIZED NICKEL|1 +STANDARD ANODIZED BRASS|1 +STANDARD ANODIZED STEEL|1 +STANDARD ANODIZED COPPER|1 +STANDARD BURNISHED TIN|1 +STANDARD BURNISHED NICKEL|1 +STANDARD BURNISHED BRASS|1 +STANDARD BURNISHED STEEL|1 +STANDARD BURNISHED COPPER|1 +STANDARD PLATED TIN|1 +STANDARD PLATED NICKEL|1 +STANDARD PLATED BRASS|1 +STANDARD PLATED STEEL|1 +STANDARD PLATED COPPER|1 +STANDARD POLISHED TIN|1 +STANDARD POLISHED NICKEL|1 +STANDARD POLISHED BRASS|1 +STANDARD POLISHED STEEL|1 +STANDARD POLISHED COPPER|1 +STANDARD BRUSHED TIN|1 +STANDARD BRUSHED NICKEL|1 +STANDARD BRUSHED BRASS|1 +STANDARD BRUSHED STEEL|1 +STANDARD BRUSHED COPPER|1 +SMALL ANODIZED TIN|1 +SMALL ANODIZED NICKEL|1 +SMALL ANODIZED BRASS|1 +SMALL ANODIZED STEEL|1 +SMALL ANODIZED COPPER|1 +SMALL BURNISHED TIN|1 +SMALL BURNISHED NICKEL|1 +SMALL BURNISHED BRASS|1 +SMALL BURNISHED STEEL|1 +SMALL BURNISHED COPPER|1 +SMALL PLATED TIN|1 +SMALL PLATED NICKEL|1 +SMALL PLATED BRASS|1 +SMALL PLATED STEEL|1 +SMALL PLATED COPPER|1 +SMALL POLISHED TIN|1 +SMALL POLISHED NICKEL|1 +SMALL POLISHED BRASS|1 +SMALL POLISHED STEEL|1 +SMALL POLISHED COPPER|1 +SMALL BRUSHED TIN|1 +SMALL BRUSHED NICKEL|1 +SMALL BRUSHED BRASS|1 +SMALL BRUSHED STEEL|1 +SMALL BRUSHED COPPER|1 +MEDIUM ANODIZED TIN|1 +MEDIUM ANODIZED NICKEL|1 +MEDIUM ANODIZED BRASS|1 +MEDIUM ANODIZED STEEL|1 +MEDIUM ANODIZED COPPER|1 +MEDIUM BURNISHED TIN|1 +MEDIUM BURNISHED NICKEL|1 +MEDIUM BURNISHED BRASS|1 +MEDIUM BURNISHED STEEL|1 +MEDIUM BURNISHED COPPER|1 +MEDIUM PLATED TIN|1 +MEDIUM PLATED NICKEL|1 +MEDIUM PLATED BRASS|1 +MEDIUM PLATED STEEL|1 +MEDIUM PLATED COPPER|1 +MEDIUM POLISHED TIN|1 +MEDIUM POLISHED NICKEL|1 +MEDIUM POLISHED BRASS|1 +MEDIUM POLISHED STEEL|1 +MEDIUM POLISHED COPPER|1 +MEDIUM BRUSHED TIN|1 +MEDIUM BRUSHED NICKEL|1 +MEDIUM BRUSHED BRASS|1 +MEDIUM BRUSHED STEEL|1 +MEDIUM BRUSHED COPPER|1 +LARGE ANODIZED TIN|1 +LARGE ANODIZED NICKEL|1 +LARGE ANODIZED BRASS|1 +LARGE ANODIZED STEEL|1 +LARGE ANODIZED COPPER|1 +LARGE BURNISHED TIN|1 +LARGE BURNISHED NICKEL|1 +LARGE BURNISHED BRASS|1 +LARGE BURNISHED STEEL|1 +LARGE BURNISHED COPPER|1 +LARGE PLATED TIN|1 +LARGE PLATED NICKEL|1 +LARGE PLATED BRASS|1 +LARGE PLATED STEEL|1 +LARGE PLATED COPPER|1 +LARGE POLISHED TIN|1 +LARGE POLISHED NICKEL|1 +LARGE POLISHED BRASS|1 +LARGE POLISHED STEEL|1 +LARGE POLISHED COPPER|1 +LARGE BRUSHED TIN|1 +LARGE BRUSHED NICKEL|1 +LARGE BRUSHED BRASS|1 +LARGE BRUSHED STEEL|1 +LARGE BRUSHED COPPER|1 +ECONOMY ANODIZED TIN|1 +ECONOMY ANODIZED NICKEL|1 +ECONOMY ANODIZED BRASS|1 +ECONOMY ANODIZED STEEL|1 +ECONOMY ANODIZED COPPER|1 +ECONOMY BURNISHED TIN|1 +ECONOMY BURNISHED NICKEL|1 +ECONOMY BURNISHED BRASS|1 +ECONOMY BURNISHED STEEL|1 +ECONOMY BURNISHED COPPER|1 +ECONOMY PLATED TIN|1 +ECONOMY PLATED NICKEL|1 +ECONOMY PLATED BRASS|1 +ECONOMY PLATED STEEL|1 +ECONOMY PLATED COPPER|1 +ECONOMY POLISHED TIN|1 +ECONOMY POLISHED NICKEL|1 +ECONOMY POLISHED BRASS|1 +ECONOMY POLISHED STEEL|1 +ECONOMY POLISHED COPPER|1 +ECONOMY BRUSHED TIN|1 +ECONOMY BRUSHED NICKEL|1 +ECONOMY BRUSHED BRASS|1 +ECONOMY BRUSHED STEEL|1 +ECONOMY BRUSHED COPPER|1 +PROMO ANODIZED TIN|1 +PROMO ANODIZED NICKEL|1 +PROMO ANODIZED BRASS|1 +PROMO ANODIZED STEEL|1 +PROMO ANODIZED COPPER|1 +PROMO BURNISHED TIN|1 +PROMO BURNISHED NICKEL|1 +PROMO BURNISHED BRASS|1 +PROMO BURNISHED STEEL|1 +PROMO BURNISHED COPPER|1 +PROMO PLATED TIN|1 +PROMO PLATED NICKEL|1 +PROMO PLATED BRASS|1 +PROMO PLATED STEEL|1 +PROMO PLATED COPPER|1 +PROMO POLISHED TIN|1 +PROMO POLISHED NICKEL|1 +PROMO POLISHED BRASS|1 +PROMO POLISHED STEEL|1 +PROMO POLISHED COPPER|1 +PROMO BRUSHED TIN|1 +PROMO BRUSHED NICKEL|1 +PROMO BRUSHED BRASS|1 +PROMO BRUSHED STEEL|1 +PROMO BRUSHED COPPER|1 +end p_types +### +# colors +# NOTE: This distribution CANNOT be used by pick_str(), since agg_str() perturbs its order +### +begin colors +COUNT|92 +almond|1 +antique|1 +aquamarine|1 +azure|1 +beige|1 +bisque|1 +black|1 +blanched|1 +blue|1 +blush|1 +brown|1 +burlywood|1 +burnished|1 +chartreuse|1 +chiffon|1 +chocolate|1 +coral|1 +cornflower|1 +cornsilk|1 +cream|1 +cyan|1 +dark|1 +deep|1 +dim|1 +dodger|1 +drab|1 +firebrick|1 +floral|1 +forest|1 +frosted|1 +gainsboro|1 +ghost|1 +goldenrod|1 +green|1 +grey|1 +honeydew|1 +hot|1 +indian|1 +ivory|1 +khaki|1 +lace|1 +lavender|1 +lawn|1 +lemon|1 +light|1 +lime|1 +linen|1 +magenta|1 +maroon|1 +medium|1 +metallic|1 +midnight|1 +mint|1 +misty|1 +moccasin|1 +navajo|1 +navy|1 +olive|1 +orange|1 +orchid|1 +pale|1 +papaya|1 +peach|1 +peru|1 +pink|1 +plum|1 +powder|1 +puff|1 +purple|1 +red|1 +rose|1 +rosy|1 +royal|1 +saddle|1 +salmon|1 +sandy|1 +seashell|1 +sienna|1 +sky|1 +slate|1 +smoke|1 +snow|1 +spring|1 +steel|1 +tan|1 +thistle|1 +tomato|1 +turquoise|1 +violet|1 +wheat|1 +white|1 +yellow|1 +end colors +################ +################ +## psuedo text distributions +################ +################ +### +# nouns +### +BEGIN nouns +COUNT|45 +packages|40 +requests|40 +accounts|40 +deposits|40 +foxes|20 +ideas|20 +theodolites|20 +pinto beans|20 +instructions|20 +dependencies|10 +excuses|10 +platelets|10 +asymptotes|10 +courts|5 +dolphins|5 +multipliers|1 +sauternes|1 +warthogs|1 +frets|1 +dinos|1 +attainments|1 +somas|1 +Tiresias|1 +patterns|1 +forges|1 +braids|1 +frays|1 +warhorses|1 +dugouts|1 +notornis|1 +epitaphs|1 +pearls|1 +tithes|1 +waters|1 +orbits|1 +gifts|1 +sheaves|1 +depths|1 +sentiments|1 +decoys|1 +realms|1 +pains|1 +grouches|1 +escapades|1 +hockey players|1 +END nouns +### +# verbs +### +BEGIN verbs +COUNT|40 +sleep|20 +wake|20 +are|20 +cajole|20 +haggle|20 +nag|10 +use|10 +boost|10 +affix|5 +detect|5 +integrate|5 +maintain|1 +nod|1 +was|1 +lose|1 +sublate|1 +solve|1 +thrash|1 +promise|1 +engage|1 +hinder|1 +print|1 +x-ray|1 +breach|1 +eat|1 +grow|1 +impress|1 +mold|1 +poach|1 +serve|1 +run|1 +dazzle|1 +snooze|1 +doze|1 +unwind|1 +kindle|1 +play|1 +hang|1 +believe|1 +doubt|1 +END verbs +### +# adverbs +## +BEGIN adverbs +COUNT|28 +sometimes|1 +always|1 +never|1 +furiously|50 +slyly|50 +carefully|50 +blithely|40 +quickly|30 +fluffily|20 +slowly|1 +quietly|1 +ruthlessly|1 +thinly|1 +closely|1 +doggedly|1 +daringly|1 +bravely|1 +stealthily|1 +permanently|1 +enticingly|1 +idly|1 +busily|1 +regularly|1 +finally|1 +ironically|1 +evenly|1 +boldly|1 +silently|1 +END adverbs +### +# articles +## +BEGIN articles +COUNT|3 +the|50 +a|20 +an|5 +END articles +### +# prepositions +## +BEGIN prepositions +COUNT|47 +about|50 +above|50 +according to|50 +across|50 +after|50 +against|40 +along|40 +alongside of|30 +among|30 +around|20 +at|10 +atop|1 +before|1 +behind|1 +beneath|1 +beside|1 +besides|1 +between|1 +beyond|1 +by|1 +despite|1 +during|1 +except|1 +for|1 +from|1 +in place of|1 +inside|1 +instead of|1 +into|1 +near|1 +of|1 +on|1 +outside|1 +over|1 +past|1 +since|1 +through|1 +throughout|1 +to|1 +toward|1 +under|1 +until|1 +up|1 +upon|1 +whithout|1 +with|1 +within|1 +END prepositions +### +# auxillaries +## +BEGIN auxillaries +COUNT|18 +do|1 +may|1 +might|1 +shall|1 +will|1 +would|1 +can|1 +could|1 +should|1 +ought to|1 +must|1 +will have to|1 +shall have to|1 +could have to|1 +should have to|1 +must have to|1 +need to|1 +try to|1 +END auxiallaries +### +# terminators +## +BEGIN terminators +COUNT|6 +.|50 +;|1 +:|1 +?|1 +!|1 +--|1 +END terminators +### +# adjectives +## +BEGIN adjectives +COUNT|29 +special|20 +pending|20 +unusual|20 +express|20 +furious|1 +sly|1 +careful|1 +blithe|1 +quick|1 +fluffy|1 +slow|1 +quiet|1 +ruthless|1 +thin|1 +close|1 +dogged|1 +daring|1 +brave|1 +stealthy|1 +permanent|1 +enticing|1 +idle|1 +busy|1 +regular|50 +final|40 +ironic|40 +even|30 +bold|20 +silent|10 +END adjectives +### +# grammar +# first level grammar. N=noun phrase, V=verb phrase, +# P=prepositional phrase, T=setence termination +## +BEGIN grammar +COUNT|5 +N V T|3 +N V P T|3 +N V N T|3 +N P V N T|1 +N P V P T|1 +END grammar +### +# NP +# second level grammar. Noun phrases. N=noun, A=article, +# J=adjective, D=adverb +## +BEGIN np +COUNT|4 +N|10 +J N|20 +J, J N|10 +D J N|50 +END np +### +# VP +# second level grammar. Verb phrases. V=verb, X=auxiallary, +# D=adverb +## +BEGIN vp +COUNT|4 +V|30 +X V|1 +V D|40 +X V D|1 +END vp +### +# Q13 +# Substitution parameters for Q13 +## +BEGIN Q13a +COUNT|4 +special|20 +pending|20 +unusual|20 +express|20 +END Q13a +BEGIN Q13b +COUNT|4 +packages|40 +requests|40 +accounts|40 +deposits|40 +END Q13b diff --git a/py-polars/tests/benchmark/data/pdsh/generate_data.py b/py-polars/tests/benchmark/data/pdsh/generate_data.py new file mode 100644 index 000000000000..0f8866311ada --- /dev/null +++ b/py-polars/tests/benchmark/data/pdsh/generate_data.py @@ -0,0 +1,165 @@ +""" +Disclaimer. + +Certain portions of the contents of this file are derived from TPC-H version 3.0.1 +(retrieved from +http://www.tpc.org/tpc_documents_current_versions/current_specifications5.asp). +Such portions are subject to copyrights held by Transaction Processing +Performance Council (“TPC”) and licensed under the TPC EULA is available at +http://www.tpc.org/tpc_documents_current_versions/current_specifications5.asp) +(the “TPC EULA”). + +You may not use this file except in compliance with the TPC EULA. +DISCLAIMER: Portions of this file is derived from the TPC-H benchmark and as +such any result obtained using this file are not comparable to published TPC-H +Benchmark results, as the results obtained from using this file do not comply with +the TPC-H Benchmark. +""" + +from __future__ import annotations + +import logging +import subprocess +import sys +from pathlib import Path + +import polars as pl + +logging.basicConfig(level=logging.INFO) + +CURRENT_DIR = Path(__file__).parent +DBGEN_DIR = CURRENT_DIR / "dbgen" + +__all__ = ["load_pdsh_table"] + + +def load_pdsh_table(table_name: str, scale_factor: float = 0.01) -> pl.DataFrame: + """ + Load a PDS-H table from disk. + + If the file does not exist, it is generated along with all other tables. + """ + folder = CURRENT_DIR / f"sf-{scale_factor:g}" + file_path = folder / f"{table_name}.parquet" + + if not file_path.exists(): + _generate_pdsh_data(scale_factor) + + return pl.read_parquet(file_path) + + +def _generate_pdsh_data(scale_factor: float = 0.01) -> None: + """Generate all PDS-H datasets with the given scale factor.""" + # TODO: Can we make this work under Windows? + if sys.platform == "win32": + msg = "cannot generate PDS-H data under Windows" + raise RuntimeError(msg) + + subprocess.run(["./dbgen", "-f", "-v", "-s", str(scale_factor)], cwd=DBGEN_DIR) + + _process_data(scale_factor) + + +def _process_data(scale_factor: float = 0.01) -> None: + """Process the data into Parquet files with the correct schema.""" + dest = CURRENT_DIR / f"sf-{scale_factor:g}" + dest.mkdir(exist_ok=True) + + for table_name, columns in TABLE_COLUMN_NAMES.items(): + logging.info(f"Processing table: {table_name}") + + table_path = DBGEN_DIR / f"{table_name}.tbl" + lf = pl.scan_csv( + table_path, + has_header=False, + separator="|", + try_parse_dates=True, + new_columns=columns, + ) + + # Drop empty last column because CSV ends with a separator + lf = lf.select(columns) + + lf.sink_parquet(dest / f"{table_name}.parquet") + table_path.unlink() + + +TABLE_COLUMN_NAMES = { + "customer": [ + "c_custkey", + "c_name", + "c_address", + "c_nationkey", + "c_phone", + "c_acctbal", + "c_mktsegment", + "c_comment", + ], + "lineitem": [ + "l_orderkey", + "l_partkey", + "l_suppkey", + "l_linenumber", + "l_quantity", + "l_extendedprice", + "l_discount", + "l_tax", + "l_returnflag", + "l_linestatus", + "l_shipdate", + "l_commitdate", + "l_receiptdate", + "l_shipinstruct", + "l_shipmode", + "comments", + ], + "nation": [ + "n_nationkey", + "n_name", + "n_regionkey", + "n_comment", + ], + "orders": [ + "o_orderkey", + "o_custkey", + "o_orderstatus", + "o_totalprice", + "o_orderdate", + "o_orderpriority", + "o_clerk", + "o_shippriority", + "o_comment", + ], + "part": [ + "p_partkey", + "p_name", + "p_mfgr", + "p_brand", + "p_type", + "p_size", + "p_container", + "p_retailprice", + "p_comment", + ], + "partsupp": [ + "ps_partkey", + "ps_suppkey", + "ps_availqty", + "ps_supplycost", + "ps_comment", + ], + "region": [ + "r_regionkey", + "r_name", + "r_comment", + ], + "supplier": [ + "s_suppkey", + "s_name", + "s_address", + "s_nationkey", + "s_phone", + "s_acctbal", + "s_comment", + ], +} diff --git a/py-polars/tests/benchmark/interop/__init__.py b/py-polars/tests/benchmark/interop/__init__.py new file mode 100644 index 000000000000..2b2fa00648fd --- /dev/null +++ b/py-polars/tests/benchmark/interop/__init__.py @@ -0,0 +1 @@ +"""Benchmark tests for conversions from/to other data formats.""" diff --git a/py-polars/tests/benchmark/interop/test_numpy.py b/py-polars/tests/benchmark/interop/test_numpy.py new file mode 100644 index 000000000000..5b516b0ecb19 --- /dev/null +++ b/py-polars/tests/benchmark/interop/test_numpy.py @@ -0,0 +1,53 @@ +"""Benchmark tests for conversions from/to NumPy.""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +import pytest + +import polars as pl + +pytestmark = pytest.mark.benchmark() + + +@pytest.fixture(scope="module") +def floats_array() -> np.ndarray[Any, Any]: + n_rows = 10_000 + return np.random.randn(n_rows) + + +@pytest.fixture +def floats(floats_array: np.ndarray[Any, Any]) -> pl.Series: + return pl.Series(floats_array) + + +@pytest.fixture +def floats_with_nulls(floats: pl.Series) -> pl.Series: + null_probability = 0.1 + validity = pl.Series(np.random.uniform(size=floats.len())) > null_probability + return pl.select(pl.when(validity).then(floats)).to_series() + + +@pytest.fixture +def floats_chunked(floats_array: np.ndarray[Any, Any]) -> pl.Series: + n_chunks = 5 + chunk_len = len(floats_array) // n_chunks + chunks = [ + floats_array[i * chunk_len : (i + 1) * chunk_len] for i in range(n_chunks) + ] + chunks_copy = [pl.Series(c.copy()) for c in chunks] + return pl.concat(chunks_copy, rechunk=False) + + +def test_to_numpy_series_zero_copy(floats: pl.Series) -> None: + floats.to_numpy() + + +def test_to_numpy_series_with_nulls(floats_with_nulls: pl.Series) -> None: + floats_with_nulls.to_numpy() + + +def test_to_numpy_series_chunked(floats_chunked: pl.Series) -> None: + floats_chunked.to_numpy() diff --git a/py-polars/tests/benchmark/test_filter.py b/py-polars/tests/benchmark/test_filter.py new file mode 100644 index 000000000000..047404f01bb1 --- /dev/null +++ b/py-polars/tests/benchmark/test_filter.py @@ -0,0 +1,33 @@ +"""Benchmark tests for the filter operation.""" + +from __future__ import annotations + +import pytest + +import polars as pl + +pytestmark = pytest.mark.benchmark() + + +def test_filter1(groupby_data: pl.DataFrame) -> None: + ( + groupby_data.lazy() + .filter(pl.col("id1").eq_missing(pl.lit("id046"))) + .select( + pl.col("id6").cast(pl.Int64).sum(), + pl.col("v3").sum(), + ) + .collect() + ) + + +def test_filter2(groupby_data: pl.DataFrame) -> None: + ( + groupby_data.lazy() + .filter(~(pl.col("id1").eq_missing(pl.lit("id046")))) + .select( + pl.col("id6").cast(pl.Int64).sum(), + pl.col("v3").sum(), + ) + .collect() + ) diff --git a/py-polars/tests/benchmark/test_group_by.py b/py-polars/tests/benchmark/test_group_by.py new file mode 100644 index 000000000000..7547b13b5e62 --- /dev/null +++ b/py-polars/tests/benchmark/test_group_by.py @@ -0,0 +1,129 @@ +""" +Benchmark tests for the group-by operation. + +These tests are based on the H2O.ai database benchmark. + +See: +https://h2oai.github.io/db-benchmark/ +""" + +from __future__ import annotations + +import pytest + +import polars as pl + +pytestmark = pytest.mark.benchmark() + + +def test_groupby_h2oai_q1(groupby_data: pl.DataFrame) -> None: + ( + groupby_data.lazy() + .group_by("id1") + .agg( + pl.sum("v1").alias("v1_sum"), + ) + .collect() + ) + + +def test_groupby_h2oai_q2(groupby_data: pl.DataFrame) -> None: + ( + groupby_data.lazy() + .group_by("id1", "id2") + .agg( + pl.sum("v1").alias("v1_sum"), + ) + .collect() + ) + + +def test_groupby_h2oai_q3(groupby_data: pl.DataFrame) -> None: + ( + groupby_data.lazy() + .group_by("id3") + .agg( + pl.sum("v1").alias("v1_sum"), + pl.mean("v3").alias("v3_mean"), + ) + .collect() + ) + + +def test_groupby_h2oai_q4(groupby_data: pl.DataFrame) -> None: + ( + groupby_data.lazy() + .group_by("id4") + .agg( + pl.mean("v1").alias("v1_mean"), + pl.mean("v2").alias("v2_mean"), + pl.mean("v3").alias("v3_mean"), + ) + .collect() + ) + + +def test_groupby_h2oai_q5(groupby_data: pl.DataFrame) -> None: + ( + groupby_data.lazy() + .group_by("id6") + .agg( + pl.sum("v1").alias("v1_sum"), + pl.sum("v2").alias("v2_sum"), + pl.sum("v3").alias("v3_sum"), + ) + .collect() + ) + + +def test_groupby_h2oai_q6(groupby_data: pl.DataFrame) -> None: + ( + groupby_data.lazy() + .group_by("id4", "id5") + .agg( + pl.median("v3").alias("v3_median"), + pl.std("v3").alias("v3_std"), + ) + .collect() + ) + + +def test_groupby_h2oai_q7(groupby_data: pl.DataFrame) -> None: + ( + groupby_data.lazy() + .group_by("id3") + .agg((pl.max("v1") - pl.min("v2")).alias("range_v1_v2")) + .collect() + ) + + +def test_groupby_h2oai_q8(groupby_data: pl.DataFrame) -> None: + ( + groupby_data.lazy() + .drop_nulls("v3") + .group_by("id6") + .agg(pl.col("v3").top_k(2).alias("largest2_v3")) + .explode("largest2_v3") + .collect() + ) + + +def test_groupby_h2oai_q9(groupby_data: pl.DataFrame) -> None: + ( + groupby_data.lazy() + .group_by("id2", "id4") + .agg((pl.corr("v1", "v2") ** 2).alias("r2")) + .collect() + ) + + +def test_groupby_h2oai_q10(groupby_data: pl.DataFrame) -> None: + ( + groupby_data.lazy() + .group_by("id1", "id2", "id3", "id4", "id5", "id6") + .agg( + pl.sum("v3").alias("v3_sum"), + pl.count("v1").alias("v1_count"), + ) + .collect() + ) diff --git a/py-polars/tests/benchmark/test_io.py b/py-polars/tests/benchmark/test_io.py new file mode 100644 index 000000000000..44e165a61a03 --- /dev/null +++ b/py-polars/tests/benchmark/test_io.py @@ -0,0 +1,23 @@ +"""Benchmark tests for the I/O operations.""" + +from pathlib import Path + +import pytest + +import polars as pl + +pytestmark = pytest.mark.benchmark() + + +def test_write_read_scan_large_csv(groupby_data: pl.DataFrame, tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + data_path = tmp_path / "data.csv" + groupby_data.write_csv(data_path) + + predicate = pl.col("v2") < 5 + + shape_eager = pl.read_csv(data_path).filter(predicate).shape + shape_lazy = pl.scan_csv(data_path).filter(predicate).collect().shape + + assert shape_lazy == shape_eager diff --git a/py-polars/tests/benchmark/test_join_where.py b/py-polars/tests/benchmark/test_join_where.py new file mode 100644 index 000000000000..c5b77da4116e --- /dev/null +++ b/py-polars/tests/benchmark/test_join_where.py @@ -0,0 +1,111 @@ +"""Benchmark tests for join_where with inequality conditions.""" + +from __future__ import annotations + +import numpy as np +import pytest + +import polars as pl +from polars.exceptions import ColumnNotFoundError + +pytestmark = pytest.mark.benchmark() + + +def test_strict_inequalities(east_west: tuple[pl.DataFrame, pl.DataFrame]) -> None: + east, west = east_west + result = ( + east.lazy() + .join_where( + west.lazy(), + [pl.col("dur") < pl.col("time"), pl.col("rev") > pl.col("cost")], + ) + .collect() + ) + + assert len(result) > 0 + + +def test_non_strict_inequalities(east_west: tuple[pl.DataFrame, pl.DataFrame]) -> None: + east, west = east_west + result = ( + east.lazy() + .join_where( + west.lazy(), + [pl.col("dur") <= pl.col("time"), pl.col("rev") >= pl.col("cost")], + ) + .collect() + ) + + assert len(result) > 0 + + +def test_single_inequality(east_west: tuple[pl.DataFrame, pl.DataFrame]) -> None: + east, west = east_west + result = ( + east.lazy() + # Reduce the number of results by scaling LHS dur column up + .with_columns((pl.col("dur") * 30).alias("scaled_dur")) + .join_where( + west.lazy(), + pl.col("scaled_dur") < pl.col("time"), + ) + .collect() + ) + + assert len(result) > 0 + + +@pytest.fixture(scope="module") +def east_west() -> tuple[pl.DataFrame, pl.DataFrame]: + num_rows_left, num_rows_right = 50_000, 5_000 + rng = np.random.default_rng(42) + + # Generate two separate datasets where revenue/cost are linearly related to + # duration/time, but add some noise to the west table so that there are some + # rows where the cost for the same or greater time will be less than the east table. + east_dur = rng.integers(1_000, 50_000, num_rows_left) + east_rev = (east_dur * 0.123).astype(np.int32) + west_time = rng.integers(1_000, 50_000, num_rows_right) + west_cost = west_time * 0.123 + west_cost += rng.normal(0.0, 1.0, num_rows_right) + west_cost = west_cost.astype(np.int32) + + east = pl.DataFrame( + { + "id": np.arange(0, num_rows_left), + "dur": east_dur, + "rev": east_rev, + "cores": rng.integers(1, 10, num_rows_left), + } + ) + west = pl.DataFrame( + { + "t_id": np.arange(0, num_rows_right), + "time": west_time, + "cost": west_cost, + "cores": rng.integers(1, 10, num_rows_right), + } + ) + + return east, west + + +def test_join_where_invalid_column() -> None: + df = pl.DataFrame({"x": 1}) + with pytest.raises(ColumnNotFoundError, match="y"): + df.join_where(df, pl.col("x") < pl.col("y")) + + # Nested column + df1 = pl.DataFrame({"a": [1, 2, 3], "b": [True, False, True]}) + df2 = pl.DataFrame( + { + "a": [2, 3, 4], + "c": ["a", "b", "c"], + } + ) + with pytest.raises(ColumnNotFoundError, match="d"): + df = df1.join_where( + df2, + ((pl.col("a") - pl.col("b")) > (pl.col("c") == "a").cast(pl.Int32)) + > (pl.col("a") - pl.col("d")), + ) diff --git a/py-polars/tests/benchmark/test_pdsh.py b/py-polars/tests/benchmark/test_pdsh.py new file mode 100644 index 000000000000..2ee601b5b895 --- /dev/null +++ b/py-polars/tests/benchmark/test_pdsh.py @@ -0,0 +1,774 @@ +""" +Disclaimer. + +Certain portions of the contents of this file are derived from TPC-H version 3.0.1 +(retrieved from +http://www.tpc.org/tpc_documents_current_versions/current_specifications5.asp). +Such portions are subject to copyrights held by Transaction Processing +Performance Council (“TPC”) and licensed under the TPC EULA is available at +http://www.tpc.org/tpc_documents_current_versions/current_specifications5.asp) +(the “TPC EULA”). + +You may not use this file except in compliance with the TPC EULA. +DISCLAIMER: Portions of this file is derived from the TPC-H benchmark and as +such any result obtained using this file are not comparable to published TPC-H +Benchmark results, as the results obtained from using this file do not comply with +the TPC-H Benchmark. +""" + +import sys +from datetime import date + +import pytest + +import polars as pl +from tests.benchmark.data import load_pdsh_table + +if sys.platform == "win32": + pytest.skip("PDS-H data cannot be generated under Windows", allow_module_level=True) + +pytestmark = pytest.mark.benchmark() + + +@pytest.fixture(scope="module") +def customer() -> pl.LazyFrame: + return load_pdsh_table("customer").lazy() + + +@pytest.fixture(scope="module") +def lineitem() -> pl.LazyFrame: + return load_pdsh_table("lineitem").lazy() + + +@pytest.fixture(scope="module") +def nation() -> pl.LazyFrame: + return load_pdsh_table("nation").lazy() + + +@pytest.fixture(scope="module") +def orders() -> pl.LazyFrame: + return load_pdsh_table("orders").lazy() + + +@pytest.fixture(scope="module") +def part() -> pl.LazyFrame: + return load_pdsh_table("part").lazy() + + +@pytest.fixture(scope="module") +def partsupp() -> pl.LazyFrame: + return load_pdsh_table("partsupp").lazy() + + +@pytest.fixture(scope="module") +def region() -> pl.LazyFrame: + return load_pdsh_table("region").lazy() + + +@pytest.fixture(scope="module") +def supplier() -> pl.LazyFrame: + return load_pdsh_table("supplier").lazy() + + +def test_pdsh_q1(lineitem: pl.LazyFrame) -> None: + var1 = date(1998, 9, 2) + + q_final = ( + lineitem.filter(pl.col("l_shipdate") <= var1) + .group_by("l_returnflag", "l_linestatus") + .agg( + pl.sum("l_quantity").alias("sum_qty"), + pl.sum("l_extendedprice").alias("sum_base_price"), + (pl.col("l_extendedprice") * (1.0 - pl.col("l_discount"))) + .sum() + .alias("sum_disc_price"), + ( + pl.col("l_extendedprice") + * (1.0 - pl.col("l_discount")) + * (1.0 + pl.col("l_tax")) + ) + .sum() + .alias("sum_charge"), + pl.mean("l_quantity").alias("avg_qty"), + pl.mean("l_extendedprice").alias("avg_price"), + pl.mean("l_discount").alias("avg_disc"), + pl.len().alias("count_order"), + ) + .sort("l_returnflag", "l_linestatus") + ) + q_final.collect() + + +def test_pdsh_q2( + nation: pl.LazyFrame, + part: pl.LazyFrame, + partsupp: pl.LazyFrame, + region: pl.LazyFrame, + supplier: pl.LazyFrame, +) -> None: + var1 = 15 + var2 = "BRASS" + var3 = "EUROPE" + + result_q1 = ( + part.join(partsupp, left_on="p_partkey", right_on="ps_partkey") + .join(supplier, left_on="ps_suppkey", right_on="s_suppkey") + .join(nation, left_on="s_nationkey", right_on="n_nationkey") + .join(region, left_on="n_regionkey", right_on="r_regionkey") + .filter(pl.col("p_size") == var1) + .filter(pl.col("p_type").str.ends_with(var2)) + .filter(pl.col("r_name") == var3) + ) + + q_final = ( + result_q1.group_by("p_partkey") + .agg(pl.min("ps_supplycost")) + .join(result_q1, on=["p_partkey", "ps_supplycost"]) + .select( + "s_acctbal", + "s_name", + "n_name", + "p_partkey", + "p_mfgr", + "s_address", + "s_phone", + "s_comment", + ) + .sort( + by=["s_acctbal", "n_name", "s_name", "p_partkey"], + descending=[True, False, False, False], + ) + .head(100) + ) + q_final.collect() + + +def test_pdsh_q3( + customer: pl.LazyFrame, lineitem: pl.LazyFrame, orders: pl.LazyFrame +) -> None: + var1 = "BUILDING" + var2 = date(1995, 3, 15) + + q_final = ( + customer.filter(pl.col("c_mktsegment") == var1) + .join(orders, left_on="c_custkey", right_on="o_custkey") + .join(lineitem, left_on="o_orderkey", right_on="l_orderkey") + .filter(pl.col("o_orderdate") < var2) + .filter(pl.col("l_shipdate") > var2) + .with_columns( + (pl.col("l_extendedprice") * (1 - pl.col("l_discount"))).alias("revenue") + ) + .group_by("o_orderkey", "o_orderdate", "o_shippriority") + .agg(pl.sum("revenue")) + .select( + pl.col("o_orderkey").alias("l_orderkey"), + "revenue", + "o_orderdate", + "o_shippriority", + ) + .sort(by=["revenue", "o_orderdate"], descending=[True, False]) + .head(10) + ) + q_final.collect() + + +def test_pdsh_q4(lineitem: pl.LazyFrame, orders: pl.LazyFrame) -> None: + var1 = date(1993, 7, 1) + var2 = date(1993, 10, 1) + + q_final = ( + lineitem.join(orders, left_on="l_orderkey", right_on="o_orderkey") + .filter(pl.col("o_orderdate").is_between(var1, var2, closed="left")) + .filter(pl.col("l_commitdate") < pl.col("l_receiptdate")) + .unique(subset=["o_orderpriority", "l_orderkey"]) + .group_by("o_orderpriority") + .agg(pl.len().alias("order_count")) + .sort("o_orderpriority") + ) + q_final.collect() + + +def test_pdsh_q5( + customer: pl.LazyFrame, + lineitem: pl.LazyFrame, + nation: pl.LazyFrame, + orders: pl.LazyFrame, + region: pl.LazyFrame, + supplier: pl.LazyFrame, +) -> None: + var1 = "ASIA" + var2 = date(1994, 1, 1) + var3 = date(1995, 1, 1) + + q_final = ( + region.join(nation, left_on="r_regionkey", right_on="n_regionkey") + .join(customer, left_on="n_nationkey", right_on="c_nationkey") + .join(orders, left_on="c_custkey", right_on="o_custkey") + .join(lineitem, left_on="o_orderkey", right_on="l_orderkey") + .join( + supplier, + left_on=["l_suppkey", "n_nationkey"], + right_on=["s_suppkey", "s_nationkey"], + ) + .filter(pl.col("r_name") == var1) + .filter(pl.col("o_orderdate").is_between(var2, var3, closed="left")) + .with_columns( + (pl.col("l_extendedprice") * (1 - pl.col("l_discount"))).alias("revenue") + ) + .group_by("n_name") + .agg(pl.sum("revenue")) + .sort(by="revenue", descending=True) + ) + + q_final.collect() + + +def test_pdsh_q6(lineitem: pl.LazyFrame) -> None: + var1 = date(1994, 1, 1) + var2 = date(1995, 1, 1) + var3 = 0.05 + var4 = 0.07 + var5 = 24 + + q_final = ( + lineitem.filter(pl.col("l_shipdate").is_between(var1, var2, closed="left")) + .filter(pl.col("l_discount").is_between(var3, var4)) + .filter(pl.col("l_quantity") < var5) + .with_columns( + (pl.col("l_extendedprice") * pl.col("l_discount")).alias("revenue") + ) + .select(pl.sum("revenue")) + ) + + q_final.collect() + + +def test_pdsh_q7( + customer: pl.LazyFrame, + lineitem: pl.LazyFrame, + nation: pl.LazyFrame, + orders: pl.LazyFrame, + supplier: pl.LazyFrame, +) -> None: + var1 = "FRANCE" + var2 = "GERMANY" + var3 = date(1995, 1, 1) + var4 = date(1996, 12, 31) + + n1 = nation.filter(pl.col("n_name") == var1) + n2 = nation.filter(pl.col("n_name") == var2) + + q1 = ( + customer.join(n1, left_on="c_nationkey", right_on="n_nationkey") + .join(orders, left_on="c_custkey", right_on="o_custkey") + .rename({"n_name": "cust_nation"}) + .join(lineitem, left_on="o_orderkey", right_on="l_orderkey") + .join(supplier, left_on="l_suppkey", right_on="s_suppkey") + .join(n2, left_on="s_nationkey", right_on="n_nationkey") + .rename({"n_name": "supp_nation"}) + ) + + q2 = ( + customer.join(n2, left_on="c_nationkey", right_on="n_nationkey") + .join(orders, left_on="c_custkey", right_on="o_custkey") + .rename({"n_name": "cust_nation"}) + .join(lineitem, left_on="o_orderkey", right_on="l_orderkey") + .join(supplier, left_on="l_suppkey", right_on="s_suppkey") + .join(n1, left_on="s_nationkey", right_on="n_nationkey") + .rename({"n_name": "supp_nation"}) + ) + + q_final = ( + pl.concat([q1, q2]) + .filter(pl.col("l_shipdate").is_between(var3, var4)) + .with_columns( + (pl.col("l_extendedprice") * (1 - pl.col("l_discount"))).alias("volume"), + pl.col("l_shipdate").dt.year().alias("l_year"), + ) + .group_by("supp_nation", "cust_nation", "l_year") + .agg(pl.sum("volume").alias("revenue")) + .sort(by=["supp_nation", "cust_nation", "l_year"]) + ) + q_final.collect() + + +def test_pdsh_q8( + customer: pl.LazyFrame, + lineitem: pl.LazyFrame, + nation: pl.LazyFrame, + orders: pl.LazyFrame, + part: pl.LazyFrame, + region: pl.LazyFrame, + supplier: pl.LazyFrame, +) -> None: + var1 = "BRAZIL" + var2 = "AMERICA" + var3 = "ECONOMY ANODIZED STEEL" + var4 = date(1995, 1, 1) + var5 = date(1996, 12, 31) + + n1 = nation.select("n_nationkey", "n_regionkey") + n2 = nation.select("n_nationkey", "n_name") + + q_final = ( + part.join(lineitem, left_on="p_partkey", right_on="l_partkey") + .join(supplier, left_on="l_suppkey", right_on="s_suppkey") + .join(orders, left_on="l_orderkey", right_on="o_orderkey") + .join(customer, left_on="o_custkey", right_on="c_custkey") + .join(n1, left_on="c_nationkey", right_on="n_nationkey") + .join(region, left_on="n_regionkey", right_on="r_regionkey") + .filter(pl.col("r_name") == var2) + .join(n2, left_on="s_nationkey", right_on="n_nationkey") + .filter(pl.col("o_orderdate").is_between(var4, var5)) + .filter(pl.col("p_type") == var3) + .select( + pl.col("o_orderdate").dt.year().alias("o_year"), + (pl.col("l_extendedprice") * (1 - pl.col("l_discount"))).alias("volume"), + pl.col("n_name").alias("nation"), + ) + .with_columns( + pl.when(pl.col("nation") == var1) + .then(pl.col("volume")) + .otherwise(0) + .alias("_tmp") + ) + .group_by("o_year") + .agg((pl.sum("_tmp") / pl.sum("volume")).round(2).alias("mkt_share")) + .sort("o_year") + ) + + q_final.collect() + + +def test_pdsh_q9( + lineitem: pl.LazyFrame, + nation: pl.LazyFrame, + orders: pl.LazyFrame, + part: pl.LazyFrame, + partsupp: pl.LazyFrame, + supplier: pl.LazyFrame, +) -> None: + q_final = ( + lineitem.join(supplier, left_on="l_suppkey", right_on="s_suppkey") + .join( + partsupp, + left_on=["l_suppkey", "l_partkey"], + right_on=["ps_suppkey", "ps_partkey"], + ) + .join(part, left_on="l_partkey", right_on="p_partkey") + .join(orders, left_on="l_orderkey", right_on="o_orderkey") + .join(nation, left_on="s_nationkey", right_on="n_nationkey") + .filter(pl.col("p_name").str.contains("green")) + .select( + pl.col("n_name").alias("nation"), + pl.col("o_orderdate").dt.year().alias("o_year"), + ( + pl.col("l_extendedprice") * (1 - pl.col("l_discount")) + - pl.col("ps_supplycost") * pl.col("l_quantity") + ).alias("amount"), + ) + .group_by("nation", "o_year") + .agg(pl.sum("amount").round(2).alias("sum_profit")) + .sort(by=["nation", "o_year"], descending=[False, True]) + ) + + q_final.collect() + + +def test_pdsh_q10( + customer: pl.LazyFrame, + lineitem: pl.LazyFrame, + nation: pl.LazyFrame, + orders: pl.LazyFrame, +) -> None: + var1 = date(1993, 10, 1) + var2 = date(1994, 1, 1) + + q_final = ( + customer.join(orders, left_on="c_custkey", right_on="o_custkey") + .join(lineitem, left_on="o_orderkey", right_on="l_orderkey") + .join(nation, left_on="c_nationkey", right_on="n_nationkey") + .filter(pl.col("o_orderdate").is_between(var1, var2, closed="left")) + .filter(pl.col("l_returnflag") == "R") + .group_by( + "c_custkey", + "c_name", + "c_acctbal", + "c_phone", + "n_name", + "c_address", + "c_comment", + ) + .agg( + (pl.col("l_extendedprice") * (1 - pl.col("l_discount"))) + .sum() + .round(2) + .alias("revenue") + ) + .select( + "c_custkey", + "c_name", + "revenue", + "c_acctbal", + "n_name", + "c_address", + "c_phone", + "c_comment", + ) + .sort(by="revenue", descending=True) + .head(20) + ) + + q_final.collect() + + +def test_pdsh_q11( + nation: pl.LazyFrame, partsupp: pl.LazyFrame, supplier: pl.LazyFrame +) -> None: + var1 = "GERMANY" + var2 = 0.0001 + + q1 = ( + partsupp.join(supplier, left_on="ps_suppkey", right_on="s_suppkey") + .join(nation, left_on="s_nationkey", right_on="n_nationkey") + .filter(pl.col("n_name") == var1) + ) + q2 = q1.select( + (pl.col("ps_supplycost") * pl.col("ps_availqty")).sum().round(2).alias("tmp") + * var2 + ).with_columns(pl.lit(1).alias("lit")) + + q_final = ( + q1.group_by("ps_partkey") + .agg( + (pl.col("ps_supplycost") * pl.col("ps_availqty")) + .sum() + .round(2) + .alias("value") + ) + .with_columns(pl.lit(1).alias("lit")) + .join(q2, on="lit") + .filter(pl.col("value") > pl.col("tmp")) + .select("ps_partkey", "value") + .sort("value", descending=True) + ) + + q_final.collect() + + +def test_pdsh_q12(lineitem: pl.LazyFrame, orders: pl.LazyFrame) -> None: + var1 = "MAIL" + var2 = "SHIP" + var3 = date(1994, 1, 1) + var4 = date(1995, 1, 1) + + q_final = ( + orders.join(lineitem, left_on="o_orderkey", right_on="l_orderkey") + .filter(pl.col("l_shipmode").is_in([var1, var2])) + .filter(pl.col("l_commitdate") < pl.col("l_receiptdate")) + .filter(pl.col("l_shipdate") < pl.col("l_commitdate")) + .filter(pl.col("l_receiptdate").is_between(var3, var4, closed="left")) + .with_columns( + pl.when(pl.col("o_orderpriority").is_in(["1-URGENT", "2-HIGH"])) + .then(1) + .otherwise(0) + .alias("high_line_count"), + pl.when(pl.col("o_orderpriority").is_in(["1-URGENT", "2-HIGH"]).not_()) + .then(1) + .otherwise(0) + .alias("low_line_count"), + ) + .group_by("l_shipmode") + .agg(pl.col("high_line_count").sum(), pl.col("low_line_count").sum()) + .sort("l_shipmode") + ) + q_final.collect() + + +def test_pdsh_q13(customer: pl.LazyFrame, orders: pl.LazyFrame) -> None: + var1 = "special" + var2 = "requests" + + orders = orders.filter(pl.col("o_comment").str.contains(f"{var1}.*{var2}").not_()) + q_final = ( + customer.join(orders, left_on="c_custkey", right_on="o_custkey", how="left") + .group_by("c_custkey") + .agg(pl.col("o_orderkey").count().alias("c_count")) + .group_by("c_count") + .len() + .select(pl.col("c_count"), pl.col("len").alias("custdist")) + .sort(by=["custdist", "c_count"], descending=[True, True]) + ) + q_final.collect() + + +def test_pdsh_q14(lineitem: pl.LazyFrame, part: pl.LazyFrame) -> None: + var1 = date(1995, 9, 1) + var2 = date(1995, 10, 1) + + q_final = ( + lineitem.join(part, left_on="l_partkey", right_on="p_partkey") + .filter(pl.col("l_shipdate").is_between(var1, var2, closed="left")) + .select( + ( + 100.00 + * pl.when(pl.col("p_type").str.contains("PROMO*")) + .then(pl.col("l_extendedprice") * (1 - pl.col("l_discount"))) + .otherwise(0) + .sum() + / (pl.col("l_extendedprice") * (1 - pl.col("l_discount"))).sum() + ) + .round(2) + .alias("promo_revenue") + ) + ) + q_final.collect() + + +def test_pdsh_q15(lineitem: pl.LazyFrame, supplier: pl.LazyFrame) -> None: + var1 = date(1996, 1, 1) + var2 = date(1996, 4, 1) + + revenue = ( + lineitem.filter(pl.col("l_shipdate").is_between(var1, var2, closed="left")) + .group_by("l_suppkey") + .agg( + (pl.col("l_extendedprice") * (1 - pl.col("l_discount"))) + .sum() + .alias("total_revenue") + ) + .select(pl.col("l_suppkey").alias("supplier_no"), pl.col("total_revenue")) + ) + + q_final = ( + supplier.join(revenue, left_on="s_suppkey", right_on="supplier_no") + .filter(pl.col("total_revenue") == pl.col("total_revenue").max()) + .with_columns(pl.col("total_revenue").round(2)) + .select("s_suppkey", "s_name", "s_address", "s_phone", "total_revenue") + .sort("s_suppkey") + ) + q_final.collect() + + +def test_pdsh_q16( + part: pl.LazyFrame, partsupp: pl.LazyFrame, supplier: pl.LazyFrame +) -> None: + var1 = "Brand#45" + + supplier = supplier.filter( + pl.col("s_comment").str.contains(".*Customer.*Complaints.*") + ).select(pl.col("s_suppkey"), pl.col("s_suppkey").alias("ps_suppkey")) + + q_final = ( + part.join(partsupp, left_on="p_partkey", right_on="ps_partkey") + .filter(pl.col("p_brand") != var1) + .filter(pl.col("p_type").str.contains("MEDIUM POLISHED*").not_()) + .filter(pl.col("p_size").is_in([49, 14, 23, 45, 19, 3, 36, 9])) + .join(supplier, left_on="ps_suppkey", right_on="s_suppkey", how="left") + .filter(pl.col("ps_suppkey_right").is_null()) + .group_by("p_brand", "p_type", "p_size") + .agg(pl.col("ps_suppkey").n_unique().alias("supplier_cnt")) + .sort( + by=["supplier_cnt", "p_brand", "p_type", "p_size"], + descending=[True, False, False, False], + ) + ) + q_final.collect() + + +def test_pdsh_q17(lineitem: pl.LazyFrame, part: pl.LazyFrame) -> None: + var1 = "Brand#23" + var2 = "MED BOX" + + q1 = ( + part.filter(pl.col("p_brand") == var1) + .filter(pl.col("p_container") == var2) + .join(lineitem, how="left", left_on="p_partkey", right_on="l_partkey") + ) + + q_final = ( + q1.group_by("p_partkey") + .agg((0.2 * pl.col("l_quantity").mean()).alias("avg_quantity")) + .select(pl.col("p_partkey").alias("key"), pl.col("avg_quantity")) + .join(q1, left_on="key", right_on="p_partkey") + .filter(pl.col("l_quantity") < pl.col("avg_quantity")) + .select((pl.col("l_extendedprice").sum() / 7.0).round(2).alias("avg_yearly")) + ) + q_final.collect() + + +def test_pdsh_q18( + customer: pl.LazyFrame, lineitem: pl.LazyFrame, orders: pl.LazyFrame +) -> None: + var1 = 300 + + q_final = ( + lineitem.group_by("l_orderkey") + .agg(pl.col("l_quantity").sum().alias("sum_quantity")) + .filter(pl.col("sum_quantity") > var1) + .select(pl.col("l_orderkey").alias("key"), pl.col("sum_quantity")) + .join(orders, left_on="key", right_on="o_orderkey") + .join(lineitem, left_on="key", right_on="l_orderkey") + .join(customer, left_on="o_custkey", right_on="c_custkey") + .group_by("c_name", "o_custkey", "key", "o_orderdate", "o_totalprice") + .agg(pl.col("l_quantity").sum().alias("col6")) + .select( + pl.col("c_name"), + pl.col("o_custkey").alias("c_custkey"), + pl.col("key").alias("o_orderkey"), + pl.col("o_orderdate").alias("o_orderdat"), + pl.col("o_totalprice"), + pl.col("col6"), + ) + .sort(by=["o_totalprice", "o_orderdat"], descending=[True, False]) + .head(100) + ) + q_final.collect() + + +def test_pdsh_q19(lineitem: pl.LazyFrame, part: pl.LazyFrame) -> None: + q_final = ( + part.join(lineitem, left_on="p_partkey", right_on="l_partkey") + .filter(pl.col("l_shipmode").is_in(["AIR", "AIR REG"])) + .filter(pl.col("l_shipinstruct") == "DELIVER IN PERSON") + .filter( + ( + (pl.col("p_brand") == "Brand#12") + & pl.col("p_container").is_in( + ["SM CASE", "SM BOX", "SM PACK", "SM PKG"] + ) + & (pl.col("l_quantity").is_between(1, 11)) + & (pl.col("p_size").is_between(1, 5)) + ) + | ( + (pl.col("p_brand") == "Brand#23") + & pl.col("p_container").is_in( + ["MED BAG", "MED BOX", "MED PKG", "MED PACK"] + ) + & (pl.col("l_quantity").is_between(10, 20)) + & (pl.col("p_size").is_between(1, 10)) + ) + | ( + (pl.col("p_brand") == "Brand#34") + & pl.col("p_container").is_in( + ["LG CASE", "LG BOX", "LG PACK", "LG PKG"] + ) + & (pl.col("l_quantity").is_between(20, 30)) + & (pl.col("p_size").is_between(1, 15)) + ) + ) + .select( + (pl.col("l_extendedprice") * (1 - pl.col("l_discount"))) + .sum() + .round(2) + .alias("revenue") + ) + ) + q_final.collect() + + +def test_pdsh_q20( + lineitem: pl.LazyFrame, + nation: pl.LazyFrame, + part: pl.LazyFrame, + partsupp: pl.LazyFrame, + supplier: pl.LazyFrame, +) -> None: + var1 = date(1994, 1, 1) + var2 = date(1995, 1, 1) + var3 = "CANADA" + var4 = "forest" + + q1 = ( + lineitem.filter(pl.col("l_shipdate").is_between(var1, var2, closed="left")) + .group_by("l_partkey", "l_suppkey") + .agg((pl.col("l_quantity").sum() * 0.5).alias("sum_quantity")) + ) + q2 = nation.filter(pl.col("n_name") == var3) + q3 = supplier.join(q2, left_on="s_nationkey", right_on="n_nationkey") + + q_final = ( + part.filter(pl.col("p_name").str.starts_with(var4)) + .select(pl.col("p_partkey").unique()) + .join(partsupp, left_on="p_partkey", right_on="ps_partkey") + .join( + q1, + left_on=["ps_suppkey", "p_partkey"], + right_on=["l_suppkey", "l_partkey"], + ) + .filter(pl.col("ps_availqty") > pl.col("sum_quantity")) + .select(pl.col("ps_suppkey").unique()) + .join(q3, left_on="ps_suppkey", right_on="s_suppkey") + .select("s_name", "s_address") + .sort("s_name") + ) + q_final.collect() + + +def test_pdsh_q21( + lineitem: pl.LazyFrame, + nation: pl.LazyFrame, + orders: pl.LazyFrame, + supplier: pl.LazyFrame, +) -> None: + var1 = "SAUDI ARABIA" + + q1 = ( + lineitem.group_by("l_orderkey") + .agg(pl.col("l_suppkey").n_unique().alias("nunique_col")) + .filter(pl.col("nunique_col") > 1) + .join( + lineitem.filter(pl.col("l_receiptdate") > pl.col("l_commitdate")), + on="l_orderkey", + ) + ) + + q_final = ( + q1.group_by("l_orderkey") + .agg(pl.col("l_suppkey").n_unique().alias("nunique_col")) + .join(q1, on="l_orderkey") + .join(supplier, left_on="l_suppkey", right_on="s_suppkey") + .join(nation, left_on="s_nationkey", right_on="n_nationkey") + .join(orders, left_on="l_orderkey", right_on="o_orderkey") + .filter(pl.col("nunique_col") == 1) + .filter(pl.col("n_name") == var1) + .filter(pl.col("o_orderstatus") == "F") + .group_by("s_name") + .agg(pl.len().alias("numwait")) + .sort(by=["numwait", "s_name"], descending=[True, False]) + .head(100) + ) + q_final.collect() + + +def test_pdsh_q22(customer: pl.LazyFrame, orders: pl.LazyFrame) -> None: + q1 = ( + customer.with_columns(pl.col("c_phone").str.slice(0, 2).alias("cntrycode")) + .filter(pl.col("cntrycode").str.contains("13|31|23|29|30|18|17")) + .select("c_acctbal", "c_custkey", "cntrycode") + ) + + q2 = ( + q1.filter(pl.col("c_acctbal") > 0.0) + .select(pl.col("c_acctbal").mean().alias("avg_acctbal")) + .with_columns(pl.lit(1).alias("lit")) + ) + + q3 = orders.select(pl.col("o_custkey").unique()).with_columns( + pl.col("o_custkey").alias("c_custkey") + ) + + q_final = ( + q1.join(q3, on="c_custkey", how="left") + .filter(pl.col("o_custkey").is_null()) + .with_columns(pl.lit(1).alias("lit")) + .join(q2, on="lit") + .filter(pl.col("c_acctbal") > pl.col("avg_acctbal")) + .group_by("cntrycode") + .agg( + pl.col("c_acctbal").count().alias("numcust"), + pl.col("c_acctbal").sum().round(2).alias("totacctbal"), + ) + .sort("cntrycode") + ) + q_final.collect() diff --git a/py-polars/tests/benchmark/test_with_columns.py b/py-polars/tests/benchmark/test_with_columns.py new file mode 100644 index 000000000000..ca1018640592 --- /dev/null +++ b/py-polars/tests/benchmark/test_with_columns.py @@ -0,0 +1,44 @@ +from time import perf_counter + +import pytest + +import polars as pl +import polars.selectors as cs + + +# TODO: this is slow in streaming +@pytest.mark.may_fail_auto_streaming +@pytest.mark.slow +def test_with_columns_quadratic_19503() -> None: + num_columns = 10_000 + data1 = {f"col_{i}": [0] for i in range(num_columns)} + df1 = pl.DataFrame(data1) + + data2 = {f"feature_{i}": [0] for i in range(num_columns)} + df2 = pl.DataFrame(data2) + + times = [] # [slow, fast] + + class _: + rhs = df2 + t = perf_counter() + df1.with_columns(rhs) + times.append(perf_counter() - t) + + class _: # type: ignore[no-redef] + rhs = df2.select(cs.by_index(range(num_columns // 1_000))) + t = perf_counter() + df1.with_columns(rhs) + times.append(perf_counter() - t) + + ratio = times[0] / times[1] + + # Assert the relative rather than exact runtime to avoid flakiness in CI + # We pick a threshold just low enough to pass CI without any false + # negatives. + # 1.12.0 | 1.14.0 + # M3 Pro 11-core | 200x | 20x + # EC2 c7i.4xlarge | 150x | 13x + # GitHub CI runner | | 50x + if ratio > 100: + raise AssertionError(ratio) diff --git a/py-polars/tests/docs/run_doctest.py b/py-polars/tests/docs/run_doctest.py new file mode 100644 index 000000000000..404070deae85 --- /dev/null +++ b/py-polars/tests/docs/run_doctest.py @@ -0,0 +1,176 @@ +""" +Run all doctest examples of the `polars` module using Python's built-in doctest module. + +How to check examples: run this script, if exits with code 0, all is good. Otherwise, +the errors will be reported. + +How to modify behaviour for doctests: +1. if you would like code to be run and output checked: add the output below the code + block +2. if you would like code to be run (and thus checked whether it actually not fails), + but output not be checked: add `# doctest: +IGNORE_RESULT` to the code block. You may + still add example output. +3. if you would not like code to run: add `#doctest: +SKIP`. You may still add example + output. + +Notes +----- +* Doctest does not have a built-in IGNORE_RESULT directive. We have a number of tests + where we want to ensure that the code runs, but the output may be random by design, or + not interesting for us to check. To allow for this behaviour, a custom output checker + has been created, see below. +* The doctests depend on the exact string representation staying the same. This may not + be true in the future. For instance, in the past, the printout of DataFrames has + changed from rounded corners to less rounded corners. To facilitate such a change, + whilst not immediately having to add IGNORE_RESULT directives everywhere or changing + all outputs, set `IGNORE_RESULT_ALL=True` below. Do note that this does mean no output + is being checked anymore. +""" + +from __future__ import annotations + +import doctest +import importlib +import re +import sys +import unittest +import warnings +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import TYPE_CHECKING, Any + +import polars as pl + +if TYPE_CHECKING: + from collections.abc import Iterator + from types import ModuleType + + +if sys.version_info < (3, 12): + # Tests that print an OrderedDict fail (e.g. DataFrame.schema) as the repr + # has changed in Python 3.12 + warnings.warn( + "Certain doctests may fail when running on a Python version below 3.12." + " Update your Python version to 3.12 or later to make sure all tests pass.", + stacklevel=2, + ) + +# associate specific doctest method names with optional modules. +# if the module is found in the environment those doctests will +# run; if the module is not found, their doctests are skipped. +OPTIONAL_MODULES_AND_METHODS: dict[str, set[str]] = { + "jax": {"to_jax"}, + "torch": {"to_torch"}, +} +OPTIONAL_MODULES: set[str] = set() +SKIP_METHODS: set[str] = set() + +for mod, methods in OPTIONAL_MODULES_AND_METHODS.items(): + try: + importlib.import_module(mod) + except ImportError: # noqa: PERF203 + SKIP_METHODS.update(methods) + OPTIONAL_MODULES.add(mod) + + +def doctest_teardown(d: doctest.DocTest) -> None: + # don't let config changes or string cache state leak between tests + pl.Config.restore_defaults() + pl.disable_string_cache() + + +def modules_in_path(p: Path) -> Iterator[ModuleType]: + for file in p.rglob("*.py"): + # Construct path as string for import, for instance "dataframe.frame". + # (The -3 drops the ".py") + try: + file_name_import = ".".join(file.relative_to(p).parts)[:-3] + temp_module = importlib.import_module(p.name + "." + file_name_import) + yield temp_module + except ImportError as err: # noqa: PERF203 + if not any(re.search(rf"\b{mod}\b", str(err)) for mod in OPTIONAL_MODULES): + raise + + +class FilteredTestSuite(unittest.TestSuite): # noqa: D101 + def __iter__(self) -> Iterator[Any]: + for suite in self._tests: + suite._tests = [ # type: ignore[attr-defined] + test + for test in suite._tests # type: ignore[attr-defined] + if test.id().rsplit(".", 1)[-1] not in SKIP_METHODS + ] + yield suite + + +if __name__ == "__main__": + # set to True to just run the code, and do not check any output. + # Will still report errors if the code is invalid + IGNORE_RESULT_ALL = False + + # Below the implementation of the IGNORE_RESULT directive + # You can ignore the result of a doctest by adding "doctest: +IGNORE_RESULT" into + # the code block. The difference with SKIP is that if the code errors on running, + # that will still be reported. + IGNORE_RESULT = doctest.register_optionflag("IGNORE_RESULT") + + # Set doctests to fail on warnings + warnings.simplefilter("error", Warning) + warnings.filterwarnings( + "ignore", + message="datetime.datetime.utcfromtimestamp\\(\\) is deprecated.*", + category=DeprecationWarning, + ) + warnings.filterwarnings( + "ignore", + message="datetime.datetime.utcnow\\(\\) is deprecated.*", + category=DeprecationWarning, + ) + + OutputChecker = doctest.OutputChecker + + class IgnoreResultOutputChecker(OutputChecker): + """Python doctest output checker with support for IGNORE_RESULT.""" + + def check_output(self, want: str, got: str, optionflags: Any) -> bool: + """Return True iff the actual output from an example matches the output.""" + if IGNORE_RESULT_ALL: + return True + if IGNORE_RESULT & optionflags: + return True + else: + return OutputChecker.check_output(self, want, got, optionflags) + + doctest.OutputChecker = IgnoreResultOutputChecker # type: ignore[misc] + + # We want to be relaxed about whitespace, but strict on True vs 1 + doctest.NORMALIZE_WHITESPACE = True + doctest.DONT_ACCEPT_TRUE_FOR_1 = True + + # If REPORT_NDIFF is turned on, it will report on line by line, character by + # character, differences. The disadvantage is that you cannot just copy the output + # directly into the docstring. + # doctest.REPORT_NDIFF = True + + # __file__ returns the __init__.py, so grab the parent + src_dir = Path(pl.__file__).parent + + with TemporaryDirectory() as tmpdir: + # collect all tests + tests = [ + doctest.DocTestSuite( + m, + extraglobs={"pl": pl, "dirpath": Path(tmpdir)}, + tearDown=doctest_teardown, + optionflags=1, + ) + for m in modules_in_path(src_dir) + ] + test_suite = FilteredTestSuite(tests) + + # Ensure that we clean up any artifacts produced by the doctests + # with patch(pl.DataFrame.write_csv): + # run doctests and report + result = unittest.TextTestRunner().run(test_suite) + success_flag = (result.testsRun > 0) & (len(result.failures) == 0) + sys.exit(int(not success_flag)) diff --git a/py-polars/tests/docs/test_user_guide.py b/py-polars/tests/docs/test_user_guide.py new file mode 100644 index 000000000000..98657c5ad89d --- /dev/null +++ b/py-polars/tests/docs/test_user_guide.py @@ -0,0 +1,41 @@ +"""Run all Python code snippets.""" + +import os +import runpy +import sys +from collections.abc import Iterator +from pathlib import Path + +import matplotlib as mpl +import pytest + +# Do not show plots +mpl.use("Agg") + +# Get paths to Python code snippets +repo_root = Path(__file__).parent.parent.parent.parent +python_snippets_dir = repo_root / "docs" / "source" / "src" / "python" +snippet_paths = list(python_snippets_dir.rglob("*.py")) + +# Skip visualization snippets +snippet_paths = [p for p in snippet_paths if "visualization" not in str(p)] + +# Skip UDF section on Python 3.13 as numba does not support it yet +if sys.version_info >= (3, 13): + snippet_paths = [p for p in snippet_paths if "user-defined-functions" not in str(p)] + + +@pytest.fixture(scope="module") +def _change_test_dir() -> Iterator[None]: + """Change path to repo root to accommodate data paths in code snippets.""" + current_path = Path().resolve() + os.chdir(repo_root) + yield + os.chdir(current_path) + + +@pytest.mark.docs +@pytest.mark.parametrize("path", snippet_paths) +@pytest.mark.usefixtures("_change_test_dir") +def test_run_python_snippets(path: Path) -> None: + runpy.run_path(str(path)) diff --git a/py-polars/tests/test_df.py b/py-polars/tests/test_df.py deleted file mode 100644 index 8f948bd16091..000000000000 --- a/py-polars/tests/test_df.py +++ /dev/null @@ -1,140 +0,0 @@ -from pypolars import DataFrame, Series -import pytest - - -def test_init(): - df = DataFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0]}) - - # length mismatch - with pytest.raises(RuntimeError): - df = DataFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0, 4.0]}) - - -def test_selection(): - df = DataFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0], "c": ["a", "b", "c"]}) - - assert df["a"].dtype == "i64" - assert df["b"].dtype == "f64" - assert df["c"].dtype == "str" - - assert df[["a", "b"]].columns == ["a", "b"] - assert df[[True, False, True]].height == 2 - - assert df[[True, False, True], "b"].shape == (2, 1) - assert df[[True, False, False], ["a", "b"]].shape == (1, 2) - - assert df[[0, 1], "b"].shape == (2, 1) - assert df[[2], ["a", "b"]].shape == (1, 2) - assert df.select_at_idx(0).name == "a" - assert (df.a == df["a"]).sum() == 3 - assert (df.c == df["a"]).sum() == 0 - - -def test_sort(): - df = DataFrame({"a": [2, 1, 3], "b": [1, 2, 3]}) - df.sort("a", in_place=True) - assert df.frame_equal(DataFrame({"a": [1, 2, 3], "b": [2, 1, 3]})) - - -def test_replace(): - df = DataFrame({"a": [2, 1, 3], "b": [1, 2, 3]}) - s = Series("c", [True, False, True]) - df.replace("a", s) - assert df.frame_equal(DataFrame({"c": [True, False, True], "b": [1, 2, 3]})) - - -def test_slice(): - df = DataFrame({"a": [2, 1, 3], "b": ["a", "b", "c"]}) - df = df.slice(1, 2) - assert df.frame_equal(DataFrame({"a": [1, 3], "b": ["b", "c"]})) - - -def test_head_tail(): - df = DataFrame({"a": range(10), "b": range(10)}) - assert df.head(5).height == 5 - assert df.tail(5).height == 5 - - assert not df.head(5).frame_equal(df.tail(5)) - # check if it doesn't fail when out of bounds - assert df.head(100).height == 10 - assert df.tail(100).height == 10 - - -def test_groupby(): - df = DataFrame( - { - "a": ["a", "b", "a", "b", "b", "c"], - "b": [1, 2, 3, 4, 5, 6], - "c": [6, 5, 4, 3, 2, 1], - } - ) - assert ( - df.groupby("a") - .select("b") - .sum() - .frame_equal(DataFrame({"a": ["a", "b", "c"], "": [4, 11, 6]})) - ) - assert ( - df.groupby("a") - .select("c") - .sum() - .frame_equal(DataFrame({"a": ["a", "b", "c"], "": [10, 10, 1]})) - ) - assert ( - df.groupby("a") - .select("b") - .min() - .frame_equal(DataFrame({"a": ["a", "b", "c"], "": [1, 2, 6]})) - ) - assert ( - df.groupby("a") - .select("b") - .max() - .frame_equal(DataFrame({"a": ["a", "b", "c"], "": [3, 5, 6]})) - ) - assert ( - df.groupby("a") - .select("b") - .mean() - .frame_equal(DataFrame({"a": ["a", "b", "c"], "": [2.0, (2 + 4 + 5) / 3, 6.0]})) - ) - # - # # TODO: is false because count is u32 - # df.groupby(by="a", select="b", agg="count").frame_equal( - # DataFrame({"a": ["a", "b", "c"], "": [2, 3, 1]}) - # ) - - -def test_join(): - df_left = DataFrame( - {"a": ["a", "b", "a", "z"], "b": [1, 2, 3, 4], "c": [6, 5, 4, 3],} - ) - df_right = DataFrame( - {"a": ["b", "c", "b", "a"], "k": [0, 3, 9, 6], "c": [1, 0, 2, 1],} - ) - - joined = df_left.join(df_right, left_on="a", right_on="a").sort("a") - assert joined["b"].series_equal(Series("", [1, 3, 2, 2])) - joined = df_left.join(df_right, left_on="a", right_on="a", how="left").sort("a") - assert joined["c_right"].is_null().sum() == 1 - assert joined["b"].series_equal(Series("", [1, 3, 2, 2, 4])) - joined = df_left.join(df_right, left_on="a", right_on="a", how="outer").sort("a") - assert joined["c_right"].null_count() == 1 - assert joined["c"].null_count() == 2 - assert joined["b"].null_count() == 2 - - -def test_hstack(): - df = DataFrame({"a": [2, 1, 3], "b": ["a", "b", "c"]}) - df.hstack([Series("stacked", [-1, -1, -1])]) - assert df.shape == (3, 3) - assert df.columns == ["a", "b", "stacked"] - - -def test_drop(): - df = DataFrame({"a": [2, 1, 3], "b": ["a", "b", "c"], "c": [1, 2, 3]}) - df = df.drop("a") - assert df.shape == (3, 2) - df = DataFrame({"a": [2, 1, 3], "b": ["a", "b", "c"], "c": [1, 2, 3]}) - s = df.drop_in_place("a") - assert s.name == "a" diff --git a/py-polars/tests/test_series.py b/py-polars/tests/test_series.py deleted file mode 100644 index 7500852e04a0..000000000000 --- a/py-polars/tests/test_series.py +++ /dev/null @@ -1,177 +0,0 @@ -from pypolars import Series -from pypolars.ffi import aligned_array_f32 -import numpy as np - - -def create_series(): - return Series("a", [1, 2]) - - -def test_equality(): - a = create_series() - b = a - - cmp = a == b - assert isinstance(cmp, Series) - assert cmp.sum() == 2 - assert (a != b).sum() == 0 - assert (a >= b).sum() == 2 - assert (a <= b).sum() == 2 - assert (a > b).sum() == 0 - assert (a < b).sum() == 0 - assert a.sum() == 3 - assert a.series_equal(b) - - a = Series("name", ["ham", "foo", "bar"]) - assert (a == "ham").to_list() == [True, False, False] - - -def test_agg(): - a = create_series() - assert a.mean() == 1.5 - assert a.min() == 1 - assert a.max() == 2 - - -def test_arithmetic(): - a = create_series() - b = a - - assert ((a * b) == [1, 4]).sum() == 2 - assert ((a / b) == [1, 1]).sum() == 2 - assert ((a + b) == [2, 4]).sum() == 2 - assert ((a - b) == [0, 0]).sum() == 2 - assert ((a + 1) == [2, 3]).sum() == 2 - assert ((a - 1) == [0, 1]).sum() == 2 - assert ((a / 1) == [1, 2]).sum() == 2 - assert ((a * 2) == [2, 4]).sum() == 2 - assert ((1 + a) == [2, 3]).sum() == 2 - assert ((1 - a) == [0, -1]).sum() == 2 - assert ((1 * a) == [1, 2]).sum() == 2 - # integer division - assert ((1 / a) == [1, 0]).sum() == 2 - - -def test_various(): - a = create_series() - - assert a.is_null().sum() == 0 - assert a.name == "a" - a.rename("b") - assert a.name == "b" - assert a.len() == 2 - assert len(a) == 2 - b = a.slice(1, 1) - assert b.len() == 1 - assert b.series_equal(Series("", [2])) - a.append(b) - assert a.series_equal(Series("", [1, 2, 2])) - - a = Series("a", range(20)) - assert a.head(5).len() == 5 - assert a.tail(5).len() == 5 - assert a.head(5) != a.tail(5) - - a = Series("a", [2, 1, 4]) - a.sort(in_place=True) - assert a.series_equal(Series("", [1, 2, 4])) - a = Series("a", [2, 1, 1, 4, 4, 4]) - assert list(a.arg_unique()) == [0, 1, 3] - - assert a.take([2, 3]).series_equal(Series("", [1, 4])) - - -def test_filter(): - a = Series("a", range(20)) - assert a[a > 1].len() == 18 - assert a[a < 1].len() == 1 - assert a[a <= 1].len() == 2 - assert a[a >= 1].len() == 19 - assert a[a == 1].len() == 1 - assert a[a != 1].len() == 19 - - -def test_cast(): - a = Series("a", range(20)) - - assert a.cast_f32().dtype == "f32" - assert a.cast_f64().dtype == "f64" - assert a.cast_i32().dtype == "i32" - assert a.cast_u32().dtype == "u32" - assert a.cast_date64().dtype == "date64" - assert a.cast_time64ns().dtype == "time64(ns)" - assert a.cast_date32().dtype == "date32" - - -def test_to_python(): - a = Series("a", range(20)) - b = a.to_list() - assert isinstance(b, list) - assert len(b) == 20 - - a = Series("a", [1, None, 2], nullable=True) - assert a.null_count() == 1 - assert a.to_list() == [1, None, 2] - - -def test_sort(): - a = Series("a", [2, 1, 3]) - assert a.sort().to_list() == [1, 2, 3] - assert a.sort(reverse=True) == [3, 2, 1] - - -def test_rechunk(): - a = Series("a", [1, 2, 3]) - b = Series("b", [4, 5, 6]) - a.append(b) - assert a.n_chunks() == 2 - assert a.rechunk(in_place=False).n_chunks() == 1 - a.rechunk(in_place=True) - assert a.n_chunks() == 1 - - -def test_view(): - a = Series("a", [1.0, 2.0, 3.0]) - assert isinstance(a.view(), np.ndarray) - assert np.all(a.view() == np.array([1, 2, 3])) - - -def test_numpy_interface(): - # this isn't used anymore. - a, ptr = aligned_array_f32(10) - assert a.dtype == np.float32 - assert a.shape == (10,) - pointer, read_only_flag = a.__array_interface__["data"] - # set read only flag to False - a.__array_interface__["data"] = (pointer, False) - # the __array_interface is used to create a new array (pointing to the same memory) - b = np.array(a) - # now the memory is writeable - b[0] = 1 - - # TODO: sent pointer to Rust and take ownership of array. - # https://stackoverflow.com/questions/37988849/safer-way-to-expose-a-c-allocated-memory-buffer-using-numpy-ctypes - - -def test_ufunc(): - a = Series("a", [1.0, 2.0, 3.0, 4.0]) - b = np.multiply(a, 4) - assert isinstance(b, Series) - assert b == [4, 8, 12, 16] - - # test if null bitmask is preserved - a = Series("a", [1.0, None, 3.0], nullable=True) - b = np.exp(a) - assert b.null_count() == 1 - - -def test_get(): - a = Series("a", [1, 2, 3]) - assert a[0] == 1 - assert a[:2] == [1, 2] - - -def test_fill_none(): - a = Series("a", [1, 2, None], nullable=True) - b = a.fill_none("forward") - assert b == [1, 2, 2] diff --git a/py-polars/tests/unit/__init__.py b/py-polars/tests/unit/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/cloud/__init__.py b/py-polars/tests/unit/cloud/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/cloud/test_prepare_cloud_plan.py b/py-polars/tests/unit/cloud/test_prepare_cloud_plan.py new file mode 100644 index 000000000000..b0dd40ee8086 --- /dev/null +++ b/py-polars/tests/unit/cloud/test_prepare_cloud_plan.py @@ -0,0 +1,100 @@ +from io import BytesIO +from pathlib import Path + +import pytest + +import polars as pl +from polars._utils.cloud import prepare_cloud_plan +from polars.exceptions import ComputeError, InvalidOperationError + +CLOUD_SOURCE = "s3://my-nonexistent-bucket/dataset" + + +@pytest.mark.parametrize( + "lf", + [ + pl.scan_parquet(CLOUD_SOURCE).select("c", pl.lit(2)).with_row_index(), + pl.LazyFrame({"a": [1, 2], "b": [3, 4]}) + .select("a", "b") + .filter(pl.col("a") == pl.lit(1)), + ], +) +def test_prepare_cloud_plan(lf: pl.LazyFrame) -> None: + result = prepare_cloud_plan(lf) + assert isinstance(result, bytes) + + deserialized = pl.LazyFrame.deserialize(BytesIO(result)) + assert isinstance(deserialized, pl.LazyFrame) + + +@pytest.mark.parametrize( + "lf", + [ + pl.LazyFrame({"a": [1, 2], "b": [3, 4]}).select( + pl.col("a").map_elements(lambda x: sum(x)) + ), + pl.LazyFrame({"a": [1, 2], "b": [3, 4]}).select( + pl.col("b").map_batches(lambda x: sum(x)) + ), + pl.LazyFrame({"a": [1, 2], "b": [3, 4]}).map_batches(lambda x: x), + pl.scan_parquet(CLOUD_SOURCE).filter( + pl.col("a") < pl.lit(1).map_elements(lambda x: x + 1) + ), + pl.LazyFrame({"a": [[1, 2], [3, 4, 5]], "b": [3, 4]}).select( + pl.col("a").map_elements(lambda x: sum(x), return_dtype=pl.Int64) + ), + ], +) +def test_prepare_cloud_plan_udf(lf: pl.LazyFrame) -> None: + result = prepare_cloud_plan(lf) + assert isinstance(result, bytes) + + deserialized = pl.LazyFrame.deserialize(BytesIO(result)) + assert isinstance(deserialized, pl.LazyFrame) + + +def test_prepare_cloud_plan_optimization_toggle() -> None: + lf = pl.LazyFrame({"a": [1, 2], "b": [3, 4]}) + + with pytest.raises(TypeError, match="unexpected keyword argument"): + prepare_cloud_plan(lf, nonexistent_optimization=False) + + result = prepare_cloud_plan(lf, projection_pushdown=False) + assert isinstance(result, bytes) + + # TODO: How to check that this optimization was toggled correctly? + deserialized = pl.LazyFrame.deserialize(BytesIO(result)) + assert isinstance(deserialized, pl.LazyFrame) + + +@pytest.mark.parametrize( + "lf", + [ + pl.scan_parquet("data.parquet"), + pl.scan_ndjson(Path("data.ndjson")), + pl.scan_csv("data-*.csv"), + pl.scan_ipc(["data-1.feather", "data-2.feather"]), + ], +) +def test_prepare_cloud_plan_fail_on_local_data_source(lf: pl.LazyFrame) -> None: + with pytest.raises( + InvalidOperationError, + match="logical plan ineligible for execution on Polars Cloud", + ): + prepare_cloud_plan(lf) + + +@pytest.mark.parametrize( + "lf", + [ + pl.LazyFrame({"a": [{"x": 1, "y": 2}]}).select( + pl.col("a").name.map(lambda x: x.upper()) + ), + pl.LazyFrame({"a": [{"x": 1, "y": 2}]}).select( + pl.col("a").name.map_fields(lambda x: x.upper()) + ), + ], +) +def test_prepare_cloud_plan_fail_on_serialization(lf: pl.LazyFrame) -> None: + with pytest.raises(ComputeError, match="serialization not supported"): + prepare_cloud_plan(lf) diff --git a/py-polars/tests/unit/conftest.py b/py-polars/tests/unit/conftest.py new file mode 100644 index 000000000000..5c56ecf55e13 --- /dev/null +++ b/py-polars/tests/unit/conftest.py @@ -0,0 +1,317 @@ +from __future__ import annotations + +import gc +import os +import random +import string +import sys +from contextlib import contextmanager +from functools import wraps +from typing import TYPE_CHECKING, Any, cast + +import numpy as np +import pytest + +import polars as pl +from polars.testing.parametric import load_profile + +if TYPE_CHECKING: + from collections.abc import Generator + from types import ModuleType + from typing import Any + + FixtureRequest = Any + +load_profile( + profile=os.environ.get("POLARS_HYPOTHESIS_PROFILE", "fast"), # type: ignore[arg-type] +) + +# Data type groups +SIGNED_INTEGER_DTYPES = [pl.Int8(), pl.Int16(), pl.Int32(), pl.Int64(), pl.Int128()] +UNSIGNED_INTEGER_DTYPES = [pl.UInt8(), pl.UInt16(), pl.UInt32(), pl.UInt64()] +INTEGER_DTYPES = SIGNED_INTEGER_DTYPES + UNSIGNED_INTEGER_DTYPES +FLOAT_DTYPES = [pl.Float32(), pl.Float64()] +NUMERIC_DTYPES = INTEGER_DTYPES + FLOAT_DTYPES + +DATETIME_DTYPES = [pl.Datetime("ms"), pl.Datetime("us"), pl.Datetime("ns")] +DURATION_DTYPES = [pl.Duration("ms"), pl.Duration("us"), pl.Duration("ns")] +TEMPORAL_DTYPES = [*DATETIME_DTYPES, *DURATION_DTYPES, pl.Date(), pl.Time()] + +NESTED_DTYPES = [pl.List, pl.Struct, pl.Array] + + +@pytest.fixture +def partition_limit() -> int: + """The limit at which Polars will start partitioning in debug builds.""" + return 15 + + +@pytest.fixture +def df() -> pl.DataFrame: + df = pl.DataFrame( + { + "bools": [False, True, False], + "bools_nulls": [None, True, False], + "int": [1, 2, 3], + "int_nulls": [1, None, 3], + "floats": [1.0, 2.0, 3.0], + "floats_nulls": [1.0, None, 3.0], + "strings": ["foo", "bar", "ham"], + "strings_nulls": ["foo", None, "ham"], + "date": [1324, 123, 1234], + "datetime": [13241324, 12341256, 12341234], + "time": [13241324, 12341256, 12341234], + "list_str": [["a", "b", None], ["a"], []], + "list_bool": [[True, False, None], [None], []], + "list_int": [[1, None, 3], [None], []], + "list_flt": [[1.0, None, 3.0], [None], []], + } + ) + return df.with_columns( + pl.col("date").cast(pl.Date), + pl.col("datetime").cast(pl.Datetime), + pl.col("strings").cast(pl.Categorical).alias("cat"), + pl.col("strings").cast(pl.Enum(["foo", "ham", "bar"])).alias("enum"), + pl.col("time").cast(pl.Time), + ) + + +@pytest.fixture +def df_no_lists(df: pl.DataFrame) -> pl.DataFrame: + return df.select( + pl.all().exclude(["list_str", "list_int", "list_bool", "list_int", "list_flt"]) + ) + + +@pytest.fixture +def fruits_cars() -> pl.DataFrame: + return pl.DataFrame( + { + "A": [1, 2, 3, 4, 5], + "fruits": ["banana", "banana", "apple", "apple", "banana"], + "B": [5, 4, 3, 2, 1], + "cars": ["beetle", "audi", "beetle", "beetle", "beetle"], + }, + schema_overrides={"A": pl.Int64, "B": pl.Int64}, + ) + + +@pytest.fixture +def str_ints_df() -> pl.DataFrame: + n = 1000 + + strs = pl.Series("strs", random.choices(string.ascii_lowercase, k=n)) + strs = pl.select( + pl.when(strs == "a") + .then(pl.lit("")) + .when(strs == "b") + .then(None) + .otherwise(strs) + .alias("strs") + ).to_series() + + vals = pl.Series("vals", np.random.rand(n)) + + return pl.DataFrame([vals, strs]) + + +ISO8601_FORMATS_DATETIME = [] + +for T in ["T", " "]: + for hms in ( + [ + f"{T}%H:%M:%S", + f"{T}%H%M%S", + f"{T}%H:%M", + f"{T}%H%M", + ] + + [f"{T}%H:%M:%S.{fraction}" for fraction in ["%9f", "%6f", "%3f"]] + + [f"{T}%H%M%S.{fraction}" for fraction in ["%9f", "%6f", "%3f"]] + + [""] + ): + for date_sep in ("/", "-"): + fmt = f"%Y{date_sep}%m{date_sep}%d{hms}" + ISO8601_FORMATS_DATETIME.append(fmt) + + +@pytest.fixture(params=ISO8601_FORMATS_DATETIME) +def iso8601_format_datetime(request: pytest.FixtureRequest) -> list[str]: + return cast(list[str], request.param) + + +ISO8601_TZ_AWARE_FORMATS_DATETIME = [] + +for T in ["T", " "]: + for hms in ( + [ + f"{T}%H:%M:%S", + f"{T}%H%M%S", + f"{T}%H:%M", + f"{T}%H%M", + ] + + [f"{T}%H:%M:%S.{fraction}" for fraction in ["%9f", "%6f", "%3f"]] + + [f"{T}%H%M%S.{fraction}" for fraction in ["%9f", "%6f", "%3f"]] + ): + for date_sep in ("/", "-"): + fmt = f"%Y{date_sep}%m{date_sep}%d{hms}%#z" + ISO8601_TZ_AWARE_FORMATS_DATETIME.append(fmt) + + +@pytest.fixture(params=ISO8601_TZ_AWARE_FORMATS_DATETIME) +def iso8601_tz_aware_format_datetime(request: pytest.FixtureRequest) -> list[str]: + return cast(list[str], request.param) + + +ISO8601_FORMATS_DATE = [] + +for date_sep in ("/", "-"): + fmt = f"%Y{date_sep}%m{date_sep}%d" + ISO8601_FORMATS_DATE.append(fmt) + + +@pytest.fixture(params=ISO8601_FORMATS_DATE) +def iso8601_format_date(request: pytest.FixtureRequest) -> list[str]: + return cast(list[str], request.param) + + +class MemoryUsage: + """ + Provide an API for measuring peak memory usage. + + Memory from PyArrow is not tracked at the moment. + """ + + def reset_tracking(self) -> None: + """Reset tracking to zero.""" + # gc.collect() + # tracemalloc.stop() + # tracemalloc.start() + # assert self.get_peak() < 100_000 + + def get_current(self) -> int: + """ + Return currently allocated memory, in bytes. + + This only tracks allocations since this object was created or + ``reset_tracking()`` was called, whichever is later. + """ + return 0 + # tracemalloc.get_traced_memory()[0] + + def get_peak(self) -> int: + """ + Return peak allocated memory, in bytes. + + This returns peak allocations since this object was created or + ``reset_tracking()`` was called, whichever is later. + """ + return 0 + # tracemalloc.get_traced_memory()[1] + + +# The bizarre syntax is from +# https://github.com/pytest-dev/pytest/issues/1368#issuecomment-2344450259 - we +# need to mark any test using this fixture as slow because we have a sleep +# added to work around a CPython bug, see the end of the function. +@pytest.fixture(params=[pytest.param(0, marks=pytest.mark.slow)]) +def memory_usage_without_pyarrow() -> Generator[MemoryUsage, Any, Any]: + """ + Provide an API for measuring peak memory usage. + + Not thread-safe: there should only be one instance of MemoryUsage at any + given time. + + Memory usage from PyArrow is not tracked. + """ + if not pl.polars._debug: # type: ignore[attr-defined] + pytest.skip("Memory usage only available in debug/dev builds.") + + if os.getenv("POLARS_FORCE_ASYNC", "0") == "1": + pytest.skip("Hangs when combined with async glob") + + if sys.platform == "win32": + # abi3 wheels don't have the tracemalloc C APIs, which breaks linking + # on Windows. + pytest.skip("Windows not supported at the moment.") + + gc.collect() + try: + yield MemoryUsage() + finally: + gc.collect() + # gc.collect() + # tracemalloc.start() + # try: + # yield MemoryUsage() + # finally: + # # Workaround for https://github.com/python/cpython/issues/128679 + # time.sleep(1) + # gc.collect() + # + # tracemalloc.stop() + + +@pytest.fixture(params=[True, False]) +def test_global_and_local( + request: FixtureRequest, +) -> Generator[Any, Any, Any]: + """ + Setup fixture which runs each test with and without global string cache. + + Usage: @pytest.mark.usefixtures("test_global_and_local") + """ + use_global = request.param + if use_global: + with pl.StringCache(): + # Pre-fill some global items to ensure physical repr isn't 0..n. + pl.Series(["eapioejf", "2m4lmv", "3v3v9dlf"], dtype=pl.Categorical) + yield + else: + yield + + +@contextmanager +def mock_module_import( + name: str, module: ModuleType, *, replace_if_exists: bool = False +) -> Generator[None, None, None]: + """ + Mock an optional module import for the duration of a context. + + Parameters + ---------- + name + The name of the module to mock. + module + A ModuleType instance representing the mocked module. + replace_if_exists + Whether to replace the module if it already exists in `sys.modules` (defaults to + False, meaning that if the module is already imported, it will not be replaced). + """ + if (original := sys.modules.get(name, None)) is not None and not replace_if_exists: + yield + else: + sys.modules[name] = module + try: + yield + finally: + if original is not None: + sys.modules[name] = original + else: + del sys.modules[name] + + +# The new streaming engine currently only works if you keep the same string cache +# alive the entire time. +def with_string_cache_if_auto_streaming(f: Any) -> Any: + if ( + os.getenv("POLARS_AUTO_NEW_STREAMING", os.getenv("POLARS_FORCE_NEW_STREAMING")) + != "1" + ): + return f + + @wraps(f) + def with_cache(*args: Any, **kwargs: Any) -> Any: + with pl.StringCache(): + return f(*args, **kwargs) + + return with_cache diff --git a/py-polars/tests/unit/constructors/__init__.py b/py-polars/tests/unit/constructors/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/constructors/test_any_value_fallbacks.py b/py-polars/tests/unit/constructors/test_any_value_fallbacks.py new file mode 100644 index 000000000000..a197f3705443 --- /dev/null +++ b/py-polars/tests/unit/constructors/test_any_value_fallbacks.py @@ -0,0 +1,414 @@ +# TODO: Replace direct calls to fallback constructors with calls to the Series +# constructor once the Python-side logic has been updated +from __future__ import annotations + +from datetime import date, datetime, time, timedelta +from decimal import Decimal as D +from typing import TYPE_CHECKING, Any + +import pytest + +import polars as pl +from polars._utils.wrap import wrap_s +from polars.polars import PySeries +from polars.testing import assert_frame_equal + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + + +@pytest.mark.parametrize( + ("dtype", "values"), + [ + (pl.Int64, [-1, 0, 100_000, None]), + (pl.Float64, [-1.5, 0.0, 10.0, None]), + (pl.Boolean, [True, False, None]), + (pl.Binary, [b"123", b"xyz", None]), + (pl.String, ["123", "xyz", None]), + (pl.Date, [date(1970, 1, 1), date(2020, 12, 31), None]), + (pl.Time, [time(0, 0), time(23, 59, 59), None]), + (pl.Datetime, [datetime(1970, 1, 1), datetime(2020, 12, 31, 23, 59, 59), None]), + (pl.Duration, [timedelta(hours=0), timedelta(seconds=100), None]), + (pl.Categorical, ["a", "b", "a", None]), + (pl.Enum(["a", "b"]), ["a", "b", "a", None]), + (pl.Decimal(10, 3), [D("12.345"), D("0.789"), None]), + ( + pl.Struct({"a": pl.Int8, "b": pl.String}), + [{"a": 1, "b": "foo"}, {"a": -1, "b": "bar"}], + ), + ], +) +@pytest.mark.parametrize("strict", [True, False]) +def test_fallback_with_dtype_strict( + dtype: PolarsDataType, values: list[Any], strict: bool +) -> None: + result = wrap_s( + PySeries.new_from_any_values_and_dtype("", values, dtype, strict=strict) + ) + assert result.to_list() == values + + +@pytest.mark.parametrize( + ("dtype", "values"), + [ + (pl.Int64, [1.0, 2.0]), + (pl.Float64, [1, 2]), + (pl.Boolean, [0, 1]), + (pl.Binary, ["123", "xyz"]), + (pl.String, [b"123", b"xyz"]), + (pl.Date, [datetime(1970, 1, 1), datetime(2020, 12, 31, 23, 59, 59)]), + (pl.Time, [datetime(1970, 1, 1), datetime(2020, 12, 31, 23, 59, 59)]), + (pl.Datetime, [date(1970, 1, 1), date(2020, 12, 31)]), + (pl.Datetime("ms"), [datetime(1970, 1, 1), datetime(2020, 12, 31, 23, 59, 59)]), + (pl.Datetime("ns"), [datetime(1970, 1, 1), datetime(2020, 12, 31, 23, 59, 59)]), + (pl.Duration, [0, 1200]), + (pl.Duration("ms"), [timedelta(hours=0), timedelta(seconds=100)]), + (pl.Duration("ns"), [timedelta(hours=0), timedelta(seconds=100)]), + (pl.Categorical, [0, 1, 0]), + (pl.Enum(["a", "b"]), [0, 1, 0]), + (pl.Decimal(10, 3), [100, 200]), + (pl.Decimal(5, 3), [D("1.2345")]), + ( + pl.Struct({"a": pl.Int8, "b": pl.String}), + [{"a": 1, "b": "foo"}, {"a": 2.0, "b": "bar"}], + ), + ], +) +def test_fallback_with_dtype_strict_failure( + dtype: PolarsDataType, values: list[Any] +) -> None: + with pytest.raises(TypeError, match="unexpected value"): + PySeries.new_from_any_values_and_dtype("", values, dtype, strict=True) + + +@pytest.mark.parametrize( + ("dtype", "values", "expected"), + [ + ( + pl.Int64, + [False, True, 0, -1, 0.0, 2.5, date(1970, 1, 2), "5", "xyz"], + [0, 1, 0, -1, 0, 2, 1, 5, None], + ), + ( + pl.Float64, + [False, True, 0, -1, 0.0, 2.5, date(1970, 1, 2), "5", "xyz"], + [0.0, 1.0, 0.0, -1.0, 0.0, 2.5, 1.0, 5.0, None], + ), + ( + pl.Boolean, + [False, True, 0, -1, 0.0, 2.5, date(1970, 1, 1), "true"], + [False, True, False, True, False, True, None, None], + ), + ( + pl.Binary, + [b"123", "xyz", 100, True, None], + [b"123", b"xyz", None, None, None], + ), + ( + pl.String, + ["xyz", 1, 2.5, date(1970, 1, 1), True, b"123", None], + ["xyz", "1", "2.5", "1970-01-01", "true", None, None], + ), + ( + pl.Date, + ["xyz", 1, 2.5, date(1970, 1, 1), datetime(2000, 1, 1, 12), True, None], + [ + None, + date(1970, 1, 2), + date(1970, 1, 3), + date(1970, 1, 1), + date(2000, 1, 1), + None, + None, + ], + ), + ( + pl.Time, + [ + "xyz", + 1, + 2.5, + date(1970, 1, 1), + time(12, 0), + datetime(2000, 1, 1, 12), + timedelta(hours=5), + True, + None, + ], + [ + None, + time(0, 0), + time(0, 0), + None, + time(12, 0), + time(12, 0), + None, + None, + None, + ], + ), + ( + pl.Datetime, + [ + "xyz", + 1, + 2.5, + date(1970, 1, 1), + time(12, 0), + datetime(2000, 1, 1, 12), + timedelta(hours=5), + True, + None, + ], + [ + None, + datetime(1970, 1, 1, microsecond=1), + datetime(1970, 1, 1, microsecond=2), + datetime(1970, 1, 1), + None, + datetime(2000, 1, 1, 12, 0), + None, + None, + None, + ], + ), + ( + pl.Duration, + [ + "xyz", + 1, + 2.5, + date(1970, 1, 1), + time(12, 0), + datetime(2000, 1, 1, 12), + timedelta(hours=5), + True, + None, + ], + [ + None, + timedelta(microseconds=1), + timedelta(microseconds=2), + None, + timedelta(hours=12), + None, + timedelta(hours=5), + None, + None, + ], + ), + ( + pl.Categorical, + ["xyz", 1, 2.5, date(1970, 1, 1), True, b"123", None], + ["xyz", "1", "2.5", "1970-01-01", "true", None, None], + ), + ( + pl.Enum(["a", "b"]), + ["a", "b", "c", 1, 2, None], + ["a", "b", None, None, None, None], + ), + ( + pl.Decimal(5, 3), + [ + D("12"), + D("1.2345"), + # D("123456"), + False, + True, + 0, + -1, + 0.0, + 2.5, + date(1970, 1, 2), + "5", + "xyz", + ], + [ + D("12.000"), + None, + # None, + None, + None, + D("0.000"), + D("-1.000"), + None, + None, + None, + None, + None, + ], + ), + ( + pl.Struct({"a": pl.Int8, "b": pl.String}), + [{"a": 1, "b": "foo"}, {"a": 1_000, "b": 2.0}], + [{"a": 1, "b": "foo"}, {"a": None, "b": "2.0"}], + ), + ], +) +def test_fallback_with_dtype_nonstrict( + dtype: PolarsDataType, values: list[Any], expected: list[Any] +) -> None: + result = wrap_s( + PySeries.new_from_any_values_and_dtype("", values, dtype, strict=False) + ) + assert result.to_list() == expected + + +@pytest.mark.parametrize( + ("expected_dtype", "values"), + [ + (pl.Int64, [-1, 0, 100_000, None]), + (pl.Float64, [-1.5, 0.0, 10.0, None]), + (pl.Boolean, [True, False, None]), + (pl.Binary, [b"123", b"xyz", None]), + (pl.String, ["123", "xyz", None]), + (pl.Date, [date(1970, 1, 1), date(2020, 12, 31), None]), + (pl.Time, [time(0, 0), time(23, 59, 59), None]), + ( + pl.Datetime("us"), + [datetime(1970, 1, 1), datetime(2020, 12, 31, 23, 59, 59), None], + ), + (pl.Duration("us"), [timedelta(hours=0), timedelta(seconds=100), None]), + (pl.Decimal(None, 3), [D("12.345"), D("0.789"), None]), + (pl.Decimal(None, 0), [D("12"), D("56789"), None]), + ( + pl.Struct({"a": pl.Int64, "b": pl.String, "c": pl.Float64}), + [{"a": 1, "b": "foo", "c": None}, {"a": -1, "b": "bar", "c": 3.0}], + ), + ], +) +@pytest.mark.parametrize("strict", [True, False]) +def test_fallback_without_dtype( + expected_dtype: PolarsDataType, values: list[Any], strict: bool +) -> None: + result = wrap_s(PySeries.new_from_any_values("", values, strict=strict)) + assert result.to_list() == values + assert result.dtype == expected_dtype + + +@pytest.mark.parametrize( + "values", + [ + [1.0, 2], + [1, 2.0], + [False, 1], + [b"123", "xyz"], + ["123", b"xyz"], + [date(1970, 1, 1), datetime(2020, 12, 31)], + [time(0, 0), 1_000], + [datetime(1970, 1, 1), date(2020, 12, 31)], + [timedelta(hours=0), 1_000], + [D("12.345"), 100], + [D("12.345"), 3.14], + [{"a": 1, "b": "foo"}, {"a": -1, "b": date(2020, 12, 31)}], + [{"a": None}, {"a": 1.0}, {"a": 1}], + ], +) +def test_fallback_without_dtype_strict_failure(values: list[Any]) -> None: + with pytest.raises(TypeError, match="unexpected value"): + PySeries.new_from_any_values("", values, strict=True) + + +@pytest.mark.parametrize( + ("values", "expected", "expected_dtype"), + [ + ([True, 2], [1, 2], pl.Int64), + ([1, 2.0], [1.0, 2.0], pl.Float64), + ([2.0, "c"], ["2.0", "c"], pl.String), + ( + [date(1970, 1, 1), datetime(2022, 12, 31)], + [datetime(1970, 1, 1), datetime(2022, 12, 31)], + pl.Datetime("us"), + ), + ([D("3.1415"), 2.51], [3.1415, 2.51], pl.Float64), + ([D("3.1415"), 100], [D("3.1415"), D("100")], pl.Decimal(None, 4)), + ([1, 2.0, b"d", date(2022, 1, 1)], [1, 2.0, b"d", date(2022, 1, 1)], pl.Object), + ( + [ + {"a": 1, "b": "foo", "c": None}, + {"a": 2.0, "b": date(2020, 12, 31), "c": None}, + ], + [ + {"a": 1.0, "b": "foo", "c": None}, + {"a": 2.0, "b": "2020-12-31", "c": None}, + ], + pl.Struct({"a": pl.Float64, "b": pl.String, "c": pl.Null}), + ), + ( + [{"a": None}, {"a": 1.0}, {"a": 1}], + [{"a": None}, {"a": 1.0}, {"a": 1.0}], + pl.Struct({"a": pl.Float64}), + ), + ], +) +def test_fallback_without_dtype_nonstrict_mixed_types( + values: list[Any], + expected_dtype: PolarsDataType, + expected: list[Any], +) -> None: + result = wrap_s(PySeries.new_from_any_values("", values, strict=False)) + assert result.dtype == expected_dtype + assert result.to_list() == expected + + +def test_fallback_without_dtype_large_int() -> None: + values = [1, 2**128, None] + with pytest.raises( + OverflowError, + match="int value too large for Polars integer types", + ): + PySeries.new_from_any_values("", values, strict=True) + + result = wrap_s(PySeries.new_from_any_values("", values, strict=False)) + assert result.dtype == pl.Float64 + assert result.to_list() == [1.0, 340282366920938500000000000000000000000.0, None] + + +def test_fallback_with_dtype_large_int() -> None: + values = [1, 2**128, None] + with pytest.raises(OverflowError): + PySeries.new_from_any_values_and_dtype("", values, dtype=pl.Int128, strict=True) + + result = wrap_s( + PySeries.new_from_any_values_and_dtype( + "", values, dtype=pl.Int128, strict=False + ) + ) + assert result.dtype == pl.Int128 + assert result.to_list() == [1, None, None] + + +def test_fallback_with_dtype_strict_failure_enum_casting() -> None: + dtype = pl.Enum(["a", "b"]) + values = ["a", "b", "c", None] + + with pytest.raises( + TypeError, match="cannot append 'c' to enum without that variant" + ): + PySeries.new_from_any_values_and_dtype("", values, dtype, strict=True) + + +def test_fallback_with_dtype_strict_failure_decimal_precision() -> None: + dtype = pl.Decimal(3, 0) + values = [D("12345")] + + with pytest.raises( + TypeError, match="decimal precision 3 can't fit values with 5 digits" + ): + PySeries.new_from_any_values_and_dtype("", values, dtype, strict=True) + + +@pytest.mark.usefixtures("test_global_and_local") +@pytest.mark.may_fail_auto_streaming +def test_categorical_lit_18874() -> None: + assert_frame_equal( + pl.DataFrame( + {"a": [1, 2, 3]}, + ).with_columns(b=pl.lit("foo").cast(pl.Categorical)), + pl.DataFrame( + [ + pl.Series("a", [1, 2, 3]), + pl.Series("b", ["foo"] * 3, pl.Categorical), + ] + ), + ) diff --git a/py-polars/tests/unit/constructors/test_constructors.py b/py-polars/tests/unit/constructors/test_constructors.py new file mode 100644 index 000000000000..f6c913e8f2f5 --- /dev/null +++ b/py-polars/tests/unit/constructors/test_constructors.py @@ -0,0 +1,1857 @@ +from __future__ import annotations + +from collections import OrderedDict, namedtuple +from datetime import date, datetime, time, timedelta, timezone +from decimal import Decimal +from random import shuffle +from typing import TYPE_CHECKING, Any, Literal, NamedTuple +from zoneinfo import ZoneInfo + +import numpy as np +import pandas as pd +import pyarrow as pa +import pytest +from packaging.version import parse as parse_version +from pydantic import BaseModel, Field, TypeAdapter + +import polars as pl +import polars.selectors as cs +from polars._utils.construction.utils import try_get_type_hints +from polars.datatypes import numpy_char_code_to_dtype +from polars.dependencies import dataclasses, pydantic +from polars.exceptions import DuplicateError, ShapeError +from polars.testing import assert_frame_equal, assert_series_equal +from tests.unit.utils.pycapsule_utils import PyCapsuleArrayHolder, PyCapsuleStreamHolder + +if TYPE_CHECKING: + import sys + from collections.abc import Callable + + from polars._typing import PolarsDataType + + if sys.version_info >= (3, 11): + from typing import Self + else: + from typing_extensions import Self + + +# ----------------------------------------------------------------------------------- +# nested dataclasses, models, namedtuple classes (can't be defined inside test func) +# ----------------------------------------------------------------------------------- +@dataclasses.dataclass +class _TestBazDC: + d: datetime + e: float + f: str + + +@dataclasses.dataclass +class _TestBarDC: + a: str + b: int + c: _TestBazDC + + +@dataclasses.dataclass +class _TestFooDC: + x: int + y: _TestBarDC + + +class _TestBazPD(pydantic.BaseModel): + d: datetime + e: float + f: str + + +class _TestBarPD(pydantic.BaseModel): + a: str + b: int + c: _TestBazPD + + +class _TestFooPD(pydantic.BaseModel): + x: int + y: _TestBarPD + + +class _TestBazNT(NamedTuple): + d: datetime + e: float + f: str + + +class _TestBarNT(NamedTuple): + a: str + b: int + c: _TestBazNT + + +class _TestFooNT(NamedTuple): + x: int + y: _TestBarNT + + +# -------------------------------------------------------------------------------- + + +def test_init_dict() -> None: + # Empty dictionary + df = pl.DataFrame({}) + assert df.shape == (0, 0) + + # Empty dictionary/values + df = pl.DataFrame({"a": [], "b": []}) + assert df.shape == (0, 2) + assert df.schema == {"a": pl.Null, "b": pl.Null} + + for df in ( + pl.DataFrame({}, schema={"a": pl.Date, "b": pl.String}), + pl.DataFrame({"a": [], "b": []}, schema={"a": pl.Date, "b": pl.String}), + ): + assert df.shape == (0, 2) + assert df.schema == {"a": pl.Date, "b": pl.String} + + # List of empty list + df = pl.DataFrame({"a": [[]], "b": [[]]}) + expected = {"a": pl.List(pl.Null), "b": pl.List(pl.Null)} + assert df.schema == expected + assert df.rows() == [([], [])] + + # Mixed dtypes + df = pl.DataFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0]}) + assert df.shape == (3, 2) + assert df.columns == ["a", "b"] + assert df.dtypes == [pl.Int64, pl.Float64] + + df = pl.DataFrame( + data={"a": [1, 2, 3], "b": [1.0, 2.0, 3.0]}, + schema=[("a", pl.Int8), ("b", pl.Float32)], + ) + assert df.schema == {"a": pl.Int8, "b": pl.Float32} + + # Values contained in tuples + df = pl.DataFrame({"a": (1, 2, 3), "b": [1.0, 2.0, 3.0]}) + assert df.shape == (3, 2) + + # Datetime/Date types (from both python and integer values) + py_datetimes = ( + datetime(2022, 12, 31, 23, 59, 59), + datetime(2022, 12, 31, 23, 59, 59), + ) + py_dates = (date(2022, 12, 31), date(2022, 12, 31)) + int_datetimes = [1672531199000000, 1672531199000000] + int_dates = [19357, 19357] + + for dates, datetimes, coldefs in ( + # test inferred and explicit (given both py/polars dtypes) + (py_dates, py_datetimes, None), + (py_dates, py_datetimes, [("dt", date), ("dtm", datetime)]), + (py_dates, py_datetimes, [("dt", pl.Date), ("dtm", pl.Datetime)]), + (int_dates, int_datetimes, [("dt", date), ("dtm", datetime)]), + (int_dates, int_datetimes, [("dt", pl.Date), ("dtm", pl.Datetime)]), + ): + df = pl.DataFrame( + data={"dt": dates, "dtm": datetimes}, + schema=coldefs, + ) + assert df.schema == {"dt": pl.Date, "dtm": pl.Datetime("us")} + assert df.rows() == list(zip(py_dates, py_datetimes)) + + # Overriding dict column names/types + df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, schema=["c", "d"]) + assert df.columns == ["c", "d"] + + df = pl.DataFrame( + {"a": [1, 2, 3], "b": [4, 5, 6]}, + schema=["c", ("d", pl.Int8)], + ) # partial type info (allowed, but mypy doesn't like it ;p) + assert df.schema == {"c": pl.Int64, "d": pl.Int8} + + df = pl.DataFrame( + {"a": [1, 2, 3], "b": [4, 5, 6]}, schema=[("c", pl.Int8), ("d", pl.Int16)] + ) + assert df.schema == {"c": pl.Int8, "d": pl.Int16} + + # empty nested objects + for empty_val in [None, "", {}, []]: # type: ignore[var-annotated] + test = [{"field": {"sub_field": empty_val, "sub_field_2": 2}}] + df = pl.DataFrame(test, schema={"field": pl.Object}) + assert df["field"][0] == test[0]["field"] + + +def test_error_string_dtypes() -> None: + with pytest.raises(TypeError, match="cannot parse input"): + pl.DataFrame( + data={"x": [1, 2], "y": [3, 4], "z": [5, 6]}, + schema={"x": "i16", "y": "i32", "z": "f32"}, # type: ignore[dict-item] + ) + + with pytest.raises(TypeError, match="cannot parse input"): + pl.Series("n", [1, 2, 3], dtype="f32") # type: ignore[arg-type] + + +def test_init_structured_objects() -> None: + # validate init from dataclass, namedtuple, and pydantic model objects + @dataclasses.dataclass + class TradeDC: + timestamp: datetime + ticker: str + price: Decimal + size: int | None = None + + class TradePD(pydantic.BaseModel): + timestamp: datetime + ticker: str + price: Decimal + size: int + + class TradeNT(NamedTuple): + timestamp: datetime + ticker: str + price: Decimal + size: int | None = None + + raw_data = [ + (datetime(2022, 9, 8, 14, 30, 45), "AAPL", Decimal("157.5"), 125), + (datetime(2022, 9, 9, 10, 15, 12), "FLSY", Decimal("10.0"), 1500), + (datetime(2022, 9, 7, 15, 30), "MU", Decimal("55.5"), 400), + ] + columns = ["timestamp", "ticker", "price", "size"] + + for TradeClass in (TradeDC, TradeNT, TradePD): + trades = [TradeClass(**dict(zip(columns, values))) for values in raw_data] # type: ignore[arg-type] + + for DF in (pl.DataFrame, pl.from_records): + df = DF(data=trades) + assert df.schema == { + "timestamp": pl.Datetime("us"), + "ticker": pl.String, + "price": pl.Decimal(scale=1), + "size": pl.Int64, + } + assert df.rows() == raw_data + + # partial dtypes override + df = DF( + data=trades, + schema_overrides={"timestamp": pl.Datetime("ms"), "size": pl.Int32}, + ) + assert df.schema == { + "timestamp": pl.Datetime("ms"), + "ticker": pl.String, + "price": pl.Decimal(scale=1), + "size": pl.Int32, + } + + # in conjunction with full 'columns' override (rename/downcast) + df = pl.DataFrame( + data=trades, + schema=[ + ("ts", pl.Datetime("ms")), + ("tk", pl.Categorical), + ("pc", pl.Decimal(scale=1)), + ("sz", pl.UInt16), + ], + ) + assert df.schema == { + "ts": pl.Datetime("ms"), + "tk": pl.Categorical(ordering="physical"), + "pc": pl.Decimal(scale=1), + "sz": pl.UInt16, + } + assert df.rows() == raw_data + + # cover a miscellaneous edge-case when detecting the annotations + assert try_get_type_hints(obj=type(None)) == {} + + +def test_init_pydantic_2x() -> None: + class PageView(BaseModel): + user_id: str + ts: datetime = Field(alias=["ts", "$date"]) # type: ignore[literal-required, call-overload] + path: str = Field("?", alias=["url", "path"]) # type: ignore[literal-required, call-overload] + referer: str = Field("?", alias="referer") + event: Literal["leave", "enter"] = Field("enter") + time_on_page: int = Field(0, serialization_alias="top") + + data_json = """ + [{ + "user_id": "x", + "ts": {"$date": "2021-01-01T00:00:00.000Z"}, + "url": "/latest/foobar", + "referer": "https://google.com", + "event": "enter", + "top": 123 + }] + """ + adapter: TypeAdapter[Any] = TypeAdapter(list[PageView]) + models = adapter.validate_json(data_json) + + result = pl.DataFrame(models) + expected = pl.DataFrame( + { + "user_id": ["x"], + "ts": [datetime(2021, 1, 1, 0, 0)], + "path": ["?"], + "referer": ["https://google.com"], + "event": ["enter"], + "time_on_page": [0], + } + ) + assert_frame_equal(result, expected) + + +def test_init_structured_objects_unhashable() -> None: + # cover an edge-case with namedtuple fields that aren't hashable + + class Test(NamedTuple): + dt: datetime + info: dict[str, int] + + test_data = [ + Test(datetime(2017, 1, 1), {"a": 1, "b": 2}), + Test(datetime(2017, 1, 2), {"a": 2, "b": 2}), + ] + df = pl.DataFrame(test_data) + # shape: (2, 2) + # ┌─────────────────────┬───────────┐ + # │ dt ┆ info │ + # │ --- ┆ --- │ + # │ datetime[μs] ┆ struct[2] │ + # ╞═════════════════════╪═══════════╡ + # │ 2017-01-01 00:00:00 ┆ {1,2} │ + # │ 2017-01-02 00:00:00 ┆ {2,2} │ + # └─────────────────────┴───────────┘ + assert df.schema == { + "dt": pl.Datetime(time_unit="us", time_zone=None), + "info": pl.Struct([pl.Field("a", pl.Int64), pl.Field("b", pl.Int64)]), + } + assert df.rows() == test_data + + +@pytest.mark.parametrize( + ("foo", "bar", "baz"), + [ + (_TestFooDC, _TestBarDC, _TestBazDC), + (_TestFooPD, _TestBarPD, _TestBazPD), + (_TestFooNT, _TestBarNT, _TestBazNT), + ], +) +def test_init_structured_objects_nested(foo: Any, bar: Any, baz: Any) -> None: + data = [ + foo( + x=100, + y=bar( + a="hello", + b=800, + c=baz(d=datetime(2023, 4, 12, 10, 30), e=-10.5, f="world"), + ), + ) + ] + df = pl.DataFrame(data) + # shape: (1, 2) + # ┌─────┬───────────────────────────────────┐ + # │ x ┆ y │ + # │ --- ┆ --- │ + # │ i64 ┆ struct[3] │ + # ╞═════╪═══════════════════════════════════╡ + # │ 100 ┆ {"hello",800,{2023-04-12 10:30:0… │ + # └─────┴───────────────────────────────────┘ + + assert df.schema == { + "x": pl.Int64, + "y": pl.Struct( + [ + pl.Field("a", pl.String), + pl.Field("b", pl.Int64), + pl.Field( + "c", + pl.Struct( + [ + pl.Field("d", pl.Datetime("us")), + pl.Field("e", pl.Float64), + pl.Field("f", pl.String), + ] + ), + ), + ] + ), + } + assert df.row(0) == ( + 100, + { + "a": "hello", + "b": 800, + "c": { + "d": datetime(2023, 4, 12, 10, 30), + "e": -10.5, + "f": "world", + }, + }, + ) + + # validate nested schema override + override_struct_schema: dict[str, PolarsDataType] = { + "x": pl.Int16, + "y": pl.Struct( + [ + pl.Field("a", pl.String), + pl.Field("b", pl.Int32), + pl.Field( + name="c", + dtype=pl.Struct( + [ + pl.Field("d", pl.Datetime("ms")), + pl.Field("e", pl.Float32), + pl.Field("f", pl.String), + ] + ), + ), + ] + ), + } + for schema, schema_overrides in ( + (None, override_struct_schema), + (override_struct_schema, None), + ): + df = ( + pl.DataFrame(data, schema=schema, schema_overrides=schema_overrides) + .unnest("y") + .unnest("c") + ) + # shape: (1, 6) + # ┌─────┬───────┬─────┬─────────────────────┬───────┬───────┐ + # │ x ┆ a ┆ b ┆ d ┆ e ┆ f │ + # │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + # │ i16 ┆ str ┆ i32 ┆ datetime[ms] ┆ f32 ┆ str │ + # ╞═════╪═══════╪═════╪═════════════════════╪═══════╪═══════╡ + # │ 100 ┆ hello ┆ 800 ┆ 2023-04-12 10:30:00 ┆ -10.5 ┆ world │ + # └─────┴───────┴─────┴─────────────────────┴───────┴───────┘ + assert df.schema == { + "x": pl.Int16, + "a": pl.String, + "b": pl.Int32, + "d": pl.Datetime("ms"), + "e": pl.Float32, + "f": pl.String, + } + assert df.row(0) == ( + 100, + "hello", + 800, + datetime(2023, 4, 12, 10, 30), + -10.5, + "world", + ) + + +def test_dataclasses_initvar_typing() -> None: + @dataclasses.dataclass + class ABC: + x: date + y: float + z: dataclasses.InitVar[list[str]] = None + + # should be able to parse the initvar typing... + abc = ABC(x=date(1999, 12, 31), y=100.0) + df = pl.DataFrame([abc]) + + # ...but should not load the initvar field into the DataFrame + assert dataclasses.asdict(abc) == df.rows(named=True)[0] + + +@pytest.mark.parametrize( + "nt", + [ + namedtuple("TestData", ["id", "info"]), # noqa: PYI024 + NamedTuple("TestData", [("id", int), ("info", str)]), + ], +) +def test_collections_namedtuple(nt: type) -> None: + nt_data = [nt(1, "a"), nt(2, "b"), nt(3, "c")] + + result = pl.DataFrame(nt_data) + expected = pl.DataFrame({"id": [1, 2, 3], "info": ["a", "b", "c"]}) + assert_frame_equal(result, expected) + + result = pl.DataFrame({"data": nt_data, "misc": ["x", "y", "z"]}) + expected = pl.DataFrame( + { + "data": [ + {"id": 1, "info": "a"}, + {"id": 2, "info": "b"}, + {"id": 3, "info": "c"}, + ], + "misc": ["x", "y", "z"], + } + ) + assert_frame_equal(result, expected) + + +def test_init_ndarray() -> None: + # Empty array + df = pl.DataFrame(np.array([])) + assert_frame_equal(df, pl.DataFrame()) + + # 1D array + df = pl.DataFrame(np.array([1, 2, 3], dtype=np.int64), schema=["a"]) + expected = pl.DataFrame({"a": [1, 2, 3]}) + assert_frame_equal(df, expected) + + df = pl.DataFrame(np.array([1, 2, 3]), schema=[("a", pl.Int32)]) + expected = pl.DataFrame({"a": [1, 2, 3]}).with_columns(pl.col("a").cast(pl.Int32)) + assert_frame_equal(df, expected) + + # 2D array (or 2x 1D array) - should default to column orientation (if C-contiguous) + for data in ( + np.array([[1, 2], [3, 4]], dtype=np.int64), + [np.array([1, 2], dtype=np.int64), np.array([3, 4], dtype=np.int64)], + ): + df = pl.DataFrame(data, orient="col") + expected = pl.DataFrame({"column_0": [1, 2], "column_1": [3, 4]}) + assert_frame_equal(df, expected) + + df = pl.DataFrame([[1, 2.0, "a"], [None, None, None]], orient="row") + expected = pl.DataFrame( + {"column_0": [1, None], "column_1": [2.0, None], "column_2": ["a", None]} + ) + assert_frame_equal(df, expected) + + df = pl.DataFrame( + data=[[1, 2.0, "a"], [None, None, None]], + schema=[("x", pl.Boolean), ("y", pl.Int32), "z"], + orient="row", + ) + assert df.rows() == [(True, 2, "a"), (None, None, None)] + assert df.schema == {"x": pl.Boolean, "y": pl.Int32, "z": pl.String} + + # 2D array - default to column orientation + df = pl.DataFrame(np.array([[1, 2], [3, 4]], dtype=np.int64)) + expected = pl.DataFrame({"column_0": [1, 3], "column_1": [2, 4]}) + assert_frame_equal(df, expected) + + # no orientation, numpy convention + df = pl.DataFrame(np.ones((3, 1), dtype=np.int64)) + assert df.shape == (3, 1) + + # 2D array - row orientation inferred + df = pl.DataFrame( + np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64), schema=["a", "b", "c"] + ) + expected = pl.DataFrame({"a": [1, 4], "b": [2, 5], "c": [3, 6]}) + assert_frame_equal(df, expected) + + # 2D array - column orientation inferred + df = pl.DataFrame( + np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64), schema=["a", "b"] + ) + expected = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + assert_frame_equal(df, expected) + + # List column from 2D array with single-column schema + df = pl.DataFrame(np.arange(4).reshape(-1, 1).astype(np.int64), schema=["a"]) + assert_frame_equal(df, pl.DataFrame({"a": [0, 1, 2, 3]})) + assert np.array_equal(df.to_numpy(), np.arange(4).reshape(-1, 1).astype(np.int64)) + + df = pl.DataFrame(np.arange(4).reshape(-1, 2).astype(np.int64), schema=["a"]) + assert_frame_equal( + df, + pl.DataFrame( + {"a": [[0, 1], [2, 3]]}, schema={"a": pl.Array(pl.Int64, shape=2)} + ), + ) + + # 2D numpy arrays + df = pl.DataFrame({"a": np.arange(5, dtype=np.int64).reshape(1, -1)}) + assert df.dtypes == [pl.Array(pl.Int64, shape=5)] + assert df.shape == (1, 1) + + df = pl.DataFrame({"a": np.arange(10, dtype=np.int64).reshape(2, -1)}) + assert df.dtypes == [pl.Array(pl.Int64, shape=5)] + assert df.shape == (2, 1) + assert df.rows() == [([0, 1, 2, 3, 4],), ([5, 6, 7, 8, 9],)] + + test_rows = [(1, 2), (3, 4)] + df = pl.DataFrame([np.array(test_rows[0]), np.array(test_rows[1])], orient="row") + expected = pl.DataFrame(test_rows, orient="row") + assert_frame_equal(df, expected) + + # round trip export/init + for shape in ((4, 4), (4, 8), (8, 4)): + np_ones = np.ones(shape=shape, dtype=np.float64) + names = [f"c{i}" for i in range(shape[1])] + + df = pl.DataFrame(np_ones, schema=names) + assert_frame_equal(df, pl.DataFrame(np.asarray(df), schema=names)) + + +def test_init_ndarray_errors() -> None: + # 2D array: orientation conflicts with columns + with pytest.raises(ValueError): + pl.DataFrame(np.array([[1, 2, 3], [4, 5, 6]]), schema=["a", "b"], orient="row") + + with pytest.raises(ValueError): + pl.DataFrame( + np.array([[1, 2, 3], [4, 5, 6]]), + schema=[("a", pl.UInt32), ("b", pl.UInt32)], + orient="row", + ) + + # Invalid orient value + with pytest.raises(ValueError): + pl.DataFrame( + np.array([[1, 2, 3], [4, 5, 6]]), + orient="wrong", # type: ignore[arg-type] + ) + + # Dimensions mismatch + with pytest.raises(ValueError): + _ = pl.DataFrame(np.array([1, 2, 3]), schema=[]) + + # Cannot init with 3D array + with pytest.raises(ValueError): + _ = pl.DataFrame(np.random.randn(2, 2, 2)) + + +def test_init_ndarray_nan() -> None: + # numpy arrays containing NaN + df0 = pl.DataFrame( + data={"x": [1.0, 2.5, float("nan")], "y": [4.0, float("nan"), 6.5]}, + ) + df1 = pl.DataFrame( + data={"x": np.array([1.0, 2.5, np.nan]), "y": np.array([4.0, np.nan, 6.5])}, + ) + df2 = pl.DataFrame( + data={"x": np.array([1.0, 2.5, np.nan]), "y": np.array([4.0, np.nan, 6.5])}, + nan_to_null=True, + ) + assert_frame_equal(df0, df1) + assert df2.rows() == [(1.0, 4.0), (2.5, None), (None, 6.5)] + + s0 = pl.Series("n", [1.0, 2.5, float("nan")]) + s1 = pl.Series("n", np.array([1.0, 2.5, float("nan")])) + s2 = pl.Series("n", np.array([1.0, 2.5, float("nan")]), nan_to_null=True) + + assert_series_equal(s0, s1) + assert s2.to_list() == [1.0, 2.5, None] + + +def test_init_ndarray_square() -> None: + # 2D square array; ensure that we maintain convention + # (first axis = rows) with/without an explicit schema + arr = np.arange(4).reshape(2, 2) + assert ( + [(0, 1), (2, 3)] + == pl.DataFrame(arr).rows() + == pl.DataFrame(arr, schema=["a", "b"]).rows() + ) + # check that we tie-break square arrays using fortran vs c-contiguous row/col major + df_c = pl.DataFrame( + data=np.array([[1, 2], [3, 4]], dtype=np.int64, order="C"), + schema=["x", "y"], + ) + assert_frame_equal(df_c, pl.DataFrame({"x": [1, 3], "y": [2, 4]})) + + df_f = pl.DataFrame( + data=np.array([[1, 2], [3, 4]], dtype=np.int64, order="F"), + schema=["x", "y"], + ) + assert_frame_equal(df_f, pl.DataFrame({"x": [1, 2], "y": [3, 4]})) + + +def test_init_numpy_unavailable(monkeypatch: Any) -> None: + monkeypatch.setattr(pl.dataframe.frame, "_check_for_numpy", lambda x: False) + with pytest.raises(TypeError): + pl.DataFrame(np.array([1, 2, 3]), schema=["a"]) + + +def test_init_numpy_scalars() -> None: + df = pl.DataFrame( + { + "bool": [np.bool_(True), np.bool_(False)], + "i8": [np.int8(16), np.int8(64)], + "u32": [np.uint32(1234), np.uint32(9876)], + } + ) + df_expected = pl.from_records( + data=[(True, 16, 1234), (False, 64, 9876)], + schema=OrderedDict([("bool", pl.Boolean), ("i8", pl.Int8), ("u32", pl.UInt32)]), + orient="row", + ) + assert_frame_equal(df, df_expected) + + +def test_null_array_print_format() -> None: + pa_tbl_null = pa.table({"a": [None, None]}) + df_null = pl.from_arrow(pa_tbl_null) + assert df_null.shape == (2, 1) + assert df_null.dtypes == [pl.Null] # type: ignore[union-attr] + assert df_null.rows() == [(None,), (None,)] # type: ignore[union-attr] + + assert ( + str(df_null) == "shape: (2, 1)\n" + "┌──────┐\n" + "│ a │\n" + "│ --- │\n" + "│ null │\n" + "╞══════╡\n" + "│ null │\n" + "│ null │\n" + "└──────┘" + ) + + +def test_init_arrow() -> None: + # Handle unnamed column + df = pl.DataFrame(pa.table({"a": [1, 2], None: [3, 4]})) + expected = pl.DataFrame({"a": [1, 2], "None": [3, 4]}) + assert_frame_equal(df, expected) + + # Rename columns + df = pl.DataFrame(pa.table({"a": [1, 2], "b": [3, 4]}), schema=["c", "d"]) + expected = pl.DataFrame({"c": [1, 2], "d": [3, 4]}) + assert_frame_equal(df, expected) + + df = pl.DataFrame( + pa.table({"a": [1, 2], None: [3, 4]}), + schema=[("c", pl.Int32), ("d", pl.Float32)], + ) + assert df.schema == {"c": pl.Int32, "d": pl.Float32} + assert df.rows() == [(1, 3.0), (2, 4.0)] + + # Bad columns argument + with pytest.raises(ValueError): + pl.DataFrame(pa.table({"a": [1, 2, 3], "b": [4, 5, 6]}), schema=["c", "d", "e"]) + + +def test_init_arrow_dupes() -> None: + tbl = pa.Table.from_arrays( + arrays=[ + pa.array([1, 2, 3], type=pa.int32()), + pa.array([4, 5, 6], type=pa.int32()), + pa.array( + [7, 8, 9], type=pa.decimal128(38, 10) + ), # included as this triggers a panic during construction alongside duplicate fields + ], + schema=pa.schema( + [("col", pa.int32()), ("col", pa.int32()), ("col3", pa.decimal128(38, 10))] + ), + ) + with pytest.raises( + DuplicateError, + match=r"""column appears more than once; names must be unique: \["col"\]""", + ): + pl.DataFrame(tbl) + + +def test_init_from_frame() -> None: + df1 = pl.DataFrame({"id": [0, 1], "misc": ["a", "b"], "val": [-10, 10]}) + assert_frame_equal(df1, pl.DataFrame(df1)) + + df2 = pl.DataFrame(df1, schema=["a", "b", "c"]) + assert_frame_equal(df2, pl.DataFrame(df2)) + + df3 = pl.DataFrame(df1, schema=["a", "b", "c"], schema_overrides={"val": pl.Int8}) + assert_frame_equal(df3, pl.DataFrame(df3)) + + assert df1.schema == {"id": pl.Int64, "misc": pl.String, "val": pl.Int64} + assert df2.schema == {"a": pl.Int64, "b": pl.String, "c": pl.Int64} + assert df3.schema == {"a": pl.Int64, "b": pl.String, "c": pl.Int8} + assert df1.rows() == df2.rows() == df3.rows() + + s1 = pl.Series("s", df3) + s2 = pl.Series(df3) + + assert s1.name == "s" + assert s2.name == "" + + +def test_init_series() -> None: + # List of Series + df = pl.DataFrame([pl.Series("a", [1, 2, 3]), pl.Series("b", [4, 5, 6])]) + expected = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + assert_frame_equal(df, expected) + + # Tuple of Series + df = pl.DataFrame((pl.Series("a", (1, 2, 3)), pl.Series("b", (4, 5, 6)))) + assert_frame_equal(df, expected) + + df = pl.DataFrame( + (pl.Series("a", (1, 2, 3)), pl.Series("b", (4, 5, 6))), + schema=[("x", pl.Float64), ("y", pl.Float64)], + ) + assert df.schema == {"x": pl.Float64, "y": pl.Float64} + assert df.rows() == [(1.0, 4.0), (2.0, 5.0), (3.0, 6.0)] + + # List of unnamed Series + df = pl.DataFrame([pl.Series([1, 2, 3]), pl.Series([4, 5, 6])]) + col0 = pl.Series("column_0", [1, 2, 3]) + col1 = pl.Series("column_1", [4, 5, 6]) + expected = pl.DataFrame([col0, col1]) + assert_frame_equal(df, expected) + + df = pl.DataFrame([pl.Series([0.0]), pl.Series([1.0])]) + assert df.schema == {"column_0": pl.Float64, "column_1": pl.Float64} + assert df.rows() == [(0.0, 1.0)] + + df = pl.DataFrame( + [pl.Series([None]), pl.Series([1.0])], + schema=[("x", pl.Date), ("y", pl.Boolean)], + ) + assert df.schema == {"x": pl.Date, "y": pl.Boolean} + assert df.rows() == [(None, True)] + + # Single Series + df = pl.DataFrame(pl.Series("a", [1, 2, 3])) + expected = pl.DataFrame({"a": [1, 2, 3]}) + assert df.schema == {"a": pl.Int64} + assert_frame_equal(df, expected) + + df = pl.DataFrame(pl.Series("a", [1, 2, 3]), schema=[("a", pl.UInt32)]) + assert df.rows() == [(1,), (2,), (3,)] + assert df.schema == {"a": pl.UInt32} + + # nested list, with/without explicit dtype + s1 = pl.Series([[[2, 2]]]) + assert s1.dtype == pl.List(pl.List(pl.Int64)) + + s2 = pl.Series([[[2, 2]]], dtype=pl.List(pl.List(pl.UInt8))) + assert s2.dtype == pl.List(pl.List(pl.UInt8)) + + nested_dtype = pl.List(pl.List(pl.UInt8)) + s3 = pl.Series("x", dtype=nested_dtype) + s4 = pl.Series(s3) + for s in (s3, s4): + assert s.dtype == nested_dtype + assert s.to_list() == [] + assert s.name == "x" + + s5 = pl.Series("", df, dtype=pl.Int8) + assert_series_equal(s5, pl.Series("", [1, 2, 3], dtype=pl.Int8)) + + +@pytest.mark.parametrize( + ("dtype", "expected_dtype"), + [ + (int, pl.Int64), + (bytes, pl.Binary), + (float, pl.Float64), + (str, pl.String), + (date, pl.Date), + (time, pl.Time), + (datetime, pl.Datetime("us")), + (timedelta, pl.Duration("us")), + (Decimal, pl.Decimal(precision=None, scale=0)), + ], +) +def test_init_py_dtype(dtype: Any, expected_dtype: PolarsDataType) -> None: + for s in ( + pl.Series("s", [None], dtype=dtype), + pl.Series("s", [], dtype=dtype), + ): + assert s.dtype == expected_dtype + + for df in ( + pl.DataFrame({"col": [None]}, schema={"col": dtype}), + pl.DataFrame({"col": []}, schema={"col": dtype}), + ): + assert df.schema == {"col": expected_dtype} + + +def test_init_py_dtype_misc_float() -> None: + assert pl.Series([100], dtype=float).dtype == pl.Float64 # type: ignore[arg-type] + + df = pl.DataFrame( + {"x": [100.0], "y": [200], "z": [None]}, + schema={"x": float, "y": float, "z": float}, + ) + assert df.schema == {"x": pl.Float64, "y": pl.Float64, "z": pl.Float64} + assert df.rows() == [(100.0, 200.0, None)] + + +def test_init_seq_of_seq() -> None: + # List of lists + df = pl.DataFrame([[1, 2, 3], [4, 5, 6]], schema=["a", "b", "c"], orient="row") + expected = pl.DataFrame({"a": [1, 4], "b": [2, 5], "c": [3, 6]}) + assert_frame_equal(df, expected) + + df = pl.DataFrame( + [[1, 2, 3], [4, 5, 6]], + schema=[("a", pl.Int8), ("b", pl.Int16), ("c", pl.Int32)], + orient="row", + ) + assert df.schema == {"a": pl.Int8, "b": pl.Int16, "c": pl.Int32} + assert df.rows() == [(1, 2, 3), (4, 5, 6)] + + # Tuple of tuples, default to column orientation + df = pl.DataFrame(((1, 2, 3), (4, 5, 6))) + expected = pl.DataFrame({"column_0": [1, 2, 3], "column_1": [4, 5, 6]}) + assert_frame_equal(df, expected) + + # Row orientation + df = pl.DataFrame(((1, 2), (3, 4)), schema=("a", "b"), orient="row") + expected = pl.DataFrame({"a": [1, 3], "b": [2, 4]}) + assert_frame_equal(df, expected) + + df = pl.DataFrame( + ((1, 2), (3, 4)), schema=(("a", pl.Float32), ("b", pl.Float32)), orient="row" + ) + assert df.schema == {"a": pl.Float32, "b": pl.Float32} + assert df.rows() == [(1.0, 2.0), (3.0, 4.0)] + + # Wrong orient value + with pytest.raises(ValueError): + df = pl.DataFrame(((1, 2), (3, 4)), orient="wrong") # type: ignore[arg-type] + + +def test_init_1d_sequence() -> None: + # Empty list + df = pl.DataFrame([]) + assert_frame_equal(df, pl.DataFrame()) + + # List/array of strings + data = ["a", "b", "c"] + for a in (data, np.array(data)): + df = pl.DataFrame(a, schema=["s"]) + expected = pl.DataFrame({"s": data}) + assert_frame_equal(df, expected) + + df = pl.DataFrame([None, True, False], schema=[("xx", pl.Int8)]) + assert df.schema == {"xx": pl.Int8} + assert df.rows() == [(None,), (1,), (0,)] + + # String sequence + result = pl.DataFrame("abc", schema=["s"]) + expected = pl.DataFrame({"s": ["a", "b", "c"]}) + assert_frame_equal(result, expected) + + # datetimes sequence + df = pl.DataFrame([datetime(2020, 1, 1)], schema={"ts": pl.Datetime("ms")}) + assert df.schema == {"ts": pl.Datetime("ms")} + df = pl.DataFrame( + [datetime(2020, 1, 1, tzinfo=timezone.utc)], schema={"ts": pl.Datetime("ms")} + ) + assert df.schema == {"ts": pl.Datetime("ms", "UTC")} + df = pl.DataFrame( + [datetime(2020, 1, 1, tzinfo=timezone(timedelta(hours=1)))], + schema={"ts": pl.Datetime("ms")}, + ) + assert df.schema == {"ts": pl.Datetime("ms", "UTC")} + df = pl.DataFrame( + [datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kathmandu"))], + schema={"ts": pl.Datetime("ms")}, + ) + assert df.schema == {"ts": pl.Datetime("ms", "Asia/Kathmandu")} + + +def test_init_pandas(monkeypatch: Any) -> None: + pandas_df = pd.DataFrame([[1, 2], [3, 4]], columns=[1, 2]) + + # integer column names + df = pl.DataFrame(pandas_df) + expected = pl.DataFrame({"1": [1, 3], "2": [2, 4]}) + assert_frame_equal(df, expected) + assert df.schema == {"1": pl.Int64, "2": pl.Int64} + + # override column names, types + df = pl.DataFrame(pandas_df, schema=[("x", pl.Float64), ("y", pl.Float64)]) + assert df.schema == {"x": pl.Float64, "y": pl.Float64} + assert df.rows() == [(1.0, 2.0), (3.0, 4.0)] + + # subclassed pandas object, with/without data & overrides + class XSeries(pd.Series): # type: ignore[type-arg] + @property + def _constructor(self) -> type: + return XSeries + + df = pl.DataFrame( + data=[ + XSeries(name="x", data=[], dtype=np.dtype("= parse_version("2.2.0"): + df = pl.DataFrame(pandas_df) + assert_frame_equal(df, expected) + + else: + with pytest.raises(TypeError): + pl.DataFrame(pandas_df) + + +def test_init_errors() -> None: + # Length mismatch + with pytest.raises(ShapeError): + pl.DataFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0, 4.0]}) + + # Columns don't match data dimensions + with pytest.raises(ShapeError): + pl.DataFrame([[1, 2], [3, 4]], schema=["a", "b", "c"]) + + # Unmatched input + with pytest.raises(TypeError): + pl.DataFrame(0) + + +def test_init_records() -> None: + dicts = [ + {"a": 1, "b": 2}, + {"b": 1, "a": 2}, + {"a": 1, "b": 2}, + ] + df = pl.DataFrame(dicts) + expected = pl.DataFrame({"a": [1, 2, 1], "b": [2, 1, 2]}) + assert_frame_equal(df, expected) + assert df.to_dicts() == dicts + + df_cd = pl.DataFrame(dicts, schema=["a", "c", "d"]) + expected_values = { + "a": [1, 2, 1], + "c": [None, None, None], + "d": [None, None, None], + } + assert df_cd.to_dict(as_series=False) == expected_values + + data = {"a": 1, "b": 2, "c": 3} + + df1 = pl.from_dicts([data]) + assert df1.columns == ["a", "b", "c"] + + df1.columns = ["x", "y", "z"] + assert df1.columns == ["x", "y", "z"] + + df2 = pl.from_dicts([data], schema=["c", "b", "a"]) + assert df2.columns == ["c", "b", "a"] + + for colname in ("c", "b", "a"): + result = pl.from_dicts([data], schema=[colname]) + expected_values = {colname: [data[colname]]} + assert result.to_dict(as_series=False) == expected_values + + +def test_init_records_schema_order() -> None: + cols: list[str] = ["a", "b", "c", "d"] + data: list[dict[str, int]] = [ + {"c": 3, "b": 2, "a": 1}, + {"b": 2, "d": 4}, + {}, + {"a": 1, "b": 2, "c": 3}, + {"d": 4, "b": 2, "a": 1}, + {"c": 3, "b": 2}, + ] + lookup = {"a": 1, "b": 2, "c": 3, "d": 4, "e": None} + + for constructor in (pl.from_dicts, pl.DataFrame): + # ensure field values are loaded according to the declared schema order + for _ in range(8): + shuffle(data) + shuffle(cols) + + df = constructor(data, schema=cols) + for col in df.columns: + assert all(value in (None, lookup[col]) for value in df[col].to_list()) + + # have schema override inferred types, omit some columns, add a new one + schema = {"a": pl.Int8, "c": pl.Int16, "e": pl.Int32} + df = constructor(data, schema=schema) + + assert df.schema == schema + for col in df.columns: + assert all(value in (None, lookup[col]) for value in df[col].to_list()) + + +def test_init_only_columns() -> None: + df = pl.DataFrame(schema=["a", "b", "c"]) + expected = pl.DataFrame({"a": [], "b": [], "c": []}) + assert_frame_equal(df, expected) + + # Validate construction with various flavours of no/empty data + no_data: Any + for no_data in (None, {}, []): + df = pl.DataFrame( + data=no_data, + schema=[ + ("a", pl.Date), + ("b", pl.UInt64), + ("c", pl.Int8), + ("d", pl.List(pl.UInt8)), + ], + ) + expected = pl.DataFrame({"a": [], "b": [], "c": []}).with_columns( + pl.col("a").cast(pl.Date), + pl.col("b").cast(pl.UInt64), + pl.col("c").cast(pl.Int8), + ) + expected.insert_column(3, pl.Series("d", [], pl.List(pl.UInt8))) + + assert df.shape == (0, 4) + assert_frame_equal(df, expected) + assert df.dtypes == [pl.Date, pl.UInt64, pl.Int8, pl.List] + assert pl.List(pl.UInt8).is_(df.schema["d"]) + + dfe = df.clear() + assert len(dfe) == 0 + assert df.schema == dfe.schema + assert dfe.shape == df.shape + + +def test_from_dicts_list_without_dtype() -> None: + result = pl.from_dicts( + [{"id": 1, "hint": ["some_text_here"]}, {"id": 2, "hint": [None]}] + ) + expected = pl.DataFrame({"id": [1, 2], "hint": [["some_text_here"], [None]]}) + assert_frame_equal(result, expected) + + +def test_from_dicts_list_struct_without_inner_dtype() -> None: + df = pl.DataFrame( + { + "users": [ + [{"category": "A"}, {"category": "B"}], + [{"category": None}, {"category": None}], + ], + "days_of_week": [1, 2], + } + ) + expected = { + "users": [ + [{"category": "A"}, {"category": "B"}], + [{"category": None}, {"category": None}], + ], + "days_of_week": [1, 2], + } + assert df.to_dict(as_series=False) == expected + + +def test_from_dicts_list_struct_without_inner_dtype_5611() -> None: + result = pl.from_dicts( + [ + {"a": []}, + {"a": [{"b": 1}]}, + ] + ) + expected = pl.DataFrame({"a": [[], [{"b": 1}]]}) + assert_frame_equal(result, expected) + + +def test_from_dict_upcast_primitive() -> None: + df = pl.from_dict({"a": [1, 2.1, 3], "b": [4, 5, 6.4]}, strict=False) + assert df.dtypes == [pl.Float64, pl.Float64] + + +def test_u64_lit_5031() -> None: + df = pl.DataFrame({"foo": [1, 2, 3]}).with_columns(pl.col("foo").cast(pl.UInt64)) + assert df.filter(pl.col("foo") < (1 << 64) - 20).shape == (3, 1) + assert df["foo"].to_list() == [1, 2, 3] + + +def test_from_dicts_missing_columns() -> None: + # missing columns from some of the data dicts + data = [{"a": 1}, {"b": 2}] + result = pl.from_dicts(data) + expected = pl.DataFrame({"a": [1, None], "b": [None, 2]}) + assert_frame_equal(result, expected) + + # partial schema with some columns missing; only load the declared keys + data = [{"a": 1, "b": 2}] + result = pl.from_dicts(data, schema=["a"]) + expected = pl.DataFrame({"a": [1]}) + assert_frame_equal(result, expected) + + +def test_from_dicts_schema_columns_do_not_match() -> None: + data = [{"a": 1, "b": 2}] + result = pl.from_dicts(data, schema=["x"]) + expected = pl.DataFrame({"x": [None]}) + assert_frame_equal(result, expected) + + +def test_from_dicts_infer_integer_types() -> None: + data = [ + { + "a": 2**7 - 1, + "b": 2**15 - 1, + "c": 2**31 - 1, + "d": 2**63 - 1, + "e": 2**127 - 1, + } + ] + result = pl.from_dicts(data).schema + # all values inferred as i64 except for values too large for i64 + expected = { + "a": pl.Int64, + "b": pl.Int64, + "c": pl.Int64, + "d": pl.Int64, + "e": pl.Int128, + } + assert result == expected + + with pytest.raises(OverflowError): + pl.from_dicts([{"too_big": 2**127}]) + + +def test_from_dicts_list_large_int_17006() -> None: + data = [{"x": [2**64 - 1]}] + + result = pl.from_dicts(data, schema={"x": pl.List(pl.UInt64)}) + expected = pl.DataFrame({"x": [[2**64 - 1]]}, schema={"x": pl.List(pl.UInt64)}) + assert_frame_equal(result, expected) + + result = pl.from_dicts(data, schema={"x": pl.Array(pl.UInt64, 1)}) + expected = pl.DataFrame({"x": [[2**64 - 1]]}, schema={"x": pl.Array(pl.UInt64, 1)}) + assert_frame_equal(result, expected) + + +def test_from_rows_dtype() -> None: + # 50 is the default inference length + # 5182 + df = pl.DataFrame( + data=[(None, None)] * 50 + [("1.23", None)], + schema=[("foo", pl.String), ("bar", pl.String)], + orient="row", + ) + assert df.dtypes == [pl.String, pl.String] + assert df.null_count().row(0) == (50, 51) + + type1 = [{"c1": 206, "c2": "type1", "c3": {"x1": "abcd", "x2": "jkl;"}}] + type2 = [ + {"c1": 208, "c2": "type2", "c3": {"a1": "abcd", "a2": "jkl;", "a3": "qwerty"}} + ] + + df = pl.DataFrame( + data=type1 * 50 + type2, + schema=[("c1", pl.Int32), ("c2", pl.Object), ("c3", pl.Object)], + ) + assert df.dtypes == [pl.Int32, pl.Object, pl.Object] + + # 50 is the default inference length + # 5266 + type1 = [{"c1": 206, "c2": "type1", "c3": {"x1": "abcd", "x2": "jkl;"}}] + type2 = [ + {"c1": 208, "c2": "type2", "c3": {"a1": "abcd", "a2": "jkl;", "a3": "qwerty"}} + ] + + df = pl.DataFrame( + data=type1 * 50 + type2, + schema=[("c1", pl.Int32), ("c2", pl.Object), ("c3", pl.Object)], + ) + assert df.dtypes == [pl.Int32, pl.Object, pl.Object] + assert df.null_count().row(0) == (0, 0, 0) + + dc = _TestBazDC(d=datetime(2020, 2, 22), e=42.0, f="xyz") + df = pl.DataFrame([[dc]], schema={"d": pl.Object}) + assert df.schema == {"d": pl.Object} + assert df.item() == dc + + +def test_from_dicts_schema() -> None: + data = [{"a": 1, "b": 4}, {"a": 2, "b": 5}, {"a": 3, "b": 6}] + + # let polars infer the dtypes, but inform it about a 3rd column. + for schema, overrides in ( + ({"a": pl.Unknown, "b": pl.Unknown, "c": pl.Int32}, None), + ({"a": None, "b": None, "c": None}, {"c": pl.Int32}), + (["a", "b", ("c", pl.Int32)], None), + ): + df = pl.from_dicts( + data, + schema=schema, # type: ignore[arg-type] + schema_overrides=overrides, + ) + assert df.dtypes == [pl.Int64, pl.Int64, pl.Int32] + assert df.to_dict(as_series=False) == { + "a": [1, 2, 3], + "b": [4, 5, 6], + "c": [None, None, None], + } + + # provide data that resolves to an empty frame (ref: scalar + # expansion shortcut), with schema/override hints + schema = {"colx": pl.String, "coly": pl.Int32} + + for param in ("schema", "schema_overrides"): + df = pl.DataFrame({"colx": [], "coly": 0}, **{param: schema}) # type: ignore[arg-type] + assert df.schema == schema + + +def test_nested_read_dicts_4143() -> None: + result = pl.from_dicts( + [ + { + "id": 1, + "hint": [ + {"some_text_here": "text", "list_": [1, 2, 4]}, + {"some_text_here": "text", "list_": [1, 2, 4]}, + ], + }, + { + "id": 2, + "hint": [ + {"some_text_here": None, "list_": [1]}, + {"some_text_here": None, "list_": [2]}, + ], + }, + ] + ) + expected = { + "hint": [ + [ + {"some_text_here": "text", "list_": [1, 2, 4]}, + {"some_text_here": "text", "list_": [1, 2, 4]}, + ], + [ + {"some_text_here": None, "list_": [1]}, + {"some_text_here": None, "list_": [2]}, + ], + ], + "id": [1, 2], + } + assert result.to_dict(as_series=False) == expected + + +def test_nested_read_dicts_4143_2() -> None: + result = pl.from_dicts( + [ + { + "id": 1, + "hint": [ + {"some_text_here": "text", "list_": [1, 2, 4]}, + {"some_text_here": "text", "list_": [1, 2, 4]}, + ], + }, + { + "id": 2, + "hint": [ + {"some_text_here": "text", "list_": []}, + {"some_text_here": "text", "list_": []}, + ], + }, + ] + ) + + assert result.dtypes == [ + pl.Int64, + pl.List(pl.Struct({"some_text_here": pl.String, "list_": pl.List(pl.Int64)})), + ] + expected = { + "id": [1, 2], + "hint": [ + [ + {"some_text_here": "text", "list_": [1, 2, 4]}, + {"some_text_here": "text", "list_": [1, 2, 4]}, + ], + [ + {"some_text_here": "text", "list_": []}, + {"some_text_here": "text", "list_": []}, + ], + ], + } + assert result.to_dict(as_series=False) == expected + + +def test_from_records_nullable_structs() -> None: + records = [ + {"id": 1, "items": [{"item_id": 100, "description": None}]}, + {"id": 1, "items": [{"item_id": 100, "description": "hi"}]}, + ] + + schema: list[tuple[str, PolarsDataType]] = [ + ("id", pl.UInt16), + ( + "items", + pl.List( + pl.Struct( + [pl.Field("item_id", pl.UInt32), pl.Field("description", pl.String)] + ) + ), + ), + ] + + schema_options: list[list[tuple[str, PolarsDataType]] | None] = [schema, None] + for s in schema_options: + result = pl.DataFrame(records, schema=s, orient="row") + expected = { + "id": [1, 1], + "items": [ + [{"item_id": 100, "description": None}], + [{"item_id": 100, "description": "hi"}], + ], + } + assert result.to_dict(as_series=False) == expected + + # check initialisation without any records + df = pl.DataFrame(schema=schema) + dict_schema = dict(schema) + assert df.to_dict(as_series=False) == {"id": [], "items": []} + assert df.schema == dict_schema + + dtype: PolarsDataType = dict_schema["items"] + series = pl.Series("items", dtype=dtype) + assert series.to_frame().to_dict(as_series=False) == {"items": []} + assert series.dtype == dict_schema["items"] + assert series.to_list() == [] + + +@pytest.mark.parametrize("unnest_column", ["a", pl.col("a"), cs.by_name("a")]) +def test_from_categorical_in_struct_defined_by_schema(unnest_column: Any) -> None: + df = pl.DataFrame( + {"a": [{"value": "foo", "counts": 1}, {"value": "bar", "counts": 2}]}, + schema={"a": pl.Struct({"value": pl.Categorical, "counts": pl.UInt32})}, + ) + + expected = pl.DataFrame( + {"value": ["foo", "bar"], "counts": [1, 2]}, + schema={"value": pl.Categorical, "counts": pl.UInt32}, + ) + + res_eager = df.unnest(unnest_column) + assert_frame_equal(res_eager, expected, categorical_as_str=True) + + res_lazy = df.lazy().unnest(unnest_column) + assert_frame_equal(res_lazy.collect(), expected, categorical_as_str=True) + + +def test_nested_schema_construction() -> None: + schema = { + "node_groups": pl.List( + pl.Struct( + [ + pl.Field("parent_node_group_id", pl.UInt8), + pl.Field( + "nodes", + pl.List( + pl.Struct( + [ + pl.Field("name", pl.String), + pl.Field( + "sub_nodes", + pl.List( + pl.Struct( + [ + pl.Field("internal_id", pl.UInt64), + pl.Field("value", pl.UInt32), + ] + ) + ), + ), + ] + ) + ), + ), + ] + ) + ) + } + df = pl.DataFrame( + { + "node_groups": [ + [{"nodes": []}, {"nodes": [{"name": "", "sub_nodes": []}]}], + ] + }, + schema=schema, + ) + + assert df.schema == schema + assert df.to_dict(as_series=False) == { + "node_groups": [ + [ + {"parent_node_group_id": None, "nodes": []}, + { + "parent_node_group_id": None, + "nodes": [{"name": "", "sub_nodes": []}], + }, + ] + ] + } + + +def test_nested_schema_construction2() -> None: + schema = { + "node_groups": pl.List( + pl.Struct( + [ + pl.Field( + "nodes", + pl.List( + pl.Struct( + [ + pl.Field("name", pl.String), + pl.Field("time", pl.UInt32), + ] + ) + ), + ) + ] + ) + ) + } + df = pl.DataFrame( + [ + {"node_groups": [{"nodes": [{"name": "a", "time": 0}]}]}, + {"node_groups": [{"nodes": []}]}, + ], + schema=schema, + ) + assert df.schema == schema + assert df.to_dict(as_series=False) == { + "node_groups": [[{"nodes": [{"name": "a", "time": 0}]}], [{"nodes": []}]] + } + + +def test_arrow_to_pyseries_with_one_chunk_does_not_copy_data() -> None: + from polars._utils.construction import arrow_to_pyseries + + original_array = pa.chunked_array([[1, 2, 3]], type=pa.int64()) + pyseries = arrow_to_pyseries("", original_array) + assert ( + pyseries.get_chunks()[0]._get_buffer_info()[0] + == original_array.chunks[0].buffers()[1].address + ) + + +def test_init_with_explicit_binary_schema() -> None: + df = pl.DataFrame({"a": [b"hello", b"world"]}, schema={"a": pl.Binary}) + assert df.schema == {"a": pl.Binary} + assert df["a"].to_list() == [b"hello", b"world"] + + s = pl.Series("a", [b"hello", b"world"], dtype=pl.Binary) + assert s.dtype == pl.Binary + assert s.to_list() == [b"hello", b"world"] + + +def test_nested_categorical() -> None: + s = pl.Series([["a"]], dtype=pl.List(pl.Categorical)) + assert s.to_list() == [["a"]] + assert s.dtype == pl.List(pl.Categorical) + + +def test_datetime_date_subclasses() -> None: + class FakeDate(date): ... + + class FakeDateChild(FakeDate): ... + + class FakeDatetime(FakeDate, datetime): ... + + result = pl.Series([FakeDate(2020, 1, 1)]) + expected = pl.Series([date(2020, 1, 1)]) + assert_series_equal(result, expected) + + result = pl.Series([FakeDateChild(2020, 1, 1)]) + expected = pl.Series([date(2020, 1, 1)]) + assert_series_equal(result, expected) + + result = pl.Series([FakeDatetime(2020, 1, 1, 3)]) + expected = pl.Series([datetime(2020, 1, 1, 3)]) + assert_series_equal(result, expected) + + +def test_list_null_constructor() -> None: + s = pl.Series("a", [[None], [None]], dtype=pl.List(pl.Null)) + assert s.dtype == pl.List(pl.Null) + assert s.to_list() == [[None], [None]] + + # nested + dtype = pl.List(pl.List(pl.Int8)) + values = [ + [], + [[], []], + [[33, 112]], + ] + s = pl.Series( + name="colx", + values=values, + dtype=dtype, + ) + assert s.dtype == dtype + assert s.to_list() == values + + # nested + # small order change has influence + dtype = pl.List(pl.List(pl.Int8)) + values = [ + [[], []], + [], + [[33, 112]], + ] + s = pl.Series( + name="colx", + values=values, + dtype=dtype, + ) + assert s.dtype == dtype + assert s.to_list() == values + + +def test_numpy_float_construction_av() -> None: + np_dict = {"a": np.float64(1)} + assert_frame_equal(pl.DataFrame(np_dict), pl.DataFrame({"a": 1.0})) + + +def test_df_init_dict_raise_on_expression_input() -> None: + with pytest.raises( + TypeError, + match="passing Expr objects to the DataFrame constructor is not supported", + ): + pl.DataFrame({"a": pl.int_range(0, 3)}) + with pytest.raises(TypeError): + pl.DataFrame({"a": pl.int_range(0, 3), "b": [3, 4, 5]}) + + # Passing a list of expressions is allowed + df = pl.DataFrame({"a": [pl.int_range(0, 3)]}) + assert df.get_column("a").dtype.is_object() + + +def test_df_schema_sequences() -> None: + schema = [ + ["address", pl.String], + ["key", pl.Int64], + ["value", pl.Float32], + ] + df = pl.DataFrame(schema=schema) # type: ignore[arg-type] + assert df.schema == {"address": pl.String, "key": pl.Int64, "value": pl.Float32} + + +def test_df_schema_sequences_incorrect_length() -> None: + schema = [ + ["address", pl.String, pl.Int8], + ["key", pl.Int64], + ["value", pl.Float32], + ] + with pytest.raises(ValueError): + pl.DataFrame(schema=schema) # type: ignore[arg-type] + + +@pytest.mark.parametrize( + ("input", "infer_func", "expected_dtype"), + [ + ("f8", numpy_char_code_to_dtype, pl.Float64), + ("f4", numpy_char_code_to_dtype, pl.Float32), + ("i4", numpy_char_code_to_dtype, pl.Int32), + ("u1", numpy_char_code_to_dtype, pl.UInt8), + ("?", numpy_char_code_to_dtype, pl.Boolean), + ("m8", numpy_char_code_to_dtype, pl.Duration("us")), + ("M8", numpy_char_code_to_dtype, pl.Datetime("us")), + ], +) +def test_numpy_inference( + input: Any, + infer_func: Callable[[Any], PolarsDataType], + expected_dtype: PolarsDataType, +) -> None: + result = infer_func(input) + assert result == expected_dtype + + +def test_array_construction() -> None: + payload = [[1, 2, 3], None, [4, 2, 3]] + + dtype = pl.Array(pl.Int64, 3) + s = pl.Series(payload, dtype=dtype) + assert s.dtype == dtype + assert s.to_list() == payload + + # inner type + dtype = pl.Array(pl.UInt8, 2) + payload = [[1, 2], None, [3, 4]] + s = pl.Series(payload, dtype=dtype) + assert s.dtype == dtype + assert s.to_list() == payload + + # create using schema + df = pl.DataFrame( + schema={ + "a": pl.Array(pl.Float32, 3), + "b": pl.Array(pl.Datetime("ms"), 5), + } + ) + assert df.dtypes == [ + pl.Array(pl.Float32, 3), + pl.Array(pl.Datetime("ms"), 5), + ] + assert df.rows() == [] + + # from dicts + rows = [ + {"row_id": "a", "data": [1, 2, 3]}, + {"row_id": "b", "data": [2, 3, 4]}, + ] + schema = {"row_id": pl.String(), "data": pl.Array(inner=pl.Int64, shape=3)} + df = pl.from_dicts(rows, schema=schema) + assert df.schema == schema + assert df.rows() == [("a", [1, 2, 3]), ("b", [2, 3, 4])] + + +@pytest.mark.may_fail_auto_streaming +def test_pycapsule_interface(df: pl.DataFrame) -> None: + df = df.rechunk() + pyarrow_table = df.to_arrow() + + # Array via C data interface + pyarrow_array = pyarrow_table["bools"].chunk(0) + round_trip_series = pl.Series(PyCapsuleArrayHolder(pyarrow_array)) + assert df["bools"].equals(round_trip_series, check_dtypes=True, check_names=False) + + # empty Array via C data interface + empty_pyarrow_array = pa.array([], type=pyarrow_array.type) + round_trip_series = pl.Series(PyCapsuleArrayHolder(empty_pyarrow_array)) + assert df["bools"].dtype == round_trip_series.dtype + + # RecordBatch via C array interface + pyarrow_record_batch = pyarrow_table.to_batches()[0] + round_trip_df = pl.DataFrame(PyCapsuleArrayHolder(pyarrow_record_batch)) + assert df.equals(round_trip_df) + + # ChunkedArray via C stream interface + pyarrow_chunked_array = pyarrow_table["bools"] + round_trip_series = pl.Series(PyCapsuleStreamHolder(pyarrow_chunked_array)) + assert df["bools"].equals(round_trip_series, check_dtypes=True, check_names=False) + + # empty ChunkedArray via C stream interface + empty_chunked_array = pa.chunked_array([], type=pyarrow_chunked_array.type) + round_trip_series = pl.Series(PyCapsuleStreamHolder(empty_chunked_array)) + assert df["bools"].dtype == round_trip_series.dtype + + # Table via C stream interface + round_trip_df = pl.DataFrame(PyCapsuleStreamHolder(pyarrow_table)) + assert df.equals(round_trip_df) + + # empty Table via C stream interface + empty_df = df[:0].to_arrow() + round_trip_df = pl.DataFrame(PyCapsuleStreamHolder(empty_df)) + orig_schema = df.schema + round_trip_schema = round_trip_df.schema + + # The "enum" schema is not preserved because categories are lost via C data + # interface + orig_schema.pop("enum") + round_trip_schema.pop("enum") + + assert orig_schema == round_trip_schema + + # RecordBatchReader via C stream interface + pyarrow_reader = pa.RecordBatchReader.from_batches( + pyarrow_table.schema, pyarrow_table.to_batches() + ) + round_trip_df = pl.DataFrame(PyCapsuleStreamHolder(pyarrow_reader)) + assert df.equals(round_trip_df) + + +@pytest.mark.parametrize( + "tz", + [ + None, + ZoneInfo("Asia/Tokyo"), + ZoneInfo("Europe/Amsterdam"), + ZoneInfo("UTC"), + timezone.utc, + ], +) +def test_init_list_of_dicts_with_timezone(tz: Any) -> None: + dt = datetime(2023, 1, 1, 0, 0, 0, 0, tzinfo=tz) + + df = pl.DataFrame([{"dt": dt}, {"dt": dt}]) + expected = pl.DataFrame({"dt": [dt, dt]}) + assert_frame_equal(df, expected) + + assert df.schema == {"dt": pl.Datetime("us", time_zone=tz)} + + +@pytest.mark.parametrize( + "tz", + [ + None, + ZoneInfo("Asia/Tokyo"), + ZoneInfo("Europe/Amsterdam"), + ZoneInfo("UTC"), + timezone.utc, + ], +) +def test_init_list_of_nested_dicts_with_timezone(tz: Any) -> None: + dt = datetime(2021, 1, 1, 0, 0, 0, 0, tzinfo=tz) + data = [{"timestamp": {"content": datetime(2021, 1, 1, 0, 0, tzinfo=tz)}}] + + df = pl.DataFrame(data).unnest("timestamp") + expected = pl.DataFrame({"content": [dt]}) + assert_frame_equal(df, expected) + + assert df.schema == {"content": pl.Datetime("us", time_zone=tz)} + + +def test_init_from_subclassed_types() -> None: + # more detailed test of one custom subclass... + import codecs + + class SuperSecretString(str): + def __new__(cls, value: str) -> Self: + return super().__new__(cls, value) + + def __repr__(self) -> str: + return codecs.encode(self, "rot_13") + + w = "windmolen" + sstr = SuperSecretString(w) + + assert sstr == w + assert isinstance(sstr, str) + assert repr(sstr) == "jvaqzbyra" + assert_series_equal(pl.Series([w, w]), pl.Series([sstr, sstr])) + + # ...then validate across other basic types + for BaseType, value in ( + (int, 42), + (float, 5.5), + (bytes, b"value"), + (str, "value"), + ): + + class SubclassedType(BaseType): # type: ignore[misc,valid-type] + def __new__(cls, value: Any) -> Self: + return super().__new__(cls, value) # type: ignore[no-any-return] + + assert ( + pl.Series([value]).to_list() == pl.Series([SubclassedType(value)]).to_list() + ) + + +def test_series_init_with_python_type_7737() -> None: + assert pl.Series([], dtype=int).dtype == pl.Int64 # type: ignore[arg-type] + assert pl.Series([], dtype=float).dtype == pl.Float64 # type: ignore[arg-type] + assert pl.Series([], dtype=bool).dtype == pl.Boolean # type: ignore[arg-type] + assert pl.Series([], dtype=str).dtype == pl.Utf8 # type: ignore[arg-type] + + with pytest.raises(TypeError): + pl.Series(["a"], dtype=int) # type: ignore[arg-type] + + with pytest.raises(TypeError): + pl.Series([True], dtype=str) # type: ignore[arg-type] + + +def test_init_from_list_shape_6968() -> None: + df1 = pl.DataFrame([[1, None], [2, None], [3, None]]) + df2 = pl.DataFrame([[None, None], [2, None], [3, None]]) + assert df1.shape == (2, 3) + assert df2.shape == (2, 3) diff --git a/py-polars/tests/unit/constructors/test_dataframe.py b/py-polars/tests/unit/constructors/test_dataframe.py new file mode 100644 index 000000000000..5d97983a966e --- /dev/null +++ b/py-polars/tests/unit/constructors/test_dataframe.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +import enum +import sys +from collections import OrderedDict +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any + +import pytest + +import polars as pl +from polars.exceptions import DataOrientationWarning, InvalidOperationError + +if TYPE_CHECKING: + from collections.abc import Iterator + + +def test_df_mixed_dtypes_string() -> None: + data = {"x": [["abc", 12, 34.5]], "y": [1]} + + with pytest.raises(TypeError, match="unexpected value"): + pl.DataFrame(data, strict=True) + + df = pl.DataFrame(data, strict=False) + assert df.schema == {"x": pl.List(pl.String), "y": pl.Int64} + assert df.rows() == [(["abc", "12", "34.5"], 1)] + + +def test_df_mixed_dtypes_object() -> None: + data = {"x": [[b"abc", 12, 34.5]], "y": [1]} + + with pytest.raises(TypeError, match="failed to determine supertype"): + pl.DataFrame(data, strict=True) + + df = pl.DataFrame(data, strict=False) + assert df.schema == {"x": pl.Object, "y": pl.Int64} + assert df.rows() == [([b"abc", 12, 34.5], 1)] + + +def test_df_object() -> None: + class Foo: + def __init__(self, value: int) -> None: + self._value = value + + def __eq__(self, other: object) -> bool: + return issubclass(other.__class__, self.__class__) and ( + self._value == other._value # type: ignore[attr-defined] + ) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self._value})" + + df = pl.DataFrame({"a": [Foo(1), Foo(2)]}) + assert df["a"].dtype.is_object() + assert df.rows() == [(Foo(1),), (Foo(2),)] + + +def test_df_init_from_generator_dict_view() -> None: + d = {0: "x", 1: "y", 2: "z"} + data = { + "keys": d.keys(), + "vals": d.values(), + "items": d.items(), + } + with pytest.raises(TypeError, match="unexpected value"): + pl.DataFrame(data, strict=True) + + df = pl.DataFrame(data, strict=False) + assert df.schema == { + "keys": pl.Int64, + "vals": pl.String, + "items": pl.List(pl.String), + } + assert df.to_dict(as_series=False) == { + "keys": [0, 1, 2], + "vals": ["x", "y", "z"], + "items": [["0", "x"], ["1", "y"], ["2", "z"]], + } + + +@pytest.mark.skipif( + sys.version_info < (3, 11), + reason="reversed dict views not supported before Python 3.11", +) +def test_df_init_from_generator_reversed_dict_view() -> None: + d = {0: "x", 1: "y", 2: "z"} + data = { + "rev_keys": reversed(d.keys()), + "rev_vals": reversed(d.values()), + "rev_items": reversed(d.items()), + } + df = pl.DataFrame(data, schema_overrides={"rev_items": pl.Object}) + + assert df.schema == { + "rev_keys": pl.Int64, + "rev_vals": pl.String, + "rev_items": pl.Object, + } + assert df.to_dict(as_series=False) == { + "rev_keys": [2, 1, 0], + "rev_vals": ["z", "y", "x"], + "rev_items": [(2, "z"), (1, "y"), (0, "x")], + } + + +def test_df_init_strict() -> None: + data = {"a": [1, 2, 3.0]} + schema = {"a": pl.Int8} + with pytest.raises(TypeError): + pl.DataFrame(data, schema=schema, strict=True) + + df = pl.DataFrame(data, schema=schema, strict=False) + + assert df["a"].to_list() == [1, 2, 3] + assert df["a"].dtype == pl.Int8 + + +def test_df_init_from_series_strict() -> None: + s = pl.Series("a", [-1, 0, 1]) + schema = {"a": pl.UInt8} + with pytest.raises(InvalidOperationError): + pl.DataFrame(s, schema=schema, strict=True) + + df = pl.DataFrame(s, schema=schema, strict=False) + + assert df["a"].to_list() == [None, 0, 1] + assert df["a"].dtype == pl.UInt8 + + +# https://github.com/pola-rs/polars/issues/15471 +def test_df_init_rows_overrides_non_existing() -> None: + df = pl.DataFrame([{"a": 1}], schema_overrides={"a": pl.Int8(), "b": pl.Boolean()}) + assert df.schema == OrderedDict({"a": pl.Int8}) + + df = pl.DataFrame( + [{"a": 3, "b": 1.0}], + schema_overrides={"a": pl.Int8, "c": pl.Utf8}, + ) + assert df.schema == OrderedDict({"a": pl.Int8, "b": pl.Float64}) + + +# https://github.com/pola-rs/polars/issues/15245 +def test_df_init_nested_mixed_types() -> None: + data = [{"key": [{"value": 1}, {"value": 1.0}]}] + + with pytest.raises(TypeError, match="unexpected value"): + pl.DataFrame(data, strict=True) + + df = pl.DataFrame(data, strict=False) + + assert df.schema == {"key": pl.List(pl.Struct({"value": pl.Float64}))} + assert df.to_dicts() == [{"key": [{"value": 1.0}, {"value": 1.0}]}] + + +class CustomSchema(Mapping[str, Any]): + """Dummy schema object for testing compatibility with Mapping.""" + + _entries: dict[str, Any] + + def __init__(self, **named_entries: Any) -> None: + self._items = OrderedDict(named_entries.items()) + + def __getitem__(self, key: str) -> Any: + return self._items[key] + + def __len__(self) -> int: + return len(self._items) + + def __iter__(self) -> Iterator[str]: + yield from self._items + + +def test_custom_schema() -> None: + df = pl.DataFrame(schema=CustomSchema(bool=pl.Boolean, misc=pl.UInt8)) + assert df.schema == OrderedDict([("bool", pl.Boolean), ("misc", pl.UInt8)]) + + with pytest.raises(TypeError): + pl.DataFrame(schema=CustomSchema(bool="boolean", misc="unsigned int")) + + +def test_list_null_constructor_schema() -> None: + expected = pl.List(pl.Null) + assert pl.DataFrame({"a": [[]]}).dtypes[0] == expected + assert pl.DataFrame(schema={"a": pl.List}).dtypes[0] == expected + + +def test_df_init_schema_object() -> None: + schema = pl.Schema({"a": pl.Int8(), "b": pl.String()}) + df = pl.DataFrame({"a": [1, 2, 3], "b": ["x", "y", "z"]}, schema=schema) + + assert df.columns == schema.names() + assert df.dtypes == schema.dtypes() + + +def test_df_init_data_orientation_inference_warning() -> None: + with pytest.warns(DataOrientationWarning): + pl.from_records([[1, 2, 3], [4, 5, 6]], schema=["a", "b", "c"]) + + +def test_df_init_enum_dtype() -> None: + class PythonEnum(str, enum.Enum): + A = "A" + B = "B" + C = "C" + + df = pl.DataFrame({"Col 1": ["A", "B", "C"]}, schema={"Col 1": PythonEnum}) + assert df.dtypes[0] == pl.Enum(["A", "B", "C"]) diff --git a/py-polars/tests/unit/constructors/test_series.py b/py-polars/tests/unit/constructors/test_series.py new file mode 100644 index 000000000000..87a4860d082d --- /dev/null +++ b/py-polars/tests/unit/constructors/test_series.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +from datetime import date, datetime, timedelta +from typing import TYPE_CHECKING, Any + +import numpy as np +import pandas as pd +import pytest + +import polars as pl +from polars.testing.asserts.series import assert_series_equal + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + + +def test_series_mixed_dtypes_list() -> None: + values = [[0.1, 1]] + + with pytest.raises(TypeError, match="unexpected value"): + pl.Series(values) + + s = pl.Series(values, strict=False) + assert s.dtype == pl.List(pl.Float64) + assert s.to_list() == [[0.1, 1.0]] + + +def test_series_mixed_dtypes_string() -> None: + values = [[12], "foo", 9] + + with pytest.raises(TypeError, match="unexpected value"): + pl.Series(values) + + s = pl.Series(values, strict=False) + assert s.dtype == pl.String + assert s.to_list() == ["[12]", "foo", "9"] + assert s[1] == "foo" + + +def test_series_mixed_dtypes_object() -> None: + values = [[12], b"foo", 9] + + with pytest.raises(TypeError, match="unexpected value"): + pl.Series(values) + + s = pl.Series(values, strict=False) + assert s.dtype.is_object() + assert s.to_list() == values + assert s[1] == b"foo" + + +# https://github.com/pola-rs/polars/issues/15139 +@pytest.mark.parametrize("dtype", [pl.List(pl.Int64), None]) +def test_sequence_of_series_with_dtype(dtype: PolarsDataType | None) -> None: + values = [1, 2, 3] + int_series = pl.Series(values) + list_series = pl.Series([int_series], dtype=dtype) + + assert list_series.to_list() == [values] + assert list_series.dtype == pl.List(pl.Int64) + + +@pytest.mark.parametrize( + ("values", "dtype", "expected_dtype"), + [ + ([1, 1.0, 1], None, pl.Float64), + ([1, 1, "1.0"], None, pl.String), + ([1, 1.0, "1.0"], None, pl.String), + ([True, 1], None, pl.Int64), + ([True, 1.0], None, pl.Float64), + ([True, 1], pl.Boolean, pl.Boolean), + ([True, 1.0], pl.Boolean, pl.Boolean), + ([False, "1.0"], None, pl.String), + ], +) +def test_upcast_primitive_and_strings( + values: list[Any], dtype: PolarsDataType, expected_dtype: PolarsDataType +) -> None: + with pytest.raises(TypeError): + pl.Series(values, dtype=dtype, strict=True) + + assert pl.Series(values, dtype=dtype, strict=False).dtype == expected_dtype + + +def test_preserve_decimal_precision() -> None: + dtype = pl.Decimal(None, 1) + s = pl.Series(dtype=dtype) + assert s.dtype == dtype + + +@pytest.mark.parametrize("dtype", [None, pl.Duration("ms")]) +def test_large_timedelta(dtype: pl.DataType | None) -> None: + values = [timedelta.min, timedelta.max] + s = pl.Series(values, dtype=dtype) + assert s.dtype == pl.Duration("ms") + + # Microsecond precision is lost + expected = [timedelta.min, timedelta.max - timedelta(microseconds=999)] + assert s.to_list() == expected + + +def test_array_large_u64() -> None: + u64_max = 2**64 - 1 + values = [[u64_max]] + dtype = pl.Array(pl.UInt64, 1) + s = pl.Series(values, dtype=dtype) + assert s.dtype == dtype + assert s.to_list() == values + + +def test_series_init_ambiguous_datetime() -> None: + value = datetime(2001, 10, 28, 2) + dtype = pl.Datetime(time_zone="Europe/Belgrade") + + result = pl.Series([value], dtype=dtype, strict=True) + expected = pl.Series([datetime(2001, 10, 28, 3)]).dt.replace_time_zone( + "Europe/Belgrade" + ) + assert_series_equal(result, expected) + + result = pl.Series([value], dtype=dtype, strict=False) + assert_series_equal(result, expected) + + +def test_series_init_nonexistent_datetime() -> None: + value = datetime(2024, 3, 31, 2, 30) + dtype = pl.Datetime(time_zone="Europe/Amsterdam") + + result = pl.Series([value], dtype=dtype, strict=True) + expected = pl.Series([datetime(2024, 3, 31, 4, 30)]).dt.replace_time_zone( + "Europe/Amsterdam" + ) + assert_series_equal(result, expected) + + result = pl.Series([value], dtype=dtype, strict=False) + assert_series_equal(result, expected) + + +# https://github.com/pola-rs/polars/issues/15518 +def test_series_init_np_temporal_with_nat_15518() -> None: + arr = np.array(["2020-01-01", "2020-01-02", "2020-01-03"], "datetime64[D]") + arr[1] = np.datetime64("NaT") + + result = pl.Series(arr) + + expected = pl.Series([date(2020, 1, 1), None, date(2020, 1, 3)]) + assert_series_equal(result, expected) + + +def test_series_init_pandas_timestamp_18127() -> None: + result = pl.Series([pd.Timestamp("2000-01-01T00:00:00.123456789", tz="UTC")]) + # Note: time unit is not (yet) respected, it should be Datetime('ns', 'UTC'). + assert result.dtype == pl.Datetime("us", "UTC") + + +def test_series_init_np_2d_zero_zero_shape() -> None: + arr = np.array([]).reshape(0, 0) + assert_series_equal( + pl.Series("a", arr), + pl.Series("a", [], pl.Array(pl.Float64, 0)), + ) + + +def test_series_init_np_2d_empty() -> None: + arr = np.array([]).reshape(0, 2) + assert_series_equal( + pl.Series("a", arr), + pl.Series("a", [], pl.Array(pl.Float64, 2)), + ) + + +def test_list_null_constructor_schema() -> None: + expected = pl.List(pl.Null) + assert pl.Series([[]]).dtype == expected + assert pl.Series([[]], dtype=pl.List).dtype == expected diff --git a/py-polars/tests/unit/constructors/test_strictness.py b/py-polars/tests/unit/constructors/test_strictness.py new file mode 100644 index 000000000000..f218206f3be8 --- /dev/null +++ b/py-polars/tests/unit/constructors/test_strictness.py @@ -0,0 +1,8 @@ +import pytest + +import polars as pl + + +def test_list_constructor_strictness() -> None: + with pytest.raises(TypeError, match="setting `strict=False`"): + pl.Series([[1], ["two"]], strict=True) diff --git a/py-polars/tests/unit/constructors/test_structs.py b/py-polars/tests/unit/constructors/test_structs.py new file mode 100644 index 000000000000..bd4f2f725ce6 --- /dev/null +++ b/py-polars/tests/unit/constructors/test_structs.py @@ -0,0 +1,149 @@ +import polars as pl + + +def test_constructor_non_strict_schema_17956() -> None: + schema = { + "logged_event": pl.Struct( + [ + pl.Field( + "completetask", + pl.Struct( + [ + pl.Field( + "parameters", + pl.List( + pl.Struct( + [ + pl.Field( + "numericarray", + pl.Struct( + [ + pl.Field( + "value", pl.List(pl.Float64) + ), + ] + ), + ), + ] + ) + ), + ), + ] + ), + ), + ] + ), + } + + data = { + "logged_event": { + "completetask": { + "parameters": [ + { + "numericarray": { + "value": [ + 431, + 431, + 431, + 431, + 431, + 431, + 431, + 431, + 431, + 431, + 431, + 431, + 431, + 431, + 431, + 431, + 431, + 431, + 431, + 431, + 431, + 431, + 431, + 431, + 431, + 431, + 431, + 431, + 431, + 431, + 431, + 431, + 431, + 431, + 430.5, + 431, + 431, + 431, + ] + } + } + ] + } + } + } + + lazyframe = pl.LazyFrame( + [data], + schema=schema, + strict=False, + ) + assert lazyframe.collect().to_dict(as_series=False) == { + "logged_event": [ + { + "completetask": { + "parameters": [ + { + "numericarray": { + "value": [ + 431.0, + 431.0, + 431.0, + 431.0, + 431.0, + 431.0, + 431.0, + 431.0, + 431.0, + 431.0, + 431.0, + 431.0, + 431.0, + 431.0, + 431.0, + 431.0, + 431.0, + 431.0, + 431.0, + 431.0, + 431.0, + 431.0, + 431.0, + 431.0, + 431.0, + 431.0, + 431.0, + 431.0, + 431.0, + 431.0, + 431.0, + 431.0, + 431.0, + 431.0, + 430.5, + 431.0, + 431.0, + 431.0, + ] + } + } + ] + } + } + ] + } diff --git a/py-polars/tests/unit/dataframe/__init__.py b/py-polars/tests/unit/dataframe/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/dataframe/test_describe.py b/py-polars/tests/unit/dataframe/test_describe.py new file mode 100644 index 000000000000..a96ea5c192df --- /dev/null +++ b/py-polars/tests/unit/dataframe/test_describe.py @@ -0,0 +1,235 @@ +from __future__ import annotations + +from datetime import date, datetime, time + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + + +@pytest.mark.parametrize("lazy", [False, True]) +def test_df_describe(lazy: bool) -> None: + df = pl.DataFrame( + { + "a": [1.0, 2.8, 3.0], + "b": [4, 5, None], + "c": [True, False, True], + "d": [None, "b", "c"], + "e": ["usd", "eur", None], + "f": [ + datetime(2020, 1, 1, 10, 30), + datetime(2021, 7, 5, 15, 0), + datetime(2022, 12, 31, 20, 30), + ], + "g": [date(2020, 1, 1), date(2021, 7, 5), date(2022, 12, 31)], + "h": [time(10, 30), time(15, 0), time(20, 30)], + "i": [1_000_000, 2_000_000, 3_000_000], + }, + schema_overrides={"e": pl.Categorical, "i": pl.Duration}, + ) + + frame: pl.DataFrame | pl.LazyFrame = df.lazy() if lazy else df + result = frame.describe() + print(result) + + expected = pl.DataFrame( + { + "statistic": [ + "count", + "null_count", + "mean", + "std", + "min", + "25%", + "50%", + "75%", + "max", + ], + "a": [ + 3.0, + 0.0, + 2.2666666666666666, + 1.1015141094572205, + 1.0, + 2.8, + 2.8, + 3.0, + 3.0, + ], + "b": [2.0, 1.0, 4.5, 0.7071067811865476, 4.0, 4.0, 5.0, 5.0, 5.0], + "c": [3.0, 0.0, 2 / 3, None, False, None, None, None, True], + "d": ["2", "1", None, None, "b", None, None, None, "c"], + "e": ["2", "1", None, None, None, None, None, None, None], + "f": [ + "3", + "0", + "2021-07-03 07:20:00", + None, + "2020-01-01 10:30:00", + "2021-07-05 15:00:00", + "2021-07-05 15:00:00", + "2022-12-31 20:30:00", + "2022-12-31 20:30:00", + ], + "g": [ + "3", + "0", + "2021-07-02 16:00:00", + None, + "2020-01-01", + "2021-07-05", + "2021-07-05", + "2022-12-31", + "2022-12-31", + ], + "h": [ + "3", + "0", + "15:20:00", + None, + "10:30:00", + "15:00:00", + "15:00:00", + "20:30:00", + "20:30:00", + ], + "i": [ + "3", + "0", + "0:00:02", + None, + "0:00:01", + "0:00:02", + "0:00:02", + "0:00:03", + "0:00:03", + ], + } + ) + assert_frame_equal(result, expected) + + +def test_df_describe_nested() -> None: + df = pl.DataFrame( + { + "struct": [{"x": 1, "y": 2}, {"x": 3, "y": 4}, {"x": 1, "y": 2}, None], + "list": [[1, 2], [3, 4], [1, 2], None], + } + ) + result = df.describe() + expected = pl.DataFrame( + [ + ("count", 3, 3), + ("null_count", 1, 1), + ("mean", None, None), + ("std", None, None), + ("min", None, None), + ("25%", None, None), + ("50%", None, None), + ("75%", None, None), + ("max", None, None), + ], + schema=["statistic"] + df.columns, + schema_overrides={"struct": pl.Float64, "list": pl.Float64}, + orient="row", + ) + assert_frame_equal(result, expected) + + +def test_df_describe_custom_percentiles() -> None: + df = pl.DataFrame({"numeric": [1, 2, 1, None]}) + result = df.describe(percentiles=(0.2, 0.4, 0.5, 0.6, 0.8)) + expected = pl.DataFrame( + [ + ("count", 3.0), + ("null_count", 1.0), + ("mean", 1.3333333333333333), + ("std", 0.5773502691896257), + ("min", 1.0), + ("20%", 1.0), + ("40%", 1.0), + ("50%", 1.0), + ("60%", 1.0), + ("80%", 2.0), + ("max", 2.0), + ], + schema=["statistic"] + df.columns, + orient="row", + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("pcts", [None, []]) +def test_df_describe_no_percentiles(pcts: list[float] | None) -> None: + df = pl.DataFrame({"numeric": [1, 2, 1, None]}) + result = df.describe(percentiles=pcts) + expected = pl.DataFrame( + [ + ("count", 3.0), + ("null_count", 1.0), + ("mean", 1.3333333333333333), + ("std", 0.5773502691896257), + ("min", 1.0), + ("max", 2.0), + ], + schema=["statistic"] + df.columns, + orient="row", + ) + assert_frame_equal(result, expected) + + +def test_df_describe_empty_column() -> None: + df = pl.DataFrame(schema={"a": pl.Int64}) + result = df.describe() + expected = pl.DataFrame( + [ + ("count", 0.0), + ("null_count", 0.0), + ("mean", None), + ("std", None), + ("min", None), + ("25%", None), + ("50%", None), + ("75%", None), + ("max", None), + ], + schema=["statistic"] + df.columns, + orient="row", + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("lazy", [False, True]) +def test_df_describe_empty(lazy: bool) -> None: + frame: pl.DataFrame | pl.LazyFrame = pl.LazyFrame() if lazy else pl.DataFrame() + cls_name = "LazyFrame" if lazy else "DataFrame" + with pytest.raises( + TypeError, match=f"cannot describe a {cls_name} that has no columns" + ): + frame.describe() + + +def test_df_describe_quantile_precision() -> None: + df = pl.DataFrame({"a": range(10)}) + result = df.describe(percentiles=[0.99, 0.999, 0.9999]) + result_metrics = result.get_column("statistic").to_list() + expected_metrics = ["99%", "99.9%", "99.99%"] + for m in expected_metrics: + assert m in result_metrics + + +# https://github.com/pola-rs/polars/issues/9830 +def test_df_describe_object() -> None: + df = pl.Series( + "object", + [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}], + dtype=pl.Object, + ).to_frame() + + result = df.describe(percentiles=(0.05, 0.25, 0.5, 0.75, 0.95)) + + expected = pl.DataFrame( + {"statistic": ["count", "null_count"], "object": ["3", "0"]} + ) + assert_frame_equal(result.head(2), expected) diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py new file mode 100644 index 000000000000..9cc80a672ff6 --- /dev/null +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -0,0 +1,3059 @@ +from __future__ import annotations + +import sys +import typing +from collections import OrderedDict +from datetime import date, datetime, time, timedelta, timezone +from decimal import Decimal +from io import BytesIO +from operator import floordiv, truediv +from typing import TYPE_CHECKING, Any, Callable, cast +from zoneinfo import ZoneInfo + +import numpy as np +import pyarrow as pa +import pytest + +import polars as pl +import polars.selectors as cs +from polars._utils.construction import iterable_to_pydf +from polars.datatypes import DTYPE_TEMPORAL_UNITS +from polars.exceptions import ( + ColumnNotFoundError, + ComputeError, + DuplicateError, + InvalidOperationError, + OutOfBoundsError, + ShapeError, +) +from polars.testing import ( + assert_frame_equal, + assert_frame_not_equal, + assert_series_equal, +) +from tests.unit.conftest import INTEGER_DTYPES + +if TYPE_CHECKING: + from collections.abc import Iterator, Sequence + + from polars import Expr + from polars._typing import JoinStrategy, UniqueKeepStrategy + + +def test_version() -> None: + pl.__version__ + + +def test_null_count() -> None: + df = pl.DataFrame({"a": [2, 1, 3], "b": ["a", "b", None]}) + assert df.null_count().shape == (1, 2) + assert df.null_count().row(0) == (0, 1) + assert df.null_count().row(np.int64(0)) == (0, 1) # type: ignore[call-overload] + + +@pytest.mark.parametrize("input", [None, (), [], {}, pa.Table.from_arrays([])]) +def test_init_empty(input: Any) -> None: + # test various flavours of empty init + df = pl.DataFrame(input) + assert df.shape == (0, 0) + assert df.is_empty() + + +def test_df_bool_ambiguous() -> None: + empty_df = pl.DataFrame() + with pytest.raises(TypeError, match="ambiguous"): + not empty_df + + +def test_special_char_colname_init() -> None: + from string import punctuation + + cols = [(c, pl.Int8) for c in punctuation] + df = pl.DataFrame(schema=cols) + + assert len(cols) == df.width + assert len(df.rows()) == 0 + assert df.is_empty() + + +def test_comparisons() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + + # Constants + assert_frame_equal(df == 2, pl.DataFrame({"a": [False, True], "b": [False, False]})) + assert_frame_equal(df != 2, pl.DataFrame({"a": [True, False], "b": [True, True]})) + assert_frame_equal(df < 3.0, pl.DataFrame({"a": [True, True], "b": [False, False]})) + assert_frame_equal(df >= 2, pl.DataFrame({"a": [False, True], "b": [True, True]})) + assert_frame_equal(df <= 2, pl.DataFrame({"a": [True, True], "b": [False, False]})) + + with pytest.raises(ComputeError): + df > "2" # noqa: B015 + + # Series + s = pl.Series([3, 1]) + assert_frame_equal(df >= s, pl.DataFrame({"a": [False, True], "b": [True, True]})) + + # DataFrame + other = pl.DataFrame({"a": [1, 2], "b": [2, 3]}) + assert_frame_equal( + df == other, pl.DataFrame({"a": [True, True], "b": [False, False]}) + ) + assert_frame_equal( + df != other, pl.DataFrame({"a": [False, False], "b": [True, True]}) + ) + assert_frame_equal( + df > other, pl.DataFrame({"a": [False, False], "b": [True, True]}) + ) + assert_frame_equal( + df < other, pl.DataFrame({"a": [False, False], "b": [False, False]}) + ) + assert_frame_equal( + df >= other, pl.DataFrame({"a": [True, True], "b": [True, True]}) + ) + assert_frame_equal( + df <= other, pl.DataFrame({"a": [True, True], "b": [False, False]}) + ) + + # DataFrame columns mismatch + with pytest.raises(ValueError): + df == pl.DataFrame({"a": [1, 2], "c": [3, 4]}) # noqa: B015 + with pytest.raises(ValueError): + df == pl.DataFrame({"b": [3, 4], "a": [1, 2]}) # noqa: B015 + + # DataFrame shape mismatch + with pytest.raises(ValueError): + df == pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) # noqa: B015 + + # Type mismatch + with pytest.raises(ComputeError): + df == pl.DataFrame({"a": [1, 2], "b": ["x", "y"]}) # noqa: B015 + + +def test_column_selection() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0], "c": ["a", "b", "c"]}) + + # get column by name + b = pl.Series("b", [1.0, 2.0, 3.0]) + assert_series_equal(df["b"], b) + assert_series_equal(df.get_column("b"), b) + + with pytest.raises(ColumnNotFoundError, match="x"): + df.get_column("x") + + default_series = pl.Series("x", ["?", "?", "?"]) + assert_series_equal(df.get_column("x", default=default_series), default_series) + + assert df.get_column("x", default=None) is None + + # get column by index + assert_series_equal(df.to_series(1), pl.Series("b", [1.0, 2.0, 3.0])) + assert_series_equal(df.to_series(-1), pl.Series("c", ["a", "b", "c"])) + + +def test_mixed_sequence_selection() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + result = df.select(["a", pl.col("b"), pl.lit("c")]) + expected = pl.DataFrame({"a": [1, 2], "b": [3, 4], "literal": ["c", "c"]}) + assert_frame_equal(result, expected) + + +def test_from_arrow(monkeypatch: Any) -> None: + tbl = pa.table( + { + "a": pa.array([1, 2], pa.timestamp("s")), + "b": pa.array([1, 2], pa.timestamp("ms")), + "c": pa.array([1, 2], pa.timestamp("us")), + "d": pa.array([1, 2], pa.timestamp("ns")), + "e": pa.array([1, 2], pa.int32()), + "decimal1": pa.array([1, 2], pa.decimal128(2, 1)), + "struct": pa.array( + [{"a": 1}, {"a": 2}], pa.struct([pa.field("a", pa.int32())]) + ), + } + ) + record_batches = tbl.to_batches(max_chunksize=1) + expected_schema = { + "a": pl.Datetime("ms"), + "b": pl.Datetime("ms"), + "c": pl.Datetime("us"), + "d": pl.Datetime("ns"), + "e": pl.Int32, + "decimal1": pl.Decimal(2, 1), + "struct": pl.Struct({"a": pl.Int32()}), + } + expected_data = [ + ( + datetime(1970, 1, 1, 0, 0, 1), + datetime(1970, 1, 1, 0, 0, 0, 1000), + datetime(1970, 1, 1, 0, 0, 0, 1), + datetime(1970, 1, 1, 0, 0), + 1, + Decimal("1.0"), + {"a": 1}, + ), + ( + datetime(1970, 1, 1, 0, 0, 2), + datetime(1970, 1, 1, 0, 0, 0, 2000), + datetime(1970, 1, 1, 0, 0, 0, 2), + datetime(1970, 1, 1, 0, 0), + 2, + Decimal("2.0"), + {"a": 2}, + ), + ] + for arrow_data in (tbl, record_batches, (rb for rb in record_batches)): + df = cast(pl.DataFrame, pl.from_arrow(arrow_data)) + assert df.schema == expected_schema + assert df.rows() == expected_data + + # record batches (inc. empty) + for b, n_expected in ( + (record_batches[0], 1), + (record_batches[0][:0], 0), + ): + df = cast(pl.DataFrame, pl.from_arrow(b)) + assert df.schema == expected_schema + assert df.rows() == expected_data[:n_expected] + + empty_tbl = tbl[:0] # no rows + df = cast(pl.DataFrame, pl.from_arrow(empty_tbl)) + assert df.schema == expected_schema + assert df.rows() == [] + + # try a single column dtype override + for t in (tbl, empty_tbl): + df = pl.DataFrame(t, schema_overrides={"e": pl.Int8}) + override_schema = expected_schema.copy() + override_schema["e"] = pl.Int8 + assert df.schema == override_schema + assert df.rows() == expected_data[: (df.height)] + + # init from record batches with overrides + df = pl.DataFrame( + { + "id": ["a123", "b345", "c567", "d789", "e101"], + "points": [99, 45, 50, 85, 35], + } + ) + tbl = df.to_arrow() + batches = tbl.to_batches(max_chunksize=3) + + df0: pl.DataFrame = pl.from_arrow(batches) # type: ignore[assignment] + df1: pl.DataFrame = pl.from_arrow( # type: ignore[assignment] + data=batches, + schema=["x", "y"], + schema_overrides={"y": pl.Int32}, + ) + df2: pl.DataFrame = pl.from_arrow( # type: ignore[assignment] + data=batches[0], + schema=["x", "y"], + schema_overrides={"y": pl.Int32}, + ) + + assert df0.rows() == df.rows() + assert df1.rows() == df.rows() + assert df2.rows() == df.rows()[:3] + + assert df0.schema == {"id": pl.String, "points": pl.Int64} + print(df1.schema) + assert df1.schema == {"x": pl.String, "y": pl.Int32} + assert df2.schema == {"x": pl.String, "y": pl.Int32} + + with pytest.raises(TypeError, match="Cannot convert str"): + pl.from_arrow(data="xyz") + + with pytest.raises(TypeError, match="Cannot convert int"): + pl.from_arrow(data=(x for x in (1, 2, 3))) + + +@pytest.mark.parametrize( + "data", + [ + pa.Table.from_pydict( + { + "struct": pa.array( + [{"a": 1}, {"a": 2}], pa.struct([pa.field("a", pa.int32())]) + ), + } + ), + pa.Table.from_pydict( + { + "struct": pa.chunked_array( + [[{"a": 1}], [{"a": 2}]], pa.struct([pa.field("a", pa.int32())]) + ), + } + ), + ], +) +def test_from_arrow_struct_column(data: pa.Table) -> None: + df = cast(pl.DataFrame, pl.from_arrow(data=data)) + expected_schema = pl.Schema({"struct": pl.Struct({"a": pl.Int32()})}) + expected_data = [({"a": 1},), ({"a": 2},)] + assert df.schema == expected_schema + assert df.rows() == expected_data + + +def test_dataframe_membership_operator() -> None: + # cf. issue #4032 + df = pl.DataFrame({"name": ["Jane", "John"], "age": [20, 30]}) + assert "name" in df + assert "phone" not in df + assert df._ipython_key_completions_() == ["name", "age"] + + +def test_sort() -> None: + df = pl.DataFrame({"a": [2, 1, 3], "b": [1, 2, 3]}) + expected = pl.DataFrame({"a": [1, 2, 3], "b": [2, 1, 3]}) + assert_frame_equal(df.sort("a"), expected) + assert_frame_equal(df.sort(["a", "b"]), expected) + + +def test_sort_multi_output_exprs_01() -> None: + df = pl.DataFrame( + { + "dts": [date(2077, 10, 3), date(2077, 10, 2), date(2077, 10, 2)], + "strs": ["abc", "def", "ghi"], + "vals": [10.5, 20.3, 15.7], + } + ) + + expected = pl.DataFrame( + { + "dts": [date(2077, 10, 2), date(2077, 10, 2), date(2077, 10, 3)], + "strs": ["ghi", "def", "abc"], + "vals": [15.7, 20.3, 10.5], + } + ) + assert_frame_equal(expected, df.sort(pl.col("^(d|v).*$"))) + assert_frame_equal(expected, df.sort(cs.temporal() | cs.numeric())) + assert_frame_equal(expected, df.sort(cs.temporal(), cs.numeric(), cs.binary())) + + expected = pl.DataFrame( + { + "dts": [date(2077, 10, 3), date(2077, 10, 2), date(2077, 10, 2)], + "strs": ["abc", "def", "ghi"], + "vals": [10.5, 20.3, 15.7], + } + ) + assert_frame_equal( + expected, + df.sort(pl.col("^(d|v).*$"), descending=[True]), + ) + assert_frame_equal( + expected, + df.sort(cs.temporal() | cs.numeric(), descending=[True]), + ) + assert_frame_equal( + expected, + df.sort(cs.temporal(), cs.numeric(), descending=[True, True]), + ) + + with pytest.raises( + ValueError, + match=r"the length of `descending` \(2\) does not match the length of `by` \(1\)", + ): + df.sort(by=[cs.temporal()], descending=[True, False]) + + with pytest.raises( + ValueError, + match=r"the length of `nulls_last` \(3\) does not match the length of `by` \(2\)", + ): + df.sort("dts", "strs", nulls_last=[True, False, True]) + + # No columns selected - return original input. + assert_frame_equal(df, df.sort(pl.col("^xxx$"))) + + +@pytest.mark.parametrize( + ("by_explicit", "desc_explicit", "by_multi", "desc_multi"), + [ + ( + ["w", "x", "y", "z"], + [False, False, True, True], + [cs.integer(), cs.string()], + [False, True], + ), + ( + ["w", "y", "z"], + [True, True, False], + [pl.col("^(w|y)$"), pl.col("^z.*$")], + [True, False], + ), + ( + ["z", "w", "x"], + [True, False, False], + [pl.col("z"), cs.numeric()], + [True, False], + ), + ], +) +def test_sort_multi_output_exprs_02( + by_explicit: list[str], + desc_explicit: list[bool], + by_multi: list[Expr], + desc_multi: list[bool], +) -> None: + df = pl.DataFrame( + { + "w": [100, 100, 100, 100, 200, 200, 200, 200], + "x": [888, 888, 444, 444, 888, 888, 444, 888], + "y": ["b", "b", "a", "a", "b", "b", "a", "a"], + "z": ["x", "y", "x", "y", "x", "y", "x", "y"], + } + ) + res1 = df.sort(*by_explicit, descending=desc_explicit) + res2 = df.sort(*by_multi, descending=desc_multi) + assert_frame_equal(res1, res2) + + +def test_sort_maintain_order() -> None: + l1 = ( + pl.LazyFrame({"A": [1] * 4, "B": ["A", "B", "C", "D"]}) + .sort("A", maintain_order=True) + .slice(0, 3) + .collect()["B"] + .to_list() + ) + l2 = ( + pl.LazyFrame({"A": [1] * 4, "B": ["A", "B", "C", "D"]}) + .sort("A") + .collect() + .slice(0, 3)["B"] + .to_list() + ) + assert l1 == l2 == ["A", "B", "C"] + + +@pytest.mark.parametrize("nulls_last", [False, True], ids=["nulls_first", "nulls_last"]) +def test_sort_maintain_order_descending_repeated_nulls(nulls_last: bool) -> None: + got = ( + pl.LazyFrame({"A": [None, -1, 1, 1, None], "B": [1, 2, 3, 4, 5]}) + .sort("A", descending=True, maintain_order=True, nulls_last=nulls_last) + .collect() + ) + if nulls_last: + expect = pl.DataFrame({"A": [1, 1, -1, None, None], "B": [3, 4, 2, 1, 5]}) + else: + expect = pl.DataFrame({"A": [None, None, 1, 1, -1], "B": [1, 5, 3, 4, 2]}) + assert_frame_equal(got, expect) + + +def test_replace() -> None: + df = pl.DataFrame({"a": [2, 1, 3], "b": [1, 2, 3]}) + s = pl.Series("c", [True, False, True]) + df._replace("a", s) + assert_frame_equal(df, pl.DataFrame({"a": [True, False, True], "b": [1, 2, 3]})) + + +def test_assignment() -> None: + df = pl.DataFrame({"foo": [1, 2, 3], "bar": [2, 3, 4]}) + df = df.with_columns(pl.col("foo").alias("foo")) + # make sure that assignment does not change column order + assert df.columns == ["foo", "bar"] + df = df.with_columns( + pl.when(pl.col("foo") > 1).then(9).otherwise(pl.col("foo")).alias("foo") + ) + assert df["foo"].to_list() == [1, 9, 9] + + +def test_insert_column() -> None: + # insert series + df = ( + pl.DataFrame({"z": [3, 4, 5]}) + .insert_column(0, pl.Series("x", [1, 2, 3])) + .insert_column(-1, pl.Series("y", [2, 3, 4])) + ) + expected_df = pl.DataFrame({"x": [1, 2, 3], "y": [2, 3, 4], "z": [3, 4, 5]}) + assert_frame_equal(expected_df, df) + + # insert expressions + df = pl.DataFrame( + { + "id": ["xx", "yy", "zz"], + "v1": [5, 4, 6], + "v2": [7, 3, 3], + } + ) + df.insert_column(3, (pl.col("v1") * pl.col("v2")).alias("v3")) + df.insert_column(1, (pl.col("v2") - pl.col("v1")).alias("v0")) + + expected = pl.DataFrame( + { + "id": ["xx", "yy", "zz"], + "v0": [2, -1, -3], + "v1": [5, 4, 6], + "v2": [7, 3, 3], + "v3": [35, 12, 18], + } + ) + assert_frame_equal(df, expected) + + # check that we raise suitable index errors + for idx, column in ( + (10, pl.col("v1").sqrt().alias("v1_sqrt")), + (-10, pl.Series("foo", [1, 2, 3])), + ): + with pytest.raises( + IndexError, + match=rf"column index {idx} is out of range \(frame has 5 columns\)", + ): + df.insert_column(idx, column) + + +def test_replace_column() -> None: + df = ( + pl.DataFrame({"x": [1, 2, 3], "y": [2, 3, 4], "z": [3, 4, 5]}) + .replace_column(0, pl.Series("a", [4, 5, 6])) + .replace_column(-2, pl.Series("b", [5, 6, 7])) + .replace_column(-1, pl.Series("c", [6, 7, 8])) + ) + expected_df = pl.DataFrame({"a": [4, 5, 6], "b": [5, 6, 7], "c": [6, 7, 8]}) + assert_frame_equal(expected_df, df) + + +def test_to_series() -> None: + df = pl.DataFrame({"x": [1, 2, 3], "y": [2, 3, 4], "z": [3, 4, 5]}) + + assert_series_equal(df.to_series(), df["x"]) + assert_series_equal(df.to_series(0), df["x"]) + assert_series_equal(df.to_series(-3), df["x"]) + + assert_series_equal(df.to_series(1), df["y"]) + assert_series_equal(df.to_series(-2), df["y"]) + + assert_series_equal(df.to_series(2), df["z"]) + assert_series_equal(df.to_series(-1), df["z"]) + + +def test_to_series_bad_inputs() -> None: + df = pl.DataFrame({"x": [1, 2, 3], "y": [2, 3, 4], "z": [3, 4, 5]}) + + with pytest.raises(IndexError, match="index 5 is out of bounds"): + df.to_series(5) + + with pytest.raises(IndexError, match="index -100 is out of bounds"): + df.to_series(-100) + + with pytest.raises( + TypeError, match="'str' object cannot be interpreted as an integer" + ): + df.to_series("x") # type: ignore[arg-type] + + +def test_gather_every() -> None: + df = pl.DataFrame({"a": [1, 2, 3, 4], "b": ["w", "x", "y", "z"]}) + expected_df = pl.DataFrame({"a": [1, 3], "b": ["w", "y"]}) + assert_frame_equal(expected_df, df.gather_every(2)) + + expected_df = pl.DataFrame({"a": [2, 4], "b": ["x", "z"]}) + assert_frame_equal(expected_df, df.gather_every(2, offset=1)) + + +def test_gather_every_agg() -> None: + df = pl.DataFrame( + { + "g": [1, 1, 1, 2, 2, 2], + "a": ["a", "b", "c", "d", "e", "f"], + } + ) + out = df.group_by(pl.col("g")).agg(pl.col("a").gather_every(2)).sort("g") + expected = pl.DataFrame( + { + "g": [1, 2], + "a": [["a", "c"], ["d", "f"]], + } + ) + assert_frame_equal(out, expected) + + +def test_take_misc(fruits_cars: pl.DataFrame) -> None: + df = fruits_cars + + # Out of bounds error. + with pytest.raises(OutOfBoundsError): + df.sort("fruits").select( + pl.col("B").reverse().gather([1, 2]).implode().over("fruits"), + "fruits", + ) + + # Null indices. + assert_frame_equal( + df.select(pl.col("fruits").gather(pl.Series([0, None]))), + pl.DataFrame({"fruits": ["banana", None]}), + ) + + for index in [[0, 1], pl.Series([0, 1]), np.array([0, 1])]: + out = df.sort("fruits").select( + pl.col("B") + .reverse() + .gather(index) # type: ignore[arg-type] + .over("fruits", mapping_strategy="join"), + "fruits", + ) + + assert out[0, "B"].to_list() == [2, 3] + assert out[4, "B"].to_list() == [1, 4] + + out = df.sort("fruits").select( + pl.col("B").reverse().get(pl.lit(1)).over("fruits"), + "fruits", + ) + assert out[0, "B"] == 3 + assert out[4, "B"] == 4 + + +def test_pipe() -> None: + df = pl.DataFrame({"foo": [1, 2, 3], "bar": [6, None, 8]}) + + def _multiply(data: pl.DataFrame, mul: int) -> pl.DataFrame: + return data * mul + + result = df.pipe(_multiply, mul=3) + + assert_frame_equal(result, df * 3) + + +def test_explode() -> None: + df = pl.DataFrame({"letters": ["c", "a"], "nrs": [[1, 2], [1, 3]]}) + out = df.explode("nrs") + assert out["letters"].to_list() == ["c", "c", "a", "a"] + assert out["nrs"].to_list() == [1, 2, 1, 3] + + +@pytest.mark.parametrize( + ("stack", "exp_shape", "exp_columns"), + [ + ([pl.Series("stacked", [-1, -1, -1])], (3, 3), ["a", "b", "stacked"]), + ( + [pl.Series("stacked2", [-1, -1, -1]), pl.Series("stacked3", [-1, -1, -1])], + (3, 4), + ["a", "b", "stacked2", "stacked3"], + ), + ], +) +@pytest.mark.parametrize("in_place", [True, False]) +def test_hstack_list_of_series( + stack: list[pl.Series], + exp_shape: tuple[int, int], + exp_columns: list[str], + in_place: bool, +) -> None: + df = pl.DataFrame({"a": [2, 1, 3], "b": ["a", "b", "c"]}) + if in_place: + df.hstack(stack, in_place=True) + assert df.shape == exp_shape + assert df.columns == exp_columns + else: + df_out = df.hstack(stack, in_place=False) + assert df_out.shape == exp_shape + assert df_out.columns == exp_columns + + +@pytest.mark.parametrize("in_place", [True, False]) +def test_hstack_dataframe(in_place: bool) -> None: + df = pl.DataFrame({"a": [2, 1, 3], "b": ["a", "b", "c"]}) + df2 = pl.DataFrame({"c": [2, 1, 3], "d": ["a", "b", "c"]}) + expected = pl.DataFrame( + {"a": [2, 1, 3], "b": ["a", "b", "c"], "c": [2, 1, 3], "d": ["a", "b", "c"]} + ) + if in_place: + df.hstack(df2, in_place=True) + assert_frame_equal(df, expected) + else: + df_out = df.hstack(df2, in_place=False) + assert_frame_equal(df_out, expected) + + +def test_file_buffer() -> None: + f = BytesIO() + f.write(b"1,2,3,4,5,6\n7,8,9,10,11,12") + f.seek(0) + df = pl.read_csv(f, has_header=False) + assert df.shape == (2, 6) + + f = BytesIO() + f.write(b"1,2,3,4,5,6\n7,8,9,10,11,12") + f.seek(0) + # check if not fails on TryClone and Length impl in file.rs + with pytest.raises(ComputeError): + pl.read_parquet(f) + + +def test_shift() -> None: + df = pl.DataFrame({"A": ["a", "b", "c"], "B": [1, 3, 5]}) + a = df.shift(1) + b = pl.DataFrame( + {"A": [None, "a", "b"], "B": [None, 1, 3]}, + ) + assert_frame_equal(a, b) + + +def test_custom_group_by() -> None: + df = pl.DataFrame({"a": [1, 2, 1, 1], "b": ["a", "b", "c", "c"]}) + out = df.group_by("b", maintain_order=True).agg( + [pl.col("a").map_elements(lambda x: x.sum(), return_dtype=pl.Int64)] + ) + assert out.rows() == [("a", 1), ("b", 2), ("c", 2)] + + +def test_multiple_columns_drop() -> None: + df = pl.DataFrame({"a": [2, 1, 3], "b": [1, 2, 3], "c": [1, 2, 3]}) + # List input + out = df.drop(["a", "b"]) + assert out.columns == ["c"] + # Positional input + out = df.drop("b", "c") + assert out.columns == ["a"] + + +def test_arg_where() -> None: + s = pl.Series([True, False, True, False]) + assert_series_equal( + pl.arg_where(s, eager=True).cast(int), + pl.Series([0, 2]), + ) + + +def test_to_dummies() -> None: + df = pl.DataFrame({"A": ["a", "b", "c"], "B": [1, 3, 5]}) + dummies = df.to_dummies() + + assert dummies["A_a"].to_list() == [1, 0, 0] + assert dummies["A_b"].to_list() == [0, 1, 0] + assert dummies["A_c"].to_list() == [0, 0, 1] + + df = pl.DataFrame({"a": [1, 2, 3]}) + res = df.to_dummies() + + expected = pl.DataFrame( + {"a_1": [1, 0, 0], "a_2": [0, 1, 0], "a_3": [0, 0, 1]} + ).with_columns(pl.all().cast(pl.UInt8)) + assert_frame_equal(res, expected) + + df = pl.DataFrame( + { + "i": [1, 2, 3], + "category": ["dog", "cat", "cat"], + }, + schema={"i": pl.Int32, "category": pl.Categorical("lexical")}, + ) + expected = pl.DataFrame( + { + "i": [1, 2, 3], + "category|cat": [0, 1, 1], + "category|dog": [1, 0, 0], + }, + schema={"i": pl.Int32, "category|cat": pl.UInt8, "category|dog": pl.UInt8}, + ) + for _cols in ("category", cs.string()): + result = df.to_dummies(columns=["category"], separator="|") + assert_frame_equal(result, expected) + + # test sorted fast path + result = pl.DataFrame({"x": pl.arange(0, 3, eager=True)}).to_dummies("x") + expected = pl.DataFrame( + {"x_0": [1, 0, 0], "x_1": [0, 1, 0], "x_2": [0, 0, 1]} + ).with_columns(pl.all().cast(pl.UInt8)) + assert_frame_equal(result, expected) + + +def test_to_dummies_drop_first() -> None: + df = pl.DataFrame( + { + "foo": [0, 1, 2], + "bar": [3, 4, 5], + "baz": ["x", "y", "z"], + } + ) + dm = df.to_dummies() + dd = df.to_dummies(drop_first=True) + + assert dd.columns == ["foo_1", "foo_2", "bar_4", "bar_5", "baz_y", "baz_z"] + assert set(dm.columns) - set(dd.columns) == {"foo_0", "bar_3", "baz_x"} + assert_frame_equal(dm.select(dd.columns), dd) + assert dd.rows() == [ + (0, 0, 0, 0, 0, 0), + (1, 0, 1, 0, 1, 0), + (0, 1, 0, 1, 0, 1), + ] + + +def test_to_pandas(df: pl.DataFrame) -> None: + # pyarrow cannot deal with unsigned dictionary integer yet. + # pyarrow cannot convert a time64 w/ non-zero nanoseconds + df = df.drop(["cat", "time", "enum"]) + df.to_arrow() + df.to_pandas() + # test shifted df + df.shift(2).to_pandas() + df = pl.DataFrame({"col": pl.Series([True, False, True])}) + df.shift(2).to_pandas() + + +def test_from_arrow_table() -> None: + data = {"a": [1, 2], "b": [1, 2]} + tbl = pa.table(data) + + df = cast(pl.DataFrame, pl.from_arrow(tbl)) + assert_frame_equal(df, pl.DataFrame(data)) + + +def test_df_stats(df: pl.DataFrame) -> None: + df.var() + df.std() + df.min() + df.max() + df.sum() + df.mean() + df.median() + df.quantile(0.4, "nearest") + + +def test_df_fold() -> None: + df = pl.DataFrame({"a": [2, 1, 3], "b": [1, 2, 3], "c": [1.0, 2.0, 3.0]}) + + assert_series_equal( + df.fold(lambda s1, s2: s1 + s2), pl.Series("a", [4.0, 5.0, 9.0]) + ) + assert_series_equal( + df.fold(lambda s1, s2: s1.zip_with(s1 < s2, s2)), + pl.Series("a", [1.0, 1.0, 3.0]), + ) + + df = pl.DataFrame({"a": ["foo", "bar", "2"], "b": [1, 2, 3], "c": [1.0, 2.0, 3.0]}) + out = df.fold(lambda s1, s2: s1 + s2) + assert_series_equal(out, pl.Series("a", ["foo11.0", "bar22.0", "233.0"])) + + df = pl.DataFrame({"a": [3, 2, 1], "b": [1, 2, 3], "c": [1.0, 2.0, 3.0]}) + # just check dispatch. values are tested on rust side. + assert len(df.sum_horizontal()) == 3 + assert len(df.mean_horizontal()) == 3 + assert len(df.min_horizontal()) == 3 + assert len(df.max_horizontal()) == 3 + + df_width_one = df[["a"]] + assert_series_equal(df_width_one.fold(lambda s1, s2: s1), df["a"]) + + +def test_fold_filter() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [0, 1, 2]}) + + out = df.filter( + pl.fold( + acc=pl.lit(True), + function=lambda a, b: a & b, + exprs=[pl.col(c) > 1 for c in df.columns], + ) + ) + + assert out.shape == (1, 2) + assert out.rows() == [(3, 2)] + + out = df.filter( + pl.fold( + acc=pl.lit(True), + function=lambda a, b: a | b, + exprs=[pl.col(c) > 1 for c in df.columns], + ) + ) + + assert out.shape == (3, 2) + assert out.rows() == [(1, 0), (2, 1), (3, 2)] + + +def test_column_names() -> None: + tbl = pa.table( + { + "a": pa.array([1, 2, 3, 4, 5], pa.decimal128(38, 2)), + "b": pa.array([1, 2, 3, 4, 5], pa.int64()), + } + ) + for a in (tbl, tbl[:0]): + df = cast(pl.DataFrame, pl.from_arrow(a)) + assert df.columns == ["a", "b"] + + +def test_init_series_edge_cases() -> None: + # confirm that we don't modify the name of the input series in-place + s1 = pl.Series("X", [1, 2, 3]) + df1 = pl.DataFrame({"A": s1}, schema_overrides={"A": pl.UInt8}) + assert s1.name == "X" + assert df1["A"].name == "A" + + # init same series object under different names + df2 = pl.DataFrame({"A": s1, "B": s1}) + assert df2.rows(named=True) == [ + {"A": 1, "B": 1}, + {"A": 2, "B": 2}, + {"A": 3, "B": 3}, + ] + + # empty series names should not be overwritten + s2 = pl.Series([1, 2, 3]) + s3 = pl.Series([2, 3, 4]) + df3 = pl.DataFrame([s2, s3]) + assert s2.name == s3.name == "" + assert df3.columns == ["column_0", "column_1"] + + +def test_head_group_by() -> None: + commodity_prices = { + "commodity": [ + "Wheat", + "Wheat", + "Wheat", + "Wheat", + "Corn", + "Corn", + "Corn", + "Corn", + "Corn", + ], + "location": [ + "StPaul", + "StPaul", + "StPaul", + "Chicago", + "Chicago", + "Chicago", + "Chicago", + "Chicago", + "Chicago", + ], + "seller": [ + "Bob", + "Charlie", + "Susan", + "Paul", + "Ed", + "Mary", + "Paul", + "Charlie", + "Norman", + ], + "price": [1.0, 0.7, 0.8, 0.55, 2.0, 3.0, 2.4, 1.8, 2.1], + } + df = pl.DataFrame(commodity_prices) + + # this query flexes the wildcard exclusion quite a bit. + keys = ["commodity", "location"] + out = ( + df.sort(by="price", descending=True) + .group_by(keys, maintain_order=True) + .agg([pl.col("*").exclude(keys).head(2).name.keep()]) + .explode(pl.col("*").exclude(keys)) + ) + + assert out.shape == (5, 4) + assert out.rows() == [ + ("Corn", "Chicago", "Mary", 3.0), + ("Corn", "Chicago", "Paul", 2.4), + ("Wheat", "StPaul", "Bob", 1.0), + ("Wheat", "StPaul", "Susan", 0.8), + ("Wheat", "Chicago", "Paul", 0.55), + ] + + df = pl.DataFrame( + {"letters": ["c", "c", "a", "c", "a", "b"], "nrs": [1, 2, 3, 4, 5, 6]} + ) + out = df.group_by("letters").tail(2).sort("letters") + assert_frame_equal( + out, + pl.DataFrame({"letters": ["a", "a", "b", "c", "c"], "nrs": [3, 5, 6, 2, 4]}), + ) + out = df.group_by("letters").head(2).sort("letters") + assert_frame_equal( + out, + pl.DataFrame({"letters": ["a", "a", "b", "c", "c"], "nrs": [3, 5, 6, 1, 2]}), + ) + + +def test_is_null_is_not_null() -> None: + df = pl.DataFrame({"nrs": [1, 2, None]}) + assert df.select(pl.col("nrs").is_null())["nrs"].to_list() == [False, False, True] + assert df.select(pl.col("nrs").is_not_null())["nrs"].to_list() == [ + True, + True, + False, + ] + + +def test_is_nan_is_not_nan() -> None: + df = pl.DataFrame({"nrs": np.array([1, 2, np.nan])}) + assert df.select(pl.col("nrs").is_nan())["nrs"].to_list() == [False, False, True] + assert df.select(pl.col("nrs").is_not_nan())["nrs"].to_list() == [True, True, False] + + +def test_is_finite_is_infinite() -> None: + df = pl.DataFrame({"nrs": np.array([1, 2, np.inf])}) + assert df.select(pl.col("nrs").is_infinite())["nrs"].to_list() == [ + False, + False, + True, + ] + assert df.select(pl.col("nrs").is_finite())["nrs"].to_list() == [True, True, False] + + +def test_is_finite_is_infinite_null_series() -> None: + df = pl.DataFrame({"a": pl.Series([None, None, None], dtype=pl.Null)}) + result = df.select( + pl.col("a").is_finite().alias("finite"), + pl.col("a").is_infinite().alias("infinite"), + ) + expected = pl.DataFrame( + { + "finite": pl.Series([None, None, None], dtype=pl.Boolean), + "infinite": pl.Series([None, None, None], dtype=pl.Boolean), + } + ) + assert_frame_equal(result, expected) + + +def test_is_nan_null_series() -> None: + df = pl.DataFrame({"a": pl.Series([None, None, None], dtype=pl.Null)}) + result = df.select(pl.col("a").is_nan()) + expected = pl.DataFrame({"a": pl.Series([None, None, None], dtype=pl.Boolean)}) + assert_frame_equal(result, expected) + + +def test_len() -> None: + df = pl.DataFrame({"nrs": [1, 2, 3]}) + assert cast(int, df.select(pl.col("nrs").len()).item()) == 3 + assert len(pl.DataFrame()) == 0 + + +def test_multiple_column_sort() -> None: + df = pl.DataFrame({"a": ["foo", "bar", "2"], "b": [2, 2, 3], "c": [1.0, 2.0, 3.0]}) + out = df.sort([pl.col("b"), pl.col("c").reverse()]) + assert list(out["c"]) == [2.0, 1.0, 3.0] + assert list(out["b"]) == [2, 2, 3] + + # Explicitly specify numpy dtype because of different defaults on Windows + df = pl.DataFrame({"a": np.arange(1, 4, dtype=np.int64), "b": ["a", "a", "b"]}) + + assert_frame_equal( + df.sort("a", descending=True), + pl.DataFrame({"a": [3, 2, 1], "b": ["b", "a", "a"]}), + ) + assert_frame_equal( + df.sort("b", descending=True, maintain_order=True), + pl.DataFrame({"a": [3, 1, 2], "b": ["b", "a", "a"]}), + ) + assert_frame_equal( + df.sort(["b", "a"], descending=[False, True]), + pl.DataFrame({"a": [2, 1, 3], "b": ["a", "a", "b"]}), + ) + + +def test_cast_frame() -> None: + df = pl.DataFrame( + { + "a": [1.0, 2.5, 3.0], + "b": [4, 5, None], + "c": [True, False, True], + "d": [date(2020, 1, 2), date(2021, 3, 4), date(2022, 5, 6)], + } + ) + + # cast via col:dtype map + assert df.cast( + dtypes={"b": pl.Float32, "c": pl.String, "d": pl.Datetime("ms")}, + ).schema == { + "a": pl.Float64, + "b": pl.Float32, + "c": pl.String, + "d": pl.Datetime("ms"), + } + + # cast via col:pytype map + assert df.cast( + dtypes={"b": float, "c": str, "d": datetime}, + ).schema == { + "a": pl.Float64, + "b": pl.Float64, + "c": pl.String, + "d": pl.Datetime("us"), + } + + # cast via selector:dtype map + assert df.cast( + { + cs.numeric(): pl.UInt8, + cs.temporal(): pl.String, + } + ).rows() == [ + (1, 4, True, "2020-01-02"), + (2, 5, False, "2021-03-04"), + (3, None, True, "2022-05-06"), + ] + + # cast all fields to a single type + assert df.cast(pl.String).to_dict(as_series=False) == { + "a": ["1.0", "2.5", "3.0"], + "b": ["4", "5", None], + "c": ["true", "false", "true"], + "d": ["2020-01-02", "2021-03-04", "2022-05-06"], + } + + +def test_duration_arithmetic() -> None: + df = pl.DataFrame( + {"a": [datetime(2022, 1, 1, 0, 0, 0), datetime(2022, 1, 2, 0, 0, 0)]} + ) + d1 = pl.duration(days=3, microseconds=987000) + d2 = pl.duration(days=6, milliseconds=987) + + assert_frame_equal( + df.with_columns( + b=(df["a"] + d1), + c=(pl.col("a") + d2), + ), + pl.DataFrame( + { + "a": [ + datetime(2022, 1, 1, 0, 0, 0), + datetime(2022, 1, 2, 0, 0, 0), + ], + "b": [ + datetime(2022, 1, 4, 0, 0, 0, 987000), + datetime(2022, 1, 5, 0, 0, 0, 987000), + ], + "c": [ + datetime(2022, 1, 7, 0, 0, 0, 987000), + datetime(2022, 1, 8, 0, 0, 0, 987000), + ], + } + ), + ) + + +def test_assign() -> None: + # check if can assign in case of a single column + df = pl.DataFrame({"a": [1, 2, 3]}) + # test if we can assign in case of single column + df = df.with_columns(pl.col("a") * 2) + assert list(df["a"]) == [2, 4, 6] + + +def test_arg_sort_by(df: pl.DataFrame) -> None: + idx_df = df.select( + pl.arg_sort_by(["int_nulls", "floats"], descending=[False, True]).alias("idx") + ) + assert (idx_df["idx"] == [1, 0, 2]).all() + + idx_df = df.select( + pl.arg_sort_by(["int_nulls", "floats"], descending=False).alias("idx") + ) + assert (idx_df["idx"] == [1, 0, 2]).all() + + df = pl.DataFrame({"x": [0, 0, 0, 1, 1, 2], "y": [9, 9, 8, 7, 6, 6]}) + for expr, expected in ( + (pl.arg_sort_by(["x", "y"]), [2, 0, 1, 4, 3, 5]), + (pl.arg_sort_by(["x", "y"], descending=[True, True]), [5, 3, 4, 0, 1, 2]), + (pl.arg_sort_by(["x", "y"], descending=[True, False]), [5, 4, 3, 2, 0, 1]), + (pl.arg_sort_by(["x", "y"], descending=[False, True]), [0, 1, 2, 3, 4, 5]), + ): + assert (df.select(expr.alias("idx"))["idx"] == expected).all() + + +def test_literal_series() -> None: + df = pl.DataFrame( + { + "a": np.array([21.7, 21.8, 21], dtype=np.float32), + "b": np.array([1, 3, 2], dtype=np.int8), + "c": ["reg1", "reg2", "reg3"], + "d": np.array( + [datetime(2022, 8, 16), datetime(2022, 8, 17), datetime(2022, 8, 18)], + dtype=" None: + df = pl.DataFrame( + { + "foo": [1, 2, 3, 4, 5], + "bar": [6, 7, 8, 9, 10], + "ham": ["a", "b", "c", "d", "e"], + } + ) + expected = "foo,bar,ham\n1,6,a\n2,7,b\n3,8,c\n4,9,d\n5,10,e\n" + + # if no file argument is supplied, write_csv() will return the string + s = df.write_csv() + assert s == expected + + # otherwise it will write to the file/iobuffer + file = BytesIO() + df.write_csv(file) + file.seek(0) + s = file.read().decode("utf8") + assert s == expected + + +def test_from_generator_or_iterable() -> None: + # generator function + def gen(n: int, *, strkey: bool = True) -> Iterator[Any]: + for i in range(n): + yield (str(i) if strkey else i), 1 * i, 2**i, 3**i + + # iterable object + class Rows: + def __init__(self, n: int, *, strkey: bool = True) -> None: + self._n = n + self._strkey = strkey + + def __iter__(self) -> Iterator[Any]: + yield from gen(self._n, strkey=self._strkey) + + # check init from column-oriented generator + assert_frame_equal( + pl.DataFrame(data=gen(4, strkey=False), orient="col"), + pl.DataFrame( + data=[(0, 0, 1, 1), (1, 1, 2, 3), (2, 2, 4, 9), (3, 3, 8, 27)], orient="col" + ), + ) + # check init from row-oriented generators (more common) + expected = pl.DataFrame( + data=list(gen(4)), schema=["a", "b", "c", "d"], orient="row" + ) + for generated_frame in ( + pl.DataFrame(data=gen(4), schema=["a", "b", "c", "d"]), + pl.DataFrame(data=Rows(4), schema=["a", "b", "c", "d"]), + pl.DataFrame(data=(x for x in Rows(4)), schema=["a", "b", "c", "d"]), + ): + assert_frame_equal(expected, generated_frame) + assert generated_frame.schema == { + "a": pl.String, + "b": pl.Int64, + "c": pl.Int64, + "d": pl.Int64, + } + + # test 'iterable_to_pydf' directly to validate 'chunk_size' behaviour + cols = ["a", "b", ("c", pl.Int8), "d"] + + expected_data = [("0", 0, 1, 1), ("1", 1, 2, 3), ("2", 2, 4, 9), ("3", 3, 8, 27)] + expected_schema = [ + ("a", pl.String), + ("b", pl.Int64), + ("c", pl.Int8), + ("d", pl.Int64), + ] + + for params in ( + {"data": Rows(4)}, + {"data": gen(4), "chunk_size": 2}, + {"data": Rows(4), "chunk_size": 3}, + {"data": gen(4), "infer_schema_length": None}, + {"data": Rows(4), "infer_schema_length": 1}, + {"data": gen(4), "chunk_size": 2}, + {"data": Rows(4), "infer_schema_length": 5}, + {"data": gen(4), "infer_schema_length": 3, "chunk_size": 2}, + {"data": gen(4), "infer_schema_length": None, "chunk_size": 3}, + ): + d = iterable_to_pydf(schema=cols, **params) # type: ignore[arg-type] + assert expected_data == d.row_tuples() + assert expected_schema == list(zip(d.columns(), d.dtypes())) + + # ref: issue #6489 (initial chunk_size cannot be smaller than 'infer_schema_length') + df = pl.DataFrame( + data=iter(([{"col": None}] * 1000) + [{"col": ["a", "b", "c"]}]), + infer_schema_length=1001, + ) + assert df.schema == {"col": pl.List(pl.String)} + assert df[-2:]["col"].to_list() == [None, ["a", "b", "c"]] + + # empty iterator + assert_frame_equal( + pl.DataFrame(data=gen(0), schema=["a", "b", "c", "d"]), + pl.DataFrame(schema=["a", "b", "c", "d"]), + ) + + +def test_from_rows() -> None: + df = pl.from_records([[1, 2, "foo"], [2, 3, "bar"]], orient="row") + assert_frame_equal( + df, + pl.DataFrame( + {"column_0": [1, 2], "column_1": [2, 3], "column_2": ["foo", "bar"]} + ), + ) + df = pl.from_records( + [[1, datetime.fromtimestamp(100)], [2, datetime.fromtimestamp(2398754908)]], + schema_overrides={"column_0": pl.UInt32}, + orient="row", + ) + assert df.dtypes == [pl.UInt32, pl.Datetime] + + # auto-inference with same num rows/cols + data = [(1, 2, "foo"), (2, 3, "bar"), (3, 4, "baz")] + df = pl.from_records(data, orient="row") + assert data == df.rows() + + +def test_from_rows_of_dicts() -> None: + records = [ + {"id": 1, "value": 100, "_meta": "a"}, + {"id": 2, "value": 101, "_meta": "b"}, + ] + df_init: Callable[..., Any] + for df_init in (pl.from_dicts, pl.DataFrame): + df1 = df_init(records) + assert df1.rows() == [(1, 100, "a"), (2, 101, "b")] + + overrides = { + "id": pl.Int16, + "value": pl.Int32, + } + df2 = df_init(records, schema_overrides=overrides) + assert df2.rows() == [(1, 100, "a"), (2, 101, "b")] + assert df2.schema == {"id": pl.Int16, "value": pl.Int32, "_meta": pl.String} + + df3 = df_init(records, schema=overrides) + assert df3.rows() == [(1, 100), (2, 101)] + assert df3.schema == {"id": pl.Int16, "value": pl.Int32} + + +def test_from_records_with_schema_overrides_12032() -> None: + # the 'id' fields contains an int value that exceeds Int64 and doesn't have an exact + # Float64 representation; confirm that the override is applied *during* inference, + # not as a post-inference cast, so we maintain the accuracy of the original value. + rec = [ + {"id": 9187643043065364490, "x": 333, "y": None}, + {"id": 9223671840084328467, "x": 666.5, "y": 1698177261953686}, + {"id": 9187643043065364505, "x": 999, "y": 9223372036854775807}, + ] + df = pl.from_records(rec, schema_overrides={"x": pl.Float32, "id": pl.UInt64}) + assert df.schema == OrderedDict( + [ + ("id", pl.UInt64), + ("x", pl.Float32), + ("y", pl.Int64), + ] + ) + assert rec == df.rows(named=True) + + +def test_from_large_uint64_misc() -> None: + uint_data = [[9187643043065364490, 9223671840084328467, 9187643043065364505]] + + df = pl.DataFrame(uint_data, orient="col", schema_overrides={"column_0": pl.UInt64}) + assert df["column_0"].dtype == pl.UInt64 + assert df["column_0"].to_list() == uint_data[0] + + for overrides in ({}, {"column_1": pl.UInt64}): + df = pl.DataFrame( + uint_data, + orient="row", + schema_overrides=overrides, + ) + assert df.schema == OrderedDict( + [ + ("column_0", pl.Int64), + ("column_1", pl.Int128 if overrides == {} else pl.UInt64), + ("column_2", pl.Int64), + ] + ) + assert df.row(0) == tuple(uint_data[0]) + + +def test_repeat_by_unequal_lengths_panic() -> None: + df = pl.DataFrame( + { + "a": ["x", "y", "z"], + } + ) + with pytest.raises(ShapeError): + df.select(pl.col("a").repeat_by(pl.Series([2, 2]))) + + +@pytest.mark.parametrize( + ("value", "values_expect"), + [ + (1.2, [[1.2], [1.2, 1.2], [1.2, 1.2, 1.2]]), + (True, [[True], [True, True], [True, True, True]]), + ("x", [["x"], ["x", "x"], ["x", "x", "x"]]), + (b"a", [[b"a"], [b"a", b"a"], [b"a", b"a", b"a"]]), + ], +) +def test_repeat_by_broadcast_left( + value: float | bool | str, values_expect: list[list[float | bool | str]] +) -> None: + df = pl.DataFrame( + { + "n": [1, 2, 3], + } + ) + expected = pl.DataFrame({"values": values_expect}) + result = df.select(pl.lit(value).repeat_by(pl.col("n")).alias("values")) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + ("a", "a_expected"), + [ + ([1.2, 2.2, 3.3], [[1.2, 1.2, 1.2], [2.2, 2.2, 2.2], [3.3, 3.3, 3.3]]), + ([True, False], [[True, True, True], [False, False, False]]), + (["x", "y", "z"], [["x", "x", "x"], ["y", "y", "y"], ["z", "z", "z"]]), + ( + [b"a", b"b", b"c"], + [[b"a", b"a", b"a"], [b"b", b"b", b"b"], [b"c", b"c", b"c"]], + ), + ], +) +def test_repeat_by_broadcast_right( + a: list[float | bool | str], a_expected: list[list[float | bool | str]] +) -> None: + df = pl.DataFrame( + { + "a": a, + } + ) + expected = pl.DataFrame({"a": a_expected}) + result = df.select(pl.col("a").repeat_by(3)) + assert_frame_equal(result, expected) + result = df.select(pl.col("a").repeat_by(pl.lit(3))) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + ("a", "a_expected"), + [ + (["foo", "bar"], [["foo", "foo"], ["bar", "bar", "bar"]]), + ([1, 2], [[1, 1], [2, 2, 2]]), + ([True, False], [[True, True], [False, False, False]]), + ( + [b"a", b"b"], + [[b"a", b"a"], [b"b", b"b", b"b"]], + ), + ], +) +def test_repeat_by( + a: list[float | bool | str], a_expected: list[list[float | bool | str]] +) -> None: + df = pl.DataFrame({"a": a, "n": [2, 3]}) + expected = pl.DataFrame({"a": a_expected}) + result = df.select(pl.col("a").repeat_by("n")) + assert_frame_equal(result, expected) + + +def test_join_dates() -> None: + dts_in = pl.datetime_range( + datetime(2021, 6, 24), + datetime(2021, 6, 24, 10, 0, 0), + interval=timedelta(hours=1), + closed="left", + eager=True, + ) + dts = ( + dts_in.cast(int) + .map_elements(lambda x: x + np.random.randint(1_000 * 60, 60_000 * 60)) + .cast(pl.Datetime) + ) + + # some df with sensor id, (randomish) datetime and some value + df = pl.DataFrame( + { + "sensor": ["a"] * 5 + ["b"] * 5, + "datetime": dts, + "value": [2, 3, 4, 1, 2, 3, 5, 1, 2, 3], + } + ) + out = df.join(df, on="datetime") + assert out.height == df.height + + +def test_asof_cross_join() -> None: + left = pl.DataFrame({"a": [-10, 5, 10], "left_val": ["a", "b", "c"]}).with_columns( + pl.col("a").set_sorted() + ) + right = pl.DataFrame( + {"a": [1, 2, 3, 6, 7], "right_val": [1, 2, 3, 6, 7]} + ).with_columns(pl.col("a").set_sorted()) + + # only test dispatch of asof join + out = left.join_asof(right, on="a") + assert out.shape == (3, 3) + + left.lazy().join_asof(right.lazy(), on="a").collect() + assert out.shape == (3, 3) + + # only test dispatch of cross join + out = left.join(right, how="cross") + assert out.shape == (15, 4) + + left.lazy().join(right.lazy(), how="cross").collect() + assert out.shape == (15, 4) + + +def test_str_concat() -> None: + df = pl.DataFrame( + { + "nrs": [1, 2, 3, 4], + "name": ["ham", "spam", "foo", None], + } + ) + out = df.with_columns((pl.lit("Dr. ") + pl.col("name")).alias("graduated_name")) + assert out["graduated_name"][0] == "Dr. ham" + assert out["graduated_name"][1] == "Dr. spam" + + +def test_dot_product() -> None: + df = pl.DataFrame({"a": [1, 2, 3, 4], "b": [2, 2, 2, 2]}) + + assert df["a"].dot(df["b"]) == 20 + assert typing.cast(int, df.select([pl.col("a").dot("b")])[0, "a"]) == 20 + + result = pl.Series([1, 2, 3]) @ pl.Series([4, 5, 6]) + assert isinstance(result, int) + assert result == 32 + + result = pl.Series([1, 2, 3]) @ pl.Series([4.0, 5.0, 6.0]) + assert isinstance(result, float) + assert result == 32.0 + + result = pl.Series([1.0, 2.0, 3.0]) @ pl.Series([4.0, 5.0, 6.0]) + assert isinstance(result, float) + assert result == 32.0 + + with pytest.raises( + InvalidOperationError, match="`dot` operation not supported for dtype `bool`" + ): + pl.Series([True, False, False, True]) @ pl.Series([4, 5, 6, 7]) + + with pytest.raises( + InvalidOperationError, match="`dot` operation not supported for dtype `str`" + ): + pl.Series([1, 2, 3, 4]) @ pl.Series(["True", "False", "False", "True"]) + + +def test_hash_rows() -> None: + df = pl.DataFrame({"a": [1, 2, 3, 4], "b": [2, 2, 2, 2]}) + assert df.hash_rows().dtype == pl.UInt64 + assert df["a"].hash().dtype == pl.UInt64 + assert df.select([pl.col("a").hash().alias("foo")])["foo"].dtype == pl.UInt64 + + +def test_reproducible_hash_with_seeds() -> None: + """ + Test the reproducibility of DataFrame.hash_rows, Series.hash, and Expr.hash. + + cf. issue #3966, hashes must always be reproducible across sessions when using + the same seeds. + """ + df = pl.DataFrame({"s": [1234, None, 5678]}) + seeds = (11, 22, 33, 44) + expected = pl.Series( + "s", + [10832467230526607564, 3044502640115867787, 17228373233104406792], + dtype=pl.UInt64, + ) + result = df.hash_rows(*seeds) + assert_series_equal(expected, result, check_names=False, check_exact=True) + result = df["s"].hash(*seeds) + assert_series_equal(expected, result, check_names=False, check_exact=True) + result = df.select([pl.col("s").hash(*seeds)])["s"] + assert_series_equal(expected, result, check_names=False, check_exact=True) + + +@pytest.mark.slow +@pytest.mark.parametrize( + "e", + [ + pl.int_range(1_000_000), + # Test code path for null_count > 0 + pl.when(pl.int_range(1_000_000) != 0).then(pl.int_range(1_000_000)), + ], +) +def test_hash_collision_multiple_columns_equal_values_15390(e: pl.Expr) -> None: + df = pl.select(e.alias("a")) + + for n_columns in (1, 2, 3, 4): + s = df.select(pl.col("a").alias(f"x{i}") for i in range(n_columns)).hash_rows() + + vc = s.sort().value_counts(sort=True) + max_bucket_size = vc["count"][0] + + assert max_bucket_size == 1 + + +@pytest.mark.may_fail_auto_streaming # Python objects not yet supported in row encoding +def test_hashing_on_python_objects() -> None: + # see if we can do a group_by, drop_duplicates on a DataFrame with objects. + # this requires that the hashing and aggregations are done on python objects + + df = pl.DataFrame({"a": [1, 1, 3, 4], "b": [1, 1, 2, 2]}) + + class Foo: + def __hash__(self) -> int: + return 0 + + def __eq__(self, other: object) -> bool: + return True + + df = df.with_columns(pl.col("a").map_elements(lambda x: Foo()).alias("foo")) + assert df.group_by(["foo"]).first().shape == (1, 3) + assert df.unique().shape == (3, 3) + + +def test_unique_unit_rows() -> None: + df = pl.DataFrame({"a": [1], "b": [None]}, schema={"a": pl.Int64, "b": pl.Float32}) + + # 'unique' one-row frame should be equal to the original frame + assert_frame_equal(df, df.unique(subset="a")) + for col in df.columns: + assert df.n_unique(subset=[col]) == 1 + + +def test_panic() -> None: + # may contain some tests that yielded a panic in polars or pl_arrow + # https://github.com/pola-rs/polars/issues/1110 + a = pl.DataFrame( + { + "col1": ["a"] * 500 + ["b"] * 500, + } + ) + a.filter(pl.col("col1") != "b") + + +def test_horizontal_agg() -> None: + df = pl.DataFrame({"a": [1, None, 3], "b": [1, 2, 3]}) + + assert_series_equal(df.sum_horizontal(), pl.Series("sum", [2, 2, 6])) + assert_series_equal( + df.sum_horizontal(ignore_nulls=False), pl.Series("sum", [2, None, 6]) + ) + assert_series_equal( + df.mean_horizontal(ignore_nulls=False), pl.Series("mean", [1.0, None, 3.0]) + ) + + +def test_slicing() -> None: + # https://github.com/pola-rs/polars/issues/1322 + n = 20 + + df = pl.DataFrame( + { + "d": ["u", "u", "d", "c", "c", "d", "d"] * n, + "v1": [None, "help", None, None, None, None, None] * n, + } + ) + + assert (df.filter(pl.col("d") != "d").select([pl.col("v1").unique()])).shape == ( + 2, + 1, + ) + + +def test_group_by_cat_list() -> None: + grouped = ( + pl.DataFrame( + [ + pl.Series("str_column", ["a", "b", "b", "a", "b"]), + pl.Series("int_column", [1, 1, 2, 2, 3]), + ] + ) + .with_columns(pl.col("str_column").cast(pl.Categorical).alias("cat_column")) + .group_by("int_column", maintain_order=True) + .agg([pl.col("cat_column")])["cat_column"] + ) + + out = grouped.explode() + assert out.dtype == pl.Categorical + assert out[0] == "a" + + +def test_group_by_agg_n_unique_floats() -> None: + # tests proper dispatch + df = pl.DataFrame({"a": [1, 1, 3], "b": [1.0, 2.0, 2.0]}) + + for dtype in [pl.Float32, pl.Float64]: + out = df.group_by("a", maintain_order=True).agg( + [pl.col("b").cast(dtype).n_unique()] + ) + assert out["b"].to_list() == [2, 1] + + +def test_group_by_agg_n_unique_empty_group_idx_path() -> None: + df = pl.DataFrame( + { + "key": [1, 1, 1, 2, 2, 2], + "value": [1, 2, 3, 4, 5, 6], + "filt": [True, True, True, False, False, False], + } + ) + out = df.group_by("key", maintain_order=True).agg( + pl.col("value").filter("filt").n_unique().alias("n_unique") + ) + expected = pl.DataFrame( + { + "key": [1, 2], + "n_unique": pl.Series([3, 0], dtype=pl.UInt32), + } + ) + assert_frame_equal(out, expected) + + +def test_group_by_agg_n_unique_empty_group_slice_path() -> None: + df = pl.DataFrame( + { + "key": [1, 1, 1, 2, 2, 2], + "value": [1, 2, 3, 4, 5, 6], + "filt": [False, False, False, False, False, False], + } + ) + out = df.group_by("key", maintain_order=True).agg( + pl.col("value").filter("filt").n_unique().alias("n_unique") + ) + expected = pl.DataFrame( + { + "key": [1, 2], + "n_unique": pl.Series([0, 0], dtype=pl.UInt32), + } + ) + assert_frame_equal(out, expected) + + +def test_select_by_dtype(df: pl.DataFrame) -> None: + out = df.select(pl.col(pl.String)) + assert out.columns == ["strings", "strings_nulls"] + out = df.select(pl.col([pl.String, pl.Boolean])) + assert out.columns == ["bools", "bools_nulls", "strings", "strings_nulls"] + out = df.select(pl.col(INTEGER_DTYPES)) + assert out.columns == ["int", "int_nulls"] + + with pl.Config() as cfg: + cfg.set_auto_structify(True) + out = df.select(ints=pl.col(INTEGER_DTYPES)) + assert out.schema == { + "ints": pl.Struct( + [pl.Field("int", pl.Int64), pl.Field("int_nulls", pl.Int64)] + ) + } + + +def test_with_row_index() -> None: + df = pl.DataFrame({"a": [1, 1, 3], "b": [1.0, 2.0, 2.0]}) + + out = df.with_row_index() + assert out["index"].to_list() == [0, 1, 2] + + out = df.lazy().with_row_index().collect() + assert out["index"].to_list() == [0, 1, 2] + + +def test_with_row_index_bad_offset() -> None: + df = pl.DataFrame({"a": [1, 1, 3], "b": [1.0, 2.0, 2.0]}) + + with pytest.raises(ValueError, match="cannot be negative"): + df.with_row_index(offset=-1) + with pytest.raises( + ValueError, match="cannot be greater than the maximum index value" + ): + df.with_row_index(offset=2**32) + + +def test_with_row_index_bad_offset_lazy() -> None: + lf = pl.LazyFrame({"a": [1, 1, 3], "b": [1.0, 2.0, 2.0]}) + + with pytest.raises(ValueError, match="cannot be negative"): + lf.with_row_index(offset=-1) + with pytest.raises( + ValueError, match="cannot be greater than the maximum index value" + ): + lf.with_row_index(offset=2**32) + + +def test_with_row_count_deprecated() -> None: + df = pl.DataFrame({"a": [1, 1, 3], "b": [1.0, 2.0, 2.0]}) + + with pytest.deprecated_call(): + out = df.with_row_count() + assert out["row_nr"].to_list() == [0, 1, 2] + + with pytest.deprecated_call(): + out = df.lazy().with_row_count().collect() + assert out["row_nr"].to_list() == [0, 1, 2] + + +def test_filter_with_all_expansion() -> None: + df = pl.DataFrame( + { + "b": [1, 2, None], + "c": [1, 2, None], + "a": [None, None, None], + } + ) + out = df.filter(~pl.fold(True, lambda acc, s: acc & s.is_null(), pl.all())) + assert out.shape == (2, 3) + + +# TODO: investigate this discrepancy in auto streaming +@pytest.mark.may_fail_auto_streaming +def test_extension() -> None: + class Foo: + def __init__(self, value: Any) -> None: + self.value = value + + def __repr__(self) -> str: + return f"foo({self.value})" + + foos = [Foo(1), Foo(2), Foo(3)] + + # foos and sys.getrefcount both have a reference. + base_count = 2 + + # We compute the refcount on a separate line otherwise pytest's assert magic + # might add reference counts. + rc = sys.getrefcount(foos[0]) + assert rc == base_count + + df = pl.DataFrame({"groups": [1, 1, 2], "a": foos}) + rc = sys.getrefcount(foos[0]) + assert rc == base_count + 1 + del df + rc = sys.getrefcount(foos[0]) + assert rc == base_count + + df = pl.DataFrame({"groups": [1, 1, 2], "a": foos}) + rc = sys.getrefcount(foos[0]) + assert rc == base_count + 1 + + out = df.group_by("groups", maintain_order=True).agg(pl.col("a").alias("a")) + rc = sys.getrefcount(foos[0]) + assert rc == base_count + 2 + s = out["a"].list.explode() + rc = sys.getrefcount(foos[0]) + assert rc == base_count + 3 + del s + rc = sys.getrefcount(foos[0]) + assert rc == base_count + 2 + + assert out["a"].list.explode().to_list() == foos + rc = sys.getrefcount(foos[0]) + assert rc == base_count + 2 + del out + rc = sys.getrefcount(foos[0]) + assert rc == base_count + 1 + del df + rc = sys.getrefcount(foos[0]) + assert rc == base_count + + +@pytest.mark.parametrize("name", [None, "n", ""]) +def test_group_by_order_dispatch(name: str | None) -> None: + df = pl.DataFrame({"x": list("bab"), "y": range(3)}) + lf = df.lazy() + + result = df.group_by("x", maintain_order=True).len(name=name) + lazy_result = lf.group_by("x").len(name=name).sort(by="x", descending=True) + + name = "len" if name is None else name + expected = pl.DataFrame( + data={"x": ["b", "a"], name: [2, 1]}, + schema_overrides={name: pl.UInt32}, + ) + assert_frame_equal(result, expected) + assert_frame_equal(lazy_result.collect(), expected) + + result = df.group_by("x", maintain_order=True).all() + expected = pl.DataFrame({"x": ["b", "a"], "y": [[0, 2], [1]]}) + assert_frame_equal(result, expected) + + +def test_partitioned_group_by_order() -> None: + # check if group ordering is maintained. + # we only have 30 groups, so this triggers a partitioned group by + df = pl.DataFrame({"x": [chr(v) for v in range(33, 63)], "y": range(30)}) + out = df.group_by("x", maintain_order=True).agg(pl.all().implode()) + assert_series_equal(out["x"], df["x"]) + + +def test_schema() -> None: + df = pl.DataFrame( + {"foo": [1, 2, 3], "bar": [6.0, 7.0, 8.0], "ham": ["a", "b", "c"]} + ) + expected = {"foo": pl.Int64, "bar": pl.Float64, "ham": pl.String} + assert df.schema == expected + + +def test_schema_equality() -> None: + lf = pl.LazyFrame({"foo": [1, 2, 3], "bar": [6.0, 7.0, 8.0]}) + lf_rev = lf.select("bar", "foo") + + assert lf.collect_schema() != lf_rev.collect_schema() + assert lf.collect().schema != lf_rev.collect().schema + + +def test_df_schema_unique() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + with pytest.raises(DuplicateError): + df.columns = ["a", "a"] + + with pytest.raises(DuplicateError): + df.rename({"b": "a"}) + + +def test_empty_projection() -> None: + empty_df = pl.DataFrame({"a": [1, 2], "b": [3, 4]}).select([]) + assert empty_df.rows() == [] + assert empty_df.schema == {} + assert empty_df.shape == (0, 0) + + +@pytest.mark.may_fail_auto_streaming +def test_fill_null() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, None]}) + assert_frame_equal(df.fill_null(4), pl.DataFrame({"a": [1, 2], "b": [3, 4]})) + assert_frame_equal( + df.fill_null(strategy="max"), pl.DataFrame({"a": [1, 2], "b": [3, 3]}) + ) + + # string and list data + # string goes via binary + df = pl.DataFrame( + { + "c": [ + ["Apple", "Orange"], + ["Apple", "Orange"], + None, + ["Carrot"], + None, + None, + ], + "b": ["Apple", "Orange", None, "Carrot", None, None], + } + ) + + assert df.select( + pl.all().fill_null(strategy="forward").name.suffix("_forward"), + pl.all().fill_null(strategy="backward").name.suffix("_backward"), + ).to_dict(as_series=False) == { + "c_forward": [ + ["Apple", "Orange"], + ["Apple", "Orange"], + ["Apple", "Orange"], + ["Carrot"], + ["Carrot"], + ["Carrot"], + ], + "b_forward": ["Apple", "Orange", "Orange", "Carrot", "Carrot", "Carrot"], + "c_backward": [ + ["Apple", "Orange"], + ["Apple", "Orange"], + ["Carrot"], + ["Carrot"], + None, + None, + ], + "b_backward": ["Apple", "Orange", "Carrot", "Carrot", None, None], + } + # categoricals + df = pl.DataFrame(pl.Series("cat", ["a", None], dtype=pl.Categorical)) + s = df.select(pl.col("cat").fill_null(strategy="forward"))["cat"] + assert s.dtype == pl.Categorical + assert s.to_list() == ["a", "a"] + + +def test_fill_nan() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3.0, float("nan")]}) + assert_frame_equal( + df.fill_nan(4), + pl.DataFrame({"a": [1, 2], "b": [3.0, 4.0]}), + ) + assert_frame_equal( + df.fill_nan(None), + pl.DataFrame({"a": [1, 2], "b": [3.0, None]}), + ) + assert df["b"].fill_nan(5.0).to_list() == [3.0, 5.0] + df = pl.DataFrame( + { + "a": [1.0, np.nan, 3.0], + "b": [datetime(1, 2, 2), datetime(2, 2, 2), datetime(3, 2, 2)], + } + ) + assert df.fill_nan(2.0).dtypes == [pl.Float64, pl.Datetime] + + +# +def test_forward_fill() -> None: + df = pl.DataFrame({"a": [1.0, None, 3.0]}) + fill = df.select(pl.col("a").forward_fill())["a"] + assert_series_equal(fill, pl.Series("a", [1, 1, 3]).cast(pl.Float64)) + + df = pl.DataFrame({"a": [None, 1, None]}) + fill = df.select(pl.col("a").forward_fill())["a"] + assert_series_equal(fill, pl.Series("a", [None, 1, 1]).cast(pl.Int64)) + + +def test_backward_fill() -> None: + df = pl.DataFrame({"a": [1.0, None, 3.0]}) + fill = df.select(pl.col("a").backward_fill())["a"] + assert_series_equal(fill, pl.Series("a", [1, 3, 3]).cast(pl.Float64)) + + df = pl.DataFrame({"a": [None, 1, None]}) + fill = df.select(pl.col("a").backward_fill())["a"] + assert_series_equal(fill, pl.Series("a", [1, 1, None]).cast(pl.Int64)) + + +def test_shrink_to_fit() -> None: + df = pl.DataFrame({"foo": [1, 2, 3], "bar": [6, 7, 8], "ham": ["a", "b", "c"]}) + + assert df.shrink_to_fit(in_place=True) is df + assert df.shrink_to_fit(in_place=False) is not df + assert_frame_equal(df.shrink_to_fit(in_place=False), df) + + +def test_add_string() -> None: + df = pl.DataFrame({"a": ["hi", "there"], "b": ["hello", "world"]}) + expected = pl.DataFrame( + {"a": ["hi hello", "there hello"], "b": ["hello hello", "world hello"]} + ) + assert_frame_equal((df + " hello"), expected) + + expected = pl.DataFrame( + {"a": ["hello hi", "hello there"], "b": ["hello hello", "hello world"]} + ) + assert_frame_equal(("hello " + df), expected) + + +def test_df_broadcast() -> None: + df = pl.DataFrame({"a": [1, 2, 3]}, schema_overrides={"a": pl.UInt8}) + out = df.with_columns(pl.lit(pl.Series("s", [[1, 2]])).first()) + assert out.shape == (3, 2) + assert out.schema == {"a": pl.UInt8, "s": pl.List(pl.Int64)} + assert out.rows() == [(1, [1, 2]), (2, [1, 2]), (3, [1, 2])] + + +def test_product() -> None: + df = pl.DataFrame( + { + "int": [1, 2, 3], + "flt": [-1.0, 12.0, 9.0], + "bool_0": [True, False, True], + "bool_1": [True, True, True], + "str": ["a", "b", "c"], + }, + schema_overrides={ + "int": pl.UInt16, + "flt": pl.Float32, + }, + ) + out = df.product() + expected = pl.DataFrame( + {"int": [6], "flt": [-108.0], "bool_0": [0], "bool_1": [1], "str": [None]} + ) + assert_frame_not_equal(out, expected, check_dtypes=True) + assert_frame_equal(out, expected, check_dtypes=False) + + +def test_first_last_nth_expressions(fruits_cars: pl.DataFrame) -> None: + df = fruits_cars + out = df.select(pl.first()) + assert out.columns == ["A"] + + out = df.select(pl.last()) + assert out.columns == ["cars"] + + out = df.select(pl.nth(0)) + assert out.columns == ["A"] + + out = df.select(pl.nth(1)) + assert out.columns == ["fruits"] + + out = df.select(pl.nth(-2)) + assert out.columns == ["B"] + + +def test_is_between(fruits_cars: pl.DataFrame) -> None: + result = fruits_cars.select(pl.col("A").is_between(2, 4)).to_series() + assert_series_equal(result, pl.Series("A", [False, True, True, True, False])) + + result = fruits_cars.select(pl.col("A").is_between(2, 4, closed="none")).to_series() + assert_series_equal(result, pl.Series("A", [False, False, True, False, False])) + + result = fruits_cars.select(pl.col("A").is_between(2, 4, closed="both")).to_series() + assert_series_equal(result, pl.Series("A", [False, True, True, True, False])) + + result = fruits_cars.select( + pl.col("A").is_between(2, 4, closed="right") + ).to_series() + assert_series_equal(result, pl.Series("A", [False, False, True, True, False])) + + result = fruits_cars.select(pl.col("A").is_between(2, 4, closed="left")).to_series() + assert_series_equal(result, pl.Series("A", [False, True, True, False, False])) + + +def test_is_between_data_types() -> None: + df = pl.DataFrame( + { + "flt": [1.4, 1.2, 2.5], + "int": [2, 3, 4], + "str": ["xyz", "str", "abc"], + "date": [date(2020, 1, 1), date(2020, 2, 2), date(2020, 3, 3)], + "datetime": [ + datetime(2020, 1, 1, 0, 0, 0), + datetime(2020, 1, 1, 10, 0, 0), + datetime(2020, 1, 1, 12, 0, 0), + ], + "tm": [time(10, 30), time(0, 45), time(15, 15)], + } + ) + + # on purpose, for float and int, we pass in a mixture of bound data types + assert_series_equal( + df.select(pl.col("flt").is_between(1, 2.3))[:, 0], + pl.Series("flt", [True, True, False]), + ) + assert_series_equal( + df.select(pl.col("int").is_between(1.5, 3))[:, 0], + pl.Series("int", [True, True, False]), + ) + assert_series_equal( + df.select(pl.col("date").is_between(date(2019, 1, 1), date(2020, 2, 5)))[:, 0], + pl.Series("date", [True, True, False]), + ) + assert_series_equal( + df.select( + pl.col("datetime").is_between( + datetime(2020, 1, 1, 5, 0, 0), datetime(2020, 1, 1, 11, 0, 0) + ) + )[:, 0], + pl.Series("datetime", [False, True, False]), + ) + assert_series_equal( + df.select( + pl.col("str").is_between(pl.lit("str"), pl.lit("zzz"), closed="left") + )[:, 0], + pl.Series("str", [True, True, False]), + ) + assert_series_equal( + df.select( + pl.col("tm") + .is_between(time(0, 45), time(10, 30), closed="right") + .alias("tm_between") + )[:, 0], + pl.Series("tm_between", [True, False, False]), + ) + + +def test_empty_is_in() -> None: + df_empty_isin = pl.DataFrame({"foo": ["a", "b", "c", "d"]}).filter( + pl.col("foo").is_in([]) + ) + assert df_empty_isin.shape == (0, 1) + assert df_empty_isin.rows() == [] + assert df_empty_isin.schema == {"foo": pl.String} + + +def test_group_by_slice_expression_args() -> None: + df = pl.DataFrame({"groups": ["a"] * 10 + ["b"] * 20, "vals": range(30)}) + + out = ( + df.group_by("groups", maintain_order=True) + .agg([pl.col("vals").slice((pl.len() * 0.1).cast(int), (pl.len() // 5))]) + .explode("vals") + ) + + expected = pl.DataFrame( + {"groups": ["a", "a", "b", "b", "b", "b"], "vals": [1, 2, 12, 13, 14, 15]} + ) + assert_frame_equal(out, expected) + + +def test_join_suffixes() -> None: + df_a = pl.DataFrame({"A": [1], "B": [1]}) + df_b = pl.DataFrame({"A": [1], "B": [1]}) + + join_strategies: list[JoinStrategy] = ["left", "inner", "full", "cross"] + for how in join_strategies: + # no need for an assert, we error if wrong + df_a.join(df_b, on="A" if how != "cross" else None, suffix="_y", how=how)["B_y"] + + df_a.join_asof(df_b, on=pl.col("A").set_sorted(), suffix="_y")["B_y"] + + +def test_explode_empty() -> None: + df = ( + pl.DataFrame({"x": ["a", "a", "b", "b"], "y": [1, 1, 2, 2]}) + .group_by("x", maintain_order=True) + .agg(pl.col("y").gather([])) + ) + assert df.explode("y").to_dict(as_series=False) == { + "x": ["a", "b"], + "y": [None, None], + } + + df = pl.DataFrame({"x": ["1", "2", "4"], "y": [["a", "b", "c"], ["d"], []]}) + assert_frame_equal( + df.explode("y"), + pl.DataFrame({"x": ["1", "1", "1", "2", "4"], "y": ["a", "b", "c", "d", None]}), + ) + + df = pl.DataFrame( + { + "letters": ["a"], + "numbers": [[]], + } + ) + assert df.explode("numbers").to_dict(as_series=False) == { + "letters": ["a"], + "numbers": [None], + } + + +def test_asof_by_multiple_keys() -> None: + lhs = pl.DataFrame( + { + "a": [-20, -19, 8, 12, 14], + "by": [1, 1, 2, 2, 2], + "by2": [1, 1, 2, 2, 2], + } + ) + + rhs = pl.DataFrame( + { + "a": [-19, -15, 3, 5, 13], + "by": [1, 1, 2, 2, 2], + "by2": [1, 1, 2, 2, 2], + } + ) + + result = lhs.join_asof( + rhs, on=pl.col("a").set_sorted(), by=["by", "by2"], strategy="backward" + ).select(["a", "by"]) + expected = pl.DataFrame({"a": [-20, -19, 8, 12, 14], "by": [1, 1, 2, 2, 2]}) + assert_frame_equal(result, expected) + + +def test_list_of_list_of_struct() -> None: + expected = [{"list_of_list_of_struct": [[{"a": 1}, {"a": 2}]]}] + pa_df = pa.Table.from_pylist(expected) + + df = pl.from_arrow(pa_df) + assert df.rows() == [([[{"a": 1}, {"a": 2}]],)] # type: ignore[union-attr] + assert df.to_dicts() == expected # type: ignore[union-attr] + + df = pl.from_arrow(pa_df[:0]) + assert df.to_dicts() == [] # type: ignore[union-attr] + + +def test_fill_null_limits() -> None: + assert pl.DataFrame( + { + "a": [1, None, None, None, 5, 6, None, None, None, 10], + "b": ["a", None, None, None, "b", "c", None, None, None, "d"], + "c": [True, None, None, None, False, True, None, None, None, False], + } + ).select( + pl.all().fill_null(strategy="forward", limit=2), + pl.all().fill_null(strategy="backward", limit=2).name.suffix("_backward"), + ).to_dict(as_series=False) == { + "a": [1, 1, 1, None, 5, 6, 6, 6, None, 10], + "b": ["a", "a", "a", None, "b", "c", "c", "c", None, "d"], + "c": [True, True, True, None, False, True, True, True, None, False], + "a_backward": [1, None, 5, 5, 5, 6, None, 10, 10, 10], + "b_backward": ["a", None, "b", "b", "b", "c", None, "d", "d", "d"], + "c_backward": [ + True, + None, + False, + False, + False, + True, + None, + False, + False, + False, + ], + } + + +def test_lower_bound_upper_bound(fruits_cars: pl.DataFrame) -> None: + res_expr = fruits_cars.select(pl.col("A").lower_bound()) + assert res_expr.item() == -9223372036854775808 + + res_expr = fruits_cars.select(pl.col("B").upper_bound()) + assert res_expr.item() == 9223372036854775807 + + with pytest.raises(ComputeError): + fruits_cars.select(pl.col("fruits").upper_bound()) + + +def test_selection_misc() -> None: + df = pl.DataFrame({"x": "abc"}, schema={"x": pl.String}) + + # literal values (as scalar/list) + for zero in (0, [0]): + assert df.select(zero)["literal"].to_list() == [0] + assert df.select(literal=0)["literal"].to_list() == [0] + + # expect string values to be interpreted as cols + for x in ("x", ["x"], pl.col("x")): + assert df.select(x).rows() == [("abc",)] + + # string col + lit + assert df.with_columns(["x", 0]).to_dicts() == [{"x": "abc", "literal": 0}] + + +def test_selection_regex_and_multicol() -> None: + test_df = pl.DataFrame( + { + "a": [1, 2, 3, 4], + "b": [5, 6, 7, 8], + "c": [9, 10, 11, 12], + "foo": [13, 14, 15, 16], + }, + schema_overrides={"foo": pl.UInt8}, + ) + + # Selection only + test_df.select( + pl.col(["a", "b", "c"]).name.suffix("_list"), + pl.all().exclude("foo").name.suffix("_wild"), + pl.col("^\\w$").name.suffix("_regex"), + ) + + # Multi * Single + assert test_df.select(pl.col(["a", "b", "c"]) * pl.col("foo")).to_dict( + as_series=False + ) == { + "a": [13, 28, 45, 64], + "b": [65, 84, 105, 128], + "c": [117, 140, 165, 192], + } + assert test_df.select(pl.all().exclude("foo") * pl.col("foo")).to_dict( + as_series=False + ) == { + "a": [13, 28, 45, 64], + "b": [65, 84, 105, 128], + "c": [117, 140, 165, 192], + } + + assert test_df.select(pl.col("^\\w$") * pl.col("foo")).to_dict(as_series=False) == { + "a": [13, 28, 45, 64], + "b": [65, 84, 105, 128], + "c": [117, 140, 165, 192], + } + + # Multi * Multi + result = test_df.select(pl.col(["a", "b", "c"]) * pl.col(["a", "b", "c"])) + expected = {"a": [1, 4, 9, 16], "b": [25, 36, 49, 64], "c": [81, 100, 121, 144]} + + assert result.to_dict(as_series=False) == expected + assert test_df.select(pl.exclude("foo") * pl.exclude("foo")).to_dict( + as_series=False + ) == { + "a": [1, 4, 9, 16], + "b": [25, 36, 49, 64], + "c": [81, 100, 121, 144], + } + assert test_df.select(pl.col("^\\w$") * pl.col("^\\w$")).to_dict( + as_series=False + ) == { + "a": [1, 4, 9, 16], + "b": [25, 36, 49, 64], + "c": [81, 100, 121, 144], + } + + # kwargs + with pl.Config() as cfg: + cfg.set_auto_structify(True) + + df = test_df.select( + pl.col("^\\w$").alias("re"), + odd=(pl.col(INTEGER_DTYPES) % 2).name.suffix("_is_odd"), + maxes=pl.all().max().name.suffix("_max"), + ).head(2) + # ┌───────────┬───────────┬─────────────┐ + # │ re ┆ odd ┆ maxes │ + # │ --- ┆ --- ┆ --- │ + # │ struct[3] ┆ struct[4] ┆ struct[4] │ + # ╞═══════════╪═══════════╪═════════════╡ + # │ {1,5,9} ┆ {1,1,1,1} ┆ {4,8,12,16} │ + # │ {2,6,10} ┆ {0,0,0,0} ┆ {4,8,12,16} │ + # └───────────┴───────────┴─────────────┘ + assert df.rows() == [ + ( + {"a": 1, "b": 5, "c": 9}, + {"a_is_odd": 1, "b_is_odd": 1, "c_is_odd": 1, "foo_is_odd": 1}, + {"a_max": 4, "b_max": 8, "c_max": 12, "foo_max": 16}, + ), + ( + {"a": 2, "b": 6, "c": 10}, + {"a_is_odd": 0, "b_is_odd": 0, "c_is_odd": 0, "foo_is_odd": 0}, + {"a_max": 4, "b_max": 8, "c_max": 12, "foo_max": 16}, + ), + ] + + +@pytest.mark.parametrize("subset", ["a", cs.starts_with("x", "a")]) +@pytest.mark.may_fail_auto_streaming # Flaky in CI, see https://github.com/pola-rs/polars/issues/20943 +def test_unique_on_sorted(subset: Any) -> None: + df = pl.DataFrame(data={"a": [1, 1, 3], "b": [1, 2, 3]}) + + result = df.with_columns([pl.col("a").set_sorted()]).unique( + subset=subset, + keep="last", + ) + + expected = pl.DataFrame({"a": [1, 3], "b": [2, 3]}) + assert_frame_equal(result, expected) + + +def test_len_compute(df: pl.DataFrame) -> None: + df = df.with_columns(pl.struct(["list_bool", "cat"]).alias("struct")) + filtered = df.filter(pl.col("bools")) + for col in filtered.columns: + assert len(filtered[col]) == 1 + + taken = df[[1, 2], :] + for col in taken.columns: + assert len(taken[col]) == 2 + + +def test_filter_sequence() -> None: + df = pl.DataFrame({"a": [1, 2, 3]}) + assert df.filter([True, False, True])["a"].to_list() == [1, 3] + assert df.filter(np.array([True, False, True]))["a"].to_list() == [1, 3] + + +def test_filter_multiple_predicates() -> None: + df = pl.DataFrame( + { + "a": [1, 1, 1, 2, 2], + "b": [1, 1, 2, 2, 2], + "c": [1, 1, 2, 3, 4], + } + ) + + # multiple predicates + expected = pl.DataFrame({"a": [1, 1, 1], "b": [1, 1, 2], "c": [1, 1, 2]}) + for out in ( + df.filter(pl.col("a") == 1, pl.col("b") <= 2), # positional/splat + df.filter([pl.col("a") == 1, pl.col("b") <= 2]), # as list + ): + assert_frame_equal(out, expected) + + # multiple kwargs + assert_frame_equal( + df.filter(a=1, b=2), + pl.DataFrame({"a": [1], "b": [2], "c": [2]}), + ) + + # both positional and keyword args + assert_frame_equal( + pl.DataFrame({"a": [2], "b": [2], "c": [3]}), + df.filter(pl.col("c") < 4, a=2, b=2), + ) + + # boolean mask + out = df.filter([True, False, False, False, True]) + expected = pl.DataFrame({"a": [1, 2], "b": [1, 2], "c": [1, 4]}) + assert_frame_equal(out, expected) + + # multiple boolean masks + out = df.filter( + np.array([True, True, False, True, False]), + np.array([True, False, True, True, False]), + ) + expected = pl.DataFrame({"a": [1, 2], "b": [1, 2], "c": [1, 3]}) + assert_frame_equal(out, expected) + + +def test_indexing_set() -> None: + df = pl.DataFrame({"bool": [True, True], "str": ["N/A", "N/A"], "nr": [1, 2]}) + + df[0, "bool"] = False + df[0, "nr"] = 100 + df[0, "str"] = "foo" + + assert df.to_dict(as_series=False) == { + "bool": [False, True], + "str": ["foo", "N/A"], + "nr": [100, 2], + } + + +def test_set() -> None: + # Setting a dataframe using indices is deprecated. + # We keep these tests because we only generate a warning. + np.random.seed(1) + df = pl.DataFrame( + {"foo": np.random.rand(10), "bar": np.arange(10), "ham": ["h"] * 10} + ) + with pytest.raises( + TypeError, + match=r"DataFrame object does not support `Series` assignment by index" + r"\n\nUse `DataFrame.with_columns`.", + ): + df["new"] = np.random.rand(10) + + with pytest.raises( + TypeError, + match=r"not allowed to set DataFrame by boolean mask in the row position" + r"\n\nConsider using `DataFrame.with_columns`.", + ): + df[df["ham"] > 0.5, "ham"] = "a" + with pytest.raises( + TypeError, + match=r"not allowed to set DataFrame by boolean mask in the row position" + r"\n\nConsider using `DataFrame.with_columns`.", + ): + df[[True, False], "ham"] = "a" + + # set 2D + df = pl.DataFrame({"b": [0, 0]}) + df[["A", "B"]] = [[1, 2], [1, 2]] + + with pytest.raises(ValueError): + df[["C", "D"]] = 1 + with pytest.raises(ValueError): + df[["C", "D"]] = [1, 1] + with pytest.raises(ValueError): + df[["C", "D"]] = [[1, 2, 3], [1, 2, 3]] + + # set tuple + df = pl.DataFrame({"b": [0, 0]}) + df[0, "b"] = 1 + assert df[0, "b"] == 1 + + df[0, 0] = 2 + assert df[0, "b"] == 2 + + # row and col selection have to be int or str + with pytest.raises(TypeError): + df[:, [1]] = 1 # type: ignore[index] + with pytest.raises(TypeError): + df[True, :] = 1 # type: ignore[index] + + # needs to be a 2 element tuple + with pytest.raises(ValueError): + df[1, 2, 3] = 1 + + # we cannot index with any type, such as bool + with pytest.raises(TypeError): + df[True] = 1 # type: ignore[index] + + +def test_series_iter_over_frame() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [2, 3, 4], "c": [3, 4, 5]}) + + expected = { + 0: pl.Series("a", [1, 2, 3]), + 1: pl.Series("b", [2, 3, 4]), + 2: pl.Series("c", [3, 4, 5]), + } + for idx, s in enumerate(df): + assert_series_equal(s, expected[idx]) + + expected = { + 0: pl.Series("c", [3, 4, 5]), + 1: pl.Series("b", [2, 3, 4]), + 2: pl.Series("a", [1, 2, 3]), + } + for idx, s in enumerate(reversed(df)): + assert_series_equal(s, expected[idx]) + + +def test_union_with_aliases_4770() -> None: + lf = pl.DataFrame( + { + "a": [1, None], + "b": [3, 4], + } + ).lazy() + + lf = pl.concat( + [ + lf.select([pl.col("a").alias("x")]), + lf.select([pl.col("b").alias("x")]), + ] + ).filter(pl.col("x").is_not_null()) + + assert lf.collect()["x"].to_list() == [1, 3, 4] + + +def test_init_datetimes_with_timezone() -> None: + tz_us = "America/New_York" + tz_europe = "Europe/Amsterdam" + + dtm = datetime(2022, 10, 12, 12, 30) + for time_unit in DTYPE_TEMPORAL_UNITS: + for type_overrides in ( + { + "schema": [ + ("d1", pl.Datetime(time_unit, tz_us)), + ("d2", pl.Datetime(time_unit, tz_europe)), + ] + }, + { + "schema_overrides": { + "d1": pl.Datetime(time_unit, tz_us), + "d2": pl.Datetime(time_unit, tz_europe), + } + }, + ): + result = pl.DataFrame( + data={ + "d1": [dtm.replace(tzinfo=ZoneInfo(tz_us))], + "d2": [dtm.replace(tzinfo=ZoneInfo(tz_europe))], + }, + **type_overrides, + ) + expected = pl.DataFrame( + {"d1": ["2022-10-12 12:30"], "d2": ["2022-10-12 12:30"]} + ).with_columns( + pl.col("d1").str.to_datetime(time_unit=time_unit, time_zone=tz_us), + pl.col("d2").str.to_datetime(time_unit=time_unit, time_zone=tz_europe), + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + ( + "tzinfo", + "offset", + "dtype_time_zone", + "expected_time_zone", + "expected_item", + ), + [ + (None, "", None, None, datetime(2020, 1, 1)), + ( + timezone(timedelta(hours=-8)), + "-08:00", + "UTC", + "UTC", + datetime(2020, 1, 1, 8, tzinfo=timezone.utc), + ), + ( + timezone(timedelta(hours=-8)), + "-08:00", + None, + "UTC", + datetime(2020, 1, 1, 8, tzinfo=timezone.utc), + ), + ], +) +def test_init_vs_strptime_consistency( + tzinfo: timezone | None, + offset: str, + dtype_time_zone: str | None, + expected_time_zone: str, + expected_item: datetime, +) -> None: + result_init = pl.Series( + [datetime(2020, 1, 1, tzinfo=tzinfo)], + dtype=pl.Datetime("us", dtype_time_zone), + ) + result_strptime = pl.Series([f"2020-01-01 00:00{offset}"]).str.strptime( + pl.Datetime("us", dtype_time_zone) + ) + assert result_init.dtype == pl.Datetime("us", expected_time_zone) + assert result_init.item() == expected_item + assert_series_equal(result_init, result_strptime) + + +def test_init_vs_strptime_consistency_converts() -> None: + result = pl.Series( + [datetime(2020, 1, 1, tzinfo=timezone(timedelta(hours=-8)))], + dtype=pl.Datetime("us", "US/Pacific"), + ).item() + assert result == datetime(2020, 1, 1, 0, 0, tzinfo=ZoneInfo(key="US/Pacific")) + result = ( + pl.Series(["2020-01-01 00:00-08:00"]) + .str.strptime(pl.Datetime("us", "US/Pacific")) + .item() + ) + assert result == datetime(2020, 1, 1, 0, 0, tzinfo=ZoneInfo(key="US/Pacific")) + + +def test_init_physical_with_timezone() -> None: + tz_uae = "Asia/Dubai" + tz_asia = "Asia/Tokyo" + + dtm_us = 1665577800000000 + for time_unit in DTYPE_TEMPORAL_UNITS: + dtm = {"ms": dtm_us // 1_000, "ns": dtm_us * 1_000}.get(str(time_unit), dtm_us) + df = pl.DataFrame( + data={"d1": [dtm], "d2": [dtm]}, + schema=[ + ("d1", pl.Datetime(time_unit, tz_uae)), + ("d2", pl.Datetime(time_unit, tz_asia)), + ], + ) + assert (df["d1"].to_physical() == df["d2"].to_physical()).all() + assert df.rows() == [ + ( + datetime(2022, 10, 12, 16, 30, tzinfo=ZoneInfo(tz_uae)), + datetime(2022, 10, 12, 21, 30, tzinfo=ZoneInfo(tz_asia)), + ) + ] + + +@pytest.mark.parametrize("divop", [floordiv, truediv]) +def test_floordiv_truediv(divop: Callable[..., Any]) -> None: + # validate truediv/floordiv dataframe ops against python + df1 = pl.DataFrame( + data={ + "x": [0, -1, -2, -3], + "y": [-0.0, -3.0, 5.0, -7.0], + "z": [10, 3, -5, 7], + } + ) + + # scalar + for n in (3, 3.0, -3, -3.0): + py_div = [tuple(divop(elem, n) for elem in row) for row in df1.rows()] + df_div = divop(df1, n).rows() + assert py_div == df_div + + # series + xdf, s = df1["x"].to_frame(), pl.Series([2] * 4) + assert list(divop(xdf, s)["x"]) == [divop(x, 2) for x in list(df1["x"])] + + # frame + df2 = pl.DataFrame( + data={ + "x": [2, -2, 2, 3], + "y": [4, 4, -4, 8], + "z": [0.5, 2.0, -2.0, -3], + } + ) + df_div = divop(df1, df2).rows() + for i, (row1, row2) in enumerate(zip(df1.rows(), df2.rows())): + for j, (elem1, elem2) in enumerate(zip(row1, row2)): + assert divop(elem1, elem2) == df_div[i][j] + + +@pytest.mark.parametrize( + ("subset", "keep", "expected_mask"), + [ + (None, "first", [True, True, True, False]), + ("a", "first", [True, True, False, False]), + (["a", "b"], "first", [True, True, False, False]), + (("a", "b"), "last", [True, False, False, True]), + (("a", "b"), "none", [True, False, False, False]), + ], +) +def test_unique( + subset: str | Sequence[str], keep: UniqueKeepStrategy, expected_mask: list[bool] +) -> None: + df = pl.DataFrame({"a": [1, 2, 2, 2], "b": [3, 4, 4, 4], "c": [5, 6, 7, 7]}) + + result = df.unique(maintain_order=True, subset=subset, keep=keep) + expected = df.filter(expected_mask) + assert_frame_equal(result, expected) + + +def test_iter_slices() -> None: + df = pl.DataFrame( + { + "a": range(95), + "b": date(2023, 1, 1), + "c": "klmnopqrstuvwxyz", + } + ) + batches = list(df.iter_slices(n_rows=50)) + + assert len(batches[0]) == 50 + assert len(batches[1]) == 45 + assert batches[1].rows() == df[50:].rows() + + +def test_format_empty_df() -> None: + df = pl.DataFrame( + [ + pl.Series("val1", [], dtype=pl.Categorical), + pl.Series("val2", [], dtype=pl.Categorical), + ] + ).select( + pl.format("{}:{}", pl.col("val1"), pl.col("val2")).alias("cat"), + ) + assert df.shape == (0, 1) + assert df.dtypes == [pl.String] + + +def test_deadlocks_3409() -> None: + assert ( + pl.DataFrame({"col1": [[1, 2, 3]]}) + .with_columns( + pl.col("col1").list.eval( + pl.element().map_elements(lambda x: x, return_dtype=pl.Int64) + ) + ) + .to_dict(as_series=False) + ) == {"col1": [[1, 2, 3]]} + + assert ( + pl.DataFrame({"col1": [1, 2, 3]}) + .with_columns( + pl.col("col1").cumulative_eval(pl.element().map_batches(lambda x: 0)) + ) + .to_dict(as_series=False) + ) == {"col1": [0, 0, 0]} + + +def test_ceil() -> None: + df = pl.DataFrame({"a": [1.8, 1.2, 3.0]}) + result = df.select(pl.col("a").ceil()) + assert_frame_equal(result, pl.DataFrame({"a": [2.0, 2.0, 3.0]})) + + df = pl.DataFrame({"a": [1, 2, 3]}) + result = df.select(pl.col("a").ceil()) + assert_frame_equal(df, result) + + +def test_floor() -> None: + df = pl.DataFrame({"a": [1.8, 1.2, 3.0]}) + result = df.select(pl.col("a").floor()) + assert_frame_equal(result, pl.DataFrame({"a": [1.0, 1.0, 3.0]})) + + df = pl.DataFrame({"a": [1, 2, 3]}) + result = df.select(pl.col("a").floor()) + assert_frame_equal(df, result) + + +def test_floor_divide() -> None: + x = 10.4 + step = 0.5 + df = pl.DataFrame({"x": [x]}) + assert df.with_columns(pl.col("x") // step)[0, 0] == x // step + + +def test_round() -> None: + df = pl.DataFrame({"a": [1.8, 1.2, 3.0]}) + col_a_rounded = df.select(pl.col("a").round(decimals=0))["a"] + assert_series_equal(col_a_rounded, pl.Series("a", [2, 1, 3]).cast(pl.Float64)) + + +def test_dot() -> None: + df = pl.DataFrame({"a": [1.8, 1.2, 3.0], "b": [3.2, 1, 2]}) + assert df.select(pl.col("a").dot(pl.col("b"))).item() == 12.96 + + +def test_unstack() -> None: + from string import ascii_uppercase + + df = pl.DataFrame( + { + "col1": list(ascii_uppercase[0:9]), + "col2": pl.int_range(0, 9, eager=True), + "col3": pl.int_range(-9, 0, eager=True), + } + ) + assert df.unstack(step=3, how="vertical").to_dict(as_series=False) == { + "col1_0": ["A", "B", "C"], + "col1_1": ["D", "E", "F"], + "col1_2": ["G", "H", "I"], + "col2_0": [0, 1, 2], + "col2_1": [3, 4, 5], + "col2_2": [6, 7, 8], + "col3_0": [-9, -8, -7], + "col3_1": [-6, -5, -4], + "col3_2": [-3, -2, -1], + } + + assert df.unstack(step=3, how="horizontal").to_dict(as_series=False) == { + "col1_0": ["A", "D", "G"], + "col1_1": ["B", "E", "H"], + "col1_2": ["C", "F", "I"], + "col2_0": [0, 3, 6], + "col2_1": [1, 4, 7], + "col2_2": [2, 5, 8], + "col3_0": [-9, -6, -3], + "col3_1": [-8, -5, -2], + "col3_2": [-7, -4, -1], + } + + for column_subset in (("col2", "col3"), cs.integer()): + assert df.unstack( + step=3, + how="horizontal", + columns=column_subset, + ).to_dict(as_series=False) == { + "col2_0": [0, 3, 6], + "col2_1": [1, 4, 7], + "col2_2": [2, 5, 8], + "col3_0": [-9, -6, -3], + "col3_1": [-8, -5, -2], + "col3_2": [-7, -4, -1], + } + + +def test_window_deadlock() -> None: + np.random.seed(12) + + df = pl.DataFrame( + { + "nrs": [1, 2, 3, None, 5], + "names": ["foo", "ham", "spam", "egg", None], + "random": np.random.rand(5), + "groups": ["A", "A", "B", "C", "B"], + } + ) + + _df = df.select( + pl.col("*"), # select all + pl.col("random").sum().over("groups").alias("sum[random]/groups"), + pl.col("random").implode().over("names").alias("random/name"), + ) + + +def test_sum_empty_column_names() -> None: + df = pl.DataFrame({"x": [], "y": []}, schema={"x": pl.Boolean, "y": pl.Boolean}) + expected = pl.DataFrame( + {"x": [0], "y": [0]}, schema={"x": pl.UInt32, "y": pl.UInt32} + ) + assert_frame_equal(df.sum(), expected) + + +def test_flags() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [9, 5, 6]}) + assert df.flags == { + "a": {"SORTED_ASC": False, "SORTED_DESC": False}, + "b": {"SORTED_ASC": False, "SORTED_DESC": False}, + } + assert df.set_sorted("a").flags == { + "a": {"SORTED_ASC": True, "SORTED_DESC": False}, + "b": {"SORTED_ASC": False, "SORTED_DESC": False}, + } + + +def test_interchange() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3.0, 4.0], "c": ["foo", "bar"]}) + dfi = df.__dataframe__() + + # Testing some random properties to make sure conversion happened correctly + assert dfi.num_rows() == 2 + assert dfi.get_column(0).dtype[1] == 64 + assert dfi.get_column_by_name("c").get_buffers()["data"][0].bufsize == 6 + + +def test_from_dicts_undeclared_column_dtype() -> None: + data = [{"a": 1, "b": 2}] + result = pl.from_dicts(data, schema=["x"]) + assert result.schema == {"x": pl.Null} + + +def test_from_dicts_with_override() -> None: + data = [ + {"a": "1", "b": str(2**64 - 1), "c": "1"}, + {"a": "1", "b": "1", "c": "-5.0"}, + ] + override = {"a": pl.Int32, "b": pl.UInt64, "c": pl.Float32} + result = pl.from_dicts(data, schema_overrides=override) + assert_frame_equal( + result, + pl.DataFrame( + { + "a": pl.Series([1, 1], dtype=pl.Int32), + "b": pl.Series([2**64 - 1, 1], dtype=pl.UInt64), + "c": pl.Series([1.0, -5.0], dtype=pl.Float32), + } + ), + ) + + +def test_from_records_u64_12329() -> None: + s = pl.from_records([{"a": 9908227375760408577}]) + assert s.dtypes == [pl.Int128] + assert s["a"][0] == 9908227375760408577 + + +def test_negative_slice_12642() -> None: + df = pl.DataFrame({"x": range(5)}) + assert_frame_equal(df.slice(-2, 1), df.tail(2).head(1)) + + +def test_iter_columns() -> None: + df = pl.DataFrame({"a": [1, 1, 2], "b": [4, 5, 6]}) + iter_columns = df.iter_columns() + assert_series_equal(next(iter_columns), pl.Series("a", [1, 1, 2])) + assert_series_equal(next(iter_columns), pl.Series("b", [4, 5, 6])) + + +def test_get_column_index() -> None: + df = pl.DataFrame({"actual": [1001], "expected": [1000]}) + + assert df.get_column_index("actual") == 0 + assert df.get_column_index("expected") == 1 + + with pytest.raises(ColumnNotFoundError, match="missing"): + df.get_column_index("missing") + + +def test_dataframe_creation_with_different_series_lengths_19795() -> None: + with pytest.raises( + ShapeError, + match=r"could not create a new DataFrame: height of column 'b' \(1\) does not match height of column 'a' \(2\)", + ): + pl.DataFrame({"a": [1, 2], "b": [1]}) + + +def test_get_column_after_drop_20119() -> None: + df = pl.DataFrame({"a": ["A"], "b": ["B"], "c": ["C"]}) + df.drop_in_place("a") + c = df.get_column("c") + assert_series_equal(c, pl.Series("c", ["C"])) + + +def test_select_oob_row_20775() -> None: + df = pl.DataFrame({"a": [1, 2, 3]}) + with pytest.raises( + IndexError, + match="index 99 is out of bounds for DataFrame of height 3", + ): + df[99] + + +@pytest.mark.parametrize("idx", [3, 99, -4, -99]) +def test_select_oob_element_20775_too_large(idx: int) -> None: + df = pl.DataFrame({"a": [1, 2, 3]}) + with pytest.raises( + IndexError, + match=f"index {idx} is out of bounds for sequence of length 3", + ): + df[idx, "a"] + + +def test_nan_to_null() -> None: + a = np.array([np.nan, 1]) + + df1 = pl.DataFrame(a, nan_to_null=True) + df2 = pl.DataFrame( + (a,), + nan_to_null=True, + ) + + assert_frame_equal(df1, df2) diff --git a/py-polars/tests/unit/dataframe/test_equals.py b/py-polars/tests/unit/dataframe/test_equals.py new file mode 100644 index 000000000000..c8c5ec1c2e64 --- /dev/null +++ b/py-polars/tests/unit/dataframe/test_equals.py @@ -0,0 +1,47 @@ +import polars as pl + + +def test_equals() -> None: + # Values are checked + df1 = pl.DataFrame( + { + "foo": [1, 2, 3], + "bar": [6.0, 7.0, 8.0], + "ham": ["a", "b", "c"], + } + ) + df2 = pl.DataFrame( + { + "foo": [3, 2, 1], + "bar": [8.0, 7.0, 6.0], + "ham": ["c", "b", "a"], + } + ) + + assert df1.equals(df1) is True + assert df1.equals(df2) is False + + # Column names are checked + df3 = pl.DataFrame( + { + "a": [1, 2, 3], + "b": [6.0, 7.0, 8.0], + "c": ["a", "b", "c"], + } + ) + assert df1.equals(df3) is False + + # Datatypes are NOT checked + df = pl.DataFrame( + { + "foo": [1, 2, None], + "bar": [6.0, 7.0, None], + "ham": ["a", "b", None], + } + ) + assert df.equals(df.with_columns(pl.col("foo").cast(pl.Int8))) is True + assert df.equals(df.with_columns(pl.col("ham").cast(pl.Categorical))) is True + + # The null_equal parameter determines if None values are considered equal + assert df.equals(df) is True + assert df.equals(df, null_equal=False) is False diff --git a/py-polars/tests/unit/dataframe/test_extend.py b/py-polars/tests/unit/dataframe/test_extend.py new file mode 100644 index 000000000000..9d4f46c95e2d --- /dev/null +++ b/py-polars/tests/unit/dataframe/test_extend.py @@ -0,0 +1,97 @@ +from datetime import datetime + +import pytest + +import polars as pl +from polars.exceptions import ShapeError +from polars.testing import assert_frame_equal + + +def test_extend_various_dtypes() -> None: + with pl.StringCache(): + df1 = pl.DataFrame( + { + "foo": [1, 2], + "bar": [True, False], + "ham": ["a", "b"], + "cat": ["A", "B"], + "dates": [datetime(2021, 1, 1), datetime(2021, 2, 1)], + }, + schema_overrides={"cat": pl.Categorical}, + ) + df2 = pl.DataFrame( + { + "foo": [3, 4], + "bar": [True, None], + "ham": ["c", "d"], + "cat": ["C", "B"], + "dates": [datetime(2022, 9, 1), datetime(2021, 2, 1)], + }, + schema_overrides={"cat": pl.Categorical}, + ) + + df1.extend(df2) + + expected = pl.DataFrame( + { + "foo": [1, 2, 3, 4], + "bar": [True, False, True, None], + "ham": ["a", "b", "c", "d"], + "cat": ["A", "B", "C", "B"], + "dates": [ + datetime(2021, 1, 1), + datetime(2021, 2, 1), + datetime(2022, 9, 1), + datetime(2021, 2, 1), + ], + }, + schema_overrides={"cat": pl.Categorical}, + ) + assert_frame_equal(df1, expected) + + +def test_extend_slice_offset_8745() -> None: + df = pl.DataFrame([{"age": 1}, {"age": 2}, {"age": 3}]) + + df = df[:-1] + tail = pl.DataFrame([{"age": 8}]) + result = df.extend(tail) + + expected = pl.DataFrame({"age": [1, 2, 8]}) + assert_frame_equal(result, expected) + + +def test_extend_self() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [True, False]}) + + df.extend(df) + + expected = pl.DataFrame({"a": [1, 2, 1, 2], "b": [True, False, True, False]}) + assert_frame_equal(df, expected) + + +def test_extend_column_number_mismatch() -> None: + df1 = pl.DataFrame({"a": [1, 2], "b": [True, False]}) + df2 = df1.drop("a") + + with pytest.raises(ShapeError): + df1.extend(df2) + + +def test_extend_column_name_mismatch() -> None: + df1 = pl.DataFrame({"a": [1, 2], "b": [True, False]}) + df2 = df1.with_columns(pl.col("a").alias("c")) + + with pytest.raises(ShapeError): + df1.extend(df2) + + +def test_initialize_df_18736() -> None: + # Completely empty initialization + df = pl.DataFrame() + s_0 = pl.Series([]) + s_1 = pl.Series([None]) + s_2 = pl.Series([None, None]) + assert df.with_columns(s_0).shape == (0, 1) + assert df.with_columns(s_1).shape == (1, 1) + assert df.with_columns(s_2).shape == (2, 1) diff --git a/py-polars/tests/unit/dataframe/test_from_dict.py b/py-polars/tests/unit/dataframe/test_from_dict.py new file mode 100644 index 000000000000..d151b0e36112 --- /dev/null +++ b/py-polars/tests/unit/dataframe/test_from_dict.py @@ -0,0 +1,257 @@ +from __future__ import annotations + +from datetime import date, datetime, time, timedelta +from typing import Any + +import numpy as np +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + + +def test_from_dict_with_column_order() -> None: + # expect schema/columns order to take precedence + schema = {"a": pl.UInt8, "b": pl.UInt32} + data = {"b": [3, 4], "a": [1, 2]} + for df in ( + pl.DataFrame(data, schema=schema), + pl.DataFrame(data, schema=["a", "b"], schema_overrides=schema), + ): + # ┌─────┬─────┐ + # │ a ┆ b │ + # │ --- ┆ --- │ + # │ u8 ┆ u32 │ + # ╞═════╪═════╡ + # │ 1 ┆ 3 │ + # │ 2 ┆ 4 │ + # └─────┴─────┘ + assert df.columns == ["a", "b"] + assert df.schema == {"a": pl.UInt8, "b": pl.UInt32} + assert df.rows() == [(1, 3), (2, 4)] + + # expect an error + mismatched_schema = {"x": pl.UInt8, "b": pl.UInt32} + with pytest.raises(ValueError): + pl.DataFrame({"b": [3, 4], "a": [1, 2]}, schema=mismatched_schema) + + +def test_from_dict_with_scalars() -> None: + # one or more valid arrays, with some scalars (inc. None) + df1 = pl.DataFrame( + {"key": ["aa", "bb", "cc"], "misc": "xyz", "other": None, "value": 0} + ) + assert df1.to_dict(as_series=False) == { + "key": ["aa", "bb", "cc"], + "misc": ["xyz", "xyz", "xyz"], + "other": [None, None, None], + "value": [0, 0, 0], + } + + # edge-case: all scalars + df2 = pl.DataFrame({"key": "aa", "misc": "xyz", "other": None, "value": 0}) + assert df2.to_dict(as_series=False) == { + "key": ["aa"], + "misc": ["xyz"], + "other": [None], + "value": [0], + } + + # edge-case: single unsized generator + df3 = pl.DataFrame({"vals": map(float, [1, 2, 3])}) + assert df3.to_dict(as_series=False) == {"vals": [1.0, 2.0, 3.0]} + + # ensure we don't accidentally consume or expand map/range/generator + # cols, and can properly apply schema dtype/ordering directives + df4 = pl.DataFrame( + { + "key": range(1, 4), + "misc": (x for x in [4, 5, 6]), + "other": map(float, [7, 8, 9]), + "value": {0: "x", 1: "y", 2: "z"}.values(), + }, + schema={ + "value": pl.String, + "other": pl.Float32, + "misc": pl.Int32, + "key": pl.Int8, + }, + ) + assert df4.columns == ["value", "other", "misc", "key"] + assert df4.to_dict(as_series=False) == { + "value": ["x", "y", "z"], + "other": [7.0, 8.0, 9.0], + "misc": [4, 5, 6], + "key": [1, 2, 3], + } + assert df4.schema == { + "value": pl.String, + "other": pl.Float32, + "misc": pl.Int32, + "key": pl.Int8, + } + + # mixed with struct cols + for df5 in ( + pl.from_dict( + {"x": {"b": [1, 3], "c": [2, 4]}, "y": [5, 6], "z": "x"}, + schema_overrides={"y": pl.Int8}, + ), + pl.from_dict( + {"x": {"b": [1, 3], "c": [2, 4]}, "y": [5, 6], "z": "x"}, + schema=["x", ("y", pl.Int8), "z"], + ), + ): + assert df5.rows() == [({"b": 1, "c": 2}, 5, "x"), ({"b": 3, "c": 4}, 6, "x")] + assert df5.schema == { + "x": pl.Struct([pl.Field("b", pl.Int64), pl.Field("c", pl.Int64)]), + "y": pl.Int8, + "z": pl.String, + } + + # mixed with numpy cols... + df6 = pl.DataFrame( + {"x": np.ones(3), "y": np.zeros(3), "z": 1.0}, + ) + assert df6.rows() == [(1.0, 0.0, 1.0), (1.0, 0.0, 1.0), (1.0, 0.0, 1.0)] + + # ...and trigger multithreaded load codepath + df7 = pl.DataFrame( + { + "w": np.zeros(1001, dtype=np.uint8), + "x": np.ones(1001, dtype=np.uint8), + "y": np.zeros(1001, dtype=np.uint8), + "z": 1, + }, + schema_overrides={"z": pl.UInt8}, + ) + assert df7[999:].rows() == [(0, 1, 0, 1), (0, 1, 0, 1)] + assert df7.schema == { + "w": pl.UInt8, + "x": pl.UInt8, + "y": pl.UInt8, + "z": pl.UInt8, + } + + # misc generators/iterables + df9 = pl.DataFrame( + { + "a": iter([0, 1, 2]), + "b": (2, 1, 0).__iter__(), + "c": (v for v in (0, 0, 0)), + "d": "x", + } + ) + assert df9.rows() == [(0, 2, 0, "x"), (1, 1, 0, "x"), (2, 0, 0, "x")] + + +@pytest.mark.slow +def test_from_dict_with_values_mixed() -> None: + # a bit of everything + mixed_dtype_data: dict[str, Any] = { + "a": 0, + "b": 8, + "c": 9.5, + "d": None, + "e": True, + "f": False, + "g": time(0, 1, 2), + "h": date(2023, 3, 14), + "i": timedelta(seconds=3601), + "j": datetime(2111, 11, 11, 11, 11, 11, 11), + "k": "「趣味でヒーローをやっている者だ」", + } + # note: deliberately set this value large; if all dtypes are + # on the fast-path it'll only take ~0.03secs. if it becomes + # even remotely noticeable that will indicate a regression. + n_range = 1_000_000 + index_and_data: dict[str, Any] = {"idx": range(n_range)} + index_and_data.update(mixed_dtype_data.items()) + df = pl.DataFrame( + data=index_and_data, + schema={ + "idx": pl.Int32, + "a": pl.UInt16, + "b": pl.UInt32, + "c": pl.Float64, + "d": pl.Float32, + "e": pl.Boolean, + "f": pl.Boolean, + "g": pl.Time, + "h": pl.Date, + "i": pl.Duration, + "j": pl.Datetime, + "k": pl.String, + }, + ) + dfx = df.select(pl.exclude("idx")) + + assert df.height == n_range + assert dfx[:5].rows() == dfx[5:10].rows() + assert dfx[-10:-5].rows() == dfx[-5:].rows() + assert dfx.row(n_range // 2, named=True) == mixed_dtype_data + + +def test_from_dict_expand_nested_struct() -> None: + # confirm consistent init of nested struct from dict data + dt = date(2077, 10, 10) + expected = pl.DataFrame( + [ + pl.Series("x", [dt]), + pl.Series("nested", [{"y": -1, "z": 1}]), + ] + ) + for df in ( + pl.DataFrame({"x": dt, "nested": {"y": -1, "z": 1}}), + pl.DataFrame({"x": dt, "nested": [{"y": -1, "z": 1}]}), + pl.DataFrame({"x": [dt], "nested": {"y": -1, "z": 1}}), + pl.DataFrame({"x": [dt], "nested": [{"y": -1, "z": 1}]}), + ): + assert_frame_equal(expected, df) + + # confirm expansion to 'n' nested values + nested_values = [{"y": -1, "z": 1}, {"y": -1, "z": 1}, {"y": -1, "z": 1}] + expected = pl.DataFrame( + [ + pl.Series("x", [0, 1, 2]), + pl.Series("nested", nested_values), + ] + ) + for df in ( + pl.DataFrame({"x": range(3), "nested": {"y": -1, "z": 1}}), + pl.DataFrame({"x": [0, 1, 2], "nested": {"y": -1, "z": 1}}), + ): + assert_frame_equal(expected, df) + + +def test_from_dict_duration_subseconds() -> None: + d = {"duration": [timedelta(seconds=1, microseconds=1000)]} + result = pl.from_dict(d) + expected = pl.select(duration=pl.duration(seconds=1, microseconds=1000)) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + ("dtype", "data"), + [ + (pl.Date, date(2099, 12, 31)), + (pl.Datetime("ms"), datetime(1998, 10, 1, 10, 30)), + (pl.Duration("us"), timedelta(days=1)), + (pl.Time, time(2, 30, 10)), + ], +) +def test_from_dict_cast_logical_type(dtype: pl.DataType, data: Any) -> None: + schema = {"data": dtype} + df = pl.DataFrame({"data": [data]}, schema=schema) + physical_dict = df.cast(pl.Int64).to_dict() + + df_from_dicts = pl.from_dicts( + [ + { + "data": physical_dict["data"][0], + } + ], + schema=schema, + ) + + assert_frame_equal(df_from_dicts, df) diff --git a/py-polars/tests/unit/dataframe/test_getitem.py b/py-polars/tests/unit/dataframe/test_getitem.py new file mode 100644 index 000000000000..d5a48a74879e --- /dev/null +++ b/py-polars/tests/unit/dataframe/test_getitem.py @@ -0,0 +1,489 @@ +from __future__ import annotations + +from typing import Any + +import hypothesis.strategies as st +import numpy as np +import pytest +from hypothesis import given + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal +from polars.testing.parametric import column, dataframes +from tests.unit.conftest import INTEGER_DTYPES, SIGNED_INTEGER_DTYPES + + +@given( + df=dataframes( + max_size=10, + cols=[ + column( + "start", + dtype=pl.Int8, + allow_null=True, + strategy=st.integers(min_value=-8, max_value=8), + ), + column( + "stop", + dtype=pl.Int8, + allow_null=True, + strategy=st.integers(min_value=-6, max_value=6), + ), + column( + "step", + dtype=pl.Int8, + allow_null=True, + strategy=st.integers(min_value=-4, max_value=4).filter( + lambda x: x != 0 + ), + ), + column("misc", dtype=pl.Int32), + ], + ) + # generated dataframe example - + # ┌───────┬──────┬──────┬───────┐ + # │ start ┆ stop ┆ step ┆ misc │ + # │ --- ┆ --- ┆ --- ┆ --- │ + # │ i8 ┆ i8 ┆ i8 ┆ i32 │ + # ╞═══════╪══════╪══════╪═══════╡ + # │ 2 ┆ -1 ┆ null ┆ -55 │ + # │ -3 ┆ 0 ┆ -2 ┆ 61582 │ + # │ null ┆ 1 ┆ 2 ┆ 5865 │ + # └───────┴──────┴──────┴───────┘ +) +def test_df_getitem_row_slice(df: pl.DataFrame) -> None: + # take strategy-generated integer values from the frame as slice bounds. + # use these bounds to slice the same frame, and then validate the result + # against a py-native slice of the same data using the same bounds. + # + # given the average number of rows in the frames, and the value of + # max_examples, this will result in close to 5000 test permutations, + # running in around ~1.5 secs (depending on hardware/etc). + py_data = df.rows() + + for start, stop, step, _ in py_data: + s = slice(start, stop, step) + sliced_py_data = py_data[s] + sliced_df_data = df[s].rows() + + assert sliced_py_data == sliced_df_data, ( + f"slice [{start}:{stop}:{step}] failed on df w/len={df.height}" + ) + + +def test_df_getitem_col_single_name() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + result = df[:, "a"] + expected = df.select("a").to_series() + assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("input", "expected_cols"), + [ + (["a"], ["a"]), + (["a", "d"], ["a", "d"]), + (slice("b", "d"), ["b", "c", "d"]), + (pl.Series(["a", "b"]), ["a", "b"]), + (np.array(["c", "d"]), ["c", "d"]), + ], +) +def test_df_getitem_col_multiple_names(input: Any, expected_cols: list[str]) -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6], "d": [7, 8]}) + result = df[:, input] + expected = df.select(expected_cols) + assert_frame_equal(result, expected) + + +def test_df_getitem_col_single_index() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + result = df[:, 1] + expected = df.select("b").to_series() + assert_series_equal(result, expected) + + +def test_df_getitem_col_two_entries() -> None: + df = pl.DataFrame({"x": [1.0], "y": [1.0]}) + + assert_frame_equal(df["x", "y"], df) + assert_frame_equal(df[True, True], df) + + +@pytest.mark.parametrize( + ("input", "expected_cols"), + [ + ([0], ["a"]), + ([0, 3], ["a", "d"]), + (slice(1, 4), ["b", "c", "d"]), + (pl.Series([0, 1]), ["a", "b"]), + (np.array([2, 3]), ["c", "d"]), + ], +) +def test_df_getitem_col_multiple_indices(input: Any, expected_cols: list[str]) -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6], "d": [7, 8]}) + result = df[:, input] + expected = df.select(expected_cols) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "mask", + [ + [True, False, True], + pl.Series([True, False, True]), + np.array([True, False, True]), + ], +) +def test_df_getitem_col_boolean_mask(mask: Any) -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}) + result = df[:, mask] + expected = df.select("a", "c") + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + ("rng", "expected_cols"), + [ + (range(2), ["a", "b"]), + (range(1, 4), ["b", "c", "d"]), + (range(3, 0, -2), ["d", "b"]), + ], +) +def test_df_getitem_col_range(rng: range, expected_cols: list[str]) -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6], "d": [7, 8]}) + result = df[:, rng] + expected = df.select(expected_cols) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "input", [[], (), pl.Series(dtype=pl.Int64), np.array([], dtype=np.uint32)] +) +def test_df_getitem_col_empty_inputs(input: Any) -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3.0, 4.0]}) + result = df[:, input] + expected = pl.DataFrame() + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + ("input", "match"), + [ + ( + [0.0, 1.0], + "cannot select columns using Sequence with elements of type 'float'", + ), + ( + pl.Series([[1, 2], [3, 4]]), + "cannot select columns using Series of type List\\(Int64\\)", + ), + ( + np.array([0.0, 1.0]), + "cannot select columns using NumPy array of type float64", + ), + (object(), "cannot select columns using key of type 'object'"), + ], +) +def test_df_getitem_col_invalid_inputs(input: Any, match: str) -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3.0, 4.0]}) + with pytest.raises(TypeError, match=match): + df[:, input] + + +@pytest.mark.parametrize( + ("input", "match"), + [ + (["a", 2], "'int' object cannot be converted to 'PyString'"), + ([1, "c"], "'str' object cannot be interpreted as an integer"), + ], +) +def test_df_getitem_col_mixed_inputs(input: list[Any], match: str) -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}) + with pytest.raises(TypeError, match=match): + df[:, input] + + +@pytest.mark.parametrize( + ("input", "match"), + [ + ([0.0, 1.0], "unexpected value while building Series of type Int64"), + ( + pl.Series([[1, 2], [3, 4]]), + "cannot treat Series of type List\\(Int64\\) as indices", + ), + (np.array([0.0, 1.0]), "cannot treat NumPy array of type float64 as indices"), + (object(), "cannot select rows using key of type 'object'"), + ], +) +def test_df_getitem_row_invalid_inputs(input: Any, match: str) -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3.0, 4.0]}) + with pytest.raises(TypeError, match=match): + df[input, :] + + +def test_df_getitem_row_range() -> None: + df = pl.DataFrame({"a": [1, 2, 3, 4], "b": [5.0, 6.0, 7.0, 8.0]}) + result = df[range(3, 0, -2), :] + expected = pl.DataFrame({"a": [4, 2], "b": [8.0, 6.0]}) + assert_frame_equal(result, expected) + + +def test_df_getitem_row_range_single_input() -> None: + df = pl.DataFrame({"a": [1, 2, 3, 4], "b": [5.0, 6.0, 7.0, 8.0]}) + result = df[range(1, 3)] + expected = pl.DataFrame({"a": [2, 3], "b": [6.0, 7.0]}) + assert_frame_equal(result, expected) + + +def test_df_getitem_row_empty_list_single_input() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [5.0, 6.0]}) + result = df[[]] + expected = df.clear() + assert_frame_equal(result, expected) + + +def test_df_getitem() -> None: + """Test all the methods to use [] on a dataframe.""" + df = pl.DataFrame({"a": [1.0, 2.0, 3.0, 4.0], "b": [3, 4, 5, 6]}) + + # multiple slices. + # The first element refers to the rows, the second element to columns + assert_frame_equal(df[:, :], df) + + # str, always refers to a column name + assert_series_equal(df["a"], pl.Series("a", [1.0, 2.0, 3.0, 4.0])) + + # int, always refers to a row index (zero-based): index=1 => second row + assert_frame_equal(df[1], pl.DataFrame({"a": [2.0], "b": [4]})) + + # int, int. + # The first element refers to the rows, the second element to columns + assert df[2, 1] == 5 + assert df[2, -2] == 3.0 + + with pytest.raises(IndexError): + # Column index out of bounds + df[2, 2] + + with pytest.raises(IndexError): + # Column index out of bounds + df[2, -3] + + # int, list[int]. + # The first element refers to the rows, the second element to columns + assert_frame_equal(df[2, [1, 0]], pl.DataFrame({"b": [5], "a": [3.0]})) + assert_frame_equal(df[2, [-1, -2]], pl.DataFrame({"b": [5], "a": [3.0]})) + + with pytest.raises(IndexError): + # Column index out of bounds + df[2, [2, 0]] + + with pytest.raises(IndexError): + # Column index out of bounds + df[2, [2, -3]] + + # slice. Below an example of taking every second row + assert_frame_equal(df[1::2], pl.DataFrame({"a": [2.0, 4.0], "b": [4, 6]})) + + # slice, empty slice + assert df[:0].columns == ["a", "b"] + assert len(df[:0]) == 0 + + # make mypy happy + empty: list[int] = [] + + # empty list with column selector drops rows but keeps columns + assert_frame_equal(df[empty, :], df[:0]) + + # sequences (lists or tuples; tuple only if length != 2) + # if strings or list of expressions, assumed to be column names + # if bools, assumed to be a row mask + # if integers, assumed to be row indices + assert_frame_equal(df[["a", "b"]], df) + assert_frame_equal(df.select([pl.col("a"), pl.col("b")]), df) + assert_frame_equal( + df[[1, -4, -1, 2, 1]], + pl.DataFrame({"a": [2.0, 1.0, 4.0, 3.0, 2.0], "b": [4, 3, 6, 5, 4]}), + ) + + # pl.Series: strings for column selections. + assert_frame_equal(df[pl.Series("", ["a", "b"])], df) + + # pl.Series: positive idxs or empty idxs for row selection. + for pl_dtype in INTEGER_DTYPES: + assert_frame_equal( + df[pl.Series("", [1, 0, 3, 2, 3, 0], dtype=pl_dtype)], + pl.DataFrame( + {"a": [2.0, 1.0, 4.0, 3.0, 4.0, 1.0], "b": [4, 3, 6, 5, 6, 3]} + ), + ) + assert df[pl.Series("", [], dtype=pl_dtype)].columns == ["a", "b"] + + # pl.Series: positive and negative idxs for row selection. + for pl_dtype in SIGNED_INTEGER_DTYPES: + assert_frame_equal( + df[pl.Series("", [-1, 0, -3, -2, 3, -4], dtype=pl_dtype)], + pl.DataFrame( + {"a": [4.0, 1.0, 2.0, 3.0, 4.0, 1.0], "b": [6, 3, 4, 5, 6, 3]} + ), + ) + + # Boolean masks for rows not supported + with pytest.raises(TypeError): + df[[True, False, True], [False, True]] + with pytest.raises(TypeError): + df[pl.Series([True, False, True]), "b"] + + assert_frame_equal(df[np.array([True, False])], df[:, :1]) + + # wrong length boolean mask for column selection + with pytest.raises( + ValueError, + match=f"expected {df.width} values when selecting columns by boolean mask", + ): + df[:, [True, False, True]] + + +def test_df_getitem_numpy() -> None: + # nupmy getitem: assumed to be row indices if integers, or columns if strings + df = pl.DataFrame({"a": [1.0, 2.0, 3.0, 4.0], "b": [3, 4, 5, 6]}) + + # numpy array: positive idxs and empty idx + for np_dtype in ( + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + ): + assert_frame_equal( + df[np.array([1, 0, 3, 2, 3, 0], dtype=np_dtype)], + pl.DataFrame( + {"a": [2.0, 1.0, 4.0, 3.0, 4.0, 1.0], "b": [4, 3, 6, 5, 6, 3]} + ), + ) + assert df[np.array([], dtype=np_dtype)].columns == ["a", "b"] + + # numpy array: positive and negative idxs. + for np_dtype in (np.int8, np.int16, np.int32, np.int64): + assert_frame_equal( + df[np.array([-1, 0, -3, -2, 3, -4], dtype=np_dtype)], + pl.DataFrame( + {"a": [4.0, 1.0, 2.0, 3.0, 4.0, 1.0], "b": [6, 3, 4, 5, 6, 3]} + ), + ) + + # zero-dimensional array indexing is equivalent to int row selection + assert_frame_equal(df[np.array(0)], pl.DataFrame({"a": [1.0], "b": [3]})) + assert_frame_equal(df[np.array(1)], pl.DataFrame({"a": [2.0], "b": [4]})) + + # note that we cannot use floats (even if they could be cast to int without loss) + with pytest.raises( + TypeError, + match="cannot select columns using NumPy array of type float", + ): + _ = df[np.array([1.0])] + + with pytest.raises( + TypeError, + match="multi-dimensional NumPy arrays not supported as index", + ): + df[np.array([[0], [1]])] + + +def test_df_getitem_extended() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0], "c": ["a", "b", "c"]}) + + # select columns by mask + assert df[:2, :1].rows() == [(1,), (2,)] + assert df[:2, ["a"]].rows() == [(1,), (2,)] + + # column selection by string(s) in first dimension + assert df["a"].to_list() == [1, 2, 3] + assert df["b"].to_list() == [1.0, 2.0, 3.0] + assert df["c"].to_list() == ["a", "b", "c"] + + # row selection by integers(s) in first dimension + assert_frame_equal(df[0], pl.DataFrame({"a": [1], "b": [1.0], "c": ["a"]})) + assert_frame_equal(df[-1], pl.DataFrame({"a": [3], "b": [3.0], "c": ["c"]})) + + # row, column selection when using two dimensions + assert df[:, "a"].to_list() == [1, 2, 3] + assert df[:, 1].to_list() == [1.0, 2.0, 3.0] + assert df[:2, 2].to_list() == ["a", "b"] + + assert_frame_equal( + df[[1, 2]], pl.DataFrame({"a": [2, 3], "b": [2.0, 3.0], "c": ["b", "c"]}) + ) + assert_frame_equal( + df[[-1, -2]], pl.DataFrame({"a": [3, 2], "b": [3.0, 2.0], "c": ["c", "b"]}) + ) + + assert df[["a", "b"]].columns == ["a", "b"] + assert_frame_equal( + df[[1, 2], [1, 2]], pl.DataFrame({"b": [2.0, 3.0], "c": ["b", "c"]}) + ) + assert df[1, 2] == "b" + assert df[1, 1] == 2.0 + assert df[2, 0] == 3 + + assert df[[2], ["a", "b"]].rows() == [(3, 3.0)] + assert df.to_series(0).name == "a" + assert (df["a"] == df["a"]).sum() == 3 + assert (df["c"] == df["a"].cast(str)).sum() == 0 + assert df[:, "a":"b"].rows() == [(1, 1.0), (2, 2.0), (3, 3.0)] # type: ignore[index, misc] + assert df[:, "a":"c"].columns == ["a", "b", "c"] # type: ignore[index, misc] + assert df[:, []].shape == (0, 0) + expect = pl.DataFrame({"c": ["b"]}) + assert_frame_equal(df[1, [2]], expect) + expect = pl.DataFrame({"b": [1.0, 3.0]}) + assert_frame_equal(df[[0, 2], [1]], expect) + assert df[0, "c"] == "a" + assert df[1, "c"] == "b" + assert df[2, "c"] == "c" + assert df[0, "a"] == 1 + + # more slicing + expect = pl.DataFrame({"a": [3, 2, 1], "b": [3.0, 2.0, 1.0], "c": ["c", "b", "a"]}) + assert_frame_equal(df[::-1], expect) + expect = pl.DataFrame({"a": [1, 2], "b": [1.0, 2.0], "c": ["a", "b"]}) + assert_frame_equal(df[:-1], expect) + + expect = pl.DataFrame({"a": [1, 3], "b": [1.0, 3.0], "c": ["a", "c"]}) + assert_frame_equal(df[::2], expect) + + # only allow boolean values in column position + df = pl.DataFrame( + { + "a": [1, 2], + "b": [2, 3], + "c": [3, 4], + } + ) + + assert df[:, [False, True, True]].columns == ["b", "c"] + assert df[:, pl.Series([False, True, True])].columns == ["b", "c"] + assert df[:, pl.Series([False, False, False])].columns == [] + + +def test_df_getitem_5343() -> None: + # https://github.com/pola-rs/polars/issues/5343 + df = pl.DataFrame( + { + f"foo{col}": [n**col for n in range(5)] # 5 rows + for col in range(12) # 12 columns + } + ) + assert df[4, 4] == 256 + assert df[4, 5] == 1024 + assert_frame_equal(df[4, [2]], pl.DataFrame({"foo2": [16]})) + assert_frame_equal(df[4, [5]], pl.DataFrame({"foo5": [1024]})) + + +def test_no_deadlock_19358() -> None: + s = pl.Series(["text"] * 100 + [1] * 100, dtype=pl.Object) + result = s.to_frame()[[0, -1]] + assert result[""].to_list() == ["text", 1] diff --git a/py-polars/tests/unit/dataframe/test_glimpse.py b/py-polars/tests/unit/dataframe/test_glimpse.py new file mode 100644 index 000000000000..022bf7205d76 --- /dev/null +++ b/py-polars/tests/unit/dataframe/test_glimpse.py @@ -0,0 +1,88 @@ +import textwrap +from datetime import datetime +from typing import Any + +import polars as pl + + +def test_glimpse(capsys: Any) -> None: + df = pl.DataFrame( + { + "a": [1.0, 2.8, 3.0], + "b": [4, 5, None], + "c": [True, False, True], + "d": [None, "b", "c"], + "e": ["usd", "eur", None], + "f": pl.datetime_range( + datetime(2023, 1, 1), + datetime(2023, 1, 3), + "1d", + time_unit="us", + eager=True, + ), + "g": pl.datetime_range( + datetime(2023, 1, 1), + datetime(2023, 1, 3), + "1d", + time_unit="ms", + eager=True, + ), + "h": pl.datetime_range( + datetime(2023, 1, 1), + datetime(2023, 1, 3), + "1d", + time_unit="ns", + eager=True, + ), + "i": [[5, 6], [3, 4], [9, 8]], + "j": [[5.0, 6.0], [3.0, 4.0], [9.0, 8.0]], + "k": [["A", "a"], ["B", "b"], ["C", "c"]], + } + ) + result = df.glimpse(return_as_string=True) + + expected = textwrap.dedent( + """\ + Rows: 3 + Columns: 11 + $ a 1.0, 2.8, 3.0 + $ b 4, 5, None + $ c True, False, True + $ d None, 'b', 'c' + $ e 'usd', 'eur', None + $ f 2023-01-01 00:00:00, 2023-01-02 00:00:00, 2023-01-03 00:00:00 + $ g 2023-01-01 00:00:00, 2023-01-02 00:00:00, 2023-01-03 00:00:00 + $ h 2023-01-01 00:00:00, 2023-01-02 00:00:00, 2023-01-03 00:00:00 + $ i [5, 6], [3, 4], [9, 8] + $ j [5.0, 6.0], [3.0, 4.0], [9.0, 8.0] + $ k ['A', 'a'], ['B', 'b'], ['C', 'c'] + """ + ) + assert result == expected + + # the default is to print to the console + df.glimpse() + # remove the last newline on the capsys + assert capsys.readouterr().out[:-1] == expected + + colc = "a" * 96 + df = pl.DataFrame({colc: [11, 22, 33, 44, 55, 66]}) + result = df.glimpse( + return_as_string=True, max_colname_length=20, max_items_per_column=4 + ) + expected = textwrap.dedent( + """\ + Rows: 6 + Columns: 1 + $ aaaaaaaaaaaaaaaaaaa… 11, 22, 33, 44 + """ + ) + assert result == expected + + +def test_glimpse_colname_length() -> None: + df = pl.DataFrame({"a" * 100: [1, 2, 3]}) + result = df.glimpse(max_colname_length=96, return_as_string=True) + + expected = f"$ {'a' * 95}… 1, 2, 3" + assert result.strip().split("\n")[-1] == expected diff --git a/py-polars/tests/unit/dataframe/test_item.py b/py-polars/tests/unit/dataframe/test_item.py new file mode 100644 index 000000000000..12f9d87c913f --- /dev/null +++ b/py-polars/tests/unit/dataframe/test_item.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import pytest + +import polars as pl + + +def test_df_item() -> None: + df = pl.DataFrame({"a": [1]}) + assert df.item() == 1 + + +def test_df_item_empty() -> None: + df = pl.DataFrame() + with pytest.raises(ValueError, match=r".* frame has shape \(0, 0\)"): + df.item() + + +def test_df_item_incorrect_shape_rows() -> None: + df = pl.DataFrame({"a": [1, 2]}) + with pytest.raises(ValueError, match=r".* frame has shape \(2, 1\)"): + df.item() + + +def test_df_item_incorrect_shape_columns() -> None: + df = pl.DataFrame({"a": [1], "b": [2]}) + with pytest.raises(ValueError, match=r".* frame has shape \(1, 2\)"): + df.item() + + +@pytest.fixture(scope="module") +def df() -> pl.DataFrame: + return pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + +@pytest.mark.parametrize( + ("row", "col", "expected"), + [ + (0, 0, 1), + (1, "a", 2), + (-1, 1, 6), + (-2, "b", 5), + ], +) +def test_df_item_with_indices( + row: int, col: int | str, expected: int, df: pl.DataFrame +) -> None: + assert df.item(row, col) == expected + + +def test_df_item_with_single_index(df: pl.DataFrame) -> None: + with pytest.raises(ValueError): + df.item(0) + with pytest.raises(ValueError): + df.item(column="b") + with pytest.raises(ValueError): + df.item(None, 0) + + +@pytest.mark.parametrize( + ("row", "col"), [(0, 10), (10, 0), (10, 10), (-10, 0), (-10, 10)] +) +def test_df_item_out_of_bounds(row: int, col: int, df: pl.DataFrame) -> None: + with pytest.raises(IndexError, match="out of bounds"): + df.item(row, col) diff --git a/py-polars/tests/unit/dataframe/test_null_count.py b/py-polars/tests/unit/dataframe/test_null_count.py new file mode 100644 index 000000000000..3f9d5484327e --- /dev/null +++ b/py-polars/tests/unit/dataframe/test_null_count.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from hypothesis import example, given + +import polars as pl +from polars.testing.parametric import dataframes + + +@given( + df=dataframes( + min_size=1, + min_cols=1, + allow_null=True, + excluded_dtypes=[ + pl.String, + pl.List, + pl.Struct, # See: https://github.com/pola-rs/polars/issues/3462 + ], + ) +) +@example(df=pl.DataFrame(schema=["x", "y", "z"])) +@example(df=pl.DataFrame()) +def test_null_count(df: pl.DataFrame) -> None: + # note: the zero-row and zero-col cases are always passed as explicit examples + null_count, ncols = df.null_count(), df.width + assert null_count.shape == (1, ncols) + for idx, count in enumerate(null_count.rows()[0]): + assert count == sum(v is None for v in df.to_series(idx).to_list()) diff --git a/py-polars/tests/unit/dataframe/test_partition_by.py b/py-polars/tests/unit/dataframe/test_partition_by.py new file mode 100644 index 000000000000..5a0f145d7805 --- /dev/null +++ b/py-polars/tests/unit/dataframe/test_partition_by.py @@ -0,0 +1,91 @@ +from typing import Any + +import pytest + +import polars as pl +import polars.selectors as cs + + +@pytest.fixture +def df() -> pl.DataFrame: + return pl.DataFrame( + { + "foo": ["A", "A", "B", "B", "C"], + "N": [1, 2, 2, 4, 2], + "bar": ["k", "l", "m", "m", "l"], + } + ) + + +@pytest.mark.parametrize("input", [["foo", "bar"], cs.string()]) +def test_partition_by(df: pl.DataFrame, input: Any) -> None: + result = df.partition_by(input, maintain_order=True) + expected = [ + {"foo": ["A"], "N": [1], "bar": ["k"]}, + {"foo": ["A"], "N": [2], "bar": ["l"]}, + {"foo": ["B", "B"], "N": [2, 4], "bar": ["m", "m"]}, + {"foo": ["C"], "N": [2], "bar": ["l"]}, + ] + assert [a.to_dict(as_series=False) for a in result] == expected + + +def test_partition_by_include_key_false(df: pl.DataFrame) -> None: + result = df.partition_by("foo", "bar", maintain_order=True, include_key=False) + expected = [ + {"N": [1]}, + {"N": [2]}, + {"N": [2, 4]}, + {"N": [2]}, + ] + assert [a.to_dict(as_series=False) for a in result] == expected + + +def test_partition_by_single(df: pl.DataFrame) -> None: + result = df.partition_by("foo", maintain_order=True) + expected = [ + {"foo": ["A", "A"], "N": [1, 2], "bar": ["k", "l"]}, + {"foo": ["B", "B"], "N": [2, 4], "bar": ["m", "m"]}, + {"foo": ["C"], "N": [2], "bar": ["l"]}, + ] + assert [a.to_dict(as_series=False) for a in result] == expected + + +def test_partition_by_as_dict() -> None: + df = pl.DataFrame({"a": ["one", "two", "one", "two"], "b": [1, 2, 3, 4]}) + result = df.partition_by(cs.all(), as_dict=True) + result_first = result["one", 1] + assert result_first.to_dict(as_series=False) == {"a": ["one"], "b": [1]} + + result = df.partition_by("a", as_dict=True) + result_first = result["one",] + assert result_first.to_dict(as_series=False) == {"a": ["one", "one"], "b": [1, 3]} + + +def test_partition_by_as_dict_include_keys_false() -> None: + df = pl.DataFrame({"a": ["one", "two", "one", "two"], "b": [1, 2, 3, 4]}) + + result = df.partition_by("a", include_key=False, as_dict=True) + result_first = result["one",] + assert result_first.to_dict(as_series=False) == {"b": [1, 3]} + + +def test_partition_by_as_dict_include_keys_false_maintain_order_false() -> None: + df = pl.DataFrame({"a": ["one", "two", "one", "two"], "b": [1, 2, 3, 4]}) + with pytest.raises(ValueError): + df.partition_by(["a"], maintain_order=False, include_key=False, as_dict=True) + + +@pytest.mark.slow +def test_partition_by_as_dict_include_keys_false_large() -> None: + # test with both as_dict and include_key=False + df = pl.DataFrame( + { + "a": pl.int_range(0, 100, dtype=pl.UInt8, eager=True), + "b": pl.int_range(0, 100, dtype=pl.UInt8, eager=True), + "c": pl.int_range(0, 100, dtype=pl.UInt8, eager=True), + "d": pl.int_range(0, 100, dtype=pl.UInt8, eager=True), + } + ).sample(n=100_000, with_replacement=True, shuffle=True) + + partitions = df.partition_by(["a", "b"], as_dict=True, include_key=False) + assert all(key == value.row(0) for key, value in partitions.items()) diff --git a/py-polars/tests/unit/dataframe/test_repr.py b/py-polars/tests/unit/dataframe/test_repr.py new file mode 100644 index 000000000000..e0a137c718ef --- /dev/null +++ b/py-polars/tests/unit/dataframe/test_repr.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from hypothesis import given + +from polars.testing.parametric import dataframes + +if TYPE_CHECKING: + import polars as pl + + +@given(df=dataframes()) +def test_repr(df: pl.DataFrame) -> None: + assert isinstance(repr(df), str) diff --git a/py-polars/tests/unit/dataframe/test_repr_html.py b/py-polars/tests/unit/dataframe/test_repr_html.py new file mode 100644 index 000000000000..95e4fa3da958 --- /dev/null +++ b/py-polars/tests/unit/dataframe/test_repr_html.py @@ -0,0 +1,98 @@ +import pytest + +import polars as pl + + +def test_repr_html() -> None: + # check it does not panic/error, and appears to contain + # a reasonable table with suitably escaped html entities. + df = pl.DataFrame( + { + "foo": [1, 2, 3], + "": ["a", "b", "c"], + "": ["a", "b", "c"], + } + ) + html = df._repr_html_() + for match in ( + "foo", + "<bar>", + "<baz", + "spam>", + "1", + "2", + "3", + ): + assert match in html, f"Expected to find {match!r} in html repr" + + +def test_html_tables() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) + + # default: header contains names/dtypes + header = "abci64i64i64" + assert header in df._repr_html_() + + # validate that relevant config options are respected + with pl.Config(tbl_hide_column_names=True): + header = "i64i64i64" + assert header in df._repr_html_() + + with pl.Config(tbl_hide_column_data_types=True): + header = "abc" + assert header in df._repr_html_() + + with pl.Config( + tbl_hide_column_data_types=True, + tbl_hide_column_names=True, + ): + header = "" + assert header in df._repr_html_() + + +def test_df_repr_html_max_rows_default() -> None: + df = pl.DataFrame({"a": range(50)}) + + html = df._repr_html_() + + expected_rows = 10 + assert html.count("") - 2 == expected_rows + + +def test_df_repr_html_max_rows_odd() -> None: + df = pl.DataFrame({"a": range(50)}) + + with pl.Config(tbl_rows=9): + html = df._repr_html_() + + expected_rows = 9 + assert html.count("") - 2 == expected_rows + + +def test_series_repr_html_max_rows_default() -> None: + s = pl.Series("a", range(50)) + + html = s._repr_html_() + + expected_rows = 10 + assert html.count("") - 2 == expected_rows + + +@pytest.mark.parametrize( + ("text", "expected"), + [ + ("single space", "single space"), + ("multiple spaces", "multiple   spaces"), + ( + " trailing & leading spaces ", + "  trailing & leading spaces  ", + ), + ], +) +def test_html_representation_multiple_spaces(text: str, expected: str) -> None: + with pl.Config(fmt_str_lengths=100): + html_repr = pl.DataFrame({"s": [text]})._repr_html_() + assert f""{expected}"" in html_repr diff --git a/py-polars/tests/unit/dataframe/test_serde.py b/py-polars/tests/unit/dataframe/test_serde.py new file mode 100644 index 000000000000..9fa5751d8098 --- /dev/null +++ b/py-polars/tests/unit/dataframe/test_serde.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +import io +from datetime import date, datetime, timedelta +from decimal import Decimal as D +from typing import TYPE_CHECKING, Any + +import pytest +from hypothesis import example, given + +import polars as pl +from polars.exceptions import ComputeError +from polars.testing import assert_frame_equal +from polars.testing.parametric import dataframes + +if TYPE_CHECKING: + from pathlib import Path + + from polars._typing import SerializationFormat + + +def test_df_serde_roundtrip_binary(df: pl.DataFrame) -> None: + serialized = df.serialize() + result = pl.DataFrame.deserialize(io.BytesIO(serialized), format="binary") + assert_frame_equal(result, df, categorical_as_str=True) + + +@given(df=dataframes()) +@example(df=pl.DataFrame({"a": [None, None]}, schema={"a": pl.Null})) +@example(df=pl.DataFrame(schema={"a": pl.List(pl.String)})) +def test_df_serde_roundtrip_json(df: pl.DataFrame) -> None: + serialized = df.serialize(format="json") + result = pl.DataFrame.deserialize(io.StringIO(serialized), format="json") + + if isinstance(dt := df.to_series(0).dtype, pl.Decimal): + if dt.precision is None: + # This gets converted to precision 38 upon `to_arrow()` + pytest.skip("precision None") + + assert_frame_equal(result, df, categorical_as_str=True) + + +@pytest.mark.may_fail_auto_streaming +def test_df_serde(df: pl.DataFrame) -> None: + serialized = df.serialize() + assert isinstance(serialized, bytes) + result = pl.DataFrame.deserialize(io.BytesIO(serialized)) + assert_frame_equal(result, df) + + +@pytest.mark.may_fail_auto_streaming +def test_df_serde_json_stringio(df: pl.DataFrame) -> None: + serialized = df.serialize(format="json") + assert isinstance(serialized, str) + result = pl.DataFrame.deserialize(io.StringIO(serialized), format="json") + assert_frame_equal(result, df) + + +def test_df_serialize_json() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [9, 5, 6]}).sort("a") + result = df.serialize(format="json") + + assert isinstance(result, str) + + f = io.StringIO(result) + + assert_frame_equal(pl.DataFrame.deserialize(f, format="json"), df) + + +@pytest.mark.parametrize( + ("format", "buf"), + [ + ("binary", io.BytesIO()), + ("json", io.StringIO()), + ("json", io.BytesIO()), + ], +) +def test_df_serde_to_from_buffer( + df: pl.DataFrame, format: SerializationFormat, buf: io.IOBase +) -> None: + df.serialize(buf, format=format) + buf.seek(0) + read_df = pl.DataFrame.deserialize(buf, format=format) + assert_frame_equal(df, read_df, categorical_as_str=True) + + +@pytest.mark.write_disk +def test_df_serde_to_from_file(df: pl.DataFrame, tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + file_path = tmp_path / "small.bin" + df.serialize(file_path) + out = pl.DataFrame.deserialize(file_path) + + assert_frame_equal(df, out, categorical_as_str=True) + + +def test_df_serde2(df: pl.DataFrame) -> None: + # Text-based conversion loses time info + df = df.select(pl.all().exclude(["cat", "time"])) + s = df.serialize() + f = io.BytesIO() + f.write(s) + f.seek(0) + out = pl.DataFrame.deserialize(f) + assert_frame_equal(out, df) + + file = io.BytesIO() + df.serialize(file) + file.seek(0) + out = pl.DataFrame.deserialize(file) + assert_frame_equal(out, df) + + +def test_df_serde_enum() -> None: + dtype = pl.Enum(["foo", "bar", "ham"]) + df = pl.DataFrame([pl.Series("e", ["foo", "bar", "ham"], dtype=dtype)]) + buf = io.BytesIO() + df.serialize(buf) + buf.seek(0) + df_in = pl.DataFrame.deserialize(buf) + assert df_in.schema["e"] == dtype + + +@pytest.mark.parametrize( + ("data", "dtype"), + [ + ([[1, 2, 3], [None, None, None], [1, None, 3]], pl.Array(pl.Int32(), shape=3)), + ([["a", "b"], [None, None]], pl.Array(pl.Utf8, shape=2)), + ([[True, False, None], [None, None, None]], pl.Array(pl.Boolean, shape=3)), + ( + [[[1, 2, 3], [4, None, 5]], None, [[None, None, 2]]], + pl.List(pl.Array(pl.Int32(), shape=3)), + ), + ( + [ + [datetime(1991, 1, 1), datetime(1991, 1, 1), None], + [None, None, None], + ], + pl.Array(pl.Datetime, shape=3), + ), + ( + [[D("1.0"), D("2.0"), D("3.0")], [None, None, None]], + # we have to specify precision, because `AnonymousListBuilder::finish` + # use `ArrowDataType` which will remap `None` precision to `38` + pl.Array(pl.Decimal(precision=38, scale=1), shape=3), + ), + ], +) +def test_df_serde_array(data: Any, dtype: pl.DataType) -> None: + df = pl.DataFrame({"foo": data}, schema={"foo": dtype}) + buf = io.BytesIO() + df.serialize(buf) + buf.seek(0) + deserialized_df = pl.DataFrame.deserialize(buf) + assert_frame_equal(deserialized_df, df) + + +@pytest.mark.parametrize( + ("data", "dtype"), + [ + ( + [ + [ + datetime(1997, 10, 1), + datetime(2000, 1, 2, 10, 30, 1), + ], + [None, None], + ], + pl.Array(pl.Datetime, shape=2), + ), + ( + [[date(1997, 10, 1), date(2000, 1, 1)], [None, None]], + pl.Array(pl.Date, shape=2), + ), + ( + [ + [timedelta(seconds=1), timedelta(seconds=10)], + [None, None], + ], + pl.Array(pl.Duration, shape=2), + ), + ], +) +def test_df_serde_array_logical_inner_type(data: Any, dtype: pl.DataType) -> None: + df = pl.DataFrame({"foo": data}, schema={"foo": dtype}) + buf = io.BytesIO() + df.serialize(buf) + buf.seek(0) + result = pl.DataFrame.deserialize(buf) + assert_frame_equal(result, df) + + +def test_df_serde_float_inf_nan() -> None: + df = pl.DataFrame({"a": [1.0, float("inf"), float("-inf"), float("nan")]}) + ser = df.serialize(format="json") + result = pl.DataFrame.deserialize(io.StringIO(ser), format="json") + assert_frame_equal(result, df) + + +def test_df_serialize_invalid_type() -> None: + df = pl.DataFrame({"a": [object()]}) + with pytest.raises( + ComputeError, match="serializing data of type Object is not supported" + ): + df.serialize() + + +def test_df_serde_list_of_null_17230() -> None: + df = pl.Series([[]], dtype=pl.List(pl.Null)).to_frame() + ser = df.serialize(format="json") + result = pl.DataFrame.deserialize(io.StringIO(ser), format="json") + assert_frame_equal(result, df) diff --git a/py-polars/tests/unit/dataframe/test_shape.py b/py-polars/tests/unit/dataframe/test_shape.py new file mode 100644 index 000000000000..2409ee0c2f3f --- /dev/null +++ b/py-polars/tests/unit/dataframe/test_shape.py @@ -0,0 +1,11 @@ +import pytest + +import polars as pl + + +# TODO: remove this skip when streaming raises +@pytest.mark.may_fail_auto_streaming +def test_raise_invalid_shape_19108() -> None: + df = pl.DataFrame({"foo": [1, 2], "bar": [3, 4]}) + with pytest.raises(pl.exceptions.ShapeError): + df.select(pl.col.foo.head(0), pl.col.bar.head(1)) diff --git a/py-polars/tests/unit/dataframe/test_to_dict.py b/py-polars/tests/unit/dataframe/test_to_dict.py new file mode 100644 index 000000000000..fe50b832a8d4 --- /dev/null +++ b/py-polars/tests/unit/dataframe/test_to_dict.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from typing import Any + +import pytest +from hypothesis import given + +import polars as pl +from polars.testing import assert_frame_equal +from polars.testing.parametric import dataframes + + +@given( + df=dataframes( + excluded_dtypes=[ + pl.Categorical, # Bug: https://github.com/pola-rs/polars/issues/16196 + pl.Struct, + ], + # Roundtrip doesn't work with time zones: + # https://github.com/pola-rs/polars/issues/16297 + allow_time_zones=False, + ) +) +def test_to_dict(df: pl.DataFrame) -> None: + d = df.to_dict(as_series=False) + result = pl.from_dict(d, schema=df.schema) + assert_frame_equal(df, result, categorical_as_str=True) + + +@pytest.mark.parametrize( + ("as_series", "inner_dtype"), + [ + (True, pl.Series), + (False, list), + ], +) +def test_to_dict_misc(as_series: bool, inner_dtype: Any) -> None: + df = pl.DataFrame( + { + "A": [1, 2, 3, 4, 5], + "fruits": ["banana", "banana", "apple", "apple", "banana"], + "B": [5, 4, 3, 2, 1], + "cars": ["beetle", "audi", "beetle", "beetle", "beetle"], + "optional": [28, 300, None, 2, -30], + } + ) + s = df.to_dict(as_series=as_series) + assert isinstance(s, dict) + for v in s.values(): + assert isinstance(v, inner_dtype) + assert len(v) == df.height diff --git a/py-polars/tests/unit/dataframe/test_upsample.py b/py-polars/tests/unit/dataframe/test_upsample.py new file mode 100644 index 000000000000..46a0a8dd6b44 --- /dev/null +++ b/py-polars/tests/unit/dataframe/test_upsample.py @@ -0,0 +1,283 @@ +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING +from zoneinfo import ZoneInfo + +import pytest + +import polars as pl +from polars.exceptions import InvalidOperationError +from polars.testing import assert_frame_equal + +if TYPE_CHECKING: + from datetime import timezone + + from polars._typing import FillNullStrategy, PolarsIntegerType + + +@pytest.mark.parametrize( + ("time_zone", "tzinfo"), + [ + (None, None), + ("Europe/Warsaw", ZoneInfo("Europe/Warsaw")), + ], +) +def test_upsample(time_zone: str | None, tzinfo: ZoneInfo | timezone | None) -> None: + df = pl.DataFrame( + { + "time": [ + datetime(2021, 2, 1), + datetime(2021, 4, 1), + datetime(2021, 5, 1), + datetime(2021, 6, 1), + ], + "admin": ["Åland", "Netherlands", "Åland", "Netherlands"], + "test2": [0, 1, 2, 3], + } + ).with_columns(pl.col("time").dt.replace_time_zone(time_zone).set_sorted()) + + up = df.upsample( + time_column="time", + every="1mo", + group_by="admin", + maintain_order=True, + ).select(pl.all().fill_null(strategy="forward")) + + # this print will panic if timezones feature is not activated + # don't remove + print(up) + + expected = pl.DataFrame( + { + "time": [ + datetime(2021, 2, 1, 0, 0), + datetime(2021, 3, 1, 0, 0), + datetime(2021, 4, 1, 0, 0), + datetime(2021, 5, 1, 0, 0), + datetime(2021, 4, 1, 0, 0), + datetime(2021, 5, 1, 0, 0), + datetime(2021, 6, 1, 0, 0), + ], + "admin": [ + "Åland", + "Åland", + "Åland", + "Åland", + "Netherlands", + "Netherlands", + "Netherlands", + ], + "test2": [0, 0, 0, 2, 1, 1, 3], + } + ) + expected = expected.with_columns(pl.col("time").dt.replace_time_zone(time_zone)) + + assert_frame_equal(up, expected) + + +@pytest.mark.parametrize("time_zone", [None, "US/Central"]) +def test_upsample_crossing_dst(time_zone: str | None) -> None: + df = pl.DataFrame( + { + "time": pl.datetime_range( + datetime(2021, 11, 6), + datetime(2021, 11, 8), + time_zone=time_zone, + eager=True, + ), + "values": [1, 2, 3], + } + ) + + result = df.upsample(time_column="time", every="1d") + + expected = pl.DataFrame( + { + "time": [ + datetime(2021, 11, 6), + datetime(2021, 11, 7), + datetime(2021, 11, 8), + ], + "values": [1, 2, 3], + } + ).with_columns(pl.col("time").dt.replace_time_zone(time_zone)) + + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + ("time_zone", "tzinfo"), + [ + (None, None), + ("Pacific/Rarotonga", ZoneInfo("Pacific/Rarotonga")), + ], +) +def test_upsample_time_zones( + time_zone: str | None, tzinfo: timezone | ZoneInfo | None +) -> None: + df = pl.DataFrame( + { + "time": pl.datetime_range( + start=datetime(2021, 12, 16), + end=datetime(2021, 12, 16, 3), + interval="30m", + eager=True, + ), + "groups": ["a", "a", "a", "b", "b", "a", "a"], + "values": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], + } + ) + expected = pl.DataFrame( + { + "time": [ + datetime(2021, 12, 16, 0, 0), + datetime(2021, 12, 16, 1, 0), + datetime(2021, 12, 16, 2, 0), + datetime(2021, 12, 16, 3, 0), + ], + "groups": ["a", "a", "b", "a"], + "values": [1.0, 3.0, 5.0, 7.0], + } + ) + df = df.with_columns(pl.col("time").dt.replace_time_zone(time_zone)) + expected = expected.with_columns(pl.col("time").dt.replace_time_zone(time_zone)) + result = df.upsample(time_column="time", every="60m").fill_null(strategy="forward") + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + ("every", "fill", "expected_index", "expected_groups"), + [ + ( + "1i", + "forward", + [1, 2, 3, 4] + [5, 6, 7], + ["a"] * 4 + ["b"] * 3, + ), + ( + "1i", + "backward", + [1, 2, 3, 4] + [5, 6, 7], + ["a"] * 4 + ["b"] * 3, + ), + ], +) +@pytest.mark.parametrize("dtype", [pl.Int32, pl.Int64, pl.UInt32, pl.UInt64]) +def test_upsample_index( + every: str, + fill: FillNullStrategy | None, + expected_index: list[int], + expected_groups: list[str], + dtype: PolarsIntegerType, +) -> None: + df = ( + pl.DataFrame( + { + "index": [1, 2, 4] + [5, 7], + "groups": ["a"] * 3 + ["b"] * 2, + } + ) + .with_columns(pl.col("index").cast(dtype)) + .set_sorted("index") + ) + expected = pl.DataFrame( + { + "index": expected_index, + "groups": expected_groups, + } + ).with_columns(pl.col("index").cast(dtype)) + result = ( + df.upsample(time_column="index", group_by="groups", every=every) + .fill_null(strategy=fill) + .sort(["groups", "index"]) + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("maintain_order", [True, False]) +def test_upsample_index_invalid( + df: pl.DataFrame, + maintain_order: bool, +) -> None: + df = pl.DataFrame( + { + "index": [1, 2, 4, 5, 7], + "groups": ["a"] * 3 + ["b"] * 2, + } + ).set_sorted("index") + + with pytest.raises(InvalidOperationError, match=r"must be a parsed integer"): + df.upsample( + time_column="index", + every="1h", + maintain_order=maintain_order, + ) + + +def test_upsample_sorted_only_within_group() -> None: + df = pl.DataFrame( + { + "time": [ + datetime(2021, 4, 1), + datetime(2021, 2, 1), + datetime(2021, 5, 1), + datetime(2021, 6, 1), + ], + "admin": ["Netherlands", "Åland", "Åland", "Netherlands"], + "test2": [1, 0, 2, 3], + } + ) + + up = df.upsample( + time_column="time", + every="1mo", + group_by="admin", + maintain_order=True, + ).select(pl.all().fill_null(strategy="forward")) + + expected = pl.DataFrame( + { + "time": [ + datetime(2021, 4, 1, 0, 0), + datetime(2021, 5, 1, 0, 0), + datetime(2021, 6, 1, 0, 0), + datetime(2021, 2, 1, 0, 0), + datetime(2021, 3, 1, 0, 0), + datetime(2021, 4, 1, 0, 0), + datetime(2021, 5, 1, 0, 0), + ], + "admin": [ + "Netherlands", + "Netherlands", + "Netherlands", + "Åland", + "Åland", + "Åland", + "Åland", + ], + "test2": [1, 1, 3, 0, 0, 0, 2], + } + ) + + assert_frame_equal(up, expected) + + +def test_upsample_sorted_only_within_group_but_no_group_by_provided() -> None: + df = pl.DataFrame( + { + "time": [ + datetime(2021, 4, 1), + datetime(2021, 2, 1), + datetime(2021, 5, 1), + datetime(2021, 6, 1), + ], + "admin": ["Netherlands", "Åland", "Åland", "Netherlands"], + "test2": [1, 0, 2, 3], + } + ) + with pytest.raises( + InvalidOperationError, + match=r"argument in operation 'upsample' is not sorted, please sort the 'expr/series/column' first", + ): + df.upsample(time_column="time", every="1mo") diff --git a/py-polars/tests/unit/dataframe/test_vstack.py b/py-polars/tests/unit/dataframe/test_vstack.py new file mode 100644 index 000000000000..8649471f6099 --- /dev/null +++ b/py-polars/tests/unit/dataframe/test_vstack.py @@ -0,0 +1,83 @@ +import pytest + +import polars as pl +from polars.exceptions import SchemaError, ShapeError +from polars.testing import assert_frame_equal + + +@pytest.fixture +def df1() -> pl.DataFrame: + return pl.DataFrame({"foo": [1, 2], "bar": [6, 7], "ham": ["a", "b"]}) + + +@pytest.fixture +def df2() -> pl.DataFrame: + return pl.DataFrame({"foo": [3, 4], "bar": [8, 9], "ham": ["c", "d"]}) + + +def test_vstack(df1: pl.DataFrame, df2: pl.DataFrame) -> None: + result = df1.vstack(df2) + expected = pl.DataFrame( + {"foo": [1, 2, 3, 4], "bar": [6, 7, 8, 9], "ham": ["a", "b", "c", "d"]} + ) + assert_frame_equal(result, expected) + + +def test_vstack_in_place(df1: pl.DataFrame, df2: pl.DataFrame) -> None: + df1.vstack(df2, in_place=True) + expected = pl.DataFrame( + {"foo": [1, 2, 3, 4], "bar": [6, 7, 8, 9], "ham": ["a", "b", "c", "d"]} + ) + assert_frame_equal(df1, expected) + + +def test_vstack_self(df1: pl.DataFrame) -> None: + result = df1.vstack(df1) + expected = pl.DataFrame( + {"foo": [1, 2, 1, 2], "bar": [6, 7, 6, 7], "ham": ["a", "b", "a", "b"]} + ) + assert_frame_equal(result, expected) + + +def test_vstack_self_in_place(df1: pl.DataFrame) -> None: + df1.vstack(df1, in_place=True) + expected = pl.DataFrame( + {"foo": [1, 2, 1, 2], "bar": [6, 7, 6, 7], "ham": ["a", "b", "a", "b"]} + ) + assert_frame_equal(df1, expected) + + +def test_vstack_column_number_mismatch(df1: pl.DataFrame) -> None: + df2 = df1.drop("ham") + + with pytest.raises(ShapeError): + df1.vstack(df2) + + +def test_vstack_column_name_mismatch(df1: pl.DataFrame) -> None: + df2 = df1.with_columns(pl.col("foo").alias("oof")) + + with pytest.raises(ShapeError): + df1.vstack(df2) + + +def test_vstack_with_null_column() -> None: + df1 = pl.DataFrame({"x": [3.5]}, schema={"x": pl.Float64}) + df2 = pl.DataFrame({"x": [None]}, schema={"x": pl.Null}) + + result = df1.vstack(df2) + expected = pl.DataFrame({"x": [3.5, None]}, schema={"x": pl.Float64}) + + assert_frame_equal(result, expected) + + with pytest.raises(SchemaError): + df2.vstack(df1) + + +def test_vstack_with_nested_nulls() -> None: + a = pl.DataFrame({"x": [[3.5]]}, schema={"x": pl.List(pl.Float32)}) + b = pl.DataFrame({"x": [[None]]}, schema={"x": pl.List(pl.Null)}) + + out = a.vstack(b) + expected = pl.DataFrame({"x": [[3.5], [None]]}, schema={"x": pl.List(pl.Float32)}) + assert_frame_equal(out, expected) diff --git a/py-polars/tests/unit/datatypes/__init__.py b/py-polars/tests/unit/datatypes/__init__.py new file mode 100644 index 000000000000..1a7dfc2793d2 --- /dev/null +++ b/py-polars/tests/unit/datatypes/__init__.py @@ -0,0 +1 @@ +"""Test module for testing behaviour of specific data types in various operations.""" diff --git a/py-polars/tests/unit/datatypes/test_array.py b/py-polars/tests/unit/datatypes/test_array.py new file mode 100644 index 000000000000..b55a5d253287 --- /dev/null +++ b/py-polars/tests/unit/datatypes/test_array.py @@ -0,0 +1,404 @@ +import datetime +from datetime import timedelta +from typing import Any + +import numpy as np +import pytest + +import polars as pl +from polars.exceptions import ComputeError, InvalidOperationError +from polars.testing import assert_frame_equal, assert_series_equal + + +def test_cast_list_array() -> None: + payload = [[1, 2, 3], [4, 2, 3]] + s = pl.Series(payload) + + dtype = pl.Array(pl.Int64, 3) + out = s.cast(dtype) + assert out.dtype == dtype + assert out.to_list() == payload + assert_series_equal(out.cast(pl.List(pl.Int64)), s) + + # width is incorrect + with pytest.raises( + ComputeError, match=r"not all elements have the specified width" + ): + s.cast(pl.Array(pl.Int64, 2)) + + +def test_array_in_group_by() -> None: + df = pl.DataFrame( + [ + pl.Series("id", [1, 2]), + pl.Series("list", [[1, 2], [5, 5]], dtype=pl.Array(pl.UInt8, 2)), + ] + ) + + result = next(iter(df.group_by(["id"], maintain_order=True)))[1]["list"] + assert result.to_list() == [[1, 2]] + + df = pl.DataFrame( + {"a": [[1, 2], [2, 2], [1, 4]], "g": [1, 1, 2]}, + schema={"a": pl.Array(pl.Int64, 2), "g": pl.Int64}, + ) + + out0 = df.group_by("g").agg(pl.col("a")).sort("g") + out1 = df.set_sorted("g").group_by("g").agg(pl.col("a")) + + for out in [out0, out1]: + assert out.schema == { + "g": pl.Int64, + "a": pl.List(pl.Array(pl.Int64, 2)), + } + assert out.to_dict(as_series=False) == { + "g": [1, 2], + "a": [[[1, 2], [2, 2]], [[1, 4]]], + } + + +def test_array_invalid_operation() -> None: + s = pl.Series( + [[1, 2], [8, 9]], + dtype=pl.Array(pl.Int32, 2), + ) + with pytest.raises( + InvalidOperationError, + match=r"`sign` operation not supported for dtype `array\[", + ): + s.sign() + + +def test_array_concat() -> None: + a_df = pl.DataFrame({"a": [[0, 1], [1, 0]]}).select( + pl.col("a").cast(pl.Array(pl.Int32, 2)) + ) + b_df = pl.DataFrame({"a": [[1, 1], [0, 0]]}).select( + pl.col("a").cast(pl.Array(pl.Int32, 2)) + ) + assert pl.concat([a_df, b_df]).to_dict(as_series=False) == { + "a": [[0, 1], [1, 0], [1, 1], [0, 0]] + } + + +def test_array_equal_and_not_equal() -> None: + left = pl.Series([[1, 2], [3, 5]], dtype=pl.Array(pl.Int64, 2)) + right = pl.Series([[1, 2], [3, 1]], dtype=pl.Array(pl.Int64, 2)) + assert_series_equal(left == right, pl.Series([True, False])) + assert_series_equal(left.eq_missing(right), pl.Series([True, False])) + assert_series_equal(left != right, pl.Series([False, True])) + assert_series_equal(left.ne_missing(right), pl.Series([False, True])) + + left = pl.Series([[1, None], [3, None]], dtype=pl.Array(pl.Int64, 2)) + right = pl.Series([[1, None], [3, 4]], dtype=pl.Array(pl.Int64, 2)) + assert_series_equal(left == right, pl.Series([True, False])) + assert_series_equal(left.eq_missing(right), pl.Series([True, False])) + assert_series_equal(left != right, pl.Series([False, True])) + assert_series_equal(left.ne_missing(right), pl.Series([False, True])) + + # TODO: test eq_missing with nulled arrays, rather than null elements. + + +def test_array_list_supertype() -> None: + s1 = pl.Series([[1, 2], [3, 4]], dtype=pl.Array(pl.Int64, 2)) + s2 = pl.Series([[1.0, 2.0], [3.0, 4.5]], dtype=pl.List(inner=pl.Float64)) + + result = s1 == s2 + + expected = pl.Series([True, False]) + assert_series_equal(result, expected) + + +def test_array_in_list() -> None: + s = pl.Series( + [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], + dtype=pl.List(pl.Array(pl.Int8, 2)), + ) + assert s.dtype == pl.List(pl.Array(pl.Int8, 2)) + + +def test_array_data_type_equality() -> None: + assert pl.Array(pl.Int64, 2) == pl.Array + assert pl.Array(pl.Int64, 2) == pl.Array(pl.Int64, 2) + assert pl.Array(pl.Int64, 2) != pl.Array(pl.Int64, 3) + assert pl.Array(pl.Int64, 2) != pl.Array(pl.String, 2) + assert pl.Array(pl.Int64, 2) != pl.List(pl.Int64) + + assert pl.Array(pl.Int64, (4, 2)) == pl.Array + assert pl.Array(pl.Array(pl.Int64, 2), 4) == pl.Array(pl.Int64, (4, 2)) + assert pl.Array(pl.Int64, (4, 2)) == pl.Array(pl.Int64, (4, 2)) + assert pl.Array(pl.Int64, (4, 2)) != pl.Array(pl.String, (4, 2)) + assert pl.Array(pl.Int64, (4, 2)) != pl.Array(pl.Int64, 4) + assert pl.Array(pl.Int64, (4,)) != pl.Array(pl.Int64, (4, 2)) + + +@pytest.mark.parametrize( + ("data", "inner_type"), + [ + ([[1, 2], None, [3, None], [None, None]], pl.Int64), + ([[True, False], None, [True, None], [None, None]], pl.Boolean), + ([[1.0, 2.0], None, [3.0, None], [None, None]], pl.Float32), + ([["a", "b"], None, ["c", None], [None, None]], pl.String), + ( + [ + [datetime.datetime(2021, 1, 1), datetime.datetime(2022, 1, 1, 10, 30)], + None, + [datetime.datetime(2023, 12, 25), None], + [None, None], + ], + pl.Datetime, + ), + ( + [ + [datetime.date(2021, 1, 1), datetime.date(2022, 1, 15)], + None, + [datetime.date(2023, 12, 25), None], + [None, None], + ], + pl.Date, + ), + ( + [ + [datetime.timedelta(10), datetime.timedelta(1, 22)], + None, + [datetime.timedelta(20), None], + [None, None], + ], + pl.Duration, + ), + ([[[1, 2], None], None, [[3], None], [None, None]], pl.List(pl.Int32)), + ], +) +def test_cast_list_to_array(data: Any, inner_type: pl.DataType) -> None: + s = pl.Series(data, dtype=pl.List(inner_type)) + s = s.cast(pl.Array(inner_type, 2)) + assert s.dtype == pl.Array(inner_type, shape=2) + assert s.to_list() == data + + +@pytest.fixture +def data_dispersion() -> pl.DataFrame: + return pl.DataFrame( + { + "int": [[1, 2, 3, 4, 5]], + "float": [[1.0, 2.0, 3.0, 4.0, 5.0]], + "duration": [[1000, 2000, 3000, 4000, 5000]], + }, + schema={ + "int": pl.Array(pl.Int64, 5), + "float": pl.Array(pl.Float64, 5), + "duration": pl.Array(pl.Duration, 5), + }, + ) + + +def test_arr_var(data_dispersion: pl.DataFrame) -> None: + df = data_dispersion + + result = df.select( + pl.col("int").arr.var().name.suffix("_var"), + pl.col("float").arr.var().name.suffix("_var"), + pl.col("duration").arr.var().name.suffix("_var"), + ) + + expected = pl.DataFrame( + [ + pl.Series("int_var", [2.5], dtype=pl.Float64), + pl.Series("float_var", [2.5], dtype=pl.Float64), + pl.Series( + "duration_var", + [timedelta(microseconds=2000)], + dtype=pl.Duration(time_unit="ms"), + ), + ] + ) + + assert_frame_equal(result, expected) + + +def test_arr_std(data_dispersion: pl.DataFrame) -> None: + df = data_dispersion + + result = df.select( + pl.col("int").arr.std().name.suffix("_std"), + pl.col("float").arr.std().name.suffix("_std"), + pl.col("duration").arr.std().name.suffix("_std"), + ) + + expected = pl.DataFrame( + [ + pl.Series("int_std", [1.5811388300841898], dtype=pl.Float64), + pl.Series("float_std", [1.5811388300841898], dtype=pl.Float64), + pl.Series( + "duration_std", + [timedelta(microseconds=1581)], + dtype=pl.Duration(time_unit="us"), + ), + ] + ) + + assert_frame_equal(result, expected) + + +def test_arr_median(data_dispersion: pl.DataFrame) -> None: + df = data_dispersion + + result = df.select( + pl.col("int").arr.median().name.suffix("_median"), + pl.col("float").arr.median().name.suffix("_median"), + pl.col("duration").arr.median().name.suffix("_median"), + ) + + expected = pl.DataFrame( + [ + pl.Series("int_median", [3.0], dtype=pl.Float64), + pl.Series("float_median", [3.0], dtype=pl.Float64), + pl.Series( + "duration_median", + [timedelta(microseconds=3000)], + dtype=pl.Duration(time_unit="us"), + ), + ] + ) + + assert_frame_equal(result, expected) + + +def test_array_repeat() -> None: + dtype = pl.Array(pl.UInt8, shape=1) + s = pl.repeat([42], n=3, dtype=dtype, eager=True) + expected = pl.Series("repeat", [[42], [42], [42]], dtype=dtype) + assert s.dtype == dtype + assert_series_equal(s, expected) + + +def test_create_nested_array() -> None: + data = [[[1, 2], [3]], [[], [4, None]], None] + s1 = pl.Series(data, dtype=pl.Array(pl.List(pl.Int64), 2)) + assert s1.to_list() == data + data = [[[1, 2], [3, None]], [[None, None], [4, None]], None] + s2 = pl.Series( + [[[1, 2], [3, None]], [[None, None], [4, None]], None], + dtype=pl.Array(pl.Array(pl.Int64, 2), 2), + ) + assert s2.to_list() == data + + +def test_recursive_array_dtype() -> None: + assert str(pl.Array(pl.Int64, (2, 3))) == "Array(Int64, shape=(2, 3))" + assert str(pl.Array(pl.Int64, 3)) == "Array(Int64, shape=(3,))" + dtype = pl.Array(pl.Int64, 3) + s = pl.Series(np.arange(6).reshape((2, 3)), dtype=dtype) + assert s.dtype == dtype + assert s.len() == 2 + dtype = pl.Array(pl.List(pl.Array(pl.Int8, (2, 2))), 2) + s = pl.Series(dtype=dtype) + assert s.dtype == dtype + assert str(s) == "shape: (0,)\nSeries: '' [array[list[array[i8, (2, 2)]], 2]]\n[\n]" + + +def test_ndarray_construction() -> None: + a = np.arange(16, dtype=np.int64).reshape((2, 4, -1)) + s = pl.Series(a) + assert s.dtype == pl.Array(pl.Int64, (4, 2)) + assert (s.to_numpy() == a).all() + + +def test_array_width_deprecated() -> None: + with pytest.deprecated_call(): + dtype = pl.Array(pl.Int8, width=2) + with pytest.deprecated_call(): + assert dtype.width == 2 + + +def test_array_inner_recursive() -> None: + shape = (2, 3, 4, 5) + dtype = pl.Array(int, shape=shape) + for dim in shape: + assert dtype.size == dim + dtype = dtype.inner # type: ignore[assignment] + + +def test_array_inner_recursive_python_dtype() -> None: + dtype = pl.Array(int, shape=(2, 3)) + assert dtype.inner.inner == pl.Int64 # type: ignore[union-attr] + + +def test_array_missing_shape() -> None: + with pytest.raises(TypeError): + pl.Array(pl.Int8) + + +def test_array_invalid_shape_type() -> None: + with pytest.raises(TypeError, match="invalid input for shape"): + pl.Array(pl.Int8, shape=("x",)) # type: ignore[arg-type] + + +def test_array_invalid_physical_type_18920() -> None: + s1 = pl.Series("x", [[1000, 2000]], pl.List(pl.Datetime)) + s2 = pl.Series("x", [None], pl.List(pl.Datetime)) + + df1 = s1.to_frame().with_columns(pl.col.x.list.to_array(2)) + df2 = s2.to_frame().with_columns(pl.col.x.list.to_array(2)) + + df = pl.concat([df1, df2]) + + expected_s = pl.Series("x", [[1000, 2000], None], pl.List(pl.Datetime)) + + expected = expected_s.to_frame().with_columns(pl.col.x.list.to_array(2)) + assert_frame_equal(df, expected) + + +@pytest.mark.parametrize( + "fn", + [ + "__add__", + "__sub__", + "__mul__", + "__truediv__", + "__mod__", + "__eq__", + "__ne__", + ], +) +def test_zero_width_array(fn: str) -> None: + series_f = getattr(pl.Series, fn) + expr_f = getattr(pl.Expr, fn) + + values = [ + [ + [[]], + [None], + ], + [ + [[], []], + [None, []], + [[], None], + [None, None], + ], + ] + + for vs in values: + for lhs in vs: + for rhs in vs: + a = pl.Series("a", lhs, pl.Array(pl.Int8, 0)) + b = pl.Series("b", rhs, pl.Array(pl.Int8, 0)) + + series_f(a, b) + + df = pl.concat([a.to_frame(), b.to_frame()], how="horizontal") + df.select(c=expr_f(pl.col.a, pl.col.b)) + + +def test_sort() -> None: + def tc(a: list[Any], b: list[Any], w: int) -> None: + a_s = pl.Series("l", a, pl.Array(pl.Int64, w)) + b_s = pl.Series("l", b, pl.Array(pl.Int64, w)) + + assert_series_equal(a_s.sort(), b_s) + + tc([], [], 1) + tc([[1]], [[1]], 1) + tc([[2], [1]], [[1], [2]], 1) + tc([[2, 1]], [[2, 1]], 2) + tc([[2, 1], [1, 2]], [[1, 2], [2, 1]], 2) diff --git a/py-polars/tests/unit/datatypes/test_binary.py b/py-polars/tests/unit/datatypes/test_binary.py new file mode 100644 index 000000000000..a19f1368ddba --- /dev/null +++ b/py-polars/tests/unit/datatypes/test_binary.py @@ -0,0 +1,39 @@ +import polars as pl +from polars.testing import assert_frame_equal + + +def test_binary_filter() -> None: + df = pl.DataFrame( + { + "name": ["a", "b", "c", "d"], + "content": [b"aa", b"aaabbb", b"aa", b"\xc6i\xea"], + } + ) + assert df.filter(pl.col("content") == b"\xc6i\xea").to_dict(as_series=False) == { + "name": ["d"], + "content": [b"\xc6i\xea"], + } + + +def test_binary_to_list() -> None: + data = {"binary": [b"\xfd\x00\xfe\x00\xff\x00", b"\x10\x00\x20\x00\x30\x00"]} + schema = {"binary": pl.Binary} + + print(pl.DataFrame(data, schema)) + df = pl.DataFrame(data, schema).with_columns( + pl.col("binary").cast(pl.List(pl.UInt8)) + ) + + expected = pl.DataFrame( + {"binary": [[253, 0, 254, 0, 255, 0], [16, 0, 32, 0, 48, 0]]}, + schema={"binary": pl.List(pl.UInt8)}, + ) + print(df) + assert_frame_equal(df, expected) + + +def test_string_to_binary() -> None: + s = pl.Series("data", ["", None, "\x01\x02"]) + + assert s.cast(pl.Binary).to_list() == [b"", None, b"\x01\x02"] + assert s.cast(pl.Binary).cast(pl.Utf8).to_list() == ["", None, "\x01\x02"] diff --git a/py-polars/tests/unit/datatypes/test_bool.py b/py-polars/tests/unit/datatypes/test_bool.py new file mode 100644 index 000000000000..981167838e56 --- /dev/null +++ b/py-polars/tests/unit/datatypes/test_bool.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import numpy as np +import pytest + +import polars as pl + + +@pytest.mark.slow +def test_bool_arg_min_max() -> None: + # masks that ensures we take more than u64 chunks + # and slicing and dicing to ensure the offsets work + for _ in range(100): + offset = np.random.randint(0, 100) + sample = np.random.rand(1000) + a = sample > 0.99 + idx = a[offset:].argmax() + assert idx == pl.Series(a)[offset:].arg_max() + idx = a[offset:].argmin() + assert idx == pl.Series(a)[offset:].arg_min() + + a = sample > 0.01 + idx = a[offset:].argmax() + assert idx == pl.Series(a)[offset:].arg_max() + idx = a[offset:].argmin() + assert idx == pl.Series(a)[offset:].arg_min() + + +def test_bool_sum_empty() -> None: + assert pl.Series([], dtype=pl.Boolean).sum() == 0 + + +def test_bool_min_max() -> None: + assert pl.Series([None, True]).min() + assert not pl.Series([None, True, False]).min() + assert not pl.Series([False, True]).min() + assert pl.Series([True, True]).min() + assert not pl.Series([False, False]).min() + assert pl.Series([None, True]).max() + assert pl.Series([None, True, False]).max() + assert pl.Series([False, True]).max() + assert pl.Series([True, True]).max() + assert not pl.Series([False, False]).max() + + +def test_bool_literal_expressions() -> None: + df = pl.DataFrame({"x": [False, True]}) + + def val(expr: pl.Expr) -> dict[str, list[bool]]: + return df.select(expr).to_dict(as_series=False) + + assert val(pl.col("x") & False) == {"x": [False, False]} + assert val(pl.col("x") & True) == {"x": [False, True]} + assert val(pl.col("x") | False) == {"x": [False, True]} + assert val(pl.col("x") | True) == {"x": [True, True]} + assert val(pl.col("x") ^ False) == {"x": [False, True]} + assert val(pl.col("x") ^ True) == {"x": [True, False]} + assert val(False & pl.col("x")) == {"literal": [False, False]} + assert val(True & pl.col("x")) == {"literal": [False, True]} + assert val(False | pl.col("x")) == {"literal": [False, True]} + assert val(True | pl.col("x")) == {"literal": [True, True]} + assert val(False ^ pl.col("x")) == {"literal": [False, True]} + assert val(True ^ pl.col("x")) == {"literal": [True, False]} diff --git a/py-polars/tests/unit/datatypes/test_categorical.py b/py-polars/tests/unit/datatypes/test_categorical.py new file mode 100644 index 000000000000..f92f361df801 --- /dev/null +++ b/py-polars/tests/unit/datatypes/test_categorical.py @@ -0,0 +1,1023 @@ +from __future__ import annotations + +import contextlib +import io +import operator +from typing import TYPE_CHECKING, Any, Callable, Literal + +import pytest + +import polars as pl +from polars import StringCache +from polars.exceptions import ( + CategoricalRemappingWarning, + ComputeError, + StringCacheMismatchError, +) +from polars.testing import assert_frame_equal, assert_series_equal +from tests.unit.conftest import with_string_cache_if_auto_streaming + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + + +@StringCache() +def test_categorical_full_outer_join() -> None: + df1 = pl.DataFrame( + [ + pl.Series("key1", [42]), + pl.Series("key2", ["bar"], dtype=pl.Categorical), + pl.Series("val1", [1]), + ] + ).lazy() + + df2 = pl.DataFrame( + [ + pl.Series("key1", [42]), + pl.Series("key2", ["bar"], dtype=pl.Categorical), + pl.Series("val2", [2]), + ] + ).lazy() + + expected = pl.DataFrame( + { + "key1": [42], + "key2": ["bar"], + "val1": [1], + "key1_right": [42], + "key2_right": ["bar"], + "val2": [2], + }, + schema_overrides={"key2": pl.Categorical, "key2_right": pl.Categorical}, + ) + + out = df1.join(df2, on=["key1", "key2"], how="full").collect() + assert_frame_equal(out, expected) + + dfa = pl.DataFrame( + [ + pl.Series("key", ["foo", "bar"], dtype=pl.Categorical), + pl.Series("val1", [3, 1]), + ] + ) + dfb = pl.DataFrame( + [ + pl.Series("key", ["bar", "baz"], dtype=pl.Categorical), + pl.Series("val2", [6, 8]), + ] + ) + + df = dfa.join(dfb, on="key", how="full", maintain_order="right_left") + # the cast is important to test the rev map + assert df["key"].cast(pl.String).to_list() == ["bar", None, "foo"] + assert df["key_right"].cast(pl.String).to_list() == ["bar", "baz", None] + + +@pytest.mark.usefixtures("test_global_and_local") +def test_read_csv_categorical() -> None: + f = io.BytesIO() + f.write(b"col1,col2,col3,col4,col5,col6\n'foo',2,3,4,5,6\n'bar',8,9,10,11,12") + f.seek(0) + df = pl.read_csv(f, has_header=True, schema_overrides={"col1": pl.Categorical}) + assert df["col1"].dtype == pl.Categorical + + +@pytest.mark.usefixtures("test_global_and_local") +def test_cat_to_dummies() -> None: + df = pl.DataFrame({"foo": [1, 2, 3, 4], "bar": ["a", "b", "a", "c"]}) + df = df.with_columns(pl.col("bar").cast(pl.Categorical)) + assert df.to_dummies().to_dict(as_series=False) == { + "foo_1": [1, 0, 0, 0], + "foo_2": [0, 1, 0, 0], + "foo_3": [0, 0, 1, 0], + "foo_4": [0, 0, 0, 1], + "bar_a": [1, 0, 1, 0], + "bar_b": [0, 1, 0, 0], + "bar_c": [0, 0, 0, 1], + } + + +@pytest.mark.may_fail_auto_streaming +@pytest.mark.usefixtures("test_global_and_local") +def test_categorical_is_in_list() -> None: + # this requires type coercion to cast. + # we should not cast within the function as this would be expensive within a + # group by context that would be a cast per group + df = pl.DataFrame( + {"a": [1, 2, 3, 1, 2], "b": ["a", "b", "c", "d", "e"]} + ).with_columns(pl.col("b").cast(pl.Categorical)) + + cat_list = ("a", "b", "c") + assert df.filter(pl.col("b").is_in(cat_list)).to_dict(as_series=False) == { + "a": [1, 2, 3], + "b": ["a", "b", "c"], + } + + +@pytest.mark.usefixtures("test_global_and_local") +@with_string_cache_if_auto_streaming +def test_unset_sorted_on_append() -> None: + df1 = pl.DataFrame( + [ + pl.Series("key", ["a", "b", "a", "b"], dtype=pl.Categorical), + pl.Series("val", [1, 2, 3, 4]), + ] + ).sort("key") + df2 = pl.DataFrame( + [ + pl.Series("key", ["a", "b", "a", "b"], dtype=pl.Categorical), + pl.Series("val", [5, 6, 7, 8]), + ] + ).sort("key") + df = pl.concat([df1, df2], rechunk=False) + assert df.group_by("key").len()["len"].to_list() == [4, 4] + + +@pytest.mark.parametrize( + ("op", "expected"), + [ + (operator.eq, pl.Series([True, True, True, False, None, None])), + (operator.ne, pl.Series([False, False, False, True, None, None])), + (pl.Series.ne_missing, pl.Series([False, False, False, True, True, True])), + (pl.Series.eq_missing, pl.Series([True, True, True, False, False, False])), + ], +) +@pytest.mark.usefixtures("test_global_and_local") +def test_categorical_equality( + op: Callable[[pl.Series, pl.Series], pl.Series], expected: pl.Series +) -> None: + s = pl.Series(["a", "b", "c", "c", None, None], dtype=pl.Categorical) + s2 = pl.Series("b_cat", ["a", "b", "c", "a", "b", "c"], dtype=pl.Categorical) + assert_series_equal(op(s, s2), expected) + assert_series_equal(op(s, s2.cast(pl.String)), expected) + + +@pytest.mark.parametrize( + ("op", "expected"), + [ + (operator.eq, pl.Series([False, False, False, False, None, None])), + (operator.ne, pl.Series([True, True, True, True, None, None])), + (pl.Series.eq_missing, pl.Series([False, False, False, False, False, False])), + (pl.Series.ne_missing, pl.Series([True, True, True, True, True, True])), + ], +) +@StringCache() +def test_categorical_equality_global_fastpath( + op: Callable[[pl.Series, pl.Series], pl.Series], expected: pl.Series +) -> None: + s = pl.Series(["a", "b", "c", "c", None, None], dtype=pl.Categorical) + s2 = pl.Series(["d"], dtype=pl.Categorical) + assert_series_equal(op(s, s2), expected) + assert_series_equal(op(s, s2.cast(pl.String)), expected) + + +@pytest.mark.parametrize( + ("op", "expected_phys", "expected_lexical"), + [ + ( + operator.le, + pl.Series([True, True, True, True, False]), + pl.Series([False, True, True, False, True]), + ), + ( + operator.lt, + pl.Series([True, False, False, True, False]), + pl.Series([False, False, False, False, True]), + ), + ( + operator.ge, + pl.Series([False, True, True, False, True]), + pl.Series([True, True, True, True, False]), + ), + ( + operator.gt, + pl.Series([False, False, False, False, True]), + pl.Series([True, False, False, True, False]), + ), + ], +) +@StringCache() +def test_categorical_global_ordering( + op: Callable[[pl.Series, pl.Series], pl.Series], + expected_phys: pl.Series, + expected_lexical: pl.Series, +) -> None: + s = pl.Series(["z", "b", "c", "c", "a"], dtype=pl.Categorical) + s2 = pl.Series("b_cat", ["a", "b", "c", "a", "c"], dtype=pl.Categorical) + assert_series_equal(op(s, s2), expected_phys) + + s = s.cast(pl.Categorical("lexical")) + s2 = s2.cast(pl.Categorical("lexical")) + assert_series_equal(op(s, s2), expected_lexical) + + +@pytest.mark.parametrize( + ("op", "expected_phys", "expected_lexical"), + [ + (operator.le, pl.Series([True, True, False]), pl.Series([False, True, False])), + ( + operator.lt, + pl.Series([True, False, False]), + pl.Series([False, False, False]), + ), + (operator.ge, pl.Series([False, True, True]), pl.Series([True, True, True])), + (operator.gt, pl.Series([False, False, True]), pl.Series([True, False, True])), + ], +) +@StringCache() +def test_categorical_global_ordering_broadcast_rhs( + op: Callable[[pl.Series, pl.Series], pl.Series], + expected_phys: pl.Series, + expected_lexical: pl.Series, +) -> None: + s = pl.Series(["c", "a", "b"], dtype=pl.Categorical) + s2 = pl.Series("b_cat", ["a"], dtype=pl.Categorical) + assert_series_equal(op(s, s2), expected_phys) + + s = s.cast(pl.Categorical("lexical")) + s2 = s2.cast(pl.Categorical("lexical")) + assert_series_equal(op(s, s2), expected_lexical) + assert_series_equal(op(s, s2.cast(pl.String)), expected_lexical) + + +@pytest.mark.parametrize( + ("op", "expected_phys", "expected_lexical"), + [ + (operator.le, pl.Series([True, True, True]), pl.Series([True, False, True])), + (operator.lt, pl.Series([True, True, False]), pl.Series([True, False, False])), + (operator.ge, pl.Series([False, False, True]), pl.Series([False, True, True])), + ( + operator.gt, + pl.Series([False, False, False]), + pl.Series([False, True, False]), + ), + ], +) +@StringCache() +def test_categorical_global_ordering_broadcast_lhs( + op: Callable[[pl.Series, pl.Series], pl.Series], + expected_phys: pl.Series, + expected_lexical: pl.Series, +) -> None: + s = pl.Series(["b"], dtype=pl.Categorical) + s2 = pl.Series(["c", "a", "b"], dtype=pl.Categorical) + assert_series_equal(op(s, s2), expected_phys) + + s = s.cast(pl.Categorical("lexical")) + s2 = s2.cast(pl.Categorical("lexical")) + assert_series_equal(op(s, s2), expected_lexical) + assert_series_equal(op(s, s2.cast(pl.String)), expected_lexical) + + +@pytest.mark.parametrize( + ("op", "expected"), + [ + (operator.le, pl.Series([True, True, True, False, True, True])), + (operator.lt, pl.Series([False, False, False, False, True, False])), + (operator.ge, pl.Series([True, True, True, True, False, True])), + (operator.gt, pl.Series([False, False, False, True, False, False])), + ], +) +@pytest.mark.usefixtures("test_global_and_local") +def test_categorical_ordering( + op: Callable[[pl.Series, pl.Series], pl.Series], expected: pl.Series +) -> None: + s = pl.Series(["a", "b", "c", "c", "a", "b"], dtype=pl.Categorical) + s2 = pl.Series("b_cat", ["a", "b", "c", "a", "c", "b"], dtype=pl.Categorical) + assert_series_equal(op(s, s2), expected) + + +@pytest.mark.parametrize( + ("op", "expected"), + [ + (operator.le, pl.Series([None, True, True, True, True, True])), + (operator.lt, pl.Series([None, False, False, False, True, True])), + (operator.ge, pl.Series([None, True, True, True, False, False])), + (operator.gt, pl.Series([None, False, False, False, False, False])), + ], +) +@pytest.mark.usefixtures("test_global_and_local") +def test_compare_categorical( + op: Callable[[pl.Series, pl.Series], pl.Series], expected: pl.Series +) -> None: + s = pl.Series([None, "a", "b", "c", "b", "a"], dtype=pl.Categorical) + s2 = pl.Series([None, "a", "b", "c", "c", "b"]) + + assert_series_equal(op(s, s2), expected) + + +@pytest.mark.parametrize( + ("op", "expected"), + [ + (operator.le, pl.Series([None, True, True, False, True, True])), + (operator.lt, pl.Series([None, True, False, False, False, True])), + (operator.ge, pl.Series([None, False, True, True, True, False])), + (operator.gt, pl.Series([None, False, False, True, False, False])), + (operator.eq, pl.Series([None, False, True, False, True, False])), + (operator.ne, pl.Series([None, True, False, True, False, True])), + (pl.Series.eq_missing, pl.Series([False, False, True, False, True, False])), + (pl.Series.ne_missing, pl.Series([True, True, False, True, False, True])), + ], +) +@pytest.mark.usefixtures("test_global_and_local") +def test_compare_categorical_single( + op: Callable[[pl.Series, pl.Series], pl.Series], expected: pl.Series +) -> None: + s = pl.Series([None, "a", "b", "c", "b", "a"], dtype=pl.Categorical) + s2 = "b" + + assert_series_equal(op(s, s2), expected) # type: ignore[arg-type] + + +@pytest.mark.parametrize( + ("op", "expected"), + [ + (operator.le, pl.Series([None, True, True, True, True, True])), + (operator.lt, pl.Series([None, True, True, True, True, True])), + (operator.ge, pl.Series([None, False, False, False, False, False])), + (operator.gt, pl.Series([None, False, False, False, False, False])), + (operator.eq, pl.Series([None, False, False, False, False, False])), + (operator.ne, pl.Series([None, True, True, True, True, True])), + (pl.Series.ne_missing, pl.Series([True, True, True, True, True, True])), + (pl.Series.eq_missing, pl.Series([False, False, False, False, False, False])), + ], +) +@StringCache() +def test_compare_categorical_single_non_existent( + op: Callable[[pl.Series, pl.Series], pl.Series], expected: pl.Series +) -> None: + s = pl.Series([None, "a", "b", "c", "b", "a"], dtype=pl.Categorical) + s2 = "d" + assert_series_equal(op(s, s2), expected) # type: ignore[arg-type] + s_cat = pl.Series(["d"], dtype=pl.Categorical) + assert_series_equal(op(s, s_cat), expected) + assert_series_equal(op(s, s_cat.cast(pl.String)), expected) + + +@pytest.mark.parametrize( + ("op", "expected"), + [ + ( + operator.le, + pl.Series([None, None, None, None, None, None], dtype=pl.Boolean), + ), + ( + operator.lt, + pl.Series([None, None, None, None, None, None], dtype=pl.Boolean), + ), + ( + operator.ge, + pl.Series([None, None, None, None, None, None], dtype=pl.Boolean), + ), + ( + operator.gt, + pl.Series([None, None, None, None, None, None], dtype=pl.Boolean), + ), + ( + operator.eq, + pl.Series([None, None, None, None, None, None], dtype=pl.Boolean), + ), + ( + operator.ne, + pl.Series([None, None, None, None, None, None], dtype=pl.Boolean), + ), + (pl.Series.ne_missing, pl.Series([False, True, True, True, True, True])), + (pl.Series.eq_missing, pl.Series([True, False, False, False, False, False])), + ], +) +@StringCache() +def test_compare_categorical_single_none( + op: Callable[[pl.Series, pl.Series], pl.Series], expected: pl.Series +) -> None: + s = pl.Series([None, "a", "b", "c", "b", "a"], dtype=pl.Categorical) + s2 = pl.Series([None], dtype=pl.Categorical) + assert_series_equal(op(s, s2), expected) + assert_series_equal(op(s, s2.cast(pl.String)), expected) + + +def test_categorical_error_on_local_cmp() -> None: + df_cat = pl.DataFrame( + [ + pl.Series("a_cat", ["c", "a", "b", "c", "b"], dtype=pl.Categorical), + pl.Series("b_cat", ["F", "G", "E", "G", "G"], dtype=pl.Categorical), + ] + ) + with pytest.raises( + StringCacheMismatchError, + match="cannot compare categoricals coming from different sources", + ): + df_cat.filter(pl.col("a_cat") == pl.col("b_cat")) + + +@pytest.mark.usefixtures("test_global_and_local") +def test_cast_null_to_categorical() -> None: + assert pl.DataFrame().with_columns( + pl.lit(None).cast(pl.Categorical).alias("nullable_enum") + ).dtypes == [pl.Categorical] + + +@StringCache() +def test_merge_lit_under_global_cache_4491() -> None: + df = pl.DataFrame( + [ + pl.Series("label", ["foo", "bar"], dtype=pl.Categorical), + pl.Series("value", [3, 9]), + ] + ) + assert df.with_columns( + pl.when(pl.col("value") > 5) + .then(pl.col("label")) + .otherwise(pl.lit(None, pl.Categorical)) + ).to_dict(as_series=False) == {"label": [None, "bar"], "value": [3, 9]} + + +def test_nested_cache_composition() -> None: + # very artificial example/test, but validates the behaviour + # of nested StringCache scopes, which we want to play well + # with each other when composing more complex pipelines. + + assert pl.using_string_cache() is False + + # function representing a composable stage of a pipeline; it implements + # an inner scope for the case where it is called by itself, but when + # called as part of a larger series of ops it should not invalidate + # the string cache (eg: the outermost scope should be respected). + def create_lazy(data: dict) -> pl.LazyFrame: # type: ignore[type-arg] + with pl.StringCache(): + df = pl.DataFrame({"a": ["foo", "bar", "ham"], "b": [1, 2, 3]}) + lf = df.with_columns(pl.col("a").cast(pl.Categorical)).lazy() + + # confirm that scope-exit does NOT invalidate the + # cache yet, as an outer context is still active + assert pl.using_string_cache() is True + return lf + + # this outer scope should be respected + with pl.StringCache(): + lf1 = create_lazy({"a": ["foo", "bar", "ham"], "b": [1, 2, 3]}) + lf2 = create_lazy({"a": ["spam", "foo", "eggs"], "c": [3, 2, 2]}) + + res = lf1.join(lf2, on="a", how="inner").collect().rows() + assert sorted(res) == [("bar", 2, 2), ("foo", 1, 1), ("ham", 3, 3)] + + # no other scope active; NOW we expect the cache to have been invalidated + assert pl.using_string_cache() is False + + +@pytest.mark.usefixtures("test_global_and_local") +def test_categorical_in_struct_nulls() -> None: + s = pl.Series( + "job", ["doctor", "waiter", None, None, None, "doctor"], pl.Categorical + ) + df = pl.DataFrame([s]) + s = (df.select(pl.col("job").value_counts(sort=True)))["job"] + + assert s[0] == {"job": None, "count": 3} + assert s[1] == {"job": "doctor", "count": 2} + assert s[2] == {"job": "waiter", "count": 1} + + +@pytest.mark.usefixtures("test_global_and_local") +def test_cast_inner_categorical() -> None: + dtype = pl.List(pl.Categorical) + out = pl.Series("foo", [["a"], ["a", "b"]]).cast(dtype) + assert out.dtype == dtype + assert out.to_list() == [["a"], ["a", "b"]] + + with pytest.raises( + ComputeError, match=r"casting to categorical not allowed in `list.eval`" + ): + pl.Series("foo", [["a", "b"], ["a", "b"]]).list.eval( + pl.element().cast(pl.Categorical) + ) + + +@pytest.mark.slow +def test_stringcache() -> None: + N = 1_500 + with pl.StringCache(): + # create a large enough column that the categorical map is reallocated + df = pl.DataFrame({"cats": pl.arange(0, N, eager=True)}).select( + pl.col("cats").cast(pl.String).cast(pl.Categorical) + ) + assert df.filter(pl.col("cats").is_in(["1", "2"])).to_dict(as_series=False) == { + "cats": ["1", "2"] + } + + +@pytest.mark.parametrize( + ("dtype", "outcome"), + [ + (pl.Categorical, ["foo", "bar", "baz"]), + (pl.Categorical("physical"), ["foo", "bar", "baz"]), + (pl.Categorical("lexical"), ["bar", "baz", "foo"]), + ], +) +@pytest.mark.usefixtures("test_global_and_local") +def test_categorical_sort_order_by_parameter( + dtype: PolarsDataType, outcome: list[str] +) -> None: + s = pl.Series(["foo", "bar", "baz"], dtype=dtype) + df = pl.DataFrame({"cat": s}) + assert df.sort(["cat"])["cat"].to_list() == outcome + + +@StringCache() +@pytest.mark.parametrize("row_fmt_sort_enabled", [False, True]) +def test_categorical_sort_order(row_fmt_sort_enabled: bool, monkeypatch: Any) -> None: + # create the categorical ordering first + pl.Series(["foo", "bar", "baz"], dtype=pl.Categorical) + df = pl.DataFrame( + { + "n": [0, 0, 0], + # use same categories in different order + "x": pl.Series(["baz", "bar", "foo"], dtype=pl.Categorical), + } + ) + + if row_fmt_sort_enabled: + monkeypatch.setenv("POLARS_ROW_FMT_SORT", "1") + + result = df.sort(["n", "x"]) + assert result["x"].to_list() == ["foo", "bar", "baz"] + + result = df.with_columns(pl.col("x").cast(pl.Categorical("lexical"))).sort("n", "x") + assert result["x"].to_list() == ["bar", "baz", "foo"] + + +def test_err_on_categorical_asof_join_by_arg() -> None: + df1 = pl.DataFrame( + [ + pl.Series("cat", ["a", "foo", "bar", "foo", "bar"], dtype=pl.Categorical), + pl.Series("time", [-10, 0, 10, 20, 30], dtype=pl.Int32), + ] + ) + df2 = pl.DataFrame( + [ + pl.Series( + "cat", + ["bar", "bar", "bar", "bar", "foo", "foo", "foo", "foo"], + dtype=pl.Categorical, + ), + pl.Series("time", [-5, 5, 15, 25] * 2, dtype=pl.Int32), + pl.Series("x", [1, 2, 3, 4] * 2, dtype=pl.Int32), + ] + ) + with pytest.raises( + StringCacheMismatchError, + match="cannot compare categoricals coming from different sources", + ): + df1.join_asof(df2, on=pl.col("time").set_sorted(), by="cat") + + +@pytest.mark.usefixtures("test_global_and_local") +def test_categorical_list_get_item() -> None: + out = pl.Series([["a"]]).cast(pl.List(pl.Categorical)).item() + assert isinstance(out, pl.Series) + assert out.dtype == pl.Categorical + + +@pytest.mark.usefixtures("test_global_and_local") +def test_nested_categorical_aggregation_7848() -> None: + # a double categorical aggregation + assert pl.DataFrame( + { + "group": [1, 1, 2, 2, 2, 3, 3], + "letter": ["a", "b", "c", "d", "e", "f", "g"], + } + ).with_columns([pl.col("letter").cast(pl.Categorical)]).group_by( + "group", maintain_order=True + ).all().with_columns(pl.col("letter").list.len().alias("c_group")).group_by( + ["c_group"], maintain_order=True + ).agg(pl.col("letter")).to_dict(as_series=False) == { + "c_group": [2, 3], + "letter": [[["a", "b"], ["f", "g"]], [["c", "d", "e"]]], + } + + +@pytest.mark.usefixtures("test_global_and_local") +def test_nested_categorical_cast() -> None: + values = [["x"], ["y"], ["x"]] + dtype = pl.List(pl.Categorical) + s = pl.Series(values).cast(dtype) + assert s.dtype == dtype + assert s.to_list() == values + + +@pytest.mark.usefixtures("test_global_and_local") +def test_struct_categorical_nesting() -> None: + # this triggers a lot of materialization + df = pl.DataFrame( + {"cats": ["Value1", "Value2", "Value1"]}, + schema_overrides={"cats": pl.Categorical}, + ) + s = df.select(pl.struct(pl.col("cats")))["cats"].implode() + assert s.dtype == pl.List(pl.Struct([pl.Field("cats", pl.Categorical)])) + # triggers recursive conversion + assert s.to_list() == [[{"cats": "Value1"}, {"cats": "Value2"}, {"cats": "Value1"}]] + # triggers different recursive conversion + assert len(s.to_arrow()) == 1 + + +def test_categorical_fill_null_existing_category() -> None: + # ensure physical types align + df = pl.DataFrame({"col": ["a", None, "a"]}, schema={"col": pl.Categorical}) + result = df.fill_null("a").with_columns(pl.col("col").to_physical().alias("code")) + expected = {"col": ["a", "a", "a"], "code": [0, 0, 0]} + assert result.to_dict(as_series=False) == expected + + +@pytest.mark.usefixtures("test_global_and_local") +@pytest.mark.may_fail_auto_streaming +def test_categorical_fill_null_stringcache() -> None: + df = pl.LazyFrame( + {"index": [1, 2, 3], "cat": ["a", "b", None]}, + schema={"index": pl.Int64(), "cat": pl.Categorical()}, + ) + a = df.select(pl.col("cat").fill_null("hi")).collect() + + assert a.to_dict(as_series=False) == {"cat": ["a", "b", "hi"]} + assert a.dtypes == [pl.Categorical] + + +@pytest.mark.usefixtures("test_global_and_local") +def test_fast_unique_flag_from_arrow() -> None: + with pl.StringCache(): + df = pl.DataFrame( + { + "colB": ["1", "2", "3", "4", "5", "5", "5", "5"], + } + ).with_columns([pl.col("colB").cast(pl.Categorical)]) + + filtered = df.to_arrow().filter( + [True, False, True, True, False, True, True, True] + ) + assert pl.from_arrow(filtered).select(pl.col("colB").n_unique()).item() == 4 # type: ignore[union-attr] + + +@pytest.mark.usefixtures("test_global_and_local") +def test_construct_with_null() -> None: + # Example from https://github.com/pola-rs/polars/issues/7188 + df = pl.from_dicts([{"A": None}, {"A": "foo"}], schema={"A": pl.Categorical}) + assert df.to_series().to_list() == [None, "foo"] + + s = pl.Series([{"struct_A": None}], dtype=pl.Struct({"struct_A": pl.Categorical})) + assert s.to_list() == [{"struct_A": None}] + + +def test_categorical_concat_string_cached() -> None: + with pl.StringCache(): + df1 = pl.DataFrame({"x": ["A"]}).with_columns(pl.col("x").cast(pl.Categorical)) + df2 = pl.DataFrame({"x": ["B"]}).with_columns(pl.col("x").cast(pl.Categorical)) + + out = pl.concat([df1, df2]) + assert out.dtypes == [pl.Categorical] + assert out["x"].to_list() == ["A", "B"] + + +def test_list_builder_different_categorical_rev_maps() -> None: + with pl.StringCache(): + # built with different values, so different rev-map + s1 = pl.Series(["a", "b"], dtype=pl.Categorical) + s2 = pl.Series(["c", "d"], dtype=pl.Categorical) + + assert pl.DataFrame({"c": [s1, s2]}).to_dict(as_series=False) == { + "c": [["a", "b"], ["c", "d"]] + } + + +@pytest.mark.usefixtures("test_global_and_local") +def test_categorical_collect_11408() -> None: + df = pl.DataFrame( + data={"groups": ["a", "b", "c"], "cats": ["a", "b", "c"], "amount": [1, 2, 3]}, + schema={"groups": pl.String, "cats": pl.Categorical, "amount": pl.Int8}, + ) + + assert df.group_by("groups").agg( + pl.col("cats").filter(pl.col("amount") == pl.col("amount").min()).first() + ).sort("groups").to_dict(as_series=False) == { + "groups": ["a", "b", "c"], + "cats": ["a", "b", "c"], + } + + +@pytest.mark.usefixtures("test_global_and_local") +def test_categorical_nested_cast_unchecked() -> None: + s = pl.Series("cat", [["cat"]]).cast(pl.List(pl.Categorical)) + assert pl.Series([s]).to_list() == [[["cat"]]] + + +def test_categorical_update_lengths() -> None: + with pl.StringCache(): + s1 = pl.Series(["", ""], dtype=pl.Categorical) + s2 = pl.Series([None, "", ""], dtype=pl.Categorical) + + s = pl.concat([s1, s2], rechunk=False) + assert s.null_count() == 1 + assert s.len() == 5 + + +def test_categorical_zip_append_local_different_rev_map() -> None: + s1 = pl.Series(["cat1", "cat2", "cat1"], dtype=pl.Categorical) + s2 = pl.Series(["cat2", "cat2", "cat3"], dtype=pl.Categorical) + with pytest.warns( + CategoricalRemappingWarning, + match="Local categoricals have different encodings", + ): + s3 = s1.append(s2) + categories = s3.cat.get_categories() + assert len(categories) == 3 + assert set(categories) == {"cat1", "cat2", "cat3"} + + +def test_categorical_zip_extend_local_different_rev_map() -> None: + s1 = pl.Series(["cat1", "cat2", "cat1"], dtype=pl.Categorical) + s2 = pl.Series(["cat2", "cat2", "cat3"], dtype=pl.Categorical) + with pytest.warns( + CategoricalRemappingWarning, + match="Local categoricals have different encodings", + ): + s3 = s1.extend(s2) + categories = s3.cat.get_categories() + assert len(categories) == 3 + assert set(categories) == {"cat1", "cat2", "cat3"} + + +def test_categorical_zip_with_local_different_rev_map() -> None: + s1 = pl.Series(["cat1", "cat2", "cat1"], dtype=pl.Categorical) + mask = pl.Series([True, False, False]) + s2 = pl.Series(["cat2", "cat2", "cat3"], dtype=pl.Categorical) + with pytest.warns( + CategoricalRemappingWarning, + match="Local categoricals have different encodings", + ): + s3 = s1.zip_with(mask, s2) + categories = s3.cat.get_categories() + assert len(categories) == 3 + assert set(categories) == {"cat1", "cat2", "cat3"} + + +def test_categorical_vstack_with_local_different_rev_map() -> None: + df1 = pl.DataFrame({"a": pl.Series(["a", "b", "c"], dtype=pl.Categorical)}) + df2 = pl.DataFrame({"a": pl.Series(["d", "e", "f"], dtype=pl.Categorical)}) + with pytest.warns( + CategoricalRemappingWarning, + match="Local categoricals have different encodings", + ): + df3 = df1.vstack(df2) + assert df3.get_column("a").cat.get_categories().to_list() == [ + "a", + "b", + "c", + "d", + "e", + "f", + ] + assert df3.get_column("a").cast(pl.UInt32).to_list() == [0, 1, 2, 3, 4, 5] + + +@pytest.mark.usefixtures("test_global_and_local") +def test_shift_over_13041() -> None: + df = pl.DataFrame( + { + "id": [0, 0, 0, 1, 1, 1], + "cat_col": pl.Series(["a", "b", "c", "d", "e", "f"], dtype=pl.Categorical), + } + ) + result = df.with_columns(pl.col("cat_col").shift(2).over("id")) + + assert result.to_dict(as_series=False) == { + "id": [0, 0, 0, 1, 1, 1], + "cat_col": [None, None, "a", None, None, "d"], + } + + +@pytest.mark.parametrize("context", [pl.StringCache(), contextlib.nullcontext()]) +@pytest.mark.parametrize("ordering", ["physical", "lexical"]) +@pytest.mark.usefixtures("test_global_and_local") +def test_sort_categorical_retain_none( + context: contextlib.AbstractContextManager, # type: ignore[type-arg] + ordering: Literal["physical", "lexical"], +) -> None: + with context: + df = pl.DataFrame( + [ + pl.Series( + "e", + ["foo", None, "bar", "ham", None], + dtype=pl.Categorical(ordering=ordering), + ) + ] + ) + + df_sorted = df.with_columns(pl.col("e").sort()) + assert ( + df_sorted.get_column("e").null_count() + == df.get_column("e").null_count() + == 2 + ) + if ordering == "lexical": + assert df_sorted.get_column("e").to_list() == [ + None, + None, + "bar", + "foo", + "ham", + ] + + +@pytest.mark.usefixtures("test_global_and_local") +def test_cast_from_cat_to_numeric() -> None: + cat_series = pl.Series( + "cat_series", + ["0.69845702", "0.69317475", "2.43642724", "-0.95303469", "0.60684237"], + ).cast(pl.Categorical) + maximum = cat_series.cast(pl.Float32).max() + assert abs(maximum - 2.43642724) < 1e-6 # type: ignore[operator] + + s = pl.Series(["1", "2", "3"], dtype=pl.Categorical) + assert s.cast(pl.UInt8).sum() == 6 + + +@pytest.mark.usefixtures("test_global_and_local") +def test_cat_preserve_lexical_ordering_on_clear() -> None: + s = pl.Series("a", ["a", "b"], dtype=pl.Categorical(ordering="lexical")) + s2 = s.clear() + assert s.dtype == s2.dtype + + +@pytest.mark.usefixtures("test_global_and_local") +def test_cat_preserve_lexical_ordering_on_concat() -> None: + dtype = pl.Categorical(ordering="lexical") + + df = pl.DataFrame({"x": ["b", "a", "c"]}).with_columns(pl.col("x").cast(dtype)) + df2 = pl.concat([df, df]) + assert df2["x"].dtype == dtype + + +@pytest.mark.usefixtures("test_global_and_local") +@pytest.mark.may_fail_auto_streaming +def test_cat_append_lexical_sorted_flag() -> None: + df = pl.DataFrame({"x": [0, 1, 1], "y": ["B", "B", "A"]}).with_columns( + pl.col("y").cast(pl.Categorical(ordering="lexical")) + ) + df2 = pl.concat([part.sort("y") for part in df.partition_by("x")]) + + assert not (df2["y"].is_sorted()) + + s = pl.Series("a", ["z", "k", "a"], pl.Categorical("lexical")) + s1 = s[[0]] + s2 = s[[1]] + s3 = s[[2]] + s1.append(s2) + s1.append(s3) + + assert not (s1.is_sorted()) + + +@pytest.mark.usefixtures("test_global_and_local") +def test_cast_physical_lexical_sorted_flag_20864() -> None: + df = pl.DataFrame({"s": ["b", "a"], "v": [1, 2]}) + sorted_physically = df.cast({"s": pl.Categorical("physical")}).sort("s") + sorted_lexically = sorted_physically.cast({"s": pl.Categorical("lexical")}).sort( + "s" + ) + assert sorted_lexically["s"].to_list() == ["a", "b"] + + +@pytest.mark.usefixtures("test_global_and_local") +def test_get_cat_categories_multiple_chunks() -> None: + df = pl.DataFrame( + [ + pl.Series("e", ["a", "b"], pl.Categorical), + ] + ) + df = pl.concat( + [df for _ in range(100)], how="vertical", rechunk=False, parallel=True + ) + df_cat = df.lazy().select(pl.col("e").cat.get_categories()).collect() + assert len(df_cat) == 2 + + +@pytest.mark.parametrize( + "f", + [ + lambda x: (pl.List(pl.Categorical), [x]), + lambda x: (pl.Struct({"a": pl.Categorical}), {"a": x}), + ], +) +def test_nested_categorical_concat( + f: Callable[[str], tuple[pl.DataType, list[str] | dict[str, str]]], +) -> None: + dt, va = f("a") + _, vb = f("b") + a = pl.DataFrame({"x": [va]}, schema={"x": dt}) + b = pl.DataFrame({"x": [vb]}, schema={"x": dt}) + + with pytest.raises(pl.exceptions.StringCacheMismatchError): + pl.concat([a, b]) + + +@with_string_cache_if_auto_streaming +@pytest.mark.usefixtures("test_global_and_local") +def test_perfect_group_by_19452() -> None: + n = 40 + df2 = pl.DataFrame( + { + "a": pl.int_range(n, eager=True).cast(pl.String).cast(pl.Categorical), + "b": pl.int_range(n, eager=True), + } + ) + + assert df2.with_columns(a=(pl.col("b")).over(pl.col("a")))["a"].is_sorted() + + +@pytest.mark.usefixtures("test_global_and_local") +def test_perfect_group_by_19950() -> None: + dtype = pl.Enum(categories=["a", "b", "c"]) + + left = pl.DataFrame({"x": "a"}).cast(dtype) + right = pl.DataFrame({"x": "a", "y": "b"}).cast(dtype) + assert left.join(right, on="x").group_by("y").first().to_dict(as_series=False) == { + "y": ["b"], + "x": ["a"], + } + + +@pytest.mark.usefixtures("test_global_and_local") +def test_categorical_unique() -> None: + with pl.StringCache(): + s = pl.Series(["a", "b", None], dtype=pl.Categorical) + assert s.n_unique() == 3 + assert s.unique().sort().to_list() == [None, "a", "b"] + + +@pytest.mark.usefixtures("test_global_and_local") +def test_categorical_unique_20539() -> None: + with pl.StringCache(): + df = pl.DataFrame( + {"number": [1, 1, 2, 2, 3], "letter": ["a", "b", "b", "c", "c"]} + ) + + result = ( + df.cast({"letter": pl.Categorical}) + .group_by("number") + .agg( + unique=pl.col("letter").unique(maintain_order=True), + unique_with_order=pl.col("letter").unique(maintain_order=True), + ) + ) + + assert result.sort("number").to_dict(as_series=False) == { + "number": [1, 2, 3], + "unique": [["a", "b"], ["b", "c"], ["c"]], + "unique_with_order": [["a", "b"], ["b", "c"], ["c"]], + } + + +@pytest.mark.usefixtures("test_global_and_local") +def test_categorical_prefill() -> None: + with pl.StringCache(): + # https://github.com/pola-rs/polars/pull/20547#issuecomment-2569473443 + # test_compare_categorical_single + assert (pl.Series(["a"], dtype=pl.Categorical) < "a").to_list() == [False] + + # test_unique_categorical + a = pl.Series(["a"], dtype=pl.Categorical) + assert a.unique().to_list() == ["a"] + + s = pl.Series(["1", "2", "3"], dtype=pl.Categorical) + s = s.filter([True, False, True]) + assert s.n_unique() == 2 + + +@pytest.mark.may_fail_auto_streaming # not implemented +@pytest.mark.usefixtures("test_global_and_local") +def test_categorical_min_max() -> None: + schema = pl.Schema( + { + "a": pl.Categorical("physical"), + "b": pl.Categorical("lexical"), + "c": pl.Enum(["foo", "bar"]), + } + ) + lf = pl.LazyFrame( + { + "a": ["foo", "bar"], + "b": ["foo", "bar"], + "c": ["foo", "bar"], + }, + schema=schema, + ) + + q = lf.select(pl.all().min()) + result = q.collect() + assert q.collect_schema() == schema + assert result.schema == schema + assert result.to_dict(as_series=False) == {"a": ["foo"], "b": ["bar"], "c": ["foo"]} + + q = lf.select(pl.all().max()) + result = q.collect() + assert q.collect_schema() == schema + assert result.schema == schema + assert result.to_dict(as_series=False) == {"a": ["bar"], "b": ["foo"], "c": ["bar"]} diff --git a/py-polars/tests/unit/datatypes/test_datatype.py b/py-polars/tests/unit/datatypes/test_datatype.py new file mode 100644 index 000000000000..804fb52829b6 --- /dev/null +++ b/py-polars/tests/unit/datatypes/test_datatype.py @@ -0,0 +1,11 @@ +import copy + +import polars as pl + + +# https://github.com/pola-rs/polars/issues/14771 +def test_datatype_copy() -> None: + dtype = pl.Int64() + result = copy.deepcopy(dtype) + assert dtype == result + assert isinstance(result, pl.Int64) diff --git a/py-polars/tests/unit/datatypes/test_decimal.py b/py-polars/tests/unit/datatypes/test_decimal.py new file mode 100644 index 000000000000..544ffe5f7db6 --- /dev/null +++ b/py-polars/tests/unit/datatypes/test_decimal.py @@ -0,0 +1,722 @@ +from __future__ import annotations + +import io +import itertools +import operator +from dataclasses import dataclass +from decimal import Decimal as D +from math import ceil, floor +from random import choice, randrange, seed +from typing import Any, Callable, NamedTuple + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal + + +@pytest.fixture(scope="module") +def permutations_int_dec_none() -> list[tuple[D | int | None, ...]]: + return list( + itertools.permutations( + [ + D("-0.01"), + D("1.2345678"), + D("500"), + -1, + None, + ] + ) + ) + + +@pytest.mark.slow +def test_series_from_pydecimal_and_ints( + permutations_int_dec_none: list[tuple[D | int | None, ...]], +) -> None: + # TODO: check what happens if there are strings, floats arrow scalars in the list + for data in permutations_int_dec_none: + s = pl.Series("name", data, strict=False) + assert s.dtype == pl.Decimal(scale=7) # inferred scale = 7, precision = None + assert s.dtype.is_decimal() + assert s.name == "name" + assert s.null_count() == 1 + for i, d in enumerate(data): + assert s[i] == d + assert s.to_list() == [D(x) if x is not None else None for x in data] + + +@pytest.mark.slow +def test_frame_from_pydecimal_and_ints( + permutations_int_dec_none: list[tuple[D | int | None, ...]], monkeypatch: Any +) -> None: + class X(NamedTuple): + a: int | D | None + + @dataclass + class Y: + a: int | D | None + + for data in permutations_int_dec_none: + row_data = [(d,) for d in data] + for cls in (X, Y): + for ctor in (pl.DataFrame, pl.from_records): + df = ctor(data=list(map(cls, data))) + assert df.schema == { + "a": pl.Decimal(scale=7), + } + assert df.rows() == row_data + + +@pytest.mark.parametrize( + ("input", "trim_zeros", "expected"), + [ + ("0.00", True, "0"), + ("0.00", False, "0.00"), + ("-1", True, "-1"), + ("-1.000000000000000000000000000", False, "-1.000000000000000000000000000"), + ("0.0100", True, "0.01"), + ("0.0100", False, "0.0100"), + ("0.010000000000000000000000000", False, "0.010000000000000000000000000"), + ("-1.123801239123981293891283123", True, "-1.123801239123981293891283123"), + ( + "12345678901.234567890123458390192857685", + True, + "12345678901.234567890123458390192857685", + ), + ( + "-99999999999.999999999999999999999999999", + True, + "-99999999999.999999999999999999999999999", + ), + ], +) +def test_decimal_format(input: str, trim_zeros: bool, expected: str) -> None: + with pl.Config(trim_decimal_zeros=trim_zeros): + series = pl.Series([input]).str.to_decimal() + formatted = str(series).split("\n")[-2].strip() + assert formatted == expected + + +def test_init_decimal_dtype() -> None: + s = pl.Series( + "a", [D("-0.01"), D("1.2345678"), D("500")], dtype=pl.Decimal, strict=False + ) + assert s.dtype.is_numeric() + + df = pl.DataFrame( + {"a": [D("-0.01"), D("1.2345678"), D("500")]}, + schema={"a": pl.Decimal}, + strict=False, + ) + assert df["a"].dtype.is_numeric() + + +def test_decimal_convert_to_float_by_schema() -> None: + # Column Based + df = pl.DataFrame( + {"a": [D("1"), D("2.55"), D("45.000"), D("10.0")]}, schema={"a": pl.Float64} + ) + expected = pl.DataFrame({"a": [1.0, 2.55, 45.0, 10.0]}) + assert_frame_equal(df, expected) + + # Row Based + df = pl.DataFrame( + [[D("1"), D("2.55"), D("45.000"), D("10.0")]], schema={"a": pl.Float64} + ) + expected = pl.DataFrame({"a": [1.0, 2.55, 45.0, 10.0]}) + assert_frame_equal(df, expected) + + +def test_df_constructor_convert_decimal_to_float_9873() -> None: + result = pl.DataFrame( + [[D("45.0000")], [D("45.0000")]], schema={"a": pl.Float64}, orient="row" + ) + expected = pl.DataFrame({"a": [45.0, 45.0]}) + assert_frame_equal(result, expected) + + +def test_decimal_cast() -> None: + df = pl.DataFrame( + { + "decimals": [D("2"), D("2"), D("-1.5")], + }, + strict=False, + ) + + result = df.with_columns(pl.col("decimals").cast(pl.Float32).alias("b2")) + expected = {"decimals": [D("2"), D("2"), D("-1.5")], "b2": [2.0, 2.0, -1.5]} + assert result.to_dict(as_series=False) == expected + + +def test_decimal_cast_no_scale() -> None: + s = pl.Series().cast(pl.Decimal) + assert s.dtype == pl.Decimal(precision=None, scale=0) + + s = pl.Series([D("10.0")]).cast(pl.Decimal) + assert s.dtype == pl.Decimal(precision=None, scale=1) + + +def test_decimal_scale_precision_roundtrip(monkeypatch: Any) -> None: + assert pl.from_arrow(pl.Series("dec", [D("10.0")]).to_arrow()).item() == D("10.0") + + +def test_string_to_decimal() -> None: + values = [ + "40.12", + "3420.13", + "120134.19", + "3212.98", + "12.90", + "143.09", + "143.9", + "-62.44", + ] + + s = pl.Series(values).str.to_decimal() + assert s.dtype == pl.Decimal(scale=2) + + assert s.to_list() == [D(v) for v in values] + + +def test_read_csv_decimal(monkeypatch: Any) -> None: + csv = """a,b +123.12,a +1.1,a +0.01,a""" + + df = pl.read_csv(csv.encode(), schema_overrides={"a": pl.Decimal(scale=2)}) + assert df.dtypes == [pl.Decimal(scale=2), pl.String] + assert df["a"].to_list() == [ + D("123.12"), + D("1.10"), + D("0.01"), + ] + + +def test_decimal_eq_number() -> None: + a = pl.Series([D("1.5"), D("22.25"), D("10.0")], dtype=pl.Decimal, strict=False) + assert_series_equal(a == 1, pl.Series([False, False, False])) + assert_series_equal(a == 1.5, pl.Series([True, False, False])) + assert_series_equal(a == D("1.5"), pl.Series([True, False, False])) + assert_series_equal(a == pl.Series([D("1.5")]), pl.Series([True, False, False])) + + +@pytest.mark.parametrize( + ("op", "expected"), + [ + (operator.le, pl.Series([None, True, True, True, True, True])), + (operator.lt, pl.Series([None, False, False, False, True, True])), + (operator.ge, pl.Series([None, True, True, True, False, False])), + (operator.gt, pl.Series([None, False, False, False, False, False])), + ], +) +def test_decimal_compare( + op: Callable[[pl.Series, pl.Series], pl.Series], expected: pl.Series +) -> None: + s = pl.Series( + [None, D("1.2"), D("2.13"), D("4.99"), D("2.13"), D("1.2")], + dtype=pl.Decimal, + strict=False, + ) + s2 = pl.Series( + [None, D("1.200"), D("2.13"), D("4.99"), D("4.99"), D("2.13")], strict=False + ) + + assert_series_equal(op(s, s2), expected) + + +def test_decimal_arithmetic() -> None: + df = pl.DataFrame( + { + "a": [D("0.1"), D("10.1"), D("100.01")], + "b": [D("20.1"), D("10.19"), D("39.21")], + }, + strict=False, + ) + dt = pl.Decimal(20, 10) + + out = df.select( + out1=pl.col("a") * pl.col("b"), + out2=pl.col("a") + pl.col("b"), + out3=pl.col("a") / pl.col("b"), + out4=pl.col("a") - pl.col("b"), + out5=pl.col("a").cast(dt) / pl.col("b").cast(dt), + ) + assert out.dtypes == [ + pl.Decimal(precision=None, scale=4), + pl.Decimal(precision=None, scale=2), + pl.Decimal(precision=None, scale=6), + pl.Decimal(precision=None, scale=2), + pl.Decimal(precision=None, scale=14), + ] + + assert out.to_dict(as_series=False) == { + "out1": [D("2.0100"), D("102.9190"), D("3921.3921")], + "out2": [D("20.20"), D("20.29"), D("139.22")], + "out3": [D("0.004975"), D("0.991167"), D("2.550624")], + "out4": [D("-20.00"), D("-0.09"), D("60.80")], + "out5": [D("0.00497512437810"), D("0.99116781157998"), D("2.55062484060188")], + } + + +def test_decimal_series_value_arithmetic() -> None: + s = pl.Series([D("0.10"), D("10.10"), D("100.01")]) + + out1 = s + 10 + out2 = s + D("10") + out3 = s + D("10.0001") + out4 = s * 2 / 3 + out5 = s / D("1.5") + out6 = s - 5 + + assert out1.dtype == pl.Decimal(precision=None, scale=2) + assert out2.dtype == pl.Decimal(precision=None, scale=2) + assert out3.dtype == pl.Decimal(precision=None, scale=4) + assert out4.dtype == pl.Decimal(precision=None, scale=8) + assert out5.dtype == pl.Decimal(precision=None, scale=6) + assert out6.dtype == pl.Decimal(precision=None, scale=2) + + assert out1.to_list() == [D("10.1"), D("20.1"), D("110.01")] + assert out2.to_list() == [D("10.1"), D("20.1"), D("110.01")] + assert out3.to_list() == [D("10.1001"), D("20.1001"), D("110.0101")] + assert out4.to_list() == [ + D("0.06666666"), + D("6.73333333"), + D("66.67333333"), + ] # TODO: do we want floor instead of round? + assert out5.to_list() == [D("0.066666"), D("6.733333"), D("66.673333")] + assert out6.to_list() == [D("-4.9"), D("5.1"), D("95.01")] + + +def test_decimal_aggregations() -> None: + df = pl.DataFrame( + { + "g": [1, 1, 2, 2], + "a": [D("0.1"), D("10.1"), D("100.01"), D("9000.12")], + }, + strict=False, + ) + + assert df.group_by("g").agg("a").sort("g").to_dict(as_series=False) == { + "g": [1, 2], + "a": [[D("0.1"), D("10.1")], [D("100.01"), D("9000.12")]], + } + + assert df.group_by("g", maintain_order=True).agg( + sum=pl.sum("a"), + min=pl.min("a"), + max=pl.max("a"), + ).to_dict(as_series=False) == { + "g": [1, 2], + "sum": [D("10.20"), D("9100.13")], + "min": [D("0.10"), D("100.01")], + "max": [D("10.10"), D("9000.12")], + } + + res = df.select( + sum=pl.sum("a"), + min=pl.min("a"), + max=pl.max("a"), + mean=pl.mean("a"), + median=pl.median("a"), + ) + expected = pl.DataFrame( + { + "sum": [D("9110.33")], + "min": [D("0.10")], + "max": [D("9000.12")], + "mean": [2277.5825], + "median": [55.055], + } + ) + assert_frame_equal(res, expected) + + description = pl.DataFrame( + { + "statistic": [ + "count", + "null_count", + "mean", + "std", + "min", + "25%", + "50%", + "75%", + "max", + ], + "g": [4.0, 0.0, 1.5, 0.5773502691896257, 1.0, 1.0, 2.0, 2.0, 2.0], + "a": [ + 4.0, + 0.0, + 2277.5825, + 4481.916846516863, + 0.1, + 10.1, + 100.01, + 100.01, + 9000.12, + ], + } + ) + assert_frame_equal(df.describe(), description) + + +def test_decimal_cumulative_aggregations() -> None: + df = pl.Series("a", [D("2.2"), D("1.1"), D("3.3")]).to_frame() + result = df.select( + pl.col("a").cum_sum().alias("cum_sum"), + pl.col("a").cum_min().alias("cum_min"), + pl.col("a").cum_max().alias("cum_max"), + ) + expected = pl.DataFrame( + { + "cum_sum": [D("2.2"), D("3.3"), D("6.6")], + "cum_min": [D("2.2"), D("1.1"), D("1.1")], + "cum_max": [D("2.2"), D("2.2"), D("3.3")], + } + ) + assert_frame_equal(result, expected) + + +def test_decimal_df_vertical_sum() -> None: + df = pl.DataFrame({"a": [D("1.1"), D("2.2")]}) + expected = pl.DataFrame({"a": [D("3.3")]}) + + assert_frame_equal(df.sum(), expected) + + +def test_decimal_df_vertical_agg() -> None: + df = pl.DataFrame({"a": [D("1.0"), D("2.0"), D("3.0")]}) + expected_min = pl.DataFrame({"a": [D("1.0")]}) + expected_max = pl.DataFrame({"a": [D("3.0")]}) + assert_frame_equal(df.min(), expected_min) + assert_frame_equal(df.max(), expected_max) + + +def test_decimal_in_filter() -> None: + df = pl.DataFrame( + { + "foo": [1, 2, 3], + "bar": ["6", "7", "8"], + } + ) + df = df.with_columns(pl.col("bar").cast(pl.Decimal)) + assert df.filter(pl.col("foo") > 1).to_dict(as_series=False) == { + "foo": [2, 3], + "bar": [D("7"), D("8")], + } + + +def test_decimal_sort() -> None: + df = pl.DataFrame( + { + "foo": [1, 2, 3], + "bar": [D("3.4"), D("2.1"), D("4.5")], + "baz": [1, 1, 2], + } + ) + assert df.sort("bar").to_dict(as_series=False) == { + "foo": [2, 1, 3], + "bar": [D("2.1"), D("3.4"), D("4.5")], + "baz": [1, 1, 2], + } + assert df.sort(["bar", "foo"]).to_dict(as_series=False) == { + "foo": [2, 1, 3], + "bar": [D("2.1"), D("3.4"), D("4.5")], + "baz": [1, 1, 2], + } + assert df.sort(["foo", "bar"]).to_dict(as_series=False) == { + "foo": [1, 2, 3], + "bar": [D("3.4"), D("2.1"), D("4.5")], + "baz": [1, 1, 2], + } + + assert df.select([pl.col("foo").sort_by("bar", descending=True).alias("s1")])[ + "s1" + ].to_list() == [3, 1, 2] + assert df.select([pl.col("foo").sort_by(["baz", "bar"]).alias("s2")])[ + "s2" + ].to_list() == [2, 1, 3] + + +def test_decimal_unique() -> None: + df = pl.DataFrame( + { + "foo": [1, 1, 2], + "bar": [D("3.4"), D("3.4"), D("4.5")], + } + ) + assert df.unique().sort("bar").to_dict(as_series=False) == { + "foo": [1, 2], + "bar": [D("3.4"), D("4.5")], + } + + +def test_decimal_write_parquet_12375() -> None: + df = pl.DataFrame( + { + "hi": [True, False, True, False], + "bye": [D(1), D(2), D(3), D(47283957238957239875)], + }, + ) + assert df["bye"].dtype == pl.Decimal + + f = io.BytesIO() + df.write_parquet(f) + + +def test_decimal_list_get_13847() -> None: + df = pl.DataFrame({"a": [[D("1.1"), D("1.2")], [D("2.1")]]}) + out = df.select(pl.col("a").list.get(0)) + print(out) + expected = pl.DataFrame({"a": [D("1.1"), D("2.1")]}) + assert_frame_equal(out, expected) + + +def test_decimal_explode() -> None: + nested_decimal_df = pl.DataFrame( + { + "bar": [[D("3.4"), D("3.4")], [D("4.5")]], + } + ) + df = nested_decimal_df.explode("bar") + expected_df = pl.DataFrame( + { + "bar": [D("3.4"), D("3.4"), D("4.5")], + } + ) + assert_frame_equal(df, expected_df) + + # test group-by head #15330 + df = pl.DataFrame( + { + "foo": [1, 1, 2], + "bar": [D("3.4"), D("3.4"), D("4.5")], + } + ) + head_df = df.group_by("foo", maintain_order=True).head(1) + expected_df = pl.DataFrame({"foo": [1, 2], "bar": [D("3.4"), D("4.5")]}) + assert_frame_equal(head_df, expected_df) + + +def test_decimal_streaming() -> None: + seed(1) + scale = D("1e18") + data = [ + {"group": choice("abc"), "value": randrange(10**32) / scale} for _ in range(20) + ] + lf = pl.LazyFrame(data, schema_overrides={"value": pl.Decimal(scale=18)}) + assert lf.group_by("group").agg(pl.sum("value")).collect(engine="streaming").sort( + "group" + ).to_dict(as_series=False) == { + "group": ["a", "b", "c"], + "value": [ + D("244215083629512.120161049441284000"), + D("510640422312378.070344831471216000"), + D("161102921617598.363263936811563000"), + ], + } + + +def test_decimal_supertype() -> None: + q = pl.LazyFrame([0.12345678]).select( + pl.col("column_0").cast(pl.Decimal(scale=6)) * 1 + ) + assert q.collect().dtypes[0].is_decimal() + + +def test_decimal_raise_oob_precision() -> None: + df = pl.DataFrame({"a": [1.0]}) + # max precision is 38. + with pytest.raises(pl.exceptions.InvalidOperationError): + df.select(b=pl.col("a").cast(pl.Decimal(76, 38))) + + +def test_decimal_dynamic_float_st() -> None: + assert pl.LazyFrame({"a": [D("2.0"), D("0.5")]}).filter( + pl.col("a").is_between(0.45, 0.9) + ).collect().to_dict(as_series=False) == {"a": [D("0.5")]} + + +def test_decimal_strict_scale_inference_17770() -> None: + values = [D("0.1"), D("0.10"), D("1.0121")] + s = pl.Series(values, strict=True) + assert s.dtype == pl.Decimal(precision=None, scale=4) + assert s.to_list() == values + + +def test_decimal_round() -> None: + dtype = pl.Decimal(3, 2) + values = [D(f"{float(v) / 100.0:.02f}") for v in range(-150, 250, 1)] + i_s = pl.Series("a", values, dtype) + + floor_s = pl.Series("a", [floor(v) for v in values], dtype) + ceil_s = pl.Series("a", [ceil(v) for v in values], dtype) + + assert_series_equal(i_s.floor(), floor_s) + assert_series_equal(i_s.ceil(), ceil_s) + + for decimals in range(10): + got_s = i_s.round(decimals) + expected_s = pl.Series("a", [round(v, decimals) for v in values], dtype) + + assert_series_equal(got_s, expected_s) + + +def test_decimal_arithmetic_schema() -> None: + q = pl.LazyFrame({"x": [1.0]}, schema={"x": pl.Decimal(15, 2)}) + + q1 = q.select(pl.col.x * pl.col.x) + assert q1.collect_schema() == q1.collect().schema + q1 = q.select(pl.col.x / pl.col.x) + assert q1.collect_schema() == q1.collect().schema + q1 = q.select(pl.col.x - pl.col.x) + assert q1.collect_schema() == q1.collect().schema + q1 = q.select(pl.col.x + pl.col.x) + assert q1.collect_schema() == q1.collect().schema + + +def test_decimal_arithmetic_schema_float_20369() -> None: + s = pl.Series("x", [1.0], dtype=pl.Decimal(15, 2)) + assert_series_equal((s - 1.0), pl.Series("x", [0.0], dtype=pl.Decimal(None, 2))) + assert_series_equal( + (3.0 - s), pl.Series("literal", [2.0], dtype=pl.Decimal(None, 2)) + ) + assert_series_equal( + (3.0 / s), pl.Series("literal", [3.0], dtype=pl.Decimal(None, 6)) + ) + assert_series_equal( + (s / 3.0), pl.Series("x", [0.333333], dtype=pl.Decimal(None, 6)) + ) + + assert_series_equal((s + 1.0), pl.Series("x", [2.0], dtype=pl.Decimal(None, 2))) + assert_series_equal( + (1.0 + s), pl.Series("literal", [2.0], dtype=pl.Decimal(None, 2)) + ) + assert_series_equal((s * 1.0), pl.Series("x", [1.0], dtype=pl.Decimal(None, 4))) + assert_series_equal( + (1.0 * s), pl.Series("literal", [1.0], dtype=pl.Decimal(None, 4)) + ) + + +def test_decimal_horizontal_20482() -> None: + b = pl.LazyFrame( + { + "a": [D("123.000000"), D("234.000000")], + "b": [D("123.000000"), D("234.000000")], + }, + schema={ + "a": pl.Decimal(18, 6), + "b": pl.Decimal(18, 6), + }, + ) + + assert ( + b.select( + min=pl.min_horizontal(pl.col("a"), pl.col("b")), + max=pl.max_horizontal(pl.col("a"), pl.col("b")), + sum=pl.sum_horizontal(pl.col("a"), pl.col("b")), + ).collect() + ).to_dict(as_series=False) == { + "min": [D("123.000000"), D("234.000000")], + "max": [D("123.000000"), D("234.000000")], + "sum": [D("246.000000"), D("468.000000")], + } + + +def test_decimal_horizontal_different_scales_16296() -> None: + df = pl.DataFrame( + { + "a": [D("1.111")], + "b": [D("2.22")], + "c": [D("3.3")], + }, + schema={ + "a": pl.Decimal(18, 3), + "b": pl.Decimal(18, 2), + "c": pl.Decimal(18, 1), + }, + ) + + assert ( + df.select( + min=pl.min_horizontal(pl.col("a", "b", "c")), + max=pl.max_horizontal(pl.col("a", "b", "c")), + sum=pl.sum_horizontal(pl.col("a", "b", "c")), + ) + ).to_dict(as_series=False) == { + "min": [D("1.111")], + "max": [D("3.300")], + "sum": [D("6.631")], + } + + +def test_shift_over_12957() -> None: + df = pl.DataFrame( + { + "a": [1, 1, 2, 2], + "b": [D("1.1"), D("1.1"), D("2.2"), D("2.2")], + } + ) + result = df.select( + x=pl.col("b").shift(1).over("a"), + y=pl.col("a").shift(1).over("b"), + ) + assert result["x"].to_list() == [None, D("1.1"), None, D("2.2")] + assert result["y"].to_list() == [None, 1, None, 2] + + +def test_fill_null() -> None: + s = pl.Series("a", [D("1.2"), None, D("1.4")]) + + assert s.fill_null(D("0.0")).to_list() == [D("1.2"), D("0.0"), D("1.4")] + assert s.fill_null(strategy="zero").to_list() == [D("1.2"), D("0.0"), D("1.4")] + assert s.fill_null(strategy="max").to_list() == [D("1.2"), D("1.4"), D("1.4")] + assert s.fill_null(strategy="min").to_list() == [D("1.2"), D("1.2"), D("1.4")] + assert s.fill_null(strategy="one").to_list() == [D("1.2"), D("1.0"), D("1.4")] + assert s.fill_null(strategy="forward").to_list() == [D("1.2"), D("1.2"), D("1.4")] + assert s.fill_null(strategy="backward").to_list() == [D("1.2"), D("1.4"), D("1.4")] + assert s.fill_null(strategy="mean").to_list() == [D("1.2"), D("1.3"), D("1.4")] + + +def test_unique() -> None: + ser = pl.Series([D("1.1"), D("1.1"), D("2.2")]) + uniq = pl.Series([D("1.1"), D("2.2")]) + + assert_series_equal(ser.unique(maintain_order=False), uniq, check_order=False) + assert_series_equal(ser.unique(maintain_order=True), uniq) + assert ser.n_unique() == 2 + assert ser.arg_unique().to_list() == [0, 2] + + +def test_groupby_agg_single_element_11232() -> None: + data = {"g": [-1], "decimal": [-1]} + schema = {"g": pl.Int64(), "decimal": pl.Decimal(38, 0)} + result = ( + pl.LazyFrame(data, schema=schema) + .group_by("g", maintain_order=True) + .agg(pl.col("decimal").min()) + .collect() + ) + expected = pl.DataFrame(data, schema=schema) + assert_frame_equal(result, expected) + + +def test_decimal_from_large_ints_9084() -> None: + numbers = [2963091539321097135000000000, 25658709114149718824803874] + s = pl.Series(numbers, dtype=pl.Decimal) + assert s.to_list() == [D(n) for n in numbers] + + +def test_cast_float_to_decimal_12775() -> None: + s = pl.Series([1.5]) + # default scale = 0 + assert s.cast(pl.Decimal).to_list() == [D("1")] + assert s.cast(pl.Decimal(scale=1)).to_list() == [D("1.5")] + + +def test_decimal_min_over_21096() -> None: + df = pl.Series("x", [1, 2], pl.Decimal(scale=2)).to_frame() + result = df.select(pl.col("x").min().over("x")) + assert result["x"].to_list() == [D("1.00"), D("2.00")] diff --git a/py-polars/tests/unit/datatypes/test_duration.py b/py-polars/tests/unit/datatypes/test_duration.py new file mode 100644 index 000000000000..675885d5c0b7 --- /dev/null +++ b/py-polars/tests/unit/datatypes/test_duration.py @@ -0,0 +1,224 @@ +from datetime import timedelta +from typing import Any + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + + +def test_duration_cum_sum() -> None: + df = pl.DataFrame({"A": [timedelta(days=1), timedelta(days=2)]}) + + assert df.select(pl.col("A").cum_sum()).to_dict(as_series=False) == { + "A": [timedelta(days=1), timedelta(days=3)] + } + assert df.schema["A"].is_(pl.Duration(time_unit="us")) + for duration_dtype in ( + pl.Duration, + pl.Duration(time_unit="ms"), + pl.Duration(time_unit="ns"), + ): + assert df.schema["A"].is_(duration_dtype) is False + + +def test_duration_to_string() -> None: + df = pl.DataFrame( + { + "td": [ + timedelta(days=180, seconds=56789, microseconds=987654), + timedelta(days=0, seconds=64875, microseconds=8884), + timedelta(days=2, hours=23, seconds=4975, milliseconds=1), + timedelta(hours=1, seconds=1, milliseconds=1, microseconds=1), + timedelta(seconds=-42, milliseconds=-42), + None, + ] + }, + schema={"td": pl.Duration("us")}, + ) + + df_str = df.select( + td_ms=pl.col("td").cast(pl.Duration("ms")), + td_int=pl.col("td").cast(pl.Int64), + td_str_iso=pl.col("td").dt.to_string(), + td_str_pl=pl.col("td").dt.to_string("polars"), + ) + assert df_str.schema == { + "td_ms": pl.Duration(time_unit="ms"), + "td_int": pl.Int64, + "td_str_iso": pl.String, + "td_str_pl": pl.String, + } + + expected = pl.DataFrame( + { + "td_ms": [ + timedelta(days=180, seconds=56789, milliseconds=987), + timedelta(days=0, seconds=64875, milliseconds=8), + timedelta(days=2, hours=23, seconds=4975, milliseconds=1), + timedelta(hours=1, seconds=1, milliseconds=1), + timedelta(seconds=-42, milliseconds=-42), + None, + ], + "td_int": [ + 15608789987654, + 64875008884, + 260575001000, + 3601001001, + -42042000, + None, + ], + "td_str_iso": [ + "P180DT15H46M29.987654S", + "PT18H1M15.008884S", + "P3DT22M55.001S", + "PT1H1.001001S", + "-PT42.042S", + None, + ], + "td_str_pl": [ + "180d 15h 46m 29s 987654µs", + "18h 1m 15s 8884µs", + "3d 22m 55s 1ms", + "1h 1s 1001µs", + "-42s -42ms", + None, + ], + }, + schema_overrides={"td_ms": pl.Duration(time_unit="ms")}, + ) + assert_frame_equal(expected, df_str) + + # individual +/- parts + df = pl.DataFrame( + { + "td_ns": [ + timedelta(weeks=1), + timedelta(days=1), + timedelta(hours=1), + timedelta(minutes=1), + timedelta(seconds=1), + timedelta(milliseconds=1), + timedelta(microseconds=1), + timedelta(seconds=0), + timedelta(microseconds=-1), + timedelta(milliseconds=-1), + timedelta(seconds=-1), + timedelta(minutes=-1), + timedelta(hours=-1), + timedelta(days=-1), + timedelta(weeks=-1), + ] + }, + schema={"td_ns": pl.Duration("ns")}, + ) + df_str = df.select(pl.col("td_ns").dt.to_string("iso")) + assert df_str["td_ns"].to_list() == [ + "P7D", + "P1D", + "PT1H", + "PT1M", + "PT1S", + "PT0.001S", + "PT0.000001S", + "PT0S", + "-PT0.000001S", + "-PT0.001S", + "-PT1S", + "-PT1M", + "-PT1H", + "-P1D", + "-P7D", + ] + + +def test_duration_std_var() -> None: + df = pl.DataFrame( + {"duration": [1000, 5000, 3000]}, schema={"duration": pl.Duration} + ) + + result = df.select( + pl.col("duration").var().name.suffix("_var"), + pl.col("duration").std().name.suffix("_std"), + ) + + expected = pl.DataFrame( + [ + pl.Series( + "duration_var", + [timedelta(microseconds=4000)], + dtype=pl.Duration(time_unit="ms"), + ), + pl.Series( + "duration_std", + [timedelta(microseconds=2000)], + dtype=pl.Duration(time_unit="us"), + ), + ] + ) + + assert_frame_equal(result, expected) + + +def test_series_duration_std_var() -> None: + s = pl.Series([timedelta(days=1), timedelta(days=2), timedelta(days=4)]) + assert s.std() == timedelta(days=1, seconds=45578, microseconds=180014) + assert s.var() == timedelta(days=201600000) + + +def test_series_duration_var_overflow() -> None: + s = pl.Series([timedelta(days=10), timedelta(days=20), timedelta(days=40)]) + with pytest.raises(OverflowError): + s.var() + + +@pytest.mark.parametrize("other", [24, pl.Series([24])]) +def test_series_duration_div_multiply(other: Any) -> None: + s = pl.Series([timedelta(hours=1)]) + assert (s * other).to_list() == [timedelta(days=1)] + assert (other * s).to_list() == [timedelta(days=1)] + assert (s / other).to_list() == [timedelta(minutes=2, seconds=30)] + + +def test_series_duration_units() -> None: + td = timedelta + + assert_frame_equal( + pl.DataFrame({"x": [0, 1, 2, 3]}).select(x=pl.duration(weeks=pl.col("x"))), + pl.DataFrame({"x": [td(weeks=i) for i in range(4)]}), + ) + assert_frame_equal( + pl.DataFrame({"x": [0, 1, 2, 3]}).select(x=pl.duration(days=pl.col("x"))), + pl.DataFrame({"x": [td(days=i) for i in range(4)]}), + ) + assert_frame_equal( + pl.DataFrame({"x": [0, 1, 2, 3]}).select(x=pl.duration(hours=pl.col("x"))), + pl.DataFrame({"x": [td(hours=i) for i in range(4)]}), + ) + assert_frame_equal( + pl.DataFrame({"x": [0, 1, 2, 3]}).select(x=pl.duration(minutes=pl.col("x"))), + pl.DataFrame({"x": [td(minutes=i) for i in range(4)]}), + ) + assert_frame_equal( + pl.DataFrame({"x": [0, 1, 2, 3]}).select( + x=pl.duration(milliseconds=pl.col("x")) + ), + pl.DataFrame({"x": [td(milliseconds=i) for i in range(4)]}), + ) + assert_frame_equal( + pl.DataFrame({"x": [0, 1, 2, 3]}).select( + x=pl.duration(microseconds=pl.col("x")) + ), + pl.DataFrame({"x": [td(microseconds=i) for i in range(4)]}), + ) + + +def test_comparison_with_string_raises_9461() -> None: + df = pl.DataFrame({"duration": [timedelta(hours=2)]}) + with pytest.raises(pl.exceptions.InvalidOperationError): + df.filter(pl.col("duration") > "1h") + + +def test_duration_invalid_cast_22258() -> None: + with pytest.raises(pl.exceptions.InvalidOperationError): + pl.select(a=pl.duration(days=[1, 2, 3, 4])) # type: ignore[arg-type] diff --git a/py-polars/tests/unit/datatypes/test_enum.py b/py-polars/tests/unit/datatypes/test_enum.py new file mode 100644 index 000000000000..d19b3c8945fd --- /dev/null +++ b/py-polars/tests/unit/datatypes/test_enum.py @@ -0,0 +1,725 @@ +# mypy: disable-error-code="redundant-expr" +from __future__ import annotations + +import enum +import io +import operator +import re +import sys +from datetime import date +from textwrap import dedent +from typing import Any, Callable + +import pytest + +import polars as pl +from polars import StringCache +from polars.exceptions import ( + ComputeError, + InvalidOperationError, + OutOfBoundsError, + SchemaError, +) +from polars.testing import assert_frame_equal, assert_series_equal +from tests.unit.conftest import INTEGER_DTYPES + +if sys.version_info >= (3, 11): + from enum import StrEnum + + PyStrEnum: type[enum.Enum] | None = StrEnum +else: + PyStrEnum = None + + +def test_enum_creation() -> None: + dtype = pl.Enum(["a", "b"]) + s = pl.Series([None, "a", "b"], dtype=dtype) + assert s.null_count() == 1 + assert s.len() == 3 + assert s.dtype == dtype + + # from iterables + e = pl.Enum(f"x{i}" for i in range(5)) + assert e.categories.to_list() == ["x0", "x1", "x2", "x3", "x4"] + + e = pl.Enum("abcde") + assert e.categories.to_list() == ["a", "b", "c", "d", "e"] + + +@pytest.mark.parametrize("categories", [[], pl.Series("foo", dtype=pl.Int16), None]) +def test_enum_init_empty(categories: pl.Series | list[str] | None) -> None: + dtype = pl.Enum(categories) # type: ignore[arg-type] + expected = pl.Series("category", dtype=pl.String) + assert_series_equal(dtype.categories, expected) + + +def test_enum_init_from_python() -> None: + # standard string enum + class Color1(str, enum.Enum): + RED = "red" + GREEN = "green" + BLUE = "blue" + + dtype = pl.Enum(Color1) + assert dtype == pl.Enum(["red", "green", "blue"]) + + # standard generic enum + class Color2(enum.Enum): + RED = "red" + GREEN = "green" + BLUE = "blue" + + dtype = pl.Enum(Color2) + assert dtype == pl.Enum(["red", "green", "blue"]) + + # specialised string enum + if sys.version_info >= (3, 11): + + class Color3(enum.Enum): + RED = "red" + GREEN = "green" + BLUE = "blue" + + dtype = pl.Enum(Color3) + assert dtype == pl.Enum(["red", "green", "blue"]) + + +def test_enum_init_from_python_invalid() -> None: + class Color(int, enum.Enum): + RED = 1 + GREEN = 2 + BLUE = 3 + + with pytest.raises( + TypeError, + match="Enum categories must be strings", + ): + pl.Enum(Color) + + # flag/int enums + for EnumBase in (enum.Flag, enum.IntFlag, enum.IntEnum): + + class Color(EnumBase): # type: ignore[no-redef,misc,valid-type] + RED = enum.auto() + GREEN = enum.auto() + BLUE = enum.auto() + + with pytest.raises( + TypeError, + match="Enum categories must be strings; `Color` values are integers", + ): + pl.Enum(Color) + + +def test_enum_non_existent() -> None: + with pytest.raises( + InvalidOperationError, + match="conversion from `str` to `enum` failed in column '' for 1 out of 4 values: \\[\"c\"\\]", + ): + pl.Series([None, "a", "b", "c"], dtype=pl.Enum(categories=["a", "b"])) + + +def test_enum_non_existent_non_strict() -> None: + s = pl.Series( + [None, "a", "b", "c"], dtype=pl.Enum(categories=["a", "b"]), strict=False + ) + expected = pl.Series([None, "a", "b", None], dtype=pl.Enum(categories=["a", "b"])) + assert_series_equal(s, expected) + + +def test_enum_from_schema_argument() -> None: + df = pl.DataFrame( + {"col1": ["a", "b", "c"]}, schema={"col1": pl.Enum(["a", "b", "c"])} + ) + assert df.get_column("col1").dtype == pl.Enum + assert dedent( + """ + │ col1 │ + │ --- │ + │ enum │ + ╞══════╡ + """ + ) in str(df) + + +def test_equality_of_two_separately_constructed_enums() -> None: + s = pl.Series([None, "a", "b"], dtype=pl.Enum(categories=["a", "b"])) + s2 = pl.Series([None, "a", "b"], dtype=pl.Enum(categories=["a", "b"])) + assert_series_equal(s, s2) + + +def test_nested_enum_creation() -> None: + dtype = pl.List(pl.Enum(["a", "b", "c"])) + s = pl.Series([[None, "a"], ["b", "c"]], dtype=dtype) + assert s.len() == 2 + assert s.dtype == dtype + + +def test_enum_union() -> None: + e1 = pl.Enum(["a", "b"]) + e2 = pl.Enum(["b", "c"]) + assert e1 | e2 == pl.Enum(["a", "b", "c"]) + assert e1.union(e2) == pl.Enum(["a", "b", "c"]) + + +def test_nested_enum_concat() -> None: + dtype = pl.List(pl.Enum(["a", "b", "c", "d"])) + s1 = pl.Series([[None, "a"], ["b", "c"]], dtype=dtype) + s2 = pl.Series([["c", "d"], ["a", None]], dtype=dtype) + expected = pl.Series( + [ + [None, "a"], + ["b", "c"], + ["c", "d"], + ["a", None], + ], + dtype=dtype, + ) + + assert_series_equal(pl.concat((s1, s2)), expected) + assert_series_equal(s1.extend(s2), expected) + + +def test_nested_enum_agg_sort_18026() -> None: + df = ( + pl.DataFrame({"a": [1, 1, 2, 2], "b": ["Y", "Z", "Z", "Y"]}) + .cast({"b": pl.Enum(["Z", "Y"])}) + .with_columns(pl.struct("b", "a").alias("c")) + ) + result = df.group_by("a").agg("c").sort("a") + expected = pl.DataFrame( + { + "a": [1, 2], + "c": [ + [{"b": "Y", "a": 1}, {"b": "Z", "a": 1}], + [{"b": "Z", "a": 2}, {"b": "Y", "a": 2}], + ], + }, + schema={ + "a": pl.Int64, + "c": pl.List( + pl.Struct([pl.Field("b", pl.Enum(["Z", "Y"])), pl.Field("a", pl.Int64)]) + ), + }, + ) + assert_frame_equal(result, expected) + + +def test_casting_to_an_enum_from_utf() -> None: + dtype = pl.Enum(["a", "b", "c"]) + s = pl.Series([None, "a", "b", "c"]) + s2 = s.cast(dtype) + assert s2.dtype == dtype + assert s2.null_count() == 1 + + +def test_casting_to_an_enum_from_categorical() -> None: + dtype = pl.Enum(["a", "b", "c"]) + s = pl.Series([None, "a", "b", "c"], dtype=pl.Categorical) + s2 = s.cast(dtype) + assert s2.dtype == dtype + assert s2.null_count() == 1 + expected = pl.Series([None, "a", "b", "c"], dtype=dtype) + assert_series_equal(s2, expected) + + +def test_casting_to_an_enum_from_categorical_nonstrict() -> None: + dtype = pl.Enum(["a", "b"]) + s = pl.Series([None, "a", "b", "c"], dtype=pl.Categorical) + s2 = s.cast(dtype, strict=False) + assert s2.dtype == dtype + assert s2.null_count() == 2 # "c" mapped to null + expected = pl.Series([None, "a", "b", None], dtype=dtype) + assert_series_equal(s2, expected) + + +def test_casting_to_an_enum_from_enum_nonstrict() -> None: + dtype = pl.Enum(["a", "b"]) + s = pl.Series([None, "a", "b", "c"], dtype=pl.Enum(["a", "b", "c"])) + s2 = s.cast(dtype, strict=False) + assert s2.dtype == dtype + assert s2.null_count() == 2 # "c" mapped to null + expected = pl.Series([None, "a", "b", None], dtype=dtype) + assert_series_equal(s2, expected) + + +def test_casting_to_an_enum_from_integer() -> None: + dtype = pl.Enum(["a", "b", "c"]) + expected = pl.Series([None, "b", "a", "c"], dtype=dtype) + s = pl.Series([None, 1, 0, 2], dtype=pl.UInt32) + s_enum = s.cast(dtype) + assert s_enum.dtype == dtype + assert s_enum.null_count() == 1 + assert_series_equal(s_enum, expected) + + +def test_casting_to_an_enum_oob_from_integer() -> None: + dtype = pl.Enum(["a", "b", "c"]) + s = pl.Series([None, 1, 0, 5], dtype=pl.UInt32) + with pytest.raises( + OutOfBoundsError, match=("index 5 is bigger than the number of categories 3") + ): + s.cast(dtype) + + +def test_casting_to_an_enum_from_categorical_nonexistent() -> None: + with pytest.raises( + InvalidOperationError, + match=( + r"conversion from `cat` to `enum` failed in column '' for 1 out of 4 values: \[\"c\"\]" + ), + ): + pl.Series([None, "a", "b", "c"], dtype=pl.Categorical).cast(pl.Enum(["a", "b"])) + + +@StringCache() +def test_casting_to_an_enum_from_global_categorical() -> None: + dtype = pl.Enum(["a", "b", "c"]) + s = pl.Series([None, "a", "b", "c"], dtype=pl.Categorical) + s2 = s.cast(dtype) + assert s2.dtype == dtype + assert s2.null_count() == 1 + expected = pl.Series([None, "a", "b", "c"], dtype=dtype) + assert_series_equal(s2, expected) + + +@StringCache() +def test_casting_to_an_enum_from_global_categorical_nonexistent() -> None: + with pytest.raises( + InvalidOperationError, + match=( + r"conversion from `cat` to `enum` failed in column '' for 1 out of 4 values: \[\"c\"\]" + ), + ): + pl.Series([None, "a", "b", "c"], dtype=pl.Categorical).cast(pl.Enum(["a", "b"])) + + +def test_casting_from_an_enum_to_local() -> None: + dtype = pl.Enum(["a", "b", "c"]) + s = pl.Series([None, "a", "b", "c"], dtype=dtype) + s2 = s.cast(pl.Categorical) + expected = pl.Series([None, "a", "b", "c"], dtype=pl.Categorical) + assert_series_equal(s2, expected) + + +@StringCache() +def test_casting_from_an_enum_to_global() -> None: + dtype = pl.Enum(["a", "b", "c"]) + s = pl.Series([None, "a", "b", "c"], dtype=dtype) + s2 = s.cast(pl.Categorical) + expected = pl.Series([None, "a", "b", "c"], dtype=pl.Categorical) + assert_series_equal(s2, expected) + + +def test_append_to_an_enum() -> None: + s = pl.Series([None, "a", "b", "c"], dtype=pl.Enum(["a", "b", "c"])) + s2 = pl.Series(["c", "a", "b", "c"], dtype=pl.Enum(["a", "b", "c"])) + s.append(s2) + assert s.len() == 8 + + +def test_append_to_an_enum_with_new_category() -> None: + with pytest.raises( + SchemaError, + match=("type Enum.*is incompatible with expected type Enum.*"), + ): + pl.Series([None, "a", "b", "c"], dtype=pl.Enum(["a", "b", "c"])).append( + pl.Series(["d", "a", "b", "c"], dtype=pl.Enum(["a", "b", "c", "d"])) + ) + + +def test_extend_to_an_enum() -> None: + s = pl.Series([None, "a", "b", "c"], dtype=pl.Enum(["a", "b", "c"])) + s2 = pl.Series(["c", "a", "b", "c"], dtype=pl.Enum(["a", "b", "c"])) + s.extend(s2) + assert s.len() == 8 + assert s.null_count() == 1 + + +def test_series_init_uninstantiated_enum() -> None: + with pytest.raises( + InvalidOperationError, + match="cannot cast / initialize Enum without categories present", + ): + pl.Series(["a", "b", "a"], dtype=pl.Enum) + + +@pytest.mark.parametrize( + ("op", "expected"), + [ + (operator.le, pl.Series([None, True, True, True])), + (operator.lt, pl.Series([None, True, False, False])), + (operator.ge, pl.Series([None, False, True, True])), + (operator.gt, pl.Series([None, False, False, False])), + (operator.eq, pl.Series([None, False, True, True])), + (operator.ne, pl.Series([None, True, False, False])), + (pl.Series.ne_missing, pl.Series([False, True, False, False])), + (pl.Series.eq_missing, pl.Series([True, False, True, True])), + ], +) +def test_equality_enum( + op: Callable[[pl.Series, pl.Series], pl.Series], expected: pl.Series +) -> None: + dtype = pl.Enum(["a", "b", "c"]) + s = pl.Series([None, "a", "b", "c"], dtype=dtype) + s2 = pl.Series([None, "c", "b", "c"], dtype=dtype) + + assert_series_equal(op(s, s2), expected) + assert_series_equal(op(s, s2.cast(pl.String)), expected) + + +@pytest.mark.parametrize( + ("op", "expected"), + [ + (operator.le, pl.Series([None, False, True, True])), + (operator.lt, pl.Series([None, False, False, True])), + (operator.ge, pl.Series([None, True, True, False])), + (operator.gt, pl.Series([None, True, False, False])), + (operator.eq, pl.Series([None, False, True, False])), + (operator.ne, pl.Series([None, True, False, True])), + (pl.Series.ne_missing, pl.Series([True, True, False, True])), + (pl.Series.eq_missing, pl.Series([False, False, True, False])), + ], +) +def test_compare_enum_str_single( + op: Callable[[pl.Series, pl.Series], pl.Series], expected: pl.Series +) -> None: + s = pl.Series( + [None, "HIGH", "MEDIUM", "LOW"], dtype=pl.Enum(["LOW", "MEDIUM", "HIGH"]) + ) + s2 = "MEDIUM" + + assert_series_equal(op(s, s2), expected) # type: ignore[arg-type] + + +def test_equality_missing_enum_scalar() -> None: + dtype = pl.Enum(["a", "b", "c"]) + df = pl.DataFrame({"a": pl.Series([None, "a", "b", "c"], dtype=dtype)}) + + out = df.select( + pl.col("a").eq_missing(pl.lit("c", dtype=dtype)).alias("cmp") + ).get_column("cmp") + expected = pl.Series("cmp", [False, False, False, True], dtype=pl.Boolean) + assert_series_equal(out, expected) + + out_str = df.select(pl.col("a").eq_missing(pl.lit("c")).alias("cmp")).get_column( + "cmp" + ) + assert_series_equal(out_str, expected) + + out = df.select( + pl.col("a").ne_missing(pl.lit("c", dtype=dtype)).alias("cmp") + ).get_column("cmp") + expected = pl.Series("cmp", [True, True, True, False], dtype=pl.Boolean) + assert_series_equal(out, expected) + + out_str = df.select(pl.col("a").ne_missing(pl.lit("c")).alias("cmp")).get_column( + "cmp" + ) + assert_series_equal(out_str, expected) + + +def test_equality_missing_enum_none_scalar() -> None: + dtype = pl.Enum(["a", "b", "c"]) + df = pl.DataFrame({"a": pl.Series([None, "a", "b", "c"], dtype=dtype)}) + + out = df.select( + pl.col("a").eq_missing(pl.lit(None, dtype=dtype)).alias("cmp") + ).get_column("cmp") + expected = pl.Series("cmp", [True, False, False, False], dtype=pl.Boolean) + assert_series_equal(out, expected) + + out = df.select( + pl.col("a").ne_missing(pl.lit(None, dtype=dtype)).alias("cmp") + ).get_column("cmp") + expected = pl.Series("cmp", [False, True, True, True], dtype=pl.Boolean) + assert_series_equal(out, expected) + + +@pytest.mark.parametrize(("op"), [operator.le, operator.lt, operator.ge, operator.gt]) +def test_compare_enum_str_single_raise( + op: Callable[[pl.Series, pl.Series], pl.Series], +) -> None: + s = pl.Series( + [None, "HIGH", "MEDIUM", "LOW"], dtype=pl.Enum(["LOW", "MEDIUM", "HIGH"]) + ) + s2 = "NOTEXIST" + + with pytest.raises( + InvalidOperationError, + match=re.escape( + "conversion from `str` to `enum` failed in column '' for 1 out of 1 values: [\"NOTEXIST\"]" + ), + ): + op(s, s2) # type: ignore[arg-type] + + +def test_compare_enum_str_raise() -> None: + s = pl.Series([None, "a", "b", "c"], dtype=pl.Enum(["a", "b", "c"])) + s2 = pl.Series([None, "d", "d", "d"]) + s_broadcast = pl.Series(["d"]) + + for s_compare in [s2, s_broadcast]: + for op in [operator.le, operator.gt, operator.ge, operator.lt]: + with pytest.raises( + InvalidOperationError, + match="conversion from `str` to `enum` failed in column", + ): + op(s, s_compare) + + +def test_different_enum_comparison_order() -> None: + df_enum = pl.DataFrame( + [ + pl.Series( + "a_cat", ["c", "a", "b", "c", "b"], dtype=pl.Enum(["a", "b", "c"]) + ), + pl.Series( + "b_cat", ["F", "G", "E", "G", "G"], dtype=pl.Enum(["F", "G", "E"]) + ), + ] + ) + for op in [operator.gt, operator.ge, operator.lt, operator.le]: + with pytest.raises( + ComputeError, + match="can only compare categoricals of the same type", + ): + df_enum.filter(op(pl.col("a_cat"), pl.col("b_cat"))) + + +@pytest.mark.parametrize("categories", [[None], ["x", "y", None]]) +def test_enum_categories_null(categories: list[str | None]) -> None: + with pytest.raises(TypeError, match="Enum categories must not contain null values"): + pl.Enum(categories) # type: ignore[arg-type] + + +@pytest.mark.parametrize( + ("categories", "type"), [([date.today()], "Date"), ([-10, 10], "Int64")] +) +def test_valid_enum_category_types(categories: Any, type: str) -> None: + with pytest.raises( + TypeError, match=f"Enum categories must be strings; found data of type {type}" + ): + pl.Enum(categories) + + +def test_enum_categories_unique() -> None: + with pytest.raises(ValueError, match="must be unique; found duplicate 'a'"): + pl.Enum(["a", "a", "b", "b", "b", "c"]) + + +def test_enum_categories_series_input() -> None: + categories = pl.Series("a", ["a", "b", "c"]) + dtype = pl.Enum(categories) + assert_series_equal(dtype.categories, categories.alias("category")) + + +def test_enum_categories_series_zero_copy() -> None: + categories = pl.Series(["a", "b"]) + dtype = pl.Enum(categories) + + s = pl.Series([None, "a", "b"], dtype=dtype) + result_dtype = s.dtype + + assert result_dtype == dtype + + +@pytest.mark.parametrize("dtype", INTEGER_DTYPES) +def test_enum_cast_from_other_integer_dtype(dtype: pl.DataType) -> None: + enum_dtype = pl.Enum(["a", "b", "c", "d"]) + series = pl.Series([1, 2, 3, 3, 2, 1], dtype=dtype) + series.cast(enum_dtype) + + +def test_enum_cast_from_other_integer_dtype_oob() -> None: + enum_dtype = pl.Enum(["a", "b", "c", "d"]) + series = pl.Series([-1, 2, 3, 3, 2, 1], dtype=pl.Int8) + with pytest.raises( + InvalidOperationError, match="conversion from `i8` to `u32` failed in column" + ): + series.cast(enum_dtype) + + series = pl.Series([2**34, 2, 3, 3, 2, 1], dtype=pl.UInt64) + with pytest.raises( + InvalidOperationError, + match="conversion from `u64` to `u32` failed in column", + ): + series.cast(enum_dtype) + + +def test_enum_creating_col_expr() -> None: + df = pl.DataFrame( + { + "col1": ["a", "b", "c"], + "col2": ["d", "e", "f"], + "col3": ["g", "h", "i"], + }, + schema={ + "col1": pl.Enum(["a", "b", "c"]), + "col2": pl.Categorical(), + "col3": pl.Enum(["g", "h", "i"]), + }, + ) + + out = df.select(pl.col(pl.Enum)) + expected = df.select("col1", "col3") + assert_frame_equal(out, expected) + + +def test_enum_cse_eq() -> None: + df = pl.DataFrame({"a": [1]}) + + # these both share the value "a", which is used in both expressions + dt1 = pl.Enum(["a", "b"]) + dt2 = pl.Enum(["a", "c"]) + + out = ( + df.lazy() + .select( + pl.when(True).then(pl.lit("a", dtype=dt1)).alias("dt1"), + pl.when(True).then(pl.lit("a", dtype=dt2)).alias("dt2"), + ) + .collect() + ) + + assert out["dt1"].item() == "a" + assert out["dt2"].item() == "a" + assert out["dt1"].dtype == pl.Enum(["a", "b"]) + assert out["dt2"].dtype == pl.Enum(["a", "c"]) + assert out["dt1"].dtype != out["dt2"].dtype + + +def test_category_comparison_subset() -> None: + dt1 = pl.Enum(["a"]) + dt2 = pl.Enum(["a", "b"]) + out = ( + pl.LazyFrame() + .select( + pl.lit("a", dtype=dt1).alias("dt1"), + pl.lit("a", dtype=dt2).alias("dt2"), + ) + .collect() + ) + + assert out["dt1"].item() == "a" + assert out["dt2"].item() == "a" + assert out["dt1"].dtype == pl.Enum(["a"]) + assert out["dt2"].dtype == pl.Enum(["a", "b"]) + assert out["dt1"].dtype != out["dt2"].dtype + + +@pytest.mark.parametrize("dt", INTEGER_DTYPES) +def test_integer_cast_to_enum_15738(dt: pl.DataType) -> None: + s = pl.Series([0, 1, 2], dtype=dt).cast(pl.Enum(["a", "b", "c"])) + assert s.to_list() == ["a", "b", "c"] + expected_s = pl.Series(["a", "b", "c"], dtype=pl.Enum(["a", "b", "c"])) + assert_series_equal(s, expected_s) + + +def test_enum_19269() -> None: + en = pl.Enum(["X", "Z", "Y"]) + df = pl.DataFrame( + {"test": pl.Series(["X", "Y", "Z"], dtype=en), "group": [1, 2, 2]} + ) + out = ( + df.group_by("group", maintain_order=True) + .agg(pl.col("test").mode()) + .select( + a=pl.col("test").list.max(), + b=pl.col("test").list.min(), + ) + ) + + assert out.to_dict(as_series=False) == {"a": ["X", "Y"], "b": ["X", "Z"]} + assert out.dtypes == [en, en] + + +def test_roundtrip_enum_parquet() -> None: + dtype = pl.Enum(["foo", "bar", "ham"]) + df = pl.DataFrame(pl.Series("d", ["foo", "bar"]), schema=pl.Schema({"d": dtype})) + f = io.BytesIO() + df.write_parquet(f) + f.seek(0) + assert pl.scan_parquet(f).collect_schema()["d"] == dtype + + +@pytest.mark.parametrize( + "EnumBase", + [ + (enum.Enum,), + (str, enum.Enum), + *([(PyStrEnum,)] if PyStrEnum is not None else []), + ], +) +def test_init_frame_from_enums(EnumBase: tuple[type, ...]) -> None: + class Portfolio(*EnumBase): # type: ignore[misc] + TECH = "Technology" + RETAIL = "Retail" + OTHER = "Other" + + # confirm that we can infer the enum dtype from various enum bases + df = pl.DataFrame( + {"trade_id": [123, 456], "portfolio": [Portfolio.OTHER, Portfolio.TECH]} + ) + expected = pl.DataFrame( + {"trade_id": [123, 456], "portfolio": ["Other", "Technology"]}, + schema={ + "trade_id": pl.Int64, + "portfolio": pl.Enum(["Technology", "Retail", "Other"]), + }, + ) + assert_frame_equal(expected, df) + + # if schema indicates string, ensure we do *not* convert to enum + df = pl.DataFrame( + { + "trade_id": [123, 456, 789], + "portfolio": [Portfolio.OTHER, Portfolio.TECH, Portfolio.RETAIL], + }, + schema_overrides={"portfolio": pl.String}, + ) + assert df.schema == {"trade_id": pl.Int64, "portfolio": pl.String} + + +@pytest.mark.parametrize( + "EnumBase", + [ + (enum.Enum,), + (enum.Flag,), + (enum.IntEnum,), + (enum.IntFlag,), + (int, enum.Enum), + ], +) +def test_init_series_from_int_enum(EnumBase: tuple[type, ...]) -> None: + # note: we do not support integer enums as polars enums, + # but we should be able to load the values + + class Number(*EnumBase): # type: ignore[misc] + ONE = 1 + TWO = 2 + FOUR = 4 + EIGHT = 8 + + s = pl.Series(values=[Number.EIGHT, Number.TWO, Number.FOUR]) + + expected = pl.Series(values=[8, 2, 4], dtype=pl.Int64) + assert_series_equal(expected, s) + + +@pytest.mark.may_fail_auto_streaming +def test_read_enum_from_csv() -> None: + df = pl.DataFrame( + { + "foo": ["ham", "spam", None, "and", "such"], + "bar": ["ham", "spam", None, "and", "such"], + } + ) + f = io.BytesIO() + df.write_csv(f) + f.seek(0) + + schema = {"foo": pl.Enum(["ham", "and", "such", "spam"]), "bar": pl.String()} + read = pl.read_csv(f, schema=schema) + assert read.schema == schema + assert_frame_equal(df.cast(schema), read) # type: ignore[arg-type] diff --git a/py-polars/tests/unit/datatypes/test_float.py b/py-polars/tests/unit/datatypes/test_float.py new file mode 100644 index 000000000000..6d9b06372ddc --- /dev/null +++ b/py-polars/tests/unit/datatypes/test_float.py @@ -0,0 +1,311 @@ +import pyarrow as pa +import pytest + +import polars as pl +from polars.testing import assert_series_equal + + +def test_nan_in_group_by_agg() -> None: + df = pl.DataFrame( + { + "key": ["a", "a", "a", "a"], + "value": [18.58, 18.78, float("nan"), 18.63], + "bar": [0, 0, 0, 0], + } + ) + + assert df.group_by("bar", "key").agg(pl.col("value").max())["value"].item() == 18.78 + assert df.group_by("bar", "key").agg(pl.col("value").min())["value"].item() == 18.58 + + +def test_nan_aggregations() -> None: + df = pl.DataFrame({"a": [1.0, float("nan"), 2.0, 3.0], "b": [1, 1, 1, 1]}) + + aggs = [ + pl.col("a").max().alias("max"), + pl.col("a").min().alias("min"), + pl.col("a").nan_max().alias("nan_max"), + pl.col("a").nan_min().alias("nan_min"), + ] + + assert ( + str(df.select(aggs).to_dict(as_series=False)) + == "{'max': [3.0], 'min': [1.0], 'nan_max': [nan], 'nan_min': [nan]}" + ) + assert ( + str(df.group_by("b").agg(aggs).to_dict(as_series=False)) + == "{'b': [1], 'max': [3.0], 'min': [1.0], 'nan_max': [nan], 'nan_min': [nan]}" + ) + + +@pytest.mark.parametrize("descending", [True, False]) +def test_sorted_nan_max_12931(descending: bool) -> None: + s = pl.Series("x", [1.0, 2.0, float("nan")]).sort(descending=descending) + + assert s.max() == 2.0 + assert s.arg_max() == 1 + + # Test full-nan + s = pl.Series("x", [float("nan"), float("nan"), float("nan")]).sort( + descending=descending + ) + + out = s.max() + assert out != out + + # This is flipped because float arg_max calculates the index as + # * sorted ascending: (index of left-most NaN) - 1, saturating subtraction at 0 + # * sorted descending: (index of right-most NaN) + 1, saturating addition at s.len() + assert s.arg_max() == (0, 2)[descending] + + s = pl.Series("x", [1.0, 2.0, 3.0]).sort(descending=descending) + + assert s.max() == 3.0 + assert s.arg_max() == (2, 0)[descending] + + +@pytest.mark.parametrize( + ("s", "expect"), + [ + ( + pl.Series( + "x", + [ + -0.0, + float("-nan"), + 0.0, + None, + 1.0, + float("nan"), + ], + ), + pl.Series("x", [0.0, float("nan"), None, 1.0]), + ), + ( + # No nulls + pl.Series( + "x", + [ + -0.0, + float("-nan"), + 0.0, + 1.0, + float("nan"), + ], + ), + pl.Series("x", [0.0, float("nan"), 1.0]), + ), + ], +) +def test_unique(s: pl.Series, expect: pl.Series) -> None: + out = s.unique(maintain_order=False) + assert_series_equal(expect, out, check_order=False) + + out = s.unique(maintain_order=True) + assert_series_equal(expect, out) + + out = s.n_unique() # type: ignore[assignment] + assert expect.len() == out + + out = s.gather(s.arg_unique()) + assert_series_equal(expect, out) + + +def test_unique_counts() -> None: + s = pl.Series( + "x", + [ + -0.0, + 0.0, + float("-nan"), + float("nan"), + 1.0, + None, + ], + ) + expect = pl.Series("x", [2, 2, 1, 1], dtype=pl.UInt32) + out = s.unique_counts() + assert_series_equal(expect, out) + + +def test_hash() -> None: + s = pl.Series( + "x", + [ + -0.0, + 0.0, + float("-nan"), + float("nan"), + 1.0, + None, + ], + ).hash() + + # check them against each other since hash is not stable + assert s.item(0) == s.item(1) # hash(-0.0) == hash(0.0) + assert s.item(2) == s.item(3) # hash(float('-nan')) == hash(float('nan')) + + +def test_group_by_float() -> None: + # Test num_groups_proxy + # * -0.0 and 0.0 in same groups + # * -nan and nan in same groups + df = ( + pl.Series( + "x", + [ + -0.0, + 0.0, + float("-nan"), + float("nan"), + 1.0, + None, + ], + ) + .to_frame() + .with_row_index() + .with_columns(a=pl.lit("a")) + ) + + expect = pl.Series("index", [[0, 1], [2, 3], [4], [5]], dtype=pl.List(pl.UInt32)) + expect_no_null = expect.head(3) + + for group_keys in (("x",), ("x", "a")): + for maintain_order in (True, False): + for drop_nulls in (True, False): + out = df + if drop_nulls: + out = out.drop_nulls() + + out = ( + out.group_by(group_keys, maintain_order=maintain_order) # type: ignore[assignment] + .agg("index") + .sort(pl.col("index").list.get(0)) + .select("index") + .to_series() + ) + + if drop_nulls: + assert_series_equal(expect_no_null, out) # type: ignore[arg-type] + else: + assert_series_equal(expect, out) # type: ignore[arg-type] + + +def test_joins() -> None: + # Test that -0.0 joins with 0.0 and nan joins with nan + df = ( + pl.Series( + "x", + [ + -0.0, + 0.0, + float("-nan"), + float("nan"), + 1.0, + None, + ], + ) + .to_frame() + .with_row_index() + .with_columns(a=pl.lit("a")) + ) + + rhs = ( + pl.Series("x", [0.0, float("nan"), 3.0]) + .to_frame() + .with_columns(a=pl.lit("a"), rhs=True) + ) + + for join_on in ( + # Single and multiple keys + ("x",), + ( + "x", + "a", + ), + ): + how = "left" + expect = pl.Series("rhs", [True, True, True, True, None, None]) + out = df.join(rhs, on=join_on, how=how).sort("index").select("rhs").to_series() # type: ignore[arg-type] + assert_series_equal(expect, out) + + how = "inner" + expect = pl.Series("index", [0, 1, 2, 3], dtype=pl.UInt32) + out = ( + df.join(rhs, on=join_on, how=how).sort("index").select("index").to_series() # type: ignore[arg-type] + ) + assert_series_equal(expect, out) + + how = "full" + expect = pl.Series("rhs", [True, True, True, True, None, None, True]) + out = ( + df.join(rhs, on=join_on, how=how) # type: ignore[arg-type] + .sort("index", nulls_last=True) + .select("rhs") + .to_series() + ) + assert_series_equal(expect, out) + + how = "semi" + expect = pl.Series("x", [-0.0, 0.0, float("-nan"), float("nan")]) + out = ( + df.join(rhs, on=join_on, how=how) # type: ignore[arg-type] + .sort("index", nulls_last=True) + .select("x") + .to_series() + ) + assert_series_equal(expect, out) + + how = "anti" + expect = pl.Series("x", [1.0, None]) + out = ( + df.join(rhs, on=join_on, how=how) # type: ignore[arg-type] + .sort("index", nulls_last=True) + .select("x") + .to_series() + ) + assert_series_equal(expect, out) + + # test asof + # note that nans never join because nans are always greater than the other + # side of the comparison (i.e. NaN > tolerance) + expect = pl.Series("rhs", [True, True, None, None, None, None]) + out = ( + df.sort("x") + .join_asof(rhs.sort("x"), on="x", tolerance=0) + .sort("index") + .select("rhs") + .to_series() + ) + assert_series_equal(expect, out) + + +def test_first_last_distinct() -> None: + s = pl.Series( + "x", + [ + -0.0, + 0.0, + float("-nan"), + float("nan"), + 1.0, + None, + ], + ) + + assert_series_equal( + pl.Series("x", [True, False, True, False, True, True]), s.is_first_distinct() + ) + + assert_series_equal( + pl.Series("x", [False, True, False, True, True, True]), s.is_last_distinct() + ) + + +def test_arrow_float16_read_empty_20946() -> None: + schema = pa.schema([("float_column", pa.float16())]) + table = pa.table([[]], schema=schema) + + df = pl.from_arrow(table) + assert df.shape == (0, 1) + assert df.schema == pl.Schema([("float_column", pl.Float32)]) # type: ignore[union-attr] diff --git a/py-polars/tests/unit/datatypes/test_integer.py b/py-polars/tests/unit/datatypes/test_integer.py new file mode 100644 index 000000000000..ec649dd0a87e --- /dev/null +++ b/py-polars/tests/unit/datatypes/test_integer.py @@ -0,0 +1,31 @@ +import polars as pl + + +def test_integer_float_functions() -> None: + assert pl.DataFrame({"a": [1, 2]}).select( + finite=pl.all().is_finite(), + infinite=pl.all().is_infinite(), + nan=pl.all().is_nan(), + not_na=pl.all().is_not_nan(), + ).to_dict(as_series=False) == { + "finite": [True, True], + "infinite": [False, False], + "nan": [False, False], + "not_na": [True, True], + } + + +def test_int_negate_operation() -> None: + assert pl.Series([1, 2, 3, 4, 50912341409]).not_().to_list() == [ + -2, + -3, + -4, + -5, + -50912341410, + ] + + +def test_compare_zero_with_uint64_16798() -> None: + df = pl.Series("a", [(1 << 63), 0], dtype=pl.UInt64).to_frame() + assert df.select(pl.col("a") >= 0).item(0, 0) + assert df.select(pl.col("a") == 0).item(0, 0) is False diff --git a/py-polars/tests/unit/datatypes/test_list.py b/py-polars/tests/unit/datatypes/test_list.py new file mode 100644 index 000000000000..910a1a068a82 --- /dev/null +++ b/py-polars/tests/unit/datatypes/test_list.py @@ -0,0 +1,885 @@ +from __future__ import annotations + +import pickle +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from typing import TYPE_CHECKING, Any + +import pandas as pd +import pytest + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal +from tests.unit.conftest import NUMERIC_DTYPES, TEMPORAL_DTYPES + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + + +def test_dtype() -> None: + # inferred + a = pl.Series("a", [[1, 2, 3], [2, 5], [6, 7, 8, 9]]) + assert a.dtype == pl.List + assert a.dtype.inner == pl.Int64 # type: ignore[attr-defined] + assert a.dtype.is_(pl.List(pl.Int64)) + + # explicit + u64_max = (2**64) - 1 + df = pl.DataFrame( + data={ + "i": [[1, 2, 3]], + "li": [[[1, 2, 3]]], + "u": [[u64_max]], + "tm": [[time(10, 30, 45)]], + "dt": [[date(2022, 12, 31)]], + "dtm": [[datetime(2022, 12, 31, 1, 2, 3)]], + }, + schema=[ + ("i", pl.List(pl.Int8)), + ("li", pl.List(pl.List(pl.Int8))), + ("u", pl.List(pl.UInt64)), + ("tm", pl.List(pl.Time)), + ("dt", pl.List(pl.Date)), + ("dtm", pl.List(pl.Datetime)), + ], + ) + assert df.schema == { + "i": pl.List(pl.Int8), + "li": pl.List(pl.List(pl.Int8)), + "u": pl.List(pl.UInt64), + "tm": pl.List(pl.Time), + "dt": pl.List(pl.Date), + "dtm": pl.List(pl.Datetime("us")), + } + assert all(tp.is_nested() for tp in df.dtypes) + assert df.schema["i"].inner == pl.Int8 # type: ignore[attr-defined] + assert df.rows() == [ + ( + [1, 2, 3], + [[1, 2, 3]], + [u64_max], + [time(10, 30, 45)], + [date(2022, 12, 31)], + [datetime(2022, 12, 31, 1, 2, 3)], + ) + ] + + +@pytest.mark.usefixtures("test_global_and_local") +def test_categorical() -> None: + # https://github.com/pola-rs/polars/issues/2038 + df = pl.DataFrame( + [ + pl.Series("a", [1, 1, 1, 1, 1, 1, 1, 1]), + pl.Series("b", [8, 2, 3, 6, 3, 6, 2, 2]), + pl.Series("c", ["a", "b", "c", "a", "b", "c", "a", "b"]).cast( + pl.Categorical + ), + ] + ) + out = ( + df.group_by(["a", "b"]) + .agg( + pl.col("c").count().alias("num_different_c"), + pl.col("c").alias("c_values"), + ) + .filter(pl.col("num_different_c") >= 2) + .to_series(3) + ) + + assert out.dtype.inner == pl.Categorical # type: ignore[attr-defined] + assert out.dtype.inner.is_nested() is False # type: ignore[attr-defined] + + +def test_decimal() -> None: + input = [[Decimal("1.23"), Decimal("4.56")], [Decimal("7.89"), Decimal("10.11")]] + s = pl.Series(input) + assert s.dtype == pl.List(pl.Decimal) + assert s.dtype.inner == pl.Decimal # type: ignore[attr-defined] + assert s.dtype.inner.is_nested() is False # type: ignore[attr-defined] + assert s.to_list() == input + + +def test_cast_inner() -> None: + a = pl.Series([[1, 2]]) + for t in [bool, pl.Boolean]: + b = a.cast(pl.List(t)) + assert b.dtype == pl.List(pl.Boolean) + assert b.to_list() == [[True, True]] + + # this creates an inner null type + df = pl.from_pandas(pd.DataFrame(data=[[[]], [[]]], columns=["A"])) + assert ( + df["A"].cast(pl.List(int)).dtype.inner == pl.Int64 # type: ignore[attr-defined] + ) + + +def test_list_empty_group_by_result_3521() -> None: + # Create a left relation where the join column contains a null value. + left = pl.DataFrame( + {"group_by_column": [1], "join_column": [None]}, + schema_overrides={"join_column": pl.Int64}, + ) + + # Create a right relation where there is a column to count distinct on. + right = pl.DataFrame({"join_column": [1], "n_unique_column": [1]}) + + # Calculate n_unique after dropping nulls. + result = ( + left.join(right, on="join_column", how="left") + .group_by("group_by_column") + .agg(pl.col("n_unique_column").drop_nulls()) + ) + expected = {"group_by_column": [1], "n_unique_column": [[]]} + assert result.to_dict(as_series=False) == expected + + +def test_list_fill_null() -> None: + df = pl.DataFrame({"C": [["a", "b", "c"], [], [], ["d", "e"]]}) + assert df.with_columns( + pl.when(pl.col("C").list.len() == 0) + .then(None) + .otherwise(pl.col("C")) + .alias("C") + ).to_series().to_list() == [["a", "b", "c"], None, None, ["d", "e"]] + + +def test_list_fill_select_null() -> None: + assert pl.DataFrame({"a": [None, []]}).select( + pl.when(pl.col("a").list.len() == 0) + .then(None) + .otherwise(pl.col("a")) + .alias("a") + ).to_series().to_list() == [None, None] + + +def test_list_fill_list() -> None: + assert pl.DataFrame({"a": [[1, 2, 3], []]}).select( + pl.when(pl.col("a").list.len() == 0) + .then([5]) + .otherwise(pl.col("a")) + .alias("filled") + ).to_dict(as_series=False) == {"filled": [[1, 2, 3], [5]]} + + +def test_empty_list_construction() -> None: + assert pl.Series([[]]).to_list() == [[]] + + df = pl.DataFrame([{"array": [], "not_array": 1234}], orient="row") + expected = {"array": [[]], "not_array": [1234]} + assert df.to_dict(as_series=False) == expected + + df = pl.DataFrame(schema=[("col", pl.List)]) + assert df.schema == {"col": pl.List(pl.Null)} + assert df.rows() == [] + + +def test_list_hash() -> None: + out = pl.DataFrame({"a": [[1, 2, 3], [3, 4], [1, 2, 3]]}).with_columns( + pl.col("a").hash().alias("b") + ) + assert out.dtypes == [pl.List(pl.Int64), pl.UInt64] + assert out[0, "b"] == out[2, "b"] + + +def test_list_diagonal_concat() -> None: + df1 = pl.DataFrame({"a": [1, 2]}) + + df2 = pl.DataFrame({"b": [[1]]}) + + assert pl.concat([df1, df2], how="diagonal").to_dict(as_series=False) == { + "a": [1, 2, None], + "b": [None, None, [1]], + } + + +def test_inner_type_categorical_on_rechunk() -> None: + df = pl.DataFrame({"cats": ["foo", "bar"]}).select( + pl.col(pl.String).cast(pl.Categorical).implode() + ) + + assert pl.concat([df, df], rechunk=True).dtypes == [pl.List(pl.Categorical)] + + +def test_local_categorical_list() -> None: + values = [["a", "b"], ["c"], ["a", "d", "d"]] + s = pl.Series(values, dtype=pl.List(pl.Categorical)) + assert s.dtype == pl.List + assert s.dtype.inner == pl.Categorical # type: ignore[attr-defined] + assert s.to_list() == values + + # Check that underlying physicals match + idx_df = pl.Series([[0, 1], [2], [0, 3, 3]], dtype=pl.List(pl.UInt32)) + assert_series_equal(s.cast(pl.List(pl.UInt32)), idx_df) + + # Check if the categories array does not overlap + assert s.list.explode().cat.get_categories().to_list() == ["a", "b", "c", "d"] + + +def test_group_by_list_column() -> None: + df = ( + pl.DataFrame({"a": ["a", "b", "a"]}) + .with_columns(pl.col("a").cast(pl.Categorical)) + .group_by("a", maintain_order=True) + .agg(pl.col("a").alias("a_list")) + ) + + assert df.group_by("a_list", maintain_order=True).first().to_dict( + as_series=False + ) == { + "a_list": [["a", "a"], ["b"]], + "a": ["a", "b"], + } + + +def test_group_by_multiple_keys_contains_list_column() -> None: + df = ( + pl.DataFrame( + { + "a": ["x", "x", "y", "y"], + "b": [[1, 2], [1, 2], [3, 4, 5], [6]], + "c": [3, 2, 1, 0], + } + ) + .group_by(["a", "b"], maintain_order=True) + .agg(pl.all()) + ) + assert df.to_dict(as_series=False) == { + "a": ["x", "y", "y"], + "b": [[1, 2], [3, 4, 5], [6]], + "c": [[3, 2], [1], [0]], + } + + +def test_fast_explode_flag() -> None: + df1 = pl.DataFrame({"values": [[[1, 2]]]}) + assert df1.clone().vstack(df1)["values"].flags["FAST_EXPLODE"] + + # test take that produces a null in list + df = pl.DataFrame({"a": [1, 2, 1, 3]}) + df_b = pl.DataFrame({"a": [1, 2], "c": [["1", "2", "c"], ["1", "2", "c"]]}) + assert df_b["c"].flags["FAST_EXPLODE"] + + # join produces a null + assert not (df.join(df_b, on=["a"], how="left").select(["c"]))["c"].flags[ + "FAST_EXPLODE" + ] + + +def test_fast_explode_on_list_struct_6208() -> None: + data = [ + { + "label": "l", + "tag": "t", + "ref": 1, + "parents": [{"ref": 1, "tag": "t", "ratio": 62.3}], + }, + {"label": "l", "tag": "t", "ref": 1, "parents": None}, + ] + + df = pl.DataFrame( + data, + schema={ + "label": pl.String, + "tag": pl.String, + "ref": pl.Int64, + "parents": pl.List( + pl.Struct({"ref": pl.Int64, "tag": pl.String, "ratio": pl.Float64}) + ), + }, + ) + + assert not df["parents"].flags["FAST_EXPLODE"] + assert df.explode("parents").to_dict(as_series=False) == { + "label": ["l", "l"], + "tag": ["t", "t"], + "ref": [1, 1], + "parents": [{"ref": 1, "tag": "t", "ratio": 62.3}, None], + } + + +def test_flat_aggregation_to_list_conversion_6918() -> None: + df = pl.DataFrame({"a": [1, 2, 2], "b": [[0, 1], [2, 3], [4, 5]]}) + + assert df.group_by("a", maintain_order=True).agg( + pl.concat_list([pl.col("b").list.get(i).mean().implode() for i in range(2)]) + ).to_dict(as_series=False) == {"a": [1, 2], "b": [[[0.0, 1.0]], [[3.0, 4.0]]]} + + +def test_list_count_matches() -> None: + assert pl.DataFrame({"listcol": [[], [1], [1, 2, 3, 2], [1, 2, 1], [4, 4]]}).select( + pl.col("listcol").list.count_matches(2).alias("number_of_twos") + ).to_dict(as_series=False) == {"number_of_twos": [0, 0, 2, 1, 0]} + + +def test_list_sum_and_dtypes() -> None: + # ensure the dtypes of sum align with normal sum + for dt_in, dt_out in [ + (pl.Int8, pl.Int64), + (pl.Int16, pl.Int64), + (pl.Int32, pl.Int32), + (pl.Int64, pl.Int64), + (pl.UInt8, pl.Int64), + (pl.UInt16, pl.Int64), + (pl.UInt32, pl.UInt32), + (pl.UInt64, pl.UInt64), + ]: + df = pl.DataFrame( + {"a": [[1], [1, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5]]}, + schema={"a": pl.List(dt_in)}, + ) + + summed = df.explode("a").sum() + assert summed.dtypes == [dt_out] + assert summed.item() == 32 + assert df.select(pl.col("a").list.sum()).dtypes == [dt_out] + + assert df.select(pl.col("a").list.sum()).to_dict(as_series=False) == { + "a": [1, 6, 10, 15] + } + + # include nulls + assert pl.DataFrame( + {"a": [[1], [1, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5], None]} + ).select(pl.col("a").list.sum()).to_dict(as_series=False) == { + "a": [1, 6, 10, 15, None] + } + + # Booleans + assert pl.DataFrame( + {"a": [[True], [True, True], [True, False, True], [True, True, True, None]]}, + ).select(pl.col("a").list.sum()).to_dict(as_series=False) == {"a": [1, 2, 2, 3]} + + assert pl.DataFrame( + {"a": [[False], [False, False], [False, False, False]]}, + ).select(pl.col("a").list.sum()).to_dict(as_series=False) == {"a": [0, 0, 0]} + + assert pl.DataFrame( + {"a": [[True], [True, True], [True, True, True]]}, + ).select(pl.col("a").list.sum()).to_dict(as_series=False) == {"a": [1, 2, 3]} + + +def test_list_mean() -> None: + assert pl.DataFrame({"a": [[1], [1, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5]]}).select( + pl.col("a").list.mean() + ).to_dict(as_series=False) == {"a": [1.0, 2.0, 2.5, 3.0]} + + assert pl.DataFrame({"a": [[1], [1, 2, 3], [1, 2, 3, 4], None]}).select( + pl.col("a").list.mean() + ).to_dict(as_series=False) == {"a": [1.0, 2.0, 2.5, None]} + + +def test_list_all() -> None: + assert pl.DataFrame( + { + "a": [ + [True], + [False], + [True, True], + [True, False], + [False, False], + [None], + [], + ] + } + ).select(pl.col("a").list.all()).to_dict(as_series=False) == { + "a": [True, False, True, False, False, True, True] + } + + +def test_list_any() -> None: + assert pl.DataFrame( + { + "a": [ + [True], + [False], + [True, True], + [True, False], + [False, False], + [None], + [], + ] + } + ).select(pl.col("a").list.any()).to_dict(as_series=False) == { + "a": [True, False, True, True, False, False, False] + } + + +@pytest.mark.parametrize("dtype", NUMERIC_DTYPES) +def test_list_min_max(dtype: pl.DataType) -> None: + df = pl.DataFrame( + {"a": [[1], [1, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5]]}, + schema={"a": pl.List(dtype)}, + ) + result = df.select(pl.col("a").list.min()) + expected = df.select(pl.col("a").list.first()) + assert_frame_equal(result, expected) + + result = df.select(pl.col("a").list.max()) + expected = df.select(pl.col("a").list.last()) + assert_frame_equal(result, expected) + + +def test_list_min_max2() -> None: + df = pl.DataFrame( + {"a": [[1], [1, 5, -1, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5], None]}, + ) + assert df.select(pl.col("a").list.min()).to_dict(as_series=False) == { + "a": [1, -1, 1, 1, None] + } + assert df.select(pl.col("a").list.max()).to_dict(as_series=False) == { + "a": [1, 5, 4, 5, None] + } + + +def test_list_mean_fast_path_empty() -> None: + df = pl.DataFrame( + { + "a": [[], [1, 2, 3]], + } + ) + output = df.select(pl.col("a").list.mean()) + assert output.to_dict(as_series=False) == {"a": [None, 2.0]} + + +def test_list_min_max_13978() -> None: + df = pl.DataFrame( + { + "a": [[], [1, 2, 3]], + "b": [[1, 2], None], + "c": [[], [None, 1, 2]], + } + ) + out = df.select( + min_a=pl.col("a").list.min(), + max_a=pl.col("a").list.max(), + min_b=pl.col("b").list.min(), + max_b=pl.col("b").list.max(), + min_c=pl.col("c").list.min(), + max_c=pl.col("c").list.max(), + ) + expected = pl.DataFrame( + { + "min_a": [None, 1], + "max_a": [None, 3], + "min_b": [1, None], + "max_b": [2, None], + "min_c": [None, 1], + "max_c": [None, 2], + } + ) + assert_frame_equal(out, expected) + + +def test_fill_null_empty_list() -> None: + assert pl.Series([["a"], None]).fill_null([]).to_list() == [["a"], []] + + +def test_nested_logical() -> None: + assert pl.select( + pl.lit(pl.Series(["a", "b"], dtype=pl.Categorical)).implode().implode() + ).to_dict(as_series=False) == {"": [[["a", "b"]]]} + + +def test_null_list_construction_and_materialization() -> None: + s = pl.Series([None, []]) + assert s.dtype == pl.List(pl.Null) + assert s.to_list() == [None, []] + + +def test_logical_type_struct_agg_list() -> None: + df = pl.DataFrame( + {"cats": ["Value1", "Value2", "Value1"]}, + schema_overrides={"cats": pl.Categorical}, + ) + out = df.group_by(1).agg(pl.struct("cats")) + assert out.dtypes == [ + pl.Int32, + pl.List(pl.Struct([pl.Field("cats", pl.Categorical)])), + ] + assert out["cats"].to_list() == [ + [{"cats": "Value1"}, {"cats": "Value2"}, {"cats": "Value1"}] + ] + + +def test_logical_parallel_list_collect() -> None: + # this triggers the anonymous builder in par collect + out = ( + pl.DataFrame( + { + "Group": ["GroupA", "GroupA", "GroupA"], + "Values": ["Value1", "Value2", "Value1"], + }, + schema_overrides={"Values": pl.Categorical}, + ) + .group_by("Group") + .agg(pl.col("Values").value_counts(sort=True)) + .explode("Values") + .unnest("Values") + ) + assert out.dtypes == [pl.String, pl.Categorical, pl.UInt32] + assert out.to_dict(as_series=False) == { + "Group": ["GroupA", "GroupA"], + "Values": ["Value1", "Value2"], + "count": [2, 1], + } + + +def test_list_recursive_categorical_cast() -> None: + # go 3 deep, just to show off + dtype = pl.List(pl.List(pl.List(pl.Categorical))) + values = [[[["x"], ["y"]]], [[["x"]]]] + s = pl.Series(values).cast(dtype) + assert s.dtype == dtype + assert s.to_list() == values + + +@pytest.mark.parametrize( + ("data", "expected_data", "dtype"), + [ + ([None, 1, 2], [None, [1], [2]], pl.Int64), + ([None, 1.0, 2.0], [None, [1.0], [2.0]], pl.Float64), + ([None, "x", "y"], [None, ["x"], ["y"]], pl.String), + ([None, True, False], [None, [True], [False]], pl.Boolean), + ], +) +def test_non_nested_cast_to_list( + data: list[Any], expected_data: list[Any], dtype: PolarsDataType +) -> None: + s = pl.Series(data, dtype=dtype) + casted_s = s.cast(pl.List(dtype)) + expected = pl.Series(expected_data, dtype=pl.List(dtype)) + assert_series_equal(casted_s, expected) + + +def test_list_new_from_index_logical() -> None: + s = ( + pl.select(pl.struct(pl.Series("a", [date(2001, 1, 1)])).implode()) + .to_series() + .new_from_index(0, 1) + ) + assert s.dtype == pl.List(pl.Struct([pl.Field("a", pl.Date)])) + assert s.to_list() == [[{"a": date(2001, 1, 1)}]] + + # empty new_from_index # 8420 + dtype = pl.List(pl.Struct({"c": pl.Boolean})) + s = pl.Series("b", values=[[]], dtype=dtype) + s = s.new_from_index(0, 2) + assert s.dtype == dtype + assert s.to_list() == [[], []] + + +def test_list_recursive_time_unit_cast() -> None: + values = [[datetime(2000, 1, 1, 0, 0, 0)]] + dtype = pl.List(pl.Datetime("ns")) + s = pl.Series(values) + out = s.cast(dtype) + assert out.dtype == dtype + assert out.to_list() == values + + +def test_list_null_list_categorical_cast() -> None: + expected = pl.List(pl.Categorical) + s = pl.Series([[]], dtype=pl.List(pl.Null)).cast(expected) + assert s.dtype == expected + assert s.to_list() == [[]] + + +def test_list_null_pickle() -> None: + df = pl.DataFrame([{"a": None}], schema={"a": pl.List(pl.Int64)}) + assert_frame_equal(df, pickle.loads(pickle.dumps(df))) + + +def test_struct_with_nulls_as_list() -> None: + df = pl.DataFrame([[{"a": 1, "b": 2}], [{"c": 3, "d": None}]]) + result = df.select(pl.concat_list(pl.all()).alias("as_list")) + assert result.to_dict(as_series=False) == { + "as_list": [ + [ + {"a": 1, "b": 2, "c": None, "d": None}, + {"a": None, "b": None, "c": 3, "d": None}, + ] + ] + } + + +def test_list_amortized_iter_clear_settings_10126() -> None: + out = ( + pl.DataFrame({"a": [[1], [1], [2]], "b": [[1, 2], [1, 3], [4]]}) + .explode("a") + .group_by("a") + .agg(pl.col("b").flatten()) + .with_columns(pl.col("b").list.unique()) + .sort("a") + ) + + assert out.to_dict(as_series=False) == {"a": [1, 2], "b": [[1, 2, 3], [4]]} + + +def test_list_inner_cast_physical_11513() -> None: + df = pl.DataFrame( + { + "date": ["foo"], + "struct": [[]], + }, + schema_overrides={ + "struct": pl.List( + pl.Struct( + { + "field": pl.Struct( + {"subfield": pl.List(pl.Struct({"subsubfield": pl.Date}))} + ) + } + ) + ) + }, + ) + assert df.select(pl.col("struct").gather(0)).to_dict(as_series=False) == { + "struct": [[]] + } + + +@pytest.mark.parametrize( + ("dtype", "expected"), [(pl.List, True), (pl.Struct, True), (pl.String, False)] +) +def test_datatype_is_nested(dtype: PolarsDataType, expected: bool) -> None: + assert dtype.is_nested() is expected + + +def test_list_series_construction_with_dtype_11849_11878() -> None: + s = pl.Series([[1, 2], [3.3, 4.9]], dtype=pl.List(pl.Float64)) + assert s.to_list() == [[1, 2], [3.3, 4.9]] + + s1 = pl.Series([[1, 2], [3.0, 4.0]], dtype=pl.List(pl.Float64)) + s2 = pl.Series([[1, 2], [3.0, 4.9]], dtype=pl.List(pl.Float64)) + assert_series_equal(s1 == s2, pl.Series([True, False])) + + s = pl.Series( + "groups", + [[{"1": "A", "2": None}], [{"1": "B", "2": "C"}, {"1": "D", "2": "E"}]], + dtype=pl.List(pl.Struct([pl.Field("1", pl.String), pl.Field("2", pl.String)])), + ) + + assert s.to_list() == [ + [{"1": "A", "2": None}], + [{"1": "B", "2": "C"}, {"1": "D", "2": "E"}], + ] + + +def test_as_list_logical_type() -> None: + df = pl.select(timestamp=pl.date(2000, 1, 1), value=0) + assert df.group_by(True).agg( + pl.col("timestamp").gather(pl.col("value").arg_max()) + ).to_dict(as_series=False) == {"literal": [True], "timestamp": [[date(2000, 1, 1)]]} + + +@pytest.fixture +def data_dispersion() -> pl.DataFrame: + return pl.DataFrame( + { + "int": [[1, 2, 3, 4, 5]], + "float": [[1.0, 2.0, 3.0, 4.0, 5.0]], + "duration": [[1000, 2000, 3000, 4000, 5000]], + }, + schema={ + "int": pl.List(pl.Int64), + "float": pl.List(pl.Float64), + "duration": pl.List(pl.Duration), + }, + ) + + +def test_list_var(data_dispersion: pl.DataFrame) -> None: + df = data_dispersion + + result = df.select( + pl.col("int").list.var().name.suffix("_var"), + pl.col("float").list.var().name.suffix("_var"), + pl.col("duration").list.var().name.suffix("_var"), + ) + + expected = pl.DataFrame( + [ + pl.Series("int_var", [2.5], dtype=pl.Float64), + pl.Series("float_var", [2.5], dtype=pl.Float64), + pl.Series( + "duration_var", + [timedelta(microseconds=2000)], + dtype=pl.Duration(time_unit="ms"), + ), + ] + ) + + assert_frame_equal(result, expected) + + +def test_list_std(data_dispersion: pl.DataFrame) -> None: + df = data_dispersion + + result = df.select( + pl.col("int").list.std().name.suffix("_std"), + pl.col("float").list.std().name.suffix("_std"), + pl.col("duration").list.std().name.suffix("_std"), + ) + + expected = pl.DataFrame( + [ + pl.Series("int_std", [1.5811388300841898], dtype=pl.Float64), + pl.Series("float_std", [1.5811388300841898], dtype=pl.Float64), + pl.Series( + "duration_std", + [timedelta(microseconds=1581)], + dtype=pl.Duration(time_unit="us"), + ), + ] + ) + + assert_frame_equal(result, expected) + + +def test_list_median(data_dispersion: pl.DataFrame) -> None: + df = data_dispersion + + result = df.select( + pl.col("int").list.median().name.suffix("_median"), + pl.col("float").list.median().name.suffix("_median"), + pl.col("duration").list.median().name.suffix("_median"), + ) + + expected = pl.DataFrame( + [ + pl.Series("int_median", [3.0], dtype=pl.Float64), + pl.Series("float_median", [3.0], dtype=pl.Float64), + pl.Series( + "duration_median", + [timedelta(microseconds=3000)], + dtype=pl.Duration(time_unit="us"), + ), + ] + ) + + assert_frame_equal(result, expected) + + +def test_list_gather_null_struct_14927() -> None: + df = pl.DataFrame( + [ + { + "index": 0, + "col_0": [{"field_0": 1.0}], + }, + { + "index": 1, + "col_0": None, + }, + ] + ) + + expected = pl.DataFrame( + {"index": [1], "col_0": [None], "field_0": [None]}, + schema={**df.schema, "field_0": pl.Float64}, + ) + expr = pl.col("col_0").list.get(0, null_on_oob=True).struct.field("field_0") + out = df.filter(pl.col("index") > 0).with_columns(expr) + assert_frame_equal(out, expected) + + +def test_list_of_series_with_nulls() -> None: + inner_series = pl.Series("inner", [1, 2, 3]) + s = pl.Series("a", [inner_series, None]) + assert_series_equal(s, pl.Series("a", [[1, 2, 3], None])) + + +def test_take_list_15719() -> None: + schema = pl.List(pl.List(pl.Int64)) + df = pl.DataFrame( + {"a": [None, None], "b": [None, [[1, 2]]]}, schema={"a": schema, "b": schema} + ) + df = df.select( + a_explode=pl.col("a").explode(), + a_get=pl.col("a").list.get(0, null_on_oob=True), + b_explode=pl.col("b").explode(), + b_get=pl.col("b").list.get(0, null_on_oob=True), + ) + + expected_schema = pl.List(pl.Int64) + expected = pl.DataFrame( + { + "a_explode": [None, None], + "a_get": [None, None], + "b_explode": [None, [1, 2]], + "b_get": [None, [1, 2]], + }, + schema={ + "a_explode": expected_schema, + "a_get": expected_schema, + "b_explode": expected_schema, + "b_get": expected_schema, + }, + ) + + assert_frame_equal(df, expected) + + +def test_list_str_sum_exception_12935() -> None: + with pytest.raises(pl.exceptions.InvalidOperationError): + pl.Series(["foo", "bar"]).sum() + + +def test_list_list_sum_exception_12935() -> None: + with pytest.raises(pl.exceptions.InvalidOperationError): + pl.Series([[1], [2]]).sum() + + +@pytest.mark.may_fail_auto_streaming +def test_null_list_categorical_16405() -> None: + df = pl.DataFrame( + [(None, "foo")], + schema={ + "match": pl.List(pl.Categorical), + "what": pl.Categorical, + }, + orient="row", + ) + + df = df.select( + pl.col("match") + .list.set_intersection(pl.concat_list(pl.col("what"))) + .alias("result") + ) + + expected = pl.DataFrame([None], schema={"result": pl.List(pl.Categorical)}) + assert_frame_equal(df, expected) + + +def test_list_get_literal_broadcast_21463() -> None: + df = pl.DataFrame({"a": [1, 2, 3]}) + df = df.with_columns(x=pl.lit([1, 2, 3, 4])) + expected = df.with_columns(b=pl.col("x").list.get(pl.col("a"))).drop("x") + df = pl.DataFrame({"a": [1, 2, 3]}) + actual = df.with_columns(b=pl.lit([1, 2, 3, 4]).list.get(pl.col("a"))) + assert expected.equals(actual) + + +def test_sort() -> None: + def tc(a: list[Any], b: list[Any]) -> None: + a_s = pl.Series("l", a, pl.List(pl.Int64)) + b_s = pl.Series("l", b, pl.List(pl.Int64)) + + assert_series_equal(a_s.sort(), b_s) + + tc([], []) + tc([[1]], [[1]]) + tc([[1], []], [[], [1]]) + tc([[2, 1]], [[2, 1]]) + tc([[2, 1], [1, 2]], [[1, 2], [2, 1]]) + + +@pytest.mark.parametrize("inner_dtype", TEMPORAL_DTYPES) +@pytest.mark.parametrize("agg", ["min", "max", "mean", "median"]) +def test_list_agg_temporal(inner_dtype: PolarsDataType, agg: str) -> None: + lf = pl.LazyFrame({"a": [[1, 3]]}, schema={"a": pl.List(inner_dtype)}) + result = lf.select(getattr(pl.col("a").list, agg)()) + expected = lf.select(getattr(pl.col("a").explode(), agg)()) + assert result.collect_schema() == expected.collect_schema() + assert_frame_equal(result.collect(), expected.collect()) diff --git a/py-polars/tests/unit/datatypes/test_null.py b/py-polars/tests/unit/datatypes/test_null.py new file mode 100644 index 000000000000..9afcac0755b8 --- /dev/null +++ b/py-polars/tests/unit/datatypes/test_null.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + + +def test_null_index() -> None: + df = pl.DataFrame({"a": [[1, 2], [3, 4], [5, 6]], "b": [[1, 2], [1, 2], [4, 5]]}) + + result = df.with_columns(pl.lit(None).alias("null_col"))[-1] + + expected = pl.DataFrame( + {"a": [[5, 6]], "b": [[4, 5]], "null_col": [None]}, + schema_overrides={"null_col": pl.Null}, + ) + assert_frame_equal(result, expected) + + +def test_null_grouping_12950() -> None: + assert pl.DataFrame({"x": None}).unique().to_dict(as_series=False) == {"x": [None]} + assert pl.DataFrame({"x": [None, None]}).unique().to_dict(as_series=False) == { + "x": [None] + } + assert pl.DataFrame({"x": None}).slice(0, 0).unique().to_dict(as_series=False) == { + "x": [] + } + + +@pytest.mark.parametrize( + ("op", "expected"), + [ + (pl.Expr.gt, [None, None]), + (pl.Expr.lt, [None, None]), + (pl.Expr.ge, [None, None]), + (pl.Expr.le, [None, None]), + (pl.Expr.eq, [None, None]), + (pl.Expr.eq_missing, [True, True]), + (pl.Expr.ne, [None, None]), + (pl.Expr.ne_missing, [False, False]), + ], +) +def test_null_comp_14118(op: Any, expected: list[None | bool]) -> None: + df = pl.DataFrame( + { + "a": [None, None], + "b": [None, None], + } + ) + + output_df = df.select( + cmp=op(pl.col("a"), pl.col("b")), + broadcast_lhs=op(pl.lit(None), pl.col("b")), + broadcast_rhs=op(pl.col("a"), pl.lit(None)), + ) + + expected_df = pl.DataFrame( + { + "cmp": expected, + "broadcast_lhs": expected, + "broadcast_rhs": expected, + }, + schema={ + "cmp": pl.Boolean, + "broadcast_lhs": pl.Boolean, + "broadcast_rhs": pl.Boolean, + }, + ) + assert_frame_equal(output_df, expected_df) + + +def test_null_hash_rows_14100() -> None: + df = pl.DataFrame({"a": [1, 2, 3, 4], "b": [None, None, None, None]}) + assert df.hash_rows().dtype == pl.UInt64 + assert df["b"].hash().dtype == pl.UInt64 + assert df.select([pl.col("b").hash().alias("foo")])["foo"].dtype == pl.UInt64 + + +def test_null_lit_filter_16664() -> None: + assert pl.DataFrame({"x": []}).filter(pl.lit(True)).shape == (0, 1) diff --git a/py-polars/tests/unit/datatypes/test_object.py b/py-polars/tests/unit/datatypes/test_object.py new file mode 100644 index 000000000000..37b005c259c5 --- /dev/null +++ b/py-polars/tests/unit/datatypes/test_object.py @@ -0,0 +1,263 @@ +from __future__ import annotations + +import io +from pathlib import Path +from uuid import uuid4 + +import numpy as np +import pytest + +import polars as pl +from polars.exceptions import ComputeError +from polars.testing import assert_series_equal + + +def test_series_init_instantiated_object() -> None: + s = pl.Series([object(), object()], dtype=pl.Object()) + assert isinstance(s, pl.Series) + assert isinstance(s.dtype, pl.Object) + + +def test_object_empty_filter_5911() -> None: + df = pl.DataFrame( + data=[ + (1, "dog", {}), + ], + schema=[ + ("pet_id", pl.Int64), + ("pet_type", pl.Categorical), + ("pet_obj", pl.Object), + ], + orient="row", + ) + + empty_df = df.filter(pl.col("pet_type") == "cat") + out = empty_df.select(["pet_obj"]) + assert out.dtypes == [pl.Object] + assert out.shape == (0, 1) + + +def test_object_in_struct() -> None: + np_a = np.array([1, 2, 3]) + np_b = np.array([4, 5, 6]) + df = pl.DataFrame({"A": [1, 2], "B": pl.Series([np_a, np_b], dtype=pl.Object)}) + + with pytest.raises(pl.exceptions.InvalidOperationError): + df.select([pl.struct(["B"])]) + + +def test_nullable_object_13538() -> None: + df = pl.DataFrame( + data=[ + ({"a": 1},), + ({"b": 3},), + (None,), + ], + schema=[ + ("blob", pl.Object), + ], + orient="row", + ) + + df = df.select( + is_null=pl.col("blob").is_null(), is_not_null=pl.col("blob").is_not_null() + ) + assert df.to_dict(as_series=False) == { + "is_null": [False, False, True], + "is_not_null": [True, True, False], + } + + df = pl.DataFrame({"col": pl.Series([0, 1, 2, None], dtype=pl.Object)}) + df = df.select( + is_null=pl.col("col").is_null(), is_not_null=pl.col("col").is_not_null() + ) + assert df.to_dict(as_series=False) == { + "is_null": [False, False, False, True], + "is_not_null": [True, True, True, False], + } + + +def test_nullable_object_17936() -> None: + class Custom: + value: int + + def __init__(self, value: int) -> None: + self.value = value + + def mapper(value: int) -> Custom | None: + if value == 2: + return None + return Custom(value) + + df = pl.DataFrame({"a": [1, 2, 3]}) + + assert df.select( + pl.col("a").map_elements(mapper, return_dtype=pl.Object).alias("with_dtype"), + pl.col("a").map_elements(mapper).alias("without_dtype"), + ).null_count().row(0) == (1, 1) + + +def test_empty_sort() -> None: + df = pl.DataFrame( + data=[ + ({"name": "bar", "sort_key": 2},), + ({"name": "foo", "sort_key": 1},), + ], + schema=[ + ("blob", pl.Object), + ], + orient="row", + ) + df_filtered = df.filter( + pl.col("blob").map_elements( + lambda blob: blob["name"] == "baz", return_dtype=pl.Boolean + ) + ) + df_filtered.sort( + pl.col("blob").map_elements( + lambda blob: blob["sort_key"], return_dtype=pl.Int64 + ) + ) + + +def test_object_to_dicts() -> None: + df = pl.DataFrame({"d": [{"a": 1, "b": 2, "c": 3}]}, schema={"d": pl.Object}) + assert df.to_dicts() == [{"d": {"a": 1, "b": 2, "c": 3}}] + + +def test_object_concat() -> None: + df1 = pl.DataFrame( + {"a": [1, 2, 3]}, + schema={"a": pl.Object}, + ) + + df2 = pl.DataFrame( + {"a": [1, 4, 3]}, + schema={"a": pl.Object}, + ) + + catted = pl.concat([df1, df2]) + assert catted.shape == (6, 1) + assert catted.dtypes == [pl.Object] + assert catted.to_dict(as_series=False) == {"a": [1, 2, 3, 1, 4, 3]} + + +def test_object_concat_diagonal_14651() -> None: + df1 = pl.DataFrame({"a": ["abc"]}, schema={"a": pl.Object}) + df2 = pl.DataFrame({"b": ["def"]}, schema={"b": pl.Object}) + result = pl.concat([df1, df2], how="diagonal") + assert result.schema == pl.Schema({"a": pl.Object, "b": pl.Object}) + assert result["a"].to_list() == ["abc", None] + assert result["b"].to_list() == [None, "def"] + + +def test_object_row_construction() -> None: + data = [ + [uuid4()], + [uuid4()], + [uuid4()], + ] + df = pl.DataFrame( + data, + orient="row", + ) + assert df.dtypes == [pl.Object] + assert df["column_0"].to_list() == [value[0] for value in data] + + +def test_object_apply_to_struct() -> None: + s = pl.Series([0, 1, 2], dtype=pl.Object) + out = s.map_elements(lambda x: {"a": str(x), "b": x}) + assert out.dtype == pl.Struct([pl.Field("a", pl.String), pl.Field("b", pl.Int64)]) + + +def test_null_obj_str_13512() -> None: + # https://github.com/pola-rs/polars/issues/13512 + + df1 = pl.DataFrame( + { + "key": [1], + } + ) + df2 = pl.DataFrame({"key": [2], "a": pl.Series([1], dtype=pl.Object)}) + + out = df1.join(df2, on="key", how="left") + s = str(out) + assert s == ( + "shape: (1, 2)\n" + "┌─────┬────────┐\n" + "│ key ┆ a │\n" + "│ --- ┆ --- │\n" + "│ i64 ┆ object │\n" + "╞═════╪════════╡\n" + "│ 1 ┆ null │\n" + "└─────┴────────┘" + ) + + +def test_format_object_series_14267() -> None: + s = pl.Series([Path(), Path("abc")]) + expected = "shape: (2,)\nSeries: '' [o][object]\n[\n\t.\n\tabc\n]" + assert str(s) == expected + + +def test_object_raise_writers() -> None: + df = pl.DataFrame({"a": object()}) + + buf = io.BytesIO() + + with pytest.raises(ComputeError): + df.write_parquet(buf) + with pytest.raises(ComputeError): + df.write_ipc(buf) + with pytest.raises(ComputeError): + df.write_json(buf) + with pytest.raises(ComputeError): + df.write_csv(buf) + with pytest.raises(ComputeError): + df.write_avro(buf) + + +def test_raise_list_object() -> None: + # We don't want to support this. Unsafe enough as it is already. + with pytest.raises(ValueError): + pl.Series([[object()]], dtype=pl.List(pl.Object())) + + +def test_object_null_slice() -> None: + s = pl.Series("x", [1, None, 42], dtype=pl.Object) + assert_series_equal(s.is_null(), pl.Series("x", [False, True, False])) + assert_series_equal(s.slice(0, 2).is_null(), pl.Series("x", [False, True])) + assert_series_equal(s.slice(1, 1).is_null(), pl.Series("x", [True])) + assert_series_equal(s.slice(2, 1).is_null(), pl.Series("x", [False])) + + +def test_object_sort_scalar_19925() -> None: + a = object() + assert pl.DataFrame({"a": [0], "obj": [a]}).sort("a")["obj"].item() == a + + +def test_object_estimated_size() -> None: + df = pl.DataFrame( + [ + ["3", "random python object, not a string"], + ], + schema={"name": pl.String, "ob": pl.Object}, + orient="row", + ) + + # is a huge underestimation + assert df.estimated_size() == 9 + + +def test_object_polars_dtypes_20572() -> None: + df = pl.DataFrame( + { + "a": pl.Date(), + "b": pl.Decimal(5, 1), + "c": pl.Int64(), + "d": pl.Object(), + "e": pl.String(), + } + ) + assert all(dt.is_object() for dt in df.schema.dtypes()) diff --git a/py-polars/tests/unit/datatypes/test_parse.py b/py-polars/tests/unit/datatypes/test_parse.py new file mode 100644 index 000000000000..017e7dd82b93 --- /dev/null +++ b/py-polars/tests/unit/datatypes/test_parse.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +import enum +from datetime import date, datetime +from typing import ( + TYPE_CHECKING, + Any, + ForwardRef, + NamedTuple, + Optional, + Union, +) + +import pytest + +import polars as pl +from polars.datatypes._parse import ( + _parse_forward_ref_into_dtype, + _parse_generic_into_dtype, + _parse_union_type_into_dtype, + parse_into_dtype, + parse_py_type_into_dtype, +) + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + + +def assert_dtype_equal(left: PolarsDataType, right: PolarsDataType) -> None: + assert left == right + assert type(left) is type(right) + assert hash(left) == hash(right) + + +@pytest.mark.parametrize( + ("input", "expected"), + [ + (pl.Int8(), pl.Int8()), + (list, pl.List), + ], +) +def test_parse_into_dtype(input: Any, expected: PolarsDataType) -> None: + result = parse_into_dtype(input) + assert_dtype_equal(result, expected) + + +def test_parse_into_dtype_enum_19724() -> None: + class PythonEnum(str, enum.Enum): + CAT1 = "A" + CAT2 = "B" + CAT3 = "C" + + result = parse_into_dtype(PythonEnum) + expected = pl.Enum(["A", "B", "C"]) + assert_dtype_equal(result, expected) + + +def test_parse_into_dtype_enum_ints_19724() -> None: + class PythonEnum(int, enum.Enum): + CAT1 = 1 + CAT2 = 2 + CAT3 = 3 + + with pytest.raises(TypeError, match="Enum categories must be strings"): + parse_into_dtype(PythonEnum) + + +@pytest.mark.parametrize( + ("input", "expected"), + [ + (datetime, pl.Datetime("us")), + (date, pl.Date()), + (type(None), pl.Null()), + (object, pl.Object()), + ], +) +def test_parse_py_type_into_dtype(input: Any, expected: PolarsDataType) -> None: + result = parse_py_type_into_dtype(input) + assert_dtype_equal(result, expected) + + +@pytest.mark.parametrize( + ("input", "expected"), + [ + (list[int], pl.List(pl.Int64())), + (tuple[str, ...], pl.List(pl.String())), + (tuple[datetime, datetime], pl.List(pl.Datetime("us"))), + ], +) +def test_parse_generic_into_dtype(input: Any, expected: PolarsDataType) -> None: + result = _parse_generic_into_dtype(input) + assert_dtype_equal(result, expected) + + +@pytest.mark.parametrize( + "input", + [ + dict[str, float], + tuple[int, str], + tuple[int, float, float], + ], +) +def test_parse_generic_into_dtype_invalid(input: Any) -> None: + with pytest.raises(TypeError): + _parse_generic_into_dtype(input) + + +@pytest.mark.parametrize( + ("input", "expected"), + [ + (ForwardRef("date"), pl.Date()), + (ForwardRef("int | None"), pl.Int64()), + (ForwardRef("None | float"), pl.Float64()), + ], +) +def test_parse_forward_ref_into_dtype(input: Any, expected: PolarsDataType) -> None: + result = _parse_forward_ref_into_dtype(input) + assert_dtype_equal(result, expected) + + +@pytest.mark.parametrize( + ("input", "expected"), + [ + (Optional[int], pl.Int64()), + (Optional[pl.String], pl.String), + (Union[float, None], pl.Float64()), + ], +) +def test_parse_union_type_into_dtype(input: Any, expected: PolarsDataType) -> None: + result = _parse_union_type_into_dtype(input) + assert_dtype_equal(result, expected) + + +@pytest.mark.parametrize( + "input", + [ + Union[int, float], + Optional[Union[int, str]], + ], +) +def test_parse_union_type_into_dtype_invalid(input: Any) -> None: + with pytest.raises(TypeError): + _parse_union_type_into_dtype(input) + + +def test_parse_dtype_namedtuple_fields() -> None: + # Utilizes ForwardRef parsing + + class MyTuple(NamedTuple): + a: str + b: int + c: str | None = None + + schema = {c: parse_into_dtype(a) for c, a in MyTuple.__annotations__.items()} + + expected = pl.Schema({"a": pl.String(), "b": pl.Int64(), "c": pl.String()}) + assert schema == expected diff --git a/py-polars/tests/unit/datatypes/test_string.py b/py-polars/tests/unit/datatypes/test_string.py new file mode 100644 index 000000000000..4250b2c23e99 --- /dev/null +++ b/py-polars/tests/unit/datatypes/test_string.py @@ -0,0 +1,46 @@ +import json + +import polars as pl +from polars.testing import assert_series_equal + + +def test_series_init_string() -> None: + s = pl.Series(["a", "b"]) + assert s.dtype == pl.String + + +def test_utf8_alias_eq() -> None: + assert pl.Utf8 == pl.String + assert pl.Utf8 == pl.String() + assert pl.Utf8() == pl.String + assert pl.Utf8() == pl.String() + + +def test_utf8_alias_hash() -> None: + assert hash(pl.Utf8) == hash(pl.String) + assert hash(pl.Utf8()) == hash(pl.String()) + + +def test_utf8_alias_series_init() -> None: + s = pl.Series(["a", "b"], dtype=pl.Utf8) + assert s.dtype == pl.String + + +def test_utf8_alias_lit() -> None: + result = pl.select(a=pl.lit(5, dtype=pl.Utf8)).to_series() + expected = pl.Series("a", ["5"], dtype=pl.String) + assert_series_equal(result, expected) + + +def test_json_decode_multiple_chunks() -> None: + a = json.dumps({"x": None}) + b = json.dumps({"x": True}) + + df_1 = pl.Series([a]).to_frame("s") + df_2 = pl.Series([b]).to_frame("s") + + df = pl.concat([df_1, df_2]) + + assert df.with_columns(pl.col("s").str.json_decode()).to_dict(as_series=False) == { + "s": [{"x": None}, {"x": True}] + } diff --git a/py-polars/tests/unit/datatypes/test_struct.py b/py-polars/tests/unit/datatypes/test_struct.py new file mode 100644 index 000000000000..4f062f01b0cb --- /dev/null +++ b/py-polars/tests/unit/datatypes/test_struct.py @@ -0,0 +1,1290 @@ +from __future__ import annotations + +import io +from dataclasses import dataclass +from datetime import datetime, time +from typing import TYPE_CHECKING, Any, Callable + +import pandas as pd +import pyarrow as pa +import pytest + +import polars as pl +import polars.selectors as cs +from polars.exceptions import InvalidOperationError +from polars.testing import assert_frame_equal, assert_series_equal + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + + +def test_struct_to_list() -> None: + assert pl.DataFrame( + {"int": [1, 2], "str": ["a", "b"], "bool": [True, None], "list": [[1, 2], [3]]} + ).select([pl.struct(pl.all()).alias("my_struct")]).to_series().to_list() == [ + {"int": 1, "str": "a", "bool": True, "list": [1, 2]}, + {"int": 2, "str": "b", "bool": None, "list": [3]}, + ] + + +def test_apply_unnest() -> None: + df = ( + pl.Series([None, 2, 3, 4]) + .map_elements( + lambda x: {"a": x, "b": x * 2, "c": True, "d": [1, 2], "e": "foo"} + ) + .struct.unnest() + ) + + expected = pl.DataFrame( + { + "a": [None, 2, 3, 4], + "b": [None, 4, 6, 8], + "c": [None, True, True, True], + "d": [None, [1, 2], [1, 2], [1, 2]], + "e": [None, "foo", "foo", "foo"], + } + ) + + assert_frame_equal(df, expected) + + +def test_struct_equality() -> None: + # equal struct dimensions, equal values + s1 = pl.Series("misc", [{"x": "a", "y": 0}, {"x": "b", "y": 0}]) + s2 = pl.Series("misc", [{"x": "a", "y": 0}, {"x": "b", "y": 0}]) + assert (s1 == s2).all() + assert (~(s1 != s2)).all() + + # equal struct dimensions, unequal values + s3 = pl.Series("misc", [{"x": "a", "y": 0}, {"x": "c", "y": 2}]) + s4 = pl.Series("misc", [{"x": "b", "y": 1}, {"x": "d", "y": 3}]) + assert (s3 != s4).all() + assert (~(s3 == s4)).all() + + # unequal struct dimensions, equal values (where fields overlap) + s5 = pl.Series("misc", [{"x": "a", "y": 0}, {"x": "b", "y": 0}]) + s6 = pl.Series("misc", [{"x": "a", "y": 0, "z": 0}, {"x": "b", "y": 0, "z": 0}]) + assert (s5 != s6).all() + assert (~(s5 == s6)).all() + + +def test_struct_equality_strict() -> None: + s1 = pl.Struct( + [ + pl.Field("a", pl.Int64), + pl.Field("b", pl.Boolean), + pl.Field("c", pl.List(pl.Int32)), + ] + ) + s2 = pl.Struct( + [pl.Field("a", pl.Int64), pl.Field("b", pl.Boolean), pl.Field("c", pl.List)] + ) + + # strict + assert s1.is_(s2) is False + + # permissive (default) + assert s1 == s2 + assert s1 == s2 + + +def test_struct_hashes() -> None: + dtypes = ( + pl.Struct, + pl.Struct([pl.Field("a", pl.Int64)]), + pl.Struct([pl.Field("a", pl.Int64), pl.Field("b", pl.List(pl.Int64))]), + ) + assert len({hash(tp) for tp in (dtypes)}) == 3 + + +def test_struct_unnesting() -> None: + df_base = pl.DataFrame({"a": [1, 2]}) + df = df_base.select( + pl.all().alias("a_original"), + pl.col("a") + .map_elements(lambda x: {"a": x, "b": x * 2, "c": x % 2 == 0}) + .struct.rename_fields(["a", "a_squared", "mod2eq0"]) + .alias("foo"), + ) + expected = pl.DataFrame( + { + "a_original": [1, 2], + "a": [1, 2], + "a_squared": [2, 4], + "mod2eq0": [False, True], + } + ) + for cols in ("foo", cs.ends_with("oo")): + out_eager = df.unnest(cols) + assert_frame_equal(out_eager, expected) + + out_lazy = df.lazy().unnest(cols) + assert_frame_equal(out_lazy, expected.lazy()) + + out = ( + df_base.lazy() + .select( + pl.all().alias("a_original"), + pl.col("a") + .map_elements(lambda x: {"a": x, "b": x * 2, "c": x % 2 == 0}) + .struct.rename_fields(["a", "a_squared", "mod2eq0"]) + .alias("foo"), + ) + .unnest("foo") + .collect() + ) + assert_frame_equal(out, expected) + + +def test_struct_unnest_multiple() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [1.0, 2.0], "d": ["a", "b"]}) + df_structs = df.select(s1=pl.struct(["a", "b"]), s2=pl.struct(["c", "d"])) + + # List input + result = df_structs.unnest(["s1", "s2"]) + assert_frame_equal(result, df) + assert all(tp.is_nested() for tp in df_structs.dtypes) + + # Positional input + result = df_structs.unnest("s1", "s2") + assert_frame_equal(result, df) + + +def test_struct_function_expansion() -> None: + df = pl.DataFrame( + {"a": [1, 2, 3, 4], "b": ["one", "two", "three", "four"], "c": [9, 8, 7, 6]} + ) + struct_schema = {"a": pl.UInt32, "b": pl.String} + dfs = df.with_columns(pl.struct(pl.col(["a", "b"]), schema=struct_schema)) + s = dfs["a"] + + assert isinstance(s, pl.Series) + assert s.struct.fields == ["a", "b"] + assert pl.Struct(struct_schema) == s.to_frame().schema["a"] + + assert_series_equal(s, pl.Series(dfs.select("a"))) + assert_frame_equal(dfs, pl.DataFrame(dfs)) + + +def test_nested_struct() -> None: + df = pl.DataFrame({"d": [1, 2, 3], "e": ["foo", "bar", "biz"]}) + # Nest the dataframe + nest_l1 = df.to_struct("c").to_frame() + # Add another column on the same level + nest_l1 = nest_l1.with_columns(pl.col("c").is_null().alias("b")) + # Nest the dataframe again + nest_l2 = nest_l1.to_struct("a").to_frame() + + assert isinstance(nest_l2.dtypes[0], pl.datatypes.Struct) + assert [f.dtype for f in nest_l2.dtypes[0].fields] == nest_l1.dtypes + assert isinstance(nest_l1.dtypes[0], pl.datatypes.Struct) + + +def test_struct_to_pandas() -> None: + pdf = pd.DataFrame([{"a": {"b": {"c": 2}}}]) + df = pl.from_pandas(pdf) + + assert isinstance(df.dtypes[0], pl.datatypes.Struct) + assert df.to_pandas().equals(pdf) + + +def test_struct_logical_types_to_pandas() -> None: + timestamp = datetime(2022, 1, 1) + df = pd.DataFrame([{"struct": {"timestamp": timestamp}}]) + assert pl.from_pandas(df).dtypes == [pl.Struct] + + +def test_struct_cols() -> None: + """Test that struct columns can be imported and work as expected.""" + + def build_struct_df(data: list[dict[str, object]]) -> pl.DataFrame: + """ + Build Polars df from list of dicts. + + Can't import directly because of issue #3145. + """ + arrow_df = pa.Table.from_pylist(data) + polars_df = pl.from_arrow(arrow_df) + assert isinstance(polars_df, pl.DataFrame) + return polars_df + + # struct column + df = build_struct_df([{"struct_col": {"inner": 1}}]) + assert df.columns == ["struct_col"] + assert df.schema == {"struct_col": pl.Struct({"inner": pl.Int64})} + assert df["struct_col"].struct.field("inner").to_list() == [1] + + # struct in struct + df = build_struct_df([{"nested_struct_col": {"struct_col": {"inner": 1}}}]) + assert df["nested_struct_col"].struct.field("struct_col").struct.field( + "inner" + ).to_list() == [1] + + # struct in list + df = build_struct_df([{"list_of_struct_col": [{"inner": 1}]}]) + assert df["list_of_struct_col"][0].struct.field("inner").to_list() == [1] + + # struct in list in struct + df = build_struct_df( + [{"struct_list_struct_col": {"list_struct_col": [{"inner": 1}]}}] + ) + assert df["struct_list_struct_col"].struct.field("list_struct_col")[0].struct.field( + "inner" + ).to_list() == [1] + + +def test_struct_with_validity() -> None: + data = [{"a": {"b": 1}}, {"a": None}] + tbl = pa.Table.from_pylist(data) + df = pl.from_arrow(tbl) + assert isinstance(df, pl.DataFrame) + assert df["a"].to_list() == [{"b": 1}, None] + + +def test_from_dicts_struct() -> None: + assert pl.from_dicts([{"a": 1, "b": {"a": 1, "b": 2}}]).to_series(1).to_list() == [ + {"a": 1, "b": 2} + ] + + assert pl.from_dicts( + [{"a": 1, "b": {"a_deep": 1, "b_deep": {"a_deeper": [1, 2, 4]}}}] + ).to_series(1).to_list() == [{"a_deep": 1, "b_deep": {"a_deeper": [1, 2, 4]}}] + + data = [ + {"a": [{"b": 0, "c": 1}]}, + {"a": [{"b": 1, "c": 2}]}, + ] + + assert pl.from_dicts(data).to_series().to_list() == [ + [{"b": 0, "c": 1}], + [{"b": 1, "c": 2}], + ] + + +@pytest.mark.may_fail_auto_streaming +def test_list_to_struct() -> None: + df = pl.DataFrame({"a": [[1, 2, 3], [1, 2]]}) + assert df.to_series().list.to_struct().to_list() == [ + {"field_0": 1, "field_1": 2, "field_2": 3}, + {"field_0": 1, "field_1": 2, "field_2": None}, + ] + + df = pl.DataFrame({"a": [[1, 2], [1, 2, 3]]}) + assert df.to_series().list.to_struct( + fields=lambda idx: f"col_name_{idx}" + ).to_list() == [ + {"col_name_0": 1, "col_name_1": 2}, + {"col_name_0": 1, "col_name_1": 2}, + ] + + df = pl.DataFrame({"a": [[1, 2], [1, 2, 3]]}) + assert df.to_series().list.to_struct(n_field_strategy="max_width").to_list() == [ + {"field_0": 1, "field_1": 2, "field_2": None}, + {"field_0": 1, "field_1": 2, "field_2": 3}, + ] + + # set upper bound + df = pl.DataFrame({"lists": [[1, 1, 1], [0, 1, 0], [1, 0, 0]]}) + assert df.lazy().select( + pl.col("lists").list.to_struct(upper_bound=3, _eager=True) + ).unnest("lists").sum().collect().columns == ["field_0", "field_1", "field_2"] + + +def test_sort_df_with_list_struct() -> None: + assert pl.DataFrame([{"a": 1, "b": [{"c": 1}]}]).sort("a").to_dict( + as_series=False + ) == { + "a": [1], + "b": [[{"c": 1}]], + } + + +def test_struct_list_head_tail() -> None: + assert pl.DataFrame( + { + "list_of_struct": [ + [{"a": 1, "b": 4}, {"a": 3, "b": 6}], + [{"a": 10, "b": 40}, {"a": 20, "b": 50}, {"a": 30, "b": 60}], + ] + } + ).with_columns( + pl.col("list_of_struct").list.head(1).alias("head"), + pl.col("list_of_struct").list.tail(1).alias("tail"), + ).to_dict(as_series=False) == { + "list_of_struct": [ + [{"a": 1, "b": 4}, {"a": 3, "b": 6}], + [{"a": 10, "b": 40}, {"a": 20, "b": 50}, {"a": 30, "b": 60}], + ], + "head": [[{"a": 1, "b": 4}], [{"a": 10, "b": 40}]], + "tail": [[{"a": 3, "b": 6}], [{"a": 30, "b": 60}]], + } + + +def test_struct_agg_all() -> None: + df = pl.DataFrame( + { + "group": ["a", "a", "b", "b", "b"], + "col1": [ + {"x": 1, "y": 100}, + {"x": 2, "y": 200}, + {"x": 3, "y": 300}, + {"x": 4, "y": 400}, + {"x": 5, "y": 500}, + ], + } + ) + + assert df.group_by("group", maintain_order=True).all().to_dict(as_series=False) == { + "group": ["a", "b"], + "col1": [ + [{"x": 1, "y": 100}, {"x": 2, "y": 200}], + [{"x": 3, "y": 300}, {"x": 4, "y": 400}, {"x": 5, "y": 500}], + ], + } + + +def test_struct_empty_list_creation() -> None: + payload = [[], [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}], []] + + assert pl.DataFrame({"list_struct": payload}).to_dict(as_series=False) == { + "list_struct": [[], [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}], []] + } + + # pop first + payload = payload[1:] + assert pl.DataFrame({"list_struct": payload}).to_dict(as_series=False) == { + "list_struct": [[{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}], []] + } + + +def test_struct_arr_methods() -> None: + df = pl.DataFrame( + { + "list_struct": [ + [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}], + [{"a": 1, "b": 2}, {"a": 3, "b": 4}], + [{"a": 1, "b": 2}], + ], + } + ) + assert df.select([pl.col("list_struct").list.first()]).to_dict(as_series=False) == { + "list_struct": [{"a": 1, "b": 2}, {"a": 1, "b": 2}, {"a": 1, "b": 2}] + } + assert df.select([pl.col("list_struct").list.last()]).to_dict(as_series=False) == { + "list_struct": [{"a": 5, "b": 6}, {"a": 3, "b": 4}, {"a": 1, "b": 2}] + } + assert df.select([pl.col("list_struct").list.get(0)]).to_dict(as_series=False) == { + "list_struct": [{"a": 1, "b": 2}, {"a": 1, "b": 2}, {"a": 1, "b": 2}] + } + + +def test_struct_concat_list() -> None: + assert pl.DataFrame( + { + "list_struct1": [ + [{"a": 1, "b": 2}, {"a": 3, "b": 4}], + [{"a": 1, "b": 2}], + ], + "list_struct2": [ + [{"a": 6, "b": 7}, {"a": 8, "b": 9}], + [{"a": 6, "b": 7}], + ], + } + ).with_columns(pl.col("list_struct1").list.concat("list_struct2").alias("result"))[ + "result" + ].to_list() == [ + [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 6, "b": 7}, {"a": 8, "b": 9}], + [{"a": 1, "b": 2}, {"a": 6, "b": 7}], + ] + + +def test_struct_arr_reverse() -> None: + assert pl.DataFrame( + { + "list_struct": [ + [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}], + [{"a": 30, "b": 40}, {"a": 10, "b": 20}, {"a": 50, "b": 60}], + ], + } + ).with_columns([pl.col("list_struct").list.reverse()]).to_dict(as_series=False) == { + "list_struct": [ + [{"a": 5, "b": 6}, {"a": 3, "b": 4}, {"a": 1, "b": 2}], + [{"a": 50, "b": 60}, {"a": 10, "b": 20}, {"a": 30, "b": 40}], + ] + } + + +def test_struct_comparison() -> None: + df = pl.DataFrame( + { + "col1": [{"a": 1, "b": 2}, {"a": 3, "b": 4}], + "col2": [{"a": 1, "b": 2}, {"a": 3, "b": 4}], + } + ) + assert df.filter(pl.col("col1") == pl.col("col2")).rows() == [ + ({"a": 1, "b": 2}, {"a": 1, "b": 2}), + ({"a": 3, "b": 4}, {"a": 3, "b": 4}), + ] + # floats w/ ints + df = pl.DataFrame( + { + "col1": [{"a": 1, "b": 2}, {"a": 3, "b": 4}], + "col2": [{"a": 1.0, "b": 2}, {"a": 3.0, "b": 4}], + } + ) + assert df.filter(pl.col("col1") == pl.col("col2")).rows() == [ + ({"a": 1, "b": 2}, {"a": 1.0, "b": 2}), + ({"a": 3, "b": 4}, {"a": 3.0, "b": 4}), + ] + + df = pl.DataFrame( + { + "col1": [{"a": 1, "b": 2}, {"a": 3, "b": 4}], + "col2": [{"a": 2, "b": 2}, {"a": 3, "b": 4}], + } + ) + assert df.filter(pl.col("col1") == pl.col("col2")).to_dict(as_series=False) == { + "col1": [{"a": 3, "b": 4}], + "col2": [{"a": 3, "b": 4}], + } + + +def test_struct_order() -> None: + df = pl.DataFrame({"col1": [{"a": 1, "b": 2}, {"b": 4, "a": 3}]}) + expected = {"col1": [{"a": 1, "b": 2}, {"a": 3, "b": 4}]} + assert df.to_dict(as_series=False) == expected + + # null values should not trigger this + assert ( + pl.Series( + values=[ + {"a": 1, "b": None}, + {"a": 2, "b": 20}, + ], + ).to_list() + ) == [{"a": 1, "b": None}, {"a": 2, "b": 20}] + + assert ( + pl.Series( + values=[ + {"a": 1, "b": 10}, + {"a": 2, "b": None}, + ], + ).to_list() + ) == [{"a": 1, "b": 10}, {"a": 2, "b": None}] + + +def test_struct_arr_eval() -> None: + df = pl.DataFrame( + {"col_struct": [[{"a": 1, "b": 11}, {"a": 2, "b": 12}, {"a": 1, "b": 11}]]} + ) + assert df.with_columns( + pl.col("col_struct").list.eval(pl.element().first()).alias("first") + ).to_dict(as_series=False) == { + "col_struct": [[{"a": 1, "b": 11}, {"a": 2, "b": 12}, {"a": 1, "b": 11}]], + "first": [[{"a": 1, "b": 11}]], + } + + +def test_list_of_struct_unique() -> None: + df = pl.DataFrame( + {"col_struct": [[{"a": 1, "b": 11}, {"a": 2, "b": 12}, {"a": 1, "b": 11}]]} + ) + # the order is unpredictable + unique = df.with_columns(pl.col("col_struct").list.unique().alias("unique"))[ + "unique" + ].to_list() + assert len(unique) == 1 + unique_el = unique[0] + assert len(unique_el) == 2 + assert {"a": 2, "b": 12} in unique_el + assert {"a": 1, "b": 11} in unique_el + + +def test_nested_explode_4026() -> None: + df = pl.DataFrame( + { + "data": [ + [ + {"account_id": 10, "values": [1, 2]}, + {"account_id": 11, "values": [10, 20]}, + ] + ], + "day": ["monday"], + } + ) + + assert df.explode("data").to_dict(as_series=False) == { + "data": [ + {"account_id": 10, "values": [1, 2]}, + {"account_id": 11, "values": [10, 20]}, + ], + "day": ["monday", "monday"], + } + + +def test_nested_struct_sliced_append() -> None: + s = pl.Series( + [ + { + "_experience": { + "aaid": { + "id": "A", + "namespace": {"code": "alpha"}, + } + } + }, + { + "_experience": { + "aaid": { + "id": "B", + "namespace": {"code": "bravo"}, + }, + } + }, + { + "_experience": { + "aaid": { + "id": "D", + "namespace": {"code": "delta"}, + } + } + }, + ] + ) + s2 = s[1:] + s.append(s2) + + assert s.to_list() == [ + {"_experience": {"aaid": {"id": "A", "namespace": {"code": "alpha"}}}}, + {"_experience": {"aaid": {"id": "B", "namespace": {"code": "bravo"}}}}, + {"_experience": {"aaid": {"id": "D", "namespace": {"code": "delta"}}}}, + {"_experience": {"aaid": {"id": "B", "namespace": {"code": "bravo"}}}}, + {"_experience": {"aaid": {"id": "D", "namespace": {"code": "delta"}}}}, + ] + + +def test_struct_group_by_field_agg_4216() -> None: + df = pl.DataFrame([{"a": {"b": 1}, "c": 0}]) + + result = df.group_by("c").agg(pl.col("a").struct.field("b").count()) + expected = {"c": [0], "b": [1]} + assert result.to_dict(as_series=False) == expected + + +def test_struct_getitem() -> None: + assert pl.Series([{"a": 1, "b": 2}]).struct["b"].name == "b" + assert pl.Series([{"a": 1, "b": 2}]).struct[0].name == "a" + assert pl.Series([{"a": 1, "b": 2}]).struct[1].name == "b" + assert pl.Series([{"a": 1, "b": 2}]).struct[-1].name == "b" + assert pl.Series([{"a": 1, "b": 2}]).to_frame().select( + pl.col("").struct[0] + ).to_dict(as_series=False) == {"a": [1]} + + +def test_struct_supertype() -> None: + assert pl.from_dicts( + [{"vehicle": {"auto": "car"}}, {"vehicle": {"auto": None}}] + ).to_dict(as_series=False) == {"vehicle": [{"auto": "car"}, {"auto": None}]} + + +def test_struct_any_value_get_after_append() -> None: + schema = {"a": pl.Int8, "b": pl.Int32} + struct_def = pl.Struct(schema) + + a = pl.Series("s", [{"a": 1, "b": 2}], dtype=struct_def) + b = pl.Series("s", [{"a": 2, "b": 3}], dtype=struct_def) + a = a.append(b) + + assert a[0] == {"a": 1, "b": 2} + assert a[1] == {"a": 2, "b": 3} + assert schema == a.to_frame().unnest("s").schema + + +def test_struct_categorical_5843() -> None: + df = pl.DataFrame({"foo": ["a", "b", "c", "a"]}).with_columns( + pl.col("foo").cast(pl.Categorical) + ) + result = df.select(pl.col("foo").value_counts(sort=True)) + assert result.to_dict(as_series=False) == { + "foo": [ + {"foo": "a", "count": 2}, + {"foo": "b", "count": 1}, + {"foo": "c", "count": 1}, + ] + } + + +def test_empty_struct() -> None: + # List + df = pl.DataFrame({"a": [[{}]]}) + assert df.to_dict(as_series=False) == {"a": [[{}]]} + + # Struct one not empty + df = pl.DataFrame({"a": [[{}, {"a": 10}]]}) + assert df.to_dict(as_series=False) == {"a": [[{"a": None}, {"a": 10}]]} + + # Empty struct + df = pl.DataFrame({"a": [{}]}) + assert df.to_dict(as_series=False) == {"a": [{}]} + + +@pytest.mark.parametrize( + "dtype", + [ + pl.List, + pl.List(pl.Null), + pl.List(pl.String), + pl.Array(pl.Null, 32), + pl.Array(pl.UInt8, 16), + pl.Struct([pl.Field("", pl.Null)]), + pl.Struct([pl.Field("x", pl.UInt32), pl.Field("y", pl.Float64)]), + ], +) +def test_empty_series_nested_dtype(dtype: PolarsDataType) -> None: + # various flavours of empty nested dtype + s = pl.Series("nested", dtype=dtype) + assert s.dtype.base_type() == dtype.base_type() + assert s.to_list() == [] + + +@pytest.mark.parametrize( + "data", + [ + [{}, {}], + [{}, None], + [None, {}], + [None, None], + ], +) +def test_empty_with_schema_struct(data: list[dict[str, object] | None]) -> None: + # Empty structs, with schema + struct_schema = {"a": pl.Date, "b": pl.Boolean, "c": pl.Float64} + frame_schema = {"x": pl.Int8, "y": pl.Struct(struct_schema)} + + @dataclass + class TestData: + x: int + y: dict[str, object] | None + + # test init from rows, dicts, and dataclasses + dict_data = {"x": [10, 20], "y": data} + dataclass_data = [ + TestData(10, data[0]), + TestData(20, data[1]), + ] + for frame_data in (dict_data, dataclass_data): + df = pl.DataFrame( + data=frame_data, + schema=frame_schema, # type: ignore[arg-type] + ) + assert df.schema == frame_schema + assert df.unnest("y").columns == ["x", "a", "b", "c"] + assert df.rows() == [ + ( + 10, + {"a": None, "b": None, "c": None} if data[0] is not None else None, + ), + ( + 20, + {"a": None, "b": None, "c": None} if data[1] is not None else None, + ), + ] + + +def test_struct_null_cast() -> None: + dtype = pl.Struct( + [ + pl.Field("a", pl.Int64), + pl.Field("b", pl.String), + pl.Field("c", pl.List(pl.Float64)), + ] + ) + assert ( + pl.DataFrame() + .lazy() + .select([pl.lit(None, dtype=pl.Null).cast(dtype, strict=True)]) + .collect() + ).to_dict(as_series=False) == {"literal": [None]} + + +def test_nested_struct_in_lists_cast() -> None: + assert pl.DataFrame( + { + "node_groups": [ + [{"nodes": [{"id": 1, "is_started": True}]}], + [{"nodes": []}], + ] + } + ).to_dict(as_series=False) == { + "node_groups": [[{"nodes": [{"id": 1, "is_started": True}]}], [{"nodes": []}]] + } + + +def test_struct_concat_self_no_rechunk() -> None: + df = pl.DataFrame([{"A": {"a": 1}}]) + out = pl.concat([df, df], rechunk=False) + assert out.dtypes == [pl.Struct([pl.Field("a", pl.Int64)])] + assert out.to_dict(as_series=False) == {"A": [{"a": 1}, {"a": 1}]} + + +def test_sort_structs() -> None: + df = pl.DataFrame( + { + "sex": ["m", "f", "f", "f", "m", "m", "f"], + "age": [22, 38, 26, 24, 21, 46, 22], + }, + ) + df_sorted_as_struct = df.select(pl.struct(["sex", "age"]).sort()).unnest("sex") + df_expected = df.sort(by=["sex", "age"]) + + assert_frame_equal(df_expected, df_sorted_as_struct) + assert df_sorted_as_struct.to_dict(as_series=False) == { + "sex": ["f", "f", "f", "f", "m", "m", "m"], + "age": [22, 24, 26, 38, 21, 22, 46], + } + + +def test_struct_applies_as_map() -> None: + df = pl.DataFrame({"id": [1, 1, 2], "x": ["a", "b", "c"], "y": ["d", "e", "f"]}) + + # the window function doesn't really make sense + # but it runs the test: #7286 + assert df.select( + pl.struct([pl.col("x"), pl.col("y") + pl.col("y")]).over("id") + ).to_dict(as_series=False) == { + "x": [{"x": "a", "y": "dd"}, {"x": "b", "y": "ee"}, {"x": "c", "y": "ff"}] + } + + +def test_struct_is_in() -> None: + # The dtype casts below test that struct is_in upcasts dtypes. + s1 = ( + pl.DataFrame({"x": [4, 3, 4, 9], "y": [0, 4, 6, 2]}) + .select(pl.struct(schema={"x": pl.Int64, "y": pl.Int64})) + .to_series() + ) + s2 = ( + pl.DataFrame({"x": [4, 3, 5, 9], "y": [0, 7, 6, 2]}) + .select(pl.struct(["x", "y"])) + .to_series() + ) + assert s1.is_in(s2).to_list() == [True, False, False, True] + + +def test_nested_struct_logicals() -> None: + # single nested + payload1 = [[{"a": time(10)}], [{"a": time(10)}]] + assert pl.Series(payload1).to_list() == payload1 + # double nested + payload2 = [[[{"a": time(10)}]], [[{"a": time(10)}]]] + assert pl.Series(payload2).to_list() == payload2 + + +def test_struct_name_passed_in_agg_apply() -> None: + struct_expr = pl.struct( + [ + pl.col("A").min(), + pl.col("B").search_sorted(pl.Series([3, 4])), + ] + ).alias("index") + + assert pl.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [1, 2, 2]}).group_by( + "C" + ).agg(struct_expr).sort("C", descending=True).to_dict(as_series=False) == { + "C": [2, 1], + "index": [ + [{"A": 2, "B": 0}, {"A": 2, "B": 0}], + [{"A": 1, "B": 0}, {"A": 1, "B": 0}], + ], + } + + df = pl.DataFrame({"val": [-3, -2, -1, 0, 1, 2, 3], "k": [0] * 7}) + + assert df.group_by("k").agg( + pl.struct( + pl.col("val").value_counts(sort=True).struct.field("val").alias("val"), + pl.col("val").value_counts(sort=True).struct.field("count").alias("count"), + ) + ).to_dict(as_series=False) == { + "k": [0], + "val": [ + [ + {"val": -3, "count": 1}, + {"val": -2, "count": 1}, + {"val": -1, "count": 1}, + {"val": 0, "count": 1}, + {"val": 1, "count": 1}, + {"val": 2, "count": 1}, + {"val": 3, "count": 1}, + ] + ], + } + + +def test_struct_null_count_strict_cast() -> None: + s = pl.Series([{"a": None}]).cast(pl.Struct({"a": pl.Categorical})) + assert s.dtype == pl.Struct([pl.Field("a", pl.Categorical)]) + assert s.to_list() == [{"a": None}] + + +def test_struct_get_field_by_index() -> None: + df = pl.DataFrame({"val": [{"a": 1, "b": 2}]}) + expected = {"b": [2]} + assert df.select(pl.all().struct[1]).to_dict(as_series=False) == expected + + +def test_struct_arithmetic_schema() -> None: + q = pl.LazyFrame({"A": [1], "B": [2]}) + + assert q.select(pl.struct("A") - pl.struct("B")).collect_schema()["A"] == pl.Struct( + {"A": pl.Int64} + ) + + +def test_struct_field() -> None: + df = pl.DataFrame( + { + "item": [ + {"name": "John", "age": 30, "car": None}, + {"name": "Alice", "age": 65, "car": "Volvo"}, + ] + } + ) + + assert df.select( + pl.col("item").struct.with_fields( + pl.field("name").str.to_uppercase(), pl.field("car").fill_null("Mazda") + ) + ).to_dict(as_series=False) == { + "item": [ + {"name": "JOHN", "age": 30, "car": "Mazda"}, + {"name": "ALICE", "age": 65, "car": "Volvo"}, + ] + } + + +def test_struct_field_recognized_as_renaming_expr_16480() -> None: + q = pl.LazyFrame( + { + "foo": "bar", + "my_struct": [{"x": 1, "y": 2}], + } + ).select(pl.col("my_struct").struct.field("x")) + + q = q.select("x") + assert q.collect().to_dict(as_series=False) == {"x": [1]} + + +def test_struct_filter_chunked_16498() -> None: + with pl.StringCache(): + N = 5 + df_orig1 = pl.DataFrame({"cat_a": ["remove"] * N, "cat_b": ["b"] * N}) + + df_orig2 = pl.DataFrame({"cat_a": ["a"] * N, "cat_b": ["b"] * N}) + + df = pl.concat([df_orig1, df_orig2], rechunk=False).cast(pl.Categorical) + df = df.select(pl.struct(pl.all()).alias("s")) + df = df.filter(pl.col("s").struct.field("cat_a") != pl.lit("remove")) + assert df.shape == (5, 1) + + +def test_struct_field_dynint_nullable_16243() -> None: + pl.select(pl.lit(None).fill_null(pl.struct(42))) + + +def test_struct_split_16536() -> None: + df = pl.DataFrame({"struct": [{"a": {"a": {"a": 1}}}], "list": [[1]], "int": [1]}) + + df = pl.concat([df, df, df, df], rechunk=False) + assert df.filter(pl.col("int") == 1).shape == (4, 3) + + +def test_struct_wildcard_expansion_and_exclude() -> None: + df = pl.DataFrame( + { + "id": [1, 2], + "meta_data": [ + {"system_data": "to_remove", "user_data": "keep"}, + {"user_data": "keep_"}, + ], + } + ) + + # ensure wildcard expansion is on input + assert df.lazy().select( + pl.col("meta_data").struct.with_fields("*") + ).collect().schema["meta_data"].fields == [ # type: ignore[attr-defined] + pl.Field("system_data", pl.String), + pl.Field("user_data", pl.String), + pl.Field("id", pl.Int64), + pl.Field( + "meta_data", pl.Struct({"system_data": pl.String, "user_data": pl.String}) + ), + ] + + with pytest.raises(InvalidOperationError): + df.lazy().select( + pl.col("meta_data").struct.with_fields(pl.field("*").exclude("user_data")) + ).collect() + + +def test_struct_chunked_gather_17603() -> None: + df = pl.DataFrame( + { + "id": [0, 0, 1, 1], + "a": [0, 1, 2, 3], + } + ).select("id", pl.struct("a")) + df = pl.concat((df, df)) + + assert df.select(pl.col("a").map_batches(lambda s: s).over("id")).to_dict( + as_series=False + ) == { + "a": [ + {"a": 0}, + {"a": 1}, + {"a": 2}, + {"a": 3}, + {"a": 0}, + {"a": 1}, + {"a": 2}, + {"a": 3}, + ] + } + + +def test_struct_out_nullability_from_arrow() -> None: + df = pl.DataFrame(pd.DataFrame({"abc": [{"a": 1.0, "b": pd.NA}, pd.NA]})) + res = df.select(a=pl.col("abc").struct.field("a")) + assert res.to_dicts() == [{"a": 1.0}, {"a": None}] + + +def test_empty_struct_raise() -> None: + with pytest.raises(ValueError): + pl.struct() + + +def test_named_exprs() -> None: + df = pl.DataFrame({"a": 1}) + schema = {"b": pl.Int64} + res = df.select(pl.struct(schema=schema, b=pl.col("a"))) + assert res.to_dict(as_series=False) == {"b": [{"b": 1}]} + assert res.schema["b"] == pl.Struct(schema) + + +def test_struct_outer_nullability_zip_18119() -> None: + df = pl.Series("int", [0, 1, 2, 3], dtype=pl.Int64).to_frame() + assert df.lazy().with_columns( + result=pl.when(pl.col("int") >= 1).then( + pl.struct( + a=pl.when(pl.col("int") % 2 == 1).then(True), + b=pl.when(pl.col("int") >= 2).then(False), + ) + ) + ).collect().to_dict(as_series=False) == { + "int": [0, 1, 2, 3], + "result": [ + None, + {"a": True, "b": None}, + {"a": None, "b": False}, + {"a": True, "b": False}, + ], + } + + +def test_struct_group_by_shift_18107() -> None: + df_in = pl.DataFrame( + { + "group": [1, 1, 1, 2, 2, 2], + "id": [1, 2, 3, 4, 5, 6], + "value": [ + {"lon": 20, "lat": 10}, + {"lon": 30, "lat": 20}, + {"lon": 40, "lat": 30}, + {"lon": 50, "lat": 40}, + {"lon": 60, "lat": 50}, + {"lon": 70, "lat": 60}, + ], + } + ) + + assert df_in.group_by("group", maintain_order=True).agg( + pl.col("value").shift(-1) + ).to_dict(as_series=False) == { + "group": [1, 2], + "value": [ + [{"lon": 30, "lat": 20}, {"lon": 40, "lat": 30}, None], + [{"lon": 60, "lat": 50}, {"lon": 70, "lat": 60}, None], + ], + } + + +def test_struct_chunked_zip_18119() -> None: + dtype = pl.Struct({"x": pl.Null}) + + a_dfs = [pl.DataFrame([pl.Series("a", [None] * i, dtype)]) for i in range(5)] + b_dfs = [pl.DataFrame([pl.Series("b", [None] * i, dtype)]) for i in range(5)] + mask_dfs = [ + pl.DataFrame([pl.Series("f", [None] * i, pl.Boolean)]) for i in range(5) + ] + + a = pl.concat([a_dfs[2], a_dfs[2], a_dfs[1]]) + b = pl.concat([b_dfs[4], b_dfs[1]]) + mask = pl.concat([mask_dfs[3], mask_dfs[2]]) + + df = pl.concat([a, b, mask], how="horizontal") + + assert_frame_equal( + df.select(pl.when(pl.col.f).then(pl.col.a).otherwise(pl.col.b)), + pl.DataFrame([pl.Series("a", [None] * 5, dtype)]), + ) + + +def test_struct_null_zip() -> None: + df = pl.Series("int", [], dtype=pl.Struct({"x": pl.Int64})).to_frame() + assert_frame_equal( + df.select(pl.when(pl.Series([True])).then(pl.col.int).otherwise(pl.col.int)), + pl.Series("int", [], dtype=pl.Struct({"x": pl.Int64})).to_frame(), + ) + + +@pytest.mark.parametrize("size", [0, 1, 2, 5, 9, 13, 42]) +def test_zfs_construction(size: int) -> None: + a = pl.Series("a", [{}] * size, pl.Struct([])) + assert a.len() == size + + +@pytest.mark.parametrize("size", [0, 1, 2, 13]) +def test_zfs_unnest(size: int) -> None: + a = pl.Series("a", [{}] * size, pl.Struct([])).struct.unnest() + assert a.height == size + assert a.width == 0 + + +@pytest.mark.parametrize("size", [0, 1, 2, 13]) +def test_zfs_equality(size: int) -> None: + a = pl.Series("a", [{}] * size, pl.Struct([])) + b = pl.Series("a", [{}] * size, pl.Struct([])) + + assert_series_equal(a, b) + + assert_frame_equal( + a.to_frame(), + b.to_frame(), + ) + + +def test_zfs_nullable_when_otherwise() -> None: + a = pl.Series("a", [{}, None, {}, {}, None], pl.Struct([])) + b = pl.Series("b", [None, {}, None, {}, None], pl.Struct([])) + + df = pl.DataFrame([a, b]) + + df = df.select( + x=pl.when(pl.col.a.is_not_null()).then(pl.col.a).otherwise(pl.col.b), + y=pl.when(pl.col.a.is_null()).then(pl.col.a).otherwise(pl.col.b), + ) + + assert_series_equal(df["x"], pl.Series("x", [{}, {}, {}, {}, None], pl.Struct([]))) + assert_series_equal( + df["y"], pl.Series("y", [None, None, None, {}, None], pl.Struct([])) + ) + + +def test_zfs_struct_fns() -> None: + a = pl.Series("a", [{}], pl.Struct([])) + + assert a.struct.fields == [] + + # @TODO: This should really throw an error as per #19132 + assert a.struct.rename_fields(["a"]).struct.unnest().shape == (1, 0) + assert a.struct.rename_fields([]).struct.unnest().shape == (1, 0) + + assert_series_equal(a.struct.json_encode(), pl.Series("a", ["{}"], pl.String)) + + +@pytest.mark.parametrize("format", ["binary", "json"]) +@pytest.mark.parametrize("size", [0, 1, 2, 13]) +def test_zfs_serialization_roundtrip(format: pl.SerializationFormat, size: int) -> None: + a = pl.Series("a", [{}] * size, pl.Struct([])).to_frame() + + f = io.BytesIO() + a.serialize(f, format=format) + + f.seek(0) + assert_frame_equal( + a, + pl.DataFrame.deserialize(f, format=format), + ) + + +@pytest.mark.parametrize("size", [0, 1, 2, 13]) +def test_zfs_row_encoding(size: int) -> None: + a = pl.Series("a", [{}] * size, pl.Struct([])) + + df = pl.DataFrame([a, pl.Series("x", list(range(size)), pl.Int8)]) + + gb = df.lazy().group_by(["a", "x"]).agg(pl.all().min()).collect(engine="streaming") + + # We need to ignore the order because the group_by is non-deterministic + assert_frame_equal(gb, df, check_row_order=False) + + +@pytest.mark.may_fail_auto_streaming +def test_list_to_struct_19208() -> None: + df = pl.DataFrame( + { + "nested": [ + [{"a": 1}], + [], + [{"a": 3}], + ] + } + ) + assert pl.concat([df[0], df[1], df[2]]).select( + pl.col("nested").list.to_struct(_eager=True) + ).to_dict(as_series=False) == { + "nested": [{"field_0": {"a": 1}}, {"field_0": None}, {"field_0": {"a": 3}}] + } + + +def test_struct_reverse_outer_validity_19445() -> None: + assert_series_equal( + pl.Series([{"a": 1}, None]).reverse(), + pl.Series([None, {"a": 1}]), + ) + + +@pytest.mark.parametrize("maybe_swap", [lambda a, b: (a, b), lambda a, b: (b, a)]) +def test_struct_eq_missing_outer_validity_19156( + maybe_swap: Callable[[pl.Series, pl.Series], tuple[pl.Series, pl.Series]], +) -> None: + # Ensure that lit({'x': NULL}).eq_missing(lit(NULL)) => False + l, r = maybe_swap( # noqa: E741 + pl.Series([{"a": None, "b": None}, None]), + pl.Series([None, {"a": None, "b": None}]), + ) + + assert_series_equal(l.eq_missing(r), pl.Series([False, False])) + assert_series_equal(l.ne_missing(r), pl.Series([True, True])) + + l, r = maybe_swap( # noqa: E741 + pl.Series([{"a": None, "b": None}, None]), + pl.Series([None]), + ) + + assert_series_equal(l.eq_missing(r), pl.Series([False, True])) + assert_series_equal(l.ne_missing(r), pl.Series([True, False])) + + l, r = maybe_swap( # noqa: E741 + pl.Series([{"a": None, "b": None}, None]), + pl.Series([{"a": None, "b": None}]), + ) + + assert_series_equal(l.eq_missing(r), pl.Series([True, False])) + assert_series_equal(l.ne_missing(r), pl.Series([False, True])) + + +def test_struct_field_list_eval_17356() -> None: + df = pl.DataFrame( + { + "items": [ + [ + {"name": "John", "age": 30, "car": None}, + ], + [ + {"name": "Alice", "age": 65, "car": "Volvo"}, + ], + ] + } + ) + + assert df.select( + pl.col("items").list.eval( + pl.element().struct.with_fields( + pl.field("name").str.to_uppercase(), pl.field("car").fill_null("Mazda") + ) + ) + ).to_dict(as_series=False) == { + "items": [ + [{"name": "JOHN", "age": 30, "car": "Mazda"}], + [{"name": "ALICE", "age": 65, "car": "Mazda"}], + ], + } + + +@pytest.mark.parametrize("data", [[1], [[1]], {"a": 1}, [{"a": 1}]]) +def test_leaf_list_eq_19613(data: Any) -> None: + assert not pl.DataFrame([data]).equals(pl.DataFrame([[data]])) + + +def test_nested_object_raises_15237() -> None: + obj = object() + df = pl.DataFrame({"a": [obj]}) + with pytest.raises(InvalidOperationError, match="nested objects are not allowed"): + df.select(pl.struct("a")) + + +def test_empty_struct_with_fields_21095() -> None: + df = pl.DataFrame({"a": [{}, {}]}) + assert_frame_equal( + df.select(pl.col("a").struct.with_fields(a=pl.lit(42, pl.Int64))), + pl.DataFrame({"a": [{"a": 42}, {"a": 42}]}), + ) + assert_frame_equal( + df.select(pl.col("a").struct.with_fields(a=None)), + pl.DataFrame({"a": [{"a": None}, {"a": None}]}), + ) + + +def test_cast_to_struct_needs_field_14083() -> None: + with pytest.raises( + InvalidOperationError, match="must specify one field in the struct" + ): + pl.Series([1], dtype=pl.Int32).cast(pl.Struct) + + with pytest.raises( + InvalidOperationError, match="must specify one field in the struct" + ): + pl.Series([1], dtype=pl.Int32).cast(pl.Struct({"a": pl.UInt8, "b": pl.UInt8})) + + +@pytest.mark.filterwarnings("ignore:Comparisons with None always result in null.") +def test_zip_outer_validity_infinite_recursion_21267() -> None: + s = pl.Series("x", [None, None], pl.Struct({"f": pl.Null})) + assert_series_equal( + s.to_frame().select(pl.col.x.__eq__(None)).to_series(), + pl.Series("x", [None, None], pl.Boolean), + ) + + +def test_struct_arithmetic_broadcast_21376() -> None: + df = pl.DataFrame( + { + "struct1": [{"low": 1, "mid": 2, "up": 3}], + "list_struct": [ + [{"low": 1, "mid": 2, "up": 3}, {"low": 1, "mid": 2, "up": 3}] + ], + } + ) + expected = pl.DataFrame( + { + "add_struct": [{"low": 2, "mid": 4, "up": 6}] * 2, + } + ) + out = ( + df.with_row_index() + .explode("list_struct") + .select((pl.col("struct1") + pl.col("list_struct")).alias("add_struct")) + ) + assert_frame_equal(out, expected) + + +def test_struct_cast_string_multiple_chunks_21650() -> None: + df = pl.DataFrame({"a": [{"a": 1, "b": 2}]}) + df = pl.concat([df, df], rechunk=False) + result = df.select(pl.col("a").cast(pl.String)) + expected = pl.DataFrame({"a": ["{1,2}", "{1,2}"]}) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/datatypes/test_temporal.py b/py-polars/tests/unit/datatypes/test_temporal.py new file mode 100644 index 000000000000..bb72e65bec2c --- /dev/null +++ b/py-polars/tests/unit/datatypes/test_temporal.py @@ -0,0 +1,2471 @@ +from __future__ import annotations + +import io +from datetime import date, datetime, time, timedelta, timezone +from typing import TYPE_CHECKING, Any, cast +from zoneinfo import ZoneInfo + +import hypothesis.strategies as st +import numpy as np +import pandas as pd +import pyarrow as pa +import pytest +import pytz +from hypothesis import given + +import polars as pl +import polars.selectors as cs +from polars.datatypes import DTYPE_TEMPORAL_UNITS +from polars.exceptions import ( + ComputeError, + InvalidOperationError, + PolarsInefficientMapWarning, +) +from polars.testing import ( + assert_frame_equal, + assert_series_equal, + assert_series_not_equal, +) +from tests.unit.conftest import DATETIME_DTYPES, TEMPORAL_DTYPES + +if TYPE_CHECKING: + from polars._typing import ( + Ambiguous, + PolarsTemporalType, + TimeUnit, + ) + + +def test_fill_null() -> None: + dtm = datetime.strptime("2021-01-01", "%Y-%m-%d") + s = pl.Series("A", [dtm, None]) + + for fill_val_datetime in (dtm, pl.lit(dtm)): + out = s.fill_null(fill_val_datetime) + + assert out.null_count() == 0 + assert out.dt[0] == dtm + assert out.dt[1] == dtm + + dt1 = date(2001, 1, 1) + dt2 = date(2001, 1, 2) + dt3 = date(2001, 1, 3) + + s = pl.Series("a", [dt1, dt2, dt3, None]) + dt_2 = date(2001, 1, 4) + + for fill_val_date in (dt_2, pl.lit(dt_2)): + out = s.fill_null(fill_val_date) + + assert out.null_count() == 0 + assert out.dt[0] == dt1 + assert out.dt[1] == dt2 + assert out.dt[-1] == dt_2 + + +def test_fill_null_temporal() -> None: + # test filling nulls with temporal literals across cols that use various timeunits + dtm = datetime.now() + dtm_ms = dtm.replace(microsecond=(dtm.microsecond // 1000) * 1000) + td = timedelta(days=7, seconds=45045) + tm = dtm.time() + dt = dtm.date() + + df = pl.DataFrame( + [ + [dtm, dtm_ms, dtm, dtm, dt, tm, td, td, td, td], + [None] * 10, + ], + schema=[ + ("a", pl.Datetime), + ("b", pl.Datetime("ms")), + ("c", pl.Datetime("us")), + ("d", pl.Datetime("ns")), + ("e", pl.Date), + ("f", pl.Time), + ("g", pl.Duration), + ("h", pl.Duration("ms")), + ("i", pl.Duration("us")), + ("j", pl.Duration("ns")), + ], + orient="row", + ) + + # fill literals + dtm_us_fill = dtm_ns_fill = datetime(2023, 12, 31, 23, 59, 59, 999999) + dtm_ms_fill = datetime(2023, 12, 31, 23, 59, 59, 999000) + td_us_fill = timedelta(days=7, seconds=45045, microseconds=123456) + td_ms_fill = timedelta(days=7, seconds=45045, microseconds=123000) + dt_fill = date(2023, 12, 31) + tm_fill = time(23, 59, 59) + + # apply literals via fill_null + ldf = df.lazy() + for temporal_literal in (dtm_ns_fill, td_us_fill, dt_fill, tm_fill): + ldf = ldf.fill_null(temporal_literal) + + # validate + assert ldf.collect().rows() == [ + (dtm, dtm_ms, dtm, dtm, dt, tm, td, td, td, td), # first row (no null values) + ( # second row (was composed entirely of nulls, now filled-in with literals) + dtm_us_fill, + dtm_ms_fill, + dtm_us_fill, + dtm_ns_fill, + dt_fill, + tm_fill, + td_us_fill, + td_ms_fill, + td_us_fill, + td_us_fill, + ), + ] + + +def test_filter_date() -> None: + dtcol = pl.col("date") + df = pl.DataFrame( + {"date": ["2020-01-02", "2020-01-03", "2020-01-04"], "index": [1, 2, 3]} + ).with_columns(dtcol.str.strptime(pl.Date, "%Y-%m-%d")) + assert df.rows() == [ + (date(2020, 1, 2), 1), + (date(2020, 1, 3), 2), + (date(2020, 1, 4), 3), + ] + + # filter by datetime + assert df.filter(dtcol <= pl.lit(datetime(2019, 1, 3))).is_empty() + assert df.filter(dtcol < pl.lit(datetime(2020, 1, 4))).rows() == df.rows()[:2] + assert df.filter(dtcol < pl.lit(datetime(2020, 1, 5))).rows() == df.rows() + + # filter by date + assert df.filter(dtcol <= pl.lit(date(2019, 1, 3))).is_empty() + assert df.filter(dtcol < pl.lit(date(2020, 1, 4))).rows() == df.rows()[:2] + assert df.filter(dtcol < pl.lit(date(2020, 1, 5))).rows() == df.rows() + + +def test_filter_time() -> None: + times = [time(8, 0), time(9, 0), time(10, 0)] + df = pl.DataFrame({"t": times}) + + assert df.filter(pl.col("t") <= pl.lit(time(7, 0))).is_empty() + assert df.filter(pl.col("t") < pl.lit(time(11, 0))).rows() == [(t,) for t in times] + assert df.filter(pl.col("t") < pl.lit(time(10, 0))).to_series().to_list() == [ + time(8, 0), + time(9, 0), + ] + + +def test_series_add_timedelta() -> None: + dates = pl.Series( + [datetime(2000, 1, 1), datetime(2027, 5, 19), datetime(2054, 10, 4)] + ) + out = pl.Series( + [datetime(2027, 5, 19), datetime(2054, 10, 4), datetime(2082, 2, 19)] + ) + assert_series_equal((dates + timedelta(days=10_000)), out) + + +def test_series_add_datetime() -> None: + deltas = pl.Series([timedelta(10_000), timedelta(20_000), timedelta(30_000)]) + out = pl.Series( + [datetime(2027, 5, 19), datetime(2054, 10, 4), datetime(2082, 2, 19)] + ) + assert_series_equal(deltas + pl.Series([datetime(2000, 1, 1)]), out) + + +def test_diff_datetime() -> None: + df = pl.DataFrame( + { + "timestamp": ["2021-02-01", "2021-03-1", "2850-04-1"], + "guild": [1, 2, 3], + "char": ["a", "a", "b"], + } + ) + out = ( + df.with_columns( + pl.col("timestamp").str.strptime(pl.Date, format="%Y-%m-%d"), + ).with_columns(pl.col("timestamp").diff().over("char", mapping_strategy="join")) + )["timestamp"] + assert_series_equal(out[0], out[1]) + + +def test_from_pydatetime() -> None: + datetimes = [ + datetime(2021, 1, 1), + datetime(2021, 1, 2), + datetime(2021, 1, 3), + datetime(2021, 1, 4, 12, 12), + None, + ] + s = pl.Series("name", datetimes) + assert s.dtype == pl.Datetime + assert s.name == "name" + assert s.null_count() == 1 + assert s.dt[0] == datetimes[0] + + dates = [date(2021, 1, 1), date(2021, 1, 2), date(2021, 1, 3), None] + s = pl.Series("name", dates) + assert s.dtype == pl.Date + assert s.name == "name" + assert s.null_count() == 1 + assert s.dt[0] == dates[0] + + +def test_int_to_python_datetime() -> None: + df = pl.DataFrame({"a": [100_000_000, 200_000_000]}).with_columns( + pl.col("a").cast(pl.Datetime).alias("b"), + pl.col("a").cast(pl.Datetime("ms")).alias("c"), + pl.col("a").cast(pl.Datetime("us")).alias("d"), + pl.col("a").cast(pl.Datetime("ns")).alias("e"), + ) + assert df.rows() == [ + ( + 100000000, + datetime(1970, 1, 1, 0, 1, 40), + datetime(1970, 1, 2, 3, 46, 40), + datetime(1970, 1, 1, 0, 1, 40), + datetime(1970, 1, 1, 0, 0, 0, 100000), + ), + ( + 200000000, + datetime(1970, 1, 1, 0, 3, 20), + datetime(1970, 1, 3, 7, 33, 20), + datetime(1970, 1, 1, 0, 3, 20), + datetime(1970, 1, 1, 0, 0, 0, 200000), + ), + ] + + assert df.select(pl.col(col).dt.timestamp() for col in ("c", "d", "e")).rows() == [ + (100000000000, 100000000, 100000), + (200000000000, 200000000, 200000), + ] + + assert df.select( + getattr(pl.col("a").cast(pl.Duration).dt, f"total_{unit}")().alias(f"u[{unit}]") + for unit in ("milliseconds", "microseconds", "nanoseconds") + ).rows() == [ + (100000, 100000000, 100000000000), + (200000, 200000000, 200000000000), + ] + + +def test_int_to_python_timedelta() -> None: + df = pl.DataFrame({"a": [100_001, 200_002]}).with_columns( + pl.col("a").cast(pl.Duration).alias("b"), + pl.col("a").cast(pl.Duration("ms")).alias("c"), + pl.col("a").cast(pl.Duration("us")).alias("d"), + pl.col("a").cast(pl.Duration("ns")).alias("e"), + ) + assert df.rows() == [ + ( + 100001, + timedelta(microseconds=100001), + timedelta(seconds=100, microseconds=1000), + timedelta(microseconds=100001), + timedelta(microseconds=100), + ), + ( + 200002, + timedelta(microseconds=200002), + timedelta(seconds=200, microseconds=2000), + timedelta(microseconds=200002), + timedelta(microseconds=200), + ), + ] + + assert df.select(pl.col(col).cast(pl.Int64) for col in ("c", "d", "e")).rows() == [ + (100001, 100001, 100001), + (200002, 200002, 200002), + ] + + +def test_datetime_consistency() -> None: + dt = datetime(2022, 7, 5, 10, 30, 45, 123455) + df = pl.DataFrame({"date": [dt]}) + + assert df["date"].dt[0] == dt + + for date_literal in ( + dt, + np.datetime64(dt, "us"), + np.datetime64(dt, "ns"), + ): + assert df.select(pl.lit(date_literal))["literal"].dt[0] == dt + assert df.filter(pl.col("date") == date_literal).rows() == [(dt,)] + + ddf = df.select( + pl.col("date"), + pl.lit(dt).alias("dt"), + pl.lit(dt).cast(pl.Datetime("ms")).alias("dt_ms"), + pl.lit(dt).cast(pl.Datetime("us")).alias("dt_us"), + pl.lit(dt).cast(pl.Datetime("ns")).alias("dt_ns"), + ) + assert ddf.schema == { + "date": pl.Datetime("us"), + "dt": pl.Datetime("us"), + "dt_ms": pl.Datetime("ms"), + "dt_us": pl.Datetime("us"), + "dt_ns": pl.Datetime("ns"), + } + assert ddf.select([pl.col(c).cast(int) for c in ddf.schema]).rows() == [ + ( + 1657017045123455, + 1657017045123455, + 1657017045123, + 1657017045123455, + 1657017045123455000, + ) + ] + + test_data = [ + datetime(2000, 1, 1, 1, 1, 1, 555555), + datetime(2514, 5, 30, 1, 53, 4, 986754), + datetime(3099, 12, 31, 23, 59, 59, 123456), + datetime(9999, 12, 31, 23, 59, 59, 999999), + ] + ddf = pl.DataFrame({"dtm": test_data}).with_columns( + pl.col("dtm").dt.nanosecond().alias("ns") + ) + assert ddf.rows() == [ + (test_data[0], 555555000), + (test_data[1], 986754000), + (test_data[2], 123456000), + (test_data[3], 999999000), + ] + # Same as above, but for tz-aware + test_data = [ + datetime(2000, 1, 1, 1, 1, 1, 555555, tzinfo=ZoneInfo("Asia/Kathmandu")), + datetime(2514, 5, 30, 1, 53, 4, 986754, tzinfo=ZoneInfo("Asia/Kathmandu")), + datetime(3099, 12, 31, 23, 59, 59, 123456, tzinfo=ZoneInfo("Asia/Kathmandu")), + datetime(9999, 12, 31, 23, 59, 59, 999999, tzinfo=ZoneInfo("Asia/Kathmandu")), + ] + ddf = pl.DataFrame({"dtm": test_data}).with_columns( + pl.col("dtm").dt.nanosecond().alias("ns") + ) + assert ddf.rows() == [ + (test_data[0], 555555000), + (test_data[1], 986754000), + (test_data[2], 123456000), + (test_data[3], 999999000), + ] + # Similar to above, but check for no error when crossing DST + test_data = [ + datetime(2021, 11, 7, 0, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 11, 7, 1, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 11, 7, 1, 0, fold=1, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 11, 7, 2, 0, tzinfo=ZoneInfo("US/Central")), + ] + ddf = pl.DataFrame({"dtm": test_data}).select( + pl.col("dtm").dt.convert_time_zone("US/Central") + ) + assert ddf.rows() == [ + (test_data[0],), + (test_data[1],), + (test_data[2],), + (test_data[3],), + ] + + +def test_timezone() -> None: + ts = pa.timestamp("s") + data = pa.array([1000, 2000], type=ts) + s = cast(pl.Series, pl.from_arrow(data)) + + tz_ts = pa.timestamp("s", tz="America/New_York") + tz_data = pa.array([1000, 2000], type=tz_ts) + tz_s = cast(pl.Series, pl.from_arrow(tz_data)) + + # different timezones are not considered equal + # we check both `null_equal=True` and `null_equal=False` + # https://github.com/pola-rs/polars/issues/5023 + assert s.equals(tz_s, null_equal=False) is False + assert s.equals(tz_s, null_equal=True) is False + assert_series_not_equal(tz_s, s) + assert_series_equal(s.cast(int), tz_s.cast(int)) + + +def test_to_dicts() -> None: + now = datetime.now() + data = { + "a": now, + "b": now.date(), + "c": now.time(), + "d": timedelta(days=1, seconds=43200), + } + df = pl.DataFrame( + data, schema_overrides={"a": pl.Datetime("ns"), "d": pl.Duration("ns")} + ) + assert df.height == 1 + + d = df.to_dicts()[0] + for col in data: + assert d[col] == data[col] + assert isinstance(d[col], type(data[col])) + + +def test_to_list() -> None: + s = pl.Series("date", [123543, 283478, 1243]).cast(pl.Date) + + out = s.to_list() + assert out[0] == date(2308, 4, 2) + + s = pl.Series("datetime", [a * 1_000_000 for a in [123543, 283478, 1243]]).cast( + pl.Datetime + ) + out = s.to_list() + assert out[0] == datetime(1970, 1, 2, 10, 19, 3) + + +def test_rows() -> None: + s0 = pl.Series("date", [123543, 283478, 1243]).cast(pl.Date) + with pytest.deprecated_call( + match="`ExprDateTimeNameSpace.with_time_unit` is deprecated" + ): + s1 = ( + pl.Series("datetime", [a * 1_000_000 for a in [123543, 283478, 1243]]) + .cast(pl.Datetime) + .dt.with_time_unit("ns") + ) + df = pl.DataFrame([s0, s1]) + + rows = df.rows() + assert rows[0][0] == date(2308, 4, 2) + assert rows[0][1] == datetime(1970, 1, 1, 0, 2, 3, 543000) + + +@pytest.mark.parametrize( + ("one", "two"), + [ + (date(2001, 1, 1), date(2001, 1, 2)), + (datetime(2001, 1, 1), datetime(2001, 1, 2)), + (time(20, 10, 0), time(20, 10, 1)), + # also test if the conversion stays correct with wide date ranges + (date(201, 1, 1), date(201, 1, 2)), + (date(5001, 1, 1), date(5001, 1, 2)), + ], +) +def test_date_comp(one: PolarsTemporalType, two: PolarsTemporalType) -> None: + a = pl.Series("a", [one, two]) + assert (a == one).to_list() == [True, False] + assert (a == two).to_list() == [False, True] + assert (a != one).to_list() == [False, True] + assert (a > one).to_list() == [False, True] + assert (a >= one).to_list() == [True, True] + assert (a < one).to_list() == [False, False] + assert (a <= one).to_list() == [True, False] + + +def test_datetime_comp_tz_aware() -> None: + a = pl.Series( + "a", [datetime(2020, 1, 1), datetime(2020, 1, 2)] + ).dt.replace_time_zone("Asia/Kathmandu") + other = datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kathmandu")) + result = a > other + expected = pl.Series("a", [False, True]) + assert_series_equal(result, expected) + + +def test_datetime_comp_tz_aware_invalid() -> None: + a = pl.Series( + "a", [datetime(2020, 1, 1), datetime(2020, 1, 2)] + ).dt.replace_time_zone("Asia/Kathmandu") + other = datetime(2020, 1, 1) + with pytest.raises( + TypeError, + match="Datetime time zone None does not match Series timezone 'Asia/Kathmandu'", + ): + _ = a > other + + +def test_to_arrow() -> None: + date_series = pl.Series("dates", ["2022-01-16", "2022-01-17"]).str.strptime( + pl.Date, "%Y-%m-%d" + ) + arr = date_series.to_arrow() + assert arr.type == pa.date32() + + +def test_explode_date() -> None: + datetimes = [ + datetime(2021, 12, 1, 0, 0), + datetime(2021, 12, 1, 0, 0), + datetime(2021, 12, 1, 0, 0), + datetime(2021, 12, 1, 0, 0), + ] + dates = [ + date(2021, 12, 1), + date(2021, 12, 1), + date(2021, 12, 1), + date(2021, 12, 1), + ] + for dclass, values in ((date, dates), (datetime, datetimes)): + df = pl.DataFrame( + { + "a": values, + "b": ["a", "b", "a", "b"], + "c": [1.0, 2.0, 1.5, 2.5], + } + ) + out = ( + df.group_by("b", maintain_order=True) + .agg([pl.col("a"), pl.col("c").pct_change()]) + .explode(["a", "c"]) + ) + assert out.shape == (4, 3) + assert out.rows() == [ + ("a", dclass(2021, 12, 1), None), + ("a", dclass(2021, 12, 1), 0.5), + ("b", dclass(2021, 12, 1), None), + ("b", dclass(2021, 12, 1), 0.25), + ] + + +def test_microseconds_accuracy() -> None: + timestamps = [ + datetime(2600, 1, 1, 0, 0, 0, 123456), + datetime(2800, 1, 1, 0, 0, 0, 456789), + ] + a = pa.Table.from_arrays( + arrays=[timestamps, [128, 256]], + schema=pa.schema( + [ + ("timestamp", pa.timestamp("us")), + ("value", pa.int16()), + ] + ), + ) + df = cast(pl.DataFrame, pl.from_arrow(a)) + assert df["timestamp"].to_list() == timestamps + + +def test_read_utc_times_parquet() -> None: + df = pd.DataFrame( + data={ + "Timestamp": pd.date_range( + "2022-01-01T00:00+00:00", "2022-01-01T10:00+00:00", freq="h" + ) + } + ) + f = io.BytesIO() + df.to_parquet(f) + f.seek(0) + df_in = pl.read_parquet(f) + tz = ZoneInfo("UTC") + assert df_in["Timestamp"][0] == datetime(2022, 1, 1, 0, 0, tzinfo=tz) + + +@pytest.mark.parametrize( + "ts", + [ + # Pandas support all of the following timezone and the pandas parsing behavior + # is + # - timezone name: pytz.timezone on <=2.x and ZoneInfo on 3.x. + # - timezone offset: datetime.timedelta. + # pytz.timezone. + # ZoneInfo. + # pytz.FixedOffset + # datetime.timedelta. + pd.Timestamp("20200101 00:00", tz=pytz.timezone("America/New_York")), + pd.Timestamp("20200101 00:00", tz=ZoneInfo("America/New_York")), + # TODO: polars currently doesn't retain FixedOffset timezone. They will be + # converted to UTC. Uncomment the following two tests once we support + # FixedOffset timezone. + # pd.Timestamp("20200101 00:00", tz=pytz.FixedOffset(-300)), + # pd.Timestamp("20200101 00:00", tz=timezone(timedelta(days=-1, seconds=68400))) + ], +) +def test_convert_pandas_timezone_info(ts: pd.Timestamp) -> None: + df1 = pl.DataFrame({"date": [ts]}) + df2 = pl.select(date=pl.lit(ts)) + + for df in (df1, df2): + df_ts = df["date"][0] + assert df_ts == ts, (df_ts, ts) + assert df_ts.tzinfo is not None + assert ts.tzinfo is not None + assert df_ts.tzinfo.utcoffset(df_ts) == ts.tzinfo.utcoffset(ts), ( + df_ts, + ts, + ) + + +@pytest.mark.parametrize( + "ts", + [ + pd.Timestamp("20200101 00:00", tz=pytz.FixedOffset(-300)), + pd.Timestamp("20200101 00:00", tz=timezone(timedelta(days=-1, seconds=68400))), + ], +) +def test_convert_pandas_timezone_info_fixed_offset(ts: pd.Timestamp) -> None: + # TODO: polars currently doesn't retain FixedOffset timezone. They will + # be converted to UTC. Remove this test once we support FixedOffset + # timezone. See test_convert_pandas_timezone_info for more details. + df1 = pl.DataFrame({"date": [ts]}) + df2 = pl.select(date=pl.lit(ts)) + + for df in (df1, df2): + df_ts = df["date"][0] + assert df_ts == ts, (df_ts, ts) + assert df_ts.tzinfo == ZoneInfo("UTC") + + +def test_asof_join_tolerance_grouper() -> None: + from datetime import date + + df1 = pl.DataFrame({"date": [date(2020, 1, 5), date(2020, 1, 10)], "by": [1, 1]}) + df2 = pl.DataFrame( + { + "date": [date(2020, 1, 5), date(2020, 1, 6)], + "by": [1, 1], + "values": [100, 200], + } + ) + + out = df1.join_asof(df2, by="by", on=pl.col("date").set_sorted(), tolerance="3d") + + expected = pl.DataFrame( + { + "date": [date(2020, 1, 5), date(2020, 1, 10)], + "by": [1, 1], + "date_right": [date(2020, 1, 5), None], + "values": [100, None], + } + ) + + assert_frame_equal(out, expected) + + +def test_rolling_mean_3020() -> None: + df = pl.DataFrame( + { + "Date": [ + "1998-04-12", + "1998-04-19", + "1998-04-26", + "1998-05-03", + "1998-05-10", + "1998-05-17", + "1998-05-24", + ], + "val": range(7), + } + ).with_columns(pl.col("Date").str.strptime(pl.Date).set_sorted()) + + period: str | timedelta + for period in ("1w", timedelta(days=7)): + result = df.rolling(index_column="Date", period=period).agg( + pl.col("val").mean().alias("val_mean") + ) + expected = pl.DataFrame( + { + "Date": [ + date(1998, 4, 12), + date(1998, 4, 19), + date(1998, 4, 26), + date(1998, 5, 3), + date(1998, 5, 10), + date(1998, 5, 17), + date(1998, 5, 24), + ], + "val_mean": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + } + ) + assert_frame_equal(result, expected) + + +def test_asof_join() -> None: + format = "%F %T%.3f" + dates = [ + "2016-05-25 13:30:00.023", + "2016-05-25 13:30:00.023", + "2016-05-25 13:30:00.030", + "2016-05-25 13:30:00.041", + "2016-05-25 13:30:00.048", + "2016-05-25 13:30:00.049", + "2016-05-25 13:30:00.072", + "2016-05-25 13:30:00.075", + ] + ticker = ["GOOG", "MSFT", "MSFT", "MSFT", "GOOG", "AAPL", "GOOG", "MSFT"] + quotes = pl.DataFrame( + { + "dates": pl.Series(dates).str.strptime(pl.Datetime, format=format), + "ticker": ticker, + "bid": [720.5, 51.95, 51.97, 51.99, 720.50, 97.99, 720.50, 52.01], + } + ).set_sorted("dates") + dates = [ + "2016-05-25 13:30:00.023", + "2016-05-25 13:30:00.038", + "2016-05-25 13:30:00.048", + "2016-05-25 13:30:00.048", + "2016-05-25 13:30:00.048", + ] + ticker = ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"] + trades = pl.DataFrame( + { + "dates": pl.Series(dates).str.strptime(pl.Datetime, format=format), + "ticker": ticker, + "bid": [51.95, 51.95, 720.77, 720.92, 98.0], + } + ).set_sorted("dates") + assert trades.schema == { + "dates": pl.Datetime("ms"), + "ticker": pl.String, + "bid": pl.Float64, + } + out = trades.join_asof(quotes, on="dates", strategy="backward") + + assert out.schema == { + "dates": pl.Datetime("ms"), + "ticker": pl.String, + "bid": pl.Float64, + "ticker_right": pl.String, + "bid_right": pl.Float64, + } + assert out.columns == ["dates", "ticker", "bid", "ticker_right", "bid_right"] + assert (out["dates"].cast(int)).to_list() == [ + 1464183000023, + 1464183000038, + 1464183000048, + 1464183000048, + 1464183000048, + ] + assert trades.join_asof(quotes, on="dates", strategy="forward")[ + "bid_right" + ].to_list() == [720.5, 51.99, 720.5, 720.5, 720.5] + + out = trades.join_asof(quotes, on="dates", by="ticker") + assert out["bid_right"].to_list() == [51.95, 51.97, 720.5, 720.5, None] + + out = quotes.join_asof(trades, on="dates", by="ticker") + assert out["bid_right"].to_list() == [ + None, + 51.95, + 51.95, + 51.95, + 720.92, + 98.0, + 720.92, + 51.95, + ] + assert quotes.join_asof(trades, on="dates", strategy="backward", tolerance="5ms")[ + "bid_right" + ].to_list() == [51.95, 51.95, None, 51.95, 98.0, 98.0, None, None] + assert quotes.join_asof(trades, on="dates", strategy="forward", tolerance="5ms")[ + "bid_right" + ].to_list() == [51.95, 51.95, None, None, 720.77, None, None, None] + + assert trades.join_asof(quotes, on="dates", strategy="nearest")[ + "bid_right" + ].to_list() == [51.95, 51.99, 720.5, 720.5, 720.5] + assert quotes.join_asof(trades, on="dates", strategy="nearest")[ + "bid_right" + ].to_list() == [51.95, 51.95, 51.95, 51.95, 98.0, 98.0, 98.0, 98.0] + + assert trades.sort(by=["ticker", "dates"]).join_asof( + quotes.sort(by=["ticker", "dates"]), on="dates", by="ticker", strategy="nearest" + )["bid_right"].to_list() == [97.99, 720.5, 720.5, 51.95, 51.99] + assert quotes.sort(by=["ticker", "dates"]).join_asof( + trades.sort(by=["ticker", "dates"]), on="dates", by="ticker", strategy="nearest" + )["bid_right"].to_list() == [ + 98.0, + 720.92, + 720.92, + 720.92, + 51.95, + 51.95, + 51.95, + 51.95, + ] + + +@pytest.mark.parametrize( + ("skip_nulls", "expected_value"), + [ + (True, None), + (False, datetime(2010, 9, 12)), + ], +) +def test_temporal_dtypes_map_elements( + skip_nulls: bool, expected_value: datetime | None +) -> None: + df = pl.DataFrame( + {"timestamp": [1284286794000, None, 1234567890000]}, + schema=[("timestamp", pl.Datetime("ms"))], + ) + const_dtm = datetime(2010, 9, 12) + + with pytest.warns( + PolarsInefficientMapWarning, + match=r"(?s)Replace this expression.*lambda x:", + ): + result = df.with_columns( + # don't actually do this; native expressions are MUCH faster ;) + pl.col("timestamp") + .map_elements( + lambda x: const_dtm, + skip_nulls=skip_nulls, + return_dtype=pl.Datetime, + ) + .alias("const_dtm"), + # note: the below now trigger a PolarsInefficientMapWarning + pl.col("timestamp") + .map_elements( + lambda x: x and x.date(), + skip_nulls=skip_nulls, + return_dtype=pl.Date, + ) + .alias("date"), + pl.col("timestamp") + .map_elements( + lambda x: x and x.time(), + skip_nulls=skip_nulls, + return_dtype=pl.Time, + ) + .alias("time"), + ) + expected = pl.DataFrame( + [ + ( + datetime(2010, 9, 12, 10, 19, 54), + datetime(2010, 9, 12, 0, 0), + date(2010, 9, 12), + time(10, 19, 54), + ), + (None, expected_value, None, None), + ( + datetime(2009, 2, 13, 23, 31, 30), + datetime(2010, 9, 12, 0, 0), + date(2009, 2, 13), + time(23, 31, 30), + ), + ], + schema={ + "timestamp": pl.Datetime("ms"), + "const_dtm": pl.Datetime("us"), + "date": pl.Date, + "time": pl.Time, + }, + orient="row", + ) + assert_frame_equal(result, expected) + + +def test_timelike_init() -> None: + durations = [timedelta(days=1), timedelta(days=2)] + dates = [date(2022, 1, 1), date(2022, 1, 2)] + datetimes = [datetime(2022, 1, 1), datetime(2022, 1, 2)] + + for ts in [durations, dates, datetimes]: + s = pl.Series(ts) + assert s.to_list() == ts + + +def test_timedelta_timeunit_init() -> None: + td_us = timedelta(days=7, seconds=45045, microseconds=123456) + td_ms = timedelta(days=7, seconds=45045, microseconds=123000) + + df = pl.DataFrame( + [[td_us, td_us, td_us]], + schema=[ + ("x", pl.Duration("ms")), + ("y", pl.Duration("us")), + ("z", pl.Duration("ns")), + ], + orient="row", + ) + assert df.rows() == [(td_ms, td_us, td_us)] + + +def test_duration_filter() -> None: + df = pl.DataFrame( + { + "start_date": [date(2022, 1, 1), date(2022, 1, 1), date(2022, 1, 1)], + "end_date": [date(2022, 1, 7), date(2022, 2, 20), date(2023, 1, 1)], + } + ).with_columns((pl.col("end_date") - pl.col("start_date")).alias("time_passed")) + + assert df.filter(pl.col("time_passed") < timedelta(days=30)).rows() == [ + (date(2022, 1, 1), date(2022, 1, 7), timedelta(days=6)) + ] + assert df.filter(pl.col("time_passed") >= timedelta(days=30)).rows() == [ + (date(2022, 1, 1), date(2022, 2, 20), timedelta(days=50)), + (date(2022, 1, 1), date(2023, 1, 1), timedelta(days=365)), + ] + + +def test_agg_logical() -> None: + dates = [date(2001, 1, 1), date(2002, 1, 1)] + s = pl.Series(dates) + assert s.max() == dates[1] + assert s.min() == dates[0] + + +def test_from_time_arrow() -> None: + pa_times = pa.table([pa.array([10, 20, 30], type=pa.time32("s"))], names=["times"]) + + result: pl.DataFrame = pl.from_arrow(pa_times) # type: ignore[assignment] + + assert result.to_series().to_list() == [ + time(0, 0, 10), + time(0, 0, 20), + time(0, 0, 30), + ] + assert result.rows() == [ + (time(0, 0, 10),), + (time(0, 0, 20),), + (time(0, 0, 30),), + ] + + +def test_timedelta_from() -> None: + as_dict = { + "A": [1, 2], + "B": [timedelta(seconds=4633), timedelta(seconds=50)], + } + as_rows = [ + { + "A": 1, + "B": timedelta(seconds=4633), + }, + { + "A": 2, + "B": timedelta(seconds=50), + }, + ] + assert_frame_equal(pl.DataFrame(as_dict), pl.DataFrame(as_rows)) + + +def test_duration_aggregations() -> None: + df = pl.DataFrame( + { + "group": ["A", "B", "A", "B"], + "start": [ + datetime(2022, 1, 1), + datetime(2022, 1, 2), + datetime(2022, 1, 3), + datetime(2022, 1, 4), + ], + "end": [ + datetime(2022, 1, 2), + datetime(2022, 1, 4), + datetime(2022, 1, 6), + datetime(2022, 1, 6), + ], + } + ) + df = df.with_columns((pl.col("end") - pl.col("start")).alias("duration")) + assert df.group_by("group", maintain_order=True).agg( + [ + pl.col("duration").mean().alias("mean"), + pl.col("duration").sum().alias("sum"), + pl.col("duration").min().alias("min"), + pl.col("duration").max().alias("max"), + pl.col("duration").quantile(0.1).alias("quantile"), + pl.col("duration").median().alias("median"), + pl.col("duration").alias("list"), + ] + ).to_dict(as_series=False) == { + "group": ["A", "B"], + "mean": [timedelta(days=2), timedelta(days=2)], + "sum": [timedelta(days=4), timedelta(days=4)], + "min": [timedelta(days=1), timedelta(days=2)], + "max": [timedelta(days=3), timedelta(days=2)], + "quantile": [timedelta(days=1), timedelta(days=2)], + "median": [timedelta(days=2), timedelta(days=2)], + "list": [ + [timedelta(days=1), timedelta(days=3)], + [timedelta(days=2), timedelta(days=2)], + ], + } + + +def test_datetime_units() -> None: + df = pl.DataFrame( + { + "ns": pl.datetime_range( + datetime(2020, 1, 1), + datetime(2020, 5, 1), + "1mo", + time_unit="ns", + eager=True, + ), + "us": pl.datetime_range( + datetime(2020, 1, 1), + datetime(2020, 5, 1), + "1mo", + time_unit="us", + eager=True, + ), + "ms": pl.datetime_range( + datetime(2020, 1, 1), + datetime(2020, 5, 1), + "1mo", + time_unit="ms", + eager=True, + ), + } + ) + names = set(df.columns) + + for unit in DTYPE_TEMPORAL_UNITS: + subset = names - {unit} + + assert ( + len(set(df.select([pl.all().exclude(pl.Datetime(unit))]).columns) - subset) + == 0 + ) + + +def test_datetime_instance_selection() -> None: + test_data = { + "ns": [datetime(2022, 12, 31, 1, 2, 3)], + "us": [datetime(2022, 12, 31, 4, 5, 6)], + "ms": [datetime(2022, 12, 31, 7, 8, 9)], + } + df = pl.DataFrame( + data=test_data, + schema=[ + ("ns", pl.Datetime("ns")), + ("us", pl.Datetime("us")), + ("ms", pl.Datetime("ms")), + ], + ) + for time_unit in DTYPE_TEMPORAL_UNITS: + res = df.select(pl.col([pl.Datetime(time_unit)])).dtypes + assert res == [pl.Datetime(time_unit)] + assert len(df.filter(pl.col(time_unit) == test_data[time_unit][0])) == 1 + + assert list(df.select(pl.exclude(DATETIME_DTYPES))) == [] + + +def test_sum_duration() -> None: + assert pl.DataFrame( + [ + {"name": "Jen", "duration": timedelta(seconds=60)}, + {"name": "Mike", "duration": timedelta(seconds=30)}, + {"name": "Jen", "duration": timedelta(seconds=60)}, + ] + ).select( + pl.col("duration").sum(), + pl.col("duration").dt.total_seconds().alias("sec").sum(), + ).to_dict(as_series=False) == { + "duration": [timedelta(seconds=150)], + "sec": [150], + } + + +def test_supertype_timezones_4174() -> None: + df = pl.DataFrame( + { + "dt": pl.datetime_range( + datetime(2020, 3, 1), datetime(2020, 5, 1), "1mo", eager=True + ), + } + ).with_columns( + pl.col("dt").dt.replace_time_zone("Europe/London").name.suffix("_London") + ) + + # test if this runs without error + date_to_fill = df["dt_London"][0] + df.with_columns(df["dt_London"].shift(fill_value=date_to_fill)) + + +def test_from_dict_tu_consistency() -> None: + tz = ZoneInfo("PRC") + dt = datetime(2020, 8, 1, 12, 0, 0, tzinfo=tz) + from_dict = pl.from_dict({"dt": [dt]}) + from_dicts = pl.from_dicts([{"dt": dt}]) + + assert from_dict.dtypes == from_dicts.dtypes + + +def test_date_arr_concat() -> None: + expected = {"d": [[date(2000, 1, 1), date(2000, 1, 1)]]} + + # type date + df = pl.DataFrame({"d": [date(2000, 1, 1)]}) + assert ( + df.select(pl.col("d").list.concat(pl.col("d"))).to_dict(as_series=False) + == expected + ) + # type list[date] + df = pl.DataFrame({"d": [[date(2000, 1, 1)]]}) + assert ( + df.select(pl.col("d").list.concat(pl.col("d"))).to_dict(as_series=False) + == expected + ) + + +def test_date_timedelta() -> None: + df = pl.DataFrame( + {"date": pl.date_range(date(2001, 1, 1), date(2001, 1, 3), "1d", eager=True)} + ) + assert df.with_columns( + (pl.col("date") + timedelta(days=1)).alias("date_plus_one"), + (pl.col("date") - timedelta(days=1)).alias("date_min_one"), + ).to_dict(as_series=False) == { + "date": [date(2001, 1, 1), date(2001, 1, 2), date(2001, 1, 3)], + "date_plus_one": [date(2001, 1, 2), date(2001, 1, 3), date(2001, 1, 4)], + "date_min_one": [date(2000, 12, 31), date(2001, 1, 1), date(2001, 1, 2)], + } + + +def test_datetime_hashes() -> None: + dtypes = ( + pl.Datetime, + pl.Datetime("us"), + pl.Datetime("us", "UTC"), + pl.Datetime("us", "Zulu"), + ) + assert len({hash(tp) for tp in (dtypes)}) == 4 + + +def test_datetime_string_casts() -> None: + df = pl.DataFrame( + { + "x": [1661855445123], + "y": [1661855445123456], + "z": [1661855445123456789], + }, + schema=[ + ("x", pl.Datetime("ms")), + ("y", pl.Datetime("us")), + ("z", pl.Datetime("ns")), + ], + ) + assert df.select( + [pl.col("x").dt.to_string("%F %T").alias("w")] + + [pl.col(d).cast(str) for d in df.columns] + ).rows() == [ + ( + "2022-08-30 10:30:45", + "2022-08-30 10:30:45.123", + "2022-08-30 10:30:45.123456", + "2022-08-30 10:30:45.123456789", + ) + ] + assert df.select( + pl.col("x").dt.to_string("iso").name.suffix(":iso"), + pl.col("y").dt.to_string("iso").name.suffix(":iso"), + pl.col("z").dt.to_string("iso").name.suffix(":iso"), + ).rows() == [ + ( + "2022-08-30 10:30:45.123", + "2022-08-30 10:30:45.123456", + "2022-08-30 10:30:45.123456789", + ) + ] + assert df.select( + pl.col("x").dt.to_string("iso:strict").name.suffix(":iso:strict"), + pl.col("y").dt.to_string("iso:strict").name.suffix(":iso:strict"), + pl.col("z").dt.to_string("iso:strict").name.suffix(":iso:strict"), + ).rows() == [ + ( + "2022-08-30T10:30:45.123", + "2022-08-30T10:30:45.123456", + "2022-08-30T10:30:45.123456789", + ) + ] + + +def test_temporal_to_string_iso_default() -> None: + df = pl.DataFrame( + { + "td": [ + timedelta(days=-1, seconds=-42), + timedelta(days=14, hours=-10, microseconds=1001), + timedelta(seconds=0), + ], + "tm": [ + time(1, 2, 3, 456789), + time(23, 59, 9, 101), + time(0), + ], + "dt": [ + date(1999, 3, 1), + date(2020, 5, 3), + date(2077, 7, 5), + ], + "dtm": [ + datetime(1980, 8, 10, 0, 10, 20), + datetime(2010, 10, 20, 8, 25, 35), + datetime(2040, 12, 30, 16, 40, 50), + ], + } + ).with_columns(dtm_tz=pl.col("dtm").dt.replace_time_zone("Asia/Kathmandu")) + + df_stringified = df.select( + pl.col("td").dt.to_string("polars").alias("td_pl"), + cs.temporal().dt.to_string().name.suffix(":iso"), + cs.datetime().dt.to_string("iso:strict").name.suffix(":iso:strict"), + ) + assert df_stringified.to_dict(as_series=False) == { + "td_pl": [ + "-1d -42s", + "13d 14h 1001µs", + "0µs", + ], + "td:iso": [ + "-P1DT42S", + "P13DT14H0.001001S", + "PT0S", + ], + "tm:iso": [ + "01:02:03.456789", + "23:59:09.000101", + "00:00:00", + ], + "dt:iso": [ + "1999-03-01", + "2020-05-03", + "2077-07-05", + ], + "dtm:iso": [ + "1980-08-10 00:10:20.000000", + "2010-10-20 08:25:35.000000", + "2040-12-30 16:40:50.000000", + ], + "dtm_tz:iso": [ + "1980-08-10 00:10:20.000000+05:30", + "2010-10-20 08:25:35.000000+05:45", + "2040-12-30 16:40:50.000000+05:45", + ], + "dtm:iso:strict": [ + "1980-08-10T00:10:20.000000", + "2010-10-20T08:25:35.000000", + "2040-12-30T16:40:50.000000", + ], + "dtm_tz:iso:strict": [ + "1980-08-10T00:10:20.000000+05:30", + "2010-10-20T08:25:35.000000+05:45", + "2040-12-30T16:40:50.000000+05:45", + ], + } + + +def test_temporal_to_string_error() -> None: + df = pl.DataFrame({"td": [timedelta(days=1)], "dt": [date(2024, 11, 25)]}) + with pytest.raises( + InvalidOperationError, + match="'polars' is not a valid `to_string` format for date dtype expressions", + ): + df.select(cs.temporal().dt.to_string("polars")) + + +def test_iso_year() -> None: + assert pl.Series([datetime(2022, 1, 1, 7, 8, 40)]).dt.iso_year()[0] == 2021 + assert pl.Series([date(2022, 1, 1)]).dt.iso_year()[0] == 2021 + + +def test_replace_time_zone() -> None: + ny = ZoneInfo("America/New_York") + assert pl.DataFrame({"a": [datetime(2022, 9, 25, 14)]}).with_columns( + pl.col("a").dt.replace_time_zone("America/New_York").alias("b") + ).to_dict(as_series=False) == { + "a": [datetime(2022, 9, 25, 14, 0)], + "b": [datetime(2022, 9, 25, 14, 0, tzinfo=ny)], + } + + +def test_replace_time_zone_non_existent_null() -> None: + result = ( + pl.Series(["2021-03-28 02:30", "2021-03-28 03:30"]) + .str.to_datetime() + .dt.replace_time_zone("Europe/Warsaw", non_existent="null") + ) + expected = pl.Series([None, datetime(2021, 3, 28, 3, 30)]).dt.replace_time_zone( + "Europe/Warsaw" + ) + assert_series_equal(result, expected) + + +def test_invalid_non_existent() -> None: + with pytest.raises( + ValueError, match="`non_existent` must be one of {'null', 'raise'}, got cabbage" + ): + ( + pl.Series([datetime(2020, 1, 1)]).dt.replace_time_zone( + "Europe/Warsaw", + non_existent="cabbage", # type: ignore[arg-type] + ) + ) + + +@pytest.mark.parametrize( + ("to_tz", "tzinfo"), + [ + ("America/Barbados", ZoneInfo("America/Barbados")), + (None, None), + ], +) +@pytest.mark.parametrize("from_tz", ["Asia/Seoul", None]) +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_replace_time_zone_from_to( + from_tz: str, + to_tz: str, + tzinfo: timezone | ZoneInfo, + time_unit: TimeUnit, +) -> None: + ts = pl.Series(["2020-01-01"]).str.strptime(pl.Datetime(time_unit)) + result = ts.dt.replace_time_zone(from_tz).dt.replace_time_zone(to_tz).item() + expected = datetime(2020, 1, 1, 0, 0, tzinfo=tzinfo) + assert result == expected + + +def test_strptime_with_tz() -> None: + result = ( + pl.Series(["2020-01-01 03:00:00"]) + .str.strptime(pl.Datetime("us", "Africa/Monrovia")) + .item() + ) + assert result == datetime(2020, 1, 1, 3, tzinfo=ZoneInfo("Africa/Monrovia")) + + +@pytest.mark.parametrize( + ("time_unit", "time_zone"), + [ + ("us", "Europe/London"), + ("ms", None), + ("ns", "Africa/Lagos"), + ], +) +def test_strptime_empty(time_unit: TimeUnit, time_zone: str | None) -> None: + ts = ( + pl.Series([None]) + .cast(pl.String) + .str.strptime(pl.Datetime(time_unit, time_zone)) + ) + assert ts.dtype == pl.Datetime(time_unit, time_zone) + + +def test_strptime_with_invalid_tz() -> None: + with pytest.raises(ComputeError, match="unable to parse time zone: 'foo'"): + pl.Series(["2020-01-01 03:00:00"]).str.strptime(pl.Datetime("us", "foo")) + with pytest.raises( + ComputeError, + match="unable to parse time zone: 'foo'", + ): + pl.Series(["2020-01-01 03:00:00+01:00"]).str.strptime( + pl.Datetime("us", "foo"), "%Y-%m-%d %H:%M:%S%z" + ) + + +def test_strptime_unguessable_format() -> None: + with pytest.raises( + ComputeError, + match="could not find an appropriate format to parse dates, please define a format", + ): + pl.Series(["foobar"]).str.strptime(pl.Datetime) + + +def test_convert_time_zone_invalid() -> None: + ts = pl.Series(["2020-01-01"]).str.strptime(pl.Datetime) + with pytest.raises(ComputeError, match="unable to parse time zone: 'foo'"): + ts.dt.replace_time_zone("UTC").dt.convert_time_zone("foo") + + +def test_convert_time_zone_lazy_schema() -> None: + ts_us = pl.Series(["2020-01-01"]).str.strptime(pl.Datetime("us", "UTC")) + ts_ms = pl.Series(["2020-01-01"]).str.strptime(pl.Datetime("ms", "UTC")) + ldf = pl.DataFrame({"ts_us": ts_us, "ts_ms": ts_ms}).lazy() + result = ldf.with_columns( + pl.col("ts_us").dt.convert_time_zone("America/New_York").alias("ts_us_ny"), + pl.col("ts_ms").dt.convert_time_zone("America/New_York").alias("ts_us_kt"), + ).collect_schema() + expected = { + "ts_us": pl.Datetime("us", "UTC"), + "ts_ms": pl.Datetime("ms", "UTC"), + "ts_us_ny": pl.Datetime("us", "America/New_York"), + "ts_us_kt": pl.Datetime("ms", "America/New_York"), + } + assert result == expected + + +def test_convert_time_zone_on_tz_naive() -> None: + ts = pl.Series(["2020-01-01"]).str.strptime(pl.Datetime) + result = ts.dt.convert_time_zone("Asia/Kathmandu").item() + expected = datetime(2020, 1, 1, 5, 45, tzinfo=ZoneInfo("Asia/Kathmandu")) + assert result == expected + result = ( + ts.dt.replace_time_zone("UTC").dt.convert_time_zone("Asia/Kathmandu").item() + ) + assert result == expected + + +def test_tz_aware_get_idx_5010() -> None: + when = int(datetime(2022, 1, 1, 12, tzinfo=ZoneInfo("Asia/Shanghai")).timestamp()) + a = pa.array([when]).cast(pa.timestamp("s", tz="Asia/Shanghai")) + assert int(pl.from_arrow(a)[0].timestamp()) == when # type: ignore[union-attr] + + +def test_tz_datetime_duration_arithm_5221() -> None: + run_datetimes = [ + datetime.fromisoformat("2022-01-01T00:00:00+00:00"), + datetime.fromisoformat("2022-01-02T00:00:00+00:00"), + ] + out = pl.DataFrame( + data={"run_datetime": run_datetimes}, + schema=[("run_datetime", pl.Datetime(time_zone="UTC"))], + ) + out = out.with_columns(pl.col("run_datetime") + pl.duration(days=1)) + utc = ZoneInfo("UTC") + assert out.to_dict(as_series=False) == { + "run_datetime": [ + datetime(2022, 1, 2, 0, 0, tzinfo=utc), + datetime(2022, 1, 3, 0, 0, tzinfo=utc), + ] + } + + +def test_auto_infer_time_zone() -> None: + dt = datetime(2022, 10, 17, 10, tzinfo=ZoneInfo("Asia/Shanghai")) + s = pl.Series([dt]) + assert s.dtype == pl.Datetime("us", "Asia/Shanghai") + assert s[0] == dt + + +def test_logical_nested_take() -> None: + frame = pl.DataFrame( + { + "ix": [2, 1], + "dt": [[datetime(2001, 1, 1)], [datetime(2001, 1, 2)]], + "d": [[date(2001, 1, 1)], [date(2001, 1, 1)]], + "t": [[time(10)], [time(10)]], + "del": [[timedelta(10)], [timedelta(10)]], + "str": [[{"a": time(10)}], [{"a": time(10)}]], + } + ) + out = frame.sort(by="ix") + + assert out.dtypes[:-1] == [ + pl.Int64, + pl.List(pl.Datetime("us")), + pl.List(pl.Date), + pl.List(pl.Time), + pl.List(pl.Duration("us")), + ] + assert out.to_dict(as_series=False) == { + "ix": [1, 2], + "dt": [[datetime(2001, 1, 2, 0, 0)], [datetime(2001, 1, 1, 0, 0)]], + "d": [[date(2001, 1, 1)], [date(2001, 1, 1)]], + "t": [[time(10, 0)], [time(10, 0)]], + "del": [[timedelta(days=10)], [timedelta(days=10)]], + "str": [[{"a": time(10, 0)}], [{"a": time(10, 0)}]], + } + + +def test_replace_time_zone_from_naive() -> None: + df = pl.DataFrame( + { + "date": pl.Series(["2022-01-01", "2022-01-02"]).str.strptime( + pl.Date, "%Y-%m-%d" + ) + } + ) + + assert df.select( + pl.col("date").cast(pl.Datetime).dt.replace_time_zone("America/New_York") + ).to_dict(as_series=False) == { + "date": [ + datetime(2022, 1, 1, 0, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 1, 2, 0, 0, tzinfo=ZoneInfo("America/New_York")), + ] + } + + +@pytest.mark.parametrize( + ("ambiguous", "expected"), + [ + ( + "latest", + datetime(2018, 10, 28, 2, 30, fold=0, tzinfo=ZoneInfo("Europe/Brussels")), + ), + ( + "earliest", + datetime(2018, 10, 28, 2, 30, fold=1, tzinfo=ZoneInfo("Europe/Brussels")), + ), + ], +) +def test_replace_time_zone_ambiguous_with_ambiguous( + ambiguous: Ambiguous, expected: datetime +) -> None: + ts = pl.Series(["2018-10-28 02:30:00"]).str.strptime(pl.Datetime) + result = ts.dt.replace_time_zone("Europe/Brussels", ambiguous=ambiguous).item() + assert result == expected + + +def test_replace_time_zone_ambiguous_raises() -> None: + ts = pl.Series(["2018-10-28 02:30:00"]).str.strptime(pl.Datetime) + with pytest.raises( + ComputeError, + match="Please use `ambiguous` to tell how it should be localized", + ): + ts.dt.replace_time_zone("Europe/Brussels") + + +@pytest.mark.parametrize( + ("from_tz", "expected_sortedness", "ambiguous"), + [ + ("Europe/London", False, "earliest"), + ("UTC", True, "earliest"), + ("UTC", True, "raise"), + (None, True, "earliest"), + (None, True, "raise"), + ], +) +def test_replace_time_zone_sortedness_series( + from_tz: str | None, expected_sortedness: bool, ambiguous: Ambiguous +) -> None: + ser = ( + pl.Series("ts", [1603584000000001, 1603587600000000]) + .cast(pl.Datetime("us", from_tz)) + .sort() + ) + assert ser.flags["SORTED_ASC"] + result = ser.dt.replace_time_zone("UTC", ambiguous=ambiguous) + assert result.flags["SORTED_ASC"] == expected_sortedness + assert result.flags["SORTED_ASC"] == result.is_sorted() + + +@pytest.mark.parametrize( + ("from_tz", "expected_sortedness", "ambiguous"), + [ + ("Europe/London", False, "earliest"), + ("UTC", True, "earliest"), + ("UTC", True, "raise"), + (None, True, "earliest"), + (None, True, "raise"), + ], +) +def test_replace_time_zone_sortedness_expressions( + from_tz: str | None, expected_sortedness: bool, ambiguous: str +) -> None: + df = ( + pl.Series("ts", [1603584000000001, 1603584060000000, 1603587600000000]) + .cast(pl.Datetime("us", from_tz)) + .sort() + .to_frame() + ) + df = df.with_columns(ambiguous=pl.Series([ambiguous] * 3)) + assert df["ts"].flags["SORTED_ASC"] + result = df.select( + pl.col("ts").dt.replace_time_zone("UTC", ambiguous=pl.col("ambiguous")) + ) + assert result["ts"].flags["SORTED_ASC"] == expected_sortedness + assert result["ts"].is_sorted() == expected_sortedness + + +def test_invalid_ambiguous_value_in_expression() -> None: + df = pl.DataFrame( + {"a": [datetime(2020, 10, 25, 1)] * 2, "b": ["earliest", "cabbage"]} + ) + with pytest.raises(InvalidOperationError, match="Invalid argument cabbage"): + df.select( + pl.col("a").dt.replace_time_zone("Europe/London", ambiguous=pl.col("b")) + ) + + +def test_replace_time_zone_ambiguous_null() -> None: + df = pl.DataFrame( + { + "a": [datetime(2020, 10, 25, 1)] * 3 + [None], + "b": ["earliest", "latest", "null", "raise"], + } + ) + # expression containing 'null' + result = df.select( + pl.col("a").dt.replace_time_zone("Europe/London", ambiguous=pl.col("b")) + )["a"] + expected = [ + datetime(2020, 10, 25, 1, fold=0, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 1, fold=1, tzinfo=ZoneInfo("Europe/London")), + None, + None, + ] + assert result[0] == expected[0] + assert result[1] == expected[1] + assert result[2] == expected[2] + assert result[3] == expected[3] + + # single 'null' value + result = df.select( + pl.col("a").dt.replace_time_zone("Europe/London", ambiguous="null") + )["a"] + assert result[0] is None + assert result[1] is None + assert result[2] is None + assert result[3] is None + + +def test_ambiguous_expressions() -> None: + # strptime + df = pl.DataFrame( + { + "ts": ["2020-10-25 01:00"] * 2, + "ambiguous": ["earliest", "latest"], + } + ) + result = df.select( + pl.col("ts").str.strptime( + pl.Datetime("us", "Europe/London"), ambiguous=pl.col("ambiguous") + ) + )["ts"] + expected = pl.Series("ts", [1603584000000000, 1603587600000000]).cast( + pl.Datetime("us", "Europe/London") + ) + assert_series_equal(result, expected) + + # truncate + df = pl.DataFrame( + { + "ts": [datetime(2020, 10, 25, 1), datetime(2020, 10, 25, 1)], + "ambiguous": ["earliest", "latest"], + } + ) + df = df.with_columns( + pl.col("ts").dt.replace_time_zone( + "Europe/London", ambiguous=pl.col("ambiguous") + ) + ) + result = df.select(pl.col("ts").dt.truncate("1h"))["ts"] + expected = pl.Series("ts", [1603584000000000, 1603587600000000]).cast( + pl.Datetime("us", "Europe/London") + ) + assert_series_equal(result, expected) + + # replace_time_zone + df = pl.DataFrame( + { + "ts": [datetime(2020, 10, 25, 1), datetime(2020, 10, 25, 1)], + "ambiguous": ["earliest", "latest"], + } + ) + result = df.select( + pl.col("ts").dt.replace_time_zone( + "Europe/London", ambiguous=pl.col("ambiguous") + ) + )["ts"] + expected = pl.Series("ts", [1603584000000000, 1603587600000000]).cast( + pl.Datetime("us", "Europe/London") + ) + assert_series_equal(result, expected) + + # pl.datetime + df = pl.DataFrame( + { + "year": [2020] * 2, + "month": [10] * 2, + "day": [25] * 2, + "hour": [1] * 2, + "minute": [0] * 2, + "ambiguous": ["earliest", "latest"], + } + ) + result = df.select( + pl.datetime( + "year", + "month", + "day", + "hour", + "minute", + time_zone="Europe/London", + ambiguous=pl.col("ambiguous"), + ) + )["datetime"] + expected = pl.DataFrame( + {"datetime": [1603584000000000, 1603587600000000]}, + schema={"datetime": pl.Datetime("us", "Europe/London")}, + )["datetime"] + assert_series_equal(result, expected) + + +def test_single_ambiguous_null() -> None: + df = pl.DataFrame( + {"ts": [datetime(2020, 10, 2, 1, 1)], "ambiguous": [None]}, + schema_overrides={"ambiguous": pl.String}, + ) + result = df.select( + pl.col("ts").dt.replace_time_zone( + "Europe/London", ambiguous=pl.col("ambiguous") + ) + )["ts"].item() + assert result is None + + +def test_unlocalize() -> None: + tz_naive = pl.Series(["2020-01-01 03:00:00"]).str.strptime(pl.Datetime) + tz_aware = tz_naive.dt.replace_time_zone("UTC").dt.convert_time_zone( + "Europe/Brussels" + ) + result = tz_aware.dt.replace_time_zone(None).item() + assert result == datetime(2020, 1, 1, 4) + + +def test_tz_aware_truncate() -> None: + df = pl.DataFrame( + { + "dt": pl.datetime_range( + start=datetime(2022, 11, 1), + end=datetime(2022, 11, 4), + interval="12h", + eager=True, + ).dt.replace_time_zone("America/New_York") + } + ) + result = df.with_columns(pl.col("dt").dt.truncate("1d").alias("trunced")) + expected = { + "dt": [ + datetime(2022, 11, 1, 0, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 11, 1, 12, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 11, 2, 0, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 11, 2, 12, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 11, 3, 0, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 11, 3, 12, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 11, 4, 0, 0, tzinfo=ZoneInfo("America/New_York")), + ], + "trunced": [ + datetime(2022, 11, 1, 0, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 11, 1, 0, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 11, 2, 0, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 11, 2, 0, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 11, 3, 0, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 11, 3, 0, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 11, 4, 0, 0, tzinfo=ZoneInfo("America/New_York")), + ], + } + assert result.to_dict(as_series=False) == expected + + # 5507 + lf = pl.DataFrame( + { + "naive": pl.datetime_range( + start=datetime(2021, 12, 31, 23), + end=datetime(2022, 1, 1, 6), + interval="1h", + eager=True, + ) + } + ).lazy() + lf = lf.with_columns(pl.col("naive").dt.replace_time_zone("UTC").alias("UTC")) + lf = lf.with_columns(pl.col("UTC").dt.convert_time_zone("US/Central").alias("CST")) + lf = lf.with_columns(pl.col("CST").dt.truncate("1d").alias("CST truncated")) + assert lf.collect().to_dict(as_series=False) == { + "naive": [ + datetime(2021, 12, 31, 23, 0), + datetime(2022, 1, 1, 0, 0), + datetime(2022, 1, 1, 1, 0), + datetime(2022, 1, 1, 2, 0), + datetime(2022, 1, 1, 3, 0), + datetime(2022, 1, 1, 4, 0), + datetime(2022, 1, 1, 5, 0), + datetime(2022, 1, 1, 6, 0), + ], + "UTC": [ + datetime(2021, 12, 31, 23, 0, tzinfo=ZoneInfo("UTC")), + datetime(2022, 1, 1, 0, 0, tzinfo=ZoneInfo("UTC")), + datetime(2022, 1, 1, 1, 0, tzinfo=ZoneInfo("UTC")), + datetime(2022, 1, 1, 2, 0, tzinfo=ZoneInfo("UTC")), + datetime(2022, 1, 1, 3, 0, tzinfo=ZoneInfo("UTC")), + datetime(2022, 1, 1, 4, 0, tzinfo=ZoneInfo("UTC")), + datetime(2022, 1, 1, 5, 0, tzinfo=ZoneInfo("UTC")), + datetime(2022, 1, 1, 6, 0, tzinfo=ZoneInfo("UTC")), + ], + "CST": [ + datetime(2021, 12, 31, 17, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 12, 31, 18, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 12, 31, 19, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 12, 31, 20, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 12, 31, 21, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 12, 31, 22, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 12, 31, 23, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2022, 1, 1, 0, 0, tzinfo=ZoneInfo("US/Central")), + ], + "CST truncated": [ + datetime(2021, 12, 31, 0, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 12, 31, 0, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 12, 31, 0, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 12, 31, 0, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 12, 31, 0, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 12, 31, 0, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 12, 31, 0, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2022, 1, 1, 0, 0, tzinfo=ZoneInfo("US/Central")), + ], + } + + +def test_to_string_invalid_format() -> None: + tz_naive = pl.Series(["2020-01-01"]).str.strptime(pl.Datetime) + with pytest.raises( + ComputeError, match="cannot format timezone-naive Datetime with format '%z'" + ): + tz_naive.dt.to_string("%z") + + +def test_tz_aware_to_string() -> None: + df = pl.DataFrame( + { + "dt": pl.datetime_range( + start=datetime(2022, 11, 1), + end=datetime(2022, 11, 4), + interval="24h", + eager=True, + ).dt.replace_time_zone("America/New_York") + } + ) + result = df.with_columns(pl.col("dt").dt.to_string("%c").alias("fmt")) + expected = { + "dt": [ + datetime(2022, 11, 1, 0, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 11, 2, 0, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 11, 3, 0, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 11, 4, 0, 0, tzinfo=ZoneInfo("America/New_York")), + ], + "fmt": [ + "Tue Nov 1 00:00:00 2022", + "Wed Nov 2 00:00:00 2022", + "Thu Nov 3 00:00:00 2022", + "Fri Nov 4 00:00:00 2022", + ], + } + assert result.to_dict(as_series=False) == expected + + +@pytest.mark.parametrize( + ("time_zone", "directive", "expected"), + [ + ("Pacific/Pohnpei", "%z", "+1100"), + ("Pacific/Pohnpei", "%Z", "+11"), + ], +) +def test_tz_aware_with_timezone_directive( + time_zone: str, directive: str, expected: str +) -> None: + tz_naive = pl.Series(["2020-01-01 03:00:00"]).str.strptime(pl.Datetime) + tz_aware = tz_naive.dt.replace_time_zone(time_zone) + result = tz_aware.dt.to_string(directive).item() + assert result == expected + + +def test_local_time_zone_name() -> None: + ser = pl.Series(["2020-01-01 03:00ACST"]).str.strptime( + pl.Datetime, "%Y-%m-%d %H:%M%Z" + ) + result = ser[0] + expected = datetime(2020, 1, 1, 3) + assert result == expected + + +def test_tz_aware_filter_lit() -> None: + start = datetime(1970, 1, 1) + stop = datetime(1970, 1, 1, 7) + dt = datetime(1970, 1, 1, 6, tzinfo=ZoneInfo("America/New_York")) + + assert ( + pl.DataFrame({"date": pl.datetime_range(start, stop, "1h", eager=True)}) + .with_columns( + pl.col("date").dt.replace_time_zone("America/New_York").alias("nyc") + ) + .filter(pl.col("nyc") < dt) + ).to_dict(as_series=False) == { + "date": [ + datetime(1970, 1, 1, 0, 0), + datetime(1970, 1, 1, 1, 0), + datetime(1970, 1, 1, 2, 0), + datetime(1970, 1, 1, 3, 0), + datetime(1970, 1, 1, 4, 0), + datetime(1970, 1, 1, 5, 0), + ], + "nyc": [ + datetime(1970, 1, 1, 0, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(1970, 1, 1, 1, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(1970, 1, 1, 2, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(1970, 1, 1, 3, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(1970, 1, 1, 4, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(1970, 1, 1, 5, 0, tzinfo=ZoneInfo("America/New_York")), + ], + } + + +def test_asof_join_by_forward() -> None: + dfa = pl.DataFrame( + {"category": ["a", "a", "a", "a", "a"], "value_one": [1, 2, 3, 5, 12]} + ).set_sorted("value_one") + + dfb = pl.DataFrame({"category": ["a"], "value_two": [3]}).set_sorted("value_two") + + assert dfa.join_asof( + dfb, + left_on="value_one", + right_on="value_two", + by="category", + strategy="forward", + ).to_dict(as_series=False) == { + "category": ["a", "a", "a", "a", "a"], + "value_one": [1, 2, 3, 5, 12], + "value_two": [3, 3, 3, None, None], + } + + +def test_truncate_broadcast_left() -> None: + df = pl.DataFrame({"every": [None, "1y", "1mo", "1d", "1h"]}) + out = df.select( + date=pl.lit(date(2024, 4, 19)).dt.truncate(pl.col("every")), + datetime=pl.lit(datetime(2024, 4, 19, 10, 30, 20)).dt.truncate(pl.col("every")), + ) + expected = pl.DataFrame( + { + "date": [ + None, + date(2024, 1, 1), + date(2024, 4, 1), + date(2024, 4, 19), + date(2024, 4, 19), + ], + "datetime": [ + None, + datetime(2024, 1, 1), + datetime(2024, 4, 1), + datetime(2024, 4, 19), + datetime(2024, 4, 19, 10), + ], + } + ) + assert_frame_equal(out, expected) + + +def test_truncate_expr() -> None: + df = pl.DataFrame( + { + "date": [ + datetime(2022, 11, 14), + datetime(2023, 10, 11), + datetime(2022, 3, 20, 5, 7, 18), + datetime(2022, 4, 3, 13, 30, 32), + ], + "every": ["1y", "1mo", "1m", "1s"], + } + ) + + every_expr = df.select(pl.col("date").dt.truncate(every=pl.col("every"))) + assert every_expr.to_dict(as_series=False) == { + "date": [ + datetime(2022, 1, 1), + datetime(2023, 10, 1), + datetime(2022, 3, 20, 5, 7), + datetime(2022, 4, 3, 13, 30, 32), + ] + } + + all_lit = df.select(pl.col("date").dt.truncate(every=pl.lit("1mo"))) + assert all_lit.to_dict(as_series=False) == { + "date": [ + datetime(2022, 11, 1), + datetime(2023, 10, 1), + datetime(2022, 3, 1), + datetime(2022, 4, 1), + ] + } + + df = pl.DataFrame( + { + "date": pl.datetime_range( + date(2020, 10, 25), + datetime(2020, 10, 25, 2), + "30m", + eager=True, + time_zone="Europe/London", + ).dt.offset_by("15m"), + "every": ["30m", "15m", "30m", "15m", "30m", "15m", "30m"], + } + ) + + # How to deal with ambiguousness is auto-inferred + ambiguous_expr = df.select(pl.col("date").dt.truncate(every=pl.lit("30m"))) + assert ambiguous_expr.to_dict(as_series=False) == { + "date": [ + datetime(2020, 10, 25, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 0, 30, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 1, 0, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 1, 30, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 1, 0, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 1, 30, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 2, 0, tzinfo=ZoneInfo("Europe/London")), + ] + } + + all_expr = df.select(pl.col("date").dt.truncate(every=pl.col("every"))) + assert all_expr.to_dict(as_series=False) == { + "date": [ + datetime(2020, 10, 25, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 0, 45, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 1, 0, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 1, 45, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 1, 0, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 1, 45, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 2, 0, tzinfo=ZoneInfo("Europe/London")), + ] + } + + +def test_truncate_propagate_null() -> None: + df = pl.DataFrame( + { + "date": [ + None, + datetime(2022, 11, 14), + datetime(2022, 3, 20, 5, 7, 18), + ], + "every": ["1y", None, "1m"], + } + ) + assert df.select(pl.col("date").dt.truncate(every=pl.col("every"))).to_dict( + as_series=False + ) == {"date": [None, None, datetime(2022, 3, 20, 5, 7, 0)]} + assert df.select( + pl.col("date").dt.truncate( + every=pl.lit(None, dtype=pl.String), + ) + ).to_dict(as_series=False) == {"date": [None, None, None]} + + +def test_truncate_by_calendar_weeks() -> None: + # 5557 + start = datetime(2022, 11, 14, 0, 0, 0) + end = datetime(2022, 11, 20, 0, 0, 0) + + assert ( + pl.datetime_range(start, end, timedelta(days=1), eager=True) + .alias("date") + .to_frame() + .select([pl.col("date").dt.truncate("1w")]) + ).to_dict(as_series=False) == { + "date": [ + datetime(2022, 11, 14), + datetime(2022, 11, 14), + datetime(2022, 11, 14), + datetime(2022, 11, 14), + datetime(2022, 11, 14), + datetime(2022, 11, 14), + datetime(2022, 11, 14), + ], + } + + df = pl.DataFrame( + { + "date": pl.Series(["1768-03-01", "2023-01-01"]).str.strptime( + pl.Date, "%Y-%m-%d" + ) + } + ) + + assert df.select(pl.col("date").dt.truncate("1w")).to_dict(as_series=False) == { + "date": [ + date(1768, 2, 29), + date(2022, 12, 26), + ], + } + + +def test_truncate_by_multiple_weeks() -> None: + df = pl.DataFrame( + { + "date": pl.Series( + [ + # Wednesday and Monday + "2022-04-20", + "2022-11-28", + ] + ).str.strptime(pl.Date, "%Y-%m-%d") + } + ) + + assert ( + df.select( + pl.col("date").dt.truncate("2w").alias("2w"), + pl.col("date").dt.truncate("3w").alias("3w"), + pl.col("date").dt.truncate("4w").alias("4w"), + pl.col("date").dt.truncate("5w").alias("5w"), + pl.col("date").dt.truncate("17w").alias("17w"), + ) + ).to_dict(as_series=False) == { + "2w": [date(2022, 4, 18), date(2022, 11, 28)], + "3w": [date(2022, 4, 11), date(2022, 11, 28)], + "4w": [date(2022, 4, 18), date(2022, 11, 28)], + "5w": [date(2022, 3, 28), date(2022, 11, 28)], + "17w": [date(2022, 2, 21), date(2022, 10, 17)], + } + + +def test_truncate_by_multiple_weeks_diffs() -> None: + df = pl.DataFrame( + { + "ts": pl.date_range(date(2020, 1, 1), date(2020, 2, 1), eager=True), + } + ) + result = df.select( + pl.col("ts").dt.truncate("1w").alias("1w"), + pl.col("ts").dt.truncate("2w").alias("2w"), + pl.col("ts").dt.truncate("3w").alias("3w"), + ).select(pl.all().diff().drop_nulls().unique(maintain_order=True)) + expected = pl.DataFrame( + { + "1w": [timedelta(0), timedelta(days=7)], + "2w": [timedelta(0), timedelta(days=14)], + "3w": [timedelta(0), timedelta(days=21)], + } + ).select(pl.all().cast(pl.Duration("ms"))) + assert_frame_equal(result, expected) + + +def test_truncate_ambiguous() -> None: + ser = ( + pl.datetime_range( + date(2020, 10, 25), + datetime(2020, 10, 25, 2), + "30m", + eager=True, + time_zone="Europe/London", + ) + .alias("datetime") + .dt.offset_by("15m") + ) + df = ser.to_frame() + result = df.select(pl.col("datetime").dt.truncate("30m")) + expected = ( + pl.datetime_range( + date(2020, 10, 25), + datetime(2020, 10, 25, 2), + "30m", + eager=True, + time_zone="Europe/London", + ) + .alias("datetime") + .to_frame() + ) + assert_frame_equal(result, expected) + + +def test_truncate_ambiguous2() -> None: + ser = ( + pl.datetime_range( + date(2020, 10, 25), + datetime(2020, 10, 25, 2), + "30m", + eager=True, + time_zone="Europe/London", + ) + .alias("datetime") + .dt.offset_by("15m") + ) + result = ser.dt.truncate("30m") + expected = ( + pl.Series( + [ + "2020-10-25T00:00:00+0100", + "2020-10-25T00:30:00+0100", + "2020-10-25T01:00:00+0100", + "2020-10-25T01:30:00+0100", + "2020-10-25T01:00:00+0000", + "2020-10-25T01:30:00+0000", + "2020-10-25T02:00:00+0000", + ] + ) + .str.to_datetime() + .dt.convert_time_zone("Europe/London") + .rename("datetime") + ) + assert_series_equal(result, expected) + + +def test_truncate_non_existent_14957() -> None: + with pytest.raises(ComputeError, match="non-existent"): + pl.Series([datetime(2020, 3, 29, 2, 1)]).dt.replace_time_zone( + "Europe/London" + ).dt.truncate("46m") + + +def test_cast_time_to_duration() -> None: + assert pl.Series([time(hour=0, minute=0, second=2)]).cast( + pl.Duration + ).item() == timedelta(seconds=2) + + +def test_tz_aware_day_weekday() -> None: + start = datetime(2001, 1, 1) + stop = datetime(2001, 1, 9) + df = pl.DataFrame( + { + "date": pl.datetime_range( + start, stop, timedelta(days=3), time_zone="UTC", eager=True + ) + } + ) + + df = df.with_columns( + pl.col("date").dt.convert_time_zone("Asia/Tokyo").alias("tk_date"), + pl.col("date").dt.convert_time_zone("America/New_York").alias("ny_date"), + ) + + assert df.select( + pl.col("date").dt.day().alias("day"), + pl.col("tk_date").dt.day().alias("tk_day"), + pl.col("ny_date").dt.day().alias("ny_day"), + pl.col("date").dt.weekday().alias("weekday"), + pl.col("tk_date").dt.weekday().alias("tk_weekday"), + pl.col("ny_date").dt.weekday().alias("ny_weekday"), + ).to_dict(as_series=False) == { + "day": [1, 4, 7], + "tk_day": [1, 4, 7], + "ny_day": [31, 3, 6], + "weekday": [1, 4, 7], + "tk_weekday": [1, 4, 7], + "ny_weekday": [7, 3, 6], + } + + +def test_datetime_cum_agg_schema() -> None: + df = pl.DataFrame( + { + "timestamp": [ + datetime(2023, 1, 1), + datetime(2023, 1, 2), + datetime(2023, 1, 3), + ] + } + ) + # Exactly the same as above but with lazy() and collect() later + assert ( + df.lazy() + .with_columns( + (pl.col("timestamp").cum_min()).alias("cum_min"), + (pl.col("timestamp").cum_max()).alias("cum_max"), + ) + .with_columns( + (pl.col("cum_min") + pl.duration(hours=24)).alias("cum_min+24"), + (pl.col("cum_max") + pl.duration(hours=24)).alias("cum_max+24"), + ) + .collect() + ).to_dict(as_series=False) == { + "timestamp": [ + datetime(2023, 1, 1, 0, 0), + datetime(2023, 1, 2, 0, 0), + datetime(2023, 1, 3, 0, 0), + ], + "cum_min": [ + datetime(2023, 1, 1, 0, 0), + datetime(2023, 1, 1, 0, 0), + datetime(2023, 1, 1, 0, 0), + ], + "cum_max": [ + datetime(2023, 1, 1, 0, 0), + datetime(2023, 1, 2, 0, 0), + datetime(2023, 1, 3, 0, 0), + ], + "cum_min+24": [ + datetime(2023, 1, 2, 0, 0), + datetime(2023, 1, 2, 0, 0), + datetime(2023, 1, 2, 0, 0), + ], + "cum_max+24": [ + datetime(2023, 1, 2, 0, 0), + datetime(2023, 1, 3, 0, 0), + datetime(2023, 1, 4, 0, 0), + ], + } + + +def test_infer_iso8601_datetime(iso8601_format_datetime: str) -> None: + # construct an example time string + time_string = ( + iso8601_format_datetime.replace("%Y", "2134") + .replace("%m", "12") + .replace("%d", "13") + .replace("%H", "01") + .replace("%M", "12") + .replace("%S", "34") + .replace("%3f", "123") + .replace("%6f", "123456") + .replace("%9f", "123456789") + ) + parsed = pl.Series([time_string]).str.strptime(pl.Datetime("ns")) + assert parsed.dt.year().item() == 2134 + assert parsed.dt.month().item() == 12 + assert parsed.dt.day().item() == 13 + if "%H" in iso8601_format_datetime: + assert parsed.dt.hour().item() == 1 + if "%M" in iso8601_format_datetime: + assert parsed.dt.minute().item() == 12 + if "%S" in iso8601_format_datetime: + assert parsed.dt.second().item() == 34 + if "%9f" in iso8601_format_datetime: + assert parsed.dt.nanosecond().item() == 123456789 + if "%6f" in iso8601_format_datetime: + assert parsed.dt.nanosecond().item() == 123456000 + if "%3f" in iso8601_format_datetime: + assert parsed.dt.nanosecond().item() == 123000000 + + +def test_infer_iso8601_tz_aware_datetime(iso8601_tz_aware_format_datetime: str) -> None: + # construct an example time string + time_string = ( + iso8601_tz_aware_format_datetime.replace("%Y", "2134") + .replace("%m", "12") + .replace("%d", "13") + .replace("%H", "02") + .replace("%M", "12") + .replace("%S", "34") + .replace("%3f", "123") + .replace("%6f", "123456") + .replace("%9f", "123456789") + .replace("%#z", "+01:00") + ) + parsed = pl.Series([time_string]).str.strptime(pl.Datetime("ns")) + assert parsed.dt.year().item() == 2134 + assert parsed.dt.month().item() == 12 + assert parsed.dt.day().item() == 13 + if "%H" in iso8601_tz_aware_format_datetime: + assert parsed.dt.hour().item() == 1 + if "%M" in iso8601_tz_aware_format_datetime: + assert parsed.dt.minute().item() == 12 + if "%S" in iso8601_tz_aware_format_datetime: + assert parsed.dt.second().item() == 34 + if "%9f" in iso8601_tz_aware_format_datetime: + assert parsed.dt.nanosecond().item() == 123456789 + if "%6f" in iso8601_tz_aware_format_datetime: + assert parsed.dt.nanosecond().item() == 123456000 + if "%3f" in iso8601_tz_aware_format_datetime: + assert parsed.dt.nanosecond().item() == 123000000 + assert parsed.dtype == pl.Datetime("ns", "UTC") + + +def test_infer_iso8601_date(iso8601_format_date: str) -> None: + # construct an example date string + time_string = ( + iso8601_format_date.replace("%Y", "2134") + .replace("%m", "12") + .replace("%d", "13") + ) + parsed = pl.Series([time_string]).str.strptime(pl.Date) + assert parsed.dt.year().item() == 2134 + assert parsed.dt.month().item() == 12 + assert parsed.dt.day().item() == 13 + + +def test_year_null_backed_by_out_of_range_15313() -> None: + # Create a Series where the null value is backed by a value which would + # be out-of-range for Datetime('us') + s = pl.Series([None, 2**63 - 1]) + s -= 2**63 - 1 + result = s.cast(pl.Datetime).dt.year() + expected = pl.Series([None, 1970], dtype=pl.Int32) + assert_series_equal(result, expected) + result = s.cast(pl.Date).dt.year() + expected = pl.Series([None, 1970], dtype=pl.Int32) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "dtype", + [*TEMPORAL_DTYPES, pl.Datetime("ms", "UTC"), pl.Datetime("ns", "Europe/Amsterdam")], +) +def test_series_is_temporal(dtype: pl.DataType) -> None: + s = pl.Series([None], dtype=dtype) + assert s.dtype.is_temporal() is True + + +@pytest.mark.parametrize( + "time_zone", + [ + None, + timezone.utc, + "America/Caracas", + "Asia/Kathmandu", + "Asia/Taipei", + "Europe/Amsterdam", + "Europe/Lisbon", + "Indian/Maldives", + "Pacific/Norfolk", + "Pacific/Samoa", + "Turkey", + "US/Eastern", + "UTC", + "Zulu", + ], +) +def test_misc_precision_any_value_conversion(time_zone: Any) -> None: + tz = ZoneInfo(time_zone) if isinstance(time_zone, str) else time_zone + # default precision (μs) + dt = datetime(2514, 5, 30, 1, 53, 4, 986754, tzinfo=tz) + assert pl.Series([dt]).to_list() == [dt] + + # ms precision + dt = datetime(2243, 1, 1, 0, 0, 0, 1000, tzinfo=tz) + assert pl.Series([dt]).cast(pl.Datetime("ms", time_zone)).to_list() == [dt] + + # ns precision + dt = datetime(2256, 1, 1, 0, 0, 0, 1, tzinfo=tz) + assert pl.Series([dt]).cast(pl.Datetime("ns", time_zone)).to_list() == [dt] + + +@pytest.mark.parametrize( + "tm", + [ + time(0, 20, 30, 1), + time(8, 40, 15, 8888), + time(15, 10, 20, 123456), + time(23, 59, 59, 999999), + ], +) +def test_pytime_conversion(tm: time) -> None: + s = pl.Series("tm", [tm]) + assert s.to_list() == [tm] + + +@given( + value=st.datetimes(min_value=datetime(1800, 1, 1), max_value=datetime(2100, 1, 1)), + time_zone=st.sampled_from(["UTC", "Asia/Kathmandu", "Europe/Amsterdam", None]), + time_unit=st.sampled_from(["ms", "us", "ns"]), +) +def test_weekday_vs_stdlib_datetime( + value: datetime, time_zone: str, time_unit: TimeUnit +) -> None: + result = ( + pl.Series([value], dtype=pl.Datetime(time_unit)) + .dt.replace_time_zone(time_zone, non_existent="null", ambiguous="null") + .dt.weekday() + .item() + ) + if result is not None: + expected = value.isoweekday() + assert result == expected + + +@given( + value=st.dates(), +) +def test_weekday_vs_stdlib_date(value: date) -> None: + result = pl.Series([value]).dt.weekday().item() + expected = value.isoweekday() + assert result == expected + + +def test_temporal_downcast_construction_19309() -> None: + # implicit cast from us to ms upon construction + s = pl.Series( + [ + datetime(1969, 1, 1, 0, 0, 0, 1), + datetime(1969, 12, 31, 23, 59, 59, 999999), + datetime(1970, 1, 1, 0, 0, 0, 0), + datetime(1970, 1, 1, 0, 0, 0, 1), + ], + dtype=pl.Datetime("ms"), + ) + + assert s.to_list() == [ + datetime(1969, 1, 1), + datetime(1969, 12, 31, 23, 59, 59, 999000), + datetime(1970, 1, 1), + datetime(1970, 1, 1), + ] diff --git a/py-polars/tests/unit/datatypes/test_time.py b/py-polars/tests/unit/datatypes/test_time.py new file mode 100644 index 000000000000..a7eb5c095964 --- /dev/null +++ b/py-polars/tests/unit/datatypes/test_time.py @@ -0,0 +1,33 @@ +from datetime import time + +import pytest + +import polars as pl + + +def test_time_to_string_cast() -> None: + assert pl.Series([time(12, 1, 1)]).cast(str).to_list() == ["12:01:01"] + + +def test_time_zero_3828() -> None: + assert pl.Series(values=[time(0)], dtype=pl.Time).to_list() == [time(0)] + + +def test_time_microseconds_3843() -> None: + in_val = [time(0, 9, 11, 558332)] + s = pl.Series(in_val) + assert s.to_list() == in_val + + +def test_invalid_casts() -> None: + with pytest.raises(pl.exceptions.InvalidOperationError): + pl.DataFrame({"a": []}).with_columns(a=pl.lit(-1).cast(pl.Time)) + + with pytest.raises(pl.exceptions.InvalidOperationError): + pl.Series([-1]).cast(pl.Time) + + with pytest.raises(pl.exceptions.InvalidOperationError): + pl.Series([24 * 60 * 60 * 1_000_000_000]).cast(pl.Time) + + largest_value = pl.Series([24 * 60 * 60 * 1_000_000_000 - 1]).cast(pl.Time) + assert "23:59:59.999999999" in str(largest_value) diff --git a/py-polars/tests/unit/datatypes/test_utils.py b/py-polars/tests/unit/datatypes/test_utils.py new file mode 100644 index 000000000000..c8bb113bd5a1 --- /dev/null +++ b/py-polars/tests/unit/datatypes/test_utils.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import pytest +from hypothesis import given + +import polars as pl +from polars.datatypes._utils import dtype_to_init_repr +from polars.testing.parametric import dtypes + + +@given(dtype=dtypes()) +def test_dtype_to_init_repr_parametric(dtype: pl.DataType) -> None: + assert repr(dtype) == dtype_to_init_repr(dtype, prefix="") + + +@pytest.mark.parametrize( + ("dtype", "expected"), + [ + (pl.Struct, "pl.Struct"), + (pl.Array(pl.Int8, 2), "pl.Array(pl.Int8, shape=(2,))"), + (pl.List(pl.Int32), "pl.List(pl.Int32)"), + (pl.List(pl.List(pl.Int8)), "pl.List(pl.List(pl.Int8))"), + ( + pl.Struct({"x": pl.String, "y": pl.List(pl.Int8)}), + "pl.Struct({'x': pl.String, 'y': pl.List(pl.Int8)})", + ), + ], +) +def test_dtype_to_init_repr(dtype: pl.DataType, expected: str) -> None: + assert dtype_to_init_repr(dtype) == expected diff --git a/py-polars/tests/unit/expr/__init__.py b/py-polars/tests/unit/expr/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/expr/test_dunders.py b/py-polars/tests/unit/expr/test_dunders.py new file mode 100644 index 000000000000..3ab2810f5e7d --- /dev/null +++ b/py-polars/tests/unit/expr/test_dunders.py @@ -0,0 +1,16 @@ +import polars as pl +from polars.testing.asserts.frame import assert_frame_equal + + +def test_add_parse_str_input_as_literal() -> None: + df = pl.DataFrame({"a": ["x", "y"]}) + result = df.select(pl.col("a") + "b") + expected = pl.DataFrame({"a": ["xb", "yb"]}) + assert_frame_equal(result, expected) + + +def test_truediv_parse_str_input_as_col_name() -> None: + df = pl.DataFrame({"a": [10, 12], "b": [5, 4]}) + result = df.select(pl.col("a") / "b") + expected = pl.DataFrame({"a": [2, 3]}, schema={"a": pl.Float64}) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/expr/test_expr_apply_eval.py b/py-polars/tests/unit/expr/test_expr_apply_eval.py new file mode 100644 index 000000000000..73e501f705bd --- /dev/null +++ b/py-polars/tests/unit/expr/test_expr_apply_eval.py @@ -0,0 +1,15 @@ +import polars as pl + + +def test_expression_15183() -> None: + assert ( + pl.DataFrame( + {"a": [1, 2, 3, 4, 5, 2, 3, 5, 1], "b": [1, 2, 3, 1, 2, 3, 1, 2, 3]} + ) + .group_by("a") + .agg(pl.col.b.unique().sort().str.join("-").str.split("-")) + .sort("a") + ).to_dict(as_series=False) == { + "a": [1, 2, 3, 4, 5], + "b": [["1", "3"], ["2", "3"], ["1", "3"], ["1"], ["2"]], + } diff --git a/py-polars/tests/unit/expr/test_exprs.py b/py-polars/tests/unit/expr/test_exprs.py new file mode 100644 index 000000000000..ff515688c3ef --- /dev/null +++ b/py-polars/tests/unit/expr/test_exprs.py @@ -0,0 +1,659 @@ +from __future__ import annotations + +from datetime import date, datetime, timedelta, timezone +from itertools import permutations +from typing import TYPE_CHECKING, Any, cast +from zoneinfo import ZoneInfo + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal +from tests.unit.conftest import ( + DATETIME_DTYPES, + DURATION_DTYPES, + FLOAT_DTYPES, + INTEGER_DTYPES, + NUMERIC_DTYPES, + TEMPORAL_DTYPES, +) + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + + +def test_arg_true() -> None: + df = pl.DataFrame({"a": [1, 1, 2, 1]}) + res = df.select((pl.col("a") == 1).arg_true()) + expected = pl.DataFrame([pl.Series("a", [0, 1, 3], dtype=pl.UInt32)]) + assert_frame_equal(res, expected) + + +def test_suffix(fruits_cars: pl.DataFrame) -> None: + df = fruits_cars + out = df.select([pl.all().name.suffix("_reverse")]) + assert out.columns == ["A_reverse", "fruits_reverse", "B_reverse", "cars_reverse"] + + +def test_pipe() -> None: + df = pl.DataFrame({"foo": [1, 2, 3], "bar": [6, None, 8]}) + + def _multiply(expr: pl.Expr, mul: int) -> pl.Expr: + return expr * mul + + result = df.select( + pl.col("foo").pipe(_multiply, mul=2), + pl.col("bar").pipe(_multiply, mul=3), + ) + + expected = pl.DataFrame({"foo": [2, 4, 6], "bar": [18, None, 24]}) + assert_frame_equal(result, expected) + + +def test_prefix(fruits_cars: pl.DataFrame) -> None: + df = fruits_cars + out = df.select([pl.all().name.prefix("reverse_")]) + assert out.columns == ["reverse_A", "reverse_fruits", "reverse_B", "reverse_cars"] + + +def test_filter_where() -> None: + df = pl.DataFrame({"a": [1, 2, 3, 1, 2, 3], "b": [4, 5, 6, 7, 8, 9]}) + result_filter = df.group_by("a", maintain_order=True).agg( + pl.col("b").filter(pl.col("b") > 4).alias("c") + ) + expected = pl.DataFrame({"a": [1, 2, 3], "c": [[7], [5, 8], [6, 9]]}) + assert_frame_equal(result_filter, expected) + + with pytest.deprecated_call(): + result_where = df.group_by("a", maintain_order=True).agg( + pl.col("b").where(pl.col("b") > 4).alias("c") + ) + assert_frame_equal(result_where, expected) + + # apply filter constraints using kwargs + df = pl.DataFrame( + { + "key": ["a", "a", "a", "a", "a", "a", "b", "b", "b", "b", "b", "b"], + "n": [1, 4, 4, 2, 2, 3, 1, 3, 0, 2, 3, 4], + }, + schema_overrides={"n": pl.UInt8}, + ) + res = ( + df.group_by("key") + .agg( + n_0=pl.col("n").filter(n=0), + n_1=pl.col("n").filter(n=1), + n_2=pl.col("n").filter(n=2), + n_3=pl.col("n").filter(n=3), + n_4=pl.col("n").filter(n=4), + ) + .sort(by="key") + ) + assert res.rows() == [ + ("a", [], [1], [2, 2], [3], [4, 4]), + ("b", [0], [1], [2], [3, 3], [4]), + ] + + +def test_len_expr() -> None: + df = pl.DataFrame({"a": [1, 2, 3, 3, 3], "b": ["a", "a", "b", "a", "a"]}) + + out = df.select(pl.len()) + assert out.shape == (1, 1) + assert cast(int, out.item()) == 5 + + out = df.group_by("b", maintain_order=True).agg(pl.len()) + assert out["b"].to_list() == ["a", "b"] + assert out["len"].to_list() == [4, 1] + + +def test_map_alias() -> None: + out = pl.DataFrame({"foo": [1, 2, 3]}).select( + (pl.col("foo") * 2).name.map(lambda name: f"{name}{name}") + ) + expected = pl.DataFrame({"foofoo": [2, 4, 6]}) + assert_frame_equal(out, expected) + + +def test_unique_stable() -> None: + s = pl.Series("a", [1, 1, 1, 1, 2, 2, 2, 3, 3]) + expected = pl.Series("a", [1, 2, 3]) + assert_series_equal(s.unique(maintain_order=True), expected) + + +def test_entropy() -> None: + df = pl.DataFrame( + { + "group": ["A", "A", "A", "B", "B", "B", "B", "C"], + "id": [1, 2, 1, 4, 5, 4, 6, 7], + } + ) + result = df.group_by("group", maintain_order=True).agg( + pl.col("id").entropy(normalize=True) + ) + expected = pl.DataFrame( + {"group": ["A", "B", "C"], "id": [1.0397207708399179, 1.371381017771811, 0.0]} + ) + assert_frame_equal(result, expected) + + +def test_dot_in_group_by() -> None: + df = pl.DataFrame( + { + "group": ["a", "a", "a", "b", "b", "b"], + "x": [1, 1, 1, 1, 1, 1], + "y": [1, 2, 3, 4, 5, 6], + } + ) + + result = df.group_by("group", maintain_order=True).agg( + pl.col("x").dot("y").alias("dot") + ) + expected = pl.DataFrame({"group": ["a", "b"], "dot": [6, 15]}) + assert_frame_equal(result, expected) + + +def test_dtype_col_selection() -> None: + df = pl.DataFrame( + data=[], + schema={ + "a1": pl.Datetime, + "a2": pl.Datetime("ms"), + "a3": pl.Datetime("ms"), + "a4": pl.Datetime("ns"), + "b": pl.Date, + "c": pl.Time, + "d1": pl.Duration, + "d2": pl.Duration("ms"), + "d3": pl.Duration("us"), + "d4": pl.Duration("ns"), + "e": pl.Int8, + "f": pl.Int16, + "g": pl.Int32, + "h": pl.Int64, + "i": pl.Float32, + "j": pl.Float64, + "k": pl.UInt8, + "l": pl.UInt16, + "m": pl.UInt32, + "n": pl.UInt64, + }, + ) + assert df.select(pl.col(INTEGER_DTYPES)).columns == [ + "e", + "f", + "g", + "h", + "k", + "l", + "m", + "n", + ] + assert df.select(pl.col(FLOAT_DTYPES)).columns == ["i", "j"] + assert df.select(pl.col(NUMERIC_DTYPES)).columns == [ + "e", + "f", + "g", + "h", + "i", + "j", + "k", + "l", + "m", + "n", + ] + assert df.select(pl.col(TEMPORAL_DTYPES)).columns == [ + "a1", + "a2", + "a3", + "a4", + "b", + "c", + "d1", + "d2", + "d3", + "d4", + ] + assert df.select(pl.col(DATETIME_DTYPES)).columns == [ + "a1", + "a2", + "a3", + "a4", + ] + assert df.select(pl.col(DURATION_DTYPES)).columns == [ + "d1", + "d2", + "d3", + "d4", + ] + + +def test_list_eval_expression() -> None: + df = pl.DataFrame({"a": [1, 8, 3], "b": [4, 5, 2]}) + + for parallel in [True, False]: + assert df.with_columns( + pl.concat_list(["a", "b"]) + .list.eval(pl.first().rank(), parallel=parallel) + .alias("rank") + ).to_dict(as_series=False) == { + "a": [1, 8, 3], + "b": [4, 5, 2], + "rank": [[1.0, 2.0], [2.0, 1.0], [2.0, 1.0]], + } + + assert df["a"].reshape((1, -1)).arr.to_list().list.eval( + pl.first(), parallel=parallel + ).to_list() == [[1, 8, 3]] + + +def test_null_count_expr() -> None: + df = pl.DataFrame({"key": ["a", "b", "b", "a"], "val": [1, 2, None, 1]}) + + assert df.select([pl.all().null_count()]).to_dict(as_series=False) == { + "key": [0], + "val": [1], + } + + +def test_pos_neg() -> None: + df = pl.DataFrame( + { + "x": [3, 2, 1], + "y": [6, 7, 8], + } + ).with_columns(-pl.col("x"), +pl.col("y"), -pl.lit(1)) + + # #11149: ensure that we preserve the output name (where available) + assert df.to_dict(as_series=False) == { + "x": [-3, -2, -1], + "y": [6, 7, 8], + "literal": [-1, -1, -1], + } + + +def test_power_by_expression() -> None: + out = pl.DataFrame( + {"a": [1, None, None, 4, 5, 6], "b": [1, 2, None, 4, None, 6]} + ).select( + pl.col("a").pow(pl.col("b")).alias("pow_expr"), + (pl.col("a") ** pl.col("b")).alias("pow_op"), + (2 ** pl.col("b")).alias("pow_op_left"), + ) + + for pow_col in ("pow_expr", "pow_op"): + assert out[pow_col].to_list() == [1.0, None, None, 256.0, None, 46656.0] + assert out["pow_op_left"].to_list() == [2.0, 4.0, None, 16.0, None, 64.0] + + +@pytest.mark.may_fail_auto_streaming +def test_expression_appends() -> None: + df = pl.DataFrame({"a": [1, 1, 2]}) + + assert df.select(pl.repeat(None, 3).append(pl.col("a"))).n_chunks() == 2 + assert df.select(pl.repeat(None, 3).append(pl.col("a")).rechunk()).n_chunks() == 1 + + out = df.select(pl.concat([pl.repeat(None, 3), pl.col("a")], rechunk=True)) + + assert out.n_chunks() == 1 + assert out.to_series().to_list() == [None, None, None, 1, 1, 2] + + +def test_arr_contains() -> None: + df_groups = pl.DataFrame( + { + "animals": [ + ["cat", "mouse", "dog"], + ["dog", "hedgehog", "mouse", "cat"], + ["peacock", "mouse", "aardvark"], + ], + } + ) + # string array contains + assert df_groups.lazy().filter( + pl.col("animals").list.contains("mouse"), + ).collect().to_dict(as_series=False) == { + "animals": [ + ["cat", "mouse", "dog"], + ["dog", "hedgehog", "mouse", "cat"], + ["peacock", "mouse", "aardvark"], + ] + } + # string array contains and *not* contains + assert df_groups.filter( + pl.col("animals").list.contains("mouse"), + ~pl.col("animals").list.contains("hedgehog"), + ).to_dict(as_series=False) == { + "animals": [ + ["cat", "mouse", "dog"], + ["peacock", "mouse", "aardvark"], + ], + } + + +def test_logical_boolean() -> None: + # note, cannot use expressions in logical + # boolean context (eg: and/or/not operators) + with pytest.raises(TypeError, match="ambiguous"): + pl.col("colx") and pl.col("coly") # type: ignore[redundant-expr] + + with pytest.raises(TypeError, match="ambiguous"): + pl.col("colx") or pl.col("coly") # type: ignore[redundant-expr] + + df = pl.DataFrame({"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, 5]}) + + with pytest.raises(TypeError, match="ambiguous"): + df.select([(pl.col("a") > pl.col("b")) and (pl.col("b") > pl.col("b"))]) + + with pytest.raises(TypeError, match="ambiguous"): + df.select([(pl.col("a") > pl.col("b")) or (pl.col("b") > pl.col("b"))]) + + +def test_lit_dtypes() -> None: + def lit_series(value: Any, dtype: PolarsDataType | None) -> pl.Series: + return pl.select(pl.lit(value, dtype=dtype)).to_series() + + d = datetime(2049, 10, 5, 1, 2, 3, 987654) + d_ms = datetime(2049, 10, 5, 1, 2, 3, 987000) + d_tz = datetime(2049, 10, 5, 1, 2, 3, 987654, tzinfo=ZoneInfo("Asia/Kathmandu")) + + td = timedelta(days=942, hours=6, microseconds=123456) + td_ms = timedelta(days=942, seconds=21600, microseconds=123000) + + df = pl.DataFrame( + { + "dtm_ms": lit_series(d, pl.Datetime("ms")), + "dtm_us": lit_series(d, pl.Datetime("us")), + "dtm_ns": lit_series(d, pl.Datetime("ns")), + "dtm_aware_0": lit_series(d, pl.Datetime("us", "Asia/Kathmandu")), + "dtm_aware_1": lit_series(d_tz, pl.Datetime("us")), + "dtm_aware_2": lit_series(d_tz, None), + "dtm_aware_3": lit_series(d, pl.Datetime(time_zone="Asia/Kathmandu")), + "dur_ms": lit_series(td, pl.Duration("ms")), + "dur_us": lit_series(td, pl.Duration("us")), + "dur_ns": lit_series(td, pl.Duration("ns")), + "f32": lit_series(0, pl.Float32), + "u16": lit_series(0, pl.UInt16), + "i16": lit_series(0, pl.Int16), + "i64": lit_series(pl.Series([8]), None), + "list_i64": lit_series(pl.Series([[1, 2, 3]]), None), + } + ) + assert df.dtypes == [ + pl.Datetime("ms"), + pl.Datetime("us"), + pl.Datetime("ns"), + pl.Datetime("us", "Asia/Kathmandu"), + pl.Datetime("us", "Asia/Kathmandu"), + pl.Datetime("us", "Asia/Kathmandu"), + pl.Datetime("us", "Asia/Kathmandu"), + pl.Duration("ms"), + pl.Duration("us"), + pl.Duration("ns"), + pl.Float32, + pl.UInt16, + pl.Int16, + pl.Int64, + pl.List(pl.Int64), + ] + assert df.row(0) == ( + d_ms, + d, + d, + d_tz, + d_tz, + d_tz, + d_tz, + td_ms, + td, + td, + 0, + 0, + 0, + 8, + [1, 2, 3], + ) + + +def test_lit_empty_tu() -> None: + td = timedelta(1) + assert pl.select(pl.lit(td, dtype=pl.Duration)).item() == td + assert pl.select(pl.lit(td, dtype=pl.Duration)).dtypes[0].time_unit == "us" # type: ignore[attr-defined] + + t = datetime(2023, 1, 1) + assert pl.select(pl.lit(t, dtype=pl.Datetime)).item() == t + assert pl.select(pl.lit(t, dtype=pl.Datetime)).dtypes[0].time_unit == "us" # type: ignore[attr-defined] + + +def test_incompatible_lit_dtype() -> None: + with pytest.raises( + TypeError, + match=r"time zone of dtype \('Asia/Kathmandu'\) differs from time zone of value \(datetime.timezone.utc\)", + ): + pl.lit( + datetime(2020, 1, 1, tzinfo=timezone.utc), + dtype=pl.Datetime("us", "Asia/Kathmandu"), + ) + + +def test_lit_dtype_utc() -> None: + result = pl.select( + pl.lit( + datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kathmandu")), + dtype=pl.Datetime("us", "Asia/Kathmandu"), + ) + ) + expected = pl.DataFrame( + {"literal": [datetime(2019, 12, 31, 18, 15, tzinfo=timezone.utc)]} + ).select(pl.col("literal").dt.convert_time_zone("Asia/Kathmandu")) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + ("input", "expected"), + [ + (("a",), ["b", "c"]), + (("a", "b"), ["c"]), + ((["a", "b"],), ["c"]), + ((pl.Int64,), ["c"]), + ((pl.String, pl.Float32), ["a", "b"]), + (([pl.String, pl.Float32],), ["a", "b"]), + ], +) +def test_exclude(input: tuple[Any, ...], expected: list[str]) -> None: + df = pl.DataFrame(schema={"a": pl.Int64, "b": pl.Int64, "c": pl.String}) + assert df.select(pl.all().exclude(*input)).columns == expected + + +@pytest.mark.parametrize("input", [(5,), (["a"], date.today()), (pl.Int64, "a")]) +def test_exclude_invalid_input(input: tuple[Any, ...]) -> None: + df = pl.DataFrame(schema=["a", "b", "c"]) + with pytest.raises(TypeError): + df.select(pl.all().exclude(*input)) + + +def test_operators_vs_expressions() -> None: + df = pl.DataFrame( + data={ + "x": [5, 6, 7, 4, 8], + "y": [1.5, 2.5, 1.0, 4.0, -5.75], + "z": [-9, 2, -1, 4, 8], + } + ) + for c1, c2 in permutations("xyz", r=2): + df_op = df.select( + a=pl.col(c1) == pl.col(c2), + b=pl.col(c1) // pl.col(c2), + c=pl.col(c1) > pl.col(c2), + d=pl.col(c1) >= pl.col(c2), + e=pl.col(c1) < pl.col(c2), + f=pl.col(c1) <= pl.col(c2), + g=pl.col(c1) % pl.col(c2), + h=pl.col(c1) != pl.col(c2), + i=pl.col(c1) - pl.col(c2), + j=pl.col(c1) / pl.col(c2), + k=pl.col(c1) * pl.col(c2), + l=pl.col(c1) + pl.col(c2), + ) + df_expr = df.select( + a=pl.col(c1).eq(pl.col(c2)), + b=pl.col(c1).floordiv(pl.col(c2)), + c=pl.col(c1).gt(pl.col(c2)), + d=pl.col(c1).ge(pl.col(c2)), + e=pl.col(c1).lt(pl.col(c2)), + f=pl.col(c1).le(pl.col(c2)), + g=pl.col(c1).mod(pl.col(c2)), + h=pl.col(c1).ne(pl.col(c2)), + i=pl.col(c1).sub(pl.col(c2)), + j=pl.col(c1).truediv(pl.col(c2)), + k=pl.col(c1).mul(pl.col(c2)), + l=pl.col(c1).add(pl.col(c2)), + ) + assert_frame_equal(df_op, df_expr) + + # xor - only int cols + assert_frame_equal( + df.select(pl.col("x") ^ pl.col("z")), + df.select(pl.col("x").xor(pl.col("z"))), + ) + + # and (&) or (|) chains + assert_frame_equal( + df.select( + all=(pl.col("x") >= pl.col("z")).and_( + pl.col("y") >= pl.col("z"), + pl.col("y") == pl.col("y"), + pl.col("z") <= pl.col("x"), + pl.col("y") != pl.col("x"), + ) + ), + df.select( + all=( + (pl.col("x") >= pl.col("z")) + & (pl.col("y") >= pl.col("z")) + & (pl.col("y") == pl.col("y")) + & (pl.col("z") <= pl.col("x")) + & (pl.col("y") != pl.col("x")) + ) + ), + ) + + assert_frame_equal( + df.select( + any=(pl.col("x") == pl.col("y")).or_( + pl.col("x") == pl.col("y"), + pl.col("y") == pl.col("z"), + pl.col("y").cast(int) == pl.col("z"), + ) + ), + df.select( + any=(pl.col("x") == pl.col("y")) + | (pl.col("x") == pl.col("y")) + | (pl.col("y") == pl.col("z")) + | (pl.col("y").cast(int) == pl.col("z")) + ), + ) + + +def test_head() -> None: + df = pl.DataFrame({"a": [1, 2, 3, 4, 5]}) + assert df.select(pl.col("a").head(0)).to_dict(as_series=False) == {"a": []} + assert df.select(pl.col("a").head(3)).to_dict(as_series=False) == {"a": [1, 2, 3]} + assert df.select(pl.col("a").head(10)).to_dict(as_series=False) == { + "a": [1, 2, 3, 4, 5] + } + assert df.select(pl.col("a").head(pl.len() // 2)).to_dict(as_series=False) == { + "a": [1, 2] + } + + +def test_tail() -> None: + df = pl.DataFrame({"a": [1, 2, 3, 4, 5]}) + assert df.select(pl.col("a").tail(0)).to_dict(as_series=False) == {"a": []} + assert df.select(pl.col("a").tail(3)).to_dict(as_series=False) == {"a": [3, 4, 5]} + assert df.select(pl.col("a").tail(10)).to_dict(as_series=False) == { + "a": [1, 2, 3, 4, 5] + } + assert df.select(pl.col("a").tail(pl.len() // 2)).to_dict(as_series=False) == { + "a": [4, 5] + } + + +def test_repr_short_expression() -> None: + expr = pl.functions.all().len().name.prefix("length:") + # we cut off the last ten characters because that includes the + # memory location which will vary between runs + result = repr(expr).split("0x")[0] + + expected = " None: + expr = pl.functions.col(pl.String).str.count_matches("") + + # we cut off the last ten characters because that includes the + # memory location which will vary between runs + result = repr(expr).split("0x")[0] + + # note the … denoting that there was truncated text + expected = "") + + +def test_repr_gather() -> None: + result = repr(pl.col("a").gather(0)) + assert 'col("a").gather(dyn int: 0)' in result + result = repr(pl.col("a").get(0)) + assert 'col("a").get(dyn int: 0)' in result + + +def test_replace_no_cse() -> None: + plan = ( + pl.LazyFrame({"a": [1], "b": [2]}) + .select([(pl.col("a") * pl.col("a")).sum().replace(1, None)]) + .explain() + ) + assert "POLARS_CSER" not in plan + + +def test_slice_rejects_non_integral() -> None: + df = pl.LazyFrame({"a": [0, 1, 2, 3], "b": [1.5, 2, 3, 4]}) + + with pytest.raises(pl.exceptions.InvalidOperationError): + df.select(pl.col("a").slice(pl.col("b").slice(0, 1), None)).collect() + + with pytest.raises(pl.exceptions.InvalidOperationError): + df.select(pl.col("a").slice(0, pl.col("b").slice(1, 2))).collect() + + with pytest.raises(pl.exceptions.InvalidOperationError): + df.select(pl.col("a").slice(pl.lit("1"), None)).collect() + + +def test_slice() -> None: + data = {"a": [0, 1, 2, 3], "b": [1, 2, 3, 4]} + df = pl.DataFrame(data) + + result = df.select(pl.col("a").slice(1)) + expected = pl.DataFrame({"a": data["a"][1:]}) + assert_frame_equal(result, expected) + + result = df.select(pl.all().slice(1, 1)) + expected = pl.DataFrame({"a": data["a"][1:2], "b": data["b"][1:2]}) + assert_frame_equal(result, expected) + + +def test_function_expr_scalar_identification_18755() -> None: + # The function uses `ApplyOptions::GroupWise`, however the input is scalar. + assert_frame_equal( + pl.DataFrame({"a": [1, 2]}).with_columns(pl.lit(5).shrink_dtype().alias("b")), + pl.DataFrame({"a": [1, 2], "b": pl.Series([5, 5], dtype=pl.Int8)}), + ) + + +def test_concat_deprecation() -> None: + with pytest.deprecated_call(match="`ExprStringNameSpace.concat` is deprecated."): + pl.Series(["foo"]).str.concat() + with pytest.deprecated_call(match="`ExprStringNameSpace.concat` is deprecated."): + pl.DataFrame({"foo": ["bar"]}).select(pl.all().str.concat()) diff --git a/py-polars/tests/unit/expr/test_literal.py b/py-polars/tests/unit/expr/test_literal.py new file mode 100644 index 000000000000..3e6f5e59b4d8 --- /dev/null +++ b/py-polars/tests/unit/expr/test_literal.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any +from zoneinfo import ZoneInfo + +import pytest +from dateutil.tz import tzoffset + +import polars as pl +from polars.testing import assert_frame_equal + + +def test_literal_scalar_list_18686() -> None: + df = pl.DataFrame({"column1": [1, 2], "column2": ["A", "B"]}) + out = df.with_columns(lit1=pl.lit([]).cast(pl.List(pl.String)), lit2=pl.lit([])) + + assert out.to_dict(as_series=False) == { + "column1": [1, 2], + "column2": ["A", "B"], + "lit1": [[], []], + "lit2": [[], []], + } + assert out.schema == pl.Schema( + [ + ("column1", pl.Int64), + ("column2", pl.String), + ("lit1", pl.List(pl.String)), + ("lit2", pl.List(pl.Null)), + ] + ) + + +def test_literal_integer_20807() -> None: + for i in range(100): + value = 2**i + assert pl.select(pl.lit(value)).item() == value + assert pl.select(pl.lit(-value)).item() == -value + assert pl.select(pl.lit(value, dtype=pl.Int128)).item() == value + assert pl.select(pl.lit(-value, dtype=pl.Int128)).item() == -value + + +@pytest.mark.parametrize( + ("tz", "lit_dtype"), + [ + (ZoneInfo("Asia/Kabul"), None), + (ZoneInfo("Asia/Kabul"), pl.Datetime("us", "Asia/Kabul")), + (ZoneInfo("Europe/Paris"), pl.Datetime("us", "Europe/Paris")), + (timezone.utc, pl.Datetime("us", "UTC")), + ], +) +def test_literal_datetime_timezone(tz: Any, lit_dtype: pl.DataType | None) -> None: + expected_dtype = pl.Datetime("us", time_zone=str(tz)) + value = datetime(2020, 1, 1, tzinfo=tz) + + df1 = pl.DataFrame({"dt": [value]}) + df2 = pl.select(dt=pl.lit(value, dtype=lit_dtype)) + + assert_frame_equal(df1, df2) + assert df1.schema["dt"] == expected_dtype + assert df1.item() == value + + +@pytest.mark.parametrize( + ("tz", "lit_dtype", "expected_item"), + [ + ( + # fixed offset from UTC + tzoffset(None, 16200), + None, + datetime(2019, 12, 31, 19, 30, tzinfo=timezone.utc), + ), + ( + # fixed offset from UTC + tzoffset("Kabul", 16200), + None, + datetime(2019, 12, 31, 19, 30, tzinfo=ZoneInfo("UTC")), + ), + ( + # fixed offset from UTC with matching timezone + tzoffset(None, 16200), + pl.Datetime("us", "Asia/Kabul"), + datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kabul")), + ), + ( + # fixed offset from UTC with matching timezone + tzoffset("Kabul", 16200), + pl.Datetime("us", "Asia/Kabul"), + datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kabul")), + ), + ], +) +def test_literal_datetime_timezone_utc_offset( + tz: Any, lit_dtype: pl.DataType | None, expected_item: datetime +) -> None: + overrides = {"schema_overrides": {"dt": lit_dtype}} if lit_dtype else {} + value = datetime(2020, 1, 1, tzinfo=tz) + + # validate both frame and lit constructors + df1 = pl.DataFrame({"dt": [value]}, **overrides) # type: ignore[arg-type] + df2 = pl.select(dt=pl.lit(value, dtype=lit_dtype)) + + assert_frame_equal(df1, df2) + + expected_tz = "UTC" if lit_dtype is None else getattr(lit_dtype, "time_zone", None) + expected_dtype = pl.Datetime("us", time_zone=expected_tz) + + for df in (df1, df2): + assert df.schema["dt"] == expected_dtype + assert df.item() == expected_item + + +def test_literal_datetime_timezone_utc_error() -> None: + value = datetime(2020, 1, 1, tzinfo=tzoffset("Somewhere", offset=3600)) + + with pytest.raises( + TypeError, + match=( + r"time zone of dtype \('Pacific/Galapagos'\) differs from" + r" time zone of value \(tzoffset\('Somewhere', 3600\)\)" + ), + ): + # the offset does not correspond to the offset of the declared timezone + pl.select(dt=pl.lit(value, dtype=pl.Datetime(time_zone="Pacific/Galapagos"))) diff --git a/py-polars/tests/unit/expr/test_serde.py b/py-polars/tests/unit/expr/test_serde.py new file mode 100644 index 000000000000..415abb5db415 --- /dev/null +++ b/py-polars/tests/unit/expr/test_serde.py @@ -0,0 +1,71 @@ +import io + +import pytest + +import polars as pl +from polars.exceptions import ComputeError + + +@pytest.mark.parametrize( + "expr", + [ + pl.col("foo").sum().over("bar"), + pl.col("foo").rolling_quantile(0.25, window_size=5), + pl.col("foo").rolling_var(window_size=4, ddof=2), + pl.col("foo").rolling_min(window_size=2), + pl.col("foo").rolling_quantile_by("bar", window_size="1mo", quantile=0.75), + ], +) +def test_expr_serde_roundtrip_binary(expr: pl.Expr) -> None: + json = expr.meta.serialize(format="binary") + round_tripped = pl.Expr.deserialize(io.BytesIO(json), format="binary") + assert round_tripped.meta == expr + + +@pytest.mark.parametrize( + "expr", + [ + pl.col("foo").sum().over("bar"), + pl.col("foo").rolling_quantile(0.25, window_size=5), + pl.col("foo").rolling_var(window_size=4, ddof=2), + pl.col("foo").rolling_min(window_size=2), + pl.col("foo").rolling_quantile_by("bar", window_size="1mo", quantile=0.75), + ], +) +def test_expr_serde_roundtrip_json(expr: pl.Expr) -> None: + expr = pl.col("foo").sum().over("bar") + json = expr.meta.serialize(format="json") + round_tripped = pl.Expr.deserialize(io.StringIO(json), format="json") + assert round_tripped.meta == expr + + +def test_expr_deserialize_file_not_found() -> None: + with pytest.raises(FileNotFoundError): + pl.Expr.deserialize("abcdef") + + +def test_expr_deserialize_invalid_json() -> None: + with pytest.raises( + ComputeError, match="could not deserialize input into an expression" + ): + pl.Expr.deserialize(io.StringIO("abcdef"), format="json") + + +def test_expression_json_13991() -> None: + expr = pl.col("foo").cast(pl.Decimal) + json = expr.meta.serialize(format="json") + + round_tripped = pl.Expr.deserialize(io.StringIO(json), format="json") + assert round_tripped.meta == expr + + +def test_expr_write_json_from_json_deprecated() -> None: + expr = pl.col("foo").sum().over("bar") + + with pytest.deprecated_call(): + json = expr.meta.write_json() + + with pytest.deprecated_call(): + round_tripped = pl.Expr.from_json(json) + + assert round_tripped.meta == expr diff --git a/py-polars/tests/unit/expr/test_udfs.py b/py-polars/tests/unit/expr/test_udfs.py new file mode 100644 index 000000000000..9b34a6d3dd47 --- /dev/null +++ b/py-polars/tests/unit/expr/test_udfs.py @@ -0,0 +1,45 @@ +from typing import Any + +import pytest + +import polars as pl + + +def test_pass_name_alias_18914() -> None: + df = pl.DataFrame({"id": [1], "value": [2]}) + + assert df.select( + pl.all() + .map_elements( + lambda x: x, + skip_nulls=False, + pass_name=True, + return_dtype=pl.List(pl.Int64), + ) + .over("id") + ).to_dict(as_series=False) == {"id": [1], "value": [2]} + + +@pytest.mark.parametrize( + "dtype", + [ + pl.String, + pl.Int64, + pl.Boolean, + pl.List(pl.Int32), + pl.Array(pl.Boolean, 2), + pl.Struct({"a": pl.Int8}), + pl.Enum(["a"]), + ], +) +def test_raises_udf(dtype: pl.DataType) -> None: + def raise_f(item: Any) -> None: + raise ValueError + + with pytest.raises(pl.exceptions.ComputeError): + pl.select( + pl.lit(1).map_elements( + raise_f, + return_dtype=dtype, + ) + ) diff --git a/py-polars/tests/unit/functions/__init__.py b/py-polars/tests/unit/functions/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/functions/as_datatype/__init__.py b/py-polars/tests/unit/functions/as_datatype/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/functions/as_datatype/test_concat_arr.py b/py-polars/tests/unit/functions/as_datatype/test_concat_arr.py new file mode 100644 index 000000000000..d555bff7e034 --- /dev/null +++ b/py-polars/tests/unit/functions/as_datatype/test_concat_arr.py @@ -0,0 +1,187 @@ +import pytest + +import polars as pl +from polars.exceptions import ShapeError +from polars.testing import assert_series_equal + + +def test_concat_arr() -> None: + assert_series_equal( + pl.select( + pl.concat_arr( + pl.Series([1, 3, 5]), + pl.Series([2, 4, 6]), + ) + ).to_series(), + pl.Series([[1, 2], [3, 4], [5, 6]], dtype=pl.Array(pl.Int64, 2)), + ) + + assert_series_equal( + pl.select( + pl.concat_arr( + pl.Series([1, 3, 5]), + pl.Series([2, 4, None]), + ) + ).to_series(), + pl.Series([[1, 2], [3, 4], [5, None]], dtype=pl.Array(pl.Int64, 2)), + ) + + assert_series_equal( + pl.select( + pl.concat_arr( + pl.Series([1, 3, 5]), + pl.Series([[2], [None], None], dtype=pl.Array(pl.Int64, 1)), + ) + ).to_series(), + pl.Series([[1, 2], [3, None], None], dtype=pl.Array(pl.Int64, 2)), + ) + + +def test_concat_arr_broadcast() -> None: + assert_series_equal( + pl.select( + pl.concat_arr( + pl.Series([1, 3, 5]), + pl.lit(None, dtype=pl.Int64), + ) + ).to_series(), + pl.Series([[1, None], [3, None], [5, None]], dtype=pl.Array(pl.Int64, 2)), + ) + + assert_series_equal( + pl.select( + pl.concat_arr( + pl.Series([1, 3, 5]), + pl.lit(None, dtype=pl.Array(pl.Int64, 2)), + ) + ).to_series(), + pl.Series([None, None, None], dtype=pl.Array(pl.Int64, 3)), + ) + + assert_series_equal( + pl.select( + pl.concat_arr( + pl.Series([1, 3, 5]), + pl.lit([0, None], dtype=pl.Array(pl.Int64, 2)), + ) + ).to_series(), + pl.Series( + [[1, 0, None], [3, 0, None], [5, 0, None]], dtype=pl.Array(pl.Int64, 3) + ), + ) + + assert_series_equal( + pl.select( + pl.concat_arr(pl.lit(1, dtype=pl.Int64).alias(""), pl.Series([1, 2, 3])) + ).to_series(), + pl.Series([[1, 1], [1, 2], [1, 3]], dtype=pl.Array(pl.Int64, 2)), + ) + + assert_series_equal( + pl.select( + pl.concat_arr(pl.Series([1, 2, 3]), pl.lit(1, dtype=pl.Int64)) + ).to_series(), + pl.Series([[1, 1], [2, 1], [3, 1]], dtype=pl.Array(pl.Int64, 2)), + ) + + with pytest.raises(ShapeError, match="length of column.*did not match"): + assert_series_equal( + pl.select( + pl.concat_arr(pl.Series([1, 3, 5]), pl.Series([1, 1])) + ).to_series(), + pl.Series([None, None, None], dtype=pl.Array(pl.Int64, 3)), + ) + + assert_series_equal( + pl.select( + pl.concat_arr( + pl.Series( + [{"x": [1], "y": [2]}, {"x": [3], "y": None}], + dtype=pl.Struct({"x": pl.Array(pl.Int64, 1)}), + ), + pl.lit( + {"x": [9], "y": [11]}, dtype=pl.Struct({"x": pl.Array(pl.Int64, 1)}) + ), + ) + ).to_series(), + pl.Series( + [ + [{"x": [1], "y": [2]}, {"x": [9], "y": [11]}], + [{"x": [3], "y": [4]}, {"x": [9], "y": [11]}], + ], + dtype=pl.Array(pl.Struct({"x": pl.Array(pl.Int64, 1)}), 2), + ), + ) + + +@pytest.mark.parametrize("inner_dtype", [pl.Int64(), pl.Null()]) +def test_concat_arr_validity_combination(inner_dtype: pl.DataType) -> None: + assert_series_equal( + pl.select( + pl.concat_arr( + pl.Series([[], [], None, None], dtype=pl.Array(inner_dtype, 0)), + pl.Series([[], [], None, None], dtype=pl.Array(inner_dtype, 0)), + pl.Series([[None], None, [None], None], dtype=pl.Array(inner_dtype, 1)), + ), + ).to_series(), + pl.Series([[None], None, None, None], dtype=pl.Array(inner_dtype, 1)), + ) + + assert_series_equal( + pl.select( + pl.concat_arr( + pl.Series([None, None], dtype=inner_dtype), + pl.Series([[None], None], dtype=pl.Array(inner_dtype, 1)), + ), + ).to_series(), + pl.Series([[None, None], None], dtype=pl.Array(inner_dtype, 2)), + ) + + +def test_concat_arr_zero_fields() -> None: + assert_series_equal( + ( + pl.Series([[[]], [None]], dtype=pl.Array(pl.Array(pl.Int64, 0), 1)) + .to_frame() + .select(pl.concat_arr(pl.first(), pl.first())) + .to_series() + ), + pl.Series([[[], []], [None, None]], dtype=pl.Array(pl.Array(pl.Int64, 0), 2)), + ) + + assert_series_equal( + ( + pl.Series([[{}], [None]], dtype=pl.Array(pl.Struct({}), 1)) + .to_frame() + .select(pl.concat_arr(pl.first(), pl.first())) + .to_series() + ), + pl.Series([[{}, {}], [None, None]], dtype=pl.Array(pl.Struct({}), 2)), + ) + + assert_series_equal( + ( + pl.Series( + [[{"x": []}], [{"x": None}], [None]], + dtype=pl.Array(pl.Struct({"x": pl.Array(pl.Int64, 0)}), 1), + ) + .to_frame() + .select(pl.concat_arr(pl.first(), pl.first())) + .to_series() + ), + pl.Series( + [[{"x": []}, {"x": []}], [{"x": None}, {"x": None}], [None, None]], + dtype=pl.Array(pl.Struct({"x": pl.Array(pl.Int64, 0)}), 2), + ), + ) + + +@pytest.mark.may_fail_auto_streaming +def test_concat_arr_scalar() -> None: + lit = pl.lit([b"A"], dtype=pl.Array(pl.Binary, 1)) + df = pl.select(pl.repeat(lit, 10)) + + assert df._to_metadata()["repr"].to_list() == ["scalar"] + + out = df.with_columns(out=pl.concat_arr(pl.first(), pl.first())) + assert out._to_metadata()["repr"].to_list() == ["scalar", "scalar"] diff --git a/py-polars/tests/unit/functions/as_datatype/test_concat_list.py b/py-polars/tests/unit/functions/as_datatype/test_concat_list.py new file mode 100644 index 000000000000..9c02ee0dfd74 --- /dev/null +++ b/py-polars/tests/unit/functions/as_datatype/test_concat_list.py @@ -0,0 +1,197 @@ +import pytest + +import polars as pl +from polars.exceptions import ComputeError +from polars.testing import assert_frame_equal, assert_series_equal + + +def test_list_concat() -> None: + s0 = pl.Series("a", [[1, 2]]) + s1 = pl.Series("b", [[3, 4, 5]]) + expected = pl.Series("a", [[1, 2, 3, 4, 5]]) + + out = s0.list.concat([s1]) + assert_series_equal(out, expected) + + out = s0.list.concat(s1) + assert_series_equal(out, expected) + + df = pl.DataFrame([s0, s1]) + assert_series_equal(df.select(pl.concat_list(["a", "b"]).alias("a"))["a"], expected) + assert_series_equal( + df.select(pl.col("a").list.concat("b").alias("a"))["a"], expected + ) + assert_series_equal( + df.select(pl.col("a").list.concat(["b"]).alias("a"))["a"], expected + ) + + +def test_concat_list_with_lit() -> None: + df = pl.DataFrame({"a": [1, 2, 3]}) + + result = df.select(pl.concat_list([pl.col("a"), pl.lit(1)]).alias("a")) + expected = {"a": [[1, 1], [2, 1], [3, 1]]} + assert result.to_dict(as_series=False) == expected + + result = df.select(pl.concat_list([pl.lit(1), pl.col("a")]).alias("a")) + expected = {"a": [[1, 1], [1, 2], [1, 3]]} + assert result.to_dict(as_series=False) == expected + + +def test_concat_list_empty_raises() -> None: + with pytest.raises(ComputeError): + pl.DataFrame({"a": [1, 2, 3]}).with_columns(pl.concat_list([])) + + +def test_list_concat_nulls() -> None: + assert pl.DataFrame( + { + "a": [["a", "b"], None, ["c", "d", "e"], None], + "t": [["x"], ["y"], None, None], + } + ).with_columns(pl.concat_list(["a", "t"]).alias("concat"))["concat"].to_list() == [ + ["a", "b", "x"], + None, + None, + None, + ] + + +def test_concat_list_in_agg_6397() -> None: + df = pl.DataFrame({"group": [1, 2, 2, 3], "value": ["a", "b", "c", "d"]}) + + # single list + # TODO: this shouldn't be allowed and raise + # Currently this do a cast to list in the expression and + # therefore leads to different nesting + assert df.group_by("group").agg( + [ + # this casts every element to a list + pl.concat_list(pl.col("value")), + ] + ).sort("group").to_dict(as_series=False) == { + "group": [1, 2, 3], + "value": [[["a"]], [["b"], ["c"]], [["d"]]], + } + + # nested list + assert df.group_by("group").agg( + [ + pl.concat_list(pl.col("value").implode()).alias("result"), + ] + ).sort("group").to_dict(as_series=False) == { + "group": [1, 2, 3], + "result": [["a"], ["b", "c"], ["d"]], + } + + +def test_list_concat_supertype() -> None: + df = pl.DataFrame( + [pl.Series("a", [1, 2], pl.UInt8), pl.Series("b", [10000, 20000], pl.UInt16)] + ) + assert df.with_columns(pl.concat_list(pl.col(["a", "b"])).alias("concat_list"))[ + "concat_list" + ].to_list() == [[1, 10000], [2, 20000]] + + +@pytest.mark.usefixtures("test_global_and_local") +def test_categorical_list_concat_4762() -> None: + df = pl.DataFrame({"x": "a"}) + expected = {"x": [["a", "a"]]} + + q = df.lazy().select([pl.concat_list([pl.col("x").cast(pl.Categorical)] * 2)]) + with pl.StringCache(): + assert q.collect().to_dict(as_series=False) == expected + + +def test_list_concat_rolling_window() -> None: + # inspired by: + # https://stackoverflow.com/questions/70377100/use-the-rolling-function-of-polars-to-get-a-list-of-all-values-in-the-rolling-wi + # this tests if it works without specifically creating list dtype upfront. note that + # the given answer is preferred over this snippet as that reuses the list array when + # shifting + df = pl.DataFrame( + { + "A": [1.0, 2.0, 9.0, 2.0, 13.0], + } + ) + out = df.with_columns( + pl.col("A").shift(i).alias(f"A_lag_{i}") for i in range(3) + ).select(pl.concat_list([f"A_lag_{i}" for i in range(3)][::-1]).alias("A_rolling")) + assert out.shape == (5, 1) + + s = out.to_series() + assert s.dtype == pl.List + assert s.to_list() == [ + [None, None, 1.0], + [None, 1.0, 2.0], + [1.0, 2.0, 9.0], + [2.0, 9.0, 2.0], + [9.0, 2.0, 13.0], + ] + + # this test proper null behavior of concat list + out = ( + df.with_columns( + pl.col("A").reshape((-1, 1)).arr.to_list() # first turn into a list + ) + .with_columns( + pl.col("A").shift(i).alias(f"A_lag_{i}") + for i in range(3) # slice the lists to a lag + ) + .select( + pl.all(), + pl.concat_list([f"A_lag_{i}" for i in range(3)][::-1]).alias("A_rolling"), + ) + ) + assert out.shape == (5, 5) + + l64 = pl.List(pl.Float64) + assert out.schema == { + "A": l64, + "A_lag_0": l64, + "A_lag_1": l64, + "A_lag_2": l64, + "A_rolling": l64, + } + + +def test_concat_list_reverse_struct_fields() -> None: + df = pl.DataFrame({"nums": [1, 2, 3, 4], "letters": ["a", "b", "c", "d"]}).select( + pl.col("nums"), + pl.struct(["letters", "nums"]).alias("combo"), + pl.struct(["nums", "letters"]).alias("reverse_combo"), + ) + result1 = df.select(pl.concat_list(["combo", "reverse_combo"])) + result2 = df.select(pl.concat_list(["combo", "combo"])) + assert_frame_equal(result1, result2) + + +def test_concat_list_empty() -> None: + df = pl.DataFrame({"a": []}) + df.select(pl.concat_list("a")) + + +def test_concat_list_empty_struct() -> None: + df = pl.DataFrame({"a": []}, schema={"a": pl.Struct({"b": pl.Boolean})}) + df.select(pl.concat_list("a")) + + +def test_cross_join_concat_list_18587() -> None: + lf = pl.LazyFrame({"u32": [0, 1], "str": ["a", "b"]}) + + lf1 = lf.select(pl.struct(pl.all()).alias("1")) + lf2 = lf.select(pl.struct(pl.all()).alias("2")) + lf3 = lf.select(pl.struct(pl.all()).alias("3")) + + result = ( + lf1.join(lf2, how="cross") + .join(lf3, how="cross") + .select(pl.concat_list("1", "2", "3")) + .collect() + ) + + vals = [{"u32": 0, "str": "a"}, {"u32": 1, "str": "b"}] + expected = [[a, b, c] for a in vals for b in vals for c in vals] + + assert result["1"].to_list() == expected diff --git a/py-polars/tests/unit/functions/as_datatype/test_concat_str.py b/py-polars/tests/unit/functions/as_datatype/test_concat_str.py new file mode 100644 index 000000000000..85b76ffe2535 --- /dev/null +++ b/py-polars/tests/unit/functions/as_datatype/test_concat_str.py @@ -0,0 +1,73 @@ +import pytest + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal + + +def test_concat_str_wildcard_expansion() -> None: + # one function requires wildcard expansion the other need + # this tests the nested behavior + # see: #2867 + + df = pl.DataFrame({"a": ["x", "Y", "z"], "b": ["S", "o", "S"]}) + assert df.select( + pl.concat_str(pl.all()).str.to_lowercase() + ).to_series().to_list() == ["xs", "yo", "zs"] + + +def test_concat_str_with_non_utf8_col() -> None: + out = ( + pl.LazyFrame({"a": [0], "b": ["x"]}) + .select(pl.concat_str(["a", "b"], separator="-").fill_null(pl.col("a"))) + .collect() + ) + expected = pl.Series("a", ["0-x"], dtype=pl.String) + assert_series_equal(out.to_series(), expected) + + +def test_empty_df_concat_str_11701() -> None: + df = pl.DataFrame({"a": []}) + out = df.select(pl.concat_str([pl.col("a").cast(pl.String), pl.lit("x")])) + assert_frame_equal(out, pl.DataFrame({"a": []}, schema={"a": pl.String})) + + +def test_concat_str_ignore_nulls() -> None: + df = pl.DataFrame({"a": ["a", None, "c"], "b": [None, 2, 3], "c": ["x", "y", "z"]}) + + # ignore nulls + out = df.select([pl.concat_str(["a", "b", "c"], separator="-", ignore_nulls=True)]) + assert out["a"].to_list() == ["a-x", "2-y", "c-3-z"] + # propagate nulls + out = df.select([pl.concat_str(["a", "b", "c"], separator="-", ignore_nulls=False)]) + assert out["a"].to_list() == [None, None, "c-3-z"] + + +@pytest.mark.parametrize( + "expr", + [ + "a" + pl.concat_str(pl.lit("b"), pl.lit("c"), ignore_nulls=True), + "a" + pl.concat_str(pl.lit("b"), pl.lit("c"), ignore_nulls=False), + pl.concat_str(pl.lit("b"), pl.lit("c"), ignore_nulls=True) + "a", + pl.concat_str(pl.lit("b"), pl.lit("c"), ignore_nulls=False) + "a", + pl.lit(None, dtype=pl.String) + + pl.concat_str(pl.lit("b"), pl.lit("c"), ignore_nulls=True), + pl.lit(None, dtype=pl.String) + + pl.concat_str(pl.lit("b"), pl.lit("c"), ignore_nulls=False), + pl.concat_str(pl.lit("b"), pl.lit("c"), ignore_nulls=True) + + pl.lit(None, dtype=pl.String), + pl.concat_str(pl.lit("b"), pl.lit("c"), ignore_nulls=False) + + pl.lit(None, dtype=pl.String), + pl.lit(None, dtype=pl.String) + "a", + "a" + pl.lit(None, dtype=pl.String), + pl.concat_str(None, ignore_nulls=False) + + pl.concat_str(pl.lit("b"), ignore_nulls=False), + pl.concat_str(None, ignore_nulls=True) + + pl.concat_str(pl.lit("b"), ignore_nulls=True), + ], +) +def test_simplify_str_addition_concat_str(expr: pl.Expr) -> None: + ldf = pl.LazyFrame({}).select(expr) + print(ldf.collect(simplify_expression=True)) + assert_frame_equal( + ldf.collect(simplify_expression=True), ldf.collect(simplify_expression=False) + ) diff --git a/py-polars/tests/unit/functions/as_datatype/test_datetime.py b/py-polars/tests/unit/functions/as_datatype/test_datetime.py new file mode 100644 index 000000000000..688ecc2ec7ce --- /dev/null +++ b/py-polars/tests/unit/functions/as_datatype/test_datetime.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING +from zoneinfo import ZoneInfo + +import pytest + +import polars as pl +from polars.exceptions import ComputeError +from polars.testing import assert_series_equal + +if TYPE_CHECKING: + from polars._typing import TimeUnit + + +def test_date_datetime() -> None: + df = pl.DataFrame( + { + "year": [2001, 2002, 2003], + "month": [1, 2, 3], + "day": [1, 2, 3], + "hour": [23, 12, 8], + } + ) + out = df.select( + pl.all(), + pl.datetime("year", "month", "day", "hour").dt.hour().cast(int).alias("h2"), + pl.date("year", "month", "day").dt.day().cast(int).alias("date"), + ) + assert_series_equal(out["date"], df["day"].rename("date")) + assert_series_equal(out["h2"], df["hour"].rename("h2")) + + +@pytest.mark.parametrize( + "components", + [ + [2025, 13, 1], + [2025, 1, 32], + [2025, 2, 29], + ], +) +def test_date_invalid_component(components: list[int]) -> None: + y, m, d = components + msg = rf"Invalid date components \({y}, {m}, {d}\) supplied" + with pytest.raises(ComputeError, match=msg): + pl.select(pl.date(*components)) + + +@pytest.mark.parametrize( + "components", + [ + [2025, 13, 1, 0, 0, 0, 0], + [2025, 1, 32, 0, 0, 0, 0], + [2025, 2, 29, 0, 0, 0, 0], + ], +) +def test_datetime_invalid_date_component(components: list[int]) -> None: + y, m, d = components[0:3] + msg = rf"Invalid date components \({y}, {m}, {d}\) supplied" + with pytest.raises(ComputeError, match=msg): + pl.select(pl.datetime(*components)) + + +@pytest.mark.parametrize( + "components", + [ + [2025, 1, 1, 25, 0, 0, 0], + [2025, 1, 1, 0, 60, 0, 0], + [2025, 1, 1, 0, 0, 60, 0], + [2025, 1, 1, 0, 0, 0, 2_000_000], + ], +) +def test_datetime_invalid_time_component(components: list[int]) -> None: + h, mnt, s, us = components[3:] + ns = us * 1_000 + msg = rf"Invalid time components \({h}, {mnt}, {s}, {ns}\) supplied" + with pytest.raises(ComputeError, match=msg): + pl.select(pl.datetime(*components)) + + +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_datetime_time_unit(time_unit: TimeUnit) -> None: + result = pl.datetime(2022, 1, 2, time_unit=time_unit) + + assert pl.select(result.dt.year()).item() == 2022 + assert pl.select(result.dt.month()).item() == 1 + assert pl.select(result.dt.day()).item() == 2 + + +@pytest.mark.parametrize("time_zone", [None, "Europe/Amsterdam", "UTC"]) +def test_datetime_time_zone(time_zone: str | None) -> None: + result = pl.datetime(2022, 1, 2, 10, time_zone=time_zone) + + assert pl.select(result.dt.year()).item() == 2022 + assert pl.select(result.dt.month()).item() == 1 + assert pl.select(result.dt.day()).item() == 2 + assert pl.select(result.dt.hour()).item() == 10 + + +def test_datetime_ambiguous_time_zone() -> None: + expr = pl.datetime(2018, 10, 28, 2, 30, time_zone="Europe/Brussels") + + with pytest.raises(ComputeError): + pl.select(expr) + + +def test_datetime_ambiguous_time_zone_earliest() -> None: + expr = pl.datetime( + 2018, 10, 28, 2, 30, time_zone="Europe/Brussels", ambiguous="earliest" + ) + + result = pl.select(expr).item() + + expected = datetime(2018, 10, 28, 2, 30, tzinfo=ZoneInfo("Europe/Brussels")) + assert result == expected + assert result.fold == 0 + + +def test_datetime_wildcard_expansion() -> None: + df = pl.DataFrame({"a": [1], "b": [2]}) + assert df.select( + pl.datetime(year=pl.all(), month=pl.all(), day=pl.all()).name.keep() + ).to_dict(as_series=False) == { + "a": [datetime(1, 1, 1, 0, 0)], + "b": [datetime(2, 2, 2, 0, 0)], + } + + +def test_datetime_invalid_time_zone() -> None: + df1 = pl.DataFrame({"year": []}, schema={"year": pl.Int32}) + df2 = pl.DataFrame({"year": [2024]}) + + with pytest.raises(ComputeError, match="unable to parse time zone: 'foo'"): + df1.select(pl.datetime("year", 1, 1, time_zone="foo")) + + with pytest.raises(ComputeError, match="unable to parse time zone: 'foo'"): + df2.select(pl.datetime("year", 1, 1, time_zone="foo")) + + +def test_datetime_from_empty_column() -> None: + df = pl.DataFrame({"year": []}, schema={"year": pl.Int32}) + + assert df.select(datetime=pl.datetime("year", 1, 1)).shape == (0, 1) + assert df.with_columns(datetime=pl.datetime("year", 1, 1)).shape == (0, 2) diff --git a/py-polars/tests/unit/functions/as_datatype/test_duration.py b/py-polars/tests/unit/functions/as_datatype/test_duration.py new file mode 100644 index 000000000000..a3fc5dc658ef --- /dev/null +++ b/py-polars/tests/unit/functions/as_datatype/test_duration.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +from datetime import date, datetime, timedelta +from typing import TYPE_CHECKING + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + +if TYPE_CHECKING: + from polars._typing import TimeUnit + + +def test_empty_duration() -> None: + s = pl.DataFrame([], {"days": pl.Int32}).select(pl.duration(days="days")) + assert s.dtypes == [pl.Duration("us")] + assert s.shape == (0, 1) + + +@pytest.mark.parametrize( + ("time_unit", "expected"), + [ + ("ms", timedelta(days=1, minutes=2, seconds=3, milliseconds=4)), + ("us", timedelta(days=1, minutes=2, seconds=3, milliseconds=4, microseconds=5)), + ("ns", timedelta(days=1, minutes=2, seconds=3, milliseconds=4, microseconds=5)), + ], +) +def test_duration_time_units(time_unit: TimeUnit, expected: timedelta) -> None: + result = pl.LazyFrame().select( + duration=pl.duration( + days=1, + minutes=2, + seconds=3, + milliseconds=4, + microseconds=5, + nanoseconds=6, + time_unit=time_unit, + ) + ) + assert result.collect_schema()["duration"] == pl.Duration(time_unit) + assert result.collect()["duration"].item() == expected + if time_unit == "ns": + assert ( + result.collect()["duration"].dt.total_nanoseconds().item() == 86523004005006 + ) + + +def test_datetime_duration_offset() -> None: + df = pl.DataFrame( + { + "datetime": [ + datetime(1999, 1, 1, 7), + datetime(2022, 1, 2, 14), + datetime(3000, 12, 31, 21), + ], + "add": [1, 2, -1], + } + ) + out = df.select( + (pl.col("datetime") + pl.duration(weeks="add")).alias("add_weeks"), + (pl.col("datetime") + pl.duration(days="add")).alias("add_days"), + (pl.col("datetime") + pl.duration(hours="add")).alias("add_hours"), + (pl.col("datetime") + pl.duration(seconds="add")).alias("add_seconds"), + (pl.col("datetime") + pl.duration(microseconds=pl.col("add") * 1000)).alias( + "add_usecs" + ), + ) + expected = pl.DataFrame( + { + "add_weeks": [ + datetime(1999, 1, 8, 7), + datetime(2022, 1, 16, 14), + datetime(3000, 12, 24, 21), + ], + "add_days": [ + datetime(1999, 1, 2, 7), + datetime(2022, 1, 4, 14), + datetime(3000, 12, 30, 21), + ], + "add_hours": [ + datetime(1999, 1, 1, hour=8), + datetime(2022, 1, 2, hour=16), + datetime(3000, 12, 31, hour=20), + ], + "add_seconds": [ + datetime(1999, 1, 1, 7, second=1), + datetime(2022, 1, 2, 14, second=2), + datetime(3000, 12, 31, 20, 59, 59), + ], + "add_usecs": [ + datetime(1999, 1, 1, 7, microsecond=1000), + datetime(2022, 1, 2, 14, microsecond=2000), + datetime(3000, 12, 31, 20, 59, 59, 999000), + ], + } + ) + assert_frame_equal(out, expected) + + +def test_date_duration_offset() -> None: + df = pl.DataFrame( + { + "date": [date(10, 1, 1), date(2000, 7, 5), date(9990, 12, 31)], + "offset": [365, 7, -31], + } + ) + out = df.select( + (pl.col("date") + pl.duration(days="offset")).alias("add_days"), + (pl.col("date") - pl.duration(days="offset")).alias("sub_days"), + (pl.col("date") + pl.duration(weeks="offset")).alias("add_weeks"), + (pl.col("date") - pl.duration(weeks="offset")).alias("sub_weeks"), + ) + assert out.to_dict(as_series=False) == { + "add_days": [date(11, 1, 1), date(2000, 7, 12), date(9990, 11, 30)], + "sub_days": [date(9, 1, 1), date(2000, 6, 28), date(9991, 1, 31)], + "add_weeks": [date(16, 12, 30), date(2000, 8, 23), date(9990, 5, 28)], + "sub_weeks": [date(3, 1, 3), date(2000, 5, 17), date(9991, 8, 5)], + } + + +def test_add_duration_3786() -> None: + df = pl.DataFrame( + { + "datetime": [datetime(2022, 1, 1), datetime(2022, 1, 2)], + "add": [1, 2], + } + ) + assert df.slice(0, 1).with_columns( + (pl.col("datetime") + pl.duration(weeks="add")).alias("add_weeks"), + (pl.col("datetime") + pl.duration(days="add")).alias("add_days"), + (pl.col("datetime") + pl.duration(seconds="add")).alias("add_seconds"), + (pl.col("datetime") + pl.duration(milliseconds="add")).alias( + "add_milliseconds" + ), + (pl.col("datetime") + pl.duration(hours="add")).alias("add_hours"), + ).to_dict(as_series=False) == { + "datetime": [datetime(2022, 1, 1, 0, 0)], + "add": [1], + "add_weeks": [datetime(2022, 1, 8, 0, 0)], + "add_days": [datetime(2022, 1, 2, 0, 0)], + "add_seconds": [datetime(2022, 1, 1, 0, 0, 1)], + "add_milliseconds": [datetime(2022, 1, 1, 0, 0, 0, 1000)], + "add_hours": [datetime(2022, 1, 1, 1, 0)], + } + + +@pytest.mark.parametrize( + ("time_unit", "ms", "us", "ns"), + [ + ("ms", 11, 0, 0), + ("us", 0, 11_007, 0), + ("ns", 0, 0, 11_007_003), + ], +) +def test_duration_subseconds_us(time_unit: TimeUnit, ms: int, us: int, ns: int) -> None: + result = pl.duration( + milliseconds=6, microseconds=4_005, nanoseconds=1_002_003, time_unit=time_unit + ) + expected = pl.duration( + milliseconds=ms, microseconds=us, nanoseconds=ns, time_unit=time_unit + ) + assert_frame_equal(pl.select(result), pl.select(expected)) + + +def test_duration_time_unit_ns() -> None: + result = pl.duration(milliseconds=4, microseconds=3_000, nanoseconds=10) + expected = pl.duration( + milliseconds=4, microseconds=3_000, nanoseconds=10, time_unit="ns" + ) + assert_frame_equal(pl.select(result), pl.select(expected)) + + +def test_duration_time_unit_us() -> None: + result = pl.duration(milliseconds=4, microseconds=3_000) + expected = pl.duration(milliseconds=4, microseconds=3_000, time_unit="us") + assert_frame_equal(pl.select(result), pl.select(expected)) + + +def test_duration_time_unit_ms() -> None: + result = pl.duration(milliseconds=4) + expected = pl.duration(milliseconds=4, time_unit="us") + assert_frame_equal(pl.select(result), pl.select(expected)) + + +def test_duration_wildcard_expansion() -> None: + # Test that wildcard expansions occurs correctly in pl.duration + # https://github.com/pola-rs/polars/issues/19007 + df = df = pl.DataFrame({"a": [1], "b": [2]}) + assert df.select(pl.duration(hours=pl.all()).name.keep()).to_dict( + as_series=False + ) == {"a": [timedelta(seconds=3600)], "b": [timedelta(seconds=7200)]} diff --git a/py-polars/tests/unit/functions/as_datatype/test_format.py b/py-polars/tests/unit/functions/as_datatype/test_format.py new file mode 100644 index 000000000000..074efa8cc374 --- /dev/null +++ b/py-polars/tests/unit/functions/as_datatype/test_format.py @@ -0,0 +1,8 @@ +import polars as pl + + +def test_format() -> None: + df = pl.DataFrame({"a": ["a", "b", "c"], "b": [1, 2, 3]}) + + out = df.select([pl.format("foo_{}_bar_{}", pl.col("a"), "b").alias("fmt")]) + assert out["fmt"].to_list() == ["foo_a_bar_1", "foo_b_bar_2", "foo_c_bar_3"] diff --git a/py-polars/tests/unit/functions/as_datatype/test_struct.py b/py-polars/tests/unit/functions/as_datatype/test_struct.py new file mode 100644 index 000000000000..2073723f2d04 --- /dev/null +++ b/py-polars/tests/unit/functions/as_datatype/test_struct.py @@ -0,0 +1,267 @@ +from datetime import date, datetime + +import pytest + +import polars as pl +from polars.exceptions import DuplicateError +from polars.testing import assert_frame_equal, assert_series_equal + + +def test_struct_args_kwargs() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": ["a", "b"]}) + + # Single input + result = df.select(r=pl.struct((pl.col("a") + pl.col("b")).alias("p"))) + expected = pl.DataFrame({"r": [{"p": 4}, {"p": 6}]}) + assert_frame_equal(result, expected) + + # List input + result = df.select(r=pl.struct([pl.col("a").alias("p"), pl.col("b").alias("q")])) + expected = pl.DataFrame({"r": [{"p": 1, "q": 3}, {"p": 2, "q": 4}]}) + assert_frame_equal(result, expected) + + # Positional input + result = df.select(r=pl.struct(pl.col("a").alias("p"), pl.col("b").alias("q"))) + assert_frame_equal(result, expected) + + # Keyword input + result = df.select(r=pl.struct(p="a", q="b")) + assert_frame_equal(result, expected) + + +def test_struct_with_lit() -> None: + expr = pl.struct([pl.col("a"), pl.lit(1).alias("b")]) + + assert ( + pl.DataFrame({"a": pl.Series([], dtype=pl.Int64)}) + .select(expr) + .to_dict(as_series=False) + ) == {"a": []} + + assert ( + pl.DataFrame({"a": pl.Series([1], dtype=pl.Int64)}) + .select(expr) + .to_dict(as_series=False) + ) == {"a": [{"a": 1, "b": 1}]} + + assert ( + pl.DataFrame({"a": pl.Series([1, 2], dtype=pl.Int64)}) + .select(expr) + .to_dict(as_series=False) + ) == {"a": [{"a": 1, "b": 1}, {"a": 2, "b": 1}]} + + +def test_eager_struct() -> None: + with pytest.raises(DuplicateError, match="multiple fields with name '' found"): + s = pl.struct([pl.Series([1, 2, 3]), pl.Series(["a", "b", "c"])], eager=True) + + s = pl.struct( + [pl.Series("a", [1, 2, 3]), pl.Series("b", ["a", "b", "c"])], eager=True + ) + assert s.dtype == pl.Struct + + +def test_struct_from_schema_only() -> None: + # Workaround for new streaming engine. + with pl.StringCache(): + # we create a dataframe with default types + df = pl.DataFrame( + { + "str": ["a", "b", "c", "d", "e"], + "u8": [1, 2, 3, 4, 5], + "i32": [1, 2, 3, 4, 5], + "f64": [1, 2, 3, 4, 5], + "cat": ["a", "b", "c", "d", "e"], + "datetime": pl.Series( + [ + date(2023, 1, 1), + date(2023, 1, 2), + date(2023, 1, 3), + date(2023, 1, 4), + date(2023, 1, 5), + ] + ), + "bool": [1, 0, 1, 1, 0], + "list[u8]": [[1], [2], [3], [4], [5]], + } + ) + + # specify a schema with specific dtypes + s = df.select( + pl.struct( + schema={ + "str": pl.String, + "u8": pl.UInt8, + "i32": pl.Int32, + "f64": pl.Float64, + "cat": pl.Categorical, + "datetime": pl.Datetime("ms"), + "bool": pl.Boolean, + "list[u8]": pl.List(pl.UInt8), + } + ).alias("s") + )["s"] + + # check dtypes + assert s.dtype == pl.Struct( + [ + pl.Field("str", pl.String), + pl.Field("u8", pl.UInt8), + pl.Field("i32", pl.Int32), + pl.Field("f64", pl.Float64), + pl.Field("cat", pl.Categorical), + pl.Field("datetime", pl.Datetime("ms")), + pl.Field("bool", pl.Boolean), + pl.Field("list[u8]", pl.List(pl.UInt8)), + ] + ) + + # check values + assert s.to_list() == [ + { + "str": "a", + "u8": 1, + "i32": 1, + "f64": 1.0, + "cat": "a", + "datetime": datetime(2023, 1, 1, 0, 0), + "bool": True, + "list[u8]": [1], + }, + { + "str": "b", + "u8": 2, + "i32": 2, + "f64": 2.0, + "cat": "b", + "datetime": datetime(2023, 1, 2, 0, 0), + "bool": False, + "list[u8]": [2], + }, + { + "str": "c", + "u8": 3, + "i32": 3, + "f64": 3.0, + "cat": "c", + "datetime": datetime(2023, 1, 3, 0, 0), + "bool": True, + "list[u8]": [3], + }, + { + "str": "d", + "u8": 4, + "i32": 4, + "f64": 4.0, + "cat": "d", + "datetime": datetime(2023, 1, 4, 0, 0), + "bool": True, + "list[u8]": [4], + }, + { + "str": "e", + "u8": 5, + "i32": 5, + "f64": 5.0, + "cat": "e", + "datetime": datetime(2023, 1, 5, 0, 0), + "bool": False, + "list[u8]": [5], + }, + ] + + +def test_struct_broadcasting() -> None: + df = pl.DataFrame( + { + "col1": [1, 2], + "col2": [10, 20], + } + ) + + assert ( + df.select( + pl.struct( + [ + pl.lit("a").alias("a"), + pl.col("col1").alias("col1"), + ] + ).alias("my_struct") + ) + ).to_dict(as_series=False) == { + "my_struct": [{"a": "a", "col1": 1}, {"a": "a", "col1": 2}] + } + + +def test_struct_list_cat_8235() -> None: + df = pl.DataFrame( + {"values": [["a", "b", "c"]]}, schema={"values": pl.List(pl.Categorical)} + ) + assert df.select(pl.struct("values")).to_dict(as_series=False) == { + "values": [{"values": ["a", "b", "c"]}] + } + + +def test_struct_lit_cast() -> None: + df = pl.DataFrame({"a": [1, 2, 3]}) + schema = {"a": pl.Int64, "b": pl.List(pl.Int64)} + + out = df.select( + pl.struct(pl.col("a"), pl.lit(None).alias("b"), schema=schema) # type: ignore[arg-type] + ).get_column("a") + + expected = pl.Series( + "a", + [ + {"a": 1, "b": None}, + {"a": 2, "b": None}, + {"a": 3, "b": None}, + ], + dtype=pl.Struct([pl.Field("a", pl.Int64), pl.Field("b", pl.List(pl.Int64))]), + ) + assert_series_equal(out, expected) + + out = df.select( + pl.struct([pl.col("a"), pl.lit(pl.Series([[]])).alias("b")], schema=schema) # type: ignore[arg-type] + ).get_column("a") + + expected = pl.Series( + "a", + [ + {"a": 1, "b": []}, + {"a": 2, "b": []}, + {"a": 3, "b": []}, + ], + dtype=pl.Struct([pl.Field("a", pl.Int64), pl.Field("b", pl.List(pl.Int64))]), + ) + assert_series_equal(out, expected) + + +def test_suffix_in_struct_creation() -> None: + assert ( + pl.DataFrame( + { + "a": [1, 2], + "b": [3, 4], + "c": [5, 6], + } + ).select(pl.struct(pl.col(["a", "c"]).name.suffix("_foo")).alias("bar")) + ).unnest("bar").to_dict(as_series=False) == {"a_foo": [1, 2], "c_foo": [5, 6]} + + +def test_resolved_names_15442() -> None: + df = pl.DataFrame( + { + "x": [206.0], + "y": [225.0], + } + ) + center = pl.struct( + x=pl.col("x"), + y=pl.col("y"), + ) + + left = 0 + right = 1000 + in_x = (left < center.struct.field("x")) & (center.struct.field("x") <= right) + assert df.lazy().filter(in_x).collect().shape == (1, 2) diff --git a/py-polars/tests/unit/functions/as_datatype/test_time.py b/py-polars/tests/unit/functions/as_datatype/test_time.py new file mode 100644 index 000000000000..668b87b2941d --- /dev/null +++ b/py-polars/tests/unit/functions/as_datatype/test_time.py @@ -0,0 +1,24 @@ +import polars as pl +from polars.testing import assert_series_equal + + +def test_time() -> None: + df = pl.DataFrame( + { + "hour": [7, 14, 21], + "min": [10, 20, 30], + "sec": [15, 30, 45], + "micro": [123456, 555555, 987654], + } + ) + out = df.select( + pl.all(), + pl.time("hour", "min", "sec", "micro").dt.hour().cast(int).alias("h2"), + pl.time("hour", "min", "sec", "micro").dt.minute().cast(int).alias("m2"), + pl.time("hour", "min", "sec", "micro").dt.second().cast(int).alias("s2"), + pl.time("hour", "min", "sec", "micro").dt.microsecond().cast(int).alias("ms2"), + ) + assert_series_equal(out["h2"], df["hour"].rename("h2")) + assert_series_equal(out["m2"], df["min"].rename("m2")) + assert_series_equal(out["s2"], df["sec"].rename("s2")) + assert_series_equal(out["ms2"], df["micro"].rename("ms2")) diff --git a/py-polars/tests/unit/functions/range/__init__.py b/py-polars/tests/unit/functions/range/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/functions/range/test_date_range.py b/py-polars/tests/unit/functions/range/test_date_range.py new file mode 100644 index 000000000000..758b604aeba5 --- /dev/null +++ b/py-polars/tests/unit/functions/range/test_date_range.py @@ -0,0 +1,321 @@ +from __future__ import annotations + +from datetime import date, datetime +from typing import TYPE_CHECKING + +import pandas as pd +import pytest + +import polars as pl +from polars.exceptions import ComputeError, InvalidOperationError +from polars.testing import assert_frame_equal, assert_series_equal + +if TYPE_CHECKING: + from polars._typing import ClosedInterval + + +def test_date_range() -> None: + # if low/high are both date, range is also be date _iff_ the granularity is >= 1d + result = pl.date_range(date(2022, 1, 1), date(2022, 3, 1), "1mo", eager=True) + assert result.to_list() == [date(2022, 1, 1), date(2022, 2, 1), date(2022, 3, 1)] + + +def test_date_range_invalid_time_unit() -> None: + with pytest.raises(InvalidOperationError, match="'x' not supported"): + pl.date_range( + start=date(2021, 12, 16), + end=date(2021, 12, 18), + interval="1X", + eager=True, + ) + + +def test_date_range_lazy_with_literals() -> None: + df = pl.DataFrame({"misc": ["x"]}).with_columns( + pl.date_ranges( + date(2000, 1, 1), + date(2023, 8, 31), + interval="987d", + eager=False, + ).alias("dts") + ) + assert df.rows() == [ + ( + "x", + [ + date(2000, 1, 1), + date(2002, 9, 14), + date(2005, 5, 28), + date(2008, 2, 9), + date(2010, 10, 23), + date(2013, 7, 6), + date(2016, 3, 19), + date(2018, 12, 1), + date(2021, 8, 14), + ], + ) + ] + assert ( + df.rows()[0][1] + == pd.date_range( + date(2000, 1, 1), date(2023, 12, 31), freq="987d" + ).date.tolist() + ) + + +@pytest.mark.parametrize("low", ["start", pl.col("start")]) +@pytest.mark.parametrize("high", ["stop", pl.col("stop")]) +def test_date_range_lazy_with_expressions( + low: str | pl.Expr, high: str | pl.Expr +) -> None: + lf = pl.LazyFrame( + { + "start": [date(2015, 6, 30)], + "stop": [date(2022, 12, 31)], + } + ) + + result = lf.with_columns( + pl.date_ranges(low, high, interval="678d", eager=False).alias("dts") + ) + + assert result.collect().rows() == [ + ( + date(2015, 6, 30), + date(2022, 12, 31), + [ + date(2015, 6, 30), + date(2017, 5, 8), + date(2019, 3, 17), + date(2021, 1, 23), + date(2022, 12, 2), + ], + ) + ] + + df = pl.DataFrame( + { + "start": [date(2000, 1, 1), date(2022, 6, 1)], + "stop": [date(2000, 1, 2), date(2022, 6, 2)], + } + ) + + result_df = df.with_columns(pl.date_ranges(low, high, interval="1d").alias("dts")) + + assert result_df.to_dict(as_series=False) == { + "start": [date(2000, 1, 1), date(2022, 6, 1)], + "stop": [date(2000, 1, 2), date(2022, 6, 2)], + "dts": [ + [date(2000, 1, 1), date(2000, 1, 2)], + [date(2022, 6, 1), date(2022, 6, 2)], + ], + } + + +def test_date_ranges_single_row_lazy_7110() -> None: + df = pl.DataFrame( + { + "name": ["A"], + "from": [date(2020, 1, 1)], + "to": [date(2020, 1, 2)], + } + ) + result = df.with_columns( + pl.date_ranges( + start=pl.col("from"), + end=pl.col("to"), + interval="1d", + eager=False, + ).alias("date_range") + ) + expected = pl.DataFrame( + { + "name": ["A"], + "from": [date(2020, 1, 1)], + "to": [date(2020, 1, 2)], + "date_range": [[date(2020, 1, 1), date(2020, 1, 2)]], + } + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + ("closed", "expected_values"), + [ + ("right", [date(2020, 2, 29), date(2020, 3, 31)]), + ("left", [date(2020, 1, 31), date(2020, 2, 29)]), + ("none", [date(2020, 2, 29)]), + ("both", [date(2020, 1, 31), date(2020, 2, 29), date(2020, 3, 31)]), + ], +) +def test_date_range_end_of_month_5441( + closed: ClosedInterval, expected_values: list[date] +) -> None: + start = date(2020, 1, 31) + stop = date(2020, 3, 31) + result = pl.date_range(start, stop, interval="1mo", closed=closed, eager=True) + expected = pl.Series("literal", expected_values) + assert_series_equal(result, expected) + + +def test_date_range_name() -> None: + result_eager = pl.date_range(date(2020, 1, 1), date(2020, 1, 3), eager=True) + assert result_eager.name == "literal" + + start = pl.Series("left", [date(2020, 1, 1)]) + result_lazy = pl.select( + pl.date_range(start, date(2020, 1, 3), eager=False) + ).to_series() + assert result_lazy.name == "left" + + +def test_date_ranges_eager() -> None: + start = pl.Series("start", [date(2022, 1, 1), date(2022, 1, 2)]) + end = pl.Series("end", [date(2022, 1, 4), date(2022, 1, 3)]) + + result = pl.date_ranges(start, end, eager=True) + + expected = pl.Series( + "start", + [ + [date(2022, 1, 1), date(2022, 1, 2), date(2022, 1, 3), date(2022, 1, 4)], + [date(2022, 1, 2), date(2022, 1, 3)], + ], + ) + assert_series_equal(result, expected) + + +def test_date_range_eager() -> None: + start = pl.Series("start", [date(2022, 1, 1)]) + end = pl.Series("end", [date(2022, 1, 3)]) + + result = pl.date_range(start, end, eager=True) + + expected = pl.Series( + "start", [date(2022, 1, 1), date(2022, 1, 2), date(2022, 1, 3)] + ) + assert_series_equal(result, expected) + + +def test_date_range_input_shape_empty() -> None: + empty = pl.Series(dtype=pl.Datetime) + single = pl.Series([datetime(2022, 1, 2)]) + + with pytest.raises( + ComputeError, match="`start` must contain exactly one value, got 0 values" + ): + pl.date_range(empty, single, eager=True) + with pytest.raises( + ComputeError, match="`end` must contain exactly one value, got 0 values" + ): + pl.date_range(single, empty, eager=True) + with pytest.raises( + ComputeError, match="`start` must contain exactly one value, got 0 values" + ): + pl.date_range(empty, empty, eager=True) + + +def test_date_range_input_shape_multiple_values() -> None: + single = pl.Series([datetime(2022, 1, 2)]) + multiple = pl.Series([datetime(2022, 1, 3), datetime(2022, 1, 4)]) + + with pytest.raises( + ComputeError, match="`start` must contain exactly one value, got 2 values" + ): + pl.date_range(multiple, single, eager=True) + with pytest.raises( + ComputeError, match="`end` must contain exactly one value, got 2 values" + ): + pl.date_range(single, multiple, eager=True) + with pytest.raises( + ComputeError, match="`start` must contain exactly one value, got 2 values" + ): + pl.date_range(multiple, multiple, eager=True) + + +def test_date_range_start_later_than_end() -> None: + result = pl.date_range(date(2000, 3, 20), date(2000, 3, 5), eager=True) + expected = pl.Series("literal", dtype=pl.Date) + assert_series_equal(result, expected) + + +def test_date_range_24h_interval_raises() -> None: + with pytest.raises( + ComputeError, + match="`interval` input for `date_range` must consist of full days", + ): + pl.date_range(date(2022, 1, 1), date(2022, 1, 3), interval="24h", eager=True) + + +def test_long_date_range_12461() -> None: + result = pl.date_range(date(1900, 1, 1), date(2300, 1, 1), "1d", eager=True) + assert result[0] == date(1900, 1, 1) + assert result[-1] == date(2300, 1, 1) + assert (result.diff()[1:].dt.total_days() == 1).all() + + +def test_date_ranges_broadcasting() -> None: + df = pl.DataFrame({"dates": [date(2021, 1, 1), date(2021, 1, 2), date(2021, 1, 3)]}) + result = df.select( + pl.date_ranges(start="dates", end=date(2021, 1, 3)).alias("end"), + pl.date_ranges(start=date(2021, 1, 1), end="dates").alias("start"), + ) + expected = pl.DataFrame( + { + "end": [ + [date(2021, 1, 1), date(2021, 1, 2), date(2021, 1, 3)], + [date(2021, 1, 2), date(2021, 1, 3)], + [date(2021, 1, 3)], + ], + "start": [ + [date(2021, 1, 1)], + [date(2021, 1, 1), date(2021, 1, 2)], + [date(2021, 1, 1), date(2021, 1, 2), date(2021, 1, 3)], + ], + } + ) + assert_frame_equal(result, expected) + + +def test_date_ranges_broadcasting_fail() -> None: + start = pl.Series([date(2021, 1, 1), date(2021, 1, 2), date(2021, 1, 3)]) + end = pl.Series([date(2021, 1, 2), date(2021, 1, 3)]) + + with pytest.raises( + ComputeError, match=r"lengths of `start` \(3\) and `end` \(2\) do not match" + ): + pl.date_ranges(start, end, eager=True) + + +def test_date_range_datetime_input() -> None: + result = pl.date_range( + datetime(2022, 1, 1, 12), datetime(2022, 1, 3), interval="1d", eager=True + ) + expected = pl.Series( + "literal", [date(2022, 1, 1), date(2022, 1, 2), date(2022, 1, 3)] + ) + assert_series_equal(result, expected) + + +def test_date_ranges_datetime_input() -> None: + result = pl.date_ranges( + datetime(2022, 1, 1, 12), datetime(2022, 1, 3), interval="1d", eager=True + ) + expected = pl.Series( + "literal", [[date(2022, 1, 1), date(2022, 1, 2), date(2022, 1, 3)]] + ) + assert_series_equal(result, expected) + + +def test_date_range_with_subclass_18470_18447() -> None: + class MyAmazingDate(date): + pass + + class MyAmazingDatetime(datetime): + pass + + result = pl.datetime_range( + MyAmazingDate(2020, 1, 1), MyAmazingDatetime(2020, 1, 2), eager=True + ) + expected = pl.Series("literal", [datetime(2020, 1, 1), datetime(2020, 1, 2)]) + assert_series_equal(result, expected) diff --git a/py-polars/tests/unit/functions/range/test_datetime_range.py b/py-polars/tests/unit/functions/range/test_datetime_range.py new file mode 100644 index 000000000000..d052900b1e45 --- /dev/null +++ b/py-polars/tests/unit/functions/range/test_datetime_range.py @@ -0,0 +1,643 @@ +from __future__ import annotations + +from datetime import date, datetime, timedelta +from typing import TYPE_CHECKING +from zoneinfo import ZoneInfo + +import hypothesis.strategies as st +import pytest +from hypothesis import given, settings + +import polars as pl +from polars.datatypes import DTYPE_TEMPORAL_UNITS +from polars.exceptions import ComputeError, InvalidOperationError, SchemaError +from polars.testing import assert_frame_equal, assert_series_equal + +if TYPE_CHECKING: + from polars._typing import ClosedInterval, PolarsDataType, TimeUnit + + +def test_datetime_range() -> None: + result = pl.datetime_range( + date(1985, 1, 1), date(2015, 7, 1), timedelta(days=1, hours=12), eager=True + ) + assert len(result) == 7426 + assert result.dt[0] == datetime(1985, 1, 1) + assert result.dt[1] == datetime(1985, 1, 2, 12, 0) + assert result.dt[2] == datetime(1985, 1, 4, 0, 0) + assert result.dt[-1] == datetime(2015, 6, 30, 12, 0) + + for time_unit in DTYPE_TEMPORAL_UNITS: + rng = pl.datetime_range( + datetime(2020, 1, 1), + date(2020, 1, 2), + "2h", + time_unit=time_unit, + eager=True, + ) + assert rng.dtype.time_unit == time_unit # type: ignore[attr-defined] + assert rng.shape == (13,) + assert rng.dt[0] == datetime(2020, 1, 1) + assert rng.dt[-1] == datetime(2020, 1, 2) + + result = pl.datetime_range(date(2022, 1, 1), date(2022, 1, 2), "1h30m", eager=True) + assert list(result) == [ + datetime(2022, 1, 1, 0, 0), + datetime(2022, 1, 1, 1, 30), + datetime(2022, 1, 1, 3, 0), + datetime(2022, 1, 1, 4, 30), + datetime(2022, 1, 1, 6, 0), + datetime(2022, 1, 1, 7, 30), + datetime(2022, 1, 1, 9, 0), + datetime(2022, 1, 1, 10, 30), + datetime(2022, 1, 1, 12, 0), + datetime(2022, 1, 1, 13, 30), + datetime(2022, 1, 1, 15, 0), + datetime(2022, 1, 1, 16, 30), + datetime(2022, 1, 1, 18, 0), + datetime(2022, 1, 1, 19, 30), + datetime(2022, 1, 1, 21, 0), + datetime(2022, 1, 1, 22, 30), + datetime(2022, 1, 2, 0, 0), + ] + + result = pl.datetime_range( + datetime(2022, 1, 1), datetime(2022, 1, 1, 0, 1), "987456321ns", eager=True + ) + assert len(result) == 61 + assert result.dtype.time_unit == "ns" # type: ignore[attr-defined] + assert result.dt.second()[-1] == 59 + assert result.cast(pl.String)[-1] == "2022-01-01 00:00:59.247379260" + + +@pytest.mark.parametrize( + ("time_unit", "expected_micros"), + [ + ("ms", 986000), + ("us", 986759), + ("ns", 986759), + (None, 986759), + ], +) +def test_datetime_range_precision( + time_unit: TimeUnit | None, expected_micros: int +) -> None: + micros = 986759 + start = datetime(2000, 5, 30, 1, 53, 4, micros) + stop = datetime(2000, 5, 31, 1, 53, 4, micros) + result = pl.datetime_range(start, stop, time_unit=time_unit, eager=True) + expected_start = start.replace(microsecond=expected_micros) + expected_stop = stop.replace(microsecond=expected_micros) + assert result[0] == expected_start + assert result[1] == expected_stop + + +def test_datetime_range_invalid_time_unit() -> None: + with pytest.raises(InvalidOperationError, match="'x' not supported"): + pl.datetime_range( + start=datetime(2021, 12, 16), + end=datetime(2021, 12, 16, 3), + interval="1X", + eager=True, + ) + + +def test_datetime_range_lazy_time_zones() -> None: + start = datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kathmandu")) + stop = datetime(2020, 1, 2, tzinfo=ZoneInfo("Asia/Kathmandu")) + result = ( + pl.DataFrame({"start": [start], "stop": [stop]}) + .with_columns( + pl.datetime_range( + start, + stop, + interval="678d", + eager=False, + time_zone="Pacific/Tarawa", + ) + ) + .lazy() + ) + expected = pl.DataFrame( + { + "start": [ + datetime(2020, 1, 1, 00, 00, tzinfo=ZoneInfo(key="Asia/Kathmandu")) + ], + "stop": [ + datetime(2020, 1, 2, 00, 00, tzinfo=ZoneInfo(key="Asia/Kathmandu")) + ], + "literal": [ + datetime(2020, 1, 1, 6, 15, tzinfo=ZoneInfo(key="Pacific/Tarawa")) + ], + } + ).with_columns(pl.col("literal").dt.convert_time_zone("Pacific/Tarawa")) + assert_frame_equal(result.collect(), expected) + + +@pytest.mark.parametrize("low", ["start", pl.col("start")]) +@pytest.mark.parametrize("high", ["stop", pl.col("stop")]) +def test_datetime_range_lazy_with_expressions( + low: str | pl.Expr, high: str | pl.Expr +) -> None: + df = pl.DataFrame( + { + "start": [datetime(2000, 1, 1), datetime(2022, 6, 1)], + "stop": [datetime(2000, 1, 2), datetime(2022, 6, 2)], + } + ) + + result_df = df.with_columns( + pl.datetime_ranges(low, high, interval="1d").alias("dts") + ) + + assert result_df.to_dict(as_series=False) == { + "start": [datetime(2000, 1, 1, 0, 0), datetime(2022, 6, 1, 0, 0)], + "stop": [datetime(2000, 1, 2, 0, 0), datetime(2022, 6, 2, 0, 0)], + "dts": [ + [datetime(2000, 1, 1, 0, 0), datetime(2000, 1, 2, 0, 0)], + [datetime(2022, 6, 1, 0, 0), datetime(2022, 6, 2, 0, 0)], + ], + } + + +def test_datetime_range_invalid_time_zone() -> None: + with pytest.raises(ComputeError, match="unable to parse time zone: 'foo'"): + pl.datetime_range( + datetime(2001, 1, 1), + datetime(2001, 1, 3), + time_zone="foo", + eager=True, + ) + + +def test_timezone_aware_datetime_range() -> None: + low = datetime(2022, 10, 17, 10, tzinfo=ZoneInfo("Asia/Shanghai")) + high = datetime(2022, 11, 17, 10, tzinfo=ZoneInfo("Asia/Shanghai")) + + assert pl.datetime_range( + low, high, interval=timedelta(days=5), eager=True + ).to_list() == [ + datetime(2022, 10, 17, 10, 0, tzinfo=ZoneInfo("Asia/Shanghai")), + datetime(2022, 10, 22, 10, 0, tzinfo=ZoneInfo("Asia/Shanghai")), + datetime(2022, 10, 27, 10, 0, tzinfo=ZoneInfo("Asia/Shanghai")), + datetime(2022, 11, 1, 10, 0, tzinfo=ZoneInfo("Asia/Shanghai")), + datetime(2022, 11, 6, 10, 0, tzinfo=ZoneInfo("Asia/Shanghai")), + datetime(2022, 11, 11, 10, 0, tzinfo=ZoneInfo("Asia/Shanghai")), + datetime(2022, 11, 16, 10, 0, tzinfo=ZoneInfo("Asia/Shanghai")), + ] + + with pytest.raises( + SchemaError, + match="failed to determine supertype", + ): + pl.datetime_range( + low, + high.replace(tzinfo=None), + interval=timedelta(days=5), + time_zone="UTC", + eager=True, + ) + + +def test_tzaware_datetime_range_crossing_dst_hourly() -> None: + result = pl.datetime_range( + datetime(2021, 11, 7), + datetime(2021, 11, 7, 2), + "1h", + time_zone="US/Central", + eager=True, + ) + assert result.to_list() == [ + datetime(2021, 11, 7, 0, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 11, 7, 1, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 11, 7, 1, 0, fold=1, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 11, 7, 2, 0, tzinfo=ZoneInfo("US/Central")), + ] + + +def test_tzaware_datetime_range_crossing_dst_daily() -> None: + result = pl.datetime_range( + datetime(2021, 11, 7), + datetime(2021, 11, 11), + "2d", + time_zone="US/Central", + eager=True, + ) + assert result.to_list() == [ + datetime(2021, 11, 7, 0, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 11, 9, 0, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 11, 11, 0, 0, tzinfo=ZoneInfo("US/Central")), + ] + + +def test_tzaware_datetime_range_crossing_dst_weekly() -> None: + result = pl.datetime_range( + datetime(2021, 11, 7), + datetime(2021, 11, 20), + "1w", + time_zone="US/Central", + eager=True, + ) + assert result.to_list() == [ + datetime(2021, 11, 7, 0, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 11, 14, 0, 0, tzinfo=ZoneInfo("US/Central")), + ] + + +def test_tzaware_datetime_range_crossing_dst_monthly() -> None: + result = pl.datetime_range( + datetime(2021, 11, 7), + datetime(2021, 12, 20), + "1mo", + time_zone="US/Central", + eager=True, + ) + assert result.to_list() == [ + datetime(2021, 11, 7, 0, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 12, 7, 0, 0, tzinfo=ZoneInfo("US/Central")), + ] + + +def test_datetime_range_with_unsupported_datetimes() -> None: + with pytest.raises( + ComputeError, + match=r"datetime '2021-11-07 01:00:00' is ambiguous in time zone 'US/Central'", + ): + pl.datetime_range( + datetime(2021, 11, 7, 1), + datetime(2021, 11, 7, 2), + "1h", + time_zone="US/Central", + eager=True, + ) + with pytest.raises( + ComputeError, + match=r"datetime '2021-03-28 02:30:00' is non-existent in time zone 'Europe/Vienna'", + ): + pl.datetime_range( + datetime(2021, 3, 28, 2, 30), + datetime(2021, 3, 28, 4), + "1h", + time_zone="Europe/Vienna", + eager=True, + ) + + +@pytest.mark.parametrize( + ("values_time_zone", "input_time_zone", "output_time_zone"), + [ + ("Asia/Kathmandu", "Asia/Kathmandu", "Asia/Kathmandu"), + ("Asia/Kathmandu", None, "Asia/Kathmandu"), + (None, "Asia/Kathmandu", "Asia/Kathmandu"), + (None, None, None), + ], +) +@pytest.mark.parametrize( + ("values_time_unit", "input_time_unit", "output_time_unit"), + [ + ("ms", None, "ms"), + ("us", None, "us"), + ("ns", None, "ns"), + ("ms", "ms", "ms"), + ("us", "ms", "ms"), + ("ns", "ms", "ms"), + ("ms", "us", "us"), + ("us", "us", "us"), + ("ns", "us", "us"), + ("ms", "ns", "ns"), + ("us", "ns", "ns"), + ("ns", "ns", "ns"), + ], +) +def test_datetime_ranges_schema( + values_time_zone: str | None, + input_time_zone: str | None, + output_time_zone: str | None, + values_time_unit: TimeUnit, + input_time_unit: TimeUnit | None, + output_time_unit: TimeUnit, +) -> None: + df = ( + pl.DataFrame({"start": [datetime(2020, 1, 1)], "end": [datetime(2020, 1, 2)]}) + .with_columns( + pl.col("*") + .dt.replace_time_zone(values_time_zone) + .dt.cast_time_unit(values_time_unit) + ) + .lazy() + ) + result = df.with_columns( + datetime_range=pl.datetime_ranges( + pl.col("start"), + pl.col("end"), + time_zone=input_time_zone, + time_unit=input_time_unit, + ) + ) + expected_schema = { + "start": pl.Datetime(time_unit=values_time_unit, time_zone=values_time_zone), + "end": pl.Datetime(time_unit=values_time_unit, time_zone=values_time_zone), + "datetime_range": pl.List( + pl.Datetime(time_unit=output_time_unit, time_zone=output_time_zone) + ), + } + assert result.collect_schema() == expected_schema + assert result.collect().schema == expected_schema + + expected = pl.DataFrame( + { + "start": [datetime(2020, 1, 1)], + "end": [datetime(2020, 1, 2)], + "datetime_range": [[datetime(2020, 1, 1), datetime(2020, 1, 2)]], + } + ).with_columns( + pl.col("start") + .dt.replace_time_zone(values_time_zone) + .dt.cast_time_unit(values_time_unit), + pl.col("end") + .dt.replace_time_zone(values_time_zone) + .dt.cast_time_unit(values_time_unit), + pl.col("datetime_range") + .explode() + .dt.replace_time_zone(output_time_zone) + .dt.cast_time_unit(output_time_unit) + .implode(), + ) + assert_frame_equal(result.collect(), expected) + + +@pytest.mark.parametrize( + ( + "input_time_unit", + "input_time_zone", + "output_dtype", + "interval", + "expected_datetime_range", + ), + [ + (None, None, pl.Datetime("us"), "1s1d", ["2020-01-01", "2020-01-02 00:00:01"]), + (None, None, pl.Datetime("us"), "1d1s", ["2020-01-01", "2020-01-02 00:00:01"]), + ( + None, + None, + pl.Datetime("ns"), + "1d1ns", + ["2020-01-01", "2020-01-02 00:00:00.000000001"], + ), + ("ms", None, pl.Datetime("ms"), "1s1d", ["2020-01-01", "2020-01-02 00:00:01"]), + ("ms", None, pl.Datetime("ms"), "1d1s", ["2020-01-01", "2020-01-02 00:00:01"]), + ( + None, + "Asia/Kathmandu", + pl.Datetime("us", "Asia/Kathmandu"), + "1s1d", + ["2020-01-01", "2020-01-02 00:00:01"], + ), + ( + None, + "Asia/Kathmandu", + pl.Datetime("us", "Asia/Kathmandu"), + "1d1s", + ["2020-01-01", "2020-01-02 00:00:01"], + ), + ( + None, + "Asia/Kathmandu", + pl.Datetime("ns", "Asia/Kathmandu"), + "1d1ns", + ["2020-01-01", "2020-01-02 00:00:00.000000001"], + ), + ( + "ms", + "Asia/Kathmandu", + pl.Datetime("ms", "Asia/Kathmandu"), + "1s1d", + ["2020-01-01", "2020-01-02 00:00:01"], + ), + ( + "ms", + "Asia/Kathmandu", + pl.Datetime("ms", "Asia/Kathmandu"), + "1d1s", + ["2020-01-01", "2020-01-02 00:00:01"], + ), + ], +) +def test_datetime_range_schema_upcasts_to_datetime( + input_time_unit: TimeUnit | None, + input_time_zone: str | None, + output_dtype: PolarsDataType, + interval: str, + expected_datetime_range: list[str], +) -> None: + df = pl.DataFrame({"start": [date(2020, 1, 1)], "end": [date(2020, 1, 3)]}).lazy() + result = df.with_columns( + datetime_range=pl.datetime_ranges( + pl.col("start"), + pl.col("end"), + interval=interval, + time_unit=input_time_unit, + time_zone=input_time_zone, + ) + ) + expected_schema = { + "start": pl.Date, + "end": pl.Date, + "datetime_range": pl.List(output_dtype), + } + assert result.collect_schema() == expected_schema + assert result.collect().schema == expected_schema + + expected = pl.DataFrame( + { + "start": [date(2020, 1, 1)], + "end": [date(2020, 1, 3)], + "datetime_range": pl.Series(expected_datetime_range) + .str.to_datetime(time_unit="ns") + .implode(), + } + ).with_columns( + pl.col("datetime_range") + .explode() + .dt.cast_time_unit(output_dtype.time_unit) # type: ignore[union-attr] + .dt.replace_time_zone(output_dtype.time_zone) # type: ignore[union-attr] + .implode(), + ) + assert_frame_equal(result.collect(), expected) + + # check datetime_range too + result_single = pl.datetime_range( + date(2020, 1, 1), + date(2020, 1, 3), + interval=interval, + time_unit=input_time_unit, + time_zone=input_time_zone, + eager=True, + ).alias("datetime") + assert_series_equal( + result_single, expected["datetime_range"].explode().rename("datetime") + ) + + +def test_datetime_ranges_no_alias_schema_9037() -> None: + df = pl.DataFrame( + {"start": [datetime(2020, 1, 1)], "end": [datetime(2020, 1, 2)]} + ).lazy() + result = df.with_columns(pl.datetime_ranges(pl.col("start"), pl.col("end"))) + expected_schema = { + "start": pl.List(pl.Datetime(time_unit="us", time_zone=None)), + "end": pl.Datetime(time_unit="us", time_zone=None), + } + assert result.collect_schema() == expected_schema + assert result.collect().schema == expected_schema + + +@pytest.mark.parametrize("interval", [timedelta(0), timedelta(minutes=-10)]) +def test_datetime_range_invalid_interval(interval: timedelta) -> None: + with pytest.raises(ComputeError, match="`interval` must be positive"): + pl.datetime_range( + datetime(2000, 3, 20), datetime(2000, 3, 21), interval="-1h", eager=True + ) + + +@pytest.mark.parametrize( + ("closed", "expected_values"), + [ + ("right", [datetime(2020, 2, 29), datetime(2020, 3, 31)]), + ("left", [datetime(2020, 1, 31), datetime(2020, 2, 29)]), + ("none", [datetime(2020, 2, 29)]), + ("both", [datetime(2020, 1, 31), datetime(2020, 2, 29), datetime(2020, 3, 31)]), + ], +) +def test_datetime_range_end_of_month_5441( + closed: ClosedInterval, expected_values: list[datetime] +) -> None: + start = date(2020, 1, 31) + stop = date(2020, 3, 31) + result = pl.datetime_range(start, stop, interval="1mo", closed=closed, eager=True) + expected = pl.Series("literal", expected_values) + assert_series_equal(result, expected) + + +def test_datetime_ranges_broadcasting() -> None: + df = pl.DataFrame( + { + "datetimes": [ + datetime(2021, 1, 1), + datetime(2021, 1, 2), + datetime(2021, 1, 3), + ] + } + ) + result = df.select( + pl.datetime_ranges(start="datetimes", end=datetime(2021, 1, 3)).alias("end"), + pl.datetime_ranges(start=datetime(2021, 1, 1), end="datetimes").alias("start"), + ) + expected = pl.DataFrame( + { + "end": [ + [datetime(2021, 1, 1), datetime(2021, 1, 2), datetime(2021, 1, 3)], + [datetime(2021, 1, 2), datetime(2021, 1, 3)], + [datetime(2021, 1, 3)], + ], + "start": [ + [datetime(2021, 1, 1)], + [datetime(2021, 1, 1), datetime(2021, 1, 2)], + [datetime(2021, 1, 1), datetime(2021, 1, 2), datetime(2021, 1, 3)], + ], + } + ) + assert_frame_equal(result, expected) + + +def test_datetime_range_specifying_ambiguous_11713() -> None: + result = pl.datetime_range( + pl.datetime(2023, 10, 29, 2, 0).dt.replace_time_zone( + "Europe/Madrid", ambiguous="earliest" + ), + pl.datetime(2023, 10, 29, 3, 0).dt.replace_time_zone("Europe/Madrid"), + "1h", + eager=True, + ) + expected = pl.Series( + "datetime", + [ + datetime(2023, 10, 29, 2), + datetime(2023, 10, 29, 2), + datetime(2023, 10, 29, 3), + ], + ).dt.replace_time_zone( + "Europe/Madrid", ambiguous=pl.Series(["earliest", "latest", "raise"]) + ) + assert_series_equal(result, expected) + result = pl.datetime_range( + pl.datetime(2023, 10, 29, 2, 0).dt.replace_time_zone( + "Europe/Madrid", ambiguous="latest" + ), + pl.datetime(2023, 10, 29, 3, 0).dt.replace_time_zone("Europe/Madrid"), + "1h", + eager=True, + ) + expected = pl.Series( + "datetime", [datetime(2023, 10, 29, 2), datetime(2023, 10, 29, 3)] + ).dt.replace_time_zone("Europe/Madrid", ambiguous=pl.Series(["latest", "raise"])) + assert_series_equal(result, expected) + + +@given( + closed=st.sampled_from(["none", "left", "right", "both"]), + time_unit=st.sampled_from(["ms", "us", "ns"]), + n=st.integers(1, 10), + size=st.integers(8, 10), + unit=st.sampled_from(["s", "m", "h", "d", "mo"]), + start=st.datetimes(datetime(1965, 1, 1), datetime(2100, 1, 1)), +) +@settings(max_examples=50) +@pytest.mark.benchmark +def test_datetime_range_fast_slow_paths( + closed: ClosedInterval, + time_unit: TimeUnit, + n: int, + size: int, + unit: str, + start: datetime, +) -> None: + end = pl.select(pl.lit(start).dt.offset_by(f"{n * size}{unit}")).item() + result_slow = pl.datetime_range( + start, + end, + closed=closed, + time_unit=time_unit, + interval=f"{n}{unit}", + time_zone="Asia/Kathmandu", + eager=True, + ).dt.replace_time_zone(None) + result_fast = pl.datetime_range( + start, + end, + closed=closed, + time_unit=time_unit, + interval=f"{n}{unit}", + eager=True, + ) + assert_series_equal(result_slow, result_fast) + + +def test_dt_range_with_nanosecond_interval_19931() -> None: + with pytest.raises( + InvalidOperationError, match="interval 1ns is too small for time unit ms" + ): + pl.datetime_range( + pl.date(2022, 1, 1), + pl.date(2022, 1, 1), + time_zone="Asia/Kathmandu", + interval="1ns", + time_unit="ms", + eager=True, + ) + + +def test_datetime_range_with_nanoseconds_overflow_15735() -> None: + s = pl.datetime_range(date(2000, 1, 1), date(2300, 1, 1), "24h", eager=True) + assert s.dtype == pl.Datetime("us") + assert s.shape == (109574,) diff --git a/py-polars/tests/unit/functions/range/test_int_range.py b/py-polars/tests/unit/functions/range/test_int_range.py new file mode 100644 index 000000000000..f1d47c3000a6 --- /dev/null +++ b/py-polars/tests/unit/functions/range/test_int_range.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +import polars as pl +from polars.exceptions import ComputeError, InvalidOperationError +from polars.testing import assert_frame_equal, assert_series_equal + + +def test_int_range() -> None: + result = pl.int_range(0, 3) + expected = pl.Series("int_range", [0, 1, 2]) + assert_series_equal(pl.select(int_range=result).to_series(), expected) + + +def test_int_range_alias() -> None: + # note: `arange` is an alias for `int_range` + ldf = pl.LazyFrame({"a": [1, 1, 1]}) + result = ldf.filter(pl.col("a") >= pl.arange(0, 3)).collect() + expected = pl.DataFrame({"a": [1, 1]}) + assert_frame_equal(result, expected) + + +def test_int_range_decreasing() -> None: + assert pl.int_range(10, 1, -2, eager=True).to_list() == list(range(10, 1, -2)) + assert pl.int_range(10, -1, -1, eager=True).to_list() == list(range(10, -1, -1)) + + +def test_int_range_expr() -> None: + df = pl.DataFrame({"a": ["foobar", "barfoo"]}) + out = df.select(pl.int_range(0, pl.col("a").count() * 10)) + assert out.shape == (20, 1) + assert out.to_series(0)[-1] == 19 + + # eager arange + out2 = pl.arange(0, 10, 2, eager=True) + assert out2.to_list() == [0, 2, 4, 6, 8] + + +def test_int_range_short_syntax() -> None: + result = pl.int_range(3) + expected = pl.Series("int", [0, 1, 2]) + assert_series_equal(pl.select(int=result).to_series(), expected) + + +def test_int_ranges_short_syntax() -> None: + result = pl.int_ranges(3) + expected = pl.Series("int", [[0, 1, 2]]) + assert_series_equal(pl.select(int=result).to_series(), expected) + + +def test_int_range_start_default() -> None: + result = pl.int_range(end=3) + expected = pl.Series("int", [0, 1, 2]) + assert_series_equal(pl.select(int=result).to_series(), expected) + + +def test_int_ranges_start_default() -> None: + df = pl.DataFrame({"end": [3, 2]}) + result = df.select(int_range=pl.int_ranges(end="end")) + expected = pl.DataFrame({"int_range": [[0, 1, 2], [0, 1]]}) + assert_frame_equal(result, expected) + + +def test_int_range_eager() -> None: + result = pl.int_range(0, 3, eager=True) + expected = pl.Series("literal", [0, 1, 2]) + assert_series_equal(result, expected) + + +def test_int_range_lazy() -> None: + lf = pl.select(n=pl.int_range(8, 0, -2), eager=False) + expected = pl.LazyFrame({"n": [8, 6, 4, 2]}) + assert_frame_equal(lf, expected) + + +def test_int_range_schema() -> None: + result = pl.LazyFrame().select(int=pl.int_range(-3, 3)) + + expected_schema = {"int": pl.Int64} + assert result.collect_schema() == expected_schema + assert result.collect().schema == expected_schema + + +@pytest.mark.parametrize( + ("start", "end", "expected"), + [ + ("a", "b", pl.Series("a", [[1, 2], [2, 3]])), + (-1, "a", pl.Series("literal", [[-1, 0], [-1, 0, 1]])), + ("b", 4, pl.Series("b", [[3], []])), + ], +) +def test_int_ranges(start: Any, end: Any, expected: pl.Series) -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + + result = df.select(pl.int_ranges(start, end)) + assert_series_equal(result.to_series(), expected) + + +def test_int_ranges_decreasing() -> None: + expected = pl.Series("literal", [[5, 4, 3, 2, 1]], dtype=pl.List(pl.Int64)) + assert_series_equal(pl.int_ranges(5, 0, -1, eager=True), expected) + assert_series_equal(pl.select(pl.int_ranges(5, 0, -1)).to_series(), expected) + + +@pytest.mark.parametrize( + ("start", "end", "step"), + [ + (0, -5, 1), + (5, 0, 1), + (0, 5, -1), + ], +) +def test_int_ranges_empty(start: int, end: int, step: int) -> None: + assert_series_equal( + pl.int_range(start, end, step, eager=True), + pl.Series("literal", [], dtype=pl.Int64), + ) + assert_series_equal( + pl.int_ranges(start, end, step, eager=True), + pl.Series("literal", [[]], dtype=pl.List(pl.Int64)), + ) + assert_series_equal( + pl.Series("int", [], dtype=pl.Int64), + pl.select(int=pl.int_range(start, end, step)).to_series(), + ) + assert_series_equal( + pl.Series("int_range", [[]], dtype=pl.List(pl.Int64)), + pl.select(int_range=pl.int_ranges(start, end, step)).to_series(), + ) + + +def test_int_ranges_eager() -> None: + start = pl.Series("s", [1, 2]) + result = pl.int_ranges(start, 4, eager=True) + + expected = pl.Series("s", [[1, 2, 3], [2, 3]]) + assert_series_equal(result, expected) + + +def test_int_ranges_schema_dtype_default() -> None: + lf = pl.LazyFrame({"start": [1, 2], "end": [3, 4]}) + + result = lf.select(pl.int_ranges("start", "end")) + + expected_schema = {"start": pl.List(pl.Int64)} + assert result.collect_schema() == expected_schema + assert result.collect().schema == expected_schema + + +def test_int_ranges_schema_dtype_arg() -> None: + lf = pl.LazyFrame({"start": [1, 2], "end": [3, 4]}) + + result = lf.select(pl.int_ranges("start", "end", dtype=pl.UInt16)) + + expected_schema = {"start": pl.List(pl.UInt16)} + assert result.collect_schema() == expected_schema + assert result.collect().schema == expected_schema + + +def test_int_range_input_shape_empty() -> None: + empty = pl.Series(dtype=pl.Time) + single = pl.Series([5]) + + with pytest.raises( + ComputeError, match="`start` must contain exactly one value, got 0 values" + ): + pl.int_range(empty, single, eager=True) + with pytest.raises( + ComputeError, match="`end` must contain exactly one value, got 0 values" + ): + pl.int_range(single, empty, eager=True) + with pytest.raises( + ComputeError, match="`start` must contain exactly one value, got 0 values" + ): + pl.int_range(empty, empty, eager=True) + + +def test_int_range_input_shape_multiple_values() -> None: + single = pl.Series([5]) + multiple = pl.Series([10, 15]) + + with pytest.raises( + ComputeError, match="`start` must contain exactly one value, got 2 values" + ): + pl.int_range(multiple, single, eager=True) + with pytest.raises( + ComputeError, match="`end` must contain exactly one value, got 2 values" + ): + pl.int_range(single, multiple, eager=True) + with pytest.raises( + ComputeError, match="`start` must contain exactly one value, got 2 values" + ): + pl.int_range(multiple, multiple, eager=True) + + +# https://github.com/pola-rs/polars/issues/10867 +def test_int_range_index_type_negative() -> None: + result = pl.select(pl.int_range(pl.lit(3).cast(pl.UInt32).alias("start"), -1, -1)) + expected = pl.DataFrame({"start": [3, 2, 1, 0]}) + assert_frame_equal(result, expected) + + +def test_int_range_null_input() -> None: + with pytest.raises(ComputeError, match="invalid null input for `int_range`"): + pl.select(pl.int_range(3, pl.lit(None), -1, dtype=pl.UInt32)) + + +def test_int_range_invalid_conversion() -> None: + with pytest.raises( + InvalidOperationError, match="conversion from `i32` to `u32` failed" + ): + pl.select(pl.int_range(3, -1, -1, dtype=pl.UInt32)) + + +def test_int_range_non_integer_dtype() -> None: + with pytest.raises( + ComputeError, match="non-integer `dtype` passed to `int_range`: Float64" + ): + pl.select(pl.int_range(3, -1, -1, dtype=pl.Float64)) # type: ignore[arg-type] + + +def test_int_ranges_broadcasting() -> None: + df = pl.DataFrame({"int": [1, 2, 3]}) + result = df.select( + # result column name means these columns will be broadcast + pl.int_ranges(1, pl.Series([2, 4, 6]), "int").alias("start"), + pl.int_ranges("int", 6, "int").alias("end"), + pl.int_ranges("int", pl.col("int") + 2, 1).alias("step"), + pl.int_ranges("int", 3, 1).alias("end_step"), + pl.int_ranges(1, "int", 1).alias("start_step"), + pl.int_ranges(1, 6, "int").alias("start_end"), + pl.int_ranges("int", pl.Series([4, 5, 10]), "int").alias("no_broadcast"), + ) + expected = pl.DataFrame( + { + "start": [[1], [1, 3], [1, 4]], + "end": [ + [1, 2, 3, 4, 5], + [2, 4], + [3], + ], + "step": [[1, 2], [2, 3], [3, 4]], + "end_step": [ + [1, 2], + [2], + [], + ], + "start_step": [ + [], + [1], + [1, 2], + ], + "start_end": [ + [1, 2, 3, 4, 5], + [1, 3, 5], + [1, 4], + ], + "no_broadcast": [[1, 2, 3], [2, 4], [3, 6, 9]], + } + ) + assert_frame_equal(result, expected) + + +# https://github.com/pola-rs/polars/issues/15307 +def test_int_range_non_int_dtype() -> None: + with pytest.raises( + ComputeError, match="non-integer `dtype` passed to `int_range`: String" + ): + pl.int_range(0, 3, dtype=pl.String, eager=True) # type: ignore[arg-type] + + +# https://github.com/pola-rs/polars/issues/15307 +def test_int_ranges_non_int_dtype() -> None: + with pytest.raises( + ComputeError, match="non-integer `dtype` passed to `int_ranges`: String" + ): + pl.int_ranges(0, 3, dtype=pl.String, eager=True) # type: ignore[arg-type] diff --git a/py-polars/tests/unit/functions/range/test_linear_space.py b/py-polars/tests/unit/functions/range/test_linear_space.py new file mode 100644 index 000000000000..a30127047d21 --- /dev/null +++ b/py-polars/tests/unit/functions/range/test_linear_space.py @@ -0,0 +1,637 @@ +from __future__ import annotations + +import re +from datetime import date, datetime +from typing import TYPE_CHECKING, Any + +import numpy as np +import pytest + +import polars as pl +from polars.exceptions import ComputeError, InvalidOperationError, ShapeError +from polars.testing import assert_frame_equal, assert_series_equal + +if TYPE_CHECKING: + from polars import Expr + from polars._typing import ClosedInterval, PolarsDataType + + +@pytest.mark.parametrize( + ("start", "end"), + [ + (0, 0), + (0, 1), + (-1, 0), + (-2.1, 3.4), + ], +) +@pytest.mark.parametrize("num_samples", [0, 1, 2, 5, 1_000]) +@pytest.mark.parametrize("interval", ["both", "left", "right", "none"]) +@pytest.mark.parametrize("eager", [True, False]) +def test_linear_space_values( + start: int | float, + end: int | float, + num_samples: int, + interval: ClosedInterval, + eager: bool, +) -> None: + if eager: + result = pl.linear_space( + start, end, num_samples, closed=interval, eager=True + ).rename("ls") + else: + result = pl.select( + ls=pl.linear_space(start, end, num_samples, closed=interval) + ).to_series() + + if interval == "both": + expected = pl.Series("ls", np.linspace(start, end, num_samples)) + elif interval == "left": + expected = pl.Series("ls", np.linspace(start, end, num_samples, endpoint=False)) + elif interval == "right": + expected = pl.Series("ls", np.linspace(start, end, num_samples + 1)[1:]) + elif interval == "none": + expected = pl.Series("ls", np.linspace(start, end, num_samples + 2)[1:-1]) + + assert_series_equal(result, expected) + + +def test_linear_space_expr() -> None: + lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) + + result = lf.select(pl.linear_space(0, pl.col("a").len(), 3)) + expected = lf.select(literal=pl.Series([0.0, 2.5, 5.0], dtype=pl.Float64)) + assert_frame_equal(result, expected) + + result = lf.select(pl.linear_space(pl.col("a").len(), 0, 3)) + expected = lf.select(a=pl.Series([5.0, 2.5, 0.0], dtype=pl.Float64)) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + ("dtype_start", "dtype_end", "dtype_expected"), + [ + (pl.Float32, pl.Float32, pl.Float32), + (pl.Float32, pl.Float64, pl.Float64), + (pl.Float64, pl.Float32, pl.Float64), + (pl.Float64, pl.Float64, pl.Float64), + (pl.UInt8, pl.UInt32, pl.Float64), + (pl.Int16, pl.Int128, pl.Float64), + (pl.Int8, pl.Float64, pl.Float64), + ], +) +def test_linear_space_numeric_dtype( + dtype_start: PolarsDataType, + dtype_end: PolarsDataType, + dtype_expected: PolarsDataType, +) -> None: + lf = pl.LazyFrame() + result = lf.select( + ls=pl.linear_space(pl.lit(0, dtype=dtype_start), pl.lit(1, dtype=dtype_end), 6) + ) + expected = lf.select( + ls=pl.Series([0.0, 0.2, 0.4, 0.6, 0.8, 1.0], dtype=dtype_expected) + ) + assert_frame_equal(result, expected) + + +def test_linear_space_date() -> None: + d1 = date(2025, 1, 1) + d2 = date(2025, 2, 1) + out_values = [ + datetime(2025, 1, 1), + datetime(2025, 1, 11, 8), + datetime(2025, 1, 21, 16), + datetime(2025, 2, 1), + ] + lf = pl.LazyFrame() + + result = lf.select(ls=pl.linear_space(d1, d2, 4, closed="both")) + expected = lf.select(ls=pl.Series(out_values, dtype=pl.Datetime("ms"))) + assert_frame_equal(result, expected) + + result = lf.select(ls=pl.linear_space(d1, d2, 3, closed="left")) + expected = lf.select(ls=pl.Series(out_values[:-1], dtype=pl.Datetime("ms"))) + assert_frame_equal(result, expected) + + result = lf.select(ls=pl.linear_space(d1, d2, 3, closed="right")) + expected = lf.select(ls=pl.Series(out_values[1:], dtype=pl.Datetime("ms"))) + assert_frame_equal(result, expected) + + result = lf.select(ls=pl.linear_space(d1, d2, 2, closed="none")) + expected = lf.select(ls=pl.Series(out_values[1:-1], dtype=pl.Datetime("ms"))) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "dtype", + [ + pl.Datetime("ms", None), + pl.Datetime("ms", time_zone="Asia/Tokyo"), + pl.Datetime("us", None), + pl.Datetime("us", time_zone="Asia/Tokyo"), + pl.Datetime("ns", time_zone="Asia/Tokyo"), + pl.Time, + pl.Duration("ms"), + pl.Duration("us"), + pl.Duration("ns"), + ], +) +def test_linear_space_temporal(dtype: PolarsDataType) -> None: + # All temporal types except for Date, which is tested above. + start = 0 + end = 1_000_000_000 + + lf = pl.LazyFrame() + + result_int = lf.select( + ls=pl.linear_space(start, end, 11).cast(pl.Int64).cast(dtype) + ) + result_dt = lf.select( + ls=pl.linear_space(pl.lit(start, dtype=dtype), pl.lit(end, dtype=dtype), 11) + ) + + assert_frame_equal(result_int, result_dt) + + +@pytest.mark.parametrize( + ("dtype1", "dtype2", "str1", "str2"), + [ + (pl.Date, pl.Datetime("ms"), "Date", "Datetime(Milliseconds, None)"), + ( + pl.Datetime("ms"), + pl.Datetime("ns"), + "Datetime(Milliseconds, None)", + "Datetime(Nanoseconds, None)", + ), + (pl.Datetime("us"), pl.Time, "Datetime(Microseconds, None)", "Time"), + ( + pl.Duration("us"), + pl.Duration("ms"), + "Duration(Microseconds)", + "Duration(Milliseconds)", + ), + (pl.Int32, pl.String, "Int32", "String"), + ], +) +def test_linear_space_incompatible_dtypes( + dtype1: PolarsDataType, + dtype2: PolarsDataType, + str1: str, + str2: str, +) -> None: + value1 = pl.lit(0, dtype1) + value2 = pl.lit(1, dtype2) + with pytest.raises( + ComputeError, + match=re.escape( + f"'start' and 'end' have incompatible dtypes, got {str1} and {str2}" + ), + ): + pl.linear_space(value1, value2, 11, eager=True) + + +def test_linear_space_expr_wrong_length() -> None: + df = pl.DataFrame({"a": [1, 2, 3, 4, 5]}) + msg = "unable to add a column of length 6 to a DataFrame of height 5" + streaming_msg = "zip node received non-equal length inputs" + with pytest.raises(ShapeError, match=rf"({msg})|({streaming_msg})"): + df.with_columns(pl.linear_space(0, 1, 6)) + + +def test_linear_space_num_samples_expr() -> None: + lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) + result = lf.with_columns(ls=pl.linear_space(0, 1, pl.len(), closed="left")) + expected = lf.with_columns(ls=pl.Series([0, 0.2, 0.4, 0.6, 0.8], dtype=pl.Float64)) + assert_frame_equal(result, expected) + + +def test_linear_space_invalid_num_samples_expr() -> None: + lf = pl.LazyFrame({"x": [1, 2, 3]}) + with pytest.raises( + ComputeError, match="`num_samples` must contain exactly one value, got 3 values" + ): + lf.select(pl.linear_space(0, 1, pl.col("x"))).collect() + + +@pytest.mark.parametrize("interval", ["both", "left", "right", "none"]) +def test_linear_spaces_values(interval: ClosedInterval) -> None: + starts = [ + None, 0.0, 0.0, 0.0, 0.0, + 0.0, None, 0.0, 0.0, 0.0, + -1.0, -1.0, None, -1.0, -1.0, + -2.1, -2.1, -2.1, None, -2.1, + ] # fmt: skip + + ends = [ + 0.0, None, 0.0, 0.0, 0.0, + 1.0, 1.0, None, 1.0, 1.0, + 0.0, 0.0, 0.0, None, 0.0, + 3.4, 3.4, 3.4, 3.4, None, + ] # fmt: skip + + num_samples = [ + 0, 1, None, 5, 1_1000, + 0, 1, 2, 5, None, + 0, 1, 2, 5, 1_1000, + 0, 1, 2, 5, 1_1000, + ] # fmt: skip + + df = pl.DataFrame( + { + "start": starts, + "end": ends, + "num_samples": num_samples, + } + ) + + out = df.select(pl.linear_spaces("start", "end", "num_samples", closed=interval))[ + "start" + ] + + # We check each element against the output from pl.linear_space(), which is + # validated above. + for row, start, end, ns in zip(out, starts, ends, num_samples): + if start is None or end is None or ns is None: + assert row is None + else: + expected = pl.linear_space( + start, end, ns, eager=True, closed=interval + ).rename("") + assert_series_equal(row, expected) + + +@pytest.mark.parametrize("interval", ["both", "left", "right", "none"]) +def test_linear_spaces_one_numeric(interval: ClosedInterval) -> None: + # Two expressions, one numeric input + starts = [1, 2] + ends = [5, 6] + num_samples = [3, 4] + lf = pl.LazyFrame( + { + "start": starts, + "end": ends, + "num_samples": num_samples, + } + ) + result = lf.select( + pl.linear_spaces(starts[0], "end", "num_samples", closed=interval).alias( + "start" + ), + pl.linear_spaces("start", ends[0], "num_samples", closed=interval).alias("end"), + pl.linear_spaces("start", "end", num_samples[0], closed=interval).alias( + "num_samples" + ), + ) + expected_start0 = pl.linear_space( + starts[0], ends[0], num_samples[0], closed=interval, eager=True + ) + expected_start1 = pl.linear_space( + starts[0], ends[1], num_samples[1], closed=interval, eager=True + ) + expected_end0 = pl.linear_space( + starts[0], ends[0], num_samples[0], closed=interval, eager=True + ) + expected_end1 = pl.linear_space( + starts[1], ends[0], num_samples[1], closed=interval, eager=True + ) + expected_ns0 = pl.linear_space( + starts[0], ends[0], num_samples[0], closed=interval, eager=True + ) + expected_ns1 = pl.linear_space( + starts[1], ends[1], num_samples[0], closed=interval, eager=True + ) + expected = pl.LazyFrame( + { + "start": [expected_start0, expected_start1], + "end": [expected_end0, expected_end1], + "num_samples": [expected_ns0, expected_ns1], + } + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("interval", ["both", "left", "right", "none"]) +def test_linear_spaces_two_numeric(interval: ClosedInterval) -> None: + # One expression, two numeric inputs + starts = [1, 2] + ends = [5, 6] + num_samples = [3, 4] + lf = pl.LazyFrame( + { + "start": starts, + "end": ends, + "num_samples": num_samples, + } + ) + result = lf.select( + pl.linear_spaces("start", ends[0], num_samples[0], closed=interval).alias( + "start" + ), + pl.linear_spaces(starts[0], "end", num_samples[0], closed=interval).alias( + "end" + ), + pl.linear_spaces(starts[0], ends[0], "num_samples", closed=interval).alias( + "num_samples" + ), + ) + expected_start0 = pl.linear_space( + starts[0], ends[0], num_samples[0], closed=interval, eager=True + ) + expected_start1 = pl.linear_space( + starts[1], ends[0], num_samples[0], closed=interval, eager=True + ) + expected_end0 = pl.linear_space( + starts[0], ends[0], num_samples[0], closed=interval, eager=True + ) + expected_end1 = pl.linear_space( + starts[0], ends[1], num_samples[0], closed=interval, eager=True + ) + expected_ns0 = pl.linear_space( + starts[0], ends[0], num_samples[0], closed=interval, eager=True + ) + expected_ns1 = pl.linear_space( + starts[0], ends[0], num_samples[1], closed=interval, eager=True + ) + expected = pl.LazyFrame( + { + "start": [expected_start0, expected_start1], + "end": [expected_end0, expected_end1], + "num_samples": [expected_ns0, expected_ns1], + } + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "num_samples", + [ + 5, + pl.lit(5), + pl.lit(5, dtype=pl.UInt8), + pl.lit(5, dtype=pl.UInt16), + pl.lit(5, dtype=pl.UInt32), + pl.lit(5, dtype=pl.UInt64), + pl.lit(5, dtype=pl.Int8), + pl.lit(5, dtype=pl.Int16), + pl.lit(5, dtype=pl.Int32), + pl.lit(5, dtype=pl.Int64), + ], +) +@pytest.mark.parametrize("interval", ["both", "left", "right", "none"]) +@pytest.mark.parametrize( + "dtype", + [ + pl.Float32, + pl.Float64, + pl.Datetime, + ], +) +def test_linear_spaces_as_array( + interval: ClosedInterval, + num_samples: int | Expr, + dtype: PolarsDataType, +) -> None: + starts = [1, 2] + ends = [5, 6] + lf = pl.LazyFrame( + { + "start": pl.Series(starts, dtype=dtype), + "end": pl.Series(ends, dtype=dtype), + } + ) + result = lf.select( + a=pl.linear_spaces("start", "end", num_samples, closed=interval, as_array=True) + ) + expected_0 = pl.linear_space( + pl.lit(starts[0], dtype=dtype), + pl.lit(ends[0], dtype=dtype), + num_samples, + closed=interval, + eager=True, + ) + expected_1 = pl.linear_space( + pl.lit(starts[1], dtype=dtype), + pl.lit(ends[1], dtype=dtype), + num_samples, + closed=interval, + eager=True, + ) + expected = pl.LazyFrame( + {"a": pl.Series([expected_0, expected_1], dtype=pl.Array(dtype, 5))} + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("bad_num_samples", [pl.lit("a"), 1.0, "num_samples"]) +def test_linear_space_invalid_as_array(bad_num_samples: Any) -> None: + lf = pl.LazyFrame( + { + "start": [1, 2], + "end": [5, 6], + "num_samples": [2, 4], + } + ) + with pytest.raises( + InvalidOperationError, + match="'as_array' is only valid when 'num_samples' is a constant integer", + ): + lf.select(pl.linear_spaces("starts", "ends", bad_num_samples, as_array=True)) + + +@pytest.mark.parametrize("interval", ["both", "left", "right", "none"]) +def test_linear_spaces_numeric_input(interval: ClosedInterval) -> None: + starts = [1, 2] + ends = [5, 6] + num_samples = [3, 4] + lf = pl.LazyFrame( + { + "start": starts, + "end": ends, + "num_samples": num_samples, + } + ) + result = lf.select( + pl.linear_spaces("start", "end", "num_samples", closed=interval).alias("all"), + pl.linear_spaces(0, "end", "num_samples", closed=interval).alias("start"), + pl.linear_spaces("start", 10, "num_samples", closed=interval).alias("end"), + pl.linear_spaces("start", "end", 8, closed=interval).alias("num_samples"), + ) + expected_all0 = pl.linear_space( + starts[0], ends[0], num_samples[0], closed=interval, eager=True + ) + expected_all1 = pl.linear_space( + starts[1], ends[1], num_samples[1], closed=interval, eager=True + ) + expected_start0 = pl.linear_space( + 0, ends[0], num_samples[0], closed=interval, eager=True + ) + expected_start1 = pl.linear_space( + 0, ends[1], num_samples[1], closed=interval, eager=True + ) + expected_end0 = pl.linear_space( + starts[0], 10, num_samples[0], closed=interval, eager=True + ) + expected_end1 = pl.linear_space( + starts[1], 10, num_samples[1], closed=interval, eager=True + ) + expected_ns0 = pl.linear_space(starts[0], ends[0], 8, closed=interval, eager=True) + expected_ns1 = pl.linear_space(starts[1], ends[1], 8, closed=interval, eager=True) + expected = pl.LazyFrame( + { + "all": [expected_all0, expected_all1], + "start": [expected_start0, expected_start1], + "end": [expected_end0, expected_end1], + "num_samples": [expected_ns0, expected_ns1], + } + ) + assert_frame_equal(result, expected) + + +def test_linear_spaces_date() -> None: + d1 = date(2025, 1, 1) + d2 = date(2025, 2, 1) + + lf = pl.LazyFrame( + { + "start": [None, d1, d1, d1, None, d1, d1, d1], + "end": [d2, None, d2, d2, d2, None, d2, d2], + "num_samples": [3, 3, None, 3, 4, 4, None, 4], + } + ) + + result = lf.select(pl.linear_spaces("start", "end", "num_samples")) + expected = pl.LazyFrame( + { + "start": pl.Series( + [ + None, + None, + None, + [ + datetime(2025, 1, 1), + datetime(2025, 1, 16, 12), + datetime(2025, 2, 1), + ], + None, + None, + None, + [ + datetime(2025, 1, 1), + datetime(2025, 1, 11, 8), + datetime(2025, 1, 21, 16), + datetime(2025, 2, 1), + ], + ], + dtype=pl.List(pl.Datetime(time_unit="ms")), + ) + } + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "dtype", + [ + pl.Datetime("ms", None), + pl.Datetime("ms", time_zone="Asia/Tokyo"), + pl.Datetime("us", None), + pl.Datetime("us", time_zone="Asia/Tokyo"), + pl.Datetime("ns", time_zone="Asia/Tokyo"), + pl.Time, + pl.Duration("ms"), + pl.Duration("us"), + pl.Duration("ns"), + ], +) +def test_linear_spaces_temporal(dtype: PolarsDataType) -> None: + # All temporal types except for Date, which is tested above. + start = 0 + end = 1_000_000_000 + + lf = pl.LazyFrame( + { + "start": [start, start], + "end": [end, end], + "num_samples": [10, 15], + } + ) + lf_temporal = lf.select(pl.col("start", "end").cast(dtype), "num_samples") + result_int = lf.select(pl.linear_spaces("start", "end", "num_samples")).select( + pl.col("start").cast(pl.List(dtype)) + ) + result_dt = lf_temporal.select(pl.linear_spaces("start", "end", "num_samples")) + + assert_frame_equal(result_int, result_dt) + + +@pytest.mark.parametrize( + ("dtype1", "dtype2", "str1", "str2"), + [ + (pl.Date, pl.Datetime("ms"), "Date", "Datetime(Milliseconds, None)"), + ( + pl.Datetime("ms"), + pl.Datetime("ns"), + "Datetime(Milliseconds, None)", + "Datetime(Nanoseconds, None)", + ), + (pl.Datetime("us"), pl.Time, "Datetime(Microseconds, None)", "Time"), + ( + pl.Duration("us"), + pl.Duration("ms"), + "Duration(Microseconds)", + "Duration(Milliseconds)", + ), + (pl.Int32, pl.String, "Int32", "String"), + ], +) +def test_linear_spaces_incompatible_dtypes( + dtype1: PolarsDataType, + dtype2: PolarsDataType, + str1: str, + str2: str, +) -> None: + df = pl.LazyFrame( + { + "start": pl.Series([0]).cast(dtype1), + "end": pl.Series([1]).cast(dtype2), + "num_samples": 3, + } + ) + with pytest.raises( + ComputeError, + match=re.escape( + f"'start' and 'end' have incompatible dtypes, got {str1} and {str2}" + ), + ): + df.select(pl.linear_spaces("start", "end", "num_samples")).collect() + + +def test_linear_spaces_f32() -> None: + df = pl.LazyFrame( + { + "start": pl.Series([0.0, 1.0], dtype=pl.Float32), + "end": pl.Series([10.0, 11.0], dtype=pl.Float32), + } + ) + result = df.select(pl.linear_spaces("start", "end", 6)) + expected = pl.LazyFrame( + { + "start": pl.Series( + [ + [0.0, 2.0, 4.0, 6.0, 8.0, 10.0], + [1.0, 3.0, 5.0, 7.0, 9.0, 11.0], + ], + dtype=pl.List(pl.Float32), + ) + } + ) + assert_frame_equal(result, expected) + + +def test_linear_spaces_eager() -> None: + start = pl.Series("s", [1, 2]) + result = pl.linear_spaces(start, 6, 3, eager=True) + + expected = pl.Series("s", [[1.0, 3.5, 6.0], [2.0, 4.0, 6.0]]) + assert_series_equal(result, expected) diff --git a/py-polars/tests/unit/functions/range/test_time_range.py b/py-polars/tests/unit/functions/range/test_time_range.py new file mode 100644 index 000000000000..63bf5e0ec73f --- /dev/null +++ b/py-polars/tests/unit/functions/range/test_time_range.py @@ -0,0 +1,283 @@ +from __future__ import annotations + +from datetime import time, timedelta +from typing import TYPE_CHECKING + +import pytest + +import polars as pl +from polars.exceptions import ComputeError +from polars.testing import assert_frame_equal, assert_series_equal + +if TYPE_CHECKING: + from polars._typing import ClosedInterval + + +def test_time_range_schema() -> None: + df = pl.DataFrame({"start": [time(1)], "end": [time(1, 30)]}).lazy() + result = df.with_columns(time_range=pl.time_ranges(pl.col("start"), pl.col("end"))) + expected_schema = {"start": pl.Time, "end": pl.Time, "time_range": pl.List(pl.Time)} + assert result.collect_schema() == expected_schema + assert result.collect().schema == expected_schema + + +def test_time_ranges_eager() -> None: + start = pl.Series("start", [time(9, 0), time(10, 0)]) + end = pl.Series("end", [time(12, 0), time(11, 0)]) + + result = pl.time_ranges(start, end, eager=True) + + expected = pl.Series( + "start", + [ + [time(9, 0), time(10, 0), time(11, 0), time(12, 0)], + [time(10, 0), time(11, 0)], + ], + ) + assert_series_equal(result, expected) + + +def test_time_range_eager_explode() -> None: + start = pl.Series("start", [time(9, 0)]) + end = pl.Series("end", [time(11, 0)]) + + result = pl.time_range(start, end, eager=True) + + expected = pl.Series("start", [time(9, 0), time(10, 0), time(11, 0)]) + assert_series_equal(result, expected) + + +def test_time_range_input_shape_empty() -> None: + empty = pl.Series(dtype=pl.Time) + single = pl.Series([time(12, 0)]) + + with pytest.raises( + ComputeError, match="`start` must contain exactly one value, got 0 values" + ): + pl.time_range(empty, single, eager=True) + with pytest.raises( + ComputeError, match="`end` must contain exactly one value, got 0 values" + ): + pl.time_range(single, empty, eager=True) + with pytest.raises( + ComputeError, match="`start` must contain exactly one value, got 0 values" + ): + pl.time_range(empty, empty, eager=True) + + +def test_time_range_input_shape_multiple_values() -> None: + single = pl.Series([time(12, 0)]) + multiple = pl.Series([time(11, 0), time(12, 0)]) + + with pytest.raises( + ComputeError, match="`start` must contain exactly one value, got 2 values" + ): + pl.time_range(multiple, single, eager=True) + with pytest.raises( + ComputeError, match="`end` must contain exactly one value, got 2 values" + ): + pl.time_range(single, multiple, eager=True) + with pytest.raises( + ComputeError, match="`start` must contain exactly one value, got 2 values" + ): + pl.time_range(multiple, multiple, eager=True) + + +def test_time_range_start_equals_end() -> None: + t = time(12, 0) + + result = pl.time_range(t, t, closed="both", eager=True) + + expected = pl.Series("literal", [t]) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize("closed", ["left", "right", "none"]) +def test_time_range_start_equals_end_open(closed: ClosedInterval) -> None: + t = time(12, 0) + + result = pl.time_range(t, t, closed=closed, eager=True) + + expected = pl.Series("literal", dtype=pl.Time) + assert_series_equal(result, expected) + + +def test_time_range_start_later_than_end() -> None: + result = pl.time_range(time(12), time(11), eager=True) + expected = pl.Series("literal", dtype=pl.Time) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize("interval", [timedelta(0), timedelta(minutes=-10)]) +def test_time_range_invalid_step(interval: timedelta) -> None: + with pytest.raises(ComputeError, match="`interval` must be positive"): + pl.time_range(time(11), time(12), interval=interval, eager=True) + + +def test_time_range_lit_lazy() -> None: + tm = pl.select( + pl.time_range( + start=time(1, 2, 3), + end=time(23, 59, 59), + interval="5h45m10s333ms", + closed="right", + ).alias("tm") + ) + + assert tm["tm"].to_list() == [ + time(6, 47, 13, 333000), + time(12, 32, 23, 666000), + time(18, 17, 33, 999000), + ] + + # validate unset start/end + tm = pl.select(pl.time_range(interval="5h45m10s333ms").alias("tm")) + assert tm["tm"].to_list() == [ + time(0, 0), + time(5, 45, 10, 333000), + time(11, 30, 20, 666000), + time(17, 15, 30, 999000), + time(23, 0, 41, 332000), + ] + + tm = pl.select( + pl.time_range(start=pl.lit(time(23, 59, 59, 999980)), interval="10000ns").alias( + "tm" + ) + ) + assert tm["tm"].to_list() == [ + time(23, 59, 59, 999980), + time(23, 59, 59, 999990), + ] + + +def test_time_range_lit_eager() -> None: + eager = True + tm = pl.select( + pl.time_range( + start=time(1, 2, 3), + end=time(23, 59, 59), + interval="5h45m10s333ms", + closed="right", + eager=eager, + ).alias("tm") + ) + if not eager: + tm = tm.select(pl.col("tm").explode()) + assert tm["tm"].to_list() == [ + time(6, 47, 13, 333000), + time(12, 32, 23, 666000), + time(18, 17, 33, 999000), + ] + + # validate unset start/end + tm = pl.select( + pl.time_range( + interval="5h45m10s333ms", + eager=eager, + ).alias("tm") + ) + if not eager: + tm = tm.select(pl.col("tm").explode()) + assert tm["tm"].to_list() == [ + time(0, 0), + time(5, 45, 10, 333000), + time(11, 30, 20, 666000), + time(17, 15, 30, 999000), + time(23, 0, 41, 332000), + ] + + tm = pl.select( + pl.time_range( + start=pl.lit(time(23, 59, 59, 999980)), + interval="10000ns", + eager=eager, + ).alias("tm") + ) + tm = tm.select(pl.col("tm").explode()) + assert tm["tm"].to_list() == [ + time(23, 59, 59, 999980), + time(23, 59, 59, 999990), + ] + + +def test_time_range_expr() -> None: + df = pl.DataFrame( + { + "start": pl.time_range(interval="6h", eager=True), + "stop": pl.time_range(start=time(2, 59), interval="5h59m", eager=True), + } + ).with_columns(intervals=pl.time_ranges("start", pl.col("stop"), interval="1h29m")) + # shape: (4, 3) + # ┌──────────┬──────────┬────────────────────────────────┐ + # │ start ┆ stop ┆ intervals │ + # │ --- ┆ --- ┆ --- │ + # │ time ┆ time ┆ list[time] │ + # ╞══════════╪══════════╪════════════════════════════════╡ + # │ 00:00:00 ┆ 02:59:00 ┆ [00:00:00, 01:29:00, 02:58:00] │ + # │ 06:00:00 ┆ 08:58:00 ┆ [06:00:00, 07:29:00, 08:58:00] │ + # │ 12:00:00 ┆ 14:57:00 ┆ [12:00:00, 13:29:00] │ + # │ 18:00:00 ┆ 20:56:00 ┆ [18:00:00, 19:29:00] │ + # └──────────┴──────────┴────────────────────────────────┘ + assert df.rows() == [ + (time(0, 0), time(2, 59), [time(0, 0), time(1, 29), time(2, 58)]), + (time(6, 0), time(8, 58), [time(6, 0), time(7, 29), time(8, 58)]), + (time(12, 0), time(14, 57), [time(12, 0), time(13, 29)]), + (time(18, 0), time(20, 56), [time(18, 0), time(19, 29)]), + ] + + +def test_time_range_name() -> None: + expected_name = "literal" + result_eager = pl.time_range(time(10), time(12), eager=True) + assert result_eager.name == expected_name + + expected_name = "s1" + result_lazy = pl.select( + pl.time_range( + pl.Series("s1", [time(10)]), pl.Series("s2", [time(12)]), eager=False + ) + ).to_series() + assert result_lazy.name == expected_name + + +def test_time_ranges_broadcasting() -> None: + df = pl.DataFrame({"time": [time(10, 0), time(11, 0), time(12, 0)]}) + result = df.select( + pl.time_ranges(start="time", end=time(12, 0)).alias("end"), + pl.time_ranges(start=time(10, 0), end="time").alias("start"), + ) + expected = pl.DataFrame( + { + "end": [ + [time(10, 0), time(11, 0), time(12, 0)], + [time(11, 0), time(12, 0)], + [time(12, 0)], + ], + "start": [ + [time(10, 0)], + [time(10, 0), time(11, 0)], + [time(10, 0), time(11, 0), time(12, 0)], + ], + } + ) + assert_frame_equal(result, expected) + + +def test_time_ranges_mismatched_chunks() -> None: + s1 = pl.Series("s1", [time(10), time(11)]) + s1.append(pl.Series([time(12)])) + + s2 = pl.Series("s2", [time(12)]) + s2.append(pl.Series([time(12), time(12)])) + + result = pl.time_ranges(s1, s2, eager=True) + expected = pl.Series( + "s1", + [ + [time(10, 0), time(11, 0), time(12, 0)], + [time(11, 0), time(12, 0)], + [time(12, 0)], + ], + ) + assert_series_equal(result, expected) diff --git a/py-polars/tests/unit/functions/test_business_day_count.py b/py-polars/tests/unit/functions/test_business_day_count.py new file mode 100644 index 000000000000..7a3c9236f62d --- /dev/null +++ b/py-polars/tests/unit/functions/test_business_day_count.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +import datetime as dt +from datetime import date + +import hypothesis.strategies as st +import numpy as np +import pytest +from hypothesis import assume, given, reject + +import polars as pl +from polars._utils.various import parse_version +from polars.exceptions import ComputeError +from polars.testing import assert_series_equal + + +def test_business_day_count() -> None: + # (Expression, expression) + df = pl.DataFrame( + { + "start": [date(2020, 1, 1), date(2020, 1, 2)], + "end": [date(2020, 1, 2), date(2020, 1, 10)], + } + ) + result = df.select( + business_day_count=pl.business_day_count("start", "end"), + )["business_day_count"] + expected = pl.Series("business_day_count", [1, 6], pl.Int32) + assert_series_equal(result, expected) + + # (Expression, scalar) + result = df.select( + business_day_count=pl.business_day_count("start", date(2020, 1, 10)), + )["business_day_count"] + expected = pl.Series("business_day_count", [7, 6], pl.Int32) + assert_series_equal(result, expected) + result = df.select( + business_day_count=pl.business_day_count("start", pl.lit(None, dtype=pl.Date)), + )["business_day_count"] + expected = pl.Series("business_day_count", [None, None], pl.Int32) + assert_series_equal(result, expected) + + # (Scalar, expression) + result = df.select( + business_day_count=pl.business_day_count(date(2020, 1, 1), "end"), + )["business_day_count"] + expected = pl.Series("business_day_count", [1, 7], pl.Int32) + assert_series_equal(result, expected) + result = df.select( + business_day_count=pl.business_day_count(pl.lit(None, dtype=pl.Date), "end"), + )["business_day_count"] + expected = pl.Series("business_day_count", [None, None], pl.Int32) + assert_series_equal(result, expected) + + # (Scalar, scalar) + result = df.select( + business_day_count=pl.business_day_count(date(2020, 1, 1), date(2020, 1, 10)), + )["business_day_count"] + expected = pl.Series("business_day_count", [7], pl.Int32) + assert_series_equal(result, expected) + + +def test_business_day_count_w_week_mask() -> None: + df = pl.DataFrame( + { + "start": [date(2020, 1, 1), date(2020, 1, 2)], + "end": [date(2020, 1, 2), date(2020, 1, 10)], + } + ) + result = df.select( + business_day_count=pl.business_day_count( + "start", "end", week_mask=(True, True, True, True, True, True, False) + ), + )["business_day_count"] + expected = pl.Series("business_day_count", [1, 7], pl.Int32) + assert_series_equal(result, expected) + + result = df.select( + business_day_count=pl.business_day_count( + "start", "end", week_mask=(True, True, True, False, False, False, True) + ), + )["business_day_count"] + expected = pl.Series("business_day_count", [1, 4], pl.Int32) + assert_series_equal(result, expected) + + +def test_business_day_count_w_week_mask_invalid() -> None: + with pytest.raises(ValueError, match=r"expected a sequence of length 7 \(got 2\)"): + pl.business_day_count("start", "end", week_mask=(False, 0)) # type: ignore[arg-type] + df = pl.DataFrame( + { + "start": [date(2020, 1, 1), date(2020, 1, 2)], + "end": [date(2020, 1, 2), date(2020, 1, 10)], + } + ) + with pytest.raises( + ComputeError, match="`week_mask` must have at least one business day" + ): + df.select(pl.business_day_count("start", "end", week_mask=[False] * 7)) + + +def test_business_day_count_schema() -> None: + lf = pl.LazyFrame( + { + "start": [date(2020, 1, 1), date(2020, 1, 2)], + "end": [date(2020, 1, 2), date(2020, 1, 10)], + } + ) + result = lf.select( + business_day_count=pl.business_day_count("start", "end"), + ) + assert result.collect_schema()["business_day_count"] == pl.Int32 + assert result.collect().schema["business_day_count"] == pl.Int32 + assert 'col("start").business_day_count([col("end")])' in result.explain() + + +def test_business_day_count_w_holidays() -> None: + df = pl.DataFrame( + { + "start": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 2)], + "end": [date(2020, 1, 2), date(2020, 1, 10), date(2020, 1, 9)], + } + ) + result = df.select( + business_day_count=pl.business_day_count( + "start", "end", holidays=[date(2020, 1, 1), date(2020, 1, 9)] + ), + )["business_day_count"] + expected = pl.Series("business_day_count", [0, 5, 5], pl.Int32) + assert_series_equal(result, expected) + + +@given( + start=st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)), + end=st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)), + week_mask=st.lists( + st.sampled_from([True, False]), + min_size=7, + max_size=7, + ), + holidays=st.lists( + st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)), + min_size=0, + max_size=100, + ), +) +def test_against_np_busday_count( + start: dt.date, end: dt.date, week_mask: tuple[bool, ...], holidays: list[dt.date] +) -> None: + assume(any(week_mask)) + result = ( + pl.DataFrame({"start": [start], "end": [end]}) + .select( + n=pl.business_day_count( + "start", "end", week_mask=week_mask, holidays=holidays + ) + )["n"] + .item() + ) + expected = np.busday_count(start, end, weekmask=week_mask, holidays=holidays) + if start > end and parse_version(np.__version__) < (1, 25): + # Bug in old versions of numpy + reject() + assert result == expected + + +def test_unequal_length_22018() -> None: + with pytest.raises(pl.exceptions.ShapeError): + pl.select( + pl.business_day_count( + pl.Series([date(2020, 1, 1)] * 2), + pl.Series([date(2020, 1, 1)] * 3), + ) + ) diff --git a/py-polars/tests/unit/functions/test_col.py b/py-polars/tests/unit/functions/test_col.py new file mode 100644 index 000000000000..6361bb08daae --- /dev/null +++ b/py-polars/tests/unit/functions/test_col.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import polars as pl +from polars import col +from polars.testing import assert_frame_equal + + +def test_col_select() -> None: + df = pl.DataFrame( + { + "ham": [1, 2, 3], + "hamburger": [11, 22, 33], + "foo": [3, 2, 1], + "bar": ["a", "b", "c"], + } + ) + + # Single column + assert df.select(pl.col("foo")).columns == ["foo"] + # Regex + assert df.select(pl.col("*")).columns == ["ham", "hamburger", "foo", "bar"] + assert df.select(pl.col("^ham.*$")).columns == ["ham", "hamburger"] + assert df.select(pl.col("*").exclude("ham")).columns == ["hamburger", "foo", "bar"] + # Multiple inputs + assert df.select(pl.col(["hamburger", "foo"])).columns == ["hamburger", "foo"] + assert df.select(pl.col("hamburger", "foo")).columns == ["hamburger", "foo"] + assert df.select(pl.col(pl.Series(["ham", "foo"]))).columns == ["ham", "foo"] + # Dtypes + assert df.select(pl.col(pl.String)).columns == ["bar"] + assert df.select(pl.col(pl.Int64, pl.Float64)).columns == [ + "ham", + "hamburger", + "foo", + ] + + +def test_col_series_selection() -> None: + ldf = pl.LazyFrame({"a": [1], "b": [1], "c": [1]}) + srs = pl.Series(["b", "c"]) + + assert ldf.select(pl.col(srs)).collect_schema().names() == ["b", "c"] + + +def test_col_dot_style() -> None: + df = pl.DataFrame({"lower": 1, "UPPER": 2, "_underscored": 3}) + + result = df.select( + col.lower, + col.UPPER, + col._underscored, + ) + + expected = df.select("lower", "UPPER", "_underscored") + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/functions/test_concat.py b/py-polars/tests/unit/functions/test_concat.py new file mode 100644 index 000000000000..75dbf608ed1a --- /dev/null +++ b/py-polars/tests/unit/functions/test_concat.py @@ -0,0 +1,96 @@ +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + + +@pytest.mark.slow +def test_concat_expressions_stack_overflow() -> None: + n = 10000 + e = pl.concat([pl.lit(x) for x in range(n)]) + + df = pl.select(e) + assert df.shape == (n, 1) + + +@pytest.mark.slow +def test_concat_lf_stack_overflow() -> None: + n = 1000 + bar = pl.DataFrame({"a": 0}).lazy() + + for i in range(n): + bar = pl.concat([bar, pl.DataFrame({"a": i}).lazy()]) + assert bar.collect().shape == (1001, 1) + + +def test_concat_vertically_relaxed() -> None: + a = pl.DataFrame( + data={"a": [1, 2, 3], "b": [True, False, None]}, + schema={"a": pl.Int8, "b": pl.Boolean}, + ) + b = pl.DataFrame( + data={"a": [43, 2, 3], "b": [32, 1, None]}, + schema={"a": pl.Int16, "b": pl.Int64}, + ) + out = pl.concat([a, b], how="vertical_relaxed") + assert out.schema == {"a": pl.Int16, "b": pl.Int64} + assert out.to_dict(as_series=False) == { + "a": [1, 2, 3, 43, 2, 3], + "b": [1, 0, None, 32, 1, None], + } + out = pl.concat([b, a], how="vertical_relaxed") + assert out.schema == {"a": pl.Int16, "b": pl.Int64} + assert out.to_dict(as_series=False) == { + "a": [43, 2, 3, 1, 2, 3], + "b": [32, 1, None, 1, 0, None], + } + + c = pl.DataFrame({"a": [1, 2], "b": [2, 1]}) + d = pl.DataFrame({"a": [1.0, 0.2], "b": [None, 0.1]}) + + out = pl.concat([c, d], how="vertical_relaxed") + assert out.schema == {"a": pl.Float64, "b": pl.Float64} + assert out.to_dict(as_series=False) == { + "a": [1.0, 2.0, 1.0, 0.2], + "b": [2.0, 1.0, None, 0.1], + } + out = pl.concat([d, c], how="vertical_relaxed") + assert out.schema == {"a": pl.Float64, "b": pl.Float64} + assert out.to_dict(as_series=False) == { + "a": [1.0, 0.2, 1.0, 2.0], + "b": [None, 0.1, 2.0, 1.0], + } + + +def test_concat_group_by() -> None: + df = pl.DataFrame( + { + "g": [0, 0, 0, 0, 1, 1, 1, 1], + "a": [0, 1, 2, 3, 4, 5, 6, 7], + "b": [8, 9, 10, 11, 12, 13, 14, 15], + } + ) + out = df.group_by("g").agg(pl.concat([pl.col.a, pl.col.b])) + + assert_frame_equal( + out, + pl.DataFrame( + { + "g": [0, 1], + "a": [[0, 1, 2, 3, 8, 9, 10, 11], [4, 5, 6, 7, 12, 13, 14, 15]], + } + ), + check_row_order=False, + ) + + +def test_concat_19877() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + out = df.select(pl.concat([pl.col("a"), pl.col("b")])) + assert_frame_equal(out, pl.DataFrame({"a": [1, 2, 3, 4]})) + + +def test_concat_zip_series_21980() -> None: + df = pl.DataFrame({"x": 1, "y": 2}) + out = df.select(pl.concat([pl.col.x, pl.col.y]), pl.Series([3, 4])) + assert_frame_equal(out, pl.DataFrame({"x": [1, 2], "": [3, 4]})) diff --git a/py-polars/tests/unit/functions/test_cum_count.py b/py-polars/tests/unit/functions/test_cum_count.py new file mode 100644 index 000000000000..4526da3a18f7 --- /dev/null +++ b/py-polars/tests/unit/functions/test_cum_count.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal + + +@pytest.mark.parametrize(("reverse", "output"), [(False, [1, 2, 2]), (True, [2, 1, 0])]) +def test_cum_count_single_arg(reverse: bool, output: list[int]) -> None: + df = pl.DataFrame({"a": [5, 5, None]}) + result = df.select(pl.cum_count("a", reverse=reverse)) + expected = pl.Series("a", output, dtype=pl.UInt32).to_frame() + assert_frame_equal(result, expected) + assert result.to_series().flags[("SORTED_ASC", "SORTED_DESC")[reverse]] + + +def test_cum_count_multi_arg() -> None: + df = pl.DataFrame( + { + "a": [5, 5, 5], + "b": [None, 5, 5], + "c": [5, None, 5], + "d": [5, 5, None], + "e": [None, None, None], + } + ) + result = df.select(pl.cum_count("a", "b", "c", "d", "e")) + expected = pl.DataFrame( + [ + pl.Series("a", [1, 2, 3], dtype=pl.UInt32), + pl.Series("b", [0, 1, 2], dtype=pl.UInt32), + pl.Series("c", [1, 1, 2], dtype=pl.UInt32), + pl.Series("d", [1, 2, 2], dtype=pl.UInt32), + pl.Series("e", [0, 0, 0], dtype=pl.UInt32), + ] + ) + assert_frame_equal(result, expected) + + +def test_cum_count_multi_arg_reverse() -> None: + df = pl.DataFrame( + { + "a": [5, 5, 5], + "b": [None, 5, 5], + "c": [5, None, 5], + "d": [5, 5, None], + "e": [None, None, None], + } + ) + result = df.select(pl.cum_count("a", "b", "c", "d", "e", reverse=True)) + expected = pl.DataFrame( + [ + pl.Series("a", [3, 2, 1], dtype=pl.UInt32), + pl.Series("b", [2, 2, 1], dtype=pl.UInt32), + pl.Series("c", [2, 1, 1], dtype=pl.UInt32), + pl.Series("d", [2, 1, 0], dtype=pl.UInt32), + pl.Series("e", [0, 0, 0], dtype=pl.UInt32), + ] + ) + assert_frame_equal(result, expected) + + +def test_cum_count() -> None: + df = pl.DataFrame( + [["a"], ["a"], ["a"], ["b"], ["b"], ["a"]], schema=["A"], orient="row" + ) + + out = df.group_by("A", maintain_order=True).agg( + pl.col("A").cum_count().alias("foo") + ) + + assert out["foo"][0].to_list() == [1, 2, 3, 4] + assert out["foo"][1].to_list() == [1, 2] + + +def test_series_cum_count() -> None: + s = pl.Series(["x", "k", None, "d"]) + result = s.cum_count() + expected = pl.Series([1, 2, 2, 3], dtype=pl.UInt32) + assert_series_equal(result, expected) diff --git a/py-polars/tests/unit/functions/test_ewm_by.py b/py-polars/tests/unit/functions/test_ewm_by.py new file mode 100644 index 000000000000..6d2384e0aefa --- /dev/null +++ b/py-polars/tests/unit/functions/test_ewm_by.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from datetime import date + +import hypothesis.strategies as st +import pytest +from hypothesis import given + +import polars as pl +from polars.testing import assert_frame_equal +from polars.testing.parametric import column, dataframes + + +@given( + data=st.data(), + half_life=st.integers(min_value=1, max_value=1000), +) +def test_ewm_by(data: st.DataObject, half_life: int) -> None: + # For evenly spaced times, ewm_by and ewm should be equivalent + df = data.draw( + dataframes( + [ + column( + "values", + strategy=st.floats(min_value=-100, max_value=100), + dtype=pl.Float64, + ), + ], + min_size=1, + ) + ) + result = df.with_row_index().select( + pl.col("values").ewm_mean_by(by="index", half_life=f"{half_life}i") + ) + expected = df.select( + pl.col("values").ewm_mean(half_life=half_life, ignore_nulls=False, adjust=False) + ) + assert_frame_equal(result, expected) + result = ( + df.with_row_index() + .sort("values") + .with_columns( + pl.col("values").ewm_mean_by(by="index", half_life=f"{half_life}i") + ) + .sort("index") + .select("values") + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("length", [1, 3]) +def test_length_mismatch_22084(length: int) -> None: + s = pl.Series([0, None]) + by = pl.Series([date(2020, 1, 5)] * length) + with pytest.raises(pl.exceptions.ShapeError): + s.ewm_mean_by(by, half_life="4d") diff --git a/py-polars/tests/unit/functions/test_functions.py b/py-polars/tests/unit/functions/test_functions.py new file mode 100644 index 000000000000..b71e41f24f57 --- /dev/null +++ b/py-polars/tests/unit/functions/test_functions.py @@ -0,0 +1,666 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import numpy as np +import pytest + +import polars as pl +from polars.exceptions import DuplicateError, InvalidOperationError +from polars.testing import assert_frame_equal, assert_series_equal + +if TYPE_CHECKING: + from polars._typing import ConcatMethod + + +def test_concat_align() -> None: + a = pl.DataFrame({"a": ["a", "b", "d", "e", "e"], "b": [1, 2, 4, 5, 6]}) + b = pl.DataFrame({"a": ["a", "b", "c"], "c": [5.5, 6.0, 7.5]}) + c = pl.DataFrame({"a": ["a", "b", "c", "d", "e"], "d": ["w", "x", "y", "z", None]}) + + for align_full in ("align", "align_full"): + result = pl.concat([a, b, c], how=align_full) + expected = pl.DataFrame( + { + "a": ["a", "b", "c", "d", "e", "e"], + "b": [1, 2, None, 4, 5, 6], + "c": [5.5, 6.0, 7.5, None, None, None], + "d": ["w", "x", "y", "z", None, None], + } + ) + assert_frame_equal(result, expected) + + result = pl.concat([a, b, c], how="align_left") + expected = pl.DataFrame( + { + "a": ["a", "b", "d", "e", "e"], + "b": [1, 2, 4, 5, 6], + "c": [5.5, 6.0, None, None, None], + "d": ["w", "x", "z", None, None], + } + ) + assert_frame_equal(result, expected) + + result = pl.concat([a, b, c], how="align_right") + expected = pl.DataFrame( + { + "a": ["a", "b", "c", "d", "e"], + "b": [1, 2, None, None, None], + "c": [5.5, 6.0, 7.5, None, None], + "d": ["w", "x", "y", "z", None], + } + ) + assert_frame_equal(result, expected) + + result = pl.concat([a, b, c], how="align_inner") + expected = pl.DataFrame( + { + "a": ["a", "b"], + "b": [1, 2], + "c": [5.5, 6.0], + "d": ["w", "x"], + } + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "strategy", ["align", "align_full", "align_left", "align_right"] +) +def test_concat_align_no_common_cols(strategy: ConcatMethod) -> None: + df1 = pl.DataFrame({"a": [1, 2], "b": [1, 2]}) + df2 = pl.DataFrame({"c": [3, 4], "d": [3, 4]}) + + with pytest.raises( + InvalidOperationError, + match=f"{strategy!r} strategy requires at least one common column", + ): + pl.concat((df1, df2), how=strategy) + + +@pytest.mark.parametrize( + ("a", "b", "c", "strategy"), + [ + ( + pl.DataFrame({"a": [1, 2]}), + pl.DataFrame({"b": ["a", "b"], "c": [3, 4]}), + pl.DataFrame({"a": [5, 6], "c": [5, 6], "d": [5, 6], "b": ["x", "y"]}), + "diagonal", + ), + ( + pl.DataFrame( + {"a": [1, 2]}, + schema_overrides={"a": pl.Int32}, + ), + pl.DataFrame( + {"b": ["a", "b"], "c": [3, 4]}, + schema_overrides={"c": pl.UInt8}, + ), + pl.DataFrame( + {"a": [5, 6], "c": [5, 6], "d": [5, 6], "b": ["x", "y"]}, + schema_overrides={"b": pl.Categorical}, + ), + "diagonal_relaxed", + ), + ], +) +def test_concat_diagonal( + a: pl.DataFrame, b: pl.DataFrame, c: pl.DataFrame, strategy: ConcatMethod +) -> None: + for out in [ + pl.concat([a, b, c], how=strategy), + pl.concat([a.lazy(), b.lazy(), c.lazy()], how=strategy).collect(), + ]: + expected = pl.DataFrame( + { + "a": [1, 2, None, None, 5, 6], + "b": [None, None, "a", "b", "x", "y"], + "c": [None, None, 3, 4, 5, 6], + "d": [None, None, None, None, 5, 6], + } + ) + assert_frame_equal(out, expected) + + +def test_concat_diagonal_relaxed_with_empty_frame() -> None: + df1 = pl.DataFrame() + df2 = pl.DataFrame( + { + "a": ["a", "b"], + "b": [1, 2], + } + ) + out = pl.concat((df1, df2), how="diagonal_relaxed") + expected = df2 + assert_frame_equal(out, expected) + + +@pytest.mark.parametrize("lazy", [False, True]) +def test_concat_horizontal(lazy: bool) -> None: + a = pl.DataFrame({"a": ["a", "b"], "b": [1, 2]}) + b = pl.DataFrame({"c": [5, 7, 8, 9], "d": [1, 2, 1, 2], "e": [1, 2, 1, 2]}) + + if lazy: + out = pl.concat([a.lazy(), b.lazy()], how="horizontal").collect() + else: + out = pl.concat([a, b], how="horizontal") + + expected = pl.DataFrame( + { + "a": ["a", "b", None, None], + "b": [1, 2, None, None], + "c": [5, 7, 8, 9], + "d": [1, 2, 1, 2], + "e": [1, 2, 1, 2], + } + ) + assert_frame_equal(out, expected) + + +@pytest.mark.parametrize("lazy", [False, True]) +def test_concat_horizontal_three_dfs(lazy: bool) -> None: + a = pl.DataFrame({"a1": [1, 2, 3], "a2": ["a", "b", "c"]}) + b = pl.DataFrame({"b1": [0.25, 0.5]}) + c = pl.DataFrame({"c1": [1, 2, 3, 4], "c2": [5, 6, 7, 8], "c3": [9, 10, 11, 12]}) + + if lazy: + out = pl.concat([a.lazy(), b.lazy(), c.lazy()], how="horizontal").collect() + else: + out = pl.concat([a, b, c], how="horizontal") + + expected = pl.DataFrame( + { + "a1": [1, 2, 3, None], + "a2": ["a", "b", "c", None], + "b1": [0.25, 0.5, None, None], + "c1": [1, 2, 3, 4], + "c2": [5, 6, 7, 8], + "c3": [9, 10, 11, 12], + } + ) + assert_frame_equal(out, expected) + + +@pytest.mark.parametrize("lazy", [False, True]) +def test_concat_horizontal_single_df(lazy: bool) -> None: + a = pl.DataFrame({"a": ["a", "b"], "b": [1, 2]}) + + if lazy: + out = pl.concat([a.lazy()], how="horizontal").collect() + else: + out = pl.concat([a], how="horizontal") + + expected = a + assert_frame_equal(out, expected) + + +def test_concat_horizontal_duplicate_col() -> None: + a = pl.LazyFrame({"a": ["a", "b"], "b": [1, 2]}) + b = pl.LazyFrame({"c": [5, 7, 8, 9], "d": [1, 2, 1, 2], "a": [1, 2, 1, 2]}) + + with pytest.raises(DuplicateError): + pl.concat([a, b], how="horizontal").collect() + + +def test_concat_vertical() -> None: + a = pl.DataFrame({"a": ["a", "b"], "b": [1, 2]}) + b = pl.DataFrame({"a": ["c", "d", "e"], "b": [3, 4, 5]}) + + result = pl.concat([a, b], how="vertical") + expected = pl.DataFrame( + { + "a": ["a", "b", "c", "d", "e"], + "b": [1, 2, 3, 4, 5], + } + ) + assert_frame_equal(result, expected) + + +def test_cov() -> None: + s1 = pl.Series("a", [10, 37, -40]) + s2 = pl.Series("b", [70, -10, 35]) + + # lazy/expression + lf = pl.LazyFrame([s1, s2]) + res1 = lf.select( + x=pl.cov("a", "b"), + y=pl.cov("a", "b", ddof=2), + ).collect() + + # eager/series + res2 = ( + pl.cov(s1, s2, eager=True).alias("x"), + pl.cov(s1, s2, eager=True, ddof=2).alias("y"), + ) + + # expect same result from both approaches + for idx, (r1, r2) in enumerate(zip(res1, res2)): + expected_value = -645.8333333333 if idx == 0 else -1291.6666666666 + assert pytest.approx(expected_value) == r1.item() + assert_series_equal(r1, r2) + + +def test_corr() -> None: + s1 = pl.Series("a", [10, 37, -40]) + s2 = pl.Series("b", [70, -10, 35]) + + # lazy/expression + lf = pl.LazyFrame([s1, s2]) + res1 = lf.select( + x=pl.corr("a", "b"), + y=pl.corr("a", "b", method="spearman"), + ).collect() + + # eager/series + res2 = ( + pl.corr(s1, s2, eager=True).alias("x"), + pl.corr(s1, s2, method="spearman", eager=True).alias("y"), + ) + + # expect same result from both approaches + for idx, (r1, r2) in enumerate(zip(res1, res2)): + assert pytest.approx(-0.412199756 if idx == 0 else -0.5) == r1.item() + assert_series_equal(r1, r2) + + +def test_extend_ints() -> None: + a = pl.DataFrame({"a": [1 for _ in range(1)]}, schema={"a": pl.Int64}) + with pytest.raises(pl.exceptions.SchemaError): + a.extend(a.select(pl.lit(0, dtype=pl.Int32).alias("a"))) + + +def test_null_handling_correlation() -> None: + df = pl.DataFrame({"a": [1, 2, 3, None, 4], "b": [1, 2, 3, 10, 4]}) + + out = df.select( + pl.corr("a", "b").alias("pearson"), + pl.corr("a", "b", method="spearman").alias("spearman"), + ) + assert out["pearson"][0] == pytest.approx(1.0) + assert out["spearman"][0] == pytest.approx(1.0) + + # see #4930 + df1 = pl.DataFrame({"a": [None, 1, 2], "b": [None, 2, 1]}) + df2 = pl.DataFrame({"a": [np.nan, 1, 2], "b": [np.nan, 2, 1]}) + + assert np.isclose(df1.select(pl.corr("a", "b", method="spearman")).item(), -1.0) + assert ( + str( + df2.select(pl.corr("a", "b", method="spearman", propagate_nans=True)).item() + ) + == "nan" + ) + + +def test_align_frames() -> None: + import numpy as np + import pandas as pd + + # setup some test frames + pdf1 = pd.DataFrame( + { + "date": pd.date_range(start="2019-01-02", periods=9), + "a": np.array([0, 1, 2, np.nan, 4, 5, 6, 7, 8], dtype=np.float64), + "b": np.arange(9, 18, dtype=np.float64), + } + ).set_index("date") + + pdf2 = pd.DataFrame( + { + "date": pd.date_range(start="2019-01-04", periods=7), + "a": np.arange(9, 16, dtype=np.float64), + "b": np.arange(10, 17, dtype=np.float64), + } + ).set_index("date") + + # calculate dot-product in pandas + pd_dot = (pdf1 * pdf2).sum(axis="columns").to_frame("dot").reset_index() + + # use "align_frames" to calculate dot-product from disjoint rows. pandas uses an + # index to automatically infer the correct frame-alignment for the calculation; + # we need to do it explicitly (which also makes it clearer what is happening) + pf1, pf2 = pl.align_frames( + pl.from_pandas(pdf1.reset_index()), + pl.from_pandas(pdf2.reset_index()), + on="date", + ) + pl_dot = ( + (pf1[["a", "b"]] * pf2[["a", "b"]]) + .fill_null(0) + .select(pl.sum_horizontal("*").alias("dot")) + .insert_column(0, pf1["date"]) + ) + # confirm we match the same operation in pandas + assert_frame_equal(pl_dot, pl.from_pandas(pd_dot)) + pd.testing.assert_frame_equal(pd_dot, pl_dot.to_pandas()) + + # confirm alignment function works with lazy frames + lf1, lf2 = pl.align_frames( + pl.from_pandas(pdf1.reset_index()).lazy(), + pl.from_pandas(pdf2.reset_index()).lazy(), + on="date", + ) + assert isinstance(lf1, pl.LazyFrame) + assert_frame_equal(lf1.collect(), pf1) + assert_frame_equal(lf2.collect(), pf2) + + # misc: no frames results in an empty list + assert pl.align_frames(on="date") == [] + + # expected error condition + with pytest.raises(TypeError): + pl.align_frames( # type: ignore[type-var] + pl.from_pandas(pdf1.reset_index()).lazy(), + pl.from_pandas(pdf2.reset_index()), + on="date", + ) + + +def test_align_frames_misc() -> None: + df1 = pl.DataFrame([[3, 5, 6], [5, 8, 9]], orient="row") + df2 = pl.DataFrame([[2, 5, 6], [3, 8, 9], [4, 2, 0]], orient="row") + + # descending result + pf1, pf2 = pl.align_frames( + [df1, df2], # list input + on="column_0", + descending=True, + ) + assert pf1.rows() == [(5, 8, 9), (4, None, None), (3, 5, 6), (2, None, None)] + assert pf2.rows() == [(5, None, None), (4, 2, 0), (3, 8, 9), (2, 5, 6)] + + # handle identical frames + pf1, pf2, pf3 = pl.align_frames( + (df for df in (df1, df2, df2)), # generator input + on="column_0", + descending=True, + ) + assert pf1.rows() == [(5, 8, 9), (4, None, None), (3, 5, 6), (2, None, None)] + for pf in (pf2, pf3): + assert pf.rows() == [(5, None, None), (4, 2, 0), (3, 8, 9), (2, 5, 6)] + + +def test_align_frames_with_nulls() -> None: + df1 = pl.DataFrame({"key": ["x", "y", None], "value": [1, 2, 0]}) + df2 = pl.DataFrame({"key": ["x", None, "z", "y"], "value": [4, 3, 6, 5]}) + + a1, a2 = pl.align_frames(df1, df2, on="key") + + aligned_frame_data = a1.to_dict(as_series=False), a2.to_dict(as_series=False) + assert aligned_frame_data == ( + {"key": [None, "x", "y", "z"], "value": [0, 1, 2, None]}, + {"key": [None, "x", "y", "z"], "value": [3, 4, 5, 6]}, + ) + + +def test_align_frames_duplicate_key() -> None: + # setup some test frames with duplicate key/alignment values + df1 = pl.DataFrame({"x": ["a", "a", "a", "e"], "y": [1, 2, 4, 5]}) + df2 = pl.DataFrame({"y": [0, 0, -1], "z": [5.5, 6.0, 7.5], "x": ["a", "b", "b"]}) + + # align rows, confirming correctness and original column order + af1, af2 = pl.align_frames(df1, df2, on="x") + + # shape: (6, 2) shape: (6, 3) + # ┌─────┬──────┐ ┌──────┬──────┬─────┐ + # │ x ┆ y │ │ y ┆ z ┆ x │ + # │ --- ┆ --- │ │ --- ┆ --- ┆ --- │ + # │ str ┆ i64 │ │ i64 ┆ f64 ┆ str │ + # ╞═════╪══════╡ ╞══════╪══════╪═════╡ + # │ a ┆ 1 │ │ 0 ┆ 5.5 ┆ a │ + # │ a ┆ 2 │ │ 0 ┆ 5.5 ┆ a │ + # │ a ┆ 4 │ │ 0 ┆ 5.5 ┆ a │ + # │ b ┆ null │ │ 0 ┆ 6.0 ┆ b │ + # │ b ┆ null │ │ -1 ┆ 7.5 ┆ b │ + # │ e ┆ 5 │ │ null ┆ null ┆ e │ + # └─────┴──────┘ └──────┴──────┴─────┘ + assert af1.rows() == [ + ("a", 1), + ("a", 2), + ("a", 4), + ("b", None), + ("b", None), + ("e", 5), + ] + assert af2.rows() == [ + (0, 5.5, "a"), + (0, 5.5, "a"), + (0, 5.5, "a"), + (0, 6.0, "b"), + (-1, 7.5, "b"), + (None, None, "e"), + ] + + # align frames the other way round, using "left" alignment strategy + af1, af2 = pl.align_frames(df2, df1, on="x", how="left") + + # shape: (5, 3) shape: (5, 2) + # ┌─────┬─────┬─────┐ ┌─────┬──────┐ + # │ y ┆ z ┆ x │ │ x ┆ y │ + # │ --- ┆ --- ┆ --- │ │ --- ┆ --- │ + # │ i64 ┆ f64 ┆ str │ │ str ┆ i64 │ + # ╞═════╪═════╪═════╡ ╞═════╪══════╡ + # │ 0 ┆ 5.5 ┆ a │ │ a ┆ 1 │ + # │ 0 ┆ 5.5 ┆ a │ │ a ┆ 2 │ + # │ 0 ┆ 5.5 ┆ a │ │ a ┆ 4 │ + # │ 0 ┆ 6.0 ┆ b │ │ b ┆ null │ + # │ -1 ┆ 7.5 ┆ b │ │ b ┆ null │ + # └─────┴─────┴─────┘ └─────┴──────┘ + assert af1.rows() == [ + (0, 5.5, "a"), + (0, 5.5, "a"), + (0, 5.5, "a"), + (0, 6.0, "b"), + (-1, 7.5, "b"), + ] + assert af2.rows() == [ + ("a", 1), + ("a", 2), + ("a", 4), + ("b", None), + ("b", None), + ] + + +def test_align_frames_single_row_20445() -> None: + left = pl.DataFrame({"a": [1], "b": [2]}) + right = pl.DataFrame({"a": [1], "c": [3]}) + result = pl.align_frames(left, right, how="left", on="a") + assert_frame_equal(result[0], left) + assert_frame_equal(result[1], right) + + +def test_coalesce() -> None: + df = pl.DataFrame( + { + "a": [1, None, None, None], + "b": [1, 2, None, None], + "c": [5, None, 3, None], + } + ) + # list inputs + expected = pl.Series("d", [1, 2, 3, 10]).to_frame() + result = df.select(pl.coalesce(["a", "b", "c", 10]).alias("d")) + assert_frame_equal(expected, result) + + # positional inputs + expected = pl.Series("d", [1.0, 2.0, 3.0, 10.0]).to_frame() + result = df.select(pl.coalesce(pl.col(["a", "b", "c"]), 10.0).alias("d")) + assert_frame_equal(result, expected) + + +def test_coalesce_eager() -> None: + # eager/series inputs + s1 = pl.Series("colx", [None, 2, None]) + s2 = pl.Series("coly", [1, None, None]) + s3 = pl.Series("colz", [None, None, 3]) + + res = pl.coalesce(s1, s2, s3, eager=True) + expected = pl.Series("colx", [1, 2, 3]) + assert_series_equal(expected, res) + + for zero in (0, pl.lit(0)): + res = pl.coalesce(s1, zero, eager=True) + expected = pl.Series("colx", [0, 2, 0]) + assert_series_equal(expected, res) + + res = pl.coalesce(zero, s1, eager=True) + expected = pl.Series("literal", [0, 0, 0]) + assert_series_equal(expected, res) + + with pytest.raises( + ValueError, + match="expected at least one Series in 'coalesce' if 'eager=True'", + ): + pl.coalesce("x", "y", eager=True) + + +def test_overflow_diff() -> None: + df = pl.DataFrame({"a": [20, 10, 30]}) + assert df.select(pl.col("a").cast(pl.UInt64).diff()).to_dict(as_series=False) == { + "a": [None, -10, 20] + } + + +def test_fill_null_unknown_output_type() -> None: + df = pl.DataFrame({"a": [None, 2, 3, 4, 5]}) + assert df.with_columns( + np.exp(pl.col("a")).fill_null(pl.lit(1, pl.Float64)) + ).to_dict(as_series=False) == { + "a": [ + 1.0, + 7.38905609893065, + 20.085536923187668, + 54.598150033144236, + 148.4131591025766, + ] + } + + +def test_approx_n_unique() -> None: + df1 = pl.DataFrame({"a": [None, 1, 2], "b": [None, 2, 1]}) + + assert_frame_equal( + df1.select(pl.approx_n_unique("b")), + pl.DataFrame({"b": pl.Series(values=[3], dtype=pl.UInt32)}), + ) + + assert_frame_equal( + df1.select(pl.col("b").approx_n_unique()), + pl.DataFrame({"b": pl.Series(values=[3], dtype=pl.UInt32)}), + ) + + +def test_lazy_functions() -> None: + df = pl.DataFrame( + { + "a": ["foo", "bar", "foo"], + "b": [1, 2, 3], + "c": [-1.0, 2.0, 4.0], + } + ) + + # test function expressions against frame + out = df.select( + pl.var("b").name.suffix("_var"), + pl.std("b").name.suffix("_std"), + pl.max("a", "b").name.suffix("_max"), + pl.min("a", "b").name.suffix("_min"), + pl.sum("b", "c").name.suffix("_sum"), + pl.mean("b", "c").name.suffix("_mean"), + pl.median("c", "b").name.suffix("_median"), + pl.n_unique("b", "a").name.suffix("_n_unique"), + pl.first("a").name.suffix("_first"), + pl.first("b", "c").name.suffix("_first"), + pl.last("c", "b", "a").name.suffix("_last"), + ) + expected: dict[str, list[Any]] = { + "b_var": [1.0], + "b_std": [1.0], + "a_max": ["foo"], + "b_max": [3], + "a_min": ["bar"], + "b_min": [1], + "b_sum": [6], + "c_sum": [5.0], + "b_mean": [2.0], + "c_mean": [5 / 3], + "c_median": [2.0], + "b_median": [2.0], + "b_n_unique": [3], + "a_n_unique": [2], + "a_first": ["foo"], + "b_first": [1], + "c_first": [-1.0], + "c_last": [4.0], + "b_last": [3], + "a_last": ["foo"], + } + assert_frame_equal( + out, + pl.DataFrame( + data=expected, + schema_overrides={ + "a_n_unique": pl.UInt32, + "b_n_unique": pl.UInt32, + }, + ), + ) + + # test function expressions against series + for name, value in expected.items(): + col, fn = name.split("_", 1) + if series_fn := getattr(df[col], fn, None): + assert series_fn() == value[0] + + # regex selection + out = df.select( + pl.struct(pl.max("^a|b$")).alias("x"), + pl.struct(pl.min("^.*[bc]$")).alias("y"), + pl.struct(pl.sum("^[^a]$")).alias("z"), + ) + assert out.rows() == [ + ({"a": "foo", "b": 3}, {"b": 1, "c": -1.0}, {"b": 6, "c": 5.0}) + ] + + +def test_count() -> None: + df = pl.DataFrame({"a": [1, 1, 1], "b": [None, "xx", "yy"]}) + out = df.select(pl.count("a")) + assert list(out["a"]) == [3] + + for count_expr in ( + pl.count("b", "a"), + [pl.count("b"), pl.count("a")], + ): + out = df.select(count_expr) + assert out.rows() == [(2, 3)] + + +def test_head_tail(fruits_cars: pl.DataFrame) -> None: + res_expr = fruits_cars.select(pl.head("A", 2)) + expected = pl.Series("A", [1, 2]) + assert_series_equal(res_expr.to_series(), expected) + + res_expr = fruits_cars.select(pl.tail("A", 2)) + expected = pl.Series("A", [4, 5]) + assert_series_equal(res_expr.to_series(), expected) + + +def test_escape_regex() -> None: + result = pl.escape_regex("abc(\\w+)") + expected = "abc\\(\\\\w\\+\\)" + assert result == expected + + df = pl.DataFrame({"text": ["abc", "def", None, "abc(\\w+)"]}) + with pytest.raises( + TypeError, + match="escape_regex function is unsupported for `Expr`, you may want use `Expr.str.escape_regex` instead", + ): + df.with_columns(escaped=pl.escape_regex(pl.col("text"))) # type: ignore[arg-type] + + with pytest.raises( + TypeError, + match="escape_regex function supports only `str` type, got `int`", + ): + pl.escape_regex(3) # type: ignore[arg-type] diff --git a/py-polars/tests/unit/functions/test_horizontal.py b/py-polars/tests/unit/functions/test_horizontal.py new file mode 100644 index 000000000000..688e2bfcf58a --- /dev/null +++ b/py-polars/tests/unit/functions/test_horizontal.py @@ -0,0 +1,21 @@ +import pytest + +import polars as pl + + +@pytest.mark.parametrize( + "f", + [ + "min", + "max", + "sum", + "mean", + ], +) +def test_shape_mismatch_19336(f: str) -> None: + a = pl.Series([1, 2, 3]) + b = pl.Series([1, 2]) + fn = getattr(pl, f"{f}_horizontal") + + with pytest.raises(pl.exceptions.ShapeError): + pl.select((fn)(a, b)) diff --git a/py-polars/tests/unit/functions/test_lit.py b/py-polars/tests/unit/functions/test_lit.py new file mode 100644 index 000000000000..ef2ca9712818 --- /dev/null +++ b/py-polars/tests/unit/functions/test_lit.py @@ -0,0 +1,245 @@ +# mypy: disable-error-code="redundant-expr" +from __future__ import annotations + +import enum +import sys +from datetime import datetime, timedelta +from decimal import Decimal +from typing import TYPE_CHECKING, Any + +import numpy as np +import pytest +from hypothesis import given + +import polars as pl +from polars.testing import assert_frame_equal +from polars.testing.parametric.strategies import series +from polars.testing.parametric.strategies.data import datetimes + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + + +if sys.version_info >= (3, 11): + from enum import StrEnum + + PyStrEnum: type[enum.Enum] | None = StrEnum +else: + PyStrEnum = None + + +@pytest.mark.parametrize( + "input", + [ + [[1, 2], [3, 4, 5]], + [1, 2, 3], + ], +) +def test_lit_list_input(input: list[Any]) -> None: + df = pl.DataFrame({"a": [1, 2]}) + result = df.with_columns(pl.lit(input).first()) + expected = pl.DataFrame({"a": [1, 2], "literal": [input, input]}) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "input", + [ + ([1, 2], [3, 4, 5]), + (1, 2, 3), + ], +) +def test_lit_tuple_input(input: tuple[Any, ...]) -> None: + df = pl.DataFrame({"a": [1, 2]}) + result = df.with_columns(pl.lit(input).first()) + + expected = pl.DataFrame({"a": [1, 2], "literal": [list(input), list(input)]}) + assert_frame_equal(result, expected) + + +def test_lit_numpy_array_input() -> None: + df = pl.DataFrame({"a": [1, 2]}) + input = np.array([3, 4]) + + result = df.with_columns(pl.lit(input, dtype=pl.Int64)) + + expected = pl.DataFrame({"a": [1, 2], "literal": [3, 4]}) + assert_frame_equal(result, expected) + + +def test_lit_ambiguous_datetimes_11379() -> None: + df = pl.DataFrame( + { + "ts": pl.datetime_range( + datetime(2020, 10, 25), + datetime(2020, 10, 25, 2), + "1h", + time_zone="Europe/London", + eager=True, + ) + } + ) + for i in range(df.height): + result = df.filter(pl.col("ts") >= df["ts"][i]) + expected = df[i:] + assert_frame_equal(result, expected) + + +def test_list_datetime_11571() -> None: + sec_np_ns = np.timedelta64(1_000_000_000, "ns") + sec_np_us = np.timedelta64(1_000_000, "us") + assert pl.select(pl.lit(sec_np_ns))[0, 0] == timedelta(seconds=1) + assert pl.select(pl.lit(sec_np_us))[0, 0] == timedelta(seconds=1) + + +@pytest.mark.parametrize( + ("input", "dtype"), + [ + pytest.param(-(2**31), pl.Int32, id="i32 min"), + pytest.param(-(2**31) - 1, pl.Int64, id="below i32 min"), + pytest.param(2**31 - 1, pl.Int32, id="i32 max"), + pytest.param(2**31, pl.Int64, id="above i32 max"), + pytest.param(2**63 - 1, pl.Int64, id="i64 max"), + pytest.param(2**63, pl.UInt64, id="above i64 max"), + ], +) +def test_lit_int_return_type(input: int, dtype: PolarsDataType) -> None: + assert pl.select(pl.lit(input)).to_series().dtype == dtype + + +def test_lit_unsupported_type() -> None: + with pytest.raises( + TypeError, + match="cannot create expression literal for value of type LazyFrame", + ): + pl.lit(pl.LazyFrame({"a": [1, 2, 3]})) + + +@pytest.mark.parametrize( + "EnumBase", + [ + (enum.Enum,), + (str, enum.Enum), + *([(PyStrEnum,)] if PyStrEnum is not None else []), + ], +) +def test_lit_enum_input_16668(EnumBase: tuple[type, ...]) -> None: + # https://github.com/pola-rs/polars/issues/16668 + + class State(*EnumBase): # type: ignore[misc] + NSW = "New South Wales" + QLD = "Queensland" + VIC = "Victoria" + + # validate that frame schema has inferred the enum + df = pl.DataFrame({"state": [State.NSW, State.VIC]}) + assert df.schema == { + "state": pl.Enum(["New South Wales", "Queensland", "Victoria"]) + } + + # check use of enum as lit/constraint + value = State.VIC + expected = "Victoria" + + for lit_value in ( + pl.lit(value), + pl.lit(value.value), # type: ignore[attr-defined] + ): + assert pl.select(lit_value).item() == expected + assert df.filter(state=value).item() == expected + assert df.filter(state=lit_value).item() == expected + + assert df.filter(pl.col("state") == State.QLD).is_empty() + assert df.filter(pl.col("state") != State.QLD).height == 2 + + +@pytest.mark.parametrize( + "EnumBase", + [ + (enum.Enum,), + (enum.Flag,), + (enum.IntEnum,), + (enum.IntFlag,), + (int, enum.Enum), + ], +) +def test_lit_enum_input_non_string(EnumBase: tuple[type, ...]) -> None: + # https://github.com/pola-rs/polars/issues/16668 + + class Number(*EnumBase): # type: ignore[misc] + ONE = 1 + TWO = 2 + + value = Number.ONE + + result = pl.lit(value) + assert pl.select(result).dtypes[0] == pl.Int32 + assert pl.select(result).item() == 1 + + result = pl.lit(value, dtype=pl.Int8) + assert pl.select(result).dtypes[0] == pl.Int8 + assert pl.select(result).item() == 1 + + +@given(value=datetimes("ns")) +def test_datetime_ns(value: datetime) -> None: + result = pl.select(pl.lit(value, dtype=pl.Datetime("ns")))["literal"][0] + assert result == value + + +@given(value=datetimes("us")) +def test_datetime_us(value: datetime) -> None: + result = pl.select(pl.lit(value, dtype=pl.Datetime("us")))["literal"][0] + assert result == value + result = pl.select(pl.lit(value, dtype=pl.Datetime))["literal"][0] + assert result == value + + +@given(value=datetimes("ms")) +def test_datetime_ms(value: datetime) -> None: + result = pl.select(pl.lit(value, dtype=pl.Datetime("ms")))["literal"][0] + expected_microsecond = value.microsecond // 1000 * 1000 + assert result == value.replace(microsecond=expected_microsecond) + + +def test_lit_decimal() -> None: + value = Decimal("0.1") + + expr = pl.lit(value) + df = pl.select(expr) + result = df.item() + + assert df.dtypes[0] == pl.Decimal(None, 1) + assert result == value + + +def test_lit_string_float() -> None: + value = 3.2 + + expr = pl.lit(value, dtype=pl.Utf8) + df = pl.select(expr) + result = df.item() + + assert df.dtypes[0] == pl.String + assert result == str(value) + + +@given(s=series(min_size=1, max_size=1, allow_null=False, allowed_dtypes=pl.Decimal)) +def test_lit_decimal_parametric(s: pl.Series) -> None: + scale = s.dtype.scale # type: ignore[attr-defined] + value = s.item() + + expr = pl.lit(value) + df = pl.select(expr) + result = df.item() + + assert df.dtypes[0] == pl.Decimal(None, scale) + assert result == value + + +@pytest.mark.parametrize( + "item", + [{}, {"foo": 1}], +) +def test_lit_structs(item: Any) -> None: + assert pl.select(pl.lit(item)).to_dict(as_series=False) == {"literal": [item]} diff --git a/py-polars/tests/unit/functions/test_nth.py b/py-polars/tests/unit/functions/test_nth.py new file mode 100644 index 000000000000..f13d28bb58a6 --- /dev/null +++ b/py-polars/tests/unit/functions/test_nth.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import pytest + +import polars as pl +from polars.exceptions import DuplicateError +from polars.testing import assert_frame_equal + + +@pytest.mark.parametrize( + ("expr", "expected_cols"), + [ + (pl.nth(0), "a"), + (pl.nth(-1), "c"), + (pl.nth(2, 1), ["c", "b"]), + (pl.nth([2, -2, 0]), ["c", "b", "a"]), + ], +) +def test_nth(expr: pl.Expr, expected_cols: list[str]) -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}) + result = df.select(expr) + expected = df.select(expected_cols) + assert_frame_equal(result, expected) + + +def test_nth_duplicate() -> None: + df = pl.DataFrame({"a": [1, 2]}) + with pytest.raises(DuplicateError, match="a"): + df.select(pl.nth(0, 0)) diff --git a/py-polars/tests/unit/functions/test_repeat.py b/py-polars/tests/unit/functions/test_repeat.py new file mode 100644 index 000000000000..5b66ebd8c020 --- /dev/null +++ b/py-polars/tests/unit/functions/test_repeat.py @@ -0,0 +1,441 @@ +from __future__ import annotations + +from datetime import date, datetime, time, timedelta +from typing import TYPE_CHECKING, Any + +import pytest + +import polars as pl +from polars.exceptions import ComputeError, SchemaError +from polars.testing import assert_frame_equal, assert_series_equal + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + + +@pytest.mark.parametrize( + ("value", "n", "dtype", "expected_dtype"), + [ + (2**31, 5, None, pl.Int64), + (2**31 - 1, 5, None, pl.Int32), + (-(2**31) - 1, 3, None, pl.Int64), + (-(2**31), 3, None, pl.Int32), + ("foo", 2, None, pl.String), + (1.0, 5, None, pl.Float64), + (True, 4, None, pl.Boolean), + (None, 7, None, pl.Null), + (0, 0, None, pl.Int32), + (datetime(2023, 2, 2), 3, None, pl.Datetime), + (date(2023, 2, 2), 3, None, pl.Date), + (time(10, 15), 1, None, pl.Time), + (timedelta(hours=3), 10, None, pl.Duration), + (8, 2, pl.UInt8, pl.UInt8), + (date(2023, 2, 2), 3, pl.Datetime, pl.Datetime), + (7.5, 5, pl.UInt16, pl.UInt16), + ([1, 2, 3], 2, pl.List(pl.Int64), pl.List(pl.Int64)), + (b"ab12", 3, pl.Binary, pl.Binary), + ], +) +def test_repeat( + value: Any, + n: int, + dtype: PolarsDataType, + expected_dtype: PolarsDataType, +) -> None: + expected = pl.Series("repeat", [value] * n).cast(expected_dtype) + + result_eager = pl.repeat(value, n=n, dtype=dtype, eager=True) + assert_series_equal(result_eager, expected) + + result_lazy = pl.select(pl.repeat(value, n=n, dtype=dtype, eager=False)).to_series() + assert_series_equal(result_lazy, expected) + + +def test_repeat_expr_input_eager() -> None: + result = pl.select(pl.repeat(1, n=pl.lit(3), eager=True)).to_series() + expected = pl.Series("repeat", [1, 1, 1], dtype=pl.Int32) + assert_series_equal(result, expected) + + +def test_repeat_expr_input_lazy() -> None: + df = pl.DataFrame({"a": [3, 2, 1]}) + result = df.select(pl.repeat(1, n=pl.col("a").first())).to_series() + expected = pl.Series("repeat", [1, 1, 1], dtype=pl.Int32) + assert_series_equal(result, expected) + + df = pl.DataFrame({"a": [3, 2, 1]}) + assert df.select(pl.repeat(pl.sum("a"), n=2)).to_series().to_list() == [6, 6] + + +def test_repeat_n_zero() -> None: + assert pl.repeat(1, n=0, eager=True).len() == 0 + + +@pytest.mark.parametrize( + "n", + [1.5, 2.0, date(1971, 1, 2), "hello"], +) +def test_repeat_n_non_integer(n: Any) -> None: + with pytest.raises(SchemaError, match="expected expression of dtype 'integer'"): + pl.repeat(1, n=pl.lit(n), eager=True) + + +def test_repeat_n_empty() -> None: + df = pl.DataFrame(schema={"a": pl.Int32}) + with pytest.raises(ComputeError, match="'n' must be scalar value"): + df.select(pl.repeat(1, n=pl.col("a"))) + + +def test_repeat_n_negative() -> None: + with pytest.raises(ComputeError, match="could not parse value '-1' as a size"): + pl.repeat(1, n=-1, eager=True) + + +@pytest.mark.parametrize( + ("n", "value", "dtype"), + [ + (2, 1, pl.UInt32), + (0, 1, pl.Int16), + (3, 1, pl.Float32), + (1, "1", pl.Utf8), + (2, ["1"], pl.List(pl.Utf8)), + (4, True, pl.Boolean), + (2, [True], pl.List(pl.Boolean)), + (2, [1], pl.Array(pl.Int16, shape=1)), + (2, [1, 1, 1], pl.Array(pl.Int8, shape=3)), + (1, [1], pl.List(pl.UInt32)), + ], +) +def test_ones( + n: int, + value: Any, + dtype: PolarsDataType, +) -> None: + expected = pl.Series("ones", [value] * n, dtype=dtype) + + result_eager = pl.ones(n=n, dtype=dtype, eager=True) + assert_series_equal(result_eager, expected) + + result_lazy = pl.select(pl.ones(n=n, dtype=dtype, eager=False)).to_series() + assert_series_equal(result_lazy, expected) + + +@pytest.mark.parametrize( + ("n", "value", "dtype"), + [ + (2, 0, pl.UInt8), + (0, 0, pl.Int32), + (3, 0, pl.Float32), + (1, "0", pl.Utf8), + (2, ["0"], pl.List(pl.Utf8)), + (4, False, pl.Boolean), + (2, [False], pl.List(pl.Boolean)), + (3, [0], pl.Array(pl.UInt32, shape=1)), + (2, [0, 0, 0], pl.Array(pl.UInt32, shape=3)), + (1, [0], pl.List(pl.UInt32)), + ], +) +def test_zeros( + n: int, + value: Any, + dtype: PolarsDataType, +) -> None: + expected = pl.Series("zeros", [value] * n, dtype=dtype) + + result_eager = pl.zeros(n=n, dtype=dtype, eager=True) + assert_series_equal(result_eager, expected) + + result_lazy = pl.select(pl.zeros(n=n, dtype=dtype, eager=False)).to_series() + assert_series_equal(result_lazy, expected) + + +def test_ones_zeros_misc() -> None: + # check we default to f64 if dtype is unspecified + s_ones = pl.ones(n=2, eager=True) + s_zeros = pl.zeros(n=2, eager=True) + + assert s_ones.dtype == s_zeros.dtype == pl.Float64 + + # confirm that we raise a suitable error if dtype is invalid + with pytest.raises(TypeError, match="invalid dtype for `ones`"): + pl.ones(n=2, dtype=pl.Struct({"x": pl.Date, "y": pl.Duration}), eager=True) + + with pytest.raises(TypeError, match="invalid dtype for `zeros`"): + pl.zeros(n=2, dtype=pl.Struct({"x": pl.Date, "y": pl.Duration}), eager=True) + + +def test_repeat_by_logical_dtype() -> None: + with pl.StringCache(): + df = pl.DataFrame( + { + "repeat": [1, 2, 3], + "date": [date(2021, 1, 1)] * 3, + "cat": ["a", "b", "c"], + }, + schema={"repeat": pl.Int32, "date": pl.Date, "cat": pl.Categorical}, + ) + out = df.select( + pl.col("date").repeat_by("repeat"), pl.col("cat").repeat_by("repeat") + ) + + expected_df = pl.DataFrame( + { + "date": [ + [date(2021, 1, 1)], + [date(2021, 1, 1), date(2021, 1, 1)], + [date(2021, 1, 1), date(2021, 1, 1), date(2021, 1, 1)], + ], + "cat": [["a"], ["b", "b"], ["c", "c", "c"]], + }, + schema={"date": pl.List(pl.Date), "cat": pl.List(pl.Categorical)}, + ) + + assert_frame_equal(out, expected_df) + + +def test_repeat_by_list() -> None: + df = pl.DataFrame( + { + "repeat": [1, 2, 3, None], + "value": [None, [1, 2, 3], [4, None], [1, 2]], + }, + schema={"repeat": pl.UInt32, "value": pl.List(pl.UInt8)}, + ) + out = df.select(pl.col("value").repeat_by("repeat")) + + expected_df = pl.DataFrame( + { + "value": [ + [None], + [[1, 2, 3], [1, 2, 3]], + [[4, None], [4, None], [4, None]], + None, + ], + }, + schema={"value": pl.List(pl.List(pl.UInt8))}, + ) + + assert_frame_equal(out, expected_df) + + +def test_repeat_by_nested_list() -> None: + df = pl.DataFrame( + { + "repeat": [1, 2, 3], + "value": [None, [[1], [2, 2]], [[3, 3], None, [4, None]]], + }, + schema={"repeat": pl.UInt32, "value": pl.List(pl.List(pl.Int16))}, + ) + out = df.select(pl.col("value").repeat_by("repeat")) + + expected_df = pl.DataFrame( + { + "value": [ + [None], + [[[1], [2, 2]], [[1], [2, 2]]], + [ + [[3, 3], None, [4, None]], + [[3, 3], None, [4, None]], + [[3, 3], None, [4, None]], + ], + ], + }, + schema={"value": pl.List(pl.List(pl.List(pl.Int16)))}, + ) + + assert_frame_equal(out, expected_df) + + +def test_repeat_by_struct() -> None: + df = pl.DataFrame( + { + "repeat": [1, 2, 3], + "value": [None, {"a": 1, "b": 2}, {"a": 3, "b": None}], + }, + schema={"repeat": pl.UInt32, "value": pl.Struct({"a": pl.Int8, "b": pl.Int32})}, + ) + out = df.select(pl.col("value").repeat_by("repeat")) + + expected_df = pl.DataFrame( + { + "value": [ + [None], + [{"a": 1, "b": 2}, {"a": 1, "b": 2}], + [{"a": 3, "b": None}, {"a": 3, "b": None}, {"a": 3, "b": None}], + ], + }, + schema={"value": pl.List(pl.Struct({"a": pl.Int8, "b": pl.Int32}))}, + ) + + assert_frame_equal(out, expected_df) + + +def test_repeat_by_nested_struct() -> None: + df = pl.DataFrame( + { + "repeat": [1, 2, 3], + "value": [ + None, + {"a": {"x": 1, "y": 1}, "b": 2}, + {"a": {"x": None, "y": 3}, "b": None}, + ], + }, + schema={ + "repeat": pl.UInt32, + "value": pl.Struct( + {"a": pl.Struct({"x": pl.Int64, "y": pl.Int128}), "b": pl.Int32} + ), + }, + ) + out = df.select(pl.col("value").repeat_by("repeat")) + + expected_df = pl.DataFrame( + { + "value": [ + [None], + [{"a": {"x": 1, "y": 1}, "b": 2}, {"a": {"x": 1, "y": 1}, "b": 2}], + [ + {"a": {"x": None, "y": 3}, "b": None}, + {"a": {"x": None, "y": 3}, "b": None}, + {"a": {"x": None, "y": 3}, "b": None}, + ], + ], + }, + schema={ + "value": pl.List( + pl.Struct( + {"a": pl.Struct({"x": pl.Int64, "y": pl.Int128}), "b": pl.Int32} + ) + ) + }, + ) + + assert_frame_equal(out, expected_df) + + +def test_repeat_by_struct_in_list() -> None: + df = pl.DataFrame( + { + "repeat": [1, 2, 3], + "value": [ + None, + [{"a": "foo", "b": "A"}, None], + [{"a": None, "b": "B"}, {"a": "test", "b": "B"}], + ], + }, + schema={ + "repeat": pl.UInt32, + "value": pl.List(pl.Struct({"a": pl.String, "b": pl.Enum(["A", "B"])})), + }, + ) + out = df.select(pl.col("value").repeat_by("repeat")) + + expected_df = pl.DataFrame( + { + "value": [ + [None], + [[{"a": "foo", "b": "A"}, None], [{"a": "foo", "b": "A"}, None]], + [ + [{"a": None, "b": "B"}, {"a": "test", "b": "B"}], + [{"a": None, "b": "B"}, {"a": "test", "b": "B"}], + [{"a": None, "b": "B"}, {"a": "test", "b": "B"}], + ], + ], + }, + schema={ + "value": pl.List( + pl.List(pl.Struct({"a": pl.String, "b": pl.Enum(["A", "B"])})) + ) + }, + ) + + assert_frame_equal(out, expected_df) + + +def test_repeat_by_list_in_struct() -> None: + df = pl.DataFrame( + { + "repeat": [1, 2, 3], + "value": [ + None, + {"a": [1, 2, 3], "b": ["x", "y", None]}, + {"a": [None, 5, 6], "b": None}, + ], + }, + schema={ + "repeat": pl.UInt32, + "value": pl.Struct({"a": pl.List(pl.Int8), "b": pl.List(pl.String)}), + }, + ) + out = df.select(pl.col("value").repeat_by("repeat")) + + expected_df = pl.DataFrame( + { + "value": [ + [None], + [ + {"a": [1, 2, 3], "b": ["x", "y", None]}, + {"a": [1, 2, 3], "b": ["x", "y", None]}, + ], + [ + {"a": [None, 5, 6], "b": None}, + {"a": [None, 5, 6], "b": None}, + {"a": [None, 5, 6], "b": None}, + ], + ], + }, + schema={ + "value": pl.List( + pl.Struct({"a": pl.List(pl.Int8), "b": pl.List(pl.String)}) + ) + }, + ) + + assert_frame_equal(out, expected_df) + + +@pytest.mark.parametrize( + ("data", "expected_data"), + [ + (["a", "b", None], [["a", "a"], None, [None, None, None]]), + ([1, 2, None], [[1, 1], None, [None, None, None]]), + ([1.1, 2.2, None], [[1.1, 1.1], None, [None, None, None]]), + ([True, False, None], [[True, True], None, [None, None, None]]), + ], +) +def test_repeat_by_none_13053(data: list[Any], expected_data: list[list[Any]]) -> None: + df = pl.DataFrame({"x": data, "by": [2, None, 3]}) + res = df.select(repeat=pl.col("x").repeat_by("by")) + expected = pl.Series("repeat", expected_data) + assert_series_equal(res.to_series(), expected) + + +def test_repeat_by_literal_none_20268() -> None: + df = pl.DataFrame({"x": ["a", "b"]}) + expected = pl.Series("repeat", [None, None], dtype=pl.List(pl.String)) + + res = df.select(repeat=pl.col("x").repeat_by(pl.lit(None))) + assert_series_equal(res.to_series(), expected) + + res = df.select(repeat=pl.col("x").repeat_by(None)) # type: ignore[arg-type] + assert_series_equal(res.to_series(), expected) + + +@pytest.mark.parametrize("value", [pl.Series([]), pl.Series([1, 2])]) +def test_repeat_nonscalar_value(value: pl.Series) -> None: + with pytest.raises(ComputeError, match="'value' must be scalar value"): + pl.select(pl.repeat(pl.Series(value), n=1)) + + +@pytest.mark.parametrize("n", [[], [1, 2]]) +def test_repeat_nonscalar_n(n: list[int]) -> None: + df = pl.DataFrame({"n": n}) + with pytest.raises(ComputeError, match="'n' must be scalar value"): + df.select(pl.repeat("a", pl.col("n"))) + + +def test_repeat_value_first() -> None: + df = pl.DataFrame({"a": ["a", "b", "c"], "n": [4, 5, 6]}) + result = df.select(rep=pl.repeat(pl.col("a").first(), n=pl.col("n").first())) + expected = pl.DataFrame({"rep": ["a", "a", "a", "a"]}) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/functions/test_when_then.py b/py-polars/tests/unit/functions/test_when_then.py new file mode 100644 index 000000000000..936de52b3dcf --- /dev/null +++ b/py-polars/tests/unit/functions/test_when_then.py @@ -0,0 +1,757 @@ +from __future__ import annotations + +import itertools +import random +from datetime import datetime +from typing import Any + +import pytest + +import polars as pl +from polars.exceptions import InvalidOperationError, ShapeError +from polars.testing import assert_frame_equal + + +def test_when_then() -> None: + df = pl.DataFrame({"a": [1, 2, 3, 4, 5]}) + + expr = pl.when(pl.col("a") < 3).then(pl.lit("x")) + + result = df.select( + expr.otherwise(pl.lit("y")).alias("a"), + expr.alias("b"), + ) + + expected = pl.DataFrame( + { + "a": ["x", "x", "y", "y", "y"], + "b": ["x", "x", None, None, None], + } + ) + assert_frame_equal(result, expected) + + +def test_when_then_chained() -> None: + df = pl.DataFrame({"a": [1, 2, 3, 4, 5]}) + + expr = ( + pl.when(pl.col("a") < 3) + .then(pl.lit("x")) + .when(pl.col("a") > 4) + .then(pl.lit("z")) + ) + + result = df.select( + expr.otherwise(pl.lit("y")).alias("a"), + expr.alias("b"), + ) + + expected = pl.DataFrame( + { + "a": ["x", "x", "y", "y", "z"], + "b": ["x", "x", None, None, "z"], + } + ) + assert_frame_equal(result, expected) + + +def test_when_then_invalid_chains() -> None: + with pytest.raises(AttributeError): + pl.when("a").when("b") # type: ignore[attr-defined] + with pytest.raises(AttributeError): + pl.when("a").otherwise(2) # type: ignore[attr-defined] + with pytest.raises(AttributeError): + pl.when("a").then(1).then(2) # type: ignore[attr-defined] + with pytest.raises(AttributeError): + pl.when("a").then(1).otherwise(2).otherwise(3) # type: ignore[attr-defined] + with pytest.raises(AttributeError): + pl.when("a").then(1).when("b").when("c") # type: ignore[attr-defined] + with pytest.raises(AttributeError): + pl.when("a").then(1).when("b").otherwise("2") # type: ignore[attr-defined] + with pytest.raises(AttributeError): + pl.when("a").then(1).when("b").then(2).when("c").when("d") # type: ignore[attr-defined] + + +def test_when_then_implicit_none() -> None: + df = pl.DataFrame( + { + "team": ["A", "A", "A", "B", "B", "C"], + "points": [11, 8, 10, 6, 6, 5], + } + ) + + result = df.select( + pl.when(pl.col("points") > 7).then(pl.lit("Foo")), + pl.when(pl.col("points") > 7).then(pl.lit("Foo")).alias("bar"), + ) + + expected = pl.DataFrame( + { + "literal": ["Foo", "Foo", "Foo", None, None, None], + "bar": ["Foo", "Foo", "Foo", None, None, None], + } + ) + assert_frame_equal(result, expected) + + +def test_when_then_empty_list_5547() -> None: + out = pl.DataFrame({"a": []}).select([pl.when(pl.col("a") > 1).then([1])]) + assert out.shape == (0, 1) + assert out.dtypes == [pl.List(pl.Int64)] + + +def test_nested_when_then_and_wildcard_expansion_6284() -> None: + df = pl.DataFrame( + { + "1": ["a", "b"], + "2": ["c", "d"], + } + ) + + out0 = df.with_columns( + pl.when(pl.any_horizontal(pl.all() == "a")) + .then(pl.lit("a")) + .otherwise( + pl.when(pl.any_horizontal(pl.all() == "d")) + .then(pl.lit("d")) + .otherwise(None) + ) + .alias("result") + ) + + out1 = df.with_columns( + pl.when(pl.any_horizontal(pl.all() == "a")) + .then(pl.lit("a")) + .when(pl.any_horizontal(pl.all() == "d")) + .then(pl.lit("d")) + .otherwise(None) + .alias("result") + ) + + assert_frame_equal(out0, out1) + assert out0.to_dict(as_series=False) == { + "1": ["a", "b"], + "2": ["c", "d"], + "result": ["a", "d"], + } + + +def test_list_zip_with_logical_type() -> None: + df = pl.DataFrame( + { + "start": [datetime(2023, 1, 1, 1, 1, 1), datetime(2023, 1, 1, 1, 1, 1)], + "stop": [datetime(2023, 1, 1, 1, 3, 1), datetime(2023, 1, 1, 1, 4, 1)], + "use": [1, 0], + } + ) + + df = df.with_columns( + pl.datetime_ranges( + pl.col("start"), pl.col("stop"), interval="1h", eager=False, closed="left" + ).alias("interval_1"), + pl.datetime_ranges( + pl.col("start"), pl.col("stop"), interval="1h", eager=False, closed="left" + ).alias("interval_2"), + ) + + out = df.select( + pl.when(pl.col("use") == 1) + .then(pl.col("interval_2")) + .otherwise(pl.col("interval_1")) + .alias("interval_new") + ) + assert out.dtypes == [pl.List(pl.Datetime(time_unit="us", time_zone=None))] + + +def test_type_coercion_when_then_otherwise_2806() -> None: + out = ( + pl.DataFrame({"names": ["foo", "spam", "spam"], "nrs": [1, 2, 3]}) + .select( + pl.when(pl.col("names") == "spam") + .then(pl.col("nrs") * 2) + .otherwise(pl.lit("other")) + .alias("new_col"), + ) + .to_series() + ) + expected = pl.Series("new_col", ["other", "4", "6"]) + assert out.to_list() == expected.to_list() + + # test it remains float32 + assert ( + pl.Series("a", [1.0, 2.0, 3.0], dtype=pl.Float32) + .to_frame() + .select(pl.when(pl.col("a") > 2.0).then(pl.col("a")).otherwise(0.0)) + ).to_series().dtype == pl.Float32 + + +def test_when_then_edge_cases_3994() -> None: + df = pl.DataFrame(data={"id": [1, 1], "type": [2, 2]}) + + # this tests if lazy correctly assigns the list schema to the column aggregation + assert ( + df.lazy() + .group_by(["id"]) + .agg(pl.col("type")) + .with_columns( + pl.when(pl.col("type").list.len() == 0) + .then(pl.lit(None)) + .otherwise(pl.col("type")) + .name.keep() + ) + .collect() + ).to_dict(as_series=False) == {"id": [1], "type": [[2, 2]]} + + # this tests ternary with an empty argument + assert ( + df.filter(pl.col("id") == 42) + .group_by(["id"]) + .agg(pl.col("type")) + .with_columns( + pl.when(pl.col("type").list.len() == 0) + .then(pl.lit(None)) + .otherwise(pl.col("type")) + .name.keep() + ) + ).to_dict(as_series=False) == {"id": [], "type": []} + + +def test_object_when_then_4702() -> None: + # please don't ever do this + x = pl.DataFrame({"Row": [1, 2], "Type": [pl.Date, pl.UInt8]}) + + assert x.with_columns( + pl.when(pl.col("Row") == 1) + .then(pl.lit(pl.UInt16, allow_object=True)) + .otherwise(pl.lit(pl.UInt8, allow_object=True)) + .alias("New_Type") + ).to_dict(as_series=False) == { + "Row": [1, 2], + "Type": [pl.Date, pl.UInt8], + "New_Type": [pl.UInt16, pl.UInt8], + } + + +@pytest.mark.may_fail_auto_streaming +def test_comp_categorical_lit_dtype() -> None: + df = pl.DataFrame( + data={"column": ["a", "b", "e"], "values": [1, 5, 9]}, + schema=[("column", pl.Categorical), ("more", pl.Int32)], + ) + + assert df.with_columns( + pl.when(pl.col("column") == "e") + .then(pl.lit("d")) + .otherwise(pl.col("column")) + .alias("column") + ).dtypes == [pl.Categorical, pl.Int32] + + +def test_comp_incompatible_enum_dtype() -> None: + df = pl.DataFrame({"a": pl.Series(["a", "b"], dtype=pl.Enum(["a", "b"]))}) + + with pytest.raises( + InvalidOperationError, + match="conversion from `str` to `enum` failed in column 'scalar'", + ): + df.with_columns( + pl.when(pl.col("a") == "a").then(pl.col("a")).otherwise(pl.lit("c")) + ) + + +def test_predicate_broadcast() -> None: + df = pl.DataFrame( + { + "key": ["a", "a", "b", "b", "c", "c"], + "val": [1, 2, 3, 4, 5, 6], + } + ) + out = df.group_by("key", maintain_order=True).agg( + agg=pl.when(pl.col("val").min() >= 3).then(pl.col("val")), + ) + assert out.to_dict(as_series=False) == { + "key": ["a", "b", "c"], + "agg": [[None, None], [3, 4], [5, 6]], + } + + +@pytest.mark.parametrize( + "mask_expr", + [ + pl.lit(True), + pl.first("true"), + pl.lit(False), + pl.first("false"), + pl.lit(None, dtype=pl.Boolean), + pl.col("null_bool"), + pl.col("true"), + pl.col("false"), + ], +) +@pytest.mark.parametrize( + "truthy_expr", + [ + pl.lit(1), + pl.first("x"), + pl.col("x"), + ], +) +@pytest.mark.parametrize( + "falsy_expr", + [ + pl.lit(1), + pl.first("x"), + pl.col("x"), + ], +) +def test_single_element_broadcast( + mask_expr: pl.Expr, + truthy_expr: pl.Expr, + falsy_expr: pl.Expr, +) -> None: + df = ( + pl.Series("x", 5 * [1], dtype=pl.Int32) + .to_frame() + .with_columns(true=True, false=False, null_bool=pl.lit(None, dtype=pl.Boolean)) + ) + + # Given that the lengths of the mask, truthy and falsy are all either: + # - Length 1 + # - Equal length to the maximum length of the 3. + # This test checks that all length-1 exprs are broadcast to the max length. + result = df.select( + pl.when(mask_expr).then(truthy_expr.alias("x")).otherwise(falsy_expr) + ) + expected = df.select("x").head( + df.select( + pl.max_horizontal(mask_expr.len(), truthy_expr.len(), falsy_expr.len()) + ).item() + ) + assert_frame_equal(result, expected) + + result = ( + df.group_by(pl.lit(True).alias("key")) + .agg(pl.when(mask_expr).then(truthy_expr.alias("x")).otherwise(falsy_expr)) + .drop("key") + ) + if expected.height > 1: + result = result.explode(pl.all()) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "df", + [pl.DataFrame({"x": range(5)}), pl.DataFrame({"x": 5 * [[*range(5)]]})], +) +@pytest.mark.parametrize( + "ternary_expr", + [ + pl.when(True).then(pl.col("x").head(2)).otherwise(pl.col("x")), + pl.when(False).then(pl.col("x").head(2)).otherwise(pl.col("x")), + ], +) +def test_mismatched_height_should_raise( + df: pl.DataFrame, ternary_expr: pl.Expr +) -> None: + with pytest.raises(ShapeError): + df.select(ternary_expr) + + with pytest.raises(ShapeError): + df.group_by(pl.lit(True).alias("key")).agg(ternary_expr) + + +def test_when_then_output_name_12380() -> None: + df = pl.DataFrame( + {"x": range(5), "y": range(5, 10)}, schema={"x": pl.Int8, "y": pl.Int64} + ).with_columns(true=True, false=False, null_bool=pl.lit(None, dtype=pl.Boolean)) + + expect = df.select(pl.col("x").cast(pl.Int64)) + for true_expr in (pl.first("true"), pl.col("true"), pl.lit(True)): + ternary_expr = pl.when(true_expr).then(pl.col("x")).otherwise(pl.col("y")) + + actual = df.select(ternary_expr) + assert_frame_equal( + expect, + actual, + ) + actual = ( + df.group_by(pl.lit(True).alias("key")) + .agg(ternary_expr) + .drop("key") + .explode(pl.all()) + ) + assert_frame_equal( + expect, + actual, + ) + + expect = df.select(pl.col("y").alias("x")) + for false_expr in ( + pl.first("false"), + pl.col("false"), + pl.lit(False), + pl.first("null_bool"), + pl.col("null_bool"), + pl.lit(None, dtype=pl.Boolean), + ): + ternary_expr = pl.when(false_expr).then(pl.col("x")).otherwise(pl.col("y")) + + actual = df.select(ternary_expr) + assert_frame_equal( + expect, + actual, + ) + actual = ( + df.group_by(pl.lit(True).alias("key")) + .agg(ternary_expr) + .drop("key") + .explode(pl.all()) + ) + assert_frame_equal( + expect, + actual, + ) + + +def test_when_then_nested_non_unit_literal_predicate_agg_broadcast_12242() -> None: + df = pl.DataFrame( + { + "array_name": ["A", "A", "A", "B", "B"], + "array_idx": [5, 0, 3, 7, 2], + "array_val": [1, 2, 3, 4, 5], + } + ) + + int_range = pl.int_range(pl.min("array_idx"), pl.max("array_idx") + 1) + + is_valid_idx = int_range.is_in("array_idx") + + idxs = is_valid_idx.cum_sum() - 1 + + ternary_expr = pl.when(is_valid_idx).then(pl.col("array_val").gather(idxs)) + + expect = pl.DataFrame( + [ + pl.Series("array_name", ["A", "B"], dtype=pl.String), + pl.Series( + "array_val", + [[1, None, None, 2, None, 3], [4, None, None, None, None, 5]], + dtype=pl.List(pl.Int64), + ), + ] + ) + + assert_frame_equal( + expect, df.group_by("array_name").agg(ternary_expr).sort("array_name") + ) + + +def test_when_then_non_unit_literal_predicate_agg_broadcast_12382() -> None: + df = pl.DataFrame({"id": [1, 1], "value": [0, 3]}) + + expect = pl.DataFrame({"id": [1], "literal": [["yes", None, None, "yes", None]]}) + actual = df.group_by("id").agg( + pl.when(pl.int_range(0, 5).is_in("value")).then(pl.lit("yes")) + ) + + assert_frame_equal(expect, actual) + + +def test_when_then_binary_op_predicate_agg_12526() -> None: + df = pl.DataFrame( + { + "a": [1, 1, 1], + "b": [1, 2, 5], + } + ) + + expect = pl.DataFrame( + {"a": [1], "col": [None]}, schema={"a": pl.Int64, "col": pl.String} + ) + + actual = df.group_by("a").agg( + col=( + pl.when( + pl.col("a").shift(1) > 2, + pl.col("b").is_not_null(), + ) + .then(pl.lit("abc")) + .when( + pl.col("a").shift(1) > 1, + pl.col("b").is_not_null(), + ) + .then(pl.lit("def")) + .otherwise(pl.lit(None)) + .first() + ) + ) + + assert_frame_equal(expect, actual) + + +def test_when_predicates_kwargs() -> None: + df = pl.DataFrame( + { + "x": [10, 20, 30, 40], + "y": [15, -20, None, 1], + "z": ["a", "b", "c", "d"], + } + ) + assert_frame_equal( # kwargs only + df.select(matched=pl.when(x=30, z="c").then(True).otherwise(False)), + pl.DataFrame({"matched": [False, False, True, False]}), + ) + assert_frame_equal( # mixed predicates & kwargs + df.select(matched=pl.when(pl.col("x") < 30, z="b").then(True).otherwise(False)), + pl.DataFrame({"matched": [False, True, False, False]}), + ) + assert_frame_equal( # chained when/then with mixed predicates/kwargs + df.select( + misc=pl.when(pl.col("x") > 50) + .then(pl.lit("x>50")) + .when(y=1) + .then(pl.lit("y=1")) + .when(pl.col("z").is_in(["a", "b"]), pl.col("y") < 0) + .then(pl.lit("z in (a|b), y<0")) + .otherwise(pl.lit("?")) + ), + pl.DataFrame({"misc": ["?", "z in (a|b), y<0", "?", "y=1"]}), + ) + + +def test_when_then_null_broadcast() -> None: + assert ( + pl.select( + pl.when(pl.repeat(True, 2, dtype=pl.Boolean)).then( + pl.repeat(None, 1, dtype=pl.Null) + ) + ).height + == 2 + ) + + +@pytest.mark.slow +@pytest.mark.parametrize("len", [1, 10, 100, 500]) +@pytest.mark.parametrize( + ("dtype", "vals"), + [ + pytest.param(pl.Boolean, [False, True], id="Boolean"), + pytest.param(pl.UInt8, [0, 1], id="UInt8"), + pytest.param(pl.UInt16, [0, 1], id="UInt16"), + pytest.param(pl.UInt32, [0, 1], id="UInt32"), + pytest.param(pl.UInt64, [0, 1], id="UInt64"), + pytest.param(pl.Float32, [0.0, 1.0], id="Float32"), + pytest.param(pl.Float64, [0.0, 1.0], id="Float64"), + pytest.param(pl.String, ["0", "12"], id="String"), + pytest.param(pl.Array(pl.String, 2), [["0", "1"], ["3", "4"]], id="StrArray"), + pytest.param(pl.Array(pl.Int64, 2), [[0, 1], [3, 4]], id="IntArray"), + pytest.param(pl.List(pl.String), [["0"], ["1", "2"]], id="List"), + pytest.param( + pl.Struct({"foo": pl.Int32, "bar": pl.String}), + [{"foo": 0, "bar": "1"}, {"foo": 1, "bar": "2"}], + id="Struct", + ), + pytest.param(pl.Object, ["x", "y"], id="Object"), + ], +) +@pytest.mark.parametrize("broadcast", list(itertools.product([False, True], repeat=3))) +def test_when_then_parametric( + len: int, dtype: pl.DataType, vals: list[Any], broadcast: list[bool] +) -> None: + # Makes no sense to broadcast all columns. + if all(broadcast): + return + + rng = random.Random(42) + + for _ in range(10): + mask = rng.choices([False, True, None], k=len) + if_true = rng.choices(vals + [None], k=len) + if_false = rng.choices(vals + [None], k=len) + + py_mask, py_true, py_false = ( + [c[0]] * len if b else c + for b, c in zip(broadcast, [mask, if_true, if_false]) + ) + pl_mask, pl_true, pl_false = ( + c.first() if b else c + for b, c in zip(broadcast, [pl.col.mask, pl.col.if_true, pl.col.if_false]) + ) + + ref = pl.DataFrame( + {"if_true": [t if m else f for m, t, f in zip(py_mask, py_true, py_false)]}, + schema={"if_true": dtype}, + ) + df = pl.DataFrame( + { + "mask": mask, + "if_true": if_true, + "if_false": if_false, + }, + schema={"mask": pl.Boolean, "if_true": dtype, "if_false": dtype}, + ) + + ans = df.select(pl.when(pl_mask).then(pl_true).otherwise(pl_false)) + if dtype != pl.Object: + assert_frame_equal(ref, ans) + else: + assert ref["if_true"].to_list() == ans["if_true"].to_list() + + +def test_when_then_else_struct_18961() -> None: + v1 = [None, {"foo": 0, "bar": "1"}] + v2 = [{"foo": 0, "bar": "1"}, {"foo": 0, "bar": "1"}] + + df = pl.DataFrame({"left": v1, "right": v2, "mask": [False, True]}) + + expected = [{"foo": 0, "bar": "1"}, {"foo": 0, "bar": "1"}] + ans = ( + df.select( + pl.when(pl.col.mask).then(pl.col.left).otherwise(pl.col.right.first()) + ) + .get_column("left") + .to_list() + ) + assert expected == ans + + df = pl.DataFrame({"left": v2, "right": v1, "mask": [True, False]}) + + expected = [{"foo": 0, "bar": "1"}, {"foo": 0, "bar": "1"}] + ans = ( + df.select( + pl.when(pl.col.mask).then(pl.col.left.first()).otherwise(pl.col.right) + ) + .get_column("left") + .to_list() + ) + assert expected == ans + + df = pl.DataFrame({"left": v1, "right": v2, "mask": [True, False]}) + + expected2 = [None, {"foo": 0, "bar": "1"}] + ans = ( + df.select( + pl.when(pl.col.mask) + .then(pl.col.left.first()) + .otherwise(pl.col.right.first()) + ) + .get_column("left") + .to_list() + ) + assert expected2 == ans + + +def test_when_then_supertype_15975() -> None: + df = pl.DataFrame({"a": [1, 2, 3]}) + + assert df.with_columns( + pl.when(True).then(1 ** pl.col("a") + 1.0 * pl.col("a")) + ).to_dict(as_series=False) == {"a": [1, 2, 3], "literal": [2.0, 3.0, 4.0]} + + +def test_when_then_supertype_15975_comment() -> None: + df = pl.LazyFrame({"foo": [1, 3, 4], "bar": [3, 4, 0]}) + + q = df.with_columns( + pl.when(pl.col("foo") == 1) + .then(1) + .when(pl.col("foo") == 2) + .then(4) + .when(pl.col("foo") == 3) + .then(1.5) + .when(pl.col("foo") == 4) + .then(16) + .otherwise(0) + .alias("val") + ) + + assert q.collect()["val"].to_list() == [1.0, 1.5, 16.0] + + +def test_chained_when_no_subclass_17142() -> None: + # https://github.com/pola-rs/polars/pull/17142 + when = pl.when(True).then(1).when(True) + + assert not isinstance(when, pl.Expr) + assert " None: + df = pl.DataFrame( + [ + pl.Series("x", [{"a": 1}]), + pl.Series("b", [False]), + ] + ) + + df = df.vstack(df) + + # This used to panic + assert_frame_equal( + df.select(pl.when(pl.col.b).then(pl.first("x")).otherwise(pl.first("x"))), + pl.DataFrame({"x": [{"a": 1}, {"a": 1}]}), + ) + + +some_scalar = pl.Series("a", [{"x": 2}], pl.Struct) +none_scalar = pl.Series("a", [None], pl.Struct({"x": pl.Int64})) +column = pl.Series("a", [{"x": 2}, {"x": 2}], pl.Struct) + + +@pytest.mark.parametrize( + "values", + [ + (some_scalar, some_scalar), + (some_scalar, pl.col.a), + (some_scalar, none_scalar), + (some_scalar, column), + (none_scalar, pl.col.a), + (none_scalar, none_scalar), + (none_scalar, column), + (pl.col.a, pl.col.a), + (pl.col.a, column), + (column, column), + ], +) +def test_struct_when_then_broadcasting_combinations_19122( + values: tuple[Any, Any], +) -> None: + lv, rv = values + + df = pl.Series("a", [{"x": 1}, {"x": 1}], pl.Struct).to_frame() + + assert_frame_equal( + df.select( + pl.when(pl.col.a.struct.field("x") == 0).then(lv).otherwise(rv).alias("a") + ), + df.select( + pl.when(pl.col.a.struct.field("x") == 0).then(None).otherwise(rv).alias("a") + ), + ) + + assert_frame_equal( + df.select( + pl.when(pl.col.a.struct.field("x") != 0).then(rv).otherwise(lv).alias("a") + ), + df.select( + pl.when(pl.col.a.struct.field("x") != 0).then(rv).otherwise(None).alias("a") + ), + ) + + +def test_when_then_to_decimal_18375() -> None: + df = pl.DataFrame({"a": ["1.23", "4.56"]}) + + result = df.with_columns( + b=pl.when(False).then(None).otherwise(pl.col("a").str.to_decimal()), + c=pl.when(True).then(pl.col("a").str.to_decimal()), + ) + expected = pl.DataFrame( + { + "a": ["1.23", "4.56"], + "b": ["1.23", "4.56"], + "c": ["1.23", "4.56"], + }, + schema={"a": pl.String, "b": pl.Decimal, "c": pl.Decimal}, + ) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/interchange/__init__.py b/py-polars/tests/unit/interchange/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/interchange/test_buffer.py b/py-polars/tests/unit/interchange/test_buffer.py new file mode 100644 index 000000000000..048532d4de76 --- /dev/null +++ b/py-polars/tests/unit/interchange/test_buffer.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import pytest + +import polars as pl +from polars.interchange.buffer import PolarsBuffer +from polars.interchange.protocol import CopyNotAllowedError, DlpackDeviceType + + +@pytest.mark.parametrize( + ("data", "allow_copy"), + [ + (pl.Series([1, 2]), True), + (pl.Series([1, 2]), False), + (pl.concat([pl.Series([1, 2]), pl.Series([1, 2])], rechunk=False), True), + ], +) +def test_init(data: pl.Series, allow_copy: bool) -> None: + buffer = PolarsBuffer(data, allow_copy=allow_copy) + assert buffer._data.n_chunks() == 1 + + +def test_init_invalid_input() -> None: + s = pl.Series([1, 2]) + data = pl.concat([s, s], rechunk=False) + + with pytest.raises( + CopyNotAllowedError, match="non-contiguous buffer must be made contiguous" + ): + PolarsBuffer(data, allow_copy=False) + + +@pytest.mark.parametrize( + ("data", "expected"), + [ + (pl.Series([1, 2], dtype=pl.Int8), 2), + (pl.Series([1, 2], dtype=pl.Int64), 16), + (pl.Series([1.4, 2.9, 3.0], dtype=pl.Float32), 12), + (pl.Series([97, 98, 99, 195, 169, 195, 162, 195, 167], dtype=pl.UInt8), 9), + (pl.Series([0, 1, 0, 2, 0], dtype=pl.UInt32), 20), + (pl.Series([True, False], dtype=pl.Boolean), 1), + (pl.Series([True] * 8, dtype=pl.Boolean), 1), + (pl.Series([True] * 9, dtype=pl.Boolean), 2), + (pl.Series([True] * 9, dtype=pl.Boolean)[5:], 2), + ], +) +def test_bufsize(data: pl.Series, expected: int) -> None: + buffer = PolarsBuffer(data) + assert buffer.bufsize == expected + + +@pytest.mark.parametrize( + "data", + [ + pl.Series([1, 2]), + pl.Series([1.2, 2.9, 3.0]), + pl.Series([True, False]), + pl.Series([True, False])[1:], + pl.Series([97, 98, 97], dtype=pl.UInt8), + pl.Series([], dtype=pl.Float32), + ], +) +def test_ptr(data: pl.Series) -> None: + buffer = PolarsBuffer(data) + result = buffer.ptr + # Memory address is unpredictable, so we just check if an integer is returned + assert isinstance(result, int) + + +def test__dlpack__() -> None: + data = pl.Series([1, 2]) + buffer = PolarsBuffer(data) + with pytest.raises(NotImplementedError): + buffer.__dlpack__() + + +def test__dlpack_device__() -> None: + data = pl.Series([1, 2]) + buffer = PolarsBuffer(data) + assert buffer.__dlpack_device__() == (DlpackDeviceType.CPU, None) + + +def test__repr__() -> None: + data = pl.Series([True, False, True]) + buffer = PolarsBuffer(data) + print(buffer.__repr__()) diff --git a/py-polars/tests/unit/interchange/test_column.py b/py-polars/tests/unit/interchange/test_column.py new file mode 100644 index 000000000000..12b9631f40e8 --- /dev/null +++ b/py-polars/tests/unit/interchange/test_column.py @@ -0,0 +1,366 @@ +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING + +import pytest + +import polars as pl +from polars.interchange.column import PolarsColumn +from polars.interchange.protocol import ColumnNullType, CopyNotAllowedError, DtypeKind +from polars.testing import assert_series_equal + +if TYPE_CHECKING: + from polars.interchange.protocol import Dtype + + +def test_size() -> None: + s = pl.Series([1, 2, 3]) + col = PolarsColumn(s) + assert col.size() == 3 + + +def test_offset() -> None: + s = pl.Series([1, 2, 3]) + col = PolarsColumn(s) + assert col.offset == 0 + + +def test_dtype_int() -> None: + s = pl.Series([1, 2, 3], dtype=pl.Int32) + col = PolarsColumn(s) + assert col.dtype == (DtypeKind.INT, 32, "i", "=") + + +def test_dtype_categorical() -> None: + s = pl.Series(["a", "b", "a"], dtype=pl.Categorical) + col = PolarsColumn(s) + assert col.dtype == (DtypeKind.CATEGORICAL, 32, "I", "=") + + +def test_describe_categorical() -> None: + s = pl.Series(["b", "a", "a", "c", None, "b"], dtype=pl.Categorical) + col = PolarsColumn(s) + + out = col.describe_categorical + + assert out["is_ordered"] is True + assert out["is_dictionary"] is True + + expected_categories = pl.Series(["b", "a", "c"]) + assert_series_equal(out["categories"]._col, expected_categories) + + +def test_describe_categorical_lexical_ordering() -> None: + s = pl.Series(["b", "a", "a", "c", None, "b"], dtype=pl.Categorical("lexical")) + col = PolarsColumn(s) + + out = col.describe_categorical + + assert out["is_ordered"] is False + + +def test_describe_categorical_enum() -> None: + s = pl.Series(["b", "a", "a", "c", None, "b"], dtype=pl.Enum(["a", "b", "c"])) + col = PolarsColumn(s) + + out = col.describe_categorical + + assert out["is_ordered"] is True + assert out["is_dictionary"] is True + + expected_categories = pl.Series("category", ["a", "b", "c"]) + assert_series_equal(out["categories"]._col, expected_categories) + + +def test_describe_categorical_other_dtype() -> None: + s = pl.Series(["a", "b", "a"], dtype=pl.String) + col = PolarsColumn(s) + with pytest.raises(TypeError): + col.describe_categorical + + +def test_describe_null() -> None: + s = pl.Series([1, 2, None]) + col = PolarsColumn(s) + assert col.describe_null == (ColumnNullType.USE_BITMASK, 0) + + +def test_describe_null_no_null_values() -> None: + s = pl.Series([1, 2, 3]) + col = PolarsColumn(s) + assert col.describe_null == (ColumnNullType.NON_NULLABLE, None) + + +def test_null_count() -> None: + s = pl.Series([None, 2, None]) + col = PolarsColumn(s) + assert col.null_count == 2 + + +def test_metadata() -> None: + s = pl.Series([1, 2]) + col = PolarsColumn(s) + assert col.metadata == {} + + +def test_num_chunks() -> None: + s = pl.Series([1, 2]) + col = PolarsColumn(s) + assert col.num_chunks() == 1 + + s2 = pl.concat([s, s], rechunk=False) + col2 = s2.to_frame().__dataframe__().get_column(0) + assert col2.num_chunks() == 2 + + +@pytest.mark.parametrize("n_chunks", [None, 2]) +def test_get_chunks(n_chunks: int | None) -> None: + s1 = pl.Series([1, 2, 3]) + s2 = pl.Series([4, 5]) + s = pl.concat([s1, s2], rechunk=False) + col = PolarsColumn(s) + + out = col.get_chunks(n_chunks) + + expected = [s1, s2] + for o, e in zip(out, expected): + assert_series_equal(o._col, e) + + +def test_get_chunks_invalid_input() -> None: + s1 = pl.Series([1, 2, 3]) + s2 = pl.Series([4, 5]) + s = pl.concat([s1, s2], rechunk=False) + col = PolarsColumn(s) + + with pytest.raises(ValueError): + next(col.get_chunks(0)) + + with pytest.raises(ValueError): + next(col.get_chunks(3)) + + +def test_get_chunks_subdivided_chunks() -> None: + s1 = pl.Series([1, 2, 3]) + s2 = pl.Series([4, 5]) + s = pl.concat([s1, s2], rechunk=False) + col = PolarsColumn(s) + + out = col.get_chunks(4) + + chunk1 = next(out) + expected1 = pl.Series([1, 2]) + assert_series_equal(chunk1._col, expected1) + + chunk2 = next(out) + expected2 = pl.Series([3]) + assert_series_equal(chunk2._col, expected2) + + chunk3 = next(out) + expected3 = pl.Series([4]) + assert_series_equal(chunk3._col, expected3) + + chunk4 = next(out) + expected4 = pl.Series([5]) + assert_series_equal(chunk4._col, expected4) + + with pytest.raises(StopIteration): + next(out) + + +@pytest.mark.parametrize( + ("series", "expected_data", "expected_dtype"), + [ + ( + pl.Series([1, None, 3], dtype=pl.Int16), + pl.Series([1, 0, 3], dtype=pl.Int16), + (DtypeKind.INT, 16, "s", "="), + ), + ( + pl.Series([-1.5, 3.0, None], dtype=pl.Float64), + pl.Series([-1.5, 3.0, 0.0], dtype=pl.Float64), + (DtypeKind.FLOAT, 64, "g", "="), + ), + ( + pl.Series(["a", "bc", None, "éâç"], dtype=pl.String), + pl.Series([97, 98, 99, 195, 169, 195, 162, 195, 167], dtype=pl.UInt8), + (DtypeKind.UINT, 8, "C", "="), + ), + ( + pl.Series( + [datetime(1988, 1, 2), None, datetime(2022, 12, 3)], dtype=pl.Datetime + ), + pl.Series([568080000000000, 0, 1670025600000000], dtype=pl.Int64), + (DtypeKind.INT, 64, "l", "="), + ), + ( + pl.Series(["a", "b", None, "a"], dtype=pl.Categorical), + pl.Series([0, 1, 0, 0], dtype=pl.UInt32), + (DtypeKind.UINT, 32, "I", "="), + ), + ], +) +def test_get_buffers_data( + series: pl.Series, + expected_data: pl.Series, + expected_dtype: Dtype, +) -> None: + col = PolarsColumn(series) + + out = col.get_buffers() + + data_buffer, data_dtype = out["data"] + assert_series_equal(data_buffer._data, expected_data) + assert data_dtype == expected_dtype + + +def test_get_buffers_int() -> None: + s = pl.Series([1, 2, 3], dtype=pl.Int8) + col = PolarsColumn(s) + + out = col.get_buffers() + + data_buffer, data_dtype = out["data"] + assert_series_equal(data_buffer._data, s) + assert data_dtype == (DtypeKind.INT, 8, "c", "=") + + assert out["validity"] is None + assert out["offsets"] is None + + +def test_get_buffers_with_validity_and_offsets() -> None: + s = pl.Series(["a", "bc", None, "éâç"]) + col = PolarsColumn(s) + + out = col.get_buffers() + + data_buffer, data_dtype = out["data"] + expected = pl.Series([97, 98, 99, 195, 169, 195, 162, 195, 167], dtype=pl.UInt8) + assert_series_equal(data_buffer._data, expected) + assert data_dtype == (DtypeKind.UINT, 8, "C", "=") + + validity = out["validity"] + assert validity is not None + val_buffer, val_dtype = validity + expected = pl.Series([True, True, False, True]) + assert_series_equal(val_buffer._data, expected) + assert val_dtype == (DtypeKind.BOOL, 1, "b", "=") + + offsets = out["offsets"] + assert offsets is not None + offsets_buffer, offsets_dtype = offsets + expected = pl.Series([0, 1, 3, 3, 9], dtype=pl.Int64) + assert_series_equal(offsets_buffer._data, expected) + assert offsets_dtype == (DtypeKind.INT, 64, "l", "=") + + +def test_get_buffers_chunked_bitmask() -> None: + s = pl.Series([True, False], dtype=pl.Boolean) + s_chunked = pl.concat([s[:1], s[1:]], rechunk=False) + col = PolarsColumn(s_chunked) + + chunks = list(col.get_chunks()) + assert chunks[0].get_buffers()["data"][0]._data.item() is True + assert chunks[1].get_buffers()["data"][0]._data.item() is False + + +def test_get_buffers_string_zero_copy_fails() -> None: + s = pl.Series("a", ["a", "bc"], dtype=pl.String) + + col = PolarsColumn(s, allow_copy=False) + + msg = "string buffers must be converted" + with pytest.raises(CopyNotAllowedError, match=msg): + col.get_buffers() + + +def test_get_buffers_global_categorical() -> None: + with pl.StringCache(): + _ = pl.Series("a", ["a", "b"], dtype=pl.Categorical) + s = pl.Series("a", ["c", "b"], dtype=pl.Categorical) + + # Converted to local categorical + col = PolarsColumn(s, allow_copy=True) + result = col.get_buffers() + + data_buffer, _ = result["data"] + expected = pl.Series("a", [0, 1], dtype=pl.UInt32) + assert_series_equal(data_buffer._data, expected) + + # Zero copy fails + col = PolarsColumn(s, allow_copy=False) + + msg = "column 'a' must be converted to a local categorical" + with pytest.raises(CopyNotAllowedError, match=msg): + col.get_buffers() + + +def test_get_buffers_chunked_zero_copy_fails() -> None: + s1 = pl.Series([1, 2, 3]) + s = pl.concat([s1, s1], rechunk=False) + col = PolarsColumn(s, allow_copy=False) + + with pytest.raises( + CopyNotAllowedError, match="non-contiguous buffer must be made contiguous" + ): + col.get_buffers() + + +def test_wrap_data_buffer() -> None: + values = pl.Series([1, 2, 3]) + col = PolarsColumn(pl.Series()) + + result_buffer, result_dtype = col._wrap_data_buffer(values) + + assert_series_equal(result_buffer._data, values) + assert result_dtype == (DtypeKind.INT, 64, "l", "=") + + +def test_wrap_validity_buffer() -> None: + validity = pl.Series([True, False, True]) + col = PolarsColumn(pl.Series()) + + result = col._wrap_validity_buffer(validity) + + assert result is not None + + result_buffer, result_dtype = result + assert_series_equal(result_buffer._data, validity) + assert result_dtype == (DtypeKind.BOOL, 1, "b", "=") + + +def test_wrap_validity_buffer_no_nulls() -> None: + col = PolarsColumn(pl.Series()) + assert col._wrap_validity_buffer(None) is None + + +def test_wrap_offsets_buffer() -> None: + offsets = pl.Series([0, 1, 3, 3, 9], dtype=pl.Int64) + col = PolarsColumn(pl.Series()) + + result = col._wrap_offsets_buffer(offsets) + + assert result is not None + + result_buffer, result_dtype = result + assert_series_equal(result_buffer._data, offsets) + assert result_dtype == (DtypeKind.INT, 64, "l", "=") + + +def test_wrap_offsets_buffer_none() -> None: + col = PolarsColumn(pl.Series()) + assert col._wrap_validity_buffer(None) is None + + +def test_column_unsupported_type() -> None: + s = pl.Series("a", [[4], [5, 6]]) + col = PolarsColumn(s) + + # Certain column operations work + assert col.num_chunks() == 1 + assert col.null_count == 0 + + # Error is raised when unsupported operations are requested + with pytest.raises(ValueError, match="not supported"): + col.dtype diff --git a/py-polars/tests/unit/interchange/test_dataframe.py b/py-polars/tests/unit/interchange/test_dataframe.py new file mode 100644 index 000000000000..ee26f9effd64 --- /dev/null +++ b/py-polars/tests/unit/interchange/test_dataframe.py @@ -0,0 +1,305 @@ +from __future__ import annotations + +import pytest + +import polars as pl +from polars.interchange.dataframe import PolarsDataFrame +from polars.interchange.protocol import CopyNotAllowedError +from polars.testing import assert_frame_equal, assert_series_equal + + +def test_dataframe_dunder() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + dfi = PolarsDataFrame(df) + + assert_frame_equal(dfi._df, df) + assert dfi._allow_copy is True + + dfi_new = dfi.__dataframe__(allow_copy=False) + + assert_frame_equal(dfi_new._df, df) + assert dfi_new._allow_copy is False + + +def test_dataframe_dunder_nan_as_null_not_implemented() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + dfi = PolarsDataFrame(df) + + with pytest.raises(NotImplementedError, match="has not been implemented"): + df.__dataframe__(nan_as_null=True) + + with pytest.raises(NotImplementedError, match="has not been implemented"): + dfi.__dataframe__(nan_as_null=True) + + +def test_metadata() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + dfi = PolarsDataFrame(df) + assert dfi.metadata == {} + + +def test_num_columns() -> None: + df = pl.DataFrame({"a": [1], "b": [2]}) + dfi = PolarsDataFrame(df) + assert dfi.num_columns() == 2 + + +def test_num_rows() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + dfi = PolarsDataFrame(df) + assert dfi.num_rows() == 2 + + +def test_num_chunks() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + dfi = PolarsDataFrame(df) + assert dfi.num_chunks() == 1 + + df2 = pl.concat([df, df], rechunk=False) + dfi2 = df2.__dataframe__() + assert dfi2.num_chunks() == 2 + + +def test_column_names() -> None: + df = pl.DataFrame({"a": [1], "b": [2]}) + dfi = PolarsDataFrame(df) + assert dfi.column_names() == ["a", "b"] + + +def test_get_column() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + dfi = PolarsDataFrame(df) + + out = dfi.get_column(1) + + expected = pl.Series("b", [3, 4]) + assert_series_equal(out._col, expected) + + +def test_get_column_by_name() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + dfi = PolarsDataFrame(df) + + out = dfi.get_column_by_name("b") + + expected = pl.Series("b", [3, 4]) + assert_series_equal(out._col, expected) + + +def test_get_columns() -> None: + s1 = pl.Series("a", [1, 2]) + s2 = pl.Series("b", [3, 4]) + df = pl.DataFrame([s1, s2]) + dfi = PolarsDataFrame(df) + + out = dfi.get_columns() + + expected = [s1, s2] + for o, e in zip(out, expected): + assert_series_equal(o._col, e) + + +def test_select_columns() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}) + dfi = PolarsDataFrame(df) + + out = dfi.select_columns([0, 2]) + + expected = pl.DataFrame({"a": [1, 2], "c": [5, 6]}) + assert_frame_equal(out._df, expected) + + +def test_select_columns_nonlist_input() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}) + dfi = PolarsDataFrame(df) + + out = dfi.select_columns((2,)) + + expected = pl.DataFrame({"c": [5, 6]}) + assert_frame_equal(out._df, expected) + + +def test_select_columns_invalid_input() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}) + dfi = PolarsDataFrame(df) + + with pytest.raises(TypeError): + dfi.select_columns(1) # type: ignore[arg-type] + + +def test_select_columns_by_name() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}) + dfi = PolarsDataFrame(df) + + out = dfi.select_columns_by_name(["a", "c"]) + + expected = pl.DataFrame({"a": [1, 2], "c": [5, 6]}) + assert_frame_equal(out._df, expected) + + +def test_select_columns_by_name_invalid_input() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}) + dfi = PolarsDataFrame(df) + + with pytest.raises(TypeError): + dfi.select_columns_by_name(1) # type: ignore[arg-type] + + +@pytest.mark.parametrize("n_chunks", [None, 2]) +def test_get_chunks(n_chunks: int | None) -> None: + df1 = pl.DataFrame({"a": [1, 2], "b": [4, 5]}) + df2 = pl.DataFrame({"a": [3], "b": [6]}) + df = pl.concat([df1, df2], rechunk=False) + dfi = PolarsDataFrame(df) + + out = dfi.get_chunks(n_chunks) + + expected = dfi._get_chunks_from_col_chunks() + for o, e in zip(out, expected): + assert_frame_equal(o._df, e) + + +def test_get_chunks_invalid_input() -> None: + df1 = pl.DataFrame({"a": [1, 2], "b": [4, 5]}) + df2 = pl.DataFrame({"a": [3], "b": [6]}) + df = pl.concat([df1, df2], rechunk=False) + + dfi = PolarsDataFrame(df) + + with pytest.raises(ValueError): + next(dfi.get_chunks(0)) + + with pytest.raises(ValueError): + next(dfi.get_chunks(3)) + + +def test_get_chunks_subdivided_chunks() -> None: + df1 = pl.DataFrame({"a": [1, 2, 3], "b": [6, 7, 8]}) + df2 = pl.DataFrame({"a": [4, 5], "b": [9, 0]}) + df = pl.concat([df1, df2], rechunk=False) + + dfi = PolarsDataFrame(df) + out = dfi.get_chunks(4) + + chunk1 = next(out) + expected1 = pl.DataFrame({"a": [1, 2], "b": [6, 7]}) + assert_frame_equal(chunk1._df, expected1) + + chunk2 = next(out) + expected2 = pl.DataFrame({"a": [3], "b": [8]}) + assert_frame_equal(chunk2._df, expected2) + + chunk3 = next(out) + expected3 = pl.DataFrame({"a": [4], "b": [9]}) + assert_frame_equal(chunk3._df, expected3) + + chunk4 = next(out) + expected4 = pl.DataFrame({"a": [5], "b": [0]}) + assert_frame_equal(chunk4._df, expected4) + + with pytest.raises(StopIteration): + next(out) + + +def test_get_chunks_zero_copy_fail() -> None: + col1 = pl.Series([1, 2]) + col2 = pl.concat([pl.Series([3]), pl.Series([4])], rechunk=False) + df = pl.DataFrame({"a": col1, "b": col2}) + + dfi = PolarsDataFrame(df, allow_copy=False) + + with pytest.raises( + CopyNotAllowedError, match="unevenly chunked columns must be rechunked" + ): + next(dfi.get_chunks()) + + +@pytest.mark.parametrize("allow_copy", [True, False]) +def test_get_chunks_from_col_chunks_single_chunk(allow_copy: bool) -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + dfi = PolarsDataFrame(df, allow_copy=allow_copy) + out = dfi._get_chunks_from_col_chunks() + + chunk1 = next(out) + assert_frame_equal(chunk1, df) + + with pytest.raises(StopIteration): + next(out) + + +@pytest.mark.parametrize("allow_copy", [True, False]) +def test_get_chunks_from_col_chunks_even_chunks(allow_copy: bool) -> None: + df1 = pl.DataFrame({"a": [1, 2], "b": [4, 5]}) + df2 = pl.DataFrame({"a": [3], "b": [6]}) + df = pl.concat([df1, df2], rechunk=False) + + dfi = PolarsDataFrame(df, allow_copy=allow_copy) + out = dfi._get_chunks_from_col_chunks() + + chunk1 = next(out) + assert_frame_equal(chunk1, df1) + + chunk2 = next(out) + assert_frame_equal(chunk2, df2) + + with pytest.raises(StopIteration): + next(out) + + +def test_get_chunks_from_col_chunks_uneven_chunks_allow_copy() -> None: + col1 = pl.concat([pl.Series([1, 2]), pl.Series([3, 4, 5])], rechunk=False) + col2 = pl.concat( + [pl.Series([6, 7]), pl.Series([8]), pl.Series([9, 0])], rechunk=False + ) + df = pl.DataFrame({"a": col1, "b": col2}) + + dfi = PolarsDataFrame(df, allow_copy=True) + out = dfi._get_chunks_from_col_chunks() + + expected1 = pl.DataFrame({"a": [1, 2], "b": [6, 7]}) + chunk1 = next(out) + assert_frame_equal(chunk1, expected1) + + expected2 = pl.DataFrame({"a": [3, 4, 5], "b": [8, 9, 0]}) + chunk2 = next(out) + assert_frame_equal(chunk2, expected2) + + with pytest.raises(StopIteration): + next(out) + + +def test_get_chunks_from_col_chunks_uneven_chunks_zero_copy_fails() -> None: + col1 = pl.concat([pl.Series([1, 2]), pl.Series([3, 4, 5])], rechunk=False) + col2 = pl.concat( + [pl.Series([6, 7]), pl.Series([8]), pl.Series([9, 0])], rechunk=False + ) + df = pl.DataFrame({"a": col1, "b": col2}) + + dfi = PolarsDataFrame(df, allow_copy=False) + out = dfi._get_chunks_from_col_chunks() + + # First chunk can be yielded zero copy + expected1 = pl.DataFrame({"a": [1, 2], "b": [6, 7]}) + chunk1 = next(out) + assert_frame_equal(chunk1, expected1) + + # Second chunk requires a rechunk of the second column + with pytest.raises(CopyNotAllowedError, match="columns must be rechunked"): + next(out) + + +def test_dataframe_unsupported_types() -> None: + df = pl.DataFrame({"a": [[4], [5, 6]]}) + dfi = PolarsDataFrame(df) + + # Generic dataframe operations work fine + assert dfi.num_rows() == 2 + + # Certain column operations also work + col = dfi.get_column_by_name("a") + assert col.num_chunks() == 1 + + # Error is raised when unsupported operations are requested + with pytest.raises(ValueError, match="not supported"): + col.dtype diff --git a/py-polars/tests/unit/interchange/test_from_dataframe.py b/py-polars/tests/unit/interchange/test_from_dataframe.py new file mode 100644 index 000000000000..e8b276939987 --- /dev/null +++ b/py-polars/tests/unit/interchange/test_from_dataframe.py @@ -0,0 +1,615 @@ +from __future__ import annotations + +from datetime import date, datetime, time, timedelta +from typing import Any + +import pandas as pd +import pyarrow as pa +import pytest + +import polars as pl +from polars.interchange.buffer import PolarsBuffer +from polars.interchange.column import PolarsColumn +from polars.interchange.from_dataframe import ( + _categorical_column_to_series, + _column_to_series, + _construct_data_buffer, + _construct_offsets_buffer, + _construct_validity_buffer, + _construct_validity_buffer_from_bitmask, + _construct_validity_buffer_from_bytemask, + _string_column_to_series, +) +from polars.interchange.protocol import ( + ColumnNullType, + CopyNotAllowedError, + DtypeKind, + Endianness, +) +from polars.testing import assert_frame_equal, assert_series_equal + +NE = Endianness.NATIVE + + +def test_from_dataframe_polars() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3.0, 4.0], "c": ["foo", "bar"]}) + with pytest.deprecated_call(match="`allow_copy` is deprecated"): + result = pl.from_dataframe(df, allow_copy=False) + assert_frame_equal(result, df) + + +def test_from_dataframe_polars_interchange_fast_path() -> None: + df = pl.DataFrame( + {"a": [1, 2], "b": [3.0, 4.0], "c": ["foo", "bar"]}, + schema_overrides={"c": pl.Categorical}, + ) + dfi = df.__dataframe__() + with pytest.deprecated_call(match="`allow_copy` is deprecated"): + result = pl.from_dataframe(dfi, allow_copy=False) + assert_frame_equal(result, df) + + +def test_from_dataframe_categorical() -> None: + df = pl.DataFrame({"a": ["foo", "bar"]}, schema={"a": pl.Categorical}) + df_pa = df.to_arrow() + + with pytest.deprecated_call(match="`allow_copy` is deprecated"): + result = pl.from_dataframe(df_pa, allow_copy=True) + expected = pl.DataFrame({"a": ["foo", "bar"]}, schema={"a": pl.Categorical}) + assert_frame_equal(result, expected) + + +def test_from_dataframe_empty_string_zero_copy() -> None: + df = pl.DataFrame({"a": []}, schema={"a": pl.String}) + df_pa = df.to_arrow() + with pytest.deprecated_call(match="`allow_copy` is deprecated"): + result = pl.from_dataframe(df_pa, allow_copy=False) + assert_frame_equal(result, df) + + +def test_from_dataframe_empty_bool_zero_copy() -> None: + df = pl.DataFrame(schema={"a": pl.Boolean}) + df_pd = df.to_pandas() + with pytest.deprecated_call(match="`allow_copy` is deprecated"): + result = pl.from_dataframe(df_pd, allow_copy=False) + assert_frame_equal(result, df) + + +def test_from_dataframe_empty_categories_zero_copy() -> None: + df = pl.DataFrame(schema={"a": pl.Enum([])}) + df_pa = df.to_arrow() + with pytest.deprecated_call(match="`allow_copy` is deprecated"): + result = pl.from_dataframe(df_pa, allow_copy=False) + assert_frame_equal(result, df) + + +def test_from_dataframe_pandas_zero_copy() -> None: + data = {"a": [1, 2], "b": [3.0, 4.0]} + + df = pd.DataFrame(data) + with pytest.deprecated_call(match="`allow_copy` is deprecated"): + result = pl.from_dataframe(df, allow_copy=False) + expected = pl.DataFrame(data) + assert_frame_equal(result, expected) + + +def test_from_dataframe_pyarrow_table_zero_copy() -> None: + df = pl.DataFrame( + { + "a": [1, 2], + "b": [3.0, 4.0], + } + ) + df_pa = df.to_arrow() + + with pytest.deprecated_call(match="`allow_copy` is deprecated"): + result = pl.from_dataframe(df_pa, allow_copy=False) + assert_frame_equal(result, df) + + +def test_from_dataframe_pyarrow_empty_table() -> None: + df = pl.Series("a", dtype=pl.Int8).to_frame() + df_pa = df.to_arrow() + + with pytest.deprecated_call(match="`allow_copy` is deprecated"): + result = pl.from_dataframe(df_pa, allow_copy=False) + assert_frame_equal(result, df) + + +def test_from_dataframe_pyarrow_recordbatch_zero_copy() -> None: + a = pa.array([1, 2]) + b = pa.array([3.0, 4.0]) + + batch = pa.record_batch([a, b], names=["a", "b"]) + with pytest.deprecated_call(match="`allow_copy` is deprecated"): + result = pl.from_dataframe(batch, allow_copy=False) + + expected = pl.DataFrame({"a": [1, 2], "b": [3.0, 4.0]}) + assert_frame_equal(result, expected) + + +def test_from_dataframe_invalid_type() -> None: + df = [[1, 2], [3, 4]] + with pytest.raises(TypeError): + pl.from_dataframe(df) # type: ignore[arg-type] + + +def test_from_dataframe_pyarrow_boolean() -> None: + df = pl.Series("a", [True, False]).to_frame() + df_pa = df.to_arrow() + + result = pl.from_dataframe(df_pa) + assert_frame_equal(result, df) + + with pytest.deprecated_call(match="`allow_copy` is deprecated"): + result = pl.from_dataframe(df_pa, allow_copy=False) + assert_frame_equal(result, df) + + +def test_from_dataframe_chunked() -> None: + df = pl.Series("a", [0, 1], dtype=pl.Int8).to_frame() + df_chunked = pl.concat([df[:1], df[1:]], rechunk=False) + + df_pa = df_chunked.to_arrow() + result = pl.from_dataframe(df_pa, rechunk=False) + + assert_frame_equal(result, df_chunked) + assert result.n_chunks() == 2 + + +@pytest.mark.may_fail_auto_streaming +def test_from_dataframe_chunked_string() -> None: + df = pl.Series("a", ["a", None, "bc", "d", None, "efg"]).to_frame() + df_chunked = pl.concat([df[:1], df[1:3], df[3:]], rechunk=False) + + df_pa = df_chunked.to_arrow() + result = pl.from_dataframe(df_pa, rechunk=False) + + assert_frame_equal(result, df_chunked) + assert result.n_chunks() == 3 + + +def test_from_dataframe_pandas_nan_as_null() -> None: + df = pd.Series([1.0, float("nan"), float("inf")], name="a").to_frame() + result = pl.from_dataframe(df) + expected = pl.Series("a", [1.0, None, float("inf")]).to_frame() + assert_frame_equal(result, expected) + assert result.n_chunks() == 1 + + +def test_from_dataframe_pandas_boolean_bytes() -> None: + df = pd.Series([True, False], name="a").to_frame() + result = pl.from_dataframe(df) + + expected = pl.Series("a", [True, False]).to_frame() + assert_frame_equal(result, expected) + + with pytest.deprecated_call(match="`allow_copy` is deprecated"): + result = pl.from_dataframe(df, allow_copy=False) + expected = pl.Series("a", [True, False]).to_frame() + assert_frame_equal(result, expected) + + +def test_from_dataframe_categorical_pandas() -> None: + values = ["a", "b", None, "a"] + + df_pd = pd.Series(values, dtype="category", name="a").to_frame() + + result = pl.from_dataframe(df_pd) + expected = pl.Series("a", values, dtype=pl.Categorical).to_frame() + assert_frame_equal(result, expected) + + with pytest.deprecated_call(match="`allow_copy` is deprecated"): + result = pl.from_dataframe(df_pd, allow_copy=False) + expected = pl.Series("a", values, dtype=pl.Categorical).to_frame() + assert_frame_equal(result, expected) + + +def test_from_dataframe_categorical_pyarrow() -> None: + values = ["a", "b", None, "a"] + + dtype = pa.dictionary(pa.int32(), pa.utf8()) + arr = pa.array(values, dtype) + df_pa = pa.Table.from_arrays([arr], names=["a"]) + + result = pl.from_dataframe(df_pa) + expected = pl.Series("a", values, dtype=pl.Categorical).to_frame() + assert_frame_equal(result, expected) + + with pytest.deprecated_call(match="`allow_copy` is deprecated"): + result = pl.from_dataframe(df_pa, allow_copy=False) + assert_frame_equal(result, expected) + + +def test_from_dataframe_categorical_non_string_keys() -> None: + values = [1, 2, None, 1] + + dtype = pa.dictionary(pa.uint32(), pa.int32()) + arr = pa.array(values, dtype) + df_pa = pa.Table.from_arrays([arr], names=["a"]) + result = pl.from_dataframe(df_pa) + expected = pl.DataFrame({"a": [1, 2, None, 1]}, schema={"a": pl.Int32}) + assert_frame_equal(result, expected) + + +def test_from_dataframe_categorical_non_u32_values() -> None: + values = [None, None] + + dtype = pa.dictionary(pa.int8(), pa.utf8()) + arr = pa.array(values, dtype) + df_pa = pa.Table.from_arrays([arr], names=["a"]) + + result = pl.from_dataframe(df_pa) + expected = pl.Series("a", values, dtype=pl.Categorical).to_frame() + assert_frame_equal(result, expected) + + with pytest.deprecated_call(match="`allow_copy` is deprecated"): + result = pl.from_dataframe(df_pa, allow_copy=False) + assert_frame_equal(result, expected) + + +class PatchableColumn(PolarsColumn): + """Helper class that allows patching certain PolarsColumn properties.""" + + describe_null: tuple[ColumnNullType, Any] = (ColumnNullType.USE_BITMASK, 0) + describe_categorical: dict[str, Any] = {} # type: ignore[assignment] # noqa: RUF012 + null_count = 0 + + +def test_column_to_series_use_sentinel_i64_min() -> None: + I64_MIN = -9223372036854775808 + dtype = pl.Datetime("us") + physical = pl.Series([0, I64_MIN]) + logical = physical.cast(dtype) + + col = PatchableColumn(logical) + col.describe_null = (ColumnNullType.USE_SENTINEL, I64_MIN) + col.null_count = 1 + + result = _column_to_series(col, dtype, allow_copy=True) + expected = pl.Series([datetime(1970, 1, 1), None]) + assert_series_equal(result, expected) + + +def test_column_to_series_duration() -> None: + s = pl.Series([timedelta(seconds=10), timedelta(days=5), None]) + col = PolarsColumn(s) + result = _column_to_series(col, s.dtype, allow_copy=True) + assert_series_equal(result, s) + + +def test_column_to_series_time() -> None: + s = pl.Series([time(10, 0), time(23, 59, 59), None]) + col = PolarsColumn(s) + result = _column_to_series(col, s.dtype, allow_copy=True) + assert_series_equal(result, s) + + +def test_column_to_series_use_sentinel_date() -> None: + mask_value = date(1900, 1, 1) + + s = pl.Series([date(1970, 1, 1), mask_value, date(2000, 1, 1)]) + + col = PatchableColumn(s) + col.describe_null = (ColumnNullType.USE_SENTINEL, mask_value) + col.null_count = 1 + + result = _column_to_series(col, pl.Date, allow_copy=True) + expected = pl.Series([date(1970, 1, 1), None, date(2000, 1, 1)]) + assert_series_equal(result, expected) + + +def test_column_to_series_use_sentinel_datetime() -> None: + dtype = pl.Datetime("ns") + mask_value = datetime(1900, 1, 1) + + s = pl.Series([datetime(1970, 1, 1), mask_value, datetime(2000, 1, 1)], dtype=dtype) + + col = PatchableColumn(s) + col.describe_null = (ColumnNullType.USE_SENTINEL, mask_value) + col.null_count = 1 + + result = _column_to_series(col, dtype, allow_copy=True) + expected = pl.Series( + [datetime(1970, 1, 1), None, datetime(2000, 1, 1)], dtype=dtype + ) + assert_series_equal(result, expected) + + +def test_column_to_series_use_sentinel_invalid_value() -> None: + dtype = pl.Datetime("ns") + mask_value = "invalid" + + s = pl.Series([datetime(1970, 1, 1), None, datetime(2000, 1, 1)], dtype=dtype) + + col = PatchableColumn(s) + col.describe_null = (ColumnNullType.USE_SENTINEL, mask_value) + col.null_count = 1 + + with pytest.raises( + TypeError, + match="invalid sentinel value for column of type Datetime\\(time_unit='ns', time_zone=None\\): 'invalid'", + ): + _column_to_series(col, dtype, allow_copy=True) + + +def test_string_column_to_series_no_offsets() -> None: + s = pl.Series([97, 98, 99]) + col = PolarsColumn(s) + with pytest.raises( + RuntimeError, + match="cannot create String column without an offsets buffer", + ): + _string_column_to_series(col, allow_copy=True) + + +@pytest.mark.usefixtures("test_global_and_local") +def test_categorical_column_to_series_non_dictionary() -> None: + s = pl.Series(["a", "b", None, "a"], dtype=pl.Categorical) + + col = PatchableColumn(s) + col.describe_categorical = {"is_dictionary": False} + + with pytest.raises( + NotImplementedError, match="non-dictionary categoricals are not yet supported" + ): + _categorical_column_to_series(col, allow_copy=True) + + +def test_construct_data_buffer() -> None: + data = pl.Series([0, 1, 3, 3, 9], dtype=pl.Int64) + buffer = PolarsBuffer(data) + dtype = (DtypeKind.INT, 64, "l", NE) + + result = _construct_data_buffer(buffer, dtype, length=5, allow_copy=True) + assert_series_equal(result, data) + + +def test_construct_data_buffer_boolean_sliced() -> None: + data = pl.Series([False, True, True, False]) + data_sliced = data[2:] + buffer = PolarsBuffer(data_sliced) + dtype = (DtypeKind.BOOL, 1, "b", NE) + + result = _construct_data_buffer(buffer, dtype, length=2, offset=2, allow_copy=True) + assert_series_equal(result, data_sliced) + + +def test_construct_data_buffer_logical_dtype() -> None: + data = pl.Series([100, 200, 300], dtype=pl.Int32) + buffer = PolarsBuffer(data) + dtype = (DtypeKind.DATETIME, 32, "tdD", NE) + + result = _construct_data_buffer(buffer, dtype, length=3, allow_copy=True) + assert_series_equal(result, data) + + +def test_construct_offsets_buffer() -> None: + data = pl.Series([0, 1, 3, 3, 9], dtype=pl.Int64) + buffer = PolarsBuffer(data) + dtype = (DtypeKind.INT, 64, "l", NE) + + result = _construct_offsets_buffer(buffer, dtype, offset=0, allow_copy=True) + assert_series_equal(result, data) + + +def test_construct_offsets_buffer_offset() -> None: + data = pl.Series([0, 1, 3, 3, 9], dtype=pl.Int64) + buffer = PolarsBuffer(data) + dtype = (DtypeKind.INT, 64, "l", NE) + offset = 2 + + result = _construct_offsets_buffer(buffer, dtype, offset=offset, allow_copy=True) + assert_series_equal(result, data[offset:]) + + +def test_construct_offsets_buffer_copy() -> None: + data = pl.Series([0, 1, 3, 3, 9], dtype=pl.UInt32) + buffer = PolarsBuffer(data) + dtype = (DtypeKind.UINT, 32, "I", NE) + + with pytest.raises(CopyNotAllowedError): + _construct_offsets_buffer(buffer, dtype, offset=0, allow_copy=False) + + result = _construct_offsets_buffer(buffer, dtype, offset=0, allow_copy=True) + expected = pl.Series([0, 1, 3, 3, 9], dtype=pl.Int64) + assert_series_equal(result, expected) + + +@pytest.fixture +def bitmask() -> PolarsBuffer: + data = pl.Series([False, True, True, False]) + return PolarsBuffer(data) + + +@pytest.fixture +def bytemask() -> PolarsBuffer: + data = pl.Series([0, 1, 1, 0], dtype=pl.UInt8) + return PolarsBuffer(data) + + +def test_construct_validity_buffer_non_nullable() -> None: + s = pl.Series([1, 2, 3]) + + col = PatchableColumn(s) + col.describe_null = (ColumnNullType.NON_NULLABLE, None) + col.null_count = 1 + + result = _construct_validity_buffer(None, col, s.dtype, s, allow_copy=True) + assert result is None + + +def test_construct_validity_buffer_null_count() -> None: + s = pl.Series([1, 2, 3]) + + col = PatchableColumn(s) + col.describe_null = (ColumnNullType.USE_SENTINEL, -1) + col.null_count = 0 + + result = _construct_validity_buffer(None, col, s.dtype, s, allow_copy=True) + assert result is None + + +def test_construct_validity_buffer_use_bitmask(bitmask: PolarsBuffer) -> None: + s = pl.Series([1, 2, 3, 4]) + + col = PatchableColumn(s) + col.describe_null = (ColumnNullType.USE_BITMASK, 0) + col.null_count = 2 + + dtype = (DtypeKind.BOOL, 1, "b", NE) + validity_buffer_info = (bitmask, dtype) + + result = _construct_validity_buffer( + validity_buffer_info, col, s.dtype, s, allow_copy=True + ) + expected = pl.Series([False, True, True, False]) + assert_series_equal(result, expected) # type: ignore[arg-type] + + result = _construct_validity_buffer(None, col, s.dtype, s, allow_copy=True) + assert result is None + + +def test_construct_validity_buffer_use_bytemask(bytemask: PolarsBuffer) -> None: + s = pl.Series([1, 2, 3, 4]) + + col = PatchableColumn(s) + col.describe_null = (ColumnNullType.USE_BYTEMASK, 0) + col.null_count = 2 + + dtype = (DtypeKind.UINT, 8, "C", NE) + validity_buffer_info = (bytemask, dtype) + + result = _construct_validity_buffer( + validity_buffer_info, col, s.dtype, s, allow_copy=True + ) + expected = pl.Series([False, True, True, False]) + assert_series_equal(result, expected) # type: ignore[arg-type] + + result = _construct_validity_buffer(None, col, s.dtype, s, allow_copy=True) + assert result is None + + +def test_construct_validity_buffer_use_nan() -> None: + s = pl.Series([1.0, 2.0, float("nan")]) + + col = PatchableColumn(s) + col.describe_null = (ColumnNullType.USE_NAN, None) + col.null_count = 1 + + result = _construct_validity_buffer(None, col, s.dtype, s, allow_copy=True) + expected = pl.Series([True, True, False]) + assert_series_equal(result, expected) # type: ignore[arg-type] + + with pytest.raises(CopyNotAllowedError, match="bitmask must be constructed"): + _construct_validity_buffer(None, col, s.dtype, s, allow_copy=False) + + +def test_construct_validity_buffer_use_sentinel() -> None: + s = pl.Series(["a", "bc", "NULL"]) + + col = PatchableColumn(s) + col.describe_null = (ColumnNullType.USE_SENTINEL, "NULL") + col.null_count = 1 + + result = _construct_validity_buffer(None, col, s.dtype, s, allow_copy=True) + expected = pl.Series([True, True, False]) + assert_series_equal(result, expected) # type: ignore[arg-type] + + with pytest.raises(CopyNotAllowedError, match="bitmask must be constructed"): + _construct_validity_buffer(None, col, s.dtype, s, allow_copy=False) + + +def test_construct_validity_buffer_unsupported() -> None: + s = pl.Series([1, 2, 3]) + + col = PatchableColumn(s) + col.describe_null = (100, None) # type: ignore[assignment] + col.null_count = 1 + + with pytest.raises(NotImplementedError, match="unsupported null type: 100"): + _construct_validity_buffer(None, col, s.dtype, s, allow_copy=True) + + +@pytest.mark.parametrize("allow_copy", [True, False]) +def test_construct_validity_buffer_from_bitmask( + allow_copy: bool, bitmask: PolarsBuffer +) -> None: + result = _construct_validity_buffer_from_bitmask( + bitmask, null_value=0, offset=0, length=4, allow_copy=allow_copy + ) + expected = pl.Series([False, True, True, False]) + assert_series_equal(result, expected) + + +def test_construct_validity_buffer_from_bitmask_inverted(bitmask: PolarsBuffer) -> None: + result = _construct_validity_buffer_from_bitmask( + bitmask, null_value=1, offset=0, length=4, allow_copy=True + ) + expected = pl.Series([True, False, False, True]) + assert_series_equal(result, expected) + + +def test_construct_validity_buffer_from_bitmask_zero_copy_fails( + bitmask: PolarsBuffer, +) -> None: + with pytest.raises(CopyNotAllowedError): + _construct_validity_buffer_from_bitmask( + bitmask, null_value=1, offset=0, length=4, allow_copy=False + ) + + +def test_construct_validity_buffer_from_bitmask_sliced() -> None: + data = pl.Series([False, True, True, False]) + data_sliced = data[2:] + bitmask = PolarsBuffer(data_sliced) + + result = _construct_validity_buffer_from_bitmask( + bitmask, null_value=0, offset=2, length=2, allow_copy=True + ) + assert_series_equal(result, data_sliced) + + +def test_construct_validity_buffer_from_bytemask(bytemask: PolarsBuffer) -> None: + result = _construct_validity_buffer_from_bytemask( + bytemask, null_value=0, allow_copy=True + ) + expected = pl.Series([False, True, True, False]) + assert_series_equal(result, expected) + + +def test_construct_validity_buffer_from_bytemask_inverted( + bytemask: PolarsBuffer, +) -> None: + result = _construct_validity_buffer_from_bytemask( + bytemask, null_value=1, allow_copy=True + ) + expected = pl.Series([True, False, False, True]) + assert_series_equal(result, expected) + + +def test_construct_validity_buffer_from_bytemask_zero_copy_fails( + bytemask: PolarsBuffer, +) -> None: + with pytest.raises(CopyNotAllowedError): + _construct_validity_buffer_from_bytemask( + bytemask, null_value=0, allow_copy=False + ) + + +def test_interchange_protocol_fallback(monkeypatch: pytest.MonkeyPatch) -> None: + df_pd = pd.DataFrame({"a": [1, 2, 3]}) + monkeypatch.setattr(df_pd, "__arrow_c_stream__", lambda *args, **kwargs: 1 / 0) + with pytest.warns( + UserWarning, match="Falling back to Dataframe Interchange Protocol" + ): + result = pl.from_dataframe(df_pd) + expected = pl.DataFrame({"a": [1, 2, 3]}) + assert_frame_equal(result, expected) + + +def test_to_pandas_int8_20316() -> None: + df = pl.Series("a", [None], pl.Int8).to_frame() + df_pd = df.to_pandas(use_pyarrow_extension_array=True) + result = pl.from_dataframe(df_pd) + assert_frame_equal(result, df) diff --git a/py-polars/tests/unit/interchange/test_roundtrip.py b/py-polars/tests/unit/interchange/test_roundtrip.py new file mode 100644 index 000000000000..131173a343ac --- /dev/null +++ b/py-polars/tests/unit/interchange/test_roundtrip.py @@ -0,0 +1,295 @@ +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING + +import pandas as pd +import pyarrow as pa +import pyarrow.interchange +import pytest +from hypothesis import given + +import polars as pl +from polars._utils.various import parse_version +from polars.interchange.from_dataframe import ( + from_dataframe as from_dataframe_interchange_protocol, +) +from polars.testing import assert_frame_equal, assert_series_equal +from polars.testing.parametric import dataframes + +skip_if_broken_pandas_version = pytest.mark.skipif( + pd.__version__.startswith("2"), reason="bug. see #20316" +) + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + +protocol_dtypes: list[PolarsDataType] = [ + pl.Int8, + pl.Int16, + pl.Int32, + pl.Int64, + pl.UInt8, + pl.UInt16, + pl.UInt32, + pl.UInt64, + pl.Float32, + pl.Float64, + pl.Boolean, + pl.String, + pl.Datetime, + # This is broken for empty dataframes + # TODO: Enable lexically ordered categoricals + # pl.Categorical("physical"), + # TODO: Add Enum + # pl.Enum, +] + + +@given( + dataframes( + allowed_dtypes=protocol_dtypes, + allow_null=False, # Bug: https://github.com/pola-rs/polars/issues/16190 + ) +) +def test_to_dataframe_pyarrow_parametric(df: pl.DataFrame) -> None: + dfi = df.__dataframe__() + df_pa = pa.interchange.from_dataframe(dfi) + + with pl.StringCache(): + result: pl.DataFrame = pl.from_arrow(df_pa) # type: ignore[assignment] + assert_frame_equal(result, df, categorical_as_str=True) + + +@given( + dataframes( + allowed_dtypes=protocol_dtypes, + excluded_dtypes=[ + pl.String, # Polars String type does not match protocol spec + pl.Categorical, + ], + allow_chunks=False, + ) +) +def test_to_dataframe_pyarrow_zero_copy_parametric(df: pl.DataFrame) -> None: + dfi = df.__dataframe__(allow_copy=False) + df_pa = pa.interchange.from_dataframe(dfi, allow_copy=False) + + result: pl.DataFrame = pl.from_arrow(df_pa) # type: ignore[assignment] + assert_frame_equal(result, df, categorical_as_str=True) + + +@pytest.mark.filterwarnings( + "ignore:.*PEP3118 format string that does not match its itemsize:RuntimeWarning" +) +@given( + dataframes( + allowed_dtypes=protocol_dtypes, + allow_null=False, # Bug: https://github.com/pola-rs/polars/issues/16190 + ) +) +def test_to_dataframe_pandas_parametric(df: pl.DataFrame) -> None: + dfi = df.__dataframe__() + df_pd = pd.api.interchange.from_dataframe(dfi) + result = pl.from_pandas(df_pd, nan_to_null=False) + assert_frame_equal(result, df, categorical_as_str=True) + + +@pytest.mark.filterwarnings( + "ignore:.*PEP3118 format string that does not match its itemsize:RuntimeWarning" +) +@given( + dataframes( + allowed_dtypes=protocol_dtypes, + excluded_dtypes=[ + pl.String, # Polars String type does not match protocol spec + pl.Categorical, + ], + allow_chunks=False, + allow_null=False, # Bug: https://github.com/pola-rs/polars/issues/16190 + ) +) +def test_to_dataframe_pandas_zero_copy_parametric(df: pl.DataFrame) -> None: + dfi = df.__dataframe__(allow_copy=False) + df_pd = pd.api.interchange.from_dataframe(dfi, allow_copy=False) + result = pl.from_pandas(df_pd, nan_to_null=False) + assert_frame_equal(result, df, categorical_as_str=True) + + +@given( + dataframes( + allowed_dtypes=protocol_dtypes, + excluded_dtypes=[ + pl.Categorical, # Categoricals read back as Enum types + ], + ) +) +def test_from_dataframe_pyarrow_parametric(df: pl.DataFrame) -> None: + df_pa = df.to_arrow() + result = from_dataframe_interchange_protocol(df_pa) + assert_frame_equal(result, df, categorical_as_str=True) + + +@given( + dataframes( + allowed_dtypes=protocol_dtypes, + excluded_dtypes=[ + pl.String, # Polars String type does not match protocol spec + pl.Categorical, # Polars copies the categories to construct a mapping + pl.Boolean, # pyarrow exports boolean buffers as byte-packed: https://github.com/apache/arrow/issues/37991 + ], + allow_chunks=False, + ) +) +def test_from_dataframe_pyarrow_zero_copy_parametric(df: pl.DataFrame) -> None: + df_pa = df.to_arrow() + result = from_dataframe_interchange_protocol(df_pa, allow_copy=False) + assert_frame_equal(result, df) + + +@skip_if_broken_pandas_version +@given( + dataframes( + allowed_dtypes=protocol_dtypes, + excluded_dtypes=[ + pl.Categorical, # Categoricals come back as Enums + pl.Float32, # NaN values come back as nulls + pl.Float64, # NaN values come back as nulls + ], + ) +) +def test_from_dataframe_pandas_parametric(df: pl.DataFrame) -> None: + df_pd = df.to_pandas(use_pyarrow_extension_array=True) + result = from_dataframe_interchange_protocol(df_pd) + assert_frame_equal(result, df, categorical_as_str=True) + + +@skip_if_broken_pandas_version +@given( + dataframes( + allowed_dtypes=protocol_dtypes, + excluded_dtypes=[ + pl.String, # Polars String type does not match protocol spec + pl.Categorical, # Categoricals come back as Enums + pl.Float32, # NaN values come back as nulls + pl.Float64, # NaN values come back as nulls + pl.Boolean, # pandas exports boolean buffers as byte-packed + ], + # Empty dataframes cause an error due to a bug in pandas. + # https://github.com/pandas-dev/pandas/issues/56700 + min_size=1, + allow_chunks=False, + ) +) +def test_from_dataframe_pandas_zero_copy_parametric(df: pl.DataFrame) -> None: + df_pd = df.to_pandas(use_pyarrow_extension_array=True) + result = from_dataframe_interchange_protocol(df_pd, allow_copy=False) + assert_frame_equal(result, df) + + +@given( + dataframes( + allowed_dtypes=protocol_dtypes, + excluded_dtypes=[ + pl.Categorical, # Categoricals come back as Enums + pl.Float32, # NaN values come back as nulls + pl.Float64, # NaN values come back as nulls + ], + # Empty string columns cause an error due to a bug in pandas. + # https://github.com/pandas-dev/pandas/issues/56703 + min_size=1, + allow_null=False, # Bug: https://github.com/pola-rs/polars/issues/16190 + ) +) +def test_from_dataframe_pandas_native_parametric(df: pl.DataFrame) -> None: + df_pd = df.to_pandas() + result = from_dataframe_interchange_protocol(df_pd) + assert_frame_equal(result, df, categorical_as_str=True) + + +@given( + dataframes( + allowed_dtypes=protocol_dtypes, + excluded_dtypes=[ + pl.String, # Polars String type does not match protocol spec + pl.Categorical, # Categoricals come back as Enums + pl.Float32, # NaN values come back as nulls + pl.Float64, # NaN values come back as nulls + pl.Boolean, # pandas exports boolean buffers as byte-packed + ], + # Empty dataframes cause an error due to a bug in pandas. + # https://github.com/pandas-dev/pandas/issues/56700 + min_size=1, + allow_chunks=False, + allow_null=False, # Bug: https://github.com/pola-rs/polars/issues/16190 + ) +) +def test_from_dataframe_pandas_native_zero_copy_parametric(df: pl.DataFrame) -> None: + df_pd = df.to_pandas() + result = from_dataframe_interchange_protocol(df_pd, allow_copy=False) + assert_frame_equal(result, df) + + +def test_to_dataframe_pandas_boolean_subchunks() -> None: + df = pl.Series("a", [False, False]).to_frame() + df_chunked = pl.concat([df[0, :], df[1, :]], rechunk=False) + dfi = df_chunked.__dataframe__() + + df_pd = pd.api.interchange.from_dataframe(dfi) + result = pl.from_pandas(df_pd, nan_to_null=False) + + assert_frame_equal(result, df) + + +def test_to_dataframe_pyarrow_boolean() -> None: + df = pl.Series("a", [True, False], dtype=pl.Boolean).to_frame() + dfi = df.__dataframe__() + + df_pa = pa.interchange.from_dataframe(dfi) + result: pl.DataFrame = pl.from_arrow(df_pa) # type: ignore[assignment] + + assert_frame_equal(result, df) + + +def test_to_dataframe_pyarrow_boolean_midbyte_slice() -> None: + s = pl.Series("a", [False] * 9)[3:] + df = s.to_frame() + dfi = df.__dataframe__() + + df_pa = pa.interchange.from_dataframe(dfi) + result: pl.DataFrame = pl.from_arrow(df_pa) # type: ignore[assignment] + + assert_frame_equal(result, df) + + +@pytest.mark.skipif( + parse_version(pd.__version__) < (2, 2), + reason="Pandas versions < 2.2 do not implement the required conversions", +) +def test_from_dataframe_pandas_timestamp_ns() -> None: + df = pl.Series("a", [datetime(2000, 1, 1)], dtype=pl.Datetime("ns")).to_frame() + df_pd = df.to_pandas(use_pyarrow_extension_array=True) + result = pl.from_dataframe(df_pd) + assert_frame_equal(result, df) + + +def test_from_pyarrow_str_dict_with_null_values_20270() -> None: + tb = pa.table( + { + "col1": pa.DictionaryArray.from_arrays( + [0, 0, None, 1, 2], ["A", None, "B"] + ), + }, + schema=pa.schema({"col1": pa.dictionary(pa.uint32(), pa.string())}), + ) + df = pl.from_arrow(tb) + assert isinstance(df, pl.DataFrame) + + assert_series_equal( + df.to_series(), pl.Series("col1", ["A", "A", None, None, "B"], pl.Categorical) + ) + assert_series_equal( + df.select(pl.col.col1.cat.get_categories()).to_series(), + pl.Series(["A", "B"]), + check_names=False, + ) diff --git a/py-polars/tests/unit/interchange/test_utils.py b/py-polars/tests/unit/interchange/test_utils.py new file mode 100644 index 000000000000..527de10ffc1e --- /dev/null +++ b/py-polars/tests/unit/interchange/test_utils.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import polars as pl +from polars.interchange.protocol import DtypeKind, Endianness +from polars.interchange.utils import ( + dtype_to_polars_dtype, + get_buffer_length_in_elements, + polars_dtype_to_data_buffer_dtype, + polars_dtype_to_dtype, +) + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + from polars.interchange.protocol import Dtype + +NE = Endianness.NATIVE + + +@pytest.mark.parametrize( + ("polars_dtype", "dtype"), + [ + (pl.Int8, (DtypeKind.INT, 8, "c", NE)), + (pl.Int16, (DtypeKind.INT, 16, "s", NE)), + (pl.Int32, (DtypeKind.INT, 32, "i", NE)), + (pl.Int64, (DtypeKind.INT, 64, "l", NE)), + (pl.UInt8, (DtypeKind.UINT, 8, "C", NE)), + (pl.UInt16, (DtypeKind.UINT, 16, "S", NE)), + (pl.UInt32, (DtypeKind.UINT, 32, "I", NE)), + (pl.UInt64, (DtypeKind.UINT, 64, "L", NE)), + (pl.Float32, (DtypeKind.FLOAT, 32, "f", NE)), + (pl.Float64, (DtypeKind.FLOAT, 64, "g", NE)), + (pl.Boolean, (DtypeKind.BOOL, 1, "b", NE)), + (pl.String, (DtypeKind.STRING, 8, "U", NE)), + (pl.Date, (DtypeKind.DATETIME, 32, "tdD", NE)), + (pl.Time, (DtypeKind.DATETIME, 64, "ttu", NE)), + (pl.Duration, (DtypeKind.DATETIME, 64, "tDu", NE)), + (pl.Duration(time_unit="ns"), (DtypeKind.DATETIME, 64, "tDn", NE)), + (pl.Datetime, (DtypeKind.DATETIME, 64, "tsu:", NE)), + (pl.Datetime(time_unit="ms"), (DtypeKind.DATETIME, 64, "tsm:", NE)), + ( + pl.Datetime(time_zone="Amsterdam/Europe"), + (DtypeKind.DATETIME, 64, "tsu:Amsterdam/Europe", NE), + ), + ( + pl.Datetime(time_unit="ns", time_zone="Asia/Seoul"), + (DtypeKind.DATETIME, 64, "tsn:Asia/Seoul", NE), + ), + ], +) +def test_dtype_conversions(polars_dtype: PolarsDataType, dtype: Dtype) -> None: + assert polars_dtype_to_dtype(polars_dtype) == dtype + assert dtype_to_polars_dtype(dtype) == polars_dtype + + +@pytest.mark.parametrize( + "dtype", + [ + (DtypeKind.CATEGORICAL, 32, "I", NE), + (DtypeKind.CATEGORICAL, 8, "C", NE), + ], +) +def test_dtype_to_polars_dtype_categorical(dtype: Dtype) -> None: + assert dtype_to_polars_dtype(dtype) == pl.Enum + + +@pytest.mark.parametrize( + "polars_dtype", + [ + pl.Categorical, + pl.Categorical("lexical"), + pl.Enum, + pl.Enum(["a", "b"]), + ], +) +def test_polars_dtype_to_dtype_categorical(polars_dtype: PolarsDataType) -> None: + assert polars_dtype_to_dtype(polars_dtype) == (DtypeKind.CATEGORICAL, 32, "I", NE) + + +def test_polars_dtype_to_dtype_unsupported_type() -> None: + polars_dtype = pl.List(pl.Int8) + with pytest.raises(ValueError, match="not supported"): + polars_dtype_to_dtype(polars_dtype) + + +def test_dtype_to_polars_dtype_unsupported_type() -> None: + dtype = (DtypeKind.FLOAT, 16, "e", NE) + with pytest.raises( + NotImplementedError, + match="unsupported data type: \\(, 16, 'e', '='\\)", + ): + dtype_to_polars_dtype(dtype) + + +def test_dtype_to_polars_dtype_unsupported_temporal_type() -> None: + dtype = (DtypeKind.DATETIME, 64, "tss:", NE) + with pytest.raises( + NotImplementedError, + match="unsupported temporal data type: \\(, 64, 'tss:', '='\\)", + ): + dtype_to_polars_dtype(dtype) + + +@pytest.mark.parametrize( + ("dtype", "expected"), + [ + ((DtypeKind.INT, 64, "l", NE), 3), + ((DtypeKind.UINT, 32, "I", NE), 6), + ], +) +def test_get_buffer_length_in_elements(dtype: Dtype, expected: int) -> None: + assert get_buffer_length_in_elements(24, dtype) == expected + + +def test_get_buffer_length_in_elements_unsupported_dtype() -> None: + dtype = (DtypeKind.BOOL, 1, "b", NE) + with pytest.raises( + ValueError, + match="cannot get buffer length for buffer with dtype \\(, 1, 'b', '='\\)", + ): + get_buffer_length_in_elements(24, dtype) + + +@pytest.mark.parametrize( + ("dtype", "expected"), + [ + (pl.Int8, pl.Int8), + (pl.Date, pl.Int32), + (pl.Time, pl.Int64), + (pl.String, pl.UInt8), + (pl.Enum, pl.UInt32), + ], +) +def test_polars_dtype_to_data_buffer_dtype( + dtype: PolarsDataType, expected: PolarsDataType +) -> None: + assert polars_dtype_to_data_buffer_dtype(dtype) == expected + + +def test_polars_dtype_to_data_buffer_dtype_unsupported_dtype() -> None: + dtype = pl.List(pl.Int8) + with pytest.raises(NotImplementedError): + polars_dtype_to_data_buffer_dtype(dtype) diff --git a/py-polars/tests/unit/interop/__init__.py b/py-polars/tests/unit/interop/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/interop/numpy/__init__.py b/py-polars/tests/unit/interop/numpy/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/interop/numpy/test_array_method.py b/py-polars/tests/unit/interop/numpy/test_array_method.py new file mode 100644 index 000000000000..d7cbef24a01b --- /dev/null +++ b/py-polars/tests/unit/interop/numpy/test_array_method.py @@ -0,0 +1,37 @@ +import numpy as np +import pytest +from numpy.testing import assert_array_equal + +import polars as pl + + +def test_series_array_method_copy_false() -> None: + s = pl.Series([1, 2, None]) + with pytest.raises(RuntimeError, match="copy not allowed"): + s.__array__(copy=False) + + result = s.__array__(copy=None) + expected = np.array([1.0, 2.0, np.nan]) + assert_array_equal(result, expected) + + +@pytest.mark.parametrize("copy", [True, False]) +def test_series_array_method_copy_zero_copy(copy: bool) -> None: + s = pl.Series([1, 2, 3]) + result = s.__array__(copy=copy) + + assert result.flags.writeable is copy + + +def test_df_array_method() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0]}) + + out_array = np.asarray(df, order="F") + expected_array = np.array([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]], dtype=np.float64) + assert_array_equal(out_array, expected_array) + assert out_array.flags["F_CONTIGUOUS"] is True + + out_array = np.asarray(df, dtype=np.uint8, order="C") + expected_array = np.array([[1, 1], [2, 2], [3, 3]], dtype=np.uint8) + assert_array_equal(out_array, expected_array) + assert out_array.flags["C_CONTIGUOUS"] is True diff --git a/py-polars/tests/unit/interop/numpy/test_from_numpy_df.py b/py-polars/tests/unit/interop/numpy/test_from_numpy_df.py new file mode 100644 index 000000000000..d7003cb9fcf6 --- /dev/null +++ b/py-polars/tests/unit/interop/numpy/test_from_numpy_df.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import pytest +from numpy.testing import assert_array_equal + +import polars as pl +from polars.testing import assert_frame_equal + +if TYPE_CHECKING: + import numpy.typing as npt + + from polars._typing import PolarsDataType, PolarsTemporalType + + +def test_from_numpy() -> None: + data = np.array([[1, 2, 3], [4, 5, 6]]) + df = pl.from_numpy( + data, + schema=["a", "b"], + orient="col", + schema_overrides={"a": pl.UInt32, "b": pl.UInt32}, + ) + assert df.shape == (3, 2) + assert df.rows() == [(1, 4), (2, 5), (3, 6)] + assert df.schema == {"a": pl.UInt32, "b": pl.UInt32} + data2 = np.array(["foo", "bar"], dtype=object) + df2 = pl.from_numpy(data2) + assert df2.shape == (2, 1) + assert df2.rows() == [("foo",), ("bar",)] + assert df2.schema == {"column_0": pl.String} + with pytest.raises( + ValueError, + match="cannot create DataFrame from array with more than two dimensions", + ): + _ = pl.from_numpy(np.array([[[1]]])) + with pytest.raises( + ValueError, match="cannot create DataFrame from zero-dimensional array" + ): + _ = pl.from_numpy(np.array(1)) + + +def test_from_numpy_array_value() -> None: + df = pl.DataFrame({"A": [[2, 3]]}) + assert df.rows() == [([2, 3],)] + assert df.schema == {"A": pl.List(pl.Int64)} + + +def test_construct_from_ndarray_value() -> None: + array_cell = np.array([2, 3]) + df = pl.DataFrame(np.array([[array_cell, 4]], dtype=object)) + assert df.dtypes == [pl.Object, pl.Object] + to_numpy = df.to_numpy() + assert to_numpy.shape == (1, 2) + assert_array_equal(to_numpy[0][0], array_cell) + assert to_numpy[0][1] == 4 + + +def test_from_numpy_nparray_value() -> None: + array_cell = np.array([2, 3]) + df = pl.from_numpy(np.array([[array_cell, 4]], dtype=object)) + assert df.dtypes == [pl.Object, pl.Object] + to_numpy = df.to_numpy() + assert to_numpy.shape == (1, 2) + assert_array_equal(to_numpy[0][0], array_cell) + assert to_numpy[0][1] == 4 + + +def test_from_numpy_structured() -> None: + test_data = [ + ("Google Pixel 7", 521.90, True), + ("Apple iPhone 14 Pro", 999.00, True), + ("Samsung Galaxy S23 Ultra", 1199.99, False), + ("OnePlus 11", 699.00, True), + ] + # create a numpy structured array... + arr_structured = np.array( + test_data, + dtype=np.dtype( + [ + ("product", "U32"), + ("price_usd", "float64"), + ("in_stock", "bool"), + ] + ), + ) + # ...and also establish as a record array view + arr_records = arr_structured.view(np.recarray) + + # confirm that we can cleanly initialise a DataFrame from both, + # respecting the native dtypes and any schema overrides, etc. + for arr in (arr_structured, arr_records): + df = pl.DataFrame(data=arr).sort(by="price_usd", descending=True) + + assert df.schema == { + "product": pl.String, + "price_usd": pl.Float64, + "in_stock": pl.Boolean, + } + assert df.rows() == sorted(test_data, key=lambda row: -row[1]) + + for df in ( + pl.DataFrame( + data=arr, schema=["phone", ("price_usd", pl.Float32), "available"] + ), + pl.DataFrame( + data=arr, + schema=["phone", "price_usd", "available"], + schema_overrides={"price_usd": pl.Float32}, + ), + ): + assert df.schema == { + "phone": pl.String, + "price_usd": pl.Float32, + "available": pl.Boolean, + } + + +def test_from_numpy2() -> None: + # note: numpy timeunit support is limited to those supported by polars. + # as a result, datetime64[s] raises + x = np.asarray(range(100_000, 200_000, 10_000), dtype="datetime64[s]") + with pytest.raises(ValueError, match="Please cast to the closest supported unit"): + pl.Series(x) + + +@pytest.mark.parametrize( + ("numpy_time_unit", "expected_values", "expected_dtype"), + [ + ("ns", ["1970-01-02T01:12:34.123456789"], pl.Datetime("ns")), + ("us", ["1970-01-02T01:12:34.123456"], pl.Datetime("us")), + ("ms", ["1970-01-02T01:12:34.123"], pl.Datetime("ms")), + ("D", ["1970-01-02"], pl.Date), + ], +) +def test_from_numpy_supported_units( + numpy_time_unit: str, + expected_values: list[str], + expected_dtype: PolarsTemporalType, +) -> None: + values = np.array( + ["1970-01-02T01:12:34.123456789123456789"], + dtype=f"datetime64[{numpy_time_unit}]", + ) + result = pl.from_numpy(values) + expected = ( + pl.Series("column_0", expected_values).str.strptime(expected_dtype).to_frame() + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + ("np_dtype", "dtype"), + [ + (np.float64, pl.Float64), + (np.int32, pl.Int32), + ], +) +def test_from_numpy_empty(np_dtype: npt.DTypeLike, dtype: PolarsDataType) -> None: + data = np.array([], dtype=np_dtype) + result = pl.from_numpy(data, schema=["a"]) + expected = pl.Series("a", [], dtype=dtype).to_frame() + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/interop/numpy/test_from_numpy_series.py b/py-polars/tests/unit/interop/numpy/test_from_numpy_series.py new file mode 100644 index 000000000000..4c7f3d8c14d2 --- /dev/null +++ b/py-polars/tests/unit/interop/numpy/test_from_numpy_series.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from datetime import timedelta +from typing import TYPE_CHECKING + +import numpy as np +import pytest + +import polars as pl + +if TYPE_CHECKING: + from polars._typing import TimeUnit + + +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_from_numpy_timedelta(time_unit: TimeUnit) -> None: + s = pl.Series( + "name", + np.array( + [timedelta(days=1), timedelta(seconds=1)], dtype=f"timedelta64[{time_unit}]" + ), + ) + assert s.dtype == pl.Duration(time_unit) + assert s.name == "name" + assert s.dt[0] == timedelta(days=1) + assert s.dt[1] == timedelta(seconds=1) diff --git a/py-polars/tests/unit/interop/numpy/test_numpy.py b/py-polars/tests/unit/interop/numpy/test_numpy.py new file mode 100644 index 000000000000..1b5cf0aac019 --- /dev/null +++ b/py-polars/tests/unit/interop/numpy/test_numpy.py @@ -0,0 +1,102 @@ +from typing import Any + +import numpy as np +import pytest +from numpy.testing import assert_array_equal + +import polars as pl + + +@pytest.fixture( + params=[ + ("int8", [1, 3, 2], pl.Int8, np.int8), + ("int16", [1, 3, 2], pl.Int16, np.int16), + ("int32", [1, 3, 2], pl.Int32, np.int32), + ("int64", [1, 3, 2], pl.Int64, np.int64), + ("uint8", [1, 3, 2], pl.UInt8, np.uint8), + ("uint16", [1, 3, 2], pl.UInt16, np.uint16), + ("uint32", [1, 3, 2], pl.UInt32, np.uint32), + ("uint64", [1, 3, 2], pl.UInt64, np.uint64), + ("float16", [-123.0, 0.0, 456.0], pl.Float32, np.float16), + ("float32", [21.7, 21.8, 21], pl.Float32, np.float32), + ("float64", [21.7, 21.8, 21], pl.Float64, np.float64), + ("bool", [True, False, False], pl.Boolean, np.bool_), + ("object", [21.7, "string1", object()], pl.Object, np.object_), + ("str", ["string1", "string2", "string3"], pl.String, np.str_), + ("intc", [1, 3, 2], pl.Int32, np.intc), + ("uintc", [1, 3, 2], pl.UInt32, np.uintc), + ("str_fixed", ["string1", "string2", "string3"], pl.String, np.str_), + ( + "bytes", + [b"byte_string1", b"byte_string2", b"byte_string3"], + pl.Binary, + np.bytes_, + ), + ] +) +def numpy_interop_test_data(request: Any) -> Any: + return request.param + + +def test_df_from_numpy(numpy_interop_test_data: Any) -> None: + name, values, pl_dtype, np_dtype = numpy_interop_test_data + df = pl.DataFrame({name: np.array(values, dtype=np_dtype)}) + assert [pl_dtype] == df.dtypes + + +def test_asarray(numpy_interop_test_data: Any) -> None: + name, values, pl_dtype, np_dtype = numpy_interop_test_data + pl_series_to_numpy_array = np.asarray(pl.Series(name, values, pl_dtype)) + numpy_array = np.asarray(values, dtype=np_dtype) + assert_array_equal(pl_series_to_numpy_array, numpy_array) + + +def test_to_numpy(numpy_interop_test_data: Any) -> None: + name, values, pl_dtype, np_dtype = numpy_interop_test_data + pl_series_to_numpy_array = pl.Series(name, values, pl_dtype).to_numpy() + numpy_array = np.asarray(values, dtype=np_dtype) + assert_array_equal(pl_series_to_numpy_array, numpy_array) + + +def test_numpy_to_lit() -> None: + out = pl.select(pl.lit(np.array([1, 2, 3]))).to_series().to_list() + assert out == [1, 2, 3] + out = pl.select(pl.lit(np.float32(0))).to_series().to_list() + assert out == [0.0] + + +def test_numpy_disambiguation() -> None: + a = np.array([1, 2]) + df = pl.DataFrame({"a": a}) + result = df.with_columns(b=a).to_dict(as_series=False) # type: ignore[arg-type] + expected = { + "a": [1, 2], + "b": [1, 2], + } + assert result == expected + + +def test_respect_dtype_with_series_from_numpy() -> None: + assert pl.Series("foo", np.array([1, 2, 3]), dtype=pl.UInt32).dtype == pl.UInt32 + + +@pytest.mark.parametrize( + ("np_dtype_cls", "expected_pl_dtype"), + [ + (np.int8, pl.Int8), + (np.int16, pl.Int16), + (np.int32, pl.Int32), + (np.int64, pl.Int64), + (np.uint8, pl.UInt8), + (np.uint16, pl.UInt16), + (np.uint32, pl.UInt32), + (np.uint64, pl.UInt64), + (np.float16, pl.Float32), # << note: we don't currently have a native f16 + (np.float32, pl.Float32), + (np.float64, pl.Float64), + ], +) +def test_init_from_numpy_values(np_dtype_cls: Any, expected_pl_dtype: Any) -> None: + # test init from raw numpy values (vs arrays) + s = pl.Series("n", [np_dtype_cls(0), np_dtype_cls(4), np_dtype_cls(8)]) + assert s.dtype == expected_pl_dtype diff --git a/py-polars/tests/unit/interop/numpy/test_to_numpy_df.py b/py-polars/tests/unit/interop/numpy/test_to_numpy_df.py new file mode 100644 index 000000000000..f45edc05cd0b --- /dev/null +++ b/py-polars/tests/unit/interop/numpy/test_to_numpy_df.py @@ -0,0 +1,289 @@ +from __future__ import annotations + +from datetime import datetime +from decimal import Decimal as D +from typing import TYPE_CHECKING, Any + +import numpy as np +import pytest +from hypothesis import given +from numpy.testing import assert_array_equal, assert_equal + +import polars as pl +from polars.testing import assert_frame_equal +from polars.testing.parametric import series + +if TYPE_CHECKING: + import numpy.typing as npt + + from polars._typing import IndexOrder, PolarsDataType + + +def assert_zero_copy(s: pl.Series, arr: np.ndarray[Any, Any]) -> None: + if s.len() == 0: + return + s_ptr = s._get_buffers()["values"]._get_buffer_info()[0] + arr_ptr = arr.__array_interface__["data"][0] + assert s_ptr == arr_ptr + + +@given( + s=series( + min_size=6, + max_size=6, + allowed_dtypes=[pl.Datetime, pl.Duration], + allow_null=False, + allow_chunks=False, + ) +) +@pytest.mark.may_fail_auto_streaming +def test_df_to_numpy_zero_copy(s: pl.Series) -> None: + df = pl.DataFrame({"a": s[:3], "b": s[3:]}) + + result = df.to_numpy(allow_copy=False) + + assert_zero_copy(s, result) + assert result.flags.writeable is False + + +@pytest.mark.parametrize( + ("order", "f_contiguous", "c_contiguous"), + [ + ("fortran", True, False), + ("c", False, True), + ], +) +def test_to_numpy(order: IndexOrder, f_contiguous: bool, c_contiguous: bool) -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0]}) + + out_array = df.to_numpy(order=order) + expected_array = np.array([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]], dtype=np.float64) + assert_array_equal(out_array, expected_array) + assert out_array.flags["F_CONTIGUOUS"] == f_contiguous + assert out_array.flags["C_CONTIGUOUS"] == c_contiguous + + structured_array = df.to_numpy(structured=True, order=order) + expected_array = np.array( + [(1, 1.0), (2, 2.0), (3, 3.0)], dtype=[("a", " None: + # round-trip structured array: validate init/export + structured_array = np.array( + [ + ("Google Pixel 7", 521.90, True), + ("Apple iPhone 14 Pro", 999.00, True), + ("OnePlus 11", 699.00, True), + ("Samsung Galaxy S23 Ultra", 1199.99, False), + ], + dtype=np.dtype( + [ + ("product", "U24"), + ("price_usd", "float64"), + ("in_stock", "bool"), + ] + ), + ) + df = pl.from_numpy(structured_array) + assert df.schema == { + "product": pl.String, + "price_usd": pl.Float64, + "in_stock": pl.Boolean, + } + exported_array = df.to_numpy(structured=True) + assert exported_array["product"].dtype == np.dtype("U24") + assert_array_equal(exported_array, structured_array) + + # none/nan values + df = pl.DataFrame({"x": ["a", None, "b"], "y": [5.5, None, -5.5]}) + exported_array = df.to_numpy(structured=True) + + assert exported_array.dtype == np.dtype([("x", object), ("y", float)]) + for name in df.columns: + assert_equal( + list(exported_array[name]), + ( + df[name].fill_null(float("nan")) + if df.schema[name].is_float() + else df[name] + ).to_list(), + ) + + +def test_numpy_preserve_uint64_4112() -> None: + df = pl.DataFrame({"a": [1, 2, 3]}).with_columns(pl.col("a").hash()) + assert df.to_numpy().dtype == np.dtype("uint64") + assert df.to_numpy(structured=True).dtype == np.dtype([("a", "uint64")]) + + +def test_df_to_numpy_decimal() -> None: + decimal_data = [D("1.234"), D("2.345"), D("-3.456")] + df = pl.Series("n", decimal_data).to_frame() + + result = df.to_numpy() + + expected = np.array(decimal_data).reshape((-1, 1)) + assert_array_equal(result, expected) + + +def test_df_to_numpy_zero_copy_path() -> None: + rows = 10 + cols = 5 + x = np.ones((rows, cols), order="F") + x[:, 1] = 2.0 + df = pl.DataFrame(x) + x = df.to_numpy(allow_copy=False) + assert x.flags.f_contiguous is True + assert x.flags.writeable is False + assert str(x[0, :]) == "[1. 2. 1. 1. 1.]" + + +@pytest.mark.may_fail_auto_streaming +def test_df_to_numpy_zero_copy_path_temporal() -> None: + values = [datetime(1970 + i, 1, 1) for i in range(12)] + s = pl.Series(values) + df = pl.DataFrame({"a": s[:4], "b": s[4:8], "c": s[8:]}) + + result: npt.NDArray[np.generic] = df.to_numpy(allow_copy=False) + assert result.flags.f_contiguous is True + assert result.flags.writeable is False + assert result.tolist() == [list(row) for row in df.iter_rows()] + + +def test_to_numpy_zero_copy_path_writable() -> None: + rows = 10 + cols = 5 + x = np.ones((rows, cols), order="F") + x[:, 1] = 2.0 + df = pl.DataFrame(x) + x = df.to_numpy(writable=True) + assert x.flags["WRITEABLE"] + + +def test_df_to_numpy_structured_not_zero_copy() -> None: + df = pl.DataFrame({"a": [1, 2]}) + msg = "cannot create structured array without copying data" + with pytest.raises(RuntimeError, match=msg): + df.to_numpy(structured=True, allow_copy=False) + + +def test_df_to_numpy_writable_not_zero_copy() -> None: + df = pl.DataFrame({"a": [1, 2]}) + msg = "copy not allowed: cannot create a writable array without copying data" + with pytest.raises(RuntimeError, match=msg): + df.to_numpy(allow_copy=False, writable=True) + + +def test_df_to_numpy_not_zero_copy() -> None: + df = pl.DataFrame({"a": [1, 2, None]}) + with pytest.raises(RuntimeError): + df.to_numpy(allow_copy=False) + + +@pytest.mark.parametrize( + ("schema", "expected_dtype"), + [ + ({"a": pl.Int8, "b": pl.Int8}, np.int8), + ({"a": pl.Int8, "b": pl.UInt16}, np.int32), + ({"a": pl.Int8, "b": pl.String}, np.object_), + ], +) +def test_df_to_numpy_empty_dtype_viewable( + schema: dict[str, PolarsDataType], expected_dtype: npt.DTypeLike +) -> None: + df = pl.DataFrame(schema=schema) + result = df.to_numpy(allow_copy=False) + assert result.shape == (0, 2) + assert result.dtype == expected_dtype + assert result.flags.writeable is True + + +def test_df_to_numpy_structured_nested() -> None: + df = pl.DataFrame( + { + "a": [1, 2], + "b": [3.0, 4.0], + "c": [{"x": "a", "y": 1.0}, {"x": "b", "y": 2.0}], + } + ) + result = df.to_numpy(structured=True) + + expected = np.array( + [ + (1, 3.0, ("a", 1.0)), + (2, 4.0, ("b", 2.0)), + ], + dtype=[ + ("a", " None: + df = pl.DataFrame( + {"a": [[1, 2]], "b": 1}, + schema={"a": pl.Array(pl.Int64, 2), "b": pl.Int32}, + ) + result = df.to_numpy() + + expected = np.array([[np.array([1, 2]), 1]], dtype=np.object_) + + assert result.shape == (1, 2) + assert result[0].shape == (2,) + assert_array_equal(result[0][0], expected[0][0]) + + +@pytest.mark.parametrize("order", ["c", "fortran"]) +def test_df_to_numpy_stacking_string(order: IndexOrder) -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": ["x", "y", "z"]}) + result = df.to_numpy(order=order) + + expected = np.array([[1, "x"], [2, "y"], [3, "z"]], dtype=np.object_) + + assert_array_equal(result, expected) + if order == "c": + assert result.flags.c_contiguous is True + else: + assert result.flags.f_contiguous is True + + +def test_to_numpy_chunked_16375() -> None: + assert ( + pl.concat( + [ + pl.DataFrame({"a": [1, 1, 2], "b": [2, 3, 4]}), + pl.DataFrame({"a": [1, 1, 2], "b": [2, 3, 4]}), + ], + rechunk=False, + ).to_numpy() + == np.array([[1, 2], [1, 3], [2, 4], [1, 2], [1, 3], [2, 4]]) + ).all() + + +def test_to_numpy_c_order_1700() -> None: + rng = np.random.default_rng() + df = pl.DataFrame({f"col_{i}": rng.normal(size=20) for i in range(3)}) + df_chunked = pl.concat([df.slice(i * 10, 10) for i in range(3)]) + assert_frame_equal( + df_chunked, + pl.from_numpy(df_chunked.to_numpy(order="c"), schema=df_chunked.schema), + ) diff --git a/py-polars/tests/unit/interop/numpy/test_to_numpy_series.py b/py-polars/tests/unit/interop/numpy/test_to_numpy_series.py new file mode 100644 index 000000000000..663640641059 --- /dev/null +++ b/py-polars/tests/unit/interop/numpy/test_to_numpy_series.py @@ -0,0 +1,486 @@ +from __future__ import annotations + +from datetime import date, datetime, time, timedelta +from decimal import Decimal as D +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import numpy as np +import pytest +from hypothesis import given, settings +from numpy.testing import assert_array_equal + +import polars as pl +from polars.testing import assert_series_equal +from polars.testing.parametric import series + +if TYPE_CHECKING: + import numpy.typing as npt + + from polars._typing import PolarsDataType + + +def assert_zero_copy(s: pl.Series, arr: np.ndarray[Any, Any]) -> None: + if s.len() == 0: + return + s_ptr = s._get_buffers()["values"]._get_buffer_info()[0] + arr_ptr = arr.__array_interface__["data"][0] + assert s_ptr == arr_ptr + + +def assert_allow_copy_false_raises(s: pl.Series) -> None: + with pytest.raises(RuntimeError, match="copy not allowed"): + s.to_numpy(allow_copy=False) + + +@pytest.mark.parametrize( + ("dtype", "expected_dtype"), + [ + (pl.Int8, np.int8), + (pl.Int16, np.int16), + (pl.Int32, np.int32), + (pl.Int64, np.int64), + (pl.UInt8, np.uint8), + (pl.UInt16, np.uint16), + (pl.UInt32, np.uint32), + (pl.UInt64, np.uint64), + (pl.Float32, np.float32), + (pl.Float64, np.float64), + ], +) +def test_series_to_numpy_numeric_zero_copy( + dtype: PolarsDataType, expected_dtype: npt.DTypeLike +) -> None: + s = pl.Series([1, 2, 3]).cast(dtype) + result: npt.NDArray[np.generic] = s.to_numpy(allow_copy=False) + + assert_zero_copy(s, result) + assert result.tolist() == s.to_list() + assert result.dtype == expected_dtype + + +@pytest.mark.parametrize( + ("dtype", "expected_dtype"), + [ + (pl.Int8, np.float32), + (pl.Int16, np.float32), + (pl.Int32, np.float64), + (pl.Int64, np.float64), + (pl.UInt8, np.float32), + (pl.UInt16, np.float32), + (pl.UInt32, np.float64), + (pl.UInt64, np.float64), + (pl.Float32, np.float32), + (pl.Float64, np.float64), + ], +) +def test_series_to_numpy_numeric_with_nulls( + dtype: PolarsDataType, expected_dtype: npt.DTypeLike +) -> None: + s = pl.Series([1, 2, None], dtype=dtype, strict=False) + result: npt.NDArray[np.generic] = s.to_numpy() + + assert result.tolist()[:-1] == s.to_list()[:-1] + assert np.isnan(result[-1]) + assert result.dtype == expected_dtype + assert_allow_copy_false_raises(s) + + +@pytest.mark.parametrize( + ("dtype", "expected_dtype"), + [ + (pl.Duration, np.dtype("timedelta64[us]")), + (pl.Duration("ms"), np.dtype("timedelta64[ms]")), + (pl.Duration("us"), np.dtype("timedelta64[us]")), + (pl.Duration("ns"), np.dtype("timedelta64[ns]")), + (pl.Datetime, np.dtype("datetime64[us]")), + (pl.Datetime("ms"), np.dtype("datetime64[ms]")), + (pl.Datetime("us"), np.dtype("datetime64[us]")), + (pl.Datetime("ns"), np.dtype("datetime64[ns]")), + ], +) +def test_series_to_numpy_temporal_zero_copy( + dtype: PolarsDataType, expected_dtype: npt.DTypeLike +) -> None: + values = [0, 2_000, 1_000_000] + s = pl.Series(values, dtype=dtype, strict=False) + result: npt.NDArray[np.generic] = s.to_numpy(allow_copy=False) + + assert_zero_copy(s, result) + # NumPy tolist returns integers for ns precision + if s.dtype.time_unit == "ns": # type: ignore[attr-defined] + assert result.tolist() == values + else: + assert result.tolist() == s.to_list() + assert result.dtype == expected_dtype + + +def test_series_to_numpy_datetime_with_tz_zero_copy() -> None: + values = [datetime(1970, 1, 1), datetime(2024, 2, 28)] + s = pl.Series(values).dt.convert_time_zone("Europe/Amsterdam").rechunk() + result: npt.NDArray[np.generic] = s.to_numpy(allow_copy=False) + + assert_zero_copy(s, result) + assert result.tolist() == values + assert result.dtype == np.dtype("datetime64[us]") + + +def test_series_to_numpy_date() -> None: + values = [date(1970, 1, 1), date(2024, 2, 28)] + s = pl.Series(values) + + result: npt.NDArray[np.generic] = s.to_numpy() + + assert s.to_list() == result.tolist() + assert result.dtype == np.dtype("datetime64[D]") + assert result.flags.writeable is True + assert_allow_copy_false_raises(s) + + +def test_series_to_numpy_multi_dimensional_init() -> None: + s = pl.Series(np.atleast_3d(np.array([-10.5, 0.0, 10.5]))) + assert_series_equal( + s, + pl.Series( + [[[-10.5], [0.0], [10.5]]], + dtype=pl.Array(pl.Float64, shape=(3, 1)), + ), + ) + s = pl.Series(np.array(0), dtype=pl.Int32) + assert_series_equal(s, pl.Series([0], dtype=pl.Int32)) + + +@pytest.mark.parametrize( + ("dtype", "expected_dtype"), + [ + (pl.Date, np.dtype("datetime64[D]")), + (pl.Duration("ms"), np.dtype("timedelta64[ms]")), + (pl.Duration("us"), np.dtype("timedelta64[us]")), + (pl.Duration("ns"), np.dtype("timedelta64[ns]")), + (pl.Datetime, np.dtype("datetime64[us]")), + (pl.Datetime("ms"), np.dtype("datetime64[ms]")), + (pl.Datetime("us"), np.dtype("datetime64[us]")), + (pl.Datetime("ns"), np.dtype("datetime64[ns]")), + ], +) +def test_series_to_numpy_temporal_with_nulls( + dtype: PolarsDataType, expected_dtype: npt.DTypeLike +) -> None: + values = [0, 2_000, 1_000_000, None] + s = pl.Series(values, dtype=dtype, strict=False) + result: npt.NDArray[np.generic] = s.to_numpy() + + # NumPy tolist returns integers for ns precision + if getattr(s.dtype, "time_unit", None) == "ns": + assert result.tolist() == values + else: + assert result.tolist() == s.to_list() + assert result.dtype == expected_dtype + assert_allow_copy_false_raises(s) + + +def test_series_to_numpy_datetime_with_tz_with_nulls() -> None: + values = [datetime(1970, 1, 1), datetime(2024, 2, 28), None] + s = pl.Series(values).dt.convert_time_zone("Europe/Amsterdam") + result: npt.NDArray[np.generic] = s.to_numpy() + + assert result.tolist() == values + assert result.dtype == np.dtype("datetime64[us]") + assert_allow_copy_false_raises(s) + + +@pytest.mark.parametrize( + ("dtype", "values"), + [ + (pl.Time, [time(10, 30, 45), time(23, 59, 59)]), + (pl.Categorical, ["a", "b", "a"]), + (pl.Enum(["a", "b", "c"]), ["a", "b", "a"]), + (pl.String, ["a", "bc", "def"]), + (pl.Binary, [b"a", b"bc", b"def"]), + (pl.Decimal, [D("1.234"), D("2.345"), D("-3.456")]), + (pl.Object, [Path(), Path("abc")]), + ], +) +@pytest.mark.parametrize("with_nulls", [False, True]) +def test_to_numpy_object_dtypes( + dtype: PolarsDataType, values: list[Any], with_nulls: bool +) -> None: + if with_nulls: + values.append(None) + + s = pl.Series(values, dtype=dtype) + result: npt.NDArray[np.generic] = s.to_numpy() + + assert result.tolist() == values + assert result.dtype == np.object_ + assert_allow_copy_false_raises(s) + + +def test_series_to_numpy_bool() -> None: + s = pl.Series([True, False]) + result: npt.NDArray[np.generic] = s.to_numpy() + + assert s.to_list() == result.tolist() + assert result.dtype == np.bool_ + assert result.flags.writeable is True + assert_allow_copy_false_raises(s) + + +def test_series_to_numpy_bool_with_nulls() -> None: + s = pl.Series([True, False, None]) + result: npt.NDArray[np.generic] = s.to_numpy() + + assert s.to_list() == result.tolist() + assert result.dtype == np.object_ + assert_allow_copy_false_raises(s) + + +def test_series_to_numpy_array_of_int() -> None: + values = [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]] + s = pl.Series(values, dtype=pl.Array(pl.Array(pl.Int8, 3), 2)) + result = s.to_numpy(allow_copy=False) + + expected = np.array(values) + assert_array_equal(result, expected) + assert result.dtype == np.int8 + assert result.shape == (2, 2, 3) + + +def test_series_to_numpy_array_of_str() -> None: + values = [["1", "2", "3"], ["4", "5", "10000"]] + s = pl.Series(values, dtype=pl.Array(pl.String, 3)) + result: npt.NDArray[np.generic] = s.to_numpy() + assert result.tolist() == values + assert result.dtype == np.object_ + + +def test_series_to_numpy_array_with_nulls() -> None: + values = [[1, 2], [3, 4], None] + s = pl.Series(values, dtype=pl.Array(pl.Int64, 2)) + result = s.to_numpy() + + expected = np.array([[1.0, 2.0], [3.0, 4.0], [np.nan, np.nan]]) + assert_array_equal(result, expected) + assert result.dtype == np.float64 + assert_allow_copy_false_raises(s) + + +def test_series_to_numpy_array_with_nested_nulls() -> None: + values = [[None, 2], [3, 4], [5, None]] + s = pl.Series(values, dtype=pl.Array(pl.Int64, 2)) + result = s.to_numpy() + + expected = np.array([[np.nan, 2.0], [3.0, 4.0], [5.0, np.nan]]) + assert_array_equal(result, expected) + assert result.dtype == np.float64 + assert_allow_copy_false_raises(s) + + +def test_series_to_numpy_array_of_arrays() -> None: + values = [[[None, 2], [3, 4]], [None, [7, 8]]] + s = pl.Series(values, dtype=pl.Array(pl.Array(pl.Int64, 2), 2)) + result = s.to_numpy() + + expected = np.array([[[np.nan, 2], [3, 4]], [[np.nan, np.nan], [7, 8]]]) + assert_array_equal(result, expected) + assert result.dtype == np.float64 + assert result.shape == (2, 2, 2) + assert_allow_copy_false_raises(s) + + +@pytest.mark.parametrize("chunked", [True, False]) +def test_series_to_numpy_list(chunked: bool) -> None: + values = [[1, 2], [3, 4, 5], [6], []] + s = pl.Series(values) + if chunked: + s = pl.concat([s[:2], s[2:]]) + result = s.to_numpy() + + expected = np.array([np.array(v, dtype=np.int64) for v in values], dtype=np.object_) + for res, exp in zip(result, expected): + assert_array_equal(res, exp) + assert result.dtype == expected.dtype + assert_allow_copy_false_raises(s) + + +def test_series_to_numpy_struct_numeric_supertype() -> None: + values = [{"a": 1, "b": 2.0}, {"a": 3, "b": 4.0}, {"a": 5, "b": None}] + s = pl.Series(values) + result = s.to_numpy() + + expected = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, np.nan]]) + assert_array_equal(result, expected) + assert result.dtype == np.float64 + assert_allow_copy_false_raises(s) + + +def test_to_numpy_null() -> None: + s = pl.Series([None, None], dtype=pl.Null) + result = s.to_numpy() + expected = np.array([np.nan, np.nan], dtype=np.float32) + assert_array_equal(result, expected) + assert result.dtype == np.float32 + assert_allow_copy_false_raises(s) + + +def test_to_numpy_empty() -> None: + s = pl.Series(dtype=pl.String) + result = s.to_numpy(allow_copy=False) + assert result.dtype == np.object_ + assert result.shape == (0,) + + +def test_to_numpy_empty_writable() -> None: + s = pl.Series(dtype=pl.Int64) + result = s.to_numpy(allow_copy=False, writable=True) + assert result.dtype == np.int64 + assert result.shape == (0,) + assert result.flags.writeable is True + + +def test_to_numpy_chunked() -> None: + s1 = pl.Series([1, 2]) + s2 = pl.Series([3, 4]) + s = pl.concat([s1, s2], rechunk=False) + + result: npt.NDArray[np.generic] = s.to_numpy() + + assert result.tolist() == s.to_list() + assert result.dtype == np.int64 + assert result.flags.writeable is True + assert_allow_copy_false_raises(s) + + # Check that writing to the array doesn't change the original data + result[0] = 10 + assert result.tolist() == [10, 2, 3, 4] + assert s.to_list() == [1, 2, 3, 4] + + +def test_to_numpy_chunked_temporal_nested() -> None: + dtype = pl.Array(pl.Datetime("us"), 1) + s1 = pl.Series([[datetime(2020, 1, 1)], [datetime(2021, 1, 1)]], dtype=dtype) + s2 = pl.Series([[datetime(2022, 1, 1)], [datetime(2023, 1, 1)]], dtype=dtype) + s = pl.concat([s1, s2], rechunk=False) + + result: npt.NDArray[np.generic] = s.to_numpy() + + assert result.tolist() == s.to_list() + assert result.dtype == np.dtype("datetime64[us]") + assert result.shape == (4, 1) + assert result.flags.writeable is True + assert_allow_copy_false_raises(s) + + +def test_zero_copy_only_deprecated() -> None: + values = [1, 2] + s = pl.Series([1, 2]) + with pytest.deprecated_call(): + result: npt.NDArray[np.generic] = s.to_numpy(zero_copy_only=True) + assert result.tolist() == values + + +def test_series_to_numpy_temporal() -> None: + s0 = pl.Series("date", [123543, 283478, 1243]).cast(pl.Date) + s1 = pl.Series( + "datetime", [datetime(2021, 1, 2, 3, 4, 5), datetime(2021, 2, 3, 4, 5, 6)] + ) + s2 = pl.datetime_range( + datetime(2021, 1, 1, 0), + datetime(2021, 1, 1, 1), + interval="1h", + time_unit="ms", + eager=True, + ) + assert str(s0.to_numpy()) == "['2308-04-02' '2746-02-20' '1973-05-28']" + assert ( + str(s1.to_numpy()[:2]) + == "['2021-01-02T03:04:05.000000' '2021-02-03T04:05:06.000000']" + ) + assert ( + str(s2.to_numpy()[:2]) + == "['2021-01-01T00:00:00.000' '2021-01-01T01:00:00.000']" + ) + s3 = pl.Series([timedelta(hours=1), timedelta(hours=-2)]) + out = np.array([3_600_000_000_000, -7_200_000_000_000], dtype="timedelta64[ns]") + assert (s3.to_numpy() == out).all() + + +@given( + s=series( + min_size=1, + max_size=10, + excluded_dtypes=[ + pl.Categorical, + pl.List, + pl.Struct, + pl.Datetime("ms"), + pl.Duration("ms"), + ], + allow_null=False, + allow_time_zones=False, # NumPy does not support parsing time zone aware data + ).filter( + lambda s: ( + not (s.dtype == pl.String and s.str.contains("\x00").any()) + and not (s.dtype == pl.Binary and s.bin.contains(b"\x00").any()) + ) + ), +) +@settings(max_examples=250) +def test_series_to_numpy(s: pl.Series) -> None: + result = s.to_numpy() + + values = s.to_list() + dtype_map = { + pl.Datetime("ns"): "datetime64[ns]", + pl.Datetime("us"): "datetime64[us]", + pl.Duration("ns"): "timedelta64[ns]", + pl.Duration("us"): "timedelta64[us]", + pl.Null(): "float32", + } + np_dtype = dtype_map.get(s.dtype) + expected = np.array(values, dtype=np_dtype) + + assert_array_equal(result, expected) + + +@pytest.mark.parametrize("writable", [False, True]) +@pytest.mark.parametrize("pyarrow_available", [False, True]) +def test_to_numpy2( + writable: bool, pyarrow_available: bool, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.setattr(pl.series.series, "_PYARROW_AVAILABLE", pyarrow_available) + + np_array = pl.Series("a", [1, 2, 3], pl.UInt8).to_numpy(writable=writable) + + np.testing.assert_array_equal(np_array, np.array([1, 2, 3], dtype=np.uint8)) + # Test if numpy array is readonly or writable. + assert np_array.flags.writeable == writable + + if writable: + np_array[1] += 10 + np.testing.assert_array_equal(np_array, np.array([1, 12, 3], dtype=np.uint8)) + + np_array_with_missing_values = pl.Series("a", [None, 2, 3], pl.UInt8).to_numpy( + writable=writable + ) + + np.testing.assert_array_equal( + np_array_with_missing_values, + np.array( + [np.nan, 2.0, 3.0], + dtype=(np.float64 if pyarrow_available else np.float32), + ), + ) + + if writable: + # As Null values can't be encoded natively in a numpy array, + # this array will never be a view. + assert np_array_with_missing_values.flags.writeable == writable + + +def test_to_numpy_series_indexed_18986() -> None: + df = pl.DataFrame({"a": [[4, 5, 6], [7, 8, 9, 10], None]}) + assert (df[1].to_numpy()[0, 0] == np.array([7, 8, 9, 10])).all() + assert ( + df.to_numpy()[2] == np.array([None]) + ).all() # this one is strange, but only option in numpy? diff --git a/py-polars/tests/unit/interop/numpy/test_ufunc_expr.py b/py-polars/tests/unit/interop/numpy/test_ufunc_expr.py new file mode 100644 index 000000000000..fda0c5a5f722 --- /dev/null +++ b/py-polars/tests/unit/interop/numpy/test_ufunc_expr.py @@ -0,0 +1,195 @@ +from __future__ import annotations + +from typing import Any, Callable, cast + +import numpy as np +import pytest + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal + + +def test_ufunc() -> None: + df = pl.DataFrame([pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt8)]) + out = df.select( + np.power(pl.col("a"), 2).alias("power_uint8"), # type: ignore[call-overload] + np.power(pl.col("a"), 2.0).alias("power_float64"), # type: ignore[call-overload] + np.power(pl.col("a"), 2, dtype=np.uint16).alias("power_uint16"), # type: ignore[call-overload] + ) + expected = pl.DataFrame( + [ + pl.Series("power_uint8", [1, 4, 9, 16], dtype=pl.UInt8), + pl.Series("power_float64", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64), + pl.Series("power_uint16", [1, 4, 9, 16], dtype=pl.UInt16), + ] + ) + assert_frame_equal(out, expected) + assert out.dtypes == expected.dtypes + + +def test_ufunc_expr_not_first() -> None: + """Check numpy ufunc expressions also work if expression not the first argument.""" + df = pl.DataFrame([pl.Series("a", [1, 2, 3], dtype=pl.Float64)]) + out = df.select( + np.power(2.0, cast(Any, pl.col("a"))).alias("power"), + (2.0 / cast(Any, pl.col("a"))).alias("divide_scalar"), + ) + expected = pl.DataFrame( + [ + pl.Series("power", [2**1, 2**2, 2**3], dtype=pl.Float64), + pl.Series("divide_scalar", [2 / 1, 2 / 2, 2 / 3], dtype=pl.Float64), + ] + ) + assert_frame_equal(out, expected) + + +def test_lazy_ufunc() -> None: + ldf = pl.LazyFrame([pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt8)]) + out = ldf.select( + np.power(cast(Any, pl.col("a")), 2).alias("power_uint8"), + np.power(cast(Any, pl.col("a")), 2.0).alias("power_float64"), + np.power(cast(Any, pl.col("a")), 2, dtype=np.uint16).alias("power_uint16"), + ) + expected = pl.DataFrame( + [ + pl.Series("power_uint8", [1, 4, 9, 16], dtype=pl.UInt8), + pl.Series("power_float64", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64), + pl.Series("power_uint16", [1, 4, 9, 16], dtype=pl.UInt16), + ] + ) + assert_frame_equal(out.collect(), expected) + + +def test_lazy_ufunc_expr_not_first() -> None: + """Check numpy ufunc expressions also work if expression not the first argument.""" + ldf = pl.LazyFrame([pl.Series("a", [1, 2, 3], dtype=pl.Float64)]) + out = ldf.select( + np.power(2.0, cast(Any, pl.col("a"))).alias("power"), + (2.0 / cast(Any, pl.col("a"))).alias("divide_scalar"), + ) + expected = pl.DataFrame( + [ + pl.Series("power", [2**1, 2**2, 2**3], dtype=pl.Float64), + pl.Series("divide_scalar", [2 / 1, 2 / 2, 2 / 3], dtype=pl.Float64), + ] + ) + assert_frame_equal(out.collect(), expected) + + +def test_ufunc_recognition() -> None: + df = pl.DataFrame({"a": [1, 1, 2, 2], "b": [1.1, 2.2, 3.3, 4.4]}) + assert_frame_equal(df.select(np.exp(pl.col("b"))), df.select(pl.col("b").exp())) + + +# https://github.com/pola-rs/polars/issues/6770 +def test_ufunc_multiple_expressions() -> None: + df = pl.DataFrame( + { + "v": [ + -4.293, + -2.4659, + -1.8378, + -0.2821, + -4.5649, + -3.8128, + -7.4274, + 3.3443, + 3.8604, + -4.2200, + ], + "u": [ + -11.2268, + 6.3478, + 7.1681, + 3.4986, + 2.7320, + -1.0695, + -10.1408, + 11.2327, + 6.6623, + -8.1412, + ], + } + ) + expected = np.arctan2(df.get_column("v"), df.get_column("u")) + result = df.select(np.arctan2(pl.col("v"), pl.col("u")))[:, 0] # type: ignore[call-overload] + assert_series_equal(expected, result) # type: ignore[arg-type] + + +def test_repeated_name_ufunc_17472() -> None: + """If a ufunc takes multiple inputs has a repeating name, this works.""" + df = pl.DataFrame({"a": [6.0]}) + result = df.select(np.divide(pl.col("a"), pl.col("a"))) # type: ignore[call-overload] + expected = pl.DataFrame({"a": [1.0]}) + assert_frame_equal(expected, result) + + +def test_grouped_ufunc() -> None: + df = pl.DataFrame({"id": ["a", "a", "b", "b"], "values": [0.1, 0.1, -0.1, -0.1]}) + df.group_by("id").agg(pl.col("values").log1p().sum().pipe(np.expm1)) + + +def test_generalized_ufunc_scalar() -> None: + numba = pytest.importorskip("numba") + + @numba.guvectorize([(numba.int64[:], numba.int64[:])], "(n)->()") # type: ignore[misc] + def my_custom_sum(arr, result) -> None: # type: ignore[no-untyped-def] # noqa: ANN001 + total = 0 + for value in arr: + total += value + result[0] = total + + # Make type checkers happy: + custom_sum = cast(Callable[[object], object], my_custom_sum) + + # Demonstrate NumPy as the canonical expected behavior: + assert custom_sum(np.array([10, 2, 3], dtype=np.int64)) == 15 + + # Direct call of the gufunc: + df = pl.DataFrame({"values": [10, 2, 3]}) + assert custom_sum(df.get_column("values")) == 15 + + # Indirect call of the gufunc: + indirect = df.select(pl.col("values").map_batches(custom_sum, returns_scalar=True)) + assert_frame_equal(indirect, pl.DataFrame({"values": 15})) + indirect = df.select(pl.col("values").map_batches(custom_sum, returns_scalar=False)) + assert_frame_equal(indirect, pl.DataFrame({"values": [15]})) + + # group_by() + df = pl.DataFrame({"labels": ["a", "b", "a", "b"], "values": [10, 2, 3, 30]}) + indirect = ( + df.group_by("labels") + .agg(pl.col("values").map_batches(custom_sum, returns_scalar=True)) + .sort("labels") + ) + assert_frame_equal( + indirect, pl.DataFrame({"labels": ["a", "b"], "values": [13, 32]}) + ) + + +def make_gufunc_mean() -> Callable[[pl.Series], pl.Series]: + numba = pytest.importorskip("numba") + + @numba.guvectorize([(numba.float64[:], numba.float64[:])], "(n)->(n)") # type: ignore[misc] + def gufunc_mean(arr: Any, result: Any) -> None: + mean = arr.mean() + for i in range(len(arr)): + result[i] = mean + i + + return gufunc_mean # type: ignore[no-any-return] + + +def test_generalized_ufunc() -> None: + gufunc_mean = make_gufunc_mean() + df = pl.DataFrame({"s": [1.0, 2.0, 3.0]}) + result = df.select([pl.col("s").map_batches(gufunc_mean).alias("result")]) + expected = pl.DataFrame({"result": [2.0, 3.0, 4.0]}) + assert_frame_equal(result, expected) + + +def test_grouped_generalized_ufunc() -> None: + gufunc_mean = make_gufunc_mean() + df = pl.DataFrame({"id": ["a", "a", "b", "b"], "values": [1.0, 2.0, 3.0, 4.0]}) + result = df.group_by("id").agg(pl.col("values").map_batches(gufunc_mean)).sort("id") + expected = pl.DataFrame({"id": ["a", "b"], "values": [[1.5, 2.5], [3.5, 4.5]]}) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/interop/numpy/test_ufunc_series.py b/py-polars/tests/unit/interop/numpy/test_ufunc_series.py new file mode 100644 index 000000000000..73007949d0c4 --- /dev/null +++ b/py-polars/tests/unit/interop/numpy/test_ufunc_series.py @@ -0,0 +1,194 @@ +from typing import Any, Callable, cast + +import numpy as np +import pytest +from numpy.testing import assert_array_equal + +import polars as pl +from polars.exceptions import ComputeError +from polars.testing import assert_series_equal + + +def test_ufunc() -> None: + # test if output dtype is calculated correctly. + s_float32 = pl.Series("a", [1.0, 2.0, 3.0, 4.0], dtype=pl.Float32) + assert_series_equal( + cast(pl.Series, np.multiply(s_float32, 4)), + pl.Series("a", [4.0, 8.0, 12.0, 16.0], dtype=pl.Float32), + ) + + s_float64 = pl.Series("a", [1.0, 2.0, 3.0, 4.0], dtype=pl.Float64) + assert_series_equal( + cast(pl.Series, np.multiply(s_float64, 4)), + pl.Series("a", [4.0, 8.0, 12.0, 16.0], dtype=pl.Float64), + ) + + s_uint8 = pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt8) + assert_series_equal( + cast(pl.Series, np.power(s_uint8, 2)), + pl.Series("a", [1, 4, 9, 16], dtype=pl.UInt8), + ) + assert_series_equal( + cast(pl.Series, np.power(s_uint8, 2.0)), + pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64), + ) + assert_series_equal( + cast(pl.Series, np.power(s_uint8, 2, dtype=np.uint16)), + pl.Series("a", [1, 4, 9, 16], dtype=pl.UInt16), + ) + + s_int8 = pl.Series("a", [1, -2, 3, -4], dtype=pl.Int8) + assert_series_equal( + cast(pl.Series, np.power(s_int8, 2)), + pl.Series("a", [1, 4, 9, 16], dtype=pl.Int8), + ) + assert_series_equal( + cast(pl.Series, np.power(s_int8, 2.0)), + pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64), + ) + assert_series_equal( + cast(pl.Series, np.power(s_int8, 2, dtype=np.int16)), + pl.Series("a", [1, 4, 9, 16], dtype=pl.Int16), + ) + + s_uint32 = pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt32) + assert_series_equal( + cast(pl.Series, np.power(s_uint32, 2)), + pl.Series("a", [1, 4, 9, 16], dtype=pl.UInt32), + ) + assert_series_equal( + cast(pl.Series, np.power(s_uint32, 2.0)), + pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64), + ) + + s_int32 = pl.Series("a", [1, -2, 3, -4], dtype=pl.Int32) + assert_series_equal( + cast(pl.Series, np.power(s_int32, 2)), + pl.Series("a", [1, 4, 9, 16], dtype=pl.Int32), + ) + assert_series_equal( + cast(pl.Series, np.power(s_int32, 2.0)), + pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64), + ) + + s_uint64 = pl.Series("a", [1, 2, 3, 4], dtype=pl.UInt64) + assert_series_equal( + cast(pl.Series, np.power(s_uint64, 2)), + pl.Series("a", [1, 4, 9, 16], dtype=pl.UInt64), + ) + assert_series_equal( + cast(pl.Series, np.power(s_uint64, 2.0)), + pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64), + ) + + s_int64 = pl.Series("a", [1, -2, 3, -4], dtype=pl.Int64) + assert_series_equal( + cast(pl.Series, np.power(s_int64, 2)), + pl.Series("a", [1, 4, 9, 16], dtype=pl.Int64), + ) + assert_series_equal( + cast(pl.Series, np.power(s_int64, 2.0)), + pl.Series("a", [1.0, 4.0, 9.0, 16.0], dtype=pl.Float64), + ) + + # test if null bitmask is preserved + a1 = pl.Series("a", [1.0, None, 3.0]) + b1 = cast(pl.Series, np.exp(a1)) + assert b1.null_count() == 1 + + # test if it works with chunked series. + a2 = pl.Series("a", [1.0, None, 3.0]) + b2 = pl.Series("b", [4.0, 5.0, None]) + a2.append(b2) + assert a2.n_chunks() == 2 + c2 = np.multiply(a2, 3) + assert_series_equal( + cast(pl.Series, c2), + pl.Series("a", [3.0, None, 9.0, 12.0, 15.0, None]), + ) + + # Test if nulls propagate through ufuncs + a3 = pl.Series("a", [None, None, 3, 3]) + b3 = pl.Series("b", [None, 3, None, 3]) + assert_series_equal( + cast(pl.Series, np.maximum(a3, b3)), pl.Series("a", [None, None, None, 3]) + ) + + +def test_numpy_string_array() -> None: + s_str = pl.Series("a", ["aa", "bb", "cc", "dd"], dtype=pl.String) + assert_array_equal( + np.char.capitalize(s_str), + np.array(["Aa", "Bb", "Cc", "Dd"], dtype=" Callable[[pl.Series], pl.Series]: + numba = pytest.importorskip("numba") + + @numba.guvectorize([(numba.float64[:], numba.float64[:])], "(n)->(n)") # type: ignore[misc] + def add_one(arr: Any, result: Any) -> None: + for i in range(len(arr)): + result[i] = arr[i] + 1.0 + + return add_one # type: ignore[no-any-return] + + +def test_generalized_ufunc() -> None: + """A generalized ufunc can be called on a pl.Series.""" + add_one = make_add_one() + s_float = pl.Series("f", [1.0, 2.0, 3.0]) + result = add_one(s_float) + assert_series_equal(result, pl.Series("f", [2.0, 3.0, 4.0])) + + +def test_generalized_ufunc_missing_data() -> None: + """ + If a pl.Series is missing data, using a generalized ufunc is not allowed. + + While this particular example isn't necessarily a semantic issue, consider + a mean() function running on integers: it will give wrong results if the + input is missing data, since NumPy has no way to model missing slots. In + the general case, we can't assume the function will handle missing data + correctly. + """ + add_one = make_add_one() + s_float = pl.Series("f", [1.0, 2.0, 3.0, None], dtype=pl.Float64) + with pytest.raises( + ComputeError, + match="Can't pass a Series with missing data to a generalized ufunc", + ): + add_one(s_float) + + +def make_divide_by_sum() -> Callable[[pl.Series, pl.Series], pl.Series]: + numba = pytest.importorskip("numba") + float64 = numba.float64 + + @numba.guvectorize([(float64[:], float64[:], float64[:])], "(n),(m)->(m)") # type: ignore[misc] + def divide_by_sum(arr: Any, arr2: Any, result: Any) -> None: + total = arr.sum() + for i in range(len(arr2)): + result[i] = arr2[i] / total + + return divide_by_sum # type: ignore[no-any-return] + + +def test_generalized_ufunc_different_output_size() -> None: + """ + It's possible to call a generalized ufunc that takes pl.Series of different sizes. + + The result has the correct size. + """ + divide_by_sum = make_divide_by_sum() + + series = pl.Series("s", [1.0, 3.0], dtype=pl.Float64) + series2 = pl.Series("s2", [8.0, 16.0, 32.0], dtype=pl.Float64) + assert_series_equal( + divide_by_sum(series, series2), + pl.Series("s", [2.0, 4.0, 8.0], dtype=pl.Float64), + ) + assert_series_equal( + divide_by_sum(series2, series), + pl.Series("s2", [1.0 / 56, 3.0 / 56], dtype=pl.Float64), + ) diff --git a/py-polars/tests/unit/interop/test_from_pandas.py b/py-polars/tests/unit/interop/test_from_pandas.py new file mode 100644 index 000000000000..1c2c2eb47641 --- /dev/null +++ b/py-polars/tests/unit/interop/test_from_pandas.py @@ -0,0 +1,408 @@ +from __future__ import annotations + +from datetime import datetime, timedelta +from typing import TYPE_CHECKING, Any + +import numpy as np +import pandas as pd +import pytest + +import polars as pl +from polars.testing import assert_frame_equal +from polars.testing.asserts.series import assert_series_equal +from tests.unit.conftest import with_string_cache_if_auto_streaming + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + + +def test_index_not_silently_excluded() -> None: + ddict = {"a": [1, 2, 3], "b": [4, 5, 6]} + df = pd.DataFrame(ddict, index=pd.Index([7, 8, 9], name="a")) + with pytest.raises(ValueError, match="indices and column names must not overlap"): + pl.from_pandas(df, include_index=True) + + +def test_nameless_multiindex_doesnt_raise_with_include_index_false_18130() -> None: + df = pd.DataFrame( + range(4), + columns=["A"], + index=pd.MultiIndex.from_product((["C", "D"], [3, 4])), + ) + result = pl.from_pandas(df) + expected = pl.DataFrame({"A": [0, 1, 2, 3]}) + assert_frame_equal(result, expected) + + +def test_from_pandas() -> None: + df = pd.DataFrame( + { + "bools": [False, True, False], + "bools_nulls": [None, True, False], + "int": [1, 2, 3], + "int_nulls": [1, None, 3], + "floats": [1.0, 2.0, 3.0], + "floats_nulls": [1.0, None, 3.0], + "strings": ["foo", "bar", "ham"], + "strings_nulls": ["foo", None, "ham"], + "strings-cat": ["foo", "bar", "ham"], + } + ) + df["strings-cat"] = df["strings-cat"].astype("category") + + out = pl.from_pandas(df) + assert out.shape == (3, 9) + assert out.schema == { + "bools": pl.Boolean, + "bools_nulls": pl.Boolean, + "int": pl.Int64, + "int_nulls": pl.Float64, + "floats": pl.Float64, + "floats_nulls": pl.Float64, + "strings": pl.String, + "strings_nulls": pl.String, + "strings-cat": pl.Categorical(ordering="physical"), + } + assert out.rows() == [ + (False, None, 1, 1.0, 1.0, 1.0, "foo", "foo", "foo"), + (True, True, 2, None, 2.0, None, "bar", None, "bar"), + (False, False, 3, 3.0, 3.0, 3.0, "ham", "ham", "ham"), + ] + + # partial dtype overrides from pandas + overrides = {"int": pl.Int8, "int_nulls": pl.Int32, "floats": pl.Float32} + out = pl.from_pandas(df, schema_overrides=overrides) + for col, dtype in overrides.items(): + assert out.schema[col] == dtype + + +@pytest.mark.parametrize( + "nulls", + [ + [], + [None], + [None, None], + [None, None, None], + ], +) +def test_from_pandas_nulls(nulls: list[None]) -> None: + # empty and/or all null values, no pandas dtype + ps = pd.Series(nulls) + s = pl.from_pandas(ps) + assert nulls == s.to_list() + + +def test_from_pandas_nan_to_null() -> None: + df = pd.DataFrame( + { + "bools_nulls": [None, True, False], + "int_nulls": [1, None, 3], + "floats_nulls": [1.0, None, 3.0], + "strings_nulls": ["foo", None, "ham"], + "nulls": [None, np.nan, np.nan], + } + ) + out_true = pl.from_pandas(df) + out_false = pl.from_pandas(df, nan_to_null=False) + assert all(val is None for val in out_true["nulls"]) + assert all(np.isnan(val) for val in out_false["nulls"][1:]) + + df = pd.Series([2, np.nan, None], name="pd") # type: ignore[assignment] + out_true = pl.from_pandas(df) + out_false = pl.from_pandas(df, nan_to_null=False) + assert [val is None for val in out_true] + assert [np.isnan(val) for val in out_false[1:]] + + +def test_from_pandas_datetime() -> None: + ts = datetime(2021, 1, 1, 20, 20, 20, 20) + pd_s = pd.Series([ts, ts]) + tmp = pl.from_pandas(pd_s.to_frame("a")) + s = tmp["a"] + assert s.dt.hour()[0] == 20 + assert s.dt.minute()[0] == 20 + assert s.dt.second()[0] == 20 + + date_times = pd.date_range("2021-06-24 00:00:00", "2021-06-24 09:00:00", freq="1h") + s = pl.from_pandas(date_times) + assert s[0] == datetime(2021, 6, 24, 0, 0) + assert s[-1] == datetime(2021, 6, 24, 9, 0) + + +@pytest.mark.parametrize( + ("index_class", "index_data", "index_params", "expected_data", "expected_dtype"), + [ + (pd.Index, [100, 200, 300], {}, None, pl.Int64), + (pd.Index, [1, 2, 3], {"dtype": "uint32"}, None, pl.UInt32), + (pd.RangeIndex, 5, {}, [0, 1, 2, 3, 4], pl.Int64), + (pd.CategoricalIndex, ["N", "E", "S", "W"], {}, None, pl.Categorical), + ( + pd.DatetimeIndex, + [datetime(1960, 12, 31), datetime(2077, 10, 20)], + {"dtype": "datetime64[ms]"}, + None, + pl.Datetime("ms"), + ), + ( + pd.TimedeltaIndex, + ["24 hours", "2 days 8 hours", "3 days 42 seconds"], + {}, + [timedelta(1), timedelta(days=2, hours=8), timedelta(days=3, seconds=42)], + pl.Duration("ns"), + ), + ], +) +def test_from_pandas_index( + index_class: Any, + index_data: Any, + index_params: dict[str, Any], + expected_data: list[Any] | None, + expected_dtype: PolarsDataType, +) -> None: + if expected_data is None: + expected_data = index_data + + s = pl.from_pandas(index_class(index_data, **index_params)) + assert s.to_list() == expected_data + assert s.dtype == expected_dtype + + +def test_from_pandas_include_indexes() -> None: + data = { + "dtm": [datetime(2021, 1, 1), datetime(2021, 1, 2), datetime(2021, 1, 3)], + "val": [100, 200, 300], + "misc": ["x", "y", "z"], + } + pd_df = pd.DataFrame(data) + + df = pl.from_pandas(pd_df.set_index(["dtm"])) + assert df.to_dict(as_series=False) == { + "val": [100, 200, 300], + "misc": ["x", "y", "z"], + } + + df = pl.from_pandas(pd_df.set_index(["dtm", "val"])) + assert df.to_dict(as_series=False) == {"misc": ["x", "y", "z"]} + + df = pl.from_pandas(pd_df.set_index(["dtm"]), include_index=True) + assert df.to_dict(as_series=False) == data + + df = pl.from_pandas(pd_df.set_index(["dtm", "val"]), include_index=True) + assert df.to_dict(as_series=False) == data + + +def test_from_pandas_series_include_indexes() -> None: + # no default index + pd_series = pd.Series({"a": 1, "b": 2}, name="number").rename_axis(["letter"]) + df = pl.from_pandas(pd_series, include_index=True) + assert df.to_dict(as_series=False) == {"letter": ["a", "b"], "number": [1, 2]} + + # default index + pd_series = pd.Series(range(2)) + df = pl.from_pandas(pd_series, include_index=True) + assert df.to_dict(as_series=False) == {"index": [0, 1], "0": [0, 1]} + + +def test_duplicate_cols_diff_types() -> None: + df = pd.DataFrame([[1, 2, 3, 4], [5, 6, 7, 8]], columns=["0", 0, "1", 1]) + with pytest.raises( + ValueError, + match="Pandas dataframe contains non-unique indices and/or column names", + ): + pl.from_pandas(df) + + +def test_from_pandas_duplicated_columns() -> None: + df = pd.DataFrame([[1, 2, 3, 4], [5, 6, 7, 8]], columns=["a", "b", "c", "b"]) + with pytest.raises( + ValueError, + match="Pandas dataframe contains non-unique indices and/or column names", + ): + pl.from_pandas(df) + + +def test_from_pandas_null() -> None: + # null column is an object dtype, so pl.Utf8 is most close + df = pd.DataFrame([{"a": None}, {"a": None}]) + out = pl.DataFrame(df) + assert out.dtypes == [pl.String] + assert out["a"][0] is None + + df = pd.DataFrame([{"a": None, "b": 1}, {"a": None, "b": 2}]) + out = pl.DataFrame(df) + assert out.dtypes == [pl.String, pl.Int64] + + +def test_from_pandas_nested_list() -> None: + # this panicked in https://github.com/pola-rs/polars/issues/1615 + pddf = pd.DataFrame( + {"a": [1, 2, 3, 4], "b": [["x", "y"], ["x", "y", "z"], ["x"], ["x", "y"]]} + ) + pldf = pl.from_pandas(pddf) + assert pldf.shape == (4, 2) + assert pldf.rows() == [ + (1, ["x", "y"]), + (2, ["x", "y", "z"]), + (3, ["x"]), + (4, ["x", "y"]), + ] + + +def test_from_pandas_categorical_none() -> None: + s = pd.Series(["a", "b", "c", pd.NA], dtype="category") + out = pl.from_pandas(s) + assert out.dtype == pl.Categorical + assert out.to_list() == ["a", "b", "c", None] + + +def test_from_pandas_dataframe() -> None: + pd_df = pd.DataFrame([[1, 2, 3], [4, 5, 6]], columns=["a", "b", "c"]) + df = pl.from_pandas(pd_df) + assert df.shape == (2, 3) + assert df.rows() == [(1, 2, 3), (4, 5, 6)] + + # if not a pandas dataframe, raise a ValueError + with pytest.raises(TypeError): + _ = pl.from_pandas([1, 2]) # type: ignore[call-overload] + + +def test_from_pandas_series() -> None: + pd_series = pd.Series([1, 2, 3], name="pd") + s = pl.from_pandas(pd_series) + assert s.shape == (3,) + assert list(s) == [1, 2, 3] + + +def test_from_empty_pandas() -> None: + pandas_df = pd.DataFrame( + { + "A": [], + "fruits": [], + } + ) + polars_df = pl.from_pandas(pandas_df) + assert polars_df.columns == ["A", "fruits"] + assert polars_df.dtypes == [pl.Float64, pl.Float64] + + +def test_from_null_column() -> None: + df = pl.from_pandas(pd.DataFrame(data=[pd.NA, pd.NA], columns=["n/a"])) + + assert df.shape == (2, 1) + assert df.columns == ["n/a"] + assert df.dtypes[0] == pl.Null + + +def test_from_pandas_ns_resolution() -> None: + df = pd.DataFrame( + [pd.Timestamp(year=2021, month=1, day=1, hour=1, second=1, nanosecond=1)], + columns=["date"], + ) + assert pl.from_pandas(df)[0, 0] == datetime(2021, 1, 1, 1, 0, 1) + + +def test_pandas_string_none_conversion_3298() -> None: + data: dict[str, list[str | None]] = {"col_1": ["a", "b", "c", "d"]} + data["col_1"][0] = None + df_pd = pd.DataFrame(data) + df_pl = pl.DataFrame(df_pd) + assert df_pl.to_series().to_list() == [None, "b", "c", "d"] + + +def test_from_pandas_null_struct_6412() -> None: + data = [ + { + "a": { + "b": None, + }, + }, + {"a": None}, + ] + df_pandas = pd.DataFrame(data) + assert pl.from_pandas(df_pandas).to_dict(as_series=False) == { + "a": [{"b": None}, None] + } + + +@with_string_cache_if_auto_streaming +def test_untrusted_categorical_input() -> None: + df_pd = pd.DataFrame({"x": pd.Categorical(["x"], ["x", "y"])}) + df = pl.from_pandas(df_pd) + result = df.group_by("x").len() + expected = pl.DataFrame( + {"x": ["x"], "len": [1]}, schema={"x": pl.Categorical, "len": pl.UInt32} + ) + assert_frame_equal(result, expected, categorical_as_str=True) + + +@pytest.fixture +def _set_pyarrow_unavailable(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "polars._utils.construction.dataframe._PYARROW_AVAILABLE", False + ) + monkeypatch.setattr("polars._utils.construction.series._PYARROW_AVAILABLE", False) + + +@pytest.mark.usefixtures("_set_pyarrow_unavailable") +def test_from_pandas_pyarrow_not_available_succeeds() -> None: + data: dict[str, Any] = { + "a": [1, 2], + "b": ["one", "two"], + "c": np.array(["2020-01-01", "2020-01-02"], dtype="datetime64[ns]"), + "d": np.array(["2020-01-01", "2020-01-02"], dtype="datetime64[us]"), + "e": np.array(["2020-01-01", "2020-01-02"], dtype="datetime64[ms]"), + "f": np.array([1, 2], dtype="timedelta64[ns]"), + "g": np.array([1, 2], dtype="timedelta64[us]"), + "h": np.array([1, 2], dtype="timedelta64[ms]"), + "i": [True, False], + } + + # DataFrame + result = pl.from_pandas(pd.DataFrame(data)) + expected = pl.DataFrame(data) + assert_frame_equal(result, expected) + + # Series + for col in data: + s_pd = pd.Series(data[col]) + result_s = pl.from_pandas(s_pd) + expected_s = pl.Series(data[col]) + assert_series_equal(result_s, expected_s) + + +@pytest.mark.usefixtures("_set_pyarrow_unavailable") +def test_from_pandas_pyarrow_not_available_fails() -> None: + with pytest.raises(ImportError, match="pyarrow is required"): + pl.from_pandas(pd.DataFrame({"a": [1, 2, 3]}, dtype="Int64")) + with pytest.raises(ImportError, match="pyarrow is required"): + pl.from_pandas(pd.Series([1, 2, 3], dtype="Int64")) + with pytest.raises(ImportError, match="pyarrow is required"): + pl.from_pandas( + pd.DataFrame({"a": pd.to_datetime(["2020-01-01T00:00+01:00"]).to_series()}) + ) + with pytest.raises(ImportError, match="pyarrow is required"): + pl.from_pandas(pd.DataFrame({"a": [None, "foo"]})) + + +def test_from_pandas_nan_to_null_16453(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "polars._utils.construction.dataframe._MIN_NUMPY_SIZE_FOR_MULTITHREADING", 2 + ) + df = pd.DataFrame( + {"a": [np.nan, 1.0, 2], "b": [1.0, 2.0, 3.0], "c": [4.0, 5.0, 6.0]} + ) + result = pl.from_pandas(df, nan_to_null=True) + expected = pl.DataFrame( + {"a": [None, 1.0, 2], "b": [1.0, 2.0, 3.0], "c": [4.0, 5.0, 6.0]} + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("null", [pd.NA, np.nan, None, float("nan")]) +def test_from_pandas_string_with_natype_17355(null: Any) -> None: + # https://github.com/pola-rs/polars/issues/17355 + + pd_df = pd.DataFrame({"col": ["a", null]}) + result = pl.from_pandas(pd_df) + expected = pl.DataFrame({"col": ["a", None]}) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/interop/test_interop.py b/py-polars/tests/unit/interop/test_interop.py new file mode 100644 index 000000000000..fcec4d508f42 --- /dev/null +++ b/py-polars/tests/unit/interop/test_interop.py @@ -0,0 +1,902 @@ +from __future__ import annotations + +from datetime import date, datetime, time, timezone +from typing import Any, cast + +import numpy as np +import pandas as pd +import pyarrow as pa +import pytest + +import polars as pl +from polars.exceptions import ComputeError, UnstableWarning +from polars.interchange.protocol import CompatLevel +from polars.testing import assert_frame_equal, assert_series_equal +from tests.unit.utils.pycapsule_utils import PyCapsuleStreamHolder + + +def test_arrow_list_roundtrip() -> None: + # https://github.com/pola-rs/polars/issues/1064 + tbl = pa.table({"a": [1], "b": [[1, 2]]}) + arw = pl.from_arrow(tbl).to_arrow() + + assert arw.shape == tbl.shape + assert arw.schema.names == tbl.schema.names + for c1, c2 in zip(arw.columns, tbl.columns): + assert c1.to_pylist() == c2.to_pylist() + + +def test_arrow_null_roundtrip() -> None: + tbl = pa.table({"a": [None, None], "b": [[None, None], [None, None]]}) + df = pl.from_arrow(tbl) + + if isinstance(df, pl.DataFrame): + assert df.dtypes == [pl.Null, pl.List(pl.Null)] + + arw = df.to_arrow() + + assert arw.shape == tbl.shape + assert arw.schema.names == tbl.schema.names + for c1, c2 in zip(arw.columns, tbl.columns): + assert c1.to_pylist() == c2.to_pylist() + + +def test_arrow_empty_dataframe() -> None: + # 0x0 dataframe + df = pl.DataFrame({}) + tbl = pa.table({}) + assert df.to_arrow() == tbl + df2 = cast(pl.DataFrame, pl.from_arrow(df.to_arrow())) + assert_frame_equal(df2, df) + + # 0 row dataframe + df = pl.DataFrame({}, schema={"a": pl.Int32}) + tbl = pa.Table.from_batches([], pa.schema([pa.field("a", pa.int32())])) + assert df.to_arrow() == tbl + df2 = cast(pl.DataFrame, pl.from_arrow(df.to_arrow())) + assert df2.schema == {"a": pl.Int32} + assert df2.shape == (0, 1) + + +def test_arrow_dict_to_polars() -> None: + pa_dict = pa.DictionaryArray.from_arrays( + indices=np.array([0, 1, 2, 3, 1, 0, 2, 3, 3, 2]), + dictionary=np.array(["AAA", "BBB", "CCC", "DDD"]), + ).cast(pa.large_utf8()) + + s = pl.Series( + name="pa_dict", + values=["AAA", "BBB", "CCC", "DDD", "BBB", "AAA", "CCC", "DDD", "DDD", "CCC"], + ) + assert_series_equal(s, pl.Series("pa_dict", pa_dict)) + + +def test_arrow_list_chunked_array() -> None: + a = pa.array([[1, 2], [3, 4]]) + ca = pa.chunked_array([a, a, a]) + s = cast(pl.Series, pl.from_arrow(ca)) + assert s.dtype == pl.List + + +# Test that polars convert Arrays of logical types correctly to arrow +def test_arrow_array_logical() -> None: + # cast to large string and uint32 indices because polars converts to those + pa_data1 = ( + pa.array(["a", "b", "c", "d"]) + .dictionary_encode() + .cast(pa.dictionary(pa.uint32(), pa.large_string())) + ) + pa_array_logical1 = pa.FixedSizeListArray.from_arrays(pa_data1, 2) + + s1 = pl.Series( + values=[["a", "b"], ["c", "d"]], + dtype=pl.Array(pl.Enum(["a", "b", "c", "d"]), shape=2), + ) + assert s1.to_arrow() == pa_array_logical1 + + pa_data2 = pa.array([date(2024, 1, 1), date(2024, 1, 2)]) + pa_array_logical2 = pa.FixedSizeListArray.from_arrays(pa_data2, 1) + + s2 = pl.Series( + values=[[date(2024, 1, 1)], [date(2024, 1, 2)]], + dtype=pl.Array(pl.Date, shape=1), + ) + assert s2.to_arrow() == pa_array_logical2 + + +def test_from_dict() -> None: + data = {"a": [1, 2], "b": [3, 4]} + df = pl.from_dict(data) + assert df.shape == (2, 2) + for s1, s2 in zip(list(df), [pl.Series("a", [1, 2]), pl.Series("b", [3, 4])]): + assert_series_equal(s1, s2) + + +def test_from_dict_struct() -> None: + data: dict[str, dict[str, list[int]] | list[int]] = { + "a": {"b": [1, 3], "c": [2, 4]}, + "d": [5, 6], + } + df = pl.from_dict(data) + assert df.shape == (2, 2) + assert df["a"][0] == {"b": 1, "c": 2} + assert df["a"][1] == {"b": 3, "c": 4} + assert df.schema == {"a": pl.Struct({"b": pl.Int64, "c": pl.Int64}), "d": pl.Int64} + + +def test_from_dicts() -> None: + data = [{"a": 1, "b": 4}, {"a": 2, "b": 5}, {"a": 3, "b": None}] + df = pl.from_dicts(data) # type: ignore[arg-type] + assert df.shape == (3, 2) + assert df.rows() == [(1, 4), (2, 5), (3, None)] + assert df.schema == {"a": pl.Int64, "b": pl.Int64} + + +def test_from_dict_no_inference() -> None: + schema = {"a": pl.String} + data = [{"a": "aa"}] + df = pl.from_dicts(data, schema_overrides=schema, infer_schema_length=0) + assert df.schema == schema + assert df.to_dicts() == data + + +def test_from_dicts_schema_override() -> None: + schema = { + "a": pl.String, + "b": pl.Int64, + "c": pl.List(pl.Struct({"x": pl.Int64, "y": pl.String, "z": pl.Float64})), + } + + # initial data matches the expected schema + data1 = [ + { + "a": "l", + "b": i, + "c": [{"x": (j + 2), "y": "?", "z": (j % 2)} for j in range(2)], + } + for i in range(5) + ] + + # extend with a mix of fields that are/not in the schema + data2 = [{"b": i + 5, "d": "ABC", "e": "DEF"} for i in range(5)] + + for n_infer in (0, 3, 5, 8, 10, 100): + df = pl.DataFrame( + data=(data1 + data2), + schema=schema, # type: ignore[arg-type] + infer_schema_length=n_infer, + ) + assert df.schema == schema + assert df.rows() == [ + ("l", 0, [{"x": 2, "y": "?", "z": 0.0}, {"x": 3, "y": "?", "z": 1.0}]), + ("l", 1, [{"x": 2, "y": "?", "z": 0.0}, {"x": 3, "y": "?", "z": 1.0}]), + ("l", 2, [{"x": 2, "y": "?", "z": 0.0}, {"x": 3, "y": "?", "z": 1.0}]), + ("l", 3, [{"x": 2, "y": "?", "z": 0.0}, {"x": 3, "y": "?", "z": 1.0}]), + ("l", 4, [{"x": 2, "y": "?", "z": 0.0}, {"x": 3, "y": "?", "z": 1.0}]), + (None, 5, None), + (None, 6, None), + (None, 7, None), + (None, 8, None), + (None, 9, None), + ] + + +def test_from_dicts_struct() -> None: + data = [{"a": {"b": 1, "c": 2}, "d": 5}, {"a": {"b": 3, "c": 4}, "d": 6}] + df = pl.from_dicts(data) + assert df.shape == (2, 2) + assert df["a"][0] == {"b": 1, "c": 2} + assert df["a"][1] == {"b": 3, "c": 4} + + # 5649 + assert pl.from_dicts([{"a": [{"x": 1}]}, {"a": [{"y": 1}]}]).to_dict( + as_series=False + ) == {"a": [[{"y": None, "x": 1}], [{"y": 1, "x": None}]]} + assert pl.from_dicts([{"a": [{"x": 1}, {"y": 2}]}, {"a": [{"y": 1}]}]).to_dict( + as_series=False + ) == {"a": [[{"y": None, "x": 1}, {"y": 2, "x": None}], [{"y": 1, "x": None}]]} + + +def test_from_records() -> None: + data = [[1, 2, 3], [4, 5, 6]] + df = pl.from_records(data, schema=["a", "b"]) + assert df.shape == (3, 2) + assert df.rows() == [(1, 4), (2, 5), (3, 6)] + + +# https://github.com/pola-rs/polars/issues/15195 +@pytest.mark.parametrize( + "input", + [ + pl.Series([1, 2]), + pl.Series([{"a": 1, "b": 2}]), + pl.DataFrame({"a": [1, 2], "b": [3, 4]}), + ], +) +def test_from_records_non_sequence_input(input: Any) -> None: + with pytest.raises(TypeError, match="expected data of type Sequence"): + pl.from_records(input) + + +def test_from_arrow() -> None: + data = pa.table({"a": [1, 2, 3], "b": [4, 5, 6]}) + df = pl.from_arrow(data) + assert df.shape == (3, 2) + assert df.rows() == [(1, 4), (2, 5), (3, 6)] # type: ignore[union-attr] + + # if not a PyArrow type, raise a TypeError + with pytest.raises(TypeError): + _ = pl.from_arrow([1, 2]) + + df = pl.from_arrow( + data, schema=["a", "b"], schema_overrides={"a": pl.UInt32, "b": pl.UInt64} + ) + assert df.rows() == [(1, 4), (2, 5), (3, 6)] # type: ignore[union-attr] + assert df.schema == {"a": pl.UInt32, "b": pl.UInt64} # type: ignore[union-attr] + + +def test_from_arrow_with_bigquery_metadata() -> None: + arrow_schema = pa.schema( + [ + pa.field("id", pa.int64()).with_metadata( + {"ARROW:extension:name": "google:sqlType:integer"} + ), + pa.field( + "misc", + pa.struct([("num", pa.int32()), ("val", pa.string())]), + ).with_metadata({"ARROW:extension:name": "google:sqlType:struct"}), + ] + ) + arrow_tbl = pa.Table.from_pylist( + [{"id": 1, "misc": None}, {"id": 2, "misc": None}], + schema=arrow_schema, + ) + + expected_data = {"id": [1, 2], "num": [None, None], "val": [None, None]} + expected_schema = {"id": pl.Int64, "num": pl.Int32, "val": pl.String} + assert_frame_equal( + pl.DataFrame(expected_data, schema=expected_schema), + pl.from_arrow(arrow_tbl).unnest("misc"), # type: ignore[union-attr] + ) + + +def test_from_optional_not_available() -> None: + from polars.dependencies import _LazyModule + + # proxy module is created dynamically if the required module is not available + # (see the polars.dependencies source code for additional detail/comments) + + np = _LazyModule("numpy", module_available=False) + with pytest.raises(ImportError, match=r"np\.array requires 'numpy'"): + pl.from_numpy(np.array([[1, 2], [3, 4]]), schema=["a", "b"]) + + pa = _LazyModule("pyarrow", module_available=False) + with pytest.raises(ImportError, match=r"pa\.table requires 'pyarrow'"): + pl.from_arrow(pa.table({"a": [1, 2, 3], "b": [4, 5, 6]})) + + pd = _LazyModule("pandas", module_available=False) + with pytest.raises(ImportError, match=r"pd\.Series requires 'pandas'"): + pl.from_pandas(pd.Series([1, 2, 3])) + + +def test_upcast_pyarrow_dicts() -> None: + # https://github.com/pola-rs/polars/issues/1752 + tbls = [ + pa.table( + { + "col_name": pa.array( + [f"value_{i}"], pa.dictionary(pa.int8(), pa.string()) + ) + } + ) + for i in range(128) + ] + + tbl = pa.concat_tables(tbls, promote_options="default") + out = cast(pl.DataFrame, pl.from_arrow(tbl)) + assert out.shape == (128, 1) + assert out["col_name"][0] == "value_0" + assert out["col_name"][127] == "value_127" + + +def test_no_rechunk() -> None: + table = pa.Table.from_pydict({"x": pa.chunked_array([list("ab"), list("cd")])}) + # table + assert pl.from_arrow(table, rechunk=False).n_chunks() == 2 + # chunked array + assert pl.from_arrow(table["x"], rechunk=False).n_chunks() == 2 + + +def test_from_empty_arrow() -> None: + df = cast(pl.DataFrame, pl.from_arrow(pa.table(pd.DataFrame({"a": [], "b": []})))) + assert df.columns == ["a", "b"] + assert df.dtypes == [pl.Float64, pl.Float64] + + # 2705 + df1 = pd.DataFrame(columns=["b"], dtype=float, index=pd.Index([])) + tbl = pa.Table.from_pandas(df1) + out = cast(pl.DataFrame, pl.from_arrow(tbl)) + assert out.columns == ["b", "__index_level_0__"] + assert out.dtypes == [pl.Float64, pl.Null] + tbl = pa.Table.from_pandas(df1, preserve_index=False) + out = cast(pl.DataFrame, pl.from_arrow(tbl)) + assert out.columns == ["b"] + assert out.dtypes == [pl.Float64] + + # 4568 + tbl = pa.table({"l": []}, schema=pa.schema([("l", pa.large_list(pa.uint8()))])) + + df = cast(pl.DataFrame, pl.from_arrow(tbl)) + assert df.schema["l"] == pl.List(pl.UInt8) + + +def test_cat_int_types_3500() -> None: + with pl.StringCache(): + # Create an enum / categorical / dictionary typed pyarrow array + # Most simply done by creating a pandas categorical series first + categorical_s = pd.Series(["a", "a", "b"], dtype="category") + pyarrow_array = pa.Array.from_pandas(categorical_s) + + # The in-memory representation of each category can either be a signed or + # unsigned 8-bit integer. Pandas uses Int8... + int_dict_type = pa.dictionary(index_type=pa.int8(), value_type=pa.utf8()) + # ... while DuckDB uses UInt8 + uint_dict_type = pa.dictionary(index_type=pa.uint8(), value_type=pa.utf8()) + + for t in [int_dict_type, uint_dict_type]: + s = cast(pl.Series, pl.from_arrow(pyarrow_array.cast(t))) + assert_series_equal( + s, pl.Series(["a", "a", "b"]).cast(pl.Categorical), check_names=False + ) + + +def test_from_pyarrow_chunked_array() -> None: + column = pa.chunked_array([[1], [2]]) + series = pl.Series("column", column) + assert series.to_list() == [1, 2] + + +def test_arrow_list_null_5697() -> None: + # Create a pyarrow table with a list[null] column. + pa_table = pa.table([[[None]]], names=["mycol"]) + df = pl.from_arrow(pa_table) + pa_table = df.to_arrow() + # again to polars to test the schema + assert pl.from_arrow(pa_table).schema == {"mycol": pl.List(pl.Null)} # type: ignore[union-attr] + + +def test_from_pyarrow_map() -> None: + pa_table = pa.table( + [[1, 2], [[("a", "something")], [("a", "else"), ("b", "another key")]]], + schema=pa.schema( + [("idx", pa.int16()), ("mapping", pa.map_(pa.string(), pa.string()))] + ), + ) + + # Convert from an empty table to trigger an ArrowSchema -> native schema + # conversion (checks that ArrowDataType::Map is handled in Rust). + pl.DataFrame(pa_table.slice(0, 0)) + + result = pl.DataFrame(pa_table) + assert result.to_dict(as_series=False) == { + "idx": [1, 2], + "mapping": [ + [{"key": "a", "value": "something"}], + [{"key": "a", "value": "else"}, {"key": "b", "value": "another key"}], + ], + } + + +def test_from_fixed_size_binary_list() -> None: + val = [[b"63A0B1C66575DD5708E1EB2B"]] + arrow_array = pa.array(val, type=pa.list_(pa.binary(24))) + s = cast(pl.Series, pl.from_arrow(arrow_array)) + assert s.dtype == pl.List(pl.Binary) + assert s.to_list() == val + + +def test_dataframe_from_repr() -> None: + # round-trip various types + with pl.StringCache(): + frame = ( + pl.LazyFrame( + { + "a": [1, 2, None], + "b": [4.5, 5.5, 6.5], + "c": ["x", "y", "z"], + "d": [True, False, True], + "e": [None, "", None], + "f": [date(2022, 7, 5), date(2023, 2, 5), date(2023, 8, 5)], + "g": [time(0, 0, 0, 1), time(12, 30, 45), time(23, 59, 59, 999000)], + "h": [ + datetime(2022, 7, 5, 10, 30, 45, 4560), + datetime(2023, 10, 12, 20, 3, 8, 11), + None, + ], + }, + ) + .with_columns( + pl.col("c").cast(pl.Categorical), + pl.col("h").cast(pl.Datetime("ns")), + ) + .collect() + ) + + assert frame.schema == { + "a": pl.Int64, + "b": pl.Float64, + "c": pl.Categorical(ordering="physical"), + "d": pl.Boolean, + "e": pl.String, + "f": pl.Date, + "g": pl.Time, + "h": pl.Datetime("ns"), + } + df = cast(pl.DataFrame, pl.from_repr(repr(frame))) + assert_frame_equal(frame, df) + + # empty frame; confirm schema is inferred + df = cast( + pl.DataFrame, + pl.from_repr( + """ + ┌─────┬─────┬─────┬─────┬─────┬───────┐ + │ id ┆ q1 ┆ q2 ┆ q3 ┆ q4 ┆ total │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ str ┆ i8 ┆ i16 ┆ i32 ┆ i64 ┆ f64 │ + ╞═════╪═════╪═════╪═════╪═════╪═══════╡ + └─────┴─────┴─────┴─────┴─────┴───────┘ + """ + ), + ) + assert df.shape == (0, 6) + assert df.rows() == [] + assert df.schema == { + "id": pl.String, + "q1": pl.Int8, + "q2": pl.Int16, + "q3": pl.Int32, + "q4": pl.Int64, + "total": pl.Float64, + } + + # empty frame with no dtypes + df = cast( + pl.DataFrame, + pl.from_repr( + """ + ┌──────┬───────┐ + │ misc ┆ other │ + ╞══════╪═══════╡ + └──────┴───────┘ + """ + ), + ) + assert_frame_equal(df, pl.DataFrame(schema={"misc": pl.String, "other": pl.String})) + + # empty frame with non-standard/blank 'null' + df = cast( + pl.DataFrame, + pl.from_repr( + """ + ┌─────┬─────┐ + │ c1 ┆ c2 │ + │ --- ┆ --- │ + │ i32 ┆ f64 │ + ╞═════╪═════╡ + │ │ │ + └─────┴─────┘ + """ + ), + ) + assert_frame_equal( + df, + pl.DataFrame( + data=[(None, None)], schema={"c1": pl.Int32, "c2": pl.Float64}, orient="row" + ), + ) + + df = cast( + pl.DataFrame, + pl.from_repr( + """ + # >>> Missing cols with old-style ellipsis, nulls, commented out + # ┌────────────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬──────┐ + # │ dt ┆ c1 ┆ c2 ┆ c3 ┆ ... ┆ c96 ┆ c97 ┆ c98 ┆ c99 │ + # │ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │ + # │ date ┆ i32 ┆ i32 ┆ i32 ┆ ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ + # ╞════════════╪═════╪═════╪═════╪═════╪═════╪═════╪═════╪══════╡ + # │ 2023-03-25 ┆ 1 ┆ 2 ┆ 3 ┆ ... ┆ 96 ┆ 97 ┆ 98 ┆ 99 │ + # │ 1999-12-31 ┆ 3 ┆ 6 ┆ 9 ┆ ... ┆ 288 ┆ 291 ┆ 294 ┆ null │ + # │ null ┆ 9 ┆ 18 ┆ 27 ┆ ... ┆ 864 ┆ 873 ┆ 882 ┆ 891 │ + # └────────────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴──────┘ + """ + ), + ) + assert df.schema == { + "dt": pl.Date, + "c1": pl.Int32, + "c2": pl.Int32, + "c3": pl.Int32, + "c96": pl.Int64, + "c97": pl.Int64, + "c98": pl.Int64, + "c99": pl.Int64, + } + assert df.rows() == [ + (date(2023, 3, 25), 1, 2, 3, 96, 97, 98, 99), + (date(1999, 12, 31), 3, 6, 9, 288, 291, 294, None), + (None, 9, 18, 27, 864, 873, 882, 891), + ] + + df = cast( + pl.DataFrame, + pl.from_repr( + """ + # >>> no dtypes: + # ┌────────────┬──────┐ + # │ dt ┆ c99 │ + # ╞════════════╪══════╡ + # │ 2023-03-25 ┆ 99 │ + # │ 1999-12-31 ┆ null │ + # │ null ┆ 891 │ + # └────────────┴──────┘ + """ + ), + ) + assert df.schema == {"dt": pl.Date, "c99": pl.Int64} + assert df.rows() == [ + (date(2023, 3, 25), 99), + (date(1999, 12, 31), None), + (None, 891), + ] + + df = cast( + pl.DataFrame, + pl.from_repr( + """ + In [2]: with pl.Config() as cfg: + ...: pl.Config.set_tbl_formatting("UTF8_FULL", rounded_corners=True) + ...: print(df) + ...: + shape: (1, 5) + ╭───────────┬────────────┬───┬───────┬────────────────────────────────╮ + │ source_ac ┆ source_cha ┆ … ┆ ident ┆ timestamp │ + │ tor_id ┆ nnel_id ┆ ┆ --- ┆ --- │ + │ --- ┆ --- ┆ ┆ str ┆ datetime[μs, Asia/Tokyo] │ + │ i32 ┆ i64 ┆ ┆ ┆ │ + ╞═══════════╪════════════╪═══╪═══════╪════════════════════════════════╡ + │ 123456780 ┆ 9876543210 ┆ … ┆ a:b:c ┆ 2023-03-25 10:56:59.663053 JST │ + ├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ … ┆ … ┆ … ┆ … ┆ … │ + ├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ 803065983 ┆ 2055938745 ┆ … ┆ x:y:z ┆ 2023-03-25 12:38:18.050545 JST │ + ╰───────────┴────────────┴───┴───────┴────────────────────────────────╯ + # "Een fluitje van een cent..." :) + """ + ), + ) + assert df.shape == (2, 4) + assert df.schema == { + "source_actor_id": pl.Int32, + "source_channel_id": pl.Int64, + "ident": pl.String, + "timestamp": pl.Datetime("us", "Asia/Tokyo"), + } + + +def test_series_from_repr() -> None: + with pl.StringCache(): + frame = ( + pl.LazyFrame( + { + "a": [1, 2, None], + "b": [4.5, 5.5, 6.5], + "c": ["x", "y", "z"], + "d": [True, False, True], + "e": [None, "", None], + "f": [date(2022, 7, 5), date(2023, 2, 5), date(2023, 8, 5)], + "g": [time(0, 0, 0, 1), time(12, 30, 45), time(23, 59, 59, 999000)], + "h": [ + datetime(2022, 7, 5, 10, 30, 45, 4560), + datetime(2023, 10, 12, 20, 3, 8, 11), + None, + ], + }, + ) + .with_columns( + pl.col("c").cast(pl.Categorical), + pl.col("h").cast(pl.Datetime("ns")), + ) + .collect() + ) + + for col in frame.columns: + s = cast(pl.Series, pl.from_repr(repr(frame[col]))) + assert_series_equal(s, frame[col]) + + s = cast( + pl.Series, + pl.from_repr( + """ + Out[3]: + shape: (3,) + Series: 's' [str] + [ + "a" + … + "c" + ] + """ + ), + ) + assert_series_equal(s, pl.Series("s", ["a", "c"])) + + s = cast( + pl.Series, + pl.from_repr( + """ + Series: 'flt' [f32] + [ + ] + """ + ), + ) + assert_series_equal(s, pl.Series("flt", [], dtype=pl.Float32)) + + s = cast( + pl.Series, + pl.from_repr( + """ + Series: 'flt' [f64] + [ + null + +inf + -inf + inf + 0.0 + NaN + ] + >>> print("stuff") + """ + ), + ) + inf, nan = float("inf"), float("nan") + assert_series_equal( + s, + pl.Series( + name="flt", + dtype=pl.Float64, + values=[None, inf, -inf, inf, 0.0, nan], + ), + ) + + +def test_dataframe_from_repr_custom_separators() -> None: + # repr created with custom digit-grouping + # and non-default group/decimal separators + df = cast( + pl.DataFrame, + pl.from_repr( + """ + ┌───────────┬────────────┐ + │ x ┆ y │ + │ --- ┆ --- │ + │ i32 ┆ f64 │ + ╞═══════════╪════════════╡ + │ 123.456 ┆ -10.000,55 │ + │ -9.876 ┆ 10,0 │ + │ 9.999.999 ┆ 8,5e8 │ + └───────────┴────────────┘ + """ + ), + ) + assert_frame_equal( + df, + pl.DataFrame( + { + "x": [123456, -9876, 9999999], + "y": [-10000.55, 10.0, 850000000.0], + }, + schema={"x": pl.Int32, "y": pl.Float64}, + ), + ) + + +def test_sliced_struct_from_arrow() -> None: + # Create a dataset with 3 rows + tbl = pa.Table.from_arrays( + arrays=[ + pa.StructArray.from_arrays( + arrays=[ + pa.array([1, 2, 3], pa.int32()), + pa.array(["foo", "bar", "baz"], pa.utf8()), + ], + names=["a", "b"], + ) + ], + names=["struct_col"], + ) + + # slice the table + # check if FFI correctly reads sliced + result = cast(pl.DataFrame, pl.from_arrow(tbl.slice(1, 2))) + assert result.to_dict(as_series=False) == { + "struct_col": [{"a": 2, "b": "bar"}, {"a": 3, "b": "baz"}] + } + + result = cast(pl.DataFrame, pl.from_arrow(tbl.slice(1, 1))) + assert result.to_dict(as_series=False) == {"struct_col": [{"a": 2, "b": "bar"}]} + + +def test_from_arrow_invalid_time_zone() -> None: + arr = pa.array( + [datetime(2021, 1, 1, 0, 0, 0, 0)], + type=pa.timestamp("ns", tz="this-is-not-a-time-zone"), + ) + with pytest.raises( + ComputeError, match=r"unable to parse time zone: 'this-is-not-a-time-zone'" + ): + pl.from_arrow(arr) + + +@pytest.mark.parametrize( + ("fixed_offset", "etc_tz"), + [ + ("+10:00", "Etc/GMT-10"), + ("10:00", "Etc/GMT-10"), + ("-10:00", "Etc/GMT+10"), + ("+05:00", "Etc/GMT-5"), + ("05:00", "Etc/GMT-5"), + ("-05:00", "Etc/GMT+5"), + ], +) +def test_from_arrow_fixed_offset(fixed_offset: str, etc_tz: str) -> None: + arr = pa.array( + [datetime(2021, 1, 1, 0, 0, 0, 0)], + type=pa.timestamp("us", tz=fixed_offset), + ) + result = cast(pl.Series, pl.from_arrow(arr)) + expected = pl.Series( + [datetime(2021, 1, 1, tzinfo=timezone.utc)] + ).dt.convert_time_zone(etc_tz) + assert_series_equal(result, expected) + + +def test_from_avro_valid_time_zone_13032() -> None: + arr = pa.array( + [datetime(2021, 1, 1, 0, 0, 0, 0)], type=pa.timestamp("ns", tz="00:00") + ) + result = cast(pl.Series, pl.from_arrow(arr)) + expected = pl.Series([datetime(2021, 1, 1)], dtype=pl.Datetime("ns", "UTC")) + assert_series_equal(result, expected) + + +def test_from_numpy_different_resolution_15991() -> None: + result = pl.Series( + np.array(["2020-01-01"], dtype="datetime64[ns]"), dtype=pl.Datetime("us") + ) + expected = pl.Series([datetime(2020, 1, 1)], dtype=pl.Datetime("us")) + assert_series_equal(result, expected) + + +def test_from_numpy_different_resolution_invalid() -> None: + with pytest.raises(ValueError, match="Please cast"): + pl.Series( + np.array(["2020-01-01"], dtype="datetime64[s]"), dtype=pl.Datetime("us") + ) + + +def test_compat_level(monkeypatch: pytest.MonkeyPatch) -> None: + # change these if compat level bumped + monkeypatch.setenv("POLARS_WARN_UNSTABLE", "1") + oldest = CompatLevel.oldest() + assert oldest is CompatLevel.oldest() # test singleton + assert oldest._version == 0 # type: ignore[attr-defined] + with pytest.warns(UnstableWarning): + newest = CompatLevel.newest() + assert newest is CompatLevel.newest() + assert newest._version == 1 # type: ignore[attr-defined] + + str_col = pl.Series(["awd"]) + bin_col = pl.Series([b"dwa"]) + assert str_col._newest_compat_level() == newest._version # type: ignore[attr-defined] + assert isinstance(str_col.to_arrow(), pa.LargeStringArray) + assert isinstance(str_col.to_arrow(compat_level=oldest), pa.LargeStringArray) + assert isinstance(str_col.to_arrow(compat_level=newest), pa.StringViewArray) + assert isinstance(bin_col.to_arrow(), pa.LargeBinaryArray) + assert isinstance(bin_col.to_arrow(compat_level=oldest), pa.LargeBinaryArray) + assert isinstance(bin_col.to_arrow(compat_level=newest), pa.BinaryViewArray) + + df = pl.DataFrame({"str_col": str_col, "bin_col": bin_col}) + assert isinstance(df.to_arrow()["str_col"][0], pa.LargeStringScalar) + assert isinstance( + df.to_arrow(compat_level=oldest)["str_col"][0], pa.LargeStringScalar + ) + assert isinstance( + df.to_arrow(compat_level=newest)["str_col"][0], pa.StringViewScalar + ) + assert isinstance(df.to_arrow()["bin_col"][0], pa.LargeBinaryScalar) + assert isinstance( + df.to_arrow(compat_level=oldest)["bin_col"][0], pa.LargeBinaryScalar + ) + assert isinstance( + df.to_arrow(compat_level=newest)["bin_col"][0], pa.BinaryViewScalar + ) + + assert len(df.write_ipc(None).getbuffer()) == 786 + assert len(df.write_ipc(None, compat_level=oldest).getbuffer()) == 914 + assert len(df.write_ipc(None, compat_level=newest).getbuffer()) == 786 + assert len(df.write_ipc_stream(None).getbuffer()) == 544 + assert len(df.write_ipc_stream(None, compat_level=oldest).getbuffer()) == 672 + assert len(df.write_ipc_stream(None, compat_level=newest).getbuffer()) == 544 + + +def test_df_pycapsule_interface() -> None: + df = pl.DataFrame( + { + "a": [1, 2, 3], + "b": ["a", "b", "c"], + "c": ["fooooooooooooooooooooo", "bar", "looooooooooooooooong string"], + } + ) + + capsule_df = PyCapsuleStreamHolder(df) + out = pa.table(capsule_df) + assert df.shape == out.shape + assert df.schema.names() == out.schema.names + + schema_overrides = {"a": pl.Int128} + expected_schema = pl.Schema([("a", pl.Int128), ("b", pl.String), ("c", pl.String)]) + + for arrow_obj in ( + pl.from_arrow(capsule_df), # capsule + out, # table loaded from capsule + ): + df_res = pl.from_arrow(arrow_obj, schema_overrides=schema_overrides) + assert expected_schema == df_res.schema # type: ignore[union-attr] + assert isinstance(df_res, pl.DataFrame) + assert df.equals(df_res) + + +def test_misaligned_nested_arrow_19097() -> None: + a = pl.Series("a", [1, 2, 3]) + a = a.slice(1, 2) # by slicing we offset=1 the values + a = a.replace(2, None) # then we add a validity mask with offset=0 + a = a.reshape((2, 1)) # then we make it nested + assert_series_equal(pl.Series("a", a.to_arrow()), a) + + +def test_arrow_roundtrip_lex_cat_20288() -> None: + tb = ( + pl.Series("a", ["A", "B"], pl.Categorical(ordering="lexical")) + .to_frame() + .to_arrow() + ) + df = pl.from_arrow(tb) + assert isinstance(df, pl.DataFrame) + dt = df.schema["a"] + assert isinstance(dt, pl.Categorical) + assert dt.ordering == "lexical" + + +def test_from_arrow_string_cache_20271() -> None: + with pl.StringCache(): + s = pl.Series("a", ["A", "B", "C"], pl.Categorical) + df = pl.from_arrow( + pa.table({"b": pa.DictionaryArray.from_arrays([0, 1], ["D", "E"])}) + ) + assert isinstance(df, pl.DataFrame) + + assert_series_equal( + s.to_physical(), pl.Series("a", [0, 1, 2]), check_dtypes=False + ) + assert_series_equal(df.to_series(), pl.Series("b", ["D", "E"], pl.Categorical)) + assert_series_equal( + df.to_series().to_physical(), pl.Series("b", [3, 4]), check_dtypes=False + ) + + +def test_to_arrow_empty_chunks_20627() -> None: + df = pl.concat(2 * [pl.Series([1])]).filter(pl.Series([False, True])).to_frame() + assert df.to_arrow().shape == (1, 1) diff --git a/py-polars/tests/unit/interop/test_to_init_repr.py b/py-polars/tests/unit/interop/test_to_init_repr.py new file mode 100644 index 000000000000..3740b3c68b73 --- /dev/null +++ b/py-polars/tests/unit/interop/test_to_init_repr.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import zoneinfo # noqa: F401 +from datetime import date, datetime, time, timedelta, timezone + +import polars as pl +from polars.testing import assert_frame_equal + + +def test_to_init_repr() -> None: + # round-trip various types + with pl.StringCache(): + df = ( + pl.LazyFrame( + { + "a": [1, 2, None], + "b": [4.5, 5.5, 6.5], + "c": ["x", "y", "z"], + "d": [True, False, True], + "e": [None, "", None], + "f": [date(2022, 7, 5), date(2023, 2, 5), date(2023, 8, 5)], + "g": [time(0, 0, 0, 1), time(12, 30, 45), time(23, 59, 59, 999000)], + "h": [ + datetime(2022, 7, 5, 10, 30, 45, 4560), + datetime(2023, 10, 12, 20, 3, 8, 11), + None, + ], + "i": [ + datetime(2022, 7, 5, 10, 30, 45, 4560), + datetime(2023, 10, 12, 20, 3, 8, 11), + None, + ], + "null": [None, None, None], + "enum": ["a", "b", "c"], + "duration": [timedelta(days=1), timedelta(days=2), None], + "binary": [bytes([0]), bytes([0, 1]), bytes([0, 1, 2])], + "object": [timezone.utc, timezone.utc, timezone.utc], + }, + ) + .with_columns( + pl.col("c").cast(pl.Categorical), + pl.col("h").cast(pl.Datetime("ns")), + pl.col("i").dt.replace_time_zone("Australia/Melbourne"), + pl.col("enum").cast(pl.Enum(["a", "b", "c"])), + ) + .collect() + ) + + result = eval(df.to_init_repr().replace("datetime.", "")) + expected = df + # drop "object" because it can not be compared by assert_frame_equal + assert_frame_equal(result.drop("object"), expected.drop("object")) + + +def test_to_init_repr_nested_dtype() -> None: + # round-trip nested types + df = pl.LazyFrame( + { + "list": pl.Series(values=[[1], [2], [3]], dtype=pl.List(pl.Int32)), + "list_list": pl.Series( + values=[[[1]], [[2]], [[3]]], dtype=pl.List(pl.List(pl.Int8)) + ), + "array": pl.Series( + values=[[1.0], [2.0], [3.0]], + dtype=pl.Array(pl.Float32, 1), + ), + "struct": pl.Series( + values=[ + {"x": "foo", "y": [1, 2]}, + {"x": "bar", "y": [3, 4, 5]}, + {"x": "foobar", "y": []}, + ], + dtype=pl.Struct({"x": pl.String, "y": pl.List(pl.Int8)}), + ), + }, + ).collect() + + assert_frame_equal(eval(df.to_init_repr()), df) + + +def test_to_init_repr_nested_dtype_roundtrip() -> None: + # round-trip nested types + df = pl.LazyFrame( + { + "list": pl.Series(values=[[1], [2], [3]], dtype=pl.List(pl.Int32)), + "list_list": pl.Series( + values=[[[1]], [[2]], [[3]]], dtype=pl.List(pl.List(pl.Int8)) + ), + "array": pl.Series( + values=[[1.0], [2.0], [3.0]], + dtype=pl.Array(pl.Float32, 1), + ), + "struct": pl.Series( + values=[ + {"x": "foo", "y": [1, 2]}, + {"x": "bar", "y": [3, 4, 5]}, + {"x": "foobar", "y": []}, + ], + dtype=pl.Struct({"x": pl.String, "y": pl.List(pl.Int8)}), + ), + }, + ).collect() + + assert_frame_equal(eval(df.to_init_repr()), df) diff --git a/py-polars/tests/unit/interop/test_to_pandas.py b/py-polars/tests/unit/interop/test_to_pandas.py new file mode 100644 index 000000000000..c803df572aa4 --- /dev/null +++ b/py-polars/tests/unit/interop/test_to_pandas.py @@ -0,0 +1,197 @@ +from __future__ import annotations + +from datetime import date, datetime +from typing import TYPE_CHECKING, Literal + +import hypothesis.strategies as st +import numpy as np +import pandas as pd +import pyarrow as pa +import pytest +from hypothesis import given + +import polars as pl + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + + +def test_df_to_pandas_empty() -> None: + df = pl.DataFrame() + result = df.to_pandas() + expected = pd.DataFrame() + pd.testing.assert_frame_equal(result, expected) + + +def test_to_pandas() -> None: + df = pl.DataFrame( + { + "a": [1, 2, 3], + "b": [6, None, 8], + "c": [10.0, 25.0, 50.5], + "d": [date(2023, 7, 5), None, date(1999, 12, 13)], + "e": ["a", "b", "c"], + "f": [None, "e", "f"], + "g": [datetime.now(), datetime.now(), None], + }, + schema_overrides={"a": pl.UInt8}, + ).with_columns( + pl.col("e").cast(pl.Categorical).alias("h"), + pl.col("f").cast(pl.Categorical).alias("i"), + ) + + pd_out = df.to_pandas() + + pd_out_dtypes_expected = [ + np.dtype(np.uint8), + np.dtype(np.float64), + np.dtype(np.float64), + np.dtype("datetime64[ms]"), + np.dtype(np.object_), + np.dtype(np.object_), + np.dtype("datetime64[us]"), + pd.CategoricalDtype(categories=["a", "b", "c"], ordered=False), + pd.CategoricalDtype(categories=["e", "f"], ordered=False), + ] + assert pd_out_dtypes_expected == pd_out.dtypes.to_list() + + pd_out_dtypes_expected[3] = np.dtype("O") + pd_out = df.to_pandas(date_as_object=True) + assert pd_out_dtypes_expected == pd_out.dtypes.to_list() + + pd_pa_out = df.to_pandas(use_pyarrow_extension_array=True) + pd_pa_dtypes_names = [dtype.name for dtype in pd_pa_out.dtypes] + pd_pa_dtypes_names_expected = [ + "uint8[pyarrow]", + "int64[pyarrow]", + "double[pyarrow]", + "date32[day][pyarrow]", + "large_string[pyarrow]", + "large_string[pyarrow]", + "timestamp[us][pyarrow]", + "dictionary[pyarrow]", + "dictionary[pyarrow]", + ] + assert pd_pa_dtypes_names == pd_pa_dtypes_names_expected + + +@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum(["best", "test"])]) +def test_cat_to_pandas(dtype: pl.DataType) -> None: + df = pl.DataFrame({"a": ["best", "test"]}) + df = df.with_columns(pl.all().cast(dtype)) + + pd_out = df.to_pandas() + assert isinstance(pd_out["a"].dtype, pd.CategoricalDtype) + + pd_pa_out = df.to_pandas(use_pyarrow_extension_array=True) + assert pd_pa_out["a"].dtype == pd.ArrowDtype( + pa.dictionary(pa.int64(), pa.large_string()) + ) + + +@given( + column_type_names=st.lists( + st.one_of(st.just("Object"), st.just("Int32")), min_size=1, max_size=8 + ) +) +def test_object_to_pandas(column_type_names: list[Literal["Object", "Int32"]]) -> None: + """ + Converting ``pl.Object`` dtype columns to Pandas is handled correctly. + + This edge case is handled with a separate code path than other data types, + so we test it more thoroughly. + """ + column_types = [getattr(pl, name) for name in column_type_names] + data = { + f"col_{i}": [object()] if dtype.is_object() else [-i] + for i, dtype in enumerate(column_types) + } + df = pl.DataFrame( + data, schema={f"col_{i}": column_types[i] for i in range(len(column_types))} + ) + for pyarrow in [True, False]: + pandas_df = df.to_pandas(use_pyarrow_extension_array=pyarrow) + assert isinstance(pandas_df, pd.DataFrame) + assert pandas_df.to_dict(orient="list") == data + + +def test_from_empty_pandas_with_dtypes() -> None: + df = pd.DataFrame(columns=["a", "b"]) + df["a"] = df["a"].astype(str) + df["b"] = df["b"].astype(float) + assert pl.from_pandas(df).dtypes == [pl.String, pl.Float64] + + df = pl.DataFrame( + data=[], + schema={ + "a": pl.Int32, + "b": pl.Datetime, + "c": pl.Float32, + "d": pl.Duration, + "e": pl.String, + }, + ).to_pandas() + + assert pl.from_pandas(df).dtypes == [ + pl.Int32, + pl.Datetime, + pl.Float32, + pl.Duration, + pl.String, + ] + + +def test_to_pandas_series() -> None: + assert (pl.Series("a", [1, 2, 3]).to_pandas() == pd.Series([1, 2, 3])).all() + + +def test_to_pandas_date() -> None: + data = [date(1990, 1, 1), date(2024, 12, 31)] + s = pl.Series("a", data) + + result_series = s.to_pandas() + expected_series = pd.Series(data, dtype="datetime64[ms]", name="a") + pd.testing.assert_series_equal(result_series, expected_series) + + result_df = s.to_frame().to_pandas() + expected_df = expected_series.to_frame() + pd.testing.assert_frame_equal(result_df, expected_df) + + +def test_to_pandas_datetime() -> None: + data = [datetime(1990, 1, 1, 0, 0, 0), datetime(2024, 12, 31, 23, 59, 59)] + s = pl.Series("a", data) + + result_series = s.to_pandas() + expected_series = pd.Series(data, dtype="datetime64[us]", name="a") + pd.testing.assert_series_equal(result_series, expected_series) + + result_df = s.to_frame().to_pandas() + expected_df = expected_series.to_frame() + pd.testing.assert_frame_equal(result_df, expected_df) + + +@pytest.mark.parametrize("use_pyarrow_extension_array", [True, False]) +def test_object_to_pandas_series(use_pyarrow_extension_array: bool) -> None: + values = [object(), [1, 2, 3]] + pd.testing.assert_series_equal( + pl.Series("a", values, dtype=pl.Object).to_pandas( + use_pyarrow_extension_array=use_pyarrow_extension_array + ), + pd.Series(values, dtype=object, name="a"), + ) + + +@pytest.mark.parametrize("polars_dtype", [pl.Categorical, pl.Enum(["a", "b"])]) +def test_series_to_pandas_categorical(polars_dtype: PolarsDataType) -> None: + s = pl.Series("x", ["a", "b", "a"], dtype=polars_dtype) + result = s.to_pandas() + expected = pd.Series(["a", "b", "a"], name="x", dtype="category") + pd.testing.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("polars_dtype", [pl.Categorical, pl.Enum(["a", "b"])]) +def test_series_to_pandas_categorical_pyarrow(polars_dtype: PolarsDataType) -> None: + s = pl.Series("x", ["a", "b", "a"], dtype=polars_dtype) + result = s.to_pandas(use_pyarrow_extension_array=True) + assert s.to_list() == result.to_list() diff --git a/py-polars/tests/unit/io/cloud/__init__.py b/py-polars/tests/unit/io/cloud/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/io/cloud/test_aws.py b/py-polars/tests/unit/io/cloud/test_aws.py new file mode 100644 index 000000000000..6f2116421822 --- /dev/null +++ b/py-polars/tests/unit/io/cloud/test_aws.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import multiprocessing +from typing import TYPE_CHECKING, Any, Callable + +import boto3 +import pytest +from moto.server import ThreadedMotoServer + +import polars as pl +from polars.testing import assert_frame_equal + +if TYPE_CHECKING: + from collections.abc import Iterator + from pathlib import Path + +pytestmark = [ + pytest.mark.skip( + reason="Causes intermittent failures in CI. See: " + "https://github.com/pola-rs/polars/issues/16910" + ), + pytest.mark.xdist_group("aws"), + pytest.mark.slow(), +] + + +@pytest.fixture(scope="module") +def monkeypatch_module() -> Any: + """Allow module-scoped monkeypatching.""" + with pytest.MonkeyPatch.context() as mp: + yield mp + + +@pytest.fixture(scope="module") +def s3_base(monkeypatch_module: Any) -> Iterator[str]: + monkeypatch_module.setenv("AWS_ACCESS_KEY_ID", "accesskey") + monkeypatch_module.setenv("AWS_SECRET_ACCESS_KEY", "secretkey") + monkeypatch_module.setenv("AWS_DEFAULT_REGION", "us-east-1") + + host = "127.0.0.1" + port = 5000 + moto_server = ThreadedMotoServer(host, port) + # Start in a separate process to avoid deadlocks + mp = multiprocessing.get_context("spawn") + p = mp.Process(target=moto_server._server_entry, daemon=True) + p.start() + print("server up") + yield f"http://{host}:{port}" + print("moto done") + p.kill() + + +@pytest.fixture +def s3(s3_base: str, io_files_path: Path) -> str: + region = "us-east-1" + client = boto3.client("s3", region_name=region, endpoint_url=s3_base) + client.create_bucket(Bucket="bucket") + + files = ["foods1.csv", "foods1.ipc", "foods1.parquet", "foods2.parquet"] + for file in files: + client.upload_file(io_files_path / file, Bucket="bucket", Key=file) + return s3_base + + +@pytest.mark.parametrize( + ("function", "extension"), + [ + (pl.read_csv, "csv"), + (pl.read_ipc, "ipc"), + ], +) +def test_read_s3(s3: str, function: Callable[..., Any], extension: str) -> None: + storage_options = {"endpoint_url": s3} + df = function( + f"s3://bucket/foods1.{extension}", + storage_options=storage_options, + ) + assert df.columns == ["category", "calories", "fats_g", "sugars_g"] + assert df.shape == (27, 4) + + # ensure we aren't modifying the original user dictionary (ref #15859) + assert storage_options == {"endpoint_url": s3} + + +@pytest.mark.parametrize( + ("function", "extension"), + [(pl.scan_ipc, "ipc"), (pl.scan_parquet, "parquet")], +) +def test_scan_s3(s3: str, function: Callable[..., Any], extension: str) -> None: + lf = function( + f"s3://bucket/foods1.{extension}", + storage_options={"endpoint_url": s3}, + ) + assert lf.collect_schema().names() == ["category", "calories", "fats_g", "sugars_g"] + assert lf.collect().shape == (27, 4) + + +def test_lazy_count_s3(s3: str) -> None: + lf = pl.scan_parquet( + "s3://bucket/foods*.parquet", storage_options={"endpoint_url": s3} + ).select(pl.len()) + + assert "FAST_COUNT" in lf.explain() + expected = pl.DataFrame({"len": [54]}, schema={"len": pl.UInt32}) + assert_frame_equal(lf.collect(), expected) diff --git a/py-polars/tests/unit/io/cloud/test_catalog.py b/py-polars/tests/unit/io/cloud/test_catalog.py new file mode 100644 index 000000000000..c521172678d0 --- /dev/null +++ b/py-polars/tests/unit/io/cloud/test_catalog.py @@ -0,0 +1,11 @@ +import pytest + +import polars as pl + + +def test_catalog_require_https() -> None: + with pytest.raises(ValueError): + pl.Catalog("http://") + + pl.Catalog("https://") + pl.Catalog("http://", require_https=False) diff --git a/py-polars/tests/unit/io/cloud/test_cloud.py b/py-polars/tests/unit/io/cloud/test_cloud.py new file mode 100644 index 000000000000..949b520cbd63 --- /dev/null +++ b/py-polars/tests/unit/io/cloud/test_cloud.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import contextlib +from functools import partial + +import pytest + +import polars as pl +from polars.exceptions import ComputeError + + +@pytest.mark.slow +@pytest.mark.parametrize("format", ["parquet", "csv", "ndjson", "ipc"]) +def test_scan_nonexistent_cloud_path_17444(format: str) -> None: + # https://github.com/pola-rs/polars/issues/17444 + + path_str = f"s3://my-nonexistent-bucket/data.{format}" + scan_function = getattr(pl, f"scan_{format}") + # Prevent automatic credential provideder instantiation, otherwise CI may fail with + # * pytest.PytestUnraisableExceptionWarning: + # * Exception ignored: + # * ResourceWarning: unclosed socket + scan_function = partial(scan_function, credential_provider=None) + + # Just calling the scan function should not raise any errors + if format == "ndjson": + # NDJSON does not have a `retries` parameter yet - so use the default + result = scan_function(path_str) + else: + result = scan_function(path_str, retries=0) + assert isinstance(result, pl.LazyFrame) + + # Upon collection, it should fail + with pytest.raises(ComputeError): + result.collect() + + +def test_scan_err_rebuild_store_19933() -> None: + call_count = 0 + + def f() -> None: + nonlocal call_count + call_count += 1 + raise AssertionError + + q = pl.scan_parquet( + "s3://.../...", + storage_options={"aws_region": "eu-west-1"}, + credential_provider=f, # type: ignore[arg-type] + ) + + with contextlib.suppress(Exception): + q.collect() + + # Note: We get called 2 times per attempt + if call_count != 4: + raise AssertionError(call_count) diff --git a/py-polars/tests/unit/io/cloud/test_credential_provider.py b/py-polars/tests/unit/io/cloud/test_credential_provider.py new file mode 100644 index 000000000000..087646f8ba27 --- /dev/null +++ b/py-polars/tests/unit/io/cloud/test_credential_provider.py @@ -0,0 +1,145 @@ +import io +import pickle +from typing import Any + +import pytest + +import polars as pl +import polars.io.cloud.credential_provider +from polars.exceptions import ComputeError + + +@pytest.mark.parametrize( + "io_func", + [ + *[pl.scan_parquet, pl.read_parquet], + pl.scan_csv, + *[pl.scan_ndjson, pl.read_ndjson], + pl.scan_ipc, + ], +) +def test_credential_provider_scan( + io_func: Any, monkeypatch: pytest.MonkeyPatch +) -> None: + err_magic = "err_magic_3" + + def raises(*_: None, **__: None) -> None: + raise AssertionError(err_magic) + + from polars.io.cloud.credential_provider._builder import CredentialProviderBuilder + + monkeypatch.setattr(CredentialProviderBuilder, "__init__", raises) + + with pytest.raises(AssertionError, match=err_magic): + io_func("s3://bucket/path", credential_provider="auto") + + with pytest.raises(AssertionError, match=err_magic): + io_func( + "s3://bucket/path", + credential_provider="auto", + storage_options={"aws_region": "eu-west-1"}, + ) + + # We can't test these with the `read_` functions as they end up executing + # the query + if io_func.__name__.startswith("scan_"): + # Passing `None` should disable the automatic instantiation of + # `CredentialProviderAWS` + io_func("s3://bucket/path", credential_provider=None) + # Passing `storage_options` should disable the automatic instantiation of + # `CredentialProviderAWS` + io_func( + "s3://bucket/path", + credential_provider="auto", + storage_options={"aws_access_key_id": "polars"}, + ) + + err_magic = "err_magic_7" + + def raises_2() -> pl.CredentialProviderFunctionReturn: + raise AssertionError(err_magic) + + with pytest.raises(AssertionError, match=err_magic): + io_func("s3://bucket/path", credential_provider=raises_2).collect() + + +@pytest.mark.parametrize( + ("provider_class", "path"), + [ + (polars.io.cloud.credential_provider.CredentialProviderAWS, "s3://.../..."), + (polars.io.cloud.credential_provider.CredentialProviderGCP, "gs://.../..."), + (polars.io.cloud.credential_provider.CredentialProviderAzure, "az://.../..."), + ], +) +def test_credential_provider_serialization_auto_init( + provider_class: polars.io.cloud.credential_provider.CredentialProvider, + path: str, + monkeypatch: pytest.MonkeyPatch, +) -> None: + def raises_1(*a: Any, **kw: Any) -> None: + msg = "err_magic_1" + raise AssertionError(msg) + + monkeypatch.setattr(provider_class, "__init__", raises_1) + + # If this is not set we will get an error before hitting the credential + # provider logic when polars attempts to retrieve the region from AWS. + monkeypatch.setenv("AWS_REGION", "eu-west-1") + + # Credential provider should not be initialized during query plan construction. + q = pl.scan_parquet(path) + + # Check baseline - query plan is configured to auto-initialize the credential + # provider. + with pytest.raises(pl.exceptions.ComputeError, match="err_magic_1"): + q.collect() + + q = pickle.loads(pickle.dumps(q)) + + def raises_2(*a: Any, **kw: Any) -> None: + msg = "err_magic_2" + raise AssertionError(msg) + + monkeypatch.setattr(provider_class, "__init__", raises_2) + + # Check that auto-initialization happens upon executing the deserialized + # query. + with pytest.raises(pl.exceptions.ComputeError, match="err_magic_2"): + q.collect() + + +def test_credential_provider_serialization_custom_provider() -> None: + err_magic = "err_magic_3" + + class ErrCredentialProvider(pl.CredentialProvider): + def __call__(self) -> pl.CredentialProviderFunctionReturn: + raise AssertionError(err_magic) + + lf = pl.scan_parquet( + "s3://bucket/path", credential_provider=ErrCredentialProvider() + ) + + serialized = lf.serialize() + + lf = pl.LazyFrame.deserialize(io.BytesIO(serialized)) + + with pytest.raises(ComputeError, match=err_magic): + lf.collect() + + +def test_credential_provider_skips_google_config_autoload( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("GOOGLE_SERVICE_ACCOUNT_PATH", "__non_existent") + + with pytest.raises(ComputeError, match="__non_existent"): + pl.scan_parquet("gs://.../...", credential_provider=None).collect() + + err_magic = "err_magic_3" + + def raises() -> pl.CredentialProviderFunctionReturn: + raise AssertionError(err_magic) + + # We should get a different error raised by our `raises()` function. + with pytest.raises(ComputeError, match=err_magic): + pl.scan_parquet("gs://.../...", credential_provider=raises).collect() diff --git a/py-polars/tests/unit/io/conftest.py b/py-polars/tests/unit/io/conftest.py new file mode 100644 index 000000000000..df245c097be4 --- /dev/null +++ b/py-polars/tests/unit/io/conftest.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + + +@pytest.fixture +def io_files_path() -> Path: + return Path(__file__).parent / "files" diff --git a/py-polars/tests/unit/io/database/conftest.py b/py-polars/tests/unit/io/database/conftest.py new file mode 100644 index 000000000000..8fcdc27547a3 --- /dev/null +++ b/py-polars/tests/unit/io/database/conftest.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +import sqlite3 +from datetime import date +from typing import TYPE_CHECKING + +import pytest + +if TYPE_CHECKING: + from pathlib import Path + + +@pytest.fixture +def tmp_sqlite_db(tmp_path: Path) -> Path: + test_db = tmp_path / "test.db" + test_db.unlink(missing_ok=True) + + def convert_date(val: bytes) -> date: + """Convert ISO 8601 date to datetime.date object.""" + return date.fromisoformat(val.decode()) + + # NOTE: at the time of writing adcb/connectorx have weak SQLite support (poor or + # no bool/date/datetime dtypes, for example) and there is a bug in connectorx that + # causes float rounding < py 3.11, hence we are only testing/storing simple values + # in this test db for now. as support improves, we can add/test additional dtypes). + sqlite3.register_converter("date", convert_date) + conn = sqlite3.connect(test_db) + + # ┌─────┬───────┬───────┬────────────┐ + # │ id ┆ name ┆ value ┆ date │ + # │ --- ┆ --- ┆ --- ┆ --- │ + # │ i64 ┆ str ┆ f64 ┆ date │ + # ╞═════╪═══════╪═══════╪════════════╡ + # │ 1 ┆ misc ┆ 100.0 ┆ 2020-01-01 │ + # │ 2 ┆ other ┆ -99.0 ┆ 2021-12-31 │ + # └─────┴───────┴───────┴────────────┘ + conn.executescript( + """ + CREATE TABLE IF NOT EXISTS test_data ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + value FLOAT, + date DATE + ); + REPLACE INTO test_data(name,value,date) + VALUES ('misc',100.0,'2020-01-01'), + ('other',-99.5,'2021-12-31'); + """ + ) + conn.close() + return test_db + + +@pytest.fixture +def tmp_sqlite_inference_db(tmp_path: Path) -> Path: + test_db = tmp_path / "test_inference.db" + test_db.unlink(missing_ok=True) + conn = sqlite3.connect(test_db) + conn.executescript( + """ + CREATE TABLE IF NOT EXISTS test_data (name TEXT, value FLOAT); + REPLACE INTO test_data(name,value) VALUES (NULL,NULL), ('foo',0); + """ + ) + conn.close() + return test_db diff --git a/py-polars/tests/unit/io/database/test_async.py b/py-polars/tests/unit/io/database/test_async.py new file mode 100644 index 000000000000..48cb06cc0414 --- /dev/null +++ b/py-polars/tests/unit/io/database/test_async.py @@ -0,0 +1,201 @@ +from __future__ import annotations + +import asyncio +from math import ceil +from types import ModuleType +from typing import TYPE_CHECKING, Any, overload + +import pytest +import sqlalchemy +from sqlalchemy.ext.asyncio import create_async_engine + +import polars as pl +from polars._utils.various import parse_version +from polars.io.database._utils import _run_async +from polars.testing import assert_frame_equal +from tests.unit.conftest import mock_module_import + +if TYPE_CHECKING: + from collections.abc import Iterable + from pathlib import Path + +SURREAL_MOCK_DATA: list[dict[str, Any]] = [ + { + "id": "item:8xj31jfpdkf9gvmxdxpi", + "name": "abc", + "tags": ["polars"], + "checked": False, + }, + { + "id": "item:l59k19swv2adsv4q04cj", + "name": "mno", + "tags": ["async"], + "checked": None, + }, + { + "id": "item:w831f1oyqnwztv5q03em", + "name": "xyz", + "tags": ["stroop", "wafel"], + "checked": True, + }, +] + + +class MockSurrealConnection: + """Mock SurrealDB connection/client object.""" + + __module__ = "surrealdb" + + def __init__(self, url: str, mock_data: list[dict[str, Any]]) -> None: + self._mock_data = mock_data.copy() + self.url = url + + async def __aenter__(self) -> Any: + await self.connect() + return self + + async def __aexit__(self, *args: object, **kwargs: Any) -> None: + await self.close() + + async def close(self) -> None: + pass + + async def connect(self) -> None: + pass + + async def use(self, namespace: str, database: str) -> None: + pass + + async def query( + self, query: str, variables: dict[str, Any] | None = None + ) -> list[dict[str, Any]]: + return [{"result": self._mock_data, "status": "OK", "time": "32.083µs"}] + + +class MockedSurrealModule(ModuleType): + """Mock SurrealDB module; enables internal `isinstance` check for AsyncSurrealDB.""" + + AsyncSurrealDB = MockSurrealConnection + + +@pytest.mark.skipif( + parse_version(sqlalchemy.__version__) < (2, 0), + reason="SQLAlchemy 2.0+ required for async tests", +) +def test_read_async(tmp_sqlite_db: Path) -> None: + # confirm that we can load frame data from the core sqlalchemy async + # primitives: AsyncConnection, AsyncEngine, and async_sessionmaker + from sqlalchemy.ext.asyncio import async_sessionmaker + + async_engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_sqlite_db}") + async_connection = async_engine.connect() + async_session = async_sessionmaker(async_engine) + async_session_inst = async_session() + + expected_frame = pl.DataFrame( + {"id": [2, 1], "name": ["other", "misc"], "value": [-99.5, 100.0]} + ) + async_conn: Any + for async_conn in ( + async_engine, + async_connection, + async_session, + async_session_inst, + ): + if async_conn in (async_session, async_session_inst): + constraint, execute_opts = "", {} + else: + constraint = "WHERE value > :n" + execute_opts = {"parameters": {"n": -1000}} + + df = pl.read_database( + query=f""" + SELECT id, name, value + FROM test_data {constraint} + ORDER BY id DESC + """, + connection=async_conn, + execute_options=execute_opts, + ) + assert_frame_equal(expected_frame, df) + + +async def _nested_async_test(tmp_sqlite_db: Path) -> pl.DataFrame: + async_engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_sqlite_db}") + return pl.read_database( + query="SELECT id, name FROM test_data ORDER BY id", + connection=async_engine.connect(), + ) + + +@pytest.mark.skipif( + parse_version(sqlalchemy.__version__) < (2, 0), + reason="SQLAlchemy 2.0+ required for async tests", +) +def test_read_async_nested(tmp_sqlite_db: Path) -> None: + # This tests validates that we can handle nested async calls + expected_frame = pl.DataFrame({"id": [1, 2], "name": ["misc", "other"]}) + df = asyncio.run(_nested_async_test(tmp_sqlite_db)) + assert_frame_equal(expected_frame, df) + + +@overload +async def _surreal_query_as_frame( + url: str, query: str, batch_size: None +) -> pl.DataFrame: ... + + +@overload +async def _surreal_query_as_frame( + url: str, query: str, batch_size: int +) -> Iterable[pl.DataFrame]: ... + + +async def _surreal_query_as_frame( + url: str, query: str, batch_size: int | None +) -> pl.DataFrame | Iterable[pl.DataFrame]: + batch_params = ( + {"iter_batches": True, "batch_size": batch_size} if batch_size else {} + ) + async with MockSurrealConnection(url=url, mock_data=SURREAL_MOCK_DATA) as client: + await client.use(namespace="test", database="test") + return pl.read_database( # type: ignore[no-any-return,call-overload] + query=query, + connection=client, + **batch_params, + ) + + +@pytest.mark.parametrize("batch_size", [None, 1, 2, 3, 4]) +def test_surrealdb_fetchall(batch_size: int | None) -> None: + with mock_module_import("surrealdb", MockedSurrealModule("surrealdb")): + df_expected = pl.DataFrame(SURREAL_MOCK_DATA) + res = asyncio.run( + _surreal_query_as_frame( + url="ws://localhost:8000/rpc", + query="SELECT * FROM item", + batch_size=batch_size, + ) + ) + if batch_size: + frames = list(res) # type: ignore[call-overload] + n_mock_rows = len(SURREAL_MOCK_DATA) + assert len(frames) == ceil(n_mock_rows / batch_size) + assert_frame_equal(df_expected[:batch_size], frames[0]) + else: + assert_frame_equal(df_expected, res) # type: ignore[arg-type] + + +def test_async_nested_captured_loop_21263() -> None: + # Tests awaiting a future that has "captured" the original event loop from + # within a `_run_async` context. + async def test_impl() -> None: + loop = asyncio.get_running_loop() + task = loop.create_task(asyncio.sleep(0)) + + _run_async(await_task(task)) + + async def await_task(task: Any) -> None: + await task + + asyncio.run(test_impl()) diff --git a/py-polars/tests/unit/io/database/test_inference.py b/py-polars/tests/unit/io/database/test_inference.py new file mode 100644 index 000000000000..d4140a576f4c --- /dev/null +++ b/py-polars/tests/unit/io/database/test_inference.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +import sqlite3 +from typing import TYPE_CHECKING + +import pytest + +import polars as pl +from polars.exceptions import ComputeError +from polars.io.database._inference import _infer_dtype_from_database_typename + +if TYPE_CHECKING: + from pathlib import Path + + from polars._typing import PolarsDataType + + +@pytest.mark.parametrize( + ("value", "expected_dtype"), + [ + # string types + ("UTF16", pl.String), + ("char(8)", pl.String), + ("BPCHAR", pl.String), + ("nchar[128]", pl.String), + ("varchar", pl.String), + ("CHARACTER VARYING(64)", pl.String), + ("nvarchar(32)", pl.String), + ("TEXT", pl.String), + # array types + ("float32[]", pl.List(pl.Float32)), + ("double array", pl.List(pl.Float64)), + ("array[bool]", pl.List(pl.Boolean)), + ("array of nchar(8)", pl.List(pl.String)), + ("array[array[int8]]", pl.List(pl.List(pl.Int64))), + # numeric types + ("numeric[10,5]", pl.Decimal(10, 5)), + ("bigdecimal", pl.Decimal), + ("decimal128(10,5)", pl.Decimal(10, 5)), + ("double precision", pl.Float64), + ("floating point", pl.Float64), + ("numeric", pl.Float64), + ("real", pl.Float64), + ("boolean", pl.Boolean), + ("tinyint", pl.Int8), + ("smallint", pl.Int16), + ("int", pl.Int64), + ("int4", pl.Int32), + ("int2", pl.Int16), + ("int(16)", pl.Int16), + ("ROWID", pl.UInt64), + ("mediumint", pl.Int32), + ("unsigned mediumint", pl.UInt32), + ("cardinal_number", pl.UInt64), + ("smallserial", pl.Int16), + ("serial", pl.Int32), + ("bigserial", pl.Int64), + # temporal types + ("timestamp(3)", pl.Datetime("ms")), + ("timestamp(5)", pl.Datetime("us")), + ("timestamp(7)", pl.Datetime("ns")), + ("datetime without tz", pl.Datetime("us")), + ("duration(2)", pl.Duration("ms")), + ("interval", pl.Duration("us")), + ("date", pl.Date), + ("time", pl.Time), + ("date32", pl.Date), + ("time64", pl.Time), + # binary types + ("BYTEA", pl.Binary), + ("BLOB", pl.Binary), + # miscellaneous + ("NULL", pl.Null), + ], +) +def test_dtype_inference_from_string( + value: str, + expected_dtype: PolarsDataType, +) -> None: + inferred_dtype = _infer_dtype_from_database_typename(value) + assert inferred_dtype == expected_dtype # type: ignore[operator] + + +@pytest.mark.parametrize( + "value", + [ + "FooType", + "Unknown", + "MISSING", + "XML", # note: we deliberately exclude "number" as it is ambiguous. + "Number", # (could refer to any size of int, float, or decimal dtype) + ], +) +def test_dtype_inference_from_invalid_string(value: str) -> None: + with pytest.raises(ValueError, match="cannot infer dtype"): + _infer_dtype_from_database_typename(value) + + inferred_dtype = _infer_dtype_from_database_typename( + value=value, + raise_unmatched=False, + ) + assert inferred_dtype is None + + +def test_infer_schema_length(tmp_sqlite_inference_db: Path) -> None: + # note: first row of this test database contains only NULL values + conn = sqlite3.connect(tmp_sqlite_inference_db) + for infer_len in (2, 100, None): + df = pl.read_database( + connection=conn, + query="SELECT * FROM test_data", + infer_schema_length=infer_len, + ) + assert df.schema == {"name": pl.String, "value": pl.Float64} + + with pytest.raises( + ComputeError, + match='could not append value: "foo" of type: str.*`infer_schema_length`', + ): + pl.read_database( + connection=conn, + query="SELECT * FROM test_data", + infer_schema_length=1, + ) diff --git a/py-polars/tests/unit/io/database/test_read.py b/py-polars/tests/unit/io/database/test_read.py new file mode 100644 index 000000000000..176ba2b3e857 --- /dev/null +++ b/py-polars/tests/unit/io/database/test_read.py @@ -0,0 +1,836 @@ +from __future__ import annotations + +import os +import sqlite3 +import sys +from contextlib import suppress +from datetime import date +from pathlib import Path +from types import GeneratorType +from typing import TYPE_CHECKING, Any, NamedTuple, cast + +import pyarrow as pa +import pytest +import sqlalchemy +from sqlalchemy import Integer, MetaData, Table, create_engine, func, select, text +from sqlalchemy.orm import sessionmaker +from sqlalchemy.sql.expression import cast as alchemy_cast + +import polars as pl +from polars._utils.various import parse_version +from polars.exceptions import DuplicateError, UnsuitableSQLError +from polars.io.database._arrow_registry import ARROW_DRIVER_REGISTRY +from polars.testing import assert_frame_equal, assert_series_equal + +if TYPE_CHECKING: + from polars._typing import ( + ConnectionOrCursor, + DbReadEngine, + SchemaDefinition, + SchemaDict, + ) + + +def adbc_sqlite_connect(*args: Any, **kwargs: Any) -> Any: + with suppress(ModuleNotFoundError): # not available on windows + from adbc_driver_sqlite.dbapi import connect + + args = tuple(str(a) if isinstance(a, Path) else a for a in args) + return connect(*args, **kwargs) + + +class MockConnection: + """Mock connection class for databases we can't test in CI.""" + + def __init__( + self, + driver: str, + batch_size: int | None, + exact_batch_size: bool, + test_data: pa.Table, + repeat_batch_calls: bool, + ) -> None: + self.__class__.__module__ = driver + self._cursor = MockCursor( + repeat_batch_calls=repeat_batch_calls, + exact_batch_size=exact_batch_size, + batched=(batch_size is not None), + test_data=test_data, + ) + + def close(self) -> None: + pass + + def cursor(self) -> Any: + return self._cursor + + +class MockCursor: + """Mock cursor class for databases we can't test in CI.""" + + def __init__( + self, + batched: bool, + exact_batch_size: bool, + test_data: pa.Table, + repeat_batch_calls: bool, + ) -> None: + self.resultset = MockResultSet( + test_data=test_data, + batched=batched, + exact_batch_size=exact_batch_size, + repeat_batch_calls=repeat_batch_calls, + ) + self.exact_batch_size = exact_batch_size + self.called: list[str] = [] + self.batched = batched + self.n_calls = 1 + + def __getattr__(self, name: str) -> Any: + if "fetch" in name: + self.called.append(name) + return self.resultset + super().__getattr__(name) # type: ignore[misc] + + def close(self) -> Any: + pass + + def execute(self, query: str) -> Any: + return self + + +class MockResultSet: + """Mock resultset class for databases we can't test in CI.""" + + def __init__( + self, + test_data: pa.Table, + batched: bool, + exact_batch_size: bool, + repeat_batch_calls: bool = False, + ) -> None: + self.test_data = test_data + self.repeat_batched_calls = repeat_batch_calls + self.exact_batch_size = exact_batch_size + self.batched = batched + self.n_calls = 1 + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + if not self.exact_batch_size: + assert len(args) == 0 + if self.repeat_batched_calls: + res = self.test_data[: None if self.n_calls else 0] + self.n_calls -= 1 + else: + res = iter((self.test_data,)) + return res + + +class DatabaseReadTestParams(NamedTuple): + """Clarify read test params.""" + + read_method: str + connect_using: Any + expected_dtypes: SchemaDefinition + expected_dates: list[date | str] + schema_overrides: SchemaDict | None = None + batch_size: int | None = None + + +class ExceptionTestParams(NamedTuple): + """Clarify exception test params.""" + + read_method: str + query: str | list[str] + protocol: Any + errclass: type[Exception] + errmsg: str + engine: str | None = None + execute_options: dict[str, Any] | None = None + kwargs: dict[str, Any] | None = None + + +@pytest.mark.write_disk +@pytest.mark.parametrize( + ( + "read_method", + "connect_using", + "expected_dtypes", + "expected_dates", + "schema_overrides", + "batch_size", + ), + [ + pytest.param( + *DatabaseReadTestParams( + read_method="read_database_uri", + connect_using="connectorx", + expected_dtypes={ + "id": pl.UInt8, + "name": pl.String, + "value": pl.Float64, + "date": pl.Date, + }, + expected_dates=[date(2020, 1, 1), date(2021, 12, 31)], + schema_overrides={"id": pl.UInt8}, + ), + id="uri: connectorx", + ), + pytest.param( + *DatabaseReadTestParams( + read_method="read_database_uri", + connect_using="adbc", + expected_dtypes={ + "id": pl.UInt8, + "name": pl.String, + "value": pl.Float64, + "date": pl.String, + }, + expected_dates=["2020-01-01", "2021-12-31"], + schema_overrides={"id": pl.UInt8}, + ), + marks=pytest.mark.skipif( + sys.platform == "win32", + reason="adbc_driver_sqlite not available on Windows", + ), + id="uri: adbc", + ), + pytest.param( + *DatabaseReadTestParams( + read_method="read_database", + connect_using=lambda path: sqlite3.connect(path, detect_types=True), + expected_dtypes={ + "id": pl.UInt8, + "name": pl.String, + "value": pl.Float32, + "date": pl.Date, + }, + expected_dates=[date(2020, 1, 1), date(2021, 12, 31)], + schema_overrides={"id": pl.UInt8, "value": pl.Float32}, + ), + id="conn: sqlite3", + ), + pytest.param( + *DatabaseReadTestParams( + read_method="read_database", + connect_using=lambda path: sqlite3.connect(path, detect_types=True), + expected_dtypes={ + "id": pl.Int32, + "name": pl.String, + "value": pl.Float32, + "date": pl.Date, + }, + expected_dates=[date(2020, 1, 1), date(2021, 12, 31)], + schema_overrides={"id": pl.Int32, "value": pl.Float32}, + batch_size=1, + ), + id="conn: sqlite3", + ), + pytest.param( + *DatabaseReadTestParams( + read_method="read_database", + connect_using=lambda path: create_engine( + f"sqlite:///{path}", + connect_args={"detect_types": sqlite3.PARSE_DECLTYPES}, + ).connect(), + expected_dtypes={ + "id": pl.Int64, + "name": pl.String, + "value": pl.Float64, + "date": pl.Date, + }, + expected_dates=[date(2020, 1, 1), date(2021, 12, 31)], + ), + id="conn: sqlalchemy", + ), + pytest.param( + *DatabaseReadTestParams( + read_method="read_database", + connect_using=adbc_sqlite_connect, + expected_dtypes={ + "id": pl.Int64, + "name": pl.String, + "value": pl.Float64, + "date": pl.String, + }, + expected_dates=["2020-01-01", "2021-12-31"], + ), + marks=pytest.mark.skipif( + sys.platform == "win32", + reason="adbc_driver_sqlite not available on Windows", + ), + id="conn: adbc (fetchall)", + ), + pytest.param( + *DatabaseReadTestParams( + read_method="read_database", + connect_using=adbc_sqlite_connect, + expected_dtypes={ + "id": pl.Int64, + "name": pl.String, + "value": pl.Float64, + "date": pl.String, + }, + expected_dates=["2020-01-01", "2021-12-31"], + batch_size=1, + ), + marks=pytest.mark.skipif( + sys.platform == "win32", + reason="adbc_driver_sqlite not available on Windows", + ), + id="conn: adbc (batched)", + ), + ], +) +def test_read_database( + read_method: str, + connect_using: Any, + expected_dtypes: dict[str, pl.DataType], + expected_dates: list[date | str], + schema_overrides: SchemaDict | None, + batch_size: int | None, + tmp_sqlite_db: Path, +) -> None: + if read_method == "read_database_uri": + connect_using = cast("DbReadEngine", connect_using) + # instantiate the connection ourselves, using connectorx/adbc + df = pl.read_database_uri( + uri=f"sqlite:///{tmp_sqlite_db}", + query="SELECT * FROM test_data", + engine=connect_using, + schema_overrides=schema_overrides, + ) + df_empty = pl.read_database_uri( + uri=f"sqlite:///{tmp_sqlite_db}", + query="SELECT * FROM test_data WHERE name LIKE '%polars%'", + engine=connect_using, + schema_overrides=schema_overrides, + ) + elif "adbc" in os.environ["PYTEST_CURRENT_TEST"]: + # externally instantiated adbc connections + with connect_using(tmp_sqlite_db) as conn, conn.cursor(): + df = pl.read_database( + connection=conn, + query="SELECT * FROM test_data", + schema_overrides=schema_overrides, + batch_size=batch_size, + ) + df_empty = pl.read_database( + connection=conn, + query="SELECT * FROM test_data WHERE name LIKE '%polars%'", + schema_overrides=schema_overrides, + batch_size=batch_size, + ) + else: + # other user-supplied connections + df = pl.read_database( + connection=connect_using(tmp_sqlite_db), + query="SELECT * FROM test_data WHERE name NOT LIKE '%polars%'", + schema_overrides=schema_overrides, + batch_size=batch_size, + ) + df_empty = pl.read_database( + connection=connect_using(tmp_sqlite_db), + query="SELECT * FROM test_data WHERE name LIKE '%polars%'", + schema_overrides=schema_overrides, + batch_size=batch_size, + ) + + # validate the expected query return (data and schema) + assert df.schema == expected_dtypes + assert df.shape == (2, 4) + assert df["date"].to_list() == expected_dates + + # note: 'cursor.description' is not reliable when no query + # data is returned, so no point comparing expected dtypes + assert df_empty.columns == ["id", "name", "value", "date"] + assert df_empty.shape == (0, 4) + assert df_empty["date"].to_list() == [] + + +def test_read_database_alchemy_selectable(tmp_sqlite_db: Path) -> None: + # various flavours of alchemy connection + alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}") + alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)() + alchemy_conn: ConnectionOrCursor = alchemy_engine.connect() + + t = Table("test_data", MetaData(), autoload_with=alchemy_engine) + + # establish sqlalchemy "selectable" and validate usage + selectable_query = select( + alchemy_cast(func.strftime("%Y", t.c.date), Integer).label("year"), + t.c.name, + t.c.value, + ).where(t.c.value < 0) + + expected = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]}) + + for conn in (alchemy_session, alchemy_engine, alchemy_conn): + assert_frame_equal( + pl.read_database(selectable_query, connection=conn), + expected, + ) + + batches = list( + pl.read_database( + selectable_query, + connection=conn, + iter_batches=True, + batch_size=1, + ) + ) + assert len(batches) == 1 + assert_frame_equal(batches[0], expected) + + +def test_read_database_alchemy_textclause(tmp_sqlite_db: Path) -> None: + # various flavours of alchemy connection + alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}") + alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)() + alchemy_conn: ConnectionOrCursor = alchemy_engine.connect() + + # establish sqlalchemy "textclause" and validate usage + textclause_query = text( + """ + SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value + FROM test_data + WHERE value < 0 + """ + ) + + expected = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]}) + + for conn in (alchemy_session, alchemy_engine, alchemy_conn): + assert_frame_equal( + pl.read_database(textclause_query, connection=conn), + expected, + ) + + batches = list( + pl.read_database( + textclause_query, + connection=conn, + iter_batches=True, + batch_size=1, + ) + ) + assert len(batches) == 1 + assert_frame_equal(batches[0], expected) + + +def test_read_database_parameterised(tmp_sqlite_db: Path) -> None: + # raw cursor "execute" only takes positional params, alchemy cursor takes kwargs + alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}") + alchemy_conn: ConnectionOrCursor = alchemy_engine.connect() + alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)() + raw_conn: ConnectionOrCursor = sqlite3.connect(tmp_sqlite_db) + + # establish parameterised queries and validate usage + query = """ + SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value + FROM test_data + WHERE value < {n} + """ + expected_frame = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]}) + + for param, param_value in ( + (":n", {"n": 0}), + ("?", (0,)), + ("?", [0]), + ): + for conn in (alchemy_session, alchemy_engine, alchemy_conn, raw_conn): + if alchemy_session is conn and param == "?": + continue # alchemy session.execute() doesn't support positional params + if parse_version(sqlalchemy.__version__) < (2, 0) and param == ":n": + continue # skip for older sqlalchemy versions + + assert_frame_equal( + expected_frame, + pl.read_database( + query.format(n=param), + connection=conn, + execute_options={"parameters": param_value}, + ), + ) + + +@pytest.mark.parametrize( + ("param", "param_value"), + [ + (":n", {"n": 0}), + ("?", (0,)), + ("?", [0]), + ], +) +@pytest.mark.skipif( + sys.platform == "win32", reason="adbc_driver_sqlite not available on Windows" +) +def test_read_database_parameterised_uri( + param: str, param_value: Any, tmp_sqlite_db: Path +) -> None: + alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}") + uri = alchemy_engine.url.render_as_string(hide_password=False) + query = """ + SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value + FROM test_data + WHERE value < {n} + """ + expected_frame = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]}) + + for param, param_value in ( + (":n", pa.Table.from_pydict({"n": [0]})), + ("?", (0,)), + ("?", [0]), + ): + # test URI read method (adbc only) + assert_frame_equal( + expected_frame, + pl.read_database_uri( + query.format(n=param), + uri=uri, + engine="adbc", + execute_options={"parameters": param_value}, + ), + ) + + # no connectorx support for execute_options + with pytest.raises( + ValueError, + match="connectorx.*does not support.*execute_options", + ): + pl.read_database_uri( + query.format(n=":n"), + uri=uri, + engine="connectorx", + execute_options={"parameters": (":n", {"n": 0})}, + ) + + +@pytest.mark.parametrize( + ("driver", "batch_size", "iter_batches", "expected_call"), + [ + ("snowflake", None, False, "fetch_arrow_all"), + ("snowflake", 10_000, False, "fetch_arrow_all"), + ("snowflake", 10_000, True, "fetch_arrow_batches"), + ("databricks", None, False, "fetchall_arrow"), + ("databricks", 25_000, False, "fetchall_arrow"), + ("databricks", 25_000, True, "fetchmany_arrow"), + ("turbodbc", None, False, "fetchallarrow"), + ("turbodbc", 50_000, False, "fetchallarrow"), + ("turbodbc", 50_000, True, "fetcharrowbatches"), + ("adbc_driver_postgresql", None, False, "fetch_arrow_table"), + ("adbc_driver_postgresql", 75_000, False, "fetch_arrow_table"), + ("adbc_driver_postgresql", 75_000, True, "fetch_arrow_table"), + ], +) +def test_read_database_mocked( + driver: str, batch_size: int | None, iter_batches: bool, expected_call: str +) -> None: + # since we don't have access to snowflake/databricks/etc from CI we + # mock them so we can check that we're calling the expected methods + arrow = pl.DataFrame({"x": [1, 2, 3], "y": ["aa", "bb", "cc"]}).to_arrow() + + reg = ARROW_DRIVER_REGISTRY.get(driver, {}) # type: ignore[var-annotated] + exact_batch_size = reg.get("exact_batch_size", False) + repeat_batch_calls = reg.get("repeat_batch_calls", False) + + mc = MockConnection( + driver, + batch_size, + test_data=arrow, + repeat_batch_calls=repeat_batch_calls, + exact_batch_size=exact_batch_size, # type: ignore[arg-type] + ) + res = pl.read_database( + query="SELECT * FROM test_data", + connection=mc, + iter_batches=iter_batches, + batch_size=batch_size, + ) + if iter_batches: + assert isinstance(res, GeneratorType) + res = pl.concat(res) + + res = cast(pl.DataFrame, res) + assert expected_call in mc.cursor().called + assert res.rows() == [(1, "aa"), (2, "bb"), (3, "cc")] + + +@pytest.mark.parametrize( + ( + "read_method", + "query", + "protocol", + "errclass", + "errmsg", + "engine", + "execute_options", + "kwargs", + ), + [ + pytest.param( + *ExceptionTestParams( + read_method="read_database_uri", + query="SELECT * FROM test_data", + protocol="sqlite", + errclass=ValueError, + errmsg="engine must be one of {'connectorx', 'adbc'}, got 'not_an_engine'", + engine="not_an_engine", + ), + id="Not an available sql engine", + ), + pytest.param( + *ExceptionTestParams( + read_method="read_database_uri", + query=["SELECT * FROM test_data", "SELECT * FROM test_data"], + protocol="sqlite", + errclass=ValueError, + errmsg="only a single SQL query string is accepted for adbc", + engine="adbc", + ), + id="Unavailable list of queries for adbc", + ), + pytest.param( + *ExceptionTestParams( + read_method="read_database_uri", + query="SELECT * FROM test_data", + protocol="mysql", + errclass=ModuleNotFoundError, + errmsg="ADBC 'adbc_driver_mysql.dbapi' driver not detected.", + engine="adbc", + ), + id="Unavailable adbc driver", + ), + pytest.param( + *ExceptionTestParams( + read_method="read_database_uri", + query="SELECT * FROM test_data", + protocol=sqlite3.connect(":memory:"), + errclass=TypeError, + errmsg="expected connection to be a URI string", + engine="adbc", + ), + id="Invalid connection URI", + ), + pytest.param( + *ExceptionTestParams( + read_method="read_database", + query="SELECT * FROM imaginary_table", + protocol=sqlite3.connect(":memory:"), + errclass=sqlite3.OperationalError, + errmsg="no such table: imaginary_table", + ), + id="Invalid query (unrecognised table name)", + ), + pytest.param( + *ExceptionTestParams( + read_method="read_database", + query="SELECT * FROM imaginary_table", + protocol=sys.getsizeof, # not a connection + errclass=TypeError, + errmsg="Unrecognised connection .* no 'execute' or 'cursor' method", + ), + id="Invalid read DB kwargs", + ), + pytest.param( + *ExceptionTestParams( + read_method="read_database", + query="/* tag: misc */ INSERT INTO xyz VALUES ('polars')", + protocol=sqlite3.connect(":memory:"), + errclass=UnsuitableSQLError, + errmsg="INSERT statements are not valid 'read' queries", + ), + id="Invalid statement type", + ), + pytest.param( + *ExceptionTestParams( + read_method="read_database", + query="DELETE FROM xyz WHERE id = 'polars'", + protocol=sqlite3.connect(":memory:"), + errclass=UnsuitableSQLError, + errmsg="DELETE statements are not valid 'read' queries", + ), + id="Invalid statement type", + ), + pytest.param( + *ExceptionTestParams( + read_method="read_database", + query="SELECT * FROM sqlite_master", + protocol=sqlite3.connect(":memory:"), + errclass=ValueError, + kwargs={"iter_batches": True}, + errmsg="Cannot set `iter_batches` without also setting a non-zero `batch_size`", + ), + id="Invalid batch_size", + ), + pytest.param( + *ExceptionTestParams( + read_method="read_database", + engine="adbc", + query="SELECT * FROM test_data", + protocol=sqlite3.connect(":memory:"), + errclass=TypeError, + errmsg=r"unexpected keyword argument 'partition_on'", + kwargs={"partition_on": "id"}, + ), + id="Invalid kwargs", + ), + pytest.param( + *ExceptionTestParams( + read_method="read_database", + engine="adbc", + query="SELECT * FROM test_data", + protocol="{not:a, valid:odbc_string}", + errclass=ValueError, + errmsg=r"unable to identify string connection as valid ODBC", + ), + id="Invalid ODBC string", + ), + ], +) +def test_read_database_exceptions( + read_method: str, + query: str, + protocol: Any, + errclass: type[Exception], + errmsg: str, + engine: DbReadEngine | None, + execute_options: dict[str, Any] | None, + kwargs: dict[str, Any] | None, +) -> None: + if read_method == "read_database_uri": + conn = f"{protocol}://test" if isinstance(protocol, str) else protocol + params = {"uri": conn, "query": query, "engine": engine} + else: + params = {"connection": protocol, "query": query} + if execute_options: + params["execute_options"] = execute_options + if kwargs is not None: + params.update(kwargs) + + read_database = getattr(pl, read_method) + with pytest.raises(errclass, match=errmsg): + read_database(**params) + + +@pytest.mark.parametrize( + "query", + [ + "SELECT 1, 1 FROM test_data", + 'SELECT 1 AS "n", 2 AS "n" FROM test_data', + 'SELECT name, value AS "name" FROM test_data', + ], +) +def test_read_database_duplicate_column_error(tmp_sqlite_db: Path, query: str) -> None: + alchemy_conn = create_engine(f"sqlite:///{tmp_sqlite_db}").connect() + with pytest.raises( + DuplicateError, + match="column .+ appears more than once in the query/result cursor", + ): + pl.read_database(query, connection=alchemy_conn) + + +@pytest.mark.parametrize( + "uri", + [ + "fakedb://123:456@account/database/schema?warehouse=warehouse&role=role", + "fakedb://my#%us3r:p433w0rd@not_a_real_host:9999/database", + ], +) +def test_read_database_cx_credentials(uri: str) -> None: + with pytest.raises(RuntimeError, match=r"Source.*not supported"): + pl.read_database_uri("SELECT * FROM data", uri=uri, engine="connectorx") + + +@pytest.mark.skipif( + sys.platform == "win32", + reason="kuzu segfaults on windows: https://github.com/pola-rs/polars/actions/runs/12502055945/job/34880479875?pr=20462", +) +@pytest.mark.write_disk +def test_read_kuzu_graph_database(tmp_path: Path, io_files_path: Path) -> None: + import kuzu + + tmp_path.mkdir(exist_ok=True) + if (kuzu_test_db := (tmp_path / "kuzu_test.db")).exists(): + kuzu_test_db.unlink() + + test_db = str(kuzu_test_db).replace("\\", "/") + + db = kuzu.Database(test_db) + conn = kuzu.Connection(db) + conn.execute("CREATE NODE TABLE User(name STRING, age UINT64, PRIMARY KEY (name))") + conn.execute("CREATE REL TABLE Follows(FROM User TO User, since INT64)") + + users = str(io_files_path / "graph-data" / "user.csv").replace("\\", "/") + follows = str(io_files_path / "graph-data" / "follows.csv").replace("\\", "/") + + conn.execute(f'COPY User FROM "{users}"') + conn.execute(f'COPY Follows FROM "{follows}"') + + # basic: single relation + df1 = pl.read_database( + query="MATCH (u:User) RETURN u.name, u.age", + connection=conn, + ) + assert_frame_equal( + df1, + pl.DataFrame( + { + "u.name": ["Adam", "Karissa", "Zhang", "Noura"], + "u.age": [30, 40, 50, 25], + }, + schema={"u.name": pl.Utf8, "u.age": pl.UInt64}, + ), + ) + + # join: connected edges/relations + df2 = pl.read_database( + query="MATCH (a:User)-[f:Follows]->(b:User) RETURN a.name, f.since, b.name", + connection=conn, + schema_overrides={"f.since": pl.Int16}, + ) + assert_frame_equal( + df2, + pl.DataFrame( + { + "a.name": ["Adam", "Adam", "Karissa", "Zhang"], + "f.since": [2020, 2020, 2021, 2022], + "b.name": ["Karissa", "Zhang", "Zhang", "Noura"], + }, + schema={"a.name": pl.Utf8, "f.since": pl.Int16, "b.name": pl.Utf8}, + ), + ) + + # empty: no results for the given query + df3 = pl.read_database( + query="MATCH (a:User)-[f:Follows]->(b:User) WHERE a.name = '🔎️' RETURN a.name, f.since, b.name", + connection=conn, + ) + assert_frame_equal( + df3, + pl.DataFrame( + schema={"a.name": pl.Utf8, "f.since": pl.Int64, "b.name": pl.Utf8} + ), + ) + + +def test_sqlalchemy_row_init(tmp_sqlite_db: Path) -> None: + expected_frame = pl.DataFrame( + { + "id": [1, 2], + "name": ["misc", "other"], + "value": [100.0, -99.5], + "date": ["2020-01-01", "2021-12-31"], + } + ) + expected_series = expected_frame.to_struct() + + alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}") + with alchemy_engine.connect() as conn: + query_result = conn.execute(text("SELECT * FROM test_data ORDER BY name")) + df = pl.DataFrame(list(query_result)) + assert_frame_equal(expected_frame, df) + + query_result = conn.execute(text("SELECT * FROM test_data ORDER BY name")) + s = pl.Series(list(query_result)) + assert_series_equal(expected_series, s) diff --git a/py-polars/tests/unit/io/database/test_write.py b/py-polars/tests/unit/io/database/test_write.py new file mode 100644 index 000000000000..0a4fa7218cff --- /dev/null +++ b/py-polars/tests/unit/io/database/test_write.py @@ -0,0 +1,353 @@ +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING, Any + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import Session +from sqlalchemy.pool import NullPool + +import polars as pl +from polars.io.database._utils import _open_adbc_connection +from polars.testing import assert_frame_equal + +if TYPE_CHECKING: + from pathlib import Path + + from polars._typing import DbWriteEngine + + +@pytest.mark.write_disk +@pytest.mark.parametrize( + ("engine", "uri_connection"), + [ + ("sqlalchemy", True), + ("sqlalchemy", False), + pytest.param( + "adbc", + True, + marks=pytest.mark.skipif( + sys.platform == "win32", + reason="adbc not available on Windows", + ), + ), + pytest.param( + "adbc", + False, + marks=pytest.mark.skipif( + sys.platform == "win32", + reason="adbc not available on Windows", + ), + ), + ], +) +class TestWriteDatabase: + """Database write tests that share common pytest/parametrize options.""" + + @staticmethod + def _get_connection(uri: str, engine: DbWriteEngine, uri_connection: bool) -> Any: + if uri_connection: + return uri + elif engine == "sqlalchemy": + return create_engine(uri) + else: + return _open_adbc_connection(uri) + + def test_write_database_create( + self, engine: DbWriteEngine, uri_connection: bool, tmp_path: Path + ) -> None: + """Test basic database table creation.""" + df = pl.DataFrame( + { + "id": [1234, 5678], + "name": ["misc", "other"], + "value": [1000.0, -9999.0], + } + ) + tmp_path.mkdir(exist_ok=True) + test_db_uri = f"sqlite:///{tmp_path}/test_create_{int(uri_connection)}.db" + + table_name = "test_create" + conn = self._get_connection(test_db_uri, engine, uri_connection) + + assert ( + df.write_database( + table_name=table_name, + connection=conn, + engine=engine, + ) + == 2 + ) + result = pl.read_database( + query=f"SELECT * FROM {table_name}", + connection=create_engine(test_db_uri), + ) + assert_frame_equal(result, df) + + if hasattr(conn, "close"): + conn.close() + + def test_write_database_append_replace( + self, engine: DbWriteEngine, uri_connection: bool, tmp_path: Path + ) -> None: + """Test append/replace ops against existing database table.""" + df = pl.DataFrame( + { + "key": ["xx", "yy", "zz"], + "value": [123, None, 789], + "other": [5.5, 7.0, None], + } + ) + tmp_path.mkdir(exist_ok=True) + test_db_uri = f"sqlite:///{tmp_path}/test_append_{int(uri_connection)}.db" + + table_name = "test_append" + conn = self._get_connection(test_db_uri, engine, uri_connection) + + assert ( + df.write_database( + table_name=table_name, + connection=conn, + engine=engine, + ) + == 3 + ) + with pytest.raises(Exception): # noqa: B017 + df.write_database( + table_name=table_name, + connection=conn, + if_table_exists="fail", + engine=engine, + ) + + assert ( + df.write_database( + table_name=table_name, + connection=conn, + if_table_exists="replace", + engine=engine, + ) + == 3 + ) + result = pl.read_database( + query=f"SELECT * FROM {table_name}", + connection=create_engine(test_db_uri), + ) + assert_frame_equal(result, df) + + assert ( + df[:2].write_database( + table_name=table_name, + connection=conn, + if_table_exists="append", + engine=engine, + ) + == 2 + ) + result = pl.read_database( + query=f"SELECT * FROM {table_name}", + connection=create_engine(test_db_uri), + ) + assert_frame_equal(result, pl.concat([df, df[:2]])) + + if engine == "adbc" and not uri_connection: + assert conn._closed is False + + if hasattr(conn, "close"): + conn.close() + + def test_write_database_create_quoted_tablename( + self, engine: DbWriteEngine, uri_connection: bool, tmp_path: Path + ) -> None: + """Test parsing/handling of quoted database table names.""" + df = pl.DataFrame( + { + "col x": [100, 200, 300], + "col y": ["a", "b", "c"], + } + ) + tmp_path.mkdir(exist_ok=True) + test_db_uri = f"sqlite:///{tmp_path}/test_create_quoted.db" + + # table name has some special chars, so requires quoting, and + # is explicitly qualified with the sqlite 'main' schema + qualified_table_name = f'main."test-append-{engine}-{int(uri_connection)}"' + conn = self._get_connection(test_db_uri, engine, uri_connection) + + assert ( + df.write_database( + table_name=qualified_table_name, + connection=conn, + engine=engine, + ) + == 3 + ) + assert ( + df.write_database( + table_name=qualified_table_name, + connection=conn, + if_table_exists="replace", + engine=engine, + ) + == 3 + ) + result = pl.read_database( + query=f"SELECT * FROM {qualified_table_name}", + connection=create_engine(test_db_uri), + ) + assert_frame_equal(result, df) + + if engine == "adbc" and not uri_connection: + assert conn._closed is False + + if hasattr(conn, "close"): + conn.close() + + def test_write_database_errors( + self, engine: DbWriteEngine, uri_connection: bool, tmp_path: Path + ) -> None: + """Confirm that expected errors are raised.""" + df = pl.DataFrame({"colx": [1, 2, 3]}) + + with pytest.raises( + ValueError, match="`table_name` appears to be invalid: 'w.x.y.z'" + ): + df.write_database( + connection="sqlite:///:memory:", + table_name="w.x.y.z", + engine=engine, + ) + + with pytest.raises( + ValueError, + match="`if_table_exists` must be one of .* got 'do_something'", + ): + df.write_database( + connection="sqlite:///:memory:", + table_name="main.test_errs", + if_table_exists="do_something", # type: ignore[arg-type] + engine=engine, + ) + + with pytest.raises( + TypeError, + match="unrecognised connection type", + ): + df.write_database(connection=True, table_name="misc") # type: ignore[arg-type] + + +@pytest.mark.write_disk +def test_write_database_using_sa_session(tmp_path: str) -> None: + df = pl.DataFrame( + { + "key": ["xx", "yy", "zz"], + "value": [123, None, 789], + "other": [5.5, 7.0, None], + } + ) + table_name = "test_sa_session" + test_db_uri = f"sqlite:///{tmp_path}/test_sa_session.db" + engine = create_engine(test_db_uri, poolclass=NullPool) + with Session(engine) as session: + df.write_database(table_name, session) + session.commit() + + with Session(engine) as session: + result = pl.read_database( + query=f"select * from {table_name}", connection=session + ) + + assert_frame_equal(result, df) + + +@pytest.mark.write_disk +@pytest.mark.parametrize("pass_connection", [True, False]) +def test_write_database_sa_rollback(tmp_path: str, pass_connection: bool) -> None: + df = pl.DataFrame( + { + "key": ["xx", "yy", "zz"], + "value": [123, None, 789], + "other": [5.5, 7.0, None], + } + ) + table_name = "test_sa_rollback" + test_db_uri = f"sqlite:///{tmp_path}/test_sa_rollback.db" + engine = create_engine(test_db_uri, poolclass=NullPool) + with Session(engine) as session: + if pass_connection: + conn = session.connection() + df.write_database(table_name, conn) + else: + df.write_database(table_name, session) + session.rollback() + + with Session(engine) as session: + count = pl.read_database( + query=f"select count(*) from {table_name}", connection=session + ).item(0, 0) + + assert isinstance(count, int) + assert count == 0 + + +@pytest.mark.write_disk +@pytest.mark.parametrize("pass_connection", [True, False]) +def test_write_database_sa_commit(tmp_path: str, pass_connection: bool) -> None: + df = pl.DataFrame( + { + "key": ["xx", "yy", "zz"], + "value": [123, None, 789], + "other": [5.5, 7.0, None], + } + ) + table_name = "test_sa_commit" + test_db_uri = f"sqlite:///{tmp_path}/test_sa_commit.db" + engine = create_engine(test_db_uri, poolclass=NullPool) + with Session(engine) as session: + if pass_connection: + conn = session.connection() + df.write_database(table_name, conn) + else: + df.write_database(table_name, session) + session.commit() + + with Session(engine) as session: + result = pl.read_database( + query=f"select * from {table_name}", connection=session + ) + + assert_frame_equal(result, df) + + +@pytest.mark.skipif(sys.platform == "win32", reason="adbc not available on Windows") +def test_write_database_adbc_temporary_table() -> None: + """Confirm that execution_options are passed along to create temporary tables.""" + df = pl.DataFrame({"colx": [1, 2, 3]}) + temp_tbl_name = "should_be_temptable" + expected_temp_table_create_sql = ( + """CREATE TABLE "should_be_temptable" ("colx" INTEGER)""" + ) + + # test with sqlite in memory + conn = _open_adbc_connection("sqlite:///:memory:") + assert ( + df.write_database( + temp_tbl_name, + connection=conn, + if_table_exists="fail", + engine_options={"temporary": True}, + ) + == 3 + ) + temp_tbl_sql_df = pl.read_database( + "select sql from sqlite_temp_master where type='table' and tbl_name = ?", + connection=conn, + execute_options={"parameters": [temp_tbl_name]}, + ) + assert temp_tbl_sql_df.shape[0] == 1, "no temp table created" + actual_temp_table_create_sql = temp_tbl_sql_df["sql"][0] + assert expected_temp_table_create_sql == actual_temp_table_create_sql + + if hasattr(conn, "close"): + conn.close() diff --git a/py-polars/tests/unit/io/files/delta-table/.part-00000-e42312d7-60e5-454d-acbc-db192d220e73-c000.snappy.parquet.crc b/py-polars/tests/unit/io/files/delta-table/.part-00000-e42312d7-60e5-454d-acbc-db192d220e73-c000.snappy.parquet.crc new file mode 100644 index 000000000000..adb35a378bbb Binary files /dev/null and b/py-polars/tests/unit/io/files/delta-table/.part-00000-e42312d7-60e5-454d-acbc-db192d220e73-c000.snappy.parquet.crc differ diff --git a/py-polars/tests/unit/io/files/delta-table/.part-00000-e4a999da-df45-4fb0-bdc4-d999fc0f58aa-c000.snappy.parquet.crc b/py-polars/tests/unit/io/files/delta-table/.part-00000-e4a999da-df45-4fb0-bdc4-d999fc0f58aa-c000.snappy.parquet.crc new file mode 100644 index 000000000000..42b582e1098d Binary files /dev/null and b/py-polars/tests/unit/io/files/delta-table/.part-00000-e4a999da-df45-4fb0-bdc4-d999fc0f58aa-c000.snappy.parquet.crc differ diff --git a/py-polars/tests/unit/io/files/delta-table/_delta_log/.00000000000000000000.json.crc b/py-polars/tests/unit/io/files/delta-table/_delta_log/.00000000000000000000.json.crc new file mode 100644 index 000000000000..caef6fd161e1 Binary files /dev/null and b/py-polars/tests/unit/io/files/delta-table/_delta_log/.00000000000000000000.json.crc differ diff --git a/py-polars/tests/unit/io/files/delta-table/_delta_log/.00000000000000000001.json.crc b/py-polars/tests/unit/io/files/delta-table/_delta_log/.00000000000000000001.json.crc new file mode 100644 index 000000000000..c89884908750 Binary files /dev/null and b/py-polars/tests/unit/io/files/delta-table/_delta_log/.00000000000000000001.json.crc differ diff --git a/py-polars/tests/unit/io/files/delta-table/_delta_log/00000000000000000000.json b/py-polars/tests/unit/io/files/delta-table/_delta_log/00000000000000000000.json new file mode 100644 index 000000000000..bab81a810a63 --- /dev/null +++ b/py-polars/tests/unit/io/files/delta-table/_delta_log/00000000000000000000.json @@ -0,0 +1,4 @@ +{"protocol":{"minReaderVersion":1,"minWriterVersion":2}} +{"metaData":{"id":"3b7eaff9-6f0c-4d58-914c-737aa94700eb","format":{"provider":"parquet","options":{}},"schemaString":"{\"type\":\"struct\",\"fields\":[{\"name\":\"name\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}},{\"name\":\"age\",\"type\":\"integer\",\"nullable\":true,\"metadata\":{}}]}","partitionColumns":[],"configuration":{},"createdTime":1664141086433}} +{"add":{"path":"part-00000-e4a999da-df45-4fb0-bdc4-d999fc0f58aa-c000.snappy.parquet","partitionValues":{},"size":690,"modificationTime":1664141088000,"dataChange":true}} +{"commitInfo":{"timestamp":1664141088939,"operation":"WRITE","operationParameters":{"mode":"Append","partitionBy":"[]"},"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{"numFiles":"1","numOutputRows":"2","numOutputBytes":"690"},"engineInfo":"Apache-Spark/3.2.0 Delta-Lake/1.1.0"}} diff --git a/py-polars/tests/unit/io/files/delta-table/_delta_log/00000000000000000001.json b/py-polars/tests/unit/io/files/delta-table/_delta_log/00000000000000000001.json new file mode 100644 index 000000000000..06679f9c54fc --- /dev/null +++ b/py-polars/tests/unit/io/files/delta-table/_delta_log/00000000000000000001.json @@ -0,0 +1,3 @@ +{"metaData":{"id":"3b7eaff9-6f0c-4d58-914c-737aa94700eb","format":{"provider":"parquet","options":{}},"schemaString":"{\"type\":\"struct\",\"fields\":[{\"name\":\"name\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}},{\"name\":\"age\",\"type\":\"integer\",\"nullable\":true,\"metadata\":{}},{\"name\":\"test\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}}]}","partitionColumns":[],"configuration":{},"createdTime":1664141086433}} +{"add":{"path":"part-00000-e42312d7-60e5-454d-acbc-db192d220e73-c000.snappy.parquet","partitionValues":{},"size":972,"modificationTime":1664141960000,"dataChange":true}} +{"commitInfo":{"timestamp":1664141960403,"operation":"WRITE","operationParameters":{"mode":"Append","partitionBy":"[]"},"readVersion":0,"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{"numFiles":"1","numOutputRows":"2","numOutputBytes":"972"},"engineInfo":"Apache-Spark/3.2.0 Delta-Lake/1.1.0"}} diff --git a/py-polars/tests/unit/io/files/delta-table/part-00000-e42312d7-60e5-454d-acbc-db192d220e73-c000.snappy.parquet b/py-polars/tests/unit/io/files/delta-table/part-00000-e42312d7-60e5-454d-acbc-db192d220e73-c000.snappy.parquet new file mode 100644 index 000000000000..6e73fab43bdf Binary files /dev/null and b/py-polars/tests/unit/io/files/delta-table/part-00000-e42312d7-60e5-454d-acbc-db192d220e73-c000.snappy.parquet differ diff --git a/py-polars/tests/unit/io/files/delta-table/part-00000-e4a999da-df45-4fb0-bdc4-d999fc0f58aa-c000.snappy.parquet b/py-polars/tests/unit/io/files/delta-table/part-00000-e4a999da-df45-4fb0-bdc4-d999fc0f58aa-c000.snappy.parquet new file mode 100644 index 000000000000..42290a7d2e10 Binary files /dev/null and b/py-polars/tests/unit/io/files/delta-table/part-00000-e4a999da-df45-4fb0-bdc4-d999fc0f58aa-c000.snappy.parquet differ diff --git a/py-polars/tests/unit/io/files/empty.csv b/py-polars/tests/unit/io/files/empty.csv new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/io/files/empty.ods b/py-polars/tests/unit/io/files/empty.ods new file mode 100644 index 000000000000..58ef4927bcfd Binary files /dev/null and b/py-polars/tests/unit/io/files/empty.ods differ diff --git a/py-polars/tests/unit/io/files/empty.xls b/py-polars/tests/unit/io/files/empty.xls new file mode 100644 index 000000000000..86a54ce8204d Binary files /dev/null and b/py-polars/tests/unit/io/files/empty.xls differ diff --git a/py-polars/tests/unit/io/files/empty.xlsb b/py-polars/tests/unit/io/files/empty.xlsb new file mode 100644 index 000000000000..ad8d776ed40b Binary files /dev/null and b/py-polars/tests/unit/io/files/empty.xlsb differ diff --git a/py-polars/tests/unit/io/files/empty.xlsx b/py-polars/tests/unit/io/files/empty.xlsx new file mode 100644 index 000000000000..0e800b7a082e Binary files /dev/null and b/py-polars/tests/unit/io/files/empty.xlsx differ diff --git a/py-polars/tests/unit/io/files/empty_datapage_v2.snappy.parquet b/py-polars/tests/unit/io/files/empty_datapage_v2.snappy.parquet new file mode 100644 index 000000000000..30d6fa7a687b Binary files /dev/null and b/py-polars/tests/unit/io/files/empty_datapage_v2.snappy.parquet differ diff --git a/py-polars/tests/unit/io/files/example.ods b/py-polars/tests/unit/io/files/example.ods new file mode 100644 index 000000000000..eff940590cd8 Binary files /dev/null and b/py-polars/tests/unit/io/files/example.ods differ diff --git a/py-polars/tests/unit/io/files/example.xls b/py-polars/tests/unit/io/files/example.xls new file mode 100644 index 000000000000..d07d32fa899c Binary files /dev/null and b/py-polars/tests/unit/io/files/example.xls differ diff --git a/py-polars/tests/unit/io/files/example.xlsb b/py-polars/tests/unit/io/files/example.xlsb new file mode 100644 index 000000000000..0185629399f0 Binary files /dev/null and b/py-polars/tests/unit/io/files/example.xlsb differ diff --git a/py-polars/tests/unit/io/files/example.xlsx b/py-polars/tests/unit/io/files/example.xlsx new file mode 100644 index 000000000000..4bbaa9ebe9e1 Binary files /dev/null and b/py-polars/tests/unit/io/files/example.xlsx differ diff --git a/py-polars/tests/unit/io/files/foods1.csv b/py-polars/tests/unit/io/files/foods1.csv new file mode 100644 index 000000000000..66f1ef9a1d5f --- /dev/null +++ b/py-polars/tests/unit/io/files/foods1.csv @@ -0,0 +1,28 @@ +category,calories,fats_g,sugars_g +vegetables,45,0.5,2 +seafood,150,5,0 +meat,100,5,0 +fruit,60,0,11 +seafood,140,5,1 +meat,120,10,1 +vegetables,20,0,2 +fruit,30,0,5 +seafood,130,5,0 +fruit,50,4.5,0 +meat,110,7,0 +vegetables,25,0,2 +fruit,30,0,3 +vegetables,22,0,3 +vegetables,25,0,4 +seafood,100,5,0 +seafood,200,10,0 +seafood,200,7,2 +fruit,60,0,11 +meat,110,7,0 +vegetables,25,0,3 +seafood,200,7,2 +seafood,130,1.5,0 +fruit,130,0,25 +meat,100,7,0 +vegetables,30,0,5 +fruit,50,0,11 diff --git a/py-polars/tests/unit/io/files/foods1.ipc b/py-polars/tests/unit/io/files/foods1.ipc new file mode 100644 index 000000000000..d68513b43566 Binary files /dev/null and b/py-polars/tests/unit/io/files/foods1.ipc differ diff --git a/py-polars/tests/unit/io/files/foods1.json b/py-polars/tests/unit/io/files/foods1.json new file mode 100644 index 000000000000..327566874aa1 --- /dev/null +++ b/py-polars/tests/unit/io/files/foods1.json @@ -0,0 +1,27 @@ +[{"category":"vegetables","calories":45,"fats_g":0.5,"sugars_g":2}, +{"category":"seafood","calories":150,"fats_g":5.0,"sugars_g":0}, +{"category":"meat","calories":100,"fats_g":5.0,"sugars_g":0}, +{"category":"fruit","calories":60,"fats_g":0.0,"sugars_g":11}, +{"category":"seafood","calories":140,"fats_g":5.0,"sugars_g":1}, +{"category":"meat","calories":120,"fats_g":10.0,"sugars_g":1}, +{"category":"vegetables","calories":20,"fats_g":0.0,"sugars_g":2}, +{"category":"fruit","calories":30,"fats_g":0.0,"sugars_g":5}, +{"category":"seafood","calories":130,"fats_g":5.0,"sugars_g":0}, +{"category":"fruit","calories":50,"fats_g":4.5,"sugars_g":0}, +{"category":"meat","calories":110,"fats_g":7.0,"sugars_g":0}, +{"category":"vegetables","calories":25,"fats_g":0.0,"sugars_g":2}, +{"category":"fruit","calories":30,"fats_g":0.0,"sugars_g":3}, +{"category":"vegetables","calories":22,"fats_g":0.0,"sugars_g":3}, +{"category":"vegetables","calories":25,"fats_g":0.0,"sugars_g":4}, +{"category":"seafood","calories":100,"fats_g":5.0,"sugars_g":0}, +{"category":"seafood","calories":200,"fats_g":10.0,"sugars_g":0}, +{"category":"seafood","calories":200,"fats_g":7.0,"sugars_g":2}, +{"category":"fruit","calories":60,"fats_g":0.0,"sugars_g":11}, +{"category":"meat","calories":110,"fats_g":7.0,"sugars_g":0}, +{"category":"vegetables","calories":25,"fats_g":0.0,"sugars_g":3}, +{"category":"seafood","calories":200,"fats_g":7.0,"sugars_g":2}, +{"category":"seafood","calories":130,"fats_g":1.5,"sugars_g":0}, +{"category":"fruit","calories":130,"fats_g":0.0,"sugars_g":25}, +{"category":"meat","calories":100,"fats_g":7.0,"sugars_g":0}, +{"category":"vegetables","calories":30,"fats_g":0.0,"sugars_g":5}, +{"category":"fruit","calories":50,"fats_g":0.0,"sugars_g":11}] diff --git a/py-polars/tests/unit/io/files/foods1.ndjson b/py-polars/tests/unit/io/files/foods1.ndjson new file mode 100644 index 000000000000..7bcc87ed1d3d --- /dev/null +++ b/py-polars/tests/unit/io/files/foods1.ndjson @@ -0,0 +1,27 @@ +{"category":"vegetables","calories":45,"fats_g":0.5,"sugars_g":2} +{"category":"seafood","calories":150,"fats_g":5.0,"sugars_g":0} +{"category":"meat","calories":100,"fats_g":5.0,"sugars_g":0} +{"category":"fruit","calories":60,"fats_g":0.0,"sugars_g":11} +{"category":"seafood","calories":140,"fats_g":5.0,"sugars_g":1} +{"category":"meat","calories":120,"fats_g":10.0,"sugars_g":1} +{"category":"vegetables","calories":20,"fats_g":0.0,"sugars_g":2} +{"category":"fruit","calories":30,"fats_g":0.0,"sugars_g":5} +{"category":"seafood","calories":130,"fats_g":5.0,"sugars_g":0} +{"category":"fruit","calories":50,"fats_g":4.5,"sugars_g":0} +{"category":"meat","calories":110,"fats_g":7.0,"sugars_g":0} +{"category":"vegetables","calories":25,"fats_g":0.0,"sugars_g":2} +{"category":"fruit","calories":30,"fats_g":0.0,"sugars_g":3} +{"category":"vegetables","calories":22,"fats_g":0.0,"sugars_g":3} +{"category":"vegetables","calories":25,"fats_g":0.0,"sugars_g":4} +{"category":"seafood","calories":100,"fats_g":5.0,"sugars_g":0} +{"category":"seafood","calories":200,"fats_g":10.0,"sugars_g":0} +{"category":"seafood","calories":200,"fats_g":7.0,"sugars_g":2} +{"category":"fruit","calories":60,"fats_g":0.0,"sugars_g":11} +{"category":"meat","calories":110,"fats_g":7.0,"sugars_g":0} +{"category":"vegetables","calories":25,"fats_g":0.0,"sugars_g":3} +{"category":"seafood","calories":200,"fats_g":7.0,"sugars_g":2} +{"category":"seafood","calories":130,"fats_g":1.5,"sugars_g":0} +{"category":"fruit","calories":130,"fats_g":0.0,"sugars_g":25} +{"category":"meat","calories":100,"fats_g":7.0,"sugars_g":0} +{"category":"vegetables","calories":30,"fats_g":0.0,"sugars_g":5} +{"category":"fruit","calories":50,"fats_g":0.0,"sugars_g":11} diff --git a/py-polars/tests/unit/io/files/foods1.parquet b/py-polars/tests/unit/io/files/foods1.parquet new file mode 100644 index 000000000000..e2e837a81133 Binary files /dev/null and b/py-polars/tests/unit/io/files/foods1.parquet differ diff --git a/py-polars/tests/unit/io/files/foods2.csv b/py-polars/tests/unit/io/files/foods2.csv new file mode 100644 index 000000000000..2faf5d6d04fa --- /dev/null +++ b/py-polars/tests/unit/io/files/foods2.csv @@ -0,0 +1,28 @@ +category,calories,fats_g,sugars_g +meat,101,8,0 +vegetables,44,0.3,3 +seafood,129,4,1 +fruit,49,4.4,0 +meat,95,4,0 +vegetables,24,0,1 +fruit,123,0,24 +fruit,59,0,9 +fruit,27,0,2 +vegetables,21,0,3 +fruit,61,0,12 +vegetables,26,0,4 +seafood,146,6,2 +vegetables,23,0,2 +fruit,51,1,13 +meat,111,8,2 +seafood,104,4,1 +meat,123,12,0 +seafood,201,8,4 +vegetables,35,1,7 +vegetables,21,1,3 +seafood,133,1.6,1 +meat,110,5,0 +seafood,192,6,0 +seafood,142,4,0 +fruit,31,0,4 +seafood,194,12,1 diff --git a/py-polars/tests/unit/io/files/foods2.ipc b/py-polars/tests/unit/io/files/foods2.ipc new file mode 100644 index 000000000000..e9c55ded5fb6 Binary files /dev/null and b/py-polars/tests/unit/io/files/foods2.ipc differ diff --git a/py-polars/tests/unit/io/files/foods2.ndjson b/py-polars/tests/unit/io/files/foods2.ndjson new file mode 100644 index 000000000000..b98bc2d91c4a --- /dev/null +++ b/py-polars/tests/unit/io/files/foods2.ndjson @@ -0,0 +1,27 @@ +{"category":"meat","calories":101,"fats_g":8.0,"sugars_g":0} +{"category":"vegetables","calories":44,"fats_g":0.3,"sugars_g":3} +{"category":"seafood","calories":129,"fats_g":4.0,"sugars_g":1} +{"category":"fruit","calories":49,"fats_g":4.4,"sugars_g":0} +{"category":"meat","calories":95,"fats_g":4.0,"sugars_g":0} +{"category":"vegetables","calories":24,"fats_g":0.0,"sugars_g":1} +{"category":"fruit","calories":123,"fats_g":0.0,"sugars_g":24} +{"category":"fruit","calories":59,"fats_g":0.0,"sugars_g":9} +{"category":"fruit","calories":27,"fats_g":0.0,"sugars_g":2} +{"category":"vegetables","calories":21,"fats_g":0.0,"sugars_g":3} +{"category":"fruit","calories":61,"fats_g":0.0,"sugars_g":12} +{"category":"vegetables","calories":26,"fats_g":0.0,"sugars_g":4} +{"category":"seafood","calories":146,"fats_g":6.0,"sugars_g":2} +{"category":"vegetables","calories":23,"fats_g":0.0,"sugars_g":2} +{"category":"fruit","calories":51,"fats_g":1.0,"sugars_g":13} +{"category":"meat","calories":111,"fats_g":8.0,"sugars_g":2} +{"category":"seafood","calories":104,"fats_g":4.0,"sugars_g":1} +{"category":"meat","calories":123,"fats_g":12.0,"sugars_g":0} +{"category":"seafood","calories":201,"fats_g":8.0,"sugars_g":4} +{"category":"vegetables","calories":35,"fats_g":1.0,"sugars_g":7} +{"category":"vegetables","calories":21,"fats_g":1.0,"sugars_g":3} +{"category":"seafood","calories":133,"fats_g":1.6,"sugars_g":1} +{"category":"meat","calories":110,"fats_g":5.0,"sugars_g":0} +{"category":"seafood","calories":192,"fats_g":6.0,"sugars_g":0} +{"category":"seafood","calories":142,"fats_g":4.0,"sugars_g":0} +{"category":"fruit","calories":31,"fats_g":0.0,"sugars_g":4} +{"category":"seafood","calories":194,"fats_g":12.0,"sugars_g":1} diff --git a/py-polars/tests/unit/io/files/foods2.parquet b/py-polars/tests/unit/io/files/foods2.parquet new file mode 100644 index 000000000000..99d36e724733 Binary files /dev/null and b/py-polars/tests/unit/io/files/foods2.parquet differ diff --git a/py-polars/tests/unit/io/files/foods3.csv b/py-polars/tests/unit/io/files/foods3.csv new file mode 100644 index 000000000000..80150ad15c26 --- /dev/null +++ b/py-polars/tests/unit/io/files/foods3.csv @@ -0,0 +1,28 @@ +category,calories,fats_g,sugars_g +seafood,117,9,0 +seafood,201,6,1 +fruit,59,1,14 +meat,97,6,0 +meat,124,12,1 +meat,113,11,1 +vegetables,30,1,1 +seafood,191,6,1 +vegetables,35,0.4,0 +vegetables,21,0,2 +seafood,121,1.5,0 +seafood,125,5,1 +vegetables,21,0,3 +seafood,142,5,0 +meat,118,7,1 +fruit,61,0,12 +fruit,33,1,4 +vegetables,31,0,6 +meat,109,7,2 +vegetables,22,0,1 +fruit,31,0,2 +vegetables,22,0,2 +seafood,155,5,0 +fruit,133,0,27 +seafood,205,9,0 +fruit,72,4.5,7 +fruit,60,1,7 diff --git a/py-polars/tests/unit/io/files/foods4.csv b/py-polars/tests/unit/io/files/foods4.csv new file mode 100644 index 000000000000..c8e33560dd88 --- /dev/null +++ b/py-polars/tests/unit/io/files/foods4.csv @@ -0,0 +1,28 @@ +category,calories,fats_g,sugars_g +vegetables,21,0,3 +seafood,158,5,4 +fruit,40,3,3 +vegetables,30,1,6 +meat,124,14,3 +fruit,63,3,11 +meat,114,5,0 +vegetables,32,0,5 +fruit,66,0,13 +seafood,200,14,0 +vegetables,26,0,4 +vegetables,25,0,5 +seafood,135,1.6,1 +seafood,144,5,2 +seafood,130,7,2 +vegetables,48,0.7,5 +fruit,136,2,28 +vegetables,32,0,8 +fruit,54,0,14 +seafood,105,5,3 +meat,103,5,3 +fruit,40,0,9 +seafood,205,7,4 +meat,100,5,0 +meat,110,5,0 +fruit,50,2.5,4 +seafood,213,7,5 diff --git a/py-polars/tests/unit/io/files/foods5.csv b/py-polars/tests/unit/io/files/foods5.csv new file mode 100644 index 000000000000..bc6efdfd9106 --- /dev/null +++ b/py-polars/tests/unit/io/files/foods5.csv @@ -0,0 +1,28 @@ +category,calories,fats_g,sugars_g +seafood,142,6,3 +meat,99,4,0 +fruit,127,0,23 +vegetables,23,0,1 +vegetables,30,0,2 +meat,88,5,1 +fruit,55,0,8 +fruit,27,0,4 +seafood,127,1.3,1 +meat,123,11,2 +seafood,124,4,2 +seafood,204,4,4 +fruit,52,0,10 +fruit,58,0,14 +vegetables,18,0,4 +seafood,102,6,0 +seafood,210,11,0 +vegetables,23,0,5 +vegetables,26,0,0 +seafood,145,4,0 +meat,95,5,0 +fruit,34,0,2 +meat,48,2,0 +vegetables,37,0.4,1 +fruit,56,4.2,0 +vegetables,34,0,3 +seafood,180,6,1 diff --git a/py-polars/tests/unit/io/files/graph-data/follows.csv b/py-polars/tests/unit/io/files/graph-data/follows.csv new file mode 100644 index 000000000000..5ec090c283cd --- /dev/null +++ b/py-polars/tests/unit/io/files/graph-data/follows.csv @@ -0,0 +1,4 @@ +Adam,Karissa,2020 +Adam,Zhang,2020 +Karissa,Zhang,2021 +Zhang,Noura,2022 diff --git a/py-polars/tests/unit/io/files/graph-data/user.csv b/py-polars/tests/unit/io/files/graph-data/user.csv new file mode 100644 index 000000000000..0421e38ee559 --- /dev/null +++ b/py-polars/tests/unit/io/files/graph-data/user.csv @@ -0,0 +1,4 @@ +Adam,30 +Karissa,40 +Zhang,50 +Noura,25 diff --git a/py-polars/tests/unit/io/files/gzipped.csv.gz b/py-polars/tests/unit/io/files/gzipped.csv.gz new file mode 100644 index 000000000000..5ea4d8785993 Binary files /dev/null and b/py-polars/tests/unit/io/files/gzipped.csv.gz differ diff --git a/py-polars/tests/unit/io/files/iceberg-table/data/ts_day=2023-03-01/.00000-1-6bc54766-6e8a-4fd5-8c00-c6bacbdcaeeb-00001.parquet.crc b/py-polars/tests/unit/io/files/iceberg-table/data/ts_day=2023-03-01/.00000-1-6bc54766-6e8a-4fd5-8c00-c6bacbdcaeeb-00001.parquet.crc new file mode 100644 index 000000000000..a9285317b405 Binary files /dev/null and b/py-polars/tests/unit/io/files/iceberg-table/data/ts_day=2023-03-01/.00000-1-6bc54766-6e8a-4fd5-8c00-c6bacbdcaeeb-00001.parquet.crc differ diff --git a/py-polars/tests/unit/io/files/iceberg-table/data/ts_day=2023-03-01/00000-1-6bc54766-6e8a-4fd5-8c00-c6bacbdcaeeb-00001.parquet b/py-polars/tests/unit/io/files/iceberg-table/data/ts_day=2023-03-01/00000-1-6bc54766-6e8a-4fd5-8c00-c6bacbdcaeeb-00001.parquet new file mode 100644 index 000000000000..0bbb8ba707a5 Binary files /dev/null and b/py-polars/tests/unit/io/files/iceberg-table/data/ts_day=2023-03-01/00000-1-6bc54766-6e8a-4fd5-8c00-c6bacbdcaeeb-00001.parquet differ diff --git a/py-polars/tests/unit/io/files/iceberg-table/data/ts_day=2023-03-02/.00000-1-6bc54766-6e8a-4fd5-8c00-c6bacbdcaeeb-00002.parquet.crc b/py-polars/tests/unit/io/files/iceberg-table/data/ts_day=2023-03-02/.00000-1-6bc54766-6e8a-4fd5-8c00-c6bacbdcaeeb-00002.parquet.crc new file mode 100644 index 000000000000..258a5cd76cc1 Binary files /dev/null and b/py-polars/tests/unit/io/files/iceberg-table/data/ts_day=2023-03-02/.00000-1-6bc54766-6e8a-4fd5-8c00-c6bacbdcaeeb-00002.parquet.crc differ diff --git a/py-polars/tests/unit/io/files/iceberg-table/data/ts_day=2023-03-02/00000-1-6bc54766-6e8a-4fd5-8c00-c6bacbdcaeeb-00002.parquet b/py-polars/tests/unit/io/files/iceberg-table/data/ts_day=2023-03-02/00000-1-6bc54766-6e8a-4fd5-8c00-c6bacbdcaeeb-00002.parquet new file mode 100644 index 000000000000..8da5caa43195 Binary files /dev/null and b/py-polars/tests/unit/io/files/iceberg-table/data/ts_day=2023-03-02/00000-1-6bc54766-6e8a-4fd5-8c00-c6bacbdcaeeb-00002.parquet differ diff --git a/py-polars/tests/unit/io/files/iceberg-table/metadata/aef5d952-7e24-4764-9b30-3483be37240f-m0.avro b/py-polars/tests/unit/io/files/iceberg-table/metadata/aef5d952-7e24-4764-9b30-3483be37240f-m0.avro new file mode 100644 index 000000000000..a5e28a04bcff Binary files /dev/null and b/py-polars/tests/unit/io/files/iceberg-table/metadata/aef5d952-7e24-4764-9b30-3483be37240f-m0.avro differ diff --git a/py-polars/tests/unit/io/files/iceberg-table/metadata/snap-7051579356916758811-1-aef5d952-7e24-4764-9b30-3483be37240f.avro b/py-polars/tests/unit/io/files/iceberg-table/metadata/snap-7051579356916758811-1-aef5d952-7e24-4764-9b30-3483be37240f.avro new file mode 100644 index 000000000000..21f503266051 Binary files /dev/null and b/py-polars/tests/unit/io/files/iceberg-table/metadata/snap-7051579356916758811-1-aef5d952-7e24-4764-9b30-3483be37240f.avro differ diff --git a/py-polars/tests/unit/io/files/iceberg-table/metadata/v2.metadata.json b/py-polars/tests/unit/io/files/iceberg-table/metadata/v2.metadata.json new file mode 100644 index 000000000000..04e078f8a63b --- /dev/null +++ b/py-polars/tests/unit/io/files/iceberg-table/metadata/v2.metadata.json @@ -0,0 +1,109 @@ +{ + "format-version" : 1, + "table-uuid" : "c470045a-5d75-48aa-9d4b-86a86a9b1fce", + "location" : "/tmp/iceberg/t1", + "last-updated-ms" : 1694547405299, + "last-column-id" : 3, + "schema" : { + "type" : "struct", + "schema-id" : 0, + "fields" : [ { + "id" : 1, + "name" : "id", + "required" : false, + "type" : "int" + }, { + "id" : 2, + "name" : "str", + "required" : false, + "type" : "string" + }, { + "id" : 3, + "name" : "ts", + "required" : false, + "type" : "timestamp" + } ] + }, + "current-schema-id" : 0, + "schemas" : [ { + "type" : "struct", + "schema-id" : 0, + "fields" : [ { + "id" : 1, + "name" : "id", + "required" : false, + "type" : "int" + }, { + "id" : 2, + "name" : "str", + "required" : false, + "type" : "string" + }, { + "id" : 3, + "name" : "ts", + "required" : false, + "type" : "timestamp" + } ] + } ], + "partition-spec" : [ { + "name" : "ts_day", + "transform" : "day", + "source-id" : 3, + "field-id" : 1000 + } ], + "default-spec-id" : 0, + "partition-specs" : [ { + "spec-id" : 0, + "fields" : [ { + "name" : "ts_day", + "transform" : "day", + "source-id" : 3, + "field-id" : 1000 + } ] + } ], + "last-partition-id" : 1000, + "default-sort-order-id" : 0, + "sort-orders" : [ { + "order-id" : 0, + "fields" : [ ] + } ], + "properties" : { + "owner" : "fokkodriesprong" + }, + "current-snapshot-id" : 7051579356916758811, + "refs" : { + "main" : { + "snapshot-id" : 7051579356916758811, + "type" : "branch" + } + }, + "snapshots" : [ { + "snapshot-id" : 7051579356916758811, + "timestamp-ms" : 1694547405299, + "summary" : { + "operation" : "append", + "spark.app.id" : "local-1694547283063", + "added-data-files" : "2", + "added-records" : "3", + "added-files-size" : "1788", + "changed-partition-count" : "2", + "total-records" : "3", + "total-files-size" : "1788", + "total-data-files" : "2", + "total-delete-files" : "0", + "total-position-deletes" : "0", + "total-equality-deletes" : "0" + }, + "manifest-list" : "/tmp/iceberg/t1/metadata/snap-7051579356916758811-1-aef5d952-7e24-4764-9b30-3483be37240f.avro", + "schema-id" : 0 + } ], + "statistics" : [ ], + "snapshot-log" : [ { + "timestamp-ms" : 1694547405299, + "snapshot-id" : 7051579356916758811 + } ], + "metadata-log" : [ { + "timestamp-ms" : 1694547211303, + "metadata-file" : "/tmp/iceberg/t1/metadata/v1.metadata.json" + } ] +} \ No newline at end of file diff --git a/py-polars/tests/unit/io/files/mixed.ods b/py-polars/tests/unit/io/files/mixed.ods new file mode 100644 index 000000000000..87df7b8c6ecc Binary files /dev/null and b/py-polars/tests/unit/io/files/mixed.ods differ diff --git a/py-polars/tests/unit/io/files/mixed.xlsb b/py-polars/tests/unit/io/files/mixed.xlsb new file mode 100644 index 000000000000..2d5a738f2709 Binary files /dev/null and b/py-polars/tests/unit/io/files/mixed.xlsb differ diff --git a/py-polars/tests/unit/io/files/mixed.xlsx b/py-polars/tests/unit/io/files/mixed.xlsx new file mode 100644 index 000000000000..69879ddc3703 Binary files /dev/null and b/py-polars/tests/unit/io/files/mixed.xlsx differ diff --git a/py-polars/tests/unit/io/files/nan_test.xlsx b/py-polars/tests/unit/io/files/nan_test.xlsx new file mode 100644 index 000000000000..4b3edbcc6bdb Binary files /dev/null and b/py-polars/tests/unit/io/files/nan_test.xlsx differ diff --git a/py-polars/tests/unit/io/files/only_header.csv b/py-polars/tests/unit/io/files/only_header.csv new file mode 100644 index 000000000000..022dea2a48dc --- /dev/null +++ b/py-polars/tests/unit/io/files/only_header.csv @@ -0,0 +1 @@ +Name,Address \ No newline at end of file diff --git a/py-polars/tests/unit/io/files/small.csv b/py-polars/tests/unit/io/files/small.csv new file mode 100644 index 000000000000..98a471696f2b --- /dev/null +++ b/py-polars/tests/unit/io/files/small.csv @@ -0,0 +1,5 @@ +a,b,c +1,i,16200126 +2,j,16250130 +3,k,17220012 +4,l,17290009 diff --git a/py-polars/tests/unit/io/files/small.parquet b/py-polars/tests/unit/io/files/small.parquet new file mode 100644 index 000000000000..cd6f7d72b15a Binary files /dev/null and b/py-polars/tests/unit/io/files/small.parquet differ diff --git a/py-polars/tests/unit/io/files/test_empty_rows.xlsx b/py-polars/tests/unit/io/files/test_empty_rows.xlsx new file mode 100644 index 000000000000..1fc27ee4c3e4 Binary files /dev/null and b/py-polars/tests/unit/io/files/test_empty_rows.xlsx differ diff --git a/py-polars/tests/unit/io/files/tz_aware.parquet b/py-polars/tests/unit/io/files/tz_aware.parquet new file mode 100755 index 000000000000..8baad68da983 Binary files /dev/null and b/py-polars/tests/unit/io/files/tz_aware.parquet differ diff --git a/py-polars/tests/unit/io/files/zstd_compressed.csv.zst b/py-polars/tests/unit/io/files/zstd_compressed.csv.zst new file mode 100644 index 000000000000..ae05d77685f6 Binary files /dev/null and b/py-polars/tests/unit/io/files/zstd_compressed.csv.zst differ diff --git a/py-polars/tests/unit/io/test_avro.py b/py-polars/tests/unit/io/test_avro.py new file mode 100644 index 000000000000..bfd960a60228 --- /dev/null +++ b/py-polars/tests/unit/io/test_avro.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import io +from typing import TYPE_CHECKING + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + +if TYPE_CHECKING: + from pathlib import Path + + from polars._typing import AvroCompression + + +COMPRESSIONS = ["uncompressed", "snappy", "deflate"] + + +@pytest.fixture +def example_df() -> pl.DataFrame: + return pl.DataFrame({"i64": [1, 2], "f64": [0.1, 0.2], "str": ["a", "b"]}) + + +@pytest.mark.parametrize("compression", COMPRESSIONS) +def test_from_to_buffer(example_df: pl.DataFrame, compression: AvroCompression) -> None: + buf = io.BytesIO() + example_df.write_avro(buf, compression=compression) + buf.seek(0) + + read_df = pl.read_avro(buf) + assert_frame_equal(example_df, read_df) + + +@pytest.mark.write_disk +@pytest.mark.parametrize("compression", COMPRESSIONS) +def test_from_to_file( + example_df: pl.DataFrame, compression: AvroCompression, tmp_path: Path +) -> None: + tmp_path.mkdir(exist_ok=True) + + file_path = tmp_path / "small.avro" + example_df.write_avro(file_path, compression=compression) + df_read = pl.read_avro(file_path) + + assert_frame_equal(example_df, df_read) + + +def test_select_columns() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [True, False, True], "c": ["a", "b", "c"]}) + expected = pl.DataFrame({"b": [True, False, True], "c": ["a", "b", "c"]}) + + f = io.BytesIO() + df.write_avro(f) + f.seek(0) + + read_df = pl.read_avro(f, columns=["b", "c"]) + assert_frame_equal(expected, read_df) + + +def test_select_projection() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [True, False, True], "c": ["a", "b", "c"]}) + expected = pl.DataFrame({"b": [True, False, True], "c": ["a", "b", "c"]}) + + f = io.BytesIO() + df.write_avro(f) + f.seek(0) + + read_df = pl.read_avro(f, columns=[1, 2]) + assert_frame_equal(expected, read_df) + + +def test_with_name() -> None: + df = pl.DataFrame({"a": [1]}) + expected = pl.DataFrame( + { + "type": ["record"], + "name": ["my_schema_name"], + "fields": [[{"name": "a", "type": ["null", "long"]}]], + } + ) + + f = io.BytesIO() + df.write_avro(f, name="my_schema_name") + + f.seek(0) + raw = f.read() + + read_df = pl.read_json(raw[raw.find(b"{") : raw.rfind(b"}") + 1]) + + assert_frame_equal(expected, read_df) diff --git a/py-polars/tests/unit/io/test_csv.py b/py-polars/tests/unit/io/test_csv.py new file mode 100644 index 000000000000..e35cf98754c9 --- /dev/null +++ b/py-polars/tests/unit/io/test_csv.py @@ -0,0 +1,2607 @@ +from __future__ import annotations + +import gzip +import io +import os +import sys +import textwrap +import zlib +from datetime import date, datetime, time, timedelta, timezone +from decimal import Decimal as D +from tempfile import NamedTemporaryFile +from typing import TYPE_CHECKING, TypedDict + +import numpy as np +import pyarrow as pa +import pytest +import zstandard + +import polars as pl +from polars._utils.various import normalize_filepath +from polars.exceptions import ComputeError, InvalidOperationError, NoDataError +from polars.io.csv import BatchedCsvReader +from polars.testing import assert_frame_equal, assert_series_equal + +if TYPE_CHECKING: + from pathlib import Path + + from polars._typing import TimeUnit + from tests.unit.conftest import MemoryUsage + + +@pytest.fixture +def foods_file_path(io_files_path: Path) -> Path: + return io_files_path / "foods1.csv" + + +def test_quoted_date() -> None: + csv = textwrap.dedent( + """\ + a,b + "2022-01-01",1 + "2022-01-02",2 + """ + ) + result = pl.read_csv(csv.encode(), try_parse_dates=True) + expected = pl.DataFrame({"a": [date(2022, 1, 1), date(2022, 1, 2)], "b": [1, 2]}) + assert_frame_equal(result, expected) + + +# Issue: https://github.com/pola-rs/polars/issues/10826 +def test_date_pattern_with_datetime_override_10826() -> None: + result = pl.read_csv( + source=io.StringIO("col\n2023-01-01\n2023-02-01\n2023-03-01"), + schema_overrides={"col": pl.Datetime}, + ) + expected = pl.Series( + "col", [datetime(2023, 1, 1), datetime(2023, 2, 1), datetime(2023, 3, 1)] + ).to_frame() + assert_frame_equal(result, expected) + + result = pl.read_csv( + source=io.StringIO("col\n2023-01-01T01:02:03\n2023-02-01\n2023-03-01"), + schema_overrides={"col": pl.Datetime}, + ) + expected = pl.Series( + "col", + [datetime(2023, 1, 1, 1, 2, 3), datetime(2023, 2, 1), datetime(2023, 3, 1)], + ).to_frame() + assert_frame_equal(result, expected) + + +def test_to_from_buffer(df_no_lists: pl.DataFrame) -> None: + df = df_no_lists + buf = io.BytesIO() + df.write_csv(buf) + buf.seek(0) + + read_df = pl.read_csv(buf, try_parse_dates=True) + read_df = read_df.with_columns( + pl.col("cat").cast(pl.Categorical), + pl.col("enum").cast(pl.Enum(["foo", "ham", "bar"])), + pl.col("time").cast(pl.Time), + ) + assert_frame_equal(df, read_df, categorical_as_str=True) + with pytest.raises(AssertionError): + assert_frame_equal(df.select("time", "cat"), read_df, categorical_as_str=True) + + +@pytest.mark.write_disk +def test_to_from_file(df_no_lists: pl.DataFrame, tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + df = df_no_lists.drop("strings_nulls") + + file_path = tmp_path / "small.csv" + df.write_csv(file_path) + read_df = pl.read_csv(file_path, try_parse_dates=True) + + read_df = read_df.with_columns( + pl.col("cat").cast(pl.Categorical), + pl.col("enum").cast(pl.Enum(["foo", "ham", "bar"])), + pl.col("time").cast(pl.Time), + ) + assert_frame_equal(df, read_df, categorical_as_str=True) + + +def test_normalize_filepath(io_files_path: Path) -> None: + with pytest.raises(IsADirectoryError): + normalize_filepath(io_files_path) + + assert normalize_filepath(str(io_files_path), check_not_directory=False) == str( + io_files_path + ) + + +def test_infer_schema_false() -> None: + csv = textwrap.dedent( + """\ + a,b,c + 1,2,3 + 1,2,3 + """ + ) + f = io.StringIO(csv) + df = pl.read_csv(f, infer_schema=False) + assert df.dtypes == [pl.String, pl.String, pl.String] + + +@pytest.mark.may_fail_auto_streaming # read->scan_csv dispatch +def test_csv_null_values() -> None: + csv = textwrap.dedent( + """\ + a,b,c + na,b,c + a,na,c + """ + ) + f = io.StringIO(csv) + df = pl.read_csv(f, null_values="na") + assert df.rows() == [(None, "b", "c"), ("a", None, "c")] + + # note: after reading, the buffer position in StringIO will have been + # advanced; reading again will raise NoDataError, so we provide a hint + # in the error string about this, suggesting "seek(0)" as a possible fix... + with pytest.raises( + NoDataError, match=r"empty data .* position = 20; try seek\(0\)" + ): + pl.read_csv(f) + + # ... unless we explicitly tell read_csv not to raise an + # exception, in which case we expect an empty dataframe + assert_frame_equal(pl.read_csv(f, raise_if_empty=False), pl.DataFrame()) + + out = io.BytesIO() + df.write_csv(out, null_value="na") + assert csv == out.getvalue().decode("ascii") + + csv = textwrap.dedent( + """\ + a,b,c + na,b,c + a,n/a,c + """ + ) + f = io.StringIO(csv) + df = pl.read_csv(f, null_values=["na", "n/a"]) + assert df.rows() == [(None, "b", "c"), ("a", None, "c")] + + csv = textwrap.dedent( + r""" + a,b,c + na,b,c + a,\N,c + ,b, + """ + ) + f = io.StringIO(csv) + df = pl.read_csv(f, null_values={"a": "na", "b": r"\N"}) + assert df.rows() == [(None, "b", "c"), ("a", None, "c"), (None, "b", None)] + + +def test_csv_missing_utf8_is_empty_string() -> None: + # validate 'missing_utf8_is_empty_string' for missing fields that are... + # >> ...leading + # >> ...trailing (both EOL & EOF) + # >> ...in lines that have missing fields + # >> ...in cols containing no other strings + # >> ...interacting with other user-supplied null values + + csv = textwrap.dedent( + r""" + a,b,c + na,b,c + a,\N,c + ,b, + """ + ) + f = io.StringIO(csv) + df = pl.read_csv( + f, + null_values={"a": "na", "b": r"\N"}, + missing_utf8_is_empty_string=True, + ) + # ┌──────┬──────┬─────┐ + # │ a ┆ b ┆ c │ + # ╞══════╪══════╪═════╡ + # │ null ┆ b ┆ c │ + # │ a ┆ null ┆ c │ + # │ ┆ b ┆ │ + # └──────┴──────┴─────┘ + assert df.rows() == [(None, "b", "c"), ("a", None, "c"), ("", "b", "")] + + csv = textwrap.dedent( + r""" + a,b,c,d,e,f,g + na,,,,\N,, + a,\N,c,,,,g + ,,, + ,,,na,,, + """ + ) + f = io.StringIO(csv) + df = pl.read_csv(f, null_values=["na", r"\N"]) + # ┌──────┬──────┬──────┬──────┬──────┬──────┬──────┐ + # │ a ┆ b ┆ c ┆ d ┆ e ┆ f ┆ g │ + # ╞══════╪══════╪══════╪══════╪══════╪══════╪══════╡ + # │ null ┆ null ┆ null ┆ null ┆ null ┆ null ┆ null │ + # │ a ┆ null ┆ c ┆ null ┆ null ┆ null ┆ g │ + # │ null ┆ null ┆ null ┆ null ┆ null ┆ null ┆ null │ + # │ null ┆ null ┆ null ┆ null ┆ null ┆ null ┆ null │ + # └──────┴──────┴──────┴──────┴──────┴──────┴──────┘ + assert df.rows() == [ + (None, None, None, None, None, None, None), + ("a", None, "c", None, None, None, "g"), + (None, None, None, None, None, None, None), + (None, None, None, None, None, None, None), + ] + + f.seek(0) + df = pl.read_csv( + f, + null_values=["na", r"\N"], + missing_utf8_is_empty_string=True, + ) + # ┌──────┬──────┬─────┬──────┬──────┬──────┬─────┐ + # │ a ┆ b ┆ c ┆ d ┆ e ┆ f ┆ g │ + # ╞══════╪══════╪═════╪══════╪══════╪══════╪═════╡ + # │ null ┆ ┆ ┆ ┆ null ┆ ┆ │ + # │ a ┆ null ┆ c ┆ ┆ ┆ ┆ g │ + # │ ┆ ┆ ┆ ┆ ┆ ┆ │ + # │ ┆ ┆ ┆ null ┆ ┆ ┆ │ + # └──────┴──────┴─────┴──────┴──────┴──────┴─────┘ + assert df.rows() == [ + (None, "", "", "", None, "", ""), + ("a", None, "c", "", "", "", "g"), + ("", "", "", "", "", "", ""), + ("", "", "", None, "", "", ""), + ] + + +def test_csv_int_types() -> None: + f = io.StringIO( + "u8,i8,u16,i16,u32,i32,u64,i64,i128\n" + "0,0,0,0,0,0,0,0,0\n" + "0,-128,0,-32768,0,-2147483648,0,-9223372036854775808,-170141183460469231731687303715884105728\n" + "255,127,65535,32767,4294967295,2147483647,18446744073709551615,9223372036854775807,170141183460469231731687303715884105727\n" + "01,01,01,01,01,01,01,01,01\n" + "01,-01,01,-01,01,-01,01,-01,01\n" + ) + df = pl.read_csv( + f, + schema={ + "u8": pl.UInt8, + "i8": pl.Int8, + "u16": pl.UInt16, + "i16": pl.Int16, + "u32": pl.UInt32, + "i32": pl.Int32, + "u64": pl.UInt64, + "i64": pl.Int64, + "i128": pl.Int128, + }, + ) + + assert_frame_equal( + df, + pl.DataFrame( + { + "u8": pl.Series([0, 0, 255, 1, 1], dtype=pl.UInt8), + "i8": pl.Series([0, -128, 127, 1, -1], dtype=pl.Int8), + "u16": pl.Series([0, 0, 65535, 1, 1], dtype=pl.UInt16), + "i16": pl.Series([0, -32768, 32767, 1, -1], dtype=pl.Int16), + "u32": pl.Series([0, 0, 4294967295, 1, 1], dtype=pl.UInt32), + "i32": pl.Series([0, -2147483648, 2147483647, 1, -1], dtype=pl.Int32), + "u64": pl.Series([0, 0, 18446744073709551615, 1, 1], dtype=pl.UInt64), + "i64": pl.Series( + [0, -9223372036854775808, 9223372036854775807, 1, -1], + dtype=pl.Int64, + ), + "i128": pl.Series( + [ + 0, + -170141183460469231731687303715884105728, + 170141183460469231731687303715884105727, + 1, + 1, + ], + dtype=pl.Int128, + ), + } + ), + ) + + +def test_csv_float_parsing() -> None: + lines_with_floats = [ + "123.86,+123.86,-123.86\n", + ".987,+.987,-.987\n", + "5.,+5.,-5.\n", + "inf,+inf,-inf\n", + "NaN,+NaN,-NaN\n", + ] + + for line_with_floats in lines_with_floats: + f = io.StringIO(line_with_floats) + df = pl.read_csv(f, has_header=False, new_columns=["a", "b", "c"]) + assert df.dtypes == [pl.Float64, pl.Float64, pl.Float64] + + lines_with_scientific_numbers = [ + "1e27,1E65,1e-28,1E-9\n", + "+1e27,+1E65,+1e-28,+1E-9\n", + "1e+27,1E+65,1e-28,1E-9\n", + "+1e+27,+1E+65,+1e-28,+1E-9\n", + "-1e+27,-1E+65,-1e-28,-1E-9\n", + # "e27,E65,e-28,E-9\n", + # "+e27,+E65,+e-28,+E-9\n", + # "-e27,-E65,-e-28,-E-9\n", + ] + + for line_with_scientific_numbers in lines_with_scientific_numbers: + f = io.StringIO(line_with_scientific_numbers) + df = pl.read_csv(f, has_header=False, new_columns=["a", "b", "c", "d"]) + assert df.dtypes == [pl.Float64, pl.Float64, pl.Float64, pl.Float64] + + +def test_datetime_parsing() -> None: + csv = textwrap.dedent( + """\ + timestamp,open,high + 2021-01-01 00:00:00,0.00305500,0.00306000 + 2021-01-01 00:15:00,0.00298800,0.00300400 + 2021-01-01 00:30:00,0.00298300,0.00300100 + 2021-01-01 00:45:00,0.00299400,0.00304000 + """ + ) + + f = io.StringIO(csv) + df = pl.read_csv(f, try_parse_dates=True) + assert df.dtypes == [pl.Datetime, pl.Float64, pl.Float64] + + +def test_datetime_parsing_default_formats() -> None: + csv = textwrap.dedent( + """\ + ts_dmy,ts_dmy_f,ts_dmy_p + 01/01/2021 00:00:00,31-01-2021T00:00:00.123,31-01-2021 11:00 + 01/01/2021 00:15:00,31-01-2021T00:15:00.123,31-01-2021 01:00 + 01/01/2021 00:30:00,31-01-2021T00:30:00.123,31-01-2021 01:15 + 01/01/2021 00:45:00,31-01-2021T00:45:00.123,31-01-2021 01:30 + """ + ) + + f = io.StringIO(csv) + df = pl.read_csv(f, try_parse_dates=True) + assert df.dtypes == [pl.Datetime, pl.Datetime, pl.Datetime] + + +@pytest.mark.may_fail_auto_streaming # read->scan_csv dispatch +def test_partial_schema_overrides() -> None: + csv = textwrap.dedent( + """\ + a,b,c + 1,2,3 + 1,2,3 + """ + ) + f = io.StringIO(csv) + df = pl.read_csv(f, schema_overrides=[pl.String]) + assert df.dtypes == [pl.String, pl.Int64, pl.Int64] + + +@pytest.mark.may_fail_auto_streaming # read->scan_csv dispatch +def test_schema_overrides_with_column_name_selection() -> None: + csv = textwrap.dedent( + """\ + a,b,c,d + 1,2,3,4 + 1,2,3,4 + """ + ) + f = io.StringIO(csv) + df = pl.read_csv(f, columns=["c", "b", "d"], schema_overrides=[pl.Int32, pl.String]) + assert df.dtypes == [pl.String, pl.Int32, pl.Int64] + + +@pytest.mark.may_fail_auto_streaming # read->scan_csv dispatch +def test_schema_overrides_with_column_idx_selection() -> None: + csv = textwrap.dedent( + """\ + a,b,c,d + 1,2,3,4 + 1,2,3,4 + """ + ) + f = io.StringIO(csv) + df = pl.read_csv(f, columns=[2, 1, 3], schema_overrides=[pl.Int32, pl.String]) + # Columns without an explicit dtype set will get pl.String if dtypes is a list + # if the column selection is done with column indices instead of column names. + assert df.dtypes == [pl.String, pl.Int32, pl.String] + # Projections are sorted. + assert df.columns == ["b", "c", "d"] + + +def test_partial_column_rename() -> None: + csv = textwrap.dedent( + """\ + a,b,c + 1,2,3 + 1,2,3 + """ + ) + f = io.StringIO(csv) + for use in [True, False]: + f.seek(0) + df = pl.read_csv(f, new_columns=["foo"], use_pyarrow=use) + assert df.columns == ["foo", "b", "c"] + + +@pytest.mark.parametrize( + ("col_input", "col_out"), + [([0, 1], ["a", "b"]), ([0, 2], ["a", "c"]), (["b"], ["b"])], +) +def test_read_csv_columns_argument( + col_input: list[int] | list[str], col_out: list[str] +) -> None: + csv = textwrap.dedent( + """\ + a,b,c + 1,2,3 + 1,2,3 + """ + ) + f = io.StringIO(csv) + df = pl.read_csv(f, columns=col_input) + assert df.shape[0] == 2 + assert df.columns == col_out + + +@pytest.mark.may_fail_auto_streaming # read->scan_csv dispatch +def test_read_csv_buffer_ownership() -> None: + bts = b"\xf0\x9f\x98\x80,5.55,333\n\xf0\x9f\x98\x86,-5.0,666" + buf = io.BytesIO(bts) + df = pl.read_csv( + buf, + has_header=False, + new_columns=["emoji", "flt", "int"], + ) + # confirm that read_csv succeeded, and didn't close the input buffer (#2696) + assert df.shape == (2, 3) + assert df.rows() == [("😀", 5.55, 333), ("😆", -5.0, 666)] + assert not buf.closed + assert buf.read() == bts + + +@pytest.mark.may_fail_auto_streaming # read->scan_csv dispatch +@pytest.mark.write_disk +def test_read_csv_encoding(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + bts = ( + b"Value1,Value2,Value3,Value4,Region\n" + b"-30,7.5,2578,1,\xa5x\xa5_\n-32,7.97,3006,1,\xa5x\xa4\xa4\n" + b"-31,8,3242,2,\xb7s\xa6\xcb\n-33,7.97,3300,3,\xb0\xaa\xb6\xaf\n" + b"-20,7.91,3384,4,\xac\xfc\xb0\xea\n" + ) + + file_path = tmp_path / "encoding.csv" + file_path.write_bytes(bts) + + file_str = str(file_path) + bytesio = io.BytesIO(bts) + + for use_pyarrow in (False, True): + bytesio.seek(0) + for file in [file_path, file_str, bts, bytesio]: + assert_series_equal( + pl.read_csv( + file, # type: ignore[arg-type] + encoding="big5", + use_pyarrow=use_pyarrow, + ).get_column("Region"), + pl.Series("Region", ["台北", "台中", "新竹", "高雄", "美國"]), + ) + + +@pytest.mark.may_fail_auto_streaming # read->scan_csv dispatch +@pytest.mark.write_disk +def test_read_csv_encoding_lossy(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + bts = ( + b"\xc8\xec\xff,\xc2\xee\xe7\xf0\xe0\xf1\xf2,\xc3\xee\xf0\xee\xe4\n" + b"\xc8\xe2\xe0\xed,25,\xcc\xee\xf1\xea\xe2\xe0\n" + # \x98 is not supported in "windows-1251". + b"\xce\xeb\xfc\xe3\xe0,30,\xd1\xe0\xed\xea\xf2-\x98\xcf\xe5\xf2\xe5\xf0\xe1\xf3\xf0\xe3\n" + ) + + file_path = tmp_path / "encoding_lossy.csv" + file_path.write_bytes(bts) + + file_str = str(file_path) + bytesio = io.BytesIO(bts) + bytesio.seek(0) + + for file in [file_path, file_str, bts, bytesio]: + assert_series_equal( + pl.read_csv( + file, # type: ignore[arg-type] + encoding="windows-1251-lossy", + use_pyarrow=False, + ).get_column("Город"), + pl.Series("Город", ["Москва", "Санкт-�Петербург"]), + ) + + +@pytest.mark.may_fail_auto_streaming # read->scan_csv dispatch +def test_column_rename_and_schema_overrides() -> None: + csv = textwrap.dedent( + """\ + a,b,c + 1,2,3 + 1,2,3 + """ + ) + f = io.StringIO(csv) + df = pl.read_csv( + f, + new_columns=["A", "B", "C"], + schema_overrides={"A": pl.String, "B": pl.Int64, "C": pl.Float32}, + ) + assert df.dtypes == [pl.String, pl.Int64, pl.Float32] + + f = io.StringIO(csv) + df = pl.read_csv( + f, + columns=["a", "c"], + new_columns=["A", "C"], + schema_overrides={"A": pl.String, "C": pl.Float32}, + ) + assert df.dtypes == [pl.String, pl.Float32] + + csv = textwrap.dedent( + """\ + 1,2,3 + 1,2,3 + """ + ) + f = io.StringIO(csv) + df = pl.read_csv( + f, + new_columns=["A", "B", "C"], + schema_overrides={"A": pl.String, "C": pl.Float32}, + has_header=False, + ) + assert df.dtypes == [pl.String, pl.Int64, pl.Float32] + + +def test_compressed_csv(io_files_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("POLARS_FORCE_ASYNC", "0") + + # gzip compression + csv = textwrap.dedent( + """\ + a,b,c + 1,a,1.0 + 2,b,2.0 + 3,c,3.0 + """ + ) + fout = io.BytesIO() + with gzip.GzipFile(fileobj=fout, mode="w") as f: + f.write(csv.encode()) + + csv_bytes = fout.getvalue() + out = pl.read_csv(csv_bytes) + expected = pl.DataFrame( + {"a": [1, 2, 3], "b": ["a", "b", "c"], "c": [1.0, 2.0, 3.0]} + ) + assert_frame_equal(out, expected) + + # now from disk + csv_file = io_files_path / "gzipped.csv.gz" + out = pl.read_csv(str(csv_file), truncate_ragged_lines=True) + assert_frame_equal(out, expected) + + # now with schema defined + schema = {"a": pl.Int64, "b": pl.Utf8, "c": pl.Float64} + out = pl.read_csv(str(csv_file), schema=schema, truncate_ragged_lines=True) + assert_frame_equal(out, expected) + + # now with column projection + out = pl.read_csv(csv_bytes, columns=["a", "b"]) + expected = pl.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}) + assert_frame_equal(out, expected) + + # zlib compression + csv_bytes = zlib.compress(csv.encode()) + out = pl.read_csv(csv_bytes) + expected = pl.DataFrame( + {"a": [1, 2, 3], "b": ["a", "b", "c"], "c": [1.0, 2.0, 3.0]} + ) + assert_frame_equal(out, expected) + + # zstd compression + csv_bytes = zstandard.compress(csv.encode()) + out = pl.read_csv(csv_bytes) + assert_frame_equal(out, expected) + + # zstd compressed file + csv_file = io_files_path / "zstd_compressed.csv.zst" + out = pl.scan_csv(csv_file, truncate_ragged_lines=True).collect() + assert_frame_equal(out, expected) + out = pl.read_csv(str(csv_file), truncate_ragged_lines=True) + assert_frame_equal(out, expected) + + # no compression + f2 = io.BytesIO(b"a,b\n1,2\n") + out2 = pl.read_csv(f2) + expected = pl.DataFrame({"a": [1], "b": [2]}) + assert_frame_equal(out2, expected) + + +def test_partial_decompression(foods_file_path: Path) -> None: + f_out = io.BytesIO() + with gzip.GzipFile(fileobj=f_out, mode="w") as f: + f.write(foods_file_path.read_bytes()) + + csv_bytes = f_out.getvalue() + for n_rows in [1, 5, 26]: + out = pl.read_csv(csv_bytes, n_rows=n_rows) + assert out.shape == (n_rows, 4) + + # zstd compression + csv_bytes = zstandard.compress(foods_file_path.read_bytes()) + for n_rows in [1, 5, 26]: + out = pl.read_csv(csv_bytes, n_rows=n_rows) + assert out.shape == (n_rows, 4) + + +def test_empty_bytes() -> None: + b = b"" + with pytest.raises(NoDataError): + pl.read_csv(b) + + df = pl.read_csv(b, raise_if_empty=False) + assert_frame_equal(df, pl.DataFrame()) + + +def test_empty_line_with_single_column() -> None: + df = pl.read_csv( + b"a\n\nb\n", + new_columns=["A"], + has_header=False, + comment_prefix="#", + use_pyarrow=False, + ) + expected = pl.DataFrame({"A": ["a", None, "b"]}) + assert_frame_equal(df, expected) + + +def test_empty_line_with_multiple_columns() -> None: + df = pl.read_csv( + b"a,b\n\nc,d\n", + new_columns=["A", "B"], + has_header=False, + comment_prefix="#", + use_pyarrow=False, + ) + expected = pl.DataFrame({"A": ["a", None, "c"], "B": ["b", None, "d"]}) + assert_frame_equal(df, expected) + + +def test_preserve_whitespace_at_line_start() -> None: + df = pl.read_csv( + b" a\n b \n c\nd", + new_columns=["A"], + has_header=False, + use_pyarrow=False, + ) + expected = pl.DataFrame({"A": [" a", " b ", " c", "d"]}) + assert_frame_equal(df, expected) + + +def test_csv_multi_char_comment() -> None: + csv = textwrap.dedent( + """\ + #a,b + ##c,d + """ + ) + f = io.StringIO(csv) + df = pl.read_csv( + f, + new_columns=["A", "B"], + has_header=False, + comment_prefix="##", + use_pyarrow=False, + ) + expected = pl.DataFrame({"A": ["#a"], "B": ["b"]}) + assert_frame_equal(df, expected) + + # check comment interaction with headers/skip_rows + for skip_rows, b in ( + (1, io.BytesIO(b"\n#!skip\n#!skip\nCol1\tCol2\n")), + (0, io.BytesIO(b"\n#!skip\n#!skip\nCol1\tCol2")), + (0, io.BytesIO(b"#!skip\nCol1\tCol2\n#!skip\n")), + (0, io.BytesIO(b"#!skip\nCol1\tCol2")), + ): + df = pl.read_csv(b, separator="\t", comment_prefix="#!", skip_rows=skip_rows) + assert_frame_equal(df, pl.DataFrame(schema=["Col1", "Col2"]).cast(pl.Utf8)) + + +def test_csv_quote_char() -> None: + expected = pl.DataFrame( + [ + pl.Series("linenum", [1, 2, 3, 4, 5, 6, 7, 8, 9]), + pl.Series( + "last_name", + [ + "Jagger", + 'O"Brian', + "Richards", + 'L"Etoile', + "Watts", + "Smith", + '"Wyman"', + "Woods", + 'J"o"ne"s', + ], + ), + pl.Series( + "first_name", + [ + "Mick", + '"Mary"', + "Keith", + "Bennet", + "Charlie", + 'D"Shawn', + "Bill", + "Ron", + "Brian", + ], + ), + ] + ) + rolling_stones = textwrap.dedent( + """\ + linenum,last_name,first_name + 1,Jagger,Mick + 2,O"Brian,"Mary" + 3,Richards,Keith + 4,L"Etoile,Bennet + 5,Watts,Charlie + 6,Smith,D"Shawn + 7,"Wyman",Bill + 8,Woods,Ron + 9,J"o"ne"s,Brian + """ + ) + for use_pyarrow in (False, True): + out = pl.read_csv( + rolling_stones.encode(), quote_char=None, use_pyarrow=use_pyarrow + ) + assert out.shape == (9, 3) + assert_frame_equal(out, expected) + + # non-standard quote char + df = pl.DataFrame({"x": ["", "0*0", "xyz"]}) + csv_data = df.write_csv(quote_char="*") + + assert csv_data == "x\n**\n*0**0*\nxyz\n" + assert_frame_equal(df, pl.read_csv(io.StringIO(csv_data), quote_char="*")) + + +def test_csv_empty_quotes_char_1622() -> None: + pl.read_csv(b"a,b,c,d\nA1,B1,C1,1\nA2,B2,C2,2\n", quote_char="") + + +def test_ignore_try_parse_dates() -> None: + csv = textwrap.dedent( + """\ + a,b,c + 1,i,16200126 + 2,j,16250130 + 3,k,17220012 + 4,l,17290009 + """ + ).encode() + + headers = ["a", "b", "c"] + dtypes: dict[str, type[pl.DataType]] = dict.fromkeys( + headers, pl.String + ) # Forces String type for every column + df = pl.read_csv(csv, columns=headers, schema_overrides=dtypes) + assert df.dtypes == [pl.String, pl.String, pl.String] + + +def test_csv_date_handling() -> None: + csv = textwrap.dedent( + """\ + date + 1745-04-02 + 1742-03-21 + 1743-06-16 + 1730-07-22 + + 1739-03-16 + """ + ) + expected = pl.DataFrame( + { + "date": [ + date(1745, 4, 2), + date(1742, 3, 21), + date(1743, 6, 16), + date(1730, 7, 22), + None, + date(1739, 3, 16), + ] + } + ) + out = pl.read_csv(csv.encode(), try_parse_dates=True) + assert_frame_equal(out, expected) + dtypes = {"date": pl.Date} + out = pl.read_csv(csv.encode(), schema_overrides=dtypes) + assert_frame_equal(out, expected) + + +def test_csv_no_date_dtype_because_string() -> None: + csv = textwrap.dedent( + """\ + date + 2024-01-01 + 2024-01-02 + hello + """ + ) + out = pl.read_csv(csv.encode(), try_parse_dates=True) + assert out.dtypes == [pl.String] + + +def test_csv_infer_date_dtype() -> None: + csv = textwrap.dedent( + """\ + date + 2024-01-01 + "2024-01-02" + + 2024-01-04 + """ + ) + out = pl.read_csv(csv.encode(), try_parse_dates=True) + expected = pl.DataFrame( + { + "date": [ + date(2024, 1, 1), + date(2024, 1, 2), + None, + date(2024, 1, 4), + ] + } + ) + assert_frame_equal(out, expected) + + +def test_csv_date_dtype_ignore_errors() -> None: + csv = textwrap.dedent( + """\ + date + hello + 2024-01-02 + world + !! + """ + ) + out = pl.read_csv( + csv.encode(), ignore_errors=True, schema_overrides={"date": pl.Date} + ) + expected = pl.DataFrame( + { + "date": [ + None, + date(2024, 1, 2), + None, + None, + ] + } + ) + assert_frame_equal(out, expected) + + +@pytest.mark.may_fail_auto_streaming # read->scan_csv dispatch +def test_csv_globbing(io_files_path: Path) -> None: + path = io_files_path / "foods*.csv" + df = pl.read_csv(path) + assert df.shape == (135, 4) + + with pytest.MonkeyPatch.context() as mp: + mp.setenv("POLARS_FORCE_ASYNC", "0") + + with pytest.raises(ValueError): + _ = pl.read_csv(path, columns=[0, 1]) + + df = pl.read_csv(path, columns=["category", "sugars_g"]) + assert df.shape == (135, 2) + assert df.row(-1) == ("seafood", 1) + assert df.row(0) == ("vegetables", 2) + + with pytest.MonkeyPatch.context() as mp: + mp.setenv("POLARS_FORCE_ASYNC", "0") + + with pytest.raises(ValueError): + _ = pl.read_csv( + path, schema_overrides=[pl.String, pl.Int64, pl.Int64, pl.Int64] + ) + + dtypes = { + "category": pl.String, + "calories": pl.Int32, + "fats_g": pl.Float32, + "sugars_g": pl.Int32, + } + + df = pl.read_csv(path, schema_overrides=dtypes) + assert df.dtypes == list(dtypes.values()) + + +def test_csv_schema_offset(foods_file_path: Path) -> None: + csv = textwrap.dedent( + """\ + metadata + line + col1,col2,col3 + alpha,beta,gamma + 1,2.0,"A" + 3,4.0,"B" + 5,6.0,"C" + """ + ).encode() + + df = pl.read_csv(csv, skip_rows=3) + assert df.columns == ["alpha", "beta", "gamma"] + assert df.shape == (3, 3) + assert df.dtypes == [pl.Int64, pl.Float64, pl.String] + + df = pl.read_csv(csv, skip_rows=2, skip_rows_after_header=1) + assert df.columns == ["col1", "col2", "col3"] + assert df.shape == (3, 3) + assert df.dtypes == [pl.Int64, pl.Float64, pl.String] + + df = pl.scan_csv(foods_file_path, skip_rows=4).collect() + assert df.columns == ["fruit", "60", "0", "11"] + assert df.shape == (23, 4) + assert df.dtypes == [pl.String, pl.Int64, pl.Float64, pl.Int64] + + df = pl.scan_csv(foods_file_path, skip_rows_after_header=24).collect() + assert df.columns == ["category", "calories", "fats_g", "sugars_g"] + assert df.shape == (3, 4) + assert df.dtypes == [pl.String, pl.Int64, pl.Int64, pl.Int64] + + df = pl.scan_csv( + foods_file_path, skip_rows_after_header=24, infer_schema_length=1 + ).collect() + assert df.columns == ["category", "calories", "fats_g", "sugars_g"] + assert df.shape == (3, 4) + assert df.dtypes == [pl.String, pl.Int64, pl.Int64, pl.Int64] + + +def test_empty_string_missing_round_trip() -> None: + df = pl.DataFrame({"varA": ["A", "", None], "varB": ["B", "", None]}) + for null in (None, "NA", "NULL", r"\N"): + f = io.BytesIO() + df.write_csv(f, null_value=null) + f.seek(0) + df_read = pl.read_csv(f, null_values=null) + assert_frame_equal(df, df_read) + + +def test_write_csv_separator() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]}) + f = io.BytesIO() + df.write_csv(f, separator="\t") + f.seek(0) + assert f.read() == b"a\tb\n1\t1\n2\t2\n3\t3\n" + f.seek(0) + assert_frame_equal(df, pl.read_csv(f, separator="\t")) + + +def test_write_csv_line_terminator() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]}) + f = io.BytesIO() + df.write_csv(f, line_terminator="\r\n") + f.seek(0) + assert f.read() == b"a,b\r\n1,1\r\n2,2\r\n3,3\r\n" + f.seek(0) + assert_frame_equal(df, pl.read_csv(f, eol_char="\n")) + + +def test_escaped_null_values() -> None: + csv = textwrap.dedent( + """\ + "a","b","c" + "a","n/a","NA" + "None","2","3.0" + """ + ) + f = io.StringIO(csv) + df = pl.read_csv( + f, + null_values={"a": "None", "b": "n/a", "c": "NA"}, + schema_overrides={"a": pl.String, "b": pl.Int64, "c": pl.Float64}, + ) + assert df[1, "a"] is None + assert df[0, "b"] is None + assert df[0, "c"] is None + + +def test_quoting_round_trip() -> None: + f = io.BytesIO() + df = pl.DataFrame( + { + "a": [ + "tab,separated,field", + "newline\nseparated\nfield", + 'quote"separated"field', + ] + } + ) + df.write_csv(f) + f.seek(0) + read_df = pl.read_csv(f) + assert_frame_equal(read_df, df) + + +def test_csv_field_schema_inference_with_whitespace() -> None: + csv = """\ +bool,bool-,-bool,float,float-,-float,int,int-,-int +true,true , true,1.2,1.2 , 1.2,1,1 , 1 +""" + df = pl.read_csv(io.StringIO(csv), has_header=True) + expected = pl.DataFrame( + { + "bool": [True], + "bool-": ["true "], + "-bool": [" true"], + "float": [1.2], + "float-": ["1.2 "], + "-float": [" 1.2"], + "int": [1], + "int-": ["1 "], + "-int": [" 1"], + } + ) + assert_frame_equal(df, expected) + + +def test_fallback_chrono_parser() -> None: + data = textwrap.dedent( + """\ + date_1,date_2 + 2021-01-01,2021-1-1 + 2021-02-02,2021-2-2 + 2021-10-10,2021-10-10 + """ + ) + df = pl.read_csv(data.encode(), try_parse_dates=True) + assert df.null_count().row(0) == (0, 0) + + +def test_tz_aware_try_parse_dates() -> None: + data = ( + "a,b,c,d\n" + "2020-01-01T02:00:00+01:00,2021-04-28T00:00:00+02:00,2021-03-28T00:00:00+01:00,2\n" + "2020-01-01T03:00:00+01:00,2021-04-29T00:00:00+02:00,2021-03-29T00:00:00+02:00,3\n" + ) + result = pl.read_csv(io.StringIO(data), try_parse_dates=True) + expected = pl.DataFrame( + { + "a": [ + datetime(2020, 1, 1, 1, tzinfo=timezone.utc), + datetime(2020, 1, 1, 2, tzinfo=timezone.utc), + ], + "b": [ + datetime(2021, 4, 27, 22, tzinfo=timezone.utc), + datetime(2021, 4, 28, 22, tzinfo=timezone.utc), + ], + "c": [ + datetime(2021, 3, 27, 23, tzinfo=timezone.utc), + datetime(2021, 3, 28, 22, tzinfo=timezone.utc), + ], + "d": [2, 3], + } + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("try_parse_dates", [True, False]) +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_csv_overwrite_datetime_dtype( + try_parse_dates: bool, time_unit: TimeUnit +) -> None: + data = """\ +a +2020-1-1T00:00:00.123456789 +2020-1-2T00:00:00.987654321 +2020-1-3T00:00:00.132547698 +""" + result = pl.read_csv( + io.StringIO(data), + try_parse_dates=try_parse_dates, + schema_overrides={"a": pl.Datetime(time_unit)}, + ) + expected = pl.DataFrame( + { + "a": pl.Series( + [ + "2020-01-01T00:00:00.123456789", + "2020-01-02T00:00:00.987654321", + "2020-01-03T00:00:00.132547698", + ] + ).str.to_datetime(time_unit=time_unit) + } + ) + assert_frame_equal(result, expected) + + +def test_csv_string_escaping() -> None: + df = pl.DataFrame({"a": ["Free trip to A,B", '''Special rate "1.79"''']}) + f = io.BytesIO() + df.write_csv(f) + f.seek(0) + df_read = pl.read_csv(f) + assert_frame_equal(df_read, df) + + +@pytest.mark.write_disk +def test_glob_csv(df_no_lists: pl.DataFrame, tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + df = df_no_lists.drop("strings_nulls") + file_path = tmp_path / "small.csv" + df.write_csv(file_path) + + path_glob = tmp_path / "small*.csv" + assert pl.scan_csv(path_glob).collect().shape == (3, 12) + assert pl.read_csv(path_glob).shape == (3, 12) + + +def test_csv_whitespace_separator_at_start_do_not_skip() -> None: + csv = "\t\t\t\t0\t1" + result = pl.read_csv(csv.encode(), separator="\t", has_header=False) + expected = { + "column_1": [None], + "column_2": [None], + "column_3": [None], + "column_4": [None], + "column_5": [0], + "column_6": [1], + } + assert result.to_dict(as_series=False) == expected + + +def test_csv_whitespace_separator_at_end_do_not_skip() -> None: + csv = "0\t1\t\t\t\t" + result = pl.read_csv(csv.encode(), separator="\t", has_header=False) + expected = { + "column_1": [0], + "column_2": [1], + "column_3": [None], + "column_4": [None], + "column_5": [None], + "column_6": [None], + } + assert result.to_dict(as_series=False) == expected + + +def test_csv_multiple_null_values() -> None: + df = pl.DataFrame( + { + "a": [1, 2, None, 4], + "b": ["2022-01-01", "__NA__", "", "NA"], + } + ) + f = io.BytesIO() + df.write_csv(f) + f.seek(0) + + df2 = pl.read_csv(f, null_values=["__NA__", "NA"]) + expected = pl.DataFrame( + { + "a": [1, 2, None, 4], + "b": ["2022-01-01", None, "", None], + } + ) + assert_frame_equal(df2, expected) + + +def test_different_eol_char() -> None: + csv = "a,1,10;b,2,20;c,3,30" + expected = pl.DataFrame( + {"column_1": ["a", "b", "c"], "column_2": [1, 2, 3], "column_3": [10, 20, 30]} + ) + assert_frame_equal( + pl.read_csv(csv.encode(), eol_char=";", has_header=False), expected + ) + + +def test_csv_write_escape_headers() -> None: + df0 = pl.DataFrame({"col,1": ["data,1"], 'col"2': ['data"2'], "col:3": ["data:3"]}) + out = io.BytesIO() + df0.write_csv(out) + assert out.getvalue() == b'"col,1","col""2",col:3\n"data,1","data""2",data:3\n' + + df1 = pl.DataFrame({"c,o,l,u,m,n": [123]}) + out = io.BytesIO() + df1.write_csv(out) + + out.seek(0) + df2 = pl.read_csv(out) + assert_frame_equal(df1, df2) + assert df2.schema == {"c,o,l,u,m,n": pl.Int64} + + +def test_csv_write_escape_newlines() -> None: + df = pl.DataFrame({"escape": ["n\nn"]}) + f = io.BytesIO() + df.write_csv(f) + f.seek(0) + read_df = pl.read_csv(f) + assert_frame_equal(df, read_df) + + +def test_skip_new_line_embedded_lines() -> None: + csv = r"""a,b,c,d,e\n +1,2,3,"\n Test",\n +4,5,6,"Test A",\n +7,8,,"Test B \n",\n""" + + for empty_string, missing_value in ((True, ""), (False, None)): + df = pl.read_csv( + csv.encode(), + skip_rows_after_header=1, + infer_schema_length=0, + missing_utf8_is_empty_string=empty_string, + ) + assert df.to_dict(as_series=False) == { + "a": ["4", "7"], + "b": ["5", "8"], + "c": ["6", missing_value], + "d": ["Test A", "Test B \\n"], + "e\\n": ["\\n", "\\n"], + } + + +def test_csv_schema_overrides_bool() -> None: + csv = "a, b\n" + ",false\n" + ",false\n" + ",false" + df = pl.read_csv( + csv.encode(), + schema_overrides={"a": pl.Boolean, "b": pl.Boolean}, + ) + assert df.dtypes == [pl.Boolean, pl.Boolean] + + +@pytest.mark.parametrize( + ("fmt", "expected"), + [ + (None, "dt\n2022-01-02T00:00:00.000000\n"), + ("%F %T%.3f", "dt\n2022-01-02 00:00:00.000\n"), + ("%Y", "dt\n2022\n"), + ("%m", "dt\n01\n"), + ("%m$%d", "dt\n01$02\n"), + ("%R", "dt\n00:00\n"), + ], +) +def test_datetime_format(fmt: str, expected: str) -> None: + df = pl.DataFrame({"dt": [datetime(2022, 1, 2)]}) + csv = df.write_csv(datetime_format=fmt) + assert csv == expected + + +@pytest.mark.parametrize( + ("fmt", "expected"), + [ + (None, "dt\n2022-01-02T00:00:00.000000+0000\n"), + ("%F %T%.3f%z", "dt\n2022-01-02 00:00:00.000+0000\n"), + ("%Y%z", "dt\n2022+0000\n"), + ("%m%z", "dt\n01+0000\n"), + ("%m$%d%z", "dt\n01$02+0000\n"), + ("%R%z", "dt\n00:00+0000\n"), + ], +) +@pytest.mark.parametrize("tzinfo", [timezone.utc, timezone(timedelta(hours=0))]) +def test_datetime_format_tz_aware(fmt: str, expected: str, tzinfo: timezone) -> None: + df = pl.DataFrame({"dt": [datetime(2022, 1, 2, tzinfo=tzinfo)]}) + csv = df.write_csv(datetime_format=fmt) + assert csv == expected + + +@pytest.mark.parametrize( + ("tu1", "tu2", "expected"), + [ + ( + "ns", + "ns", + "x,y\n2022-09-04T10:30:45.123000000,2022-09-04T10:30:45.123000000\n", + ), + ( + "ns", + "us", + "x,y\n2022-09-04T10:30:45.123000000,2022-09-04T10:30:45.123000\n", + ), + ( + "ns", + "ms", + "x,y\n2022-09-04T10:30:45.123000000,2022-09-04T10:30:45.123\n", + ), + ("us", "us", "x,y\n2022-09-04T10:30:45.123000,2022-09-04T10:30:45.123000\n"), + ("us", "ms", "x,y\n2022-09-04T10:30:45.123000,2022-09-04T10:30:45.123\n"), + ("ms", "us", "x,y\n2022-09-04T10:30:45.123,2022-09-04T10:30:45.123000\n"), + ("ms", "ms", "x,y\n2022-09-04T10:30:45.123,2022-09-04T10:30:45.123\n"), + ], +) +def test_datetime_format_inferred_precision( + tu1: TimeUnit, tu2: TimeUnit, expected: str +) -> None: + df = pl.DataFrame( + data={ + "x": [datetime(2022, 9, 4, 10, 30, 45, 123000)], + "y": [datetime(2022, 9, 4, 10, 30, 45, 123000)], + }, + schema=[ + ("x", pl.Datetime(tu1)), + ("y", pl.Datetime(tu2)), + ], + ) + assert expected == df.write_csv() + + +def test_inferred_datetime_format_mixed() -> None: + ts = pl.datetime_range(datetime(2000, 1, 1), datetime(2000, 1, 2), eager=True) + df = pl.DataFrame({"naive": ts, "aware": ts.dt.replace_time_zone("UTC")}) + result = df.write_csv() + expected = ( + "naive,aware\n" + "2000-01-01T00:00:00.000000,2000-01-01T00:00:00.000000+0000\n" + "2000-01-02T00:00:00.000000,2000-01-02T00:00:00.000000+0000\n" + ) + assert result == expected + + +@pytest.mark.parametrize( + ("fmt", "expected"), + [ + (None, "dt\n2022-01-02\n"), + ("%Y", "dt\n2022\n"), + ("%m", "dt\n01\n"), + ("%m$%d", "dt\n01$02\n"), + ], +) +def test_date_format(fmt: str, expected: str) -> None: + df = pl.DataFrame({"dt": [date(2022, 1, 2)]}) + csv = df.write_csv(date_format=fmt) + assert csv == expected + + +@pytest.mark.parametrize( + ("fmt", "expected"), + [ + (None, "dt\n16:15:30.000000000\n"), + ("%R", "dt\n16:15\n"), + ], +) +def test_time_format(fmt: str, expected: str) -> None: + df = pl.DataFrame({"dt": [time(16, 15, 30)]}) + csv = df.write_csv(time_format=fmt) + assert csv == expected + + +@pytest.mark.parametrize("dtype", [pl.Float32, pl.Float64]) +def test_float_precision(dtype: pl.Float32 | pl.Float64) -> None: + df = pl.Series("col", [1.0, 2.2, 3.33], dtype=dtype).to_frame() + + assert df.write_csv(float_precision=None) == "col\n1.0\n2.2\n3.33\n" + assert df.write_csv(float_precision=0) == "col\n1\n2\n3\n" + assert df.write_csv(float_precision=1) == "col\n1.0\n2.2\n3.3\n" + assert df.write_csv(float_precision=2) == "col\n1.00\n2.20\n3.33\n" + assert df.write_csv(float_precision=3) == "col\n1.000\n2.200\n3.330\n" + + +def test_float_scientific() -> None: + df = ( + pl.Series( + "colf64", + [3.141592653589793 * mult for mult in (1e-8, 1e-3, 1e3, 1e17)], + dtype=pl.Float64, + ) + .to_frame() + .with_columns(pl.col("colf64").cast(pl.Float32).alias("colf32")) + ) + + assert ( + df.write_csv(float_precision=None, float_scientific=False) + == "colf64,colf32\n0.00000003141592653589793,0.00000003141592586075603\n0.0031415926535897933,0.0031415927223861217\n3141.592653589793,3141.5927734375\n314159265358979300,314159265516355600\n" + ) + assert ( + df.write_csv(float_precision=0, float_scientific=False) + == "colf64,colf32\n0,0\n0,0\n3142,3142\n314159265358979328,314159265516355584\n" + ) + assert ( + df.write_csv(float_precision=1, float_scientific=False) + == "colf64,colf32\n0.0,0.0\n0.0,0.0\n3141.6,3141.6\n314159265358979328.0,314159265516355584.0\n" + ) + assert ( + df.write_csv(float_precision=3, float_scientific=False) + == "colf64,colf32\n0.000,0.000\n0.003,0.003\n3141.593,3141.593\n314159265358979328.000,314159265516355584.000\n" + ) + + assert ( + df.write_csv(float_precision=None, float_scientific=True) + == "colf64,colf32\n3.141592653589793e-8,3.1415926e-8\n3.1415926535897933e-3,3.1415927e-3\n3.141592653589793e3,3.1415928e3\n3.141592653589793e17,3.1415927e17\n" + ) + assert ( + df.write_csv(float_precision=0, float_scientific=True) + == "colf64,colf32\n3e-8,3e-8\n3e-3,3e-3\n3e3,3e3\n3e17,3e17\n" + ) + assert ( + df.write_csv(float_precision=1, float_scientific=True) + == "colf64,colf32\n3.1e-8,3.1e-8\n3.1e-3,3.1e-3\n3.1e3,3.1e3\n3.1e17,3.1e17\n" + ) + assert ( + df.write_csv(float_precision=3, float_scientific=True) + == "colf64,colf32\n3.142e-8,3.142e-8\n3.142e-3,3.142e-3\n3.142e3,3.142e3\n3.142e17,3.142e17\n" + ) + + +def test_skip_rows_different_field_len() -> None: + csv = io.StringIO( + textwrap.dedent( + """\ + a,b + 1,A + 2, + 3,B + 4, + """ + ) + ) + for empty_string, missing_value in ((True, ""), (False, None)): + csv.seek(0) + assert pl.read_csv( + csv, skip_rows_after_header=2, missing_utf8_is_empty_string=empty_string + ).to_dict(as_series=False) == { + "a": [3, 4], + "b": ["B", missing_value], + } + + +def test_duplicated_columns() -> None: + csv = textwrap.dedent( + """a,a + 1,2 + """ + ) + assert pl.read_csv(csv.encode()).columns == ["a", "a_duplicated_0"] + new = ["c", "d"] + assert pl.read_csv(csv.encode(), new_columns=new).columns == new + + +def test_error_message() -> None: + data = io.StringIO("target,wind,energy,miso\n1,2,3,4\n1,2,1e5,1\n") + with pytest.raises( + ComputeError, + match=r"could not parse `1e5` as dtype `i64` at column 'energy' \(column number 3\)", + ): + pl.read_csv(data, infer_schema_length=1) + + +def test_csv_categorical_lifetime() -> None: + # escaped strings do some heap allocates in the builder + # this tests of the lifetimes remains valid + csv = textwrap.dedent( + r""" + a,b + "needs_escape",b + "" ""needs" escape" foo"",b + "" ""needs" escape" foo"", + """ + ) + + df = pl.read_csv( + csv.encode(), schema_overrides={"a": pl.Categorical, "b": pl.Categorical} + ) + assert df.dtypes == [pl.Categorical, pl.Categorical] + assert df.to_dict(as_series=False) == { + "a": ["needs_escape", ' "needs escape foo', ' "needs escape foo'], + "b": ["b", "b", None], + } + + assert (df["a"] == df["b"]).to_list() == [False, False, None] + + +def test_csv_categorical_categorical_merge() -> None: + N = 50 + f = io.BytesIO() + pl.DataFrame({"x": ["A"] * N + ["B"] * N}).write_csv(f) + f.seek(0) + assert pl.read_csv( + f, schema_overrides={"x": pl.Categorical}, sample_size=10 + ).unique(maintain_order=True)["x"].to_list() == ["A", "B"] + + +@pytest.mark.write_disk +def test_batched_csv_reader(foods_file_path: Path) -> None: + reader = pl.read_csv_batched(foods_file_path, batch_size=4) + assert isinstance(reader, BatchedCsvReader) + + batches = reader.next_batches(5) + assert batches is not None + out = pl.concat(batches) + assert_frame_equal(out, pl.read_csv(foods_file_path).head(out.height)) + + # the final batch of the low-memory variant is different + reader = pl.read_csv_batched(foods_file_path, batch_size=4, low_memory=True) + batches = reader.next_batches(10) + assert batches is not None + + assert_frame_equal(pl.concat(batches), pl.read_csv(foods_file_path)) + + reader = pl.read_csv_batched(foods_file_path, batch_size=4, low_memory=True) + batches = reader.next_batches(10) + assert_frame_equal(pl.concat(batches), pl.read_csv(foods_file_path)) # type: ignore[arg-type] + + # ragged lines + with NamedTemporaryFile() as tmp: + data = b"A\nB,ragged\nC" + tmp.write(data) + tmp.seek(0) + + expected = pl.DataFrame({"A": ["B", "C"]}) + batches = pl.read_csv_batched( + tmp.name, + has_header=True, + truncate_ragged_lines=True, + ).next_batches(1) + + assert batches is not None + assert_frame_equal(pl.concat(batches), expected) + + +def test_batched_csv_reader_empty(io_files_path: Path) -> None: + empty_csv = io_files_path / "empty.csv" + with pytest.raises(NoDataError, match="empty CSV"): + pl.read_csv_batched(source=empty_csv) + + reader = pl.read_csv_batched(source=empty_csv, raise_if_empty=False) + assert reader.next_batches(1) is None + + +def test_batched_csv_reader_all_batches(foods_file_path: Path) -> None: + for new_columns in [None, ["Category", "Calories", "Fats_g", "Sugars_g"]]: + out = pl.read_csv(foods_file_path, new_columns=new_columns) + reader = pl.read_csv_batched( + foods_file_path, new_columns=new_columns, batch_size=4 + ) + batches = reader.next_batches(5) + batched_dfs = [] + + while batches: + batched_dfs.extend(batches) + batches = reader.next_batches(5) + + assert all(x.height > 0 for x in batched_dfs) + + batched_concat_df = pl.concat(batched_dfs, rechunk=True) + assert_frame_equal(out, batched_concat_df) + + +def test_batched_csv_reader_no_batches(foods_file_path: Path) -> None: + reader = pl.read_csv_batched(foods_file_path, batch_size=4) + batches = reader.next_batches(0) + + assert batches is None + + +def test_read_csv_batched_invalid_source() -> None: + with pytest.raises(TypeError): + pl.read_csv_batched(source=5) # type: ignore[arg-type] + + +def test_csv_single_categorical_null() -> None: + f = io.BytesIO() + pl.DataFrame( + { + "x": ["A"], + "y": [None], + "z": ["A"], + } + ).write_csv(f) + f.seek(0) + + df = pl.read_csv( + f, + schema_overrides={"y": pl.Categorical}, + ) + + assert df.dtypes == [pl.String, pl.Categorical, pl.String] + assert df.to_dict(as_series=False) == {"x": ["A"], "y": [None], "z": ["A"]} + + +def test_csv_quoted_missing() -> None: + csv = ( + '"col1"|"col2"|"col3"|"col4"\n' + '"0"|"Free text with a line\nbreak"|"123"|"456"\n' + '"1"|"Free text without a linebreak"|""|"789"\n' + '"0"|"Free text with \ntwo \nlinebreaks"|"101112"|"131415"' + ) + result = pl.read_csv( + csv.encode(), separator="|", schema_overrides={"col3": pl.Int32} + ) + expected = pl.DataFrame( + { + "col1": [0, 1, 0], + "col2": [ + "Free text with a line\nbreak", + "Free text without a linebreak", + "Free text with \ntwo \nlinebreaks", + ], + "col3": [123, None, 101112], + "col4": [456, 789, 131415], + }, + schema_overrides={"col3": pl.Int32}, + ) + assert_frame_equal(result, expected) + + +def test_csv_write_tz_aware() -> None: + df = pl.DataFrame({"times": datetime(2021, 1, 1)}).with_columns( + pl.col("times") + .dt.replace_time_zone("UTC") + .dt.convert_time_zone("Europe/Zurich") + ) + assert df.write_csv() == "times\n2021-01-01T01:00:00.000000+0100\n" + + +def test_csv_statistics_offset() -> None: + # this would fail if the statistics sample did not also sample + # from the end of the file + # the lines at the end have larger rows as the numbers increase + N = 5_000 + csv = "\n".join(str(x) for x in range(N)) + assert pl.read_csv(io.StringIO(csv), n_rows=N).height == 4999 + + +@pytest.mark.write_disk +def test_csv_scan_categorical(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + N = 5_000 + df = pl.DataFrame({"x": ["A"] * N}) + + file_path = tmp_path / "test_csv_scan_categorical.csv" + df.write_csv(file_path) + result = pl.scan_csv(file_path, schema_overrides={"x": pl.Categorical}).collect() + + assert result["x"].dtype == pl.Categorical + + +@pytest.mark.write_disk +def test_csv_scan_new_columns_less_than_original_columns(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + df = pl.DataFrame({"x": ["A"], "y": ["A"], "z": "A"}) + + file_path = tmp_path / "test_csv_scan_new_columns.csv" + df.write_csv(file_path) + result = pl.scan_csv(file_path, new_columns=["x_new", "y_new"]).collect() + + assert result.columns == ["x_new", "y_new", "z"] + + +def test_read_csv_chunked() -> None: + """Check that row count is properly functioning.""" + N = 10_000 + csv = "1\n" * N + df = pl.read_csv(io.StringIO(csv), row_index_name="count") + + # The next value should always be higher if monotonically increasing. + assert df.filter(pl.col("count") < pl.col("count").shift(1)).is_empty() + + +def test_read_empty_csv(io_files_path: Path) -> None: + with pytest.raises(NoDataError) as err: + pl.read_csv(io_files_path / "empty.csv") + assert "empty CSV" in str(err.value) + + df = pl.read_csv(io_files_path / "empty.csv", raise_if_empty=False) + assert_frame_equal(df, pl.DataFrame()) + + with pytest.raises(pa.ArrowInvalid) as err: + pl.read_csv(io_files_path / "empty.csv", use_pyarrow=True) + assert "Empty CSV" in str(err.value) + + df = pl.read_csv( + io_files_path / "empty.csv", raise_if_empty=False, use_pyarrow=True + ) + assert_frame_equal(df, pl.DataFrame()) + + +@pytest.mark.slow +def test_read_web_file() -> None: + url = "https://raw.githubusercontent.com/pola-rs/polars/main/examples/datasets/foods1.csv" + df = pl.read_csv(url) + assert df.shape == (27, 4) + + +@pytest.mark.slow +def test_csv_multiline_splits() -> None: + # create a very unlikely csv file with many multilines in a + # single field (e.g. 5000). polars must reject multi-threading here + # as it cannot find proper file chunks without sequentially parsing. + + np.random.seed(0) + f = io.BytesIO() + + def some_multiline_str(n: int) -> str: + strs = [] + strs.append('"') + # sample between 0-5 so it is likely the multiline field also gets 3 separators. + strs.extend(f"{'xx,' * length}" for length in np.random.randint(0, 5, n)) + + strs.append('"') + return "\n".join(strs) + + for _ in range(4): + f.write(f"field1,field2,{some_multiline_str(5000)}\n".encode()) + + f.seek(0) + assert pl.read_csv(f, has_header=False).shape == (4, 3) + + +def test_read_csv_n_rows_outside_heuristic() -> None: + # create a fringe case csv file that breaks the heuristic determining how much of + # the file to read, and ensure n_rows is still adhered to + + f = io.StringIO() + + f.write(",,,?????????\n" * 1000) + f.write("?????????????????????????????????????????????????,,,\n") + f.write(",,,?????????\n" * 1048) + + f.seek(0) + assert pl.read_csv(f, n_rows=2048, has_header=False).shape == (2048, 4) + + +def test_read_csv_comments_on_top_with_schema_11667() -> None: + csv = """ +# This is a comment +A,B +1,Hello +2,World +""".strip() + + schema = { + "A": pl.Int32(), + "B": pl.Utf8(), + } + + df = pl.read_csv(io.StringIO(csv), comment_prefix="#", schema=schema) + assert df.height == 2 + assert df.schema == schema + + +def test_write_csv_stdout_stderr(capsys: pytest.CaptureFixture[str]) -> None: + df = pl.DataFrame( + { + "numbers": [1, 2, 3], + "strings": ["test", "csv", "stdout"], + "dates": [date(2023, 1, 1), date(2023, 1, 2), date(2023, 1, 3)], + } + ) + df.write_csv(sys.stdout) + captured = capsys.readouterr() + assert captured.out == ( + "numbers,strings,dates\n" + "1,test,2023-01-01\n" + "2,csv,2023-01-02\n" + "3,stdout,2023-01-03\n" + ) + + df.write_csv(sys.stderr) + captured = capsys.readouterr() + assert captured.err == ( + "numbers,strings,dates\n" + "1,test,2023-01-01\n" + "2,csv,2023-01-02\n" + "3,stdout,2023-01-03\n" + ) + + +def test_csv_9929() -> None: + df = pl.DataFrame({"nrs": [1, 2, 3]}) + f = io.BytesIO() + df.write_csv(f) + f.seek(0) + with pytest.raises(NoDataError): + pl.read_csv(f, skip_rows=10**6) + + +def test_csv_quote_styles() -> None: + class TemporalFormats(TypedDict): + datetime_format: str + time_format: str + + temporal_formats: TemporalFormats = { + "datetime_format": "%Y-%m-%dT%H:%M:%S", + "time_format": "%H:%M:%S", + } + + dtm = datetime(2077, 7, 5, 3, 1, 0) + dt = dtm.date() + tm = dtm.time() + + df = pl.DataFrame( + { + "float": [1.0, 2.0, None], + "string": ["a", "a,bc", '"hello'], + "int": [1, 2, 3], + "bool": [True, False, None], + "date": [dt, None, dt], + "datetime": [None, dtm, dtm], + "time": [tm, tm, None], + "decimal": [D("1.0"), D("2.0"), None], + } + ) + + assert df.write_csv(quote_style="always", **temporal_formats) == ( + '"float","string","int","bool","date","datetime","time","decimal"\n' + '"1.0","a","1","true","2077-07-05","","03:01:00","1.0"\n' + '"2.0","a,bc","2","false","","2077-07-05T03:01:00","03:01:00","2.0"\n' + '"","""hello","3","","2077-07-05","2077-07-05T03:01:00","",""\n' + ) + assert df.write_csv(quote_style="necessary", **temporal_formats) == ( + "float,string,int,bool,date,datetime,time,decimal\n" + "1.0,a,1,true,2077-07-05,,03:01:00,1.0\n" + '2.0,"a,bc",2,false,,2077-07-05T03:01:00,03:01:00,2.0\n' + ',"""hello",3,,2077-07-05,2077-07-05T03:01:00,,\n' + ) + assert df.write_csv(quote_style="never", **temporal_formats) == ( + "float,string,int,bool,date,datetime,time,decimal\n" + "1.0,a,1,true,2077-07-05,,03:01:00,1.0\n" + "2.0,a,bc,2,false,,2077-07-05T03:01:00,03:01:00,2.0\n" + ',"hello,3,,2077-07-05,2077-07-05T03:01:00,,\n' + ) + assert df.write_csv( + quote_style="non_numeric", quote_char="8", **temporal_formats + ) == ( + "8float8,8string8,8int8,8bool8,8date8,8datetime8,8time8,8decimal8\n" + "1.0,8a8,1,8true8,82077-07-058,,803:01:008,1.0\n" + "2.0,8a,bc8,2,8false8,,82077-07-05T03:01:008,803:01:008,2.0\n" + ',8"hello8,3,,82077-07-058,82077-07-05T03:01:008,,\n' + ) + + +def test_ignore_errors_casting_dtypes() -> None: + csv = """inventory +10 + +400 +90 +""" + + assert pl.read_csv( + source=io.StringIO(csv), + schema_overrides={"inventory": pl.Int8}, + ignore_errors=True, + ).to_dict(as_series=False) == {"inventory": [10, None, None, 90]} + + with pytest.raises(ComputeError): + pl.read_csv( + source=io.StringIO(csv), + schema_overrides={"inventory": pl.Int8}, + ignore_errors=False, + ) + + +def test_ignore_errors_date_parser() -> None: + data_invalid_date = "int,float,date\n3,3.4,X" + with pytest.raises(ComputeError): + pl.read_csv( + source=io.StringIO(data_invalid_date), + schema_overrides={"date": pl.Date}, + ignore_errors=False, + ) + + +def test_csv_ragged_lines() -> None: + expected = {"A": ["B", "C"]} + assert ( + pl.read_csv( + io.StringIO("A\nB,ragged\nC"), has_header=True, truncate_ragged_lines=True + ).to_dict(as_series=False) + == expected + ) + assert ( + pl.read_csv( + io.StringIO("A\nB\nC,ragged"), has_header=True, truncate_ragged_lines=True + ).to_dict(as_series=False) + == expected + ) + + for s in ["A\nB,ragged\nC", "A\nB\nC,ragged"]: + with pytest.raises(ComputeError, match=r"found more fields than defined"): + pl.read_csv(io.StringIO(s), has_header=True, truncate_ragged_lines=False) + with pytest.raises(ComputeError, match=r"found more fields than defined"): + pl.read_csv(io.StringIO(s), has_header=True, truncate_ragged_lines=False) + + +@pytest.mark.may_fail_auto_streaming # missing_columns parameter for CSV +def test_provide_schema() -> None: + # can be used to overload schema with ragged csv files + assert pl.read_csv( + io.StringIO("A\nB,ragged\nC"), + has_header=False, + schema={"A": pl.String, "B": pl.String, "C": pl.String}, + ).to_dict(as_series=False) == { + "A": ["A", "B", "C"], + "B": [None, "ragged", None], + "C": [None, None, None], + } + + +def test_custom_writable_object() -> None: + df = pl.DataFrame({"a": [10, 20, 30], "b": ["x", "y", "z"]}) + + class CustomBuffer: + writes: list[bytes] + + def __init__(self) -> None: + self.writes = [] + + def write(self, data: bytes) -> int: + self.writes.append(data) + return len(data) + + buf = CustomBuffer() + df.write_csv(buf) # type: ignore[call-overload] + + assert b"".join(buf.writes) == b"a,b\n10,x\n20,y\n30,z\n" + + +@pytest.mark.parametrize( + ("csv", "expected"), + [ + (b"a,b\n1,2\n1,2\n", pl.DataFrame({"a": [1, 1], "b": [2, 2]})), + (b"a,b\n1,2\n1,2", pl.DataFrame({"a": [1, 1], "b": [2, 2]})), + (b"a\n1\n1\n", pl.DataFrame({"a": [1, 1]})), + (b"a\n1\n1", pl.DataFrame({"a": [1, 1]})), + ], + ids=[ + "multiple columns, ends with LF", + "multiple columns, ends with non-LF", + "single column, ends with LF", + "single column, ends with non-LF", + ], +) +def test_read_filelike_object_12266(csv: bytes, expected: pl.DataFrame) -> None: + buf = io.BufferedReader(io.BytesIO(csv)) # type: ignore[arg-type] + df = pl.read_csv(buf) + assert_frame_equal(df, expected) + + +def test_read_filelike_object_12404() -> None: + expected = pl.DataFrame({"a": [1, 1], "b": [2, 2]}) + csv = expected.write_csv(line_terminator=";").encode() + buf = io.BufferedReader(io.BytesIO(csv)) # type: ignore[arg-type] + df = pl.read_csv(buf, eol_char=";") + assert_frame_equal(df, expected) + + +def test_write_csv_bom() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]}) + f = io.BytesIO() + df.write_csv(f, include_bom=True) + f.seek(0) + assert f.read() == b"\xef\xbb\xbfa,b\n1,1\n2,2\n3,3\n" + + +def test_write_csv_batch_size_zero() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]}) + f = io.BytesIO() + with pytest.raises(ValueError, match="invalid zero value"): + df.write_csv(f, batch_size=0) + + +def test_empty_csv_no_raise() -> None: + assert pl.read_csv(io.StringIO(), raise_if_empty=False, has_header=False).shape == ( + 0, + 0, + ) + + +def test_csv_no_new_line_last() -> None: + csv = io.StringIO("a b\n1 1\n2 2\n3 2.1") + assert pl.read_csv(csv, separator=" ").to_dict(as_series=False) == { + "a": [1, 2, 3], + "b": [1.0, 2.0, 2.1], + } + + +def test_invalid_csv_raise() -> None: + with pytest.raises(ComputeError): + pl.read_csv( + b""" + "WellCompletionCWI","FacilityID","ProductionMonth","ReportedHoursProdInj","ProdAccountingProductType","ReportedVolume","VolumetricActivityType" + "SK0000608V001","SK BT B1H3780","202001","","GAS","1.700","PROD" + "SK0127960V000","SK BT 0018977","202001","","GAS","45.500","PROD" + "SK0127960V000","SK BT 0018977"," + """.strip() + ) + + +@pytest.mark.write_disk +def test_partial_read_compressed_file( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.setenv("POLARS_FORCE_ASYNC", "0") + + df = pl.DataFrame( + {"idx": range(1_000), "dt": date(2025, 12, 31), "txt": "hello world"} + ) + tmp_path.mkdir(exist_ok=True) + file_path = tmp_path / "large.csv.gz" + bytes_io = io.BytesIO() + df.write_csv(bytes_io) + bytes_io.seek(0) + with gzip.open(file_path, mode="wb") as f: + f.write(bytes_io.getvalue()) + df = pl.read_csv( + file_path, skip_rows=40, has_header=False, skip_rows_after_header=20, n_rows=30 + ) + assert df.shape == (30, 3) + + +def test_read_csv_invalid_schema_overrides() -> None: + csv = textwrap.dedent( + """\ + a,b + 1,foo + 2,bar + 3,baz + """ + ) + f = io.StringIO(csv) + with pytest.raises( + TypeError, match="`schema_overrides` should be of type list or dict" + ): + pl.read_csv(f, schema_overrides={pl.Int64, pl.String}) # type: ignore[arg-type] + + +def test_read_csv_invalid_schema_overrides_length() -> None: + csv = textwrap.dedent( + """\ + a,b + 1,foo + 2,bar + 3,baz + """ + ) + f = io.StringIO(csv) + + # streaming dispatches read_csv -> _scan_csv_impl which does not accept a list + if ( + os.getenv("POLARS_AUTO_NEW_STREAMING", os.getenv("POLARS_FORCE_NEW_STREAMING")) + == "1" + ): + err = TypeError + match = "expected 'schema_overrides' dict, found 'list'" + else: + err = InvalidOperationError + match = "The number of schema overrides must be less than or equal to the number of fields" + + with pytest.raises(err, match=match): + pl.read_csv(f, schema_overrides=[pl.Int64, pl.String, pl.Boolean]) + + +@pytest.mark.parametrize("columns", [["b"], "b"]) +def test_read_csv_single_column(columns: list[str] | str) -> None: + csv = textwrap.dedent( + """\ + a,b,c + 1,2,3 + 4,5,6 + """ + ) + f = io.StringIO(csv) + df = pl.read_csv(f, columns=columns) + expected = pl.DataFrame({"b": [2, 5]}) + assert_frame_equal(df, expected) + + +def test_csv_invalid_escape_utf8_14960() -> None: + with pytest.raises(ComputeError, match=r"Field .* is not properly escaped"): + pl.read_csv('col1\n""•'.encode()) + + +def test_csv_invalid_escape() -> None: + with pytest.raises(ComputeError): + pl.read_csv(b'col1,col2\n"a,b') + + +@pytest.mark.slow +@pytest.mark.write_disk +def test_read_csv_only_loads_selected_columns( + memory_usage_without_pyarrow: MemoryUsage, + tmp_path: Path, +) -> None: + """Only requested columns are loaded by ``read_csv()``.""" + tmp_path.mkdir(exist_ok=True) + + # Each column will be about 8MB of RAM + series = pl.arange(0, 1_000_000, dtype=pl.Int64, eager=True) + + file_path = tmp_path / "multicolumn.csv" + df = pl.DataFrame( + { + "a": series, + "b": series, + } + ) + df.write_csv(file_path) + del df, series + + memory_usage_without_pyarrow.reset_tracking() + + # Only load one column: + df = pl.read_csv(str(file_path), columns=["b"], rechunk=False) + del df + # Only one column's worth of memory should be used; 2 columns would be + # 16_000_000 at least, but there's some overhead. + # assert 8_000_000 < memory_usage_without_pyarrow.get_peak() < 13_000_000 + + # Globs use a different code path for reading + memory_usage_without_pyarrow.reset_tracking() + df = pl.read_csv(str(tmp_path / "*.csv"), columns=["b"], rechunk=False) + del df + # Only one column's worth of memory should be used; 2 columns would be + # 16_000_000 at least, but there's some overhead. + # assert 8_000_000 < memory_usage_without_pyarrow.get_peak() < 13_000_000 + + # read_csv_batched() test: + memory_usage_without_pyarrow.reset_tracking() + result: list[pl.DataFrame] = [] + batched = pl.read_csv_batched( + str(file_path), + columns=["b"], + rechunk=False, + n_threads=1, + low_memory=True, + batch_size=10_000, + ) + while sum(df.height for df in result) < 1_000_000: + next_batch = batched.next_batches(1) + if next_batch is None: + break + result += next_batch + del result + # assert 8_000_000 < memory_usage_without_pyarrow.get_peak() < 20_000_000 + + +def test_csv_escape_cf_15349() -> None: + f = io.BytesIO() + df = pl.DataFrame({"test": ["normal", "with\rcr"]}) + df.write_csv(f) + f.seek(0) + assert f.read() == b'test\nnormal\n"with\rcr"\n' + + +@pytest.mark.write_disk +@pytest.mark.parametrize("streaming", [True, False]) +def test_skip_rows_after_header(tmp_path: Path, streaming: bool) -> None: + tmp_path.mkdir(exist_ok=True) + path = tmp_path / "data.csv" + + df = pl.Series("a", [1, 2, 3, 4, 5], dtype=pl.Int64).to_frame() + df.write_csv(path) + + skip = 2 + expect = df.slice(skip) + out = pl.scan_csv(path, skip_rows_after_header=skip).collect( + engine="streaming" if streaming else "in-memory" + ) + + assert_frame_equal(out, expect) + + +@pytest.mark.parametrize("use_pyarrow", [True, False]) +def test_skip_rows_after_header_pyarrow(use_pyarrow: bool) -> None: + csv = textwrap.dedent( + """\ + foo,bar + 1,2 + 3,4 + 5,6 + """ + ) + f = io.StringIO(csv) + df = pl.read_csv(f, skip_rows_after_header=1, use_pyarrow=use_pyarrow) + expected = pl.DataFrame({"foo": [3, 5], "bar": [4, 6]}) + assert_frame_equal(df, expected) + + +def test_csv_float_decimal() -> None: + floats = b"a;b\n12,239;1,233\n13,908;87,32" + read = pl.read_csv(floats, decimal_comma=True, separator=";") + assert read.dtypes == [pl.Float64] * 2 + assert read.to_dict(as_series=False) == {"a": [12.239, 13.908], "b": [1.233, 87.32]} + + floats = b"a;b\n12,239;1,233\n13,908;87,32" + with pytest.raises( + InvalidOperationError, match=r"'decimal_comma' argument cannot be combined" + ): + pl.read_csv(floats, decimal_comma=True) + + +@pytest.mark.may_fail_auto_streaming # read->scan_csv dispatch +def test_fsspec_not_available() -> None: + with pytest.MonkeyPatch.context() as mp: + mp.setenv("POLARS_FORCE_ASYNC", "0") + mp.setattr("polars.io._utils._FSSPEC_AVAILABLE", False) + + with pytest.raises( + ImportError, match=r"`fsspec` is required for `storage_options` argument" + ): + pl.read_csv( + "s3://foods/cabbage.csv", + storage_options={"key": "key", "secret": "secret"}, + ) + + +@pytest.mark.may_fail_auto_streaming # read->scan_csv dispatch +def test_read_csv_dtypes_deprecated() -> None: + csv = textwrap.dedent( + """\ + a,b,c + 1,2,3 + 4,5,6 + """ + ) + f = io.StringIO(csv) + + with pytest.deprecated_call(): + df = pl.read_csv(f, dtypes=[pl.Int8, pl.Int8, pl.Int8]) # type: ignore[call-arg] + + expected = pl.DataFrame( + {"a": [1, 4], "b": [2, 5], "c": [3, 6]}, + schema={"a": pl.Int8, "b": pl.Int8, "c": pl.Int8}, + ) + assert_frame_equal(df, expected) + + +def test_projection_applied_on_file_with_no_rows_16606(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + path = tmp_path / "data.csv" + + data = """\ +a,b,c,d +""" + + with path.open("w") as f: + f.write(data) + + columns = ["a", "b"] + + out = pl.read_csv(path, columns=columns).columns + assert out == columns + + out = pl.scan_csv(path).select(columns).collect().columns + assert out == columns + + +@pytest.mark.write_disk +def test_write_csv_to_dangling_file_17328( + df_no_lists: pl.DataFrame, tmp_path: Path +) -> None: + tmp_path.mkdir(exist_ok=True) + df_no_lists.write_csv((tmp_path / "dangling.csv").open("w")) + + +def test_write_csv_raise_on_non_utf8_17328( + df_no_lists: pl.DataFrame, tmp_path: Path +) -> None: + tmp_path.mkdir(exist_ok=True) + with pytest.raises(InvalidOperationError, match="file encoding is not UTF-8"): + df_no_lists.write_csv((tmp_path / "dangling.csv").open("w", encoding="gbk")) + + +@pytest.mark.may_fail_auto_streaming # read->scan_csv dispatch +@pytest.mark.write_disk +def test_write_csv_appending_17543(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + df = pl.DataFrame({"col": ["value"]}) + with (tmp_path / "append.csv").open("w") as f: + f.write("# test\n") + df.write_csv(f) + with (tmp_path / "append.csv").open("r") as f: + assert f.readline() == "# test\n" + assert pl.read_csv(f).equals(df) + + +def test_write_csv_passing_params_18825() -> None: + df = pl.DataFrame({"c1": [1, 2], "c2": [3, 4]}) + buffer = io.StringIO() + df.write_csv(buffer, separator="\t", include_header=False) + + result_str = buffer.getvalue() + expected_str = "1\t3\n2\t4\n" + + assert result_str == expected_str + + +@pytest.mark.parametrize( + ("dtype", "df"), + [ + (pl.Decimal(scale=2), pl.DataFrame({"x": ["0.1"]}).cast(pl.Decimal(scale=2))), + (pl.Categorical, pl.DataFrame({"x": ["A"]})), + ( + pl.Time, + pl.DataFrame({"x": ["12:15:00"]}).with_columns( + pl.col("x").str.strptime(pl.Time) + ), + ), + ], +) +def test_read_csv_cast_unparsable_later( + dtype: pl.Decimal | pl.Categorical | pl.Time, df: pl.DataFrame +) -> None: + f = io.BytesIO() + df.write_csv(f) + f.seek(0) + assert df.equals(pl.read_csv(f, schema={"x": dtype})) + + +def test_csv_double_new_line() -> None: + assert pl.read_csv(b"a,b,c\n\n", has_header=False).to_dict(as_series=False) == { + "column_1": ["a", None], + "column_2": ["b", None], + "column_3": ["c", None], + } + + +def test_csv_quoted_newlines_skip_rows_19535() -> None: + assert_frame_equal( + pl.read_csv( + b"""\ +"a\nb" +0 +""", + has_header=False, + skip_rows=1, + new_columns=["x"], + ), + pl.DataFrame({"x": 0}), + ) + + +@pytest.mark.write_disk +def test_csv_read_time_dtype(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + path = tmp_path / "1" + path.write_bytes(b"""\ +time +00:00:00.000000000 +""") + + df = pl.Series("time", [0]).cast(pl.Time()).to_frame() + + assert_frame_equal(pl.read_csv(path, try_parse_dates=True), df) + assert_frame_equal(pl.read_csv(path, schema_overrides={"time": pl.Time}), df) + assert_frame_equal(pl.scan_csv(path, try_parse_dates=True).collect(), df) + assert_frame_equal(pl.scan_csv(path, schema={"time": pl.Time}).collect(), df) + assert_frame_equal( + pl.scan_csv(path, schema={"time": pl.Time}).collect(engine="streaming"), df + ) + + +def test_csv_try_parse_dates_leading_zero_8_digits_22167() -> None: + result = pl.read_csv( + io.StringIO( + "a\n2025-04-06T18:56:42.617736974Z\n2025-04-06T18:57:42.77756192Z\n2025-04-06T18:58:44.56928733Z" + ), + try_parse_dates=True, + ) + expected = pl.DataFrame( + { + "a": [ + datetime(2025, 4, 6, 18, 56, 42, 617736, tzinfo=timezone.utc), + datetime(2025, 4, 6, 18, 57, 42, 777561, tzinfo=timezone.utc), + datetime(2025, 4, 6, 18, 58, 44, 569287, tzinfo=timezone.utc), + ] + } + ) + assert_frame_equal(result, expected) + + +@pytest.mark.may_fail_auto_streaming # read->scan_csv dispatch +def test_csv_read_time_schema_overrides() -> None: + df = pl.Series("time", [0]).cast(pl.Time()).to_frame() + + assert_frame_equal( + pl.read_csv( + b"""\ +time +00:00:00.000000000 +""", + schema_overrides=[pl.Time], + ), + df, + ) + + +def test_batched_csv_schema_overrides(io_files_path: Path) -> None: + foods = io_files_path / "foods1.csv" + batched = pl.read_csv_batched(foods, schema_overrides={"calories": pl.String}) + res = batched.next_batches(1) + assert res is not None + b = res[0] + assert b["calories"].dtype == pl.String + assert b.width == 4 + + +def test_csv_ragged_lines_20062() -> None: + buf = io.StringIO("""A,B,C,D,E,F,G,H,I,J,K,L,M,N,O,P,Q,R,S,T,U,V +,"B",,,,,,,,,A,,,,,,,, +a,a,a,a,a,a,a,a,a,a,a,a,a,a,a,a,a,a,a,0.0,1.0,2.0,3.0 +""") + assert pl.read_csv(buf, truncate_ragged_lines=True).to_dict(as_series=False) == { + "A": [None, "a"], + "B": ["B", "a"], + "C": [None, "a"], + "D": [None, "a"], + "E": [None, "a"], + "F": [None, "a"], + "G": [None, "a"], + "H": [None, "a"], + "I": [None, "a"], + "J": [None, "a"], + "K": ["A", "a"], + "L": [None, "a"], + "M": [None, "a"], + "N": [None, "a"], + "O": [None, "a"], + "P": [None, "a"], + "Q": [None, "a"], + "R": [None, "a"], + "S": [None, "a"], + "T": [None, 0.0], + "U": [None, 1.0], + "V": [None, 2.0], + } + + +def test_csv_skip_lines() -> None: + fh = io.BytesIO() + fh.write(b'Header line "1" -> quote count 2\n') + fh.write(b'Header line "2"" -> quote count 3\n') + fh.write(b'Header line "3" -> quote count 2 => Total 7 quotes ERROR\n') + fh.write(b"column_01, column_02, column_03\n") + fh.write(b"123.12, 21, 99.9\n") + fh.write(b"65.84, 75, 64.7\n") + fh.seek(0) + + df = pl.read_csv(fh, has_header=True, skip_lines=3) + assert df.to_dict(as_series=False) == { + "column_01": [123.12, 65.84], + " column_02": [" 21", " 75"], + " column_03": [" 99.9", " 64.7"], + } + + fh.seek(0) + assert_frame_equal(pl.scan_csv(fh, has_header=True, skip_lines=3).collect(), df) + + +def test_csv_invalid_quoted_comment_line() -> None: + # Comment quotes should be ignored. + assert pl.read_csv( + b'#"Comment\nColA\tColB\n1\t2', separator="\t", comment_prefix="#" + ).to_dict(as_series=False) == {"ColA": [1], "ColB": [2]} + + +@pytest.mark.may_fail_auto_streaming # missing_columns parameter for CSV +def test_csv_compressed_new_columns_19916() -> None: + n_rows = 100 + + df = pl.DataFrame( + { + "a": range(n_rows), + "b": range(n_rows), + "c": range(n_rows), + "d": range(n_rows), + "e": range(n_rows), + "f": range(n_rows), + } + ) + + b = zstandard.compress(df.write_csv(include_header=False).encode()) + + q = pl.scan_csv(b, has_header=False, new_columns=["a", "b", "c", "d", "e", "f"]) + assert_frame_equal(q.collect(), df) + + +def test_trailing_separator_8240() -> None: + csv = "A|B|" + + expected = pl.DataFrame( + {"column_1": ["A"], "column_2": ["B"], "column_3": [None]}, + schema={"column_1": pl.String, "column_2": pl.String, "column_3": pl.String}, + ) + + result = pl.read_csv(io.StringIO(csv), separator="|", has_header=False) + assert_frame_equal(result, expected) + + result = pl.scan_csv(io.StringIO(csv), separator="|", has_header=False).collect() + assert_frame_equal(result, expected) + + +def test_header_only_column_selection_17173() -> None: + csv = "A,B" + result = pl.read_csv(io.StringIO(csv), columns=["B"]) + expected = pl.Series("B", [], pl.String()).to_frame() + assert_frame_equal(result, expected) + + +def test_csv_enum_raise() -> None: + ENUM_DTYPE = pl.Enum(["foo", "bar"]) + with ( + io.StringIO("col\nfoo\nbaz\n") as csv, + pytest.raises( + pl.exceptions.ComputeError, match="category baz doesn't exist in Enum dtype" + ), + ): + pl.read_csv( + csv, + schema={"col": ENUM_DTYPE}, + ) + + +def test_csv_no_header_ragged_lines_1505() -> None: + # Test that the header schema will grow dynamically. + csv = io.StringIO("""a,b,c +a,b,c,d,e,f +g,h,i,j,k""") + + assert pl.read_csv(csv, has_header=False).to_dict(as_series=False) == { + "column_1": ["a", "a", "g"], + "column_2": ["b", "b", "h"], + "column_3": ["c", "c", "i"], + "column_4": [None, "d", "j"], + "column_5": [None, "e", "k"], + "column_6": [None, "f", None], + } + + +@pytest.mark.parametrize( + ("filter_value", "expected"), + [ + (10, "a,b,c\n10,20,99\n"), + (11, "a,b,c\n11,21,99\n"), + (12, "a,b,c\n12,22,99\n12,23,99\n"), + ], +) +def test_csv_write_scalar_empty_chunk_20273(filter_value: int, expected: str) -> None: + # df and filter expression are designed to test different + # Column variants (Series, Scalar) and different number of chunks: + # 10 > single row, ScalarColumn, multiple chunks, first is non-empty + # 11 > single row, ScalarColumn, multiple chunks, first is empty + # 12 > multiple rows, SeriesColumn, multiple chunks, some empty + df1 = pl.DataFrame( + { + "a": [10, 11, 12, 12], # (12, 12 is intentional) + "b": [20, 21, 22, 23], + }, + ) + df2 = pl.DataFrame({"c": [99]}) + df3 = df1.join(df2, how="cross").filter(pl.col("a").eq(filter_value)) + assert df3.write_csv() == expected diff --git a/py-polars/tests/unit/io/test_delta.py b/py-polars/tests/unit/io/test_delta.py new file mode 100644 index 000000000000..26e680551cb4 --- /dev/null +++ b/py-polars/tests/unit/io/test_delta.py @@ -0,0 +1,528 @@ +from __future__ import annotations + +import os +from datetime import datetime, timezone +from pathlib import Path + +import pyarrow as pa +import pyarrow.fs +import pytest +from deltalake import DeltaTable, write_deltalake +from deltalake.exceptions import DeltaError, TableNotFoundError +from deltalake.table import TableMerger + +import polars as pl +from polars.testing import assert_frame_equal, assert_frame_not_equal + + +@pytest.fixture +def delta_table_path(io_files_path: Path) -> Path: + return io_files_path / "delta-table" + + +def test_scan_delta(delta_table_path: Path) -> None: + ldf = pl.scan_delta(str(delta_table_path), version=0) + + expected = pl.DataFrame({"name": ["Joey", "Ivan"], "age": [14, 32]}) + assert_frame_equal(expected, ldf.collect(), check_dtypes=False) + + +def test_scan_delta_version(delta_table_path: Path) -> None: + df1 = pl.scan_delta(str(delta_table_path), version=0).collect() + df2 = pl.scan_delta(str(delta_table_path), version=1).collect() + + assert_frame_not_equal(df1, df2) + + +@pytest.mark.write_disk +def test_scan_delta_timestamp_version(tmp_path: Path) -> None: + df_sample = pl.DataFrame({"name": ["Joey"], "age": [14]}) + df_sample.write_delta(tmp_path, mode="append") + + df_sample2 = pl.DataFrame({"name": ["Ivan"], "age": [34]}) + df_sample2.write_delta(tmp_path, mode="append") + + log_dir = tmp_path / "_delta_log" + log_mtime_pair = [ + ("00000000000000000000.json", datetime(2010, 1, 1).timestamp()), + ("00000000000000000001.json", datetime(2024, 1, 1).timestamp()), + ] + for file_name, dt_epoch in log_mtime_pair: + file_path = log_dir / file_name + os.utime(str(file_path), (dt_epoch, dt_epoch)) + + df1 = pl.scan_delta( + str(tmp_path), version=datetime(2010, 1, 1, tzinfo=timezone.utc) + ).collect() + df2 = pl.scan_delta( + str(tmp_path), version=datetime(2024, 1, 1, tzinfo=timezone.utc) + ).collect() + + assert_frame_equal(df1, df_sample) + assert_frame_equal(df2, pl.concat([df_sample, df_sample2]), check_row_order=False) + + +def test_scan_delta_columns(delta_table_path: Path) -> None: + ldf = pl.scan_delta(str(delta_table_path), version=0).select("name") + + expected = pl.DataFrame({"name": ["Joey", "Ivan"]}) + assert_frame_equal(expected, ldf.collect(), check_dtypes=False) + + +def test_scan_delta_relative(delta_table_path: Path) -> None: + rel_delta_table_path = str(delta_table_path / ".." / "delta-table") + + ldf = pl.scan_delta(rel_delta_table_path, version=0) + + expected = pl.DataFrame({"name": ["Joey", "Ivan"], "age": [14, 32]}) + assert_frame_equal(expected, ldf.collect(), check_dtypes=False) + + ldf = pl.scan_delta(rel_delta_table_path, version=1) + assert_frame_not_equal(expected, ldf.collect()) + + +def test_read_delta(delta_table_path: Path) -> None: + df = pl.read_delta(str(delta_table_path), version=0) + + expected = pl.DataFrame({"name": ["Joey", "Ivan"], "age": [14, 32]}) + assert_frame_equal(expected, df, check_dtypes=False) + + +def test_read_delta_version(delta_table_path: Path) -> None: + df1 = pl.read_delta(str(delta_table_path), version=0) + df2 = pl.read_delta(str(delta_table_path), version=1) + + assert_frame_not_equal(df1, df2) + + +@pytest.mark.write_disk +def test_read_delta_timestamp_version(tmp_path: Path) -> None: + df_sample = pl.DataFrame({"name": ["Joey"], "age": [14]}) + df_sample.write_delta(tmp_path, mode="append") + + df_sample2 = pl.DataFrame({"name": ["Ivan"], "age": [34]}) + df_sample2.write_delta(tmp_path, mode="append") + + log_dir = tmp_path / "_delta_log" + log_mtime_pair = [ + ("00000000000000000000.json", datetime(2010, 1, 1).timestamp()), + ("00000000000000000001.json", datetime(2024, 1, 1).timestamp()), + ] + for file_name, dt_epoch in log_mtime_pair: + file_path = log_dir / file_name + os.utime(str(file_path), (dt_epoch, dt_epoch)) + + df1 = pl.read_delta( + str(tmp_path), version=datetime(2010, 1, 1, tzinfo=timezone.utc) + ) + df2 = pl.read_delta( + str(tmp_path), version=datetime(2024, 1, 1, tzinfo=timezone.utc) + ) + + assert_frame_equal(df1, df_sample) + assert_frame_equal(df2, pl.concat([df_sample, df_sample2]), check_row_order=False) + + +def test_read_delta_columns(delta_table_path: Path) -> None: + df = pl.read_delta(str(delta_table_path), version=0, columns=["name"]) + + expected = pl.DataFrame({"name": ["Joey", "Ivan"]}) + assert_frame_equal(expected, df, check_dtypes=False) + + +def test_read_delta_relative(delta_table_path: Path) -> None: + rel_delta_table_path = str(delta_table_path / ".." / "delta-table") + + df = pl.read_delta(rel_delta_table_path, version=0) + + expected = pl.DataFrame({"name": ["Joey", "Ivan"], "age": [14, 32]}) + assert_frame_equal(expected, df, check_dtypes=False) + + +@pytest.mark.write_disk +def test_write_delta(df: pl.DataFrame, tmp_path: Path) -> None: + v0 = df.select(pl.col(pl.String)) + v1 = df.select(pl.col(pl.Int64)) + df_supported = df.drop(["cat", "enum", "time"]) + + # Case: Success (version 0) + v0.write_delta(tmp_path) + + # Case: Error if table exists + with pytest.raises(DeltaError, match="A table already exists"): + v0.write_delta(tmp_path) + + # Case: Overwrite with new version (version 1) + v1.write_delta( + tmp_path, mode="overwrite", delta_write_options={"schema_mode": "overwrite"} + ) + + # Case: Error if schema contains unsupported columns + with pytest.raises(TypeError): + df.write_delta( + tmp_path, mode="overwrite", delta_write_options={"schema_mode": "overwrite"} + ) + + partitioned_tbl_uri = (tmp_path / ".." / "partitioned_table").resolve() + + # Case: Write new partitioned table (version 0) + df_supported.write_delta( + partitioned_tbl_uri, delta_write_options={"partition_by": "strings"} + ) + + # Case: Read back + tbl = DeltaTable(tmp_path) + partitioned_tbl = DeltaTable(partitioned_tbl_uri) + + pl_df_0 = pl.read_delta(tbl.table_uri, version=0) + pl_df_1 = pl.read_delta(tbl.table_uri, version=1) + pl_df_partitioned = pl.read_delta(str(partitioned_tbl_uri)) + + assert v0.shape == pl_df_0.shape + assert v0.columns == pl_df_0.columns + assert v1.shape == pl_df_1.shape + assert v1.columns == pl_df_1.columns + + assert df_supported.shape == pl_df_partitioned.shape + assert sorted(df_supported.columns) == sorted(pl_df_partitioned.columns) + + assert tbl.version() == 1 + assert partitioned_tbl.version() == 0 + assert Path(partitioned_tbl.table_uri) == partitioned_tbl_uri + assert partitioned_tbl.metadata().partition_columns == ["strings"] + + assert_frame_equal(v0, pl_df_0, check_row_order=False) + assert_frame_equal(v1, pl_df_1, check_row_order=False) + + cols = [c for c in df_supported.columns if not c.startswith("list_")] + assert_frame_equal( + df_supported.select(cols), + pl_df_partitioned.select(cols), + check_row_order=False, + ) + + # Case: Append to existing tables + v1.write_delta(tmp_path, mode="append") + tbl = DeltaTable(tmp_path) + pl_df_1 = pl.read_delta(tbl.table_uri, version=2) + + assert tbl.version() == 2 + assert pl_df_1.shape == (6, 2) # Rows are doubled + assert v1.columns == pl_df_1.columns + + df_supported.write_delta(partitioned_tbl_uri, mode="append") + partitioned_tbl = DeltaTable(partitioned_tbl_uri) + pl_df_partitioned = pl.read_delta(partitioned_tbl.table_uri, version=1) + + assert partitioned_tbl.version() == 1 + assert pl_df_partitioned.shape == (6, 14) # Rows are doubled + assert sorted(df_supported.columns) == sorted(pl_df_partitioned.columns) + + df_supported.write_delta(partitioned_tbl_uri, mode="overwrite") + + +@pytest.mark.write_disk +def test_write_delta_overwrite_schema_deprecated( + df: pl.DataFrame, tmp_path: Path +) -> None: + df = df.select(pl.col(pl.Int64)) + with pytest.deprecated_call(): + df.write_delta(tmp_path, mode="overwrite", overwrite_schema=True) + result = pl.read_delta(str(tmp_path)) + assert_frame_equal(df, result) + + +@pytest.mark.write_disk +@pytest.mark.parametrize( + "series", + [ + pl.Series("string", ["test"], dtype=pl.String), + pl.Series("uint", [1], dtype=pl.UInt64), + pl.Series("int", [1], dtype=pl.Int64), + pl.Series( + "uint_list", + [[[[[1, 2, 3], [1, 2, 3]], [[1, 2, 3], [1, 2, 3]]]]], + dtype=pl.List(pl.List(pl.List(pl.List(pl.UInt16)))), + ), + pl.Series( + "date_ns", [datetime(2010, 1, 1, 0, 0)], dtype=pl.Datetime(time_unit="ns") + ).dt.replace_time_zone("Australia/Lord_Howe"), + pl.Series( + "date_us", + [datetime(2010, 1, 1, 0, 0)], + dtype=pl.Datetime(time_unit="us"), + ), + pl.Series( + "list_date", + [ + [ + datetime(2010, 1, 1, 0, 0), + datetime(2010, 1, 2, 0, 0), + ] + ], + dtype=pl.List(pl.Datetime(time_unit="ns")), + ), + pl.Series( + "list_date_us", + [ + [ + datetime(2010, 1, 1, 0, 0), + datetime(2010, 1, 2, 0, 0), + ] + ], + dtype=pl.List(pl.Datetime(time_unit="ms")), + ), + pl.Series( + "nested_list_date", + [ + [ + [ + datetime(2010, 1, 1, 0, 0), + datetime(2010, 1, 2, 0, 0), + ] + ] + ], + dtype=pl.List(pl.List(pl.Datetime(time_unit="ns"))), + ), + pl.Series( + "struct_with_list", + [ + { + "date_range": [ + datetime(2010, 1, 1, 0, 0), + datetime(2010, 1, 2, 0, 0), + ], + "date_us": [ + datetime(2010, 1, 1, 0, 0), + datetime(2010, 1, 2, 0, 0), + ], + "date_range_nested": [ + [ + datetime(2010, 1, 1, 0, 0), + datetime(2010, 1, 2, 0, 0), + ] + ], + "string": "test", + "int": 1, + } + ], + dtype=pl.Struct( + [ + pl.Field( + "date_range", + pl.List(pl.Datetime(time_unit="ms", time_zone="UTC")), + ), + pl.Field( + "date_us", pl.List(pl.Datetime(time_unit="ms", time_zone=None)) + ), + pl.Field( + "date_range_nested", + pl.List(pl.List(pl.Datetime(time_unit="ms", time_zone=None))), + ), + pl.Field("string", pl.String), + pl.Field("int", pl.UInt32), + ] + ), + ), + pl.Series( + "list_with_struct_with_list", + [ + [ + { + "date_range": [ + datetime(2010, 1, 1, 0, 0), + datetime(2010, 1, 2, 0, 0), + ], + "date_ns": [ + datetime(2010, 1, 1, 0, 0), + datetime(2010, 1, 2, 0, 0), + ], + "date_range_nested": [ + [ + datetime(2010, 1, 1, 0, 0), + datetime(2010, 1, 2, 0, 0), + ] + ], + "string": "test", + "int": 1, + } + ] + ], + dtype=pl.List( + pl.Struct( + [ + pl.Field( + "date_range", + pl.List(pl.Datetime(time_unit="ns", time_zone=None)), + ), + pl.Field( + "date_ns", + pl.List(pl.Datetime(time_unit="ns", time_zone=None)), + ), + pl.Field( + "date_range_nested", + pl.List( + pl.List(pl.Datetime(time_unit="ns", time_zone=None)) + ), + ), + pl.Field("string", pl.String), + pl.Field("int", pl.UInt32), + ] + ) + ), + ), + ], +) +def test_write_delta_w_compatible_schema(series: pl.Series, tmp_path: Path) -> None: + df = series.to_frame() + + # Create table + df.write_delta(tmp_path, mode="append") + + # Write to table again, should pass with reconstructed schema + df.write_delta(tmp_path, mode="append") + + tbl = DeltaTable(tmp_path) + assert tbl.version() == 1 + + +@pytest.mark.write_disk +def test_write_delta_with_schema_10540(tmp_path: Path) -> None: + df = pl.DataFrame({"a": [1, 2, 3]}) + + pa_schema = pa.schema([("a", pa.int64())]) + df.write_delta(tmp_path, delta_write_options={"schema": pa_schema}) + + +@pytest.mark.write_disk +@pytest.mark.parametrize( + "expr", + [ + pl.datetime(2010, 1, 1, time_unit="us", time_zone="UTC"), + pl.datetime(2010, 1, 1, time_unit="ns", time_zone="EST"), + pl.datetime(2010, 1, 1, time_unit="ms", time_zone="Europe/Amsterdam"), + ], +) +def test_write_delta_with_tz_in_df(expr: pl.Expr, tmp_path: Path) -> None: + df = pl.select(expr) + + expected_dtype = pl.Datetime("us", "UTC") + expected = pl.select(expr.cast(expected_dtype)) + + df.write_delta(tmp_path, mode="append") + # write second time because delta-rs also casts timestamp with tz to timestamp no tz + df.write_delta(tmp_path, mode="append") + + # Check schema of DeltaTable object + tbl = DeltaTable(tmp_path) + assert tbl.schema().to_pyarrow() == expected.to_arrow().schema + + # Check result + result = pl.read_delta(str(tmp_path), version=0) + assert_frame_equal(result, expected) + + +def test_write_delta_with_merge_and_no_table(tmp_path: Path) -> None: + df = pl.DataFrame({"a": [1, 2, 3]}) + + with pytest.raises(TableNotFoundError): + df.write_delta( + tmp_path, mode="merge", delta_merge_options={"predicate": "a = a"} + ) + + +@pytest.mark.write_disk +def test_write_delta_with_merge(tmp_path: Path) -> None: + df = pl.DataFrame({"a": [1, 2, 3]}) + + df.write_delta(tmp_path) + + merger = df.write_delta( + tmp_path, + mode="merge", + delta_merge_options={ + "predicate": "s.a = t.a", + "source_alias": "s", + "target_alias": "t", + }, + ) + + assert isinstance(merger, TableMerger) + assert merger._builder.source_alias == "s" + assert merger._builder.target_alias == "t" + + merger.when_matched_delete(predicate="t.a > 2").execute() + + result = pl.read_delta(str(tmp_path)) + + expected = df.filter(pl.col("a") <= 2) + assert_frame_equal(result, expected, check_row_order=False) + + +@pytest.mark.write_disk +def test_unsupported_dtypes(tmp_path: Path) -> None: + df = pl.DataFrame({"a": [None]}, schema={"a": pl.Null}) + with pytest.raises(TypeError, match="unsupported data type"): + df.write_delta(tmp_path / "null") + + df = pl.DataFrame({"a": [123]}, schema={"a": pl.Time}) + with pytest.raises(TypeError, match="unsupported data type"): + df.write_delta(tmp_path / "time") + + +@pytest.mark.skip( + reason="upstream bug in delta-rs causing categorical to be written as categorical in parquet" +) +@pytest.mark.write_disk +@pytest.mark.usefixtures("test_global_and_local") +def test_categorical_becomes_string(tmp_path: Path) -> None: + df = pl.DataFrame({"a": ["A", "B", "A"]}, schema={"a": pl.Categorical}) + df.write_delta(tmp_path) + df2 = pl.read_delta(str(tmp_path)) + assert_frame_equal(df2, pl.DataFrame({"a": ["A", "B", "A"]}, schema={"a": pl.Utf8})) + + +def test_scan_delta_DT_input(delta_table_path: Path) -> None: + DT = DeltaTable(str(delta_table_path), version=0) + ldf = pl.scan_delta(DT) + + expected = pl.DataFrame({"name": ["Joey", "Ivan"], "age": [14, 32]}) + assert_frame_equal(expected, ldf.collect(), check_dtypes=False) + + +@pytest.mark.write_disk +def test_read_delta_empty(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + path = str(tmp_path) + + DeltaTable.create(path, pl.DataFrame(schema={"x": pl.Int64}).to_arrow().schema) + assert_frame_equal(pl.read_delta(path), pl.DataFrame(schema={"x": pl.Int64})) + + +@pytest.mark.write_disk +def test_read_delta_arrow_map_type(tmp_path: Path) -> None: + payload = [ + {"id": 1, "account_id": {17: "100.01.001 Cash"}}, + {"id": 2, "account_id": {18: "180.01.001 Cash", 19: "foo"}}, + ] + + schema = pa.schema( + [ + pa.field("id", pa.int32()), + pa.field("account_id", pa.map_(pa.int32(), pa.string())), + ] + ) + table = pa.Table.from_pylist(payload, schema) + + expect = pl.DataFrame(table) + + table_path = str(tmp_path) + write_deltalake( + table_path, + table, + mode="overwrite", + engine="rust", + ) + + assert_frame_equal(pl.scan_delta(table_path).collect(), expect) + assert_frame_equal(pl.read_delta(table_path), expect) diff --git a/py-polars/tests/unit/io/test_hive.py b/py-polars/tests/unit/io/test_hive.py new file mode 100644 index 000000000000..9cf5d9349c6f --- /dev/null +++ b/py-polars/tests/unit/io/test_hive.py @@ -0,0 +1,976 @@ +from __future__ import annotations + +import sys +import urllib.parse +import warnings +from collections import OrderedDict +from datetime import date, datetime +from functools import partial +from pathlib import Path +from typing import Any, Callable + +import pyarrow.parquet as pq +import pytest + +import polars as pl +from polars.exceptions import ComputeError, SchemaFieldNotFoundError +from polars.testing import assert_frame_equal, assert_series_equal + + +def impl_test_hive_partitioned_predicate_pushdown( + io_files_path: Path, + tmp_path: Path, + monkeypatch: Any, +) -> None: + monkeypatch.setenv("POLARS_VERBOSE", "1") + df = pl.read_ipc(io_files_path / "*.ipc") + + root = tmp_path / "partitioned_data" + + pq.write_to_dataset( + df.to_arrow(), + root_path=root, + partition_cols=["category", "fats_g"], + ) + q = pl.scan_parquet(root / "**/*.parquet", hive_partitioning=False) + # checks schema + assert q.collect_schema().names() == ["calories", "sugars_g"] + # checks materialization + assert q.collect().columns == ["calories", "sugars_g"] + + q = pl.scan_parquet(root / "**/*.parquet", hive_partitioning=True) + assert q.collect_schema().names() == ["calories", "sugars_g", "category", "fats_g"] + + # Partitioning changes the order + sort_by = ["fats_g", "category", "calories", "sugars_g"] + + # The hive partitioned columns are appended, + # so we must ensure we assert in the proper order. + df = df.select(["calories", "sugars_g", "category", "fats_g"]) + for streaming in [True, False]: + for pred in [ + pl.col("category") == "vegetables", + pl.col("category") != "vegetables", + pl.col("fats_g") > 0.5, + (pl.col("fats_g") == 0.5) & (pl.col("category") == "vegetables"), + ]: + assert_frame_equal( + q.filter(pred) + .sort(sort_by) + .collect(engine="streaming" if streaming else "in-memory"), + df.filter(pred).sort(sort_by), + ) + + # tests: 11536 + assert q.filter(pl.col("sugars_g") == 25).collect().shape == (1, 4) + + # tests: 12570 + assert q.filter(pl.col("fats_g") == 1225.0).select("category").collect().shape == ( + 0, + 1, + ) + + +@pytest.mark.xdist_group("streaming") +@pytest.mark.write_disk +def test_hive_partitioned_predicate_pushdown( + io_files_path: Path, + tmp_path: Path, + monkeypatch: Any, +) -> None: + impl_test_hive_partitioned_predicate_pushdown( + io_files_path, + tmp_path, + monkeypatch, + ) + + +@pytest.mark.xdist_group("streaming") +@pytest.mark.write_disk +def test_hive_partitioned_predicate_pushdown_single_threaded_async_17155( + io_files_path: Path, + tmp_path: Path, + monkeypatch: Any, +) -> None: + monkeypatch.setenv("POLARS_FORCE_ASYNC", "1") + monkeypatch.setenv("POLARS_PREFETCH_SIZE", "1") + + impl_test_hive_partitioned_predicate_pushdown( + io_files_path, + tmp_path, + monkeypatch, + ) + + +@pytest.mark.write_disk +@pytest.mark.may_fail_auto_streaming +def test_hive_partitioned_predicate_pushdown_skips_correct_number_of_files( + tmp_path: Path, monkeypatch: Any, capfd: Any +) -> None: + monkeypatch.setenv("POLARS_VERBOSE", "1") + df = pl.DataFrame({"d": pl.arange(0, 5, eager=True)}).with_columns( + a=pl.col("d") % 5 + ) + root = tmp_path / "test_int_partitions" + df.write_parquet( + root, + use_pyarrow=True, + pyarrow_options={"partition_cols": ["a"]}, + ) + + q = pl.scan_parquet(root / "**/*.parquet", hive_partitioning=True) + assert q.filter(pl.col("a").is_in([1, 4])).collect().shape == (2, 2) + assert "allows skipping 3 / 5" in capfd.readouterr().err + + # Ensure the CSE can work with hive partitions. + q = q.filter(pl.col("a").gt(2)) + result = q.join(q, on="a", how="left").collect(comm_subplan_elim=True) + expected = { + "a": [3, 4], + "d": [3, 4], + "d_right": [3, 4], + } + assert result.to_dict(as_series=False) == expected + + +@pytest.mark.write_disk +def test_hive_streaming_pushdown_is_in_22212(tmp_path: Path) -> None: + ( + pl.DataFrame({"x": range(5)}).write_parquet( + tmp_path, + partition_by="x", + ) + ) + + lf = pl.scan_parquet(tmp_path, hive_partitioning=True).filter( + pl.col("x").is_in([1, 4]) + ) + + assert_frame_equal( + lf.collect(engine="streaming", predicate_pushdown=False), + lf.collect(engine="streaming", predicate_pushdown=True), + ) + + +@pytest.mark.xdist_group("streaming") +@pytest.mark.write_disk +@pytest.mark.parametrize("streaming", [True, False]) +def test_hive_partitioned_slice_pushdown( + io_files_path: Path, tmp_path: Path, streaming: bool +) -> None: + df = pl.read_ipc(io_files_path / "*.ipc") + + root = tmp_path / "partitioned_data" + + # Ignore the pyarrow legacy warning until we can write properly with new settings. + warnings.filterwarnings("ignore") + pq.write_to_dataset( + df.to_arrow(), + root_path=root, + partition_cols=["category", "fats_g"], + ) + + q = pl.scan_parquet(root / "**/*.parquet", hive_partitioning=True) + schema = q.collect_schema() + expect_count = pl.select(pl.lit(1, dtype=pl.UInt32).alias(x) for x in schema) + + assert_frame_equal( + q.head(1) + .collect(engine="streaming" if streaming else "in-memory") + .select(pl.all().len()), + expect_count, + ) + assert q.head(0).collect( + engine="streaming" if streaming else "in-memory" + ).columns == [ + "calories", + "sugars_g", + "category", + "fats_g", + ] + + +@pytest.mark.xdist_group("streaming") +@pytest.mark.write_disk +def test_hive_partitioned_projection_pushdown( + io_files_path: Path, tmp_path: Path +) -> None: + df = pl.read_ipc(io_files_path / "*.ipc") + + root = tmp_path / "partitioned_data" + + # Ignore the pyarrow legacy warning until we can write properly with new settings. + warnings.filterwarnings("ignore") + pq.write_to_dataset( + df.to_arrow(), + root_path=root, + partition_cols=["category", "fats_g"], + ) + + q = pl.scan_parquet(root / "**/*.parquet", hive_partitioning=True) + columns = ["sugars_g", "category"] + for streaming in [True, False]: + assert ( + q.select(columns) + .collect(engine="streaming" if streaming else "in-memory") + .columns + == columns + ) + + # test that hive partition columns are projected with the correct height when + # the projection contains only hive partition columns (11796) + for parallel in ("row_groups", "columns"): + q = pl.scan_parquet( + root / "**/*.parquet", + hive_partitioning=True, + parallel=parallel, + ) + + expected = q.collect().select("category") + result = q.select("category").collect() + + assert_frame_equal(result, expected) + + +@pytest.mark.write_disk +def test_hive_partitioned_projection_skips_files(tmp_path: Path) -> None: + # ensure that it makes hive columns even when . in dir value + # and that it doesn't make hive columns from filename with = + df = pl.DataFrame( + {"sqlver": [10012.0, 10013.0], "namespace": ["eos", "fda"], "a": [1, 2]} + ) + root = tmp_path / "partitioned_data" + for dir_tuple, sub_df in df.partition_by( + ["sqlver", "namespace"], include_key=False, as_dict=True + ).items(): + new_path = root / f"sqlver={dir_tuple[0]}" / f"namespace={dir_tuple[1]}" + new_path.mkdir(parents=True, exist_ok=True) + sub_df.write_parquet(new_path / "file=8484.parquet") + test_df = ( + pl.scan_parquet(str(root) + "/**/**/*.parquet", hive_partitioning=True) + # don't care about column order + .select("sqlver", "namespace", "a", pl.exclude("sqlver", "namespace", "a")) + .collect() + ) + assert_frame_equal(df, test_df) + + +@pytest.fixture +def dataset_path(tmp_path: Path) -> Path: + tmp_path.mkdir(exist_ok=True) + + # Set up Hive partitioned Parquet file + root = tmp_path / "dataset" + part1 = root / "c=1" + part2 = root / "c=2" + root.mkdir() + part1.mkdir() + part2.mkdir() + df1 = pl.DataFrame({"a": [1, 2], "b": [11.0, 12.0]}) + df2 = pl.DataFrame({"a": [3, 4], "b": [13.0, 14.0]}) + df3 = pl.DataFrame({"a": [5, 6], "b": [15.0, 16.0]}) + df1.write_parquet(part1 / "one.parquet") + df2.write_parquet(part1 / "two.parquet") + df3.write_parquet(part2 / "three.parquet") + + return root + + +@pytest.mark.write_disk +def test_scan_parquet_hive_schema(dataset_path: Path) -> None: + result = pl.scan_parquet(dataset_path / "**/*.parquet", hive_partitioning=True) + assert result.collect_schema() == OrderedDict( + {"a": pl.Int64, "b": pl.Float64, "c": pl.Int64} + ) + + result = pl.scan_parquet( + dataset_path / "**/*.parquet", + hive_partitioning=True, + hive_schema={"c": pl.Int32}, + ) + + expected_schema = OrderedDict({"a": pl.Int64, "b": pl.Float64, "c": pl.Int32}) + assert result.collect_schema() == expected_schema + assert result.collect().schema == expected_schema + + +@pytest.mark.write_disk +def test_read_parquet_invalid_hive_schema(dataset_path: Path) -> None: + with pytest.raises( + SchemaFieldNotFoundError, + match='path contains column not present in the given Hive schema: "c"', + ): + pl.read_parquet( + dataset_path / "**/*.parquet", + hive_partitioning=True, + hive_schema={"nonexistent": pl.Int32}, + ) + + +def test_read_parquet_hive_schema_with_pyarrow() -> None: + with pytest.raises( + TypeError, + match="cannot use `hive_partitions` with `use_pyarrow=True`", + ): + pl.read_parquet("test.parquet", hive_schema={"c": pl.Int32}, use_pyarrow=True) + + +@pytest.mark.parametrize( + ("scan_func", "write_func"), + [ + (pl.scan_parquet, pl.DataFrame.write_parquet), + (pl.scan_ipc, pl.DataFrame.write_ipc), + ], +) +@pytest.mark.parametrize( + "glob", + [True, False], +) +def test_hive_partition_directory_scan( + tmp_path: Path, + scan_func: Callable[..., pl.LazyFrame], + write_func: Callable[[pl.DataFrame, Path], None], + glob: bool, +) -> None: + tmp_path.mkdir(exist_ok=True) + + dfs = [ + pl.DataFrame({'x': 5 * [1], 'a': 1, 'b': 1}), + pl.DataFrame({'x': 5 * [2], 'a': 1, 'b': 2}), + pl.DataFrame({'x': 5 * [3], 'a': 22, 'b': 1}), + pl.DataFrame({'x': 5 * [4], 'a': 22, 'b': 2}), + ] # fmt: skip + + for df in dfs: + a = df.item(0, "a") + b = df.item(0, "b") + path = tmp_path / f"a={a}/b={b}/data.bin" + path.parent.mkdir(exist_ok=True, parents=True) + write_func(df.drop("a", "b"), path) + + df = pl.concat(dfs) + hive_schema = df.lazy().select("a", "b").collect_schema() + + scan = scan_func + + if scan_func is pl.scan_parquet: + scan = partial(scan, glob=glob) + + scan_with_hive_schema = partial(scan_func, hive_schema=hive_schema) + + out = scan_with_hive_schema( + tmp_path, + hive_partitioning=True, + ).collect() + assert_frame_equal(out, df) + + out = scan(tmp_path, hive_partitioning=False).collect() + assert_frame_equal(out, df.drop("a", "b")) + + out = scan_with_hive_schema( + tmp_path / "a=1", + hive_partitioning=True, + ).collect() + assert_frame_equal(out, df.filter(a=1).drop("a")) + + out = scan(tmp_path / "a=1", hive_partitioning=False).collect() + assert_frame_equal(out, df.filter(a=1).drop("a", "b")) + + path = tmp_path / "a=1/b=1/data.bin" + + out = scan_with_hive_schema(path, hive_partitioning=True).collect() + assert_frame_equal(out, dfs[0]) + + out = scan(path, hive_partitioning=False).collect() + assert_frame_equal(out, dfs[0].drop("a", "b")) + + # Test default behavior with `hive_partitioning=None`, which should only + # enable hive partitioning when a single directory is passed: + out = scan_with_hive_schema(tmp_path).collect() + assert_frame_equal(out, df) + + # Otherwise, hive partitioning is not enabled automatically: + out = scan(tmp_path / "a=1/b=1/data.bin").collect() + assert out.columns == ["x"] + + out = scan([tmp_path / "a=1/", tmp_path / "a=22/"]).collect() + assert out.columns == ["x"] + + out = scan([tmp_path / "a=1/", tmp_path / "a=22/b=1/data.bin"]).collect() + assert out.columns == ["x"] + + if glob: + out = scan(tmp_path / "a=1/**/*.bin").collect() + assert out.columns == ["x"] + + # Test `hive_partitioning=True` + out = scan_with_hive_schema(tmp_path, hive_partitioning=True).collect() + assert_frame_equal(out, df) + + # Accept multiple directories from the same level + out = scan_with_hive_schema( + [tmp_path / "a=1", tmp_path / "a=22"], hive_partitioning=True + ).collect() + assert_frame_equal(out, df.drop("a")) + + with pytest.raises( + pl.exceptions.InvalidOperationError, + match="attempted to read from different directory levels with hive partitioning enabled:", + ): + scan_with_hive_schema( + [tmp_path / "a=1", tmp_path / "a=22/b=1"], hive_partitioning=True + ).collect() + + if glob: + out = scan_with_hive_schema( + tmp_path / "**/*.bin", hive_partitioning=True + ).collect() + assert_frame_equal(out, df) + + # Parse hive from full path for glob patterns + out = scan_with_hive_schema( + [tmp_path / "a=1/**/*.bin", tmp_path / "a=22/**/*.bin"], + hive_partitioning=True, + ).collect() + assert_frame_equal(out, df) + + # Parse hive from full path for files + out = scan_with_hive_schema( + tmp_path / "a=1/b=1/data.bin", hive_partitioning=True + ).collect() + assert_frame_equal(out, df.filter(a=1, b=1)) + + out = scan_with_hive_schema( + [tmp_path / "a=1/b=1/data.bin", tmp_path / "a=22/b=1/data.bin"], + hive_partitioning=True, + ).collect() + assert_frame_equal( + out, + df.filter( + ((pl.col("a") == 1) & (pl.col("b") == 1)) + | ((pl.col("a") == 22) & (pl.col("b") == 1)) + ), + ) + + # Test `hive_partitioning=False` + out = scan(tmp_path, hive_partitioning=False).collect() + assert_frame_equal(out, df.drop("a", "b")) + + if glob: + out = scan(tmp_path / "**/*.bin", hive_partitioning=False).collect() + assert_frame_equal(out, df.drop("a", "b")) + + out = scan(tmp_path / "a=1/b=1/data.bin", hive_partitioning=False).collect() + assert_frame_equal(out, df.filter(a=1, b=1).drop("a", "b")) + + +def test_hive_partition_schema_inference(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + dfs = [ + pl.DataFrame({"x": 1}), + pl.DataFrame({"x": 2}), + pl.DataFrame({"x": 3}), + ] + + paths = [ + tmp_path / "a=1/data.bin", + tmp_path / "a=1.5/data.bin", + tmp_path / "a=polars/data.bin", + ] + + expected = [ + pl.Series("a", [1], dtype=pl.Int64), + pl.Series("a", [1.0, 1.5], dtype=pl.Float64), + pl.Series("a", ["1", "1.5", "polars"], dtype=pl.String), + ] + + for i in range(3): + paths[i].parent.mkdir(exist_ok=True, parents=True) + dfs[i].write_parquet(paths[i]) + out = pl.scan_parquet(tmp_path).collect() + + assert_series_equal(out["a"], expected[i]) + + +@pytest.mark.write_disk +def test_hive_partition_force_async_17155(tmp_path: Path, monkeypatch: Any) -> None: + monkeypatch.setenv("POLARS_FORCE_ASYNC", "1") + monkeypatch.setenv("POLARS_PREFETCH_SIZE", "1") + + dfs = [ + pl.DataFrame({"x": 1}), + pl.DataFrame({"x": 2}), + pl.DataFrame({"x": 3}), + ] + + paths = [ + tmp_path / "a=1/b=1/data.bin", + tmp_path / "a=2/b=2/data.bin", + tmp_path / "a=3/b=3/data.bin", + ] + + for i in range(3): + paths[i].parent.mkdir(exist_ok=True, parents=True) + dfs[i].write_parquet(paths[i]) + + lf = pl.scan_parquet(tmp_path) + + assert_frame_equal( + lf.collect(), pl.DataFrame({k: [1, 2, 3] for k in ["x", "a", "b"]}) + ) + + +@pytest.mark.parametrize( + ("scan_func", "write_func"), + [ + (partial(pl.scan_parquet, parallel="row_groups"), pl.DataFrame.write_parquet), + (partial(pl.scan_parquet, parallel="columns"), pl.DataFrame.write_parquet), + (partial(pl.scan_parquet, parallel="prefiltered"), pl.DataFrame.write_parquet), + ( + lambda *a, **kw: pl.scan_parquet(*a, parallel="prefiltered", **kw).filter( + pl.col("b") == pl.col("b") + ), + pl.DataFrame.write_parquet, + ), + (pl.scan_ipc, pl.DataFrame.write_ipc), + ], +) +@pytest.mark.write_disk +@pytest.mark.slow +@pytest.mark.parametrize("projection_pushdown", [True, False]) +def test_hive_partition_columns_contained_in_file( + tmp_path: Path, + scan_func: Callable[[Any], pl.LazyFrame], + write_func: Callable[[pl.DataFrame, Path], None], + projection_pushdown: bool, +) -> None: + path = tmp_path / "a=1/b=2/data.bin" + path.parent.mkdir(exist_ok=True, parents=True) + df = pl.DataFrame( + {"x": 1, "a": 1, "b": 2, "y": 1}, + schema={"x": pl.Int32, "a": pl.Int8, "b": pl.Int16, "y": pl.Int32}, + ) + write_func(df, path) + + def assert_with_projections( + lf: pl.LazyFrame, df: pl.DataFrame, *, row_index: str | None = None + ) -> None: + row_index: list[str] = [row_index] if row_index is not None else [] # type: ignore[no-redef] + + from itertools import permutations + + cols = ["a", "b", "x", "y", *row_index] # type: ignore[misc] + + for projection in ( + x for i in range(len(cols)) for x in permutations(cols[: 1 + i]) + ): + assert_frame_equal( + lf.select(projection).collect(projection_pushdown=projection_pushdown), + df.select(projection), + ) + + lf = scan_func(path, hive_partitioning=True) # type: ignore[call-arg] + rhs = df + assert_frame_equal(lf.collect(projection_pushdown=projection_pushdown), rhs) + assert_with_projections(lf, rhs) + + lf = scan_func( # type: ignore[call-arg] + path, + hive_schema={"a": pl.String, "b": pl.String}, + hive_partitioning=True, + ) + rhs = df.with_columns(pl.col("a", "b").cast(pl.String)) + assert_frame_equal( + lf.collect(projection_pushdown=projection_pushdown), + rhs, + ) + assert_with_projections(lf, rhs) + + # partial cols in file + partial_path = tmp_path / "a=1/b=2/partial_data.bin" + df = pl.DataFrame( + {"x": 1, "b": 2, "y": 1}, + schema={"x": pl.Int32, "b": pl.Int16, "y": pl.Int32}, + ) + write_func(df, partial_path) + + rhs = rhs.select( + pl.col("x").cast(pl.Int32), + pl.col("b").cast(pl.Int16), + pl.col("y").cast(pl.Int32), + pl.col("a").cast(pl.Int64), + ) + + lf = scan_func(partial_path, hive_partitioning=True) # type: ignore[call-arg] + assert_frame_equal(lf.collect(projection_pushdown=projection_pushdown), rhs) + assert_with_projections(lf, rhs) + + assert_frame_equal( + lf.with_row_index().collect(projection_pushdown=projection_pushdown), + rhs.with_row_index(), + ) + assert_with_projections( + lf.with_row_index(), rhs.with_row_index(), row_index="index" + ) + + assert_frame_equal( + lf.with_row_index() + .select(pl.exclude("index"), "index") + .collect(projection_pushdown=projection_pushdown), + rhs.with_row_index().select(pl.exclude("index"), "index"), + ) + assert_with_projections( + lf.with_row_index().select(pl.exclude("index"), "index"), + rhs.with_row_index().select(pl.exclude("index"), "index"), + row_index="index", + ) + + lf = scan_func( # type: ignore[call-arg] + partial_path, + hive_schema={"a": pl.String, "b": pl.String}, + hive_partitioning=True, + ) + rhs = rhs.select( + pl.col("x").cast(pl.Int32), + pl.col("b").cast(pl.String), + pl.col("y").cast(pl.Int32), + pl.col("a").cast(pl.String), + ) + assert_frame_equal( + lf.collect(projection_pushdown=projection_pushdown), + rhs, + ) + assert_with_projections(lf, rhs) + + +@pytest.mark.write_disk +def test_hive_partition_dates(tmp_path: Path) -> None: + df = pl.DataFrame( + { + "date1": [ + datetime(2024, 1, 1), + datetime(2024, 2, 1), + datetime(2024, 3, 1), + None, + ], + "date2": [ + datetime(2023, 1, 1), + datetime(2023, 2, 1), + None, + datetime(2023, 3, 1), + ], + "x": [1, 2, 3, 4], + }, + schema={"date1": pl.Date, "date2": pl.Datetime, "x": pl.Int32}, + ) + + root = tmp_path / "pyarrow" + pq.write_to_dataset( + df.to_arrow(), + root_path=root, + partition_cols=["date1", "date2"], + ) + + lf = pl.scan_parquet( + root, hive_schema=df.clear().select("date1", "date2").collect_schema() + ) + assert_frame_equal(lf.collect(), df.select("x", "date1", "date2")) + + lf = pl.scan_parquet(root) + assert_frame_equal(lf.collect(), df.select("x", "date1", "date2")) + + lf = pl.scan_parquet(root, try_parse_hive_dates=False) + assert_frame_equal( + lf.collect(), + df.select("x", "date1", "date2").with_columns( + pl.col("date1", "date2").cast(pl.String) + ), + ) + + for perc_escape in [True, False] if sys.platform != "win32" else [True]: + root = tmp_path / f"includes_hive_cols_in_file_{perc_escape}" + for (date1, date2), part_df in df.group_by( + pl.col("date1").cast(pl.String).fill_null("__HIVE_DEFAULT_PARTITION__"), + pl.col("date2").cast(pl.String).fill_null("__HIVE_DEFAULT_PARTITION__"), + ): + if perc_escape: + date2 = urllib.parse.quote(date2) # type: ignore[call-overload] + + path = root / f"date1={date1}/date2={date2}/data.bin" + path.parent.mkdir(exist_ok=True, parents=True) + part_df.write_parquet(path) + + # The schema for the hive columns is included in the file, so it should + # just work + lf = pl.scan_parquet(root) + assert_frame_equal(lf.collect(), df) + + lf = pl.scan_parquet(root, try_parse_hive_dates=False) + assert_frame_equal( + lf.collect(), + df.with_columns(pl.col("date1", "date2").cast(pl.String)), + ) + + +@pytest.mark.parametrize( + ("scan_func", "write_func"), + [ + (pl.scan_parquet, pl.DataFrame.write_parquet), + (pl.scan_ipc, pl.DataFrame.write_ipc), + ], +) +@pytest.mark.write_disk +def test_projection_only_hive_parts_gives_correct_number_of_rows( + tmp_path: Path, + scan_func: Callable[[Any], pl.LazyFrame], + write_func: Callable[[pl.DataFrame, Path], None], +) -> None: + # Check the number of rows projected when projecting only hive parts, which + # should be the same as the number of rows in the file. + path = tmp_path / "a=3/data.bin" + path.parent.mkdir(exist_ok=True, parents=True) + + write_func(pl.DataFrame({"x": [1, 1, 1]}), path) + + assert_frame_equal( + scan_func(path, hive_partitioning=True).select("a").collect(), # type: ignore[call-arg] + pl.DataFrame({"a": [3, 3, 3]}), + ) + + +@pytest.mark.parametrize( + "df", + [ + pl.select( + pl.Series("a", [1, 2, 3, 4], dtype=pl.Int8), + pl.Series("b", [1, 2, 3, 4], dtype=pl.Int8), + pl.Series("x", [1, 2, 3, 4]), + ), + pl.select( + pl.Series( + "a", + [1.2981275, 2.385974035, 3.1231892749185718397510, 4.129387128949156], + dtype=pl.Float64, + ), + pl.Series("b", ["a", "b", " / c = : ", "d"]), + pl.Series("x", [1, 2, 3, 4]), + ), + ], +) +@pytest.mark.write_disk +def test_hive_write(tmp_path: Path, df: pl.DataFrame) -> None: + root = tmp_path + df.write_parquet(root, partition_by=["a", "b"]) + + lf = pl.scan_parquet(root) + assert_frame_equal(lf.collect(), df) + + lf = pl.scan_parquet(root, hive_schema={"a": pl.String, "b": pl.String}) + assert_frame_equal(lf.collect(), df.with_columns(pl.col("a", "b").cast(pl.String))) + + +@pytest.mark.slow +@pytest.mark.write_disk +def test_hive_write_multiple_files(tmp_path: Path) -> None: + chunk_size = 262_144 + n_rows = 100_000 + df = pl.select(a=pl.repeat(0, n_rows), b=pl.int_range(0, n_rows)) + + n_files = int(df.estimated_size() / chunk_size) + + assert n_files > 1, "increase df size or decrease file size" + + root = tmp_path + df.write_parquet(root, partition_by="a", partition_chunk_size_bytes=chunk_size) + + assert sum(1 for _ in (root / "a=0").iterdir()) == n_files + assert_frame_equal(pl.scan_parquet(root).collect(), df) + + +@pytest.mark.write_disk +def test_hive_write_dates(tmp_path: Path) -> None: + df = pl.DataFrame( + { + "date1": [ + datetime(2024, 1, 1), + datetime(2024, 2, 1), + datetime(2024, 3, 1), + None, + ], + "date2": [ + datetime(2023, 1, 1), + datetime(2023, 2, 1), + None, + datetime(2023, 3, 1, 1, 1, 1, 1), + ], + "x": [1, 2, 3, 4], + }, + schema={"date1": pl.Date, "date2": pl.Datetime, "x": pl.Int32}, + ) + + root = tmp_path + df.write_parquet(root, partition_by=["date1", "date2"]) + + lf = pl.scan_parquet(root) + assert_frame_equal(lf.collect(), df) + + lf = pl.scan_parquet(root, try_parse_hive_dates=False) + assert_frame_equal( + lf.collect(), + df.with_columns(pl.col("date1", "date2").cast(pl.String)), + ) + + +@pytest.mark.write_disk +@pytest.mark.may_fail_auto_streaming +def test_hive_predicate_dates_14712( + tmp_path: Path, monkeypatch: Any, capfd: Any +) -> None: + monkeypatch.setenv("POLARS_VERBOSE", "1") + pl.DataFrame({"a": [datetime(2024, 1, 1)]}).write_parquet( + tmp_path, partition_by="a" + ) + pl.scan_parquet(tmp_path).filter(pl.col("a") != datetime(2024, 1, 1)).collect() + assert "allows skipping 1 / 1" in capfd.readouterr().err + + +@pytest.mark.skipif(sys.platform != "win32", reason="Test is only for Windows paths") +@pytest.mark.write_disk +def test_hive_windows_splits_on_forward_slashes(tmp_path: Path) -> None: + # Note: This needs to be an absolute path. + tmp_path = tmp_path.resolve() + path = f"{tmp_path}/a=1/b=1/c=1/d=1/e=1" + Path(path).mkdir(exist_ok=True, parents=True) + + df = pl.DataFrame({"x": "x"}) + df.write_parquet(f"{path}/data.parquet") + + expect = pl.DataFrame( + [ + s.new_from_index(0, 5) + for s in pl.DataFrame( + { + "x": "x", + "a": 1, + "b": 1, + "c": 1, + "d": 1, + "e": 1, + } + ) + ] + ) + + assert_frame_equal( + pl.scan_parquet( + [ + f"{tmp_path}/a=1/b=1/c=1/d=1/e=1/data.parquet", + f"{tmp_path}\\a=1\\b=1\\c=1\\d=1\\e=1\\data.parquet", + f"{tmp_path}\\a=1/b=1/c=1/d=1/**/*", + f"{tmp_path}/a=1/b=1\\c=1/d=1/**/*", + f"{tmp_path}/a=1/b=1/c=1/d=1\\e=1/*", + ], + hive_partitioning=True, + ).collect(), + expect, + ) + + +@pytest.mark.write_disk +def test_passing_hive_schema_with_hive_partitioning_disabled_raises( + tmp_path: Path, +) -> None: + with pytest.raises( + ComputeError, + match="a hive schema was given but hive_partitioning was disabled", + ): + pl.scan_parquet( + tmp_path, + schema={"x": pl.Int64}, + hive_schema={"h": pl.String}, + hive_partitioning=False, + ).collect() + + +@pytest.mark.write_disk +def test_hive_auto_enables_when_unspecified_and_hive_schema_passed( + tmp_path: Path, +) -> None: + tmp_path.mkdir(exist_ok=True) + (tmp_path / "a=1").mkdir(exist_ok=True) + + pl.DataFrame({"x": 1}).write_parquet(tmp_path / "a=1/1") + + for path in [tmp_path / "a=1/1", tmp_path / "**/*"]: + lf = pl.scan_parquet(path, hive_schema={"a": pl.UInt8}) + + assert_frame_equal( + lf.collect(), + pl.select( + pl.Series("x", [1]), + pl.Series("a", [1], dtype=pl.UInt8), + ), + ) + + +@pytest.mark.write_disk +@pytest.mark.parametrize("force_single_thread", [True, False]) +def test_hive_parquet_prefiltered_20894_21327( + tmp_path: Path, force_single_thread: bool +) -> None: + n_threads = 1 if force_single_thread else pl.thread_pool_size() + + file_path = tmp_path / "date=2025-01-01/00000000.parquet" + file_path.parent.mkdir(exist_ok=True, parents=True) + + data = pl.DataFrame( + { + "date": [date(2025, 1, 1), date(2025, 1, 1)], + "value": ["1", "2"], + } + ) + + data.write_parquet(file_path) + + import base64 + import subprocess + + # For security, and for Windows backslashes. + scan_path_b64 = base64.b64encode(str(file_path.absolute()).encode()).decode() + + # This is, the easiest way to control the threadpool size so that it is stable. + out = subprocess.check_output( + [ + sys.executable, + "-c", + f"""\ +import os +os.environ["POLARS_MAX_THREADS"] = "{n_threads}" + +import polars as pl +import datetime +import base64 + +from polars.testing import assert_frame_equal + +assert pl.thread_pool_size() == {n_threads} + +tmp_path = base64.b64decode("{scan_path_b64}").decode() +df = pl.scan_parquet(tmp_path, hive_partitioning=True).filter(pl.col("value") == "1").collect() +# We need the str() to trigger panic on invalid state +str(df) + +assert_frame_equal(df, pl.DataFrame( + [ + pl.Series('date', [datetime.date(2025, 1, 1)], dtype=pl.Date), + pl.Series('value', ['1'], dtype=pl.String), + ] +)) + +print("OK", end="") +""", + ], + ) + + assert out == b"OK" diff --git a/py-polars/tests/unit/io/test_iceberg.py b/py-polars/tests/unit/io/test_iceberg.py new file mode 100644 index 000000000000..97bcf615e0a9 --- /dev/null +++ b/py-polars/tests/unit/io/test_iceberg.py @@ -0,0 +1,226 @@ +# mypy: disable-error-code="attr-defined" +from __future__ import annotations + +import contextlib +import os +from datetime import datetime +from pathlib import Path + +import pytest + +import polars as pl +from polars.io.iceberg import _convert_predicate, _to_ast +from polars.testing import assert_frame_equal + + +@pytest.fixture +def iceberg_path(io_files_path: Path) -> str: + # Iceberg requires absolute paths, so we'll symlink + # the test table into /tmp/iceberg/t1/ + Path("/tmp/iceberg").mkdir(parents=True, exist_ok=True) + current_path = Path(__file__).parent.resolve() + + with contextlib.suppress(FileExistsError): + os.symlink(f"{current_path}/files/iceberg-table", "/tmp/iceberg/t1") + + iceberg_path = io_files_path / "iceberg-table" / "metadata" / "v2.metadata.json" + return f"file://{iceberg_path.resolve()}" + + +@pytest.mark.slow +@pytest.mark.write_disk +@pytest.mark.filterwarnings( + "ignore:No preferred file implementation for scheme*:UserWarning" +) +@pytest.mark.ci_only +class TestIcebergScanIO: + """Test coverage for `iceberg` scan ops.""" + + def test_scan_iceberg_plain(self, iceberg_path: str) -> None: + df = pl.scan_iceberg(iceberg_path) + assert len(df.collect()) == 3 + assert df.collect_schema() == { + "id": pl.Int32, + "str": pl.String, + "ts": pl.Datetime(time_unit="us", time_zone=None), + } + + def test_scan_iceberg_snapshot_id(self, iceberg_path: str) -> None: + df = pl.scan_iceberg(iceberg_path, snapshot_id=7051579356916758811) + assert len(df.collect()) == 3 + assert df.collect_schema() == { + "id": pl.Int32, + "str": pl.String, + "ts": pl.Datetime(time_unit="us", time_zone=None), + } + + def test_scan_iceberg_snapshot_id_not_found(self, iceberg_path: str) -> None: + with pytest.raises(ValueError, match="Snapshot ID not found"): + pl.scan_iceberg(iceberg_path, snapshot_id=1234567890) + + def test_scan_iceberg_filter_on_partition(self, iceberg_path: str) -> None: + ts1 = datetime(2023, 3, 1, 18, 15) + ts2 = datetime(2023, 3, 1, 19, 25) + ts3 = datetime(2023, 3, 2, 22, 0) + + lf = pl.scan_iceberg(iceberg_path) + + res = lf.filter(pl.col("ts") >= ts2) + assert len(res.collect()) == 2 + + res = lf.filter(pl.col("ts") > ts2).select(pl.col("id")) + assert res.collect().rows() == [(3,)] + + res = lf.filter(pl.col("ts") <= ts2).select("id", "ts") + assert res.collect().rows(named=True) == [ + {"id": 1, "ts": ts1}, + {"id": 2, "ts": ts2}, + ] + + res = lf.filter(pl.col("ts") > ts3) + assert len(res.collect()) == 0 + + for constraint in ( + (pl.col("ts") == ts1) | (pl.col("ts") == ts3), + pl.col("ts").is_in([ts1, ts3]), + ): + res = lf.filter(constraint).select("id") + assert res.collect().rows() == [(1,), (3,)] + + def test_scan_iceberg_filter_on_column(self, iceberg_path: str) -> None: + lf = pl.scan_iceberg(iceberg_path) + res = lf.filter(pl.col("id") < 2) + assert res.collect().rows() == [(1, "1", datetime(2023, 3, 1, 18, 15))] + + res = lf.filter(pl.col("id") == 2) + assert res.collect().rows() == [(2, "2", datetime(2023, 3, 1, 19, 25))] + + res = lf.filter(pl.col("id").is_in([1, 3])) + assert res.collect().rows() == [ + (1, "1", datetime(2023, 3, 1, 18, 15)), + (3, "3", datetime(2023, 3, 2, 22, 0)), + ] + + +@pytest.mark.ci_only +class TestIcebergExpressions: + """Test coverage for `iceberg` expressions comprehension.""" + + def test_is_null_expression(self) -> None: + from pyiceberg.expressions import IsNull + + expr = _to_ast("(pa.compute.field('id')).is_null()") + assert _convert_predicate(expr) == IsNull("id") + + def test_is_not_null_expression(self) -> None: + from pyiceberg.expressions import IsNull, Not + + expr = _to_ast("~(pa.compute.field('id')).is_null()") + assert _convert_predicate(expr) == Not(IsNull("id")) + + def test_isin_expression(self) -> None: + from pyiceberg.expressions import In, literal + + expr = _to_ast("(pa.compute.field('id')).isin([1,2,3])") + assert _convert_predicate(expr) == In( + "id", {literal(1), literal(2), literal(3)} + ) + + def test_parse_combined_expression(self) -> None: + from pyiceberg.expressions import ( + And, + EqualTo, + GreaterThan, + In, + Or, + Reference, + literal, + ) + + expr = _to_ast( + "(((pa.compute.field('str') == '2') & (pa.compute.field('id') > 10)) | (pa.compute.field('id')).isin([1,2,3]))" + ) + assert _convert_predicate(expr) == Or( + left=And( + left=EqualTo(term=Reference(name="str"), literal=literal("2")), + right=GreaterThan(term="id", literal=literal(10)), + ), + right=In("id", {literal(1), literal(2), literal(3)}), + ) + + def test_parse_gt(self) -> None: + from pyiceberg.expressions import GreaterThan + + expr = _to_ast("(pa.compute.field('ts') > '2023-08-08')") + assert _convert_predicate(expr) == GreaterThan("ts", "2023-08-08") + + def test_parse_gteq(self) -> None: + from pyiceberg.expressions import GreaterThanOrEqual + + expr = _to_ast("(pa.compute.field('ts') >= '2023-08-08')") + assert _convert_predicate(expr) == GreaterThanOrEqual("ts", "2023-08-08") + + def test_parse_eq(self) -> None: + from pyiceberg.expressions import EqualTo + + expr = _to_ast("(pa.compute.field('ts') == '2023-08-08')") + assert _convert_predicate(expr) == EqualTo("ts", "2023-08-08") + + def test_parse_lt(self) -> None: + from pyiceberg.expressions import LessThan + + expr = _to_ast("(pa.compute.field('ts') < '2023-08-08')") + assert _convert_predicate(expr) == LessThan("ts", "2023-08-08") + + def test_parse_lteq(self) -> None: + from pyiceberg.expressions import LessThanOrEqual + + expr = _to_ast("(pa.compute.field('ts') <= '2023-08-08')") + assert _convert_predicate(expr) == LessThanOrEqual("ts", "2023-08-08") + + def test_compare_boolean(self) -> None: + from pyiceberg.expressions import EqualTo + + expr = _to_ast("(pa.compute.field('ts') == pa.compute.scalar(True))") + assert _convert_predicate(expr) == EqualTo("ts", True) + + expr = _to_ast("(pa.compute.field('ts') == pa.compute.scalar(False))") + assert _convert_predicate(expr) == EqualTo("ts", False) + + +@pytest.mark.slow +@pytest.mark.write_disk +@pytest.mark.filterwarnings("ignore:Delete operation did not match any records") +@pytest.mark.filterwarnings( + "ignore:Iceberg does not have a dictionary type. will be inferred as large_string on read." +) +def test_write_iceberg(df: pl.DataFrame, tmp_path: Path) -> None: + from pyiceberg.catalog.sql import SqlCatalog + + # time64[ns] type is currently not supported in pyiceberg. + # https://github.com/apache/iceberg-python/issues/1169 + df = df.drop("time", "cat", "enum") + + # in-memory catalog + catalog = SqlCatalog( + "default", uri="sqlite:///:memory:", warehouse=f"file://{tmp_path}" + ) + catalog.create_namespace("foo") + table = catalog.create_table( + "foo.bar", + schema=df.to_arrow().schema, + ) + + df.write_iceberg(table, mode="overwrite") + actual = pl.scan_iceberg(table).collect() + print(df, actual) + + assert_frame_equal(df, actual) + + # append on top of already written data, expecting twice the data + df.write_iceberg(table, mode="append") + # double the `df` by vertically stacking the dataframe on top of itself + expected = df.vstack(df) + actual = pl.scan_iceberg(table).collect() + print(expected, actual) + assert_frame_equal(expected, actual, check_dtypes=False) diff --git a/py-polars/tests/unit/io/test_io_plugin.py b/py-polars/tests/unit/io/test_io_plugin.py new file mode 100644 index 000000000000..738f236ccd6c --- /dev/null +++ b/py-polars/tests/unit/io/test_io_plugin.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +import datetime +import io +from typing import TYPE_CHECKING + +import numpy as np +import pytest + +import polars as pl +from polars.io.plugins import register_io_source +from polars.testing import assert_series_equal + +if TYPE_CHECKING: + from collections.abc import Iterator + + +def test_io_plugin_predicate_no_serialization_21130() -> None: + def custom_io() -> pl.LazyFrame: + def source_generator( + with_columns: list[str] | None, + predicate: pl.Expr | None, + n_rows: int | None, + batch_size: int | None, + ) -> Iterator[pl.DataFrame]: + df = pl.DataFrame( + {"json_val": ['{"a":"1"}', None, '{"a":2}', '{"a":2.1}', '{"a":true}']} + ) + if predicate is not None: + df = df.filter(predicate) + if batch_size and df.height > batch_size: + yield from df.iter_slices(n_rows=batch_size) + else: + yield df + + return register_io_source( + io_source=source_generator, schema={"json_val": pl.String} + ) + + lf = custom_io() + assert lf.filter( + pl.col("json_val").str.json_path_match("$.a").is_in(["1"]) + ).collect().to_dict(as_series=False) == {"json_val": ['{"a":"1"}']} + + +def test_defer() -> None: + lf = pl.defer( + lambda: pl.DataFrame({"a": np.ones(3)}), + schema={"a": pl.Boolean}, + validate_schema=False, + ) + assert lf.collect().to_dict(as_series=False) == {"a": [1.0, 1.0, 1.0]} + lf = pl.defer( + lambda: pl.DataFrame({"a": np.ones(3)}), + schema={"a": pl.Boolean}, + validate_schema=True, + ) + with pytest.raises(pl.exceptions.SchemaError): + lf.collect() + + +def test_empty_iterator_io_plugin() -> None: + def _io_source( + with_columns: list[str] | None, + predicate: pl.Expr | None, + n_rows: int | None, + batch_size: int | None, + ) -> Iterator[pl.DataFrame]: + yield from [] + + schema = pl.Schema([("a", pl.Int64)]) + df = register_io_source(_io_source, schema=schema) + assert df.collect().schema == schema + + +def test_scan_lines() -> None: + def scan_lines(f: io.BytesIO) -> pl.LazyFrame: + schema = pl.Schema({"lines": pl.String()}) + + def generator( + with_columns: list[str] | None, + predicate: pl.Expr | None, + n_rows: int | None, + batch_size: int | None, + ) -> Iterator[pl.DataFrame]: + x = f + if batch_size is None: + batch_size = 100_000 + + batch_lines: list[str] = [] + while n_rows != 0: + batch_lines.clear() + remaining_rows = batch_size + if n_rows is not None: + remaining_rows = min(remaining_rows, n_rows) + n_rows -= remaining_rows + + while remaining_rows != 0 and (line := x.readline().rstrip()): + if isinstance(line, str): + batch_lines += [batch_lines] + else: + batch_lines += [line.decode()] + remaining_rows -= 1 + + df = pl.Series("lines", batch_lines, pl.String()).to_frame() + + if with_columns is not None: + df = df.select(with_columns) + if predicate is not None: + df = df.filter(predicate) + + yield df + + if remaining_rows != 0: + break + + return register_io_source(io_source=generator, schema=schema) + + text = """ +Hello +This is some text +It is spread over multiple lines +This allows it to read into multiple rows. + """.strip() + f = io.BytesIO(bytes(text, encoding="utf-8")) + + assert_series_equal( + scan_lines(f).collect().to_series(), + pl.Series("lines", text.splitlines(), pl.String()), + ) + + +def test_datetime_io_predicate_pushdown_21790() -> None: + recorded: dict[str, pl.Expr | None] = {"predicate": None} + df = pl.DataFrame( + { + "timestamp": [ + datetime.datetime(2024, 1, 1, 0), + datetime.datetime(2024, 1, 3, 0), + ] + } + ) + + def _source( + with_columns: list[str] | None, + predicate: pl.Expr | None, + n_rows: int | None, + batch_size: int | None, + ) -> Iterator[pl.DataFrame]: + # capture the predicate passed in + recorded["predicate"] = predicate + inner_df = df.clone() + if with_columns is not None: + inner_df = inner_df.select(with_columns) + if predicate is not None: + inner_df = inner_df.filter(predicate) + + yield inner_df + + schema = {"timestamp": pl.Datetime(time_unit="ns")} + lf = register_io_source(io_source=_source, schema=schema) + + cutoff = datetime.datetime(2024, 1, 4) + expr = pl.col("timestamp") < cutoff + filtered_df = lf.filter(expr).collect() + + pushed_predicate = recorded["predicate"] + assert pushed_predicate is not None + assert_series_equal(filtered_df.to_series(), df.filter(expr).to_series()) + + # check the expression directly + dt_val, column_cast = pushed_predicate.meta.pop() + # Extract the datetime value from the expression + assert pl.DataFrame({}).select(dt_val).item() == cutoff + + column = column_cast.meta.pop()[0] + assert column.meta == pl.col("timestamp") diff --git a/py-polars/tests/unit/io/test_ipc.py b/py-polars/tests/unit/io/test_ipc.py new file mode 100644 index 000000000000..3e3c8a06633f --- /dev/null +++ b/py-polars/tests/unit/io/test_ipc.py @@ -0,0 +1,471 @@ +from __future__ import annotations + +import io +from decimal import Decimal +from typing import TYPE_CHECKING, Any, no_type_check + +import pandas as pd +import pytest + +import polars as pl +from polars.interchange.protocol import CompatLevel +from polars.testing import assert_frame_equal, assert_series_equal + +if TYPE_CHECKING: + from pathlib import Path + + from polars._typing import IpcCompression + from tests.unit.conftest import MemoryUsage + +COMPRESSIONS = ["uncompressed", "lz4", "zstd"] + + +def read_ipc(is_stream: bool, *args: Any, **kwargs: Any) -> pl.DataFrame: + if is_stream: + return pl.read_ipc_stream(*args, **kwargs) + else: + return pl.read_ipc(*args, **kwargs) + + +def write_ipc(df: pl.DataFrame, is_stream: bool, *args: Any, **kwargs: Any) -> Any: + if is_stream: + return df.write_ipc_stream(*args, **kwargs) + else: + return df.write_ipc(*args, **kwargs) + + +@pytest.mark.parametrize("compression", COMPRESSIONS) +@pytest.mark.parametrize("stream", [True, False]) +def test_from_to_buffer( + df: pl.DataFrame, compression: IpcCompression, stream: bool +) -> None: + # use an ad-hoc buffer (file=None) + buf1 = write_ipc(df, stream, None, compression=compression) + buf1.seek(0) + read_df = read_ipc(stream, buf1, use_pyarrow=False) + assert_frame_equal(df, read_df, categorical_as_str=True) + + # explicitly supply an existing buffer + buf2 = io.BytesIO() + buf2.seek(0) + write_ipc(df, stream, buf2, compression=compression) + buf2.seek(0) + read_df = read_ipc(stream, buf2, use_pyarrow=False) + assert_frame_equal(df, read_df, categorical_as_str=True) + + +@pytest.mark.parametrize("compression", COMPRESSIONS) +@pytest.mark.parametrize("path_as_string", [True, False]) +@pytest.mark.parametrize("stream", [True, False]) +@pytest.mark.write_disk +def test_from_to_file( + df: pl.DataFrame, + compression: IpcCompression, + path_as_string: bool, + tmp_path: Path, + stream: bool, +) -> None: + tmp_path.mkdir(exist_ok=True) + file_path = tmp_path / "small.ipc" + if path_as_string: + file_path = str(file_path) # type: ignore[assignment] + write_ipc(df, stream, file_path, compression=compression) + df_read = read_ipc(stream, file_path, use_pyarrow=False) + + assert_frame_equal(df, df_read, categorical_as_str=True) + + +@pytest.mark.parametrize("stream", [True, False]) +@pytest.mark.write_disk +def test_select_columns_from_file( + df: pl.DataFrame, tmp_path: Path, stream: bool +) -> None: + tmp_path.mkdir(exist_ok=True) + file_path = tmp_path / "small.ipc" + write_ipc(df, stream, file_path) + df_read = read_ipc(stream, file_path, columns=["bools"]) + + assert df_read.columns == ["bools"] + + +@pytest.mark.parametrize("stream", [True, False]) +def test_select_columns_from_buffer(stream: bool) -> None: + df = pl.DataFrame( + { + "a": [1], + "b": [2], + "c": [3], + }, + schema={"a": pl.Int64(), "b": pl.Int128(), "c": pl.UInt8()}, + ) + + f = io.BytesIO() + write_ipc(df, stream, f) + f.seek(0) + + actual = read_ipc(stream, f, columns=["b", "c", "a"], use_pyarrow=False) + + expected = pl.DataFrame( + { + "b": [2], + "c": [3], + "a": [1], + }, + schema={"b": pl.Int128(), "c": pl.UInt8(), "a": pl.Int64()}, + ) + assert_frame_equal(expected, actual) + + +@pytest.mark.parametrize("stream", [True, False]) +def test_select_columns_projection(stream: bool) -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [True, False, True], "c": ["a", "b", "c"]}) + expected = pl.DataFrame({"b": [True, False, True], "c": ["a", "b", "c"]}) + + f = io.BytesIO() + write_ipc(df, stream, f) + f.seek(0) + + read_df = read_ipc(stream, f, columns=[1, 2], use_pyarrow=False) + assert_frame_equal(expected, read_df) + + +@pytest.mark.parametrize("compression", COMPRESSIONS) +@pytest.mark.parametrize("stream", [True, False]) +def test_compressed_simple(compression: IpcCompression, stream: bool) -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [True, False, True], "c": ["a", "b", "c"]}) + + f = io.BytesIO() + write_ipc(df, stream, f, compression=compression) + f.seek(0) + + df_read = read_ipc(stream, f, use_pyarrow=False) + assert_frame_equal(df_read, df) + + +@pytest.mark.parametrize("compression", COMPRESSIONS) +def test_ipc_schema(compression: IpcCompression) -> None: + schema = { + "i64": pl.Int64(), + "i128": pl.Int128(), + "u8": pl.UInt8(), + "f32": pl.Float32(), + "f64": pl.Float64(), + "str": pl.String(), + "bool": pl.Boolean(), + } + df = pl.DataFrame( + { + "i64": [1, 2], + "i128": [1, 2], + "u8": [1, 2], + "f32": [1, 2], + "f64": [1, 2], + "str": ["a", None], + "bool": [True, False], + }, + schema=schema, + ) + + f = io.BytesIO() + df.write_ipc(f, compression=compression) + f.seek(0) + + assert pl.read_ipc_schema(f) == schema + + +@pytest.mark.write_disk +@pytest.mark.parametrize("compression", COMPRESSIONS) +@pytest.mark.parametrize("path_as_string", [True, False]) +def test_ipc_schema_from_file( + df_no_lists: pl.DataFrame, + compression: IpcCompression, + path_as_string: bool, + tmp_path: Path, +) -> None: + tmp_path.mkdir(exist_ok=True) + + file_path = tmp_path / "small.ipc" + if path_as_string: + file_path = str(file_path) # type: ignore[assignment] + df_no_lists.write_ipc(file_path, compression=compression) + schema = pl.read_ipc_schema(file_path) + + expected = { + "bools": pl.Boolean(), + "bools_nulls": pl.Boolean(), + "int": pl.Int64(), + "int_nulls": pl.Int64(), + "floats": pl.Float64(), + "floats_nulls": pl.Float64(), + "strings": pl.String(), + "strings_nulls": pl.String(), + "date": pl.Date(), + "datetime": pl.Datetime(), + "time": pl.Time(), + "cat": pl.Categorical(), + "enum": pl.Enum( + [] + ), # at schema inference categories are not read an empty Enum is returned + } + assert schema == expected + + +@pytest.mark.parametrize("stream", [True, False]) +def test_ipc_column_order(stream: bool) -> None: + df = pl.DataFrame( + { + "cola": ["x", "y", "z"], + "colb": [1, 2, 3], + "colc": [4.5, 5.6, 6.7], + } + ) + f = io.BytesIO() + write_ipc(df, stream, f) + f.seek(0) + + columns = ["colc", "colb", "cola"] + # read file into polars; the specified column order is no longer respected + assert read_ipc(stream, f, columns=columns).columns == columns + + +@pytest.mark.write_disk +def test_glob_ipc(df: pl.DataFrame, tmp_path: Path) -> None: + file_path = tmp_path / "small.ipc" + df.write_ipc(file_path) + + file_path_glob = tmp_path / "small*.ipc" + + result_scan = pl.scan_ipc(file_path_glob).collect() + result_read = pl.read_ipc(file_path_glob, use_pyarrow=False) + + for result in [result_scan, result_read]: + assert_frame_equal(result, df, categorical_as_str=True) + + +def test_from_float16() -> None: + # Create a feather file with a 16-bit floating point column + pandas_df = pd.DataFrame({"column": [1.0]}, dtype="float16") + f = io.BytesIO() + pandas_df.to_feather(f) + f.seek(0) + assert pl.read_ipc(f, use_pyarrow=False).dtypes == [pl.Float32] + + +@pytest.mark.write_disk +def test_binview_ipc_mmap(tmp_path: Path) -> None: + df = pl.DataFrame({"foo": ["aa" * 10, "bb", None, "small", "big" * 20]}) + file_path = tmp_path / "dump.ipc" + df.write_ipc(file_path, compat_level=CompatLevel.newest()) + read = pl.read_ipc(file_path, memory_map=True) + assert_frame_equal(df, read) + + +def test_list_nested_enum() -> None: + dtype = pl.List(pl.Enum(["a", "b", "c"])) + df = pl.DataFrame(pl.Series("list_cat", [["a", "b", "c", None]], dtype=dtype)) + buffer = io.BytesIO() + df.write_ipc(buffer, compat_level=CompatLevel.newest()) + buffer.seek(0) + df = pl.read_ipc(buffer) + assert df.get_column("list_cat").dtype == dtype + + +def test_struct_nested_enum() -> None: + dtype = pl.Struct({"enum": pl.Enum(["a", "b", "c"])}) + df = pl.DataFrame( + pl.Series( + "struct_cat", [{"enum": "a"}, {"enum": "b"}, {"enum": None}], dtype=dtype + ) + ) + buffer = io.BytesIO() + df.write_ipc(buffer, compat_level=CompatLevel.newest()) + buffer.seek(0) + df = pl.read_ipc(buffer) + assert df.get_column("struct_cat").dtype == dtype + + +@pytest.mark.slow +def test_ipc_view_gc_14448() -> None: + f = io.BytesIO() + # This size was required to trigger the bug + df = pl.DataFrame( + pl.Series(["small"] * 10 + ["looooooong string......."] * 750).slice(20, 20) + ) + df.write_ipc(f, compat_level=CompatLevel.newest()) + f.seek(0) + assert_frame_equal(pl.read_ipc(f), df) + + +@pytest.mark.slow +@pytest.mark.write_disk +@pytest.mark.parametrize("stream", [True, False]) +def test_read_ipc_only_loads_selected_columns( + memory_usage_without_pyarrow: MemoryUsage, + tmp_path: Path, + stream: bool, +) -> None: + """Only requested columns are loaded by ``read_ipc()``/``read_ipc_stream()``.""" + tmp_path.mkdir(exist_ok=True) + + # Each column will be about 16MB of RAM. There's a fixed overhead tied to + # block size so smaller file sizes can be misleading in terms of memory + # usage. + series = pl.arange(0, 2_000_000, dtype=pl.Int64, eager=True) + + file_path = tmp_path / "multicolumn.ipc" + df = pl.DataFrame( + { + "a": series, + "b": series, + } + ) + write_ipc(df, stream, file_path) + del df, series + + memory_usage_without_pyarrow.reset_tracking() + + # Only load one column: + kwargs = {} + if not stream: + kwargs["memory_map"] = False + df = read_ipc(stream, str(file_path), columns=["b"], rechunk=False, **kwargs) + del df + # Only one column's worth of memory should be used; 2 columns would be + # 32_000_000 at least, but there's some overhead. + # assert 16_000_000 < memory_usage_without_pyarrow.get_peak() < 23_000_000 + + +@pytest.mark.write_disk +def test_ipc_decimal_15920( + tmp_path: Path, +) -> None: + tmp_path.mkdir(exist_ok=True) + + base_df = pl.Series( + "x", + [ + *[ + Decimal(x) + for x in [ + "10.1", "11.2", "12.3", "13.4", "14.5", "15.6", "16.7", "17.8", "18.9", "19.0", + "20.1", "21.2", "22.3", "23.4", "24.5", "25.6", "26.7", "27.8", "28.9", "29.0", + "30.1", "31.2", "32.3", "33.4", "34.5", "35.6", "36.7", "37.8", "38.9", "39.0" + ] + ], + *(50 * [None]) + ], + dtype=pl.Decimal(18, 2), + ).to_frame() # fmt: skip + + for df in [base_df, base_df.drop_nulls()]: + path = f"{tmp_path}/data" + df.write_ipc(path) + assert_frame_equal(pl.read_ipc(path), df) + + +def test_ipc_variadic_buffers_categorical_binview_18636() -> None: + df = pl.DataFrame( + { + "Test": pl.Series(["Value012"], dtype=pl.Categorical), + "Test2": pl.Series(["Value Two 20032"], dtype=pl.String), + } + ) + + b = io.BytesIO() + df.write_ipc(b) + b.seek(0) + assert_frame_equal(pl.read_ipc(b), df) + + +@pytest.mark.parametrize("size", [0, 1, 2, 13]) +def test_ipc_chunked_roundtrip(size: int) -> None: + a = pl.Series("a", [{"x": 1}] * size, pl.Struct({"x": pl.Int8})).to_frame() + + c = pl.concat([a] * 2, how="vertical") + + f = io.BytesIO() + c.write_ipc(f) + + f.seek(0) + assert_frame_equal(c, pl.read_ipc(f)) + + +@pytest.mark.parametrize("size", [0, 1, 2, 13]) +def test_zfs_ipc_roundtrip(size: int) -> None: + a = pl.Series("a", [{}] * size, pl.Struct([])).to_frame() + + f = io.BytesIO() + a.write_ipc(f) + + f.seek(0) + assert_frame_equal(a, pl.read_ipc(f)) + + +@pytest.mark.parametrize("size", [0, 1, 2, 13]) +def test_zfs_ipc_chunked_roundtrip(size: int) -> None: + a = pl.Series("a", [{}] * size, pl.Struct([])).to_frame() + + c = pl.concat([a] * 2, how="vertical") + + f = io.BytesIO() + c.write_ipc(f) + + f.seek(0) + assert_frame_equal(c, pl.read_ipc(f)) + + +@pytest.mark.parametrize("size", [0, 1, 2, 13]) +@pytest.mark.parametrize("value", [{}, {"x": 1}]) +@pytest.mark.write_disk +def test_memmap_ipc_chunked_structs( + size: int, value: dict[str, int], tmp_path: Path +) -> None: + a = pl.Series("a", [value] * size, pl.Struct).to_frame() + + c = pl.concat([a] * 2, how="vertical") + + f = tmp_path / "f.ipc" + c.write_ipc(f) + assert_frame_equal(c, pl.read_ipc(f)) + + +def test_categorical_lexical_sort_2732() -> None: + df = pl.DataFrame( + { + "a": ["foo", "bar", "baz"], + "b": [1, 3, 2], + }, + schema_overrides={"a": pl.Categorical("lexical")}, + ) + f = io.BytesIO() + df.write_ipc(f) + f.seek(0) + assert_frame_equal(df, pl.read_ipc(f)) + + +def test_enum_scan_21564() -> None: + s = pl.Series("a", ["A"], pl.Enum(["A"])) + + # DataFrame with a an enum field + f = io.BytesIO() + s.to_frame().write_ipc(f) + + f.seek(0) + assert_series_equal( + pl.scan_ipc(f).collect().to_series(), + s, + ) + + +@no_type_check +def test_roundtrip_empty_str_list_21163() -> None: + schema = { + "s": pl.Utf8, + "list": pl.List(pl.Utf8), + } + row1 = pl.DataFrame({"s": ["A"], "list": [[]]}, schema=schema) + row2 = pl.DataFrame({"s": ["B"], "list": [[]]}, schema=schema) + df = pl.concat([row1, row2]) + bytes = df.serialize() + deserialized = pl.DataFrame.deserialize(io.BytesIO(bytes)) + assert_frame_equal(df, deserialized) diff --git a/py-polars/tests/unit/io/test_json.py b/py-polars/tests/unit/io/test_json.py new file mode 100644 index 000000000000..3d4efd0bcbd8 --- /dev/null +++ b/py-polars/tests/unit/io/test_json.py @@ -0,0 +1,615 @@ +from __future__ import annotations + +import gzip +import io +import json +import zlib +from collections import OrderedDict +from datetime import datetime +from decimal import Decimal as D +from io import BytesIO +from typing import TYPE_CHECKING + +import zstandard + +if TYPE_CHECKING: + from pathlib import Path + +import orjson +import pytest + +import polars as pl +from polars.exceptions import ComputeError +from polars.testing import assert_frame_equal + + +def test_write_json() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": ["a", "b", None]}) + out = df.write_json() + assert out == '[{"a":1,"b":"a"},{"a":2,"b":"b"},{"a":3,"b":null}]' + + # Test round trip + f = io.BytesIO() + f.write(out.encode()) + f.seek(0) + result = pl.read_json(f) + assert_frame_equal(result, df) + + +def test_write_json_categoricals() -> None: + data = {"column": ["test1", "test2", "test3", "test4"]} + df = pl.DataFrame(data).with_columns(pl.col("column").cast(pl.Categorical)) + expected = ( + '[{"column":"test1"},{"column":"test2"},{"column":"test3"},{"column":"test4"}]' + ) + assert df.write_json() == expected + + +def test_write_json_duration() -> None: + df = pl.DataFrame( + { + "a": pl.Series( + [91762939, 91762890, 6020836], dtype=pl.Duration(time_unit="ms") + ) + } + ) + + # we don't guarantee a format, just round-circling + value = df.write_json() + expected = '[{"a":"PT91762.939S"},{"a":"PT91762.89S"},{"a":"PT6020.836S"}]' + assert value == expected + + +def test_write_json_time() -> None: + ns = 1_000_000_000 + df = pl.DataFrame( + { + "a": pl.Series( + [7291 * ns + 54321, 54321 * ns + 12345, 86399 * ns], + dtype=pl.Time, + ), + } + ) + + value = df.write_json() + expected = ( + '[{"a":"02:01:31.000054321"},{"a":"15:05:21.000012345"},{"a":"23:59:59"}]' + ) + assert value == expected + + +def test_write_json_decimal() -> None: + df = pl.DataFrame({"a": pl.Series([D("1.00"), D("2.00"), None])}) + + # we don't guarantee a format, just round-circling + value = df.write_json() + assert value == """[{"a":"1.00"},{"a":"2.00"},{"a":null}]""" + + +def test_json_infer_schema_length_11148() -> None: + response = [{"col1": 1}] * 2 + [{"col1": 1, "col2": 2}] * 1 + with pytest.raises( + pl.exceptions.ComputeError, match="extra field in struct data: col2" + ): + pl.read_json(json.dumps(response).encode(), infer_schema_length=2) + + response = [{"col1": 1}] * 2 + [{"col1": 1, "col2": 2}] * 1 + result = pl.read_json(json.dumps(response).encode(), infer_schema_length=3) + assert set(result.columns) == {"col1", "col2"} + + +def test_to_from_buffer_arraywise_schema() -> None: + buf = io.StringIO( + """ + [ + {"a": 5, "b": "foo", "c": null}, + {"a": 11.4, "b": null, "c": true, "d": 8}, + {"a": -25.8, "b": "bar", "c": false} + ]""" + ) + + read_df = pl.read_json(buf, schema={"b": pl.String, "e": pl.Int16}) + + assert_frame_equal( + read_df, + pl.DataFrame( + { + "b": pl.Series(["foo", None, "bar"], dtype=pl.String), + "e": pl.Series([None, None, None], dtype=pl.Int16), + } + ), + ) + + +def test_to_from_buffer_arraywise_schema_override() -> None: + buf = io.StringIO( + """ + [ + {"a": 5, "b": "foo", "c": null}, + {"a": 11.4, "b": null, "c": true, "d": 8}, + {"a": -25.8, "b": "bar", "c": false} + ]""" + ) + + read_df = pl.read_json(buf, schema_overrides={"c": pl.Int64, "d": pl.Float64}) + + assert_frame_equal( + read_df, + pl.DataFrame( + { + "a": pl.Series([5, 11.4, -25.8], dtype=pl.Float64), + "b": pl.Series(["foo", None, "bar"], dtype=pl.String), + "c": pl.Series([None, 1, 0], dtype=pl.Int64), + "d": pl.Series([None, 8, None], dtype=pl.Float64), + } + ), + check_column_order=False, + ) + + +def test_write_ndjson() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": ["a", "b", None]}) + out = df.write_ndjson() + assert out == '{"a":1,"b":"a"}\n{"a":2,"b":"b"}\n{"a":3,"b":null}\n' + + # Test round trip + f = io.BytesIO() + f.write(out.encode()) + f.seek(0) + result = pl.read_ndjson(f) + assert_frame_equal(result, df) + + +def test_write_ndjson_with_trailing_newline() -> None: + input = """{"Column1":"Value1"}\n""" + df = pl.read_ndjson(io.StringIO(input)) + + expected = pl.DataFrame({"Column1": ["Value1"]}) + assert_frame_equal(df, expected) + + +def test_read_ndjson_empty_array() -> None: + assert pl.read_ndjson(io.StringIO("""{"foo": {"bar": []}}""")).to_dict( + as_series=False + ) == {"foo": [{"bar": []}]} + + +def test_ndjson_nested_null() -> None: + json_payload = """{"foo":{"bar":[{}]}}""" + df = pl.read_ndjson(io.StringIO(json_payload)) + + # 'bar' represents an empty list of structs; check the schema is correct (eg: picks + # up that it IS a list of structs), but confirm that list is empty (ref: #11301) + # We don't support empty structs yet. So Null is closest. + assert df.schema == {"foo": pl.Struct([pl.Field("bar", pl.List(pl.Struct({})))])} + assert df.to_dict(as_series=False) == {"foo": [{"bar": [{}]}]} + + +def test_ndjson_nested_string_int() -> None: + ndjson = """{"Accumulables":[{"Value":32395888},{"Value":"539454"}]}""" + assert pl.read_ndjson(io.StringIO(ndjson)).to_dict(as_series=False) == { + "Accumulables": [[{"Value": "32395888"}, {"Value": "539454"}]] + } + + +def test_json_supertype_infer() -> None: + json_string = """[ +{"c":[{"b": [], "a": "1"}]}, +{"c":[{"b":[]}]}, +{"c":[{"b":["1"], "a": "1"}]}] +""" + python_infer = pl.from_records(json.loads(json_string)) + polars_infer = pl.read_json(io.StringIO(json_string)) + assert_frame_equal(python_infer, polars_infer) + + +def test_ndjson_sliced_list_serialization() -> None: + data = {"col1": [0, 2], "col2": [[3, 4, 5], [6, 7, 8]]} + df = pl.DataFrame(data) + f = io.BytesIO() + sliced_df = df[1, :] + sliced_df.write_ndjson(f) + assert f.getvalue() == b'{"col1":2,"col2":[6,7,8]}\n' + + +def test_json_deserialize_9687() -> None: + response = { + "volume": [0.0, 0.0, 0.0], + "open": [1263.0, 1263.0, 1263.0], + "close": [1263.0, 1263.0, 1263.0], + "high": [1263.0, 1263.0, 1263.0], + "low": [1263.0, 1263.0, 1263.0], + } + + result = pl.read_json(json.dumps(response).encode()) + + assert result.to_dict(as_series=False) == {k: [v] for k, v in response.items()} + + +def test_ndjson_ignore_errors() -> None: + # this schema is inconsistent as "value" is string and object + jsonl = r"""{"Type":"insert","Key":[1],"SeqNo":1,"Timestamp":1,"Fields":[{"Name":"added_id","Value":2},{"Name":"body","Value":{"a": 1}}]} + {"Type":"insert","Key":[1],"SeqNo":1,"Timestamp":1,"Fields":[{"Name":"added_id","Value":2},{"Name":"body","Value":{"a": 1}}]}""" + + buf = io.BytesIO(jsonl.encode()) + + # check if we can replace with nulls + assert pl.read_ndjson(buf, ignore_errors=True).to_dict(as_series=False) == { + "Type": ["insert", "insert"], + "Key": [[1], [1]], + "SeqNo": [1, 1], + "Timestamp": [1, 1], + "Fields": [ + [{"Name": "added_id", "Value": "2"}, {"Name": "body", "Value": '{"a": 1}'}], + [{"Name": "added_id", "Value": "2"}, {"Name": "body", "Value": '{"a": 1}'}], + ], + } + + schema = { + "Fields": pl.List( + pl.Struct([pl.Field("Name", pl.String), pl.Field("Value", pl.Int64)]) + ) + } + # schema argument only parses Fields + assert pl.read_ndjson(buf, schema=schema, ignore_errors=True).to_dict( + as_series=False + ) == { + "Fields": [ + [{"Name": "added_id", "Value": 2}, {"Name": "body", "Value": None}], + [{"Name": "added_id", "Value": 2}, {"Name": "body", "Value": None}], + ] + } + + # schema_overrides argument does schema inference, but overrides Fields + result = pl.read_ndjson(buf, schema_overrides=schema, ignore_errors=True) + expected = { + "Type": ["insert", "insert"], + "Key": [[1], [1]], + "SeqNo": [1, 1], + "Timestamp": [1, 1], + "Fields": [ + [{"Name": "added_id", "Value": 2}, {"Name": "body", "Value": None}], + [{"Name": "added_id", "Value": 2}, {"Name": "body", "Value": None}], + ], + } + assert result.to_dict(as_series=False) == expected + + +def test_json_null_infer() -> None: + json = BytesIO( + bytes( + """ + [ + { + "a": 1, + "b": null + } + ] + """, + "UTF-8", + ) + ) + + assert pl.read_json(json).schema == OrderedDict({"a": pl.Int64, "b": pl.Null}) + + +def test_ndjson_null_buffer() -> None: + data = io.BytesIO( + b"""\ + {"id": 1, "zero_column": 0, "empty_array_column": [], "empty_object_column": {}, "null_column": null} + {"id": 2, "zero_column": 0, "empty_array_column": [], "empty_object_column": {}, "null_column": null} + {"id": 3, "zero_column": 0, "empty_array_column": [], "empty_object_column": {}, "null_column": null} + {"id": 4, "zero_column": 0, "empty_array_column": [], "empty_object_column": {}, "null_column": null} + """ + ) + + assert pl.read_ndjson(data).schema == OrderedDict( + [ + ("id", pl.Int64), + ("zero_column", pl.Int64), + ("empty_array_column", pl.List(pl.Null)), + ("empty_object_column", pl.Struct([])), + ("null_column", pl.Null), + ] + ) + + +def test_ndjson_null_inference_13183() -> None: + assert pl.read_ndjson( + b""" + {"map": "a", "start_time": 0.795, "end_time": 1.495} + {"map": "a", "start_time": 1.6239999999999999, "end_time": 2.0540000000000003} + {"map": "c", "start_time": 2.184, "end_time": 2.645} + {"map": "a", "start_time": null, "end_time": null} + """.strip() + ).to_dict(as_series=False) == { + "map": ["a", "a", "c", "a"], + "start_time": [0.795, 1.6239999999999999, 2.184, None], + "end_time": [1.495, 2.0540000000000003, 2.645, None], + } + + +@pytest.mark.write_disk +def test_json_wrong_input_handle_textio(tmp_path: Path) -> None: + # This shouldn't be passed, but still we test if we can handle it gracefully + df = pl.DataFrame( + { + "x": [1, 2, 3], + "y": ["a", "b", "c"], + } + ) + file_path = tmp_path / "test.ndjson" + df.write_ndjson(file_path) + + with file_path.open() as f: + result = pl.read_ndjson(f) + assert_frame_equal(result, df) + + +def test_json_normalize() -> None: + data = [ + {"id": 1, "name": {"first": "Coleen", "last": "Volk"}}, + {"name": {"given": "Mark", "family": "Regner"}}, + {"id": 2, "name": "Faye Raker"}, + ] + + assert pl.json_normalize(data, max_level=0).to_dict(as_series=False) == { + "id": [1, None, 2], + "name": [ + '{"first": "Coleen", "last": "Volk"}', + '{"given": "Mark", "family": "Regner"}', + "Faye Raker", + ], + } + + assert pl.json_normalize(data, max_level=1).to_dict(as_series=False) == { + "id": [1, None, 2], + "name.first": ["Coleen", None, None], + "name.last": ["Volk", None, None], + "name.given": [None, "Mark", None], + "name.family": [None, "Regner", None], + "name": [None, None, "Faye Raker"], + } + + data = [ + { + "id": 1, + "name": "Cole Volk", + "fitness": {"height": 130, "weight": 60}, + }, + {"name": "Mark Reg", "fitness": {"height": 130, "weight": 60}}, + { + "id": 2, + "name": "Faye Raker", + "fitness": {"height": 130, "weight": 60}, + }, + ] + assert pl.json_normalize(data, max_level=1, separator=":").to_dict( + as_series=False, + ) == { + "id": [1, None, 2], + "name": ["Cole Volk", "Mark Reg", "Faye Raker"], + "fitness:height": [130, 130, 130], + "fitness:weight": [60, 60, 60], + } + assert pl.json_normalize(data, max_level=0).to_dict( + as_series=False, + ) == { + "id": [1, None, 2], + "name": ["Cole Volk", "Mark Reg", "Faye Raker"], + "fitness": [ + '{"height": 130, "weight": 60}', + '{"height": 130, "weight": 60}', + '{"height": 130, "weight": 60}', + ], + } + assert pl.json_normalize(data, max_level=0, encoder=orjson.dumps).to_dict( + as_series=False, + ) == { + "id": [1, None, 2], + "name": ["Cole Volk", "Mark Reg", "Faye Raker"], + "fitness": [ + b'{"height":130,"weight":60}', + b'{"height":130,"weight":60}', + b'{"height":130,"weight":60}', + ], + } + + +def test_empty_json() -> None: + df = pl.read_json(io.StringIO("{}")) + assert df.shape == (0, 0) + assert isinstance(df, pl.DataFrame) + + df = pl.read_json(b'{"j":{}}') + assert df.dtypes == [pl.Struct([])] + assert df.shape == (1, 1) + + +def test_compressed_json() -> None: + # shared setup + json_obj = [ + {"id": 1, "name": "Alice", "trusted": True}, + {"id": 2, "name": "Bob", "trusted": True}, + {"id": 3, "name": "Carol", "trusted": False}, + ] + expected = pl.DataFrame(json_obj, orient="row") + json_bytes = json.dumps(json_obj).encode() + + # gzip + compressed_bytes = gzip.compress(json_bytes) + out = pl.read_json(compressed_bytes) + assert_frame_equal(out, expected) + + # zlib + compressed_bytes = zlib.compress(json_bytes) + out = pl.read_json(compressed_bytes) + assert_frame_equal(out, expected) + + # zstd + compressed_bytes = zstandard.compress(json_bytes) + out = pl.read_json(compressed_bytes) + assert_frame_equal(out, expected) + + # no compression + uncompressed = io.BytesIO(json_bytes) + out = pl.read_json(uncompressed) + assert_frame_equal(out, expected) + + +def test_empty_list_json() -> None: + df = pl.read_json(io.StringIO("[]")) # + assert df.shape == (0, 0) + assert isinstance(df, pl.DataFrame) + + df = pl.read_json(b"[]") + assert df.shape == (0, 0) + assert isinstance(df, pl.DataFrame) + + +def test_json_infer_3_dtypes() -> None: + # would SO before + df = pl.DataFrame({"a": ["{}", "1", "[1, 2]"]}) + + with pytest.raises(pl.exceptions.ComputeError): + df.select(pl.col("a").str.json_decode()) + + df = pl.DataFrame({"a": [None, "1", "[1, 2]"]}) + out = df.select(pl.col("a").str.json_decode(dtype=pl.List(pl.String))) + assert out["a"].to_list() == [None, ["1"], ["1", "2"]] + assert out.dtypes[0] == pl.List(pl.String) + + +# NOTE: This doesn't work for 0, but that is normal +@pytest.mark.parametrize("size", [1, 2, 13]) +def test_zfs_json_roundtrip(size: int) -> None: + a = pl.Series("a", [{}] * size, pl.Struct([])).to_frame() + + f = io.StringIO() + a.write_json(f) + + f.seek(0) + assert_frame_equal(a, pl.read_json(f)) + + +def test_read_json_raise_on_data_type_mismatch() -> None: + with pytest.raises(ComputeError): + pl.read_json( + b"""\ +[ + {"a": null}, + {"a": 1} +] +""", + infer_schema_length=1, + ) + + +def test_read_json_struct_schema() -> None: + with pytest.raises(ComputeError, match="extra field in struct data: b"): + pl.read_json( + b"""\ +[ + {"a": 1}, + {"a": 2, "b": 2} +] +""", + infer_schema_length=1, + ) + + assert_frame_equal( + pl.read_json( + b"""\ +[ + {"a": 1}, + {"a": 2, "b": 2} +] +""", + infer_schema_length=2, + ), + pl.DataFrame({"a": [1, 2], "b": [None, 2]}), + ) + + # If the schema was explicitly given, then we ignore extra fields. + # TODO: There should be a `columns=` parameter to this. + assert_frame_equal( + pl.read_json( + b"""\ +[ + {"a": 1}, + {"a": 2, "b": 2} +] +""", + schema={"a": pl.Int64}, + ), + pl.DataFrame({"a": [1, 2]}), + ) + + +def test_read_ndjson_inner_list_types_18244() -> None: + assert pl.read_ndjson( + io.StringIO("""{"a":null,"b":null,"c":null}"""), + schema={ + "a": pl.List(pl.String), + "b": pl.List(pl.Int32), + "c": pl.List(pl.Float64), + }, + ).schema == ( + {"a": pl.List(pl.String), "b": pl.List(pl.Int32), "c": pl.List(pl.Float64)} + ) + + +def test_read_json_utf_8_sig_encoding() -> None: + data = [{"a": [1, 2], "b": [1, 2]}] + result = pl.read_json(json.dumps(data).encode("utf-8-sig")) + expected = pl.DataFrame(data) + assert_frame_equal(result, expected) + + +def test_write_masked_out_list_22202() -> None: + df = pl.DataFrame({"x": [1, 2], "y": [None, 3]}) + + output_file = io.BytesIO() + + query = ( + df.group_by("x", maintain_order=True) + .all() + .select(pl.when(pl.col("y").list.sum() > 0).then("y")) + ) + + eager = query.write_ndjson().encode() + + query.lazy().sink_ndjson(output_file) + lazy = output_file.getvalue() + + assert eager == lazy + + +def test_nested_datetime_ndjson() -> None: + f = io.StringIO( + """{"start_date":"2025-03-14T09:30:27Z","steps":[{"id":1,"start_date":"2025-03-14T09:30:27Z"},{"id":2,"start_date":"2025-03-14T09:31:27Z"}]}""" + ) + + schema = { + "start_date": pl.Datetime, + "steps": pl.List(pl.Struct({"id": pl.Int64, "start_date": pl.Datetime})), + } + + assert pl.read_ndjson(f, schema=schema).to_dict(as_series=False) == { # type: ignore[arg-type] + "start_date": [datetime(2025, 3, 14, 9, 30, 27)], + "steps": [ + [ + {"id": 1, "start_date": datetime(2025, 3, 14, 9, 30, 27)}, + {"id": 2, "start_date": datetime(2025, 3, 14, 9, 31, 27)}, + ] + ], + } + + +def test_ndjson_22229() -> None: + li = [ + '{ "campaign": { "id": "123456" }, "metrics": { "conversions": 7}}', + '{ "campaign": { "id": "654321" }, "metrics": { "conversions": 3.5}}', + ] + + assert pl.read_ndjson(io.StringIO("\n".join(li))).to_dict(as_series=False) diff --git a/py-polars/tests/unit/io/test_lazy_count_star.py b/py-polars/tests/unit/io/test_lazy_count_star.py new file mode 100644 index 000000000000..34717c992c60 --- /dev/null +++ b/py-polars/tests/unit/io/test_lazy_count_star.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pathlib import Path + + from polars.lazyframe.frame import LazyFrame + +import gzip +import re +from tempfile import NamedTemporaryFile + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + + +# Parameters +# * lf: COUNT(*) query +def assert_fast_count( + lf: LazyFrame, + expected_count: int, + *, + expected_name: str = "len", + capfd: pytest.CaptureFixture[str], + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("POLARS_VERBOSE", "1") + + capfd.readouterr() # resets stderr + result = lf.collect() + capture = capfd.readouterr().err + project_logs = set(re.findall(r"project: \d+", capture)) + + # Logs current differ depending on file type / implementation dispatch + if "FAST COUNT" in lf.explain(): + # * Should be no projections when fast count is enabled + assert not project_logs + else: + # * Otherwise should have at least one `project: 0` (there is 1 per file). + assert project_logs == {"project: 0"} + + assert result.schema == {expected_name: pl.get_index_type()} + assert result.item() == expected_count + + # Test effect of the environment variable + monkeypatch.setenv("POLARS_FAST_FILE_COUNT_DISPATCH", "0") + + capfd.readouterr() + lf.collect() + capture = capfd.readouterr().err + project_logs = set(re.findall(r"project: \d+", capture)) + + assert "FAST COUNT" not in lf.explain() + assert project_logs == {"project: 0"} + + monkeypatch.setenv("POLARS_FAST_FILE_COUNT_DISPATCH", "1") + + capfd.readouterr() + lf.collect() + capture = capfd.readouterr().err + project_logs = set(re.findall(r"project: \d+", capture)) + + assert "FAST COUNT" in lf.explain() + assert not project_logs + + +@pytest.mark.parametrize( + ("path", "n_rows"), [("foods1.csv", 27), ("foods*.csv", 27 * 5)] +) +def test_count_csv( + io_files_path: Path, + path: str, + n_rows: int, + capfd: pytest.CaptureFixture[str], + monkeypatch: pytest.MonkeyPatch, +) -> None: + lf = pl.scan_csv(io_files_path / path).select(pl.len()) + + assert_fast_count(lf, n_rows, capfd=capfd, monkeypatch=monkeypatch) + + +def test_count_csv_comment_char( + capfd: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch +) -> None: + q = pl.scan_csv( + b""" +a,b +1,2 + +# +3,4 +""", + comment_prefix="#", + ) + + assert_frame_equal( + q.collect(), pl.DataFrame({"a": [1, None, 3], "b": [2, None, 4]}) + ) + + q = q.select(pl.len()) + assert_fast_count(q, 3, capfd=capfd, monkeypatch=monkeypatch) + + +@pytest.mark.write_disk +def test_commented_csv( + capfd: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch +) -> None: + with NamedTemporaryFile() as csv_a: + csv_a.write(b"A,B\nGr1,A\nGr1,B\n# comment line\n") + csv_a.seek(0) + + lf = pl.scan_csv(csv_a.name, comment_prefix="#").select(pl.len()) + assert_fast_count(lf, 2, capfd=capfd, monkeypatch=monkeypatch) + + +@pytest.mark.parametrize( + ("pattern", "n_rows"), [("small.parquet", 4), ("foods*.parquet", 54)] +) +def test_count_parquet( + io_files_path: Path, + pattern: str, + n_rows: int, + capfd: pytest.CaptureFixture[str], + monkeypatch: pytest.MonkeyPatch, +) -> None: + lf = pl.scan_parquet(io_files_path / pattern).select(pl.len()) + assert_fast_count(lf, n_rows, capfd=capfd, monkeypatch=monkeypatch) + + +@pytest.mark.parametrize( + ("path", "n_rows"), [("foods1.ipc", 27), ("foods*.ipc", 27 * 2)] +) +def test_count_ipc( + io_files_path: Path, + path: str, + n_rows: int, + capfd: pytest.CaptureFixture[str], + monkeypatch: pytest.MonkeyPatch, +) -> None: + lf = pl.scan_ipc(io_files_path / path).select(pl.len()) + assert_fast_count(lf, n_rows, capfd=capfd, monkeypatch=monkeypatch) + + +@pytest.mark.parametrize( + ("path", "n_rows"), [("foods1.ndjson", 27), ("foods*.ndjson", 27 * 2)] +) +def test_count_ndjson( + io_files_path: Path, + path: str, + n_rows: int, + capfd: pytest.CaptureFixture[str], + monkeypatch: pytest.MonkeyPatch, +) -> None: + lf = pl.scan_ndjson(io_files_path / path).select(pl.len()) + assert_fast_count(lf, n_rows, capfd=capfd, monkeypatch=monkeypatch) + + +def test_count_compressed_csv_18057( + io_files_path: Path, + capfd: pytest.CaptureFixture[str], + monkeypatch: pytest.MonkeyPatch, +) -> None: + csv_file = io_files_path / "gzipped.csv.gz" + + expected = pl.DataFrame( + {"a": [1, 2, 3], "b": ["a", "b", "c"], "c": [1.0, 2.0, 3.0]} + ) + lf = pl.scan_csv(csv_file, truncate_ragged_lines=True) + out = lf.collect() + assert_frame_equal(out, expected) + # This also tests: + # #18070 "CSV count_rows does not skip empty lines at file start" + # as the file has an empty line at the beginning. + + q = lf.select(pl.len()) + assert_fast_count(q, 3, capfd=capfd, monkeypatch=monkeypatch) + + +def test_count_compressed_ndjson( + tmp_path: Path, capfd: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch +) -> None: + tmp_path.mkdir(exist_ok=True) + path = tmp_path / "data.jsonl.gz" + df = pl.DataFrame({"x": range(5)}) + + with gzip.open(path, "wb") as f: + df.write_ndjson(f) + + lf = pl.scan_ndjson(path).select(pl.len()) + assert_fast_count(lf, 5, capfd=capfd, monkeypatch=monkeypatch) + + +def test_count_projection_pd( + capfd: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch +) -> None: + df = pl.DataFrame({"a": range(3), "b": range(3)}) + + q = ( + pl.scan_csv(df.write_csv().encode()) + .with_row_index() + .select(pl.all()) + .select(pl.len()) + ) + + # Manual assert, this is not converted to FAST COUNT but we will have + # 0-width projections. + + monkeypatch.setenv("POLARS_VERBOSE", "1") + capfd.readouterr() + result = q.collect() + capture = capfd.readouterr().err + project_logs = set(re.findall(r"project: \d+", capture)) + + assert project_logs == {"project: 0"} + assert result.item() == 3 diff --git a/py-polars/tests/unit/io/test_lazy_csv.py b/py-polars/tests/unit/io/test_lazy_csv.py new file mode 100644 index 000000000000..1bb294d3b17f --- /dev/null +++ b/py-polars/tests/unit/io/test_lazy_csv.py @@ -0,0 +1,463 @@ +from __future__ import annotations + +import io +import tempfile +from collections import OrderedDict +from pathlib import Path + +import numpy as np +import pytest + +import polars as pl +from polars.exceptions import ComputeError, ShapeError +from polars.testing import assert_frame_equal + + +@pytest.fixture +def foods_file_path(io_files_path: Path) -> Path: + return io_files_path / "foods1.csv" + + +def test_scan_csv(io_files_path: Path) -> None: + df = pl.scan_csv(io_files_path / "small.csv") + assert df.collect().shape == (4, 3) + + +def test_scan_csv_no_cse_deadlock(io_files_path: Path) -> None: + dfs = [pl.scan_csv(io_files_path / "small.csv")] * (pl.thread_pool_size() + 1) + pl.concat(dfs, parallel=True).collect(comm_subplan_elim=False) + + +def test_scan_empty_csv(io_files_path: Path) -> None: + with pytest.raises(Exception) as excinfo: + pl.scan_csv(io_files_path / "empty.csv").collect() + assert "empty CSV" in str(excinfo.value) + + lf = pl.scan_csv(io_files_path / "empty.csv", raise_if_empty=False) + assert_frame_equal(lf, pl.LazyFrame()) + + +@pytest.mark.write_disk +def test_invalid_utf8(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + np.random.seed(1) + bts = bytes(np.random.randint(0, 255, 200)) + + file_path = tmp_path / "nonutf8.csv" + file_path.write_bytes(bts) + + a = pl.read_csv(file_path, has_header=False, encoding="utf8-lossy") + b = pl.scan_csv(file_path, has_header=False, encoding="utf8-lossy").collect() + + assert_frame_equal(a, b) + + +def test_row_index(foods_file_path: Path) -> None: + df = pl.read_csv(foods_file_path, row_index_name="row_index") + assert df["row_index"].to_list() == list(range(27)) + + df = ( + pl.scan_csv(foods_file_path, row_index_name="row_index") + .filter(pl.col("category") == pl.lit("vegetables")) + .collect() + ) + + assert df["row_index"].to_list() == [0, 6, 11, 13, 14, 20, 25] + + df = ( + pl.scan_csv(foods_file_path, row_index_name="row_index") + .with_row_index("foo", 10) + .filter(pl.col("category") == pl.lit("vegetables")) + .collect() + ) + + assert df["foo"].to_list() == [10, 16, 21, 23, 24, 30, 35] + + +@pytest.mark.parametrize("file_name", ["foods1.csv", "foods*.csv"]) +@pytest.mark.may_fail_auto_streaming # missing_columns parameter for CSV +def test_scan_csv_schema_overwrite_and_dtypes_overwrite( + io_files_path: Path, file_name: str +) -> None: + file_path = io_files_path / file_name + df = pl.scan_csv( + file_path, + schema_overrides={"calories_foo": pl.String, "fats_g_foo": pl.Float32}, + with_column_names=lambda names: [f"{a}_foo" for a in names], + ).collect() + assert df.dtypes == [pl.String, pl.String, pl.Float32, pl.Int64] + assert df.columns == [ + "category_foo", + "calories_foo", + "fats_g_foo", + "sugars_g_foo", + ] + + +@pytest.mark.parametrize("file_name", ["foods1.csv", "foods*.csv"]) +@pytest.mark.parametrize("dtype", [pl.Int8, pl.UInt8, pl.Int16, pl.UInt16]) +@pytest.mark.may_fail_auto_streaming # missing_columns parameter for CSV +def test_scan_csv_schema_overwrite_and_small_dtypes_overwrite( + io_files_path: Path, file_name: str, dtype: pl.DataType +) -> None: + file_path = io_files_path / file_name + df = pl.scan_csv( + file_path, + schema_overrides={"calories_foo": pl.String, "sugars_g_foo": dtype}, + with_column_names=lambda names: [f"{a}_foo" for a in names], + ).collect() + assert df.dtypes == [pl.String, pl.String, pl.Float64, dtype] + assert df.columns == [ + "category_foo", + "calories_foo", + "fats_g_foo", + "sugars_g_foo", + ] + + +@pytest.mark.parametrize("file_name", ["foods1.csv", "foods*.csv"]) +@pytest.mark.may_fail_auto_streaming # missing_columns parameter for CSV +def test_scan_csv_schema_new_columns_dtypes( + io_files_path: Path, file_name: str +) -> None: + file_path = io_files_path / file_name + + for dtype in [pl.Int8, pl.UInt8, pl.Int16, pl.UInt16]: + # assign 'new_columns', providing partial dtype overrides + df1 = pl.scan_csv( + file_path, + schema_overrides={"calories": pl.String, "sugars": dtype}, + new_columns=["category", "calories", "fats", "sugars"], + ).collect() + assert df1.dtypes == [pl.String, pl.String, pl.Float64, dtype] + assert df1.columns == ["category", "calories", "fats", "sugars"] + + # assign 'new_columns' with 'dtypes' list + df2 = pl.scan_csv( + file_path, + schema_overrides=[pl.String, pl.String, pl.Float64, dtype], + new_columns=["category", "calories", "fats", "sugars"], + ).collect() + assert df1.rows() == df2.rows() + + # rename existing columns, then lazy-select disjoint cols + lf = pl.scan_csv( + file_path, + new_columns=["colw", "colx", "coly", "colz"], + ) + schema = lf.collect_schema() + assert schema.dtypes() == [pl.String, pl.Int64, pl.Float64, pl.Int64] + assert schema.names() == ["colw", "colx", "coly", "colz"] + assert ( + lf.select("colz", "colx").collect().rows() + == df1.select("sugars", pl.col("calories").cast(pl.Int64)).rows() + ) + + # partially rename columns / overwrite dtypes + df4 = pl.scan_csv( + file_path, + schema_overrides=[pl.String, pl.String], + new_columns=["category", "calories"], + ).collect() + assert df4.dtypes == [pl.String, pl.String, pl.Float64, pl.Int64] + assert df4.columns == ["category", "calories", "fats_g", "sugars_g"] + + # cannot have len(new_columns) > len(actual columns) + with pytest.raises(ShapeError): + pl.scan_csv( + file_path, + schema_overrides=[pl.String, pl.String], + new_columns=["category", "calories", "c3", "c4", "c5"], + ).collect() + + # cannot set both 'new_columns' and 'with_column_names' + with pytest.raises(ValueError, match="mutually.exclusive"): + pl.scan_csv( + file_path, + schema_overrides=[pl.String, pl.String], + new_columns=["category", "calories", "fats", "sugars"], + with_column_names=lambda cols: [col.capitalize() for col in cols], + ).collect() + + +def test_lazy_n_rows(foods_file_path: Path) -> None: + df = ( + pl.scan_csv(foods_file_path, n_rows=4, row_index_name="idx") + .filter(pl.col("idx") > 2) + .collect() + ) + assert df.to_dict(as_series=False) == { + "idx": [3], + "category": ["fruit"], + "calories": [60], + "fats_g": [0.0], + "sugars_g": [11], + } + + +def test_lazy_row_index_no_push_down(foods_file_path: Path) -> None: + plan = ( + pl.scan_csv(foods_file_path) + .with_row_index() + .filter(pl.col("index") == 1) + .filter(pl.col("category") == pl.lit("vegetables")) + .explain(predicate_pushdown=True) + ) + # related to row count is not pushed. + assert 'FILTER [(col("index")) == (1)]\nFROM' in plan + # unrelated to row count is pushed. + assert 'SELECTION: [(col("category")) == ("vegetables")]' in plan + + +@pytest.mark.write_disk +def test_glob_skip_rows(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + for i in range(2): + file_path = tmp_path / f"test_{i}.csv" + file_path.write_text( + f""" +metadata goes here +file number {i} +foo,bar,baz +1,2,3 +4,5,6 +7,8,9 +""" + ) + file_path = tmp_path / "*.csv" + assert pl.read_csv(file_path, skip_rows=2).to_dict(as_series=False) == { + "foo": [1, 4, 7, 1, 4, 7], + "bar": [2, 5, 8, 2, 5, 8], + "baz": [3, 6, 9, 3, 6, 9], + } + + +def test_glob_n_rows(io_files_path: Path) -> None: + file_path = io_files_path / "foods*.csv" + df = pl.scan_csv(file_path, n_rows=40).collect() + + # 27 rows from foods1.csv and 13 from foods2.csv + assert df.shape == (40, 4) + + # take first and last rows + assert df[[0, 39]].to_dict(as_series=False) == { + "category": ["vegetables", "seafood"], + "calories": [45, 146], + "fats_g": [0.5, 6.0], + "sugars_g": [2, 2], + } + + +def test_scan_csv_schema_overwrite_not_projected_8483(foods_file_path: Path) -> None: + df = ( + pl.scan_csv( + foods_file_path, + schema_overrides={"calories": pl.String, "sugars_g": pl.Int8}, + ) + .select(pl.len()) + .collect() + ) + expected = pl.DataFrame({"len": 27}, schema={"len": pl.UInt32}) + assert_frame_equal(df, expected) + + +def test_csv_list_arg(io_files_path: Path) -> None: + first = io_files_path / "foods1.csv" + second = io_files_path / "foods2.csv" + + df = pl.scan_csv(source=[first, second]).collect() + assert df.shape == (54, 4) + assert df.row(-1) == ("seafood", 194, 12.0, 1) + assert df.row(0) == ("vegetables", 45, 0.5, 2) + + +# https://github.com/pola-rs/polars/issues/9887 +def test_scan_csv_slice_offset_zero(io_files_path: Path) -> None: + lf = pl.scan_csv(io_files_path / "small.csv") + result = lf.slice(0) + assert result.collect().height == 4 + + +@pytest.mark.write_disk +def test_scan_empty_csv_with_row_index(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + file_path = tmp_path / "small.parquet" + df = pl.DataFrame({"a": []}) + df.write_csv(file_path) + + read = pl.scan_csv(file_path).with_row_index("idx") + assert read.collect().schema == OrderedDict([("idx", pl.UInt32), ("a", pl.String)]) + + +@pytest.mark.write_disk +def test_csv_null_values_with_projection_15515() -> None: + data = """IndCode,SireCode,BirthDate,Flag +ID00316,.,19940315, +""" + + with tempfile.NamedTemporaryFile() as f: + f.write(data.encode()) + f.seek(0) + + q = ( + pl.scan_csv(f.name, null_values={"SireCode": "."}) + .with_columns(pl.col("SireCode").alias("SireKey")) + .select("SireKey", "BirthDate") + ) + + assert q.collect().to_dict(as_series=False) == { + "SireKey": [None], + "BirthDate": [19940315], + } + + +@pytest.mark.write_disk +def test_csv_respect_user_schema_ragged_lines_15254() -> None: + with tempfile.NamedTemporaryFile() as f: + f.write( + b""" +A,B,C +1,2,3 +4,5,6,7,8 +9,10,11 +""".strip() + ) + f.seek(0) + + df = pl.scan_csv( + f.name, schema=dict.fromkeys("ABCDE", pl.String), truncate_ragged_lines=True + ).collect() + assert df.to_dict(as_series=False) == { + "A": ["1", "4", "9"], + "B": ["2", "5", "10"], + "C": ["3", "6", "11"], + "D": [None, "7", None], + "E": [None, "8", None], + } + + +@pytest.mark.parametrize("streaming", [True, False]) +@pytest.mark.parametrize( + "dfs", + [ + [pl.DataFrame({"a": [1, 2, 3]}), pl.DataFrame({"b": [4, 5, 6]})], + [ + pl.DataFrame({"a": [1, 2, 3]}), + pl.DataFrame({"b": [4, 5, 6], "c": [7, 8, 9]}), + ], + ], +) +@pytest.mark.may_fail_auto_streaming # missing_columns parameter for CSV +def test_file_list_schema_mismatch( + tmp_path: Path, dfs: list[pl.DataFrame], streaming: bool +) -> None: + tmp_path.mkdir(exist_ok=True) + + paths = [f"{tmp_path}/{i}.csv" for i in range(len(dfs))] + + for df, path in zip(dfs, paths): + df.write_csv(path) + + lf = pl.scan_csv(paths) + with pytest.raises((ComputeError, pl.exceptions.ColumnNotFoundError)): + lf.collect(engine="streaming" if streaming else "in-memory") + + if streaming: + pytest.xfail(reason="missing_columns parameter for CSV") + + if len({df.width for df in dfs}) == 1: + expect = pl.concat(df.select(x=pl.first().cast(pl.Int8)) for df in dfs) + out = pl.scan_csv(paths, schema={"x": pl.Int8}).collect( # type: ignore[call-overload] + engine="streaming" if streaming else "in-memory" # type: ignore[redundant-expr] + ) + + assert_frame_equal(out, expect) + + +@pytest.mark.may_fail_auto_streaming +@pytest.mark.parametrize("streaming", [True, False]) +def test_file_list_schema_supertype(tmp_path: Path, streaming: bool) -> None: + tmp_path.mkdir(exist_ok=True) + + data_lst = [ + """\ +a +1 +2 +""", + """\ +a +b +c +""", + ] + + paths = [f"{tmp_path}/{i}.csv" for i in range(len(data_lst))] + + for data, path in zip(data_lst, paths): + with Path(path).open("w") as f: + f.write(data) + + expect = pl.Series("a", ["1", "2", "b", "c"]).to_frame() + out = pl.scan_csv(paths).collect( # type: ignore[call-overload] + engine="old-streaming" if streaming else "in-memory" + ) + + assert_frame_equal(out, expect) + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_file_list_comment_skip_rows_16327(tmp_path: Path, streaming: bool) -> None: + tmp_path.mkdir(exist_ok=True) + + data_lst = [ + """\ +# comment +a +b +c +""", + """\ +a +b +c +""", + ] + + paths = [f"{tmp_path}/{i}.csv" for i in range(len(data_lst))] + + for data, path in zip(data_lst, paths): + with Path(path).open("w") as f: + f.write(data) + + expect = pl.Series("a", ["b", "c", "b", "c"]).to_frame() + out = pl.scan_csv(paths, comment_prefix="#").collect( + engine="streaming" if streaming else "in-memory" + ) + + assert_frame_equal(out, expect) + + +@pytest.mark.xfail(reason="Bug: https://github.com/pola-rs/polars/issues/17634") +def test_scan_csv_with_column_names_nonexistent_file() -> None: + path_str = "my-nonexistent-data.csv" + path = Path(path_str) + assert not path.exists() + + # Just calling the scan function should not raise any errors + result = pl.scan_csv(path, with_column_names=lambda x: [c.upper() for c in x]) + assert isinstance(result, pl.LazyFrame) + + # Upon collection, it should fail + with pytest.raises(FileNotFoundError): + result.collect() + + +def test_select_nonexistent_column() -> None: + csv = "a\n1" + f = io.StringIO(csv) + + with pytest.raises(pl.exceptions.ColumnNotFoundError): + pl.scan_csv(f).select("b").collect() diff --git a/py-polars/tests/unit/io/test_lazy_ipc.py b/py-polars/tests/unit/io/test_lazy_ipc.py new file mode 100644 index 000000000000..daba3d5c45bc --- /dev/null +++ b/py-polars/tests/unit/io/test_lazy_ipc.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pytest + +import polars as pl +from polars.testing.asserts.frame import assert_frame_equal + +if TYPE_CHECKING: + from pathlib import Path + + +@pytest.fixture +def foods_ipc_path(io_files_path: Path) -> Path: + return io_files_path / "foods1.ipc" + + +def test_row_index(foods_ipc_path: Path) -> None: + df = pl.read_ipc(foods_ipc_path, row_index_name="row_index", use_pyarrow=False) + assert df["row_index"].to_list() == list(range(27)) + + df = ( + pl.scan_ipc(foods_ipc_path, row_index_name="row_index") + .filter(pl.col("category") == pl.lit("vegetables")) + .collect() + ) + + assert df["row_index"].to_list() == [0, 6, 11, 13, 14, 20, 25] + + df = ( + pl.scan_ipc(foods_ipc_path, row_index_name="row_index") + .with_row_index("foo", 10) + .filter(pl.col("category") == pl.lit("vegetables")) + .collect() + ) + + assert df["foo"].to_list() == [10, 16, 21, 23, 24, 30, 35] + + +def test_is_in_type_coercion(foods_ipc_path: Path) -> None: + out = ( + pl.scan_ipc(foods_ipc_path) + .filter(pl.col("category").is_in(("vegetables", "ice cream"))) + .collect() + ) + assert out.shape == (7, 4) + out = ( + pl.scan_ipc(foods_ipc_path) + .select(pl.col("category").alias("cat")) + .filter(pl.col("cat").is_in(["vegetables"])) + .collect() + ) + assert out.shape == (7, 1) + + +def test_row_index_schema(foods_ipc_path: Path) -> None: + assert ( + pl.scan_ipc(foods_ipc_path, row_index_name="id") + .select(["id", "category"]) + .collect() + ).dtypes == [pl.UInt32, pl.String] + + +def test_glob_n_rows(io_files_path: Path) -> None: + file_path = io_files_path / "foods*.ipc" + df = pl.scan_ipc(file_path, n_rows=40).collect() + + # 27 rows from foods1.ipc and 13 from foods2.ipc + assert df.shape == (40, 4) + + # take first and last rows + assert df[[0, 39]].to_dict(as_series=False) == { + "category": ["vegetables", "seafood"], + "calories": [45, 146], + "fats_g": [0.5, 6.0], + "sugars_g": [2, 2], + } + + +def test_ipc_list_arg(io_files_path: Path) -> None: + first = io_files_path / "foods1.ipc" + second = io_files_path / "foods2.ipc" + + df = pl.scan_ipc(source=[first, second]).collect() + assert df.shape == (54, 4) + assert df.row(-1) == ("seafood", 194, 12.0, 1) + assert df.row(0) == ("vegetables", 45, 0.5, 2) + + +def test_scan_ipc_local_with_async( + monkeypatch: Any, + io_files_path: Path, +) -> None: + monkeypatch.setenv("POLARS_VERBOSE", "1") + monkeypatch.setenv("POLARS_FORCE_ASYNC", "1") + + assert_frame_equal( + pl.scan_ipc(io_files_path / "foods1.ipc").head(1).collect(), + pl.DataFrame( + { + "category": ["vegetables"], + "calories": [45], + "fats_g": [0.5], + "sugars_g": [2], + } + ), + ) diff --git a/py-polars/tests/unit/io/test_lazy_json.py b/py-polars/tests/unit/io/test_lazy_json.py new file mode 100644 index 000000000000..b3561408f334 --- /dev/null +++ b/py-polars/tests/unit/io/test_lazy_json.py @@ -0,0 +1,311 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + +if TYPE_CHECKING: + from pathlib import Path + + +@pytest.fixture +def foods_ndjson_path(io_files_path: Path) -> Path: + return io_files_path / "foods1.ndjson" + + +def test_scan_ndjson(foods_ndjson_path: Path) -> None: + df = pl.scan_ndjson(foods_ndjson_path, row_index_name="row_index").collect() + assert df["row_index"].to_list() == list(range(27)) + + df = ( + pl.scan_ndjson(foods_ndjson_path, row_index_name="row_index") + .filter(pl.col("category") == pl.lit("vegetables")) + .collect() + ) + + assert df["row_index"].to_list() == [0, 6, 11, 13, 14, 20, 25] + + df = ( + pl.scan_ndjson(foods_ndjson_path, row_index_name="row_index") + .with_row_index("foo", 10) + .filter(pl.col("category") == pl.lit("vegetables")) + .collect() + ) + + assert df["foo"].to_list() == [10, 16, 21, 23, 24, 30, 35] + + +def test_scan_ndjson_with_schema(foods_ndjson_path: Path) -> None: + schema = { + "category": pl.Categorical, + "calories": pl.Int64, + "fats_g": pl.Float64, + "sugars_g": pl.Int64, + } + df = pl.scan_ndjson(foods_ndjson_path, schema=schema).collect() + assert df["category"].dtype == pl.Categorical + assert df["calories"].dtype == pl.Int64 + assert df["fats_g"].dtype == pl.Float64 + assert df["sugars_g"].dtype == pl.Int64 + + schema["sugars_g"] = pl.Float64 + df = pl.scan_ndjson(foods_ndjson_path, schema=schema).collect() + assert df["sugars_g"].dtype == pl.Float64 + + +def test_scan_ndjson_infer_0(foods_ndjson_path: Path) -> None: + with pytest.raises(ValueError): + pl.scan_ndjson(foods_ndjson_path, infer_schema_length=0) + + +def test_scan_ndjson_batch_size_zero() -> None: + with pytest.raises(ValueError, match="invalid zero value"): + pl.scan_ndjson("test.ndjson", batch_size=0) + + +@pytest.mark.write_disk +def test_scan_with_projection(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + json = r""" +{"text": "\"hello", "id": 1} +{"text": "\n{\n\t\t\"inner\": \"json\n}\n", "id": 10} +{"id": 0, "text":"\"","date":"2013-08-03 15:17:23"} +{"id": 1, "text":"\"123\"","date":"2009-05-19 21:07:53"} +{"id": 2, "text":"/....","date":"2009-05-19 21:07:53"} +{"id": 3, "text":"\n\n..","date":"2"} +{"id": 4, "text":"\"'/\n...","date":"2009-05-19 21:07:53"} +{"id": 5, "text":".h\"h1hh\\21hi1e2emm...","date":"2009-05-19 21:07:53"} +{"id": 6, "text":"xxxx....","date":"2009-05-19 21:07:53"} +{"id": 7, "text":".\"quoted text\".","date":"2009-05-19 21:07:53"} +""" + json_bytes = bytes(json, "utf-8") + + file_path = tmp_path / "escape_chars.json" + file_path.write_bytes(json_bytes) + + actual = pl.scan_ndjson(file_path).select(["id", "text"]).collect() + + expected = pl.DataFrame( + { + "id": [1, 10, 0, 1, 2, 3, 4, 5, 6, 7], + "text": [ + '"hello', + '\n{\n\t\t"inner": "json\n}\n', + '"', + '"123"', + "/....", + "\n\n..", + "\"'/\n...", + '.h"h1hh\\21hi1e2emm...', + "xxxx....", + '."quoted text".', + ], + } + ) + assert_frame_equal(actual, expected) + + +def test_projection_pushdown_ndjson(io_files_path: Path) -> None: + file_path = io_files_path / "foods1.ndjson" + df = pl.scan_ndjson(file_path).select(pl.col.calories) + + explain = df.explain() + + assert "simple π" not in explain + assert "PROJECT 1/4 COLUMNS" in explain + + assert_frame_equal(df.collect(no_optimization=True), df.collect()) + + +def test_predicate_pushdown_ndjson(io_files_path: Path) -> None: + file_path = io_files_path / "foods1.ndjson" + df = pl.scan_ndjson(file_path).filter(pl.col.calories > 80) + + explain = df.explain() + + assert "FILTER" not in explain + assert """SELECTION: [(col("calories")) > (80)]""" in explain + + assert_frame_equal(df.collect(no_optimization=True), df.collect()) + + +def test_glob_n_rows(io_files_path: Path) -> None: + file_path = io_files_path / "foods*.ndjson" + df = pl.scan_ndjson(file_path, n_rows=40).collect() + + # 27 rows from foods1.ndjson and 13 from foods2.ndjson + assert df.shape == (40, 4) + + # take first and last rows + assert df[[0, 39]].to_dict(as_series=False) == { + "category": ["vegetables", "seafood"], + "calories": [45, 146], + "fats_g": [0.5, 6.0], + "sugars_g": [2, 2], + } + + +# See #10661. +def test_json_no_unicode_truncate() -> None: + assert pl.read_ndjson(rb'{"field": "\ufffd1234"}')[0, 0] == "\ufffd1234" + + +def test_ndjson_list_arg(io_files_path: Path) -> None: + first = io_files_path / "foods1.ndjson" + second = io_files_path / "foods2.ndjson" + + df = pl.scan_ndjson(source=[first, second]).collect() + assert df.shape == (54, 4) + assert df.row(-1) == ("seafood", 194, 12.0, 1) + assert df.row(0) == ("vegetables", 45, 0.5, 2) + + +def test_glob_single_scan(io_files_path: Path) -> None: + file_path = io_files_path / "foods*.ndjson" + df = pl.scan_ndjson(file_path, n_rows=40) + + explain = df.explain() + + assert explain.count("SCAN") == 1 + assert "UNION" not in explain + + +def test_scan_ndjson_empty_lines_in_middle() -> None: + assert_frame_equal( + pl.scan_ndjson( + f"""\ +{{"a": 1}} +{" "} +{{"a": 2}}{" "} +{" "} +{{"a": 3}} +""".encode() + ).collect(), + pl.DataFrame({"a": [1, 2, 3]}), + ) + + +@pytest.mark.parametrize("row_index_offset", [None, 0, 20]) +def test_scan_ndjson_slicing( + foods_ndjson_path: Path, row_index_offset: int | None +) -> None: + lf = pl.scan_ndjson(foods_ndjson_path) + + if row_index_offset is not None: + lf = lf.with_row_index(offset=row_index_offset) + + for q in [ + lf.head(5), + lf.tail(5), + lf.head(0), + lf.tail(0), + lf.slice(-999, 3), + lf.slice(999, 3), + lf.slice(-999, 0), + lf.slice(999, 0), + lf.slice(-999), + lf.slice(-3, 999), + ]: + assert_frame_equal(q.collect(), q.collect(no_optimization=True)) + + +@pytest.mark.parametrize( + "dtype", + [ + pl.Boolean, + pl.Int32, + pl.Int64, + pl.UInt64, + pl.UInt32, + pl.Float32, + pl.Float64, + pl.Datetime, + pl.Date, + pl.Null, + ], +) +def test_scan_ndjson_raises_on_parse_error(dtype: pl.DataType) -> None: + buf = b"""\ +{"a": "AAAA"} +""" + + cx = ( + pytest.raises( + pl.exceptions.ComputeError, + match="got non-null value for NULL-typed column: AAAA", + ) + if str(dtype) == "Null" + else pytest.raises(pl.exceptions.ComputeError, match="cannot parse 'AAAA' as ") + ) + + with cx: + pl.scan_ndjson( + buf, + schema={"a": dtype}, + ).collect() + + assert_frame_equal( + pl.scan_ndjson(buf, schema={"a": dtype}, ignore_errors=True).collect(), + pl.DataFrame({"a": [None]}, schema={"a": dtype}), + ) + + +def test_scan_ndjson_parse_string() -> None: + assert_frame_equal( + pl.scan_ndjson( + b"""\ +{"a": "123"} +""", + schema={"a": pl.String}, + ).collect(), + pl.DataFrame({"a": "123"}), + ) + + +def test_scan_ndjson_raises_on_parse_error_nested() -> None: + buf = b"""\ +{"a": {"b": "AAAA"}} +""" + q = pl.scan_ndjson( + buf, + schema={"a": pl.Struct({"b": pl.Int64})}, + ) + + with pytest.raises(pl.exceptions.ComputeError): + q.collect() + + q = pl.scan_ndjson( + buf, schema={"a": pl.Struct({"b": pl.Int64})}, ignore_errors=True + ) + + assert_frame_equal( + q.collect(), + pl.DataFrame({"a": [{"b": None}]}, schema={"a": pl.Struct({"b": pl.Int64})}), + ) + + +def test_scan_ndjson_nested_as_string() -> None: + buf = b"""\ +{"a": {"x": 1}, "b": [1,2,3], "c": {"y": null}, "d": [{"k": "abc"}, {"j": "123"}, {"l": 7}]} +""" + + df = pl.scan_ndjson( + buf, + schema={"a": pl.String, "b": pl.String, "c": pl.String, "d": pl.String}, + ).collect() + + assert_frame_equal( + df, + pl.DataFrame( + { + "a": '{"x": 1}', + "b": "[1, 2, 3]", + "c": '{"y": null}', + "d": '[{"k": "abc"}, {"j": "123"}, {"l": 7}]', + } + ), + ) diff --git a/py-polars/tests/unit/io/test_lazy_parquet.py b/py-polars/tests/unit/io/test_lazy_parquet.py new file mode 100644 index 000000000000..d5a5b5489cc4 --- /dev/null +++ b/py-polars/tests/unit/io/test_lazy_parquet.py @@ -0,0 +1,887 @@ +from __future__ import annotations + +from collections import OrderedDict +from pathlib import Path +from threading import Thread +from typing import TYPE_CHECKING, Any + +import pandas as pd +import pytest + +import polars as pl +from polars.exceptions import ComputeError +from polars.testing import assert_frame_equal + +if TYPE_CHECKING: + from polars._typing import ParallelStrategy + + +@pytest.fixture +def parquet_file_path(io_files_path: Path) -> Path: + return io_files_path / "small.parquet" + + +@pytest.fixture +def foods_parquet_path(io_files_path: Path) -> Path: + return io_files_path / "foods1.parquet" + + +def test_scan_parquet(parquet_file_path: Path) -> None: + df = pl.scan_parquet(parquet_file_path) + assert df.collect().shape == (4, 3) + + +def test_scan_parquet_local_with_async( + monkeypatch: Any, foods_parquet_path: Path +) -> None: + monkeypatch.setenv("POLARS_FORCE_ASYNC", "1") + pl.scan_parquet(foods_parquet_path.relative_to(Path.cwd())).head(1).collect() + + +def test_row_index(foods_parquet_path: Path) -> None: + df = pl.read_parquet(foods_parquet_path, row_index_name="row_index") + assert df["row_index"].to_list() == list(range(27)) + + df = ( + pl.scan_parquet(foods_parquet_path, row_index_name="row_index") + .filter(pl.col("category") == pl.lit("vegetables")) + .collect() + ) + + assert df["row_index"].to_list() == [0, 6, 11, 13, 14, 20, 25] + + df = ( + pl.scan_parquet(foods_parquet_path, row_index_name="row_index") + .with_row_index("foo", 10) + .filter(pl.col("category") == pl.lit("vegetables")) + .collect() + ) + + assert df["foo"].to_list() == [10, 16, 21, 23, 24, 30, 35] + + +def test_row_index_len_16543(foods_parquet_path: Path) -> None: + q = pl.scan_parquet(foods_parquet_path).with_row_index() + assert q.select(pl.all()).select(pl.len()).collect().item() == 27 + + +@pytest.mark.write_disk +@pytest.mark.usefixtures("test_global_and_local") +def test_categorical_parquet_statistics(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + df = pl.DataFrame( + { + "book": [ + "bookA", + "bookA", + "bookB", + "bookA", + "bookA", + "bookC", + "bookC", + "bookC", + ], + "transaction_id": [1, 2, 3, 4, 5, 6, 7, 8], + "user": ["bob", "bob", "bob", "tim", "lucy", "lucy", "lucy", "lucy"], + } + ).with_columns(pl.col("book").cast(pl.Categorical)) + + file_path = tmp_path / "books.parquet" + df.write_parquet(file_path, statistics=True) + + parallel_options: list[ParallelStrategy] = [ + "auto", + "columns", + "row_groups", + "none", + ] + for par in parallel_options: + df = ( + pl.scan_parquet(file_path, parallel=par) + .filter(pl.col("book") == "bookA") + .collect() + ) + assert df.shape == (4, 3) + + +@pytest.mark.write_disk +def test_parquet_eq_stats(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + file_path = tmp_path / "stats.parquet" + + df1 = pd.DataFrame({"a": [None, 1, None, 2, 3, 3, 4, 4, 5, 5]}) + df1.to_parquet(file_path, engine="pyarrow") + df = pl.scan_parquet(file_path).filter(pl.col("a") == 4).collect() + assert df["a"].to_list() == [4.0, 4.0] + + assert ( + pl.scan_parquet(file_path).filter(pl.col("a") == 2).select(pl.col("a").sum()) + ).collect()[0, "a"] == 2.0 + + assert pl.scan_parquet(file_path).filter(pl.col("a") == 5).collect().shape == ( + 2, + 1, + ) + + +@pytest.mark.write_disk +def test_parquet_is_in_stats(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + file_path = tmp_path / "stats.parquet" + + df1 = pd.DataFrame({"a": [None, 1, None, 2, 3, 3, 4, 4, 5, 5]}) + df1.to_parquet(file_path, engine="pyarrow") + df = pl.scan_parquet(file_path).filter(pl.col("a").is_in([5])).collect() + assert df["a"].to_list() == [5.0, 5.0] + + assert ( + pl.scan_parquet(file_path) + .filter(pl.col("a").is_in([5])) + .select(pl.col("a").sum()) + ).collect()[0, "a"] == 10.0 + + assert ( + pl.scan_parquet(file_path) + .filter(pl.col("a").is_in([1, 2, 3])) + .select(pl.col("a").sum()) + ).collect()[0, "a"] == 9.0 + + assert ( + pl.scan_parquet(file_path) + .filter(pl.col("a").is_in([1, 2, 3])) + .select(pl.col("a").sum()) + ).collect()[0, "a"] == 9.0 + + assert ( + pl.scan_parquet(file_path) + .filter(pl.col("a").is_in([5])) + .select(pl.col("a").sum()) + ).collect()[0, "a"] == 10.0 + + assert pl.scan_parquet(file_path).filter( + pl.col("a").is_in([1, 2, 3, 4, 5]) + ).collect().shape == (8, 1) + + +@pytest.mark.write_disk +def test_parquet_stats(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + file_path = tmp_path / "binary_stats.parquet" + + df1 = pd.DataFrame({"a": [None, 1, None, 2, 3, 3, 4, 4, 5, 5]}) + df1.to_parquet(file_path, engine="pyarrow") + df = ( + pl.scan_parquet(file_path) + .filter(pl.col("a").is_not_null() & (pl.col("a") > 4)) + .collect() + ) + assert df["a"].to_list() == [5.0, 5.0] + + assert ( + pl.scan_parquet(file_path).filter(pl.col("a") > 4).select(pl.col("a").sum()) + ).collect()[0, "a"] == 10.0 + + assert ( + pl.scan_parquet(file_path).filter(pl.col("a") < 4).select(pl.col("a").sum()) + ).collect()[0, "a"] == 9.0 + + assert ( + pl.scan_parquet(file_path).filter(pl.col("a") < 4).select(pl.col("a").sum()) + ).collect()[0, "a"] == 9.0 + + assert ( + pl.scan_parquet(file_path).filter(pl.col("a") > 4).select(pl.col("a").sum()) + ).collect()[0, "a"] == 10.0 + assert pl.scan_parquet(file_path).filter( + (pl.col("a") * 10) > 5.0 + ).collect().shape == (8, 1) + + +def test_row_index_schema_parquet(parquet_file_path: Path) -> None: + assert ( + pl.scan_parquet(str(parquet_file_path), row_index_name="id") + .select(["id", "b"]) + .collect() + ).dtypes == [pl.UInt32, pl.String] + + +@pytest.mark.write_disk +def test_parquet_is_in_statistics(monkeypatch: Any, capfd: Any, tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + monkeypatch.setenv("POLARS_VERBOSE", "1") + + df = pl.DataFrame({"idx": pl.arange(0, 100, eager=True)}).with_columns( + (pl.col("idx") // 25).alias("part") + ) + df = pl.concat(df.partition_by("part", as_dict=False), rechunk=False) + assert df.n_chunks("all") == [4, 4] + + file_path = tmp_path / "stats.parquet" + df.write_parquet(file_path, statistics=True, use_pyarrow=False) + + file_path = tmp_path / "stats.parquet" + df.write_parquet(file_path, statistics=True, use_pyarrow=False) + + for pred in [ + pl.col("idx").is_in([150, 200, 300]), + pl.col("idx").is_in([5, 250, 350]), + ]: + result = pl.scan_parquet(file_path).filter(pred).collect() + assert_frame_equal(result, df.filter(pred)) + + captured = capfd.readouterr().err + assert "Predicate pushdown: reading 1 / 1 row groups" in captured + assert "Predicate pushdown: reading 0 / 1 row groups" in captured + + +@pytest.mark.write_disk +def test_parquet_statistics(monkeypatch: Any, capfd: Any, tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + monkeypatch.setenv("POLARS_VERBOSE", "1") + + df = pl.DataFrame({"idx": pl.arange(0, 100, eager=True)}).with_columns( + (pl.col("idx") // 25).alias("part") + ) + df = pl.concat(df.partition_by("part", as_dict=False), rechunk=False) + assert df.n_chunks("all") == [4, 4] + + file_path = tmp_path / "stats.parquet" + df.write_parquet(file_path, statistics=True, use_pyarrow=False, row_group_size=50) + + for pred in [ + pl.col("idx") < 50, + pl.col("idx") > 50, + pl.col("idx").null_count() != 0, + pl.col("idx").null_count() == 0, + pl.col("idx").min() == pl.col("part").null_count(), + ]: + result = pl.scan_parquet(file_path).filter(pred).collect() + assert_frame_equal(result, df.filter(pred)) + + captured = capfd.readouterr().err + + assert "Predicate pushdown: reading 1 / 2 row groups" in captured + + +@pytest.mark.write_disk +@pytest.mark.usefixtures("test_global_and_local") +def test_categorical(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + df = pl.DataFrame( + [ + pl.Series("name", ["Bob", "Alice", "Bob"], pl.Categorical), + pl.Series("amount", [100, 200, 300]), + ] + ) + + file_path = tmp_path / "categorical.parquet" + df.write_parquet(file_path) + + with pl.StringCache(): + result = ( + pl.scan_parquet(file_path) + .group_by("name") + .agg(pl.col("amount").sum()) + .collect() + .sort("name") + ) + expected = pl.DataFrame( + {"name": ["Bob", "Alice"], "amount": [400, 200]}, + schema_overrides={"name": pl.Categorical}, + ) + assert_frame_equal(result, expected) + + +def test_glob_n_rows(io_files_path: Path) -> None: + file_path = io_files_path / "foods*.parquet" + df = pl.scan_parquet(file_path, n_rows=40).collect() + + # 27 rows from foods1.parquet and 13 from foods2.parquet + assert df.shape == (40, 4) + + # take first and last rows + assert df[[0, 39]].to_dict(as_series=False) == { + "category": ["vegetables", "seafood"], + "calories": [45, 146], + "fats_g": [0.5, 6.0], + "sugars_g": [2, 2], + } + + +@pytest.mark.write_disk +def test_parquet_statistics_filter_9925(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + file_path = tmp_path / "codes.parquet" + df = pl.DataFrame({"code": [300964, 300972, 500_000, 26]}) + df.write_parquet(file_path, statistics=True) + + q = pl.scan_parquet(file_path).filter( + (pl.col("code").floordiv(100_000)).is_in([0, 3]) + ) + assert q.collect().to_dict(as_series=False) == {"code": [300964, 300972, 26]} + + +@pytest.mark.write_disk +def test_parquet_statistics_filter_11069(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + file_path = tmp_path / "foo.parquet" + pl.DataFrame({"x": [1, None]}).write_parquet(file_path, statistics=False) + + result = pl.scan_parquet(file_path).filter(pl.col("x").is_null()).collect() + expected = {"x": [None]} + assert result.to_dict(as_series=False) == expected + + +def test_parquet_list_arg(io_files_path: Path) -> None: + first = io_files_path / "foods1.parquet" + second = io_files_path / "foods2.parquet" + + df = pl.scan_parquet(source=[first, second]).collect() + assert df.shape == (54, 4) + assert df.row(-1) == ("seafood", 194, 12.0, 1) + assert df.row(0) == ("vegetables", 45, 0.5, 2) + + +@pytest.mark.write_disk +def test_parquet_many_row_groups_12297(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + file_path = tmp_path / "foo.parquet" + df = pl.DataFrame({"x": range(100)}) + df.write_parquet(file_path, row_group_size=5, use_pyarrow=True) + assert_frame_equal(pl.scan_parquet(file_path).collect(), df) + + +@pytest.mark.write_disk +def test_row_index_empty_file(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + file_path = tmp_path / "test.parquet" + df = pl.DataFrame({"a": []}, schema={"a": pl.Float32}) + df.write_parquet(file_path) + result = pl.scan_parquet(file_path).with_row_index("idx").collect() + assert result.schema == OrderedDict([("idx", pl.UInt32), ("a", pl.Float32)]) + + +@pytest.mark.write_disk +def test_io_struct_async_12500(tmp_path: Path) -> None: + file_path = tmp_path / "test.parquet" + pl.DataFrame( + [ + pl.Series("c1", [{"a": "foo", "b": "bar"}], dtype=pl.Struct), + pl.Series("c2", [18]), + ] + ).write_parquet(file_path) + assert pl.scan_parquet(file_path).select("c1").collect().to_dict( + as_series=False + ) == {"c1": [{"a": "foo", "b": "bar"}]} + + +@pytest.mark.write_disk +@pytest.mark.parametrize("streaming", [True, False]) +def test_parquet_different_schema(tmp_path: Path, streaming: bool) -> None: + # Schema is different but the projected columns are same dtype. + f1 = tmp_path / "a.parquet" + f2 = tmp_path / "b.parquet" + a = pl.DataFrame({"a": [1.0], "b": "a"}) + + b = pl.DataFrame({"a": [1], "b": "a"}) + + a.write_parquet(f1) + b.write_parquet(f2) + assert pl.scan_parquet([f1, f2]).select("b").collect( + engine="streaming" if streaming else "in-memory" + ).columns == ["b"] + + +@pytest.mark.write_disk +def test_nested_slice_12480(tmp_path: Path) -> None: + path = tmp_path / "data.parquet" + df = pl.select(pl.lit(1).repeat_by(10_000).explode().cast(pl.List(pl.Int32))) + + df.write_parquet(path, use_pyarrow=True, pyarrow_options={"data_page_size": 1}) + + assert pl.scan_parquet(path).slice(0, 1).collect().height == 1 + + +@pytest.mark.write_disk +def test_scan_deadlock_rayon_spawn_from_async_15172( + monkeypatch: Any, tmp_path: Path +) -> None: + monkeypatch.setenv("POLARS_FORCE_ASYNC", "1") + monkeypatch.setenv("POLARS_MAX_THREADS", "1") + path = tmp_path / "data.parquet" + + df = pl.Series("x", [1]).to_frame() + df.write_parquet(path) + + results = [pl.DataFrame()] + + def scan_collect() -> None: + results[0] = pl.collect_all([pl.scan_parquet(path)])[0] + + # Make sure we don't sit there hanging forever on the broken case + t = Thread(target=scan_collect, daemon=True) + t.start() + t.join(5) + + assert results[0].equals(df) + + +@pytest.mark.write_disk +@pytest.mark.parametrize("streaming", [True, False]) +def test_parquet_schema_mismatch_panic_17067(tmp_path: Path, streaming: bool) -> None: + pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}).write_parquet(tmp_path / "1.parquet") + pl.DataFrame({"c": [1, 2, 3], "d": [4, 5, 6]}).write_parquet(tmp_path / "2.parquet") + + if streaming: + with pytest.raises(pl.exceptions.SchemaError): + pl.scan_parquet(tmp_path).collect(engine="streaming") + else: + with pytest.raises(pl.exceptions.SchemaError): + pl.scan_parquet(tmp_path).collect(engine="in-memory") + + +@pytest.mark.write_disk +def test_predicate_push_down_categorical_17744(tmp_path: Path) -> None: + path = tmp_path / "1" + + df = pl.DataFrame( + data={ + "n": [1, 2, 3], + "ccy": ["USD", "JPY", "EUR"], + }, + schema_overrides={"ccy": pl.Categorical("lexical")}, + ) + df.write_parquet(path) + expect = df.head(1).with_columns(pl.col(pl.Categorical).cast(pl.String)) + + lf = pl.scan_parquet(path) + + for predicate in [pl.col("ccy") == "USD", pl.col("ccy").is_in(["USD"])]: + assert_frame_equal( + lf.filter(predicate) + .with_columns(pl.col(pl.Categorical).cast(pl.String)) + .collect(), + expect, + ) + + +@pytest.mark.write_disk +@pytest.mark.parametrize("streaming", [True, False]) +def test_parquet_slice_pushdown_non_zero_offset( + tmp_path: Path, streaming: bool +) -> None: + paths = [tmp_path / "1", tmp_path / "2", tmp_path / "3"] + dfs = [pl.DataFrame({"x": i}) for i in range(len(paths))] + + for df, p in zip(dfs, paths): + df.write_parquet(p) + + # Parquet files containing only the metadata - i.e. the data parts are removed. + # Used to test that a reader doesn't try to read any data. + def trim_to_metadata(path: str | Path) -> None: + path = Path(path) + v = path.read_bytes() + metadata_and_footer_len = 8 + int.from_bytes(v[-8:][:4], "little") + path.write_bytes(v[-metadata_and_footer_len:]) + + trim_to_metadata(paths[0]) + trim_to_metadata(paths[2]) + + # Check baseline: + # * Metadata can be read without error + assert pl.read_parquet_schema(paths[0]) == dfs[0].schema + # * Attempting to read any data will error + with pytest.raises(ComputeError): + pl.scan_parquet(paths[0]).collect( + engine="streaming" if streaming else "in-memory" + ) + + df = dfs[1] + assert_frame_equal( + pl.scan_parquet(paths) + .slice(1, 1) + .collect(engine="streaming" if streaming else "in-memory"), + df, + ) + assert_frame_equal( + pl.scan_parquet(paths[1:]) + .head(1) + .collect(engine="streaming" if streaming else "in-memory"), + df, + ) + assert_frame_equal( + ( + pl.scan_parquet([paths[1], paths[1], paths[1]]) + .with_row_index() + .slice(1, 1) + .collect(engine="streaming" if streaming else "in-memory") + ), + df.with_row_index(offset=1), + ) + assert_frame_equal( + ( + pl.scan_parquet([paths[1], paths[1], paths[1]]) + .with_row_index(offset=1) + .slice(1, 1) + .collect(engine="streaming" if streaming else "in-memory") + ), + df.with_row_index(offset=2), + ) + assert_frame_equal( + pl.scan_parquet(paths[1:]) + .head(1) + .collect(engine="streaming" if streaming else "in-memory"), + df, + ) + + # Negative slice unsupported in streaming + if not streaming: + assert_frame_equal(pl.scan_parquet(paths).slice(-2, 1).collect(), df) + assert_frame_equal(pl.scan_parquet(paths[:2]).tail(1).collect(), df) + assert_frame_equal( + pl.scan_parquet(paths[1:]).slice(-99, 1).collect(), df.clear() + ) + + path = tmp_path / "data" + df = pl.select(x=pl.int_range(0, 50)) + df.write_parquet(path) + assert_frame_equal(pl.scan_parquet(path).slice(-100, 75).collect(), df.head(25)) + assert_frame_equal( + pl.scan_parquet([path, path]).with_row_index().slice(-25, 100).collect(), + pl.concat([df, df]).with_row_index().slice(75), + ) + assert_frame_equal( + pl.scan_parquet([path, path]) + .with_row_index(offset=10) + .slice(-25, 100) + .collect(), + pl.concat([df, df]).with_row_index(offset=10).slice(75), + ) + assert_frame_equal( + pl.scan_parquet(path).slice(-1, (1 << 32) - 1).collect(), df.tail(1) + ) + + +@pytest.mark.write_disk +def test_predicate_slice_pushdown_row_index_20485(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + file_path = tmp_path / "slice_pushdown.parquet" + row_group_size = 100000 + num_row_groups = 3 + + df = pl.select(ref=pl.int_range(num_row_groups * row_group_size)) + df.write_parquet(file_path, row_group_size=row_group_size) + + # Use a slice that starts near the end of one row group and extends into the next + # to test handling of slices that span multiple row groups. + slice_start = 199995 + slice_len = 10 + ldf = pl.scan_parquet(file_path) + sliced_df = ldf.with_row_index().slice(slice_start, slice_len).collect() + sliced_df_no_pushdown = ( + ldf.with_row_index().slice(slice_start, slice_len).collect(slice_pushdown=False) + ) + + expected_index = list(range(slice_start, slice_start + slice_len)) + actual_index = list(sliced_df["index"]) + assert actual_index == expected_index + + assert_frame_equal(sliced_df, sliced_df_no_pushdown) + + +@pytest.mark.write_disk +@pytest.mark.parametrize("streaming", [True, False]) +def test_parquet_row_groups_shift_bug_18739(tmp_path: Path, streaming: bool) -> None: + tmp_path.mkdir(exist_ok=True) + path = tmp_path / "data.bin" + + df = pl.DataFrame({"id": range(100)}) + df.write_parquet(path, row_group_size=1) + + lf = pl.scan_parquet(path) + assert_frame_equal(df, lf.collect(engine="streaming" if streaming else "in-memory")) + + +@pytest.mark.write_disk +@pytest.mark.parametrize("streaming", [True, False]) +def test_dsl2ir_cached_metadata(tmp_path: Path, streaming: bool) -> None: + df = pl.DataFrame({"x": 1}) + path = tmp_path / "1" + df.write_parquet(path) + + lf = pl.scan_parquet(path) + assert_frame_equal(lf.collect(), df) + + # Removes the metadata portion of the parquet file. + # Used to test that a reader doesn't try to read the metadata. + def remove_metadata(path: str | Path) -> None: + path = Path(path) + v = path.read_bytes() + metadata_and_footer_len = 8 + int.from_bytes(v[-8:][:4], "little") + path.write_bytes(v[:-metadata_and_footer_len] + b"PAR1") + + remove_metadata(path) + assert_frame_equal(lf.collect(engine="streaming" if streaming else "in-memory"), df) + + +@pytest.mark.write_disk +def test_parquet_unaligned_schema_read(tmp_path: Path) -> None: + dfs = [ + pl.DataFrame({"a": 1, "b": 10}), + pl.DataFrame({"b": 11, "a": 2}), + pl.DataFrame({"x": 3, "a": 3, "y": 3, "b": 12}), + ] + + paths = [tmp_path / "1", tmp_path / "2", tmp_path / "3"] + + for df, path in zip(dfs, paths): + df.write_parquet(path) + + lf = pl.scan_parquet(paths) + + assert_frame_equal( + lf.select("a").collect(engine="in-memory"), + pl.DataFrame({"a": [1, 2, 3]}), + ) + + assert_frame_equal( + lf.with_row_index().select("a").collect(engine="in-memory"), + pl.DataFrame({"a": [1, 2, 3]}), + ) + + assert_frame_equal( + lf.select("b", "a").collect(engine="in-memory"), + pl.DataFrame({"b": [10, 11, 12], "a": [1, 2, 3]}), + ) + + assert_frame_equal( + pl.scan_parquet(paths[:2]).collect(engine="in-memory"), + pl.DataFrame({"a": [1, 2], "b": [10, 11]}), + ) + + with pytest.raises(pl.exceptions.SchemaError): + lf.collect(engine="in-memory") + + with pytest.raises(pl.exceptions.SchemaError): + lf.with_row_index().collect(engine="in-memory") + + +@pytest.mark.write_disk +@pytest.mark.parametrize("streaming", [True, False]) +def test_parquet_unaligned_schema_read_dtype_mismatch( + tmp_path: Path, streaming: bool +) -> None: + dfs = [ + pl.DataFrame({"a": 1, "b": 10}), + pl.DataFrame({"b": "11", "a": "2"}), + ] + + paths = [tmp_path / "1", tmp_path / "2"] + + for df, path in zip(dfs, paths): + df.write_parquet(path) + + lf = pl.scan_parquet(paths) + + with pytest.raises(pl.exceptions.SchemaError, match="data type mismatch"): + lf.collect(engine="streaming" if streaming else "in-memory") + + +@pytest.mark.write_disk +@pytest.mark.parametrize("streaming", [True, False]) +def test_parquet_unaligned_schema_read_missing_cols_from_first( + tmp_path: Path, streaming: bool +) -> None: + dfs = [ + pl.DataFrame({"a": 1, "b": 10}), + pl.DataFrame({"b": 11}), + ] + + paths = [tmp_path / "1", tmp_path / "2"] + + for df, path in zip(dfs, paths): + df.write_parquet(path) + + lf = pl.scan_parquet(paths) + + with pytest.raises( + (pl.exceptions.SchemaError, pl.exceptions.ColumnNotFoundError), + ): + lf.collect(engine="streaming" if streaming else "in-memory") + + +@pytest.mark.parametrize("parallel", ["columns", "row_groups", "prefiltered", "none"]) +@pytest.mark.parametrize("streaming", [True, False]) +@pytest.mark.write_disk +def test_parquet_schema_arg( + tmp_path: Path, + parallel: ParallelStrategy, + streaming: bool, +) -> None: + tmp_path.mkdir(exist_ok=True) + dfs = [pl.DataFrame({"a": 1, "b": 1}), pl.DataFrame({"a": 2, "b": 2})] + paths = [tmp_path / "1", tmp_path / "2"] + + for df, path in zip(dfs, paths): + df.write_parquet(path) + + schema: dict[str, pl.DataType] = { + "1": pl.Datetime(time_unit="ms", time_zone="CET"), + "a": pl.Int64(), + "b": pl.Int64(), + } + + # Test `schema` containing an extra column. + + lf = pl.scan_parquet(paths, parallel=parallel, schema=schema) + + with pytest.raises((pl.exceptions.SchemaError, pl.exceptions.ColumnNotFoundError)): + lf.collect(engine="streaming" if streaming else "in-memory") + + lf = pl.scan_parquet( + paths, parallel=parallel, schema=schema, allow_missing_columns=True + ) + + assert_frame_equal( + lf.collect(engine="streaming" if streaming else "in-memory"), + pl.DataFrame({"1": None, "a": [1, 2], "b": [1, 2]}, schema=schema), + ) + + # Just one test that `read_parquet` is propagating this argument. + assert_frame_equal( + pl.read_parquet( + paths, parallel=parallel, schema=schema, allow_missing_columns=True + ), + pl.DataFrame({"1": None, "a": [1, 2], "b": [1, 2]}, schema=schema), + ) + + # Issue #19081: If a schema arg is passed, ensure its fields are propagated + # to the IR, otherwise even if `allow_missing_columns=True`, downstream + # `select()`s etc. will fail with ColumnNotFound if the column is not in + # the first file. + lf = pl.scan_parquet( + paths, parallel=parallel, schema=schema, allow_missing_columns=True + ).select("1") + + s = lf.collect(engine="streaming" if streaming else "in-memory").to_series() + assert s.len() == 2 + assert s.null_count() == 2 + + # Test files containing extra columns not in `schema` + + schema: dict[str, type[pl.DataType]] = {"a": pl.Int64} # type: ignore[no-redef] + + for allow_missing_columns in [True, False]: + lf = pl.scan_parquet( + paths, + parallel=parallel, + schema=schema, + allow_missing_columns=allow_missing_columns, + ) + + with pytest.raises(pl.exceptions.SchemaError): + lf.collect(engine="streaming" if streaming else "in-memory") + + lf = pl.scan_parquet(paths, parallel=parallel, schema=schema).select("a") + + assert_frame_equal( + lf.collect(engine="in-memory"), + pl.DataFrame({"a": [1, 2]}, schema=schema), + ) + + schema: dict[str, type[pl.DataType]] = {"a": pl.Int64, "b": pl.Int8} # type: ignore[no-redef] + + lf = pl.scan_parquet(paths, parallel=parallel, schema=schema) + + with pytest.raises( + pl.exceptions.SchemaError, + match="data type mismatch for column b: expected: i8, found: i64", + ): + lf.collect(engine="streaming" if streaming else "in-memory") + + +def test_scan_parquet_schema_specified_with_empty_files_list(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + assert_frame_equal( + pl.scan_parquet(tmp_path, schema={"x": pl.Int64}).collect(), + pl.DataFrame(schema={"x": pl.Int64}), + ) + + assert_frame_equal( + pl.scan_parquet(tmp_path, schema={"x": pl.Int64}).with_row_index().collect(), + pl.DataFrame(schema={"x": pl.Int64}).with_row_index(), + ) + + assert_frame_equal( + pl.scan_parquet( + tmp_path, schema={"x": pl.Int64}, hive_schema={"h": pl.String} + ).collect(), + pl.DataFrame(schema={"x": pl.Int64, "h": pl.String}), + ) + + assert_frame_equal( + ( + pl.scan_parquet( + tmp_path, schema={"x": pl.Int64}, hive_schema={"h": pl.String} + ) + .with_row_index() + .collect() + ), + pl.DataFrame(schema={"x": pl.Int64, "h": pl.String}).with_row_index(), + ) + + +@pytest.mark.parametrize("allow_missing_columns", [True, False]) +@pytest.mark.write_disk +def test_scan_parquet_ignores_dtype_mismatch_for_non_projected_columns_19249( + tmp_path: Path, + allow_missing_columns: bool, +) -> None: + tmp_path.mkdir(exist_ok=True) + paths = [tmp_path / "1", tmp_path / "2"] + + pl.DataFrame({"a": 1, "b": 1}, schema={"a": pl.Int32, "b": pl.UInt8}).write_parquet( + paths[0] + ) + pl.DataFrame( + {"a": 1, "b": 1}, schema={"a": pl.Int32, "b": pl.UInt64} + ).write_parquet(paths[1]) + + assert_frame_equal( + pl.scan_parquet(paths, allow_missing_columns=allow_missing_columns) + .select("a") + .collect(engine="in-memory"), + pl.DataFrame({"a": [1, 1]}, schema={"a": pl.Int32}), + ) + + +@pytest.mark.parametrize("streaming", [True, False]) +@pytest.mark.write_disk +def test_scan_parquet_streaming_row_index_19606( + tmp_path: Path, streaming: bool +) -> None: + tmp_path.mkdir(exist_ok=True) + paths = [tmp_path / "1", tmp_path / "2"] + + dfs = [pl.DataFrame({"x": i}) for i in range(len(paths))] + + for df, p in zip(dfs, paths): + df.write_parquet(p) + + assert_frame_equal( + pl.scan_parquet(tmp_path) + .with_row_index() + .collect(engine="streaming" if streaming else "in-memory"), + pl.DataFrame( + {"index": [0, 1], "x": [0, 1]}, schema={"index": pl.UInt32, "x": pl.Int64} + ), + ) diff --git a/py-polars/tests/unit/io/test_multiscan.py b/py-polars/tests/unit/io/test_multiscan.py new file mode 100644 index 000000000000..09a8508f2e35 --- /dev/null +++ b/py-polars/tests/unit/io/test_multiscan.py @@ -0,0 +1,625 @@ +from __future__ import annotations + +import io +from typing import TYPE_CHECKING, Any, Callable + +import pytest +from hypothesis import given +from hypothesis import strategies as st + +import polars as pl +from polars.meta.index_type import get_index_type +from polars.testing import assert_frame_equal + +if TYPE_CHECKING: + from pathlib import Path + + +@pytest.mark.write_disk +@pytest.mark.parametrize( + ("scan", "write"), + [ + (pl.scan_ipc, pl.DataFrame.write_ipc), + (pl.scan_parquet, pl.DataFrame.write_parquet), + (pl.scan_csv, pl.DataFrame.write_csv), + (pl.scan_ndjson, pl.DataFrame.write_ndjson), + ], +) +def test_include_file_paths(tmp_path: Path, scan: Any, write: Any) -> None: + a_path = tmp_path / "a" + b_path = tmp_path / "b" + + write(pl.DataFrame({"a": [5, 10]}), a_path) + write(pl.DataFrame({"a": [1996]}), b_path) + + out = scan([a_path, b_path], include_file_paths="f") + + assert_frame_equal( + out.collect(), + pl.DataFrame( + { + "a": [5, 10, 1996], + "f": [str(a_path), str(a_path), str(b_path)], + } + ), + ) + + +@pytest.mark.parametrize( + ("scan", "write", "ext", "supports_missing_columns", "supports_hive_partitioning"), + [ + (pl.scan_ipc, pl.DataFrame.write_ipc, "ipc", False, True), + (pl.scan_parquet, pl.DataFrame.write_parquet, "parquet", True, True), + (pl.scan_csv, pl.DataFrame.write_csv, "csv", False, False), + (pl.scan_ndjson, pl.DataFrame.write_ndjson, "jsonl", False, False), + ], +) +@pytest.mark.parametrize("missing_column", [False, True]) +@pytest.mark.parametrize("row_index", [False, True]) +@pytest.mark.parametrize("include_file_paths", [False, True]) +@pytest.mark.parametrize("hive", [False, True]) +@pytest.mark.parametrize("col", [False, True]) +@pytest.mark.write_disk +def test_multiscan_projection( + tmp_path: Path, + scan: Callable[..., pl.LazyFrame], + write: Callable[[pl.DataFrame, Path], Any], + ext: str, + supports_missing_columns: bool, + supports_hive_partitioning: bool, + missing_column: bool, + row_index: bool, + include_file_paths: bool, + hive: bool, + col: bool, +) -> None: + a = pl.DataFrame({"col": [5, 10, 1996]}) + b = pl.DataFrame({"col": [13, 37]}) + + if missing_column and supports_missing_columns: + a = a.with_columns(missing=pl.Series([420, 2000, 9])) + + a_path: Path + b_path: Path + multiscan_path: Path + + if hive and supports_hive_partitioning: + (tmp_path / "hive_col=0").mkdir() + a_path = tmp_path / "hive_col=0" / f"a.{ext}" + (tmp_path / "hive_col=1").mkdir() + b_path = tmp_path / "hive_col=1" / f"b.{ext}" + + multiscan_path = tmp_path + + else: + a_path = tmp_path / f"a.{ext}" + b_path = tmp_path / f"b.{ext}" + + multiscan_path = tmp_path / f"*.{ext}" + + write(a, a_path) + write(b, b_path) + + base_projection = [] + if missing_column and supports_missing_columns: + base_projection += ["missing"] + if row_index: + base_projection += ["row_index"] + if include_file_paths: + base_projection += ["file_path"] + if hive and supports_hive_partitioning: + base_projection += ["hive_col"] + if col: + base_projection += ["col"] + + ifp = "file_path" if include_file_paths else None + ri = "row_index" if row_index else None + + args = { + "allow_missing_columns": missing_column, + "include_file_paths": ifp, + "row_index_name": ri, + "hive_partitioning": hive, + } + + if not supports_missing_columns: + del args["allow_missing_columns"] + if not supports_hive_partitioning: + del args["hive_partitioning"] + + for projection in [ + base_projection, + base_projection[::-1], + ]: + assert_frame_equal( + scan(multiscan_path, **args).collect(engine="streaming").select(projection), + scan(multiscan_path, **args).select(projection).collect(engine="streaming"), + ) + + for remove in range(len(base_projection)): + new_projection = base_projection.copy() + new_projection.pop(remove) + + for projection in [ + new_projection, + new_projection[::-1], + ]: + print(projection) + assert_frame_equal( + scan(multiscan_path, **args) + .collect(engine="streaming") + .select(projection), + scan(multiscan_path, **args) + .select(projection) + .collect(engine="streaming"), + ) + + +@pytest.mark.parametrize( + ("scan", "write", "ext"), + [ + (pl.scan_ipc, pl.DataFrame.write_ipc, "ipc"), + (pl.scan_parquet, pl.DataFrame.write_parquet, "parquet"), + ], +) +@pytest.mark.write_disk +def test_multiscan_hive_predicate( + tmp_path: Path, + scan: Callable[..., pl.LazyFrame], + write: Callable[[pl.DataFrame, Path], Any], + ext: str, +) -> None: + a = pl.DataFrame({"col": [5, 10, 1996]}) + b = pl.DataFrame({"col": [13, 37]}) + c = pl.DataFrame({"col": [3, 5, 2024]}) + + (tmp_path / "hive_col=0").mkdir() + a_path = tmp_path / "hive_col=0" / f"0.{ext}" + (tmp_path / "hive_col=1").mkdir() + b_path = tmp_path / "hive_col=1" / f"0.{ext}" + (tmp_path / "hive_col=2").mkdir() + c_path = tmp_path / "hive_col=2" / f"0.{ext}" + + multiscan_path = tmp_path + + write(a, a_path) + write(b, b_path) + write(c, c_path) + + full = scan(multiscan_path).collect(engine="streaming") + full_ri = full.with_row_index("ri", 42) + + last_pred = None + try: + for pred in [ + pl.col.hive_col == 0, + pl.col.hive_col == 1, + pl.col.hive_col == 2, + pl.col.hive_col < 2, + pl.col.hive_col > 0, + pl.col.hive_col != 1, + pl.col.hive_col != 3, + pl.col.col == 13, + pl.col.col != 13, + (pl.col.col != 13) & (pl.col.hive_col == 1), + (pl.col.col != 13) & (pl.col.hive_col != 1), + ]: + last_pred = pred + assert_frame_equal( + full.filter(pred), + scan(multiscan_path).filter(pred).collect(engine="streaming"), + ) + + assert_frame_equal( + full_ri.filter(pred), + scan(multiscan_path) + .with_row_index("ri", 42) + .filter(pred) + .collect(engine="streaming"), + ) + except Exception as _: + print(last_pred) + raise + + +@pytest.mark.parametrize( + ("scan", "write", "ext"), + [ + (pl.scan_ipc, pl.DataFrame.write_ipc, "ipc"), + (pl.scan_parquet, pl.DataFrame.write_parquet, "parquet"), + (pl.scan_csv, pl.DataFrame.write_csv, "csv"), + (pl.scan_ndjson, pl.DataFrame.write_ndjson, "jsonl"), + ], +) +@pytest.mark.write_disk +def test_multiscan_row_index( + tmp_path: Path, + scan: Callable[..., pl.LazyFrame], + write: Callable[[pl.DataFrame, Path], Any], + ext: str, +) -> None: + a = pl.DataFrame({"col": [5, 10, 1996]}) + b = pl.DataFrame({"col": [42]}) + c = pl.DataFrame({"col": [13, 37]}) + + write(a, tmp_path / f"a.{ext}") + write(b, tmp_path / f"b.{ext}") + write(c, tmp_path / f"c.{ext}") + + col = pl.concat([a, b, c]).to_series() + g = tmp_path / f"*.{ext}" + + assert_frame_equal( + scan(g, row_index_name="ri").collect(), + pl.DataFrame( + [ + pl.Series("ri", range(6), get_index_type()), + col, + ] + ), + ) + + start = 42 + assert_frame_equal( + scan(g, row_index_name="ri", row_index_offset=start).collect(), + pl.DataFrame( + [ + pl.Series("ri", range(start, start + 6), get_index_type()), + col, + ] + ), + ) + + start = 42 + assert_frame_equal( + scan(g, row_index_name="ri", row_index_offset=start).slice(3, 3).collect(), + pl.DataFrame( + [ + pl.Series("ri", range(start + 3, start + 6), get_index_type()), + col.slice(3, 3), + ] + ), + ) + + start = 42 + assert_frame_equal( + scan(g, row_index_name="ri", row_index_offset=start) + .filter(pl.col("col") < 15) + .collect(), + pl.DataFrame( + [ + pl.Series("ri", [start + 0, start + 1, start + 4], get_index_type()), + pl.Series("col", [5, 10, 13]), + ] + ), + ) + + +@pytest.mark.parametrize( + ("scan", "write", "ext"), + [ + (pl.scan_ipc, pl.DataFrame.write_ipc, "ipc"), + (pl.scan_parquet, pl.DataFrame.write_parquet, "parquet"), + pytest.param( + pl.scan_csv, + pl.DataFrame.write_csv, + "csv", + marks=pytest.mark.xfail( + reason="See https://github.com/pola-rs/polars/issues/21211" + ), + ), + (pl.scan_ndjson, pl.DataFrame.write_ndjson, "jsonl"), + ], +) +@pytest.mark.write_disk +def test_schema_mismatch_type_mismatch( + tmp_path: Path, + scan: Callable[..., pl.LazyFrame], + write: Callable[[pl.DataFrame, Path], Any], + ext: str, +) -> None: + a = pl.DataFrame({"xyz_col": [5, 10, 1996]}) + b = pl.DataFrame({"xyz_col": ["a", "b", "c"]}) + + a_path = tmp_path / f"a.{ext}" + b_path = tmp_path / f"b.{ext}" + + multiscan_path = tmp_path / f"*.{ext}" + + write(a, a_path) + write(b, b_path) + + q = scan(multiscan_path) + + # NDJSON will just parse according to `projected_schema` + cx = ( + pytest.raises(pl.exceptions.ComputeError, match="cannot parse 'a' as Int64") + if scan is pl.scan_ndjson + else pytest.raises( + pl.exceptions.SchemaError, + match="data type mismatch for column xyz_col: expected: i64, found: str", + ) + ) + + with cx: + q.collect(engine="streaming") + + +@pytest.mark.parametrize( + ("scan", "write", "ext"), + [ + # (pl.scan_parquet, pl.DataFrame.write_parquet, "parquet"), # TODO: _ + # (pl.scan_ipc, pl.DataFrame.write_ipc, "ipc"), # TODO: _ + pytest.param( + pl.scan_csv, + pl.DataFrame.write_csv, + "csv", + marks=pytest.mark.xfail( + reason="See https://github.com/pola-rs/polars/issues/21211" + ), + ), + # (pl.scan_ndjson, pl.DataFrame.write_ndjson, "jsonl"), # TODO: _ + ], +) +@pytest.mark.write_disk +def test_schema_mismatch_order_mismatch( + tmp_path: Path, + scan: Callable[..., pl.LazyFrame], + write: Callable[[pl.DataFrame, Path], Any], + ext: str, +) -> None: + a = pl.DataFrame({"x": [5, 10, 1996], "y": ["a", "b", "c"]}) + b = pl.DataFrame({"y": ["x", "y"], "x": [1, 2]}) + + a_path = tmp_path / f"a.{ext}" + b_path = tmp_path / f"b.{ext}" + + multiscan_path = tmp_path / f"*.{ext}" + + write(a, a_path) + write(b, b_path) + + q = scan(multiscan_path) + + with pytest.raises(pl.exceptions.SchemaError): + q.collect(engine="streaming") + + +@pytest.mark.parametrize( + ("scan", "write"), + [ + (pl.scan_ipc, pl.DataFrame.write_ipc), + (pl.scan_parquet, pl.DataFrame.write_parquet), + ( + pl.scan_csv, + pl.DataFrame.write_csv, + ), + (pl.scan_ndjson, pl.DataFrame.write_ndjson), + ], +) +def test_multiscan_head( + scan: Callable[..., pl.LazyFrame], + write: Callable[[pl.DataFrame, io.BytesIO | Path], Any], +) -> None: + a = io.BytesIO() + b = io.BytesIO() + for f in [a, b]: + write(pl.Series("c1", range(10)).to_frame(), f) + f.seek(0) + + assert_frame_equal( + scan([a, b]).head(5).collect(engine="streaming"), + pl.Series("c1", range(5)).to_frame(), + ) + + +@pytest.mark.parametrize( + ("scan", "write"), + [ + (pl.scan_ipc, pl.DataFrame.write_ipc), + (pl.scan_parquet, pl.DataFrame.write_parquet), + (pl.scan_ndjson, pl.DataFrame.write_ndjson), + ( + pl.scan_csv, + pl.DataFrame.write_csv, + ), + ], +) +def test_multiscan_tail( + scan: Callable[..., pl.LazyFrame], + write: Callable[[pl.DataFrame, io.BytesIO | Path], Any], +) -> None: + a = io.BytesIO() + b = io.BytesIO() + for f in [a, b]: + write(pl.Series("c1", range(10)).to_frame(), f) + f.seek(0) + + assert_frame_equal( + scan([a, b]).tail(5).collect(engine="streaming"), + pl.Series("c1", range(5, 10)).to_frame(), + ) + + +@pytest.mark.parametrize( + ("scan", "write"), + [ + (pl.scan_ipc, pl.DataFrame.write_ipc), + (pl.scan_parquet, pl.DataFrame.write_parquet), + (pl.scan_ndjson, pl.DataFrame.write_ndjson), + ( + pl.scan_csv, + pl.DataFrame.write_csv, + ), + ], +) +def test_multiscan_slice_middle( + scan: Callable[..., pl.LazyFrame], + write: Callable[[pl.DataFrame, io.BytesIO | Path], Any], +) -> None: + fs = [io.BytesIO() for _ in range(13)] + for f in fs: + write(pl.Series("c1", range(7)).to_frame(), f) + f.seek(0) + + offset = 5 * 7 - 5 + expected = ( + list(range(2, 7)) # fs[4] + + list(range(7)) # fs[5] + + list(range(5)) # fs[6] + ) + expected_series = [pl.Series("c1", expected)] + ri_expected_series = [ + pl.Series("ri", range(offset, offset + 17), get_index_type()) + ] + expected_series + + assert_frame_equal( + scan(fs).slice(offset, 17).collect(engine="streaming"), + pl.DataFrame(expected_series), + ) + assert_frame_equal( + scan(fs, row_index_name="ri").slice(offset, 17).collect(engine="streaming"), + pl.DataFrame(ri_expected_series), + ) + + # Negative slices + offset = -(13 * 7 - offset) + assert_frame_equal( + scan(fs).slice(offset, 17).collect(engine="streaming"), + pl.DataFrame(expected_series), + ) + assert_frame_equal( + scan(fs, row_index_name="ri").slice(offset, 17).collect(engine="streaming"), + pl.DataFrame(ri_expected_series), + ) + + +@pytest.mark.parametrize( + ("scan", "write", "ext"), + [ + (pl.scan_ipc, pl.DataFrame.write_ipc, "ipc"), + (pl.scan_parquet, pl.DataFrame.write_parquet, "parquet"), + (pl.scan_ndjson, pl.DataFrame.write_ndjson, "jsonl"), + (pl.scan_csv, pl.DataFrame.write_csv, "csv"), + ], +) +@given(offset=st.integers(-100, 100), length=st.integers(0, 101)) +def test_multiscan_slice_parametric( + scan: Callable[..., pl.LazyFrame], + write: Callable[[pl.DataFrame, io.BytesIO | Path], Any], + ext: str, + offset: int, + length: int, +) -> None: + ref = io.BytesIO() + write(pl.Series("c1", [i % 7 for i in range(13 * 7)]).to_frame(), ref) + ref.seek(0) + + fs = [io.BytesIO() for _ in range(13)] + for f in fs: + write(pl.Series("c1", range(7)).to_frame(), f) + f.seek(0) + + assert_frame_equal( + scan(ref).slice(offset, length).collect(), + scan(fs).slice(offset, length).collect(engine="streaming"), + ) + + ref.seek(0) + for f in fs: + f.seek(0) + + assert_frame_equal( + scan(ref, row_index_name="ri", row_index_offset=42) + .slice(offset, length) + .collect(), + scan(fs, row_index_name="ri", row_index_offset=42) + .slice(offset, length) + .collect(engine="streaming"), + ) + + assert_frame_equal( + scan(ref, row_index_name="ri", row_index_offset=42) + .slice(offset, length) + .select("ri") + .collect(), + scan(fs, row_index_name="ri", row_index_offset=42) + .slice(offset, length) + .select("ri") + .collect(engine="streaming"), + ) + + +@pytest.mark.parametrize( + ("scan", "write"), + [ + (pl.scan_ipc, pl.DataFrame.write_ipc), + (pl.scan_parquet, pl.DataFrame.write_parquet), + (pl.scan_csv, pl.DataFrame.write_csv), + (pl.scan_ndjson, pl.DataFrame.write_ndjson), + ], +) +def test_many_files(tmp_path: Path, scan: Any, write: Any) -> None: + f = io.BytesIO() + write(pl.DataFrame({"a": [5, 10, 1996]}), f) + bs = f.getvalue() + + out = scan([bs] * 1023) + + assert_frame_equal( + out.collect(), + pl.DataFrame( + { + "a": [5, 10, 1996] * 1023, + } + ), + ) + + +def test_deadlock_stop_requested(tmp_path: Path, monkeypatch: Any) -> None: + df = pl.DataFrame( + { + "a": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + } + ) + + f = io.BytesIO() + df.write_parquet(f, row_group_size=1) + + monkeypatch.setenv("POLARS_MAX_THREADS", "2") + monkeypatch.setenv("POLARS_JOIN_SAMPLE_LIMIT", "1") + + left_fs = [io.BytesIO(f.getbuffer()) for _ in range(10)] + right_fs = [io.BytesIO(f.getbuffer()) for _ in range(10)] + + left = pl.scan_parquet(left_fs) # type: ignore[arg-type] + right = pl.scan_parquet(right_fs) # type: ignore[arg-type] + + left.join(right, pl.col.a == pl.col.a).collect(engine="streaming").height + + +@pytest.mark.parametrize( + ("scan", "write"), + [ + (pl.scan_ipc, pl.DataFrame.write_ipc), + (pl.scan_parquet, pl.DataFrame.write_parquet), + (pl.scan_csv, pl.DataFrame.write_csv), + (pl.scan_ndjson, pl.DataFrame.write_ndjson), + ], +) +def test_deadlock_linearize(scan: Any, write: Any) -> None: + df = pl.DataFrame( + { + "a": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + } + ) + + f = io.BytesIO() + write(df, f) + fs = [io.BytesIO(f.getbuffer()) for _ in range(10)] + lf = scan(fs).head(100) + + assert_frame_equal( + lf.collect(engine="streaming", slice_pushdown=False), + pl.concat([df] * 10), + ) diff --git a/py-polars/tests/unit/io/test_other.py b/py-polars/tests/unit/io/test_other.py new file mode 100644 index 000000000000..1b5baf0e2100 --- /dev/null +++ b/py-polars/tests/unit/io/test_other.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +import copy +import sys +from pathlib import Path +from typing import Any, Callable, cast + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal + + +@pytest.mark.parametrize( + "read_function", + [ + pl.read_csv, + pl.read_ipc, + pl.read_json, + pl.read_parquet, + pl.read_avro, + pl.scan_csv, + pl.scan_ipc, + pl.scan_parquet, + ], +) +def test_read_missing_file(read_function: Callable[[Any], pl.DataFrame]) -> None: + match = "\\(os error 2\\): fake_file_path" + # The message associated with OS error 2 may differ per platform + if sys.platform == "linux": + match = "No such file or directory " + match + + if "scan" in read_function.__name__: + with pytest.raises(FileNotFoundError, match=match): + read_function("fake_file_path").collect() # type: ignore[attr-defined] + else: + with pytest.raises(FileNotFoundError, match=match): + read_function("fake_file_path") + + +@pytest.mark.parametrize( + "write_method_name", + [ + # "write_excel" not included + # because it already raises a FileCreateError + # from the underlying library dependency + "write_csv", + "write_ipc", + "write_ipc_stream", + "write_json", + "write_ndjson", + "write_parquet", + "write_avro", + ], +) +def test_write_missing_directory(write_method_name: str) -> None: + df = pl.DataFrame({"a": [1]}) + non_existing_path = Path("non", "existing", "path") + if non_existing_path.exists(): + pytest.fail( + "Testing on a non existing path failed because the path does exist." + ) + write_method = getattr(df, write_method_name) + with pytest.raises(FileNotFoundError): + write_method(non_existing_path) + + +def test_read_missing_file_path_truncated() -> None: + content = "lskdfj".join(str(i) for i in range(25)) + with pytest.raises( + FileNotFoundError, + match="\\.\\.\\.lskdfj14lskdfj15lskdfj16lskdfj17lskdfj18lskdfj19lskdfj20lskdfj21lskdfj22lskdfj23lskdfj24", + ): + pl.read_csv(content) + + +def test_copy() -> None: + df = pl.DataFrame({"a": [1, 2], "b": ["a", None], "c": [True, False]}) + assert_frame_equal(copy.copy(df), df) + assert_frame_equal(copy.deepcopy(df), df) + + a = pl.Series("a", [1, 2]) + assert_series_equal(copy.copy(a), a) + assert_series_equal(copy.deepcopy(a), a) + + +@pytest.mark.usefixtures("test_global_and_local") +def test_categorical_round_trip() -> None: + df = pl.DataFrame({"ints": [1, 2, 3], "cat": ["a", "b", "c"]}) + df = df.with_columns(pl.col("cat").cast(pl.Categorical)) + + tbl = df.to_arrow() + assert "dictionary" in str(tbl["cat"].type) + + df2 = cast(pl.DataFrame, pl.from_arrow(tbl)) + assert df2.dtypes == [pl.Int64, pl.Categorical] + + +def test_from_different_chunks() -> None: + s0 = pl.Series("a", [1, 2, 3, 4, None]) + s1 = pl.Series("b", [1, 2]) + s11 = pl.Series("b", [1, 2, 3]) + s1.append(s11) + + # check we don't panic + df = pl.DataFrame([s0, s1]) + df.to_arrow() + df = pl.DataFrame([s0, s1]) + out = df.to_pandas() + assert list(out.columns) == ["a", "b"] + assert out.shape == (5, 2) + + +def test_unit_io_subdir_has_no_init() -> None: + # -------------------------------------------------------------------------------- + # If this test fails it means an '__init__.py' was added to 'tests/unit/io'. + # See https://github.com/pola-rs/polars/pull/6889 for why this can cause issues. + # -------------------------------------------------------------------------------- + # TLDR: it can mask the builtin 'io' module, causing a fatal python error. + # -------------------------------------------------------------------------------- + io_dir = Path(__file__).parent + assert io_dir.parts[-2:] == ("unit", "io") + assert not (io_dir / "__init__.py").exists(), ( + "Found undesirable '__init__.py' in the 'unit.io' tests subdirectory" + ) + + +@pytest.mark.write_disk +@pytest.mark.parametrize( + ("scan_funcs", "write_func"), + [ + ([pl.scan_parquet, pl.read_parquet], pl.DataFrame.write_parquet), + ([pl.scan_csv, pl.read_csv], pl.DataFrame.write_csv), + ], +) +@pytest.mark.parametrize("char", ["[", "*"]) +def test_no_glob( + scan_funcs: list[Callable[[Any], pl.LazyFrame | pl.DataFrame]], + write_func: Callable[[pl.DataFrame, Path], None], + char: str, + tmp_path: Path, +) -> None: + if sys.platform == "win32" and char == "*": + pytest.skip("unsupported glob char for windows") + + tmp_path.mkdir(exist_ok=True) + + df = pl.DataFrame({"x": 1}) + + paths = [tmp_path / f"{char}", tmp_path / f"{char}1"] + + write_func(df, paths[0]) + write_func(df, paths[1]) + + for func in scan_funcs: + assert_frame_equal(func(paths[0], glob=False).lazy().collect(), df) # type: ignore[call-arg] diff --git a/py-polars/tests/unit/io/test_parquet.py b/py-polars/tests/unit/io/test_parquet.py new file mode 100644 index 000000000000..2f6888b5f6dc --- /dev/null +++ b/py-polars/tests/unit/io/test_parquet.py @@ -0,0 +1,3203 @@ +from __future__ import annotations + +import decimal +import functools +import io +from datetime import date, datetime, time, timezone +from decimal import Decimal +from itertools import chain +from typing import TYPE_CHECKING, Any, Callable, Literal, cast + +import fsspec +import numpy as np +import pandas as pd +import pyarrow as pa +import pyarrow.dataset as ds +import pyarrow.parquet as pq +import pytest +from hypothesis import given +from hypothesis import strategies as st + +import polars as pl +from polars.exceptions import ComputeError +from polars.testing import assert_frame_equal, assert_series_equal +from polars.testing.parametric import column, dataframes +from polars.testing.parametric.strategies.core import series + +if TYPE_CHECKING: + from pathlib import Path + + from polars._typing import ParallelStrategy, ParquetCompression + from tests.unit.conftest import MemoryUsage + + +@pytest.mark.may_fail_auto_streaming +def test_round_trip(df: pl.DataFrame) -> None: + f = io.BytesIO() + df.write_parquet(f) + f.seek(0) + assert_frame_equal(pl.read_parquet(f), df) + + +@pytest.mark.may_fail_auto_streaming +def test_scan_round_trip(df: pl.DataFrame) -> None: + f = io.BytesIO() + df.write_parquet(f) + f.seek(0) + assert_frame_equal(pl.scan_parquet(f).collect(), df) + f.seek(0) + assert_frame_equal(pl.scan_parquet(f).head().collect(), df.head()) + + +COMPRESSIONS = [ + "lz4", + "uncompressed", + "snappy", + "gzip", + # "lzo", # LZO compression currently not supported by Arrow backend + "brotli", + "zstd", +] + + +@pytest.mark.write_disk +def test_write_parquet_using_pyarrow_9753(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + df = pl.DataFrame({"a": [1, 2, 3]}) + df.write_parquet( + tmp_path / "test.parquet", + compression="zstd", + statistics=True, + use_pyarrow=True, + pyarrow_options={"coerce_timestamps": "us"}, + ) + + +@pytest.mark.parametrize("compression", COMPRESSIONS) +def test_write_parquet_using_pyarrow_write_to_dataset_with_partitioning( + tmp_path: Path, + compression: ParquetCompression, +) -> None: + df = pl.DataFrame({"a": [1, 2, 3], "partition_col": ["one", "two", "two"]}) + path_to_write = tmp_path / "test_folder" + path_to_write.mkdir(exist_ok=True) + df.write_parquet( + file=path_to_write, + statistics=True, + use_pyarrow=True, + row_group_size=128, + pyarrow_options={ + "partition_cols": ["partition_col"], + "compression": compression, + }, + ) + + # cast is necessary as pyarrow writes partitions as categorical type + read_df = pl.read_parquet(path_to_write, use_pyarrow=True).with_columns( + pl.col("partition_col").cast(pl.String) + ) + assert_frame_equal(df, read_df) + + +@pytest.fixture +def small_parquet_path(io_files_path: Path) -> Path: + return io_files_path / "small.parquet" + + +@pytest.mark.parametrize("compression", COMPRESSIONS) +@pytest.mark.parametrize("use_pyarrow", [True, False]) +def test_to_from_buffer( + df: pl.DataFrame, compression: ParquetCompression, use_pyarrow: bool +) -> None: + df = df[["list_str"]] + buf = io.BytesIO() + df.write_parquet(buf, compression=compression, use_pyarrow=use_pyarrow) + buf.seek(0) + read_df = pl.read_parquet(buf, use_pyarrow=use_pyarrow) + assert_frame_equal(df, read_df, categorical_as_str=True) + + +@pytest.mark.parametrize("use_pyarrow", [True, False]) +@pytest.mark.parametrize("rechunk_and_expected_chunks", [(True, 1), (False, 3)]) +@pytest.mark.may_fail_auto_streaming +def test_read_parquet_respects_rechunk_16416( + use_pyarrow: bool, rechunk_and_expected_chunks: tuple[bool, int] +) -> None: + # Create a dataframe with 3 chunks: + df = pl.DataFrame({"a": [1]}) + df = pl.concat([df, df, df]) + buf = io.BytesIO() + df.write_parquet(buf, row_group_size=1) + buf.seek(0) + + rechunk, expected_chunks = rechunk_and_expected_chunks + result = pl.read_parquet(buf, use_pyarrow=use_pyarrow, rechunk=rechunk) + assert result.n_chunks() == expected_chunks + + +def test_to_from_buffer_lzo(df: pl.DataFrame) -> None: + buf = io.BytesIO() + # Writing lzo compressed parquet files is not supported for now. + with pytest.raises(ComputeError): + df.write_parquet(buf, compression="lzo", use_pyarrow=False) + buf.seek(0) + # Invalid parquet file as writing failed. + with pytest.raises(ComputeError): + _ = pl.read_parquet(buf) + + buf = io.BytesIO() + with pytest.raises(OSError): + # Writing lzo compressed parquet files is not supported for now. + df.write_parquet(buf, compression="lzo", use_pyarrow=True) + buf.seek(0) + # Invalid parquet file as writing failed. + with pytest.raises(ComputeError): + _ = pl.read_parquet(buf) + + +@pytest.mark.write_disk +@pytest.mark.parametrize("compression", COMPRESSIONS) +def test_to_from_file( + df: pl.DataFrame, compression: ParquetCompression, tmp_path: Path +) -> None: + tmp_path.mkdir(exist_ok=True) + + file_path = tmp_path / "small.avro" + df.write_parquet(file_path, compression=compression) + read_df = pl.read_parquet(file_path) + assert_frame_equal(df, read_df, categorical_as_str=True) + + +@pytest.mark.write_disk +def test_to_from_file_lzo(df: pl.DataFrame, tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + file_path = tmp_path / "small.avro" + + # Writing lzo compressed parquet files is not supported for now. + with pytest.raises(ComputeError): + df.write_parquet(file_path, compression="lzo", use_pyarrow=False) + # Invalid parquet file as writing failed. + with pytest.raises(ComputeError): + _ = pl.read_parquet(file_path) + + # Writing lzo compressed parquet files is not supported for now. + with pytest.raises(OSError): + df.write_parquet(file_path, compression="lzo", use_pyarrow=True) + # Invalid parquet file as writing failed. + with pytest.raises(FileNotFoundError): + _ = pl.read_parquet(file_path) + + +def test_select_columns() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [True, False, True], "c": ["a", "b", "c"]}) + expected = pl.DataFrame({"b": [True, False, True], "c": ["a", "b", "c"]}) + + f = io.BytesIO() + df.write_parquet(f) + f.seek(0) + + read_df = pl.read_parquet(f, columns=["b", "c"], use_pyarrow=False) + assert_frame_equal(expected, read_df) + + +def test_select_projection() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [True, False, True], "c": ["a", "b", "c"]}) + expected = pl.DataFrame({"b": [True, False, True], "c": ["a", "b", "c"]}) + f = io.BytesIO() + df.write_parquet(f) + f.seek(0) + + read_df = pl.read_parquet(f, columns=[1, 2], use_pyarrow=False) + assert_frame_equal(expected, read_df) + + +@pytest.mark.parametrize("compression", COMPRESSIONS) +@pytest.mark.parametrize("use_pyarrow", [True, False]) +def test_parquet_datetime(compression: ParquetCompression, use_pyarrow: bool) -> None: + # This failed because parquet writers cast datetime to Date + f = io.BytesIO() + data = { + "datetime": [ # unix timestamp in ms + 1618354800000, + 1618354740000, + 1618354680000, + 1618354620000, + 1618354560000, + ], + "value1": [73.1999969482, 71.0999984741, 74.5, 69.5999984741, 69.6999969482], + "value2": [59.5999984741, 61.0, 62.2999992371, 56.9000015259, 60.0], + } + df = pl.DataFrame(data) + df = df.with_columns(df["datetime"].cast(pl.Datetime)) + + df.write_parquet(f, use_pyarrow=use_pyarrow, compression=compression) + f.seek(0) + read = pl.read_parquet(f) + assert_frame_equal(read, df) + + +def test_nested_parquet() -> None: + f = io.BytesIO() + data = [ + {"a": [{"b": 0}]}, + {"a": [{"b": 1}, {"b": 2}]}, + ] + df = pd.DataFrame(data) + df.to_parquet(f) + + read = pl.read_parquet(f, use_pyarrow=True) + assert read.columns == ["a"] + assert isinstance(read.dtypes[0], pl.datatypes.List) + assert isinstance(read.dtypes[0].inner, pl.datatypes.Struct) + + +@pytest.mark.write_disk +def test_glob_parquet(df: pl.DataFrame, tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + file_path = tmp_path / "small.parquet" + df.write_parquet(file_path) + + path_glob = tmp_path / "small*.parquet" + assert pl.read_parquet(path_glob).shape == (3, df.width) + assert pl.scan_parquet(path_glob).collect().shape == (3, df.width) + + +def test_chunked_round_trip() -> None: + df1 = pl.DataFrame( + { + "a": [1] * 2, + "l": [[1] for j in range(2)], + } + ) + df2 = pl.DataFrame( + { + "a": [2] * 3, + "l": [[2] for j in range(3)], + } + ) + + df = df1.vstack(df2) + + f = io.BytesIO() + df.write_parquet(f) + f.seek(0) + assert_frame_equal(pl.read_parquet(f), df) + + +@pytest.mark.write_disk +def test_lazy_self_join_file_cache_prop_3979(df: pl.DataFrame, tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + file_path = tmp_path / "small.parquet" + df.write_parquet(file_path) + + a = pl.scan_parquet(file_path) + b = pl.DataFrame({"a": [1]}).lazy() + + expected_shape = (3, df.width + b.collect_schema().len()) + assert a.join(b, how="cross").collect().shape == expected_shape + assert b.join(a, how="cross").collect().shape == expected_shape + + +def test_recursive_logical_type() -> None: + df = pl.DataFrame({"str": ["A", "B", "A", "B", "C"], "group": [1, 1, 2, 1, 2]}) + df = df.with_columns(pl.col("str").cast(pl.Categorical)) + + df_groups = df.group_by("group").agg([pl.col("str").alias("cat_list")]) + f = io.BytesIO() + df_groups.write_parquet(f, use_pyarrow=True) + f.seek(0) + read = pl.read_parquet(f, use_pyarrow=True) + assert read.dtypes == [pl.Int64, pl.List(pl.Categorical)] + assert read.shape == (2, 2) + + +def test_nested_dictionary() -> None: + with pl.StringCache(): + df = ( + pl.DataFrame({"str": ["A", "B", "A", "B", "C"], "group": [1, 1, 2, 1, 2]}) + .with_columns(pl.col("str").cast(pl.Categorical)) + .group_by("group") + .agg([pl.col("str").alias("cat_list")]) + ) + f = io.BytesIO() + df.write_parquet(f) + f.seek(0) + + read_df = pl.read_parquet(f) + assert_frame_equal(df, read_df) + + +def test_row_group_size_saturation() -> None: + df = pl.DataFrame({"a": [1, 2, 3]}) + f = io.BytesIO() + + # request larger chunk than rows in df + df.write_parquet(f, row_group_size=1024) + f.seek(0) + assert_frame_equal(pl.read_parquet(f), df) + + +def test_nested_sliced() -> None: + for df in [ + pl.Series([[1, 2], [3, 4], [5, 6]]).slice(2, 2).to_frame(), + pl.Series([[None, 2], [3, 4], [5, 6]]).to_frame(), + pl.Series([[None, 2], [3, 4], [5, 6]]).slice(2, 2).to_frame(), + pl.Series([["a", "a"], ["", "a"], ["c", "de"]]).slice(3, 2).to_frame(), + pl.Series([[None, True], [False, False], [True, True]]).slice(2, 2).to_frame(), + ]: + f = io.BytesIO() + df.write_parquet(f) + f.seek(0) + assert_frame_equal(pl.read_parquet(f), df) + + +def test_parquet_5795() -> None: + df_pd = pd.DataFrame( + { + "a": [ + "V", + "V", + "V", + "V", + "V", + "V", + "V", + "V", + "V", + "V", + "V", + "V", + "V", + "V", + None, + None, + None, + None, + None, + None, + ] + } + ) + f = io.BytesIO() + df_pd.to_parquet(f) + f.seek(0) + assert_frame_equal(pl.read_parquet(f), pl.from_pandas(df_pd)) + + +def test_parquet_nesting_structs_list() -> None: + f = io.BytesIO() + df = pl.from_records( + [ + { + "id": 1, + "list_of_structs_col": [ + {"a": 10, "b": [10, 11, 12]}, + {"a": 11, "b": [13, 14, 15]}, + ], + }, + { + "id": 2, + "list_of_structs_col": [ + {"a": 44, "b": [12]}, + ], + }, + ] + ) + + df.write_parquet(f) + f.seek(0) + + assert_frame_equal(pl.read_parquet(f), df) + + +def test_parquet_nested_dictionaries_6217() -> None: + _type = pa.dictionary(pa.int64(), pa.string()) + + fields = [("a_type", _type)] + struct_type = pa.struct(fields) + + col1 = pa.StructArray.from_arrays( + [pa.DictionaryArray.from_arrays([0, 0, 1], ["A", "B"])], + fields=struct_type, + ) + + table = pa.table({"Col1": col1}) + + with pl.StringCache(): + df = pl.from_arrow(table) + + f = io.BytesIO() + import pyarrow.parquet as pq + + pq.write_table(table, f, compression="snappy") + f.seek(0) + read = pl.read_parquet(f) + assert_frame_equal(read, df) # type: ignore[arg-type] + + +@pytest.mark.write_disk +def test_head_union(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + df1 = pl.DataFrame({"a": [0, 1, 2], "b": [1, 2, 3]}) + df2 = pl.DataFrame({"a": [3, 4, 5], "b": [4, 5, 6]}) + + file_path_1 = tmp_path / "df_fetch_1.parquet" + file_path_2 = tmp_path / "df_fetch_2.parquet" + file_path_glob = tmp_path / "df_fetch_*.parquet" + + df1.write_parquet(file_path_1) + df2.write_parquet(file_path_2) + + result_one = pl.scan_parquet(file_path_1).head(1).collect() + result_glob = pl.scan_parquet(file_path_glob).head(1).collect() + + expected = pl.DataFrame({"a": [0], "b": [1]}) + assert_frame_equal(result_one, expected) + + # Both fetch 1 per file or 1 per dataset would be ok, as we don't guarantee anything + # currently we have one per dataset. + expected = pl.DataFrame({"a": [0], "b": [1]}) + assert_frame_equal(result_glob, expected) + + +@pytest.mark.slow +def test_struct_pyarrow_dataset_5796(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + num_rows = 2**17 + 1 + + df = pl.from_records([{"id": i, "nested": {"a": i}} for i in range(num_rows)]) + file_path = tmp_path / "out.parquet" + df.write_parquet(file_path, use_pyarrow=True) + tbl = ds.dataset(file_path).to_table() + result = pl.from_arrow(tbl) + + assert_frame_equal(result, df) # type: ignore[arg-type] + + +@pytest.mark.slow +@pytest.mark.parametrize("case", [1048576, 1048577]) +def test_parquet_chunks_545(case: int) -> None: + f = io.BytesIO() + # repeat until it has case instances + df = pd.DataFrame( + np.tile([1.0, pd.to_datetime("2010-10-10")], [case, 1]), + columns=["floats", "dates"], + ) + + # write as parquet + df.to_parquet(f) + f.seek(0) + + # read it with polars + polars_df = pl.read_parquet(f) + assert_frame_equal(pl.DataFrame(df), polars_df) + + +def test_nested_null_roundtrip() -> None: + f = io.BytesIO() + df = pl.DataFrame( + { + "experiences": [ + [ + {"company": "Google", "years": None}, + {"company": "Facebook", "years": None}, + ], + ] + } + ) + + df.write_parquet(f) + f.seek(0) + df_read = pl.read_parquet(f) + assert_frame_equal(df_read, df) + + +def test_parquet_nested_list_pandas() -> None: + # pandas/pyarrow writes as nested null dict + df_pd = pd.DataFrame({"listcol": [[] * 10]}) + f = io.BytesIO() + df_pd.to_parquet(f) + f.seek(0) + df = pl.read_parquet(f) + assert df.dtypes == [pl.List(pl.Null)] + assert df.to_dict(as_series=False) == {"listcol": [[]]} + + +def test_parquet_string_cache() -> None: + f = io.BytesIO() + + df = pl.DataFrame({"a": ["a", "b", "c", "d"]}).with_columns( + pl.col("a").cast(pl.Categorical) + ) + + df.write_parquet(f, row_group_size=2) + + # this file has 2 row groups and a categorical column + # so polars should automatically set string cache + f.seek(0) + assert_series_equal(pl.read_parquet(f)["a"].cast(str), df["a"].cast(str)) + + +def test_tz_aware_parquet_9586(io_files_path: Path) -> None: + result = pl.read_parquet(io_files_path / "tz_aware.parquet") + expected = pl.DataFrame( + {"UTC_DATETIME_ID": [datetime(2023, 6, 26, 14, 15, 0, tzinfo=timezone.utc)]} + ).select(pl.col("*").cast(pl.Datetime("ns", "UTC"))) + assert_frame_equal(result, expected) + + +def test_nested_list_page_reads_to_end_11548() -> None: + df = pl.select( + pl.repeat(pl.arange(0, 2048, dtype=pl.UInt64).implode(), 2).alias("x"), + ) + + f = io.BytesIO() + + pq.write_table(df.to_arrow(), f, data_page_size=1) + + f.seek(0) + + result = pl.read_parquet(f).select(pl.col("x").list.len()) + assert result.to_series().to_list() == [2048, 2048] + + +def test_parquet_nano_second_schema() -> None: + value = time(9, 0, 0) + f = io.BytesIO() + df = pd.DataFrame({"Time": [value]}) + df.to_parquet(f) + f.seek(0) + assert pl.read_parquet(f).item() == value + + +def test_nested_struct_read_12610() -> None: + n = 1_025 + expect = pl.select(a=pl.int_range(0, n), b=pl.repeat(1, n)).with_columns( + struct=pl.struct(pl.all()) + ) + + f = io.BytesIO() + expect.write_parquet( + f, + use_pyarrow=True, + ) + f.seek(0) + + actual = pl.read_parquet(f) + assert_frame_equal(expect, actual) + + +@pytest.mark.write_disk +def test_decimal_parquet(tmp_path: Path) -> None: + path = tmp_path / "foo.parquet" + df = pl.DataFrame( + { + "foo": [1, 2, 3], + "bar": ["6", "7", "8"], + } + ) + + df = df.with_columns(pl.col("bar").cast(pl.Decimal)) + + df.write_parquet(path, statistics=True) + out = pl.scan_parquet(path).filter(foo=2).collect().to_dict(as_series=False) + assert out == {"foo": [2], "bar": [Decimal("7")]} + + +@pytest.mark.write_disk +def test_enum_parquet(tmp_path: Path) -> None: + path = tmp_path / "enum.parquet" + df = pl.DataFrame( + [pl.Series("e", ["foo", "bar", "ham"], dtype=pl.Enum(["foo", "bar", "ham"]))] + ) + df.write_parquet(path) + out = pl.read_parquet(path) + assert_frame_equal(df, out) + + +def test_parquet_rle_non_nullable_12814() -> None: + column = ( + pl.select(x=pl.arange(0, 1025, dtype=pl.Int64) // 10).to_series().to_arrow() + ) + schema = pa.schema([pa.field("foo", pa.int64(), nullable=False)]) + table = pa.Table.from_arrays([column], schema=schema) + + f = io.BytesIO() + pq.write_table(table, f, data_page_size=1) + f.seek(0) + + print(pq.read_table(f)) + + f.seek(0) + + expect = pl.DataFrame(table).tail(10) + actual = pl.read_parquet(f).tail(10) + + assert_frame_equal(expect, actual) + + +@pytest.mark.slow +def test_parquet_12831() -> None: + n = 70_000 + df = pl.DataFrame({"x": ["aaaaaa"] * n}) + f = io.BytesIO() + df.write_parquet(f, row_group_size=int(1e8), data_page_size=512) + f.seek(0) + assert_frame_equal(pl.from_arrow(pq.read_table(f)), df) # type: ignore[arg-type] + + +@pytest.mark.write_disk +def test_parquet_struct_categorical(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + df = pl.DataFrame( + [ + pl.Series("a", ["bob"], pl.Categorical), + pl.Series("b", ["foo"], pl.Categorical), + ] + ) + + file_path = tmp_path / "categorical.parquet" + df.write_parquet(file_path) + + with pl.StringCache(): + out = pl.read_parquet(file_path).select(pl.col("b").value_counts()) + assert out.to_dict(as_series=False) == {"b": [{"b": "foo", "count": 1}]} + + +@pytest.mark.write_disk +def test_null_parquet(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + df = pl.DataFrame([pl.Series("foo", [], dtype=pl.Int8)]) + file_path = tmp_path / "null.parquet" + df.write_parquet(file_path) + out = pl.read_parquet(file_path) + assert_frame_equal(out, df) + + +@pytest.mark.write_disk +def test_write_parquet_with_null_col(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + df1 = pl.DataFrame({"nulls": [None] * 2, "ints": [1] * 2}) + df2 = pl.DataFrame({"nulls": [None] * 2, "ints": [1] * 2}) + df3 = pl.DataFrame({"nulls": [None] * 3, "ints": [1] * 3}) + df = df1.vstack(df2) + df = df.vstack(df3) + file_path = tmp_path / "with_null.parquet" + df.write_parquet(file_path, row_group_size=3) + out = pl.read_parquet(file_path) + assert_frame_equal(out, df) + + +@pytest.mark.write_disk +def test_scan_parquet_binary_buffered_reader(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + df = pl.DataFrame({"a": [1, 2, 3]}) + file_path = tmp_path / "test.parquet" + df.write_parquet(file_path) + + with file_path.open("rb") as f: + out = pl.scan_parquet(f).collect() + assert_frame_equal(out, df) + + +@pytest.mark.write_disk +def test_read_parquet_binary_buffered_reader(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + df = pl.DataFrame({"a": [1, 2, 3]}) + file_path = tmp_path / "test.parquet" + df.write_parquet(file_path) + + with file_path.open("rb") as f: + out = pl.read_parquet(f) + assert_frame_equal(out, df) + + +@pytest.mark.write_disk +def test_read_parquet_binary_file_io(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + df = pl.DataFrame({"a": [1, 2, 3]}) + file_path = tmp_path / "test.parquet" + df.write_parquet(file_path) + + with file_path.open("rb", buffering=0) as f: + out = pl.read_parquet(f) + assert_frame_equal(out, df) + + +# https://github.com/pola-rs/polars/issues/15760 +@pytest.mark.write_disk +def test_read_parquet_binary_fsspec(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + df = pl.DataFrame({"a": [1, 2, 3]}) + file_path = tmp_path / "test.parquet" + df.write_parquet(file_path) + + with fsspec.open(file_path) as f: + out = pl.read_parquet(f) + assert_frame_equal(out, df) + + +def test_read_parquet_binary_bytes_io() -> None: + df = pl.DataFrame({"a": [1, 2, 3]}) + f = io.BytesIO() + df.write_parquet(f) + f.seek(0) + + out = pl.read_parquet(f) + assert_frame_equal(out, df) + + +def test_read_parquet_binary_bytes() -> None: + df = pl.DataFrame({"a": [1, 2, 3]}) + f = io.BytesIO() + df.write_parquet(f) + bytes = f.getvalue() + + out = pl.read_parquet(bytes) + assert_frame_equal(out, df) + + +def test_utc_timezone_normalization_13670(tmp_path: Path) -> None: + """'+00:00' timezones becomes 'UTC' timezone.""" + utc_path = tmp_path / "utc.parquet" + zero_path = tmp_path / "00_00.parquet" + utc_lowercase_path = tmp_path / "utc_lowercase.parquet" + for tz, path in [ + ("+00:00", zero_path), + ("UTC", utc_path), + ("utc", utc_lowercase_path), + ]: + pq.write_table( + pa.table( + {"c1": [1234567890123] * 10}, + schema=pa.schema([pa.field("c1", pa.timestamp("ms", tz=tz))]), + ), + path, + ) + + df = pl.scan_parquet([utc_path, zero_path]).head(5).collect() + assert cast(pl.Datetime, df.schema["c1"]).time_zone == "UTC" + df = pl.scan_parquet([zero_path, utc_path]).head(5).collect() + assert cast(pl.Datetime, df.schema["c1"]).time_zone == "UTC" + df = pl.scan_parquet([zero_path, utc_lowercase_path]).head(5).collect() + assert cast(pl.Datetime, df.schema["c1"]).time_zone == "UTC" + + +def test_parquet_rle_14333() -> None: + vals = [True, False, True, False, True, False, True, False, True, False] + table = pa.table({"a": vals}) + + f = io.BytesIO() + pq.write_table(table, f, data_page_version="2.0") + f.seek(0) + assert pl.read_parquet(f)["a"].to_list() == vals + + +def test_parquet_rle_null_binary_read_14638() -> None: + df = pl.DataFrame({"x": [None]}, schema={"x": pl.String}) + + f = io.BytesIO() + df.write_parquet(f, use_pyarrow=True) + f.seek(0) + assert "RLE_DICTIONARY" in pq.read_metadata(f).row_group(0).column(0).encodings + f.seek(0) + assert_frame_equal(df, pl.read_parquet(f)) + + +def test_parquet_string_rle_encoding() -> None: + n = 3 + data = { + "id": ["abcdefgh"] * n, + } + + df = pl.DataFrame(data) + f = io.BytesIO() + df.write_parquet(f, use_pyarrow=False) + f.seek(0) + + assert ( + "RLE_DICTIONARY" + in pq.ParquetFile(f).metadata.to_dict()["row_groups"][0]["columns"][0][ + "encodings" + ] + ) + + +@pytest.mark.may_fail_auto_streaming +def test_sliced_dict_with_nulls_14904() -> None: + df = ( + pl.DataFrame({"x": [None, None]}) + .cast(pl.Categorical) + .with_columns(y=pl.concat_list("x")) + .slice(0, 1) + ) + test_round_trip(df) + + +@pytest.fixture +def empty_compressed_datapage_v2_path(io_files_path: Path) -> Path: + return io_files_path / "empty_datapage_v2.snappy.parquet" + + +def test_read_empty_compressed_datapage_v2_22170( + empty_compressed_datapage_v2_path: Path, +) -> None: + df = pl.DataFrame({"value": [None]}, schema={"value": pl.Float32}) + assert_frame_equal(df, pl.read_parquet(empty_compressed_datapage_v2_path)) + + +def test_parquet_array_dtype() -> None: + df = pl.DataFrame({"x": []}) + df = df.cast({"x": pl.Array(pl.Int64, shape=3)}) + test_round_trip(df) + + +def test_parquet_array_dtype_nulls() -> None: + df = pl.DataFrame({"x": [[1, 2], None, [None, 3]]}) + df = df.cast({"x": pl.Array(pl.Int64, shape=2)}) + test_round_trip(df) + + +@pytest.mark.parametrize( + ("series", "dtype"), + [ + ([[1, 2, 3]], pl.List(pl.Int64)), + ([[1, None, 3]], pl.List(pl.Int64)), + ( + [{"a": []}, {"a": [1]}, {"a": [1, 2, 3]}], + pl.Struct({"a": pl.List(pl.Int64)}), + ), + ([{"a": None}, None, {"a": [1, 2, None]}], pl.Struct({"a": pl.List(pl.Int64)})), + ( + [[{"a": []}, {"a": [1]}, {"a": [1, 2, 3]}], None, [{"a": []}, {"a": [42]}]], + pl.List(pl.Struct({"a": pl.List(pl.Int64)})), + ), + ( + [ + [1, None, 3], + None, + [1, 3, 4], + None, + [9, None, 4], + [None, 42, 13], + [37, 511, None], + ], + pl.List(pl.Int64), + ), + ([[1, 2, 3]], pl.Array(pl.Int64, 3)), + ([[1, None, 3], None, [1, 2, None]], pl.Array(pl.Int64, 3)), + ([[1, 2], None, [None, 3]], pl.Array(pl.Int64, 2)), + ([[], [], []], pl.Array(pl.Int64, 0)), + ([[], None, []], pl.Array(pl.Int64, 0)), + ( + [[[1, 5, 2], [42, 13, 37]], [[1, 2, 3], [5, 2, 3]], [[1, 2, 1], [3, 1, 3]]], + pl.Array(pl.Array(pl.Int8, 3), 2), + ), + ( + [[[1, 5, 2], [42, 13, 37]], None, [None, [3, 1, 3]]], + pl.Array(pl.Array(pl.Int8, 3), 2), + ), + ( + [ + [[[2, 1], None, [4, 1], None], []], + None, + [None, [[4, 4], None, [1, 2]]], + ], + pl.Array(pl.List(pl.Array(pl.Int8, 2)), 2), + ), + ([[[], []]], pl.Array(pl.List(pl.Array(pl.Int8, 2)), 2)), + ( + [ + [ + [[[42, 13, 37, 15, 9, 20, 0, 0, 5, 10], None]], + [None, [None, [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]], None], + ] + ], + pl.Array(pl.List(pl.Array(pl.Array(pl.Int8, 10), 2)), 2), + ), + ( + [ + None, + [None], + [[None]], + [[[None]]], + [[[[None]]]], + [[[[[None]]]]], + [[[[[1]]]]], + ], + pl.Array(pl.Array(pl.Array(pl.Array(pl.Array(pl.Int8, 1), 1), 1), 1), 1), + ), + ( + [ + None, + [None], + [[]], + [[None]], + [[[None], None]], + [[[None], [None]]], + [[[[None]], [[[1]]]]], + [[[[[None]]]]], + [[[[[1]]]]], + ], + pl.Array(pl.List(pl.Array(pl.List(pl.Array(pl.Int8, 1)), 1)), 1), + ), + ], +) +def test_complex_types(series: list[Any], dtype: pl.DataType) -> None: + xs = pl.Series(series, dtype=dtype) + df = pl.DataFrame({"x": xs}) + + test_round_trip(df) + + +@pytest.mark.write_disk +def test_parquet_array_statistics(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + df = pl.DataFrame({"a": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], "b": [1, 2, 3]}) + file_path = tmp_path / "test.parquet" + + df.with_columns(a=pl.col("a").list.to_array(3)).lazy().filter( + pl.col("a") != [1, 2, 3] + ).collect() + df.with_columns(a=pl.col("a").list.to_array(3)).lazy().sink_parquet(file_path) + + result = pl.scan_parquet(file_path).filter(pl.col("a") != [1, 2, 3]).collect() + assert result.to_dict(as_series=False) == {"a": [[4, 5, 6], [7, 8, 9]], "b": [2, 3]} + + +@pytest.mark.slow +@pytest.mark.write_disk +def test_read_parquet_only_loads_selected_columns_15098( + memory_usage_without_pyarrow: MemoryUsage, tmp_path: Path +) -> None: + """Only requested columns are loaded by ``read_parquet()``.""" + tmp_path.mkdir(exist_ok=True) + + # Each column will be about 8MB of RAM + series = pl.arange(0, 1_000_000, dtype=pl.Int64, eager=True) + + file_path = tmp_path / "multicolumn.parquet" + df = pl.DataFrame( + { + "a": series, + "b": series, + } + ) + df.write_parquet(file_path) + del df, series + + memory_usage_without_pyarrow.reset_tracking() + + # Only load one column: + df = pl.read_parquet([file_path], columns=["b"], rechunk=False) + del df + # Only one column's worth of memory should be used; 2 columns would be + # 16_000_000 at least, but there's some overhead. + # assert 8_000_000 < memory_usage_without_pyarrow.get_peak() < 13_000_000 + + +@pytest.mark.release +def test_max_statistic_parquet_writer() -> None: + # this hits the maximal page size + # so the row group will be split into multiple pages + # the page statistics need to be correctly reduced + # for this query to make sense + n = 150_000 + + # int64 is important to hit the page size + df = pl.int_range(0, n, eager=True, dtype=pl.Int64).alias("int").to_frame() + f = io.BytesIO() + df.write_parquet(f, statistics=True, use_pyarrow=False, row_group_size=n) + f.seek(0) + result = pl.scan_parquet(f).filter(pl.col("int") > n - 3).collect() + expected = pl.DataFrame({"int": [149998, 149999]}) + assert_frame_equal(result, expected) + + +@pytest.mark.slow +def test_hybrid_rle() -> None: + # 10_007 elements to test if not a nice multiple of 8 + n = 10_007 + literal_literal = [] + literal_rle = [] + for i in range(500): + literal_literal.append(np.repeat(i, 5)) + literal_literal.append(np.repeat(i + 2, 11)) + literal_rle.append(np.repeat(i, 5)) + literal_rle.append(np.repeat(i + 2, 15)) + literal_literal.append(np.random.randint(0, 10, size=2007)) + literal_rle.append(np.random.randint(0, 10, size=7)) + literal_literal = np.concatenate(literal_literal) + literal_rle = np.concatenate(literal_rle) + df = pl.DataFrame( + { + # Primitive types + "i64": pl.Series([1, 2], dtype=pl.Int64).sample(n, with_replacement=True), + "u64": pl.Series([1, 2], dtype=pl.UInt64).sample(n, with_replacement=True), + "i8": pl.Series([1, 2], dtype=pl.Int8).sample(n, with_replacement=True), + "u8": pl.Series([1, 2], dtype=pl.UInt8).sample(n, with_replacement=True), + "string": pl.Series(["abc", "def"], dtype=pl.String).sample( + n, with_replacement=True + ), + "categorical": pl.Series(["aaa", "bbb"], dtype=pl.Categorical).sample( + n, with_replacement=True + ), + # Fill up bit-packing buffer in middle of consecutive run + "large_bit_pack": np.concatenate( + [np.repeat(i, 5) for i in range(2000)] + + [np.random.randint(0, 10, size=7)] + ), + # Literal run that is not a multiple of 8 followed by consecutive + # run initially long enough to RLE but not after padding literal + "literal_literal": literal_literal, + # Literal run that is not a multiple of 8 followed by consecutive + # run long enough to RLE even after padding literal + "literal_rle": literal_rle, + # Final run not long enough to RLE + "final_literal": np.concatenate( + [np.random.randint(0, 100, 10_000), np.repeat(-1, 7)] + ), + # Final run long enough to RLE + "final_rle": np.concatenate( + [np.random.randint(0, 100, 9_998), np.repeat(-1, 9)] + ), + # Test filling up bit-packing buffer for encode_bool, + # which is only used to encode validities + "large_bit_pack_validity": [0, None] * 4092 + + [0] * 9 + + [1] * 9 + + [2] * 10 + + [0] * 1795, + } + ) + f = io.BytesIO() + df.write_parquet(f) + f.seek(0) + for col in pq.ParquetFile(f).metadata.to_dict()["row_groups"][0]["columns"]: + assert "RLE_DICTIONARY" in col["encodings"] + f.seek(0) + assert_frame_equal(pl.read_parquet(f), df) + + +@given( + df=dataframes( + allowed_dtypes=[ + pl.Null, + pl.List, + pl.Array, + pl.Int8, + pl.UInt8, + pl.UInt32, + pl.Int64, + # pl.Date, # Turned off because of issue #17599 + # pl.Time, # Turned off because of issue #17599 + pl.Binary, + pl.Float32, + pl.Float64, + pl.String, + pl.Boolean, + ], + min_size=1, + max_size=500, + ) +) +@pytest.mark.slow +def test_roundtrip_parametric(df: pl.DataFrame) -> None: + f = io.BytesIO() + df.write_parquet(f) + f.seek(0) + result = pl.read_parquet(f) + + assert_frame_equal(df, result) + + +def test_parquet_statistics_uint64_16683() -> None: + u64_max = (1 << 64) - 1 + df = pl.Series("a", [u64_max, 0], dtype=pl.UInt64).to_frame() + file = io.BytesIO() + df.write_parquet(file, statistics=True) + file.seek(0) + statistics = pq.read_metadata(file).row_group(0).column(0).statistics + + assert statistics.min == 0 + assert statistics.max == u64_max + + +@pytest.mark.slow +@pytest.mark.parametrize("nullable", [True, False]) +def test_read_byte_stream_split(nullable: bool) -> None: + rng = np.random.default_rng(123) + num_rows = 1_000 + values = rng.uniform(-1.0e6, 1.0e6, num_rows) + if nullable: + validity_mask = rng.integers(0, 2, num_rows).astype(np.bool_) + else: + validity_mask = None + + schema = pa.schema( + [ + pa.field("floats", type=pa.float32(), nullable=nullable), + pa.field("doubles", type=pa.float64(), nullable=nullable), + ] + ) + arrays = [pa.array(values, type=field.type, mask=validity_mask) for field in schema] + table = pa.Table.from_arrays(arrays, schema=schema) + df = cast(pl.DataFrame, pl.from_arrow(table)) + + f = io.BytesIO() + pq.write_table( + table, f, compression="snappy", use_dictionary=False, use_byte_stream_split=True + ) + + f.seek(0) + read = pl.read_parquet(f) + + assert_frame_equal(read, df) + + +@pytest.mark.slow +@pytest.mark.parametrize("rows_nullable", [True, False]) +@pytest.mark.parametrize("item_nullable", [True, False]) +def test_read_byte_stream_split_arrays( + item_nullable: bool, rows_nullable: bool +) -> None: + rng = np.random.default_rng(123) + num_rows = 1_000 + max_array_len = 10 + array_lengths = rng.integers(0, max_array_len + 1, num_rows) + if rows_nullable: + row_validity_mask = rng.integers(0, 2, num_rows).astype(np.bool_) + array_lengths[row_validity_mask] = 0 + row_validity_mask = pa.array(row_validity_mask) + else: + row_validity_mask = None + + offsets = np.zeros(num_rows + 1, dtype=np.int64) + np.cumsum(array_lengths, out=offsets[1:]) + num_values = offsets[-1] + values = rng.uniform(-1.0e6, 1.0e6, num_values) + + if item_nullable: + element_validity_mask = rng.integers(0, 2, num_values).astype(np.bool_) + else: + element_validity_mask = None + + schema = pa.schema( + [ + pa.field( + "floats", + type=pa.list_(pa.field("", pa.float32(), nullable=item_nullable)), + nullable=rows_nullable, + ), + pa.field( + "doubles", + type=pa.list_(pa.field("", pa.float64(), nullable=item_nullable)), + nullable=rows_nullable, + ), + ] + ) + arrays = [ + pa.ListArray.from_arrays( + pa.array(offsets), + pa.array(values, type=field.type.field(0).type, mask=element_validity_mask), + mask=row_validity_mask, + ) + for field in schema + ] + table = pa.Table.from_arrays(arrays, schema=schema) + df = cast(pl.DataFrame, pl.from_arrow(table)) + + f = io.BytesIO() + pq.write_table( + table, f, compression="snappy", use_dictionary=False, use_byte_stream_split=True + ) + + f.seek(0) + read = pl.read_parquet(f) + + assert_frame_equal(read, df) + + +def test_parquet_nested_null_array_17795() -> None: + f = io.BytesIO() + pl.DataFrame([{"struct": {"field": None}}]).write_parquet(f) + f.seek(0) + pq.read_table(f) + + +def test_parquet_record_batches_pyarrow_fixed_size_list_16614() -> None: + # @NOTE: + # The minimum that I could get it to crash which was ~132000, but let's + # just do 150000 to be sure. + n = 150000 + x = pl.DataFrame( + {"x": np.linspace((1, 2), (2 * n, 2 * n * 1), n, dtype=np.float32)}, + schema={"x": pl.Array(pl.Float32, 2)}, + ) + + f = io.BytesIO() + x.write_parquet(f) + f.seek(0) + b = pl.read_parquet(f, use_pyarrow=True) + + assert b["x"].shape[0] == n + assert_frame_equal(b, x) + + +def test_parquet_list_element_field_name() -> None: + f = io.BytesIO() + ( + pl.DataFrame( + { + "a": [[1, 2], [1, 1, 1]], + }, + schema={"a": pl.List(pl.Int64)}, + ).write_parquet(f, use_pyarrow=False) + ) + + f.seek(0) + schema_str = str(pq.read_schema(f)) + assert "" in schema_str + assert "child 0, element: int64" in schema_str + + +def test_nested_decimal() -> None: + df = pl.DataFrame( + { + "a": [ + {"f0": None}, + None, + ] + }, + schema={"a": pl.Struct({"f0": pl.Decimal(precision=38, scale=8)})}, + ) + test_round_trip(df) + + +def test_nested_non_uniform_primitive() -> None: + df = pl.DataFrame( + {"a": [{"x": 0, "y": None}]}, + schema={ + "a": pl.Struct( + { + "x": pl.Int16, + "y": pl.Int64, + } + ) + }, + ) + test_round_trip(df) + + +def test_parquet_nested_struct_17933() -> None: + df = pl.DataFrame( + {"a": [{"x": {"u": None}, "y": True}]}, + schema={ + "a": pl.Struct( + { + "x": pl.Struct({"u": pl.String}), + "y": pl.Boolean(), + } + ) + }, + ) + test_round_trip(df) + + +# This is fixed with POLARS_FORCE_MULTISCAN=1. Without it we have +# first_metadata.unwrap() on None. +@pytest.mark.may_fail_auto_streaming +def test_parquet_pyarrow_map() -> None: + xs = [ + [ + (0, 5), + (1, 10), + (2, 19), + (3, 96), + ] + ] + + table = pa.table( + [xs], + schema=pa.schema( + [ + ("x", pa.map_(pa.int32(), pa.int32(), keys_sorted=True)), + ] + ), + ) + + f = io.BytesIO() + pq.write_table(table, f) + + expected = pl.DataFrame( + { + "x": [ + {"key": 0, "value": 5}, + {"key": 1, "value": 10}, + {"key": 2, "value": 19}, + {"key": 3, "value": 96}, + ] + }, + schema={"x": pl.Struct({"key": pl.Int32, "value": pl.Int32})}, + ) + f.seek(0) + assert_frame_equal(pl.read_parquet(f).explode(["x"]), expected) + + # Test for https://github.com/pola-rs/polars/issues/21317 + # Specifying schema/allow_missing_columns + for allow_missing_columns in [True, False]: + assert_frame_equal( + pl.read_parquet( + f, + schema={"x": pl.List(pl.Struct({"key": pl.Int32, "value": pl.Int32}))}, + allow_missing_columns=allow_missing_columns, + ).explode(["x"]), + expected, + ) + + +@pytest.mark.parametrize( + ("s", "elem"), + [ + (pl.Series(["", "hello", "hi", ""], dtype=pl.String), ""), + (pl.Series([0, 1, 2, 0], dtype=pl.Int64), 0), + (pl.Series([[0], [1], [2], [0]], dtype=pl.Array(pl.Int64, 1)), [0]), + ( + pl.Series([[0, 1], [1, 2], [2, 3], [0, 1]], dtype=pl.Array(pl.Int64, 2)), + [0, 1], + ), + ], +) +def test_parquet_high_nested_null_17805( + s: pl.Series, elem: str | int | list[int] +) -> None: + test_round_trip( + pl.DataFrame({"a": s}).select( + pl.when(pl.col("a") == elem) + .then(pl.lit(None)) + .otherwise(pl.concat_list(pl.col("a").alias("b"))) + .alias("c") + ) + ) + + +def test_struct_plain_encoded_statistics() -> None: + df = pl.DataFrame( + { + "a": [None, None, None, None, {"x": None, "y": 0}], + }, + schema={"a": pl.Struct({"x": pl.Int8, "y": pl.Int8})}, + ) + + test_scan_round_trip(df) + + +@given( + df=dataframes( + min_size=5, + excluded_dtypes=[pl.Decimal, pl.Categorical], + allow_masked_out=False, # PyArrow does not support this + ) +) +def test_scan_round_trip_parametric(df: pl.DataFrame) -> None: + test_scan_round_trip(df) + + +def test_empty_rg_no_dict_page_18146() -> None: + df = pl.DataFrame( + { + "a": [], + }, + schema={"a": pl.String}, + ) + + f = io.BytesIO() + pq.write_table(df.to_arrow(), f, compression="NONE", use_dictionary=False) + f.seek(0) + assert_frame_equal(pl.read_parquet(f), df) + + +def test_write_sliced_lists_18069() -> None: + f = io.BytesIO() + a = pl.Series(3 * [None, ["$"] * 3], dtype=pl.List(pl.String)) + + before = pl.DataFrame({"a": a}).slice(4, 2) + before.write_parquet(f) + + f.seek(0) + after = pl.read_parquet(f) + + assert_frame_equal(before, after) + + +def test_null_array_dict_pages_18085() -> None: + test = pd.DataFrame( + [ + {"A": float("NaN"), "B": 3, "C": None}, + {"A": float("NaN"), "B": None, "C": None}, + ] + ) + + f = io.BytesIO() + test.to_parquet(f) + f.seek(0) + pl.read_parquet(f) + + +@given( + df=dataframes( + min_size=1, + max_size=1000, + allowed_dtypes=[ + pl.List, + pl.Int8, + pl.Int16, + pl.Int32, + pl.Int64, + pl.UInt8, + pl.UInt16, + pl.UInt32, + pl.UInt64, + ], + allow_masked_out=False, # PyArrow does not support this + ), + row_group_size=st.integers(min_value=10, max_value=1000), +) +def test_delta_encoding_roundtrip(df: pl.DataFrame, row_group_size: int) -> None: + f = io.BytesIO() + pq.write_table( + df.to_arrow(), + f, + compression="NONE", + use_dictionary=False, + column_encoding="DELTA_BINARY_PACKED", + write_statistics=False, + row_group_size=row_group_size, + ) + + f.seek(0) + assert_frame_equal(pl.read_parquet(f), df) + + +@given( + df=dataframes(min_size=1, max_size=1000, allowed_dtypes=[pl.String, pl.Binary]), + row_group_size=st.integers(min_value=10, max_value=1000), +) +def test_delta_length_byte_array_encoding_roundtrip( + df: pl.DataFrame, row_group_size: int +) -> None: + f = io.BytesIO() + pq.write_table( + df.to_arrow(), + f, + compression="NONE", + use_dictionary=False, + column_encoding="DELTA_LENGTH_BYTE_ARRAY", + write_statistics=False, + row_group_size=row_group_size, + ) + + f.seek(0) + assert_frame_equal(pl.read_parquet(f), df) + + +@given( + df=dataframes(min_size=1, max_size=1000, allowed_dtypes=[pl.String, pl.Binary]), + row_group_size=st.integers(min_value=10, max_value=1000), +) +def test_delta_strings_encoding_roundtrip( + df: pl.DataFrame, row_group_size: int +) -> None: + f = io.BytesIO() + pq.write_table( + df.to_arrow(), + f, + compression="NONE", + use_dictionary=False, + column_encoding="DELTA_BYTE_ARRAY", + write_statistics=False, + row_group_size=row_group_size, + ) + + f.seek(0) + assert_frame_equal(pl.read_parquet(f), df) + + +EQUALITY_OPERATORS = ["__eq__", "__lt__", "__le__", "__gt__", "__ge__"] +BOOLEAN_OPERATORS = ["__or__", "__and__"] + + +@given( + df=dataframes( + min_size=0, max_size=100, min_cols=2, max_cols=5, allowed_dtypes=[pl.Int32] + ), + first_op=st.sampled_from(EQUALITY_OPERATORS), + second_op=st.sampled_from( + [None] + + [ + (booljoin, eq) + for booljoin in BOOLEAN_OPERATORS + for eq in EQUALITY_OPERATORS + ] + ), + l1=st.integers(min_value=0, max_value=1000), + l2=st.integers(min_value=0, max_value=1000), + r1=st.integers(min_value=0, max_value=1000), + r2=st.integers(min_value=0, max_value=1000), +) +@pytest.mark.parametrize("parallel_st", ["auto", "prefiltered"]) +def test_predicate_filtering( + df: pl.DataFrame, + first_op: str, + second_op: None | tuple[str, str], + l1: int, + l2: int, + r1: int, + r2: int, + parallel_st: Literal["auto", "prefiltered"], +) -> None: + f = io.BytesIO() + df.write_parquet(f, row_group_size=5) + + cols = df.columns + + l1s = cols[l1 % len(cols)] + l2s = cols[l2 % len(cols)] + expr = (getattr(pl.col(l1s), first_op))(pl.col(l2s)) + + if second_op is not None: + r1s = cols[r1 % len(cols)] + r2s = cols[r2 % len(cols)] + expr = getattr(expr, second_op[0])( + (getattr(pl.col(r1s), second_op[1]))(pl.col(r2s)) + ) + + f.seek(0) + result = pl.scan_parquet(f, parallel=parallel_st).filter(expr).collect() + assert_frame_equal(result, df.filter(expr)) + + +@pytest.mark.parametrize( + "use_dictionary", + [False, True], +) +@pytest.mark.parametrize( + "data_page_size", + [1, None], +) +@given( + s=series( + min_size=1, + max_size=10, + excluded_dtypes=[ + pl.Decimal, + pl.Categorical, + pl.Enum, + pl.Struct, # See #19612. + ], + allow_masked_out=False, # PyArrow does not support this + ), + offset=st.integers(0, 10), + length=st.integers(0, 10), +) +def test_pyarrow_slice_roundtrip( + s: pl.Series, + use_dictionary: bool, + data_page_size: int | None, + offset: int, + length: int, +) -> None: + offset %= len(s) + 1 + length %= len(s) - offset + 1 + + f = io.BytesIO() + df = s.to_frame() + pq.write_table( + df.to_arrow(), + f, + compression="NONE", + use_dictionary=use_dictionary, + data_page_size=data_page_size, + ) + + f.seek(0) + scanned = pl.scan_parquet(f).slice(offset, length).collect() + assert_frame_equal(scanned, df.slice(offset, length)) + + +@given( + df=dataframes( + min_size=1, + max_size=5, + min_cols=1, + max_cols=1, + excluded_dtypes=[pl.Decimal, pl.Categorical, pl.Enum], + ), + offset=st.integers(0, 100), + length=st.integers(0, 100), +) +def test_slice_roundtrip(df: pl.DataFrame, offset: int, length: int) -> None: + offset %= df.height + 1 + length %= df.height - offset + 1 + + f = io.BytesIO() + df.write_parquet(f) + + f.seek(0) + scanned = pl.scan_parquet(f).slice(offset, length).collect() + assert_frame_equal(scanned, df.slice(offset, length)) + + +def test_struct_prefiltered() -> None: + df = pl.DataFrame({"a": {"x": 1, "y": 2}}) + f = io.BytesIO() + df.write_parquet(f) + + f.seek(0) + ( + pl.scan_parquet(f, parallel="prefiltered") + .filter(pl.col("a").struct.field("x") == 1) + .collect() + ) + + +@pytest.mark.parametrize( + "data", + [ + ( + [{"x": ""}, {"x": "0"}], + pa.struct([pa.field("x", pa.string(), nullable=True)]), + ), + ( + [{"x": ""}, {"x": "0"}], + pa.struct([pa.field("x", pa.string(), nullable=False)]), + ), + ([[""], ["0"]], pa.list_(pa.field("item", pa.string(), nullable=False))), + ([[""], ["0"]], pa.list_(pa.field("item", pa.string(), nullable=True))), + ([[""], ["0"]], pa.list_(pa.field("item", pa.string(), nullable=False), 1)), + ([[""], ["0"]], pa.list_(pa.field("item", pa.string(), nullable=True), 1)), + ( + [["", "1"], ["0", "2"]], + pa.list_(pa.field("item", pa.string(), nullable=False), 2), + ), + ( + [["", "1"], ["0", "2"]], + pa.list_(pa.field("item", pa.string(), nullable=True), 2), + ), + ], +) +@pytest.mark.parametrize("nullable", [False, True]) +def test_nested_skip_18303( + data: tuple[list[dict[str, str] | list[str]], pa.DataType], + nullable: bool, +) -> None: + schema = pa.schema([pa.field("a", data[1], nullable=nullable)]) + tb = pa.table({"a": data[0]}, schema=schema) + + f = io.BytesIO() + pq.write_table(tb, f) + + f.seek(0) + scanned = pl.scan_parquet(f).slice(1, 1).collect() + + assert_frame_equal(scanned, pl.DataFrame(tb).slice(1, 1)) + + +def test_nested_span_multiple_pages_18400() -> None: + width = 4100 + df = pl.DataFrame( + [ + pl.Series( + "a", + [ + list(range(width)), + list(range(width)), + ], + pl.Array(pl.Int64, width), + ), + ] + ) + + f = io.BytesIO() + pq.write_table( + df.to_arrow(), + f, + use_dictionary=False, + data_page_size=1024, + column_encoding={"a": "PLAIN"}, + ) + + f.seek(0) + assert_frame_equal(df.head(1), pl.read_parquet(f, n_rows=1)) + + +@given( + df=dataframes( + min_size=0, + max_size=1000, + min_cols=2, + max_cols=5, + excluded_dtypes=[pl.Decimal, pl.Categorical, pl.Enum, pl.Array], + include_cols=[column("filter_col", pl.Boolean, allow_null=False)], + ), +) +def test_parametric_small_page_mask_filtering(df: pl.DataFrame) -> None: + f = io.BytesIO() + df.write_parquet(f, data_page_size=1024) + + expr = pl.col("filter_col") + f.seek(0) + result = pl.scan_parquet(f, parallel="prefiltered").filter(expr).collect() + assert_frame_equal(result, df.filter(expr)) + + +@pytest.mark.parametrize( + "value", + [ + "abcd", + 0, + 0.0, + False, + ], +) +def test_different_page_validity_across_pages(value: str | int | float | bool) -> None: + df = pl.DataFrame( + { + "a": [None] + [value] * 4000, + } + ) + + f = io.BytesIO() + pq.write_table( + df.to_arrow(), + f, + use_dictionary=False, + data_page_size=1024, + column_encoding={"a": "PLAIN"}, + ) + + f.seek(0) + assert_frame_equal(df, pl.read_parquet(f)) + + +@given( + df=dataframes( + min_size=0, + max_size=100, + min_cols=2, + max_cols=5, + allowed_dtypes=[pl.String, pl.Binary], + include_cols=[ + column("filter_col", pl.Int8, st.integers(0, 1), allow_null=False) + ], + ), +) +def test_delta_length_byte_array_prefiltering(df: pl.DataFrame) -> None: + cols = df.columns + + encodings = dict.fromkeys(cols, "DELTA_LENGTH_BYTE_ARRAY") + encodings["filter_col"] = "PLAIN" + + f = io.BytesIO() + pq.write_table( + df.to_arrow(), + f, + use_dictionary=False, + column_encoding=encodings, + ) + + f.seek(0) + expr = pl.col("filter_col") == 0 + result = pl.scan_parquet(f, parallel="prefiltered").filter(expr).collect() + assert_frame_equal(result, df.filter(expr)) + + +@given( + df=dataframes( + min_size=0, + max_size=10, + min_cols=1, + max_cols=5, + excluded_dtypes=[pl.Decimal, pl.Categorical, pl.Enum], + include_cols=[ + column("filter_col", pl.Int8, st.integers(0, 1), allow_null=False) + ], + ), +) +def test_general_prefiltering(df: pl.DataFrame) -> None: + f = io.BytesIO() + df.write_parquet(f) + + expr = pl.col("filter_col") == 0 + + f.seek(0) + result = pl.scan_parquet(f, parallel="prefiltered").filter(expr).collect() + assert_frame_equal(result, df.filter(expr)) + + +@given( + df=dataframes( + min_size=0, + max_size=10, + min_cols=1, + max_cols=5, + excluded_dtypes=[pl.Decimal, pl.Categorical, pl.Enum], + include_cols=[column("filter_col", pl.Boolean, allow_null=False)], + ), +) +def test_row_index_prefiltering(df: pl.DataFrame) -> None: + f = io.BytesIO() + df.write_parquet(f) + + expr = pl.col("filter_col") + + f.seek(0) + result = ( + pl.scan_parquet( + f, row_index_name="ri", row_index_offset=42, parallel="prefiltered" + ) + .filter(expr) + .collect() + ) + assert_frame_equal(result, df.with_row_index("ri", 42).filter(expr)) + + +def test_empty_parquet() -> None: + f_pd = io.BytesIO() + f_pl = io.BytesIO() + + pd.DataFrame().to_parquet(f_pd) + pl.DataFrame().write_parquet(f_pl) + + f_pd.seek(0) + f_pl.seek(0) + + empty_from_pd = pl.read_parquet(f_pd) + assert empty_from_pd.shape == (0, 0) + + empty_from_pl = pl.read_parquet(f_pl) + assert empty_from_pl.shape == (0, 0) + + +@pytest.mark.parametrize( + "strategy", + ["columns", "row_groups", "prefiltered"], +) +@pytest.mark.write_disk +def test_row_index_projection_pushdown_18463( + tmp_path: Path, strategy: pl.ParallelStrategy +) -> None: + tmp_path.mkdir(exist_ok=True) + f = tmp_path / "test.parquet" + + pl.DataFrame({"A": [1, 4], "B": [2, 5]}).write_parquet(f) + + df = pl.scan_parquet(f, parallel=strategy).with_row_index() + + assert_frame_equal(df.select("index").collect(), df.collect().select("index")) + + df = pl.scan_parquet(f, parallel=strategy).with_row_index("other_idx_name") + + assert_frame_equal( + df.select("other_idx_name").collect(), df.collect().select("other_idx_name") + ) + + df = pl.scan_parquet(f, parallel=strategy).with_row_index(offset=42) + + assert_frame_equal(df.select("index").collect(), df.collect().select("index")) + + df = pl.scan_parquet(f, parallel=strategy).with_row_index() + + assert_frame_equal( + df.select("index").slice(1, 1).collect(), + df.collect().select("index").slice(1, 1), + ) + + +@pytest.mark.write_disk +def test_write_binary_open_file(tmp_path: Path) -> None: + df = pl.DataFrame({"a": [1, 2, 3]}) + + path = tmp_path / "test.parquet" + + with path.open("wb") as f_write: + df.write_parquet(f_write) + + out = pl.read_parquet(path) + assert_frame_equal(out, df) + + +def test_prefilter_with_projection() -> None: + f = io.BytesIO() + pl.DataFrame({"a": [1], "b": [2]}).write_parquet(f) + + f.seek(0) + ( + pl.scan_parquet(f, parallel="prefiltered") + .filter(pl.col.a == 1) + .select(pl.col.a) + .collect() + ) + + +@pytest.mark.parametrize("parallel_strategy", ["prefiltered", "row_groups"]) +@pytest.mark.parametrize( + "df", + [ + pl.DataFrame({"x": 1, "y": 1}), + pl.DataFrame({"x": 1, "b": 1, "y": 1}), # hive columns in file + ], +) +@pytest.mark.write_disk +def test_prefilter_with_hive_19766( + tmp_path: Path, df: pl.DataFrame, parallel_strategy: str +) -> None: + tmp_path.mkdir(exist_ok=True) + (tmp_path / "a=1/b=1").mkdir(exist_ok=True, parents=True) + + df.write_parquet(tmp_path / "a=1/b=1/1") + expect = df.with_columns(a=pl.lit(1, dtype=pl.Int64), b=pl.lit(1, dtype=pl.Int64)) + + lf = pl.scan_parquet(tmp_path, parallel=parallel_strategy) # type: ignore[arg-type] + + for predicate in [ + pl.col("a") == 1, + pl.col("x") == 1, + (pl.col("a") == 1) & (pl.col("x") == 1), + pl.col("b") == 1, + pl.col("y") == 1, + (pl.col("a") == 1) & (pl.col("b") == 1), + ]: + assert_frame_equal( + lf.filter(predicate).collect(), + expect, + ) + + +@pytest.mark.parametrize("parallel", ["columns", "row_groups", "prefiltered", "none"]) +@pytest.mark.parametrize("streaming", [True, False]) +@pytest.mark.parametrize("projection", [pl.all(), pl.col("b")]) +@pytest.mark.write_disk +def test_allow_missing_columns( + tmp_path: Path, + parallel: str, + streaming: bool, + projection: pl.Expr, +) -> None: + tmp_path.mkdir(exist_ok=True) + dfs = [pl.DataFrame({"a": 1, "b": 1}), pl.DataFrame({"a": 2})] + paths = [tmp_path / "1", tmp_path / "2"] + + for df, path in zip(dfs, paths): + df.write_parquet(path) + + expected = pl.DataFrame({"a": [1, 2], "b": [1, None]}).select(projection) + + with pytest.raises( + (pl.exceptions.ColumnNotFoundError, pl.exceptions.SchemaError), + match="enabling `allow_missing_columns`", + ): + pl.read_parquet(paths, parallel=parallel) # type: ignore[arg-type] + + with pytest.raises( + (pl.exceptions.ColumnNotFoundError, pl.exceptions.SchemaError), + match="enabling `allow_missing_columns`", + ): + pl.scan_parquet(paths, parallel=parallel).select(projection).collect( # type: ignore[arg-type] + engine="streaming" if streaming else "in-memory" + ) + + assert_frame_equal( + pl.read_parquet( + paths, + parallel=parallel, # type: ignore[arg-type] + allow_missing_columns=True, + ).select(projection), + expected, + ) + + assert_frame_equal( + pl.scan_parquet(paths, parallel=parallel, allow_missing_columns=True) # type: ignore[arg-type] + .select(projection) + .collect(engine="streaming" if streaming else "in-memory"), + expected, + ) + + +def test_nested_nonnullable_19158() -> None: + # Bug is based on the top-level struct being nullable and the inner list + # not being nullable. + tbl = pa.table( + { + "a": [{"x": [1]}, None, {"x": [1, 2]}, None], + }, + schema=pa.schema( + [ + pa.field( + "a", + pa.struct([pa.field("x", pa.list_(pa.int8()), nullable=False)]), + nullable=True, + ) + ] + ), + ) + + f = io.BytesIO() + pq.write_table(tbl, f) + + f.seek(0) + assert_frame_equal(pl.read_parquet(f), pl.DataFrame(tbl)) + + +D = Decimal + + +@pytest.mark.parametrize("precision", range(1, 37, 2)) +@pytest.mark.parametrize( + "nesting", + [ + # Struct + lambda t: ([{"x": None}, None], pl.Struct({"x": t})), + lambda t: ([None, {"x": None}], pl.Struct({"x": t})), + lambda t: ([{"x": D("1.5")}, None], pl.Struct({"x": t})), + lambda t: ([{"x": D("1.5")}, {"x": D("4.8")}], pl.Struct({"x": t})), + # Array + lambda t: ([[None, None, D("8.2")], None], pl.Array(t, 3)), + lambda t: ([None, [None, D("8.9"), None]], pl.Array(t, 3)), + lambda t: ([[D("1.5"), D("3.7"), D("4.1")], None], pl.Array(t, 3)), + lambda t: ( + [[D("1.5"), D("3.7"), D("4.1")], [D("2.8"), D("5.2"), D("8.9")]], + pl.Array(t, 3), + ), + # List + lambda t: ([[None, D("8.2")], None], pl.List(t)), + lambda t: ([None, [D("8.9"), None]], pl.List(t)), + lambda t: ([[D("1.5"), D("4.1")], None], pl.List(t)), + lambda t: ([[D("1.5"), D("3.7"), D("4.1")], [D("2.8"), D("8.9")]], pl.List(t)), + ], +) +def test_decimal_precision_nested_roundtrip( + nesting: Callable[[pl.DataType], tuple[list[Any], pl.DataType]], + precision: int, +) -> None: + # Limit the context as to not disturb any other tests + with decimal.localcontext() as ctx: + ctx.prec = precision + + decimal_dtype = pl.Decimal(precision=precision) + values, dtype = nesting(decimal_dtype) + + df = pl.Series("a", values, dtype).to_frame() + + test_round_trip(df) + + +@pytest.mark.parametrize("parallel", ["prefiltered", "columns", "row_groups", "auto"]) +def test_conserve_sortedness( + monkeypatch: Any, capfd: Any, parallel: pl.ParallelStrategy +) -> None: + f = io.BytesIO() + + df = pl.DataFrame( + { + "a": [1, 2, 3, 4, 5, None], + "b": [1.0, 2.0, 3.0, 4.0, 5.0, None], + "c": [None, 5, 4, 3, 2, 1], + "d": [None, 5.0, 4.0, 3.0, 2.0, 1.0], + "a_nosort": [1, 2, 3, 4, 5, None], + "f": range(6), + } + ) + + pq.write_table( + df.to_arrow(), + f, + sorting_columns=[ + pq.SortingColumn(0, False, False), + pq.SortingColumn(1, False, False), + pq.SortingColumn(2, True, True), + pq.SortingColumn(3, True, True), + ], + ) + + f.seek(0) + + monkeypatch.setenv("POLARS_VERBOSE", "1") + + df = pl.scan_parquet(f, parallel=parallel).filter(pl.col.f > 1).collect() + + captured = capfd.readouterr().err + + # @NOTE: We don't conserve sortedness for anything except integers at the + # moment. + assert captured.count("Parquet conserved SortingColumn for column chunk of") == 2 + assert ( + "Parquet conserved SortingColumn for column chunk of 'a' to Ascending" + in captured + ) + assert ( + "Parquet conserved SortingColumn for column chunk of 'c' to Descending" + in captured + ) + + +@pytest.mark.parametrize("use_dictionary", [True, False]) +@pytest.mark.parametrize( + "values", + [ + (size, x) + for size in [1, 2, 3, 4, 8, 12, 15, 16, 32] + for x in [ + [list(range(size)), list(range(7, 7 + size))], + [list(range(size)), None], + [list(range(i, i + size)) for i in range(13)], + [list(range(i, i + size)) if i % 3 < 2 else None for i in range(13)], + ] + ], +) +@pytest.mark.parametrize( + "filt", + [ + lambda _: None, + lambda _: pl.col.f > 0, + lambda _: pl.col.f > 1, + lambda _: pl.col.f < 5, + lambda _: pl.col.f % 2 == 0, + lambda _: pl.col.f % 5 < 4, + lambda values: (0, min(1, len(values))), + lambda _: (1, 1), + lambda _: (-2, 1), + lambda values: (1, len(values) - 2), + ], +) +def test_fixed_size_binary( + use_dictionary: bool, + values: tuple[int, list[None | list[int]]], + filt: Callable[[list[None | list[int]]], None | pl.Expr | tuple[int, int]], +) -> None: + size, elems = values + bs = [bytes(v) if v is not None else None for v in elems] + + tbl = pa.table( + { + "a": bs, + "f": range(len(bs)), + }, + schema=pa.schema( + [ + pa.field("a", pa.binary(length=size), nullable=True), + pa.field("f", pa.int32(), nullable=True), + ] + ), + ) + + df = pl.DataFrame(tbl) + + f = io.BytesIO() + pq.write_table(tbl, f, use_dictionary=use_dictionary) + + f.seek(0) + + loaded: pl.DataFrame + if isinstance(filt, pl.Expr): + loaded = pl.scan_parquet(f).filter(filt).collect() + df = df.filter(filt) + elif isinstance(filt, tuple): + loaded = pl.scan_parquet(f).slice(filt[0], filt[1]).collect() + df = df.slice(filt[0], filt[1]) + else: + loaded = pl.read_parquet(f) + + assert_frame_equal(loaded, df) + + +def test_decode_f16() -> None: + values = [float("nan"), 0.0, 0.5, 1.0, 1.5] + + table = pa.Table.from_pydict( + { + "x": pa.array(np.array(values, dtype=np.float16), type=pa.float16()), + } + ) + + df = pl.Series("x", values, pl.Float32).to_frame() + + f = io.BytesIO() + pq.write_table(table, f) + + f.seek(0) + assert_frame_equal(pl.read_parquet(f), df) + + f.seek(0) + assert_frame_equal( + pl.scan_parquet(f).filter(pl.col.x > 0.5).collect(), + df.filter(pl.col.x > 0.5), + ) + + f.seek(0) + assert_frame_equal( + pl.scan_parquet(f).slice(1, 3).collect(), + df.slice(1, 3), + ) + + +def test_invalid_utf8_binary() -> None: + a = pl.Series("a", [b"\x80"], pl.Binary).to_frame() + f = io.BytesIO() + + a.write_parquet(f) + f.seek(0) + out = pl.read_parquet(f) + + assert_frame_equal(a, out) + + +@pytest.mark.parametrize( + "dtype", + [ + pl.Null, + pl.Int8, + pl.Int32, + pl.Datetime(), + pl.String, + pl.Binary, + pl.Boolean, + pl.Struct({"x": pl.Int32}), + pl.List(pl.Int32), + pl.Array(pl.Int32, 0), + pl.Array(pl.Int32, 2), + ], +) +@pytest.mark.parametrize( + "filt", + [ + pl.col.f == 0, + pl.col.f != 0, + pl.col.f == 1, + pl.col.f != 1, + pl.col.f == 2, + pl.col.f != 2, + pl.col.f == 3, + pl.col.f != 3, + ], +) +def test_filter_only_invalid(dtype: pl.DataType, filt: pl.Expr) -> None: + df = pl.DataFrame( + [ + pl.Series("a", [None, None, None], dtype), + pl.Series("f", range(3), pl.Int32), + ] + ) + + f = io.BytesIO() + + df.write_parquet(f) + f.seek(0) + out = pl.scan_parquet(f, parallel="prefiltered").filter(filt).collect() + + assert_frame_equal(df.filter(filt), out) + + +def test_nested_nulls() -> None: + df = pl.Series( + "a", + [ + [None, None], + None, + [None, 1], + [None, None], + [2, None], + ], + pl.Array(pl.Int32, 2), + ).to_frame() + + f = io.BytesIO() + df.write_parquet(f) + + f.seek(0) + out = pl.read_parquet(f) + assert_frame_equal(out, df) + + +@pytest.mark.parametrize("content", [[], [None], [None, 0.0]]) +def test_nested_dicts(content: list[float | None]) -> None: + df = pl.Series("a", [content], pl.List(pl.Float64)).to_frame() + + f = io.BytesIO() + df.write_parquet(f, use_pyarrow=True) + f.seek(0) + assert_frame_equal(df, pl.read_parquet(f)) + + +@pytest.mark.parametrize( + "leading_nulls", + [ + [], + [None] * 7, + ], +) +@pytest.mark.parametrize( + "trailing_nulls", + [ + [], + [None] * 7, + ], +) +@pytest.mark.parametrize( + "first_chunk", + # Create both RLE and Bitpacked chunks + [ + [1] * 57, + [1 if i % 7 < 3 and i % 5 > 3 else None for i in range(57)], + list(range(57)), + [i if i % 7 < 3 and i % 5 > 3 else None for i in range(57)], + ], +) +@pytest.mark.parametrize( + "second_chunk", + # Create both RLE and Bitpacked chunks + [ + [2] * 57, + [2 if i % 7 < 3 and i % 5 > 3 else None for i in range(57)], + list(range(57)), + [i if i % 7 < 3 and i % 5 > 3 else None for i in range(57)], + ], +) +@pytest.mark.slow +def test_dict_slices( + leading_nulls: list[None], + trailing_nulls: list[None], + first_chunk: list[None | int], + second_chunk: list[None | int], +) -> None: + df = pl.Series( + "a", leading_nulls + first_chunk + second_chunk + trailing_nulls, pl.Int64 + ).to_frame() + + f = io.BytesIO() + df.write_parquet(f) + + for offset in chain([0, 1, 2], range(3, df.height, 3)): + for length in chain([df.height, 1, 2], range(3, df.height - offset, 3)): + f.seek(0) + assert_frame_equal( + pl.scan_parquet(f).slice(offset, length).collect(), + df.slice(offset, length), + ) + + +@pytest.mark.parametrize( + "mask", + [ + [i % 13 < 3 and i % 17 > 3 for i in range(57 * 2)], + [False] * 23 + [True] * 68 + [False] * 23, + [False] * 23 + [True] * 24 + [False] * 20 + [True] * 24 + [False] * 23, + [True] + [False] * 22 + [True] * 24 + [False] * 20 + [True] * 24 + [False] * 23, + [False] * 23 + [True] * 24 + [False] * 20 + [True] * 24 + [False] * 22 + [True], + [True] + + [False] * 22 + + [True] * 24 + + [False] * 20 + + [True] * 24 + + [False] * 22 + + [True], + [False] * 56 + [True] * 58, + [False] * 57 + [True] * 57, + [False] * 58 + [True] * 56, + [True] * 56 + [False] * 58, + [True] * 57 + [False] * 57, + [True] * 58 + [False] * 56, + ], +) +@pytest.mark.parametrize( + "first_chunk", + # Create both RLE and Bitpacked chunks + [ + [1] * 57, + [1 if i % 7 < 3 and i % 5 > 3 else None for i in range(57)], + list(range(57)), + [i if i % 7 < 3 and i % 5 > 3 else None for i in range(57)], + ], +) +@pytest.mark.parametrize( + "second_chunk", + # Create both RLE and Bitpacked chunks + [ + [2] * 57, + [2 if i % 7 < 3 and i % 5 > 3 else None for i in range(57)], + list(range(57)), + [i if i % 7 < 3 and i % 5 > 3 else None for i in range(57)], + ], +) +def test_dict_masked( + mask: list[bool], + first_chunk: list[None | int], + second_chunk: list[None | int], +) -> None: + df = pl.DataFrame( + [ + pl.Series("a", first_chunk + second_chunk, pl.Int64), + pl.Series("f", mask, pl.Boolean), + ] + ) + + f = io.BytesIO() + df.write_parquet(f) + + f.seek(0) + assert_frame_equal( + pl.scan_parquet(f, parallel="prefiltered").filter(pl.col.f).collect(), + df.filter(pl.col.f), + ) + + +@pytest.mark.usefixtures("test_global_and_local") +@pytest.mark.may_fail_auto_streaming +def test_categorical_sliced_20017() -> None: + f = io.BytesIO() + df = ( + pl.Series("a", ["a", None]) + .to_frame() + .with_columns(pl.col.a.cast(pl.Categorical)) + ) + df.write_parquet(f) + + f.seek(0) + assert_frame_equal( + pl.read_parquet(f, n_rows=1), + df.head(1), + ) + + +@given( + s=series(name="a", dtype=pl.String, min_size=7, max_size=7), + mask=series( + name="mask", dtype=pl.Boolean, min_size=7, max_size=7, allow_null=False + ), +) +def test_categorical_parametric_masked(s: pl.Series, mask: pl.Series) -> None: + f = io.BytesIO() + + with pl.StringCache(): + df = pl.DataFrame([s, mask]).with_columns(pl.col.a.cast(pl.Categorical)) + df.write_parquet(f) + + f.seek(0) + assert_frame_equal( + pl.scan_parquet(f, parallel="prefiltered").filter(pl.col.mask).collect(), + df.filter(pl.col.mask), + ) + + +@given( + s=series(name="a", dtype=pl.String, min_size=7, max_size=7), + start=st.integers(0, 6), + length=st.integers(1, 7), +) +def test_categorical_parametric_sliced(s: pl.Series, start: int, length: int) -> None: + length = min(7 - start, length) + + f = io.BytesIO() + + with pl.StringCache(): + df = s.to_frame().with_columns(pl.col.a.cast(pl.Categorical)) + df.write_parquet(f) + + f.seek(0) + assert_frame_equal( + pl.scan_parquet(f).slice(start, length).collect(), + df.slice(start, length), + ) + + +@pytest.mark.write_disk +def test_prefilter_with_projection_column_order_20175(tmp_path: Path) -> None: + path = tmp_path / "1" + + pl.DataFrame({"a": 1, "b": 1, "c": 1, "d": 1, "e": 1}).write_parquet(path) + + q = ( + pl.scan_parquet(path, parallel="prefiltered") + .filter(pl.col("a") == 1) + .select("a", "d", "c") + ) + + assert_frame_equal(q.collect(), pl.DataFrame({"a": 1, "d": 1, "c": 1})) + + f = io.BytesIO() + + pl.read_csv(b"""\ +c0,c1,c2,c3,c4,c5,c6,c7,c8,c9,c10 +1,1,1,1,1,1,1,1,1,1,1 +1,1,1,1,1,1,1,1,1,1,1 +""").write_parquet(f) + + f.seek(0) + + q = ( + pl.scan_parquet( + f, + rechunk=True, + parallel="prefiltered", + ) + .filter( + pl.col("c0") == 1, + ) + .select("c0", "c9", "c3") + ) + + assert_frame_equal( + q.collect(), + pl.read_csv(b"""\ +c0,c9,c3 +1,1,1 +1,1,1 +"""), + ) + + +def test_utf8_verification_with_slice_20174() -> None: + f = io.BytesIO() + pq.write_table( + pl.Series("s", ["a", "a" * 128]).to_frame().to_arrow(), f, use_dictionary=False + ) + + f.seek(0) + assert_frame_equal( + pl.scan_parquet(f).head(1).collect(), + pl.Series("s", ["a"]).to_frame(), + ) + + +@pytest.mark.parametrize("parallel", ["prefiltered", "row_groups"]) +@pytest.mark.parametrize( + "projection", + [ + {"a": pl.Int64(), "b": pl.Int64()}, + {"b": pl.Int64(), "a": pl.Int64()}, + ], +) +def test_parquet_prefiltered_unordered_projection_20175( + parallel: str, projection: dict[str, pl.DataType] +) -> None: + df = pl.DataFrame( + [ + pl.Series("a", [0], pl.Int64), + pl.Series("b", [0], pl.Int64), + ] + ) + + f = io.BytesIO() + df.write_parquet(f) + + f.seek(0) + out = ( + pl.scan_parquet(f, parallel=parallel) # type: ignore[arg-type] + .filter(pl.col.a >= 0) + .select(*projection.keys()) + .collect() + ) + + assert out.schema == projection + + +def test_parquet_unsupported_dictionary_to_pl_17945() -> None: + t = pa.table( + { + "col1": pa.DictionaryArray.from_arrays([0, 0, None, 1], [42, 1337]), + }, + schema=pa.schema({"col1": pa.dictionary(pa.uint32(), pa.int64())}), + ) + + f = io.BytesIO() + pq.write_table(t, f, use_dictionary=False) + f.truncate() + + f.seek(0) + assert_series_equal( + pl.Series("col1", [42, 42, None, 1337], pl.Int64), + pl.read_parquet(f).to_series(), + ) + + f.seek(0) + pq.write_table(t, f) + f.truncate() + + f.seek(0) + assert_series_equal( + pl.Series("col1", [42, 42, None, 1337], pl.Int64), + pl.read_parquet(f).to_series(), + ) + + +@pytest.mark.may_fail_auto_streaming +def test_parquet_cast_to_cat() -> None: + t = pa.table( + { + "col1": pa.DictionaryArray.from_arrays([0, 0, None, 1], ["A", "B"]), + }, + schema=pa.schema({"col1": pa.dictionary(pa.uint32(), pa.string())}), + ) + + f = io.BytesIO() + pq.write_table(t, f, use_dictionary=False) + f.truncate() + + f.seek(0) + assert_series_equal( + pl.Series("col1", ["A", "A", None, "B"], pl.Categorical), + pl.read_parquet(f).to_series(), + ) + + f.seek(0) + pq.write_table(t, f) + f.truncate() + + f.seek(0) + assert_series_equal( + pl.Series("col1", ["A", "A", None, "B"], pl.Categorical), + pl.read_parquet(f).to_series(), + ) + + +def test_parquet_roundtrip_lex_cat_20288() -> None: + f = io.BytesIO() + df = pl.Series("a", ["A", "B"], pl.Categorical(ordering="lexical")).to_frame() + df.write_parquet(f) + f.seek(0) + dt = pl.scan_parquet(f).collect_schema()["a"] + assert isinstance(dt, pl.Categorical) + assert dt.ordering == "lexical" + + +def test_from_parquet_string_cache_20271() -> None: + with pl.StringCache(): + f = io.BytesIO() + s = pl.Series("a", ["A", "B", "C"], pl.Categorical) + df = pl.Series("b", ["D", "E"], pl.Categorical).to_frame() + df.write_parquet(f) + f.seek(0) + df = pl.read_parquet(f) + + assert_series_equal( + s.to_physical(), pl.Series("a", [0, 1, 2]), check_dtypes=False + ) + assert_series_equal(df.to_series(), pl.Series("b", ["D", "E"], pl.Categorical)) + assert_series_equal( + df.to_series().to_physical(), pl.Series("b", [3, 4]), check_dtypes=False + ) + + +def test_boolean_slice_pushdown_20314() -> None: + s = pl.Series("a", [None, False, True]) + f = io.BytesIO() + + s.to_frame().write_parquet(f) + + f.seek(0) + assert pl.scan_parquet(f).slice(2, 1).collect().item() + + +def test_load_pred_pushdown_fsl_19241() -> None: + f = io.BytesIO() + + fsl = pl.Series("a", [[[1, 2]]], pl.Array(pl.Array(pl.Int8, 2), 1)) + filt = pl.Series("f", [1]) + + pl.DataFrame([fsl, filt]).write_parquet(f) + + f.seek(0) + q = pl.scan_parquet(f, parallel="prefiltered").filter(pl.col.f != 4) + + assert_frame_equal(q.collect(), pl.DataFrame([fsl, filt])) + + +def test_struct_list_statistics_20510() -> None: + # Test PyArrow - Utf8ViewArray + data = { + "name": ["a", "b"], + "data": [ + {"title": "Title", "data": [0, 1, 3]}, + {"title": "Title", "data": [0, 1, 3]}, + ], + } + df = pl.DataFrame( + data, + schema=pl.Schema( + { + "name": pl.String(), + "data": pl.Struct( + { + "title": pl.String, + "data": pl.List(pl.Int64), + } + ), + } + ), + ) + + f = io.BytesIO() + df.write_parquet(f) + f.seek(0) + result = pl.scan_parquet(f).filter(pl.col("name") == "b").collect() + + assert_frame_equal(result, df.filter(pl.col("name") == "b")) + + # Test PyArrow - Utf8Array + tb = pa.table( + data, + schema=pa.schema( + [ + ("name", pa.string()), + ( + "data", + pa.struct( + [ + ("title", pa.string()), + ("data", pa.list_(pa.int64())), + ] + ), + ), + ] + ), + ) + + f.seek(0) + pq.write_table(tb, f) + f.truncate() + f.seek(0) + result = pl.scan_parquet(f).filter(pl.col("name") == "b").collect() + + assert_frame_equal(result, df.filter(pl.col("name") == "b")) + + +def test_required_masked_skip_values_20809(monkeypatch: Any) -> None: + df = pl.DataFrame( + [pl.Series("a", list(range(20)) + [42] * 15), pl.Series("b", range(35))] + ) + needle = [16, 33] + + f = io.BytesIO() + df.write_parquet(f) + + f.seek(0) + monkeypatch.setenv("POLARS_PQ_PREFILTERED_MASK", "pre") + df1 = ( + pl.scan_parquet(f, parallel="prefiltered") + .filter(pl.col.b.is_in(needle)) + .collect() + ) + + f.seek(0) + df2 = pl.read_parquet(f, parallel="columns").filter(pl.col.b.is_in(needle)) + + assert_frame_equal(df1, df2) + + +def get_tests_from_dtype( + dtype: pl.DataType, f: Callable[[int], Any] +) -> list[tuple[pl.DataType, list[Any], list[Any]]]: + return [ + (dtype, [f(i) for i in range(10)], [f(i) for i in range(11)]), + ( + dtype, + [f(i) for i in range(1337)], + [f(i) for i in [0, 1, 5, 7, 101, 1023, 1336, 1337, 1338]], + ), + ( + dtype, + list( + functools.reduce( + lambda x, y: list(x) + y, + ([f(i)] * (i % 13) for i in range(1337)), + [], + ) + ), + [f(i) for i in [0, 1, 5, 7, 101, 1023, 1336, 1337, 1338]], + ), + ( + dtype, + [f(5)] * 37 + [f(10)] * 61 + [f(1996)] * 21, + [f(i) for i in [1, 5, 10, 1996]], + ), + ] + + +@pytest.mark.parametrize("strategy", ["columns", "prefiltered"]) +@pytest.mark.parametrize( + ("dtype", "values", "needles"), + get_tests_from_dtype(pl.Int8(), lambda x: (x % 256) - 128) + + get_tests_from_dtype(pl.Int32(), lambda x: x % 256) + + get_tests_from_dtype(pl.Date(), lambda x: date(year=1 + x, month=10, day=5)) + + get_tests_from_dtype(pl.String(), lambda x: str(x)) + + get_tests_from_dtype(pl.String(), lambda x: "i" * x) + + get_tests_from_dtype(pl.String(), lambda x: f"long_strings_with_the_number_{x}"), +) +def test_equality_filter( + strategy: ParallelStrategy, + dtype: pl.DataType, + values: list[Any], + needles: list[Any], +) -> None: + df = pl.DataFrame( + [ + pl.Series("a", values, dtype), + ] + ) + + f = io.BytesIO() + df.write_parquet(f) + + for needle in needles: + f.seek(0) + scan = pl.scan_parquet(f, parallel=strategy) + try: + assert_frame_equal( + df.filter(pl.col.a == pl.lit(needle, dtype)), + scan.filter(pl.col.a == pl.lit(needle, dtype)).collect(), + ) + except: + import sys + + print(f"needle: {needle}", file=sys.stderr) + raise + + pl.read_parquet(f) + + +def test_nested_string_slice_utf8_21202() -> None: + s = pl.Series( + "a", + [ + ["A" * 128], + ["A"], + ], + pl.List(pl.String()), + ) + + f = io.BytesIO() + s.to_frame().write_parquet(f) + + f.seek(0) + assert_series_equal( + pl.scan_parquet(f).slice(1, 1).collect().to_series(), + s.slice(1, 1), + ) + + +def test_filter_true_predicate_21204() -> None: + f = io.BytesIO() + + df = pl.DataFrame({"a": [1]}) + df.write_parquet(f) + f.seek(0) + lf = pl.scan_parquet(f).filter(pl.lit(True)) + assert_frame_equal(lf.collect(), df) + + +def test_nested_deprecated_int96_timestamps_21332() -> None: + f = io.BytesIO() + + df = pl.DataFrame({"a": [{"t": datetime(2025, 1, 1)}]}) + + pq.write_table( + df.to_arrow(), + f, + use_deprecated_int96_timestamps=True, + ) + + f.seek(0) + assert_frame_equal( + pl.read_parquet(f), + df, + ) + + +def test_final_masked_optional_iteration_21378() -> None: + # fmt: off + values = [ + 1, 0, 0, 0, 0, 1, 1, 1, + 1, 0, 0, 1, 1, 1, 1, 0, + 0, 1, 1, 1, 0, 1, 0, 0, + 1, 1, 0, 0, 0, 1, 1, 1, + 0, 1, 0, 0, 1, 1, 1, 1, + 0, 1, 1, 1, 0, 1, 0, 1, + 0, 1, 1, 0, 1, 0, 1, 1, + 0, 0, 0, 0, 1, 0, 0, 0, + 0, 1, 1, 1, 0, 0, 1, 1, + 0, 0, 1, 1, 0, 0, 0, 1, + 1, 1, 0, 1, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 0, + 0, 0, 1, 0, 1, 1, 0, 0, + 0, 1, 1, 0, 0, 1, 0, 0, + 1, 1, 1, 1, 0, 0, 1, 0, + 0, 1, 1, 0, 0, 1, 1, 1, + 1, 1, 1, 0, 1, 1, 0, 1, + 0, 1, 0, 1, 0, 1, 0, 1, + 0, 0, 0, 1, 1, 0, 0, 0, + 1, 1, 0, 1, 0, 1, 0, 1, + 0, 1, 0, 0, 0, 0, 0, 1, + 0, 0, 1, 1, 0, 0, 1, 1, + 0, 1, 0, 0, 0, 1, 1, 1, + 1, 0, 1, 0, 1, 0, 1, 1, + 1, 0, 1, 0, 0, 1, 0, 1, + 0, 1, 1, 1, 0, 0, 0, 1, + 1, 1, 1, 1, 1, 1, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 1, + 1, 1, 1, 0, 0, 0, 0, 0, + 1, 1, 1, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 1, 0, 0, 0, 0, 1, + 0, 0, 0, 0, 0, 1, 0, 0, + 1, 0, 1, 0, 0, 1, 0, 0, + 0, 1, 1, 1, 0, 0, 1, 1, + 1, 0, 1, 1, 0, 0, 0, 1, + 0, 0, 1, 1, 0, 1, 0, 1, + 0, 1, 1, 1, 0, 0, 0, 1, + 0, 0, 0, 1, 0, 1, 0, 1, + 0, 1, 0, 1, 0, 1, 1, 1, + 1, 0, 1, 1, 1, 1, 1, 0, + 1, 0, 1, 0, 0, 0, 1, 1, + 0, 0, 0, 1, 0, 0, 1, 0, + 0, 1, 0, 0, 1, 0, 1, 1, + 1, 0, 0, 1, 0, 1, 1, 0, + 0, 1, 0, 1, 1, 0, 1, 0, + 0, 0, 0, 1, 1, 1, 0, 0, + 0, 1, 0, 1, 1, 0, 1, 1, + 1, 1, 0, 1, 0, 1, 0, 1, + 1, 1, 0, 1, 0, 0, 1, 0, + 1, 1, 0, 1, 1, 0, 0, 1, + 0, 0, 0, 0, 0, 1, 0, 0, + 0, 1, 0, 0, 1, 1, 1, 1, + 1, 0, 1, 1, 1, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 1, 1, + ] + + df = pl.DataFrame( + [ + pl.Series("x", [None if x == 1 else 0.0 for x in values], pl.Float32), + pl.Series( + "f", + [False] * 164 + + [True] * 10 + + [False] * 264 + + [True] * 10, + pl.Boolean(), + ), + ] + ) + + f = io.BytesIO() + df.write_parquet(f) + f.seek(0) + + output = pl.scan_parquet(f, parallel="prefiltered").filter(pl.col.f).collect() + assert_frame_equal(df.filter(pl.col.f), output) + + +def test_predicate_empty_is_in_21450() -> None: + f = io.BytesIO() + df = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + df.write_parquet(f) + + f.seek(0) + assert_frame_equal( + df.clear(), + pl.scan_parquet(f).filter(pl.col("a").is_in([])).collect(), + ) + + +@pytest.mark.write_disk +def test_scan_parquet_filter_statistics_load_missing_column_21391( + tmp_path: Path, +) -> None: + root = tmp_path + dfs = [pl.DataFrame({"x": 1, "y": 1}), pl.DataFrame({"x": 2})] + + for i, df in enumerate(dfs): + df.write_parquet(root / f"{i}.parquet") + + assert_frame_equal( + ( + pl.scan_parquet(root, allow_missing_columns=True) + .filter(pl.col("y") == 1) + .collect() + ), + pl.DataFrame({"x": 1, "y": 1}), + ) + + +@pytest.mark.parametrize( + "ty", + [ + (lambda i: i, pl.Int8, True), + (lambda i: datetime(year=2025, month=9, day=i), pl.Datetime, True), + (lambda i: float(i), pl.Float32, True), + (lambda i: str(i), pl.String, True), + (lambda i: str(i) + "make it a bit longer", pl.String, True), + (lambda i: [i, i + 7] * (i % 3), pl.List(pl.Int32), True), + (lambda i: {"x": i}, pl.Struct({"x": pl.Int32}), True), + (lambda i: [i, i + 3, i + 7], pl.Array(pl.Int32, 3), False), + ], +) +def test_filter_nulls_21538(ty: tuple[Callable[[int], Any], pl.DataType, bool]) -> None: + i_to_value, dtype, do_no_dicts = ty + + patterns: list[list[int | None]] = [ + [None, None, None, None, None], + [1, None, None, 2, None], + [None, 1, 2, 3, 4], + [1, 2, 3, 4, None], + [None, 1, 2, 3, None], + [None, 1, None, 3, None], + [1, 2, 3, 4, 5], + ] + + df = pl.DataFrame( + [ + pl.Series( + f"p{i}", [None if v is None else i_to_value(v) for v in pattern], dtype + ) + for i, pattern in enumerate(patterns) + ] + ) + + fs = [] + + dicts_f = io.BytesIO() + df.write_parquet(dicts_f) + fs += [dicts_f] + + if do_no_dicts: + no_dicts_f = io.BytesIO() + pq.write_table(df.to_arrow(), no_dicts_f, use_dictionary=False) + fs += [no_dicts_f] + + for f in fs: + for i in range(len(patterns)): + f.seek(0) + assert_frame_equal( + pl.scan_parquet(f).filter(pl.col(f"p{i}").is_null()).collect(), + df.filter(pl.col(f"p{i}").is_null()), + ) + + f.seek(0) + assert_frame_equal( + pl.scan_parquet(f).filter(pl.col(f"p{i}").is_not_null()).collect(), + df.filter(pl.col(f"p{i}").is_not_null()), + ) + + +def test_unspecialized_decoding_prefiltering() -> None: + df = pl.DataFrame( + { + "a": [None, None, None, "abc"], + "b": [False, True, False, True], + } + ) + + cols = df.columns + + encodings = dict.fromkeys(cols, "DELTA_LENGTH_BYTE_ARRAY") + encodings["b"] = "PLAIN" + + f = io.BytesIO() + pq.write_table( + df.to_arrow(), + f, + use_dictionary=False, + column_encoding=encodings, + ) + + f.seek(0) + expr = pl.col("b") + result = ( + pl.scan_parquet(f, parallel="prefiltered") + .filter(expr) + .collect(engine="streaming") + ) + assert_frame_equal(result, df.filter(expr)) + + +@pytest.mark.parametrize("parallel", ["columns", "row_groups"]) +def test_filtering_on_other_parallel_modes_with_statistics( + parallel: ParallelStrategy, +) -> None: + f = io.BytesIO() + + pl.DataFrame( + { + "a": [1, 4, 9, 2, 4, 8, 3, 4, 7], + } + ).write_parquet(f, row_group_size=3) + + f.seek(0) + assert_series_equal( + pl.scan_parquet(f, parallel=parallel) + .filter(pl.col.a == 4) + .collect() + .to_series(), + pl.Series("a", [4, 4, 4]), + ) + + +def test_filter_on_logical_dtype_22252() -> None: + f = io.BytesIO() + pl.Series("a", [datetime(1996, 10, 5)]).to_frame().write_parquet(f) + f.seek(0) + pl.scan_parquet(f).filter(pl.col.a.dt.weekday() == 6).collect() + + +def test_filter_nan_22289() -> None: + f = io.BytesIO() + pl.DataFrame( + {"a": [1, 2, float("nan")], "b": [float("nan"), 5, 6]}, strict=False + ).write_parquet(f) + + f.seek(0) + lf = pl.scan_parquet(f) + + assert_frame_equal( + lf.collect().filter(pl.col.a.is_not_nan()), + lf.filter(pl.col.a.is_not_nan()).collect(), + ) + + assert_frame_equal( + lf.collect().filter(pl.col.a.is_nan()), + lf.filter(pl.col.a.is_nan()).collect(), + ) + + +def test_reencode_categoricals_22385() -> None: + tbl = pl.Series("a", ["abc"], pl.Categorical()).to_frame().to_arrow() + tbl = tbl.cast( + pa.schema( + [ + pa.field( + "a", + pa.dictionary(pa.int32(), pa.large_string()), + metadata=tbl.schema[0].metadata, + ), + ] + ) + ) + + f = io.BytesIO() + pq.write_table(tbl, f) + + f.seek(0) + pl.scan_parquet(f).collect() diff --git a/py-polars/tests/unit/io/test_partition.py b/py-polars/tests/unit/io/test_partition.py new file mode 100644 index 000000000000..12639454e200 --- /dev/null +++ b/py-polars/tests/unit/io/test_partition.py @@ -0,0 +1,375 @@ +from __future__ import annotations + +import io +from typing import TYPE_CHECKING, Any, TypedDict + +import pytest +from hypothesis import given + +import polars as pl +from polars.io.partition import ( + PartitionByKey, + PartitionMaxSize, + PartitionParted, +) +from polars.testing import assert_frame_equal, assert_series_equal +from polars.testing.parametric.strategies import dataframes + +if TYPE_CHECKING: + from pathlib import Path + + from polars.io.partition import ( + BasePartitionContext, + ) + + +class IOType(TypedDict): + """A type of IO.""" + + ext: str + scan: Any + sink: Any + + +io_types: list[IOType] = [ + {"ext": "csv", "scan": pl.scan_csv, "sink": pl.LazyFrame.sink_csv}, + {"ext": "jsonl", "scan": pl.scan_ndjson, "sink": pl.LazyFrame.sink_ndjson}, + {"ext": "parquet", "scan": pl.scan_parquet, "sink": pl.LazyFrame.sink_parquet}, + {"ext": "ipc", "scan": pl.scan_ipc, "sink": pl.LazyFrame.sink_ipc}, +] + + +@pytest.mark.parametrize("io_type", io_types) +@pytest.mark.parametrize("length", [0, 1, 4, 5, 6, 7]) +@pytest.mark.parametrize("max_size", [1, 2, 3]) +@pytest.mark.write_disk +def test_max_size_partition( + tmp_path: Path, + io_type: IOType, + length: int, + max_size: int, +) -> None: + lf = pl.Series("a", range(length), pl.Int64).to_frame().lazy() + + (io_type["sink"])( + lf, + PartitionMaxSize(tmp_path, max_size=max_size), + engine="streaming", + # We need to sync here because platforms do not guarantee that a close on + # one thread is immediately visible on another thread. + # + # "Multithreaded processes and close()" + # https://man7.org/linux/man-pages/man2/close.2.html + sync_on_close="data", + ) + + i = 0 + while length > 0: + assert (io_type["scan"])(tmp_path / f"{i}.{io_type['ext']}").select( + pl.len() + ).collect()[0, 0] == min(max_size, length) + + length -= max_size + i += 1 + + +@pytest.mark.parametrize("io_type", io_types) +def test_max_size_partition_lambda( + tmp_path: Path, + io_type: IOType, +) -> None: + length = 17 + max_size = 3 + lf = pl.Series("a", range(length), pl.Int64).to_frame().lazy() + + (io_type["sink"])( + lf, + PartitionMaxSize( + tmp_path, + file_path=lambda ctx: ctx.file_path.with_name("abc-" + ctx.file_path.name), + max_size=max_size, + ), + engine="streaming", + # We need to sync here because platforms do not guarantee that a close on + # one thread is immediately visible on another thread. + # + # "Multithreaded processes and close()" + # https://man7.org/linux/man-pages/man2/close.2.html + sync_on_close="data", + ) + + i = 0 + while length > 0: + assert (io_type["scan"])(tmp_path / f"abc-{i}.{io_type['ext']}").select( + pl.len() + ).collect()[0, 0] == min(max_size, length) + + length -= max_size + i += 1 + + +@pytest.mark.parametrize("io_type", io_types) +@pytest.mark.write_disk +def test_partition_by_key( + tmp_path: Path, + io_type: IOType, +) -> None: + lf = pl.Series("a", [i % 4 for i in range(7)], pl.Int64).to_frame().lazy() + + (io_type["sink"])( + lf, + PartitionByKey( + tmp_path, file_path=lambda ctx: f"{ctx.file_idx}.{io_type['ext']}", by="a" + ), + engine="streaming", + # We need to sync here because platforms do not guarantee that a close on + # one thread is immediately visible on another thread. + # + # "Multithreaded processes and close()" + # https://man7.org/linux/man-pages/man2/close.2.html + sync_on_close="data", + ) + + assert_series_equal( + (io_type["scan"])(tmp_path / f"0.{io_type['ext']}").collect().to_series(), + pl.Series("a", [0, 0], pl.Int64), + ) + assert_series_equal( + (io_type["scan"])(tmp_path / f"1.{io_type['ext']}").collect().to_series(), + pl.Series("a", [1, 1], pl.Int64), + ) + assert_series_equal( + (io_type["scan"])(tmp_path / f"2.{io_type['ext']}").collect().to_series(), + pl.Series("a", [2, 2], pl.Int64), + ) + assert_series_equal( + (io_type["scan"])(tmp_path / f"3.{io_type['ext']}").collect().to_series(), + pl.Series("a", [3], pl.Int64), + ) + + scan_flags = ( + {"schema": pl.Schema({"a": pl.String()})} if io_type["ext"] == "csv" else {} + ) + + # Change the datatype. + (io_type["sink"])( + lf, + PartitionByKey( + tmp_path, + file_path=lambda ctx: f"{ctx.file_idx}.{io_type['ext']}", + by=pl.col.a.cast(pl.String()), + ), + engine="streaming", + sync_on_close="data", + ) + + assert_series_equal( + (io_type["scan"])(tmp_path / f"0.{io_type['ext']}", **scan_flags) + .collect() + .to_series(), + pl.Series("a", ["0", "0"], pl.String), + ) + assert_series_equal( + (io_type["scan"])(tmp_path / f"1.{io_type['ext']}", **scan_flags) + .collect() + .to_series(), + pl.Series("a", ["1", "1"], pl.String), + ) + assert_series_equal( + (io_type["scan"])(tmp_path / f"2.{io_type['ext']}", **scan_flags) + .collect() + .to_series(), + pl.Series("a", ["2", "2"], pl.String), + ) + assert_series_equal( + (io_type["scan"])(tmp_path / f"3.{io_type['ext']}", **scan_flags) + .collect() + .to_series(), + pl.Series("a", ["3"], pl.String), + ) + + +@pytest.mark.parametrize("io_type", io_types) +@pytest.mark.write_disk +def test_partition_parted( + tmp_path: Path, + io_type: IOType, +) -> None: + s = pl.Series("a", [1, 1, 2, 3, 3, 4, 4, 4, 6], pl.Int64) + lf = s.to_frame().lazy() + + (io_type["sink"])( + lf, + PartitionParted( + tmp_path, file_path=lambda ctx: f"{ctx.file_idx}.{io_type['ext']}", by="a" + ), + engine="streaming", + # We need to sync here because platforms do not guarantee that a close on + # one thread is immediately visible on another thread. + # + # "Multithreaded processes and close()" + # https://man7.org/linux/man-pages/man2/close.2.html + sync_on_close="data", + ) + + rle = s.rle() + + for i, row in enumerate(rle.struct.unnest().rows(named=True)): + assert_series_equal( + (io_type["scan"])(tmp_path / f"{i}.{io_type['ext']}").collect().to_series(), + pl.Series("a", [row["value"]] * row["len"], pl.Int64), + ) + + scan_flags = ( + {"schema_overrides": pl.Schema({"a_str": pl.String()})} + if io_type["ext"] == "csv" + else {} + ) + + # Change the datatype. + (io_type["sink"])( + lf, + PartitionParted( + tmp_path, + file_path=lambda ctx: f"{ctx.file_idx}.{io_type['ext']}", + by=[pl.col.a, pl.col.a.cast(pl.String()).alias("a_str")], + ), + engine="streaming", + sync_on_close="data", + ) + + for i, row in enumerate(rle.struct.unnest().rows(named=True)): + assert_frame_equal( + (io_type["scan"])( + tmp_path / f"{i}.{io_type['ext']}", **scan_flags + ).collect(), + pl.DataFrame( + [ + pl.Series("a", [row["value"]] * row["len"], pl.Int64), + pl.Series("a_str", [str(row["value"])] * row["len"], pl.String), + ] + ), + ) + + # No include key. + (io_type["sink"])( + lf, + PartitionParted( + tmp_path, + file_path=lambda ctx: f"{ctx.file_idx}.{io_type['ext']}", + by=[pl.col.a.cast(pl.String()).alias("a_str")], + include_key=False, + ), + engine="streaming", + sync_on_close="data", + ) + + for i, row in enumerate(rle.struct.unnest().rows(named=True)): + assert_series_equal( + (io_type["scan"])(tmp_path / f"{i}.{io_type['ext']}").collect().to_series(), + pl.Series("a", [row["value"]] * row["len"], pl.Int64), + ) + + +# We only deal with self-describing formats +@pytest.mark.parametrize("io_type", [io_types[2], io_types[3]]) +@pytest.mark.write_disk +@given( + df=dataframes( + min_cols=1, + excluded_dtypes=[ + pl.Decimal, # Bug see: https://github.com/pola-rs/polars/issues/21684 + pl.Duration, # Bug see: https://github.com/pola-rs/polars/issues/21964 + pl.Categorical, # We cannot ensure the string cache is properly held. + # Generate invalid UTF-8 + pl.Binary, + pl.Struct, + pl.Array, + pl.List, + ], + ) +) +def test_partition_by_key_parametric( + tmp_path_factory: pytest.TempPathFactory, + io_type: IOType, + df: pl.DataFrame, +) -> None: + col1 = df.columns[0] + + tmp_path = tmp_path_factory.mktemp("data") + + dfs = df.partition_by(col1) + (io_type["sink"])( + df.lazy(), + PartitionByKey( + tmp_path, file_path=lambda ctx: f"{ctx.file_idx}.{io_type['ext']}", by=col1 + ), + engine="streaming", + # We need to sync here because platforms do not guarantee that a close on + # one thread is immediately visible on another thread. + # + # "Multithreaded processes and close()" + # https://man7.org/linux/man-pages/man2/close.2.html + sync_on_close="data", + ) + + for i, df in enumerate(dfs): + assert_frame_equal( + df, + (io_type["scan"])( + tmp_path / f"{i}.{io_type['ext']}", + ).collect(), + ) + + +def test_max_size_partition_collect_files(tmp_path: Path) -> None: + length = 17 + max_size = 3 + lf = pl.Series("a", range(length), pl.Int64).to_frame().lazy() + + io_type = io_types[0] + output_files = [] + + def file_path_cb(ctx: BasePartitionContext) -> Path: + print(ctx) + print(ctx.full_path) + output_files.append(ctx.full_path) + print(ctx.file_path) + return ctx.file_path + + (io_type["sink"])( + lf, + PartitionMaxSize(tmp_path, file_path=file_path_cb, max_size=max_size), + engine="streaming", + # We need to sync here because platforms do not guarantee that a close on + # one thread is immediately visible on another thread. + # + # "Multithreaded processes and close()" + # https://man7.org/linux/man-pages/man2/close.2.html + sync_on_close="data", + ) + + assert output_files == [tmp_path / f"{i}.{io_type['ext']}" for i in range(6)] + + +@pytest.mark.parametrize(("io_type"), io_types) +def test_partition_to_memory(tmp_path: Path, io_type: IOType) -> None: + df = pl.DataFrame( + { + "a": [5, 10, 1996], + } + ) + + output_files = {} + + def file_path_cb(ctx: BasePartitionContext) -> io.BytesIO: + f = io.BytesIO() + output_files[ctx.file_path] = f + return f + + io_type["sink"](df.lazy(), PartitionMaxSize("", file_path=file_path_cb, max_size=1)) + + assert len(output_files) == df.height + for i, (_, value) in enumerate(output_files.items()): + value.seek(0) + assert_frame_equal(io_type["scan"](value).collect(), df.slice(i, 1)) diff --git a/py-polars/tests/unit/io/test_pickle.py b/py-polars/tests/unit/io/test_pickle.py new file mode 100644 index 000000000000..5b798cbda2eb --- /dev/null +++ b/py-polars/tests/unit/io/test_pickle.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import io +import pickle + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal + + +def test_pickle() -> None: + a = pl.Series("a", [1, 2]) + b = pickle.dumps(a) + out = pickle.loads(b) + assert_series_equal(a, out) + df = pl.DataFrame({"a": [1, 2], "b": ["a", None], "c": [True, False]}) + b = pickle.dumps(df) + out = pickle.loads(b) + assert_frame_equal(df, out) + + +def test_pickle_expr() -> None: + for e in [ + pl.all(), + pl.len(), + pl.duration(weeks=10, days=20, hours=3), + pl.col("a").cast(pl.Int128), + ]: + f = io.BytesIO() + pickle.dump(e, f) + + f.seek(0) + pickle.load(f) diff --git a/py-polars/tests/unit/io/test_plugins.py b/py-polars/tests/unit/io/test_plugins.py new file mode 100644 index 000000000000..6303df166962 --- /dev/null +++ b/py-polars/tests/unit/io/test_plugins.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import polars as pl +from polars.io.plugins import register_io_source +from polars.testing import assert_frame_equal + +if TYPE_CHECKING: + from collections.abc import Iterator + + +# A simple python source. But this can dispatch into a rust IO source as well. +def my_source( + with_columns: list[str] | None, + predicate: pl.Expr | None, + _n_rows: int | None, + _batch_size: int | None, +) -> Iterator[pl.DataFrame]: + for i in [1, 2, 3]: + df = pl.DataFrame({"a": [i], "b": [i]}) + + if predicate is not None: + df = df.filter(predicate) + + if with_columns is not None: + df = df.select(with_columns) + + yield df + + +def scan_my_source() -> pl.LazyFrame: + # schema inference logic + # TODO: make lazy via callable + schema = pl.Schema({"a": pl.Int64(), "b": pl.Int64()}) + + return register_io_source(my_source, schema=schema) + + +def test_my_source() -> None: + assert_frame_equal( + scan_my_source().collect(), pl.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]}) + ) + assert_frame_equal( + scan_my_source().filter(pl.col("b") > 1).collect(), + pl.DataFrame({"a": [2, 3], "b": [2, 3]}), + ) + assert_frame_equal( + scan_my_source().filter(pl.col("b") > 1).select("a").collect(), + pl.DataFrame({"a": [2, 3]}), + ) + assert_frame_equal( + scan_my_source().select("a").collect(), pl.DataFrame({"a": [1, 2, 3]}) + ) diff --git a/py-polars/tests/unit/io/test_pyarrow_dataset.py b/py-polars/tests/unit/io/test_pyarrow_dataset.py new file mode 100644 index 000000000000..01dd9f208355 --- /dev/null +++ b/py-polars/tests/unit/io/test_pyarrow_dataset.py @@ -0,0 +1,231 @@ +from __future__ import annotations + +from datetime import date, datetime, time +from typing import TYPE_CHECKING, Callable + +import pyarrow.dataset as ds + +import polars as pl +from polars.testing import assert_frame_equal + +if TYPE_CHECKING: + from pathlib import Path + + +def helper_dataset_test( + file_path: Path, + query: Callable[[pl.LazyFrame], pl.LazyFrame], + batch_size: int | None = None, + n_expected: int | None = None, + check_predicate_pushdown: bool = False, +) -> None: + dset = ds.dataset(file_path, format="ipc") + q = pl.scan_ipc(file_path).pipe(query) + + expected = q.collect() + out = pl.scan_pyarrow_dataset(dset, batch_size=batch_size).pipe(query).collect() + assert_frame_equal(out, expected) + if n_expected is not None: + assert len(out) == n_expected + + if check_predicate_pushdown: + assert "FILTER" not in q.explain() + + +# @pytest.mark.write_disk() +def test_pyarrow_dataset_source(df: pl.DataFrame, tmp_path: Path) -> None: + file_path = tmp_path / "small.ipc" + df.write_ipc(file_path) + + helper_dataset_test( + file_path, + lambda lf: lf.filter("bools").select("bools", "floats", "date"), + n_expected=1, + check_predicate_pushdown=True, + ) + helper_dataset_test( + file_path, + lambda lf: lf.filter(~pl.col("bools")).select("bools", "floats", "date"), + n_expected=2, + check_predicate_pushdown=True, + ) + helper_dataset_test( + file_path, + lambda lf: lf.filter(pl.col("int_nulls").is_null()).select( + "bools", "floats", "date" + ), + n_expected=1, + check_predicate_pushdown=True, + ) + helper_dataset_test( + file_path, + lambda lf: lf.filter(pl.col("int_nulls").is_not_null()).select( + "bools", "floats", "date" + ), + n_expected=2, + check_predicate_pushdown=True, + ) + helper_dataset_test( + file_path, + lambda lf: lf.filter( + pl.col("int_nulls").is_not_null() == pl.col("bools") + ).select("bools", "floats", "date"), + n_expected=0, + check_predicate_pushdown=True, + ) + # this equality on a column with nulls fails as pyarrow has different + # handling kleene logic. We leave it for now and document it in the function. + helper_dataset_test( + file_path, + lambda lf: lf.filter(pl.col("int") == 10).select( + "bools", "floats", "int_nulls" + ), + n_expected=0, + check_predicate_pushdown=True, + ) + helper_dataset_test( + file_path, + lambda lf: lf.filter(pl.col("int") != 10).select( + "bools", "floats", "int_nulls" + ), + n_expected=3, + check_predicate_pushdown=True, + ) + + for closed, n_expected in zip(["both", "left", "right", "none"], [3, 2, 2, 1]): + helper_dataset_test( + file_path, + lambda lf, closed=closed: lf.filter( # type: ignore[misc] + pl.col("int").is_between(1, 3, closed=closed) + ).select("bools", "floats", "date"), + n_expected=n_expected, + check_predicate_pushdown=True, + ) + # this predicate is not supported by pyarrow + # check if we still do it on our side + helper_dataset_test( + file_path, + lambda lf: lf.filter(pl.col("floats").sum().over("date") == 10).select( + "bools", "floats", "date" + ), + n_expected=0, + ) + # temporal types + helper_dataset_test( + file_path, + lambda lf: lf.filter(pl.col("date") < date(1972, 1, 1)).select( + "bools", "floats", "date" + ), + n_expected=1, + check_predicate_pushdown=True, + ) + helper_dataset_test( + file_path, + lambda lf: lf.filter( + pl.col("datetime") > datetime(1970, 1, 1, second=13) + ).select("bools", "floats", "date"), + n_expected=1, + check_predicate_pushdown=True, + ) + # not yet supported in pyarrow + helper_dataset_test( + file_path, + lambda lf: lf.filter(pl.col("time") >= time(microsecond=100)).select( + "bools", "time", "date" + ), + n_expected=3, + check_predicate_pushdown=True, + ) + # pushdown is_in + helper_dataset_test( + file_path, + lambda lf: lf.filter(pl.col("int").is_in([1, 3, 20])).select( + "bools", "floats", "date" + ), + n_expected=2, + check_predicate_pushdown=True, + ) + helper_dataset_test( + file_path, + lambda lf: lf.filter( + pl.col("date").is_in([date(1973, 8, 17), date(1973, 5, 19)]) + ).select("bools", "floats", "date"), + n_expected=2, + check_predicate_pushdown=True, + ) + helper_dataset_test( + file_path, + lambda lf: lf.filter( + pl.col("datetime").is_in( + [ + datetime(1970, 1, 1, 0, 0, 12, 341234), + datetime(1970, 1, 1, 0, 0, 13, 241324), + ] + ) + ).select("bools", "floats", "date"), + n_expected=2, + check_predicate_pushdown=True, + ) + helper_dataset_test( + file_path, + lambda lf: lf.filter(pl.col("int").is_in(list(range(120)))).select( + "bools", "floats", "date" + ), + n_expected=3, + check_predicate_pushdown=True, + ) + # TODO: remove string cache + with pl.StringCache(): + helper_dataset_test( + file_path, + lambda lf: lf.filter(pl.col("cat").is_in([])).select( + "bools", "floats", "date" + ), + n_expected=0, + ) + helper_dataset_test( + file_path, + lambda lf: lf.select(pl.exclude("enum")), + batch_size=2, + n_expected=3, + ) + + # direct filter + helper_dataset_test( + file_path, + lambda lf: lf.filter(pl.Series([True, False, True])).select( + "bools", "floats", "date" + ), + n_expected=2, + ) + + helper_dataset_test( + file_path, + lambda lf: lf.filter(pl.col("bools") & pl.col("int").is_in([1, 2])).select( + "bools", "floats" + ), + n_expected=1, + check_predicate_pushdown=True, + ) + + +def test_pyarrow_dataset_comm_subplan_elim(tmp_path: Path) -> None: + df0 = pl.DataFrame({"a": [1, 2, 3]}) + + df1 = pl.DataFrame({"a": [1, 2]}) + + file_path_0 = tmp_path / "0.parquet" + file_path_1 = tmp_path / "1.parquet" + + df0.write_parquet(file_path_0) + df1.write_parquet(file_path_1) + + ds0 = ds.dataset(file_path_0, format="parquet") + ds1 = ds.dataset(file_path_1, format="parquet") + + lf0 = pl.scan_pyarrow_dataset(ds0) + lf1 = pl.scan_pyarrow_dataset(ds1) + + assert lf0.join(lf1, on="a", how="inner").collect().to_dict(as_series=False) == { + "a": [1, 2] + } diff --git a/py-polars/tests/unit/io/test_scan.py b/py-polars/tests/unit/io/test_scan.py new file mode 100644 index 000000000000..7d0570f8eea7 --- /dev/null +++ b/py-polars/tests/unit/io/test_scan.py @@ -0,0 +1,1091 @@ +from __future__ import annotations + +import io +import sys +from dataclasses import dataclass +from datetime import datetime +from functools import partial +from math import ceil +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable + +import pytest + +import polars as pl +from polars.testing.asserts.frame import assert_frame_equal + +if TYPE_CHECKING: + from polars._typing import SchemaDict + + +@dataclass +class _RowIndex: + name: str = "index" + offset: int = 0 + + +def _enable_force_async(monkeypatch: pytest.MonkeyPatch) -> None: + """Modifies the provided monkeypatch context.""" + monkeypatch.setenv("POLARS_VERBOSE", "1") + monkeypatch.setenv("POLARS_FORCE_ASYNC", "1") + + +def _scan( + file_path: Path, + schema: SchemaDict | None = None, + row_index: _RowIndex | None = None, +) -> pl.LazyFrame: + suffix = file_path.suffix + row_index_name = None if row_index is None else row_index.name + row_index_offset = 0 if row_index is None else row_index.offset + + if ( + scan_func := { + ".ipc": pl.scan_ipc, + ".parquet": pl.scan_parquet, + ".csv": pl.scan_csv, + ".ndjson": pl.scan_ndjson, + }.get(suffix) + ) is not None: # fmt: skip + result = scan_func( + file_path, + row_index_name=row_index_name, + row_index_offset=row_index_offset, + ) # type: ignore[operator] + + else: + msg = f"Unknown suffix {suffix}" + raise NotImplementedError(msg) + + return result # type: ignore[no-any-return] + + +def _write(df: pl.DataFrame, file_path: Path) -> None: + suffix = file_path.suffix + + if ( + write_func := { + ".ipc": pl.DataFrame.write_ipc, + ".parquet": pl.DataFrame.write_parquet, + ".csv": pl.DataFrame.write_csv, + ".ndjson": pl.DataFrame.write_ndjson, + }.get(suffix) + ) is not None: # fmt: skip + return write_func(df, file_path) # type: ignore[operator, no-any-return] + + msg = f"Unknown suffix {suffix}" + raise NotImplementedError(msg) + + +@pytest.fixture( + scope="session", + params=["csv", "ipc", "parquet", "ndjson"], +) +def data_file_extension(request: pytest.FixtureRequest) -> str: + return f".{request.param}" + + +@pytest.fixture(scope="session") +def session_tmp_dir(tmp_path_factory: pytest.TempPathFactory) -> Path: + return tmp_path_factory.mktemp("polars-test") + + +@pytest.fixture( + params=[False, True], + ids=["sync", "async"], +) +def force_async( + request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch +) -> bool: + value: bool = request.param + return value + + +@dataclass +class _DataFile: + path: Path + df: pl.DataFrame + + +def df_with_chunk_size_limit(df: pl.DataFrame, limit: int) -> pl.DataFrame: + return pl.concat( + ( + df.slice(i * limit, min(limit, df.height - i * limit)) + for i in range(ceil(df.height / limit)) + ), + rechunk=False, + ) + + +@pytest.fixture(scope="session") +def data_file_single(session_tmp_dir: Path, data_file_extension: str) -> _DataFile: + max_rows_per_batch = 727 + file_path = (session_tmp_dir / "data").with_suffix(data_file_extension) + df = pl.DataFrame( + { + "sequence": range(10000), + } + ) + assert max_rows_per_batch < df.height + _write(df_with_chunk_size_limit(df, max_rows_per_batch), file_path) + return _DataFile(path=file_path, df=df) + + +@pytest.fixture(scope="session") +def data_file_glob(session_tmp_dir: Path, data_file_extension: str) -> _DataFile: + max_rows_per_batch = 200 + row_counts = [ + 100, 186, 95, 185, 90, 84, 115, 81, 87, 217, 126, 85, 98, 122, 129, 122, 1089, 82, + 234, 86, 93, 90, 91, 263, 87, 126, 86, 161, 191, 1368, 403, 192, 102, 98, 115, 81, + 111, 305, 92, 534, 431, 150, 90, 128, 152, 118, 127, 124, 229, 368, 81, + ] # fmt: skip + assert sum(row_counts) == 10000 + + # Make sure we pad file names with enough zeros to ensure correct + # lexicographical ordering. + assert len(row_counts) < 100 + + # Make sure that some of our data frames consist of multiple chunks which + # affects the output of certain file formats. + assert any(row_count > max_rows_per_batch for row_count in row_counts) + df = pl.DataFrame( + { + "sequence": range(10000), + } + ) + + row_offset = 0 + for index, row_count in enumerate(row_counts): + file_path = (session_tmp_dir / f"data_{index:02}").with_suffix( + data_file_extension + ) + _write( + df_with_chunk_size_limit( + df.slice(row_offset, row_count), max_rows_per_batch + ), + file_path, + ) + row_offset += row_count + return _DataFile( + path=(session_tmp_dir / "data_*").with_suffix(data_file_extension), df=df + ) + + +@pytest.fixture(scope="session", params=["single", "glob"]) +def data_file( + request: pytest.FixtureRequest, + data_file_single: _DataFile, + data_file_glob: _DataFile, +) -> _DataFile: + if request.param == "single": + return data_file_single + if request.param == "glob": + return data_file_glob + raise NotImplementedError() + + +@pytest.mark.write_disk +def test_scan( + capfd: Any, monkeypatch: pytest.MonkeyPatch, data_file: _DataFile, force_async: bool +) -> None: + if force_async: + _enable_force_async(monkeypatch) + + df = _scan(data_file.path, data_file.df.schema).collect() + + assert_frame_equal(df, data_file.df) + + +@pytest.mark.write_disk +def test_scan_with_limit( + capfd: Any, monkeypatch: pytest.MonkeyPatch, data_file: _DataFile, force_async: bool +) -> None: + if force_async: + _enable_force_async(monkeypatch) + + df = _scan(data_file.path, data_file.df.schema).limit(4483).collect() + + assert_frame_equal( + df, + pl.DataFrame( + { + "sequence": range(4483), + } + ), + ) + + +@pytest.mark.write_disk +def test_scan_with_filter( + capfd: Any, monkeypatch: pytest.MonkeyPatch, data_file: _DataFile, force_async: bool +) -> None: + if force_async: + _enable_force_async(monkeypatch) + + df = ( + _scan(data_file.path, data_file.df.schema) + .filter(pl.col("sequence") % 2 == 0) + .collect() + ) + + assert_frame_equal( + df, + pl.DataFrame( + { + "sequence": (2 * x for x in range(5000)), + } + ), + ) + + +@pytest.mark.write_disk +def test_scan_with_filter_and_limit( + capfd: Any, monkeypatch: pytest.MonkeyPatch, data_file: _DataFile, force_async: bool +) -> None: + if force_async: + _enable_force_async(monkeypatch) + + df = ( + _scan(data_file.path, data_file.df.schema) + .filter(pl.col("sequence") % 2 == 0) + .limit(4483) + .collect() + ) + + assert_frame_equal( + df, + pl.DataFrame( + { + "sequence": (2 * x for x in range(4483)), + }, + ), + ) + + +@pytest.mark.write_disk +def test_scan_with_limit_and_filter( + capfd: Any, monkeypatch: pytest.MonkeyPatch, data_file: _DataFile, force_async: bool +) -> None: + if force_async: + _enable_force_async(monkeypatch) + + df = ( + _scan(data_file.path, data_file.df.schema) + .limit(4483) + .filter(pl.col("sequence") % 2 == 0) + .collect() + ) + + assert_frame_equal( + df, + pl.DataFrame( + { + "sequence": (2 * x for x in range(2242)), + }, + ), + ) + + +@pytest.mark.write_disk +def test_scan_with_row_index_and_limit( + capfd: Any, monkeypatch: pytest.MonkeyPatch, data_file: _DataFile, force_async: bool +) -> None: + if force_async: + _enable_force_async(monkeypatch) + + df = ( + _scan(data_file.path, data_file.df.schema, row_index=_RowIndex()) + .limit(4483) + .collect() + ) + + assert_frame_equal( + df, + pl.DataFrame( + { + "index": range(4483), + "sequence": range(4483), + }, + schema_overrides={"index": pl.UInt32}, + ), + ) + + +@pytest.mark.write_disk +def test_scan_with_row_index_and_filter( + capfd: Any, monkeypatch: pytest.MonkeyPatch, data_file: _DataFile, force_async: bool +) -> None: + if force_async: + _enable_force_async(monkeypatch) + + df = ( + _scan(data_file.path, data_file.df.schema, row_index=_RowIndex()) + .filter(pl.col("sequence") % 2 == 0) + .collect() + ) + + assert_frame_equal( + df, + pl.DataFrame( + { + "index": (2 * x for x in range(5000)), + "sequence": (2 * x for x in range(5000)), + }, + schema_overrides={"index": pl.UInt32}, + ), + ) + + +@pytest.mark.write_disk +def test_scan_with_row_index_limit_and_filter( + capfd: Any, monkeypatch: pytest.MonkeyPatch, data_file: _DataFile, force_async: bool +) -> None: + if force_async: + _enable_force_async(monkeypatch) + + df = ( + _scan(data_file.path, data_file.df.schema, row_index=_RowIndex()) + .limit(4483) + .filter(pl.col("sequence") % 2 == 0) + .collect() + ) + + assert_frame_equal( + df, + pl.DataFrame( + { + "index": (2 * x for x in range(2242)), + "sequence": (2 * x for x in range(2242)), + }, + schema_overrides={"index": pl.UInt32}, + ), + ) + + +@pytest.mark.write_disk +def test_scan_with_row_index_projected_out( + capfd: Any, monkeypatch: pytest.MonkeyPatch, data_file: _DataFile, force_async: bool +) -> None: + if data_file.path.suffix == ".csv" and force_async: + pytest.skip(reason="async reading of .csv not yet implemented") + + if force_async: + _enable_force_async(monkeypatch) + + subset = next(iter(data_file.df.schema.keys())) + df = ( + _scan(data_file.path, data_file.df.schema, row_index=_RowIndex()) + .select(subset) + .collect() + ) + + assert_frame_equal(df, data_file.df.select(subset)) + + +@pytest.mark.write_disk +def test_scan_with_row_index_filter_and_limit( + capfd: Any, monkeypatch: pytest.MonkeyPatch, data_file: _DataFile, force_async: bool +) -> None: + if data_file.path.suffix == ".csv" and force_async: + pytest.skip(reason="async reading of .csv not yet implemented") + + if force_async: + _enable_force_async(monkeypatch) + + df = ( + _scan(data_file.path, data_file.df.schema, row_index=_RowIndex()) + .filter(pl.col("sequence") % 2 == 0) + .limit(4483) + .collect() + ) + + assert_frame_equal( + df, + pl.DataFrame( + { + "index": (2 * x for x in range(4483)), + "sequence": (2 * x for x in range(4483)), + }, + schema_overrides={"index": pl.UInt32}, + ), + ) + + +@pytest.mark.write_disk +@pytest.mark.parametrize( + ("scan_func", "write_func"), + [ + (pl.scan_parquet, pl.DataFrame.write_parquet), + (pl.scan_ipc, pl.DataFrame.write_ipc), + (pl.scan_csv, pl.DataFrame.write_csv), + (pl.scan_ndjson, pl.DataFrame.write_ndjson), + ], +) +@pytest.mark.parametrize( + "streaming", + [True, False], +) +def test_scan_limit_0_does_not_panic( + tmp_path: Path, + scan_func: Callable[[Any], pl.LazyFrame], + write_func: Callable[[pl.DataFrame, Path], None], + streaming: bool, +) -> None: + tmp_path.mkdir(exist_ok=True) + path = tmp_path / "data.bin" + df = pl.DataFrame({"x": 1}) + write_func(df, path) + assert_frame_equal( + scan_func(path) + .head(0) + .collect(engine="streaming" if streaming else "in-memory"), + df.clear(), + ) + + +@pytest.mark.write_disk +@pytest.mark.parametrize( + ("scan_func", "write_func"), + [ + (pl.scan_csv, pl.DataFrame.write_csv), + (pl.scan_parquet, pl.DataFrame.write_parquet), + (pl.scan_ipc, pl.DataFrame.write_ipc), + (pl.scan_ndjson, pl.DataFrame.write_ndjson), + ], +) +@pytest.mark.parametrize( + "glob", + [True, False], +) +def test_scan_directory( + tmp_path: Path, + scan_func: Callable[..., pl.LazyFrame], + write_func: Callable[[pl.DataFrame, Path], None], + glob: bool, +) -> None: + tmp_path.mkdir(exist_ok=True) + + dfs: list[pl.DataFrame] = [ + pl.DataFrame({"a": [0, 0, 0, 0, 0]}), + pl.DataFrame({"a": [1, 1, 1, 1, 1]}), + pl.DataFrame({"a": [2, 2, 2, 2, 2]}), + ] + + paths = [ + tmp_path / "0.bin", + tmp_path / "1.bin", + tmp_path / "dir/data.bin", + ] + + for df, path in zip(dfs, paths): + path.parent.mkdir(exist_ok=True) + write_func(df, path) + + df = pl.concat(dfs) + + scan = scan_func + + if scan_func in [pl.scan_csv, pl.scan_ndjson]: + scan = partial(scan, schema=df.schema) + + if scan_func is pl.scan_parquet: + scan = partial(scan, glob=glob) + + out = scan(tmp_path).collect() + assert_frame_equal(out, df) + + +@pytest.mark.write_disk +def test_scan_glob_excludes_directories(tmp_path: Path) -> None: + for dir in ["dir1", "dir2", "dir3"]: + (tmp_path / dir).mkdir() + + df = pl.DataFrame({"a": [1, 2, 3]}) + + df.write_parquet(tmp_path / "dir1/data.bin") + df.write_parquet(tmp_path / "dir2/data.parquet") + df.write_parquet(tmp_path / "data.parquet") + + assert_frame_equal(pl.scan_parquet(tmp_path / "**/*.bin").collect(), df) + assert_frame_equal(pl.scan_parquet(tmp_path / "**/data*.bin").collect(), df) + assert_frame_equal( + pl.scan_parquet(tmp_path / "**/*").collect(), pl.concat(3 * [df]) + ) + assert_frame_equal(pl.scan_parquet(tmp_path / "*").collect(), df) + + +@pytest.mark.parametrize("file_name", ["a b", "a %25 b"]) +@pytest.mark.write_disk +def test_scan_async_whitespace_in_path( + tmp_path: Path, monkeypatch: Any, file_name: str +) -> None: + monkeypatch.setenv("POLARS_FORCE_ASYNC", "1") + tmp_path.mkdir(exist_ok=True) + + path = tmp_path / f"{file_name}.parquet" + df = pl.DataFrame({"x": 1}) + df.write_parquet(path) + assert_frame_equal(pl.scan_parquet(path).collect(), df) + assert_frame_equal(pl.scan_parquet(tmp_path).collect(), df) + assert_frame_equal(pl.scan_parquet(tmp_path / "*").collect(), df) + assert_frame_equal(pl.scan_parquet(tmp_path / "*.parquet").collect(), df) + path.unlink() + + +@pytest.mark.write_disk +def test_path_expansion_excludes_empty_files_17362(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + df = pl.DataFrame({"x": 1}) + df.write_parquet(tmp_path / "data.parquet") + (tmp_path / "empty").touch() + + assert_frame_equal(pl.scan_parquet(tmp_path).collect(), df) + assert_frame_equal(pl.scan_parquet(tmp_path / "*").collect(), df) + + +@pytest.mark.write_disk +def test_path_expansion_empty_directory_does_not_panic(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + with pytest.raises(pl.exceptions.ComputeError): + pl.scan_parquet(tmp_path).collect() + + with pytest.raises(pl.exceptions.ComputeError): + pl.scan_parquet(tmp_path / "**/*").collect() + + +@pytest.mark.write_disk +def test_scan_single_dir_differing_file_extensions_raises_17436(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + df = pl.DataFrame({"x": 1}) + df.write_parquet(tmp_path / "data.parquet") + df.write_ipc(tmp_path / "data.ipc") + + with pytest.raises( + pl.exceptions.InvalidOperationError, match="different file extensions" + ): + pl.scan_parquet(tmp_path).collect() + + for lf in [ + pl.scan_parquet(tmp_path / "*.parquet"), + pl.scan_ipc(tmp_path / "*.ipc"), + ]: + assert_frame_equal(lf.collect(), df) + + # Ensure passing a glob doesn't trigger file extension checking + with pytest.raises( + pl.exceptions.ComputeError, + match="parquet: File out of specification: The file must end with PAR1", + ): + pl.scan_parquet(tmp_path / "*").collect() + + +@pytest.mark.parametrize("format", ["parquet", "csv", "ndjson", "ipc"]) +def test_scan_nonexistent_path(format: str) -> None: + path_str = f"my-nonexistent-data.{format}" + path = Path(path_str) + assert not path.exists() + + scan_function = getattr(pl, f"scan_{format}") + + # Just calling the scan function should not raise any errors + result = scan_function(path) + assert isinstance(result, pl.LazyFrame) + + # Upon collection, it should fail + with pytest.raises(FileNotFoundError): + result.collect() + + +@pytest.mark.write_disk +@pytest.mark.parametrize( + ("scan_func", "write_func"), + [ + (pl.scan_parquet, pl.DataFrame.write_parquet), + (pl.scan_ipc, pl.DataFrame.write_ipc), + (pl.scan_csv, pl.DataFrame.write_csv), + (pl.scan_ndjson, pl.DataFrame.write_ndjson), + ], +) +@pytest.mark.parametrize( + "streaming", + [True, False], +) +def test_scan_include_file_paths( + tmp_path: Path, + scan_func: Callable[..., pl.LazyFrame], + write_func: Callable[[pl.DataFrame, Path], None], + streaming: bool, +) -> None: + tmp_path.mkdir(exist_ok=True) + dfs: list[pl.DataFrame] = [] + + for x in ["1", "2"]: + path = Path(f"{tmp_path}/{x}.bin").absolute() + dfs.append(pl.DataFrame({"x": 10 * [x]}).with_columns(path=pl.lit(str(path)))) + write_func(dfs[-1].drop("path"), path) + + df = pl.concat(dfs) + assert df.columns == ["x", "path"] + + with pytest.raises( + pl.exceptions.DuplicateError, + match=r'column name for file paths "x" conflicts with column name from file', + ): + scan_func(tmp_path, include_file_paths="x").collect( + engine="streaming" if streaming else "in-memory" + ) + + f = scan_func + if scan_func in [pl.scan_csv, pl.scan_ndjson]: + f = partial(f, schema=df.drop("path").schema) + + lf: pl.LazyFrame = f(tmp_path, include_file_paths="path") + assert_frame_equal(lf.collect(engine="streaming" if streaming else "in-memory"), df) + + # Test projecting only the path column + q = lf.select("path") + assert q.collect_schema() == {"path": pl.String} + assert_frame_equal( + q.collect(engine="streaming" if streaming else "in-memory"), + df.select("path"), + ) + + q = q.select("path").head(3) + assert q.collect_schema() == {"path": pl.String} + assert_frame_equal( + q.collect(engine="streaming" if streaming else "in-memory"), + df.select("path").head(3), + ) + + # Test predicates + for predicate in [pl.col("path") != pl.col("x"), pl.col("path") != ""]: + assert_frame_equal( + lf.filter(predicate).collect( + engine="streaming" if streaming else "in-memory" + ), + df, + ) + + # Test codepaths that materialize empty DataFrames + assert_frame_equal( + lf.head(0).collect(engine="streaming" if streaming else "in-memory"), + df.head(0), + ) + + +@pytest.mark.write_disk +def test_async_path_expansion_bracket_17629(tmp_path: Path) -> None: + path = tmp_path / "data.parquet" + + df = pl.DataFrame({"x": 1}) + df.write_parquet(path) + + assert_frame_equal(pl.scan_parquet(tmp_path / "[d]ata.parquet").collect(), df) + + +@pytest.mark.parametrize( + "method", + ["parquet", "csv", "ipc", "ndjson"], +) +@pytest.mark.may_fail_auto_streaming # unsupported negative slice offset -1 for CSV source +def test_scan_in_memory(method: str) -> None: + f = io.BytesIO() + df = pl.DataFrame( + { + "a": [1, 2, 3], + "b": ["x", "y", "z"], + } + ) + + (getattr(df, f"write_{method}"))(f) + + f.seek(0) + result = (getattr(pl, f"scan_{method}"))(f).collect() + assert_frame_equal(df, result) + + f.seek(0) + result = (getattr(pl, f"scan_{method}"))(f).slice(1, 2).collect() + assert_frame_equal(df.slice(1, 2), result) + + f.seek(0) + result = (getattr(pl, f"scan_{method}"))(f).slice(-1, 1).collect() + assert_frame_equal(df.slice(-1, 1), result) + + g = io.BytesIO() + (getattr(df, f"write_{method}"))(g) + + f.seek(0) + g.seek(0) + result = (getattr(pl, f"scan_{method}"))([f, g]).collect() + assert_frame_equal(df.vstack(df), result) + + f.seek(0) + g.seek(0) + result = (getattr(pl, f"scan_{method}"))([f, g]).slice(1, 2).collect() + assert_frame_equal(df.vstack(df).slice(1, 2), result) + + f.seek(0) + g.seek(0) + result = (getattr(pl, f"scan_{method}"))([f, g]).slice(-1, 1).collect() + assert_frame_equal(df.vstack(df).slice(-1, 1), result) + + +def test_scan_pyobject_zero_copy_buffer_mutate() -> None: + f = io.BytesIO() + + df = pl.DataFrame({"x": [1, 2, 3, 4, 5]}) + df.write_ipc(f) + f.seek(0) + + q = pl.scan_ipc(f) + assert_frame_equal(q.collect(), df) + + f.write(b"AAA") + assert_frame_equal(q.collect(), df) + + +@pytest.mark.parametrize( + "method", + ["csv", "ndjson"], +) +def test_scan_stringio(method: str) -> None: + f = io.StringIO() + df = pl.DataFrame( + { + "a": [1, 2, 3], + "b": ["x", "y", "z"], + } + ) + + (getattr(df, f"write_{method}"))(f) + + f.seek(0) + result = (getattr(pl, f"scan_{method}"))(f).collect() + assert_frame_equal(df, result) + + g = io.StringIO() + (getattr(df, f"write_{method}"))(g) + + f.seek(0) + g.seek(0) + result = (getattr(pl, f"scan_{method}"))([f, g]).collect() + assert_frame_equal(df.vstack(df), result) + + +@pytest.mark.parametrize( + "method", + [pl.scan_parquet, pl.scan_csv, pl.scan_ipc, pl.scan_ndjson], +) +def test_empty_list(method: Callable[[list[str]], pl.LazyFrame]) -> None: + with pytest.raises(pl.exceptions.ComputeError, match="expected at least 1 source"): + _ = (method)([]).collect() + + +def test_scan_double_collect_row_index_invalidates_cached_ir_18892() -> None: + lf = pl.scan_csv(io.BytesIO(b"a\n1\n2\n3")) + + lf.collect() + + out = lf.with_row_index().collect() + + assert_frame_equal( + out, + pl.DataFrame( + {"index": [0, 1, 2], "a": [1, 2, 3]}, + schema={"index": pl.UInt32, "a": pl.Int64}, + ), + ) + + +def test_scan_include_file_paths_respects_projection_pushdown() -> None: + q = pl.scan_csv(b"a,b,c\na1,b1,c1", include_file_paths="path_name").select( + ["a", "b"] + ) + + assert_frame_equal(q.collect(), pl.DataFrame({"a": "a1", "b": "b1"})) + + +def test_streaming_scan_csv_include_file_paths_18257(io_files_path: Path) -> None: + lf = pl.scan_csv( + io_files_path / "foods1.csv", + include_file_paths="path", + ).select("category", "path") + + assert lf.collect(engine="streaming").columns == ["category", "path"] + + +def test_streaming_scan_csv_with_row_index_19172(io_files_path: Path) -> None: + lf = ( + pl.scan_csv(io_files_path / "foods1.csv", infer_schema=False) + .with_row_index() + .select("calories", "index") + .head(1) + ) + + assert_frame_equal( + lf.collect(engine="streaming"), + pl.DataFrame( + {"calories": "45", "index": 0}, + schema={"calories": pl.String, "index": pl.UInt32}, + ), + ) + + +@pytest.mark.write_disk +def test_predicate_hive_pruning_with_cast(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + df = pl.DataFrame({"x": 1}) + + (p := (tmp_path / "date=2024-01-01")).mkdir() + + df.write_parquet(p / "1") + + (p := (tmp_path / "date=2024-01-02")).mkdir() + + # Write an invalid parquet file that will cause errors if polars attempts to + # read it. + # This works because `scan_parquet()` only looks at the first file during + # schema inference. + (p / "1").write_text("not a parquet file") + + expect = pl.DataFrame({"x": 1, "date": datetime(2024, 1, 1).date()}) + + lf = pl.scan_parquet(tmp_path) + + q = lf.filter(pl.col("date") < datetime(2024, 1, 2).date()) + + assert_frame_equal(q.collect(), expect) + + # This filter expr with stprtime is effectively what LazyFrame.sql() + # generates + q = lf.filter( + pl.col("date") + < pl.lit("2024-01-02").str.strptime( + dtype=pl.Date, format="%Y-%m-%d", ambiguous="latest" + ) + ) + + assert_frame_equal(q.collect(), expect) + + q = lf.sql("select * from self where date < '2024-01-02'") + print(q.explain()) + assert_frame_equal(q.collect(), expect) + + +def test_predicate_stats_eval_nested_binary() -> None: + bufs: list[bytes] = [] + + for i in range(10): + b = io.BytesIO() + pl.DataFrame({"x": i}).write_parquet(b) + b.seek(0) + bufs.append(b.read()) + + assert_frame_equal( + ( + pl.scan_parquet(bufs) + .filter(pl.col("x") % 2 == 0) + .collect(no_optimization=True) + ), + pl.DataFrame({"x": [0, 2, 4, 6, 8]}), + ) + + assert_frame_equal( + ( + pl.scan_parquet(bufs) + # The literal eval depth limit is 4 - + # * crates/polars-expr/src/expressions/mod.rs::PhysicalExpr::evaluate_inline + .filter(pl.col("x") == pl.lit("222").str.slice(0, 1).cast(pl.Int64)) + .collect() + ), + pl.DataFrame({"x": [2]}), + ) + + +@pytest.mark.slow +@pytest.mark.parametrize("streaming", [True, False]) +def test_scan_csv_bytesio_memory_usage( + streaming: bool, + # memory_usage_without_pyarrow: MemoryUsage, +) -> None: + # memory_usage = memory_usage_without_pyarrow + + # Create CSV that is ~6-7 MB in size: + f = io.BytesIO() + df = pl.DataFrame({"mydata": pl.int_range(0, 1_000_000, eager=True)}) + df.write_csv(f) + # assert 6_000_000 < f.tell() < 7_000_000 + f.seek(0, 0) + + # A lazy scan shouldn't make a full copy of the data: + # starting_memory = memory_usage.get_current() + assert ( + pl.scan_csv(f) + .filter(pl.col("mydata") == 999_999) + .collect(engine="streaming" if streaming else "in-memory") + .item() + == 999_999 + ) + # assert memory_usage.get_peak() - starting_memory < 1_000_000 + + +@pytest.mark.parametrize( + "scan_type", + [ + (pl.DataFrame.write_parquet, pl.scan_parquet), + (pl.DataFrame.write_ipc, pl.scan_ipc), + (pl.DataFrame.write_csv, pl.scan_csv), + (pl.DataFrame.write_ndjson, pl.scan_ndjson), + ], +) +def test_only_project_row_index(scan_type: tuple[Any, Any]) -> None: + write, scan = scan_type + + f = io.BytesIO() + df = pl.DataFrame([pl.Series("a", [1, 2, 3], pl.UInt32)]) + write(df, f) + + f.seek(0) + s = scan(f, row_index_name="row_index", row_index_offset=42) + + assert_frame_equal( + s.select("row_index").collect(), + pl.DataFrame({"row_index": [42, 43, 44]}), + check_dtypes=False, + ) + + +@pytest.mark.parametrize( + "scan_type", + [ + (pl.DataFrame.write_parquet, pl.scan_parquet), + (pl.DataFrame.write_ipc, pl.scan_ipc), + (pl.DataFrame.write_csv, pl.scan_csv), + (pl.DataFrame.write_ndjson, pl.scan_ndjson), + ], +) +def test_only_project_include_file_paths(scan_type: tuple[Any, Any]) -> None: + write, scan = scan_type + + f = io.BytesIO() + df = pl.DataFrame([pl.Series("a", [1, 2, 3], pl.UInt32)]) + write(df, f) + + f.seek(0) + s = scan(f, include_file_paths="file_path") + + # The exact value for in-memory buffers is undefined + c = s.select("file_path").collect() + assert c.height == 3 + assert c.columns == ["file_path"] + + +@pytest.mark.parametrize( + "scan_type", + [ + (pl.DataFrame.write_parquet, pl.scan_parquet), + pytest.param( + (pl.DataFrame.write_ipc, pl.scan_ipc), + marks=pytest.mark.xfail( + reason="has no allow_missing_columns parameter. https://github.com/pola-rs/polars/issues/21166" + ), + ), + pytest.param( + (pl.DataFrame.write_csv, pl.scan_csv), + marks=pytest.mark.xfail( + reason="has no allow_missing_columns parameter. https://github.com/pola-rs/polars/issues/21166" + ), + ), + pytest.param( + (pl.DataFrame.write_ndjson, pl.scan_ndjson), + marks=pytest.mark.xfail( + reason="has no allow_missing_columns parameter. https://github.com/pola-rs/polars/issues/21166" + ), + ), + ], +) +def test_only_project_missing(scan_type: tuple[Any, Any]) -> None: + write, scan = scan_type + + f = io.BytesIO() + g = io.BytesIO() + write( + pl.DataFrame( + [pl.Series("a", [], pl.UInt32), pl.Series("missing", [], pl.Int32)] + ), + f, + ) + write(pl.DataFrame([pl.Series("a", [1, 2, 3], pl.UInt32)]), g) + + f.seek(0) + g.seek(0) + s = scan([f, g], allow_missing_columns=True) + + assert_frame_equal( + s.select("missing").collect(), + pl.DataFrame([pl.Series("missing", [None, None, None], pl.Int32)]), + ) + + +@pytest.mark.skipif(sys.platform == "win32", reason="windows paths are a mess") +@pytest.mark.write_disk +@pytest.mark.parametrize( + "scan_type", + [ + (pl.DataFrame.write_parquet, pl.scan_parquet), + (pl.DataFrame.write_ipc, pl.scan_ipc), + (pl.DataFrame.write_csv, pl.scan_csv), + (pl.DataFrame.write_ndjson, pl.scan_ndjson), + ], +) +def test_async_read_21945(tmp_path: Path, scan_type: tuple[Any, Any]) -> None: + f1 = tmp_path / "f1" + f2 = tmp_path / "f2" + + pl.DataFrame({"value": [1, 2]}).write_parquet(f1) + pl.DataFrame({"value": [3]}).write_parquet(f2) + + df = ( + pl.scan_parquet(["file://" + str(f1), str(f2)], include_file_paths="foo") + .filter(value=1) + .collect() + ) + + assert_frame_equal( + df, pl.DataFrame({"value": [1], "foo": ["file://" + f1.as_posix()]}) + ) + + +@pytest.mark.write_disk +@pytest.mark.parametrize("with_str_contains", [False, True]) +def test_hive_pruning_str_contains_21706( + tmp_path: Path, capfd: Any, monkeypatch: Any, with_str_contains: bool +) -> None: + df = pl.DataFrame( + { + "pdate": [20250301, 20250301, 20250302, 20250302, 20250303, 20250303], + "prod_id": ["A1", "A2", "B1", "B2", "C1", "C2"], + "price": [11, 22, 33, 44, 55, 66], + } + ) + + df.write_parquet(tmp_path, partition_by=["pdate"]) + + monkeypatch.setenv("POLARS_VERBOSE", "1") + f = pl.col("pdate") == 20250303 + if with_str_contains: + f = f & pl.col("prod_id").str.contains("1") + + df = pl.scan_parquet(tmp_path, hive_partitioning=True).filter(f).collect() + + captured = capfd.readouterr().err + assert "allows skipping 2 / 3" in captured + + assert_frame_equal( + df, + pl.scan_parquet(tmp_path, hive_partitioning=True).collect().filter(f), + ) diff --git a/py-polars/tests/unit/io/test_sink.py b/py-polars/tests/unit/io/test_sink.py new file mode 100644 index 000000000000..9f1663f98c11 --- /dev/null +++ b/py-polars/tests/unit/io/test_sink.py @@ -0,0 +1,108 @@ +import io +from pathlib import Path +from typing import Any + +import pytest + +import polars as pl +from polars._typing import EngineType +from polars.testing import assert_frame_equal + +SINKS = [ + (pl.scan_ipc, pl.LazyFrame.sink_ipc), + (pl.scan_parquet, pl.LazyFrame.sink_parquet), + (pl.scan_csv, pl.LazyFrame.sink_csv), + (pl.scan_ndjson, pl.LazyFrame.sink_ndjson), +] + + +@pytest.mark.parametrize(("scan", "sink"), SINKS) +@pytest.mark.parametrize("engine", ["in-memory", "streaming"]) +@pytest.mark.write_disk +def test_mkdir(tmp_path: Path, scan: Any, sink: Any, engine: EngineType) -> None: + df = pl.DataFrame( + { + "a": [1, 2, 3], + } + ) + + with pytest.raises(FileNotFoundError): + sink(df.lazy(), tmp_path / "a" / "b" / "c" / "file", engine=engine) + + f = tmp_path / "a" / "b" / "c" / "file2" + sink(df.lazy(), f, mkdir=True) + + assert_frame_equal(scan(f).collect(), df) + + +@pytest.mark.parametrize(("scan", "sink"), SINKS) +@pytest.mark.parametrize("engine", ["in-memory", "streaming"]) +@pytest.mark.write_disk +def test_lazy_sinks(tmp_path: Path, scan: Any, sink: Any, engine: EngineType) -> None: + df = pl.DataFrame({"a": [1, 2, 3]}) + lf1 = sink(df.lazy(), tmp_path / "a", lazy=True) + lf2 = sink(df.lazy(), tmp_path / "b", lazy=True) + + assert not Path(tmp_path / "a").exists() + assert not Path(tmp_path / "b").exists() + + pl.collect_all([lf1, lf2], engine=engine) + + assert_frame_equal(scan(tmp_path / "a").collect(), df) + assert_frame_equal(scan(tmp_path / "b").collect(), df) + + +@pytest.mark.parametrize( + "sink", + [ + pl.LazyFrame.sink_ipc, + pl.LazyFrame.sink_parquet, + pl.LazyFrame.sink_csv, + pl.LazyFrame.sink_ndjson, + ], +) +@pytest.mark.write_disk +def test_double_lazy_error(sink: Any) -> None: + df = pl.DataFrame({}) + + with pytest.raises( + pl.exceptions.InvalidOperationError, + match="cannot create a sink on top of another sink", + ): + sink(sink(df.lazy(), "a", lazy=True), "b") + + +@pytest.mark.parametrize(("scan", "sink"), SINKS) +def test_sink_to_memory(sink: Any, scan: Any) -> None: + df = pl.DataFrame( + { + "a": [5, 10, 1996], + } + ) + + f = io.BytesIO() + sink(df.lazy(), f) + + f.seek(0) + assert_frame_equal( + scan(f).collect(), + df, + ) + + +@pytest.mark.parametrize(("scan", "sink"), SINKS) +@pytest.mark.write_disk +def test_sink_to_file(tmp_path: Path, sink: Any, scan: Any) -> None: + df = pl.DataFrame( + { + "a": [5, 10, 1996], + } + ) + + with (tmp_path / "f").open("w+") as f: + sink(df.lazy(), f, sync_on_close="all") + f.seek(0) + assert_frame_equal( + scan(f).collect(), + df, + ) diff --git a/py-polars/tests/unit/io/test_skip_batch_predicate.py b/py-polars/tests/unit/io/test_skip_batch_predicate.py new file mode 100644 index 000000000000..30ae14162467 --- /dev/null +++ b/py-polars/tests/unit/io/test_skip_batch_predicate.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +import contextlib +import datetime +from typing import TYPE_CHECKING, Any, TypedDict + +from hypothesis import Phase, given, settings + +import polars as pl +from polars.meta import get_index_type +from polars.testing import assert_frame_equal, assert_series_equal +from polars.testing.parametric.strategies import series + +if TYPE_CHECKING: + from collections.abc import Sequence + + from polars._typing import PythonLiteral + + +class Case(TypedDict): + """A test case for Skip Batch Predicate.""" + + min: Any | None + max: Any | None + null_count: int | None + len: int | None + can_skip: bool + + +def assert_skp_series( + name: str, + dtype: pl.DataType, + expr: pl.Expr, + cases: Sequence[Case], +) -> None: + sbp = expr._skip_batch_predicate({name: dtype}) + + df = pl.DataFrame( + [ + pl.Series(f"{name}_min", [i["min"] for i in cases], dtype), + pl.Series(f"{name}_max", [i["max"] for i in cases], dtype), + pl.Series(f"{name}_nc", [i["null_count"] for i in cases], get_index_type()), + pl.Series("len", [i["len"] for i in cases], get_index_type()), + ] + ) + mask = pl.Series("can_skip", [i["can_skip"] for i in cases], pl.Boolean) + + out = df.select(can_skip=sbp).to_series() + out = out.replace(None, False) + + try: + assert_series_equal(out, mask) + except AssertionError: + print(sbp) + raise + + +def test_true_false_predicate() -> None: + true_sbp = pl.lit(True)._skip_batch_predicate({}) + false_sbp = pl.lit(False)._skip_batch_predicate({}) + null_sbp = pl.lit(None)._skip_batch_predicate({}) + + df = pl.DataFrame({"len": [1]}) + + out = df.select( + true=true_sbp, + false=false_sbp, + null=null_sbp, + ) + + assert_frame_equal( + out, + pl.DataFrame( + { + "true": [False], + "false": [True], + "null": [True], + } + ), + ) + + +def test_equality() -> None: + assert_skp_series( + "a", + pl.Int64(), + pl.col("a") == 5, + [ + {"min": 1, "max": 2, "null_count": 0, "len": 42, "can_skip": True}, + {"min": 6, "max": 7, "null_count": 0, "len": 42, "can_skip": True}, + {"min": 1, "max": 7, "null_count": 0, "len": 42, "can_skip": False}, + {"min": None, "max": None, "null_count": 42, "len": 42, "can_skip": True}, + ], + ) + + assert_skp_series( + "a", + pl.Int64(), + pl.col("a") != 0, + [ + {"min": 0, "max": 0, "null_count": 6, "len": 7, "can_skip": False}, + ], + ) + + +def test_datetimes() -> None: + d = datetime.datetime(2023, 4, 1, 0, 0, 0, tzinfo=datetime.timezone.utc) + td = datetime.timedelta + + assert_skp_series( + "a", + pl.Datetime(time_zone=datetime.timezone.utc), + pl.col("a") == d, + [ + { + "min": d - td(days=2), + "max": d - td(days=1), + "null_count": 0, + "len": 42, + "can_skip": True, + }, + { + "min": d + td(days=1), + "max": d - td(days=2), + "null_count": 0, + "len": 42, + "can_skip": True, + }, + {"min": d, "max": d, "null_count": 42, "len": 42, "can_skip": True}, + {"min": d, "max": d, "null_count": 0, "len": 42, "can_skip": False}, + { + "min": d - td(days=2), + "max": d + td(days=2), + "null_count": 0, + "len": 42, + "can_skip": False, + }, + { + "min": d + td(days=1), + "max": None, + "null_count": None, + "len": None, + "can_skip": True, + }, + ], + ) + + +@given( + s=series( + name="x", + min_size=1, + ), +) +@settings( + report_multiple_bugs=False, + phases=(Phase.explicit, Phase.reuse, Phase.generate, Phase.target, Phase.explain), +) +def test_skip_batch_predicate_parametric(s: pl.Series) -> None: + name = "x" + dtype = s.dtype + + value_a = s.slice(0, 1) + + lit_a = pl.lit(value_a[0], dtype) + + exprs = [ + pl.col.x == lit_a, + pl.col.x != lit_a, + pl.col.x.eq_missing(lit_a), + pl.col.x.ne_missing(lit_a), + pl.col.x.is_null(), + pl.col.x.is_not_null(), + ] + + try: + _ = s > value_a + exprs += [ + pl.col.x > lit_a, + pl.col.x >= lit_a, + pl.col.x < lit_a, + pl.col.x <= lit_a, + pl.col.x.is_in(pl.Series([None, value_a[0]], dtype=dtype)), + ] + + if s.len() > 1: + value_b = s.slice(1, 1) + lit_b = pl.lit(value_b[0], dtype) + + exprs += [ + pl.col.x.is_between(lit_a, lit_b), + pl.col.x.is_in(pl.Series([value_a[0], value_b[0]], dtype=dtype)), + ] + except Exception as _: + pass + + for expr in exprs: + sbp = expr._skip_batch_predicate({name: dtype}) + + if sbp is None: + continue + + mins: list[PythonLiteral | None] = [None] + with contextlib.suppress(Exception): + mins = [s.min()] + + maxs: list[PythonLiteral | None] = [None] + with contextlib.suppress(Exception): + maxs = [s.max()] + + null_counts = [s.null_count()] + lengths = [s.len()] + + df = pl.DataFrame( + [ + pl.Series(f"{name}_min", mins, dtype), + pl.Series(f"{name}_max", maxs, dtype), + pl.Series(f"{name}_nc", null_counts, get_index_type()), + pl.Series("len", lengths, get_index_type()), + ] + ) + + can_skip = df.select(can_skip=sbp).fill_null(False).to_series()[0] + if can_skip: + try: + assert s.to_frame().filter(expr).height == 0 + except Exception as _: + print(expr) + print(sbp) + print(df) + print(s.to_frame().filter(expr)) + + raise diff --git a/py-polars/tests/unit/io/test_spreadsheet.py b/py-polars/tests/unit/io/test_spreadsheet.py new file mode 100644 index 000000000000..3303fcac7787 --- /dev/null +++ b/py-polars/tests/unit/io/test_spreadsheet.py @@ -0,0 +1,1446 @@ +from __future__ import annotations + +import warnings +from collections import OrderedDict +from datetime import date, datetime, time +from io import BytesIO +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable +from zoneinfo import ZoneInfo + +import pytest + +import polars as pl +import polars.selectors as cs +from polars.exceptions import ( + NoDataError, + ParameterCollisionError, +) +from polars.testing import assert_frame_equal, assert_series_equal +from tests.unit.conftest import FLOAT_DTYPES, NUMERIC_DTYPES + +if TYPE_CHECKING: + from collections.abc import Sequence + + from polars._typing import ( + ExcelSpreadsheetEngine, + PolarsDataType, + SchemaDict, + SelectorType, + ) + + +# pytestmark = pytest.mark.slow() + + +@pytest.fixture +def path_xls(io_files_path: Path) -> Path: + # old excel 97-2004 format + return io_files_path / "example.xls" + + +@pytest.fixture +def path_xlsx(io_files_path: Path) -> Path: + # modern excel format + return io_files_path / "example.xlsx" + + +@pytest.fixture +def path_xlsb(io_files_path: Path) -> Path: + # excel binary format + return io_files_path / "example.xlsb" + + +@pytest.fixture +def path_ods(io_files_path: Path) -> Path: + # open document spreadsheet + return io_files_path / "example.ods" + + +@pytest.fixture +def path_xls_empty(io_files_path: Path) -> Path: + return io_files_path / "empty.xls" + + +@pytest.fixture +def path_xlsx_empty(io_files_path: Path) -> Path: + return io_files_path / "empty.xlsx" + + +@pytest.fixture +def path_xlsx_mixed(io_files_path: Path) -> Path: + return io_files_path / "mixed.xlsx" + + +@pytest.fixture +def path_xlsb_empty(io_files_path: Path) -> Path: + return io_files_path / "empty.xlsb" + + +@pytest.fixture +def path_xlsb_mixed(io_files_path: Path) -> Path: + return io_files_path / "mixed.xlsb" + + +@pytest.fixture +def path_ods_empty(io_files_path: Path) -> Path: + return io_files_path / "empty.ods" + + +@pytest.fixture +def path_ods_mixed(io_files_path: Path) -> Path: + return io_files_path / "mixed.ods" + + +@pytest.fixture +def path_empty_rows_excel(io_files_path: Path) -> Path: + return io_files_path / "test_empty_rows.xlsx" + + +@pytest.mark.parametrize( + ("read_spreadsheet", "source", "engine_params"), + [ + # xls file + (pl.read_excel, "path_xls", {"engine": "calamine"}), + # xlsx file + (pl.read_excel, "path_xlsx", {"engine": "calamine"}), + (pl.read_excel, "path_xlsx", {"engine": "openpyxl"}), + (pl.read_excel, "path_xlsx", {"engine": "xlsx2csv"}), + # xlsb file (binary) + (pl.read_excel, "path_xlsb", {"engine": "calamine"}), + # open document + (pl.read_ods, "path_ods", {}), + ], +) +def test_read_spreadsheet( + read_spreadsheet: Callable[..., pl.DataFrame], + source: str, + engine_params: dict[str, str], + request: pytest.FixtureRequest, +) -> None: + sheet_params: dict[str, Any] + + for sheet_params in ( + {"sheet_name": None, "sheet_id": None}, + {"sheet_name": "test1"}, + {"sheet_id": 1}, + ): + df = read_spreadsheet( + source=request.getfixturevalue(source), + **engine_params, + **sheet_params, + ) + expected = pl.DataFrame({"hello": ["Row 1", "Row 2"]}) + assert_frame_equal(df, expected) + + +@pytest.mark.parametrize( + ("read_spreadsheet", "source", "params"), + [ + # xls file + (pl.read_excel, "path_xls", {"engine": "calamine"}), + # xlsx file + (pl.read_excel, "path_xlsx", {"engine": "calamine"}), + (pl.read_excel, "path_xlsx", {"engine": "openpyxl"}), + (pl.read_excel, "path_xlsx", {"engine": "xlsx2csv"}), + # xlsb file (binary) + (pl.read_excel, "path_xlsb", {"engine": "calamine"}), + # open document + (pl.read_ods, "path_ods", {}), + ], +) +def test_read_excel_multiple_worksheets( + read_spreadsheet: Callable[..., dict[str, pl.DataFrame]], + source: str, + params: dict[str, str], + request: pytest.FixtureRequest, +) -> None: + spreadsheet_path = request.getfixturevalue(source) + frames_by_id = read_spreadsheet( + spreadsheet_path, + sheet_id=[2, 1], + sheet_name=None, + **params, + ) + frames_by_name = read_spreadsheet( + spreadsheet_path, + sheet_id=None, + sheet_name=["test2", "test1"], + **params, + ) + for frames in (frames_by_id, frames_by_name): + assert list(frames_by_name) == ["test2", "test1"] + + expected1 = pl.DataFrame({"hello": ["Row 1", "Row 2"]}) + expected2 = pl.DataFrame({"world": ["Row 3", "Row 4"]}) + + assert_frame_equal(frames["test1"], expected1) + assert_frame_equal(frames["test2"], expected2) + + +@pytest.mark.parametrize( + ("read_spreadsheet", "source", "params"), + [ + # xls file + (pl.read_excel, "path_xls", {"engine": "calamine"}), + # xlsx file + (pl.read_excel, "path_xlsx", {"engine": "calamine"}), + (pl.read_excel, "path_xlsx", {"engine": "openpyxl"}), + (pl.read_excel, "path_xlsx", {"engine": "xlsx2csv"}), + # xlsb file (binary) + (pl.read_excel, "path_xlsb", {"engine": "calamine"}), + # open document + (pl.read_ods, "path_ods", {}), + ], +) +def test_read_excel_multiple_workbooks( + read_spreadsheet: Callable[..., Any], + source: str, + params: dict[str, str], + request: pytest.FixtureRequest, +) -> None: + spreadsheet_path = request.getfixturevalue(source) + + # multiple workbooks, single worksheet + df = read_spreadsheet( + [ + spreadsheet_path, + spreadsheet_path, + spreadsheet_path, + ], + sheet_id=None, + sheet_name="test1", + include_file_paths="path", + **params, + ) + expected = pl.DataFrame( + { + "hello": ["Row 1", "Row 2", "Row 1", "Row 2", "Row 1", "Row 2"], + "path": [str(spreadsheet_path)] * 6, + }, + ) + assert_frame_equal(df, expected) + + # multiple workbooks, multiple worksheets + res = read_spreadsheet( + [ + spreadsheet_path, + spreadsheet_path, + spreadsheet_path, + ], + sheet_id=None, + sheet_name=["test1", "test2"], + **params, + ) + expected_frames = { + "test1": pl.DataFrame( + {"hello": ["Row 1", "Row 2", "Row 1", "Row 2", "Row 1", "Row 2"]} + ), + "test2": pl.DataFrame( + {"world": ["Row 3", "Row 4", "Row 3", "Row 4", "Row 3", "Row 4"]} + ), + } + assert sorted(res) == sorted(expected_frames) + assert_frame_equal(res["test1"], expected_frames["test1"]) + assert_frame_equal(res["test2"], expected_frames["test2"]) + + +@pytest.mark.parametrize( + ("read_spreadsheet", "source", "params"), + [ + # xls file + (pl.read_excel, "path_xls", {"engine": "calamine"}), + # xlsx file + (pl.read_excel, "path_xlsx", {"engine": "calamine"}), + (pl.read_excel, "path_xlsx", {"engine": "openpyxl"}), + (pl.read_excel, "path_xlsx", {"engine": "xlsx2csv"}), + # xlsb file (binary) + (pl.read_excel, "path_xlsb", {"engine": "calamine"}), + # open document + (pl.read_ods, "path_ods", {}), + ], +) +def test_read_excel_all_sheets( + read_spreadsheet: Callable[..., dict[str, pl.DataFrame]], + source: str, + params: dict[str, str], + request: pytest.FixtureRequest, +) -> None: + spreadsheet_path = request.getfixturevalue(source) + frames = read_spreadsheet( + spreadsheet_path, + sheet_id=0, + **params, + ) + assert len(frames) == (4 if str(spreadsheet_path).endswith("ods") else 6) + + expected1 = pl.DataFrame({"hello": ["Row 1", "Row 2"]}) + expected2 = pl.DataFrame({"world": ["Row 3", "Row 4"]}) + expected3 = pl.DataFrame( + { + "cardinality": [1, 3, 15, 30, 150, 300], + "rows_by_key": [0.05059, 0.04478, 0.04414, 0.05245, 0.05395, 0.05677], + "iter_groups": [0.04806, 0.04223, 0.04774, 0.04864, 0.0572, 0.06945], + } + ) + assert_frame_equal(frames["test1"], expected1) + assert_frame_equal(frames["test2"], expected2) + if params.get("engine") == "openpyxl": + # TODO: flag that trims trailing all-null rows? + assert_frame_equal(frames["test3"], expected3) + assert_frame_equal(frames["test4"].drop_nulls(), expected3) + + +@pytest.mark.parametrize( + "engine", + ["calamine", "openpyxl", "xlsx2csv"], +) +def test_read_excel_basic_datatypes(engine: ExcelSpreadsheetEngine) -> None: + df = pl.DataFrame( + { + "A": [1, 2, 3, 4, 5], + "fruits": ["banana", "banana", "apple", "apple", "banana"], + "floats": [1.1, 1.2, 1.3, 1.4, 1.5], + "datetime": [datetime(2023, 1, x) for x in range(1, 6)], + "nulls": [1, None, None, None, 0], + }, + ) + xls = BytesIO() + df.write_excel(xls, position="C5") + + schema_overrides = {"datetime": pl.Datetime("us"), "nulls": pl.Boolean()} + df_compare = df.with_columns( + pl.col(nm).cast(tp) for nm, tp in schema_overrides.items() + ) + for sheet_id, sheet_name in ((None, None), (1, None), (None, "Sheet1")): + df_from_excel = pl.read_excel( + xls, + sheet_id=sheet_id, + sheet_name=sheet_name, + engine=engine, + schema_overrides=schema_overrides, + ) + assert_frame_equal(df_compare, df_from_excel) + + # check some additional overrides + # (note: xlsx2csv can't currently convert datetime with trailing '00:00:00' to date) + dt_override = {"datetime": pl.Date} if engine != "xlsx2csv" else {} + df = pl.read_excel( + xls, + sheet_id=sheet_id, + sheet_name=sheet_name, + engine=engine, + schema_overrides={"A": pl.Float32, **dt_override}, + ) + assert_series_equal( + df["A"], + pl.Series(name="A", values=[1.0, 2.0, 3.0, 4.0, 5.0], dtype=pl.Float32), + ) + if dt_override: + assert_series_equal( + df["datetime"], + pl.Series( + name="datetime", + values=[date(2023, 1, x) for x in range(1, 6)], + dtype=pl.Date, + ), + ) + + +@pytest.mark.parametrize( + ("read_spreadsheet", "source", "params"), + [ + # TODO: uncomment once fastexcel offers a suitable param + # (pl.read_excel, "path_xlsx", {"engine": "xlsx2csv"}), + (pl.read_excel, "path_xlsx", {"engine": "xlsx2csv"}), + (pl.read_excel, "path_xlsx", {"engine": "openpyxl"}), + ], +) +def test_read_dropped_cols( + read_spreadsheet: Callable[..., dict[str, pl.DataFrame]], + source: str, + params: dict[str, str], + request: pytest.FixtureRequest, +) -> None: + spreadsheet_path = request.getfixturevalue(source) + + df1 = read_spreadsheet( + spreadsheet_path, + sheet_name="test4", + **params, + ) + df2 = read_spreadsheet( + spreadsheet_path, + sheet_name="test4", + drop_empty_cols=False, + **params, + ) + assert df1.to_dict(as_series=False) == { # type: ignore[attr-defined] + "cardinality": [1, 3, 15, 30, 150, 300], + "rows_by_key": [0.05059, 0.04478, 0.04414, 0.05245, 0.05395, 0.05677], + "iter_groups": [0.04806, 0.04223, 0.04774, 0.04864, 0.0572, 0.06945], + } + assert df2.to_dict(as_series=False) == { # type: ignore[attr-defined] + "": [None, None, None, None, None, None], + "cardinality": [1, 3, 15, 30, 150, 300], + "rows_by_key": [0.05059, 0.04478, 0.04414, 0.05245, 0.05395, 0.05677], + "iter_groups": [0.04806, 0.04223, 0.04774, 0.04864, 0.0572, 0.06945], + "0": [None, None, None, None, None, None], + "1": [None, None, None, None, None, None], + } + + +@pytest.mark.parametrize( + ("source", "params"), + [ + ("path_xls", {"engine": "calamine", "sheet_name": "temporal"}), + ("path_xlsx", {"engine": "calamine", "table_name": "TemporalData"}), + ("path_xlsx", {"engine": "openpyxl", "sheet_name": "temporal"}), + ("path_xlsb", {"engine": "calamine", "sheet_name": "temporal"}), + ], +) +def test_read_excel_temporal_data( + source: str, + params: dict[str, str], + request: pytest.FixtureRequest, +) -> None: + source_path = request.getfixturevalue(source) + + temporal_schema = { + "id": pl.UInt16(), + "dtm": pl.Datetime("ms"), + "dt": pl.Date(), + "dtm_str": pl.Datetime(time_zone="Asia/Tokyo"), + "dt_str": pl.Date(), + "tm_str": pl.Time(), + } + parsed_df = pl.read_excel( # type: ignore[call-overload] + source_path, + **params, + schema_overrides=temporal_schema, + ) + TK = ZoneInfo("Asia/Tokyo") + + expected = pl.DataFrame( + { + "id": [100, 200, 300, 400], + "dtm": [ + datetime(1999, 12, 31, 1, 2, 3), + None, + datetime(1969, 7, 5, 10, 30, 45), + datetime(2077, 10, 10, 5, 59, 44), + ], + "dt": [ + date(2000, 1, 18), + date(1965, 8, 8), + date(2027, 4, 22), + None, + ], + "dtm_str": [ + None, + datetime(1900, 1, 30, 14, 50, 20, tzinfo=TK), + datetime(2026, 5, 7, 23, 59, 59, tzinfo=TK), + datetime(2007, 6, 1, 0, 0, tzinfo=TK), + ], + "dt_str": [ + date(2000, 6, 14), + date(1978, 2, 28), + None, + date(2040, 12, 4), + ], + "tm_str": [ + time(23, 50, 22), + time(0, 0, 1), + time(10, 10, 33), + time(18, 30, 15), + ], + }, + schema=temporal_schema, + ) + assert_frame_equal(expected, parsed_df) + + +@pytest.mark.parametrize( + ("read_spreadsheet", "source", "params"), + [ + # xls file + (pl.read_excel, "path_xls", {"engine": "calamine"}), + # xlsx file + (pl.read_excel, "path_xlsx", {"engine": "calamine"}), + (pl.read_excel, "path_xlsx", {"engine": "openpyxl"}), + (pl.read_excel, "path_xlsx", {"engine": "xlsx2csv"}), + # xlsb file (binary) + (pl.read_excel, "path_xlsb", {"engine": "calamine"}), + # open document + (pl.read_ods, "path_ods", {}), + ], +) +def test_read_invalid_worksheet( + read_spreadsheet: Callable[..., dict[str, pl.DataFrame]], + source: str, + params: dict[str, str], + request: pytest.FixtureRequest, +) -> None: + spreadsheet_path = request.getfixturevalue(source) + for param, sheet_id, sheet_name in ( + ("id", 999, None), + ("name", None, "not_a_sheet_name"), + ): + value = sheet_id if param == "id" else sheet_name + with pytest.raises( + ValueError, + match=f"no matching sheet found when `sheet_{param}` is {value!r}", + ): + read_spreadsheet( + spreadsheet_path, sheet_id=sheet_id, sheet_name=sheet_name, **params + ) + + +@pytest.mark.parametrize( + ("read_spreadsheet", "source", "additional_params"), + [ + (pl.read_excel, "path_xlsx_mixed", {"engine": "openpyxl"}), + (pl.read_ods, "path_ods_mixed", {}), + ], +) +@pytest.mark.may_fail_auto_streaming +def test_read_mixed_dtype_columns( + read_spreadsheet: Callable[..., dict[str, pl.DataFrame]], + source: str, + additional_params: dict[str, str], + request: pytest.FixtureRequest, +) -> None: + spreadsheet_path = request.getfixturevalue(source) + schema_overrides = { + "Employee ID": pl.Utf8(), + "Employee Name": pl.Utf8(), + "Date": pl.Date(), + "Details": pl.Categorical("lexical"), + "Asset ID": pl.Utf8(), + } + df = read_spreadsheet( + spreadsheet_path, + sheet_id=0, + schema_overrides=schema_overrides, + **additional_params, + )["Sheet1"] + + assert_frame_equal( + df, + pl.DataFrame( + { + "Employee ID": ["123456", "44333", "US00011", "135967", "IN86868"], + "Employee Name": ["Test1", "Test2", "Test4", "Test5", "Test6"], + "Date": [ + date(2023, 7, 21), + date(2023, 7, 21), + date(2023, 7, 21), + date(2023, 7, 21), + date(2023, 7, 21), + ], + "Details": [ + "Healthcare", + "Healthcare", + "Healthcare", + "Healthcare", + "Something", + ], + "Asset ID": ["84444", "84444", "84444", "84444", "ABC123"], + }, + schema_overrides=schema_overrides, + ), + ) + + +def test_schema_overrides(path_xlsx: Path, path_xlsb: Path, path_ods: Path) -> None: + df1 = pl.read_excel( + path_xlsx, + sheet_name="test4", + schema_overrides={"cardinality": pl.UInt16}, + ).drop_nulls() + + assert df1.schema["cardinality"] == pl.UInt16 + assert df1.schema["rows_by_key"] == pl.Float64 + assert df1.schema["iter_groups"] == pl.Float64 + + df2 = pl.read_excel( + path_xlsx, + sheet_name="test4", + engine="xlsx2csv", + read_options={"schema_overrides": {"cardinality": pl.UInt16}}, + ).drop_nulls() + + assert df2.schema["cardinality"] == pl.UInt16 + assert df2.schema["rows_by_key"] == pl.Float64 + assert df2.schema["iter_groups"] == pl.Float64 + + df3 = pl.read_excel( + path_xlsx, + sheet_name="test4", + engine="xlsx2csv", + schema_overrides={"cardinality": pl.UInt16}, + read_options={ + "schema_overrides": { + "rows_by_key": pl.Float32, + "iter_groups": pl.Float32, + }, + }, + ).drop_nulls() + + assert df3.schema["cardinality"] == pl.UInt16 + assert df3.schema["rows_by_key"] == pl.Float32 + assert df3.schema["iter_groups"] == pl.Float32 + + for workbook_path in (path_xlsx, path_xlsb, path_ods): + read_spreadsheet = ( + pl.read_ods if workbook_path.suffix == ".ods" else pl.read_excel + ) + df4 = read_spreadsheet( # type: ignore[operator] + workbook_path, + sheet_name="test5", + schema_overrides={"dtm": pl.Datetime("ns"), "dt": pl.Date}, + ) + assert_frame_equal( + df4, + pl.DataFrame( + { + "dtm": [ + datetime(1999, 12, 31, 10, 30, 45), + datetime(2010, 10, 11, 12, 13, 14), + ], + "dt": [date(2024, 1, 1), date(2018, 8, 7)], + "val": [1.5, -0.5], + }, + schema={"dtm": pl.Datetime("ns"), "dt": pl.Date, "val": pl.Float64}, + ), + ) + + with pytest.raises(ParameterCollisionError): + # cannot specify 'cardinality' in both schema_overrides and read_options + pl.read_excel( + path_xlsx, + sheet_name="test4", + engine="xlsx2csv", + schema_overrides={"cardinality": pl.UInt16}, + read_options={"schema_overrides": {"cardinality": pl.Int32}}, + ) + + # read multiple sheets in conjunction with 'schema_overrides' + # (note: reading the same sheet twice simulates the issue in #11850) + overrides = OrderedDict( + [ + ("cardinality", pl.UInt32), + ("rows_by_key", pl.Float32), + ("iter_groups", pl.Float64), + ] + ) + df = pl.read_excel( # type: ignore[call-overload] + path_xlsx, + sheet_name=["test4", "test4"], + schema_overrides=overrides, + ) + for col, dtype in overrides.items(): + assert df["test4"].schema[col] == dtype + + +@pytest.mark.parametrize( + ("engine", "read_opts_param"), + [ + ("xlsx2csv", "infer_schema_length"), + ("calamine", "schema_sample_rows"), + ], +) +def test_invalid_parameter_combinations_infer_schema_len( + path_xlsx: Path, engine: str, read_opts_param: str +) -> None: + with pytest.raises( + ParameterCollisionError, + match=f"cannot specify both `infer_schema_length`.*{read_opts_param}", + ): + pl.read_excel( # type: ignore[call-overload] + path_xlsx, + sheet_id=1, + engine=engine, + read_options={read_opts_param: 512}, + infer_schema_length=1024, + ) + + +@pytest.mark.parametrize( + ("engine", "read_opts_param"), + [ + ("xlsx2csv", "columns"), + ("calamine", "use_columns"), + ], +) +def test_invalid_parameter_combinations_columns( + path_xlsx: Path, engine: str, read_opts_param: str +) -> None: + with pytest.raises( + ParameterCollisionError, + match=f"cannot specify both `columns`.*{read_opts_param}", + ): + pl.read_excel( # type: ignore[call-overload] + path_xlsx, + sheet_id=1, + engine=engine, + read_options={read_opts_param: ["B", "C", "D"]}, + columns=["A", "B", "C"], + ) + + +def test_unsupported_engine() -> None: + with pytest.raises(NotImplementedError): + pl.read_excel(None, engine="foo") # type: ignore[call-overload] + + +def test_unsupported_binary_workbook(path_xlsb: Path) -> None: + with pytest.raises(Exception, match="does not support binary format"): + pl.read_excel(path_xlsb, engine="openpyxl") + + +@pytest.mark.parametrize("engine", ["calamine", "openpyxl", "xlsx2csv"]) +def test_read_excel_all_sheets_with_sheet_name(path_xlsx: Path, engine: str) -> None: + with pytest.raises( + ValueError, + match=r"cannot specify both `sheet_name` \('Sheet1'\) and `sheet_id` \(1\)", + ): + pl.read_excel( # type: ignore[call-overload] + path_xlsx, + sheet_id=1, + sheet_name="Sheet1", + engine=engine, + ) + + +# the parameters don't change the data, only the formatting, so we expect +# the same result each time. however, it's important to validate that the +# parameter permutations don't raise exceptions, or interfere with the +# values written to the worksheet, so test multiple variations. +@pytest.mark.parametrize( + "write_params", + [ + # default parameters + {}, + # basic formatting + { + "autofit": True, + "table_style": "Table Style Dark 2", + "column_totals": True, + "float_precision": 0, + }, + # slightly customized formatting, with some formulas + { + "position": (0, 0), + "table_style": { + "style": "Table Style Medium 23", + "first_column": True, + }, + "conditional_formats": {"val": "data_bar"}, + "column_formats": { + "val": "#,##0.000;[White]-#,##0.000", + ("day", "month", "year"): {"align": "left", "num_format": "0"}, + }, + "header_format": {"italic": True, "bg_color": "#d9d9d9"}, + "column_widths": {"val": 100}, + "row_heights": {0: 35}, + "formulas": { + # string: formula added to the end of the table (but before row_totals) + "day": "=DAY([@dtm])", + "month": "=MONTH([@dtm])", + "year": { + # dict: full control over formula positioning/dtype + "formula": "=YEAR([@dtm])", + "insert_after": "month", + "return_type": pl.Int16, + }, + }, + "column_totals": True, + "row_totals": True, + }, + # heavily customized formatting/definition + { + "position": "A1", + "table_name": "PolarsFrameData", + "table_style": "Table Style Light 9", + "conditional_formats": { + # single dict format + "str": { + "type": "duplicate", + "format": {"bg_color": "#ff0000", "font_color": "#ffffff"}, + }, + # multiple dict formats + "val": [ + { + "type": "3_color_scale", + "min_color": "#4bacc6", + "mid_color": "#ffffff", + "max_color": "#daeef3", + }, + { + "type": "cell", + "criteria": "<", + "value": -90, + "format": {"font_color": "white"}, + }, + ], + "dtm": [ + { + "type": "top", + "value": 1, + "format": {"bold": True, "font_color": "green"}, + }, + { + "type": "bottom", + "value": 1, + "format": {"bold": True, "font_color": "red"}, + }, + ], + }, + "dtype_formats": { + frozenset( + FLOAT_DTYPES + ): '_(£* #,##0.00_);_(£* (#,##0.00);_(£* "-"??_);_(@_)', + pl.Date: "dd-mm-yyyy", + }, + "column_formats": {"dtm": {"font_color": "#31869c", "bg_color": "#b7dee8"}}, + "column_totals": {"val": "average", "dtm": "min"}, + "column_widths": {("str", "val"): 60, "dtm": 80}, + "row_totals": {"tot": True}, + "hidden_columns": ["str"], + "hide_gridlines": True, + "include_header": False, + }, + ], +) +def test_excel_round_trip(write_params: dict[str, Any]) -> None: + df = pl.DataFrame( + { + "dtm": [date(2023, 1, 1), date(2023, 1, 2), date(2023, 1, 3)], + "str": ["xxx", "yyy", "xxx"], + "val": [100.5, 55.0, -99.5], + } + ) + + engine: ExcelSpreadsheetEngine + for engine in ("calamine", "xlsx2csv"): + read_options, has_header = ( + ({}, True) + if write_params.get("include_header", True) + else ( + {"new_columns": ["dtm", "str", "val"]} + if engine == "xlsx2csv" + else {"column_names": ["dtm", "str", "val"]}, + False, + ) + ) + + fmt_strptime = "%Y-%m-%d" + if write_params.get("dtype_formats", {}).get(pl.Date) == "dd-mm-yyyy": + fmt_strptime = "%d-%m-%Y" + + # write to xlsx using various parameters... + xls = BytesIO() + _wb = df.write_excel(workbook=xls, worksheet="data", **write_params) + + # ...and read it back again: + xldf = pl.read_excel( + xls, + sheet_name="data", + engine=engine, + read_options=read_options, + has_header=has_header, + )[:3].select(df.columns[:3]) + + if engine == "xlsx2csv": + xldf = xldf.with_columns(pl.col("dtm").str.strptime(pl.Date, fmt_strptime)) + + assert_frame_equal(df, xldf) + + +@pytest.mark.parametrize("engine", ["calamine", "xlsx2csv"]) +def test_excel_write_column_and_row_totals(engine: ExcelSpreadsheetEngine) -> None: + df = pl.DataFrame( + { + "id": ["aaa", "bbb", "ccc", "ddd", "eee"], + # float cols + "q1": [100.0, 55.5, -20.0, 0.5, 35.0], + "q2": [30.5, -10.25, 15.0, 60.0, 20.5], + # int cols + "q3": [-50, 0, 40, 80, 80], + "q4": [75, 55, 25, -10, -55], + } + ) + for fn_sum in (True, "sum", "SUM"): + xls = BytesIO() + df.write_excel( + xls, + worksheet="misc", + sparklines={"trend": ["q1", "q2", "q3", "q4"]}, + row_totals={ + # add semiannual row total columns + "h1": ("q1", "q2"), + "h2": ("q3", "q4"), + }, + column_totals=fn_sum, + ) + + # note that the totals are written as formulae, so we + # won't have the calculated values in the dataframe + xldf = pl.read_excel(xls, sheet_name="misc", engine=engine) + + assert xldf.columns == ["id", "q1", "q2", "q3", "q4", "trend", "h1", "h2"] + assert xldf.row(-1) == (None, 0.0, 0.0, 0, 0, None, 0.0, 0) + + +@pytest.mark.parametrize( + ("engine", "list_dtype"), + [ + ("calamine", pl.List(pl.Int8)), + ("openpyxl", pl.List(pl.UInt16)), + ("xlsx2csv", pl.Array(pl.Int32, 3)), + ], +) +def test_excel_write_compound_types( + engine: ExcelSpreadsheetEngine, + list_dtype: PolarsDataType, +) -> None: + df = pl.DataFrame( + data={"x": [None, [1, 2, 3], [4, 5, 6]], "y": ["a", "b", "c"], "z": [9, 8, 7]}, + schema_overrides={"x": pl.Array(pl.Int32, 3)}, + ).select("x", pl.struct(["y", "z"])) + + xls = BytesIO() + df.write_excel(xls, worksheet="data") + + # also test reading from the various flavours of supported binary data + # across all backend engines (check bytesio, bytes, and memoryview) + for binary_data in ( + xls, + xls.getvalue(), + xls.getbuffer(), + ): + xldf = pl.read_excel( + binary_data, + sheet_name="data", + engine=engine, + include_file_paths="wbook", + ) + + # expect string conversion (only scalar values are supported) + assert xldf.rows() == [ + (None, "{'y': 'a', 'z': 9}", "in-mem"), + ("[1, 2, 3]", "{'y': 'b', 'z': 8}", "in-mem"), + ("[4, 5, 6]", "{'y': 'c', 'z': 7}", "in-mem"), + ] + + +def test_excel_read_named_table_with_total_row(tmp_path: Path) -> None: + df = pl.DataFrame( + { + "x": ["aa", "bb", "cc"], + "y": [100, 325, -250], + "z": [975, -444, 123], + } + ) + # when we read back a named table object with a total row we expect the read + # to automatically omit that row as it is *not* part of the actual table data + wb_path = Path(tmp_path).joinpath("test_named_table_read.xlsx") + df.write_excel( + wb_path, + worksheet="data", + table_name="PolarsFrameTable", + column_totals=True, + ) + for engine in ("calamine", "openpyxl"): + xldf = pl.read_excel(wb_path, table_name="PolarsFrameTable", engine=engine) + assert_frame_equal(df, xldf) + + # xlsx2csv doesn't support reading named tables, so we see the + # column total if we don't filter it out after reading the data + with pytest.raises( + ValueError, + match="the `table_name` parameter is not supported by the 'xlsx2csv' engine", + ): + pl.read_excel(wb_path, table_name="PolarsFrameTable", engine="xlsx2csv") + + xldf = pl.read_excel(wb_path, sheet_name="data", engine="xlsx2csv") + assert_frame_equal(df, xldf.head(3)) + assert xldf.height == 4 + assert xldf.row(3) == (None, 0, 0) + + +@pytest.mark.parametrize("engine", ["calamine", "openpyxl", "xlsx2csv"]) +def test_excel_write_to_bytesio(engine: ExcelSpreadsheetEngine) -> None: + df = pl.DataFrame({"colx": [1.5, -2, 0], "coly": ["a", None, "c"]}) + + excel_bytes = BytesIO() + df.write_excel(excel_bytes) + + df_read = pl.read_excel(excel_bytes, engine=engine) + assert_frame_equal(df, df_read) + + # also confirm consistent behaviour when 'infer_schema_length=0' + df_read = pl.read_excel(excel_bytes, engine=engine, infer_schema_length=0) + expected = pl.DataFrame({"colx": ["1.5", "-2", "0"], "coly": ["a", None, "c"]}) + assert_frame_equal(expected, df_read) + + +@pytest.mark.parametrize("engine", ["xlsx2csv", "openpyxl", "calamine"]) +def test_excel_write_to_file_object( + engine: ExcelSpreadsheetEngine, tmp_path: Path +) -> None: + tmp_path.mkdir(exist_ok=True) + + df = pl.DataFrame({"x": ["aaa", "bbb", "ccc"], "y": [123, 456, 789]}) + + # write to bytesio + xls = BytesIO() + df.write_excel(xls, worksheet="data") + assert_frame_equal(df, pl.read_excel(xls, engine=engine)) + + # write to file path + path = Path(tmp_path).joinpath("test_write_path.xlsx") + df.write_excel(path, worksheet="data") + assert_frame_equal(df, pl.read_excel(xls, engine=engine)) + + # write to file path (as string) + path = Path(tmp_path).joinpath("test_write_path_str.xlsx") + df.write_excel(str(path), worksheet="data") + assert_frame_equal(df, pl.read_excel(xls, engine=engine)) + + # write to file object + path = Path(tmp_path).joinpath("test_write_file_object.xlsx") + with path.open("wb") as tgt: + df.write_excel(tgt, worksheet="data") + with path.open("rb") as src: + assert_frame_equal(df, pl.read_excel(src, engine=engine)) + + +@pytest.mark.parametrize("engine", ["calamine", "openpyxl", "xlsx2csv"]) +def test_excel_read_no_headers(engine: ExcelSpreadsheetEngine) -> None: + df = pl.DataFrame( + {"colx": [1, 2, 3], "coly": ["aaa", "bbb", "ccc"], "colz": [0.5, 0.0, -1.0]} + ) + xls = BytesIO() + df.write_excel(xls, worksheet="data", include_header=False) + + xldf = pl.read_excel(xls, engine=engine, has_header=False) + expected = xldf.rename({"column_1": "colx", "column_2": "coly", "column_3": "colz"}) + assert_frame_equal(df, expected) + + +@pytest.mark.parametrize("engine", ["calamine", "openpyxl", "xlsx2csv"]) +def test_excel_write_sparklines(engine: ExcelSpreadsheetEngine) -> None: + from xlsxwriter import Workbook + + # note that we don't (quite) expect sparkline export to round-trip as we + # inject additional empty columns to hold them (which will read as nulls) + df = pl.DataFrame( + { + "id": ["aaa", "bbb", "ccc", "ddd", "eee"], + "q1": [100, 55, -20, 0, 35], + "q2": [30, -10, 15, 60, 20], + "q3": [-50, 0, 40, 80, 80], + "q4": [75, 55, 25, -10, -55], + } + ).cast(dtypes={pl.Int64: pl.Float64}) + + # also: confirm that we can use a Workbook directly with "write_excel" + xls = BytesIO() + with Workbook(xls) as wb: + df.write_excel( + workbook=wb, + worksheet="frame_data", + table_style="Table Style Light 2", + dtype_formats={frozenset(NUMERIC_DTYPES): "#,##0_);(#,##0)"}, + column_formats={cs.starts_with("h"): "#,##0_);(#,##0)"}, + sparklines={ + "trend": ["q1", "q2", "q3", "q4"], + "+/-": { + "columns": ["q1", "q2", "q3", "q4"], + "insert_after": "id", + "type": "win_loss", + }, + }, + conditional_formats={ + cs.starts_with("q", "h"): { + "type": "2_color_scale", + "min_color": "#95b3d7", + "max_color": "#ffffff", + } + }, + column_widths={cs.starts_with("q", "h"): 40}, + row_totals={ + "h1": ("q1", "q2"), + "h2": ("q3", "q4"), + }, + hide_gridlines=True, + row_heights=35, + sheet_zoom=125, + ) + + tables = {tbl["name"] for tbl in wb.get_worksheet_by_name("frame_data").tables} + assert "Frame0" in tables + + with warnings.catch_warnings(): + # ignore an openpyxl user warning about sparklines + warnings.simplefilter("ignore", UserWarning) + xldf = pl.read_excel(xls, sheet_name="frame_data", engine=engine) + + # ┌─────┬──────┬─────┬─────┬─────┬─────┬───────┬─────┬─────┐ + # │ id ┆ +/- ┆ q1 ┆ q2 ┆ q3 ┆ q4 ┆ trend ┆ h1 ┆ h2 │ + # │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + # │ str ┆ str ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ str ┆ i64 ┆ i64 │ + # ╞═════╪══════╪═════╪═════╪═════╪═════╪═══════╪═════╪═════╡ + # │ aaa ┆ null ┆ 100 ┆ 30 ┆ -50 ┆ 75 ┆ null ┆ 0 ┆ 0 │ + # │ bbb ┆ null ┆ 55 ┆ -10 ┆ 0 ┆ 55 ┆ null ┆ 0 ┆ 0 │ + # │ ccc ┆ null ┆ -20 ┆ 15 ┆ 40 ┆ 25 ┆ null ┆ 0 ┆ 0 │ + # │ ddd ┆ null ┆ 0 ┆ 60 ┆ 80 ┆ -10 ┆ null ┆ 0 ┆ 0 │ + # │ eee ┆ null ┆ 35 ┆ 20 ┆ 80 ┆ -55 ┆ null ┆ 0 ┆ 0 │ + # └─────┴──────┴─────┴─────┴─────┴─────┴───────┴─────┴─────┘ + + for sparkline_col in ("+/-", "trend"): + assert set(xldf[sparkline_col]) in ({None}, {""}) + + assert xldf.columns == ["id", "+/-", "q1", "q2", "q3", "q4", "trend", "h1", "h2"] + assert_frame_equal( + df, xldf.drop("+/-", "trend", "h1", "h2").cast(dtypes={pl.Int64: pl.Float64}) + ) + + +def test_excel_write_multiple_tables() -> None: + from xlsxwriter import Workbook + + # note: also checks that empty tables don't error on write + df = pl.DataFrame(schema={"colx": pl.Date, "coly": pl.String, "colz": pl.Float64}) + + # write multiple frames to multiple worksheets + xls = BytesIO() + with Workbook(xls) as wb: + df.rename({"colx": "colx0", "coly": "coly0", "colz": "colz0"}).write_excel( + workbook=wb, worksheet="sheet1", position="A1" + ) + df.rename({"colx": "colx1", "coly": "coly1", "colz": "colz1"}).write_excel( + workbook=wb, worksheet="sheet1", position="X10" + ) + df.rename({"colx": "colx2", "coly": "coly2", "colz": "colz2"}).write_excel( + workbook=wb, worksheet="sheet2", position="C25" + ) + + # also validate integration of externally-added formats + fmt = wb.add_format({"bg_color": "#ffff00"}) + df.rename({"colx": "colx3", "coly": "coly3", "colz": "colz3"}).write_excel( + workbook=wb, + worksheet="sheet3", + position="D4", + conditional_formats={ + "colz3": { + "type": "formula", + "criteria": "=C2=B2", + "format": fmt, + } + }, + ) + + table_names = { + tbl["name"] + for sheet in wb.sheetnames + for tbl in wb.get_worksheet_by_name(sheet).tables + } + assert table_names == {f"Frame{n}" for n in range(4)} + assert pl.read_excel(xls, sheet_name="sheet3").rows() == [] + + # test loading one of the written tables by name + for engine in ("calamine", "openpyxl"): + df1 = pl.read_excel( + xls, + sheet_name="sheet2", + table_name="Frame2", + engine=engine, + ) + df2 = pl.read_excel( + xls, + table_name="Frame2", + engine=engine, + ) + assert df1.columns == ["colx2", "coly2", "colz2"] + assert_frame_equal(df1, df2) + + # if we supply a sheet name (which is optional when using `table_name`), + # then the table name must be present in *that* sheet, or we raise an error + with pytest.raises( + RuntimeError, + match="table named 'Frame3' not found in sheet 'sheet1'", + ): + pl.read_excel(xls, sheet_name="sheet1", table_name="Frame3") + + +def test_excel_write_worksheet_object() -> None: + # write to worksheet object + from xlsxwriter import Workbook + + df = pl.DataFrame({"colx": ["aaa", "bbb", "ccc"], "coly": [-1234, 0, 5678]}) + + with Workbook(xls := BytesIO()) as wb: + ws = wb.add_worksheet("frame_data") + df.write_excel(wb, worksheet=ws) + ws.hide_zero() + + assert_frame_equal(df, pl.read_excel(xls, sheet_name="frame_data")) + + with pytest.raises( # noqa: SIM117 + ValueError, + match="the given workbook object .* is not the parent of worksheet 'frame_data'", + ): + with Workbook(BytesIO()) as wb: + df.write_excel(wb, worksheet=ws) + + with pytest.raises( # noqa: SIM117 + TypeError, + match="worksheet object requires the parent workbook object; found workbook=None", + ): + with Workbook(BytesIO()) as wb: + df.write_excel(None, worksheet=ws) + + +def test_excel_write_beyond_max_rows_cols(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + path = tmp_path / "test_max_dimensions.xlsx" + sheet = "mysheet" + + df = pl.DataFrame({"col1": range(10), "col2": range(10, 20)}) + + with pytest.raises(pl.exceptions.InvalidOperationError): + df.write_excel(workbook=path, worksheet=sheet, position="A1048570") + + +def test_excel_freeze_panes() -> None: + from xlsxwriter import Workbook + + # note: checks that empty tables don't error on write + df1 = pl.DataFrame(schema={"colx": pl.Date, "coly": pl.String, "colz": pl.Float64}) + df2 = pl.DataFrame(schema={"colx": pl.Date, "coly": pl.String, "colz": pl.Float64}) + df3 = pl.DataFrame(schema={"colx": pl.Date, "coly": pl.String, "colz": pl.Float64}) + + xls = BytesIO() + + # use all three freeze_pane notations + with Workbook(xls) as wb: + df1.write_excel(workbook=wb, worksheet="sheet1", freeze_panes=(1, 0)) + df2.write_excel(workbook=wb, worksheet="sheet2", freeze_panes=(1, 0, 3, 4)) + df3.write_excel(workbook=wb, worksheet="sheet3", freeze_panes=("B2")) + + table_names: set[str] = set() + for sheet in ("sheet1", "sheet2", "sheet3"): + table_names.update( + tbl["name"] for tbl in wb.get_worksheet_by_name(sheet).tables + ) + assert table_names == {f"Frame{n}" for n in range(3)} + assert pl.read_excel(xls, sheet_name="sheet3").rows() == [] + + +@pytest.mark.parametrize( + ("read_spreadsheet", "source", "schema_overrides"), + [ + (pl.read_excel, "path_xlsx_empty", None), + (pl.read_excel, "path_xlsb_empty", None), + (pl.read_excel, "path_xls_empty", None), + (pl.read_ods, "path_ods_empty", None), + # Test with schema overrides, to ensure they don't interfere with + # raising NoDataErrors. + (pl.read_excel, "path_xlsx_empty", {"a": pl.Int64}), + (pl.read_excel, "path_xlsb_empty", {"a": pl.Int64}), + (pl.read_excel, "path_xls_empty", {"a": pl.Int64}), + (pl.read_ods, "path_ods_empty", {"a": pl.Int64}), + ], +) +def test_excel_empty_sheet( + read_spreadsheet: Callable[..., pl.DataFrame], + source: str, + request: pytest.FixtureRequest, + schema_overrides: SchemaDict | None, +) -> None: + ods = (empty_spreadsheet_path := request.getfixturevalue(source)).suffix == ".ods" + read_spreadsheet = pl.read_ods if ods else pl.read_excel # type: ignore[assignment] + + with pytest.raises(NoDataError, match="empty Excel sheet"): + read_spreadsheet(empty_spreadsheet_path, schema_overrides=schema_overrides) + + engine_params = [{}] if ods else [{"engine": "calamine"}] + for params in engine_params: + df = read_spreadsheet( + empty_spreadsheet_path, + sheet_name="no_data", + raise_if_empty=False, + **params, + ) + expected = pl.DataFrame() + assert_frame_equal(df, expected) + + df = read_spreadsheet( + empty_spreadsheet_path, + sheet_name="no_rows", + raise_if_empty=False, + **params, + ) + expected = pl.DataFrame(schema={f"col{c}": pl.String for c in ("x", "y", "z")}) + assert_frame_equal(df, expected) + + +@pytest.mark.parametrize( + ("engine", "hidden_columns"), + [ + ("xlsx2csv", ["a"]), + ("openpyxl", ["a", "b"]), + ("calamine", ["a", "b"]), + ("xlsx2csv", cs.numeric()), + ("openpyxl", cs.last()), + ], +) +def test_excel_hidden_columns( + hidden_columns: list[str] | SelectorType, + engine: ExcelSpreadsheetEngine, +) -> None: + df = pl.DataFrame({"a": [1, 2], "b": ["x", "y"]}) + + xls = BytesIO() + df.write_excel(xls, hidden_columns=hidden_columns) + + read_df = pl.read_excel(xls) + assert_frame_equal(df, read_df) + + +def test_excel_mixed_calamine_float_data(io_files_path: Path) -> None: + df = pl.read_excel(io_files_path / "nan_test.xlsx", engine="calamine") + nan = float("nan") + assert_frame_equal( + pl.DataFrame({"float_col": [nan, nan, nan, 100.0, 200.0, 300.0]}), + df, + ) + + +@pytest.mark.parametrize("engine", ["calamine", "openpyxl", "xlsx2csv"]) +@pytest.mark.may_fail_auto_streaming # read->scan_csv dispatch, _read_spreadsheet_xlsx2csv needs to be changed not to call `_reorder_columns` on the df +def test_excel_type_inference_with_nulls(engine: ExcelSpreadsheetEngine) -> None: + df = pl.DataFrame( + { + "a": [1, 2, None], + "b": [1.0, None, 3.5], + "c": ["x", None, "z"], + "d": [True, False, None], + "e": [ + date(2023, 1, 1), + None, + date(2023, 1, 4), + ], + "f": [ + datetime(2023, 1, 1), + datetime(2000, 10, 10, 10, 10), + None, + ], + "g": [ + None, + "1920-08-08 00:00:00", + "2077-10-20 00:00:00.000000", + ], + } + ) + xls = BytesIO() + df.write_excel(xls) + + reversed_cols = list(reversed(df.columns)) + read_cols: Sequence[str] | Sequence[int] + expected = df.select(reversed_cols).with_columns( + pl.col("g").str.slice(0, 10).str.to_date() + ) + for read_cols in ( + reversed_cols, + [6, 5, 4, 3, 2, 1, 0], + ): + read_df = pl.read_excel( + xls, + engine=engine, + columns=read_cols, + schema_overrides={ + "e": pl.Date, + "f": pl.Datetime("us"), + "g": pl.Date, + }, + ) + assert_frame_equal(expected, read_df) + + +@pytest.mark.parametrize("engine", ["calamine", "openpyxl", "xlsx2csv"]) +def test_drop_empty_rows( + path_empty_rows_excel: Path, engine: ExcelSpreadsheetEngine +) -> None: + df1 = pl.read_excel( + source=path_empty_rows_excel, + engine=engine, + ) # check default + assert df1.shape == (8, 4) + + df2 = pl.read_excel( + source=path_empty_rows_excel, + engine=engine, + drop_empty_rows=True, + ) + assert df2.shape == (8, 4) + + df3 = pl.read_excel( + source=path_empty_rows_excel, + engine=engine, + drop_empty_rows=False, + ) + assert df3.shape == (10, 4) + + +def test_excel_write_select_col_dtype() -> None: + from openpyxl import load_workbook + from xlsxwriter import Workbook + + def get_col_widths(wb_bytes: BytesIO) -> dict[str, int]: + return { + k: round(v.width) + for k, v in load_workbook(wb_bytes).active.column_dimensions.items() + } + + df = pl.DataFrame( + { + "name": [["Alice", "Ben"], ["Charlie", "Delta"]], + "col2": ["Hi", "Bye"], + } + ) + + # column_widths test: + # pl.List(pl.String)) datatype should not match column with no list + check = BytesIO() + with Workbook(check) as wb: + df.write_excel(wb, column_widths={cs.by_dtype(pl.List(pl.String)): 300}) + + assert get_col_widths(check) == {"A": 43} + + # column_widths test: + # pl.String datatype should not match column with list + check = BytesIO() + with Workbook(check) as wb: + df.write_excel(wb, column_widths={cs.by_dtype(pl.String): 300}) + + assert get_col_widths(check) == {"B": 43} + + # hidden_columns test: + # pl.List(pl.String)) datatype should not match column with no list + check = BytesIO() + with Workbook(check) as wb: + df.write_excel(wb, hidden_columns=cs.by_dtype(pl.List(pl.String))) + + assert get_col_widths(check) == {"A": 0} + + # hidden_columns test: + # pl.String datatype should not match column with list + check = BytesIO() + with Workbook(check) as wb: + df.write_excel(wb, hidden_columns=cs.by_dtype(pl.String)) + + assert get_col_widths(check) == {"B": 0} diff --git a/py-polars/tests/unit/io/test_utils.py b/py-polars/tests/unit/io/test_utils.py new file mode 100644 index 000000000000..5a1aad084b71 --- /dev/null +++ b/py-polars/tests/unit/io/test_utils.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pytest + +import polars as pl +from polars.io._utils import looks_like_url, parse_columns_arg, parse_row_index_args + +if TYPE_CHECKING: + from collections.abc import Sequence + + +@pytest.mark.parametrize( + ("columns", "expected"), + [ + (["a", "b"], (None, ["a", "b"])), + ((1, 2), ((1, 2), None)), + ("foo", (None, ["foo"])), + (3, ([3], None)), + (None, (None, None)), + ], +) +def test_parse_columns_arg( + columns: Sequence[str] | Sequence[int] | str | int | None, + expected: tuple[Sequence[int] | None, Sequence[str] | None], +) -> None: + assert parse_columns_arg(columns) == expected + + +def test_parse_columns_arg_mixed_types() -> None: + with pytest.raises(TypeError): + parse_columns_arg(["a", 1]) # type: ignore[arg-type] + + +@pytest.mark.parametrize("columns", [["a", "a"], [1, 1, 2]]) +def test_parse_columns_arg_duplicates(columns: Sequence[str] | Sequence[int]) -> None: + with pytest.raises(ValueError): + parse_columns_arg(columns) + + +def test_parse_row_index_args() -> None: + assert parse_row_index_args("idx", 5) == ("idx", 5) + assert parse_row_index_args(None, 5) is None + + +@pytest.mark.parametrize( + ("url", "result"), + [ + ("HTTPS://pola.rs/data.csv", True), + ("http://pola.rs/data.csv", True), + ("ftps://pola.rs/data.csv", True), + ("FTP://pola.rs/data.csv", True), + ("htp://pola.rs/data.csv", False), + ("fttp://pola.rs/data.csv", False), + ("http_not_a_url", False), + ("ftp_not_a_url", False), + ("/mnt/data.csv", False), + ("file://mnt/data.csv", False), + ], +) +def test_looks_like_url(url: str, result: bool) -> None: + assert looks_like_url(url) == result + + +@pytest.mark.parametrize( + "scan", [pl.scan_csv, pl.scan_parquet, pl.scan_ndjson, pl.scan_ipc] +) +def test_filename_in_err(scan: Any) -> None: + with pytest.raises(FileNotFoundError, match=r".*does not exist"): + scan("does not exist").collect() diff --git a/py-polars/tests/unit/io/test_write.py b/py-polars/tests/unit/io/test_write.py new file mode 100644 index 000000000000..e49577e08291 --- /dev/null +++ b/py-polars/tests/unit/io/test_write.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Callable + +import pytest + +import polars as pl +from polars.testing.asserts.frame import assert_frame_equal + +READ_WRITE_FUNC_PARAM = [ + (pl.read_parquet, pl.DataFrame.write_parquet), + (lambda *a: pl.scan_csv(*a).collect(), pl.DataFrame.write_csv), + (lambda *a: pl.scan_ipc(*a).collect(), pl.DataFrame.write_ipc), + # Sink + (pl.read_parquet, lambda df, path: pl.DataFrame.lazy(df).sink_parquet(path)), + ( + lambda *a: pl.scan_csv(*a).collect(), + lambda df, path: pl.DataFrame.lazy(df).sink_csv(path), + ), + ( + lambda *a: pl.scan_ipc(*a).collect(), + lambda df, path: pl.DataFrame.lazy(df).sink_ipc(path), + ), + ( + lambda *a: pl.scan_ndjson(*a).collect(), + lambda df, path: pl.DataFrame.lazy(df).sink_ndjson(path), + ), +] + + +@pytest.mark.parametrize( + ("read_func", "write_func"), + READ_WRITE_FUNC_PARAM, +) +@pytest.mark.write_disk +def test_write_async( + read_func: Callable[[Path], pl.DataFrame], + write_func: Callable[[pl.DataFrame, Path], None], + tmp_path: Path, +) -> None: + tmp_path.mkdir(exist_ok=True) + path = (tmp_path / "1").absolute() + path = f"file://{path}" # type: ignore[assignment] + + df = pl.DataFrame({"x": 1}) + + write_func(df, path) + + assert_frame_equal(read_func(path), df) + + +@pytest.mark.parametrize( + ("read_func", "write_func"), + READ_WRITE_FUNC_PARAM, +) +@pytest.mark.parametrize("opt_absolute_fn", [Path, Path.absolute]) +@pytest.mark.write_disk +def test_write_async_force_async( + read_func: Callable[[Path], pl.DataFrame], + write_func: Callable[[pl.DataFrame, Path], None], + opt_absolute_fn: Callable[[Path], Path], + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("POLARS_FORCE_ASYNC", "1") + tmp_path.mkdir(exist_ok=True) + path = opt_absolute_fn(tmp_path / "1") + + df = pl.DataFrame({"x": 1}) + + write_func(df, path) + + assert_frame_equal(read_func(path), df) diff --git a/py-polars/tests/unit/lazyframe/__init__.py b/py-polars/tests/unit/lazyframe/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/lazyframe/cuda/test_node_visitor.py b/py-polars/tests/unit/lazyframe/cuda/test_node_visitor.py new file mode 100644 index 000000000000..8c9df01c71e4 --- /dev/null +++ b/py-polars/tests/unit/lazyframe/cuda/test_node_visitor.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +import time +from functools import lru_cache, partial +from typing import TYPE_CHECKING, Any, Callable + +import polars as pl +from polars._utils.wrap import wrap_df +from polars.polars import _ir_nodes + +if TYPE_CHECKING: + import pandas as pd + + +class Timer: + """Simple-minded timing of nodes.""" + + def __init__(self, start: int | None) -> None: + self.start = start + self.timings: list[tuple[int, int, str]] = [] + + def record(self, fn: Callable[[], pd.DataFrame], name: str) -> pd.DataFrame: + start = time.monotonic_ns() + result = fn() + end = time.monotonic_ns() + if self.start is not None: + self.timings.append((start - self.start, end - self.start, name)) + return result + + +def test_run_on_pandas() -> None: + # Simple join example, missing multiple columns, slices, etc. + def join( + inputs: list[Callable[[], pd.DataFrame]], + obj: Any, + _node_traverser: Any, + timer: Timer, + ) -> Callable[[], pd.DataFrame]: + assert len(obj.left_on) == 1 + assert len(obj.right_on) == 1 + left_on = obj.left_on[0].output_name + right_on = obj.right_on[0].output_name + + assert len(inputs) == 2 + + def run(inputs: list[Callable[[], pd.DataFrame]]) -> pd.DataFrame: + # materialize inputs + dataframes = [call() for call in inputs] + return timer.record( + lambda: dataframes[0].merge( + dataframes[1], left_on=left_on, right_on=right_on + ), + "pandas-join", + ) + + return partial(run, inputs) + + # Simple scan example, missing predicates, columns pruning, slices, etc. + def df_scan( + _inputs: None, obj: Any, _: Any, timer: Timer + ) -> Callable[[], pd.DataFrame]: + assert obj.selection is None + return lambda: timer.record(lambda: wrap_df(obj.df).to_pandas(), "pandas-scan") + + @lru_cache(1) + def get_node_converters() -> dict[ + type, Callable[[Any, Any, Any, Timer], Callable[[], pd.DataFrame]] + ]: + return { + _ir_nodes.Join: join, + _ir_nodes.DataFrameScan: df_scan, + } + + def get_input(node_traverser: Any, *, timer: Timer) -> Callable[[], pd.DataFrame]: + current_node = node_traverser.get_node() + + inputs_callable = [] + for inp in node_traverser.get_inputs(): + node_traverser.set_node(inp) + inputs_callable.append(get_input(node_traverser, timer=timer)) + + node_traverser.set_node(current_node) + ir_node = node_traverser.view_current_node() + return get_node_converters()[ir_node.__class__]( + inputs_callable, ir_node, node_traverser, timer + ) + + def run_on_pandas(node_traverser: Any, query_start: int | None) -> None: + timer = Timer( + time.monotonic_ns() - query_start if query_start is not None else None + ) + current_node = node_traverser.get_node() + + callback = get_input(node_traverser, timer=timer) + + def run_callback( + columns: list[str] | None, + _: Any, + n_rows: int | None, + should_time: bool, + ) -> pl.DataFrame | tuple[pl.DataFrame, list[tuple[int, int, str]]]: + assert n_rows is None + assert columns is None + + # produce a wrong result to ensure the callback has run. + result = pl.from_pandas(callback() * 2) + if should_time: + return result, timer.timings + else: + return result + + node_traverser.set_node(current_node) + node_traverser.set_udf(run_callback) + + # Polars query that will run on pandas + q1 = pl.LazyFrame({"foo": [1, 2, 3]}) + q2 = pl.LazyFrame({"foo": [1], "bar": [2]}) + q = q1.join(q2, on="foo") + assert q.collect( + post_opt_callback=run_on_pandas # type: ignore[call-overload] + ).to_dict(as_series=False) == { + "foo": [2], + "bar": [4], + } + + result, timings = q.profile(post_opt_callback=run_on_pandas) + assert result.to_dict(as_series=False) == { + "foo": [2], + "bar": [4], + } + assert timings["node"].to_list() == [ + "optimization", + "pandas-scan", + "pandas-scan", + "pandas-join", + ] diff --git a/py-polars/tests/unit/lazyframe/optimizations.py b/py-polars/tests/unit/lazyframe/optimizations.py new file mode 100644 index 000000000000..8b33cd6a8967 --- /dev/null +++ b/py-polars/tests/unit/lazyframe/optimizations.py @@ -0,0 +1,15 @@ +import io + +import polars as pl +from polars.testing import assert_frame_equal + + +def test_fast_count_alias_18581() -> None: + f = io.BytesIO() + f.write(b"a,b,c\n1,2,3\n4,5,6") + f.flush() + f.seek(0) + + df = pl.scan_csv(f).select(pl.len().alias("weird_name")).collect() + + assert_frame_equal(pl.DataFrame({"weird_name": 2}), df) diff --git a/py-polars/tests/unit/lazyframe/test_collect_all.py b/py-polars/tests/unit/lazyframe/test_collect_all.py new file mode 100644 index 000000000000..051a211ab7f7 --- /dev/null +++ b/py-polars/tests/unit/lazyframe/test_collect_all.py @@ -0,0 +1,20 @@ +from typing import cast + +import pytest + +import polars as pl + + +def test_collect_all_type_coercion_21805() -> None: + df = pl.LazyFrame({"A": [1.0, 2.0]}) + df = df.with_columns(pl.col("A").shift().fill_null(2)) + assert pl.collect_all([df])[0]["A"].to_list() == [2.0, 1.0] + + +@pytest.mark.parametrize("no_optimization", [False, True]) +def test_collect_all(df: pl.DataFrame, no_optimization: bool) -> None: + lf1 = df.lazy().select(pl.col("int").sum()) + lf2 = df.lazy().select((pl.col("floats") * 2).sum()) + out = pl.collect_all([lf1, lf2], no_optimization=no_optimization) + assert cast(int, out[0].item()) == 6 + assert cast(float, out[1].item()) == 12.0 diff --git a/py-polars/tests/unit/lazyframe/test_collect_schema.py b/py-polars/tests/unit/lazyframe/test_collect_schema.py new file mode 100644 index 000000000000..2cc4222ccc00 --- /dev/null +++ b/py-polars/tests/unit/lazyframe/test_collect_schema.py @@ -0,0 +1,22 @@ +from hypothesis import given + +import polars as pl +from polars.testing.parametric import dataframes + + +@given(lf=dataframes(lazy=True)) +def test_collect_schema_parametric(lf: pl.LazyFrame) -> None: + assert lf.collect_schema() == lf.collect().schema + + +def test_collect_schema() -> None: + lf = pl.LazyFrame( + { + "foo": [1, 2, 3], + "bar": [6.0, 7.0, 8.0], + "ham": ["a", "b", "c"], + } + ) + result = lf.collect_schema() + expected = pl.Schema({"foo": pl.Int64(), "bar": pl.Float64(), "ham": pl.String()}) + assert result == expected diff --git a/py-polars/tests/unit/lazyframe/test_engine_selection.py b/py-polars/tests/unit/lazyframe/test_engine_selection.py new file mode 100644 index 000000000000..9d3f233438ba --- /dev/null +++ b/py-polars/tests/unit/lazyframe/test_engine_selection.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + +if TYPE_CHECKING: + from polars._typing import EngineType + + +@pytest.fixture +def df() -> pl.LazyFrame: + return pl.LazyFrame({"a": [1, 2, 3]}) + + +@pytest.fixture(params=["gpu", pl.GPUEngine()]) +def engine(request: pytest.FixtureRequest) -> EngineType: + value: EngineType = request.param + return value + + +def test_engine_selection_invalid_raises(df: pl.LazyFrame) -> None: + with pytest.raises(ValueError): + df.collect(engine="unknown") # type: ignore[call-overload] + + +def test_engine_selection_background_warns( + df: pl.LazyFrame, engine: EngineType +) -> None: + expect = df.collect() + with pytest.warns( + UserWarning, match="GPU engine does not support streaming or background" + ): + got = df.collect(engine=engine, background=True) + assert_frame_equal(expect, got.fetch_blocking()) + + +def test_engine_selection_eager_quiet(df: pl.LazyFrame, engine: EngineType) -> None: + expect = df.collect() + # _eager collection turns off GPU engine quietly + got = df.collect(engine=engine, _eager=True) + assert_frame_equal(expect, got) + + +def test_engine_import_error_raises(df: pl.LazyFrame, engine: EngineType) -> None: + with pytest.raises(ImportError, match="GPU engine requested"): + df.collect(engine=engine) diff --git a/py-polars/tests/unit/lazyframe/test_explain.py b/py-polars/tests/unit/lazyframe/test_explain.py new file mode 100644 index 000000000000..806a7c6cad67 --- /dev/null +++ b/py-polars/tests/unit/lazyframe/test_explain.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import pytest + +import polars as pl + + +def test_lf_explain_format_tree() -> None: + lf = pl.LazyFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]}) + plan = lf.select("a").select(pl.col("a").sum() + pl.len()) + + result = plan.explain(format="tree") + print(result) + + expected = """\ + 0 1 + ┌───────────────────────────────────────────────────────── + │ + │ ╭────────╮ + 0 │ │ SELECT │ + │ ╰───┬┬───╯ + │ ││ + │ │╰───────────────────────────╮ + │ │ │ + │ ╭─────────┴──────────╮ │ + │ │ expression: │ ╭──────────────┴──────────────╮ + │ │ [(col("a") │ │ FROM: │ + 1 │ │ .sum()) + (len() │ │ DF ["a", "b"] │ + │ │ .cast(Int64))] │ │ PROJECT: ["a"]; 1/2 COLUMNS │ + │ ╰────────────────────╯ ╰─────────────────────────────╯ +\ +""" + assert result == expected + + +def test_lf_explain_tree_format_deprecated() -> None: + lf = pl.LazyFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]}) + + with pytest.deprecated_call(): + lf.explain(tree_format=True) diff --git a/py-polars/tests/unit/lazyframe/test_getitem.py b/py-polars/tests/unit/lazyframe/test_getitem.py new file mode 100644 index 000000000000..f7565e1fa3d7 --- /dev/null +++ b/py-polars/tests/unit/lazyframe/test_getitem.py @@ -0,0 +1,66 @@ +# ---------------------------------------------------- +# Validate LazyFrame behaviour with parametric tests +# ---------------------------------------------------- +import hypothesis.strategies as st +from hypothesis import example, given + +import polars as pl +from polars.testing.parametric import column, dataframes + + +@given( + ldf=dataframes( + max_size=10, + lazy=True, + cols=[ + column( + "start", + dtype=pl.Int8, + allow_null=True, + strategy=st.integers(min_value=-3, max_value=4), + ), + column( + "stop", + dtype=pl.Int8, + allow_null=True, + strategy=st.integers(min_value=-2, max_value=6), + ), + column( + "step", + dtype=pl.Int8, + allow_null=True, + strategy=st.integers(min_value=-3, max_value=3).filter( + lambda x: x != 0 + ), + ), + column("misc", dtype=pl.Int32), + ], + ) +) +@example( + ldf=pl.LazyFrame( + { + "start": [-1, None, 1, None, 1, -1], + "stop": [None, 0, -1, -1, 2, 1], + "step": [-1, -1, 1, None, -1, 1], + "misc": [1, 2, 3, 4, 5, 6], + } + ) +) +def test_lazyframe_getitem(ldf: pl.LazyFrame) -> None: + py_data = ldf.collect().rows() + + for start, stop, step, _ in py_data: + s = slice(start, stop, step) + sliced_py_data = py_data[s] + try: + sliced_df_data = ldf[s].collect().rows() + assert sliced_py_data == sliced_df_data, ( + f"slice [{start}:{stop}:{step}] failed on lazy df w/len={len(py_data)}" + ) + + except ValueError as exc: + # test params will trigger some known + # unsupported cases; filter them here. + if "not supported" not in str(exc): + raise diff --git a/py-polars/tests/unit/lazyframe/test_lazyframe.py b/py-polars/tests/unit/lazyframe/test_lazyframe.py new file mode 100644 index 000000000000..31bfd847557f --- /dev/null +++ b/py-polars/tests/unit/lazyframe/test_lazyframe.py @@ -0,0 +1,1493 @@ +from __future__ import annotations + +from datetime import date, datetime +from functools import reduce +from inspect import signature +from operator import add +from string import ascii_letters +from typing import TYPE_CHECKING, Any, Callable, NoReturn, cast + +import numpy as np +import pytest + +import polars as pl +import polars.selectors as cs +from polars import lit, when +from polars.exceptions import ( + InvalidOperationError, + PerformanceWarning, + PolarsInefficientMapWarning, +) +from polars.testing import assert_frame_equal, assert_series_equal +from tests.unit.conftest import FLOAT_DTYPES, NUMERIC_DTYPES + +if TYPE_CHECKING: + from _pytest.capture import CaptureFixture + + from polars._typing import PolarsDataType + + +def test_init_signature_match() -> None: + # eager/lazy init signatures are expected to match; if this test fails, it + # means a parameter was added to one but not the other, and that should be + # fixed (or an explicit exemption should be made here, with an explanation) + assert signature(pl.DataFrame.__init__) == signature(pl.LazyFrame.__init__) + + +def test_lazy_misc() -> None: + ldf = pl.LazyFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0]}) + _ = ldf.with_columns(pl.lit(1).alias("foo")).select([pl.col("a"), pl.col("foo")]) + + # test if it executes + _ = ldf.with_columns( + when(pl.col("a") > pl.lit(2)).then(pl.lit(10)).otherwise(pl.lit(1)).alias("new") + ).collect() + + +def test_implode() -> None: + ldf = pl.LazyFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0]}) + eager = ( + ldf.group_by(pl.col("a").alias("grp"), maintain_order=True) + .agg(pl.implode("a", "b").name.suffix("_imp")) + .collect() + ) + assert_frame_equal( + eager, + pl.DataFrame( + { + "grp": [1, 2, 3], + "a_imp": [[1], [2], [3]], + "b_imp": [[1.0], [2.0], [3.0]], + } + ), + ) + + +def test_lazyframe_membership_operator() -> None: + ldf = pl.LazyFrame({"name": ["Jane", "John"], "age": [20, 30]}) + assert "name" in ldf + assert "phone" not in ldf + + # note: cannot use lazyframe in boolean context + with pytest.raises(TypeError, match="ambiguous"): + not ldf + + +def test_apply() -> None: + ldf = pl.LazyFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0]}) + new = ldf.with_columns_seq( + pl.col("a").map_batches(lambda s: s * 2, return_dtype=pl.Int64).alias("foo") + ) + expected = ldf.clone().with_columns((pl.col("a") * 2).alias("foo")) + assert_frame_equal(new, expected) + assert_frame_equal(new.collect(), expected.collect()) + + with pytest.warns(PolarsInefficientMapWarning, match="with this one instead"): + for strategy in ["thread_local", "threading"]: + ldf = pl.LazyFrame({"a": [1, 2, 3] * 20, "b": [1.0, 2.0, 3.0] * 20}) + new = ldf.with_columns( + pl.col("a") + .map_elements(lambda s: s * 2, strategy=strategy, return_dtype=pl.Int64) # type: ignore[arg-type] + .alias("foo") + ) + expected = ldf.clone().with_columns((pl.col("a") * 2).alias("foo")) + assert_frame_equal(new.collect(), expected.collect()) + + +def test_add_eager_column() -> None: + lf = pl.LazyFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0]}) + assert lf.collect_schema().len() == 2 + + out = lf.with_columns(pl.lit(pl.Series("c", [1, 2, 3]))).collect() + assert out["c"].sum() == 6 + assert out.collect_schema().len() == 3 + + +def test_set_null() -> None: + ldf = pl.LazyFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0]}) + out = ldf.with_columns( + when(pl.col("a") > 1).then(lit(None)).otherwise(100).alias("foo") + ).collect() + s = out["foo"] + assert s[0] == 100 + assert s[1] is None + assert s[2] is None + + +def test_gather_every() -> None: + ldf = pl.LazyFrame({"a": [1, 2, 3, 4], "b": ["w", "x", "y", "z"]}) + expected_df = pl.DataFrame({"a": [1, 3], "b": ["w", "y"]}) + assert_frame_equal(expected_df, ldf.gather_every(2).collect()) + expected_df = pl.DataFrame({"a": [2, 4], "b": ["x", "z"]}) + assert_frame_equal(expected_df, ldf.gather_every(2, offset=1).collect()) + + +def test_agg() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0]}) + ldf = df.lazy().min() + res = ldf.collect() + assert res.shape == (1, 2) + assert res.row(0) == (1, 1.0) + + +def test_count_suffix_10783() -> None: + df = pl.DataFrame( + { + "a": [["a", "c", "b"], ["a", "b", "c"], ["a", "d", "c"], ["c", "a", "b"]], + "b": [["a", "c", "b"], ["a", "b", "c"], ["a", "d", "c"], ["c", "a", "b"]], + } + ) + df_with_cnt = df.with_columns( + pl.len() + .over(pl.col("a").list.sort().list.join("").hash()) + .name.suffix("_suffix") + ) + df_expect = df.with_columns(pl.Series("len_suffix", [3, 3, 1, 3])) + assert_frame_equal(df_with_cnt, df_expect, check_dtypes=False) + + +def test_or() -> None: + ldf = pl.LazyFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0]}) + out = ldf.filter((pl.col("a") == 1) | (pl.col("b") > 2)).collect() + assert out.rows() == [(1, 1.0), (3, 3.0)] + + +def test_filter_str() -> None: + # use a str instead of a column expr + ldf = pl.LazyFrame( + { + "time": ["11:11:00", "11:12:00", "11:13:00", "11:14:00"], + "bools": [True, False, True, False], + } + ) + + # last row based on a filter + result = ldf.filter(pl.col("bools")).select_seq(pl.last("*")).collect() + expected = pl.DataFrame({"time": ["11:13:00"], "bools": [True]}) + assert_frame_equal(result, expected) + + # last row based on a filter + result = ldf.filter("bools").select(pl.last("*")).collect() + assert_frame_equal(result, expected) + + +def test_filter_multiple_predicates() -> None: + ldf = pl.LazyFrame( + { + "a": [1, 1, 1, 2, 2], + "b": [1, 1, 2, 2, 2], + "c": [1, 1, 2, 3, 4], + } + ) + + # multiple predicates + expected = pl.DataFrame({"a": [1, 1, 1], "b": [1, 1, 2], "c": [1, 1, 2]}) + for out in ( + ldf.filter(pl.col("a") == 1, pl.col("b") <= 2), # positional/splat + ldf.filter([pl.col("a") == 1, pl.col("b") <= 2]), # as list + ): + assert_frame_equal(out.collect(), expected) + + # multiple kwargs + assert_frame_equal( + ldf.filter(a=1, b=2).collect(), + pl.DataFrame({"a": [1], "b": [2], "c": [2]}), + ) + + # both positional and keyword args + assert_frame_equal( + ldf.filter(pl.col("c") < 4, a=2, b=2).collect(), + pl.DataFrame({"a": [2], "b": [2], "c": [3]}), + ) + + ldf = pl.LazyFrame( + { + "description": ["eq", "gt", "ge"], + "predicate": ["==", ">", ">="], + }, + ) + assert ldf.filter(predicate="==").select("description").collect().item() == "eq" + + +@pytest.mark.parametrize( + "predicate", + [ + [pl.lit(True)], + iter([pl.lit(True)]), + [True, True, True], + iter([True, True, True]), + (p for p in (pl.col("c") < 9,)), + (p for p in (pl.col("a") > 0, pl.col("b") > 0)), + ], +) +def test_filter_seq_iterable_all_true(predicate: Any) -> None: + ldf = pl.LazyFrame( + { + "a": [1, 1, 1], + "b": [1, 1, 2], + "c": [3, 1, 2], + } + ) + assert_frame_equal(ldf, ldf.filter(predicate)) + + +def test_apply_custom_function() -> None: + ldf = pl.LazyFrame( + { + "A": [1, 2, 3, 4, 5], + "fruits": ["banana", "banana", "apple", "apple", "banana"], + "B": [5, 4, 3, 2, 1], + "cars": ["beetle", "audi", "beetle", "beetle", "beetle"], + } + ) + + # two ways to determine the length groups. + df = ( + ldf.group_by("fruits") + .agg( + [ + pl.col("cars") + .map_elements(lambda groups: groups.len(), return_dtype=pl.Int64) + .alias("custom_1"), + pl.col("cars") + .map_elements(lambda groups: groups.len(), return_dtype=pl.Int64) + .alias("custom_2"), + pl.count("cars").alias("cars_count"), + ] + ) + .sort("custom_1", descending=True) + ).collect() + + expected = pl.DataFrame( + { + "fruits": ["banana", "apple"], + "custom_1": [3, 2], + "custom_2": [3, 2], + "cars_count": [3, 2], + } + ) + expected = expected.with_columns(pl.col("cars_count").cast(pl.UInt32)) + assert_frame_equal(df, expected) + + +def test_group_by() -> None: + ldf = pl.LazyFrame( + { + "a": [1.0, None, 3.0, 4.0], + "b": [5.0, 2.5, -3.0, 2.0], + "grp": ["a", "a", "b", "b"], + } + ) + expected_a = pl.DataFrame({"grp": ["a", "b"], "a": [1.0, 3.5]}) + expected_a_b = pl.DataFrame({"grp": ["a", "b"], "a": [1.0, 3.5], "b": [3.75, -0.5]}) + + for out in ( + ldf.group_by("grp").agg(pl.mean("a")).collect(), + ldf.group_by(pl.col("grp")).agg(pl.mean("a")).collect(), + ): + assert_frame_equal(out.sort(by="grp"), expected_a) + + out = ldf.group_by("grp").agg(pl.mean("a", "b")).collect() + assert_frame_equal(out.sort(by="grp"), expected_a_b) + + +def test_arg_unique() -> None: + ldf = pl.LazyFrame({"a": [4, 1, 4]}) + col_a_unique = ldf.select(pl.col("a").arg_unique()).collect()["a"] + assert_series_equal(col_a_unique, pl.Series("a", [0, 1]).cast(pl.UInt32)) + + +def test_arg_sort() -> None: + ldf = pl.LazyFrame({"a": [4, 1, 3]}).select(pl.col("a").arg_sort()) + assert ldf.collect()["a"].to_list() == [1, 2, 0] + + +def test_window_function() -> None: + lf = pl.LazyFrame( + { + "A": [1, 2, 3, 4, 5], + "fruits": ["banana", "banana", "apple", "apple", "banana"], + "B": [5, 4, 3, 2, 1], + "cars": ["beetle", "audi", "beetle", "beetle", "beetle"], + } + ) + assert lf.collect_schema().len() == 4 + + q = lf.with_columns( + pl.sum("A").over("fruits").alias("fruit_sum_A"), + pl.first("B").over("fruits").alias("fruit_first_B"), + pl.max("B").over("cars").alias("cars_max_B"), + ) + assert q.collect_schema().len() == 7 + + assert q.collect()["cars_max_B"].to_list() == [5, 4, 5, 5, 5] + + out = lf.select([pl.first("B").over(["fruits", "cars"]).alias("B_first")]) + assert out.collect()["B_first"].to_list() == [5, 4, 3, 3, 5] + + +def test_when_then_flatten() -> None: + ldf = pl.LazyFrame({"foo": [1, 2, 3], "bar": [3, 4, 5]}) + + assert ldf.select( + when(pl.col("foo") > 1) + .then(pl.col("bar")) + .when(pl.col("bar") < 3) + .then(10) + .otherwise(30) + ).collect()["bar"].to_list() == [30, 4, 5] + + +def test_describe_plan() -> None: + assert isinstance(pl.LazyFrame({"a": [1]}).explain(optimized=True), str) + assert isinstance(pl.LazyFrame({"a": [1]}).explain(optimized=False), str) + + +def test_inspect(capsys: CaptureFixture[str]) -> None: + ldf = pl.LazyFrame({"a": [1]}) + ldf.inspect().collect() + captured = capsys.readouterr() + assert len(captured.out) > 0 + + ldf.select(pl.col("a").cum_sum().inspect().alias("bar")).collect() + res = capsys.readouterr() + assert len(res.out) > 0 + + +@pytest.mark.may_fail_auto_streaming +def test_fetch(fruits_cars: pl.DataFrame) -> None: + res = fruits_cars.lazy().select("*")._fetch(2) + assert_frame_equal(res, res[:2]) + + +def test_fold_filter() -> None: + lf = pl.LazyFrame({"a": [1, 2, 3], "b": [0, 1, 2]}) + + out = lf.filter( + pl.fold( + acc=pl.lit(True), + function=lambda a, b: a & b, + exprs=[pl.col(c) > 1 for c in lf.collect_schema()], + ) + ).collect() + + assert out.shape == (1, 2) + assert out.rows() == [(3, 2)] + + out = lf.filter( + pl.fold( + acc=pl.lit(True), + function=lambda a, b: a | b, + exprs=[pl.col(c) > 1 for c in lf.collect_schema()], + ) + ).collect() + + assert out.rows() == [(1, 0), (2, 1), (3, 2)] + + +def test_head_group_by() -> None: + commodity_prices = { + "commodity": [ + "Wheat", + "Wheat", + "Wheat", + "Wheat", + "Corn", + "Corn", + "Corn", + "Corn", + "Corn", + ], + "location": [ + "StPaul", + "StPaul", + "StPaul", + "Chicago", + "Chicago", + "Chicago", + "Chicago", + "Chicago", + "Chicago", + ], + "seller": [ + "Bob", + "Charlie", + "Susan", + "Paul", + "Ed", + "Mary", + "Paul", + "Charlie", + "Norman", + ], + "price": [1.0, 0.7, 0.8, 0.55, 2.0, 3.0, 2.4, 1.8, 2.1], + } + ldf = pl.LazyFrame(commodity_prices) + + # this query flexes the wildcard exclusion quite a bit. + keys = ["commodity", "location"] + out = ( + ldf.sort(by="price", descending=True) + .group_by(keys, maintain_order=True) + .agg([pl.col("*").exclude(keys).head(2).name.keep()]) + .explode(pl.col("*").exclude(keys)) + ) + + assert out.collect().rows() == [ + ("Corn", "Chicago", "Mary", 3.0), + ("Corn", "Chicago", "Paul", 2.4), + ("Wheat", "StPaul", "Bob", 1.0), + ("Wheat", "StPaul", "Susan", 0.8), + ("Wheat", "Chicago", "Paul", 0.55), + ] + + ldf = pl.LazyFrame( + {"letters": ["c", "c", "a", "c", "a", "b"], "nrs": [1, 2, 3, 4, 5, 6]} + ) + out = ldf.group_by("letters").tail(2).sort("letters") + assert_frame_equal( + out.collect(), + pl.DataFrame({"letters": ["a", "a", "b", "c", "c"], "nrs": [3, 5, 6, 2, 4]}), + ) + out = ldf.group_by("letters").head(2).sort("letters") + assert_frame_equal( + out.collect(), + pl.DataFrame({"letters": ["a", "a", "b", "c", "c"], "nrs": [3, 5, 6, 1, 2]}), + ) + + +def test_is_null_is_not_null() -> None: + ldf = pl.LazyFrame({"nrs": [1, 2, None]}).select( + pl.col("nrs").is_null().alias("is_null"), + pl.col("nrs").is_not_null().alias("not_null"), + ) + assert ldf.collect()["is_null"].to_list() == [False, False, True] + assert ldf.collect()["not_null"].to_list() == [True, True, False] + + +def test_is_nan_is_not_nan() -> None: + ldf = pl.LazyFrame({"nrs": np.array([1, 2, np.nan])}).select( + pl.col("nrs").is_nan().alias("is_nan"), + pl.col("nrs").is_not_nan().alias("not_nan"), + ) + assert ldf.collect()["is_nan"].to_list() == [False, False, True] + assert ldf.collect()["not_nan"].to_list() == [True, True, False] + + +def test_is_finite_is_infinite() -> None: + ldf = pl.LazyFrame({"nrs": np.array([1, 2, np.inf])}).select( + pl.col("nrs").is_infinite().alias("is_inf"), + pl.col("nrs").is_finite().alias("not_inf"), + ) + assert ldf.collect()["is_inf"].to_list() == [False, False, True] + assert ldf.collect()["not_inf"].to_list() == [True, True, False] + + +def test_len() -> None: + ldf = pl.LazyFrame({"nrs": [1, 2, 3]}) + assert cast(int, ldf.select(pl.col("nrs").len()).collect().item()) == 3 + + +@pytest.mark.parametrize("dtype", NUMERIC_DTYPES) +def test_cum_agg(dtype: PolarsDataType) -> None: + ldf = pl.LazyFrame({"a": [1, 2, 3, 2]}, schema={"a": dtype}) + assert_series_equal( + ldf.select(pl.col("a").cum_min()).collect()["a"], + pl.Series("a", [1, 1, 1, 1], dtype=dtype), + ) + assert_series_equal( + ldf.select(pl.col("a").cum_max()).collect()["a"], + pl.Series("a", [1, 2, 3, 3], dtype=dtype), + ) + + expected_dtype = ( + pl.Int64 if dtype in [pl.Int8, pl.Int16, pl.UInt8, pl.UInt16] else dtype + ) + assert_series_equal( + ldf.select(pl.col("a").cum_sum()).collect()["a"], + pl.Series("a", [1, 3, 6, 8], dtype=expected_dtype), + ) + + expected_dtype = ( + pl.Int64 + if dtype in [pl.Int8, pl.Int16, pl.Int32, pl.UInt8, pl.UInt16, pl.UInt32] + else dtype + ) + assert_series_equal( + ldf.select(pl.col("a").cum_prod()).collect()["a"], + pl.Series("a", [1, 2, 6, 12], dtype=expected_dtype), + ) + + +def test_ceil() -> None: + ldf = pl.LazyFrame({"a": [1.8, 1.2, 3.0]}) + result = ldf.select(pl.col("a").ceil()).collect() + assert_frame_equal(result, pl.DataFrame({"a": [2.0, 2.0, 3.0]})) + + ldf = pl.LazyFrame({"a": [1, 2, 3]}) + result = ldf.select(pl.col("a").ceil()).collect() + assert_frame_equal(ldf.collect(), result) + + +def test_floor() -> None: + ldf = pl.LazyFrame({"a": [1.8, 1.2, 3.0]}) + result = ldf.select(pl.col("a").floor()).collect() + assert_frame_equal(result, pl.DataFrame({"a": [1.0, 1.0, 3.0]})) + + ldf = pl.LazyFrame({"a": [1, 2, 3]}) + result = ldf.select(pl.col("a").floor()).collect() + assert_frame_equal(ldf.collect(), result) + + +@pytest.mark.parametrize( + ("n", "ndigits", "expected"), + [ + (1.005, 2, 1.0), + (1234.00000254495, 10, 1234.000002545), + (1835.665, 2, 1835.67), + (-1835.665, 2, -1835.67), + (1.27499, 2, 1.27), + (123.45678, 2, 123.46), + (1254, 2, 1254.0), + (1254, 0, 1254.0), + (123.55, 0, 124.0), + (123.55, 1, 123.6), + (-1.23456789, 6, -1.234568), + (1.0e-5, 5, 0.00001), + (1.0e-20, 20, 1e-20), + (1.0e20, 2, 100000000000000000000.0), + ], +) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_round(n: float, ndigits: int, expected: float, dtype: pl.DataType) -> None: + ldf = pl.LazyFrame({"value": [n]}, schema_overrides={"value": dtype}) + assert_series_equal( + ldf.select(pl.col("value").round(decimals=ndigits)).collect().to_series(), + pl.Series("value", [expected], dtype=dtype), + ) + + +def test_dot() -> None: + ldf = pl.LazyFrame({"a": [1.8, 1.2, 3.0], "b": [3.2, 1, 2]}).select( + pl.col("a").dot(pl.col("b")) + ) + assert cast(float, ldf.collect().item()) == 12.96 + + +def test_sort() -> None: + ldf = pl.LazyFrame({"a": [1, 2, 3, 2]}).select(pl.col("a").sort()) + assert_series_equal(ldf.collect()["a"], pl.Series("a", [1, 2, 2, 3])) + + +def test_custom_group_by() -> None: + ldf = pl.LazyFrame({"a": [1, 2, 1, 1], "b": ["a", "b", "c", "c"]}) + out = ( + ldf.group_by("b", maintain_order=True) + .agg([pl.col("a").map_elements(lambda x: x.sum(), return_dtype=pl.Int64)]) + .collect() + ) + assert out.rows() == [("a", 1), ("b", 2), ("c", 2)] + + +def test_lazy_columns() -> None: + lf = pl.LazyFrame( + { + "a": [1], + "b": [1], + "c": [1], + } + ) + assert lf.select("a", "c").collect_schema().names() == ["a", "c"] + + +def test_cast_frame() -> None: + lf = pl.LazyFrame( + { + "a": [1.0, 2.5, 3.0], + "b": [4, 5, None], + "c": [True, False, True], + "d": [date(2020, 1, 2), date(2021, 3, 4), date(2022, 5, 6)], + } + ) + + # cast via col:dtype map + assert lf.cast( + dtypes={"b": pl.Float32, "c": pl.String, "d": pl.Datetime("ms")} + ).collect_schema() == { + "a": pl.Float64, + "b": pl.Float32, + "c": pl.String, + "d": pl.Datetime("ms"), + } + + # cast via selector:dtype map + lfc = lf.cast( + { + cs.float(): pl.UInt8, + cs.integer(): pl.Int32, + cs.temporal(): pl.String, + } + ) + assert lfc.collect_schema() == { + "a": pl.UInt8, + "b": pl.Int32, + "c": pl.Boolean, + "d": pl.String, + } + assert lfc.collect().rows() == [ + (1, 4, True, "2020-01-02"), + (2, 5, False, "2021-03-04"), + (3, None, True, "2022-05-06"), + ] + + # cast all fields to a single type + result = lf.cast(pl.String) + expected = pl.LazyFrame( + { + "a": ["1.0", "2.5", "3.0"], + "b": ["4", "5", None], + "c": ["true", "false", "true"], + "d": ["2020-01-02", "2021-03-04", "2022-05-06"], + } + ) + assert_frame_equal(result, expected) + + # test 'strict' mode + lf = pl.LazyFrame({"a": [1000, 2000, 3000]}) + + with pytest.raises(InvalidOperationError, match="conversion .* failed"): + lf.cast(pl.UInt8).collect() + + assert lf.cast(pl.UInt8, strict=False).collect().rows() == [ + (None,), + (None,), + (None,), + ] + + +def test_interpolate() -> None: + df = pl.DataFrame({"a": [1, None, 3]}) + assert df.select(pl.col("a").interpolate())["a"].to_list() == [1, 2, 3] + assert df["a"].interpolate().to_list() == [1, 2, 3] + assert df.interpolate()["a"].to_list() == [1, 2, 3] + assert df.lazy().interpolate().collect()["a"].to_list() == [1, 2, 3] + + +def test_fill_nan() -> None: + df = pl.DataFrame({"a": [1.0, np.nan, 3.0]}) + assert_series_equal(df.fill_nan(2.0)["a"], pl.Series("a", [1.0, 2.0, 3.0])) + assert_series_equal( + df.lazy().fill_nan(2.0).collect()["a"], pl.Series("a", [1.0, 2.0, 3.0]) + ) + assert_series_equal( + df.lazy().fill_nan(None).collect()["a"], pl.Series("a", [1.0, None, 3.0]) + ) + assert_series_equal( + df.select(pl.col("a").fill_nan(2))["a"], pl.Series("a", [1.0, 2.0, 3.0]) + ) + # nearest + assert pl.Series([None, 1, None, None, None, -8, None, None, 10]).interpolate( + method="nearest" + ).to_list() == [None, 1, 1, -8, -8, -8, -8, 10, 10] + + +def test_fill_null() -> None: + df = pl.DataFrame({"a": [1.0, None, 3.0]}) + + assert df.select([pl.col("a").fill_null(strategy="min")])["a"][1] == 1.0 + assert df.lazy().fill_null(2).collect()["a"].to_list() == [1.0, 2.0, 3.0] + + with pytest.raises(ValueError, match="must specify either"): + df.fill_null() + with pytest.raises(ValueError, match="cannot specify both"): + df.fill_null(value=3.0, strategy="max") + with pytest.raises(ValueError, match="can only specify `limit`"): + df.fill_null(strategy="max", limit=2) + + +def test_backward_fill() -> None: + ldf = pl.LazyFrame({"a": [1.0, None, 3.0]}) + col_a_backward_fill = ldf.select( + [pl.col("a").fill_null(strategy="backward")] + ).collect()["a"] + assert_series_equal(col_a_backward_fill, pl.Series("a", [1, 3, 3]).cast(pl.Float64)) + + +def test_rolling(fruits_cars: pl.DataFrame) -> None: + ldf = fruits_cars.lazy() + out = ldf.select( + pl.col("A").rolling_min(3, min_samples=1).alias("1"), + pl.col("A").rolling_min(3).alias("1b"), + pl.col("A").rolling_mean(3, min_samples=1).alias("2"), + pl.col("A").rolling_mean(3).alias("2b"), + pl.col("A").rolling_max(3, min_samples=1).alias("3"), + pl.col("A").rolling_max(3).alias("3b"), + pl.col("A").rolling_sum(3, min_samples=1).alias("4"), + pl.col("A").rolling_sum(3).alias("4b"), + # below we use .round purely for the ability to do assert frame equality + pl.col("A").rolling_std(3).round(1).alias("std"), + pl.col("A").rolling_var(3).round(1).alias("var"), + ) + + assert_frame_equal( + out.collect(), + pl.DataFrame( + { + "1": [1, 1, 1, 2, 3], + "1b": [None, None, 1, 2, 3], + "2": [1.0, 1.5, 2.0, 3.0, 4.0], + "2b": [None, None, 2.0, 3.0, 4.0], + "3": [1, 2, 3, 4, 5], + "3b": [None, None, 3, 4, 5], + "4": [1, 3, 6, 9, 12], + "4b": [None, None, 6, 9, 12], + "std": [None, None, 1.0, 1.0, 1.0], + "var": [None, None, 1.0, 1.0, 1.0], + } + ), + ) + + out_single_val_variance = ldf.select( + pl.col("A").rolling_std(3, min_samples=1).round(decimals=4).alias("std"), + pl.col("A").rolling_var(3, min_samples=1).round(decimals=1).alias("var"), + ).collect() + + assert cast(float, out_single_val_variance[0, "std"]) is None + assert cast(float, out_single_val_variance[0, "var"]) is None + + +def test_arr_namespace(fruits_cars: pl.DataFrame) -> None: + ldf = fruits_cars.lazy() + out = ldf.select( + "fruits", + pl.col("B") + .over("fruits", mapping_strategy="join") + .list.min() + .alias("B_by_fruits_min1"), + pl.col("B") + .min() + .over("fruits", mapping_strategy="join") + .alias("B_by_fruits_min2"), + pl.col("B") + .over("fruits", mapping_strategy="join") + .list.max() + .alias("B_by_fruits_max1"), + pl.col("B") + .max() + .over("fruits", mapping_strategy="join") + .alias("B_by_fruits_max2"), + pl.col("B") + .over("fruits", mapping_strategy="join") + .list.sum() + .alias("B_by_fruits_sum1"), + pl.col("B") + .sum() + .over("fruits", mapping_strategy="join") + .alias("B_by_fruits_sum2"), + pl.col("B") + .over("fruits", mapping_strategy="join") + .list.mean() + .alias("B_by_fruits_mean1"), + pl.col("B") + .mean() + .over("fruits", mapping_strategy="join") + .alias("B_by_fruits_mean2"), + ) + expected = pl.DataFrame( + { + "fruits": ["banana", "banana", "apple", "apple", "banana"], + "B_by_fruits_min1": [1, 1, 2, 2, 1], + "B_by_fruits_min2": [1, 1, 2, 2, 1], + "B_by_fruits_max1": [5, 5, 3, 3, 5], + "B_by_fruits_max2": [5, 5, 3, 3, 5], + "B_by_fruits_sum1": [10, 10, 5, 5, 10], + "B_by_fruits_sum2": [10, 10, 5, 5, 10], + "B_by_fruits_mean1": [ + 3.3333333333333335, + 3.3333333333333335, + 2.5, + 2.5, + 3.3333333333333335, + ], + "B_by_fruits_mean2": [ + 3.3333333333333335, + 3.3333333333333335, + 2.5, + 2.5, + 3.3333333333333335, + ], + } + ) + assert_frame_equal(out.collect(), expected) + + +def test_arithmetic() -> None: + ldf = pl.LazyFrame({"a": [1, 2, 3]}) + + out = ldf.select( + (pl.col("a") % 2).alias("1"), + (2 % pl.col("a")).alias("2"), + (1 // pl.col("a")).alias("3"), + (1 * pl.col("a")).alias("4"), + (1 + pl.col("a")).alias("5"), + (1 - pl.col("a")).alias("6"), + (pl.col("a") // 2).alias("7"), + (pl.col("a") * 2).alias("8"), + (pl.col("a") + 2).alias("9"), + (pl.col("a") - 2).alias("10"), + (-pl.col("a")).alias("11"), + ) + expected = pl.DataFrame( + { + "1": [1, 0, 1], + "2": [0, 0, 2], + "3": [1, 0, 0], + "4": [1, 2, 3], + "5": [2, 3, 4], + "6": [0, -1, -2], + "7": [0, 1, 1], + "8": [2, 4, 6], + "9": [3, 4, 5], + "10": [-1, 0, 1], + "11": [-1, -2, -3], + } + ) + assert_frame_equal(out.collect(), expected) + + +def test_float_floor_divide() -> None: + x = 10.4 + step = 0.5 + ldf = pl.LazyFrame({"x": [x]}) + ldf_res = ldf.with_columns(pl.col("x") // step).collect().item() + assert ldf_res == x // step + + +def test_argminmax() -> None: + ldf = pl.LazyFrame({"a": [1, 2, 3, 4, 5], "b": [1, 1, 2, 2, 2]}) + out = ldf.select( + pl.col("a").arg_min().alias("min"), + pl.col("a").arg_max().alias("max"), + ).collect() + assert out["max"][0] == 4 + assert out["min"][0] == 0 + + out = ( + ldf.group_by("b", maintain_order=True) + .agg([pl.col("a").arg_min().alias("min"), pl.col("a").arg_max().alias("max")]) + .collect() + ) + assert out["max"][0] == 1 + assert out["min"][0] == 0 + + +def test_reverse() -> None: + out = pl.LazyFrame({"a": [1, 2], "b": [3, 4]}).reverse() + expected = pl.DataFrame({"a": [2, 1], "b": [4, 3]}) + assert_frame_equal(out.collect(), expected) + + +def test_limit(fruits_cars: pl.DataFrame) -> None: + assert_frame_equal(fruits_cars.lazy().limit(1).collect(), fruits_cars[0, :]) + + +def test_head(fruits_cars: pl.DataFrame) -> None: + assert_frame_equal(fruits_cars.lazy().head(2).collect(), fruits_cars[:2, :]) + + +def test_tail(fruits_cars: pl.DataFrame) -> None: + assert_frame_equal(fruits_cars.lazy().tail(2).collect(), fruits_cars[3:, :]) + + +def test_last(fruits_cars: pl.DataFrame) -> None: + result = fruits_cars.lazy().last().collect() + expected = fruits_cars[(len(fruits_cars) - 1) :, :] + assert_frame_equal(result, expected) + + +def test_first(fruits_cars: pl.DataFrame) -> None: + assert_frame_equal(fruits_cars.lazy().first().collect(), fruits_cars[0, :]) + + +def test_join_suffix() -> None: + df_left = pl.DataFrame( + { + "a": ["a", "b", "a", "z"], + "b": [1, 2, 3, 4], + "c": [6, 5, 4, 3], + } + ) + df_right = pl.DataFrame( + { + "a": ["b", "c", "b", "a"], + "b": [0, 3, 9, 6], + "c": [1, 0, 2, 1], + } + ) + out = df_left.join(df_right, on="a", suffix="_bar") + assert out.columns == ["a", "b", "c", "b_bar", "c_bar"] + out = df_left.lazy().join(df_right.lazy(), on="a", suffix="_bar").collect() + assert out.columns == ["a", "b", "c", "b_bar", "c_bar"] + + +def test_collect_unexpected_kwargs(df: pl.DataFrame) -> None: + with pytest.raises(TypeError, match="unexpected keyword argument"): + df.lazy().collect(common_subexpr_elim=False) # type: ignore[call-overload] + + +def test_spearman_corr() -> None: + ldf = pl.LazyFrame( + { + "era": [1, 1, 1, 2, 2, 2], + "prediction": [2, 4, 5, 190, 1, 4], + "target": [1, 3, 2, 1, 43, 3], + } + ) + + out = ( + ldf.group_by("era", maintain_order=True).agg( + pl.corr(pl.col("prediction"), pl.col("target"), method="spearman").alias( + "c" + ), + ) + ).collect()["c"] + assert np.isclose(out[0], 0.5) + assert np.isclose(out[1], -1.0) + + # we can also pass in column names directly + out = ( + ldf.group_by("era", maintain_order=True).agg( + pl.corr("prediction", "target", method="spearman").alias("c"), + ) + ).collect()["c"] + assert np.isclose(out[0], 0.5) + assert np.isclose(out[1], -1.0) + + +def test_spearman_corr_ties() -> None: + """In Spearman correlation, ranks are computed using the average method .""" + df = pl.DataFrame({"a": [1, 1, 1, 2, 3, 7, 4], "b": [4, 3, 2, 2, 4, 3, 1]}) + + result = df.select( + pl.corr("a", "b", method="spearman").alias("a1"), + pl.corr(pl.col("a").rank("min"), pl.col("b").rank("min")).alias("a2"), + pl.corr(pl.col("a").rank(), pl.col("b").rank()).alias("a3"), + ) + expected = pl.DataFrame( + [ + pl.Series("a1", [-0.19048482943986483], dtype=pl.Float64), + pl.Series("a2", [-0.17223653586587362], dtype=pl.Float64), + pl.Series("a3", [-0.19048482943986483], dtype=pl.Float64), + ] + ) + assert_frame_equal(result, expected) + + +def test_pearson_corr() -> None: + ldf = pl.LazyFrame( + { + "era": [1, 1, 1, 2, 2, 2], + "prediction": [2, 4, 5, 190, 1, 4], + "target": [1, 3, 2, 1, 43, 3], + } + ) + + out = ( + ldf.group_by("era", maintain_order=True).agg( + pl.corr( + pl.col("prediction"), + pl.col("target"), + method="pearson", + ).alias("c"), + ) + ).collect()["c"] + assert out.to_list() == pytest.approx([0.6546536707079772, -5.477514993831792e-1]) + + # we can also pass in column names directly + out = ( + ldf.group_by("era", maintain_order=True).agg( + pl.corr("prediction", "target", method="pearson").alias("c"), + ) + ).collect()["c"] + assert out.to_list() == pytest.approx([0.6546536707079772, -5.477514993831792e-1]) + + +def test_null_count() -> None: + lf = pl.LazyFrame({"a": [1, 2, None, 2], "b": [None, 3, None, 3]}) + assert lf.null_count().collect().rows() == [(1, 2)] + + +def test_lazy_concat(df: pl.DataFrame) -> None: + shape = df.shape + shape = (shape[0] * 2, shape[1]) + + out = pl.concat([df.lazy(), df.lazy()]).collect() + assert out.shape == shape + assert_frame_equal(out, df.vstack(df)) + + +def test_self_join() -> None: + # 2720 + ldf = pl.from_dict( + data={ + "employee_id": [100, 101, 102], + "employee_name": ["James", "Alice", "Bob"], + "manager_id": [None, 100, 101], + } + ).lazy() + + out = ( + ldf.join(other=ldf, left_on="manager_id", right_on="employee_id", how="left") + .select( + pl.col("employee_id"), + pl.col("employee_name"), + pl.col("employee_name_right").alias("manager_name"), + ) + .collect() + ) + assert set(out.rows()) == { + (100, "James", None), + (101, "Alice", "James"), + (102, "Bob", "Alice"), + } + + +def test_group_lengths() -> None: + ldf = pl.LazyFrame( + { + "group": ["A", "A", "A", "B", "B", "B", "B"], + "id": ["1", "1", "2", "3", "4", "3", "5"], + } + ) + + result = ldf.group_by(["group"], maintain_order=True).agg( + [ + (pl.col("id").unique_counts() / pl.col("id").len()) + .sum() + .alias("unique_counts_sum"), + pl.col("id").unique().len().alias("unique_len"), + ] + ) + expected = pl.DataFrame( + { + "group": ["A", "B"], + "unique_counts_sum": [1.0, 1.0], + "unique_len": [2, 3], + }, + schema_overrides={"unique_len": pl.UInt32}, + ) + assert_frame_equal(result.collect(), expected) + + +def test_quantile_filtered_agg() -> None: + assert ( + pl.LazyFrame( + { + "group": [0, 0, 0, 0, 1, 1, 1, 1], + "value": [1, 2, 3, 4, 1, 2, 3, 4], + } + ) + .group_by("group") + .agg(pl.col("value").filter(pl.col("value") < 2).quantile(0.5)) + .collect()["value"] + .to_list() + ) == [1.0, 1.0] + + +def test_predicate_count_vstack() -> None: + l1 = pl.LazyFrame( + { + "k": ["x", "y"], + "v": [3, 2], + } + ) + l2 = pl.LazyFrame( + { + "k": ["x", "y"], + "v": [5, 7], + } + ) + assert pl.concat([l1, l2]).filter(pl.len().over("k") == 2).collect()[ + "v" + ].to_list() == [3, 2, 5, 7] + + +def test_lazy_method() -> None: + # We want to support `.lazy()` on a Lazy DataFrame to allow more generic user code. + df = pl.DataFrame({"a": [1, 1, 2, 2, 3, 3], "b": [1, 2, 3, 4, 5, 6]}) + assert_frame_equal(df.lazy(), df.lazy().lazy()) + + +def test_update_schema_after_projection_pd_t4157() -> None: + ldf = pl.LazyFrame({"c0": [], "c1": [], "c2": []}).rename({"c2": "c2_"}) + assert ldf.drop("c2_").select(pl.col("c0")).collect().columns == ["c0"] + + +def test_type_coercion_unknown_4190() -> None: + df = ( + pl.LazyFrame({"a": [1, 2, 3], "b": [1, 2, 3]}).with_columns( + pl.col("a") & pl.col("a").fill_null(True) + ) + ).collect() + assert df.shape == (3, 2) + assert df.rows() == [(1, 1), (2, 2), (3, 3)] + + +def test_lazy_cache_same_key() -> None: + ldf = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 4, 5], "c": ["x", "y", "z"]}) + + # these have the same schema, but should not be used by cache as they are different + add_node = ldf.select([(pl.col("a") + pl.col("b")).alias("a"), pl.col("c")]).cache() + mult_node = ldf.select((pl.col("a") * pl.col("b")).alias("a"), pl.col("c")).cache() + + result = mult_node.join(add_node, on="c", suffix="_mult").select( + (pl.col("a") - pl.col("a_mult")).alias("a"), pl.col("c") + ) + expected = pl.LazyFrame({"a": [-1, 2, 7], "c": ["x", "y", "z"]}) + assert_frame_equal(result, expected, check_row_order=False) + + +@pytest.mark.may_fail_auto_streaming +def test_lazy_cache_hit(monkeypatch: Any, capfd: Any) -> None: + monkeypatch.setenv("POLARS_VERBOSE", "1") + + ldf = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 4, 5], "c": ["x", "y", "z"]}) + add_node = ldf.select([(pl.col("a") + pl.col("b")).alias("a"), pl.col("c")]).cache() + + result = add_node.join(add_node, on="c", suffix="_mult").select( + (pl.col("a") - pl.col("a_mult")).alias("a"), pl.col("c") + ) + expected = pl.LazyFrame({"a": [0, 0, 0], "c": ["x", "y", "z"]}) + assert_frame_equal(result, expected) + + (_, err) = capfd.readouterr() + assert "CACHE HIT" in err + + +def test_lazy_cache_parallel() -> None: + df_evaluated = 0 + + def map_df(df: pl.DataFrame) -> pl.DataFrame: + nonlocal df_evaluated + df_evaluated += 1 + return df + + df = pl.LazyFrame({"a": [1]}).map_batches(map_df).cache() + + df = pl.concat( + [ + df.select(pl.col("a") + 1), + df.select(pl.col("a") + 2), + df.select(pl.col("a") + 3), + ], + parallel=True, + ) + + assert df_evaluated == 0 + + df.collect() + assert df_evaluated == 1 + + +def test_lazy_cache_nested_parallel() -> None: + df_inner_evaluated = 0 + df_outer_evaluated = 0 + + def map_df_inner(df: pl.DataFrame) -> pl.DataFrame: + nonlocal df_inner_evaluated + df_inner_evaluated += 1 + return df + + def map_df_outer(df: pl.DataFrame) -> pl.DataFrame: + nonlocal df_outer_evaluated + df_outer_evaluated += 1 + return df + + df_inner = pl.LazyFrame({"a": [1]}).map_batches(map_df_inner).cache() + df_outer = df_inner.select(pl.col("a") + 1).map_batches(map_df_outer).cache() + + df = pl.concat( + [ + df_outer.select(pl.col("a") + 2), + df_outer.select(pl.col("a") + 3), + ], + parallel=True, + ) + + assert df_inner_evaluated == 0 + assert df_outer_evaluated == 0 + + df.collect() + assert df_inner_evaluated == 1 + assert df_outer_evaluated == 1 + + +def test_quadratic_behavior_4736() -> None: + # no assert; if this function does not stall our tests it has passed! + lf = pl.LazyFrame(schema=list(ascii_letters)) + lf.select(reduce(add, (pl.col(c) for c in lf.collect_schema()))) + + +@pytest.mark.parametrize("input_dtype", [pl.Int64, pl.Float64]) +def test_from_epoch(input_dtype: PolarsDataType) -> None: + ldf = pl.LazyFrame( + [ + pl.Series("timestamp_d", [13285]).cast(input_dtype), + pl.Series("timestamp_s", [1147880044]).cast(input_dtype), + pl.Series("timestamp_ms", [1147880044 * 1_000]).cast(input_dtype), + pl.Series("timestamp_us", [1147880044 * 1_000_000]).cast(input_dtype), + pl.Series("timestamp_ns", [1147880044 * 1_000_000_000]).cast(input_dtype), + ] + ) + + exp_dt = datetime(2006, 5, 17, 15, 34, 4) + expected = pl.DataFrame( + [ + pl.Series("timestamp_d", [date(2006, 5, 17)]), + pl.Series("timestamp_s", [exp_dt]), # s is no Polars dtype, defaults to us + pl.Series("timestamp_ms", [exp_dt]).cast(pl.Datetime("ms")), + pl.Series("timestamp_us", [exp_dt]), # us is Polars Datetime default + pl.Series("timestamp_ns", [exp_dt]).cast(pl.Datetime("ns")), + ] + ) + + ldf_result = ldf.select( + pl.from_epoch(pl.col("timestamp_d"), time_unit="d"), + pl.from_epoch(pl.col("timestamp_s"), time_unit="s"), + pl.from_epoch(pl.col("timestamp_ms"), time_unit="ms"), + pl.from_epoch(pl.col("timestamp_us"), time_unit="us"), + pl.from_epoch(pl.col("timestamp_ns"), time_unit="ns"), + ).collect() + + assert_frame_equal(ldf_result, expected) + + ts_col = pl.col("timestamp_s") + with pytest.raises(ValueError): + _ = ldf.select(pl.from_epoch(ts_col, time_unit="s2")) # type: ignore[call-overload] + + +def test_from_epoch_str() -> None: + ldf = pl.LazyFrame( + [ + pl.Series("timestamp_ms", [1147880044 * 1_000]).cast(pl.String), + pl.Series("timestamp_us", [1147880044 * 1_000_000]).cast(pl.String), + ] + ) + + with pytest.raises(InvalidOperationError): + ldf.select( + pl.from_epoch(pl.col("timestamp_ms"), time_unit="ms"), + pl.from_epoch(pl.col("timestamp_us"), time_unit="us"), + ).collect() + + +def test_cum_agg_types() -> None: + ldf = pl.LazyFrame({"a": [1, 2], "b": [True, False], "c": [1.3, 2.4]}) + cum_sum_lf = ldf.select( + pl.col("a").cum_sum(), + pl.col("b").cum_sum(), + pl.col("c").cum_sum(), + ) + assert cum_sum_lf.collect_schema()["a"] == pl.Int64 + assert cum_sum_lf.collect_schema()["b"] == pl.UInt32 + assert cum_sum_lf.collect_schema()["c"] == pl.Float64 + collected_cumsum_lf = cum_sum_lf.collect() + assert collected_cumsum_lf.schema == cum_sum_lf.collect_schema() + + cum_prod_lf = ldf.select( + pl.col("a").cast(pl.UInt64).cum_prod(), + pl.col("b").cum_prod(), + pl.col("c").cum_prod(), + ) + assert cum_prod_lf.collect_schema()["a"] == pl.UInt64 + assert cum_prod_lf.collect_schema()["b"] == pl.Int64 + assert cum_prod_lf.collect_schema()["c"] == pl.Float64 + collected_cum_prod_lf = cum_prod_lf.collect() + assert collected_cum_prod_lf.schema == cum_prod_lf.collect_schema() + + +def test_compare_schema_between_lazy_and_eager_6904() -> None: + float32_df = pl.DataFrame({"x": pl.Series(values=[], dtype=pl.Float32)}) + eager_result = float32_df.select(pl.col("x").sqrt()).select(pl.col(pl.Float32)) + lazy_result = ( + float32_df.lazy() + .select(pl.col("x").sqrt()) + .select(pl.col(pl.Float32)) + .collect() + ) + assert eager_result.shape == lazy_result.shape + + eager_result = float32_df.select(pl.col("x").pow(2)).select(pl.col(pl.Float32)) + lazy_result = ( + float32_df.lazy() + .select(pl.col("x").pow(2)) + .select(pl.col(pl.Float32)) + .collect() + ) + assert eager_result.shape == lazy_result.shape + + int32_df = pl.DataFrame({"x": pl.Series(values=[], dtype=pl.Int32)}) + eager_result = int32_df.select(pl.col("x").pow(2)).select(pl.col(pl.Float64)) + lazy_result = ( + int32_df.lazy().select(pl.col("x").pow(2)).select(pl.col(pl.Float64)).collect() + ) + assert eager_result.shape == lazy_result.shape + + int8_df = pl.DataFrame({"x": pl.Series(values=[], dtype=pl.Int8)}) + eager_result = int8_df.select(pl.col("x").diff()).select(pl.col(pl.Int16)) + lazy_result = ( + int8_df.lazy().select(pl.col("x").diff()).select(pl.col(pl.Int16)).collect() + ) + assert eager_result.shape == lazy_result.shape + + +@pytest.mark.slow +@pytest.mark.parametrize( + "dtype", + [ + pl.UInt8, + pl.UInt16, + pl.UInt32, + pl.UInt64, + pl.Int8, + pl.Int16, + pl.Int32, + pl.Int64, + pl.Float32, + pl.Float64, + ], +) +@pytest.mark.parametrize( + "func", + [ + pl.col("x").arg_max(), + pl.col("x").arg_min(), + pl.col("x").max(), + pl.col("x").mean(), + pl.col("x").median(), + pl.col("x").min(), + pl.col("x").nan_max(), + pl.col("x").nan_min(), + pl.col("x").product(), + pl.col("x").quantile(0.5), + pl.col("x").std(), + pl.col("x").sum(), + pl.col("x").var(), + ], +) +def test_compare_aggregation_between_lazy_and_eager_6904( + dtype: PolarsDataType, func: pl.Expr +) -> None: + df = pl.DataFrame( + { + "x": pl.Series(values=[1, 2, 3] * 2, dtype=dtype), + "y": pl.Series(values=["a"] * 3 + ["b"] * 3), + } + ) + result_eager = df.select(func.over("y")).select("x") + dtype_eager = result_eager["x"].dtype + result_lazy = df.lazy().select(func.over("y")).select(pl.col(dtype_eager)).collect() + assert_frame_equal(result_eager, result_lazy) + + +@pytest.mark.parametrize( + "comparators", + [ + ("==", pl.LazyFrame.__eq__), + ("!=", pl.LazyFrame.__ne__), + (">", pl.LazyFrame.__gt__), + ("<", pl.LazyFrame.__lt__), + (">=", pl.LazyFrame.__ge__), + ("<=", pl.LazyFrame.__le__), + ], +) +def test_lazy_comparison_operators( + comparators: tuple[str, Callable[[pl.LazyFrame, Any], NoReturn]], +) -> None: + # we cannot compare lazy frames, so all should raise a TypeError + with pytest.raises( + TypeError, + match=f'"{comparators[0]!r}" comparison not supported for LazyFrame objects', + ): + comparators[1](pl.LazyFrame(), pl.LazyFrame()) + + +def test_lf_properties() -> None: + lf = pl.LazyFrame( + { + "foo": [1, 2, 3], + "bar": [6.0, 7.0, 8.0], + "ham": ["a", "b", "c"], + } + ) + with pytest.warns(PerformanceWarning): + assert lf.schema == {"foo": pl.Int64, "bar": pl.Float64, "ham": pl.String} + with pytest.warns(PerformanceWarning): + assert lf.columns == ["foo", "bar", "ham"] + with pytest.warns(PerformanceWarning): + assert lf.dtypes == [pl.Int64, pl.Float64, pl.String] + with pytest.warns(PerformanceWarning): + assert lf.width == 3 + + +def test_lf_unnest() -> None: + lf = pl.DataFrame( + [ + pl.Series( + "a", + [{"ab": [1, 2, 3], "ac": [3, 4, 5]}], + dtype=pl.Struct({"ab": pl.List(pl.Int64), "ac": pl.List(pl.Int64)}), + ), + pl.Series( + "b", + [{"ba": [5, 6, 7], "bb": [7, 8, 9]}], + dtype=pl.Struct({"ba": pl.List(pl.Int64), "bb": pl.List(pl.Int64)}), + ), + ] + ).lazy() + + expected = pl.DataFrame( + [ + pl.Series("ab", [[1, 2, 3]], dtype=pl.List(pl.Int64)), + pl.Series("ac", [[3, 4, 5]], dtype=pl.List(pl.Int64)), + pl.Series("ba", [[5, 6, 7]], dtype=pl.List(pl.Int64)), + pl.Series("bb", [[7, 8, 9]], dtype=pl.List(pl.Int64)), + ] + ) + assert_frame_equal(lf.unnest("a", "b").collect(), expected) + + +def test_type_coercion_cast_boolean_after_comparison() -> None: + import operator + + lf = pl.LazyFrame({"a": 1, "b": 2}) + + for op in [ + operator.eq, + operator.ne, + operator.lt, + operator.le, + operator.gt, + operator.ge, + pl.Expr.eq_missing, + pl.Expr.ne_missing, + ]: + e = op(pl.col("a"), pl.col("b")).cast(pl.Boolean).alias("o") + assert "cast" not in lf.with_columns(e).explain() + + e = op(pl.col("a"), pl.col("b")).cast(pl.Boolean).cast(pl.Boolean).alias("o") + assert "cast" not in lf.with_columns(e).explain() + + for op in [operator.and_, operator.or_, operator.xor]: + e = op(pl.col("a"), pl.col("b")).cast(pl.Boolean) + assert "cast" in lf.with_columns(e).explain() + + +def test_unique_length_multiple_columns() -> None: + lf = pl.LazyFrame( + { + "a": [1, 1, 1, 2, 3], + "b": [100, 100, 200, 100, 300], + } + ) + assert lf.unique().select(pl.len()).collect().item() == 4 diff --git a/py-polars/tests/unit/lazyframe/test_optimizations.py b/py-polars/tests/unit/lazyframe/test_optimizations.py new file mode 100644 index 000000000000..14556e48967c --- /dev/null +++ b/py-polars/tests/unit/lazyframe/test_optimizations.py @@ -0,0 +1,335 @@ +import itertools + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + + +def test_is_null_followed_by_all() -> None: + lf = pl.LazyFrame({"group": [0, 0, 0, 1], "val": [6, 0, None, None]}) + + expected_df = pl.DataFrame({"group": [0, 1], "val": [False, True]}) + result_lf = lf.group_by("group", maintain_order=True).agg( + pl.col("val").is_null().all() + ) + + assert ( + r'[[(col("val").count()) == (col("val").null_count())]]' in result_lf.explain() + ) + assert "is_null" not in result_lf + assert_frame_equal(expected_df, result_lf.collect()) + + # verify we don't optimize on chained expressions when last one is not col + non_optimized_result_plan = ( + lf.group_by("group", maintain_order=True) + .agg(pl.col("val").abs().is_null().all()) + .explain() + ) + assert "null_count" not in non_optimized_result_plan + assert "is_null" in non_optimized_result_plan + + # edge case of empty series + lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32}) + + expected_df = pl.DataFrame({"val": [True]}) + result_df = lf.select(pl.col("val").is_null().all()).collect() + assert_frame_equal(expected_df, result_df) + + +def test_is_null_followed_by_any() -> None: + lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]}) + + expected_df = pl.DataFrame({"group": [0, 1, 2], "val": [True, True, False]}) + result_lf = lf.group_by("group", maintain_order=True).agg( + pl.col("val").is_null().any() + ) + assert_frame_equal(expected_df, result_lf.collect()) + + # edge case of empty series + lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32}) + + expected_df = pl.DataFrame({"val": [False]}) + result_df = lf.select(pl.col("val").is_null().any()).collect() + assert_frame_equal(expected_df, result_df) + + +def test_is_not_null_followed_by_all() -> None: + lf = pl.LazyFrame({"group": [0, 0, 0, 1], "val": [6, 0, 5, None]}) + + expected_df = pl.DataFrame({"group": [0, 1], "val": [True, False]}) + result_df = ( + lf.group_by("group", maintain_order=True) + .agg(pl.col("val").is_not_null().all()) + .collect() + ) + + assert_frame_equal(expected_df, result_df) + + # edge case of empty series + lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32}) + + expected_df = pl.DataFrame({"val": [True]}) + result_df = lf.select(pl.col("val").is_not_null().all()).collect() + assert_frame_equal(expected_df, result_df) + + +def test_is_not_null_followed_by_any() -> None: + lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]}) + + expected_df = pl.DataFrame({"group": [0, 1, 2], "val": [True, False, True]}) + result_lf = lf.group_by("group", maintain_order=True).agg( + pl.col("val").is_not_null().any() + ) + + assert ( + r'[[(col("val").null_count()) < (col("val").count())]]' in result_lf.explain() + ) + assert "is_not_null" not in result_lf.explain() + assert_frame_equal(expected_df, result_lf.collect()) + + # verify we don't optimize on chained expressions when last one is not col + non_optimized_result_plan = ( + lf.group_by("group", maintain_order=True) + .agg(pl.col("val").abs().is_not_null().any()) + .explain() + ) + assert "null_count" not in non_optimized_result_plan + assert "is_not_null" in non_optimized_result_plan + + # edge case of empty series + lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32}) + + expected_df = pl.DataFrame({"val": [False]}) + result_df = lf.select(pl.col("val").is_not_null().any()).collect() + assert_frame_equal(expected_df, result_df) + + +def test_is_null_followed_by_sum() -> None: + lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]}) + + expected_df = pl.DataFrame( + {"group": [0, 1, 2], "val": [1, 1, 0]}, schema_overrides={"val": pl.UInt32} + ) + result_lf = lf.group_by("group", maintain_order=True).agg( + pl.col("val").is_null().sum() + ) + + assert r'[col("val").null_count()]' in result_lf.explain() + assert "is_null" not in result_lf.explain() + assert_frame_equal(expected_df, result_lf.collect()) + + # edge case of empty series + lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32}) + + expected_df = pl.DataFrame({"val": [0]}, schema={"val": pl.UInt32}) + result_df = lf.select(pl.col("val").is_null().sum()).collect() + assert_frame_equal(expected_df, result_df) + + +def test_is_not_null_followed_by_sum() -> None: + lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]}) + + expected_df = pl.DataFrame( + {"group": [0, 1, 2], "val": [2, 0, 1]}, schema_overrides={"val": pl.UInt32} + ) + result_lf = lf.group_by("group", maintain_order=True).agg( + pl.col("val").is_not_null().sum() + ) + + assert ( + r'[[(col("val").count()) - (col("val").null_count())]]' in result_lf.explain() + ) + assert "is_not_null" not in result_lf.explain() + assert_frame_equal(expected_df, result_lf.collect()) + + # verify we don't optimize on chained expressions when last one is not col + non_optimized_result_lf = lf.group_by("group", maintain_order=True).agg( + pl.col("val").abs().is_not_null().sum() + ) + assert "null_count" not in non_optimized_result_lf.explain() + assert "is_not_null" in non_optimized_result_lf.explain() + + # edge case of empty series + lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32}) + + expected_df = pl.DataFrame({"val": [0]}, schema={"val": pl.UInt32}) + result_df = lf.select(pl.col("val").is_not_null().sum()).collect() + assert_frame_equal(expected_df, result_df) + + +def test_drop_nulls_followed_by_len() -> None: + lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]}) + + expected_df = pl.DataFrame( + {"group": [0, 1, 2], "val": [2, 0, 1]}, schema_overrides={"val": pl.UInt32} + ) + result_lf = lf.group_by("group", maintain_order=True).agg( + pl.col("val").drop_nulls().len() + ) + + assert ( + r'[[(col("val").count()) - (col("val").null_count())]]' in result_lf.explain() + ) + assert "drop_nulls" not in result_lf.explain() + assert_frame_equal(expected_df, result_lf.collect()) + + # verify we don't optimize on chained expressions when last one is not col + non_optimized_result_plan = ( + lf.group_by("group", maintain_order=True) + .agg(pl.col("val").abs().drop_nulls().len()) + .explain() + ) + assert "null_count" not in non_optimized_result_plan + assert "drop_nulls" in non_optimized_result_plan + + +def test_drop_nulls_followed_by_count() -> None: + lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]}) + + expected_df = pl.DataFrame( + {"group": [0, 1, 2], "val": [2, 0, 1]}, schema_overrides={"val": pl.UInt32} + ) + result_lf = lf.group_by("group", maintain_order=True).agg( + pl.col("val").drop_nulls().count() + ) + + assert ( + r'[[(col("val").count()) - (col("val").null_count())]]' in result_lf.explain() + ) + assert "drop_nulls" not in result_lf.explain() + assert_frame_equal(expected_df, result_lf.collect()) + + # verify we don't optimize on chained expressions when last one is not col + non_optimized_result_plan = ( + lf.group_by("group", maintain_order=True) + .agg(pl.col("val").abs().drop_nulls().count()) + .explain() + ) + assert "null_count" not in non_optimized_result_plan + assert "drop_nulls" in non_optimized_result_plan + + +def test_collapse_joins() -> None: + a = pl.LazyFrame({"a": [1, 2, 3], "b": [2, 2, 2]}) + b = pl.LazyFrame({"x": [7, 1, 2]}) + + cross = a.join(b, how="cross") + + inner_join = cross.filter(pl.col.a == pl.col.x) + e = inner_join.explain() + assert "INNER JOIN" in e + assert "FILTER" not in e + assert_frame_equal( + inner_join.collect(collapse_joins=False), + inner_join.collect(), + check_row_order=False, + ) + + inner_join = cross.filter(pl.col.x == pl.col.a) + e = inner_join.explain() + assert "INNER JOIN" in e + assert "FILTER" not in e + assert_frame_equal( + inner_join.collect(collapse_joins=False), + inner_join.collect(), + check_row_order=False, + ) + + double_inner_join = cross.filter(pl.col.x == pl.col.a).filter(pl.col.x == pl.col.b) + e = double_inner_join.explain() + assert "INNER JOIN" in e + assert "FILTER" not in e + assert_frame_equal( + double_inner_join.collect(collapse_joins=False), + double_inner_join.collect(), + check_row_order=False, + ) + + dont_mix = cross.filter(pl.col.x + pl.col.a != 0) + e = dont_mix.explain() + assert "NESTED LOOP JOIN" in e + assert "FILTER" not in e + assert_frame_equal( + dont_mix.collect(collapse_joins=False), + dont_mix.collect(), + check_row_order=False, + ) + + no_literals = cross.filter(pl.col.x == 2) + e = no_literals.explain() + assert "NESTED LOOP JOIN" in e + assert_frame_equal( + no_literals.collect(collapse_joins=False), + no_literals.collect(), + check_row_order=False, + ) + + iejoin = cross.filter(pl.col.x >= pl.col.a) + e = iejoin.explain() + assert "IEJOIN" in e + assert "NESTED LOOP JOIN" not in e + assert "CROSS JOIN" not in e + assert "FILTER" not in e + assert_frame_equal( + iejoin.collect(collapse_joins=False), + iejoin.collect(), + check_row_order=False, + ) + + iejoin = cross.filter(pl.col.x >= pl.col.a).filter(pl.col.x <= pl.col.b) + e = iejoin.explain() + assert "IEJOIN" in e + assert "CROSS JOIN" not in e + assert "NESTED LOOP JOIN" not in e + assert "FILTER" not in e + assert_frame_equal( + iejoin.collect(collapse_joins=False), iejoin.collect(), check_row_order=False + ) + + +@pytest.mark.slow +def test_collapse_joins_combinations() -> None: + # This just tests all possible combinations for expressions on a cross join. + + a = pl.LazyFrame({"a": [1, 2, 3], "x": [7, 2, 1]}) + b = pl.LazyFrame({"b": [2, 2, 2], "x": [7, 1, 3]}) + + cross = a.join(b, how="cross") + + exprs = [] + + for lhs in [pl.col.a, pl.col.b, pl.col.x, pl.lit(1), pl.col.a + pl.col.b]: + for rhs in [pl.col.a, pl.col.b, pl.col.x, pl.lit(1), pl.col.a * pl.col.x]: + for cmp in ["__eq__", "__ge__", "__lt__"]: + e = (getattr(lhs, cmp))(rhs) + exprs.append(e) + + for amount in range(3): + for merge in itertools.product(["__and__", "__or__"] * (amount - 1)): + for es in itertools.product(*([exprs] * amount)): + e = es[0] + for i in range(amount - 1): + e = (getattr(e, merge[i]))(es[i + 1]) + + # NOTE: We need to sort because the order of the cross-join & + # IE-join is unspecified. Therefore, this might not necessarily + # create the exact same dataframe. + optimized = cross.filter(e).sort(pl.all()).collect() + unoptimized = cross.filter(e).collect(collapse_joins=False) + + try: + assert_frame_equal(optimized, unoptimized, check_row_order=False) + except: + print(e) + print() + print("Optimized") + print(cross.filter(e).explain()) + print(optimized) + print() + print("Unoptimized") + print(cross.filter(e).explain(collapse_joins=False)) + print(unoptimized) + print() + + raise diff --git a/py-polars/tests/unit/lazyframe/test_order_observability.py b/py-polars/tests/unit/lazyframe/test_order_observability.py new file mode 100644 index 000000000000..425a4954b0da --- /dev/null +++ b/py-polars/tests/unit/lazyframe/test_order_observability.py @@ -0,0 +1,53 @@ +import polars as pl +from polars.testing import assert_frame_equal + + +def test_order_observability() -> None: + q = pl.LazyFrame({"a": [1, 2, 3], "b": [1, 2, 3]}).sort("a") + + assert "SORT" not in q.group_by("a").sum().explain(_check_order=True) + assert "SORT" not in q.group_by("a").min().explain(_check_order=True) + assert "SORT" not in q.group_by("a").max().explain(_check_order=True) + assert "SORT" in q.group_by("a").last().explain(_check_order=True) + assert "SORT" in q.group_by("a").first().explain(_check_order=True) + + +def test_order_observability_group_by_dynamic() -> None: + assert ( + pl.LazyFrame( + {"REGIONID": [1, 23, 4], "INTERVAL_END": [32, 43, 12], "POWER": [12, 3, 1]} + ) + .sort("REGIONID", "INTERVAL_END") + .group_by_dynamic(index_column="INTERVAL_END", every="1i", group_by="REGIONID") + .agg(pl.col("POWER").sum()) + .sort("POWER") + .head() + .explain() + ).count("SORT") == 2 + + +def test_remove_double_sort() -> None: + assert ( + pl.LazyFrame({"a": [1, 2, 3, 3]}).sort("a").sort("a").explain().count("SORT") + == 1 + ) + + +def test_double_sort_maintain_order_18558() -> None: + df = pl.DataFrame( + { + "col1": [1, 2, 2, 4, 5, 6], + "col2": [2, 2, 0, 0, 2, None], + } + ) + + lf = df.lazy().sort("col2").sort("col1", maintain_order=True) + + expect = pl.DataFrame( + [ + pl.Series("col1", [1, 2, 2, 4, 5, 6], dtype=pl.Int64), + pl.Series("col2", [2, 0, 2, 0, 2, None], dtype=pl.Int64), + ] + ) + + assert_frame_equal(lf.collect(), expect) diff --git a/py-polars/tests/unit/lazyframe/test_rename.py b/py-polars/tests/unit/lazyframe/test_rename.py new file mode 100644 index 000000000000..5f1d93756fac --- /dev/null +++ b/py-polars/tests/unit/lazyframe/test_rename.py @@ -0,0 +1,37 @@ +from types import MappingProxyType + +import pytest + +import polars as pl +from polars.exceptions import ColumnNotFoundError + + +def test_lazy_rename() -> None: + lf = pl.LazyFrame({"x": [1], "y": [2]}) + + result = lf.rename({"y": "x", "x": "y"}).select(["x", "y"]).collect() + assert result.to_dict(as_series=False) == {"x": [2], "y": [1]} + + # the `strict` param controls whether we fail on columns not found in the frame + remap_colnames = {"b": "a", "y": "x", "a": "b", "x": "y"} + with pytest.raises(ColumnNotFoundError, match="'b' is invalid"): + lf.rename(remap_colnames).collect() + + result = lf.rename(remap_colnames, strict=False).collect() + assert result.to_dict(as_series=False) == {"x": [2], "y": [1]} + + +def test_remove_redundant_mapping_4668() -> None: + lf = pl.LazyFrame([["a"]] * 2, ["A", "B "]).lazy() + clean_name_dict = {x: " ".join(x.split()) for x in lf.collect_schema()} + lf = lf.rename(clean_name_dict) + assert lf.collect_schema().names() == ["A", "B"] + + +def test_rename_mapping_19400() -> None: + # use a mapping type that is not a dict + mapping = MappingProxyType({"a": "b", "b": "c"}) + + assert pl.LazyFrame({"a": [1], "b": [2]}).rename(mapping).collect().to_dict( + as_series=False + ) == {"b": [1], "c": [2]} diff --git a/py-polars/tests/unit/lazyframe/test_serde.py b/py-polars/tests/unit/lazyframe/test_serde.py new file mode 100644 index 000000000000..a3ffb47e01c7 --- /dev/null +++ b/py-polars/tests/unit/lazyframe/test_serde.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +import io +from typing import TYPE_CHECKING + +import pytest +from hypothesis import example, given + +import polars as pl +from polars.exceptions import ComputeError +from polars.testing import assert_frame_equal +from polars.testing.parametric import dataframes + +if TYPE_CHECKING: + from pathlib import Path + + from polars._typing import SerializationFormat + + +@given( + lf=dataframes( + lazy=True, + excluded_dtypes=[pl.Struct], + ) +) +@example(lf=pl.LazyFrame({"foo": ["a", "b", "a"]}, schema={"foo": pl.Enum(["b", "a"])})) +def test_lf_serde_roundtrip_binary(lf: pl.LazyFrame) -> None: + serialized = lf.serialize(format="binary") + result = pl.LazyFrame.deserialize(io.BytesIO(serialized), format="binary") + assert_frame_equal(result, lf, categorical_as_str=True) + + +@given( + lf=dataframes( + lazy=True, + excluded_dtypes=[ + pl.Float32, # Bug, see: https://github.com/pola-rs/polars/issues/17211 + pl.Float64, # Bug, see: https://github.com/pola-rs/polars/issues/17211 + pl.Struct, # Outer nullability not supported + ], + ) +) +@pytest.mark.filterwarnings("ignore") +def test_lf_serde_roundtrip_json(lf: pl.LazyFrame) -> None: + serialized = lf.serialize(format="json") + result = pl.LazyFrame.deserialize(io.StringIO(serialized), format="json") + assert_frame_equal(result, lf, categorical_as_str=True) + + +@pytest.fixture +def lf() -> pl.LazyFrame: + """Sample LazyFrame for testing serialization/deserialization.""" + return pl.LazyFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}).select("a").sum() + + +@pytest.mark.filterwarnings("ignore") +def test_lf_serde_json_stringio(lf: pl.LazyFrame) -> None: + serialized = lf.serialize(format="json") + assert isinstance(serialized, str) + result = pl.LazyFrame.deserialize(io.StringIO(serialized), format="json") + assert_frame_equal(result, lf) + + +def test_lf_serde(lf: pl.LazyFrame) -> None: + serialized = lf.serialize() + assert isinstance(serialized, bytes) + result = pl.LazyFrame.deserialize(io.BytesIO(serialized)) + assert_frame_equal(result, lf) + + +@pytest.mark.parametrize( + ("format", "buf"), + [ + ("binary", io.BytesIO()), + ("json", io.StringIO()), + ("json", io.BytesIO()), + ], +) +@pytest.mark.filterwarnings("ignore") +def test_lf_serde_to_from_buffer( + lf: pl.LazyFrame, format: SerializationFormat, buf: io.IOBase +) -> None: + lf.serialize(buf, format=format) + buf.seek(0) + result = pl.LazyFrame.deserialize(buf, format=format) + assert_frame_equal(lf, result) + + +@pytest.mark.write_disk +def test_lf_serde_to_from_file(lf: pl.LazyFrame, tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + file_path = tmp_path / "small.bin" + lf.serialize(file_path) + result = pl.LazyFrame.deserialize(file_path) + + assert_frame_equal(lf, result) + + +def test_lf_deserialize_validation() -> None: + f = io.BytesIO(b"hello world!") + with pytest.raises(ComputeError, match="expected value at line 1 column 1"): + pl.LazyFrame.deserialize(f, format="json") + + +@pytest.mark.write_disk +def test_lf_serde_scan(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + path = tmp_path / "dataset.parquet" + + df = pl.DataFrame({"a": [1, 2, 3], "b": ["x", "y", "z"]}) + df.write_parquet(path) + lf = pl.scan_parquet(path) + + ser = lf.serialize() + result = pl.LazyFrame.deserialize(io.BytesIO(ser)) + assert_frame_equal(result, lf) + assert_frame_equal(result.collect(), df) + + +@pytest.mark.filterwarnings("ignore::polars.exceptions.PolarsInefficientMapWarning") +def test_lf_serde_version_specific_lambda() -> None: + lf = pl.LazyFrame({"a": [1, 2, 3]}).select( + pl.col("a").map_elements(lambda x: x + 1, return_dtype=pl.Int64) + ) + ser = lf.serialize() + + result = pl.LazyFrame.deserialize(io.BytesIO(ser)) + expected = pl.LazyFrame({"a": [2, 3, 4]}) + assert_frame_equal(result, expected) + + +def custom_function(x: pl.Series) -> pl.Series: + return x + 1 + + +@pytest.mark.filterwarnings("ignore::polars.exceptions.PolarsInefficientMapWarning") +def test_lf_serde_version_specific_named_function() -> None: + lf = pl.LazyFrame({"a": [1, 2, 3]}).select( + pl.col("a").map_batches(custom_function, return_dtype=pl.Int64) + ) + ser = lf.serialize() + + result = pl.LazyFrame.deserialize(io.BytesIO(ser)) + expected = pl.LazyFrame({"a": [2, 3, 4]}) + assert_frame_equal(result, expected) + + +@pytest.mark.filterwarnings("ignore::polars.exceptions.PolarsInefficientMapWarning") +def test_lf_serde_map_batches_on_lazyframe() -> None: + lf = pl.LazyFrame({"a": [1, 2, 3]}).map_batches(lambda x: x + 1) + ser = lf.serialize() + + result = pl.LazyFrame.deserialize(io.BytesIO(ser)) + expected = pl.LazyFrame({"a": [2, 3, 4]}) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/lazyframe/test_with_context.py b/py-polars/tests/unit/lazyframe/test_with_context.py new file mode 100644 index 000000000000..10710bff4335 --- /dev/null +++ b/py-polars/tests/unit/lazyframe/test_with_context.py @@ -0,0 +1,121 @@ +from datetime import datetime + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + + +def test_with_context() -> None: + df_a = pl.DataFrame({"a": [1, 2, 3], "b": ["a", "c", None]}).lazy() + df_b = pl.DataFrame({"c": ["foo", "ham"]}) + + with pytest.deprecated_call(): + result = df_a.with_context(df_b.lazy()).select( + pl.col("b") + pl.col("c").first() + ) + assert result.collect().to_dict(as_series=False) == {"b": ["afoo", "cfoo", None]} + + with pytest.deprecated_call(): + context = df_a.with_context(df_b.lazy()) + with pytest.raises(pl.exceptions.ShapeError): + context.select("a", "c").collect() + + +# https://github.com/pola-rs/polars/issues/5867 +def test_with_context_ignore_5867() -> None: + outer = pl.LazyFrame({"OtherCol": [1, 2, 3, 4]}) + with pytest.deprecated_call(): + lf = pl.LazyFrame( + {"Category": [1, 1, 2, 2], "Counts": [1, 2, 3, 4]} + ).with_context(outer) + + result = lf.group_by("Category", maintain_order=True).agg(pl.col("Counts").sum()) + + expected = pl.LazyFrame({"Category": [1, 2], "Counts": [3, 7]}) + assert_frame_equal(result, expected) + + +def test_predicate_pushdown_with_context_11014() -> None: + df1 = pl.LazyFrame( + { + "df1_c1": [1, 2, 3], + "df1_c2": [2, 3, 4], + } + ) + + df2 = pl.LazyFrame( + { + "df2_c1": [2, 3, 4], + "df2_c2": [3, 4, 5], + } + ) + + with pytest.deprecated_call(): + out = ( + df1.with_context(df2) + .filter(pl.col("df1_c1").is_in(pl.col("df2_c1"))) + .collect(predicate_pushdown=True) + ) + + assert out.to_dict(as_series=False) == {"df1_c1": [2, 3], "df1_c2": [3, 4]} + + +@pytest.mark.xdist_group("streaming") +def test_streaming_11219() -> None: + # https://github.com/pola-rs/polars/issues/11219 + + lf = pl.LazyFrame({"a": [1, 2, 3], "b": ["a", "c", None]}) + lf_other = pl.LazyFrame({"c": ["foo", "ham"]}) + lf_other2 = pl.LazyFrame({"c": ["foo", "ham"]}) + + with pytest.deprecated_call(): + context = lf.with_context([lf_other, lf_other2]) + + assert context.select(pl.col("b") + pl.col("c").first()).collect( + engine="old-streaming" # type: ignore[call-overload] + ).to_dict(as_series=False) == {"b": ["afoo", "cfoo", None]} + + +def test_no_cse_in_with_context() -> None: + df1 = pl.DataFrame( + { + "timestamp": [ + datetime(2023, 1, 1, 0, 0), + datetime(2023, 5, 1, 0, 0), + datetime(2023, 10, 1, 0, 0), + ], + "value": [2, 5, 9], + } + ) + df2 = pl.DataFrame( + { + "date_start": [ + datetime(2022, 12, 31, 0, 0), + datetime(2023, 1, 2, 0, 0), + ], + "date_end": [ + datetime(2023, 4, 30, 0, 0), + datetime(2023, 5, 5, 0, 0), + ], + "label": [0, 1], + } + ) + + with pytest.deprecated_call(): + context = df1.lazy().with_context(df2.lazy()) + + assert ( + context.select( + pl.col("date_start", "label").gather( + pl.col("date_start").search_sorted(pl.col("timestamp")) - 1 + ), + ) + ).collect().to_dict(as_series=False) == { + "date_start": [ + datetime(2022, 12, 31, 0, 0), + datetime(2023, 1, 2, 0, 0), + datetime(2023, 1, 2, 0, 0), + ], + "label": [0, 1, 1], + } diff --git a/py-polars/tests/unit/meta/__init__.py b/py-polars/tests/unit/meta/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/meta/test_build.py b/py-polars/tests/unit/meta/test_build.py new file mode 100644 index 000000000000..d6b711ef4ee2 --- /dev/null +++ b/py-polars/tests/unit/meta/test_build.py @@ -0,0 +1,6 @@ +import polars as pl + + +def test_build_info_version() -> None: + build_info = pl.build_info() + assert build_info["version"] == pl.__version__ diff --git a/py-polars/tests/unit/meta/test_index_type.py b/py-polars/tests/unit/meta/test_index_type.py new file mode 100644 index 000000000000..07bc112b3dcd --- /dev/null +++ b/py-polars/tests/unit/meta/test_index_type.py @@ -0,0 +1,5 @@ +import polars as pl + + +def test_get_index_type() -> None: + assert pl.get_index_type() == pl.UInt32() diff --git a/py-polars/tests/unit/meta/test_thread_pool.py b/py-polars/tests/unit/meta/test_thread_pool.py new file mode 100644 index 000000000000..159ab89cc946 --- /dev/null +++ b/py-polars/tests/unit/meta/test_thread_pool.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +import pytest + +import polars as pl + + +def test_thread_pool_size() -> None: + result = pl.thread_pool_size() + assert isinstance(result, int) + + +def test_threadpool_size_deprecated() -> None: + with pytest.deprecated_call(): + result = pl.threadpool_size() + assert isinstance(result, int) diff --git a/py-polars/tests/unit/meta/test_versions.py b/py-polars/tests/unit/meta/test_versions.py new file mode 100644 index 000000000000..36504c7fe4c5 --- /dev/null +++ b/py-polars/tests/unit/meta/test_versions.py @@ -0,0 +1,16 @@ +from typing import Any + +import pytest + +import polars as pl + + +@pytest.mark.slow +def test_show_versions(capsys: Any) -> None: + pl.show_versions() + + out, _ = capsys.readouterr() + assert "Python" in out + assert "Polars" in out + assert "LTS CPU" in out + assert "Optional dependencies" in out diff --git a/py-polars/tests/unit/ml/__init__.py b/py-polars/tests/unit/ml/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/ml/test_to_jax.py b/py-polars/tests/unit/ml/test_to_jax.py new file mode 100644 index 000000000000..2d12dd0d37c2 --- /dev/null +++ b/py-polars/tests/unit/ml/test_to_jax.py @@ -0,0 +1,194 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pytest + +import polars as pl +import polars.selectors as cs +from polars.dependencies import _lazy_import + +# don't import jax until an actual test is triggered (the decorator already +# ensures the tests aren't run locally; this avoids premature local import) +jx, _ = _lazy_import("jax") +jxn, _ = _lazy_import("jax.numpy") + +pytestmark = pytest.mark.ci_only + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + + +@pytest.fixture +def df() -> pl.DataFrame: + return pl.DataFrame( + { + "x": [1, 2, 2, 3], + "y": [1, 0, 1, 0], + "z": [1.5, -0.5, 0.0, -2.0], + }, + schema_overrides={"x": pl.Int8, "z": pl.Float32}, + ) + + +def assert_array_equal(actual: Any, expected: Any, nans_equal: bool = True) -> None: + assert isinstance(actual, jx.Array) + jxn.array_equal(actual, expected, equal_nan=nans_equal) + + +@pytest.mark.parametrize( + ("dtype", "expected_jax_dtype"), + [ + (pl.Int8, "int8"), + (pl.Int16, "int16"), + (pl.Int32, "int32"), + (pl.Int64, "int32"), + (pl.UInt8, "uint8"), + (pl.UInt16, "uint16"), + (pl.UInt32, "uint32"), + (pl.UInt64, "uint32"), + ], +) +def test_to_jax_from_series( + dtype: PolarsDataType, + expected_jax_dtype: str, +) -> None: + s = pl.Series("x", [1, 2, 3, 4], dtype=dtype) + for dvc in (None, "cpu", jx.devices("cpu")[0]): + assert_array_equal( + s.to_jax(device=dvc), + jxn.array([1, 2, 3, 4], dtype=getattr(jxn, expected_jax_dtype)), + ) + + +def test_to_jax_array(df: pl.DataFrame) -> None: + a1 = df.to_jax() + a2 = df.to_jax("array") + a3 = df.to_jax("array", device="cpu") + a4 = df.to_jax("array", device=jx.devices("cpu")[0]) + + expected = jxn.array( + [ + [1.0, 1.0, 1.5], + [2.0, 0.0, -0.5], + [2.0, 1.0, 0.0], + [3.0, 0.0, -2.0], + ], + dtype=jxn.float32, + ) + for a in (a1, a2, a3, a4): + assert_array_equal(a, expected) + + +def test_2D_array_cols_to_jax() -> None: + # 2D array + df1 = pl.DataFrame( + {"data": [[1, 1], [1, 2], [2, 2]]}, + schema_overrides={"data": pl.Array(pl.Int32, shape=(2,))}, + ) + arr1 = df1.to_jax() + assert_array_equal( + arr1, + jxn.array([[1, 1], [1, 2], [2, 2]], dtype=jxn.int32), + ) + + # nested 2D array + df2 = pl.DataFrame( + {"data": [[[1, 1], [1, 2]], [[2, 2], [2, 3]]]}, + schema_overrides={"data": pl.Array(pl.Array(pl.Int32, shape=(2,)), shape=(2,))}, + ) + arr2 = df2.to_jax() + assert_array_equal( + arr2, + jxn.array([[[1, 1], [1, 2]], [[2, 2], [2, 3]]], dtype=jxn.int32), + ) + + # dict with 2D array + df3 = df2.insert_column(0, pl.Series("lbl", [0, 1])) + lbl_feat_dict = df3.to_jax("dict") + assert_array_equal( + lbl_feat_dict["lbl"], + jxn.array([0, 1], jxn.int32), + ) + assert_array_equal( + lbl_feat_dict["data"], + jxn.array([[[1, 1], [1, 2]], [[2, 2], [2, 3]]], jxn.int32), + ) + + # no support for list (yet? could add if ragged arrays are valid) + with pytest.raises( + TypeError, + match=r"cannot convert List column 'data' to Jax Array \(use Array dtype instead\)", + ): + pl.DataFrame({"data": [[1, 1], [1, 2], [2, 2]]}).to_jax() + + +def test_to_jax_dict(df: pl.DataFrame) -> None: + arr_dict = df.to_jax("dict") + assert list(arr_dict.keys()) == ["x", "y", "z"] + + assert_array_equal(arr_dict["x"], jxn.array([1, 2, 2, 3], dtype=jxn.int8)) + assert_array_equal(arr_dict["y"], jxn.array([1, 0, 1, 0], dtype=jxn.int32)) + assert_array_equal( + arr_dict["z"], + jxn.array([1.5, -0.5, 0.0, -2.0], dtype=jxn.float32), + ) + + arr_dict = df.to_jax("dict", dtype=pl.Float32) + for a, expected_data in zip( + arr_dict.values(), + ([1.0, 2.0, 2.0, 3.0], [1.0, 0.0, 1.0, 0.0], [1.5, -0.5, 0.0, -2.0]), + ): + assert_array_equal(a, jxn.array(expected_data, dtype=jxn.float32)) + + +def test_to_jax_feature_label_dict(df: pl.DataFrame) -> None: + df = pl.DataFrame( + { + "age": [25, 32, 45, 22, 34], + "income": [50000, 75000, 60000, 58000, 120000], + "education": ["bachelor", "master", "phd", "bachelor", "phd"], + "purchased": [False, True, True, False, True], + } + ).to_dummies("education", separator=":") + + lbl_feat_dict = df.to_jax(return_type="dict", label="purchased") + assert list(lbl_feat_dict.keys()) == ["label", "features"] + + assert_array_equal( + lbl_feat_dict["label"], + jxn.array([[False], [True], [True], [False], [True]], dtype=jxn.bool), + ) + assert_array_equal( + lbl_feat_dict["features"], + jxn.array( + [ + [25, 50000, 1, 0, 0], + [32, 75000, 0, 1, 0], + [45, 60000, 0, 0, 1], + [22, 58000, 1, 0, 0], + [34, 120000, 0, 0, 1], + ], + dtype=jxn.int32, + ), + ) + + +def test_misc_errors(df: pl.DataFrame) -> None: + with pytest.raises( + ValueError, + match="invalid `return_type`: 'stroopwafel'", + ): + _res0 = df.to_jax("stroopwafel") # type: ignore[call-overload] + + with pytest.raises( + ValueError, + match="`label` is required if setting `features` when `return_type='dict'", + ): + _res2 = df.to_jax("dict", features=cs.float()) + + with pytest.raises( + ValueError, + match="`label` and `features` only apply when `return_type` is 'dict'", + ): + _res3 = df.to_jax(label="stroopwafel") diff --git a/py-polars/tests/unit/ml/test_to_torch.py b/py-polars/tests/unit/ml/test_to_torch.py new file mode 100644 index 000000000000..fc685178d732 --- /dev/null +++ b/py-polars/tests/unit/ml/test_to_torch.py @@ -0,0 +1,389 @@ +from __future__ import annotations + +import sys +from typing import Any + +import pytest + +import polars as pl +import polars.selectors as cs +from polars.dependencies import _lazy_import + +# don't import torch until an actual test is triggered (the decorator already +# ensures the tests aren't run locally; this avoids premature local import) +torch, _ = _lazy_import("torch") + +pytestmark = [ + pytest.mark.ci_only, + pytest.mark.skipif( + sys.platform == "win32" and sys.version_info >= (3, 13), + reason="Torch does not ship wheels for Python 3.13 on Windows", + ), +] + + +@pytest.fixture +def df() -> pl.DataFrame: + return pl.DataFrame( + { + "x": [1, 2, 2, 3], + "y": [True, False, True, False], + "z": [1.5, -0.5, 0.0, -2.0], + }, + schema_overrides={"x": pl.Int8, "z": pl.Float32}, + ) + + +def assert_tensor_equal(actual: Any, expected: Any) -> None: + torch.testing.assert_close(actual, expected) + + +def test_to_torch_from_series() -> None: + s = pl.Series("x", [1, 2, 3, 4], dtype=pl.Int8) + t = s.to_torch() + + assert list(t.shape) == [4] + assert_tensor_equal(t, torch.tensor([1, 2, 3, 4], dtype=torch.int8)) + + # note: torch doesn't natively support uint16/32/64. + # confirm that we export to a suitable signed integer type + s = s.cast(pl.UInt16) + t = s.to_torch() + assert_tensor_equal(t, torch.tensor([1, 2, 3, 4], dtype=torch.int32)) + + for dtype in (pl.UInt32, pl.UInt64): + t = s.cast(dtype).to_torch() + assert_tensor_equal(t, torch.tensor([1, 2, 3, 4], dtype=torch.int64)) + + +def test_to_torch_tensor(df: pl.DataFrame) -> None: + t1 = df.to_torch() + t2 = df.to_torch("tensor") + + assert list(t1.shape) == [4, 3] + assert (t1 == t2).all().item() is True + + +def test_to_torch_dict(df: pl.DataFrame) -> None: + td = df.to_torch("dict") + + assert list(td.keys()) == ["x", "y", "z"] + + assert_tensor_equal(td["x"], torch.tensor([1, 2, 2, 3], dtype=torch.int8)) + assert_tensor_equal( + td["y"], torch.tensor([True, False, True, False], dtype=torch.bool) + ) + assert_tensor_equal( + td["z"], torch.tensor([1.5, -0.5, 0.0, -2.0], dtype=torch.float32) + ) + + +def test_to_torch_feature_label_dict(df: pl.DataFrame) -> None: + df = pl.DataFrame( + { + "age": [25, 32, 45, 22, 34], + "income": [50000, 75000, 60000, 58000, 120000], + "education": ["bachelor", "master", "phd", "bachelor", "phd"], + "purchased": [False, True, True, False, True], + }, + schema_overrides={"age": pl.Int32, "income": pl.Int32}, + ).to_dummies("education", separator=":") + + lbl_feat_dict = df.to_torch(return_type="dict", label="purchased") + assert list(lbl_feat_dict.keys()) == ["label", "features"] + + assert_tensor_equal( + lbl_feat_dict["label"], + torch.tensor([[False], [True], [True], [False], [True]], dtype=torch.bool), + ) + assert_tensor_equal( + lbl_feat_dict["features"], + torch.tensor( + [ + [25, 50000, 1, 0, 0], + [32, 75000, 0, 1, 0], + [45, 60000, 0, 0, 1], + [22, 58000, 1, 0, 0], + [34, 120000, 0, 0, 1], + ], + dtype=torch.int32, + ), + ) + + +def test_2D_array_cols_to_torch() -> None: + # 2D array + df1 = pl.DataFrame( + {"data": [[1, 1], [1, 2], [2, 2]]}, + schema_overrides={"data": pl.Array(pl.Int32, shape=(2,))}, + ) + arr1 = df1.to_torch() + assert_tensor_equal( + arr1, + torch.tensor([[1, 1], [1, 2], [2, 2]], dtype=torch.int32), + ) + + # nested 2D array + df2 = pl.DataFrame( + {"data": [[[1, 1], [1, 2]], [[2, 2], [2, 3]]]}, + schema_overrides={"data": pl.Array(pl.Array(pl.Int32, shape=(2,)), shape=(2,))}, + ) + arr2 = df2.to_torch() + assert_tensor_equal( + arr2, + torch.tensor([[[1, 1], [1, 2]], [[2, 2], [2, 3]]], dtype=torch.int32), + ) + + # dict with 2D array + df3 = df2.insert_column(0, pl.Series("lbl", [0, 1], dtype=pl.Int32)) + lbl_feat_dict = df3.to_torch("dict") + assert_tensor_equal( + lbl_feat_dict["lbl"], + torch.tensor([0, 1], dtype=torch.int32), + ) + assert_tensor_equal( + lbl_feat_dict["data"], + torch.tensor([[[1, 1], [1, 2]], [[2, 2], [2, 3]]], dtype=torch.int32), + ) + + # no support for list (yet? could add if ragged arrays are valid) + with pytest.raises( + TypeError, + match=r"cannot convert List column 'data' to Tensor \(use Array dtype instead\)", + ): + pl.DataFrame({"data": [[1, 1], [1, 2], [2, 2]]}).to_torch() + + +def test_to_torch_dataset(df: pl.DataFrame) -> None: + ds = df.to_torch("dataset", dtype=pl.Float64) + + assert len(ds) == 4 + assert isinstance(ds, torch.utils.data.Dataset) + assert repr(ds).startswith(" None: + df = pl.DataFrame( + {"lbl": [0, 1], "data": [[[1, 1], [1, 2]], [[2, 2], [2, 3]]]}, + schema_overrides={"data": pl.Array(pl.Array(pl.Int32, shape=(2,)), shape=(2,))}, + ) + ds = df.to_torch("dataset", label="lbl") + + assert len(ds) == 2 + assert_tensor_equal(ds[0][1], torch.tensor(0, dtype=torch.int64)) + assert_tensor_equal(ds[1][1], torch.tensor(1, dtype=torch.int64)) + assert_tensor_equal(ds[0][0], torch.tensor([[1, 1], [1, 2]], dtype=torch.int32)) + assert_tensor_equal(ds[1][0], torch.tensor([[2, 2], [2, 3]], dtype=torch.int32)) + + +def test_to_torch_dataset_feature_reorder(df: pl.DataFrame) -> None: + ds = df.to_torch("dataset", label="x", features=["z", "y"]) + assert_tensor_equal( + torch.tensor( + [ + [1.5000, 1.0000], + [-0.5000, 0.0000], + [0.0000, 1.0000], + [-2.0000, 0.0000], + ] + ), + ds.features, + ) + assert_tensor_equal(torch.tensor([1, 2, 2, 3], dtype=torch.int8), ds.labels) + + +def test_to_torch_dataset_feature_subset(df: pl.DataFrame) -> None: + ds = df.to_torch("dataset", label="x", features=["z"]) + assert_tensor_equal( + torch.tensor([[1.5000], [-0.5000], [0.0000], [-2.0000]]), + ds.features, + ) + assert_tensor_equal(torch.tensor([1, 2, 2, 3], dtype=torch.int8), ds.labels) + + +def test_to_torch_dataset_index_slice(df: pl.DataFrame) -> None: + ds = df.to_torch("dataset") + ts = ds[1:3] + + expected = (torch.tensor([[2.0000, 0.0000, -0.5000], [2.0000, 1.0000, 0.0000]]),) + assert_tensor_equal(expected, ts) + + ts = ds[::2] + expected = (torch.tensor([[1.0000, 1.0000, 1.5000], [2.0, 1.0, 0.0]]),) + assert_tensor_equal(expected, ts) + + +@pytest.mark.parametrize( + "index", + [ + [0, 3], + range(0, 4, 3), + slice(0, 4, 3), + ], +) +def test_to_torch_dataset_index_multi(index: Any, df: pl.DataFrame) -> None: + ds = df.to_torch("dataset") + ts = ds[index] + + expected = (torch.tensor([[1.0, 1.0, 1.5], [3.0, 0.0, -2.0]]),) + assert_tensor_equal(expected, ts) + assert ds.schema == {"features": torch.float32, "labels": None} + + +def test_to_torch_dataset_index_range(df: pl.DataFrame) -> None: + ds = df.to_torch("dataset") + ts = ds[range(3, 0, -1)] + + expected = (torch.tensor([[3.0, 0.0, -2.0], [2.0, 1.0, 0.0], [2.0, 0.0, -0.5]]),) + assert_tensor_equal(expected, ts) + + +def test_to_dataset_half_precision(df: pl.DataFrame) -> None: + ds = df.to_torch("dataset", label="x") + assert ds.schema == {"features": torch.float32, "labels": torch.int8} + + dsf16 = ds.half() + assert dsf16.schema == {"features": torch.float16, "labels": torch.float16} + + # half precision across all data + ts = dsf16[:3:2] + expected = ( + torch.tensor([[1.0000, 1.5000], [1.0000, 0.0000]], dtype=torch.float16), + torch.tensor([1.0, 2.0], dtype=torch.float16), + ) + assert_tensor_equal(expected, ts) + + # only apply half precision to the feature data + dsf16 = ds.half(labels=False) + assert dsf16.schema == {"features": torch.float16, "labels": torch.int8} + + ts = dsf16[:3:2] + expected = ( + torch.tensor([[1.0000, 1.5000], [1.0000, 0.0000]], dtype=torch.float16), + torch.tensor([1, 2], dtype=torch.int8), + ) + assert_tensor_equal(expected, ts) + + # only apply half precision to the label data + dsf16 = ds.half(features=False) + assert dsf16.schema == {"features": torch.float32, "labels": torch.float16} + + ts = dsf16[:3:2] + expected = ( + torch.tensor([[1.0000, 1.5000], [1.0000, 0.0000]], dtype=torch.float32), + torch.tensor([1.0, 2.0], dtype=torch.float16), + ) + assert_tensor_equal(expected, ts) + + # no labels + dsf16 = df.to_torch("dataset").half() + assert dsf16.schema == {"features": torch.float16, "labels": None} + + ts = dsf16[:3:2] + expected = ( # type: ignore[assignment] + torch.tensor( + data=[[1.0000, 1.0000, 1.5000], [2.0000, 1.0000, 0.0000]], + dtype=torch.float16, + ), + ) + assert_tensor_equal(expected, ts) + + +@pytest.mark.parametrize( + ("label", "features"), + [ + ("x", None), + ("x", ["y", "z"]), + (cs.integer(), ~cs.integer()), + ], +) +def test_to_torch_labelled_dataset(label: Any, features: Any, df: pl.DataFrame) -> None: + ds = df.to_torch("dataset", label=label, features=features) + ts = next(iter(torch.utils.data.DataLoader(ds, batch_size=2, shuffle=False))) + + expected = [ + torch.tensor([[1.0, 1.5], [0.0, -0.5]]), + torch.tensor([1, 2], dtype=torch.int8), + ] + assert len(ts) == len(expected) + for actual, exp in zip(ts, expected): + assert_tensor_equal(exp, actual) + + +def test_to_torch_labelled_dataset_expr(df: pl.DataFrame) -> None: + ds = df.to_torch( + "dataset", + dtype=pl.Float64, + label=(pl.col("x") * 8).cast(pl.Int16), + ) + dl = torch.utils.data.DataLoader(ds, batch_size=2, shuffle=False) + for data in (tuple(ds[:2]), tuple(next(iter(dl)))): + expected = ( + torch.tensor([[1.0000, 1.5000], [0.0000, -0.5000]], dtype=torch.float64), + torch.tensor([8, 16], dtype=torch.int16), + ) + assert len(data) == len(expected) + for actual, exp in zip(data, expected): + assert_tensor_equal(exp, actual) + + +def test_to_torch_labelled_dataset_multi(df: pl.DataFrame) -> None: + ds = df.to_torch("dataset", label=["x", "y"]) + dl = torch.utils.data.DataLoader(ds, batch_size=3, shuffle=False) + ts = list(dl) + + expected = [ + [ + torch.tensor([[1.5000], [-0.5000], [0.0000]]), + torch.tensor([[1, 1], [2, 0], [2, 1]], dtype=torch.int8), + ], + [ + torch.tensor([[-2.0]]), + torch.tensor([[3, 0]], dtype=torch.int8), + ], + ] + assert len(ts) == len(expected) + + for actual, exp in zip(ts, expected): + assert len(actual) == len(exp) + for a, e in zip(actual, exp): + assert_tensor_equal(e, a) + + +def test_misc_errors(df: pl.DataFrame) -> None: + ds = df.to_torch("dataset") + + with pytest.raises( + ValueError, + match="invalid `return_type`: 'stroopwafel'", + ): + _res0 = df.to_torch("stroopwafel") # type: ignore[call-overload] + + with pytest.raises( + ValueError, + match="does not support u16, u32, or u64 dtypes", + ): + _res1 = df.to_torch(dtype=pl.UInt16) + + with pytest.raises( + IndexError, + match="tensors used as indices must be long, int", + ): + _res2 = ds[torch.tensor([0, 3], dtype=torch.complex64)] + + with pytest.raises( + ValueError, + match="`label` and `features` only apply when `return_type` is 'dataset' or 'dict'", + ): + _res3 = df.to_torch(label="stroopwafel") + + with pytest.raises( + ValueError, + match="`label` is required if setting `features` when `return_type='dict'", + ): + _res4 = df.to_torch("dict", features=cs.float()) diff --git a/py-polars/tests/unit/operations/__init__.py b/py-polars/tests/unit/operations/__init__.py new file mode 100644 index 000000000000..2560f5ab9e7a --- /dev/null +++ b/py-polars/tests/unit/operations/__init__.py @@ -0,0 +1 @@ +"""Test module for extensive testing of specific operations like join or explode.""" diff --git a/py-polars/tests/unit/operations/aggregation/__init__.py b/py-polars/tests/unit/operations/aggregation/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/operations/aggregation/test_aggregations.py b/py-polars/tests/unit/operations/aggregation/test_aggregations.py new file mode 100644 index 000000000000..d56ffce6d5a7 --- /dev/null +++ b/py-polars/tests/unit/operations/aggregation/test_aggregations.py @@ -0,0 +1,776 @@ +from __future__ import annotations + +from datetime import date, datetime, timedelta +from typing import TYPE_CHECKING, Any, cast + +import numpy as np +import pytest + +import polars as pl +from polars.exceptions import InvalidOperationError +from polars.testing import assert_frame_equal + +if TYPE_CHECKING: + import numpy.typing as npt + + from polars._typing import PolarsDataType + + +def test_quantile_expr_input() -> None: + df = pl.DataFrame({"a": [1, 2, 3, 4, 5], "b": [0.0, 0.0, 0.3, 0.2, 0.0]}) + + assert_frame_equal( + df.select([pl.col("a").quantile(pl.col("b").sum() + 0.1)]), + df.select(pl.col("a").quantile(0.6)), + ) + + +def test_boolean_aggs() -> None: + df = pl.DataFrame({"bool": [True, False, None, True]}) + + aggs = [ + pl.mean("bool").alias("mean"), + pl.std("bool").alias("std"), + pl.var("bool").alias("var"), + ] + assert df.select(aggs).to_dict(as_series=False) == { + "mean": [0.6666666666666666], + "std": [0.5773502691896258], + "var": [0.33333333333333337], + } + + assert df.group_by(pl.lit(1)).agg(aggs).to_dict(as_series=False) == { + "literal": [1], + "mean": [0.6666666666666666], + "std": [0.5773502691896258], + "var": [0.33333333333333337], + } + + +def test_duration_aggs() -> None: + df = pl.DataFrame( + { + "time1": pl.datetime_range( + start=datetime(2022, 12, 12), + end=datetime(2022, 12, 18), + interval="1d", + eager=True, + ), + "time2": pl.datetime_range( + start=datetime(2023, 1, 12), + end=datetime(2023, 1, 18), + interval="1d", + eager=True, + ), + } + ) + + df = df.with_columns((pl.col("time2") - pl.col("time1")).alias("time_difference")) + + assert df.select("time_difference").mean().to_dict(as_series=False) == { + "time_difference": [timedelta(days=31)] + } + assert df.group_by(pl.lit(1)).agg(pl.mean("time_difference")).to_dict( + as_series=False + ) == { + "literal": [1], + "time_difference": [timedelta(days=31)], + } + + +def test_list_aggregation_that_filters_all_data_6017() -> None: + out = ( + pl.DataFrame({"col_to_group_by": [2], "flt": [1672740910.967138], "col3": [1]}) + .group_by("col_to_group_by") + .agg((pl.col("flt").filter(col3=0).diff() * 1000).diff().alias("calc")) + ) + + assert out.schema == {"col_to_group_by": pl.Int64, "calc": pl.List(pl.Float64)} + assert out.to_dict(as_series=False) == {"col_to_group_by": [2], "calc": [[]]} + + +def test_median() -> None: + s = pl.Series([1, 2, 3]) + assert s.median() == 2 + + +def test_single_element_std() -> None: + s = pl.Series([1]) + assert s.std(ddof=1) is None + assert s.std(ddof=0) == 0.0 + + +def test_quantile() -> None: + s = pl.Series([1, 2, 3]) + assert s.quantile(0.5, "nearest") == 2 + assert s.quantile(0.5, "lower") == 2 + assert s.quantile(0.5, "higher") == 2 + + +@pytest.mark.slow +@pytest.mark.parametrize("tp", [int, float]) +@pytest.mark.parametrize("n", [1, 2, 10, 100]) +def test_quantile_vs_numpy(tp: type, n: int) -> None: + a: np.ndarray[Any, Any] = np.random.randint(0, 50, n).astype(tp) + np_result: npt.ArrayLike | None = np.median(a) + # nan check + if np_result != np_result: + np_result = None + median = pl.Series(a).median() + if median is not None: + assert np.isclose(median, np_result) # type: ignore[arg-type] + else: + assert np_result is None + + q = np.random.sample() + try: + np_result = np.quantile(a, q) + except IndexError: + np_result = None + if np_result: + # nan check + if np_result != np_result: + np_result = None + assert np.isclose( + pl.Series(a).quantile(q, interpolation="linear"), # type: ignore[arg-type] + np_result, # type: ignore[arg-type] + ) + + +def test_mean_overflow() -> None: + assert np.isclose( + pl.Series([9_223_372_036_854_775_800, 100]).mean(), # type: ignore[arg-type] + 4.611686018427388e18, + ) + + +def test_mean_null_simd() -> None: + for dtype in [int, float]: + df = ( + pl.Series(np.random.randint(0, 100, 1000)) + .cast(dtype) + .to_frame("a") + .select(pl.when(pl.col("a") > 40).then(pl.col("a"))) + ) + + s = df["a"] + assert s.mean() == s.to_pandas().mean() + + +def test_literal_group_agg_chunked_7968() -> None: + df = pl.DataFrame({"A": [1, 1], "B": [1, 3]}) + ser = pl.concat([pl.Series([3]), pl.Series([4, 5])], rechunk=False) + + assert_frame_equal( + df.group_by("A").agg(pl.col("B").search_sorted(ser)), + pl.DataFrame( + [ + pl.Series("A", [1], dtype=pl.Int64), + pl.Series("B", [[1, 2, 2]], dtype=pl.List(pl.UInt32)), + ] + ), + ) + + +def test_duration_function_literal() -> None: + df = pl.DataFrame( + { + "A": ["x", "x", "y", "y", "y"], + "T": pl.datetime_range( + date(2022, 1, 1), date(2022, 5, 1), interval="1mo", eager=True + ), + "S": [1, 2, 4, 8, 16], + } + ) + + result = df.group_by("A", maintain_order=True).agg( + (pl.col("T").max() + pl.duration(seconds=1)) - pl.col("T") + ) + + # this checks if the `pl.duration` is flagged as AggState::Literal + expected = pl.DataFrame( + { + "A": ["x", "y"], + "T": [ + [timedelta(days=31, seconds=1), timedelta(seconds=1)], + [ + timedelta(days=61, seconds=1), + timedelta(days=30, seconds=1), + timedelta(seconds=1), + ], + ], + } + ) + assert_frame_equal(result, expected) + + +def test_string_par_materialize_8207() -> None: + df = pl.LazyFrame( + { + "a": ["a", "b", "d", "c", "e"], + "b": ["P", "L", "R", "T", "a long string"], + } + ) + + assert df.group_by(["a"]).agg(pl.min("b")).sort("a").collect().to_dict( + as_series=False + ) == { + "a": ["a", "b", "c", "d", "e"], + "b": ["P", "L", "T", "R", "a long string"], + } + + +def test_online_variance() -> None: + df = pl.DataFrame( + { + "id": [1] * 5, + "no_nulls": [1, 2, 3, 4, 5], + "nulls": [1, None, 3, None, 5], + } + ) + + assert_frame_equal( + df.group_by("id") + .agg(pl.all().exclude("id").std()) + .select(["no_nulls", "nulls"]), + df.select(pl.all().exclude("id").std()), + ) + + +def test_implode_and_agg() -> None: + df = pl.DataFrame({"type": ["water", "fire", "water", "earth"]}) + + # this would OOB + with pytest.raises( + InvalidOperationError, + match=r"'implode' followed by an aggregation is not allowed", + ): + df.group_by("type").agg(pl.col("type").implode().first().alias("foo")) + + # implode + function should be allowed in group_by + assert df.group_by("type", maintain_order=True).agg( + pl.col("type").implode().list.head().alias("foo") + ).to_dict(as_series=False) == { + "type": ["water", "fire", "earth"], + "foo": [["water", "water"], ["fire"], ["earth"]], + } + assert df.select(pl.col("type").implode().list.head(1).over("type")).to_dict( + as_series=False + ) == {"type": [["water"], ["fire"], ["water"], ["earth"]]} + + +def test_mapped_literal_to_literal_9217() -> None: + df = pl.DataFrame({"unique_id": ["a", "b"]}) + assert df.group_by(True).agg( + pl.struct(pl.lit("unique_id").alias("unique_id")) + ).to_dict(as_series=False) == { + "literal": [True], + "unique_id": [{"unique_id": "unique_id"}], + } + + +def test_sum_empty_and_null_set() -> None: + series = pl.Series("a", [], dtype=pl.Float32) + assert series.sum() == 0 + + series = pl.Series("a", [None], dtype=pl.Float32) + assert series.sum() == 0 + + df = pl.DataFrame( + {"a": [None, None, None], "b": [1, 1, 1]}, + schema={"a": pl.Float32, "b": pl.Int64}, + ) + assert df.select(pl.sum("a")).item() == 0.0 + assert df.group_by("b").agg(pl.sum("a"))["a"].item() == 0.0 + + +def test_horizontal_sum_null_to_identity() -> None: + assert pl.DataFrame({"a": [1, 5], "b": [10, None]}).select( + pl.sum_horizontal(["a", "b"]) + ).to_series().to_list() == [11, 5] + + +def test_horizontal_sum_bool_dtype() -> None: + out = pl.DataFrame({"a": [True, False]}).select(pl.sum_horizontal("a")) + assert_frame_equal(out, pl.DataFrame({"a": pl.Series([1, 0], dtype=pl.UInt32)})) + + +def test_horizontal_sum_in_group_by_15102() -> None: + nbr_records = 1000 + out = ( + pl.LazyFrame( + { + "x": [None, "two", None] * nbr_records, + "y": ["one", "two", None] * nbr_records, + "z": [None, "two", None] * nbr_records, + } + ) + .select(pl.sum_horizontal(pl.all().is_null()).alias("num_null")) + .group_by("num_null") + .len() + .sort(by="num_null") + .collect() + ) + assert_frame_equal( + out, + pl.DataFrame( + { + "num_null": pl.Series([0, 2, 3], dtype=pl.UInt32), + "len": pl.Series([nbr_records] * 3, dtype=pl.UInt32), + } + ), + ) + + +def test_first_last_unit_length_12363() -> None: + df = pl.DataFrame( + { + "a": [1, 2], + "b": [None, None], + } + ) + + assert df.select( + pl.all().drop_nulls().first().name.suffix("_first"), + pl.all().drop_nulls().last().name.suffix("_last"), + ).to_dict(as_series=False) == { + "a_first": [1], + "b_first": [None], + "a_last": [2], + "b_last": [None], + } + + +def test_binary_op_agg_context_no_simplify_expr_12423() -> None: + expect = pl.DataFrame({"x": [1], "y": [1]}, schema={"x": pl.Int64, "y": pl.Int32}) + + for simplify_expression in (True, False): + assert_frame_equal( + expect, + pl.LazyFrame({"x": [1]}) + .group_by("x") + .agg(y=pl.lit(1) * pl.lit(1)) + .collect(simplify_expression=simplify_expression), + ) + + +def test_nan_inf_aggregation() -> None: + df = pl.DataFrame( + [ + ("both nan", np.nan), + ("both nan", np.nan), + ("nan and 5", np.nan), + ("nan and 5", 5), + ("nan and null", np.nan), + ("nan and null", None), + ("both none", None), + ("both none", None), + ("both inf", np.inf), + ("both inf", np.inf), + ("inf and null", np.inf), + ("inf and null", None), + ], + schema=["group", "value"], + orient="row", + ) + + assert_frame_equal( + df.group_by("group", maintain_order=True).agg( + min=pl.col("value").min(), + max=pl.col("value").max(), + mean=pl.col("value").mean(), + ), + pl.DataFrame( + [ + ("both nan", np.nan, np.nan, np.nan), + ("nan and 5", 5, 5, np.nan), + ("nan and null", np.nan, np.nan, np.nan), + ("both none", None, None, None), + ("both inf", np.inf, np.inf, np.inf), + ("inf and null", np.inf, np.inf, np.inf), + ], + schema=["group", "min", "max", "mean"], + orient="row", + ), + ) + + +@pytest.mark.parametrize("dtype", [pl.Int16, pl.UInt16]) +def test_int16_max_12904(dtype: PolarsDataType) -> None: + s = pl.Series([None, 1], dtype=dtype) + + assert s.min() == 1 + assert s.max() == 1 + + +def test_agg_filter_over_empty_df_13610() -> None: + ldf = pl.LazyFrame( + { + "a": [1, 1, 1, 2, 3], + "b": [True, True, True, True, True], + "c": [None, None, None, None, None], + } + ) + + out = ( + ldf.drop_nulls() + .group_by(["a"], maintain_order=True) + .agg(pl.col("b").filter(pl.col("b").shift(1))) + .collect() + ) + expected = pl.DataFrame(schema={"a": pl.Int64, "b": pl.List(pl.Boolean)}) + assert_frame_equal(out, expected) + + df = pl.DataFrame(schema={"a": pl.Int64, "b": pl.Boolean}) + out = df.group_by("a").agg(pl.col("b").filter(pl.col("b").shift())) + expected = pl.DataFrame(schema={"a": pl.Int64, "b": pl.List(pl.Boolean)}) + assert_frame_equal(out, expected) + + +@pytest.mark.slow +def test_agg_empty_sum_after_filter_14734() -> None: + f = ( + pl.DataFrame({"a": [1, 2], "b": [1, 2]}) + .lazy() + .group_by("a") + .agg(pl.col("b").filter(pl.lit(False)).sum()) + .collect + ) + + last = f() + + # We need both possible output orders, which should happen within + # 1000 iterations (during testing it usually happens within 10). + limit = 1000 + i = 0 + while (curr := f()).equals(last): + i += 1 + assert i != limit + + expect = pl.Series("b", [0, 0]).to_frame() + assert_frame_equal(expect, last.select("b")) + assert_frame_equal(expect, curr.select("b")) + + +@pytest.mark.slow +def test_grouping_hash_14749() -> None: + n_groups = 251 + rows_per_group = 4 + assert ( + pl.DataFrame( + { + "grp": np.repeat(np.arange(n_groups), rows_per_group), + "x": np.tile(np.arange(rows_per_group), n_groups), + } + ) + .select(pl.col("x").max().over("grp"))["x"] + .value_counts() + ).to_dict(as_series=False) == {"x": [3], "count": [1004]} + + +@pytest.mark.parametrize( + ("in_dtype", "out_dtype"), + [ + (pl.Boolean, pl.Float64), + (pl.UInt8, pl.Float64), + (pl.UInt16, pl.Float64), + (pl.UInt32, pl.Float64), + (pl.UInt64, pl.Float64), + (pl.Int8, pl.Float64), + (pl.Int16, pl.Float64), + (pl.Int32, pl.Float64), + (pl.Int64, pl.Float64), + (pl.Float32, pl.Float32), + (pl.Float64, pl.Float64), + ], +) +def test_horizontal_mean_single_column( + in_dtype: PolarsDataType, + out_dtype: PolarsDataType, +) -> None: + out = ( + pl.LazyFrame({"a": pl.Series([1, 0]).cast(in_dtype)}) + .select(pl.mean_horizontal(pl.all())) + .collect() + ) + + assert_frame_equal(out, pl.DataFrame({"a": pl.Series([1.0, 0.0]).cast(out_dtype)})) + + +def test_horizontal_mean_in_group_by_15115() -> None: + nbr_records = 1000 + out = ( + pl.LazyFrame( + { + "w": [None, "one", "two", "three"] * nbr_records, + "x": [None, None, "two", "three"] * nbr_records, + "y": [None, None, None, "three"] * nbr_records, + "z": [None, None, None, None] * nbr_records, + } + ) + .select(pl.mean_horizontal(pl.all().is_null()).alias("mean_null")) + .group_by("mean_null") + .len() + .sort(by="mean_null") + .collect() + ) + assert_frame_equal( + out, + pl.DataFrame( + { + "mean_null": pl.Series([0.25, 0.5, 0.75, 1.0], dtype=pl.Float64), + "len": pl.Series([nbr_records] * 4, dtype=pl.UInt32), + } + ), + ) + + +def test_group_count_over_null_column_15705() -> None: + df = pl.DataFrame( + {"a": [1, 1, 2, 2, 3, 3], "c": [None, None, None, None, None, None]} + ) + out = df.group_by("a", maintain_order=True).agg(pl.col("c").count()) + assert out["c"].to_list() == [0, 0, 0] + + +@pytest.mark.release +def test_min_max_2850() -> None: + # https://github.com/pola-rs/polars/issues/2850 + df = pl.DataFrame( + { + "id": [ + 130352432, + 130352277, + 130352611, + 130352833, + 130352305, + 130352258, + 130352764, + 130352475, + 130352368, + 130352346, + ] + } + ) + + minimum = 130352258 + maximum = 130352833.0 + + for _ in range(10): + permuted = df.sample(fraction=1.0, seed=0) + computed = permuted.select( + pl.col("id").min().alias("min"), pl.col("id").max().alias("max") + ) + assert cast(int, computed[0, "min"]) == minimum + assert cast(float, computed[0, "max"]) == maximum + + +def test_multi_arg_structify_15834() -> None: + df = pl.DataFrame( + { + "group": [1, 2, 1, 2], + "value": [ + 0.1973209146402105, + 0.13380719982405365, + 0.6152394463707009, + 0.4558767896005155, + ], + } + ) + + assert df.lazy().group_by("group").agg( + pl.struct(a=1, value=pl.col("value").sum()) + ).collect().sort("group").to_dict(as_series=False) == { + "group": [1, 2], + "a": [ + {"a": 1, "value": 0.8125603610109114}, + {"a": 1, "value": 0.5896839894245691}, + ], + } + + +def test_filter_aggregation_16642() -> None: + df = pl.DataFrame( + { + "datetime": [ + datetime(2022, 1, 1, 11, 0), + datetime(2022, 1, 1, 11, 1), + datetime(2022, 1, 1, 11, 2), + datetime(2022, 1, 1, 11, 3), + datetime(2022, 1, 1, 11, 4), + datetime(2022, 1, 1, 11, 5), + datetime(2022, 1, 1, 11, 6), + datetime(2022, 1, 1, 11, 7), + datetime(2022, 1, 1, 11, 8), + datetime(2022, 1, 1, 11, 9, 1), + datetime(2022, 1, 2, 11, 0), + datetime(2022, 1, 2, 11, 1), + datetime(2022, 1, 2, 11, 2), + datetime(2022, 1, 2, 11, 3), + datetime(2022, 1, 2, 11, 4), + datetime(2022, 1, 2, 11, 5), + datetime(2022, 1, 2, 11, 6), + datetime(2022, 1, 2, 11, 7), + datetime(2022, 1, 2, 11, 8), + datetime(2022, 1, 2, 11, 9, 1), + ], + "alpha": [ + "A", + "B", + "C", + "D", + "E", + "F", + "G", + "H", + "I", + "J", + "A", + "B", + "C", + "D", + "E", + "F", + "G", + "H", + "I", + "J", + ], + "num": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + } + ) + grouped = df.group_by(pl.col("datetime").dt.date()) + + ts_filter = pl.col("datetime").dt.time() <= pl.time(11, 3) + + report = grouped.agg(pl.col("num").filter(ts_filter).max()).sort("datetime") + assert report.to_dict(as_series=False) == { + "datetime": [date(2022, 1, 1), date(2022, 1, 2)], + "num": [3, 3], + } + + +def test_sort_by_over_single_nulls_first() -> None: + key = [0, 0, 0, 0, 1, 1, 1, 1] + df = pl.DataFrame( + { + "key": key, + "value": [2, None, 1, 0, 2, None, 1, 0], + } + ) + out = df.select( + pl.all().sort_by("value", nulls_last=False, maintain_order=True).over("key") + ) + expected = pl.DataFrame( + { + "key": key, + "value": [None, 0, 1, 2, None, 0, 1, 2], + } + ) + assert_frame_equal(out, expected) + + +def test_sort_by_over_single_nulls_last() -> None: + key = [0, 0, 0, 0, 1, 1, 1, 1] + df = pl.DataFrame( + { + "key": key, + "value": [2, None, 1, 0, 2, None, 1, 0], + } + ) + out = df.select( + pl.all().sort_by("value", nulls_last=True, maintain_order=True).over("key") + ) + expected = pl.DataFrame( + { + "key": key, + "value": [0, 1, 2, None, 0, 1, 2, None], + } + ) + assert_frame_equal(out, expected) + + +def test_sort_by_over_multiple_nulls_first() -> None: + key1 = [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1] + key2 = [0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1] + df = pl.DataFrame( + { + "key1": key1, + "key2": key2, + "value": [1, None, 0, 1, None, 0, 1, None, 0, None, 1, 0], + } + ) + out = df.select( + pl.all() + .sort_by("value", nulls_last=False, maintain_order=True) + .over("key1", "key2") + ) + expected = pl.DataFrame( + { + "key1": key1, + "key2": key2, + "value": [None, 0, 1, None, 0, 1, None, 0, 1, None, 0, 1], + } + ) + assert_frame_equal(out, expected) + + +def test_sort_by_over_multiple_nulls_last() -> None: + key1 = [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1] + key2 = [0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1] + df = pl.DataFrame( + { + "key1": key1, + "key2": key2, + "value": [1, None, 0, 1, None, 0, 1, None, 0, None, 1, 0], + } + ) + out = df.select( + pl.all() + .sort_by("value", nulls_last=True, maintain_order=True) + .over("key1", "key2") + ) + expected = pl.DataFrame( + { + "key1": key1, + "key2": key2, + "value": [0, 1, None, 0, 1, None, 0, 1, None, 0, 1, None], + } + ) + assert_frame_equal(out, expected) + + +def test_slice_after_agg_raises() -> None: + with pytest.raises( + InvalidOperationError, match=r"cannot slice\(\) an aggregated scalar value" + ): + pl.select(a=1, b=1).group_by("a").agg(pl.col("b").first().slice(99, 0)) + + +def test_agg_scalar_empty_groups_20115() -> None: + assert_frame_equal( + ( + pl.DataFrame({"key": [123], "value": [456]}) + .group_by("key") + .agg(pl.col("value").slice(1, 1).first()) + ), + pl.select(key=pl.lit(123, pl.Int64), value=pl.lit(None, pl.Int64)), + ) + + +def test_agg_expr_returns_list_type_15574() -> None: + assert ( + pl.LazyFrame({"a": [1, None], "b": [1, 2]}) + .group_by("b") + .agg(pl.col("a").drop_nulls()) + .collect_schema() + ) == {"b": pl.Int64, "a": pl.List(pl.Int64)} + + +def test_empty_agg_22005() -> None: + out = ( + pl.concat([pl.LazyFrame({"a": [1, 2]}), pl.LazyFrame({"a": [1, 2]})]) + .limit(0) + .select(pl.col("a").sum()) + ) + assert_frame_equal(out.collect(), pl.DataFrame({"a": 0})) diff --git a/py-polars/tests/unit/operations/aggregation/test_folds.py b/py-polars/tests/unit/operations/aggregation/test_folds.py new file mode 100644 index 000000000000..f14bd7852347 --- /dev/null +++ b/py-polars/tests/unit/operations/aggregation/test_folds.py @@ -0,0 +1,80 @@ +import polars as pl +from polars.testing import assert_frame_equal + + +def test_fold_reduce() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0]}) + + out = df.select( + pl.fold(acc=pl.lit(0), function=lambda acc, x: acc + x, exprs=pl.all()).alias( + "foo" + ) + ) + assert out["foo"].to_list() == [2, 4, 6] + out = df.select( + pl.reduce(function=lambda acc, x: acc + x, exprs=pl.all()).alias("foo") + ) + assert out["foo"].to_list() == [2, 4, 6] + + +def test_cum_fold() -> None: + df = pl.DataFrame( + { + "a": [1, 2, 3, 4], + "b": [5, 6, 7, 8], + "c": [10, 20, 30, 40], + } + ) + result = df.select(pl.cum_fold(pl.lit(0), lambda a, b: a + b, pl.all())) + expected = pl.DataFrame( + { + "cum_fold": [ + {"a": 1, "b": 6, "c": 16}, + {"a": 2, "b": 8, "c": 28}, + {"a": 3, "b": 10, "c": 40}, + {"a": 4, "b": 12, "c": 52}, + ] + } + ) + assert_frame_equal(result, expected) + + +def test_cum_reduce() -> None: + df = pl.DataFrame( + { + "a": [1, 2, 3, 4], + "b": [5, 6, 7, 8], + "c": [10, 20, 30, 40], + } + ) + result = df.select(pl.cum_reduce(lambda a, b: a + b, pl.all())) + expected = pl.DataFrame( + { + "cum_reduce": [ + {"a": 1, "b": 6, "c": 16}, + {"a": 2, "b": 8, "c": 28}, + {"a": 3, "b": 10, "c": 40}, + {"a": 4, "b": 12, "c": 52}, + ] + } + ) + assert_frame_equal(result, expected) + + +def test_alias_prune_in_fold_15438() -> None: + df = pl.DataFrame({"x": [1, 2], "expected_result": ["first", "second"]}).select( + actual_result=pl.fold( + acc=pl.lit("other", dtype=pl.Utf8), + function=lambda acc, x: pl.when(x).then(pl.lit(x.name)).otherwise(acc), # type: ignore[arg-type, return-value] + exprs=[ + (pl.col("x") == 1).alias("first"), + (pl.col("x") == 2).alias("second"), + ], + ) + ) + expected = pl.DataFrame( + { + "actual_result": ["first", "second"], + } + ) + assert_frame_equal(df, expected) diff --git a/py-polars/tests/unit/operations/aggregation/test_horizontal.py b/py-polars/tests/unit/operations/aggregation/test_horizontal.py new file mode 100644 index 000000000000..9c170da3aca7 --- /dev/null +++ b/py-polars/tests/unit/operations/aggregation/test_horizontal.py @@ -0,0 +1,667 @@ +from __future__ import annotations + +import datetime +from collections import OrderedDict +from typing import TYPE_CHECKING, Any + +import pytest + +import polars as pl +import polars.selectors as cs +from polars.exceptions import ComputeError, PolarsError +from polars.testing import assert_frame_equal, assert_series_equal + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + + +def test_any_expr(fruits_cars: pl.DataFrame) -> None: + assert fruits_cars.select(pl.any_horizontal("A", "B")).to_series()[0] is True + + +def test_all_any_horizontally() -> None: + df = pl.DataFrame( + [ + [False, False, True], + [False, False, True], + [True, False, False], + [False, None, True], + [None, None, False], + ], + schema=["var1", "var2", "var3"], + orient="row", + ) + result = df.select( + any=pl.any_horizontal(pl.col("var2"), pl.col("var3")), + all=pl.all_horizontal(pl.col("var2"), pl.col("var3")), + ) + expected = pl.DataFrame( + { + "any": [True, True, False, True, None], + "all": [False, False, False, None, False], + } + ) + assert_frame_equal(result, expected) + + # note: a kwargs filter will use an internal call to all_horizontal + dfltr = df.lazy().filter(var1=True, var3=False) + assert dfltr.collect().rows() == [(True, False, False)] + + # confirm that we reduced the horizontal filter components + # (eg: explain does not contain an "all_horizontal" node) + assert "horizontal" not in dfltr.explain().lower() + + +def test_empty_all_any_horizontally() -> None: + # any/all_horizontal don't allow empty input, but we can still trigger this + # by selecting an empty set of columns with pl.selectors. + df = pl.DataFrame({"x": [1, 2, 3]}) + assert_frame_equal( + df.select(pl.any_horizontal(cs.string().is_null())), + pl.DataFrame({"literal": False}), + ) + assert_frame_equal( + df.select(pl.all_horizontal(cs.string().is_null())), + pl.DataFrame({"literal": True}), + ) + + +def test_all_any_single_input() -> None: + df = pl.DataFrame({"a": [0, 1, None]}) + out = df.select( + all=pl.all_horizontal(pl.col("a")), any=pl.any_horizontal(pl.col("a")) + ) + + expected = pl.DataFrame( + { + "all": [False, True, None], + "any": [False, True, None], + } + ) + assert_frame_equal(out, expected) + + +def test_all_any_accept_expr() -> None: + lf = pl.LazyFrame( + { + "a": [1, None, 2, None], + "b": [1, 2, None, None], + } + ) + + result = lf.select( + pl.any_horizontal(pl.all().is_null()).alias("null_in_row"), + pl.all_horizontal(pl.all().is_null()).alias("all_null_in_row"), + ) + + expected = pl.LazyFrame( + { + "null_in_row": [False, True, True, True], + "all_null_in_row": [False, False, False, True], + } + ) + assert_frame_equal(result, expected) + + +def test_max_min_multiple_columns(fruits_cars: pl.DataFrame) -> None: + result = fruits_cars.select(max=pl.max_horizontal("A", "B")) + expected = pl.Series("max", [5, 4, 3, 4, 5]) + assert_series_equal(result.to_series(), expected) + + result = fruits_cars.select(min=pl.min_horizontal("A", "B")) + expected = pl.Series("min", [1, 2, 3, 2, 1]) + assert_series_equal(result.to_series(), expected) + + +def test_max_min_nulls_consistency() -> None: + df = pl.DataFrame({"a": [None, 2, 3], "b": [4, None, 6], "c": [7, 5, 0]}) + + result = df.select(max=pl.max_horizontal("a", "b", "c")).to_series() + expected = pl.Series("max", [7, 5, 6]) + assert_series_equal(result, expected) + + result = df.select(min=pl.min_horizontal("a", "b", "c")).to_series() + expected = pl.Series("min", [4, 2, 0]) + assert_series_equal(result, expected) + + +def test_nested_min_max() -> None: + df = pl.DataFrame({"a": [1], "b": [2], "c": [3], "d": [4]}) + + result = df.with_columns( + pl.max_horizontal( + pl.min_horizontal("a", "b"), pl.min_horizontal("c", "d") + ).alias("t") + ) + + expected = pl.DataFrame({"a": [1], "b": [2], "c": [3], "d": [4], "t": [3]}) + assert_frame_equal(result, expected) + + +def test_empty_inputs_raise() -> None: + with pytest.raises( + ComputeError, + match="cannot return empty fold because the number of output rows is unknown", + ): + pl.select(pl.any_horizontal()) + + with pytest.raises( + ComputeError, + match="cannot return empty fold because the number of output rows is unknown", + ): + pl.select(pl.all_horizontal()) + + +def test_max_min_wildcard_columns(fruits_cars: pl.DataFrame) -> None: + result = fruits_cars.select(pl.col(pl.datatypes.Int64)).select( + min=pl.min_horizontal("*") + ) + expected = pl.Series("min", [1, 2, 3, 2, 1]) + assert_series_equal(result.to_series(), expected) + + result = fruits_cars.select(pl.col(pl.datatypes.Int64)).select( + min=pl.min_horizontal(pl.all()) + ) + assert_series_equal(result.to_series(), expected) + + result = fruits_cars.select(pl.col(pl.datatypes.Int64)).select( + max=pl.max_horizontal("*") + ) + expected = pl.Series("max", [5, 4, 3, 4, 5]) + assert_series_equal(result.to_series(), expected) + + result = fruits_cars.select(pl.col(pl.datatypes.Int64)).select( + max=pl.max_horizontal(pl.all()) + ) + assert_series_equal(result.to_series(), expected) + + result = fruits_cars.select(pl.col(pl.datatypes.Int64)).select( + max=pl.max_horizontal(pl.all(), "A", "*") + ) + assert_series_equal(result.to_series(), expected) + + +@pytest.mark.parametrize( + ("input", "expected_data"), + [ + (pl.col("^a|b$"), [1, 2]), + (pl.col("a", "b"), [1, 2]), + (pl.col("a"), [1, 4]), + (pl.lit(5, dtype=pl.Int64), [5]), + (5.0, [5.0]), + ], +) +def test_min_horizontal_single_input(input: Any, expected_data: list[Any]) -> None: + df = pl.DataFrame({"a": [1, 4], "b": [3, 2]}) + result = df.select(min=pl.min_horizontal(input)).to_series() + expected = pl.Series("min", expected_data) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("inputs", "expected_data"), + [ + ((["a", "b"]), [1, 2]), + (("a", "b"), [1, 2]), + (("a", 3), [1, 3]), + ], +) +def test_min_horizontal_multi_input( + inputs: tuple[Any, ...], expected_data: list[Any] +) -> None: + df = pl.DataFrame({"a": [1, 4], "b": [3, 2]}) + result = df.select(min=pl.min_horizontal(*inputs)) + expected = pl.DataFrame({"min": expected_data}) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + ("input", "expected_data"), + [ + (pl.col("^a|b$"), [3, 4]), + (pl.col("a", "b"), [3, 4]), + (pl.col("a"), [1, 4]), + (pl.lit(5, dtype=pl.Int64), [5]), + (5.0, [5.0]), + ], +) +def test_max_horizontal_single_input(input: Any, expected_data: list[Any]) -> None: + df = pl.DataFrame({"a": [1, 4], "b": [3, 2]}) + result = df.select(max=pl.max_horizontal(input)) + expected = pl.DataFrame({"max": expected_data}) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + ("inputs", "expected_data"), + [ + ((["a", "b"]), [3, 4]), + (("a", "b"), [3, 4]), + (("a", 3), [3, 4]), + ], +) +def test_max_horizontal_multi_input( + inputs: tuple[Any, ...], expected_data: list[Any] +) -> None: + df = pl.DataFrame({"a": [1, 4], "b": [3, 2]}) + result = df.select(max=pl.max_horizontal(*inputs)) + expected = pl.DataFrame({"max": expected_data}) + assert_frame_equal(result, expected) + + +def test_expanding_sum() -> None: + df = pl.DataFrame( + { + "x": [0, 1, 2], + "y_1": [1.1, 2.2, 3.3], + "y_2": [1.0, 2.5, 3.5], + } + ) + + result = df.with_columns(pl.sum_horizontal(pl.col(r"^y_.*$")).alias("y_sum"))[ + "y_sum" + ] + assert result.to_list() == [2.1, 4.7, 6.8] + + +def test_sum_max_min() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0]}) + out = df.select( + sum=pl.sum_horizontal("a", "b"), + max=pl.max_horizontal("a", pl.col("b") ** 2), + min=pl.min_horizontal("a", pl.col("b") ** 2), + ) + assert_series_equal(out["sum"], pl.Series("sum", [2.0, 4.0, 6.0])) + assert_series_equal(out["max"], pl.Series("max", [1.0, 4.0, 9.0])) + assert_series_equal(out["min"], pl.Series("min", [1.0, 2.0, 3.0])) + + +def test_str_sum_horizontal() -> None: + df = pl.DataFrame( + {"A": ["a", "b", None, "c", None], "B": ["f", "g", "h", None, None]} + ) + out = df.select(pl.sum_horizontal("A", "B")) + assert_series_equal(out["A"], pl.Series("A", ["af", "bg", "h", "c", ""])) + + +def test_sum_null_dtype() -> None: + df = pl.DataFrame( + { + "A": [5, None, 3, 2, 1], + "B": [5, 3, None, 2, 1], + "C": [None, None, None, None, None], + } + ) + + assert_series_equal( + df.select(pl.sum_horizontal("A", "B", "C")).to_series(), + pl.Series("A", [10, 3, 3, 4, 2]), + ) + assert_series_equal( + df.select(pl.sum_horizontal("C", "B")).to_series(), + pl.Series("C", [5, 3, 0, 2, 1]), + ) + assert_series_equal( + df.select(pl.sum_horizontal("C", "C")).to_series(), + pl.Series("C", [None, None, None, None, None]), + ) + + +def test_sum_single_col() -> None: + df = pl.DataFrame( + { + "A": [5, None, 3, None, 1], + } + ) + + assert_series_equal( + df.select(pl.sum_horizontal("A")).to_series(), pl.Series("A", [5, 0, 3, 0, 1]) + ) + + +@pytest.mark.parametrize("ignore_nulls", [False, True]) +def test_sum_correct_supertype(ignore_nulls: bool) -> None: + values = [1, 2] if ignore_nulls else [None, None] # type: ignore[list-item] + lf = pl.LazyFrame( + { + "null": [None, None], + "int": pl.Series(values, dtype=pl.Int32), + "float": pl.Series(values, dtype=pl.Float32), + } + ) + + # null + int32 should produce int32 + out = lf.select(pl.sum_horizontal("null", "int", ignore_nulls=ignore_nulls)) + expected = pl.LazyFrame({"null": pl.Series(values, dtype=pl.Int32)}) + assert_frame_equal(out.collect(), expected.collect()) + assert out.collect_schema() == expected.collect_schema() + + # null + float32 should produce float32 + out = lf.select(pl.sum_horizontal("null", "float", ignore_nulls=ignore_nulls)) + expected = pl.LazyFrame({"null": pl.Series(values, dtype=pl.Float32)}) + assert_frame_equal(out.collect(), expected.collect()) + assert out.collect_schema() == expected.collect_schema() + + # null + int32 + float32 should produce float64 + values = [2, 4] if ignore_nulls else [None, None] # type: ignore[list-item] + out = lf.select( + pl.sum_horizontal("null", "int", "float", ignore_nulls=ignore_nulls) + ) + expected = pl.LazyFrame({"null": pl.Series(values, dtype=pl.Float64)}) + assert_frame_equal(out.collect(), expected.collect()) + assert out.collect_schema() == expected.collect_schema() + + +def test_cum_sum_horizontal() -> None: + df = pl.DataFrame( + { + "a": [1, 2], + "b": [3, 4], + "c": [5, 6], + } + ) + result = df.select(pl.cum_sum_horizontal("a", "c")) + expected = pl.DataFrame({"cum_sum": [{"a": 1, "c": 6}, {"a": 2, "c": 8}]}) + assert_frame_equal(result, expected) + + +def test_sum_dtype_12028() -> None: + result = pl.select( + pl.sum_horizontal([pl.duration(seconds=10)]).alias("sum_duration") + ) + expected = pl.DataFrame( + [ + pl.Series( + "sum_duration", + [datetime.timedelta(seconds=10)], + dtype=pl.Duration(time_unit="us"), + ), + ] + ) + assert_frame_equal(expected, result) + + +def test_horizontal_expr_use_left_name() -> None: + df = pl.DataFrame( + { + "a": [1, 2], + "b": [3, 4], + } + ) + + assert df.select(pl.sum_horizontal("a", "b")).columns == ["a"] + assert df.select(pl.max_horizontal("*")).columns == ["a"] + assert df.select(pl.min_horizontal("b", "a")).columns == ["b"] + assert df.select(pl.any_horizontal("b", "a")).columns == ["b"] + assert df.select(pl.all_horizontal("a", "b")).columns == ["a"] + + +def test_horizontal_broadcasting() -> None: + df = pl.DataFrame( + { + "a": [1, 3], + "b": [3, 6], + } + ) + + assert_series_equal( + df.select(sum=pl.sum_horizontal(1, "a", "b")).to_series(), + pl.Series("sum", [5, 10]), + ) + assert_series_equal( + df.select(mean=pl.mean_horizontal(1, "a", "b")).to_series(), + pl.Series("mean", [1.66666, 3.33333]), + ) + assert_series_equal( + df.select(max=pl.max_horizontal(4, "*")).to_series(), pl.Series("max", [4, 6]) + ) + assert_series_equal( + df.select(min=pl.min_horizontal(2, "b", "a")).to_series(), + pl.Series("min", [1, 2]), + ) + assert_series_equal( + df.select(any=pl.any_horizontal(False, pl.Series([True, False]))).to_series(), + pl.Series("any", [True, False]), + ) + assert_series_equal( + df.select(all=pl.all_horizontal(True, pl.Series([True, False]))).to_series(), + pl.Series("all", [True, False]), + ) + + +def test_mean_horizontal() -> None: + lf = pl.LazyFrame({"a": [1, 2, 3], "b": [2.0, 4.0, 6.0], "c": [3, None, 9]}) + result = lf.select(pl.mean_horizontal(pl.all()).alias("mean")) + + expected = pl.LazyFrame({"mean": [2.0, 3.0, 6.0]}, schema={"mean": pl.Float64}) + assert_frame_equal(result, expected) + + +def test_mean_horizontal_bool() -> None: + df = pl.DataFrame( + { + "a": [True, False, False], + "b": [None, True, False], + "c": [True, False, False], + } + ) + expected = pl.DataFrame({"mean": [1.0, 1 / 3, 0.0]}, schema={"mean": pl.Float64}) + result = df.select(mean=pl.mean_horizontal(pl.all())) + assert_frame_equal(result, expected) + + +def test_mean_horizontal_no_columns() -> None: + lf = pl.LazyFrame({"a": [1, 2, 3], "b": [2.0, 4.0, 6.0], "c": [3, None, 9]}) + + with pytest.raises(ComputeError, match="number of output rows is unknown"): + lf.select(pl.mean_horizontal()) + + +def test_mean_horizontal_no_rows() -> None: + lf = pl.LazyFrame({"a": [], "b": [], "c": []}).with_columns(pl.all().cast(pl.Int64)) + + result = lf.select(pl.mean_horizontal(pl.all())) + + expected = pl.LazyFrame({"a": []}, schema={"a": pl.Float64}) + assert_frame_equal(result, expected) + + +def test_mean_horizontal_all_null() -> None: + lf = pl.LazyFrame({"a": [1, None], "b": [2, None], "c": [None, None]}) + + result = lf.select(pl.mean_horizontal(pl.all())) + + expected = pl.LazyFrame({"a": [1.5, None]}, schema={"a": pl.Float64}) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + ("in_dtype", "out_dtype"), + [ + (pl.Boolean, pl.Float64), + (pl.UInt8, pl.Float64), + (pl.UInt16, pl.Float64), + (pl.UInt32, pl.Float64), + (pl.UInt64, pl.Float64), + (pl.Int8, pl.Float64), + (pl.Int16, pl.Float64), + (pl.Int32, pl.Float64), + (pl.Int64, pl.Float64), + (pl.Float32, pl.Float32), + (pl.Float64, pl.Float64), + ], +) +def test_schema_mean_horizontal_single_column( + in_dtype: PolarsDataType, + out_dtype: PolarsDataType, +) -> None: + lf = pl.LazyFrame({"a": pl.Series([1, 0]).cast(in_dtype)}).select( + pl.mean_horizontal(pl.all()) + ) + + assert lf.collect_schema() == OrderedDict([("a", out_dtype)]) + + +def test_schema_boolean_sum_horizontal() -> None: + lf = pl.LazyFrame({"a": [True, False]}).select(pl.sum_horizontal("a")) + assert lf.collect_schema() == OrderedDict([("a", pl.UInt32)]) + + +def test_fold_all_schema() -> None: + df = pl.DataFrame( + { + "A": [1, 2, 3, 4, 5], + "fruits": ["banana", "banana", "apple", "apple", "banana"], + "B": [5, 4, 3, 2, 1], + "cars": ["beetle", "audi", "beetle", "beetle", "beetle"], + "optional": [28, 300, None, 2, -30], + } + ) + # divide because of overflow + result = df.select(pl.sum_horizontal(pl.all().hash(seed=1) // int(1e8))) + assert result.dtypes == [pl.UInt64] + + +@pytest.mark.parametrize( + "horizontal_func", + [ + pl.all_horizontal, + pl.any_horizontal, + pl.max_horizontal, + pl.min_horizontal, + pl.mean_horizontal, + pl.sum_horizontal, + ], +) +def test_expected_horizontal_dtype_errors(horizontal_func: type[pl.Expr]) -> None: + from decimal import Decimal as D + + import polars as pl + + df = pl.DataFrame( + { + "cola": [D("1.5"), D("0.5"), D("5"), D("0"), D("-0.25")], + "colb": [[0, 1], [2], [3, 4], [5], [6]], + "colc": ["aa", "bb", "cc", "dd", "ee"], + "cold": ["bb", "cc", "dd", "ee", "ff"], + "cole": [1000, 2000, 3000, 4000, 5000], + } + ) + with pytest.raises(PolarsError): + df.select( + horizontal_func( # type: ignore[call-arg] + pl.col("cola"), + pl.col("colb"), + pl.col("colc"), + pl.col("cold"), + pl.col("cole"), + ) + ) + + +def test_horizontal_sum_boolean_with_null() -> None: + lf = pl.LazyFrame( + { + "null": [None, None], + "bool": [True, False], + } + ) + + out = lf.select( + pl.sum_horizontal("null", "bool").alias("null_first"), + pl.sum_horizontal("bool", "null").alias("bool_first"), + ) + + expected_schema = pl.Schema( + { + "null_first": pl.get_index_type(), + "bool_first": pl.get_index_type(), + } + ) + + assert out.collect_schema() == expected_schema + + expected_df = pl.DataFrame( + { + "null_first": pl.Series([1, 0], dtype=pl.get_index_type()), + "bool_first": pl.Series([1, 0], dtype=pl.get_index_type()), + } + ) + + assert_frame_equal(out.collect(), expected_df) + + +@pytest.mark.parametrize("ignore_nulls", [True, False]) +@pytest.mark.parametrize( + ("dtype_in", "dtype_out"), + [ + (pl.Null, pl.Null), + (pl.Boolean, pl.get_index_type()), + (pl.UInt8, pl.UInt8), + (pl.Float32, pl.Float32), + (pl.Float64, pl.Float64), + (pl.Decimal(None, 5), pl.Decimal(None, 5)), + ], +) +def test_horizontal_sum_with_null_col_ignore_strategy( + dtype_in: PolarsDataType, + dtype_out: PolarsDataType, + ignore_nulls: bool, +) -> None: + lf = pl.LazyFrame( + { + "null": [None, None, None], + "s": pl.Series([1, 0, 1], dtype=dtype_in, strict=False), + "s2": pl.Series([1, 0, None], dtype=dtype_in, strict=False), + } + ) + result = lf.select(pl.sum_horizontal("null", "s", "s2", ignore_nulls=ignore_nulls)) + if ignore_nulls and dtype_in != pl.Null: + values = [2, 0, 1] + else: + values = [None, None, None] # type: ignore[list-item] + expected = pl.LazyFrame(pl.Series("null", values, dtype=dtype_out)) + assert_frame_equal(result, expected) + assert result.collect_schema() == expected.collect_schema() + + +@pytest.mark.parametrize("ignore_nulls", [True, False]) +@pytest.mark.parametrize( + ("dtype_in", "dtype_out"), + [ + (pl.Null, pl.Float64), + (pl.Boolean, pl.Float64), + (pl.UInt8, pl.Float64), + (pl.Float32, pl.Float32), + (pl.Float64, pl.Float64), + ], +) +def test_horizontal_mean_with_null_col_ignore_strategy( + dtype_in: PolarsDataType, + dtype_out: PolarsDataType, + ignore_nulls: bool, +) -> None: + lf = pl.LazyFrame( + { + "null": [None, None, None], + "s": pl.Series([1, 0, 1], dtype=dtype_in, strict=False), + "s2": pl.Series([1, 0, None], dtype=dtype_in, strict=False), + } + ) + result = lf.select(pl.mean_horizontal("null", "s", "s2", ignore_nulls=ignore_nulls)) + if ignore_nulls and dtype_in != pl.Null: + values = [1, 0, 1] + else: + values = [None, None, None] # type: ignore[list-item] + expected = pl.LazyFrame(pl.Series("null", values, dtype=dtype_out)) + assert_frame_equal(result, expected) + + +def test_raise_invalid_types_21835() -> None: + df = pl.DataFrame({"x": [1, 2], "y": ["three", "four"]}) + + with pytest.raises( + ComputeError, + match=r"cannot compare string with numeric type \(i64\)", + ): + df.select(pl.min_horizontal("x", "y")) diff --git a/py-polars/tests/unit/operations/aggregation/test_implode.py b/py-polars/tests/unit/operations/aggregation/test_implode.py new file mode 100644 index 000000000000..843abf23b87f --- /dev/null +++ b/py-polars/tests/unit/operations/aggregation/test_implode.py @@ -0,0 +1,23 @@ +import polars as pl + + +def test_implode_22192_22191() -> None: + df = pl.DataFrame({"x": [5, 6, 7, 8, 9], "g": [1, 2, 3, 3, 3]}) + assert df.group_by("g").agg(pl.col.x.implode()).sort("x").to_dict( + as_series=False + ) == {"g": [1, 2, 3], "x": [[5], [6], [7, 8, 9]]} + assert df.select(pl.col.x.implode().over("g")).to_dict(as_series=False) == { + "x": [[5], [6], [7, 8, 9], [7, 8, 9], [7, 8, 9]] + } + + +def test_implode_agg_lit() -> None: + assert ( + pl.DataFrame() + .group_by(1) + .agg( + pl.lit(pl.Series("x", [[3]])).list.set_union( + pl.lit(pl.Series([1])).implode() + ) + ) + ).to_dict(as_series=False) == {"literal": [1], "x": [[[3, 1]]]} diff --git a/py-polars/tests/unit/operations/aggregation/test_vertical.py b/py-polars/tests/unit/operations/aggregation/test_vertical.py new file mode 100644 index 000000000000..fc74fdf59b65 --- /dev/null +++ b/py-polars/tests/unit/operations/aggregation/test_vertical.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import numpy as np +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + + +def assert_expr_equal( + left: pl.Expr, + right: pl.Expr, + context: pl.DataFrame | pl.LazyFrame | None = None, +) -> None: + """ + Evaluate expressions in a context to determine equality. + + Parameters + ---------- + left + The expression to compare. + right + The other expression the compare. + context + The context in which the expressions will be evaluated. Defaults to an empty + context. + """ + if context is None: + context = pl.DataFrame() + assert_frame_equal(context.select(left), context.select(right)) + + +def test_all_expr() -> None: + df = pl.DataFrame({"nrs": [1, 2, 3, 4, 5, None]}) + assert_frame_equal(df.select(pl.all()), df) + + +def test_any_expr(fruits_cars: pl.DataFrame) -> None: + assert fruits_cars.with_columns(pl.col("A").cast(bool)).select(pl.any("A")).item() + + +@pytest.mark.parametrize("function", ["all", "any"]) +@pytest.mark.parametrize("input", ["a", "^a|b$"]) +def test_alias_for_col_agg_bool(function: str, input: str) -> None: + result = getattr(pl, function)(input) # e.g. pl.all(input) + expected = getattr(pl.col(input), function)() # e.g. pl.col(input).all() + context = pl.DataFrame({"a": [True, False], "b": [True, True]}) + assert_expr_equal(result, expected, context) + + +@pytest.mark.parametrize("function", ["min", "max", "sum", "cum_sum"]) +@pytest.mark.parametrize("input", ["a", "^a|b$"]) +def test_alias_for_col_agg(function: str, input: str) -> None: + result = getattr(pl, function)(input) # e.g. pl.min(input) + expected = getattr(pl.col(input), function)() # e.g. pl.col(input).min() + context = pl.DataFrame({"a": [1, 4], "b": [3, 2]}) + assert_expr_equal(result, expected, context) + + +@pytest.mark.release +def test_mean_overflow() -> None: + np.random.seed(1) + expected = 769.5607652 + + df = pl.DataFrame(np.random.randint(500, 1040, 5000000), schema=["value"]) + + result = df.with_columns(pl.mean("value"))[0, 0] + assert np.isclose(result, expected) + + result = df.with_columns(pl.col("value").cast(pl.Int32)).with_columns( + pl.mean("value") + )[0, 0] + assert np.isclose(result, expected) + + result = df.with_columns(pl.col("value").cast(pl.Int32)).get_column("value").mean() + assert np.isclose(result, expected) + + +def test_deep_subexpression_f32_schema_7129() -> None: + df = pl.DataFrame({"a": [1.1, 2.3, 3.4, 4.5]}, schema={"a": pl.Float32()}) + assert df.with_columns(pl.col("a") - pl.col("a").median()).dtypes == [pl.Float32] + assert df.with_columns( + (pl.col("a") - pl.col("a").mean()) / (pl.col("a").std() + 0.001) + ).dtypes == [pl.Float32] diff --git a/py-polars/tests/unit/operations/arithmetic/__init__.py b/py-polars/tests/unit/operations/arithmetic/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py new file mode 100644 index 000000000000..0d6611b6a03f --- /dev/null +++ b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py @@ -0,0 +1,876 @@ +from __future__ import annotations + +import operator +from collections import OrderedDict +from datetime import date, datetime, timedelta +from typing import TYPE_CHECKING, Any, Callable + +import numpy as np +import pytest + +import polars as pl +from polars import ( + Date, + Float64, + Int8, + Int16, + Int32, + Int64, + UInt8, + UInt16, + UInt32, + UInt64, +) +from polars.exceptions import ColumnNotFoundError, InvalidOperationError +from polars.testing import assert_frame_equal, assert_series_equal +from tests.unit.conftest import INTEGER_DTYPES, NUMERIC_DTYPES + +if TYPE_CHECKING: + from polars._typing import PolarsIntegerType + + +def test_sqrt_neg_inf() -> None: + out = pl.DataFrame( + { + "val": [float("-Inf"), -9, 0, 9, float("Inf")], + } + ).with_columns(pl.col("val").sqrt().alias("sqrt")) + # comparing nans and infinities by string value as they are not cmp + assert str(out["sqrt"].to_list()) == str( + [float("nan"), float("nan"), 0.0, 3.0, float("Inf")] + ) + + +def test_arithmetic_with_logical_on_series_4920() -> None: + assert (pl.Series([date(2022, 6, 3)]) - date(2022, 1, 1)).dtype == pl.Duration("ms") + + +@pytest.mark.parametrize( + ("left", "right", "expected_value", "expected_dtype"), + [ + (date(2021, 1, 1), date(2020, 1, 1), timedelta(days=366), pl.Duration("ms")), + ( + datetime(2021, 1, 1), + datetime(2020, 1, 1), + timedelta(days=366), + pl.Duration("us"), + ), + (timedelta(days=1), timedelta(days=2), timedelta(days=-1), pl.Duration("us")), + (2.0, 3.0, -1.0, pl.Float64), + ], +) +def test_arithmetic_sub( + left: object, right: object, expected_value: object, expected_dtype: pl.DataType +) -> None: + result = left - pl.Series([right]) + expected = pl.Series("", [expected_value], dtype=expected_dtype) + assert_series_equal(result, expected) + result = pl.Series([left]) - right + assert_series_equal(result, expected) + + +def test_struct_arithmetic() -> None: + df = pl.DataFrame( + { + "a": [1, 2], + "b": [3, 4], + "c": [5, 6], + } + ).select(pl.cum_sum_horizontal("a", "c")) + assert df.select(pl.col("cum_sum") * 2).to_dict(as_series=False) == { + "cum_sum": [{"a": 2, "c": 12}, {"a": 4, "c": 16}] + } + assert df.select(pl.col("cum_sum") - 2).to_dict(as_series=False) == { + "cum_sum": [{"a": -1, "c": 4}, {"a": 0, "c": 6}] + } + assert df.select(pl.col("cum_sum") + 2).to_dict(as_series=False) == { + "cum_sum": [{"a": 3, "c": 8}, {"a": 4, "c": 10}] + } + assert df.select(pl.col("cum_sum") / 2).to_dict(as_series=False) == { + "cum_sum": [{"a": 0.5, "c": 3.0}, {"a": 1.0, "c": 4.0}] + } + assert df.select(pl.col("cum_sum") // 2).to_dict(as_series=False) == { + "cum_sum": [{"a": 0, "c": 3}, {"a": 1, "c": 4}] + } + + # inline, this checks cum_sum reports the right output type + assert pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}).select( + pl.cum_sum_horizontal("a", "c") * 3 + ).to_dict(as_series=False) == {"cum_sum": [{"a": 3, "c": 18}, {"a": 6, "c": 24}]} + + +def test_simd_float_sum_determinism() -> None: + out = [] + for _ in range(10): + a = pl.Series( + [ + 0.021415853782953836, + 0.06234123511682772, + 0.016962384922753124, + 0.002595968402539279, + 0.007632765529696731, + 0.012105848332077212, + 0.021439787151032317, + 0.3223049133700719, + 0.10526670729539435, + 0.0859029285522487, + ] + ) + out.append(a.sum()) + + assert out == [ + 0.6579683924555951, + 0.6579683924555951, + 0.6579683924555951, + 0.6579683924555951, + 0.6579683924555951, + 0.6579683924555951, + 0.6579683924555951, + 0.6579683924555951, + 0.6579683924555951, + 0.6579683924555951, + ] + + +def test_floor_division_float_int_consistency() -> None: + a = np.random.randn(10) * 10 + + assert (pl.Series(a) // 5).to_list() == list(a // 5) + assert (pl.Series(a, dtype=pl.Int32) // 5).to_list() == list( + (a.astype(int) // 5).astype(int) + ) + + +def test_series_expr_arithm() -> None: + s = pl.Series([1, 2, 3]) + assert (s + pl.col("a")).meta == pl.lit(s) + pl.col("a") + assert (s - pl.col("a")).meta == pl.lit(s) - pl.col("a") + assert (s / pl.col("a")).meta == pl.lit(s) / pl.col("a") + assert (s // pl.col("a")).meta == pl.lit(s) // pl.col("a") + assert (s * pl.col("a")).meta == pl.lit(s) * pl.col("a") + assert (s % pl.col("a")).meta == pl.lit(s) % pl.col("a") + + +def test_fused_arithm() -> None: + df = pl.DataFrame( + { + "a": [1, 2, 3], + "b": [10, 20, 30], + "c": [5, 5, 5], + } + ) + + q = df.lazy().select( + pl.col("a") * pl.col("b") + pl.col("c"), + (pl.col("a") + pl.col("b") * pl.col("c")).alias("2"), + ) + # the extra aliases are because the fma does operation reordering + assert ( + """col("c").fma([col("a"), col("b")]).alias("a"), col("a").fma([col("b"), col("c")]).alias("2")""" + in q.explain() + ) + assert q.collect().to_dict(as_series=False) == { + "a": [15, 45, 95], + "2": [51, 102, 153], + } + # fsm + q = df.lazy().select(pl.col("a") - pl.col("b") * pl.col("c")) + assert """col("a").fsm([col("b"), col("c")])""" in q.explain() + assert q.collect()["a"].to_list() == [-49, -98, -147] + # fms + q = df.lazy().select(pl.col("a") * pl.col("b") - pl.col("c")) + assert """col("a").fms([col("b"), col("c")])""" in q.explain() + assert q.collect()["a"].to_list() == [5, 35, 85] + + # check if we constant fold instead of fma + q = df.lazy().select(pl.lit(1) * pl.lit(2) - pl.col("c")) + assert """(2) - (col("c")""" in q.explain() + + # Check if fused is turned off for literals see: #9857 + for expr in [ + pl.col("c") * 2 + 5, + pl.col("c") * 2 + pl.col("c"), + pl.col("c") * 2 - 5, + pl.col("c") * 2 - pl.col("c"), + 5 - pl.col("c") * 2, + pl.col("c") - pl.col("c") * 2, + ]: + q = df.lazy().select(expr) + assert all(el not in q.explain() for el in ["fms", "fsm", "fma"]), ( + f"Fused Arithmetic applied on literal {expr}: {q.explain()}" + ) + + +def test_literal_no_upcast() -> None: + df = pl.DataFrame({"a": pl.Series([1, 2, 3], dtype=pl.Float32)}) + + q = ( + df.lazy() + .select( + (pl.col("a") * -5 + 2).alias("fma"), + (2 - pl.col("a") * 5).alias("fsm"), + (pl.col("a") * 5 - 2).alias("fms"), + ) + .collect() + ) + assert set(q.schema.values()) == {pl.Float32}, ( + "Literal * Column (Float32) should not lead upcast" + ) + + +def test_boolean_addition() -> None: + s = pl.DataFrame( + {"a": [True, False, False], "b": [True, False, True]} + ).sum_horizontal() + + assert s.dtype == pl.get_index_type() + assert s.to_list() == [2, 0, 1] + df = pl.DataFrame( + {"a": [True], "b": [False]}, + ).select(pl.sum_horizontal("a", "b")) + assert df.dtypes == [pl.get_index_type()] + + +def test_bitwise_6311() -> None: + df = pl.DataFrame({"col1": [0, 1, 2, 3], "flag": [0, 0, 0, 0]}) + + assert ( + df.with_columns( + pl.when((pl.col("col1") < 1) | (pl.col("col1") >= 3)) + .then(pl.col("flag") | 2) # set flag b0010 + .otherwise(pl.col("flag")) + ).with_columns( + pl.when(pl.col("col1") > -1) + .then(pl.col("flag") | 4) + .otherwise(pl.col("flag")) + ) + ).to_dict(as_series=False) == {"col1": [0, 1, 2, 3], "flag": [6, 4, 4, 6]} + + +def test_arithmetic_null_count() -> None: + df = pl.DataFrame({"a": [1, None, 2], "b": [None, 2, 1]}) + out = df.select( + no_broadcast=pl.col("a") + pl.col("b"), + broadcast_left=1 + pl.col("b"), + broadcast_right=pl.col("a") + 1, + ) + assert out.null_count().to_dict(as_series=False) == { + "no_broadcast": [2], + "broadcast_left": [1], + "broadcast_right": [1], + } + + +@pytest.mark.parametrize( + "op", + [ + operator.add, + operator.floordiv, + operator.mod, + operator.mul, + operator.sub, + ], +) +@pytest.mark.parametrize("dtype", NUMERIC_DTYPES) +def test_operator_arithmetic_with_nulls(op: Any, dtype: pl.DataType) -> None: + df = pl.DataFrame({"n": [2, 3]}, schema={"n": dtype}) + s = df.to_series() + + df_expected = pl.DataFrame({"n": [None, None]}, schema={"n": dtype}) + s_expected = df_expected.to_series() + + # validate expr, frame, and series behaviour with null value arithmetic + op_name = op.__name__ + for null_expr in (None, pl.lit(None)): + assert_frame_equal(df_expected, df.select(op(pl.col("n"), null_expr))) + assert_frame_equal( + df_expected, df.select(getattr(pl.col("n"), op_name)(null_expr)) + ) + + assert_frame_equal(op(df, None), df_expected) + assert_series_equal(op(s, None), s_expected) + + +@pytest.mark.parametrize( + "op", + [ + operator.add, + operator.mod, + operator.mul, + operator.sub, + ], +) +def test_null_column_arithmetic(op: Any) -> None: + df = pl.DataFrame({"a": [None, None], "b": [None, None]}) + expected_df = pl.DataFrame({"a": [None, None]}) + + output_df = df.select(op(pl.col("a"), pl.col("b"))) + assert_frame_equal(expected_df, output_df) + # test broadcast right + output_df = df.select(op(pl.col("a"), pl.Series([None]))) + assert_frame_equal(expected_df, output_df) + # test broadcast left + output_df = df.select(op(pl.Series("a", [None]), pl.col("a"))) + assert_frame_equal(expected_df, output_df) + + +def test_bool_floordiv() -> None: + df = pl.DataFrame({"x": [True]}) + + with pytest.raises( + InvalidOperationError, + match="floor_div operation not supported for dtype `bool`", + ): + df.with_columns(pl.col("x").floordiv(2)) + + +def test_arithmetic_in_aggregation_3739() -> None: + def demean_dot() -> pl.Expr: + x = pl.col("x") + y = pl.col("y") + x1 = x - x.mean() + y1 = y - y.mean() + return (x1 * y1).sum().alias("demean_dot") + + assert ( + pl.DataFrame( + { + "key": ["a", "a", "a", "a"], + "x": [4, 2, 2, 4], + "y": [2, 0, 2, 0], + } + ) + .group_by("key") + .agg( + [ + demean_dot(), + ] + ) + ).to_dict(as_series=False) == {"key": ["a"], "demean_dot": [0.0]} + + +def test_arithmetic_on_df() -> None: + df = pl.DataFrame({"a": [1.0, 2.0], "b": [3.0, 4.0]}) + + for df_mul in (df * 2, 2 * df): + expected = pl.DataFrame({"a": [2.0, 4.0], "b": [6.0, 8.0]}) + assert_frame_equal(df_mul, expected) + + for df_plus in (df + 2, 2 + df): + expected = pl.DataFrame({"a": [3.0, 4.0], "b": [5.0, 6.0]}) + assert_frame_equal(df_plus, expected) + + df_div = df / 2 + expected = pl.DataFrame({"a": [0.5, 1.0], "b": [1.5, 2.0]}) + assert_frame_equal(df_div, expected) + + df_minus = df - 2 + expected = pl.DataFrame({"a": [-1.0, 0.0], "b": [1.0, 2.0]}) + assert_frame_equal(df_minus, expected) + + df_mod = df % 2 + expected = pl.DataFrame({"a": [1.0, 0.0], "b": [1.0, 0.0]}) + assert_frame_equal(df_mod, expected) + + df2 = pl.DataFrame({"c": [10]}) + + out = df + df2 + expected = pl.DataFrame({"a": [11.0, None], "b": [None, None]}).with_columns( + pl.col("b").cast(pl.Float64) + ) + assert_frame_equal(out, expected) + + out = df - df2 + expected = pl.DataFrame({"a": [-9.0, None], "b": [None, None]}).with_columns( + pl.col("b").cast(pl.Float64) + ) + assert_frame_equal(out, expected) + + out = df / df2 + expected = pl.DataFrame({"a": [0.1, None], "b": [None, None]}).with_columns( + pl.col("b").cast(pl.Float64) + ) + assert_frame_equal(out, expected) + + out = df * df2 + expected = pl.DataFrame({"a": [10.0, None], "b": [None, None]}).with_columns( + pl.col("b").cast(pl.Float64) + ) + assert_frame_equal(out, expected) + + out = df % df2 + expected = pl.DataFrame({"a": [1.0, None], "b": [None, None]}).with_columns( + pl.col("b").cast(pl.Float64) + ) + assert_frame_equal(out, expected) + + # cannot do arithmetic with a sequence + with pytest.raises(TypeError, match="operation not supported"): + _ = df + [1] # type: ignore[operator] + + +def test_df_series_division() -> None: + df = pl.DataFrame( + { + "a": [2, 2, 4, 4, 6, 6], + "b": [2, 2, 10, 5, 6, 6], + } + ) + s = pl.Series([2, 2, 2, 2, 2, 2]) + assert (df / s).to_dict(as_series=False) == { + "a": [1.0, 1.0, 2.0, 2.0, 3.0, 3.0], + "b": [1.0, 1.0, 5.0, 2.5, 3.0, 3.0], + } + assert (df // s).to_dict(as_series=False) == { + "a": [1, 1, 2, 2, 3, 3], + "b": [1, 1, 5, 2, 3, 3], + } + + +@pytest.mark.parametrize( + "s", [pl.Series([1, 2], dtype=Int64), pl.Series([1, 2], dtype=Float64)] +) +def test_arithmetic_series(s: pl.Series) -> None: + a = s + b = s + + assert ((a * b) == [1, 4]).sum() == 2 + assert ((a / b) == [1.0, 1.0]).sum() == 2 + assert ((a + b) == [2, 4]).sum() == 2 + assert ((a - b) == [0, 0]).sum() == 2 + assert ((a + 1) == [2, 3]).sum() == 2 + assert ((a - 1) == [0, 1]).sum() == 2 + assert ((a / 1) == [1.0, 2.0]).sum() == 2 + assert ((a // 2) == [0, 1]).sum() == 2 + assert ((a * 2) == [2, 4]).sum() == 2 + assert ((2 + a) == [3, 4]).sum() == 2 + assert ((1 - a) == [0, -1]).sum() == 2 + assert ((2 * a) == [2, 4]).sum() == 2 + + # integer division + assert_series_equal(1 / a, pl.Series([1.0, 0.5])) + expected = pl.Series([1, 0]) if s.dtype == Int64 else pl.Series([1.0, 0.5]) + assert_series_equal(1 // a, expected) + # modulo + assert ((1 % a) == [0, 1]).sum() == 2 + assert ((a % 1) == [0, 0]).sum() == 2 + # negate + assert (-a == [-1, -2]).sum() == 2 + # unary plus + assert (+a == a).all() + # wrong dtypes in rhs operands + assert ((1.0 - a) == [0.0, -1.0]).sum() == 2 + assert ((1.0 / a) == [1.0, 0.5]).sum() == 2 + assert ((1.0 * a) == [1, 2]).sum() == 2 + assert ((1.0 + a) == [2, 3]).sum() == 2 + assert ((1.0 % a) == [0, 1]).sum() == 2 + + +def test_arithmetic_datetime() -> None: + a = pl.Series("a", [datetime(2021, 1, 1)]) + with pytest.raises(TypeError): + a // 2 + with pytest.raises(TypeError): + a / 2 + with pytest.raises(TypeError): + a * 2 + with pytest.raises(TypeError): + a % 2 + with pytest.raises( + InvalidOperationError, + ): + a**2 + with pytest.raises(TypeError): + 2 / a + with pytest.raises(TypeError): + 2 // a + with pytest.raises(TypeError): + 2 * a + with pytest.raises(TypeError): + 2 % a + with pytest.raises( + InvalidOperationError, + ): + 2**a + + +def test_power_series() -> None: + a = pl.Series([1, 2], dtype=Int64) + b = pl.Series([None, 2.0], dtype=Float64) + c = pl.Series([date(2020, 2, 28), date(2020, 3, 1)], dtype=Date) + d = pl.Series([1, 2], dtype=UInt8) + e = pl.Series([1, 2], dtype=Int8) + f = pl.Series([1, 2], dtype=UInt16) + g = pl.Series([1, 2], dtype=Int16) + h = pl.Series([1, 2], dtype=UInt32) + i = pl.Series([1, 2], dtype=Int32) + j = pl.Series([1, 2], dtype=UInt64) + k = pl.Series([1, 2], dtype=Int64) + m = pl.Series([2**33, 2**33], dtype=UInt64) + + # pow + assert_series_equal(a**2, pl.Series([1, 4], dtype=Int64)) + assert_series_equal(b**3, pl.Series([None, 8.0], dtype=Float64)) + assert_series_equal(a**a, pl.Series([1, 4], dtype=Int64)) + assert_series_equal(b**b, pl.Series([None, 4.0], dtype=Float64)) + assert_series_equal(a**b, pl.Series([None, 4.0], dtype=Float64)) + assert_series_equal(d**d, pl.Series([1, 4], dtype=UInt8)) + assert_series_equal(e**d, pl.Series([1, 4], dtype=Int8)) + assert_series_equal(f**d, pl.Series([1, 4], dtype=UInt16)) + assert_series_equal(g**d, pl.Series([1, 4], dtype=Int16)) + assert_series_equal(h**d, pl.Series([1, 4], dtype=UInt32)) + assert_series_equal(i**d, pl.Series([1, 4], dtype=Int32)) + assert_series_equal(j**d, pl.Series([1, 4], dtype=UInt64)) + assert_series_equal(k**d, pl.Series([1, 4], dtype=Int64)) + + with pytest.raises( + InvalidOperationError, + match="`pow` operation not supported for dtype `null` as exponent", + ): + a ** pl.lit(None) + + with pytest.raises( + InvalidOperationError, + match="`pow` operation not supported for dtype `date` as base", + ): + c**2 + with pytest.raises( + InvalidOperationError, + match="`pow` operation not supported for dtype `date` as exponent", + ): + 2**c + + with pytest.raises(ColumnNotFoundError): + a ** "hi" # type: ignore[operator] + + # Raising to UInt64: raises if can't be downcast safely to UInt32... + with pytest.raises( + InvalidOperationError, match="conversion from `u64` to `u32` failed" + ): + a**m + # ... but succeeds otherwise. + assert_series_equal(a**j, pl.Series([1, 4], dtype=Int64)) + + # rpow + assert_series_equal(2.0**a, pl.Series(None, [2.0, 4.0], dtype=Float64)) + assert_series_equal(2**b, pl.Series(None, [None, 4.0], dtype=Float64)) + + with pytest.raises(ColumnNotFoundError): + "hi" ** a + + # Series.pow() method + assert_series_equal(a.pow(2), pl.Series([1, 4], dtype=Int64)) + + +def test_rpow_name_20071() -> None: + result = 1 ** pl.Series("a", [1, 2]) + expected = pl.Series("a", [1, 1], pl.Int32) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("expected", "expr", "column_names"), + [ + (np.array([[2, 4], [6, 8]], dtype=np.int64), lambda a, b: a + b, ("a", "a")), + (np.array([[0, 0], [0, 0]], dtype=np.int64), lambda a, b: a - b, ("a", "a")), + (np.array([[1, 4], [9, 16]], dtype=np.int64), lambda a, b: a * b, ("a", "a")), + ( + np.array([[1.0, 1.0], [1.0, 1.0]], dtype=np.float64), + lambda a, b: a / b, + ("a", "a"), + ), + (np.array([[0, 0], [0, 0]], dtype=np.int64), lambda a, b: a % b, ("a", "a")), + ( + np.array([[3, 4], [7, 8]], dtype=np.int64), + lambda a, b: a + b, + ("a", "uint8"), + ), + # This fails because the code is buggy, see + # https://github.com/pola-rs/polars/issues/17820 + # + # ( + # np.array([[[2, 4]], [[6, 8]]], dtype=np.int64), + # lambda a, b: a + b, + # ("nested", "nested"), + # ), + ], +) +def test_array_arithmetic_same_size( + expected: Any, + expr: Callable[[pl.Series | pl.Expr, pl.Series | pl.Expr], pl.Series], + column_names: tuple[str, str], +) -> None: + df = pl.DataFrame( + [ + pl.Series("a", np.array([[1, 2], [3, 4]], dtype=np.int64)), + pl.Series("uint8", np.array([[2, 2], [4, 4]], dtype=np.uint8)), + pl.Series("nested", np.array([[[1, 2]], [[3, 4]]], dtype=np.int64)), + ] + ) + # Expr-based arithmetic: + assert_frame_equal( + df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))), + pl.Series(column_names[0], expected).to_frame(), + ) + # Direct arithmetic on the Series: + assert_series_equal( + expr(df[column_names[0]], df[column_names[1]]), + pl.Series(column_names[0], expected), + ) + + +def test_schema_owned_arithmetic_5669() -> None: + df = ( + pl.LazyFrame({"A": [1, 2, 3]}) + .filter(pl.col("A") >= 3) + .with_columns(-pl.col("A").alias("B")) + .collect() + ) + assert df.columns == ["A", "B"] + assert df.rows() == [(3, -3)] + + +def test_schema_true_divide_6643() -> None: + df = pl.DataFrame({"a": [1]}) + a = pl.col("a") + assert df.lazy().select(a / 2).select(pl.col(pl.Int64)).collect().shape == (0, 0) + + +def test_literal_subtract_schema_13284() -> None: + assert ( + pl.LazyFrame({"a": [23, 30]}, schema={"a": pl.UInt8}) + .with_columns(pl.col("a") - pl.lit(1)) + .group_by("a") + .len() + ).collect_schema() == OrderedDict([("a", pl.UInt8), ("len", pl.UInt32)]) + + +@pytest.mark.parametrize("dtype", INTEGER_DTYPES) +def test_int_operator_stability(dtype: pl.DataType) -> None: + s = pl.Series(values=[10], dtype=dtype) + assert pl.select(pl.lit(s) // 2).dtypes == [dtype] + assert pl.select(pl.lit(s) + 2).dtypes == [dtype] + assert pl.select(pl.lit(s) - 2).dtypes == [dtype] + assert pl.select(pl.lit(s) * 2).dtypes == [dtype] + assert pl.select(pl.lit(s) / 2).dtypes == [pl.Float64] + + +def test_duration_division_schema() -> None: + df = pl.DataFrame({"a": [1]}) + q = ( + df.lazy() + .with_columns(pl.col("a").cast(pl.Duration)) + .select(pl.col("a") / pl.col("a")) + ) + + assert q.collect_schema() == {"a": pl.Float64} + assert q.collect().to_dict(as_series=False) == {"a": [1.0]} + + +@pytest.mark.parametrize( + ("a", "b", "op"), + [ + (pl.Duration, pl.Int32, "+"), + (pl.Int32, pl.Duration, "+"), + (pl.Time, pl.Int32, "+"), + (pl.Int32, pl.Time, "+"), + (pl.Date, pl.Int32, "+"), + (pl.Int32, pl.Date, "+"), + (pl.Datetime, pl.Duration, "*"), + (pl.Duration, pl.Datetime, "*"), + (pl.Date, pl.Duration, "*"), + (pl.Duration, pl.Date, "*"), + (pl.Time, pl.Duration, "*"), + (pl.Duration, pl.Time, "*"), + ], +) +def test_raise_invalid_temporal(a: pl.DataType, b: pl.DataType, op: str) -> None: + a = pl.Series("a", [], dtype=a) # type: ignore[assignment] + b = pl.Series("b", [], dtype=b) # type: ignore[assignment] + _df = pl.DataFrame([a, b]) + + with pytest.raises(InvalidOperationError): + eval(f"_df.select(pl.col('a') {op} pl.col('b'))") + + +def test_arithmetic_duration_div_multiply() -> None: + df = pl.DataFrame([pl.Series("a", [100, 200, 3000], dtype=pl.Duration)]) + + q = df.lazy().with_columns( + b=pl.col("a") / 2, + c=pl.col("a") / 2.5, + d=pl.col("a") * 2, + e=pl.col("a") * 2.5, + f=pl.col("a") / pl.col("a"), # a constant float + ) + assert q.collect_schema() == pl.Schema( + [ + ("a", pl.Duration(time_unit="us")), + ("b", pl.Duration(time_unit="us")), + ("c", pl.Duration(time_unit="us")), + ("d", pl.Duration(time_unit="us")), + ("e", pl.Duration(time_unit="us")), + ("f", pl.Float64()), + ] + ) + assert q.collect().to_dict(as_series=False) == { + "a": [ + timedelta(microseconds=100), + timedelta(microseconds=200), + timedelta(microseconds=3000), + ], + "b": [ + timedelta(microseconds=50), + timedelta(microseconds=100), + timedelta(microseconds=1500), + ], + "c": [ + timedelta(microseconds=40), + timedelta(microseconds=80), + timedelta(microseconds=1200), + ], + "d": [ + timedelta(microseconds=200), + timedelta(microseconds=400), + timedelta(microseconds=6000), + ], + "e": [ + timedelta(microseconds=250), + timedelta(microseconds=500), + timedelta(microseconds=7500), + ], + "f": [1.0, 1.0, 1.0], + } + + # rhs + + q = df.lazy().with_columns( + b=2 * pl.col("a"), + c=2.5 * pl.col("a"), + ) + assert q.collect_schema() == pl.Schema( + [ + ("a", pl.Duration(time_unit="us")), + ("b", pl.Duration(time_unit="us")), + ("c", pl.Duration(time_unit="us")), + ] + ) + assert q.collect().to_dict(as_series=False) == { + "a": [ + timedelta(microseconds=100), + timedelta(microseconds=200), + timedelta(microseconds=3000), + ], + "b": [ + timedelta(microseconds=200), + timedelta(microseconds=400), + timedelta(microseconds=6000), + ], + "c": [ + timedelta(microseconds=250), + timedelta(microseconds=500), + timedelta(microseconds=7500), + ], + } + + +def test_invalid_shapes_err() -> None: + with pytest.raises( + InvalidOperationError, + match=r"cannot do arithmetic operation on series of different lengths: got 2 and 3", + ): + pl.Series([1, 2]) + pl.Series([1, 2, 3]) + + +def test_date_datetime_sub() -> None: + df = pl.DataFrame({"foo": [date(2020, 1, 1)], "bar": [datetime(2020, 1, 5)]}) + + assert df.select( + pl.col("foo") - pl.col("bar"), + pl.col("bar") - pl.col("foo"), + ).to_dict(as_series=False) == { + "foo": [timedelta(days=-4)], + "bar": [timedelta(days=4)], + } + + +def test_time_time_sub() -> None: + df = pl.DataFrame( + { + "foo": pl.Series([-1, 0, 10]).cast(pl.Datetime("us")), + "bar": pl.Series([1, 0, 1]).cast(pl.Datetime("us")), + } + ) + + assert df.select( + pl.col("foo").dt.time() - pl.col("bar").dt.time(), + pl.col("bar").dt.time() - pl.col("foo").dt.time(), + ).to_dict(as_series=False) == { + "foo": [ + timedelta(days=1, microseconds=-2), + timedelta(0), + timedelta(microseconds=9), + ], + "bar": [ + timedelta(days=-1, microseconds=2), + timedelta(0), + timedelta(microseconds=-9), + ], + } + + +def test_raise_invalid_shape() -> None: + with pytest.raises(InvalidOperationError): + pl.DataFrame([[1, 2], [3, 4]]) * pl.DataFrame([1, 2, 3]) + + +def test_integer_divide_scalar_zero_lhs_19142() -> None: + assert_series_equal(pl.Series([0]) // pl.Series([1, 0]), pl.Series([0, None])) + assert_series_equal(pl.Series([0]) % pl.Series([1, 0]), pl.Series([0, None])) + + +def test_compound_duration_21389() -> None: + # test add + lf = pl.LazyFrame( + { + "ts": datetime(2024, 1, 1, 1, 2, 3), + "duration": timedelta(days=1), + } + ) + result = lf.select(pl.col("ts") + pl.col("duration") * 2) + expected_schema = pl.Schema({"ts": pl.Datetime(time_unit="us", time_zone=None)}) + expected = pl.DataFrame({"ts": datetime(2024, 1, 3, 1, 2, 3)}) + assert result.collect_schema() == expected_schema + assert_frame_equal(result.collect(), expected) + + # test subtract + result = lf.select(pl.col("ts") - pl.col("duration") * 2) + expected_schema = pl.Schema({"ts": pl.Datetime(time_unit="us", time_zone=None)}) + expected = pl.DataFrame({"ts": datetime(2023, 12, 30, 1, 2, 3)}) + assert result.collect_schema() == expected_schema + assert_frame_equal(result.collect(), expected) + + +@pytest.mark.parametrize("dtype", INTEGER_DTYPES) +def test_arithmetic_i128(dtype: PolarsIntegerType) -> None: + s = pl.Series("a", [0, 1, 127], dtype=dtype, strict=False) + s128 = pl.Series("a", [0, 0, 0], dtype=pl.Int128) + expected = pl.Series("a", [0, 1, 127], dtype=pl.Int128) + assert_series_equal(s + s128, expected) + assert_series_equal(s128 + s, expected) + + +def test_arithmetic_i128_nonint() -> None: + s128 = pl.Series("a", [0], dtype=pl.Int128) + + s = pl.Series("a", [1.0], dtype=pl.Float32) + assert_series_equal(s + s128, pl.Series("a", [1.0], dtype=pl.Float64)) + assert_series_equal(s128 + s, pl.Series("a", [1.0], dtype=pl.Float64)) + + s = pl.Series("a", [1.0], dtype=pl.Float64) + assert_series_equal(s + s128, s) + assert_series_equal(s128 + s, s) + + s = pl.Series("a", [True], dtype=pl.Boolean) + assert_series_equal(s + s128, pl.Series("a", [1], dtype=pl.Int128)) + assert_series_equal(s128 + s, pl.Series("a", [1], dtype=pl.Int128)) diff --git a/py-polars/tests/unit/operations/arithmetic/test_array.py b/py-polars/tests/unit/operations/arithmetic/test_array.py new file mode 100644 index 000000000000..e6f47e7bdacf --- /dev/null +++ b/py-polars/tests/unit/operations/arithmetic/test_array.py @@ -0,0 +1,461 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable + +import pytest + +import polars as pl +from polars.exceptions import InvalidOperationError +from polars.testing import assert_series_equal +from tests.unit.operations.arithmetic.utils import ( + BROADCAST_SERIES_COMBINATIONS, + EXEC_OP_COMBINATIONS, +) + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + + +@pytest.mark.parametrize( + "array_side", ["left", "left3", "both", "both3", "right3", "right", "none"] +) +@pytest.mark.parametrize( + "broadcast_series", + BROADCAST_SERIES_COMBINATIONS, +) +@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +@pytest.mark.slow +def test_array_arithmetic_values( + array_side: str, + broadcast_series: Callable[ + [pl.Series, pl.Series, pl.Series], tuple[pl.Series, pl.Series, pl.Series] + ], + exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series], +) -> None: + """ + Tests value correctness. + + This test checks for output value correctness (a + b == c) across different + codepaths, by wrapping the values (a, b, c) in different combinations of + list / primitive columns. + """ + import operator as op + + dtypes: list[Any] = [pl.Null, pl.Null, pl.Null] + dtype: Any = pl.Null + + def materialize_array(v: Any) -> pl.Series: + return pl.Series( + [[None, v, None]], + dtype=pl.Array(dtype, 3), + ) + + def materialize_array3(v: Any) -> pl.Series: + return pl.Series( + [[[[None, v], None], None]], + dtype=pl.Array(pl.Array(pl.Array(dtype, 2), 2), 2), + ) + + def materialize_primitive(v: Any) -> pl.Series: + return pl.Series([v], dtype=dtype) + + def materialize_series( + l: Any, # noqa: E741 + r: Any, + o: Any, + ) -> tuple[pl.Series, pl.Series, pl.Series]: + nonlocal dtype + + dtype = dtypes[0] + l = { # noqa: E741 + "left": materialize_array, + "left3": materialize_array3, + "both": materialize_array, + "both3": materialize_array3, + "right": materialize_primitive, + "right3": materialize_primitive, + "none": materialize_primitive, + }[array_side](l) # fmt: skip + + dtype = dtypes[1] + r = { + "left": materialize_primitive, + "left3": materialize_primitive, + "both": materialize_array, + "both3": materialize_array3, + "right": materialize_array, + "right3": materialize_array3, + "none": materialize_primitive, + }[array_side](r) # fmt: skip + + dtype = dtypes[2] + o = { + "left": materialize_array, + "left3": materialize_array3, + "both": materialize_array, + "both3": materialize_array3, + "right": materialize_array, + "right3": materialize_array3, + "none": materialize_primitive, + }[array_side](o) # fmt: skip + + assert l.len() == 1 + assert r.len() == 1 + assert o.len() == 1 + + return broadcast_series(l, r, o) + + # Signed + dtypes = [pl.Int8, pl.Int8, pl.Int8] + + l, r, o = materialize_series(2, 3, 5) # noqa: E741 + assert_series_equal(exec_op(l, r, op.add), o) + + l, r, o = materialize_series(-5, 127, 124) # noqa: E741 + assert_series_equal(exec_op(l, r, op.sub), o) + + l, r, o = materialize_series(-5, 127, -123) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mul), o) + + l, r, o = materialize_series(-5, 3, -2) # noqa: E741 + assert_series_equal(exec_op(l, r, op.floordiv), o) + + l, r, o = materialize_series(-5, 3, 1) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mod), o) + + dtypes = [pl.UInt8, pl.UInt8, pl.Float64] + l, r, o = materialize_series(2, 128, 0.015625) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + # Unsigned + dtypes = [pl.UInt8, pl.UInt8, pl.UInt8] + + l, r, o = materialize_series(2, 3, 5) # noqa: E741 + assert_series_equal(exec_op(l, r, op.add), o) + + l, r, o = materialize_series(2, 3, 255) # noqa: E741 + assert_series_equal(exec_op(l, r, op.sub), o) + + l, r, o = materialize_series(2, 128, 0) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mul), o) + + l, r, o = materialize_series(5, 2, 2) # noqa: E741 + assert_series_equal(exec_op(l, r, op.floordiv), o) + + l, r, o = materialize_series(5, 2, 1) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mod), o) + + dtypes = [pl.UInt8, pl.UInt8, pl.Float64] + l, r, o = materialize_series(2, 128, 0.015625) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + # Floats. Note we pick Float32 to ensure there is no accidental upcasting + # to Float64. + dtypes = [pl.Float32, pl.Float32, pl.Float32] + l, r, o = materialize_series(1.7, 2.3, 4.0) # noqa: E741 + assert_series_equal(exec_op(l, r, op.add), o) + + l, r, o = materialize_series(1.7, 2.3, -0.5999999999999999) # noqa: E741 + assert_series_equal(exec_op(l, r, op.sub), o) + + l, r, o = materialize_series(1.7, 2.3, 3.9099999999999997) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mul), o) + + l, r, o = materialize_series(7.0, 3.0, 2.0) # noqa: E741 + assert_series_equal(exec_op(l, r, op.floordiv), o) + + l, r, o = materialize_series(-5.0, 3.0, 1.0) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mod), o) + + l, r, o = materialize_series(2.0, 128.0, 0.015625) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + # + # Tests for zero behavior + # + + # Integer + + dtypes = [pl.UInt8, pl.UInt8, pl.UInt8] + + l, r, o = materialize_series(1, 0, None) # noqa: E741 + assert_series_equal(exec_op(l, r, op.floordiv), o) + assert_series_equal(exec_op(l, r, op.mod), o) + + l, r, o = materialize_series(0, 0, None) # noqa: E741 + assert_series_equal(exec_op(l, r, op.floordiv), o) + assert_series_equal(exec_op(l, r, op.mod), o) + + dtypes = [pl.UInt8, pl.UInt8, pl.Float64] + + l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + # Float + + dtypes = [pl.Float32, pl.Float32, pl.Float32] + + l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.floordiv), o) + + l, r, o = materialize_series(1, 0, float("nan")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mod), o) + + l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.floordiv), o) + + l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mod), o) + + l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + # + # Tests for NULL behavior + # + + for dtype, truediv_dtype in [ # type: ignore[misc] + [pl.Int8, pl.Float64], + [pl.Float32, pl.Float32], + ]: + for vals in [ + [None, None, None], + [0, None, None], + [None, 0, None], + [0, None, None], + [None, 0, None], + [3, None, None], + [None, 3, None], + ]: + dtypes = 3 * [dtype] + + l, r, o = materialize_series(*vals) # type: ignore[misc] # noqa: E741 + assert_series_equal(exec_op(l, r, op.add), o) + assert_series_equal(exec_op(l, r, op.sub), o) + assert_series_equal(exec_op(l, r, op.mul), o) + assert_series_equal(exec_op(l, r, op.floordiv), o) + assert_series_equal(exec_op(l, r, op.mod), o) + dtypes[2] = truediv_dtype # type: ignore[has-type] + l, r, o = materialize_series(*vals) # type: ignore[misc] # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + # Type upcasting for Boolean and Null + + # Check boolean upcasting + dtypes = [pl.Boolean, pl.UInt8, pl.UInt8] + + l, r, o = materialize_series(True, 3, 4) # noqa: E741 + assert_series_equal(exec_op(l, r, op.add), o) + + l, r, o = materialize_series(True, 3, 254) # noqa: E741 + assert_series_equal(exec_op(l, r, op.sub), o) + + l, r, o = materialize_series(True, 3, 3) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mul), o) + + l, r, o = materialize_series(True, 3, 0) # noqa: E741 + if array_side != "none": + # TODO: FIXME: We get an error on non-lists with this: + # "floor_div operation not supported for dtype `bool`" + assert_series_equal(exec_op(l, r, op.floordiv), o) + + l, r, o = materialize_series(True, 3, 1) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mod), o) + + dtypes = [pl.Boolean, pl.UInt8, pl.Float64] + l, r, o = materialize_series(True, 128, 0.0078125) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + # Check Null upcasting + dtypes = [pl.Null, pl.UInt8, pl.UInt8] + l, r, o = materialize_series(None, 3, None) # noqa: E741 + assert_series_equal(exec_op(l, r, op.add), o) + assert_series_equal(exec_op(l, r, op.sub), o) + assert_series_equal(exec_op(l, r, op.mul), o) + if array_side != "none": + assert_series_equal(exec_op(l, r, op.floordiv), o) + assert_series_equal(exec_op(l, r, op.mod), o) + + dtypes = [pl.Null, pl.UInt8, pl.Float64] + l, r, o = materialize_series(None, 3, None) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + +@pytest.mark.parametrize( + ("lhs_dtype", "rhs_dtype", "expected_dtype"), + [ + (pl.Array(pl.Int64, 2), pl.Int64, pl.Array(pl.Float64, 2)), + (pl.Array(pl.Float32, 2), pl.Float32, pl.Array(pl.Float32, 2)), + (pl.Array(pl.Duration("us"), 2), pl.Int64, pl.Array(pl.Duration("us"), 2)), + ], +) +def test_array_truediv_schema( + lhs_dtype: PolarsDataType, rhs_dtype: PolarsDataType, expected_dtype: PolarsDataType +) -> None: + schema = {"lhs": lhs_dtype, "rhs": rhs_dtype} + df = pl.DataFrame({"lhs": [[None, 10]], "rhs": 2}, schema=schema) + result = df.lazy().select(pl.col("lhs").truediv("rhs")).collect_schema()["lhs"] + assert result == expected_dtype + + +def test_array_literal_broadcast() -> None: + df = pl.DataFrame({"A": [[0.1, 0.2], [0.3, 0.4]]}).cast(pl.Array(float, 2)) + + lit = pl.lit([3, 5], pl.Array(float, 2)) + assert df.select( + mul=pl.all() * lit, + div=pl.all() / lit, + add=pl.all() + lit, + sub=pl.all() - lit, + div_=lit / pl.all(), + add_=lit + pl.all(), + sub_=lit - pl.all(), + mul_=lit * pl.all(), + ).to_dict(as_series=False) == { + "mul": [[0.30000000000000004, 1.0], [0.8999999999999999, 2.0]], + "div": [[0.03333333333333333, 0.04], [0.09999999999999999, 0.08]], + "add": [[3.1, 5.2], [3.3, 5.4]], + "sub": [[-2.9, -4.8], [-2.7, -4.6]], + "div_": [[30.0, 25.0], [10.0, 12.5]], + "add_": [[3.1, 5.2], [3.3, 5.4]], + "sub_": [[2.9, 4.8], [2.7, 4.6]], + "mul_": [[0.30000000000000004, 1.0], [0.8999999999999999, 2.0]], + } + + +def test_array_arith_double_nested_shape() -> None: + # Ensure the implementation doesn't just naively add the leaf arrays without + # checking the dimension. In this example both arrays have the leaf stride as + # 6, however one is (3, 2) while the other is (2, 3). + a = pl.Series([[[1, 1], [1, 1], [1, 1]]], dtype=pl.Array(pl.Array(pl.Int64, 2), 3)) + b = pl.Series([[[1, 1, 1], [1, 1, 1]]], dtype=pl.Array(pl.Array(pl.Int64, 3), 2)) + + with pytest.raises(InvalidOperationError, match="differing dtypes"): + a + b + + +@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +@pytest.mark.parametrize( + "broadcast_series", + BROADCAST_SERIES_COMBINATIONS, +) +@pytest.mark.slow +def test_array_numeric_op_validity_combination( + broadcast_series: Callable[ + [pl.Series, pl.Series, pl.Series], tuple[pl.Series, pl.Series, pl.Series] + ], + exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series], +) -> None: + import operator as op + + array_dtype = pl.Array(pl.Int64, 1) + + a = pl.Series("a", [[1], [2], None, [None], [11], [1111]], dtype=array_dtype) + b = pl.Series("b", [[1], [3], [11], [1111], None, [None]], dtype=array_dtype) + # expected result + e = pl.Series("a", [[2], [5], None, [None], None, [None]], dtype=array_dtype) + + assert_series_equal( + exec_op(a, b, op.add), + e, + ) + + a = pl.Series("a", [[1]], dtype=array_dtype) + b = pl.Series("b", [None], dtype=pl.Int64) + e = pl.Series("a", [[None]], dtype=array_dtype) + + a, b, e = broadcast_series(a, b, e) + assert_series_equal(exec_op(a, b, op.add), e) + + a = pl.Series("a", [None], dtype=array_dtype) + b = pl.Series("b", [1], dtype=pl.Int64) + e = pl.Series("a", [None], dtype=array_dtype) + + a, b, e = broadcast_series(a, b, e) + assert_series_equal(exec_op(a, b, op.add), e) + + a = pl.Series("a", [None], dtype=array_dtype) + b = pl.Series("b", [0], dtype=pl.Int64) + e = pl.Series("a", [None], dtype=array_dtype) + + a, b, e = broadcast_series(a, b, e) + assert_series_equal(exec_op(a, b, op.floordiv), e) + + # >1 level nested array + a = pl.Series( + # row 1: [ [1, NULL], NULL ] + # row 2: NULL + [[[1, None], None], None], + dtype=pl.Array(pl.Array(pl.Int64, 2), 2), + ) + b = pl.Series( + [[[0, 0], [0, 0]], [[0, 0], [0, 0]]], + dtype=pl.Array(pl.Array(pl.Int64, 2), 2), + ) + e = a # added 0 + assert_series_equal(exec_op(a, b, op.add), e) + + +def test_array_elementwise_arithmetic_19682() -> None: + dt = pl.Array(pl.Int64, (2, 3)) + + a = pl.Series("a", [[[1, 2, 3], [4, 5, 6]]], dt) + sc = pl.Series("a", [1]) + zfa = pl.Series("a", [[]], pl.Array(pl.Int64, 0)) + + assert_series_equal(a + a, pl.Series("a", [[[2, 4, 6], [8, 10, 12]]], dt)) + assert_series_equal(a + sc, pl.Series("a", [[[2, 3, 4], [5, 6, 7]]], dt)) + assert_series_equal(sc + a, pl.Series("a", [[[2, 3, 4], [5, 6, 7]]], dt)) + assert_series_equal(zfa + zfa, pl.Series("a", [[]], pl.Array(pl.Int64, 0))) + + +@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +def test_array_add_supertype( + exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series], +) -> None: + import operator as op + + a = pl.Series("a", [[1], [2]], dtype=pl.Array(pl.Int8, 1)) + b = pl.Series("b", [[1], [999]], dtype=pl.Array(pl.Int64, 1)) + + assert_series_equal( + exec_op(a, b, op.add), + pl.Series("a", [[2], [1001]], dtype=pl.Array(pl.Int64, 1)), + ) + + +@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +def test_array_arithmetic_dtype_mismatch( + exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series], +) -> None: + import operator as op + + a = pl.Series("a", [[1], [2]], dtype=pl.Array(pl.Int64, 1)) + b = pl.Series("b", [[1, 1], [999, 999]], dtype=pl.Array(pl.Int64, 2)) + + with pytest.raises(InvalidOperationError, match="differing dtypes"): + exec_op(a, b, op.add) + + a = pl.Series([[[1]], [[1]]], dtype=pl.Array(pl.List(pl.Int64), 1)) + b = pl.Series([1], dtype=pl.Int64) + + with pytest.raises( + InvalidOperationError, match="dtype was not array on all nesting levels" + ): + exec_op(a, a, op.add) + + with pytest.raises( + InvalidOperationError, match="dtype was not array on all nesting levels" + ): + exec_op(a, b, op.add) + + with pytest.raises( + InvalidOperationError, match="dtype was not array on all nesting levels" + ): + exec_op(b, a, op.add) diff --git a/py-polars/tests/unit/operations/arithmetic/test_list.py b/py-polars/tests/unit/operations/arithmetic/test_list.py new file mode 100644 index 000000000000..affd0e7aa6e1 --- /dev/null +++ b/py-polars/tests/unit/operations/arithmetic/test_list.py @@ -0,0 +1,1026 @@ +from __future__ import annotations + +import operator +from typing import TYPE_CHECKING, Any, Callable + +import pytest + +import polars as pl +from polars.exceptions import InvalidOperationError, ShapeError +from polars.testing import assert_frame_equal, assert_series_equal +from tests.unit.operations.arithmetic.utils import ( + BROADCAST_SERIES_COMBINATIONS, + EXEC_OP_COMBINATIONS, +) + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + + +@pytest.mark.parametrize( + "list_side", ["left", "left3", "both", "right3", "right", "none"] +) +@pytest.mark.parametrize( + "broadcast_series", + BROADCAST_SERIES_COMBINATIONS, +) +@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +@pytest.mark.slow +def test_list_arithmetic_values( + list_side: str, + broadcast_series: Callable[ + [pl.Series, pl.Series, pl.Series], tuple[pl.Series, pl.Series, pl.Series] + ], + exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series], +) -> None: + """ + Tests value correctness. + + This test checks for output value correctness (a + b == c) across different + codepaths, by wrapping the values (a, b, c) in different combinations of + list / primitive columns. + """ + import operator as op + + dtypes: list[Any] = [pl.Null, pl.Null, pl.Null] + dtype: Any = pl.Null + + def materialize_list(v: Any) -> pl.Series: + return pl.Series( + [[None, v, None]], + dtype=pl.List(dtype), + ) + + def materialize_list3(v: Any) -> pl.Series: + return pl.Series( + [[[[None, v], None], None]], + dtype=pl.List(pl.List(pl.List(dtype))), + ) + + def materialize_primitive(v: Any) -> pl.Series: + return pl.Series([v], dtype=dtype) + + def materialize_series( + l: Any, # noqa: E741 + r: Any, + o: Any, + ) -> tuple[pl.Series, pl.Series, pl.Series]: + nonlocal dtype + + dtype = dtypes[0] + l = { # noqa: E741 + "left": materialize_list, + "left3": materialize_list3, + "both": materialize_list, + "right": materialize_primitive, + "right3": materialize_primitive, + "none": materialize_primitive, + }[list_side](l) # fmt: skip + + dtype = dtypes[1] + r = { + "left": materialize_primitive, + "left3": materialize_primitive, + "both": materialize_list, + "right": materialize_list, + "right3": materialize_list3, + "none": materialize_primitive, + }[list_side](r) # fmt: skip + + dtype = dtypes[2] + o = { + "left": materialize_list, + "left3": materialize_list3, + "both": materialize_list, + "right": materialize_list, + "right3": materialize_list3, + "none": materialize_primitive, + }[list_side](o) # fmt: skip + + assert l.len() == 1 + assert r.len() == 1 + assert o.len() == 1 + + return broadcast_series(l, r, o) + + # Signed + dtypes = [pl.Int8, pl.Int8, pl.Int8] + + l, r, o = materialize_series(2, 3, 5) # noqa: E741 + assert_series_equal(exec_op(l, r, op.add), o) + + l, r, o = materialize_series(-5, 127, 124) # noqa: E741 + assert_series_equal(exec_op(l, r, op.sub), o) + + l, r, o = materialize_series(-5, 127, -123) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mul), o) + + l, r, o = materialize_series(-5, 3, -2) # noqa: E741 + assert_series_equal(exec_op(l, r, op.floordiv), o) + + l, r, o = materialize_series(-5, 3, 1) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mod), o) + + dtypes = [pl.UInt8, pl.UInt8, pl.Float64] + l, r, o = materialize_series(2, 128, 0.015625) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + # Unsigned + dtypes = [pl.UInt8, pl.UInt8, pl.UInt8] + + l, r, o = materialize_series(2, 3, 5) # noqa: E741 + assert_series_equal(exec_op(l, r, op.add), o) + + l, r, o = materialize_series(2, 3, 255) # noqa: E741 + assert_series_equal(exec_op(l, r, op.sub), o) + + l, r, o = materialize_series(2, 128, 0) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mul), o) + + l, r, o = materialize_series(5, 2, 2) # noqa: E741 + assert_series_equal(exec_op(l, r, op.floordiv), o) + + l, r, o = materialize_series(5, 2, 1) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mod), o) + + dtypes = [pl.UInt8, pl.UInt8, pl.Float64] + l, r, o = materialize_series(2, 128, 0.015625) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + # Floats. Note we pick Float32 to ensure there is no accidental upcasting + # to Float64. + dtypes = [pl.Float32, pl.Float32, pl.Float32] + l, r, o = materialize_series(1.7, 2.3, 4.0) # noqa: E741 + assert_series_equal(exec_op(l, r, op.add), o) + + l, r, o = materialize_series(1.7, 2.3, -0.5999999999999999) # noqa: E741 + assert_series_equal(exec_op(l, r, op.sub), o) + + l, r, o = materialize_series(1.7, 2.3, 3.9099999999999997) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mul), o) + + l, r, o = materialize_series(7.0, 3.0, 2.0) # noqa: E741 + assert_series_equal(exec_op(l, r, op.floordiv), o) + + l, r, o = materialize_series(-5.0, 3.0, 1.0) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mod), o) + + l, r, o = materialize_series(2.0, 128.0, 0.015625) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + # + # Tests for zero behavior + # + + # Integer + + dtypes = [pl.UInt8, pl.UInt8, pl.UInt8] + + l, r, o = materialize_series(1, 0, None) # noqa: E741 + assert_series_equal(exec_op(l, r, op.floordiv), o) + assert_series_equal(exec_op(l, r, op.mod), o) + + l, r, o = materialize_series(0, 0, None) # noqa: E741 + assert_series_equal(exec_op(l, r, op.floordiv), o) + assert_series_equal(exec_op(l, r, op.mod), o) + + dtypes = [pl.UInt8, pl.UInt8, pl.Float64] + + l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + # Float + + dtypes = [pl.Float32, pl.Float32, pl.Float32] + + l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.floordiv), o) + + l, r, o = materialize_series(1, 0, float("nan")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mod), o) + + l, r, o = materialize_series(1, 0, float("inf")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.floordiv), o) + + l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mod), o) + + l, r, o = materialize_series(0, 0, float("nan")) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + # + # Tests for NULL behavior + # + + for dtype, truediv_dtype in [ # type: ignore[misc] + [pl.Int8, pl.Float64], + [pl.Float32, pl.Float32], + ]: + for vals in [ + [None, None, None], + [0, None, None], + [None, 0, None], + [0, None, None], + [None, 0, None], + [3, None, None], + [None, 3, None], + ]: + dtypes = 3 * [dtype] + + l, r, o = materialize_series(*vals) # type: ignore[misc] # noqa: E741 + assert_series_equal(exec_op(l, r, op.add), o) + assert_series_equal(exec_op(l, r, op.sub), o) + assert_series_equal(exec_op(l, r, op.mul), o) + assert_series_equal(exec_op(l, r, op.floordiv), o) + assert_series_equal(exec_op(l, r, op.mod), o) + dtypes[2] = truediv_dtype # type: ignore[has-type] + l, r, o = materialize_series(*vals) # type: ignore[misc] # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + # Type upcasting for Boolean and Null + + # Check boolean upcasting + dtypes = [pl.Boolean, pl.UInt8, pl.UInt8] + + l, r, o = materialize_series(True, 3, 4) # noqa: E741 + assert_series_equal(exec_op(l, r, op.add), o) + + l, r, o = materialize_series(True, 3, 254) # noqa: E741 + assert_series_equal(exec_op(l, r, op.sub), o) + + l, r, o = materialize_series(True, 3, 3) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mul), o) + + l, r, o = materialize_series(True, 3, 0) # noqa: E741 + if list_side != "none": + # TODO: FIXME: We get an error on non-lists with this: + # "floor_div operation not supported for dtype `bool`" + assert_series_equal(exec_op(l, r, op.floordiv), o) + + l, r, o = materialize_series(True, 3, 1) # noqa: E741 + assert_series_equal(exec_op(l, r, op.mod), o) + + dtypes = [pl.Boolean, pl.UInt8, pl.Float64] + l, r, o = materialize_series(True, 128, 0.0078125) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + # Check Null upcasting + dtypes = [pl.Null, pl.UInt8, pl.UInt8] + l, r, o = materialize_series(None, 3, None) # noqa: E741 + assert_series_equal(exec_op(l, r, op.add), o) + assert_series_equal(exec_op(l, r, op.sub), o) + assert_series_equal(exec_op(l, r, op.mul), o) + if list_side != "none": + assert_series_equal(exec_op(l, r, op.floordiv), o) + assert_series_equal(exec_op(l, r, op.mod), o) + + dtypes = [pl.Null, pl.UInt8, pl.Float64] + l, r, o = materialize_series(None, 3, None) # noqa: E741 + assert_series_equal(exec_op(l, r, op.truediv), o) + + +@pytest.mark.parametrize( + ("lhs_dtype", "rhs_dtype", "expected_dtype"), + [ + (pl.List(pl.Int64), pl.Int64, pl.List(pl.Float64)), + (pl.List(pl.Float32), pl.Float32, pl.List(pl.Float32)), + (pl.List(pl.Duration("us")), pl.Int64, pl.List(pl.Duration("us"))), + ], +) +def test_list_truediv_schema( + lhs_dtype: PolarsDataType, rhs_dtype: PolarsDataType, expected_dtype: PolarsDataType +) -> None: + schema = {"lhs": lhs_dtype, "rhs": rhs_dtype} + df = pl.DataFrame({"lhs": [[None, 10]], "rhs": 2}, schema=schema) + result = df.lazy().select(pl.col("lhs").truediv("rhs")).collect_schema()["lhs"] + assert result == expected_dtype + + +@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +def test_list_add_supertype( + exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series], +) -> None: + import operator as op + + a = pl.Series("a", [[1], [2]], dtype=pl.List(pl.Int8)) + b = pl.Series("b", [[1], [999]], dtype=pl.List(pl.Int64)) + + assert_series_equal( + exec_op(a, b, op.add), + pl.Series("a", [[2], [1001]], dtype=pl.List(pl.Int64)), + ) + + +@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +@pytest.mark.parametrize( + "broadcast_series", + BROADCAST_SERIES_COMBINATIONS, +) +@pytest.mark.slow +def test_list_numeric_op_validity_combination( + broadcast_series: Callable[ + [pl.Series, pl.Series, pl.Series], tuple[pl.Series, pl.Series, pl.Series] + ], + exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series], +) -> None: + import operator as op + + a = pl.Series("a", [[1], [2], None, [None], [11], [1111]], dtype=pl.List(pl.Int32)) + b = pl.Series("b", [[1], [3], [11], [1111], None, [None]], dtype=pl.List(pl.Int64)) + # expected result + e = pl.Series("a", [[2], [5], None, [None], None, [None]], dtype=pl.List(pl.Int64)) + + assert_series_equal( + exec_op(a, b, op.add), + e, + ) + + a = pl.Series("a", [[1]], dtype=pl.List(pl.Int32)) + b = pl.Series("b", [None], dtype=pl.Int64) + e = pl.Series("a", [[None]], dtype=pl.List(pl.Int64)) + + a, b, e = broadcast_series(a, b, e) + assert_series_equal(exec_op(a, b, op.add), e) + + a = pl.Series("a", [None], dtype=pl.List(pl.Int32)) + b = pl.Series("b", [1], dtype=pl.Int64) + e = pl.Series("a", [None], dtype=pl.List(pl.Int64)) + + a, b, e = broadcast_series(a, b, e) + assert_series_equal(exec_op(a, b, op.add), e) + + a = pl.Series("a", [None], dtype=pl.List(pl.Int32)) + b = pl.Series("b", [0], dtype=pl.Int64) + e = pl.Series("a", [None], dtype=pl.List(pl.Int64)) + + a, b, e = broadcast_series(a, b, e) + assert_series_equal(exec_op(a, b, op.floordiv), e) + + +def test_list_add_alignment() -> None: + a = pl.Series("a", [[1, 1], [1, 1, 1]]) + b = pl.Series("b", [[1, 1, 1], [1, 1]]) + + df = pl.DataFrame([a, b]) + + with pytest.raises(ShapeError): + df.select(x=pl.col("a") + pl.col("b")) + + # Test masking and slicing + a = pl.Series("a", [[1, 1, 1], [1], [1, 1], [1, 1, 1]]) + b = pl.Series("b", [[1, 1], [1], [1, 1, 1], [1]]) + c = pl.Series("c", [1, 1, 1, 1]) + p = pl.Series("p", [True, True, False, False]) + + df = pl.DataFrame([a, b, c, p]).filter("p").slice(1) + + for rhs in [pl.col("b"), pl.lit(1), pl.col("c"), pl.lit([1])]: + assert_series_equal( + df.select(x=pl.col("a") + rhs).to_series(), pl.Series("x", [[2]]) + ) + + df = df.vstack(df) + + for rhs in [pl.col("b"), pl.lit(1), pl.col("c"), pl.lit([1])]: + assert_series_equal( + df.select(x=pl.col("a") + rhs).to_series(), pl.Series("x", [[2], [2]]) + ) + + +@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +@pytest.mark.slow +def test_list_add_empty_lists( + exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series], +) -> None: + l = pl.Series( # noqa: E741 + "x", + [[[[]], []], []], + ) + r = pl.Series([1]) + + assert_series_equal( + exec_op(l, r, operator.add), + pl.Series("x", [[[[]], []], []], dtype=pl.List(pl.List(pl.List(pl.Int64)))), + ) + + l = pl.Series( # noqa: E741 + "x", + [[[[]], None], []], + ) + r = pl.Series([1]) + + assert_series_equal( + exec_op(l, r, operator.add), + pl.Series("x", [[[[]], None], []], dtype=pl.List(pl.List(pl.List(pl.Int64)))), + ) + + +@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +def test_list_to_list_arithmetic_double_nesting_raises_error( + exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series], +) -> None: + s = pl.Series(dtype=pl.List(pl.List(pl.Int32))) + + with pytest.raises( + InvalidOperationError, + match="cannot add two list columns with non-numeric inner types", + ): + exec_op(s, s, operator.add) + + +@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +def test_list_add_height_mismatch( + exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series], +) -> None: + s = pl.Series([[1], [2], [3]], dtype=pl.List(pl.Int32)) + + # TODO: Make the error type consistently a ShapeError + with pytest.raises( + (ShapeError, InvalidOperationError), + match="length", + ): + exec_op(s, pl.Series([1, 1]), operator.add) + + +@pytest.mark.parametrize( + "op", + [ + operator.add, + operator.sub, + operator.mul, + operator.floordiv, + operator.mod, + operator.truediv, + ], +) +@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +@pytest.mark.slow +def test_list_date_to_numeric_arithmetic_raises_error( + op: Callable[[Any], Any], exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series] +) -> None: + l = pl.Series([1], dtype=pl.Date) # noqa: E741 + r = pl.Series([[1]], dtype=pl.List(pl.Int32)) + + exec_op(l.to_physical(), r, op) + + # TODO(_): Ideally this always raises InvalidOperationError. The TypeError + # is being raised by checks on the Python side that should be moved to Rust. + with pytest.raises((InvalidOperationError, TypeError)): + exec_op(l, r, op) + + +@pytest.mark.parametrize( + ("expected", "expr", "column_names"), + [ + ([[2, 4], [6]], lambda a, b: a + b, ("a", "a")), + ([[0, 0], [0]], lambda a, b: a - b, ("a", "a")), + ([[1, 4], [9]], lambda a, b: a * b, ("a", "a")), + ([[1.0, 1.0], [1.0]], lambda a, b: a / b, ("a", "a")), + ([[0, 0], [0]], lambda a, b: a % b, ("a", "a")), + ( + [[3, 4], [7]], + lambda a, b: a + b, + ("a", "uint8"), + ), + ], +) +def test_list_arithmetic_same_size( + expected: Any, + expr: Callable[[pl.Series | pl.Expr, pl.Series | pl.Expr], pl.Series], + column_names: tuple[str, str], +) -> None: + df = pl.DataFrame( + [ + pl.Series("a", [[1, 2], [3]]), + pl.Series("uint8", [[2, 2], [4]], dtype=pl.List(pl.UInt8())), + pl.Series("nested", [[[1, 2]], [[3]]]), + pl.Series( + "nested_uint8", [[[1, 2]], [[3]]], dtype=pl.List(pl.List(pl.UInt8())) + ), + ] + ) + # Expr-based arithmetic: + assert_frame_equal( + df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))), + pl.Series(column_names[0], expected).to_frame(), + ) + # Direct arithmetic on the Series: + assert_series_equal( + expr(df[column_names[0]], df[column_names[1]]), + pl.Series(column_names[0], expected), + ) + + +@pytest.mark.parametrize( + ("a", "b", "expected"), + [ + ([[1, 2, 3]], [[1, None, 5]], [[2, None, 8]]), + ([[2], None, [5]], [None, [3], [2]], [None, None, [7]]), + ], +) +def test_list_arithmetic_nulls(a: list[Any], b: list[Any], expected: list[Any]) -> None: + series_a = pl.Series(a) + series_b = pl.Series(b) + series_expected = pl.Series(expected) + + # Same dtype: + assert_series_equal(series_a + series_b, series_expected) + + # Different dtype: + assert_series_equal( + series_a._recursive_cast_to_dtype(pl.Int32()) + + series_b._recursive_cast_to_dtype(pl.Int64()), + series_expected._recursive_cast_to_dtype(pl.Int64()), + ) + + +def test_list_arithmetic_error_cases() -> None: + # Different series length: + with pytest.raises(InvalidOperationError, match="different lengths"): + _ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1, 2], [3, 4]]) + with pytest.raises(InvalidOperationError, match="different lengths"): + _ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1, 2], None]) + + # Different list length: + with pytest.raises(ShapeError, match="lengths differed at index 0: 2 != 1"): + _ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1]]) + + with pytest.raises(ShapeError, match="lengths differed at index 0: 2 != 1"): + _ = pl.Series("a", [[1, 2], [2, 3]]) / pl.Series("b", [[1], None]) + + +@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) +def test_list_arithmetic_invalid_dtypes( + exec_op: Callable[[pl.Series, pl.Series, Any], pl.Series], +) -> None: + import operator as op + + a = pl.Series([[1, 2]]) + b = pl.Series(["hello"]) + + # Wrong types: + with pytest.raises( + InvalidOperationError, match="add operation not supported for dtypes" + ): + exec_op(a, b, op.add) + + a = pl.Series("a", [[1]]) + b = pl.Series("b", [[[1]]]) + + # list<->list is restricted to 1 level of nesting + with pytest.raises( + InvalidOperationError, + match="cannot add two list columns with non-numeric inner types", + ): + exec_op(a, b, op.add) + + # Ensure dtype is validated to be `List` at all nesting levels instead of panicking. + a = pl.Series([[[1]], [[1]]], dtype=pl.List(pl.Array(pl.Int64, 1))) + b = pl.Series([1], dtype=pl.Int64) + + with pytest.raises( + InvalidOperationError, match="dtype was not list on all nesting levels" + ): + exec_op(a, b, op.add) + + with pytest.raises( + InvalidOperationError, match="dtype was not list on all nesting levels" + ): + exec_op(b, a, op.add) + + +@pytest.mark.parametrize( + ("expected", "expr", "column_names"), + [ + # All 5 arithmetic operations: + ([[3, 4], [6]], lambda a, b: a + b, ("list", "int64")), + ([[-1, 0], [0]], lambda a, b: a - b, ("list", "int64")), + ([[2, 4], [9]], lambda a, b: a * b, ("list", "int64")), + ([[0.5, 1.0], [1.0]], lambda a, b: a / b, ("list", "int64")), + ([[1, 0], [0]], lambda a, b: a % b, ("list", "int64")), + # Different types: + ( + [[3, 4], [7]], + lambda a, b: a + b, + ("list", "uint8"), + ), + # Extra nesting + different types: + ( + [[[3, 4]], [[8]]], + lambda a, b: a + b, + ("nested", "int64"), + ), + # Primitive numeric on the left; only addition and multiplication are + # supported: + ([[3, 4], [6]], lambda a, b: a + b, ("int64", "list")), + ([[2, 4], [9]], lambda a, b: a * b, ("int64", "list")), + # Primitive numeric on the left with different types: + ( + [[3, 4], [7]], + lambda a, b: a + b, + ("uint8", "list"), + ), + ( + [[2, 4], [12]], + lambda a, b: a * b, + ("uint8", "list"), + ), + ], +) +def test_list_and_numeric_arithmetic_same_size( + expected: Any, + expr: Callable[[pl.Series | pl.Expr, pl.Series | pl.Expr], pl.Series], + column_names: tuple[str, str], +) -> None: + df = pl.DataFrame( + [ + pl.Series("list", [[1, 2], [3]]), + pl.Series("int64", [2, 3], dtype=pl.Int64()), + pl.Series("uint8", [2, 4], dtype=pl.UInt8()), + pl.Series("nested", [[[1, 2]], [[5]]]), + ] + ) + # Expr-based arithmetic: + assert_frame_equal( + df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))), + pl.Series(column_names[0], expected).to_frame(), + ) + # Direct arithmetic on the Series: + assert_series_equal( + expr(df[column_names[0]], df[column_names[1]]), + pl.Series(column_names[0], expected), + ) + + +@pytest.mark.parametrize( + ("a", "b", "expected"), + [ + # Null on numeric on the right: + ([[1, 2], [3]], [1, None], [[2, 3], [None]]), + # Null on list on the left: + ([[[1, 2]], [[3]]], [None, 1], [[[None, None]], [[4]]]), + # Extra nesting: + ([[[2, None]], [[3, 6]]], [3, 4], [[[5, None]], [[7, 10]]]), + ], +) +def test_list_and_numeric_arithmetic_nulls( + a: list[Any], b: list[Any], expected: list[Any] +) -> None: + series_a = pl.Series(a) + series_b = pl.Series(b) + series_expected = pl.Series(expected, dtype=series_a.dtype) + + # Same dtype: + assert_series_equal(series_a + series_b, series_expected) + + # Different dtype: + assert_series_equal( + series_a._recursive_cast_to_dtype(pl.Int32()) + + series_b._recursive_cast_to_dtype(pl.Int64()), + series_expected._recursive_cast_to_dtype(pl.Int64()), + ) + + # Swap sides: + assert_series_equal(series_b + series_a, series_expected) + assert_series_equal( + series_b._recursive_cast_to_dtype(pl.Int32()) + + series_a._recursive_cast_to_dtype(pl.Int64()), + series_expected._recursive_cast_to_dtype(pl.Int64()), + ) + + +def test_list_and_numeric_arithmetic_error_cases() -> None: + # Different series length: + with pytest.raises( + InvalidOperationError, match="series of different lengths: got 3 and 2" + ): + _ = pl.Series("a", [[1, 2], [3, 4], [5, 6]]) + pl.Series("b", [1, 2]) + with pytest.raises( + InvalidOperationError, match="series of different lengths: got 3 and 2" + ): + _ = pl.Series("a", [[1, 2], [3, 4], [5, 6]]) / pl.Series("b", [1, None]) + + # Wrong types: + with pytest.raises( + InvalidOperationError, match="add operation not supported for dtypes" + ): + _ = pl.Series("a", [[1, 2], [3, 4]]) + pl.Series("b", ["hello", "world"]) + + +@pytest.mark.parametrize("broadcast", [True, False]) +@pytest.mark.parametrize("dtype", [pl.Int64(), pl.Float64()]) +def test_list_arithmetic_div_ops_zero_denominator( + broadcast: bool, dtype: pl.DataType +) -> None: + # Notes + # * truediv (/) on integers upcasts to Float64 + # * Otherwise, we test floordiv (//) and module/rem (%) + # * On integers, 0-denominator is expected to output NULL + # * On floats, 0-denominator has different outputs, e.g. NaN, Inf, depending + # on a few factors (e.g. whether the numerator is also 0). + + s = pl.Series([[0], [1], [None], None]).cast(pl.List(dtype)) + + n = 1 if broadcast else s.len() + + # list<->primitive + + # truediv + assert_series_equal( + pl.Series([1]).new_from_index(0, n) / s, + pl.Series([[float("inf")], [1.0], [None], None], dtype=pl.List(pl.Float64)), + ) + + assert_series_equal( + s / pl.Series([1]).new_from_index(0, n), + pl.Series([[0.0], [1.0], [None], None], dtype=pl.List(pl.Float64)), + ) + + # floordiv + assert_series_equal( + pl.Series([1]).new_from_index(0, n) // s, + ( + pl.Series([[None], [1], [None], None], dtype=s.dtype) + if not dtype.is_float() + else pl.Series([[float("inf")], [1.0], [None], None], dtype=s.dtype) + ), + ) + + assert_series_equal( + s // pl.Series([0]).new_from_index(0, n), + ( + pl.Series([[None], [None], [None], None], dtype=s.dtype) + if not dtype.is_float() + else pl.Series( + [[float("nan")], [float("inf")], [None], None], dtype=s.dtype + ) + ), + ) + + # rem + assert_series_equal( + pl.Series([1]).new_from_index(0, n) % s, + ( + pl.Series([[None], [0], [None], None], dtype=s.dtype) + if not dtype.is_float() + else pl.Series([[float("nan")], [0.0], [None], None], dtype=s.dtype) + ), + ) + + assert_series_equal( + s % pl.Series([0]).new_from_index(0, n), + ( + pl.Series([[None], [None], [None], None], dtype=s.dtype) + if not dtype.is_float() + else pl.Series( + [[float("nan")], [float("nan")], [None], None], dtype=s.dtype + ) + ), + ) + + # list<->list + + # truediv + assert_series_equal( + pl.Series([[1]]).new_from_index(0, n) / s, + pl.Series([[float("inf")], [1.0], [None], None], dtype=pl.List(pl.Float64)), + ) + + assert_series_equal( + s / pl.Series([[0]]).new_from_index(0, n), + pl.Series( + [[float("nan")], [float("inf")], [None], None], dtype=pl.List(pl.Float64) + ), + ) + + # floordiv + assert_series_equal( + pl.Series([[1]]).new_from_index(0, n) // s, + ( + pl.Series([[None], [1], [None], None], dtype=s.dtype) + if not dtype.is_float() + else pl.Series([[float("inf")], [1.0], [None], None], dtype=s.dtype) + ), + ) + + assert_series_equal( + s // pl.Series([[0]]).new_from_index(0, n), + ( + pl.Series([[None], [None], [None], None], dtype=s.dtype) + if not dtype.is_float() + else pl.Series( + [[float("nan")], [float("inf")], [None], None], dtype=s.dtype + ) + ), + ) + + # rem + assert_series_equal( + pl.Series([[1]]).new_from_index(0, n) % s, + ( + pl.Series([[None], [0], [None], None], dtype=s.dtype) + if not dtype.is_float() + else pl.Series([[float("nan")], [0.0], [None], None], dtype=s.dtype) + ), + ) + + assert_series_equal( + s % pl.Series([[0]]).new_from_index(0, n), + ( + pl.Series([[None], [None], [None], None], dtype=s.dtype) + if not dtype.is_float() + else pl.Series( + [[float("nan")], [float("nan")], [None], None], dtype=s.dtype + ) + ), + ) + + +def test_list_to_primitive_arithmetic() -> None: + # Input data + # * List type: List(List(List(Int16))) (triple-nested) + # * Numeric type: Int32 + # + # Tests run + # Broadcast Operation + # | L | R | + # * list<->primitive | | | floor_div + # * primitive<->list | | | floor_div + # * list<->primitive | | * | subtract + # * primitive<->list | * | | subtract + # * list<->primitive | * | | subtract + # * primitive<->list | | * | subtract + # + # Notes + # * In floor_div, we check that results from a 0 denominator are masked out + # * We choose floor_div and subtract as they emit different results when + # sides are swapped + + # Create some non-zero start offsets and masked out rows. + lhs = ( + pl.Series( + [ + [[[None, None, None, None, None]]], # sliced out + # Nulls at every level XO + [[[3, 7]], [[-3], [None], [], [], None], [], None], + [[[1, 2, 3, 4, 5]]], # masked out + [[[3, 7]], [[0], [None], [], [], None]], + [[[3, 7]]], + ], + dtype=pl.List(pl.List(pl.List(pl.Int16))), + ) + .slice(1) + .to_frame() + .select(pl.when(pl.int_range(pl.len()) != 1).then(pl.first())) + .to_series() + ) + + # Note to reader: This is what our LHS looks like + assert_series_equal( + lhs, + pl.Series( + [ + [[[3, 7]], [[-3], [None], [], [], None], [], None], + None, + [[[3, 7]], [[0], [None], [], [], None]], + [[[3, 7]]], + ], + dtype=pl.List(pl.List(pl.List(pl.Int16))), + ), + ) + + class _: + # Floor div, no broadcasting + rhs = pl.Series([5, 1, 0, None], dtype=pl.Int32) + + assert len(lhs) == len(rhs) + + expect = pl.Series( + [ + [[[0, 1]], [[-1], [None], [], [], None], [], None], + None, + [[[None, None]], [[None], [None], [], [], None]], + [[[None, None]]], + ], + dtype=pl.List(pl.List(pl.List(pl.Int32))), + ) + + out = ( + pl.select(l=lhs, r=rhs) + .select(pl.col("l") // pl.col("r")) + .to_series() + .alias("") + ) + + assert_series_equal(out, expect) + + # Flipped + + expect = pl.Series( # noqa: PIE794 + [ + [[[1, 0]], [[-2], [None], [], [], None], [], None], + None, + [[[0, 0]], [[None], [None], [], [], None]], + [[[None, None]]], + ], + dtype=pl.List(pl.List(pl.List(pl.Int32))), + ) + + out = ( # noqa: PIE794 + pl.select(l=lhs, r=rhs) + .select(pl.col("r") // pl.col("l")) + .to_series() + .alias("") + ) + + assert_series_equal(out, expect) + + class _: # type: ignore[no-redef] + # Subtraction with broadcasting + rhs = pl.Series([1], dtype=pl.Int32) + + expect = pl.Series( + [ + [[[2, 6]], [[-4], [None], [], [], None], [], None], + None, + [[[2, 6]], [[-1], [None], [], [], None]], + [[[2, 6]]], + ], + dtype=pl.List(pl.List(pl.List(pl.Int32))), + ) + + out = pl.select(l=lhs).select(pl.col("l") - rhs).to_series().alias("") + + assert_series_equal(out, expect) + + # Flipped + + expect = pl.Series( # noqa: PIE794 + [ + [[[-2, -6]], [[4], [None], [], [], None], [], None], + None, + [[[-2, -6]], [[1], [None], [], [], None]], + [[[-2, -6]]], + ], + dtype=pl.List(pl.List(pl.List(pl.Int32))), + ) + + out = pl.select(l=lhs).select(rhs - pl.col("l")).to_series().alias("") # noqa: PIE794 + + assert_series_equal(out, expect) + + # Test broadcasting of the list side + lhs = lhs.slice(2, 1) + # Note to reader: This is what our LHS looks like + assert_series_equal( + lhs, + pl.Series( + [ + [[[3, 7]], [[0], [None], [], [], None]], + ], + dtype=pl.List(pl.List(pl.List(pl.Int16))), + ), + ) + + assert len(lhs) == 1 + + class _: # type: ignore[no-redef] + rhs = pl.Series([1, 2, 3, None, 5], dtype=pl.Int32) + + expect = pl.Series( + [ + [[[2, 6]], [[-1], [None], [], [], None]], + [[[1, 5]], [[-2], [None], [], [], None]], + [[[0, 4]], [[-3], [None], [], [], None]], + [[[None, None]], [[None], [None], [], [], None]], + [[[-2, 2]], [[-5], [None], [], [], None]], + ], + dtype=pl.List(pl.List(pl.List(pl.Int32))), + ) + + out = pl.select(r=rhs).select(lhs - pl.col("r")).to_series().alias("") + + assert_series_equal(out, expect) + + # Flipped + + expect = pl.Series( # noqa: PIE794 + [ + [[[-2, -6]], [[1], [None], [], [], None]], + [[[-1, -5]], [[2], [None], [], [], None]], + [[[0, -4]], [[3], [None], [], [], None]], + [[[None, None]], [[None], [None], [], [], None]], + [[[2, -2]], [[5], [None], [], [], None]], + ], + dtype=pl.List(pl.List(pl.List(pl.Int32))), + ) + + out = pl.select(r=rhs).select(pl.col("r") - lhs).to_series().alias("") # noqa: PIE794 + + assert_series_equal(out, expect) diff --git a/py-polars/tests/unit/operations/arithmetic/test_neg.py b/py-polars/tests/unit/operations/arithmetic/test_neg.py new file mode 100644 index 000000000000..90b0cac2795b --- /dev/null +++ b/py-polars/tests/unit/operations/arithmetic/test_neg.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from datetime import timedelta +from decimal import Decimal as D +from typing import TYPE_CHECKING + +import pytest + +import polars as pl +from polars.exceptions import InvalidOperationError +from polars.testing import assert_frame_equal +from polars.testing.asserts.series import assert_series_equal + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + + +@pytest.mark.parametrize( + "dtype", [pl.Int8, pl.Int16, pl.Int32, pl.Int64, pl.Float32, pl.Float64] +) +def test_neg_operator(dtype: PolarsDataType) -> None: + lf = pl.LazyFrame({"a": [-1, 0, 1, None]}, schema={"a": dtype}) + result = lf.select(-pl.col("a")) + expected = pl.LazyFrame({"a": [1, 0, -1, None]}, schema={"a": dtype}) + assert_frame_equal(result, expected) + + +def test_neg_method() -> None: + lf = pl.LazyFrame({"a": [-1, 0, 1, None]}) + result_op = lf.select(-pl.col("a")) + result_method = lf.select(pl.col("a").neg()) + assert_frame_equal(result_op, result_method) + + +def test_neg_decimal() -> None: + lf = pl.LazyFrame({"a": [D("-1.5"), D("0.0"), D("5.0"), None]}) + result = lf.select(-pl.col("a")) + expected = pl.LazyFrame({"a": [D("1.5"), D("0.0"), D("-5.0"), None]}) + assert_frame_equal(result, expected) + + +def test_neg_duration() -> None: + lf = pl.LazyFrame({"a": [timedelta(hours=2), timedelta(days=-2), None]}) + result = lf.select(-pl.col("a")) + expected = pl.LazyFrame({"a": [timedelta(hours=-2), timedelta(days=2), None]}) + assert_frame_equal(result, expected) + + +def test_neg_overflow_wrapping() -> None: + df = pl.DataFrame({"a": [-128]}, schema={"a": pl.Int8}) + result = df.select(-pl.col("a")) + assert_frame_equal(result, df) + + +def test_neg_unsigned_int() -> None: + df = pl.DataFrame({"a": [1, 2, 3]}, schema={"a": pl.UInt8}) + with pytest.raises( + InvalidOperationError, match="`neg` operation not supported for dtype `u8`" + ): + df.select(-pl.col("a")) + + +def test_neg_non_numeric() -> None: + df = pl.DataFrame({"a": ["p", "q", "r"]}) + with pytest.raises( + InvalidOperationError, match="`neg` operation not supported for dtype `str`" + ): + df.select(-pl.col("a")) + + +def test_neg_series_operator() -> None: + s = pl.Series("a", [-1, 0, 1, None]) + result = -s + expected = pl.Series("a", [1, 0, -1, None]) + assert_series_equal(result, expected) diff --git a/py-polars/tests/unit/operations/arithmetic/test_pos.py b/py-polars/tests/unit/operations/arithmetic/test_pos.py new file mode 100644 index 000000000000..cbe92a6f1ab9 --- /dev/null +++ b/py-polars/tests/unit/operations/arithmetic/test_pos.py @@ -0,0 +1,20 @@ +from datetime import datetime + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal + + +def test_pos() -> None: + df = pl.LazyFrame({"x": [1, 2]}) + result = df.select(+pl.col("x")) + assert_frame_equal(result, df) + + +def test_pos_string() -> None: + a = pl.Series("a", [""]) + assert_series_equal(+a, a) + + +def test_pos_datetime() -> None: + a = pl.Series("a", [datetime(2022, 1, 1)]) + assert_series_equal(+a, a) diff --git a/py-polars/tests/unit/operations/arithmetic/test_pow.py b/py-polars/tests/unit/operations/arithmetic/test_pow.py new file mode 100644 index 000000000000..581a544c17ff --- /dev/null +++ b/py-polars/tests/unit/operations/arithmetic/test_pow.py @@ -0,0 +1,59 @@ +import polars as pl + + +def test_pow_dtype() -> None: + df = pl.DataFrame( + { + "foo": [1, 2, 3, 4, 5], + "a": [1, 2, 3, 4, 5], + "b": [1, 2, 3, 4, 5], + "c": [1, 2, 3, 4, 5], + "d": [1, 2, 3, 4, 5], + "e": [1, 2, 3, 4, 5], + "f": [1, 2, 3, 4, 5], + "g": [1, 2, 1, 2, 1], + "h": [1, 2, 1, 2, 1], + }, + schema_overrides={ + "a": pl.Int64, + "b": pl.UInt64, + "c": pl.Int32, + "d": pl.UInt32, + "e": pl.Int16, + "f": pl.UInt16, + "g": pl.Int8, + "h": pl.UInt8, + }, + ).lazy() + + df = ( + df.with_columns([pl.col("foo").cast(pl.UInt32)]) + .with_columns( + (pl.col("foo") * 2**2).alias("scaled_foo"), + (pl.col("foo") * 2**2.1).alias("scaled_foo2"), + (pl.col("a") ** pl.col("h")).alias("a_pow_h"), + (pl.col("b") ** pl.col("h")).alias("b_pow_h"), + (pl.col("c") ** pl.col("h")).alias("c_pow_h"), + (pl.col("d") ** pl.col("h")).alias("d_pow_h"), + (pl.col("e") ** pl.col("h")).alias("e_pow_h"), + (pl.col("f") ** pl.col("h")).alias("f_pow_h"), + (pl.col("g") ** pl.col("h")).alias("g_pow_h"), + (pl.col("h") ** pl.col("h")).alias("h_pow_h"), + ) + .drop(["a", "b", "c", "d", "e", "f", "g", "h"]) + ) + expected = [ + pl.UInt32, + pl.UInt32, + pl.Float64, + pl.Int64, + pl.UInt64, + pl.Int32, + pl.UInt32, + pl.Int16, + pl.UInt16, + pl.Int8, + pl.UInt8, + ] + assert df.collect().dtypes == expected + assert df.collect_schema().dtypes() == expected diff --git a/py-polars/tests/unit/operations/arithmetic/utils.py b/py-polars/tests/unit/operations/arithmetic/utils.py new file mode 100644 index 000000000000..eb5c3d8acf3f --- /dev/null +++ b/py-polars/tests/unit/operations/arithmetic/utils.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from typing import Any + +import polars as pl + + +def exec_op_with_series(lhs: pl.Series, rhs: pl.Series, op: Any) -> pl.Series: + v: pl.Series = op(lhs, rhs) + return v + + +def exec_op_with_expr(lhs: pl.Series, rhs: pl.Series, op: Any) -> pl.Series: + return pl.select(lhs).lazy().select(op(pl.first(), rhs)).collect().to_series() + + +def exec_op_with_expr_no_type_coercion( + lhs: pl.Series, rhs: pl.Series, op: Any +) -> pl.Series: + return ( + pl.select(lhs) + .lazy() + .select(op(pl.first(), rhs)) + .collect(type_coercion=False) + .to_series() + ) + + +BROADCAST_LEN = 3 + + +def broadcast_left( + l: pl.Series, # noqa: E741 + r: pl.Series, + o: pl.Series, +) -> tuple[pl.Series, pl.Series, pl.Series]: + return l.new_from_index(0, BROADCAST_LEN), r, o.new_from_index(0, BROADCAST_LEN) + + +def broadcast_right( + l: pl.Series, # noqa: E741 + r: pl.Series, + o: pl.Series, +) -> tuple[pl.Series, pl.Series, pl.Series]: + return l, r.new_from_index(0, BROADCAST_LEN), o.new_from_index(0, BROADCAST_LEN) + + +def broadcast_both( + l: pl.Series, # noqa: E741 + r: pl.Series, + o: pl.Series, +) -> tuple[pl.Series, pl.Series, pl.Series]: + return ( + l.new_from_index(0, BROADCAST_LEN), + r.new_from_index(0, BROADCAST_LEN), + o.new_from_index(0, BROADCAST_LEN), + ) + + +def broadcast_none( + l: pl.Series, # noqa: E741 + r: pl.Series, + o: pl.Series, +) -> tuple[pl.Series, pl.Series, pl.Series]: + return l, r, o + + +BROADCAST_SERIES_COMBINATIONS = [ + broadcast_left, + broadcast_right, + broadcast_both, + broadcast_none, +] + +EXEC_OP_COMBINATIONS = [ + exec_op_with_series, + exec_op_with_expr, + exec_op_with_expr_no_type_coercion, +] diff --git a/py-polars/tests/unit/operations/map/__init__.py b/py-polars/tests/unit/operations/map/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py b/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py new file mode 100644 index 000000000000..61416054755d --- /dev/null +++ b/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py @@ -0,0 +1,528 @@ +from __future__ import annotations + +import datetime as dt +import json +import math +import re +from datetime import datetime +from functools import partial +from math import cosh +from typing import Any, Callable + +import numpy as np +import pytest + +import polars as pl +from polars._utils.udfs import _NUMPY_FUNCTIONS, BytecodeParser +from polars._utils.various import in_terminal_that_supports_colour +from polars.exceptions import PolarsInefficientMapWarning +from polars.testing import assert_frame_equal, assert_series_equal + +MY_CONSTANT = 3 +MY_DICT = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e"} +MY_LIST = [1, 2, 3] + +# column_name, function, expected_suggestion +TEST_CASES = [ + # --------------------------------------------- + # numeric expr: math, comparison, logic ops + # --------------------------------------------- + ("a", "lambda x: x + 1 - (2 / 3)", '(pl.col("a") + 1) - 0.6666666666666666'), + ("a", "lambda x: x // 1 % 2", '(pl.col("a") // 1) % 2'), + ("a", "lambda x: x & True", 'pl.col("a") & True'), + ("a", "lambda x: x | False", 'pl.col("a") | False'), + ("a", "lambda x: abs(x) != 3", 'pl.col("a").abs() != 3'), + ("a", "lambda x: int(x) > 1", 'pl.col("a").cast(pl.Int64) > 1'), + ("a", "lambda x: not (x > 1) or x == 2", '~(pl.col("a") > 1) | (pl.col("a") == 2)'), + ("a", "lambda x: x is None", 'pl.col("a") is None'), + ("a", "lambda x: x is not None", 'pl.col("a") is not None'), + ( + "a", + "lambda x: ((x * -x) ** x) * 1.0", + '((pl.col("a") * -pl.col("a")) ** pl.col("a")) * 1.0', + ), + ( + "a", + "lambda x: 1.0 * (x * (x**x))", + '1.0 * (pl.col("a") * (pl.col("a") ** pl.col("a")))', + ), + ( + "a", + "lambda x: (x / x) + ((x * x) - x)", + '(pl.col("a") / pl.col("a")) + ((pl.col("a") * pl.col("a")) - pl.col("a"))', + ), + ( + "a", + "lambda x: (10 - x) / (((x * 4) - x) // (2 + (x * (x - 1))))", + '(10 - pl.col("a")) / (((pl.col("a") * 4) - pl.col("a")) // (2 + (pl.col("a") * (pl.col("a") - 1))))', + ), + ("a", "lambda x: x in (2, 3, 4)", 'pl.col("a").is_in((2, 3, 4))'), + ("a", "lambda x: x not in (2, 3, 4)", '~pl.col("a").is_in((2, 3, 4))'), + ( + "a", + "lambda x: x in (1, 2, 3, 4, 3) and x % 2 == 0 and x > 0", + 'pl.col("a").is_in((1, 2, 3, 4, 3)) & ((pl.col("a") % 2) == 0) & (pl.col("a") > 0)', + ), + ("a", "lambda x: MY_CONSTANT + x", 'MY_CONSTANT + pl.col("a")'), + ( + "a", + "lambda x: (float(x) * int(x)) // 2", + '(pl.col("a").cast(pl.Float64) * pl.col("a").cast(pl.Int64)) // 2', + ), + ( + "a", + "lambda x: 1 / (1 + np.exp(-x))", + '1 / (1 + (-pl.col("a")).exp())', + ), + # --------------------------------------------- + # math module + # --------------------------------------------- + ("e", "lambda x: math.asin(x)", 'pl.col("e").arcsin()'), + ("e", "lambda x: math.asinh(x)", 'pl.col("e").arcsinh()'), + ("e", "lambda x: math.atan(x)", 'pl.col("e").arctan()'), + ("e", "lambda x: math.atanh(x)", 'pl.col("e").arctanh()'), + ("e", "lambda x: math.cos(x)", 'pl.col("e").cos()'), + ("e", "lambda x: math.degrees(x)", 'pl.col("e").degrees()'), + ("e", "lambda x: math.exp(x)", 'pl.col("e").exp()'), + ("e", "lambda x: math.log(x)", 'pl.col("e").log()'), + ("e", "lambda x: math.log10(x)", 'pl.col("e").log10()'), + ("e", "lambda x: math.log1p(x)", 'pl.col("e").log1p()'), + ("e", "lambda x: math.radians(x)", 'pl.col("e").radians()'), + ("e", "lambda x: math.sin(x)", 'pl.col("e").sin()'), + ("e", "lambda x: math.sinh(x)", 'pl.col("e").sinh()'), + ("e", "lambda x: math.sqrt(x)", 'pl.col("e").sqrt()'), + ("e", "lambda x: math.tan(x)", 'pl.col("e").tan()'), + ("e", "lambda x: math.tanh(x)", 'pl.col("e").tanh()'), + # --------------------------------------------- + # numpy module + # --------------------------------------------- + ("e", "lambda x: np.arccos(x)", 'pl.col("e").arccos()'), + ("e", "lambda x: np.arccosh(x)", 'pl.col("e").arccosh()'), + ("e", "lambda x: np.arcsin(x)", 'pl.col("e").arcsin()'), + ("e", "lambda x: np.arcsinh(x)", 'pl.col("e").arcsinh()'), + ("e", "lambda x: np.arctan(x)", 'pl.col("e").arctan()'), + ("e", "lambda x: np.arctanh(x)", 'pl.col("e").arctanh()'), + ("a", "lambda x: 0 + np.cbrt(x)", '0 + pl.col("a").cbrt()'), + ("e", "lambda x: np.ceil(x)", 'pl.col("e").ceil()'), + ("e", "lambda x: np.cos(x)", 'pl.col("e").cos()'), + ("e", "lambda x: np.cosh(x)", 'pl.col("e").cosh()'), + ("e", "lambda x: np.degrees(x)", 'pl.col("e").degrees()'), + ("e", "lambda x: np.exp(x)", 'pl.col("e").exp()'), + ("e", "lambda x: np.floor(x)", 'pl.col("e").floor()'), + ("e", "lambda x: np.log(x)", 'pl.col("e").log()'), + ("e", "lambda x: np.log10(x)", 'pl.col("e").log10()'), + ("e", "lambda x: np.log1p(x)", 'pl.col("e").log1p()'), + ("e", "lambda x: np.radians(x)", 'pl.col("e").radians()'), + ("a", "lambda x: np.sign(x)", 'pl.col("a").sign()'), + ("a", "lambda x: np.sin(x) + 1", 'pl.col("a").sin() + 1'), + ( + "a", # note: functions operate on consts + "lambda x: np.sin(3.14159265358979) + (x - 1) + abs(-3)", + '(np.sin(3.14159265358979) + (pl.col("a") - 1)) + abs(-3)', + ), + ("a", "lambda x: np.sinh(x) + 1", 'pl.col("a").sinh() + 1'), + ("a", "lambda x: np.sqrt(x) + 1", 'pl.col("a").sqrt() + 1'), + ("a", "lambda x: np.tan(x) + 1", 'pl.col("a").tan() + 1'), + ("e", "lambda x: np.tanh(x)", 'pl.col("e").tanh()'), + # --------------------------------------------- + # logical 'and/or' (validate nesting levels) + # --------------------------------------------- + ( + "a", + "lambda x: x > 1 or (x == 1 and x == 2)", + '(pl.col("a") > 1) | ((pl.col("a") == 1) & (pl.col("a") == 2))', + ), + ( + "a", + "lambda x: (x > 1 or x == 1) and x == 2", + '((pl.col("a") > 1) | (pl.col("a") == 1)) & (pl.col("a") == 2)', + ), + ( + "a", + "lambda x: x > 2 or x != 3 and x not in (0, 1, 4)", + '(pl.col("a") > 2) | ((pl.col("a") != 3) & ~pl.col("a").is_in((0, 1, 4)))', + ), + ( + "a", + "lambda x: x > 1 and x != 2 or x % 2 == 0 and x < 3", + '((pl.col("a") > 1) & (pl.col("a") != 2)) | (((pl.col("a") % 2) == 0) & (pl.col("a") < 3))', + ), + ( + "a", + "lambda x: x > 1 and (x != 2 or x % 2 == 0) and x < 3", + '(pl.col("a") > 1) & ((pl.col("a") != 2) | ((pl.col("a") % 2) == 0)) & (pl.col("a") < 3)', + ), + # --------------------------------------------- + # string exprs + # --------------------------------------------- + ("b", "lambda x: str(x).title()", 'pl.col("b").cast(pl.String).str.to_titlecase()'), + ( + "b", + 'lambda x: x.lower() + ":" + x.upper() + ":" + x.title()', + '(((pl.col("b").str.to_lowercase() + \':\') + pl.col("b").str.to_uppercase()) + \':\') + pl.col("b").str.to_titlecase()', + ), + ( + "b", + "lambda x: x.strip().startswith('#')", + """pl.col("b").str.strip_chars().str.starts_with('#')""", + ), + ( + "b", + """lambda x: x.rstrip().endswith(('!','#','?','"'))""", + """pl.col("b").str.strip_chars_end().str.contains(r'(!|\\#|\\?|")$')""", + ), + ( + "b", + """lambda x: x.lstrip().startswith(('!','#','?',"'"))""", + """pl.col("b").str.strip_chars_start().str.contains(r"^(!|\\#|\\?|')")""", + ), + ( + "b", + "lambda x: x.replace(':','')", + """pl.col("b").str.replace_all(':','',literal=True)""", + ), + ( + "b", + "lambda x: x.replace(':','',2)", + """pl.col("b").str.replace(':','',n=2,literal=True)""", + ), + ( + "b", + "lambda x: x.removeprefix('A').removesuffix('F')", + """pl.col("b").str.strip_prefix('A').str.strip_suffix('F')""", + ), + ( + "b", + "lambda x: x.zfill(8)", + """pl.col("b").str.zfill(8)""", + ), + # --------------------------------------------- + # json expr: load/extract + # --------------------------------------------- + ("c", "lambda x: json.loads(x)", 'pl.col("c").str.json_decode()'), + # --------------------------------------------- + # replace + # --------------------------------------------- + ("a", "lambda x: MY_DICT[x]", 'pl.col("a").replace_strict(MY_DICT)'), + ( + "a", + "lambda x: MY_DICT[x - 1] + MY_DICT[1 + x]", + '(pl.col("a") - 1).replace_strict(MY_DICT) + (1 + pl.col("a")).replace_strict(MY_DICT)', + ), + # --------------------------------------------- + # standard library datetime parsing + # --------------------------------------------- + ( + "d", + 'lambda x: datetime.strptime(x, "%Y-%m-%d")', + 'pl.col("d").str.to_datetime(format="%Y-%m-%d")', + ), + ( + "d", + 'lambda x: dt.datetime.strptime(x, "%Y-%m-%d")', + 'pl.col("d").str.to_datetime(format="%Y-%m-%d")', + ), + # --------------------------------------------- + # temporal attributes/methods + # --------------------------------------------- + ( + "f", + "lambda x: x.isoweekday()", + 'pl.col("f").dt.weekday()', + ), + ( + "f", + "lambda x: x.hour + x.minute + x.second", + '(pl.col("f").dt.hour() + pl.col("f").dt.minute()) + pl.col("f").dt.second()', + ), + # --------------------------------------------- + # Bitwise shifts + # --------------------------------------------- + ( + "a", + "lambda x: (3 << (30-x)) & 3", + '(3 * 2**(30 - pl.col("a"))).cast(pl.Int64) & 3', + ), + ( + "a", + "lambda x: (x << 32) & 3", + '(pl.col("a") * 2**32).cast(pl.Int64) & 3', + ), + ( + "a", + "lambda x: ((32-x) >> (3)) & 3", + '((32 - pl.col("a")) / 2**3).cast(pl.Int64) & 3', + ), + ( + "a", + "lambda x: (32 >> (3-x)) & 3", + '(32 / 2**(3 - pl.col("a"))).cast(pl.Int64) & 3', + ), +] + +NOOP_TEST_CASES = [ + "lambda x: x", + "lambda x, y: x + y", + "lambda x: x[0] + 1", + "lambda x: MY_LIST[x]", + "lambda x: MY_DICT[1]", + 'lambda x: "first" if x == 1 else "not first"', + 'lambda x: np.sign(x, casting="unsafe")', +] + +EVAL_ENVIRONMENT = { + "MY_CONSTANT": MY_CONSTANT, + "MY_DICT": MY_DICT, + "MY_LIST": MY_LIST, + "cosh": cosh, + "datetime": datetime, + "dt": dt, + "math": math, + "np": np, + "pl": pl, +} + + +@pytest.mark.parametrize( + "func", + NOOP_TEST_CASES, +) +def test_parse_invalid_function(func: str) -> None: + # functions we don't (yet?) offer suggestions for + parser = BytecodeParser(eval(func), map_target="expr") + assert not parser.can_attempt_rewrite() or not parser.to_expression("x") + + +@pytest.mark.parametrize( + ("col", "func", "expr_repr"), + TEST_CASES, +) +@pytest.mark.filterwarnings( + "ignore:invalid value encountered:RuntimeWarning", + "ignore:.*without specifying `return_dtype`:polars.exceptions.MapWithoutReturnDtypeWarning", +) +def test_parse_apply_functions(col: str, func: str, expr_repr: str) -> None: + with pytest.warns( + PolarsInefficientMapWarning, + match=r"(?s)Expr\.map_elements.*with this one instead", + ): + parser = BytecodeParser(eval(func), map_target="expr") + suggested_expression = parser.to_expression(col) + assert suggested_expression == expr_repr + + df = pl.DataFrame( + { + "a": [1, 2, 3], + "b": ["AB", "cd", "eF"], + "c": ['{"a": 1}', '{"b": 2}', '{"c": 3}'], + "d": ["2020-01-01", "2020-01-02", "2020-01-03"], + "e": [0.5, 0.4, 0.1], + "f": [ + datetime(1969, 12, 31), + datetime(2024, 5, 6), + datetime(2077, 10, 20), + ], + } + ) + + result_frame = df.select( + x=col, + y=eval(suggested_expression, EVAL_ENVIRONMENT), + ) + expected_frame = df.select( + x=pl.col(col), + y=pl.col(col).map_elements(eval(func)), + ) + assert_frame_equal( + result_frame, + expected_frame, + check_dtypes=(".dt." not in suggested_expression), + ) + + +@pytest.mark.filterwarnings( + "ignore:invalid value encountered:RuntimeWarning", + "ignore:.*without specifying `return_dtype`:polars.exceptions.MapWithoutReturnDtypeWarning", +) +def test_parse_apply_raw_functions() -> None: + lf = pl.LazyFrame({"a": [1.1, 2.0, 3.4]}) + + # test bare 'numpy' functions + for func_name in _NUMPY_FUNCTIONS: + func = getattr(np, func_name) + + # note: we can't parse/rewrite raw numpy functions... + parser = BytecodeParser(func, map_target="expr") + assert not parser.can_attempt_rewrite() + + # ...but we ARE still able to warn + with pytest.warns( + PolarsInefficientMapWarning, + match=rf"(?s)Expr\.map_elements.*Replace this expression.*np\.{func_name}", + ): + df1 = lf.select(pl.col("a").map_elements(func)).collect() + df2 = lf.select(getattr(pl.col("a"), func_name)()).collect() + assert_frame_equal(df1, df2) + + # test bare 'json.loads' + result_frames = [] + with pytest.warns( + PolarsInefficientMapWarning, + match=r"(?s)Expr\.map_elements.*with this one instead:.*\.str\.json_decode", + ): + for expr in ( + pl.col("value").str.json_decode(), + pl.col("value").map_elements(json.loads), + ): + result_frames.append( # noqa: PERF401 + pl.LazyFrame({"value": ['{"a":1, "b": true, "c": "xx"}', None]}) + .select(extracted=expr) + .unnest("extracted") + .collect() + ) + + assert_frame_equal(*result_frames) + + # test primitive python casts + for py_cast, pl_dtype in ((str, pl.String), (int, pl.Int64), (float, pl.Float64)): + with pytest.warns( + PolarsInefficientMapWarning, + match=rf'(?s)with this one instead.*pl\.col\("a"\)\.cast\(pl\.{pl_dtype.__name__}\)', + ): + assert_frame_equal( + lf.select(pl.col("a").map_elements(py_cast)).collect(), + lf.select(pl.col("a").cast(pl_dtype)).collect(), + ) + + +def test_parse_apply_miscellaneous() -> None: + # note: can also identify inefficient functions and methods as well as lambdas + class Test: + def x10(self, x: float) -> float: + return x * 10 + + def mcosh(self, x: float) -> float: + return cosh(x) + + parser = BytecodeParser(Test().x10, map_target="expr") + suggested_expression = parser.to_expression(col="colx") + assert suggested_expression == 'pl.col("colx") * 10' + + with pytest.warns( + PolarsInefficientMapWarning, + match=r"(?s)Series\.map_elements.*with this one instead.*s\.cosh\(\)", + ): + pl.Series("colx", [0.5, 0.25]).map_elements( + function=Test().mcosh, + return_dtype=pl.Float64, + ) + + # note: all constants - should not create a warning/suggestion + suggested_expression = BytecodeParser( + lambda x: MY_CONSTANT + 42, map_target="expr" + ).to_expression(col="colx") + assert suggested_expression is None + + # literals as method parameters + with pytest.warns( + PolarsInefficientMapWarning, + match=r"(?s)Series\.map_elements.*with this one instead.*\(np\.cos\(3\) \+ s\) - abs\(-1\)", + ): + s = pl.Series("srs", [0, 1, 2, 3, 4]) + assert_series_equal( + s.map_elements(lambda x: np.cos(3) + x - abs(-1), return_dtype=pl.Float64), + np.cos(3) + s - 1, + ) + + # if 's' is already the name of a global variable then the series alias + # used in the user warning will fall back (in priority order) through + # various aliases until it finds one that is available. + s, srs, series = -1, 0, 1 # type: ignore[assignment] + expr1 = BytecodeParser(lambda x: x + s, map_target="series") + expr2 = BytecodeParser(lambda x: srs + x + s, map_target="series") + expr3 = BytecodeParser(lambda x: srs + x + s - x + series, map_target="series") + + assert expr1.to_expression(col="srs") == "srs + s" + assert expr2.to_expression(col="srs") == "(srs + series) + s" + assert expr3.to_expression(col="srs") == "(((srs + srs0) + s) - srs0) + series" + + +@pytest.mark.parametrize( + ("name", "data", "func", "expr_repr"), + [ + ( + "srs", + [1, 2, 3], + lambda x: str(x), + "s.cast(pl.String)", + ), + ( + "", + [-20, -12, -5, 0, 5, 12, 20], + lambda x: (abs(x) != 12) and (x > 10 or x < -10 or x == 0), + "(s.abs() != 12) & ((s > 10) | (s < -10) | (s == 0))", + ), + ], +) +@pytest.mark.filterwarnings( + "ignore:.*without specifying `return_dtype`:polars.exceptions.MapWithoutReturnDtypeWarning" +) +def test_parse_apply_series( + name: str, data: list[Any], func: Callable[[Any], Any], expr_repr: str +) -> None: + # expression/series generate same warning, with 's' as the series placeholder + with pytest.warns( + PolarsInefficientMapWarning, match=r"(?s)Series\.map_elements.*s\.\w+\(" + ): + s = pl.Series(name, data) + + parser = BytecodeParser(func, map_target="series") + suggested_expression = parser.to_expression(s.name) + assert suggested_expression == expr_repr + + expected_series = s.map_elements(func) + result_series = eval(suggested_expression) + assert_series_equal(expected_series, result_series) + + +def test_expr_exact_warning_message() -> None: + red, green, end_escape = ( + ("\x1b[31m", "\x1b[32m", "\x1b[0m") + if in_terminal_that_supports_colour() + else ("", "", "") + ) + msg = re.escape( + "\n" + "Expr.map_elements is significantly slower than the native expressions API.\n" + "Only use if you absolutely CANNOT implement your logic otherwise.\n" + "Replace this expression...\n" + f' {red}- pl.col("a").map_elements(lambda x: ...){end_escape}\n' + "with this one instead:\n" + f' {green}+ pl.col("a") + 1{end_escape}\n' + ) + # Check the EXACT warning message. If modifying the message in the future, + # make sure to keep the `^` and `$`, and keep the assertion on `len(warnings)`. + with pytest.warns(PolarsInefficientMapWarning, match=rf"^{msg}$") as warnings: + df = pl.DataFrame({"a": [1, 2, 3]}) + df.select(pl.col("a").map_elements(lambda x: x + 1, return_dtype=pl.Int64)) + + assert len(warnings) == 1 + + +def test_omit_implicit_bool() -> None: + parser = BytecodeParser( + function=lambda x: x and x and x.date(), + map_target="expr", + ) + suggested_expression = parser.to_expression("d") + assert suggested_expression == 'pl.col("d").dt.date()' + + +def test_partial_functions_13523() -> None: + def plus(value: int, amount: int) -> int: + return value + amount + + data = {"a": [1, 2], "b": [3, 4]} + df = pl.DataFrame(data) + # should not warn + _ = df["a"].map_elements(partial(plus, amount=1)) diff --git a/py-polars/tests/unit/operations/map/test_map_batches.py b/py-polars/tests/unit/operations/map/test_map_batches.py new file mode 100644 index 000000000000..d0013eeb4f60 --- /dev/null +++ b/py-polars/tests/unit/operations/map/test_map_batches.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +from functools import reduce + +import numpy as np +import pytest + +import polars as pl +from polars.exceptions import ComputeError, InvalidOperationError +from polars.testing import assert_frame_equal + + +def test_map_return_py_object() -> None: + df = pl.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + + result = df.select([pl.all().map_batches(lambda s: reduce(lambda a, b: a + b, s))]) + + expected = pl.DataFrame({"A": [6], "B": [15]}) + assert_frame_equal(result, expected) + + +def test_map_no_dtype_set_8531() -> None: + df = pl.DataFrame({"a": [1]}) + + result = df.with_columns( + pl.col("a").map_batches(lambda x: x * 2).shift(n=0, fill_value=0) + ) + + expected = pl.DataFrame({"a": [2]}) + assert_frame_equal(result, expected) + + +def test_error_on_reducing_map() -> None: + df = pl.DataFrame( + {"id": [0, 0, 0, 1, 1, 1], "t": [2, 4, 5, 10, 11, 14], "y": [0, 1, 1, 2, 3, 4]} + ) + with pytest.raises( + InvalidOperationError, + match=( + r"output length of `map` \(1\) must be equal to " + r"the input length \(6\); consider using `apply` instead" + ), + ): + df.group_by("id").agg(pl.map_batches(["t", "y"], np.mean)) + + df = pl.DataFrame({"x": [1, 2, 3, 4], "group": [1, 2, 1, 2]}) + + with pytest.raises( + InvalidOperationError, + match=( + r"output length of `map` \(1\) must be equal to " + r"the input length \(4\); consider using `apply` instead" + ), + ): + df.select( + pl.col("x") + .map_batches( + lambda x: x.cut(breaks=[1, 2, 3], include_breaks=True).struct.unnest(), + is_elementwise=True, + ) + .over("group") + ) + + +def test_map_batches_group() -> None: + df = pl.DataFrame( + {"id": [0, 0, 0, 1, 1, 1], "t": [2, 4, 5, 10, 11, 14], "y": [0, 1, 1, 2, 3, 4]} + ) + assert df.group_by("id").agg(pl.col("t").map_batches(lambda s: s.sum())).sort( + "id" + ).to_dict(as_series=False) == {"id": [0, 1], "t": [[11], [35]]} + # If returns_scalar is True, the result won't be wrapped in a list: + assert df.group_by("id").agg( + pl.col("t").map_batches(lambda s: s.sum(), returns_scalar=True) + ).sort("id").to_dict(as_series=False) == {"id": [0, 1], "t": [11, 35]} + + +def test_ufunc_args() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [2, 4, 6]}) + result = df.select( + z=np.add(pl.col("a"), pl.col("b")) # type: ignore[call-overload] + ) + expected = pl.DataFrame({"z": [3, 6, 9]}) + assert_frame_equal(result, expected) + result = df.select(z=np.add(2, pl.col("a"))) # type: ignore[call-overload] + expected = pl.DataFrame({"z": [3, 4, 5]}) + assert_frame_equal(result, expected) + + +def test_lazy_map_schema() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}) + + # identity + assert_frame_equal(df.lazy().map_batches(lambda x: x).collect(), df) + + def custom(df: pl.DataFrame) -> pl.Series: + return df["a"] + + with pytest.raises( + ComputeError, + match="Expected 'LazyFrame.map' to return a 'DataFrame', got a", + ): + df.lazy().map_batches(custom).collect() # type: ignore[arg-type] + + def custom2( + df: pl.DataFrame, + ) -> pl.DataFrame: + # changes schema + return df.select(pl.all().cast(pl.String)) + + with pytest.raises( + ComputeError, + match="The output schema of 'LazyFrame.map' is incorrect. Expected", + ): + df.lazy().map_batches(custom2).collect() + + assert df.lazy().map_batches( + custom2, validate_output_schema=False + ).collect().to_dict(as_series=False) == {"a": ["1", "2", "3"], "b": ["a", "b", "c"]} + + +def test_map_batches_collect_schema_17327() -> None: + df = pl.LazyFrame({"a": [1, 1, 1], "b": [2, 3, 4]}) + q = df.group_by("a").agg(pl.col("b").map_batches(lambda s: s)) + expected = pl.Schema({"a": pl.Int64(), "b": pl.List(pl.Unknown)}) + assert q.collect_schema() == expected diff --git a/py-polars/tests/unit/operations/map/test_map_elements.py b/py-polars/tests/unit/operations/map/test_map_elements.py new file mode 100644 index 000000000000..ea7e1b9405cb --- /dev/null +++ b/py-polars/tests/unit/operations/map/test_map_elements.py @@ -0,0 +1,401 @@ +from __future__ import annotations + +import json +from datetime import date, datetime, timedelta +from typing import Any, NamedTuple + +import numpy as np +import pytest + +import polars as pl +from polars.exceptions import PolarsInefficientMapWarning +from polars.testing import assert_frame_equal, assert_series_equal +from tests.unit.conftest import NUMERIC_DTYPES, TEMPORAL_DTYPES + +pytestmark = pytest.mark.filterwarnings( + "ignore::polars.exceptions.PolarsInefficientMapWarning" +) + + +def test_map_elements_infer_list() -> None: + df = pl.DataFrame( + { + "int": [1, 2], + "str": ["a", "b"], + "bool": [True, None], + } + ) + assert df.select([pl.all().map_elements(lambda x: [x])]).dtypes == [pl.List] * 3 + + +def test_map_elements_upcast_null_dtype_empty_list() -> None: + df = pl.DataFrame({"a": [1, 2]}) + out = df.select( + pl.col("a").map_elements(lambda _: [], return_dtype=pl.List(pl.Int64)) + ) + assert_frame_equal( + out, pl.DataFrame({"a": [[], []]}, schema={"a": pl.List(pl.Int64)}) + ) + + +def test_map_elements_arithmetic_consistency() -> None: + df = pl.DataFrame({"A": ["a", "a"], "B": [2, 3]}) + with pytest.warns(PolarsInefficientMapWarning, match="with this one instead"): + assert df.group_by("A").agg( + pl.col("B").map_elements( + lambda x: x + 1.0, return_dtype=pl.List(pl.Float64) + ) + )["B"].to_list() == [[3.0, 4.0]] + + +def test_map_elements_struct() -> None: + df = pl.DataFrame( + { + "A": ["a", "a", None], + "B": [2, 3, None], + "C": [True, False, None], + "D": [12.0, None, None], + "E": [None, [1], [2, 3]], + } + ) + out = df.with_columns(pl.struct(df.columns).alias("struct")).select( + pl.col("struct").map_elements(lambda x: x["A"]).alias("A_field"), + pl.col("struct").map_elements(lambda x: x["B"]).alias("B_field"), + pl.col("struct").map_elements(lambda x: x["C"]).alias("C_field"), + pl.col("struct").map_elements(lambda x: x["D"]).alias("D_field"), + pl.col("struct").map_elements(lambda x: x["E"]).alias("E_field"), + ) + expected = pl.DataFrame( + { + "A_field": ["a", "a", None], + "B_field": [2, 3, None], + "C_field": [True, False, None], + "D_field": [12.0, None, None], + "E_field": [None, [1], [2, 3]], + } + ) + + assert_frame_equal(out, expected) + + +def test_map_elements_numpy_int_out() -> None: + df = pl.DataFrame({"col1": [2, 4, 8, 16]}) + result = df.with_columns( + pl.col("col1").map_elements(lambda x: np.left_shift(x, 8)).alias("result") + ) + expected = pl.DataFrame({"col1": [2, 4, 8, 16], "result": [512, 1024, 2048, 4096]}) + assert_frame_equal(result, expected) + + df = pl.DataFrame({"col1": [2, 4, 8, 16], "shift": [1, 1, 2, 2]}) + result = df.select( + pl.struct(["col1", "shift"]) + .map_elements(lambda cols: np.left_shift(cols["col1"], cols["shift"])) + .alias("result") + ) + expected = pl.DataFrame({"result": [4, 8, 32, 64]}) + assert_frame_equal(result, expected) + + +def test_datelike_identity() -> None: + for s in [ + pl.Series([datetime(year=2000, month=1, day=1)]), + pl.Series([timedelta(hours=2)]), + pl.Series([date(year=2000, month=1, day=1)]), + ]: + assert s.map_elements(lambda x: x).to_list() == s.to_list() + + +def test_map_elements_list_any_value_fallback() -> None: + with pytest.warns( + PolarsInefficientMapWarning, + match=r'(?s)with this one instead:.*pl.col\("text"\).str.json_decode()', + ): + df = pl.DataFrame({"text": ['[{"x": 1, "y": 2}, {"x": 3, "y": 4}]']}) + assert df.select( + pl.col("text").map_elements( + json.loads, + return_dtype=pl.List(pl.Struct({"x": pl.Int64, "y": pl.Int64})), + ) + ).to_dict(as_series=False) == {"text": [[{"x": 1, "y": 2}, {"x": 3, "y": 4}]]} + + # starts with empty list '[]' + df = pl.DataFrame( + { + "text": [ + "[]", + '[{"x": 1, "y": 2}, {"x": 3, "y": 4}]', + '[{"x": 1, "y": 2}]', + ] + } + ) + assert df.select( + pl.col("text").map_elements( + json.loads, + return_dtype=pl.List(pl.Struct({"x": pl.Int64, "y": pl.Int64})), + ) + ).to_dict(as_series=False) == { + "text": [[], [{"x": 1, "y": 2}, {"x": 3, "y": 4}], [{"x": 1, "y": 2}]] + } + + +def test_map_elements_all_types() -> None: + # test we don't panic + dtypes = NUMERIC_DTYPES + TEMPORAL_DTYPES + [pl.Decimal(None, 2)] + for dtype in dtypes: + pl.Series([1, 2, 3, 4, 5], dtype=dtype).map_elements(lambda x: x) + + +def test_map_elements_type_propagation() -> None: + assert ( + pl.from_dict( + { + "a": [1, 2, 3], + "b": [{"c": 1, "d": 2}, {"c": 2, "d": 3}, {"c": None, "d": None}], + } + ) + .group_by("a", maintain_order=True) + .agg( + [ + pl.when(~pl.col("b").has_nulls()) + .then( + pl.col("b").map_elements( + lambda s: s[0]["c"], + return_dtype=pl.Float64, + ) + ) + .otherwise(None) + ] + ) + ).to_dict(as_series=False) == {"a": [1, 2, 3], "b": [1.0, 2.0, None]} + + +def test_empty_list_in_map_elements() -> None: + df = pl.DataFrame( + {"a": [[1], [1, 2], [3, 4], [5, 6]], "b": [[3], [1, 2], [1, 2], [4, 5]]} + ) + + assert df.select( + pl.struct(["a", "b"]).map_elements( + lambda row: list(set(row["a"]) & set(row["b"])) + ) + ).to_dict(as_series=False) == {"a": [[], [1, 2], [], [5]]} + + +@pytest.mark.parametrize("value", [1, True, "abc", [1, 2], {"a": 1}]) +@pytest.mark.parametrize("return_value", [1, True, "abc", [1, 2], {"a": 1}]) +def test_map_elements_skip_nulls(value: Any, return_value: Any) -> None: + s = pl.Series([value, None]) + + result = s.map_elements(lambda x: return_value, skip_nulls=True).to_list() + assert result == [return_value, None] + + result = s.map_elements(lambda x: return_value, skip_nulls=False).to_list() + assert result == [return_value, return_value] + + +def test_map_elements_object_dtypes() -> None: + with pytest.warns( + PolarsInefficientMapWarning, + match=r"(?s)Replace this expression.*lambda x:", + ): + assert pl.DataFrame( + {"a": pl.Series([1, 2, "a", 4, 5], dtype=pl.Object)} + ).with_columns( + pl.col("a").map_elements(lambda x: x * 2, return_dtype=pl.Object), + pl.col("a") + .map_elements( + lambda x: isinstance(x, (int, float)), return_dtype=pl.Boolean + ) + .alias("is_numeric1"), + pl.col("a") + .map_elements( + lambda x: isinstance(x, (int, float)), return_dtype=pl.Boolean + ) + .alias("is_numeric_infer"), + ).to_dict(as_series=False) == { + "a": [2, 4, "aa", 8, 10], + "is_numeric1": [True, True, False, True, True], + "is_numeric_infer": [True, True, False, True, True], + } + + +def test_map_elements_explicit_list_output_type() -> None: + out = pl.DataFrame({"str": ["a", "b"]}).with_columns( + pl.col("str").map_elements( + lambda _: pl.Series([1, 2, 3]), return_dtype=pl.List(pl.Int64) + ) + ) + + assert out.dtypes == [pl.List(pl.Int64)] + assert out.to_dict(as_series=False) == {"str": [[1, 2, 3], [1, 2, 3]]} + + +def test_map_elements_dict() -> None: + with pytest.warns( + PolarsInefficientMapWarning, + match=r'(?s)with this one instead:.*pl.col\("abc"\).str.json_decode()', + ): + df = pl.DataFrame({"abc": ['{"A":"Value1"}', '{"B":"Value2"}']}) + assert df.select( + pl.col("abc").map_elements( + json.loads, return_dtype=pl.Struct({"A": pl.String, "B": pl.String}) + ) + ).to_dict(as_series=False) == { + "abc": [{"A": "Value1", "B": None}, {"A": None, "B": "Value2"}] + } + assert pl.DataFrame( + {"abc": ['{"A":"Value1", "B":"Value2"}', '{"B":"Value3"}']} + ).select( + pl.col("abc").map_elements( + json.loads, return_dtype=pl.Struct({"A": pl.String, "B": pl.String}) + ) + ).to_dict(as_series=False) == { + "abc": [{"A": "Value1", "B": "Value2"}, {"A": None, "B": "Value3"}] + } + + +def test_map_elements_pass_name() -> None: + df = pl.DataFrame( + { + "bar": [1, 1, 2], + "foo": [1, 2, 3], + } + ) + + mapper = {"foo": "foo1"} + + def element_mapper(s: pl.Series) -> pl.Series: + return pl.Series([mapper[s.name]]) + + assert df.group_by("bar", maintain_order=True).agg( + pl.col("foo").map_elements(element_mapper, pass_name=True), + ).to_dict(as_series=False) == {"bar": [1, 2], "foo": [["foo1"], ["foo1"]]} + + +def test_map_elements_binary() -> None: + assert pl.DataFrame({"bin": [b"\x11" * 12, b"\x22" * 12, b"\xaa" * 12]}).select( + pl.col("bin").map_elements(bytes.hex) + ).to_dict(as_series=False) == { + "bin": [ + "111111111111111111111111", + "222222222222222222222222", + "aaaaaaaaaaaaaaaaaaaaaaaa", + ] + } + + +def test_map_elements_set_datetime_output_8984() -> None: + df = pl.DataFrame({"a": [""]}) + payload = datetime(2001, 1, 1) + assert df.select( + pl.col("a").map_elements(lambda _: payload, return_dtype=pl.Datetime), + )["a"].to_list() == [payload] + + +def test_map_elements_dict_order_10128() -> None: + df = pl.select(pl.lit("").map_elements(lambda x: {"c": 1, "b": 2, "a": 3})) + assert df.to_dict(as_series=False) == {"literal": [{"c": 1, "b": 2, "a": 3}]} + + +def test_map_elements_10237() -> None: + df = pl.DataFrame({"a": [1, 2, 3]}) + assert ( + df.select(pl.all().map_elements(lambda x: x > 50))["a"].to_list() == [False] * 3 + ) + + +def test_map_elements_on_empty_col_10639() -> None: + df = pl.DataFrame({"A": [], "B": []}, schema={"A": pl.Float32, "B": pl.Float32}) + res = df.group_by("B").agg( + pl.col("A") + .map_elements(lambda x: x, return_dtype=pl.Int32, strategy="threading") + .alias("Foo") + ) + assert res.to_dict(as_series=False) == { + "B": [], + "Foo": [], + } + + res = df.group_by("B").agg( + pl.col("A") + .map_elements(lambda x: x, return_dtype=pl.Int32, strategy="thread_local") + .alias("Foo") + ) + assert res.to_dict(as_series=False) == { + "B": [], + "Foo": [], + } + + +def test_map_elements_chunked_14390() -> None: + s = pl.concat(2 * [pl.Series([1])], rechunk=False) + assert s.n_chunks() > 1 + with pytest.warns(PolarsInefficientMapWarning): + assert_series_equal( + s.map_elements(str, return_dtype=pl.String), + pl.Series(["1", "1"]), + check_names=False, + ) + + +def test_cabbage_strategy_14396() -> None: + df = pl.DataFrame({"x": [1, 2, 3]}) + with ( + pytest.raises(ValueError, match="strategy 'cabbage' is not supported"), + pytest.warns(PolarsInefficientMapWarning), + ): + df.select(pl.col("x").map_elements(lambda x: 2 * x, strategy="cabbage")) # type: ignore[arg-type] + + +def test_unknown_map_elements() -> None: + df = pl.DataFrame( + { + "Amount": [10, 1, 1, 5], + "Flour": ["1000g", "100g", "50g", "75g"], + } + ) + + q = df.lazy().select( + pl.col("Amount"), + pl.col("Flour").map_elements(lambda x: 100.0) / pl.col("Amount"), + ) + + assert q.collect().to_dict(as_series=False) == { + "Amount": [10, 1, 1, 5], + "Flour": [10.0, 100.0, 100.0, 20.0], + } + assert q.collect_schema().dtypes() == [pl.Int64, pl.Unknown] + + +def test_map_elements_list_dtype_18472() -> None: + s = pl.Series([[None], ["abc ", None]]) + result = s.map_elements(lambda s: [i.strip() if i else None for i in s]) + expected = pl.Series([[None], ["abc", None]]) + assert_series_equal(result, expected) + + +def test_map_elements_list_return_dtype() -> None: + s = pl.Series([[1], [2, 3]]) + return_dtype = pl.List(pl.UInt16) + + result = s.map_elements( + lambda s: [i + 1 for i in s], + return_dtype=return_dtype, + ) + expected = pl.Series([[2], [3, 4]], dtype=return_dtype) + assert_series_equal(result, expected) + + +def test_map_elements_list_of_named_tuple_15425() -> None: + class Foo(NamedTuple): + x: int + + df = pl.DataFrame({"a": [0, 1, 2]}) + result = df.select( + pl.col("a").map_elements( + lambda x: [Foo(i) for i in range(x)], + return_dtype=pl.List(pl.Struct({"x": pl.Int64})), + ) + ) + expected = pl.DataFrame({"a": [[], [{"x": 0}], [{"x": 0}, {"x": 1}]]}) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/operations/map/test_map_groups.py b/py-polars/tests/unit/operations/map/test_map_groups.py new file mode 100644 index 000000000000..5286bcceb1cb --- /dev/null +++ b/py-polars/tests/unit/operations/map/test_map_groups.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import numpy as np +import pytest + +import polars as pl +from polars.exceptions import ComputeError +from polars.testing import assert_frame_equal + +if TYPE_CHECKING: + from collections.abc import Sequence + + +def test_map_groups() -> None: + df = pl.DataFrame( + { + "a": ["a", "b", "a", "b", "b", "c"], + "b": [1, 2, 3, 4, 5, 6], + "c": [6, 5, 4, 3, 2, 1], + } + ) + + result = df.group_by("a").map_groups(lambda df: df[["c"]].sum()) + + expected = pl.DataFrame({"c": [10, 10, 1]}) + assert_frame_equal(result, expected, check_row_order=False) + + +def test_map_groups_lazy() -> None: + lf = pl.LazyFrame({"a": [1, 1, 3], "b": [1.0, 2.0, 3.0]}) + + schema = {"a": pl.Float64, "b": pl.Float64} + result = lf.group_by("a").map_groups(lambda df: df * 2.0, schema=schema) + + expected = pl.LazyFrame({"a": [6.0, 2.0, 2.0], "b": [6.0, 2.0, 4.0]}) + assert_frame_equal(result, expected, check_row_order=False) + assert result.collect_schema() == expected.collect_schema() + + +def test_map_groups_rolling() -> None: + df = pl.DataFrame( + { + "a": [1, 2, 3, 4, 5], + "b": [1, 2, 3, 4, 5], + } + ).set_sorted("a") + + def function(df: pl.DataFrame) -> pl.DataFrame: + return df.select( + pl.col("a").min(), + pl.col("b").max(), + ) + + result = df.rolling("a", period="2i").map_groups(function, schema=df.schema) + + expected = pl.DataFrame( + [ + pl.Series("a", [1, 1, 2, 3, 4], dtype=pl.Int64), + pl.Series("b", [1, 2, 3, 4, 5], dtype=pl.Int64), + ] + ) + assert_frame_equal(result, expected) + + +def test_map_groups_empty() -> None: + df = pl.DataFrame(schema={"x": pl.Int64}) + with pytest.raises( + ComputeError, match=r"cannot group_by \+ apply on empty 'DataFrame'" + ): + df.group_by("x").map_groups(lambda x: x) + + +def test_map_groups_none() -> None: + df = pl.DataFrame( + { + "g": [1, 1, 1, 2, 2, 2, 5], + "a": [2, 4, 5, 190, 1, 4, 1], + "b": [1, 3, 2, 1, 43, 3, 1], + } + ) + + out = ( + df.group_by("g", maintain_order=True).agg( + pl.map_groups( + exprs=["a", pl.col("b") ** 4, pl.col("a") / 4], + function=lambda x: x[0] * x[1] + x[2].sum(), + return_dtype=pl.Float64, + returns_scalar=False, + ).alias("multiple") + ) + )["multiple"] + assert out[0].to_list() == [4.75, 326.75, 82.75] + assert out[1].to_list() == [238.75, 3418849.75, 372.75] + + out_df = df.select(pl.map_batches(exprs=["a", "b"], function=lambda s: s[0] * s[1])) + assert out_df["a"].to_list() == (df["a"] * df["b"]).to_list() + + # check if we can return None + def func(s: Sequence[pl.Series]) -> pl.Series | None: + if s[0][0] == 190: + return None + else: + return s[0] + + out = ( + df.group_by("g", maintain_order=True).agg( + pl.map_groups( + exprs=["a", pl.col("b") ** 4, pl.col("a") / 4], + function=func, + returns_scalar=False, + ).alias("multiple") + ) + )["multiple"] + assert out[1] is None + + +def test_map_groups_object_output() -> None: + df = pl.DataFrame( + { + "names": ["foo", "ham", "spam", "cheese", "egg", "foo"], + "dates": ["1", "1", "2", "3", "3", "4"], + "groups": ["A", "A", "B", "B", "B", "C"], + } + ) + + class Foo: + def __init__(self, payload: Any) -> None: + self.payload = payload + + result = df.group_by("groups").agg( + pl.map_groups( + [pl.col("dates"), pl.col("names")], + lambda s: Foo(dict(zip(s[0], s[1]))), + return_dtype=pl.Object, + returns_scalar=True, + ) + ) + + assert result.dtypes == [pl.String, pl.Object] + + +def test_map_groups_numpy_output_3057() -> None: + df = pl.DataFrame( + { + "id": [0, 0, 0, 1, 1, 1], + "t": [2.0, 4.3, 5, 10, 11, 14], + "y": [0.0, 1, 1.3, 2, 3, 4], + } + ) + + result = df.group_by("id", maintain_order=True).agg( + pl.map_groups( + ["y", "t"], lambda lst: np.mean([lst[0], lst[1]]), returns_scalar=True + ).alias("result") + ) + + expected = pl.DataFrame({"id": [0, 1], "result": [2.266666, 7.333333]}) + assert_frame_equal(result, expected) + + +def test_map_groups_return_all_null_15260() -> None: + def foo(x: pl.Series) -> pl.Series: + return pl.Series([x[0][0]], dtype=x[0].dtype) + + assert_frame_equal( + pl.DataFrame({"key": [0, 0, 1], "a": [None, None, None]}) + .group_by("key") + .agg(pl.map_groups(exprs=["a"], function=foo, returns_scalar=True)) # type: ignore[arg-type] + .sort("key"), + pl.DataFrame({"key": [0, 1], "a": [None, None]}), + ) diff --git a/py-polars/tests/unit/operations/map/test_map_rows.py b/py-polars/tests/unit/operations/map/test_map_rows.py new file mode 100644 index 000000000000..b890a42320e7 --- /dev/null +++ b/py-polars/tests/unit/operations/map/test_map_rows.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +import polars as pl +from polars.exceptions import ComputeError +from polars.testing import assert_frame_equal + + +def test_map_rows() -> None: + df = pl.DataFrame({"a": ["foo", "2"], "b": [1, 2], "c": [1.0, 2.0]}) + + result = df.map_rows(lambda x: len(x), None) + + expected = pl.DataFrame({"map": [3, 3]}) + assert_frame_equal(result, expected) + + +def test_map_rows_list_return() -> None: + df = pl.DataFrame({"start": [1, 2], "end": [3, 5]}) + + result = df.map_rows(lambda r: pl.Series(range(r[0], r[1] + 1))) + + expected = pl.DataFrame({"map": [[1, 2, 3], [2, 3, 4, 5]]}) + assert_frame_equal(result, expected) + + +def test_map_rows_dataframe_return() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": ["c", "d", None]}) + + result = df.map_rows(lambda row: (row[0] * 10, "foo", True, row[-1])) + + expected = pl.DataFrame( + { + "column_0": [10, 20, 30], + "column_1": ["foo", "foo", "foo"], + "column_2": [True, True, True], + "column_3": ["c", "d", None], + } + ) + assert_frame_equal(result, expected) + + +def test_map_rows_error_return_type() -> None: + df = pl.DataFrame({"a": [[1, 2], [2, 3]], "b": [[4, 5], [6, 7]]}) + + def combine(row: tuple[Any, ...]) -> list[Any]: + res = [x + y for x, y in zip(row[0], row[1])] + return [res] + + with pytest.raises(ComputeError, match="expected tuple, got list"): + df.map_rows(combine) + + +def test_map_rows_shifted_chunks() -> None: + df = pl.DataFrame(pl.Series("texts", ["test", "test123", "tests"])) + df = df.select(pl.col("texts"), pl.col("texts").shift(1).alias("texts_shifted")) + + result = df.map_rows(lambda x: x) + + expected = pl.DataFrame( + { + "column_0": ["test", "test123", "tests"], + "column_1": [None, "test", "test123"], + } + ) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/operations/namespaces/__init__.py b/py-polars/tests/unit/operations/namespaces/__init__.py new file mode 100644 index 000000000000..1c4fcfa24662 --- /dev/null +++ b/py-polars/tests/unit/operations/namespaces/__init__.py @@ -0,0 +1,11 @@ +""" +Test module containing dedicated tests for all namespace methods. + +Namespace methods are methods that are available on Series and Expr classes for +operations that are only available for specific data types. For example, +`Series.str.to_lowercase()`. + +These methods are almost exclusively implemented as expressions, with the Series method +dispatching to this implementation through a decorator. This means we only need to test +the Series method, as this will indirectly test the Expr method as well. +""" diff --git a/py-polars/tests/unit/operations/namespaces/array/__init__.py b/py-polars/tests/unit/operations/namespaces/array/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/operations/namespaces/array/test_array.py b/py-polars/tests/unit/operations/namespaces/array/test_array.py new file mode 100644 index 000000000000..0e20ffb1158f --- /dev/null +++ b/py-polars/tests/unit/operations/namespaces/array/test_array.py @@ -0,0 +1,551 @@ +from __future__ import annotations + +import datetime +from typing import Any + +import pytest + +import polars as pl +from polars.exceptions import ComputeError, InvalidOperationError +from polars.testing import assert_frame_equal, assert_series_equal + + +def test_arr_min_max() -> None: + s = pl.Series("a", [[1, 2], [4, 3]], dtype=pl.Array(pl.Int64, 2)) + assert s.arr.max().to_list() == [2, 4] + assert s.arr.min().to_list() == [1, 3] + + s_with_null = pl.Series("a", [[None, 2], None, [3, 4]], dtype=pl.Array(pl.Int64, 2)) + assert s_with_null.arr.max().to_list() == [2, None, 4] + assert s_with_null.arr.min().to_list() == [2, None, 3] + + +def test_array_min_max_dtype_12123() -> None: + df = pl.LazyFrame( + [pl.Series("a", [[1.0, 3.0], [2.0, 5.0]]), pl.Series("b", [1.0, 2.0])], + schema_overrides={ + "a": pl.Array(pl.Float64, 2), + }, + ) + + df = df.with_columns( + max=pl.col("a").arr.max().alias("max"), + min=pl.col("a").arr.min().alias("min"), + ) + + assert df.collect_schema() == { + "a": pl.Array(pl.Float64, 2), + "b": pl.Float64, + "max": pl.Float64, + "min": pl.Float64, + } + + out = df.select(pl.col("max") * pl.col("b"), pl.col("min") * pl.col("b")).collect() + + assert_frame_equal(out, pl.DataFrame({"max": [3.0, 10.0], "min": [1.0, 4.0]})) + + +@pytest.mark.parametrize( + ("data", "expected_sum", "dtype"), + [ + ([[1, 2], [4, 3]], [3, 7], pl.Int64), + ([[1, None], [None, 3], [None, None]], [1, 3, 0], pl.Int64), + ([[1.0, 2.0], [4.0, 3.0]], [3.0, 7.0], pl.Float32), + ([[1.0, None], [None, 3.0], [None, None]], [1.0, 3.0, 0], pl.Float32), + ([[True, False], [True, True], [False, False]], [1, 2, 0], pl.Boolean), + ([[True, None], [None, False], [None, None]], [1, 0, 0], pl.Boolean), + ], +) +def test_arr_sum( + data: list[list[Any]], expected_sum: list[Any], dtype: pl.DataType +) -> None: + s = pl.Series("a", data, dtype=pl.Array(dtype, 2)) + assert s.arr.sum().to_list() == expected_sum + + +def test_array_lengths() -> None: + df = pl.DataFrame( + [ + pl.Series("a", [[1, 2, 3]], dtype=pl.Array(pl.Int64, 3)), + pl.Series("b", [[4, 5]], dtype=pl.Array(pl.Int64, 2)), + ] + ) + out = df.select(pl.col("a").arr.len(), pl.col("b").arr.len()) + expected_df = pl.DataFrame( + {"a": [3], "b": [2]}, schema={"a": pl.UInt32, "b": pl.UInt32} + ) + assert_frame_equal(out, expected_df) + + assert pl.Series("a", [[], []], pl.Array(pl.Null, 0)).arr.len().to_list() == [0, 0] + assert pl.Series("a", [None, []], pl.Array(pl.Null, 0)).arr.len().to_list() == [ + None, + 0, + ] + assert pl.Series("a", [None], pl.Array(pl.Null, 0)).arr.len().to_list() == [None] + + assert pl.Series("a", [], pl.Array(pl.Null, 0)).arr.len().to_list() == [] + assert pl.Series("a", [], pl.Array(pl.Null, 1)).arr.len().to_list() == [] + assert pl.Series( + "a", [[1, 2, 3], None, [7, 8, 9]], pl.Array(pl.Int32, 3) + ).arr.len().to_list() == [3, None, 3] + + +def test_arr_unique() -> None: + df = pl.DataFrame( + {"a": pl.Series("a", [[1, 1], [4, 3]], dtype=pl.Array(pl.Int64, 2))} + ) + + out = df.select(pl.col("a").arr.unique(maintain_order=True)) + expected = pl.DataFrame({"a": [[1], [4, 3]]}) + assert_frame_equal(out, expected) + + +def test_array_any_all() -> None: + s = pl.Series( + [[True, True], [False, True], [False, False], [None, None], None], + dtype=pl.Array(pl.Boolean, 2), + ) + + expected_any = pl.Series([True, True, False, False, None]) + assert_series_equal(s.arr.any(), expected_any) + + expected_all = pl.Series([True, False, False, True, None]) + assert_series_equal(s.arr.all(), expected_all) + + s = pl.Series([[1, 2], [3, 4], [5, 6]], dtype=pl.Array(pl.Int64, 2)) + with pytest.raises(ComputeError, match="expected boolean elements in array"): + s.arr.any() + with pytest.raises(ComputeError, match="expected boolean elements in array"): + s.arr.all() + + +def test_array_sort() -> None: + s = pl.Series([[2, None, 1], [1, 3, 2]], dtype=pl.Array(pl.UInt32, 3)) + + desc = s.arr.sort(descending=True) + expected = pl.Series([[None, 2, 1], [3, 2, 1]], dtype=pl.Array(pl.UInt32, 3)) + assert_series_equal(desc, expected) + + asc = s.arr.sort(descending=False) + expected = pl.Series([[None, 1, 2], [1, 2, 3]], dtype=pl.Array(pl.UInt32, 3)) + assert_series_equal(asc, expected) + + # test nulls_last + s = pl.Series([[None, 1, 2], [-1, None, 9]], dtype=pl.Array(pl.Int8, 3)) + assert_series_equal( + s.arr.sort(nulls_last=True), + pl.Series([[1, 2, None], [-1, 9, None]], dtype=pl.Array(pl.Int8, 3)), + ) + assert_series_equal( + s.arr.sort(nulls_last=False), + pl.Series([[None, 1, 2], [None, -1, 9]], dtype=pl.Array(pl.Int8, 3)), + ) + + +def test_array_reverse() -> None: + s = pl.Series([[2, None, 1], [1, None, 2]], dtype=pl.Array(pl.UInt32, 3)) + + s = s.arr.reverse() + expected = pl.Series([[1, None, 2], [2, None, 1]], dtype=pl.Array(pl.UInt32, 3)) + assert_series_equal(s, expected) + + +def test_array_arg_min_max() -> None: + s = pl.Series("a", [[1, 2, 4], [3, 2, 1]], dtype=pl.Array(pl.UInt32, 3)) + expected = pl.Series("a", [0, 2], dtype=pl.UInt32) + assert_series_equal(s.arr.arg_min(), expected) + expected = pl.Series("a", [2, 0], dtype=pl.UInt32) + assert_series_equal(s.arr.arg_max(), expected) + + +def test_array_get() -> None: + s = pl.Series( + "a", + [[1, 2, 3, 4], [5, 6, None, None], [7, 8, 9, 10]], + dtype=pl.Array(pl.Int64, 4), + ) + + # Test index literal. + out = s.arr.get(1, null_on_oob=False) + expected = pl.Series("a", [2, 6, 8], dtype=pl.Int64) + assert_series_equal(out, expected) + + # Null index literal. + out_df = s.to_frame().select(pl.col.a.arr.get(pl.lit(None), null_on_oob=False)) + expected_df = pl.Series("a", [None, None, None], dtype=pl.Int64).to_frame() + assert_frame_equal(out_df, expected_df) + + # Out-of-bounds index literal. + with pytest.raises(ComputeError, match="get index is out of bounds"): + out = s.arr.get(100, null_on_oob=False) + + # Negative index literal. + out = s.arr.get(-2, null_on_oob=False) + expected = pl.Series("a", [3, None, 9], dtype=pl.Int64) + assert_series_equal(out, expected) + + # Test index expr. + with pytest.raises(ComputeError, match="get index is out of bounds"): + out = s.arr.get(pl.Series([1, -2, 100]), null_on_oob=False) + + out = s.arr.get(pl.Series([1, -2, 0]), null_on_oob=False) + expected = pl.Series("a", [2, None, 7], dtype=pl.Int64) + assert_series_equal(out, expected) + + # Test logical type. + s = pl.Series( + "a", + [ + [datetime.date(1999, 1, 1), datetime.date(2000, 1, 1)], + [datetime.date(2001, 10, 1), None], + [None, None], + ], + dtype=pl.Array(pl.Date, 2), + ) + with pytest.raises(ComputeError, match="get index is out of bounds"): + out = s.arr.get(pl.Series([1, -2, 4]), null_on_oob=False) + + +def test_array_get_null_on_oob() -> None: + s = pl.Series( + "a", + [[1, 2, 3, 4], [5, 6, None, None], [7, 8, 9, 10]], + dtype=pl.Array(pl.Int64, 4), + ) + + # Test index literal. + out = s.arr.get(1, null_on_oob=True) + expected = pl.Series("a", [2, 6, 8], dtype=pl.Int64) + assert_series_equal(out, expected) + + # Null index literal. + out_df = s.to_frame().select(pl.col.a.arr.get(pl.lit(None), null_on_oob=True)) + expected_df = pl.Series("a", [None, None, None], dtype=pl.Int64).to_frame() + assert_frame_equal(out_df, expected_df) + + # Out-of-bounds index literal. + out = s.arr.get(100, null_on_oob=True) + expected = pl.Series("a", [None, None, None], dtype=pl.Int64) + assert_series_equal(out, expected) + + # Negative index literal. + out = s.arr.get(-2, null_on_oob=True) + expected = pl.Series("a", [3, None, 9], dtype=pl.Int64) + assert_series_equal(out, expected) + + # Test index expr. + out = s.arr.get(pl.Series([1, -2, 100]), null_on_oob=True) + expected = pl.Series("a", [2, None, None], dtype=pl.Int64) + assert_series_equal(out, expected) + + # Test logical type. + s = pl.Series( + "a", + [ + [datetime.date(1999, 1, 1), datetime.date(2000, 1, 1)], + [datetime.date(2001, 10, 1), None], + [None, None], + ], + dtype=pl.Array(pl.Date, 2), + ) + out = s.arr.get(pl.Series([1, -2, 4]), null_on_oob=True) + expected = pl.Series( + "a", + [datetime.date(2000, 1, 1), datetime.date(2001, 10, 1), None], + dtype=pl.Date, + ) + assert_series_equal(out, expected) + + +def test_arr_first_last() -> None: + s = pl.Series( + "a", + [[1, 2, 3], [None, 5, 6], [None, None, None]], + dtype=pl.Array(pl.Int64, 3), + ) + + first = s.arr.first() + expected_first = pl.Series( + "a", + [1, None, None], + dtype=pl.Int64, + ) + assert_series_equal(first, expected_first) + + last = s.arr.last() + expected_last = pl.Series( + "a", + [3, 6, None], + dtype=pl.Int64, + ) + assert_series_equal(last, expected_last) + + +@pytest.mark.parametrize( + ("data", "set", "dtype"), + [ + ([1, 2], [[1, 2], [3, 4]], pl.Int64), + ([True, False], [[True, False], [True, True]], pl.Boolean), + (["a", "b"], [["a", "b"], ["c", "d"]], pl.String), + ([b"a", b"b"], [[b"a", b"b"], [b"c", b"d"]], pl.Binary), + ( + [{"a": 1}, {"a": 2}], + [[{"a": 1}, {"a": 2}], [{"b": 1}, {"a": 3}]], + pl.Struct([pl.Field("a", pl.Int64)]), + ), + ], +) +def test_is_in_array(data: list[Any], set: list[list[Any]], dtype: pl.DataType) -> None: + df = pl.DataFrame( + {"a": data, "arr": set}, + schema={"a": dtype, "arr": pl.Array(dtype, 2)}, + ) + out = df.select(is_in=pl.col("a").is_in(pl.col("arr"))).to_series() + expected = pl.Series("is_in", [True, False]) + assert_series_equal(out, expected) + + +def test_array_join() -> None: + df = pl.DataFrame( + { + "a": [["ab", "c", "d"], ["e", "f", "g"], [None, None, None], None], + "separator": ["&", None, "*", "_"], + }, + schema={ + "a": pl.Array(pl.String, 3), + "separator": pl.String, + }, + ) + out = df.select(pl.col("a").arr.join("-")) + assert out.to_dict(as_series=False) == {"a": ["ab-c-d", "e-f-g", "", None]} + out = df.select(pl.col("a").arr.join(pl.col("separator"))) + assert out.to_dict(as_series=False) == {"a": ["ab&c&d", None, "", None]} + + # test ignore_nulls argument + df = pl.DataFrame( + { + "a": [ + ["a", None, "b", None], + None, + [None, None, None, None], + ["c", "d", "e", "f"], + ], + "separator": ["-", "&", " ", "@"], + }, + schema={ + "a": pl.Array(pl.String, 4), + "separator": pl.String, + }, + ) + # ignore nulls + out = df.select(pl.col("a").arr.join("-", ignore_nulls=True)) + assert out.to_dict(as_series=False) == {"a": ["a-b", None, "", "c-d-e-f"]} + out = df.select(pl.col("a").arr.join(pl.col("separator"), ignore_nulls=True)) + assert out.to_dict(as_series=False) == {"a": ["a-b", None, "", "c@d@e@f"]} + # propagate nulls + out = df.select(pl.col("a").arr.join("-", ignore_nulls=False)) + assert out.to_dict(as_series=False) == {"a": [None, None, None, "c-d-e-f"]} + out = df.select(pl.col("a").arr.join(pl.col("separator"), ignore_nulls=False)) + assert out.to_dict(as_series=False) == {"a": [None, None, None, "c@d@e@f"]} + + +def test_array_explode() -> None: + df = pl.DataFrame( + { + "str": [["a", "b"], ["c", None], None], + "nested": [[[1, 2], [3]], [[], [4, None]], None], + "logical": [ + [datetime.date(1998, 1, 1), datetime.date(2000, 10, 1)], + [datetime.date(2024, 1, 1), None], + None, + ], + }, + schema={ + "str": pl.Array(pl.String, 2), + "nested": pl.Array(pl.List(pl.Int64), 2), + "logical": pl.Array(pl.Date, 2), + }, + ) + out = df.select(pl.all().arr.explode()) + expected = pl.DataFrame( + { + "str": ["a", "b", "c", None, None], + "nested": [[1, 2], [3], [], [4, None], None], + "logical": [ + datetime.date(1998, 1, 1), + datetime.date(2000, 10, 1), + datetime.date(2024, 1, 1), + None, + None, + ], + } + ) + assert_frame_equal(out, expected) + + # test no-null fast path + s = pl.Series( + [ + [datetime.date(1998, 1, 1), datetime.date(1999, 1, 3)], + [datetime.date(2000, 1, 1), datetime.date(2023, 10, 1)], + ], + dtype=pl.Array(pl.Date, 2), + ) + out_s = s.arr.explode() + expected_s = pl.Series( + [ + datetime.date(1998, 1, 1), + datetime.date(1999, 1, 3), + datetime.date(2000, 1, 1), + datetime.date(2023, 10, 1), + ], + dtype=pl.Date, + ) + assert_series_equal(out_s, expected_s) + + +@pytest.mark.parametrize( + ("arr", "data", "expected", "dtype"), + [ + ([[1, 2], [3, None], None], 1, [1, 0, None], pl.Int64), + ([[True, False], [True, None], None], True, [1, 1, None], pl.Boolean), + ([["a", "b"], ["c", None], None], "a", [1, 0, None], pl.String), + ([[b"a", b"b"], [b"c", None], None], b"a", [1, 0, None], pl.Binary), + ], +) +def test_array_count_matches( + arr: list[list[Any] | None], data: Any, expected: list[Any], dtype: pl.DataType +) -> None: + df = pl.DataFrame({"arr": arr}, schema={"arr": pl.Array(dtype, 2)}) + out = df.select(count_matches=pl.col("arr").arr.count_matches(data)) + assert out.to_dict(as_series=False) == {"count_matches": expected} + + +def test_array_count_matches_wildcard_expansion() -> None: + df = pl.DataFrame( + {"a": [[1, 2]], "b": [[3, 4]]}, + schema={"a": pl.Array(pl.Int64, 2), "b": pl.Array(pl.Int64, 2)}, + ) + assert df.select(pl.all().arr.count_matches(3)).to_dict(as_series=False) == { + "a": [0], + "b": [1], + } + + +def test_array_to_struct() -> None: + df = pl.DataFrame( + {"a": [[1, 2, 3], [4, 5, None]]}, schema={"a": pl.Array(pl.Int8, 3)} + ) + assert df.select([pl.col("a").arr.to_struct()]).to_series().to_list() == [ + {"field_0": 1, "field_1": 2, "field_2": 3}, + {"field_0": 4, "field_1": 5, "field_2": None}, + ] + + df = pl.DataFrame( + {"a": [[1, 2, None], [1, 2, 3]]}, schema={"a": pl.Array(pl.Int8, 3)} + ) + assert df.select( + pl.col("a").arr.to_struct(fields=lambda idx: f"col_name_{idx}") + ).to_series().to_list() == [ + {"col_name_0": 1, "col_name_1": 2, "col_name_2": None}, + {"col_name_0": 1, "col_name_1": 2, "col_name_2": 3}, + ] + + assert df.lazy().select(pl.col("a").arr.to_struct()).unnest( + "a" + ).sum().collect().columns == ["field_0", "field_1", "field_2"] + + +def test_array_shift() -> None: + df = pl.DataFrame( + {"a": [[1, 2, 3], None, [4, 5, 6], [7, 8, 9]], "n": [None, 1, 1, -2]}, + schema={"a": pl.Array(pl.Int64, 3), "n": pl.Int64}, + ) + + out = df.select( + lit=pl.col("a").arr.shift(1), expr=pl.col("a").arr.shift(pl.col("n")) + ) + expected = pl.DataFrame( + { + "lit": [[None, 1, 2], None, [None, 4, 5], [None, 7, 8]], + "expr": [None, None, [None, 4, 5], [9, None, None]], + }, + schema={"lit": pl.Array(pl.Int64, 3), "expr": pl.Array(pl.Int64, 3)}, + ) + assert_frame_equal(out, expected) + + +def test_array_n_unique() -> None: + df = pl.DataFrame( + { + "a": [[1, 1, 2], [3, 3, 3], [None, None, None], None], + }, + schema={"a": pl.Array(pl.Int64, 3)}, + ) + + out = df.select(n_unique=pl.col("a").arr.n_unique()) + expected = pl.DataFrame( + {"n_unique": [2, 1, 1, None]}, schema={"n_unique": pl.UInt32} + ) + assert_frame_equal(out, expected) + + +def test_explode_19049() -> None: + df = pl.DataFrame({"a": [[1, 2, 3]]}, schema={"a": pl.Array(pl.Int64, 3)}) + result_df = df.select(pl.col.a.arr.explode()) + expected_df = pl.DataFrame({"a": [1, 2, 3]}, schema={"a": pl.Int64}) + assert_frame_equal(result_df, expected_df) + + df = pl.DataFrame({"a": [1, 2, 3]}, schema={"a": pl.Int64}) + with pytest.raises(InvalidOperationError, match="expected Array type, got: i64"): + df.select(pl.col.a.arr.explode()) + + +def test_array_join_unequal_lengths_22018() -> None: + df = pl.DataFrame( + [ + pl.Series( + "a", + [ + ["a", "b", "d"], + ["ya", "x", "y"], + ["ya", "x", "y"], + ], + pl.Array(pl.String, 3), + ), + ] + ) + with pytest.raises(pl.exceptions.ShapeError): + df.select(pl.col.a.arr.join(pl.Series([",", "-"]))) + + +def test_array_shift_unequal_lengths_22018() -> None: + with pytest.raises(pl.exceptions.ShapeError): + pl.Series( + "a", + [ + ["a", "b", "d"], + ["a", "b", "d"], + ["a", "b", "d"], + ], + pl.Array(pl.String, 3), + ).arr.shift(pl.Series([1, 2])) + + +def test_array_shift_self_broadcast_22124() -> None: + assert_series_equal( + pl.Series( + "a", + [ + ["a", "b", "d"], + ], + pl.Array(pl.String, 3), + ).arr.shift(pl.Series([1, 2])), + pl.Series( + "a", + [ + [None, "a", "b"], + [None, None, "a"], + ], + pl.Array(pl.String, 3), + ), + ) diff --git a/py-polars/tests/unit/operations/namespaces/array/test_contains.py b/py-polars/tests/unit/operations/namespaces/array/test_contains.py new file mode 100644 index 000000000000..5c94fd986f3f --- /dev/null +++ b/py-polars/tests/unit/operations/namespaces/array/test_contains.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +import polars as pl +from polars.exceptions import SchemaError +from polars.testing import assert_series_equal + + +@pytest.mark.parametrize( + ("array", "data", "expected", "dtype"), + [ + ([[1, 2], [3, 4]], [1, 5], [True, False], pl.Int64), + ([[True, False], [True, True]], [True, False], [True, False], pl.Boolean), + ([["a", "b"], ["c", "d"]], ["a", "b"], [True, False], pl.String), + ([[b"a", b"b"], [b"c", b"d"]], [b"a", b"b"], [True, False], pl.Binary), + ( + [[{"a": 1}, {"a": 2}], [{"b": 1}, {"a": 3}]], + [{"a": 1}, {"a": 2}], + [True, False], + pl.Struct([pl.Field("a", pl.Int64)]), + ), + ], +) +def test_array_contains_expr( + array: list[list[Any]], data: list[Any], expected: list[bool], dtype: pl.DataType +) -> None: + df = pl.DataFrame( + { + "array": array, + "data": data, + }, + schema={ + "array": pl.Array(dtype, 2), + "data": dtype, + }, + ) + out = df.select(contains=pl.col("array").arr.contains(pl.col("data"))).to_series() + expected_series = pl.Series("contains", expected) + assert_series_equal(out, expected_series) + + +@pytest.mark.parametrize( + ("array", "data", "expected", "dtype"), + [ + ([[1, 2], [3, 4]], 1, [True, False], pl.Int64), + ([[True, False], [True, True]], True, [True, True], pl.Boolean), + ([["a", "b"], ["c", "d"]], "a", [True, False], pl.String), + ([[b"a", b"b"], [b"c", b"d"]], b"a", [True, False], pl.Binary), + ], +) +def test_array_contains_literal( + array: list[list[Any]], data: Any, expected: list[bool], dtype: pl.DataType +) -> None: + df = pl.DataFrame( + { + "array": array, + }, + schema={ + "array": pl.Array(dtype, 2), + }, + ) + out = df.select(contains=pl.col("array").arr.contains(data)).to_series() + expected_series = pl.Series("contains", expected) + assert_series_equal(out, expected_series) + + +def test_array_contains_invalid_datatype() -> None: + df = pl.DataFrame({"a": [[1, 2], [3, 4]]}, schema={"a": pl.List(pl.Int8)}) + with pytest.raises(SchemaError, match="invalid series dtype: expected `Array`"): + df.select(pl.col("a").arr.contains(2)) diff --git a/py-polars/tests/unit/operations/namespaces/array/test_to_list.py b/py-polars/tests/unit/operations/namespaces/array/test_to_list.py new file mode 100644 index 000000000000..25b9d8f21bdf --- /dev/null +++ b/py-polars/tests/unit/operations/namespaces/array/test_to_list.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal + + +def test_arr_to_list() -> None: + s = pl.Series("a", [[1, 2], [4, 3]], dtype=pl.Array(pl.Int8, 2)) + + result = s.arr.to_list() + + expected = pl.Series("a", [[1, 2], [4, 3]], dtype=pl.List(pl.Int8)) + assert_series_equal(result, expected) + + # test logical type + data = {"duration": [[1000, 2000], None]} + df = pl.DataFrame( + data, + schema={ + "duration": pl.Array(pl.Datetime, shape=2), + }, + ).with_columns(pl.col("duration").arr.to_list()) + + expected_df = pl.DataFrame( + data, + schema={ + "duration": pl.List(pl.Datetime), + }, + ) + assert_frame_equal(df, expected_df) + + +def test_arr_to_list_lazy() -> None: + s = pl.Series("a", [[1, 2], [4, 3]], dtype=pl.Array(pl.Int8, 2)) + lf = s.to_frame().lazy() + + result = lf.select(pl.col("a").arr.to_list()) + + s = pl.Series("a", [[1, 2], [4, 3]], dtype=pl.List(pl.Int8)) + expected = s.to_frame().lazy() + assert_frame_equal(result, expected) + + +def test_arr_to_list_nested_array_preserved() -> None: + s = pl.Series( + "a", + [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], + dtype=pl.Array(pl.Array(pl.Int8, 2), 2), + ) + lf = s.to_frame().lazy() + + result = lf.select(pl.col("a").arr.to_list()) + + s = pl.Series( + "a", + [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], + ).cast(pl.List(pl.Array(pl.Int8, 2))) + expected = s.to_frame().lazy() + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/operations/namespaces/conftest.py b/py-polars/tests/unit/operations/namespaces/conftest.py new file mode 100644 index 000000000000..6c017aa02e6e --- /dev/null +++ b/py-polars/tests/unit/operations/namespaces/conftest.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + + +@pytest.fixture +def namespace_files_path() -> Path: + return Path(__file__).parent / "files" diff --git a/py-polars/tests/unit/operations/namespaces/files/test_tree_fmt.txt b/py-polars/tests/unit/operations/namespaces/files/test_tree_fmt.txt new file mode 100644 index 000000000000..b9ac79bcb4cf --- /dev/null +++ b/py-polars/tests/unit/operations/namespaces/files/test_tree_fmt.txt @@ -0,0 +1,96 @@ +(pl.col("foo") * pl.col("bar")).sum().over("ham", "ham2") / 2 + + 0 1 2 3 4 + ┌─────────────────────────────────────────────────────────────────────────────── + │ + │ ╭───────────╮ + 0 │ │ binary: / │ + │ ╰─────┬┬────╯ + │ ││ + │ │╰────────────────╮ + │ │ │ + │ ╭────────┴────────╮ ╭───┴────╮ + 1 │ │ lit(dyn int: 2) │ │ window │ + │ ╰─────────────────╯ ╰───┬┬───╯ + │ ││ + │ │╰────────────┬──────────────╮ + │ │ │ │ + │ ╭─────┴─────╮ ╭────┴─────╮ ╭──┴──╮ + 2 │ │ col(ham2) │ │ col(ham) │ │ sum │ + │ ╰───────────╯ ╰──────────╯ ╰──┬──╯ + │ │ + │ │ + │ │ + │ ╭─────┴─────╮ + 3 │ │ binary: * │ + │ ╰─────┬┬────╯ + │ ││ + │ │╰────────────╮ + │ │ │ + │ ╭────┴─────╮ ╭────┴─────╮ + 4 │ │ col(bar) │ │ col(foo) │ + │ ╰──────────╯ ╰──────────╯ + +--- +(pl.col("foo") * pl.col("bar")).sum().over(pl.col("ham")) / 2 + + 0 1 2 3 + ┌──────────────────────────────────────────────────────────────── + │ + │ ╭───────────╮ + 0 │ │ binary: / │ + │ ╰─────┬┬────╯ + │ ││ + │ │╰───────────────╮ + │ │ │ + │ ╭────────┴────────╮ ╭───┴────╮ + 1 │ │ lit(dyn int: 2) │ │ window │ + │ ╰─────────────────╯ ╰───┬┬───╯ + │ ││ + │ │╰─────────────╮ + │ │ │ + │ ╭────┴─────╮ ╭──┴──╮ + 2 │ │ col(ham) │ │ sum │ + │ ╰──────────╯ ╰──┬──╯ + │ │ + │ │ + │ │ + │ ╭─────┴─────╮ + 3 │ │ binary: * │ + │ ╰─────┬┬────╯ + │ ││ + │ │╰────────────╮ + │ │ │ + │ ╭────┴─────╮ ╭────┴─────╮ + 4 │ │ col(bar) │ │ col(foo) │ + │ ╰──────────╯ ╰──────────╯ + +--- +(pl.col("a") + pl.col("b"))**2 + pl.int_range(3) + + 0 1 2 3 4 + ┌────────────────────────────────────────────────────────────────────────────────────────────── + │ + │ ╭───────────╮ + 0 │ │ binary: + │ + │ ╰─────┬┬────╯ + │ ││ + │ │╰──────────────────────────────────────────╮ + │ │ │ + │ ╭──────────┴──────────╮ ╭───────┴───────╮ + 1 │ │ function: int_range │ │ function: pow │ + │ ╰──────────┬┬─────────╯ ╰───────┬┬──────╯ + │ ││ ││ + │ │╰─────────────────────╮ │╰────────────────╮ + │ │ │ │ │ + │ ╭────────┴────────╮ ╭────────┴────────╮ ╭────────┴────────╮ ╭─────┴─────╮ + 2 │ │ lit(dyn int: 3) │ │ lit(dyn int: 0) │ │ lit(dyn int: 2) │ │ binary: + │ + │ ╰─────────────────╯ ╰─────────────────╯ ╰─────────────────╯ ╰─────┬┬────╯ + │ ││ + │ │╰───────────╮ + │ │ │ + │ ╭───┴────╮ ╭───┴────╮ + 3 │ │ col(b) │ │ col(a) │ + │ ╰────────╯ ╰────────╯ + + diff --git a/py-polars/tests/unit/operations/namespaces/list/__init__.py b/py-polars/tests/unit/operations/namespaces/list/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/operations/namespaces/list/test_eval.py b/py-polars/tests/unit/operations/namespaces/list/test_eval.py new file mode 100644 index 000000000000..7432fe413384 --- /dev/null +++ b/py-polars/tests/unit/operations/namespaces/list/test_eval.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +import polars as pl +from polars.exceptions import ( + StructFieldNotFoundError, +) +from polars.testing import assert_frame_equal, assert_series_equal + + +def test_list_eval_dtype_inference() -> None: + grades = pl.DataFrame( + { + "student": ["bas", "laura", "tim", "jenny"], + "arithmetic": [10, 5, 6, 8], + "biology": [4, 6, 2, 7], + "geography": [8, 4, 9, 7], + } + ) + + rank_pct = pl.col("").rank(descending=True) / pl.col("").count().cast(pl.UInt16) + + # the .list.first() would fail if .list.eval did not correctly infer the output type + assert grades.with_columns( + pl.concat_list(pl.all().exclude("student")).alias("all_grades") + ).select( + pl.col("all_grades") + .list.eval(rank_pct, parallel=True) + .alias("grades_rank") + .list.first() + ).to_series().to_list() == [ + 0.3333333333333333, + 0.6666666666666666, + 0.6666666666666666, + 0.3333333333333333, + ] + + +def test_list_eval_categorical() -> None: + df = pl.DataFrame({"test": [["a", None]]}, schema={"test": pl.List(pl.Categorical)}) + df = df.select( + pl.col("test").list.eval(pl.element().filter(pl.element().is_not_null())) + ) + assert_series_equal( + df.get_column("test"), pl.Series("test", [["a"]], dtype=pl.List(pl.Categorical)) + ) + + +def test_list_eval_type_coercion() -> None: + last_non_null_value = pl.element().fill_null(3).last() + df = pl.DataFrame({"array_cols": [[1, None]]}) + + assert df.select( + pl.col("array_cols") + .list.eval(last_non_null_value, parallel=False) + .alias("col_last") + ).to_dict(as_series=False) == {"col_last": [[3]]} + + +def test_list_eval_all_null() -> None: + df = pl.DataFrame({"foo": [1, 2, 3], "bar": [None, None, None]}).with_columns( + pl.col("bar").cast(pl.List(pl.String)) + ) + + assert df.select(pl.col("bar").list.eval(pl.element())).to_dict( + as_series=False + ) == {"bar": [None, None, None]} + + +def test_empty_eval_dtype_5546() -> None: + # https://github.com/pola-rs/polars/issues/5546 + df = pl.DataFrame([{"a": [{"name": 1}, {"name": 2}]}]) + + dtype = df.dtypes[0] + + assert ( + df.limit(0).with_columns( + pl.col("a") + .list.eval(pl.element().filter(pl.first().struct.field("name") == 1)) + .alias("a_filtered") + ) + ).dtypes == [dtype, dtype] + + +def test_list_eval_gather_every_13410() -> None: + df = pl.DataFrame({"a": [[1, 2, 3], [4, 5, 6]]}) + out = df.with_columns(result=pl.col("a").list.eval(pl.element().gather_every(2))) + expected = pl.DataFrame({"a": [[1, 2, 3], [4, 5, 6]], "result": [[1, 3], [4, 6]]}) + assert_frame_equal(out, expected) + + +def test_list_eval_err_raise_15653() -> None: + df = pl.DataFrame({"foo": [[]]}) + with pytest.raises(StructFieldNotFoundError): + df.with_columns(bar=pl.col("foo").list.eval(pl.element().struct.field("baz"))) + + +def test_list_eval_type_cast_11188() -> None: + df = pl.DataFrame( + [ + {"a": None}, + ], + schema={"a": pl.List(pl.Int64)}, + ) + assert df.select( + pl.col("a").list.eval(pl.element().cast(pl.String)).alias("a_str") + ).schema == {"a_str": pl.List(pl.String)} + + +@pytest.mark.parametrize( + "data", + [ + {"a": [["0"], ["1"]]}, + {"a": [["0", "1"], ["2", "3"]]}, + {"a": [["0", "1"]]}, + {"a": [["0"]]}, + ], +) +@pytest.mark.parametrize( + "expr", + [ + pl.lit(""), + pl.format("test: {}", pl.element()), + ], +) +def test_list_eval_list_output_18510(data: dict[str, Any], expr: pl.Expr) -> None: + df = pl.DataFrame(data) + result = df.select(pl.col("a").list.eval(pl.lit(""))) + assert result.to_series().dtype == pl.List(pl.String) diff --git a/py-polars/tests/unit/operations/namespaces/list/test_list.py b/py-polars/tests/unit/operations/namespaces/list/test_list.py new file mode 100644 index 000000000000..87f6ee1a063b --- /dev/null +++ b/py-polars/tests/unit/operations/namespaces/list/test_list.py @@ -0,0 +1,1078 @@ +from __future__ import annotations + +import re +from datetime import date, datetime +from typing import TYPE_CHECKING + +import numpy as np +import pytest + +import polars as pl +from polars.exceptions import ( + ComputeError, + OutOfBoundsError, + SchemaError, +) +from polars.testing import assert_frame_equal, assert_series_equal + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + + +def test_list_arr_get() -> None: + a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]]) + out = a.list.get(0, null_on_oob=False) + expected = pl.Series("a", [1, 4, 6]) + assert_series_equal(out, expected) + out = a.list[0] + expected = pl.Series("a", [1, 4, 6]) + assert_series_equal(out, expected) + out = a.list.first() + assert_series_equal(out, expected) + out = pl.select(pl.lit(a).list.first()).to_series() + assert_series_equal(out, expected) + + out = a.list.get(-1, null_on_oob=False) + expected = pl.Series("a", [3, 5, 9]) + assert_series_equal(out, expected) + out = a.list.last() + assert_series_equal(out, expected) + out = pl.select(pl.lit(a).list.last()).to_series() + assert_series_equal(out, expected) + + with pytest.raises(ComputeError, match="get index is out of bounds"): + a.list.get(3, null_on_oob=False) + + # Null index. + out_df = a.to_frame().select(pl.col.a.list.get(pl.lit(None), null_on_oob=False)) + expected_df = pl.Series("a", [None, None, None], dtype=pl.Int64).to_frame() + assert_frame_equal(out_df, expected_df) + + a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]]) + + with pytest.raises(ComputeError, match="get index is out of bounds"): + a.list.get(-3, null_on_oob=False) + + with pytest.raises(ComputeError, match="get index is out of bounds"): + pl.DataFrame( + {"a": [[1], [2], [3], [4, 5, 6], [7, 8, 9], [None, 11]]} + ).with_columns( + pl.col("a").list.get(i, null_on_oob=False).alias(f"get_{i}") + for i in range(4) + ) + + # get by indexes where some are out of bounds + df = pl.DataFrame({"cars": [[1, 2, 3], [2, 3], [4], []], "indexes": [-2, 1, -3, 0]}) + + with pytest.raises(ComputeError, match="get index is out of bounds"): + df.select([pl.col("cars").list.get("indexes", null_on_oob=False)]).to_dict( + as_series=False + ) + + # exact on oob boundary + df = pl.DataFrame( + { + "index": [3, 3, 3], + "lists": [[3, 4, 5], [4, 5, 6], [7, 8, 9, 4]], + } + ) + + with pytest.raises(ComputeError, match="get index is out of bounds"): + df.select(pl.col("lists").list.get(3, null_on_oob=False)) + + with pytest.raises(ComputeError, match="get index is out of bounds"): + df.select(pl.col("lists").list.get(pl.col("index"), null_on_oob=False)) + + +def test_list_arr_get_null_on_oob() -> None: + a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]]) + out = a.list.get(0, null_on_oob=True) + expected = pl.Series("a", [1, 4, 6]) + assert_series_equal(out, expected) + out = a.list[0] + expected = pl.Series("a", [1, 4, 6]) + assert_series_equal(out, expected) + out = a.list.first() + assert_series_equal(out, expected) + out = pl.select(pl.lit(a).list.first()).to_series() + assert_series_equal(out, expected) + + out = a.list.get(-1, null_on_oob=True) + expected = pl.Series("a", [3, 5, 9]) + assert_series_equal(out, expected) + out = a.list.last() + assert_series_equal(out, expected) + out = pl.select(pl.lit(a).list.last()).to_series() + assert_series_equal(out, expected) + + # Out of bounds index. + out = a.list.get(3, null_on_oob=True) + expected = pl.Series("a", [None, None, 9]) + assert_series_equal(out, expected) + + # Null index. + out_df = a.to_frame().select(pl.col.a.list.get(pl.lit(None), null_on_oob=True)) + expected_df = pl.Series("a", [None, None, None], dtype=pl.Int64).to_frame() + assert_frame_equal(out_df, expected_df) + + a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]]) + out = a.list.get(-3, null_on_oob=True) + expected = pl.Series("a", [1, None, 7]) + assert_series_equal(out, expected) + + assert pl.DataFrame( + {"a": [[1], [2], [3], [4, 5, 6], [7, 8, 9], [None, 11]]} + ).with_columns( + pl.col("a").list.get(i, null_on_oob=True).alias(f"get_{i}") for i in range(4) + ).to_dict(as_series=False) == { + "a": [[1], [2], [3], [4, 5, 6], [7, 8, 9], [None, 11]], + "get_0": [1, 2, 3, 4, 7, None], + "get_1": [None, None, None, 5, 8, 11], + "get_2": [None, None, None, 6, 9, None], + "get_3": [None, None, None, None, None, None], + } + + # get by indexes where some are out of bounds + df = pl.DataFrame({"cars": [[1, 2, 3], [2, 3], [4], []], "indexes": [-2, 1, -3, 0]}) + + assert df.select([pl.col("cars").list.get("indexes", null_on_oob=True)]).to_dict( + as_series=False + ) == {"cars": [2, 3, None, None]} + # exact on oob boundary + df = pl.DataFrame( + { + "index": [3, 3, 3], + "lists": [[3, 4, 5], [4, 5, 6], [7, 8, 9, 4]], + } + ) + + assert df.select(pl.col("lists").list.get(3, null_on_oob=True)).to_dict( + as_series=False + ) == {"lists": [None, None, 4]} + assert df.select( + pl.col("lists").list.get(pl.col("index"), null_on_oob=True) + ).to_dict(as_series=False) == {"lists": [None, None, 4]} + + +def test_list_categorical_get() -> None: + df = pl.DataFrame( + { + "actions": pl.Series( + [["a", "b"], ["c"], [None], None], dtype=pl.List(pl.Categorical) + ), + } + ) + expected = pl.Series("actions", ["a", "c", None, None], dtype=pl.Categorical) + assert_series_equal( + df["actions"].list.get(0, null_on_oob=True), expected, categorical_as_str=True + ) + + +def test_list_gather_wrong_indices_list_type() -> None: + a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]]) + expected = pl.Series("a", [[1, 2], [4], [6, 9]]) + + # int8 + indices_series = pl.Series("indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.Int8)) + result = a.list.gather(indices=indices_series) + assert_series_equal(result, expected) + + # int16 + indices_series = pl.Series( + "indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.Int16) + ) + result = a.list.gather(indices=indices_series) + assert_series_equal(result, expected) + + # int32 + indices_series = pl.Series( + "indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.Int32) + ) + result = a.list.gather(indices=indices_series) + assert_series_equal(result, expected) + + # int64 + indices_series = pl.Series( + "indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.Int64) + ) + result = a.list.gather(indices=indices_series) + assert_series_equal(result, expected) + + # uint8 + indices_series = pl.Series( + "indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.UInt8) + ) + result = a.list.gather(indices=indices_series) + assert_series_equal(result, expected) + + # uint16 + indices_series = pl.Series( + "indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.UInt16) + ) + result = a.list.gather(indices=indices_series) + assert_series_equal(result, expected) + + # uint32 + indices_series = pl.Series( + "indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.UInt32) + ) + result = a.list.gather(indices=indices_series) + assert_series_equal(result, expected) + + # uint64 + indices_series = pl.Series( + "indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.UInt64) + ) + result = a.list.gather(indices=indices_series) + assert_series_equal(result, expected) + + df = pl.DataFrame( + { + "index": [["2"], ["2"], ["2"]], + "lists": [[3, 4, 5], [4, 5, 6], [7, 8, 9, 4]], + } + ) + with pytest.raises( + ComputeError, match=re.escape("cannot use dtype `list[str]` as an index") + ): + df.select(pl.col("lists").list.gather(pl.col("index"))) + + +def test_contains() -> None: + a = pl.Series("a", [[1, 2, 3], [2, 5], [6, 7, 8, 9]]) + out = a.list.contains(2) + expected = pl.Series("a", [True, True, False]) + assert_series_equal(out, expected) + + out = pl.select(pl.lit(a).list.contains(2)).to_series() + assert_series_equal(out, expected) + + +def test_list_contains_invalid_datatype() -> None: + df = pl.DataFrame({"a": [[1, 2], [3, 4]]}, schema={"a": pl.Array(pl.Int8, shape=2)}) + with pytest.raises(SchemaError, match="invalid series dtype: expected `List`"): + df.select(pl.col("a").list.contains(2)) + + +def test_list_contains_wildcard_expansion() -> None: + # Test that wildcard expansions occurs correctly in list.contains + # https://github.com/pola-rs/polars/issues/18968 + df = pl.DataFrame({"a": [[1, 2]], "b": [[3, 4]]}) + assert df.select(pl.all().list.contains(3)).to_dict(as_series=False) == { + "a": [False], + "b": [True], + } + + +def test_list_concat() -> None: + df = pl.DataFrame({"a": [[1, 2], [1], [1, 2, 3]]}) + + out = df.select([pl.col("a").list.concat(pl.Series([[1, 2]]))]) + assert out["a"][0].to_list() == [1, 2, 1, 2] + + out = df.select([pl.col("a").list.concat([1, 4])]) + assert out["a"][0].to_list() == [1, 2, 1, 4] + + out_s = df["a"].list.concat([4, 1]) + assert out_s[0].to_list() == [1, 2, 4, 1] + + +def test_list_join() -> None: + df = pl.DataFrame( + { + "a": [["ab", "c", "d"], ["e", "f"], ["g"], [], None], + "separator": ["&", None, "*", "_", "*"], + } + ) + out = df.select(pl.col("a").list.join("-")) + assert out.to_dict(as_series=False) == {"a": ["ab-c-d", "e-f", "g", "", None]} + out = df.select(pl.col("a").list.join(pl.col("separator"))) + assert out.to_dict(as_series=False) == {"a": ["ab&c&d", None, "g", "", None]} + + # test ignore_nulls argument + df = pl.DataFrame( + { + "a": [["a", None, "b", None], None, [None, None], ["c", "d"], []], + "separator": ["-", "&", " ", "@", "/"], + } + ) + # ignore nulls + out = df.select(pl.col("a").list.join("-", ignore_nulls=True)) + assert out.to_dict(as_series=False) == {"a": ["a-b", None, "", "c-d", ""]} + out = df.select(pl.col("a").list.join(pl.col("separator"), ignore_nulls=True)) + assert out.to_dict(as_series=False) == {"a": ["a-b", None, "", "c@d", ""]} + # propagate nulls + out = df.select(pl.col("a").list.join("-", ignore_nulls=False)) + assert out.to_dict(as_series=False) == {"a": [None, None, None, "c-d", ""]} + out = df.select(pl.col("a").list.join(pl.col("separator"), ignore_nulls=False)) + assert out.to_dict(as_series=False) == {"a": [None, None, None, "c@d", ""]} + + +def test_list_arr_empty() -> None: + df = pl.DataFrame({"cars": [[1, 2, 3], [2, 3], [4], []]}) + + out = df.select( + pl.col("cars").list.first().alias("cars_first"), + pl.when(pl.col("cars").list.first() == 2) + .then(1) + .when(pl.col("cars").list.contains(2)) + .then(2) + .otherwise(3) + .alias("cars_literal"), + ) + expected = pl.DataFrame( + {"cars_first": [1, 2, 4, None], "cars_literal": [2, 1, 3, 3]}, + schema_overrides={"cars_literal": pl.Int32}, # Literals default to Int32 + ) + assert_frame_equal(out, expected) + + +def test_list_argminmax() -> None: + s = pl.Series("a", [[1, 2], [3, 2, 1]]) + expected = pl.Series("a", [0, 2], dtype=pl.UInt32) + assert_series_equal(s.list.arg_min(), expected) + expected = pl.Series("a", [1, 0], dtype=pl.UInt32) + assert_series_equal(s.list.arg_max(), expected) + + +def test_list_shift() -> None: + s = pl.Series("a", [[1, 2], [3, 2, 1]]) + expected = pl.Series("a", [[None, 1], [None, 3, 2]]) + assert s.list.shift().to_list() == expected.to_list() + + df = pl.DataFrame( + { + "values": [ + [1, 2, None], + [1, 2, 3], + [None, 1, 2], + [None, None, None], + [1, 2], + ], + "shift": [1, -2, 3, 2, None], + } + ) + df = df.select(pl.col("values").list.shift(pl.col("shift"))) + expected_df = pl.DataFrame( + { + "values": [ + [None, 1, 2], + [3, None, None], + [None, None, None], + [None, None, None], + None, + ] + } + ) + assert_frame_equal(df, expected_df) + + +def test_list_drop_nulls() -> None: + s = pl.Series("values", [[1, None, 2, None], [None, None], [1, 2], None]) + expected = pl.Series("values", [[1, 2], [], [1, 2], None]) + assert_series_equal(s.list.drop_nulls(), expected) + + df = pl.DataFrame({"values": [[None, 1, None, 2], [None], [3, 4]]}) + df = df.select(pl.col("values").list.drop_nulls()) + expected_df = pl.DataFrame({"values": [[1, 2], [], [3, 4]]}) + assert_frame_equal(df, expected_df) + + +def test_list_sample() -> None: + s = pl.Series("values", [[1, 2, 3, None], [None, None], [1, 2], None]) + + expected_sample_n = pl.Series("values", [[3, 1], [None], [2], None]) + assert_series_equal( + s.list.sample(n=pl.Series([2, 1, 1, 1]), seed=1), expected_sample_n + ) + + expected_sample_frac = pl.Series("values", [[3, 1], [None], [1, 2], None]) + assert_series_equal( + s.list.sample(fraction=pl.Series([0.5, 0.5, 1.0, 0.3]), seed=1), + expected_sample_frac, + ) + + df = pl.DataFrame( + { + "values": [[1, 2, 3, None], [None, None], [3, 4]], + "n": [2, 1, 2], + "frac": [0.5, 0.5, 1.0], + } + ) + df = df.select( + sample_n=pl.col("values").list.sample(n=pl.col("n"), seed=1), + sample_frac=pl.col("values").list.sample(fraction=pl.col("frac"), seed=1), + ) + expected_df = pl.DataFrame( + {"sample_n": [[3, 1], [None], [3, 4]], "sample_frac": [[3, 1], [None], [3, 4]]} + ) + assert_frame_equal(df, expected_df) + + +def test_list_diff() -> None: + s = pl.Series("a", [[1, 2], [10, 2, 1]]) + expected = pl.Series("a", [[None, 1], [None, -8, -1]]) + assert s.list.diff().to_list() == expected.to_list() + + +def test_slice() -> None: + vals = [[1, 2, 3, 4], [10, 2, 1]] + s = pl.Series("a", vals) + assert s.list.head(2).to_list() == [[1, 2], [10, 2]] + assert s.list.tail(2).to_list() == [[3, 4], [2, 1]] + assert s.list.tail(200).to_list() == vals + assert s.list.head(200).to_list() == vals + assert s.list.slice(1, 2).to_list() == [[2, 3], [2, 1]] + assert s.list.slice(-5, 2).to_list() == [[1], []] + + +def test_list_ternary_concat() -> None: + df = pl.DataFrame( + { + "list1": [["123", "456"], None], + "list2": [["789"], ["zzz"]], + } + ) + + assert df.with_columns( + pl.when(pl.col("list1").is_null()) + .then(pl.col("list1").list.concat(pl.col("list2"))) + .otherwise(pl.col("list2")) + .alias("result") + ).to_dict(as_series=False) == { + "list1": [["123", "456"], None], + "list2": [["789"], ["zzz"]], + "result": [["789"], None], + } + + assert df.with_columns( + pl.when(pl.col("list1").is_null()) + .then(pl.col("list2")) + .otherwise(pl.col("list1").list.concat(pl.col("list2"))) + .alias("result") + ).to_dict(as_series=False) == { + "list1": [["123", "456"], None], + "list2": [["789"], ["zzz"]], + "result": [["123", "456", "789"], ["zzz"]], + } + + +def test_arr_contains_categorical() -> None: + df = pl.DataFrame( + {"str": ["A", "B", "A", "B", "C"], "group": [1, 1, 2, 1, 2]} + ).lazy() + df = df.with_columns(pl.col("str").cast(pl.Categorical)) + df_groups = df.group_by("group").agg([pl.col("str").alias("str_list")]) + + result = df_groups.filter(pl.col("str_list").list.contains("C")).collect() + expected = {"group": [2], "str_list": [["A", "C"]]} + assert result.to_dict(as_series=False) == expected + + +def test_list_slice() -> None: + df = pl.DataFrame( + { + "lst": [[1, 2, 3, 4], [10, 2, 1]], + "offset": [1, 2], + "len": [3, 2], + } + ) + + assert df.select([pl.col("lst").list.slice("offset", "len")]).to_dict( + as_series=False + ) == {"lst": [[2, 3, 4], [1]]} + assert df.select([pl.col("lst").list.slice("offset", 1)]).to_dict( + as_series=False + ) == {"lst": [[2], [1]]} + assert df.select([pl.col("lst").list.slice(-2, "len")]).to_dict( + as_series=False + ) == {"lst": [[3, 4], [2, 1]]} + + +def test_list_sliced_get_5186() -> None: + # https://github.com/pola-rs/polars/issues/5186 + n = 30 + df = pl.from_dict( + { + "ind": pl.arange(0, n, eager=True), + "inds": pl.Series( + np.stack([np.arange(n), -np.arange(n)], axis=-1), dtype=pl.List + ), + } + ) + + exprs = [ + "ind", + pl.col("inds").list.first().alias("first_element"), + pl.col("inds").list.last().alias("last_element"), + ] + out1 = df.select(exprs)[10:20] + out2 = df[10:20].select(exprs) + assert_frame_equal(out1, out2) + + +def test_list_amortized_apply_explode_5812() -> None: + s = pl.Series([None, [1, 3], [0, -3], [1, 2, 2]]) + assert s.list.sum().to_list() == [None, 4, -3, 5] + assert s.list.min().to_list() == [None, 1, -3, 1] + assert s.list.max().to_list() == [None, 3, 0, 2] + assert s.list.arg_min().to_list() == [None, 0, 1, 0] + assert s.list.arg_max().to_list() == [None, 1, 0, 1] + + +def test_list_slice_5866() -> None: + vals = [[1, 2, 3, 4], [10, 2, 1]] + s = pl.Series("a", vals) + assert s.list.slice(1).to_list() == [[2, 3, 4], [2, 1]] + + +def test_list_gather() -> None: + s = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8]]) + # mypy: we make it work, but idiomatic is `arr.get`. + assert s.list.gather(0).to_list() == [[1], [4], [6]] # type: ignore[arg-type] + assert s.list.gather([0, 1]).to_list() == [[1, 2], [4, 5], [6, 7]] + + assert s.list.gather([-1, 1]).to_list() == [[3, 2], [5, 5], [8, 7]] + + # use another list to make sure negative indices are respected + gatherer = pl.Series([[-1, 1], [-1, 1], [-1, -2]]) + assert s.list.gather(gatherer).to_list() == [[3, 2], [5, 5], [8, 7]] + with pytest.raises(OutOfBoundsError, match=r"gather indices are out of bounds"): + s.list.gather([1, 2]) + s = pl.Series( + [["A", "B", "C"], ["A"], ["B"], ["1", "2"], ["e"]], + ) + + assert s.list.gather([0, 2], null_on_oob=True).to_list() == [ + ["A", "C"], + ["A", None], + ["B", None], + ["1", None], + ["e", None], + ] + assert s.list.gather([0, 1, 2], null_on_oob=True).to_list() == [ + ["A", "B", "C"], + ["A", None, None], + ["B", None, None], + ["1", "2", None], + ["e", None, None], + ] + s = pl.Series([[42, 1, 2], [5, 6, 7]]) + + with pytest.raises(OutOfBoundsError, match=r"gather indices are out of bounds"): + s.list.gather([[0, 1, 2, 3], [0, 1, 2, 3]]) + + assert s.list.gather([0, 1, 2, 3], null_on_oob=True).to_list() == [ + [42, 1, 2, None], + [5, 6, 7, None], + ] + + +def test_list_function_group_awareness() -> None: + df = pl.DataFrame( + { + "a": [100, 103, 105, 106, 105, 104, 103, 106, 100, 102], + "group": [0, 0, 1, 1, 1, 1, 1, 1, 2, 2], + } + ) + + assert df.group_by("group").agg( + [ + pl.col("a").get(0).alias("get_scalar"), + pl.col("a").gather([0]).alias("take_no_implode"), + pl.col("a").implode().list.get(0).alias("implode_get"), + pl.col("a").implode().list.gather([0]).alias("implode_take"), + pl.col("a").implode().list.slice(0, 3).alias("implode_slice"), + ] + ).sort("group").to_dict(as_series=False) == { + "group": [0, 1, 2], + "get_scalar": [100, 105, 100], + "take_no_implode": [[100], [105], [100]], + "implode_get": [100, 105, 100], + "implode_take": [[[100]], [[105]], [[100]]], + "implode_slice": [[100, 103], [105, 106, 105], [100, 102]], + } + + +def test_list_get_logical_types() -> None: + df = pl.DataFrame( + { + "date_col": [[datetime(2023, 2, 1).date(), datetime(2023, 2, 2).date()]], + "datetime_col": [[datetime(2023, 2, 1), datetime(2023, 2, 2)]], + } + ) + + assert df.select(pl.all().list.get(1).name.suffix("_element_1")).to_dict( + as_series=False + ) == { + "date_col_element_1": [date(2023, 2, 2)], + "datetime_col_element_1": [datetime(2023, 2, 2, 0, 0)], + } + + +def test_list_gather_logical_type() -> None: + df = pl.DataFrame( + {"foo": [["foo", "foo", "bar"]], "bar": [[5.0, 10.0, 12.0]]} + ).with_columns(pl.col("foo").cast(pl.List(pl.Categorical))) + + df = pl.concat([df, df], rechunk=False) + assert df.n_chunks() == 2 + assert df.select(pl.all().gather([0, 1])).to_dict(as_series=False) == { + "foo": [["foo", "foo", "bar"], ["foo", "foo", "bar"]], + "bar": [[5.0, 10.0, 12.0], [5.0, 10.0, 12.0]], + } + + +def test_list_unique() -> None: + s = pl.Series([[1, 1, 2, 2, 3], [3, 3, 3, 2, 1, 2]]) + result = s.list.unique(maintain_order=True) + expected = pl.Series([[1, 2, 3], [3, 2, 1]]) + assert_series_equal(result, expected) + + +def test_list_unique2() -> None: + s = pl.Series("a", [[2, 1], [1, 2, 2]]) + result = s.list.unique() + assert len(result) == 2 + assert sorted(result[0]) == [1, 2] + assert sorted(result[1]) == [1, 2] + + +@pytest.mark.may_fail_auto_streaming +def test_list_to_struct() -> None: + df = pl.DataFrame({"n": [[0, 1, 2], [0, 1]]}) + + assert df.select(pl.col("n").list.to_struct(_eager=True)).rows(named=True) == [ + {"n": {"field_0": 0, "field_1": 1, "field_2": 2}}, + {"n": {"field_0": 0, "field_1": 1, "field_2": None}}, + ] + + assert df.select( + pl.col("n").list.to_struct(fields=lambda idx: f"n{idx}", _eager=True) + ).rows(named=True) == [ + {"n": {"n0": 0, "n1": 1, "n2": 2}}, + {"n": {"n0": 0, "n1": 1, "n2": None}}, + ] + + assert df.select(pl.col("n").list.to_struct(fields=["one", "two", "three"])).rows( + named=True + ) == [ + {"n": {"one": 0, "two": 1, "three": 2}}, + {"n": {"one": 0, "two": 1, "three": None}}, + ] + + q = df.lazy().select( + pl.col("n").list.to_struct(fields=["a", "b"]).struct.field("a") + ) + + assert_frame_equal(q.collect(), pl.DataFrame({"a": [0, 0]})) + + # Check that: + # * Specifying an upper bound calls the field name getter function to + # retrieve the lazy schema + # * The upper bound is respected during execution + q = df.lazy().select( + pl.col("n") + .list.to_struct(fields=str, upper_bound=2, _eager=True) + .struct.unnest() + ) + assert q.collect_schema() == {"0": pl.Int64, "1": pl.Int64} + assert_frame_equal(q.collect(), pl.DataFrame({"0": [0, 0], "1": [1, 1]})) + + assert df.lazy().select( + pl.col("n").list.to_struct(_eager=True) + ).collect_schema() == {"n": pl.Unknown} + + +def test_list_to_struct_all_null_12119() -> None: + s = pl.Series([None], dtype=pl.List(pl.Int64)) + result = s.list.to_struct(fields=["a", "b", "c"]).to_list() + assert result == [{"a": None, "b": None, "c": None}] + + +def test_select_from_list_to_struct_11143() -> None: + ldf = pl.LazyFrame({"some_col": [[1.0, 2.0], [1.5, 3.0]]}) + ldf = ldf.select( + pl.col("some_col").list.to_struct(fields=["a", "b"], upper_bound=2) + ) + df = ldf.select(pl.col("some_col").struct.field("a")).collect() + assert df.equals(pl.DataFrame({"a": [1.0, 1.5]})) + + +def test_list_arr_get_8810() -> None: + assert pl.DataFrame(pl.Series("a", [None], pl.List(pl.Int64))).select( + pl.col("a").list.get(0, null_on_oob=True) + ).to_dict(as_series=False) == {"a": [None]} + + +def test_list_tail_underflow_9087() -> None: + assert pl.Series([["a", "b", "c"]]).list.tail(pl.lit(1, pl.UInt32)).to_list() == [ + ["c"] + ] + + +def test_list_count_match_boolean_nulls_9141() -> None: + a = pl.DataFrame({"a": [[True, None, False]]}) + assert a.select(pl.col("a").list.count_matches(True))["a"].to_list() == [1] + + +def test_list_count_match_categorical() -> None: + df = pl.DataFrame( + {"list": [["0"], ["1"], ["1", "2", "3", "2"], ["1", "2", "1"], ["4", "4"]]}, + schema={"list": pl.List(pl.Categorical)}, + ) + assert df.select(pl.col("list").list.count_matches("2").alias("number_of_twos"))[ + "number_of_twos" + ].to_list() == [0, 0, 2, 1, 0] + + +def test_list_count_matches_boolean_nulls_9141() -> None: + a = pl.DataFrame({"a": [[True, None, False]]}) + + assert a.select(pl.col("a").list.count_matches(True))["a"].to_list() == [1] + + +def test_list_count_matches_wildcard_expansion() -> None: + # Test that wildcard expansions occurs correctly in list.count_match + # https://github.com/pola-rs/polars/issues/18968 + df = pl.DataFrame({"a": [[1, 2]], "b": [[3, 4]]}) + assert df.select(pl.all().list.count_matches(3)).to_dict(as_series=False) == { + "a": [0], + "b": [1], + } + + +def test_list_gather_oob_10079() -> None: + df = pl.DataFrame( + { + "a": [[1, 2, 3], [], [None, 3], [5, 6, 7]], + "b": [["2"], ["3"], [None], ["3", "Hi"]], + } + ) + with pytest.raises(OutOfBoundsError, match="gather indices are out of bounds"): + df.select(pl.col("a").gather(999)) + + +def test_utf8_empty_series_arg_min_max_10703() -> None: + res = pl.select(pl.lit(pl.Series("list", [["a"], []]))).with_columns( + pl.all(), + pl.all().list.arg_min().alias("arg_min"), + pl.all().list.arg_max().alias("arg_max"), + ) + assert res.to_dict(as_series=False) == { + "list": [["a"], []], + "arg_min": [0, None], + "arg_max": [0, None], + } + + +def test_list_to_array() -> None: + data = [[1.0, 2.0], [3.0, 4.0]] + s = pl.Series(data, dtype=pl.List(pl.Float32)) + + result = s.list.to_array(2) + result_slice = s[1:].list.to_array(2) + + expected = pl.Series(data, dtype=pl.Array(pl.Float32, 2)) + assert_series_equal(result, expected) + + expected_slice = pl.Series([data[1]], dtype=pl.Array(pl.Float32, 2)) + assert_series_equal(result_slice, expected_slice) + + # test logical type + df = pl.DataFrame( + data={"duration": [[1000, 2000], None]}, + schema={ + "duration": pl.List(pl.Datetime), + }, + ).with_columns(pl.col("duration").list.to_array(2)) + + expected_df = pl.DataFrame( + data={"duration": [[1000, 2000], None]}, + schema={ + "duration": pl.Array(pl.Datetime, 2), + }, + ) + assert_frame_equal(df, expected_df) + + +def test_list_to_array_wrong_lengths() -> None: + s = pl.Series([[1.0, 2.0], [3.0, 4.0]], dtype=pl.List(pl.Float32)) + with pytest.raises(ComputeError, match="not all elements have the specified width"): + s.list.to_array(3) + + +def test_list_to_array_wrong_dtype() -> None: + s = pl.Series([1.0, 2.0]) + with pytest.raises(ComputeError, match="expected List dtype"): + s.list.to_array(2) + + +def test_list_lengths() -> None: + s = pl.Series([[1, 2, None], [5]]) + result = s.list.len() + expected = pl.Series([3, 1], dtype=pl.UInt32) + assert_series_equal(result, expected) + + s = pl.Series("a", [[1, 2], [1, 2, 3]]) + assert_series_equal(s.list.len(), pl.Series("a", [2, 3], dtype=pl.UInt32)) + df = pl.DataFrame([s]) + assert_series_equal( + df.select(pl.col("a").list.len())["a"], pl.Series("a", [2, 3], dtype=pl.UInt32) + ) + + assert_series_equal( + pl.select( + pl.when(pl.Series([True, False])) + .then(pl.Series([[1, 1], [1, 1]])) + .list.len() + ).to_series(), + pl.Series([2, None], dtype=pl.UInt32), + ) + + assert_series_equal( + pl.select( + pl.when(pl.Series([False, False])) + .then(pl.Series([[1, 1], [1, 1]])) + .list.len() + ).to_series(), + pl.Series([None, None], dtype=pl.UInt32), + ) + + +def test_list_arithmetic() -> None: + s = pl.Series("a", [[1, 2], [1, 2, 3]]) + assert_series_equal(s.list.sum(), pl.Series("a", [3, 6])) + assert_series_equal(s.list.mean(), pl.Series("a", [1.5, 2.0])) + assert_series_equal(s.list.max(), pl.Series("a", [2, 3])) + assert_series_equal(s.list.min(), pl.Series("a", [1, 1])) + + +def test_list_ordering() -> None: + s = pl.Series("a", [[2, 1], [1, 3, 2]]) + assert_series_equal(s.list.sort(), pl.Series("a", [[1, 2], [1, 2, 3]])) + assert_series_equal(s.list.reverse(), pl.Series("a", [[1, 2], [2, 3, 1]])) + + # test nulls_last + s = pl.Series([[None, 1, 2], [-1, None, 9]]) + assert_series_equal( + s.list.sort(nulls_last=True), pl.Series([[1, 2, None], [-1, 9, None]]) + ) + assert_series_equal( + s.list.sort(nulls_last=False), pl.Series([[None, 1, 2], [None, -1, 9]]) + ) + + +def test_list_get_logical_type() -> None: + s = pl.Series( + "a", + [ + [date(1999, 1, 1), date(2000, 1, 1)], + [date(2001, 10, 1), None], + ], + dtype=pl.List(pl.Date), + ) + + out = s.list.get(0) + expected = pl.Series( + "a", + [date(1999, 1, 1), date(2001, 10, 1)], + dtype=pl.Date, + ) + assert_series_equal(out, expected) + + out = s.list.get(pl.Series([1, -2])) + expected = pl.Series( + "a", + [date(2000, 1, 1), date(2001, 10, 1)], + dtype=pl.Date, + ) + assert_series_equal(out, expected) + + +def test_list_gather_every() -> None: + df = pl.DataFrame( + { + "lst": [[1, 2, 3], [], [4, 5], None, [6, 7, 8], [9, 10, 11, 12]], + "n": [2, 2, 1, 3, None, 2], + "offset": [None, 1, 0, 1, 2, 2], + } + ) + + out = df.select( + n_expr=pl.col("lst").list.gather_every(pl.col("n"), 0), + offset_expr=pl.col("lst").list.gather_every(2, pl.col("offset")), + all_expr=pl.col("lst").list.gather_every(pl.col("n"), pl.col("offset")), + all_lit=pl.col("lst").list.gather_every(2, 0), + ) + + expected = pl.DataFrame( + { + "n_expr": [[1, 3], [], [4, 5], None, None, [9, 11]], + "offset_expr": [None, [], [4], None, [8], [11]], + "all_expr": [None, [], [4, 5], None, None, [11]], + "all_lit": [[1, 3], [], [4], None, [6, 8], [9, 11]], + } + ) + + assert_frame_equal(out, expected) + + +def test_list_n_unique() -> None: + df = pl.DataFrame( + { + "a": [[1, 1, 2], [3, 3], [None], None, []], + } + ) + + out = df.select(n_unique=pl.col("a").list.n_unique()) + expected = pl.DataFrame( + {"n_unique": [2, 1, 1, None, 0]}, schema={"n_unique": pl.UInt32} + ) + assert_frame_equal(out, expected) + + +def test_list_get_with_null() -> None: + df = pl.DataFrame({"a": [None, [1, 2]], "b": [False, True]}) + + # We allow two layouts of null in ListArray: + # 1. null element are stored as arbitrary values in `value` array. + # 2. null element are not stored in `value` array. + out = df.select( + # For performance reasons, when-then-otherwise produces the list with layout-1. + layout1=pl.when(pl.col("b")).then([1, 2]).list.get(0, null_on_oob=True), + layout2=pl.col("a").list.get(0, null_on_oob=True), + ) + + expected = pl.DataFrame( + { + "layout1": [None, 1], + "layout2": [None, 1], + } + ) + + assert_frame_equal(out, expected) + + +def test_list_sum_bool_schema() -> None: + q = pl.LazyFrame({"x": [[True, True, False]]}) + assert q.select(pl.col("x").list.sum()).collect_schema()["x"] == pl.UInt32 + + +def test_list_concat_struct_19279() -> None: + df = pl.select( + pl.struct(s=pl.lit("abcd").str.split("").explode(), i=pl.int_range(0, 4)) + ) + df = pl.concat([df[:2], df[-2:]]) + assert df.select(pl.concat_list("s")).to_dict(as_series=False) == { + "s": [ + [{"s": "a", "i": 0}], + [{"s": "b", "i": 1}], + [{"s": "c", "i": 2}], + [{"s": "d", "i": 3}], + ] + } + + +def test_list_eval_element_schema_19345() -> None: + assert_frame_equal( + ( + pl.LazyFrame({"a": [[{"a": 1}]]}) + .select(pl.col("a").list.eval(pl.element().struct.field("a"))) + .collect() + ), + pl.DataFrame({"a": [[1]]}), + ) + + +@pytest.mark.parametrize( + ("agg", "inner_dtype", "expected_dtype"), + [ + ("sum", pl.Int8, pl.Int64), + ("max", pl.Int8, pl.Int8), + ("sum", pl.Duration("us"), pl.Duration("us")), + ("min", pl.Duration("ms"), pl.Duration("ms")), + ("min", pl.String, pl.String), + ("max", pl.String, pl.String), + ], +) +def test_list_agg_all_null( + agg: str, inner_dtype: PolarsDataType, expected_dtype: PolarsDataType +) -> None: + s = pl.Series([None, None], dtype=pl.List(inner_dtype)) + assert getattr(s.list, agg)().dtype == expected_dtype + + +@pytest.mark.parametrize( + ("inner_dtype", "expected_inner_dtype"), + [ + (pl.Datetime("us"), pl.Duration("us")), + (pl.Date(), pl.Duration("ms")), + (pl.Time(), pl.Duration("ns")), + (pl.UInt64(), pl.Int64()), + (pl.UInt32(), pl.Int64()), + (pl.UInt8(), pl.Int16()), + (pl.Int8(), pl.Int8()), + (pl.Float32(), pl.Float32()), + ], +) +def test_list_diff_schema( + inner_dtype: PolarsDataType, expected_inner_dtype: PolarsDataType +) -> None: + lf = ( + pl.LazyFrame({"a": [[1, 2]]}) + .cast(pl.List(inner_dtype)) + .select(pl.col("a").list.diff(1)) + ) + expected = {"a": pl.List(expected_inner_dtype)} + assert lf.collect_schema() == expected + assert lf.collect().schema == expected + + +def test_gather_every_nzero_22027() -> None: + df = pl.DataFrame( + [ + pl.Series( + "a", + [ + ["a"], + ["eb", "d"], + ], + pl.List(pl.String), + ), + ] + ) + with pytest.raises(pl.exceptions.ComputeError): + df.select(pl.col.a.list.gather_every(pl.Series([0, 0]))) + + +def test_list_sample_n_unequal_lengths_22018() -> None: + with pytest.raises(pl.exceptions.ShapeError): + pl.Series("a", [[1, 2], [1, 2]]).list.sample(pl.Series([1, 2, 1])) + + +def test_list_sample_fraction_unequal_lengths_22018() -> None: + with pytest.raises(pl.exceptions.ShapeError): + pl.Series("a", [[1, 2], [1, 2]]).list.sample( + fraction=pl.Series([0.5, 0.2, 0.4]) + ) + + +def test_list_sample_n_self_broadcast() -> None: + assert pl.Series("a", [[1, 2]]).list.sample(pl.Series([1, 2, 1])).len() == 3 + + +def test_list_sample_fraction_self_broadcast() -> None: + assert ( + pl.Series("a", [[1, 2]]).list.sample(fraction=pl.Series([0.5, 0.2, 0.4])).len() + == 3 + ) + + +def test_list_shift_unequal_lengths_22018() -> None: + with pytest.raises(pl.exceptions.ShapeError): + pl.Series("a", [[1, 2], [1, 2]]).list.shift(pl.Series([1, 2, 3])) + + +def test_list_shift_self_broadcast() -> None: + assert pl.Series("a", [[1, 2]]).list.shift(pl.Series([1, 2, 1])).len() == 3 diff --git a/py-polars/tests/unit/operations/namespaces/list/test_set_operations.py b/py-polars/tests/unit/operations/namespaces/list/test_set_operations.py new file mode 100644 index 000000000000..67b91f149d69 --- /dev/null +++ b/py-polars/tests/unit/operations/namespaces/list/test_set_operations.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +import polars as pl +from polars.testing import assert_frame_equal + + +def test_list_set_oob() -> None: + df = pl.DataFrame({"a": [[42], [23]]}) + result = df.select(pl.col("a").list.set_intersection([])) + assert result.to_dict(as_series=False) == {"a": [[], []]} + + +def test_list_set_operations_float() -> None: + df = pl.DataFrame( + {"a": [[1, 2, 3], [1, 1, 1], [4]], "b": [[4, 2, 1], [2, 1, 12], [4]]}, + schema={"a": pl.List(pl.Float32), "b": pl.List(pl.Float32)}, + ) + + assert df.select(pl.col("a").list.set_union("b"))["a"].to_list() == [ + [1.0, 2.0, 3.0, 4.0], + [1.0, 2.0, 12.0], + [4.0], + ] + assert df.select(pl.col("a").list.set_intersection("b"))["a"].to_list() == [ + [1.0, 2.0], + [1.0], + [4.0], + ] + assert df.select(pl.col("a").list.set_difference("b"))["a"].to_list() == [ + [3.0], + [], + [], + ] + assert df.select(pl.col("b").list.set_difference("a"))["b"].to_list() == [ + [4.0], + [2.0, 12.0], + [], + ] + + +def test_list_set_operations() -> None: + df = pl.DataFrame( + { + "a": [[1, 2, 3], [1, 1, 1], [4]], + "b": [[4, 2, 1], [2, 1, 12], [4]], + "c": [[1, 2], [2, 1, 76], [8, 9]], + } + ) + + assert df.select(pl.col("a").list.set_union("b"))["a"].to_list() == [ + [1, 2, 3, 4], + [1, 2, 12], + [4], + ] + assert df.select(pl.col("a").list.set_intersection("b"))["a"].to_list() == [ + [1, 2], + [1], + [4], + ] + assert df.select(pl.col("a").list.set_difference("b"))["a"].to_list() == [ + [3], + [], + [], + ] + assert df.select(pl.col("b").list.set_difference("a"))["b"].to_list() == [ + [4], + [2, 12], + [], + ] + + # check expansion of columns + assert df.select(pl.col("a", "b").list.set_intersection("c")).to_dict( + as_series=False + ) == {"a": [[1, 2], [1], []], "b": [[2, 1], [2, 1], []]} + + assert df.select(pl.col("a", "b").list.set_union("c")).to_dict(as_series=False) == { + "a": [[1, 2, 3], [1, 2, 76], [4, 8, 9]], + "b": [[4, 2, 1], [2, 1, 12, 76], [4, 8, 9]], + } + + assert df.select(pl.col("a", "b").list.set_difference("c")).to_dict( + as_series=False + ) == {"a": [[3], [], [4]], "b": [[4], [12], [4]]} + + assert df.select(pl.col("a", "b").list.set_symmetric_difference("c")).to_dict( + as_series=False + ) == {"a": [[3], [2, 76], [4, 8, 9]], "b": [[4], [12, 76], [4, 8, 9]]} + + # check logical types + dtype = pl.List(pl.Date) + assert ( + df.select(pl.col("b").cast(dtype).list.set_difference(pl.col("a").cast(dtype)))[ + "b" + ].dtype + == dtype + ) + + df = pl.DataFrame( + { + "a": [["a", "b", "c"], ["b", "e", "z"]], + "b": [["b", "s", "a"], ["a", "e", "f"]], + } + ) + + assert df.select(pl.col("a").list.set_union("b"))["a"].to_list() == [ + ["a", "b", "c", "s"], + ["b", "e", "z", "a", "f"], + ] + + df = pl.DataFrame( + { + "a": [[2, 3, 3], [3, 1], [1, 2, 3]], + "b": [[2, 3, 4], [3, 3, 1], [3, 3]], + } + ) + r1 = df.with_columns(pl.col("a").list.set_intersection("b"))["a"].to_list() + r2 = df.with_columns(pl.col("b").list.set_intersection("a"))["b"].to_list() + exp = [[2, 3], [3, 1], [3]] + assert r1 == exp + assert r2 == exp + + +def test_list_set_operations_broadcast() -> None: + df = pl.DataFrame( + { + "a": [[2, 3, 3], [3, 1], [1, 2, 3]], + } + ) + + assert df.with_columns( + pl.col("a").list.set_intersection(pl.lit(pl.Series([[1, 2]]))) + ).to_dict(as_series=False) == {"a": [[2], [1], [1, 2]]} + assert df.with_columns( + pl.col("a").list.set_union(pl.lit(pl.Series([[1, 2]]))) + ).to_dict(as_series=False) == {"a": [[2, 3, 1], [3, 1, 2], [1, 2, 3]]} + assert df.with_columns( + pl.col("a").list.set_difference(pl.lit(pl.Series([[1, 2]]))) + ).to_dict(as_series=False) == {"a": [[3], [3], [3]]} + assert df.with_columns( + pl.lit(pl.Series("a", [[1, 2]])).list.set_difference("a") + ).to_dict(as_series=False) == {"a": [[1], [2], []]} + + +def test_list_set_operation_different_length_chunk_12734() -> None: + df = pl.DataFrame( + { + "a": [[2, 3, 3], [4, 1], [1, 2, 3]], + } + ) + + df = pl.concat([df.slice(0, 1), df.slice(1, 1), df.slice(2, 1)], rechunk=False) + assert df.with_columns( + pl.col("a").list.set_difference(pl.lit(pl.Series([[1, 2]]))) + ).to_dict(as_series=False) == {"a": [[3], [4], [3]]} + + +def test_list_set_operations_binary() -> None: + df = pl.DataFrame( + { + "a": [[b"1", b"2", b"3"], [b"1", b"1", b"1"], [b"4"]], + "b": [[b"4", b"2", b"1"], [b"2", b"1", b"12"], [b"4"]], + }, + schema={"a": pl.List(pl.Binary), "b": pl.List(pl.Binary)}, + ) + + assert df.select(pl.col("a").list.set_union("b"))["a"].to_list() == [ + [b"1", b"2", b"3", b"4"], + [b"1", b"2", b"12"], + [b"4"], + ] + assert df.select(pl.col("a").list.set_intersection("b"))["a"].to_list() == [ + [b"1", b"2"], + [b"1"], + [b"4"], + ] + assert df.select(pl.col("a").list.set_difference("b"))["a"].to_list() == [ + [b"3"], + [], + [], + ] + assert df.select(pl.col("b").list.set_difference("a"))["b"].to_list() == [ + [b"4"], + [b"2", b"12"], + [], + ] + + +def test_list_set_operations_broadcast_binary() -> None: + df = pl.DataFrame( + { + "a": [["2", "3", "3"], ["3", "1"], ["1", "2", "3"]], + "b": [["1", "2"], ["4"], ["5"]], + } + ) + + assert df.select(pl.col("a").list.set_intersection(pl.col.b.first())).to_dict( + as_series=False + ) == {"a": [["2"], ["1"], ["1", "2"]]} + assert df.select(pl.col("a").list.set_union(pl.col.b.first())).to_dict( + as_series=False + ) == {"a": [["2", "3", "1"], ["3", "1", "2"], ["1", "2", "3"]]} + assert df.select(pl.col("a").list.set_difference(pl.col.b.first())).to_dict( + as_series=False + ) == {"a": [["3"], ["3"], ["3"]]} + assert df.select(pl.col.b.first().list.set_difference("a")).to_dict( + as_series=False + ) == {"b": [["1"], ["2"], []]} + + +def test_set_operations_14290() -> None: + df = pl.DataFrame( + { + "a": [[1, 2], [2, 3]], + "b": [None, [1, 2]], + } + ) + + out = df.with_columns(pl.col("a").shift(1).alias("shifted_a")).select( + b_dif_a=pl.col("b").list.set_difference("a"), + shifted_a_dif_a=pl.col("shifted_a").list.set_difference("a"), + ) + expected = pl.DataFrame({"b_dif_a": [None, [1]], "shifted_a_dif_a": [None, [1]]}) + assert_frame_equal(out, expected) + + +def test_broadcast_sliced() -> None: + df = pl.DataFrame({"a": [[1, 2], [3, 4]]}) + out = df.select( + pl.col("a").list.set_difference(pl.Series([[1], [2, 3, 4]]).slice(0, 1)) + ) + expected = pl.DataFrame({"a": [[2], [3, 4]]}) + + assert_frame_equal(out, expected) diff --git a/py-polars/tests/unit/operations/namespaces/string/__init__.py b/py-polars/tests/unit/operations/namespaces/string/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/operations/namespaces/string/test_concat.py b/py-polars/tests/unit/operations/namespaces/string/test_concat.py new file mode 100644 index 000000000000..13ee591cd3a8 --- /dev/null +++ b/py-polars/tests/unit/operations/namespaces/string/test_concat.py @@ -0,0 +1,78 @@ +from datetime import datetime + +import pytest + +import polars as pl +from polars.testing import assert_series_equal + + +def test_str_join() -> None: + s = pl.Series(["1", None, "2", None]) + # propagate null + assert_series_equal( + s.str.join("-", ignore_nulls=False), pl.Series([None], dtype=pl.String) + ) + # ignore null + assert_series_equal(s.str.join(), pl.Series(["12"])) + + # str None/null is ok + s = pl.Series(["1", "None", "2", "null"]) + assert_series_equal( + s.str.join("-", ignore_nulls=False), pl.Series(["1-None-2-null"]) + ) + assert_series_equal(s.str.join("-"), pl.Series(["1-None-2-null"])) + + +def test_str_join2() -> None: + df = pl.DataFrame({"foo": [1, None, 2, None]}) + + out = df.select(pl.col("foo").str.join(ignore_nulls=False)) + assert out.item() is None + + out = df.select(pl.col("foo").str.join()) + assert out.item() == "12" + + +def test_str_join_all_null() -> None: + s = pl.Series([None, None, None], dtype=pl.String) + assert_series_equal( + s.str.join(ignore_nulls=False), pl.Series([None], dtype=pl.String) + ) + assert_series_equal(s.str.join(ignore_nulls=True), pl.Series([""])) + + +def test_str_join_empty_list() -> None: + s = pl.Series([], dtype=pl.String) + assert_series_equal(s.str.join(ignore_nulls=False), pl.Series([""])) + assert_series_equal(s.str.join(ignore_nulls=True), pl.Series([""])) + + +def test_str_join_empty_list2() -> None: + s = pl.Series([], dtype=pl.String) + df = pl.DataFrame({"foo": s}) + result = df.select(pl.col("foo").str.join()).item() + expected = "" + assert result == expected + + +def test_str_join_empty_list_agg_context() -> None: + df = pl.DataFrame(data={"i": [1], "v": [None]}, schema_overrides={"v": pl.String}) + result = df.group_by("i").agg(pl.col("v").drop_nulls().str.join())["v"].item() + expected = "" + assert result == expected + + +def test_str_join_datetime() -> None: + df = pl.DataFrame({"d": [datetime(2020, 1, 1), None, datetime(2022, 1, 1)]}) + out = df.select(pl.col("d").str.join("|", ignore_nulls=True)) + assert out.item() == "2020-01-01 00:00:00.000000|2022-01-01 00:00:00.000000" + out = df.select(pl.col("d").str.join("|", ignore_nulls=False)) + assert out.item() is None + + +def test_str_concat_deprecated() -> None: + s = pl.Series(["1", None, "2", None]) + with pytest.deprecated_call(): + result = s.str.concat() + expected = pl.Series(["1-2"]) + assert_series_equal(result, expected) diff --git a/py-polars/tests/unit/operations/namespaces/string/test_pad.py b/py-polars/tests/unit/operations/namespaces/string/test_pad.py new file mode 100644 index 000000000000..62723de0f38c --- /dev/null +++ b/py-polars/tests/unit/operations/namespaces/string/test_pad.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +import pytest + +import polars as pl +from polars.exceptions import ShapeError +from polars.testing import assert_frame_equal + + +def test_str_pad_start() -> None: + df = pl.DataFrame({"a": ["foo", "longer_foo", "longest_fooooooo", "hi"]}) + + result = df.select( + pl.col("a").str.pad_start(10).alias("padded"), + pl.col("a").str.pad_start(10).str.len_bytes().alias("padded_len"), + ) + + expected = pl.DataFrame( + { + "padded": [" foo", "longer_foo", "longest_fooooooo", " hi"], + "padded_len": [10, 10, 16, 10], + }, + schema_overrides={"padded_len": pl.UInt32}, + ) + assert_frame_equal(result, expected) + + +def test_str_pad_end() -> None: + df = pl.DataFrame({"a": ["foo", "longer_foo", "longest_fooooooo", "hi"]}) + + result = df.select( + pl.col("a").str.pad_end(10).alias("padded"), + pl.col("a").str.pad_end(10).str.len_bytes().alias("padded_len"), + ) + + expected = pl.DataFrame( + { + "padded": ["foo ", "longer_foo", "longest_fooooooo", "hi "], + "padded_len": [10, 10, 16, 10], + }, + schema_overrides={"padded_len": pl.UInt32}, + ) + assert_frame_equal(result, expected) + + +def test_str_zfill() -> None: + df = pl.DataFrame( + { + "num": [-10, -1, 0, 1, 10, 100, 1000, 10000, 100000, 1000000, None], + } + ) + out = [ + "-0010", + "-0001", + "00000", + "00001", + "00010", + "00100", + "01000", + "10000", + "100000", + "1000000", + None, + ] + assert ( + df.with_columns(pl.col("num").cast(str).str.zfill(5)).to_series().to_list() + == out + ) + assert df["num"].cast(str).str.zfill(5).to_list() == out + + +def test_str_zfill_expr() -> None: + df = pl.DataFrame( + { + "num": ["-10", "-1", "0", "1", "10", None, "1"], + "len": [3, 4, 3, 2, 5, 3, None], + } + ) + out = df.select( + all_expr=pl.col("num").str.zfill(pl.col("len")), + str_lit=pl.lit("10").str.zfill(pl.col("len")), + len_lit=pl.col("num").str.zfill(5), + ) + expected = pl.DataFrame( + { + "all_expr": ["-10", "-001", "000", "01", "00010", None, None], + "str_lit": ["010", "0010", "010", "10", "00010", "010", None], + "len_lit": ["-0010", "-0001", "00000", "00001", "00010", None, "00001"], + } + ) + assert_frame_equal(out, expected) + + +def test_str_zfill_wrong_length() -> None: + df = pl.DataFrame({"num": ["-10", "-1", "0"]}) + with pytest.raises(ShapeError): + df.select(pl.col("num").str.zfill(pl.Series([1, 2]))) + + +def test_pad_end_unicode() -> None: + lf = pl.LazyFrame({"a": ["Café", "345", "東京", None]}) + + result = lf.select(pl.col("a").str.pad_end(6, "日")) + + expected = pl.LazyFrame({"a": ["Café日日", "345日日日", "東京日日日日", None]}) + assert_frame_equal(result, expected) + + +def test_pad_start_unicode() -> None: + lf = pl.LazyFrame({"a": ["Café", "345", "東京", None]}) + + result = lf.select(pl.col("a").str.pad_start(6, "日")) + + expected = pl.LazyFrame({"a": ["日日Café", "日日日345", "日日日日東京", None]}) + assert_frame_equal(result, expected) + + +def test_str_zfill_unicode_not_respected() -> None: + lf = pl.LazyFrame({"a": ["Café", "345", "東京", None]}) + + result = lf.select(pl.col("a").str.zfill(6)) + + expected = pl.LazyFrame({"a": ["0Café", "000345", "東京", None]}) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/operations/namespaces/string/test_string.py b/py-polars/tests/unit/operations/namespaces/string/test_string.py new file mode 100644 index 000000000000..b6b843fbd850 --- /dev/null +++ b/py-polars/tests/unit/operations/namespaces/string/test_string.py @@ -0,0 +1,2001 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +import polars as pl +import polars.selectors as cs +from polars.exceptions import ( + ColumnNotFoundError, + ComputeError, + InvalidOperationError, + SchemaError, + ShapeError, +) +from polars.testing import assert_frame_equal, assert_series_equal + + +def test_str_slice() -> None: + df = pl.DataFrame({"a": ["foobar", "barfoo"]}) + assert df["a"].str.slice(-3).to_list() == ["bar", "foo"] + assert df.select([pl.col("a").str.slice(2, 4)])["a"].to_list() == ["obar", "rfoo"] + + +def test_str_slice_expr() -> None: + df = pl.DataFrame( + { + "a": ["foobar", None, "barfoo", "abcd", ""], + "offset": [1, 3, None, -3, 2], + "length": [3, 4, 2, None, 2], + } + ) + out = df.select( + all_expr=pl.col("a").str.slice("offset", "length"), + offset_expr=pl.col("a").str.slice("offset", 2), + length_expr=pl.col("a").str.slice(0, "length"), + length_none=pl.col("a").str.slice("offset", None), + offset_length_lit=pl.col("a").str.slice(-3, 3), + str_lit=pl.lit("qwert").str.slice("offset", "length"), + ) + expected = pl.DataFrame( + { + "all_expr": ["oob", None, None, "bcd", ""], + "offset_expr": ["oo", None, None, "bc", ""], + "length_expr": ["foo", None, "ba", "abcd", ""], + "length_none": ["oobar", None, None, "bcd", ""], + "offset_length_lit": ["bar", None, "foo", "bcd", ""], + "str_lit": ["wer", "rt", None, "ert", "er"], + } + ) + assert_frame_equal(out, expected) + + # negative length is not allowed + with pytest.raises(InvalidOperationError): + df.select(pl.col("a").str.slice(0, -1)) + + +def test_str_slice_wrong_length() -> None: + df = pl.DataFrame({"num": ["-10", "-1", "0"]}) + with pytest.raises(ShapeError): + df.select(pl.col("num").str.slice(pl.Series([1, 2]))) + + +@pytest.mark.parametrize( + ("input", "n", "output"), + [ + (["012345", "", None], 0, ["", "", None]), + (["012345", "", None], 2, ["01", "", None]), + (["012345", "", None], -2, ["0123", "", None]), + (["012345", "", None], 100, ["012345", "", None]), + (["012345", "", None], -100, ["", "", None]), + ], +) +def test_str_head(input: list[str], n: int, output: list[str]) -> None: + assert pl.Series(input).str.head(n).to_list() == output + + +@pytest.mark.parametrize( + ("input", "n", "output"), + [ + ("你好世界", 0, ""), + ("你好世界", 2, "你好"), + ("你好世界", 999, "你好世界"), + ("你好世界", -1, "你好世"), + ("你好世界", -2, "你好"), + ("你好世界", -999, ""), + ], +) +def test_str_head_codepoints(input: str, n: int, output: str) -> None: + assert pl.Series([input]).str.head(n).to_list() == [output] + + +def test_str_head_expr() -> None: + s = "012345" + df = pl.DataFrame( + {"a": [s, s, s, s, s, s, "", None], "n": [0, 2, -2, 100, -100, None, 3, -2]} + ) + out = df.select( + n_expr=pl.col("a").str.head("n"), + n_pos2=pl.col("a").str.head(2), + n_neg2=pl.col("a").str.head(-2), + n_pos100=pl.col("a").str.head(100), + n_pos_neg100=pl.col("a").str.head(-100), + n_pos_0=pl.col("a").str.head(0), + str_lit=pl.col("a").str.head(pl.lit(2)), + lit_expr=pl.lit(s).str.head("n"), + lit_n=pl.lit(s).str.head(2), + ) + expected = pl.DataFrame( + { + "n_expr": ["", "01", "0123", "012345", "", None, "", None], + "n_pos2": ["01", "01", "01", "01", "01", "01", "", None], + "n_neg2": ["0123", "0123", "0123", "0123", "0123", "0123", "", None], + "n_pos100": [s, s, s, s, s, s, "", None], + "n_pos_neg100": ["", "", "", "", "", "", "", None], + "n_pos_0": ["", "", "", "", "", "", "", None], + "str_lit": ["01", "01", "01", "01", "01", "01", "", None], + "lit_expr": ["", "01", "0123", "012345", "", None, "012", "0123"], + "lit_n": ["01", "01", "01", "01", "01", "01", "01", "01"], + } + ) + assert_frame_equal(out, expected) + + +def test_str_head_wrong_length() -> None: + df = pl.DataFrame({"num": ["-10", "-1", "0"]}) + with pytest.raises(ShapeError): + df.select(pl.col("num").str.head(pl.Series([1, 2]))) + + +@pytest.mark.parametrize( + ("input", "n", "output"), + [ + (["012345", "", None], 0, ["", "", None]), + (["012345", "", None], 2, ["45", "", None]), + (["012345", "", None], -2, ["2345", "", None]), + (["012345", "", None], 100, ["012345", "", None]), + (["012345", "", None], -100, ["", "", None]), + ], +) +def test_str_tail(input: list[str], n: int, output: list[str]) -> None: + assert pl.Series(input).str.tail(n).to_list() == output + + +@pytest.mark.parametrize( + ("input", "n", "output"), + [ + ("你好世界", 0, ""), + ("你好世界", 2, "世界"), + ("你好世界", 999, "你好世界"), + ("你好世界", -1, "好世界"), + ("你好世界", -2, "世界"), + ("你好世界", -999, ""), + ], +) +def test_str_tail_codepoints(input: str, n: int, output: str) -> None: + assert pl.Series([input]).str.tail(n).to_list() == [output] + + +def test_str_tail_expr() -> None: + s = "012345" + df = pl.DataFrame( + {"a": [s, s, s, s, s, s, "", None], "n": [0, 2, -2, 100, -100, None, 3, -2]} + ) + out = df.select( + n_expr=pl.col("a").str.tail("n"), + n_pos2=pl.col("a").str.tail(2), + n_neg2=pl.col("a").str.tail(-2), + n_pos100=pl.col("a").str.tail(100), + n_pos_neg100=pl.col("a").str.tail(-100), + n_pos_0=pl.col("a").str.tail(0), + str_lit=pl.col("a").str.tail(pl.lit(2)), + lit_expr=pl.lit(s).str.tail("n"), + lit_n=pl.lit(s).str.tail(2), + ) + expected = pl.DataFrame( + { + "n_expr": ["", "45", "2345", "012345", "", None, "", None], + "n_pos2": ["45", "45", "45", "45", "45", "45", "", None], + "n_neg2": ["2345", "2345", "2345", "2345", "2345", "2345", "", None], + "n_pos100": [s, s, s, s, s, s, "", None], + "n_pos_neg100": ["", "", "", "", "", "", "", None], + "n_pos_0": ["", "", "", "", "", "", "", None], + "str_lit": ["45", "45", "45", "45", "45", "45", "", None], + "lit_expr": ["", "45", "2345", "012345", "", None, "345", "2345"], + "lit_n": ["45", "45", "45", "45", "45", "45", "45", "45"], + } + ) + assert_frame_equal(out, expected) + + +def test_str_tail_wrong_length() -> None: + df = pl.DataFrame({"num": ["-10", "-1", "0"]}) + with pytest.raises(ShapeError): + df.select(pl.col("num").str.tail(pl.Series([1, 2]))) + + +def test_str_slice_multibyte() -> None: + ref = "你好世界" + s = pl.Series([ref]) + + # Pad the string to simplify (negative) offsets starting before/after the string. + npad = 20 + padref = "_" * npad + ref + "_" * npad + for start in range(-5, 6): + for length in range(6): + offset = npad + start if start >= 0 else npad + start + len(ref) + correct = padref[offset : offset + length].strip("_") + result = s.str.slice(start, length) + expected = pl.Series([correct]) + assert_series_equal(result, expected) + + +def test_str_len_bytes() -> None: + s = pl.Series(["Café", None, "345", "東京"]) + result = s.str.len_bytes() + expected = pl.Series([5, None, 3, 6], dtype=pl.UInt32) + assert_series_equal(result, expected) + + +def test_str_len_chars() -> None: + s = pl.Series(["Café", None, "345", "東京"]) + result = s.str.len_chars() + expected = pl.Series([4, None, 3, 2], dtype=pl.UInt32) + assert_series_equal(result, expected) + + +def test_str_contains() -> None: + s = pl.Series(["messi", "ronaldo", "ibrahimovic"]) + expected = pl.Series([True, False, False]) + assert_series_equal(s.str.contains("mes"), expected) + + +def test_str_contains_wrong_length() -> None: + df = pl.DataFrame({"num": ["-10", "-1", "0"]}) + with pytest.raises(ShapeError): + df.select(pl.col("num").str.contains(pl.Series(["a", "b"]))) # type: ignore [arg-type] + + +def test_count_match_literal() -> None: + s = pl.Series(["12 dbc 3xy", "cat\\w", "1zy3\\d\\d", None]) + out = s.str.count_matches(r"\d", literal=True) + expected = pl.Series([0, 0, 2, None], dtype=pl.UInt32) + assert_series_equal(out, expected) + + out = s.str.count_matches(pl.Series([r"\w", r"\w", r"\d", r"\d"]), literal=True) + expected = pl.Series([0, 1, 2, None], dtype=pl.UInt32) + assert_series_equal(out, expected) + + +def test_str_encode() -> None: + s = pl.Series(["foo", "bar", None]) + hex_encoded = pl.Series(["666f6f", "626172", None]) + base64_encoded = pl.Series(["Zm9v", "YmFy", None]) + + assert_series_equal(s.str.encode("hex"), hex_encoded) + assert_series_equal(s.str.encode("base64"), base64_encoded) + with pytest.raises(ValueError): + s.str.encode("utf8") # type: ignore[arg-type] + + +def test_str_decode() -> None: + hex_encoded = pl.Series(["666f6f", "626172", None]) + base64_encoded = pl.Series(["Zm9v", "YmFy", None]) + expected = pl.Series([b"foo", b"bar", None]) + + assert_series_equal(hex_encoded.str.decode("hex"), expected) + assert_series_equal(base64_encoded.str.decode("base64"), expected) + + +def test_str_decode_exception() -> None: + s = pl.Series(["not a valid", "626172", None]) + with pytest.raises(ComputeError): + s.str.decode(encoding="hex") + with pytest.raises(ComputeError): + s.str.decode(encoding="base64") + with pytest.raises(ValueError): + s.str.decode("utf8") # type: ignore[arg-type] + + +@pytest.mark.parametrize("strict", [True, False]) +def test_str_find(strict: bool) -> None: + df = pl.DataFrame( + data=[ + ("Dubai", 3564931, "b[ai]", "ai"), + ("Abu Dhabi", 1807000, "b[ai]", " "), + ("Sharjah", 1405000, "[ai]n", "s"), + ("Al Ain", 846747, "[ai]n", ""), + ("Ajman", 490035, "[ai]n", "ma"), + ("Ras Al Khaimah", 191753, "a.+a", "Kha"), + ("Fujairah", 118933, "a.+a", None), + ("Umm Al Quwain", 59098, "a.+a", "wa"), + (None, None, None, "n/a"), + ], + schema={ + "city": pl.String, + "population": pl.Int32, + "pat": pl.String, + "lit": pl.String, + }, + orient="row", + ) + city, pop, pat, lit = (pl.col(c) for c in ("city", "population", "pat", "lit")) + + for match_lit in (True, False): + res = df.select( + find_a_regex=city.str.find("(?i)a", strict=strict), + find_a_lit=city.str.find("a", literal=match_lit), + find_00_lit=pop.cast(pl.String).str.find("00", literal=match_lit), + find_col_lit=city.str.find(lit, strict=strict, literal=match_lit), + find_col_pat=city.str.find(pat, strict=strict), + ) + assert res.to_dict(as_series=False) == { + "find_a_regex": [3, 0, 2, 0, 0, 1, 3, 4, None], + "find_a_lit": [3, 6, 2, None, 3, 1, 3, 10, None], + "find_00_lit": [None, 4, 4, None, 2, None, None, None, None], + "find_col_lit": [3, 3, None, 0, 2, 7, None, 9, None], + "find_col_pat": [2, 7, None, 4, 3, 1, 3, None, None], + } + + +def test_str_find_invalid_regex() -> None: + # test behaviour of 'strict' with invalid regular expressions + df = pl.DataFrame({"txt": ["AbCdEfG"]}) + rx_invalid = "(?i)AB.))" + + with pytest.raises(ComputeError): + df.with_columns(pl.col("txt").str.find(rx_invalid, strict=True)) + + res = df.with_columns(pl.col("txt").str.find(rx_invalid, strict=False)) + assert res.item() is None + + +def test_str_find_escaped_chars() -> None: + # test behaviour of 'literal=True' with special chars + df = pl.DataFrame({"txt": ["123.*465", "x(x?)x"]}) + + res = df.with_columns( + x1=pl.col("txt").str.find("(x?)", literal=True), + x2=pl.col("txt").str.find(".*4", literal=True), + x3=pl.col("txt").str.find("(x?)"), + x4=pl.col("txt").str.find(".*4"), + ) + # ┌──────────┬──────┬──────┬─────┬──────┐ + # │ txt ┆ x1 ┆ x2 ┆ x3 ┆ x4 │ + # │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + # │ str ┆ u32 ┆ u32 ┆ u32 ┆ u32 │ + # ╞══════════╪══════╪══════╪═════╪══════╡ + # │ 123.*465 ┆ null ┆ 3 ┆ 0 ┆ 0 │ + # │ x(x?)x ┆ 1 ┆ null ┆ 0 ┆ null │ + # └──────────┴──────┴──────┴─────┴──────┘ + assert_frame_equal( + pl.DataFrame( + { + "txt": ["123.*465", "x(x?)x"], + "x1": [None, 1], + "x2": [3, None], + "x3": [0, 0], + "x4": [0, None], + } + ).cast({cs.signed_integer(): pl.UInt32}), + res, + ) + + +@pytest.mark.may_fail_auto_streaming +def test_str_find_wrong_length() -> None: + df = pl.DataFrame({"num": ["-10", "-1", "0"]}) + with pytest.raises(ShapeError): + df.select(pl.col("num").str.find(pl.Series(["a", "b"]))) # type: ignore [arg-type] + + +def test_hex_decode_return_dtype() -> None: + data = {"a": ["68656c6c6f", "776f726c64"]} + expr = pl.col("a").str.decode("hex") + + df = pl.DataFrame(data).select(expr) + assert df.schema == {"a": pl.Binary} + + ldf = pl.LazyFrame(data).select(expr) + assert ldf.collect_schema() == {"a": pl.Binary} + + +def test_base64_decode_return_dtype() -> None: + data = {"a": ["Zm9v", "YmFy"]} + expr = pl.col("a").str.decode("base64") + + df = pl.DataFrame(data).select(expr) + assert df.schema == {"a": pl.Binary} + + ldf = pl.LazyFrame(data).select(expr) + assert ldf.collect_schema() == {"a": pl.Binary} + + +def test_str_replace_str_replace_all() -> None: + s = pl.Series(["hello", "world", "test", "rooted"]) + expected = pl.Series(["hell0", "w0rld", "test", "r0oted"]) + assert_series_equal(s.str.replace("o", "0"), expected) + + expected = pl.Series(["hell0", "w0rld", "test", "r00ted"]) + assert_series_equal(s.str.replace_all("o", "0"), expected) + + +def test_str_replace_n_single() -> None: + s = pl.Series(["aba", "abaa"]) + + assert s.str.replace("a", "b", n=1).to_list() == ["bba", "bbaa"] + assert s.str.replace("a", "b", n=2).to_list() == ["bbb", "bbba"] + assert s.str.replace("a", "b", n=3).to_list() == ["bbb", "bbbb"] + + +def test_str_replace_n_same_length() -> None: + # pat and val have the same length + # this triggers a fast path + s = pl.Series(["abfeab", "foobarabfooabab"]) + assert s.str.replace("ab", "AB", n=1).to_list() == ["ABfeab", "foobarABfooabab"] + assert s.str.replace("ab", "AB", n=2).to_list() == ["ABfeAB", "foobarABfooABab"] + assert s.str.replace("ab", "AB", n=3).to_list() == ["ABfeAB", "foobarABfooABAB"] + + +def test_str_to_lowercase() -> None: + s = pl.Series(["Hello", "WORLD"]) + expected = pl.Series(["hello", "world"]) + assert_series_equal(s.str.to_lowercase(), expected) + + +def test_str_to_uppercase() -> None: + s = pl.Series(["Hello", "WORLD"]) + expected = pl.Series(["HELLO", "WORLD"]) + assert_series_equal(s.str.to_uppercase(), expected) + + +def test_str_case_cyrillic() -> None: + vals = ["Biтpyк", "Iвaн"] + s = pl.Series(vals) + assert s.str.to_lowercase().to_list() == [a.lower() for a in vals] + assert s.str.to_uppercase().to_list() == [a.upper() for a in vals] + + +def test_str_to_integer() -> None: + bin = pl.Series(["110", "101", "010"]) + assert_series_equal(bin.str.to_integer(base=2), pl.Series([6, 5, 2]).cast(pl.Int64)) + + hex = pl.Series(["fa1e", "ff00", "cafe", "invalid", None]) + assert_series_equal( + hex.str.to_integer(base=16, strict=False), + pl.Series([64030, 65280, 51966, None, None]).cast(pl.Int64), + check_exact=True, + ) + + with pytest.raises(ComputeError): + hex.str.to_integer(base=16) + + +@pytest.mark.parametrize("strict", [False, True]) +def test_str_to_integer_invalid_base(strict: bool) -> None: + numbers = pl.Series(["1", "ZZZ", "-ABCZZZ", None]) + with pytest.raises(ComputeError): + numbers.str.to_integer(base=100, strict=strict) + + df = pl.DataFrame({"str": numbers, "base": [0, 1, 100, None]}) + with pytest.raises(ComputeError): + df.select(pl.col("str").str.to_integer(base=pl.col("base"), strict=strict)) + + +def test_str_to_integer_base_expr() -> None: + df = pl.DataFrame( + {"str": ["110", "ff00", "234", None, "130"], "base": [2, 16, 10, 8, None]} + ) + out = df.select(base_expr=pl.col("str").str.to_integer(base="base")) + expected = pl.DataFrame({"base_expr": [6, 65280, 234, None, None]}) + assert_frame_equal(out, expected) + + # test strict raise + df = pl.DataFrame({"str": ["110", "ff00", "cafe", None], "base": [2, 10, 10, 8]}) + + with pytest.raises(ComputeError): + df.select(pl.col("str").str.to_integer(base="base")) + + +def test_str_to_integer_base_literal() -> None: + df = pl.DataFrame( + { + "bin": ["110", "101", "-010", "invalid", None], + "hex": ["fa1e", "ff00", "cafe", "invalid", None], + } + ) + result = df.with_columns( + pl.col("bin").str.to_integer(base=2, strict=False), + pl.col("hex").str.to_integer(base=16, strict=False), + ) + + expected = pl.DataFrame( + { + "bin": [6, 5, -2, None, None], + "hex": [64030, 65280, 51966, None, None], + } + ) + assert_frame_equal(result, expected) + + with pytest.raises(ComputeError): + df.with_columns( + pl.col("bin").str.to_integer(base=2), + pl.col("hex").str.to_integer(base=16), + ) + + +def test_str_strip_chars_expr() -> None: + df = pl.DataFrame( + { + "s": [" hello ", "^^world^^", "&&hi&&", " polars ", None], + "pat": [" ", "^", "&", None, "anything"], + } + ) + + all_expr = df.select( + pl.col("s").str.strip_chars(pl.col("pat")).alias("strip_chars"), + pl.col("s").str.strip_chars_start(pl.col("pat")).alias("strip_chars_start"), + pl.col("s").str.strip_chars_end(pl.col("pat")).alias("strip_chars_end"), + ) + + expected = pl.DataFrame( + { + "strip_chars": ["hello", "world", "hi", "polars", None], + "strip_chars_start": ["hello ", "world^^", "hi&&", "polars ", None], + "strip_chars_end": [" hello", "^^world", "&&hi", " polars", None], + } + ) + + assert_frame_equal(all_expr, expected) + + strip_by_null = df.select( + pl.col("s").str.strip_chars(None).alias("strip_chars"), + pl.col("s").str.strip_chars_start(None).alias("strip_chars_start"), + pl.col("s").str.strip_chars_end(None).alias("strip_chars_end"), + ) + + # only whitespace are striped. + expected = pl.DataFrame( + { + "strip_chars": ["hello", "^^world^^", "&&hi&&", "polars", None], + "strip_chars_start": ["hello ", "^^world^^", "&&hi&&", "polars ", None], + "strip_chars_end": [" hello", "^^world^^", "&&hi&&", " polars", None], + } + ) + assert_frame_equal(strip_by_null, expected) + + +def test_str_strip_chars() -> None: + s = pl.Series([" hello ", "world\t "]) + expected = pl.Series(["hello", "world"]) + assert_series_equal(s.str.strip_chars(), expected) + + expected = pl.Series(["hell", "world"]) + assert_series_equal(s.str.strip_chars().str.strip_chars("o"), expected) + + expected = pl.Series(["ell", "rld\t"]) + assert_series_equal(s.str.strip_chars(" hwo"), expected) + + +def test_str_strip_chars_wrong_length() -> None: + df = pl.DataFrame({"num": ["-10", "-1", "0"]}) + with pytest.raises(ShapeError): + df.select(pl.col("num").str.strip_chars(pl.Series(["a", "b"]))) + + +def test_str_strip_chars_start() -> None: + s = pl.Series([" hello ", "\t world"]) + expected = pl.Series(["hello ", "world"]) + assert_series_equal(s.str.strip_chars_start(), expected) + + expected = pl.Series(["ello ", "world"]) + assert_series_equal(s.str.strip_chars_start().str.strip_chars_start("h"), expected) + + expected = pl.Series(["ello ", "\t world"]) + assert_series_equal(s.str.strip_chars_start("hw "), expected) + + +def test_str_strip_chars_start_wrong_length() -> None: + df = pl.DataFrame({"num": ["-10", "-1", "0"]}) + with pytest.raises(ShapeError): + df.select(pl.col("num").str.strip_chars_start(pl.Series(["a", "b"]))) + + +def test_str_strip_chars_end() -> None: + s = pl.Series([" hello ", "world\t "]) + expected = pl.Series([" hello", "world"]) + assert_series_equal(s.str.strip_chars_end(), expected) + + expected = pl.Series([" hell", "world"]) + assert_series_equal(s.str.strip_chars_end().str.strip_chars_end("o"), expected) + + expected = pl.Series([" he", "wor"]) + assert_series_equal(s.str.strip_chars_end("odl \t"), expected) + + +def test_str_strip_chars_end_wrong_length() -> None: + df = pl.DataFrame({"num": ["-10", "-1", "0"]}) + with pytest.raises(ShapeError): + df.select(pl.col("num").str.strip_chars_end(pl.Series(["a", "b"]))) + + +def test_str_strip_whitespace() -> None: + s = pl.Series("a", ["trailing ", " leading", " both "]) + + expected = pl.Series("a", ["trailing", " leading", " both"]) + assert_series_equal(s.str.strip_chars_end(), expected) + + expected = pl.Series("a", ["trailing ", "leading", "both "]) + assert_series_equal(s.str.strip_chars_start(), expected) + + expected = pl.Series("a", ["trailing", "leading", "both"]) + assert_series_equal(s.str.strip_chars(), expected) + + +def test_str_strip_prefix_literal() -> None: + s = pl.Series(["foo:bar", "foofoo:bar", "bar:bar", "foo", "", None]) + expected = pl.Series([":bar", "foo:bar", "bar:bar", "", "", None]) + assert_series_equal(s.str.strip_prefix("foo"), expected) + # test null literal + expected = pl.Series([None, None, None, None, None, None], dtype=pl.String) + assert_series_equal(s.str.strip_prefix(pl.lit(None, dtype=pl.String)), expected) + + +def test_str_strip_prefix_suffix_expr() -> None: + df = pl.DataFrame( + { + "s": ["foo-bar", "foobarbar", "barfoo", "", "anything", None], + "prefix": ["foo", "foobar", "foo", "", None, "bar"], + "suffix": ["bar", "barbar", "bar", "", None, "foo"], + } + ) + out = df.select( + pl.col("s").str.strip_prefix(pl.col("prefix")).alias("strip_prefix"), + pl.col("s").str.strip_suffix(pl.col("suffix")).alias("strip_suffix"), + ) + assert out.to_dict(as_series=False) == { + "strip_prefix": ["-bar", "bar", "barfoo", "", None, None], + "strip_suffix": ["foo-", "foo", "barfoo", "", None, None], + } + + +def test_str_strip_prefix_wrong_length() -> None: + df = pl.DataFrame({"num": ["-10", "-1", "0"]}) + with pytest.raises(ShapeError): + df.select(pl.col("num").str.strip_prefix(pl.Series(["a", "b"]))) + + +def test_str_strip_suffix() -> None: + s = pl.Series(["foo:bar", "foo:barbar", "foo:foo", "bar", "", None]) + expected = pl.Series(["foo:", "foo:bar", "foo:foo", "", "", None]) + assert_series_equal(s.str.strip_suffix("bar"), expected) + # test null literal + expected = pl.Series([None, None, None, None, None, None], dtype=pl.String) + assert_series_equal(s.str.strip_suffix(pl.lit(None, dtype=pl.String)), expected) + + +def test_str_strip_suffix_wrong_length() -> None: + df = pl.DataFrame({"num": ["-10", "-1", "0"]}) + with pytest.raises(ShapeError): + df.select(pl.col("num").str.strip_suffix(pl.Series(["a", "b"]))) + + +def test_str_split() -> None: + a = pl.Series("a", ["a, b", "a", "ab,c,de"]) + for out in [a.str.split(","), pl.select(pl.lit(a).str.split(",")).to_series()]: + assert out[0].to_list() == ["a", " b"] + assert out[1].to_list() == ["a"] + assert out[2].to_list() == ["ab", "c", "de"] + + for out in [ + a.str.split(",", inclusive=True), + pl.select(pl.lit(a).str.split(",", inclusive=True)).to_series(), + ]: + assert out[0].to_list() == ["a,", " b"] + assert out[1].to_list() == ["a"] + assert out[2].to_list() == ["ab,", "c,", "de"] + + +def test_json_decode_series() -> None: + s = pl.Series(["[1, 2, 3]", None, "[4, 5, 6]"]) + expected = pl.Series([[1, 2, 3], None, [4, 5, 6]]) + dtype = pl.List(pl.Int64) + assert_series_equal(s.str.json_decode(None), expected) + assert_series_equal(s.str.json_decode(dtype), expected) + + s = pl.Series(['{"a": 1, "b": true}', None, '{"a": 2, "b": false}']) + expected = pl.Series([{"a": 1, "b": True}, None, {"a": 2, "b": False}]) + dtype2 = pl.Struct([pl.Field("a", pl.Int64), pl.Field("b", pl.Boolean)]) + assert_series_equal(s.str.json_decode(None), expected) + assert_series_equal(s.str.json_decode(dtype2), expected) + + expected = pl.Series([{"a": 1}, None, {"a": 2}]) + dtype2 = pl.Struct([pl.Field("a", pl.Int64)]) + assert_series_equal(s.str.json_decode(dtype2), expected) + + s = pl.Series([], dtype=pl.String) + expected = pl.Series([], dtype=pl.List(pl.Int64)) + dtype = pl.List(pl.Int64) + assert_series_equal(s.str.json_decode(dtype), expected) + + +def test_json_decode_lazy_expr() -> None: + dtype = pl.Struct([pl.Field("a", pl.Int64), pl.Field("b", pl.Boolean)]) + ldf = ( + pl.DataFrame({"json": ['{"a": 1, "b": true}', None, '{"a": 2, "b": false}']}) + .lazy() + .select(pl.col("json").str.json_decode(dtype)) + ) + expected = pl.DataFrame( + {"json": [{"a": 1, "b": True}, None, {"a": 2, "b": False}]} + ).lazy() + assert ldf.collect_schema() == {"json": dtype} + assert_frame_equal(ldf, expected) + + +def test_json_decode_nested_struct() -> None: + json = [ + '[{"key_1": "a"}]', + '[{"key_1": "a2", "key_2": 2}]', + '[{"key_1": "a3", "key_2": 3, "key_3": "c"}]', + ] + df = pl.DataFrame({"json_str": json}) + df_parsed = df.with_columns( + pl.col("json_str").str.json_decode().alias("parsed_list_json") + ) + + expected_dtype = pl.List( + pl.Struct( + [ + pl.Field("key_1", pl.String), + pl.Field("key_2", pl.Int64), + pl.Field("key_3", pl.String), + ] + ) + ) + assert df_parsed.get_column("parsed_list_json").dtype == expected_dtype + + key_1_values = df_parsed.select( + pl.col("parsed_list_json") + .list.get(0) + .struct.field("key_1") + .alias("key_1_values") + ) + expected_values = pl.Series("key_1_values", ["a", "a2", "a3"]) + assert_series_equal(key_1_values.get_column("key_1_values"), expected_values) + + +def test_json_decode_primitive_to_list_11053() -> None: + df = pl.DataFrame( + { + "json": [ + '{"col1": ["123"], "col2": "123"}', + '{"col1": ["xyz"], "col2": null}', + ] + } + ) + schema = pl.Struct( + { + "col1": pl.List(pl.String), + "col2": pl.List(pl.String), + } + ) + + output = df.select( + pl.col("json").str.json_decode(schema).alias("decoded_json") + ).unnest("decoded_json") + expected = pl.DataFrame({"col1": [["123"], ["xyz"]], "col2": [["123"], None]}) + assert_frame_equal(output, expected) + + +def test_jsonpath_single() -> None: + s = pl.Series(['{"a":"1"}', None, '{"a":2}', '{"a":2.1}', '{"a":true}']) + expected = pl.Series(["1", None, "2", "2.1", "true"]) + assert_series_equal(s.str.json_path_match("$.a"), expected) + + +def test_json_path_match() -> None: + df = pl.DataFrame( + { + "str": [ + '{"a":"1"}', + None, + '{"b":2}', + '{"a":2.1, "b": "hello"}', + '{"a":true}', + ], + "pat": ["$.a", "$.a", "$.b", "$.b", None], + } + ) + out = df.select( + all_expr=pl.col("str").str.json_path_match(pl.col("pat")), + str_expr=pl.col("str").str.json_path_match("$.a"), + pat_expr=pl.lit('{"a": 1.1, "b": 10}').str.json_path_match(pl.col("pat")), + ) + expected = pl.DataFrame( + { + "all_expr": ["1", None, "2", "hello", None], + "str_expr": ["1", None, None, "2.1", "true"], + "pat_expr": ["1.1", "1.1", "10", "10", None], + } + ) + assert_frame_equal(out, expected) + + +def test_str_json_path_match_wrong_length() -> None: + df = pl.DataFrame({"num": ["-10", "-1", "0"]}) + with pytest.raises((ShapeError, ComputeError)): + df.select(pl.col("num").str.json_path_match(pl.Series(["a", "b"]))) + + +def test_extract_regex() -> None: + s = pl.Series( + [ + "http://vote.com/ballon_dor?candidate=messi&ref=polars", + "http://vote.com/ballon_dor?candidat=jorginho&ref=polars", + "http://vote.com/ballon_dor?candidate=ronaldo&ref=polars", + ] + ) + expected = pl.Series(["messi", None, "ronaldo"]) + assert_series_equal(s.str.extract(r"candidate=(\w+)", 1), expected) + + +def test_extract() -> None: + df = pl.DataFrame( + { + "s": ["aron123", "12butler", "charly*", "~david", None], + "pat": [r"^([a-zA-Z]+)", r"^(\d+)", None, "^(da)", r"(.*)"], + } + ) + + out = df.select( + all_expr=pl.col("s").str.extract(pl.col("pat"), 1), + str_expr=pl.col("s").str.extract("^([a-zA-Z]+)", 1), + pat_expr=pl.lit("aron123").str.extract(pl.col("pat")), + ) + expected = pl.DataFrame( + { + "all_expr": ["aron", "12", None, None, None], + "str_expr": ["aron", None, "charly", None, None], + "pat_expr": ["aron", None, None, None, "aron123"], + } + ) + assert_frame_equal(out, expected) + + +def test_extract_binary() -> None: + df = pl.DataFrame({"foo": ["aron", "butler", "charly", "david"]}) + out = df.filter(pl.col("foo").str.extract("^(a)", 1) == "a").to_series() + assert out[0] == "aron" + + +def test_str_join_returns_scalar() -> None: + df = pl.DataFrame( + [pl.Series("val", ["A", "B", "C", "D"]), pl.Series("id", [1, 1, 2, 2])] + ) + grouped = ( + df.group_by("id") + .agg(pl.col("val").str.join(delimiter=",").alias("grouped")) + .get_column("grouped") + ) + assert grouped.dtype == pl.String + + +def test_contains() -> None: + # test strict/non strict + s_txt = pl.Series(["123", "456", "789"]) + assert ( + pl.Series([None, None, None]).cast(pl.Boolean).to_list() + == s_txt.str.contains("(not_valid_regex", literal=False, strict=False).to_list() + ) + with pytest.raises(ComputeError): + s_txt.str.contains("(not_valid_regex", literal=False, strict=True) + assert ( + pl.Series([True, False, False]).cast(pl.Boolean).to_list() + == s_txt.str.contains("1", literal=False, strict=False).to_list() + ) + + df = pl.DataFrame( + data=[(1, "some * * text"), (2, "(with) special\n * chars"), (3, "**etc...?$")], + schema=["idx", "text"], + orient="row", + ) + for pattern, as_literal, expected in ( + (r"\* \*", False, [True, False, False]), + (r"* *", True, [True, False, False]), + (r"^\(", False, [False, True, False]), + (r"^\(", True, [False, False, False]), + (r"(", True, [False, True, False]), + (r"e", False, [True, True, True]), + (r"e", True, [True, True, True]), + (r"^\S+$", False, [False, False, True]), + (r"\?\$", False, [False, False, True]), + (r"?$", True, [False, False, True]), + ): + # series + assert ( + expected == df["text"].str.contains(pattern, literal=as_literal).to_list() + ) + # frame select + assert ( + expected + == df.select(pl.col("text").str.contains(pattern, literal=as_literal))[ + "text" + ].to_list() + ) + # frame filter + assert sum(expected) == len( + df.filter(pl.col("text").str.contains(pattern, literal=as_literal)) + ) + + +def test_contains_expr() -> None: + df = pl.DataFrame( + { + "text": [ + "some text", + "(with) special\n .* chars", + "**etc...?$", + None, + "b", + "invalid_regex", + ], + "pattern": [r"[me]", r".*", r"^\(", "a", None, "*"], + } + ) + + assert df.select( + pl.col("text") + .str.contains(pl.col("pattern"), literal=False, strict=False) + .alias("contains"), + pl.col("text") + .str.contains(pl.col("pattern"), literal=True) + .alias("contains_lit"), + ).to_dict(as_series=False) == { + "contains": [True, True, False, None, None, None], + "contains_lit": [False, True, False, None, None, False], + } + + with pytest.raises(ComputeError): + df.select( + pl.col("text").str.contains(pl.col("pattern"), literal=False, strict=True) + ) + + +@pytest.mark.parametrize( + ("pattern", "case_insensitive", "expected"), + [ + (["me"], False, True), + (["Me"], False, False), + (["Me"], True, True), + (pl.Series(["me", "they"]), False, True), + (pl.Series(["Me", "they"]), False, False), + (pl.Series(["Me", "they"]), True, True), + (["me", "they"], False, True), + (["Me", "they"], False, False), + (["Me", "they"], True, True), + ], +) +def test_contains_any( + pattern: pl.Series | list[str], + case_insensitive: bool, + expected: bool, +) -> None: + df = pl.DataFrame({"text": ["Tell me what you want"]}) + # series + assert ( + expected + == df["text"] + .str.contains_any(pattern, ascii_case_insensitive=case_insensitive) + .item() + ) + # expr + assert ( + expected + == df.select( + pl.col("text").str.contains_any( + pattern, ascii_case_insensitive=case_insensitive + ) + )["text"].item() + ) + # frame filter + assert int(expected) == len( + df.filter( + pl.col("text").str.contains_any( + pattern, ascii_case_insensitive=case_insensitive + ) + ) + ) + + +def test_replace() -> None: + df = pl.DataFrame( + data=[(1, "* * text"), (2, "(with) special\n * chars **etc...?$")], + schema=["idx", "text"], + orient="row", + ) + for pattern, replacement, as_literal, expected in ( + (r"\*", "-", False, ["- * text", "(with) special\n - chars **etc...?$"]), + (r"*", "-", True, ["- * text", "(with) special\n - chars **etc...?$"]), + (r"^\(", "[", False, ["* * text", "[with) special\n * chars **etc...?$"]), + (r"^\(", "[", True, ["* * text", "(with) special\n * chars **etc...?$"]), + (r"t$", "an", False, ["* * texan", "(with) special\n * chars **etc...?$"]), + (r"t$", "an", True, ["* * text", "(with) special\n * chars **etc...?$"]), + (r"(with) special", "$1", True, ["* * text", "$1\n * chars **etc...?$"]), + ( + r"\((with)\) special", + ":$1:", + False, + ["* * text", ":with:\n * chars **etc...?$"], + ), + ): + # series + assert ( + expected + == df["text"] + .str.replace(pattern, replacement, literal=as_literal) + .to_list() + ) + # expr + assert ( + expected + == df.select( + pl.col("text").str.replace(pattern, replacement, literal=as_literal) + )["text"].to_list() + ) + + assert pl.Series(["."]).str.replace(".", "$0", literal=True)[0] == "$0" + assert pl.Series(["(.)(?)"]).str.replace(".", "$1", literal=True)[0] == "($1)(?)" + + +def test_replace_all() -> None: + df = pl.DataFrame( + data=[(1, "* * text"), (2, "(with) special\n * chars **etc...?$")], + schema=["idx", "text"], + orient="row", + ) + for pattern, replacement, as_literal, expected in ( + (r"\*", "-", False, ["- - text", "(with) special\n - chars --etc...?$"]), + (r"*", "-", True, ["- - text", "(with) special\n - chars --etc...?$"]), + (r"\W", "", False, ["text", "withspecialcharsetc"]), + (r".?$", "", True, ["* * text", "(with) special\n * chars **etc.."]), + ( + r"(with) special", + "$1", + True, + ["* * text", "$1\n * chars **etc...?$"], + ), + ( + r"\((with)\) special", + ":$1:", + False, + ["* * text", ":with:\n * chars **etc...?$"], + ), + ( + r"(\b)[\w\s]{2,}(\b)", + "$1(blah)$3", + False, + ["* * (blah)", "((blah)) (blah)\n * (blah) **(blah)...?$"], + ), + ): + # series + assert ( + expected + == df["text"] + .str.replace_all(pattern, replacement, literal=as_literal) + .to_list() + ) + # expr + assert ( + expected + == df.select( + pl.col("text").str.replace_all(pattern, replacement, literal=as_literal) + )["text"].to_list() + ) + # invalid regex (but valid literal - requires "literal=True") + with pytest.raises(ComputeError): + df["text"].str.replace_all("*", "") + + assert ( + pl.Series([r"(.)(\?)(\?)"]).str.replace_all("\\?", "$0", literal=True)[0] + == "(.)($0)($0)" + ) + assert ( + pl.Series([r"(.)(\?)(\?)"]).str.replace_all("\\?", "$0", literal=False)[0] + == "(.)(\\?)(\\?)" + ) + + +def test_replace_all_literal_no_caputures() -> None: + # When using literal = True, capture groups should be disabled + + # Single row code path in Rust + df = pl.DataFrame({"text": ["I found yesterday."], "amt": ["$1"]}) + df = df.with_columns( + pl.col("text") + .str.replace_all("", pl.col("amt"), literal=True) + .alias("text2") + ) + assert df.get_column("text2")[0] == "I found $1 yesterday." + + # Multi-row code path in Rust + df2 = pl.DataFrame( + { + "text": ["I found yesterday.", "I lost yesterday."], + "amt": ["$1", "$2"], + } + ) + df2 = df2.with_columns( + pl.col("text") + .str.replace_all("", pl.col("amt"), literal=True) + .alias("text2") + ) + assert df2.get_column("text2")[0] == "I found $1 yesterday." + assert df2.get_column("text2")[1] == "I lost $2 yesterday." + + +def test_replace_literal_no_caputures() -> None: + # When using literal = True, capture groups should be disabled + + # Single row code path in Rust + df = pl.DataFrame({"text": ["I found yesterday."], "amt": ["$1"]}) + df = df.with_columns( + pl.col("text").str.replace("", pl.col("amt"), literal=True).alias("text2") + ) + assert df.get_column("text2")[0] == "I found $1 yesterday." + + # Multi-row code path in Rust + # A string shorter than 32 chars, + # and one longer than 32 chars to test both sub-paths + df2 = pl.DataFrame( + { + "text": [ + "I found yesterday.", + "I lost yesterday and this string is longer than 32 characters.", + ], + "amt": ["$1", "$2"], + } + ) + df2 = df2.with_columns( + pl.col("text").str.replace("", pl.col("amt"), literal=True).alias("text2") + ) + assert df2.get_column("text2")[0] == "I found $1 yesterday." + assert ( + df2.get_column("text2")[1] + == "I lost $2 yesterday and this string is longer than 32 characters." + ) + + +def test_replace_expressions() -> None: + df = pl.DataFrame({"foo": ["123 bla 45 asd", "xyz 678 910t"], "value": ["A", "B"]}) + out = df.select([pl.col("foo").str.replace(pl.col("foo").first(), pl.col("value"))]) + assert out.to_dict(as_series=False) == {"foo": ["A", "xyz 678 910t"]} + out = df.select([pl.col("foo").str.replace(pl.col("foo").last(), "value")]) + assert out.to_dict(as_series=False) == {"foo": ["123 bla 45 asd", "value"]} + + df = pl.DataFrame( + {"foo": ["1 bla 45 asd", "xyz 6t"], "pat": [r"\d", r"\W"], "value": ["A", "B"]} + ) + out = df.select([pl.col("foo").str.replace_all(pl.col("pat").first(), "value")]) + assert out.to_dict(as_series=False) == { + "foo": ["value bla valuevalue asd", "xyz valuet"] + } + + +@pytest.mark.parametrize( + ("pattern", "replacement", "case_insensitive", "expected"), + [ + (["say"], "", False, "Tell me what you want"), + (["me"], ["them"], False, "Tell them what you want"), + (["who"], ["them"], False, "Tell me what you want"), + (["me", "you"], "it", False, "Tell it what it want"), + (["Me", "you"], "it", False, "Tell me what it want"), + (["me", "you"], ["it"], False, "Tell it what it want"), + (["me", "you"], ["you", "me"], False, "Tell you what me want"), + (["me", "You", "them"], "it", False, "Tell it what you want"), + (["Me", "you"], "it", True, "Tell it what it want"), + (["me", "YOU"], ["you", "me"], True, "Tell you what me want"), + (pl.Series(["me", "YOU"]), ["you", "me"], False, "Tell you what you want"), + (pl.Series(["me", "YOU"]), ["you", "me"], True, "Tell you what me want"), + ], +) +def test_replace_many( + pattern: pl.Series | list[str], + replacement: pl.Series | list[str] | str, + case_insensitive: bool, + expected: str, +) -> None: + df = pl.DataFrame({"text": ["Tell me what you want"]}) + # series + assert ( + expected + == df["text"] + .str.replace_many(pattern, replacement, ascii_case_insensitive=case_insensitive) + .item() + ) + # expr + assert ( + expected + == df.select( + pl.col("text").str.replace_many( + pattern, + replacement, + ascii_case_insensitive=case_insensitive, + ) + ).item() + ) + + +def test_replace_many_groupby() -> None: + df = pl.DataFrame( + { + "x": ["a", "b", "c", "d", "e", "f", "g", "h", "i"], + "g": [0, 0, 0, 1, 1, 1, 2, 2, 2], + } + ) + out = df.group_by("g").agg(pl.col.x.str.replace_many(pl.col.x.head(2), "")) + expected = pl.DataFrame( + { + "g": [0, 1, 2], + "x": [["", "", "c"], ["", "", "f"], ["", "", "i"]], + } + ) + assert_frame_equal(out, expected, check_row_order=False) + + +@pytest.mark.parametrize( + ("mapping", "case_insensitive", "expected"), + [ + ({}, False, "Tell me what you want"), + ({"me": "them"}, False, "Tell them what you want"), + ({"who": "them"}, False, "Tell me what you want"), + ({"me": "it", "you": "it"}, False, "Tell it what it want"), + ({"Me": "it", "you": "it"}, False, "Tell me what it want"), + ({"me": "you", "you": "me"}, False, "Tell you what me want"), + ({}, True, "Tell me what you want"), + ({"Me": "it", "you": "it"}, True, "Tell it what it want"), + ({"me": "you", "YOU": "me"}, True, "Tell you what me want"), + ], +) +def test_replace_many_mapping( + mapping: dict[str, str], + case_insensitive: bool, + expected: str, +) -> None: + df = pl.DataFrame({"text": ["Tell me what you want"]}) + # series + assert ( + expected + == df["text"] + .str.replace_many(mapping, ascii_case_insensitive=case_insensitive) + .item() + ) + # expr + assert ( + expected + == df.select( + pl.col("text").str.replace_many( + mapping, + ascii_case_insensitive=case_insensitive, + ) + ).item() + ) + + +def test_replace_many_invalid_inputs() -> None: + df = pl.DataFrame({"text": ["Tell me what you want"]}) + + # Ensure a string as the first argument is parsed as a column name. + with pytest.raises(ColumnNotFoundError, match="me"): + df.select(pl.col("text").str.replace_many("me", "you")) + + with pytest.raises(SchemaError): + df.select(pl.col("text").str.replace_many(1, 2)) + + with pytest.raises(SchemaError): + df.select(pl.col("text").str.replace_many([1], [2])) + + with pytest.raises(SchemaError): + df.select(pl.col("text").str.replace_many(["me"], None)) + + with pytest.raises(TypeError): + df.select(pl.col("text").str.replace_many(["me"])) + + with pytest.raises( + InvalidOperationError, + match="expected the same amount of patterns as replacement strings", + ): + df.select(pl.col("text").str.replace_many(["a"], ["b", "c"])) + + s = df.to_series() + + with pytest.raises(ColumnNotFoundError, match="me"): + s.str.replace_many("me", "you") # type: ignore[arg-type] + + with pytest.raises(SchemaError): + df.select(pl.col("text").str.replace_many(["me"], None)) + + with pytest.raises(TypeError): + df.select(pl.col("text").str.replace_many(["me"])) + + with pytest.raises( + InvalidOperationError, + match="expected the same amount of patterns as replacement strings", + ): + s.str.replace_many(["a"], ["b", "c"]) + + +def test_extract_all_count() -> None: + df = pl.DataFrame({"foo": ["123 bla 45 asd", "xaz 678 910t", "boo", None]}) + assert ( + df.select( + pl.col("foo").str.extract_all(r"a").alias("extract"), + pl.col("foo").str.count_matches(r"a").alias("count"), + ).to_dict(as_series=False) + ) == {"extract": [["a", "a"], ["a"], [], None], "count": [2, 1, 0, None]} + + assert df["foo"].str.extract_all(r"a").dtype == pl.List + assert df["foo"].str.count_matches(r"a").dtype == pl.UInt32 + + +def test_count_matches_many() -> None: + df = pl.DataFrame( + { + "foo": ["123 bla 45 asd", "xyz 678 910t", None, "boo"], + "bar": [r"\d", r"[a-z]", r"\d", None], + } + ) + assert ( + df.select( + pl.col("foo").str.count_matches(pl.col("bar")).alias("count") + ).to_dict(as_series=False) + ) == {"count": [5, 4, None, None]} + + assert df["foo"].str.count_matches(df["bar"]).dtype == pl.UInt32 + + # Test broadcast. + broad = df.select( + pl.col("foo").str.count_matches(pl.col("bar").first()).alias("count"), + pl.col("foo").str.count_matches(pl.col("bar").last()).alias("count_null"), + ) + assert broad.to_dict(as_series=False) == { + "count": [5, 6, None, 0], + "count_null": [None, None, None, None], + } + assert broad.schema == {"count": pl.UInt32, "count_null": pl.UInt32} + + +def test_extract_all_many() -> None: + df = pl.DataFrame( + { + "foo": ["ab", "abc", "abcd", "foo", None, "boo"], + "re": ["a", "bc", "a.c", "a", "a", None], + } + ) + assert df["foo"].str.extract_all(df["re"]).to_list() == [ + ["a"], + ["bc"], + ["abc"], + [], + None, + None, + ] + + # Test broadcast. + broad = df.select( + pl.col("foo").str.extract_all(pl.col("re").first()).alias("a"), + pl.col("foo").str.extract_all(pl.col("re").last()).alias("null"), + ) + assert broad.to_dict(as_series=False) == { + "a": [["a"], ["a"], ["a"], [], None, []], + "null": [None] * 6, + } + assert broad.schema == {"a": pl.List(pl.String), "null": pl.List(pl.String)} + + +def test_extract_groups() -> None: + def _named_groups_builder(pattern: str, groups: dict[str, str]) -> str: + return pattern.format( + **{name: f"(?<{name}>{value})" for name, value in groups.items()} + ) + + expected = { + "authority": ["ISO", "ISO/IEC/IEEE"], + "spec_num": ["80000", "29148"], + "part_num": ["1", None], + "revision_year": ["2009", "2018"], + } + + pattern = _named_groups_builder( + r"{authority}\s{spec_num}(?:-{part_num})?(?::{revision_year})", + { + "authority": r"^ISO(?:/[A-Z]+)*", + "spec_num": r"\d+", + "part_num": r"\d+", + "revision_year": r"\d{4}", + }, + ) + + df = pl.DataFrame({"iso_code": ["ISO 80000-1:2009", "ISO/IEC/IEEE 29148:2018"]}) + + assert ( + df.select(pl.col("iso_code").str.extract_groups(pattern)) + .unnest("iso_code") + .to_dict(as_series=False) + == expected + ) + + assert df.select(pl.col("iso_code").str.extract_groups("")).to_dict( + as_series=False + ) == {"iso_code": [{"iso_code": None}, {"iso_code": None}]} + + assert df.select( + pl.col("iso_code").str.extract_groups(r"\A(ISO\S*).*?(\d+)") + ).to_dict(as_series=False) == { + "iso_code": [{"1": "ISO", "2": "80000"}, {"1": "ISO/IEC/IEEE", "2": "29148"}] + } + + assert df.select( + pl.col("iso_code").str.extract_groups(r"\A(ISO\S*).*?(?\d+)\z") + ).to_dict(as_series=False) == { + "iso_code": [ + {"1": "ISO", "year": "2009"}, + {"1": "ISO/IEC/IEEE", "year": "2018"}, + ] + } + + assert pl.select( + pl.lit(r"foobar").str.extract_groups(r"(?.{3})|(?...)") + ).to_dict(as_series=False) == {"literal": [{"foo": "foo", "bar": None}]} + + +def test_starts_ends_with() -> None: + df = pl.DataFrame( + { + "a": ["hamburger_with_tomatoes", "nuts", "lollypop", None], + "sub": ["ham", "ts", None, "anything"], + } + ) + + assert df.select( + pl.col("a").str.ends_with("pop").alias("ends_pop"), + pl.col("a").str.ends_with(pl.lit(None)).alias("ends_None"), + pl.col("a").str.ends_with(pl.col("sub")).alias("ends_sub"), + pl.col("a").str.starts_with("ham").alias("starts_ham"), + pl.col("a").str.starts_with(pl.lit(None)).alias("starts_None"), + pl.col("a").str.starts_with(pl.col("sub")).alias("starts_sub"), + ).to_dict(as_series=False) == { + "ends_pop": [False, False, True, None], + "ends_None": [None, None, None, None], + "ends_sub": [False, True, None, None], + "starts_ham": [True, False, False, None], + "starts_None": [None, None, None, None], + "starts_sub": [True, False, None, None], + } + + +def test_json_path_match_type_4905() -> None: + df = pl.DataFrame({"json_val": ['{"a":"hello"}', None, '{"a":"world"}']}) + assert df.filter( + pl.col("json_val").str.json_path_match("$.a").is_in(["hello"]) + ).to_dict(as_series=False) == {"json_val": ['{"a":"hello"}']} + + +def test_decode_strict() -> None: + df = pl.DataFrame( + {"strings": ["0IbQvTc3", "0J%2FQldCf0JA%3D", "0J%2FRgNC%2B0YHRgtC%2B"]} + ) + result = df.select(pl.col("strings").str.decode("base64", strict=False)) + expected = {"strings": [b"\xd0\x86\xd0\xbd77", None, None]} + assert result.to_dict(as_series=False) == expected + + with pytest.raises(ComputeError): + df.select(pl.col("strings").str.decode("base64", strict=True)) + + +def test_split() -> None: + df = pl.DataFrame({"x": ["a_a", None, "b", "c_c_c", ""]}) + out = df.select([pl.col("x").str.split("_")]) + + expected = pl.DataFrame( + [ + {"x": ["a", "a"]}, + {"x": None}, + {"x": ["b"]}, + {"x": ["c", "c", "c"]}, + {"x": [""]}, + ] + ) + + assert_frame_equal(out, expected) + assert_frame_equal(df["x"].str.split("_").to_frame(), expected) + + out = df.select([pl.col("x").str.split("_", inclusive=True)]) + + expected = pl.DataFrame( + [ + {"x": ["a_", "a"]}, + {"x": None}, + {"x": ["b"]}, + {"x": ["c_", "c_", "c"]}, + {"x": []}, + ] + ) + + assert_frame_equal(out, expected) + assert_frame_equal(df["x"].str.split("_", inclusive=True).to_frame(), expected) + + out = df.select([pl.col("x").str.split("")]) + + expected = pl.DataFrame( + [ + {"x": ["a", "_", "a"]}, + {"x": None}, + {"x": ["b"]}, + {"x": ["c", "_", "c", "_", "c"]}, + {"x": []}, + ] + ) + + assert_frame_equal(out, expected) + assert_frame_equal(df["x"].str.split("").to_frame(), expected) + + out = df.select([pl.col("x").str.split("", inclusive=True)]) + + expected = pl.DataFrame( + [ + {"x": ["a", "_", "a"]}, + {"x": None}, + {"x": ["b"]}, + {"x": ["c", "_", "c", "_", "c"]}, + {"x": []}, + ] + ) + + assert_frame_equal(out, expected) + assert_frame_equal(df["x"].str.split("", inclusive=True).to_frame(), expected) + + plan = ( + df.lazy() + .select( + a=pl.col("x").str.split(" ", inclusive=False), + b=pl.col("x").str.split_exact(" ", 1, inclusive=False), + ) + .explain() + ) + + assert "str.split(" in plan + assert "str.split_exact(" in plan + + plan = ( + df.lazy() + .select( + a=pl.col("x").str.split(" ", inclusive=True), + b=pl.col("x").str.split_exact(" ", 1, inclusive=True), + ) + .explain() + ) + + assert "str.split_inclusive(" in plan + assert "str.split_exact_inclusive(" in plan + + +def test_split_expr() -> None: + df = pl.DataFrame( + { + "x": ["a_a", None, "b", "c*c*c", "dddd", ""], + "by": ["_", "#", "^", "*", "", ""], + } + ) + out = df.select([pl.col("x").str.split(pl.col("by"))]) + expected = pl.DataFrame( + [ + {"x": ["a", "a"]}, + {"x": None}, + {"x": ["b"]}, + {"x": ["c", "c", "c"]}, + {"x": ["d", "d", "d", "d"]}, + {"x": []}, + ] + ) + assert_frame_equal(out, expected) + + out = df.select([pl.col("x").str.split(pl.col("by"), inclusive=True)]) + expected = pl.DataFrame( + [ + {"x": ["a_", "a"]}, + {"x": None}, + {"x": ["b"]}, + {"x": ["c*", "c*", "c"]}, + {"x": ["d", "d", "d", "d"]}, + {"x": []}, + ] + ) + assert_frame_equal(out, expected) + + +def test_split_exact() -> None: + df = pl.DataFrame({"x": ["a_a", None, "b", "c_c", ""]}) + out = df.select([pl.col("x").str.split_exact("_", 2, inclusive=False)]).unnest("x") + + expected = pl.DataFrame( + { + "field_0": ["a", None, "b", "c", ""], + "field_1": ["a", None, None, "c", None], + "field_2": pl.Series([None, None, None, None, None], dtype=pl.String), + } + ) + + assert_frame_equal(out, expected) + out2 = df["x"].str.split_exact("_", 2, inclusive=False).to_frame().unnest("x") + assert_frame_equal(out2, expected) + + out = df.select([pl.col("x").str.split_exact("_", 1, inclusive=True)]).unnest("x") + + expected = pl.DataFrame( + { + "field_0": ["a_", None, "b", "c_", None], + "field_1": ["a", None, None, "c", None], + } + ) + assert_frame_equal(out, expected) + assert df["x"].str.split_exact("_", 1).dtype == pl.Struct + assert df["x"].str.split_exact("_", 1, inclusive=False).dtype == pl.Struct + + out = df.select([pl.col("x").str.split_exact("", 1)]).unnest("x") + + expected = pl.DataFrame( + { + "field_0": ["a", None, "b", "c", None], + "field_1": ["_", None, None, "_", None], + } + ) + assert_frame_equal(out, expected) + + out = df.select([pl.col("x").str.split_exact("", 1, inclusive=True)]).unnest("x") + + expected = pl.DataFrame( + { + "field_0": ["a", None, "b", "c", None], + "field_1": ["_", None, None, "_", None], + } + ) + assert_frame_equal(out, expected) + + +def test_split_exact_expr() -> None: + df = pl.DataFrame( + { + "x": ["a_a", None, "b", "c^c^c", "d#d", "eeee", ""], + "by": ["_", "&", "$", "^", None, "", ""], + } + ) + + out = df.select( + pl.col("x").str.split_exact(pl.col("by"), 2, inclusive=False) + ).unnest("x") + + expected = pl.DataFrame( + { + "field_0": ["a", None, "b", "c", None, "e", None], + "field_1": ["a", None, None, "c", None, "e", None], + "field_2": pl.Series( + [None, None, None, "c", None, "e", None], dtype=pl.String + ), + } + ) + + assert_frame_equal(out, expected) + + out2 = df.select( + pl.col("x").str.split_exact(pl.col("by"), 2, inclusive=True) + ).unnest("x") + + expected2 = pl.DataFrame( + { + "field_0": ["a_", None, "b", "c^", None, "e", None], + "field_1": ["a", None, None, "c^", None, "e", None], + "field_2": pl.Series( + [None, None, None, "c", None, "e", None], dtype=pl.String + ), + } + ) + assert_frame_equal(out2, expected2) + + +def test_splitn() -> None: + df = pl.DataFrame({"x": ["a_a", None, "b", "c_c_c", ""]}) + out = df.select([pl.col("x").str.splitn("_", 2)]).unnest("x") + + expected = pl.DataFrame( + { + "field_0": ["a", None, "b", "c", ""], + "field_1": ["a", None, None, "c_c", None], + } + ) + + assert_frame_equal(out, expected) + assert_frame_equal(df["x"].str.splitn("_", 2).to_frame().unnest("x"), expected) + + out = df.select([pl.col("x").str.splitn("", 2)]).unnest("x") + + expected = pl.DataFrame( + { + "field_0": ["a", None, "b", "c", None], + "field_1": ["_a", None, None, "_c_c", None], + } + ) + + assert_frame_equal(out, expected) + assert_frame_equal(df["x"].str.splitn("", 2).to_frame().unnest("x"), expected) + + +def test_splitn_expr() -> None: + df = pl.DataFrame( + { + "x": ["a_a", None, "b", "c^c^c", "d#d", "eeee", ""], + "by": ["_", "&", "$", "^", None, "", ""], + } + ) + + out = df.select(pl.col("x").str.splitn(pl.col("by"), 2)).unnest("x") + + expected = pl.DataFrame( + { + "field_0": ["a", None, "b", "c", None, "e", None], + "field_1": ["a", None, None, "c^c", None, "eee", None], + } + ) + + assert_frame_equal(out, expected) + + +def test_titlecase() -> None: + df = pl.DataFrame( + { + "misc": [ + "welcome to my world", + "double space", + "and\ta\t tab", + "by jean-paul sartre, 'esq'", + "SOMETIMES/life/gives/you/a/2nd/chance", + ], + } + ) + expected = [ + "Welcome To My World", + "Double Space", + "And\tA\t Tab", + "By Jean-Paul Sartre, 'Esq'", + "Sometimes/Life/Gives/You/A/2nd/Chance", + ] + actual = df.select(pl.col("misc").str.to_titlecase()).to_series() + for ex, act in zip(expected, actual): + assert ex == act, f"{ex} != {act}" + + df = pl.DataFrame( + { + "quotes": [ + "'e.t. phone home'", + "you talkin' to me?", + "i feel the need--the need for speed", + "to infinity,and BEYOND!", + "say 'what' again!i dare you - I\u00a0double-dare you!", + "What.we.got.here... is#failure#to#communicate", + ] + } + ) + expected_str = [ + "'E.T. Phone Home'", + "You Talkin' To Me?", + "I Feel The Need--The Need For Speed", + "To Infinity,And Beyond!", + "Say 'What' Again!I Dare You - I\u00a0Double-Dare You!", + "What.We.Got.Here... Is#Failure#To#Communicate", + ] + expected_py = [s.title() for s in df["quotes"].to_list()] + for ex_str, ex_py, act in zip( + expected_str, expected_py, df["quotes"].str.to_titlecase() + ): + assert ex_str == act, f"{ex_str} != {act}" + assert ex_py == act, f"{ex_py} != {act}" + + +def test_string_replace_with_nulls_10124() -> None: + df = pl.DataFrame({"col1": ["S", "S", "S", None, "S", "S", "S", "S"]}) + + assert df.select( + pl.col("col1"), + pl.col("col1").str.replace("S", "O", n=1).alias("n_1"), + pl.col("col1").str.replace("S", "O", n=3).alias("n_3"), + ).to_dict(as_series=False) == { + "col1": ["S", "S", "S", None, "S", "S", "S", "S"], + "n_1": ["O", "O", "O", None, "O", "O", "O", "O"], + "n_3": ["O", "O", "O", None, "O", "O", "O", "O"], + } + + +def test_string_extract_groups_lazy_schema_10305() -> None: + df = pl.LazyFrame( + data={ + "url": [ + "http://vote.com/ballon_dor?candidate=messi&ref=python", + "http://vote.com/ballon_dor?candidate=weghorst&ref=polars", + "http://vote.com/ballon_dor?error=404&ref=rust", + ] + } + ) + pattern = r"candidate=(?\w+)&ref=(?\w+)" + df = df.select(captures=pl.col("url").str.extract_groups(pattern)).unnest( + "captures" + ) + + assert df.collect_schema() == {"candidate": pl.String, "ref": pl.String} + + +def test_string_reverse() -> None: + df = pl.DataFrame( + { + "text": [None, "foo", "bar", "i like pizza&#", None, "man\u0303ana"], + } + ) + expected = pl.DataFrame( + [ + pl.Series( + "text", + [None, "oof", "rab", "#&azzip ekil i", None, "anan\u0303am"], + dtype=pl.String, + ), + ] + ) + + result = df.select(pl.col("text").str.reverse()) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + ("data", "expected_data"), + [ + (["", None, "a"], ["", None, "b"]), + ([None, None, "a"], [None, None, "b"]), + (["", "", ""], ["", "", ""]), + ([None, None, None], [None, None, None]), + (["a", "", None], ["b", "", None]), + ], +) +def test_replace_lit_n_char_13385( + data: list[str | None], expected_data: list[str | None] +) -> None: + s = pl.Series(data, dtype=pl.String) + res = s.str.replace("a", "b", literal=True) + expected_s = pl.Series(expected_data, dtype=pl.String) + assert_series_equal(res, expected_s) + + +def test_extract_many() -> None: + df = pl.DataFrame({"values": ["discontent", "foobar"]}) + patterns = ["winter", "disco", "onte", "discontent"] + assert df.with_columns( + pl.col("values").str.extract_many(patterns, overlapping=False).alias("matches"), + pl.col("values") + .str.extract_many(patterns, overlapping=True) + .alias("matches_overlapping"), + ).to_dict(as_series=False) == { + "values": ["discontent", "foobar"], + "matches": [["disco"], []], + "matches_overlapping": [["disco", "onte", "discontent"], []], + } + + # many patterns + df = pl.DataFrame( + { + "values": ["discontent", "rhapsody"], + "patterns": [ + ["winter", "disco", "onte", "discontent"], + ["rhap", "ody", "coalesce"], + ], + } + ) + + # extract_many + assert df.select(pl.col("values").str.extract_many("patterns")).to_dict( + as_series=False + ) == {"values": [["disco"], ["rhap", "ody"]]} + + # find_many + f1 = df.select(pl.col("values").str.find_many("patterns")) + f2 = df["values"].str.find_many(df["patterns"]) + + assert_series_equal(f1["values"], f2) + assert f2.to_list() == [[0], [0, 5]] + + +def test_json_decode_raise_on_data_type_mismatch_13061() -> None: + assert_series_equal( + pl.Series(["null", "null"]).str.json_decode(infer_schema_length=1), + pl.Series([None, None]), + ) + + with pytest.raises(ComputeError): + pl.Series(["null", "1"]).str.json_decode(infer_schema_length=1) + + assert_series_equal( + pl.Series(["null", "1"]).str.json_decode(infer_schema_length=2), + pl.Series([None, 1]), + ) + + +def test_json_decode_struct_schema() -> None: + with pytest.raises(ComputeError, match="extra field in struct data: b"): + pl.Series([r'{"a": 1}', r'{"a": 2, "b": 2}']).str.json_decode( + infer_schema_length=1 + ) + + assert_series_equal( + pl.Series([r'{"a": 1}', r'{"a": 2, "b": 2}']).str.json_decode( + infer_schema_length=2 + ), + pl.Series([{"a": 1, "b": None}, {"a": 2, "b": 2}]), + ) + + # If the schema was explicitly given, then we ignore extra fields. + # TODO: There should be a `columns=` parameter to this. + assert_series_equal( + pl.Series([r'{"a": 1}', r'{"a": 2, "b": 2}']).str.json_decode( + dtype=pl.Struct({"a": pl.Int64}) + ), + pl.Series([{"a": 1}, {"a": 2}]), + ) + + +def test_escape_regex() -> None: + df = pl.DataFrame({"text": ["abc", "def", None, "abc(\\w+)"]}) + result_df = df.with_columns(pl.col("text").str.escape_regex().alias("escaped")) + expected_df = pl.DataFrame( + { + "text": ["abc", "def", None, "abc(\\w+)"], + "escaped": ["abc", "def", None, "abc\\(\\\\w\\+\\)"], + } + ) + + assert_frame_equal(result_df, expected_df) + assert_series_equal(result_df["escaped"], expected_df["escaped"]) + + +@pytest.mark.parametrize( + ("form", "expected_data"), + [ + ("NFC", ["01²", "KADOKAWA"]), # noqa: RUF001 + ("NFD", ["01²", "KADOKAWA"]), # noqa: RUF001 + ("NFKC", ["012", "KADOKAWA"]), + ("NFKD", ["012", "KADOKAWA"]), + ], +) +def test_string_normalize(form: Any, expected_data: list[str | None]) -> None: + s = pl.Series(["01²", "KADOKAWA"], dtype=pl.String) # noqa: RUF001 + res = s.str.normalize(form) + expected_s = pl.Series(expected_data, dtype=pl.String) + assert_series_equal(res, expected_s) + + +def test_string_normalize_wrong_input() -> None: + with pytest.raises(ValueError, match="`form` must be one of"): + pl.Series(["01²"], dtype=pl.String).str.normalize("foobar") # type: ignore[arg-type] + + +def test_to_integer_unequal_lengths_22034() -> None: + s = pl.Series("a", ["1", "2", "3"], pl.String) + with pytest.raises(pl.exceptions.ShapeError): + s.str.to_integer(base=pl.Series([4, 5, 5, 4])) + + +def test_broadcast_self() -> None: + s = pl.Series("a", ["3"], pl.String) + with pytest.raises( + pl.exceptions.ComputeError, match="strict integer parsing failed" + ): + s.str.to_integer(base=pl.Series([2, 2, 3, 4])) + + +def test_strptime_unequal_length_22018() -> None: + s = pl.Series(["2020-01-01 01:00Z", "2020-01-01 02:00Z"]) + with pytest.raises(pl.exceptions.ShapeError): + s.str.strptime( + pl.Datetime, "%Y-%m-%d %H:%M%#z", ambiguous=pl.Series(["a", "b", "d"]) + ) + + +@pytest.mark.parametrize("inclusive", [False, True]) +def test_str_split_unequal_length_22018(inclusive: bool) -> None: + with pytest.raises(pl.exceptions.ShapeError): + pl.Series(["a-c", "x-y"]).str.split( + pl.Series(["-", "/", "+"]), inclusive=inclusive + ) + + +def test_str_split_self_broadcast() -> None: + assert_series_equal( + pl.Series(["a-/c"]).str.split(pl.Series(["-", "/", "+"])), + pl.Series([["a", "/c"], ["a-", "c"], ["a-/c"]]), + ) diff --git a/py-polars/tests/unit/operations/namespaces/temporal/__init__.py b/py-polars/tests/unit/operations/namespaces/temporal/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/operations/namespaces/temporal/test_add_business_days.py b/py-polars/tests/unit/operations/namespaces/temporal/test_add_business_days.py new file mode 100644 index 000000000000..b7c77ef473bb --- /dev/null +++ b/py-polars/tests/unit/operations/namespaces/temporal/test_add_business_days.py @@ -0,0 +1,292 @@ +from __future__ import annotations + +import datetime as dt +from datetime import date, datetime, timedelta +from typing import TYPE_CHECKING + +import hypothesis.strategies as st +import numpy as np +import pytest +from hypothesis import assume, given + +import polars as pl +from polars.exceptions import ComputeError, InvalidOperationError +from polars.testing import assert_series_equal + +if TYPE_CHECKING: + from polars._typing import Roll, TimeUnit + +from zoneinfo import ZoneInfo + + +def test_add_business_days() -> None: + # (Expression, expression) + df = pl.DataFrame( + { + "start": [date(2020, 1, 1), date(2020, 1, 2)], + "n": [-1, 5], + } + ) + result = df.select(result=pl.col("start").dt.add_business_days("n"))["result"] + expected = pl.Series("result", [date(2019, 12, 31), date(2020, 1, 9)], pl.Date) + assert_series_equal(result, expected) + + # (Expression, scalar) + result = df.select(result=pl.col("start").dt.add_business_days(5))["result"] + expected = pl.Series("result", [date(2020, 1, 8), date(2020, 1, 9)], pl.Date) + assert_series_equal(result, expected) + + # (Scalar, expression) + result = df.select( + result=pl.lit(date(2020, 1, 1), dtype=pl.Date).dt.add_business_days(pl.col("n")) + )["result"] + expected = pl.Series("result", [date(2019, 12, 31), date(2020, 1, 8)], pl.Date) + assert_series_equal(result, expected) + + # (Scalar, scalar) + result = df.select( + result=pl.lit(date(2020, 1, 1), dtype=pl.Date).dt.add_business_days(5) + )["result"] + expected = pl.Series("result", [date(2020, 1, 8)], pl.Date) + assert_series_equal(result, expected) + + +def test_add_business_day_w_week_mask() -> None: + df = pl.DataFrame( + { + "start": [date(2020, 1, 1), date(2020, 1, 2)], + "n": [1, 5], + } + ) + result = df.select( + result=pl.col("start").dt.add_business_days( + "n", week_mask=(True, True, True, True, True, True, False) + ) + )["result"] + expected = pl.Series("result", [date(2020, 1, 2), date(2020, 1, 8)]) + assert_series_equal(result, expected) + + result = df.select( + result=pl.col("start").dt.add_business_days( + "n", week_mask=(True, True, True, True, False, False, True) + ) + )["result"] + expected = pl.Series("result", [date(2020, 1, 2), date(2020, 1, 9)]) + assert_series_equal(result, expected) + + +def test_add_business_day_w_week_mask_invalid() -> None: + with pytest.raises(ValueError, match=r"expected a sequence of length 7 \(got 2\)"): + pl.col("start").dt.add_business_days("n", week_mask=(False, 0)) # type: ignore[arg-type] + df = pl.DataFrame( + { + "start": [date(2020, 1, 1), date(2020, 1, 2)], + "n": [1, 5], + } + ) + with pytest.raises( + ComputeError, match="`week_mask` must have at least one business day" + ): + df.select(pl.col("start").dt.add_business_days("n", week_mask=[False] * 7)) + + +def test_add_business_days_schema() -> None: + lf = pl.LazyFrame( + { + "start": [date(2020, 1, 1), date(2020, 1, 2)], + "n": [1, 5], + } + ) + result = lf.select( + result=pl.col("start").dt.add_business_days("n"), + ) + assert result.collect_schema()["result"] == pl.Date + assert result.collect().schema["result"] == pl.Date + assert 'col("start").add_business_days([col("n")])' in result.explain() + + +def test_add_business_days_w_holidays() -> None: + df = pl.DataFrame( + { + "start": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 2)], + "n": [1, 5, 7], + } + ) + result = df.select( + result=pl.col("start").dt.add_business_days( + "n", holidays=[date(2020, 1, 3), date(2020, 1, 9)] + ), + )["result"] + expected = pl.Series( + "result", [date(2020, 1, 2), date(2020, 1, 13), date(2020, 1, 15)] + ) + assert_series_equal(result, expected) + result = df.select( + result=pl.col("start").dt.add_business_days( + "n", holidays=[date(2020, 1, 1), date(2020, 1, 2)], roll="backward" + ), + )["result"] + expected = pl.Series( + "result", [date(2020, 1, 3), date(2020, 1, 9), date(2020, 1, 13)] + ) + assert_series_equal(result, expected) + result = df.select( + result=pl.col("start").dt.add_business_days( + "n", + holidays=[ + date(2019, 1, 1), + date(2020, 1, 1), + date(2020, 1, 2), + date(2021, 1, 1), + ], + roll="backward", + ), + )["result"] + expected = pl.Series( + "result", [date(2020, 1, 3), date(2020, 1, 9), date(2020, 1, 13)] + ) + assert_series_equal(result, expected) + + +def test_add_business_days_w_roll() -> None: + df = pl.DataFrame( + { + "start": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 4)], + "n": [1, 5, 7], + } + ) + with pytest.raises(ComputeError, match="is not a business date"): + df.select(result=pl.col("start").dt.add_business_days("n")) + result = df.select( + result=pl.col("start").dt.add_business_days("n", roll="forward") + )["result"] + expected = pl.Series( + "result", [date(2020, 1, 2), date(2020, 1, 9), date(2020, 1, 15)] + ) + assert_series_equal(result, expected) + result = df.select( + result=pl.col("start").dt.add_business_days("n", roll="backward") + )["result"] + expected = pl.Series( + "result", [date(2020, 1, 2), date(2020, 1, 9), date(2020, 1, 14)] + ) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize("time_zone", [None, "Europe/London", "Asia/Kathmandu"]) +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_add_business_days_datetime(time_zone: str | None, time_unit: TimeUnit) -> None: + tzinfo = ZoneInfo(time_zone) if time_zone is not None else None + df = pl.DataFrame( + { + "start": [ + datetime(2020, 3, 28, 1, tzinfo=tzinfo), + datetime(2020, 1, 10, 4, tzinfo=tzinfo), + ] + }, + schema={"start": pl.Datetime(time_unit, time_zone)}, + ) + result = df.select( + result=pl.col("start").dt.add_business_days(2, week_mask=[True] * 7) + )["result"] + expected = pl.Series( + "result", + [datetime(2020, 3, 30, 1), datetime(2020, 1, 12, 4)], + pl.Datetime(time_unit), + ).dt.replace_time_zone(time_zone) + assert_series_equal(result, expected) + + with pytest.raises(ComputeError, match="is not a business date"): + df.select(result=pl.col("start").dt.add_business_days(2)) + + +def test_add_business_days_invalid() -> None: + df = pl.DataFrame({"start": [timedelta(1)]}) + with pytest.raises(InvalidOperationError, match="expected date or datetime"): + df.select(result=pl.col("start").dt.add_business_days(2, week_mask=[True] * 7)) + df = pl.DataFrame({"start": [date(2020, 1, 1)]}) + with pytest.raises( + InvalidOperationError, + match="expected Int64, Int32, UInt64, or UInt32, got f64", + ): + df.select( + result=pl.col("start").dt.add_business_days(1.5, week_mask=[True] * 7) + ) + with pytest.raises( + ValueError, + match="`roll` must be one of {'raise', 'forward', 'backward'}, got cabbage", + ): + df.select(result=pl.col("start").dt.add_business_days(1, roll="cabbage")) # type: ignore[arg-type] + + +def test_add_business_days_w_nulls() -> None: + df = pl.DataFrame( + { + "start": [date(2020, 3, 28), None], + "n": [None, 2], + }, + ) + result = df.select(result=pl.col("start").dt.add_business_days("n"))["result"] + expected = pl.Series("result", [None, None], dtype=pl.Date) + assert_series_equal(result, expected) + + result = df.select( + result=pl.col("start").dt.add_business_days(pl.lit(None, dtype=pl.Int32)) + )["result"] + assert_series_equal(result, expected) + + result = df.select(result=pl.lit(None, dtype=pl.Date).dt.add_business_days("n"))[ + "result" + ] + assert_series_equal(result, expected) + + result = df.select( + result=pl.lit(None, dtype=pl.Date).dt.add_business_days( + pl.lit(None, dtype=pl.Int32) + ) + )["result"] + expected = pl.Series("result", [None], dtype=pl.Date) + assert_series_equal(result, expected) + + +@given( + start=st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)), + n=st.integers(min_value=-100, max_value=100), + week_mask=st.lists( + st.sampled_from([True, False]), + min_size=7, + max_size=7, + ), + holidays=st.lists( + st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)), + min_size=0, + max_size=100, + ), + roll=st.sampled_from(["forward", "backward"]), +) +def test_against_np_busday_offset( + start: dt.date, + n: int, + week_mask: tuple[bool, ...], + holidays: list[dt.date], + roll: Roll, +) -> None: + assume(any(week_mask)) + result = ( + pl.DataFrame({"start": [start]}) + .select( + res=pl.col("start").dt.add_business_days( + n, week_mask=week_mask, holidays=holidays, roll=roll + ) + )["res"] + .item() + ) + expected = np.busday_offset( + start, n, weekmask=week_mask, holidays=holidays, roll=roll + ) + assert result == expected + + +def test_unequal_lengths_22018() -> None: + with pytest.raises(pl.exceptions.ShapeError): + pl.Series([date(2088, 8, 5)] * 2).dt.add_business_days(pl.Series([1] * 3)) diff --git a/py-polars/tests/unit/operations/namespaces/temporal/test_datetime.py b/py-polars/tests/unit/operations/namespaces/temporal/test_datetime.py new file mode 100644 index 000000000000..afc25c35ed8e --- /dev/null +++ b/py-polars/tests/unit/operations/namespaces/temporal/test_datetime.py @@ -0,0 +1,1520 @@ +from __future__ import annotations + +from collections import OrderedDict +from datetime import date, datetime, time, timedelta +from typing import TYPE_CHECKING +from zoneinfo import ZoneInfo + +import pytest +from hypothesis import given + +import polars as pl +from polars.datatypes import DTYPE_TEMPORAL_UNITS +from polars.exceptions import ComputeError, InvalidOperationError +from polars.testing import assert_frame_equal, assert_series_equal +from polars.testing.parametric import series + +if TYPE_CHECKING: + from polars._typing import PolarsDataType, TemporalLiteral, TimeUnit + + +@pytest.fixture +def series_of_int_dates() -> pl.Series: + return pl.Series([10000, 20000, 30000], dtype=pl.Date) + + +@pytest.fixture +def series_of_str_dates() -> pl.Series: + return pl.Series(["2020-01-01 00:00:00.000000000", "2020-02-02 03:20:10.987654321"]) + + +def test_dt_to_string(series_of_int_dates: pl.Series) -> None: + expected_str_dates = pl.Series(["1997-05-19", "2024-10-04", "2052-02-20"]) + + assert series_of_int_dates.dtype == pl.Date + assert_series_equal(series_of_int_dates.dt.to_string("%F"), expected_str_dates) + + # Check strftime alias as well + assert_series_equal(series_of_int_dates.dt.strftime("%F"), expected_str_dates) + + +@pytest.mark.parametrize( + ("unit_attr", "expected"), + [ + ("millennium", pl.Series(values=[2, 3, 3], dtype=pl.Int32)), + ("century", pl.Series(values=[20, 21, 21], dtype=pl.Int32)), + ("year", pl.Series(values=[1997, 2024, 2052], dtype=pl.Int32)), + ("iso_year", pl.Series(values=[1997, 2024, 2052], dtype=pl.Int32)), + ("quarter", pl.Series(values=[2, 4, 1], dtype=pl.Int8)), + ("month", pl.Series(values=[5, 10, 2], dtype=pl.Int8)), + ("week", pl.Series(values=[21, 40, 8], dtype=pl.Int8)), + ("day", pl.Series(values=[19, 4, 20], dtype=pl.Int8)), + ("weekday", pl.Series(values=[1, 5, 2], dtype=pl.Int8)), + ("ordinal_day", pl.Series(values=[139, 278, 51], dtype=pl.Int16)), + ], +) +@pytest.mark.parametrize("time_zone", ["Asia/Kathmandu", None]) +def test_dt_extract_datetime_component( + unit_attr: str, + expected: pl.Series, + series_of_int_dates: pl.Series, + time_zone: str | None, +) -> None: + assert_series_equal(getattr(series_of_int_dates.dt, unit_attr)(), expected) + assert_series_equal( + getattr( + series_of_int_dates.cast(pl.Datetime).dt.replace_time_zone(time_zone).dt, + unit_attr, + )(), + expected, + ) + + +@pytest.mark.parametrize( + ("unit_attr", "expected"), + [ + ("hour", pl.Series(values=[0, 3], dtype=pl.Int8)), + ("minute", pl.Series(values=[0, 20], dtype=pl.Int8)), + ("second", pl.Series(values=[0, 10], dtype=pl.Int8)), + ("millisecond", pl.Series(values=[0, 987], dtype=pl.Int32)), + ("microsecond", pl.Series(values=[0, 987654], dtype=pl.Int32)), + ("nanosecond", pl.Series(values=[0, 987654321], dtype=pl.Int32)), + ], +) +def test_strptime_extract_times( + unit_attr: str, + expected: pl.Series, + series_of_int_dates: pl.Series, + series_of_str_dates: pl.Series, +) -> None: + s = series_of_str_dates.str.strptime(pl.Datetime, format="%Y-%m-%d %H:%M:%S.%9f") + + assert_series_equal(getattr(s.dt, unit_attr)(), expected) + + +@pytest.mark.parametrize("time_zone", [None, "Asia/Kathmandu"]) +@pytest.mark.parametrize( + ("attribute", "expected"), + [ + ("date", date(2022, 1, 1)), + ("time", time(23)), + ], +) +def test_dt_date_and_time( + attribute: str, time_zone: None | str, expected: date | time +) -> None: + ser = pl.Series([datetime(2022, 1, 1, 23)]).dt.replace_time_zone(time_zone) + result = getattr(ser.dt, attribute)().item() + assert result == expected + + +@pytest.mark.parametrize("time_zone", [None, "Asia/Kathmandu"]) +@pytest.mark.parametrize("time_unit", ["us", "ns", "ms"]) +def test_dt_replace_time_zone_none(time_zone: str | None, time_unit: TimeUnit) -> None: + ser = ( + pl.Series([datetime(2022, 1, 1, 23)]) + .dt.cast_time_unit(time_unit) + .dt.replace_time_zone(time_zone) + ) + result = ser.dt.replace_time_zone(None) + expected = datetime(2022, 1, 1, 23) + assert result.dtype == pl.Datetime(time_unit, None) + assert result.item() == expected + + +def test_dt_datetime_deprecated() -> None: + s = pl.Series([datetime(2022, 1, 1, 23)]).dt.replace_time_zone("Asia/Kathmandu") + with pytest.deprecated_call(): + result = s.dt.datetime() + expected = datetime(2022, 1, 1, 23) + assert result.dtype == pl.Datetime(time_zone=None) + assert result.item() == expected + + +@pytest.mark.parametrize("time_zone", [None, "Asia/Kathmandu", "UTC"]) +def test_local_date_sortedness(time_zone: str | None) -> None: + # singleton + ser = (pl.Series([datetime(2022, 1, 1, 23)]).dt.replace_time_zone(time_zone)).sort() + result = ser.dt.date() + assert result.flags["SORTED_ASC"] + + # 2 elements + ser = ( + pl.Series([datetime(2022, 1, 1, 23)] * 2).dt.replace_time_zone(time_zone) + ).sort() + result = ser.dt.date() + assert result.flags["SORTED_ASC"] + + +@pytest.mark.parametrize("time_zone", [None, "Asia/Kathmandu", "UTC"]) +def test_local_time_sortedness(time_zone: str | None) -> None: + # singleton - always sorted + ser = (pl.Series([datetime(2022, 1, 1, 23)]).dt.replace_time_zone(time_zone)).sort() + result = ser.dt.time() + assert result.flags["SORTED_ASC"] + + # three elements - not sorted + ser = ( + pl.Series( + [ + datetime(2022, 1, 1, 23), + datetime(2022, 1, 2, 21), + datetime(2022, 1, 3, 22), + ] + ).dt.replace_time_zone(time_zone) + ).sort() + result = ser.dt.time() + assert not result.flags["SORTED_ASC"] + assert not result.flags["SORTED_DESC"] + + +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_local_time_before_epoch(time_unit: TimeUnit) -> None: + ser = pl.Series([datetime(1969, 7, 21, 2, 56, 2, 123000)]).dt.cast_time_unit( + time_unit + ) + result = ser.dt.time().item() + expected = time(2, 56, 2, 123000) + assert result == expected + + +@pytest.mark.parametrize( + ("time_zone", "offset", "expected"), + [ + (None, "1d", True), + ("Europe/London", "1d", False), + ("UTC", "1d", True), + (None, "1m", True), + ("Europe/London", "1m", True), + ("UTC", "1m", True), + (None, "1w", True), + ("Europe/London", "1w", False), + ("UTC", "1w", True), + (None, "1h", True), + ("Europe/London", "1h", True), + ("UTC", "1h", True), + ], +) +def test_offset_by_sortedness( + time_zone: str | None, offset: str, expected: bool +) -> None: + s = pl.datetime_range( + datetime(2020, 10, 25), + datetime(2020, 10, 25, 3), + "30m", + time_zone=time_zone, + eager=True, + ).sort() + assert s.flags["SORTED_ASC"] + assert not s.flags["SORTED_DESC"] + result = s.dt.offset_by(offset) + assert result.flags["SORTED_ASC"] == expected + assert not result.flags["SORTED_DESC"] + + +def test_offset_by_invalid_duration() -> None: + with pytest.raises( + InvalidOperationError, match="expected leading integer in the duration string" + ): + pl.Series([datetime(2022, 3, 20, 5, 7)]).dt.offset_by("P") + + +def test_offset_by_missing_unit() -> None: + with pytest.raises( + InvalidOperationError, + match="expected a unit to follow integer in the duration string '1'", + ): + pl.Series([datetime(2022, 3, 20, 5, 7)]).dt.offset_by("1") + + with pytest.raises( + InvalidOperationError, + match="expected a unit to follow integer in the duration string '1mo23d4'", + ): + pl.Series([datetime(2022, 3, 20, 5, 7)]).dt.offset_by("1mo23d4") + + with pytest.raises( + InvalidOperationError, + match="expected a unit to follow integer in the duration string '-2d1'", + ): + pl.Series([datetime(2022, 3, 20, 5, 7)]).dt.offset_by("-2d1") + + with pytest.raises( + InvalidOperationError, + match="expected a unit to follow integer in the duration string '1d2'", + ): + pl.DataFrame( + {"a": [datetime(2022, 3, 20, 5, 7)] * 2, "b": ["1d", "1d2"]} + ).select(pl.col("a").dt.offset_by(pl.col("b"))) + + +def test_dt_datetime_date_time_invalid() -> None: + with pytest.raises(ComputeError, match="expected Datetime or Date"): + pl.Series([time(23)]).dt.date() + with pytest.raises(ComputeError, match="expected Datetime or Date"): + pl.Series([timedelta(1)]).dt.date() + with pytest.raises(ComputeError, match="expected Datetime or Time"): + pl.Series([timedelta(1)]).dt.time() + with pytest.raises(ComputeError, match="expected Datetime or Time"): + pl.Series([date(2020, 1, 1)]).dt.time() + + +@pytest.mark.parametrize( + ("dt", "expected"), + [ + (datetime(2022, 3, 15, 3), datetime(2022, 3, 1, 3)), + (datetime(2022, 3, 15, 3, 2, 1, 123000), datetime(2022, 3, 1, 3, 2, 1, 123000)), + (datetime(2022, 3, 15), datetime(2022, 3, 1)), + (datetime(2022, 3, 1), datetime(2022, 3, 1)), + ], +) +@pytest.mark.parametrize( + ("tzinfo", "time_zone"), + [ + (None, None), + (ZoneInfo("Asia/Kathmandu"), "Asia/Kathmandu"), + ], +) +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_month_start_datetime( + dt: datetime, + expected: datetime, + time_unit: TimeUnit, + tzinfo: ZoneInfo | None, + time_zone: str | None, +) -> None: + ser = pl.Series([dt]).dt.replace_time_zone(time_zone).dt.cast_time_unit(time_unit) + result = ser.dt.month_start().item() + assert result == expected.replace(tzinfo=tzinfo) + + +@pytest.mark.parametrize( + ("dt", "expected"), + [ + (date(2022, 3, 15), date(2022, 3, 1)), + (date(2022, 3, 31), date(2022, 3, 1)), + ], +) +def test_month_start_date(dt: date, expected: date) -> None: + ser = pl.Series([dt]) + result = ser.dt.month_start().item() + assert result == expected + + +@pytest.mark.parametrize( + ("dt", "expected"), + [ + (datetime(2022, 3, 15, 3), datetime(2022, 3, 31, 3)), + ( + datetime(2022, 3, 15, 3, 2, 1, 123000), + datetime(2022, 3, 31, 3, 2, 1, 123000), + ), + (datetime(2022, 3, 15), datetime(2022, 3, 31)), + (datetime(2022, 3, 31), datetime(2022, 3, 31)), + ], +) +@pytest.mark.parametrize( + ("tzinfo", "time_zone"), + [ + (None, None), + (ZoneInfo("Asia/Kathmandu"), "Asia/Kathmandu"), + ], +) +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_month_end_datetime( + dt: datetime, + expected: datetime, + time_unit: TimeUnit, + tzinfo: ZoneInfo | None, + time_zone: str | None, +) -> None: + ser = pl.Series([dt]).dt.replace_time_zone(time_zone).dt.cast_time_unit(time_unit) + result = ser.dt.month_end().item() + assert result == expected.replace(tzinfo=tzinfo) + + +@pytest.mark.parametrize( + ("dt", "expected"), + [ + (date(2022, 3, 15), date(2022, 3, 31)), + (date(2022, 3, 31), date(2022, 3, 31)), + ], +) +def test_month_end_date(dt: date, expected: date) -> None: + ser = pl.Series([dt]) + result = ser.dt.month_end().item() + assert result == expected + + +def test_month_start_end_invalid() -> None: + ser = pl.Series([time(1, 2, 3)]) + with pytest.raises( + InvalidOperationError, + match=r"`month_start` operation not supported for dtype `time` \(expected: date/datetime\)", + ): + ser.dt.month_start() + with pytest.raises( + InvalidOperationError, + match=r"`month_end` operation not supported for dtype `time` \(expected: date/datetime\)", + ): + ser.dt.month_end() + + +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_base_utc_offset(time_unit: TimeUnit) -> None: + ser = pl.datetime_range( + datetime(2011, 12, 29), + datetime(2012, 1, 1), + "2d", + time_zone="Pacific/Apia", + eager=True, + ).dt.cast_time_unit(time_unit) + result = ser.dt.base_utc_offset().rename("base_utc_offset") + expected = pl.Series( + "base_utc_offset", + [-11 * 3600 * 1000, 13 * 3600 * 1000], + dtype=pl.Duration("ms"), + ) + assert_series_equal(result, expected) + + +def test_base_utc_offset_lazy_schema() -> None: + ser = pl.datetime_range( + datetime(2020, 10, 25), + datetime(2020, 10, 26), + time_zone="Europe/London", + eager=True, + ) + df = pl.DataFrame({"ts": ser}).lazy() + result = df.with_columns( + base_utc_offset=pl.col("ts").dt.base_utc_offset() + ).collect_schema() + expected = { + "ts": pl.Datetime(time_unit="us", time_zone="Europe/London"), + "base_utc_offset": pl.Duration(time_unit="ms"), + } + assert result == expected + + +def test_base_utc_offset_invalid() -> None: + ser = pl.datetime_range(datetime(2020, 10, 25), datetime(2020, 10, 26), eager=True) + with pytest.raises( + InvalidOperationError, + match=r"`base_utc_offset` operation not supported for dtype `datetime\[μs\]` \(expected: time-zone-aware datetime\)", + ): + ser.dt.base_utc_offset().rename("base_utc_offset") + + +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_dst_offset(time_unit: TimeUnit) -> None: + ser = pl.datetime_range( + datetime(2020, 10, 25), + datetime(2020, 10, 26), + time_zone="Europe/London", + eager=True, + ).dt.cast_time_unit(time_unit) + result = ser.dt.dst_offset().rename("dst_offset") + expected = pl.Series("dst_offset", [3_600 * 1_000, 0], dtype=pl.Duration("ms")) + assert_series_equal(result, expected) + + +def test_dst_offset_lazy_schema() -> None: + ser = pl.datetime_range( + datetime(2020, 10, 25), + datetime(2020, 10, 26), + time_zone="Europe/London", + eager=True, + ) + df = pl.DataFrame({"ts": ser}).lazy() + result = df.with_columns(dst_offset=pl.col("ts").dt.dst_offset()).collect_schema() + expected = { + "ts": pl.Datetime(time_unit="us", time_zone="Europe/London"), + "dst_offset": pl.Duration(time_unit="ms"), + } + assert result == expected + + +def test_dst_offset_invalid() -> None: + ser = pl.datetime_range(datetime(2020, 10, 25), datetime(2020, 10, 26), eager=True) + with pytest.raises( + InvalidOperationError, + match=r"`dst_offset` operation not supported for dtype `datetime\[μs\]` \(expected: time-zone-aware datetime\)", + ): + ser.dt.dst_offset().rename("dst_offset") + + +@pytest.mark.parametrize( + ("time_unit", "expected"), + [ + ("d", pl.Series(values=[18262, 18294], dtype=pl.Int32)), + ("s", pl.Series(values=[1_577_836_800, 1_580_613_610], dtype=pl.Int64)), + ( + "ms", + pl.Series(values=[1_577_836_800_000, 1_580_613_610_987], dtype=pl.Int64), + ), + ], +) +def test_strptime_epoch( + time_unit: TimeUnit, + expected: pl.Series, + series_of_str_dates: pl.Series, +) -> None: + s = series_of_str_dates.str.strptime(pl.Datetime, format="%Y-%m-%d %H:%M:%S.%9f") + + assert_series_equal(s.dt.epoch(time_unit=time_unit), expected) + + +def test_strptime_fractional_seconds(series_of_str_dates: pl.Series) -> None: + s = series_of_str_dates.str.strptime(pl.Datetime, format="%Y-%m-%d %H:%M:%S.%9f") + + assert_series_equal( + s.dt.second(fractional=True), + pl.Series([0.0, 10.987654321], dtype=pl.Float64), + ) + + +@pytest.mark.parametrize( + ("unit_attr", "expected"), + [ + ("total_days", pl.Series([1])), + ("total_hours", pl.Series([24])), + ("total_minutes", pl.Series([24 * 60])), + ("total_seconds", pl.Series([3600 * 24])), + ("total_milliseconds", pl.Series([3600 * 24 * int(1e3)])), + ("total_microseconds", pl.Series([3600 * 24 * int(1e6)])), + ("total_nanoseconds", pl.Series([3600 * 24 * int(1e9)])), + ], +) +def test_duration_extract_times( + unit_attr: str, + expected: pl.Series, +) -> None: + duration = pl.Series([datetime(2022, 1, 2)]) - pl.Series([datetime(2022, 1, 1)]) + + assert_series_equal(getattr(duration.dt, unit_attr)(), expected) + + +@pytest.mark.parametrize( + ("time_unit", "every"), + [ + ("ms", "1h"), + ("us", "1h0m0s"), + ("ns", timedelta(hours=1)), + ], + ids=["milliseconds", "microseconds", "nanoseconds"], +) +def test_truncate( + time_unit: TimeUnit, + every: str | timedelta, +) -> None: + start, stop = datetime(2022, 1, 1), datetime(2022, 1, 2) + s = pl.datetime_range( + start, + stop, + timedelta(minutes=30), + time_unit=time_unit, + eager=True, + ).alias(f"dates[{time_unit}]") + + # can pass strings and time-deltas + out = s.dt.truncate(every) + assert out.dt[0] == start + assert out.dt[1] == start + assert out.dt[2] == start + timedelta(hours=1) + assert out.dt[3] == start + timedelta(hours=1) + # ... + assert out.dt[-3] == stop - timedelta(hours=1) + assert out.dt[-2] == stop - timedelta(hours=1) + assert out.dt[-1] == stop + + +def test_truncate_negative() -> None: + """Test that truncating to a negative duration gives a helpful error message.""" + df = pl.DataFrame( + { + "date": [date(1895, 5, 7), date(1955, 11, 5)], + "datetime": [datetime(1895, 5, 7), datetime(1955, 11, 5)], + "duration": ["-1m", "1m"], + } + ) + + with pytest.raises( + ComputeError, match="cannot truncate a Date to a negative duration" + ): + df.select(pl.col("date").dt.truncate("-1m")) + + with pytest.raises( + ComputeError, match="cannot truncate a Datetime to a negative duration" + ): + df.select(pl.col("datetime").dt.truncate("-1m")) + + with pytest.raises( + ComputeError, match="cannot truncate a Date to a negative duration" + ): + df.select(pl.col("date").dt.truncate(pl.col("duration"))) + + with pytest.raises( + ComputeError, match="cannot truncate a Datetime to a negative duration" + ): + df.select(pl.col("datetime").dt.truncate(pl.col("duration"))) + + +@pytest.mark.parametrize( + ("time_unit", "every"), + [ + ("ms", "1h"), + ("us", "1h0m0s"), + ("ns", timedelta(hours=1)), + ], + ids=["milliseconds", "microseconds", "nanoseconds"], +) +def test_round( + time_unit: TimeUnit, + every: str | timedelta, +) -> None: + start, stop = datetime(2022, 1, 1), datetime(2022, 1, 2) + s = pl.datetime_range( + start, + stop, + timedelta(minutes=30), + time_unit=time_unit, + eager=True, + ).alias(f"dates[{time_unit}]") + + # can pass strings and time-deltas + out = s.dt.round(every) + assert out.dt[0] == start + assert out.dt[1] == start + timedelta(hours=1) + assert out.dt[2] == start + timedelta(hours=1) + assert out.dt[3] == start + timedelta(hours=2) + # ... + assert out.dt[-3] == stop - timedelta(hours=1) + assert out.dt[-2] == stop + assert out.dt[-1] == stop + + +def test_round_expr() -> None: + df = pl.DataFrame( + { + "date": [ + datetime(2022, 11, 14), + datetime(2023, 10, 11), + datetime(2022, 3, 20, 5, 7, 18), + datetime(2022, 4, 3, 13, 30, 32), + None, + datetime(2022, 12, 1), + ], + "every": ["1y", "1mo", "1m", "1m", "1mo", None], + } + ) + + output = df.select( + all_expr=pl.col("date").dt.round(every=pl.col("every")), + date_lit=pl.lit(datetime(2022, 4, 3, 13, 30, 32)).dt.round( + every=pl.col("every") + ), + every_lit=pl.col("date").dt.round("1d"), + ) + + expected = pl.DataFrame( + { + "all_expr": [ + datetime(2023, 1, 1), + datetime(2023, 10, 1), + datetime(2022, 3, 20, 5, 7), + datetime(2022, 4, 3, 13, 31), + None, + None, + ], + "date_lit": [ + datetime(2022, 1, 1), + datetime(2022, 4, 1), + datetime(2022, 4, 3, 13, 31), + datetime(2022, 4, 3, 13, 31), + datetime(2022, 4, 1), + None, + ], + "every_lit": [ + datetime(2022, 11, 14), + datetime(2023, 10, 11), + datetime(2022, 3, 20), + datetime(2022, 4, 4), + None, + datetime(2022, 12, 1), + ], + } + ) + + assert_frame_equal(output, expected) + + all_lit = pl.select(all_lit=pl.lit(datetime(2022, 3, 20, 5, 7)).dt.round("1h")) + assert all_lit.to_dict(as_series=False) == {"all_lit": [datetime(2022, 3, 20, 5)]} + + +def test_round_negative() -> None: + """Test that rounding to a negative duration gives a helpful error message.""" + with pytest.raises( + ComputeError, match="cannot round a Date to a negative duration" + ): + pl.Series([date(1895, 5, 7)]).dt.round("-1m") + + with pytest.raises( + ComputeError, match="cannot round a Datetime to a negative duration" + ): + pl.Series([datetime(1895, 5, 7)]).dt.round("-1m") + + +def test_round_invalid_duration() -> None: + with pytest.raises( + InvalidOperationError, match="expected leading integer in the duration string" + ): + pl.Series([datetime(2022, 3, 20, 5, 7)]).dt.round("P") + + +@pytest.mark.parametrize( + ("time_unit", "date_in_that_unit"), + [ + ("ns", [978307200000000000, 981022089000000000]), + ("us", [978307200000000, 981022089000000]), + ("ms", [978307200000, 981022089000]), + ], + ids=["nanoseconds", "microseconds", "milliseconds"], +) +def test_cast_time_units( + time_unit: TimeUnit, + date_in_that_unit: list[int], +) -> None: + dates = pl.Series([datetime(2001, 1, 1), datetime(2001, 2, 1, 10, 8, 9)]) + + assert dates.dt.cast_time_unit(time_unit).cast(int).to_list() == date_in_that_unit + + +def test_epoch_matches_timestamp() -> None: + dates = pl.Series([datetime(2001, 1, 1), datetime(2001, 2, 1, 10, 8, 9)]) + + for unit in DTYPE_TEMPORAL_UNITS: + assert_series_equal(dates.dt.epoch(unit), dates.dt.timestamp(unit)) + + assert_series_equal(dates.dt.epoch("s"), dates.dt.timestamp("ms") // 1000) + assert_series_equal( + dates.dt.epoch("d"), + (dates.dt.timestamp("ms") // (1000 * 3600 * 24)).cast(pl.Int32), + ) + + +@pytest.mark.parametrize( + ("tzinfo", "time_zone"), + [(None, None), (ZoneInfo("Asia/Kathmandu"), "Asia/Kathmandu")], +) +def test_date_time_combine(tzinfo: ZoneInfo | None, time_zone: str | None) -> None: + # Define a DataFrame with columns for datetime, date, and time + df = pl.DataFrame( + { + "dtm": [ + datetime(2022, 12, 31, 10, 30, 45), + datetime(2023, 7, 5, 23, 59, 59), + ], + "dt": [ + date(2022, 10, 10), + date(2022, 7, 5), + ], + "tm": [ + time(1, 2, 3, 456000), + time(7, 8, 9, 101000), + ], + } + ) + df = df.with_columns(pl.col("dtm").dt.replace_time_zone(time_zone)) + + # Combine datetime/date with time + df = df.select( + pl.col("dtm").dt.combine(pl.col("tm")).alias("d1"), # datetime & time + pl.col("dt").dt.combine(pl.col("tm")).alias("d2"), # date & time + pl.col("dt").dt.combine(time(4, 5, 6)).alias("d3"), # date & specified time + ) + + # Assert that the new columns have the expected values and datatypes + expected_dict = { + "d1": [ # Time component should be overwritten by `tm` values + datetime(2022, 12, 31, 1, 2, 3, 456000, tzinfo=tzinfo), + datetime(2023, 7, 5, 7, 8, 9, 101000, tzinfo=tzinfo), + ], + "d2": [ # Both date and time components combined "as-is" into new datetime + datetime(2022, 10, 10, 1, 2, 3, 456000), + datetime(2022, 7, 5, 7, 8, 9, 101000), + ], + "d3": [ # New datetime should use specified time component + datetime(2022, 10, 10, 4, 5, 6), + datetime(2022, 7, 5, 4, 5, 6), + ], + } + assert df.to_dict(as_series=False) == expected_dict + + expected_schema = { + "d1": pl.Datetime("us", time_zone), + "d2": pl.Datetime("us"), + "d3": pl.Datetime("us"), + } + assert df.schema == expected_schema + + +def test_combine_unsupported_types() -> None: + with pytest.raises(ComputeError, match="expected Date or Datetime, got time"): + pl.Series([time(1, 2)]).dt.combine(time(3, 4)) + + +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +@pytest.mark.parametrize("time_zone", ["Asia/Kathmandu", None]) +def test_combine_lazy_schema_datetime( + time_zone: str | None, + time_unit: TimeUnit, +) -> None: + df = pl.DataFrame({"ts": pl.Series([datetime(2020, 1, 1)])}) + df = df.with_columns(pl.col("ts").dt.replace_time_zone(time_zone)) + result = df.lazy().select( + pl.col("ts").dt.combine(time(1, 2, 3), time_unit=time_unit) + ) + expected_dtypes = [pl.Datetime(time_unit, time_zone)] + assert result.collect_schema().dtypes() == expected_dtypes + + +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_combine_lazy_schema_date(time_unit: TimeUnit) -> None: + df = pl.DataFrame({"ts": pl.Series([date(2020, 1, 1)])}) + result = df.lazy().select( + pl.col("ts").dt.combine(time(1, 2, 3), time_unit=time_unit) + ) + expected_dtypes = [pl.Datetime(time_unit, None)] + assert result.collect_schema().dtypes() == expected_dtypes + + +def test_is_leap_year() -> None: + assert pl.datetime_range( + datetime(1990, 1, 1), datetime(2004, 1, 1), "1y", eager=True + ).dt.is_leap_year().to_list() == [ + False, + False, + True, # 1992 + False, + False, + False, + True, # 1996 + False, + False, + False, + True, # 2000 + False, + False, + False, + True, # 2004 + ] + + +def test_quarter() -> None: + assert pl.datetime_range( + datetime(2022, 1, 1), datetime(2022, 12, 1), "1mo", eager=True + ).dt.quarter().to_list() == [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4] + + +def test_offset_by() -> None: + df = pl.DataFrame( + { + "dates": pl.datetime_range( + datetime(2000, 1, 1), datetime(2020, 1, 1), "1y", eager=True + ) + } + ) + + # Add two new columns to the DataFrame using the offset_by() method + df = df.with_columns( + df["dates"].dt.offset_by("1y").alias("date_plus_1y"), + df["dates"].dt.offset_by("-1y2mo").alias("date_min"), + ) + + # Assert that the day of the month for all the dates in new columns is 1 + assert (df["date_plus_1y"].dt.day() == 1).all() + assert (df["date_min"].dt.day() == 1).all() + + # Assert that the 'date_min' column contains the expected list of dates + expected_dates = [datetime(year, 11, 1, 0, 0) for year in range(1998, 2019)] + assert df["date_min"].to_list() == expected_dates + + +@pytest.mark.parametrize("time_zone", ["US/Central", None]) +def test_offset_by_crossing_dst(time_zone: str | None) -> None: + ser = pl.Series([datetime(2021, 11, 7)]).dt.replace_time_zone(time_zone) + result = ser.dt.offset_by("1d") + expected = pl.Series([datetime(2021, 11, 8)]).dt.replace_time_zone(time_zone) + assert_series_equal(result, expected) + + +def test_negative_offset_by_err_msg_8464() -> None: + result = pl.Series([datetime(2022, 3, 30)]).dt.offset_by("-1mo") + expected = pl.Series([datetime(2022, 2, 28)]) + assert_series_equal(result, expected) + + +def test_offset_by_truncate_sorted_flag() -> None: + s = pl.Series([datetime(2001, 1, 1), datetime(2001, 1, 2)]) + s = s.set_sorted() + + assert s.flags["SORTED_ASC"] + s1 = s.dt.offset_by("1d") + assert s1.to_list() == [datetime(2001, 1, 2), datetime(2001, 1, 3)] + assert s1.flags["SORTED_ASC"] + s2 = s1.dt.truncate("1mo") + assert s2.flags["SORTED_ASC"] + + +def test_offset_by_broadcasting() -> None: + # test broadcast lhs + df = pl.DataFrame( + { + "offset": ["1d", "10d", "3d", None], + } + ) + result = df.select( + d1=pl.lit(datetime(2020, 10, 25)).dt.offset_by(pl.col("offset")), + d2=pl.lit(datetime(2020, 10, 25)) + .dt.cast_time_unit("ms") + .dt.offset_by(pl.col("offset")), + d3=pl.lit(datetime(2020, 10, 25)) + .dt.replace_time_zone("Europe/London") + .dt.offset_by(pl.col("offset")), + d4=pl.lit(datetime(2020, 10, 25)).dt.date().dt.offset_by(pl.col("offset")), + d5=pl.lit(None, dtype=pl.Datetime).dt.offset_by(pl.col("offset")), + ) + expected_dict = { + "d1": [ + datetime(2020, 10, 26), + datetime(2020, 11, 4), + datetime(2020, 10, 28), + None, + ], + "d2": [ + datetime(2020, 10, 26), + datetime(2020, 11, 4), + datetime(2020, 10, 28), + None, + ], + "d3": [ + datetime(2020, 10, 26, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 11, 4, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 28, tzinfo=ZoneInfo("Europe/London")), + None, + ], + "d4": [ + datetime(2020, 10, 26).date(), + datetime(2020, 11, 4).date(), + datetime(2020, 10, 28).date(), + None, + ], + "d5": [None, None, None, None], + } + assert result.to_dict(as_series=False) == expected_dict + + # test broadcast rhs + df = pl.DataFrame({"dt": [datetime(2020, 10, 25), datetime(2021, 1, 2), None]}) + result = df.select( + d1=pl.col("dt").dt.offset_by(pl.lit("1mo3d")), + d2=pl.col("dt").dt.cast_time_unit("ms").dt.offset_by(pl.lit("1y1mo")), + d3=pl.col("dt") + .dt.replace_time_zone("Europe/London") + .dt.offset_by(pl.lit("3d")), + d4=pl.col("dt").dt.date().dt.offset_by(pl.lit("1y1mo1d")), + ) + expected_dict = { + "d1": [datetime(2020, 11, 28), datetime(2021, 2, 5), None], + "d2": [datetime(2021, 11, 25), datetime(2022, 2, 2), None], + "d3": [ + datetime(2020, 10, 28, tzinfo=ZoneInfo("Europe/London")), + datetime(2021, 1, 5, tzinfo=ZoneInfo("Europe/London")), + None, + ], + "d4": [datetime(2021, 11, 26).date(), datetime(2022, 2, 3).date(), None], + } + assert result.to_dict(as_series=False) == expected_dict + + # test all literal + result = df.select(d=pl.lit(datetime(2021, 11, 26)).dt.offset_by("1mo1d")) + assert result.to_dict(as_series=False) == {"d": [datetime(2021, 12, 27)]} + + +def test_offset_by_expressions() -> None: + df = pl.DataFrame( + { + "a": [ + datetime(2020, 10, 25), + datetime(2021, 1, 2), + None, + datetime(2021, 1, 4), + None, + ], + "b": ["1d", "10d", "3d", None, None], + } + ) + df = df.sort("a") + result = df.select( + c=pl.col("a").dt.offset_by(pl.col("b")), + d=pl.col("a").dt.cast_time_unit("ms").dt.offset_by(pl.col("b")), + e=pl.col("a").dt.replace_time_zone("Europe/London").dt.offset_by(pl.col("b")), + f=pl.col("a").dt.date().dt.offset_by(pl.col("b")), + ) + + expected = pl.DataFrame( + { + "c": [None, None, datetime(2020, 10, 26), datetime(2021, 1, 12), None], + "d": [None, None, datetime(2020, 10, 26), datetime(2021, 1, 12), None], + "e": [ + None, + None, + datetime(2020, 10, 26, tzinfo=ZoneInfo("Europe/London")), + datetime(2021, 1, 12, tzinfo=ZoneInfo("Europe/London")), + None, + ], + "f": [None, None, date(2020, 10, 26), date(2021, 1, 12), None], + }, + schema_overrides={ + "d": pl.Datetime("ms"), + "e": pl.Datetime(time_zone="Europe/London"), + }, + ) + assert_frame_equal(result, expected) + assert result.flags == { + "c": {"SORTED_ASC": False, "SORTED_DESC": False}, + "d": {"SORTED_ASC": False, "SORTED_DESC": False}, + "e": {"SORTED_ASC": False, "SORTED_DESC": False}, + "f": {"SORTED_ASC": False, "SORTED_DESC": False}, + } + + # Check single-row cases + for i in range(df.height): + df_slice = df[i : i + 1] + result = df_slice.select( + c=pl.col("a").dt.offset_by(pl.col("b")), + d=pl.col("a").dt.cast_time_unit("ms").dt.offset_by(pl.col("b")), + e=pl.col("a") + .dt.replace_time_zone("Europe/London") + .dt.offset_by(pl.col("b")), + f=pl.col("a").dt.date().dt.offset_by(pl.col("b")), + ) + assert_frame_equal(result, expected[i : i + 1]) + # single-row Series are always sorted + assert result.flags == { + "c": {"SORTED_ASC": True, "SORTED_DESC": False}, + "d": {"SORTED_ASC": True, "SORTED_DESC": False}, + "e": {"SORTED_ASC": True, "SORTED_DESC": False}, + "f": {"SORTED_ASC": True, "SORTED_DESC": False}, + } + + +@pytest.mark.parametrize( + ("duration", "input_date", "expected"), + [ + ("1mo", date(2018, 1, 31), date(2018, 2, 28)), + ("1y", date(2024, 2, 29), date(2025, 2, 28)), + ("1y1mo", date(2024, 1, 30), date(2025, 2, 28)), + ], +) +def test_offset_by_saturating_8217_8474( + duration: str, input_date: date, expected: date +) -> None: + result = pl.Series([input_date]).dt.offset_by(duration).item() + assert result == expected + + +def test_year_empty_df() -> None: + df = pl.DataFrame(pl.Series(name="date", dtype=pl.Date)) + assert df.select(pl.col("date").dt.year()).dtypes == [pl.Int32] + + +def test_epoch_invalid() -> None: + with pytest.raises(InvalidOperationError, match="not supported for dtype"): + pl.Series([timedelta(1)]).dt.epoch() + + +@pytest.mark.parametrize( + "time_unit", + ["ms", "us", "ns"], + ids=["milliseconds", "microseconds", "nanoseconds"], +) +def test_weekday(time_unit: TimeUnit) -> None: + friday = pl.Series([datetime(2023, 2, 17)]) + + assert friday.dt.cast_time_unit(time_unit).dt.weekday()[0] == 5 + assert friday.cast(pl.Date).dt.weekday()[0] == 5 + + +@pytest.mark.parametrize( + ("values", "expected_median"), + [ + ([], None), + ([None, None], None), + ([date(2022, 1, 1)], datetime(2022, 1, 1)), + ([date(2022, 1, 1), date(2022, 1, 2), date(2022, 1, 4)], datetime(2022, 1, 2)), + ([date(2022, 1, 1), date(2022, 1, 2), date(2024, 5, 15)], datetime(2022, 1, 2)), + ([datetime(2022, 1, 1)], datetime(2022, 1, 1)), + ( + [datetime(2022, 1, 1), datetime(2022, 1, 2), datetime(2022, 1, 3)], + datetime(2022, 1, 2), + ), + ( + [datetime(2022, 1, 1), datetime(2022, 1, 2), datetime(2024, 5, 15)], + datetime(2022, 1, 2), + ), + ([timedelta(days=1)], timedelta(days=1)), + ([timedelta(days=1), timedelta(days=2), timedelta(days=3)], timedelta(days=2)), + ([timedelta(days=1), timedelta(days=2), timedelta(days=15)], timedelta(days=2)), + ([time(hour=1)], time(hour=1)), + ([time(hour=1), time(hour=2), time(hour=3)], time(hour=2)), + ([time(hour=1), time(hour=2), time(hour=15)], time(hour=2)), + ], + ids=[ + "empty", + "Nones", + "single_date", + "spread_even_date", + "spread_skewed_date", + "single_datetime", + "spread_even_datetime", + "spread_skewed_datetime", + "single_dur", + "spread_even_dur", + "spread_skewed_dur", + "single_time", + "spread_even_time", + "spread_skewed_time", + ], +) +def test_median( + values: list[TemporalLiteral | None], expected_median: TemporalLiteral | None +) -> None: + assert pl.Series(values).median() == expected_median + + +@pytest.mark.parametrize( + ("values", "expected_mean"), + [ + ([], None), + ([None, None], None), + ([date(2022, 1, 1)], datetime(2022, 1, 1)), + ( + [date(2022, 1, 1), date(2022, 1, 2), date(2022, 1, 4)], + datetime(2022, 1, 2, 8), + ), + ( + [date(2022, 1, 1), date(2022, 1, 2), date(2024, 5, 15)], + datetime(2022, 10, 16, 16, 0), + ), + ([datetime(2022, 1, 1)], datetime(2022, 1, 1)), + ( + [datetime(2022, 1, 1), datetime(2022, 1, 2), datetime(2022, 1, 3)], + datetime(2022, 1, 2), + ), + ( + [datetime(2022, 1, 1), datetime(2022, 1, 2), datetime(2024, 5, 15)], + datetime(2022, 10, 16, 16, 0, 0), + ), + ([timedelta(days=1)], timedelta(days=1)), + ([timedelta(days=1), timedelta(days=2), timedelta(days=3)], timedelta(days=2)), + ([timedelta(days=1), timedelta(days=2), timedelta(days=15)], timedelta(days=6)), + ([time(hour=1)], time(hour=1)), + ([time(hour=1), time(hour=2), time(hour=3)], time(hour=2)), + ([time(hour=1), time(hour=2), time(hour=15)], time(hour=6)), + ], + ids=[ + "empty", + "Nones", + "single_date", + "spread_even_date", + "spread_skewed_date", + "single_datetime", + "spread_even_datetime", + "spread_skewed_datetime", + "single_duration", + "spread_even_duration", + "spread_skewed_duration", + "single_time", + "spread_even_time", + "spread_skewed_time", + ], +) +def test_mean( + values: list[TemporalLiteral | None], expected_mean: TemporalLiteral | None +) -> None: + assert pl.Series(values).mean() == expected_mean + + +@pytest.mark.parametrize( + ("values", "expected_mean"), + [ + ([None], None), + ( + [datetime(2022, 1, 1), datetime(2022, 1, 2), datetime(2024, 5, 15)], + datetime(2022, 10, 16, 16, 0, 0), + ), + ], + ids=["None_dt", "spread_skewed_dt"], +) +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_datetime_mean_with_tu( + values: list[datetime], expected_mean: datetime, time_unit: TimeUnit +) -> None: + assert pl.Series(values, dtype=pl.Duration(time_unit)).mean() == expected_mean + + +@pytest.mark.parametrize( + ("values", "expected_median"), + [ + ([None], None), + ( + [datetime(2022, 1, 1), datetime(2022, 1, 2), datetime(2024, 5, 15)], + datetime(2022, 1, 2), + ), + ], + ids=["None_dt", "spread_skewed_dt"], +) +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_datetime_median_with_tu( + values: list[datetime], expected_median: datetime, time_unit: TimeUnit +) -> None: + assert pl.Series(values, dtype=pl.Duration(time_unit)).median() == expected_median + + +@pytest.mark.parametrize( + ("values", "expected_mean"), + [ + ([None], None), + ( + [timedelta(days=1), timedelta(days=2), timedelta(days=15)], + timedelta(days=6), + ), + ], + ids=["None_dur", "spread_skewed_dur"], +) +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_duration_mean_with_tu( + values: list[timedelta], expected_mean: timedelta, time_unit: TimeUnit +) -> None: + assert pl.Series(values, dtype=pl.Duration(time_unit)).mean() == expected_mean + + +@pytest.mark.parametrize( + ("values", "expected_median"), + [ + ([None], None), + ( + [timedelta(days=1), timedelta(days=2), timedelta(days=15)], + timedelta(days=2), + ), + ], + ids=["None_dur", "spread_skewed_dur"], +) +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_duration_median_with_tu( + values: list[timedelta], expected_median: timedelta, time_unit: TimeUnit +) -> None: + assert pl.Series(values, dtype=pl.Duration(time_unit)).median() == expected_median + + +def test_agg_mean_expr() -> None: + df = pl.DataFrame( + { + "date": pl.Series( + [date(2023, 1, 1), date(2023, 1, 2), date(2023, 1, 4)], + dtype=pl.Date, + ), + "datetime_ms": pl.Series( + [datetime(2023, 1, 1), datetime(2023, 1, 2), datetime(2023, 1, 4)], + dtype=pl.Datetime("ms"), + ), + "datetime_us": pl.Series( + [datetime(2023, 1, 1), datetime(2023, 1, 2), datetime(2023, 1, 4)], + dtype=pl.Datetime("us"), + ), + "datetime_ns": pl.Series( + [datetime(2023, 1, 1), datetime(2023, 1, 2), datetime(2023, 1, 4)], + dtype=pl.Datetime("ns"), + ), + "duration_ms": pl.Series( + [timedelta(days=1), timedelta(days=2), timedelta(days=4)], + dtype=pl.Duration("ms"), + ), + "duration_us": pl.Series( + [timedelta(days=1), timedelta(days=2), timedelta(days=4)], + dtype=pl.Duration("us"), + ), + "duration_ns": pl.Series( + [timedelta(days=1), timedelta(days=2), timedelta(days=4)], + dtype=pl.Duration("ns"), + ), + "time": pl.Series( + [time(hour=1), time(hour=2), time(hour=4)], + dtype=pl.Time, + ), + } + ) + + expected = pl.DataFrame( + { + "date": pl.Series([datetime(2023, 1, 2, 8, 0)], dtype=pl.Datetime("ms")), + "datetime_ms": pl.Series( + [datetime(2023, 1, 2, 8, 0, 0)], dtype=pl.Datetime("ms") + ), + "datetime_us": pl.Series( + [datetime(2023, 1, 2, 8, 0, 0)], dtype=pl.Datetime("us") + ), + "datetime_ns": pl.Series( + [datetime(2023, 1, 2, 8, 0, 0)], dtype=pl.Datetime("ns") + ), + "duration_ms": pl.Series( + [timedelta(days=2, hours=8)], dtype=pl.Duration("ms") + ), + "duration_us": pl.Series( + [timedelta(days=2, hours=8)], dtype=pl.Duration("us") + ), + "duration_ns": pl.Series( + [timedelta(days=2, hours=8)], dtype=pl.Duration("ns") + ), + "time": pl.Series([time(hour=2, minute=20)], dtype=pl.Time), + } + ) + + assert_frame_equal(df.select(pl.all().mean()), expected) + + +def test_agg_median_expr() -> None: + df = pl.DataFrame( + { + "date": pl.Series( + [date(2023, 1, 1), date(2023, 1, 2), date(2023, 1, 4)], + dtype=pl.Date, + ), + "datetime_ms": pl.Series( + [datetime(2023, 1, 1), datetime(2023, 1, 2), datetime(2023, 1, 4)], + dtype=pl.Datetime("ms"), + ), + "datetime_us": pl.Series( + [datetime(2023, 1, 1), datetime(2023, 1, 2), datetime(2023, 1, 4)], + dtype=pl.Datetime("us"), + ), + "datetime_ns": pl.Series( + [datetime(2023, 1, 1), datetime(2023, 1, 2), datetime(2023, 1, 4)], + dtype=pl.Datetime("ns"), + ), + "duration_ms": pl.Series( + [timedelta(days=1), timedelta(days=2), timedelta(days=4)], + dtype=pl.Duration("ms"), + ), + "duration_us": pl.Series( + [timedelta(days=1), timedelta(days=2), timedelta(days=4)], + dtype=pl.Duration("us"), + ), + "duration_ns": pl.Series( + [timedelta(days=1), timedelta(days=2), timedelta(days=4)], + dtype=pl.Duration("ns"), + ), + "time": pl.Series( + [time(hour=1), time(hour=2), time(hour=4)], + dtype=pl.Time, + ), + } + ) + + expected = pl.DataFrame( + { + "date": pl.Series([datetime(2023, 1, 2)], dtype=pl.Datetime("ms")), + "datetime_ms": pl.Series([datetime(2023, 1, 2)], dtype=pl.Datetime("ms")), + "datetime_us": pl.Series([datetime(2023, 1, 2)], dtype=pl.Datetime("us")), + "datetime_ns": pl.Series([datetime(2023, 1, 2)], dtype=pl.Datetime("ns")), + "duration_ms": pl.Series([timedelta(days=2)], dtype=pl.Duration("ms")), + "duration_us": pl.Series([timedelta(days=2)], dtype=pl.Duration("us")), + "duration_ns": pl.Series([timedelta(days=2)], dtype=pl.Duration("ns")), + "time": pl.Series([time(hour=2)], dtype=pl.Time), + } + ) + + assert_frame_equal(df.select(pl.all().median()), expected) + + +@given( + s=series(min_size=1, max_size=10, dtype=pl.Duration), +) +@pytest.mark.skip( + "These functions are currently bugged for large values: " + "https://github.com/pola-rs/polars/issues/16057" +) +def test_series_duration_timeunits( + s: pl.Series, +) -> None: + nanos = s.dt.total_nanoseconds().to_list() + micros = s.dt.total_microseconds().to_list() + millis = s.dt.total_milliseconds().to_list() + + scale = { + "ns": 1, + "us": 1_000, + "ms": 1_000_000, + } + assert nanos == [v * scale[s.dtype.time_unit] for v in s.to_physical()] # type: ignore[attr-defined] + assert micros == [int(v / 1_000) for v in nanos] + assert millis == [int(v / 1_000) for v in micros] + + # special handling for ns timeunit (as we may generate a microsecs-based + # timedelta that results in 64bit overflow on conversion to nanosecs) + lower_bound, upper_bound = -(2**63), (2**63) - 1 + if all( + (lower_bound <= (us * 1000) <= upper_bound) + for us in micros + if isinstance(us, int) + ): + for ns, us in zip(s.dt.total_nanoseconds(), micros): + assert ns == (us * 1000) + + +@given( + s=series(min_size=1, max_size=10, dtype=pl.Datetime, allow_null=False), +) +def test_series_datetime_timeunits( + s: pl.Series, +) -> None: + # datetime + assert s.to_list() == list(s) + assert list(s.dt.millisecond()) == [v.microsecond // 1000 for v in s] + assert list(s.dt.nanosecond()) == [v.microsecond * 1000 for v in s] + assert list(s.dt.microsecond()) == [v.microsecond for v in s] + + +def test_dt_median_deprecated() -> None: + values = [date(2022, 1, 1), date(2022, 1, 2), date(2024, 5, 15)] + s = pl.Series(values) + with pytest.deprecated_call(): + result = s.dt.median() + assert result == s.median() + + +def test_dt_mean_deprecated() -> None: + values = [date(2022, 1, 1), date(2022, 1, 2), date(2024, 5, 15)] + s = pl.Series(values) + with pytest.deprecated_call(): + result = s.dt.mean() + assert result == s.mean() + + +@pytest.mark.parametrize( + "dtype", + [ + pl.Date, + pl.Datetime("ms"), + pl.Datetime("ms", "EST"), + pl.Datetime("us"), + pl.Datetime("us", "EST"), + pl.Datetime("ns"), + pl.Datetime("ns", "EST"), + ], +) +@pytest.mark.parametrize( + "value", + [ + # date(1677, 9, 22), # See test_literal_from_datetime. + date(1970, 1, 1), + date(2024, 2, 29), + date(2262, 4, 11), + ], +) +def test_literal_from_date( + value: date, + dtype: PolarsDataType, +) -> None: + out = pl.select(pl.lit(value, dtype=dtype)) + assert out.schema == OrderedDict({"literal": dtype}) + if dtype == pl.Datetime: + tz = ZoneInfo(dtype.time_zone) if dtype.time_zone is not None else None # type: ignore[union-attr] + value = datetime(value.year, value.month, value.day, tzinfo=tz) + assert out.item() == value + + +@pytest.mark.parametrize( + "dtype", + [ + pl.Date, + pl.Datetime("ms"), + pl.Datetime("ms", "EST"), + pl.Datetime("us"), + pl.Datetime("us", "EST"), + pl.Datetime("ns"), + pl.Datetime("ns", "EST"), + ], +) +@pytest.mark.parametrize( + "value", + [ + # Very old dates with a timezone like EST caused problems for the CI due + # to the IANA timezone database updating their historical offset, so + # these have been disabled for now. A mismatch between the timezone + # database that chrono_tz crate uses vs. the one that Python uses (which + # differs from platform to platform) will cause this to fail. + # datetime(1677, 9, 22), + # datetime(1677, 9, 22, tzinfo=ZoneInfo("EST")), + datetime(1970, 1, 1), + datetime(1970, 1, 1, tzinfo=ZoneInfo("EST")), + datetime(2024, 2, 29), + datetime(2024, 2, 29, tzinfo=ZoneInfo("EST")), + datetime(2262, 4, 11), + datetime(2262, 4, 11, tzinfo=ZoneInfo("EST")), + ], +) +def test_literal_from_datetime( + value: datetime, + dtype: pl.Date | pl.Datetime, +) -> None: + out = pl.select(pl.lit(value, dtype=dtype)) + if dtype == pl.Date: + value = value.date() # type: ignore[assignment] + elif dtype.time_zone is None and value.tzinfo is not None: # type: ignore[union-attr] + # update the dtype with the supplied time zone in the value + dtype = pl.Datetime(dtype.time_unit, str(value.tzinfo)) # type: ignore[union-attr] + elif dtype.time_zone is not None and value.tzinfo is None: # type: ignore[union-attr] + # cast from dt without tz to dtype with tz + value = value.replace(tzinfo=ZoneInfo(dtype.time_zone)) # type: ignore[union-attr] + + assert out.schema == OrderedDict({"literal": dtype}) + assert out.item() == value + + +@pytest.mark.parametrize( + "value", + [ + time(0), + time(hour=1), + time(hour=16, minute=43, microsecond=500), + time(hour=23, minute=59, second=59, microsecond=999999), + ], +) +def test_literal_from_time(value: time) -> None: + out = pl.select(pl.lit(value)) + assert out.schema == OrderedDict({"literal": pl.Time}) + assert out.item() == value + + +@pytest.mark.parametrize( + "dtype", + [ + None, + pl.Duration("ms"), + pl.Duration("us"), + pl.Duration("ns"), + ], +) +@pytest.mark.parametrize( + "value", + [ + timedelta(0), + timedelta(hours=1), + timedelta(days=-99999), + timedelta(days=99999), + ], +) +def test_literal_from_timedelta(value: time, dtype: pl.Duration | None) -> None: + out = pl.select(pl.lit(value, dtype=dtype)) + assert out.schema == OrderedDict({"literal": dtype or pl.Duration("us")}) + assert out.item() == value diff --git a/py-polars/tests/unit/operations/namespaces/temporal/test_is_business_day.py b/py-polars/tests/unit/operations/namespaces/temporal/test_is_business_day.py new file mode 100644 index 000000000000..9dd90a2cd8e9 --- /dev/null +++ b/py-polars/tests/unit/operations/namespaces/temporal/test_is_business_day.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from datetime import date, datetime + +import pytest + +import polars as pl +from polars.exceptions import ComputeError +from polars.testing import assert_series_equal + + +@pytest.mark.parametrize( + ("holidays", "week_mask", "expected_values"), + [ + ((), (True, True, True, True, True, False, False), [True, True]), + ( + (date(2020, 1, 1),), + (True, True, True, True, True, False, False), + [False, True], + ), + ( + (date(2020, 1, 2),), + (True, True, True, True, True, False, False), + [True, False], + ), + ( + (date(2020, 1, 2), date(2020, 1, 1)), + (True, True, True, True, True, False, False), + [False, False], + ), + ((), (True, True, False, True, True, True, True), [False, True]), + ((), (True, True, False, False, True, True, True), [False, False]), + ((), (True, True, True, False, True, True, True), [True, False]), + ( + (), + (True, True, True, True, True, True, True), + [True, True], + ), # no holidays :sob: + ( + (date(2020, 1, 1),), + (True, True, True, True, True, True, True), + [False, True], + ), + ( + (date(2020, 1, 1),), + (True, True, False, True, True, True, True), + [False, True], + ), + ], +) +def test_is_business_day( + holidays: tuple[date, ...], week_mask: tuple[bool, ...], expected_values: list[bool] +) -> None: + # Date + df = pl.DataFrame({"date": [date(2020, 1, 1), date(2020, 1, 2)]}) + result = df.select( + pl.col("date").dt.is_business_day(holidays=holidays, week_mask=week_mask) + )["date"] + expected = pl.Series("date", expected_values) + assert_series_equal(result, expected) + # Datetime + df = pl.DataFrame({"date": [datetime(2020, 1, 1, 3), datetime(2020, 1, 2, 1)]}) + result = df.select( + pl.col("date").dt.is_business_day(holidays=holidays, week_mask=week_mask) + )["date"] + assert_series_equal(result, expected) + # Datetime tz-aware + df = pl.DataFrame( + {"date": [datetime(2020, 1, 1), datetime(2020, 1, 2)]} + ).with_columns(pl.col("date").dt.replace_time_zone("Asia/Kathmandu")) + result = df.select( + pl.col("date").dt.is_business_day(holidays=holidays, week_mask=week_mask) + )["date"] + assert_series_equal(result, expected) + + +def test_is_business_day_invalid() -> None: + df = pl.DataFrame({"date": [date(2020, 1, 1), date(2020, 1, 2)]}) + with pytest.raises(ComputeError): + df.select(pl.col("date").dt.is_business_day(week_mask=[False] * 7)) + + +def test_is_business_day_repr() -> None: + assert "is_business_day" in repr(pl.col("date").dt.is_business_day()) diff --git a/py-polars/tests/unit/operations/namespaces/temporal/test_offset_by.py b/py-polars/tests/unit/operations/namespaces/temporal/test_offset_by.py new file mode 100644 index 000000000000..4740a74b0a7e --- /dev/null +++ b/py-polars/tests/unit/operations/namespaces/temporal/test_offset_by.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +from datetime import date, datetime +from typing import TYPE_CHECKING + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal + +if TYPE_CHECKING: + from polars._typing import TimeUnit + + +@pytest.mark.parametrize( + ("inputs", "offset", "outputs"), + [ + ( + [date(2020, 1, 1), date(2020, 1, 2)], + "1d", + [date(2020, 1, 2), date(2020, 1, 3)], + ), + ( + [date(2020, 1, 1), date(2020, 1, 2)], + "-1d", + [date(2019, 12, 31), date(2020, 1, 1)], + ), + ( + [date(2020, 1, 1), date(2020, 1, 2)], + "3d", + [date(2020, 1, 4), date(2020, 1, 5)], + ), + ( + [date(2020, 1, 1), date(2020, 1, 2)], + "72h", + [date(2020, 1, 4), date(2020, 1, 5)], + ), + ( + [date(2020, 1, 1), date(2020, 1, 2)], + "2d24h", + [date(2020, 1, 4), date(2020, 1, 5)], + ), + ( + [date(2020, 1, 1), date(2020, 1, 2)], + "-2mo", + [date(2019, 11, 1), date(2019, 11, 2)], + ), + ], +) +def test_date_offset_by(inputs: list[date], offset: str, outputs: list[date]) -> None: + result = pl.Series(inputs).dt.offset_by(offset) + expected = pl.Series(outputs) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("inputs", "offset", "outputs"), + [ + ( + [date(2020, 1, 1), date(2020, 1, 2)], + "1d", + [date(2020, 1, 2), date(2020, 1, 3)], + ), + ( + [date(2020, 1, 1), date(2020, 1, 2)], + "-1d", + [date(2019, 12, 31), date(2020, 1, 1)], + ), + ( + [date(2020, 1, 1), date(2020, 1, 2)], + "3d", + [date(2020, 1, 4), date(2020, 1, 5)], + ), + ( + [date(2020, 1, 1), date(2020, 1, 2)], + "72h", + [date(2020, 1, 4), date(2020, 1, 5)], + ), + ( + [date(2020, 1, 1), date(2020, 1, 2)], + "2d24h", + [date(2020, 1, 4), date(2020, 1, 5)], + ), + ( + [date(2020, 1, 1), date(2020, 1, 2)], + "7m", + [datetime(2020, 1, 1, 0, 7), datetime(2020, 1, 2, 0, 7)], + ), + ( + [date(2020, 1, 1), date(2020, 1, 2)], + "-3m", + [datetime(2019, 12, 31, 23, 57), datetime(2020, 1, 1, 23, 57)], + ), + ( + [date(2020, 1, 1), date(2020, 1, 2)], + "2mo", + [datetime(2020, 3, 1), datetime(2020, 3, 2)], + ), + ], +) +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +@pytest.mark.parametrize("time_zone", ["Europe/London", "Asia/Kathmandu", None]) +def test_datetime_offset_by( + inputs: list[date], + offset: str, + outputs: list[datetime], + time_unit: TimeUnit, + time_zone: str | None, +) -> None: + result = ( + pl.Series(inputs, dtype=pl.Datetime(time_unit)) + .dt.replace_time_zone(time_zone) + .dt.offset_by(offset) + ) + expected = pl.Series(outputs, dtype=pl.Datetime(time_unit)).dt.replace_time_zone( + time_zone + ) + assert_series_equal(result, expected) + + +def test_offset_by_unique_29_feb_19608() -> None: + df20 = pl.select( + t=pl.datetime_range( + pl.datetime(2020, 2, 28), + pl.datetime(2020, 3, 1), + closed="left", + time_unit="ms", + interval="8h", + time_zone="UTC", + ), + ).with_columns(x=pl.int_range(pl.len())) + df19 = df20.with_columns(pl.col("t").dt.offset_by("-1y")) + result = df19.unique("t", keep="first").sort("t") + expected = pl.DataFrame( + { + "t": [ + datetime(2019, 2, 28), + datetime(2019, 2, 28, 8), + datetime(2019, 2, 28, 16), + ], + "x": [0, 1, 2], + }, + schema_overrides={"t": pl.Datetime("ms", "UTC")}, + ) + assert_frame_equal(result, expected) + + +def test_month_then_day_21283() -> None: + series_vienna = pl.Series( + [datetime(2024, 5, 15, 8, 0)], dtype=pl.Datetime(time_zone="Europe/Vienna") + ) + result = series_vienna.dt.offset_by("2y1mo1q1h")[0] + expected = datetime.strptime("2026-09-15 11:00:00+02:00", "%Y-%m-%d %H:%M:%S%z") + assert result == expected + result = series_vienna.dt.offset_by("2y1mo1q1h1d")[0] + expected = datetime.strptime("2026-09-16 11:00:00+02:00", "%Y-%m-%d %H:%M:%S%z") + assert result == expected + series_utc = pl.Series( + [datetime(2024, 5, 15, 8, 0)], dtype=pl.Datetime(time_zone="UTC") + ) + result = series_utc.dt.offset_by("2y1mo1q1h")[0] + expected = datetime.strptime("2026-09-15 09:00:00+00:00", "%Y-%m-%d %H:%M:%S%z") + assert result == expected + result = series_utc.dt.offset_by("2y1mo1q1h1d")[0] + expected = datetime.strptime("2026-09-16 09:00:00+00:00", "%Y-%m-%d %H:%M:%S%z") + assert result == expected + + +def test_offset_by_unequal_length_22018() -> None: + with pytest.raises(pl.exceptions.ShapeError): + pl.Series([datetime(2088, 8, 8, 8, 8, 8, 8)] * 2).dt.offset_by( + pl.Series([f"{h}y" for h in range(3)]) + ) diff --git a/py-polars/tests/unit/operations/namespaces/temporal/test_replace.py b/py-polars/tests/unit/operations/namespaces/temporal/test_replace.py new file mode 100644 index 000000000000..31a26d7ca64a --- /dev/null +++ b/py-polars/tests/unit/operations/namespaces/temporal/test_replace.py @@ -0,0 +1,405 @@ +from __future__ import annotations + +from datetime import date, datetime +from typing import TYPE_CHECKING + +import pytest + +import polars as pl +from polars.exceptions import ComputeError +from polars.testing import assert_frame_equal, assert_series_equal + +if TYPE_CHECKING: + from polars._typing import TimeUnit + + +def test_replace_expr_datetime() -> None: + df = pl.DataFrame( + { + "dates": [ + datetime(2088, 8, 8, 8, 8, 8, 8), + datetime(2088, 8, 8, 8, 8, 8, 8), + datetime(2088, 8, 8, 8, 8, 8, 8), + datetime(2088, 8, 8, 8, 8, 8, 8), + datetime(2088, 8, 8, 8, 8, 8, 8), + datetime(2088, 8, 8, 8, 8, 8, 8), + datetime(2088, 8, 8, 8, 8, 8, 8), + None, + ], + "year": [None, 2, 3, 4, 5, 6, 7, 8], + "month": [1, None, 3, 4, 5, 6, 7, 8], + "day": [1, 2, None, 4, 5, 6, 7, 8], + "hour": [1, 2, 3, None, 5, 6, 7, 8], + "minute": [1, 2, 3, 4, None, 6, 7, 8], + "second": [1, 2, 3, 4, 5, None, 7, 8], + "microsecond": [1, 2, 3, 4, 5, 6, None, 8], + } + ) + + result = df.select( + pl.col("dates").dt.replace( + year="year", + month="month", + day="day", + hour="hour", + minute="minute", + second="second", + microsecond="microsecond", + ) + ) + + expected = pl.DataFrame( + { + "dates": [ + datetime(2088, 1, 1, 1, 1, 1, 1), + datetime(2, 8, 2, 2, 2, 2, 2), + datetime(3, 3, 8, 3, 3, 3, 3), + datetime(4, 4, 4, 8, 4, 4, 4), + datetime(5, 5, 5, 5, 8, 5, 5), + datetime(6, 6, 6, 6, 6, 8, 6), + datetime(7, 7, 7, 7, 7, 7, 8), + None, + ] + } + ) + + assert_frame_equal(result, expected) + + +def test_replace_expr_date() -> None: + df = pl.DataFrame( + { + "dates": [date(2088, 8, 8), date(2088, 8, 8), date(2088, 8, 8), None], + "year": [None, 2, 3, 4], + "month": [1, None, 3, 4], + "day": [1, 2, None, 4], + } + ) + + result = df.select( + pl.col("dates").dt.replace(year="year", month="month", day="day") + ) + + expected = pl.DataFrame( + {"dates": [date(2088, 1, 1), date(2, 8, 2), date(3, 3, 8), None]} + ) + + assert_frame_equal(result, expected) + + +def test_replace_int_datetime() -> None: + df = pl.DataFrame( + { + "a": [ + datetime(1, 1, 1, 1, 1, 1, 1), + datetime(2, 2, 2, 2, 2, 2, 2), + datetime(3, 3, 3, 3, 3, 3, 3), + None, + ] + } + ) + result = df.select( + pl.col("a").dt.replace().alias("no_change"), + pl.col("a").dt.replace(year=9).alias("year"), + pl.col("a").dt.replace(month=9).alias("month"), + pl.col("a").dt.replace(day=9).alias("day"), + pl.col("a").dt.replace(hour=9).alias("hour"), + pl.col("a").dt.replace(minute=9).alias("minute"), + pl.col("a").dt.replace(second=9).alias("second"), + pl.col("a").dt.replace(microsecond=9).alias("microsecond"), + ) + expected = pl.DataFrame( + { + "no_change": [ + datetime(1, 1, 1, 1, 1, 1, 1), + datetime(2, 2, 2, 2, 2, 2, 2), + datetime(3, 3, 3, 3, 3, 3, 3), + None, + ], + "year": [ + datetime(9, 1, 1, 1, 1, 1, 1), + datetime(9, 2, 2, 2, 2, 2, 2), + datetime(9, 3, 3, 3, 3, 3, 3), + None, + ], + "month": [ + datetime(1, 9, 1, 1, 1, 1, 1), + datetime(2, 9, 2, 2, 2, 2, 2), + datetime(3, 9, 3, 3, 3, 3, 3), + None, + ], + "day": [ + datetime(1, 1, 9, 1, 1, 1, 1), + datetime(2, 2, 9, 2, 2, 2, 2), + datetime(3, 3, 9, 3, 3, 3, 3), + None, + ], + "hour": [ + datetime(1, 1, 1, 9, 1, 1, 1), + datetime(2, 2, 2, 9, 2, 2, 2), + datetime(3, 3, 3, 9, 3, 3, 3), + None, + ], + "minute": [ + datetime(1, 1, 1, 1, 9, 1, 1), + datetime(2, 2, 2, 2, 9, 2, 2), + datetime(3, 3, 3, 3, 9, 3, 3), + None, + ], + "second": [ + datetime(1, 1, 1, 1, 1, 9, 1), + datetime(2, 2, 2, 2, 2, 9, 2), + datetime(3, 3, 3, 3, 3, 9, 3), + None, + ], + "microsecond": [ + datetime(1, 1, 1, 1, 1, 1, 9), + datetime(2, 2, 2, 2, 2, 2, 9), + datetime(3, 3, 3, 3, 3, 3, 9), + None, + ], + } + ) + assert_frame_equal(result, expected) + + +def test_replace_int_date() -> None: + df = pl.DataFrame( + { + "a": [ + date(1, 1, 1), + date(2, 2, 2), + date(3, 3, 3), + None, + ] + } + ) + result = df.select( + pl.col("a").dt.replace().alias("no_change"), + pl.col("a").dt.replace(year=9).alias("year"), + pl.col("a").dt.replace(month=9).alias("month"), + pl.col("a").dt.replace(day=9).alias("day"), + ) + expected = pl.DataFrame( + { + "no_change": [ + date(1, 1, 1), + date(2, 2, 2), + date(3, 3, 3), + None, + ], + "year": [ + date(9, 1, 1), + date(9, 2, 2), + date(9, 3, 3), + None, + ], + "month": [ + date(1, 9, 1), + date(2, 9, 2), + date(3, 9, 3), + None, + ], + "day": [ + date(1, 1, 9), + date(2, 2, 9), + date(3, 3, 9), + None, + ], + } + ) + assert_frame_equal(result, expected) + + +def test_replace_ambiguous() -> None: + # Value to be replaced by an ambiguous hour. + value = pl.select( + pl.datetime(2020, 10, 25, 5, time_zone="Europe/London") + ).to_series() + + input = [2020, 10, 25, 1] + tz = "Europe/London" + + # earliest + expected = pl.select( + pl.datetime(*input, time_zone=tz, ambiguous="earliest") + ).to_series() + result = value.dt.replace(hour=1, ambiguous="earliest") + assert_series_equal(result, expected) + + # latest + expected = pl.select( + pl.datetime(*input, time_zone=tz, ambiguous="latest") + ).to_series() + result = value.dt.replace(hour=1, ambiguous="latest") + assert_series_equal(result, expected) + + # null + expected = pl.select( + pl.datetime(*input, time_zone=tz, ambiguous="null") + ).to_series() + result = value.dt.replace(hour=1, ambiguous="null") + assert_series_equal(result, expected) + + # raise + with pytest.raises( + ComputeError, + match=( + "datetime '2020-10-25 01:00:00' is ambiguous in time zone 'Europe/London'. " + "Please use `ambiguous` to tell how it should be localized." + ), + ): + value.dt.replace(hour=1, ambiguous="raise") + + +def test_replace_datetime_preserve_ns() -> None: + df = pl.DataFrame( + { + "a": pl.Series(["2020-01-01T00:00:00.123456789"] * 2).cast( + pl.Datetime("ns") + ), + "year": [2021, None], + "microsecond": [50, None], + } + ) + + result = df.select( + year=pl.col("a").dt.replace(year="year"), + us=pl.col("a").dt.replace(microsecond="microsecond"), + ) + + expected = pl.DataFrame( + { + "year": pl.Series( + [ + "2021-01-01T00:00:00.123456789", + "2020-01-01T00:00:00.123456789", + ] + ).cast(pl.Datetime("ns")), + "us": pl.Series( + [ + "2020-01-01T00:00:00.000050", + "2020-01-01T00:00:00.123456789", + ] + ).cast(pl.Datetime("ns")), + } + ) + + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("tu", ["ms", "us", "ns"]) +@pytest.mark.parametrize("tzinfo", [None, "Africa/Nairobi", "America/New_York"]) +def test_replace_preserve_tu_and_tz(tu: TimeUnit, tzinfo: str) -> None: + s = pl.Series( + [datetime(2024, 1, 1), datetime(2024, 1, 2)], + dtype=pl.Datetime(time_unit=tu, time_zone=tzinfo), + ) + result = s.dt.replace(year=2000) + assert result.dtype.time_unit == tu # type: ignore[attr-defined] + assert result.dtype.time_zone == tzinfo # type: ignore[attr-defined] + + +def test_replace_date_invalid_components() -> None: + df = pl.DataFrame({"a": [date(2025, 1, 1)]}) + + with pytest.raises( + ComputeError, match=r"Invalid date components \(2025, 13, 1\) supplied" + ): + df.select(pl.col("a").dt.replace(month=13)) + with pytest.raises( + ComputeError, match=r"Invalid date components \(2025, 1, 32\) supplied" + ): + df.select(pl.col("a").dt.replace(day=32)) + + +def test_replace_datetime_invalid_date_components() -> None: + df = pl.DataFrame({"a": [datetime(2025, 1, 1)]}) + + with pytest.raises( + ComputeError, match=r"Invalid date components \(2025, 13, 1\) supplied" + ): + df.select(pl.col("a").dt.replace(month=13)) + with pytest.raises( + ComputeError, match=r"Invalid date components \(2025, 1, 32\) supplied" + ): + df.select(pl.col("a").dt.replace(day=32)) + + +def test_replace_datetime_invalid_time_components() -> None: + df = pl.DataFrame({"a": [datetime(2025, 1, 1)]}) + + # hour + with pytest.raises( + ComputeError, match=r"Invalid time components \(25, 0, 0, 0\) supplied" + ): + df.select(pl.col("a").dt.replace(hour=25)) + + # minute + with pytest.raises( + ComputeError, match=r"Invalid time components \(0, 61, 0, 0\) supplied" + ): + df.select(pl.col("a").dt.replace(minute=61)) + + # second + with pytest.raises( + ComputeError, match=r"Invalid time components \(0, 0, 61, 0\) supplied" + ): + df.select(pl.col("a").dt.replace(second=61)) + + # microsecond + with pytest.raises( + ComputeError, + match=r"Invalid time components \(0, 0, 0, 2000000000\) supplied", + ): + df.select(pl.col("a").dt.replace(microsecond=2_000_000)) + + +def test_replace_unequal_length_22018() -> None: + with pytest.raises(pl.exceptions.ShapeError): + pl.Series([datetime(2088, 8, 8, 8, 8, 8, 8)] * 2).dt.replace( + year=pl.Series([2000, 2001, 2002]) + ) + + +def test_replace_broadcast_self() -> None: + df = pl.DataFrame( + { + "year": [None, 2, 3, 4, 5, 6, 7, 8], + "month": [1, None, 3, 4, 5, 6, 7, 8], + "day": [1, 2, None, 4, 5, 6, 7, 8], + "hour": [1, 2, 3, None, 5, 6, 7, 8], + "minute": [1, 2, 3, 4, None, 6, 7, 8], + "second": [1, 2, 3, 4, 5, None, 7, 8], + "microsecond": [1, 2, 3, 4, 5, 6, None, 8], + } + ) + + result = df.select( + pl.lit(pl.Series("dates", [datetime(2088, 8, 8, 8, 8, 8, 8)])).dt.replace( + year="year", + month="month", + day="day", + hour="hour", + minute="minute", + second="second", + microsecond="microsecond", + ) + ) + + expected = pl.DataFrame( + { + "dates": [ + datetime(2088, 1, 1, 1, 1, 1, 1), + datetime(2, 8, 2, 2, 2, 2, 2), + datetime(3, 3, 8, 3, 3, 3, 3), + datetime(4, 4, 4, 8, 4, 4, 4), + datetime(5, 5, 5, 5, 8, 5, 5), + datetime(6, 6, 6, 6, 6, 8, 6), + datetime(7, 7, 7, 7, 7, 7, 8), + datetime(8, 8, 8, 8, 8, 8, 8), + ] + } + ) + + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/operations/namespaces/temporal/test_round.py b/py-polars/tests/unit/operations/namespaces/temporal/test_round.py new file mode 100644 index 000000000000..5fefcdaf5893 --- /dev/null +++ b/py-polars/tests/unit/operations/namespaces/temporal/test_round.py @@ -0,0 +1,281 @@ +from __future__ import annotations + +from datetime import date, datetime, timedelta +from typing import TYPE_CHECKING +from zoneinfo import ZoneInfo + +import hypothesis.strategies as st +import pytest +from hypothesis import given + +import polars as pl +from polars._utils.convert import parse_as_duration_string +from polars.testing import assert_series_equal + +if TYPE_CHECKING: + from polars.type_aliases import TimeUnit + + +@pytest.mark.parametrize("time_zone", [None, "Asia/Kathmandu"]) +def test_round_by_day_datetime(time_zone: str | None) -> None: + ser = pl.Series([datetime(2021, 11, 7, 3)]).dt.replace_time_zone(time_zone) + result = ser.dt.round("1d") + expected = pl.Series([datetime(2021, 11, 7)]).dt.replace_time_zone(time_zone) + assert_series_equal(result, expected) + + +def test_round_ambiguous() -> None: + t = ( + pl.datetime_range( + date(2020, 10, 25), + datetime(2020, 10, 25, 2), + "30m", + eager=True, + time_zone="Europe/London", + ) + .alias("datetime") + .dt.offset_by("15m") + ) + result = t.dt.round("30m") + expected = ( + pl.Series( + [ + "2020-10-25T00:30:00+0100", + "2020-10-25T01:00:00+0100", + "2020-10-25T01:30:00+0100", + "2020-10-25T01:00:00+0000", + "2020-10-25T01:30:00+0000", + "2020-10-25T02:00:00+0000", + "2020-10-25T02:30:00+0000", + ] + ) + .str.to_datetime() + .dt.convert_time_zone("Europe/London") + .rename("datetime") + ) + assert_series_equal(result, expected) + + df = pl.DataFrame( + { + "date": pl.datetime_range( + date(2020, 10, 25), + datetime(2020, 10, 25, 2), + "30m", + eager=True, + time_zone="Europe/London", + ).dt.offset_by("15m") + } + ) + + df = df.select(pl.col("date").dt.round("30m")) + assert df.to_dict(as_series=False) == { + "date": [ + datetime(2020, 10, 25, 0, 30, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 1, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 1, 30, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 1, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 1, 30, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 2, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 2, 30, tzinfo=ZoneInfo("Europe/London")), + ] + } + + +def test_round_by_week() -> None: + df = pl.DataFrame( + { + "date": pl.Series( + [ + # Sunday and Monday + "1998-04-12", + "2022-11-28", + ] + ).str.strptime(pl.Date, "%Y-%m-%d") + } + ) + + assert ( + df.select( + pl.col("date").dt.round("7d").alias("7d"), + pl.col("date").dt.round("1w").alias("1w"), + ) + ).to_dict(as_series=False) == { + "7d": [date(1998, 4, 9), date(2022, 12, 1)], + "1w": [date(1998, 4, 13), date(2022, 11, 28)], + } + + +@given( + datetimes=st.lists( + st.datetimes(min_value=datetime(1960, 1, 1), max_value=datetime(1980, 1, 1)), + min_size=1, + max_size=3, + ), + every=st.timedeltas( + min_value=timedelta(microseconds=1), max_value=timedelta(days=1) + ).map(parse_as_duration_string), +) +def test_dt_round_fast_path_vs_slow_path(datetimes: list[datetime], every: str) -> None: + s = pl.Series(datetimes) + # Might use fastpath: + result = s.dt.round(every) + # Definitely uses slowpath: + expected = s.dt.round(pl.Series([every] * len(datetimes))) + assert_series_equal(result, expected) + + +def test_round_date() -> None: + # n vs n + df = pl.DataFrame( + {"a": [date(2020, 1, 1), None, date(2020, 1, 19)], "b": [None, "1mo", "1mo"]} + ) + result = df.select(pl.col("a").dt.round(pl.col("b")))["a"] + expected = pl.Series("a", [None, None, date(2020, 2, 1)]) + assert_series_equal(result, expected) + + # n vs 1 + df = pl.DataFrame( + {"a": [date(2020, 1, 1), None, date(2020, 1, 3)], "b": [None, "1mo", "1mo"]} + ) + result = df.select(pl.col("a").dt.round("1mo"))["a"] + expected = pl.Series("a", [date(2020, 1, 1), None, date(2020, 1, 1)]) + assert_series_equal(result, expected) + + # n vs missing + df = pl.DataFrame( + {"a": [date(2020, 1, 1), None, date(2020, 1, 3)], "b": [None, "1mo", "1mo"]} + ) + result = df.select(pl.col("a").dt.round(pl.lit(None, dtype=pl.String)))["a"] + expected = pl.Series("a", [None, None, None], dtype=pl.Date) + assert_series_equal(result, expected) + + # 1 vs n + df = pl.DataFrame( + {"a": [date(2020, 1, 1), None, date(2020, 1, 3)], "b": [None, "1mo", "1mo"]} + ) + result = df.select(a=pl.date(2020, 1, 1).dt.round(pl.col("b")))["a"] + expected = pl.Series("a", [None, date(2020, 1, 1), date(2020, 1, 1)]) + assert_series_equal(result, expected) + + # missing vs n + df = pl.DataFrame( + {"a": [date(2020, 1, 1), None, date(2020, 1, 3)], "b": [None, "1mo", "1mo"]} + ) + result = df.select(a=pl.lit(None, dtype=pl.Date).dt.round(pl.col("b")))["a"] + expected = pl.Series("a", [None, None, None], dtype=pl.Date) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_round_datetime_simple(time_unit: TimeUnit) -> None: + s = pl.Series([datetime(2020, 1, 2, 6)], dtype=pl.Datetime(time_unit)) + result = s.dt.round("1mo").item() + assert result == datetime(2020, 1, 1) + result = s.dt.round("1d").item() + assert result == datetime(2020, 1, 2) + + +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_round_datetime_w_expression(time_unit: TimeUnit) -> None: + df = pl.DataFrame( + {"a": [datetime(2020, 1, 2, 6), datetime(2020, 1, 20, 21)], "b": ["1mo", "1d"]}, + schema_overrides={"a": pl.Datetime(time_unit)}, + ) + result = df.select(pl.col("a").dt.round(pl.col("b")))["a"] + assert result[0] == datetime(2020, 1, 1) + assert result[1] == datetime(2020, 1, 21) + + +@pytest.mark.parametrize( + ("time_unit", "expected"), + [ + ("ms", 0), + ("us", 0), + ("ns", 0), + ], +) +def test_round_negative_towards_epoch_18239(time_unit: TimeUnit, expected: int) -> None: + s = pl.Series([datetime(1970, 1, 1)], dtype=pl.Datetime(time_unit)) + s = s.dt.offset_by(f"-1{time_unit}") + result = s.dt.round(f"2{time_unit}").dt.timestamp(time_unit="ns").item() + assert result == expected + result = ( + s.dt.replace_time_zone("Europe/London") + .dt.round(f"2{time_unit}") + .dt.replace_time_zone(None) + .dt.timestamp(time_unit="ns") + .item() + ) + assert result == expected + + +@pytest.mark.parametrize( + ("time_unit", "expected"), + [ + ("ms", 2_000_000), + ("us", 2_000), + ("ns", 2), + ], +) +def test_round_positive_away_from_epoch_18239( + time_unit: TimeUnit, expected: int +) -> None: + s = pl.Series([datetime(1970, 1, 1)], dtype=pl.Datetime(time_unit)) + s = s.dt.offset_by(f"1{time_unit}") + result = s.dt.round(f"2{time_unit}").dt.timestamp(time_unit="ns").item() + assert result == expected + result = ( + s.dt.replace_time_zone("Europe/London") + .dt.round(f"2{time_unit}") + .dt.replace_time_zone(None) + .dt.timestamp(time_unit="ns") + .item() + ) + assert result == expected + + +@pytest.mark.parametrize("as_date", [False, True]) +def test_round_unequal_length_22018(as_date: bool) -> None: + start = datetime(2001, 1, 1) + stop = datetime(2001, 1, 1, 1) + s = pl.datetime_range(start, stop, "10m", eager=True).alias("datetime") + if as_date: + s = s.dt.date() + + with pytest.raises(pl.exceptions.ShapeError): + s.dt.round(pl.Series(["30m", "20m"])) + + +def test_round_small() -> None: + small = 1.234e-320 + small_s = pl.Series([small]) + assert small_s.round().item() == 0.0 + assert small_s.round(320).item() == 1e-320 + assert small_s.round(321).item() == 1.2e-320 + assert small_s.round(322).item() == 1.23e-320 + assert small_s.round(323).item() == 1.234e-320 + assert small_s.round(324).item() == small + assert small_s.round(1000).item() == small + + assert small_s.round_sig_figs(1).item() == 1e-320 + assert small_s.round_sig_figs(2).item() == 1.2e-320 + assert small_s.round_sig_figs(3).item() == 1.23e-320 + assert small_s.round_sig_figs(4).item() == 1.234e-320 + assert small_s.round_sig_figs(5).item() == small + assert small_s.round_sig_figs(1000).item() == small + + +def test_round_big() -> None: + big = 1.234e308 + max_err = big / 10**10 + big_s = pl.Series([big]) + assert big_s.round().item() == big + assert big_s.round(1).item() == big + assert big_s.round(100).item() == big + + assert abs(big_s.round_sig_figs(1).item() - 1e308) <= max_err + assert abs(big_s.round_sig_figs(2).item() - 1.2e308) <= max_err + assert abs(big_s.round_sig_figs(3).item() - 1.23e308) <= max_err + assert abs(big_s.round_sig_figs(4).item() - 1.234e308) <= max_err + assert abs(big_s.round_sig_figs(4).item() - big) <= max_err + assert big_s.round_sig_figs(100).item() == big diff --git a/py-polars/tests/unit/operations/namespaces/temporal/test_to_datetime.py b/py-polars/tests/unit/operations/namespaces/temporal/test_to_datetime.py new file mode 100644 index 000000000000..f07106ee76ac --- /dev/null +++ b/py-polars/tests/unit/operations/namespaces/temporal/test_to_datetime.py @@ -0,0 +1,197 @@ +from __future__ import annotations + +from datetime import date, datetime +from typing import TYPE_CHECKING +from zoneinfo import ZoneInfo + +import hypothesis.strategies as st +import pytest +from hypothesis import given + +import polars as pl +from polars.exceptions import ComputeError, InvalidOperationError +from polars.testing import assert_series_equal + +if TYPE_CHECKING: + from hypothesis.strategies import DrawFn + + from polars._typing import TimeUnit + + +DATE_FORMATS = ["%Y{}%m{}%d", "%d{}%m{}%Y"] +SEPARATORS = ["-", "/", "."] +TIME_FORMATS = [ + "T%H:%M:%S", + "T%H%M%S", + "T%H:%M", + "T%H%M", + " %H:%M:%S", + " %H%M%S", + " %H:%M", + " %H%M", + "", # allow no time part (plain date) +] +FRACTIONS = [ + "%.9f", + "%.6f", + "%.3f", + # "%.f", # alternative which allows any number of digits + "", +] +TIMEZONES = ["%#z", ""] +DATETIME_PATTERNS = [ + date_format.format(separator, separator) + time_format + fraction + tz + for separator in SEPARATORS + for date_format in DATE_FORMATS + for time_format in TIME_FORMATS + for fraction in FRACTIONS + if time_format.endswith("%S") or fraction == "" + for tz in TIMEZONES + if (date_format.startswith("%Y") and time_format != "") or tz == "" +] + + +@pytest.mark.parametrize("fmt", DATETIME_PATTERNS) +def test_to_datetime_inferable_formats(fmt: str) -> None: + time_string = ( + fmt.replace("%Y", "2024") + .replace("%m", "12") + .replace("%d", "13") + .replace("%H", "23") + .replace("%M", "34") + .replace("%S", "45") + .replace("%.3f", ".123") + .replace("%.6f", ".123456") + .replace("%.9f", ".123456789") + .replace("%#z", "+0100") + ) + + pl.Series([time_string]).str.to_datetime(strict=True) + + +@st.composite +def datetime_formats(draw: DrawFn) -> str: + """Returns a strategy which generates datetime format strings.""" + parts = [ + "%m", + "%b", + "%B", + "%d", + "%j", + "%a", + "%A", + "%w", + "%H", + "%I", + "%p", + "%M", + "%S", + "%U", + "%W", + "%%", + ] + fmt = draw(st.sets(st.sampled_from(parts))) + fmt.add("%Y") # Make sure year is always present + return " ".join(fmt) + + +@given( + datetimes=st.datetimes( + min_value=datetime(1699, 1, 1), + max_value=datetime(9999, 12, 31), + ), + fmt=datetime_formats(), +) +def test_to_datetime(datetimes: datetime, fmt: str) -> None: + input = datetimes.strftime(fmt) + expected = datetime.strptime(input, fmt) + try: + result = pl.Series([input]).str.to_datetime(format=fmt).item() + # If there's an exception, check that it's either: + # - something which polars can't parse at all: missing day or month + # - something on which polars intentionally raises + except InvalidOperationError as exc: + assert "failed in column" in str(exc) # noqa: PT017 + assert not any(day in fmt for day in ("%d", "%j")) or not any( + month in fmt for month in ("%b", "%B", "%m") + ) + except ComputeError as exc: + assert "Invalid format string" in str(exc) # noqa: PT017 + assert ( + (("%H" in fmt) ^ ("%M" in fmt)) + or (("%I" in fmt) ^ ("%M" in fmt)) + or ("%S" in fmt and "%H" not in fmt) + or ("%S" in fmt and "%I" not in fmt) + or (("%I" in fmt) ^ ("%p" in fmt)) + or (("%H" in fmt) ^ ("%p" in fmt)) + ) + else: + assert result == expected + + +@given( + d=st.datetimes( + min_value=datetime(1699, 1, 1), + max_value=datetime(9999, 12, 31), + ), + tu=st.sampled_from(["ms", "us"]), +) +def test_cast_to_time_and_combine(d: datetime, tu: TimeUnit) -> None: + # round-trip date/time extraction + recombining + df = pl.DataFrame({"d": [d]}, schema={"d": pl.Datetime(tu)}) + res = df.select( + d=pl.col("d"), + dt=pl.col("d").dt.date(), + tm=pl.col("d").cast(pl.Time), + ).with_columns( + dtm=pl.col("dt").dt.combine(pl.col("tm")), + ) + + datetimes = res["d"].to_list() + assert [d.date() for d in datetimes] == res["dt"].to_list() + assert [d.time() for d in datetimes] == res["tm"].to_list() + assert datetimes == res["dtm"].to_list() + + +def test_to_datetime_aware_values_aware_dtype() -> None: + s = pl.Series(["2020-01-01T01:12:34+01:00"]) + expected = pl.Series([datetime(2020, 1, 1, 5, 57, 34)]).dt.replace_time_zone( + "Asia/Kathmandu" + ) + + # When Polars infers the format + result = s.str.to_datetime(time_zone="Asia/Kathmandu") + assert_series_equal(result, expected) + + # When the format is provided + result = s.str.to_datetime(format="%Y-%m-%dT%H:%M:%S%z", time_zone="Asia/Kathmandu") + assert_series_equal(result, expected) + + # With `exact=False` + result = s.str.to_datetime( + format="%Y-%m-%dT%H:%M:%S%z", time_zone="Asia/Kathmandu", exact=False + ) + assert_series_equal(result, expected) + + # Check consistency with Series constructor + result = pl.Series( + [datetime(2020, 1, 1, 5, 57, 34, tzinfo=ZoneInfo("Asia/Kathmandu"))], + dtype=pl.Datetime("us", "Asia/Kathmandu"), + ) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("inputs", "format", "expected"), + [ + ("01-01-69", "%d-%m-%y", date(2069, 1, 1)), # Polars' parser + ("01-01-70", "%d-%m-%y", date(1970, 1, 1)), # Polars' parser + ("01-January-69", "%d-%B-%y", date(2069, 1, 1)), # Chrono + ("01-January-70", "%d-%B-%y", date(1970, 1, 1)), # Chrono + ], +) +def test_to_datetime_two_digit_year_17213( + inputs: str, format: str, expected: date +) -> None: + result = pl.Series([inputs]).str.to_date(format=format).item() + assert result == expected diff --git a/py-polars/tests/unit/operations/namespaces/temporal/test_truncate.py b/py-polars/tests/unit/operations/namespaces/temporal/test_truncate.py new file mode 100644 index 000000000000..9ce4e5cb4f53 --- /dev/null +++ b/py-polars/tests/unit/operations/namespaces/temporal/test_truncate.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +from datetime import date, datetime, timedelta +from typing import TYPE_CHECKING + +import hypothesis.strategies as st +import pytest +from hypothesis import given + +import polars as pl +from polars._utils.convert import parse_as_duration_string +from polars.testing import assert_series_equal + +if TYPE_CHECKING: + from polars._typing import TimeUnit + + +@given( + value=st.datetimes( + min_value=datetime(1000, 1, 1), + max_value=datetime(3000, 1, 1), + ), + n=st.integers(min_value=1, max_value=100), +) +def test_truncate_monthly(value: date, n: int) -> None: + result = pl.Series([value]).dt.truncate(f"{n}mo").item() + # manual calculation + total = value.year * 12 + value.month - 1 + remainder = total % n + total -= remainder + year, month = (total // 12), ((total % 12) + 1) + expected = datetime(year, month, 1) + assert result == expected + + +def test_truncate_date() -> None: + # n vs n + df = pl.DataFrame( + {"a": [date(2020, 1, 1), None, date(2020, 1, 3)], "b": [None, "1mo", "1mo"]} + ) + result = df.select(pl.col("a").dt.truncate(pl.col("b")))["a"] + expected = pl.Series("a", [None, None, date(2020, 1, 1)]) + assert_series_equal(result, expected) + + # n vs 1 + df = pl.DataFrame( + {"a": [date(2020, 1, 1), None, date(2020, 1, 3)], "b": [None, "1mo", "1mo"]} + ) + result = df.select(pl.col("a").dt.truncate("1mo"))["a"] + expected = pl.Series("a", [date(2020, 1, 1), None, date(2020, 1, 1)]) + assert_series_equal(result, expected) + + # n vs missing + df = pl.DataFrame( + {"a": [date(2020, 1, 1), None, date(2020, 1, 3)], "b": [None, "1mo", "1mo"]} + ) + result = df.select(pl.col("a").dt.truncate(pl.lit(None, dtype=pl.String)))["a"] + expected = pl.Series("a", [None, None, None], dtype=pl.Date) + assert_series_equal(result, expected) + + # 1 vs n + df = pl.DataFrame( + {"a": [date(2020, 1, 1), None, date(2020, 1, 3)], "b": [None, "1mo", "1mo"]} + ) + result = df.select(a=pl.date(2020, 1, 1).dt.truncate(pl.col("b")))["a"] + expected = pl.Series("a", [None, date(2020, 1, 1), date(2020, 1, 1)]) + assert_series_equal(result, expected) + + # missing vs n + df = pl.DataFrame( + {"a": [date(2020, 1, 1), None, date(2020, 1, 3)], "b": [None, "1mo", "1mo"]} + ) + result = df.select(a=pl.lit(None, dtype=pl.Date).dt.truncate(pl.col("b")))["a"] + expected = pl.Series("a", [None, None, None], dtype=pl.Date) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_truncate_datetime_simple(time_unit: TimeUnit) -> None: + s = pl.Series([datetime(2020, 1, 2, 6)], dtype=pl.Datetime(time_unit)) + result = s.dt.truncate("1mo").item() + assert result == datetime(2020, 1, 1) + result = s.dt.truncate("1d").item() + assert result == datetime(2020, 1, 2) + + +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_truncate_datetime_w_expression(time_unit: TimeUnit) -> None: + df = pl.DataFrame( + {"a": [datetime(2020, 1, 2, 6), datetime(2020, 1, 3, 7)], "b": ["1mo", "1d"]}, + schema_overrides={"a": pl.Datetime(time_unit)}, + ) + result = df.select(pl.col("a").dt.truncate(pl.col("b")))["a"] + assert result[0] == datetime(2020, 1, 1) + assert result[1] == datetime(2020, 1, 3) + + +def test_pre_epoch_truncate_17581() -> None: + s = pl.Series([datetime(1980, 1, 1), datetime(1969, 1, 1, 1)]) + result = s.dt.truncate("1d") + expected = pl.Series([datetime(1980, 1, 1), datetime(1969, 1, 1)]) + assert_series_equal(result, expected) + + +@given( + datetimes=st.lists( + st.datetimes(min_value=datetime(1960, 1, 1), max_value=datetime(1980, 1, 1)), + min_size=1, + max_size=3, + ), + every=st.timedeltas( + min_value=timedelta(microseconds=1), max_value=timedelta(days=1) + ).map(parse_as_duration_string), +) +def test_fast_path_vs_slow_path(datetimes: list[datetime], every: str) -> None: + s = pl.Series(datetimes) + # Might use fastpath: + result = s.dt.truncate(every) + # Definitely uses slowpath: + expected = s.dt.truncate(pl.Series([every] * len(datetimes))) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize("as_date", [False, True]) +def test_truncate_unequal_length_22018(as_date: bool) -> None: + s = pl.Series([datetime(2088, 8, 8, 8, 8, 8, 8)] * 2) + if as_date: + s = s.dt.date() + with pytest.raises(pl.exceptions.ShapeError): + s.dt.truncate(pl.Series(["1y"] * 3)) diff --git a/py-polars/tests/unit/operations/namespaces/test_binary.py b/py-polars/tests/unit/operations/namespaces/test_binary.py new file mode 100644 index 000000000000..9c5a0111185c --- /dev/null +++ b/py-polars/tests/unit/operations/namespaces/test_binary.py @@ -0,0 +1,282 @@ +from __future__ import annotations + +import random +import struct +from typing import TYPE_CHECKING + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + +if TYPE_CHECKING: + from polars._typing import SizeUnit, TransferEncoding + + +def test_binary_conversions() -> None: + df = pl.DataFrame({"blob": [b"abc", None, b"cde"]}).with_columns( + pl.col("blob").cast(pl.String).alias("decoded_blob") + ) + + assert df.to_dict(as_series=False) == { + "blob": [b"abc", None, b"cde"], + "decoded_blob": ["abc", None, "cde"], + } + assert df[0, 0] == b"abc" + assert df[1, 0] is None + assert df.dtypes == [pl.Binary, pl.String] + + +def test_contains() -> None: + df = pl.DataFrame( + data=[ + (1, b"some * * text"), + (2, b"(with) special\n * chars"), + (3, b"**etc...?$"), + (4, None), + ], + schema=["idx", "bin"], + orient="row", + ) + for pattern, expected in ( + (b"e * ", [True, False, False, None]), + (b"text", [True, False, False, None]), + (b"special", [False, True, False, None]), + (b"", [True, True, True, None]), + (b"qwe", [False, False, False, None]), + ): + # series + assert expected == df["bin"].bin.contains(pattern).to_list() + # frame select + assert ( + expected == df.select(pl.col("bin").bin.contains(pattern))["bin"].to_list() + ) + # frame filter + assert sum(e for e in expected if e is True) == len( + df.filter(pl.col("bin").bin.contains(pattern)) + ) + + +def test_contains_with_expr() -> None: + df = pl.DataFrame( + { + "bin": [b"some * * text", b"(with) special\n * chars", b"**etc...?$", None], + "lit1": [b"e * ", b"", b"qwe", b"None"], + "lit2": [None, b"special\n", b"?!", None], + } + ) + + assert df.select( + pl.col("bin").bin.contains(pl.col("lit1")).alias("contains_1"), + pl.col("bin").bin.contains(pl.col("lit2")).alias("contains_2"), + pl.col("bin").bin.contains(pl.lit(None)).alias("contains_3"), + ).to_dict(as_series=False) == { + "contains_1": [True, True, False, None], + "contains_2": [None, True, False, None], + "contains_3": [None, None, None, None], + } + + +def test_starts_ends_with() -> None: + assert pl.DataFrame( + { + "a": [b"hamburger", b"nuts", b"lollypop", None], + "end": [b"ger", b"tg", None, b"anything"], + "start": [b"ha", b"nga", None, b"anything"], + } + ).select( + pl.col("a").bin.ends_with(b"pop").alias("end_lit"), + pl.col("a").bin.ends_with(pl.lit(None)).alias("end_none"), + pl.col("a").bin.ends_with(pl.col("end")).alias("end_expr"), + pl.col("a").bin.starts_with(b"ham").alias("start_lit"), + pl.col("a").bin.ends_with(pl.lit(None)).alias("start_none"), + pl.col("a").bin.starts_with(pl.col("start")).alias("start_expr"), + ).to_dict(as_series=False) == { + "end_lit": [False, False, True, None], + "end_none": [None, None, None, None], + "end_expr": [True, False, None, None], + "start_lit": [True, False, False, None], + "start_none": [None, None, None, None], + "start_expr": [True, False, None, None], + } + + +def test_base64_encode() -> None: + df = pl.DataFrame({"data": [b"asd", b"qwe"]}) + + assert df["data"].bin.encode("base64").to_list() == ["YXNk", "cXdl"] + + +def test_base64_decode() -> None: + df = pl.DataFrame({"data": [b"YXNk", b"cXdl"]}) + + assert df["data"].bin.decode("base64").to_list() == [b"asd", b"qwe"] + + +def test_hex_encode() -> None: + df = pl.DataFrame({"data": [b"asd", b"qwe"]}) + + assert df["data"].bin.encode("hex").to_list() == ["617364", "717765"] + + +def test_hex_decode() -> None: + df = pl.DataFrame({"data": [b"617364", b"717765"]}) + + assert df["data"].bin.decode("hex").to_list() == [b"asd", b"qwe"] + + +@pytest.mark.parametrize( + "encoding", + ["hex", "base64"], +) +def test_compare_encode_between_lazy_and_eager_6814(encoding: TransferEncoding) -> None: + df = pl.DataFrame({"x": [b"aa", b"bb", b"cc"]}) + expr = pl.col("x").bin.encode(encoding) + + result_eager = df.select(expr) + dtype = result_eager["x"].dtype + + result_lazy = df.lazy().select(expr).select(pl.col(dtype)).collect() + assert_frame_equal(result_eager, result_lazy) + + +@pytest.mark.parametrize( + "encoding", + ["hex", "base64"], +) +def test_compare_decode_between_lazy_and_eager_6814(encoding: TransferEncoding) -> None: + df = pl.DataFrame({"x": [b"d3d3", b"abcd", b"1234"]}) + expr = pl.col("x").bin.decode(encoding) + + result_eager = df.select(expr) + dtype = result_eager["x"].dtype + + result_lazy = df.lazy().select(expr).select(pl.col(dtype)).collect() + assert_frame_equal(result_eager, result_lazy) + + +@pytest.mark.parametrize( + ("sz", "unit", "expected"), + [(128, "b", 128), (512, "kb", 0.5), (131072, "mb", 0.125)], +) +def test_binary_size(sz: int, unit: SizeUnit, expected: int | float) -> None: + df = pl.DataFrame({"data": [b"\x00" * sz]}, schema={"data": pl.Binary}) + for sz in ( + df.select(sz=pl.col("data").bin.size(unit)).item(), # expr + df["data"].bin.size(unit).item(), # series + ): + assert sz == expected + + +@pytest.mark.parametrize( + ("dtype", "type_size", "struct_type"), + [ + (pl.Int8, 1, "b"), + (pl.UInt8, 1, "B"), + (pl.Int16, 2, "h"), + (pl.UInt16, 2, "H"), + (pl.Int32, 4, "i"), + (pl.UInt32, 4, "I"), + (pl.Int64, 8, "q"), + (pl.UInt64, 8, "Q"), + (pl.Float32, 4, "f"), + (pl.Float64, 8, "d"), + ], +) +def test_reinterpret( + dtype: pl.DataType, + type_size: int, + struct_type: str, +) -> None: + # Make test reproducible + random.seed(42) + + byte_arr = [random.randbytes(type_size) for _ in range(3)] + df = pl.DataFrame({"x": byte_arr}) + + for endianness in ["little", "big"]: + # So that mypy doesn't complain + struct_endianness = "<" if endianness == "little" else ">" + expected = [ + struct.unpack_from(f"{struct_endianness}{struct_type}", elem_bytes)[0] + for elem_bytes in byte_arr + ] + expected_df = pl.DataFrame({"x": expected}, schema={"x": dtype}) + + result = df.select( + pl.col("x").bin.reinterpret(dtype=dtype, endianness=endianness) # type: ignore[arg-type] + ) + + assert_frame_equal(result, expected_df) + + +@pytest.mark.parametrize( + ("dtype", "type_size"), + [ + (pl.Int128, 16), + ], +) +def test_reinterpret_int( + dtype: pl.DataType, + type_size: int, +) -> None: + # Function used for testing integers that `struct` or `numpy` + # doesn't support parsing from bytes. + # Rather than creating bytes directly, create integer and view it as bytes + is_signed = dtype.is_signed_integer() + + if is_signed: + min_val = -(2 ** (type_size - 1)) + max_val = 2 ** (type_size - 1) - 1 + else: + min_val = 0 + max_val = 2**type_size - 1 + + # Make test reproducible + random.seed(42) + + expected = [random.randint(min_val, max_val) for _ in range(3)] + expected_df = pl.DataFrame({"x": expected}, schema={"x": dtype}) + + for endianness in ["little", "big"]: + byte_arr = [ + val.to_bytes(type_size, byteorder=endianness, signed=is_signed) # type: ignore[arg-type] + for val in expected + ] + df = pl.DataFrame({"x": byte_arr}) + + result = df.select( + pl.col("x").bin.reinterpret(dtype=dtype, endianness=endianness) # type: ignore[arg-type] + ) + + assert_frame_equal(result, expected_df) + + +def test_reinterpret_invalid() -> None: + # Fails because buffer has more than 4 bytes + df = pl.DataFrame({"x": [b"d3d3a"]}) + print(struct.unpack_from(" None: + s = pl.Series("a", [b"a", b"xyz"], pl.Binary).bin + f = getattr(s, func) + with pytest.raises(pl.exceptions.ShapeError): + f(pl.Series([b"x", b"y", b"z"])) diff --git a/py-polars/tests/unit/operations/namespaces/test_categorical.py b/py-polars/tests/unit/operations/namespaces/test_categorical.py new file mode 100644 index 000000000000..3ca4df90a109 --- /dev/null +++ b/py-polars/tests/unit/operations/namespaces/test_categorical.py @@ -0,0 +1,303 @@ +from __future__ import annotations + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal + + +@pytest.mark.usefixtures("test_global_and_local") +def test_categorical_lexical_sort() -> None: + df = pl.DataFrame( + {"cats": ["z", "z", "k", "a", "b"], "vals": [3, 1, 2, 2, 3]} + ).with_columns( + pl.col("cats").cast(pl.Categorical("lexical")), + ) + + out = df.sort(["cats"]) + assert out["cats"].dtype == pl.Categorical + expected = pl.DataFrame( + {"cats": ["a", "b", "k", "z", "z"], "vals": [2, 3, 2, 3, 1]} + ) + assert_frame_equal(out.with_columns(pl.col("cats").cast(pl.String)), expected) + out = df.sort(["cats", "vals"]) + expected = pl.DataFrame( + {"cats": ["a", "b", "k", "z", "z"], "vals": [2, 3, 2, 1, 3]} + ) + assert_frame_equal(out.with_columns(pl.col("cats").cast(pl.String)), expected) + out = df.sort(["vals", "cats"]) + + expected = pl.DataFrame( + {"cats": ["z", "a", "k", "b", "z"], "vals": [1, 2, 2, 3, 3]} + ) + assert_frame_equal(out.with_columns(pl.col("cats").cast(pl.String)), expected) + + s = pl.Series(["a", "c", "a", "b", "a"], dtype=pl.Categorical("lexical")) + assert s.sort().cast(pl.String).to_list() == [ + "a", + "a", + "a", + "b", + "c", + ] + + +@pytest.mark.usefixtures("test_global_and_local") +def test_categorical_lexical_ordering_after_concat() -> None: + ldf1 = ( + pl.DataFrame([pl.Series("key1", [8, 5]), pl.Series("key2", ["fox", "baz"])]) + .lazy() + .with_columns(pl.col("key2").cast(pl.Categorical("lexical"))) + ) + ldf2 = ( + pl.DataFrame( + [pl.Series("key1", [6, 8, 6]), pl.Series("key2", ["fox", "foo", "bar"])] + ) + .lazy() + .with_columns(pl.col("key2").cast(pl.Categorical("lexical"))) + ) + df = pl.concat([ldf1, ldf2]).select(pl.col("key2")).collect() + + assert df.sort("key2").to_dict(as_series=False) == { + "key2": ["bar", "baz", "foo", "fox", "fox"] + } + + +@pytest.mark.usefixtures("test_global_and_local") +@pytest.mark.may_fail_auto_streaming +def test_sort_categoricals_6014_internal() -> None: + # create basic categorical + df = pl.DataFrame({"key": ["bbb", "aaa", "ccc"]}).with_columns( + pl.col("key").cast(pl.Categorical) + ) + + out = df.sort("key") + assert out.to_dict(as_series=False) == {"key": ["bbb", "aaa", "ccc"]} + + +@pytest.mark.usefixtures("test_global_and_local") +def test_sort_categoricals_6014_lexical() -> None: + # create lexically-ordered categorical + df = pl.DataFrame({"key": ["bbb", "aaa", "ccc"]}).with_columns( + pl.col("key").cast(pl.Categorical("lexical")) + ) + + out = df.sort("key") + assert out.to_dict(as_series=False) == {"key": ["aaa", "bbb", "ccc"]} + + +@pytest.mark.usefixtures("test_global_and_local") +def test_categorical_get_categories() -> None: + assert pl.Series( + "cats", ["foo", "bar", "foo", "foo", "ham"], dtype=pl.Categorical + ).cat.get_categories().to_list() == ["foo", "bar", "ham"] + + +def test_cat_to_local() -> None: + with pl.StringCache(): + s1 = pl.Series(["a", "b", "a"], dtype=pl.Categorical) + s2 = pl.Series(["c", "b", "d"], dtype=pl.Categorical) + + # s2 physical starts after s1 + assert s1.to_physical().to_list() == [0, 1, 0] + assert s2.to_physical().to_list() == [2, 1, 3] + + out = s2.cat.to_local() + + # Physical has changed and now starts at 0, string values are the same + assert out.cat.is_local() + assert out.to_physical().to_list() == [0, 1, 2] + assert out.to_list() == s2.to_list() + + # s2 should be unchanged after the operation + assert not s2.cat.is_local() + assert s2.to_physical().to_list() == [2, 1, 3] + assert s2.to_list() == ["c", "b", "d"] + + +def test_cat_to_local_missing_values() -> None: + with pl.StringCache(): + _ = pl.Series(["a", "b"], dtype=pl.Categorical) + s = pl.Series(["c", "b", None, "d"], dtype=pl.Categorical) + + out = s.cat.to_local() + assert out.to_physical().to_list() == [0, 1, None, 2] + + +def test_cat_to_local_already_local() -> None: + s = pl.Series(["a", "c", "a", "b"], dtype=pl.Categorical) + + assert s.cat.is_local() + out = s.cat.to_local() + + assert out.to_physical().to_list() == [0, 1, 0, 2] + assert out.to_list() == ["a", "c", "a", "b"] + + +def test_cat_is_local() -> None: + s = pl.Series(["a", "c", "a", "b"], dtype=pl.Categorical) + assert s.cat.is_local() + + with pl.StringCache(): + s2 = pl.Series(["a", "b", "a"], dtype=pl.Categorical) + assert not s2.cat.is_local() + + +@pytest.mark.usefixtures("test_global_and_local") +def test_cat_uses_lexical_ordering() -> None: + s = pl.Series(["a", "b", None, "b"]).cast(pl.Categorical) + assert s.cat.uses_lexical_ordering() is False + + s = s.cast(pl.Categorical("lexical")) + assert s.cat.uses_lexical_ordering() is True + + s = s.cast(pl.Categorical("physical")) + assert s.cat.uses_lexical_ordering() is False + + +@pytest.mark.usefixtures("test_global_and_local") +def test_cat_len_bytes() -> None: + # test Series + s = pl.Series("a", ["Café", None, "Café", "345", "東京"], dtype=pl.Categorical) + result = s.cat.len_bytes() + expected = pl.Series("a", [5, None, 5, 3, 6], dtype=pl.UInt32) + assert_series_equal(result, expected) + + # test DataFrame expr + df = pl.DataFrame(s) + result_df = df.select(pl.col("a").cat.len_bytes()) + expected_df = pl.DataFrame(expected) + assert_frame_equal(result_df, expected_df) + + # test LazyFrame expr + result_lf = df.lazy().select(pl.col("a").cat.len_bytes()).collect() + assert_frame_equal(result_lf, expected_df) + + # test GroupBy + result_df = ( + pl.LazyFrame({"key": [1, 1, 1, 1, 1, 2, 2, 2, 2, 2], "value": s.extend(s)}) + .group_by("key", maintain_order=True) + .agg(pl.col("value").cat.len_bytes().alias("len_bytes")) + .explode("len_bytes") + .collect() + ) + expected_df = pl.DataFrame( + { + "key": [1, 1, 1, 1, 1, 2, 2, 2, 2, 2], + "len_bytes": pl.Series( + [5, None, 5, 3, 6, 5, None, 5, 3, 6], dtype=pl.get_index_type() + ), + } + ) + assert_frame_equal(result_df, expected_df) + + +@pytest.mark.usefixtures("test_global_and_local") +def test_cat_len_chars() -> None: + # test Series + s = pl.Series("a", ["Café", None, "Café", "345", "東京"], dtype=pl.Categorical) + result = s.cat.len_chars() + expected = pl.Series("a", [4, None, 4, 3, 2], dtype=pl.UInt32) + assert_series_equal(result, expected) + + # test DataFrame expr + df = pl.DataFrame(s) + result_df = df.select(pl.col("a").cat.len_chars()) + expected_df = pl.DataFrame(expected) + assert_frame_equal(result_df, expected_df) + + # test LazyFrame expr + result_lf = df.lazy().select(pl.col("a").cat.len_chars()).collect() + assert_frame_equal(result_lf, expected_df) + + # test GroupBy + result_df = ( + pl.LazyFrame({"key": [1, 1, 1, 1, 1, 2, 2, 2, 2, 2], "value": s.extend(s)}) + .group_by("key", maintain_order=True) + .agg(pl.col("value").cat.len_chars().alias("len_bytes")) + .explode("len_bytes") + .collect() + ) + expected_df = pl.DataFrame( + { + "key": [1, 1, 1, 1, 1, 2, 2, 2, 2, 2], + "len_bytes": pl.Series( + [4, None, 4, 3, 2, 4, None, 4, 3, 2], dtype=pl.get_index_type() + ), + } + ) + assert_frame_equal(result_df, expected_df) + + +@pytest.mark.usefixtures("test_global_and_local") +def test_starts_ends_with() -> None: + s = pl.Series( + "a", + ["hamburger_with_tomatoes", "nuts", "nuts", "lollypop", None], + dtype=pl.Categorical, + ) + assert_series_equal( + s.cat.ends_with("pop"), pl.Series("a", [False, False, False, True, None]) + ) + assert_series_equal( + s.cat.starts_with("nu"), pl.Series("a", [False, True, True, False, None]) + ) + + with pytest.raises(TypeError, match="'prefix' must be a string; found"): + s.cat.starts_with(None) # type: ignore[arg-type] + + with pytest.raises(TypeError, match="'suffix' must be a string; found"): + s.cat.ends_with(None) # type: ignore[arg-type] + + df = pl.DataFrame( + { + "a": pl.Series( + ["hamburger_with_tomatoes", "nuts", "nuts", "lollypop", None], + dtype=pl.Categorical, + ), + } + ) + + expected = { + "ends_pop": [False, False, False, True, None], + "starts_ham": [True, False, False, False, None], + } + + assert ( + df.select( + pl.col("a").cat.ends_with("pop").alias("ends_pop"), + pl.col("a").cat.starts_with("ham").alias("starts_ham"), + ).to_dict(as_series=False) + == expected + ) + + with pytest.raises(TypeError, match="'prefix' must be a string; found"): + df.select(pl.col("a").cat.starts_with(None)) # type: ignore[arg-type] + + with pytest.raises(TypeError, match="'suffix' must be a string; found"): + df.select(pl.col("a").cat.ends_with(None)) # type: ignore[arg-type] + + +def test_cat_slice() -> None: + df = pl.DataFrame( + { + "a": pl.Series( + [ + "foobar", + "barfoo", + "foobar", + "x", + None, + ], + dtype=pl.Categorical, + ) + } + ) + assert df["a"].cat.slice(-3).to_list() == ["bar", "foo", "bar", "x", None] + assert df.select([pl.col("a").cat.slice(2, 4)])["a"].to_list() == [ + "obar", + "rfoo", + "obar", + "", + None, + ] diff --git a/py-polars/tests/unit/operations/namespaces/test_meta.py b/py-polars/tests/unit/operations/namespaces/test_meta.py new file mode 100644 index 000000000000..5a0c253fbfed --- /dev/null +++ b/py-polars/tests/unit/operations/namespaces/test_meta.py @@ -0,0 +1,195 @@ +from __future__ import annotations + +from datetime import date, datetime, time, timedelta +from typing import TYPE_CHECKING, Any + +import pytest + +import polars as pl +import polars.selectors as cs +from polars.exceptions import ComputeError +from tests.unit.conftest import NUMERIC_DTYPES + +if TYPE_CHECKING: + from pathlib import Path + + +def test_meta_pop_and_cmp() -> None: + e = pl.col("foo").alias("bar") + + first = e.meta.pop()[0] + assert first.meta == pl.col("foo") + assert first.meta != pl.col("bar") + + assert first.meta.eq(pl.col("foo")) + assert first.meta.ne(pl.col("bar")) + + +def test_root_and_output_names() -> None: + e = pl.col("foo") * pl.col("bar") + assert e.meta.output_name() == "foo" + assert e.meta.root_names() == ["foo", "bar"] + + e = pl.col("foo").filter(bar=13) + assert e.meta.output_name() == "foo" + assert e.meta.root_names() == ["foo", "bar"] + + e = pl.sum("foo").over("groups") + assert e.meta.output_name() == "foo" + assert e.meta.root_names() == ["foo", "groups"] + + e = pl.sum("foo").slice(pl.len() - 10, pl.col("bar")) + assert e.meta.output_name() == "foo" + assert e.meta.root_names() == ["foo", "bar"] + + e = pl.len() + assert e.meta.output_name() == "len" + + with pytest.raises( + ComputeError, + match="cannot determine output column without a context for this expression", + ): + pl.all().name.suffix("_").meta.output_name() + + assert ( + pl.all().name.suffix("_").meta.output_name(raise_if_undetermined=False) is None + ) + + +def test_undo_aliases() -> None: + e = pl.col("foo").alias("bar") + assert e.meta.undo_aliases().meta == pl.col("foo") + + e = pl.col("foo").sum().over("bar") + assert e.name.keep().meta.undo_aliases().meta == e + + e.alias("bar").alias("foo") + assert e.meta.undo_aliases().meta == e + assert e.name.suffix("ham").meta.undo_aliases().meta == e + + +def test_meta_has_multiple_outputs() -> None: + e = pl.col(["a", "b"]).name.suffix("_foo") + assert e.meta.has_multiple_outputs() + + +def test_is_column() -> None: + e = pl.col("foo") + assert e.meta.is_column() + + e = pl.col("foo").alias("bar") + assert not e.meta.is_column() + + e = pl.col("foo") * pl.col("bar") + assert not e.meta.is_column() + + +@pytest.mark.parametrize( + ("expr", "is_column_selection"), + [ + # columns + (pl.col("foo"), True), + (pl.col("foo", "bar"), True), + (pl.col(NUMERIC_DTYPES), True), + # column expressions + (pl.col("foo") + 100, False), + (pl.col("foo").floordiv(10), False), + (pl.col("foo") * pl.col("bar"), False), + # selectors / expressions + (cs.numeric() * 100, False), + (cs.temporal() - cs.time(), True), + (cs.numeric().exclude("value"), True), + ((cs.temporal() - cs.time()).exclude("dt"), True), + # top-level selection funcs + (pl.nth(2), True), + (pl.first(), True), + (pl.last(), True), + ], +) +def test_is_column_selection( + expr: pl.Expr, + is_column_selection: bool, +) -> None: + if is_column_selection: + assert expr.meta.is_column_selection() + assert expr.meta.is_column_selection(allow_aliasing=True) + expr = ( + expr.name.suffix("!") + if expr.meta.has_multiple_outputs() + else expr.alias("!") + ) + assert not expr.meta.is_column_selection() + assert expr.meta.is_column_selection(allow_aliasing=True) + else: + assert not expr.meta.is_column_selection() + + +@pytest.mark.parametrize( + "value", + [ + None, + 1234, + 567.89, + float("inf"), + date.today(), + datetime.now(), + time(10, 30, 45), + timedelta(hours=-24), + ["x", "y", "z"], + pl.Series([None, None]), + [[10, 20], [30, 40]], + "this is the way", + ], +) +def test_is_literal(value: Any) -> None: + e = pl.lit(value) + assert e.meta.is_literal() + + e = pl.lit(value).alias("foo") + assert not e.meta.is_literal() + + e = pl.lit(value).alias("foo") + assert e.meta.is_literal(allow_aliasing=True) + + +def test_meta_is_regex_projection() -> None: + e = pl.col("^.*$").name.suffix("_foo") + assert e.meta.is_regex_projection() + assert e.meta.has_multiple_outputs() + + e = pl.col("^.*") # no trailing '$' + assert not e.meta.is_regex_projection() + assert not e.meta.has_multiple_outputs() + assert e.meta.is_column() + + +def test_meta_tree_format(namespace_files_path: Path) -> None: + with (namespace_files_path / "test_tree_fmt.txt").open("r", encoding="utf-8") as f: + test_sets = f.read().split("---") + for test_set in test_sets: + expression = test_set.strip().split("\n")[0] + tree_fmt = "\n".join(test_set.strip().split("\n")[1:]) + e = eval(expression) + result = e.meta.tree_format(return_as_string=True) + result = "\n".join(s.rstrip() for s in result.split("\n")) + assert result.strip() == tree_fmt.strip() + + +def test_meta_show_graph(namespace_files_path: Path) -> None: + e = (pl.col("foo") * pl.col("bar")).sum().over(pl.col("ham")) / 2 + dot = e.meta.show_graph(show=False, raw_output=True) + assert dot is not None + assert len(dot) > 0 + # Don't check output contents since this creates a maintenance burden + # Assume output check in test_meta_tree_format is enough + + +def test_literal_output_name() -> None: + e = pl.lit(1) + assert e.meta.output_name() == "literal" + + e = pl.lit(pl.Series("abc", [1, 2, 3])) + assert e.meta.output_name() == "abc" + + e = pl.lit(pl.Series([1, 2, 3])) + assert e.meta.output_name() == "" diff --git a/py-polars/tests/unit/operations/namespaces/test_name.py b/py-polars/tests/unit/operations/namespaces/test_name.py new file mode 100644 index 000000000000..eac08e537a88 --- /dev/null +++ b/py-polars/tests/unit/operations/namespaces/test_name.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from collections import OrderedDict + +import polars as pl + + +def test_name_change_case() -> None: + df = pl.DataFrame( + schema={"ColX": pl.Int32, "ColY": pl.String}, + ).with_columns( + pl.all().name.to_uppercase(), + pl.all().name.to_lowercase(), + ) + assert df.schema == OrderedDict( + [ + ("ColX", pl.Int32), + ("ColY", pl.String), + ("COLX", pl.Int32), + ("COLY", pl.String), + ("colx", pl.Int32), + ("coly", pl.String), + ] + ) + + +def test_name_prefix_suffix() -> None: + df = pl.DataFrame( + schema={"ColX": pl.Int32, "ColY": pl.String}, + ).with_columns( + pl.all().name.prefix("#"), + pl.all().name.suffix("!!"), + ) + assert df.schema == OrderedDict( + [ + ("ColX", pl.Int32), + ("ColY", pl.String), + ("#ColX", pl.Int32), + ("#ColY", pl.String), + ("ColX!!", pl.Int32), + ("ColY!!", pl.String), + ] + ) + + +def test_name_update_all() -> None: + df = pl.DataFrame( + schema={ + "col1": pl.UInt32, + "col2": pl.Float64, + "other": pl.UInt64, + } + ) + assert ( + df.select( + pl.col("col2").append(pl.col("other")), + pl.col("col1").append(pl.col("other")).name.keep(), + pl.col("col1").append(pl.col("other")).name.prefix("prefix_"), + pl.col("col1").append(pl.col("other")).name.suffix("_suffix"), + ) + ).schema == OrderedDict( + [ + ("col2", pl.Float64), + ("col1", pl.UInt64), + ("prefix_col1", pl.UInt64), + ("col1_suffix", pl.UInt64), + ] + ) diff --git a/py-polars/tests/unit/operations/namespaces/test_plot.py b/py-polars/tests/unit/operations/namespaces/test_plot.py new file mode 100644 index 000000000000..f4889dcc65e9 --- /dev/null +++ b/py-polars/tests/unit/operations/namespaces/test_plot.py @@ -0,0 +1,77 @@ +import altair as alt + +import polars as pl + + +def test_dataframe_plot() -> None: + # dry-run, check nothing errors + df = pl.DataFrame( + { + "length": [1, 4, 6], + "width": [4, 5, 6], + "species": ["setosa", "setosa", "versicolor"], + } + ) + df.plot.line(x="length", y="width", color="species").to_json() + df.plot.point(x="length", y="width", size="species").to_json() + df.plot.scatter(x="length", y="width", size="species").to_json() + df.plot.bar(x="length", y="width", color="species").to_json() + df.plot.area(x="length", y="width", color="species").to_json() + + +def test_dataframe_plot_tooltip() -> None: + df = pl.DataFrame( + { + "length": [1, 4, 6], + "width": [4, 5, 6], + "species": ["setosa", "setosa", "versicolor"], + } + ) + result = df.plot.line(x="length", y="width", color="species").to_dict() + assert result["mark"]["tooltip"] is True + result = df.plot.line( + x="length", y="width", color="species", tooltip=["length", "width"] + ).to_dict() + assert result["encoding"]["tooltip"] == [ + {"field": "length", "type": "quantitative"}, + {"field": "width", "type": "quantitative"}, + ] + + +def test_series_plot() -> None: + # dry-run, check nothing errors + s = pl.Series("a", [1, 4, 4, 4, 7, 2, 5, 3, 6]) + s.plot.kde().to_json() + s.plot.hist().to_json() + s.plot.line().to_json() + s.plot.point().to_json() + + +def test_series_plot_tooltip() -> None: + s = pl.Series("a", [1, 4, 4, 4, 7, 2, 5, 3, 6]) + result = s.plot.line().to_dict() + assert result["mark"]["tooltip"] is True + result = s.plot.line(tooltip=["a"]).to_dict() + assert result["encoding"]["tooltip"] == [{"field": "a", "type": "quantitative"}] + + +def test_empty_dataframe() -> None: + pl.DataFrame({"a": [], "b": []}).plot.point(x="a", y="b") + + +def test_nameless_series() -> None: + pl.Series([1, 2, 3]).plot.kde().to_json() + + +def test_x_with_axis_18830() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) + result = df.plot.line(x=alt.X("a", axis=alt.Axis(labelAngle=-90))).to_dict() + assert result["mark"]["tooltip"] is True + + +def test_errorbar_19787() -> None: + df = pl.DataFrame({"A": [0, 1, 2], "B": [10, 11, 12], "C": [1, 2, 3]}) + result = df.plot.errorbar(x="A", y="B", yError="C").to_dict() + assert "tooltip" not in result["encoding"] + result = df["A"].plot.errorbar().to_dict() + assert "tooltip" not in result["encoding"] diff --git a/py-polars/tests/unit/operations/namespaces/test_strptime.py b/py-polars/tests/unit/operations/namespaces/test_strptime.py new file mode 100644 index 000000000000..79f0a9f62bfb --- /dev/null +++ b/py-polars/tests/unit/operations/namespaces/test_strptime.py @@ -0,0 +1,789 @@ +""" +Module for testing `.str.strptime` of the string namespace. + +This method gets its own module due to its complexity. +""" + +from __future__ import annotations + +from datetime import date, datetime, time, timedelta, timezone +from typing import TYPE_CHECKING +from zoneinfo import ZoneInfo + +import pytest + +import polars as pl +from polars.exceptions import ChronoFormatWarning, ComputeError, InvalidOperationError +from polars.testing import assert_series_equal + +if TYPE_CHECKING: + from polars._typing import PolarsTemporalType, TimeUnit + + +def test_str_strptime() -> None: + s = pl.Series(["2020-01-01", "2020-02-02"]) + expected = pl.Series([date(2020, 1, 1), date(2020, 2, 2)]) + assert_series_equal(s.str.strptime(pl.Date, "%Y-%m-%d"), expected) + + s = pl.Series(["2020-01-01 00:00:00", "2020-02-02 03:20:10"]) + expected = pl.Series( + [datetime(2020, 1, 1, 0, 0, 0), datetime(2020, 2, 2, 3, 20, 10)] + ) + assert_series_equal(s.str.strptime(pl.Datetime, "%Y-%m-%d %H:%M:%S"), expected) + + s = pl.Series(["00:00:00", "03:20:10"]) + expected = pl.Series([0, 12010000000000], dtype=pl.Time) + assert_series_equal(s.str.strptime(pl.Time, "%H:%M:%S"), expected) + + +def test_date_parse_omit_day() -> None: + df = pl.DataFrame({"month": ["2022-01"]}) + assert df.select(pl.col("month").str.to_date(format="%Y-%m")).item() == date( + 2022, 1, 1 + ) + assert df.select( + pl.col("month").str.to_datetime(format="%Y-%m") + ).item() == datetime(2022, 1, 1) + + +def test_to_datetime_precision() -> None: + s = pl.Series( + "date", ["2022-09-12 21:54:36.789321456", "2022-09-13 12:34:56.987456321"] + ) + ds = s.str.to_datetime() + assert ds.cast(pl.Date).is_not_null().all() + assert getattr(ds.dtype, "time_unit", None) == "us" + + time_units: list[TimeUnit] = ["ms", "us", "ns"] + suffixes = ["%.3f", "%.6f", "%.9f"] + test_data = zip( + time_units, + suffixes, + ( + [789000000, 987000000], + [789321000, 987456000], + [789321456, 987456321], + ), + ) + for time_unit, suffix, expected_values in test_data: + ds = s.str.to_datetime(f"%Y-%m-%d %H:%M:%S{suffix}", time_unit=time_unit) + assert getattr(ds.dtype, "time_unit", None) == time_unit + assert ds.dt.nanosecond().to_list() == expected_values + + +@pytest.mark.parametrize( + ("time_unit", "expected"), + [("ms", "123000000"), ("us", "123456000"), ("ns", "123456789")], +) +@pytest.mark.parametrize("format", ["%Y-%m-%d %H:%M:%S%.f", None]) +def test_to_datetime_precision_with_time_unit( + time_unit: TimeUnit, expected: str, format: str +) -> None: + s = pl.Series(["2020-01-01 00:00:00.123456789"]) + result = s.str.to_datetime(format, time_unit=time_unit).dt.to_string("%f")[0] + assert result == expected + + +@pytest.mark.parametrize( + ("tz_string", "timedelta"), + [("+01:00", timedelta(minutes=60)), ("-01:30", timedelta(hours=-1, minutes=-30))], +) +def test_timezone_aware_strptime(tz_string: str, timedelta: timedelta) -> None: + times = pl.DataFrame( + { + "delivery_datetime": [ + "2021-12-05 06:00:00" + tz_string, + "2021-12-05 07:00:00" + tz_string, + "2021-12-05 08:00:00" + tz_string, + ] + } + ) + assert times.with_columns( + pl.col("delivery_datetime").str.to_datetime(format="%Y-%m-%d %H:%M:%S%z") + ).to_dict(as_series=False) == { + "delivery_datetime": [ + datetime(2021, 12, 5, 6, 0, tzinfo=timezone(timedelta)), + datetime(2021, 12, 5, 7, 0, tzinfo=timezone(timedelta)), + datetime(2021, 12, 5, 8, 0, tzinfo=timezone(timedelta)), + ] + } + + +def test_to_date_non_exact_strptime() -> None: + s = pl.Series("a", ["2022-01-16", "2022-01-17", "foo2022-01-18", "b2022-01-19ar"]) + format = "%Y-%m-%d" + + result = s.str.to_date(format, strict=False, exact=True) + expected = pl.Series("a", [date(2022, 1, 16), date(2022, 1, 17), None, None]) + assert_series_equal(result, expected) + + result = s.str.to_date(format, strict=False, exact=False) + expected = pl.Series( + "a", + [date(2022, 1, 16), date(2022, 1, 17), date(2022, 1, 18), date(2022, 1, 19)], + ) + assert_series_equal(result, expected) + + with pytest.raises(InvalidOperationError): + s.str.to_date(format, strict=True, exact=True) + + +@pytest.mark.parametrize( + ("time_string", "expected"), + [ + ("01-02-2024", date(2024, 2, 1)), + ("01.02.2024", date(2024, 2, 1)), + ("01/02/2024", date(2024, 2, 1)), + ("2024-02-01", date(2024, 2, 1)), + ("2024/02/01", date(2024, 2, 1)), + ("31-12-2024", date(2024, 12, 31)), + ("31.12.2024", date(2024, 12, 31)), + ("31/12/2024", date(2024, 12, 31)), + ("2024-12-31", date(2024, 12, 31)), + ("2024/12/31", date(2024, 12, 31)), + ], +) +def test_to_date_all_inferred_date_patterns(time_string: str, expected: date) -> None: + result = pl.Series([time_string]).str.to_date() + assert result[0] == expected + + +@pytest.mark.parametrize( + ("time_string", "expected"), + [ + ("2024-12-04 09:08:00", datetime(2024, 12, 4, 9, 8, 0)), + ("2024-12-4 9:8:0", datetime(2024, 12, 4, 9, 8, 0)), + ("2024/12/04 9:8", datetime(2024, 12, 4, 9, 8, 0)), + ("4/12/2024 9:8", datetime(2024, 12, 4, 9, 8, 0)), + ], +) +def test_to_datetime_infer_missing_digit_in_time_16092( + time_string: str, expected: datetime +) -> None: + result = pl.Series([time_string]).str.to_datetime() + assert result[0] == expected + + +@pytest.mark.parametrize( + ("value", "attr"), + [ + ("a", "to_date"), + ("ab", "to_date"), + ("a", "to_datetime"), + ("ab", "to_datetime"), + ], +) +def test_non_exact_short_elements_10223(value: str, attr: str) -> None: + with pytest.raises((InvalidOperationError, ComputeError)): + getattr(pl.Series(["2019-01-01", value]).str, attr)(exact=False) + + +@pytest.mark.parametrize( + ("offset", "time_zone", "tzinfo", "format"), + [ + ("+01:00", "UTC", timezone(timedelta(hours=1)), "%Y-%m-%dT%H:%M%z"), + ("", None, None, "%Y-%m-%dT%H:%M"), + ], +) +def test_to_datetime_non_exact_strptime( + offset: str, time_zone: str | None, tzinfo: timezone | None, format: str +) -> None: + s = pl.Series( + "a", + [ + f"2022-01-16T00:00{offset}", + f"2022-01-17T00:00{offset}", + f"foo2022-01-18T00:00{offset}", + f"b2022-01-19T00:00{offset}ar", + ], + ) + + result = s.str.to_datetime(format, strict=False, exact=True) + expected = pl.Series( + "a", + [ + datetime(2022, 1, 16, tzinfo=tzinfo), + datetime(2022, 1, 17, tzinfo=tzinfo), + None, + None, + ], + ) + assert_series_equal(result, expected) + assert result.dtype == pl.Datetime("us", time_zone) + + result = s.str.to_datetime(format, strict=False, exact=False) + expected = pl.Series( + "a", + [ + datetime(2022, 1, 16, tzinfo=tzinfo), + datetime(2022, 1, 17, tzinfo=tzinfo), + datetime(2022, 1, 18, tzinfo=tzinfo), + datetime(2022, 1, 19, tzinfo=tzinfo), + ], + ) + assert_series_equal(result, expected) + assert result.dtype == pl.Datetime("us", time_zone) + + with pytest.raises(InvalidOperationError): + s.str.to_datetime(format, strict=True, exact=True) + + +def test_to_datetime_dates_datetimes() -> None: + s = pl.Series("date", ["2021-04-22", "2022-01-04 00:00:00"]) + assert s.str.to_datetime().to_list() == [ + datetime(2021, 4, 22, 0, 0), + datetime(2022, 1, 4, 0, 0), + ] + + +@pytest.mark.parametrize( + ("time_string", "expected"), + [ + ("09-05-2019", datetime(2019, 5, 9)), + ("2018-09-05", datetime(2018, 9, 5)), + ("2018-09-05T04:05:01", datetime(2018, 9, 5, 4, 5, 1)), + ("2018-09-05T04:24:01.9", datetime(2018, 9, 5, 4, 24, 1, 900000)), + ("2018-09-05T04:24:02.11", datetime(2018, 9, 5, 4, 24, 2, 110000)), + ("2018-09-05T14:24:02.123", datetime(2018, 9, 5, 14, 24, 2, 123000)), + ("2019-04-18T02:45:55.555000000", datetime(2019, 4, 18, 2, 45, 55, 555000)), + ("2019-04-18T22:45:55.555123", datetime(2019, 4, 18, 22, 45, 55, 555123)), + ( + "2018-09-05T04:05:01+01:00", + datetime(2018, 9, 5, 4, 5, 1, tzinfo=timezone(timedelta(hours=1))), + ), + ( + "2018-09-05T04:24:01.9+01:00", + datetime(2018, 9, 5, 4, 24, 1, 900000, tzinfo=timezone(timedelta(hours=1))), + ), + ( + "2018-09-05T04:24:02.11+01:00", + datetime(2018, 9, 5, 4, 24, 2, 110000, tzinfo=timezone(timedelta(hours=1))), + ), + ( + "2018-09-05T14:24:02.123+01:00", + datetime( + 2018, 9, 5, 14, 24, 2, 123000, tzinfo=timezone(timedelta(hours=1)) + ), + ), + ( + "2019-04-18T02:45:55.555000000+01:00", + datetime( + 2019, 4, 18, 2, 45, 55, 555000, tzinfo=timezone(timedelta(hours=1)) + ), + ), + ( + "2019-04-18T22:45:55.555123+01:00", + datetime( + 2019, 4, 18, 22, 45, 55, 555123, tzinfo=timezone(timedelta(hours=1)) + ), + ), + ], +) +def test_to_datetime_patterns_single(time_string: str, expected: str) -> None: + result = pl.Series([time_string]).str.to_datetime().item() + assert result == expected + + +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_infer_tz_aware_time_unit(time_unit: TimeUnit) -> None: + result = pl.Series(["2020-01-02T04:00:00+02:00"]).str.to_datetime( + time_unit=time_unit + ) + assert result.dtype == pl.Datetime(time_unit, "UTC") + assert result.item() == datetime(2020, 1, 2, 2, 0, tzinfo=timezone.utc) + + +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_infer_tz_aware_with_utc(time_unit: TimeUnit) -> None: + result = pl.Series(["2020-01-02T04:00:00+02:00"]).str.to_datetime( + time_unit=time_unit + ) + assert result.dtype == pl.Datetime(time_unit, "UTC") + assert result.item() == datetime(2020, 1, 2, 2, 0, tzinfo=timezone.utc) + + +def test_str_to_datetime_infer_tz_aware() -> None: + result = ( + pl.Series(["2020-01-02T04:00:00+02:00"]) + .str.to_datetime(time_unit="us", time_zone="Europe/Vienna") + .item() + ) + assert result == datetime(2020, 1, 2, 3, tzinfo=ZoneInfo("Europe/Vienna")) + + +@pytest.mark.parametrize( + "result", + [ + pl.Series(["2020-01-01T00:00:00+00:00"]).str.strptime( + pl.Datetime("us", "UTC"), format="%Y-%m-%dT%H:%M:%S%z" + ), + pl.Series(["2020-01-01T00:00:00+00:00"]).str.strptime( + pl.Datetime("us"), format="%Y-%m-%dT%H:%M:%S%z" + ), + pl.Series(["2020-01-01T00:00:00+00:00"]).str.strptime(pl.Datetime("us", "UTC")), + pl.Series(["2020-01-01T00:00:00+00:00"]).str.strptime(pl.Datetime("us")), + pl.Series(["2020-01-01T00:00:00+00:00"]).str.to_datetime( + time_zone="UTC", format="%Y-%m-%dT%H:%M:%S%z" + ), + pl.Series(["2020-01-01T00:00:00+00:00"]).str.to_datetime( + format="%Y-%m-%dT%H:%M:%S%z" + ), + pl.Series(["2020-01-01T00:00:00+00:00"]).str.to_datetime(time_zone="UTC"), + pl.Series(["2020-01-01T00:00:00+00:00"]).str.to_datetime(), + ], +) +def test_parsing_offset_aware_with_utc_dtype(result: pl.Series) -> None: + expected = pl.Series([datetime(2020, 1, 1, tzinfo=timezone.utc)]) + assert_series_equal(result, expected) + + +def test_datetime_strptime_patterns_consistent() -> None: + # note that all should be year first + df = pl.Series( + "date", + [ + "2018-09-05", + "2018-09-05T04:05:01", + "2018-09-05T04:24:01.9", + "2018-09-05T04:24:02.11", + "2018-09-05T14:24:02.123", + "2018-09-05T14:24:02.123Z", + "2019-04-18T02:45:55.555000000", + "2019-04-18T22:45:55.555123", + ], + ).to_frame() + s = df.with_columns( + pl.col("date").str.to_datetime(strict=False).alias("parsed"), + )["parsed"] + assert s.null_count() == 1 + assert s[5] is None + + +def test_datetime_strptime_patterns_inconsistent() -> None: + # note that the pattern is inferred from the first element to + # be DatetimeDMY, and so the others (correctly) parse as `null`. + df = pl.Series( + "date", + [ + "09-05-2019", + "2018-09-05", + "2018-09-05T04:05:01", + "2018-09-05T04:24:01.9", + "2018-09-05T04:24:02.11", + "2018-09-05T14:24:02.123", + "2018-09-05T14:24:02.123Z", + "2019-04-18T02:45:55.555000000", + "2019-04-18T22:45:55.555123", + ], + ).to_frame() + s = df.with_columns(pl.col("date").str.to_datetime(strict=False).alias("parsed"))[ + "parsed" + ] + assert s.null_count() == 8 + assert s[0] is not None + + +@pytest.mark.parametrize( + ( + "ts", + "format", + "exp_year", + "exp_month", + "exp_day", + "exp_hour", + "exp_minute", + "exp_second", + ), + [ + ("-0031-04-24 22:13:20", "%Y-%m-%d %H:%M:%S", -31, 4, 24, 22, 13, 20), + ("-0031-04-24", "%Y-%m-%d", -31, 4, 24, 0, 0, 0), + ], +) +def test_parse_negative_dates( + ts: str, + format: str, + exp_year: int, + exp_month: int, + exp_day: int, + exp_hour: int, + exp_minute: int, + exp_second: int, +) -> None: + s = pl.Series([ts]) + result = s.str.to_datetime(format, time_unit="ms") + # Python datetime.datetime doesn't support negative dates, so comparing + # with `result.item()` directly won't work. + assert result.dt.year().item() == exp_year + assert result.dt.month().item() == exp_month + assert result.dt.day().item() == exp_day + assert result.dt.hour().item() == exp_hour + assert result.dt.minute().item() == exp_minute + assert result.dt.second().item() == exp_second + + +def test_short_formats() -> None: + s = pl.Series(["20202020", "2020"]) + assert s.str.to_date("%Y", strict=False).to_list() == [ + None, + date(2020, 1, 1), + ] + assert s.str.to_date("%bar", strict=False).to_list() == [None, None] + + +@pytest.mark.parametrize( + ("time_string", "fmt", "datatype", "expected"), + [ + ("Jul/2020", "%b/%Y", pl.Date, date(2020, 7, 1)), + ("Jan/2020", "%b/%Y", pl.Date, date(2020, 1, 1)), + ("02/Apr/2020", "%d/%b/%Y", pl.Date, date(2020, 4, 2)), + ("Dec/2020", "%b/%Y", pl.Datetime, datetime(2020, 12, 1, 0, 0)), + ("Nov/2020", "%b/%Y", pl.Datetime, datetime(2020, 11, 1, 0, 0)), + ("02/Feb/2020", "%d/%b/%Y", pl.Datetime, datetime(2020, 2, 2, 0, 0)), + ], +) +def test_strptime_abbrev_month( + time_string: str, fmt: str, datatype: PolarsTemporalType, expected: date +) -> None: + s = pl.Series([time_string]) + result = s.str.strptime(datatype, fmt).item() + assert result == expected + + +def test_full_month_name() -> None: + s = pl.Series(["2022-December-01"]).str.to_datetime("%Y-%B-%d") + assert s[0] == datetime(2022, 12, 1) + + +@pytest.mark.parametrize( + ("datatype", "expected"), + [ + (pl.Datetime, datetime(2022, 1, 1)), + (pl.Date, date(2022, 1, 1)), + ], +) +def test_single_digit_month( + datatype: PolarsTemporalType, expected: datetime | date +) -> None: + s = pl.Series(["2022-1-1"]).str.strptime(datatype, "%Y-%m-%d") + assert s[0] == expected + + +def test_invalid_date_parsing_4898() -> None: + assert pl.Series(["2022-09-18", "2022-09-50"]).str.to_date( + "%Y-%m-%d", strict=False + ).to_list() == [date(2022, 9, 18), None] + + +def test_strptime_invalid_timezone() -> None: + ts = pl.Series(["2020-01-01 00:00:00+01:00"]).str.to_datetime("%Y-%m-%d %H:%M:%S%z") + with pytest.raises(ComputeError, match=r"unable to parse time zone: 'foo'"): + ts.dt.replace_time_zone("foo") + + +def test_to_datetime_ambiguous_or_non_existent() -> None: + with pytest.raises( + ComputeError, + match="datetime '2021-11-07 01:00:00' is ambiguous in time zone 'US/Central'", + ): + pl.Series(["2021-11-07 01:00"]).str.to_datetime( + time_unit="us", time_zone="US/Central" + ) + with pytest.raises( + ComputeError, + match="datetime '2021-03-28 02:30:00' is non-existent in time zone 'Europe/Warsaw'", + ): + pl.Series(["2021-03-28 02:30"]).str.to_datetime( + time_unit="us", time_zone="Europe/Warsaw" + ) + with pytest.raises( + ComputeError, + match="datetime '2021-03-28 02:30:00' is non-existent in time zone 'Europe/Warsaw'", + ): + pl.Series(["2021-03-28 02:30"]).str.to_datetime( + time_unit="us", + time_zone="Europe/Warsaw", + ambiguous="null", + ) + with pytest.raises( + ComputeError, + match="datetime '2021-03-28 02:30:00' is non-existent in time zone 'Europe/Warsaw'", + ): + pl.Series(["2021-03-28 02:30"] * 2).str.to_datetime( + time_unit="us", + time_zone="Europe/Warsaw", + ambiguous=pl.Series(["null", "null"]), + ) + + +@pytest.mark.parametrize( + ("ts", "fmt", "expected"), + [ + ("2020-01-01T00:00:00Z", None, datetime(2020, 1, 1, tzinfo=timezone.utc)), + ("2020-01-01T00:00:00Z", "%+", datetime(2020, 1, 1, tzinfo=timezone.utc)), + ( + "2020-01-01T00:00:00+01:00", + "%Y-%m-%dT%H:%M:%S%z", + datetime(2020, 1, 1, tzinfo=timezone(timedelta(seconds=3600))), + ), + ( + "2020-01-01T00:00:00+01:00", + "%Y-%m-%dT%H:%M:%S%:z", + datetime(2020, 1, 1, tzinfo=timezone(timedelta(seconds=3600))), + ), + ( + "2020-01-01T00:00:00+01:00", + "%Y-%m-%dT%H:%M:%S%#z", + datetime(2020, 1, 1, tzinfo=timezone(timedelta(seconds=3600))), + ), + ], +) +def test_to_datetime_tz_aware_strptime(ts: str, fmt: str, expected: datetime) -> None: + result = pl.Series([ts]).str.to_datetime(fmt).item() + assert result == expected + + +@pytest.mark.parametrize("format", ["%+", "%Y-%m-%dT%H:%M:%S%z"]) +def test_crossing_dst(format: str) -> None: + ts = ["2021-03-27T23:59:59+01:00", "2021-03-28T23:59:59+02:00"] + result = pl.Series(ts).str.to_datetime(format) + assert result[0] == datetime(2021, 3, 27, 22, 59, 59, tzinfo=ZoneInfo("UTC")) + assert result[1] == datetime(2021, 3, 28, 21, 59, 59, tzinfo=ZoneInfo("UTC")) + + +@pytest.mark.parametrize("format", ["%+", "%Y-%m-%dT%H:%M:%S%z"]) +def test_crossing_dst_tz_aware(format: str) -> None: + ts = ["2021-03-27T23:59:59+01:00", "2021-03-28T23:59:59+02:00"] + result = pl.Series(ts).str.to_datetime(format) + expected = pl.Series( + [ + datetime(2021, 3, 27, 22, 59, 59, tzinfo=timezone.utc), + datetime(2021, 3, 28, 21, 59, 59, tzinfo=timezone.utc), + ] + ) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("data", "format", "expected"), + [ + ( + "2023-02-05T05:10:10.074000", + "%Y-%m-%dT%H:%M:%S%.f", + datetime(2023, 2, 5, 5, 10, 10, 74000), + ), + ], +) +def test_strptime_subseconds_datetime(data: str, format: str, expected: time) -> None: + s = pl.Series([data]) + result = s.str.to_datetime(format).item() + assert result == expected + + +@pytest.mark.parametrize( + ("string", "fmt"), + [ + pytest.param("2023-05-04|7", "%Y-%m-%d|%H", id="hour but no minute"), + pytest.param("2023-05-04|7", "%Y-%m-%d|%k", id="padded hour but no minute"), + pytest.param("2023-05-04|10", "%Y-%m-%d|%M", id="minute but no hour"), + pytest.param("2023-05-04|10", "%Y-%m-%d|%S", id="second but no hour"), + pytest.param( + "2000-Jan-01 01 00 01", "%Y-%b-%d %I %M %S", id="12-hour clock but no AM/PM" + ), + pytest.param( + "2000-Jan-01 01 00 01", + "%Y-%b-%d %l %M %S", + id="padded 12-hour clock but no AM/PM", + ), + ], +) +def test_strptime_incomplete_formats(string: str, fmt: str) -> None: + with pytest.raises( + ComputeError, + match="Invalid format string", + ): + pl.Series([string]).str.to_datetime(fmt) + + +@pytest.mark.parametrize( + ("string", "fmt", "expected"), + [ + ("2023-05-04|7:3", "%Y-%m-%d|%H:%M", datetime(2023, 5, 4, 7, 3)), + ("2023-05-04|10:03", "%Y-%m-%d|%H:%M", datetime(2023, 5, 4, 10, 3)), + ( + "2000-Jan-01 01 00 01 am", + "%Y-%b-%d %I %M %S %P", + datetime(2000, 1, 1, 1, 0, 1), + ), + ( + "2000-Jan-01 01 00 01 am", + "%Y-%b-%d %_I %M %S %P", + datetime(2000, 1, 1, 1, 0, 1), + ), + ( + "2000-Jan-01 01 00 01 am", + "%Y-%b-%d %l %M %S %P", + datetime(2000, 1, 1, 1, 0, 1), + ), + ( + "2000-Jan-01 01 00 01 AM", + "%Y-%b-%d %I %M %S %p", + datetime(2000, 1, 1, 1, 0, 1), + ), + ( + "2000-Jan-01 01 00 01 AM", + "%Y-%b-%d %_I %M %S %p", + datetime(2000, 1, 1, 1, 0, 1), + ), + ( + "2000-Jan-01 01 00 01 AM", + "%Y-%b-%d %l %M %S %p", + datetime(2000, 1, 1, 1, 0, 1), + ), + ], +) +def test_strptime_complete_formats(string: str, fmt: str, expected: datetime) -> None: + # Similar to the above, but these formats are complete and should work + result = pl.Series([string]).str.to_datetime(fmt).item() + assert result == expected + + +@pytest.mark.parametrize( + ("data", "format", "expected"), + [ + ("05:10:10.074000", "%H:%M:%S%.f", time(5, 10, 10, 74000)), + ("05:10:10.074000", "%T%.6f", time(5, 10, 10, 74000)), + ("05:10:10.074000", "%H:%M:%S%.3f", time(5, 10, 10, 74000)), + ], +) +def test_to_time_subseconds(data: str, format: str, expected: time) -> None: + s = pl.Series([data]) + for res in ( + s.str.to_time().item(), + s.str.to_time(format).item(), + ): + assert res == expected + + +def test_to_time_format_warning() -> None: + s = pl.Series(["05:10:10.074000"]) + with pytest.warns(ChronoFormatWarning, match=".%f"): + result = s.str.to_time("%H:%M:%S.%f").item() + assert result == time(5, 10, 10, 74) + + +@pytest.mark.parametrize("exact", [True, False]) +def test_to_datetime_ambiguous_earliest(exact: bool) -> None: + result = ( + pl.Series(["2020-10-25 01:00"]) + .str.to_datetime(time_zone="Europe/London", ambiguous="earliest", exact=exact) + .item() + ) + expected = datetime(2020, 10, 25, 1, fold=0, tzinfo=ZoneInfo("Europe/London")) + assert result == expected + result = ( + pl.Series(["2020-10-25 01:00"]) + .str.to_datetime(time_zone="Europe/London", ambiguous="latest", exact=exact) + .item() + ) + expected = datetime(2020, 10, 25, 1, fold=1, tzinfo=ZoneInfo("Europe/London")) + assert result == expected + with pytest.raises(ComputeError): + pl.Series(["2020-10-25 01:00"]).str.to_datetime( + time_zone="Europe/London", + exact=exact, + ).item() + + +def test_to_datetime_naive_format_and_time_zone() -> None: + # format-specified path + result = pl.Series(["2020-01-01"]).str.to_datetime( + format="%Y-%m-%d", time_zone="Asia/Kathmandu" + ) + expected = pl.Series([datetime(2020, 1, 1)]).dt.replace_time_zone("Asia/Kathmandu") + assert_series_equal(result, expected) + # format-inferred path + result = pl.Series(["2020-01-01"]).str.to_datetime(time_zone="Asia/Kathmandu") + assert_series_equal(result, expected) + + +@pytest.mark.parametrize("exact", [True, False]) +def test_strptime_ambiguous_earliest(exact: bool) -> None: + result = ( + pl.Series(["2020-10-25 01:00"]) + .str.strptime( + pl.Datetime("us", "Europe/London"), ambiguous="earliest", exact=exact + ) + .item() + ) + expected = datetime(2020, 10, 25, 1, fold=0, tzinfo=ZoneInfo("Europe/London")) + assert result == expected + result = ( + pl.Series(["2020-10-25 01:00"]) + .str.strptime( + pl.Datetime("us", "Europe/London"), ambiguous="latest", exact=exact + ) + .item() + ) + expected = datetime(2020, 10, 25, 1, fold=1, tzinfo=ZoneInfo("Europe/London")) + assert result == expected + with pytest.raises(ComputeError): + pl.Series(["2020-10-25 01:00"]).str.strptime( + pl.Datetime("us", "Europe/London"), + exact=exact, + ).item() + + +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_to_datetime_out_of_range_13401(time_unit: TimeUnit) -> None: + s = pl.Series(["2020-January-01 12:34:66"]) + with pytest.raises(InvalidOperationError, match="conversion .* failed"): + s.str.to_datetime("%Y-%B-%d %H:%M:%S", time_unit=time_unit) + assert ( + s.str.to_datetime("%Y-%B-%d %H:%M:%S", strict=False, time_unit=time_unit).item() + is None + ) + + +def test_out_of_ns_range_no_tu_specified_13592() -> None: + df = pl.DataFrame({"dates": ["2022-08-31 00:00:00.0", "0920-09-18 00:00:00.0"]}) + result = df.select(pl.col("dates").str.to_datetime(format="%Y-%m-%d %H:%M:%S%.f"))[ + "dates" + ] + expected = pl.Series( + "dates", + [datetime(2022, 8, 31, 0, 0), datetime(920, 9, 18, 0, 0)], + dtype=pl.Datetime("us"), + ) + assert_series_equal(result, expected) + + +def test_wrong_format_percent() -> None: + with pytest.raises(InvalidOperationError): + pl.Series(["2019-01-01"]).str.strptime(pl.Date, format="d%") + + +def test_polars_parser_fooled_by_trailing_nonsense_22167() -> None: + with pytest.raises(InvalidOperationError): + pl.Series(["2025-04-06T18:57:42.77756192Z"]).str.to_datetime( + "%Y-%m-%dT%H:%M:%S.%9fcabbagebananapotato" + ) + with pytest.raises(InvalidOperationError): + pl.Series(["2025-04-06T18:57:42.77756192Z"]).str.to_datetime( + "%Y-%m-%dT%H:%M:%S.%9f#z" + ) + with pytest.raises(InvalidOperationError): + pl.Series(["2025-04-06T18:57:42.77Z"]).str.to_datetime( + "%Y-%m-%dT%H:%M:%S.%3f#z" + ) + with pytest.raises(InvalidOperationError): + pl.Series(["2025-04-06T18:57:42.77123Z"]).str.to_datetime( + "%Y-%m-%dT%H:%M:%S.%6f#z" + ) + + +def test_strptime_empty_input_22214() -> None: + s = pl.Series("x", [], pl.String) + + assert s.str.strptime(pl.Time, "%H:%M:%S%.f").is_empty() + assert s.str.strptime(pl.Date, "%Y-%m-%d").is_empty() + assert s.str.strptime(pl.Datetime, "%Y-%m-%d %H:%M%#z").is_empty() diff --git a/py-polars/tests/unit/operations/namespaces/test_struct.py b/py-polars/tests/unit/operations/namespaces/test_struct.py new file mode 100644 index 000000000000..e418e3c84a2d --- /dev/null +++ b/py-polars/tests/unit/operations/namespaces/test_struct.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import datetime +from collections import OrderedDict + +import pytest + +import polars as pl +from polars.exceptions import ( + OutOfBoundsError, +) +from polars.testing import assert_frame_equal + + +def test_struct_various() -> None: + df = pl.DataFrame( + {"int": [1, 2], "str": ["a", "b"], "bool": [True, None], "list": [[1, 2], [3]]} + ) + s = df.to_struct("my_struct") + + assert s.struct.fields == ["int", "str", "bool", "list"] + assert s[0] == {"int": 1, "str": "a", "bool": True, "list": [1, 2]} + assert s[1] == {"int": 2, "str": "b", "bool": None, "list": [3]} + assert s.struct.field("list").to_list() == [[1, 2], [3]] + assert s.struct.field("int").to_list() == [1, 2] + assert s.struct["list"].to_list() == [[1, 2], [3]] + assert s.struct["int"].to_list() == [1, 2] + + for s, expected_name in ( + (df.to_struct(), ""), + (df.to_struct("my_struct"), "my_struct"), + ): + assert s.name == expected_name + assert_frame_equal(s.struct.unnest(), df) + assert s.struct._ipython_key_completions_() == s.struct.fields + + +def test_rename_fields() -> None: + df = pl.DataFrame({"int": [1, 2], "str": ["a", "b"], "bool": [True, None]}) + s = df.to_struct("my_struct").struct.rename_fields(["a", "b"]) + assert s.struct.fields == ["a", "b"] + + +def test_struct_json_encode() -> None: + assert pl.DataFrame( + {"a": [{"a": [1, 2], "b": [45]}, {"a": [9, 1, 3], "b": None}]} + ).with_columns(pl.col("a").struct.json_encode().alias("encoded")).to_dict( + as_series=False + ) == { + "a": [{"a": [1, 2], "b": [45]}, {"a": [9, 1, 3], "b": None}], + "encoded": ['{"a":[1,2],"b":[45]}', '{"a":[9,1,3],"b":null}'], + } + + +def test_struct_json_encode_logical_type() -> None: + df = pl.DataFrame( + { + "a": [ + { + "a": [datetime.date(1997, 1, 1)], + "b": [datetime.datetime(2000, 1, 29, 10, 30)], + "c": [datetime.timedelta(1, 25)], + } + ] + } + ).select(pl.col("a").struct.json_encode().alias("encoded")) + assert df.to_dict(as_series=False) == { + "encoded": ['{"a":["1997-01-01"],"b":["2000-01-29 10:30:00"],"c":["PT86425S"]}'] + } + + +def test_map_fields() -> None: + df = pl.DataFrame({"x": {"a": 1, "b": 2}}) + assert df.schema == OrderedDict([("x", pl.Struct({"a": pl.Int64, "b": pl.Int64}))]) + df = df.select(pl.col("x").name.map_fields(lambda x: x.upper())) + assert df.schema == OrderedDict([("x", pl.Struct({"A": pl.Int64, "B": pl.Int64}))]) + + +def test_prefix_suffix_fields() -> None: + df = pl.DataFrame({"x": {"a": 1, "b": 2}}) + + prefix_df = df.select(pl.col("x").name.prefix_fields("p_")) + assert prefix_df.schema == OrderedDict( + [("x", pl.Struct({"p_a": pl.Int64, "p_b": pl.Int64}))] + ) + + suffix_df = df.select(pl.col("x").name.suffix_fields("_f")) + assert suffix_df.schema == OrderedDict( + [("x", pl.Struct({"a_f": pl.Int64, "b_f": pl.Int64}))] + ) + + +def test_struct_alias_prune_15401() -> None: + df = pl.DataFrame({"a": []}, schema={"a": pl.Struct({"b": pl.Int8})}) + assert df.select(pl.col("a").alias("c").struct.field("b")).columns == ["b"] + + +def test_empty_list_eval_schema_5734() -> None: + df = pl.DataFrame({"a": [[{"b": 1, "c": 2}]]}) + assert df.filter(False).select( + pl.col("a").list.eval(pl.element().struct.field("b")) + ).schema == {"a": pl.List(pl.Int64)} + + +def test_field_by_index_18732() -> None: + df = pl.DataFrame({"foo": [{"a": 1, "b": 2}, {"a": 2, "b": 1}]}) + + # illegal upper bound + with pytest.raises(OutOfBoundsError, match=r"index 2 for length: 2"): + df.filter(pl.col.foo.struct[2] == 1) + + # legal + expected_df = pl.DataFrame({"foo": [{"a": 1, "b": 2}]}) + result_df = df.filter(pl.col.foo.struct[0] == 1) + assert_frame_equal(expected_df, result_df) + + expected_df = pl.DataFrame({"foo": [{"a": 2, "b": 1}]}) + result_df = df.filter(pl.col.foo.struct[-1] == 1) + assert_frame_equal(expected_df, result_df) diff --git a/py-polars/tests/unit/operations/rolling/__init__.py b/py-polars/tests/unit/operations/rolling/__init__.py new file mode 100644 index 000000000000..6cb7a378fc9a --- /dev/null +++ b/py-polars/tests/unit/operations/rolling/__init__.py @@ -0,0 +1 @@ +"""Test module for rolling window operations.""" diff --git a/py-polars/tests/unit/operations/rolling/test_map.py b/py-polars/tests/unit/operations/rolling/test_map.py new file mode 100644 index 000000000000..ff2620788fc5 --- /dev/null +++ b/py-polars/tests/unit/operations/rolling/test_map.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import pytest + +import polars as pl +from polars.testing import assert_series_equal +from tests.unit.conftest import INTEGER_DTYPES + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + + +@pytest.mark.parametrize( + ("input", "output"), + [ + ([1, 5], [1, 6]), + ([1], [1]), + ], +) +def test_rolling_map_window_size_9160(input: list[int], output: list[int]) -> None: + s = pl.Series(input) + result = s.rolling_map(lambda x: sum(x), window_size=2, min_samples=1) + expected = pl.Series(output) + assert_series_equal(result, expected) + + +def testing_rolling_map_window_size_with_nulls() -> None: + s = pl.Series([0, 1, None, 3, 4, 5]) + result = s.rolling_map(lambda x: sum(x), window_size=3, min_samples=3) + expected = pl.Series([None, None, None, None, None, 12]) + assert_series_equal(result, expected) + + +def test_rolling_map_clear_reuse_series_state_10681() -> None: + df = pl.DataFrame( + { + "a": [1, 1, 1, 1, 2, 2, 2, 2], + "b": [0.0, 1.0, 11.0, 7.0, 4.0, 2.0, 3.0, 8.0], + } + ) + + result = df.select( + pl.col("b") + .rolling_map(lambda s: s.min(), window_size=3, min_samples=2) + .over("a") + .alias("min") + ) + + expected = pl.Series("min", [None, 0.0, 0.0, 1.0, None, 2.0, 2.0, 2.0]) + assert_series_equal(result.to_series(), expected) + + +def test_rolling_map_np_nansum() -> None: + s = pl.Series("a", [11.0, 2.0, 9.0, float("nan"), 8.0]) + + result = s.rolling_map(np.nansum, 3) + + expected = pl.Series("a", [None, None, 22.0, 11.0, 17.0]) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize("dtype", [pl.Float32, pl.Float64]) +def test_rolling_map_std(dtype: PolarsDataType) -> None: + s = pl.Series("A", [1.0, 2.0, 9.0, 2.0, 13.0], dtype=dtype) + result = s.rolling_map(function=lambda s: s.std(), window_size=3) + + expected = pl.Series("A", [None, None, 4.358899, 4.041452, 5.567764], dtype=dtype) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize("dtype", [pl.Float32, pl.Float64]) +def test_rolling_map_std_weights(dtype: PolarsDataType) -> None: + s = pl.Series("A", [1.0, 2.0, 9.0, 2.0, 13.0], dtype=dtype) + + result = s.rolling_map( + function=lambda s: s.std(), window_size=3, weights=[1.0, 2.0, 3.0] + ) + + expected = pl.Series("A", [None, None, 14.224392, 8.326664, 18.929694], dtype=dtype) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize("dtype", INTEGER_DTYPES) +def test_rolling_map_sum_int(dtype: PolarsDataType) -> None: + s = pl.Series("A", [1, 2, 9, 2, 13], dtype=dtype) + + result = s.rolling_map(function=lambda s: s.sum(), window_size=3) + + expected = pl.Series("A", [None, None, 12, 13, 24], dtype=dtype) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize("dtype", INTEGER_DTYPES) +def test_rolling_map_sum_int_cast_to_float(dtype: PolarsDataType) -> None: + s = pl.Series("A", [1, 2, 9, None, 13], dtype=dtype) + + result = s.rolling_map( + function=lambda s: s.sum(), window_size=3, weights=[1.0, 2.0, 3.0] + ) + + expected = pl.Series("A", [None, None, 32.0, None, None], dtype=pl.Float64) + assert_series_equal(result, expected) + + +def test_rolling_map_rolling_sum() -> None: + s = pl.Series("A", list(range(5)), dtype=pl.Float64) + + result = s.rolling_map( + function=lambda s: s.sum(), + window_size=3, + weights=[1.0, 2.1, 3.2], + min_samples=2, + center=True, + ) + + expected = s.rolling_sum( + window_size=3, weights=[1.0, 2.1, 3.2], min_samples=2, center=True + ) + assert_series_equal(result, expected) + + +def test_rolling_map_rolling_std() -> None: + s = pl.Series("A", list(range(6)), dtype=pl.Float64) + + result = s.rolling_map( + function=lambda s: s.std(), + window_size=4, + min_samples=3, + center=False, + ) + + expected = s.rolling_std(window_size=4, min_samples=3, center=False) + assert_series_equal(result, expected) diff --git a/py-polars/tests/unit/operations/rolling/test_rolling.py b/py-polars/tests/unit/operations/rolling/test_rolling.py new file mode 100644 index 000000000000..5b85ae5bbc95 --- /dev/null +++ b/py-polars/tests/unit/operations/rolling/test_rolling.py @@ -0,0 +1,1510 @@ +from __future__ import annotations + +import sys +from datetime import date, datetime, timedelta +from typing import TYPE_CHECKING + +import hypothesis.strategies as st +import numpy as np +import pytest +from hypothesis import assume, given +from numpy import nan + +import polars as pl +from polars._utils.convert import parse_as_duration_string +from polars.exceptions import ComputeError, InvalidOperationError +from polars.testing import assert_frame_equal, assert_series_equal +from polars.testing.parametric import column, dataframes +from polars.testing.parametric.strategies.dtype import _time_units +from tests.unit.conftest import INTEGER_DTYPES + +if TYPE_CHECKING: + from hypothesis.strategies import SearchStrategy + + from polars._typing import ( + ClosedInterval, + PolarsDataType, + TimeUnit, + ) + + +@pytest.fixture +def example_df() -> pl.DataFrame: + return pl.DataFrame( + { + "dt": [ + datetime(2021, 1, 1), + datetime(2021, 1, 2), + datetime(2021, 1, 4), + datetime(2021, 1, 5), + datetime(2021, 1, 7), + ], + "values": pl.arange(0, 5, eager=True), + } + ) + + +@pytest.mark.parametrize( + "period", + ["1d", "2d", "3d", timedelta(days=1), timedelta(days=2), timedelta(days=3)], +) +@pytest.mark.parametrize("closed", ["left", "right", "none", "both"]) +def test_rolling_kernels_and_rolling( + example_df: pl.DataFrame, period: str | timedelta, closed: ClosedInterval +) -> None: + out1 = example_df.set_sorted("dt").select( + pl.col("dt"), + # this differs from group_by aggregation because the empty window is + # null here + # where the sum aggregation of an empty set is 0 + pl.col("values") + .rolling_sum_by("dt", period, closed=closed) + .fill_null(0) + .alias("sum"), + pl.col("values").rolling_var_by("dt", period, closed=closed).alias("var"), + pl.col("values").rolling_mean_by("dt", period, closed=closed).alias("mean"), + pl.col("values").rolling_std_by("dt", period, closed=closed).alias("std"), + pl.col("values") + .rolling_quantile_by("dt", period, quantile=0.2, closed=closed) + .alias("quantile"), + ) + out2 = ( + example_df.set_sorted("dt") + .rolling("dt", period=period, closed=closed) + .agg( + [ + pl.col("values").sum().alias("sum"), + pl.col("values").var().alias("var"), + pl.col("values").mean().alias("mean"), + pl.col("values").std().alias("std"), + pl.col("values").quantile(quantile=0.2).alias("quantile"), + ] + ) + ) + assert_frame_equal(out1, out2) + + +@pytest.mark.parametrize( + ("offset", "closed", "expected_values"), + [ + pytest.param( + "-1d", + "left", + [[1], [1, 2], [2, 3], [3, 4]], + id="partial lookbehind, left", + ), + pytest.param( + "-1d", + "right", + [[1, 2], [2, 3], [3, 4], [4]], + id="partial lookbehind, right", + ), + pytest.param( + "-1d", + "both", + [[1, 2], [1, 2, 3], [2, 3, 4], [3, 4]], + id="partial lookbehind, both", + ), + pytest.param( + "-1d", + "none", + [[1], [2], [3], [4]], + id="partial lookbehind, none", + ), + pytest.param( + "-2d", + "left", + [[], [1], [1, 2], [2, 3]], + id="full lookbehind, left", + ), + pytest.param( + "-3d", + "left", + [[], [], [1], [1, 2]], + id="full lookbehind, offset > period, left", + ), + pytest.param( + "-3d", + "right", + [[], [1], [1, 2], [2, 3]], + id="full lookbehind, right", + ), + pytest.param( + "-3d", + "both", + [[], [1], [1, 2], [1, 2, 3]], + id="full lookbehind, both", + ), + pytest.param( + "-2d", + "none", + [[], [1], [2], [3]], + id="full lookbehind, none", + ), + pytest.param( + "-3d", + "none", + [[], [], [1], [2]], + id="full lookbehind, offset > period, none", + ), + ], +) +def test_rolling_negative_offset( + offset: str, closed: ClosedInterval, expected_values: list[list[int]] +) -> None: + df = pl.DataFrame( + { + "ts": pl.datetime_range( + datetime(2021, 1, 1), datetime(2021, 1, 4), "1d", eager=True + ), + "value": [1, 2, 3, 4], + } + ) + result = df.rolling("ts", period="2d", offset=offset, closed=closed).agg( + pl.col("value") + ) + expected = pl.DataFrame( + { + "ts": pl.datetime_range( + datetime(2021, 1, 1), datetime(2021, 1, 4), "1d", eager=True + ), + "value": expected_values, + } + ) + assert_frame_equal(result, expected) + + +def test_rolling_skew() -> None: + s = pl.Series([1, 2, 3, 3, 2, 10, 8]) + assert s.rolling_skew(window_size=4, bias=True).to_list() == pytest.approx( + [ + None, + None, + None, + -0.49338220021815865, + 0.0, + 1.097025449363867, + 0.09770939201338157, + ] + ) + + assert s.rolling_skew(window_size=4, bias=False).to_list() == pytest.approx( + [ + None, + None, + None, + -0.8545630383279711, + 0.0, + 1.9001038154942962, + 0.16923763134384154, + ] + ) + + +def test_rolling_kurtosis() -> None: + s = pl.Series([1, 2, 3, 3, 2, 10, 8]) + assert s.rolling_kurtosis(window_size=4, bias=True).to_list() == pytest.approx( + [ + None, + None, + None, + -1.371900826446281, + -1.9999999999999991, + -0.7055324211778693, + -1.7878967572797346, + ] + ) + assert s.rolling_kurtosis( + window_size=4, bias=True, fisher=False + ).to_list() == pytest.approx( + [ + None, + None, + None, + 1.628099173553719, + 1.0000000000000009, + 2.2944675788221307, + 1.2121032427202654, + ] + ) + + +@pytest.mark.parametrize("time_zone", [None, "US/Central"]) +@pytest.mark.parametrize( + ("rolling_fn", "expected_values", "expected_dtype"), + [ + ("rolling_mean_by", [None, 1.0, 2.0, 3.0, 4.0, 5.0], pl.Float64), + ("rolling_sum_by", [None, 1, 2, 3, 4, 5], pl.Int64), + ("rolling_min_by", [None, 1, 2, 3, 4, 5], pl.Int64), + ("rolling_max_by", [None, 1, 2, 3, 4, 5], pl.Int64), + ("rolling_std_by", [None, None, None, None, None, None], pl.Float64), + ("rolling_var_by", [None, None, None, None, None, None], pl.Float64), + ], +) +def test_rolling_crossing_dst( + time_zone: str | None, + rolling_fn: str, + expected_values: list[int | None | float], + expected_dtype: PolarsDataType, +) -> None: + ts = pl.datetime_range( + datetime(2021, 11, 5), datetime(2021, 11, 10), "1d", time_zone="UTC", eager=True + ).dt.replace_time_zone(time_zone) + df = pl.DataFrame({"ts": ts, "value": [1, 2, 3, 4, 5, 6]}) + + result = df.with_columns( + getattr(pl.col("value"), rolling_fn)(by="ts", window_size="1d", closed="left") + ) + + expected = pl.DataFrame( + {"ts": ts, "value": expected_values}, schema_overrides={"value": expected_dtype} + ) + assert_frame_equal(result, expected) + + +def test_rolling_by_invalid() -> None: + df = pl.DataFrame( + {"a": [1, 2, 3], "b": [4, 5, 6]}, schema_overrides={"a": pl.Int16} + ).sort("a") + msg = "unsupported data type: i16 for temporal/index column, expected UInt64, UInt32, Int64, Int32, Datetime, Date, Duration, or Time" + with pytest.raises(InvalidOperationError, match=msg): + df.select(pl.col("b").rolling_min_by("a", "2i")) + df = pl.DataFrame({"a": [1, 2, 3], "b": [date(2020, 1, 1)] * 3}).sort("b") + msg = "`window_size` duration may not be a parsed integer" + with pytest.raises(InvalidOperationError, match=msg): + df.select(pl.col("a").rolling_min_by("b", "2i")) + + +def test_rolling_infinity() -> None: + s = pl.Series("col", ["-inf", "5", "5"]).cast(pl.Float64) + s = s.rolling_mean(2) + expected = pl.Series("col", [None, "-inf", "5"]).cast(pl.Float64) + assert_series_equal(s, expected) + + +def test_rolling_by_non_temporal_window_size() -> None: + df = pl.DataFrame( + {"a": [4, 5, 6], "b": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)]} + ).sort("a", "b") + msg = "`window_size` duration may not be a parsed integer" + with pytest.raises(InvalidOperationError, match=msg): + df.with_columns(pl.col("a").rolling_sum_by("b", "2i", closed="left")) + + +@pytest.mark.parametrize( + "dtype", + [ + pl.UInt8, + pl.Int64, + pl.Float32, + pl.Float64, + pl.Time, + pl.Date, + pl.Datetime("ms"), + pl.Datetime("us"), + pl.Datetime("ns"), + pl.Datetime("ns", "Asia/Kathmandu"), + pl.Duration("ms"), + pl.Duration("us"), + pl.Duration("ns"), + ], +) +def test_rolling_extrema(dtype: PolarsDataType) -> None: + # sorted data and nulls flags trigger different kernels + df = ( + ( + pl.DataFrame( + { + "col1": pl.int_range(0, 7, eager=True), + "col2": pl.int_range(0, 7, eager=True).reverse(), + } + ) + ) + .with_columns( + pl.when(pl.int_range(0, pl.len(), eager=False) < 2) + .then(None) + .otherwise(pl.all()) + .name.suffix("_nulls") + ) + .cast(dtype) + ) + + expected = { + "col1": [None, None, 0, 1, 2, 3, 4], + "col2": [None, None, 4, 3, 2, 1, 0], + "col1_nulls": [None, None, None, None, 2, 3, 4], + "col2_nulls": [None, None, None, None, 2, 1, 0], + } + result = df.select([pl.all().rolling_min(3)]) + assert result.to_dict(as_series=False) == { + k: pl.Series(v, dtype=dtype).to_list() for k, v in expected.items() + } + + expected = { + "col1": [None, None, 2, 3, 4, 5, 6], + "col2": [None, None, 6, 5, 4, 3, 2], + "col1_nulls": [None, None, None, None, 4, 5, 6], + "col2_nulls": [None, None, None, None, 4, 3, 2], + } + result = df.select([pl.all().rolling_max(3)]) + assert result.to_dict(as_series=False) == { + k: pl.Series(v, dtype=dtype).to_list() for k, v in expected.items() + } + + # shuffled data triggers other kernels + df = df.select([pl.all().shuffle(0)]) + expected = { + "col1": [None, None, 0, 0, 1, 2, 2], + "col2": [None, None, 0, 2, 1, 1, 1], + "col1_nulls": [None, None, None, None, None, 2, 2], + "col2_nulls": [None, None, None, None, None, 1, 1], + } + result = df.select([pl.all().rolling_min(3)]) + assert result.to_dict(as_series=False) == { + k: pl.Series(v, dtype=dtype).to_list() for k, v in expected.items() + } + + result = df.select([pl.all().rolling_max(3)]) + expected = { + "col1": [None, None, 6, 4, 5, 5, 5], + "col2": [None, None, 6, 6, 5, 4, 4], + "col1_nulls": [None, None, None, None, None, 5, 5], + "col2_nulls": [None, None, None, None, None, 4, 4], + } + assert result.to_dict(as_series=False) == { + k: pl.Series(v, dtype=dtype).to_list() for k, v in expected.items() + } + + +@pytest.mark.parametrize( + "dtype", + [ + pl.UInt8, + pl.Int64, + pl.Float32, + pl.Float64, + pl.Time, + pl.Date, + pl.Datetime("ms"), + pl.Datetime("us"), + pl.Datetime("ns"), + pl.Datetime("ns", "Asia/Kathmandu"), + pl.Duration("ms"), + pl.Duration("us"), + pl.Duration("ns"), + ], +) +def test_rolling_group_by_extrema(dtype: PolarsDataType) -> None: + # ensure we hit different branches so create + + df = pl.DataFrame( + { + "col1": pl.arange(0, 7, eager=True).reverse(), + } + ).with_columns( + pl.col("col1").reverse().alias("index"), + pl.col("col1").cast(dtype), + ) + + expected = { + "col1_list": pl.Series( + [ + [6], + [6, 5], + [6, 5, 4], + [5, 4, 3], + [4, 3, 2], + [3, 2, 1], + [2, 1, 0], + ], + dtype=pl.List(dtype), + ).to_list(), + "col1_min": pl.Series([6, 5, 4, 3, 2, 1, 0], dtype=dtype).to_list(), + "col1_max": pl.Series([6, 6, 6, 5, 4, 3, 2], dtype=dtype).to_list(), + "col1_first": pl.Series([6, 6, 6, 5, 4, 3, 2], dtype=dtype).to_list(), + "col1_last": pl.Series([6, 5, 4, 3, 2, 1, 0], dtype=dtype).to_list(), + } + result = ( + df.rolling( + index_column="index", + period="3i", + ) + .agg( + [ + pl.col("col1").name.suffix("_list"), + pl.col("col1").min().name.suffix("_min"), + pl.col("col1").max().name.suffix("_max"), + pl.col("col1").first().alias("col1_first"), + pl.col("col1").last().alias("col1_last"), + ] + ) + .select(["col1_list", "col1_min", "col1_max", "col1_first", "col1_last"]) + ) + assert result.to_dict(as_series=False) == expected + + # ascending order + + df = pl.DataFrame( + { + "col1": pl.arange(0, 7, eager=True), + } + ).with_columns( + pl.col("col1").alias("index"), + pl.col("col1").cast(dtype), + ) + + result = ( + df.rolling( + index_column="index", + period="3i", + ) + .agg( + [ + pl.col("col1").name.suffix("_list"), + pl.col("col1").min().name.suffix("_min"), + pl.col("col1").max().name.suffix("_max"), + pl.col("col1").first().alias("col1_first"), + pl.col("col1").last().alias("col1_last"), + ] + ) + .select(["col1_list", "col1_min", "col1_max", "col1_first", "col1_last"]) + ) + expected = { + "col1_list": pl.Series( + [ + [0], + [0, 1], + [0, 1, 2], + [1, 2, 3], + [2, 3, 4], + [3, 4, 5], + [4, 5, 6], + ], + dtype=pl.List(dtype), + ).to_list(), + "col1_min": pl.Series([0, 0, 0, 1, 2, 3, 4], dtype=dtype).to_list(), + "col1_max": pl.Series([0, 1, 2, 3, 4, 5, 6], dtype=dtype).to_list(), + "col1_first": pl.Series([0, 0, 0, 1, 2, 3, 4], dtype=dtype).to_list(), + "col1_last": pl.Series([0, 1, 2, 3, 4, 5, 6], dtype=dtype).to_list(), + } + assert result.to_dict(as_series=False) == expected + + # shuffled data. + df = pl.DataFrame( + { + "col1": pl.arange(0, 7, eager=True).shuffle(1), + } + ).with_columns( + pl.col("col1").cast(dtype), + pl.col("col1").sort().alias("index"), + ) + + result = ( + df.rolling( + index_column="index", + period="3i", + ) + .agg( + [ + pl.col("col1").min().name.suffix("_min"), + pl.col("col1").max().name.suffix("_max"), + pl.col("col1").name.suffix("_list"), + ] + ) + .select(["col1_list", "col1_min", "col1_max"]) + ) + expected = { + "col1_list": pl.Series( + [ + [3], + [3, 4], + [3, 4, 5], + [4, 5, 6], + [5, 6, 2], + [6, 2, 1], + [2, 1, 0], + ], + dtype=pl.List(dtype), + ).to_list(), + "col1_min": pl.Series([3, 3, 3, 4, 2, 1, 0], dtype=dtype).to_list(), + "col1_max": pl.Series([3, 4, 5, 6, 6, 6, 2], dtype=dtype).to_list(), + } + assert result.to_dict(as_series=False) == expected + + +def test_rolling_slice_pushdown() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": ["a", "a", "b"], "c": [1, 3, 5]}).lazy() + df = ( + df.sort("a") + .rolling( + "a", + group_by="b", + period="2i", + ) + .agg([(pl.col("c") - pl.col("c").shift(fill_value=0)).sum().alias("c")]) + ) + assert df.head(2).collect().to_dict(as_series=False) == { + "b": ["a", "a"], + "a": [1, 2], + "c": [1, 3], + } + + +def test_overlapping_groups_4628() -> None: + df = pl.DataFrame( + { + "index": [1, 2, 3, 4, 5, 6], + "val": [10, 20, 40, 70, 110, 160], + } + ) + assert ( + df.rolling(index_column=pl.col("index").set_sorted(), period="3i").agg( + [ + pl.col("val").diff(n=1).alias("val.diff"), + (pl.col("val") - pl.col("val").shift(1)).alias("val - val.shift"), + ] + ) + ).to_dict(as_series=False) == { + "index": [1, 2, 3, 4, 5, 6], + "val.diff": [ + [None], + [None, 10], + [None, 10, 20], + [None, 20, 30], + [None, 30, 40], + [None, 40, 50], + ], + "val - val.shift": [ + [None], + [None, 10], + [None, 10, 20], + [None, 20, 30], + [None, 30, 40], + [None, 40, 50], + ], + } + + +@pytest.mark.skipif(sys.platform == "win32", reason="Minor numerical diff") +def test_rolling_skew_lagging_null_5179() -> None: + s = pl.Series([None, 3, 4, 1, None, None, None, None, 3, None, 5, 4, 7, 2, 1, None]) + result = s.rolling_skew(3, min_samples=1).fill_nan(-1.0) + expected = pl.Series( + [ + None, + -1.0, + 0.0, + -0.3818017741606059, + 0.0, + -1.0, + None, + None, + -1.0, + -1.0, + 0.0, + 0.0, + 0.38180177416060695, + 0.23906314692954517, + 0.6309038567106234, + 0.0, + ] + ) + assert_series_equal(result, expected, check_names=False) + + +def test_rolling_var_numerical_stability_5197() -> None: + s = pl.Series([*[1.2] * 4, *[3.3] * 7]) + res = s.to_frame("a").with_columns(pl.col("a").rolling_var(5))[:, 0].to_list() + assert res[4:] == pytest.approx( + [ + 0.882, + 1.3229999999999997, + 1.3229999999999997, + 0.8819999999999983, + 0.0, + 0.0, + 0.0, + ] + ) + assert res[:4] == [None] * 4 + + +def test_rolling_iter() -> None: + df = pl.DataFrame( + { + "date": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 5)], + "a": [1, 2, 2], + "b": [4, 5, 6], + } + ).set_sorted("date") + + # Without 'by' argument + result1 = [ + (name[0], data.shape) + for name, data in df.rolling(index_column="date", period="2d") + ] + expected1 = [ + (date(2020, 1, 1), (1, 3)), + (date(2020, 1, 2), (2, 3)), + (date(2020, 1, 5), (1, 3)), + ] + assert result1 == expected1 + + # With 'by' argument + result2 = [ + (name, data.shape) + for name, data in df.rolling(index_column="date", period="2d", group_by="a") + ] + expected2 = [ + ((1, date(2020, 1, 1)), (1, 3)), + ((2, date(2020, 1, 2)), (1, 3)), + ((2, date(2020, 1, 5)), (1, 3)), + ] + assert result2 == expected2 + + +def test_rolling_negative_period() -> None: + df = pl.DataFrame({"ts": [datetime(2020, 1, 1)], "value": [1]}).with_columns( + pl.col("ts").set_sorted() + ) + with pytest.raises( + ComputeError, match="rolling window period should be strictly positive" + ): + df.rolling("ts", period="-1d", offset="-1d").agg(pl.col("value")) + with pytest.raises( + ComputeError, match="rolling window period should be strictly positive" + ): + df.lazy().rolling("ts", period="-1d", offset="-1d").agg( + pl.col("value") + ).collect() + with pytest.raises( + InvalidOperationError, match="`window_size` must be strictly positive" + ): + df.select( + pl.col("value").rolling_min_by("ts", window_size="-1d", closed="left") + ) + with pytest.raises( + InvalidOperationError, match="`window_size` must be strictly positive" + ): + df.lazy().select( + pl.col("value").rolling_min_by("ts", window_size="-1d", closed="left") + ).collect() + + +def test_rolling_skew_window_offset() -> None: + assert (pl.arange(0, 20, eager=True) ** 2).rolling_skew(20)[ + -1 + ] == 0.6612545648596286 + + +def test_rolling_cov_corr() -> None: + df = pl.DataFrame({"x": [3, 3, 3, 5, 8], "y": [3, 4, 4, 4, 8]}) + + res = df.select( + pl.rolling_cov("x", "y", window_size=3).alias("cov"), + pl.rolling_corr("x", "y", window_size=3).alias("corr"), + ).to_dict(as_series=False) + assert res["cov"][2:] == pytest.approx([0.0, 0.0, 5.333333333333336]) + assert res["corr"][2:] == pytest.approx([nan, 0.0, 0.9176629354822473], nan_ok=True) + assert res["cov"][:2] == [None] * 2 + assert res["corr"][:2] == [None] * 2 + + +def test_rolling_cov_corr_nulls() -> None: + df1 = pl.DataFrame( + {"a": [1.06, 1.07, 0.93, 0.78, 0.85], "lag_a": [1.0, 1.06, 1.07, 0.93, 0.78]} + ) + df2 = pl.DataFrame( + { + "a": [1.0, 1.06, 1.07, 0.93, 0.78, 0.85], + "lag_a": [None, 1.0, 1.06, 1.07, 0.93, 0.78], + } + ) + + val_1 = df1.select( + pl.rolling_corr("a", "lag_a", window_size=10, min_samples=5, ddof=1) + ) + val_2 = df2.select( + pl.rolling_corr("a", "lag_a", window_size=10, min_samples=5, ddof=1) + ) + + df1_expected = pl.DataFrame({"a": [None, None, None, None, 0.62204709]}) + df2_expected = pl.DataFrame({"a": [None, None, None, None, None, 0.62204709]}) + + assert_frame_equal(val_1, df1_expected, atol=0.0000001) + assert_frame_equal(val_2, df2_expected, atol=0.0000001) + + val_1 = df1.select( + pl.rolling_cov("a", "lag_a", window_size=10, min_samples=5, ddof=1) + ) + val_2 = df2.select( + pl.rolling_cov("a", "lag_a", window_size=10, min_samples=5, ddof=1) + ) + + df1_expected = pl.DataFrame({"a": [None, None, None, None, 0.009445]}) + df2_expected = pl.DataFrame({"a": [None, None, None, None, None, 0.009445]}) + + assert_frame_equal(val_1, df1_expected, atol=0.0000001) + assert_frame_equal(val_2, df2_expected, atol=0.0000001) + + +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_rolling_empty_window_9406(time_unit: TimeUnit) -> None: + datecol = pl.Series( + "d", + [datetime(2019, 1, x) for x in [16, 17, 18, 22, 23]], + dtype=pl.Datetime(time_unit=time_unit, time_zone=None), + ).set_sorted() + rawdata = pl.Series("x", [1.1, 1.2, 1.3, 1.15, 1.25], dtype=pl.Float64) + rmin = pl.Series("x", [None, 1.1, 1.1, None, 1.15], dtype=pl.Float64) + rmax = pl.Series("x", [None, 1.1, 1.2, None, 1.15], dtype=pl.Float64) + df = pl.DataFrame([datecol, rawdata]) + + assert_frame_equal( + pl.DataFrame([datecol, rmax]), + df.select( + pl.col("d"), + pl.col("x").rolling_max_by("d", window_size="3d", closed="left"), + ), + ) + assert_frame_equal( + pl.DataFrame([datecol, rmin]), + df.select( + pl.col("d"), + pl.col("x").rolling_min_by("d", window_size="3d", closed="left"), + ), + ) + + +def test_rolling_weighted_quantile_10031() -> None: + assert_series_equal( + pl.Series([1, 2]).rolling_median(window_size=2, weights=[0, 1]), + pl.Series([None, 2.0]), + ) + + assert_series_equal( + pl.Series([1, 2, 3, 5]).rolling_quantile(0.7, "linear", 3, [0.1, 0.3, 0.6]), + pl.Series([None, None, 2.55, 4.1]), + ) + + assert_series_equal( + pl.Series([1, 2, 3, 5, 8]).rolling_quantile( + 0.7, "linear", 4, [0.1, 0.2, 0, 0.3] + ), + pl.Series([None, None, None, 3.5, 5.5]), + ) + + +def test_rolling_meta_eq_10101() -> None: + assert pl.col("A").rolling_sum(10).meta.eq(pl.col("A").rolling_sum(10)) is True + + +def test_rolling_aggregations_unsorted_raise_10991() -> None: + df = pl.DataFrame( + { + "dt": [datetime(2020, 1, 3), datetime(2020, 1, 1), datetime(2020, 1, 2)], + "val": [1, 2, 3], + } + ) + result = df.with_columns(roll=pl.col("val").rolling_sum_by("dt", "2d")) + expected = pl.DataFrame( + { + "dt": [datetime(2020, 1, 3), datetime(2020, 1, 1), datetime(2020, 1, 2)], + "val": [1, 2, 3], + "roll": [4, 2, 5], + } + ) + assert_frame_equal(result, expected) + result = ( + df.with_row_index() + .sort("dt") + .with_columns(roll=pl.col("val").rolling_sum_by("dt", "2d")) + .sort("index") + .drop("index") + ) + assert_frame_equal(result, expected) + + +def test_rolling_aggregations_with_over_11225() -> None: + start = datetime(2001, 1, 1) + + df_temporal = pl.DataFrame( + { + "date": [start + timedelta(days=k) for k in range(5)], + "group": ["A"] * 2 + ["B"] * 3, + } + ).with_row_index() + + df_temporal = df_temporal.sort("group", "date") + + result = df_temporal.with_columns( + rolling_row_mean=pl.col("index") + .rolling_mean_by( + by="date", + window_size="2d", + closed="left", + ) + .over("group") + ) + expected = pl.DataFrame( + { + "index": [0, 1, 2, 3, 4], + "date": pl.datetime_range(date(2001, 1, 1), date(2001, 1, 5), eager=True), + "group": ["A", "A", "B", "B", "B"], + "rolling_row_mean": [None, 0.0, None, 2.0, 2.5], + }, + schema_overrides={"index": pl.UInt32}, + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dtype", INTEGER_DTYPES) +def test_rolling_ints(dtype: PolarsDataType) -> None: + s = pl.Series("a", [1, 2, 3, 2, 1], dtype=dtype) + assert_series_equal( + s.rolling_min(2), pl.Series("a", [None, 1, 2, 2, 1], dtype=dtype) + ) + assert_series_equal( + s.rolling_max(2), pl.Series("a", [None, 2, 3, 3, 2], dtype=dtype) + ) + assert_series_equal( + s.rolling_sum(2), + pl.Series( + "a", + [None, 3, 5, 5, 3], + dtype=( + pl.Int64 if dtype in [pl.Int8, pl.UInt8, pl.Int16, pl.UInt16] else dtype + ), + ), + ) + assert_series_equal(s.rolling_mean(2), pl.Series("a", [None, 1.5, 2.5, 2.5, 1.5])) + + assert s.rolling_std(2).to_list()[1] == pytest.approx(0.7071067811865476) + assert s.rolling_var(2).to_list()[1] == pytest.approx(0.5) + assert s.rolling_std(2, ddof=0).to_list()[1] == pytest.approx(0.5) + assert s.rolling_var(2, ddof=0).to_list()[1] == pytest.approx(0.25) + + assert_series_equal( + s.rolling_median(4), pl.Series("a", [None, None, None, 2, 2], dtype=pl.Float64) + ) + assert_series_equal( + s.rolling_quantile(0, "nearest", 3), + pl.Series("a", [None, None, 1, 2, 1], dtype=pl.Float64), + ) + assert_series_equal( + s.rolling_quantile(0, "lower", 3), + pl.Series("a", [None, None, 1, 2, 1], dtype=pl.Float64), + ) + assert_series_equal( + s.rolling_quantile(0, "higher", 3), + pl.Series("a", [None, None, 1, 2, 1], dtype=pl.Float64), + ) + assert s.rolling_skew(4).null_count() == 3 + + +def test_rolling_floats() -> None: + # 3099 + # test if we maintain proper dtype + for dt in [pl.Float32, pl.Float64]: + result = pl.Series([1, 2, 3], dtype=dt).rolling_min(2, weights=[0.1, 0.2]) + expected = pl.Series([None, 0.1, 0.2], dtype=dt) + assert_series_equal(result, expected) + + df = pl.DataFrame({"val": [1.0, 2.0, 3.0, np.nan, 5.0, 6.0, 7.0]}) + + for e in [ + pl.col("val").rolling_min(window_size=3), + pl.col("val").rolling_max(window_size=3), + ]: + out = df.with_columns(e).to_series() + assert out.null_count() == 2 + assert np.isnan(out.to_numpy()).sum() == 5 + + expected_values = [None, None, 2.0, 3.0, 5.0, 6.0, 6.0] + assert ( + df.with_columns(pl.col("val").rolling_median(window_size=3)) + .to_series() + .to_list() + == expected_values + ) + assert ( + df.with_columns(pl.col("val").rolling_quantile(0.5, window_size=3)) + .to_series() + .to_list() + == expected_values + ) + + nan = float("nan") + s = pl.Series("a", [11.0, 2.0, 9.0, nan, 8.0]) + assert_series_equal( + s.rolling_sum(3), + pl.Series("a", [None, None, 22.0, nan, nan]), + ) + + +def test_rolling_std_nulls_min_samples_1_20076() -> None: + result = pl.Series([1, 2, None, 4]).rolling_std(3, min_samples=1) + expected = pl.Series( + [None, 0.7071067811865476, 0.7071067811865476, 1.4142135623730951] + ) + assert_series_equal(result, expected) + + +def test_rolling_by_date() -> None: + df = pl.DataFrame( + { + "dt": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)], + "val": [1, 2, 3], + } + ).sort("dt") + + result = df.with_columns(roll=pl.col("val").rolling_sum_by("dt", "2d")) + expected = df.with_columns(roll=pl.Series([1, 3, 5])) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dtype", [pl.Int64, pl.Int32, pl.UInt64, pl.UInt32]) +def test_rolling_by_integer(dtype: PolarsDataType) -> None: + df = ( + pl.DataFrame({"val": [1, 2, 3]}) + .with_row_index() + .with_columns(pl.col("index").cast(dtype)) + ) + result = df.with_columns(roll=pl.col("val").rolling_sum_by("index", "2i")) + expected = df.with_columns(roll=pl.Series([1, 3, 5])) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dtype", INTEGER_DTYPES) +def test_rolling_sum_by_integer(dtype: PolarsDataType) -> None: + lf = ( + pl.LazyFrame({"a": [1, 2, 3]}, schema={"a": dtype}) + .with_row_index() + .select(pl.col("a").rolling_sum_by("index", "2i")) + ) + result = lf.collect() + expected_dtype = ( + pl.Int64 if dtype in [pl.Int8, pl.UInt8, pl.Int16, pl.UInt16] else dtype + ) + expected = pl.DataFrame({"a": [1, 3, 5]}, schema={"a": expected_dtype}) + assert_frame_equal(result, expected) + assert lf.collect_schema() == expected.schema + + +def test_rolling_nanoseconds_11003() -> None: + df = pl.DataFrame( + { + "dt": [ + "2020-01-01T00:00:00.000000000", + "2020-01-01T00:00:00.000000100", + "2020-01-01T00:00:00.000000200", + ], + "val": [1, 2, 3], + } + ) + df = df.with_columns(pl.col("dt").str.to_datetime(time_unit="ns")).set_sorted("dt") + result = df.with_columns(pl.col("val").rolling_sum_by("dt", "500ns")) + expected = df.with_columns(val=pl.Series([1, 3, 6])) + assert_frame_equal(result, expected) + + +def test_rolling_by_1mo_saturating_12216() -> None: + df = pl.DataFrame( + { + "date": [ + date(2020, 6, 29), + date(2020, 6, 30), + date(2020, 7, 30), + date(2020, 7, 31), + date(2020, 8, 1), + ], + "val": [1, 2, 3, 4, 5], + } + ).set_sorted("date") + result = df.rolling(index_column="date", period="1mo").agg(vals=pl.col("val")) + expected = pl.DataFrame( + { + "date": [ + date(2020, 6, 29), + date(2020, 6, 30), + date(2020, 7, 30), + date(2020, 7, 31), + date(2020, 8, 1), + ], + "vals": [[1], [1, 2], [3], [3, 4], [3, 4, 5]], + } + ) + assert_frame_equal(result, expected) + + # check with `closed='both'` against DuckDB output + result = df.rolling(index_column="date", period="1mo", closed="both").agg( + vals=pl.col("val") + ) + expected = pl.DataFrame( + { + "date": [ + date(2020, 6, 29), + date(2020, 6, 30), + date(2020, 7, 30), + date(2020, 7, 31), + date(2020, 8, 1), + ], + "vals": [[1], [1, 2], [2, 3], [2, 3, 4], [3, 4, 5]], + } + ) + assert_frame_equal(result, expected) + + +def test_index_expr_with_literal() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}).sort("a") + out = df.rolling(index_column=(5 * pl.col("a")).set_sorted(), period="2i").agg( + pl.col("b") + ) + expected = pl.DataFrame({"literal": [5, 10, 15], "b": [["a"], ["b"], ["c"]]}) + assert_frame_equal(out, expected) + + +def test_index_expr_output_name_12244() -> None: + df = pl.DataFrame({"A": [1, 2, 3]}) + + out = df.rolling(pl.int_range(0, pl.len()), period="2i").agg("A") + assert out.to_dict(as_series=False) == { + "literal": [0, 1, 2], + "A": [[1], [1, 2], [2, 3]], + } + + +def test_rolling_median() -> None: + for n in range(10, 25): + array = np.random.randint(0, 20, n) + for k in [3, 5, 7]: + a = pl.Series(array) + assert_series_equal( + a.rolling_median(k), pl.from_pandas(a.to_pandas().rolling(k).median()) + ) + + +@pytest.mark.slow +def test_rolling_median_2() -> None: + np.random.seed(12) + n = 1000 + df = pl.DataFrame({"x": np.random.normal(0, 1, n)}) + # this can differ because simd sizes and non-associativity of floats. + assert df.select( + pl.col("x").rolling_median(window_size=10).sum() + ).item() == pytest.approx(5.139429061527812) + assert df.select( + pl.col("x").rolling_median(window_size=100).sum() + ).item() == pytest.approx(26.60506093611384) + + +@pytest.mark.parametrize( + ("dates", "closed", "expected"), + [ + ( + [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)], + "right", + [None, 3, 5], + ), + ( + [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)], + "left", + [None, None, 3], + ), + ( + [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)], + "both", + [None, 3, 6], + ), + ( + [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)], + "none", + [None, None, None], + ), + ( + [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 4)], + "right", + [None, 3, None], + ), + ( + [date(2020, 1, 1), date(2020, 1, 3), date(2020, 1, 4)], + "right", + [None, None, 5], + ), + ( + [date(2020, 1, 1), date(2020, 1, 3), date(2020, 1, 5)], + "right", + [None, None, None], + ), + ], +) +def test_rolling_min_samples( + dates: list[date], closed: ClosedInterval, expected: list[int] +) -> None: + df = pl.DataFrame({"date": dates, "value": [1, 2, 3]}).sort("date") + result = df.select( + pl.col("value").rolling_sum_by( + "date", window_size="2d", min_samples=2, closed=closed + ) + )["value"] + assert_series_equal(result, pl.Series("value", expected, pl.Int64)) + + # Starting with unsorted data + result = ( + df.sort("date", descending=True) + .with_columns( + pl.col("value").rolling_sum_by( + "date", window_size="2d", min_samples=2, closed=closed + ) + ) + .sort("date")["value"] + ) + assert_series_equal(result, pl.Series("value", expected, pl.Int64)) + + +def test_rolling_returns_scalar_15656() -> None: + df = pl.DataFrame( + { + "a": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)], + "b": [4, 5, 6], + "c": [1, 2, 3], + } + ) + result = df.group_by("c").agg(pl.col("b").rolling_mean_by("a", "2d")).sort("c") + expected = pl.DataFrame({"c": [1, 2, 3], "b": [[4.0], [5.0], [6.0]]}) + assert_frame_equal(result, expected) + + +def test_rolling_invalid() -> None: + df = pl.DataFrame( + { + "values": [1, 4], + "times": [datetime(2020, 1, 3), datetime(2020, 1, 1)], + }, + ) + with pytest.raises( + InvalidOperationError, match="duration may not be a parsed integer" + ): + ( + df.sort("times") + .rolling("times", period="3000i") + .agg(pl.col("values").sum().alias("sum")) + ) + with pytest.raises( + InvalidOperationError, match="duration must be a parsed integer" + ): + ( + df.with_row_index() + .rolling("index", period="3000d") + .agg(pl.col("values").sum().alias("sum")) + ) + + +def test_by_different_length() -> None: + df = pl.DataFrame({"b": [1]}) + with pytest.raises(InvalidOperationError, match="must be the same length"): + df.select( + pl.col("b").rolling_max_by(pl.Series([datetime(2020, 1, 1)] * 2), "1d") + ) + + +def test_incorrect_nulls_16246() -> None: + df = pl.concat( + [ + pl.DataFrame({"a": [datetime(2020, 1, 1)], "b": [1]}), + pl.DataFrame({"a": [datetime(2021, 1, 1)], "b": [1]}), + ], + rechunk=False, + ) + result = df.select(pl.col("b").rolling_max_by("a", "1d")) + expected = pl.DataFrame({"b": [1, 1]}) + assert_frame_equal(result, expected) + + +def test_rolling_with_dst() -> None: + df = pl.DataFrame( + {"a": [datetime(2020, 10, 26, 1), datetime(2020, 10, 26)], "b": [1, 2]} + ).with_columns(pl.col("a").dt.replace_time_zone("Europe/London")) + with pytest.raises(ComputeError, match="is ambiguous"): + df.select(pl.col("b").rolling_sum_by("a", "1d")) + with pytest.raises(ComputeError, match="is ambiguous"): + df.sort("a").select(pl.col("b").rolling_sum_by("a", "1d")) + + +def interval_defs() -> SearchStrategy[ClosedInterval]: + closed: list[ClosedInterval] = ["left", "right", "both", "none"] + return st.sampled_from(closed) + + +@given( + period=st.timedeltas( + min_value=timedelta(microseconds=0), max_value=timedelta(days=1000) + ).map(parse_as_duration_string), + offset=st.timedeltas( + min_value=timedelta(days=-1000), max_value=timedelta(days=1000) + ).map(parse_as_duration_string), + closed=interval_defs(), + data=st.data(), + time_unit=_time_units(), +) +def test_rolling_parametric( + period: str, + offset: str, + closed: ClosedInterval, + data: st.DataObject, + time_unit: TimeUnit, +) -> None: + assume(period != "") + dataframe = data.draw( + dataframes( + [ + column( + "ts", + strategy=st.datetimes( + min_value=datetime(2000, 1, 1), + max_value=datetime(2001, 1, 1), + ), + dtype=pl.Datetime(time_unit), + ), + column( + "value", + strategy=st.integers(min_value=-100, max_value=100), + dtype=pl.Int64, + ), + ], + min_size=1, + ) + ) + df = dataframe.sort("ts") + result = df.rolling("ts", period=period, offset=offset, closed=closed).agg( + pl.col("value") + ) + + expected_dict: dict[str, list[object]] = {"ts": [], "value": []} + for ts, _ in df.iter_rows(): + window = df.filter( + pl.col("ts").is_between( + pl.lit(ts, dtype=pl.Datetime(time_unit)).dt.offset_by(offset), + pl.lit(ts, dtype=pl.Datetime(time_unit)) + .dt.offset_by(offset) + .dt.offset_by(period), + closed=closed, + ) + ) + value = window["value"].to_list() + expected_dict["ts"].append(ts) + expected_dict["value"].append(value) + expected = pl.DataFrame(expected_dict).select( + pl.col("ts").cast(pl.Datetime(time_unit)), + pl.col("value").cast(pl.List(pl.Int64)), + ) + assert_frame_equal(result, expected) + + +@given( + window_size=st.timedeltas( + min_value=timedelta(microseconds=0), max_value=timedelta(days=2) + ).map(parse_as_duration_string), + closed=interval_defs(), + data=st.data(), + time_unit=_time_units(), + aggregation=st.sampled_from( + [ + "min", + "max", + "mean", + "sum", + "std", + "var", + "median", + ] + ), +) +def test_rolling_aggs( + window_size: str, + closed: ClosedInterval, + data: st.DataObject, + time_unit: TimeUnit, + aggregation: str, +) -> None: + assume(window_size != "") + + # Testing logic can be faulty when window is more precise than time unit + # https://github.com/pola-rs/polars/issues/11754 + assume(not (time_unit == "ms" and "us" in window_size)) + + dataframe = data.draw( + dataframes( + [ + column( + "ts", + strategy=st.datetimes( + min_value=datetime(2000, 1, 1), + max_value=datetime(2001, 1, 1), + ), + dtype=pl.Datetime(time_unit), + ), + column( + "value", + strategy=st.integers(min_value=-100, max_value=100), + dtype=pl.Int64, + ), + ], + ) + ) + df = dataframe.sort("ts") + func = f"rolling_{aggregation}_by" + result = df.with_columns( + getattr(pl.col("value"), func)("ts", window_size=window_size, closed=closed) + ) + result_from_unsorted = dataframe.with_columns( + getattr(pl.col("value"), func)("ts", window_size=window_size, closed=closed) + ).sort("ts") + + expected_dict: dict[str, list[object]] = {"ts": [], "value": []} + for ts, _ in df.iter_rows(): + window = df.filter( + pl.col("ts").is_between( + pl.lit(ts, dtype=pl.Datetime(time_unit)).dt.offset_by( + f"-{window_size}" + ), + pl.lit(ts, dtype=pl.Datetime(time_unit)), + closed=closed, + ) + ) + expected_dict["ts"].append(ts) + if window.is_empty(): + expected_dict["value"].append(None) + else: + value = getattr(window["value"], aggregation)() + expected_dict["value"].append(value) + expected = pl.DataFrame(expected_dict).select( + pl.col("ts").cast(pl.Datetime(time_unit)), + pl.col("value").cast(result["value"].dtype), + ) + assert_frame_equal(result, expected) + assert_frame_equal(result_from_unsorted, expected) + + +def test_rolling_by_nulls() -> None: + df = pl.DataFrame({"a": [1, None], "b": [1, 2]}) + with pytest.raises( + InvalidOperationError, match="not yet supported for series with null values" + ): + df.select(pl.col("a").rolling_min_by("b", "2i")) + with pytest.raises( + InvalidOperationError, match="not yet supported for series with null values" + ): + df.select(pl.col("b").rolling_min_by("a", "2i")) + + +def test_window_size_validation() -> None: + df = pl.DataFrame({"x": [1.0]}) + + with pytest.raises(OverflowError, match=r"can't convert negative int to unsigned"): + df.with_columns(trailing_min=pl.col("x").rolling_min(window_size=-3)) + + +def test_rolling_empty_21032() -> None: + df = pl.DataFrame(schema={"a": pl.Datetime("ms"), "b": pl.Int64()}) + + result = df.rolling(index_column="a", period=timedelta(days=2)).agg( + pl.col("b").sum() + ) + assert_frame_equal(result, df) + + result = df.rolling( + index_column="a", period=timedelta(days=2), offset=timedelta(days=3) + ).agg(pl.col("b").sum()) + assert_frame_equal(result, df) + + +def test_rolling_offset_agg_15122() -> None: + df = pl.DataFrame({"a": [1, 1, 1, 2, 2, 2], "b": [1, 2, 3, 1, 2, 3]}) + + result = df.rolling(index_column="b", period="1i", offset="0i", group_by="a").agg( + window=pl.col("b") + ) + expected = df.with_columns(window=pl.Series([[2], [3], [], [2], [3], []])) + assert_frame_equal(result, expected) + + result = df.rolling(index_column="b", period="1i", offset="1i", group_by="a").agg( + window=pl.col("b") + ) + expected = df.with_columns(window=pl.Series([[3], [], [], [3], [], []])) + assert_frame_equal(result, expected) + + +def test_rolling_sum_stability_11146() -> None: + data_frame = pl.DataFrame( + { + "value": [ + 0.0, + 290.57, + 107.0, + 172.0, + 124.25, + 304.0, + 379.5, + 347.35, + 1516.41, + 386.12, + 226.5, + 294.62, + 125.5, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ] + } + ) + assert ( + data_frame.with_columns( + pl.col("value").rolling_mean(window_size=8, min_samples=1).alias("test_col") + )["test_col"][-1] + == 0.0 + ) + + +def test_rolling() -> None: + df = pl.DataFrame( + { + "n": [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10], + "col1": ["A", "B"] * 11, + } + ) + + assert df.rolling("n", period="1i", group_by="col1").agg().to_dict( + as_series=False + ) == { + "col1": [ + "A", + "A", + "A", + "A", + "A", + "A", + "A", + "A", + "A", + "A", + "A", + "B", + "B", + "B", + "B", + "B", + "B", + "B", + "B", + "B", + "B", + "B", + ], + "n": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + } diff --git a/py-polars/tests/unit/operations/rolling/test_rolling_fixed.py b/py-polars/tests/unit/operations/rolling/test_rolling_fixed.py new file mode 100644 index 000000000000..b99bc2b04334 --- /dev/null +++ b/py-polars/tests/unit/operations/rolling/test_rolling_fixed.py @@ -0,0 +1,7 @@ +import polars as pl + + +def test_rolling_var_stability_12905() -> None: + s1 = pl.Series("a", [36743.6 for _ in range(10)]) + assert s1.rolling_var(window_size=12, min_samples=2).sum() == 0.0 + assert s1.rolling_std(window_size=12, min_samples=2).sum() == 0.0 diff --git a/py-polars/tests/unit/operations/test_abs.py b/py-polars/tests/unit/operations/test_abs.py new file mode 100644 index 000000000000..ad0d6eadf9c1 --- /dev/null +++ b/py-polars/tests/unit/operations/test_abs.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +from datetime import date, timedelta +from decimal import Decimal as D +from typing import TYPE_CHECKING, cast + +import numpy as np +import pytest + +import polars as pl +from polars.exceptions import InvalidOperationError +from polars.testing import assert_frame_equal, assert_series_equal +from tests.unit.conftest import FLOAT_DTYPES, SIGNED_INTEGER_DTYPES + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + + +def test_abs() -> None: + # ints + s = pl.Series([1, -2, 3, -4]) + assert_series_equal(s.abs(), pl.Series([1, 2, 3, 4])) + assert_series_equal(cast(pl.Series, np.abs(s)), pl.Series([1, 2, 3, 4])) + + # floats + s = pl.Series([1.0, -2.0, 3, -4.0]) + assert_series_equal(s.abs(), pl.Series([1.0, 2.0, 3.0, 4.0])) + assert_series_equal(cast(pl.Series, np.abs(s)), pl.Series([1.0, 2.0, 3.0, 4.0])) + assert_series_equal( + pl.select(pl.lit(s).abs()).to_series(), pl.Series([1.0, 2.0, 3.0, 4.0]) + ) + + +def test_abs_series_duration() -> None: + s = pl.Series([timedelta(hours=1), timedelta(hours=-1)]) + assert s.abs().to_list() == [timedelta(hours=1), timedelta(hours=1)] + + +def test_abs_expr() -> None: + df = pl.DataFrame({"x": [-1, 0, 1]}) + out = df.select(abs(pl.col("x"))) + + assert out["x"].to_list() == [1, 0, 1] + + +def test_builtin_abs() -> None: + s = pl.Series("s", [-1, 0, 1, None]) + assert abs(s).to_list() == [1, 0, 1, None] + + +@pytest.mark.parametrize("dtype", [*FLOAT_DTYPES, *SIGNED_INTEGER_DTYPES]) +def test_abs_builtin(dtype: PolarsDataType) -> None: + lf = pl.LazyFrame({"a": [-1, 0, 1, None]}, schema={"a": dtype}) + result = lf.select(abs(pl.col("a"))) + expected = pl.LazyFrame({"a": [1, 0, 1, None]}, schema={"a": dtype}) + assert_frame_equal(result, expected) + + +def test_abs_method() -> None: + lf = pl.LazyFrame({"a": [-1, 0, 1, None]}) + result_op = lf.select(abs(pl.col("a"))) + result_method = lf.select(pl.col("a").abs()) + assert_frame_equal(result_op, result_method) + + +def test_abs_decimal() -> None: + lf = pl.LazyFrame({"a": [D("-1.5"), D("0.0"), D("5.0"), None]}) + result = lf.select(pl.col("a").abs()) + expected = pl.LazyFrame({"a": [D("1.5"), D("0.0"), D("5.0"), None]}) + assert_frame_equal(result, expected) + + +def test_abs_duration() -> None: + lf = pl.LazyFrame({"a": [timedelta(hours=2), timedelta(days=-2), None]}) + result = lf.select(pl.col("a").abs()) + expected = pl.LazyFrame({"a": [timedelta(hours=2), timedelta(days=2), None]}) + assert_frame_equal(result, expected) + + +def test_abs_overflow_wrapping() -> None: + df = pl.DataFrame({"a": [-128]}, schema={"a": pl.Int8}) + result = df.select(pl.col("a").abs()) + assert_frame_equal(result, df) + + +def test_abs_unsigned_int() -> None: + df = pl.DataFrame({"a": [1, 2, 3]}, schema={"a": pl.UInt8}) + result = df.select(pl.col("a").abs()) + assert_frame_equal(result, df) + + +def test_abs_non_numeric() -> None: + df = pl.DataFrame({"a": ["p", "q", "r"]}) + with pytest.raises( + InvalidOperationError, match="`abs` operation not supported for dtype `str`" + ): + df.select(pl.col("a").abs()) + + +def test_abs_date() -> None: + df = pl.DataFrame({"date": [date(1960, 1, 1), date(1970, 1, 1), date(1980, 1, 1)]}) + + with pytest.raises( + InvalidOperationError, match="`abs` operation not supported for dtype `date`" + ): + df.select(pl.col("date").abs()) + + +def test_abs_series_builtin() -> None: + s = pl.Series("a", [-1, 0, 1, None]) + result = abs(s) + expected = pl.Series("a", [1, 0, 1, None]) + assert_series_equal(result, expected) diff --git a/py-polars/tests/unit/operations/test_bitwise.py b/py-polars/tests/unit/operations/test_bitwise.py new file mode 100644 index 000000000000..dfbe8306fc80 --- /dev/null +++ b/py-polars/tests/unit/operations/test_bitwise.py @@ -0,0 +1,204 @@ +from __future__ import annotations + +import sys +import typing + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal +from tests.unit.conftest import INTEGER_DTYPES + + +@pytest.mark.parametrize("op", ["and_", "or_"]) +def test_bitwise_integral_schema(op: str) -> None: + df = pl.LazyFrame({"a": [1, 2], "b": [3, 4]}) + q = df.select(getattr(pl.col("a"), op)(pl.col("b"))) + assert q.collect_schema()["a"] == df.collect_schema()["a"] + + +@pytest.mark.parametrize("op", ["and_", "or_", "xor"]) +def test_bitwise_single_null_value_schema(op: str) -> None: + df = pl.DataFrame({"a": [True, True]}) + q = df.select(getattr(pl.col("a"), op)(None)) + result_schema = q.collect_schema() + assert result_schema.len() == 1 + assert "a" in result_schema + + +def leading_zeros(v: int | None, nb: int) -> int | None: + if v is None: + return None + + b = bin(v)[2:] + blen = len(b) - len(b.lstrip("0")) + if blen == len(b): + return nb + else: + return nb - len(b) + blen + + +def leading_ones(v: int | None, nb: int) -> int | None: + if v is None: + return None + + b = bin(v)[2:] + if len(b) < nb: + return 0 + else: + return len(b) - len(b.lstrip("1")) + + +def trailing_zeros(v: int | None, nb: int) -> int | None: + if v is None: + return None + + b = bin(v)[2:] + blen = len(b) - len(b.rstrip("0")) + if blen == len(b): + return nb + else: + return blen + + +def trailing_ones(v: int | None) -> int | None: + if v is None: + return None + + b = bin(v)[2:] + return len(b) - len(b.rstrip("1")) + + +@pytest.mark.parametrize( + "value", + [ + 0x00, + 0x01, + 0xFCEF_0123, + 0xFFFF_FFFF, + 0xFFF0_FFE1_ABCD_EF01, + 0xAAAA_AAAA_AAAA_AAAA, + None, + ], +) +@pytest.mark.parametrize("dtype", [*INTEGER_DTYPES, pl.Boolean]) +@pytest.mark.skipif(sys.version_info < (3, 10), reason="bit_count introduced in 3.10") +@typing.no_type_check +def test_bit_counts(value: int, dtype: pl.DataType) -> None: + bitsize = 8 + if "Boolean" in str(dtype): + bitsize = 1 + if "16" in str(dtype): + bitsize = 16 + elif "32" in str(dtype): + bitsize = 32 + elif "64" in str(dtype): + bitsize = 64 + elif "128" in str(dtype): + bitsize = 128 + + if bitsize == 1 and value is not None: + value = value & 1 != 0 + + co = 1 if value else 0 + cz = 0 if value else 1 + elif value is not None: + value = value & ((1 << bitsize) - 1) + + if dtype.is_signed_integer() and value >> (bitsize - 1) > 0: + value = value - pow(2, bitsize - 1) + + co = value.bit_count() + cz = bitsize - co + else: + co = None + cz = None + + assert_series_equal( + pl.Series("a", [value], dtype).bitwise_count_ones(), + pl.Series("a", [co], pl.UInt32), + ) + assert_series_equal( + pl.Series("a", [value], dtype).bitwise_count_zeros(), + pl.Series("a", [cz], pl.UInt32), + ) + assert_series_equal( + pl.Series("a", [value], dtype).bitwise_leading_ones(), + pl.Series("a", [leading_ones(value, bitsize)], pl.UInt32), + ) + assert_series_equal( + pl.Series("a", [value], dtype).bitwise_leading_zeros(), + pl.Series("a", [leading_zeros(value, bitsize)], pl.UInt32), + ) + assert_series_equal( + pl.Series("a", [value], dtype).bitwise_trailing_ones(), + pl.Series("a", [trailing_ones(value)], pl.UInt32), + ) + assert_series_equal( + pl.Series("a", [value], dtype).bitwise_trailing_zeros(), + pl.Series("a", [trailing_zeros(value, bitsize)], pl.UInt32), + ) + + +@pytest.mark.parametrize("dtype", INTEGER_DTYPES) +def test_bit_aggregations(dtype: pl.DataType) -> None: + s = pl.Series("a", [0x74, 0x1C, 0x05], dtype) + + df = s.to_frame().select( + AND=pl.col.a.bitwise_and(), + OR=pl.col.a.bitwise_or(), + XOR=pl.col.a.bitwise_xor(), + ) + + assert_frame_equal( + df, + pl.DataFrame( + [ + pl.Series("AND", [0x04], dtype), + pl.Series("OR", [0x7D], dtype), + pl.Series("XOR", [0x6D], dtype), + ] + ), + ) + + +@pytest.mark.parametrize("dtype", INTEGER_DTYPES) +def test_bit_group_by(dtype: pl.DataType) -> None: + df = pl.DataFrame( + [ + pl.Series("g", [4, 1, 1, 2, 3, 2, 4, 4], pl.Int8), + pl.Series("a", [0x03, 0x74, 0x1C, 0x05, None, 0x70, 0x01, None], dtype), + ] + ) + + df = df.group_by("g").agg( + AND=pl.col.a.bitwise_and(), + OR=pl.col.a.bitwise_or(), + XOR=pl.col.a.bitwise_xor(), + ) + + assert_frame_equal( + df, + pl.DataFrame( + [ + pl.Series("g", [1, 2, 3, 4], pl.Int8), + pl.Series("AND", [0x74 & 0x1C, 0x05 & 0x70, None, 0x01], dtype), + pl.Series("OR", [0x74 | 0x1C, 0x05 | 0x70, None, 0x03], dtype), + pl.Series("XOR", [0x74 ^ 0x1C, 0x05 ^ 0x70, None, 0x02], dtype), + ] + ), + check_row_order=False, + ) + + +def test_scalar_bitwise_xor() -> None: + df = pl.select( + pl.repeat(pl.lit(0x80, pl.UInt8), i).bitwise_xor().alias(f"l{i}") + for i in range(5) + ).transpose() + + assert_series_equal( + df.to_series(), + pl.Series("x", [None, 0x80, 0x00, 0x80, 0x00], pl.UInt8), + check_names=False, + ) diff --git a/py-polars/tests/unit/operations/test_cast.py b/py-polars/tests/unit/operations/test_cast.py new file mode 100644 index 000000000000..661ebf50cb84 --- /dev/null +++ b/py-polars/tests/unit/operations/test_cast.py @@ -0,0 +1,742 @@ +from __future__ import annotations + +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from typing import TYPE_CHECKING, Any + +import pytest + +import polars as pl +from polars._utils.constants import MS_PER_SECOND, NS_PER_SECOND, US_PER_SECOND +from polars.exceptions import ComputeError, InvalidOperationError +from polars.testing import assert_frame_equal +from polars.testing.asserts.series import assert_series_equal +from tests.unit.conftest import INTEGER_DTYPES + +if TYPE_CHECKING: + from polars._typing import PolarsDataType, PythonDataType + + +@pytest.mark.parametrize("dtype", [pl.Date(), pl.Date, date]) +def test_string_date(dtype: PolarsDataType | PythonDataType) -> None: + df = pl.DataFrame({"x1": ["2021-01-01"]}).with_columns( + **{"x1-date": pl.col("x1").cast(dtype)} + ) + expected = pl.DataFrame({"x1-date": [date(2021, 1, 1)]}) + out = df.select(pl.col("x1-date")) + assert_frame_equal(expected, out) + + +def test_invalid_string_date() -> None: + df = pl.DataFrame({"x1": ["2021-01-aa"]}) + + with pytest.raises(InvalidOperationError): + df.with_columns(**{"x1-date": pl.col("x1").cast(pl.Date)}) + + +def test_string_datetime() -> None: + df = pl.DataFrame( + {"x1": ["2021-12-19T00:39:57", "2022-12-19T16:39:57"]} + ).with_columns( + **{ + "x1-datetime-ns": pl.col("x1").cast(pl.Datetime(time_unit="ns")), + "x1-datetime-ms": pl.col("x1").cast(pl.Datetime(time_unit="ms")), + "x1-datetime-us": pl.col("x1").cast(pl.Datetime(time_unit="us")), + } + ) + first_row = datetime(year=2021, month=12, day=19, hour=00, minute=39, second=57) + second_row = datetime(year=2022, month=12, day=19, hour=16, minute=39, second=57) + expected = pl.DataFrame( + { + "x1-datetime-ns": [first_row, second_row], + "x1-datetime-ms": [first_row, second_row], + "x1-datetime-us": [first_row, second_row], + } + ).select( + pl.col("x1-datetime-ns").dt.cast_time_unit("ns"), + pl.col("x1-datetime-ms").dt.cast_time_unit("ms"), + pl.col("x1-datetime-us").dt.cast_time_unit("us"), + ) + + out = df.select( + pl.col("x1-datetime-ns"), pl.col("x1-datetime-ms"), pl.col("x1-datetime-us") + ) + assert_frame_equal(expected, out) + + +def test_invalid_string_datetime() -> None: + df = pl.DataFrame({"x1": ["2021-12-19 00:39:57", "2022-12-19 16:39:57"]}) + with pytest.raises(InvalidOperationError): + df.with_columns( + **{"x1-datetime-ns": pl.col("x1").cast(pl.Datetime(time_unit="ns"))} + ) + + +def test_string_datetime_timezone() -> None: + ccs_tz = "America/Caracas" + stg_tz = "America/Santiago" + utc_tz = "UTC" + df = pl.DataFrame( + {"x1": ["1996-12-19T16:39:57 +00:00", "2022-12-19T00:39:57 +00:00"]} + ).with_columns( + **{ + "x1-datetime-ns": pl.col("x1").cast( + pl.Datetime(time_unit="ns", time_zone=ccs_tz) + ), + "x1-datetime-ms": pl.col("x1").cast( + pl.Datetime(time_unit="ms", time_zone=stg_tz) + ), + "x1-datetime-us": pl.col("x1").cast( + pl.Datetime(time_unit="us", time_zone=utc_tz) + ), + } + ) + + expected = pl.DataFrame( + { + "x1-datetime-ns": [ + datetime(year=1996, month=12, day=19, hour=12, minute=39, second=57), + datetime(year=2022, month=12, day=18, hour=20, minute=39, second=57), + ], + "x1-datetime-ms": [ + datetime(year=1996, month=12, day=19, hour=13, minute=39, second=57), + datetime(year=2022, month=12, day=18, hour=21, minute=39, second=57), + ], + "x1-datetime-us": [ + datetime(year=1996, month=12, day=19, hour=16, minute=39, second=57), + datetime(year=2022, month=12, day=19, hour=00, minute=39, second=57), + ], + } + ).select( + pl.col("x1-datetime-ns").dt.cast_time_unit("ns").dt.replace_time_zone(ccs_tz), + pl.col("x1-datetime-ms").dt.cast_time_unit("ms").dt.replace_time_zone(stg_tz), + pl.col("x1-datetime-us").dt.cast_time_unit("us").dt.replace_time_zone(utc_tz), + ) + + out = df.select( + pl.col("x1-datetime-ns"), pl.col("x1-datetime-ms"), pl.col("x1-datetime-us") + ) + + assert_frame_equal(expected, out) + + +@pytest.mark.parametrize(("dtype"), [pl.Int8, pl.Int16, pl.Int32, pl.Int64]) +def test_leading_plus_zero_int(dtype: pl.DataType) -> None: + s_int = pl.Series( + [ + "-000000000000002", + "-1", + "-0", + "0", + "+0", + "1", + "+1", + "0000000000000000000002", + "+000000000000000000003", + ] + ) + assert_series_equal( + s_int.cast(dtype), pl.Series([-2, -1, 0, 0, 0, 1, 1, 2, 3], dtype=dtype) + ) + + +@pytest.mark.parametrize(("dtype"), [pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64]) +def test_leading_plus_zero_uint(dtype: pl.DataType) -> None: + s_int = pl.Series( + ["0", "+0", "1", "+1", "0000000000000000000002", "+000000000000000000003"] + ) + assert_series_equal(s_int.cast(dtype), pl.Series([0, 0, 1, 1, 2, 3], dtype=dtype)) + + +@pytest.mark.parametrize(("dtype"), [pl.Float32, pl.Float64]) +def test_leading_plus_zero_float(dtype: pl.DataType) -> None: + s_float = pl.Series( + [ + "-000000000000002.0", + "-1.0", + "-.5", + "-0.0", + "0.", + "+0", + "+.5", + "1", + "+1", + "0000000000000000000002", + "+000000000000000000003", + ] + ) + assert_series_equal( + s_float.cast(dtype), + pl.Series( + [-2.0, -1.0, -0.5, 0.0, 0.0, 0.0, 0.5, 1.0, 1.0, 2.0, 3.0], dtype=dtype + ), + ) + + +def _cast_series( + val: int | datetime | date | time | timedelta, + dtype_in: PolarsDataType, + dtype_out: PolarsDataType, + strict: bool, +) -> int | datetime | date | time | timedelta | None: + return pl.Series("a", [val], dtype=dtype_in).cast(dtype_out, strict=strict).item() # type: ignore[no-any-return] + + +def _cast_expr( + val: int | datetime | date | time | timedelta, + dtype_in: PolarsDataType, + dtype_out: PolarsDataType, + strict: bool, +) -> int | datetime | date | time | timedelta | None: + return ( # type: ignore[no-any-return] + pl.Series("a", [val], dtype=dtype_in) + .to_frame() + .select(pl.col("a").cast(dtype_out, strict=strict)) + .item() + ) + + +def _cast_lit( + val: int | datetime | date | time | timedelta, + dtype_in: PolarsDataType, + dtype_out: PolarsDataType, + strict: bool, +) -> int | datetime | date | time | timedelta | None: + return pl.select(pl.lit(val, dtype=dtype_in).cast(dtype_out, strict=strict)).item() # type: ignore[no-any-return] + + +@pytest.mark.parametrize( + ("value", "from_dtype", "to_dtype", "should_succeed", "expected_value"), + [ + (-1, pl.Int8, pl.UInt8, False, None), + (-1, pl.Int16, pl.UInt16, False, None), + (-1, pl.Int32, pl.UInt32, False, None), + (-1, pl.Int64, pl.UInt64, False, None), + (2**7, pl.UInt8, pl.Int8, False, None), + (2**15, pl.UInt16, pl.Int16, False, None), + (2**31, pl.UInt32, pl.Int32, False, None), + (2**63, pl.UInt64, pl.Int64, False, None), + (2**7 - 1, pl.UInt8, pl.Int8, True, 2**7 - 1), + (2**15 - 1, pl.UInt16, pl.Int16, True, 2**15 - 1), + (2**31 - 1, pl.UInt32, pl.Int32, True, 2**31 - 1), + (2**63 - 1, pl.UInt64, pl.Int64, True, 2**63 - 1), + ], +) +def test_strict_cast_int( + value: int, + from_dtype: PolarsDataType, + to_dtype: PolarsDataType, + should_succeed: bool, + expected_value: Any, +) -> None: + args = [value, from_dtype, to_dtype, True] + if should_succeed: + assert _cast_series(*args) == expected_value # type: ignore[arg-type] + assert _cast_expr(*args) == expected_value # type: ignore[arg-type] + assert _cast_lit(*args) == expected_value # type: ignore[arg-type] + else: + with pytest.raises(InvalidOperationError): + _cast_series(*args) # type: ignore[arg-type] + with pytest.raises(InvalidOperationError): + _cast_expr(*args) # type: ignore[arg-type] + with pytest.raises(InvalidOperationError): + _cast_lit(*args) # type: ignore[arg-type] + + +@pytest.mark.parametrize( + ("value", "from_dtype", "to_dtype", "expected_value"), + [ + (-1, pl.Int8, pl.UInt8, None), + (-1, pl.Int16, pl.UInt16, None), + (-1, pl.Int32, pl.UInt32, None), + (-1, pl.Int64, pl.UInt64, None), + (2**7, pl.UInt8, pl.Int8, None), + (2**15, pl.UInt16, pl.Int16, None), + (2**31, pl.UInt32, pl.Int32, None), + (2**63, pl.UInt64, pl.Int64, None), + (2**7 - 1, pl.UInt8, pl.Int8, 2**7 - 1), + (2**15 - 1, pl.UInt16, pl.Int16, 2**15 - 1), + (2**31 - 1, pl.UInt32, pl.Int32, 2**31 - 1), + (2**63 - 1, pl.UInt64, pl.Int64, 2**63 - 1), + ], +) +def test_cast_int( + value: int, + from_dtype: PolarsDataType, + to_dtype: PolarsDataType, + expected_value: Any, +) -> None: + args = [value, from_dtype, to_dtype, False] + assert _cast_series(*args) == expected_value # type: ignore[arg-type] + assert _cast_expr(*args) == expected_value # type: ignore[arg-type] + assert _cast_lit(*args) == expected_value # type: ignore[arg-type] + + +def _cast_series_t( + val: int | datetime | date | time | timedelta, + dtype_in: PolarsDataType, + dtype_out: PolarsDataType, + strict: bool, +) -> pl.Series: + return pl.Series("a", [val], dtype=dtype_in).cast(dtype_out, strict=strict) + + +def _cast_expr_t( + val: int | datetime | date | time | timedelta, + dtype_in: PolarsDataType, + dtype_out: PolarsDataType, + strict: bool, +) -> pl.Series: + return ( + pl.Series("a", [val], dtype=dtype_in) + .to_frame() + .select(pl.col("a").cast(dtype_out, strict=strict)) + .to_series() + ) + + +def _cast_lit_t( + val: int | datetime | date | time | timedelta, + dtype_in: PolarsDataType, + dtype_out: PolarsDataType, + strict: bool, +) -> pl.Series: + return pl.select( + pl.lit(val, dtype=dtype_in).cast(dtype_out, strict=strict) + ).to_series() + + +@pytest.mark.parametrize( + ( + "value", + "from_dtype", + "to_dtype", + "should_succeed", + "expected_value", + ), + [ + # date to datetime + (date(1970, 1, 1), pl.Date, pl.Datetime("ms"), True, datetime(1970, 1, 1)), + (date(1970, 1, 1), pl.Date, pl.Datetime("us"), True, datetime(1970, 1, 1)), + (date(1970, 1, 1), pl.Date, pl.Datetime("ns"), True, datetime(1970, 1, 1)), + # datetime to date + (datetime(1970, 1, 1), pl.Datetime("ms"), pl.Date, True, date(1970, 1, 1)), + (datetime(1970, 1, 1), pl.Datetime("us"), pl.Date, True, date(1970, 1, 1)), + (datetime(1970, 1, 1), pl.Datetime("ns"), pl.Date, True, date(1970, 1, 1)), + # datetime to time + (datetime(2000, 1, 1, 1, 0, 0), pl.Datetime("ms"), pl.Time, True, time(hour=1)), + (datetime(2000, 1, 1, 1, 0, 0), pl.Datetime("us"), pl.Time, True, time(hour=1)), + (datetime(2000, 1, 1, 1, 0, 0), pl.Datetime("ns"), pl.Time, True, time(hour=1)), + # duration to int + (timedelta(seconds=1), pl.Duration("ms"), pl.Int32, True, MS_PER_SECOND), + (timedelta(seconds=1), pl.Duration("us"), pl.Int64, True, US_PER_SECOND), + (timedelta(seconds=1), pl.Duration("ns"), pl.Int64, True, NS_PER_SECOND), + # time to duration + (time(hour=1), pl.Time, pl.Duration("ms"), True, timedelta(hours=1)), + (time(hour=1), pl.Time, pl.Duration("us"), True, timedelta(hours=1)), + (time(hour=1), pl.Time, pl.Duration("ns"), True, timedelta(hours=1)), + # int to date + (100, pl.UInt8, pl.Date, True, date(1970, 4, 11)), + (100, pl.UInt16, pl.Date, True, date(1970, 4, 11)), + (100, pl.UInt32, pl.Date, True, date(1970, 4, 11)), + (100, pl.UInt64, pl.Date, True, date(1970, 4, 11)), + (100, pl.Int8, pl.Date, True, date(1970, 4, 11)), + (100, pl.Int16, pl.Date, True, date(1970, 4, 11)), + (100, pl.Int32, pl.Date, True, date(1970, 4, 11)), + (100, pl.Int64, pl.Date, True, date(1970, 4, 11)), + # failures + (2**63 - 1, pl.Int64, pl.Date, False, None), + (-(2**62), pl.Int64, pl.Date, False, None), + (date(1970, 5, 10), pl.Date, pl.Int8, False, None), + (date(2149, 6, 7), pl.Date, pl.Int16, False, None), + (datetime(9999, 12, 31), pl.Datetime, pl.Int8, False, None), + (datetime(9999, 12, 31), pl.Datetime, pl.Int16, False, None), + ], +) +def test_strict_cast_temporal( + value: int, + from_dtype: PolarsDataType, + to_dtype: PolarsDataType, + should_succeed: bool, + expected_value: Any, +) -> None: + args = [value, from_dtype, to_dtype, True] + if should_succeed: + out = _cast_series_t(*args) # type: ignore[arg-type] + assert out.item() == expected_value + assert out.dtype == to_dtype + out = _cast_expr_t(*args) # type: ignore[arg-type] + assert out.item() == expected_value + assert out.dtype == to_dtype + out = _cast_lit_t(*args) # type: ignore[arg-type] + assert out.item() == expected_value + assert out.dtype == to_dtype + else: + with pytest.raises(InvalidOperationError): + _cast_series_t(*args) # type: ignore[arg-type] + with pytest.raises(InvalidOperationError): + _cast_expr_t(*args) # type: ignore[arg-type] + with pytest.raises(InvalidOperationError): + _cast_lit_t(*args) # type: ignore[arg-type] + + +@pytest.mark.parametrize( + ( + "value", + "from_dtype", + "to_dtype", + "expected_value", + ), + [ + # date to datetime + (date(1970, 1, 1), pl.Date, pl.Datetime("ms"), datetime(1970, 1, 1)), + (date(1970, 1, 1), pl.Date, pl.Datetime("us"), datetime(1970, 1, 1)), + (date(1970, 1, 1), pl.Date, pl.Datetime("ns"), datetime(1970, 1, 1)), + # datetime to date + (datetime(1970, 1, 1), pl.Datetime("ms"), pl.Date, date(1970, 1, 1)), + (datetime(1970, 1, 1), pl.Datetime("us"), pl.Date, date(1970, 1, 1)), + (datetime(1970, 1, 1), pl.Datetime("ns"), pl.Date, date(1970, 1, 1)), + # datetime to time + (datetime(2000, 1, 1, 1, 0, 0), pl.Datetime("ms"), pl.Time, time(hour=1)), + (datetime(2000, 1, 1, 1, 0, 0), pl.Datetime("us"), pl.Time, time(hour=1)), + (datetime(2000, 1, 1, 1, 0, 0), pl.Datetime("ns"), pl.Time, time(hour=1)), + # duration to int + (timedelta(seconds=1), pl.Duration("ms"), pl.Int32, MS_PER_SECOND), + (timedelta(seconds=1), pl.Duration("us"), pl.Int64, US_PER_SECOND), + (timedelta(seconds=1), pl.Duration("ns"), pl.Int64, NS_PER_SECOND), + # time to duration + (time(hour=1), pl.Time, pl.Duration("ms"), timedelta(hours=1)), + (time(hour=1), pl.Time, pl.Duration("us"), timedelta(hours=1)), + (time(hour=1), pl.Time, pl.Duration("ns"), timedelta(hours=1)), + # int to date + (100, pl.UInt8, pl.Date, date(1970, 4, 11)), + (100, pl.UInt16, pl.Date, date(1970, 4, 11)), + (100, pl.UInt32, pl.Date, date(1970, 4, 11)), + (100, pl.UInt64, pl.Date, date(1970, 4, 11)), + (100, pl.Int8, pl.Date, date(1970, 4, 11)), + (100, pl.Int16, pl.Date, date(1970, 4, 11)), + (100, pl.Int32, pl.Date, date(1970, 4, 11)), + (100, pl.Int64, pl.Date, date(1970, 4, 11)), + # failures + (2**63 - 1, pl.Int64, pl.Date, None), + (-(2**62), pl.Int64, pl.Date, None), + (date(1970, 5, 10), pl.Date, pl.Int8, None), + (date(2149, 6, 7), pl.Date, pl.Int16, None), + (datetime(9999, 12, 31), pl.Datetime, pl.Int8, None), + (datetime(9999, 12, 31), pl.Datetime, pl.Int16, None), + ], +) +def test_cast_temporal( + value: int, + from_dtype: PolarsDataType, + to_dtype: PolarsDataType, + expected_value: Any, +) -> None: + args = [value, from_dtype, to_dtype, False] + out = _cast_series_t(*args) # type: ignore[arg-type] + if expected_value is None: + assert out.item() is None + else: + assert out.item() == expected_value + assert out.dtype == to_dtype + + out = _cast_expr_t(*args) # type: ignore[arg-type] + if expected_value is None: + assert out.item() is None + else: + assert out.item() == expected_value + assert out.dtype == to_dtype + + out = _cast_lit_t(*args) # type: ignore[arg-type] + if expected_value is None: + assert out.item() is None + else: + assert out.item() == expected_value + assert out.dtype == to_dtype + + +@pytest.mark.parametrize( + ( + "value", + "from_dtype", + "to_dtype", + "expected_value", + ), + [ + (str(2**7 - 1), pl.String, pl.Int8, 2**7 - 1), + (str(2**15 - 1), pl.String, pl.Int16, 2**15 - 1), + (str(2**31 - 1), pl.String, pl.Int32, 2**31 - 1), + (str(2**63 - 1), pl.String, pl.Int64, 2**63 - 1), + ("1.0", pl.String, pl.Float32, 1.0), + ("1.0", pl.String, pl.Float64, 1.0), + # overflow + (str(2**7), pl.String, pl.Int8, None), + (str(2**15), pl.String, pl.Int16, None), + (str(2**31), pl.String, pl.Int32, None), + (str(2**63), pl.String, pl.Int64, None), + ], +) +def test_cast_string( + value: int, + from_dtype: PolarsDataType, + to_dtype: PolarsDataType, + expected_value: Any, +) -> None: + args = [value, from_dtype, to_dtype, False] + out = _cast_series_t(*args) # type: ignore[arg-type] + if expected_value is None: + assert out.item() is None + else: + assert out.item() == expected_value + assert out.dtype == to_dtype + + out = _cast_expr_t(*args) # type: ignore[arg-type] + if expected_value is None: + assert out.item() is None + else: + assert out.item() == expected_value + assert out.dtype == to_dtype + + out = _cast_lit_t(*args) # type: ignore[arg-type] + if expected_value is None: + assert out.item() is None + else: + assert out.item() == expected_value + assert out.dtype == to_dtype + + +@pytest.mark.parametrize( + ( + "value", + "from_dtype", + "to_dtype", + "should_succeed", + "expected_value", + ), + [ + (str(2**7 - 1), pl.String, pl.Int8, True, 2**7 - 1), + (str(2**15 - 1), pl.String, pl.Int16, True, 2**15 - 1), + (str(2**31 - 1), pl.String, pl.Int32, True, 2**31 - 1), + (str(2**63 - 1), pl.String, pl.Int64, True, 2**63 - 1), + ("1.0", pl.String, pl.Float32, True, 1.0), + ("1.0", pl.String, pl.Float64, True, 1.0), + # overflow + (str(2**7), pl.String, pl.Int8, False, None), + (str(2**15), pl.String, pl.Int16, False, None), + (str(2**31), pl.String, pl.Int32, False, None), + (str(2**63), pl.String, pl.Int64, False, None), + ], +) +def test_strict_cast_string( + value: int, + from_dtype: PolarsDataType, + to_dtype: PolarsDataType, + should_succeed: bool, + expected_value: Any, +) -> None: + args = [value, from_dtype, to_dtype, True] + if should_succeed: + out = _cast_series_t(*args) # type: ignore[arg-type] + assert out.item() == expected_value + assert out.dtype == to_dtype + out = _cast_expr_t(*args) # type: ignore[arg-type] + assert out.item() == expected_value + assert out.dtype == to_dtype + out = _cast_lit_t(*args) # type: ignore[arg-type] + assert out.item() == expected_value + assert out.dtype == to_dtype + else: + with pytest.raises(InvalidOperationError): + _cast_series_t(*args) # type: ignore[arg-type] + with pytest.raises(InvalidOperationError): + _cast_expr_t(*args) # type: ignore[arg-type] + with pytest.raises(InvalidOperationError): + _cast_lit_t(*args) # type: ignore[arg-type] + + +@pytest.mark.parametrize( + "dtype_in", + [(pl.Categorical), (pl.Enum(["1"]))], +) +@pytest.mark.parametrize( + "dtype_out", + [ + *INTEGER_DTYPES, + pl.Date, + pl.Datetime, + pl.Time, + pl.Duration, + pl.String, + pl.Categorical, + pl.Enum(["1", "2"]), + ], +) +def test_cast_categorical_name_retention( + dtype_in: PolarsDataType, dtype_out: PolarsDataType +) -> None: + assert pl.Series("a", ["1"], dtype=dtype_in).cast(dtype_out).name == "a" + + +def test_cast_date_to_time() -> None: + s = pl.Series([date(1970, 1, 1), date(2000, 12, 31)]) + msg = "casting from Date to Time not supported" + with pytest.raises(InvalidOperationError, match=msg): + s.cast(pl.Time) + + +def test_cast_time_to_date() -> None: + s = pl.Series([time(0, 0), time(20, 00)]) + msg = "casting from Time to Date not supported" + with pytest.raises(InvalidOperationError, match=msg): + s.cast(pl.Date) + + +def test_cast_decimal_to_boolean() -> None: + s = pl.Series("s", [Decimal("0.0"), Decimal("1.5"), Decimal("-1.5")]) + assert_series_equal(s.cast(pl.Boolean), pl.Series("s", [False, True, True])) + + df = s.to_frame() + assert_frame_equal( + df.select(pl.col("s").cast(pl.Boolean)), + pl.DataFrame({"s": [False, True, True]}), + ) + + +def test_cast_array_to_different_width() -> None: + s = pl.Series([[1, 2], [3, 4]], dtype=pl.Array(pl.Int8, 2)) + with pytest.raises( + InvalidOperationError, match="cannot cast Array to a different width" + ): + s.cast(pl.Array(pl.Int16, 3)) + + +def test_cast_decimal_to_decimal_high_precision() -> None: + precision = 22 + values = [Decimal("9" * precision)] + s = pl.Series(values, dtype=pl.Decimal(None, 0)) + + target_dtype = pl.Decimal(precision, 0) + result = s.cast(target_dtype) + + assert result.dtype == target_dtype + assert result.to_list() == values + + +@pytest.mark.parametrize("value", [float("inf"), float("nan")]) +def test_invalid_cast_float_to_decimal(value: float) -> None: + s = pl.Series([value], dtype=pl.Float64) + with pytest.raises( + InvalidOperationError, + match=r"conversion from `f64` to `decimal\[\*,0\]` failed", + ): + s.cast(pl.Decimal) + + +def test_err_on_time_datetime_cast() -> None: + s = pl.Series([time(10, 0, 0), time(11, 30, 59)]) + with pytest.raises( + InvalidOperationError, + match="casting from Time to Datetime\\(Microseconds, None\\) not supported; consider using `dt.combine`", + ): + s.cast(pl.Datetime) + + +def test_err_on_invalid_time_zone_cast() -> None: + s = pl.Series([datetime(2021, 1, 1)]) + with pytest.raises(ComputeError, match=r"unable to parse time zone: 'qwerty'"): + s.cast(pl.Datetime("us", "qwerty")) + + +def test_invalid_inner_type_cast_list() -> None: + s = pl.Series([[-1, 1]]) + with pytest.raises( + InvalidOperationError, + match=r"cannot cast List inner type: 'Int64' to Categorical", + ): + s.cast(pl.List(pl.Categorical)) + + +def test_all_null_cast_5826() -> None: + df = pl.DataFrame(data=[pl.Series("a", [None], dtype=pl.String)]) + out = df.with_columns(pl.col("a").cast(pl.Boolean)) + assert out.dtypes == [pl.Boolean] + assert out.item() is None + + +@pytest.mark.parametrize("dtype", INTEGER_DTYPES) +def test_bool_numeric_supertype(dtype: PolarsDataType) -> None: + df = pl.DataFrame({"v": [1, 2, 3, 4, 5, 6]}) + result = df.select((pl.col("v") < 3).sum().cast(dtype) / pl.len()) + assert result.item() - 0.3333333 <= 0.00001 + + +@pytest.mark.parametrize("dtype", [pl.String(), pl.String, str]) +def test_cast_consistency(dtype: PolarsDataType | PythonDataType) -> None: + assert pl.DataFrame().with_columns(a=pl.lit(0.0)).with_columns( + b=pl.col("a").cast(dtype), c=pl.lit(0.0).cast(dtype) + ).to_dict(as_series=False) == {"a": [0.0], "b": ["0.0"], "c": ["0.0"]} + + +def test_cast_int_to_string_unsets_sorted_flag_19424() -> None: + s = pl.Series([1, 2]).set_sorted() + assert s.flags["SORTED_ASC"] + assert not s.cast(pl.String).flags["SORTED_ASC"] + + +def test_cast_integer_to_decimal() -> None: + s = pl.Series([1, 2, 3]) + result = s.cast(pl.Decimal(10, 2)) + expected = pl.Series( + "", [Decimal("1.00"), Decimal("2.00"), Decimal("3.00")], pl.Decimal(10, 2) + ) + assert_series_equal(result, expected) + + +def test_cast_python_dtypes() -> None: + s = pl.Series([0, 1]) + assert s.cast(int).dtype == pl.Int64 + assert s.cast(float).dtype == pl.Float64 + assert s.cast(bool).dtype == pl.Boolean + assert s.cast(str).dtype == pl.String + + +def test_overflowing_cast_literals_21023() -> None: + for no_optimization in [True, False]: + assert_frame_equal( + ( + pl.LazyFrame() + .select( + pl.lit(pl.Series([128], dtype=pl.Int64)).cast( + pl.Int8, wrap_numerical=True + ) + ) + .collect(no_optimization=no_optimization) + ), + pl.Series([-128], dtype=pl.Int8).to_frame(), + ) + + +def test_invalid_empty_cast_to_empty_enum() -> None: + with pytest.raises( + InvalidOperationError, + match="cannot cast / initialize Enum without categories present", + ): + pl.Series([], dtype=pl.Enum) + + +@pytest.mark.parametrize("value", [True, False]) +@pytest.mark.parametrize( + "dtype", + [ + pl.Enum(["a", "b"]), + pl.Series(["a", "b"], dtype=pl.Categorical).dtype, + ], +) +@pytest.mark.usefixtures("test_global_and_local") +def test_invalid_bool_to_cat(value: bool, dtype: PolarsDataType) -> None: + # Enum + with pytest.raises( + InvalidOperationError, + match="cannot cast Boolean to Categorical", + ): + pl.Series([value]).cast(dtype) diff --git a/py-polars/tests/unit/operations/test_clear.py b/py-polars/tests/unit/operations/test_clear.py new file mode 100644 index 000000000000..b3ca4afc7d87 --- /dev/null +++ b/py-polars/tests/unit/operations/test_clear.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import hypothesis.strategies as st +import pytest +from hypothesis import given + +import polars as pl +from polars.testing.parametric import series + + +@given(s=series()) +def test_clear_series_parametric(s: pl.Series) -> None: + result = s.clear() + + assert result.dtype == s.dtype + assert result.name == s.name + assert result.is_empty() + + +@given( + s=series( + excluded_dtypes=[ + pl.Struct, # See: https://github.com/pola-rs/polars/issues/3462 + ] + ), + n=st.integers(min_value=0, max_value=10), +) +def test_clear_series_n_parametric(s: pl.Series, n: int) -> None: + result = s.clear(n) + + assert result.dtype == s.dtype + assert result.name == s.name + assert len(result) == n + assert result.null_count() == n + + +@pytest.mark.parametrize("n", [0, 2, 5]) +def test_clear_series(n: int) -> None: + a = pl.Series(name="a", values=[1, 2, 3], dtype=pl.Int16) + + result = a.clear(n) + assert result.dtype == a.dtype + assert result.name == a.name + assert len(result) == n + assert result.null_count() == n + + +def test_clear_df() -> None: + df = pl.DataFrame( + {"a": [1, 2], "b": [True, False]}, schema={"a": pl.UInt32, "b": pl.Boolean} + ) + + result = df.clear() + assert result.schema == df.schema + assert result.rows() == [] + + result = df.clear(3) + assert result.schema == df.schema + assert result.rows() == [(None, None), (None, None), (None, None)] + + +def test_clear_lf() -> None: + lf = pl.LazyFrame( + { + "foo": [1, 2, 3], + "bar": [6.0, 7.0, 8.0], + "ham": ["a", "b", "c"], + } + ) + ldfe = lf.clear() + assert ldfe.collect_schema() == lf.collect_schema() + + ldfe = lf.clear(2) + assert ldfe.collect_schema() == lf.collect_schema() + assert ldfe.collect().rows() == [(None, None, None), (None, None, None)] + + +def test_clear_series_object_starting_with_null() -> None: + s = pl.Series([None, object()]) + + result = s.clear() + + assert result.dtype == s.dtype + assert result.name == s.name + assert result.is_empty() + + +def test_clear_raise_negative_n() -> None: + s = pl.Series([1, 2, 3]) + + msg = "`n` should be greater than or equal to 0, got -1" + with pytest.raises(ValueError, match=msg): + s.clear(-1) + with pytest.raises(ValueError, match=msg): + s.to_frame().clear(-1) diff --git a/py-polars/tests/unit/operations/test_clip.py b/py-polars/tests/unit/operations/test_clip.py new file mode 100644 index 000000000000..08e8998ded02 --- /dev/null +++ b/py-polars/tests/unit/operations/test_clip.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +from datetime import datetime +from decimal import Decimal + +import pytest + +import polars as pl +from polars.exceptions import InvalidOperationError +from polars.testing import assert_frame_equal, assert_series_equal + + +@pytest.fixture +def clip_exprs() -> list[pl.Expr]: + return [ + pl.col("a").clip(pl.col("min"), pl.col("max")).alias("clip"), + pl.col("a").clip(lower_bound=pl.col("min")).alias("clip_min"), + pl.col("a").clip(upper_bound=pl.col("max")).alias("clip_max"), + ] + + +def test_clip_int(clip_exprs: list[pl.Expr]) -> None: + lf = pl.LazyFrame( + { + "a": [1, 2, 3, 4, 5, None], + "min": [0, -1, 4, None, 4, -10], + "max": [2, 1, 8, 5, None, 10], + } + ) + result = lf.select(clip_exprs) + expected = pl.LazyFrame( + { + "clip": [1, 1, 4, 4, 5, None], + "clip_min": [1, 2, 4, 4, 5, None], + "clip_max": [1, 1, 3, 4, 5, None], + } + ) + assert_frame_equal(result, expected) + + +def test_clip_float(clip_exprs: list[pl.Expr]) -> None: + lf = pl.LazyFrame( + { + "a": [1.0, 2.0, 3.0, 4.0, 5.0, None], + "min": [0.0, -1.0, 4.0, None, 4.0, None], + "max": [2.0, 1.0, 8.0, 5.0, None, None], + } + ) + result = lf.select(clip_exprs) + expected = pl.LazyFrame( + { + "clip": [1.0, 1.0, 4.0, 4.0, 5.0, None], + "clip_min": [1.0, 2.0, 4.0, 4.0, 5.0, None], + "clip_max": [1.0, 1.0, 3.0, 4.0, 5.0, None], + } + ) + assert_frame_equal(result, expected) + + +def test_clip_datetime(clip_exprs: list[pl.Expr]) -> None: + lf = pl.LazyFrame( + { + "a": [ + datetime(1995, 6, 5, 10, 30), + datetime(1995, 6, 5), + datetime(2023, 10, 20, 18, 30, 6), + None, + datetime(2023, 9, 24), + datetime(2000, 1, 10), + ], + "min": [ + datetime(1995, 6, 5, 10, 29), + datetime(1996, 6, 5), + datetime(2020, 9, 24), + datetime(2020, 1, 1), + None, + datetime(2000, 1, 1), + ], + "max": [ + datetime(1995, 7, 21, 10, 30), + datetime(2000, 1, 1), + datetime(2023, 9, 20, 18, 30, 6), + datetime(2000, 1, 1), + datetime(1993, 3, 13), + None, + ], + } + ) + result = lf.select(clip_exprs) + expected = pl.LazyFrame( + { + "clip": [ + datetime(1995, 6, 5, 10, 30), + datetime(1996, 6, 5), + datetime(2023, 9, 20, 18, 30, 6), + None, + datetime(1993, 3, 13), + datetime(2000, 1, 10), + ], + "clip_min": [ + datetime(1995, 6, 5, 10, 30), + datetime(1996, 6, 5), + datetime(2023, 10, 20, 18, 30, 6), + None, + datetime(2023, 9, 24), + datetime(2000, 1, 10), + ], + "clip_max": [ + datetime(1995, 6, 5, 10, 30), + datetime(1995, 6, 5), + datetime(2023, 9, 20, 18, 30, 6), + None, + datetime(1993, 3, 13), + datetime(2000, 1, 10), + ], + } + ) + assert_frame_equal(result, expected) + + +def test_clip_non_numeric_dtype_fails() -> None: + msg = "`clip` only supports physical numeric types" + + s = pl.Series(["a", "b", "c"]) + with pytest.raises(InvalidOperationError, match=msg): + s.clip(pl.lit("b"), pl.lit("z")) + + +def test_clip_string_input() -> None: + df = pl.DataFrame({"a": [0, 1, 2], "min": [1, None, 1]}) + result = df.select(pl.col("a").clip("min")) + expected = pl.DataFrame({"a": [1, 1, 2]}) + assert_frame_equal(result, expected) + + +def test_clip_bound_invalid_for_original_dtype() -> None: + s = pl.Series([1, 2, 3, 4], dtype=pl.UInt32) + with pytest.raises( + InvalidOperationError, match="conversion from `i32` to `u32` failed" + ): + s.clip(-1, 5) + + +def test_clip_decimal() -> None: + ser = pl.Series("a", ["1.1", "2.2", "3.3"], pl.Decimal(21, 1)) + + result = ser.clip(lower_bound=Decimal("1.5"), upper_bound=Decimal("2.5")) + expected = pl.Series("a", ["1.5", "2.2", "2.5"], pl.Decimal(21, 1)) + assert_series_equal(result, expected) + + result = ser.clip(lower_bound=Decimal("1.5")) + expected = pl.Series("a", ["1.5", "2.2", "3.3"], pl.Decimal(21, 1)) + assert_series_equal(result, expected) + + result = ser.clip(upper_bound=Decimal("2.5")) + expected = pl.Series("a", ["1.1", "2.2", "2.5"], pl.Decimal(21, 1)) + assert_series_equal(result, expected) + + +def test_clip_unequal_lengths_22018() -> None: + with pytest.raises(pl.exceptions.ShapeError): + pl.Series([1, 2, 3]).clip(lower_bound=pl.Series([1, 2])) + with pytest.raises(pl.exceptions.ShapeError): + pl.Series([1, 2, 3]).clip(upper_bound=pl.Series([1, 2])) + with pytest.raises(pl.exceptions.ShapeError): + pl.Series([1, 2, 3]).clip(pl.Series([1, 2]), pl.Series([1, 2, 3])) + with pytest.raises(pl.exceptions.ShapeError): + pl.Series([1, 2, 3]).clip(pl.Series([1, 2, 3]), pl.Series([1, 2])) diff --git a/py-polars/tests/unit/operations/test_comparison.py b/py-polars/tests/unit/operations/test_comparison.py new file mode 100644 index 000000000000..a58f718e8ee9 --- /dev/null +++ b/py-polars/tests/unit/operations/test_comparison.py @@ -0,0 +1,459 @@ +from __future__ import annotations + +import math +from contextlib import nullcontext +from typing import TYPE_CHECKING, Any + +import pytest + +import polars as pl +from polars.exceptions import ComputeError +from polars.testing import assert_frame_equal, assert_series_equal + +if TYPE_CHECKING: + from contextlib import AbstractContextManager as ContextManager + + from polars._typing import PolarsDataType + + +def test_comparison_order_null_broadcasting() -> None: + # see more: 8183 + exprs = [ + pl.col("v") < pl.col("null"), + pl.col("null") < pl.col("v"), + pl.col("v") <= pl.col("null"), + pl.col("null") <= pl.col("v"), + pl.col("v") > pl.col("null"), + pl.col("null") > pl.col("v"), + pl.col("v") >= pl.col("null"), + pl.col("null") >= pl.col("v"), + ] + + kwargs = {f"out{i}": e for i, e in zip(range(len(exprs)), exprs)} + + # single value, hits broadcasting branch + df = pl.DataFrame({"v": [42], "null": [None]}) + assert all((df.select(**kwargs).null_count() == 1).rows()[0]) + + # multiple values, hits default branch + df = pl.DataFrame({"v": [42, 42], "null": [None, None]}) + assert all((df.select(**kwargs).null_count() == 2).rows()[0]) + + +def test_comparison_nulls_single() -> None: + df1 = pl.DataFrame( + { + "a": pl.Series([None], dtype=pl.String), + "b": pl.Series([None], dtype=pl.Int64), + "c": pl.Series([None], dtype=pl.Boolean), + } + ) + df2 = pl.DataFrame( + { + "a": pl.Series([None], dtype=pl.String), + "b": pl.Series([None], dtype=pl.Int64), + "c": pl.Series([None], dtype=pl.Boolean), + } + ) + assert (df1 == df2).row(0) == (None, None, None) + assert (df1 != df2).row(0) == (None, None, None) + + +def test_comparison_series_expr() -> None: + df = pl.DataFrame({"a": pl.Series([1, 2, 3]), "b": pl.Series([2, 1, 3])}) + + assert_frame_equal( + df.select( + (df["a"] == pl.col("b")).alias("eq"), # False, False, True + (df["a"] != pl.col("b")).alias("ne"), # True, True, False + (df["a"] < pl.col("b")).alias("lt"), # True, False, False + (df["a"] <= pl.col("b")).alias("le"), # True, False, True + (df["a"] > pl.col("b")).alias("gt"), # False, True, False + (df["a"] >= pl.col("b")).alias("ge"), # False, True, True + ), + pl.DataFrame( + { + "eq": [False, False, True], + "ne": [True, True, False], + "lt": [True, False, False], + "le": [True, False, True], + "gt": [False, True, False], + "ge": [False, True, True], + } + ), + ) + + +def test_comparison_expr_expr() -> None: + df = pl.DataFrame({"a": pl.Series([1, 2, 3]), "b": pl.Series([2, 1, 3])}) + + assert_frame_equal( + df.select( + (pl.col("a") == pl.col("b")).alias("eq"), # False, False, True + (pl.col("a") != pl.col("b")).alias("ne"), # True, True, False + (pl.col("a") < pl.col("b")).alias("lt"), # True, False, False + (pl.col("a") <= pl.col("b")).alias("le"), # True, False, True + (pl.col("a") > pl.col("b")).alias("gt"), # False, True, False + (pl.col("a") >= pl.col("b")).alias("ge"), # False, True, True + ), + pl.DataFrame( + { + "eq": [False, False, True], + "ne": [True, True, False], + "lt": [True, False, False], + "le": [True, False, True], + "gt": [False, True, False], + "ge": [False, True, True], + } + ), + ) + + +def test_comparison_expr_series() -> None: + df = pl.DataFrame({"a": pl.Series([1, 2, 3]), "b": pl.Series([2, 1, 3])}) + + assert_frame_equal( + df.select( + (pl.col("a") == df["b"]).alias("eq"), # False, False, True + (pl.col("a") != df["b"]).alias("ne"), # True, True, False + (pl.col("a") < df["b"]).alias("lt"), # True, False, False + (pl.col("a") <= df["b"]).alias("le"), # True, False, True + (pl.col("a") > df["b"]).alias("gt"), # False, True, False + (pl.col("a") >= df["b"]).alias("ge"), # False, True, True + ), + pl.DataFrame( + { + "eq": [False, False, True], + "ne": [True, True, False], + "lt": [True, False, False], + "le": [True, False, True], + "gt": [False, True, False], + "ge": [False, True, True], + } + ), + ) + + +def test_offset_handling_arg_where_7863() -> None: + df_check = pl.DataFrame({"a": [0, 1]}) + df_check.select((pl.lit(0).append(pl.col("a")).append(0)) != 0) + assert ( + df_check.select((pl.lit(0).append(pl.col("a")).append(0)) != 0) + .select(pl.col("literal").arg_true()) + .item() + == 2 + ) + + +def test_missing_equality_on_bools() -> None: + df = pl.DataFrame( + { + "a": [True, None, False], + } + ) + + assert df.select(pl.col("a").ne_missing(True))["a"].to_list() == [False, True, True] + assert df.select(pl.col("a").ne_missing(False))["a"].to_list() == [ + True, + True, + False, + ] + + +def test_struct_equality_18870() -> None: + s = pl.Series([{"a": 1}, None]) + + # eq + result = s.eq(s).to_list() + expected = [True, None] + assert result == expected + + # ne + result = s.ne(s).to_list() + expected = [False, None] + assert result == expected + + # eq_missing + result = s.eq_missing(s).to_list() + expected = [True, True] + assert result == expected + + # ne_missing + result = s.ne_missing(s).to_list() + expected = [False, False] + assert result == expected + + +def test_struct_nested_equality() -> None: + df = pl.DataFrame( + { + "a": [{"foo": 0, "bar": "1"}, {"foo": None, "bar": "1"}, None], + "b": [{"foo": 0, "bar": "1"}] * 3, + } + ) + + # eq + ans = df.select(pl.col("a").eq(pl.col("b"))) + expected = pl.DataFrame({"a": [True, False, None]}) + assert_frame_equal(ans, expected) + + # ne + ans = df.select(pl.col("a").ne(pl.col("b"))) + expected = pl.DataFrame({"a": [False, True, None]}) + assert_frame_equal(ans, expected) + + +def isnan(x: Any) -> bool: + return isinstance(x, float) and math.isnan(x) + + +def reference_ordering_propagating(lhs: Any, rhs: Any) -> str | None: + # normal < nan, nan == nan, nulls propagate + if lhs is None or rhs is None: + return None + + if isnan(lhs) and isnan(rhs): + return "=" + + if isnan(lhs) or lhs > rhs: + return ">" + + if isnan(rhs) or lhs < rhs: + return "<" + + return "=" + + +def reference_ordering_missing(lhs: Any, rhs: Any) -> str: + # null < normal < nan, nan == nan, null == null + if lhs is None and rhs is None: + return "=" + + if lhs is None: + return "<" + + if rhs is None: + return ">" + + if isnan(lhs) and isnan(rhs): + return "=" + + if isnan(lhs) or lhs > rhs: + return ">" + + if isnan(rhs) or lhs < rhs: + return "<" + + return "=" + + +def verify_total_ordering( + lhs: Any, rhs: Any, dummy: Any, dtype: PolarsDataType +) -> None: + ref = reference_ordering_propagating(lhs, rhs) + refmiss = reference_ordering_missing(lhs, rhs) + + # Add dummy variable so we don't broadcast or do full-null optimization. + assert dummy is not None + df = pl.DataFrame( + {"l": [lhs, dummy], "r": [rhs, dummy]}, schema={"l": dtype, "r": dtype} + ) + + ans = df.select( + (pl.col("l") == pl.col("r")).alias("eq"), + (pl.col("l") != pl.col("r")).alias("ne"), + (pl.col("l") < pl.col("r")).alias("lt"), + (pl.col("l") <= pl.col("r")).alias("le"), + (pl.col("l") > pl.col("r")).alias("gt"), + (pl.col("l") >= pl.col("r")).alias("ge"), + pl.col("l").eq_missing(pl.col("r")).alias("eq_missing"), + pl.col("l").ne_missing(pl.col("r")).alias("ne_missing"), + ) + + ans_correct_dict = { + "eq": [ref and ref == "="], # "ref and X" propagates ref is None + "ne": [ref and ref != "="], + "lt": [ref and ref == "<"], + "le": [ref and (ref == "<" or ref == "=")], + "gt": [ref and ref == ">"], + "ge": [ref and (ref == ">" or ref == "=")], + "eq_missing": [refmiss == "="], + "ne_missing": [refmiss != "="], + } + ans_correct = pl.DataFrame( + ans_correct_dict, schema=dict.fromkeys(ans_correct_dict, pl.Boolean) + ) + + assert_frame_equal(ans[:1], ans_correct) + + +def verify_total_ordering_broadcast( + lhs: Any, rhs: Any, dummy: Any, dtype: PolarsDataType +) -> None: + ref = reference_ordering_propagating(lhs, rhs) + refmiss = reference_ordering_missing(lhs, rhs) + + # Add dummy variable so we don't broadcast inherently. + assert dummy is not None + df = pl.DataFrame( + {"l": [lhs, dummy], "r": [rhs, dummy]}, schema={"l": dtype, "r": dtype} + ) + + ans_first = df.select( + (pl.col("l") == pl.col("r").first()).alias("eq"), + (pl.col("l") != pl.col("r").first()).alias("ne"), + (pl.col("l") < pl.col("r").first()).alias("lt"), + (pl.col("l") <= pl.col("r").first()).alias("le"), + (pl.col("l") > pl.col("r").first()).alias("gt"), + (pl.col("l") >= pl.col("r").first()).alias("ge"), + pl.col("l").eq_missing(pl.col("r").first()).alias("eq_missing"), + pl.col("l").ne_missing(pl.col("r").first()).alias("ne_missing"), + ) + + ans_scalar = df.select( + (pl.col("l") == rhs).alias("eq"), + (pl.col("l") != rhs).alias("ne"), + (pl.col("l") < rhs).alias("lt"), + (pl.col("l") <= rhs).alias("le"), + (pl.col("l") > rhs).alias("gt"), + (pl.col("l") >= rhs).alias("ge"), + (pl.col("l").eq_missing(rhs)).alias("eq_missing"), + (pl.col("l").ne_missing(rhs)).alias("ne_missing"), + ) + + ans_correct_dict = { + "eq": [ref and ref == "="], # "ref and X" propagates ref is None + "ne": [ref and ref != "="], + "lt": [ref and ref == "<"], + "le": [ref and (ref == "<" or ref == "=")], + "gt": [ref and ref == ">"], + "ge": [ref and (ref == ">" or ref == "=")], + "eq_missing": [refmiss == "="], + "ne_missing": [refmiss != "="], + } + ans_correct = pl.DataFrame( + ans_correct_dict, schema=dict.fromkeys(ans_correct_dict, pl.Boolean) + ) + + assert_frame_equal(ans_first[:1], ans_correct) + assert_frame_equal(ans_scalar[:1], ans_correct) + + +INTERESTING_FLOAT_VALUES = [ + 0.0, + -0.0, + -1.0, + 1.0, + -float("nan"), + float("nan"), + -float("inf"), + float("inf"), + None, +] + + +@pytest.mark.slow +@pytest.mark.parametrize("lhs", INTERESTING_FLOAT_VALUES) +@pytest.mark.parametrize("rhs", INTERESTING_FLOAT_VALUES) +def test_total_ordering_float_series(lhs: float | None, rhs: float | None) -> None: + verify_total_ordering(lhs, rhs, 0.0, pl.Float32) + verify_total_ordering(lhs, rhs, 0.0, pl.Float64) + context: pytest.WarningsRecorder | ContextManager[None] = ( + pytest.warns(UserWarning) if rhs is None else nullcontext() + ) + with context: + verify_total_ordering_broadcast(lhs, rhs, 0.0, pl.Float32) + verify_total_ordering_broadcast(lhs, rhs, 0.0, pl.Float64) + + +INTERESTING_STRING_VALUES = [ + "", + "foo", + "bar", + "fooo", + "fooooooooooo", + "foooooooooooo", + "fooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooom", + "foooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo", + "fooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooop", + None, +] + + +@pytest.mark.slow +@pytest.mark.parametrize("lhs", INTERESTING_STRING_VALUES) +@pytest.mark.parametrize("rhs", INTERESTING_STRING_VALUES) +def test_total_ordering_string_series(lhs: str | None, rhs: str | None) -> None: + verify_total_ordering(lhs, rhs, "", pl.String) + context: pytest.WarningsRecorder | ContextManager[None] = ( + pytest.warns(UserWarning) if rhs is None else nullcontext() + ) + with context: + verify_total_ordering_broadcast(lhs, rhs, "", pl.String) + + +@pytest.mark.slow +@pytest.mark.parametrize("str_lhs", INTERESTING_STRING_VALUES) +@pytest.mark.parametrize("str_rhs", INTERESTING_STRING_VALUES) +def test_total_ordering_binary_series(str_lhs: str | None, str_rhs: str | None) -> None: + lhs = None if str_lhs is None else str_lhs.encode("utf-8") + rhs = None if str_rhs is None else str_rhs.encode("utf-8") + verify_total_ordering(lhs, rhs, b"", pl.Binary) + context: pytest.WarningsRecorder | ContextManager[None] = ( + pytest.warns(UserWarning) if rhs is None else nullcontext() + ) + with context: + verify_total_ordering_broadcast(lhs, rhs, b"", pl.Binary) + + +@pytest.mark.parametrize("lhs", [None, False, True]) +@pytest.mark.parametrize("rhs", [None, False, True]) +def test_total_ordering_bool_series(lhs: bool | None, rhs: bool | None) -> None: + verify_total_ordering(lhs, rhs, False, pl.Boolean) + context: pytest.WarningsRecorder | ContextManager[None] = ( + pytest.warns(UserWarning) if rhs is None else nullcontext() + ) + with context: + verify_total_ordering_broadcast(lhs, rhs, False, pl.Boolean) + + +def test_cat_compare_with_bool() -> None: + data = pl.DataFrame([pl.Series("col1", ["a", "b"], dtype=pl.Categorical)]) + + with pytest.raises(ComputeError, match="cannot compare categorical with bool"): + data.filter(pl.col("col1") == True) # noqa: E712 + + +def test_schema_ne_missing_9256() -> None: + df = pl.DataFrame({"a": [0, 1, None], "b": [True, False, True]}) + + assert df.select(pl.col("a").ne_missing(0).or_(pl.col("b")))["a"].all() + + +def test_nested_binary_literal_super_type_12227() -> None: + # The `.alias` is important here to trigger the bug. + result = pl.select(x=1).select((pl.lit(0) + ((pl.col("x") > 0) * 0.1)).alias("x")) + assert result.item() == 0.1 + + result = pl.select((pl.lit(0) + (pl.lit(0) == pl.lit(0)) * pl.lit(0.1)) + pl.lit(0)) + assert result.item() == 0.1 + + +def test_struct_broadcasting_comparison() -> None: + df = pl.DataFrame({"foo": [{"a": 1}, {"a": 2}, {"a": 1}]}) + assert df.select(eq=pl.col.foo == pl.col.foo.last()).to_dict(as_series=False) == { + "eq": [True, False, True] + } + + +@pytest.mark.parametrize("dtype", [pl.List(pl.Int64), pl.Array(pl.Int64, 1)]) +def test_compare_list_broadcast_empty_first_chunk_20165(dtype: pl.DataType) -> None: + s = pl.concat(2 * [pl.Series([[1]], dtype=dtype)]).filter([False, True]) + + assert s.len() == 1 + assert s.n_chunks() == 2 + + assert_series_equal( + pl.select(pl.lit(pl.Series([[1], [2]]), dtype=dtype) == pl.lit(s)).to_series(), + pl.Series([True, False]), + ) diff --git a/py-polars/tests/unit/operations/test_concat.py b/py-polars/tests/unit/operations/test_concat.py new file mode 100644 index 000000000000..5158fd251b5d --- /dev/null +++ b/py-polars/tests/unit/operations/test_concat.py @@ -0,0 +1,109 @@ +import io +from typing import IO + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + + +def test_concat_invalid_schema_err_20355() -> None: + lf1 = pl.LazyFrame({"x": [1], "y": [None]}) + lf2 = pl.LazyFrame({"y": [1]}) + with pytest.raises(pl.exceptions.InvalidOperationError): + pl.concat([lf1, lf2]).collect(engine="streaming") + + +def test_concat_df() -> None: + df1 = pl.DataFrame({"a": [2, 1, 3], "b": [1, 2, 3], "c": [1, 2, 3]}) + df2 = pl.concat([df1, df1], rechunk=True) + + assert df2.shape == (6, 3) + assert df2.n_chunks() == 1 + assert df2.rows() == df1.rows() + df1.rows() + assert pl.concat([df1, df1], rechunk=False).n_chunks() == 2 + + # concat from generator of frames + df3 = pl.concat(items=(df1 for _ in range(2))) + assert_frame_equal(df2, df3) + + # check that df4 is not modified following concat of itself + df4 = pl.from_records(((1, 2), (1, 2))) + _ = pl.concat([df4, df4, df4]) + + assert df4.shape == (2, 2) + assert df4.rows() == [(1, 1), (2, 2)] + + # misc error conditions + with pytest.raises(ValueError): + _ = pl.concat([]) + + with pytest.raises(ValueError): + pl.concat([df1, df1], how="rubbish") # type: ignore[arg-type] + + +def test_concat_to_empty() -> None: + assert pl.concat([pl.DataFrame([]), pl.DataFrame({"a": [1]})]).to_dict( + as_series=False + ) == {"a": [1]} + + +def test_concat_multiple_parquet_inmem() -> None: + f = io.BytesIO() + g = io.BytesIO() + + df1 = pl.DataFrame( + { + "a": [1, 2, 3], + "b": ["xyz", "abc", "wow"], + } + ) + df2 = pl.DataFrame( + { + "a": [5, 6, 7], + "b": ["a", "few", "entries"], + } + ) + + dfs = pl.concat([df1, df2]) + + df1.write_parquet(f) + df2.write_parquet(g) + + f.seek(0) + g.seek(0) + + items: list[IO[bytes]] = [f, g] + assert_frame_equal(pl.read_parquet(items), dfs) + + f.seek(0) + g.seek(0) + + assert_frame_equal(pl.read_parquet(items, use_pyarrow=True), dfs) + + f.seek(0) + g.seek(0) + + fb = f.read() + gb = g.read() + + assert_frame_equal(pl.read_parquet([fb, gb]), dfs) + assert_frame_equal(pl.read_parquet([fb, gb], use_pyarrow=True), dfs) + + +def test_concat_series() -> None: + s = pl.Series("a", [2, 1, 3]) + + assert pl.concat([s, s]).len() == 6 + # check if s remains unchanged + assert s.len() == 3 + + +def test_concat_null_20501() -> None: + a = pl.DataFrame({"id": [1], "value": ["foo"]}) + b = pl.DataFrame({"id": [2], "value": [None]}) + + assert pl.concat([a.lazy(), b.lazy()]).collect().to_dict(as_series=False) == { + "id": [1, 2], + "value": ["foo", None], + } diff --git a/py-polars/tests/unit/operations/test_cross_join.py b/py-polars/tests/unit/operations/test_cross_join.py new file mode 100644 index 000000000000..9913ab3f1094 --- /dev/null +++ b/py-polars/tests/unit/operations/test_cross_join.py @@ -0,0 +1,77 @@ +from datetime import datetime +from zoneinfo import ZoneInfo + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + + +def test_cross_join_predicate_pushdown_block_16956() -> None: + lf = pl.LazyFrame( + [ + [1718085600000, 1718172000000, 1718776800000], + [1718114400000, 1718200800000, 1718805600000], + ], + schema=["start_datetime", "end_datetime"], + ).cast(pl.Datetime("ms", "Europe/Amsterdam")) + + assert ( + lf.join(lf, how="cross") + .filter( + pl.col.end_datetime_right.is_between( + pl.col.start_datetime, pl.col.start_datetime.dt.offset_by("132h") + ) + ) + .select("start_datetime", "end_datetime_right") + ).collect(predicate_pushdown=True).to_dict(as_series=False) == { + "start_datetime": [ + datetime(2024, 6, 11, 8, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")), + datetime(2024, 6, 11, 8, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")), + datetime(2024, 6, 12, 8, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")), + datetime(2024, 6, 19, 8, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")), + ], + "end_datetime_right": [ + datetime(2024, 6, 11, 16, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")), + datetime(2024, 6, 12, 16, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")), + datetime(2024, 6, 12, 16, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")), + datetime(2024, 6, 19, 16, 0, tzinfo=ZoneInfo(key="Europe/Amsterdam")), + ], + } + + +def test_cross_join_raise_on_keys() -> None: + df = pl.DataFrame({"a": [0, 1], "b": ["x", "y"]}) + + with pytest.raises(ValueError): + df.join(df, how="cross", left_on="a", right_on="b") + + +def test_nested_loop_join() -> None: + left = pl.LazyFrame( + { + "a": [1, 2, 1, 3], + "b": [1, 2, 3, 4], + } + ) + right = pl.LazyFrame( + { + "c": [4, 1, 2], + "d": [1, 2, 3], + } + ) + + actual = left.join_where(right, pl.col("a") != pl.col("c")) + plan = actual.explain() + assert "NESTED LOOP JOIN" in plan + expected = pl.DataFrame( + { + "a": [1, 1, 2, 2, 1, 1, 3, 3, 3], + "b": [1, 1, 2, 2, 3, 3, 4, 4, 4], + "c": [4, 2, 4, 1, 4, 2, 4, 1, 2], + "d": [1, 3, 1, 2, 1, 3, 1, 2, 3], + } + ) + assert_frame_equal( + actual.collect(), expected, check_row_order=False, check_exact=True + ) diff --git a/py-polars/tests/unit/operations/test_cut.py b/py-polars/tests/unit/operations/test_cut.py new file mode 100644 index 000000000000..a8e79bc7271f --- /dev/null +++ b/py-polars/tests/unit/operations/test_cut.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal +from tests.unit.conftest import with_string_cache_if_auto_streaming + +inf = float("inf") + + +def test_cut() -> None: + s = pl.Series("a", [-2, -1, 0, 1, 2]) + + result = s.cut([-1, 1]) + + expected = pl.Series( + "a", + [ + "(-inf, -1]", + "(-inf, -1]", + "(-1, 1]", + "(-1, 1]", + "(1, inf]", + ], + dtype=pl.Categorical, + ) + assert_series_equal(result, expected, categorical_as_str=True) + + +def test_cut_lazy_schema() -> None: + lf = pl.LazyFrame({"a": [-2, -1, 0, 1, 2]}) + + result = lf.select(pl.col("a").cut([-1, 1])) + + expected = pl.LazyFrame( + {"a": ["(-inf, -1]", "(-inf, -1]", "(-1, 1]", "(-1, 1]", "(1, inf]"]}, + schema={"a": pl.Categorical}, + ) + assert_frame_equal(result, expected, categorical_as_str=True) + + +def test_cut_include_breaks() -> None: + s = pl.Series("a", [-2, -1, 0, 1, 2]) + + out = s.cut([-1.5, 0.25, 1.0], labels=["a", "b", "c", "d"], include_breaks=True) + + expected = pl.DataFrame( + { + "breakpoint": [-1.5, 0.25, 0.25, 1.0, inf], + "category": ["a", "b", "b", "c", "d"], + }, + schema_overrides={"category": pl.Categorical}, + ).to_struct("a") + assert_series_equal(out, expected, categorical_as_str=True) + + +# https://github.com/pola-rs/polars/issues/11255 +def test_cut_include_breaks_lazy_schema() -> None: + lf = pl.LazyFrame({"a": [-2, -1, 0, 1, 2]}) + + result = lf.select( + pl.col("a").cut([-1, 1], include_breaks=True).alias("cut") + ).unnest("cut") + + expected = pl.LazyFrame( + { + "breakpoint": [-1.0, -1.0, 1.0, 1.0, inf], + "category": ["(-inf, -1]", "(-inf, -1]", "(-1, 1]", "(-1, 1]", "(1, inf]"], + }, + schema_overrides={"category": pl.Categorical}, + ) + assert_frame_equal(result, expected, categorical_as_str=True) + + +def test_cut_null_values() -> None: + s = pl.Series([-1.0, None, 1.0, 2.0, None, 8.0, 4.0]) + + result = s.cut([1.5, 5.0], labels=["a", "b", "c"]) + + expected = pl.Series(["a", None, "a", "b", None, "c", "b"], dtype=pl.Categorical) + assert_series_equal(result, expected, categorical_as_str=True) + + +def test_cut_bin_name_in_agg_context() -> None: + df = pl.DataFrame({"a": [1]}).select( + cut=pl.col("a").cut([1, 2], include_breaks=True).over(1), + qcut=pl.col("a").qcut([1], include_breaks=True).over(1), + qcut_uniform=pl.col("a").qcut(1, include_breaks=True).over(1), + ) + schema = pl.Struct( + {"breakpoint": pl.Float64, "category": pl.Categorical("physical")} + ) + assert df.schema == {"cut": schema, "qcut": schema, "qcut_uniform": schema} + + +@pytest.mark.parametrize( + ("breaks", "expected_labels", "expected_physical", "expected_unique"), + [ + ( + [2, 4], + pl.Series("x", ["(-inf, 2]", "(-inf, 2]", "(2, 4]", "(2, 4]", "(4, inf]"]), + pl.Series("x", [0, 0, 1, 1, 2], dtype=pl.UInt32), + 3, + ), + ( + [99, 101], + pl.Series("x", 5 * ["(-inf, 99]"]), + pl.Series("x", 5 * [0], dtype=pl.UInt32), + 1, + ), + ], +) +@with_string_cache_if_auto_streaming +def test_cut_fast_unique_15981( + breaks: list[int], + expected_labels: pl.Series, + expected_physical: pl.Series, + expected_unique: int, +) -> None: + s = pl.Series("x", [1, 2, 3, 4, 5]) + + include_breaks = False + s_cut = s.cut(breaks, include_breaks=include_breaks) + + assert_series_equal(s_cut.cast(pl.String), expected_labels) + assert_series_equal(s_cut.to_physical(), expected_physical) + assert s_cut.n_unique() == s_cut.to_physical().n_unique() == expected_unique + s_cut.to_frame().group_by(s.name).len() + + include_breaks = True + s_cut = ( + s.cut(breaks, include_breaks=include_breaks).struct.field("category").alias("x") + ) + + assert_series_equal(s_cut.cast(pl.String), expected_labels) + assert_series_equal(s_cut.to_physical(), expected_physical) + assert s_cut.n_unique() == s_cut.to_physical().n_unique() == expected_unique + s_cut.to_frame().group_by(s.name).len() diff --git a/py-polars/tests/unit/operations/test_diff.py b/py-polars/tests/unit/operations/test_diff.py new file mode 100644 index 000000000000..8eb441e0b695 --- /dev/null +++ b/py-polars/tests/unit/operations/test_diff.py @@ -0,0 +1,34 @@ +import pytest + +import polars as pl +from polars.exceptions import ComputeError +from polars.testing import assert_frame_equal + + +def test_diff_duration_dtype() -> None: + data = ["2022-01-01", "2022-01-02", "2022-01-03", "2022-01-03"] + df = pl.Series("date", data).str.to_date("%Y-%m-%d").to_frame() + + result = df.select(pl.col("date").diff() < pl.duration(days=1)) + + expected = pl.Series("date", [None, False, False, True]).to_frame() + assert_frame_equal(result, expected) + + +def test_diff_scalarity() -> None: + df = pl.DataFrame( + { + "a": [1, 2, 3, 2, 2, 3, 0], + "n": [1, 3, 2, 4, 5, 1, 1], + } + ) + + with pytest.raises(ComputeError, match="'n' must be scalar value"): + df.select(pl.col("a").diff("n")) + + result = df.select(pl.col("a").diff(pl.col("n").mean().cast(pl.Int32))) + expected = pl.DataFrame({"a": [None, None, 2, 0, -1, 1, -2]}) + assert_frame_equal(result, expected) + + result = df.select(pl.col("a").diff(2)) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/operations/test_drop.py b/py-polars/tests/unit/operations/test_drop.py new file mode 100644 index 000000000000..fec5c16c2bf0 --- /dev/null +++ b/py-polars/tests/unit/operations/test_drop.py @@ -0,0 +1,172 @@ +from typing import Any + +import pytest + +import polars as pl +import polars.selectors as cs +from polars.testing import assert_frame_equal + + +def test_drop() -> None: + df = pl.DataFrame({"a": [2, 1, 3], "b": ["a", "b", "c"], "c": [1, 2, 3]}) + df = df.drop("a") + assert df.shape == (3, 2) + + df = pl.DataFrame({"a": [2, 1, 3], "b": ["a", "b", "c"], "c": [1, 2, 3]}) + s = df.drop_in_place("a") + assert s.name == "a" + + +def test_drop_explode_6641() -> None: + df = pl.DataFrame( + { + "chromosome": ["chr1"] * 2, + "identifier": [["chr1:10426:10429:ACC>A"], ["chr1:10426:10429:ACC>*"]], + "alternate": [["A"], ["T"]], + "quality": pl.Series([None, None], dtype=pl.Float32()), + } + ).lazy() + + assert ( + df.explode(["identifier", "alternate"]) + .with_columns(pl.struct(["identifier", "alternate"]).alias("test")) + .drop(["identifier", "alternate"]) + .select(pl.concat_list([pl.col("test"), pl.col("test")])) + .collect() + ).to_dict(as_series=False) == { + "test": [ + [ + {"identifier": "chr1:10426:10429:ACC>A", "alternate": "A"}, + {"identifier": "chr1:10426:10429:ACC>A", "alternate": "A"}, + ], + [ + {"identifier": "chr1:10426:10429:ACC>*", "alternate": "T"}, + {"identifier": "chr1:10426:10429:ACC>*", "alternate": "T"}, + ], + ] + } + + +@pytest.mark.parametrize( + "subset", + [ + "foo", + ["foo"], + {"foo"}, + ], +) +def test_drop_nulls(subset: Any) -> None: + df = pl.DataFrame( + { + "foo": [1, 2, 3], + "bar": [6, None, 8], + "ham": ["a", "b", "c"], + } + ) + result = df.drop_nulls() + expected = pl.DataFrame( + { + "foo": [1, 3], + "bar": [6, 8], + "ham": ["a", "c"], + } + ) + assert_frame_equal(result, expected) + + # below we only drop entries if they are null in the column 'foo' + result = df.drop_nulls(subset) + assert_frame_equal(result, df) + + +def test_drop_nulls_lazy() -> None: + lf = pl.LazyFrame({"foo": [1, 2, 3], "bar": [6, None, 8], "ham": ["a", "b", "c"]}) + expected = pl.LazyFrame({"foo": [1, 3], "bar": [6, 8], "ham": ["a", "c"]}) + + result = lf.drop_nulls() + assert_frame_equal(result, expected) + + result = lf.drop_nulls(cs.contains("a")) + assert_frame_equal(result, expected) + + +def test_drop_nulls_misc() -> None: + df = pl.DataFrame({"nrs": [None, 1, 2, 3, None, 4, 5, None]}) + assert df.select(pl.col("nrs").drop_nulls()).to_dict(as_series=False) == { + "nrs": [1, 2, 3, 4, 5] + } + + +def test_drop_columns() -> None: + out = pl.LazyFrame({"a": [1], "b": [2], "c": [3]}).drop(["a", "b"]) + assert out.collect_schema().names() == ["c"] + + out = pl.LazyFrame({"a": [1], "b": [2], "c": [3]}).drop(~cs.starts_with("c")) + assert out.collect_schema().names() == ["c"] + + out = pl.LazyFrame({"a": [1], "b": [2], "c": [3]}).drop("a") + assert out.collect_schema().names() == ["b", "c"] + + out2 = pl.DataFrame({"a": [1], "b": [2], "c": [3]}).drop("a", "b") + assert out2.collect_schema().names() == ["c"] + + out2 = pl.DataFrame({"a": [1], "b": [2], "c": [3]}).drop({"a", "b", "c"}) + assert out2.collect_schema().names() == [] + + +@pytest.mark.parametrize("lazy", [True, False]) +def test_drop_nans(lazy: bool) -> None: + DataFrame = pl.LazyFrame if lazy else pl.DataFrame + df = DataFrame( + { + "a": [1.0, float("nan"), 3.0, 4.0], + "b": [10000, 20000, 30000, 40000], + "c": [-90.5, 25.0, 0.0, float("nan")], + } + ) + expected = DataFrame( + { + "a": [1.0, 3.0], + "b": [10000, 30000], + "c": [-90.5, 0.0], + } + ) + assert_frame_equal(expected, df.drop_nans()) + + expected = DataFrame( + { + "a": [1.0, float("nan"), 3.0], + "b": [10000, 20000, 30000], + "c": [-90.5, 25.0, 0.0], + } + ) + assert_frame_equal(expected, df.drop_nans(subset=["c"])) + assert_frame_equal(expected, df.drop_nans(subset=cs.ends_with("c"))) + + +def test_drop_nan_ignore_null_3525() -> None: + df = pl.DataFrame({"a": [1.0, float("nan"), 2.0, None, 3.0, 4.0]}) + assert df.select(pl.col("a").drop_nans()).to_series().to_list() == [ + 1.0, + 2.0, + None, + 3.0, + 4.0, + ] + + +def test_drop_without_parameters() -> None: + df = pl.DataFrame({"a": [1, 2]}) + assert_frame_equal(df.drop(), df) + assert_frame_equal(df.lazy().drop(*[]), df.lazy()) + + +def test_drop_strict() -> None: + df = pl.DataFrame({"a": [1, 2]}) + + df.drop("a") + + with pytest.raises(pl.exceptions.ColumnNotFoundError, match="b"): + df.drop("b") + + df.drop("a", strict=False) + df.drop("b", strict=False) diff --git a/py-polars/tests/unit/operations/test_drop_nulls.py b/py-polars/tests/unit/operations/test_drop_nulls.py new file mode 100644 index 000000000000..6cdfc071fa31 --- /dev/null +++ b/py-polars/tests/unit/operations/test_drop_nulls.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from hypothesis import given + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal +from polars.testing.parametric import series + + +@given( + s=series( + allow_null=True, + excluded_dtypes=[ + pl.Struct, # See: https://github.com/pola-rs/polars/issues/3462 + ], + ) +) +def test_drop_nulls_parametric(s: pl.Series) -> None: + result = s.drop_nulls() + assert result.len() == s.len() - s.null_count() + + filter_result = s.filter(s.is_not_null()) + assert_series_equal(result, filter_result) + + +def test_df_drop_nulls_struct() -> None: + df = pl.DataFrame( + {"x": [{"a": 1, "b": 2}, {"a": 1, "b": None}, {"a": None, "b": 2}, None]} + ) + + result = df.drop_nulls() + + expected = df.head(3) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/operations/test_ewm.py b/py-polars/tests/unit/operations/test_ewm.py new file mode 100644 index 000000000000..4b209e64324c --- /dev/null +++ b/py-polars/tests/unit/operations/test_ewm.py @@ -0,0 +1,312 @@ +from __future__ import annotations + +from typing import Any + +import hypothesis.strategies as st +import numpy as np +import pytest +from hypothesis import given + +import polars as pl +from polars.expr.expr import _prepare_alpha +from polars.testing import assert_series_equal +from polars.testing.parametric import series + + +def test_ewm_mean() -> None: + s = pl.Series([2, 5, 3]) + + expected = pl.Series([2.0, 4.0, 3.4285714285714284]) + assert_series_equal(s.ewm_mean(alpha=0.5, adjust=True, ignore_nulls=True), expected) + assert_series_equal( + s.ewm_mean(alpha=0.5, adjust=True, ignore_nulls=False), expected + ) + + expected = pl.Series([2.0, 3.8, 3.421053]) + assert_series_equal(s.ewm_mean(com=2.0, adjust=True, ignore_nulls=True), expected) + assert_series_equal(s.ewm_mean(com=2.0, adjust=True, ignore_nulls=False), expected) + + expected = pl.Series([2.0, 3.5, 3.25]) + assert_series_equal( + s.ewm_mean(alpha=0.5, adjust=False, ignore_nulls=True), expected + ) + assert_series_equal( + s.ewm_mean(alpha=0.5, adjust=False, ignore_nulls=False), expected + ) + + s = pl.Series([2, 3, 5, 7, 4]) + + expected = pl.Series([None, 2.666667, 4.0, 5.6, 4.774194]) + assert_series_equal( + s.ewm_mean(alpha=0.5, adjust=True, min_samples=2, ignore_nulls=True), expected + ) + assert_series_equal( + s.ewm_mean(alpha=0.5, adjust=True, min_samples=2, ignore_nulls=False), expected + ) + + expected = pl.Series([None, None, 4.0, 5.6, 4.774194]) + assert_series_equal( + s.ewm_mean(alpha=0.5, adjust=True, min_samples=3, ignore_nulls=True), expected + ) + assert_series_equal( + s.ewm_mean(alpha=0.5, adjust=True, min_samples=3, ignore_nulls=False), expected + ) + + s = pl.Series([None, 1.0, 5.0, 7.0, None, 2.0, 5.0, 4]) + + expected = pl.Series( + [ + None, + 1.0, + 3.6666666666666665, + 5.571428571428571, + None, + 3.6666666666666665, + 4.354838709677419, + 4.174603174603175, + ], + ) + assert_series_equal(s.ewm_mean(alpha=0.5, adjust=True, ignore_nulls=True), expected) + expected = pl.Series( + [ + None, + 1.0, + 3.666666666666667, + 5.571428571428571, + None, + 3.08695652173913, + 4.2, + 4.092436974789916, + ] + ) + assert_series_equal( + s.ewm_mean(alpha=0.5, adjust=True, ignore_nulls=False), expected + ) + + expected = pl.Series([None, 1.0, 3.0, 5.0, None, 3.5, 4.25, 4.125]) + assert_series_equal( + s.ewm_mean(alpha=0.5, adjust=False, ignore_nulls=True), expected + ) + + expected = pl.Series([None, 1.0, 3.0, 5.0, None, 3.0, 4.0, 4.0]) + assert_series_equal( + s.ewm_mean(alpha=0.5, adjust=False, ignore_nulls=False), expected + ) + + +def test_ewm_mean_leading_nulls() -> None: + for min_samples in [1, 2, 3]: + assert ( + pl.Series([1, 2, 3, 4]) + .ewm_mean(com=3, min_samples=min_samples, ignore_nulls=False) + .null_count() + == min_samples - 1 + ) + assert pl.Series([None, 1.0, 1.0, 1.0]).ewm_mean( + alpha=0.5, min_samples=1, ignore_nulls=True + ).to_list() == [None, 1.0, 1.0, 1.0] + assert pl.Series([None, 1.0, 1.0, 1.0]).ewm_mean( + alpha=0.5, min_samples=2, ignore_nulls=True + ).to_list() == [None, None, 1.0, 1.0] + + +def test_ewm_mean_min_samples() -> None: + series = pl.Series([1.0, None, None, None]) + + ewm_mean = series.ewm_mean(alpha=0.5, min_samples=1, ignore_nulls=True) + assert ewm_mean.to_list() == [1.0, None, None, None] + ewm_mean = series.ewm_mean(alpha=0.5, min_samples=2, ignore_nulls=True) + assert ewm_mean.to_list() == [None, None, None, None] + + series = pl.Series([1.0, None, 2.0, None, 3.0]) + + ewm_mean = series.ewm_mean(alpha=0.5, min_samples=1, ignore_nulls=True) + assert_series_equal( + ewm_mean, + pl.Series( + [ + 1.0, + None, + 1.6666666666666665, + None, + 2.4285714285714284, + ] + ), + ) + ewm_mean = series.ewm_mean(alpha=0.5, min_samples=2, ignore_nulls=True) + assert_series_equal( + ewm_mean, + pl.Series( + [ + None, + None, + 1.6666666666666665, + None, + 2.4285714285714284, + ] + ), + ) + + +def test_ewm_std_var() -> None: + series = pl.Series("a", [2, 5, 3]) + + var = series.ewm_var(alpha=0.5, ignore_nulls=False) + std = series.ewm_std(alpha=0.5, ignore_nulls=False) + expected = pl.Series("a", [0.0, 4.5, 1.9285714285714288]) + assert np.allclose(var, std**2, rtol=1e-16) + assert_series_equal(var, expected) + + +def test_ewm_std_var_with_nulls() -> None: + series = pl.Series("a", [2, 5, None, 3]) + + var = series.ewm_var(alpha=0.5, ignore_nulls=True) + std = series.ewm_std(alpha=0.5, ignore_nulls=True) + expected = pl.Series("a", [0.0, 4.5, None, 1.9285714285714288]) + assert_series_equal(var, expected) + assert_series_equal(std**2, expected) + + var = series.ewm_var(alpha=0.5, ignore_nulls=False) + std = series.ewm_std(alpha=0.5, ignore_nulls=False) + expected = pl.Series("a", [0.0, 4.5, None, 1.7307692307692308]) + assert_series_equal(var, expected) + assert_series_equal(std**2, expected) + + +def test_ewm_param_validation() -> None: + s = pl.Series("values", range(10)) + + with pytest.raises(ValueError, match="mutually exclusive"): + s.ewm_std(com=0.5, alpha=0.5, ignore_nulls=False) + + with pytest.raises(ValueError, match="mutually exclusive"): + s.ewm_mean(span=1.5, half_life=0.75, ignore_nulls=False) + + with pytest.raises(ValueError, match="mutually exclusive"): + s.ewm_var(alpha=0.5, span=1.5, ignore_nulls=False) + + with pytest.raises(ValueError, match="require `com` >= 0"): + s.ewm_std(com=-0.5, ignore_nulls=False) + + with pytest.raises(ValueError, match="require `span` >= 1"): + s.ewm_mean(span=0.5, ignore_nulls=False) + + with pytest.raises(ValueError, match="require `half_life` > 0"): + s.ewm_var(half_life=0, ignore_nulls=False) + + for alpha in (-0.5, -0.0000001, 0.0, 1.0000001, 1.5): + with pytest.raises(ValueError, match="require 0 < `alpha` <= 1"): + s.ewm_std(alpha=alpha, ignore_nulls=False) + + +# https://github.com/pola-rs/polars/issues/4951 +@pytest.mark.may_fail_auto_streaming +def test_ewm_with_multiple_chunks() -> None: + df0 = pl.DataFrame( + data=[ + ("w", 6.0, 1.0), + ("x", 5.0, 2.0), + ("y", 4.0, 3.0), + ("z", 3.0, 4.0), + ], + schema=["a", "b", "c"], + orient="row", + ).with_columns( + pl.col(pl.Float64).log().diff().name.prefix("ld_"), + ) + assert df0.n_chunks() == 1 + + # NOTE: We aren't testing whether `select` creates two chunks; + # we just need two chunks to properly test `ewm_mean` + df1 = df0.select(["ld_b", "ld_c"]) + assert df1.n_chunks() == 2 + + ewm_std = df1.with_columns( + pl.all().ewm_std(com=20, ignore_nulls=False).name.prefix("ewm_"), + ) + assert ewm_std.null_count().sum_horizontal()[0] == 4 + + +def alpha_guard(**decay_param: float) -> bool: + """Protects against unnecessary noise in small number regime.""" + if not next(iter(decay_param.values())): + return True + alpha = _prepare_alpha(**decay_param) + return ((1 - alpha) if round(alpha) else alpha) > 1e-6 + + +@given( + s=series( + min_size=4, + dtype=pl.Float64, + allow_null=True, + strategy=st.floats(min_value=-1e8, max_value=1e8), + ), + half_life=st.floats(min_value=0, max_value=4, exclude_min=True).filter( + lambda x: alpha_guard(half_life=x) + ), + com=st.floats(min_value=0, max_value=99).filter(lambda x: alpha_guard(com=x)), + span=st.floats(min_value=1, max_value=10).filter(lambda x: alpha_guard(span=x)), + ignore_nulls=st.booleans(), + adjust=st.booleans(), + bias=st.booleans(), +) +def test_ewm_methods( + s: pl.Series, + com: float | None, + span: float | None, + half_life: float | None, + ignore_nulls: bool, + adjust: bool, + bias: bool, +) -> None: + # validate a large set of varied EWM calculations + for decay_param in [{"com": com}, {"span": span}, {"half_life": half_life}]: + alpha = _prepare_alpha(**decay_param) + + # convert parametrically-generated series to pandas, then use that as a + # reference implementation for comparison (after normalising NaN/None) + p = s.to_pandas() + + # note: skip min_samples < 2, due to pandas-side inconsistency: + # https://github.com/pola-rs/polars/issues/5006#issuecomment-1259477178 + for mp in range(2, len(s), len(s) // 3): + # consolidate ewm parameters + pl_params: dict[str, Any] = { + "min_samples": mp, + "adjust": adjust, + "ignore_nulls": ignore_nulls, + } + pl_params.update(decay_param) + pd_params: dict[str, Any] = { + "min_periods": mp, + "adjust": adjust, + "ignore_nulls": ignore_nulls, + } + pd_params.update(decay_param) + + if "half_life" in pl_params: + pd_params["halflife"] = pd_params.pop("half_life") + if "ignore_nulls" in pl_params: + pd_params["ignore_na"] = pd_params.pop("ignore_nulls") + + # mean: + ewm_mean_pl = s.ewm_mean(**pl_params).fill_nan(None) + ewm_mean_pd = pl.Series(p.ewm(**pd_params).mean()) + if alpha == 1: + # apply fill-forward to nulls to match pandas + # https://github.com/pola-rs/polars/pull/5011#issuecomment-1262318124 + ewm_mean_pl = ewm_mean_pl.fill_null(strategy="forward") + + assert_series_equal(ewm_mean_pl, ewm_mean_pd, atol=1e-07) + + # std: + ewm_std_pl = s.ewm_std(bias=bias, **pl_params).fill_nan(None) + ewm_std_pd = pl.Series(p.ewm(**pd_params).std(bias=bias)) + assert_series_equal(ewm_std_pl, ewm_std_pd, atol=1e-07) + + # var: + ewm_var_pl = s.ewm_var(bias=bias, **pl_params).fill_nan(None) + ewm_var_pd = pl.Series(p.ewm(**pd_params).var(bias=bias)) + assert_series_equal(ewm_var_pl, ewm_var_pd, atol=1e-07) diff --git a/py-polars/tests/unit/operations/test_ewm_by.py b/py-polars/tests/unit/operations/test_ewm_by.py new file mode 100644 index 000000000000..8004303238d4 --- /dev/null +++ b/py-polars/tests/unit/operations/test_ewm_by.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +from datetime import date, datetime, timedelta +from typing import TYPE_CHECKING + +import pytest + +import polars as pl +from polars.exceptions import InvalidOperationError +from polars.testing import assert_frame_equal, assert_series_equal + +if TYPE_CHECKING: + from polars._typing import PolarsIntegerType, TimeUnit + +from zoneinfo import ZoneInfo + + +@pytest.mark.parametrize("sort", [True, False]) +def test_ewma_by_date(sort: bool) -> None: + df = pl.LazyFrame( + { + "values": [3.0, 1.0, 2.0, None, 4.0], + "times": [ + None, + date(2020, 1, 4), + date(2020, 1, 11), + date(2020, 1, 16), + date(2020, 1, 18), + ], + } + ) + if sort: + df = df.sort("times") + result = df.select( + pl.col("values").ewm_mean_by("times", half_life=timedelta(days=2)), + ) + expected = pl.DataFrame( + {"values": [None, 1.0, 1.9116116523516815, None, 3.815410804703363]} + ) + assert_frame_equal(result.collect(), expected) + assert result.collect_schema()["values"] == pl.Float64 + assert result.collect().schema["values"] == pl.Float64 + + +def test_ewma_by_date_constant() -> None: + df = pl.DataFrame( + { + "values": [1, 1, 1], + "times": [ + date(2020, 1, 4), + date(2020, 1, 11), + date(2020, 1, 16), + ], + } + ) + result = df.select( + pl.col("values").ewm_mean_by("times", half_life=timedelta(days=2)), + ) + expected = pl.DataFrame({"values": [1.0, 1, 1]}) + assert_frame_equal(result, expected) + + +def test_ewma_f32() -> None: + df = pl.LazyFrame( + { + "values": [3.0, 1.0, 2.0, None, 4.0], + "times": [ + None, + date(2020, 1, 4), + date(2020, 1, 11), + date(2020, 1, 16), + date(2020, 1, 18), + ], + }, + schema_overrides={"values": pl.Float32}, + ) + result = df.select( + pl.col("values").ewm_mean_by("times", half_life=timedelta(days=2)), + ) + expected = pl.DataFrame( + {"values": [None, 1.0, 1.9116116523516815, None, 3.815410804703363]}, + schema_overrides={"values": pl.Float32}, + ) + assert_frame_equal(result.collect(), expected) + assert result.collect_schema()["values"] == pl.Float32 + assert result.collect().schema["values"] == pl.Float32 + + +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +@pytest.mark.parametrize("time_zone", [None, "UTC"]) +def test_ewma_by_datetime(time_unit: TimeUnit, time_zone: str | None) -> None: + df = pl.DataFrame( + { + "values": [3.0, 1.0, 2.0, None, 4.0], + "times": [ + None, + datetime(2020, 1, 4), + datetime(2020, 1, 11), + datetime(2020, 1, 16), + datetime(2020, 1, 18), + ], + }, + schema_overrides={"times": pl.Datetime(time_unit, time_zone)}, + ) + result = df.select( + pl.col("values").ewm_mean_by("times", half_life=timedelta(days=2)), + ) + expected = pl.DataFrame( + {"values": [None, 1.0, 1.9116116523516815, None, 3.815410804703363]} + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_ewma_by_datetime_tz_aware(time_unit: TimeUnit) -> None: + tzinfo = ZoneInfo("Asia/Kathmandu") + df = pl.DataFrame( + { + "values": [3.0, 1.0, 2.0, None, 4.0], + "times": [ + None, + datetime(2020, 1, 4, tzinfo=tzinfo), + datetime(2020, 1, 11, tzinfo=tzinfo), + datetime(2020, 1, 16, tzinfo=tzinfo), + datetime(2020, 1, 18, tzinfo=tzinfo), + ], + }, + schema_overrides={"times": pl.Datetime(time_unit, "Asia/Kathmandu")}, + ) + msg = "expected `half_life` to be a constant duration" + with pytest.raises(InvalidOperationError, match=msg): + df.select( + pl.col("values").ewm_mean_by("times", half_life="2d"), + ) + + result = df.select( + pl.col("values").ewm_mean_by("times", half_life="48h0ns"), + ) + expected = pl.DataFrame( + {"values": [None, 1.0, 1.9116116523516815, None, 3.815410804703363]} + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("data_type", [pl.Int64, pl.Int32, pl.UInt64, pl.UInt32]) +def test_ewma_by_index(data_type: PolarsIntegerType) -> None: + df = pl.LazyFrame( + { + "values": [3.0, 1.0, 2.0, None, 4.0], + "times": [ + None, + 4, + 11, + 16, + 18, + ], + }, + schema_overrides={"times": data_type}, + ) + result = df.select( + pl.col("values").ewm_mean_by("times", half_life="2i"), + ) + expected = pl.DataFrame( + {"values": [None, 1.0, 1.9116116523516815, None, 3.815410804703363]} + ) + assert_frame_equal(result.collect(), expected) + assert result.collect_schema()["values"] == pl.Float64 + assert result.collect().schema["values"] == pl.Float64 + + +def test_ewma_by_empty() -> None: + df = pl.DataFrame({"values": []}, schema_overrides={"values": pl.Float64}) + result = df.with_row_index().select( + pl.col("values").ewm_mean_by("index", half_life="2i"), + ) + expected = pl.DataFrame({"values": []}, schema_overrides={"values": pl.Float64}) + assert_frame_equal(result, expected) + + +def test_ewma_by_if_unsorted() -> None: + df = pl.DataFrame({"values": [3.0, 2.0], "by": [3, 1]}) + result = df.with_columns( + pl.col("values").ewm_mean_by("by", half_life="2i"), + ) + expected = pl.DataFrame({"values": [2.5, 2.0], "by": [3, 1]}) + assert_frame_equal(result, expected) + + result = df.with_columns( + pl.col("values").ewm_mean_by("by", half_life="2i"), + ) + assert_frame_equal(result, expected) + + result = df.sort("by").with_columns( + pl.col("values").ewm_mean_by("by", half_life="2i"), + ) + assert_frame_equal(result, expected.sort("by")) + + +def test_ewma_by_invalid() -> None: + df = pl.DataFrame({"values": [1, 2]}) + with pytest.raises(InvalidOperationError, match="half_life cannot be negative"): + df.with_row_index().select( + pl.col("values").ewm_mean_by("index", half_life="-2i"), + ) + df = pl.DataFrame({"values": [[1, 2], [3, 4]]}) + with pytest.raises( + InvalidOperationError, match=r"expected series to be Float64, Float32, .*" + ): + df.with_row_index().select( + pl.col("values").ewm_mean_by("index", half_life="2i"), + ) + + +def test_ewma_by_warn_two_chunks() -> None: + df = pl.DataFrame({"values": [3.0, 2.0], "by": [3, 1]}) + df = pl.concat([df, df], rechunk=False) + + result = df.with_columns( + pl.col("values").ewm_mean_by("by", half_life="2i"), + ) + expected = pl.DataFrame({"values": [2.5, 2.0, 2.5, 2], "by": [3, 1, 3, 1]}) + assert_frame_equal(result, expected) + result = df.sort("by").with_columns( + pl.col("values").ewm_mean_by("by", half_life="2i"), + ) + assert_frame_equal(result, expected.sort("by")) + + +def test_ewma_by_multiple_chunks() -> None: + # times contains null + times = pl.Series([1, 2]).append(pl.Series([None], dtype=pl.Int64)) + values = pl.Series([1, 2]).append(pl.Series([3])) + result = values.ewm_mean_by(times, half_life="2i") + expected = pl.Series([1.0, 1.292893, None]) + assert_series_equal(result, expected) + + # values contains null + times = pl.Series([1, 2]).append(pl.Series([3])) + values = pl.Series([1, 2]).append(pl.Series([None], dtype=pl.Int64)) + result = values.ewm_mean_by(times, half_life="2i") + assert_series_equal(result, expected) diff --git a/py-polars/tests/unit/operations/test_explode.py b/py-polars/tests/unit/operations/test_explode.py new file mode 100644 index 000000000000..4d92fd932f5b --- /dev/null +++ b/py-polars/tests/unit/operations/test_explode.py @@ -0,0 +1,454 @@ +from __future__ import annotations + +import pyarrow as pa +import pytest + +import polars as pl +import polars.selectors as cs +from polars.exceptions import ShapeError +from polars.testing import assert_frame_equal, assert_series_equal + + +def test_explode_multiple() -> None: + df = pl.DataFrame({"a": [[1, 2], [3, 4]], "b": [[5, 6], [7, 8]]}) + + expected = pl.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]}) + assert_frame_equal(df.explode(cs.all()), expected) + assert_frame_equal(df.explode(["a", "b"]), expected) + assert_frame_equal(df.explode("a", "b"), expected) + + +def test_group_by_flatten_list() -> None: + df = pl.DataFrame({"group": ["a", "b", "b"], "values": [[1, 2], [2, 3], [4]]}) + result = df.group_by("group", maintain_order=True).agg(pl.col("values").flatten()) + + expected = pl.DataFrame({"group": ["a", "b"], "values": [[1, 2], [2, 3, 4]]}) + assert_frame_equal(result, expected) + + +def test_explode_empty_df_3402() -> None: + df = pl.DataFrame({"a": pa.array([], type=pa.large_list(pa.int32()))}) + assert df.explode("a").dtypes == [pl.Int32] + + +def test_explode_empty_df_3460() -> None: + df = pl.DataFrame({"a": pa.array([[]], type=pa.large_list(pa.int32()))}) + assert df.explode("a").dtypes == [pl.Int32] + + +def test_explode_empty_df_3902() -> None: + df = pl.DataFrame( + { + "first": [1, 2, 3, 4, 5], + "second": [["a"], [], ["b", "c"], [], ["d", "f", "g"]], + } + ) + expected = pl.DataFrame( + { + "first": [1, 2, 3, 3, 4, 5, 5, 5], + "second": ["a", None, "b", "c", None, "d", "f", "g"], + } + ) + assert_frame_equal(df.explode("second"), expected) + + +def test_explode_empty_list_4003() -> None: + df = pl.DataFrame( + [ + {"id": 1, "nested": []}, + {"id": 2, "nested": [1]}, + {"id": 3, "nested": [2]}, + ] + ) + assert df.explode("nested").to_dict(as_series=False) == { + "id": [1, 2, 3], + "nested": [None, 1, 2], + } + + +def test_explode_empty_list_4107() -> None: + df = pl.DataFrame({"b": [[1], [2], []] * 2}).with_row_index() + + assert_frame_equal( + df.explode(["b"]), df.explode(["b"]).drop("index").with_row_index() + ) + + +def test_explode_correct_for_slice() -> None: + df = pl.DataFrame({"b": [[1, 1], [2, 2], [3, 3], [4, 4]]}) + assert df.slice(2, 2).explode(["b"])["b"].to_list() == [3, 3, 4, 4] + + df = ( + ( + pl.DataFrame({"group": pl.arange(0, 5, eager=True)}).join( + pl.DataFrame( + { + "b": [[1, 2, 3], [2, 3], [4], [1, 2, 3], [0]], + } + ), + how="cross", + ) + ) + .sort("group") + .with_row_index() + ) + expected = pl.DataFrame( + { + "index": [0, 0, 0, 1, 1, 2, 3, 3, 3, 4, 5, 5, 5, 6, 6, 7, 8, 8, 8, 9], + "group": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "b": [1, 2, 3, 2, 3, 4, 1, 2, 3, 0, 1, 2, 3, 2, 3, 4, 1, 2, 3, 0], + }, + schema_overrides={"index": pl.UInt32}, + ) + assert_frame_equal(df.slice(0, 10).explode(["b"]), expected) + + +def test_sliced_null_explode() -> None: + s = pl.Series("", [[1], [2], [3], [4], [], [6]]) + assert s.slice(2, 4).list.explode().to_list() == [3, 4, None, 6] + assert s.slice(2, 2).list.explode().to_list() == [3, 4] + assert pl.Series("", [[1], [2], None, [4], [], [6]]).slice( + 2, 4 + ).list.explode().to_list() == [None, 4, None, 6] + + s = pl.Series("", [["a"], ["b"], ["c"], ["d"], [], ["e"]]) + assert s.slice(2, 4).list.explode().to_list() == ["c", "d", None, "e"] + assert s.slice(2, 2).list.explode().to_list() == ["c", "d"] + assert pl.Series("", [["a"], ["b"], None, ["d"], [], ["e"]]).slice( + 2, 4 + ).list.explode().to_list() == [None, "d", None, "e"] + + s = pl.Series("", [[False], [False], [True], [False], [], [True]]) + assert s.slice(2, 2).list.explode().to_list() == [True, False] + assert s.slice(2, 4).list.explode().to_list() == [True, False, None, True] + + +def test_explode_in_agg_context() -> None: + df = pl.DataFrame( + {"idxs": [[0], [1], [0, 2]], "array": [[0.0, 3.5], [4.6, 0.0], [0.0, 7.8, 0.0]]} + ) + + assert ( + df.with_row_index() + .explode("idxs") + .group_by("index") + .agg(pl.col("array").flatten()) + ).to_dict(as_series=False) == { + "index": [0, 1, 2], + "array": [[0.0, 3.5], [4.6, 0.0], [0.0, 7.8, 0.0, 0.0, 7.8, 0.0]], + } + + +def test_explode_inner_lists_3985() -> None: + df = pl.DataFrame( + data={"id": [1, 1, 1], "categories": [["a"], ["b"], ["a", "c"]]} + ).lazy() + + assert ( + df.group_by("id") + .agg(pl.col("categories")) + .with_columns(pl.col("categories").list.eval(pl.element().list.explode())) + ).collect().to_dict(as_series=False) == { + "id": [1], + "categories": [["a", "b", "a", "c"]], + } + + +def test_list_struct_explode_6905() -> None: + assert pl.DataFrame( + { + "group": [ + [], + [ + {"params": [1]}, + {"params": []}, + ], + ] + }, + schema={"group": pl.List(pl.Struct([pl.Field("params", pl.List(pl.Int32))]))}, + )["group"].list.explode().to_list() == [ + None, + {"params": [1]}, + {"params": []}, + ] + + +def test_explode_binary() -> None: + assert pl.Series([[1, 2], [3]]).cast( + pl.List(pl.Binary) + ).list.explode().to_list() == [ + b"1", + b"2", + b"3", + ] + + +def test_explode_null_list() -> None: + assert pl.Series([["a"], None], dtype=pl.List(pl.String))[ + 1:2 + ].list.min().to_list() == [None] + + +def test_explode_invalid_element_count() -> None: + df = pl.DataFrame( + { + "col1": [["X", "Y", "Z"], ["F", "G"], ["P"]], + "col2": [["A", "B", "C"], ["C"], ["D", "E"]], + } + ).with_row_index() + with pytest.raises( + ShapeError, match=r"exploded columns must have matching element counts" + ): + df.explode(["col1", "col2"]) + + +def test_logical_explode() -> None: + out = ( + pl.DataFrame( + {"cats": ["Value1", "Value2", "Value1"]}, + schema_overrides={"cats": pl.Categorical}, + ) + .group_by(1) + .agg(pl.struct("cats")) + .explode("cats") + .unnest("cats") + ) + assert out["cats"].dtype == pl.Categorical + assert out["cats"].to_list() == ["Value1", "Value2", "Value1"] + + +def test_explode_inner_null() -> None: + expected = pl.DataFrame({"A": [None, None]}, schema={"A": pl.Null}) + out = pl.DataFrame({"A": [[], []]}, schema={"A": pl.List(pl.Null)}).explode("A") + assert_frame_equal(out, expected) + + +def test_explode_array() -> None: + df = pl.LazyFrame( + {"a": [[1, 2], [2, 3]], "b": [1, 2]}, + schema_overrides={"a": pl.Array(pl.Int64, 2)}, + ) + expected = pl.DataFrame({"a": [1, 2, 2, 3], "b": [1, 1, 2, 2]}) + for ex in ("a", ~cs.integer()): + out = df.explode(ex).collect() + assert_frame_equal(out, expected) + + +def test_string_list_agg_explode() -> None: + df = pl.DataFrame({"a": [[None], ["b"]]}) + + df = df.select( + pl.col("a").list.eval(pl.element().filter(pl.element().is_not_null())) + ) + assert not df["a"].flags["FAST_EXPLODE"] + + df2 = pl.DataFrame({"a": [[], ["b"]]}) + + assert_frame_equal(df, df2) + assert_frame_equal(df.explode("a"), df2.explode("a")) + + +def test_explode_null_struct() -> None: + df = [ + {"col1": None}, + { + "col1": [ + {"field1": None, "field2": None, "field3": None}, + {"field1": None, "field2": "some", "field3": "value"}, + ] + }, + ] + + assert pl.DataFrame(df).explode("col1").to_dict(as_series=False) == { + "col1": [ + None, + {"field1": None, "field2": None, "field3": None}, + {"field1": None, "field2": "some", "field3": "value"}, + ] + } + + +def test_df_explode_with_array() -> None: + df = pl.DataFrame( + { + "arr": [["a", "b"], ["c", None], None, ["d", "e"]], + "list": [[1, 2], [3], [4, None], None], + "val": ["x", "y", "z", "q"], + }, + schema={ + "arr": pl.Array(pl.String, 2), + "list": pl.List(pl.Int64), + "val": pl.String, + }, + ) + + expected_by_arr = pl.DataFrame( + { + "arr": ["a", "b", "c", None, None, "d", "e"], + "list": [[1, 2], [1, 2], [3], [3], [4, None], None, None], + "val": ["x", "x", "y", "y", "z", "q", "q"], + } + ) + assert_frame_equal(df.explode(pl.col("arr")), expected_by_arr) + + expected_by_list = pl.DataFrame( + { + "arr": [["a", "b"], ["a", "b"], ["c", None], None, None, ["d", "e"]], + "list": [1, 2, 3, 4, None, None], + "val": ["x", "x", "y", "z", "z", "q"], + }, + schema={ + "arr": pl.Array(pl.String, 2), + "list": pl.Int64, + "val": pl.String, + }, + ) + assert_frame_equal(df.explode(pl.col("list")), expected_by_list) + + df = pl.DataFrame( + { + "arr": [["a", "b"], ["c", None], None, ["d", "e"]], + "list": [[1, 2], [3, 4], None, [5, None]], + "val": [None, 1, 2, None], + }, + schema={ + "arr": pl.Array(pl.String, 2), + "list": pl.List(pl.Int64), + "val": pl.Int64, + }, + ) + expected_by_arr_and_list = pl.DataFrame( + { + "arr": ["a", "b", "c", None, None, "d", "e"], + "list": [1, 2, 3, 4, None, 5, None], + "val": [None, None, 1, 1, 2, None, None], + }, + schema={ + "arr": pl.String, + "list": pl.Int64, + "val": pl.Int64, + }, + ) + assert_frame_equal(df.explode("arr", "list"), expected_by_arr_and_list) + + +def test_explode_nullable_list() -> None: + df = pl.DataFrame({"layout1": [None, [1, 2]], "b": [False, True]}).with_columns( + layout2=pl.when(pl.col("b")).then([1, 2]), + ) + + explode_df = df.explode("layout1", "layout2") + expected_df = pl.DataFrame( + { + "layout1": [None, 1, 2], + "b": [False, True, True], + "layout2": [None, 1, 2], + } + ) + assert_frame_equal(explode_df, expected_df) + + explode_expr = df.select( + pl.col("layout1").explode(), + pl.col("layout2").explode(), + ) + expected_df = pl.DataFrame( + { + "layout1": [None, 1, 2], + "layout2": [None, 1, 2], + } + ) + assert_frame_equal(explode_expr, expected_df) + + +def test_group_by_flatten_string() -> None: + df = pl.DataFrame({"group": ["a", "b", "b"], "values": ["foo", "bar", "baz"]}) + + result = df.group_by("group", maintain_order=True).agg( + pl.col("values").str.split("").explode() + ) + + expected = pl.DataFrame( + { + "group": ["a", "b"], + "values": [["f", "o", "o"], ["b", "a", "r", "b", "a", "z"]], + } + ) + assert_frame_equal(result, expected) + + +def test_fast_explode_merge_right_16923() -> None: + df = pl.concat( + [ + pl.DataFrame({"foo": [["a", "b"], ["c"]]}), + pl.DataFrame({"foo": [None]}, schema={"foo": pl.List(pl.Utf8)}), + ], + how="diagonal", + rechunk=True, + ).explode("foo") + + assert df.height == 4 + + +def test_fast_explode_merge_left_16923() -> None: + df = pl.concat( + [ + pl.DataFrame({"foo": [None]}, schema={"foo": pl.List(pl.Utf8)}), + pl.DataFrame({"foo": [["a", "b"], ["c"]]}), + ], + how="diagonal", + rechunk=True, + ).explode("foo") + + assert df.height == 4 + + +@pytest.mark.parametrize( + ("values", "exploded"), + [ + (["foobar", None], ["f", "o", "o", "b", "a", "r", None]), + ([None, "foo", "bar"], [None, "f", "o", "o", "b", "a", "r"]), + ( + [None, "foo", "bar", None, "ham"], + [None, "f", "o", "o", "b", "a", "r", None, "h", "a", "m"], + ), + (["foo", "bar", "ham"], ["f", "o", "o", "b", "a", "r", "h", "a", "m"]), + (["", None, "foo", "bar"], ["", None, "f", "o", "o", "b", "a", "r"]), + (["", "foo", "bar"], ["", "f", "o", "o", "b", "a", "r"]), + ], +) +def test_series_str_explode_deprecated( + values: list[str | None], exploded: list[str | None] +) -> None: + with pytest.deprecated_call(): + result = pl.Series(values).str.explode() + assert result.to_list() == exploded + + +def test_expr_str_explode_deprecated() -> None: + df = pl.Series("a", ["Hello", "World"]) + with pytest.deprecated_call(): + result = df.to_frame().select(pl.col("a").str.explode()).to_series() + + expected = pl.Series("a", ["H", "e", "l", "l", "o", "W", "o", "r", "l", "d"]) + assert_series_equal(result, expected) + + +def test_undefined_col_15852() -> None: + lf = pl.LazyFrame({"foo": [1]}) + + with pytest.raises(pl.exceptions.ColumnNotFoundError): + lf.explode("bar").join(lf, on="foo").collect() + + +def test_explode_17648() -> None: + df = pl.DataFrame({"a": [[1, 3], [2, 6, 7], [3, 9, 2], [4], [5, 1, 2, 3, 4]]}) + assert ( + df.slice(1, 2) + .with_columns(pl.int_ranges(pl.col("a").list.len()).alias("count")) + .explode("a", "count") + ).to_dict(as_series=False) == {"a": [2, 6, 7, 3, 9, 2], "count": [0, 1, 2, 0, 1, 2]} + + +def test_explode_struct_nulls() -> None: + df = pl.DataFrame({"A": [[{"B": 1}], [None], []]}) + assert df.explode("A").to_dict(as_series=False) == {"A": [{"B": 1}, None, None]} diff --git a/py-polars/tests/unit/operations/test_extend_constant.py b/py-polars/tests/unit/operations/test_extend_constant.py new file mode 100644 index 000000000000..2106b0dfec83 --- /dev/null +++ b/py-polars/tests/unit/operations/test_extend_constant.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +from datetime import date, datetime, time, timedelta +from typing import TYPE_CHECKING, Any + +import pytest + +import polars as pl +from polars.exceptions import ComputeError +from polars.testing import assert_frame_equal, assert_series_equal + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + + +@pytest.mark.parametrize( + ("const", "dtype"), + [ + (1, pl.Int8), + (4, pl.UInt32), + (4.5, pl.Float32), + (None, pl.Float64), + ("白鵬翔", pl.String), + (date.today(), pl.Date), + (datetime.now(), pl.Datetime("ns")), + (time(23, 59, 59), pl.Time), + (timedelta(hours=7, seconds=123), pl.Duration("ms")), + ], +) +def test_extend_constant(const: Any, dtype: PolarsDataType) -> None: + df = pl.DataFrame({"a": pl.Series("s", [None], dtype=dtype)}) + + expected_df = pl.DataFrame( + {"a": pl.Series("s", [None, const, const, const], dtype=dtype)} + ) + + assert_frame_equal(df.select(pl.col("a").extend_constant(const, 3)), expected_df) + + s = pl.Series("s", [None], dtype=dtype) + expected = pl.Series("s", [None, const, const, const], dtype=dtype) + assert_series_equal(s.extend_constant(const, 3), expected) + + # test n expr + expected = pl.Series("s", [None, const, const], dtype=dtype) + assert_series_equal(s.extend_constant(const, pl.Series([2])), expected) + + # test value expr + expected = pl.Series("s", [None, const, const, const], dtype=dtype) + assert_series_equal(s.extend_constant(pl.Series([const], dtype=dtype), 3), expected) + + +@pytest.mark.parametrize( + ("const", "dtype"), + [ + (1, pl.Int8), + (4, pl.UInt32), + (4.5, pl.Float32), + (None, pl.Float64), + ("白鵬翔", pl.String), + (date.today(), pl.Date), + (datetime.now(), pl.Datetime("ns")), + (time(23, 59, 59), pl.Time), + (timedelta(hours=7, seconds=123), pl.Duration("ms")), + ], +) +def test_extend_constant_arr(const: Any, dtype: PolarsDataType) -> None: + """ + Test extend_constant in pl.List array. + + NOTE: This function currently fails when the Series is a list with a single [None] + value. Hence, this function does not begin with [[None]], but [[const]]. + """ + s = pl.Series("s", [[const]], dtype=pl.List(dtype)) + + expected = pl.Series("s", [[const, const, const, const]], dtype=pl.List(dtype)) + + assert_series_equal(s.list.eval(pl.element().extend_constant(const, 3)), expected) + + +def test_extend_by_not_uint_expr() -> None: + s = pl.Series("s", [1]) + with pytest.raises(ComputeError, match="value and n should have unit length"): + s.extend_constant(pl.Series([2, 3]), 3) + with pytest.raises(ComputeError, match="value and n should have unit length"): + s.extend_constant(2, pl.Series([3, 4])) diff --git a/py-polars/tests/unit/operations/test_fill_null.py b/py-polars/tests/unit/operations/test_fill_null.py new file mode 100644 index 000000000000..0e19606a2ebc --- /dev/null +++ b/py-polars/tests/unit/operations/test_fill_null.py @@ -0,0 +1,123 @@ +import datetime + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal + + +def test_fill_null_minimal_upcast_4056() -> None: + df = pl.DataFrame({"a": [-1, 2, None]}) + df = df.with_columns(pl.col("a").cast(pl.Int8)) + assert df.with_columns(pl.col(pl.Int8).fill_null(-1)).dtypes[0] == pl.Int8 + assert df.with_columns(pl.col(pl.Int8).fill_null(-1000)).dtypes[0] == pl.Int16 + + +def test_fill_enum_upcast() -> None: + dtype = pl.Enum(["a", "b"]) + s = pl.Series(["a", "b", None], dtype=dtype) + s_filled = s.fill_null("b") + expected = pl.Series(["a", "b", "b"], dtype=dtype) + assert s_filled.dtype == dtype + assert_series_equal(s_filled, expected) + + +def test_fill_null_static_schema_4843() -> None: + df1 = pl.DataFrame( + { + "a": [1, 2, None], + "b": [1, None, 4], + } + ).lazy() + + df2 = df1.select([pl.col(pl.Int64).fill_null(0)]) + df3 = df2.select(pl.col(pl.Int64)) + assert df3.collect_schema() == {"a": pl.Int64, "b": pl.Int64} + + +def test_fill_null_non_lit() -> None: + df = pl.DataFrame( + { + "a": pl.Series([1, None], dtype=pl.Int32), + "b": pl.Series([None, 2], dtype=pl.UInt32), + "c": pl.Series([None, 2], dtype=pl.Int64), + "d": pl.Series([None, 2], dtype=pl.Decimal), + } + ) + assert df.fill_null(0).select(pl.all().null_count()).transpose().sum().item() == 0 + + +def test_fill_null_f32_with_lit() -> None: + # ensure the literal integer does not upcast the f32 to an f64 + df = pl.DataFrame({"a": [1.1, 1.2]}, schema=[("a", pl.Float32)]) + assert df.fill_null(value=0).dtypes == [pl.Float32] + + +def test_fill_null_lit_() -> None: + df = pl.DataFrame( + { + "a": pl.Series([1, None], dtype=pl.Int32), + "b": pl.Series([None, 2], dtype=pl.UInt32), + "c": pl.Series([None, 2], dtype=pl.Int64), + } + ) + assert ( + df.fill_null(pl.lit(0)).select(pl.all().null_count()).transpose().sum().item() + == 0 + ) + + +def test_fill_null_decimal_with_int_14331() -> None: + s = pl.Series("a", ["1.1", None], dtype=pl.Decimal(precision=None, scale=5)) + result = s.fill_null(0) + expected = pl.Series("a", ["1.1", "0.0"], dtype=pl.Decimal(precision=None, scale=5)) + assert_series_equal(result, expected) + + +def test_fill_null_date_with_int_11362() -> None: + match = "got invalid or ambiguous dtypes" + + s = pl.Series([datetime.date(2000, 1, 1)]) + with pytest.raises(pl.exceptions.InvalidOperationError, match=match): + s.fill_null(0) + + s = pl.Series([None], dtype=pl.Date) + with pytest.raises(pl.exceptions.InvalidOperationError, match=match): + s.fill_null(1) + + +def test_fill_null_int_dtype_15546() -> None: + df = pl.Series("a", [1, 2, None], dtype=pl.Int8).to_frame().lazy() + result = df.fill_null(0).collect() + expected = pl.Series("a", [1, 2, 0], dtype=pl.Int8).to_frame() + assert_frame_equal(result, expected) + + +def test_fill_null_with_list_10869() -> None: + assert_series_equal( + pl.Series([[1], None]).fill_null([2]), + pl.Series([[1], [2]]), + ) + + match = "failed to determine supertype" + with pytest.raises(pl.exceptions.SchemaError, match=match): + pl.Series([1, None]).fill_null([2]) + + +def test_unequal_lengths_22018() -> None: + with pytest.raises(pl.exceptions.ShapeError): + pl.Series([1, None]).fill_null(pl.Series([1] * 3)) + with pytest.raises(pl.exceptions.ShapeError): + pl.Series([1, 2]).fill_null(pl.Series([1] * 3)) + + +def test_self_broadcast() -> None: + assert_series_equal( + pl.Series([1]).fill_null(pl.Series(range(3))), + pl.Series([1] * 3), + ) + + assert_series_equal( + pl.Series([None]).fill_null(pl.Series(range(3))), + pl.Series(range(3)), + ) diff --git a/py-polars/tests/unit/operations/test_filter.py b/py-polars/tests/unit/operations/test_filter.py new file mode 100644 index 000000000000..08f5b13bc0ab --- /dev/null +++ b/py-polars/tests/unit/operations/test_filter.py @@ -0,0 +1,330 @@ +from __future__ import annotations + +from datetime import datetime, timedelta +from typing import TYPE_CHECKING + +import numpy as np +import pytest + +import polars as pl +import polars.selectors as cs +from polars.testing import assert_frame_equal, assert_series_equal + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + + +def test_simplify_expression_lit_true_4376() -> None: + df = pl.DataFrame([[1, 4, 7], [2, 5, 8], [3, 6, 9]]) + assert df.lazy().filter(pl.lit(True) | (pl.col("column_0") == 1)).collect( + simplify_expression=True + ).rows() == [(1, 2, 3), (4, 5, 6), (7, 8, 9)] + assert df.lazy().filter((pl.col("column_0") == 1) | pl.lit(True)).collect( + simplify_expression=True + ).rows() == [(1, 2, 3), (4, 5, 6), (7, 8, 9)] + + +def test_filter_contains_nth_11205() -> None: + df = pl.DataFrame({"x": [False]}) + assert df.filter(pl.first()).is_empty() + + +def test_unpivot_values_predicate_pushdown() -> None: + lf = pl.DataFrame( + { + "id": [1], + "asset_key_1": ["123"], + "asset_key_2": ["456"], + "asset_key_3": ["abc"], + } + ).lazy() + + assert ( + lf.unpivot(index="id", on=["asset_key_1", "asset_key_2", "asset_key_3"]) + .filter(pl.col("value") == pl.lit("123")) + .collect() + ).to_dict(as_series=False) == { + "id": [1], + "variable": ["asset_key_1"], + "value": ["123"], + } + + +def test_group_by_filter_all_true() -> None: + df = pl.DataFrame( + { + "name": ["a", "a", "b", "b"], + "type": [None, 1, 1, None], + "order": [1, 2, 3, 4], + } + ) + out = ( + df.group_by("name") + .agg( + [ + pl.col("order") + .filter(pl.col("order") > 0, type=1) + .n_unique() + .alias("n_unique") + ] + ) + .select("n_unique") + ) + assert out.to_dict(as_series=False) == {"n_unique": [1, 1]} + + +def test_filter_is_in_4572() -> None: + df = pl.DataFrame({"id": [1, 2, 1, 2], "k": ["a"] * 2 + ["b"] * 2}) + expected = df.group_by("id").agg(pl.col("k").filter(k="a").implode()).sort("id") + result = ( + df.group_by("id") + .agg(pl.col("k").filter(pl.col("k").is_in(["a"])).implode()) + .sort("id") + ) + assert_frame_equal(result, expected) + result = ( + df.sort("id") + .group_by("id") + .agg(pl.col("k").filter(pl.col("k").is_in(["a"])).implode()) + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "dtype", [pl.Int32, pl.Boolean, pl.String, pl.Binary, pl.List(pl.Int64), pl.Object] +) +def test_filter_on_empty(dtype: PolarsDataType) -> None: + df = pl.DataFrame({"a": []}, schema={"a": dtype}) + out = df.filter(pl.col("a").is_null()) + assert out.is_empty() + + +def test_filter_aggregation_any() -> None: + df = pl.DataFrame( + { + "id": [1, 2, 3, 4], + "group": [1, 2, 1, 1], + "pred_a": [False, True, False, False], + "pred_b": [False, False, True, True], + } + ) + + result = ( + df.group_by("group") + .agg( + pl.any_horizontal("pred_a", "pred_b").alias("any"), + pl.col("id") + .filter(pl.any_horizontal("pred_a", "pred_b")) + .alias("filtered"), + ) + .sort("group") + ) + + assert result.to_dict(as_series=False) == { + "group": [1, 2], + "any": [[False, True, True], [True]], + "filtered": [[3, 4], [2]], + } + + +def test_predicate_order_explode_5950() -> None: + df = pl.from_dict( + { + "i": [[0, 1], [1, 2]], + "n": [0, None], + } + ) + + assert ( + df.lazy() + .explode("i") + .filter(pl.len().over(["i"]) == 2) + .filter(pl.col("n").is_not_null()) + ).collect().to_dict(as_series=False) == {"i": [1], "n": [0]} + + +def test_binary_simplification_5971() -> None: + df = pl.DataFrame(pl.Series("a", [1, 2, 3, 4])) + assert df.select((pl.col("a") > 2) | pl.lit(False))["a"].to_list() == [ + False, + False, + True, + True, + ] + + +@pytest.mark.usefixtures("test_global_and_local") +def test_categorical_string_comparison_6283() -> None: + scores = pl.DataFrame( + { + "zone": pl.Series( + [ + "North", + "North", + "North", + "South", + "South", + "East", + "East", + "East", + "East", + ] + ).cast(pl.Categorical), + "funding": pl.Series( + ["yes", "yes", "no", "yes", "no", "no", "no", "yes", "yes"] + ).cast(pl.Categorical), + "score": [78, 39, 76, 56, 67, 89, 100, 55, 80], + } + ) + + assert scores.filter(scores["zone"] == "North").to_dict(as_series=False) == { + "zone": ["North", "North", "North"], + "funding": ["yes", "yes", "no"], + "score": [78, 39, 76], + } + + +def test_clear_window_cache_after_filter_10499() -> None: + df = pl.from_dict( + { + "a": [None, None, 3, None, 5, 0, 0, 0, 9, 10], + "b": [1, 1, 2, 2, 3, 3, 4, 4, 5, 5], + } + ) + + assert df.lazy().filter((pl.col("a").null_count() < pl.len()).over("b")).filter( + ((pl.col("a") == 0).sum() < pl.len()).over("b") + ).collect().to_dict(as_series=False) == { + "a": [3, None, 5, 0, 9, 10], + "b": [2, 2, 3, 3, 5, 5], + } + + +def test_agg_function_of_filter_10565() -> None: + df_int = pl.DataFrame(data={"a": []}, schema={"a": pl.Int16}) + assert df_int.filter(pl.col("a").n_unique().over("a") == 1).to_dict( + as_series=False + ) == {"a": []} + + df_str = pl.DataFrame(data={"a": []}, schema={"a": pl.String}) + assert df_str.filter(pl.col("a").n_unique().over("a") == 1).to_dict( + as_series=False + ) == {"a": []} + + assert df_str.lazy().filter(pl.col("a").n_unique().over("a") == 1).collect( + predicate_pushdown=False + ).to_dict(as_series=False) == {"a": []} + + +def test_filter_logical_type_13194() -> None: + data = { + "id": [1, 1, 2], + "date": [ + [datetime(year=2021, month=1, day=1)], + [datetime(year=2021, month=1, day=1)], + [datetime(year=2025, month=1, day=30)], + ], + "cat": [ + ["a", "b", "c"], + ["a", "b", "c"], + ["d", "e", "f"], + ], + } + + df = pl.DataFrame(data).with_columns(pl.col("cat").cast(pl.List(pl.Categorical()))) + + df = df.filter(pl.col("id") == pl.col("id").shift(1)) + expected_df = pl.DataFrame( + { + "id": [1], + "date": [[datetime(year=2021, month=1, day=1)]], + "cat": [["a", "b", "c"]], + }, + schema={ + "id": pl.Int64, + "date": pl.List(pl.Datetime), + "cat": pl.List(pl.Categorical), + }, + ) + assert_frame_equal(df, expected_df) + + +def test_filter_horizontal_selector_15428() -> None: + df = pl.DataFrame({"a": [1, 2, 3]}) + + df = df.filter(pl.all_horizontal((cs.by_name("^.*$") & cs.integer()) <= 2)) + expected_df = pl.DataFrame({"a": [1, 2]}) + + assert_frame_equal(df, expected_df) + + +@pytest.mark.slow +@pytest.mark.parametrize( + "dtype", [pl.Boolean, pl.Int8, pl.Int16, pl.Int32, pl.Int64, pl.String] +) +@pytest.mark.parametrize("size", list(range(64)) + [100, 1000, 10000]) +@pytest.mark.parametrize("selectivity", [0.0, 0.01, 0.1, 0.5, 0.9, 0.99, 1.0 + 1e-6]) +def test_filter(dtype: PolarsDataType, size: int, selectivity: float) -> None: + rng = np.random.Generator(np.random.PCG64(size * 100 + int(100 * selectivity))) + np_payload = rng.uniform(size=size) * 100.0 + np_mask = rng.uniform(size=size) < selectivity + payload = pl.Series(np_payload).cast(dtype) + mask = pl.Series(np_mask, dtype=pl.Boolean) + + reference = pl.Series(np_payload[np_mask]).cast(dtype) + result = payload.filter(mask) + assert_series_equal(reference, result) + + +def test_filter_group_aware_17030() -> None: + df = pl.DataFrame({"foo": ["1", "2", "1", "2", "1", "2"]}) + + trim_col = "foo" + group_count = pl.col(trim_col).count().over(trim_col) + group_cum_count = pl.col(trim_col).cum_count().over(trim_col) + filter_expr = ( + (group_count > 2) & (group_cum_count > 1) & (group_cum_count < group_count) + ) + assert df.filter(filter_expr)["foo"].to_list() == ["1", "2"] + + +def test_invalid_filter_18295() -> None: + codes = ["a"] * 5 + ["b"] * 5 + values = list(range(-2, 3)) + list(range(2, -3, -1)) + df = pl.DataFrame({"code": codes, "value": values}) + with pytest.raises(pl.exceptions.ShapeError): + df.group_by("code").agg( + pl.col("value") + .ewm_mean(span=2, ignore_nulls=True) + .tail(3) + .filter(pl.col("value") > 0), + ).sort("code") + + +def test_filter_19771() -> None: + q = pl.LazyFrame({"a": [None, None]}) + assert q.filter(pl.lit(True)).collect()["a"].to_list() == [None, None] + + +def test_filter_expand_20014() -> None: + n_rows = 1 + date_list = [datetime(2000, 1, 1) + timedelta(days=x) for x in range(n_rows)] + df = pl.DataFrame( + { + "date": date_list, + "col1": [1], + } + ) + + df = df.with_columns(pl.col("date").dt.month().alias("month")) + assert ( + df.lazy() + .filter( + pl.col("month") <= 6, + pl.col("date") >= pl.datetime(2000, 1, 1), + pl.col("date") <= pl.datetime(2020, 1, 1), + ) + .explain(optimized=False) + .count("FILTER") + == 3 + ) diff --git a/py-polars/tests/unit/operations/test_gather.py b/py-polars/tests/unit/operations/test_gather.py new file mode 100644 index 000000000000..3b73e0331bc1 --- /dev/null +++ b/py-polars/tests/unit/operations/test_gather.py @@ -0,0 +1,284 @@ +import numpy as np +import pytest + +import polars as pl +from polars.exceptions import ComputeError +from polars.testing import assert_frame_equal, assert_series_equal + + +def test_negative_index() -> None: + df = pl.DataFrame({"a": [1, 2, 3, 4, 5, 6]}) + assert df.select(pl.col("a").gather([0, -1])).to_dict(as_series=False) == { + "a": [1, 6] + } + assert df.group_by(pl.col("a") % 2).agg(b=pl.col("a").gather([0, -1])).sort( + "a" + ).to_dict(as_series=False) == {"a": [0, 1], "b": [[2, 6], [1, 5]]} + + +def test_gather_agg_schema() -> None: + df = pl.DataFrame( + { + "group": [ + "one", + "one", + "one", + "two", + "two", + "two", + ], + "value": [1, 98, 2, 3, 99, 4], + } + ) + assert ( + df.lazy() + .group_by("group", maintain_order=True) + .agg(pl.col("value").get(1)) + .collect_schema()["value"] + == pl.Int64 + ) + + +def test_gather_lit_single_16535() -> None: + df = pl.DataFrame({"x": [1, 2, 2, 1], "y": [1, 2, 3, 4]}) + + assert df.group_by(["x"], maintain_order=True).agg(pl.all().gather([1])).to_dict( + as_series=False + ) == {"x": [1, 2], "y": [[4], [3]]} + + +def test_list_get_null_offset_17248() -> None: + df = pl.DataFrame({"material": [["PB", "PVC", "CI"], ["CI"], ["CI"]]}) + + assert df.select( + result=pl.when(pl.col.material.list.len() == 1).then("material").list.get(0), + )["result"].to_list() == [None, "CI", "CI"] + + +def test_list_get_null_oob_17252() -> None: + df = pl.DataFrame( + { + "name": ["BOB-3", "BOB", None], + } + ) + + split = df.with_columns(pl.col("name").str.split("-")) + assert split.with_columns(pl.col("name").list.get(0))["name"].to_list() == [ + "BOB", + "BOB", + None, + ] + + +def test_list_get_null_on_oob_false_success() -> None: + # test Series (single offset) with nulls + expected = pl.Series("a", [2, None, 2], dtype=pl.Int64) + s_nulls = pl.Series("a", [[1, 2], None, [1, 2, 3]]) + out = s_nulls.list.get(1, null_on_oob=False) + assert_series_equal(out, expected) + + # test Expr (multiple offsets) with nulls + df = s_nulls.to_frame().with_columns(pl.lit(1).alias("idx")) + out = df.select(pl.col("a").list.get("idx", null_on_oob=True)).to_series() + assert_series_equal(out, expected) + + # test Series (single offset) with no nulls + expected = pl.Series("a", [2, 2, 2], dtype=pl.Int64) + s_no_nulls = pl.Series("a", [[1, 2], [1, 2], [1, 2, 3]]) + out = s_no_nulls.list.get(1, null_on_oob=False) + assert_series_equal(out, expected) + + # test Expr (multiple offsets) with no nulls + df = s_no_nulls.to_frame().with_columns(pl.lit(1).alias("idx")) + out = df.select(pl.col("a").list.get("idx", null_on_oob=True)).to_series() + assert_series_equal(out, expected) + + +def test_list_get_null_on_oob_false_failure() -> None: + # test Series (single offset) with nulls + s_nulls = pl.Series("a", [[1, 2], None, [1, 2, 3]]) + with pytest.raises(ComputeError, match="get index is out of bounds"): + s_nulls.list.get(2, null_on_oob=False) + + # test Expr (multiple offsets) with nulls + df = s_nulls.to_frame().with_columns(pl.lit(2).alias("idx")) + with pytest.raises(ComputeError, match="get index is out of bounds"): + df.select(pl.col("a").list.get("idx", null_on_oob=False)) + + # test Series (single offset) with no nulls + s_no_nulls = pl.Series("a", [[1, 2], [1], [1, 2, 3]]) + with pytest.raises(ComputeError, match="get index is out of bounds"): + s_no_nulls.list.get(2, null_on_oob=False) + + # test Expr (multiple offsets) with no nulls + df = s_no_nulls.to_frame().with_columns(pl.lit(2).alias("idx")) + with pytest.raises(ComputeError, match="get index is out of bounds"): + df.select(pl.col("a").list.get("idx", null_on_oob=False)) + + +def test_list_get_null_on_oob_true() -> None: + # test Series (single offset) with nulls + s_nulls = pl.Series("a", [[1, 2], None, [1, 2, 3]]) + out = s_nulls.list.get(2, null_on_oob=True) + expected = pl.Series("a", [None, None, 3], dtype=pl.Int64) + assert_series_equal(out, expected) + + # test Expr (multiple offsets) with nulls + df = s_nulls.to_frame().with_columns(pl.lit(2).alias("idx")) + out = df.select(pl.col("a").list.get("idx", null_on_oob=True)).to_series() + assert_series_equal(out, expected) + + # test Series (single offset) with no nulls + s_no_nulls = pl.Series("a", [[1, 2], [1], [1, 2, 3]]) + out = s_no_nulls.list.get(2, null_on_oob=True) + expected = pl.Series("a", [None, None, 3], dtype=pl.Int64) + assert_series_equal(out, expected) + + # test Expr (multiple offsets) with no nulls + df = s_no_nulls.to_frame().with_columns(pl.lit(2).alias("idx")) + out = df.select(pl.col("a").list.get("idx", null_on_oob=True)).to_series() + assert_series_equal(out, expected) + + +def test_chunked_gather_phys_repr_17446() -> None: + dfa = pl.DataFrame({"replace_unique_id": range(2)}) + + for dt in [pl.Date, pl.Time, pl.Duration]: + dfb = dfa.clone() + dfb = dfb.with_columns(ds_start_date_right=pl.lit(None).cast(dt)) + dfb = pl.concat([dfb, dfb]) + + assert dfa.join(dfb, how="left", on=pl.col("replace_unique_id")).shape == (4, 2) + + +def test_gather_str_col_18099() -> None: + df = pl.DataFrame({"foo": [1, 2, 3], "idx": [0, 0, 1]}) + assert df.with_columns(pl.col("foo").gather("idx")).to_dict(as_series=False) == { + "foo": [1, 1, 2], + "idx": [0, 0, 1], + } + + +def test_gather_list_19243() -> None: + df = pl.DataFrame({"a": [[0.1, 0.2, 0.3]]}) + assert df.with_columns(pl.lit([0]).alias("c")).with_columns( + gather=pl.col("a").list.gather(pl.col("c"), null_on_oob=True) + ).to_dict(as_series=False) == { + "a": [[0.1, 0.2, 0.3]], + "c": [[0]], + "gather": [[0.1]], + } + + +def test_gather_array_list_null_19302() -> None: + data = pl.DataFrame( + {"data": [None]}, schema_overrides={"data": pl.List(pl.Array(pl.Float32, 1))} + ) + assert data.select(pl.col("data").list.get(0)).to_dict(as_series=False) == { + "data": [None] + } + + +def test_gather_array() -> None: + a = np.arange(16).reshape(-1, 2, 2) + s = pl.Series(a) + + for idx in [[1, 2], [0, 0], [1, 0], [1, 1, 1, 1, 1, 1, 1, 1]]: + assert (s.gather(idx).to_numpy() == a[idx]).all() + + v = s[[0, 1, None, 3]] # type: ignore[list-item] + assert v[2] is None + + +def test_gather_array_outer_validity_19482() -> None: + s = ( + pl.Series([[1], [1]], dtype=pl.Array(pl.Int64, 1)) + .to_frame() + .select(pl.when(pl.int_range(pl.len()) == 0).then(pl.first())) + .to_series() + ) + + expect = pl.Series([[1], None], dtype=pl.Array(pl.Int64, 1)) + assert_series_equal(s, expect) + assert_series_equal(s.gather([0, 1]), expect) + + +def test_gather_len_19561() -> None: + N = 4 + df = pl.DataFrame({"foo": ["baz"] * N, "bar": range(N)}) + idxs = pl.int_range(1, N).repeat_by(pl.int_range(1, N)).flatten() + gather = pl.col.bar.gather(idxs).alias("gather") + + assert df.group_by("foo").agg(gather.len()).to_dict(as_series=False) == { + "foo": ["baz"], + "gather": [6], + } + + +def test_gather_agg_group_update_scalar() -> None: + # If `gather` doesn't update groups properly, `first` will try to access + # index 2 (the original index of the first element of group `1`), but gather + # outputs only two elements (one for each group), leading to an out of + # bounds access. + df = ( + pl.DataFrame({"gid": [0, 0, 1, 1], "x": ["0:0", "0:1", "1:0", "1:1"]}) + .lazy() + .group_by("gid", maintain_order=True) + .agg(x_at_gid=pl.col("x").gather(pl.col("gid").last()).first()) + .collect(no_optimization=True, simplify_expression=False) + ) + expected = pl.DataFrame({"gid": [0, 1], "x_at_gid": ["0:0", "1:1"]}) + assert_frame_equal(df, expected) + + +def test_gather_agg_group_update_literal() -> None: + # If `gather` doesn't update groups properly, `first` will try to access + # index 2 (the original index of the first element of group `1`), but gather + # outputs only two elements (one for each group), leading to an out of + # bounds access. + df = ( + pl.DataFrame({"gid": [0, 0, 1], "x": ["0:0", "0:1", "1:0"]}) + .lazy() + .group_by("gid", maintain_order=True) + .agg(x_at_0=pl.col("x").gather(0).first()) + .collect(no_optimization=True, simplify_expression=False) + ) + expected = pl.DataFrame({"gid": [0, 1], "x_at_0": ["0:0", "1:0"]}) + assert_frame_equal(df, expected) + + +def test_gather_agg_group_update_negative() -> None: + # If `gather` doesn't update groups properly, `first` will try to access + # index 2 (the original index of the first element of group `1`), but gather + # outputs only two elements (one for each group), leading to an out of + # bounds access. + df = ( + pl.DataFrame({"gid": [0, 0, 1], "x": ["0:0", "0:1", "1:0"]}) + .lazy() + .group_by("gid", maintain_order=True) + .agg(x_last=pl.col("x").gather(-1).first()) + .collect(no_optimization=True, simplify_expression=False) + ) + expected = pl.DataFrame({"gid": [0, 1], "x_last": ["0:1", "1:0"]}) + assert_frame_equal(df, expected) + + +def test_gather_agg_group_update_multiple() -> None: + # If `gather` doesn't update groups properly, `first` will try to access + # index 4 (the original index of the first element of group `1`), but gather + # outputs only four elements (two for each group), leading to an out of + # bounds access. + df = ( + pl.DataFrame( + { + "gid": [0, 0, 0, 0, 1, 1], + "x": ["0:0", "0:1", "0:2", "0:3", "1:0", "1:1"], + } + ) + .lazy() + .group_by("gid", maintain_order=True) + .agg(x_at_0=pl.col("x").gather([0, 1]).first()) + .collect(no_optimization=True, simplify_expression=False) + ) + expected = pl.DataFrame({"gid": [0, 1], "x_at_0": ["0:0", "1:0"]}) + assert_frame_equal(df, expected) diff --git a/py-polars/tests/unit/operations/test_group_by.py b/py-polars/tests/unit/operations/test_group_by.py new file mode 100644 index 000000000000..9bd96de12873 --- /dev/null +++ b/py-polars/tests/unit/operations/test_group_by.py @@ -0,0 +1,1270 @@ +from __future__ import annotations + +import typing +from collections import OrderedDict +from datetime import date, datetime, timedelta +from typing import TYPE_CHECKING, Any + +import numpy as np +import pytest + +import polars as pl +import polars.selectors as cs +from polars.exceptions import ColumnNotFoundError +from polars.meta import get_index_type +from polars.testing import assert_frame_equal, assert_series_equal + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + + +def test_group_by() -> None: + df = pl.DataFrame( + { + "a": ["a", "b", "a", "b", "b", "c"], + "b": [1, 2, 3, 4, 5, 6], + "c": [6, 5, 4, 3, 2, 1], + } + ) + + # Use lazy API in eager group_by + assert sorted(df.group_by("a").agg([pl.sum("b")]).rows()) == [ + ("a", 4), + ("b", 11), + ("c", 6), + ] + # test if it accepts a single expression + assert df.group_by("a", maintain_order=True).agg(pl.sum("b")).rows() == [ + ("a", 4), + ("b", 11), + ("c", 6), + ] + + df = pl.DataFrame( + { + "a": [1, 2, 3, 4, 5], + "b": ["a", "a", "b", "b", "b"], + "c": [None, 1, None, 1, None], + } + ) + + # check if this query runs and thus column names propagate + df.group_by("b").agg(pl.col("c").fill_null(strategy="forward")).explode("c") + + # get a specific column + result = df.group_by("b", maintain_order=True).agg(pl.count("a")) + assert result.rows() == [("a", 2), ("b", 3)] + assert result.columns == ["b", "a"] + + +@pytest.mark.parametrize( + ("input", "expected", "input_dtype", "output_dtype"), + [ + ([1, 2, 3, 4], [2, 4], pl.UInt8, pl.Float64), + ([1, 2, 3, 4], [2, 4], pl.Int8, pl.Float64), + ([1, 2, 3, 4], [2, 4], pl.UInt16, pl.Float64), + ([1, 2, 3, 4], [2, 4], pl.Int16, pl.Float64), + ([1, 2, 3, 4], [2, 4], pl.UInt32, pl.Float64), + ([1, 2, 3, 4], [2, 4], pl.Int32, pl.Float64), + ([1, 2, 3, 4], [2, 4], pl.UInt64, pl.Float64), + ([1, 2, 3, 4], [2, 4], pl.Float32, pl.Float32), + ([1, 2, 3, 4], [2, 4], pl.Float64, pl.Float64), + ([False, True, True, True], [2 / 3, 1], pl.Boolean, pl.Float64), + ( + [date(2023, 1, 1), date(2023, 1, 2), date(2023, 1, 4), date(2023, 1, 5)], + [datetime(2023, 1, 2, 8, 0, 0), datetime(2023, 1, 5)], + pl.Date, + pl.Datetime("ms"), + ), + ( + [ + datetime(2023, 1, 1), + datetime(2023, 1, 2), + datetime(2023, 1, 3), + datetime(2023, 1, 4), + ], + [datetime(2023, 1, 2), datetime(2023, 1, 4)], + pl.Datetime("ms"), + pl.Datetime("ms"), + ), + ( + [ + datetime(2023, 1, 1), + datetime(2023, 1, 2), + datetime(2023, 1, 3), + datetime(2023, 1, 4), + ], + [datetime(2023, 1, 2), datetime(2023, 1, 4)], + pl.Datetime("us"), + pl.Datetime("us"), + ), + ( + [ + datetime(2023, 1, 1), + datetime(2023, 1, 2), + datetime(2023, 1, 3), + datetime(2023, 1, 4), + ], + [datetime(2023, 1, 2), datetime(2023, 1, 4)], + pl.Datetime("ns"), + pl.Datetime("ns"), + ), + ( + [timedelta(1), timedelta(2), timedelta(3), timedelta(4)], + [timedelta(2), timedelta(4)], + pl.Duration("ms"), + pl.Duration("ms"), + ), + ( + [timedelta(1), timedelta(2), timedelta(3), timedelta(4)], + [timedelta(2), timedelta(4)], + pl.Duration("us"), + pl.Duration("us"), + ), + ( + [timedelta(1), timedelta(2), timedelta(3), timedelta(4)], + [timedelta(2), timedelta(4)], + pl.Duration("ns"), + pl.Duration("ns"), + ), + ], +) +def test_group_by_mean_by_dtype( + input: list[Any], + expected: list[Any], + input_dtype: PolarsDataType, + output_dtype: PolarsDataType, +) -> None: + # groups are defined by first 3 values, then last value + name = str(input_dtype) + key = ["a", "a", "a", "b"] + df = pl.DataFrame( + { + "key": key, + name: pl.Series(input, dtype=input_dtype), + } + ) + result = df.group_by("key", maintain_order=True).mean() + df_expected = pl.DataFrame( + { + "key": ["a", "b"], + name: pl.Series(expected, dtype=output_dtype), + } + ) + assert_frame_equal(result, df_expected) + + +@pytest.mark.parametrize( + ("input", "expected", "input_dtype", "output_dtype"), + [ + ([1, 2, 4, 5], [2, 5], pl.UInt8, pl.Float64), + ([1, 2, 4, 5], [2, 5], pl.Int8, pl.Float64), + ([1, 2, 4, 5], [2, 5], pl.UInt16, pl.Float64), + ([1, 2, 4, 5], [2, 5], pl.Int16, pl.Float64), + ([1, 2, 4, 5], [2, 5], pl.UInt32, pl.Float64), + ([1, 2, 4, 5], [2, 5], pl.Int32, pl.Float64), + ([1, 2, 4, 5], [2, 5], pl.UInt64, pl.Float64), + ([1, 2, 4, 5], [2, 5], pl.Float32, pl.Float32), + ([1, 2, 4, 5], [2, 5], pl.Float64, pl.Float64), + ([False, True, True, True], [1, 1], pl.Boolean, pl.Float64), + ( + [date(2023, 1, 1), date(2023, 1, 2), date(2023, 1, 4), date(2023, 1, 5)], + [datetime(2023, 1, 2), datetime(2023, 1, 5)], + pl.Date, + pl.Datetime("ms"), + ), + ( + [ + datetime(2023, 1, 1), + datetime(2023, 1, 2), + datetime(2023, 1, 4), + datetime(2023, 1, 5), + ], + [datetime(2023, 1, 2), datetime(2023, 1, 5)], + pl.Datetime("ms"), + pl.Datetime("ms"), + ), + ( + [ + datetime(2023, 1, 1), + datetime(2023, 1, 2), + datetime(2023, 1, 4), + datetime(2023, 1, 5), + ], + [datetime(2023, 1, 2), datetime(2023, 1, 5)], + pl.Datetime("us"), + pl.Datetime("us"), + ), + ( + [ + datetime(2023, 1, 1), + datetime(2023, 1, 2), + datetime(2023, 1, 4), + datetime(2023, 1, 5), + ], + [datetime(2023, 1, 2), datetime(2023, 1, 5)], + pl.Datetime("ns"), + pl.Datetime("ns"), + ), + ( + [timedelta(1), timedelta(2), timedelta(4), timedelta(5)], + [timedelta(2), timedelta(5)], + pl.Duration("ms"), + pl.Duration("ms"), + ), + ( + [timedelta(1), timedelta(2), timedelta(4), timedelta(5)], + [timedelta(2), timedelta(5)], + pl.Duration("us"), + pl.Duration("us"), + ), + ( + [timedelta(1), timedelta(2), timedelta(4), timedelta(5)], + [timedelta(2), timedelta(5)], + pl.Duration("ns"), + pl.Duration("ns"), + ), + ], +) +def test_group_by_median_by_dtype( + input: list[Any], + expected: list[Any], + input_dtype: PolarsDataType, + output_dtype: PolarsDataType, +) -> None: + # groups are defined by first 3 values, then last value + name = str(input_dtype) + key = ["a", "a", "a", "b"] + df = pl.DataFrame( + { + "key": key, + name: pl.Series(input, dtype=input_dtype), + } + ) + result = df.group_by("key", maintain_order=True).median() + df_expected = pl.DataFrame( + { + "key": ["a", "b"], + name: pl.Series(expected, dtype=output_dtype), + } + ) + assert_frame_equal(result, df_expected) + + +@pytest.fixture +def df() -> pl.DataFrame: + return pl.DataFrame( + { + "a": [1, 2, 3, 4, 5], + "b": ["a", "a", "b", "b", "b"], + "c": [None, 1, None, 1, None], + } + ) + + +@pytest.mark.parametrize( + ("method", "expected"), + [ + ("all", [("a", [1, 2], [None, 1]), ("b", [3, 4, 5], [None, 1, None])]), + ("len", [("a", 2), ("b", 3)]), + ("first", [("a", 1, None), ("b", 3, None)]), + ("last", [("a", 2, 1), ("b", 5, None)]), + ("max", [("a", 2, 1), ("b", 5, 1)]), + ("mean", [("a", 1.5, 1.0), ("b", 4.0, 1.0)]), + ("median", [("a", 1.5, 1.0), ("b", 4.0, 1.0)]), + ("min", [("a", 1, 1), ("b", 3, 1)]), + ("n_unique", [("a", 2, 2), ("b", 3, 2)]), + ], +) +def test_group_by_shorthands( + df: pl.DataFrame, method: str, expected: list[tuple[Any]] +) -> None: + gb = df.group_by("b", maintain_order=True) + result = getattr(gb, method)() + assert result.rows() == expected + + gb_lazy = df.lazy().group_by("b", maintain_order=True) + result = getattr(gb_lazy, method)().collect() + assert result.rows() == expected + + +def test_group_by_shorthand_quantile(df: pl.DataFrame) -> None: + result = df.group_by("b", maintain_order=True).quantile(0.5) + expected = [("a", 2.0, 1.0), ("b", 4.0, 1.0)] + assert result.rows() == expected + + result = df.lazy().group_by("b", maintain_order=True).quantile(0.5).collect() + assert result.rows() == expected + + +def test_group_by_args() -> None: + df = pl.DataFrame( + { + "a": ["a", "b", "a", "b", "b", "c"], + "b": [1, 2, 3, 4, 5, 6], + "c": [6, 5, 4, 3, 2, 1], + } + ) + + # Single column name + assert df.group_by("a").agg("b").columns == ["a", "b"] + # Column names as list + expected = ["a", "b", "c"] + assert df.group_by(["a", "b"]).agg("c").columns == expected + # Column names as positional arguments + assert df.group_by("a", "b").agg("c").columns == expected + # With keyword argument + assert df.group_by("a", "b", maintain_order=True).agg("c").columns == expected + # Multiple aggregations as list + assert df.group_by("a").agg(["b", "c"]).columns == expected + # Multiple aggregations as positional arguments + assert df.group_by("a").agg("b", "c").columns == expected + # Multiple aggregations as keyword arguments + assert df.group_by("a").agg(q="b", r="c").columns == ["a", "q", "r"] + + +def test_group_by_empty() -> None: + df = pl.DataFrame({"a": [1, 1, 2]}) + result = df.group_by("a").agg() + expected = pl.DataFrame({"a": [1, 2]}) + assert_frame_equal(result, expected, check_row_order=False) + + +def test_group_by_iteration() -> None: + df = pl.DataFrame( + { + "foo": ["a", "b", "a", "b", "b", "c"], + "bar": [1, 2, 3, 4, 5, 6], + "baz": [6, 5, 4, 3, 2, 1], + } + ) + expected_names = ["a", "b", "c"] + expected_rows = [ + [("a", 1, 6), ("a", 3, 4)], + [("b", 2, 5), ("b", 4, 3), ("b", 5, 2)], + [("c", 6, 1)], + ] + gb_iter = enumerate(df.group_by("foo", maintain_order=True)) + for i, (group, data) in gb_iter: + assert group == (expected_names[i],) + assert data.rows() == expected_rows[i] + + # Grouped by ALL columns should give groups of a single row + result = list(df.group_by(["foo", "bar", "baz"])) + assert len(result) == 6 + + # Iterating over groups should also work when grouping by expressions + result2 = list(df.group_by(["foo", pl.col("bar") * pl.col("baz")])) + assert len(result2) == 5 + + # Single expression, alias in group_by + df = pl.DataFrame({"foo": [1, 2, 3, 4, 5, 6]}) + gb = df.group_by((pl.col("foo") // 2).alias("bar"), maintain_order=True) + result3 = [(group, df.rows()) for group, df in gb] + expected3 = [ + ((0,), [(1,)]), + ((1,), [(2,), (3,)]), + ((2,), [(4,), (5,)]), + ((3,), [(6,)]), + ] + assert result3 == expected3 + + +def test_group_by_iteration_selector() -> None: + df = pl.DataFrame({"a": ["one", "two", "one", "two"], "b": [1, 2, 3, 4]}) + result = dict(df.group_by(cs.string())) + result_first = result["one",] + assert result_first.to_dict(as_series=False) == {"a": ["one", "one"], "b": [1, 3]} + + +@pytest.mark.parametrize("input", [[pl.col("b").sum()], pl.col("b").sum()]) +def test_group_by_agg_input_types(input: Any) -> None: + df = pl.LazyFrame({"a": [1, 1, 2, 2], "b": [1, 2, 3, 4]}) + result = df.group_by("a", maintain_order=True).agg(input) + expected = pl.LazyFrame({"a": [1, 2], "b": [3, 7]}) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("input", [str, "b".join]) +def test_group_by_agg_bad_input_types(input: Any) -> None: + df = pl.LazyFrame({"a": [1, 1, 2, 2], "b": [1, 2, 3, 4]}) + with pytest.raises(TypeError): + df.group_by("a").agg(input) + + +def test_group_by_sorted_empty_dataframe_3680() -> None: + df = ( + pl.DataFrame( + [ + pl.Series("key", [], dtype=pl.Categorical), + pl.Series("val", [], dtype=pl.Float64), + ] + ) + .lazy() + .sort("key") + .group_by("key") + .tail(1) + .collect(_check_order=False) + ) + assert df.rows() == [] + assert df.shape == (0, 2) + assert df.schema == {"key": pl.Categorical(ordering="physical"), "val": pl.Float64} + + +def test_group_by_custom_agg_empty_list() -> None: + assert ( + pl.DataFrame( + [ + pl.Series("key", [], dtype=pl.Categorical), + pl.Series("val", [], dtype=pl.Float64), + ] + ) + .group_by("key") + .agg( + [ + pl.col("val").mean().alias("mean"), + pl.col("val").std().alias("std"), + pl.col("val").skew().alias("skew"), + pl.col("val").kurtosis().alias("kurt"), + ] + ) + ).dtypes == [pl.Categorical, pl.Float64, pl.Float64, pl.Float64, pl.Float64] + + +def test_apply_after_take_in_group_by_3869() -> None: + assert ( + pl.DataFrame( + { + "k": list("aaabbb"), + "t": [1, 2, 3, 4, 5, 6], + "v": [3, 1, 2, 5, 6, 4], + } + ) + .group_by("k", maintain_order=True) + .agg( + pl.col("v").get(pl.col("t").arg_max()).sqrt() + ) # <- fails for sqrt, exp, log, pow, etc. + ).to_dict(as_series=False) == {"k": ["a", "b"], "v": [1.4142135623730951, 2.0]} + + +def test_group_by_signed_transmutes() -> None: + df = pl.DataFrame({"foo": [-1, -2, -3, -4, -5], "bar": [500, 600, 700, 800, 900]}) + + for dt in [pl.Int8, pl.Int16, pl.Int32, pl.Int64]: + df = ( + df.with_columns([pl.col("foo").cast(dt), pl.col("bar")]) + .group_by("foo", maintain_order=True) + .agg(pl.col("bar").median()) + ) + + assert df.to_dict(as_series=False) == { + "foo": [-1, -2, -3, -4, -5], + "bar": [500.0, 600.0, 700.0, 800.0, 900.0], + } + + +def test_arg_sort_sort_by_groups_update__4360() -> None: + df = pl.DataFrame( + { + "group": ["a"] * 3 + ["b"] * 3 + ["c"] * 3, + "col1": [1, 2, 3] * 3, + "col2": [1, 2, 3, 3, 2, 1, 2, 3, 1], + } + ) + + out = df.with_columns( + pl.col("col2").arg_sort().over("group").alias("col2_arg_sort") + ).with_columns( + pl.col("col1").sort_by(pl.col("col2_arg_sort")).over("group").alias("result_a"), + pl.col("col1") + .sort_by(pl.col("col2").arg_sort()) + .over("group") + .alias("result_b"), + ) + + assert_series_equal(out["result_a"], out["result_b"], check_names=False) + assert out["result_a"].to_list() == [1, 2, 3, 3, 2, 1, 2, 3, 1] + + +def test_unique_order() -> None: + df = pl.DataFrame({"a": [1, 2, 1]}).with_row_index() + assert df.unique(keep="last", subset="a", maintain_order=True).to_dict( + as_series=False + ) == { + "index": [1, 2], + "a": [2, 1], + } + assert df.unique(keep="first", subset="a", maintain_order=True).to_dict( + as_series=False + ) == { + "index": [0, 1], + "a": [1, 2], + } + + +def test_group_by_dynamic_flat_agg_4814() -> None: + df = pl.DataFrame({"a": [1, 2, 2], "b": [1, 8, 12]}).set_sorted("a") + + assert df.group_by_dynamic("a", every="1i", period="2i").agg( + [ + (pl.col("b").sum() / pl.col("a").sum()).alias("sum_ratio_1"), + (pl.col("b").last() / pl.col("a").last()).alias("last_ratio_1"), + (pl.col("b") / pl.col("a")).last().alias("last_ratio_2"), + ] + ).to_dict(as_series=False) == { + "a": [1, 2], + "sum_ratio_1": [4.2, 5.0], + "last_ratio_1": [6.0, 6.0], + "last_ratio_2": [6.0, 6.0], + } + + +@pytest.mark.parametrize( + ("every", "period"), + [ + ("10s", timedelta(seconds=100)), + (timedelta(seconds=10), "100s"), + ], +) +@pytest.mark.parametrize("time_zone", [None, "UTC", "Asia/Kathmandu"]) +def test_group_by_dynamic_overlapping_groups_flat_apply_multiple_5038( + every: str | timedelta, period: str | timedelta, time_zone: str | None +) -> None: + res = ( + ( + pl.DataFrame( + { + "a": [ + datetime(2021, 1, 1) + timedelta(seconds=2**i) + for i in range(10) + ], + "b": [float(i) for i in range(10)], + } + ) + .with_columns(pl.col("a").dt.replace_time_zone(time_zone)) + .lazy() + .set_sorted("a") + .group_by_dynamic("a", every=every, period=period) + .agg([pl.col("b").var().sqrt().alias("corr")]) + ) + .collect() + .sum() + .to_dict(as_series=False) + ) + + assert res["corr"] == pytest.approx([6.988674024215477]) + assert res["a"] == [None] + + +def test_take_in_group_by() -> None: + df = pl.DataFrame({"group": [1, 1, 1, 2, 2, 2], "values": [10, 200, 3, 40, 500, 6]}) + assert df.group_by("group").agg( + pl.col("values").get(1) - pl.col("values").get(2) + ).sort("group").to_dict(as_series=False) == {"group": [1, 2], "values": [197, 494]} + + +def test_group_by_wildcard() -> None: + df = pl.DataFrame( + { + "a": [1, 2], + "b": [1, 2], + } + ) + assert df.group_by([pl.col("*")], maintain_order=True).agg( + [pl.col("a").first().name.suffix("_agg")] + ).to_dict(as_series=False) == {"a": [1, 2], "b": [1, 2], "a_agg": [1, 2]} + + +def test_group_by_all_masked_out() -> None: + df = pl.DataFrame( + { + "val": pl.Series( + [None, None, None, None], dtype=pl.Categorical, nan_to_null=True + ).set_sorted(), + "col": [4, 4, 4, 4], + } + ) + parts = df.partition_by("val") + assert len(parts) == 1 + assert_frame_equal(parts[0], df) + + +def test_group_by_null_propagation_6185() -> None: + df_1 = pl.DataFrame({"A": [0, 0], "B": [1, 2]}) + + expr = pl.col("A").filter(pl.col("A") > 0) + + expected = {"B": [1, 2], "A": [None, None]} + assert ( + df_1.group_by("B") + .agg((expr - expr.mean()).mean()) + .sort("B") + .to_dict(as_series=False) + == expected + ) + + +def test_group_by_when_then_with_binary_and_agg_in_pred_6202() -> None: + df = pl.DataFrame( + {"code": ["a", "b", "b", "b", "a"], "xx": [1.0, -1.5, -0.2, -3.9, 3.0]} + ) + assert ( + df.group_by("code", maintain_order=True).agg( + [pl.when(pl.col("xx") > pl.min("xx")).then(True).otherwise(False)] + ) + ).to_dict(as_series=False) == { + "code": ["a", "b"], + "literal": [[False, True], [True, True, False]], + } + + +def test_group_by_binary_agg_with_literal() -> None: + df = pl.DataFrame({"id": ["a", "a", "b", "b"], "value": [1, 2, 3, 4]}) + + out = df.group_by("id", maintain_order=True).agg( + pl.col("value") + pl.Series([1, 3]) + ) + assert out.to_dict(as_series=False) == {"id": ["a", "b"], "value": [[2, 5], [4, 7]]} + + out = df.group_by("id", maintain_order=True).agg(pl.col("value") + pl.lit(1)) + assert out.to_dict(as_series=False) == {"id": ["a", "b"], "value": [[2, 3], [4, 5]]} + + out = df.group_by("id", maintain_order=True).agg(pl.lit(1) + pl.lit(2)) + assert out.to_dict(as_series=False) == {"id": ["a", "b"], "literal": [3, 3]} + + out = df.group_by("id", maintain_order=True).agg(pl.lit(1) + pl.Series([2, 3])) + assert out.to_dict(as_series=False) == { + "id": ["a", "b"], + "literal": [[3, 4], [3, 4]], + } + + out = df.group_by("id", maintain_order=True).agg( + value=pl.lit(pl.Series([1, 2])) + pl.lit(pl.Series([3, 4])) + ) + assert out.to_dict(as_series=False) == {"id": ["a", "b"], "value": [[4, 6], [4, 6]]} + + +@pytest.mark.slow +@pytest.mark.parametrize("dtype", [pl.Int32, pl.UInt32]) +def test_overflow_mean_partitioned_group_by_5194(dtype: PolarsDataType) -> None: + df = pl.DataFrame( + [ + pl.Series("data", [10_00_00_00] * 100_000, dtype=dtype), + pl.Series("group", [1, 2] * 50_000, dtype=dtype), + ] + ) + result = df.group_by("group").agg(pl.col("data").mean()).sort(by="group") + expected = {"group": [1, 2], "data": [10000000.0, 10000000.0]} + assert result.to_dict(as_series=False) == expected + + +# https://github.com/pola-rs/polars/issues/7181 +def test_group_by_multiple_column_reference() -> None: + df = pl.DataFrame( + { + "gr": ["a", "b", "a", "b", "a", "b"], + "val": [1, 20, 100, 2000, 10000, 200000], + } + ) + result = df.group_by("gr").agg( + pl.col("val") + pl.col("val").shift().fill_null(0), + ) + + assert result.sort("gr").to_dict(as_series=False) == { + "gr": ["a", "b"], + "val": [[1, 101, 10100], [20, 2020, 202000]], + } + + +@pytest.mark.parametrize( + ("aggregation", "args", "expected_values", "expected_dtype"), + [ + ("first", [], [1, None], pl.Int64), + ("last", [], [1, None], pl.Int64), + ("max", [], [1, None], pl.Int64), + ("mean", [], [1.0, None], pl.Float64), + ("median", [], [1.0, None], pl.Float64), + ("min", [], [1, None], pl.Int64), + ("n_unique", [], [1, 0], pl.UInt32), + ("quantile", [0.5], [1.0, None], pl.Float64), + ], +) +def test_group_by_empty_groups( + aggregation: str, + args: list[object], + expected_values: list[object], + expected_dtype: pl.DataType, +) -> None: + df = pl.DataFrame({"a": [1, 2], "b": [1, 2]}) + result = df.group_by("b", maintain_order=True).agg( + getattr(pl.col("a").filter(pl.col("b") != 2), aggregation)(*args) + ) + expected = pl.DataFrame({"b": [1, 2], "a": expected_values}).with_columns( + pl.col("a").cast(expected_dtype) + ) + assert_frame_equal(result, expected) + + +# https://github.com/pola-rs/polars/issues/8663 +def test_perfect_hash_table_null_values() -> None: + # fmt: off + values = ["3", "41", "17", "5", "26", "27", "43", "45", "41", "13", "45", "48", "17", "22", "31", "25", "28", "13", "7", "26", "17", "4", "43", "47", "30", "28", "8", "27", "6", "7", "26", "11", "37", "29", "49", "20", "29", "28", "23", "9", None, "38", "19", "7", "38", "3", "30", "37", "41", "5", "16", "26", "31", "6", "25", "11", "17", "31", "31", "20", "26", None, "39", "10", "38", "4", "39", "15", "13", "35", "38", "11", "39", "11", "48", "36", "18", "11", "34", "16", "28", "9", "37", "8", "17", "48", "44", "28", "25", "30", "37", "30", "18", "12", None, "27", "10", "3", "16", "27", "6"] + groups = ["3", "41", "17", "5", "26", "27", "43", "45", "13", "48", "22", "31", "25", "28", "7", "4", "47", "30", "8", "6", "11", "37", "29", "49", "20", "23", "9", None, "38", "19", "16", "39", "10", "15", "35", "36", "18", "34", "44", "12"] + # fmt: on + + s = pl.Series("a", values, dtype=pl.Categorical) + + result = ( + s.to_frame("a").group_by("a", maintain_order=True).agg(pl.col("a").alias("agg")) + ) + + agg_values = [ + ["3", "3", "3"], + ["41", "41", "41"], + ["17", "17", "17", "17", "17"], + ["5", "5"], + ["26", "26", "26", "26", "26"], + ["27", "27", "27", "27"], + ["43", "43"], + ["45", "45"], + ["13", "13", "13"], + ["48", "48", "48"], + ["22"], + ["31", "31", "31", "31"], + ["25", "25", "25"], + ["28", "28", "28", "28", "28"], + ["7", "7", "7"], + ["4", "4"], + ["47"], + ["30", "30", "30", "30"], + ["8", "8"], + ["6", "6", "6"], + ["11", "11", "11", "11", "11"], + ["37", "37", "37", "37"], + ["29", "29"], + ["49"], + ["20", "20"], + ["23"], + ["9", "9"], + [None, None, None], + ["38", "38", "38", "38"], + ["19"], + ["16", "16", "16"], + ["39", "39", "39"], + ["10", "10"], + ["15"], + ["35"], + ["36"], + ["18", "18"], + ["34"], + ["44"], + ["12"], + ] + expected = pl.DataFrame( + { + "a": groups, + "agg": agg_values, + }, + schema={"a": pl.Categorical, "agg": pl.List(pl.Categorical)}, + ) + assert_frame_equal(result, expected) + + +def test_group_by_partitioned_ending_cast(monkeypatch: Any) -> None: + monkeypatch.setenv("POLARS_FORCE_PARTITION", "1") + df = pl.DataFrame({"a": [1] * 5, "b": [1] * 5}) + out = df.group_by(["a", "b"]).agg(pl.len().cast(pl.Int64).alias("num")) + expected = pl.DataFrame({"a": [1], "b": [1], "num": [5]}) + assert_frame_equal(out, expected) + + +def test_group_by_series_partitioned(partition_limit: int) -> None: + # test 15354 + df = pl.DataFrame([0, 0] * partition_limit) + groups = pl.Series([0, 1] * partition_limit) + df.group_by(groups).agg(pl.all().is_not_null().sum()) + + +def test_group_by_list_scalar_11749() -> None: + df = pl.DataFrame( + { + "group_name": ["a;b", "a;b", "c;d", "c;d", "a;b", "a;b"], + "parent_name": ["a", "b", "c", "d", "a", "b"], + "measurement": [ + ["x1", "x2"], + ["x1", "x2"], + ["y1", "y2"], + ["z1", "z2"], + ["x1", "x2"], + ["x1", "x2"], + ], + } + ) + assert ( + df.group_by("group_name").agg( + (pl.col("measurement").first() == pl.col("measurement")).alias("eq"), + ) + ).sort("group_name").to_dict(as_series=False) == { + "group_name": ["a;b", "c;d"], + "eq": [[True, True, True, True], [True, False]], + } + + +def test_group_by_with_expr_as_key() -> None: + gb = pl.select(x=1).group_by(pl.col("x").alias("key")) + result = gb.agg(pl.all().first()) + expected = gb.agg(pl.first("x")) + assert_frame_equal(result, expected) + + # tests: 11766 + result = gb.head(0) + expected = gb.agg(pl.col("x").head(0)).explode("x") + assert_frame_equal(result, expected) + + result = gb.tail(0) + expected = gb.agg(pl.col("x").tail(0)).explode("x") + assert_frame_equal(result, expected) + + +def test_lazy_group_by_reuse_11767() -> None: + lgb = pl.select(x=1).lazy().group_by("x") + a = lgb.len() + b = lgb.len() + assert_frame_equal(a, b) + + +def test_group_by_double_on_empty_12194() -> None: + df = pl.DataFrame({"group": [1], "x": [1]}).clear() + squared_deviation_sum = ((pl.col("x") - pl.col("x").mean()) ** 2).sum() + assert df.group_by("group").agg(squared_deviation_sum).schema == OrderedDict( + [("group", pl.Int64), ("x", pl.Float64)] + ) + + +def test_group_by_when_then_no_aggregation_predicate() -> None: + df = pl.DataFrame( + { + "key": ["aa", "aa", "bb", "bb", "aa", "aa"], + "val": [-3, -2, 1, 4, -3, 5], + } + ) + assert df.group_by("key").agg( + pos=pl.when(pl.col("val") >= 0).then(pl.col("val")).sum(), + neg=pl.when(pl.col("val") < 0).then(pl.col("val")).sum(), + ).sort("key").to_dict(as_series=False) == { + "key": ["aa", "bb"], + "pos": [5, 5], + "neg": [-8, 0], + } + + +def test_group_by_apply_first_input_is_literal() -> None: + df = pl.DataFrame({"x": [1, 2, 3, 4, 5], "g": [1, 1, 2, 2, 2]}) + pow = df.group_by("g").agg(2 ** pl.col("x")) + assert pow.sort("g").to_dict(as_series=False) == { + "g": [1, 2], + "literal": [[2.0, 4.0], [8.0, 16.0, 32.0]], + } + + +def test_group_by_all_12869() -> None: + df = pl.DataFrame({"a": [1]}) + result = next(iter(df.group_by(pl.all())))[1] + assert_frame_equal(df, result) + + +def test_group_by_named() -> None: + df = pl.DataFrame({"a": [1, 1, 2, 2, 3, 3], "b": range(6)}) + result = df.group_by(z=pl.col("a") * 2, maintain_order=True).agg(pl.col("b").min()) + expected = df.group_by((pl.col("a") * 2).alias("z"), maintain_order=True).agg( + pl.col("b").min() + ) + assert_frame_equal(result, expected) + + +def test_group_by_with_null() -> None: + df = pl.DataFrame( + {"a": [None, None, None, None], "b": [1, 1, 2, 2], "c": ["x", "y", "z", "u"]} + ) + expected = pl.DataFrame( + {"a": [None, None], "b": [1, 2], "c": [["x", "y"], ["z", "u"]]} + ) + output = df.group_by(["a", "b"], maintain_order=True).agg(pl.col("c")) + assert_frame_equal(expected, output) + + +def test_partitioned_group_by_14954(monkeypatch: Any) -> None: + monkeypatch.setenv("POLARS_FORCE_PARTITION", "1") + assert ( + pl.DataFrame({"a": range(20)}) + .select(pl.col("a") % 2) + .group_by("a") + .agg( + (pl.col("a") > 1000).alias("a > 1000"), + ) + ).sort("a").to_dict(as_series=False) == { + "a": [0, 1], + "a > 1000": [ + [False, False, False, False, False, False, False, False, False, False], + [False, False, False, False, False, False, False, False, False, False], + ], + } + + +def test_partitioned_group_by_nulls_mean_21838() -> None: + size = 10 + a = [1 for i in range(size)] + [2 for i in range(size)] + [3 for i in range(size)] + b = [1 for i in range(size)] + [None for i in range(size * 2)] + df = pl.DataFrame({"a": a, "b": b}) + assert df.group_by("a").mean().sort("a").to_dict(as_series=False) == { + "a": [1, 2, 3], + "b": [1.0, None, None], + } + + +def test_aggregated_scalar_elementwise_15602() -> None: + df = pl.DataFrame({"group": [1, 2, 1]}) + + out = df.group_by("group", maintain_order=True).agg( + foo=pl.col("group").is_between(1, pl.max("group")) + ) + expected = pl.DataFrame({"group": [1, 2], "foo": [[True, True], [True]]}) + assert_frame_equal(out, expected) + + +def test_group_by_multiple_null_cols_15623() -> None: + df = pl.DataFrame(schema={"a": pl.Null, "b": pl.Null}).group_by(pl.all()).len() + assert df.is_empty() + + +@pytest.mark.release +@pytest.mark.usefixtures("test_global_and_local") +def test_categorical_vs_str_group_by() -> None: + # this triggers the perfect hash table + s = pl.Series("a", np.random.randint(0, 50, 100)) + s_with_nulls = pl.select( + pl.when(s < 3).then(None).otherwise(s).alias("a") + ).to_series() + + for s_ in [s, s_with_nulls]: + s_ = s_.cast(str) + cat_out = ( + s_.cast(pl.Categorical) + .to_frame("a") + .group_by("a") + .agg(pl.first().alias("first")) + ) + + str_out = s_.to_frame("a").group_by("a").agg(pl.first().alias("first")) + cat_out.with_columns(pl.col("a").cast(str)) + assert_frame_equal( + cat_out.with_columns( + pl.col("a").cast(str), pl.col("first").cast(pl.List(str)) + ).sort("a"), + str_out.sort("a"), + ) + + +@pytest.mark.release +def test_boolean_min_max_agg() -> None: + np.random.seed(0) + idx = np.random.randint(0, 500, 1000) + c = np.random.randint(0, 500, 1000) > 250 + + df = pl.DataFrame({"idx": idx, "c": c}) + aggs = [pl.col("c").min().alias("c_min"), pl.col("c").max().alias("c_max")] + + result = df.group_by("idx").agg(aggs).sum() + + schema = {"idx": pl.Int64, "c_min": pl.UInt32, "c_max": pl.UInt32} + expected = pl.DataFrame( + { + "idx": [107583], + "c_min": [120], + "c_max": [321], + }, + schema=schema, + ) + assert_frame_equal(result, expected) + + nulls = np.random.randint(0, 500, 1000) < 100 + + result = ( + df.with_columns(c=pl.when(pl.lit(nulls)).then(None).otherwise(pl.col("c"))) + .group_by("idx") + .agg(aggs) + .sum() + ) + + expected = pl.DataFrame( + { + "idx": [107583], + "c_min": [133], + "c_max": [276], + }, + schema=schema, + ) + assert_frame_equal(result, expected) + + +def test_partitioned_group_by_chunked(partition_limit: int) -> None: + n = partition_limit + df1 = pl.DataFrame(np.random.randn(n, 2)) + df2 = pl.DataFrame(np.random.randn(n, 2)) + gps = pl.Series(name="oo", values=[0] * n + [1] * n) + df = pl.concat([df1, df2], rechunk=False) + assert_frame_equal( + df.group_by(gps).sum().sort("oo"), + df.rechunk().group_by(gps, maintain_order=True).sum(), + ) + + +def test_schema_on_agg() -> None: + lf = pl.LazyFrame({"a": ["x", "x", "y", "n"], "b": [1, 2, 3, 4]}) + + result = lf.group_by("a").agg( + pl.col("b").min().alias("min"), + pl.col("b").max().alias("max"), + pl.col("b").sum().alias("sum"), + pl.col("b").first().alias("first"), + pl.col("b").last().alias("last"), + ) + expected_schema = { + "a": pl.String, + "min": pl.Int64, + "max": pl.Int64, + "sum": pl.Int64, + "first": pl.Int64, + "last": pl.Int64, + } + assert result.collect_schema() == expected_schema + + +def test_group_by_schema_err() -> None: + lf = pl.LazyFrame({"foo": [None, 1, 2], "bar": [1, 2, 3]}) + with pytest.raises(ColumnNotFoundError): + lf.group_by("not-existent").agg( + pl.col("bar").max().alias("max_bar") + ).collect_schema() + + +@pytest.mark.parametrize( + ("data", "expr", "expected_select", "expected_gb"), + [ + ( + {"x": ["x"], "y": ["y"]}, + pl.coalesce(pl.col("x"), pl.col("y")), + {"x": pl.String}, + {"x": pl.List(pl.String)}, + ), + ( + {"x": [True]}, + pl.col("x").sum(), + {"x": pl.UInt32}, + {"x": pl.UInt32}, + ), + ( + {"a": [[1, 2]]}, + pl.col("a").list.sum(), + {"a": pl.Int64}, + {"a": pl.List(pl.Int64)}, + ), + ], +) +def test_schemas( + data: dict[str, list[Any]], + expr: pl.Expr, + expected_select: dict[str, PolarsDataType], + expected_gb: dict[str, PolarsDataType], +) -> None: + df = pl.DataFrame(data) + + # test selection schema + schema = df.select(expr).schema + for key, dtype in expected_select.items(): + assert schema[key] == dtype + + # test group_by schema + schema = df.group_by(pl.lit(1)).agg(expr).schema + for key, dtype in expected_gb.items(): + assert schema[key] == dtype + + +def test_lit_iter_schema() -> None: + df = pl.DataFrame( + { + "key": ["A", "A", "A", "A"], + "dates": [ + date(1970, 1, 1), + date(1970, 1, 1), + date(1970, 1, 2), + date(1970, 1, 3), + ], + } + ) + + result = df.group_by("key").agg(pl.col("dates").unique() + timedelta(days=1)) + expected = { + "key": ["A"], + "dates": [[date(1970, 1, 2), date(1970, 1, 3), date(1970, 1, 4)]], + } + assert result.to_dict(as_series=False) == expected + + +def test_absence_off_null_prop_8224() -> None: + # a reminder to self to not do null propagation + # it is inconsistent and makes output dtype + # dependent of the data, big no! + + def sub_col_min(column: str, min_column: str) -> pl.Expr: + return pl.col(column).sub(pl.col(min_column).min()) + + df = pl.DataFrame( + { + "group": [1, 1, 2, 2], + "vals_num": [10.0, 11.0, 12.0, 13.0], + "vals_partial": [None, None, 12.0, 13.0], + "vals_null": [None, None, None, None], + } + ) + + q = ( + df.lazy() + .group_by("group") + .agg( + sub_col_min("vals_num", "vals_num").alias("sub_num"), + sub_col_min("vals_num", "vals_partial").alias("sub_partial"), + sub_col_min("vals_num", "vals_null").alias("sub_null"), + ) + ) + + assert q.collect().dtypes == [ + pl.Int64, + pl.List(pl.Float64), + pl.List(pl.Float64), + pl.List(pl.Float64), + ] + + +def test_grouped_slice_literals() -> None: + assert pl.DataFrame({"idx": [1, 2, 3]}).group_by(True).agg( + x=pl.lit([1, 2]).slice( + -1, 1 + ), # slices a list of 1 element, so remains the same element + x2=pl.lit(pl.Series([1, 2])).slice(-1, 1), + ).to_dict(as_series=False) == {"literal": [True], "x": [[1, 2]], "x2": [2]} + + +def test_positional_by_with_list_or_tuple_17540() -> None: + with pytest.raises(TypeError, match="Hint: if you"): + pl.DataFrame({"a": [1, 2, 3]}).group_by(by=["a"]) + with pytest.raises(TypeError, match="Hint: if you"): + pl.LazyFrame({"a": [1, 2, 3]}).group_by(by=["a"]) + + +def test_group_by_agg_19173() -> None: + df = pl.DataFrame({"x": [1.0], "g": [0]}) + out = df.head(0).group_by("g").agg((pl.col.x - pl.col.x.sum() * pl.col.x) ** 2) + assert out.to_dict(as_series=False) == {"g": [], "x": []} + assert out.schema == pl.Schema([("g", pl.Int64), ("x", pl.List(pl.Float64))]) + + +def test_group_by_map_groups_slice_pushdown_20002() -> None: + schema = { + "a": pl.Int8, + "b": pl.UInt8, + } + + df = ( + pl.LazyFrame( + data={"a": [1, 2, 3, 4, 5], "b": [90, 80, 70, 60, 50]}, + schema=schema, + ) + .group_by("a", maintain_order=True) + .map_groups(lambda df: df * 2.0, schema=schema) + .head(3) + .collect() + ) + + assert_frame_equal( + df, + pl.DataFrame( + { + "a": [2.0, 4.0, 6.0], + "b": [180.0, 160.0, 140.0], + } + ), + ) + + +@typing.no_type_check +def test_group_by_lit_series(capfd: Any, monkeypatch: Any) -> None: + monkeypatch.setenv("POLARS_VERBOSE", "1") + n = 10 + df = pl.DataFrame({"x": np.ones(2 * n), "y": n * list(range(2))}) + a = np.ones(n, dtype=float) + df.lazy().group_by("y").agg(pl.col("x").dot(a)).collect() + captured = capfd.readouterr().err + assert "are not partitionable" in captured + + +def test_group_by_list_column() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [[1, 2], [3], [1, 2]]}) + result = df.group_by("b").agg(pl.sum("a")).sort("b") + expected = pl.DataFrame({"b": [[1, 2], [3]], "a": [4, 2]}) + assert_frame_equal(result, expected) + + +def test_enum_perfect_group_by_21360() -> None: + dtype = pl.Enum(categories=["a", "b"]) + + assert_frame_equal( + pl.from_dicts([{"col": "a"}], schema={"col": dtype}) + .group_by("col") + .agg(pl.len()), + pl.DataFrame( + [ + pl.Series("col", ["a"], dtype), + pl.Series("len", [1], get_index_type()), + ] + ), + ) + + +def test_partitioned_group_by_21634(partition_limit: int) -> None: + n = partition_limit + df = pl.DataFrame({"grp": [1] * n, "x": [1] * n}) + assert df.group_by("grp", True).agg().to_dict(as_series=False) == { + "grp": [1], + "literal": [True], + } + + +def test_group_by_cse_dup_key_alias_22238() -> None: + df = pl.LazyFrame({"a": [1, 1, 2, 2, -1], "x": [0, 1, 2, 3, 10]}) + result = df.group_by( + pl.col("a").abs(), + pl.col("a").abs().alias("a_with_alias"), + ).agg(pl.col("x").sum()) + assert_frame_equal( + result.collect(), + pl.DataFrame({"a": [1, 2], "a_with_alias": [1, 2], "x": [11, 5]}), + check_row_order=False, + ) + + +def test_group_by_22328() -> None: + N = 20 + + df1 = pl.select( + x=pl.repeat(1, N // 2).append(pl.repeat(2, N // 2)).shuffle(), + y=pl.lit(3.0, pl.Float32), + ).lazy() + + df2 = pl.select(x=pl.repeat(4, N)).lazy() + + assert ( + df2.join(df1.group_by("x").mean().with_columns(z="y"), how="left", on="x") + .with_columns(pl.col("z").fill_null(0)) + .collect() + ).shape == (20, 3) diff --git a/py-polars/tests/unit/operations/test_group_by_dynamic.py b/py-polars/tests/unit/operations/test_group_by_dynamic.py new file mode 100644 index 000000000000..b9f6bb3dec19 --- /dev/null +++ b/py-polars/tests/unit/operations/test_group_by_dynamic.py @@ -0,0 +1,1075 @@ +from __future__ import annotations + +from datetime import date, datetime, timedelta, timezone +from typing import TYPE_CHECKING, Any +from zoneinfo import ZoneInfo + +import numpy as np +import pytest + +import polars as pl +from polars.exceptions import ComputeError, InvalidOperationError +from polars.testing import assert_frame_equal + +if TYPE_CHECKING: + from polars._typing import Label, StartBy + + +@pytest.mark.parametrize( + ("input_df", "expected_grouped_df"), + [ + ( + ( + pl.DataFrame( + { + "dt": [ + datetime(2021, 12, 31, 0, 0, 0), + datetime(2022, 1, 1, 0, 0, 1), + datetime(2022, 3, 31, 0, 0, 1), + datetime(2022, 4, 1, 0, 0, 1), + ] + } + ) + ), + pl.DataFrame( + { + "dt": [ + datetime(2021, 10, 1), + datetime(2022, 1, 1), + datetime(2022, 4, 1), + ], + "num_points": [1, 2, 1], + }, + schema={"dt": pl.Datetime, "num_points": pl.UInt32}, + ).sort("dt"), + ) + ], +) +def test_group_by_dynamic( + input_df: pl.DataFrame, expected_grouped_df: pl.DataFrame +) -> None: + result = ( + input_df.sort("dt") + .group_by_dynamic("dt", every="1q") + .agg(pl.col("dt").count().alias("num_points")) + .sort("dt") + ) + assert_frame_equal(result, expected_grouped_df) + + +@pytest.mark.parametrize( + ("every", "offset"), + [ + ("3d", "-1d"), + (timedelta(days=3), timedelta(days=-1)), + ], +) +def test_dynamic_group_by_timezone_awareness( + every: str | timedelta, offset: str | timedelta +) -> None: + df = pl.DataFrame( + ( + pl.datetime_range( + datetime(2020, 1, 1), + datetime(2020, 1, 10), + timedelta(days=1), + time_unit="ns", + eager=True, + ) + .alias("datetime") + .dt.replace_time_zone("UTC"), + pl.arange(1, 11, eager=True).alias("value"), + ) + ) + + assert ( + df.group_by_dynamic( + "datetime", + every=every, + offset=offset, + closed="right", + include_boundaries=True, + label="datapoint", + ).agg(pl.col("value").last()) + ).dtypes == [pl.Datetime("ns", "UTC")] * 3 + [pl.Int64] + + +@pytest.mark.parametrize("tzinfo", [None, ZoneInfo("UTC"), ZoneInfo("Asia/Kathmandu")]) +def test_group_by_dynamic_startby_5599(tzinfo: ZoneInfo | None) -> None: + # start by datapoint + start = datetime(2022, 12, 16, tzinfo=tzinfo) + stop = datetime(2022, 12, 16, hour=3, tzinfo=tzinfo) + df = pl.DataFrame({"date": pl.datetime_range(start, stop, "30m", eager=True)}) + + assert df.group_by_dynamic( + "date", + every="31m", + include_boundaries=True, + label="datapoint", + start_by="datapoint", + ).agg(pl.len()).to_dict(as_series=False) == { + "_lower_boundary": [ + datetime(2022, 12, 16, 0, 0, tzinfo=tzinfo), + datetime(2022, 12, 16, 0, 31, tzinfo=tzinfo), + datetime(2022, 12, 16, 1, 2, tzinfo=tzinfo), + datetime(2022, 12, 16, 1, 33, tzinfo=tzinfo), + datetime(2022, 12, 16, 2, 4, tzinfo=tzinfo), + datetime(2022, 12, 16, 2, 35, tzinfo=tzinfo), + ], + "_upper_boundary": [ + datetime(2022, 12, 16, 0, 31, tzinfo=tzinfo), + datetime(2022, 12, 16, 1, 2, tzinfo=tzinfo), + datetime(2022, 12, 16, 1, 33, tzinfo=tzinfo), + datetime(2022, 12, 16, 2, 4, tzinfo=tzinfo), + datetime(2022, 12, 16, 2, 35, tzinfo=tzinfo), + datetime(2022, 12, 16, 3, 6, tzinfo=tzinfo), + ], + "date": [ + datetime(2022, 12, 16, 0, 0, tzinfo=tzinfo), + datetime(2022, 12, 16, 1, 0, tzinfo=tzinfo), + datetime(2022, 12, 16, 1, 30, tzinfo=tzinfo), + datetime(2022, 12, 16, 2, 0, tzinfo=tzinfo), + datetime(2022, 12, 16, 2, 30, tzinfo=tzinfo), + datetime(2022, 12, 16, 3, 0, tzinfo=tzinfo), + ], + "len": [2, 1, 1, 1, 1, 1], + } + + # start by monday + start = datetime(2022, 1, 1, tzinfo=tzinfo) + stop = datetime(2022, 1, 12, 7, tzinfo=tzinfo) + + df = pl.DataFrame( + {"date": pl.datetime_range(start, stop, "12h", eager=True)} + ).with_columns(pl.col("date").dt.weekday().alias("day")) + + result = df.group_by_dynamic( + "date", + every="1w", + period="3d", + include_boundaries=True, + start_by="monday", + label="datapoint", + ).agg([pl.len(), pl.col("day").first().alias("data_day")]) + assert result.to_dict(as_series=False) == { + "_lower_boundary": [ + datetime(2022, 1, 3, 0, 0, tzinfo=tzinfo), + datetime(2022, 1, 10, 0, 0, tzinfo=tzinfo), + ], + "_upper_boundary": [ + datetime(2022, 1, 6, 0, 0, tzinfo=tzinfo), + datetime(2022, 1, 13, 0, 0, tzinfo=tzinfo), + ], + "date": [ + datetime(2022, 1, 3, 0, 0, tzinfo=tzinfo), + datetime(2022, 1, 10, 0, 0, tzinfo=tzinfo), + ], + "len": [6, 5], + "data_day": [1, 1], + } + # start by saturday + result = df.group_by_dynamic( + "date", + every="1w", + period="3d", + include_boundaries=True, + start_by="saturday", + label="datapoint", + ).agg([pl.len(), pl.col("day").first().alias("data_day")]) + assert result.to_dict(as_series=False) == { + "_lower_boundary": [ + datetime(2022, 1, 1, 0, 0, tzinfo=tzinfo), + datetime(2022, 1, 8, 0, 0, tzinfo=tzinfo), + ], + "_upper_boundary": [ + datetime(2022, 1, 4, 0, 0, tzinfo=tzinfo), + datetime(2022, 1, 11, 0, 0, tzinfo=tzinfo), + ], + "date": [ + datetime(2022, 1, 1, 0, 0, tzinfo=tzinfo), + datetime(2022, 1, 8, 0, 0, tzinfo=tzinfo), + ], + "len": [6, 6], + "data_day": [6, 6], + } + + +def test_group_by_dynamic_by_monday_and_offset_5444() -> None: + df = pl.DataFrame( + { + "date": [ + "2022-11-01", + "2022-11-02", + "2022-11-05", + "2022-11-08", + "2022-11-08", + "2022-11-09", + "2022-11-10", + ], + "label": ["a", "b", "a", "a", "b", "a", "b"], + "value": [1, 2, 3, 4, 5, 6, 7], + } + ).with_columns(pl.col("date").str.strptime(pl.Date, "%Y-%m-%d").set_sorted()) + + result = df.group_by_dynamic( + "date", every="1w", offset="1d", group_by="label", start_by="monday" + ).agg(pl.col("value").sum()) + + assert result.to_dict(as_series=False) == { + "label": ["a", "a", "b", "b"], + "date": [ + date(2022, 11, 1), + date(2022, 11, 8), + date(2022, 11, 1), + date(2022, 11, 8), + ], + "value": [4, 10, 2, 12], + } + + # test empty + result_empty = ( + df.filter(pl.col("date") == date(1, 1, 1)) + .group_by_dynamic( + "date", every="1w", offset="1d", group_by="label", start_by="monday" + ) + .agg(pl.col("value").sum()) + ) + assert result_empty.schema == result.schema + + +@pytest.mark.parametrize( + ("label", "expected"), + [ + ("left", [datetime(2020, 1, 1), datetime(2020, 1, 2)]), + ("right", [datetime(2020, 1, 2), datetime(2020, 1, 3)]), + ("datapoint", [datetime(2020, 1, 1, 1), datetime(2020, 1, 2, 3)]), + ], +) +def test_group_by_dynamic_label(label: Label, expected: list[datetime]) -> None: + df = pl.DataFrame( + { + "ts": [ + datetime(2020, 1, 1, 1), + datetime(2020, 1, 1, 2), + datetime(2020, 1, 2, 3), + datetime(2020, 1, 2, 4), + ], + "n": [1, 2, 3, 4], + "group": ["a", "a", "b", "b"], + } + ).sort("ts") + result = ( + df.group_by_dynamic("ts", every="1d", label=label, group_by="group") + .agg(pl.col("n"))["ts"] + .to_list() + ) + assert result == expected + + +@pytest.mark.parametrize( + ("label", "expected"), + [ + ("left", [datetime(2020, 1, 1), datetime(2020, 1, 2), datetime(2020, 1, 3)]), + ("right", [datetime(2020, 1, 2), datetime(2020, 1, 3), datetime(2020, 1, 4)]), + ( + "datapoint", + [datetime(2020, 1, 1, 1), datetime(2020, 1, 2, 2), datetime(2020, 1, 3, 3)], + ), + ], +) +def test_group_by_dynamic_label_with_by(label: Label, expected: list[datetime]) -> None: + df = pl.DataFrame( + { + "ts": [ + datetime(2020, 1, 1, 1), + datetime(2020, 1, 2, 2), + datetime(2020, 1, 3, 3), + ], + "n": [1, 2, 3], + } + ).sort("ts") + result = ( + df.group_by_dynamic("ts", every="1d", label=label) + .agg(pl.col("n"))["ts"] + .to_list() + ) + assert result == expected + + +def test_group_by_dynamic_slice_pushdown() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": ["a", "a", "b"], "c": [1, 3, 5]}).lazy() + df = ( + df.sort("a") + .group_by_dynamic("a", group_by="b", every="2i") + .agg((pl.col("c") - pl.col("c").shift(fill_value=0)).sum().alias("c")) + ) + assert df.head(2).collect().to_dict(as_series=False) == { + "b": ["a", "a"], + "a": [0, 2], + "c": [1, 3], + } + + +def test_rolling_kernels_group_by_dynamic_7548() -> None: + assert pl.DataFrame( + {"time": pl.arange(0, 4, eager=True), "value": pl.arange(0, 4, eager=True)} + ).group_by_dynamic("time", every="1i", period="3i").agg( + pl.col("value"), + pl.col("value").min().alias("min_value"), + pl.col("value").max().alias("max_value"), + pl.col("value").sum().alias("sum_value"), + ).to_dict(as_series=False) == { + "time": [0, 1, 2, 3], + "value": [[0, 1, 2], [1, 2, 3], [2, 3], [3]], + "min_value": [0, 1, 2, 3], + "max_value": [2, 3, 3, 3], + "sum_value": [3, 6, 5, 3], + } + + +def test_rolling_dynamic_sortedness_check() -> None: + # when the by argument is passed, the sortedness flag + # will be unset as the take shuffles data, so we must explicitly + # check the sortedness + df = pl.DataFrame( + { + "idx": [1, 2, -1, 2, 1, 1], + "group": [1, 1, 1, 2, 2, 1], + } + ) + + with pytest.raises(ComputeError, match=r"input data is not sorted"): + df.group_by_dynamic("idx", every="2i", group_by="group").agg( + pl.col("idx").alias("idx1") + ) + + # no `by` argument + with pytest.raises( + InvalidOperationError, + match=r"argument in operation 'group_by_dynamic' is not sorted", + ): + df.group_by_dynamic("idx", every="2i").agg(pl.col("idx").alias("idx1")) + + +@pytest.mark.parametrize("time_zone", [None, "UTC", "Asia/Kathmandu"]) +def test_group_by_dynamic_elementwise_following_mean_agg_6904( + time_zone: str | None, +) -> None: + df = ( + pl.DataFrame( + { + "a": [datetime(2021, 1, 1) + timedelta(seconds=2**i) for i in range(5)], + "b": [float(i) for i in range(5)], + } + ) + .with_columns(pl.col("a").dt.replace_time_zone(time_zone)) + .lazy() + .set_sorted("a") + .group_by_dynamic("a", every="10s", period="100s") + .agg([pl.col("b").mean().sin().alias("c")]) + .collect() + ) + assert_frame_equal( + df, + pl.DataFrame( + { + "a": [ + datetime(2021, 1, 1, 0, 0), + datetime(2021, 1, 1, 0, 0, 10), + ], + "c": [0.9092974268256817, -0.7568024953079282], + } + ).with_columns(pl.col("a").dt.replace_time_zone(time_zone)), + ) + + +@pytest.mark.parametrize("every", ["1h", timedelta(hours=1)]) +@pytest.mark.parametrize("tzinfo", [None, ZoneInfo("UTC"), ZoneInfo("Asia/Kathmandu")]) +def test_group_by_dynamic_lazy(every: str | timedelta, tzinfo: ZoneInfo | None) -> None: + ldf = pl.LazyFrame( + { + "time": pl.datetime_range( + start=datetime(2021, 12, 16, tzinfo=tzinfo), + end=datetime(2021, 12, 16, 2, tzinfo=tzinfo), + interval="30m", + eager=True, + ), + "n": range(5), + } + ) + df = ( + ldf.group_by_dynamic("time", every=every, closed="right") + .agg( + [ + pl.col("time").min().alias("time_min"), + pl.col("time").max().alias("time_max"), + ] + ) + .collect() + ) + assert sorted(df.rows()) == [ + ( + datetime(2021, 12, 15, 23, 0, tzinfo=tzinfo), + datetime(2021, 12, 16, 0, 0, tzinfo=tzinfo), + datetime(2021, 12, 16, 0, 0, tzinfo=tzinfo), + ), + ( + datetime(2021, 12, 16, 0, 0, tzinfo=tzinfo), + datetime(2021, 12, 16, 0, 30, tzinfo=tzinfo), + datetime(2021, 12, 16, 1, 0, tzinfo=tzinfo), + ), + ( + datetime(2021, 12, 16, 1, 0, tzinfo=tzinfo), + datetime(2021, 12, 16, 1, 30, tzinfo=tzinfo), + datetime(2021, 12, 16, 2, 0, tzinfo=tzinfo), + ), + ] + + +def test_group_by_dynamic_validation() -> None: + df = pl.DataFrame( + { + "index": [0, 0, 1, 1], + "group": ["banana", "pear", "banana", "pear"], + "weight": [2, 3, 5, 7], + } + ) + + with pytest.raises(ComputeError, match="'every' argument must be positive"): + df.group_by_dynamic("index", group_by="group", every="-1i", period="2i").agg( + pl.col("weight") + ) + + +def test_no_sorted_no_error() -> None: + df = pl.DataFrame( + { + "dt": [datetime(2001, 1, 1), datetime(2001, 1, 2)], + } + ) + result = df.group_by_dynamic("dt", every="1h").agg(pl.len().alias("count")) + expected = pl.DataFrame( + { + "dt": [datetime(2001, 1, 1), datetime(2001, 1, 2)], + "count": [1, 1], + }, + schema_overrides={"count": pl.get_index_type()}, + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("tzinfo", [None, ZoneInfo("UTC"), ZoneInfo("Asia/Kathmandu")]) +def test_truncate_negative_offset(tzinfo: ZoneInfo | None) -> None: + time_zone = tzinfo.key if tzinfo is not None else None + df = pl.DataFrame( + { + "event_date": [ + datetime(2021, 4, 11), + datetime(2021, 4, 29), + datetime(2021, 5, 29), + ], + "adm1_code": [1, 2, 1], + } + ).set_sorted("event_date") + df = df.with_columns(pl.col("event_date").dt.replace_time_zone(time_zone)) + out = df.group_by_dynamic( + index_column="event_date", + every="1mo", + period="2mo", + offset="-1mo", + include_boundaries=True, + ).agg( + [ + pl.col("adm1_code"), + ] + ) + + assert out["event_date"].to_list() == [ + datetime(2021, 3, 1, tzinfo=tzinfo), + datetime(2021, 4, 1, tzinfo=tzinfo), + datetime(2021, 5, 1, tzinfo=tzinfo), + ] + df = pl.DataFrame( + { + "event_date": [ + datetime(2021, 4, 11), + datetime(2021, 4, 29), + datetime(2021, 5, 29), + ], + "adm1_code": [1, 2, 1], + "five_type": ["a", "b", "a"], + "actor": ["a", "a", "a"], + "admin": ["a", "a", "a"], + "fatalities": [10, 20, 30], + } + ).set_sorted("event_date") + df = df.with_columns(pl.col("event_date").dt.replace_time_zone(time_zone)) + + out = df.group_by_dynamic( + index_column="event_date", + every="1mo", + group_by=["admin", "five_type", "actor"], + ).agg([pl.col("adm1_code").unique(), (pl.col("fatalities") > 0).sum()]) + + assert out["event_date"].to_list() == [ + datetime(2021, 4, 1, tzinfo=tzinfo), + datetime(2021, 5, 1, tzinfo=tzinfo), + datetime(2021, 4, 1, tzinfo=tzinfo), + ] + + for dt in [pl.Int32, pl.Int64]: + df = ( + pl.DataFrame( + { + "idx": np.arange(6), + "A": ["A", "A", "B", "B", "B", "C"], + } + ) + .with_columns(pl.col("idx").cast(dt)) + .set_sorted("idx") + ) + + out = df.group_by_dynamic( + "idx", every="2i", period="3i", include_boundaries=True + ).agg(pl.col("A")) + + assert out.shape == (3, 4) + assert out["A"].to_list() == [ + ["A", "A", "B"], + ["B", "B", "B"], + ["B", "C"], + ] + + +def test_groupy_by_dynamic_median_10695() -> None: + df = pl.DataFrame( + { + "timestamp": pl.datetime_range( + datetime(2023, 8, 22, 15, 44, 30), + datetime(2023, 8, 22, 15, 48, 50), + "20s", + eager=True, + ), + "foo": [2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + } + ) + + assert df.group_by_dynamic( + index_column="timestamp", + every="60s", + period="3m", + ).agg(pl.col("foo").median()).to_dict(as_series=False) == { + "timestamp": [ + datetime(2023, 8, 22, 15, 44), + datetime(2023, 8, 22, 15, 45), + datetime(2023, 8, 22, 15, 46), + datetime(2023, 8, 22, 15, 47), + datetime(2023, 8, 22, 15, 48), + ], + "foo": [1.0, 1.0, 1.0, 1.0, 1.0], + } + + +def test_group_by_dynamic_when_conversion_crosses_dates_7274() -> None: + df = ( + pl.DataFrame( + data={ + "timestamp": ["1970-01-01 00:00:00+01:00", "1970-01-01 01:00:00+01:00"], + "value": [1, 1], + } + ) + .with_columns( + pl.col("timestamp") + .str.strptime(pl.Datetime, format="%Y-%m-%d %H:%M:%S%:z") + .dt.convert_time_zone("Africa/Lagos") + .set_sorted() + ) + .with_columns( + pl.col("timestamp") + .dt.convert_time_zone("UTC") + .alias("timestamp_utc") + .set_sorted() + ) + ) + result = df.group_by_dynamic( + index_column="timestamp", every="1d", closed="left" + ).agg(pl.col("value").count()) + expected = pl.DataFrame({"timestamp": [datetime(1970, 1, 1)], "value": [2]}) + expected = expected.with_columns( + pl.col("timestamp").dt.replace_time_zone("Africa/Lagos"), + pl.col("value").cast(pl.UInt32), + ) + assert_frame_equal(result, expected) + result = df.group_by_dynamic( + index_column="timestamp_utc", every="1d", closed="left" + ).agg(pl.col("value").count()) + expected = pl.DataFrame( + { + "timestamp_utc": [datetime(1969, 12, 31), datetime(1970, 1, 1)], + "value": [1, 1], + } + ) + expected = expected.with_columns( + pl.col("timestamp_utc").dt.replace_time_zone("UTC"), + pl.col("value").cast(pl.UInt32), + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("time_zone", [None, "UTC", "Asia/Kathmandu"]) +def test_default_negative_every_offset_dynamic_group_by(time_zone: str | None) -> None: + # 2791 + dts = [ + datetime(2020, 1, 1), + datetime(2020, 1, 2), + datetime(2020, 2, 1), + datetime(2020, 3, 1), + ] + df = pl.DataFrame({"dt": dts, "idx": range(len(dts))}).set_sorted("dt") + df = df.with_columns(pl.col("dt").dt.replace_time_zone(time_zone)) + out = df.group_by_dynamic(index_column="dt", every="1mo", closed="right").agg( + pl.col("idx") + ) + + expected = pl.DataFrame( + { + "dt": [ + datetime(2019, 12, 1, 0, 0), + datetime(2020, 1, 1, 0, 0), + datetime(2020, 2, 1, 0, 0), + ], + "idx": [[0], [1, 2], [3]], + } + ) + expected = expected.with_columns(pl.col("dt").dt.replace_time_zone(time_zone)) + assert_frame_equal(out, expected) + + +@pytest.mark.parametrize( + ("rule", "offset"), + [ + ("1h", timedelta(hours=2)), + ("1d", timedelta(days=2)), + ("1w", timedelta(weeks=2)), + ], +) +def test_group_by_dynamic_crossing_dst(rule: str, offset: timedelta) -> None: + start_dt = datetime(2021, 11, 7) + end_dt = start_dt + offset + date_range = pl.datetime_range( + start_dt, end_dt, rule, time_zone="US/Central", eager=True + ) + df = pl.DataFrame({"time": date_range, "value": range(len(date_range))}) + result = df.group_by_dynamic("time", every=rule, start_by="datapoint").agg( + pl.col("value").mean() + ) + expected = pl.DataFrame( + {"time": date_range, "value": range(len(date_range))}, + schema_overrides={"value": pl.Float64}, + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + ("start_by", "expected_time", "expected_value"), + [ + ( + "monday", + [ + datetime(2021, 11, 1), + datetime(2021, 11, 8), + ], + [0.0, 4.0], + ), + ( + "tuesday", + [ + datetime(2021, 11, 2), + datetime(2021, 11, 9), + ], + [0.5, 4.5], + ), + ( + "wednesday", + [ + datetime(2021, 11, 3), + datetime(2021, 11, 10), + ], + [1.0, 5.0], + ), + ( + "thursday", + [ + datetime(2021, 11, 4), + datetime(2021, 11, 11), + ], + [1.5, 5.5], + ), + ( + "friday", + [ + datetime(2021, 11, 5), + datetime(2021, 11, 12), + ], + [2.0, 6.0], + ), + ( + "saturday", + [ + datetime(2021, 11, 6), + datetime(2021, 11, 13), + ], + [2.5, 6.5], + ), + ( + "sunday", + [ + datetime(2021, 11, 7), + datetime(2021, 11, 14), + ], + [3.0, 7.0], + ), + ], +) +def test_group_by_dynamic_startby_monday_crossing_dst( + start_by: StartBy, expected_time: list[datetime], expected_value: list[float] +) -> None: + start_dt = datetime(2021, 11, 7) + end_dt = datetime(2021, 11, 14) + date_range = pl.datetime_range( + start_dt, end_dt, "1d", time_zone="US/Central", eager=True + ) + df = pl.DataFrame({"time": date_range, "value": range(len(date_range))}) + result = df.group_by_dynamic("time", every="1w", start_by=start_by).agg( + pl.col("value").mean() + ) + expected = pl.DataFrame( + {"time": expected_time, "value": expected_value}, + ) + expected = expected.with_columns(pl.col("time").dt.replace_time_zone("US/Central")) + assert_frame_equal(result, expected) + + +def test_group_by_dynamic_startby_monday_dst_8737() -> None: + start_dt = datetime(2021, 11, 6, 20) + stop_dt = datetime(2021, 11, 7, 20) + date_range = pl.datetime_range( + start_dt, stop_dt, "1d", time_zone="US/Central", eager=True + ) + df = pl.DataFrame({"time": date_range, "value": range(len(date_range))}) + result = df.group_by_dynamic("time", every="1w", start_by="monday").agg( + pl.col("value").mean() + ) + expected = pl.DataFrame( + { + "time": [ + datetime(2021, 11, 1), + ], + "value": [0.5], + }, + ) + expected = expected.with_columns(pl.col("time").dt.replace_time_zone("US/Central")) + assert_frame_equal(result, expected) + + +def test_group_by_dynamic_monthly_crossing_dst() -> None: + start_dt = datetime(2021, 11, 1) + end_dt = datetime(2021, 12, 1) + date_range = pl.datetime_range( + start_dt, end_dt, "1mo", time_zone="US/Central", eager=True + ) + df = pl.DataFrame({"time": date_range, "value": range(len(date_range))}) + result = df.group_by_dynamic("time", every="1mo").agg(pl.col("value").mean()) + expected = pl.DataFrame( + {"time": date_range, "value": range(len(date_range))}, + schema_overrides={"value": pl.Float64}, + ) + assert_frame_equal(result, expected) + + +def test_group_by_dynamic_2d_9333() -> None: + df = pl.DataFrame({"ts": [datetime(2000, 1, 1, 3)], "values": [10.0]}) + df = df.with_columns(pl.col("ts").set_sorted()) + result = df.group_by_dynamic("ts", every="2d").agg(pl.col("values")) + expected = pl.DataFrame({"ts": [datetime(1999, 12, 31, 0)], "values": [[10.0]]}) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("every", ["1h", timedelta(hours=1)]) +@pytest.mark.parametrize("tzinfo", [None, ZoneInfo("UTC"), ZoneInfo("Asia/Kathmandu")]) +def test_group_by_dynamic_iter(every: str | timedelta, tzinfo: ZoneInfo | None) -> None: + time_zone = tzinfo.key if tzinfo is not None else None + df = pl.DataFrame( + { + "datetime": [ + datetime(2020, 1, 1, 10, 0), + datetime(2020, 1, 1, 10, 50), + datetime(2020, 1, 1, 11, 10), + ], + "a": [1, 2, 2], + "b": [4, 5, 6], + } + ).set_sorted("datetime") + df = df.with_columns(pl.col("datetime").dt.replace_time_zone(time_zone)) + + # Without 'by' argument + result1 = [ + (name, data.shape) + for name, data in df.group_by_dynamic("datetime", every=every, closed="left") + ] + expected1 = [ + ((datetime(2020, 1, 1, 10, tzinfo=tzinfo),), (2, 3)), + ((datetime(2020, 1, 1, 11, tzinfo=tzinfo),), (1, 3)), + ] + assert result1 == expected1 + + # With 'by' argument + result2 = [ + (name, data.shape) + for name, data in df.group_by_dynamic( + "datetime", every=every, closed="left", group_by="a" + ) + ] + expected2 = [ + ((1, datetime(2020, 1, 1, 10, tzinfo=tzinfo)), (1, 3)), + ((2, datetime(2020, 1, 1, 10, tzinfo=tzinfo)), (1, 3)), + ((2, datetime(2020, 1, 1, 11, tzinfo=tzinfo)), (1, 3)), + ] + assert result2 == expected2 + + +# https://github.com/pola-rs/polars/issues/11339 +@pytest.mark.parametrize("include_boundaries", [True, False]) +def test_group_by_dynamic_lazy_schema(include_boundaries: bool) -> None: + lf = pl.LazyFrame( + { + "dt": pl.datetime_range( + start=datetime(2022, 2, 10), + end=datetime(2022, 2, 12), + eager=True, + ), + "n": range(3), + } + ) + + result = lf.group_by_dynamic( + "dt", every="2d", closed="right", include_boundaries=include_boundaries + ).agg(pl.col("dt").min().alias("dt_min")) + + assert result.collect_schema() == result.collect().schema + + +def test_group_by_dynamic_12414() -> None: + df = pl.DataFrame( + { + "today": [ + date(2023, 3, 3), + date(2023, 8, 31), + date(2023, 9, 1), + date(2023, 9, 4), + ], + "b": [1, 2, 3, 4], + } + ).sort("today") + assert df.group_by_dynamic( + "today", + every="6mo", + period="3d", + closed="left", + start_by="datapoint", + include_boundaries=True, + ).agg( + gt_min_count=(pl.col.b >= (pl.col.b.min())).sum(), + ).to_dict(as_series=False) == { + "_lower_boundary": [datetime(2023, 3, 3, 0, 0), datetime(2023, 9, 3, 0, 0)], + "_upper_boundary": [datetime(2023, 3, 6, 0, 0), datetime(2023, 9, 6, 0, 0)], + "today": [date(2023, 3, 3), date(2023, 9, 3)], + "gt_min_count": [1, 1], + } + + +@pytest.mark.parametrize("input", [[pl.col("b").sum()], pl.col("b").sum()]) +def test_group_by_dynamic_agg_input_types(input: Any) -> None: + df = pl.LazyFrame({"index_column": [0, 1, 2, 3], "b": [1, 3, 1, 2]}).set_sorted( + "index_column" + ) + result = df.group_by_dynamic( + index_column="index_column", every="2i", closed="right" + ).agg(input) + + expected = pl.LazyFrame({"index_column": [-2, 0, 2], "b": [1, 4, 2]}) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("input", [str, "b".join]) +def test_group_by_dynamic_agg_bad_input_types(input: Any) -> None: + df = pl.LazyFrame({"index_column": [0, 1, 2, 3], "b": [1, 3, 1, 2]}).set_sorted( + "index_column" + ) + with pytest.raises(TypeError): + df.group_by_dynamic( + index_column="index_column", every="2i", closed="right" + ).agg(input) + + +def test_group_by_dynamic_15225() -> None: + df = pl.DataFrame( + { + "a": [1, 2, 3], + "b": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)], + "c": [1, 1, 2], + } + ) + result = df.group_by_dynamic("b", every="2d").agg(pl.sum("a")) + expected = pl.DataFrame({"b": [date(2020, 1, 1), date(2020, 1, 3)], "a": [3, 3]}) + assert_frame_equal(result, expected) + result = df.group_by_dynamic("b", every="2d", group_by="c").agg(pl.sum("a")) + expected = pl.DataFrame( + {"c": [1, 2], "b": [date(2020, 1, 1), date(2020, 1, 3)], "a": [3, 3]} + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("start_by", ["window", "friday"]) +def test_earliest_point_included_when_offset_is_set_15241(start_by: StartBy) -> None: + df = pl.DataFrame( + data={ + "t": pl.Series( + [ + datetime(2024, 3, 22, 3, 0, tzinfo=timezone.utc), + datetime(2024, 3, 22, 4, 0, tzinfo=timezone.utc), + datetime(2024, 3, 22, 5, 0, tzinfo=timezone.utc), + datetime(2024, 3, 22, 6, 0, tzinfo=timezone.utc), + ] + ), + "v": [1, 10, 100, 1000], + } + ).set_sorted("t") + result = df.group_by_dynamic( + index_column="t", + every="1d", + offset=timedelta(hours=5), + start_by=start_by, + ).agg("v") + expected = pl.DataFrame( + { + "t": [ + datetime(2024, 3, 21, 5, 0, tzinfo=timezone.utc), + datetime(2024, 3, 22, 5, 0, tzinfo=timezone.utc), + ], + "v": [[1, 10], [100, 1000]], + } + ) + assert_frame_equal(result, expected) + + +def test_group_by_dynamic_invalid() -> None: + df = pl.DataFrame( + { + "values": [1, 4], + "times": [datetime(2020, 1, 3), datetime(2020, 1, 1)], + }, + ) + with pytest.raises( + InvalidOperationError, match="duration may not be a parsed integer" + ): + ( + df.sort("times") + .group_by_dynamic("times", every="3000i") + .agg(pl.col("values").sum().alias("sum")) + ) + with pytest.raises( + InvalidOperationError, match="duration must be a parsed integer" + ): + ( + df.with_row_index() + .group_by_dynamic("index", every="3000d") + .agg(pl.col("values").sum().alias("sum")) + ) + + +def test_group_by_dynamic_get() -> None: + df = pl.DataFrame( + { + "time": pl.date_range(pl.date(2021, 1, 1), pl.date(2021, 1, 8), eager=True), + "data": pl.arange(8, eager=True), + } + ) + + assert df.group_by_dynamic( + index_column="time", + every="2d", + period="3d", + start_by="datapoint", + ).agg( + get=pl.col("data").get(1), + ).to_dict(as_series=False) == { + "time": [ + date(2021, 1, 1), + date(2021, 1, 3), + date(2021, 1, 5), + date(2021, 1, 7), + ], + "get": [1, 3, 5, 7], + } + + +def test_group_by_dynamic_exclude_index_from_expansion_17075() -> None: + lf = pl.LazyFrame( + { + "time": pl.datetime_range( + start=datetime(2021, 12, 16), + end=datetime(2021, 12, 16, 3), + interval="30m", + eager=True, + ), + "n": range(7), + "m": range(7), + } + ) + + assert lf.group_by_dynamic( + "time", every="1h", closed="right" + ).last().collect().to_dict(as_series=False) == { + "time": [ + datetime(2021, 12, 15, 23, 0), + datetime(2021, 12, 16, 0, 0), + datetime(2021, 12, 16, 1, 0), + datetime(2021, 12, 16, 2, 0), + ], + "n": [0, 2, 4, 6], + "m": [0, 2, 4, 6], + } + + +def test_group_by_dynamic_overlapping_19704() -> None: + df = pl.DataFrame( + { + "a": [datetime(2020, 1, 1), datetime(2020, 2, 1), datetime(2020, 3, 1)], + "b": [1, 2, 3], + } + ) + result = df.group_by_dynamic( + "a", every="1mo", period="45d", include_boundaries=True + ).agg(pl.col("b").sum()) + expected = pl.DataFrame( + { + "_lower_boundary": [ + datetime(2020, 1, 1, 0, 0), + datetime(2020, 2, 1, 0, 0), + datetime(2020, 3, 1, 0, 0), + ], + "_upper_boundary": [ + datetime(2020, 2, 15, 0, 0), + datetime(2020, 3, 17, 0, 0), + datetime(2020, 4, 15, 0, 0), + ], + "a": [ + datetime(2020, 1, 1, 0, 0), + datetime(2020, 2, 1, 0, 0), + datetime(2020, 3, 1, 0, 0), + ], + "b": [3, 5, 3], + } + ) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/operations/test_has_nulls.py b/py-polars/tests/unit/operations/test_has_nulls.py new file mode 100644 index 000000000000..7a78c9a09ff8 --- /dev/null +++ b/py-polars/tests/unit/operations/test_has_nulls.py @@ -0,0 +1,57 @@ +from hypothesis import given + +import polars as pl +from polars.testing import assert_frame_equal +from polars.testing.parametric import dataframes, series + + +@given(s=series(allow_null=False)) +def test_has_nulls_series_no_nulls(s: pl.Series) -> None: + assert s.has_nulls() is False + + +@given(df=dataframes(allow_null=False)) +def test_has_nulls_expr_no_nulls(df: pl.DataFrame) -> None: + result = df.select(pl.all().has_nulls()) + assert result.select(pl.any_horizontal(df.columns)).item() is False + + +@given( + s=series( + excluded_dtypes=[ + pl.Struct, # https://github.com/pola-rs/polars/issues/3462 + ] + ) +) +def test_has_nulls_series_parametric(s: pl.Series) -> None: + result = s.has_nulls() + assert result == (s.null_count() > 0) + assert result == s.is_null().any() + + +@given( + lf=dataframes( + excluded_dtypes=[ + pl.Struct, # https://github.com/pola-rs/polars/issues/3462 + ], + lazy=True, + ) +) +def test_has_nulls_expr_parametric(lf: pl.LazyFrame) -> None: + result = lf.select(pl.all().has_nulls()) + + assert_frame_equal(result, lf.select(pl.all().null_count() > 0)) + assert_frame_equal(result, lf.select(pl.all().is_null().any())) + + +def test_has_nulls_series() -> None: + s = pl.Series([1, 2, None]) + assert s.has_nulls() is True + assert s[:2].has_nulls() is False + + +def test_has_nulls_expr() -> None: + lf = pl.LazyFrame({"a": [1, 2, None], "b": ["x", "y", "z"]}) + result = lf.select(pl.all().has_nulls()) + expected = pl.LazyFrame({"a": [True], "b": [False]}) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/operations/test_hash.py b/py-polars/tests/unit/operations/test_hash.py new file mode 100644 index 000000000000..9311ae7c383f --- /dev/null +++ b/py-polars/tests/unit/operations/test_hash.py @@ -0,0 +1,11 @@ +import polars as pl + + +def test_hash_struct() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + df = df.select(pl.struct(pl.all())) + assert df.select(pl.col("a").hash())["a"].to_list() == [ + 2509182728763614268, + 6116564025432436932, + 49592145888590321, + ] diff --git a/py-polars/tests/unit/operations/test_hist.py b/py-polars/tests/unit/operations/test_hist.py new file mode 100644 index 000000000000..2f690294bd5f --- /dev/null +++ b/py-polars/tests/unit/operations/test_hist.py @@ -0,0 +1,549 @@ +from __future__ import annotations + +import numpy as np +import pytest + +import polars as pl +from polars import StringCache +from polars.exceptions import ComputeError +from polars.testing import assert_frame_equal + +inf = float("inf") + + +@StringCache() +def test_hist_empty_data_no_inputs() -> None: + with pl.StringCache(): + s = pl.Series([], dtype=pl.UInt8) + + # No bins or edges specified: 10 bins around unit interval + expected = pl.DataFrame( + { + "breakpoint": pl.Series( + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], dtype=pl.Float64 + ), + "category": pl.Series( + [ + "[0.0, 0.1]", + "(0.1, 0.2]", + "(0.2, 0.3]", + "(0.3, 0.4]", + "(0.4, 0.5]", + "(0.5, 0.6]", + "(0.6, 0.7]", + "(0.7, 0.8]", + "(0.8, 0.9]", + "(0.9, 1.0]", + ], + dtype=pl.Categorical, + ), + "count": pl.Series([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=pl.UInt32), + } + ) + result = s.hist() + assert_frame_equal(result, expected) + + +@StringCache() +def test_hist_empty_data_empty_bins() -> None: + s = pl.Series([], dtype=pl.UInt8) + + # No bins or edges specified: 10 bins around unit interval + expected = pl.DataFrame( + { + "breakpoint": pl.Series([], dtype=pl.Float64), + "category": pl.Series([], dtype=pl.Categorical), + "count": pl.Series([], dtype=pl.UInt32), + } + ) + result = s.hist(bins=[]) + assert_frame_equal(result, expected) + + +@StringCache() +def test_hist_empty_data_single_bin_edge() -> None: + s = pl.Series([], dtype=pl.UInt8) + + # No bins or edges specified: 10 bins around unit interval + expected = pl.DataFrame( + { + "breakpoint": pl.Series([], dtype=pl.Float64), + "category": pl.Series([], dtype=pl.Categorical), + "count": pl.Series([], dtype=pl.UInt32), + } + ) + result = s.hist(bins=[2]) + assert_frame_equal(result, expected) + + +@StringCache() +def test_hist_empty_data_valid_edges() -> None: + s = pl.Series([], dtype=pl.UInt8) + + # No bins or edges specified: 10 bins around unit interval + expected = pl.DataFrame( + { + "breakpoint": pl.Series([2.0, 3.0], dtype=pl.Float64), + "category": pl.Series(["[1.0, 2.0]", "(2.0, 3.0]"], dtype=pl.Categorical), + "count": pl.Series([0, 0], dtype=pl.UInt32), + } + ) + result = s.hist(bins=[1, 2, 3]) + assert_frame_equal(result, expected) + + +@StringCache() +def test_hist_empty_data_invalid_edges() -> None: + s = pl.Series([], dtype=pl.UInt8) + with pytest.raises(ComputeError, match="bins must increase monotonically"): + s.hist(bins=[1, 0]) # invalid order + + +@StringCache() +def test_hist_empty_data_bad_bin_count() -> None: + s = pl.Series([], dtype=pl.UInt8) + with pytest.raises(OverflowError, match="can't convert negative int to unsigned"): + s.hist(bin_count=-1) # invalid order + + +@StringCache() +def test_hist_empty_data_zero_bin_count() -> None: + s = pl.Series([], dtype=pl.UInt8) + expected = pl.DataFrame( + { + "breakpoint": pl.Series([], dtype=pl.Float64), + "category": pl.Series([], dtype=pl.Categorical), + "count": pl.Series([], dtype=pl.UInt32), + } + ) + result = s.hist(bin_count=0) + assert_frame_equal(result, expected) + + +@StringCache() +def test_hist_empty_data_single_bin_count() -> None: + s = pl.Series([], dtype=pl.UInt8) + expected = pl.DataFrame( + { + "breakpoint": pl.Series([1.0], dtype=pl.Float64), + "category": pl.Series(["[0.0, 1.0]"], dtype=pl.Categorical), + "count": pl.Series([0], dtype=pl.UInt32), + } + ) + result = s.hist(bin_count=1) + assert_frame_equal(result, expected) + + +@StringCache() +def test_hist_empty_data_valid_bin_count() -> None: + s = pl.Series([], dtype=pl.UInt8) + expected = pl.DataFrame( + { + "breakpoint": pl.Series([0.2, 0.4, 0.6, 0.8, 1.0], dtype=pl.Float64), + "category": pl.Series( + [ + "[0.0, 0.2]", + "(0.2, 0.4]", + "(0.4, 0.6]", + "(0.6, 0.8]", + "(0.8, 1.0]", + ], + dtype=pl.Categorical, + ), + "count": pl.Series([0, 0, 0, 0, 0], dtype=pl.UInt32), + } + ) + result = s.hist(bin_count=5) + assert_frame_equal(result, expected) + + +@StringCache() +def test_hist_invalid_bin_count() -> None: + s = pl.Series([-5, 2, 0, 1, 99], dtype=pl.Int32) + with pytest.raises(OverflowError, match="can't convert negative int to unsigned"): + s.hist(bin_count=-1) # invalid order + + +@StringCache() +def test_hist_invalid_bins() -> None: + s = pl.Series([-5, 2, 0, 1, 99], dtype=pl.Int32) + with pytest.raises(ComputeError, match="bins must increase monotonically"): + s.hist(bins=[1, 0]) # invalid order + + +@StringCache() +def test_hist_bin_outside_data() -> None: + s = pl.Series([-5, 2, 0, 1, 99], dtype=pl.Int32) + result = s.hist(bins=[-10, -9]) + expected = pl.DataFrame( + { + "breakpoint": pl.Series([-9.0], dtype=pl.Float64), + "category": pl.Series(["[-10.0, -9.0]"], dtype=pl.Categorical), + "count": pl.Series([0], dtype=pl.UInt32), + } + ) + assert_frame_equal(result, expected) + + +@StringCache() +def test_hist_bins_between_data() -> None: + s = pl.Series([-5, 2, 0, 1, 99], dtype=pl.Int32) + result = s.hist(bins=[4.5, 10.5]) + expected = pl.DataFrame( + { + "breakpoint": pl.Series([10.5], dtype=pl.Float64), + "category": pl.Series(["[4.5, 10.5]"], dtype=pl.Categorical), + "count": pl.Series([0], dtype=pl.UInt32), + } + ) + assert_frame_equal(result, expected) + + +@StringCache() +def test_hist_bins_first_edge() -> None: + s = pl.Series([-5, 2, 0, 1, 99], dtype=pl.Int32) + result = s.hist(bins=[2, 3, 4]) + expected = pl.DataFrame( + { + "breakpoint": pl.Series([3.0, 4.0], dtype=pl.Float64), + "category": pl.Series(["[2.0, 3.0]", "(3.0, 4.0]"], dtype=pl.Categorical), + "count": pl.Series([1, 0], dtype=pl.UInt32), + } + ) + assert_frame_equal(result, expected) + + +@StringCache() +def test_hist_bins_last_edge() -> None: + s = pl.Series([-5, 2, 0, 1, 99], dtype=pl.Int32) + result = s.hist(bins=[-4, 0, 99, 100]) + expected = pl.DataFrame( + { + "breakpoint": pl.Series([0.0, 99.0, 100.0], dtype=pl.Float64), + "category": pl.Series( + [ + "[-4.0, 0.0]", + "(0.0, 99.0]", + "(99.0, 100.0]", + ], + dtype=pl.Categorical, + ), + "count": pl.Series([1, 3, 0], dtype=pl.UInt32), + } + ) + assert_frame_equal(result, expected) + + +@StringCache() +def test_hist_single_value_single_bin_count() -> None: + s = pl.Series([1], dtype=pl.Int32) + result = s.hist(bin_count=1) + expected = pl.DataFrame( + { + "breakpoint": pl.Series([1.5], dtype=pl.Float64), + "category": pl.Series(["[0.5, 1.5]"], dtype=pl.Categorical), + "count": pl.Series([1], dtype=pl.UInt32), + } + ) + assert_frame_equal(result, expected) + + +@StringCache() +def test_hist_single_bin_count() -> None: + s = pl.Series([-5, 2, 0, 1, 99], dtype=pl.Int32) + result = s.hist(bin_count=1) + expected = pl.DataFrame( + { + "breakpoint": pl.Series([99.0], dtype=pl.Float64), + "category": pl.Series(["[-5.0, 99.0]"], dtype=pl.Categorical), + "count": pl.Series([5], dtype=pl.UInt32), + } + ) + assert_frame_equal(result, expected) + + +@StringCache() +def test_hist_partial_covering() -> None: + s = pl.Series([-5, 2, 0, 1, 99], dtype=pl.Int32) + result = s.hist(bins=[-1.5, 2.5, 50, 105]) + expected = pl.DataFrame( + { + "breakpoint": pl.Series([2.5, 50.0, 105.0], dtype=pl.Float64), + "category": pl.Series( + ["[-1.5, 2.5]", "(2.5, 50.0]", "(50.0, 105.0]"], dtype=pl.Categorical + ), + "count": pl.Series([3, 0, 1], dtype=pl.UInt32), + } + ) + assert_frame_equal(result, expected) + + +@StringCache() +def test_hist_full_covering() -> None: + s = pl.Series([-5, 2, 0, 1, 99], dtype=pl.Int32) + result = s.hist(bins=[-5.5, 2.5, 50, 105]) + expected = pl.DataFrame( + { + "breakpoint": pl.Series([2.5, 50.0, 105.0], dtype=pl.Float64), + "category": pl.Series( + ["[-5.5, 2.5]", "(2.5, 50.0]", "(50.0, 105.0]"], dtype=pl.Categorical + ), + "count": pl.Series([4, 0, 1], dtype=pl.UInt32), + } + ) + assert_frame_equal(result, expected) + + +@StringCache() +def test_hist_more_bins_than_data() -> None: + s = pl.Series([-5, 2, 0, 1, 99], dtype=pl.Int32) + result = s.hist(bin_count=8) + + # manually compute breaks + span = 99 - (-5) + width = span / 8 + breaks = [-5 + width * i for i in range(8 + 1)] + categories = [f"({breaks[i]}, {breaks[i + 1]}]" for i in range(8)] + categories[0] = f"[{categories[0][1:]}" + + expected = pl.DataFrame( + { + "breakpoint": pl.Series(breaks[1:], dtype=pl.Float64), + "category": pl.Series(categories, dtype=pl.Categorical), + "count": pl.Series([4, 0, 0, 0, 0, 0, 0, 1], dtype=pl.UInt32), + } + ) + assert_frame_equal(result, expected) + + +@StringCache() +def test_hist() -> None: + s = pl.Series("a", [1, 3, 8, 8, 2, 1, 3]) + out = s.hist(bin_count=4) + expected = pl.DataFrame( + { + "breakpoint": pl.Series([2.75, 4.5, 6.25, 8.0], dtype=pl.Float64), + "category": pl.Series( + ["[1.0, 2.75]", "(2.75, 4.5]", "(4.5, 6.25]", "(6.25, 8.0]"], + dtype=pl.Categorical, + ), + "count": pl.Series([3, 2, 0, 2], dtype=pl.get_index_type()), + } + ) + assert_frame_equal(out, expected) + + +@pytest.mark.may_fail_auto_streaming +def test_hist_all_null() -> None: + s = pl.Series([None], dtype=pl.Float64) + out = s.hist() + expected = pl.DataFrame( + { + "breakpoint": pl.Series( + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], dtype=pl.Float64 + ), + "category": pl.Series( + [ + "[0.0, 0.1]", + "(0.1, 0.2]", + "(0.2, 0.3]", + "(0.3, 0.4]", + "(0.4, 0.5]", + "(0.5, 0.6]", + "(0.6, 0.7]", + "(0.7, 0.8]", + "(0.8, 0.9]", + "(0.9, 1.0]", + ], + dtype=pl.Categorical, + ), + "count": pl.Series([0] * 10, dtype=pl.get_index_type()), + } + ) + assert_frame_equal(out, expected) + + +@pytest.mark.parametrize("n_null", [0, 5]) +@pytest.mark.parametrize("n_values", [3, 10, 250]) +@pytest.mark.may_fail_auto_streaming +def test_hist_rand(n_values: int, n_null: int) -> None: + s_rand = pl.Series([None] * n_null, dtype=pl.Int64) + s_values = pl.Series(np.random.randint(0, 100, n_values), dtype=pl.Int64) + if s_values.n_unique() == 1: + pytest.skip("Identical values not tested.") + s = pl.concat((s_rand, s_values)) + out = s.hist(bin_count=10) + + bp = out["breakpoint"] + count = out["count"] + min_edge = s.min() - (s.max() - s.min()) * 0.001 # type: ignore[operator] + for i in range(out.height): + if i == 0: + lower = min_edge + else: + lower = bp[i - 1] + upper = bp[i] + + assert ((s <= upper) & (s > lower)).sum() == count[i] + + +def test_hist_floating_point() -> None: + # This test hits the specific floating point case where the bin width should be + # 5.7, but it is represented by 5.6999999. The result is that an item equal to the + # upper bound (72) exceeds the maximum bins. This tests the code path that catches + # this case. + n_values = 3 + n_null = 50 + + np.random.seed(2) + s_rand = pl.Series([None] * n_null, dtype=pl.Int64) + s_values = pl.Series(np.random.randint(0, 100, n_values), dtype=pl.Int64) + s = pl.concat((s_rand, s_values)) + out = s.hist(bin_count=10) + min_edge = s.min() - (s.max() - s.min()) * 0.001 # type: ignore[operator] + + bp = out["breakpoint"] + count = out["count"] + for i in range(out.height): + if i == 0: + lower = min_edge + else: + lower = bp[i - 1] + upper = bp[i] + + assert ((s <= upper) & (s > lower)).sum() == count[i] + + +def test_hist_max_boundary_19998() -> None: + s = pl.Series( + [ + 9514.988509739183, + 30738.098872148617, + 41400.15705103004, + 49093.06982022727, + ] + ) + result = s.hist(bin_count=50) + assert result["count"].sum() == 4 + + +def test_hist_max_boundary_20133() -> None: + # Given a set of values that result in bin index to be a floating point number that + # is represented as 5.000000000000001 + s = pl.Series( + [ + 6.197601318359375, + 83.5203145345052, + ] + ) + + # When histogram is calculated + result = s.hist(bin_count=5) + + # Then there is no exception (previously was possible to get index out of bounds + # here) and all the numbers fit into at least one of the bins + assert result["count"].sum() == 2 + + +@pytest.mark.may_fail_auto_streaming +def test_hist_same_values_20030() -> None: + out = pl.Series([1, 1]).hist(bin_count=2) + expected = pl.DataFrame( + { + "breakpoint": pl.Series([1.0, 1.5], dtype=pl.Float64), + "category": pl.Series(["[0.5, 1.0]", "(1.0, 1.5]"], dtype=pl.Categorical), + "count": pl.Series([2, 0], dtype=pl.get_index_type()), + } + ) + assert_frame_equal(out, expected) + + +@pytest.mark.may_fail_auto_streaming +def test_hist_breakpoint_accuracy() -> None: + s = pl.Series([1, 2, 3, 4]) + out = s.hist(bin_count=3) + expected = pl.DataFrame( + { + "breakpoint": pl.Series([2.0, 3.0, 4.0], dtype=pl.Float64), + "category": pl.Series( + ["[1.0, 2.0]", "(2.0, 3.0]", "(3.0, 4.0]"], dtype=pl.Categorical + ), + "count": pl.Series([2, 1, 1], dtype=pl.get_index_type()), + } + ) + assert_frame_equal(out, expected) + + +def test_hist_ensure_max_value_20879() -> None: + s = pl.Series([-1 / 3, 0, 1, 2, 3, 7]) + with pl.StringCache(): + result = s.hist(bin_count=3) + expected = pl.DataFrame( + { + "breakpoint": pl.Series( + [ + 2.0 + 1 / 9, + 4.0 + 5 / 9, + 7.0, + ], + dtype=pl.Float64, + ), + "category": pl.Series( + [ + "[-0.333333, 2.111111]", + "(2.111111, 4.555556]", + "(4.555556, 7.0]", + ], + dtype=pl.Categorical, + ), + "count": pl.Series([4, 1, 1], dtype=pl.get_index_type()), + } + ) + assert_frame_equal(result, expected) + + +def test_hist_ignore_nans_21082() -> None: + s = pl.Series([0.0, float("nan"), 0.5, float("nan"), 1.0]) + with pl.StringCache(): + result = s.hist(bins=[-0.001, 0.25, 0.5, 0.75, 1.0]) + expected = pl.DataFrame( + { + "breakpoint": pl.Series([0.25, 0.5, 0.75, 1.0], dtype=pl.Float64), + "category": pl.Series( + [ + "[-0.001, 0.25]", + "(0.25, 0.5]", + "(0.5, 0.75]", + "(0.75, 1.0]", + ], + dtype=pl.Categorical, + ), + "count": pl.Series([1, 1, 0, 1], dtype=pl.get_index_type()), + } + ) + assert_frame_equal(result, expected) + + +def test_hist_include_lower_22056() -> None: + s = pl.Series("a", [1, 5]) + with pl.StringCache(): + result = s.hist(bins=[1, 5], include_category=True) + expected = pl.DataFrame( + { + "breakpoint": pl.Series([5.0], dtype=pl.Float64), + "category": pl.Series(["[1.0, 5.0]"], dtype=pl.Categorical), + "count": pl.Series([2], dtype=pl.get_index_type()), + } + ) + assert_frame_equal(result, expected) + + +def test_hist_ulp_edge_22234() -> None: + # Uniform path + s = pl.Series([1.0, 1e-16, 1.3e-16, -1.0]) + result = s.hist(bin_count=2) + assert result["count"].to_list() == [1, 3] + + # Manual path + result = s.hist(bins=[-1, 0, 1]) + assert result["count"].to_list() == [1, 3] diff --git a/py-polars/tests/unit/operations/test_index_of.py b/py-polars/tests/unit/operations/test_index_of.py new file mode 100644 index 000000000000..eada0efe8859 --- /dev/null +++ b/py-polars/tests/unit/operations/test_index_of.py @@ -0,0 +1,353 @@ +from __future__ import annotations + +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from typing import TYPE_CHECKING, Any + +import numpy as np +import pytest +from hypothesis import example, given +from hypothesis import strategies as st + +import polars as pl +from polars.exceptions import InvalidOperationError + +if TYPE_CHECKING: + from polars._typing import IntoExpr +from polars.testing import assert_frame_equal + + +def isnan(value: object) -> bool: + if isinstance(value, int): + return False + if not isinstance(value, (np.number, float)): + return False + return np.isnan(value) # type: ignore[no-any-return] + + +def assert_index_of( + series: pl.Series, + value: IntoExpr, + convert_to_literal: bool = False, +) -> None: + """``Series.index_of()`` returns the index, or ``None`` if it can't be found.""" + if isnan(value): + expected_index = None + for i, o in enumerate(series.to_list()): + if o is not None and np.isnan(o): + expected_index = i + break + else: + try: + expected_index = series.to_list().index(value) + except ValueError: + expected_index = None + if expected_index == -1: + expected_index = None + + if convert_to_literal: + value = pl.lit(value, dtype=series.dtype) + + # Eager API: + print(series.index_of(value), expected_index, value, series) + assert series.index_of(value) == expected_index + # Lazy API: + assert pl.LazyFrame({"series": series}).select( + pl.col("series").index_of(value) + ).collect().get_column("series").to_list() == [expected_index] + + +@pytest.mark.parametrize("dtype", [pl.Float32, pl.Float64]) +def test_float(dtype: pl.DataType) -> None: + values = [1.5, np.nan, np.inf, 3.0, None, -np.inf, 0.0, -0.0, -np.nan] + series = pl.Series(values, dtype=dtype) + sorted_series_asc = series.sort(descending=False) + sorted_series_desc = series.sort(descending=True) + chunked_series = pl.concat([pl.Series([1, 7], dtype=dtype), series], rechunk=False) + + extra_values = [ + np.int8(3), + np.int64(2**42), + np.float64(1.5), + np.float32(1.5), + np.float32(2**37), + np.float64(2**100), + ] + for s in [series, sorted_series_asc, sorted_series_desc, chunked_series]: + for value in values: + assert_index_of(s, value, convert_to_literal=True) + assert_index_of(s, value, convert_to_literal=False) + for value in extra_values: # type: ignore[assignment] + assert_index_of(s, value) + + # Explicitly check some extra-tricky edge cases: + assert series.index_of(-np.nan) == 1 # -np.nan should match np.nan + assert series.index_of(-0.0) == 6 # -0.0 should match 0.0 + + +def test_null() -> None: + series = pl.Series([None, None], dtype=pl.Null) + assert_index_of(series, None) + + +def test_empty() -> None: + series = pl.Series([], dtype=pl.Null) + assert_index_of(series, None) + series = pl.Series([], dtype=pl.Int64) + assert_index_of(series, None) + assert_index_of(series, 12) + assert_index_of(series.sort(descending=True), 12) + assert_index_of(series.sort(descending=False), 12) + + +@pytest.mark.parametrize( + "dtype", + [ + pl.Int8, + pl.Int16, + pl.Int32, + pl.Int64, + pl.UInt8, + pl.UInt16, + pl.UInt32, + pl.UInt64, + pl.Int128, + ], +) +def test_integer(dtype: pl.DataType) -> None: + values = [ + 51, + 3, + None, + 4, + pl.select(dtype.max()).item(), # type: ignore[attr-defined] + pl.select(dtype.min()).item(), # type: ignore[attr-defined] + ] + series = pl.Series(values, dtype=dtype) + sorted_series_asc = series.sort(descending=False) + sorted_series_desc = series.sort(descending=True) + chunked_series = pl.concat( + [pl.Series([100, 7], dtype=dtype), series], rechunk=False + ) + + extra_values = [pl.select(v).item() for v in [dtype.max() - 1, dtype.min() + 1]] # type: ignore[attr-defined] + for s in [series, sorted_series_asc, sorted_series_desc, chunked_series]: + value: IntoExpr + for value in values: + assert_index_of(s, value, convert_to_literal=True) + assert_index_of(s, value, convert_to_literal=False) + for value in extra_values: + assert_index_of(s, value, convert_to_literal=True) + assert_index_of(s, value, convert_to_literal=False) + + # Can't cast floats: + for f in [np.float32(3.1), np.float64(3.1), 50.9]: + with pytest.raises(InvalidOperationError, match="cannot cast lossless"): + s.index_of(f) # type: ignore[arg-type] + + +def test_groupby() -> None: + df = pl.DataFrame( + {"label": ["a", "b", "a", "b", "a", "b"], "value": [10, 3, 20, 2, 40, 20]} + ) + expected = pl.DataFrame( + {"label": ["a", "b"], "value": [1, 2]}, + schema={"label": pl.String, "value": pl.UInt32}, + ) + assert_frame_equal( + df.group_by("label", maintain_order=True).agg(pl.col("value").index_of(20)), + expected, + ) + assert_frame_equal( + df.lazy() + .group_by("label", maintain_order=True) + .agg(pl.col("value").index_of(20)) + .collect(), + expected, + ) + + +LISTS_STRATEGY = st.lists( + st.one_of(st.none(), st.integers(min_value=10, max_value=50)), max_size=10 +) + + +@given( + list1=LISTS_STRATEGY, + list2=LISTS_STRATEGY, + list3=LISTS_STRATEGY, +) +# The examples are cases where this test previously caught bugs: +@example([], [], [None]) +def test_randomized( + list1: list[int | None], list2: list[int | None], list3: list[int | None] +) -> None: + series = pl.concat( + [pl.Series(values, dtype=pl.Int8) for values in [list1, list2, list3]], + rechunk=False, + ) + sorted_series = series.sort(descending=False) + sorted_series2 = series.sort(descending=True) + + # Values are between 10 and 50, plus add None and max/min range values: + for i in set(range(10, 51)) | {-128, 127, None}: + assert_index_of(series, i) + assert_index_of(sorted_series, i) + assert_index_of(sorted_series2, i) + + +ENUM = pl.Enum(["a", "b", "c"]) + + +@pytest.mark.parametrize( + ("series", "extra_values", "sortable"), + [ + (pl.Series(["abc", None, "bb"]), ["", "🚲"], True), + (pl.Series([True, None, False, True, False]), [], True), + ( + pl.Series([datetime(1997, 12, 31), datetime(1996, 1, 1)]), + [datetime(2023, 12, 12, 16, 12, 39)], + True, + ), + ( + pl.Series([date(1997, 12, 31), None, date(1996, 1, 1)]), + [date(2023, 12, 12)], + True, + ), + ( + pl.Series([time(16, 12, 31), None, time(11, 10, 53)]), + [time(11, 12, 16)], + True, + ), + ( + pl.Series( + [timedelta(hours=12), None, timedelta(minutes=3)], + ), + [timedelta(minutes=17)], + True, + ), + (pl.Series([[1, 2], None, [4, 5], [6], [None, 3, 5]]), [[5, 7], []], True), + ( + pl.Series([[[1, 2]], None, [[4, 5]], [[6]], [[None, 3, 5]], [None]]), + [[[5, 7]], []], + True, + ), + ( + pl.Series([[1, 2], None, [4, 5], [None, 3]], dtype=pl.Array(pl.Int64(), 2)), + [[5, 7], [None, None]], + True, + ), + ( + pl.Series( + [[[1, 2]], [None], [[4, 5]], None, [[None, 3]]], + dtype=pl.Array(pl.Array(pl.Int64(), 2), 1), + ), + [[[5, 7]], [[None, None]]], + True, + ), + ( + pl.Series( + [{"a": 1, "b": 2}, None, {"a": 3, "b": 4}, {"a": None, "b": 2}], + dtype=pl.Struct({"a": pl.Int64(), "b": pl.Int64()}), + ), + [{"a": 7, "b": None}, {"a": 6, "b": 4}], + False, + ), + (pl.Series([b"abc", None, b"xxx"]), [b"\x0025"], True), + (pl.Series([Decimal(12), None, Decimal(3)]), [Decimal(4)], True), + ], +) +def test_other_types( + series: pl.Series, extra_values: list[Any], sortable: bool +) -> None: + expected_values = series.to_list() + series_variants = [series, series.drop_nulls()] + if sortable: + series_variants.extend( + [ + series.sort(descending=False), + series.sort(descending=True), + ] + ) + for s in series_variants: + for value in expected_values: + assert_index_of(s, value, convert_to_literal=True) + assert_index_of(s, value, convert_to_literal=False) + # Extra values may not be expressible as literal of correct dtype, so + # don't try: + for value in extra_values: + assert_index_of(s, value) + + +# Before the output type would be list[idx-type] when no item was found +def test_non_found_correct_type() -> None: + df = pl.DataFrame( + [ + pl.Series("a", [0, 1], pl.Int32), + pl.Series("b", [1, 2], pl.Int32), + ] + ) + + assert_frame_equal( + df.group_by("a", maintain_order=True).agg(pl.col.b.index_of(1)), + pl.DataFrame({"a": [0, 1], "b": [0, None]}), + check_dtypes=False, + ) + + +def test_error_on_multiple_values() -> None: + with pytest.raises( + pl.exceptions.InvalidOperationError, + match="needle of `index_of` can only contain", + ): + pl.Series("a", [1, 2, 3]).index_of(pl.Series([2, 3])) + + +@pytest.mark.parametrize( + "convert_to_literal", + [ + True, + False, + ], +) +def test_enum(convert_to_literal: bool) -> None: + series = pl.Series(["a", "c", None, "b"], dtype=pl.Enum(["c", "b", "a"])) + expected_values = series.to_list() + for s in [ + series, + series.drop_nulls(), + series.sort(descending=False), + series.sort(descending=True), + ]: + for value in expected_values: + assert_index_of(s, value, convert_to_literal=convert_to_literal) + + +@pytest.mark.parametrize( + "convert_to_literal", + [ + pytest.param( + True, + marks=pytest.mark.xfail( + reason="https://github.com/pola-rs/polars/issues/20318" + ), + ), + pytest.param( + False, + marks=pytest.mark.xfail( + reason="https://github.com/pola-rs/polars/issues/20171" + ), + ), + ], +) +def test_categorical(convert_to_literal: bool) -> None: + series = pl.Series(["a", "c", None, "b"], dtype=pl.Categorical) + expected_values = series.to_list() + for s in [ + series, + series.drop_nulls(), + series.sort(descending=False), + series.sort(descending=True), + ]: + for value in expected_values: + assert_index_of(s, value, convert_to_literal=convert_to_literal) diff --git a/py-polars/tests/unit/operations/test_inequality_join.py b/py-polars/tests/unit/operations/test_inequality_join.py new file mode 100644 index 000000000000..9b8ae6a9030a --- /dev/null +++ b/py-polars/tests/unit/operations/test_inequality_join.py @@ -0,0 +1,683 @@ +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING, Any + +import hypothesis.strategies as st +import numpy as np +import pytest +from hypothesis import given + +import polars as pl +from polars.testing import assert_frame_equal +from polars.testing.parametric.strategies import series + +if TYPE_CHECKING: + from hypothesis.strategies import DrawFn, SearchStrategy + + +@pytest.mark.parametrize( + ("pred_1", "pred_2"), + [ + (pl.col("time") > pl.col("time_right"), pl.col("cost") < pl.col("cost_right")), + (pl.col("time_right") < pl.col("time"), pl.col("cost_right") > pl.col("cost")), + ], +) +def test_self_join(pred_1: pl.Expr, pred_2: pl.Expr) -> None: + west = pl.DataFrame( + { + "t_id": [404, 498, 676, 742], + "time": [100, 140, 80, 90], + "cost": [6, 11, 10, 5], + "cores": [4, 2, 1, 4], + } + ) + + actual = west.join_where(west, pred_1, pred_2) + + expected = pl.DataFrame( + { + "t_id": [742, 404], + "time": [90, 100], + "cost": [5, 6], + "cores": [4, 4], + "t_id_right": [676, 676], + "time_right": [80, 80], + "cost_right": [10, 10], + "cores_right": [1, 1], + } + ) + assert_frame_equal(actual, expected, check_row_order=False, check_exact=True) + + +def test_basic_ie_join() -> None: + east = pl.DataFrame( + { + "id": [100, 101, 102], + "dur": [140, 100, 90], + "rev": [12, 12, 5], + "cores": [2, 8, 4], + } + ) + west = pl.DataFrame( + { + "t_id": [404, 498, 676, 742], + "time": [100, 140, 80, 90], + "cost": [6, 11, 10, 5], + "cores": [4, 2, 1, 4], + } + ) + + actual = east.join_where( + west, + pl.col("dur") < pl.col("time"), + pl.col("rev") > pl.col("cost"), + ) + + expected = pl.DataFrame( + { + "id": [101], + "dur": [100], + "rev": [12], + "cores": [8], + "t_id": [498], + "time": [140], + "cost": [11], + "cores_right": [2], + } + ) + assert_frame_equal(actual, expected, check_row_order=False, check_exact=True) + + +@given( + offset=st.integers(-6, 5), + length=st.integers(0, 6), +) +def test_ie_join_with_slice(offset: int, length: int) -> None: + east = pl.DataFrame( + { + "id": [100, 101, 102], + "dur": [120, 140, 160], + "rev": [12, 14, 16], + "cores": [2, 8, 4], + } + ).lazy() + west = pl.DataFrame( + { + "t_id": [404, 498, 676, 742], + "time": [90, 130, 150, 170], + "cost": [9, 13, 15, 16], + "cores": [4, 2, 1, 4], + } + ).lazy() + + actual = ( + east.join_where( + west, + pl.col("dur") < pl.col("time"), + pl.col("rev") < pl.col("cost"), + ) + .slice(offset, length) + .collect() + ) + + expected_full = pl.DataFrame( + { + "id": [101, 101, 100, 100, 100], + "dur": [140, 140, 120, 120, 120], + "rev": [14, 14, 12, 12, 12], + "cores": [8, 8, 2, 2, 2], + "t_id": [676, 742, 498, 676, 742], + "time": [150, 170, 130, 150, 170], + "cost": [15, 16, 13, 15, 16], + "cores_right": [1, 4, 2, 1, 4], + } + ) + # The ordering of the result is arbitrary, so we can + # only verify that each row of the slice is present in the full expected result. + assert len(actual) == len(expected_full.slice(offset, length)) + + expected_rows = set(expected_full.iter_rows()) + for row in actual.iter_rows(): + assert row in expected_rows, f"{row} not in expected rows" + + +def test_ie_join_with_expressions() -> None: + east = pl.DataFrame( + { + "id": [100, 101, 102], + "dur": [70, 50, 45], + "rev": [12, 12, 5], + "cores": [2, 8, 4], + } + ) + west = pl.DataFrame( + { + "t_id": [404, 498, 676, 742], + "time": [100, 140, 80, 90], + "cost": [12, 22, 20, 10], + "cores": [4, 2, 1, 4], + } + ) + + actual = east.join_where( + west, + (pl.col("dur") * 2) < pl.col("time"), + pl.col("rev") > (pl.col("cost").cast(pl.Int32) // 2).cast(pl.Int64), + ) + + expected = pl.DataFrame( + { + "id": [101], + "dur": [50], + "rev": [12], + "cores": [8], + "t_id": [498], + "time": [140], + "cost": [22], + "cores_right": [2], + } + ) + assert_frame_equal(actual, expected, check_row_order=False, check_exact=True) + + +@pytest.mark.parametrize( + "range_constraint", + [ + [ + # can write individual components + pl.col("time") >= pl.col("start_time"), + pl.col("time") < pl.col("end_time"), + ], + [ + # or a single `is_between` expression + pl.col("time").is_between("start_time", "end_time", closed="left") + ], + ], +) +def test_join_where_predicates(range_constraint: list[pl.Expr]) -> None: + left = pl.DataFrame( + { + "id": [0, 1, 2, 3, 4, 5], + "group": [0, 0, 0, 1, 1, 1], + "time": [ + datetime(2024, 8, 26, 15, 34, 30), + datetime(2024, 8, 26, 15, 35, 30), + datetime(2024, 8, 26, 15, 36, 30), + datetime(2024, 8, 26, 15, 37, 30), + datetime(2024, 8, 26, 15, 38, 0), + datetime(2024, 8, 26, 15, 39, 0), + ], + } + ) + right = pl.DataFrame( + { + "id": [0, 1, 2], + "group": [0, 1, 1], + "start_time": [ + datetime(2024, 8, 26, 15, 34, 0), + datetime(2024, 8, 26, 15, 35, 0), + datetime(2024, 8, 26, 15, 38, 0), + ], + "end_time": [ + datetime(2024, 8, 26, 15, 36, 0), + datetime(2024, 8, 26, 15, 37, 0), + datetime(2024, 8, 26, 15, 39, 0), + ], + } + ) + + actual = left.join_where(right, *range_constraint).select("id", "id_right") + + expected = pl.DataFrame( + { + "id": [0, 1, 1, 2, 4], + "id_right": [0, 0, 1, 1, 2], + } + ) + assert_frame_equal(actual, expected, check_row_order=False, check_exact=True) + + q = ( + left.lazy() + .join_where( + right.lazy(), + pl.col("group_right") == pl.col("group"), + *range_constraint, + ) + .select("id", "id_right", "group") + .sort("id") + ) + + explained = q.explain() + assert "INNER JOIN" in explained + assert "FILTER" in explained + actual = q.collect() + + expected = ( + left.join(right, how="cross") + .filter(pl.col("group") == pl.col("group_right"), *range_constraint) + .select("id", "id_right", "group") + .sort("id") + ) + assert_frame_equal(actual, expected, check_exact=True) + + q = ( + left.lazy() + .join_where( + right.lazy(), + pl.col("group") != pl.col("group_right"), + *range_constraint, + ) + .select("id", "id_right", "group") + .sort("id") + ) + + explained = q.explain() + assert "IEJOIN" in explained + assert "FILTER" in explained + actual = q.collect() + + expected = ( + left.join(right, how="cross") + .filter(pl.col("group") != pl.col("group_right"), *range_constraint) + .select("id", "id_right", "group") + .sort("id") + ) + assert_frame_equal(actual, expected, check_exact=True) + + q = ( + left.lazy() + .join_where( + right.lazy(), + pl.col("group") != pl.col("group_right"), + ) + .select("id", "group", "group_right") + .sort("id") + .select("group", "group_right") + ) + + explained = q.explain() + assert "NESTED LOOP" in explained + actual = q.collect() + assert actual.to_dict(as_series=False) == { + "group": [0, 0, 0, 0, 0, 0, 1, 1, 1], + "group_right": [1, 1, 1, 1, 1, 1, 0, 0, 0], + } + + +def _inequality_expression(col1: str, op: str, col2: str) -> pl.Expr: + if op == "<": + return pl.col(col1) < pl.col(col2) + elif op == "<=": + return pl.col(col1) <= pl.col(col2) + elif op == ">": + return pl.col(col1) > pl.col(col2) + elif op == ">=": + return pl.col(col1) >= pl.col(col2) + else: + message = f"Invalid operator '{op}'" + raise ValueError(message) + + +def operators() -> SearchStrategy[str]: + valid_operators = ["<", "<=", ">", ">="] + return st.sampled_from(valid_operators) + + +@st.composite +def east_df( + draw: DrawFn, with_nulls: bool = False, use_floats: bool = False +) -> pl.DataFrame: + height = draw(st.integers(min_value=0, max_value=20)) + + if use_floats: + dur_strategy: SearchStrategy[Any] = st.floats(allow_nan=True) + rev_strategy: SearchStrategy[Any] = st.floats(allow_nan=True) + dur_dtype: type[pl.DataType] = pl.Float32 + rev_dtype: type[pl.DataType] = pl.Float32 + else: + dur_strategy = st.integers(min_value=100, max_value=105) + rev_strategy = st.integers(min_value=9, max_value=13) + dur_dtype = pl.Int64 + rev_dtype = pl.Int64 + + if with_nulls: + dur_strategy = dur_strategy | st.none() + rev_strategy = rev_strategy | st.none() + + cores_strategy = st.integers(min_value=1, max_value=10) + + ids = np.arange(0, height) + dur = draw(st.lists(dur_strategy, min_size=height, max_size=height)) + rev = draw(st.lists(rev_strategy, min_size=height, max_size=height)) + cores = draw(st.lists(cores_strategy, min_size=height, max_size=height)) + + return pl.DataFrame( + [ + pl.Series("id", ids, dtype=pl.Int64), + pl.Series("dur", dur, dtype=dur_dtype), + pl.Series("rev", rev, dtype=rev_dtype), + pl.Series("cores", cores, dtype=pl.Int64), + ] + ) + + +@st.composite +def west_df( + draw: DrawFn, with_nulls: bool = False, use_floats: bool = False +) -> pl.DataFrame: + height = draw(st.integers(min_value=0, max_value=20)) + + if use_floats: + time_strategy: SearchStrategy[Any] = st.floats(allow_nan=True) + cost_strategy: SearchStrategy[Any] = st.floats(allow_nan=True) + time_dtype: type[pl.DataType] = pl.Float32 + cost_dtype: type[pl.DataType] = pl.Float32 + else: + time_strategy = st.integers(min_value=100, max_value=105) + cost_strategy = st.integers(min_value=9, max_value=13) + time_dtype = pl.Int64 + cost_dtype = pl.Int64 + + if with_nulls: + time_strategy = time_strategy | st.none() + cost_strategy = cost_strategy | st.none() + + cores_strategy = st.integers(min_value=1, max_value=10) + + t_id = np.arange(100, 100 + height) + time = draw(st.lists(time_strategy, min_size=height, max_size=height)) + cost = draw(st.lists(cost_strategy, min_size=height, max_size=height)) + cores = draw(st.lists(cores_strategy, min_size=height, max_size=height)) + + return pl.DataFrame( + [ + pl.Series("t_id", t_id, dtype=pl.Int64), + pl.Series("time", time, dtype=time_dtype), + pl.Series("cost", cost, dtype=cost_dtype), + pl.Series("cores", cores, dtype=pl.Int64), + ] + ) + + +@given( + east=east_df(), + west=west_df(), + op1=operators(), + op2=operators(), +) +def test_ie_join(east: pl.DataFrame, west: pl.DataFrame, op1: str, op2: str) -> None: + expr0 = _inequality_expression("dur", op1, "time") + expr1 = _inequality_expression("rev", op2, "cost") + + actual = east.join_where(west, expr0 & expr1) + + expected = east.join(west, how="cross").filter(expr0 & expr1) + assert_frame_equal(actual, expected, check_row_order=False, check_exact=True) + + +@given( + east=east_df(with_nulls=True), + west=west_df(with_nulls=True), + op1=operators(), + op2=operators(), +) +def test_ie_join_with_nulls( + east: pl.DataFrame, west: pl.DataFrame, op1: str, op2: str +) -> None: + expr0 = _inequality_expression("dur", op1, "time") + expr1 = _inequality_expression("rev", op2, "cost") + + actual = east.join_where(west, expr0 & expr1) + + expected = east.join(west, how="cross").filter(expr0 & expr1) + assert_frame_equal(actual, expected, check_row_order=False, check_exact=True) + + +@given( + east=east_df(use_floats=True), + west=west_df(use_floats=True), + op1=operators(), + op2=operators(), +) +def test_ie_join_with_floats( + east: pl.DataFrame, west: pl.DataFrame, op1: str, op2: str +) -> None: + expr0 = _inequality_expression("dur", op1, "time") + expr1 = _inequality_expression("rev", op2, "cost") + + actual = east.join_where(west, expr0, expr1) + + expected = east.join(west, how="cross").filter(expr0 & expr1) + assert_frame_equal(actual, expected, check_row_order=False, check_exact=True) + + +def test_raise_invalid_input_join_where() -> None: + df = pl.DataFrame({"id": [1, 2]}) + with pytest.raises( + pl.exceptions.InvalidOperationError, + match="expected join keys/predicates", + ): + df.join_where(df) + + +def test_ie_join_use_keys_multiple() -> None: + a = pl.LazyFrame({"a": [1, 2, 3], "x": [7, 2, 1]}) + b = pl.LazyFrame({"b": [2, 2, 2], "x": [7, 1, 3]}) + + assert a.join_where( + b, + pl.col.a >= pl.col.b, + pl.col.a <= pl.col.b, + ).collect().sort("x_right").to_dict(as_series=False) == { + "a": [2, 2, 2], + "x": [2, 2, 2], + "b": [2, 2, 2], + "x_right": [1, 3, 7], + } + + +@given( + left=series( + dtype=pl.Int64, + strategy=st.integers(min_value=0, max_value=10) | st.none(), + max_size=10, + ), + right=series( + dtype=pl.Int64, + strategy=st.integers(min_value=-10, max_value=10) | st.none(), + max_size=10, + ), + op=operators(), +) +def test_single_inequality(left: pl.Series, right: pl.Series, op: str) -> None: + expr = _inequality_expression("x", op, "y") + + left_df = pl.DataFrame( + { + "id": np.arange(len(left)), + "x": left, + } + ) + right_df = pl.DataFrame( + { + "id": np.arange(len(right)), + "y": right, + } + ) + + actual = left_df.join_where(right_df, expr) + + expected = left_df.join(right_df, how="cross").filter(expr) + assert_frame_equal(actual, expected, check_row_order=False, check_exact=True) + + +@given( + offset=st.integers(-6, 5), + length=st.integers(0, 6), +) +def test_single_inequality_with_slice(offset: int, length: int) -> None: + left = pl.DataFrame( + { + "id": list(range(8)), + "x": [0, 1, 1, 2, 3, 5, 5, 7], + } + ) + right = pl.DataFrame( + { + "id": list(range(6)), + "y": [-1, 2, 4, 4, 6, 9], + } + ) + + expr = pl.col("x") > pl.col("y") + actual = left.join_where(right, expr).slice(offset, length) + + expected_full = left.join(right, how="cross").filter(expr) + + assert len(actual) == len(expected_full.slice(offset, length)) + + expected_rows = set(expected_full.iter_rows()) + for row in actual.iter_rows(): + assert row in expected_rows, f"{row} not in expected rows" + + +def test_ie_join_projection_pd_19005() -> None: + lf = pl.LazyFrame({"a": [1, 2], "b": [3, 4]}).with_row_index() + q = ( + lf.join_where( + lf, + pl.col.index < pl.col.index_right, + pl.col.index.cast(pl.Int64) + pl.col.a > pl.col.a_right, + ) + .group_by(pl.col.index) + .agg(pl.col.index_right) + ) + + out = q.collect() + assert out.schema == pl.Schema( + [("index", pl.get_index_type()), ("index_right", pl.List(pl.get_index_type()))] + ) + assert out.shape == (0, 2) + + +def test_single_sided_predicate() -> None: + left = pl.LazyFrame({"a": [1, -1, 2]}).with_row_index() + right = pl.LazyFrame({"b": [1, 2]}) + + result = ( + left.join_where(right, pl.col.index >= pl.col.a) + .collect() + .sort("index", "a", "b") + ) + expected = pl.DataFrame( + { + "index": pl.Series([1, 1, 2, 2], dtype=pl.get_index_type()), + "a": [-1, -1, 2, 2], + "b": [1, 2, 1, 2], + } + ) + assert_frame_equal(result, expected) + + +def test_join_on_strings() -> None: + df = pl.LazyFrame( + { + "a": ["a", "b", "c"], + "b": ["b", "b", "b"], + } + ) + + q = df.join_where(df, pl.col("a").ge(pl.col("a_right"))) + + assert "NESTED LOOP JOIN" in q.explain() + # Note: Output is flaky without sort when POLARS_MAX_THREADS=1 + assert q.collect().sort(pl.all()).to_dict(as_series=False) == { + "a": ["a", "b", "b", "c", "c", "c"], + "b": ["b", "b", "b", "b", "b", "b"], + "a_right": ["a", "a", "b", "a", "b", "c"], + "b_right": ["b", "b", "b", "b", "b", "b"], + } + + +def test_join_partial_column_name_overlap_19119() -> None: + left = pl.LazyFrame({"a": [1], "b": [2]}) + right = pl.LazyFrame({"a": [2], "d": [0]}) + + q = left.join_where(right, pl.col("a") > pl.col("d")) + + assert q.collect().to_dict(as_series=False) == { + "a": [1], + "b": [2], + "a_right": [2], + "d": [0], + } + + +def test_join_predicate_pushdown_19580() -> None: + left = pl.LazyFrame( + { + "a": [1, 2, 3, 1], + "b": [1, 2, 3, 4], + "c": [2, 3, 4, 5], + } + ) + + right = pl.LazyFrame({"a": [1, 3], "c": [2, 4], "d": [6, 3]}) + + q = left.join_where( + right, + pl.col("b") < pl.col("c_right"), + pl.col("a") < pl.col("a_right"), + pl.col("a") < pl.col("d"), + ) + + expect = ( + left.join(right, how="cross") + .collect() + .filter( + (pl.col("a") < pl.col("d")) + & (pl.col("b") < pl.col("c_right")) + & (pl.col("a") < pl.col("a_right")) + ) + ) + + assert_frame_equal(expect, q.collect(), check_row_order=False) + + +def test_join_where_literal_20061() -> None: + df_left = pl.DataFrame( + {"id": [1, 2, 3], "value_left": [10, 20, 30], "flag": [1, 0, 1]} + ) + + df_right = pl.DataFrame( + { + "id": [1, 2, 3], + "value_right": [5, 5, 25], + "flag": [1, 0, 1], + } + ) + + assert df_left.join_where( + df_right, + pl.col("value_left") > pl.col("value_right"), + pl.col("flag_right") == pl.lit(1, dtype=pl.Int8), + ).sort(pl.all()).to_dict(as_series=False) == { + "id": [1, 2, 3, 3], + "value_left": [10, 20, 30, 30], + "flag": [1, 0, 1, 1], + "id_right": [1, 1, 1, 3], + "value_right": [5, 5, 5, 25], + "flag_right": [1, 1, 1, 1], + } + + +def test_boolean_predicate_join_where() -> None: + urls = pl.LazyFrame({"url": "abcd.com/page"}) + categories = pl.LazyFrame({"base_url": "abcd.com", "category": "landing page"}) + assert ( + "NESTED LOOP JOIN" + in urls.join_where( + categories, pl.col("url").str.starts_with(pl.col("base_url")) + ).explain() + ) diff --git a/py-polars/tests/unit/operations/test_interpolate.py b/py-polars/tests/unit/operations/test_interpolate.py new file mode 100644 index 000000000000..5d39ffc751fa --- /dev/null +++ b/py-polars/tests/unit/operations/test_interpolate.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +from datetime import date, datetime, time, timedelta +from typing import TYPE_CHECKING, Any + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal +from tests.unit.conftest import NUMERIC_DTYPES + +if TYPE_CHECKING: + from polars._typing import PolarsDataType, PolarsTemporalType + +from zoneinfo import ZoneInfo + + +@pytest.mark.parametrize( + ("input_dtype", "output_dtype"), + [ + (pl.Int8, pl.Float64), + (pl.Int16, pl.Float64), + (pl.Int32, pl.Float64), + (pl.Int64, pl.Float64), + (pl.Int128, pl.Float64), + (pl.UInt8, pl.Float64), + (pl.UInt16, pl.Float64), + (pl.UInt32, pl.Float64), + (pl.UInt64, pl.Float64), + (pl.Float32, pl.Float32), + (pl.Float64, pl.Float64), + ], +) +def test_interpolate_linear( + input_dtype: PolarsDataType, output_dtype: PolarsDataType +) -> None: + df = pl.LazyFrame({"a": [1, None, 2, None, 3]}, schema={"a": input_dtype}) + result = df.with_columns(pl.all().interpolate(method="linear")) + assert result.collect_schema()["a"] == output_dtype + expected = pl.DataFrame( + {"a": [1.0, 1.5, 2.0, 2.5, 3.0]}, schema={"a": output_dtype} + ) + assert_frame_equal(result.collect(), expected) + + +@pytest.mark.parametrize( + ("input", "input_dtype", "output"), + [ + ( + [date(2020, 1, 1), None, date(2020, 1, 2)], + pl.Date, + [date(2020, 1, 1), date(2020, 1, 1), date(2020, 1, 2)], + ), + ( + [datetime(2020, 1, 1), None, datetime(2020, 1, 2)], + pl.Datetime("ms"), + [datetime(2020, 1, 1), datetime(2020, 1, 1, 12), datetime(2020, 1, 2)], + ), + ( + [ + datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kathmandu")), + None, + datetime(2020, 1, 2, tzinfo=ZoneInfo("Asia/Kathmandu")), + ], + pl.Datetime("us", "Asia/Kathmandu"), + [ + datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kathmandu")), + datetime(2020, 1, 1, 12, tzinfo=ZoneInfo("Asia/Kathmandu")), + datetime(2020, 1, 2, tzinfo=ZoneInfo("Asia/Kathmandu")), + ], + ), + ([time(1), None, time(2)], pl.Time, [time(1), time(1, 30), time(2)]), + ( + [timedelta(1), None, timedelta(2)], + pl.Duration("ms"), + [timedelta(1), timedelta(1, hours=12), timedelta(2)], + ), + ], +) +def test_interpolate_temporal_linear( + input: list[Any], input_dtype: PolarsTemporalType, output: list[Any] +) -> None: + df = pl.LazyFrame({"a": input}, schema={"a": input_dtype}) + result = df.with_columns(pl.all().interpolate(method="linear")) + assert result.collect_schema()["a"] == input_dtype + expected = pl.DataFrame({"a": output}, schema={"a": input_dtype}) + assert_frame_equal(result.collect(), expected) + + +@pytest.mark.parametrize("input_dtype", NUMERIC_DTYPES) +def test_interpolate_nearest(input_dtype: PolarsDataType) -> None: + df = pl.LazyFrame({"a": [1, None, 2, None, 3]}, schema={"a": input_dtype}) + result = df.with_columns(pl.all().interpolate(method="nearest")) + assert result.collect_schema()["a"] == input_dtype + expected = pl.DataFrame({"a": [1, 2, 2, 3, 3]}, schema={"a": input_dtype}) + assert_frame_equal(result.collect(), expected) + + +@pytest.mark.parametrize( + ("input", "input_dtype", "output"), + [ + ( + [date(2020, 1, 1), None, date(2020, 1, 2)], + pl.Date, + [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 2)], + ), + ( + [datetime(2020, 1, 1), None, datetime(2020, 1, 2)], + pl.Datetime("ms"), + [datetime(2020, 1, 1), datetime(2020, 1, 2), datetime(2020, 1, 2)], + ), + ( + [ + datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kathmandu")), + None, + datetime(2020, 1, 2, tzinfo=ZoneInfo("Asia/Kathmandu")), + ], + pl.Datetime("us", "Asia/Kathmandu"), + [ + datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kathmandu")), + datetime(2020, 1, 2, tzinfo=ZoneInfo("Asia/Kathmandu")), + datetime(2020, 1, 2, tzinfo=ZoneInfo("Asia/Kathmandu")), + ], + ), + ([time(1), None, time(2)], pl.Time, [time(1), time(2), time(2)]), + ( + [timedelta(1), None, timedelta(2)], + pl.Duration("ms"), + [timedelta(1), timedelta(2), timedelta(2)], + ), + ], +) +def test_interpolate_temporal_nearest( + input: list[Any], input_dtype: PolarsTemporalType, output: list[Any] +) -> None: + df = pl.LazyFrame({"a": input}, schema={"a": input_dtype}) + result = df.with_columns(pl.all().interpolate(method="nearest")) + assert result.collect_schema()["a"] == input_dtype + expected = pl.DataFrame({"a": output}, schema={"a": input_dtype}) + assert_frame_equal(result.collect(), expected) diff --git a/py-polars/tests/unit/operations/test_interpolate_by.py b/py-polars/tests/unit/operations/test_interpolate_by.py new file mode 100644 index 000000000000..2a6335ee6bad --- /dev/null +++ b/py-polars/tests/unit/operations/test_interpolate_by.py @@ -0,0 +1,251 @@ +from __future__ import annotations + +from datetime import date +from typing import TYPE_CHECKING + +import hypothesis.strategies as st +import numpy as np +import pytest +from hypothesis import assume, given + +import polars as pl +from polars.exceptions import InvalidOperationError +from polars.testing import assert_frame_equal, assert_series_equal +from polars.testing.parametric import column, dataframes + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + + +@pytest.mark.parametrize( + "times_dtype", + [ + pl.Datetime("ms"), + pl.Datetime("us", "Asia/Kathmandu"), + pl.Datetime("ns"), + pl.Date, + pl.Int64, + pl.Int32, + pl.UInt64, + pl.UInt32, + pl.Float32, + pl.Float64, + ], +) +@pytest.mark.parametrize( + "values_dtype", + [ + pl.Float64, + pl.Float32, + pl.Int64, + pl.Int32, + pl.UInt64, + pl.UInt32, + ], +) +def test_interpolate_by( + values_dtype: PolarsDataType, times_dtype: PolarsDataType +) -> None: + df = pl.DataFrame( + { + "times": [ + 1, + 3, + 10, + 11, + 12, + 16, + 21, + 30, + ], + "values": [1, None, None, 5, None, None, None, 6], + }, + schema={"times": times_dtype, "values": values_dtype}, + ) + result = df.select(pl.col("values").interpolate_by("times")) + expected = pl.DataFrame( + { + "values": [ + 1.0, + 1.7999999999999998, + 4.6, + 5.0, + 5.052631578947368, + 5.2631578947368425, + 5.526315789473684, + 6.0, + ] + } + ) + if values_dtype == pl.Float32: + expected = expected.select(pl.col("values").cast(pl.Float32)) + assert_frame_equal(result, expected) + result = ( + df.sort("times", descending=True) + .with_columns(pl.col("values").interpolate_by("times")) + .sort("times") + .drop("times") + ) + assert_frame_equal(result, expected) + + +def test_interpolate_by_leading_nulls() -> None: + df = pl.DataFrame( + { + "times": [ + date(2020, 1, 1), + date(2020, 1, 1), + date(2020, 1, 1), + date(2020, 1, 1), + date(2020, 1, 3), + date(2020, 1, 10), + date(2020, 1, 11), + ], + "values": [None, None, None, 1, None, None, 5], + } + ) + result = df.select(pl.col("values").interpolate_by("times")) + expected = pl.DataFrame({"values": [None, None, None, 1.0, 1.8, 4.6, 5.0]}) + assert_frame_equal(result, expected) + result = ( + df.sort("times", maintain_order=True, descending=True) + .with_columns(pl.col("values").interpolate_by("times")) + .sort("times", maintain_order=True) + .drop("times") + ) + assert_frame_equal(result, expected, check_exact=False) + + +@pytest.mark.parametrize("dataset", ["floats", "dates"]) +def test_interpolate_by_trailing_nulls(dataset: str) -> None: + input_data = { + "dates": pl.DataFrame( + { + "times": [ + date(2020, 1, 1), + date(2020, 1, 3), + date(2020, 1, 10), + date(2020, 1, 11), + date(2020, 1, 12), + date(2020, 1, 13), + ], + "values": [1, None, None, 5, None, None], + } + ), + "floats": pl.DataFrame( + { + "times": [0.2, 0.4, 0.5, 0.6, 0.9, 1.1], + "values": [1, None, None, 5, None, None], + } + ), + } + + expected_data = { + "dates": pl.DataFrame( + {"values": [1.0, 1.7999999999999998, 4.6, 5.0, None, None]} + ), + "floats": pl.DataFrame({"values": [1.0, 3.0, 4.0, 5.0, None, None]}), + } + + df = input_data[dataset] + expected = expected_data[dataset] + + result = df.select(pl.col("values").interpolate_by("times")) + + assert_frame_equal(result, expected) + result = ( + df.sort("times", descending=True) + .with_columns(pl.col("values").interpolate_by("times")) + .sort("times") + .drop("times") + ) + assert_frame_equal(result, expected) + + +@given(data=st.data(), x_dtype=st.sampled_from([pl.Date, pl.Float64])) +def test_interpolate_vs_numpy(data: st.DataObject, x_dtype: pl.DataType) -> None: + if x_dtype == pl.Float64: + by_strategy = st.floats( + min_value=-1e150, + max_value=1e150, + allow_nan=False, + allow_infinity=False, + allow_subnormal=False, + ) + else: + by_strategy = None + + dataframe = ( + data.draw( + dataframes( + [ + column( + "ts", + dtype=x_dtype, + allow_null=False, + strategy=by_strategy, + ), + column( + "value", + dtype=pl.Float64, + allow_null=True, + ), + ], + min_size=1, + ) + ) + .sort("ts") + .fill_nan(None) + .unique("ts") + ) + + if x_dtype == pl.Float64: + assume(not dataframe["ts"].is_nan().any()) + assume(not dataframe["ts"].is_null().any()) + assume(not dataframe["ts"].is_in([float("-inf"), float("inf")]).any()) + + assume(not dataframe["value"].is_null().all()) + assume(not dataframe["value"].is_in([float("-inf"), float("inf")]).any()) + + dataframe = dataframe.sort("ts") + + result = dataframe.select(pl.col("value").interpolate_by("ts"))["value"] + + mask = dataframe["value"].is_not_null() + + np_dtype = "int64" if x_dtype == pl.Date else "float64" + x = dataframe["ts"].to_numpy().astype(np_dtype) + xp = dataframe["ts"].filter(mask).to_numpy().astype(np_dtype) + yp = dataframe["value"].filter(mask).to_numpy().astype("float64") + interp = np.interp(x, xp, yp) + # Polars preserves nulls on boundaries, but NumPy doesn't. + first_non_null = dataframe["value"].is_not_null().arg_max() + last_non_null = len(dataframe) - dataframe["value"][::-1].is_not_null().arg_max() # type: ignore[operator] + interp[:first_non_null] = float("nan") + interp[last_non_null:] = float("nan") + expected = dataframe.with_columns(value=pl.Series(interp, nan_to_null=True))[ + "value" + ] + + # We increase the absolute error threshold, numpy has some instability, see #22348. + assert_series_equal(result, expected, atol=1e-4) + result_from_unsorted = ( + dataframe.sort("ts", descending=True) + .with_columns(pl.col("value").interpolate_by("ts")) + .sort("ts")["value"] + ) + assert_series_equal(result_from_unsorted, expected) + + +def test_interpolate_by_invalid() -> None: + s = pl.Series([1, None, 3]) + by = pl.Series([1, 2]) + with pytest.raises(InvalidOperationError, match=r"\(3\), got 2"): + s.interpolate_by(by) + + by = pl.Series([1, None, 3]) + with pytest.raises( + InvalidOperationError, + match="null values in `by` column are not yet supported in 'interpolate_by'", + ): + s.interpolate_by(by) diff --git a/py-polars/tests/unit/operations/test_is_first_last_distinct.py b/py-polars/tests/unit/operations/test_is_first_last_distinct.py new file mode 100644 index 000000000000..00a6e0a5f259 --- /dev/null +++ b/py-polars/tests/unit/operations/test_is_first_last_distinct.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +import datetime +from typing import TYPE_CHECKING, Any + +import pytest + +import polars as pl +from polars.exceptions import InvalidOperationError +from polars.testing import assert_frame_equal, assert_series_equal + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + + +def test_is_first_distinct() -> None: + lf = pl.LazyFrame({"a": [4, 1, 4]}) + result = lf.select(pl.col("a").is_first_distinct()).collect()["a"] + expected = pl.Series("a", [True, True, False]) + assert_series_equal(result, expected) + + +def test_is_first_distinct_bool_bit_chunk_index_calc() -> None: + # The fast path activates on sizes >=64 and processes in chunks of 64-bits. + # It calculates the indexes using the bit counts, which needs to be from the + # correct side. + assert pl.arange(0, 64, eager=True).filter( + pl.Series([True] + 63 * [False]).is_first_distinct() + ).to_list() == [0, 1] + + assert pl.arange(0, 64, eager=True).filter( + pl.Series([False] + 63 * [True]).is_first_distinct() + ).to_list() == [0, 1] + + assert pl.arange(0, 64, eager=True).filter( + pl.Series(2 * [True] + 2 * [False] + 60 * [None]).is_first_distinct() + ).to_list() == [0, 2, 4] + + assert pl.arange(0, 64, eager=True).filter( + pl.Series(2 * [False] + 2 * [None] + 60 * [True]).is_first_distinct() + ).to_list() == [0, 2, 4] + + +def test_is_first_distinct_struct() -> None: + lf = pl.LazyFrame({"a": [1, 2, 3, 2, None, 2, 1], "b": [0, 2, 3, 2, None, 2, 0]}) + result = lf.select(pl.struct("a", "b").is_first_distinct()) + expected = pl.LazyFrame({"a": [True, True, True, False, True, False, False]}) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "data", + [ + [[1, 2], [3], [1, 2], [4, None], [4, None], [], []], + [[True, None], [True], [True, None], [False], [False], [], []], + [[b"1", b"2"], [b"3"], [b"1", b"2"], [b"4", None], [b"4", None], [], []], + [["a", "b"], ["&"], ["a", "b"], ["...", None], ["...", None], [], []], + [ + [datetime.date(2000, 10, 1), datetime.date(2001, 1, 30)], + [datetime.date(1949, 10, 1)], + [datetime.date(2000, 10, 1), datetime.date(2001, 1, 30)], + [datetime.date(1998, 7, 1), None], + [datetime.date(1998, 7, 1), None], + [], + [], + ], + ], +) +def test_is_first_last_distinct_list(data: list[list[Any] | None]) -> None: + lf = pl.LazyFrame({"a": data}) + result = lf.select( + first=pl.col("a").is_first_distinct(), last=pl.col("a").is_last_distinct() + ) + expected = pl.LazyFrame( + { + "first": [True, True, False, True, False, True, False], + "last": [False, True, True, False, True, False, True], + } + ) + assert_frame_equal(result, expected) + + +def test_is_first_last_distinct_list_inner_nested() -> None: + df = pl.DataFrame({"a": [[[1, 2]], [[1, 2]]]}) + err_msg = "only allowed if the inner type is not nested" + with pytest.raises(InvalidOperationError, match=err_msg): + df.select(pl.col("a").is_first_distinct()) + with pytest.raises(InvalidOperationError, match=err_msg): + df.select(pl.col("a").is_last_distinct()) + + +def test_is_first_distinct_various() -> None: + # numeric + s = pl.Series([1, 1, None, 2, None, 3, 3]) + expected = [True, False, True, True, False, True, False] + assert s.is_first_distinct().to_list() == expected + # str + s = pl.Series(["x", "x", None, "y", None, "z", "z"]) + expected = [True, False, True, True, False, True, False] + assert s.is_first_distinct().to_list() == expected + # boolean + s = pl.Series([True, True, None, False, None, False, False]) + expected = [True, False, True, True, False, False, False] + assert s.is_first_distinct().to_list() == expected + # struct + s = pl.Series( + [ + {"x": 1, "y": 2}, + {"x": 1, "y": 2}, + None, + {"x": 2, "y": 1}, + None, + {"x": 3, "y": 2}, + {"x": 3, "y": 2}, + ] + ) + expected = [True, False, True, True, False, True, False] + assert s.is_first_distinct().to_list() == expected + # list + s = pl.Series([[1, 2], [1, 2], None, [2, 3], None, [3, 4], [3, 4]]) + expected = [True, False, True, True, False, True, False] + assert s.is_first_distinct().to_list() == expected + + +def test_is_last_distinct() -> None: + # numeric + s = pl.Series([1, 1, None, 2, None, 3, 3]) + expected = [False, True, False, True, True, False, True] + assert s.is_last_distinct().to_list() == expected + # str + s = pl.Series(["x", "x", None, "y", None, "z", "z"]) + expected = [False, True, False, True, True, False, True] + assert s.is_last_distinct().to_list() == expected + # boolean + s = pl.Series([True, True, None, False, None, False, False]) + expected = [False, True, False, False, True, False, True] + assert s.is_last_distinct().to_list() == expected + # struct + s = pl.Series( + [ + {"x": 1, "y": 2}, + {"x": 1, "y": 2}, + None, + {"x": 2, "y": 1}, + None, + {"x": 3, "y": 2}, + {"x": 3, "y": 2}, + ] + ) + expected = [False, True, False, True, True, False, True] + assert s.is_last_distinct().to_list() == expected + + +@pytest.mark.parametrize("dtypes", [pl.Int32, pl.String, pl.Boolean, pl.List(pl.Int32)]) +def test_is_first_last_distinct_all_null(dtypes: PolarsDataType) -> None: + s = pl.Series([None, None, None], dtype=dtypes) + assert s.is_first_distinct().to_list() == [True, False, False] + assert s.is_last_distinct().to_list() == [False, False, True] diff --git a/py-polars/tests/unit/operations/test_is_in.py b/py-polars/tests/unit/operations/test_is_in.py new file mode 100644 index 000000000000..6c180a1ae93a --- /dev/null +++ b/py-polars/tests/unit/operations/test_is_in.py @@ -0,0 +1,756 @@ +from __future__ import annotations + +from collections.abc import Collection +from datetime import date +from decimal import Decimal as D +from typing import TYPE_CHECKING + +import pytest + +import polars as pl +from polars import StringCache +from polars.exceptions import InvalidOperationError +from polars.testing import assert_frame_equal, assert_series_equal + +if TYPE_CHECKING: + from collections.abc import Iterator + + from polars._typing import PolarsDataType + + +def test_struct_logical_is_in() -> None: + df1 = pl.DataFrame( + { + "x": pl.date_range(date(2022, 1, 1), date(2022, 1, 7), eager=True), + "y": [0, 4, 6, 2, 3, 4, 5], + } + ) + df2 = pl.DataFrame( + { + "x": pl.date_range(date(2022, 1, 3), date(2022, 1, 9), eager=True), + "y": [6, 2, 3, 4, 5, 0, 1], + } + ) + + s1 = df1.select(pl.struct(["x", "y"])).to_series() + s2 = df2.select(pl.struct(["x", "y"])).to_series() + assert s1.is_in(s2).to_list() == [False, False, True, True, True, True, True] + + +def test_struct_logical_is_in_nonullpropagate() -> None: + s = pl.Series([date(2022, 1, 1), date(2022, 1, 2), date(2022, 1, 3), None]) + df1 = pl.DataFrame( + { + "x": s, + "y": [0, 4, 6, None], + } + ) + s = pl.Series([date(2022, 2, 1), date(2022, 1, 2), date(2022, 2, 3), None]) + df2 = pl.DataFrame( + { + "x": s, + "y": [6, 4, 3, None], + } + ) + + # Left has no nulls, right has nulls + s1 = df1.select(pl.struct(["x", "y"])).to_series() + s1 = s1.extend_constant(s1[0], 1) + s2 = df2.select(pl.struct(["x", "y"])).to_series().extend_constant(None, 1) + assert s1.is_in(s2, nulls_equal=False).to_list() == [ + False, + True, + False, + True, + False, + ] + assert s1.is_in(s2, nulls_equal=True).to_list() == [ + False, + True, + False, + True, + False, + ] + + # Left has nulls, right has no nulls + s1 = df1.select(pl.struct(["x", "y"])).to_series().extend_constant(None, 1) + s2 = df2.select(pl.struct(["x", "y"])).to_series() + s2 = s2.extend_constant(s2[0], 1) + assert s1.is_in(s2, nulls_equal=False).to_list() == [ + False, + True, + False, + True, + None, + ] + assert s1.is_in(s2, nulls_equal=True).to_list() == [ + False, + True, + False, + True, + False, + ] + + # Both have nulls + # {None, None} is a valid element unaffected by the missing parameter. + s1 = df1.select(pl.struct(["x", "y"])).to_series().extend_constant(None, 1) + s2 = df2.select(pl.struct(["x", "y"])).to_series().extend_constant(None, 1) + assert s1.is_in(s2, nulls_equal=False).to_list() == [ + False, + True, + False, + True, + None, + ] + assert s1.is_in(s2, nulls_equal=True).to_list() == [ + False, + True, + False, + True, + True, + ] + + +@pytest.mark.parametrize("nulls_equal", [False, True]) +def test_is_in_bool(nulls_equal: bool) -> None: + vals = [True, None] + df = pl.DataFrame({"A": [True, False, None]}) + missing_value = True if nulls_equal else None + assert df.select(pl.col("A").is_in(vals, nulls_equal=nulls_equal)).to_dict( + as_series=False + ) == {"A": [True, False, missing_value]} + + +def test_is_in_bool_11216() -> None: + s = pl.Series([False]).is_in([False, None]) + expected = pl.Series([True]) + assert_series_equal(s, expected) + + +@pytest.mark.parametrize("nulls_equal", [False, True]) +def test_is_in_empty_list_4559(nulls_equal: bool) -> None: + assert pl.Series(["a"]).is_in([], nulls_equal=nulls_equal).to_list() == [False] + + +def test_is_in_empty_list_4639() -> None: + df = pl.DataFrame({"a": [1, None]}) + empty_list: list[int] = [] + + result = df.with_columns([pl.col("a").is_in(empty_list).alias("a_in_list")]) + expected = pl.DataFrame({"a": [1, None], "a_in_list": [False, None]}) + assert_frame_equal(result, expected) + + +def test_is_in_struct() -> None: + df = pl.DataFrame( + { + "struct_elem": [{"a": 1, "b": 11}, {"a": 1, "b": 90}], + "struct_list": [ + [{"a": 1, "b": 11}, {"a": 2, "b": 12}, {"a": 3, "b": 13}], + [{"a": 3, "b": 3}], + ], + } + ) + + assert df.filter(pl.col("struct_elem").is_in("struct_list")).to_dict( + as_series=False + ) == { + "struct_elem": [{"a": 1, "b": 11}], + "struct_list": [[{"a": 1, "b": 11}, {"a": 2, "b": 12}, {"a": 3, "b": 13}]], + } + + +def test_is_in_null_prop() -> None: + assert pl.Series([None], dtype=pl.Float32).is_in(pl.Series([42])).item() is None + assert pl.Series([{"a": None}, None], dtype=pl.Struct({"a": pl.Float32})).is_in( + pl.Series([{"a": 42}], dtype=pl.Struct({"a": pl.Float32})) + ).to_list() == [False, None] + + assert pl.Series([{"a": None}, None], dtype=pl.Struct({"a": pl.Boolean})).is_in( + pl.Series([{"a": 42}], dtype=pl.Struct({"a": pl.Boolean})) + ).to_list() == [False, None] + + +def test_is_in_9070() -> None: + assert not pl.Series([1]).is_in(pl.Series([1.99])).item() + + +def test_is_in_float_list_10764() -> None: + df = pl.DataFrame( + { + "lst": [[1.0, 2.0, 3.0, 4.0, 5.0], [3.14, 5.28]], + "n": [3.0, 2.0], + } + ) + assert df.select(pl.col("n").is_in("lst").alias("is_in")).to_dict( + as_series=False + ) == {"is_in": [True, False]} + + +def test_is_in_df() -> None: + df = pl.DataFrame({"a": [1, 2, 3]}) + assert df.select(pl.col("a").is_in([1, 2]))["a"].to_list() == [True, True, False] + + +def test_is_in_series() -> None: + s = pl.Series(["a", "b", "c"]) + + out = s.is_in(["a", "b"]) + assert out.to_list() == [True, True, False] + + # Check if empty list is converted to pl.String + out = s.is_in([]) + assert out.to_list() == [False] * out.len() + + for x_y_z in (["x", "y", "z"], {"x", "y", "z"}): + out = s.is_in(x_y_z) + assert out.to_list() == [False, False, False] + + df = pl.DataFrame({"a": [1.0, 2.0], "b": [1, 4], "c": ["e", "d"]}) + assert df.select(pl.col("a").is_in(pl.col("b"))).to_series().to_list() == [ + True, + False, + ] + assert df.select(pl.col("b").is_in([])).to_series().to_list() == [False] * df.height + + with pytest.raises( + InvalidOperationError, + match=r"'is_in' cannot check for List\(String\) values in Int64 data", + ): + df.select(pl.col("b").is_in(["x", "x"])) + + # check we don't shallow-copy and accidentally modify 'a' (see: #10072) + a = pl.Series("a", [1, 2]) + b = pl.Series("b", [1, 3]).is_in(a) + + assert a.name == "a" + assert_series_equal(b, pl.Series("b", [True, False])) + + +@pytest.mark.parametrize("nulls_equal", [False, True]) +def test_is_in_null(nulls_equal: bool) -> None: + # No nulls in right + s = pl.Series([None, None], dtype=pl.Null) + result = s.is_in([1, 2], nulls_equal=nulls_equal) + missing_value = False if nulls_equal else None + expected = pl.Series([missing_value, missing_value], dtype=pl.Boolean) + assert_series_equal(result, expected) + + # Nulls in right + s = pl.Series([None, None], dtype=pl.Null) + result = s.is_in([None, None], nulls_equal=nulls_equal) + missing_value = True if nulls_equal else None + expected = pl.Series([missing_value, missing_value], dtype=pl.Boolean) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize("nulls_equal", [False, True]) +def test_is_in_boolean(nulls_equal: bool) -> None: + # Nulls in neither left nor right + s = pl.Series([True, False]) + result = s.is_in([True, False], nulls_equal=nulls_equal) + expected = pl.Series([True, True]) + assert_series_equal(result, expected) + + # Nulls in left only + s = pl.Series([True, None]) + result = s.is_in([False, False], nulls_equal=nulls_equal) + missing_value = False if nulls_equal else None + expected = pl.Series([False, missing_value]) + assert_series_equal(result, expected) + + # Nulls in right only + s = pl.Series([True, False]) + result = s.is_in([True, None], nulls_equal=nulls_equal) + expected = pl.Series([True, False]) + assert_series_equal(result, expected) + + # Nulls in both + s = pl.Series([True, False, None]) + result = s.is_in([True, None], nulls_equal=nulls_equal) + missing_value = True if nulls_equal else None + expected = pl.Series([True, False, missing_value]) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize("dtype", [pl.List(pl.Boolean), pl.Array(pl.Boolean, 2)]) +@pytest.mark.parametrize("nulls_equal", [False, True]) +def test_is_in_boolean_list(dtype: PolarsDataType, nulls_equal: bool) -> None: + # Note list is_in does not propagate nulls. + df = pl.DataFrame( + { + "a": [True, False, None, None, None], + "b": pl.Series( + [ + [True, False], + [True, True], + [None, True], + [False, True], + [True, True], + ], + dtype=dtype, + ), + } + ) + missing_true = True if nulls_equal else None + missing_false = False if nulls_equal else None + result = df.select(pl.col("a").is_in("b", nulls_equal=nulls_equal))["a"] + expected = pl.Series("a", [True, False, missing_true, missing_false, missing_false]) + assert_series_equal(result, expected) + + +def test_is_in_invalid_shape() -> None: + with pytest.raises(InvalidOperationError): + pl.Series("a", [1, 2, 3]).is_in([[], []]) + + +@pytest.mark.may_fail_auto_streaming +def test_is_in_list_rhs() -> None: + assert_series_equal( + pl.Series([1, 2, 3, 4, 5]).is_in(pl.Series([[1], [2, 9], [None], None, None])), + pl.Series([True, True, False, None, None]), + ) + + +@pytest.mark.parametrize("dtype", [pl.Float32, pl.Float64]) +def test_is_in_float(dtype: PolarsDataType) -> None: + s = pl.Series([float("nan"), 0.0], dtype=dtype) + result = s.is_in([-0.0, -float("nan")]) + expected = pl.Series([True, True], dtype=pl.Boolean) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("df", "matches", "expected_error"), + [ + ( + pl.DataFrame({"a": [1, 2], "b": [[1.0, 2.5], [3.0, 4.0]]}), + [True, False], + None, + ), + ( + pl.DataFrame({"a": [2.5, 3.0], "b": [[1, 2], [3, 4]]}), + [False, True], + None, + ), + ( + pl.DataFrame( + {"a": [None, None], "b": [[1, 2], [3, 4]]}, + schema_overrides={"a": pl.Null}, + ), + [None, None], + None, + ), + ( + pl.DataFrame({"a": ["1", "2"], "b": [[1, 2], [3, 4]]}), + None, + r"'is_in' cannot check for List\(Int64\) values in String data", + ), + ( + pl.DataFrame({"a": [date.today(), None], "b": [[1, 2], [3, 4]]}), + None, + r"'is_in' cannot check for List\(Int64\) values in Date data", + ), + ], +) +def test_is_in_expr_list_series( + df: pl.DataFrame, matches: list[bool] | None, expected_error: str | None +) -> None: + expr_is_in = pl.col("a").is_in(pl.col("b")) + if matches: + assert df.select(expr_is_in).to_series().to_list() == matches + else: + with pytest.raises(InvalidOperationError, match=expected_error): + df.select(expr_is_in) + + +@pytest.mark.parametrize( + ("df", "matches"), + [ + ( + pl.DataFrame({"a": [1, None], "b": [[1.0, 2.5, 4.0], [3.0, 4.0, 5.0]]}), + [True, False], + ), + ( + pl.DataFrame({"a": [1, None], "b": [[0.0, 2.5, None], [3.0, 4.0, None]]}), + [False, True], + ), + ( + pl.DataFrame( + {"a": [None, None], "b": [[1, 2], [3, 4]]}, + schema_overrides={"a": pl.Null}, + ), + [False, False], + ), + ( + pl.DataFrame( + {"a": [None, None], "b": [[1, 2], [3, None]]}, + schema_overrides={"a": pl.Null}, + ), + [False, True], + ), + ], +) +def test_is_in_expr_list_series_nonullpropagate( + df: pl.DataFrame, matches: list[bool] +) -> None: + expr_is_in = pl.col("a").is_in(pl.col("b"), nulls_equal=True) + assert df.select(expr_is_in).to_series().to_list() == matches + + +@pytest.mark.parametrize("nulls_equal", [False, True]) +def test_is_in_null_series(nulls_equal: bool) -> None: + df = pl.DataFrame({"a": ["a", "b", None]}) + result = df.select(pl.col("a").is_in([None], nulls_equal=nulls_equal)) + missing_value = True if nulls_equal else None + expected = pl.DataFrame({"a": [False, False, missing_value]}) + assert_frame_equal(result, expected) + + +def test_is_in_int_range() -> None: + r = pl.int_range(0, 3, eager=False) + out = pl.select(r.is_in([1, 2])).to_series() + assert out.to_list() == [False, True, True] + + r = pl.int_range(0, 3, eager=True) # type: ignore[assignment] + out = r.is_in([1, 2]) # type: ignore[assignment] + assert out.to_list() == [False, True, True] + + +def test_is_in_date_range() -> None: + r = pl.date_range(date(2023, 1, 1), date(2023, 1, 3), eager=False) + out = pl.select(r.is_in([date(2023, 1, 2), date(2023, 1, 3)])).to_series() + assert out.to_list() == [False, True, True] + + r = pl.date_range(date(2023, 1, 1), date(2023, 1, 3), eager=True) # type: ignore[assignment] + out = r.is_in([date(2023, 1, 2), date(2023, 1, 3)]) # type: ignore[assignment] + assert out.to_list() == [False, True, True] + + +@StringCache() +@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum(["a", "b", "c"])]) +@pytest.mark.parametrize("nulls_equal", [False, True]) +def test_cat_is_in_series(dtype: pl.DataType, nulls_equal: bool) -> None: + s = pl.Series(["a", "b", "c", None], dtype=dtype) + s2 = pl.Series(["b", "c"], dtype=dtype) + missing_value = False if nulls_equal else None + expected = pl.Series([False, True, True, missing_value]) + assert_series_equal(s.is_in(s2, nulls_equal=nulls_equal), expected) + + s2_str = s2.cast(pl.String) + assert_series_equal(s.is_in(s2_str, nulls_equal=nulls_equal), expected) + + +@StringCache() +@pytest.mark.parametrize("nulls_equal", [False, True]) +def test_cat_is_in_series_non_existent(nulls_equal: bool) -> None: + dtype = pl.Categorical + s = pl.Series(["a", "b", "c", None], dtype=dtype) + s2 = pl.Series(["a", "d", "e"], dtype=dtype) + missing_value = False if nulls_equal else None + expected = pl.Series([True, False, False, missing_value]) + assert_series_equal(s.is_in(s2, nulls_equal=nulls_equal), expected) + + s2_str = s2.cast(pl.String) + assert_series_equal(s.is_in(s2_str, nulls_equal=nulls_equal), expected) + + +@StringCache() +@pytest.mark.parametrize( + "nulls_equal", + [ + False, + pytest.param( + True, + marks=pytest.mark.xfail( + reason="Bug. See https://github.com/pola-rs/polars/issues/22260" + ), + ), + ], +) +def test_enum_is_in_series_non_existent(nulls_equal: bool) -> None: + dtype = pl.Enum(["a", "b", "c"]) + missing_value = False if nulls_equal else None + s = pl.Series(["a", "b", "c", None], dtype=dtype) + s2_str = pl.Series(["a", "d", "e"]) + expected = pl.Series([True, False, False, missing_value]) + assert_series_equal(s.is_in(s2_str, nulls_equal=nulls_equal), expected) + + +@StringCache() +@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum(["a", "b", "c"])]) +@pytest.mark.parametrize("nulls_equal", [False, True]) +def test_cat_is_in_with_lit_str(dtype: pl.DataType, nulls_equal: bool) -> None: + missing_value = False if nulls_equal else None + s = pl.Series(["a", "b", "c", None], dtype=dtype) + lit = ["b"] + expected = pl.Series([False, True, False, missing_value]) + + assert_series_equal(s.is_in(lit, nulls_equal=nulls_equal), expected) + + +@StringCache() +@pytest.mark.parametrize( + ("nulls_equal", "dtype"), + [ + (False, pl.Categorical), + (False, pl.Enum(["a", "b", "c"])), + (True, pl.Categorical), + pytest.param( + True, + pl.Enum(["a", "b", "c"]), + marks=pytest.mark.xfail( + reason="Bug. See https://github.com/pola-rs/polars/issues/22260" + ), + ), + ], +) +def test_cat_is_in_with_lit_str_non_existent( + dtype: pl.DataType, nulls_equal: bool +) -> None: + missing_value = False if nulls_equal else None + s = pl.Series(["a", "b", "c", None], dtype=dtype) + lit = ["d"] + expected = pl.Series([False, False, False, missing_value]) + + assert_series_equal(s.is_in(lit, nulls_equal=nulls_equal), expected) + + +@StringCache() +@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum(["a", "b", "c"])]) +def test_cat_is_in_with_lit_str_cache_setup(dtype: pl.DataType) -> None: + # init the global cache + _ = pl.Series(["c", "b", "a"], dtype=dtype) + + assert_series_equal(pl.Series(["a"], dtype=dtype).is_in(["a"]), pl.Series([True])) + assert_series_equal(pl.Series(["b"], dtype=dtype).is_in(["b"]), pl.Series([True])) + assert_series_equal(pl.Series(["c"], dtype=dtype).is_in(["c"]), pl.Series([True])) + + +def test_is_in_with_wildcard_13809() -> None: + out = pl.DataFrame({"A": ["B"]}).select(pl.all().is_in(["C"])) + expected = pl.DataFrame({"A": [False]}) + assert_frame_equal(out, expected) + + +@pytest.mark.parametrize( + "dtype", + [ + pytest.param(pl.Categorical, marks=pytest.mark.may_fail_auto_streaming), + pl.Enum(["a", "b", "c", "d"]), + ], +) +def test_cat_is_in_from_str(dtype: pl.DataType) -> None: + s = pl.Series(["c", "c", "b"], dtype=dtype) + + # test local + assert_series_equal( + pl.Series(["a", "d", "b"]).is_in(s), + pl.Series([False, False, True]), + ) + + +@pl.StringCache() +@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum(["a", "b", "c", "d"])]) +@pytest.mark.may_fail_auto_streaming +def test_cat_list_is_in_from_cat(dtype: pl.DataType) -> None: + df = pl.DataFrame( + [ + (["a", "b"], "c"), + (["a", "b"], "a"), + (["a", None], None), + (["a", "c"], None), + (["a"], "d"), + ], + schema={"li": pl.List(dtype), "x": dtype}, + orient="row", + ) + res = df.select(pl.col("li").list.contains(pl.col("x"))) + expected_df = pl.DataFrame({"li": [False, True, True, False, False]}) + assert_frame_equal(res, expected_df) + + +@pl.StringCache() +@pytest.mark.parametrize( + ("val", "expected"), + [ + ("b", [True, False, False, None, True]), + (None, [False, False, True, None, False]), + ("e", [False, False, False, None, False]), + ], +) +@pytest.mark.may_fail_auto_streaming +def test_cat_list_is_in_from_cat_single(val: str | None, expected: list[bool]) -> None: + df = pl.Series( + "li", + [["a", "b"], ["a", "c"], ["a", None], None, ["b"]], + dtype=pl.List(pl.Categorical), + ).to_frame() + res = df.select(pl.col("li").list.contains(pl.lit(val, dtype=pl.Categorical))) + expected_df = pl.DataFrame({"li": expected}) + assert_frame_equal(res, expected_df) + + +@pl.StringCache() +def test_cat_list_is_in_from_str() -> None: + df = pl.DataFrame( + [ + (["a", "b"], "c"), + (["a", "b"], "a"), + (["a", None], None), + (["a", "c"], None), + (["a"], "d"), + ], + schema={"li": pl.List(pl.Categorical), "x": pl.String}, + orient="row", + ) + res = df.select(pl.col("li").list.contains(pl.col("x"))) + expected_df = pl.DataFrame({"li": [False, True, True, False, False]}) + assert_frame_equal(res, expected_df) + + +@pl.StringCache() +@pytest.mark.parametrize( + ("val", "expected"), + [ + ("b", [True, False, False, None, True]), + (None, [False, False, True, None, False]), + ("e", [False, False, False, None, False]), + ], +) +def test_cat_list_is_in_from_single_str(val: str | None, expected: list[bool]) -> None: + df = pl.Series( + "li", + [["a", "b"], ["a", "c"], ["a", None], None, ["b"]], + dtype=pl.List(pl.Categorical), + ).to_frame() + res = df.select(pl.col("li").list.contains(pl.lit(val, dtype=pl.String))) + expected_df = pl.DataFrame({"li": expected}) + assert_frame_equal(res, expected_df) + + +@pytest.mark.parametrize("nulls_equal", [False, True]) +def test_is_in_struct_enum_17618(nulls_equal: bool) -> None: + df = pl.DataFrame() + dtype = pl.Enum(categories=["HBS"]) + df = df.insert_column(0, pl.Series("category", [], dtype=dtype)) + assert df.filter( + pl.struct("category").is_in( + pl.Series( + [{"category": "HBS"}], + dtype=pl.Struct({"category": df["category"].dtype}), + ), + nulls_equal=nulls_equal, + ) + ).shape == (0, 1) + + +@pytest.mark.parametrize("nulls_equal", [False, True]) +def test_is_in_decimal(nulls_equal: bool) -> None: + assert pl.DataFrame({"a": [D("0.0"), D("0.2"), D("0.1")]}).select( + pl.col("a").is_in([0.0, 0.1], nulls_equal=nulls_equal) + )["a"].to_list() == [True, False, True] + assert pl.DataFrame({"a": [D("0.0"), D("0.2"), D("0.1")]}).select( + pl.col("a").is_in([D("0.0"), D("0.1")], nulls_equal=nulls_equal) + )["a"].to_list() == [True, False, True] + assert pl.DataFrame({"a": [D("0.0"), D("0.2"), D("0.1")]}).select( + pl.col("a").is_in([1, 0, 2], nulls_equal=nulls_equal) + )["a"].to_list() == [True, False, False] + missing_value = True if nulls_equal else None + assert pl.DataFrame({"a": [D("0.0"), D("0.2"), None]}).select( + pl.col("a").is_in([0.0, 0.1, None], nulls_equal=nulls_equal) + )["a"].to_list() == [True, False, missing_value] + missing_value = False if nulls_equal else None + assert pl.DataFrame({"a": [D("0.0"), D("0.2"), None]}).select( + pl.col("a").is_in([0.0, 0.1], nulls_equal=nulls_equal) + )["a"].to_list() == [True, False, missing_value] + + +def test_is_in_collection() -> None: + df = pl.DataFrame( + { + "lbl": ["aa", "bb", "cc", "dd", "ee"], + "val": [0, 1, 2, 3, 4], + } + ) + + class CustomCollection(Collection[int]): + def __init__(self, vals: Collection[int]) -> None: + super().__init__() + self.vals = vals + + def __contains__(self, x: object) -> bool: + return x in self.vals + + def __iter__(self) -> Iterator[int]: + yield from self.vals + + def __len__(self) -> int: + return len(self.vals) + + for constraint_values in ( + {3, 2, 1}, + frozenset({3, 2, 1}), + CustomCollection([3, 2, 1]), + ): + res = df.filter(pl.col("val").is_in(constraint_values)) + assert set(res["lbl"]) == {"bb", "cc", "dd"} + + +@pytest.mark.parametrize("nulls_equal", [False, True]) +def test_null_propagate_all_paths(nulls_equal: bool) -> None: + # No nulls in either + s = pl.Series([1, 2, 3]) + result = s.is_in([1, 3, 8], nulls_equal=nulls_equal) + expected = pl.Series([True, False, True]) + assert_series_equal(result, expected) + + # Nulls in left only + s = pl.Series([1, 2, None]) + result = s.is_in([1, 3, 8], nulls_equal=nulls_equal) + missing_value = False if nulls_equal else None + expected = pl.Series([True, False, missing_value]) + assert_series_equal(result, expected) + + # Nulls in right only + s = pl.Series([1, 2, 3]) + result = s.is_in([1, 3, None], nulls_equal=nulls_equal) + expected = pl.Series([True, False, True]) + assert_series_equal(result, expected) + + # Nulls in both + s = pl.Series([1, 2, None]) + result = s.is_in([1, 3, None], nulls_equal=nulls_equal) + missing_value = True if nulls_equal else None + expected = pl.Series([True, False, missing_value]) + assert_series_equal(result, expected) + + +@pytest.mark.usefixtures("test_global_and_local") +@pytest.mark.parametrize("nulls_equal", [False, True]) +def test_null_propagate_all_paths_cat(nulls_equal: bool) -> None: + # No nulls in either + s = pl.Series(["1", "2", "3"]) + result = s.is_in(["1", "3", "8"], nulls_equal=nulls_equal) + expected = pl.Series([True, False, True]) + assert_series_equal(result, expected) + + # Nulls in left only + s = pl.Series(["1", "2", None]) + result = s.is_in(["1", "3", "8"], nulls_equal=nulls_equal) + missing_value = False if nulls_equal else None + expected = pl.Series([True, False, missing_value]) + assert_series_equal(result, expected) + + # Nulls in right only + s = pl.Series(["1", "2", "3"]) + result = s.is_in(["1", "3", None], nulls_equal=nulls_equal) + expected = pl.Series([True, False, True]) + assert_series_equal(result, expected) + + # Nulls in both + s = pl.Series(["1", "2", None]) + result = s.is_in(["1", "3", None], nulls_equal=nulls_equal) + missing_value = True if nulls_equal else None + expected = pl.Series([True, False, missing_value]) + assert_series_equal(result, expected) diff --git a/py-polars/tests/unit/operations/test_is_null.py b/py-polars/tests/unit/operations/test_is_null.py new file mode 100644 index 000000000000..159977b2d0de --- /dev/null +++ b/py-polars/tests/unit/operations/test_is_null.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from hypothesis import given + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal +from polars.testing.parametric import series + + +@given(s=series(allow_null=True)) +def test_is_null_parametric(s: pl.Series) -> None: + is_null = s.is_null() + is_not_null = s.is_not_null() + + assert is_null.null_count() == 0 + assert_series_equal(is_null, ~is_not_null) + + +def test_is_null_struct() -> None: + df = pl.DataFrame( + {"x": [{"a": 1, "b": 2}, {"a": None, "b": None}, {"a": None, "b": 2}, None]} + ) + + result = df.select( + null=pl.col("x").is_null(), + not_null=pl.col("x").is_not_null(), + ) + + expected = pl.DataFrame( + { + "null": [False, False, False, True], + "not_null": [True, True, True, False], + } + ) + assert_frame_equal(result, expected) + + +def test_is_null_null() -> None: + s = pl.Series([None, None]) + + result = s.is_null() + expected = pl.Series([True, True]) + assert_series_equal(result, expected) + + result = s.is_not_null() + expected = pl.Series([False, False]) + assert_series_equal(result, expected) diff --git a/py-polars/tests/unit/operations/test_is_sorted.py b/py-polars/tests/unit/operations/test_is_sorted.py new file mode 100644 index 000000000000..f75bb5fe2c81 --- /dev/null +++ b/py-polars/tests/unit/operations/test_is_sorted.py @@ -0,0 +1,439 @@ +from datetime import date +from typing import Any + +import numpy as np +import pandas as pd +import pytest + +import polars as pl +from polars.testing import assert_series_equal + + +def is_sorted_any(s: pl.Series) -> bool: + return s.flags["SORTED_ASC"] or s.flags["SORTED_DESC"] + + +def is_not_sorted(s: pl.Series) -> bool: + return not is_sorted_any(s) + + +def test_sorted_flag_14552() -> None: + a = pl.DataFrame({"a": [2, 1, 3]}) + + a = pl.concat([a, a], rechunk=False) + assert not a.join(a, on="a", how="left")["a"].flags["SORTED_ASC"] + + +def test_sorted_flag_concat_15072() -> None: + # Both all-null + a = pl.Series("x", [None, None], dtype=pl.Int8) + b = pl.Series("x", [None, None], dtype=pl.Int8) + assert pl.concat((a, b)).flags["SORTED_ASC"] + + # left all-null, right 0 < null_count < len + a = pl.Series("x", [None, None], dtype=pl.Int8) + b = pl.Series("x", [1, 2, 1, None], dtype=pl.Int8) + + out = pl.concat((a, b.sort())) + assert out.to_list() == [None, None, None, 1, 1, 2] + assert out.flags["SORTED_ASC"] + + out = pl.concat((a, b.sort(descending=True))) + assert out.to_list() == [None, None, None, 2, 1, 1] + assert out.flags["SORTED_DESC"] + + out = pl.concat((a, b.sort(nulls_last=True))) + assert out.to_list() == [None, None, 1, 1, 2, None] + assert is_not_sorted(out) + + out = pl.concat((a, b.sort(nulls_last=True, descending=True))) + assert out.to_list() == [None, None, 2, 1, 1, None] + assert is_not_sorted(out) + + # left 0 < null_count < len, right all-null + a = pl.Series("x", [1, 2, 1, None], dtype=pl.Int8) + b = pl.Series("x", [None, None], dtype=pl.Int8) + + out = pl.concat((a.sort(), b)) + assert out.to_list() == [None, 1, 1, 2, None, None] + assert is_not_sorted(out) + + out = pl.concat((a.sort(descending=True), b)) + assert out.to_list() == [None, 2, 1, 1, None, None] + assert is_not_sorted(out) + + out = pl.concat((a.sort(nulls_last=True), b)) + assert out.to_list() == [1, 1, 2, None, None, None] + assert out.flags["SORTED_ASC"] + + out = pl.concat((a.sort(nulls_last=True, descending=True), b)) + assert out.to_list() == [2, 1, 1, None, None, None] + assert out.flags["SORTED_DESC"] + + # both 0 < null_count < len + assert pl.concat( + ( + pl.Series([None, 1]).set_sorted(), + pl.Series([2]).set_sorted(), + ) + ).flags["SORTED_ASC"] + + assert is_not_sorted( + pl.concat( + ( + pl.Series([None, 1]).set_sorted(), + pl.Series([2, None]).set_sorted(), + ) + ) + ) + + assert pl.concat( + ( + pl.Series([None, 2]).set_sorted(descending=True), + pl.Series([1]).set_sorted(descending=True), + ) + ).flags["SORTED_DESC"] + + assert is_not_sorted( + pl.concat( + ( + pl.Series([None, 2]).set_sorted(descending=True), + pl.Series([1, None]).set_sorted(descending=True), + ) + ) + ) + + # Concat with empty series + s = pl.Series([None, 1]).set_sorted() + + out = pl.concat((s.clear(), s)) + assert_series_equal(out, s) + assert out.flags["SORTED_ASC"] + + out = pl.concat((s, s.clear())) + assert_series_equal(out, s) + assert out.flags["SORTED_ASC"] + + s = pl.Series([1, None]).set_sorted() + + out = pl.concat((s.clear(), s)) + assert_series_equal(out, s) + assert out.flags["SORTED_ASC"] + + out = pl.concat((s, s.clear())) + assert_series_equal(out, s) + assert out.flags["SORTED_ASC"] + + +@pytest.mark.parametrize("unit_descending", [True, False]) +def test_sorted_flag_concat_unit(unit_descending: bool) -> None: + unit = pl.Series([1]).set_sorted(descending=unit_descending) + + a = unit + b = pl.Series([2, 3]).set_sorted() + + out = pl.concat((a, b)) + assert out.to_list() == [1, 2, 3] + assert out.flags["SORTED_ASC"] + + out = pl.concat((b, a)) + assert out.to_list() == [2, 3, 1] + assert is_not_sorted(out) + + a = unit + b = pl.Series([3, 2]).set_sorted(descending=True) + + out = pl.concat((a, b)) + assert out.to_list() == [1, 3, 2] + assert is_not_sorted(out) + + out = pl.concat((b, a)) + assert out.to_list() == [3, 2, 1] + assert out.flags["SORTED_DESC"] + + # unit with nulls first + unit = pl.Series([None, 1]).set_sorted(descending=unit_descending) + + a = unit + b = pl.Series([2, 3]).set_sorted() + + out = pl.concat((a, b)) + assert out.to_list() == [None, 1, 2, 3] + assert out.flags["SORTED_ASC"] + + out = pl.concat((b, a)) + assert out.to_list() == [2, 3, None, 1] + assert is_not_sorted(out) + + a = unit + b = pl.Series([3, 2]).set_sorted(descending=True) + + out = pl.concat((a, b)) + assert out.to_list() == [None, 1, 3, 2] + assert is_not_sorted(out) + + out = pl.concat((b, a)) + assert out.to_list() == [3, 2, None, 1] + assert is_not_sorted(out) + + # unit with nulls last + unit = pl.Series([1, None]).set_sorted(descending=unit_descending) + + a = unit + b = pl.Series([2, 3]).set_sorted() + + out = pl.concat((a, b)) + assert out.to_list() == [1, None, 2, 3] + assert is_not_sorted(out) + + out = pl.concat((b, a)) + assert out.to_list() == [2, 3, 1, None] + assert is_not_sorted(out) + + a = unit + b = pl.Series([3, 2]).set_sorted(descending=True) + + out = pl.concat((a, b)) + assert out.to_list() == [1, None, 3, 2] + assert is_not_sorted(out) + + out = pl.concat((b, a)) + assert out.to_list() == [3, 2, 1, None] + assert out.flags["SORTED_DESC"] + + +def test_sorted_flag_null() -> None: + assert pl.DataFrame({"x": [None] * 2})["x"].flags["SORTED_ASC"] is False + + +def test_sorted_update_flags_10327() -> None: + assert pl.concat( + [ + pl.Series("a", [1], dtype=pl.Int64).to_frame(), + pl.Series("a", [], dtype=pl.Int64).to_frame(), + pl.Series("a", [2], dtype=pl.Int64).to_frame(), + pl.Series("a", [], dtype=pl.Int64).to_frame(), + ] + )["a"].to_list() == [1, 2] + + +def test_sorted_flag_unset_by_arithmetic_4937() -> None: + df = pl.DataFrame( + { + "ts": [1, 1, 1, 0, 1], + "price": [3.3, 3.0, 3.5, 3.6, 3.7], + "mask": [1, 1, 1, 1, 0], + } + ) + + assert df.sort("price").group_by("ts").agg( + [ + (pl.col("price") * pl.col("mask")).max().alias("pmax"), + (pl.col("price") * pl.col("mask")).min().alias("pmin"), + ] + ).sort("ts").to_dict(as_series=False) == { + "ts": [0, 1], + "pmax": [3.6, 3.5], + "pmin": [3.6, 0.0], + } + + +def test_unset_sorted_flag_after_extend() -> None: + df1 = pl.DataFrame({"Add": [37, 41], "Batch": [48, 49]}).sort("Add") + df2 = pl.DataFrame({"Add": [37], "Batch": [67]}).sort("Add") + + df1.extend(df2) + assert not df1["Add"].flags["SORTED_ASC"] + df = df1.group_by("Add").agg([pl.col("Batch").min()]).sort("Add") + assert df["Add"].flags["SORTED_ASC"] + assert df.to_dict(as_series=False) == {"Add": [37, 41], "Batch": [48, 49]} + + +def test_sorted_flag_after_streaming_join() -> None: + # streaming left join + df1 = pl.DataFrame({"x": [1, 2, 3, 4], "y": [2, 4, 6, 6]}).set_sorted("x") + df2 = pl.DataFrame({"x": [4, 2, 3, 1], "z": [1, 4, 9, 1]}) + assert ( + df1.lazy() + .join(df2.lazy(), on="x", how="left", maintain_order="left") + .collect(engine="old-streaming")["x"] # type: ignore[call-overload] + .flags["SORTED_ASC"] + ) + + +def test_sorted_flag_partition_by() -> None: + assert ( + pl.DataFrame({"one": [1, 2, 3], "two": ["a", "a", "b"]}) + .set_sorted("one") + .partition_by("two", maintain_order=True)[0]["one"] + .flags["SORTED_ASC"] + ) + + +@pytest.mark.parametrize("value", [1, "a", True]) +def test_sorted_flag_singletons(value: Any) -> None: + assert pl.DataFrame({"x": [value]})["x"].flags["SORTED_ASC"] is True + + +def test_is_sorted() -> None: + assert not pl.Series([1, 2, 5, None, 2, None]).is_sorted() + assert pl.Series([1, 2, 4, None, None]).is_sorted(nulls_last=True) + assert pl.Series([None, None, 1, 2, 4]).is_sorted(nulls_last=False) + assert not pl.Series([None, 1, None, 2, 4]).is_sorted() + assert not pl.Series([None, 1, 2, 3, -1, 4]).is_sorted(nulls_last=False) + assert not pl.Series([1, 2, 3, -1, 4, None, None]).is_sorted(nulls_last=True) + assert not pl.Series([1, 2, 3, -1, 4]).is_sorted() + assert pl.Series([1, 2, 3, 4]).is_sorted() + assert pl.Series([5, 2, 1, 1, -1]).is_sorted(descending=True) + assert pl.Series([None, None, 5, 2, 1, 1, -1]).is_sorted( + descending=True, nulls_last=False + ) + assert pl.Series([5, 2, 1, 1, -1, None, None]).is_sorted( + descending=True, nulls_last=True + ) + assert not pl.Series([5, None, 2, 1, 1, -1, None, None]).is_sorted( + descending=True, nulls_last=True + ) + assert not pl.Series([5, 2, 1, 10, 1, -1, None, None]).is_sorted( + descending=True, nulls_last=True + ) + + +def test_sorted_flag() -> None: + s = pl.arange(0, 7, eager=True) + assert s.flags["SORTED_ASC"] + assert s.reverse().flags["SORTED_DESC"] + assert pl.Series([b"a"]).set_sorted().flags["SORTED_ASC"] + assert ( + pl.Series([date(2020, 1, 1), date(2020, 1, 2)]) + .set_sorted() + .cast(pl.Datetime) + .flags["SORTED_ASC"] + ) + + # empty + q = pl.LazyFrame( + schema={ + "store_id": pl.UInt16, + "item_id": pl.UInt32, + "timestamp": pl.Datetime, + } + ).sort("timestamp") + + assert q.collect()["timestamp"].flags["SORTED_ASC"] + + # ensure we don't panic for these types + # struct + pl.Series([{"a": 1}]).set_sorted(descending=True) + # list + pl.Series([[{"a": 1}]]).set_sorted(descending=True) + # object + pl.Series([{"a": 1}], dtype=pl.Object).set_sorted(descending=True) + + +@pytest.mark.may_fail_auto_streaming +def test_sorted_flag_after_joins() -> None: + np.random.seed(1) + dfa = pl.DataFrame( + { + "a": np.random.randint(0, 13, 20), + "b": np.random.randint(0, 13, 20), + } + ).sort("a") + + dfb = pl.DataFrame( + { + "a": np.random.randint(0, 13, 10), + "b": np.random.randint(0, 13, 10), + } + ) + + dfapd = dfa.to_pandas() + dfbpd = dfb.to_pandas() + + def test_with_pd( + dfa: pd.DataFrame, dfb: pd.DataFrame, on: str, how: str, joined: pl.DataFrame + ) -> None: + a = ( + dfa.merge( + dfb, + on=on, + how=how, # type: ignore[arg-type] + suffixes=("", "_right"), + ) + .sort_values(["a", "b"]) + .reset_index(drop=True) + ) + b = joined.sort(["a", "b"]).to_pandas() + pd.testing.assert_frame_equal(a, b) + + joined = dfa.join(dfb, on="b", how="left", coalesce=True) + assert joined["a"].flags["SORTED_ASC"] + test_with_pd(dfapd, dfbpd, "b", "left", joined) + + joined = dfa.join(dfb, on="b", how="inner") + assert joined["a"].flags["SORTED_ASC"] + test_with_pd(dfapd, dfbpd, "b", "inner", joined) + + joined = dfa.join(dfb, on="b", how="semi") + assert joined["a"].flags["SORTED_ASC"] + joined = dfa.join(dfb, on="b", how="semi") + assert joined["a"].flags["SORTED_ASC"] + + joined = dfb.join(dfa, on="b", how="left", coalesce=True) + assert not joined["a"].flags["SORTED_ASC"] + test_with_pd(dfbpd, dfapd, "b", "left", joined) + + joined = dfb.join(dfa, on="b", how="inner") + if (joined["a"] != sorted(joined["a"])).any(): + assert not joined["a"].flags["SORTED_ASC"] + + joined = dfb.join(dfa, on="b", how="semi") + if (joined["a"] != sorted(joined["a"])).any(): + assert not joined["a"].flags["SORTED_ASC"] + + joined = dfb.join(dfa, on="b", how="anti") + if (joined["a"] != sorted(joined["a"])).any(): + assert not joined["a"].flags["SORTED_ASC"] + + +def test_sorted_flag_group_by_dynamic() -> None: + df = pl.DataFrame({"ts": [date(2020, 1, 1), date(2020, 1, 2)], "val": [1, 2]}) + assert ( + ( + df.group_by_dynamic(pl.col("ts").set_sorted(), every="1d").agg( + pl.col("val").sum() + ) + ) + .to_series() + .flags["SORTED_ASC"] + ) + + +def test_is_sorted_rle_id() -> None: + assert pl.Series([12, 3345, 12, 3, 4, 4, 1, 12]).rle_id().flags["SORTED_ASC"] + + +def test_is_sorted_chunked_select() -> None: + df = pl.DataFrame({"a": np.ones(14)}) + + assert ( + pl.concat([df, df, df], rechunk=False) + .set_sorted("a") + .select(pl.col("a").alias("b")) + )["b"].flags["SORTED_ASC"] + + +def test_is_sorted_arithmetic_overflow_14106() -> None: + s = pl.Series([0, 200], dtype=pl.UInt8).sort() + assert not (s + 200).is_sorted() + + +def test_is_sorted_struct() -> None: + s = pl.Series("a", [{"x": 3}, {"x": 1}, {"x": 2}]).sort() + assert s.flags["SORTED_ASC"] + assert not s.flags["SORTED_DESC"] + + s = s.sort(descending=True) + assert s.flags["SORTED_DESC"] + assert not s.flags["SORTED_ASC"] diff --git a/py-polars/tests/unit/operations/test_join.py b/py-polars/tests/unit/operations/test_join.py new file mode 100644 index 000000000000..2c21a9444395 --- /dev/null +++ b/py-polars/tests/unit/operations/test_join.py @@ -0,0 +1,2134 @@ +from __future__ import annotations + +import typing +import warnings +from datetime import date, datetime +from time import perf_counter +from typing import TYPE_CHECKING, Any, Callable, Literal + +import numpy as np +import pandas as pd +import pytest + +import polars as pl +from polars.exceptions import ( + ColumnNotFoundError, + ComputeError, + DuplicateError, + InvalidOperationError, + SchemaError, +) +from polars.testing import assert_frame_equal, assert_series_equal +from tests.unit.conftest import with_string_cache_if_auto_streaming + +if TYPE_CHECKING: + from polars._typing import JoinStrategy, PolarsDataType + + +def test_semi_anti_join() -> None: + df_a = pl.DataFrame({"key": [1, 2, 3], "payload": ["f", "i", None]}) + + df_b = pl.DataFrame({"key": [3, 4, 5, None]}) + + assert df_a.join(df_b, on="key", how="anti").to_dict(as_series=False) == { + "key": [1, 2], + "payload": ["f", "i"], + } + assert df_a.join(df_b, on="key", how="semi").to_dict(as_series=False) == { + "key": [3], + "payload": [None], + } + + # lazy + result = df_a.lazy().join(df_b.lazy(), on="key", how="anti").collect() + expected_values = {"key": [1, 2], "payload": ["f", "i"]} + assert result.to_dict(as_series=False) == expected_values + + result = df_a.lazy().join(df_b.lazy(), on="key", how="semi").collect() + expected_values = {"key": [3], "payload": [None]} + assert result.to_dict(as_series=False) == expected_values + + df_a = pl.DataFrame( + {"a": [1, 2, 3, 1], "b": ["a", "b", "c", "a"], "payload": [10, 20, 30, 40]} + ) + + df_b = pl.DataFrame({"a": [3, 3, 4, 5], "b": ["c", "c", "d", "e"]}) + + assert df_a.join(df_b, on=["a", "b"], how="anti").to_dict(as_series=False) == { + "a": [1, 2, 1], + "b": ["a", "b", "a"], + "payload": [10, 20, 40], + } + assert df_a.join(df_b, on=["a", "b"], how="semi").to_dict(as_series=False) == { + "a": [3], + "b": ["c"], + "payload": [30], + } + + +@pytest.mark.may_fail_auto_streaming +def test_join_same_cat_src() -> None: + df = pl.DataFrame( + data={"column": ["a", "a", "b"], "more": [1, 2, 3]}, + schema=[("column", pl.Categorical), ("more", pl.Int32)], + ) + df_agg = df.group_by("column").agg(pl.col("more").mean()) + assert_frame_equal( + df.join(df_agg, on="column"), + pl.DataFrame( + { + "column": ["a", "a", "b"], + "more": [1, 2, 3], + "more_right": [1.5, 1.5, 3.0], + }, + schema=[ + ("column", pl.Categorical), + ("more", pl.Int32), + ("more_right", pl.Float64), + ], + ), + check_row_order=False, + ) + + +@pytest.mark.parametrize("reverse", [False, True]) +def test_sorted_merge_joins(reverse: bool) -> None: + n = 30 + df_a = pl.DataFrame({"a": np.sort(np.random.randint(0, n // 2, n))}).with_row_index( + "row_a" + ) + df_b = pl.DataFrame( + {"a": np.sort(np.random.randint(0, n // 2, n // 2))} + ).with_row_index("row_b") + + if reverse: + df_a = df_a.select(pl.all().reverse()) + df_b = df_b.select(pl.all().reverse()) + + join_strategies: list[JoinStrategy] = ["left", "inner"] + for cast_to in [int, str, float]: + for how in join_strategies: + df_a_ = df_a.with_columns(pl.col("a").cast(cast_to)) + df_b_ = df_b.with_columns(pl.col("a").cast(cast_to)) + + # hash join + out_hash_join = df_a_.join(df_b_, on="a", how=how) + + # sorted merge join + out_sorted_merge_join = df_a_.with_columns( + pl.col("a").set_sorted(descending=reverse) + ).join( + df_b_.with_columns(pl.col("a").set_sorted(descending=reverse)), + on="a", + how=how, + ) + + assert_frame_equal( + out_hash_join, out_sorted_merge_join, check_row_order=False + ) + + +def test_join_negative_integers() -> None: + expected = pl.DataFrame({"a": [-6, -1, 0], "b": [-6, -1, 0]}) + df1 = pl.DataFrame( + { + "a": [-1, -6, -3, 0], + } + ) + + df2 = pl.DataFrame( + { + "a": [-6, -1, -4, -2, 0], + "b": [-6, -1, -4, -2, 0], + } + ) + + for dt in [pl.Int8, pl.Int16, pl.Int32, pl.Int64]: + assert_frame_equal( + df1.with_columns(pl.all().cast(dt)).join( + df2.with_columns(pl.all().cast(dt)), on="a", how="inner" + ), + expected.select(pl.all().cast(dt)), + check_row_order=False, + ) + + +def test_deprecated() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + other = pl.DataFrame({"a": [1, 2], "c": [3, 4]}) + result = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [3, 4]}) + + np.testing.assert_equal( + df.join(other=other, on="a", maintain_order="left").to_numpy(), + result.to_numpy(), + ) + np.testing.assert_equal( + df.lazy() + .join(other=other.lazy(), on="a", maintain_order="left") + .collect() + .to_numpy(), + result.to_numpy(), + ) + + +def test_deprecated_parameter_join_nulls() -> None: + df = pl.DataFrame({"a": [1, None]}) + with pytest.deprecated_call( + match=r"The argument `join_nulls` for `DataFrame.join` is deprecated. It has been renamed to `nulls_equal`" + ): + result = df.join(df, on="a", join_nulls=True) # type: ignore[call-arg] + assert_frame_equal(result, df, check_row_order=False) + + +def test_join_on_expressions() -> None: + df_a = pl.DataFrame({"a": [1, 2, 3]}) + + df_b = pl.DataFrame({"b": [1, 4, 9, 9, 0]}) + + assert_frame_equal( + df_a.join(df_b, left_on=(pl.col("a") ** 2).cast(int), right_on=pl.col("b")), + pl.DataFrame({"a": [1, 2, 3, 3], "b": [1, 4, 9, 9]}), + check_row_order=False, + ) + + +def test_join_lazy_frame_on_expression() -> None: + # Tests a lazy frame projection pushdown bug + # https://github.com/pola-rs/polars/issues/19822 + + df = pl.DataFrame(data={"a": [0, 1], "b": [2, 3]}) + + lazy_join = ( + df.lazy() + .join(df.lazy(), left_on=pl.coalesce("b", "a"), right_on="a") + .select("a") + .collect() + ) + + eager_join = df.join(df, left_on=pl.coalesce("b", "a"), right_on="a").select("a") + + assert lazy_join.shape == eager_join.shape + + +def test_join() -> None: + df_left = pl.DataFrame( + { + "a": ["a", "b", "a", "z"], + "b": [1, 2, 3, 4], + "c": [6, 5, 4, 3], + } + ) + df_right = pl.DataFrame( + { + "a": ["b", "c", "b", "a"], + "k": [0, 3, 9, 6], + "c": [1, 0, 2, 1], + } + ) + + joined = df_left.join( + df_right, left_on="a", right_on="a", maintain_order="left_right" + ).sort("a") + assert_series_equal(joined["b"], pl.Series("b", [1, 3, 2, 2])) + + joined = df_left.join( + df_right, left_on="a", right_on="a", how="left", maintain_order="left_right" + ).sort("a") + assert joined["c_right"].is_null().sum() == 1 + assert_series_equal(joined["b"], pl.Series("b", [1, 3, 2, 2, 4])) + + joined = df_left.join(df_right, left_on="a", right_on="a", how="full").sort("a") + assert joined["c_right"].null_count() == 1 + assert joined["c"].null_count() == 1 + assert joined["b"].null_count() == 1 + assert joined["k"].null_count() == 1 + assert joined["a"].null_count() == 1 + + # we need to pass in a column to join on, either by supplying `on`, or both + # `left_on` and `right_on` + with pytest.raises(ValueError): + df_left.join(df_right) + with pytest.raises(ValueError): + df_left.join(df_right, right_on="a") + with pytest.raises(ValueError): + df_left.join(df_right, left_on="a") + + df_a = pl.DataFrame({"a": [1, 2, 1, 1], "b": ["a", "b", "c", "c"]}) + df_b = pl.DataFrame( + {"foo": [1, 1, 1], "bar": ["a", "c", "c"], "ham": ["let", "var", "const"]} + ) + + # just check if join on multiple columns runs + df_a.join(df_b, left_on=["a", "b"], right_on=["foo", "bar"]) + eager_join = df_a.join(df_b, left_on="a", right_on="foo") + lazy_join = df_a.lazy().join(df_b.lazy(), left_on="a", right_on="foo").collect() + + cols = ["a", "b", "bar", "ham"] + assert lazy_join.shape == eager_join.shape + assert_frame_equal(lazy_join.sort(by=cols), eager_join.sort(by=cols)) + + +def test_joins_dispatch() -> None: + # this just flexes the dispatch a bit + + # don't change the data of this dataframe, this triggered: + # https://github.com/pola-rs/polars/issues/1688 + dfa = pl.DataFrame( + { + "a": ["a", "b", "c", "a"], + "b": [1, 2, 3, 1], + "date": ["2021-01-01", "2021-01-02", "2021-01-03", "2021-01-01"], + "datetime": [13241324, 12341256, 12341234, 13241324], + } + ).with_columns( + pl.col("date").str.strptime(pl.Date), pl.col("datetime").cast(pl.Datetime) + ) + + join_strategies: list[JoinStrategy] = ["left", "inner", "full"] + for how in join_strategies: + dfa.join(dfa, on=["a", "b", "date", "datetime"], how=how) + dfa.join(dfa, on=["date", "datetime"], how=how) + dfa.join(dfa, on=["date", "datetime", "a"], how=how) + dfa.join(dfa, on=["date", "a"], how=how) + dfa.join(dfa, on=["a", "datetime"], how=how) + dfa.join(dfa, on=["date"], how=how) + + +def test_join_on_cast() -> None: + df_a = ( + pl.DataFrame({"a": [-5, -2, 3, 3, 9, 10]}) + .with_row_index() + .with_columns(pl.col("a").cast(pl.Int32)) + ) + + df_b = pl.DataFrame({"a": [-2, -3, 3, 10]}) + + assert_frame_equal( + df_a.join(df_b, on=pl.col("a").cast(pl.Int64)), + pl.DataFrame( + { + "index": [1, 2, 3, 5], + "a": [-2, 3, 3, 10], + "a_right": [-2, 3, 3, 10], + } + ), + check_row_order=False, + check_dtypes=False, + ) + assert df_a.lazy().join( + df_b.lazy(), + on=pl.col("a").cast(pl.Int64), + maintain_order="left", + ).collect().to_dict(as_series=False) == { + "index": [1, 2, 3, 5], + "a": [-2, 3, 3, 10], + "a_right": [-2, 3, 3, 10], + } + + +def test_join_chunks_alignment_4720() -> None: + # https://github.com/pola-rs/polars/issues/4720 + + df1 = pl.DataFrame( + { + "index1": pl.arange(0, 2, eager=True), + "index2": pl.arange(10, 12, eager=True), + } + ) + + df2 = pl.DataFrame( + { + "index3": pl.arange(100, 102, eager=True), + } + ) + + df3 = pl.DataFrame( + { + "index1": pl.arange(0, 2, eager=True), + "index2": pl.arange(10, 12, eager=True), + "index3": pl.arange(100, 102, eager=True), + } + ) + assert_frame_equal( + df1.join(df2, how="cross").join( + df3, + on=["index1", "index2", "index3"], + how="left", + ), + pl.DataFrame( + { + "index1": [0, 0, 1, 1], + "index2": [10, 10, 11, 11], + "index3": [100, 101, 100, 101], + } + ), + check_row_order=False, + ) + + assert_frame_equal( + df1.join(df2, how="cross").join( + df3, + on=["index3", "index1", "index2"], + how="left", + ), + pl.DataFrame( + { + "index1": [0, 0, 1, 1], + "index2": [10, 10, 11, 11], + "index3": [100, 101, 100, 101], + } + ), + check_row_order=False, + ) + + +def test_jit_sort_joins() -> None: + n = 200 + # Explicitly specify numpy dtype because of different defaults on Windows + dfa = pd.DataFrame( + { + "a": np.random.randint(0, 100, n, dtype=np.int64), + "b": np.arange(0, n, dtype=np.int64), + } + ) + + n = 40 + dfb = pd.DataFrame( + { + "a": np.random.randint(0, 100, n, dtype=np.int64), + "b": np.arange(0, n, dtype=np.int64), + } + ) + dfa_pl = pl.from_pandas(dfa).sort("a") + dfb_pl = pl.from_pandas(dfb) + + join_strategies: list[Literal["left", "inner"]] = ["left", "inner"] + for how in join_strategies: + pd_result = dfa.merge(dfb, on="a", how=how) + pd_result.columns = pd.Index(["a", "b", "b_right"]) + + # left key sorted right is not + pl_result = dfa_pl.join(dfb_pl, on="a", how=how).sort(["a", "b", "b_right"]) + + a = ( + pl.from_pandas(pd_result) + .with_columns(pl.all().cast(int)) + .sort(["a", "b", "b_right"]) + ) + assert_frame_equal(a, pl_result) + assert pl_result["a"].flags["SORTED_ASC"] + + # left key sorted right is not + pd_result = dfb.merge(dfa, on="a", how=how) + pd_result.columns = pd.Index(["a", "b", "b_right"]) + pl_result = dfb_pl.join(dfa_pl, on="a", how=how).sort(["a", "b", "b_right"]) + + a = ( + pl.from_pandas(pd_result) + .with_columns(pl.all().cast(int)) + .sort(["a", "b", "b_right"]) + ) + assert_frame_equal(a, pl_result) + assert pl_result["a"].flags["SORTED_ASC"] + + +def test_join_panic_on_binary_expr_5915() -> None: + df_a = pl.DataFrame({"a": [1, 2, 3]}).lazy() + df_b = pl.DataFrame({"b": [1, 4, 9, 9, 0]}).lazy() + + z = df_a.join(df_b, left_on=[(pl.col("a") + 1).cast(int)], right_on=[pl.col("b")]) + assert z.collect().to_dict(as_series=False) == {"a": [3], "b": [4]} + + +def test_semi_join_projection_pushdown_6423() -> None: + df1 = pl.DataFrame({"x": [1]}).lazy() + df2 = pl.DataFrame({"y": [1], "x": [1]}).lazy() + + assert ( + df1.join(df2, left_on="x", right_on="y", how="semi") + .join(df2, left_on="x", right_on="y", how="semi") + .select(["x"]) + ).collect().to_dict(as_series=False) == {"x": [1]} + + +def test_semi_join_projection_pushdown_6455() -> None: + df = pl.DataFrame( + { + "id": [1, 1, 2], + "timestamp": [ + datetime(2022, 12, 11), + datetime(2022, 12, 12), + datetime(2022, 1, 1), + ], + "value": [1, 2, 4], + } + ).lazy() + + latest = df.group_by("id").agg(pl.col("timestamp").max()) + df = df.join(latest, on=["id", "timestamp"], how="semi") + assert df.select(["id", "value"]).collect().to_dict(as_series=False) == { + "id": [1, 2], + "value": [2, 4], + } + + +def test_update() -> None: + df1 = pl.DataFrame( + { + "key1": [1, 2, 3, 4], + "key2": [1, 2, 3, 4], + "a": [1, 2, 3, 4], + "b": [1, 2, 3, 4], + "c": ["1", "2", "3", "4"], + "d": [ + date(2023, 1, 1), + date(2023, 1, 2), + date(2023, 1, 3), + date(2023, 1, 4), + ], + } + ) + + df2 = pl.DataFrame( + { + "key1": [1, 2, 3, 4], + "key2": [1, 2, 3, 5], + "a": [1, 1, 1, 1], + "b": [2, 2, 2, 2], + "c": ["3", "3", "3", "3"], + "d": [ + date(2023, 5, 5), + date(2023, 5, 5), + date(2023, 5, 5), + date(2023, 5, 5), + ], + } + ) + + # update only on key1 + expected = pl.DataFrame( + { + "key1": [1, 2, 3, 4], + "key2": [1, 2, 3, 5], + "a": [1, 1, 1, 1], + "b": [2, 2, 2, 2], + "c": ["3", "3", "3", "3"], + "d": [ + date(2023, 5, 5), + date(2023, 5, 5), + date(2023, 5, 5), + date(2023, 5, 5), + ], + } + ) + assert_frame_equal(df1.update(df2, on="key1"), expected) + + # update on key1 using different left/right names + assert_frame_equal( + df1.update( + df2.rename({"key1": "key1b"}), + left_on="key1", + right_on="key1b", + ), + expected, + ) + + # update on key1 and key2. This should fail to match the last item. + expected = pl.DataFrame( + { + "key1": [1, 2, 3, 4], + "key2": [1, 2, 3, 4], + "a": [1, 1, 1, 4], + "b": [2, 2, 2, 4], + "c": ["3", "3", "3", "4"], + "d": [ + date(2023, 5, 5), + date(2023, 5, 5), + date(2023, 5, 5), + date(2023, 1, 4), + ], + } + ) + assert_frame_equal(df1.update(df2, on=["key1", "key2"]), expected) + + # update on key1 and key2 using different left/right names + assert_frame_equal( + df1.update( + df2.rename({"key1": "key1b", "key2": "key2b"}), + left_on=["key1", "key2"], + right_on=["key1b", "key2b"], + ), + expected, + ) + + df = pl.DataFrame({"A": [1, 2, 3, 4], "B": [400, 500, 600, 700]}) + + new_df = pl.DataFrame({"B": [4, None, 6], "C": [7, 8, 9]}) + + assert df.update(new_df).to_dict(as_series=False) == { + "A": [1, 2, 3, 4], + "B": [4, 500, 6, 700], + } + df1 = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + df2 = pl.DataFrame({"a": [2, 3], "b": [8, 9]}) + + assert df1.update(df2, on="a").to_dict(as_series=False) == { + "a": [1, 2, 3], + "b": [4, 8, 9], + } + + a = pl.LazyFrame({"a": [1, 2, 3]}) + b = pl.LazyFrame({"b": [4, 5], "c": [3, 1]}) + c = a.update(b) + + assert_frame_equal(a, c) + + # check behaviour of 'how' param + result = a.update(b, left_on="a", right_on="c") + assert result.collect().to_series().to_list() == [1, 2, 3] + + result = a.update(b, how="inner", left_on="a", right_on="c") + assert sorted(result.collect().to_series().to_list()) == [1, 3] + + result = a.update(b.rename({"b": "a"}), how="full", on="a") + assert sorted(result.collect().to_series().sort().to_list()) == [1, 2, 3, 4, 5] + + # check behavior of include_nulls=True + df = pl.DataFrame( + { + "A": [1, 2, 3, 4], + "B": [400, 500, 600, 700], + } + ) + new_df = pl.DataFrame( + { + "B": [-66, None, -99], + "C": [5, 3, 1], + } + ) + out = df.update(new_df, left_on="A", right_on="C", how="full", include_nulls=True) + expected = pl.DataFrame( + { + "A": [1, 2, 3, 4, 5], + "B": [-99, 500, None, 700, -66], + } + ) + assert_frame_equal(out, expected, check_row_order=False) + + # edge-case #11684 + x = pl.DataFrame({"a": [0, 1]}) + y = pl.DataFrame({"a": [2, 3]}) + assert sorted(x.update(y, on="a", how="full")["a"].to_list()) == [0, 1, 2, 3] + + # disallowed join strategies + for join_strategy in ("cross", "anti", "semi"): + with pytest.raises( + ValueError, + match=f"`how` must be one of {{'left', 'inner', 'full'}}; found '{join_strategy}'", + ): + a.update(b, how=join_strategy) # type: ignore[arg-type] + + +def test_join_frame_consistency() -> None: + df = pl.DataFrame({"A": [1, 2, 3]}) + ldf = pl.DataFrame({"A": [1, 2, 5]}).lazy() + + with pytest.raises(TypeError, match="expected `other`.* LazyFrame"): + _ = ldf.join(df, on="A") # type: ignore[arg-type] + with pytest.raises(TypeError, match="expected `other`.* DataFrame"): + _ = df.join(ldf, on="A") # type: ignore[arg-type] + with pytest.raises(TypeError, match="expected `other`.* LazyFrame"): + _ = ldf.join_asof(df, on="A") # type: ignore[arg-type] + with pytest.raises(TypeError, match="expected `other`.* DataFrame"): + _ = df.join_asof(ldf, on="A") # type: ignore[arg-type] + + +def test_join_concat_projection_pd_case_7071() -> None: + ldf = pl.DataFrame({"id": [1, 2], "value": [100, 200]}).lazy() + ldf2 = pl.DataFrame({"id": [1, 3], "value": [100, 300]}).lazy() + + ldf = ldf.join(ldf2, on=["id", "value"]) + ldf = pl.concat([ldf, ldf2]) + result = ldf.select("id") + + expected = pl.DataFrame({"id": [1, 1, 3]}).lazy() + assert_frame_equal(result, expected) + + +@pytest.mark.may_fail_auto_streaming # legacy full join is not order-preserving whereas new-streaming is +def test_join_sorted_fast_paths_null() -> None: + df1 = pl.DataFrame({"x": [0, 1, 0]}).sort("x") + df2 = pl.DataFrame({"x": [0, None], "y": [0, 1]}) + assert df1.join(df2, on="x", how="inner").to_dict(as_series=False) == { + "x": [0, 0], + "y": [0, 0], + } + assert df1.join(df2, on="x", how="left").to_dict(as_series=False) == { + "x": [0, 0, 1], + "y": [0, 0, None], + } + assert df1.join(df2, on="x", how="anti").to_dict(as_series=False) == {"x": [1]} + assert df1.join(df2, on="x", how="semi").to_dict(as_series=False) == {"x": [0, 0]} + assert df1.join(df2, on="x", how="full").to_dict(as_series=False) == { + "x": [0, 0, 1, None], + "x_right": [0, 0, None, None], + "y": [0, 0, None, 1], + } + + +def test_full_outer_join_list_() -> None: + schema = {"id": pl.Int64, "vals": pl.List(pl.Float64)} + join_schema = {**schema, **{k + "_right": t for (k, t) in schema.items()}} + df1 = pl.DataFrame({"id": [1], "vals": [[]]}, schema=schema) # type: ignore[arg-type] + df2 = pl.DataFrame({"id": [2, 3], "vals": [[], [4]]}, schema=schema) # type: ignore[arg-type] + expected = pl.DataFrame( + { + "id": [None, None, 1], + "vals": [None, None, []], + "id_right": [2, 3, None], + "vals_right": [[], [4.0], None], + }, + schema=join_schema, # type: ignore[arg-type] + ) + out = df1.join(df2, on="id", how="full", maintain_order="right_left") + assert_frame_equal(out, expected) + + +@pytest.mark.slow +def test_join_validation() -> None: + def test_each_join_validation( + unique: pl.DataFrame, duplicate: pl.DataFrame, on: str, how: JoinStrategy + ) -> None: + # one_to_many + _one_to_many_success_inner = unique.join( + duplicate, on=on, how=how, validate="1:m" + ) + + with pytest.raises(ComputeError): + _one_to_many_fail_inner = duplicate.join( + unique, on=on, how=how, validate="1:m" + ) + + # one to one + with pytest.raises(ComputeError): + _one_to_one_fail_1_inner = unique.join( + duplicate, on=on, how=how, validate="1:1" + ) + + with pytest.raises(ComputeError): + _one_to_one_fail_2_inner = duplicate.join( + unique, on=on, how=how, validate="1:1" + ) + + # many to one + with pytest.raises(ComputeError): + _many_to_one_fail_inner = unique.join( + duplicate, on=on, how=how, validate="m:1" + ) + + _many_to_one_success_inner = duplicate.join( + unique, on=on, how=how, validate="m:1" + ) + + # many to many + _many_to_many_success_1_inner = duplicate.join( + unique, on=on, how=how, validate="m:m" + ) + + _many_to_many_success_2_inner = unique.join( + duplicate, on=on, how=how, validate="m:m" + ) + + # test data + short_unique = pl.DataFrame( + { + "id": [1, 2, 3, 4], + "id_str": ["1", "2", "3", "4"], + "name": ["hello", "world", "rust", "polars"], + } + ) + short_duplicate = pl.DataFrame( + {"id": [1, 2, 3, 1], "id_str": ["1", "2", "3", "1"], "cnt": [2, 4, 6, 1]} + ) + long_unique = pl.DataFrame( + { + "id": [1, 2, 3, 4, 5], + "id_str": ["1", "2", "3", "4", "5"], + "name": ["hello", "world", "rust", "polars", "meow"], + } + ) + long_duplicate = pl.DataFrame( + { + "id": [1, 2, 3, 1, 5], + "id_str": ["1", "2", "3", "1", "5"], + "cnt": [2, 4, 6, 1, 8], + } + ) + + join_strategies: list[JoinStrategy] = ["inner", "full", "left"] + + for join_col in ["id", "id_str"]: + for how in join_strategies: + # same size + test_each_join_validation(long_unique, long_duplicate, join_col, how) + + # left longer + test_each_join_validation(long_unique, short_duplicate, join_col, how) + + # right longer + test_each_join_validation(short_unique, long_duplicate, join_col, how) + + +@typing.no_type_check +def test_join_validation_many_keys() -> None: + # unique in both + df1 = pl.DataFrame( + { + "val1": [11, 12, 13, 14], + "val2": [1, 2, 3, 4], + } + ) + df2 = pl.DataFrame( + { + "val1": [11, 12, 13, 14], + "val2": [1, 2, 3, 4], + } + ) + for join_type in ["inner", "left", "full"]: + for val in ["m:m", "m:1", "1:1", "1:m"]: + df1.join(df2, on=["val1", "val2"], how=join_type, validate=val) + + # many in lhs + df1 = pl.DataFrame( + { + "val1": [11, 11, 12, 13, 14], + "val2": [1, 1, 2, 3, 4], + } + ) + + for join_type in ["inner", "left", "full"]: + for val in ["1:1", "1:m"]: + with pytest.raises(ComputeError): + df1.join(df2, on=["val1", "val2"], how=join_type, validate=val) + + # many in rhs + df1 = pl.DataFrame( + { + "val1": [11, 12, 13, 14], + "val2": [1, 2, 3, 4], + } + ) + df2 = pl.DataFrame( + { + "val1": [11, 11, 12, 13, 14], + "val2": [1, 1, 2, 3, 4], + } + ) + + for join_type in ["inner", "left", "full"]: + for val in ["m:1", "1:1"]: + with pytest.raises(ComputeError): + df1.join(df2, on=["val1", "val2"], how=join_type, validate=val) + + +def test_full_outer_join_bool() -> None: + df1 = pl.DataFrame({"id": [True, False], "val": [1, 2]}) + df2 = pl.DataFrame({"id": [True, False], "val": [0, -1]}) + assert df1.join(df2, on="id", how="full", maintain_order="right").to_dict( + as_series=False + ) == { + "id": [True, False], + "val": [1, 2], + "id_right": [True, False], + "val_right": [0, -1], + } + + +def test_full_outer_join_coalesce_different_names_13450() -> None: + df1 = pl.DataFrame({"L1": ["a", "b", "c"], "L3": ["b", "c", "d"], "L2": [1, 2, 3]}) + df2 = pl.DataFrame({"L3": ["a", "c", "d"], "R2": [7, 8, 9]}) + + expected = pl.DataFrame( + { + "L1": ["a", "c", "d", "b"], + "L3": ["b", "d", None, "c"], + "L2": [1, 3, None, 2], + "R2": [7, 8, 9, None], + } + ) + + out = df1.join(df2, left_on="L1", right_on="L3", how="full", coalesce=True) + assert_frame_equal(out, expected, check_row_order=False) + + +# https://github.com/pola-rs/polars/issues/10663 +def test_join_on_wildcard_error() -> None: + df = pl.DataFrame({"x": [1]}) + df2 = pl.DataFrame({"x": [1], "y": [2]}) + with pytest.raises( + InvalidOperationError, + ): + df.join(df2, on=pl.all()) + + +def test_join_on_nth_error() -> None: + df = pl.DataFrame({"x": [1]}) + df2 = pl.DataFrame({"x": [1], "y": [2]}) + with pytest.raises( + InvalidOperationError, + ): + df.join(df2, on=pl.first()) + + +def test_join_results_in_duplicate_names() -> None: + df = pl.DataFrame( + { + "a": [1, 2, 3], + "b": [4, 5, 6], + "c": [1, 2, 3], + "c_right": [1, 2, 3], + } + ) + + def f(x: Any) -> Any: + return x.join(x, on=["a", "b"], how="left") + + # Ensure it also contains the hint + match_str = "(?s)column with name 'c_right' already exists.*You may want to try" + + # Ensure it fails immediately when resolving schema. + with pytest.raises(DuplicateError, match=match_str): + f(df.lazy()).collect_schema() + + with pytest.raises(DuplicateError, match=match_str): + f(df.lazy()).collect() + + with pytest.raises(DuplicateError, match=match_str): + f(df).collect() + + +def test_join_duplicate_suffixed_columns_from_join_key_column_21048() -> None: + df = pl.DataFrame({"a": 1, "b": 1, "b_right": 1}) + + def f(x: Any) -> Any: + return x.join(x, on="a") + + # Ensure it also contains the hint + match_str = "(?s)column with name 'b_right' already exists.*You may want to try" + + # Ensure it fails immediately when resolving schema. + with pytest.raises(DuplicateError, match=match_str): + f(df.lazy()).collect_schema() + + with pytest.raises(DuplicateError, match=match_str): + f(df.lazy()).collect() + + with pytest.raises(DuplicateError, match=match_str): + f(df) + + +def test_join_projection_invalid_name_contains_suffix_15243() -> None: + df1 = pl.DataFrame({"a": [1, 2, 3]}).lazy() + df2 = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}).lazy() + + with pytest.raises(ColumnNotFoundError): + ( + df1.join(df2, on="a") + .select(pl.col("b").filter(pl.col("b") == pl.col("foo_right"))) + .collect() + ) + + +def test_join_list_non_numeric() -> None: + assert ( + pl.DataFrame( + { + "lists": [ + ["a", "b", "c"], + ["a", "c", "b"], + ["a", "c", "b"], + ["a", "c", "d"], + ] + } + ) + ).group_by("lists", maintain_order=True).agg(pl.len().alias("count")).to_dict( + as_series=False + ) == { + "lists": [["a", "b", "c"], ["a", "c", "b"], ["a", "c", "d"]], + "count": [1, 2, 1], + } + + +@pytest.mark.slow +def test_join_4_columns_with_validity() -> None: + # join on 4 columns so we trigger combine validities + # use 138 as that is 2 u64 and a remainder + a = pl.DataFrame( + {"a": [None if a % 6 == 0 else a for a in range(138)]} + ).with_columns( + b=pl.col("a"), + c=pl.col("a"), + d=pl.col("a"), + ) + + assert a.join(a, on=["a", "b", "c", "d"], how="inner", nulls_equal=True).shape == ( + 644, + 4, + ) + assert a.join(a, on=["a", "b", "c", "d"], how="inner", nulls_equal=False).shape == ( + 115, + 4, + ) + + +@pytest.mark.release +def test_cross_join() -> None: + # triggers > 100 rows implementation + # https://github.com/pola-rs/polars/blob/5f5acb2a523ce01bc710768b396762b8e69a9e07/polars/polars-core/src/frame/cross_join.rs#L34 + df1 = pl.DataFrame({"col1": ["a"], "col2": ["d"]}) + df2 = pl.DataFrame({"frame2": pl.arange(0, 100, eager=True)}) + out = df2.join(df1, how="cross") + df2 = pl.DataFrame({"frame2": pl.arange(0, 101, eager=True)}) + assert_frame_equal(df2.join(df1, how="cross").slice(0, 100), out) + + +@pytest.mark.release +def test_cross_join_slice_pushdown() -> None: + # this will likely go out of memory if we did not pushdown the slice + df = ( + pl.Series("x", pl.arange(0, 2**16 - 1, eager=True, dtype=pl.UInt16) % 2**15) + ).to_frame() + + result = df.lazy().join(df.lazy(), how="cross", suffix="_").slice(-5, 10).collect() + expected = pl.DataFrame( + { + "x": [32766, 32766, 32766, 32766, 32766], + "x_": [32762, 32763, 32764, 32765, 32766], + }, + schema={"x": pl.UInt16, "x_": pl.UInt16}, + ) + assert_frame_equal(result, expected) + + result = df.lazy().join(df.lazy(), how="cross", suffix="_").slice(2, 10).collect() + expected = pl.DataFrame( + { + "x": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + "x_": [2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + }, + schema={"x": pl.UInt16, "x_": pl.UInt16}, + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("how", ["left", "inner"]) +def test_join_coalesce(how: JoinStrategy) -> None: + a = pl.LazyFrame({"a": [1, 2], "b": [1, 2]}) + b = pl.LazyFrame( + { + "a": [1, 2, 1, 2], + "b": [5, 7, 8, 9], + "c": [1, 2, 1, 2], + } + ) + + how = "inner" + q = a.join(b, on="a", coalesce=False, how=how) + out = q.collect() + assert q.collect_schema() == out.schema + assert out.columns == ["a", "b", "a_right", "b_right", "c"] + + q = a.join(b, on=["a", "b"], coalesce=False, how=how) + out = q.collect() + assert q.collect_schema() == out.schema + assert out.columns == ["a", "b", "a_right", "b_right", "c"] + + q = a.join(b, on=["a", "b"], coalesce=True, how=how) + out = q.collect() + assert q.collect_schema() == out.schema + assert out.columns == ["a", "b", "c"] + + +@pytest.mark.parametrize("how", ["left", "inner", "full"]) +def test_join_empties(how: JoinStrategy) -> None: + df1 = pl.DataFrame({"col1": [], "col2": [], "col3": []}) + df2 = pl.DataFrame({"col2": [], "col4": [], "col5": []}) + + df = df1.join(df2, on="col2", how=how) + assert df.height == 0 + + +def test_join_raise_on_redundant_keys() -> None: + left = pl.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5], "c": [5, 6, 7]}) + right = pl.DataFrame({"a": [2, 3, 4], "c": [4, 5, 6]}) + with pytest.raises(InvalidOperationError, match="already joined on"): + left.join(right, on=["a", "a"], how="full", coalesce=True) + + +@pytest.mark.parametrize("coalesce", [False, True]) +def test_join_raise_on_repeated_expression_key_names(coalesce: bool) -> None: + left = pl.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5], "c": [5, 6, 7]}) + right = pl.DataFrame({"a": [2, 3, 4], "c": [4, 5, 6]}) + with ( # noqa: PT012 + pytest.raises(InvalidOperationError, match="already joined on"), + warnings.catch_warnings(), + ): + warnings.simplefilter(action="ignore", category=UserWarning) + left.join( + right, on=[pl.col("a"), pl.col("a") % 2], how="full", coalesce=coalesce + ) + + +def test_join_lit_panic_11410() -> None: + df = pl.LazyFrame({"date": [1, 2, 3], "symbol": [4, 5, 6]}) + dates = df.select("date").unique(maintain_order=True) + symbols = df.select("symbol").unique(maintain_order=True) + + assert symbols.join( + dates, left_on=pl.lit(1), right_on=pl.lit(1), maintain_order="left_right" + ).collect().to_dict(as_series=False) == { + "symbol": [4, 4, 4, 5, 5, 5, 6, 6, 6], + "date": [1, 2, 3, 1, 2, 3, 1, 2, 3], + } + + +def test_join_empty_literal_17027() -> None: + df1 = pl.DataFrame({"a": [1]}) + df2 = pl.DataFrame(schema={"a": pl.Int64}) + + assert df1.join(df2, on=pl.lit(0), how="left").height == 1 + assert df1.join(df2, on=pl.lit(0), how="inner").height == 0 + assert ( + df1.lazy() + .join(df2.lazy(), on=pl.lit(0), how="inner") + .collect(engine="streaming") + .height + == 0 + ) + assert ( + df1.lazy() + .join(df2.lazy(), on=pl.lit(0), how="left") + .collect(engine="streaming") + .height + == 1 + ) + + +@pytest.mark.parametrize( + ("left_on", "right_on"), + zip( + [pl.col("a"), pl.col("a").sort(), [pl.col("a"), pl.col("b")]], + [pl.col("a").slice(0, 2) * 2, pl.col("b"), [pl.col("a"), pl.col("b").head()]], + ), +) +def test_join_non_elementwise_keys_raises(left_on: pl.Expr, right_on: pl.Expr) -> None: + # https://github.com/pola-rs/polars/issues/17184 + left = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 4, 5]}) + right = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 4, 5]}) + + q = left.join( + right, + left_on=left_on, + right_on=right_on, + how="inner", + ) + + with pytest.raises(pl.exceptions.InvalidOperationError): + q.collect() + + +def test_join_coalesce_not_supported_warning() -> None: + # https://github.com/pola-rs/polars/issues/17184 + left = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 4, 5]}) + right = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 4, 5]}) + + q = left.join( + right, + left_on=[pl.col("a") * 2], + right_on=[pl.col("a") * 2], + how="inner", + coalesce=True, + ) + with pytest.warns(UserWarning, match="turning off key coalescing"): + got = q.collect() + expect = pl.DataFrame( + {"a": [1, 2, 3], "b": [3, 4, 5], "a_right": [1, 2, 3], "b_right": [3, 4, 5]} + ) + + assert_frame_equal(expect, got, check_row_order=False) + + +@pytest.mark.parametrize( + ("on_args"), + [ + {"on": "a", "left_on": "a"}, + {"on": "a", "right_on": "a"}, + {"on": "a", "left_on": "a", "right_on": "a"}, + ], +) +def test_join_on_and_left_right_on(on_args: dict[str, str]) -> None: + df1 = pl.DataFrame({"a": [1], "b": [2]}) + df2 = pl.DataFrame({"a": [1], "c": [3]}) + msg = "cannot use 'on' in conjunction with 'left_on' or 'right_on'" + with pytest.raises(ValueError, match=msg): + df1.join(df2, **on_args) # type: ignore[arg-type] + + +@pytest.mark.parametrize( + ("on_args"), + [ + {"left_on": "a"}, + {"right_on": "a"}, + ], +) +def test_join_only_left_or_right_on(on_args: dict[str, str]) -> None: + df1 = pl.DataFrame({"a": [1]}) + df2 = pl.DataFrame({"a": [1]}) + msg = "'left_on' requires corresponding 'right_on'" + with pytest.raises(ValueError, match=msg): + df1.join(df2, **on_args) # type: ignore[arg-type] + + +@pytest.mark.parametrize( + ("on_args"), + [ + {"on": "a"}, + {"left_on": "a", "right_on": "a"}, + ], +) +def test_cross_join_no_on_keys(on_args: dict[str, str]) -> None: + df1 = pl.DataFrame({"a": [1, 2]}) + df2 = pl.DataFrame({"b": [3, 4]}) + msg = "cross join should not pass join keys" + with pytest.raises(ValueError, match=msg): + df1.join(df2, how="cross", **on_args) # type: ignore[arg-type] + + +@pytest.mark.parametrize("set_sorted", [True, False]) +def test_left_join_slice_pushdown_19405(set_sorted: bool) -> None: + left = pl.LazyFrame({"k": [1, 2, 3, 4, 0]}) + right = pl.LazyFrame({"k": [1, 1, 1, 1, 0]}) + + if set_sorted: + # The data isn't actually sorted on purpose to ensure we default to a + # hash join unless we set the sorted flag here, in case there is new + # code in the future that automatically identifies sortedness during + # Series construction from Python. + left = left.set_sorted("k") + right = right.set_sorted("k") + + q = left.join(right, on="k", how="left", maintain_order="left_right").head(5) + assert_frame_equal(q.collect(), pl.DataFrame({"k": [1, 1, 1, 1, 2]})) + + +def test_join_key_type_coercion_19597() -> None: + left = pl.LazyFrame({"a": pl.Series([1, 2, 3], dtype=pl.Float64)}) + right = pl.LazyFrame({"a": pl.Series([1, 2, 3], dtype=pl.Int64)}) + + with pytest.raises(SchemaError, match="datatypes of join keys don't match"): + left.join(right, left_on=pl.col("a"), right_on=pl.col("a")).collect_schema() + + with pytest.raises(SchemaError, match="datatypes of join keys don't match"): + left.join( + right, left_on=pl.col("a") * 2, right_on=pl.col("a") * 2 + ).collect_schema() + + +def test_array_explode_join_19763() -> None: + q = pl.LazyFrame().select( + pl.lit(pl.Series([[1], [2]], dtype=pl.Array(pl.Int64, 1))).explode().alias("k") + ) + + q = q.join(pl.LazyFrame({"k": [1, 2]}), on="k") + + assert_frame_equal(q.collect().sort("k"), pl.DataFrame({"k": [1, 2]})) + + +@with_string_cache_if_auto_streaming +def test_join_full_19814() -> None: + schema = {"a": pl.Int64, "c": pl.Categorical} + a = pl.LazyFrame({"a": [1], "c": [None]}, schema=schema) + b = pl.LazyFrame({"a": [1, 3, 4]}) + assert_frame_equal( + a.join(b, on="a", how="full", coalesce=True).collect(), + pl.DataFrame({"a": [1, 3, 4], "c": [None, None, None]}, schema=schema), + check_row_order=False, + ) + + +def test_join_preserve_order_inner() -> None: + left = pl.LazyFrame({"a": [None, 2, 1, 1, 5]}) + right = pl.LazyFrame({"a": [1, 1, None, 2], "b": [6, 7, 8, 9]}) + + # Inner joins + + inner_left = left.join(right, on="a", how="inner", maintain_order="left").collect() + assert inner_left.get_column("a").cast(pl.UInt32).to_list() == [2, 1, 1, 1, 1] + inner_left_right = left.join( + right, on="a", how="inner", maintain_order="left" + ).collect() + assert inner_left.get_column("a").equals(inner_left_right.get_column("a")) + + inner_right = left.join( + right, on="a", how="inner", maintain_order="right" + ).collect() + assert inner_right.get_column("a").cast(pl.UInt32).to_list() == [1, 1, 1, 1, 2] + inner_right_left = left.join( + right, on="a", how="inner", maintain_order="right" + ).collect() + assert inner_right.get_column("a").equals(inner_right_left.get_column("a")) + + +# The new streaming engine does not provide the same maintain_order="none" +# ordering guarantee that is currently kept for compatibility on the in-memory +# engine. +@pytest.mark.may_fail_auto_streaming +def test_join_preserve_order_left() -> None: + left = pl.LazyFrame({"a": [None, 2, 1, 1, 5]}) + right = pl.LazyFrame({"a": [1, None, 2, 6], "b": [6, 7, 8, 9]}) + + # Right now the left join algorithm is ordered without explicitly setting any order + # This behaviour is deprecated but can only be removed in 2.0 + left_none = left.join(right, on="a", how="left", maintain_order="none").collect() + assert left_none.get_column("a").cast(pl.UInt32).to_list() == [ + None, + 2, + 1, + 1, + 5, + ] + + left_left = left.join(right, on="a", how="left", maintain_order="left").collect() + assert left_left.get_column("a").cast(pl.UInt32).to_list() == [ + None, + 2, + 1, + 1, + 5, + ] + + left_left_right = left.join( + right, on="a", how="left", maintain_order="left_right" + ).collect() + # If the left order is preserved then there are no unsorted right rows + assert left_left.get_column("a").equals(left_left_right.get_column("a")) + + left_right = left.join(right, on="a", how="left", maintain_order="right").collect() + assert left_right.get_column("a").cast(pl.UInt32).to_list()[:5] == [ + 1, + 1, + 2, + None, + 5, + ] + + left_right_left = left.join( + right, on="a", how="left", maintain_order="right_left" + ).collect() + assert left_right_left.get_column("a").cast(pl.UInt32).to_list() == [ + 1, + 1, + 2, + None, + 5, + ] + + right_left = left.join(right, on="a", how="right", maintain_order="left").collect() + assert right_left.get_column("a").cast(pl.UInt32).to_list() == [2, 1, 1, None, 6] + + right_right = left.join( + right, on="a", how="right", maintain_order="right" + ).collect() + assert right_right.get_column("a").cast(pl.UInt32).to_list() == [ + 1, + 1, + None, + 2, + 6, + ] + + +def test_join_preserve_order_full() -> None: + left = pl.LazyFrame({"a": [None, 2, 1, 1, 5]}) + right = pl.LazyFrame({"a": [1, None, 2, 6], "b": [6, 7, 8, 9]}) + + full_left = left.join(right, on="a", how="full", maintain_order="left").collect() + assert full_left.get_column("a").cast(pl.UInt32).to_list()[:5] == [ + None, + 2, + 1, + 1, + 5, + ] + full_right = left.join(right, on="a", how="full", maintain_order="right").collect() + assert full_right.get_column("a").cast(pl.UInt32).to_list()[:5] == [ + 1, + 1, + None, + 2, + None, + ] + + full_left_right = left.join( + right, on="a", how="full", maintain_order="left_right" + ).collect() + assert full_left_right.get_column("a_right").cast(pl.UInt32).to_list() == [ + None, + 2, + 1, + 1, + None, + None, + 6, + ] + + full_right_left = left.join( + right, on="a", how="full", maintain_order="right_left" + ).collect() + assert full_right_left.get_column("a").cast(pl.UInt32).to_list() == [ + 1, + 1, + None, + 2, + None, + None, + 5, + ] + + +@pytest.mark.parametrize( + "dtypes", + [ + ["Int128", "Int128", "Int64"], + ["Int128", "Int128", "Int32"], + ["Int128", "Int128", "Int16"], + ["Int128", "Int128", "Int8"], + ["Int128", "UInt64", "Int128"], + ["Int128", "UInt64", "Int64"], + ["Int128", "UInt64", "Int32"], + ["Int128", "UInt64", "Int16"], + ["Int128", "UInt64", "Int8"], + ["Int128", "UInt32", "Int128"], + ["Int128", "UInt16", "Int128"], + ["Int128", "UInt8", "Int128"], + + ["Int64", "Int64", "Int32"], + ["Int64", "Int64", "Int16"], + ["Int64", "Int64", "Int8"], + ["Int64", "UInt32", "Int64"], + ["Int64", "UInt32", "Int32"], + ["Int64", "UInt32", "Int16"], + ["Int64", "UInt32", "Int8"], + ["Int64", "UInt16", "Int64"], + ["Int64", "UInt8", "Int64"], + + ["Int32", "Int32", "Int16"], + ["Int32", "Int32", "Int8"], + ["Int32", "UInt16", "Int32"], + ["Int32", "UInt16", "Int16"], + ["Int32", "UInt16", "Int8"], + ["Int32", "UInt8", "Int32"], + + ["Int16", "Int16", "Int8"], + ["Int16", "UInt8", "Int16"], + ["Int16", "UInt8", "Int8"], + + ["UInt64", "UInt64", "UInt32"], + ["UInt64", "UInt64", "UInt16"], + ["UInt64", "UInt64", "UInt8"], + + ["UInt32", "UInt32", "UInt16"], + ["UInt32", "UInt32", "UInt8"], + + ["UInt16", "UInt16", "UInt8"], + + ["Float64", "Float64", "Float32"], + ], +) # fmt: skip +@pytest.mark.parametrize("swap", [True, False]) +def test_join_numeric_key_upcast_15338( + dtypes: tuple[str, str, str], swap: bool +) -> None: + supertype, ltype, rtype = (getattr(pl, x) for x in dtypes) + ltype, rtype = (rtype, ltype) if swap else (ltype, rtype) + + left = pl.select(pl.Series("a", [1, 1, 3]).cast(ltype)).lazy() + right = pl.select(pl.Series("a", [1]).cast(rtype), b=pl.lit("A")).lazy() + + assert_frame_equal( + left.join(right, on="a", how="left").collect(), + pl.select(a=pl.Series([1, 1, 3]).cast(ltype), b=pl.Series(["A", "A", None])), + check_row_order=False, + ) + + assert_frame_equal( + left.join(right, on="a", how="left", coalesce=False).drop("a_right").collect(), + pl.select(a=pl.Series([1, 1, 3]).cast(ltype), b=pl.Series(["A", "A", None])), + check_row_order=False, + ) + + assert_frame_equal( + left.join(right, on="a", how="full").collect(), + pl.select( + a=pl.Series([1, 1, 3]).cast(ltype), + a_right=pl.Series([1, 1, None]).cast(rtype), + b=pl.Series(["A", "A", None]), + ), + check_row_order=False, + ) + + assert_frame_equal( + left.join(right, on="a", how="full", coalesce=True).collect(), + pl.select( + a=pl.Series([1, 1, 3]).cast(supertype), + b=pl.Series(["A", "A", None]), + ), + check_row_order=False, + ) + + assert_frame_equal( + left.join(right, on="a", how="semi").collect(), + pl.select(a=pl.Series([1, 1]).cast(ltype)), + ) + + # join_where + for no_optimization in [True, False]: + assert_frame_equal( + left.join_where(right, pl.col("a") == pl.col("a_right")).collect( + no_optimization=no_optimization + ), + pl.select( + a=pl.Series([1, 1]).cast(ltype), + a_right=pl.lit(1, dtype=rtype), + b=pl.Series(["A", "A"]), + ), + ) + + +def test_join_numeric_key_upcast_forbid_float_int() -> None: + ltype = pl.Float64 + rtype = pl.Int128 + + left = pl.LazyFrame({"a": [1.0, 0.0]}, schema={"a": ltype}) + right = pl.LazyFrame({"a": [1, 2]}, schema={"a": rtype}) + + # Establish baseline: In a non-join context, comparisons between ltype and + # rtype succeed even if the upcast is lossy. + assert_frame_equal( + left.with_columns(right.collect()["a"].alias("a_right")) + .select(pl.col("a") == pl.col("a_right")) + .collect(), + pl.DataFrame({"a": [True, False]}), + ) + + with pytest.raises(SchemaError, match="datatypes of join keys don't match"): + left.join(right, on="a", how="left").collect() + + for no_optimization in [True, False]: + with pytest.raises( + SchemaError, match="'join_where' cannot compare Float64 with Int128" + ): + left.join_where(right, pl.col("a") == pl.col("a_right")).collect( + no_optimization=no_optimization + ) + + with pytest.raises( + SchemaError, match="'join_where' cannot compare Float64 with Int128" + ): + left.join_where( + right, pl.col("a") == (pl.col("a") == pl.col("a_right")) + ).collect(no_optimization=no_optimization) + + +def test_join_numeric_key_upcast_order() -> None: + # E.g. when we are joining on this expression: + # * col('a') + 127 + # + # and we want to upcast, ensure that we upcast like this: + # * ( col('a') + 127 ) .cast() + # + # and *not* like this: + # * ( col('a').cast() + lit(127).cast() ) + # + # as otherwise the results would be different. + + left = pl.select(pl.Series("a", [1], dtype=pl.Int8)).lazy() + right = pl.select( + pl.Series("a", [1, 128, -128], dtype=pl.Int64), b=pl.lit("A") + ).lazy() + + # col('a') in `left` is Int8, the result will overflow to become -128 + left_expr = pl.col("a") + 127 + + assert_frame_equal( + left.join(right, left_on=left_expr, right_on="a", how="inner").collect(), + pl.DataFrame( + { + "a": pl.Series([1], dtype=pl.Int8), + "a_right": pl.Series([-128], dtype=pl.Int64), + "b": "A", + } + ), + ) + + assert_frame_equal( + left.join_where(right, left_expr == pl.col("a_right")).collect(), + pl.DataFrame( + { + "a": pl.Series([1], dtype=pl.Int8), + "a_right": pl.Series([-128], dtype=pl.Int64), + "b": "A", + } + ), + ) + + assert_frame_equal( + ( + left.join(right, left_on=left_expr, right_on="a", how="full") + .collect() + .sort(pl.all()) + ), + pl.DataFrame( + { + "a": pl.Series([1, None, None], dtype=pl.Int8), + "a_right": pl.Series([-128, 1, 128], dtype=pl.Int64), + "b": ["A", "A", "A"], + } + ).sort(pl.all()), + ) + + +def test_no_collapse_join_when_maintain_order_20725() -> None: + df1 = pl.LazyFrame({"Fraction_1": [0, 25, 50, 75, 100]}) + df2 = pl.LazyFrame({"Fraction_2": [0, 1]}) + df3 = pl.LazyFrame({"Fraction_3": [0, 1]}) + + ldf = df1.join(df2, how="cross", maintain_order="left_right").join( + df3, how="cross", maintain_order="left_right" + ) + + df_pl_lazy = ldf.filter(pl.col("Fraction_1") == 100).collect() + df_pl_eager = ldf.collect().filter(pl.col("Fraction_1") == 100) + + assert_frame_equal(df_pl_lazy, df_pl_eager) + + +def test_join_where_predicate_type_coercion_21009() -> None: + left_frame = pl.LazyFrame( + { + "left_match": ["A", "B", "C", "D", "E", "F"], + "left_date_start": range(6), + } + ) + + right_frame = pl.LazyFrame( + { + "right_match": ["D", "E", "F", "G", "H", "I"], + "right_date": range(6), + } + ) + + # Note: Cannot eq the plans as the operand sides are non-deterministic + + q1 = left_frame.join_where( + right_frame, + pl.col("left_match") == pl.col("right_match"), + pl.col("right_date") >= pl.col("left_date_start"), + ) + + plan = q1.explain().splitlines() + assert plan[0].strip().startswith("FILTER") + assert plan[1] == "FROM" + assert plan[2].strip().startswith("INNER JOIN") + + q2 = left_frame.join_where( + right_frame, + pl.all_horizontal(pl.col("left_match") == pl.col("right_match")), + pl.col("right_date") >= pl.col("left_date_start"), + ) + + plan = q2.explain().splitlines() + assert plan[0].strip().startswith("FILTER") + assert plan[1] == "FROM" + assert plan[2].strip().startswith("INNER JOIN") + + assert_frame_equal(q1.collect(), q2.collect()) + + +def test_join_right_predicate_pushdown_21142() -> None: + left = pl.LazyFrame({"key": [1, 2, 4], "values": ["a", "b", "c"]}) + right = pl.LazyFrame({"key": [1, 2, 3], "values": ["d", "e", "f"]}) + + rjoin = left.join(right, on="key", how="right") + + q = rjoin.filter(pl.col("values").is_null()) + + expect = pl.select( + pl.Series("values", [None], pl.String), + pl.Series("key", [3], pl.Int64), + pl.Series("values_right", ["f"], pl.String), + ) + + assert_frame_equal(q.collect(), expect) + + # Ensure for right join, filter on RHS key-columns are pushed down. + q = rjoin.filter(pl.col("values_right").is_null()) + + plan = q.explain() + assert plan.index("FILTER") > plan.index("RIGHT PLAN ON") + + assert_frame_equal(q.collect(), expect.clear()) + + +def test_join_where_nested_expr_21066() -> None: + left = pl.LazyFrame({"a": [1, 2]}) + right = pl.LazyFrame({"a": [1]}) + + q = left.join_where(right, pl.col("a") == (pl.col("a_right") + 1)) + + assert_frame_equal(q.collect(), pl.DataFrame({"a": 2, "a_right": 1})) + + +def test_select_after_join_where_20831() -> None: + left = pl.LazyFrame( + { + "a": [1, 2, 3, 1, None], + "b": [1, 2, 3, 4, 5], + "c": [2, 3, 4, 5, 6], + } + ) + + right = pl.LazyFrame( + { + "a": [1, 4, 3, 7, None, None, 1], + "c": [2, 3, 4, 5, 6, 7, 8], + "d": [6, None, 7, 8, -1, 2, 4], + } + ) + + q = left.join_where( + right, pl.col("b") * 2 <= pl.col("a_right"), pl.col("a") < pl.col("c_right") + ) + + assert_frame_equal( + q.select("d").collect().sort("d"), + pl.Series("d", [None, None, 7, 8, 8, 8]).to_frame(), + ) + + assert q.select(pl.len()).collect().item() == 6 + + q = ( + left.join(right, how="cross") + .filter(pl.col("b") * 2 <= pl.col("a_right")) + .filter(pl.col("a") < pl.col("c_right")) + ) + + assert_frame_equal( + q.select("d").collect().sort("d"), + pl.Series("d", [None, None, 7, 8, 8, 8]).to_frame(), + ) + + assert q.select(pl.len()).collect().item() == 6 + + +@pytest.mark.parametrize( + ("dtype", "data"), + [ + (pl.Struct, [{"x": 1}, {"x": 2}, {"x": 3}, {"x": 4}]), + (pl.List, [[1], [2, 2], [3, 3, 3], [4, 4, 4, 4]]), + (pl.Array(pl.Int64, 2), [[1, 1], [2, 2], [3, 3], [4, 4]]), + ], +) +def test_join_on_nested(dtype: PolarsDataType, data: list[Any]) -> None: + lhs = pl.DataFrame( + { + "a": data[:3], + "b": [1, 2, 3], + } + ) + rhs = pl.DataFrame( + { + "a": [data[3], data[1]], + "c": [4, 2], + } + ) + + assert_frame_equal( + lhs.join(rhs, on="a", how="left", maintain_order="left"), + pl.select( + a=pl.Series(data[:3]), + b=pl.Series([1, 2, 3]), + c=pl.Series([None, 2, None]), + ), + ) + assert_frame_equal( + lhs.join(rhs, on="a", how="right", maintain_order="right"), + pl.select( + b=pl.Series([None, 2]), + a=pl.Series([data[3], data[1]]), + c=pl.Series([4, 2]), + ), + ) + assert_frame_equal( + lhs.join(rhs, on="a", how="inner"), + pl.select( + a=pl.Series([data[1]]), + b=pl.Series([2]), + c=pl.Series([2]), + ), + ) + assert_frame_equal( + lhs.join(rhs, on="a", how="full", maintain_order="left_right"), + pl.select( + a=pl.Series(data[:3] + [None]), + b=pl.Series([1, 2, 3, None]), + a_right=pl.Series([None, data[1], None, data[3]]), + c=pl.Series([None, 2, None, 4]), + ), + ) + assert_frame_equal( + lhs.join(rhs, on="a", how="semi"), + pl.select( + a=pl.Series([data[1]]), + b=pl.Series([2]), + ), + ) + assert_frame_equal( + lhs.join(rhs, on="a", how="anti", maintain_order="left"), + pl.select( + a=pl.Series([data[0], data[2]]), + b=pl.Series([1, 3]), + ), + ) + assert_frame_equal( + lhs.join(rhs, how="cross", maintain_order="left_right"), + pl.select( + a=pl.Series([data[0], data[0], data[1], data[1], data[2], data[2]]), + b=pl.Series([1, 1, 2, 2, 3, 3]), + a_right=pl.Series([data[3], data[1], data[3], data[1], data[3], data[1]]), + c=pl.Series([4, 2, 4, 2, 4, 2]), + ), + ) + + +def test_empty_join_result_with_array_15474() -> None: + lhs = pl.DataFrame( + { + "x": [1, 2], + "y": pl.Series([[1, 2, 3], [4, 5, 6]], dtype=pl.Array(pl.Int64, 3)), + } + ) + rhs = pl.DataFrame({"x": [0]}) + result = lhs.join(rhs, on="x") + expected = pl.DataFrame(schema={"x": pl.Int64, "y": pl.Array(pl.Int64, 3)}) + assert_frame_equal(result, expected) + + +@pytest.mark.slow +def test_join_where_eager_perf_21145() -> None: + left = pl.Series("left", range(3_000)).to_frame() + right = pl.Series("right", range(1_000)).to_frame() + + def time_func(func: Callable[[], Any]) -> float: + times = [] + for _ in range(3): + t = perf_counter() + func() + times.append(perf_counter() - t) + + return min(times) + + p = pl.col("left").is_between(pl.lit(0, dtype=pl.Int64), pl.col("right")) + runtime_eager = time_func(lambda: left.join_where(right, p)) + runtime_lazy = time_func(lambda: left.lazy().join_where(right.lazy(), p).collect()) + runtime_ratio = runtime_eager / runtime_lazy + + # Pick as high as reasonably possible for CI stability + # * Was observed to be >=5 seconds on the bugged version, so 3 is a safe bet. + threshold = 3 + + if runtime_ratio > threshold: + msg = f"runtime_ratio ({runtime_ratio}) > {threshold}x ({runtime_eager = }, {runtime_lazy = })" + raise ValueError(msg) + + +def test_select_len_after_semi_anti_join_21343() -> None: + lhs = pl.LazyFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + rhs = pl.LazyFrame({"a": [1, 2, 3]}) + + q = lhs.join(rhs, on="a", how="anti").select(pl.len()) + + assert q.collect().item() == 0 + + +def test_multi_leftjoin_empty_right_21701() -> None: + parent_data = { + "id": [1, 30, 80], + "parent_field1": [3, 20, 17], + } + parent_df = pl.LazyFrame(parent_data) + child_df = pl.LazyFrame( + [], + schema={"id": pl.Int32(), "parent_id": pl.Int32(), "child_field1": pl.Int32()}, + ) + subchild_df = pl.LazyFrame( + [], schema={"child_id": pl.Int32(), "subchild_field1": pl.Int32()} + ) + + joined_df = parent_df.join( + child_df.join( + subchild_df, left_on=pl.col("id"), right_on=pl.col("child_id"), how="left" + ), + left_on=pl.col("id"), + right_on=pl.col("parent_id"), + how="left", + ) + joined_df = joined_df.select("id", "parent_field1") + assert_frame_equal(joined_df.collect(), parent_df.collect(), check_row_order=False) + + +@pytest.mark.parametrize("order", ["none", "left_right", "right_left"]) +def test_join_null_equal(order: Literal["none", "left_right", "right_left"]) -> None: + lhs = pl.DataFrame({"x": [1, None, None], "y": [1, 2, 3]}) + with_null = pl.DataFrame({"x": [1, None], "z": [1, 2]}) + without_null = pl.DataFrame({"x": [1, 3], "z": [1, 3]}) + check_row_order = order != "none" + + # Inner join. + assert_frame_equal( + lhs.join(with_null, on="x", nulls_equal=True, maintain_order=order), + pl.DataFrame({"x": [1, None, None], "y": [1, 2, 3], "z": [1, 2, 2]}), + check_row_order=check_row_order, + ) + assert_frame_equal( + lhs.join(without_null, on="x", nulls_equal=True), + pl.DataFrame({"x": [1], "y": [1], "z": [1]}), + ) + + # Left join. + assert_frame_equal( + lhs.join(with_null, on="x", how="left", nulls_equal=True, maintain_order=order), + pl.DataFrame({"x": [1, None, None], "y": [1, 2, 3], "z": [1, 2, 2]}), + check_row_order=check_row_order, + ) + assert_frame_equal( + lhs.join( + without_null, on="x", how="left", nulls_equal=True, maintain_order=order + ), + pl.DataFrame({"x": [1, None, None], "y": [1, 2, 3], "z": [1, None, None]}), + check_row_order=check_row_order, + ) + + # Full join. + assert_frame_equal( + lhs.join( + with_null, + on="x", + how="full", + nulls_equal=True, + coalesce=True, + maintain_order=order, + ), + pl.DataFrame({"x": [1, None, None], "y": [1, 2, 3], "z": [1, 2, 2]}), + check_row_order=check_row_order, + ) + if order == "left_right": + expected = pl.DataFrame( + { + "x": [1, None, None, None], + "x_right": [1, None, None, 3], + "y": [1, 2, 3, None], + "z": [1, None, None, 3], + } + ) + else: + expected = pl.DataFrame( + { + "x": [1, None, None, None], + "x_right": [1, 3, None, None], + "y": [1, None, 2, 3], + "z": [1, 3, None, None], + } + ) + assert_frame_equal( + lhs.join( + without_null, on="x", how="full", nulls_equal=True, maintain_order=order + ), + expected, + check_row_order=check_row_order, + check_column_order=False, + ) + + +def test_join_categorical_21815() -> None: + with pl.StringCache(): + left = pl.DataFrame({"x": ["a", "b", "c", "d"]}).with_columns( + xc=pl.col.x.cast(pl.Categorical) + ) + right = pl.DataFrame({"x": ["c", "d", "e", "f"]}).with_columns( + xc=pl.col.x.cast(pl.Categorical) + ) + + # As key. + cat_key = left.join(right, on="xc", how="full") + + # As payload. + cat_payload = left.join(right, on="x", how="full") + + expected = pl.DataFrame( + { + "x": ["a", "b", "c", "d", None, None], + "x_right": [None, None, "c", "d", "e", "f"], + } + ).with_columns( + xc=pl.col.x.cast(pl.Categorical), + xc_right=pl.col.x_right.cast(pl.Categorical), + ) + + assert_frame_equal( + cat_key, expected, check_row_order=False, check_column_order=False + ) + assert_frame_equal( + cat_payload, expected, check_row_order=False, check_column_order=False + ) + + +def test_join_where_nested_boolean() -> None: + df1 = pl.DataFrame({"a": [1, 9, 22], "b": [6, 4, 50]}) + df2 = pl.DataFrame({"c": [1]}) + + predicate = (pl.col("a") < pl.col("b")).cast(pl.Int32) < pl.col("c") + result = df1.join_where(df2, predicate) + expected = pl.DataFrame( + { + "a": [9], + "b": [4], + "c": [1], + } + ) + assert_frame_equal(result, expected) + + +def test_join_where_dtype_upcast() -> None: + df1 = pl.DataFrame( + { + "a": pl.Series([1, 9, 22], dtype=pl.Int8), + "b": [6, 4, 50], + } + ) + df2 = pl.DataFrame({"c": [10]}) + + predicate = (pl.col("a") + (pl.col("b") > 0)) < pl.col("c") + result = df1.join_where(df2, predicate) + expected = pl.DataFrame( + { + "a": pl.Series([1], dtype=pl.Int8), + "b": [6], + "c": [10], + } + ) + assert_frame_equal(result, expected) + + +def test_join_where_valid_dtype_upcast_same_side() -> None: + # Unsafe comparisons are all contained entirely within one table (LHS) + # Safe comparisons across both tables. + df1 = pl.DataFrame( + { + "a": pl.Series([1, 9, 22], dtype=pl.Float32), + "b": [6, 4, 50], + } + ) + df2 = pl.DataFrame({"c": [10, 1, 5]}) + + predicate = ((pl.col("a") < pl.col("b")).cast(pl.Int32) + 3) < pl.col("c") + result = df1.join_where(df2, predicate).sort("a", "b", "c") + expected = pl.DataFrame( + { + "a": pl.Series([1, 1, 9, 9, 22, 22], dtype=pl.Float32), + "b": [6, 6, 4, 4, 50, 50], + "c": [5, 10, 5, 10, 5, 10], + } + ) + assert_frame_equal(result, expected) + + +def test_join_where_invalid_dtype_upcast_different_side() -> None: + # Unsafe comparisons exist across tables. + df1 = pl.DataFrame( + { + "a": pl.Series([1, 9, 22], dtype=pl.Float32), + "b": pl.Series([6, 4, 50], dtype=pl.Float64), + } + ) + df2 = pl.DataFrame({"c": [10, 1, 5]}) + + predicate = ((pl.col("a") >= pl.col("c")) + 3) < 4 + with pytest.raises( + SchemaError, match="'join_where' cannot compare Float32 with Int64" + ): + df1.join_where(df2, predicate) + + # add in a cast to predicate to fix + predicate = ((pl.col("a").cast(pl.UInt8) >= pl.col("c")) + 3) < 4 + result = df1.join_where(df2, predicate).sort("a", "b", "c") + expected = pl.DataFrame( + { + "a": pl.Series([1, 1, 9], dtype=pl.Float32), + "b": pl.Series([6, 6, 4], dtype=pl.Float64), + "c": [5, 10, 10], + } + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dtype", [pl.Int32, pl.Float32]) +def test_join_where_literals(dtype: PolarsDataType) -> None: + df1 = pl.DataFrame({"a": pl.Series([0, 1], dtype=dtype)}) + df2 = pl.DataFrame({"b": pl.Series([1, 2], dtype=dtype)}) + result = df1.join_where(df2, (pl.col("a") + pl.col("b")) < 2) + expected = pl.DataFrame( + { + "a": pl.Series([0], dtype=dtype), + "b": pl.Series([1], dtype=dtype), + } + ) + assert_frame_equal(result, expected) + + +def test_join_where_categorical_string_compare() -> None: + dt = pl.Enum(["a", "b", "c"]) + df1 = pl.DataFrame({"a": pl.Series(["a", "a", "b", "c"], dtype=dt)}) + df2 = pl.DataFrame({"b": [1, 6, 4]}) + predicate = pl.col("a").is_in(["a", "b"]) & (pl.col("b") < 5) + result = df1.join_where(df2, predicate).sort("a", "b") + expected = pl.DataFrame( + { + "a": pl.Series(["a", "a", "a", "a", "b", "b"], dtype=dt), + "b": [1, 1, 4, 4, 1, 4], + } + ) + assert_frame_equal(result, expected) + + +def test_join_where_nonboolean_predicate() -> None: + df1 = pl.DataFrame({"a": [1, 2, 3]}) + df2 = pl.DataFrame({"b": [1, 2, 3]}) + with pytest.raises( + ComputeError, match="'join_where' predicates must resolve to boolean" + ): + df1.join_where(df2, pl.col("a") * 2) + + +def test_empty_outer_join_22206() -> None: + df = pl.LazyFrame({"a": [5, 6], "b": [1, 2]}) + empty = pl.LazyFrame(schema=df.collect_schema()) + assert_frame_equal( + df.join(empty, on=["a", "b"], how="full", coalesce=True), + df, + check_row_order=False, + ) + assert_frame_equal( + empty.join(df, on=["a", "b"], how="full", coalesce=True), + df, + check_row_order=False, + ) diff --git a/py-polars/tests/unit/operations/test_join_asof.py b/py-polars/tests/unit/operations/test_join_asof.py new file mode 100644 index 000000000000..54f096dbb6e4 --- /dev/null +++ b/py-polars/tests/unit/operations/test_join_asof.py @@ -0,0 +1,1329 @@ +import warnings +from datetime import date, datetime, timedelta +from typing import Any + +import numpy as np +import pytest + +import polars as pl +from polars._typing import AsofJoinStrategy, PolarsIntegerType +from polars.exceptions import InvalidOperationError +from polars.testing import assert_frame_equal + + +def test_asof_join_singular_right_11966() -> None: + df = pl.DataFrame({"id": [1, 2, 3], "time": [0.9, 2.1, 2.8]}).sort("time") + lookup = pl.DataFrame({"time": [2.0], "value": [100]}).sort("time") + joined = df.join_asof(lookup, on="time", strategy="nearest") + expected = pl.DataFrame( + {"id": [1, 2, 3], "time": [0.9, 2.1, 2.8], "value": [100, 100, 100]} + ) + assert_frame_equal(joined, expected) + + +def test_asof_join_inline_cast_6438() -> None: + df_trades = pl.DataFrame( + { + "time": [ + datetime(2020, 1, 1, 9, 1, 0), + datetime(2020, 1, 1, 9, 1, 0), + datetime(2020, 1, 1, 9, 3, 0), + datetime(2020, 1, 1, 9, 6, 0), + ], + "stock": ["A", "B", "B", "C"], + "trade": [101, 299, 301, 500], + } + ) + + df_quotes = pl.DataFrame( + { + "time": [ + datetime(2020, 1, 1, 9, 0, 0), + datetime(2020, 1, 1, 9, 2, 0), + datetime(2020, 1, 1, 9, 3, 0), + datetime(2020, 1, 1, 9, 6, 0), + ], + "stock": ["A", "B", "C", "A"], + "quote": [100, 300, 501, 102], + } + ).with_columns([pl.col("time").dt.cast_time_unit("ns")]) + + assert df_trades.join_asof( + df_quotes, on=pl.col("time").cast(pl.Datetime("ns")).set_sorted(), by="stock" + ).to_dict(as_series=False) == { + "time": [ + datetime(2020, 1, 1, 9, 1), + datetime(2020, 1, 1, 9, 1), + datetime(2020, 1, 1, 9, 3), + datetime(2020, 1, 1, 9, 6), + ], + "time_right": [ + datetime(2020, 1, 1, 9, 0), + None, + datetime(2020, 1, 1, 9, 2), + datetime(2020, 1, 1, 9, 3), + ], + "stock": ["A", "B", "B", "C"], + "trade": [101, 299, 301, 500], + "quote": [100, None, 300, 501], + } + + +def test_asof_join_projection_resolution_4606() -> None: + a = pl.DataFrame({"a": [1], "b": [2], "c": [3]}).lazy() + b = pl.DataFrame({"a": [1], "b": [2], "d": [4]}).lazy() + joined_tbl = a.join_asof(b, on=pl.col("a").set_sorted(), by="b") + assert joined_tbl.group_by("a").agg( + [pl.col("c").sum().alias("c")] + ).collect().columns == ["a", "c"] + + +def test_asof_join_schema_5211() -> None: + df1 = pl.DataFrame({"today": [1, 2]}) + + df2 = pl.DataFrame({"next_friday": [1, 2]}) + + assert ( + df1.lazy() + .join_asof( + df2.lazy(), left_on="today", right_on="next_friday", strategy="forward" + ) + .collect_schema() + ) == {"today": pl.Int64, "next_friday": pl.Int64} + + +def test_asof_join_schema_5684() -> None: + df_a = ( + pl.DataFrame( + { + "id": [1], + "a": [1], + "b": [1], + } + ) + .lazy() + .set_sorted("a") + ) + + df_b = ( + pl.DataFrame( + { + "id": [1, 1, 2], + "b": [-3, -3, 6], + } + ) + .lazy() + .set_sorted("b") + ) + + q = ( + df_a.join_asof(df_b, by="id", left_on="a", right_on="b") + .drop("b") + .join_asof(df_b, by="id", left_on="a", right_on="b") + .drop("b") + ) + + projected_result = q.select(pl.all()).collect() + result = q.collect() + + assert_frame_equal(projected_result, result) + assert ( + q.collect_schema() + == projected_result.schema + == {"id": pl.Int64, "a": pl.Int64, "b_right": pl.Int64} + ) + + +def test_join_asof_mismatched_dtypes() -> None: + # test 'on' dtype mismatch + df1 = pl.DataFrame( + {"a": pl.Series([1, 2, 3], dtype=pl.Int64), "b": ["a", "b", "c"]} + ) + df2 = pl.DataFrame( + {"a": pl.Series([1.0, 2.0, 3.0], dtype=pl.Float64), "c": ["d", "e", "f"]} + ) + + with pytest.raises( + pl.exceptions.SchemaError, match="datatypes of join keys don't match" + ): + df1.join_asof(df2, on="a", strategy="forward") + + # test 'by' dtype mismatch + df1 = pl.DataFrame( + { + "time": pl.date_range(date(2018, 1, 1), date(2018, 1, 8), eager=True), + "group": pl.Series([1, 1, 1, 1, 2, 2, 2, 2], dtype=pl.Int32), + "value": [0, 0, None, None, 2, None, 1, None], + } + ) + df2 = pl.DataFrame( + { + "time": pl.date_range(date(2018, 1, 1), date(2018, 1, 8), eager=True), + "group": pl.Series([1, 1, 1, 1, 2, 2, 2, 2], dtype=pl.Int64), + "value": [0, 0, None, None, 2, None, 1, None], + } + ) + + with pytest.raises( + pl.exceptions.ComputeError, match="mismatching dtypes in 'by' parameter" + ): + df1.join_asof(df2, on="time", by="group", strategy="forward") + + +def test_join_asof_floats() -> None: + df1 = pl.DataFrame({"a": [1.0, 2.0, 3.0], "b": ["lrow1", "lrow2", "lrow3"]}) + df2 = pl.DataFrame({"a": [0.59, 1.49, 2.89], "b": ["rrow1", "rrow2", "rrow3"]}) + + result = df1.join_asof(df2, on=pl.col("a").set_sorted(), strategy="backward") + expected = { + "a": [1.0, 2.0, 3.0], + "b": ["lrow1", "lrow2", "lrow3"], + "a_right": [0.59, 1.49, 2.89], + "b_right": ["rrow1", "rrow2", "rrow3"], + } + assert result.to_dict(as_series=False) == expected + + # with by argument + # 5740 + df1 = pl.DataFrame( + {"b": np.linspace(0, 5, 7), "c": ["x" if i < 4 else "y" for i in range(7)]} + ) + df2 = pl.DataFrame( + { + "val": [0.0, 2.5, 2.6, 2.7, 3.4, 4.0, 5.0], + "c": ["x", "x", "x", "y", "y", "y", "y"], + } + ).with_columns(pl.col("val").alias("b").set_sorted()) + assert df1.set_sorted("b").join_asof(df2, on=pl.col("b"), by="c").to_dict( + as_series=False + ) == { + "b": [ + 0.0, + 0.8333333333333334, + 1.6666666666666667, + 2.5, + 3.3333333333333335, + 4.166666666666667, + 5.0, + ], + "c": ["x", "x", "x", "x", "y", "y", "y"], + "val": [0.0, 0.0, 0.0, 2.5, 2.7, 4.0, 5.0], + } + + +def test_join_asof_tolerance() -> None: + df_trades = pl.DataFrame( + { + "time": [ + datetime(2020, 1, 1, 9, 0, 1), + datetime(2020, 1, 1, 9, 0, 1), + datetime(2020, 1, 1, 9, 0, 3), + datetime(2020, 1, 1, 9, 0, 6), + ], + "stock": ["A", "B", "B", "C"], + "trade": [101, 299, 301, 500], + } + ).set_sorted("time") + + df_quotes = pl.DataFrame( + { + "time": [ + datetime(2020, 1, 1, 9, 0, 0), + datetime(2020, 1, 1, 9, 0, 2), + datetime(2020, 1, 1, 9, 0, 4), + datetime(2020, 1, 1, 9, 0, 6), + ], + "stock": ["A", "B", "C", "A"], + "quote": [100, 300, 501, 102], + } + ).set_sorted("time") + + assert df_trades.join_asof( + df_quotes, on="time", by="stock", tolerance="2s" + ).to_dict(as_series=False) == { + "time": [ + datetime(2020, 1, 1, 9, 0, 1), + datetime(2020, 1, 1, 9, 0, 1), + datetime(2020, 1, 1, 9, 0, 3), + datetime(2020, 1, 1, 9, 0, 6), + ], + "stock": ["A", "B", "B", "C"], + "trade": [101, 299, 301, 500], + "quote": [100, None, 300, 501], + } + + assert df_trades.join_asof( + df_quotes, on="time", by="stock", tolerance="1s" + ).to_dict(as_series=False) == { + "time": [ + datetime(2020, 1, 1, 9, 0, 1), + datetime(2020, 1, 1, 9, 0, 1), + datetime(2020, 1, 1, 9, 0, 3), + datetime(2020, 1, 1, 9, 0, 6), + ], + "stock": ["A", "B", "B", "C"], + "trade": [101, 299, 301, 500], + "quote": [100, None, 300, None], + } + + for invalid_tolerance, match in [ + ("foo", "expected leading integer"), + ([None], "could not extract number"), + ]: + with pytest.raises(pl.exceptions.PolarsError, match=match): + df_trades.join_asof( + df_quotes, + on="time", + by="stock", + tolerance=invalid_tolerance, # type: ignore[arg-type] + ) + + +def test_join_asof_tolerance_forward() -> None: + df_quotes = pl.DataFrame( + { + "time": [ + datetime(2020, 1, 1, 9, 0, 0), + datetime(2020, 1, 1, 9, 0, 2), + datetime(2020, 1, 1, 9, 0, 4), + datetime(2020, 1, 1, 9, 0, 6), + datetime(2020, 1, 1, 9, 0, 7), + ], + "stock": ["A", "B", "C", "A", "D"], + "quote": [100, 300, 501, 102, 10], + } + ).set_sorted("time") + + df_trades = pl.DataFrame( + { + "time": [ + datetime(2020, 1, 1, 9, 0, 2), + datetime(2020, 1, 1, 9, 0, 1), + datetime(2020, 1, 1, 9, 0, 3), + datetime(2020, 1, 1, 9, 0, 6), + datetime(2020, 1, 1, 9, 0, 7), + ], + "stock": ["A", "B", "B", "C", "D"], + "trade": [101, 299, 301, 500, 10], + } + ).set_sorted("time") + + assert df_quotes.join_asof( + df_trades, on="time", by="stock", tolerance="2s", strategy="forward" + ).to_dict(as_series=False) == { + "time": [ + datetime(2020, 1, 1, 9, 0, 0), + datetime(2020, 1, 1, 9, 0, 2), + datetime(2020, 1, 1, 9, 0, 4), + datetime(2020, 1, 1, 9, 0, 6), + datetime(2020, 1, 1, 9, 0, 7), + ], + "stock": ["A", "B", "C", "A", "D"], + "quote": [100, 300, 501, 102, 10], + "trade": [101, 301, 500, None, 10], + } + + assert df_quotes.join_asof( + df_trades, on="time", by="stock", tolerance="1s", strategy="forward" + ).to_dict(as_series=False) == { + "time": [ + datetime(2020, 1, 1, 9, 0, 0), + datetime(2020, 1, 1, 9, 0, 2), + datetime(2020, 1, 1, 9, 0, 4), + datetime(2020, 1, 1, 9, 0, 6), + datetime(2020, 1, 1, 9, 0, 7), + ], + "stock": ["A", "B", "C", "A", "D"], + "quote": [100, 300, 501, 102, 10], + "trade": [None, 301, None, None, 10], + } + + # Sanity check that this gives us equi-join + assert df_quotes.join_asof( + df_trades, on="time", by="stock", tolerance="0s", strategy="forward" + ).to_dict(as_series=False) == { + "time": [ + datetime(2020, 1, 1, 9, 0, 0), + datetime(2020, 1, 1, 9, 0, 2), + datetime(2020, 1, 1, 9, 0, 4), + datetime(2020, 1, 1, 9, 0, 6), + datetime(2020, 1, 1, 9, 0, 7), + ], + "stock": ["A", "B", "C", "A", "D"], + "quote": [100, 300, 501, 102, 10], + "trade": [None, None, None, None, 10], + } + + +def test_join_asof_projection() -> None: + df1 = pl.DataFrame( + { + "df1_date": [20221011, 20221012, 20221013, 20221014, 20221016], + "df1_col1": ["foo", "bar", "foo", "bar", "foo"], + "key": ["a", "b", "b", "a", "b"], + } + ).set_sorted("df1_date") + + df2 = pl.DataFrame( + { + "df2_date": [20221012, 20221015, 20221018], + "df2_col1": ["1", "2", "3"], + "key": ["a", "b", "b"], + } + ).set_sorted("df2_date") + + assert ( + ( + df1.lazy().join_asof(df2.lazy(), left_on="df1_date", right_on="df2_date") + ).select([pl.col("df2_date"), "df1_date"]) + ).collect().to_dict(as_series=False) == { + "df2_date": [None, 20221012, 20221012, 20221012, 20221015], + "df1_date": [20221011, 20221012, 20221013, 20221014, 20221016], + } + assert ( + df1.lazy().join_asof( + df2.lazy(), by="key", left_on="df1_date", right_on="df2_date" + ) + ).select(["df2_date", "df1_date"]).collect().to_dict(as_series=False) == { + "df2_date": [None, None, None, 20221012, 20221015], + "df1_date": [20221011, 20221012, 20221013, 20221014, 20221016], + } + + +def test_asof_join_by_logical_types() -> None: + dates = ( + pl.datetime_range( + datetime(2022, 1, 1), datetime(2022, 1, 2), interval="2h", eager=True + ) + .cast(pl.Datetime("ns")) + .head(9) + ) + x = pl.DataFrame({"a": dates, "b": map(float, range(9)), "c": ["1", "2", "3"] * 3}) + + result = x.join_asof(x, on=pl.col("b").set_sorted(), by=["c", "a"]) + + expected = { + "a": [ + datetime(2022, 1, 1, 0, 0), + datetime(2022, 1, 1, 2, 0), + datetime(2022, 1, 1, 4, 0), + datetime(2022, 1, 1, 6, 0), + datetime(2022, 1, 1, 8, 0), + datetime(2022, 1, 1, 10, 0), + datetime(2022, 1, 1, 12, 0), + datetime(2022, 1, 1, 14, 0), + datetime(2022, 1, 1, 16, 0), + ], + "b": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], + "c": ["1", "2", "3", "1", "2", "3", "1", "2", "3"], + "b_right": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], + } + assert result.to_dict(as_series=False) == expected + + +def test_join_asof_projection_7481() -> None: + ldf1 = pl.DataFrame({"a": [1, 2, 2], "b": "bleft"}).lazy().set_sorted("a") + ldf2 = pl.DataFrame({"a": 2, "b": [1, 2, 2]}).lazy().set_sorted("b") + + assert ( + ldf1.join_asof(ldf2, left_on="a", right_on="b").select("a", "b") + ).collect().to_dict(as_series=False) == { + "a": [1, 2, 2], + "b": ["bleft", "bleft", "bleft"], + } + + +def test_asof_join_sorted_by_group(capsys: Any) -> None: + df1 = pl.DataFrame( + { + "key": ["a", "a", "a", "b", "b", "b"], + "asof_key": [2.0, 1.0, 3.0, 1.0, 2.0, 3.0], + "a": [102, 101, 103, 104, 105, 106], + } + ).sort(by=["key", "asof_key"]) + + df2 = pl.DataFrame( + { + "key": ["a", "a", "a", "b", "b", "b"], + "asof_key": [0.9, 1.9, 2.9, 0.9, 1.9, 2.9], + "b": [201, 202, 203, 204, 205, 206], + } + ).sort(by=["key", "asof_key"]) + + expected = pl.DataFrame( + [ + pl.Series("key", ["a", "a", "a", "b", "b", "b"], dtype=pl.String), + pl.Series("asof_key", [1.0, 2.0, 3.0, 1.0, 2.0, 3.0], dtype=pl.Float64), + pl.Series("a", [101, 102, 103, 104, 105, 106], dtype=pl.Int64), + pl.Series("b", [201, 202, 203, 204, 205, 206], dtype=pl.Int64), + ] + ) + + out = df1.join_asof(df2, on="asof_key", by="key") + assert_frame_equal(out, expected) + + _, err = capsys.readouterr() + assert "is not explicitly sorted" not in err + + +def test_asof_join_nearest() -> None: + # Generic join_asof + df1 = pl.DataFrame( + { + "asof_key": [-1, 1, 2, 4, 6], + "a": [1, 2, 3, 4, 5], + } + ).sort(by="asof_key") + + df2 = pl.DataFrame( + { + "asof_key": [-1, 2, 4, 5], + "b": [1, 2, 3, 4], + } + ).sort(by="asof_key") + + expected = pl.DataFrame( + {"asof_key": [-1, 1, 2, 4, 6], "a": [1, 2, 3, 4, 5], "b": [1, 2, 2, 3, 4]} + ) + + out = df1.join_asof(df2, on="asof_key", strategy="nearest") + assert_frame_equal(out, expected) + + # Edge case: last item of right matches multiples on left + df1 = pl.DataFrame( + { + "asof_key": [9, 9, 10, 10, 10], + "a": [1, 2, 3, 4, 5], + } + ).set_sorted("asof_key") + df2 = pl.DataFrame( + { + "asof_key": [1, 2, 3, 10], + "b": [1, 2, 3, 4], + } + ).set_sorted("asof_key") + expected = pl.DataFrame( + { + "asof_key": [9, 9, 10, 10, 10], + "a": [1, 2, 3, 4, 5], + "b": [4, 4, 4, 4, 4], + } + ) + + out = df1.join_asof(df2, on="asof_key", strategy="nearest") + assert_frame_equal(out, expected) + + +def test_asof_join_nearest_with_tolerance() -> None: + a = b = [1, 2, 3, 4, 5] + + nones = pl.Series([None, None, None, None, None], dtype=pl.Int64) + + # Case 1: complete miss + df1 = pl.DataFrame({"asof_key": [1, 2, 3, 4, 5], "a": a}).set_sorted("asof_key") + df2 = pl.DataFrame( + { + "asof_key": [7, 8, 9, 10, 11], + "b": b, + } + ).set_sorted("asof_key") + expected = df1.with_columns(nones.alias("b")) + out = df1.join_asof(df2, on="asof_key", strategy="nearest", tolerance=1) + assert_frame_equal(out, expected) + + # Case 2: complete miss in other direction + df1 = pl.DataFrame({"asof_key": [7, 8, 9, 10, 11], "a": a}).set_sorted("asof_key") + df2 = pl.DataFrame( + { + "asof_key": [1, 2, 3, 4, 5], + "b": b, + } + ).set_sorted("asof_key") + expected = df1.with_columns(nones.alias("b")) + out = df1.join_asof(df2, on="asof_key", strategy="nearest", tolerance=1) + assert_frame_equal(out, expected) + + # Case 3: match first item + df1 = pl.DataFrame({"asof_key": [1, 2, 3, 4, 5], "a": a}).set_sorted("asof_key") + df2 = pl.DataFrame( + { + "asof_key": [6, 7, 8, 9, 10], + "b": b, + } + ).set_sorted("asof_key") + out = df1.join_asof(df2, on="asof_key", strategy="nearest", tolerance=1) + expected = df1.with_columns(pl.Series([None, None, None, None, 1]).alias("b")) + assert_frame_equal(out, expected) + + # Case 4: match last item + df1 = pl.DataFrame({"asof_key": [1, 2, 3, 4, 5], "a": a}).set_sorted("asof_key") + df2 = pl.DataFrame( + { + "asof_key": [-4, -3, -2, -1, 0], + "b": b, + } + ).set_sorted("asof_key") + out = df1.join_asof(df2, on="asof_key", strategy="nearest", tolerance=1) + expected = df1.with_columns(pl.Series([5, None, None, None, None]).alias("b")) + assert_frame_equal(out, expected) + + # Case 5: match multiples, pick closer + df1 = pl.DataFrame( + {"asof_key": pl.Series([1, 2, 3, 4, 5], dtype=pl.Float64), "a": a} + ).set_sorted("asof_key") + df2 = pl.DataFrame( + { + "asof_key": [0.0, 2.0, 2.4, 3.4, 10.0], + "b": b, + } + ).set_sorted("asof_key") + out = df1.join_asof(df2, on="asof_key", strategy="nearest", tolerance=1) + expected = df1.with_columns(pl.Series([2, 2, 4, 4, None]).alias("b")) + assert_frame_equal(out, expected) + + # Case 6: use 0 tolerance + df1 = pl.DataFrame( + {"asof_key": pl.Series([1, 2, 3, 4, 5], dtype=pl.Float64), "a": a} + ).set_sorted("asof_key") + df2 = pl.DataFrame( + { + "asof_key": [0.0, 2.0, 2.4, 3.4, 10.0], + "b": b, + } + ).set_sorted("asof_key") + out = df1.join_asof(df2, on="asof_key", strategy="nearest", tolerance=0) + expected = df1.with_columns(pl.Series([None, 2, None, None, None]).alias("b")) + assert_frame_equal(out, expected) + + # Case 7: test with datetime + df1 = pl.DataFrame( + { + "asof_key": pl.Series( + [ + datetime(2023, 1, 1), + datetime(2023, 1, 2), + datetime(2023, 1, 3), + datetime(2023, 1, 4), + datetime(2023, 1, 6), + ] + ), + "a": a, + } + ).set_sorted("asof_key") + df2 = pl.DataFrame( + { + "asof_key": pl.Series( + [ + datetime(2022, 1, 1), + datetime(2022, 1, 2), + datetime(2022, 1, 3), + datetime( + 2023, 1, 2, 21, 30, 0 + ), # should match with 2023-01-02, 2023-01-03, and 2021-01-04 + datetime(2023, 1, 7), + ] + ), + "b": b, + } + ).set_sorted("asof_key") + out = df1.join_asof(df2, on="asof_key", strategy="nearest", tolerance="1d4h") + expected = df1.with_columns(pl.Series([None, 4, 4, 4, 5]).alias("b")) + assert_frame_equal(out, expected) + + # Case 8: test using timedelta tolerance + out = df1.join_asof( + df2, on="asof_key", strategy="nearest", tolerance=timedelta(days=1, hours=4) + ) + assert_frame_equal(out, expected) + + # Case #9: last item is closest match + df1 = pl.DataFrame( + { + "asof_key_left": [10.00001, 20.0, 30.0], + } + ).set_sorted("asof_key_left") + df2 = pl.DataFrame( + { + "asof_key_right": [10.00001, 20.0001, 29.0], + } + ).set_sorted("asof_key_right") + out = df1.join_asof( + df2, + left_on="asof_key_left", + right_on="asof_key_right", + strategy="nearest", + tolerance=0.5, + ) + expected = pl.DataFrame( + { + "asof_key_left": [10.00001, 20.0, 30.0], + "asof_key_right": [10.00001, 20.0001, None], + } + ) + assert_frame_equal(out, expected) + + +def test_asof_join_nearest_by() -> None: + # Generic join_asof + df1 = pl.DataFrame( + { + "asof_key": [-1, 1, 2, 6, 1], + "group": [1, 1, 1, 2, 2], + "a": [1, 2, 3, 2, 5], + } + ).sort(by=["group", "asof_key"]) + + df2 = pl.DataFrame( + { + "asof_key": [-1, 2, 5, 1], + "group": [1, 1, 2, 2], + "b": [1, 2, 3, 4], + } + ).sort(by=["group", "asof_key"]) + + expected = pl.DataFrame( + { + "asof_key": [-1, 1, 2, 6, 1], + "group": [1, 1, 1, 2, 2], + "a": [1, 2, 3, 5, 2], + "b": [1, 2, 2, 4, 3], + } + ).sort(by=["group", "asof_key"]) + + # Edge case: last item of right matches multiples on left + df1 = pl.DataFrame( + { + "asof_key": [9, 9, 10, 10, 10], + "group": [1, 1, 1, 2, 2], + "a": [1, 2, 3, 2, 5], + } + ).sort(by=["group", "asof_key"]) + + df2 = pl.DataFrame( + { + "asof_key": [-1, 1, 1, 10], + "group": [1, 1, 2, 2], + "b": [1, 2, 3, 4], + } + ).sort(by=["group", "asof_key"]) + + expected = pl.DataFrame( + { + "asof_key": [9, 9, 10, 10, 10], + "group": [1, 1, 1, 2, 2], + "a": [1, 2, 3, 2, 5], + "b": [2, 2, 2, 4, 4], + } + ) + + out = df1.join_asof(df2, on="asof_key", by="group", strategy="nearest") + assert_frame_equal(out, expected) + + a = pl.DataFrame( + { + "code": [676, 35, 676, 676, 676], + "time": [364360, 364370, 364380, 365400, 367440], + } + ) + b = pl.DataFrame( + { + "code": [676, 676, 35, 676, 676], + "time": [364000, 365000, 365000, 366000, 367000], + "price": [1.0, 2.0, 50, 3.0, None], + } + ) + + expected = pl.DataFrame( + { + "code": [676, 35, 676, 676, 676], + "time": [364360, 364370, 364380, 365400, 367440], + "price": [1.0, 50.0, 1.0, 2.0, None], + } + ) + + out = a.join_asof(b, by="code", on="time", strategy="nearest") + assert_frame_equal(out, expected) + + # last item is closest match + df1 = pl.DataFrame( + { + "a": [1, 1, 1], + "asof_key_left": [10.00001, 20.0, 30.0], + } + ).set_sorted("asof_key_left") + df2 = pl.DataFrame( + { + "a": [1, 1, 1], + "asof_key_right": [10.00001, 20.0001, 29.0], + } + ).set_sorted("asof_key_right") + out = df1.join_asof( + df2, + left_on="asof_key_left", + right_on="asof_key_right", + by="a", + strategy="nearest", + ) + expected = pl.DataFrame( + { + "a": [1, 1, 1], + "asof_key_left": [10.00001, 20.0, 30.0], + "asof_key_right": [10.00001, 20.0001, 29.0], + } + ) + assert_frame_equal(out, expected) + + +def test_asof_join_nearest_by_with_tolerance() -> None: + df1 = pl.DataFrame( + { + "group": [ + 1, + 1, + 1, + 1, + 1, + 2, + 2, + 2, + 2, + 2, + 3, + 3, + 3, + 3, + 3, + 4, + 4, + 4, + 4, + 4, + 5, + 5, + 5, + 5, + 5, + 6, + 6, + 6, + 6, + 6, + ], + "asof_key": pl.Series( + [ + 1, + 2, + 3, + 4, + 5, + 7, + 8, + 9, + 10, + 11, + 1, + 2, + 3, + 4, + 5, + 1, + 2, + 3, + 4, + 5, + 1, + 2, + 3, + 4, + 5, + 1, + 2, + 3, + 4, + 5, + ], + dtype=pl.Float32, + ), + "a": [ + 1, + 2, + 3, + 4, + 5, + 1, + 2, + 3, + 4, + 5, + 1, + 2, + 3, + 4, + 5, + 1, + 2, + 3, + 4, + 5, + 1, + 2, + 3, + 4, + 5, + 1, + 2, + 3, + 4, + 5, + ], + } + ) + + df2 = pl.DataFrame( + { + "group": [ + 1, + 1, + 1, + 1, + 1, + 2, + 2, + 2, + 2, + 2, + 3, + 3, + 3, + 3, + 3, + 4, + 4, + 4, + 4, + 4, + 5, + 5, + 5, + 5, + 5, + 6, + 6, + 6, + 6, + 6, + ], + "asof_key": pl.Series( + [ + 7, + 8, + 9, + 10, + 11, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 5, + -3, + -2, + -1, + 0, + 0, + 2, + 2.4, + 3.4, + 10, + -3, + 3, + 8, + 9, + 10, + ], + dtype=pl.Float32, + ), + "b": [ + 1, + 2, + 3, + 4, + 5, + 1, + 2, + 3, + 4, + 5, + 1, + 2, + 3, + 4, + 5, + 1, + 2, + 3, + 4, + 5, + 1, + 2, + 3, + 4, + 5, + 1, + 2, + 3, + 4, + 5, + ], + } + ) + + expected = df1.with_columns( + pl.Series( + [ + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + 1, + 5, + None, + None, + 1, + 1, + 2, + 2, + 4, + 4, + None, + None, + 2, + 2, + 2, + None, + ] + ).alias("b") + ) + df1 = df1.sort(by=["group", "asof_key"]) + df2 = df2.sort(by=["group", "asof_key"]) + expected = expected.sort(by=["group", "a"]) + + out = df1.join_asof( + df2, by="group", on="asof_key", strategy="nearest", tolerance=1.0 + ).sort(by=["group", "a"]) + assert_frame_equal(out, expected) + + # last item is closest match + df1 = pl.DataFrame( + { + "a": [1, 1, 1], + "asof_key_left": [10.00001, 20.0, 30.0], + } + ).set_sorted("asof_key_left") + df2 = pl.DataFrame( + { + "a": [1, 1, 1], + "asof_key_right": [10.00001, 20.0001, 29.0], + } + ).set_sorted("asof_key_right") + out = df1.join_asof( + df2, + left_on="asof_key_left", + right_on="asof_key_right", + by="a", + strategy="nearest", + tolerance=0.5, + ) + expected = pl.DataFrame( + { + "a": [1, 1, 1], + "asof_key_left": [10.00001, 20.0, 30.0], + "asof_key_right": [10.00001, 20.0001, None], + } + ) + assert_frame_equal(out, expected) + + +def test_asof_join_nearest_by_date() -> None: + df1 = pl.DataFrame( + { + "asof_key": [ + date(2019, 12, 30), + date(2020, 1, 1), + date(2020, 1, 2), + date(2020, 1, 6), + date(2020, 1, 1), + ], + "group": [1, 1, 1, 2, 2], + "a": [1, 2, 3, 2, 5], + } + ).sort(by=["group", "asof_key"]) + + df2 = pl.DataFrame( + { + "asof_key": [ + date(2020, 1, 1), + date(2020, 1, 2), + date(2020, 1, 5), + date(2020, 1, 1), + ], + "group": [1, 1, 2, 2], + "b": [1, 2, 3, 4], + } + ).sort(by=["group", "asof_key"]) + + expected = pl.DataFrame( + { + "asof_key": [ + date(2019, 12, 30), + date(2020, 1, 1), + date(2020, 1, 2), + date(2020, 1, 6), + date(2020, 1, 1), + ], + "group": [1, 1, 1, 2, 2], + "a": [1, 2, 3, 2, 5], + "b": [1, 1, 2, 3, 4], + } + ).sort(by=["group", "asof_key"]) + + out = df1.join_asof(df2, on="asof_key", by="group", strategy="nearest") + assert_frame_equal(out, expected) + + +@pytest.mark.may_fail_auto_streaming # See #18927. +def test_asof_join_string() -> None: + left = pl.DataFrame({"x": [None, "a", "b", "c", None, "d", None]}).set_sorted("x") + right = pl.DataFrame({"x": ["apple", None, "chutney"], "y": [0, 1, 2]}).set_sorted( + "x" + ) + forward = left.join_asof(right, on="x", strategy="forward") + backward = left.join_asof(right, on="x", strategy="backward") + forward_expected = pl.DataFrame( + { + "x": [None, "a", "b", "c", None, "d", None], + "y": [None, 0, 2, 2, None, None, None], + } + ) + backward_expected = pl.DataFrame( + { + "x": [None, "a", "b", "c", None, "d", None], + "y": [None, None, 0, 0, None, 2, None], + } + ) + assert_frame_equal(forward, forward_expected) + assert_frame_equal(backward, backward_expected) + + +def test_join_asof_by_argument_parsing() -> None: + df1 = pl.DataFrame( + { + "n": [10, 20, 30, 40, 50, 60], + "id1": [0, 0, 3, 3, 5, 5], + "id2": [1, 2, 1, 2, 1, 2], + "x": ["a", "b", "c", "d", "e", "f"], + } + ).sort(by="n") + + df2 = pl.DataFrame( + { + "n": [25, 8, 5, 23, 15, 35], + "id1": [0, 0, 3, 3, 5, 5], + "id2": [1, 2, 1, 2, 1, 2], + "y": ["A", "B", "C", "D", "E", "F"], + } + ).sort(by="n") + + # any sequency for by argument is allowed, so we should see the same results here + by_list = df1.join_asof(df2, on="n", by=["id1", "id2"]) + by_tuple = df1.join_asof(df2, on="n", by=("id1", "id2")) + assert_frame_equal(by_list, by_tuple) + + # same for using the by_left and by_right kwargs + by_list2 = df1.join_asof( + df2, on="n", by_left=["id1", "id2"], by_right=["id1", "id2"] + ) + by_tuple2 = df1.join_asof( + df2, on="n", by_left=("id1", "id2"), by_right=("id1", "id2") + ) + assert_frame_equal(by_list2, by_list) + assert_frame_equal(by_tuple2, by_list) + + +def test_join_asof_invalid_args() -> None: + df1 = pl.DataFrame( + { + "a": [1, 2, 3], + "b": [1, 2, 3], + } + ).set_sorted("a") + df2 = pl.DataFrame( + { + "a": [1, 2, 3], + "c": [1, 2, 3], + } + ).set_sorted("a") + + with pytest.raises(TypeError, match="expected `on` to be str or Expr, got 'list'"): + df1.join_asof(df2, on=["a"]) # type: ignore[arg-type] + with pytest.raises( + TypeError, match="expected `left_on` to be str or Expr, got 'list'" + ): + df1.join_asof(df2, left_on=["a"], right_on="a") # type: ignore[arg-type] + with pytest.raises( + TypeError, match="expected `right_on` to be str or Expr, got 'list'" + ): + df1.join_asof(df2, left_on="a", right_on=["a"]) # type: ignore[arg-type] + + +def test_join_as_of_by_schema() -> None: + a = pl.DataFrame({"a": [1], "b": [2], "c": [3]}).lazy() + b = pl.DataFrame({"a": [1], "b": [2], "d": [4]}).lazy() + q = a.join_asof(b, on=pl.col("a").set_sorted(), by="b") + assert q.collect_schema().names() == q.collect().columns + + +def test_asof_join_by_schema() -> None: + # different `by` names. + df1 = pl.DataFrame({"on1": 0, "by1": 0}) + df2 = pl.DataFrame({"on1": 0, "by2": 0}) + + q = df1.lazy().join_asof( + df2.lazy(), + on="on1", + by_left="by1", + by_right="by2", + ) + + assert q.collect_schema() == q.collect().schema + + +def test_raise_invalid_by_arg_13020() -> None: + df1 = pl.DataFrame({"asOfDate": [date(2020, 1, 1)]}) + df2 = pl.DataFrame( + { + "endityId": [date(2020, 1, 1)], + "eventDate": ["A"], + } + ) + with pytest.raises(pl.exceptions.InvalidOperationError, match="expected both"): + df1.sort("asOfDate").join_asof( + df2.sort("eventDate"), + left_on="asOfDate", + right_on="eventDate", + by_left=None, + by_right=["entityId"], + ) + + +def test_join_asof_no_exact_matches() -> None: + trades = pl.DataFrame( + { + "time": [ + "2016-05-25 13:30:00.023", + "2016-05-25 13:30:00.038", + "2016-05-25 13:30:00.048", + "2016-05-25 13:30:00.048", + "2016-05-25 13:30:00.048", + ], + "ticker": ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"], + "price": [51.95, 51.95, 720.77, 720.92, 98.0], + "quantity": [75, 155, 100, 100, 100], + } + ).with_columns(pl.col("time").str.to_datetime()) + + quotes = pl.DataFrame( + { + "time": [ + "2016-05-25 13:30:00.023", + "2016-05-25 13:30:00.023", + "2016-05-25 13:30:00.030", + "2016-05-25 13:30:00.041", + "2016-05-25 13:30:00.048", + "2016-05-25 13:30:00.049", + "2016-05-25 13:30:00.072", + "2016-05-25 13:30:00.075", + ], + "ticker": ["GOOG", "MSFT", "MSFT", "MSFT", "GOOG", "AAPL", "GOOG", "MSFT"], + "bid": [720.50, 51.95, 51.97, 51.99, 720.50, 97.99, 720.50, 52.01], + "ask": [720.93, 51.96, 51.98, 52.00, 720.93, 98.01, 720.88, 52.03], + } + ).with_columns(pl.col("time").str.to_datetime()) + + assert trades.join_asof( + quotes, on="time", by="ticker", tolerance="10ms", allow_exact_matches=False + ).to_dict(as_series=False) == { + "time": [ + datetime(2016, 5, 25, 13, 30, 0, 23000), + datetime(2016, 5, 25, 13, 30, 0, 38000), + datetime(2016, 5, 25, 13, 30, 0, 48000), + datetime(2016, 5, 25, 13, 30, 0, 48000), + datetime(2016, 5, 25, 13, 30, 0, 48000), + ], + "ticker": ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"], + "price": [51.95, 51.95, 720.77, 720.92, 98.0], + "quantity": [75, 155, 100, 100, 100], + "bid": [None, 51.97, None, None, None], + "ask": [None, 51.98, None, None, None], + } + + +def test_join_asof_not_sorted() -> None: + df = pl.DataFrame({"a": [1, 1, 1, 2, 2, 2], "b": [2, 1, 3, 1, 2, 3]}) + with pytest.raises(InvalidOperationError, match="is not sorted"): + df.join_asof(df, on="b") + + # When 'by' is provided, we do not check sortedness, but a warning is received + with pytest.warns( + UserWarning, + match="Sortedness of columns cannot be checked when 'by' groups provided", + ): + df.join_asof(df, on="b", by="a") + + # When sortedness is False, we should get no warning + with warnings.catch_warnings(record=True) as w: + df.join_asof(df, on="b", check_sortedness=False) + df.join_asof(df, on="b", by="a", check_sortedness=False) + assert len(w) == 0 # no warnings caught + + +@pytest.mark.parametrize("left_dtype", [pl.Int64, pl.UInt64]) +@pytest.mark.parametrize("right_dtype", [pl.Int64, pl.UInt64]) +@pytest.mark.parametrize("strategy", ["backward", "forward", "nearest"]) +def test_join_asof_large_int_21276( + left_dtype: PolarsIntegerType, + right_dtype: PolarsIntegerType, + strategy: AsofJoinStrategy, +) -> None: + large_int64 = 1608129000134000123 # it only happen when "on" column is large + left = pl.DataFrame({"ts": pl.Series([large_int64 + 2], dtype=left_dtype)}) + right = pl.DataFrame( + { + "ts": pl.Series([large_int64 + 1, large_int64 + 3], dtype=right_dtype), + "value": [111, 333], + } + ) + result = left.join_asof(right, on="ts", strategy=strategy) + idx = 0 if strategy == "backward" else 1 + expected = pl.DataFrame( + { + "ts": left["ts"], + "value": right["value"].gather(idx), + } + ) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/operations/test_join_right.py b/py-polars/tests/unit/operations/test_join_right.py new file mode 100644 index 000000000000..97a01ea2df49 --- /dev/null +++ b/py-polars/tests/unit/operations/test_join_right.py @@ -0,0 +1,121 @@ +import polars as pl +from polars.testing import assert_frame_equal + + +def test_right_join_schemas() -> None: + a = pl.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]}) + + b = pl.DataFrame({"a": [1, 3], "b": [1, 3], "c": [1, 3]}) + + # coalesces the join key, so the key of the right table remains + assert a.join( + b, on="a", how="right", coalesce=True, maintain_order="right" + ).to_dict(as_series=False) == { + "b": [1, 3], + "a": [1, 3], + "b_right": [1, 3], + "c": [1, 3], + } + # doesn't coalesce the join key, so all columns remain + assert a.join(b, on="a", how="right", coalesce=False).columns == [ + "a", + "b", + "a_right", + "b_right", + "c", + ] + + # coalesces the join key, so the key of the right table remains + assert_frame_equal( + b.join(a, on="a", how="right", coalesce=True), + pl.DataFrame( + { + "b": [1, None, 3], + "c": [1, None, 3], + "a": [1, 2, 3], + "b_right": [1, 2, 3], + } + ), + check_row_order=False, + ) + assert b.join(a, on="a", how="right", coalesce=False).columns == [ + "a", + "b", + "c", + "a_right", + "b_right", + ] + + a_ = a.lazy() + b_ = b.lazy() + assert list( + a_.join(b_, on="a", how="right", coalesce=True).collect_schema().keys() + ) == ["b", "a", "b_right", "c"] + assert list( + a_.join(b_, on="a", how="right", coalesce=False).collect_schema().keys() + ) == ["a", "b", "a_right", "b_right", "c"] + assert list( + b_.join(a_, on="a", how="right", coalesce=True).collect_schema().keys() + ) == ["b", "c", "a", "b_right"] + assert list( + b_.join(a_, on="a", how="right", coalesce=False).collect_schema().keys() + ) == ["a", "b", "c", "a_right", "b_right"] + + +def test_right_join_schemas_multikey() -> None: + a = pl.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]}) + + b = pl.DataFrame({"a": [1, 3], "b": [1, 3], "c": [1, 3]}) + assert a.join(b, on=["a", "b"], how="right", coalesce=False).columns == [ + "a", + "b", + "c", + "a_right", + "b_right", + "c_right", + ] + assert_frame_equal( + a.join(b, on=["a", "b"], how="right", coalesce=True), + pl.DataFrame({"c": [1, 3], "a": [1, 3], "b": [1, 3], "c_right": [1, 3]}), + check_row_order=False, + ) + assert_frame_equal( + b.join(a, on=["a", "b"], how="right", coalesce=True), + pl.DataFrame( + {"c": [1, None, 3], "a": [1, 2, 3], "b": [1, 2, 3], "c_right": [1, 2, 3]} + ), + check_row_order=False, + ) + + +def test_join_right_different_key() -> None: + df = pl.DataFrame( + { + "foo": [1, 2, 3], + "bar": [6.0, 7.0, 8.0], + "ham1": ["a", "b", "c"], + } + ) + other_df = pl.DataFrame( + { + "apple": ["x", "y", "z"], + "ham2": ["a", "b", "d"], + } + ) + assert df.join( + other_df, left_on="ham1", right_on="ham2", how="right", maintain_order="right" + ).to_dict(as_series=False) == { + "foo": [1, 2, None], + "bar": [6.0, 7.0, None], + "apple": ["x", "y", "z"], + "ham2": ["a", "b", "d"], + } + + +def test_join_right_different_multikey() -> None: + left = pl.LazyFrame({"a": [1, 2], "b": [1, 2]}) + right = pl.LazyFrame({"c": [1, 2], "d": [1, 2]}) + result = left.join(right, left_on=["a", "b"], right_on=["c", "d"], how="right") + expected = pl.DataFrame({"c": [1, 2], "d": [1, 2]}) + assert_frame_equal(result.collect(), expected, check_row_order=False) + assert result.collect_schema() == expected.schema diff --git a/py-polars/tests/unit/operations/test_merge_sorted.py b/py-polars/tests/unit/operations/test_merge_sorted.py new file mode 100644 index 000000000000..2ef2441fefd4 --- /dev/null +++ b/py-polars/tests/unit/operations/test_merge_sorted.py @@ -0,0 +1,354 @@ +from datetime import time + +import pytest +from hypothesis import given + +import polars as pl +from polars.exceptions import ComputeError +from polars.testing import assert_frame_equal, assert_series_equal +from polars.testing.parametric import series + +left = pl.DataFrame({"a": [42, 13, 37], "b": [3, 8, 9]}) +right = pl.DataFrame({"a": [5, 10, 1996], "b": [1, 5, 7]}) +expected = pl.DataFrame( + { + "a": [5, 42, 10, 1996, 13, 37], + "b": [1, 3, 5, 7, 8, 9], + } +) + +lf = left.lazy().merge_sorted(right.lazy(), "b") + + +@pytest.mark.parametrize("streaming", [False, True]) +def test_merge_sorted(streaming: bool) -> None: + assert_frame_equal( + lf.collect(engine="streaming" if streaming else "in-memory"), + expected, + ) + + +def test_merge_sorted_pred_pd() -> None: + assert_frame_equal( + lf.filter(pl.col.b > 30).collect(), + expected.filter(pl.col.b > 30), + ) + assert_frame_equal( + lf.filter(pl.col.a < 6).collect(), + expected.filter(pl.col.a < 6), + ) + + +def test_merge_sorted_proj_pd() -> None: + assert_frame_equal( + lf.select("b").collect(), + lf.collect().select("b"), + ) + assert_frame_equal( + lf.select("a").collect(), + lf.collect().select("a"), + ) + + +@pytest.mark.parametrize("precision", [2, 3]) +def test_merge_sorted_decimal_20990(precision: int) -> None: + dtype = pl.Decimal(precision=precision, scale=1) + s = pl.Series("a", ["1.0", "0.1"], dtype) + df = pl.DataFrame([s.sort()]) + result = df.lazy().merge_sorted(df.lazy(), "a").collect().get_column("a") + expected = pl.Series("a", ["0.1", "0.1", "1.0", "1.0"], dtype) + assert_series_equal(result, expected) + + +@pytest.mark.may_fail_auto_streaming +def test_merge_sorted_categorical() -> None: + left = pl.Series("a", ["a", "b"], pl.Categorical()).sort().to_frame() + right = pl.Series("a", ["a", "b", "b"], pl.Categorical()).sort().to_frame() + result = left.merge_sorted(right, "a").get_column("a") + expected = pl.Series("a", ["a", "a", "b", "b", "b"], pl.Categorical()) + assert_series_equal(result, expected) + + right = pl.Series("a", ["b", "a"], pl.Categorical()).sort().to_frame() + with pytest.raises( + ComputeError, match="can only merge-sort categoricals with the same categories" + ): + left.merge_sorted(right, "a") + + +@pytest.mark.may_fail_auto_streaming +def test_merge_sorted_categorical_lexical() -> None: + left = pl.Series("a", ["b", "a"], pl.Categorical("lexical")).sort().to_frame() + right = pl.Series("a", ["b", "b", "a"], pl.Categorical("lexical")).sort().to_frame() + result = left.merge_sorted(right, "a").get_column("a") + expected = left.get_column("a").append(right.get_column("a")).sort() + assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("size", "ra"), + [ + (10, [1, 7, 9]), + (10, [0, 0, 0]), + (10, [10, 10, 10]), + (10, [1, None, None]), + (10_000, [1, 2471, 6432]), + (10_000, [777, 777, 777]), + (10_000, [510, 1509, 1996, 2000]), + (10_000, [None, None, None]), + (10_000, [1, None, None]), + (10_000, [None, None, 1]), + ], +) +def test_merge_sorted_unbalanced(size: int, ra: list[int]) -> None: + lhs = pl.DataFrame( + [ + pl.Series("a", range(size), pl.Int32), + pl.Series("b", range(size), pl.Int32), + ] + ) + rhs = pl.DataFrame( + [ + pl.Series("a", ra, pl.Int32), + pl.Series("b", [x * 7 for x in range(len(ra))], pl.Int32), + ] + ) + + lf = lhs.lazy().merge_sorted(rhs.lazy(), "a") + df = lf.collect(engine="streaming") + + nulls_last = ra[0] is not None + + assert df.height == size + len(ra) + assert df.get_column("a").is_sorted(nulls_last=nulls_last) + + reference = ( + lhs.get_column("a").append(rhs.get_column("a")).sort(nulls_last=nulls_last) + ) + assert_series_equal(df.get_column("a"), reference) + + +@given( + lhs=series( + name="a", allowed_dtypes=[pl.Int32], allow_null=False + ), # Nulls see: https://github.com/pola-rs/polars/issues/20991 + rhs=series( + name="a", allowed_dtypes=[pl.Int32], allow_null=False + ), # Nulls see: https://github.com/pola-rs/polars/issues/20991 +) +def test_merge_sorted_parametric_int(lhs: pl.Series, rhs: pl.Series) -> None: + l_df = pl.DataFrame([lhs.sort()]) + r_df = pl.DataFrame([rhs.sort()]) + + merge_sorted = l_df.lazy().merge_sorted(r_df.lazy(), "a").collect().get_column("a") + append_sorted = lhs.append(rhs).sort() + + assert_series_equal(merge_sorted, append_sorted) + + +@given( + lhs=series( + name="a", allowed_dtypes=[pl.Binary], allow_null=False + ), # Nulls see: https://github.com/pola-rs/polars/issues/20991 + rhs=series( + name="a", allowed_dtypes=[pl.Binary], allow_null=False + ), # Nulls see: https://github.com/pola-rs/polars/issues/20991 +) +def test_merge_sorted_parametric_binary(lhs: pl.Series, rhs: pl.Series) -> None: + l_df = pl.DataFrame([lhs.sort()]) + r_df = pl.DataFrame([rhs.sort()]) + + merge_sorted = l_df.lazy().merge_sorted(r_df.lazy(), "a").collect().get_column("a") + append_sorted = lhs.append(rhs).sort() + + assert_series_equal(merge_sorted, append_sorted) + + +@given( + lhs=series( + name="a", allowed_dtypes=[pl.String], allow_null=False + ), # Nulls see: https://github.com/pola-rs/polars/issues/20991 + rhs=series( + name="a", allowed_dtypes=[pl.String], allow_null=False + ), # Nulls see: https://github.com/pola-rs/polars/issues/20991 +) +def test_merge_sorted_parametric_string(lhs: pl.Series, rhs: pl.Series) -> None: + l_df = pl.DataFrame([lhs.sort()]) + r_df = pl.DataFrame([rhs.sort()]) + + merge_sorted = l_df.lazy().merge_sorted(r_df.lazy(), "a").collect().get_column("a") + append_sorted = lhs.append(rhs).sort() + + assert_series_equal(merge_sorted, append_sorted) + + +@given( + lhs=series( + name="a", + allowed_dtypes=[ + pl.Struct({"x": pl.Int32, "y": pl.Struct({"x": pl.Int8, "y": pl.Int8})}) + ], + allow_null=False, + ), # Nulls see: https://github.com/pola-rs/polars/issues/20991 + rhs=series( + name="a", + allowed_dtypes=[ + pl.Struct({"x": pl.Int32, "y": pl.Struct({"x": pl.Int8, "y": pl.Int8})}) + ], + allow_null=False, + ), # Nulls see: https://github.com/pola-rs/polars/issues/20991 +) +def test_merge_sorted_parametric_struct(lhs: pl.Series, rhs: pl.Series) -> None: + l_df = pl.DataFrame([lhs.sort()]) + r_df = pl.DataFrame([rhs.sort()]) + + merge_sorted = l_df.lazy().merge_sorted(r_df.lazy(), "a").collect().get_column("a") + append_sorted = lhs.append(rhs).sort() + + assert_series_equal(merge_sorted, append_sorted) + + +@given( + s=series( + name="a", + excluded_dtypes=[ + pl.Categorical( + ordering="lexical" + ), # Bug. See https://github.com/pola-rs/polars/issues/21025 + ], + allow_null=False, # See: https://github.com/pola-rs/polars/issues/20991 + ), +) +def test_merge_sorted_self_parametric(s: pl.Series) -> None: + df = pl.DataFrame([s.sort()]) + + merge_sorted = df.lazy().merge_sorted(df.lazy(), "a").collect().get_column("a") + append_sorted = s.append(s).sort() + + assert_series_equal(merge_sorted, append_sorted) + + +# This was an encountered bug in the streaming engine, it was actually a bug +# with split_at. +def test_merge_time() -> None: + s = pl.Series("a", [time(0, 0)], pl.Time) + df = pl.DataFrame([s]) + assert df.merge_sorted(df, "a").get_column("a").dtype == pl.Time() + + +@pytest.mark.may_fail_auto_streaming +def test_merge_sorted_invalid_categorical_local() -> None: + df1 = pl.DataFrame({"a": pl.Series(["a", "b", "c"], dtype=pl.Categorical)}) + df2 = pl.DataFrame({"a": pl.Series(["a", "b", "d"], dtype=pl.Categorical)}) + + with pytest.raises( + ComputeError, match="can only merge-sort categoricals with the same categories" + ): + df1.merge_sorted(df2, key="a") + + +@pytest.mark.may_fail_auto_streaming +def test_merge_sorted_categorical_global_physical() -> None: + with pl.StringCache(): + df1 = pl.DataFrame( + {"a": pl.Series(["e", "a", "f"], dtype=pl.Categorical("physical"))} + ) + df2 = pl.DataFrame( + {"a": pl.Series(["a", "c", "d"], dtype=pl.Categorical("physical"))} + ) + expected = pl.DataFrame( + { + "a": pl.Series( + (["e", "a", "a", "f", "c", "d"]), + dtype=pl.Categorical("physical"), + ) + } + ) + result = df1.merge_sorted(df2, key="a") + assert_frame_equal(result, expected) + + +@pytest.mark.may_fail_auto_streaming +def test_merge_sorted_categorical_global_lexical() -> None: + with pl.StringCache(): + df1 = pl.DataFrame( + {"a": pl.Series(["a", "e", "f"], dtype=pl.Categorical("lexical"))} + ) + df2 = pl.DataFrame( + {"a": pl.Series(["a", "c", "d"], dtype=pl.Categorical("lexical"))} + ) + expected = pl.DataFrame( + { + "a": pl.Series( + (["a", "a", "c", "d", "e", "f"]), + dtype=pl.Categorical("lexical"), + ) + } + ) + result = df1.merge_sorted(df2, key="a") + assert_frame_equal(result, expected) + + +def test_merge_sorted_categorical_21952() -> None: + with pl.StringCache(): + df1 = pl.DataFrame({"a": ["a", "b", "c"]}).cast(pl.Categorical("lexical")) + df2 = pl.DataFrame({"a": ["a", "b", "d"]}).cast(pl.Categorical("lexical")) + df = df1.merge_sorted(df2, key="a") + assert repr(df) == ( + "shape: (6, 1)\n" + "┌─────┐\n" + "│ a │\n" + "│ --- │\n" + "│ cat │\n" + "╞═════╡\n" + "│ a │\n" + "│ a │\n" + "│ b │\n" + "│ b │\n" + "│ c │\n" + "│ d │\n" + "└─────┘" + ) + + +@pytest.mark.parametrize("streaming", [False, True]) +def test_merge_sorted_chain_streaming_21789_a(streaming: bool) -> None: + lf0 = pl.LazyFrame({"foo": ["a1", "a2"], "n": [10, 20]}) + lf1 = pl.LazyFrame({"foo": ["b1", "b2"], "n": [11, 21]}) + lf2 = pl.LazyFrame({"foo": ["c1", "c2"], "n": [12, 22]}) + + pq = lf0.merge_sorted(lf1, key="n").merge_sorted(lf2, key="n") + + expected = pl.DataFrame( + { + "foo": ["a1", "b1", "c1", "a2", "b2", "c2"], + "n": [10, 11, 12, 20, 21, 22], + } + ) + + out = pq.collect(engine="streaming" if streaming else "in-memory") + + assert_frame_equal(out, expected) + + +# The following expression triggers [Blocked, Ready] [Ready] in merge_sorted. +@pytest.mark.parametrize("streaming", [False, True]) +def test_merge_sorted_chain_streaming_21789_b(streaming: bool) -> None: + lf0 = pl.LazyFrame({"foo": ["a1", "a2"], "n": [10, 20]}) + lf1 = pl.LazyFrame({"foo": ["b1", "b2"], "n": [11, 21]}) + lf2 = pl.LazyFrame({"foo": ["c1", "c2"], "n": [12, 22]}) + lf3 = pl.LazyFrame({"foo": ["d1", "d2"], "n": [13, 23]}) + + lf01 = lf0.merge_sorted(lf1, key="n").top_k(3, by="n").sort(by="n") + lf23 = lf2.merge_sorted(lf3, key="n") + pq = lf01.merge_sorted(lf23, key="n").bottom_k(6, by="n").sort(by="n") + + expected = pl.DataFrame( + { + "foo": ["b1", "c1", "d1", "a2", "b2", "c2"], + "n": [11, 12, 13, 20, 21, 22], + } + ) + + out = pq.collect(engine="streaming" if streaming else "in-memory") + + assert_frame_equal(out, expected) diff --git a/py-polars/tests/unit/operations/test_over.py b/py-polars/tests/unit/operations/test_over.py new file mode 100644 index 000000000000..a5d61aa02bc4 --- /dev/null +++ b/py-polars/tests/unit/operations/test_over.py @@ -0,0 +1,26 @@ +import polars as pl +from polars.testing import assert_series_equal + + +def test_implode_explode_over_22188() -> None: + df = pl.DataFrame( + { + "x": [1, 2, 3, 1, 2, 3, 1, 2, 3], + "y": [2, 2, 2, 3, 3, 3, 4, 4, 4], + } + ) + result = df.select( + (pl.col.x * (pl.lit(pl.Series([1, 1, 1])).implode().explode())).over(pl.col.y), + ) + + assert_series_equal(result.to_series(), df.get_column("x")) + + +def test_implode_in_over_22188() -> None: + df = pl.DataFrame( + { + "x": [[1], [2], [3]], + "y": [2, 3, 4], + } + ).select(pl.col.x.list.set_union(pl.lit(pl.Series([1])).implode()).over(pl.col.y)) + assert_series_equal(df.to_series(), pl.Series("x", [[1], [2, 1], [3, 1]])) diff --git a/py-polars/tests/unit/operations/test_pivot.py b/py-polars/tests/unit/operations/test_pivot.py new file mode 100644 index 000000000000..208a58cbd085 --- /dev/null +++ b/py-polars/tests/unit/operations/test_pivot.py @@ -0,0 +1,575 @@ +from __future__ import annotations + +from datetime import date, datetime, timedelta +from typing import TYPE_CHECKING, Any + +import pytest + +import polars as pl +import polars.selectors as cs +from polars.exceptions import ComputeError, DuplicateError +from polars.testing import assert_frame_equal + +if TYPE_CHECKING: + from polars._typing import PivotAgg, PolarsIntegerType + + +def test_pivot() -> None: + df = pl.DataFrame( + { + "foo": ["A", "A", "B", "B", "C"], + "bar": ["k", "l", "m", "n", "o"], + "N": [1, 2, 2, 4, 2], + } + ) + result = df.pivot("bar", values="N", aggregate_function=None) + + expected = pl.DataFrame( + [ + ("A", 1, 2, None, None, None), + ("B", None, None, 2, 4, None), + ("C", None, None, None, None, 2), + ], + schema=["foo", "k", "l", "m", "n", "o"], + orient="row", + ) + assert_frame_equal(result, expected) + + +def test_pivot_no_values() -> None: + df = pl.DataFrame( + { + "foo": ["A", "A", "B", "B", "C"], + "bar": ["k", "l", "m", "n", "o"], + "N1": [1, 2, 2, 4, 2], + "N2": [1, 2, 2, 4, 2], + } + ) + result = df.pivot(on="bar", index="foo", aggregate_function=None) + expected = pl.DataFrame( + { + "foo": ["A", "B", "C"], + "N1_k": [1, None, None], + "N1_l": [2, None, None], + "N1_m": [None, 2, None], + "N1_n": [None, 4, None], + "N1_o": [None, None, 2], + "N2_k": [1, None, None], + "N2_l": [2, None, None], + "N2_m": [None, 2, None], + "N2_n": [None, 4, None], + "N2_o": [None, None, 2], + } + ) + + assert_frame_equal(result, expected) + + +def test_pivot_list() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [[1, 1], [2, 2], [3, 3]]}) + + expected = pl.DataFrame( + { + "a": [1, 2, 3], + "1": [[1, 1], None, None], + "2": [None, [2, 2], None], + "3": [None, None, [3, 3]], + } + ) + out = df.pivot( + index="a", + on="a", + values="b", + aggregate_function="first", + sort_columns=True, + ) + assert_frame_equal(out, expected) + + +@pytest.mark.parametrize( + ("agg_fn", "expected_rows"), + [ + ("first", [("a", 2, None, None), ("b", None, None, 10)]), + ("len", [("a", 2, None, None), ("b", None, 2, 1)]), + ("min", [("a", 2, None, None), ("b", None, 8, 10)]), + ("max", [("a", 4, None, None), ("b", None, 8, 10)]), + ("sum", [("a", 6, None, None), ("b", None, 8, 10)]), + ("mean", [("a", 3.0, None, None), ("b", None, 8.0, 10.0)]), + ("median", [("a", 3.0, None, None), ("b", None, 8.0, 10.0)]), + ], +) +def test_pivot_aggregate(agg_fn: PivotAgg, expected_rows: list[tuple[Any]]) -> None: + df = pl.DataFrame( + { + "a": [1, 1, 2, 2, 3], + "b": ["a", "a", "b", "b", "b"], + "c": [2, 4, None, 8, 10], + } + ) + result = df.pivot( + index="b", on="a", values="c", aggregate_function=agg_fn, sort_columns=True + ) + assert result.rows() == expected_rows + + +def test_pivot_categorical_3968() -> None: + df = pl.DataFrame( + { + "foo": ["one", "one", "one", "two", "two", "two"], + "bar": ["A", "B", "C", "A", "B", "C"], + "baz": [1, 2, 3, 4, 5, 6], + } + ) + + result = df.with_columns(pl.col("baz").cast(str).cast(pl.Categorical)) + + expected = pl.DataFrame( + { + "foo": ["one", "one", "one", "two", "two", "two"], + "bar": ["A", "B", "C", "A", "B", "C"], + "baz": ["1", "2", "3", "4", "5", "6"], + }, + schema_overrides={"baz": pl.Categorical}, + ) + assert_frame_equal(result, expected, categorical_as_str=True) + + +def test_pivot_categorical_index() -> None: + df = pl.DataFrame( + {"A": ["Fire", "Water", "Water", "Fire"], "B": ["Car", "Car", "Car", "Ship"]}, + schema=[("A", pl.Categorical), ("B", pl.Categorical)], + ) + + result = df.pivot(index=["A"], on="B", values="B", aggregate_function="len") + expected = {"A": ["Fire", "Water"], "Car": [1, 2], "Ship": [1, None]} + assert result.to_dict(as_series=False) == expected + + # test expression dispatch + result = df.pivot(index=["A"], on="B", values="B", aggregate_function=pl.len()) + assert result.to_dict(as_series=False) == expected + + df = pl.DataFrame( + { + "A": ["Fire", "Water", "Water", "Fire"], + "B": ["Car", "Car", "Car", "Ship"], + "C": ["Paper", "Paper", "Paper", "Paper"], + }, + schema=[("A", pl.Categorical), ("B", pl.Categorical), ("C", pl.Categorical)], + ) + result = df.pivot(index=["A", "C"], on="B", values="B", aggregate_function="len") + expected = { + "A": ["Fire", "Water"], + "C": ["Paper", "Paper"], + "Car": [1, 2], + "Ship": [1, None], + } + assert result.to_dict(as_series=False) == expected + + +def test_pivot_multiple_values_column_names_5116() -> None: + df = pl.DataFrame( + { + "x1": [1, 2, 3, 4, 5, 6, 7, 8], + "x2": [8, 7, 6, 5, 4, 3, 2, 1], + "c1": ["A", "B"] * 4, + "c2": ["C", "C", "D", "D"] * 2, + } + ) + + with pytest.raises(ComputeError, match="found multiple elements in the same group"): + df.pivot( + index="c1", + on="c2", + values=["x1", "x2"], + separator="|", + aggregate_function=None, + ) + + result = df.pivot( + index="c1", + on="c2", + values=["x1", "x2"], + separator="|", + aggregate_function="first", + ) + expected = { + "c1": ["A", "B"], + "x1|C": [1, 2], + "x1|D": [3, 4], + "x2|C": [8, 7], + "x2|D": [6, 5], + } + assert result.to_dict(as_series=False) == expected + + +def test_pivot_duplicate_names_7731() -> None: + df = pl.DataFrame( + { + "a": [1, 4], + "b": [1.5, 2.5], + "c": ["x", "x"], + "d": [7, 8], + "e": ["x", "y"], + } + ) + result = df.pivot( + index=cs.float(), + on=cs.string(), + values=cs.integer(), + aggregate_function="first", + ).to_dict(as_series=False) + expected = { + "b": [1.5, 2.5], + 'a_{"x","x"}': [1, None], + 'a_{"x","y"}': [None, 4], + 'd_{"x","x"}': [7, None], + 'd_{"x","y"}': [None, 8], + } + assert result == expected + + +def test_pivot_duplicate_names_11663() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [1, 2], "c": ["x", "x"], "d": ["x", "y"]}) + result = df.pivot(index="b", on=["c", "d"], values="a").to_dict(as_series=False) + expected = {"b": [1, 2], '{"x","x"}': [1, None], '{"x","y"}': [None, 2]} + assert result == expected + + +def test_pivot_multiple_columns_12407() -> None: + df = pl.DataFrame( + { + "a": ["beep", "bop"], + "b": ["a", "b"], + "c": ["s", "f"], + "d": [7, 8], + "e": ["x", "y"], + } + ) + result = df.pivot( + index="b", on=["c", "e"], values=["a"], aggregate_function="len" + ).to_dict(as_series=False) + expected = {"b": ["a", "b"], '{"s","x"}': [1, None], '{"f","y"}': [None, 1]} + assert result == expected + + +def test_pivot_struct_13120() -> None: + df = pl.DataFrame( + { + "index": [1, 2, 3, 1, 2, 3], + "item_type": ["a", "a", "a", "b", "b", "b"], + "item_id": [123, 123, 123, 456, 456, 456], + "values": [4, 5, 6, 7, 8, 9], + } + ) + df = df.with_columns(pl.struct(["item_type", "item_id"]).alias("columns")).drop( + "item_type", "item_id" + ) + result = df.pivot(index="index", on="columns", values="values").to_dict( + as_series=False + ) + expected = {"index": [1, 2, 3], '{"a",123}': [4, 5, 6], '{"b",456}': [7, 8, 9]} + assert result == expected + + +def test_pivot_index_struct_14101() -> None: + df = pl.DataFrame( + { + "a": [1, 2, 1], + "b": [{"a": 1}, {"a": 1}, {"a": 2}], + "c": ["x", "y", "y"], + "d": [1, 1, 3], + } + ) + result = df.pivot(index="b", on="c", values="a") + expected = pl.DataFrame({"b": [{"a": 1}, {"a": 2}], "x": [1, None], "y": [2, 1]}) + assert_frame_equal(result, expected) + + +def test_pivot_nested_struct_17065() -> None: + df = pl.DataFrame( + { + "foo": ["one", "two", "one", "two"], + "bar": ["x", "x", "y", "y"], + "baz": [ + {"a": 1, "b": {"c": 2}}, + {"a": 3, "b": {"c": 4}}, + {"a": 5, "b": {"c": 6}}, + {"a": 7, "b": {"c": 8}}, + ], + } + ) + result = df.pivot(on="bar", index="foo", values="baz") + expected = pl.DataFrame( + { + "foo": ["one", "two"], + "x": [ + {"a": 1, "b": {"c": 2}}, + {"a": 3, "b": {"c": 4}}, + ], + "y": [ + {"a": 5, "b": {"c": 6}}, + {"a": 7, "b": {"c": 8}}, + ], + } + ) + assert_frame_equal(result, expected) + + +def test_pivot_name_already_exists() -> None: + # This should be extremely rare...but still, good to check it + df = pl.DataFrame( + { + "a": ["a", "b"], + "b": ["a", "b"], + '{"a","b"}': [1, 2], + } + ) + with pytest.raises(ComputeError, match="already exists in the DataFrame"): + df.pivot( + values='{"a","b"}', + index="a", + on=["a", "b"], + aggregate_function="first", + ) + + +def test_pivot_floats() -> None: + df = pl.DataFrame( + { + "article": ["a", "a", "a", "b", "b", "b"], + "weight": [1.0, 1.0, 4.4, 1.0, 8.8, 1.0], + "quantity": [1.0, 5.0, 1.0, 1.0, 1.0, 7.5], + "price": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + } + ) + + with pytest.raises(ComputeError, match="found multiple elements in the same group"): + result = df.pivot( + index="weight", on="quantity", values="price", aggregate_function=None + ) + + result = df.pivot( + index="weight", on="quantity", values="price", aggregate_function="first" + ) + expected = { + "weight": [1.0, 4.4, 8.8], + "1.0": [1.0, 3.0, 5.0], + "5.0": [2.0, None, None], + "7.5": [6.0, None, None], + } + assert result.to_dict(as_series=False) == expected + + result = df.pivot( + index=["article", "weight"], + on="quantity", + values="price", + aggregate_function=None, + ) + expected = { + "article": ["a", "a", "b", "b"], + "weight": [1.0, 4.4, 1.0, 8.8], + "1.0": [1.0, 3.0, 4.0, 5.0], + "5.0": [2.0, None, None, None], + "7.5": [None, None, 6.0, None], + } + assert result.to_dict(as_series=False) == expected + + +def test_pivot_reinterpret_5907() -> None: + df = pl.DataFrame( + { + "A": pl.Series([3, -2, 3, -2], dtype=pl.Int32), + "B": ["x", "x", "y", "y"], + "C": [100, 50, 500, -80], + } + ) + + result = df.pivot( + index=["A"], on=["B"], values=["C"], aggregate_function=pl.element().sum() + ) + expected = {"A": [3, -2], "x": [100, 50], "y": [500, -80]} + assert result.to_dict(as_series=False) == expected + + +def test_pivot_temporal_logical_types() -> None: + date_lst = [datetime(_, 1, 1) for _ in range(1960, 1980)] + + df = pl.DataFrame( + { + "idx": date_lst[-3:] + date_lst[0:5], + "foo": ["a"] * 3 + ["b"] * 5, + "value": [0] * 8, + } + ) + assert df.pivot( + index="idx", on="foo", values="value", aggregate_function=None + ).to_dict(as_series=False) == { + "idx": [ + datetime(1977, 1, 1, 0, 0), + datetime(1978, 1, 1, 0, 0), + datetime(1979, 1, 1, 0, 0), + datetime(1960, 1, 1, 0, 0), + datetime(1961, 1, 1, 0, 0), + datetime(1962, 1, 1, 0, 0), + datetime(1963, 1, 1, 0, 0), + datetime(1964, 1, 1, 0, 0), + ], + "a": [0, 0, 0, None, None, None, None, None], + "b": [None, None, None, 0, 0, 0, 0, 0], + } + + +def test_pivot_negative_duration() -> None: + df1 = pl.DataFrame({"root": [date(2020, i, 15) for i in (1, 2)]}) + df2 = pl.DataFrame({"delta": [timedelta(days=i) for i in (-2, -1, 0, 1)]}) + + df = df1.join(df2, how="cross").with_columns( + pl.Series(name="value", values=range(len(df1) * len(df2))) + ) + assert df.pivot( + index="delta", on="root", values="value", aggregate_function=None + ).to_dict(as_series=False) == { + "delta": [ + timedelta(days=-2), + timedelta(days=-1), + timedelta(0), + timedelta(days=1), + ], + "2020-01-15": [0, 1, 2, 3], + "2020-02-15": [4, 5, 6, 7], + } + + +def test_aggregate_function_default() -> None: + df = pl.DataFrame({"a": [1, 2], "b": ["foo", "foo"], "c": ["x", "x"]}) + with pytest.raises(ComputeError, match="found multiple elements in the same group"): + df.pivot(index="b", on="c", values="a") + + +def test_pivot_aggregate_function_count_deprecated() -> None: + df = pl.DataFrame( + { + "foo": ["A", "A", "B", "B", "C"], + "N": [1, 2, 2, 4, 2], + "bar": ["k", "l", "m", "n", "o"], + } + ) + with pytest.deprecated_call(): + df.pivot(index="foo", on="bar", values="N", aggregate_function="count") # type: ignore[arg-type] + + +def test_pivot_struct() -> None: + data = { + "id": ["a", "a", "b", "c", "c", "c"], + "week": ["1", "2", "3", "4", "3", "1"], + "num1": [1, 3, 5, 4, 3, 6], + "num2": [4, 5, 3, 4, 6, 6], + } + df = pl.DataFrame(data).with_columns(nums=pl.struct(["num1", "num2"])) + + assert df.pivot( + values="nums", index="id", on="week", aggregate_function="first" + ).to_dict(as_series=False) == { + "id": ["a", "b", "c"], + "1": [ + {"num1": 1, "num2": 4}, + None, + {"num1": 6, "num2": 6}, + ], + "2": [ + {"num1": 3, "num2": 5}, + None, + None, + ], + "3": [ + None, + {"num1": 5, "num2": 3}, + {"num1": 3, "num2": 6}, + ], + "4": [ + None, + None, + {"num1": 4, "num2": 4}, + ], + } + + +def test_duplicate_column_names_which_should_raise_14305() -> None: + df = pl.DataFrame({"a": [1, 3, 2], "c": ["a", "a", "a"], "d": [7, 8, 9]}) + with pytest.raises(DuplicateError, match="has more than one occurrence"): + df.pivot(index="a", on="c", values="d") + + +def test_multi_index_containing_struct() -> None: + df = pl.DataFrame( + { + "a": [1, 2, 1], + "b": [{"a": 1}, {"a": 1}, {"a": 2}], + "c": ["x", "y", "y"], + "d": [1, 1, 3], + } + ) + result = df.pivot(index=("b", "d"), on="c", values="a") + expected = pl.DataFrame( + {"b": [{"a": 1}, {"a": 2}], "d": [1, 3], "x": [1, None], "y": [2, 1]} + ) + assert_frame_equal(result, expected) + + +def test_list_pivot() -> None: + df = pl.DataFrame( + { + "a": [1, 2, 3, 1], + "b": [[1, 2], [3, 4], [5, 6], [1, 2]], + "c": ["x", "x", "y", "y"], + "d": [1, 2, 3, 4], + } + ) + assert df.pivot( + index=["a", "b"], + on="c", + values="d", + ).to_dict(as_series=False) == { + "a": [1, 2, 3], + "b": [[1, 2], [3, 4], [5, 6]], + "x": [1, 2, None], + "y": [4, None, 3], + } + + +def test_pivot_string_17081() -> None: + df = pl.DataFrame( + { + "a": ["1", "2", "3"], + "b": ["4", "5", "6"], + "c": ["7", "8", "9"], + } + ) + assert df.pivot(index="a", on="b", values="c", aggregate_function="min").to_dict( + as_series=False + ) == { + "a": ["1", "2", "3"], + "4": ["7", None, None], + "5": [None, "8", None], + "6": [None, None, "9"], + } + + +def test_pivot_invalid() -> None: + with pytest.raises( + pl.exceptions.InvalidOperationError, + match="`index` and `values` cannot both be None in `pivot` operation", + ): + pl.DataFrame({"a": [1, 2], "b": [2, 3], "c": [3, 4]}).pivot("a") + + +@pytest.mark.parametrize( + "dtype", + [pl.Int8, pl.Int16, pl.Int32, pl.Int64, pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64], +) +def test_pivot_empty_index_dtypes(dtype: PolarsIntegerType) -> None: + index = pl.Series([], dtype=dtype) + df = pl.DataFrame({"index": index, "on": [], "values": []}) + result = df.pivot(index="index", on="on", values="values") + expected = pl.DataFrame({"index": index}) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/operations/test_profile.py b/py-polars/tests/unit/operations/test_profile.py new file mode 100644 index 000000000000..4e371b74b3a6 --- /dev/null +++ b/py-polars/tests/unit/operations/test_profile.py @@ -0,0 +1,32 @@ +import polars as pl + + +def test_profile_columns() -> None: + ldf = pl.LazyFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0]}) + + # profile lazyframe operation/plan + lazy = ldf.group_by("a").agg(pl.implode("b")) + profiling_info = lazy.profile() + # ┌──────────────┬───────┬─────┐ + # │ node ┆ start ┆ end │ + # │ --- ┆ --- ┆ --- │ + # │ str ┆ u64 ┆ u64 │ + # ╞══════════════╪═══════╪═════╡ + # │ optimization ┆ 0 ┆ 69 │ + # │ group_by(a) ┆ 69 ┆ 342 │ + # └──────────────┴───────┴─────┘ + assert len(profiling_info) == 2 + assert profiling_info[1].columns == ["node", "start", "end"] + + +def test_profile_with_cse() -> None: + df = pl.DataFrame({"x": [], "y": []}, schema={"x": pl.Float32, "y": pl.Float32}) + + x = pl.col("x") + y = pl.col("y") + + assert df.lazy().with_columns( + pl.when(x.is_null()) + .then(None) + .otherwise(pl.when(y == 0).then(None).otherwise(x + y)) + ).profile(comm_subexpr_elim=True)[1].shape == (2, 3) diff --git a/py-polars/tests/unit/operations/test_qcut.py b/py-polars/tests/unit/operations/test_qcut.py new file mode 100644 index 000000000000..7b7b1c399a45 --- /dev/null +++ b/py-polars/tests/unit/operations/test_qcut.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +import pytest + +import polars as pl +from polars.exceptions import DuplicateError +from polars.testing import assert_frame_equal, assert_series_equal + +inf = float("inf") + + +def test_qcut() -> None: + s = pl.Series("a", [-2, -1, 0, 1, 2]) + + result = s.qcut([0.25, 0.50]) + + expected = pl.Series( + "a", + [ + "(-inf, -1]", + "(-inf, -1]", + "(-1, 0]", + "(0, inf]", + "(0, inf]", + ], + dtype=pl.Categorical, + ) + assert_series_equal(result, expected, categorical_as_str=True) + + +def test_qcut_lazy_schema() -> None: + lf = pl.LazyFrame({"a": [-2, -1, 0, 1, 2]}) + + result = lf.select(pl.col("a").qcut([0.25, 0.75])) + + expected = pl.LazyFrame( + {"a": ["(-inf, -1]", "(-inf, -1]", "(-1, 1]", "(-1, 1]", "(1, inf]"]}, + schema={"a": pl.Categorical}, + ) + assert_frame_equal(result, expected, categorical_as_str=True) + + +def test_qcut_n() -> None: + s = pl.Series("a", [-2, -1, 0, 1, 2]) + + out = s.qcut(2, labels=["x", "y"], left_closed=True) + + expected = pl.Series("a", ["x", "x", "y", "y", "y"], dtype=pl.Categorical) + assert_series_equal(out, expected, categorical_as_str=True) + + +def test_qcut_include_breaks() -> None: + s = pl.int_range(-2, 3, eager=True).alias("a") + + out = s.qcut([0.0, 0.25, 0.75], labels=["a", "b", "c", "d"], include_breaks=True) + + expected = pl.DataFrame( + { + "breakpoint": [-2.0, -1.0, 1.0, 1.0, inf], + "category": ["a", "b", "c", "c", "d"], + }, + schema_overrides={"category": pl.Categorical}, + ).to_struct("a") + assert_series_equal(out, expected, categorical_as_str=True) + + +# https://github.com/pola-rs/polars/issues/11255 +def test_qcut_include_breaks_lazy_schema() -> None: + lf = pl.LazyFrame({"a": [-2, -1, 0, 1, 2]}) + + result = lf.select( + pl.col("a").qcut([0.25, 0.75], include_breaks=True).alias("qcut") + ).unnest("qcut") + + expected = pl.LazyFrame( + { + "breakpoint": [-1.0, -1.0, 1.0, 1.0, inf], + "category": ["(-inf, -1]", "(-inf, -1]", "(-1, 1]", "(-1, 1]", "(1, inf]"], + }, + schema_overrides={"category": pl.Categorical}, + ) + assert_frame_equal(result, expected, categorical_as_str=True) + + +def test_qcut_null_values() -> None: + s = pl.Series([-1.0, None, 1.0, 2.0, None, 8.0, 4.0]) + + result = s.qcut([0.2, 0.3], labels=["a", "b", "c"]) + + expected = pl.Series(["a", None, "b", "c", None, "c", "c"], dtype=pl.Categorical) + assert_series_equal(result, expected, categorical_as_str=True) + + +def test_qcut_full_null() -> None: + s = pl.Series("a", [None, None, None, None]) + + result = s.qcut([0.25, 0.50]) + + expected = pl.Series("a", [None, None, None, None], dtype=pl.Categorical) + assert_series_equal(result, expected, categorical_as_str=True) + + +def test_qcut_full_null_with_labels() -> None: + s = pl.Series("a", [None, None, None, None]) + + result = s.qcut([0.25, 0.50], labels=["1", "2", "3"]) + + expected = pl.Series("a", [None, None, None, None], dtype=pl.Categorical) + assert_series_equal(result, expected, categorical_as_str=True) + + +def test_qcut_allow_duplicates() -> None: + s = pl.Series([1, 2, 2, 3]) + + with pytest.raises(DuplicateError): + s.qcut([0.50, 0.51]) + + result = s.qcut([0.50, 0.51], allow_duplicates=True) + + expected = pl.Series( + ["(-inf, 2]", "(-inf, 2]", "(-inf, 2]", "(2, inf]"], dtype=pl.Categorical + ) + assert_series_equal(result, expected, categorical_as_str=True) + + +def test_qcut_over() -> None: + df = pl.DataFrame( + { + "group": ["a"] * 4 + ["b"] * 4, + "value": range(8), + } + ) + + out = df.select( + pl.col("value").qcut([0.5], labels=["low", "high"]).over("group") + ).to_series() + + expected = pl.Series( + "value", + ["low", "low", "high", "high", "low", "low", "high", "high"], + dtype=pl.Categorical, + ) + assert_series_equal(out, expected, categorical_as_str=True) diff --git a/py-polars/tests/unit/operations/test_random.py b/py-polars/tests/unit/operations/test_random.py new file mode 100644 index 000000000000..ce8e5644bc77 --- /dev/null +++ b/py-polars/tests/unit/operations/test_random.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +import pytest + +import polars as pl +from polars.exceptions import ShapeError +from polars.testing import assert_frame_equal, assert_series_equal + + +def test_shuffle_group_by_reseed() -> None: + def unique_shuffle_groups(n: int, seed: int | None) -> int: + ls = [1, 2, 3] * n # 1, 2, 3, 1, 2, 3... + groups = sorted(list(range(n)) * 3) # 0, 0, 0, 1, 1, 1, ... + df = pl.DataFrame({"l": ls, "group": groups}) + shuffled = df.group_by("group", maintain_order=True).agg( + pl.col("l").shuffle(seed) + ) + num_unique = shuffled.group_by("l").agg(pl.lit(0)).select(pl.len()) + return int(num_unique[0, 0]) + + assert unique_shuffle_groups(50, None) > 1 # Astronomically unlikely. + assert ( + unique_shuffle_groups(50, 0xDEADBEEF) == 1 + ) # Fixed seed should be always the same. + + +def test_sample_expr() -> None: + a = pl.Series("a", range(20)) + out = pl.select( + pl.lit(a).sample(fraction=0.5, with_replacement=False, seed=1) + ).to_series() + + assert out.shape == (10,) + assert out.to_list() != out.sort().to_list() + assert out.unique().shape == (10,) + assert set(out).issubset(set(a)) + + out = pl.select(pl.lit(a).sample(n=10, with_replacement=False, seed=1)).to_series() + assert out.shape == (10,) + assert out.to_list() != out.sort().to_list() + assert out.unique().shape == (10,) + + # pl.set_random_seed should lead to reproducible results. + pl.set_random_seed(1) + result1 = pl.select(pl.lit(a).sample(n=10)).to_series() + pl.set_random_seed(1) + result2 = pl.select(pl.lit(a).sample(n=10)).to_series() + assert_series_equal(result1, result2) + + +def test_sample_df() -> None: + df = pl.DataFrame({"foo": [1, 2, 3], "bar": [6, 7, 8], "ham": ["a", "b", "c"]}) + + assert df.sample().shape == (1, 3) + assert df.sample(n=2, seed=0).shape == (2, 3) + assert df.sample(fraction=0.4, seed=0).shape == (1, 3) + assert df.sample(n=pl.Series([2]), seed=0).shape == (2, 3) + assert df.sample(fraction=pl.Series([0.4]), seed=0).shape == (1, 3) + assert df.select(pl.col("foo").sample(n=pl.Series([2]), seed=0)).shape == (2, 1) + assert df.select(pl.col("foo").sample(fraction=pl.Series([0.4]), seed=0)).shape == ( + 1, + 1, + ) + with pytest.raises(ValueError, match="cannot specify both `n` and `fraction`"): + df.sample(n=2, fraction=0.4) + + +def test_sample_n_expr() -> None: + df = pl.DataFrame( + { + "group": [1, 1, 1, 2, 2, 2], + "val": [1, 2, 3, 2, 1, 1], + } + ) + + out_df = df.sample(pl.Series([3]), seed=0) + expected_df = pl.DataFrame({"group": [2, 2, 1], "val": [1, 1, 3]}) + assert_frame_equal(out_df, expected_df) + + agg_df = df.group_by("group", maintain_order=True).agg( + pl.col("val").sample(pl.col("val").max(), seed=0) + ) + expected_df = pl.DataFrame({"group": [1, 2], "val": [[1, 2, 3], [1, 1]]}) + assert_frame_equal(agg_df, expected_df) + + select_df = df.select(pl.col("val").sample(pl.col("val").max(), seed=0)) + expected_df = pl.DataFrame({"val": [1, 1, 3]}) + assert_frame_equal(select_df, expected_df) + + +def test_sample_empty_df() -> None: + df = pl.DataFrame({"foo": []}) + + # // If with replacement, then expect empty df + assert df.sample(n=3, with_replacement=True).shape == (0, 1) + assert df.sample(fraction=0.4, with_replacement=True).shape == (0, 1) + + # // If without replacement, then expect shape mismatch on sample_n not sample_frac + with pytest.raises(ShapeError): + df.sample(n=3, with_replacement=False) + assert df.sample(fraction=0.4, with_replacement=False).shape == (0, 1) + + +def test_sample_series() -> None: + s = pl.Series("a", [1, 2, 3, 4, 5]) + + assert len(s.sample(n=2, seed=0)) == 2 + assert len(s.sample(fraction=0.4, seed=0)) == 2 + + assert len(s.sample(n=2, with_replacement=True, seed=0)) == 2 + + # on a series of length 5, you cannot sample more than 5 items + with pytest.raises(ShapeError): + s.sample(n=10, with_replacement=False, seed=0) + # unless you use with_replacement=True + assert len(s.sample(n=10, with_replacement=True, seed=0)) == 10 + + +def test_shuffle_expr() -> None: + # pl.set_random_seed should lead to reproducible results. + s = pl.Series("a", range(20)) + + pl.set_random_seed(1) + result1 = pl.select(pl.lit(s).shuffle()).to_series() + + pl.set_random_seed(1) + result2 = pl.select(pl.lit(s).shuffle()).to_series() + assert_series_equal(result1, result2) + + +def test_shuffle_series() -> None: + a = pl.Series("a", [1, 2, 3]) + out = a.shuffle(2) + expected = pl.Series("a", [2, 1, 3]) + assert_series_equal(out, expected) + + out = pl.select(pl.lit(a).shuffle(2)).to_series() + assert_series_equal(out, expected) + + +def test_sample_16232() -> None: + k = 2 + p = 0 + + df = pl.DataFrame({"a": [p] * k + [1 + p], "b": [[1] * p] * k + [range(1, p + 2)]}) + assert df.select(pl.col("b").list.sample(n=pl.col("a"), seed=0)).to_dict( + as_series=False + ) == {"b": [[], [], [1]]} diff --git a/py-polars/tests/unit/operations/test_rank.py b/py-polars/tests/unit/operations/test_rank.py new file mode 100644 index 000000000000..6f83663875b6 --- /dev/null +++ b/py-polars/tests/unit/operations/test_rank.py @@ -0,0 +1,97 @@ +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal + + +def test_rank_nulls() -> None: + assert pl.Series([]).rank().to_list() == [] + assert pl.Series([None]).rank().to_list() == [None] + assert pl.Series([None, None]).rank().to_list() == [None, None] + + +def test_rank_random_expr() -> None: + df = pl.from_dict( + {"a": [1] * 5, "b": [1, 2, 3, 4, 5], "c": [200, 100, 100, 50, 100]} + ) + + df_ranks1 = df.with_columns( + pl.col("c").rank(method="random", seed=1).over("a").alias("rank") + ) + df_ranks2 = df.with_columns( + pl.col("c").rank(method="random", seed=1).over("a").alias("rank") + ) + assert_frame_equal(df_ranks1, df_ranks2) + + +def test_rank_random_series() -> None: + s = pl.Series("a", [1, 2, 3, 2, 2, 3, 0]) + assert_series_equal( + s.rank("random", seed=1), pl.Series("a", [2, 4, 7, 3, 5, 6, 1], dtype=pl.UInt32) + ) + + +def test_rank_df() -> None: + df = pl.DataFrame( + { + "a": [1, 1, 2, 2, 3], + } + ) + + s = df.select(pl.col("a").rank(method="average").alias("b")).to_series() + assert s.to_list() == [1.5, 1.5, 3.5, 3.5, 5.0] + assert s.dtype == pl.Float64 + + s = df.select(pl.col("a").rank(method="max").alias("b")).to_series() + assert s.to_list() == [2, 2, 4, 4, 5] + assert s.dtype == pl.get_index_type() + + +def test_rank_so_4109() -> None: + # also tests ranks null behavior + df = pl.from_dict( + { + "id": [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4], + "rank": [None, 3, 2, 4, 1, 4, 3, 2, 1, None, 3, 4, 4, 1, None, 3], + } + ).sort(by=["id", "rank"]) + + assert df.group_by("id").agg( + [ + pl.col("rank").alias("original"), + pl.col("rank").rank(method="dense").alias("dense"), + pl.col("rank").rank(method="average").alias("average"), + ] + ).to_dict(as_series=False) == { + "id": [1, 2, 3, 4], + "original": [[None, 2, 3, 4], [1, 2, 3, 4], [None, 1, 3, 4], [None, 1, 3, 4]], + "dense": [[None, 1, 2, 3], [1, 2, 3, 4], [None, 1, 2, 3], [None, 1, 2, 3]], + "average": [ + [None, 1.0, 2.0, 3.0], + [1.0, 2.0, 3.0, 4.0], + [None, 1.0, 2.0, 3.0], + [None, 1.0, 2.0, 3.0], + ], + } + + +def test_rank_string_null_11252() -> None: + rank = pl.Series([None, "", "z", None, "a"]).rank() + assert rank.to_list() == [None, 1.0, 3.0, None, 2.0] + + +def test_rank_series() -> None: + s = pl.Series("a", [1, 2, 3, 2, 2, 3, 0]) + + assert_series_equal( + s.rank("dense"), pl.Series("a", [2, 3, 4, 3, 3, 4, 1], dtype=pl.UInt32) + ) + + df = pl.DataFrame([s]) + assert df.select(pl.col("a").rank("dense"))["a"].to_list() == [2, 3, 4, 3, 3, 4, 1] + + assert_series_equal( + s.rank("dense", descending=True), + pl.Series("a", [3, 2, 1, 2, 2, 1, 4], dtype=pl.UInt32), + ) + + assert s.rank(method="average").dtype == pl.Float64 + assert s.rank(method="max").dtype == pl.get_index_type() diff --git a/py-polars/tests/unit/operations/test_rename.py b/py-polars/tests/unit/operations/test_rename.py new file mode 100644 index 000000000000..3d9809b20842 --- /dev/null +++ b/py-polars/tests/unit/operations/test_rename.py @@ -0,0 +1,162 @@ +import polars as pl +from polars.testing import assert_frame_equal + + +def test_rename_invalidate_cache_15884() -> None: + assert ( + pl.LazyFrame({"a": [1], "b": [1]}) + .rename({"b": "b1"}) # to cache schema + .with_columns( + c=pl.col("b1").drop_nulls(), d=pl.col("b1").drop_nulls() + ) # to trigger CSE + .select("c", "d") # to trigger project push down + ).collect().to_dict(as_series=False) == {"c": [1], "d": [1]} + + +def test_rename_lf() -> None: + ldf = pl.LazyFrame({"a": [1], "b": [2], "c": [3]}) + out = ldf.rename({"a": "foo", "b": "bar"}).collect() + assert out.columns == ["foo", "bar", "c"] + + +def test_with_column_renamed_lf(fruits_cars: pl.DataFrame) -> None: + res = fruits_cars.lazy().rename({"A": "C"}).collect() + assert res.columns[0] == "C" + + +def test_rename_lf_lambda() -> None: + ldf = pl.LazyFrame({"a": [1], "b": [2], "c": [3]}) + out = ldf.rename( + lambda col: "foo" if col == "a" else "bar" if col == "b" else col + ).collect() + assert out.columns == ["foo", "bar", "c"] + + +def test_with_column_renamed() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + result = df.rename({"b": "c"}) + expected = pl.DataFrame({"a": [1, 2], "c": [3, 4]}) + assert_frame_equal(result, expected) + + +def test_rename_swap() -> None: + df = pl.DataFrame( + { + "a": [1, 2, 3, 4, 5], + "b": [5, 4, 3, 2, 1], + } + ) + + out = df.rename({"a": "b", "b": "a"}) + expected = pl.DataFrame( + { + "b": [1, 2, 3, 4, 5], + "a": [5, 4, 3, 2, 1], + } + ) + assert_frame_equal(out, expected) + + # 6195 + ldf = pl.DataFrame( + { + "weekday": [ + 1, + ], + "priority": [ + 2, + ], + "roundNumber": [ + 3, + ], + "flag": [ + 4, + ], + } + ).lazy() + + # Rename some columns (note: swapping two columns) + rename_dict = { + "weekday": "priority", + "priority": "weekday", + "roundNumber": "round_number", + } + ldf = ldf.rename(rename_dict) + + # Select some columns + ldf = ldf.select(["priority", "weekday", "round_number"]) + + assert ldf.collect().to_dict(as_series=False) == { + "priority": [1], + "weekday": [2], + "round_number": [3], + } + + +def test_rename_same_name() -> None: + df = pl.DataFrame( + { + "nrs": [1, 2, 3, 4, 5], + "groups": ["A", "A", "B", "C", "B"], + } + ).lazy() + df = df.rename({"groups": "groups"}) + df = df.select(["groups"]) + assert df.collect().to_dict(as_series=False) == { + "groups": ["A", "A", "B", "C", "B"] + } + df = pl.DataFrame( + { + "nrs": [1, 2, 3, 4, 5], + "groups": ["A", "A", "B", "C", "B"], + "test": [1, 2, 3, 4, 5], + } + ).lazy() + df = df.rename({"nrs": "nrs", "groups": "groups"}) + df = df.select(["groups"]) + df.collect() + assert df.collect().to_dict(as_series=False) == { + "groups": ["A", "A", "B", "C", "B"] + } + + +def test_rename_df(df: pl.DataFrame) -> None: + out = df.rename({"strings": "bars", "int": "foos"}) + # check if we can select these new columns + _ = out[["foos", "bars"]] + + +def test_rename_df_lambda() -> None: + df = pl.DataFrame({"a": [1], "b": [2], "c": [3]}) + out = df.rename(lambda col: "foo" if col == "a" else "bar" if col == "b" else col) + assert out.columns == ["foo", "bar", "c"] + + +def test_rename_schema_order_6660() -> None: + df = pl.DataFrame( + { + "a": [], + "b": [], + "c": [], + "d": [], + } + ) + + mapper = {"a": "1", "b": "2", "c": "3", "d": "4"} + + renamed = df.lazy().rename(mapper) + + computed = renamed.select([pl.all(), pl.col("4").alias("computed")]) + + assert renamed.collect_schema() == renamed.collect().schema + assert computed.collect_schema() == computed.collect().schema + + +def test_rename_schema_17427() -> None: + assert ( + pl.LazyFrame({"A": [1]}) + .with_columns(B=2) + .select(["A", "B"]) + .rename({"A": "C", "B": "A"}) + .select(["C", "A"]) + .collect() + ).to_dict(as_series=False) == {"C": [1], "A": [2]} diff --git a/py-polars/tests/unit/operations/test_replace.py b/py-polars/tests/unit/operations/test_replace.py new file mode 100644 index 000000000000..81edb16a6d49 --- /dev/null +++ b/py-polars/tests/unit/operations/test_replace.py @@ -0,0 +1,292 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +import polars as pl +from polars.exceptions import InvalidOperationError +from polars.testing import assert_frame_equal, assert_series_equal + + +@pytest.fixture(scope="module") +def str_mapping() -> dict[str | None, str]: + return { + "CA": "Canada", + "DE": "Germany", + "FR": "France", + None: "Not specified", + } + + +def test_replace_str_to_str(str_mapping: dict[str | None, str]) -> None: + df = pl.DataFrame({"country_code": ["FR", None, "ES", "DE"]}) + result = df.select(replaced=pl.col("country_code").replace(str_mapping)) + expected = pl.DataFrame({"replaced": ["France", "Not specified", "ES", "Germany"]}) + assert_frame_equal(result, expected) + + +def test_replace_enum() -> None: + dtype = pl.Enum(["a", "b", "c", "d"]) + s = pl.Series(["a", "b", "c"], dtype=dtype) + old = ["a", "b"] + new = pl.Series(["c", "d"], dtype=dtype) + + result = s.replace(old, new) + + expected = pl.Series(["c", "d", "c"], dtype=dtype) + assert_series_equal(result, expected) + + +def test_replace_enum_to_str() -> None: + dtype = pl.Enum(["a", "b", "c", "d"]) + s = pl.Series(["a", "b", "c"], dtype=dtype) + + result = s.replace({"a": "c", "b": "d"}) + + expected = pl.Series(["c", "d", "c"], dtype=dtype) + assert_series_equal(result, expected) + + +@pl.StringCache() +def test_replace_cat_to_cat(str_mapping: dict[str | None, str]) -> None: + lf = pl.LazyFrame( + {"country_code": ["FR", None, "ES", "DE"]}, + schema={"country_code": pl.Categorical}, + ) + old = pl.Series(["CA", "DE", "FR", None], dtype=pl.Categorical) + new = pl.Series( + ["Canada", "Germany", "France", "Not specified"], dtype=pl.Categorical + ) + + result = lf.select(replaced=pl.col("country_code").replace(old, new)) + + expected = pl.LazyFrame( + {"replaced": ["France", "Not specified", "ES", "Germany"]}, + schema_overrides={"replaced": pl.Categorical}, + ) + assert_frame_equal(result, expected) + + +def test_replace_invalid_old_dtype() -> None: + lf = pl.LazyFrame({"a": [1, 2, 3]}) + mapping = {"a": 10, "b": 20} + with pytest.raises( + InvalidOperationError, match="conversion from `str` to `i64` failed" + ): + lf.select(pl.col("a").replace(mapping)).collect() + + +def test_replace_int_to_int() -> None: + df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) + mapping = {1: 5, 3: 7} + result = df.select(replaced=pl.col("int").replace(mapping)) + expected = pl.DataFrame( + {"replaced": [None, 5, None, 7]}, schema={"replaced": pl.Int16} + ) + assert_frame_equal(result, expected) + + +def test_replace_int_to_int_keep_dtype() -> None: + df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) + old = [1, 3] + new = pl.Series([5, 7], dtype=pl.Int16) + + result = df.select(replaced=pl.col("int").replace(old, new)) + expected = pl.DataFrame( + {"replaced": [None, 5, None, 7]}, schema={"replaced": pl.Int16} + ) + assert_frame_equal(result, expected) + + +def test_replace_int_to_str() -> None: + df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) + mapping = {1: "b", 3: "d"} + with pytest.raises( + InvalidOperationError, match="conversion from `str` to `i16` failed" + ): + df.select(replaced=pl.col("int").replace(mapping)) + + +def test_replace_int_to_str_with_null() -> None: + df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) + mapping = {1: "b", 3: "d", None: "e"} + with pytest.raises( + InvalidOperationError, match="conversion from `str` to `i16` failed" + ): + df.select(replaced=pl.col("int").replace(mapping)) + + +def test_replace_empty_mapping() -> None: + df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) + mapping: dict[Any, Any] = {} + result = df.select(pl.col("int").replace(mapping)) + assert_frame_equal(result, df) + + +def test_replace_mapping_different_dtype_str_int() -> None: + df = pl.DataFrame({"int": [None, "1", None, "3"]}) + mapping = {1: "b", 3: "d"} + + result = df.select(pl.col("int").replace(mapping)) + expected = pl.DataFrame({"int": [None, "b", None, "d"]}) + assert_frame_equal(result, expected) + + +def test_replace_mapping_different_dtype_map_none() -> None: + df = pl.DataFrame({"int": [None, "1", None, "3"]}) + mapping = {1: "b", 3: "d", None: "e"} + result = df.select(pl.col("int").replace(mapping)) + expected = pl.DataFrame({"int": ["e", "b", "e", "d"]}) + assert_frame_equal(result, expected) + + +def test_replace_mapping_different_dtype_str_float() -> None: + df = pl.DataFrame({"int": [None, "1", None, "3"]}) + mapping = {1.0: "b", 3.0: "d"} + + result = df.select(pl.col("int").replace(mapping)) + assert_frame_equal(result, df) + + +# https://github.com/pola-rs/polars/issues/7132 +def test_replace_str_to_str_replace_all() -> None: + df = pl.DataFrame({"text": ["abc"]}) + mapping = {"abc": "123"} + result = df.select(pl.col("text").replace(mapping).str.replace_all("1", "-")) + expected = pl.DataFrame({"text": ["-23"]}) + assert_frame_equal(result, expected) + + +@pytest.fixture(scope="module") +def int_mapping() -> dict[int, int]: + return {1: 11, 2: 22, 3: 33, 4: 44, 5: 55} + + +def test_replace_int_to_int1(int_mapping: dict[int, int]) -> None: + s = pl.Series([-1, 22, None, 44, -5]) + result = s.replace(int_mapping) + expected = pl.Series([-1, 22, None, 44, -5]) + assert_series_equal(result, expected) + + +def test_replace_int_to_int4(int_mapping: dict[int, int]) -> None: + s = pl.Series([-1, 22, None, 44, -5]) + result = s.replace(int_mapping) + expected = pl.Series([-1, 22, None, 44, -5]) + assert_series_equal(result, expected) + + +# https://github.com/pola-rs/polars/issues/12728 +def test_replace_str_to_int2() -> None: + s = pl.Series(["a", "b"]) + mapping = {"a": 1, "b": 2} + result = s.replace(mapping) + expected = pl.Series(["1", "2"]) + assert_series_equal(result, expected) + + +def test_replace_str_to_bool_without_default() -> None: + s = pl.Series(["True", "False", "False", None]) + mapping = {"True": True, "False": False} + result = s.replace(mapping) + expected = pl.Series(["true", "false", "false", None]) + assert_series_equal(result, expected) + + +def test_replace_old_new() -> None: + s = pl.Series([1, 2, 2, 3]) + result = s.replace(2, 9) + expected = s = pl.Series([1, 9, 9, 3]) + assert_series_equal(result, expected) + + +def test_replace_old_new_many_to_one() -> None: + s = pl.Series([1, 2, 2, 3]) + result = s.replace([2, 3], 9) + expected = s = pl.Series([1, 9, 9, 9]) + assert_series_equal(result, expected) + + +def test_replace_old_new_mismatched_lengths() -> None: + s = pl.Series([1, 2, 2, 3, 4]) + with pytest.raises(InvalidOperationError): + s.replace([2, 3, 4], [8, 9]) + + +def test_replace_fast_path_one_to_one() -> None: + lf = pl.LazyFrame({"a": [1, 2, 2, 3]}) + result = lf.select(pl.col("a").replace(2, 100)) + expected = pl.LazyFrame({"a": [1, 100, 100, 3]}) + assert_frame_equal(result, expected) + + +def test_replace_fast_path_one_null_to_one() -> None: + # https://github.com/pola-rs/polars/issues/13391 + lf = pl.LazyFrame({"a": [1, None]}) + result = lf.select(pl.col("a").replace(None, 100)) + expected = pl.LazyFrame({"a": [1, 100]}) + assert_frame_equal(result, expected) + + +def test_replace_fast_path_many_with_null_to_one() -> None: + lf = pl.LazyFrame({"a": [1, 2, None]}) + result = lf.select(pl.col("a").replace([1, None], 100)) + expected = pl.LazyFrame({"a": [100, 2, 100]}) + assert_frame_equal(result, expected) + + +def test_replace_fast_path_many_to_one() -> None: + lf = pl.LazyFrame({"a": [1, 2, 2, 3]}) + result = lf.select(pl.col("a").replace([2, 3], 100)) + expected = pl.LazyFrame({"a": [1, 100, 100, 100]}) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + ("old", "new"), + [ + ([2, 2], 100), + ([2, 2], [100, 200]), + ([2, 2], [100, 100]), + ], +) +def test_replace_duplicates_old(old: list[int], new: int | list[int]) -> None: + s = pl.Series([1, 2, 3, 2, 3]) + with pytest.raises( + InvalidOperationError, + match="`old` input for `replace` must not contain duplicates", + ): + s.replace(old, new) + + +def test_replace_duplicates_new() -> None: + s = pl.Series([1, 2, 3, 2, 3]) + result = s.replace([1, 2], [100, 100]) + expected = s = pl.Series([100, 100, 3, 100, 3]) + assert_series_equal(result, expected) + + +def test_replace_return_dtype_deprecated() -> None: + s = pl.Series([1, 2, 3]) + with pytest.deprecated_call(): + result = s.replace(1, 10, return_dtype=pl.Int8) + expected = pl.Series([10, 2, 3], dtype=pl.Int8) + assert_series_equal(result, expected) + + +def test_replace_default_deprecated() -> None: + s = pl.Series([1, 2, 3]) + with pytest.deprecated_call(): + result = s.replace(1, 10, default=None) + expected = pl.Series([10, None, None], dtype=pl.Int32) + assert_series_equal(result, expected) + + +def test_replace_single_argument_not_mapping() -> None: + df = pl.DataFrame({"a": ["a", "b", "c"]}) + with pytest.raises( + TypeError, + match="`new` argument is required if `old` argument is not a Mapping type", + ): + df.select(pl.col("a").replace("b")) diff --git a/py-polars/tests/unit/operations/test_replace_strict.py b/py-polars/tests/unit/operations/test_replace_strict.py new file mode 100644 index 000000000000..ef54cf0de638 --- /dev/null +++ b/py-polars/tests/unit/operations/test_replace_strict.py @@ -0,0 +1,426 @@ +from __future__ import annotations + +import contextlib +from typing import Any + +import pytest + +import polars as pl +from polars.exceptions import CategoricalRemappingWarning, InvalidOperationError +from polars.testing import assert_frame_equal, assert_series_equal + + +def test_replace_strict_incomplete_mapping() -> None: + lf = pl.LazyFrame({"a": [1, 2, 2, 3]}) + + with pytest.raises(InvalidOperationError, match="incomplete mapping"): + lf.select(pl.col("a").replace_strict({2: 200, 3: 300})).collect() + + +def test_replace_strict_incomplete_mapping_null_raises() -> None: + s = pl.Series("a", [1, 2, 2, None, None]) + with pytest.raises(InvalidOperationError): + s.replace_strict({1: 10}) + + +def test_replace_strict_mapping_null_not_specified() -> None: + s = pl.Series("a", [1, 2, 2, None, None]) + + result = s.replace_strict({1: 10, 2: 20}) + + expected = pl.Series("a", [10, 20, 20, None, None]) + assert_series_equal(result, expected) + + +def test_replace_strict_mapping_null_specified() -> None: + s = pl.Series("a", [1, 2, 2, None, None]) + + result = s.replace_strict({1: 10, 2: 20, None: 0}) + + expected = pl.Series("a", [10, 20, 20, 0, 0]) + assert_series_equal(result, expected) + + +def test_replace_strict_mapping_null_replace_by_null() -> None: + s = pl.Series("a", [1, 2, 2, None]) + + result = s.replace_strict({1: 10, 2: None, None: 0}) + + expected = pl.Series("a", [10, None, None, 0]) + assert_series_equal(result, expected) + + +def test_replace_strict_mapping_null_with_default() -> None: + s = pl.Series("a", [1, 2, 2, None, None]) + + result = s.replace_strict({1: 10}, default=0) + + expected = pl.Series("a", [10, 0, 0, 0, 0]) + assert_series_equal(result, expected) + + +def test_replace_strict_empty() -> None: + lf = pl.LazyFrame({"a": [None, None]}) + result = lf.select(pl.col("a").replace_strict({})) + assert_frame_equal(lf, result) + + +def test_replace_strict_fast_path_many_to_one_default() -> None: + lf = pl.LazyFrame({"a": [1, 2, 2, 3]}) + result = lf.select(pl.col("a").replace_strict([2, 3], 100, default=-1)) + expected = pl.LazyFrame({"a": [-1, 100, 100, 100]}, schema={"a": pl.Int32}) + assert_frame_equal(result, expected) + + +def test_replace_strict_fast_path_many_to_one_null() -> None: + lf = pl.LazyFrame({"a": [1, 2, 2, 3]}) + result = lf.select(pl.col("a").replace_strict([2, 3], None, default=-1)) + expected = pl.LazyFrame({"a": [-1, None, None, None]}, schema={"a": pl.Int32}) + assert_frame_equal(result, expected) + + +@pytest.fixture(scope="module") +def str_mapping() -> dict[str | None, str]: + return { + "CA": "Canada", + "DE": "Germany", + "FR": "France", + None: "Not specified", + } + + +def test_replace_strict_str_to_str_default_self( + str_mapping: dict[str | None, str], +) -> None: + df = pl.DataFrame({"country_code": ["FR", None, "ES", "DE"]}) + result = df.select( + replaced=pl.col("country_code").replace_strict( + str_mapping, default=pl.col("country_code") + ) + ) + expected = pl.DataFrame({"replaced": ["France", "Not specified", "ES", "Germany"]}) + assert_frame_equal(result, expected) + + +def test_replace_strict_str_to_str_default_null( + str_mapping: dict[str | None, str], +) -> None: + df = pl.DataFrame({"country_code": ["FR", None, "ES", "DE"]}) + result = df.select( + replaced=pl.col("country_code").replace_strict(str_mapping, default=None) + ) + expected = pl.DataFrame({"replaced": ["France", "Not specified", None, "Germany"]}) + assert_frame_equal(result, expected) + + +def test_replace_strict_str_to_str_default_other( + str_mapping: dict[str | None, str], +) -> None: + df = pl.DataFrame({"country_code": ["FR", None, "ES", "DE"]}) + + result = df.with_row_index().select( + replaced=pl.col("country_code").replace_strict( + str_mapping, default=pl.col("index") + ) + ) + expected = pl.DataFrame({"replaced": ["France", "Not specified", "2", "Germany"]}) + assert_frame_equal(result, expected) + + +def test_replace_strict_str_to_cat() -> None: + s = pl.Series(["a", "b", "c"]) + mapping = {"a": "c", "b": "d"} + result = s.replace_strict(mapping, default=None, return_dtype=pl.Categorical) + expected = pl.Series(["c", "d", None], dtype=pl.Categorical) + assert_series_equal(result, expected, categorical_as_str=True) + + +def test_replace_strict_enum_to_new_enum() -> None: + s = pl.Series(["a", "b", "c"], dtype=pl.Enum(["a", "b", "c", "d"])) + old = ["a", "b"] + + new_dtype = pl.Enum(["a", "b", "c", "d", "e"]) + new = pl.Series(["c", "e"], dtype=new_dtype) + + result = s.replace_strict(old, new, default=None, return_dtype=new_dtype) + + expected = pl.Series(["c", "e", None], dtype=new_dtype) + assert_series_equal(result, expected) + + +def test_replace_strict_int_to_int_null() -> None: + df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) + mapping = {3: None} + result = df.select( + replaced=pl.col("int").replace_strict(mapping, default=pl.lit(6).cast(pl.Int16)) + ) + expected = pl.DataFrame( + {"replaced": [6, 6, 6, None]}, schema={"replaced": pl.Int16} + ) + assert_frame_equal(result, expected) + + +def test_replace_strict_int_to_int_null_default_null() -> None: + df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) + mapping = {3: None} + result = df.select(replaced=pl.col("int").replace_strict(mapping, default=None)) + expected = pl.DataFrame( + {"replaced": [None, None, None, None]}, schema={"replaced": pl.Null} + ) + assert_frame_equal(result, expected) + + +def test_replace_strict_int_to_int_null_return_dtype() -> None: + df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) + mapping = {3: None} + + result = df.select( + replaced=pl.col("int").replace_strict(mapping, default=6, return_dtype=pl.Int32) + ) + + expected = pl.DataFrame( + {"replaced": [6, 6, 6, None]}, schema={"replaced": pl.Int32} + ) + assert_frame_equal(result, expected) + + +def test_replace_strict_empty_mapping_default() -> None: + df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) + mapping: dict[Any, Any] = {} + result = df.select(pl.col("int").replace_strict(mapping, default=pl.lit("A"))) + expected = pl.DataFrame({"int": ["A", "A", "A", "A"]}) + assert_frame_equal(result, expected) + + +def test_replace_strict_int_to_int_df() -> None: + lf = pl.LazyFrame({"a": [1, 2, 3]}, schema={"a": pl.UInt8}) + mapping = {1: 11, 2: 22} + + result = lf.select( + pl.col("a").replace_strict( + old=pl.Series(mapping.keys()), + new=pl.Series(mapping.values(), dtype=pl.UInt8), + default=pl.lit(99).cast(pl.UInt8), + ) + ) + expected = pl.LazyFrame({"a": [11, 22, 99]}, schema_overrides={"a": pl.UInt8}) + assert_frame_equal(result, expected) + + +def test_replace_strict_str_to_int_fill_null() -> None: + lf = pl.LazyFrame({"a": ["one", "two"]}) + mapping = {"one": 1} + + result = lf.select( + pl.col("a") + .replace_strict(mapping, default=None, return_dtype=pl.UInt32) + .fill_null(999) + ) + + expected = pl.LazyFrame({"a": pl.Series([1, 999], dtype=pl.UInt32)}) + assert_frame_equal(result, expected) + + +def test_replace_strict_mix() -> None: + df = pl.DataFrame( + [ + pl.Series("float_to_boolean", [1.0, None]), + pl.Series("boolean_to_int", [True, False]), + pl.Series("boolean_to_str", [True, False]), + ] + ) + + result = df.with_columns( + pl.col("float_to_boolean").replace_strict({1.0: True}), + pl.col("boolean_to_int").replace_strict({True: 1, False: 0}), + pl.col("boolean_to_str").replace_strict({True: "1", False: "0"}), + ) + + expected = pl.DataFrame( + [ + pl.Series("float_to_boolean", [True, None], dtype=pl.Boolean), + pl.Series("boolean_to_int", [1, 0], dtype=pl.Int64), + pl.Series("boolean_to_str", ["1", "0"], dtype=pl.String), + ] + ) + assert_frame_equal(result, expected) + + +@pytest.fixture(scope="module") +def int_mapping() -> dict[int, int]: + return {1: 11, 2: 22, 3: 33, 4: 44, 5: 55} + + +def test_replace_strict_int_to_int2(int_mapping: dict[int, int]) -> None: + s = pl.Series([1, 22, None, 44, -5]) + result = s.replace_strict(int_mapping, default=None) + expected = pl.Series([11, None, None, None, None], dtype=pl.Int64) + assert_series_equal(result, expected) + + +def test_replace_strict_int_to_int3(int_mapping: dict[int, int]) -> None: + s = pl.Series([1, 22, None, 44, -5], dtype=pl.Int16) + result = s.replace_strict(int_mapping, default=9) + expected = pl.Series([11, 9, 9, 9, 9], dtype=pl.Int64) + assert_series_equal(result, expected) + + +def test_replace_strict_int_to_int4_return_dtype(int_mapping: dict[int, int]) -> None: + s = pl.Series([-1, 22, None, 44, -5], dtype=pl.Int16) + result = s.replace_strict(int_mapping, default=s, return_dtype=pl.Float32) + expected = pl.Series([-1.0, 22.0, None, 44.0, -5.0], dtype=pl.Float32) + assert_series_equal(result, expected) + + +def test_replace_strict_int_to_int5_return_dtype(int_mapping: dict[int, int]) -> None: + s = pl.Series([1, 22, None, 44, -5], dtype=pl.Int16) + result = s.replace_strict(int_mapping, default=9, return_dtype=pl.Float32) + expected = pl.Series([11.0, 9.0, 9.0, 9.0, 9.0], dtype=pl.Float32) + assert_series_equal(result, expected) + + +def test_replace_strict_bool_to_int() -> None: + s = pl.Series([True, False, False, None]) + mapping = {True: 1, False: 0} + result = s.replace_strict(mapping) + expected = pl.Series([1, 0, 0, None]) + assert_series_equal(result, expected) + + +def test_replace_strict_bool_to_str() -> None: + s = pl.Series([True, False, False, None]) + mapping = {True: "1", False: "0"} + result = s.replace_strict(mapping) + expected = pl.Series(["1", "0", "0", None]) + assert_series_equal(result, expected) + + +def test_replace_strict_str_to_bool() -> None: + s = pl.Series(["True", "False", "False", None]) + mapping = {"True": True, "False": False} + result = s.replace_strict(mapping) + expected = pl.Series([True, False, False, None]) + assert_series_equal(result, expected) + + +def test_replace_strict_int_to_str() -> None: + s = pl.Series("a", [-1, 2, None, 4, -5]) + mapping = {1: "one", 2: "two", 3: "three", 4: "four", 5: "five"} + + with pytest.raises(InvalidOperationError, match="incomplete mapping"): + s.replace_strict(mapping) + result = s.replace_strict(mapping, default=None) + + expected = pl.Series("a", [None, "two", None, "four", None]) + assert_series_equal(result, expected) + + +def test_replace_strict_int_to_str2() -> None: + s = pl.Series("a", [1, 2, None, 4, 5]) + mapping = {1: "one", 2: "two", 3: "three", 4: "four", 5: "five"} + + result = s.replace_strict(mapping) + + expected = pl.Series("a", ["one", "two", None, "four", "five"]) + assert_series_equal(result, expected) + + +def test_replace_strict_int_to_str_with_default() -> None: + s = pl.Series("a", [1, 2, None, 4, 5]) + mapping = {1: "one", 2: "two", 3: "three", 4: "four", 5: "five"} + + result = s.replace_strict(mapping, default="?") + + expected = pl.Series("a", ["one", "two", "?", "four", "five"]) + assert_series_equal(result, expected) + + +def test_replace_strict_str_to_int() -> None: + s = pl.Series(["a", "b"]) + mapping = {"a": 1, "b": 2} + result = s.replace_strict(mapping) + expected = pl.Series([1, 2]) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("context", "dtype"), + [ + (pl.StringCache(), pl.Categorical), + (pytest.warns(CategoricalRemappingWarning), pl.Categorical), + (contextlib.nullcontext(), pl.Enum(["a", "b", "OTHER"])), + ], +) +@pytest.mark.may_fail_auto_streaming +def test_replace_strict_cat_str( + context: contextlib.AbstractContextManager, # type: ignore[type-arg] + dtype: pl.DataType, +) -> None: + with context: + for old, new, expected in [ + ("a", "c", pl.Series("s", ["c", None], dtype=pl.Utf8)), + (["a", "b"], ["c", "d"], pl.Series("s", ["c", "d"], dtype=pl.Utf8)), + (pl.lit("a", dtype=dtype), "c", pl.Series("s", ["c", None], dtype=pl.Utf8)), + ( + pl.Series(["a", "b"], dtype=dtype), + ["c", "d"], + pl.Series("s", ["c", "d"], dtype=pl.Utf8), + ), + ]: + s = pl.Series("s", ["a", "b"], dtype=dtype) + s_replaced = s.replace_strict(old, new, default=None) # type: ignore[arg-type] + assert_series_equal(s_replaced, expected) + + s = pl.Series("s", ["a", "b"], dtype=dtype) + s_replaced = s.replace_strict(old, new, default="OTHER") # type: ignore[arg-type] + assert_series_equal(s_replaced, expected.fill_null("OTHER")) + + +@pytest.mark.parametrize( + "context", [pl.StringCache(), pytest.warns(CategoricalRemappingWarning)] +) +@pytest.mark.may_fail_auto_streaming +def test_replace_strict_cat_cat( + context: contextlib.AbstractContextManager, # type: ignore[type-arg] +) -> None: + with context: + dt = pl.Categorical + for old, new, expected in [ + ("a", pl.lit("c", dtype=dt), pl.Series("s", ["c", None], dtype=dt)), + ( + ["a", "b"], + pl.Series(["c", "d"], dtype=dt), + pl.Series("s", ["c", "d"], dtype=dt), + ), + ]: + s = pl.Series("s", ["a", "b"], dtype=dt) + s_replaced = s.replace_strict(old, new, default=None) # type: ignore[arg-type] + assert_series_equal(s_replaced, expected) + + s = pl.Series("s", ["a", "b"], dtype=dt) + s_replaced = s.replace_strict(old, new, default=pl.lit("OTHER", dtype=dt)) # type: ignore[arg-type] + assert_series_equal(s_replaced, expected.fill_null("OTHER")) + + +def test_replace_strict_single_argument_not_mapping() -> None: + df = pl.DataFrame({"a": ["b", "b", "b"]}) + with pytest.raises( + TypeError, + match="`new` argument is required if `old` argument is not a Mapping type", + ): + df.select(pl.col("a").replace_strict("b")) + + +def test_replace_strict_unique_22134() -> None: + df = pl.LazyFrame({"mapped_column": ["Jelly", "Soap", "Jelly"]}) + mapping = { + "Jelly": "Jelly", + "Soap": "Soap", + } + df = df.with_columns(pl.col("mapped_column").replace_strict(mapping, default=None)) + df = df.select(["mapped_column"]).unique() + assert_frame_equal( + df.collect(), + pl.DataFrame({"mapped_column": ["Jelly", "Soap"]}), + check_row_order=False, + ) diff --git a/py-polars/tests/unit/operations/test_reshape.py b/py-polars/tests/unit/operations/test_reshape.py new file mode 100644 index 000000000000..d6d000769637 --- /dev/null +++ b/py-polars/tests/unit/operations/test_reshape.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +import re + +import pytest + +import polars as pl +from polars.exceptions import InvalidOperationError +from polars.testing import assert_series_equal + + +def display_shape(shape: tuple[int, ...]) -> str: + return "(" + ", ".join(tuple(str(d) if d >= 0 else "inferred" for d in shape)) + ")" + + +def test_reshape() -> None: + s = pl.Series("a", [1, 2, 3, 4]) + out = s.reshape((-1, 2)) + expected = pl.Series("a", [[1, 2], [3, 4]], dtype=pl.Array(pl.Int64, 2)) + assert_series_equal(out, expected) + out = s.reshape((2, 2)) + assert_series_equal(out, expected) + out = s.reshape((2, -1)) + assert_series_equal(out, expected) + + out = s.reshape((-1, 1)) + expected = pl.Series("a", [[1], [2], [3], [4]], dtype=pl.Array(pl.Int64, 1)) + assert_series_equal(out, expected) + out = s.reshape((4, -1)) + assert_series_equal(out, expected) + out = s.reshape((4, 1)) + assert_series_equal(out, expected) + + # single dimension + out = s.reshape((4,)) + assert_series_equal(out, s) + out = s.reshape((-1,)) + assert_series_equal(out, s) + + # test lazy_dispatch + out = pl.select(pl.lit(s).reshape((-1, 1))).to_series() + assert_series_equal(out, expected) + + # invalid (empty) dimensions + with pytest.raises( + InvalidOperationError, match="at least one dimension must be specified" + ): + s.reshape(()) + + +@pytest.mark.parametrize("shape", [(1, 3), (5, 1), (-1, 5), (3, -1)]) +def test_reshape_invalid_dimension_size(shape: tuple[int, ...]) -> None: + s = pl.Series("a", [1, 2, 3, 4]) + with pytest.raises( + InvalidOperationError, + match=re.escape( + f"cannot reshape array of size 4 into shape {display_shape(shape)}" + ), + ): + s.reshape(shape) + + +def test_reshape_invalid_zero_dimension() -> None: + s = pl.Series("a", [1, 2, 3, 4]) + shape = (-1, 0) + with pytest.raises( + InvalidOperationError, + match=re.escape( + f"cannot reshape non-empty array into shape containing a zero dimension: {display_shape(shape)}" + ), + ): + s.reshape(shape) + + +@pytest.mark.parametrize("shape", [(0, -1), (0, 4), (0, 0)]) +def test_reshape_invalid_zero_dimension2(shape: tuple[int, ...]) -> None: + s = pl.Series("a", [1, 2, 3, 4]) + with pytest.raises( + InvalidOperationError, + match=re.escape( + f"cannot reshape non-empty array into shape containing a zero dimension: {display_shape(shape)}" + ), + ): + s.reshape(shape) + + +@pytest.mark.parametrize("shape", [(-1, -1), (-1, -2), (-2, -2)]) +def test_reshape_invalid_multiple_unknown_dims(shape: tuple[int, ...]) -> None: + s = pl.Series("a", [1, 2, 3, 4]) + with pytest.raises( + InvalidOperationError, match="can only specify one inferred dimension" + ): + s.reshape(shape) + + +@pytest.mark.parametrize("shape", [(0,), (-1,), (-2,)]) +def test_reshape_empty_valid_1d(shape: tuple[int, ...]) -> None: + s = pl.Series("a", [], dtype=pl.Int64) + out = s.reshape(shape) + assert_series_equal(out, s) + + +@pytest.mark.parametrize("shape", [(1,), (2,)]) +def test_reshape_empty_invalid_1d(shape: tuple[int, ...]) -> None: + s = pl.Series("a", [], dtype=pl.Int64) + with pytest.raises( + InvalidOperationError, + match=re.escape( + f"cannot reshape empty array into shape without zero dimension: ({shape[0]})" + ), + ): + s.reshape(shape) + + +def test_array_ndarray_reshape() -> None: + shape = (8, 4, 2, 1) + s = pl.Series(range(64)).reshape(shape) + n = s.to_numpy() + assert n.shape == shape + assert (n[0] == s[0].to_numpy()).all() + n = n[0] + s = s[0] + assert (n[0] == s[0].to_numpy()).all() + + +@pytest.mark.parametrize( + "shape", + [ + (0, 1), + (1, 0), + (-1, 10, 20, 10), + (-1, 1, 0), + (10, 1, 0), + (10, 0, 1, 0), + (10, 0, 1), + (42, 2, 3, 4, 0, 2, 3, 4), + (42, 1, 1, 1, 0), + ], +) +def test_reshape_empty(shape: tuple[int, ...]) -> None: + s = pl.Series("a", [], dtype=pl.Int64) + expected_len = max(shape[0], 0) + assert s.reshape(shape).len() == expected_len diff --git a/py-polars/tests/unit/operations/test_rle.py b/py-polars/tests/unit/operations/test_rle.py new file mode 100644 index 000000000000..5c25ac601a48 --- /dev/null +++ b/py-polars/tests/unit/operations/test_rle.py @@ -0,0 +1,38 @@ +import polars as pl +from polars.testing.asserts.frame import assert_frame_equal + + +def test_rle() -> None: + values = [1, 1, 2, 1, None, 1, 3, 3] + lf = pl.LazyFrame({"a": values}) + + expected = pl.LazyFrame( + {"len": [2, 1, 1, 1, 1, 2], "value": [1, 2, 1, None, 1, 3]}, + schema_overrides={"len": pl.get_index_type()}, + ) + + result_expr = lf.select(pl.col("a").rle()).unnest("a") + assert_frame_equal(result_expr, expected) + + result_series = lf.collect().to_series().rle().struct.unnest() + assert_frame_equal(result_series, expected.collect()) + + +def test_rle_id() -> None: + values = [1, 1, 2, 1, None, 1, 3, 3] + lf = pl.LazyFrame({"a": values}) + + expected = pl.LazyFrame( + {"a": [0, 0, 1, 2, 3, 4, 5, 5]}, schema={"a": pl.get_index_type()} + ) + + result_expr = lf.select(pl.col("a").rle_id()) + assert_frame_equal(result_expr, expected) + + result_series = lf.collect().to_series().rle_id() + assert_frame_equal(result_series.to_frame(), expected.collect()) + + +def test_empty_rle_21787() -> None: + assert pl.Series("a", [], pl.Int64).rle().is_empty() + assert pl.Series("a", [], pl.Int64).rle_id().is_empty() diff --git a/py-polars/tests/unit/operations/test_rolling.py b/py-polars/tests/unit/operations/test_rolling.py new file mode 100644 index 000000000000..a7f1ef3601f0 --- /dev/null +++ b/py-polars/tests/unit/operations/test_rolling.py @@ -0,0 +1,632 @@ +from __future__ import annotations + +from datetime import date, datetime, timedelta +from typing import TYPE_CHECKING, Any + +import pytest + +import polars as pl +from polars.exceptions import ComputeError, InvalidOperationError +from polars.testing import assert_frame_equal, assert_series_equal + +if TYPE_CHECKING: + from polars._typing import ClosedInterval, PolarsIntegerType + + +def test_rolling() -> None: + dates = [ + "2020-01-01 13:45:48", + "2020-01-01 16:42:13", + "2020-01-01 16:45:09", + "2020-01-02 18:12:48", + "2020-01-03 19:45:32", + "2020-01-08 23:16:43", + ] + + df = ( + pl.DataFrame({"dt": dates, "a": [3, 7, 5, 9, 2, 1]}) + .with_columns(pl.col("dt").str.strptime(pl.Datetime)) + .set_sorted("dt") + ) + + period: str | timedelta + for period in ("2d", timedelta(days=2)): + out = df.rolling(index_column="dt", period=period).agg( + [ + pl.sum("a").alias("sum_a"), + pl.min("a").alias("min_a"), + pl.max("a").alias("max_a"), + ] + ) + assert out["sum_a"].to_list() == [3, 10, 15, 24, 11, 1] + assert out["max_a"].to_list() == [3, 7, 7, 9, 9, 1] + assert out["min_a"].to_list() == [3, 3, 3, 3, 2, 1] + + +@pytest.mark.parametrize("dtype", [pl.UInt32, pl.UInt64, pl.Int32, pl.Int64]) +def test_rolling_group_by_overlapping_groups(dtype: PolarsIntegerType) -> None: + # this first aggregates overlapping groups so they cannot be naively flattened + df = pl.DataFrame({"a": [41, 60, 37, 51, 52, 39, 40]}) + + assert_series_equal( + ( + df.with_row_index() + .with_columns(pl.col("index").cast(dtype)) + .rolling(index_column="index", period="5i") + .agg( + # trigger the apply on the expression engine + pl.col("a").map_elements(lambda x: x).sum() + ) + )["a"], + df["a"].rolling_sum(window_size=5, min_samples=1), + ) + + +# TODO: This test requires the environment variable to be set prior to starting +# the thread pool, which implies prior to import. The test is only valid when +# run in isolation, and invalid otherwise because of xdist import caching. +# See GH issue #22070 +def test_rolling_group_by_overlapping_groups_21859_a(monkeypatch: Any) -> None: + monkeypatch.setenv("POLARS_MAX_THREADS", "1") + # assert pl.thread_pool_size() == 1 # pending resolution, see TODO + df = pl.select( + pl.date_range(pl.date(2023, 1, 1), pl.date(2023, 1, 5)) + ).with_row_index() + + out = df.rolling(index_column="date", period="1y").agg( + a1=pl.when(pl.col("date") >= pl.col("date")) + .then(pl.col("index").cast(pl.Int64).cum_sum()) + .last(), + a2=pl.when(pl.col("date") >= pl.col("date")) + .then(pl.col("index").cast(pl.Int64).cum_sum()) + .last(), + )["a1", "a2"] + expected = pl.DataFrame({"a1": [0, 1, 3, 6, 10], "a2": [0, 1, 3, 6, 10]}) + assert_frame_equal(out, expected) + + +# TODO: This test requires the environment variable to be set prior to starting +# the thread pool, which implies prior to import. The test is only valid when +# run in isolation, and invalid otherwise because of xdist import caching. +# See GH issue #22070 +def test_rolling_group_by_overlapping_groups_21859_b(monkeypatch: Any) -> None: + monkeypatch.setenv("POLARS_MAX_THREADS", "1") + # assert pl.thread_pool_size() == 1 # pending resolution, see TODO + df = pl.DataFrame({"a": [20, 30, 40]}) + out = ( + df.with_row_index() + .with_columns(pl.col("index")) + .cast(pl.Int64) + .rolling(index_column="index", period="3i") + .agg( + # trigger the apply on the expression engine + pl.col("a").map_elements(lambda x: x).sum().alias("a1"), + pl.col("a").map_elements(lambda x: x).sum().alias("a2"), + )["a1", "a2"] + ) + expected = pl.DataFrame({"a1": [20, 50, 90], "a2": [20, 50, 90]}) + assert_frame_equal(out, expected) + + +@pytest.mark.parametrize("input", [[pl.col("b").sum()], pl.col("b").sum()]) +@pytest.mark.parametrize("dtype", [pl.UInt32, pl.UInt64, pl.Int32, pl.Int64]) +def test_rolling_agg_input_types(input: Any, dtype: PolarsIntegerType) -> None: + df = pl.LazyFrame( + {"index_column": [0, 1, 2, 3], "b": [1, 3, 1, 2]}, + schema_overrides={"index_column": dtype}, + ).set_sorted("index_column") + result = df.rolling(index_column="index_column", period="2i").agg(input) + expected = pl.LazyFrame( + {"index_column": [0, 1, 2, 3], "b": [1, 4, 4, 3]}, + schema_overrides={"index_column": dtype}, + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("input", [str, "b".join]) +def test_rolling_agg_bad_input_types(input: Any) -> None: + df = pl.LazyFrame({"index_column": [0, 1, 2, 3], "b": [1, 3, 1, 2]}).set_sorted( + "index_column" + ) + with pytest.raises(TypeError): + df.rolling(index_column="index_column", period="2i").agg(input) + + +def test_rolling_negative_offset_3914() -> None: + df = pl.DataFrame( + { + "datetime": pl.datetime_range( + datetime(2020, 1, 1), datetime(2020, 1, 5), "1d", eager=True + ), + } + ) + result = df.rolling(index_column="datetime", period="2d", offset="-4d").agg( + pl.len() + ) + assert result["len"].to_list() == [0, 0, 1, 2, 2] + + df = pl.DataFrame({"ints": range(20)}) + + result = df.rolling(index_column="ints", period="2i", offset="-5i").agg( + pl.col("ints").alias("matches") + ) + expected = [ + [], + [], + [], + [0], + [0, 1], + [1, 2], + [2, 3], + [3, 4], + [4, 5], + [5, 6], + [6, 7], + [7, 8], + [8, 9], + [9, 10], + [10, 11], + [11, 12], + [12, 13], + [13, 14], + [14, 15], + [15, 16], + ] + assert result["matches"].to_list() == expected + + +@pytest.mark.parametrize("time_zone", [None, "US/Central"]) +def test_rolling_negative_offset_crossing_dst(time_zone: str | None) -> None: + df = pl.DataFrame( + { + "datetime": pl.datetime_range( + datetime(2021, 11, 6), + datetime(2021, 11, 9), + "1d", + time_zone=time_zone, + eager=True, + ), + "value": [1, 4, 9, 155], + } + ) + result = df.rolling(index_column="datetime", period="2d", offset="-1d").agg( + pl.col("value") + ) + expected = pl.DataFrame( + { + "datetime": pl.datetime_range( + datetime(2021, 11, 6), + datetime(2021, 11, 9), + "1d", + time_zone=time_zone, + eager=True, + ), + "value": [[1, 4], [4, 9], [9, 155], [155]], + } + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("time_zone", [None, "US/Central"]) +@pytest.mark.parametrize( + ("offset", "closed", "expected_values"), + [ + ("0d", "left", [[1, 4], [4, 9], [9, 155], [155]]), + ("0d", "right", [[4, 9], [9, 155], [155], []]), + ("0d", "both", [[1, 4, 9], [4, 9, 155], [9, 155], [155]]), + ("0d", "none", [[4], [9], [155], []]), + ("1d", "left", [[4, 9], [9, 155], [155], []]), + ("1d", "right", [[9, 155], [155], [], []]), + ("1d", "both", [[4, 9, 155], [9, 155], [155], []]), + ("1d", "none", [[9], [155], [], []]), + ], +) +def test_rolling_non_negative_offset_9077( + time_zone: str | None, + offset: str, + closed: ClosedInterval, + expected_values: list[list[int]], +) -> None: + df = pl.DataFrame( + { + "datetime": pl.datetime_range( + datetime(2021, 11, 6), + datetime(2021, 11, 9), + "1d", + time_zone=time_zone, + eager=True, + ), + "value": [1, 4, 9, 155], + } + ) + result = df.rolling( + index_column="datetime", period="2d", offset=offset, closed=closed + ).agg(pl.col("value")) + expected = pl.DataFrame( + { + "datetime": pl.datetime_range( + datetime(2021, 11, 6), + datetime(2021, 11, 9), + "1d", + time_zone=time_zone, + eager=True, + ), + "value": expected_values, + } + ) + assert_frame_equal(result, expected) + + +def test_rolling_dynamic_sortedness_check() -> None: + # when the by argument is passed, the sortedness flag + # will be unset as the take shuffles data, so we must explicitly + # check the sortedness + df = pl.DataFrame( + { + "idx": [1, 2, -1, 2, 1, 1], + "group": [1, 1, 1, 2, 2, 1], + } + ) + + with pytest.raises(ComputeError, match=r"input data is not sorted"): + df.rolling("idx", period="2i", group_by="group").agg( + pl.col("idx").alias("idx1") + ) + + # no `group_by` argument + with pytest.raises( + InvalidOperationError, + match="argument in operation 'rolling' is not sorted", + ): + df.rolling("idx", period="2i").agg(pl.col("idx").alias("idx1")) + + +def test_rolling_empty_groups_9973() -> None: + dt1 = date(2001, 1, 1) + dt2 = date(2001, 1, 2) + + data = pl.DataFrame( + { + "id": ["A", "A", "B", "B", "C", "C"], + "date": [dt1, dt2, dt1, dt2, dt1, dt2], + "value": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + } + ).sort(by=["id", "date"]) + + expected = pl.DataFrame( + { + "id": ["A", "A", "B", "B", "C", "C"], + "date": [ + date(2001, 1, 1), + date(2001, 1, 2), + date(2001, 1, 1), + date(2001, 1, 2), + date(2001, 1, 1), + date(2001, 1, 2), + ], + "value": [[2.0], [], [4.0], [], [6.0], []], + } + ) + + out = data.rolling( + index_column="date", + group_by="id", + period="2d", + offset="1d", + closed="left", + ).agg(pl.col("value")) + + assert_frame_equal(out, expected) + + +def test_rolling_duplicates_11281() -> None: + df = pl.DataFrame( + { + "ts": [ + datetime(2020, 1, 1), + datetime(2020, 1, 2), + datetime(2020, 1, 2), + datetime(2020, 1, 2), + datetime(2020, 1, 3), + datetime(2020, 1, 4), + ], + "val": [1, 2, 2, 2, 3, 4], + } + ).sort("ts") + result = df.rolling("ts", period="1d", closed="left").agg(pl.col("val")) + expected = df.with_columns(val=pl.Series([[], [1], [1], [1], [2, 2, 2], [3]])) + assert_frame_equal(result, expected) + + +def test_rolling_15225() -> None: + # https://github.com/pola-rs/polars/issues/15225 + + df = pl.DataFrame( + { + "a": [1, 2, 3], + "b": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)], + "c": [1, 1, 2], + } + ) + result = df.rolling("b", period="2d").agg(pl.sum("a")) + expected = pl.DataFrame( + {"b": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)], "a": [1, 3, 5]} + ) + assert_frame_equal(result, expected) + result = df.rolling("b", period="2d", group_by="c").agg(pl.sum("a")) + expected = pl.DataFrame( + { + "c": [1, 1, 2], + "b": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)], + "a": [1, 3, 3], + } + ) + assert_frame_equal(result, expected) + + +def test_multiple_rolling_in_single_expression() -> None: + df = pl.DataFrame( + { + "timestamp": pl.datetime_range( + datetime(2024, 1, 12), + datetime(2024, 1, 12, 0, 0, 0, 150_000), + "10ms", + eager=True, + closed="left", + ), + "price": [0] * 15, + } + ) + + front_count = ( + pl.col("price") + .count() + .rolling("timestamp", period=timedelta(milliseconds=100)) + .cast(pl.Int64) + ) + back_count = ( + pl.col("price") + .count() + .rolling("timestamp", period=timedelta(milliseconds=200)) + .cast(pl.Int64) + ) + assert df.with_columns( + back_count.alias("back"), + front_count.alias("front"), + (back_count - front_count).alias("back - front"), + )["back - front"].to_list() == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5] + + +def test_negative_zero_offset_16168() -> None: + df = pl.DataFrame({"foo": [1] * 3}).sort("foo").with_row_index() + result = df.rolling(index_column="foo", period="1i", offset="0i").agg("index") + expected = pl.DataFrame( + {"foo": [1, 1, 1], "index": [[], [], []]}, + schema_overrides={"index": pl.List(pl.UInt32)}, + ) + assert_frame_equal(result, expected) + result = df.rolling(index_column="foo", period="1i", offset="-0i").agg("index") + assert_frame_equal(result, expected) + + +def test_rolling_sorted_empty_groups_16145() -> None: + df = pl.DataFrame( + { + "id": [1, 2], + "time": [ + datetime(year=1989, month=12, day=1, hour=12, minute=3), + datetime(year=1989, month=12, day=1, hour=13, minute=14), + ], + } + ) + + assert ( + df.sort("id") + .rolling( + index_column="time", + group_by="id", + period="1d", + offset="0d", + closed="right", + ) + .agg() + .select("id") + )["id"].to_list() == [1, 2] + + +def test_rolling_by_() -> None: + df = pl.DataFrame({"group": pl.arange(0, 3, eager=True)}).join( + pl.DataFrame( + { + "datetime": pl.datetime_range( + datetime(2020, 1, 1), datetime(2020, 1, 5), "1d", eager=True + ), + } + ), + how="cross", + ) + out = ( + df.sort("datetime") + .rolling(index_column="datetime", group_by="group", period=timedelta(days=3)) + .agg([pl.len().alias("count")]) + ) + + expected = ( + df.sort(["group", "datetime"]) + .rolling(index_column="datetime", group_by="group", period="3d") + .agg([pl.len().alias("count")]) + ) + assert_frame_equal(out.sort(["group", "datetime"]), expected) + assert out.to_dict(as_series=False) == { + "group": [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2], + "datetime": [ + datetime(2020, 1, 1, 0, 0), + datetime(2020, 1, 2, 0, 0), + datetime(2020, 1, 3, 0, 0), + datetime(2020, 1, 4, 0, 0), + datetime(2020, 1, 5, 0, 0), + datetime(2020, 1, 1, 0, 0), + datetime(2020, 1, 2, 0, 0), + datetime(2020, 1, 3, 0, 0), + datetime(2020, 1, 4, 0, 0), + datetime(2020, 1, 5, 0, 0), + datetime(2020, 1, 1, 0, 0), + datetime(2020, 1, 2, 0, 0), + datetime(2020, 1, 3, 0, 0), + datetime(2020, 1, 4, 0, 0), + datetime(2020, 1, 5, 0, 0), + ], + "count": [1, 2, 3, 3, 3, 1, 2, 3, 3, 3, 1, 2, 3, 3, 3], + } + + +def test_rolling_group_by_empty_groups_by_take_6330() -> None: + df1 = pl.DataFrame({"Event": ["Rain", "Sun"]}) + df2 = pl.DataFrame({"Date": [1, 2, 3, 4]}) + df = df1.join(df2, how="cross").set_sorted("Date") + + result = df.rolling( + index_column="Date", period="2i", offset="-2i", group_by="Event", closed="left" + ).agg(pl.len()) + + assert result.to_dict(as_series=False) == { + "Event": ["Sun", "Sun", "Sun", "Sun", "Rain", "Rain", "Rain", "Rain"], + "Date": [1, 2, 3, 4, 1, 2, 3, 4], + "len": [0, 1, 2, 2, 0, 1, 2, 2], + } + + +def test_rolling_duplicates() -> None: + df = pl.DataFrame( + { + "ts": [datetime(2000, 1, 1, 0, 0), datetime(2000, 1, 1, 0, 0)], + "value": [0, 1], + } + ) + assert df.sort("ts").with_columns(pl.col("value").rolling_max_by("ts", "1d"))[ + "value" + ].to_list() == [1, 1] + + +def test_rolling_group_by_by_argument() -> None: + df = pl.DataFrame({"times": range(10), "groups": [1] * 4 + [2] * 6}) + + out = df.rolling("times", period="5i", group_by=["groups"]).agg( + pl.col("times").alias("agg_list") + ) + + expected = pl.DataFrame( + { + "groups": [1, 1, 1, 1, 2, 2, 2, 2, 2, 2], + "times": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + "agg_list": [ + [0], + [0, 1], + [0, 1, 2], + [0, 1, 2, 3], + [4], + [4, 5], + [4, 5, 6], + [4, 5, 6, 7], + [4, 5, 6, 7, 8], + [5, 6, 7, 8, 9], + ], + } + ) + + assert_frame_equal(out, expected) + + +def test_rolling_by_ordering() -> None: + # we must check that the keys still match the time labels after the rolling window + # with a `by` argument. + df = pl.DataFrame( + { + "dt": [ + datetime(2022, 1, 1, 0, 1), + datetime(2022, 1, 1, 0, 2), + datetime(2022, 1, 1, 0, 3), + datetime(2022, 1, 1, 0, 4), + datetime(2022, 1, 1, 0, 5), + datetime(2022, 1, 1, 0, 6), + datetime(2022, 1, 1, 0, 7), + ], + "key": ["A", "A", "B", "B", "A", "B", "A"], + "val": [1, 1, 1, 1, 1, 1, 1], + } + ).set_sorted("dt") + + assert df.rolling( + index_column="dt", + period="2m", + closed="both", + offset="-1m", + group_by="key", + ).agg( + [ + pl.col("val").sum().alias("sum val"), + ] + ).to_dict(as_series=False) == { + "key": ["A", "A", "A", "A", "B", "B", "B"], + "dt": [ + datetime(2022, 1, 1, 0, 1), + datetime(2022, 1, 1, 0, 2), + datetime(2022, 1, 1, 0, 5), + datetime(2022, 1, 1, 0, 7), + datetime(2022, 1, 1, 0, 3), + datetime(2022, 1, 1, 0, 4), + datetime(2022, 1, 1, 0, 6), + ], + "sum val": [2, 2, 1, 1, 2, 2, 1], + } + + +def test_rolling_bool() -> None: + dates = [ + "2020-01-01 13:45:48", + "2020-01-01 16:42:13", + "2020-01-01 16:45:09", + "2020-01-02 18:12:48", + "2020-01-03 19:45:32", + "2020-01-08 23:16:43", + ] + + df = ( + pl.DataFrame({"dt": dates, "a": [True, False, None, None, True, False]}) + .with_columns(pl.col("dt").str.strptime(pl.Datetime)) + .set_sorted("dt") + ) + + period: str | timedelta + for period in ("2d", timedelta(days=2)): + out = df.rolling(index_column="dt", period=period).agg( + sum_a=pl.col.a.sum(), + min_a=pl.col.a.min(), + max_a=pl.col.a.max(), + sum_a_ref=pl.col.a.cast(pl.Int32).sum(), + min_a_ref=pl.col.a.cast(pl.Int32).min().cast(pl.Boolean), + max_a_ref=pl.col.a.cast(pl.Int32).max().cast(pl.Boolean), + ) + assert out["sum_a"].to_list() == out["sum_a_ref"].to_list() + assert out["max_a"].to_list() == out["max_a_ref"].to_list() + assert out["min_a"].to_list() == out["min_a_ref"].to_list() + + +def test_rolling_var_zero_weight() -> None: + assert_series_equal( + pl.Series([1.0, None, 1.0, 2.0]).rolling_var(2), + pl.Series([None, None, None, 0.5]), + ) + + +def test_rolling_unsupported_22065() -> None: + with pytest.raises(pl.exceptions.InvalidOperationError): + pl.Series("a", [[]]).rolling_sum(10) + with pytest.raises(pl.exceptions.InvalidOperationError): + pl.Series("a", ["1.0"], pl.Decimal).rolling_min(1) + with pytest.raises(pl.exceptions.InvalidOperationError): + pl.Series("a", [None]).rolling_sum(10) + with pytest.raises(pl.exceptions.InvalidOperationError): + pl.Series("a", []).rolling_sum(10) + with pytest.raises(pl.exceptions.InvalidOperationError): + pl.Series("a", [[None]], pl.List(pl.Null)).rolling_sum(10) diff --git a/py-polars/tests/unit/operations/test_search_sorted.py b/py-polars/tests/unit/operations/test_search_sorted.py new file mode 100644 index 000000000000..ff1fdef1c7b9 --- /dev/null +++ b/py-polars/tests/unit/operations/test_search_sorted.py @@ -0,0 +1,78 @@ +import numpy as np +import pytest + +import polars as pl +from polars.testing import assert_series_equal + + +def test_search_sorted() -> None: + for seed in [1, 2, 3]: + np.random.seed(seed) + arr = np.sort(np.random.randn(10) * 100) + s = pl.Series(arr) + + for v in range(int(np.min(arr)), int(np.max(arr)), 20): + assert np.searchsorted(arr, v) == s.search_sorted(v) + + a = pl.Series([1, 2, 3]) + b = pl.Series([1, 2, 2, -1]) + assert a.search_sorted(b).to_list() == [0, 1, 1, 0] + b = pl.Series([1, 2, 2, None, 3]) + assert a.search_sorted(b).to_list() == [0, 1, 1, 0, 2] + + a = pl.Series(["b", "b", "d", "d"]) + b = pl.Series(["a", "b", "c", "d", "e"]) + assert a.search_sorted(b, side="left").to_list() == [0, 0, 2, 2, 4] + assert a.search_sorted(b, side="right").to_list() == [0, 2, 2, 4, 4] + + a = pl.Series([1, 1, 4, 4]) + b = pl.Series([0, 1, 2, 4, 5]) + assert a.search_sorted(b, side="left").to_list() == [0, 0, 2, 2, 4] + assert a.search_sorted(b, side="right").to_list() == [0, 2, 2, 4, 4] + + +def test_search_sorted_multichunk() -> None: + for seed in [1, 2, 3]: + np.random.seed(seed) + arr = np.sort(np.random.randn(10) * 100) + q = len(arr) // 4 + a, b, c, d = map( + pl.Series, (arr[:q], arr[q : 2 * q], arr[2 * q : 3 * q], arr[3 * q :]) + ) + s = pl.concat([a, b, c, d], rechunk=False) + assert s.n_chunks() == 4 + + for v in range(int(np.min(arr)), int(np.max(arr)), 20): + assert np.searchsorted(arr, v) == s.search_sorted(v) + + a = pl.concat( + [ + pl.Series([None, None, None], dtype=pl.Int64), + pl.Series([None, 1, 1, 2, 3]), + pl.Series([4, 4, 5, 6, 7, 8, 8]), + ], + rechunk=False, + ) + assert a.n_chunks() == 3 + b = pl.Series([-10, 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, None]) + left_ref = pl.Series( + [4, 4, 4, 6, 7, 8, 10, 11, 12, 13, 15, 0], dtype=pl.get_index_type() + ) + right_ref = pl.Series( + [4, 4, 6, 7, 8, 10, 11, 12, 13, 15, 15, 4], dtype=pl.get_index_type() + ) + assert_series_equal(a.search_sorted(b, side="left"), left_ref) + assert_series_equal(a.search_sorted(b, side="right"), right_ref) + + +def test_search_sorted_right_nulls() -> None: + a = pl.Series([1, 2, None, None]) + assert a.search_sorted(None, side="left") == 2 + assert a.search_sorted(None, side="right") == 4 + + +def test_raise_literal_numeric_search_sorted_18096() -> None: + df = pl.DataFrame({"foo": [1, 4, 7], "bar": [2, 3, 5]}) + + with pytest.raises(pl.exceptions.InvalidOperationError): + df.with_columns(idx=pl.col("foo").search_sorted("bar")) diff --git a/py-polars/tests/unit/operations/test_select.py b/py-polars/tests/unit/operations/test_select.py new file mode 100644 index 000000000000..d6f07327948a --- /dev/null +++ b/py-polars/tests/unit/operations/test_select.py @@ -0,0 +1,73 @@ +import polars as pl +from polars.testing import assert_frame_equal + + +def test_select_by_col_list(fruits_cars: pl.DataFrame) -> None: + ldf = fruits_cars.lazy() + result = ldf.select(pl.col(["A", "B"]).sum()) + expected = pl.LazyFrame({"A": 15, "B": 15}) + assert_frame_equal(result, expected) + + +def test_select_args_kwargs() -> None: + ldf = pl.LazyFrame({"foo": [1, 2], "bar": [3, 4], "ham": ["a", "b"]}) + + # Single column name + result = ldf.select("foo") + expected = pl.LazyFrame({"foo": [1, 2]}) + assert_frame_equal(result, expected) + + # Column names as list + result = ldf.select(["foo", "bar"]) + expected = pl.LazyFrame({"foo": [1, 2], "bar": [3, 4]}) + assert_frame_equal(result, expected) + + # Column names as positional arguments + result, expected = ldf.select("foo", "bar", "ham"), ldf + assert_frame_equal(result, expected) + + # Keyword arguments + result = ldf.select(oof="foo") + expected = pl.LazyFrame({"oof": [1, 2]}) + assert_frame_equal(result, expected) + + # Mixed + result = ldf.select("bar", "foo", oof="foo") + expected = pl.LazyFrame({"bar": [3, 4], "foo": [1, 2], "oof": [1, 2]}) + assert_frame_equal(result, expected) + + +def test_select_empty() -> None: + result = pl.select() + expected = pl.DataFrame() + assert_frame_equal(result, expected) + + +def test_select_none() -> None: + result = pl.select(None) + expected = pl.select(pl.lit(None)) + assert_frame_equal(result, expected) + + +def test_select_none_combined() -> None: + other = pl.lit(1).alias("one") + + result = pl.select(None, other) + expected = pl.select(pl.lit(None), other) + assert_frame_equal(result, expected) + + result = pl.select(other, None) + expected = pl.select(other, pl.lit(None)) + assert_frame_equal(result, expected) + + +def test_select_empty_list() -> None: + result = pl.select([]) + expected = pl.DataFrame() + assert_frame_equal(result, expected) + + +def test_select_named_inputs_reserved() -> None: + result = pl.select(inputs=1.0, structify=pl.lit("x")) + expected = pl.DataFrame({"inputs": [1.0], "structify": ["x"]}) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/operations/test_sets.py b/py-polars/tests/unit/operations/test_sets.py new file mode 100644 index 000000000000..8b13de7e4fd4 --- /dev/null +++ b/py-polars/tests/unit/operations/test_sets.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import pytest + +import polars as pl +from polars.exceptions import CategoricalRemappingWarning +from polars.testing import assert_series_equal + + +def test_set_intersection_13765() -> None: + df = pl.DataFrame( + { + "a": pl.Series([[1], [1]], dtype=pl.List(pl.UInt32)), + "f": pl.Series([1, 2], dtype=pl.UInt32), + } + ) + + df = df.join(df, how="cross", suffix="_other") + df = df.filter(pl.col("f") == 1) + + df.select(pl.col("a").list.set_intersection("a_other")).to_dict(as_series=False) + + +def test_set_intersection_st_17129() -> None: + df = pl.DataFrame({"a": [1, 2, 2], "b": [2, 2, 4]}) + + assert df.with_columns( + pl.col("b") + .over("a", mapping_strategy="join") + .list.set_intersection([4, 8]) + .alias("intersect") + ).to_dict(as_series=False) == { + "a": [1, 2, 2], + "b": [2, 2, 4], + "intersect": [[], [4], [4]], + } + + +@pytest.mark.parametrize( + ("set_operation", "outcome"), + [ + ("set_difference", [{"z1", "z"}, {"z"}, set(), {"z", "x2"}, {"z", "x3"}]), + ("set_intersection", [{"x", "y"}, {"y"}, {"y", "x"}, {"x", "y"}, set()]), + ( + "set_symmetric_difference", + [{"z1", "z"}, {"x", "z"}, set(), {"z", "x2"}, {"x", "y", "z", "x3"}], + ), + ], +) +@pytest.mark.may_fail_auto_streaming +def test_set_operations_cats(set_operation: str, outcome: list[set[str]]) -> None: + with pytest.warns(CategoricalRemappingWarning): + df = pl.DataFrame( + { + "a": [ + ["z1", "x", "y", "z"], + ["y", "z"], + ["x", "y"], + ["x", "y", "z", "x2"], + ["z", "x3"], + ] + }, + schema={"a": pl.List(pl.Categorical)}, + ) + df = df.with_columns( + getattr(pl.col("a").list, set_operation)(["x", "y"]).alias("b") + ) + assert df.get_column("b").dtype == pl.List(pl.Categorical) + assert [set(el) for el in df["b"].to_list()] == outcome + + +def test_set_invalid_types() -> None: + df = pl.DataFrame({"a": [1, 2, 2, 3, 3], "b": [2, 2, 4, 7, 8]}) + + with pytest.raises(pl.exceptions.InvalidOperationError): + df.with_columns( + pl.col("b") + .implode() + .implode() + .over("a", mapping_strategy="join") + .list.set_intersection([1]) + ) + + +@pytest.mark.parametrize("input", [[], [1, 2], [1, None]]) +@pytest.mark.parametrize( + "set_op", + [ + "set_union", + "set_intersection", + "set_difference", + "set_symmetric_difference", + ], +) +def test_set_opts_set_input(input: list[list[int | None]], set_op: str) -> None: + a = pl.Series([[1, 2, 3], [], [None, 3], [5, 6, 7]]) + op = getattr(a.list, set_op) + assert_series_equal(op(input).list.sort(), op(set(input)).list.sort()) diff --git a/py-polars/tests/unit/operations/test_shift.py b/py-polars/tests/unit/operations/test_shift.py new file mode 100644 index 000000000000..134c62e23a72 --- /dev/null +++ b/py-polars/tests/unit/operations/test_shift.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +from datetime import date + +import pytest + +import polars as pl +from polars.exceptions import ComputeError +from polars.testing import assert_frame_equal, assert_series_equal + + +def test_shift() -> None: + a = pl.Series("a", [1, 2, 3]) + assert_series_equal(a.shift(1), pl.Series("a", [None, 1, 2])) + assert_series_equal(a.shift(-1), pl.Series("a", [2, 3, None])) + assert_series_equal(a.shift(-2), pl.Series("a", [3, None, None])) + assert_series_equal(a.shift(-1, fill_value=10), pl.Series("a", [2, 3, 10])) + + +def test_shift_frame(fruits_cars: pl.DataFrame) -> None: + df = pl.DataFrame({"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, 5]}) + out = df.select(pl.col("a").shift(1)) + assert_series_equal(out["a"], pl.Series("a", [None, 1, 2, 3, 4])) + + res = fruits_cars.lazy().shift(2).collect() + + expected = pl.DataFrame( + { + "A": [None, None, 1, 2, 3], + "fruits": [None, None, "banana", "banana", "apple"], + "B": [None, None, 5, 4, 3], + "cars": [None, None, "beetle", "audi", "beetle"], + } + ) + assert_frame_equal(res, expected) + + # negative value + res = fruits_cars.lazy().shift(-2).collect() + for rows in [3, 4]: + for cols in range(4): + assert res[rows, cols] is None + + +def test_shift_fill_value() -> None: + ldf = pl.LazyFrame({"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, 5]}) + + # use exprs + out = ldf.with_columns( + pl.col("a").shift(n=-2, fill_value=pl.col("b").mean()) + ).collect() + assert not out["a"].has_nulls() + + # use df method + out = ldf.shift(n=2, fill_value=pl.col("b").std()).collect() + assert not out["a"].has_nulls() + + +def test_shift_expr() -> None: + ldf = pl.LazyFrame({"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, 5]}) + + # use exprs + out = ldf.select(pl.col("a").shift(n=pl.col("b").min())).collect() + assert out.to_dict(as_series=False) == {"a": [None, 1, 2, 3, 4]} + + out = ldf.select( + pl.col("a").shift(pl.col("b").min(), fill_value=pl.col("b").max()) + ).collect() + assert out.to_dict(as_series=False) == {"a": [5, 1, 2, 3, 4]} + + # use df method + out = ldf.shift(pl.lit(3)).collect() + assert out.to_dict(as_series=False) == { + "a": [None, None, None, 1, 2], + "b": [None, None, None, 1, 2], + } + out = ldf.shift(pl.lit(2), fill_value=pl.col("b").max()).collect() + assert out.to_dict(as_series=False) == {"a": [5, 5, 1, 2, 3], "b": [5, 5, 1, 2, 3]} + + +@pytest.mark.may_fail_auto_streaming +def test_shift_categorical() -> None: + df = pl.Series("a", ["a", "b"], dtype=pl.Categorical).to_frame() + + s = df.with_columns(pl.col("a").shift(fill_value="c"))["a"] + assert s.dtype == pl.Categorical + assert s.to_list() == ["c", "a"] + + +def test_shift_frame_with_fill() -> None: + df = pl.DataFrame( + { + "foo": [1, 2, 3], + "bar": [6, 7, 8], + "ham": ["a", "b", "c"], + } + ) + result = df.shift(fill_value=0) + expected = pl.DataFrame( + { + "foo": [0, 1, 2], + "bar": [0, 6, 7], + "ham": ["0", "a", "b"], + } + ) + assert_frame_equal(result, expected) + + +def test_shift_fill_value_group_logicals() -> None: + df = pl.DataFrame( + [ + (date(2001, 1, 2), "A"), + (date(2001, 1, 3), "A"), + (date(2001, 1, 4), "A"), + (date(2001, 1, 3), "B"), + (date(2001, 1, 4), "B"), + ], + schema=["d", "s"], + orient="row", + ) + result = df.select(pl.col("d").shift(fill_value=pl.col("d").max(), n=-1).over("s")) + + assert result.dtypes == [pl.Date] + + +def test_shift_n_null() -> None: + df = pl.DataFrame({"a": pl.Series([1, 2, 3], dtype=pl.Int32)}) + out = df.shift(None) # type: ignore[arg-type] + expected = pl.DataFrame({"a": pl.Series([None, None, None], dtype=pl.Int32)}) + assert_frame_equal(out, expected) + + out = df.shift(None, fill_value=1) # type: ignore[arg-type] + assert_frame_equal(out, expected) + + out = df.select(pl.col("a").shift(None)) # type: ignore[arg-type] + assert_frame_equal(out, expected) + + out = df.select(pl.col("a").shift(None, fill_value=1)) # type: ignore[arg-type] + assert_frame_equal(out, expected) + + +def test_shift_n_nonscalar() -> None: + df = pl.DataFrame( + { + "a": [1, 2, 3], + "b": [4, 5, 6], + } + ) + with pytest.raises( + ComputeError, + match="'n' must be scalar value", + ): + # Note: Expressions are not in the signature for `n`, but they work. + # We can still verify that n is scalar up-front. + df.shift(pl.col("b"), fill_value=1) # type: ignore[arg-type] + + with pytest.raises( + ComputeError, + match="'n' must be scalar value", + ): + df.select(pl.col("a").shift(pl.col("b"), fill_value=1)) + + +def test_shift_fill_value_nonscalar() -> None: + df = pl.DataFrame( + { + "a": [1, 2, 3], + "b": [4, 5, 6], + } + ) + with pytest.raises( + ComputeError, + match="'fill_value' must be scalar value", + ): + df.shift(1, fill_value=pl.col("b")) + + with pytest.raises( + ComputeError, + match="'fill_value' must be scalar value", + ): + df.select(pl.col("a").shift(1, fill_value=pl.col("b"))) diff --git a/py-polars/tests/unit/operations/test_shrink_dtype.py b/py-polars/tests/unit/operations/test_shrink_dtype.py new file mode 100644 index 000000000000..443f55814ec1 --- /dev/null +++ b/py-polars/tests/unit/operations/test_shrink_dtype.py @@ -0,0 +1,46 @@ +import polars as pl + + +def test_shrink_dtype() -> None: + out = pl.DataFrame( + { + "a": [1, 2, 3], + "b": [1, 2, 2 << 32], + "c": [-1, 2, 1 << 30], + "d": [-112, 2, 112], + "e": [-112, 2, 129], + "f": ["a", "b", "c"], + "g": [0.1, 1.32, 0.12], + "h": [True, None, False], + "i": pl.Series([None, None, None], dtype=pl.UInt64), + "j": pl.Series([None, None, None], dtype=pl.Int64), + "k": pl.Series([None, None, None], dtype=pl.Float64), + } + ).select(pl.all().shrink_dtype()) + assert out.dtypes == [ + pl.Int8, + pl.Int64, + pl.Int32, + pl.Int8, + pl.Int16, + pl.String, + pl.Float32, + pl.Boolean, + pl.UInt8, + pl.Int8, + pl.Float32, + ] + + assert out.to_dict(as_series=False) == { + "a": [1, 2, 3], + "b": [1, 2, 8589934592], + "c": [-1, 2, 1073741824], + "d": [-112, 2, 112], + "e": [-112, 2, 129], + "f": ["a", "b", "c"], + "g": [0.10000000149011612, 1.3200000524520874, 0.11999999731779099], + "h": [True, None, False], + "i": [None, None, None], + "j": [None, None, None], + "k": [None, None, None], + } diff --git a/py-polars/tests/unit/operations/test_slice.py b/py-polars/tests/unit/operations/test_slice.py new file mode 100644 index 000000000000..61934dfe9585 --- /dev/null +++ b/py-polars/tests/unit/operations/test_slice.py @@ -0,0 +1,318 @@ +from __future__ import annotations + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal, assert_frame_not_equal + + +def test_tail_union() -> None: + assert ( + pl.concat( + [ + pl.LazyFrame({"a": [1, 2]}), + pl.LazyFrame({"a": [3, 4]}), + pl.LazyFrame({"a": [5, 6]}), + ] + ) + .tail(1) + .collect() + ).to_dict(as_series=False) == {"a": [6]} + + +def test_python_slicing_data_frame() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}) + expected = pl.DataFrame({"a": [2, 3], "b": ["b", "c"]}) + for slice_params in ( + [1, 10], # slice > len(df) + [1, 2], # slice == len(df) + [1], # optional len + ): + assert_frame_equal(df.slice(*slice_params), expected) + + # Negative starting index before start of dataframe. + expected = pl.DataFrame({"a": [1, 2], "b": ["a", "b"]}) + assert_frame_equal(df.slice(-5, 4), expected) + + for py_slice in ( + slice(1, 2), + slice(0, 2, 2), + slice(3, -3, -1), + slice(1, None, -2), + slice(-1, -3, -1), + slice(-3, None, -3), + ): + # confirm frame slice matches python slice + assert df[py_slice].rows() == df.rows()[py_slice] + + +def test_python_slicing_series() -> None: + s = pl.Series(name="a", values=[0, 1, 2, 3, 4, 5], dtype=pl.UInt8) + for srs_slice, expected in ( + [s.slice(2, 3), [2, 3, 4]], + [s.slice(4, 1), [4]], + [s.slice(4, None), [4, 5]], + [s.slice(3), [3, 4, 5]], + [s.slice(-2), [4, 5]], + [s.slice(-7, 4), [0, 1, 2]], + [s.slice(-700, 4), []], + ): + assert srs_slice.to_list() == expected # type: ignore[attr-defined] + + for py_slice in ( + slice(1, 2), + slice(0, 2, 2), + slice(3, -3, -1), + slice(1, None, -2), + slice(-1, -3, -1), + slice(-3, None, -3), + ): + # confirm series slice matches python slice + assert s[py_slice].to_list() == s.to_list()[py_slice] + + +def test_python_slicing_lazy_frame() -> None: + ldf = pl.LazyFrame({"a": [1, 2, 3, 4], "b": ["a", "b", "c", "d"]}) + expected = pl.LazyFrame({"a": [3, 4], "b": ["c", "d"]}) + for slice_params in ( + [2, 10], # slice > len(df) + [2, 4], # slice == len(df) + [2], # optional len + ): + assert_frame_equal(ldf.slice(*slice_params), expected) + + for py_slice in ( + slice(1, 2), + slice(0, 3, 2), + slice(-3, None), + slice(None, 2, 2), + slice(3, None, -1), + slice(1, None, -2), + slice(0, None, -1), + ): + # confirm frame slice matches python slice + assert ldf[py_slice].collect().rows() == ldf.collect().rows()[py_slice] + + assert_frame_equal(ldf[0::-1], ldf.head(1)) + assert_frame_equal(ldf[2::-1], ldf.head(3).reverse()) + assert_frame_equal(ldf[::-1], ldf.reverse()) + assert_frame_equal(ldf[::-2], ldf.reverse().gather_every(2)) + + +def test_head_tail_limit() -> None: + df = pl.DataFrame({"a": range(10), "b": range(10)}) + + assert df.head(5).rows() == [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)] + assert_frame_equal(df.limit(5), df.head(5)) + assert df.tail(5).rows() == [(5, 5), (6, 6), (7, 7), (8, 8), (9, 9)] + assert_frame_not_equal(df.head(5), df.tail(5)) + + # check if it doesn't fail when out of bounds + assert df.head(100).height == 10 + assert df.limit(100).height == 10 + assert df.tail(100).height == 10 + + # limit is an alias of head + assert_frame_equal(df.head(5), df.limit(5)) + + # negative values + assert df.head(-7).rows() == [(0, 0), (1, 1), (2, 2)] + assert len(df.head(-2)) == 8 + assert df.tail(-8).rows() == [(8, 8), (9, 9)] + assert len(df.tail(-6)) == 4 + + # negative values out of bounds + assert len(df.head(-12)) == 0 + assert len(df.limit(-12)) == 0 + assert len(df.tail(-12)) == 0 + + +def test_hstack_slice_pushdown() -> None: + lf = pl.LazyFrame({f"column_{i}": [i] for i in range(2)}) + + out = lf.with_columns(pl.col("column_0") * 1000).slice(0, 5) + plan = out.explain() + + assert not plan.startswith("SLICE") + + +def test_hconcat_slice_pushdown() -> None: + num_dfs = 3 + lfs = [ + pl.LazyFrame({f"column_{i}": list(range(i, i + 10))}) for i in range(num_dfs) + ] + + out = pl.concat(lfs, how="horizontal").slice(2, 3) + plan = out.explain() + + assert not plan.startswith("SLICE") + + expected = pl.DataFrame( + {f"column_{i}": list(range(i + 2, i + 5)) for i in range(num_dfs)} + ) + + df_out = out.collect() + assert_frame_equal(df_out, expected) + + +@pytest.mark.parametrize( + "ref", + [ + [0, None], # Mixed. + [None, None], # Full-null. + [0, 0], # All-valid. + ], +) +def test_slice_nullcount(ref: list[int | None]) -> None: + ref *= 128 # Embiggen input. + s = pl.Series(ref) + assert s.null_count() == sum(x is None for x in ref) + assert s.slice(64).null_count() == sum(x is None for x in ref[64:]) + assert s.slice(50, 60).slice(25).null_count() == sum(x is None for x in ref[75:110]) + + +def test_slice_pushdown_set_sorted() -> None: + ldf = pl.LazyFrame({"foo": [1, 2, 3]}) + ldf = ldf.set_sorted("foo").head(2) + plan = ldf.explain() + assert "SLICE" not in plan + assert ldf.collect().height == 2 + + +def test_slice_pushdown_literal_projection_14349() -> None: + lf = pl.select(a=pl.int_range(10)).lazy() + expect = pl.DataFrame({"a": [0, 1, 2, 3, 4], "b": [10, 11, 12, 13, 14]}) + + out = lf.with_columns(b=pl.int_range(10, 20, eager=True)).head(5).collect() + assert_frame_equal(expect, out) + + out = lf.select("a", b=pl.int_range(10, 20, eager=True)).head(5).collect() + assert_frame_equal(expect, out) + + assert pl.LazyFrame().select(x=1).head(0).collect().height == 0 + assert pl.LazyFrame().with_columns(x=1).head(0).collect().height == 0 + + q = lf.select(x=1).head(0) + assert q.collect().height == 0 + + # For select, slice pushdown should happen when at least 1 input column is selected + q = lf.select("a", x=1).head(0) + # slice isn't in plan if it has been pushed down to the dataframe + assert "SLICE" not in q.explain() + assert q.collect().height == 0 + + # For with_columns, slice pushdown should happen if the input has at least 1 column + q = lf.with_columns(x=1).head(0) + assert "SLICE" not in q.explain() + assert q.collect().height == 0 + + q = lf.with_columns(pl.col("a") + 1).head(0) + assert "SLICE" not in q.explain() + assert q.collect().height == 0 + + # This does not project any of the original columns + q = lf.with_columns(a=1, b=2).head(0) + plan = q.explain() + assert plan.index("SLICE") < plan.index("WITH_COLUMNS") + assert q.collect().height == 0 + + q = lf.with_columns(b=1, c=2).head(0) + assert "SLICE" not in q.explain() + assert q.collect().height == 0 + + +@pytest.mark.parametrize( + "input_slice", + [ + (-1, None, -1), + (None, 0, -1), + (1, -1, 1), + (None, -1, None), + (1, 2, -1), + (-1, 1, 1), + ], +) +def test_slice_lazy_frame_raises_proper(input_slice: tuple[int | None]) -> None: + ldf = pl.LazyFrame({"a": [1, 2, 3]}) + s = slice(*input_slice) + with pytest.raises(ValueError, match="not supported"): + ldf[s].collect() + + +def test_double_sort_slice_pushdown_15779() -> None: + assert ( + pl.LazyFrame({"foo": [1, 2]}).sort("foo").head(0).sort("foo").collect() + ).shape == (0, 1) + + +def test_slice_pushdown_simple_projection_18288() -> None: + lf = pl.DataFrame({"col": ["0", "notanumber"]}).lazy() + lf = lf.with_columns([pl.col("col").cast(pl.Int64)]) + lf = lf.with_columns([pl.col("col"), pl.lit(None)]) + assert lf.head(1).collect().to_dict(as_series=False) == { + "col": [0], + "literal": [None], + } + + +def test_group_by_slice_all_keys() -> None: + df = pl.DataFrame( + { + "a": ["Tom", "Nick", "Marry", "Krish", "Jack", None], + "b": [ + "2020-01-01", + "2020-01-02", + "2020-01-03", + "2020-01-04", + "2020-01-05", + None, + ], + "c": [5, 6, 6, 7, 8, 5], + } + ) + + gb = df.group_by(["a", "b", "c"], maintain_order=True) + assert_frame_equal(gb.tail(1), gb.head(1)) + + +def test_slice_first_in_agg_18551() -> None: + df = pl.DataFrame({"id": [1, 1, 2], "name": ["A", "B", "C"], "value": [31, 21, 32]}) + + assert df.group_by("id", maintain_order=True).agg( + sort_by=pl.col("name").sort_by("value"), + x=pl.col("name").sort_by("value").slice(0, 1).first(), + y=pl.col("name").sort_by("value").slice(1, 1).first(), + ).to_dict(as_series=False) == { + "id": [1, 2], + "sort_by": [["B", "A"], ["C"]], + "x": ["B", "C"], + "y": ["A", None], + } + + +def test_slice_after_sort_with_nulls_20079() -> None: + df = pl.LazyFrame({"a": [None, 1.2, None]}) + out = df.sort("a", nulls_last=True).slice(0, 10).collect() + expected = pl.DataFrame({"a": [1.2, None, None]}) + assert_frame_equal(out, expected) + + out = df.sort("a", nulls_last=False).slice(0, 10).collect() + expected = pl.DataFrame({"a": [None, None, 1.2]}) + assert_frame_equal(out, expected) + + +def test_slice_pushdown_panic_20216() -> None: + col = pl.col("A") + + q = pl.LazyFrame({"A": "1/1"}) + q = q.with_columns(col.str.split("/")) + q = q.with_columns(pl.when(col.is_not_null()).then(col.list.get(0)).otherwise(None)) + + assert_frame_equal(q.slice(0, 1).collect(), pl.DataFrame({"A": ["1"]})) + assert_frame_equal(q.collect(), pl.DataFrame({"A": ["1"]})) + + +def test_slice_empty_morsel_input() -> None: + q = pl.LazyFrame({"a": []}) + assert_frame_equal(q.slice(999, 3).slice(999, 3).collect(), q.collect().clear()) + assert_frame_equal(q.slice(-999, 3).slice(-999, 3).collect(), q.collect().clear()) diff --git a/py-polars/tests/unit/operations/test_sort.py b/py-polars/tests/unit/operations/test_sort.py new file mode 100644 index 000000000000..65d846ab9163 --- /dev/null +++ b/py-polars/tests/unit/operations/test_sort.py @@ -0,0 +1,1193 @@ +from __future__ import annotations + +from datetime import date, datetime +from typing import TYPE_CHECKING, Any, Callable + +import pytest +from hypothesis import given + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal +from polars.testing.parametric import dataframes, series + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + + +@given( + s=series( + excluded_dtypes=[ + pl.Object, # Unsortable type + pl.Struct, # Bug, see: https://github.com/pola-rs/polars/issues/17007 + ], + ) +) +def test_series_sort_idempotent(s: pl.Series) -> None: + result = s.sort() + assert result.len() == s.len() + assert_series_equal(result, result.sort()) + + +@given( + df=dataframes( + excluded_dtypes=[ + pl.Object, # Unsortable type + pl.Null, # Bug, see: https://github.com/pola-rs/polars/issues/17007 + pl.Decimal, # Bug, see: https://github.com/pola-rs/polars/issues/17009 + pl.Categorical( + ordering="lexical" + ), # Bug, see: https://github.com/pola-rs/polars/issues/20364 + ], + ) +) +def test_df_sort_idempotent(df: pl.DataFrame) -> None: + cols = df.columns + result = df.sort(cols, maintain_order=True) + assert result.shape == df.shape + assert_frame_equal(result, result.sort(cols, maintain_order=True)) + + +def test_sort_dates_multiples() -> None: + df = pl.DataFrame( + [ + pl.Series( + "date", + [ + "2021-01-01 00:00:00", + "2021-01-01 00:00:00", + "2021-01-02 00:00:00", + "2021-01-02 00:00:00", + "2021-01-03 00:00:00", + ], + ).str.strptime(pl.Datetime, "%Y-%m-%d %T"), + pl.Series("values", [5, 4, 3, 2, 1]), + ] + ) + + expected = [4, 5, 2, 3, 1] + + # datetime + out: pl.DataFrame = df.sort(["date", "values"]) + assert out["values"].to_list() == expected + + # Date + out = df.with_columns(pl.col("date").cast(pl.Date)).sort(["date", "values"]) + assert out["values"].to_list() == expected + + +@pytest.mark.parametrize( + ("sort_function", "expected"), + [ + ( + lambda x: x, + ([3, 1, 2, 5, 4], [4, 5, 2, 1, 3], [5, 4, 3, 1, 2], [1, 2, 3, 4, 5]), + ), + ( + lambda x: x.sort("b", descending=True, maintain_order=True), + ([3, 1, 2, 5, 4], [4, 5, 2, 1, 3], [5, 4, 3, 1, 2], [1, 2, 3, 4, 5]), + ), + ( + lambda x: x.sort("b", "c", descending=False, maintain_order=True), + ([3, 1, 2, 5, 4], [4, 5, 2, 1, 3], [5, 4, 3, 1, 2], [3, 1, 2, 5, 4]), + ), + ], +) +def test_sort_by( + sort_function: Callable[[pl.DataFrame], pl.DataFrame], + expected: tuple[list[int], list[int], list[int], list[int]], +) -> None: + df = pl.DataFrame( + {"a": [1, 2, 3, 4, 5], "b": [1, 1, 1, 2, 2], "c": [2, 3, 1, 2, 1]} + ) + df = sort_function(df) + by: list[pl.Expr | str] + for by in [["b", "c"], [pl.col("b"), "c"]]: # type: ignore[assignment] + out = df.select(pl.col("a").sort_by(by)) + assert out["a"].to_list() == expected[0] + + # Columns as positional arguments are also accepted + out = df.select(pl.col("a").sort_by("b", "c")) + assert out["a"].to_list() == expected[0] + + out = df.select(pl.col("a").sort_by(by, descending=False)) + assert out["a"].to_list() == expected[0] + + out = df.select(pl.col("a").sort_by(by, descending=True)) + assert out["a"].to_list() == expected[1] + + out = df.select(pl.col("a").sort_by(by, descending=[True, False])) + assert out["a"].to_list() == expected[2] + + # by can also be a single column + out = df.select(pl.col("a").sort_by("b", descending=[False], maintain_order=True)) + assert out["a"].to_list() == expected[3] + + +@pytest.mark.parametrize( + ("sort_function"), + [ + lambda x: x, + lambda x: x.sort("a", descending=False, maintain_order=True), + lambda x: x.sort("a", descending=True, maintain_order=True), + lambda x: x.sort("a", descending=False, nulls_last=True), + ], +) +def test_expr_sort_by_nulls_last( + sort_function: Callable[[pl.DataFrame], pl.DataFrame], +) -> None: + df = pl.DataFrame({"a": [1, 2, None, None, 5], "b": [None, 1, 1, 2, None]}) + df = sort_function(df) + + # nulls last + out = df.select(pl.all().sort_by("a", nulls_last=True)) + assert out["a"].to_list() == [1, 2, 5, None, None] + # We don't maintain order so there are two possibilities + assert out["b"].to_list()[:3] == [None, 1, None] + assert out["b"].to_list()[3:] in [[1, 2], [2, 1]] + + # nulls first (default) + for out in ( + df.select(pl.all().sort_by("a", nulls_last=False)), + df.select(pl.all().sort_by("a")), + ): + assert out["a"].to_list() == [None, None, 1, 2, 5] + # We don't maintain order so there are two possibilities + assert out["b"].to_list()[2:] == [None, 1, None] + assert out["b"].to_list()[:2] in [[1, 2], [2, 1]] + + +def test_expr_sort_by_multi_nulls_last() -> None: + df = pl.DataFrame({"x": [None, 1, None, 3], "y": [3, 2, None, 1]}) + + res = df.sort("x", "y", nulls_last=[False, True]) + assert res.to_dict(as_series=False) == { + "x": [None, None, 1, 3], + "y": [3, None, 2, 1], + } + + res = df.sort("x", "y", nulls_last=[True, False]) + assert res.to_dict(as_series=False) == { + "x": [1, 3, None, None], + "y": [2, 1, None, 3], + } + + res = df.sort("x", "y", nulls_last=[True, False], descending=True) + assert res.to_dict(as_series=False) == { + "x": [3, 1, None, None], + "y": [1, 2, None, 3], + } + + res = df.sort("x", "y", nulls_last=[False, True], descending=True) + assert res.to_dict(as_series=False) == { + "x": [None, None, 3, 1], + "y": [3, None, 1, 2], + } + + res = df.sort("x", "y", nulls_last=[False, True], descending=[True, False]) + assert res.to_dict(as_series=False) == { + "x": [None, None, 3, 1], + "y": [3, None, 1, 2], + } + + +def test_sort_by_exprs() -> None: + # make sure that the expression does not overwrite columns in the dataframe + df = pl.DataFrame({"a": [1, 2, -1, -2]}) + out = df.sort(pl.col("a").abs()).to_series() + + assert out.to_list() == [1, -1, 2, -2] + + +@pytest.mark.parametrize( + ("sort_function", "expected"), + [ + (lambda x: x, ([0, 1, 2, 3, 4], [3, 4, 0, 1, 2])), + ( + lambda x: x.sort(descending=False, nulls_last=True), + ([0, 1, 2, 3, 4], [3, 4, 0, 1, 2]), + ), + ( + lambda x: x.sort(descending=False, nulls_last=False), + ([2, 3, 4, 0, 1], [0, 1, 2, 3, 4]), + ), + ], +) +def test_arg_sort_nulls( + sort_function: Callable[[pl.Series], pl.Series], + expected: tuple[list[int], list[int]], +) -> None: + a = pl.Series("a", [1.0, 2.0, 3.0, None, None]) + + a = sort_function(a) + + assert a.arg_sort(nulls_last=True).to_list() == expected[0] + assert a.arg_sort(nulls_last=False).to_list() == expected[1] + + res = a.to_frame().sort(by="a", nulls_last=False).to_series().to_list() + assert res == [None, None, 1.0, 2.0, 3.0] + + res = a.to_frame().sort(by="a", nulls_last=True).to_series().to_list() + assert res == [1.0, 2.0, 3.0, None, None] + + +@pytest.mark.parametrize( + ("nulls_last", "expected"), + [ + (True, [0, 1, 4, 3, 2]), + (False, [2, 3, 0, 1, 4]), + ([True, False], [0, 1, 4, 2, 3]), + ([False, True], [3, 2, 0, 1, 4]), + ], +) +def test_expr_arg_sort_nulls_last( + nulls_last: bool | list[bool], expected: list[int] +) -> None: + df = pl.DataFrame( + { + "a": [1, 2, None, None, 5], + "b": [1, 2, None, 1, None], + "c": [2, 3, 1, 2, 1], + }, + ) + + out = ( + df.select(pl.arg_sort_by("a", "b", nulls_last=nulls_last, maintain_order=True)) + .to_series() + .to_list() + ) + assert out == expected + + +@pytest.mark.parametrize( + ("sort_function"), + [ + lambda x: x, + lambda x: x.sort("Id", descending=False, maintain_order=True), + lambda x: x.sort("Id", descending=True, maintain_order=True), + ], +) +def test_arg_sort_window_functions( + sort_function: Callable[[pl.DataFrame], pl.DataFrame], +) -> None: + df = pl.DataFrame({"Id": [1, 1, 2, 2, 3, 3], "Age": [1, 2, 3, 4, 5, 6]}) + df = sort_function(df) + + out = df.select( + pl.col("Age").arg_sort().over("Id").alias("arg_sort"), + pl.arg_sort_by("Age").over("Id").alias("arg_sort_by"), + ) + assert ( + out["arg_sort"].to_list() == out["arg_sort_by"].to_list() == [0, 1, 0, 1, 0, 1] + ) + + +@pytest.mark.parametrize( + ("sort_function"), + [ + lambda x: x, + lambda x: x.sort("val", descending=False, maintain_order=True), + lambda x: x.sort("val", descending=True, maintain_order=True), + ], +) +def test_sort_nans_3740(sort_function: Callable[[pl.DataFrame], pl.DataFrame]) -> None: + df = pl.DataFrame( + { + "key": [1, 2, 3, 4, 5], + "val": [0.0, None, float("nan"), float("-inf"), float("inf")], + } + ) + df = sort_function(df) + assert df.sort("val")["key"].to_list() == [2, 4, 1, 5, 3] + + +@pytest.mark.parametrize( + ("sort_function"), + [ + lambda x: x, + lambda x: x.sort("a", descending=False, maintain_order=True), + lambda x: x.sort("a", descending=True, maintain_order=True), + ], +) +def test_sort_by_exps_nulls_last( + sort_function: Callable[[pl.DataFrame], pl.DataFrame], +) -> None: + df = pl.DataFrame({"a": [1, 3, -2, None, 1]}).with_row_index() + df = sort_function(df) + + assert df.sort(pl.col("a") ** 2, nulls_last=True).to_dict(as_series=False) == { + "index": [0, 4, 2, 1, 3], + "a": [1, 1, -2, 3, None], + } + + +def test_sort_aggregation_fast_paths() -> None: + df = pl.DataFrame( + { + "a": [None, 3, 2, 1], + "b": [3, 2, 1, None], + "c": [3, None, None, None], + "e": [None, None, None, 1], + "f": [1, 2, 5, 1], + } + ) + + expected = df.select( + pl.all().max().name.suffix("_max"), + pl.all().min().name.suffix("_min"), + ) + + assert expected.to_dict(as_series=False) == { + "a_max": [3], + "b_max": [3], + "c_max": [3], + "e_max": [1], + "f_max": [5], + "a_min": [1], + "b_min": [1], + "c_min": [3], + "e_min": [1], + "f_min": [1], + } + + for descending in [True, False]: + for null_last in [True, False]: + out = df.select( + pl.all() + .sort(descending=descending, nulls_last=null_last) + .max() + .name.suffix("_max"), + pl.all() + .sort(descending=descending, nulls_last=null_last) + .min() + .name.suffix("_min"), + ) + assert_frame_equal(out, expected) + + +@pytest.mark.parametrize("dtype", [pl.Int8, pl.Int16, pl.Int32, pl.Int64]) +def test_sorted_join_and_dtypes(dtype: PolarsDataType) -> None: + df_a = ( + pl.DataFrame({"a": [-5, -2, 3, 3, 9, 10]}) + .with_row_index() + .with_columns(pl.col("a").cast(dtype).set_sorted()) + ) + + df_b = pl.DataFrame({"a": [-2, -3, 3, 10]}).with_columns( + pl.col("a").cast(dtype).set_sorted() + ) + + result_inner = df_a.join(df_b, on="a", how="inner") + assert_frame_equal( + result_inner, + pl.DataFrame( + { + "index": [1, 2, 3, 5], + "a": [-2, 3, 3, 10], + }, + schema={"index": pl.UInt32, "a": dtype}, + ), + check_row_order=False, + ) + + result_left = df_a.join(df_b, on="a", how="left") + assert_frame_equal( + result_left, + pl.DataFrame( + { + "index": [0, 1, 2, 3, 4, 5], + "a": [-5, -2, 3, 3, 9, 10], + }, + schema={"index": pl.UInt32, "a": dtype}, + ), + check_row_order=False, + ) + + +def test_sorted_fast_paths() -> None: + s = pl.Series([1, 2, 3]).sort() + rev = s.sort(descending=True) + + assert rev.to_list() == [3, 2, 1] + assert s.sort().to_list() == [1, 2, 3] + + s = pl.Series([None, 1, 2, 3]).sort() + rev = s.sort(descending=True) + assert rev.to_list() == [None, 3, 2, 1] + assert rev.sort(descending=True).to_list() == [None, 3, 2, 1] + assert rev.sort().to_list() == [None, 1, 2, 3] + + +@pytest.mark.parametrize( + ("df", "expected"), + [ + ( + pl.DataFrame({"Idx": [0, 1, 2, 3, 4, 5, 6], "Val": [0, 1, 2, 3, 4, 5, 6]}), + ( + [[0, 1, 2, 3, 4, 5, 6]], + [[6, 5, 4, 3, 2, 1, 0]], + [[0, 1, 2, 3, 4, 5, 6]], + [[6, 5, 4, 3, 2, 1, 0]], + ), + ), + ( + pl.DataFrame( + {"Idx": [0, 1, 2, 3, 4, 5, 6], "Val": [0, 1, None, 3, None, 5, 6]} + ), + # We don't use maintain order here, so it might as well do anything + # with the None elements. + ( + [[0, 1, 3, 5, 6, 2, 4], [0, 1, 3, 5, 6, 4, 2]], + [[6, 5, 3, 1, 0, 2, 4], [6, 5, 3, 1, 0, 4, 2]], + [[2, 4, 0, 1, 3, 5, 6], [4, 2, 0, 1, 3, 5, 6]], + [[2, 4, 6, 5, 3, 1, 0], [4, 2, 6, 5, 3, 1, 0]], + ), + ), + ], +) +@pytest.mark.parametrize( + ("sort_function"), + [ + lambda x: x.sort("Val", descending=False, nulls_last=True), + lambda x: x.sort("Val", descending=True, nulls_last=True), + lambda x: x.sort("Val", descending=False, nulls_last=False), + lambda x: x.sort("Val", descending=True, nulls_last=False), + ], +) +def test_sorted_arg_sort_fast_paths( + sort_function: Callable[[pl.DataFrame], pl.DataFrame], + df: pl.DataFrame, + expected: tuple[list[list[int]], list[list[int]], list[list[int]], list[list[int]]], +) -> None: + # Test that an already sorted df is correctly sorted (by a single column) + # In certain cases below we will not go through fast path; this test + # is to assert correctness of sorting. + + df = sort_function(df) + # s will be sorted + s = df["Val"] + + # Test dataframe.sort + assert ( + df.sort("Val", descending=False, nulls_last=True)["Idx"].to_list() + in expected[0] + ) + assert ( + df.sort("Val", descending=True, nulls_last=True)["Idx"].to_list() in expected[1] + ) + assert ( + df.sort("Val", descending=False, nulls_last=False)["Idx"].to_list() + in expected[2] + ) + assert ( + df.sort("Val", descending=True, nulls_last=False)["Idx"].to_list() + in expected[3] + ) + # Test series.arg_sort + assert ( + df["Idx"][s.arg_sort(descending=False, nulls_last=True)].to_list() + in expected[0] + ) + assert ( + df["Idx"][s.arg_sort(descending=True, nulls_last=True)].to_list() in expected[1] + ) + assert ( + df["Idx"][s.arg_sort(descending=False, nulls_last=False)].to_list() + in expected[2] + ) + assert ( + df["Idx"][s.arg_sort(descending=True, nulls_last=False)].to_list() + in expected[3] + ) + + +@pytest.mark.parametrize( + ("sort_function"), + [ + lambda x: x, + lambda x: x.sort("val", descending=False, maintain_order=True), + ], +) +def test_arg_sort_rank_nans( + sort_function: Callable[[pl.DataFrame], pl.DataFrame], +) -> None: + df = pl.DataFrame( + { + "val": [1.0, float("nan")], + } + ) + df = sort_function(df) + assert ( + df.with_columns( + pl.col("val").rank().alias("rank"), + pl.col("val").arg_sort().alias("arg_sort"), + ).select(["rank", "arg_sort"]) + ).to_dict(as_series=False) == {"rank": [1.0, 2.0], "arg_sort": [0, 1]} + + +def test_set_sorted_schema() -> None: + assert ( + pl.DataFrame({"A": [0, 1]}) + .lazy() + .with_columns(pl.col("A").set_sorted()) + .collect_schema() + ) == {"A": pl.Int64} + + +@pytest.mark.parametrize( + ("sort_function"), + [ + lambda x: x, + lambda x: x.sort("foo", descending=False, maintain_order=True), + lambda x: x.sort("foo", descending=True, maintain_order=True), + ], +) +def test_sort_slice_fast_path_5245( + sort_function: Callable[[pl.LazyFrame], pl.LazyFrame], +) -> None: + df = pl.DataFrame( + { + "foo": ["f", "c", "b", "a"], + "bar": [1, 2, 3, 4], + } + ).lazy() + df = sort_function(df) + + assert df.sort("foo").limit(1).select("foo").collect().to_dict(as_series=False) == { + "foo": ["a"] + } + + +def test_explicit_list_agg_sort_in_group_by() -> None: + df = pl.DataFrame({"A": ["a", "a", "a", "b", "b", "a"], "B": [1, 2, 3, 4, 5, 6]}) + + # this was col().implode().sort() before we changed the logic + result = df.group_by("A").agg(pl.col("B").sort(descending=True)).sort("A") + expected = df.group_by("A").agg(pl.col("B").sort(descending=True)).sort("A") + assert_frame_equal(result, expected) + + +def test_sorted_join_query_5406() -> None: + df = ( + pl.DataFrame( + { + "Datetime": [ + "2022-11-02 08:00:00", + "2022-11-02 08:00:00", + "2022-11-02 08:01:00", + "2022-11-02 07:59:00", + "2022-11-02 08:02:00", + "2022-11-02 08:02:00", + ], + "Group": ["A", "A", "A", "B", "B", "B"], + "Value": [1, 2, 1, 1, 2, 1], + } + ) + .with_columns(pl.col("Datetime").str.strptime(pl.Datetime, "%Y-%m-%d %H:%M:%S")) + .with_row_index("RowId") + ) + + df1 = df.sort(by=["Datetime", "RowId"]) + + filter1 = ( + df1.group_by(["Datetime", "Group"]) + .agg([pl.all().sort_by("Value", descending=True).first()]) + .sort(["Datetime", "RowId"]) + ) + + out = df1.join(filter1, on="RowId", how="left").select( + pl.exclude(["Datetime_right", "Group_right"]) + ) + assert_series_equal( + out["Value_right"], + pl.Series("Value_right", [1, None, 2, 1, 2, None]), + check_order=False, + ) + + +def test_sort_by_in_over_5499() -> None: + df = pl.DataFrame( + { + "group": [1, 1, 1, 2, 2, 2], + "idx": pl.arange(0, 6, eager=True), + "a": [1, 3, 2, 3, 1, 2], + } + ) + assert df.select( + pl.col("idx").sort_by("a").over("group").alias("sorted_1"), + pl.col("idx").shift(1).sort_by("a").over("group").alias("sorted_2"), + ).to_dict(as_series=False) == { + "sorted_1": [0, 2, 1, 4, 5, 3], + "sorted_2": [None, 1, 0, 3, 4, None], + } + + +def test_merge_sorted() -> None: + df_a = ( + pl.datetime_range( + datetime(2022, 1, 1), datetime(2022, 12, 1), "1mo", eager=True + ) + .to_frame("range") + .with_row_index() + ) + + df_b = ( + pl.datetime_range( + datetime(2022, 1, 1), datetime(2022, 12, 1), "2mo", eager=True + ) + .to_frame("range") + .with_row_index() + .with_columns(pl.col("index") * 10) + ) + out = df_a.merge_sorted(df_b, key="range") + assert out["range"].is_sorted() + assert out.to_dict(as_series=False) == { + "index": [0, 0, 1, 2, 10, 3, 4, 20, 5, 6, 30, 7, 8, 40, 9, 10, 50, 11], + "range": [ + datetime(2022, 1, 1, 0, 0), + datetime(2022, 1, 1, 0, 0), + datetime(2022, 2, 1, 0, 0), + datetime(2022, 3, 1, 0, 0), + datetime(2022, 3, 1, 0, 0), + datetime(2022, 4, 1, 0, 0), + datetime(2022, 5, 1, 0, 0), + datetime(2022, 5, 1, 0, 0), + datetime(2022, 6, 1, 0, 0), + datetime(2022, 7, 1, 0, 0), + datetime(2022, 7, 1, 0, 0), + datetime(2022, 8, 1, 0, 0), + datetime(2022, 9, 1, 0, 0), + datetime(2022, 9, 1, 0, 0), + datetime(2022, 10, 1, 0, 0), + datetime(2022, 11, 1, 0, 0), + datetime(2022, 11, 1, 0, 0), + datetime(2022, 12, 1, 0, 0), + ], + } + + +def test_merge_sorted_one_empty() -> None: + df1 = pl.DataFrame({"key": [1, 2, 3], "a": [1, 2, 3]}) + df2 = pl.DataFrame([], schema=df1.schema) + out = df1.merge_sorted(df2, key="a") + assert_frame_equal(out, df1) + out = df2.merge_sorted(df1, key="a") + assert_frame_equal(out, df1) + + +def test_sort_args() -> None: + df = pl.DataFrame( + { + "a": [1, 2, None], + "b": [6.0, 5.0, 4.0], + "c": ["a", "c", "b"], + } + ) + expected = pl.DataFrame( + { + "a": [None, 1, 2], + "b": [4.0, 6.0, 5.0], + "c": ["b", "a", "c"], + } + ) + + # Single column name + result = df.sort("a") + assert_frame_equal(result, expected) + + # Column names as list + result = df.sort(["a", "b"]) + assert_frame_equal(result, expected) + + # Column names as positional arguments + result = df.sort("a", "b") + assert_frame_equal(result, expected) + + # nulls_last + result = df.sort("a", nulls_last=True) + assert_frame_equal(result, df) + + +def test_sort_type_coercion_6892() -> None: + df = pl.DataFrame({"a": [2, 1], "b": [2, 3]}) + assert df.lazy().sort(pl.col("a") // 2).collect().to_dict(as_series=False) == { + "a": [1, 2], + "b": [3, 2], + } + + +@pytest.mark.slow +def test_sort_row_fmt(str_ints_df: pl.DataFrame) -> None: + # we sort nulls_last as this will always dispatch + # to row_fmt and is the default in pandas + + df = str_ints_df + df_pd = df.to_pandas() + + for descending in [True, False]: + assert_frame_equal( + df.sort(["strs", "vals"], nulls_last=True, descending=descending), + pl.from_pandas( + df_pd.sort_values(["strs", "vals"], ascending=not descending) + ), + ) + + +def test_sort_by_logical() -> None: + test = pl.DataFrame( + { + "start": [date(2020, 5, 6), date(2020, 5, 13), date(2020, 5, 10)], + "end": [date(2020, 12, 31), date(2020, 12, 31), date(2021, 1, 1)], + "num": [0, 1, 2], + } + ) + assert test.select([pl.col("num").sort_by(["start", "end"]).alias("n1")])[ + "n1" + ].to_list() == [0, 2, 1] + df = pl.DataFrame( + { + "dt1": [date(2022, 2, 1), date(2022, 3, 1), date(2022, 4, 1)], + "dt2": [date(2022, 2, 2), date(2022, 3, 2), date(2022, 4, 2)], + "name": ["a", "b", "a"], + "num": [3, 4, 1], + } + ) + assert df.group_by("name").agg([pl.col("num").sort_by(["dt1", "dt2"])]).sort( + "name" + ).to_dict(as_series=False) == {"name": ["a", "b"], "num": [[3, 1], [4]]} + + +@pytest.mark.parametrize( + ("sort_function"), + [ + lambda x: x, + lambda x: x.sort("a", descending=False), + lambda x: x.sort("a", descending=True), + ], +) +def test_limit_larger_than_sort( + sort_function: Callable[[pl.LazyFrame], pl.LazyFrame], +) -> None: + df = pl.LazyFrame({"a": [1]}) + df = sort_function(df) + assert df.sort("a").limit(30).collect().to_dict(as_series=False) == {"a": [1]} + + +@pytest.mark.parametrize( + ("sort_function"), + [ + lambda x: x, + lambda x: x.sort("st", descending=False), + lambda x: x.sort("st", descending=True), + ], +) +def test_sort_by_struct( + sort_function: Callable[[pl.DataFrame], pl.DataFrame], +) -> None: + df = pl.Series([{"a": 300}, {"a": 20}, {"a": 55}]).to_frame("st").with_row_index() + df = sort_function(df) + assert df.sort("st").to_dict(as_series=False) == { + "index": [1, 2, 0], + "st": [{"a": 20}, {"a": 55}, {"a": 300}], + } + + +def test_sort_descending() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + result = df.sort(["a", "b"], descending=True) + expected = pl.DataFrame({"a": [3, 2, 1], "b": [6, 5, 4]}) + assert_frame_equal(result, expected) + result = df.sort(["a", "b"], descending=[True, True]) + assert_frame_equal(result, expected) + with pytest.raises( + ValueError, + match=r"the length of `descending` \(1\) does not match the length of `by` \(2\)", + ): + df.sort(["a", "b"], descending=[True]) + + +def test_sort_by_descending() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + result = df.select(pl.col("a").sort_by(["a", "b"], descending=True)) + expected = pl.DataFrame({"a": [3, 2, 1]}) + assert_frame_equal(result, expected) + result = df.select(pl.col("a").sort_by(["a", "b"], descending=[True, True])) + assert_frame_equal(result, expected) + with pytest.raises( + ValueError, + match=r"the length of `descending` \(1\) does not match the length of `by` \(2\)", + ): + df.select(pl.col("a").sort_by(["a", "b"], descending=[True])) + + +def test_arg_sort_by_descending() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + result = df.select(pl.arg_sort_by(["a", "b"], descending=True)) + expected = pl.DataFrame({"a": [2, 1, 0]}).select(pl.col("a").cast(pl.UInt32)) + assert_frame_equal(result, expected) + result = df.select(pl.arg_sort_by(["a", "b"], descending=[True, True])) + assert_frame_equal(result, expected) + with pytest.raises( + ValueError, + match=r"the length of `descending` \(1\) does not match the length of `exprs` \(2\)", + ): + df.select(pl.arg_sort_by(["a", "b"], descending=[True])) + + +def test_arg_sort_struct() -> None: + df = pl.DataFrame( + { + "a": [100, 300, 100, 200, 200, 100, 300, 200, 400, 400], + "b": [5, 5, 6, 7, 8, 1, 1, 2, 2, 3], + } + ) + expected = [5, 0, 2, 7, 3, 4, 6, 1, 8, 9] + assert df.select(pl.struct("a", "b").arg_sort()).to_series().to_list() == expected + + +def test_sort_top_k_fast_path() -> None: + df = pl.DataFrame( + { + "a": [1, 2, None], + "b": [6.0, 5.0, 4.0], + "c": ["a", "c", "b"], + } + ) + # this triggers fast path as head is equal to n-rows + assert df.lazy().sort("b").head(3).collect().to_dict(as_series=False) == { + "a": [None, 2, 1], + "b": [4.0, 5.0, 6.0], + "c": ["b", "c", "a"], + } + + +def test_sort_by_11653() -> None: + df = pl.DataFrame( + { + "id": [0, 0, 0, 0, 0, 1], + "weights": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + "other": [0.8, 0.4, 0.5, 0.6, 0.7, 0.8], + } + ) + + assert df.group_by("id").agg( + (pl.col("weights") / pl.col("weights").sum()) + .sort_by("other") + .sum() + .alias("sort_by"), + ).sort("id").to_dict(as_series=False) == {"id": [0, 1], "sort_by": [1.0, 1.0]} + + +@pytest.mark.parametrize( + ("sort_function"), + [ + lambda x: x, + lambda x: x.sort("bool", descending=False, nulls_last=True), + lambda x: x.sort("bool", descending=True, nulls_last=True), + lambda x: x.sort("bool", descending=False, nulls_last=False), + lambda x: x.sort("bool", descending=True, nulls_last=False), + ], +) +def test_sort_with_null_12139( + sort_function: Callable[[pl.DataFrame], pl.DataFrame], +) -> None: + df = pl.DataFrame( + { + "bool": [True, False, None, True, False], + "float": [1.0, 2.0, 3.0, 4.0, 5.0], + } + ) + df = sort_function(df) + assert df.sort( + "bool", descending=False, nulls_last=False, maintain_order=True + ).to_dict(as_series=False) == { + "bool": [None, False, False, True, True], + "float": [3.0, 2.0, 5.0, 1.0, 4.0], + } + + assert df.sort( + "bool", descending=False, nulls_last=True, maintain_order=True + ).to_dict(as_series=False) == { + "bool": [False, False, True, True, None], + "float": [2.0, 5.0, 1.0, 4.0, 3.0], + } + + assert df.sort( + "bool", descending=True, nulls_last=True, maintain_order=True + ).to_dict(as_series=False) == { + "bool": [True, True, False, False, None], + "float": [1.0, 4.0, 2.0, 5.0, 3.0], + } + + assert df.sort( + "bool", descending=True, nulls_last=False, maintain_order=True + ).to_dict(as_series=False) == { + "bool": [None, True, True, False, False], + "float": [3.0, 1.0, 4.0, 2.0, 5.0], + } + + +def test_sort_with_null_12272() -> None: + df = pl.DataFrame( + { + "a": [1.0, 1.0, 1.0], + "b": [2.0, -1.0, None], + } + ) + out = df.select((pl.col("a") * pl.col("b")).alias("product")) + + assert out.sort("product").to_dict(as_series=False) == { + "product": [None, -1.0, 2.0] + } + + +@pytest.mark.parametrize( + ("input", "expected"), + [ + ([1, None, 3], [1, 3, None]), + ( + [date(2024, 1, 1), None, date(2024, 1, 3)], + [date(2024, 1, 1), date(2024, 1, 3), None], + ), + (["a", None, "c"], ["a", "c", None]), + ], +) +def test_sort_series_nulls_last(input: list[Any], expected: list[Any]) -> None: + assert pl.Series(input).sort(nulls_last=True).to_list() == expected + + +@pytest.mark.parametrize( + ("sort_function"), + [ + lambda x: x, + lambda x: x.sort("x", descending=False, nulls_last=True), + lambda x: x.sort("x", descending=True, nulls_last=True), + lambda x: x.sort("x", descending=False, nulls_last=False), + lambda x: x.sort("x", descending=True, nulls_last=False), + ], +) +@pytest.mark.parametrize("descending", [True, False]) +@pytest.mark.parametrize("nulls_last", [True, False]) +def test_sort_descending_nulls_last( + sort_function: Callable[[pl.DataFrame], pl.DataFrame], + descending: bool, + nulls_last: bool, +) -> None: + df = pl.DataFrame({"x": [1, 3, None, 2, None], "y": [1, 3, 0, 2, 0]}) + df = sort_function(df) + null_sentinel = 100 if descending ^ nulls_last else -100 + ref_x = [1, 3, None, 2, None] + ref_x.sort(key=lambda k: null_sentinel if k is None else k, reverse=descending) + ref_y = [1, 3, 0, 2, 0] + ref_y.sort(key=lambda k: null_sentinel if k == 0 else k, reverse=descending) + + assert_frame_equal( + df.sort("x", descending=descending, nulls_last=nulls_last), + pl.DataFrame({"x": ref_x, "y": ref_y}), + ) + + assert_frame_equal( + df.sort(["x", "y"], descending=descending, nulls_last=nulls_last), + pl.DataFrame({"x": ref_x, "y": ref_y}), + ) + + +@pytest.mark.release +def test_sort_nan_1942() -> None: + # https://github.com/pola-rs/polars/issues/1942 + import time + + start = time.time() + pl.repeat(float("nan"), 2**13, eager=True).sort() + end = time.time() + + assert (end - start) < 1.0 + + +@pytest.mark.parametrize( + ("sort_function", "expected"), + [ + (lambda x: x, [1, 3, 0, 2]), + (lambda x: x.sort("values", descending=False), [0, 1, 2, 3]), + (lambda x: x.sort("values", descending=True), [2, 3, 0, 1]), + ], +) +def test_sort_chunked_no_nulls( + sort_function: Callable[[pl.DataFrame], pl.DataFrame], expected: list[int] +) -> None: + df = pl.DataFrame({"values": [3.0, 2.0]}) + df = pl.concat([df, df], rechunk=False) + df = sort_function(df) + + assert df.with_columns(pl.col("values").arg_sort())["values"].to_list() == expected + + +def test_sort_string_nulls() -> None: + str_series = pl.Series( + "b", ["a", None, "c", None, "x", "z", "y", None], dtype=pl.String + ) + assert str_series.sort(descending=False, nulls_last=False).to_list() == [ + None, + None, + None, + "a", + "c", + "x", + "y", + "z", + ] + assert str_series.sort(descending=True, nulls_last=False).to_list() == [ + None, + None, + None, + "z", + "y", + "x", + "c", + "a", + ] + assert str_series.sort(descending=True, nulls_last=True).to_list() == [ + "z", + "y", + "x", + "c", + "a", + None, + None, + None, + ] + assert str_series.sort(descending=False, nulls_last=True).to_list() == [ + "a", + "c", + "x", + "y", + "z", + None, + None, + None, + ] + + +def test_sort_by_unequal_lengths_7207() -> None: + df = pl.DataFrame({"a": [0, 1, 1, 0]}) + result = df.select(pl.arg_sort_by(["a", 1])) + + expected = pl.DataFrame({"a": [0, 3, 1, 2]}) + assert_frame_equal(result, expected, check_dtypes=False) + + +def test_sort_literals() -> None: + df = pl.DataFrame({"foo": [1, 2, 3]}) + s = pl.Series([3, 2, 1]) + assert df.sort([s])["foo"].to_list() == [3, 2, 1] + + with pytest.raises(pl.exceptions.ShapeError): + df.sort(pl.Series(values=[1, 2])) + + +def test_sorted_slice_after_function_20712() -> None: + assert_frame_equal( + pl.LazyFrame({"a": 10 * ["A"]}) + .with_columns(b=pl.col("a").str.extract("(.*)")) + .sort("b") + .head(2) + .collect(), + pl.DataFrame({"a": ["A", "A"], "b": ["A", "A"]}), + ) + + +@pytest.mark.slow +def test_sort_into_function_into_dynamic_groupby_20715() -> None: + assert ( + pl.select( + time=pl.datetime_range( + pl.lit("2025-01-13 00:01:00.000000").str.to_datetime( + "%Y-%m-%d %H:%M:%S%.f" + ), + pl.lit("2025-01-17 00:00:00.000000").str.to_datetime( + "%Y-%m-%d %H:%M:%S%.f" + ), + interval="64m", + ) + .cast(pl.String) + .reverse(), + val=pl.Series(range(90)), + cat=pl.Series(list(range(2)) * 45), + ) + .lazy() + .with_columns( + pl.col("time") + .str.to_datetime("%Y-%m-%d %H:%M:%S%.f", strict=False) + .alias("time2") + ) + .sort("time2") + .group_by_dynamic("time2", every="1m", group_by=["cat"]) + .agg(pl.sum("val")) + .sort("time2") + .collect() + .shape + ) == (90, 3) + + +def test_sort_multicolum_null() -> None: + df = pl.DataFrame({"a": [1], "b": [None]}) + assert df.sort(["a", "b"]).shape == (1, 2) + + +def test_sort_nested_multi_column() -> None: + assert pl.DataFrame({"a": [0, 0], "b": [[2], [1]]}).sort(["a", "b"]).to_dict( + as_series=False + ) == {"a": [0, 0], "b": [[1], [2]]} + + +def test_sort_bool_nulls_last() -> None: + assert_series_equal(pl.Series([False]).sort(nulls_last=True), pl.Series([False])) + assert_series_equal( + pl.Series([None, True, False]).sort(nulls_last=True), + pl.Series([False, True, None]), + ) + assert_series_equal( + pl.Series([None, True, False]).sort(nulls_last=False), + pl.Series([None, False, True]), + ) + assert_series_equal( + pl.Series([None, True, False]).sort(nulls_last=True, descending=True), + pl.Series([True, False, None]), + ) + assert_series_equal( + pl.Series([None, True, False]).sort(nulls_last=False, descending=True), + pl.Series([None, True, False]), + ) + + +@pl.StringCache() +@pytest.mark.parametrize( + "dtype", + [ + pl.Enum(["a", "b"]), + pl.Categorical(ordering="lexical"), + pl.Categorical(ordering="physical"), + ], +) +def test_sort_cat_nulls_last(dtype: PolarsDataType) -> None: + assert_series_equal( + pl.Series(["a"], dtype=dtype).sort(nulls_last=True), + pl.Series(["a"], dtype=dtype), + ) + assert_series_equal( + pl.Series([None, "b", "a"], dtype=dtype).sort(nulls_last=True), + pl.Series(["a", "b", None], dtype=dtype), + ) + assert_series_equal( + pl.Series([None, "b", "a"], dtype=dtype).sort(nulls_last=False), + pl.Series([None, "a", "b"], dtype=dtype), + ) + assert_series_equal( + pl.Series([None, "b", "a"], dtype=dtype).sort(nulls_last=True, descending=True), + pl.Series(["b", "a", None], dtype=dtype), + ) + assert_series_equal( + pl.Series([None, "b", "a"], dtype=dtype).sort( + nulls_last=False, descending=True + ), + pl.Series([None, "b", "a"], dtype=dtype), + ) diff --git a/py-polars/tests/unit/operations/test_statistics.py b/py-polars/tests/unit/operations/test_statistics.py new file mode 100644 index 000000000000..b9cab98ba569 --- /dev/null +++ b/py-polars/tests/unit/operations/test_statistics.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +from datetime import timedelta +from typing import cast + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + + +def test_corr() -> None: + df = pl.DataFrame({"a": [1, 2, 3]}) + result = df.corr() + expected = pl.DataFrame({"a": [1.0]}) + assert_frame_equal(result, expected) + + df = pl.DataFrame( + { + "a": [1, 2, 4], + "b": [-1, 23, 8], + } + ) + result = df.corr() + expected = pl.DataFrame( + { + "a": [1.0, 0.18898223650461357], + "b": [0.1889822365046136, 1.0], + } + ) + assert_frame_equal(result, expected) + + +def test_corr_nan() -> None: + df = pl.DataFrame({"a": [1.0, 1.0], "b": [1.0, 2.0]}) + assert str(df.select(pl.corr("a", "b"))[0, 0]) == "nan" + + +def test_median_quantile_duration() -> None: + df = pl.DataFrame({"A": [timedelta(days=0), timedelta(days=1)]}) + + result = df.select(pl.col("A").median()) + expected = pl.DataFrame({"A": [timedelta(seconds=43200)]}) + assert_frame_equal(result, expected) + + result = df.select(pl.col("A").quantile(0.5, interpolation="linear")) + expected = pl.DataFrame({"A": [timedelta(seconds=43200)]}) + assert_frame_equal(result, expected) + + +def test_correlation_cast_supertype() -> None: + df = pl.DataFrame({"a": [1, 8, 3], "b": [4.0, 5.0, 2.0]}) + df = df.with_columns(pl.col("b")) + assert_frame_equal( + df.select(pl.corr("a", "b")), pl.DataFrame({"a": [0.5447047794019219]}) + ) + + +def test_cov_corr_f32_type() -> None: + df = pl.DataFrame({"a": [1, 8, 3], "b": [4, 5, 2]}).select( + pl.all().cast(pl.Float32) + ) + assert df.select(pl.cov("a", "b")).dtypes == [pl.Float32] + assert df.select(pl.corr("a", "b")).dtypes == [pl.Float32] + + +def test_cov(fruits_cars: pl.DataFrame) -> None: + ldf = fruits_cars.lazy() + for cov_ab in (pl.cov(pl.col("A"), pl.col("B")), pl.cov("A", "B")): + assert cast(float, ldf.select(cov_ab).collect().item()) == -2.5 + + +def test_std(fruits_cars: pl.DataFrame) -> None: + res = fruits_cars.lazy().std().collect() + assert res["A"][0] == pytest.approx(1.5811388300841898) + + +def test_var(fruits_cars: pl.DataFrame) -> None: + res = fruits_cars.lazy().var().collect() + assert res["A"][0] == pytest.approx(2.5) + + +def test_max(fruits_cars: pl.DataFrame) -> None: + assert fruits_cars.lazy().max().collect()["A"][0] == 5 + assert fruits_cars.select(pl.col("A").max())["A"][0] == 5 + + +def test_min(fruits_cars: pl.DataFrame) -> None: + assert fruits_cars.lazy().min().collect()["A"][0] == 1 + assert fruits_cars.select(pl.col("A").min())["A"][0] == 1 + + +def test_median(fruits_cars: pl.DataFrame) -> None: + assert fruits_cars.lazy().median().collect()["A"][0] == 3 + assert fruits_cars.select(pl.col("A").median())["A"][0] == 3 + + +def test_quantile(fruits_cars: pl.DataFrame) -> None: + assert fruits_cars.lazy().quantile(0.25, "nearest").collect()["A"][0] == 2 + assert fruits_cars.select(pl.col("A").quantile(0.25, "nearest"))["A"][0] == 2 + + assert fruits_cars.lazy().quantile(0.24, "lower").collect()["A"][0] == 1 + assert fruits_cars.select(pl.col("A").quantile(0.24, "lower"))["A"][0] == 1 + + assert fruits_cars.lazy().quantile(0.26, "higher").collect()["A"][0] == 3 + assert fruits_cars.select(pl.col("A").quantile(0.26, "higher"))["A"][0] == 3 + + assert fruits_cars.lazy().quantile(0.24, "midpoint").collect()["A"][0] == 1.5 + assert fruits_cars.select(pl.col("A").quantile(0.24, "midpoint"))["A"][0] == 1.5 + + assert fruits_cars.lazy().quantile(0.24, "linear").collect()["A"][0] == 1.96 + assert fruits_cars.select(pl.col("A").quantile(0.24, "linear"))["A"][0] == 1.96 + + +def test_count() -> None: + lf = pl.LazyFrame( + { + "nulls": [None, None, None], + "one_null_str": ["one", None, "three"], + "one_null_float": [1.0, 2.0, None], + "no_nulls_int": [1, 2, 3], + } + ) + df = lf.collect() + + lf_result = lf.count() + df_result = df.count() + + expected = pl.LazyFrame( + { + "nulls": [0], + "one_null_str": [2], + "one_null_float": [2], + "no_nulls_int": [3], + }, + ).cast(pl.UInt32) + assert_frame_equal(lf_result, expected) + assert_frame_equal(df_result, expected.collect()) + + +def test_kurtosis_same_vals() -> None: + df = pl.DataFrame({"a": [1.0042855193121334] * 11}) + assert_frame_equal( + df.select(pl.col("a").kurtosis()), pl.select(a=pl.lit(float("nan"))) + ) + + +def test_correction_shape_mismatch_22080() -> None: + with pytest.raises(pl.exceptions.ShapeError): + pl.select(pl.corr(pl.Series([1, 2]), pl.Series([2, 3, 5]))) diff --git a/py-polars/tests/unit/operations/test_top_k.py b/py-polars/tests/unit/operations/test_top_k.py new file mode 100644 index 000000000000..77314ea39cbb --- /dev/null +++ b/py-polars/tests/unit/operations/test_top_k.py @@ -0,0 +1,546 @@ +from typing import Callable + +import pytest +from hypothesis import given +from hypothesis.strategies import booleans + +import polars as pl +from polars.exceptions import ComputeError +from polars.testing import assert_frame_equal, assert_series_equal +from polars.testing.parametric import series + + +def test_top_k() -> None: + # expression + s = pl.Series("a", [3, 8, 1, 5, 2]) + + assert_series_equal(s.top_k(3), pl.Series("a", [8, 5, 3]), check_order=False) + assert_series_equal(s.bottom_k(4), pl.Series("a", [3, 2, 1, 5]), check_order=False) + + # 5886 + df = pl.DataFrame( + { + "test": [2, 4, 1, 3], + "val": [2, 4, 9, 3], + "bool_val": [False, True, True, False], + "str_value": ["d", "b", "a", "c"], + } + ) + assert_frame_equal( + df.select(pl.col("test").top_k(10)), + pl.DataFrame({"test": [4, 3, 2, 1]}), + check_row_order=False, + ) + + assert_frame_equal( + df.select( + top_k=pl.col("test").top_k(pl.col("val").min()), + bottom_k=pl.col("test").bottom_k(pl.col("val").min()), + ), + pl.DataFrame({"top_k": [4, 3], "bottom_k": [1, 2]}), + check_row_order=False, + ) + + assert_frame_equal( + df.select( + pl.col("bool_val").top_k(2).alias("top_k"), + pl.col("bool_val").bottom_k(2).alias("bottom_k"), + ), + pl.DataFrame({"top_k": [True, True], "bottom_k": [False, False]}), + check_row_order=False, + ) + + assert_frame_equal( + df.select( + pl.col("str_value").top_k(2).alias("top_k"), + pl.col("str_value").bottom_k(2).alias("bottom_k"), + ), + pl.DataFrame({"top_k": ["d", "c"], "bottom_k": ["a", "b"]}), + check_row_order=False, + ) + + with pytest.raises(ComputeError, match="`k` must be set for `top_k`"): + df.select( + pl.col("bool_val").top_k(pl.lit(None)), + ) + + with pytest.raises(ComputeError, match="`k` must be a single value for `top_k`."): + df.select(pl.col("test").top_k(pl.lit(pl.Series("s", [1, 2])))) + + # dataframe + df = pl.DataFrame( + { + "a": [1, 2, 3, 4, 2, 2, None], + "b": [None, 2, 1, 4, 3, 2, None], + } + ) + + assert_frame_equal( + df.top_k(3, by=["a", "b"]), + pl.DataFrame({"a": [4, 3, 2], "b": [4, 1, 3]}), + check_row_order=False, + ) + + assert_frame_equal( + df.top_k(3, by=["a", "b"], reverse=True), + pl.DataFrame({"a": [1, 2, 2], "b": [None, 2, 2]}), + check_row_order=False, + ) + assert_frame_equal( + df.bottom_k(4, by=["a", "b"], reverse=True), + pl.DataFrame({"a": [4, 3, 2, 2], "b": [4, 1, 3, 2]}), + check_row_order=False, + ) + + df2 = pl.DataFrame( + { + "a": [1, 2, 3, 4, 5, 6], + "b": [12, 11, 10, 9, 8, 7], + "c": ["Apple", "Orange", "Apple", "Apple", "Banana", "Banana"], + } + ) + + assert_frame_equal( + df2.select( + pl.col("a", "b").top_k_by("a", 2).name.suffix("_top_by_a"), + pl.col("a", "b").top_k_by("b", 2).name.suffix("_top_by_b"), + ), + pl.DataFrame( + { + "a_top_by_a": [6, 5], + "b_top_by_a": [7, 8], + "a_top_by_b": [1, 2], + "b_top_by_b": [12, 11], + } + ), + check_row_order=False, + ) + + assert_frame_equal( + df2.select( + pl.col("a", "b").top_k_by("a", 2, reverse=True).name.suffix("_top_by_a"), + pl.col("a", "b").top_k_by("b", 2, reverse=True).name.suffix("_top_by_b"), + ), + pl.DataFrame( + { + "a_top_by_a": [1, 2], + "b_top_by_a": [12, 11], + "a_top_by_b": [6, 5], + "b_top_by_b": [7, 8], + } + ), + check_row_order=False, + ) + + assert_frame_equal( + df2.select( + pl.col("a", "b").bottom_k_by("a", 2).name.suffix("_bottom_by_a"), + pl.col("a", "b").bottom_k_by("b", 2).name.suffix("_bottom_by_b"), + ), + pl.DataFrame( + { + "a_bottom_by_a": [1, 2], + "b_bottom_by_a": [12, 11], + "a_bottom_by_b": [6, 5], + "b_bottom_by_b": [7, 8], + } + ), + check_row_order=False, + ) + + assert_frame_equal( + df2.select( + pl.col("a", "b") + .bottom_k_by("a", 2, reverse=True) + .name.suffix("_bottom_by_a"), + pl.col("a", "b") + .bottom_k_by("b", 2, reverse=True) + .name.suffix("_bottom_by_b"), + ), + pl.DataFrame( + { + "a_bottom_by_a": [6, 5], + "b_bottom_by_a": [7, 8], + "a_bottom_by_b": [1, 2], + "b_bottom_by_b": [12, 11], + } + ), + check_row_order=False, + ) + + assert_frame_equal( + df2.group_by("c", maintain_order=True) + .agg(pl.all().top_k_by("a", 2)) + .explode(pl.all().exclude("c")), + pl.DataFrame( + { + "c": ["Apple", "Apple", "Orange", "Banana", "Banana"], + "a": [4, 3, 2, 6, 5], + "b": [9, 10, 11, 7, 8], + } + ), + check_row_order=False, + ) + + assert_frame_equal( + df2.group_by("c", maintain_order=True) + .agg(pl.all().bottom_k_by("a", 2)) + .explode(pl.all().exclude("c")), + pl.DataFrame( + { + "c": ["Apple", "Apple", "Orange", "Banana", "Banana"], + "a": [1, 3, 2, 5, 6], + "b": [12, 10, 11, 8, 7], + } + ), + check_row_order=False, + ) + + assert_frame_equal( + df2.select( + pl.col("a", "b", "c").top_k_by(["c", "a"], 2).name.suffix("_top_by_ca"), + pl.col("a", "b", "c").top_k_by(["c", "b"], 2).name.suffix("_top_by_cb"), + ), + pl.DataFrame( + { + "a_top_by_ca": [2, 6], + "b_top_by_ca": [11, 7], + "c_top_by_ca": ["Orange", "Banana"], + "a_top_by_cb": [2, 5], + "b_top_by_cb": [11, 8], + "c_top_by_cb": ["Orange", "Banana"], + } + ), + check_row_order=False, + ) + + assert_frame_equal( + df2.select( + pl.col("a", "b", "c") + .bottom_k_by(["c", "a"], 2) + .name.suffix("_bottom_by_ca"), + pl.col("a", "b", "c") + .bottom_k_by(["c", "b"], 2) + .name.suffix("_bottom_by_cb"), + ), + pl.DataFrame( + { + "a_bottom_by_ca": [1, 3], + "b_bottom_by_ca": [12, 10], + "c_bottom_by_ca": ["Apple", "Apple"], + "a_bottom_by_cb": [4, 3], + "b_bottom_by_cb": [9, 10], + "c_bottom_by_cb": ["Apple", "Apple"], + } + ), + check_row_order=False, + ) + + assert_frame_equal( + df2.select( + pl.col("a", "b", "c") + .top_k_by(["c", "a"], 2, reverse=[True, False]) + .name.suffix("_top_by_ca"), + pl.col("a", "b", "c") + .top_k_by(["c", "b"], 2, reverse=[True, False]) + .name.suffix("_top_by_cb"), + ), + pl.DataFrame( + { + "a_top_by_ca": [4, 3], + "b_top_by_ca": [9, 10], + "c_top_by_ca": ["Apple", "Apple"], + "a_top_by_cb": [1, 3], + "b_top_by_cb": [12, 10], + "c_top_by_cb": ["Apple", "Apple"], + } + ), + check_row_order=False, + ) + + assert_frame_equal( + df2.select( + pl.col("a", "b", "c") + .bottom_k_by(["c", "a"], 2, reverse=[True, False]) + .name.suffix("_bottom_by_ca"), + pl.col("a", "b", "c") + .bottom_k_by(["c", "b"], 2, reverse=[True, False]) + .name.suffix("_bottom_by_cb"), + ), + pl.DataFrame( + { + "a_bottom_by_ca": [2, 5], + "b_bottom_by_ca": [11, 8], + "c_bottom_by_ca": ["Orange", "Banana"], + "a_bottom_by_cb": [2, 6], + "b_bottom_by_cb": [11, 7], + "c_bottom_by_cb": ["Orange", "Banana"], + } + ), + check_row_order=False, + ) + + assert_frame_equal( + df2.select( + pl.col("a", "b", "c") + .top_k_by(["c", "a"], 2, reverse=[False, True]) + .name.suffix("_top_by_ca"), + pl.col("a", "b", "c") + .top_k_by(["c", "b"], 2, reverse=[False, True]) + .name.suffix("_top_by_cb"), + ), + pl.DataFrame( + { + "a_top_by_ca": [2, 5], + "b_top_by_ca": [11, 8], + "c_top_by_ca": ["Orange", "Banana"], + "a_top_by_cb": [2, 6], + "b_top_by_cb": [11, 7], + "c_top_by_cb": ["Orange", "Banana"], + } + ), + check_row_order=False, + ) + + assert_frame_equal( + df2.select( + pl.col("a", "b", "c") + .top_k_by(["c", "a"], 2, reverse=[False, True]) + .name.suffix("_bottom_by_ca"), + pl.col("a", "b", "c") + .top_k_by(["c", "b"], 2, reverse=[False, True]) + .name.suffix("_bottom_by_cb"), + ), + pl.DataFrame( + { + "a_bottom_by_ca": [2, 5], + "b_bottom_by_ca": [11, 8], + "c_bottom_by_ca": ["Orange", "Banana"], + "a_bottom_by_cb": [2, 6], + "b_bottom_by_cb": [11, 7], + "c_bottom_by_cb": ["Orange", "Banana"], + } + ), + check_row_order=False, + ) + + with pytest.raises( + ValueError, + match=r"the length of `reverse` \(2\) does not match the length of `by` \(1\)", + ): + df2.select(pl.all().top_k_by("a", 2, reverse=[True, False])) + + with pytest.raises( + ValueError, + match=r"the length of `reverse` \(2\) does not match the length of `by` \(1\)", + ): + df2.select(pl.all().bottom_k_by("a", 2, reverse=[True, False])) + + +def test_top_k_reverse() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + result = df.top_k(1, by=["a", "b"], reverse=True) + expected = pl.DataFrame({"a": [1], "b": [4]}) + assert_frame_equal(result, expected, check_row_order=False) + result = df.top_k(1, by=["a", "b"], reverse=[True, True]) + assert_frame_equal(result, expected, check_row_order=False) + with pytest.raises( + ValueError, + match=r"the length of `reverse` \(1\) does not match the length of `by` \(2\)", + ): + df.top_k(1, by=["a", "b"], reverse=[True]) + + +def test_top_k_9385() -> None: + lf = pl.LazyFrame({"b": [True, False]}) + result = lf.sort(["b"]).slice(0, 1) + assert result.collect()["b"].to_list() == [False] + + +def test_top_k_empty() -> None: + df = pl.DataFrame({"test": []}) + + assert_frame_equal(df.select([pl.col("test").top_k(2)]), df) + + +@given(s=series(excluded_dtypes=[pl.Null, pl.Struct]), should_sort=booleans()) +def test_top_k_nulls(s: pl.Series, should_sort: bool) -> None: + if should_sort: + s = s.sort() + + valid_count = s.len() - s.null_count() + result = s.top_k(valid_count) + assert result.null_count() == 0 + + result = s.top_k(s.len()) + assert result.null_count() == s.null_count() + + result = s.top_k(s.len() * 2) + assert_series_equal(result, s, check_order=False) + + +@given(s=series(excluded_dtypes=[pl.Null, pl.Struct]), should_sort=booleans()) +def test_bottom_k_nulls(s: pl.Series, should_sort: bool) -> None: + if should_sort: + s = s.sort() + + valid_count = s.len() - s.null_count() + + result = s.bottom_k(valid_count) + assert result.null_count() == 0 + + result = s.bottom_k(s.len()) + assert result.null_count() == s.null_count() + + result = s.bottom_k(s.len() * 2) + assert_series_equal(result, s, check_order=False) + + +def test_top_k_descending_deprecated() -> None: + with pytest.deprecated_call(): + pl.col("a").top_k_by("b", descending=True) # type: ignore[call-arg] + + +@pytest.mark.parametrize( + ("sort_function"), + [ + lambda x: x, + lambda x: x.sort("a", descending=False, maintain_order=True), + lambda x: x.sort("a", descending=True, maintain_order=True), + ], +) +@pytest.mark.parametrize( + ("df", "df2"), + [ + ( + pl.LazyFrame({"a": [3, 4, 1, 2, 5]}), + pl.LazyFrame({"a": [1, None, None, 4, 5]}), + ), + ( + pl.LazyFrame({"a": [3, 4, 1, 2, 5], "b": [1, 2, 3, 4, 5]}), + pl.LazyFrame({"a": [1, None, None, 4, 5], "b": [1, 2, 3, 4, 5]}), + ), + ], +) +def test_top_k_df( + sort_function: Callable[[pl.LazyFrame], pl.LazyFrame], + df: pl.LazyFrame, + df2: pl.LazyFrame, +) -> None: + df = sort_function(df) + expected = [5, 4, 3] + assert df.sort("a", descending=True).limit(3).collect()["a"].to_list() == expected + assert df.top_k(3, by="a").collect()["a"].to_list() == expected + expected = [1, 2, 3] + assert df.sort("a", descending=False).limit(3).collect()["a"].to_list() == expected + assert df.bottom_k(3, by="a").collect()["a"].to_list() == expected + + df = sort_function(df2) + expected2 = [5, 4, 1, None] + assert ( + df.sort("a", descending=True, nulls_last=True).limit(4).collect()["a"].to_list() + == expected2 + ) + assert df.top_k(4, by="a").collect()["a"].to_list() == expected2 + expected2 = [1, 4, 5, None] + assert ( + df.sort("a", descending=False, nulls_last=True) + .limit(4) + .collect()["a"] + .to_list() + == expected2 + ) + assert df.bottom_k(4, by="a").collect()["a"].to_list() == expected2 + + assert df.sort("a", descending=False, nulls_last=False).limit(4).collect()[ + "a" + ].to_list() == [None, None, 1, 4] + assert df.sort("a", descending=True, nulls_last=False).limit(4).collect()[ + "a" + ].to_list() == [None, None, 5, 4] + + +@pytest.mark.parametrize("descending", [True, False]) +def test_sorted_top_k_20719(descending: bool) -> None: + df = pl.DataFrame( + [ + {"a": 1, "b": 1}, + {"a": 5, "b": 5}, + {"a": 9, "b": 9}, + {"a": 10, "b": 20}, + ] + ).sort(by="a", descending=descending) + + # Note: Output stability is guaranteed by the input sortedness as an + # implementation detail. + + for func, reverse in [ + [pl.DataFrame.top_k, False], + [pl.DataFrame.bottom_k, True], + ]: + assert_frame_equal( + df.pipe(func, 2, by="a", reverse=reverse), # type: ignore[arg-type] + pl.DataFrame( + [ + {"a": 10, "b": 20}, + {"a": 9, "b": 9}, + ] + ), + ) + + for func, reverse in [ + [pl.DataFrame.top_k, True], + [pl.DataFrame.bottom_k, False], + ]: + assert_frame_equal( + df.pipe(func, 2, by="a", reverse=reverse), # type: ignore[arg-type] + pl.DataFrame( + [ + {"a": 1, "b": 1}, + {"a": 5, "b": 5}, + ] + ), + ) + + +@pytest.mark.parametrize( + ("func", "reverse", "expect"), + [ + (pl.DataFrame.top_k, False, pl.DataFrame({"a": [2, 2]})), + (pl.DataFrame.bottom_k, True, pl.DataFrame({"a": [2, 2]})), + (pl.DataFrame.top_k, True, pl.DataFrame({"a": [1, 2]})), + (pl.DataFrame.bottom_k, False, pl.DataFrame({"a": [1, 2]})), + ], +) +@pytest.mark.parametrize("descending", [True, False]) +def test_sorted_top_k_duplicates( + func: Callable[[pl.DataFrame], pl.DataFrame], + reverse: bool, + expect: pl.DataFrame, + descending: bool, +) -> None: + assert_frame_equal( + pl.DataFrame({"a": [1, 2, 2]}) # type: ignore[call-arg] + .sort("a", descending=descending) + .pipe(func, 2, by="a", reverse=reverse), + expect, + ) + + +def test_top_k_list_dtype() -> None: + s = pl.Series([[1, 2], [3, 4], [], [0]], dtype=pl.List(pl.Int64)) + assert s.top_k(2).to_list() == [[3, 4], [1, 2]] + + s = pl.Series([[[1, 2], [3]], [[4], []], [[0]]], dtype=pl.List(pl.List(pl.Int64))) + assert s.top_k(2).to_list() == [[[4], []], [[1, 2], [3]]] + + +def test_top_k_sorted_21260() -> None: + s = pl.Series([1, 2, 3, 4, 5]) + assert s.top_k(3).sort().to_list() == [3, 4, 5] + assert s.sort(descending=False).top_k(3).sort().to_list() == [3, 4, 5] + assert s.sort(descending=True).top_k(3).sort().to_list() == [3, 4, 5] + + assert s.bottom_k(3).sort().to_list() == [1, 2, 3] + assert s.sort(descending=False).bottom_k(3).sort().to_list() == [1, 2, 3] + assert s.sort(descending=True).bottom_k(3).sort().to_list() == [1, 2, 3] diff --git a/py-polars/tests/unit/operations/test_transpose.py b/py-polars/tests/unit/operations/test_transpose.py new file mode 100644 index 000000000000..8bc691ed6acd --- /dev/null +++ b/py-polars/tests/unit/operations/test_transpose.py @@ -0,0 +1,226 @@ +import io +from collections.abc import Iterator +from datetime import date, datetime + +import pytest + +import polars as pl +from polars.exceptions import ( + InvalidOperationError, + SchemaError, + StringCacheMismatchError, +) +from polars.testing import assert_frame_equal, assert_series_equal + + +@pytest.mark.may_fail_auto_streaming +def test_transpose_supertype() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": ["foo", "bar", "ham"]}) + result = df.transpose() + expected = pl.DataFrame( + { + "column_0": ["1", "foo"], + "column_1": ["2", "bar"], + "column_2": ["3", "ham"], + } + ) + assert_frame_equal(result, expected) + + +@pytest.mark.may_fail_auto_streaming +def test_transpose_tz_naive_and_tz_aware() -> None: + df = pl.DataFrame( + { + "a": [datetime(2020, 1, 1)], + "b": [datetime(2020, 1, 1)], + } + ) + df = df.with_columns(pl.col("b").dt.replace_time_zone("Asia/Kathmandu")) + with pytest.raises( + SchemaError, + match=r"failed to determine supertype of datetime\[μs\] and datetime\[μs, Asia/Kathmandu\]", + ): + df.transpose() + + +@pytest.mark.may_fail_auto_streaming +def test_transpose_struct() -> None: + df = pl.DataFrame( + { + "a": ["foo", "bar", "ham"], + "b": [ + {"a": date(2022, 1, 1), "b": True}, + {"a": date(2022, 1, 2), "b": False}, + {"a": date(2022, 1, 3), "b": False}, + ], + } + ) + result = df.transpose() + expected = pl.DataFrame( + { + "column_0": ["foo", "{2022-01-01,true}"], + "column_1": ["bar", "{2022-01-02,false}"], + "column_2": ["ham", "{2022-01-03,false}"], + } + ) + assert_frame_equal(result, expected) + + df = pl.DataFrame( + { + "b": [ + {"a": date(2022, 1, 1), "b": True}, + {"a": date(2022, 1, 2), "b": False}, + {"a": date(2022, 1, 3), "b": False}, + ] + } + ) + result = df.transpose() + expected = pl.DataFrame( + { + "column_0": [{"a": date(2022, 1, 1), "b": True}], + "column_1": [{"a": date(2022, 1, 2), "b": False}], + "column_2": [{"a": date(2022, 1, 3), "b": False}], + } + ) + assert_frame_equal(result, expected) + + +@pytest.mark.may_fail_auto_streaming +def test_transpose_arguments() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]}) + expected = pl.DataFrame( + { + "column": ["a", "b"], + "column_0": [1, 1], + "column_1": [2, 2], + "column_2": [3, 3], + } + ) + out = df.transpose(include_header=True) + assert_frame_equal(expected, out) + + out = df.transpose(include_header=False, column_names=["a", "b", "c"]) + expected = pl.DataFrame( + { + "a": [1, 1], + "b": [2, 2], + "c": [3, 3], + } + ) + assert_frame_equal(expected, out) + + out = df.transpose( + include_header=True, header_name="foo", column_names=["a", "b", "c"] + ) + expected = pl.DataFrame( + { + "foo": ["a", "b"], + "a": [1, 1], + "b": [2, 2], + "c": [3, 3], + } + ) + assert_frame_equal(expected, out) + + def name_generator() -> Iterator[str]: + base_name = "my_column_" + count = 0 + while True: + yield f"{base_name}{count}" + count += 1 + + out = df.transpose(include_header=False, column_names=name_generator()) + expected = pl.DataFrame( + { + "my_column_0": [1, 1], + "my_column_1": [2, 2], + "my_column_2": [3, 3], + } + ) + assert_frame_equal(expected, out) + + +@pytest.mark.may_fail_auto_streaming +def test_transpose_categorical_data() -> None: + with pl.StringCache(): + df = pl.DataFrame( + [ + pl.Series(["a", "b", "c"], dtype=pl.Categorical), + pl.Series(["c", "g", "c"], dtype=pl.Categorical), + pl.Series(["d", "b", "c"], dtype=pl.Categorical), + ] + ) + df_transposed = df.transpose( + include_header=False, column_names=["col1", "col2", "col3"] + ) + assert_series_equal( + df_transposed.get_column("col1"), + pl.Series("col1", ["a", "c", "d"], dtype=pl.Categorical), + ) + + # Without string Cache only works if they have the same categories in the same order + df = pl.DataFrame( + [ + pl.Series(["a", "b", "c", "c"], dtype=pl.Categorical), + pl.Series(["a", "b", "b", "c"], dtype=pl.Categorical), + pl.Series(["a", "a", "b", "c"], dtype=pl.Categorical), + ] + ) + df_transposed = df.transpose( + include_header=False, column_names=["col1", "col2", "col3", "col4"] + ) + + with pytest.raises(StringCacheMismatchError): + pl.DataFrame( + [ + pl.Series(["a", "b", "c", "c"], dtype=pl.Categorical), + pl.Series(["c", "b", "b", "c"], dtype=pl.Categorical), + ] + ).transpose() + + +@pytest.mark.may_fail_auto_streaming +def test_transpose_logical_data() -> None: + df = pl.DataFrame( + { + "a": [date(2022, 2, 1), date(2022, 2, 2), date(2022, 1, 3)], + "b": [datetime(2022, 1, 1), datetime(2022, 1, 2), datetime(2022, 1, 3)], + } + ) + result = df.transpose() + expected = pl.DataFrame( + { + "column_0": [datetime(2022, 2, 1, 0, 0), datetime(2022, 1, 1, 0, 0)], + "column_1": [datetime(2022, 2, 2, 0, 0), datetime(2022, 1, 2, 0, 0)], + "column_2": [datetime(2022, 1, 3, 0, 0), datetime(2022, 1, 3, 0, 0)], + } + ) + assert_frame_equal(result, expected) + + +@pytest.mark.may_fail_auto_streaming +def test_err_transpose_object() -> None: + class CustomObject: + pass + + with pytest.raises(InvalidOperationError): + pl.DataFrame([CustomObject()]).transpose() + + +@pytest.mark.may_fail_auto_streaming +def test_transpose_name_from_column_13777() -> None: + csv_file = io.BytesIO(b"id,kc\nhi,3") + df = pl.read_csv(csv_file).transpose(column_names="id") + assert_series_equal(df.to_series(0), pl.Series("hi", [3])) + + +@pytest.mark.may_fail_auto_streaming +def test_transpose_multiple_chunks() -> None: + df = pl.DataFrame({"a": ["1"]}) + expected = pl.DataFrame({"column_0": ["1"], "column_1": ["1"]}) + assert_frame_equal(df.vstack(df).transpose(), expected) + + +def test_nested_struct_transpose_21923() -> None: + df = pl.DataFrame({"x": [{"a": {"b": 1, "c": 2}}]}) + assert df.transpose().item() == df.item() diff --git a/py-polars/tests/unit/operations/test_unpivot.py b/py-polars/tests/unit/operations/test_unpivot.py new file mode 100644 index 000000000000..26cfb98d97b4 --- /dev/null +++ b/py-polars/tests/unit/operations/test_unpivot.py @@ -0,0 +1,121 @@ +import pytest + +import polars as pl +import polars.selectors as cs +from polars import StringCache +from polars.testing import assert_frame_equal + + +def test_unpivot() -> None: + df = pl.DataFrame({"A": ["a", "b", "c"], "B": [1, 3, 5], "C": [2, 4, 6]}) + expected = { + ("a", "B", 1), + ("b", "B", 3), + ("c", "B", 5), + ("a", "C", 2), + ("b", "C", 4), + ("c", "C", 6), + } + for _idv, _vv in (("A", ("B", "C")), (cs.string(), cs.integer())): + unpivoted_eager = df.unpivot(index="A", on=["B", "C"]) + assert set(unpivoted_eager.iter_rows()) == expected + + unpivoted_lazy = df.lazy().unpivot(index="A", on=["B", "C"]).collect() + assert set(unpivoted_lazy.iter_rows()) == expected + + unpivoted = df.unpivot(index="A", on="B") + assert set(unpivoted["value"]) == {1, 3, 5} + + expected_full = { + ("A", "a"), + ("A", "b"), + ("A", "c"), + ("B", "1"), + ("B", "3"), + ("B", "5"), + ("C", "2"), + ("C", "4"), + ("C", "6"), + } + for unpivoted in [df.unpivot(), df.lazy().unpivot().collect()]: + assert set(unpivoted.iter_rows()) == expected_full + + with pytest.deprecated_call(match="unpivot"): + for unpivoted in [ + df.melt(value_name="foo", variable_name="bar"), + df.lazy().melt(value_name="foo", variable_name="bar").collect(), + ]: + assert set(unpivoted.iter_rows()) == expected_full + + +def test_unpivot_projection_pd_7747() -> None: + df = pl.LazyFrame( + { + "number": [1, 2, 1, 2, 1], + "age": [40, 30, 21, 33, 45], + "weight": [100, 103, 95, 90, 110], + } + ) + with pytest.deprecated_call(match="unpivot"): + result = ( + df.with_columns(pl.col("age").alias("wgt")) + .melt(id_vars="number", value_vars="wgt") + .select("number", "value") + .collect() + ) + expected = pl.DataFrame( + { + "number": [1, 2, 1, 2, 1], + "value": [40, 30, 21, 33, 45], + } + ) + assert_frame_equal(result, expected) + + +# https://github.com/pola-rs/polars/issues/10075 +def test_unpivot_no_on() -> None: + lf = pl.LazyFrame({"a": [1, 2, 3]}) + + result = lf.unpivot(index="a") + + expected = pl.LazyFrame( + schema={"a": pl.Int64, "variable": pl.String, "value": pl.Null} + ) + assert_frame_equal(result, expected) + + +def test_unpivot_raise_list() -> None: + with pytest.raises(pl.exceptions.InvalidOperationError): + pl.LazyFrame( + {"a": ["x", "y"], "b": [["test", "test2"], ["test3", "test4"]]} + ).unpivot().collect() + + +def test_unpivot_empty_18170() -> None: + assert pl.DataFrame().unpivot().schema == pl.Schema( + {"variable": pl.String(), "value": pl.Null()} + ) + + +@StringCache() +def test_unpivot_categorical_global() -> None: + df = pl.DataFrame( + { + "index": [0, 1], + "1": pl.Series(["a", "b"], dtype=pl.Categorical), + "2": pl.Series(["b", "c"], dtype=pl.Categorical), + } + ) + out = df.unpivot(["1", "2"], index="index") + assert out.dtypes == [pl.Int64, pl.String, pl.Categorical(ordering="physical")] + assert out.to_dict(as_series=False) == { + "index": [0, 1, 0, 1], + "variable": ["1", "1", "2", "2"], + "value": ["a", "b", "b", "c"], + } + + +@pytest.mark.may_fail_auto_streaming +def test_unpivot_categorical_raise_19770() -> None: + with pytest.raises(pl.exceptions.ComputeError): + (pl.DataFrame({"x": ["foo"]}).cast(pl.Categorical).unpivot()) diff --git a/py-polars/tests/unit/operations/test_value_counts.py b/py-polars/tests/unit/operations/test_value_counts.py new file mode 100644 index 000000000000..9c1a5ef4b8ed --- /dev/null +++ b/py-polars/tests/unit/operations/test_value_counts.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import pytest + +import polars as pl +from polars.exceptions import DuplicateError +from polars.testing import assert_frame_equal + + +def test_value_counts() -> None: + s = pl.Series("a", [1, 2, 2, 3]) + result = s.value_counts() + expected = pl.DataFrame( + {"a": [1, 2, 3], "count": [1, 2, 1]}, schema_overrides={"count": pl.UInt32} + ) + result_sorted = result.sort("a") + assert_frame_equal(result_sorted, expected) + + out = pl.Series("a", [12, 3345, 12, 3, 4, 4, 1, 12]).value_counts( + normalize=True, sort=True + ) + assert out["proportion"].sum() == 1.0 + assert out.to_dict(as_series=False) == { + "a": [12, 4, 3345, 3, 1], + "proportion": [0.375, 0.25, 0.125, 0.125, 0.125], + } + + +def test_value_counts_logical_type() -> None: + # test logical type + df = pl.DataFrame({"a": ["b", "c"]}).with_columns( + pl.col("a").cast(pl.Categorical).alias("ac") + ) + out = df.select(pl.all().value_counts()) + assert out["ac"].struct.field("ac").dtype == pl.Categorical + assert out["a"].struct.field("a").dtype == pl.String + + +def test_value_counts_expr() -> None: + df = pl.DataFrame( + { + "id": ["a", "b", "b", "c", "c", "c", "d", "d"], + } + ) + out = df.select(pl.col("id").value_counts(sort=True)).to_series().to_list() + assert out == [ + {"id": "c", "count": 3}, + {"id": "b", "count": 2}, + {"id": "d", "count": 2}, + {"id": "a", "count": 1}, + ] + + # nested value counts. Then the series needs the name + # 6200 + + df = pl.DataFrame({"session": [1, 1, 1], "id": [2, 2, 3]}) + + assert df.group_by("session").agg( + pl.col("id").value_counts(sort=True).first() + ).to_dict(as_series=False) == {"session": [1], "id": [{"id": 2, "count": 2}]} + + +def test_value_counts_duplicate_name() -> None: + s = pl.Series("count", [1, 0, 1]) + + # default name is 'count' ... + with pytest.raises( + DuplicateError, + match="duplicate column names; change `name` to fix", + ): + s.value_counts() + + # ... but can customize that + result = s.value_counts(name="n", sort=True) + expected = pl.DataFrame( + {"count": [1, 0], "n": [2, 1]}, schema_overrides={"n": pl.UInt32} + ) + assert_frame_equal(result, expected) + + df = pl.DataFrame({"a": [None, 1, None, 2, 3]}) + result = df.select(pl.col("a").count()) + assert result.item() == 3 + + result = df.group_by(1).agg(pl.col("a").count()) + assert result.to_dict(as_series=False) == {"literal": [1], "a": [3]} + + +def test_count() -> None: + assert pl.Series([None, 1, None, 2, 3]).count() == 3 diff --git a/py-polars/tests/unit/operations/test_window.py b/py-polars/tests/unit/operations/test_window.py new file mode 100644 index 000000000000..f1a70b2719c5 --- /dev/null +++ b/py-polars/tests/unit/operations/test_window.py @@ -0,0 +1,598 @@ +from __future__ import annotations + +import numpy as np +import pytest + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal + + +def test_over_args() -> None: + df = pl.DataFrame( + { + "a": ["a", "a", "b"], + "b": [1, 2, 3], + "c": [3, 2, 1], + } + ) + + # Single input + expected = pl.Series("c", [3, 3, 1]).to_frame() + result = df.select(pl.col("c").max().over("a")) + assert_frame_equal(result, expected) + + # Multiple input as list + expected = pl.Series("c", [3, 2, 1]).to_frame() + result = df.select(pl.col("c").max().over(["a", "b"])) + assert_frame_equal(result, expected) + + # Multiple input as positional args + result = df.select(pl.col("c").max().over("a", "b")) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dtype", [pl.Float32, pl.Float64, pl.Int32]) +def test_std(dtype: type[pl.DataType]) -> None: + if dtype == pl.Int32: + df = pl.DataFrame( + [ + pl.Series("groups", ["a", "a", "b", "b"]), + pl.Series("values", [1, 2, 3, 4], dtype=dtype), + ] + ) + else: + df = pl.DataFrame( + [ + pl.Series("groups", ["a", "a", "b", "b"]), + pl.Series("values", [1.0, 2.0, 3.0, 4.0], dtype=dtype), + ] + ) + + out = df.select(pl.col("values").std().over("groups")) + assert np.isclose(out["values"][0], 0.7071067690849304) + + out = df.select(pl.col("values").var().over("groups")) + assert np.isclose(out["values"][0], 0.5) + out = df.select(pl.col("values").mean().over("groups")) + assert np.isclose(out["values"][0], 1.5) + + +def test_issue_2529() -> None: + def stdize_out(value: str, control_for: str) -> pl.Expr: + return (pl.col(value) - pl.mean(value).over(control_for)) / pl.std(value).over( + control_for + ) + + df = pl.from_dicts( + [ + {"cat": cat, "val1": cat + _, "val2": cat + _} + for cat in range(2) + for _ in range(2) + ] + ) + + out = df.select( + "*", + stdize_out("val1", "cat").alias("out1"), + stdize_out("val2", "cat").alias("out2"), + ) + assert out["out1"].to_list() == out["out2"].to_list() + + +def test_window_function_cache() -> None: + # ensures that the cache runs the flattened first (that are the sorted groups) + # otherwise the flattened results are not ordered correctly + out = pl.DataFrame( + { + "groups": ["A", "A", "B", "B", "B"], + "groups_not_sorted": ["A", "B", "A", "B", "A"], + "values": range(5), + } + ).with_columns( + pl.col("values") + .over("groups", mapping_strategy="join") + .alias("values_list"), # aggregation to list + join + pl.col("values") + .over("groups", mapping_strategy="explode") + .alias("values_flat"), # aggregation to list + explode and concat back + pl.col("values") + .reverse() + .over("groups", mapping_strategy="explode") + .alias("values_rev"), # use flatten to reverse within a group + ) + + assert out["values_list"].to_list() == [ + [0, 1], + [0, 1], + [2, 3, 4], + [2, 3, 4], + [2, 3, 4], + ] + assert out["values_flat"].to_list() == [0, 1, 2, 3, 4] + assert out["values_rev"].to_list() == [1, 0, 4, 3, 2] + + +def test_window_range_no_rows() -> None: + df = pl.DataFrame({"x": [5, 5, 4, 4, 2, 2]}) + expr = pl.int_range(0, pl.len()).over("x") + out = df.with_columns(int=expr) + assert_frame_equal( + out, pl.DataFrame({"x": [5, 5, 4, 4, 2, 2], "int": [0, 1, 0, 1, 0, 1]}) + ) + + df = pl.DataFrame({"x": []}, schema={"x": pl.Float32}) + out = df.with_columns(int=expr) + + expected = pl.DataFrame(schema={"x": pl.Float32, "int": pl.Int64}) + assert_frame_equal(out, expected) + + +def test_no_panic_on_nan_3067() -> None: + df = pl.DataFrame( + { + "group": ["a", "a", "a", "b", "b", "b"], + "total": [1.0, 2, 3, 4, 5, np.nan], + } + ) + + expected = [None, 1.0, 2.0, None, 4.0, 5.0] + assert ( + df.select([pl.col("total").shift().over("group")])["total"].to_list() + == expected + ) + + +def test_quantile_as_window() -> None: + result = ( + pl.DataFrame( + { + "group": [0, 0, 1, 1], + "value": [0, 1, 0, 2], + } + ) + .select(pl.quantile("value", 0.9).over("group")) + .to_series() + ) + expected = pl.Series("value", [1.0, 1.0, 2.0, 2.0]) + assert_series_equal(result, expected) + + +def test_cumulative_eval_window_functions() -> None: + df = pl.DataFrame( + { + "group": [0, 0, 0, 1, 1, 1], + "val": [20, 40, 30, 2, 4, 3], + } + ) + + result = df.with_columns( + pl.col("val") + .cumulative_eval(pl.element().max()) + .over("group") + .alias("cumulative_eval_max") + ) + expected = pl.DataFrame( + { + "group": [0, 0, 0, 1, 1, 1], + "val": [20, 40, 30, 2, 4, 3], + "cumulative_eval_max": [20, 40, 40, 2, 4, 4], + } + ) + assert_frame_equal(result, expected) + + # 6394 + df = pl.DataFrame({"group": [1, 1, 2, 3], "value": [1, None, 3, None]}) + result = df.select( + pl.col("value").cumulative_eval(pl.element().mean()).over("group") + ) + expected = pl.DataFrame({"value": [1.0, 1.0, 3.0, None]}) + assert_frame_equal(result, expected) + + +def test_len_window() -> None: + assert ( + pl.DataFrame( + { + "a": [1, 1, 2], + } + ) + .with_columns(pl.len().over("a"))["len"] + .to_list() + ) == [2, 2, 1] + + +def test_window_cached_keys_sorted_update_4183() -> None: + df = pl.DataFrame( + { + "customer_ID": ["0", "0", "1"], + "date": [1, 2, 3], + } + ) + result = df.sort(by=["customer_ID", "date"]).select( + pl.count("date").over(pl.col("customer_ID")).alias("count"), + pl.col("date").rank(method="ordinal").over(pl.col("customer_ID")).alias("rank"), + ) + expected = pl.DataFrame( + {"count": [2, 2, 1], "rank": [1, 2, 1]}, + schema={"count": pl.UInt32, "rank": pl.UInt32}, + ) + assert_frame_equal(result, expected) + + +def test_window_functions_list_types() -> None: + df = pl.DataFrame( + { + "col_int": [1, 1, 2, 2], + "col_list": [[1], [1], [2], [2]], + } + ) + assert (df.select(pl.col("col_list").shift(1).alias("list_shifted")))[ + "list_shifted" + ].to_list() == [None, [1], [1], [2]] + + assert (df.select(pl.col("col_list").shift().alias("list_shifted")))[ + "list_shifted" + ].to_list() == [None, [1], [1], [2]] + + assert (df.select(pl.col("col_list").shift(fill_value=[]).alias("list_shifted")))[ + "list_shifted" + ].to_list() == [[], [1], [1], [2]] + + +def test_sorted_window_expression() -> None: + size = 10 + df = pl.DataFrame( + {"a": np.random.randint(10, size=size), "b": np.random.randint(10, size=size)} + ) + expr = (pl.col("a") + pl.col("b")).over("b").alias("computed") + + out1 = df.with_columns(expr).sort("b") + + # explicit sort + df = df.sort("b") + out2 = df.with_columns(expr) + + assert_frame_equal(out1, out2) + + +def test_nested_aggregation_window_expression() -> None: + df = pl.DataFrame( + { + "x": [1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 2, 13, 4, 15, 6, None, None, 19], + "y": [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], + } + ) + + result = df.with_columns( + pl.when(pl.col("x") >= pl.col("x").quantile(0.1)) + .then(1) + .otherwise(None) + .over("y") + .alias("foo") + ) + expected = pl.DataFrame( + { + "x": [1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 2, 13, 4, 15, 6, None, None, 19], + "y": [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "foo": [None, 1, 1, 1, 1, 1, 1, 1, 1, 1, None, 1, 1, 1, 1, None, None, 1], + }, + # Resulting column is Int32, see https://github.com/pola-rs/polars/issues/8041 + schema_overrides={"foo": pl.Int32}, + ) + assert_frame_equal(result, expected) + + +def test_window_5868() -> None: + df = pl.DataFrame({"value": [None, 2], "id": [None, 1]}) + + result_df = df.with_columns(pl.col("value").max().over("id")) + expected_df = pl.DataFrame({"value": [None, 2], "id": [None, 1]}) + assert_frame_equal(result_df, expected_df) + + df = pl.DataFrame({"a": [None, 1, 2, 3, 3, 3, 4, 4]}) + + result = df.select(pl.col("a").sum().over("a")).get_column("a") + expected = pl.Series("a", [0, 1, 2, 9, 9, 9, 8, 8]) + assert_series_equal(result, expected) + + result = ( + df.with_columns(pl.col("a").set_sorted()) + .select(pl.col("a").sum().over("a")) + .get_column("a") + ) + assert_series_equal(result, expected) + + result = df.drop_nulls().select(pl.col("a").sum().over("a")).get_column("a") + expected = pl.Series("a", [1, 2, 9, 9, 9, 8, 8]) + assert_series_equal(result, expected) + + result = ( + df.drop_nulls() + .with_columns(pl.col("a").set_sorted()) + .select(pl.col("a").sum().over("a")) + .get_column("a") + ) + assert_series_equal(result, expected) + + +def test_window_function_implode_contention_8536() -> None: + df = pl.DataFrame( + data={ + "policy": ["a", "b", "c", "c", "d", "d", "d", "d", "e", "e"], + "memo": ["LE", "RM", "", "", "", "LE", "", "", "", "RM"], + }, + schema={"policy": pl.String, "memo": pl.String}, + ) + + assert df.select( + (pl.lit("LE").is_in(pl.col("memo").over("policy", mapping_strategy="join"))) + | (pl.lit("RM").is_in(pl.col("memo").over("policy", mapping_strategy="join"))) + ).to_series().to_list() == [ + True, + True, + False, + False, + True, + True, + True, + True, + True, + True, + ] + + +def test_cached_windows_sync_8803() -> None: + assert ( + pl.DataFrame( + [ + pl.Series("id", [4, 5, 4, 6, 4, 5], dtype=pl.Int64), + pl.Series( + "is_valid", + [True, False, False, False, False, False], + dtype=pl.Boolean, + ), + ] + ) + .with_columns( + a=pl.lit(True).is_in(pl.col("is_valid")).over("id"), + b=pl.col("is_valid").sum().gt(0).over("id"), + ) + .sum() + ).to_dict(as_series=False) == {"id": [28], "is_valid": [1], "a": [3], "b": [3]} + + +def test_window_filtered_aggregation() -> None: + df = pl.DataFrame( + { + "group": ["A", "A", "B", "B"], + "field1": [2, 4, 6, 8], + "flag": [1, 0, 1, 1], + } + ) + out = df.with_columns( + pl.col("field1").filter(pl.col("flag") == 1).mean().over("group").alias("mean") + ) + expected = pl.DataFrame( + { + "group": ["A", "A", "B", "B"], + "field1": [2, 4, 6, 8], + "flag": [1, 0, 1, 1], + "mean": [2.0, 2.0, 7.0, 7.0], + } + ) + assert_frame_equal(out, expected) + + +def test_window_filtered_false_15483() -> None: + df = pl.DataFrame( + { + "group": ["A", "A"], + "value": [1, 2], + } + ) + out = df.with_columns( + pl.col("value").filter(pl.col("group") != "A").arg_max().over("group") + ) + expected = pl.DataFrame( + { + "group": ["A", "A"], + "value": [None, None], + }, + schema_overrides={"value": pl.UInt32}, + ) + assert_frame_equal(out, expected) + + +def test_window_and_cse_10152() -> None: + q = pl.LazyFrame( + { + "a": [0.0], + "b": [0.0], + } + ) + + q = q.with_columns( + a=pl.col("a").diff().over("a") / pl.col("a"), + b=pl.col("b").diff().over("a") / pl.col("a"), + c=pl.col("a").diff().over("a"), + ) + + assert q.collect().columns == ["a", "b", "c"] + + +def test_window_10417() -> None: + df = pl.DataFrame({"a": [1], "b": [1.2], "c": [2.1]}) + + assert df.lazy().with_columns( + pl.col("b") - pl.col("b").mean().over("a"), + pl.col("c") - pl.col("c").mean().over("a"), + ).collect().to_dict(as_series=False) == {"a": [1], "b": [0.0], "c": [0.0]} + + +def test_window_13173() -> None: + df = pl.DataFrame( + data={ + "color": ["yellow", "yellow"], + "color2": [None, "light"], + "val": ["2", "3"], + } + ) + assert df.with_columns( + pl.min("val").over(["color", "color2"]).alias("min_val_per_color") + ).to_dict(as_series=False) == { + "color": ["yellow", "yellow"], + "color2": [None, "light"], + "val": ["2", "3"], + "min_val_per_color": ["2", "3"], + } + + +def test_window_agg_list_null_15437() -> None: + df = pl.DataFrame({"a": [None]}) + output = df.select(pl.concat_list("a").over(1)) + expected = pl.DataFrame({"a": [[None]]}) + assert_frame_equal(output, expected) + + +@pytest.mark.release +def test_windows_not_cached() -> None: + ldf = ( + pl.DataFrame( + [ + pl.Series("key", ["a", "a", "b", "b"]), + pl.Series("val", [2, 2, 1, 3]), + ] + ) + .lazy() + .filter( + (pl.col("key").cum_count().over("key") == 1) + | (pl.col("val").shift(1).over("key").is_not_null()) + | (pl.col("val") != pl.col("val").shift(1).over("key")) + ) + ) + # this might fail if they are cached + for _ in range(1000): + ldf.collect() + + +def test_window_order_by_8662() -> None: + df = pl.DataFrame( + { + "g": [1, 1, 1, 1, 2, 2, 2, 2], + "t": [1, 2, 3, 4, 4, 1, 2, 3], + "x": [10, 20, 30, 40, 10, 20, 30, 40], + } + ) + + assert df.with_columns( + x_lag0=pl.col("x").shift(1).over("g"), + x_lag1=pl.col("x").shift(1).over("g", order_by="t"), + x_lag2=pl.col("x").shift(1).over("g", order_by="t", descending=True), + ).to_dict(as_series=False) == { + "g": [1, 1, 1, 1, 2, 2, 2, 2], + "t": [1, 2, 3, 4, 4, 1, 2, 3], + "x": [10, 20, 30, 40, 10, 20, 30, 40], + "x_lag0": [None, 10, 20, 30, None, 10, 20, 30], + "x_lag1": [None, 10, 20, 30, 40, None, 20, 30], + "x_lag2": [20, 30, 40, None, None, 30, 40, 10], + } + + +def test_window_chunked_std_17102() -> None: + c1 = pl.DataFrame({"A": [1, 1], "B": [1.0, 2.0]}) + c2 = pl.DataFrame({"A": [2, 2], "B": [1.0, 2.0]}) + + df = pl.concat([c1, c2], rechunk=False) + out = df.select(pl.col("B").std().over("A").alias("std")) + assert out.unique().item() == 0.7071067811865476 + + +def test_window_17308() -> None: + df = pl.DataFrame({"A": [1, 2], "B": [3, 4], "grp": ["A", "B"]}) + + assert df.select(pl.col("A").sum(), pl.col("B").sum().over("grp")).to_dict( + as_series=False + ) == {"A": [3, 3], "B": [3, 4]} + + +def test_lit_window_broadcast() -> None: + # the broadcast should happen in the window function + assert pl.DataFrame({"a": [1, 1, 2]}).select(pl.lit(0).over("a").alias("a"))[ + "a" + ].to_list() == [0, 0, 0] + + +def test_order_by_sorted_keys_18943() -> None: + df = pl.DataFrame( + { + "g": [1, 1, 1, 1], + "t": [4, 3, 2, 1], + "x": [10, 20, 30, 40], + } + ) + + expect = pl.DataFrame({"x": [100, 90, 70, 40]}) + + out = df.select(pl.col("x").cum_sum().over("g", order_by="t")) + assert_frame_equal(out, expect) + + out = df.set_sorted("g").select(pl.col("x").cum_sum().over("g", order_by="t")) + assert_frame_equal(out, expect) + + +def test_nested_window_keys() -> None: + df = pl.DataFrame({"x": 1, "y": "two"}) + assert df.select(pl.col("y").first().over(pl.struct("x").implode())).item() == "two" + assert df.select(pl.col("y").first().over(pl.struct("x"))).item() == "two" + + +def test_window_21692() -> None: + df = pl.DataFrame( + { + "a": [1, 2, 3, 1, 2, 3], + "b": [4, 3, 0, 1, 2, 0], + "c": ["first", "first", "first", "second", "second", "second"], + } + ) + gt0 = pl.col("b") > 0 + assert df.with_columns( + corr=pl.corr( + pl.col("a").filter(gt0), + pl.col("b").filter(gt0), + ).over("c"), + ).to_dict(as_series=False) == { + "a": [1, 2, 3, 1, 2, 3], + "b": [4, 3, 0, 1, 2, 0], + "c": ["first", "first", "first", "second", "second", "second"], + "corr": [-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], + } + + +def test_window_implode_explode() -> None: + assert pl.DataFrame( + { + "x": [1, 2, 3, 1, 2, 3, 1, 2, 3], + "y": [2, 2, 2, 3, 3, 3, 4, 4, 4], + } + ).select( + works=(pl.col.x * pl.col.x.implode().explode()).over(pl.col.y), + ).to_dict(as_series=False) == {"works": [1, 4, 9, 1, 4, 9, 1, 4, 9]} + + +def test_window_22006() -> None: + df = pl.DataFrame( + [ + {"a": 1, "b": 1}, + {"a": 1, "b": 2}, + {"a": 2, "b": 3}, + {"a": 2, "b": 4}, + ] + ) + + df_empty = pl.DataFrame([], df.schema) + + df_out = df.select(c=pl.col("b").over("a", mapping_strategy="join")) + + df_empty_out = df_empty.select(c=pl.col("b").over("a", mapping_strategy="join")) + + assert df_out.schema == df_empty_out.schema diff --git a/py-polars/tests/unit/operations/test_with_columns.py b/py-polars/tests/unit/operations/test_with_columns.py new file mode 100644 index 000000000000..e7d39dafed98 --- /dev/null +++ b/py-polars/tests/unit/operations/test_with_columns.py @@ -0,0 +1,183 @@ +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + + +def test_with_columns() -> None: + import datetime + + df = pl.DataFrame( + { + "a": [1, 2, 3, 4], + "b": [0.5, 4, 10, 13], + "c": [True, True, False, True], + } + ) + srs_named = pl.Series("f", [3, 2, 1, 0]) + srs_unnamed = pl.Series(values=[3, 2, 1, 0]) + + expected = pl.DataFrame( + { + "a": [1, 2, 3, 4], + "b": [0.5, 4, 10, 13], + "c": [True, True, False, True], + "d": [0.5, 8.0, 30.0, 52.0], + "e": [False, False, True, False], + "f": [3, 2, 1, 0], + "g": True, + "h": pl.Series(values=[1, 1, 1, 1], dtype=pl.Int32), + "i": 3.2, + "j": [1, 2, 3, 4], + "k": pl.Series(values=[None, None, None, None], dtype=pl.Null), + "l": datetime.datetime(2001, 1, 1, 0, 0), + } + ) + + # as exprs list + dx = df.with_columns( + (pl.col("a") * pl.col("b")).alias("d"), + ~pl.col("c").alias("e"), + srs_named, + pl.lit(True).alias("g"), + pl.lit(1).alias("h"), + pl.lit(3.2).alias("i"), + pl.col("a").alias("j"), + pl.lit(None).alias("k"), + pl.lit(datetime.datetime(2001, 1, 1, 0, 0)).alias("l"), + ) + assert_frame_equal(dx, expected) + + # as positional arguments + dx = df.with_columns( + (pl.col("a") * pl.col("b")).alias("d"), + ~pl.col("c").alias("e"), + srs_named, + pl.lit(True).alias("g"), + pl.lit(1).alias("h"), + pl.lit(3.2).alias("i"), + pl.col("a").alias("j"), + pl.lit(None).alias("k"), + pl.lit(datetime.datetime(2001, 1, 1, 0, 0)).alias("l"), + ) + assert_frame_equal(dx, expected) + + # as keyword arguments + dx = df.with_columns( + d=pl.col("a") * pl.col("b"), + e=~pl.col("c"), + f=srs_unnamed, + g=True, + h=1, + i=3.2, + j="a", # Note: string interpreted as column name, resolves to `pl.col("a")` + k=None, + l=datetime.datetime(2001, 1, 1, 0, 0), + ) + assert_frame_equal(dx, expected) + + # mixed + dx = df.with_columns( + (pl.col("a") * pl.col("b")).alias("d"), + ~pl.col("c").alias("e"), + f=srs_unnamed, + g=True, + h=1, + i=3.2, + j="a", # Note: string interpreted as column name, resolves to `pl.col("a")` + k=None, + l=datetime.datetime(2001, 1, 1, 0, 0), + ) + assert_frame_equal(dx, expected) + + # automatically upconvert multi-output expressions to struct + with pl.Config() as cfg: + cfg.set_auto_structify(True) + + ldf = ( + pl.DataFrame({"x1": [1, 2, 6], "x2": [1, 2, 3]}) + .lazy() + .with_columns( + pl.col(["x1", "x2"]).pct_change().alias("pct_change"), + maxes=pl.all().max().name.suffix("_max"), + xcols=pl.col("^x.*$"), + ) + ) + # ┌─────┬─────┬─────────────┬───────────┬───────────┐ + # │ x1 ┆ x2 ┆ pct_change ┆ maxes ┆ xcols │ + # │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + # │ i64 ┆ i64 ┆ struct[2] ┆ struct[2] ┆ struct[2] │ + # ╞═════╪═════╪═════════════╪═══════════╪═══════════╡ + # │ 1 ┆ 1 ┆ {null,null} ┆ {6,3} ┆ {1,1} │ + # │ 2 ┆ 2 ┆ {1.0,1.0} ┆ {6,3} ┆ {2,2} │ + # │ 6 ┆ 3 ┆ {2.0,0.5} ┆ {6,3} ┆ {6,3} │ + # └─────┴─────┴─────────────┴───────────┴───────────┘ + assert ldf.collect().to_dicts() == [ + { + "x1": 1, + "x2": 1, + "pct_change": {"x1": None, "x2": None}, + "maxes": {"x1_max": 6, "x2_max": 3}, + "xcols": {"x1": 1, "x2": 1}, + }, + { + "x1": 2, + "x2": 2, + "pct_change": {"x1": 1.0, "x2": 1.0}, + "maxes": {"x1_max": 6, "x2_max": 3}, + "xcols": {"x1": 2, "x2": 2}, + }, + { + "x1": 6, + "x2": 3, + "pct_change": {"x1": 2.0, "x2": 0.5}, + "maxes": {"x1_max": 6, "x2_max": 3}, + "xcols": {"x1": 6, "x2": 3}, + }, + ] + + +def test_with_columns_empty() -> None: + df = pl.DataFrame({"a": [1, 2]}) + result = df.with_columns() + assert_frame_equal(result, df) + + +def test_with_columns_single_series() -> None: + ldf = pl.LazyFrame({"a": [1, 2]}) + result = ldf.with_columns(pl.Series("b", [3, 4])) + + expected = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + assert_frame_equal(result.collect(), expected) + + +def test_with_columns_seq() -> None: + df = pl.DataFrame({"a": [1, 2]}) + result = df.with_columns_seq( + pl.lit(5).alias("b"), + pl.lit("foo").alias("c"), + ) + expected = pl.DataFrame( + { + "a": [1, 2], + "b": pl.Series([5, 5], dtype=pl.Int32), + "c": ["foo", "foo"], + } + ) + assert_frame_equal(result, expected) + + +# https://github.com/pola-rs/polars/issues/15588 +def test_with_columns_invalid_type() -> None: + lf = pl.LazyFrame({"a": [1, 2, 3]}) + with pytest.raises( + TypeError, match="cannot create expression literal for value of type LazyFrame" + ): + lf.with_columns(lf) # type: ignore[arg-type] + + +def test_with_columns_scalar_20981() -> None: + expected = pl.DataFrame({"a": [2.0, 2.0, 2.0]}) + lf = pl.LazyFrame({"a": [1.0, 2.0, 3.0]}) + assert_frame_equal(lf.with_columns(a=2.0).collect(), expected) + assert_frame_equal(lf.with_columns(pl.col.a.mean()).collect(), expected) diff --git a/py-polars/tests/unit/operations/unique/__init__.py b/py-polars/tests/unit/operations/unique/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/operations/unique/test_approx_n_unique.py b/py-polars/tests/unit/operations/unique/test_approx_n_unique.py new file mode 100644 index 000000000000..35b9c1598366 --- /dev/null +++ b/py-polars/tests/unit/operations/unique/test_approx_n_unique.py @@ -0,0 +1,20 @@ +import pytest + +import polars as pl +from polars.testing.asserts.frame import assert_frame_equal + + +def test_df_approx_n_unique_deprecated() -> None: + df = pl.DataFrame({"a": [1, 2, 2], "b": [2, 2, 2]}) + with pytest.deprecated_call(): + result = df.approx_n_unique() + expected = pl.DataFrame({"a": [2], "b": [1]}).cast(pl.UInt32) + assert_frame_equal(result, expected) + + +def test_lf_approx_n_unique_deprecated() -> None: + df = pl.LazyFrame({"a": [1, 2, 2], "b": [2, 2, 2]}) + with pytest.deprecated_call(): + result = df.approx_n_unique() + expected = pl.LazyFrame({"a": [2], "b": [1]}).cast(pl.UInt32) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/operations/unique/test_is_unique.py b/py-polars/tests/unit/operations/unique/test_is_unique.py new file mode 100644 index 000000000000..44923c122374 --- /dev/null +++ b/py-polars/tests/unit/operations/unique/test_is_unique.py @@ -0,0 +1,87 @@ +import polars as pl +from polars.testing import assert_series_equal + + +def test_is_unique_series() -> None: + s = pl.Series("a", [1, 2, 2, 3]) + assert_series_equal(s.is_unique(), pl.Series("a", [True, False, False, True])) + + # str + assert pl.Series(["a", "b", "c", "a"]).is_duplicated().to_list() == [ + True, + False, + False, + True, + ] + assert pl.Series(["a", "b", "c", "a"]).is_unique().to_list() == [ + False, + True, + True, + False, + ] + + +def test_is_unique() -> None: + df = pl.DataFrame({"foo": [1, 2, 2], "bar": [6, 7, 7]}) + + assert_series_equal(df.is_unique(), pl.Series("", [True, False, False])) + assert df.unique(maintain_order=True).rows() == [(1, 6), (2, 7)] + assert df.n_unique() == 2 + + +def test_is_unique2() -> None: + df = pl.DataFrame({"a": [4, 1, 4]}) + result = df.select(pl.col("a").is_unique())["a"] + assert_series_equal(result, pl.Series("a", [False, True, False])) + + +def test_is_unique_null() -> None: + s = pl.Series([]) + expected = pl.Series([], dtype=pl.Boolean) + assert_series_equal(s.is_unique(), expected) + + s = pl.Series([None]) + expected = pl.Series([True], dtype=pl.Boolean) + assert_series_equal(s.is_unique(), expected) + + s = pl.Series([None, None, None]) + expected = pl.Series([False, False, False], dtype=pl.Boolean) + assert_series_equal(s.is_unique(), expected) + + +def test_is_unique_struct() -> None: + assert pl.Series( + [{"a": 1, "b": 1}, {"a": 2, "b": 1}, {"a": 1, "b": 1}] + ).is_unique().to_list() == [False, True, False] + assert pl.Series( + [{"a": 1, "b": 1}, {"a": 2, "b": 1}, {"a": 1, "b": 1}] + ).is_duplicated().to_list() == [True, False, True] + + +def test_is_duplicated_series() -> None: + s = pl.Series("a", [1, 2, 2, 3]) + assert_series_equal(s.is_duplicated(), pl.Series("a", [False, True, True, False])) + + +def test_is_duplicated_df() -> None: + df = pl.DataFrame({"foo": [1, 2, 2], "bar": [6, 7, 7]}) + assert_series_equal(df.is_duplicated(), pl.Series("", [False, True, True])) + + +def test_is_duplicated_lf() -> None: + ldf = pl.LazyFrame({"a": [4, 1, 4]}).select(pl.col("a").is_duplicated()) + assert_series_equal(ldf.collect()["a"], pl.Series("a", [True, False, True])) + + +def test_is_duplicated_null() -> None: + s = pl.Series([]) + expected = pl.Series([], dtype=pl.Boolean) + assert_series_equal(s.is_duplicated(), expected) + + s = pl.Series([None]) + expected = pl.Series([False], dtype=pl.Boolean) + assert_series_equal(s.is_duplicated(), expected) + + s = pl.Series([None, None, None]) + expected = pl.Series([True, True, True], dtype=pl.Boolean) + assert_series_equal(s.is_duplicated(), expected) diff --git a/py-polars/tests/unit/operations/unique/test_n_unique.py b/py-polars/tests/unit/operations/unique/test_n_unique.py new file mode 100644 index 000000000000..34e441ce42be --- /dev/null +++ b/py-polars/tests/unit/operations/unique/test_n_unique.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import pytest + +import polars as pl + + +def test_n_unique() -> None: + s = pl.Series("s", [11, 11, 11, 22, 22, 33, None, None, None]) + assert s.n_unique() == 4 + + +def test_n_unique_subsets() -> None: + df = pl.DataFrame( + { + "a": [1, 1, 2, 3, 4, 5], + "b": [0.5, 0.5, 1.0, 2.0, 3.0, 3.0], + "c": [True, True, True, False, True, True], + } + ) + # omitting 'subset' counts unique rows + assert df.n_unique() == 5 + + # providing it counts unique col/expr subsets + assert df.n_unique(subset=["b", "c"]) == 4 + assert df.n_unique(subset=pl.col("c")) == 2 + assert ( + df.n_unique(subset=[(pl.col("a") // 2), (pl.col("c") | (pl.col("b") >= 2))]) + == 3 + ) + + +def test_n_unique_null() -> None: + assert pl.Series([]).n_unique() == 0 + assert pl.Series([None]).n_unique() == 1 + assert pl.Series([None, None]).n_unique() == 1 + + +@pytest.mark.parametrize( + ("input", "output"), + [ + ([], 0), + (["a", "b", "b", "c"], 3), + (["a", "b", "b", None], 3), + ], +) +def test_n_unique_categorical(input: list[str | None], output: int) -> None: + assert pl.Series(input, dtype=pl.Categorical).n_unique() == output + + +def test_n_unique_list_of_struct_20341() -> None: + df = pl.DataFrame( + { + "a": [ + [{"a": 1, "b": 2}, {"a": 10, "b": 20}], + [{"a": 1, "b": 2}, {"a": 10, "b": 20}], + [{"a": 3, "b": 4}], + ] + } + ) + assert df.select("a").n_unique() == 2 + assert df["a"].n_unique() == 2 diff --git a/py-polars/tests/unit/operations/unique/test_unique.py b/py-polars/tests/unit/operations/unique/test_unique.py new file mode 100644 index 000000000000..4fe2bbf613f6 --- /dev/null +++ b/py-polars/tests/unit/operations/unique/test_unique.py @@ -0,0 +1,291 @@ +from __future__ import annotations + +from datetime import date +from typing import TYPE_CHECKING, Any + +import pytest + +import polars as pl +from polars.exceptions import ColumnNotFoundError +from polars.testing import assert_frame_equal, assert_series_equal +from tests.unit.conftest import with_string_cache_if_auto_streaming + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + + +def test_unique_predicate_pd() -> None: + lf = pl.LazyFrame( + { + "x": ["abc", "abc"], + "y": ["xxx", "xxx"], + "z": [True, False], + } + ) + + result = ( + lf.unique(subset=["x", "y"], maintain_order=True, keep="last") + .filter(pl.col("z")) + .collect() + ) + expected = pl.DataFrame(schema={"x": pl.String, "y": pl.String, "z": pl.Boolean}) + assert_frame_equal(result, expected) + + result = ( + lf.unique(subset=["x", "y"], maintain_order=True, keep="any") + .filter(pl.col("z")) + .collect() + ) + expected = pl.DataFrame({"x": ["abc"], "y": ["xxx"], "z": [True]}) + assert_frame_equal(result, expected) + + # Issue #14595: filter should not naively be pushed past unique() + for maintain_order in (True, False): + for keep in ("first", "last", "any", "none"): + q = ( + lf.unique("x", maintain_order=maintain_order, keep=keep) + .filter(pl.col("x") == "abc") + .filter(pl.col("z")) + ) + assert_frame_equal(q.collect(predicate_pushdown=False), q.collect()) + + +def test_unique_on_list_df() -> None: + assert pl.DataFrame( + {"a": [1, 2, 3, 4, 4], "b": [[1, 1], [2], [3], [4, 4], [4, 4]]} + ).unique(maintain_order=True).to_dict(as_series=False) == { + "a": [1, 2, 3, 4], + "b": [[1, 1], [2], [3], [4, 4]], + } + + +def test_list_unique() -> None: + s = pl.Series("a", [[1, 2], [3], [1, 2], [4, 5], [2], [2]]) + assert s.unique(maintain_order=True).to_list() == [[1, 2], [3], [4, 5], [2]] + assert s.arg_unique().to_list() == [0, 1, 3, 4] + assert s.n_unique() == 4 + + +def test_unique_and_drop_stability() -> None: + # see: 2898 + # the original cause was that we wrote: + # expr_a = a.unique() + # expr_a.filter(a.unique().is_not_null()) + # meaning that the a.unique was executed twice, which is an unstable algorithm + df = pl.DataFrame({"a": [1, None, 1, None]}) + assert df.select(pl.col("a").unique().drop_nulls()).to_series()[0] == 1 + + +def test_unique_empty() -> None: + for dt in [pl.String, pl.Boolean, pl.Int32, pl.UInt32]: + s = pl.Series([], dtype=dt) + assert_series_equal(s.unique(), s) + + +def test_unique() -> None: + ldf = pl.LazyFrame({"a": [1, 2, 2], "b": [3, 3, 3]}) + + expected = pl.DataFrame({"a": [1, 2], "b": [3, 3]}) + assert_frame_equal(ldf.unique(maintain_order=True).collect(), expected) + + result = ldf.unique(subset="b", maintain_order=True).collect() + expected = pl.DataFrame({"a": [1], "b": [3]}) + assert_frame_equal(result, expected) + + s0 = pl.Series("a", [1, 2, None, 2]) + expected_s = pl.Series("a", [1, 2, None]) + assert_series_equal(s0.unique(maintain_order=True), expected_s) + assert_series_equal(s0.unique(maintain_order=False), expected_s, check_order=False) + + +def test_struct_unique_df() -> None: + df = pl.DataFrame( + { + "numerical": [1, 2, 1], + "struct": [{"x": 1, "y": 2}, {"x": 3, "y": 4}, {"x": 1, "y": 2}], + } + ) + + df.select("numerical", "struct").unique().sort("numerical") + + +def test_sorted_unique_dates() -> None: + out = ( + pl.DataFrame( + [pl.Series("dt", [date(2015, 6, 24), date(2015, 6, 23)], dtype=pl.Date)] + ) + .sort("dt") + .unique(maintain_order=False) + ) + expected = pl.DataFrame({"dt": [date(2015, 6, 23), date(2015, 6, 24)]}) + assert_frame_equal(out, expected, check_row_order=False) + + +@pytest.mark.parametrize("maintain_order", [True, False]) +def test_unique_null(maintain_order: bool) -> None: + s0 = pl.Series([]) + assert_series_equal(s0.unique(maintain_order=maintain_order), s0) + + s1 = pl.Series([None]) + assert_series_equal(s1.unique(maintain_order=maintain_order), s1) + + s2 = pl.Series([None, None]) + assert_series_equal(s2.unique(maintain_order=maintain_order), s1) + + +@pytest.mark.parametrize( + ("input", "output"), + [ + ([], []), + (["a", "b", "b", "c"], ["a", "b", "c"]), + ([None, "a", "b", "b"], [None, "a", "b"]), + ], +) +@pytest.mark.usefixtures("test_global_and_local") +def test_unique_categorical(input: list[str | None], output: list[str | None]) -> None: + with pl.StringCache(): + s = pl.Series(input, dtype=pl.Categorical) + result = s.unique(maintain_order=True) + expected = pl.Series(output, dtype=pl.Categorical) + assert_series_equal(result, expected) + + result = s.unique(maintain_order=False) + expected = pl.Series(output, dtype=pl.Categorical) + assert_series_equal(result, expected, check_order=False) + + +def test_unique_categorical_global() -> None: + with pl.StringCache(): + pl.Series(["aaaa", "bbbb", "cccc"]) # pre-fill global cache + s = pl.Series(["a", "b", "c"], dtype=pl.Categorical) + s_empty = s.slice(0, 0) + + assert s_empty.unique().to_list() == [] + assert_series_equal(s_empty.cat.get_categories(), pl.Series(["a", "b", "c"])) + + +def test_unique_with_null() -> None: + df = pl.DataFrame( + { + "a": [1, 1, 2, 2, 3, 4], + "b": ["a", "a", "b", "b", "c", "c"], + "c": [None, None, None, None, None, None], + } + ) + expected_df = pl.DataFrame( + {"a": [1, 2, 3, 4], "b": ["a", "b", "c", "c"], "c": [None, None, None, None]} + ) + assert_frame_equal(df.unique(maintain_order=True), expected_df) + + +@pytest.mark.parametrize( + ("input_json_data", "input_schema", "subset"), + [ + ({"ID": [], "Name": []}, {"ID": pl.Int64, "Name": pl.String}, "id"), + ({"ID": [], "Name": []}, {"ID": pl.Int64, "Name": pl.String}, ["age", "place"]), + ( + {"ID": [1, 2, 1, 2], "Name": ["foo", "bar", "baz", "baa"]}, + {"ID": pl.Int64, "Name": pl.String}, + "id", + ), + ( + {"ID": [1, 2, 1, 2], "Name": ["foo", "bar", "baz", "baa"]}, + {"ID": pl.Int64, "Name": pl.String}, + ["age", "place"], + ), + ], +) +def test_unique_with_bad_subset( + input_json_data: dict[str, list[Any]], + input_schema: dict[str, PolarsDataType], + subset: str | list[str], +) -> None: + df = pl.DataFrame(input_json_data, schema=input_schema) + + with pytest.raises(ColumnNotFoundError, match="not found"): + df.unique(subset=subset) + + +@pytest.mark.usefixtures("test_global_and_local") +@with_string_cache_if_auto_streaming +def test_categorical_unique_19409() -> None: + df = pl.DataFrame({"x": [str(n % 50) for n in range(127)]}).cast(pl.Categorical) + uniq = df.unique() + assert uniq.height == 50 + assert uniq.null_count().item() == 0 + assert set(uniq["x"]) == set(df["x"]) + + +def test_categorical_updated_revmap_unique_20233() -> None: + with pl.StringCache(): + s = pl.Series("a", ["A"], pl.Categorical) + + s = ( + pl.select(a=pl.when(True).then(pl.lit("C", pl.Categorical))) + .select(a=pl.when(True).then(pl.lit("D", pl.Categorical))) + .to_series() + ) + + assert_series_equal(s.unique(), pl.Series("a", ["D"], pl.Categorical)) + + +def test_unique_check_order_20480() -> None: + df = pl.DataFrame( + [ + { + "key": "some_key", + "value": "second", + "number": 2, + }, + { + "key": "some_key", + "value": "first", + "number": 1, + }, + ] + ) + assert ( + df.lazy() + .sort("key", "number") + .unique(subset="key", keep="first") + .collect()["number"] + .item() + == 1 + ) + + +def test_predicate_pushdown_unique() -> None: + q = ( + pl.LazyFrame({"id": [1, 2, 3]}) + .with_columns(pl.date(2024, 1, 1) + pl.duration(days=pl.Series([1, 2, 3]))) # type: ignore[arg-type] + .unique() + ) + + print(q.filter(pl.col("id").is_in([1, 2, 3])).explain()) + assert not q.filter(pl.col("id").is_in([1, 2, 3])).explain().startswith("FILTER") + assert q.filter(pl.col("id").sum() == pl.col("id")).explain().startswith("FILTER") + + +def test_unique_enum_19338() -> None: + for data in [ + {"enum": ["A"]}, + [{"enum": "A"}], + ]: + df = pl.DataFrame(data, schema={"enum": pl.Enum(["A", "B", "C"])}) + result = df.select(pl.col("enum").unique()) + expected = pl.DataFrame( + {"enum": ["A"]}, schema={"enum": pl.Enum(["A", "B", "C"])} + ) + assert_frame_equal(result, expected) + + +def test_unique_nan_12950() -> None: + df = pl.DataFrame({"x": float("nan")}) + result = df.unique() + assert_frame_equal(result, df) + + +def test_unique_lengths_21654() -> None: + for n in range(0, 1000, 37): + df = pl.DataFrame({"x": pl.int_range(n, eager=True)}) + assert df.unique().height == n diff --git a/py-polars/tests/unit/operations/unique/test_unique_counts.py b/py-polars/tests/unit/operations/unique/test_unique_counts.py new file mode 100644 index 000000000000..0ca4fe82091e --- /dev/null +++ b/py-polars/tests/unit/operations/unique/test_unique_counts.py @@ -0,0 +1,43 @@ +from datetime import datetime + +import polars as pl +from polars.testing import assert_series_equal + + +def test_unique_counts() -> None: + s = pl.Series("id", ["a", "b", "b", "c", "c", "c"]) + expected = pl.Series("id", [1, 2, 3], dtype=pl.UInt32) + assert_series_equal(s.unique_counts(), expected) + + +def test_unique_counts_on_dates() -> None: + assert pl.DataFrame( + { + "dt_ns": pl.datetime_range( + datetime(2020, 1, 1), datetime(2020, 3, 1), "1mo", eager=True + ), + } + ).with_columns( + pl.col("dt_ns").dt.cast_time_unit("us").alias("dt_us"), + pl.col("dt_ns").dt.cast_time_unit("ms").alias("dt_ms"), + pl.col("dt_ns").cast(pl.Date).alias("date"), + ).select(pl.all().unique_counts().sum()).to_dict(as_series=False) == { + "dt_ns": [3], + "dt_us": [3], + "dt_ms": [3], + "date": [3], + } + + +def test_unique_counts_null() -> None: + s = pl.Series([]) + expected = pl.Series([], dtype=pl.UInt32) + assert_series_equal(s.unique_counts(), expected) + + s = pl.Series([None]) + expected = pl.Series([1], dtype=pl.UInt32) + assert_series_equal(s.unique_counts(), expected) + + s = pl.Series([None, None, None]) + expected = pl.Series([3], dtype=pl.UInt32) + assert_series_equal(s.unique_counts(), expected) diff --git a/py-polars/tests/unit/series/__init__.py b/py-polars/tests/unit/series/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/series/buffers/__init__.py b/py-polars/tests/unit/series/buffers/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/series/buffers/test_from_buffer.py b/py-polars/tests/unit/series/buffers/test_from_buffer.py new file mode 100644 index 000000000000..7bdb6e45c0b9 --- /dev/null +++ b/py-polars/tests/unit/series/buffers/test_from_buffer.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from datetime import date + +import pytest +from hypothesis import given + +import polars as pl +from polars.testing import assert_series_equal +from polars.testing.parametric import series +from tests.unit.conftest import NUMERIC_DTYPES + + +@given( + s=series( + allowed_dtypes=[*NUMERIC_DTYPES, pl.Boolean], + allow_chunks=False, + allow_null=False, + ) +) +def test_series_from_buffer(s: pl.Series) -> None: + buffer_info = s._get_buffer_info() + result = pl.Series._from_buffer(s.dtype, buffer_info, owner=s) + assert_series_equal(s, result) + + +def test_series_from_buffer_numeric() -> None: + s = pl.Series([1, 2, 3], dtype=pl.UInt16) + buffer_info = s._get_buffer_info() + result = pl.Series._from_buffer(s.dtype, buffer_info, owner=s) + assert_series_equal(s, result) + + +def test_series_from_buffer_sliced_bitmask() -> None: + s = pl.Series([True] * 9, dtype=pl.Boolean)[5:] + buffer_info = s._get_buffer_info() + result = pl.Series._from_buffer(s.dtype, buffer_info, owner=s) + assert_series_equal(s, result) + + +def test_series_from_buffer_unsupported() -> None: + s = pl.Series([date(2020, 1, 1), date(2020, 2, 5)]) + buffer_info = s._get_buffer_info() + + msg = "`_from_buffer` requires a physical type as input for `dtype`, got date" + with pytest.raises(TypeError, match=msg): + pl.Series._from_buffer(pl.Date, buffer_info, owner=s) diff --git a/py-polars/tests/unit/series/buffers/test_from_buffers.py b/py-polars/tests/unit/series/buffers/test_from_buffers.py new file mode 100644 index 000000000000..43185b50a757 --- /dev/null +++ b/py-polars/tests/unit/series/buffers/test_from_buffers.py @@ -0,0 +1,178 @@ +from __future__ import annotations + +from datetime import datetime +from zoneinfo import ZoneInfo + +import pytest +from hypothesis import given + +import polars as pl +from polars.exceptions import PanicException +from polars.testing import assert_series_equal +from polars.testing.parametric import series +from tests.unit.conftest import NUMERIC_DTYPES + + +@given( + s=series( + allowed_dtypes=[*NUMERIC_DTYPES, pl.Boolean], + allow_chunks=False, + ) +) +def test_series_from_buffers_numeric_with_validity(s: pl.Series) -> None: + validity = s.is_not_null() + result = pl.Series._from_buffers(s.dtype, data=s, validity=validity) + assert_series_equal(s, result) + + +@given( + s=series( + allowed_dtypes=[*NUMERIC_DTYPES, pl.Boolean], + allow_chunks=False, + allow_null=False, + ) +) +def test_series_from_buffers_numeric(s: pl.Series) -> None: + result = pl.Series._from_buffers(s.dtype, data=s) + assert_series_equal(s, result) + + +@given( + s=series( + allowed_dtypes=[pl.Date, pl.Time, pl.Datetime, pl.Duration], + allow_chunks=False, + ) +) +def test_series_from_buffers_temporal_with_validity(s: pl.Series) -> None: + validity = s.is_not_null() + physical = pl.Int32 if s.dtype == pl.Date else pl.Int64 + data = s.cast(physical) + result = pl.Series._from_buffers(s.dtype, data=data, validity=validity) + assert_series_equal(s, result) + + +def test_series_from_buffers_int() -> None: + dtype = pl.UInt16 + data = pl.Series([97, 98, 99, 195], dtype=dtype) + validity = pl.Series([True, True, False, True]) + + result = pl.Series._from_buffers(dtype, data=data, validity=validity) + + expected = pl.Series([97, 98, None, 195], dtype=dtype) + assert_series_equal(result, expected) + + +def test_series_from_buffers_float() -> None: + dtype = pl.Float64 + data = pl.Series([0.0, 1.0, -1.0, float("nan"), float("inf")], dtype=dtype) + validity = pl.Series([True, True, False, True, True]) + + result = pl.Series._from_buffers(dtype, data=data, validity=validity) + + expected = pl.Series([0.0, 1.0, None, float("nan"), float("inf")], dtype=dtype) + assert_series_equal(result, expected) + + +def test_series_from_buffers_boolean() -> None: + dtype = pl.Boolean + data = pl.Series([True, False, True]) + validity = pl.Series([True, True, False]) + + result = pl.Series._from_buffers(dtype, data=data, validity=validity) + + expected = pl.Series([True, False, None]) + assert_series_equal(result, expected) + + +def test_series_from_buffers_datetime() -> None: + dtype = pl.Datetime(time_zone="Europe/Amsterdam") + tzinfo = ZoneInfo("Europe/Amsterdam") + data = pl.Series( + [ + datetime(2022, 2, 10, 6, tzinfo=tzinfo), + datetime(2022, 2, 11, 12, tzinfo=tzinfo), + datetime(2022, 2, 12, 18, tzinfo=tzinfo), + ], + dtype=dtype, + ).cast(pl.Int64) + validity = pl.Series([True, False, True]) + + result = pl.Series._from_buffers(dtype, data=data, validity=validity) + + expected = pl.Series( + [ + datetime(2022, 2, 10, 6, tzinfo=tzinfo), + None, + datetime(2022, 2, 12, 18, tzinfo=tzinfo), + ], + dtype=dtype, + ) + assert_series_equal(result, expected) + + +def test_series_from_buffers_string() -> None: + dtype = pl.String + data = pl.Series([97, 98, 99, 195, 169, 195, 162, 195, 167], dtype=pl.UInt8) + validity = pl.Series([True, True, False, True]) + offsets = pl.Series([0, 1, 3, 3, 9], dtype=pl.Int64) + + result = pl.Series._from_buffers(dtype, data=[data, offsets], validity=validity) + + expected = pl.Series(["a", "bc", None, "éâç"], dtype=dtype) + assert_series_equal(result, expected) + + +def test_series_from_buffers_enum() -> None: + dtype = pl.Enum(["a", "b", "c"]) + data = pl.Series([0, 1, 0, 2], dtype=pl.UInt32) + validity = pl.Series([True, True, False, True]) + + result = pl.Series._from_buffers(dtype, data=data, validity=validity) + + expected = pl.Series(["a", "b", None, "c"], dtype=dtype) + assert_series_equal(result, expected) + + +def test_series_from_buffers_sliced() -> None: + dtype = pl.Int64 + data = pl.Series([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=dtype) + data = data[5:] + validity = pl.Series([True, True, True, True, False, True, False, False, True]) + validity = validity[5:] + + result = pl.Series._from_buffers(dtype, data=data, validity=validity) + + expected = pl.Series([6, None, None, 9], dtype=dtype) + assert_series_equal(result, expected) + + +def test_series_from_buffers_unsupported_validity() -> None: + s = pl.Series([1, 2, 3]) + + msg = "validity buffer must have data type Boolean, got Int64" + with pytest.raises(TypeError, match=msg): + pl.Series._from_buffers(pl.Date, data=s, validity=s) + + +def test_series_from_buffers_unsupported_offsets() -> None: + data = pl.Series([97, 98, 99, 195, 169, 195, 162, 195, 167], dtype=pl.UInt8) + offsets = pl.Series([0, 1, 3, 3, 9], dtype=pl.Int8) + + msg = "offsets buffer must have data type Int64, got Int8" + with pytest.raises(TypeError, match=msg): + pl.Series._from_buffers(pl.String, data=[data, offsets]) + + +def test_series_from_buffers_offsets_do_not_match_data() -> None: + data = pl.Series([97, 98, 99, 195, 169, 195, 162, 195, 167], dtype=pl.UInt8) + offsets = pl.Series([0, 1, 3, 3, 9, 11], dtype=pl.Int64) + + msg = "offsets must not exceed the values length" + with pytest.raises(PanicException, match=msg): + pl.Series._from_buffers(pl.String, data=[data, offsets]) + + +def test_series_from_buffers_no_buffers() -> None: + msg = "`data` input to `_from_buffers` must contain at least one buffer" + with pytest.raises(TypeError, match=msg): + pl.Series._from_buffers(pl.Int32, data=[]) diff --git a/py-polars/tests/unit/series/buffers/test_get_buffer_info.py b/py-polars/tests/unit/series/buffers/test_get_buffer_info.py new file mode 100644 index 000000000000..71cab47d95a8 --- /dev/null +++ b/py-polars/tests/unit/series/buffers/test_get_buffer_info.py @@ -0,0 +1,44 @@ +import pytest + +import polars as pl +from polars.exceptions import ComputeError +from tests.unit.conftest import NUMERIC_DTYPES + + +@pytest.mark.parametrize("dtype", NUMERIC_DTYPES) +def test_get_buffer_info_numeric(dtype: pl.DataType) -> None: + s = pl.Series([1, 2, 3], dtype=dtype) + assert s._get_buffer_info()[0] > 0 + + +def test_get_buffer_info_bool() -> None: + s = pl.Series([True, False, False, True]) + assert s._get_buffer_info()[0] > 0 + assert s[1:]._get_buffer_info()[1] == 1 + + +def test_get_buffer_info_after_rechunk() -> None: + s = pl.Series([1, 2, 3]) + ptr = s._get_buffer_info()[0] + assert isinstance(ptr, int) + + s2 = s.append(pl.Series([1, 2])) + ptr2 = s2.rechunk()._get_buffer_info()[0] + assert ptr != ptr2 + + +def test_get_buffer_info_invalid_data_type() -> None: + s = pl.Series(["a", "bc"]) + + msg = "`_get_buffer_info` not implemented for non-physical type str; try to select a buffer first" + with pytest.raises(TypeError, match=msg): + s._get_buffer_info() + + +def test_get_buffer_info_chunked() -> None: + s1 = pl.Series([1, 2]) + s2 = pl.Series([3, 4]) + s = pl.concat([s1, s2], rechunk=False) + + with pytest.raises(ComputeError): + s._get_buffer_info() diff --git a/py-polars/tests/unit/series/buffers/test_get_buffers.py b/py-polars/tests/unit/series/buffers/test_get_buffers.py new file mode 100644 index 000000000000..ca05cb2ca0e4 --- /dev/null +++ b/py-polars/tests/unit/series/buffers/test_get_buffers.py @@ -0,0 +1,113 @@ +from datetime import date +from typing import cast + +import pytest + +import polars as pl +from polars.testing import assert_series_equal + + +def test_get_buffers_only_values() -> None: + s = pl.Series([1, 2, 3]) + + result = s._get_buffers() + + assert_series_equal(result["values"], s) + assert result["validity"] is None + assert result["offsets"] is None + + +def test_get_buffers_with_validity() -> None: + s = pl.Series([1.5, None, 3.5]) + + result = s._get_buffers() + + expected_values = pl.Series([1.5, 0.0, 3.5]) + assert_series_equal(result["values"], expected_values) + + validity = cast(pl.Series, result["validity"]) + expected_validity = pl.Series([True, False, True]) + assert_series_equal(validity, expected_validity) + + assert result["offsets"] is None + + +def test_get_buffers_string_type() -> None: + s = pl.Series(["a", "bc", None, "éâç", ""]) + + result = s._get_buffers() + + expected_values = pl.Series( + [97, 98, 99, 195, 169, 195, 162, 195, 167], dtype=pl.UInt8 + ) + assert_series_equal(result["values"], expected_values) + + validity = cast(pl.Series, result["validity"]) + expected_validity = pl.Series([True, True, False, True, True]) + assert_series_equal(validity, expected_validity) + + offsets = cast(pl.Series, result["offsets"]) + expected_offsets = pl.Series([0, 1, 3, 3, 9, 9], dtype=pl.Int64) + assert_series_equal(offsets, expected_offsets) + + +def test_get_buffers_logical_sliced() -> None: + s = pl.Series([date(1970, 1, 1), None, date(1970, 1, 3)])[1:] + + result = s._get_buffers() + + expected_values = pl.Series([0, 2], dtype=pl.Int32) + assert_series_equal(result["values"], expected_values) + + validity = cast(pl.Series, result["validity"]) + expected_validity = pl.Series([False, True]) + assert_series_equal(validity, expected_validity) + + assert result["offsets"] is None + + +def test_get_buffers_chunked() -> None: + s = pl.Series([1, 2, None, 4], dtype=pl.UInt8) + s_chunked = pl.concat([s[:2], s[2:]], rechunk=False) + + result = s_chunked._get_buffers() + + expected_values = pl.Series([1, 2, 0, 4], dtype=pl.UInt8) + assert_series_equal(result["values"], expected_values) + assert result["values"].n_chunks() == 2 + + validity = cast(pl.Series, result["validity"]) + expected_validity = pl.Series([True, True, False, True]) + assert_series_equal(validity, expected_validity) + assert validity.n_chunks() == 2 + + +def test_get_buffers_chunked_string_type() -> None: + s = pl.Series(["a", "bc", None, "éâç", ""]) + s_chunked = pl.concat([s[:2], s[2:]], rechunk=False) + + result = s_chunked._get_buffers() + + expected_values = pl.Series( + [97, 98, 99, 195, 169, 195, 162, 195, 167], dtype=pl.UInt8 + ) + assert_series_equal(result["values"], expected_values) + assert result["values"].n_chunks() == 1 + + validity = cast(pl.Series, result["validity"]) + expected_validity = pl.Series([True, True, False, True, True]) + assert_series_equal(validity, expected_validity) + assert validity.n_chunks() == 1 + + offsets = cast(pl.Series, result["offsets"]) + expected_offsets = pl.Series([0, 1, 3, 3, 9, 9], dtype=pl.Int64) + assert_series_equal(offsets, expected_offsets) + assert offsets.n_chunks() == 1 + + +def test_get_buffers_unsupported_data_type() -> None: + s = pl.Series([[1, 2], [3]]) + + msg = "`_get_buffers` not implemented for `dtype` list\\[i64\\]" + with pytest.raises(TypeError, match=msg): + s._get_buffers() diff --git a/py-polars/tests/unit/series/test_all_any.py b/py-polars/tests/unit/series/test_all_any.py new file mode 100644 index 000000000000..df2a44f2783b --- /dev/null +++ b/py-polars/tests/unit/series/test_all_any.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import pytest + +import polars as pl +from polars.exceptions import SchemaError + + +@pytest.mark.parametrize( + ("data", "expected"), + [ + ([], False), + ([None], False), + ([False], False), + ([False, None], False), + ([True], True), + ([True, None], True), + ], +) +def test_any(data: list[bool | None], expected: bool) -> None: + assert pl.Series(data, dtype=pl.Boolean).any() is expected + + +@pytest.mark.parametrize( + ("data", "expected"), + [ + ([], False), + ([None], None), + ([False], False), + ([False, None], None), + ([True], True), + ([True, None], True), + ], +) +def test_any_kleene(data: list[bool | None], expected: bool | None) -> None: + assert pl.Series(data, dtype=pl.Boolean).any(ignore_nulls=False) is expected + + +def test_any_wrong_dtype() -> None: + with pytest.raises(SchemaError, match="expected `Boolean`"): + pl.Series([0, 1, 0]).any() + + +@pytest.mark.parametrize( + ("data", "expected"), + [ + ([], True), + ([None], True), + ([False], False), + ([False, None], False), + ([True], True), + ([True, None], True), + ], +) +def test_all(data: list[bool | None], expected: bool) -> None: + assert pl.Series(data, dtype=pl.Boolean).all() is expected + + +@pytest.mark.parametrize( + ("data", "expected"), + [ + ([], True), + ([None], None), + ([False], False), + ([False, None], False), + ([True], True), + ([True, None], None), + ], +) +def test_all_kleene(data: list[bool | None], expected: bool | None) -> None: + assert pl.Series(data, dtype=pl.Boolean).all(ignore_nulls=False) is expected + + +def test_all_wrong_dtype() -> None: + with pytest.raises(SchemaError, match="expected `Boolean`"): + pl.Series([0, 1, 0]).all() diff --git a/py-polars/tests/unit/series/test_append.py b/py-polars/tests/unit/series/test_append.py new file mode 100644 index 000000000000..6dabf9ff9f2a --- /dev/null +++ b/py-polars/tests/unit/series/test_append.py @@ -0,0 +1,89 @@ +import pytest + +import polars as pl +from polars.exceptions import SchemaError +from polars.testing import assert_series_equal + + +def test_append() -> None: + a = pl.Series("a", [1, 2]) + b = pl.Series("b", [8, 9, None]) + + result = a.append(b) + + expected = pl.Series("a", [1, 2, 8, 9, None]) + assert_series_equal(a, expected) + assert_series_equal(result, expected) + assert a.n_chunks() == 2 + + +def test_append_self_3915() -> None: + a = pl.Series("a", [1, 2]) + + a.append(a) + + expected = pl.Series("a", [1, 2, 1, 2]) + assert_series_equal(a, expected) + assert a.n_chunks() == 2 + + +def test_append_bad_input() -> None: + a = pl.Series("a", [1, 2]) + b = a.to_frame() + + with pytest.raises(AttributeError): + a.append(b) # type: ignore[arg-type] + + +def test_struct_schema_on_append_extend_3452() -> None: + housing1_data = [ + { + "city": "Chicago", + "address": "100 Main St", + "price": 250000, + "nbr_bedrooms": 3, + }, + { + "city": "New York", + "address": "100 First Ave", + "price": 450000, + "nbr_bedrooms": 2, + }, + ] + + housing2_data = [ + { + "address": "303 Mockingbird Lane", + "city": "Los Angeles", + "nbr_bedrooms": 2, + "price": 450000, + }, + { + "address": "404 Moldave Dr", + "city": "Miami Beach", + "nbr_bedrooms": 1, + "price": 250000, + }, + ] + housing1, housing2 = pl.Series(housing1_data), pl.Series(housing2_data) + with pytest.raises( + SchemaError, + ): + housing1.append(housing2) + + with pytest.raises( + SchemaError, + ): + housing1.extend(housing2) + + +def test_append_null_series() -> None: + a = pl.Series("a", [1, 2], pl.Int64) + b = pl.Series("b", [None, None], pl.Null) + + result = a.append(b) + + expected = pl.Series("a", [1, 2, None, None], pl.Int64) + assert_series_equal(a, expected) + assert_series_equal(result, expected) + assert a.n_chunks() == 2 diff --git a/py-polars/tests/unit/series/test_contains.py b/py-polars/tests/unit/series/test_contains.py new file mode 100644 index 000000000000..67dea7cee074 --- /dev/null +++ b/py-polars/tests/unit/series/test_contains.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from datetime import date +from typing import Any + +import pytest + +import polars as pl + + +@pytest.mark.parametrize( + ("item", "data", "expected"), + [ + (1, [1, 2, 3], True), + (4, [1, 2, 3], False), + (None, [1, None], True), + (None, [1, 2], False), + (date(2022, 1, 1), [date(2022, 1, 1), date(2023, 1, 1)], True), + ], +) +def test_contains(item: Any, data: list[Any], expected: bool) -> None: + s = pl.Series(data) + result = item in s + assert result is expected + + +def test_contains_none() -> None: + s = pl.Series([1, None]) + result = None in s + assert result is True + + s = pl.Series([1, 2]) + assert (None in s) is False diff --git a/py-polars/tests/unit/series/test_describe.py b/py-polars/tests/unit/series/test_describe.py new file mode 100644 index 000000000000..be0641a02fb7 --- /dev/null +++ b/py-polars/tests/unit/series/test_describe.py @@ -0,0 +1,128 @@ +from datetime import date + +import polars as pl +from polars.testing.asserts.frame import assert_frame_equal + + +def test_series_describe_int() -> None: + s = pl.Series([1, 2, 3]) + result = s.describe() + + stats = { + "count": 3.0, + "null_count": 0.0, + "mean": 2.0, + "std": 1.0, + "min": 1.0, + "25%": 2.0, + "50%": 2.0, + "75%": 3.0, + "max": 3.0, + } + expected = pl.DataFrame({"statistic": stats.keys(), "value": stats.values()}) + assert_frame_equal(expected, result) + + +def test_series_describe_float() -> None: + s = pl.Series([1.3, 4.6, 8.9]) + result = s.describe() + + stats = { + "count": 3.0, + "null_count": 0.0, + "mean": 4.933333333333334, + "std": 3.8109491381194442, + "min": 1.3, + "25%": 4.6, + "50%": 4.6, + "75%": 8.9, + "max": 8.9, + } + expected = pl.DataFrame({"statistic": stats.keys(), "value": stats.values()}) + assert_frame_equal(expected, result) + + +def test_series_describe_string() -> None: + s = pl.Series(["abc", "pqr", "xyz"]) + result = s.describe() + + stats = { + "count": "3", + "null_count": "0", + "min": "abc", + "max": "xyz", + } + expected = pl.DataFrame({"statistic": stats.keys(), "value": stats.values()}) + assert_frame_equal(expected, result) + + +def test_series_describe_boolean() -> None: + s = pl.Series([True, False, None, True, True]) + result = s.describe() + + stats = { + "count": 4, + "null_count": 1, + "mean": 0.75, + "min": False, + "max": True, + } + expected = pl.DataFrame( + data={"statistic": stats.keys(), "value": stats.values()}, + schema_overrides={"value": pl.Float64}, + ) + assert_frame_equal(expected, result) + + +def test_series_describe_date() -> None: + s = pl.Series([date(1999, 12, 31), date(2011, 3, 11), date(2021, 1, 18)]) + result = s.describe(interpolation="linear") + + stats = { + "count": "3", + "null_count": "0", + "mean": "2010-09-29 16:00:00", + "min": "1999-12-31", + "25%": "2005-08-05", + "50%": "2011-03-11", + "75%": "2016-02-13", + "max": "2021-01-18", + } + expected = pl.DataFrame({"statistic": stats.keys(), "value": stats.values()}) + assert_frame_equal(expected, result) + + +def test_series_describe_empty() -> None: + s = pl.Series(dtype=pl.Float64) + result = s.describe() + stats = { + "count": 0.0, + "null_count": 0.0, + } + expected = pl.DataFrame({"statistic": stats.keys(), "value": stats.values()}) + assert_frame_equal(expected, result) + + +def test_series_describe_null() -> None: + s = pl.Series([None, None], dtype=pl.Null) + result = s.describe() + stats = { + "count": 0.0, + "null_count": 2.0, + } + expected = pl.DataFrame({"statistic": stats.keys(), "value": stats.values()}) + assert_frame_equal(expected, result) + + +def test_series_describe_nested_list() -> None: + s = pl.Series( + values=[[10e10, 10e15], [10e12, 10e13], [10e10, 10e15]], + dtype=pl.List(pl.Float64), + ) + result = s.describe() + stats = { + "count": 3.0, + "null_count": 0.0, + } + expected = pl.DataFrame({"statistic": stats.keys(), "value": stats.values()}) + assert_frame_equal(expected, result) diff --git a/py-polars/tests/unit/series/test_equals.py b/py-polars/tests/unit/series/test_equals.py new file mode 100644 index 000000000000..91df683d2423 --- /dev/null +++ b/py-polars/tests/unit/series/test_equals.py @@ -0,0 +1,305 @@ +from datetime import datetime +from typing import Callable + +import pytest + +import polars as pl +from polars.testing import assert_series_equal + + +def test_equals() -> None: + s1 = pl.Series("a", [1.0, 2.0, None], pl.Float64) + s2 = pl.Series("a", [1, 2, None], pl.Int64) + + assert s1.equals(s2) is True + assert s1.equals(s2, check_dtypes=True) is False + assert s1.equals(s2, null_equal=False) is False + + df = pl.DataFrame( + {"dtm": [datetime(2222, 2, 22, 22, 22, 22)]}, + schema_overrides={"dtm": pl.Datetime(time_zone="UTC")}, + ).with_columns( + s3=pl.col("dtm").dt.convert_time_zone("Europe/London"), + s4=pl.col("dtm").dt.convert_time_zone("Asia/Tokyo"), + ) + s3 = df["s3"].rename("b") + s4 = df["s4"].rename("b") + + assert s3.equals(s4) is False + assert s3.equals(s4, check_dtypes=True) is False + assert s3.equals(s4, null_equal=False) is False + assert s3.dt.convert_time_zone("Asia/Tokyo").equals(s4) is True + + +def test_series_equals_check_names() -> None: + s1 = pl.Series("foo", [1, 2, 3]) + s2 = pl.Series("bar", [1, 2, 3]) + assert s1.equals(s2) is True + assert s1.equals(s2, check_names=True) is False + + +def test_eq_list_cmp_list() -> None: + s = pl.Series([[1], [1, 2]]) + result = s == [1, 2] + expected = pl.Series([False, True]) + assert_series_equal(result, expected) + + +def test_eq_list_cmp_int() -> None: + s = pl.Series([[1], [1, 2]]) + with pytest.raises( + TypeError, match="cannot convert Python type 'int' to List\\(Int64\\)" + ): + s == 1 # noqa: B015 + + +def test_eq_array_cmp_list() -> None: + s = pl.Series([[1, 3], [1, 2]], dtype=pl.Array(pl.Int16, 2)) + result = s == [1, 2] + expected = pl.Series([False, True]) + assert_series_equal(result, expected) + + +def test_eq_array_cmp_int() -> None: + s = pl.Series([[1, 3], [1, 2]], dtype=pl.Array(pl.Int16, 2)) + with pytest.raises( + TypeError, + match="cannot convert Python type 'int' to Array\\(Int16, shape=\\(2,\\)\\)", + ): + s == 1 # noqa: B015 + + +def test_eq_list() -> None: + s = pl.Series([1, 1]) + + result = s == [1, 2] + expected = pl.Series([True, False]) + assert_series_equal(result, expected) + + result = s == 1 + expected = pl.Series([True, True]) + assert_series_equal(result, expected) + + +def test_eq_missing_expr() -> None: + s = pl.Series([1, None]) + result = s.eq_missing(pl.lit(1)) + + assert isinstance(result, pl.Expr) + result_evaluated = pl.select(result).to_series() + expected = pl.Series([True, False]) + assert_series_equal(result_evaluated, expected) + + +def test_ne_missing_expr() -> None: + s = pl.Series([1, None]) + result = s.ne_missing(pl.lit(1)) + + assert isinstance(result, pl.Expr) + result_evaluated = pl.select(result).to_series() + expected = pl.Series([False, True]) + assert_series_equal(result_evaluated, expected) + + +def test_series_equals_strict_deprecated() -> None: + s1 = pl.Series("a", [1.0, 2.0, None], pl.Float64) + s2 = pl.Series("a", [1, 2, None], pl.Int64) + with pytest.deprecated_call(): + assert not s1.equals(s2, strict=True) # type: ignore[call-arg] + + +@pytest.mark.parametrize("dtype", [pl.List(pl.Int64), pl.Array(pl.Int64, 2)]) +@pytest.mark.parametrize( + ("cmp_eq", "cmp_ne"), + [ + # We parametrize the comparison sides as the impl looks like this: + # match (left.len(), right.len()) { + # (1, _) => ..., + # (_, 1) => ..., + # (_, _) => ..., + # } + (pl.Series.eq, pl.Series.ne), + ( + lambda a, b: pl.Series.eq(b, a), + lambda a, b: pl.Series.ne(b, a), + ), + ], +) +def test_eq_lists_arrays( + dtype: pl.DataType, + cmp_eq: Callable[[pl.Series, pl.Series], pl.Series], + cmp_ne: Callable[[pl.Series, pl.Series], pl.Series], +) -> None: + # Broadcast NULL + assert_series_equal( + cmp_eq( + pl.Series([None], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([None, None, None], dtype=pl.Boolean), + ) + + assert_series_equal( + cmp_ne( + pl.Series([None], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([None, None, None], dtype=pl.Boolean), + ) + + # Non-broadcast full-NULL + assert_series_equal( + cmp_eq( + pl.Series(3 * [None], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([None, None, None], dtype=pl.Boolean), + ) + + assert_series_equal( + cmp_ne( + pl.Series(3 * [None], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([None, None, None], dtype=pl.Boolean), + ) + + # Broadcast valid + assert_series_equal( + cmp_eq( + pl.Series([[1, None]], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([None, True, False], dtype=pl.Boolean), + ) + + assert_series_equal( + cmp_ne( + pl.Series([[1, None]], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([None, False, True], dtype=pl.Boolean), + ) + + # Non-broadcast mixed + assert_series_equal( + cmp_eq( + pl.Series([None, [1, 1], [1, 1]], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([None, False, True], dtype=pl.Boolean), + ) + + assert_series_equal( + cmp_ne( + pl.Series([None, [1, 1], [1, 1]], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([None, True, False], dtype=pl.Boolean), + ) + + +@pytest.mark.parametrize("dtype", [pl.List(pl.Int64), pl.Array(pl.Int64, 2)]) +@pytest.mark.parametrize( + ("cmp_eq_missing", "cmp_ne_missing"), + [ + (pl.Series.eq_missing, pl.Series.ne_missing), + ( + lambda a, b: pl.Series.eq_missing(b, a), + lambda a, b: pl.Series.ne_missing(b, a), + ), + ], +) +def test_eq_missing_lists_arrays_19153( + dtype: pl.DataType, + cmp_eq_missing: Callable[[pl.Series, pl.Series], pl.Series], + cmp_ne_missing: Callable[[pl.Series, pl.Series], pl.Series], +) -> None: + def assert_series_equal( + left: pl.Series, + right: pl.Series, + *, + assert_series_equal_impl: Callable[[pl.Series, pl.Series], None] = globals()[ + "assert_series_equal" + ], + ) -> None: + # `assert_series_equal` also uses `ne_missing` underneath so we have + # some extra checks here to be sure. + assert_series_equal_impl(left, right) + assert left.to_list() == right.to_list() + assert left.null_count() == 0 + assert right.null_count() == 0 + + # Broadcast NULL + assert_series_equal( + cmp_eq_missing( + pl.Series([None], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([True, False, False]), + ) + + assert_series_equal( + cmp_ne_missing( + pl.Series([None], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([False, True, True]), + ) + + # Non-broadcast full-NULL + assert_series_equal( + cmp_eq_missing( + pl.Series(3 * [None], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([True, False, False]), + ) + + assert_series_equal( + cmp_ne_missing( + pl.Series(3 * [None], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([False, True, True]), + ) + + # Broadcast valid + assert_series_equal( + cmp_eq_missing( + pl.Series([[1, None]], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([False, True, False]), + ) + + assert_series_equal( + cmp_ne_missing( + pl.Series([[1, None]], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([True, False, True]), + ) + + # Non-broadcast mixed + assert_series_equal( + cmp_eq_missing( + pl.Series([None, [1, 1], [1, 1]], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([True, False, True]), + ) + + assert_series_equal( + cmp_ne_missing( + pl.Series([None, [1, 1], [1, 1]], dtype=dtype), + pl.Series([None, [1, None], [1, 1]], dtype=dtype), + ), + pl.Series([False, True, False]), + ) + + +def test_equals_nested_null_categorical_14875() -> None: + dtype = pl.List(pl.Struct({"cat": pl.Categorical})) + s = pl.Series([[{"cat": None}]], dtype=dtype) + assert s.equals(s) diff --git a/py-polars/tests/unit/series/test_extend.py b/py-polars/tests/unit/series/test_extend.py new file mode 100644 index 000000000000..031bcdd847ca --- /dev/null +++ b/py-polars/tests/unit/series/test_extend.py @@ -0,0 +1,53 @@ +import pytest + +import polars as pl +from polars.testing import assert_series_equal + + +def test_extend() -> None: + a = pl.Series("a", [1, 2]) + b = pl.Series("b", [8, 9, None]) + + result = a.extend(b) + + expected = pl.Series("a", [1, 2, 8, 9, None]) + assert_series_equal(a, expected) + assert_series_equal(result, expected) + assert a.n_chunks() == 1 + + +def test_extend_self() -> None: + a = pl.Series("a", [1, 2]) + + a.extend(a) + + expected = pl.Series("a", [1, 2, 1, 2]) + assert_series_equal(a, expected) + assert a.n_chunks() == 1 + + +def test_extend_bad_input() -> None: + a = pl.Series("a", [1, 2]) + b = a.to_frame() + + with pytest.raises(AttributeError): + a.extend(b) # type: ignore[arg-type] + + +def test_extend_with_null_series() -> None: + a = pl.Series("a", [1, 2], pl.Int64) + b = pl.Series("b", [None, None], pl.Null) + + result = a.extend(b) + + expected = pl.Series("a", [1, 2, None, None], pl.Int64) + assert_series_equal(a, expected) + assert_series_equal(result, expected) + assert a.n_chunks() == 1 + + +def test_extend_sliced_12968() -> None: + assert pl.Series(["a", "b"]).slice(0, 1).extend(pl.Series(["c"])).to_list() == [ + "a", + "c", + ] diff --git a/py-polars/tests/unit/series/test_getitem.py b/py-polars/tests/unit/series/test_getitem.py new file mode 100644 index 000000000000..6b7f16c580c9 --- /dev/null +++ b/py-polars/tests/unit/series/test_getitem.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +from typing import Any + +import hypothesis.strategies as st +import numpy as np +import pytest +from hypothesis import given + +import polars as pl +from polars.testing import assert_series_equal +from polars.testing.parametric import series + + +@given( + srs=series(max_size=10, dtype=pl.Int64), + start=st.sampled_from([-5, -4, -3, -2, -1, None, 0, 1, 2, 3, 4, 5]), + stop=st.sampled_from([-5, -4, -3, -2, -1, None, 0, 1, 2, 3, 4, 5]), + step=st.sampled_from([-5, -4, -3, -2, -1, None, 1, 2, 3, 4, 5]), +) +def test_series_getitem( + srs: pl.Series, + start: int | None, + stop: int | None, + step: int | None, +) -> None: + py_data = srs.to_list() + + s = slice(start, stop, step) + sliced_py_data = py_data[s] + sliced_pl_data = srs[s].to_list() + + assert sliced_py_data == sliced_pl_data, f"slice [{start}:{stop}:{step}] failed" + assert_series_equal(srs, srs, check_exact=True) + + +@pytest.mark.parametrize( + ("rng", "expected_values"), + [ + (range(2), [1, 2]), + (range(1, 4), [2, 3, 4]), + (range(3, 0, -2), [4, 2]), + ], +) +def test_series_getitem_range(rng: range, expected_values: list[int]) -> None: + s = pl.Series([1, 2, 3, 4]) + result = s[rng] + expected = pl.Series(expected_values) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "mask", + [ + [True, False, True], + pl.Series([True, False, True]), + np.array([True, False, True]), + ], +) +def test_series_getitem_boolean_mask(mask: Any) -> None: + s = pl.Series([1, 2, 3]) + print(mask) + with pytest.raises( + TypeError, + match="selecting rows by passing a boolean mask to `__getitem__` is not supported", + ): + s[mask] + + +@pytest.mark.parametrize( + "input", [[], (), pl.Series(dtype=pl.Int64), np.array([], dtype=np.uint32)] +) +def test_series_getitem_empty_inputs(input: Any) -> None: + s = pl.Series("a", ["x", "y", "z"], dtype=pl.String) + result = s[input] + expected = pl.Series("a", dtype=pl.String) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize("indices", [[0, 2], pl.Series([0, 2]), np.array([0, 2])]) +def test_series_getitem_multiple_indices(indices: Any) -> None: + s = pl.Series(["x", "y", "z"]) + result = s[indices] + expected = pl.Series(["x", "z"]) + assert_series_equal(result, expected) + + +def test_series_getitem_numpy() -> None: + s = pl.Series([9, 8, 7]) + + assert s[np.array([0, 2])].to_list() == [9, 7] + assert s[np.array([-1, -3])].to_list() == [7, 9] + assert s[np.array(-2)].to_list() == [8] + + +@pytest.mark.parametrize( + ("input", "match"), + [ + ( + [0.0, 1.0], + "cannot select elements using Sequence with elements of type 'float'", + ), + ( + "foobar", + "cannot select elements using Sequence with elements of type 'str'", + ), + ( + pl.Series([[1, 2], [3, 4]]), + "cannot treat Series of type List\\(Int64\\) as indices", + ), + (np.array([0.0, 1.0]), "cannot treat NumPy array of type float64 as indices"), + (object(), "cannot select elements using key of type 'object'"), + ], +) +def test_series_getitem_col_invalid_inputs(input: Any, match: str) -> None: + s = pl.Series([1, 2, 3]) + with pytest.raises(TypeError, match=match): + s[input] diff --git a/py-polars/tests/unit/series/test_item.py b/py-polars/tests/unit/series/test_item.py new file mode 100644 index 000000000000..7d8be87ee946 --- /dev/null +++ b/py-polars/tests/unit/series/test_item.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import datetime + +import pytest + +import polars as pl + + +def test_series_item() -> None: + s = pl.Series("a", [1]) + assert s.item() == 1 + + +def test_series_item_empty() -> None: + s = pl.Series("a", []) + with pytest.raises(ValueError): + s.item() + + +def test_series_item_incorrect_shape() -> None: + s = pl.Series("a", [1, 2]) + with pytest.raises(ValueError): + s.item() + + +@pytest.fixture(scope="module") +def s() -> pl.Series: + return pl.Series("a", [1, 2]) + + +@pytest.mark.parametrize(("index", "expected"), [(0, 1), (1, 2), (-1, 2), (-2, 1)]) +def test_series_item_with_index(index: int, expected: int, s: pl.Series) -> None: + assert s.item(index) == expected + + +@pytest.mark.parametrize("index", [-10, 10]) +def test_df_item_out_of_bounds(index: int, s: pl.Series) -> None: + with pytest.raises(IndexError, match="out of bounds"): + s.item(index) + + +def test_series_item_out_of_range_date() -> None: + s = pl.Series([datetime.date(9999, 12, 31)]).dt.offset_by("1d") + with pytest.raises(ValueError, match="out of range"): + s.item() diff --git a/py-polars/tests/unit/series/test_scatter.py b/py-polars/tests/unit/series/test_scatter.py new file mode 100644 index 000000000000..843fd1727e7d --- /dev/null +++ b/py-polars/tests/unit/series/test_scatter.py @@ -0,0 +1,96 @@ +from datetime import date, datetime +from typing import Any + +import numpy as np +import pytest + +import polars as pl +from polars.exceptions import ComputeError, InvalidOperationError, OutOfBoundsError +from polars.testing import assert_series_equal + + +@pytest.mark.parametrize( + "input", + [ + (), + [], + pl.Series(), + pl.Series(dtype=pl.Int8), + np.array([]), + ], +) +def test_scatter_noop(input: Any) -> None: + s = pl.Series("s", [1, 2, 3]) + s.scatter(input, 8) + assert s.to_list() == [1, 2, 3] + + +def test_scatter() -> None: + s = pl.Series("s", [1, 2, 3]) + + # set new values, one index at a time + s.scatter(0, 8) + s.scatter([1], None) + assert s.to_list() == [8, None, 3] + + # set new value at multiple indexes in one go + s.scatter([0, 2], None) + assert s.to_list() == [None, None, None] + + # try with different series dtype + s = pl.Series("s", ["a", "b", "c"]) + s.scatter((1, 2), "x") + assert s.to_list() == ["a", "x", "x"] + assert s.scatter([0, 2], 0.12345).to_list() == ["0.12345", "x", "0.12345"] + + # set multiple values + s = pl.Series(["z", "z", "z"]) + assert s.scatter([0, 1], ["a", "b"]).to_list() == ["a", "b", "z"] + s = pl.Series([True, False, True]) + assert s.scatter([0, 1], [False, True]).to_list() == [False, True, True] + + # set negative indices + a = pl.Series("r", range(5)) + a[-2] = None + a[-5] = None + assert a.to_list() == [None, 1, 2, None, 4] + + a = pl.Series("x", [1, 2]) + with pytest.raises(OutOfBoundsError): + a[-100] = None + assert_series_equal(a, pl.Series("x", [1, 2])) + + +def test_index_with_None_errors_16905() -> None: + s = pl.Series("s", [1, 2, 3]) + with pytest.raises(ComputeError, match="index values should not be null"): + s[[1, None]] = 5 + # The error doesn't trash the series, as it used to: + assert_series_equal(s, pl.Series("s", [1, 2, 3])) + + +def test_object_dtype_16905() -> None: + obj = object() + s = pl.Series("s", [obj, 27], dtype=pl.Object) + # This operation is not semantically wrong, it might be supported in the + # future, but for now it isn't. + with pytest.raises(InvalidOperationError): + s[0] = 5 + # The error doesn't trash the series, as it used to: + assert s.dtype.is_object() + assert s.name == "s" + assert s.to_list() == [obj, 27] + + +def test_scatter_datetime() -> None: + s = pl.Series("dt", [None, datetime(2024, 1, 31)]) + result = s.scatter(0, datetime(2022, 2, 2)) + expected = pl.Series("dt", [datetime(2022, 2, 2), datetime(2024, 1, 31)]) + assert_series_equal(result, expected) + + +def test_scatter_logical_all_null() -> None: + s = pl.Series("dt", [None, None], dtype=pl.Date) + result = s.scatter(0, date(2022, 2, 2)) + expected = pl.Series("dt", [date(2022, 2, 2), None]) + assert_series_equal(result, expected) diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py new file mode 100644 index 000000000000..f9e888d59724 --- /dev/null +++ b/py-polars/tests/unit/series/test_series.py @@ -0,0 +1,2236 @@ +from __future__ import annotations + +import math +from datetime import date, datetime, time, timedelta +from typing import TYPE_CHECKING, Any, cast +from zoneinfo import ZoneInfo + +import numpy as np +import pandas as pd +import pyarrow as pa +import pytest + +import polars as pl +from polars._utils.construction import iterable_to_pyseries +from polars.datatypes import ( + Datetime, + Field, + Float64, + Int32, + Int64, + Time, + UInt32, + UInt64, + Unknown, +) +from polars.exceptions import ( + DuplicateError, + InvalidOperationError, + PolarsInefficientMapWarning, + ShapeError, +) +from polars.testing import assert_frame_equal, assert_series_equal +from tests.unit.conftest import FLOAT_DTYPES, INTEGER_DTYPES +from tests.unit.utils.pycapsule_utils import PyCapsuleStreamHolder + +if TYPE_CHECKING: + from collections.abc import Iterator + + from polars._typing import EpochTimeUnit, PolarsDataType, TimeUnit + + +def test_cum_agg() -> None: + # confirm that known series give expected results + s = pl.Series("a", [1, 2, 3, 2]) + assert_series_equal(s.cum_sum(), pl.Series("a", [1, 3, 6, 8])) + assert_series_equal(s.cum_min(), pl.Series("a", [1, 1, 1, 1])) + assert_series_equal(s.cum_max(), pl.Series("a", [1, 2, 3, 3])) + assert_series_equal(s.cum_prod(), pl.Series("a", [1, 2, 6, 12])) + + +def test_cum_agg_with_nulls() -> None: + # confirm that known series give expected results + s = pl.Series("a", [None, 2, None, 7, 8, None]) + assert_series_equal(s.cum_sum(), pl.Series("a", [None, 2, None, 9, 17, None])) + assert_series_equal(s.cum_min(), pl.Series("a", [None, 2, None, 2, 2, None])) + assert_series_equal(s.cum_max(), pl.Series("a", [None, 2, None, 7, 8, None])) + assert_series_equal(s.cum_prod(), pl.Series("a", [None, 2, None, 14, 112, None])) + + +def test_cum_min_max_bool() -> None: + s = pl.Series("a", [None, True, True, None, False, None, True, False, False, None]) + assert_series_equal(s.cum_min().cast(pl.Int32), s.cast(pl.Int32).cum_min()) + assert_series_equal(s.cum_max().cast(pl.Int32), s.cast(pl.Int32).cum_max()) + assert_series_equal( + s.cum_min(reverse=True).cast(pl.Int32), s.cast(pl.Int32).cum_min(reverse=True) + ) + assert_series_equal( + s.cum_max(reverse=True).cast(pl.Int32), s.cast(pl.Int32).cum_max(reverse=True) + ) + + +def test_init_inputs(monkeypatch: Any) -> None: + nan = float("nan") + # Good inputs + pl.Series("a", [1, 2]) + pl.Series("a", values=[1, 2]) + pl.Series(name="a", values=[1, 2]) + pl.Series(values=[1, 2], name="a") + + assert pl.Series([1, 2]).dtype == pl.Int64 + assert pl.Series(values=[1, 2]).dtype == pl.Int64 + assert pl.Series("a").dtype == pl.Null # Null dtype used in case of no data + assert pl.Series().dtype == pl.Null + assert pl.Series([]).dtype == pl.Null + assert ( + pl.Series([None, None, None]).dtype == pl.Null + ) # f32 type used for list with only None + assert pl.Series(values=[True, False]).dtype == pl.Boolean + assert pl.Series(values=np.array([True, False])).dtype == pl.Boolean + assert pl.Series(values=np.array(["foo", "bar"])).dtype == pl.String + assert pl.Series(values=["foo", "bar"]).dtype == pl.String + assert pl.Series("a", [pl.Series([1, 2, 4]), pl.Series([3, 2, 1])]).dtype == pl.List + assert pl.Series("a", [10000, 20000, 30000], dtype=pl.Time).dtype == pl.Time + + # 2d numpy array and/or list of 1d numpy arrays + for res in ( + pl.Series( + name="a", + values=np.array([[1, 2], [3, nan]], dtype=np.float32), + nan_to_null=True, + ), + pl.Series( + name="a", + values=[ + np.array([1, 2], dtype=np.float32), + np.array([3, nan], dtype=np.float32), + ], + nan_to_null=True, + ), + pl.Series( + name="a", + values=( + np.ndarray((2,), np.float32, np.array([1, 2], dtype=np.float32)), + np.ndarray((2,), np.float32, np.array([3, nan], dtype=np.float32)), + ), + nan_to_null=True, + ), + ): + assert res.dtype == pl.Array(pl.Float32, shape=2) + assert res[0].to_list() == [1.0, 2.0] + assert res[1].to_list() == [3.0, None] + + # numpy from arange, with/without dtype + two_ints = np.arange(2, dtype=np.int64) + three_ints = np.arange(3, dtype=np.int64) + for res in ( + pl.Series("a", [two_ints, three_ints]), + pl.Series("a", [two_ints, three_ints], dtype=pl.List(pl.Int64)), + ): + assert res.dtype == pl.List(pl.Int64) + assert res.to_list() == [[0, 1], [0, 1, 2]] + + assert pl.Series( + values=np.array([["foo", "bar"], ["foo2", "bar2"]]) + ).dtype == pl.Array(pl.String, shape=2) + + # lists + assert pl.Series("a", [[1, 2], [3, 4]]).dtype == pl.List(pl.Int64) + + # conversion of Date to Datetime + s = pl.Series([date(2023, 1, 1), date(2023, 1, 2)], dtype=pl.Datetime) + assert s.to_list() == [datetime(2023, 1, 1), datetime(2023, 1, 2)] + assert Datetime == s.dtype + assert s.dtype.time_unit == "us" # type: ignore[attr-defined] + assert s.dtype.time_zone is None # type: ignore[attr-defined] + + # conversion of Date to Datetime with specified timezone and units + tu: TimeUnit = "ms" + tz = "America/Argentina/Rio_Gallegos" + s = pl.Series( + [date(2023, 1, 1), date(2023, 1, 2)], dtype=pl.Datetime(tu) + ).dt.replace_time_zone(tz) + d1 = datetime(2023, 1, 1, 0, 0, 0, 0, ZoneInfo(tz)) + d2 = datetime(2023, 1, 2, 0, 0, 0, 0, ZoneInfo(tz)) + assert s.to_list() == [d1, d2] + assert Datetime == s.dtype + assert s.dtype.time_unit == tu # type: ignore[attr-defined] + assert s.dtype.time_zone == tz # type: ignore[attr-defined] + + # datetime64: check timeunit (auto-detect, implicit/explicit) and NaT + d64 = pd.date_range(date(2021, 8, 1), date(2021, 8, 3)).values + d64[1] = None + + expected = [datetime(2021, 8, 1, 0), None, datetime(2021, 8, 3, 0)] + for dtype in (None, Datetime, Datetime("ns")): + s = pl.Series("dates", d64, dtype) + assert s.to_list() == expected + assert Datetime == s.dtype + assert s.dtype.time_unit == "ns" # type: ignore[attr-defined] + + s = pl.Series(values=d64.astype(" None: + # validate init from dataclass, namedtuple, and pydantic model objects + from typing import NamedTuple + + from polars.dependencies import dataclasses, pydantic + + @dataclasses.dataclass + class TeaShipmentDC: + exporter: str + importer: str + product: str + tonnes: int | None + + class TeaShipmentNT(NamedTuple): + exporter: str + importer: str + product: str + tonnes: None | int + + class TeaShipmentPD(pydantic.BaseModel): + exporter: str + importer: str + product: str + tonnes: int + + for Tea in (TeaShipmentDC, TeaShipmentNT, TeaShipmentPD): + t0 = Tea(exporter="Sri Lanka", importer="USA", product="Ceylon", tonnes=10) + t1 = Tea(exporter="India", importer="UK", product="Darjeeling", tonnes=25) + t2 = Tea(exporter="China", importer="UK", product="Keemum", tonnes=40) + + s = pl.Series("t", [t0, t1, t2]) + + assert isinstance(s, pl.Series) + assert s.dtype.fields == [ # type: ignore[attr-defined] + Field("exporter", pl.String), + Field("importer", pl.String), + Field("product", pl.String), + Field("tonnes", pl.Int64), + ] + assert s.to_list() == [ + { + "exporter": "Sri Lanka", + "importer": "USA", + "product": "Ceylon", + "tonnes": 10, + }, + { + "exporter": "India", + "importer": "UK", + "product": "Darjeeling", + "tonnes": 25, + }, + { + "exporter": "China", + "importer": "UK", + "product": "Keemum", + "tonnes": 40, + }, + ] + assert_frame_equal(s.to_frame(), pl.DataFrame({"t": [t0, t1, t2]})) + + +def test_to_frame() -> None: + s1 = pl.Series([1, 2]) + s2 = pl.Series("s", [1, 2]) + + df1 = s1.to_frame() + df2 = s2.to_frame() + df3 = s1.to_frame("xyz") + df4 = s2.to_frame("xyz") + + for df, name in ((df1, ""), (df2, "s"), (df3, "xyz"), (df4, "xyz")): + assert isinstance(df, pl.DataFrame) + assert df.rows() == [(1,), (2,)] + assert df.columns == [name] + + # note: the empty string IS technically a valid column name + assert s2.to_frame("").columns == [""] + assert s2.name == "s" + + +def test_bitwise_ops() -> None: + a = pl.Series([True, False, True]) + b = pl.Series([False, True, True]) + assert_series_equal((a & b), pl.Series([False, False, True])) + assert_series_equal((a | b), pl.Series([True, True, True])) + assert_series_equal((a ^ b), pl.Series([True, True, False])) + assert_series_equal((~a), pl.Series([False, True, False])) + + # rand/rxor/ror we trigger by casting the left hand to a list here in the test + # Note that the type annotations only allow Series to be passed in, but there is + # specific code to deal with non-Series inputs. + assert_series_equal( + (True & a), + pl.Series([True, False, True]), + ) + assert_series_equal( + (True | a), + pl.Series([True, True, True]), + ) + assert_series_equal( + (True ^ a), + pl.Series([False, True, False]), + ) + + +def test_bitwise_floats_invert() -> None: + s = pl.Series([2.0, 3.0, 0.0]) + + with pytest.raises(InvalidOperationError): + ~s + + +def test_equality() -> None: + a = pl.Series("a", [1, 2]) + b = a + + cmp = a == b + assert isinstance(cmp, pl.Series) + assert cmp.sum() == 2 + assert (a != b).sum() == 0 + assert (a >= b).sum() == 2 + assert (a <= b).sum() == 2 + assert (a > b).sum() == 0 + assert (a < b).sum() == 0 + assert a.sum() == 3 + assert_series_equal(a, b) + + a = pl.Series("name", ["ham", "foo", "bar"]) + assert_series_equal((a == "ham"), pl.Series("name", [True, False, False])) + + a = pl.Series("name", [[1], [1, 2], [2, 3]]) + assert_series_equal((a == [1]), pl.Series("name", [True, False, False])) + + +def test_agg() -> None: + series = pl.Series("a", [1, 2]) + assert series.mean() == 1.5 + assert series.min() == 1 + assert series.max() == 2 + + +def test_date_agg() -> None: + series = pl.Series( + [ + date(2022, 8, 2), + date(2096, 8, 1), + date(9009, 9, 9), + ], + dtype=pl.Date, + ) + assert series.min() == date(2022, 8, 2) + assert series.max() == date(9009, 9, 9) + + +@pytest.mark.parametrize( + ("s", "min", "max"), + [ + (pl.Series(["c", "b", "a"], dtype=pl.Categorical("lexical")), "a", "c"), + (pl.Series(["a", "c", "b"], dtype=pl.Categorical), "a", "b"), + (pl.Series([None, "a", "c", "b"], dtype=pl.Categorical("lexical")), "a", "c"), + (pl.Series([None, "c", "a", "b"], dtype=pl.Categorical), "c", "b"), + (pl.Series([], dtype=pl.Categorical("lexical")), None, None), + (pl.Series(["c", "b", "a"], dtype=pl.Enum(["c", "b", "a"])), "c", "a"), + (pl.Series(["c", "b", "a"], dtype=pl.Enum(["c", "b", "a", "d"])), "c", "a"), + ], +) +@pytest.mark.usefixtures("test_global_and_local") +def test_categorical_agg(s: pl.Series, min: str | None, max: str | None) -> None: + assert s.min() == min + assert s.max() == max + + +def test_add_string() -> None: + s = pl.Series(["hello", "weird"]) + result = s + " world" + print(result) + assert_series_equal(result, pl.Series(["hello world", "weird world"])) + + result = "pfx:" + s + assert_series_equal(result, pl.Series("literal", ["pfx:hello", "pfx:weird"])) + + +@pytest.mark.parametrize( + ("data", "expected_dtype"), + [ + (100, pl.Int64), + (8.5, pl.Float64), + ("서울특별시", pl.String), + (date.today(), pl.Date), + (datetime.now(), pl.Datetime("us")), + (time(23, 59, 59), pl.Time), + (timedelta(hours=7, seconds=123), pl.Duration("us")), + ], +) +def test_unknown_dtype(data: Any, expected_dtype: PolarsDataType) -> None: + # if given 'Unknown', should be able to infer the correct dtype + s = pl.Series([data], dtype=Unknown) + assert s.dtype == expected_dtype + assert s.to_list() == [data] + + +def test_various() -> None: + a = pl.Series("a", [1, 2]) + assert a.is_null().sum() == 0 + assert a.name == "a" + + a = a.rename("b") + assert a.name == "b" + assert a.len() == 2 + assert len(a) == 2 + + a.append(a.clone()) + assert_series_equal(a, pl.Series("b", [1, 2, 1, 2])) + + a = pl.Series("a", range(20)) + assert a.head(5).len() == 5 + assert a.tail(5).len() == 5 + assert (a.head(5) != a.tail(5)).all() + + a = pl.Series("a", [2, 1, 4]) + a.sort(in_place=True) + assert_series_equal(a, pl.Series("a", [1, 2, 4])) + a = pl.Series("a", [2, 1, 1, 4, 4, 4]) + assert_series_equal(a.arg_unique(), pl.Series("a", [0, 1, 3], dtype=UInt32)) + + assert_series_equal(a.gather([2, 3]), pl.Series("a", [1, 4])) + + +def test_series_dtype_is() -> None: + s = pl.Series("s", [1, 2, 3]) + + assert s.dtype.is_numeric() + assert s.dtype.is_integer() + assert s.dtype.is_signed_integer() + assert not s.dtype.is_unsigned_integer() + assert (s * 0.99).dtype.is_float() + + s = pl.Series("s", [1, 2, 3], dtype=pl.UInt8) + assert s.dtype.is_numeric() + assert s.dtype.is_integer() + assert not s.dtype.is_signed_integer() + assert s.dtype.is_unsigned_integer() + + s = pl.Series("bool", [True, None, False]) + assert not s.dtype.is_numeric() + + s = pl.Series("s", ["testing..."]) + assert s.dtype == pl.String + assert s.dtype != pl.Boolean + + s = pl.Series("s", [], dtype=pl.Decimal(20, 15)) + assert not s.dtype.is_float() + assert s.dtype.is_numeric() + assert s.is_empty() + + s = pl.Series("s", [], dtype=pl.Datetime("ms", time_zone="UTC")) + assert s.dtype.is_temporal() + + +def test_series_head_tail_limit() -> None: + s = pl.Series(range(10)) + + assert_series_equal(s.head(5), pl.Series(range(5))) + assert_series_equal(s.limit(5), s.head(5)) + assert_series_equal(s.tail(5), pl.Series(range(5, 10))) + + # check if it doesn't fail when out of bounds + assert s.head(100).len() == 10 + assert s.limit(100).len() == 10 + assert s.tail(100).len() == 10 + + # negative values + assert_series_equal(s.head(-7), pl.Series(range(3))) + assert s.head(-2).len() == 8 + assert_series_equal(s.tail(-8), pl.Series(range(8, 10))) + assert s.head(-6).len() == 4 + + # negative values out of bounds + assert s.head(-12).len() == 0 + assert s.limit(-12).len() == 0 + assert s.tail(-12).len() == 0 + + +def test_filter_ops() -> None: + a = pl.Series("a", range(20)) + assert a.filter(a > 1).len() == 18 + assert a.filter(a < 1).len() == 1 + assert a.filter(a <= 1).len() == 2 + assert a.filter(a >= 1).len() == 19 + assert a.filter(a == 1).len() == 1 + assert a.filter(a != 1).len() == 19 + + +def test_cast() -> None: + a = pl.Series("a", range(20)) + + assert a.cast(pl.Float32).dtype == pl.Float32 + assert a.cast(pl.Float64).dtype == pl.Float64 + assert a.cast(pl.Int32).dtype == pl.Int32 + assert a.cast(pl.UInt32).dtype == pl.UInt32 + assert a.cast(pl.Datetime).dtype == pl.Datetime + assert a.cast(pl.Date).dtype == pl.Date + + # display failed values, GH#4706 + with pytest.raises(InvalidOperationError, match="foobar"): + pl.Series(["1", "2", "3", "4", "foobar"]).cast(int) + + +@pytest.mark.parametrize( + "test_data", + [ + [1, None, 2], + ["abc", None, "xyz"], + [None, datetime.now()], + [[1, 2], [3, 4], None], + ], +) +def test_to_pandas(test_data: list[Any]) -> None: + a = pl.Series("s", test_data) + b = a.to_pandas() + + assert a.name == b.name + assert b.isnull().sum() == 1 + + vals_b: list[Any] + if a.dtype == pl.List: + vals_b = [(None if x is None else x.tolist()) for x in b] + else: + v = b.replace({np.nan: None}).values.tolist() + vals_b = cast(list[Any], v) + + assert vals_b == test_data + + try: + c = a.to_pandas(use_pyarrow_extension_array=True) + assert a.name == c.name + assert c.isnull().sum() == 1 + vals_c = [None if x is pd.NA else x for x in c.tolist()] + assert vals_c == test_data + except ModuleNotFoundError: + # Skip test if pandas>=1.5.0 or Pyarrow>=8.0.0 is not installed. + pass + + +def test_series_to_list() -> None: + s = pl.Series("a", range(20)) + result = s.to_list() + assert isinstance(result, list) + assert len(result) == 20 + + a = pl.Series("a", [1, None, 2]) + assert a.null_count() == 1 + assert a.to_list() == [1, None, 2] + + +def test_to_struct() -> None: + s = pl.Series("nums", ["12 34", "56 78", "90 00"]).str.extract_all(r"\d+") + + assert s.list.to_struct().struct.fields == ["field_0", "field_1"] + assert s.list.to_struct(fields=lambda idx: f"n{idx:02}").struct.fields == [ + "n00", + "n01", + ] + assert_frame_equal( + s.list.to_struct(fields=["one", "two"]).struct.unnest(), + pl.DataFrame({"one": ["12", "56", "90"], "two": ["34", "78", "00"]}), + ) + + +def test_sort() -> None: + a = pl.Series("a", [2, 1, 3]) + assert_series_equal(a.sort(), pl.Series("a", [1, 2, 3])) + assert_series_equal(a.sort(descending=True), pl.Series("a", [3, 2, 1])) + + +def test_rechunk() -> None: + a = pl.Series("a", [1, 2, 3]) + b = pl.Series("b", [4, 5, 6]) + a.append(b) + assert a.n_chunks() == 2 + assert a.rechunk(in_place=False).n_chunks() == 1 + a.rechunk(in_place=True) + assert a.n_chunks() == 1 + + +def test_indexing() -> None: + a = pl.Series("a", [1, 2, None]) + assert a[1] == 2 + assert a[2] is None + b = pl.Series("b", [True, False]) + assert b[0] + assert not b[1] + a = pl.Series("a", ["a", None]) + assert a[0] == "a" + assert a[1] is None + a = pl.Series("a", [0.1, None]) + assert a[0] == 0.1 + assert a[1] is None + + +def test_arrow() -> None: + a = pl.Series("a", [1, 2, 3, None]) + out = a.to_arrow() + assert out == pa.array([1, 2, 3, None]) + + b = pl.Series("b", [1.0, 2.0, 3.0, None]) + out = b.to_arrow() + assert out == pa.array([1.0, 2.0, 3.0, None]) + + c = pl.Series("c", ["A", "BB", "CCC", None]) + out = c.to_arrow() + assert out == pa.array(["A", "BB", "CCC", None], type=pa.large_string()) + assert_series_equal(pl.from_arrow(out), c.rename("")) # type: ignore[arg-type] + + out = c.to_frame().to_arrow()["c"] + assert isinstance(out, (pa.Array, pa.ChunkedArray)) + assert_series_equal(pl.from_arrow(out), c) # type: ignore[arg-type] + assert_series_equal(pl.from_arrow(out, schema=["x"]), c.rename("x")) # type: ignore[arg-type] + + d = pl.Series("d", [None, None, None], pl.Null) + out = d.to_arrow() + assert out == pa.nulls(3) + + s = cast( + pl.Series, + pl.from_arrow(pa.array([["foo"], ["foo", "bar"]], pa.list_(pa.utf8()))), + ) + assert s.dtype == pl.List + + # categorical dtype tests (including various forms of empty pyarrow array) + with pl.StringCache(): + arr0 = pa.array(["foo", "bar"], pa.dictionary(pa.int32(), pa.utf8())) + assert_series_equal( + pl.Series("arr", ["foo", "bar"], pl.Categorical), pl.Series("arr", arr0) + ) + arr1 = pa.array(["xxx", "xxx", None, "yyy"]).dictionary_encode() + arr2 = pa.array([]).dictionary_encode() + arr3 = pa.chunked_array([], arr1.type) + arr4 = pa.array([], arr1.type) + + assert_series_equal( + pl.Series("arr", ["xxx", "xxx", None, "yyy"], dtype=pl.Categorical), + pl.Series("arr", arr1), + ) + for arr in (arr2, arr3, arr4): + assert_series_equal( + pl.Series("arr", [], dtype=pl.Categorical), pl.Series("arr", arr) + ) + + +def test_pycapsule_interface() -> None: + a = pl.Series("a", [1, 2, 3, None]) + out = pa.chunked_array(PyCapsuleStreamHolder(a)) + out_arr = out.combine_chunks() + assert out_arr == pa.array([1, 2, 3, None]) + + +def test_get() -> None: + a = pl.Series("a", [1, 2, 3]) + pos_idxs = pl.Series("idxs", [2, 0, 1, 0], dtype=pl.Int8) + neg_and_pos_idxs = pl.Series( + "neg_and_pos_idxs", [-2, 1, 0, -1, 2, -3], dtype=pl.Int8 + ) + empty_idxs = pl.Series("idxs", [], dtype=pl.Int8) + empty_ints: list[int] = [] + assert a[0] == 1 + assert a[:2].to_list() == [1, 2] + assert a[range(1)].to_list() == [1] + assert a[range(0, 4, 2)].to_list() == [1, 3] + assert a[:0].to_list() == [] + assert a[empty_ints].to_list() == [] + assert a[neg_and_pos_idxs.to_list()].to_list() == [2, 2, 1, 3, 3, 1] + for dtype in ( + pl.UInt8, + pl.UInt16, + pl.UInt32, + pl.UInt64, + pl.Int8, + pl.Int16, + pl.Int32, + pl.Int64, + ): + assert a[pos_idxs.cast(dtype)].to_list() == [3, 1, 2, 1] + assert a[pos_idxs.cast(dtype).to_numpy()].to_list() == [3, 1, 2, 1] + assert a[empty_idxs.cast(dtype)].to_list() == [] + assert a[empty_idxs.cast(dtype).to_numpy()].to_list() == [] + + for dtype in (pl.Int8, pl.Int16, pl.Int32, pl.Int64): + nps = a[neg_and_pos_idxs.cast(dtype).to_numpy()] + assert nps.to_list() == [2, 2, 1, 3, 3, 1] + + +def test_set() -> None: + a = pl.Series("a", [True, False, True]) + mask = pl.Series("msk", [True, False, True]) + a[mask] = False + assert_series_equal(a, pl.Series("a", [False] * 3)) + + +def test_set_value_as_list_fail() -> None: + # only allowed for numerical physical types + s = pl.Series("a", [1, 2, 3]) + s[[0, 2]] = [4, 5] + assert s.to_list() == [4, 2, 5] + + # for other types it is not allowed + s = pl.Series("a", ["a", "b", "c"]) + with pytest.raises(TypeError): + s[[0, 1]] = ["d", "e"] + + s = pl.Series("a", [True, False, False]) + with pytest.raises(TypeError): + s[[0, 1]] = [True, False] + + +@pytest.mark.parametrize("key", [True, False, 1.0]) +def test_set_invalid_key(key: Any) -> None: + s = pl.Series("a", [1, 2, 3]) + with pytest.raises(TypeError): + s[key] = 1 + + +@pytest.mark.parametrize( + "key", + [ + pl.Series([False, True, True]), + pl.Series([1, 2], dtype=UInt32), + pl.Series([1, 2], dtype=UInt64), + ], +) +def test_set_key_series(key: pl.Series) -> None: + """Only UInt32/UInt64/bool are allowed.""" + s = pl.Series("a", [1, 2, 3]) + s[key] = 4 + assert_series_equal(s, pl.Series("a", [1, 4, 4])) + + +def test_set_np_array_boolean_mask() -> None: + a = pl.Series("a", [1, 2, 3]) + mask = np.array([True, False, True]) + a[mask] = 4 + assert_series_equal(a, pl.Series("a", [4, 2, 4])) + + +@pytest.mark.parametrize("dtype", [np.int32, np.int64, np.uint32, np.uint64]) +def test_set_np_array(dtype: Any) -> None: + a = pl.Series("a", [1, 2, 3]) + idx = np.array([0, 2], dtype=dtype) + a[idx] = 4 + assert_series_equal(a, pl.Series("a", [4, 2, 4])) + + +@pytest.mark.parametrize("idx", [[0, 2], (0, 2)]) +def test_set_list_and_tuple(idx: list[int] | tuple[int]) -> None: + a = pl.Series("a", [1, 2, 3]) + a[idx] = 4 + assert_series_equal(a, pl.Series("a", [4, 2, 4])) + + +def test_init_nested_tuple() -> None: + s1 = pl.Series("s", (1, 2, 3)) + assert s1.to_list() == [1, 2, 3] + + s2 = pl.Series("s", ((1, 2, 3),), dtype=pl.List(pl.UInt8)) + assert s2.to_list() == [[1, 2, 3]] + assert s2.dtype == pl.List(pl.UInt8) + + s3 = pl.Series("s", ((1, 2, 3), (1, 2, 3)), dtype=pl.List(pl.Int32)) + assert s3.to_list() == [[1, 2, 3], [1, 2, 3]] + assert s3.dtype == pl.List(pl.Int32) + + +@pytest.mark.may_fail_auto_streaming +def test_fill_null() -> None: + s = pl.Series("a", [1, 2, None]) + assert_series_equal(s.fill_null(strategy="forward"), pl.Series("a", [1, 2, 2])) + assert_series_equal(s.fill_null(14), pl.Series("a", [1, 2, 14], dtype=Int64)) + + a = pl.Series("a", [0.0, 1.0, None, 2.0, None, 3.0]) + + assert a.fill_null(0).to_list() == [0.0, 1.0, 0.0, 2.0, 0.0, 3.0] + assert a.fill_null(strategy="zero").to_list() == [0.0, 1.0, 0.0, 2.0, 0.0, 3.0] + assert a.fill_null(strategy="max").to_list() == [0.0, 1.0, 3.0, 2.0, 3.0, 3.0] + assert a.fill_null(strategy="min").to_list() == [0.0, 1.0, 0.0, 2.0, 0.0, 3.0] + assert a.fill_null(strategy="one").to_list() == [0.0, 1.0, 1.0, 2.0, 1.0, 3.0] + assert a.fill_null(strategy="forward").to_list() == [0.0, 1.0, 1.0, 2.0, 2.0, 3.0] + assert a.fill_null(strategy="backward").to_list() == [0.0, 1.0, 2.0, 2.0, 3.0, 3.0] + assert a.fill_null(strategy="mean").to_list() == [0.0, 1.0, 1.5, 2.0, 1.5, 3.0] + assert a.forward_fill().to_list() == [0.0, 1.0, 1.0, 2.0, 2.0, 3.0] + assert a.backward_fill().to_list() == [0.0, 1.0, 2.0, 2.0, 3.0, 3.0] + + b = pl.Series("b", ["a", None, "c", None, "e"]) + assert b.fill_null(strategy="min").to_list() == ["a", "a", "c", "a", "e"] + assert b.fill_null(strategy="max").to_list() == ["a", "e", "c", "e", "e"] + assert b.fill_null(strategy="zero").to_list() == ["a", "", "c", "", "e"] + assert b.fill_null(strategy="forward").to_list() == ["a", "a", "c", "c", "e"] + assert b.fill_null(strategy="backward").to_list() == ["a", "c", "c", "e", "e"] + + c = pl.Series("c", [b"a", None, b"c", None, b"e"]) + assert c.fill_null(strategy="min").to_list() == [b"a", b"a", b"c", b"a", b"e"] + assert c.fill_null(strategy="max").to_list() == [b"a", b"e", b"c", b"e", b"e"] + assert c.fill_null(strategy="zero").to_list() == [b"a", b"", b"c", b"", b"e"] + assert c.fill_null(strategy="forward").to_list() == [b"a", b"a", b"c", b"c", b"e"] + assert c.fill_null(strategy="backward").to_list() == [b"a", b"c", b"c", b"e", b"e"] + + df = pl.DataFrame( + [ + pl.Series("i32", [1, 2, None], dtype=pl.Int32), + pl.Series("i64", [1, 2, None], dtype=pl.Int64), + pl.Series("f32", [1, 2, None], dtype=pl.Float32), + pl.Series("cat", ["a", "b", None], dtype=pl.Categorical), + pl.Series("str", ["a", "b", None], dtype=pl.String), + pl.Series("bool", [True, True, None], dtype=pl.Boolean), + ] + ) + + assert df.fill_null(0, matches_supertype=False).fill_null("bar").fill_null( + False + ).to_dict(as_series=False) == { + "i32": [1, 2, None], + "i64": [1, 2, 0], + "f32": [1.0, 2.0, None], + "cat": ["a", "b", "bar"], + "str": ["a", "b", "bar"], + "bool": [True, True, False], + } + + assert df.fill_null(0, matches_supertype=True).fill_null("bar").fill_null( + False + ).to_dict(as_series=False) == { + "i32": [1, 2, 0], + "i64": [1, 2, 0], + "f32": [1.0, 2.0, 0.0], + "cat": ["a", "b", "bar"], + "str": ["a", "b", "bar"], + "bool": [True, True, False], + } + df = pl.DataFrame({"a": [1, None, 2, None]}) + + out = df.with_columns( + pl.col("a").cast(pl.UInt8).alias("u8"), + pl.col("a").cast(pl.UInt16).alias("u16"), + pl.col("a").cast(pl.UInt32).alias("u32"), + pl.col("a").cast(pl.UInt64).alias("u64"), + ).fill_null(3) + + assert out.to_dict(as_series=False) == { + "a": [1, 3, 2, 3], + "u8": [1, 3, 2, 3], + "u16": [1, 3, 2, 3], + "u32": [1, 3, 2, 3], + "u64": [1, 3, 2, 3], + } + assert out.dtypes == [pl.Int64, pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64] + + +def test_str_series_min_max_10674() -> None: + str_series = pl.Series("b", ["a", None, "c", None, "e"], dtype=pl.String) + assert str_series.min() == "a" + assert str_series.max() == "e" + assert str_series.sort(descending=False).min() == "a" + assert str_series.sort(descending=True).max() == "e" + + +def test_fill_nan() -> None: + nan = float("nan") + a = pl.Series("a", [1.0, nan, 2.0, nan, 3.0]) + assert_series_equal(a.fill_nan(None), pl.Series("a", [1.0, None, 2.0, None, 3.0])) + assert_series_equal(a.fill_nan(0), pl.Series("a", [1.0, 0.0, 2.0, 0.0, 3.0])) + + +def test_map_elements() -> None: + with pytest.warns(PolarsInefficientMapWarning): + a = pl.Series("a", [1, 2, None]) + b = a.map_elements(lambda x: x**2, return_dtype=pl.Int64) + assert list(b) == [1, 4, None] + + with pytest.warns(PolarsInefficientMapWarning): + a = pl.Series("a", ["foo", "bar", None]) + b = a.map_elements(lambda x: x + "py", return_dtype=pl.String) + assert list(b) == ["foopy", "barpy", None] + + b = a.map_elements(lambda x: len(x), return_dtype=pl.Int32) + assert list(b) == [3, 3, None] + + b = a.map_elements(lambda x: len(x)) + assert list(b) == [3, 3, None] + + # just check that it runs (somehow problem with conditional compilation) + a = pl.Series("a", [2, 2, 3]).cast(pl.Datetime) + a.map_elements(lambda x: x) + a = pl.Series("a", [2, 2, 3]).cast(pl.Date) + a.map_elements(lambda x: x) + + +def test_shape() -> None: + s = pl.Series([1, 2, 3]) + assert s.shape == (3,) + + +@pytest.mark.parametrize("arrow_available", [True, False]) +def test_create_list_series(arrow_available: bool, monkeypatch: Any) -> None: + monkeypatch.setattr(pl.series.series, "_PYARROW_AVAILABLE", arrow_available) + a = [[1, 2], None, [None, 3]] + s = pl.Series("", a) + assert s.to_list() == a + + +def test_iter() -> None: + s = pl.Series("", [1, 2, 3]) + + itr = s.__iter__() + assert itr.__next__() == 1 + assert itr.__next__() == 2 + assert itr.__next__() == 3 + assert sum(s) == 6 + + +def test_empty() -> None: + a = pl.Series(dtype=pl.Int8) + assert a.dtype == pl.Int8 + assert a.is_empty() + + a = pl.Series() + assert a.dtype == pl.Null + assert a.is_empty() + + a = pl.Series("name", []) + assert a.dtype == pl.Null + assert a.is_empty() + + a = pl.Series(values=(), dtype=pl.Int8) + assert a.dtype == pl.Int8 + assert a.is_empty() + + assert_series_equal(pl.Series(), pl.Series()) + assert_series_equal( + pl.Series(dtype=pl.Int32), pl.Series(dtype=pl.Int64), check_dtypes=False + ) + + with pytest.raises(TypeError, match="ambiguous"): + not pl.Series() + + +def test_round() -> None: + a = pl.Series("f", [1.003, 2.003]) + b = a.round(2) + assert b.to_list() == [1.00, 2.00] + + b = a.round() + assert b.to_list() == [1.0, 2.0] + + +@pytest.mark.parametrize( + ("series", "digits", "expected_result"), + [ + pytest.param(pl.Series([1.234, 0.1234]), 2, pl.Series([1.2, 0.12]), id="f64"), + pytest.param( + pl.Series([1.234, 0.1234]).cast(pl.Float32), + 2, + pl.Series([1.2, 0.12]).cast(pl.Float32), + id="f32", + ), + pytest.param(pl.Series([123400, 1234]), 2, pl.Series([120000, 1200]), id="i64"), + pytest.param( + pl.Series([123400, 1234]).cast(pl.Int32), + 2, + pl.Series([120000, 1200]).cast(pl.Int32), + id="i32", + ), + pytest.param( + pl.Series([0.0]), 2, pl.Series([0.0]), id="0 should remain the same" + ), + ], +) +def test_round_sig_figs( + series: pl.Series, digits: int, expected_result: pl.Series +) -> None: + result = series.round_sig_figs(digits=digits) + assert_series_equal(result, expected_result) + + +def test_round_sig_figs_raises_exc() -> None: + with pytest.raises(pl.exceptions.InvalidOperationError): + pl.Series([1.234, 0.1234]).round_sig_figs(digits=0) + + +def test_apply_list_out() -> None: + s = pl.Series("count", [3, 2, 2]) + out = s.map_elements(lambda val: pl.repeat(val, val, eager=True)) + assert out[0].to_list() == [3, 3, 3] + assert out[1].to_list() == [2, 2] + assert out[2].to_list() == [2, 2] + + +def test_reinterpret() -> None: + s = pl.Series("a", [1, 1, 2], dtype=pl.UInt64) + assert s.reinterpret(signed=True).dtype == pl.Int64 + df = pl.DataFrame([s]) + assert df.select([pl.col("a").reinterpret(signed=True)])["a"].dtype == pl.Int64 + + +def test_mode() -> None: + s = pl.Series("a", [1, 1, 2]) + assert s.mode().to_list() == [1] + assert s.set_sorted().mode().to_list() == [1] + + df = pl.DataFrame([s]) + assert df.select([pl.col("a").mode()])["a"].to_list() == [1] + assert ( + pl.Series(["foo", "bar", "buz", "bar"], dtype=pl.Categorical).mode().item() + == "bar" + ) + assert pl.Series([1.0, 2.0, 3.0, 2.0]).mode().item() == 2.0 + assert pl.Series(["a", "b", "c", "b"]).mode().item() == "b" + + # sorted data + assert set(pl.int_range(0, 3, eager=True).mode().to_list()) == {0, 1, 2} + + +def test_diff() -> None: + s = pl.Series("a", [1, 2, 3, 2, 2, 3, 0]) + expected = pl.Series("a", [1, 1, -1, 0, 1, -3]) + + assert_series_equal(s.diff(null_behavior="drop"), expected) + + df = pl.DataFrame([s]) + assert_series_equal( + df.select(pl.col("a").diff())["a"], pl.Series("a", [None, 1, 1, -1, 0, 1, -3]) + ) + + +def test_pct_change() -> None: + s = pl.Series("a", [1, 2, 4, 8, 16, 32, 64]) + expected = pl.Series("a", [None, None, 3.0, 3.0, 3.0, 3.0, 3.0]) + assert_series_equal(s.pct_change(2), expected) + assert_series_equal(s.pct_change(pl.Series([2])), expected) + # negative + assert pl.Series(range(5)).pct_change(-1).to_list() == [ + -1.0, + -0.5, + -0.3333333333333333, + -0.25, + None, + ] + + +def test_skew() -> None: + s = pl.Series("a", [1, 2, 3, 2, 2, 3, 0]) + + assert s.skew(bias=True) == pytest.approx(-0.5953924651018018) + assert s.skew(bias=False) == pytest.approx(-0.7717168360221258) + + df = pl.DataFrame([s]) + assert np.isclose( + df.select(pl.col("a").skew(bias=False))["a"][0], -0.7717168360221258 + ) + + +def test_kurtosis() -> None: + s = pl.Series("a", [1, 2, 3, 2, 2, 3, 0]) + expected = -0.6406250000000004 + + assert s.kurtosis() == pytest.approx(expected) + df = pl.DataFrame([s]) + assert np.isclose(df.select(pl.col("a").kurtosis())["a"][0], expected) + + +def test_sqrt() -> None: + s = pl.Series("a", [1, 2]) + assert_series_equal(s.sqrt(), pl.Series("a", [1.0, np.sqrt(2)])) + df = pl.DataFrame([s]) + assert_series_equal( + df.select(pl.col("a").sqrt())["a"], pl.Series("a", [1.0, np.sqrt(2)]) + ) + + +def test_cbrt() -> None: + s = pl.Series("a", [1, 2]) + assert_series_equal(s.cbrt(), pl.Series("a", [1.0, np.cbrt(2)])) + df = pl.DataFrame([s]) + assert_series_equal( + df.select(pl.col("a").cbrt())["a"], pl.Series("a", [1.0, np.cbrt(2)]) + ) + + +def test_range() -> None: + s1 = pl.Series("a", [1, 2, 3, 2, 2, 3, 0]) + assert_series_equal(s1[2:5], s1[range(2, 5)]) + + ranges = [range(-2, 1), range(3), range(2, 8, 2)] + + s2 = pl.Series("b", ranges, dtype=pl.List(pl.Int8)) + assert s2.to_list() == [[-2, -1, 0], [0, 1, 2], [2, 4, 6]] + assert s2.dtype == pl.List(pl.Int8) + assert s2.name == "b" + + s3 = pl.Series("c", (ranges for _ in range(3))) + assert s3.to_list() == [ + [[-2, -1, 0], [0, 1, 2], [2, 4, 6]], + [[-2, -1, 0], [0, 1, 2], [2, 4, 6]], + [[-2, -1, 0], [0, 1, 2], [2, 4, 6]], + ] + assert s3.dtype == pl.List(pl.List(pl.Int64)) + + df = pl.DataFrame([s1]) + assert_frame_equal(df[2:5], df[range(2, 5)]) + + +def test_strict_cast() -> None: + with pytest.raises(InvalidOperationError): + pl.Series("a", [2**16]).cast(dtype=pl.Int16, strict=True) + with pytest.raises(InvalidOperationError): + pl.DataFrame({"a": [2**16]}).select([pl.col("a").cast(pl.Int16, strict=True)]) + + +def test_floor_divide() -> None: + s = pl.Series("a", [1, 2, 3]) + assert_series_equal(s // 2, pl.Series("a", [0, 1, 1])) + assert_series_equal( + pl.DataFrame([s]).select(pl.col("a") // 2)["a"], pl.Series("a", [0, 1, 1]) + ) + + +def test_true_divide() -> None: + s = pl.Series("a", [1, 2]) + assert_series_equal(s / 2, pl.Series("a", [0.5, 1.0])) + assert_series_equal( + pl.DataFrame([s]).select(pl.col("a") / 2)["a"], pl.Series("a", [0.5, 1.0]) + ) + + # rtruediv + assert_series_equal( + pl.DataFrame([s]).select(2 / pl.col("a"))["literal"], + pl.Series("literal", [2.0, 1.0]), + ) + + # https://github.com/pola-rs/polars/issues/1369 + vals = [3000000000, 2, 3] + foo = pl.Series(vals) + assert_series_equal(foo / 1, pl.Series(vals, dtype=Float64)) + assert_series_equal( + pl.DataFrame({"a": vals}).select([pl.col("a") / 1])["a"], + pl.Series("a", vals, dtype=Float64), + ) + + +def test_bitwise() -> None: + a = pl.Series("a", [1, 2, 3]) + b = pl.Series("b", [3, 4, 5]) + assert_series_equal(a & b, pl.Series("a", [1, 0, 1])) + assert_series_equal(a | b, pl.Series("a", [3, 6, 7])) + assert_series_equal(a ^ b, pl.Series("a", [2, 6, 6])) + + df = pl.DataFrame([a, b]) + out = df.select( + (pl.col("a") & pl.col("b")).alias("and"), + (pl.col("a") | pl.col("b")).alias("or"), + (pl.col("a") ^ pl.col("b")).alias("xor"), + ) + assert_series_equal(out["and"], pl.Series("and", [1, 0, 1])) + assert_series_equal(out["or"], pl.Series("or", [3, 6, 7])) + assert_series_equal(out["xor"], pl.Series("xor", [2, 6, 6])) + + # ensure mistaken use of logical 'and'/'or' raises an exception + with pytest.raises(TypeError, match="ambiguous"): + a and b # type: ignore[redundant-expr] + + with pytest.raises(TypeError, match="ambiguous"): + a or b # type: ignore[redundant-expr] + + +def test_from_generator_or_iterable() -> None: + # generator function + def gen(n: int) -> Iterator[int]: + yield from range(n) + + # iterable object + class Data: + def __init__(self, n: int) -> None: + self._n = n + + def __iter__(self) -> Iterator[int]: + yield from gen(self._n) + + expected = pl.Series("s", range(10)) + assert expected.dtype == pl.Int64 + + for generated_series in ( + pl.Series("s", values=gen(10)), + pl.Series("s", values=Data(10)), + pl.Series("s", values=(x for x in gen(10))), + ): + assert_series_equal(expected, generated_series) + + # test 'iterable_to_pyseries' directly to validate 'chunk_size' behaviour + ps1 = iterable_to_pyseries("s", gen(10), dtype=pl.UInt8) + ps2 = iterable_to_pyseries("s", gen(10), dtype=pl.UInt8, chunk_size=3) + ps3 = iterable_to_pyseries("s", Data(10), dtype=pl.UInt8, chunk_size=6) + + expected = pl.Series("s", range(10), dtype=pl.UInt8) + assert expected.dtype == pl.UInt8 + + for ps in (ps1, ps2, ps3): + generated_series = pl.Series("s") + generated_series._s = ps + assert_series_equal(expected, generated_series) + + # empty generator + assert_series_equal(pl.Series("s", []), pl.Series("s", values=gen(0))) + + +def test_from_sequences(monkeypatch: Any) -> None: + # test int, str, bool, flt + values = [ + [[1], [None, 3]], + [["foo"], [None, "bar"]], + [[True], [None, False]], + [[1.0], [None, 3.0]], + ] + + for vals in values: + monkeypatch.setattr(pl.series.series, "_PYARROW_AVAILABLE", False) + a = pl.Series("a", vals) + monkeypatch.setattr(pl.series.series, "_PYARROW_AVAILABLE", True) + b = pl.Series("a", vals) + assert_series_equal(a, b) + assert a.to_list() == vals + + +def test_comparisons_int_series_to_float() -> None: + srs_int = pl.Series([1, 2, 3, 4]) + + assert_series_equal(srs_int - 1.0, pl.Series([0.0, 1.0, 2.0, 3.0])) + assert_series_equal(srs_int + 1.0, pl.Series([2.0, 3.0, 4.0, 5.0])) + assert_series_equal(srs_int * 2.0, pl.Series([2.0, 4.0, 6.0, 8.0])) + assert_series_equal(srs_int / 2.0, pl.Series([0.5, 1.0, 1.5, 2.0])) + assert_series_equal(srs_int % 2.0, pl.Series([1.0, 0.0, 1.0, 0.0])) + assert_series_equal(4.0 % srs_int, pl.Series([0.0, 0.0, 1.0, 0.0])) + + assert_series_equal(srs_int // 2.0, pl.Series([0.0, 1.0, 1.0, 2.0])) + assert_series_equal(srs_int < 3.0, pl.Series([True, True, False, False])) + assert_series_equal(srs_int <= 3.0, pl.Series([True, True, True, False])) + assert_series_equal(srs_int > 3.0, pl.Series([False, False, False, True])) + assert_series_equal(srs_int >= 3.0, pl.Series([False, False, True, True])) + assert_series_equal(srs_int == 3.0, pl.Series([False, False, True, False])) + assert_series_equal(srs_int - True, pl.Series([0, 1, 2, 3])) + + +def test_comparisons_int_series_to_float_scalar() -> None: + srs_int = pl.Series([1, 2, 3, 4]) + + assert_series_equal(srs_int < 1.5, pl.Series([True, False, False, False])) + assert_series_equal(srs_int > 1.5, pl.Series([False, True, True, True])) + + +def test_comparisons_datetime_series_to_date_scalar() -> None: + srs_date = pl.Series([date(2023, 1, 1), date(2023, 1, 2), date(2023, 1, 3)]) + dt = datetime(2023, 1, 1, 12, 0, 0) + + assert_series_equal(srs_date < dt, pl.Series([True, False, False])) + assert_series_equal(srs_date > dt, pl.Series([False, True, True])) + + +def test_comparisons_float_series_to_int() -> None: + srs_float = pl.Series([1.0, 2.0, 3.0, 4.0]) + + assert_series_equal(srs_float - 1, pl.Series([0.0, 1.0, 2.0, 3.0])) + assert_series_equal(srs_float + 1, pl.Series([2.0, 3.0, 4.0, 5.0])) + assert_series_equal(srs_float * 2, pl.Series([2.0, 4.0, 6.0, 8.0])) + assert_series_equal(srs_float / 2, pl.Series([0.5, 1.0, 1.5, 2.0])) + assert_series_equal(srs_float % 2, pl.Series([1.0, 0.0, 1.0, 0.0])) + assert_series_equal(4 % srs_float, pl.Series([0.0, 0.0, 1.0, 0.0])) + + assert_series_equal(srs_float // 2, pl.Series([0.0, 1.0, 1.0, 2.0])) + assert_series_equal(srs_float < 3, pl.Series([True, True, False, False])) + assert_series_equal(srs_float <= 3, pl.Series([True, True, True, False])) + assert_series_equal(srs_float > 3, pl.Series([False, False, False, True])) + assert_series_equal(srs_float >= 3, pl.Series([False, False, True, True])) + assert_series_equal(srs_float == 3, pl.Series([False, False, True, False])) + assert_series_equal(srs_float - True, pl.Series([0.0, 1.0, 2.0, 3.0])) + + +def test_comparisons_bool_series_to_int() -> None: + srs_bool = pl.Series([True, False]) + + # (native bool comparison should work...) + for t, f in ((True, False), (False, True)): + assert list(srs_bool == t) == list(srs_bool != f) == [t, f] + + # TODO: do we want this to work? + assert_series_equal(srs_bool / 1, pl.Series([True, False], dtype=Float64)) + match = ( + r"cannot do arithmetic with Series of dtype: Boolean" + r" and argument of type: 'bool'" + ) + with pytest.raises(TypeError, match=match): + srs_bool - 1 + with pytest.raises(TypeError, match=match): + srs_bool + 1 + match = ( + r"cannot do arithmetic with Series of dtype: Boolean" + r" and argument of type: 'bool'" + ) + with pytest.raises(TypeError, match=match): + srs_bool % 2 + with pytest.raises(TypeError, match=match): + srs_bool * 1 + + from operator import ge, gt, le, lt + + for op in (ge, gt, le, lt): + for scalar in (0, 1.0, True, False): + with pytest.raises( + TypeError, + match=r"'\W{1,2}' not supported .* 'Series' and '(int|bool|float)'", + ): + op(srs_bool, scalar) + + +@pytest.mark.parametrize( + ("values", "compare_with", "compares_equal"), + [ + ( + [date(1999, 12, 31), date(2021, 1, 31)], + date(2021, 1, 31), + [False, True], + ), + ( + [datetime(2021, 1, 1, 12, 0, 0), datetime(2021, 1, 2, 12, 0, 0)], + datetime(2021, 1, 1, 12, 0, 0), + [True, False], + ), + ( + [timedelta(days=1), timedelta(days=2)], + timedelta(days=1), + [True, False], + ), + ], +) +def test_temporal_comparison( + values: list[Any], compare_with: Any, compares_equal: list[bool] +) -> None: + assert_series_equal( + pl.Series(values) == compare_with, + pl.Series(compares_equal, dtype=pl.Boolean), + ) + + +def test_to_dummies() -> None: + s = pl.Series("a", [1, 2, 3]) + result = s.to_dummies() + expected = pl.DataFrame( + {"a_1": [1, 0, 0], "a_2": [0, 1, 0], "a_3": [0, 0, 1]}, + schema={"a_1": pl.UInt8, "a_2": pl.UInt8, "a_3": pl.UInt8}, + ) + assert_frame_equal(result, expected) + + +def test_to_dummies_drop_first() -> None: + s = pl.Series("a", [1, 2, 3]) + result = s.to_dummies(drop_first=True) + expected = pl.DataFrame( + {"a_2": [0, 1, 0], "a_3": [0, 0, 1]}, + schema={"a_2": pl.UInt8, "a_3": pl.UInt8}, + ) + assert_frame_equal(result, expected) + + +def test_to_dummies_null_clash_19096() -> None: + with pytest.raises( + DuplicateError, match="column with name '_null' has more than one occurrence" + ): + pl.Series([None, "null"]).to_dummies() + + +def test_chunk_lengths() -> None: + s = pl.Series("a", [1, 2, 2, 3]) + # this is a Series with one chunk, of length 4 + assert s.n_chunks() == 1 + assert s.chunk_lengths() == [4] + + +def test_limit() -> None: + s = pl.Series("a", [1, 2, 3]) + assert_series_equal(s.limit(2), pl.Series("a", [1, 2])) + + +def test_filter() -> None: + s = pl.Series("a", [1, 2, 3]) + mask = pl.Series("", [True, False, True]) + + assert_series_equal(s.filter(mask), pl.Series("a", [1, 3])) + assert_series_equal(s.filter([True, False, True]), pl.Series("a", [1, 3])) + assert_series_equal(s.filter(np.array([True, False, True])), pl.Series("a", [1, 3])) + + with pytest.raises(RuntimeError, match="Expected a boolean mask"): + s.filter(np.array([1, 0, 1])) + + +def test_gather_every() -> None: + s = pl.Series("a", [1, 2, 3, 4]) + assert_series_equal(s.gather_every(2), pl.Series("a", [1, 3])) + assert_series_equal(s.gather_every(2, offset=1), pl.Series("a", [2, 4])) + + +def test_arg_sort() -> None: + s = pl.Series("a", [5, 3, 4, 1, 2]) + expected = pl.Series("a", [3, 4, 1, 2, 0], dtype=UInt32) + + assert_series_equal(s.arg_sort(), expected) + + expected_descending = pl.Series("a", [0, 2, 1, 4, 3], dtype=UInt32) + assert_series_equal(s.arg_sort(descending=True), expected_descending) + + +@pytest.mark.parametrize( + ("series", "argmin", "argmax"), + [ + # Numeric + (pl.Series([5, 3, 4, 1, 2]), 3, 0), + (pl.Series([None, 5, 1]), 2, 1), + # Boolean + (pl.Series([True, False]), 1, 0), + (pl.Series([True, True]), 0, 0), + (pl.Series([False, False]), 0, 0), + (pl.Series([None, True, False, True]), 2, 1), + (pl.Series([None, True, True]), 1, 1), + (pl.Series([None, False, False]), 1, 1), + # String + (pl.Series(["a", "c", "b"]), 0, 1), + (pl.Series([None, "a", None, "b"]), 1, 3), + # Categorical + (pl.Series(["c", "b", "a"], dtype=pl.Categorical), 0, 2), + (pl.Series([None, "c", "b", None, "a"], dtype=pl.Categorical), 1, 4), + (pl.Series(["c", "b", "a"], dtype=pl.Categorical(ordering="lexical")), 2, 0), + ( + pl.Series( + [None, "c", "b", None, "a"], dtype=pl.Categorical(ordering="lexical") + ), + 4, + 1, + ), + ], +) +def test_arg_min_arg_max(series: pl.Series, argmin: int, argmax: int) -> None: + assert series.arg_min() == argmin + assert series.arg_max() == argmax + + +@pytest.mark.parametrize( + ("series"), + [ + # All nulls + pl.Series([None, None], dtype=pl.Int32), + pl.Series([None, None], dtype=pl.Boolean), + pl.Series([None, None], dtype=pl.String), + pl.Series([None, None], dtype=pl.Categorical), + pl.Series([None, None], dtype=pl.Categorical(ordering="lexical")), + # Empty Series + pl.Series([], dtype=pl.Int32), + pl.Series([], dtype=pl.Boolean), + pl.Series([], dtype=pl.String), + pl.Series([], dtype=pl.Categorical), + ], +) +def test_arg_min_arg_max_all_nulls_or_empty(series: pl.Series) -> None: + assert series.arg_min() is None + assert series.arg_max() is None + + +def test_arg_min_and_arg_max_sorted() -> None: + # test ascending and descending numerical series + s = pl.Series([None, 1, 2, 3, 4, 5]) + s.sort(in_place=True) # set ascending sorted flag + assert s.flags == {"SORTED_ASC": True, "SORTED_DESC": False} + assert s.arg_min() == 1 + assert s.arg_max() == 5 + s = pl.Series([None, 5, 4, 3, 2, 1]) + s.sort(descending=True, in_place=True) # set descing sorted flag + assert s.flags == {"SORTED_ASC": False, "SORTED_DESC": True} + assert s.arg_min() == 5 + assert s.arg_max() == 1 + + # test ascending and descending str series + s = pl.Series([None, "a", "b", "c", "d", "e"]) + s.sort(in_place=True) # set ascending sorted flag + assert s.flags == {"SORTED_ASC": True, "SORTED_DESC": False} + assert s.arg_min() == 1 + assert s.arg_max() == 5 + s = pl.Series([None, "e", "d", "c", "b", "a"]) + s.sort(descending=True, in_place=True) # set descing sorted flag + assert s.flags == {"SORTED_ASC": False, "SORTED_DESC": True} + assert s.arg_min() == 5 + assert s.arg_max() == 1 + + +def test_is_null_is_not_null() -> None: + s = pl.Series("a", [1.0, 2.0, 3.0, None]) + assert_series_equal(s.is_null(), pl.Series("a", [False, False, False, True])) + assert_series_equal(s.is_not_null(), pl.Series("a", [True, True, True, False])) + + +def test_is_finite_is_infinite() -> None: + s = pl.Series("a", [1.0, 2.0, np.inf]) + assert_series_equal(s.is_finite(), pl.Series("a", [True, True, False])) + assert_series_equal(s.is_infinite(), pl.Series("a", [False, False, True])) + + +@pytest.mark.parametrize("float_type", [pl.Float32, pl.Float64]) +def test_is_nan_is_not_nan(float_type: PolarsDataType) -> None: + s = pl.Series([1.0, np.nan, None], dtype=float_type) + + assert_series_equal(s.is_nan(), pl.Series([False, True, None])) + assert_series_equal(s.is_not_nan(), pl.Series([True, False, None])) + assert_series_equal(s.fill_nan(2.0), pl.Series([1.0, 2.0, None], dtype=float_type)) + assert_series_equal(s.drop_nans(), pl.Series([1.0, None], dtype=float_type)) + + +def test_float_methods_on_ints() -> None: + # these float-specific methods work on non-float numeric types + s = pl.Series([1, None], dtype=pl.Int32) + assert_series_equal(s.is_finite(), pl.Series([True, None])) + assert_series_equal(s.is_infinite(), pl.Series([False, None])) + assert_series_equal(s.is_nan(), pl.Series([False, None])) + assert_series_equal(s.is_not_nan(), pl.Series([True, None])) + + +def test_dot() -> None: + s1 = pl.Series("a", [1, 2, 3]) + s2 = pl.Series("b", [4.0, 5.0, 6.0]) + + assert np.array([1, 2, 3]) @ np.array([4, 5, 6]) == 32 + + for dot_result in ( + s1.dot(s2), + s1 @ s2, + [1, 2, 3] @ s2, + s1 @ np.array([4, 5, 6]), + ): + assert dot_result == 32 + + with pytest.raises(ShapeError, match="length mismatch"): + s1 @ [4, 5, 6, 7, 8] + + +def test_peak_max_peak_min() -> None: + s = pl.Series("a", [4, 1, 3, 2, 5]) + result = s.peak_min() + expected = pl.Series("a", [False, True, False, True, False]) + assert_series_equal(result, expected) + + result = s.peak_max() + expected = pl.Series("a", [True, False, True, False, True]) + assert_series_equal(result, expected) + + +def test_shrink_to_fit() -> None: + s = pl.Series("a", [4, 1, 3, 2, 5]) + sf = s.shrink_to_fit(in_place=True) + assert sf is s + + s = pl.Series("a", [4, 1, 3, 2, 5]) + sf = s.shrink_to_fit(in_place=False) + assert s is not sf + + +@pytest.mark.parametrize("unit", ["ns", "us", "ms"]) +def test_cast_datetime_to_time(unit: TimeUnit) -> None: + a = pl.Series( + "a", + [ + datetime(2022, 9, 7, 0, 0), + datetime(2022, 9, 6, 12, 0), + datetime(2022, 9, 7, 23, 59, 59), + datetime(2022, 9, 7, 23, 59, 59, 201), + ], + dtype=Datetime(unit), + ) + if unit == "ms": + # NOTE: microseconds are lost for `unit=ms` + expected_values = [time(0, 0), time(12, 0), time(23, 59, 59), time(23, 59, 59)] + else: + expected_values = [ + time(0, 0), + time(12, 0), + time(23, 59, 59), + time(23, 59, 59, 201), + ] + expected = pl.Series("a", expected_values) + assert_series_equal(a.cast(Time), expected) + + +def test_init_categorical() -> None: + with pl.StringCache(): + for values in [[None], ["foo", "bar"], [None, "foo", "bar"]]: + expected = pl.Series("a", values, dtype=pl.String).cast(pl.Categorical) + a = pl.Series("a", values, dtype=pl.Categorical) + assert_series_equal(a, expected) + + +def test_iter_nested_list() -> None: + elems = list(pl.Series("s", [[1, 2], [3, 4]])) + assert_series_equal(elems[0], pl.Series([1, 2])) + assert_series_equal(elems[1], pl.Series([3, 4])) + + rev_elems = list(reversed(pl.Series("s", [[1, 2], [3, 4]]))) + assert_series_equal(rev_elems[0], pl.Series([3, 4])) + assert_series_equal(rev_elems[1], pl.Series([1, 2])) + + +def test_iter_nested_struct() -> None: + # note: this feels inconsistent with the above test for nested list, but + # let's ensure the behaviour is codified before potentially modifying... + elems = list(pl.Series("s", [{"a": 1, "b": 2}, {"a": 3, "b": 4}])) + assert elems[0] == {"a": 1, "b": 2} + assert elems[1] == {"a": 3, "b": 4} + + rev_elems = list(reversed(pl.Series("s", [{"a": 1, "b": 2}, {"a": 3, "b": 4}]))) + assert rev_elems[0] == {"a": 3, "b": 4} + assert rev_elems[1] == {"a": 1, "b": 2} + + +@pytest.mark.parametrize( + "dtype", + [ + pl.UInt8, + pl.Float32, + pl.Int32, + pl.Boolean, + pl.List(pl.String), + pl.Struct([pl.Field("a", pl.Int64), pl.Field("b", pl.Boolean)]), + ], +) +def test_nested_list_types_preserved(dtype: pl.DataType) -> None: + srs = pl.Series([pl.Series([], dtype=dtype) for _ in range(5)]) + for srs_nested in srs: + assert srs_nested.dtype == dtype + + +@pytest.mark.parametrize( + "dtype", + [ + pl.Float64, + pl.Int32, + pl.Decimal(21, 3), + ], +) +def test_log_exp(dtype: pl.DataType) -> None: + a = pl.Series("a", [1, 100, 1000], dtype=dtype) + b = pl.Series("a", [0, 2, 3], dtype=dtype) + assert_series_equal(a.log10(), b.cast(pl.Float64)) + + expected = pl.Series("a", np.log(a.cast(pl.Float64).to_numpy())) + assert_series_equal(a.log(), expected) + + expected = pl.Series("a", np.exp(b.cast(pl.Float64).to_numpy())) + assert_series_equal(b.exp(), expected) + + expected = pl.Series("a", np.log1p(a.cast(pl.Float64).to_numpy())) + assert_series_equal(a.log1p(), expected) + + +def test_to_physical() -> None: + # casting an int result in an int + s = pl.Series("a", [1, 2, 3]) + assert_series_equal(s.to_physical(), s) + + # casting a date results in an Int32 + s = pl.Series("a", [date(2020, 1, 1)] * 3) + expected = pl.Series("a", [18262] * 3, dtype=Int32) + assert_series_equal(s.to_physical(), expected) + + # casting a categorical results in a UInt32 + s = pl.Series(["cat1"]).cast(pl.Categorical) + expected = pl.Series([0], dtype=UInt32) + assert_series_equal(s.to_physical(), expected) + + # casting a List(Categorical) results in a List(UInt32) + s = pl.Series([["cat1"]]).cast(pl.List(pl.Categorical)) + expected = pl.Series([[0]], dtype=pl.List(UInt32)) + assert_series_equal(s.to_physical(), expected) + + +def test_to_physical_rechunked_21285() -> None: + # A series with multiple chunks, dtype is array or list of structs with a + # null field (causes rechunking) and a field with a different physical and + # logical repr (causes the full body of `to_physical_repr` to run). + arr_dtype = pl.Array(pl.Struct({"f0": pl.Time, "f1": pl.Null}), shape=(1,)) + s = pl.Series("a", [None], arr_dtype) # content doesn't matter + s = s.append(s) + expected_arr_dtype = pl.Array(pl.Struct({"f0": Int64, "f1": pl.Null}), shape=(1,)) + expected = pl.Series("a", [None, None], expected_arr_dtype) + assert_series_equal(s.to_physical(), expected) + + list_dtype = pl.List(pl.Struct({"f0": pl.Time, "f1": pl.Null})) + s = pl.Series("a", [None], list_dtype) # content doesn't matter + s = s.append(s) + expected_list_dtype = pl.List(pl.Struct({"f0": Int64, "f1": pl.Null})) + expected = pl.Series("a", [None, None], expected_list_dtype) + assert_series_equal(s.to_physical(), expected) + + +def test_is_between_datetime() -> None: + s = pl.Series("a", [datetime(2020, 1, 1, 10, 0, 0), datetime(2020, 1, 1, 20, 0, 0)]) + start = datetime(2020, 1, 1, 12, 0, 0) + end = datetime(2020, 1, 1, 23, 0, 0) + expected = pl.Series("a", [False, True]) + + # only on the expression api + result = s.to_frame().with_columns(pl.col("*").is_between(start, end)).to_series() + assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "f", + [ + "sin", + "cos", + "tan", + "arcsin", + "arccos", + "arctan", + "sinh", + "cosh", + "tanh", + "arcsinh", + "arccosh", + "arctanh", + ], +) +@pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning") +def test_trigonometric(f: str) -> None: + s = pl.Series("a", [0.0, math.pi, None, math.nan]) + expected = ( + pl.Series("a", getattr(np, f)(s.to_numpy())) + .to_frame() + .with_columns(pl.when(s.is_null()).then(None).otherwise(pl.col("a")).alias("a")) + .to_series() + ) + result = getattr(s, f)() + assert_series_equal(result, expected) + + +@pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning") +def test_trigonometric_cot() -> None: + # cotangent is not available in numpy... + s = pl.Series("a", [0.0, math.pi, None, math.nan]) + expected = pl.Series("a", [math.inf, -8.1656e15, None, math.nan]) + assert_series_equal(s.cot(), expected) + + +def test_trigonometric_invalid_input() -> None: + # String + s = pl.Series("a", ["1", "2", "3"]) + with pytest.raises(InvalidOperationError): + s.sin() + + # Date + s = pl.Series("a", [date(1990, 2, 28), date(2022, 7, 26)]) + with pytest.raises(InvalidOperationError): + s.cosh() + + +@pytest.mark.parametrize("dtype", INTEGER_DTYPES) +def test_product_ints(dtype: PolarsDataType) -> None: + a = pl.Series("a", [1, 2, 3], dtype=dtype) + out = a.product() + assert out == 6 + a = pl.Series("a", [1, 2, None], dtype=dtype) + out = a.product() + assert out == 2 + a = pl.Series("a", [None, 2, 3], dtype=dtype) + out = a.product() + assert out == 6 + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_product_floats(dtype: PolarsDataType) -> None: + a = pl.Series("a", [], dtype=dtype) + out = a.product() + assert out == 1 + a = pl.Series("a", [None, None], dtype=dtype) + out = a.product() + assert out == 1 + a = pl.Series("a", [3.0, None, float("nan")], dtype=dtype) + out = a.product() + assert math.isnan(out) + + +def test_ceil() -> None: + s = pl.Series([1.8, 1.2, 3.0]) + expected = pl.Series([2.0, 2.0, 3.0]) + assert_series_equal(s.ceil(), expected) + + +def test_duration_arithmetic() -> None: + # apply some basic duration math to series + s = pl.Series([datetime(2022, 1, 1, 10, 20, 30), datetime(2022, 1, 2, 20, 40, 50)]) + d1 = pl.duration(days=5, microseconds=123456) + d2 = timedelta(days=5, microseconds=123456) + + expected_values = [ + datetime(2022, 1, 6, 10, 20, 30, 123456), + datetime(2022, 1, 7, 20, 40, 50, 123456), + ] + for d in (d1, d2): + df1 = pl.select((s + d).alias("d_offset")) + df2 = pl.select((d + s).alias("d_offset")) + assert df1["d_offset"].to_list() == expected_values + assert_series_equal(df1["d_offset"], df2["d_offset"]) + + +def test_mean_overflow() -> None: + arr = np.array([255] * (1 << 17), dtype="int16") + assert arr.mean() == 255.0 + + +def test_sign() -> None: + # Integers + a = pl.Series("a", [-9, -0, 0, 4, None]) + expected = pl.Series("a", [-1, 0, 0, 1, None]) + assert_series_equal(a.sign(), expected) + + # Floats + a = pl.Series("a", [-9.0, -0.0, 0.0, 4.0, float("nan"), None]) + expected = pl.Series("a", [-1.0, 0.0, 0.0, 1.0, float("nan"), None]) + assert_series_equal(a.sign(), expected) + + # Invalid input + a = pl.Series("a", [date(1950, 2, 1), date(1970, 1, 1), date(2022, 12, 12), None]) + with pytest.raises(InvalidOperationError): + a.sign() + + +def test_exp() -> None: + s = pl.Series("a", [0.1, 0.01, None]) + expected = pl.Series("a", [1.1051709180756477, 1.010050167084168, None]) + assert_series_equal(s.exp(), expected) + # test if we can run on empty series as well. + assert s[:0].exp().to_list() == [] + + +def test_cumulative_eval() -> None: + s = pl.Series("values", [1, 2, 3, 4, 5]) + + # evaluate expressions individually + expr1 = pl.element().first() + expr2 = pl.element().last() ** 2 + + expected1 = pl.Series("values", [1, 1, 1, 1, 1]) + expected2 = pl.Series("values", [1, 4, 9, 16, 25]) + assert_series_equal(s.cumulative_eval(expr1), expected1) + assert_series_equal(s.cumulative_eval(expr2), expected2) + + # evaluate combined expressions and validate + expr3 = expr1 - expr2 + expected3 = pl.Series("values", [0, -3, -8, -15, -24]) + assert_series_equal(s.cumulative_eval(expr3), expected3) + + +def test_reverse() -> None: + s = pl.Series("values", [1, 2, 3, 4, 5]) + assert s.reverse().to_list() == [5, 4, 3, 2, 1] + + s = pl.Series("values", ["a", "b", None, "y", "x"]) + assert s.reverse().to_list() == ["x", "y", None, "b", "a"] + + +def test_reverse_binary() -> None: + # single chunk + s = pl.Series("values", ["a", "b", "c", "d"]).cast(pl.Binary) + assert s.reverse().to_list() == [b"d", b"c", b"b", b"a"] + + # multiple chunks + chunk1 = pl.Series("values", ["a", "b"]) + chunk2 = pl.Series("values", ["c", "d"]) + s = chunk1.extend(chunk2).cast(pl.Binary) + assert s.n_chunks() == 2 + assert s.reverse().to_list() == [b"d", b"c", b"b", b"a"] + + +def test_clip() -> None: + s = pl.Series("foo", [-50, 5, None, 50]) + assert s.clip(1, 10).to_list() == [1, 5, None, 10] + + +def test_repr() -> None: + s = pl.Series("ints", [1001, 2002, 3003]) + s_repr = repr(s) + + assert "shape: (3,)" in s_repr + assert "Series: 'ints' [i64]" in s_repr + for n in s.to_list(): + assert str(n) in s_repr + + class XSeries(pl.Series): + """Custom Series class.""" + + # check custom class name reflected in repr output + x = XSeries("ints", [1001, 2002, 3003]) + x_repr = repr(x) + + assert "shape: (3,)" in x_repr + assert "XSeries: 'ints' [i64]" in x_repr + assert "1001" in x_repr + for n in x.to_list(): + assert str(n) in x_repr + + +def test_repr_html(df: pl.DataFrame) -> None: + # check it does not panic/error, and appears to contain a table + html = pl.Series("misc", [123, 456, 789])._repr_html_() + assert " None: + s = pl.Series("timestamp", [value, None]) + result = pl.from_epoch(s, time_unit=time_unit) + + expected = pl.Series("timestamp", [exp, None]).cast(exp_type) + assert_series_equal(result, expected) + + +def test_get_chunks() -> None: + a = pl.Series("a", [1, 2]) + b = pl.Series("a", [3, 4]) + chunks = pl.concat([a, b], rechunk=False).get_chunks() + assert_series_equal(chunks[0], a) + assert_series_equal(chunks[1], b) + + +def test_null_comparisons() -> None: + s = pl.Series("s", [None, "str", "a"]) + assert (s.shift() == s).null_count() == 2 + assert (s.shift() != s).null_count() == 2 + + +def test_min_max_agg_on_str() -> None: + strings = ["b", "a", "x"] + s = pl.Series(strings) + assert (s.min(), s.max()) == ("a", "x") + + +def test_min_max_full_nan_15058() -> None: + s = pl.Series([float("nan")] * 2) + assert all(x != x for x in [s.min(), s.max()]) + + +def test_is_between() -> None: + s = pl.Series("num", [1, 2, None, 4, 5]) + assert s.is_between(2, 4).to_list() == [False, True, None, True, False] + + s = pl.Series("num", [1, 2, None, 4, 5]) + assert s.is_between(2, 4, closed="left").to_list() == [ + False, + True, + None, + False, + False, + ] + + s = pl.Series("num", [1, 2, None, 4, 5]) + assert s.is_between(2, 4, closed="right").to_list() == [ + False, + False, + None, + True, + False, + ] + + s = pl.Series("num", [1, 2, None, 4, 5]) + assert s.is_between(pl.lit(2) / 2, pl.lit(4) * 2, closed="both").to_list() == [ + True, + True, + None, + True, + True, + ] + + s = pl.Series("s", ["a", "b", "c", "d", "e"]) + assert s.is_between("b", "d").to_list() == [ + False, + True, + True, + True, + False, + ] + + +@pytest.mark.parametrize( + ("dtype", "lower", "upper"), + [ + (pl.Int8, -128, 127), + (pl.UInt8, 0, 255), + (pl.Int16, -32768, 32767), + (pl.UInt16, 0, 65535), + (pl.Int32, -2147483648, 2147483647), + (pl.UInt32, 0, 4294967295), + (pl.Int64, -9223372036854775808, 9223372036854775807), + (pl.UInt64, 0, 18446744073709551615), + (pl.Float32, float("-inf"), float("inf")), + (pl.Float64, float("-inf"), float("inf")), + ], +) +def test_upper_lower_bounds( + dtype: PolarsDataType, upper: int | float, lower: int | float +) -> None: + s = pl.Series("s", dtype=dtype) + assert s.lower_bound().item() == lower + assert s.upper_bound().item() == upper + + +def test_numpy_series_arithmetic() -> None: + sx = pl.Series(values=[1, 2]) + y = np.array([3.0, 4.0]) + + result_add1 = y + sx + result_add2 = sx + y + expected_add = pl.Series([4.0, 6.0], dtype=pl.Float64) + assert_series_equal(result_add1, expected_add) # type: ignore[arg-type] + assert_series_equal(result_add2, expected_add) + + result_sub1 = cast(pl.Series, y - sx) # py37 is different vs py311 on this one + expected = pl.Series([2.0, 2.0], dtype=pl.Float64) + assert_series_equal(result_sub1, expected) + result_sub2 = sx - y + expected = pl.Series([-2.0, -2.0], dtype=pl.Float64) + assert_series_equal(result_sub2, expected) + + result_mul1 = y * sx + result_mul2 = sx * y + expected = pl.Series([3.0, 8.0], dtype=pl.Float64) + assert_series_equal(result_mul1, expected) # type: ignore[arg-type] + assert_series_equal(result_mul2, expected) + + result_div1 = y / sx + expected = pl.Series([3.0, 2.0], dtype=pl.Float64) + assert_series_equal(result_div1, expected) # type: ignore[arg-type] + result_div2 = sx / y + expected = pl.Series([1 / 3, 0.5], dtype=pl.Float64) + assert_series_equal(result_div2, expected) + + result_pow1 = y**sx + expected = pl.Series([3.0, 16.0], dtype=pl.Float64) + assert_series_equal(result_pow1, expected) # type: ignore[arg-type] + result_pow2 = sx**y + expected = pl.Series([1.0, 16.0], dtype=pl.Float64) + assert_series_equal(result_pow2, expected) # type: ignore[arg-type] + + +def test_from_epoch_seq_input() -> None: + seq_input = [1147880044] + expected = pl.Series([datetime(2006, 5, 17, 15, 34, 4)]) + result = pl.from_epoch(seq_input) + assert_series_equal(result, expected) + + +def test_symmetry_for_max_in_names() -> None: + # int + a = pl.Series("a", [1]) + assert (a - a.max()).name == (a.max() - a).name == a.name + # float + a = pl.Series("a", [1.0]) + assert (a - a.max()).name == (a.max() - a).name == a.name + # duration + a = pl.Series("a", [1], dtype=pl.Duration("ns")) + assert (a - a.max()).name == (a.max() - a).name == a.name + # datetime + a = pl.Series("a", [1], dtype=pl.Datetime("ns")) + assert (a - a.max()).name == (a.max() - a).name == a.name + + # TODO: time arithmetic support? + # a = pl.Series("a", [1], dtype=pl.Time) + # assert (a - a.max()).name == (a.max() - a).name == a.name + + +def test_series_getitem_out_of_bounds_positive() -> None: + s = pl.Series([1, 2]) + with pytest.raises( + IndexError, match="index 10 is out of bounds for sequence of length 2" + ): + s[10] + + +def test_series_getitem_out_of_bounds_negative() -> None: + s = pl.Series([1, 2]) + with pytest.raises( + IndexError, match="index -10 is out of bounds for sequence of length 2" + ): + s[-10] + + +def test_series_cmp_fast_paths() -> None: + assert ( + pl.Series([None], dtype=pl.Int32) != pl.Series([1, 2], dtype=pl.Int32) + ).to_list() == [None, None] + assert ( + pl.Series([None], dtype=pl.Int32) == pl.Series([1, 2], dtype=pl.Int32) + ).to_list() == [None, None] + + assert ( + pl.Series([None], dtype=pl.String) != pl.Series(["a", "b"], dtype=pl.String) + ).to_list() == [None, None] + assert ( + pl.Series([None], dtype=pl.String) == pl.Series(["a", "b"], dtype=pl.String) + ).to_list() == [None, None] + + assert ( + pl.Series([None], dtype=pl.Boolean) + != pl.Series([True, False], dtype=pl.Boolean) + ).to_list() == [None, None] + assert ( + pl.Series([None], dtype=pl.Boolean) + == pl.Series([False, False], dtype=pl.Boolean) + ).to_list() == [None, None] + + +def test_comp_series_with_str_13123() -> None: + s = pl.Series(["1", "2", None]) + assert_series_equal(s != "1", pl.Series([False, True, None])) + assert_series_equal(s == "1", pl.Series([True, False, None])) + assert_series_equal(s.eq_missing("1"), pl.Series([True, False, False])) + assert_series_equal(s.ne_missing("1"), pl.Series([False, True, True])) + + +@pytest.mark.parametrize( + ("data", "single", "multiple", "single_expected", "multiple_expected"), + [ + ([1, 2, 3], 1, [2, 4], 0, [1, 3]), + (["a", "b", "c"], "d", ["a", "d"], 3, [0, 3]), + ([b"a", b"b", b"c"], b"d", [b"a", b"d"], 3, [0, 3]), + ( + [date(2022, 1, 2), date(2023, 4, 1)], + date(2022, 1, 1), + [date(1999, 10, 1), date(2024, 1, 1)], + 0, + [0, 2], + ), + ([1, 2, 3], 1, np.array([2, 4]), 0, [1, 3]), # test np array. + ], +) +def test_search_sorted( + data: list[Any], + single: Any, + multiple: list[Any], + single_expected: Any, + multiple_expected: list[Any], +) -> None: + s = pl.Series(data) + single_s = s.search_sorted(single) + assert single_s == single_expected + + multiple_s = s.search_sorted(multiple) + assert_series_equal(multiple_s, pl.Series(multiple_expected, dtype=pl.UInt32)) + + +def test_series_from_pandas_with_dtype() -> None: + expected = pl.Series("foo", [1, 2, 3], dtype=pl.Int8) + s = pl.Series("foo", pd.Series([1, 2, 3]), pl.Int8) + assert_series_equal(s, expected) + s = pl.Series("foo", pd.Series([1, 2, 3], dtype="Int16"), pl.Int8) + assert_series_equal(s, expected) + + with pytest.raises(InvalidOperationError, match="conversion from"): + pl.Series("foo", pd.Series([-1, 2, 3]), pl.UInt8) + s = pl.Series("foo", pd.Series([-1, 2, 3]), pl.UInt8, strict=False) + assert s.to_list() == [None, 2, 3] + assert s.dtype == pl.UInt8 + + with pytest.raises(InvalidOperationError, match="conversion from"): + pl.Series("foo", pd.Series([-1, 2, 3], dtype="Int8"), pl.UInt8) + s = pl.Series("foo", pd.Series([-1, 2, 3], dtype="Int8"), pl.UInt8, strict=False) + assert s.to_list() == [None, 2, 3] + assert s.dtype == pl.UInt8 + + +def test_series_from_pyarrow_with_dtype() -> None: + s = pl.Series("foo", pa.array([-1, 2, 3]), pl.Int8) + assert_series_equal(s, pl.Series("foo", [-1, 2, 3], dtype=pl.Int8)) + + with pytest.raises(InvalidOperationError, match="conversion from"): + pl.Series("foo", pa.array([-1, 2, 3]), pl.UInt8) + + s = pl.Series("foo", pa.array([-1, 2, 3]), dtype=pl.UInt8, strict=False) + assert s.to_list() == [None, 2, 3] + assert s.dtype == pl.UInt8 + + +def test_series_from_numpy_with_dtype() -> None: + s = pl.Series("foo", np.array([-1, 2, 3]), pl.Int8) + assert_series_equal(s, pl.Series("foo", [-1, 2, 3], dtype=pl.Int8)) + + with pytest.raises(InvalidOperationError, match="conversion from"): + pl.Series("foo", np.array([-1, 2, 3]), pl.UInt8) + + s = pl.Series("foo", np.array([-1, 2, 3]), dtype=pl.UInt8, strict=False) + assert s.to_list() == [None, 2, 3] + assert s.dtype == pl.UInt8 + + +def test_raise_invalid_is_between() -> None: + with pytest.raises(pl.exceptions.InvalidOperationError): + pl.select(pl.lit(2).is_between(pl.lit("11"), pl.lit("33"))) + + +def test_construction_large_nested_u64_17231() -> None: + import polars as pl + + values = [{"f0": [9223372036854775808]}] + dtype = pl.Struct({"f0": pl.List(pl.UInt64)}) + + assert pl.Series(values, dtype=dtype).to_list() == values + + +def test_repeat_by() -> None: + calculated = pl.select(a=pl.Series("a", [1, 2]).repeat_by(2)) + expected = pl.select(a=pl.Series("a", [[1, 1], [2, 2]])) + assert calculated.equals(expected) diff --git a/py-polars/tests/unit/series/test_to_list.py b/py-polars/tests/unit/series/test_to_list.py new file mode 100644 index 000000000000..84d18879c3fb --- /dev/null +++ b/py-polars/tests/unit/series/test_to_list.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from hypothesis import example, given + +import polars as pl +from polars.testing import assert_series_equal +from polars.testing.parametric import series + + +@given( + s=series( + # Roundtrip doesn't work with time zones: + # https://github.com/pola-rs/polars/issues/16297 + allow_time_zones=False, + ) +) +@example(s=pl.Series(dtype=pl.Array(pl.Date, 1))) +def test_to_list(s: pl.Series) -> None: + values = s.to_list() + result = pl.Series(values, dtype=s.dtype) + assert_series_equal(s, result, categorical_as_str=True) diff --git a/py-polars/tests/unit/sql/__init__.py b/py-polars/tests/unit/sql/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/sql/test_array.py b/py-polars/tests/unit/sql/test_array.py new file mode 100644 index 000000000000..2a9a70ac3156 --- /dev/null +++ b/py-polars/tests/unit/sql/test_array.py @@ -0,0 +1,240 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +import polars as pl +from polars.exceptions import SQLInterfaceError, SQLSyntaxError +from polars.testing import assert_frame_equal + + +@pytest.mark.parametrize( + ("sort_order", "limit", "expected"), + [ + (None, None, [("a", ["x", "y"]), ("b", ["z", "X", "Y"])]), + ("ASC", None, [("a", ["x", "y"]), ("b", ["z", "Y", "X"])]), + ("DESC", None, [("a", ["y", "x"]), ("b", ["X", "Y", "z"])]), + ("ASC", 2, [("a", ["x", "y"]), ("b", ["z", "Y"])]), + ("DESC", 2, [("a", ["y", "x"]), ("b", ["X", "Y"])]), + ("ASC", 1, [("a", ["x"]), ("b", ["z"])]), + ("DESC", 1, [("a", ["y"]), ("b", ["X"])]), + ], +) +def test_array_agg(sort_order: str | None, limit: int | None, expected: Any) -> None: + order_by = "" if not sort_order else f" ORDER BY col0 {sort_order}" + limit_clause = "" if not limit else f" LIMIT {limit}" + + res = pl.sql( + f""" + WITH data (col0, col1, col2) as ( + VALUES + (1,'a','x'), + (2,'a','y'), + (4,'b','z'), + (8,'b','X'), + (7,'b','Y') + ) + SELECT col1, ARRAY_AGG(col2{order_by}{limit_clause}) AS arrs + FROM data + GROUP BY col1 + ORDER BY col1 + """ + ).collect() + + assert res.rows() == expected + + +def test_array_literals() -> None: + with pl.SQLContext(df=None, eager=True) as ctx: + res = ctx.execute( + """ + SELECT + a1, a2, + -- test some array ops + ARRAY_AGG(a1) AS a3, + ARRAY_AGG(a2) AS a4, + ARRAY_CONTAINS(a1,20) AS i20, + ARRAY_CONTAINS(a2,'zz') AS izz, + ARRAY_REVERSE(a1) AS ar1, + ARRAY_REVERSE(a2) AS ar2 + FROM ( + SELECT + -- declare array literals + [10,20,30] AS a1, + ['a','b','c'] AS a2, + FROM df + ) tbl + """ + ) + assert_frame_equal( + res, + pl.DataFrame( + { + "a1": [[10, 20, 30]], + "a2": [["a", "b", "c"]], + "a3": [[[10, 20, 30]]], + "a4": [[["a", "b", "c"]]], + "i20": [True], + "izz": [False], + "ar1": [[30, 20, 10]], + "ar2": [["c", "b", "a"]], + } + ), + ) + + +@pytest.mark.parametrize( + ("array_index", "expected"), + [ + (-4, None), + (-3, 99), + (-2, 66), + (-1, 33), + (0, None), + (1, 99), + (2, 66), + (3, 33), + (4, None), + ], +) +def test_array_indexing(array_index: int, expected: int | None) -> None: + res = pl.sql( + f""" + SELECT + arr[{array_index}] AS idx1, + ARRAY_GET(arr,{array_index}) AS idx2, + FROM (SELECT [99,66,33] AS arr) tbl + """ + ).collect() + + assert_frame_equal( + res, + pl.DataFrame( + {"idx1": [expected], "idx2": [expected]}, + ), + check_dtypes=False, + ) + + +def test_array_indexing_by_expr() -> None: + df = pl.DataFrame( + { + "idx": [-2, -1, 0, None, 1, 2, 3], + "arr": [[0, 1, 2, 3], [4, 5], [6], [7, 8, 9], [8, 7], [6, 5, 4], [3, 2, 1]], + } + ) + res = df.sql( + """ + SELECT + arr[idx] AS idx1, + ARRAY_GET(arr, idx) AS idx2 + FROM self + """ + ) + expected = [2, 5, None, None, 8, 5, 1] + assert_frame_equal(res, pl.DataFrame({"idx1": expected, "idx2": expected})) + + +def test_array_to_string() -> None: + data = { + "s_values": [["aa", "bb"], [None, "cc"], ["dd", None]], + "n_values": [[999, 777], [None, 555], [333, None]], + } + res = pl.DataFrame(data).sql( + """ + SELECT + ARRAY_TO_STRING(s_values, '') AS vs1, + ARRAY_TO_STRING(s_values, ':') AS vs2, + ARRAY_TO_STRING(s_values, ':', 'NA') AS vs3, + ARRAY_TO_STRING(n_values, '') AS vn1, + ARRAY_TO_STRING(n_values, ':') AS vn2, + ARRAY_TO_STRING(n_values, ':', 'NA') AS vn3 + FROM self + """ + ) + assert_frame_equal( + res, + pl.DataFrame( + { + "vs1": ["aabb", "cc", "dd"], + "vs2": ["aa:bb", "cc", "dd"], + "vs3": ["aa:bb", "NA:cc", "dd:NA"], + "vn1": ["999777", "555", "333"], + "vn2": ["999:777", "555", "333"], + "vn3": ["999:777", "NA:555", "333:NA"], + } + ), + ) + with pytest.raises( + SQLSyntaxError, + match=r"ARRAY_TO_STRING expects 2-3 arguments \(found 1\)", + ): + pl.sql_expr("ARRAY_TO_STRING(arr)") + + +@pytest.mark.parametrize( + "array_keyword", + ["ARRAY", ""], +) +def test_unnest_table_function(array_keyword: str) -> None: + with pl.SQLContext(df=None, eager=True) as ctx: + res = ctx.execute( + f""" + SELECT * FROM + UNNEST( + {array_keyword}[1, 2, 3, 4], + {array_keyword}['ww','xx','yy','zz'], + {array_keyword}[23.0, 24.5, 28.0, 27.5] + ) AS tbl (x,y,z); + """ + ) + assert_frame_equal( + res, + pl.DataFrame( + { + "x": [1, 2, 3, 4], + "y": ["ww", "xx", "yy", "zz"], + "z": [23.0, 24.5, 28.0, 27.5], + } + ), + ) + + +def test_unnest_table_function_errors() -> None: + with pl.SQLContext(df=None, eager=True) as ctx: + with pytest.raises( + SQLSyntaxError, + match=r'UNNEST table alias must also declare column names, eg: "frame data" \(a,b,c\)', + ): + ctx.execute('SELECT * FROM UNNEST([1, 2, 3]) AS "frame data"') + + with pytest.raises( + SQLSyntaxError, + match="UNNEST table alias requires 1 column name, found 2", + ): + ctx.execute("SELECT * FROM UNNEST([1, 2, 3]) AS tbl (a, b)") + + with pytest.raises( + SQLSyntaxError, + match="UNNEST table alias requires 2 column names, found 1", + ): + ctx.execute("SELECT * FROM UNNEST([1,2,3], [3,4,5]) AS tbl (a)") + + with pytest.raises( + SQLSyntaxError, + match=r"UNNEST table must have an alias", + ): + ctx.execute("SELECT * FROM UNNEST([1, 2, 3])") + + with pytest.raises( + SQLInterfaceError, + match=r"UNNEST tables do not \(yet\) support WITH OFFSET|ORDINALITY", + ): + ctx.execute("SELECT * FROM UNNEST([1, 2, 3]) tbl (colx) WITH OFFSET") + + with pytest.raises( + SQLInterfaceError, + match="nested array literals are not currently supported", + ): + pl.sql_expr("[[1,2,3]] AS nested") diff --git a/py-polars/tests/unit/sql/test_bitwise.py b/py-polars/tests/unit/sql/test_bitwise.py new file mode 100644 index 000000000000..7bba8cfde5c5 --- /dev/null +++ b/py-polars/tests/unit/sql/test_bitwise.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import pytest + +import polars as pl + + +@pytest.fixture +def df() -> pl.DataFrame: + return pl.DataFrame( + { + "x": [20, 32, 50, 88, 128], + "y": [-128, 0, 10, -1, None], + } + ) + + +def test_bitwise_and(df: pl.DataFrame) -> None: + res = df.sql( + """ + SELECT + x & y AS x_bitand_op_y, + BITAND(y, x) AS y_bitand_x, + BIT_AND(x, y) AS x_bitand_y, + FROM self + """ + ) + assert res.to_dict(as_series=False) == { + "x_bitand_op_y": [0, 0, 2, 88, None], + "y_bitand_x": [0, 0, 2, 88, None], + "x_bitand_y": [0, 0, 2, 88, None], + } + + +def test_bitwise_count(df: pl.DataFrame) -> None: + res = df.sql( + """ + SELECT + BITCOUNT(x) AS x_bits_set, + BIT_COUNT(y) AS y_bits_set, + FROM self + """ + ) + assert res.to_dict(as_series=False) == { + "x_bits_set": [2, 1, 3, 3, 1], + "y_bits_set": [57, 0, 2, 64, None], + } + + +def test_bitwise_or(df: pl.DataFrame) -> None: + res = df.sql( + """ + SELECT + x | y AS x_bitor_op_y, + BITOR(y, x) AS y_bitor_x, + BIT_OR(x, y) AS x_bitor_y, + FROM self + """ + ) + assert res.to_dict(as_series=False) == { + "x_bitor_op_y": [-108, 32, 58, -1, None], + "y_bitor_x": [-108, 32, 58, -1, None], + "x_bitor_y": [-108, 32, 58, -1, None], + } + + +def test_bitwise_xor(df: pl.DataFrame) -> None: + res = df.sql( + """ + SELECT + x XOR y AS x_bitxor_op_y, + BITXOR(y, x) AS y_bitxor_x, + BIT_XOR(x, y) AS x_bitxor_y, + FROM self + """ + ) + assert res.to_dict(as_series=False) == { + "x_bitxor_op_y": [-108, 32, 56, -89, None], + "y_bitxor_x": [-108, 32, 56, -89, None], + "x_bitxor_y": [-108, 32, 56, -89, None], + } diff --git a/py-polars/tests/unit/sql/test_cast.py b/py-polars/tests/unit/sql/test_cast.py new file mode 100644 index 000000000000..f3ec824b6023 --- /dev/null +++ b/py-polars/tests/unit/sql/test_cast.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +import polars as pl +import polars.selectors as cs +from polars.exceptions import InvalidOperationError, SQLInterfaceError +from polars.testing import assert_frame_equal + + +def test_cast() -> None: + df = pl.DataFrame( + { + "a": [1, 2, 3, 4, 5], + "b": [1.1, 2.2, 3.3, 4.4, 5.5], + "c": ["a", "b", "c", "d", "e"], + "d": [True, False, True, False, True], + "e": [-1, 0, None, 1, 2], + } + ) + + # test various dtype casts, using standard ("CAST AS ") + # and postgres-specific ("::") cast syntax + with pl.SQLContext(df=df, eager=True) as ctx: + res = ctx.execute( + """ + SELECT + -- float + CAST(a AS DOUBLE PRECISION) AS a_f64, + a::real AS a_f32, + b::float(24) AS b_f32, + b::float(25) AS b_f64, + e::float8 AS e_f64, + e::float4 AS e_f32, + + -- integer + CAST(b AS TINYINT) AS b_i8, + CAST(b AS SMALLINT) AS b_i16, + b::bigint AS b_i64, + d::tinyint AS d_i8, + d::hugeint AS d_i128, + a::int1 AS a_i8, + a::int2 AS a_i16, + a::int4 AS a_i32, + a::int8 AS a_i64, + + -- unsigned integer + CAST(a AS TINYINT UNSIGNED) AS a_u8, + d::uint1 AS d_u8, + a::uint2 AS a_u16, + b::uint4 AS b_u32, + b::uint8 AS b_u64, + CAST(a AS BIGINT UNSIGNED) AS a_u64, + b::utinyint AS b_u8, + b::usmallint AS b_u16, + a::uinteger AS a_u32, + d::ubigint AS d_u64, + + -- string/binary + CAST(a AS CHAR) AS a_char, + CAST(b AS VARCHAR) AS b_varchar, + c::blob AS c_blob, + c::bytes AS c_bytes, + c::VARBINARY AS c_varbinary, + CAST(d AS CHARACTER VARYING) AS d_charvar, + + -- boolean + e::bool AS e_bool, + e::boolean AS e_boolean + FROM df + """ + ) + assert res.schema == { + "a_f64": pl.Float64, + "a_f32": pl.Float32, + "b_f32": pl.Float32, + "b_f64": pl.Float64, + "e_f64": pl.Float64, + "e_f32": pl.Float32, + "b_i8": pl.Int8, + "b_i16": pl.Int16, + "b_i64": pl.Int64, + "d_i8": pl.Int8, + "d_i128": pl.Int128, + "a_i8": pl.Int8, + "a_i16": pl.Int16, + "a_i32": pl.Int32, + "a_i64": pl.Int64, + "a_u8": pl.UInt8, + "d_u8": pl.UInt8, + "a_u16": pl.UInt16, + "b_u32": pl.UInt32, + "b_u64": pl.UInt64, + "a_u64": pl.UInt64, + "b_u8": pl.UInt8, + "b_u16": pl.UInt16, + "a_u32": pl.UInt32, + "d_u64": pl.UInt64, + "a_char": pl.String, + "b_varchar": pl.String, + "c_blob": pl.Binary, + "c_bytes": pl.Binary, + "c_varbinary": pl.Binary, + "d_charvar": pl.String, + "e_bool": pl.Boolean, + "e_boolean": pl.Boolean, + } + assert res.select(cs.by_dtype(pl.Float32)).rows() == pytest.approx( + [ + (1.0, 1.100000023841858, -1.0), + (2.0, 2.200000047683716, 0.0), + (3.0, 3.299999952316284, None), + (4.0, 4.400000095367432, 1.0), + (5.0, 5.5, 2.0), + ] + ) + assert res.select(cs.by_dtype(pl.Float64)).rows() == [ + (1.0, 1.1, -1.0), + (2.0, 2.2, 0.0), + (3.0, 3.3, None), + (4.0, 4.4, 1.0), + (5.0, 5.5, 2.0), + ] + assert res.select(cs.integer()).rows() == [ + (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1), + (2, 2, 2, 0, 0, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 0), + (3, 3, 3, 1, 1, 3, 3, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 1), + (4, 4, 4, 0, 0, 4, 4, 4, 4, 4, 0, 4, 4, 4, 4, 4, 4, 4, 0), + (5, 5, 5, 1, 1, 5, 5, 5, 5, 5, 1, 5, 5, 5, 5, 5, 5, 5, 1), + ] + assert res.select(cs.string()).rows() == [ + ("1", "1.1", "true"), + ("2", "2.2", "false"), + ("3", "3.3", "true"), + ("4", "4.4", "false"), + ("5", "5.5", "true"), + ] + assert res.select(cs.binary()).rows() == [ + (b"a", b"a", b"a"), + (b"b", b"b", b"b"), + (b"c", b"c", b"c"), + (b"d", b"d", b"d"), + (b"e", b"e", b"e"), + ] + assert res.select(cs.boolean()).rows() == [ + (True, True), + (False, False), + (None, None), + (True, True), + (True, True), + ] + + with pytest.raises( + SQLInterfaceError, + match="use of FORMAT is not currently supported in CAST", + ): + pl.SQLContext(df=df, eager=True).execute( + "SELECT CAST(a AS STRING FORMAT 'HEX') FROM df" + ) + + +@pytest.mark.parametrize( + ("values", "cast_op", "error"), + [ + ([1.0, -1.0], "values::uint8", "conversion from `f64` to `u64` failed"), + ([10, 0, -1], "values::uint4", "conversion from `i64` to `u32` failed"), + ([int(1e8)], "values::int1", "conversion from `i64` to `i8` failed"), + (["a", "b"], "values::date", "conversion from `str` to `date` failed"), + (["a", "b"], "values::time", "conversion from `str` to `time` failed"), + (["a", "b"], "values::int4", "conversion from `str` to `i32` failed"), + ], +) +def test_cast_errors(values: Any, cast_op: str, error: str) -> None: + df = pl.DataFrame({"values": values}) + + # invalid CAST should raise an error... + with pytest.raises(InvalidOperationError, match=error): + df.sql(f"SELECT {cast_op} FROM self") + + # ... or return `null` values if using TRY_CAST + target_type = cast_op.split("::")[1] + res = df.sql(f"SELECT TRY_CAST(values AS {target_type}) AS cast_values FROM self") + assert None in res.to_series() + + +def test_cast_json() -> None: + df = pl.DataFrame({"txt": ['{"a":[1,2,3],"b":["x","y","z"],"c":5.0}']}) + + with pl.SQLContext(df=df, eager=True) as ctx: + for json_cast in ("txt::json", "CAST(txt AS JSON)"): + res = ctx.execute(f"SELECT {json_cast} AS j FROM df") + + assert res.schema == { + "j": pl.Struct( + { + "a": pl.List(pl.Int64), + "b": pl.List(pl.String), + "c": pl.Float64, + }, + ) + } + assert_frame_equal( + res.unnest("j"), + pl.DataFrame( + { + "a": [[1, 2, 3]], + "b": [["x", "y", "z"]], + "c": [5.0], + } + ), + ) diff --git a/py-polars/tests/unit/sql/test_conditional.py b/py-polars/tests/unit/sql/test_conditional.py new file mode 100644 index 000000000000..3a80c1234aff --- /dev/null +++ b/py-polars/tests/unit/sql/test_conditional.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +from datetime import date +from pathlib import Path + +import pytest + +import polars as pl +from polars.exceptions import SQLSyntaxError +from polars.testing import assert_frame_equal + + +@pytest.fixture +def foods_ipc_path() -> Path: + return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc" + + +def test_case_when() -> None: + lf = pl.LazyFrame( + { + "v1": [None, 2, None, 4], + "v2": [101, 202, 303, 404], + } + ) + with pl.SQLContext(test_data=lf, eager=True) as ctx: + out = ctx.execute( + """ + SELECT *, CASE WHEN COALESCE(v1, v2) % 2 != 0 THEN 'odd' ELSE 'even' END as "v3" + FROM test_data + """ + ) + assert out.to_dict(as_series=False) == { + "v1": [None, 2, None, 4], + "v2": [101, 202, 303, 404], + "v3": ["odd", "even", "odd", "even"], + } + + +@pytest.mark.parametrize("else_clause", ["ELSE NULL ", ""]) +def test_case_when_optional_else(else_clause: str) -> None: + df = pl.DataFrame( + { + "a": [1, 2, 3, 4, 5, 6, 7], + "b": [7, 6, 5, 4, 3, 2, 1], + "c": [3, 4, 0, 3, 4, 1, 1], + } + ) + query = f""" + SELECT + AVG(CASE WHEN a <= b THEN c {else_clause}END) AS conditional_mean + FROM self + """ + res = df.sql(query) + assert res.to_dict(as_series=False) == {"conditional_mean": [2.5]} + + +def test_control_flow(foods_ipc_path: Path) -> None: + nums = pl.LazyFrame( + { + "x": [1, None, 2, 3, None, 4], + "y": [5, 4, None, 3, None, 2], + "z": [3, 4, None, 3, 6, None], + } + ) + res = pl.SQLContext(df=nums).execute( + """ + SELECT + COALESCE(x,y,z) as "coalsc", + NULLIF(x, y) as "nullif x_y", + NULLIF(y, z) as "nullif y_z", + IFNULL(x, y) as "ifnull x_y", + IFNULL(y,-1) as "inullf y_z", + COALESCE(x, NULLIF(y,z)) as "both", + IF(x = y, 'eq', 'ne') as "x_eq_y", + FROM df + """, + eager=True, + ) + assert res.to_dict(as_series=False) == { + "coalsc": [1, 4, 2, 3, 6, 4], + "nullif x_y": [1, None, 2, None, None, 4], + "nullif y_z": [5, None, None, None, None, 2], + "ifnull x_y": [1, 4, 2, 3, None, 4], + "inullf y_z": [5, 4, -1, 3, -1, 2], + "both": [1, None, 2, 3, None, 4], + "x_eq_y": ["ne", "ne", "ne", "eq", "ne", "ne"], + } + + for null_func in ("IFNULL", "NULLIF"): + with pytest.raises( + SQLSyntaxError, + match=r"(IFNULL|NULLIF) expects 2 arguments \(found 3\)", + ): + pl.SQLContext(df=nums).execute(f"SELECT {null_func}(x,y,z) FROM df") + + +def test_greatest_least() -> None: + df = pl.DataFrame( + { + "a": [-100, None, 200, 99], + "b": [None, -0.1, 99.0, 100.0], + "c": ["bb", "aa", "dd", "cc"], + "d": ["cc", "bb", "aa", "dd"], + "e": [date(1969, 12, 31), date(2021, 1, 2), None, date(2021, 1, 4)], + "f": [date(1970, 1, 1), date(2000, 10, 20), date(2077, 7, 5), None], + } + ) + with pl.SQLContext(df=df) as ctx: + df_max_horizontal = ctx.execute( + """ + SELECT + GREATEST("a", 0, "b") AS max_ab_zero, + GREATEST("a", "b") AS max_ab, + GREATEST("c", "d", ) AS max_cd, + GREATEST("e", "f") AS max_ef, + GREATEST('1999-12-31'::date, "e", "f") AS max_efx + FROM df + """ + ).collect() + + assert_frame_equal( + df_max_horizontal, + pl.DataFrame( + { + "max_ab_zero": [0.0, 0.0, 200.0, 100.0], + "max_ab": [-100.0, -0.1, 200.0, 100.0], + "max_cd": ["cc", "bb", "dd", "dd"], + "max_ef": [ + date(1970, 1, 1), + date(2021, 1, 2), + date(2077, 7, 5), + date(2021, 1, 4), + ], + "max_efx": [ + date(1999, 12, 31), + date(2021, 1, 2), + date(2077, 7, 5), + date(2021, 1, 4), + ], + } + ), + ) + + df_min_horizontal = ctx.execute( + """ + SELECT + LEAST("b", "a", 0) AS min_ab_zero, + LEAST("a", "b") AS min_ab, + LEAST("c", "d") AS min_cd, + LEAST("e", "f") AS min_ef, + LEAST("f", "e", '1999-12-31'::date) AS min_efx + FROM df + """ + ).collect() + + assert_frame_equal( + df_min_horizontal, + pl.DataFrame( + { + "min_ab_zero": [-100.0, -0.1, 0.0, 0.0], + "min_ab": [-100.0, -0.1, 99.0, 99.0], + "min_cd": ["bb", "aa", "aa", "cc"], + "min_ef": [ + date(1969, 12, 31), + date(2000, 10, 20), + date(2077, 7, 5), + date(2021, 1, 4), + ], + "min_efx": [ + date(1969, 12, 31), + date(1999, 12, 31), + date(1999, 12, 31), + date(1999, 12, 31), + ], + } + ), + ) diff --git a/py-polars/tests/unit/sql/test_functions.py b/py-polars/tests/unit/sql/test_functions.py new file mode 100644 index 000000000000..84f2ecd972bc --- /dev/null +++ b/py-polars/tests/unit/sql/test_functions.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +import polars as pl +from polars.exceptions import SQLInterfaceError +from polars.testing import assert_frame_equal + + +@pytest.fixture +def foods_ipc_path() -> Path: + return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc" + + +def test_sql_expr() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": ["xyz", "abcde", None]}) + sql_exprs = pl.sql_expr( + [ + "MIN(a)", + "POWER(a,a) AS aa", + "SUBSTR(b,2,2) AS b2", + ] + ) + result = df.select(*sql_exprs) + expected = pl.DataFrame( + {"a": [1, 1, 1], "aa": [1, 4, 27], "b2": ["yz", "bc", None]} + ) + assert_frame_equal(result, expected) + + # expect expressions that can't reasonably be parsed as expressions to raise + # (for example: those that explicitly reference tables and/or use wildcards) + with pytest.raises( + SQLInterfaceError, + match=r"unable to parse 'xyz\.\*' as Expr", + ): + pl.sql_expr("xyz.*") diff --git a/py-polars/tests/unit/sql/test_group_by.py b/py-polars/tests/unit/sql/test_group_by.py new file mode 100644 index 000000000000..42f19af9d1c7 --- /dev/null +++ b/py-polars/tests/unit/sql/test_group_by.py @@ -0,0 +1,252 @@ +from __future__ import annotations + +from datetime import date +from pathlib import Path + +import pytest + +import polars as pl +from polars.exceptions import SQLSyntaxError +from polars.testing import assert_frame_equal + + +@pytest.fixture +def foods_ipc_path() -> Path: + return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc" + + +def test_group_by(foods_ipc_path: Path) -> None: + lf = pl.scan_ipc(foods_ipc_path) + + ctx = pl.SQLContext(eager=True) + ctx.register("foods", lf) + + out = ctx.execute( + """ + SELECT + count(category) as n, + category, + max(calories) as max_cal, + median(calories) as median_cal, + min(fats_g) as min_fats + FROM foods + GROUP BY category + HAVING n > 5 + ORDER BY n, category DESC + """ + ) + assert out.to_dict(as_series=False) == { + "n": [7, 7, 8], + "category": ["vegetables", "fruit", "seafood"], + "max_cal": [45, 130, 200], + "median_cal": [25.0, 50.0, 145.0], + "min_fats": [0.0, 0.0, 1.5], + } + + lf = pl.LazyFrame( + { + "grp": ["a", "b", "c", "c", "b"], + "att": ["x", "y", "x", "y", "y"], + } + ) + assert ctx.tables() == ["foods"] + + ctx.register("test", lf) + assert ctx.tables() == ["foods", "test"] + + out = ctx.execute( + """ + SELECT + grp, + COUNT(DISTINCT att) AS n_dist_attr + FROM test + GROUP BY grp + HAVING n_dist_attr > 1 + """ + ) + assert out.to_dict(as_series=False) == {"grp": ["c"], "n_dist_attr": [2]} + + +def test_group_by_all() -> None: + df = pl.DataFrame( + { + "a": ["xx", "yy", "xx", "yy", "xx", "zz"], + "b": [1, 2, 3, 4, 5, 6], + "c": [99, 99, 66, 66, 66, 66], + } + ) + + # basic group/agg + res = df.sql( + """ + SELECT + a, + SUM(b), + SUM(c), + COUNT(*) AS n + FROM self + GROUP BY ALL + ORDER BY ALL + """ + ) + expected = pl.DataFrame( + { + "a": ["xx", "yy", "zz"], + "b": [9, 6, 6], + "c": [231, 165, 66], + "n": [3, 2, 1], + } + ) + assert_frame_equal(expected, res, check_dtypes=False) + + # more involved determination of agg/group columns + res = df.sql( + """ + SELECT + SUM(b) AS sum_b, + SUM(c) AS sum_c, + (SUM(b) + SUM(c)) / 2.0 AS sum_bc_over_2, -- nested agg + a as grp, --aliased group key + FROM self + GROUP BY ALL + ORDER BY grp + """ + ) + expected = pl.DataFrame( + { + "sum_b": [9, 6, 6], + "sum_c": [231, 165, 66], + "sum_bc_over_2": [120.0, 85.5, 36.0], + "grp": ["xx", "yy", "zz"], + } + ) + assert_frame_equal(expected, res.sort(by="grp")) + + +def test_group_by_all_multi() -> None: + dt1 = date(1999, 12, 31) + dt2 = date(2028, 7, 5) + + df = pl.DataFrame( + { + "key": ["xx", "yy", "xx", "yy", "xx", "xx"], + "dt": [dt1, dt1, dt1, dt2, dt2, dt2], + "value": [10.5, -5.5, 20.5, 8.0, -3.0, 5.0], + } + ) + expected = pl.DataFrame( + { + "dt": [dt1, dt1, dt2, dt2], + "key": ["xx", "yy", "xx", "yy"], + "sum_value": [31.0, -5.5, 2.0, 8.0], + "ninety_nine": [99, 99, 99, 99], + }, + schema_overrides={"ninety_nine": pl.Int16}, + ) + + # the following groupings should all be equivalent + for group in ( + "ALL", + "1, 2", + "dt, key", + ): + res = df.sql( + f""" + SELECT dt, key, sum_value, ninety_nine::int2 FROM + ( + SELECT + dt, + key, + SUM(value) AS sum_value, + 99 AS ninety_nine + FROM self + GROUP BY {group} + ORDER BY dt, key + ) AS grp + """ + ) + assert_frame_equal(expected, res) + + +def test_group_by_ordinal_position() -> None: + df = pl.DataFrame( + { + "a": ["xx", "yy", "xx", "yy", "xx", "zz"], + "b": [1, None, 3, 4, 5, 6], + "c": [99, 99, 66, 66, 66, 66], + } + ) + expected = pl.LazyFrame( + { + "c": [66, 99], + "total_b": [18, 1], + "count_b": [4, 1], + "count_star": [4, 2], + } + ) + + with pl.SQLContext(frame=df) as ctx: + res1 = ctx.execute( + """ + SELECT + c, + SUM(b) AS total_b, + COUNT(b) AS count_b, + COUNT(*) AS count_star + FROM frame + GROUP BY 1 + ORDER BY c + """ + ) + assert_frame_equal(res1, expected, check_dtypes=False) + + res2 = ctx.execute( + """ + WITH "grp" AS ( + SELECT NULL::date as dt, c, SUM(b) AS total_b + FROM frame + GROUP BY 2, 1 + ) + SELECT c, total_b FROM grp ORDER BY c""" + ) + assert_frame_equal(res2, expected.select(pl.nth(0, 1))) + + +def test_group_by_errors() -> None: + df = pl.DataFrame( + { + "a": ["xx", "yy", "xx"], + "b": [10, 20, 30], + "c": [99, 99, 66], + } + ) + + with pytest.raises( + SQLSyntaxError, + match=r"negative ordinal values are invalid for GROUP BY; found -99", + ): + df.sql("SELECT a, SUM(b) FROM self GROUP BY -99, a") + + with pytest.raises( + SQLSyntaxError, + match=r"GROUP BY requires a valid expression or positive ordinal; found '!!!'", + ): + df.sql("SELECT a, SUM(b) FROM self GROUP BY a, '!!!'") + + with pytest.raises( + SQLSyntaxError, + match=r"'a' should participate in the GROUP BY clause or an aggregate function", + ): + df.sql("SELECT a, SUM(b) FROM self GROUP BY b") + + with pytest.raises( + SQLSyntaxError, + match=r"HAVING clause not valid outside of GROUP BY", + ): + df.sql("SELECT a, COUNT(a) AS n FROM self HAVING n > 1") + + +def test_group_by_output_struct() -> None: + df = pl.DataFrame({"g": [1], "x": [2], "y": [3]}) + out = df.group_by("g").agg(pl.struct(pl.col.x.min(), pl.col.y.sum())) + assert out.rows() == [(1, {"x": 2, "y": 3})] diff --git a/py-polars/tests/unit/sql/test_joins.py b/py-polars/tests/unit/sql/test_joins.py new file mode 100644 index 000000000000..55ebc511730f --- /dev/null +++ b/py-polars/tests/unit/sql/test_joins.py @@ -0,0 +1,722 @@ +from __future__ import annotations + +from io import BytesIO +from pathlib import Path +from typing import Any + +import pytest + +import polars as pl +from polars.exceptions import SQLInterfaceError, SQLSyntaxError +from polars.testing import assert_frame_equal + + +@pytest.fixture +def foods_ipc_path() -> Path: + return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc" + + +@pytest.mark.parametrize( + ("sql", "expected"), + [ + ( + "SELECT * FROM tbl_a LEFT SEMI JOIN tbl_b USING (a,c)", + pl.DataFrame({"a": [2], "b": [0], "c": ["y"]}), + ), + ( + "SELECT * FROM tbl_a SEMI JOIN tbl_b USING (a,c)", + pl.DataFrame({"a": [2], "b": [0], "c": ["y"]}), + ), + ( + "SELECT * FROM tbl_a LEFT SEMI JOIN tbl_b USING (a)", + pl.DataFrame({"a": [1, 2, 3], "b": [4, 0, 6], "c": ["w", "y", "z"]}), + ), + ( + "SELECT * FROM tbl_a LEFT ANTI JOIN tbl_b USING (a)", + pl.DataFrame(schema={"a": pl.Int64, "b": pl.Int64, "c": pl.String}), + ), + ( + "SELECT * FROM tbl_a ANTI JOIN tbl_b USING (a)", + pl.DataFrame(schema={"a": pl.Int64, "b": pl.Int64, "c": pl.String}), + ), + ( + "SELECT * FROM tbl_a LEFT SEMI JOIN tbl_b USING (b) LEFT SEMI JOIN tbl_c USING (c)", + pl.DataFrame({"a": [1, 3], "b": [4, 6], "c": ["w", "z"]}), + ), + ( + "SELECT * FROM tbl_a LEFT ANTI JOIN tbl_b USING (b) LEFT SEMI JOIN tbl_c USING (c)", + pl.DataFrame({"a": [2], "b": [0], "c": ["y"]}), + ), + ( + "SELECT * FROM tbl_a RIGHT ANTI JOIN tbl_b USING (b) LEFT SEMI JOIN tbl_c USING (c)", + pl.DataFrame({"a": [2], "b": [5], "c": ["y"]}), + ), + ( + "SELECT * FROM tbl_a RIGHT SEMI JOIN tbl_b USING (b) RIGHT SEMI JOIN tbl_c USING (c)", + pl.DataFrame({"c": ["z"], "d": [25.5]}), + ), + ( + "SELECT * FROM tbl_a RIGHT SEMI JOIN tbl_b USING (b) RIGHT ANTI JOIN tbl_c USING (c)", + pl.DataFrame({"c": ["w", "y"], "d": [10.5, -50.0]}), + ), + ], +) +def test_join_anti_semi(sql: str, expected: pl.DataFrame) -> None: + frames = { + "tbl_a": pl.DataFrame({"a": [1, 2, 3], "b": [4, 0, 6], "c": ["w", "y", "z"]}), + "tbl_b": pl.DataFrame({"a": [3, 2, 1], "b": [6, 5, 4], "c": ["x", "y", "z"]}), + "tbl_c": pl.DataFrame({"c": ["w", "y", "z"], "d": [10.5, -50.0, 25.5]}), + } + ctx = pl.SQLContext(frames, eager=True) + assert_frame_equal(expected, ctx.execute(sql)) + + +def test_join_cross() -> None: + frames = { + "tbl_a": pl.DataFrame({"a": [1, 2, 3], "b": [4, 0, 6], "c": ["w", "y", "z"]}), + "tbl_b": pl.DataFrame({"a": [3, 2, 1], "b": [6, 5, 4], "c": ["x", "y", "z"]}), + } + with pl.SQLContext(frames, eager=True) as ctx: + out = ctx.execute( + """ + SELECT * + FROM tbl_a + CROSS JOIN tbl_b + ORDER BY a, b, c + """ + ) + assert out.rows() == [ + (1, 4, "w", 3, 6, "x"), + (1, 4, "w", 2, 5, "y"), + (1, 4, "w", 1, 4, "z"), + (2, 0, "y", 3, 6, "x"), + (2, 0, "y", 2, 5, "y"), + (2, 0, "y", 1, 4, "z"), + (3, 6, "z", 3, 6, "x"), + (3, 6, "z", 2, 5, "y"), + (3, 6, "z", 1, 4, "z"), + ] + + +def test_join_cross_11927() -> None: + df1 = pl.DataFrame({"id": [1, 2, 3]}) # noqa: F841 + df2 = pl.DataFrame({"id": [3, 4, 5]}) # noqa: F841 + df3 = pl.DataFrame({"id": [4, 5, 6]}) # noqa: F841 + + res = pl.sql("SELECT df1.id FROM df1 CROSS JOIN df2 WHERE df1.id = df2.id") + assert_frame_equal(res.collect(), pl.DataFrame({"id": [3]})) + + res = pl.sql("SELECT * FROM df1 CROSS JOIN df3 WHERE df1.id = df3.id") + assert res.collect().is_empty() + + +@pytest.mark.parametrize( + "join_clause", + [ + "ON foods1.category = foods2.category", + "ON foods2.category = foods1.category", + "USING (category)", + ], +) +def test_join_inner(foods_ipc_path: Path, join_clause: str) -> None: + foods1 = pl.scan_ipc(foods_ipc_path) + foods2 = foods1 # noqa: F841 + schema = foods1.collect_schema() + + sort_clause = ", ".join(f'{c} ASC, "{c}:foods2" DESC' for c in schema) + out = pl.sql( + f""" + SELECT * + FROM foods1 + INNER JOIN foods2 {join_clause} + ORDER BY {sort_clause} + LIMIT 2 + """, + eager=True, + ) + + assert_frame_equal( + out, + pl.DataFrame( + { + "category": ["fruit", "fruit"], + "calories": [30, 30], + "fats_g": [0.0, 0.0], + "sugars_g": [3, 5], + "category:foods2": ["fruit", "fruit"], + "calories:foods2": [130, 130], + "fats_g:foods2": [0.0, 0.0], + "sugars_g:foods2": [25, 25], + } + ), + check_dtypes=False, + ) + + +@pytest.mark.parametrize( + "join_clause", + [ + """ + INNER JOIN tbl_b USING (a,b) + INNER JOIN tbl_c USING (c) + """, + """ + INNER JOIN tbl_b ON tbl_a.a = tbl_b.a AND tbl_a.b = tbl_b.b + INNER JOIN tbl_c ON tbl_a.c = tbl_c.c + """, + ], +) +def test_join_inner_multi(join_clause: str) -> None: + frames = { + "tbl_a": pl.DataFrame({"a": [1, 2, 3], "b": [4, None, 6]}), + "tbl_b": pl.DataFrame({"a": [3, 2, 1], "b": [6, 5, 4], "c": ["x", "y", "z"]}), + "tbl_c": pl.DataFrame({"c": ["w", "y", "z"], "d": [10.5, -50.0, 25.5]}), + } + with pl.SQLContext(frames) as ctx: + assert ctx.tables() == ["tbl_a", "tbl_b", "tbl_c"] + for select_cols in ("a, b, c, d", "tbl_a.a, tbl_a.b, tbl_b.c, tbl_c.d"): + out = ctx.execute( + f"SELECT {select_cols} FROM tbl_a {join_clause} ORDER BY a DESC" + ) + assert out.collect().rows() == [(1, 4, "z", 25.5)] + + +def test_join_inner_15663() -> None: + df_a = pl.DataFrame({"LOCID": [1, 2, 3], "VALUE": [0.1, 0.2, 0.3]}) # noqa: F841 + df_b = pl.DataFrame({"LOCID": [1, 2, 3], "VALUE": [25.6, 53.4, 12.7]}) # noqa: F841 + expected = pl.DataFrame( + { + "LOCID": [1, 2, 3], + "VALUE_A": [0.1, 0.2, 0.3], + "VALUE_B": [25.6, 53.4, 12.7], + } + ) + with pl.SQLContext(register_globals=True, eager=True) as ctx: + query = """ + SELECT + a.LOCID, + a.VALUE AS VALUE_A, + b.VALUE AS VALUE_B + FROM df_a AS a + INNER JOIN df_b AS b + USING (LOCID) + ORDER BY LOCID + """ + actual = ctx.execute(query) + assert_frame_equal(expected, actual) + + +@pytest.mark.parametrize( + "join_clause", + [ + """ + LEFT JOIN tbl_b USING (a,b) + LEFT JOIN tbl_c USING (c) + """, + """ + LEFT JOIN tbl_b ON tbl_a.a = tbl_b.a AND tbl_a.b = tbl_b.b + LEFT JOIN tbl_c ON tbl_a.c = tbl_c.c + """, + ], +) +def test_join_left_multi(join_clause: str) -> None: + frames = { + "tbl_a": pl.DataFrame({"a": [1, 2, 3], "b": [4, None, 6]}), + "tbl_b": pl.DataFrame({"a": [3, 2, 1], "b": [6, 5, 4], "c": ["x", "y", "z"]}), + "tbl_c": pl.DataFrame({"c": ["w", "y", "z"], "d": [10.5, -50.0, 25.5]}), + } + with pl.SQLContext(frames) as ctx: + for select_cols in ("a, b, c, d", "tbl_a.a, tbl_a.b, tbl_b.c, tbl_c.d"): + out = ctx.execute( + f"SELECT {select_cols} FROM tbl_a {join_clause} ORDER BY a DESC" + ) + assert out.collect().rows() == [ + (3, 6, "x", None), + (2, None, None, None), + (1, 4, "z", 25.5), + ] + + +def test_join_left_multi_nested() -> None: + frames = { + "tbl_a": pl.DataFrame({"a": [1, 2, 3], "b": [4, None, 6]}), + "tbl_b": pl.DataFrame({"a": [3, 2, 1], "b": [6, 5, 4], "c": ["x", "y", "z"]}), + "tbl_c": pl.DataFrame({"c": ["w", "y", "z"], "d": [10.5, -50.0, 25.5]}), + } + with pl.SQLContext(frames) as ctx: + for select_cols in ("a, b, c, d", "tbl_x.a, tbl_x.b, tbl_x.c, tbl_c.d"): + out = ctx.execute( + f""" + SELECT {select_cols} FROM (SELECT * + FROM tbl_a + LEFT JOIN tbl_b ON tbl_a.a = tbl_b.a AND tbl_a.b = tbl_b.b + ) tbl_x + LEFT JOIN tbl_c ON tbl_x.c = tbl_c.c + ORDER BY tbl_x.a ASC + """ + ).collect() + + assert out.rows() == [ + (1, 4, "z", 25.5), + (2, None, None, None), + (3, 6, "x", None), + ] + + +def test_join_misc_13618() -> None: + import polars as pl + + df = pl.DataFrame( + { + "A": [1, 2, 3, 4, 5], + "B": [5, 4, 3, 2, 1], + "fruits": ["banana", "banana", "apple", "apple", "banana"], + "cars": ["beetle", "audi", "beetle", "beetle", "beetle"], + } + ) + res = ( + pl.SQLContext(t=df, t1=df, eager=True) + .execute( + """ + SELECT t.A, t.fruits, t1.B, t1.cars + FROM t + JOIN t1 ON t.A = t1.B + ORDER BY t.A DESC + """ + ) + .to_dict(as_series=False) + ) + assert res == { + "A": [5, 4, 3, 2, 1], + "fruits": ["banana", "apple", "apple", "banana", "banana"], + "B": [5, 4, 3, 2, 1], + "cars": ["beetle", "audi", "beetle", "beetle", "beetle"], + } + + +def test_join_misc_16255() -> None: + df1 = pl.read_csv(BytesIO(b"id,data\n1,open")) # noqa: F841 + df2 = pl.read_csv(BytesIO(b"id,data\n1,closed")) # noqa: F841 + res = pl.sql( + """ + SELECT a.id, a.data AS d1, b.data AS d2 + FROM df1 AS a JOIN df2 AS b + ON a.id = b.id + """, + eager=True, + ) + assert res.rows() == [(1, "open", "closed")] + + +@pytest.mark.parametrize( + "constraint", ["tbl.a != tbl.b", "tbl.a > tbl.b", "a >= b", "a < b", "b <= a"] +) +def test_non_equi_joins(constraint: str) -> None: + # no support (yet) for non equi-joins in polars joins + # TODO: integrate awareness of new IEJoin + with ( + pytest.raises( + SQLInterfaceError, + match=r"only equi-join constraints \(combined with 'AND'\) are currently supported", + ), + pl.SQLContext({"tbl": pl.DataFrame({"a": [1, 2, 3], "b": [4, 3, 2]})}) as ctx, + ): + ctx.execute( + f""" + SELECT * + FROM tbl + LEFT JOIN tbl ON {constraint} -- not an equi-join + """ + ) + + +def test_implicit_joins() -> None: + # no support for this yet; ensure we catch it + with ( + pytest.raises( + SQLInterfaceError, + match=r"not currently supported .* use explicit JOIN syntax instead", + ), + pl.SQLContext( + { + "tbl": pl.DataFrame( + {"a": [1, 2, 3], "b": [4, 3, 2], "c": ["x", "y", "z"]} + ) + } + ) as ctx, + ): + ctx.execute( + """ + SELECT t1.* + FROM tbl AS t1, tbl AS t2 + WHERE t1.a = t2.b + """ + ) + + +@pytest.mark.parametrize( + ("query", "expected"), + [ + # INNER joins + ( + "SELECT df1.* FROM df1 INNER JOIN df2 USING (a)", + {"a": [1, 3], "b": ["x", "z"], "c": [100, 300]}, + ), + ( + "SELECT df2.* FROM df1 INNER JOIN df2 USING (a)", + {"a": [1, 3], "b": ["qq", "pp"], "c": [400, 500]}, + ), + ( + "SELECT df1.* FROM df2 INNER JOIN df1 USING (a)", + {"a": [1, 3], "b": ["x", "z"], "c": [100, 300]}, + ), + ( + "SELECT df2.* FROM df2 INNER JOIN df1 USING (a)", + {"a": [1, 3], "b": ["qq", "pp"], "c": [400, 500]}, + ), + # LEFT joins + ( + "SELECT df1.* FROM df1 LEFT JOIN df2 USING (a)", + {"a": [1, 2, 3], "b": ["x", "y", "z"], "c": [100, 200, 300]}, + ), + ( + "SELECT df2.* FROM df1 LEFT JOIN df2 USING (a)", + {"a": [1, 3, None], "b": ["qq", "pp", None], "c": [400, 500, None]}, + ), + ( + "SELECT df1.* FROM df2 LEFT JOIN df1 USING (a)", + {"a": [1, 3, None], "b": ["x", "z", None], "c": [100, 300, None]}, + ), + ( + "SELECT df2.* FROM df2 LEFT JOIN df1 USING (a)", + {"a": [1, 3, 4], "b": ["qq", "pp", "oo"], "c": [400, 500, 600]}, + ), + # RIGHT joins + ( + "SELECT df1.* FROM df1 RIGHT JOIN df2 USING (a)", + {"a": [1, 3, None], "b": ["x", "z", None], "c": [100, 300, None]}, + ), + ( + "SELECT df2.* FROM df1 RIGHT JOIN df2 USING (a)", + {"a": [1, 3, 4], "b": ["qq", "pp", "oo"], "c": [400, 500, 600]}, + ), + ( + "SELECT df1.* FROM df2 RIGHT JOIN df1 USING (a)", + {"a": [1, 2, 3], "b": ["x", "y", "z"], "c": [100, 200, 300]}, + ), + ( + "SELECT df2.* FROM df2 RIGHT JOIN df1 USING (a)", + {"a": [1, 3, None], "b": ["qq", "pp", None], "c": [400, 500, None]}, + ), + # FULL joins + ( + "SELECT df1.* FROM df1 FULL JOIN df2 USING (a)", + { + "a": [1, 2, 3, None], + "b": ["x", "y", "z", None], + "c": [100, 200, 300, None], + }, + ), + ( + "SELECT df2.* FROM df1 FULL JOIN df2 USING (a)", + { + "a": [1, 3, 4, None], + "b": ["qq", "pp", "oo", None], + "c": [400, 500, 600, None], + }, + ), + ( + "SELECT df1.* FROM df2 FULL JOIN df1 USING (a)", + { + "a": [1, 2, 3, None], + "b": ["x", "y", "z", None], + "c": [100, 200, 300, None], + }, + ), + ( + "SELECT df2.* FROM df2 FULL JOIN df1 USING (a)", + { + "a": [1, 3, 4, None], + "b": ["qq", "pp", "oo", None], + "c": [400, 500, 600, None], + }, + ), + ], +) +def test_wildcard_resolution_and_join_order( + query: str, expected: dict[str, Any] +) -> None: + df1 = pl.DataFrame({"a": [1, 2, 3], "b": ["x", "y", "z"], "c": [100, 200, 300]}) # noqa: F841 + df2 = pl.DataFrame({"a": [1, 3, 4], "b": ["qq", "pp", "oo"], "c": [400, 500, 600]}) # noqa: F841 + + res = pl.sql(query).collect() + assert_frame_equal( + res, + pl.DataFrame(expected), + check_row_order=False, + ) + + +def test_natural_joins_01() -> None: + df1 = pl.DataFrame( + { + "CharacterID": [1, 2, 3, 4], + "FirstName": ["Jernau Morat", "Cheradenine", "Byr", "Diziet"], + "LastName": ["Gurgeh", "Zakalwe", "Genar-Hofoen", "Sma"], + } + ) + df2 = pl.DataFrame( + { + "CharacterID": [1, 2, 3, 5], + "Role": ["Protagonist", "Protagonist", "Protagonist", "Antagonist"], + "Book": [ + "Player of Games", + "Use of Weapons", + "Excession", + "Consider Phlebas", + ], + } + ) + df3 = pl.DataFrame( + { + "CharacterID": [1, 2, 3, 4], + "Affiliation": ["Culture", "Culture", "Culture", "Shellworld"], + "Species": ["Pan-human", "Human", "Human", "Oct"], + } + ) + df4 = pl.DataFrame( + { + "CharacterID": [1, 2, 3, 6], + "Ship": [ + "Limiting Factor", + "Xenophobe", + "Grey Area", + "Falling Outside The Normal Moral Constraints", + ], + "Drone": ["Flere-Imsaho", "Skaffen-Amtiskaw", "Eccentric", "Psychopath"], + } + ) + with pl.SQLContext( + {"df1": df1, "df2": df2, "df3": df3, "df4": df4}, eager=True + ) as ctx: + # note: use of 'COLUMNS' is a neat way to drop + # all non-coalesced ":" cols + res = ctx.execute( + """ + SELECT COLUMNS('^[^:]*$') + FROM df1 + NATURAL LEFT JOIN df2 + NATURAL INNER JOIN df3 + NATURAL LEFT JOIN df4 + ORDER BY ALL + """ + ) + assert res.rows(named=True) == [ + { + "CharacterID": 1, + "FirstName": "Jernau Morat", + "LastName": "Gurgeh", + "Role": "Protagonist", + "Book": "Player of Games", + "Affiliation": "Culture", + "Species": "Pan-human", + "Ship": "Limiting Factor", + "Drone": "Flere-Imsaho", + }, + { + "CharacterID": 2, + "FirstName": "Cheradenine", + "LastName": "Zakalwe", + "Role": "Protagonist", + "Book": "Use of Weapons", + "Affiliation": "Culture", + "Species": "Human", + "Ship": "Xenophobe", + "Drone": "Skaffen-Amtiskaw", + }, + { + "CharacterID": 3, + "FirstName": "Byr", + "LastName": "Genar-Hofoen", + "Role": "Protagonist", + "Book": "Excession", + "Affiliation": "Culture", + "Species": "Human", + "Ship": "Grey Area", + "Drone": "Eccentric", + }, + { + "CharacterID": 4, + "FirstName": "Diziet", + "LastName": "Sma", + "Role": None, + "Book": None, + "Affiliation": "Shellworld", + "Species": "Oct", + "Ship": None, + "Drone": None, + }, + ] + + # misc errors + with pytest.raises(SQLSyntaxError, match=r"did you mean COLUMNS\(\*\)\?"): + pl.sql("SELECT * FROM df1 NATURAL JOIN df2 WHERE COLUMNS('*') >= 5") + + with pytest.raises(SQLSyntaxError, match=r"COLUMNS expects a regex"): + pl.sql("SELECT COLUMNS(1234) FROM df1 NATURAL JOIN df2") + + +@pytest.mark.parametrize( + ("cols_constraint", "expect_data"), + [ + (">= 5", [(8, 8, 6)]), + ("< 7", [(5, 4, 4)]), + ("< 8", [(5, 4, 4), (7, 4, 4), (0, 7, 2)]), + ("!= 4", [(8, 8, 6), (2, 8, 6), (0, 7, 2)]), + ], +) +def test_natural_joins_02(cols_constraint: str, expect_data: list[tuple[int]]) -> None: + df1 = pl.DataFrame( # noqa: F841 + { + "x": [1, 5, 3, 8, 6, 7, 4, 0, 2], + "y": [3, 4, 6, 8, 3, 4, 1, 7, 8], + } + ) + df2 = pl.DataFrame( # noqa: F841 + { + "y": [0, 4, 0, 8, 0, 4, 0, 7, None], + "z": [9, 8, 7, 6, 5, 4, 3, 2, 1], + }, + ) + actual = pl.sql( + f""" + SELECT * EXCLUDE "y:df2" + FROM df1 NATURAL JOIN df2 + WHERE COLUMNS(*) {cols_constraint} + """ + ).collect() + + expected = pl.DataFrame(expect_data, schema=actual.columns, orient="row") + assert_frame_equal(actual, expected, check_row_order=False) + + +@pytest.mark.parametrize( + "join_clause", + [ + """ + df2 JOIN df3 ON + df2.CharacterID = df3.CharacterID + """, + """ + df2 INNER JOIN ( + df3 JOIN df4 ON df3.CharacterID = df4.CharacterID + ) AS r0 ON df2.CharacterID = df3.CharacterID + """, + ], +) +def test_nested_join(join_clause: str) -> None: + df1 = pl.DataFrame( + { + "CharacterID": [1, 2, 3, 4], + "FirstName": ["Jernau Morat", "Cheradenine", "Byr", "Diziet"], + "LastName": ["Gurgeh", "Zakalwe", "Genar-Hofoen", "Sma"], + } + ) + df2 = pl.DataFrame( + { + "CharacterID": [1, 2, 3, 5], + "Role": ["Protagonist", "Protagonist", "Protagonist", "Antagonist"], + "Book": [ + "Player of Games", + "Use of Weapons", + "Excession", + "Consider Phlebas", + ], + } + ) + df3 = pl.DataFrame( + { + "CharacterID": [1, 2, 5, 6], + "Affiliation": ["Culture", "Culture", "Culture", "Shellworld"], + "Species": ["Pan-human", "Human", "Human", "Oct"], + } + ) + df4 = pl.DataFrame( + { + "CharacterID": [1, 2, 3, 6], + "Ship": [ + "Limiting Factor", + "Xenophobe", + "Grey Area", + "Falling Outside The Normal Moral Constraints", + ], + "Drone": ["Flere-Imsaho", "Skaffen-Amtiskaw", "Eccentric", "Psychopath"], + } + ) + + with pl.SQLContext( + {"df1": df1, "df2": df2, "df3": df3, "df4": df4}, eager=True + ) as ctx: + res = ctx.execute( + f""" + SELECT df1.CharacterID, df1.FirstName, df2.Role, df3.Species + FROM df1 + INNER JOIN ({join_clause}) AS r99 + ON df1.CharacterID = df2.CharacterID + ORDER BY ALL + """ + ) + assert res.rows(named=True) == [ + { + "CharacterID": 1, + "FirstName": "Jernau Morat", + "Role": "Protagonist", + "Species": "Pan-human", + }, + { + "CharacterID": 2, + "FirstName": "Cheradenine", + "Role": "Protagonist", + "Species": "Human", + }, + ] + + +def test_sql_forbid_nested_join_unnamed_relation() -> None: + df = pl.DataFrame({"a": 1}) + + with ( + pl.SQLContext({"left": df, "right": df}) as ctx, + pytest.raises(SQLInterfaceError, match="cannot join on unnamed relation"), + ): + ctx.execute( + """\ +SELECT * +FROM left +JOIN (right JOIN right ON right.a = right.a) +ON left.a = right.a +""" + ) + + +def test_nulls_equal_19624() -> None: + df1 = pl.DataFrame({"a": [1, 2, None, None]}) + df2 = pl.DataFrame({"a": [1, 1, 2, 2, None], "b": [0, 1, 2, 3, 4]}) + + # left join + result_df = df1.join(df2, how="left", on="a", nulls_equal=False, validate="1:m") + expected_df = pl.DataFrame( + {"a": [1, 1, 2, 2, None, None], "b": [0, 1, 2, 3, None, None]} + ) + assert_frame_equal(result_df, expected_df) + result_df = df2.join(df1, how="left", on="a", nulls_equal=False, validate="m:1") + expected_df = pl.DataFrame({"a": [1, 1, 2, 2, None], "b": [0, 1, 2, 3, 4]}) + assert_frame_equal(result_df, expected_df) + + # inner join + result_df = df1.join(df2, how="inner", on="a", nulls_equal=False, validate="1:m") + expected_df = pl.DataFrame({"a": [1, 1, 2, 2], "b": [0, 1, 2, 3]}) + assert_frame_equal(result_df, expected_df) + result_df = df2.join(df1, how="inner", on="a", nulls_equal=False, validate="m:1") + expected_df = pl.DataFrame({"a": [1, 1, 2, 2], "b": [0, 1, 2, 3]}) + assert_frame_equal(result_df, expected_df) diff --git a/py-polars/tests/unit/sql/test_literals.py b/py-polars/tests/unit/sql/test_literals.py new file mode 100644 index 000000000000..ebf6834e0664 --- /dev/null +++ b/py-polars/tests/unit/sql/test_literals.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +from datetime import date, datetime, timedelta + +import pytest + +import polars as pl +from polars.exceptions import SQLInterfaceError, SQLSyntaxError +from polars.testing import assert_frame_equal + + +def test_bit_hex_literals() -> None: + with pl.SQLContext(df=None, eager=True) as ctx: + out = ctx.execute( + """ + SELECT *, + -- bit strings + b'' AS b0, + b'1001' AS b1, + b'11101011' AS b2, + b'1111110100110010' AS b3, + -- hex strings + x'' AS x0, + x'FF' AS x1, + x'4142' AS x2, + x'DeadBeef' AS x3, + FROM df + """ + ) + + assert out.to_dict(as_series=False) == { + "b0": [b""], + "b1": [b"\t"], + "b2": [b"\xeb"], + "b3": [b"\xfd2"], + "x0": [b""], + "x1": [b"\xff"], + "x2": [b"AB"], + "x3": [b"\xde\xad\xbe\xef"], + } + + +def test_bit_hex_filter() -> None: + df = pl.DataFrame( + {"bin": [b"\x01", b"\x02", b"\x03", b"\x04"], "val": [9, 8, 7, 6]} + ) + with pl.SQLContext(test=df) as ctx: + for two in ("b'10'", "x'02'", "'\x02'", "b'0010'"): + out = ctx.execute(f"SELECT val FROM test WHERE bin > {two}", eager=True) + assert out.to_series().to_list() == [7, 6] + + +def test_bit_hex_errors() -> None: + with pl.SQLContext(test=None) as ctx: + with pytest.raises( + SQLSyntaxError, + match="bit string literal should contain only 0s and 1s", + ): + ctx.execute("SELECT b'007' FROM test", eager=True) + + with pytest.raises( + SQLSyntaxError, + match="hex string literal must have an even number of digits", + ): + ctx.execute("SELECT x'00F' FROM test", eager=True) + + with pytest.raises( + SQLSyntaxError, + match="hex string literal must have an even number of digits", + ): + pl.sql_expr("colx IN (x'FF',x'123')") + + with pytest.raises( + SQLInterfaceError, + match=r'NationalStringLiteral\("hmmm"\) is not a supported literal', + ): + pl.sql_expr("N'hmmm'") + + +def test_bit_hex_membership() -> None: + df = pl.DataFrame( + { + "x": [b"\x05", b"\xff", b"\xcc", b"\x0b"], + "y": [1, 2, 3, 4], + } + ) + # this checks the internal `visit_any_value` codepath + for values in ( + "b'0101', b'1011'", + "x'05', x'0b'", + ): + dff = df.filter(pl.sql_expr(f"x IN ({values})")) + assert dff["y"].to_list() == [1, 4] + + +def test_dollar_quoted_literals() -> None: + df = pl.sql( + """ + SELECT + $$xyz$$ AS dq1, + $q$xyz$q$ AS dq2, + $tag$xyz$tag$ AS dq3, + $QUOTE$xyz$QUOTE$ AS dq4, + """ + ).collect() + assert df.to_dict(as_series=False) == {f"dq{n}": ["xyz"] for n in range(1, 5)} + + df = pl.sql("SELECT $$x$z$$ AS dq").collect() + assert df.item() == "x$z" + + +def test_fixed_intervals() -> None: + with pl.SQLContext(df=None, eager=True) as ctx: + out = ctx.execute( + """ + SELECT + -- short form with/without spaces + INTERVAL '1w2h3m4s' AS i1, + INTERVAL '100ms 100us' AS i2, + -- long form with/without commas (case-insensitive) + INTERVAL '1 week, 2 hours, 3 minutes, 4 seconds' AS i3 + FROM df + """ + ) + expected = pl.DataFrame( + { + "i1": [timedelta(weeks=1, hours=2, minutes=3, seconds=4)], + "i2": [timedelta(microseconds=100100)], + "i3": [timedelta(weeks=1, hours=2, minutes=3, seconds=4)], + }, + ).cast(pl.Duration("ns")) + + assert_frame_equal(expected, out) + + # TODO: negative intervals + with pytest.raises( + SQLInterfaceError, + match="minus signs are not yet supported in interval strings; found '-7d'", + ): + ctx.execute("SELECT INTERVAL '-7d' AS one_week_ago FROM df") + + with pytest.raises( + SQLSyntaxError, + match="unary ops are not valid on interval strings; found -'7d'", + ): + ctx.execute("SELECT INTERVAL -'7d' AS one_week_ago FROM df") + + with pytest.raises( + SQLSyntaxError, + match="fixed-duration interval cannot contain years, quarters, or months", + ): + ctx.execute("SELECT INTERVAL '1 quarter 1 month' AS q FROM df") + + +def test_interval_offsets() -> None: + df = pl.DataFrame( + { + "dtm": [ + datetime(1899, 12, 31, 8), + datetime(1999, 6, 8, 10, 30), + datetime(2010, 5, 7, 20, 20, 20), + ], + "dt": [ + date(1950, 4, 10), + date(2048, 1, 20), + date(2026, 8, 5), + ], + } + ) + + out = df.sql( + """ + SELECT + dtm + INTERVAL '2 months, 30 minutes' AS dtm_plus_2mo30m, + dt + INTERVAL '100 years' AS dt_plus_100y, + dt - INTERVAL '1 quarter' AS dt_minus_1q + FROM self + ORDER BY 1 + """ + ) + assert out.to_dict(as_series=False) == { + "dtm_plus_2mo30m": [ + datetime(1900, 2, 28, 8, 30), + datetime(1999, 8, 8, 11, 0), + datetime(2010, 7, 7, 20, 50, 20), + ], + "dt_plus_100y": [ + date(2050, 4, 10), + date(2148, 1, 20), + date(2126, 8, 5), + ], + "dt_minus_1q": [ + date(1950, 1, 10), + date(2047, 10, 20), + date(2026, 5, 5), + ], + } + + +@pytest.mark.parametrize( + ("interval_comparison", "expected_result"), + [ + ("INTERVAL '3 days' <= INTERVAL '3 days, 1 microsecond'", True), + ("INTERVAL '3 days, 1 microsecond' <= INTERVAL '3 days'", False), + ("INTERVAL '3 months' >= INTERVAL '3 months'", True), + ("INTERVAL '2 quarters' < INTERVAL '2 quarters'", False), + ("INTERVAL '2 quarters' > INTERVAL '2 quarters'", False), + ("INTERVAL '3 years' <=> INTERVAL '3 years'", True), + ("INTERVAL '3 years' == INTERVAL '1008 weeks'", False), + ("INTERVAL '8 weeks' != INTERVAL '2 months'", True), + ("INTERVAL '8 weeks' = INTERVAL '2 months'", False), + ("INTERVAL '1 year' != INTERVAL '365 days'", True), + ("INTERVAL '1 year' = INTERVAL '1 year'", True), + ], +) +def test_interval_comparisons(interval_comparison: str, expected_result: bool) -> None: + with pl.SQLContext() as ctx: + res = ctx.execute(f"SELECT {interval_comparison} AS res") + assert res.collect().to_dict(as_series=False) == {"res": [expected_result]} + + +def test_select_literals_no_table() -> None: + res = pl.sql("SELECT 1 AS one, '2' AS two, 3.0 AS three", eager=True) + assert res.to_dict(as_series=False) == { + "one": [1], + "two": ["2"], + "three": [3.0], + } + + +def test_select_from_table_with_reserved_names() -> None: + select = pl.DataFrame({"select": [1, 2, 3], "from": [4, 5, 6]}) # noqa: F841 + out = pl.sql( + """ + SELECT "from", "select" + FROM "select" + WHERE "from" >= 5 AND "select" % 2 != 1 + """, + eager=True, + ) + assert out.rows() == [(5, 2)] diff --git a/py-polars/tests/unit/sql/test_miscellaneous.py b/py-polars/tests/unit/sql/test_miscellaneous.py new file mode 100644 index 000000000000..66942e9f02c6 --- /dev/null +++ b/py-polars/tests/unit/sql/test_miscellaneous.py @@ -0,0 +1,509 @@ +from __future__ import annotations + +from datetime import date +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import pytest + +import polars as pl +from polars.exceptions import ColumnNotFoundError, SQLInterfaceError, SQLSyntaxError +from polars.testing import assert_frame_equal +from tests.unit.utils.pycapsule_utils import PyCapsuleStreamHolder + +if TYPE_CHECKING: + from polars.datatypes import DataType + + +@pytest.fixture +def foods_ipc_path() -> Path: + return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc" + + +def test_any_all() -> None: + df = pl.DataFrame( # noqa: F841 + { + "x": [-1, 0, 1, 2, 3, 4], + "y": [1, 0, 0, 1, 2, 3], + } + ) + res = pl.sql( + """ + SELECT + x >= ALL(df.y) AS "All Geq", + x > ALL(df.y) AS "All G", + x < ALL(df.y) AS "All L", + x <= ALL(df.y) AS "All Leq", + x >= ANY(df.y) AS "Any Geq", + x > ANY(df.y) AS "Any G", + x < ANY(df.y) AS "Any L", + x <= ANY(df.y) AS "Any Leq", + x == ANY(df.y) AS "Any eq", + x != ANY(df.y) AS "Any Neq", + FROM df + """, + ).collect() + + assert res.to_dict(as_series=False) == { + "All Geq": [0, 0, 0, 0, 1, 1], + "All G": [0, 0, 0, 0, 0, 1], + "All L": [1, 0, 0, 0, 0, 0], + "All Leq": [1, 1, 0, 0, 0, 0], + "Any Geq": [0, 1, 1, 1, 1, 1], + "Any G": [0, 0, 1, 1, 1, 1], + "Any L": [1, 1, 1, 1, 0, 0], + "Any Leq": [1, 1, 1, 1, 1, 0], + "Any eq": [0, 1, 1, 1, 1, 0], + "Any Neq": [1, 0, 0, 0, 0, 1], + } + + +@pytest.mark.parametrize( + ("data", "schema"), + [ + ({"x": [1, 2, 3, 4]}, None), + ({"x": [9, 8, 7, 6]}, {"x": pl.Int8}), + ({"x": ["aa", "bb"]}, {"x": pl.Struct}), + ({"x": [None, None], "y": [None, None]}, {"x": pl.Date, "y": pl.Float64}), + ], +) +def test_boolean_where_clauses( + data: dict[str, Any], schema: dict[str, DataType] | None +) -> None: + df = pl.DataFrame(data=data, schema=schema) + empty_df = df.clear() + + for true in ("TRUE", "1=1", "2 == 2", "'xx' = 'xx'", "TRUE AND 1=1"): + assert_frame_equal(df, df.sql(f"SELECT * FROM self WHERE {true}")) + + for false in ("false", "1!=1", "2 != 2", "'xx' != 'xx'", "FALSE OR 1!=1"): + assert_frame_equal(empty_df, df.sql(f"SELECT * FROM self WHERE {false}")) + + +def test_count() -> None: + df = pl.DataFrame( + { + "a": [1, 2, 3, 4, 5], + "b": [1, 1, 22, 22, 333], + "c": [1, 1, None, None, 2], + } + ) + res = df.sql( + """ + SELECT + -- count + COUNT(a) AS count_a, + COUNT(b) AS count_b, + COUNT(c) AS count_c, + COUNT(*) AS count_star, + COUNT(NULL) AS count_null, + -- count distinct + COUNT(DISTINCT a) AS count_unique_a, + COUNT(DISTINCT b) AS count_unique_b, + COUNT(DISTINCT c) AS count_unique_c, + COUNT(DISTINCT NULL) AS count_unique_null, + FROM self + """, + ) + assert res.to_dict(as_series=False) == { + "count_a": [5], + "count_b": [5], + "count_c": [3], + "count_star": [5], + "count_null": [0], + "count_unique_a": [5], + "count_unique_b": [3], + "count_unique_c": [2], + "count_unique_null": [0], + } + + df = pl.DataFrame({"x": [None, None, None]}) + res = df.sql( + """ + SELECT + COUNT(x) AS count_x, + COUNT(*) AS count_star, + COUNT(DISTINCT x) AS count_unique_x + FROM self + """ + ) + assert res.to_dict(as_series=False) == { + "count_x": [0], + "count_star": [3], + "count_unique_x": [0], + } + + +def test_distinct() -> None: + df = pl.DataFrame( + { + "a": [1, 1, 1, 2, 2, 3], + "b": [1, 2, 3, 4, 5, 6], + } + ) + ctx = pl.SQLContext(register_globals=True, eager=True) + res1 = ctx.execute("SELECT DISTINCT a FROM df ORDER BY a DESC") + assert_frame_equal( + left=df.select("a").unique().sort(by="a", descending=True), + right=res1, + ) + + res2 = ctx.execute( + """ + SELECT DISTINCT + a * 2 AS two_a, + b / 2 AS half_b + FROM df + ORDER BY two_a ASC, half_b DESC + """, + ) + assert res2.to_dict(as_series=False) == { + "two_a": [2, 2, 4, 6], + "half_b": [1, 0, 2, 3], + } + + # test unregistration + ctx.unregister("df") + with pytest.raises(SQLInterfaceError, match="relation 'df' was not found"): + ctx.execute("SELECT * FROM df") + + +def test_frame_sql_globals_error() -> None: + df1 = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + df2 = pl.DataFrame({"a": [2, 3, 4], "b": [7, 6, 5]}) # noqa: F841 + + query = """ + SELECT df1.a, df2.b + FROM df2 JOIN df1 ON df1.a = df2.a + ORDER BY b DESC + """ + with pytest.raises(SQLInterfaceError, match="relation.*not found.*"): + df1.sql(query=query) + + res = pl.sql(query=query, eager=True) + assert res.to_dict(as_series=False) == {"a": [2, 3], "b": [7, 6]} + + +def test_in_no_ops_11946() -> None: + lf = pl.LazyFrame( + [ + {"i1": 1}, + {"i1": 2}, + {"i1": 3}, + ] + ) + out = lf.sql( + query="SELECT * FROM frame_data WHERE i1 in (1, 3)", + table_name="frame_data", + ).collect() + assert out.to_dict(as_series=False) == {"i1": [1, 3]} + + +def test_limit_offset() -> None: + n_values = 11 + lf = pl.LazyFrame({"a": range(n_values), "b": reversed(range(n_values))}) + ctx = pl.SQLContext(tbl=lf) + + assert ctx.execute("SELECT * FROM tbl LIMIT 3 OFFSET 4", eager=True).rows() == [ + (4, 6), + (5, 5), + (6, 4), + ] + for offset, limit in [(0, 3), (1, n_values), (2, 3), (5, 3), (8, 5), (n_values, 1)]: + out = ctx.execute( + f"SELECT * FROM tbl LIMIT {limit} OFFSET {offset}", eager=True + ) + assert_frame_equal(out, lf.slice(offset, limit).collect()) + assert len(out) == min(limit, n_values - offset) + + +def test_register_context() -> None: + # use as context manager unregisters tables created within each scope + # on exit from that scope; arbitrary levels of nesting are supported. + with pl.SQLContext() as ctx: + _lf1 = pl.LazyFrame({"a": [1, 2, 3], "b": ["m", "n", "o"]}) + _lf2 = pl.LazyFrame({"a": [2, 3, 4], "c": ["p", "q", "r"]}) + ctx.register_globals() + assert ctx.tables() == ["_lf1", "_lf2"] + + with ctx: + _lf3 = pl.LazyFrame({"a": [3, 4, 5], "b": ["s", "t", "u"]}) + _lf4 = pl.LazyFrame({"a": [4, 5, 6], "c": ["v", "w", "x"]}) + ctx.register_globals(n=2) + assert ctx.tables() == ["_lf1", "_lf2", "_lf3", "_lf4"] + + assert ctx.tables() == ["_lf1", "_lf2"] + + assert ctx.tables() == [] + + +def test_sql_on_compatible_frame_types() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + # create various different frame types + dfp = df.to_pandas() + dfa = df.to_arrow() + dfb = dfa.to_batches()[0] # noqa: F841 + dfo = PyCapsuleStreamHolder(df) # noqa: F841 + + # run polars sql query against all frame types + for dfs in ( # noqa: B007 + (df["a"] * 2).rename("c"), # polars series + (dfp["a"] * 2).rename("c"), # pandas series + ): + res = pl.sql( + """ + SELECT a, b, SUM(c) AS cc FROM ( + SELECT * FROM df -- polars frame + UNION ALL SELECT * FROM dfp -- pandas frame + UNION ALL SELECT * FROM dfa -- pyarrow table + UNION ALL SELECT * FROM dfb -- pyarrow record batch + UNION ALL SELECT * FROM dfo -- arbitrary pycapsule object + ) tbl + INNER JOIN dfs ON dfs.c == tbl.b -- join on pandas/polars series + GROUP BY "a", "b" + ORDER BY "a", "b" + """ + ).collect() + + expected = pl.DataFrame({"a": [1, 3], "b": [4, 6], "cc": [20, 30]}) + assert_frame_equal(left=expected, right=res) + + # register and operate on non-polars frames + for obj in (dfa, dfp): + with pl.SQLContext(obj=obj) as ctx: + res = ctx.execute("SELECT * FROM obj", eager=True) + assert_frame_equal(df, res) + + # don't register all compatible objects + with pytest.raises(SQLInterfaceError, match="relation 'dfp' was not found"): + pl.SQLContext(register_globals=True).execute("SELECT * FROM dfp") + + +def test_nested_cte_column_aliasing() -> None: + # trace through nested CTEs with multiple levels of column & table aliasing + df = pl.sql( + """ + WITH + x AS (SELECT w.* FROM (VALUES(1,2), (3,4)) AS w(a, b)), + y (m, n) AS ( + WITH z(c, d) AS (SELECT a, b FROM x) + SELECT d*2 AS d2, c*3 AS c3 FROM z + ) + SELECT n, m FROM y + """, + eager=True, + ) + assert df.to_dict(as_series=False) == { + "n": [3, 9], + "m": [4, 8], + } + + +def test_invalid_derived_table_column_aliases() -> None: + values_query = "SELECT * FROM (VALUES (1,2), (3,4))" + + with pytest.raises( + SQLSyntaxError, + match=r"columns \(5\) in alias 'tbl' does not match .* the table/query \(2\)", + ): + pl.sql(f"{values_query} AS tbl(a, b, c, d, e)") + + assert pl.sql(f"{values_query} tbl", eager=True).rows() == [(1, 2), (3, 4)] + + +def test_values_clause_table_registration() -> None: + with pl.SQLContext(frames=None, eager=True) as ctx: + # initially no tables are registered + assert ctx.tables() == [] + + # confirm that VALUES clause derived table is registered, post-query + res1 = ctx.execute("SELECT * FROM (VALUES (-1,1)) AS tbl(x, y)") + assert ctx.tables() == ["tbl"] + + # and confirm that we can select from it by the registered name + res2 = ctx.execute("SELECT x, y FROM tbl") + for res in (res1, res2): + assert res.to_dict(as_series=False) == {"x": [-1], "y": [1]} + + +def test_read_csv(tmp_path: Path) -> None: + # check empty string vs null, parsing of dates, etc + df = pl.DataFrame( + { + "label": ["lorem", None, "", "ipsum"], + "num": [-1, None, 0, 1], + "dt": [ + date(1969, 7, 5), + date(1999, 12, 31), + date(2077, 10, 10), + None, + ], + } + ) + csv_target = tmp_path / "test_sql_read.csv" + df.write_csv(csv_target) + + res = pl.sql(f"SELECT * FROM read_csv('{csv_target}')").collect() + assert_frame_equal(df, res) + + with pytest.raises( + SQLSyntaxError, + match="`read_csv` expects a single file path; found 3 arguments", + ): + pl.sql("SELECT * FROM read_csv('a','b','c')") + + +def test_global_variable_inference_17398() -> None: + users = pl.DataFrame({"id": "1"}) + + res = pl.sql( + query=""" + WITH user_by_email AS (SELECT id FROM users) + SELECT * FROM user_by_email + """, + eager=True, + ) + assert_frame_equal(res, users) + + +@pytest.mark.parametrize( + "query", + [ + "SELECT invalid_column FROM self", + "SELECT key, invalid_column FROM self", + "SELECT invalid_column * 2 FROM self", + "SELECT * FROM self ORDER BY invalid_column", + "SELECT * FROM self WHERE invalid_column = 200", + "SELECT * FROM self WHERE invalid_column = '200'", + "SELECT key, SUM(n) AS sum_n FROM self GROUP BY invalid_column", + ], +) +def test_invalid_cols(query: str) -> None: + df = pl.DataFrame( + { + "key": ["xx", "xx", "yy"], + "n": ["100", "200", "300"], + } + ) + with pytest.raises(ColumnNotFoundError, match="invalid_column"): + df.sql(query) + + +@pytest.mark.parametrize("filter_expr", ["", "WHERE 1 = 1", "WHERE a == 1 OR a != 1"]) +@pytest.mark.parametrize("order_expr", ["", "ORDER BY 1", "ORDER BY a"]) +def test_select_output_heights_20058_21084(filter_expr: str, order_expr: str) -> None: + df = pl.DataFrame({"a": [1, 2, 3]}) + + # Queries that maintain original height + + assert_frame_equal( + df.sql(f"SELECT 1 as a FROM self {filter_expr} {order_expr}").cast(pl.Int64), + pl.select(a=pl.Series([1, 1, 1])), + ) + + assert_frame_equal( + df.sql(f"SELECT 1 + 1 as a, 1 as b FROM self {filter_expr} {order_expr}").cast( + pl.Int64 + ), + pl.DataFrame({"a": [2, 2, 2], "b": [1, 1, 1]}), + ) + + # Queries that aggregate to unit height + + assert_frame_equal( + df.sql(f"SELECT COUNT(*) as a FROM self {filter_expr} {order_expr}").cast( + pl.Int64 + ), + pl.DataFrame({"a": 3}), + ) + + assert_frame_equal( + df.sql( + f"SELECT COUNT(*) as a, 1 as b FROM self {filter_expr} {order_expr}" + ).cast(pl.Int64), + pl.DataFrame({"a": 3, "b": 1}), + ) + + assert_frame_equal( + df.sql( + f"SELECT FIRST(a) as a, 1 as b FROM self {filter_expr} {order_expr}" + ).cast(pl.Int64), + pl.DataFrame({"a": 1, "b": 1}), + ) + + assert_frame_equal( + df.sql(f"SELECT SUM(a) as a, 1 as b FROM self {filter_expr} {order_expr}").cast( + pl.Int64 + ), + pl.DataFrame({"a": 6, "b": 1}), + ) + + assert_frame_equal( + df.sql( + f"SELECT FIRST(1) as a, 1 as b FROM self {filter_expr} {order_expr}" + ).cast(pl.Int64), + pl.DataFrame({"a": 1, "b": 1}), + ) + + assert_frame_equal( + df.sql( + f"SELECT FIRST(1) + 1 as a, 1 as b FROM self {filter_expr} {order_expr}" + ).cast(pl.Int64), + pl.DataFrame({"a": 2, "b": 1}), + ) + + assert_frame_equal( + df.sql( + f"SELECT FIRST(1 + 1) as a, 1 as b FROM self {filter_expr} {order_expr}" + ).cast(pl.Int64), + pl.DataFrame({"a": 2, "b": 1}), + ) + + +def test_select_explode_height_filter_order_by() -> None: + # Note: `unnest()` from SQL equates to `Expr.explode()` + df = pl.DataFrame( + { + "list_long": [[1, 2, 3], [4, 5, 6]], + "sort_key": [2, 1], + "filter_mask": [False, True], + "filter_mask_all_true": True, + } + ) + + # Height of unnest is larger than height of sort_key, the sort_key is + # extended with NULLs. + + assert_frame_equal( + df.sql("SELECT UNNEST(list_long) as list FROM self ORDER BY sort_key"), + pl.Series("list", [2, 1, 3, 4, 5, 6]).to_frame(), + ) + + assert_frame_equal( + df.sql( + "SELECT UNNEST(list_long) as list FROM self ORDER BY sort_key NULLS FIRST" + ), + pl.Series("list", [3, 4, 5, 6, 2, 1]).to_frame(), + ) + + # Literals are broadcasted to output height of UNNEST: + assert_frame_equal( + df.sql("SELECT UNNEST(list_long) as list, 1 as x FROM self ORDER BY sort_key"), + pl.select(pl.Series("list", [2, 1, 3, 4, 5, 6]), x=1), + ) + + # Note: Filter applies before projections in SQL + assert_frame_equal( + df.sql( + "SELECT UNNEST(list_long) as list FROM self WHERE filter_mask ORDER BY sort_key" + ), + pl.Series("list", [4, 5, 6]).to_frame(), + ) + + assert_frame_equal( + df.sql( + "SELECT UNNEST(list_long) as list FROM self WHERE filter_mask_all_true ORDER BY sort_key" + ), + pl.Series("list", [2, 1, 3, 4, 5, 6]).to_frame(), + ) diff --git a/py-polars/tests/unit/sql/test_numeric.py b/py-polars/tests/unit/sql/test_numeric.py new file mode 100644 index 000000000000..fed1d97a948e --- /dev/null +++ b/py-polars/tests/unit/sql/test_numeric.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +from decimal import Decimal as D +from typing import TYPE_CHECKING + +import pytest + +import polars as pl +from polars.exceptions import SQLInterfaceError, SQLSyntaxError +from polars.testing import assert_frame_equal, assert_series_equal + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + + +def test_div() -> None: + res = pl.sql( + """ + SELECT label, DIV(a, b) AS a_div_b, DIV(tbl.b, tbl.a) AS b_div_a + FROM ( + VALUES + ('a', 20.5, 6), + ('b', NULL, 12), + ('c', 10.0, 24), + ('d', 5.0, NULL), + ('e', 2.5, 5) + ) AS tbl(label, a, b) + """ + ).collect() + + assert res.to_dict(as_series=False) == { + "label": ["a", "b", "c", "d", "e"], + "a_div_b": [3, None, 0, None, 0], + "b_div_a": [0, None, 2, None, 2], + } + + +def test_modulo() -> None: + df = pl.DataFrame( + { + "a": [1.5, None, 3.0, 13 / 3, 5.0], + "b": [6, 7, 8, 9, 10], + "c": [11, 12, 13, 14, 15], + "d": [16.5, 17.0, 18.5, None, 20.0], + } + ) + out = df.sql( + """ + SELECT + a % 2 AS a2, + b % 3 AS b3, + MOD(c, 4) AS c4, + MOD(d, 5.5) AS d55 + FROM self + """ + ) + assert_frame_equal( + out, + pl.DataFrame( + { + "a2": [1.5, None, 1.0, 1 / 3, 1.0], + "b3": [0, 1, 2, 0, 1], + "c4": [3, 0, 1, 2, 3], + "d55": [0.0, 0.5, 2.0, None, 3.5], + } + ), + ) + + +@pytest.mark.parametrize( + ("value", "sqltype", "prec_scale", "expected_value", "expected_dtype"), + [ + (64.5, "numeric", "(3,1)", D("64.5"), pl.Decimal(3, 1)), + (512.5, "decimal", "(4,1)", D("512.5"), pl.Decimal(4, 1)), + (512.5, "numeric", "(4,0)", D("512"), pl.Decimal(4, 0)), + (-1024.75, "decimal", "(10,0)", D("-1024"), pl.Decimal(10, 0)), + (-1024.75, "numeric", "(10)", D("-1024"), pl.Decimal(10, 0)), + (-1024.75, "dec", "", D("-1024.75"), pl.Decimal(38, 9)), + ], +) +def test_numeric_decimal_type( + value: float, + sqltype: str, + prec_scale: str, + expected_value: D, + expected_dtype: PolarsDataType, +) -> None: + df = pl.DataFrame({"n": [value]}) + with pl.SQLContext(df=df) as ctx: + result = ctx.execute( + f""" + SELECT n::{sqltype}{prec_scale} AS "dec" FROM df + """ + ) + expected = pl.LazyFrame( + data={"dec": [expected_value]}, + schema={"dec": expected_dtype}, + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + ("decimals", "expected"), + [ + (0, [-8192.0, -4.0, -2.0, 2.0, 4.0, 8193.0]), + (1, [-8192.5, -4.0, -1.5, 2.5, 3.6, 8192.5]), + (2, [-8192.5, -3.96, -1.54, 2.46, 3.6, 8192.5]), + (3, [-8192.499, -3.955, -1.543, 2.457, 3.599, 8192.5]), + (4, [-8192.499, -3.955, -1.5432, 2.4568, 3.599, 8192.5001]), + ], +) +def test_round_ndigits(decimals: int, expected: list[float]) -> None: + df = pl.DataFrame( + {"n": [-8192.499, -3.9550, -1.54321, 2.45678, 3.59901, 8192.5001]}, + ) + with pl.SQLContext(df=df, eager=True) as ctx: + if decimals == 0: + out = ctx.execute("SELECT ROUND(n) AS n FROM df") + assert_series_equal(out["n"], pl.Series("n", values=expected)) + + out = ctx.execute(f'SELECT ROUND("n",{decimals}) AS n FROM df') + assert_series_equal(out["n"], pl.Series("n", values=expected)) + + +def test_round_ndigits_errors() -> None: + df = pl.DataFrame({"n": [99.999]}) + with pl.SQLContext(df=df, eager=True) as ctx: + with pytest.raises( + SQLSyntaxError, match=r"invalid value for ROUND decimals \('!!'\)" + ): + ctx.execute("SELECT ROUND(n,'!!') AS n FROM df") + + with pytest.raises( + SQLInterfaceError, match=r"ROUND .* negative decimals value \(-1\)" + ): + ctx.execute("SELECT ROUND(n,-1) AS n FROM df") + + with pytest.raises( + SQLSyntaxError, match=r"ROUND expects 1-2 arguments \(found 4\)" + ): + ctx.execute("SELECT ROUND(1.2345,6,7,8) AS n FROM df") + + +def test_stddev_variance() -> None: + df = pl.DataFrame( + { + "v1": [-1.0, 0.0, 1.0], + "v2": [5.5, 0.0, 3.0], + "v3": [-10, None, 10], + "v4": [-100.0, 0.0, -50.0], + } + ) + with pl.SQLContext(df=df) as ctx: + # note: we support all common aliases for std/var + out = ctx.execute( + """ + SELECT + STDEV(v1) AS "v1_std", + STDDEV(v2) AS "v2_std", + STDEV_SAMP(v3) AS "v3_std", + STDDEV_SAMP(v4) AS "v4_std", + VAR(v1) AS "v1_var", + VARIANCE(v2) AS "v2_var", + VARIANCE(v3) AS "v3_var", + VAR_SAMP(v4) AS "v4_var" + FROM df + """ + ).collect() + + assert_frame_equal( + out, + pl.DataFrame( + { + "v1_std": [1.0], + "v2_std": [2.7537852736431], + "v3_std": [14.142135623731], + "v4_std": [50.0], + "v1_var": [1.0], + "v2_var": [7.5833333333333], + "v3_var": [200.0], + "v4_var": [2500.0], + } + ), + ) diff --git a/py-polars/tests/unit/sql/test_operators.py b/py-polars/tests/unit/sql/test_operators.py new file mode 100644 index 000000000000..e42825259b34 --- /dev/null +++ b/py-polars/tests/unit/sql/test_operators.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +import polars as pl +import polars.selectors as cs +from polars.testing import assert_frame_equal + + +@pytest.fixture +def foods_ipc_path() -> Path: + return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc" + + +def test_div() -> None: + df = pl.LazyFrame( + { + "a": [10.0, 20.0, 30.0, 40.0, 50.0], + "b": [-100.5, 7.0, 2.5, None, -3.14], + } + ) + with pl.SQLContext(df=df, eager=True) as ctx: + res = ctx.execute( + """ + SELECT + a / b AS a_div_b, + a // b AS a_floordiv_b, + SIGN(b) AS b_sign, + FROM df + """ + ) + + assert_frame_equal( + pl.DataFrame( + [ + [-0.0995024875621891, 2.85714285714286, 12.0, None, -15.92356687898089], + [-1, 2, 12, None, -16], + [-1.0, 1.0, 1.0, None, -1.0], + ], + schema=["a_div_b", "a_floordiv_b", "b_sign"], + ), + res, + ) + + +def test_equal_not_equal() -> None: + # validate null-aware/unaware equality operators + df = pl.DataFrame({"a": [1, None, 3, 6, 5], "b": [1, None, 3, 4, None]}) + + with pl.SQLContext(frame_data=df) as ctx: + out = ctx.execute( + """ + SELECT + -- not null-aware + (a = b) as "1_eq_unaware", + (a <> b) as "2_neq_unaware", + (a != b) as "3_neq_unaware", + -- null-aware + (a <=> b) as "4_eq_aware", + (a IS NOT DISTINCT FROM b) as "5_eq_aware", + (a IS DISTINCT FROM b) as "6_neq_aware", + FROM frame_data + """ + ).collect() + + assert out.select(cs.contains("_aware").null_count().sum()).row(0) == (0, 0, 0) + assert out.select(cs.contains("_unaware").null_count().sum()).row(0) == (2, 2, 2) + + assert out.to_dict(as_series=False) == { + "1_eq_unaware": [True, None, True, False, None], + "2_neq_unaware": [False, None, False, True, None], + "3_neq_unaware": [False, None, False, True, None], + "4_eq_aware": [True, True, True, False, False], + "5_eq_aware": [True, True, True, False, False], + "6_neq_aware": [False, False, False, True, True], + } + + +def test_is_between(foods_ipc_path: Path) -> None: + lf = pl.scan_ipc(foods_ipc_path) + + ctx = pl.SQLContext(foods1=lf, eager=True) + out = ctx.execute( + """ + SELECT * + FROM foods1 + WHERE foods1.calories BETWEEN 22 AND 30 + ORDER BY "calories" DESC, "sugars_g" DESC + """ + ) + assert out.rows() == [ + ("fruit", 30, 0.0, 5), + ("vegetables", 30, 0.0, 5), + ("fruit", 30, 0.0, 3), + ("vegetables", 25, 0.0, 4), + ("vegetables", 25, 0.0, 3), + ("vegetables", 25, 0.0, 2), + ("vegetables", 22, 0.0, 3), + ] + out = ctx.execute( + """ + SELECT * + FROM foods1 + WHERE calories NOT BETWEEN 22 AND 30 + ORDER BY "calories" ASC + """ + ) + assert not any((22 <= cal <= 30) for cal in out["calories"]) + + +def test_starts_with() -> None: + lf = pl.LazyFrame( + { + "x": ["aaa", "bbb", "a"], + "y": ["abc", "b", "aa"], + }, + ) + assert lf.sql("SELECT x ^@ 'a' AS x_starts_with_a FROM self").collect().rows() == [ + (True,), + (False,), + (True,), + ] + assert lf.sql("SELECT x ^@ y AS x_starts_with_y FROM self").collect().rows() == [ + (False,), + (True,), + (False,), + ] + + +@pytest.mark.parametrize("match_float", [False, True]) +def test_unary_ops_8890(match_float: bool) -> None: + with pl.SQLContext( + df=pl.DataFrame({"a": [-2, -1, 1, 2], "b": ["w", "x", "y", "z"]}), + ) as ctx: + in_values = "(-3.0, -1.0, +2.0, +4.0)" if match_float else "(-3, -1, +2, +4)" + res = ctx.execute( + f""" + SELECT *, -(3) as c, (+4) as d + FROM df WHERE a IN {in_values} + """ + ) + assert res.collect().to_dict(as_series=False) == { + "a": [-1, 2], + "b": ["x", "z"], + "c": [-3, -3], + "d": [4, 4], + } diff --git a/py-polars/tests/unit/sql/test_order_by.py b/py-polars/tests/unit/sql/test_order_by.py new file mode 100644 index 000000000000..704ebd3e773e --- /dev/null +++ b/py-polars/tests/unit/sql/test_order_by.py @@ -0,0 +1,258 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +import polars as pl +from polars.exceptions import SQLInterfaceError, SQLSyntaxError + + +@pytest.fixture +def foods_ipc_path() -> Path: + return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc" + + +def test_order_by_basic(foods_ipc_path: Path) -> None: + foods = pl.scan_ipc(foods_ipc_path) + + order_by_distinct_res = foods.sql( + """ + SELECT DISTINCT category + FROM self + ORDER BY category DESC + """ + ).collect() + assert order_by_distinct_res.to_dict(as_series=False) == { + "category": ["vegetables", "seafood", "meat", "fruit"] + } + + for category in ("category", "category AS cat"): + category_col = category.split(" ")[-1] + order_by_group_by_res = foods.sql( + f""" + SELECT {category} + FROM self + GROUP BY category + ORDER BY {category_col} DESC + """ + ).collect() + assert order_by_group_by_res.to_dict(as_series=False) == { + category_col: ["vegetables", "seafood", "meat", "fruit"] + } + + order_by_constructed_group_by_res = foods.sql( + """ + SELECT category, SUM(calories) as summed_calories + FROM self + GROUP BY category + ORDER BY summed_calories DESC + """ + ).collect() + assert order_by_constructed_group_by_res.to_dict(as_series=False) == { + "category": ["seafood", "meat", "fruit", "vegetables"], + "summed_calories": [1250, 540, 410, 192], + } + + order_by_unselected_res = foods.sql( + """ + SELECT SUM(calories) as summed_calories + FROM self + GROUP BY category + ORDER BY summed_calories DESC + """ + ).collect() + assert order_by_unselected_res.to_dict(as_series=False) == { + "summed_calories": [1250, 540, 410, 192], + } + + +def test_order_by_misc_selection() -> None: + df = pl.DataFrame({"x": [None, 1, 2, 3], "y": [4, 2, None, 8]}) + + # order by aliased col + res = df.sql("SELECT x, y AS y2 FROM self ORDER BY y2") + assert res.to_dict(as_series=False) == {"x": [1, None, 3, 2], "y2": [2, 4, 8, None]} + + res = df.sql("SELECT x, y AS y2 FROM self ORDER BY y2 DESC") + assert res.to_dict(as_series=False) == {"x": [2, 3, None, 1], "y2": [None, 8, 4, 2]} + + # order by col found in wildcard + res = df.sql("SELECT *, y AS y2 FROM self ORDER BY y") + assert res.to_dict(as_series=False) == { + "x": [1, None, 3, 2], + "y": [2, 4, 8, None], + "y2": [2, 4, 8, None], + } + res = df.sql("SELECT *, y AS y2 FROM self ORDER BY y NULLS FIRST") + assert res.to_dict(as_series=False) == { + "x": [2, 1, None, 3], + "y": [None, 2, 4, 8], + "y2": [None, 2, 4, 8], + } + + # order by col found in qualified wildcard + res = df.sql("SELECT self.* FROM self ORDER BY x NULLS FIRST") + assert res.to_dict(as_series=False) == {"x": [None, 1, 2, 3], "y": [4, 2, None, 8]} + + res = df.sql("SELECT self.* FROM self ORDER BY y NULLS FIRST") + assert res.to_dict(as_series=False) == {"x": [2, 1, None, 3], "y": [None, 2, 4, 8]} + + # order by col excluded from wildcard + res = df.sql("SELECT * EXCLUDE y FROM self ORDER BY y") + assert res.to_dict(as_series=False) == {"x": [1, None, 3, 2]} + + res = df.sql("SELECT * EXCLUDE y FROM self ORDER BY y NULLS FIRST") + assert res.to_dict(as_series=False) == {"x": [2, 1, None, 3]} + + # order by col excluded from qualified wildcard + res = df.sql("SELECT self.* EXCLUDE y FROM self ORDER BY y") + assert res.to_dict(as_series=False) == {"x": [1, None, 3, 2]} + + # order by expression + res = df.sql("SELECT (x % y) AS xmy FROM self ORDER BY -(x % y)") + assert res.to_dict(as_series=False) == {"xmy": [3, 1, None, None]} + + res = df.sql("SELECT (x % y) AS xmy FROM self ORDER BY x % y NULLS FIRST") + assert res.to_dict(as_series=False) == {"xmy": [None, None, 1, 3]} + + # confirm that 'order by all' syntax prioritises cols + df = pl.DataFrame({"SOME": [0, 1], "ALL": [1, 0]}) + res = df.sql("SELECT * FROM self ORDER BY ALL") + assert res.to_dict(as_series=False) == {"SOME": [1, 0], "ALL": [0, 1]} + + res = df.sql("SELECT * FROM self ORDER BY ALL DESC") + assert res.to_dict(as_series=False) == {"SOME": [0, 1], "ALL": [1, 0]} + + +def test_order_by_misc_16579() -> None: + res = pl.DataFrame( + { + "x": ["apple", "orange"], + "y": ["sheep", "alligator"], + "z": ["hello", "world"], + } + ).sql( + """ + SELECT z, y, x + FROM self ORDER BY y DESC + """ + ) + assert res.columns == ["z", "y", "x"] + assert res.to_dict(as_series=False) == { + "z": ["hello", "world"], + "y": ["sheep", "alligator"], + "x": ["apple", "orange"], + } + + +def test_order_by_multi_nulls_first_last() -> None: + df = pl.DataFrame({"x": [None, 1, None, 3], "y": [3, 2, None, 1]}) + # ┌──────┬──────┐ + # │ x ┆ y │ + # │ --- ┆ --- │ + # │ i64 ┆ i64 │ + # ╞══════╪══════╡ + # │ null ┆ 3 │ + # │ 1 ┆ 2 │ + # │ null ┆ null │ + # │ 3 ┆ 1 │ + # └──────┴──────┘ + + res1 = df.sql("SELECT * FROM self ORDER BY x, y") + res2 = df.sql("SELECT * FROM self ORDER BY ALL") + for res in (res1, res2): + assert res.to_dict(as_series=False) == { + "x": [1, 3, None, None], + "y": [2, 1, 3, None], + } + + res = df.sql("SELECT * FROM self ORDER BY x NULLS FIRST, y") + assert res.to_dict(as_series=False) == { + "x": [None, None, 1, 3], + "y": [3, None, 2, 1], + } + + res = df.sql("SELECT * FROM self ORDER BY x, y NULLS FIRST") + assert res.to_dict(as_series=False) == { + "x": [1, 3, None, None], + "y": [2, 1, None, 3], + } + + res1 = df.sql("SELECT * FROM self ORDER BY x NULLS FIRST, y NULLS FIRST") + res2 = df.sql("SELECT * FROM self ORDER BY All NULLS FIRST") + for res in (res1, res2): + assert res.to_dict(as_series=False) == { + "x": [None, None, 1, 3], + "y": [None, 3, 2, 1], + } + + res1 = df.sql("SELECT * FROM self ORDER BY x DESC NULLS FIRST, y DESC NULLS FIRST") + res2 = df.sql("SELECT * FROM self ORDER BY all DESC NULLS FIRST") + for res in (res1, res2): + assert res.to_dict(as_series=False) == { + "x": [None, None, 3, 1], + "y": [None, 3, 1, 2], + } + + res = df.sql("SELECT * FROM self ORDER BY x DESC NULLS FIRST, y DESC NULLS LAST") + assert res.to_dict(as_series=False) == { + "x": [None, None, 3, 1], + "y": [3, None, 1, 2], + } + + res = df.sql("SELECT * FROM self ORDER BY y DESC NULLS FIRST, x NULLS LAST") + assert res.to_dict(as_series=False) == { + "x": [None, None, 1, 3], + "y": [None, 3, 2, 1], + } + + +def test_order_by_ordinal() -> None: + df = pl.DataFrame({"x": [None, 1, None, 3], "y": [3, 2, None, 1]}) + + res = df.sql("SELECT * FROM self ORDER BY 1, 2") + assert res.to_dict(as_series=False) == { + "x": [1, 3, None, None], + "y": [2, 1, 3, None], + } + + res = df.sql("SELECT * FROM self ORDER BY 1 DESC, 2") + assert res.to_dict(as_series=False) == { + "x": [None, None, 3, 1], + "y": [3, None, 1, 2], + } + + res = df.sql("SELECT * FROM self ORDER BY 1 DESC NULLS LAST, 2 ASC") + assert res.to_dict(as_series=False) == { + "x": [3, 1, None, None], + "y": [1, 2, 3, None], + } + + res = df.sql("SELECT * FROM self ORDER BY 1 DESC NULLS LAST, 2 ASC NULLS FIRST") + assert res.to_dict(as_series=False) == { + "x": [3, 1, None, None], + "y": [1, 2, None, 3], + } + + res = df.sql("SELECT * FROM self ORDER BY 1 DESC, 2 DESC NULLS FIRST") + assert res.to_dict(as_series=False) == { + "x": [None, None, 3, 1], + "y": [None, 3, 1, 2], + } + + +def test_order_by_errors() -> None: + df = pl.DataFrame({"a": ["w", "x", "y", "z"], "b": [1, 2, 3, 4]}) + + with pytest.raises( + SQLInterfaceError, + match="ORDER BY ordinal value must refer to a valid column; found 99", + ): + df.sql("SELECT * FROM self ORDER BY 99") + + with pytest.raises( + SQLSyntaxError, + match="negative ordinal values are invalid for ORDER BY; found -1", + ): + df.sql("SELECT * FROM self ORDER BY -1") diff --git a/py-polars/tests/unit/sql/test_regex.py b/py-polars/tests/unit/sql/test_regex.py new file mode 100644 index 000000000000..af0b9a31fa67 --- /dev/null +++ b/py-polars/tests/unit/sql/test_regex.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +import polars as pl +from polars.exceptions import SQLSyntaxError + + +@pytest.fixture +def foods_ipc_path() -> Path: + return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc" + + +@pytest.mark.parametrize( + ("regex_op", "expected"), + [ + ("RLIKE", [0, 3]), + ("REGEXP", [0, 3]), + ("NOT RLIKE", [1, 2, 4]), + ("NOT REGEXP", [1, 2, 4]), + ], +) +def test_regex_expr_match(regex_op: str, expected: list[int]) -> None: + # note: the REGEXP and RLIKE operators can also use another + # column/expression as the source of the match pattern + df = pl.DataFrame( + { + "idx": [0, 1, 2, 3, 4], + "str": ["ABC", "abc", "000", "A0C", "a0c"], + "pat": ["^A", "^A", "^A", r"[AB]\d.*$", ".*xxx$"], + } + ) + with pl.SQLContext(df=df, eager=True) as ctx: + out = ctx.execute(f"SELECT idx, str FROM df WHERE str {regex_op} pat") + assert out.to_series().to_list() == expected + + +@pytest.mark.parametrize( + ("op", "pattern", "expected"), + [ + ("~", "^veg", "vegetables"), + ("~", "^VEG", None), + ("~*", "^VEG", "vegetables"), + ("!~", "(t|s)$", "seafood"), + ("!~*", "(T|S)$", "seafood"), + ("!~*", "^.E", "fruit"), + ("!~*", "[aeiOU]", None), + ("RLIKE", "^veg", "vegetables"), + ("RLIKE", "^VEG", None), + ("RLIKE", "(?i)^VEG", "vegetables"), + ("NOT RLIKE", "(t|s)$", "seafood"), + ("NOT RLIKE", "(?i)(T|S)$", "seafood"), + ("NOT RLIKE", "(?i)^.E", "fruit"), + ("NOT RLIKE", "(?i)[aeiOU]", None), + ("REGEXP", "^veg", "vegetables"), + ("REGEXP", "^VEG", None), + ("REGEXP", "(?i)^VEG", "vegetables"), + ("NOT REGEXP", "(t|s)$", "seafood"), + ("NOT REGEXP", "(?i)(T|S)$", "seafood"), + ("NOT REGEXP", "(?i)^.E", "fruit"), + ("NOT REGEXP", "(?i)[aeiOU]", None), + ], +) +def test_regex_operators( + foods_ipc_path: Path, op: str, pattern: str, expected: str | None +) -> None: + lf = pl.scan_ipc(foods_ipc_path) + + with pl.SQLContext(foods=lf, eager=True) as ctx: + out = ctx.execute( + f""" + SELECT DISTINCT category FROM foods + WHERE category {op} '{pattern}' + """ + ) + assert out.rows() == ([(expected,)] if expected else []) + + +def test_regex_operators_error() -> None: + df = pl.LazyFrame({"sval": ["ABC", "abc", "000", "A0C", "a0c"]}) + with pl.SQLContext(df=df, eager=True) as ctx: + with pytest.raises( + SQLSyntaxError, match="invalid pattern for '~' operator: dyn .*12345" + ): + ctx.execute("SELECT * FROM df WHERE sval ~ 12345") + with pytest.raises( + SQLSyntaxError, + match=r"""invalid pattern for '!~\*' operator: col\("abcde"\)""", + ): + ctx.execute("SELECT * FROM df WHERE sval !~* abcde") + + +@pytest.mark.parametrize( + ("not_", "pattern", "flags", "expected"), + [ + ("", "^veg", None, "vegetables"), + ("", "^VEG", None, None), + ("", "(?i)^VEG", None, "vegetables"), + ("NOT", "(t|s)$", None, "seafood"), + ("NOT", "T|S$", "i", "seafood"), + ("NOT", "^.E", "i", "fruit"), + ("NOT", "[aeiOU]", "i", None), + ], +) +def test_regexp_like( + foods_ipc_path: Path, + not_: str, + pattern: str, + flags: str | None, + expected: str | None, +) -> None: + lf = pl.scan_ipc(foods_ipc_path) + flags = "" if flags is None else f",'{flags}'" + with pl.SQLContext(foods=lf, eager=True) as ctx: + out = ctx.execute( + f""" + SELECT DISTINCT category FROM foods + WHERE {not_} REGEXP_LIKE(category,'{pattern}'{flags}) + """ + ) + assert out.rows() == ([(expected,)] if expected else []) + + +def test_regexp_like_errors() -> None: + with pl.SQLContext(df=pl.DataFrame({"scol": ["xyz"]})) as ctx: + with pytest.raises( + SQLSyntaxError, + match="invalid/empty 'flags' for REGEXP_LIKE", + ): + ctx.execute("SELECT * FROM df WHERE REGEXP_LIKE(scol,'[x-z]+','')") + + with pytest.raises( + SQLSyntaxError, + match="invalid arguments for REGEXP_LIKE", + ): + ctx.execute("SELECT * FROM df WHERE REGEXP_LIKE(scol,999,999)") + + with pytest.raises( + SQLSyntaxError, + match=r"REGEXP_LIKE expects 2-3 arguments \(found 1\)", + ): + ctx.execute("SELECT * FROM df WHERE REGEXP_LIKE(scol)") diff --git a/py-polars/tests/unit/sql/test_set_ops.py b/py-polars/tests/unit/sql/test_set_ops.py new file mode 100644 index 000000000000..2472e39474d0 --- /dev/null +++ b/py-polars/tests/unit/sql/test_set_ops.py @@ -0,0 +1,216 @@ +from __future__ import annotations + +import pytest + +import polars as pl +from polars.exceptions import SQLInterfaceError +from polars.testing import assert_frame_equal + + +def test_except_intersect() -> None: + df1 = pl.DataFrame({"x": [1, 9, 1, 1], "y": [2, 3, 4, 4], "z": [5, 5, 5, 5]}) # noqa: F841 + df2 = pl.DataFrame({"x": [1, 9, 1], "y": [2, None, 4], "z": [7, 6, 5]}) # noqa: F841 + + res_e = pl.sql("SELECT x, y, z FROM df1 EXCEPT SELECT * FROM df2", eager=True) + res_i = pl.sql("SELECT * FROM df1 INTERSECT SELECT x, y, z FROM df2", eager=True) + + assert sorted(res_e.rows()) == [(1, 2, 5), (9, 3, 5)] + assert sorted(res_i.rows()) == [(1, 4, 5)] + + res_e = pl.sql("SELECT * FROM df2 EXCEPT TABLE df1", eager=True) + res_i = pl.sql( + """ + SELECT * FROM df2 + INTERSECT + SELECT x::int8, y::int8, z::int8 + FROM (VALUES (1,2,5),(9,3,5),(1,4,5),(1,4,5)) AS df1(x,y,z) + """, + eager=True, + ) + assert sorted(res_e.rows()) == [(1, 2, 7), (9, None, 6)] + assert sorted(res_i.rows()) == [(1, 4, 5)] + + # check null behaviour of nulls + with pl.SQLContext( + tbl1=pl.DataFrame({"x": [2, 9, 1], "y": [2, None, 4]}), + tbl2=pl.DataFrame({"x": [1, 9, 1], "y": [2, None, 4]}), + ) as ctx: + res = ctx.execute("SELECT * FROM tbl1 EXCEPT SELECT * FROM tbl2", eager=True) + assert_frame_equal(pl.DataFrame({"x": [2], "y": [2]}), res) + + +def test_except_intersect_by_name() -> None: + df1 = pl.DataFrame( # noqa: F841 + { + "x": [1, 9, 1, 1], + "y": [2, 3, 4, 4], + "z": [5, 5, 5, 5], + } + ) + df2 = pl.DataFrame( # noqa: F841 + { + "y": [2, None, 4], + "w": ["?", "!", "%"], + "z": [7, 6, 5], + "x": [1, 9, 1], + } + ) + res_e = pl.sql( + "SELECT x, y, z FROM df1 EXCEPT BY NAME SELECT * FROM df2", + eager=True, + ) + res_i = pl.sql( + "SELECT * FROM df1 INTERSECT BY NAME SELECT * FROM df2", + eager=True, + ) + assert sorted(res_e.rows()) == [(1, 2, 5), (9, 3, 5)] + assert sorted(res_i.rows()) == [(1, 4, 5)] + assert res_e.columns == ["x", "y", "z"] + assert res_i.columns == ["x", "y", "z"] + + +@pytest.mark.parametrize( + ("op", "op_subtype"), + [ + ("EXCEPT", "ALL"), + ("EXCEPT", "ALL BY NAME"), + ("INTERSECT", "ALL"), + ("INTERSECT", "ALL BY NAME"), + ], +) +def test_except_intersect_all_unsupported(op: str, op_subtype: str) -> None: + df1 = pl.DataFrame({"n": [1, 1, 1, 2, 2, 2, 3]}) # noqa: F841 + df2 = pl.DataFrame({"n": [1, 1, 2, 2]}) # noqa: F841 + + with pytest.raises( + SQLInterfaceError, + match=f"'{op} {op_subtype}' is not supported", + ): + pl.sql(f"SELECT * FROM df1 {op} {op_subtype} SELECT * FROM df2") + + +def test_update_statement_error() -> None: + df_large = pl.DataFrame( + { + "FQDN": ["c.ORG.na", "a.COM.na"], + "NS1": ["ns1.c.org.na", "ns1.d.net.na"], + "NS2": ["ns2.c.org.na", "ns2.d.net.na"], + "NS3": ["ns3.c.org.na", "ns3.d.net.na"], + } + ) + df_small = pl.DataFrame( + { + "FQDN": ["c.org.na"], + "NS1": ["ns1.c.org.na|127.0.0.1"], + "NS2": ["ns2.c.org.na|127.0.0.1"], + "NS3": ["ns3.c.org.na|127.0.0.1"], + } + ) + + # Create a context and register the tables + ctx = pl.SQLContext() + ctx.register("large", df_large) + ctx.register("small", df_small) + + with pytest.raises( + SQLInterfaceError, + match="'UPDATE large SET FQDN = u.FQDN, NS1 = u.NS1, NS2 = u.NS2, NS3 = u.NS3 FROM u WHERE large.FQDN = u.FQDN' operation is currently unsupported", + ): + ctx.execute(""" + WITH u AS ( + SELECT + small.FQDN, + small.NS1, + small.NS2, + small.NS3 + FROM small + INNER JOIN large ON small.FQDN = large.FQDN + ) + UPDATE large + SET + FQDN = u.FQDN, + NS1 = u.NS1, + NS2 = u.NS2, + NS3 = u.NS3 + FROM u + WHERE large.FQDN = u.FQDN + """) + + +@pytest.mark.parametrize("op", ["EXCEPT", "INTERSECT", "UNION"]) +def test_except_intersect_errors(op: str) -> None: + df1 = pl.DataFrame({"x": [1, 9, 1, 1], "y": [2, 3, 4, 4], "z": [5, 5, 5, 5]}) # noqa: F841 + df2 = pl.DataFrame({"x": [1, 9, 1], "y": [2, None, 4], "z": [7, 6, 5]}) # noqa: F841 + + if op != "UNION": + with pytest.raises( + SQLInterfaceError, + match=f"'{op} ALL' is not supported", + ): + pl.sql(f"SELECT * FROM df1 {op} ALL SELECT * FROM df2", eager=False) + + with pytest.raises( + SQLInterfaceError, + match=f"{op} requires equal number of columns in each table", + ): + pl.sql(f"SELECT x FROM df1 {op} SELECT x, y FROM df2", eager=False) + + +@pytest.mark.parametrize( + ("cols1", "cols2", "union_subtype", "expected"), + [ + ( + ["*"], + ["*"], + "", + [(1, "zz"), (2, "yy"), (3, "xx")], + ), + ( + ["*"], + ["frame2.*"], + "ALL", + [(1, "zz"), (2, "yy"), (2, "yy"), (3, "xx")], + ), + ( + ["frame1.*"], + ["c1", "c2"], + "DISTINCT", + [(1, "zz"), (2, "yy"), (3, "xx")], + ), + ( + ["*"], + ["c2", "c1"], + "ALL BY NAME", + [(1, "zz"), (2, "yy"), (2, "yy"), (3, "xx")], + ), + ( + ["c1", "c2"], + ["c2", "c1"], + "BY NAME", + [(1, "zz"), (2, "yy"), (3, "xx")], + ), + pytest.param( + ["c1", "c2"], + ["c2", "c1"], + "DISTINCT BY NAME", + [(1, "zz"), (2, "yy"), (3, "xx")], + ), + ], +) +def test_union( + cols1: list[str], + cols2: list[str], + union_subtype: str, + expected: list[tuple[int, str]], +) -> None: + with pl.SQLContext( + frame1=pl.DataFrame({"c1": [1, 2], "c2": ["zz", "yy"]}), + frame2=pl.DataFrame({"c1": [2, 3], "c2": ["yy", "xx"]}), + eager=True, + ) as ctx: + query = f""" + SELECT {", ".join(cols1)} FROM frame1 + UNION {union_subtype} + SELECT {", ".join(cols2)} FROM frame2 + """ + assert sorted(ctx.execute(query).rows()) == expected diff --git a/py-polars/tests/unit/sql/test_strings.py b/py-polars/tests/unit/sql/test_strings.py new file mode 100644 index 000000000000..dcb9b983bc0b --- /dev/null +++ b/py-polars/tests/unit/sql/test_strings.py @@ -0,0 +1,473 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +import polars as pl +from polars.exceptions import SQLSyntaxError +from polars.testing import assert_frame_equal + + +# TODO: Do not rely on I/O for these tests +@pytest.fixture +def foods_ipc_path() -> Path: + return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc" + + +def test_string_case() -> None: + df = pl.DataFrame({"words": ["Test SOME words"]}) + + with pl.SQLContext(frame=df) as ctx: + res = ctx.execute( + """ + SELECT + words, + INITCAP(words) as cap, + UPPER(words) as upper, + LOWER(words) as lower, + FROM frame + """ + ).collect() + + assert res.to_dict(as_series=False) == { + "words": ["Test SOME words"], + "cap": ["Test Some Words"], + "upper": ["TEST SOME WORDS"], + "lower": ["test some words"], + } + + +def test_string_concat() -> None: + lf = pl.LazyFrame( + { + "x": ["a", None, "c"], + "y": ["d", "e", "f"], + "z": [1, 2, 3], + } + ) + res = lf.sql( + """ + SELECT + ("x" || "x" || "y") AS c0, + ("x" || "y" || "z") AS c1, + CONCAT(("x" || '-'), "y") AS c2, + CONCAT("x", "x", "y") AS c3, + CONCAT("x", "y", ("z" * 2)) AS c4, + CONCAT_WS(':', "x", "y", "z") AS c5, + CONCAT_WS('', "y", "z", '!') AS c6 + FROM self + """, + ).collect() + + assert res.to_dict(as_series=False) == { + "c0": ["aad", None, "ccf"], + "c1": ["ad1", None, "cf3"], + "c2": ["a-d", "e", "c-f"], + "c3": ["aad", "e", "ccf"], + "c4": ["ad2", "e4", "cf6"], + "c5": ["a:d:1", "e:2", "c:f:3"], + "c6": ["d1!", "e2!", "f3!"], + } + + +@pytest.mark.parametrize( + "invalid_concat", ["CONCAT()", "CONCAT_WS()", "CONCAT_WS(':')"] +) +def test_string_concat_errors(invalid_concat: str) -> None: + lf = pl.LazyFrame({"x": ["a", "b", "c"]}) + with pytest.raises( + SQLSyntaxError, + match=r"CONCAT.*expects at least \d argument[s]? \(found \d\)", + ): + pl.SQLContext(data=lf).execute(f"SELECT {invalid_concat} FROM data") + + +def test_string_left_right_reverse() -> None: + df = pl.DataFrame({"txt": ["abcde", "abc", "a", None]}) + ctx = pl.SQLContext(df=df) + res = ctx.execute( + """ + SELECT + LEFT(txt,2) AS "l", + RIGHT(txt,2) AS "r", + REVERSE(txt) AS "rev" + FROM df + """, + ).collect() + + assert res.to_dict(as_series=False) == { + "l": ["ab", "ab", "a", None], + "r": ["de", "bc", "a", None], + "rev": ["edcba", "cba", "a", None], + } + for func, invalid_arg, invalid_err in ( + ("LEFT", "'xyz'", '"xyz"'), + ("RIGHT", "6.66", "(dyn float: 6.66)"), + ): + with pytest.raises( + SQLSyntaxError, + match=rf"""invalid 'n_chars' for {func} \({invalid_err}\)""", + ): + ctx.execute(f"""SELECT {func}(txt,{invalid_arg}) FROM df""").collect() + + +def test_string_left_negative_expr() -> None: + # negative values and expressions + df = pl.DataFrame({"s": ["alphabet", "alphabet"], "n": [-6, 6]}) + with pl.SQLContext(df=df, eager=True) as sql: + res = sql.execute( + """ + SELECT + LEFT("s",-50) AS l0, -- empty string + LEFT("s",-3) AS l1, -- all but last three chars + LEFT("s",SIGN(-1)) AS l2, -- all but last char (expr => -1) + LEFT("s",0) AS l3, -- empty string + LEFT("s",NULL) AS l4, -- null + LEFT("s",1) AS l5, -- first char + LEFT("s",SIGN(1)) AS l6, -- first char (expr => 1) + LEFT("s",3) AS l7, -- first three chars + LEFT("s",50) AS l8, -- entire string + LEFT("s","n") AS l9, -- from other col + FROM df + """ + ) + assert res.to_dict(as_series=False) == { + "l0": ["", ""], + "l1": ["alpha", "alpha"], + "l2": ["alphabe", "alphabe"], + "l3": ["", ""], + "l4": [None, None], + "l5": ["a", "a"], + "l6": ["a", "a"], + "l7": ["alp", "alp"], + "l8": ["alphabet", "alphabet"], + "l9": ["al", "alphab"], + } + + +def test_string_right_negative_expr() -> None: + # negative values and expressions + df = pl.DataFrame({"s": ["alphabet", "alphabet"], "n": [-6, 6]}) + with pl.SQLContext(df=df, eager=True) as sql: + res = sql.execute( + """ + SELECT + RIGHT("s",-50) AS l0, -- empty string + RIGHT("s",-3) AS l1, -- all but first three chars + RIGHT("s",SIGN(-1)) AS l2, -- all but first char (expr => -1) + RIGHT("s",0) AS l3, -- empty string + RIGHT("s",NULL) AS l4, -- null + RIGHT("s",1) AS l5, -- last char + RIGHT("s",SIGN(1)) AS l6, -- last char (expr => 1) + RIGHT("s",3) AS l7, -- last three chars + RIGHT("s",50) AS l8, -- entire string + RIGHT("s","n") AS l9, -- from other col + FROM df + """ + ) + assert res.to_dict(as_series=False) == { + "l0": ["", ""], + "l1": ["habet", "habet"], + "l2": ["lphabet", "lphabet"], + "l3": ["", ""], + "l4": [None, None], + "l5": ["t", "t"], + "l6": ["t", "t"], + "l7": ["bet", "bet"], + "l8": ["alphabet", "alphabet"], + "l9": ["et", "phabet"], + } + + +def test_string_lengths() -> None: + df = pl.DataFrame({"words": ["Café", None, "東京", ""]}) + + with pl.SQLContext(frame=df) as ctx: + res = ctx.execute( + """ + SELECT + words, + LENGTH(words) AS n_chrs1, + CHAR_LENGTH(words) AS n_chrs2, + CHARACTER_LENGTH(words) AS n_chrs3, + OCTET_LENGTH(words) AS n_bytes, + BIT_LENGTH(words) AS n_bits + FROM frame + """ + ).collect() + + assert res.to_dict(as_series=False) == { + "words": ["Café", None, "東京", ""], + "n_chrs1": [4, None, 2, 0], + "n_chrs2": [4, None, 2, 0], + "n_chrs3": [4, None, 2, 0], + "n_bytes": [5, None, 6, 0], + "n_bits": [40, None, 48, 0], + } + + +@pytest.mark.parametrize( + ("pattern", "like", "expected"), + [ + ("a%", "LIKE", [1, 4]), + ("a%", "ILIKE", [0, 1, 3, 4]), + ("ab%", "LIKE", [1]), + ("AB%", "ILIKE", [0, 1]), + ("ab_", "LIKE", [1]), + ("A__", "ILIKE", [0, 1]), + ("_0%_", "LIKE", [2, 4]), + ("%0", "LIKE", [2]), + ("0%", "LIKE", [2]), + ("__0%", "~~", [2, 3]), + ("%*%", "~~*", [3]), + ("____", "~~", [4]), + ("a%C", "~~", []), + ("a%C", "~~*", [0, 1, 3]), + ("%C?", "~~*", [4]), + ("a0c?", "~~", [4]), + ("000", "~~", [2]), + ("00", "~~", []), + ], +) +def test_string_like(pattern: str, like: str, expected: list[int]) -> None: + df = pl.DataFrame( + { + "idx": [0, 1, 2, 3, 4], + "txt": ["ABC", "abc", "000", "A[0]*C", "a0c?"], + } + ) + with pl.SQLContext(df=df) as ctx: + for not_ in ("", ("NOT " if like.endswith("LIKE") else "!")): + out = ctx.execute( + f"SELECT idx FROM df WHERE txt {not_}{like} '{pattern}'" + ).collect() + + res = out["idx"].to_list() + if not_: + expected = [i for i in df["idx"] if i not in expected] + assert res == expected + + +def test_string_like_multiline() -> None: + s1 = "Hello World" + s2 = "Hello\nWorld" + s3 = "hello\nWORLD" + + df = pl.DataFrame({"idx": [0, 1, 2], "txt": [s1, s2, s3]}) + + # starts with... + res1 = df.sql("SELECT * FROM self WHERE txt LIKE 'Hello%' ORDER BY idx") + res2 = df.sql("SELECT * FROM self WHERE txt ILIKE 'HELLO%' ORDER BY idx") + + assert res1["txt"].to_list() == [s1, s2] + assert res2["txt"].to_list() == [s1, s2, s3] + + # ends with... + res3 = df.sql("SELECT * FROM self WHERE txt LIKE '%WORLD' ORDER BY idx") + res4 = df.sql("SELECT * FROM self WHERE txt ILIKE '%\nWORLD' ORDER BY idx") + + assert res3["txt"].to_list() == [s3] + assert res4["txt"].to_list() == [s2, s3] + + # exact match + for s in (s1, s2, s3): + assert df.sql(f"SELECT txt FROM self WHERE txt LIKE '{s}'").item() == s + + +@pytest.mark.parametrize("form", ["NFKC", "NFKD"]) +def test_string_normalize(form: str) -> None: + df = pl.DataFrame({"txt": ["Test", "𝕋𝕖𝕤𝕥", "𝕿𝖊𝖘𝖙", "𝗧𝗲𝘀𝘁", "Ⓣⓔⓢⓣ"]}) # noqa: RUF001 + res = df.sql( + f""" + SELECT txt, NORMALIZE(txt,{form}) AS norm_txt + FROM self + """ + ) + assert res.to_dict(as_series=False) == { + "txt": ["Test", "𝕋𝕖𝕤𝕥", "𝕿𝖊𝖘𝖙", "𝗧𝗲𝘀𝘁", "Ⓣⓔⓢⓣ"], # noqa: RUF001 + "norm_txt": ["Test", "Test", "Test", "Test", "Test"], + } + + +def test_string_position() -> None: + df = pl.Series( + name="city", + values=["Dubai", "Abu Dhabi", "Sharjah", "Al Ain", "Ajman", "Ras Al Khaimah"], + ).to_frame() + + with pl.SQLContext(cities=df, eager=True) as ctx: + res = ctx.execute( + """ + SELECT + POSITION('a' IN city) AS a_lc1, + POSITION('A' IN city) AS a_uc1, + STRPOS(city,'a') AS a_lc2, + STRPOS(city,'A') AS a_uc2, + FROM cities + """ + ) + expected_lc = [4, 7, 3, 0, 4, 2] + expected_uc = [0, 1, 0, 1, 1, 5] + + assert res.to_dict(as_series=False) == { + "a_lc1": expected_lc, + "a_uc1": expected_uc, + "a_lc2": expected_lc, + "a_uc2": expected_uc, + } + + df = pl.DataFrame({"txt": ["AbCdEXz", "XyzFDkE"]}) + with pl.SQLContext(txt=df) as ctx: + res = ctx.execute( + """ + SELECT + txt, + POSITION('E' IN txt) AS match_E, + STRPOS(txt,'X') AS match_X + FROM txt + """, + eager=True, + ) + assert_frame_equal( + res, + pl.DataFrame( + data={ + "txt": ["AbCdEXz", "XyzFDkE"], + "match_E": [5, 7], + "match_X": [6, 1], + }, + schema={ + "txt": pl.String, + "match_E": pl.UInt32, + "match_X": pl.UInt32, + }, + ), + ) + + +def test_string_replace() -> None: + df = pl.DataFrame({"words": ["Yemeni coffee is the best coffee", "", None]}) + with pl.SQLContext(df=df) as ctx: + out = ctx.execute( + """ + SELECT + REPLACE( + REPLACE(words, 'coffee', 'tea'), + 'Yemeni', + 'English breakfast' + ) + FROM df + """ + ).collect() + + res = out["words"].to_list() + assert res == ["English breakfast tea is the best tea", "", None] + + with pytest.raises( + SQLSyntaxError, match=r"REPLACE expects 3 arguments \(found 2\)" + ): + ctx.execute("SELECT REPLACE(words,'coffee') FROM df") + + +def test_string_split() -> None: + df = pl.DataFrame({"s": ["xx,yy,zz", "abc,,xyz", "", None]}) + res = df.sql("SELECT *, STRING_TO_ARRAY(s,',') AS s_array FROM self") + + assert res.schema == {"s": pl.String, "s_array": pl.List(pl.String)} + assert res.to_dict(as_series=False) == { + "s": ["xx,yy,zz", "abc,,xyz", "", None], + "s_array": [["xx", "yy", "zz"], ["abc", "", "xyz"], [""], None], + } + + +def test_string_split_part() -> None: + df = pl.DataFrame({"s": ["xx,yy,zz", "abc,,xyz,???,hmm", "", None]}) + res = df.sql( + """ + SELECT + SPLIT_PART(s,',',1) AS "s+1", + SPLIT_PART(s,',',3) AS "s+3", + SPLIT_PART(s,',',-2) AS "s-2", + FROM self + """ + ) + assert res.to_dict(as_series=False) == { + "s+1": ["xx", "abc", "", None], + "s+3": ["zz", "xyz", "", None], + "s-2": ["yy", "???", "", None], + } + + +def test_string_substr() -> None: + df = pl.DataFrame( + {"scol": ["abcdefg", "abcde", "abc", None], "n": [-2, 3, 2, None]} + ) + with pl.SQLContext(df=df) as ctx: + res = ctx.execute( + """ + SELECT + -- note: sql is 1-indexed + SUBSTR(scol,1) AS s1, + SUBSTR(scol,2) AS s2, + SUBSTR(scol,3) AS s3, + SUBSTR(scol,1,5) AS s1_5, + SUBSTR(scol,2,2) AS s2_2, + SUBSTR(scol,3,1) AS s3_1, + SUBSTR(scol,-3) AS "s-3", + SUBSTR(scol,-3,3) AS "s-3_3", + SUBSTR(scol,-3,4) AS "s-3_4", + SUBSTR(scol,-3,5) AS "s-3_5", + SUBSTR(scol,-10,13) AS "s-10_13", + SUBSTR(scol,"n",2) AS "s-n2", + SUBSTR(scol,2,"n"+3) AS "s-2n3" + FROM df + """ + ).collect() + + with pytest.raises( + SQLSyntaxError, + match=r"SUBSTR does not support negative length \(-99\)", + ): + ctx.execute("SELECT SUBSTR(scol,2,-99) FROM df") + + with pytest.raises( + SQLSyntaxError, + match=r"SUBSTR expects 2-3 arguments \(found 1\)", + ): + pl.sql_expr("SUBSTR(s)") + + assert res.to_dict(as_series=False) == { + "s1": ["abcdefg", "abcde", "abc", None], + "s2": ["bcdefg", "bcde", "bc", None], + "s3": ["cdefg", "cde", "c", None], + "s1_5": ["abcde", "abcde", "abc", None], + "s2_2": ["bc", "bc", "bc", None], + "s3_1": ["c", "c", "c", None], + "s-3": ["abcdefg", "abcde", "abc", None], + "s-3_3": ["", "", "", None], + "s-3_4": ["", "", "", None], + "s-3_5": ["a", "a", "a", None], + "s-10_13": ["ab", "ab", "ab", None], + "s-n2": ["", "cd", "bc", None], + "s-2n3": ["b", "bcde", "bc", None], + } + + +def test_string_trim(foods_ipc_path: Path) -> None: + lf = pl.scan_ipc(foods_ipc_path) + out = lf.sql( + """ + SELECT DISTINCT TRIM(LEADING 'vmf' FROM category) as new_category + FROM self ORDER BY new_category DESC + """ + ).collect() + assert out.to_dict(as_series=False) == { + "new_category": ["seafood", "ruit", "egetables", "eat"] + } + with pytest.raises( + SQLSyntaxError, + match="unsupported TRIM syntax", + ): + # currently unsupported (snowflake-style) trim syntax + lf.sql("SELECT DISTINCT TRIM('*^xxxx^*', '^*') as new_category FROM self") diff --git a/py-polars/tests/unit/sql/test_structs.py b/py-polars/tests/unit/sql/test_structs.py new file mode 100644 index 000000000000..4f7a7b16c151 --- /dev/null +++ b/py-polars/tests/unit/sql/test_structs.py @@ -0,0 +1,219 @@ +from __future__ import annotations + +import pytest + +import polars as pl +from polars.exceptions import ( + SQLInterfaceError, + SQLSyntaxError, + StructFieldNotFoundError, +) +from polars.testing import assert_frame_equal + + +@pytest.fixture +def df_struct() -> pl.DataFrame: + return pl.DataFrame( + { + "id": [200, 300, 400], + "name": ["Bob", "David", "Zoe"], + "age": [45, 19, 45], + "other": [{"n": 1.5}, {"n": None}, {"n": -0.5}], + } + ).select(pl.struct(pl.all()).alias("json_msg")) + + +def test_struct_field_nested_dot_notation_22107() -> None: + # ensure dot-notation references the given name at the right level of nesting + df = pl.DataFrame( + { + "id": ["012345", "987654"], + "name": ["A Book", "Another Book"], + "author": [ + {"id": "888888", "name": "Iain M. Banks"}, + {"id": "444444", "name": "Dan Abnett"}, + ], + } + ) + + res = df.sql("SELECT id, author.id AS author_id FROM self ORDER BY id") + assert res.to_dict(as_series=False) == { + "id": ["012345", "987654"], + "author_id": ["888888", "444444"], + } + + for name in ("author.name", "self.author.name"): + res = df.sql(f"SELECT {name} FROM self ORDER BY id") + assert res.to_dict(as_series=False) == {"name": ["Iain M. Banks", "Dan Abnett"]} + + for name in ("name", "self.name"): + res = df.sql(f"SELECT {name} FROM self ORDER BY self.id DESC") + assert res.to_dict(as_series=False) == {"name": ["Another Book", "A Book"]} + + # expected errors + with pytest.raises( + SQLInterfaceError, + match="no table or struct column named 'foo' found", + ): + df.sql("SELECT foo.id FROM self ORDER BY id") + + with pytest.raises( + SQLInterfaceError, + match="no column named 'foo' found", + ): + df.sql("SELECT self.foo FROM self ORDER BY id") + + +@pytest.mark.parametrize( + "order_by", + [ + "ORDER BY json_msg.id DESC", + "ORDER BY 2 DESC", + "", + ], +) +def test_struct_field_selection(order_by: str, df_struct: pl.DataFrame) -> None: + res = df_struct.sql( + f""" + SELECT + -- validate table alias resolution + frame.json_msg.id AS ID, + self.json_msg.name AS NAME, + json_msg.age AS AGE + FROM + self AS frame + WHERE + json_msg.age > 20 AND + json_msg.other.n IS NOT NULL -- note: nested struct field + {order_by} + """ + ) + if not order_by: + res = res.sort(by="ID", descending=True) + + expected = pl.DataFrame({"ID": [400, 200], "NAME": ["Zoe", "Bob"], "AGE": [45, 45]}) + assert_frame_equal(expected, res) + + +def test_struct_field_group_by(df_struct: pl.DataFrame) -> None: + res = pl.sql( + """ + SELECT + COUNT(json_msg.age) AS n, + ARRAY_AGG(json_msg.name) AS names + FROM df_struct + GROUP BY json_msg.age + ORDER BY 1 DESC + """ + ).collect() + + expected = pl.DataFrame( + data={"n": [2, 1], "names": [["Bob", "Zoe"], ["David"]]}, + schema_overrides={"n": pl.UInt32}, + ) + assert_frame_equal(expected, res) + + +def test_struct_field_group_by_errors(df_struct: pl.DataFrame) -> None: + with pytest.raises( + SQLSyntaxError, + match="'name' should participate in the GROUP BY clause or an aggregate function", + ): + pl.sql( + """ + SELECT + json_msg.name, + SUM(json_msg.age) AS sum_age + FROM df_struct + GROUP BY json_msg.age + """ + ).collect() + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + ("nested #> '{c,1}'", 2), + ("nested #> '{c,-1}'", 1), + ("nested #>> '{c,0}'", "3"), + ("nested -> '0' -> 0", "baz"), + ("nested -> 'c' -> -1", 1), + ("nested -> 'c' ->> 2", "1"), + ], +) +def test_struct_field_operator_access(expr: str, expected: int | str) -> None: + df = pl.DataFrame( + { + "nested": { + "0": ["baz"], + "b": ["foo", "bar"], + "c": [3, 2, 1], + }, + }, + ) + assert df.sql(f"SELECT {expr} FROM self").item() == expected + + +@pytest.mark.parametrize( + ("fields", "excluding", "rename"), + [ + ("json_msg.*", "age", {}), + ("json_msg.*", "name", {"other": "misc"}), + ("self.json_msg.*", "(age,other)", {"name": "ident"}), + ("json_msg.other.*", "", {"n": "num"}), + ("self.json_msg.other.*", "", {}), + ("self.json_msg.other.*", "n", {}), + ], +) +def test_struct_field_selection_wildcards( + fields: str, + excluding: str, + rename: dict[str, str], + df_struct: pl.DataFrame, +) -> None: + exclude_cols = f"EXCLUDE {excluding}" if excluding else "" + rename_cols = ( + f"RENAME ({','.join(f'{k} AS {v}' for k, v in rename.items())})" + if rename + else "" + ) + res = df_struct.sql( + f""" + SELECT {fields} {exclude_cols} {rename_cols} + FROM self ORDER BY json_msg.id + """ + ) + + expected = df_struct.unnest("json_msg") + if fields.endswith(".other.*"): + expected = expected["other"].struct.unnest() + if excluding: + expected = expected.drop(excluding.strip(")(").split(",")) + if rename: + expected = expected.rename(rename) + + assert_frame_equal(expected, res) + + +@pytest.mark.parametrize( + ("invalid_column", "error_type"), + [ + ("json_msg.invalid_column", StructFieldNotFoundError), + ("json_msg.other.invalid_column", StructFieldNotFoundError), + ("self.json_msg.other.invalid_column", StructFieldNotFoundError), + ("json_msg.other -> invalid_column", SQLSyntaxError), + ("json_msg -> DATE '2020-09-11'", SQLSyntaxError), + ], +) +def test_struct_field_selection_errors( + invalid_column: str, + error_type: type[Exception], + df_struct: pl.DataFrame, +) -> None: + error_msg = ( + "invalid json/struct path-extract" + if ("->" in invalid_column) + else "invalid_column" + ) + with pytest.raises(error_type, match=error_msg): + df_struct.sql(f"SELECT {invalid_column} FROM self") diff --git a/py-polars/tests/unit/sql/test_subqueries.py b/py-polars/tests/unit/sql/test_subqueries.py new file mode 100644 index 000000000000..35655dc6fc54 --- /dev/null +++ b/py-polars/tests/unit/sql/test_subqueries.py @@ -0,0 +1,148 @@ +import pytest + +import polars as pl +from polars.exceptions import SQLSyntaxError +from polars.testing import assert_frame_equal + + +@pytest.mark.parametrize( + ("cols", "join_type", "constraint"), + [ + ("x", "INNER", ""), + ("y", "INNER", ""), + ("x", "LEFT", "WHERE y IN (0,1,2,3,4,5)"), + ("y", "LEFT", "WHERE y >= 0"), + ("df1.*", "FULL", "WHERE y >= 0"), + ("df2.*", "FULL", "WHERE x >= 0"), + ("* EXCLUDE y", "LEFT", "WHERE y >= 0"), + ("* EXCLUDE x", "LEFT", "WHERE x >= 0"), + ], +) +def test_from_subquery(cols: str, join_type: str, constraint: str) -> None: + df1 = pl.DataFrame({"x": [-1, 0, 3, 1, 2, -1]}) + df2 = pl.DataFrame({"y": [0, 1, 2, 3]}) + + sql = pl.SQLContext(df1=df1, df2=df2) + res = sql.execute( + f""" + SELECT {cols} FROM (SELECT * FROM df1) AS df1 + {join_type} JOIN (SELECT * FROM df2) AS df2 + ON df1.x = df2.y {constraint} + """, + eager=True, + ) + assert sorted(res.to_series()) == [0, 1, 2, 3] + + +def test_in_subquery() -> None: + df = pl.DataFrame( + { + "x": [1, 2, 3, 4, 5, 6], + "y": [2, 3, 4, 5, 6, 7], + } + ) + df_other = pl.DataFrame( + { + "w": [1, 2, 3, 4, 5, 6], + "z": [2, 3, 4, 5, 6, 7], + } + ) + df_chars = pl.DataFrame( + { + "one": ["a", "b", "c", "d", "e", "f"], + "two": ["b", "c", "d", "e", "f", "g"], + } + ) + + sql = pl.SQLContext(df=df, df_other=df_other, df_chars=df_chars) + res_same = sql.execute( + """ + SELECT + df.x as x + FROM df + WHERE x IN (SELECT y FROM df) + """, + eager=True, + ) + df_expected_same = pl.DataFrame({"x": [2, 3, 4, 5, 6]}) + assert_frame_equal( + left=df_expected_same, + right=res_same, + ) + + res_double = sql.execute( + """ + SELECT + df.x as x + FROM df + WHERE x IN (SELECT y FROM df) + AND y IN(SELECT w FROM df_other) + """, + eager=True, + ) + df_expected_double = pl.DataFrame({"x": [2, 3, 4, 5]}) + assert_frame_equal( + left=df_expected_double, + right=res_double, + ) + + res_expressions = sql.execute( + """ + SELECT + df.x as x + FROM df + WHERE x+1 IN (SELECT y FROM df) + AND y IN(SELECT w-1 FROM df_other) + """, + eager=True, + ) + df_expected_expressions = pl.DataFrame({"x": [1, 2, 3, 4]}) + assert_frame_equal( + left=df_expected_expressions, + right=res_expressions, + ) + + res_not_in = sql.execute( + """ + SELECT + df.x as x + FROM df + WHERE x NOT IN (SELECT y-5 FROM df) + AND y NOT IN(SELECT w+5 FROM df_other) + """, + eager=True, + ) + df_not_in = pl.DataFrame({"x": [3, 4]}) + assert_frame_equal( + left=df_not_in, + right=res_not_in, + ) + + res_chars = sql.execute( + """ + SELECT + df_chars.one + FROM df_chars + WHERE one IN (SELECT two FROM df_chars) + """, + eager=True, + ) + df_expected_chars = pl.DataFrame({"one": ["b", "c", "d", "e", "f"]}) + assert_frame_equal( + left=res_chars, + right=df_expected_chars, + ) + + with pytest.raises( + SQLSyntaxError, + match="SQL subquery returns more than one column", + ): + sql.execute( + """ + SELECT + df_chars.one + FROM df_chars + WHERE one IN (SELECT one, two FROM df_chars) + """, + eager=True, + ) diff --git a/py-polars/tests/unit/sql/test_table_operations.py b/py-polars/tests/unit/sql/test_table_operations.py new file mode 100644 index 000000000000..7220a5809ea4 --- /dev/null +++ b/py-polars/tests/unit/sql/test_table_operations.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import re +from datetime import date + +import pytest + +import polars as pl +from polars.exceptions import SQLInterfaceError +from polars.testing import assert_frame_equal + + +@pytest.fixture +def test_frame() -> pl.LazyFrame: + return pl.LazyFrame( + { + "x": [1, 2, 3], + "y": ["aaa", "bbb", "ccc"], + "z": [date(2000, 12, 31), date(1978, 11, 15), date(2077, 10, 20)], + }, + schema_overrides={"x": pl.UInt8}, + ) + + +@pytest.mark.parametrize( + ("delete_constraint", "expected_ids"), + [ + # basic constraints + ("WHERE id = 200", {100, 300}), + ("WHERE id = 200 OR id = 300", {100}), + ("WHERE id IN (200, 300, 400)", {100}), + ("WHERE id NOT IN (200, 300, 400)", {200, 300}), + # more involved constraints + ("WHERE EXTRACT(year FROM dt) >= 2000", {200}), + # null-handling (in the data) + ("WHERE v1 < 0", {100, 300}), + ("WHERE v1 > 0", {200, 300}), + # null handling (in the constraint) + ("WHERE v1 IS NULL", {100, 200}), + ("WHERE v1 IS NOT NULL", {300}), + # boolean handling (delete all/none) + ("WHERE FALSE", {100, 200, 300}), + ("WHERE TRUE", set()), + # no constraint; equivalent to TRUNCATE (drop all rows) + ("", set()), + ], +) +def test_delete_clause(delete_constraint: str, expected_ids: set[int]) -> None: + df = pl.DataFrame( + { + "id": [100, 200, 300], + "dt": [date(2020, 10, 10), date(1999, 1, 2), date(2001, 7, 5)], + "v1": [3.5, -4.0, None], + "v2": [10.0, 2.5, -1.5], + } + ) + res = df.sql(f"DELETE FROM self {delete_constraint}") + assert set(res["id"]) == expected_ids + + +def test_drop_table(test_frame: pl.LazyFrame) -> None: + # 'drop' completely removes the table from sql context + expected = pl.DataFrame() + + with pl.SQLContext(frame=test_frame, eager=True) as ctx: + res = ctx.execute("DROP TABLE frame") + assert_frame_equal(res, expected) + + with pytest.raises(SQLInterfaceError, match="'frame' was not found"): + ctx.execute("SELECT * FROM frame") + + +def test_explain_query(test_frame: pl.LazyFrame) -> None: + # 'explain' returns the query plan for the given sql + with pl.SQLContext(frame=test_frame) as ctx: + plan = ( + ctx.execute("EXPLAIN SELECT * FROM frame") + .select(pl.col("Logical Plan").str.join()) + .collect() + .item() + ) + assert ( + re.search( + pattern=r"PROJECT.+?COLUMNS", + string=plan, + flags=re.IGNORECASE, + ) + is not None + ) + + +def test_show_tables(test_frame: pl.LazyFrame) -> None: + # 'show tables' lists all tables registered with the sql context in sorted order + with pl.SQLContext( + tbl3=test_frame, + tbl2=test_frame, + tbl1=test_frame, + ) as ctx: + res = ctx.execute("SHOW TABLES").collect() + assert_frame_equal(res, pl.DataFrame({"name": ["tbl1", "tbl2", "tbl3"]})) + + +@pytest.mark.parametrize( + "truncate_sql", + [ + "TRUNCATE TABLE frame", + "TRUNCATE frame", + ], +) +def test_truncate_table(truncate_sql: str, test_frame: pl.LazyFrame) -> None: + # 'truncate' preserves the table, but optimally drops all rows within it + expected = pl.DataFrame(schema=test_frame.collect_schema()) + + with pl.SQLContext(frame=test_frame, eager=True) as ctx: + res = ctx.execute(truncate_sql) + assert_frame_equal(res, expected) + + res = ctx.execute("SELECT * FROM frame") + assert_frame_equal(res, expected) diff --git a/py-polars/tests/unit/sql/test_temporal.py b/py-polars/tests/unit/sql/test_temporal.py new file mode 100644 index 000000000000..7cae56b24948 --- /dev/null +++ b/py-polars/tests/unit/sql/test_temporal.py @@ -0,0 +1,460 @@ +from __future__ import annotations + +from datetime import date, datetime, time +from typing import Any, Literal + +import pytest + +import polars as pl +from polars.exceptions import InvalidOperationError, SQLInterfaceError, SQLSyntaxError +from polars.testing import assert_frame_equal + + +def test_date_func() -> None: + df = pl.DataFrame( + { + "date": [ + date(2021, 3, 15), + date(2021, 3, 28), + date(2021, 4, 4), + ], + "version": ["0.0.1", "0.7.3", "0.7.4"], + } + ) + with pl.SQLContext(df=df, eager=True) as ctx: + result = ctx.execute("SELECT date < DATE('2021-03-20') from df") + + expected = pl.DataFrame({"date": [True, False, False]}) + assert_frame_equal(result, expected) + + result = pl.select(pl.sql_expr("CAST(DATE('2023-03', '%Y-%m') as STRING)")) + expected = pl.DataFrame({"literal": ["2023-03-01"]}) + assert_frame_equal(result, expected) + + with pytest.raises( + SQLSyntaxError, + match=r"DATE expects 1-2 arguments \(found 0\)", + ): + df.sql("SELECT DATE() FROM self") + + with pytest.raises(InvalidOperationError): + df.sql("SELECT DATE('2077-07-07','not_a_valid_strftime_format') FROM self") + + +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_datetime_to_time(time_unit: Literal["ns", "us", "ms"]) -> None: + df = pl.DataFrame( # noqa: F841 + { + "dtm": [ + datetime(2099, 12, 31, 23, 59, 59), + datetime(1999, 12, 31, 12, 30, 30), + datetime(1969, 12, 31, 1, 1, 1), + datetime(1899, 12, 31, 0, 0, 0), + ], + }, + schema={"dtm": pl.Datetime(time_unit)}, + ) + + res = pl.sql("SELECT dtm::time AS tm from df").collect() + assert res["tm"].to_list() == [ + time(23, 59, 59), + time(12, 30, 30), + time(1, 1, 1), + time(0, 0, 0), + ] + + +@pytest.mark.parametrize( + ("parts", "dtype", "expected"), + [ + (["decade", "decades"], pl.Int32, [202, 202, 200]), + (["isoyear"], pl.Int32, [2024, 2020, 2005]), + (["year", "y"], pl.Int32, [2024, 2020, 2006]), + (["quarter"], pl.Int8, [1, 4, 1]), + (["month", "months", "mon", "mons"], pl.Int8, [1, 12, 1]), + (["week", "weeks"], pl.Int8, [1, 53, 52]), + (["doy"], pl.Int16, [7, 365, 1]), + (["isodow"], pl.Int8, [7, 3, 7]), + (["dow"], pl.Int8, [0, 3, 0]), + (["day", "days", "d"], pl.Int8, [7, 30, 1]), + (["hour", "hours", "h"], pl.Int8, [1, 10, 23]), + (["minute", "min", "mins", "m"], pl.Int8, [2, 30, 59]), + (["second", "seconds", "secs", "sec"], pl.Int8, [3, 45, 59]), + ( + ["millisecond", "milliseconds", "ms"], + pl.Float64, + [3123.456, 45987.654, 59555.555], + ), + ( + ["microsecond", "microseconds", "us"], + pl.Float64, + [3123456.0, 45987654.0, 59555555.0], + ), + ( + ["nanosecond", "nanoseconds", "ns"], + pl.Float64, + [3123456000.0, 45987654000.0, 59555555000.0], + ), + ( + ["time"], + pl.Time, + [time(1, 2, 3, 123456), time(10, 30, 45, 987654), time(23, 59, 59, 555555)], + ), + ( + ["epoch"], + pl.Float64, + [1704589323.123456, 1609324245.987654, 1136159999.555555], + ), + ], +) +def test_extract(parts: list[str], dtype: pl.DataType, expected: list[Any]) -> None: + df = pl.DataFrame( + { + "dt": [ + # note: these values test several edge-cases, such as isoyear, + # the mon/sun wrapping of dow vs isodow, epoch rounding, etc, + # and the results have been validated against postgresql. + datetime(2024, 1, 7, 1, 2, 3, 123456), + datetime(2020, 12, 30, 10, 30, 45, 987654), + datetime(2006, 1, 1, 23, 59, 59, 555555), + ], + } + ) + with pl.SQLContext(frame_data=df, eager=True) as ctx: + for part in parts: + for fn in ( + f"EXTRACT({part} FROM dt)", + f"DATE_PART('{part}',dt)", + ): + res = ctx.execute(f"SELECT {fn} AS {part} FROM frame_data").to_series() + assert res.dtype == dtype + assert res.to_list() == expected + + +def test_extract_errors() -> None: + df = pl.DataFrame({"dt": [datetime(2024, 1, 7, 1, 2, 3, 123456)]}) + + with pl.SQLContext(frame_data=df, eager=True) as ctx: + for part in ("femtosecond", "stroopwafel"): + with pytest.raises( + SQLSyntaxError, + match=f"EXTRACT/DATE_PART does not support '{part}' part", + ): + ctx.execute(f"SELECT EXTRACT({part} FROM dt) FROM frame_data") + + with pytest.raises( + SQLSyntaxError, + match=r"EXTRACT/DATE_PART does not support 'week\(tuesday\)' part", + ): + ctx.execute("SELECT DATE_PART('week(tuesday)', dt) FROM frame_data") + + +@pytest.mark.parametrize( + ("dt", "expected"), + [ + (date(1, 1, 1), [1, 1]), + (date(100, 1, 1), [1, 1]), + (date(101, 1, 1), [1, 2]), + (date(1000, 1, 1), [1, 10]), + (date(1001, 1, 1), [2, 11]), + (date(1899, 12, 31), [2, 19]), + (date(1900, 12, 31), [2, 19]), + (date(1901, 1, 1), [2, 20]), + (date(2000, 12, 31), [2, 20]), + (date(2001, 1, 1), [3, 21]), + (date(5555, 5, 5), [6, 56]), + (date(9999, 12, 31), [10, 100]), + ], +) +def test_extract_century_millennium(dt: date, expected: list[int]) -> None: + with pl.SQLContext(frame_data=pl.DataFrame({"dt": [dt]}), eager=True) as ctx: + res = ctx.execute( + """ + SELECT + EXTRACT(MILLENNIUM FROM dt) AS c1, + DATE_PART('century',dt) AS c2, + EXTRACT(millennium FROM dt) AS c3, + DATE_PART('CENTURY',dt) AS c4, + FROM frame_data + """ + ) + assert_frame_equal( + left=res, + right=pl.DataFrame( + data=[expected + expected], + schema=["c1", "c2", "c3", "c4"], + orient="row", + ).cast(pl.Int32), + ) + + +@pytest.mark.parametrize( + ("constraint", "expected"), + [ + ("dtm >= '2020-12-30T10:30:45.987'", [0, 2]), + ("dtm::date > '2006-01-01'", [0, 2]), + ("dtm > '2006-01-01'", [0, 1, 2]), # << implies '2006-01-01 00:00:00' + ("dtm <= '2006-01-01'", []), # << implies '2006-01-01 00:00:00' + ("dt != '1960-01-07'", [0, 1]), + ("tm != '22:10:30'", [0, 2]), + ("tm >= '11:00:00' AND tm < '22:00:00'", [0]), + ("tm BETWEEN '12:00:00' AND '23:59:58'", [0, 1]), + ("dt BETWEEN '2050-01-01' AND '2100-12-31'", [1]), + ("dt::datetime = '1960-01-07'", [2]), + ("dt::datetime = '1960-01-07 00:00:00'", [2]), + ("dtm BETWEEN '2020-12-30 10:30:44' AND '2023-01-01 00:00:00'", [2]), + ("dt IN ('1960-01-07','2077-01-01','2222-02-22')", [1, 2]), + ( + "dtm = '2024-01-07 01:02:03.123456000' OR dtm = '2020-12-30 10:30:45.987654'", + [0, 2], + ), + ], +) +def test_implicit_temporal_strings(constraint: str, expected: list[int]) -> None: + df = pl.DataFrame( + { + "idx": [0, 1, 2], + "dtm": [ + datetime(2024, 1, 7, 1, 2, 3, 123456), + datetime(2006, 1, 1, 23, 59, 59, 555555), + datetime(2020, 12, 30, 10, 30, 45, 987654), + ], + "dt": [ + date(2020, 12, 30), + date(2077, 1, 1), + date(1960, 1, 7), + ], + "tm": [ + time(17, 30, 45), + time(22, 10, 30), + time(10, 25, 15), + ], + } + ) + res = df.sql(f"SELECT idx FROM self WHERE {constraint}") + actual = sorted(res["idx"]) + assert actual == expected + + +@pytest.mark.parametrize( + "dtval", + [ + "2020-12-30T10:30:45", + "yyyy-mm-dd", + "2222-22-22", + "10:30:45", + ], +) +def test_implicit_temporal_string_errors(dtval: str) -> None: + df = pl.DataFrame({"dt": [date(2020, 12, 30)]}) + + with pytest.raises( + InvalidOperationError, + match="(conversion.*failed)|(cannot compare.*string.*temporal)", + ): + df.sql(f"SELECT * FROM self WHERE dt = '{dtval}'") + + +def test_strftime() -> None: + df = pl.DataFrame( + { + "dtm": [ + None, + datetime(1980, 9, 30, 1, 25, 50), + datetime(2077, 7, 17, 11, 30, 55), + ], + "dt": [date(1978, 7, 5), date(1969, 12, 31), date(2020, 4, 10)], + "tm": [time(10, 10, 10), time(22, 33, 55), None], + } + ) + res = df.sql( + """ + SELECT + STRFTIME(dtm,'%m.%d.%Y/%T') AS s_dtm, + STRFTIME(dt ,'%B %d, %Y') AS s_dt, + STRFTIME(tm ,'%S.%M.%H') AS s_tm, + FROM self + """ + ) + assert res.to_dict(as_series=False) == { + "s_dtm": [None, "09.30.1980/01:25:50", "07.17.2077/11:30:55"], + "s_dt": ["July 05, 1978", "December 31, 1969", "April 10, 2020"], + "s_tm": ["10.10.10", "55.33.22", None], + } + + with pytest.raises( + SQLSyntaxError, + match=r"STRFTIME expects 2 arguments \(found 4\)", + ): + pl.sql_expr("STRFTIME(dtm,'%Y-%m-%d','[extra]','[param]')") + + +def test_strptime() -> None: + df = pl.DataFrame( + { + "s_dtm": [None, "09.30.1980/01:25:50", "07.17.2077/11:30:55"], + "s_dt": ["July 5, 1978", "December 31, 1969", "April 10, 2020"], + "s_tm": ["10.10.10", "55.33.22", None], + } + ) + res = df.sql( + """ + SELECT + STRPTIME(s_dtm,'%m.%d.%Y/%T') AS dtm, + STRPTIME(s_dt ,'%B %d, %Y')::date AS dt, + STRPTIME(s_tm ,'%S.%M.%H')::time AS tm + FROM self + """ + ) + assert res.to_dict(as_series=False) == { + "dtm": [ + None, + datetime(1980, 9, 30, 1, 25, 50), + datetime(2077, 7, 17, 11, 30, 55), + ], + "dt": [date(1978, 7, 5), date(1969, 12, 31), date(2020, 4, 10)], + "tm": [time(10, 10, 10), time(22, 33, 55), None], + } + with pytest.raises( + SQLSyntaxError, + match=r"STRPTIME expects 2 arguments \(found 3\)", + ): + pl.sql_expr("STRPTIME(s,'%Y.%m.%d',false) AS dt") + + +def test_temporal_stings_to_datetime() -> None: + df = pl.DataFrame( + { + "s_dt": ["2077-10-10", "1942-01-08", "2000-07-05"], + "s_dtm1": [ + "1999-12-31 10:30:45", + "2020-06-10", + "2022-08-07T00:01:02.654321", + ], + "s_dtm2": ["31-12-1999 10:30", "10-06-2020 00:00", "07-08-2022 00:01"], + "s_tm": ["02:04:06", "12:30:45.999", "23:59:59.123456"], + } + ) + res = df.sql( + """ + SELECT + DATE(s_dt) AS dt1, + DATETIME(s_dt) AS dt2, + DATETIME(s_dtm1) AS dtm1, + DATETIME(s_dtm2,'%d-%m-%Y %H:%M') AS dtm2, + TIME(s_tm) AS tm + FROM self + """ + ) + assert res.schema == { + "dt1": pl.Date, + "dt2": pl.Datetime("us"), + "dtm1": pl.Datetime("us"), + "dtm2": pl.Datetime("us"), + "tm": pl.Time, + } + assert res.rows() == [ + ( + date(2077, 10, 10), + datetime(2077, 10, 10, 0, 0), + datetime(1999, 12, 31, 10, 30, 45), + datetime(1999, 12, 31, 10, 30), + time(2, 4, 6), + ), + ( + date(1942, 1, 8), + datetime(1942, 1, 8, 0, 0), + datetime(2020, 6, 10, 0, 0), + datetime(2020, 6, 10, 0, 0), + time(12, 30, 45, 999000), + ), + ( + date(2000, 7, 5), + datetime(2000, 7, 5, 0, 0), + datetime(2022, 8, 7, 0, 1, 2, 654321), + datetime(2022, 8, 7, 0, 1), + time(23, 59, 59, 123456), + ), + ] + + for fn in ("DATE", "TIME", "DATETIME"): + with pytest.raises( + SQLSyntaxError, + match=rf"{fn} expects 1-2 arguments \(found 3\)", + ): + pl.sql_expr(rf"{fn}(s,fmt,misc) AS xyz") + + +def test_temporal_typed_literals() -> None: + res = pl.sql( + """ + SELECT + DATE '2020-12-30' AS dt, + TIME '00:01:02' AS tm1, + TIME '23:59:59.123456' AS tm2, + TIMESTAMP '1930-01-01 12:30:00' AS dtm1, + TIMESTAMP '2077-04-27T23:45:30.123456' AS dtm2 + FROM + (VALUES (0)) tbl (x) + """, + eager=True, + ) + assert res.to_dict(as_series=False) == { + "dt": [date(2020, 12, 30)], + "tm1": [time(0, 1, 2)], + "tm2": [time(23, 59, 59, 123456)], + "dtm1": [datetime(1930, 1, 1, 12, 30)], + "dtm2": [datetime(2077, 4, 27, 23, 45, 30, 123456)], + } + + +@pytest.mark.parametrize("fn", ["DATE", "TIME", "TIMESTAMP"]) +def test_typed_literals_errors(fn: str) -> None: + with pytest.raises(SQLSyntaxError, match=f"invalid {fn} literal '999'"): + pl.sql_expr(f"{fn} '999'") + + +@pytest.mark.parametrize( + ("unit", "expected"), + [ + ("ms", [1704589323123, 1609324245987, 1136159999555]), + ("us", [1704589323123456, 1609324245987654, 1136159999555555]), + ("ns", [1704589323123456000, 1609324245987654000, 1136159999555555000]), + ], +) +def test_timestamp_time_unit(unit: str | None, expected: list[int]) -> None: + df = pl.DataFrame( + { + "ts": [ + datetime(2024, 1, 7, 1, 2, 3, 123456), + datetime(2020, 12, 30, 10, 30, 45, 987654), + datetime(2006, 1, 1, 23, 59, 59, 555555), + ], + } + ) + precision = {"ms": 3, "us": 6, "ns": 9} + + with pl.SQLContext(frame_data=df, eager=True) as ctx: + prec = f"({precision[unit]})" if unit else "" + res = ctx.execute(f"SELECT ts::timestamp{prec} FROM frame_data").to_series() + + assert res.dtype == pl.Datetime(time_unit=unit) # type: ignore[arg-type] + assert res.to_physical().to_list() == expected + + +def test_timestamp_time_unit_errors() -> None: + df = pl.DataFrame({"ts": [datetime(2024, 1, 7, 1, 2, 3, 123456)]}) + + with pl.SQLContext(frame_data=df, eager=True) as ctx: + for prec in (0, 15): + with pytest.raises( + SQLSyntaxError, + match=rf"invalid temporal type precision \(expected 1-9, found {prec}\)", + ): + ctx.execute(f"SELECT ts::timestamp({prec}) FROM frame_data") + + with pytest.raises( + SQLInterfaceError, + match="sql parser error: Expected: literal int, found: - ", + ): + ctx.execute("SELECT ts::timestamp(-3) FROM frame_data") diff --git a/py-polars/tests/unit/sql/test_trigonometric.py b/py-polars/tests/unit/sql/test_trigonometric.py new file mode 100644 index 000000000000..7e52174d6805 --- /dev/null +++ b/py-polars/tests/unit/sql/test_trigonometric.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +import math + +import polars as pl +from polars.testing import assert_frame_equal + + +def test_arctan2() -> None: + twoRootTwo = math.sqrt(2) / 2.0 + df = pl.DataFrame( # noqa: F841 + { + "y": [twoRootTwo, -twoRootTwo, twoRootTwo, -twoRootTwo], + "x": [twoRootTwo, twoRootTwo, -twoRootTwo, -twoRootTwo], + } + ) + res = pl.sql( + """ + SELECT + ATAN2D(y,x) as "atan2d", + ATAN2(y,x) as "atan2" + FROM df + """, + eager=True, + ) + df_result = pl.DataFrame({"atan2d": [45.0, -45.0, 135.0, -135.0]}) + df_result = df_result.with_columns(pl.col("atan2d").cast(pl.Float64)) + df_result = df_result.with_columns(pl.col("atan2d").radians().alias("atan2")) + + assert_frame_equal(df_result, res) + + +def test_trig() -> None: + df = pl.DataFrame( + { + "a": [-4.0, -3.0, -2.0, -1.00001, 0.0, 1.00001, 2.0, 3.0, 4.0], + } + ) + + ctx = pl.SQLContext(df=df) + res = ctx.execute( + """ + SELECT + asin(1.0)/a as "pi values", + cos(asin(1.0)/a) AS "cos", + cot(asin(1.0)/a) AS "cot", + sin(asin(1.0)/a) AS "sin", + tan(asin(1.0)/a) AS "tan", + + cosd(asind(1.0)/a) AS "cosd", + cotd(asind(1.0)/a) AS "cotd", + sind(asind(1.0)/a) AS "sind", + tand(asind(1.0)/a) AS "tand", + + 1.0/a as "inverse pi values", + acos(1.0/a) AS "acos", + asin(1.0/a) AS "asin", + atan(1.0/a) AS "atan", + + acosd(1.0/a) AS "acosd", + asind(1.0/a) AS "asind", + atand(1.0/a) AS "atand" + FROM df + """, + eager=True, + ) + + df_result = pl.DataFrame( + { + "pi values": [ + -0.392699, + -0.523599, + -0.785398, + -1.570781, + float("inf"), + 1.570781, + 0.785398, + 0.523599, + 0.392699, + ], + "cos": [ + 0.92388, + 0.866025, + 0.707107, + 0.000016, + float("nan"), + 0.000016, + 0.707107, + 0.866025, + 0.92388, + ], + "cot": [ + -2.414214, + -1.732051, + -1.0, + -0.000016, + float("nan"), + 0.000016, + 1.0, + 1.732051, + 2.414214, + ], + "sin": [ + -0.382683, + -0.5, + -0.707107, + -1.0, + float("nan"), + 1, + 0.707107, + 0.5, + 0.382683, + ], + "tan": [ + -0.414214, + -0.57735, + -1, + -63662.613851, + float("nan"), + 63662.613851, + 1, + 0.57735, + 0.414214, + ], + "cosd": [ + 0.92388, + 0.866025, + 0.707107, + 0.000016, + float("nan"), + 0.000016, + 0.707107, + 0.866025, + 0.92388, + ], + "cotd": [ + -2.414214, + -1.732051, + -1.0, + -0.000016, + float("nan"), + 0.000016, + 1.0, + 1.732051, + 2.414214, + ], + "sind": [ + -0.382683, + -0.5, + -0.707107, + -1.0, + float("nan"), + 1, + 0.707107, + 0.5, + 0.382683, + ], + "tand": [ + -0.414214, + -0.57735, + -1, + -63662.613851, + float("nan"), + 63662.613851, + 1, + 0.57735, + 0.414214, + ], + "inverse pi values": [ + -0.25, + -0.333333, + -0.5, + -0.99999, + float("inf"), + 0.99999, + 0.5, + 0.333333, + 0.25, + ], + "acos": [ + 1.823477, + 1.910633, + 2.094395, + 3.137121, + float("nan"), + 0.004472, + 1.047198, + 1.230959, + 1.318116, + ], + "asin": [ + -0.25268, + -0.339837, + -0.523599, + -1.566324, + float("nan"), + 1.566324, + 0.523599, + 0.339837, + 0.25268, + ], + "atan": [ + -0.244979, + -0.321751, + -0.463648, + -0.785393, + 1.570796, + 0.785393, + 0.463648, + 0.321751, + 0.244979, + ], + "acosd": [ + 104.477512, + 109.471221, + 120.0, + 179.743767, + float("nan"), + 0.256233, + 60.0, + 70.528779, + 75.522488, + ], + "asind": [ + -14.477512, + -19.471221, + -30.0, + -89.743767, + float("nan"), + 89.743767, + 30.0, + 19.471221, + 14.477512, + ], + "atand": [ + -14.036243, + -18.434949, + -26.565051, + -44.999714, + 90.0, + 44.999714, + 26.565051, + 18.434949, + 14.036243, + ], + } + ) + + assert_frame_equal(left=df_result, right=res, atol=1e-5) diff --git a/py-polars/tests/unit/sql/test_wildcard_opts.py b/py-polars/tests/unit/sql/test_wildcard_opts.py new file mode 100644 index 000000000000..c31a55d61829 --- /dev/null +++ b/py-polars/tests/unit/sql/test_wildcard_opts.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +import polars as pl +from polars.exceptions import DuplicateError, SQLInterfaceError +from polars.testing import assert_frame_equal + + +@pytest.fixture +def df() -> pl.DataFrame: + return pl.DataFrame( + { + "ID": [333, 666, 999], + "FirstName": ["Bruce", "Diana", "Clark"], + "LastName": ["Wayne", "Prince", "Kent"], + "Address": ["Batcave", "Paradise Island", "Fortress of Solitude"], + "City": ["Gotham", "Themyscira", "Metropolis"], + } + ) + + +@pytest.mark.parametrize( + ("excluded", "order_by", "expected"), + [ + ("ID", "ORDER BY 2, 1", ["FirstName", "LastName", "Address", "City"]), + ("(ID)", "ORDER BY City", ["FirstName", "LastName", "Address", "City"]), + ("(Address, LastName, FirstName)", "", ["ID", "City"]), + ('("ID", "FirstName", "LastName", "Address", "City")', "", []), + ], +) +def test_select_exclude( + excluded: str, + order_by: str, + expected: list[str], + df: pl.DataFrame, +) -> None: + for exclude_keyword in ("EXCLUDE", "EXCEPT"): + assert ( + df.sql(f"SELECT * {exclude_keyword} {excluded} FROM self").columns + == expected + ) + + +def test_select_exclude_order_by( + df: pl.DataFrame, +) -> None: + expected = pl.DataFrame( + { + "FirstName": ["Diana", "Clark", "Bruce"], + "Address": ["Paradise Island", "Fortress of Solitude", "Batcave"], + } + ) + for order_by in ("", "ORDER BY 1 DESC", "ORDER BY 2 DESC", "ORDER BY Address DESC"): + actual = df.sql(f"SELECT * EXCLUDE (ID,LastName,City) FROM self {order_by}") + if not order_by: + actual = actual.sort("FirstName", descending=True) + assert_frame_equal(actual, expected) + + +def test_ilike(df: pl.DataFrame) -> None: + assert df.sql("SELECT * ILIKE 'a%e' FROM self").columns == [] + assert df.sql("SELECT * ILIKE '%nam_' FROM self").columns == [ + "FirstName", + "LastName", + ] + assert df.sql("SELECT * ILIKE '%a%e%' FROM self").columns == [ + "FirstName", + "LastName", + "Address", + ] + assert df.sql( + """SELECT * ILIKE '%I%' RENAME (FirstName AS Name) FROM self""" + ).columns == [ + "ID", + "Name", + "City", + ] + + +@pytest.mark.parametrize( + ("renames", "expected"), + [ + ( + "Address AS Location", + ["ID", "FirstName", "LastName", "Location", "City"], + ), + ( + '(Address AS "Location")', + ["ID", "FirstName", "LastName", "Location", "City"], + ), + ( + '("Address" AS Location, "ID" AS PersonID)', + ["PersonID", "FirstName", "LastName", "Location", "City"], + ), + ], +) +def test_select_rename( + renames: str, + expected: list[str], + df: pl.DataFrame, +) -> None: + assert df.sql(f"SELECT * RENAME {renames} FROM self").columns == expected + + +@pytest.mark.parametrize("order_by", ["1 DESC", "Name DESC", "FirstName DESC"]) +def test_select_rename_exclude_sort(order_by: str, df: pl.DataFrame) -> None: + actual = df.sql( + f""" + SELECT * EXCLUDE (ID, City, LastName) RENAME FirstName AS Name + FROM self + ORDER BY {order_by} + """ + ) + expected = pl.DataFrame( + { + "Name": ["Diana", "Clark", "Bruce"], + "Address": ["Paradise Island", "Fortress of Solitude", "Batcave"], + } + ) + assert_frame_equal(expected, actual) + + +@pytest.mark.parametrize( + ("replacements", "check_cols", "expected"), + [ + ( + "(ID // 3 AS ID)", + ["ID"], + [(333,), (222,), (111,)], + ), + ( + "(ID // 3 AS ID) RENAME (ID AS Identifier)", + ["Identifier"], + [(333,), (222,), (111,)], + ), + ( + "((City || ':' || City) AS City, ID // -3 AS ID)", + ["City", "ID"], + [ + ("Gotham:Gotham", -111), + ("Themyscira:Themyscira", -222), + ("Metropolis:Metropolis", -333), + ], + ), + ], +) +def test_select_replace( + replacements: str, + check_cols: list[str], + expected: list[tuple[Any]], + df: pl.DataFrame, +) -> None: + for order_by in ("", "ORDER BY ID DESC", "ORDER BY -ID ASC"): + res = df.sql(f"SELECT * REPLACE {replacements} FROM self {order_by}") + if not order_by: + res = res.sort(check_cols[-1], descending=True) + + assert res.select(check_cols).rows() == expected + expected_columns = ( + check_cols + df.columns[1:] if check_cols == ["Identifier"] else df.columns + ) + assert res.columns == expected_columns + + +def test_select_wildcard_errors(df: pl.DataFrame) -> None: + # EXCLUDE and ILIKE are not allowed together + with pytest.raises(SQLInterfaceError, match="ILIKE"): + assert df.sql("SELECT * EXCLUDE Address ILIKE '%o%' FROM self") + + # these two options are aliases, with EXCLUDE being preferred + with pytest.raises( + SQLInterfaceError, + match="EXCLUDE and EXCEPT wildcard options cannot be used together", + ): + assert df.sql("SELECT * EXCLUDE Address EXCEPT City FROM self") + + # note: missing "()" around the exclude option results in dupe col + with pytest.raises( + DuplicateError, + match="City", + ): + assert df.sql("SELECT * EXCLUDE Address, City FROM self") diff --git a/py-polars/tests/unit/streaming/__init__.py b/py-polars/tests/unit/streaming/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/streaming/conftest.py b/py-polars/tests/unit/streaming/conftest.py new file mode 100644 index 000000000000..d3c434edcc3e --- /dev/null +++ b/py-polars/tests/unit/streaming/conftest.py @@ -0,0 +1,8 @@ +from pathlib import Path + +import pytest + + +@pytest.fixture +def io_files_path() -> Path: + return Path(__file__).parent.parent / "io" / "files" diff --git a/py-polars/tests/unit/streaming/test_streaming.py b/py-polars/tests/unit/streaming/test_streaming.py new file mode 100644 index 000000000000..7d6da0adc571 --- /dev/null +++ b/py-polars/tests/unit/streaming/test_streaming.py @@ -0,0 +1,413 @@ +from __future__ import annotations + +import os +import time +from datetime import date +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import numpy as np +import pytest + +import polars as pl +from polars.exceptions import PolarsInefficientMapWarning +from polars.testing import assert_frame_equal + +if TYPE_CHECKING: + from polars._typing import JoinStrategy + +pytestmark = pytest.mark.xdist_group("streaming") + + +def test_streaming_categoricals_5921() -> None: + with pl.StringCache(): + out_lazy = ( + pl.DataFrame({"X": ["a", "a", "a", "b", "b"], "Y": [2, 2, 2, 1, 1]}) + .lazy() + .with_columns(pl.col("X").cast(pl.Categorical)) + .group_by("X") + .agg(pl.col("Y").min()) + .sort("Y", descending=True) + .collect(engine="streaming") + ) + + out_eager = ( + pl.DataFrame({"X": ["a", "a", "a", "b", "b"], "Y": [2, 2, 2, 1, 1]}) + .with_columns(pl.col("X").cast(pl.Categorical)) + .group_by("X") + .agg(pl.col("Y").min()) + .sort("Y", descending=True) + ) + + for out in [out_eager, out_lazy]: + assert out.dtypes == [pl.Categorical, pl.Int64] + assert out.to_dict(as_series=False) == {"X": ["a", "b"], "Y": [2, 1]} + + +def test_streaming_block_on_literals_6054() -> None: + df = pl.DataFrame({"col_1": [0] * 5 + [1] * 5}) + s = pl.Series("col_2", list(range(10))) + + assert df.lazy().with_columns(s).group_by("col_1").agg(pl.all().first()).collect( + engine="streaming" + ).sort("col_1").to_dict(as_series=False) == {"col_1": [0, 1], "col_2": [0, 5]} + + +@pytest.mark.may_fail_auto_streaming +def test_streaming_streamable_functions(monkeypatch: Any, capfd: Any) -> None: + monkeypatch.setenv("POLARS_VERBOSE", "1") + assert ( + pl.DataFrame({"a": [1, 2, 3]}) + .lazy() + .map_batches( + function=lambda df: df.with_columns(pl.col("a").alias("b")), + schema={"a": pl.Int64, "b": pl.Int64}, + streamable=True, + ) + ).collect(engine="old-streaming").to_dict(as_series=False) == { # type: ignore[call-overload] + "a": [1, 2, 3], + "b": [1, 2, 3], + } + + (_, err) = capfd.readouterr() + assert "df -> function -> ordered_sink" in err + + +@pytest.mark.slow +@pytest.mark.may_fail_auto_streaming +def test_cross_join_stack() -> None: + a = pl.Series(np.arange(100_000)).to_frame().lazy() + t0 = time.time() + # this should be instant if directly pushed into sink + # if not the cross join will first fill the stack with all matches of a single chunk + assert a.join(a, how="cross").head().collect(engine="old-streaming").shape == (5, 2) # type: ignore[call-overload] + t1 = time.time() + assert (t1 - t0) < 0.5 + + +def test_streaming_literal_expansion() -> None: + df = pl.DataFrame( + { + "y": ["a", "b"], + "z": [1, 2], + } + ) + + q = df.lazy().select( + x=pl.lit("constant"), + y=pl.col("y"), + z=pl.col("z"), + ) + + assert q.collect(engine="streaming").to_dict(as_series=False) == { + "x": ["constant", "constant"], + "y": ["a", "b"], + "z": [1, 2], + } + assert q.group_by(["x", "y"]).agg(pl.mean("z")).sort("y").collect( + engine="streaming" + ).to_dict(as_series=False) == { + "x": ["constant", "constant"], + "y": ["a", "b"], + "z": [1.0, 2.0], + } + assert q.group_by(["x"]).agg(pl.mean("z")).collect().to_dict(as_series=False) == { + "x": ["constant"], + "z": [1.5], + } + + +@pytest.mark.may_fail_auto_streaming +def test_streaming_apply(monkeypatch: Any, capfd: Any) -> None: + monkeypatch.setenv("POLARS_VERBOSE", "1") + + q = pl.DataFrame({"a": [1, 2]}).lazy() + with pytest.warns(PolarsInefficientMapWarning, match="with this one instead"): + ( + q.select( + pl.col("a").map_elements(lambda x: x * 2, return_dtype=pl.Int64) + ).collect(engine="old-streaming") # type: ignore[call-overload] + ) + + +def test_streaming_ternary() -> None: + q = pl.LazyFrame({"a": [1, 2, 3]}) + + assert ( + q.with_columns( + pl.when(pl.col("a") >= 2).then(pl.col("a")).otherwise(None).alias("b"), + ) + .explain(engine="old-streaming") # type: ignore[arg-type] + .startswith("STREAMING") + ) + + +def test_streaming_sortedness_propagation_9494() -> None: + assert ( + pl.DataFrame( + { + "when": [date(2023, 5, 10), date(2023, 5, 20), date(2023, 6, 10)], + "what": [1, 2, 3], + } + ) + .lazy() + .sort("when") + .group_by_dynamic("when", every="1mo") + .agg(pl.col("what").sum()) + .collect(engine="streaming") + ).to_dict(as_series=False) == { + "when": [date(2023, 5, 1), date(2023, 6, 1)], + "what": [3, 3], + } + + +@pytest.mark.write_disk +@pytest.mark.slow +def test_streaming_generic_left_and_inner_join_from_disk(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + p0 = tmp_path / "df0.parquet" + p1 = tmp_path / "df1.parquet" + # by loading from disk, we get different chunks + n = 200_000 + k = 100 + + d0: dict[str, np.ndarray[Any, Any]] = { + f"x{i}": np.random.random(n) for i in range(k) + } + d0.update({"id": np.arange(n)}) + + df0 = pl.DataFrame(d0) + df1 = df0.clone().select(pl.all().shuffle(111)) + + df0.write_parquet(p0) + df1.write_parquet(p1) + + lf0 = pl.scan_parquet(p0) + lf1 = pl.scan_parquet(p1).select(pl.all().name.suffix("_r")) + + join_strategies: list[JoinStrategy] = ["left", "inner"] + for how in join_strategies: + assert_frame_equal( + lf0.join( + lf1, left_on="id", right_on="id_r", how=how, maintain_order="left" + ).collect(engine="streaming"), + lf0.join(lf1, left_on="id", right_on="id_r", how=how).collect( + engine="in-memory" + ), + check_row_order=how == "left", + ) + + +def test_streaming_9776() -> None: + df = pl.DataFrame({"col_1": ["a"] * 1000, "ID": [None] + ["a"] * 999}) + ordered = ( + df.group_by("col_1", "ID", maintain_order=True) + .len() + .filter(pl.col("col_1") == "a") + ) + unordered = ( + df.group_by("col_1", "ID", maintain_order=False) + .len() + .filter(pl.col("col_1") == "a") + ) + expected = [("a", None, 1), ("a", "a", 999)] + assert ordered.rows() == expected + assert unordered.sort(["col_1", "ID"]).rows() == expected + + +@pytest.mark.write_disk +def test_stream_empty_file(tmp_path: Path) -> None: + p = tmp_path / "in.parquet" + schema = { + "KLN_NR": pl.String, + } + + df = pl.DataFrame( + { + "KLN_NR": [], + }, + schema=schema, + ) + df.write_parquet(p) + assert pl.scan_parquet(p).collect(engine="streaming").schema == schema + + +def test_streaming_empty_df() -> None: + df = pl.DataFrame( + [ + pl.Series("a", ["a", "b", "c", "b", "a", "a"], dtype=pl.Categorical()), + pl.Series("b", ["b", "c", "c", "b", "a", "c"], dtype=pl.Categorical()), + ] + ) + + result = ( + df.lazy() + .join(df.lazy(), on="a", how="inner") + .filter(False) + .collect(engine="streaming") + ) + + assert result.to_dict(as_series=False) == {"a": [], "b": [], "b_right": []} + + +def test_streaming_duplicate_cols_5537() -> None: + assert pl.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]}).lazy().with_columns( + (pl.col("a") * 2).alias("foo"), (pl.col("a") * 3) + ).collect(engine="streaming").to_dict(as_series=False) == { + "a": [3, 6, 9], + "b": [1, 2, 3], + "foo": [2, 4, 6], + } + + +def test_null_sum_streaming_10455() -> None: + df = pl.DataFrame( + { + "x": [1] * 10, + "y": [None] * 10, + }, + schema={"x": pl.Int64, "y": pl.Float32}, + ) + result = df.lazy().group_by("x").sum().collect(engine="streaming") + expected = {"x": [1], "y": [0.0]} + assert result.to_dict(as_series=False) == expected + + +def test_boolean_agg_schema() -> None: + df = pl.DataFrame( + { + "x": [1, 1, 1], + "y": [False, True, False], + } + ).lazy() + + agg_df = df.group_by("x").agg(pl.col("y").max().alias("max_y")) + + for streaming in [True, False]: + assert ( + agg_df.collect(engine="streaming" if streaming else "in-memory").schema + == agg_df.collect_schema() + == {"x": pl.Int64, "max_y": pl.Boolean} + ) + + +@pytest.mark.write_disk +def test_streaming_csv_headers_but_no_data_13770(tmp_path: Path) -> None: + with Path.open(tmp_path / "header_no_data.csv", "w") as f: + f.write("name, age\n") + + schema = {"name": pl.String, "age": pl.Int32} + df = ( + pl.scan_csv(tmp_path / "header_no_data.csv", schema=schema) + .head() + .collect(engine="streaming") + ) + assert df.height == 0 + assert df.schema == schema + + +@pytest.mark.write_disk +def test_custom_temp_dir(tmp_path: Path, monkeypatch: Any) -> None: + tmp_path.mkdir(exist_ok=True) + monkeypatch.setenv("POLARS_TEMP_DIR", str(tmp_path)) + monkeypatch.setenv("POLARS_FORCE_OOC", "1") + monkeypatch.setenv("POLARS_VERBOSE", "1") + + s = pl.arange(0, 100_000, eager=True).rename("idx") + df = s.shuffle().to_frame() + df.lazy().sort("idx").collect(engine="old-streaming") # type: ignore[call-overload] + + assert os.listdir(tmp_path), f"Temp directory '{tmp_path}' is empty" + + +@pytest.mark.write_disk +def test_streaming_with_hconcat(tmp_path: Path) -> None: + df1 = pl.DataFrame( + { + "id": [0, 0, 1, 1, 2, 2], + "x": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0], + } + ) + df1.write_parquet(tmp_path / "df1.parquet") + + df2 = pl.DataFrame( + { + "y": [6.0, 7.0, 8.0, 9.0, 10.0, 11.0], + } + ) + df2.write_parquet(tmp_path / "df2.parquet") + + lf1 = pl.scan_parquet(tmp_path / "df1.parquet") + lf2 = pl.scan_parquet(tmp_path / "df2.parquet") + query = ( + pl.concat([lf1, lf2], how="horizontal") + .group_by("id") + .agg(pl.all().mean()) + .sort(pl.col("id")) + ) + + result = query.collect(engine="streaming") + + expected = pl.DataFrame( + { + "id": [0, 1, 2], + "x": [0.5, 2.5, 4.5], + "y": [6.5, 8.5, 10.5], + } + ) + + assert_frame_equal(result, expected) + + +@pytest.mark.write_disk +def test_elementwise_identification_in_ternary_15767(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + ( + pl.LazyFrame({"a": pl.Series([1])}) + .with_columns(b=pl.col("a").is_in(pl.Series([1, 2, 3]))) + .sink_parquet(tmp_path / "1") + ) + + ( + pl.LazyFrame({"a": pl.Series([1])}) + .with_columns( + b=pl.when(pl.col("a").is_in(pl.Series([1, 2, 3]))).then(pl.col("a")) + ) + .sink_parquet(tmp_path / "1") + ) + + +def test_streaming_temporal_17669() -> None: + df = ( + pl.LazyFrame({"a": [1, 2, 3]}, schema={"a": pl.Datetime("us")}) + .with_columns( + b=pl.col("a").dt.date(), + c=pl.col("a").dt.time(), + ) + .collect(engine="streaming") + ) + assert df.schema == { + "a": pl.Datetime("us"), + "b": pl.Date, + "c": pl.Time, + } + + +def test_i128_sum_reduction() -> None: + assert ( + pl.Series("a", [1, 2, 3], pl.Int128) + .to_frame() + .lazy() + .sum() + .collect(engine="streaming") + .item() + == 6 + ) + + +def test_streaming_flag_21799() -> None: + with pytest.raises(DeprecationWarning): + pl.LazyFrame({"a": 1}).collect(streaming=False) # type: ignore[call-overload] + with pytest.raises(DeprecationWarning): + pl.LazyFrame({"a": 1}).collect(streaming=True) # type: ignore[call-overload] diff --git a/py-polars/tests/unit/streaming/test_streaming_categoricals.py b/py-polars/tests/unit/streaming/test_streaming_categoricals.py new file mode 100644 index 000000000000..679e133412b0 --- /dev/null +++ b/py-polars/tests/unit/streaming/test_streaming_categoricals.py @@ -0,0 +1,34 @@ +import pytest + +import polars as pl + +pytestmark = pytest.mark.xdist_group("streaming") + + +def test_streaming_nested_categorical() -> None: + with pl.StringCache(): + assert ( + pl.LazyFrame({"numbers": [1, 1, 2], "cat": [["str"], ["foo"], ["bar"]]}) + .with_columns(pl.col("cat").cast(pl.List(pl.Categorical))) + .group_by("numbers") + .agg(pl.col("cat").first()) + .sort("numbers") + ).collect(engine="streaming").to_dict(as_series=False) == { + "numbers": [1, 2], + "cat": [["str"], ["bar"]], + } + + +def test_streaming_cat_14933() -> None: + # https://github.com/pola-rs/polars/issues/14933 + + df1 = pl.LazyFrame({"a": pl.Series([0], dtype=pl.UInt32)}) + df2 = pl.LazyFrame( + [ + pl.Series("a", [0, 1], dtype=pl.UInt32), + pl.Series("l", [None, None], dtype=pl.Categorical(ordering="physical")), + ] + ) + result = df1.join(df2, on="a", how="left") + expected = {"a": [0], "l": [None]} + assert result.collect(engine="streaming").to_dict(as_series=False) == expected diff --git a/py-polars/tests/unit/streaming/test_streaming_cse.py b/py-polars/tests/unit/streaming/test_streaming_cse.py new file mode 100644 index 000000000000..d91ec0a617e0 --- /dev/null +++ b/py-polars/tests/unit/streaming/test_streaming_cse.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + +pytestmark = pytest.mark.xdist_group("streaming") + + +def test_cse_expr_selection_streaming(monkeypatch: Any) -> None: + monkeypatch.setenv("POLARS_VERBOSE", "1") + q = pl.LazyFrame( + { + "a": [1, 2, 3, 4], + "b": [1, 2, 3, 4], + "c": [1, 2, 3, 4], + } + ) + + derived = pl.col("a") * pl.col("b") + derived2 = derived * derived + + exprs = [ + derived.alias("d1"), + derived2.alias("d2"), + (derived2 * 10).alias("d3"), + ] + + result = q.select(exprs).collect(comm_subexpr_elim=True, engine="streaming") + expected = pl.DataFrame( + {"d1": [1, 4, 9, 16], "d2": [1, 16, 81, 256], "d3": [10, 160, 810, 2560]} + ) + assert_frame_equal(result, expected) + + result = q.with_columns(exprs).collect(comm_subexpr_elim=True, engine="streaming") + expected = pl.DataFrame( + { + "a": [1, 2, 3, 4], + "b": [1, 2, 3, 4], + "c": [1, 2, 3, 4], + "d1": [1, 4, 9, 16], + "d2": [1, 16, 81, 256], + "d3": [10, 160, 810, 2560], + } + ) + assert_frame_equal(result, expected) + + +def test_cse_expr_group_by() -> None: + q = pl.LazyFrame( + { + "a": [1, 2, 3, 4], + "b": [1, 2, 3, 4], + "c": [1, 2, 3, 4], + } + ) + + derived = pl.col("a") * pl.col("b") + + q = ( + q.group_by("a") + .agg(derived.sum().alias("sum"), derived.min().alias("min")) + .sort("min") + ) + + assert "__POLARS_CSER" in q.explain(comm_subexpr_elim=True, optimized=True) + + s = q.explain( + comm_subexpr_elim=True, + optimized=True, + engine="old-streaming", # type: ignore[arg-type] + comm_subplan_elim=False, + ) + assert s.startswith("STREAMING") + + expected = pl.DataFrame( + {"a": [1, 2, 3, 4], "sum": [1, 4, 9, 16], "min": [1, 4, 9, 16]} + ) + for streaming in [True, False]: + out = q.collect( + comm_subexpr_elim=True, engine="streaming" if streaming else "in-memory" + ) + assert_frame_equal(out, expected) diff --git a/py-polars/tests/unit/streaming/test_streaming_group_by.py b/py-polars/tests/unit/streaming/test_streaming_group_by.py new file mode 100644 index 000000000000..e31391c43791 --- /dev/null +++ b/py-polars/tests/unit/streaming/test_streaming_group_by.py @@ -0,0 +1,533 @@ +from __future__ import annotations + +from datetime import date +from typing import TYPE_CHECKING, Any + +import numpy as np +import pytest + +import polars as pl +from polars.exceptions import DuplicateError +from polars.testing import assert_frame_equal +from tests.unit.conftest import INTEGER_DTYPES + +if TYPE_CHECKING: + from pathlib import Path + +pytestmark = pytest.mark.xdist_group("streaming") + + +@pytest.mark.slow +def test_streaming_group_by_sorted_fast_path_nulls_10273() -> None: + df = pl.Series( + name="x", + values=( + *(i for i in range(4) for _ in range(100)), + *(None for _ in range(100)), + ), + ).to_frame() + + assert ( + df.set_sorted("x") + .lazy() + .group_by("x") + .agg(pl.len()) + .collect(engine="streaming") + .sort("x") + ).to_dict(as_series=False) == { + "x": [None, 0, 1, 2, 3], + "len": [100, 100, 100, 100, 100], + } + + +def test_streaming_group_by_types() -> None: + df = pl.DataFrame( + { + "person_id": [1, 1], + "year": [1995, 1995], + "person_name": ["bob", "foo"], + "bool": [True, False], + "date": [date(2022, 1, 1), date(2022, 1, 1)], + } + ) + + for by in ["person_id", "year", "date", ["person_id", "year"]]: + out = ( + ( + df.lazy() + .group_by(by) + .agg( + [ + pl.col("person_name").first().alias("str_first"), + pl.col("person_name").last().alias("str_last"), + pl.col("person_name").mean().alias("str_mean"), + pl.col("person_name").sum().alias("str_sum"), + pl.col("bool").first().alias("bool_first"), + pl.col("bool").last().alias("bool_last"), + pl.col("bool").mean().alias("bool_mean"), + pl.col("bool").sum().alias("bool_sum"), + # pl.col("date").sum().alias("date_sum"), + # Date streaming mean/median has been temporarily disabled + # pl.col("date").mean().alias("date_mean"), + pl.col("date").first().alias("date_first"), + pl.col("date").last().alias("date_last"), + pl.col("date").min().alias("date_min"), + pl.col("date").max().alias("date_max"), + ] + ) + ) + .select(pl.all().exclude(by)) + .collect(engine="streaming") + ) + assert out.schema == { + "str_first": pl.String, + "str_last": pl.String, + "str_mean": pl.String, + "str_sum": pl.String, + "bool_first": pl.Boolean, + "bool_last": pl.Boolean, + "bool_mean": pl.Float64, + "bool_sum": pl.UInt32, + # "date_sum": pl.Date, + # "date_mean": pl.Date, + "date_first": pl.Date, + "date_last": pl.Date, + "date_min": pl.Date, + "date_max": pl.Date, + } + + assert out.to_dict(as_series=False) == { + "str_first": ["bob"], + "str_last": ["foo"], + "str_mean": [None], + "str_sum": [None], + "bool_first": [True], + "bool_last": [False], + "bool_mean": [0.5], + "bool_sum": [1], + # "date_sum": [None], + # Date streaming mean/median has been temporarily disabled + # "date_mean": [date(2022, 1, 1)], + "date_first": [date(2022, 1, 1)], + "date_last": [date(2022, 1, 1)], + "date_min": [date(2022, 1, 1)], + "date_max": [date(2022, 1, 1)], + } + + with pytest.raises(DuplicateError): + ( + df.lazy() + .group_by("person_id") + .agg( + [ + pl.col("person_name").first().alias("str_first"), + pl.col("person_name").last().alias("str_last"), + pl.col("person_name").mean().alias("str_mean"), + pl.col("person_name").sum().alias("str_sum"), + pl.col("bool").first().alias("bool_first"), + pl.col("bool").last().alias("bool_first"), + ] + ) + .select(pl.all().exclude("person_id")) + .collect(engine="streaming") + ) + + +def test_streaming_group_by_min_max() -> None: + df = pl.DataFrame( + { + "person_id": [1, 2, 3, 4, 5, 6], + "year": [1995, 1995, 1995, 2, 2, 2], + } + ) + out = ( + df.lazy() + .group_by("year") + .agg([pl.min("person_id").alias("min"), pl.max("person_id").alias("max")]) + .collect() + .sort("year") + ) + assert out["min"].to_list() == [4, 1] + assert out["max"].to_list() == [6, 3] + + +def test_streaming_non_streaming_gb() -> None: + n = 100 + df = pl.DataFrame({"a": np.random.randint(0, 20, n)}) + q = df.lazy().group_by("a").agg(pl.len()).sort("a") + assert_frame_equal(q.collect(engine="streaming"), q.collect()) + + q = df.lazy().with_columns(pl.col("a").cast(pl.String)) + q = q.group_by("a").agg(pl.len()).sort("a") + assert_frame_equal(q.collect(engine="streaming"), q.collect()) + q = df.lazy().with_columns(pl.col("a").alias("b")) + q = q.group_by(["a", "b"]).agg(pl.len(), pl.col("a").sum().alias("sum_a")).sort("a") + assert_frame_equal(q.collect(engine="streaming"), q.collect()) + + +def test_streaming_group_by_sorted_fast_path() -> None: + a = np.random.randint(0, 20, 80) + df = pl.DataFrame( + { + # test on int8 as that also tests proper conversions + "a": pl.Series(np.sort(a), dtype=pl.Int8) + } + ).with_row_index() + + df_sorted = df.with_columns(pl.col("a").set_sorted()) + + for streaming in [True, False]: + results = [] + for df_ in [df, df_sorted]: + out = ( + df_.lazy() + .group_by("a") + .agg( + [ + pl.first("a").alias("first"), + pl.last("a").alias("last"), + pl.sum("a").alias("sum"), + pl.mean("a").alias("mean"), + pl.count("a").alias("count"), + pl.min("a").alias("min"), + pl.max("a").alias("max"), + ] + ) + .sort("a") + .collect(engine="streaming" if streaming else "in-memory") + ) + results.append(out) + + assert_frame_equal(results[0], results[1]) + + +@pytest.fixture(scope="module") +def random_integers() -> pl.Series: + np.random.seed(1) + return pl.Series("a", np.random.randint(0, 10, 100), dtype=pl.Int64) + + +@pytest.mark.write_disk +def test_streaming_group_by_ooc_q1( + random_integers: pl.Series, + tmp_path: Path, + monkeypatch: Any, +) -> None: + tmp_path.mkdir(exist_ok=True) + monkeypatch.setenv("POLARS_TEMP_DIR", str(tmp_path)) + monkeypatch.setenv("POLARS_FORCE_OOC", "1") + + lf = random_integers.to_frame().lazy() + result = ( + lf.group_by("a") + .agg(pl.first("a").alias("a_first"), pl.last("a").alias("a_last")) + .sort("a") + .collect(engine="streaming") + ) + + expected = pl.DataFrame( + { + "a": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + "a_first": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + "a_last": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + } + ) + assert_frame_equal(result, expected) + + +@pytest.mark.write_disk +def test_streaming_group_by_ooc_q2( + random_integers: pl.Series, + tmp_path: Path, + monkeypatch: Any, +) -> None: + tmp_path.mkdir(exist_ok=True) + monkeypatch.setenv("POLARS_TEMP_DIR", str(tmp_path)) + monkeypatch.setenv("POLARS_FORCE_OOC", "1") + + lf = random_integers.cast(str).to_frame().lazy() + result = ( + lf.group_by("a") + .agg(pl.first("a").alias("a_first"), pl.last("a").alias("a_last")) + .sort("a") + .collect(engine="streaming") + ) + + expected = pl.DataFrame( + { + "a": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"], + "a_first": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"], + "a_last": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"], + } + ) + assert_frame_equal(result, expected) + + +@pytest.mark.write_disk +def test_streaming_group_by_ooc_q3( + random_integers: pl.Series, + tmp_path: Path, + monkeypatch: Any, +) -> None: + tmp_path.mkdir(exist_ok=True) + monkeypatch.setenv("POLARS_TEMP_DIR", str(tmp_path)) + monkeypatch.setenv("POLARS_FORCE_OOC", "1") + + lf = pl.LazyFrame({"a": random_integers, "b": random_integers}) + result = ( + lf.group_by("a", "b") + .agg(pl.first("a").alias("a_first"), pl.last("a").alias("a_last")) + .sort("a") + .collect(engine="streaming") + ) + + expected = pl.DataFrame( + { + "a": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + "b": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + "a_first": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + "a_last": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + } + ) + assert_frame_equal(result, expected) + + +def test_streaming_group_by_struct_key() -> None: + df = pl.DataFrame( + {"A": [1, 2, 3, 2], "B": ["google", "ms", "apple", "ms"], "C": [2, 3, 4, 3]} + ) + df1 = df.lazy().with_columns(pl.struct(["A", "C"]).alias("tuples")) + assert df1.group_by("tuples").agg(pl.len(), pl.col("B").first()).sort("B").collect( + engine="streaming" + ).to_dict(as_series=False) == { + "tuples": [{"A": 3, "C": 4}, {"A": 1, "C": 2}, {"A": 2, "C": 3}], + "len": [1, 1, 2], + "B": ["apple", "google", "ms"], + } + + +@pytest.mark.slow +def test_streaming_group_by_all_numeric_types_stability_8570() -> None: + m = 1000 + n = 1000 + + rng = np.random.default_rng(seed=0) + dfa = pl.DataFrame({"x": pl.arange(start=0, end=n, eager=True)}) + dfb = pl.DataFrame( + { + "y": rng.integers(low=0, high=10, size=m), + "z": rng.integers(low=0, high=2, size=m), + } + ) + dfc = dfa.join(dfb, how="cross") + + for keys in [["x", "y"], "z"]: + for dtype in [*INTEGER_DTYPES, pl.Boolean]: + # the alias checks if the schema is correctly handled + dfd = ( + dfc.lazy() + .with_columns(pl.col("z").cast(dtype)) + .group_by(keys) + .agg(pl.col("z").sum().alias("z_sum")) + .collect(engine="old-streaming") # type: ignore[call-overload] + ) + assert dfd["z_sum"].sum() == dfc["z"].sum() + + +def test_streaming_group_by_categorical_aggregate() -> None: + with pl.StringCache(): + out = ( + pl.LazyFrame( + { + "a": pl.Series( + ["a", "a", "b", "b", "c", "c", None, None], dtype=pl.Categorical + ), + "b": pl.Series( + pl.date_range( + date(2023, 4, 28), + date(2023, 5, 5), + eager=True, + ).to_list(), + dtype=pl.Date, + ), + } + ) + .group_by(["a", "b"]) + .agg([pl.col("a").first().alias("sum")]) + .collect(engine="streaming") + ) + + assert out.sort("b").to_dict(as_series=False) == { + "a": ["a", "a", "b", "b", "c", "c", None, None], + "b": [ + date(2023, 4, 28), + date(2023, 4, 29), + date(2023, 4, 30), + date(2023, 5, 1), + date(2023, 5, 2), + date(2023, 5, 3), + date(2023, 5, 4), + date(2023, 5, 5), + ], + "sum": ["a", "a", "b", "b", "c", "c", None, None], + } + + +def test_streaming_group_by_list_9758() -> None: + payload = {"a": [[1, 2]]} + assert ( + pl.LazyFrame(payload) + .group_by("a") + .first() + .collect(engine="streaming") + .to_dict(as_series=False) + == payload + ) + + +def test_streaming_restart_non_streamable_group_by() -> None: + df = pl.DataFrame({"id": [1], "id2": [1], "id3": [1], "value": [1]}) + res = ( + df.lazy() + .join(df.lazy(), on=["id", "id2"], how="left") + .filter( + (pl.col("id3") > pl.col("id3_right")) + & (pl.col("id3") - pl.col("id3_right") < 30) + ) + .group_by(["id2", "id3", "id3_right"]) + .agg( + pl.col("value").map_elements(lambda x: x).sum() * pl.col("value").sum() + ) # non-streamable UDF + nested_agg + ) + + assert "STREAMING" in res.explain(engine="old-streaming") # type: ignore[arg-type] + + +def test_group_by_min_max_string_type() -> None: + table = pl.from_dict({"a": [1, 1, 2, 2, 2], "b": ["a", "b", "c", "d", None]}) + + expected = {"a": [1, 2], "min": ["a", "c"], "max": ["b", "d"]} + + for streaming in [True, False]: + assert ( + table.lazy() + .group_by("a") + .agg([pl.min("b").alias("min"), pl.max("b").alias("max")]) + .collect(engine="streaming" if streaming else "in-memory") + .sort("a") + .to_dict(as_series=False) + == expected + ) + + +@pytest.mark.parametrize("literal", [True, "foo", 1]) +def test_streaming_group_by_literal(literal: Any) -> None: + df = pl.LazyFrame({"a": range(20)}) + + assert df.group_by(pl.lit(literal)).agg( + [ + pl.col("a").count().alias("a_count"), + pl.col("a").sum().alias("a_sum"), + ] + ).collect(engine="streaming").to_dict(as_series=False) == { + "literal": [literal], + "a_count": [20], + "a_sum": [190], + } + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_group_by_multiple_keys_one_literal(streaming: bool) -> None: + df = pl.DataFrame({"a": [1, 1, 2], "b": [4, 5, 6]}) + + expected = {"a": [1, 2], "literal": [1, 1], "b": [5, 6]} + assert ( + df.lazy() + .group_by("a", pl.lit(1)) + .agg(pl.col("b").max()) + .sort(["a", "b"]) + .collect(engine="streaming" if streaming else "in-memory") + .to_dict(as_series=False) + == expected + ) + + +def test_streaming_group_null_count() -> None: + df = pl.DataFrame({"g": [1] * 6, "a": ["yes", None] * 3}).lazy() + assert df.group_by("g").agg(pl.col("a").count()).collect( + engine="streaming" + ).to_dict(as_series=False) == {"g": [1], "a": [3]} + + +def test_streaming_group_by_binary_15116() -> None: + assert ( + pl.LazyFrame( + { + "str": [ + "A", + "A", + "BB", + "BB", + "CCCC", + "CCCC", + "DDDDDDDD", + "DDDDDDDD", + "EEEEEEEEEEEEEEEE", + "A", + ] + } + ) + .select([pl.col("str").cast(pl.Binary)]) + .group_by(["str"]) + .agg([pl.len().alias("count")]) + ).sort("str").collect(engine="streaming").to_dict(as_series=False) == { + "str": [b"A", b"BB", b"CCCC", b"DDDDDDDD", b"EEEEEEEEEEEEEEEE"], + "count": [3, 2, 2, 2, 1], + } + + +def test_streaming_group_by_convert_15380(partition_limit: int) -> None: + assert ( + pl.DataFrame({"a": [1] * partition_limit}).group_by(b="a").len()["len"].item() + == partition_limit + ) + + +@pytest.mark.parametrize("streaming", [True, False]) +@pytest.mark.parametrize("n_rows_limit_offset", [-1, +3]) +def test_streaming_group_by_boolean_mean_15610( + n_rows_limit_offset: int, streaming: bool, partition_limit: int +) -> None: + n_rows = partition_limit + n_rows_limit_offset + + # Also test non-streaming because it sometimes dispatched to streaming agg. + expect = pl.DataFrame({"a": [False, True], "c": [0.0, 0.5]}) + + n_repeats = n_rows // 3 + assert n_repeats > 0 + + out = ( + pl.select( + a=pl.repeat([True, False, True], n_repeats).explode(), + b=pl.repeat([True, False, False], n_repeats).explode(), + ) + .lazy() + .group_by("a") + .agg(c=pl.mean("b")) + .sort("a") + .collect(engine="streaming" if streaming else "in-memory") + ) + + assert_frame_equal(out, expect) + + +def test_streaming_group_by_all_null_21593() -> None: + df = pl.DataFrame( + { + "col_1": ["A", "B", "C", "D"], + "col_2": ["test", None, None, None], + } + ) + + out = df.lazy().group_by(pl.all()).min().collect(engine="streaming") + assert_frame_equal(df, out, check_row_order=False) diff --git a/py-polars/tests/unit/streaming/test_streaming_io.py b/py-polars/tests/unit/streaming/test_streaming_io.py new file mode 100644 index 000000000000..e39b1d00d5e0 --- /dev/null +++ b/py-polars/tests/unit/streaming/test_streaming_io.py @@ -0,0 +1,311 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + +if TYPE_CHECKING: + from pathlib import Path + +pytestmark = pytest.mark.xdist_group("streaming") + + +@pytest.mark.write_disk +def test_streaming_parquet_glob_5900(df: pl.DataFrame, tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + file_path = tmp_path / "small.parquet" + df.write_parquet(file_path) + + path_glob = tmp_path / "small*.parquet" + result = ( + pl.scan_parquet(path_glob).select(pl.all().first()).collect(engine="streaming") + ) + assert result.shape == (1, df.width) + + +def test_scan_slice_streaming(io_files_path: Path) -> None: + foods_file_path = io_files_path / "foods1.csv" + df = pl.scan_csv(foods_file_path).head(5).collect(engine="streaming") + assert df.shape == (5, 4) + + # globbing + foods_file_path = io_files_path / "foods*.csv" + df = pl.scan_csv(foods_file_path).head(5).collect(engine="streaming") + assert df.shape == (5, 4) + + +@pytest.mark.parametrize("dtype", [pl.Int8, pl.UInt8, pl.Int16, pl.UInt16]) +def test_scan_csv_overwrite_small_dtypes( + io_files_path: Path, dtype: pl.DataType +) -> None: + file_path = io_files_path / "foods1.csv" + df = pl.scan_csv(file_path, schema_overrides={"sugars_g": dtype}).collect( + engine="streaming" + ) + assert df.dtypes == [pl.String, pl.Int64, pl.Float64, dtype] + + +@pytest.mark.write_disk +def test_sink_parquet(io_files_path: Path, tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + file = io_files_path / "small.parquet" + + file_path = tmp_path / "sink.parquet" + + df_scanned = pl.scan_parquet(file) + df_scanned.sink_parquet(file_path) + + with pl.StringCache(): + result = pl.read_parquet(file_path) + df_read = pl.read_parquet(file) + assert_frame_equal(result, df_read) + + +@pytest.mark.write_disk +def test_sink_parquet_10115(tmp_path: Path) -> None: + in_path = tmp_path / "in.parquet" + out_path = tmp_path / "out.parquet" + + # this fails if the schema will be incorrectly due to the projection + # pushdown + (pl.DataFrame([{"x": 1, "y": "foo"}]).write_parquet(in_path)) + + joiner = pl.LazyFrame([{"y": "foo", "z": "_"}]) + + ( + pl.scan_parquet(in_path) + .join(joiner, how="left", on="y") + .select("x", "y", "z") + .sink_parquet(out_path) # + ) + + assert pl.read_parquet(out_path).to_dict(as_series=False) == { + "x": [1], + "y": ["foo"], + "z": ["_"], + } + + +@pytest.mark.write_disk +def test_sink_ipc(io_files_path: Path, tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + file = io_files_path / "small.parquet" + + file_path = tmp_path / "sink.ipc" + + df_scanned = pl.scan_parquet(file) + df_scanned.sink_ipc(file_path) + + with pl.StringCache(): + result = pl.read_ipc(file_path) + df_read = pl.read_parquet(file) + assert_frame_equal(result, df_read) + + +@pytest.mark.write_disk +def test_sink_csv(io_files_path: Path, tmp_path: Path) -> None: + source_file = io_files_path / "small.parquet" + target_file = tmp_path / "sink.csv" + + pl.scan_parquet(source_file).sink_csv(target_file) + + with pl.StringCache(): + source_data = pl.read_parquet(source_file) + target_data = pl.read_csv(target_file) + assert_frame_equal(target_data, source_data) + + +@pytest.mark.write_disk +def test_sink_csv_14494(tmp_path: Path) -> None: + pl.LazyFrame({"c": [1, 2, 3]}, schema={"c": pl.Int64}).filter( + pl.col("c") > 10 + ).sink_csv(tmp_path / "sink.csv") + assert pl.read_csv(tmp_path / "sink.csv").columns == ["c"] + + +@pytest.mark.parametrize(("value"), ["abc", ""]) +def test_sink_csv_exception_for_separator(value: str) -> None: + df = pl.LazyFrame({"dummy": ["abc"]}) + with pytest.raises(ValueError, match="should be a single byte character, but is"): + df.sink_csv("path", separator=value) + + +@pytest.mark.parametrize(("value"), ["abc", ""]) +def test_sink_csv_exception_for_quote(value: str) -> None: + df = pl.LazyFrame({"dummy": ["abc"]}) + with pytest.raises(ValueError, match="should be a single byte character, but is"): + df.sink_csv("path", quote_char=value) + + +def test_sink_csv_batch_size_zero() -> None: + lf = pl.LazyFrame({"a": [1, 2, 3], "b": [1, 2, 3]}) + with pytest.raises(ValueError, match="invalid zero value"): + lf.sink_csv("test.csv", batch_size=0) + + +@pytest.mark.write_disk +def test_sink_csv_nested_data(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + path = tmp_path / "data.csv" + + lf = pl.LazyFrame({"list": [[1, 2, 3, 4, 5]]}) + with pytest.raises( + pl.exceptions.ComputeError, match="CSV format does not support nested data" + ): + lf.sink_csv(path) + + +def test_scan_csv_only_header_10792(io_files_path: Path) -> None: + foods_file_path = io_files_path / "only_header.csv" + df = pl.scan_csv(foods_file_path).collect(engine="streaming") + assert df.to_dict(as_series=False) == {"Name": [], "Address": []} + + +def test_scan_empty_csv_10818(io_files_path: Path) -> None: + empty_file_path = io_files_path / "empty.csv" + df = pl.scan_csv(empty_file_path, raise_if_empty=False).collect(engine="streaming") + assert df.is_empty() + + +@pytest.mark.write_disk +def test_streaming_cross_join_schema(tmp_path: Path) -> None: + file_path = tmp_path / "temp.parquet" + a = pl.DataFrame({"a": [1, 2]}).lazy() + b = pl.DataFrame({"b": ["b"]}).lazy() + a.join(b, how="cross").sink_parquet(file_path) + read = pl.read_parquet(file_path, parallel="none") + assert read.to_dict(as_series=False) == {"a": [1, 2], "b": ["b", "b"]} + + +@pytest.mark.write_disk +def test_sink_ndjson_should_write_same_data( + io_files_path: Path, tmp_path: Path +) -> None: + tmp_path.mkdir(exist_ok=True) + + source_path = io_files_path / "foods1.csv" + target_path = tmp_path / "foods_test.ndjson" + + expected = pl.read_csv(source_path) + + lf = pl.scan_csv(source_path) + lf.sink_ndjson(target_path) + df = pl.read_ndjson(target_path) + + assert_frame_equal(df, expected) + + +@pytest.mark.write_disk +@pytest.mark.parametrize("streaming", [False, True]) +def test_parquet_eq_statistics( + monkeypatch: Any, capfd: Any, tmp_path: Path, streaming: bool +) -> None: + tmp_path.mkdir(exist_ok=True) + + monkeypatch.setenv("POLARS_VERBOSE", "1") + + df = pl.DataFrame({"idx": pl.arange(100, 200, eager=True)}).with_columns( + (pl.col("idx") // 25).alias("part") + ) + df = pl.concat(df.partition_by("part", as_dict=False), rechunk=False) + assert df.n_chunks("all") == [4, 4] + + file_path = tmp_path / "stats.parquet" + df.write_parquet(file_path, statistics=True, use_pyarrow=False) + + for pred in [ + pl.col("idx") == 50, + pl.col("idx") == 150, + pl.col("idx") == 210, + ]: + result = ( + pl.scan_parquet(file_path) + .filter(pred) + .collect(engine="streaming" if streaming else "in-memory") + ) + assert_frame_equal(result, df.filter(pred)) + + captured = capfd.readouterr().err + assert ( + "[ParquetFileReader]: Predicate pushdown: reading 1 / 1 row groups" in captured + ) + assert ( + "[ParquetFileReader]: Predicate pushdown: reading 0 / 1 row groups" in captured + ) + + +@pytest.mark.write_disk +def test_streaming_empty_parquet_16523(tmp_path: Path) -> None: + file_path = tmp_path / "foo.parquet" + df = pl.DataFrame({"a": []}, schema={"a": pl.Int32}) + df.write_parquet(file_path) + q = pl.scan_parquet(file_path) + q2 = pl.LazyFrame({"a": [1]}, schema={"a": pl.Int32}) + assert q.join(q2, on="a").collect(engine="streaming").shape == (0, 1) + + +@pytest.mark.parametrize( + "method", + ["parquet", "csv", "ipc", "ndjson"], +) +@pytest.mark.write_disk +def test_sink_phases(tmp_path: Path, method: str) -> None: + df = pl.DataFrame( + { + "a": [1, 2, 3, 4, 5, 6, 7], + "b": [ + "some", + "text", + "over-here-is-very-long", + "and", + "some", + "more", + "text", + ], + } + ) + + # Ordered Unions lead to many phase transitions. + ref_df = pl.concat([df] * 100) + lf = pl.concat([df.lazy()] * 100) + + (getattr(lf, f"sink_{method}"))(tmp_path / f"t.{method}", engine="streaming") + df = (getattr(pl, f"scan_{method}"))(tmp_path / f"t.{method}").collect() + + assert_frame_equal(df, ref_df) + + (getattr(lf, f"sink_{method}"))( + tmp_path / f"t.{method}", maintain_order=False, engine="streaming" + ) + height = ( + (getattr(pl, f"scan_{method}"))(tmp_path / f"t.{method}") + .select(pl.len()) + .collect()[0, 0] + ) + assert height == ref_df.height + + +def test_empty_sink_parquet_join_14863(tmp_path: Path) -> None: + file_path = tmp_path / "empty.parquet" + lf = pl.LazyFrame(schema=["a", "b", "c"]).cast(pl.String) + lf.sink_parquet(file_path) + assert_frame_equal( + pl.LazyFrame({"a": ["uno"]}).join(pl.scan_parquet(file_path), on="a").collect(), + lf.collect(), + ) + + +@pytest.mark.write_disk +def test_scan_non_existent_file_21527() -> None: + with pytest.raises( + FileNotFoundError, + match=r"a-file-that-does-not-exist", + ): + pl.scan_parquet("a-file-that-does-not-exist").sink_ipc( + "x.ipc", engine="streaming" + ) diff --git a/py-polars/tests/unit/streaming/test_streaming_join.py b/py-polars/tests/unit/streaming/test_streaming_join.py new file mode 100644 index 000000000000..01e251862b7e --- /dev/null +++ b/py-polars/tests/unit/streaming/test_streaming_join.py @@ -0,0 +1,360 @@ +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING, Literal + +import numpy as np +import pandas as pd +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + +if TYPE_CHECKING: + from pathlib import Path + + from polars._typing import JoinStrategy + +pytestmark = pytest.mark.xdist_group("streaming") + + +def test_streaming_full_outer_joins() -> None: + n = 100 + dfa = pl.DataFrame( + { + "a": np.random.randint(0, 40, n), + "idx": np.arange(0, n), + } + ) + + n = 100 + dfb = pl.DataFrame( + { + "a": np.random.randint(0, 40, n), + "idx": np.arange(0, n), + } + ) + + join_strategies: list[tuple[JoinStrategy, bool]] = [ + ("full", False), + ("full", True), + ] + for how, coalesce in join_strategies: + q = ( + dfa.lazy() + .join(dfb.lazy(), on="a", how=how, coalesce=coalesce) + .sort(["idx"]) + ) + a = q.collect(engine="streaming") + b = q.collect(engine="in-memory") + assert_frame_equal(a, b, check_row_order=False) + + +def test_streaming_joins() -> None: + n = 100 + dfa = pd.DataFrame( + { + "a": np.random.randint(0, 40, n), + "b": np.arange(0, n), + } + ) + + n = 100 + dfb = pd.DataFrame( + { + "a": np.random.randint(0, 40, n), + "b": np.arange(0, n), + } + ) + dfa_pl = pl.from_pandas(dfa).sort("a") + dfb_pl = pl.from_pandas(dfb) + + join_strategies: list[Literal["inner", "left"]] = ["inner", "left"] + for how in join_strategies: + pd_result = dfa.merge(dfb, on="a", how=how) + pd_result.columns = pd.Index(["a", "b", "b_right"]) + + pl_result = ( + dfa_pl.lazy() + .join(dfb_pl.lazy(), on="a", how=how) + .sort(["a", "b", "b_right"]) + .collect(engine="streaming") + ) + + a = ( + pl.from_pandas(pd_result) + .with_columns(pl.all().cast(int)) + .sort(["a", "b", "b_right"]) + ) + assert_frame_equal(a, pl_result, check_dtypes=False) + + pd_result = dfa.merge(dfb, on=["a", "b"], how=how) + + pl_result = ( + dfa_pl.lazy() + .join(dfb_pl.lazy(), on=["a", "b"], how=how) + .sort(["a", "b"]) + .collect(engine="streaming") + ) + + # we cast to integer because pandas joins creates floats + a = pl.from_pandas(pd_result).with_columns(pl.all().cast(int)).sort(["a", "b"]) + assert_frame_equal(a, pl_result, check_dtypes=False) + + +def test_streaming_cross_join_empty() -> None: + df1 = pl.LazyFrame(data={"col1": ["a"]}) + + df2 = pl.LazyFrame( + data={"col1": []}, + schema={"col1": str}, + ) + + out = df1.join(df2, how="cross").collect(engine="streaming") + assert out.shape == (0, 2) + assert out.columns == ["col1", "col1_right"] + + +def test_streaming_join_rechunk_12498() -> None: + rows = pl.int_range(0, 2) + + a = pl.select(A=rows).lazy() + b = pl.select(B=rows).lazy() + + q = a.join(b, how="cross") + assert q.collect(engine="streaming").sort(["B", "A"]).to_dict(as_series=False) == { + "A": [0, 1, 0, 1], + "B": [0, 0, 1, 1], + } + + +@pytest.mark.parametrize("streaming", [False, True]) +def test_join_null_matches(streaming: bool) -> None: + # null values in joins should never find a match. + df_a = pl.LazyFrame( + { + "idx_a": [0, 1, 2], + "a": [None, 1, 2], + } + ) + + df_b = pl.LazyFrame( + { + "idx_b": [0, 1, 2, 3], + "a": [None, 2, 1, None], + } + ) + # Semi + assert df_a.join(df_b, on="a", how="semi", nulls_equal=True).collect( + engine="streaming" if streaming else "in-memory" + )["idx_a"].to_list() == [0, 1, 2] + assert df_a.join(df_b, on="a", how="semi", nulls_equal=False).collect( + engine="streaming" if streaming else "in-memory" + )["idx_a"].to_list() == [1, 2] + + # Inner + expected = pl.DataFrame({"idx_a": [2, 1], "a": [2, 1], "idx_b": [1, 2]}) + assert_frame_equal( + df_a.join(df_b, on="a", how="inner").collect( + engine="streaming" if streaming else "in-memory" + ), + expected, + check_row_order=False, + ) + + # Left outer + expected = pl.DataFrame( + {"idx_a": [0, 1, 2], "a": [None, 1, 2], "idx_b": [None, 2, 1]} + ) + assert_frame_equal( + df_a.join(df_b, on="a", how="left").collect( + engine="streaming" if streaming else "in-memory" + ), + expected, + check_row_order=False, + ) + # Full outer + expected = pl.DataFrame( + { + "idx_a": [None, 2, 1, None, 0], + "a": [None, 2, 1, None, None], + "idx_b": [0, 1, 2, 3, None], + "a_right": [None, 2, 1, None, None], + } + ) + assert_frame_equal( + df_a.join(df_b, on="a", how="full").collect(), expected, check_row_order=False + ) + + +@pytest.mark.parametrize("streaming", [False, True]) +def test_join_null_matches_multiple_keys(streaming: bool) -> None: + df_a = pl.LazyFrame( + { + "a": [None, 1, 2], + "idx": [0, 1, 2], + } + ) + + df_b = pl.LazyFrame( + { + "a": [None, 2, 1, None, 1], + "idx": [0, 1, 2, 3, 1], + "c": [10, 20, 30, 40, 50], + } + ) + + expected = pl.DataFrame({"a": [1], "idx": [1], "c": [50]}) + assert_frame_equal( + df_a.join(df_b, on=["a", "idx"], how="inner").collect( + engine="streaming" if streaming else "in-memory" + ), + expected, + check_row_order=False, + ) + expected = pl.DataFrame( + {"a": [None, 1, 2], "idx": [0, 1, 2], "c": [None, 50, None]} + ) + assert_frame_equal( + df_a.join(df_b, on=["a", "idx"], how="left").collect( + engine="streaming" if streaming else "in-memory" + ), + expected, + check_row_order=False, + ) + + expected = pl.DataFrame( + { + "a": [None, None, None, None, None, 1, 2], + "idx": [None, None, None, None, 0, 1, 2], + "a_right": [None, 2, 1, None, None, 1, None], + "idx_right": [0, 1, 2, 3, None, 1, None], + "c": [10, 20, 30, 40, None, 50, None], + } + ) + assert_frame_equal( + df_a.join(df_b, on=["a", "idx"], how="full").sort("a").collect(), + expected, + check_row_order=False, + ) + + +def test_streaming_join_and_union() -> None: + a = pl.LazyFrame({"a": [1, 2]}) + + b = pl.LazyFrame({"a": [1, 2, 4, 8]}) + + c = a.join(b, on="a", maintain_order="left_right") + # The join node latest ensures that the dispatcher + # needs to replace placeholders in unions. + q = pl.concat([a, b, c]) + + out = q.collect(engine="streaming") + assert_frame_equal(out, q.collect(engine="in-memory")) + assert out.to_series().to_list() == [1, 2, 1, 2, 4, 8, 1, 2] + + +def test_non_coalescing_streaming_left_join() -> None: + df1 = pl.LazyFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}) + + df2 = pl.LazyFrame({"a": [1, 2], "c": ["j", "i"]}) + + q = df1.join(df2, on="a", how="left", coalesce=False) + assert q.explain(engine="old-streaming").startswith("STREAMING") # type: ignore[arg-type] + assert_frame_equal( + q.collect(engine="streaming"), + pl.DataFrame( + { + "a": [1, 2, 3], + "b": ["a", "b", "c"], + "a_right": [1, 2, None], + "c": ["j", "i", None], + } + ), + check_row_order=False, + ) + + +@pytest.mark.write_disk +def test_streaming_outer_join_partial_flush(tmp_path: Path) -> None: + data = { + "value_at": [datetime(2024, i + 1, 1) for i in range(6)], + "value": list(range(6)), + } + + parquet_path = tmp_path / "data.parquet" + pl.DataFrame(data=data).write_parquet(parquet_path) + + other_parquet_path = tmp_path / "data2.parquet" + pl.DataFrame(data=data).write_parquet(other_parquet_path) + + lf1 = pl.scan_parquet(other_parquet_path) + lf2 = pl.scan_parquet(parquet_path) + + join_cols = set(lf1.collect_schema()).intersection(set(lf2.collect_schema())) + final_lf = lf1.join(lf2, on=list(join_cols), how="full", coalesce=True) + + assert_frame_equal( + final_lf.collect(engine="streaming"), + pl.DataFrame( + { + "value_at": [ + datetime(2024, 1, 1, 0, 0), + datetime(2024, 2, 1, 0, 0), + datetime(2024, 3, 1, 0, 0), + datetime(2024, 4, 1, 0, 0), + datetime(2024, 5, 1, 0, 0), + datetime(2024, 6, 1, 0, 0), + ], + "value": [0, 1, 2, 3, 4, 5], + } + ), + check_row_order=False, + ) + + +def test_flush_join_and_operation_19040() -> None: + df_A = pl.LazyFrame({"K": [True, False], "A": [1, 1]}) + + df_B = pl.LazyFrame({"K": [True], "B": [1]}) + + df_C = pl.LazyFrame({"K": [True], "C": [1]}) + + q = ( + df_A.join(df_B, how="full", on=["K"], coalesce=True) + .join(df_C, how="full", on=["K"], coalesce=True) + .with_columns(B=pl.col("B")) + .sort("K") + ) + assert q.collect(engine="streaming").to_dict(as_series=False) == { + "K": [False, True], + "A": [1, 1], + "B": [None, 1], + "C": [None, 1], + } + + +def test_full_coalesce_join_and_rename_15583() -> None: + df1 = pl.LazyFrame({"a": [1, 2, 3]}) + df2 = pl.LazyFrame({"a": [3, 4, 5]}) + + result = ( + df1.join(df2, on="a", how="full", coalesce=True) + .select(pl.all().name.map(lambda c: c.upper())) + .sort("A") + .collect(engine="streaming") + ) + assert result["A"].to_list() == [1, 2, 3, 4, 5] + + +def test_invert_order_full_join_22295() -> None: + lf = pl.LazyFrame( + { + "value_at": [datetime(2024, i + 1, 1) for i in range(6)], + "value": list(range(6)), + } + ) + + lf.join(lf, on=["value", "value_at"], how="full", coalesce=True).collect( + engine="streaming" + ) diff --git a/py-polars/tests/unit/streaming/test_streaming_sort.py b/py-polars/tests/unit/streaming/test_streaming_sort.py new file mode 100644 index 000000000000..42660fd80d31 --- /dev/null +++ b/py-polars/tests/unit/streaming/test_streaming_sort.py @@ -0,0 +1,298 @@ +from __future__ import annotations + +from collections import Counter +from datetime import datetime +from typing import TYPE_CHECKING, Any + +import numpy as np +import pytest + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal + +if TYPE_CHECKING: + from pathlib import Path + +pytestmark = pytest.mark.xdist_group("streaming") + + +def assert_df_sorted_by( + df: pl.DataFrame, + sort_df: pl.DataFrame, + cols: list[str], + descending: list[bool] | None = None, +) -> None: + if descending is None: + descending = [False] * len(cols) + + # Is sorted by the key columns? + keycols = sort_df[cols] + equal = keycols.head(-1) == keycols.tail(-1) + + # Tuple inequality. + # a0 < b0 || (a0 == b0 && (a1 < b1 || (a1 == b1 && ...)) + # Evaluating in reverse is easiest. + ordered = equal[cols[-1]] + for c, desc in zip(cols[::-1], descending[::-1]): + ordered &= equal[c] + if desc: + ordered |= keycols[c].head(-1) > keycols[c].tail(-1) + else: + ordered |= keycols[c].head(-1) < keycols[c].tail(-1) + + assert ordered.all() + + # Do all the rows still exist? + assert Counter(df.rows()) == Counter(sort_df.rows()) + + +def test_streaming_sort_multiple_columns_logical_types() -> None: + data = { + "foo": [3, 2, 1], + "bar": ["a", "b", "c"], + "baz": [ + datetime(2023, 5, 1, 15, 45), + datetime(2023, 5, 1, 13, 45), + datetime(2023, 5, 1, 14, 45), + ], + } + + result = pl.LazyFrame(data).sort("foo", "baz").collect(engine="streaming") + + expected = pl.DataFrame( + { + "foo": [1, 2, 3], + "bar": ["c", "b", "a"], + "baz": [ + datetime(2023, 5, 1, 14, 45), + datetime(2023, 5, 1, 13, 45), + datetime(2023, 5, 1, 15, 45), + ], + } + ) + assert_frame_equal(result, expected) + + +@pytest.mark.write_disk +@pytest.mark.slow +def test_ooc_sort(tmp_path: Path, monkeypatch: Any) -> None: + tmp_path.mkdir(exist_ok=True) + monkeypatch.setenv("POLARS_TEMP_DIR", str(tmp_path)) + monkeypatch.setenv("POLARS_FORCE_OOC", "1") + + s = pl.arange(0, 100_000, eager=True).rename("idx") + + df = s.shuffle().to_frame() + + for descending in [True, False]: + out = ( + df.lazy().sort("idx", descending=descending).collect(engine="streaming") + ).to_series() + + assert_series_equal(out, s.sort(descending=descending)) + + +@pytest.mark.debug +@pytest.mark.write_disk +@pytest.mark.parametrize("spill_source", [True, False]) +def test_streaming_sort( + tmp_path: Path, monkeypatch: Any, capfd: Any, spill_source: bool +) -> None: + tmp_path.mkdir(exist_ok=True) + monkeypatch.setenv("POLARS_TEMP_DIR", str(tmp_path)) + monkeypatch.setenv("POLARS_FORCE_OOC", "1") + monkeypatch.setenv("POLARS_VERBOSE", "1") + if spill_source: + monkeypatch.setenv("POLARS_SPILL_SORT_PARTITIONS", "1") + # this creates a lot of duplicate partitions and triggers: #7568 + assert ( + pl.Series(np.random.randint(0, 100, 100)) + .to_frame("s") + .lazy() + .sort("s") + .collect(engine="old-streaming")["s"] # type: ignore[call-overload] + .is_sorted() + ) + (_, err) = capfd.readouterr() + assert "df -> sort" in err + if spill_source: + assert "PARTITIONED FORCE SPILLED" in err + + +@pytest.mark.write_disk +@pytest.mark.parametrize("spill_source", [True, False]) +def test_out_of_core_sort_9503( + tmp_path: Path, monkeypatch: Any, spill_source: bool +) -> None: + tmp_path.mkdir(exist_ok=True) + monkeypatch.setenv("POLARS_TEMP_DIR", str(tmp_path)) + monkeypatch.setenv("POLARS_FORCE_OOC", "1") + if spill_source: + monkeypatch.setenv("POLARS_SPILL_SORT_PARTITIONS", "1") + np.random.seed(0) + + num_rows = 100_000 + num_columns = 2 + num_tables = 10 + + # ensure we create many chunks + # this will ensure we create more files + # and that creates contention while dumping + df = pl.concat( + [ + pl.DataFrame( + [ + pl.Series(np.random.randint(0, 10000, size=num_rows)) + for _ in range(num_columns) + ] + ) + for _ in range(num_tables) + ], + rechunk=False, + ) + lf = df.lazy() + + result = lf.sort(df.columns).collect(engine="streaming") + + assert result.shape == (1_000_000, 2) + assert result["column_0"].flags["SORTED_ASC"] + assert result.head(20).to_dict(as_series=False) == { + "column_0": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + "column_1": [ + 242, + 245, + 588, + 618, + 732, + 902, + 925, + 945, + 1009, + 1161, + 1352, + 1365, + 1451, + 1581, + 1778, + 1836, + 1976, + 2091, + 2120, + 2124, + ], + } + + +@pytest.mark.write_disk +@pytest.mark.slow +def test_streaming_sort_multiple_columns( + str_ints_df: pl.DataFrame, tmp_path: Path, monkeypatch: Any, capfd: Any +) -> None: + tmp_path.mkdir(exist_ok=True) + monkeypatch.setenv("POLARS_TEMP_DIR", str(tmp_path)) + monkeypatch.setenv("POLARS_FORCE_OOC", "1") + monkeypatch.setenv("POLARS_VERBOSE", "1") + df = str_ints_df + + out = df.lazy().sort(["strs", "vals"]).collect(engine="old-streaming") # type: ignore[call-overload] + assert_frame_equal(out, out.sort(["strs", "vals"])) + err = capfd.readouterr().err + assert "OOC sort forced" in err + assert "RUN STREAMING PIPELINE" in err + assert "df -> sort_multiple" in err + assert out.columns == ["vals", "strs"] + + +def test_streaming_sort_sorted_flag() -> None: + # empty + q = pl.LazyFrame( + schema={ + "store_id": pl.UInt16, + "item_id": pl.UInt32, + "timestamp": pl.Datetime, + } + ).sort("timestamp") + + assert q.collect(engine="streaming")["timestamp"].flags["SORTED_ASC"] + + +@pytest.mark.parametrize( + ("sort_by"), + [ + ["fats_g", "category"], + ["fats_g", "category", "calories"], + ["fats_g", "category", "calories", "sugars_g"], + ], +) +def test_streaming_sort_varying_order_and_dtypes( + io_files_path: Path, sort_by: list[str] +) -> None: + q = pl.scan_parquet(io_files_path / "foods*.parquet") + df = q.collect() + assert_df_sorted_by(df, q.sort(sort_by).collect(engine="streaming"), sort_by) + assert_df_sorted_by(df, q.sort(sort_by).collect(engine="in-memory"), sort_by) + + +def test_streaming_sort_fixed_reverse() -> None: + df = pl.DataFrame( + { + "a": [1, 1, 2, 1, 2, 4, 1, 7], + "b": [1, 2, 2, 1, 2, 4, 8, 7], + } + ) + descending = [True, False] + q = df.lazy().sort(by=["a", "b"], descending=descending) + + assert_df_sorted_by( + df, q.collect(engine="streaming"), ["a", "b"], descending=descending + ) + assert_df_sorted_by( + df, q.collect(engine="in-memory"), ["a", "b"], descending=descending + ) + + +def test_reverse_variable_sort_13573() -> None: + df = pl.DataFrame( + { + "a": ["one", "two", "three"], + "b": ["four", "five", "six"], + } + ).lazy() + assert df.sort("a", "b", descending=[True, False]).collect( + engine="streaming" + ).to_dict(as_series=False) == { + "a": ["two", "three", "one"], + "b": ["five", "six", "four"], + } + + +def test_nulls_last_streaming_sort() -> None: + assert pl.LazyFrame({"x": [1, None]}).sort("x", nulls_last=True).collect( + engine="streaming" + ).to_dict(as_series=False) == {"x": [1, None]} + + +@pytest.mark.parametrize("descending", [True, False]) +@pytest.mark.parametrize("nulls_last", [True, False]) +def test_sort_descending_nulls_last(descending: bool, nulls_last: bool) -> None: + df = pl.DataFrame({"x": [1, 3, None, 2, None], "y": [1, 3, 0, 2, 0]}) + + null_sentinel = 100 if descending ^ nulls_last else -100 + ref_x = [1, 3, None, 2, None] + ref_x.sort(key=lambda k: null_sentinel if k is None else k, reverse=descending) + ref_y = [1, 3, 0, 2, 0] + ref_y.sort(key=lambda k: null_sentinel if k == 0 else k, reverse=descending) + + assert_frame_equal( + df.lazy() + .sort("x", descending=descending, nulls_last=nulls_last) + .collect(engine="streaming"), + pl.DataFrame({"x": ref_x, "y": ref_y}), + ) + + assert_frame_equal( + df.lazy() + .sort(["x", "y"], descending=descending, nulls_last=nulls_last) + .collect(engine="streaming"), + pl.DataFrame({"x": ref_x, "y": ref_y}), + ) diff --git a/py-polars/tests/unit/streaming/test_streaming_unique.py b/py-polars/tests/unit/streaming/test_streaming_unique.py new file mode 100644 index 000000000000..0f2769030983 --- /dev/null +++ b/py-polars/tests/unit/streaming/test_streaming_unique.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + +if TYPE_CHECKING: + from pathlib import Path + +pytestmark = pytest.mark.xdist_group("streaming") + + +@pytest.mark.write_disk +@pytest.mark.slow +def test_streaming_out_of_core_unique( + io_files_path: Path, tmp_path: Path, monkeypatch: Any, capfd: Any +) -> None: + tmp_path.mkdir(exist_ok=True) + monkeypatch.setenv("POLARS_TEMP_DIR", str(tmp_path)) + monkeypatch.setenv("POLARS_FORCE_OOC", "1") + monkeypatch.setenv("POLARS_VERBOSE", "1") + monkeypatch.setenv("POLARS_STREAMING_GROUPBY_SPILL_SIZE", "256") + df = pl.read_csv(io_files_path / "foods*.csv") + # this creates 10M rows + q = df.lazy() + q = q.join(q, how="cross").select(df.columns).head(10_000) + + # uses out-of-core unique + df1 = q.join(q.head(1000), how="cross").unique().collect(engine="streaming") + # this ensures the cross join gives equal result but uses the in-memory unique + df2 = q.join(q.head(1000), how="cross").collect(engine="streaming").unique() + assert df1.shape == df2.shape + + # TODO: Re-enable this check when this issue is fixed: https://github.com/pola-rs/polars/issues/10466 + _ = capfd.readouterr().err + # assert "OOC group_by started" in err + + +@pytest.mark.may_fail_auto_streaming +def test_streaming_unique(monkeypatch: Any, capfd: Any) -> None: + monkeypatch.setenv("POLARS_VERBOSE", "1") + df = pl.DataFrame({"a": [1, 2, 2, 2], "b": [3, 4, 4, 4], "c": [5, 6, 7, 7]}) + q = df.lazy().unique(subset=["a", "c"], maintain_order=False).sort(["a", "b", "c"]) + assert_frame_equal(q.collect(engine="old-streaming"), q.collect(engine="in-memory")) # type: ignore[call-overload] + + q = df.lazy().unique(subset=["b", "c"], maintain_order=False).sort(["a", "b", "c"]) + assert_frame_equal(q.collect(engine="old-streaming"), q.collect(engine="in-memory")) # type: ignore[call-overload] + + q = df.lazy().unique(subset=None, maintain_order=False).sort(["a", "b", "c"]) + assert_frame_equal(q.collect(engine="old-streaming"), q.collect(engine="in-memory")) # type: ignore[call-overload] + (_, err) = capfd.readouterr() + assert "df -> re-project-sink -> sort_multiple" in err diff --git a/py-polars/tests/unit/test_api.py b/py-polars/tests/unit/test_api.py new file mode 100644 index 000000000000..4b37f47c38b4 --- /dev/null +++ b/py-polars/tests/unit/test_api.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + + +def test_custom_df_namespace() -> None: + @pl.api.register_dataframe_namespace("split") + class SplitFrame: + def __init__(self, df: pl.DataFrame) -> None: + self._df = df + + def by_first_letter_of_column_names(self) -> list[pl.DataFrame]: + return [ + self._df.select([col for col in self._df.columns if col[0] == f]) + for f in sorted({col[0] for col in self._df.columns}) + ] + + def by_first_letter_of_column_values(self, col: str) -> list[pl.DataFrame]: + return [ + self._df.filter(pl.col(col).str.starts_with(c)) + for c in sorted(set(df.select(pl.col(col).str.slice(0, 1)).to_series())) + ] + + df = pl.DataFrame( + data=[["xx", 2, 3, 4], ["xy", 4, 5, 6], ["yy", 5, 6, 7], ["yz", 6, 7, 8]], + schema=["a1", "a2", "b1", "b2"], + orient="row", + ) + + dfs = df.split.by_first_letter_of_column_names() # type: ignore[attr-defined] + assert [d.rows() for d in dfs] == [ + [("xx", 2), ("xy", 4), ("yy", 5), ("yz", 6)], + [(3, 4), (5, 6), (6, 7), (7, 8)], + ] + dfs = df.split.by_first_letter_of_column_values("a1") # type: ignore[attr-defined] + assert [d.rows() for d in dfs] == [ + [("xx", 2, 3, 4), ("xy", 4, 5, 6)], + [("yy", 5, 6, 7), ("yz", 6, 7, 8)], + ] + + +def test_custom_expr_namespace() -> None: + @pl.api.register_expr_namespace("power") + class PowersOfN: + def __init__(self, expr: pl.Expr) -> None: + self._expr = expr + + def next(self, p: int) -> pl.Expr: + return (p ** (self._expr.log(p).ceil()).cast(pl.Int64)).cast(pl.Int64) + + def previous(self, p: int) -> pl.Expr: + return (p ** (self._expr.log(p).floor()).cast(pl.Int64)).cast(pl.Int64) + + def nearest(self, p: int) -> pl.Expr: + return (p ** (self._expr.log(p)).round(0).cast(pl.Int64)).cast(pl.Int64) + + df = pl.DataFrame([1.4, 24.3, 55.0, 64.001], schema=["n"]) + assert df.select( + pl.col("n"), + pl.col("n").power.next(p=2).alias("next_pow2"), # type: ignore[attr-defined] + pl.col("n").power.previous(p=2).alias("prev_pow2"), # type: ignore[attr-defined] + pl.col("n").power.nearest(p=2).alias("nearest_pow2"), # type: ignore[attr-defined] + ).rows() == [ + (1.4, 2, 1, 1), + (24.3, 32, 16, 32), + (55.0, 64, 32, 64), + (64.001, 128, 64, 64), + ] + + +def test_custom_lazy_namespace() -> None: + @pl.api.register_lazyframe_namespace("split") + class SplitFrame: + def __init__(self, lf: pl.LazyFrame) -> None: + self._lf = lf + + def by_column_dtypes(self) -> list[pl.LazyFrame]: + return [ + self._lf.select(pl.col(tp)) + for tp in dict.fromkeys(self._lf.collect_schema().dtypes()) + ] + + ldf = pl.DataFrame( + data=[["xx", 2, 3, 4], ["xy", 4, 5, 6], ["yy", 5, 6, 7], ["yz", 6, 7, 8]], + schema=["a1", "a2", "b1", "b2"], + orient="row", + ).lazy() + + df1, df2 = (d.collect() for d in ldf.split.by_column_dtypes()) # type: ignore[attr-defined] + assert_frame_equal( + df1, + pl.DataFrame([("xx",), ("xy",), ("yy",), ("yz",)], schema=["a1"], orient="row"), + ) + assert_frame_equal( + df2, + pl.DataFrame( + [(2, 3, 4), (4, 5, 6), (5, 6, 7), (6, 7, 8)], + schema=["a2", "b1", "b2"], + orient="row", + ), + ) + + +def test_custom_series_namespace() -> None: + @pl.api.register_series_namespace("math") + class CustomMath: + def __init__(self, s: pl.Series) -> None: + self._s = s + + def square(self) -> pl.Series: + return self._s * self._s + + s = pl.Series("n", [1.5, 31.0, 42.0, 64.5]) + assert s.math.square().to_list() == [ # type: ignore[attr-defined] + 2.25, + 961.0, + 1764.0, + 4160.25, + ] + + +@pytest.mark.slow +@pytest.mark.parametrize("pcls", [pl.Expr, pl.DataFrame, pl.LazyFrame, pl.Series]) +def test_class_namespaces_are_registered(pcls: Any) -> None: + # confirm that existing (and new) namespaces + # have been added to that class's "_accessors" attr + namespaces: set[str] = getattr(pcls, "_accessors", set()) + for name in dir(pcls): + if not name.startswith("_"): + attr = getattr(pcls, name) + if isinstance(attr, property): + try: + obj = attr.fget(pcls) # type: ignore[misc] + except Exception: + continue + + if obj.__class__.__name__.endswith("NameSpace"): + ns = obj._accessor + assert ns in namespaces, ( + f"{ns!r} should be registered in {pcls.__name__}._accessors" + ) + + +def test_namespace_cannot_override_builtin() -> None: + with pytest.raises(AttributeError): + + @pl.api.register_dataframe_namespace("dt") + class CustomDt: + def __init__(self, df: pl.DataFrame) -> None: + self._df = df + + +def test_namespace_warning_on_override() -> None: + @pl.api.register_dataframe_namespace("math") + class CustomMath: + def __init__(self, df: pl.DataFrame) -> None: + self._df = df + + with pytest.raises(UserWarning): + + @pl.api.register_dataframe_namespace("math") + class CustomMath2: + def __init__(self, df: pl.DataFrame) -> None: + self._df = df diff --git a/py-polars/tests/unit/test_arity.py b/py-polars/tests/unit/test_arity.py new file mode 100644 index 000000000000..099a760f9fe0 --- /dev/null +++ b/py-polars/tests/unit/test_arity.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + + +def test_expression_literal_series_order() -> None: + s = pl.Series([1, 2, 3]) + df = pl.DataFrame({"a": [1, 2, 3]}) + + result = df.select(pl.col("a") + s) + expected = pl.DataFrame({"a": [2, 4, 6]}) + assert_frame_equal(result, expected) + + result = df.select(pl.lit(s) + pl.col("a")) + expected = pl.DataFrame({"": [2, 4, 6]}) + assert_frame_equal(result, expected) + + +def test_when_then_broadcast_nulls_12665() -> None: + df = pl.DataFrame( + { + "val": [1, 2, 3, 4], + "threshold": [4, None, None, 1], + } + ) + + assert df.select( + when=pl.when(pl.col("val") > pl.col("threshold")).then(1).otherwise(0), + ).to_dict(as_series=False) == {"when": [0, 0, 0, 1]} + + +@pytest.mark.parametrize( + ("needs_broadcast", "expect_contains"), + [ + (pl.lit("a"), [True, False, False]), + (pl.col("name").head(1), [True, False, False]), + (pl.lit(None, dtype=pl.String), [None, None, None]), + (pl.col("null_utf8").head(1), [None, None, None]), + ], +) +@pytest.mark.parametrize("literal", [True, False]) +@pytest.mark.parametrize( + "df", + [ + pl.DataFrame( + { + "name": ["a", "b", "c"], + "null_utf8": pl.Series([None, None, None], dtype=pl.String), + } + ) + ], +) +def test_broadcast_string_ops_12632( + df: pl.DataFrame, + needs_broadcast: pl.Expr, + expect_contains: list[bool], + literal: bool, +) -> None: + assert ( + df.select(needs_broadcast.str.contains(pl.col("name"), literal=literal)) + .to_series() + .to_list() + == expect_contains + ) + + assert ( + df.select(needs_broadcast.str.starts_with(pl.col("name"))).to_series().to_list() + == expect_contains + ) + + assert ( + df.select(needs_broadcast.str.ends_with(pl.col("name"))).to_series().to_list() + == expect_contains + ) + + assert df.select(needs_broadcast.str.strip_chars(pl.col("name"))).height == 3 + assert df.select(needs_broadcast.str.strip_chars_start(pl.col("name"))).height == 3 + assert df.select(needs_broadcast.str.strip_chars_end(pl.col("name"))).height == 3 + + +def test_negate_inlined_14278() -> None: + df = pl.DataFrame( + {"group": ["A", "A", "B", "B", "B", "C", "C"], "value": [1, 2, 3, 4, 5, 6, 7]} + ) + + agg_expr = [ + pl.struct("group", "value").tail(2).alias("list"), + pl.col("value").sort().tail(2).count().alias("count"), + ] + + q = df.lazy().group_by("group").agg(agg_expr) + assert q.collect().sort("group").to_dict(as_series=False) == { + "group": ["A", "B", "C"], + "list": [ + [{"group": "A", "value": 1}, {"group": "A", "value": 2}], + [{"group": "B", "value": 4}, {"group": "B", "value": 5}], + [{"group": "C", "value": 6}, {"group": "C", "value": 7}], + ], + "count": [2, 2, 2], + } + + +def test_nested_level_literals_17377() -> None: + df = pl.LazyFrame({"group": [1, 2], "value": [1, 2]}) + + df2 = df.group_by("group").agg( + [ + pl.when((pl.col("value") < 0).all()) + .then(None) + .otherwise(pl.col("value").mean()) + .alias("res") + ] + ) + + assert df2.collect_schema() == pl.Schema({"group": pl.Int64(), "res": pl.Float64()}) diff --git a/py-polars/tests/unit/test_async.py b/py-polars/tests/unit/test_async.py new file mode 100644 index 000000000000..2f13771c4dad --- /dev/null +++ b/py-polars/tests/unit/test_async.py @@ -0,0 +1,208 @@ +from __future__ import annotations + +import asyncio +import sys +import time +from functools import partial +from typing import Any, Callable + +import pytest + +import polars as pl +from polars.dependencies import gevent +from polars.exceptions import ColumnNotFoundError + +pytestmark = pytest.mark.slow() + + +async def _aio_collect_async(raises: bool = False) -> pl.DataFrame: + lf = ( + pl.LazyFrame( + { + "a": ["a", "b", "a", "b", "b", "c"], + "b": [1, 2, 3, 4, 5, 6], + "c": [6, 5, 4, 3, 2, 1], + } + ) + .group_by("a", maintain_order=True) + .agg(pl.all().sum()) + ) + if raises: + lf = lf.select(pl.col("foo_bar")) + return await lf.collect_async() + + +async def _aio_collect_all_async(raises: bool = False) -> list[pl.DataFrame]: + lf = ( + pl.LazyFrame( + { + "a": ["a", "b", "a", "b", "b", "c"], + "b": [1, 2, 3, 4, 5, 6], + "c": [6, 5, 4, 3, 2, 1], + } + ) + .group_by("a", maintain_order=True) + .agg(pl.all().sum()) + ) + if raises: + lf = lf.select(pl.col("foo_bar")) + + lf2 = pl.LazyFrame({"a": [1, 2], "b": [1, 2]}).group_by("a").sum() + + return await pl.collect_all_async([lf, lf2]) + + +_aio_collect = pytest.mark.parametrize( + ("collect", "raises"), + [ + (_aio_collect_async, None), + (_aio_collect_all_async, None), + (partial(_aio_collect_async, True), ColumnNotFoundError), + (partial(_aio_collect_all_async, True), ColumnNotFoundError), + ], +) + + +def _aio_run(coroutine: Any, raises: Exception | None = None) -> None: + if raises is not None: + with pytest.raises(raises): # type: ignore[call-overload] + asyncio.run(coroutine) + else: + assert len(asyncio.run(coroutine)) > 0 + + +@_aio_collect +def test_collect_async_switch( + collect: Callable[[], Any], + raises: Exception | None, +) -> None: + async def main() -> Any: + df = collect() + await asyncio.sleep(0.3) + return await df + + _aio_run(main(), raises) + + +@_aio_collect +def test_collect_async_task( + collect: Callable[[], Any], raises: Exception | None +) -> None: + async def main() -> Any: + df = asyncio.create_task(collect()) + await asyncio.sleep(0.3) + return await df + + _aio_run(main(), raises) + + +def _gevent_collect_async(raises: bool = False) -> Any: + lf = ( + pl.LazyFrame( + { + "a": ["a", "b", "a", "b", "b", "c"], + "b": [1, 2, 3, 4, 5, 6], + "c": [6, 5, 4, 3, 2, 1], + } + ) + .group_by("a", maintain_order=True) + .agg(pl.all().sum()) + ) + if raises: + lf = lf.select(pl.col("foo_bar")) + return lf.collect_async(gevent=True) + + +def _gevent_collect_all_async(raises: bool = False) -> Any: + lf = ( + pl.LazyFrame( + { + "a": ["a", "b", "a", "b", "b", "c"], + "b": [1, 2, 3, 4, 5, 6], + "c": [6, 5, 4, 3, 2, 1], + } + ) + .group_by("a", maintain_order=True) + .agg(pl.all().sum()) + ) + if raises: + lf = lf.select(pl.col("foo_bar")) + return pl.collect_all_async([lf], gevent=True) + + +_gevent_collect = pytest.mark.parametrize( + ("get_result", "raises"), + [ + (_gevent_collect_async, None), + (_gevent_collect_all_async, None), + (partial(_gevent_collect_async, True), ColumnNotFoundError), + (partial(_gevent_collect_all_async, True), ColumnNotFoundError), + ], +) + + +def _gevent_run(callback: Callable[[], Any], raises: Exception | None = None) -> None: + if raises is not None: + with pytest.raises(raises): # type: ignore[call-overload] + callback() + else: + assert len(callback()) > 0 + + +@_gevent_collect +def test_gevent_collect_async_without_hub( + get_result: Callable[[], Any], raises: Exception | None +) -> None: + def main() -> Any: + return get_result().get() + + _gevent_run(main, raises) + + +@_gevent_collect +def test_gevent_collect_async_with_hub( + get_result: Callable[[], Any], raises: Exception | None +) -> None: + _hub = gevent.get_hub() + + def main() -> Any: + return get_result().get() + + _gevent_run(main, raises) + + +@pytest.mark.skipif(sys.platform == "win32", reason="May time out on Windows") +@_gevent_collect +def test_gevent_collect_async_switch( + get_result: Callable[[], Any], raises: Exception | None +) -> None: + def main() -> Any: + result = get_result() + gevent.sleep(0.1) + return result.get(block=False, timeout=3) + + _gevent_run(main, raises) + + +@_gevent_collect +def test_gevent_collect_async_no_switch( + get_result: Callable[[], Any], raises: Exception | None +) -> None: + def main() -> Any: + result = get_result() + time.sleep(1) + return result.get(block=False, timeout=None) + + _gevent_run(main, raises) + + +@_gevent_collect +def test_gevent_collect_async_spawn( + get_result: Callable[[], Any], raises: Exception | None +) -> None: + def main() -> Any: + result_greenlet = gevent.spawn(get_result) + gevent.spawn(gevent.sleep, 0.1) + return result_greenlet.get().get() + + _gevent_run(main, raises) diff --git a/py-polars/tests/unit/test_chunks.py b/py-polars/tests/unit/test_chunks.py new file mode 100644 index 000000000000..63e7a327f003 --- /dev/null +++ b/py-polars/tests/unit/test_chunks.py @@ -0,0 +1,16 @@ +import numpy as np + +import polars as pl + + +def test_chunks_align_16830() -> None: + n = 2 + df = pl.DataFrame( + {"index_1": np.repeat(np.arange(10), n), "index_2": np.repeat(np.arange(10), n)} + ) + df = pl.concat([df[0:10], df[10:]], rechunk=False) + df = df.filter(df["index_1"] == 0) # filter chunks + df = df.with_columns( + index_2=pl.Series(values=[0] * n) + ) # set a chunk of different size + df.set_sorted("index_2") # triggers `select_chunk`. diff --git a/py-polars/tests/unit/test_config.py b/py-polars/tests/unit/test_config.py new file mode 100644 index 000000000000..cf3f00b07cfb --- /dev/null +++ b/py-polars/tests/unit/test_config.py @@ -0,0 +1,1010 @@ +from __future__ import annotations + +import os +from pathlib import Path +from textwrap import dedent +from typing import TYPE_CHECKING, Any + +import pytest + +import polars as pl +import polars.polars as plr +from polars._utils.unstable import issue_unstable_warning +from polars.config import _POLARS_CFG_ENV_VARS + +if TYPE_CHECKING: + from collections.abc import Iterator + + +@pytest.fixture(autouse=True) +def _environ() -> Iterator[None]: + """Fixture to restore the environment after/during tests.""" + with pl.StringCache(), pl.Config(restore_defaults=True): + yield + + +def test_ascii_tables() -> None: + df = pl.DataFrame( + { + "a": [1, 2], + "b": [4, 5], + "c": [[list(range(1, 26))], [list(range(1, 76))]], + } + ) + + ascii_table_repr = ( + "shape: (2, 3)\n" + "+-----+-----+------------------+\n" + "| a | b | c |\n" + "| --- | --- | --- |\n" + "| i64 | i64 | list[list[i64]] |\n" + "+==============================+\n" + "| 1 | 4 | [[1, 2, ... 25]] |\n" + "| 2 | 5 | [[1, 2, ... 75]] |\n" + "+-----+-----+------------------+" + ) + # note: expect to render ascii only within the given scope + with pl.Config(set_ascii_tables=True): + assert repr(df) == ascii_table_repr + + # confirm back to utf8 default after scope-exit + assert ( + repr(df) == "shape: (2, 3)\n" + "┌─────┬─────┬─────────────────┐\n" + "│ a ┆ b ┆ c │\n" + "│ --- ┆ --- ┆ --- │\n" + "│ i64 ┆ i64 ┆ list[list[i64]] │\n" + "╞═════╪═════╪═════════════════╡\n" + "│ 1 ┆ 4 ┆ [[1, 2, … 25]] │\n" + "│ 2 ┆ 5 ┆ [[1, 2, … 75]] │\n" + "└─────┴─────┴─────────────────┘" + ) + + @pl.Config(set_ascii_tables=True) + def ascii_table() -> str: + return repr(df) + + assert ascii_table() == ascii_table_repr + + +def test_hide_header_elements() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) + + pl.Config.set_tbl_hide_column_data_types(True) + assert ( + str(df) == "shape: (3, 3)\n" + "┌───┬───┬───┐\n" + "│ a ┆ b ┆ c │\n" + "╞═══╪═══╪═══╡\n" + "│ 1 ┆ 4 ┆ 7 │\n" + "│ 2 ┆ 5 ┆ 8 │\n" + "│ 3 ┆ 6 ┆ 9 │\n" + "└───┴───┴───┘" + ) + + pl.Config.set_tbl_hide_column_data_types(False).set_tbl_hide_column_names(True) + assert ( + str(df) == "shape: (3, 3)\n" + "┌─────┬─────┬─────┐\n" + "│ i64 ┆ i64 ┆ i64 │\n" + "╞═════╪═════╪═════╡\n" + "│ 1 ┆ 4 ┆ 7 │\n" + "│ 2 ┆ 5 ┆ 8 │\n" + "│ 3 ┆ 6 ┆ 9 │\n" + "└─────┴─────┴─────┘" + ) + + +def test_set_tbl_cols() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) + + pl.Config.set_tbl_cols(1) + assert str(df).split("\n")[2] == "│ a ┆ … │" + pl.Config.set_tbl_cols(2) + assert str(df).split("\n")[2] == "│ a ┆ … ┆ c │" + pl.Config.set_tbl_cols(3) + assert str(df).split("\n")[2] == "│ a ┆ b ┆ c │" + + df = pl.DataFrame( + {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9], "d": [10, 11, 12]} + ) + pl.Config.set_tbl_cols(2) + assert str(df).split("\n")[2] == "│ a ┆ … ┆ d │" + pl.Config.set_tbl_cols(3) + assert str(df).split("\n")[2] == "│ a ┆ b ┆ … ┆ d │" + pl.Config.set_tbl_cols(-1) + assert str(df).split("\n")[2] == "│ a ┆ b ┆ c ┆ d │" + + +def test_set_tbl_rows() -> None: + df = pl.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8], "c": [9, 10, 11, 12]}) + ser = pl.Series("ser", [1, 2, 3, 4, 5]) + + pl.Config.set_tbl_rows(0) + assert ( + str(df) == "shape: (4, 3)\n" + "┌─────┬─────┬─────┐\n" + "│ a ┆ b ┆ c │\n" + "│ --- ┆ --- ┆ --- │\n" + "│ i64 ┆ i64 ┆ i64 │\n" + "╞═════╪═════╪═════╡\n" + "│ … ┆ … ┆ … │\n" + "└─────┴─────┴─────┘" + ) + assert str(ser) == "shape: (5,)\nSeries: 'ser' [i64]\n[\n\t…\n]" + + pl.Config.set_tbl_rows(1) + assert ( + str(df) == "shape: (4, 3)\n" + "┌─────┬─────┬─────┐\n" + "│ a ┆ b ┆ c │\n" + "│ --- ┆ --- ┆ --- │\n" + "│ i64 ┆ i64 ┆ i64 │\n" + "╞═════╪═════╪═════╡\n" + "│ 1 ┆ 5 ┆ 9 │\n" + "│ … ┆ … ┆ … │\n" + "└─────┴─────┴─────┘" + ) + assert str(ser) == "shape: (5,)\nSeries: 'ser' [i64]\n[\n\t1\n\t…\n]" + + pl.Config.set_tbl_rows(2) + assert ( + str(df) == "shape: (4, 3)\n" + "┌─────┬─────┬─────┐\n" + "│ a ┆ b ┆ c │\n" + "│ --- ┆ --- ┆ --- │\n" + "│ i64 ┆ i64 ┆ i64 │\n" + "╞═════╪═════╪═════╡\n" + "│ 1 ┆ 5 ┆ 9 │\n" + "│ … ┆ … ┆ … │\n" + "│ 4 ┆ 8 ┆ 12 │\n" + "└─────┴─────┴─────┘" + ) + assert str(ser) == "shape: (5,)\nSeries: 'ser' [i64]\n[\n\t1\n\t…\n\t5\n]" + + pl.Config.set_tbl_rows(3) + assert ( + str(df) == "shape: (4, 3)\n" + "┌─────┬─────┬─────┐\n" + "│ a ┆ b ┆ c │\n" + "│ --- ┆ --- ┆ --- │\n" + "│ i64 ┆ i64 ┆ i64 │\n" + "╞═════╪═════╪═════╡\n" + "│ 1 ┆ 5 ┆ 9 │\n" + "│ 2 ┆ 6 ┆ 10 │\n" + "│ … ┆ … ┆ … │\n" + "│ 4 ┆ 8 ┆ 12 │\n" + "└─────┴─────┴─────┘" + ) + assert str(ser) == "shape: (5,)\nSeries: 'ser' [i64]\n[\n\t1\n\t2\n\t…\n\t5\n]" + + pl.Config.set_tbl_rows(4) + assert ( + str(df) == "shape: (4, 3)\n" + "┌─────┬─────┬─────┐\n" + "│ a ┆ b ┆ c │\n" + "│ --- ┆ --- ┆ --- │\n" + "│ i64 ┆ i64 ┆ i64 │\n" + "╞═════╪═════╪═════╡\n" + "│ 1 ┆ 5 ┆ 9 │\n" + "│ 2 ┆ 6 ┆ 10 │\n" + "│ 3 ┆ 7 ┆ 11 │\n" + "│ 4 ┆ 8 ┆ 12 │\n" + "└─────┴─────┴─────┘" + ) + assert str(ser) == "shape: (5,)\nSeries: 'ser' [i64]\n[\n\t1\n\t2\n\t…\n\t4\n\t5\n]" + + df = pl.DataFrame( + { + "a": [1, 2, 3, 4, 5], + "b": [6, 7, 8, 9, 10], + "c": [11, 12, 13, 14, 15], + } + ) + + pl.Config.set_tbl_rows(3) + assert ( + str(df) == "shape: (5, 3)\n" + "┌─────┬─────┬─────┐\n" + "│ a ┆ b ┆ c │\n" + "│ --- ┆ --- ┆ --- │\n" + "│ i64 ┆ i64 ┆ i64 │\n" + "╞═════╪═════╪═════╡\n" + "│ 1 ┆ 6 ┆ 11 │\n" + "│ 2 ┆ 7 ┆ 12 │\n" + "│ … ┆ … ┆ … │\n" + "│ 5 ┆ 10 ┆ 15 │\n" + "└─────┴─────┴─────┘" + ) + + pl.Config.set_tbl_rows(-1) + assert str(ser) == "shape: (5,)\nSeries: 'ser' [i64]\n[\n\t1\n\t2\n\t3\n\t4\n\t5\n]" + + pl.Config.set_tbl_hide_dtype_separator(True) + assert ( + str(df) == "shape: (5, 3)\n" + "┌─────┬─────┬─────┐\n" + "│ a ┆ b ┆ c │\n" + "│ i64 ┆ i64 ┆ i64 │\n" + "╞═════╪═════╪═════╡\n" + "│ 1 ┆ 6 ┆ 11 │\n" + "│ 2 ┆ 7 ┆ 12 │\n" + "│ 3 ┆ 8 ┆ 13 │\n" + "│ 4 ┆ 9 ┆ 14 │\n" + "│ 5 ┆ 10 ┆ 15 │\n" + "└─────┴─────┴─────┘" + ) + + +def test_set_tbl_formats() -> None: + df = pl.DataFrame( + { + "foo": [1, 2, 3], + "bar": [6.0, 7.0, 8.0], + "ham": ["a", "b", "c"], + } + ) + pl.Config().set_tbl_formatting("ASCII_MARKDOWN") + assert str(df) == ( + "shape: (3, 3)\n" + "| foo | bar | ham |\n" + "| --- | --- | --- |\n" + "| i64 | f64 | str |\n" + "|-----|-----|-----|\n" + "| 1 | 6.0 | a |\n" + "| 2 | 7.0 | b |\n" + "| 3 | 8.0 | c |" + ) + + pl.Config().set_tbl_formatting("ASCII_BORDERS_ONLY_CONDENSED") + with pl.Config(tbl_hide_dtype_separator=True): + assert str(df) == ( + "shape: (3, 3)\n" + "+-----------------+\n" + "| foo bar ham |\n" + "| i64 f64 str |\n" + "+=================+\n" + "| 1 6.0 a |\n" + "| 2 7.0 b |\n" + "| 3 8.0 c |\n" + "+-----------------+" + ) + + # temporarily scope "nothing" style, with no data types + with pl.Config( + tbl_formatting="NOTHING", + tbl_hide_column_data_types=True, + ): + assert str(df) == ( + "shape: (3, 3)\n" + " foo bar ham \n" + " 1 6.0 a \n" + " 2 7.0 b \n" + " 3 8.0 c " + ) + + # after scope, expect previous style + assert str(df) == ( + "shape: (3, 3)\n" + "+-----------------+\n" + "| foo bar ham |\n" + "| --- --- --- |\n" + "| i64 f64 str |\n" + "+=================+\n" + "| 1 6.0 a |\n" + "| 2 7.0 b |\n" + "| 3 8.0 c |\n" + "+-----------------+" + ) + + # invalid style + with pytest.raises(ValueError, match="invalid table format name: 'NOPE'"): + pl.Config().set_tbl_formatting("NOPE") # type: ignore[arg-type] + + +def test_set_tbl_width_chars() -> None: + df = pl.DataFrame( + { + "a really long col": [1, 2, 3], + "b": ["", "this is a string value that will be truncated", None], + "this is 10": [4, 5, 6], + } + ) + assert max(len(line) for line in str(df).split("\n")) == 68 + + pl.Config.set_tbl_width_chars(60) + assert max(len(line) for line in str(df).split("\n")) == 60 + + # force minimal table size (will hard-wrap everything; "don't try this at home" :p) + pl.Config.set_tbl_width_chars(0) + assert max(len(line) for line in str(df).split("\n")) == 19 + + # this check helps to check that column width bucketing + # is exact; no extraneous character allocation + df = pl.DataFrame( + { + "A": [1, 2, 3, 4, 5], + "fruits": ["banana", "banana", "apple", "apple", "banana"], + "B": [5, 4, 3, 2, 1], + "cars": ["beetle", "audi", "beetle", "beetle", "beetle"], + }, + schema_overrides={"A": pl.Int64, "B": pl.Int64}, + ).select(pl.all(), pl.all().name.suffix("_suffix!")) + + with pl.Config(tbl_width_chars=87): + assert max(len(line) for line in str(df).split("\n")) == 87 + + # check that -1 is interpreted as no limit + df = pl.DataFrame({str(i): ["a" * 25] for i in range(5)}) + for tbl_width_chars, expected_width in [ + (None, 100), + (-1, 141), + ]: + with pl.Config(tbl_width_chars=tbl_width_chars): + assert max(len(line) for line in str(df).split("\n")) == expected_width + + +def test_shape_below_table_and_inlined_dtype() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}) + + pl.Config.set_tbl_column_data_type_inline(True).set_tbl_dataframe_shape_below(True) + pl.Config.set_tbl_formatting("UTF8_FULL", rounded_corners=True) + assert ( + str(df) == "" + "╭─────────┬─────────┬─────────╮\n" + "│ a (i64) ┆ b (i64) ┆ c (i64) │\n" + "╞═════════╪═════════╪═════════╡\n" + "│ 1 ┆ 3 ┆ 5 │\n" + "├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤\n" + "│ 2 ┆ 4 ┆ 6 │\n" + "╰─────────┴─────────┴─────────╯\n" + "shape: (2, 3)" + ) + + pl.Config.set_tbl_dataframe_shape_below(False) + assert ( + str(df) == "shape: (2, 3)\n" + "╭─────────┬─────────┬─────────╮\n" + "│ a (i64) ┆ b (i64) ┆ c (i64) │\n" + "╞═════════╪═════════╪═════════╡\n" + "│ 1 ┆ 3 ┆ 5 │\n" + "├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤\n" + "│ 2 ┆ 4 ┆ 6 │\n" + "╰─────────┴─────────┴─────────╯" + ) + ( + pl.Config.set_tbl_formatting(None, rounded_corners=False) + .set_tbl_column_data_type_inline(False) + .set_tbl_cell_alignment("RIGHT") + ) + assert ( + str(df) == "shape: (2, 3)\n" + "┌─────┬─────┬─────┐\n" + "│ a ┆ b ┆ c │\n" + "│ --- ┆ --- ┆ --- │\n" + "│ i64 ┆ i64 ┆ i64 │\n" + "╞═════╪═════╪═════╡\n" + "│ 1 ┆ 3 ┆ 5 │\n" + "│ 2 ┆ 4 ┆ 6 │\n" + "└─────┴─────┴─────┘" + ) + with pytest.raises(ValueError): + pl.Config.set_tbl_cell_alignment("INVALID") # type: ignore[arg-type] + + +def test_shape_format_for_big_numbers() -> None: + df = pl.DataFrame({"a": range(1, 1001), "b": range(1001, 1001 + 1000)}) + + pl.Config.set_tbl_column_data_type_inline(True).set_tbl_dataframe_shape_below(True) + pl.Config.set_tbl_formatting("UTF8_FULL", rounded_corners=True) + assert ( + str(df) == "" + "╭─────────┬─────────╮\n" + "│ a (i64) ┆ b (i64) │\n" + "╞═════════╪═════════╡\n" + "│ 1 ┆ 1001 │\n" + "├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤\n" + "│ 2 ┆ 1002 │\n" + "├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤\n" + "│ 3 ┆ 1003 │\n" + "├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤\n" + "│ 4 ┆ 1004 │\n" + "├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤\n" + "│ 5 ┆ 1005 │\n" + "├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤\n" + "│ … ┆ … │\n" + "├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤\n" + "│ 996 ┆ 1996 │\n" + "├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤\n" + "│ 997 ┆ 1997 │\n" + "├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤\n" + "│ 998 ┆ 1998 │\n" + "├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤\n" + "│ 999 ┆ 1999 │\n" + "├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤\n" + "│ 1000 ┆ 2000 │\n" + "╰─────────┴─────────╯\n" + "shape: (1_000, 2)" + ) + + pl.Config.set_tbl_column_data_type_inline(True).set_tbl_dataframe_shape_below(False) + assert ( + str(df) == "shape: (1_000, 2)\n" + "╭─────────┬─────────╮\n" + "│ a (i64) ┆ b (i64) │\n" + "╞═════════╪═════════╡\n" + "│ 1 ┆ 1001 │\n" + "├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤\n" + "│ 2 ┆ 1002 │\n" + "├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤\n" + "│ 3 ┆ 1003 │\n" + "├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤\n" + "│ 4 ┆ 1004 │\n" + "├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤\n" + "│ 5 ┆ 1005 │\n" + "├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤\n" + "│ … ┆ … │\n" + "├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤\n" + "│ 996 ┆ 1996 │\n" + "├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤\n" + "│ 997 ┆ 1997 │\n" + "├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤\n" + "│ 998 ┆ 1998 │\n" + "├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤\n" + "│ 999 ┆ 1999 │\n" + "├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤\n" + "│ 1000 ┆ 2000 │\n" + "╰─────────┴─────────╯" + ) + + pl.Config.set_tbl_rows(0) + ser = pl.Series("ser", range(1000)) + assert str(ser) == "shape: (1_000,)\nSeries: 'ser' [i64]\n[\n\t…\n]" + + pl.Config.set_tbl_rows(1) + pl.Config.set_tbl_cols(1) + df = pl.DataFrame({str(col_num): 1 for col_num in range(1000)}) + + assert ( + str(df) == "shape: (1, 1_000)\n" + "╭─────────┬───╮\n" + "│ 0 (i64) ┆ … │\n" + "╞═════════╪═══╡\n" + "│ 1 ┆ … │\n" + "╰─────────┴───╯" + ) + + pl.Config.set_tbl_formatting("ASCII_FULL_CONDENSED") + assert ( + str(df) == "shape: (1, 1_000)\n" + "+---------+-----+\n" + "| 0 (i64) | ... |\n" + "+===============+\n" + "| 1 | ... |\n" + "+---------+-----+" + ) + + +def test_numeric_right_alignment() -> None: + pl.Config.set_tbl_cell_numeric_alignment("RIGHT") + + df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) + assert ( + str(df) == "shape: (3, 3)\n" + "┌─────┬─────┬─────┐\n" + "│ a ┆ b ┆ c │\n" + "│ --- ┆ --- ┆ --- │\n" + "│ i64 ┆ i64 ┆ i64 │\n" + "╞═════╪═════╪═════╡\n" + "│ 1 ┆ 4 ┆ 7 │\n" + "│ 2 ┆ 5 ┆ 8 │\n" + "│ 3 ┆ 6 ┆ 9 │\n" + "└─────┴─────┴─────┘" + ) + + df = pl.DataFrame( + {"a": [1.1, 2.22, 3.333], "b": [4.0, 5.0, 6.0], "c": [7.0, 8.0, 9.0]} + ) + with pl.Config(): + pl.Config.set_fmt_float("full") + assert ( + str(df) == "shape: (3, 3)\n" + "┌───────┬─────┬─────┐\n" + "│ a ┆ b ┆ c │\n" + "│ --- ┆ --- ┆ --- │\n" + "│ f64 ┆ f64 ┆ f64 │\n" + "╞═══════╪═════╪═════╡\n" + "│ 1.1 ┆ 4 ┆ 7 │\n" + "│ 2.22 ┆ 5 ┆ 8 │\n" + "│ 3.333 ┆ 6 ┆ 9 │\n" + "└───────┴─────┴─────┘" + ) + + with pl.Config(fmt_float="mixed"): + assert ( + str(df) == "shape: (3, 3)\n" + "┌───────┬─────┬─────┐\n" + "│ a ┆ b ┆ c │\n" + "│ --- ┆ --- ┆ --- │\n" + "│ f64 ┆ f64 ┆ f64 │\n" + "╞═══════╪═════╪═════╡\n" + "│ 1.1 ┆ 4.0 ┆ 7.0 │\n" + "│ 2.22 ┆ 5.0 ┆ 8.0 │\n" + "│ 3.333 ┆ 6.0 ┆ 9.0 │\n" + "└───────┴─────┴─────┘" + ) + + with pl.Config(float_precision=6): + assert str(df) == ( + "shape: (3, 3)\n" + "┌──────────┬──────────┬──────────┐\n" + "│ a ┆ b ┆ c │\n" + "│ --- ┆ --- ┆ --- │\n" + "│ f64 ┆ f64 ┆ f64 │\n" + "╞══════════╪══════════╪══════════╡\n" + "│ 1.100000 ┆ 4.000000 ┆ 7.000000 │\n" + "│ 2.220000 ┆ 5.000000 ┆ 8.000000 │\n" + "│ 3.333000 ┆ 6.000000 ┆ 9.000000 │\n" + "└──────────┴──────────┴──────────┘" + ) + with pl.Config(float_precision=None): + assert ( + str(df) == "shape: (3, 3)\n" + "┌───────┬─────┬─────┐\n" + "│ a ┆ b ┆ c │\n" + "│ --- ┆ --- ┆ --- │\n" + "│ f64 ┆ f64 ┆ f64 │\n" + "╞═══════╪═════╪═════╡\n" + "│ 1.1 ┆ 4.0 ┆ 7.0 │\n" + "│ 2.22 ┆ 5.0 ┆ 8.0 │\n" + "│ 3.333 ┆ 6.0 ┆ 9.0 │\n" + "└───────┴─────┴─────┘" + ) + + df = pl.DataFrame( + {"a": [1.1, 22.2, 3.33], "b": [444.0, 55.5, 6.6], "c": [77.7, 8888.0, 9.9999]} + ) + with pl.Config(fmt_float="full", float_precision=1): + assert ( + str(df) == "shape: (3, 3)\n" + "┌──────┬───────┬────────┐\n" + "│ a ┆ b ┆ c │\n" + "│ --- ┆ --- ┆ --- │\n" + "│ f64 ┆ f64 ┆ f64 │\n" + "╞══════╪═══════╪════════╡\n" + "│ 1.1 ┆ 444.0 ┆ 77.7 │\n" + "│ 22.2 ┆ 55.5 ┆ 8888.0 │\n" + "│ 3.3 ┆ 6.6 ┆ 10.0 │\n" + "└──────┴───────┴────────┘" + ) + + df = pl.DataFrame( + { + "a": [1100000000000000000.1, 22200000000000000.2, 33330000000000000.33333], + "b": [40000000000000000000.0, 5, 600000000000000000.0], + "c": [700000.0, 80000000000000000.0, 900], + } + ) + with pl.Config(float_precision=2): + assert ( + str(df) == "shape: (3, 3)\n" + "┌─────────┬─────────┬───────────┐\n" + "│ a ┆ b ┆ c │\n" + "│ --- ┆ --- ┆ --- │\n" + "│ f64 ┆ f64 ┆ f64 │\n" + "╞═════════╪═════════╪═══════════╡\n" + "│ 1.10e18 ┆ 4.00e19 ┆ 700000.00 │\n" + "│ 2.22e16 ┆ 5.00 ┆ 8.00e16 │\n" + "│ 3.33e16 ┆ 6.00e17 ┆ 900.00 │\n" + "└─────────┴─────────┴───────────┘" + ) + + +@pytest.mark.write_disk +def test_config_load_save(tmp_path: Path) -> None: + for file in ( + None, + tmp_path / "polars.config", + str(tmp_path / "polars.config"), + ): + # set some config options... + pl.Config.set_tbl_cols(12) + pl.Config.set_verbose(True) + pl.Config.set_fmt_float("full") + pl.Config.set_float_precision(6) + pl.Config.set_thousands_separator(",") + assert os.environ.get("POLARS_VERBOSE") == "1" + + if file is None: + cfg = pl.Config.save() + assert isinstance(cfg, str) + else: + assert pl.Config.save_to_file(file) is None + + assert "POLARS_VERBOSE" in pl.Config.state(if_set=True) + + # ...modify the same options... + pl.Config.set_tbl_cols(10) + pl.Config.set_verbose(False) + pl.Config.set_fmt_float("mixed") + pl.Config.set_float_precision(2) + pl.Config.set_thousands_separator(None) + assert os.environ.get("POLARS_VERBOSE") == "0" + + # ...load back from config file/string... + if file is None: + pl.Config.load(cfg) + else: + with pytest.raises(ValueError, match="invalid Config file"): + pl.Config.load_from_file(cfg) + + if isinstance(file, Path): + with pytest.raises(TypeError, match="the JSON object must be str"): + pl.Config.load(file) # type: ignore[arg-type] + else: + with pytest.raises(ValueError, match="invalid Config string"): + pl.Config.load(file) + + pl.Config.load_from_file(file) + + # ...and confirm the saved options were set. + assert os.environ.get("POLARS_FMT_MAX_COLS") == "12" + assert os.environ.get("POLARS_VERBOSE") == "1" + assert plr.get_float_fmt() == "full" + assert plr.get_float_precision() == 6 + + # restore all default options (unsets from env) + pl.Config.restore_defaults() + for e in ("POLARS_FMT_MAX_COLS", "POLARS_VERBOSE"): + assert e not in pl.Config.state(if_set=True) + assert e in pl.Config.state() + + assert os.environ.get("POLARS_FMT_MAX_COLS") is None + assert os.environ.get("POLARS_VERBOSE") is None + assert plr.get_float_fmt() == "mixed" + assert plr.get_float_precision() is None + + # ref: #11094 + with pl.Config( + streaming_chunk_size=100, + tbl_cols=2000, + tbl_formatting="UTF8_NO_BORDERS", + tbl_hide_column_data_types=True, + tbl_hide_dtype_separator=True, + tbl_rows=2000, + tbl_width_chars=2000, + verbose=True, + ): + assert isinstance(repr(pl.DataFrame({"xyz": [0]})), str) + + +def test_config_load_save_context() -> None: + # store the default configuration state + default_state = pl.Config.save() + + # establish some non-default settings + pl.Config.set_tbl_formatting("ASCII_MARKDOWN") + pl.Config.set_verbose(True) + + # load the default config, validate load & context manager behaviour + with pl.Config.load(default_state): + assert os.environ.get("POLARS_FMT_TABLE_FORMATTING") is None + assert os.environ.get("POLARS_VERBOSE") is None + + # ensure earlier state was restored + assert os.environ["POLARS_FMT_TABLE_FORMATTING"] == "ASCII_MARKDOWN" + assert os.environ["POLARS_VERBOSE"] + + +def test_config_instances() -> None: + # establish two config instances that defer setting their options + cfg_markdown = pl.Config( + tbl_formatting="MARKDOWN", + apply_on_context_enter=True, + ) + cfg_compact = pl.Config( + tbl_rows=4, + tbl_cols=4, + tbl_column_data_type_inline=True, + apply_on_context_enter=True, + ) + + # check instance (in)equality + assert cfg_markdown != cfg_compact + assert cfg_markdown == pl.Config( + tbl_formatting="MARKDOWN", apply_on_context_enter=True + ) + + # confirm that the options have not been applied yet + assert os.environ.get("POLARS_FMT_TABLE_FORMATTING") is None + + # confirm that the deferred options are applied when the instance context + # is entered into, and that they can be re-used without leaking state + @cfg_markdown + def fn1() -> str | None: + return os.environ.get("POLARS_FMT_TABLE_FORMATTING") + + assert fn1() == "MARKDOWN" + assert os.environ.get("POLARS_FMT_TABLE_FORMATTING") is None + + with cfg_markdown: # can re-use instance as decorator and context + assert os.environ.get("POLARS_FMT_TABLE_FORMATTING") == "MARKDOWN" + assert os.environ.get("POLARS_FMT_TABLE_FORMATTING") is None + + @cfg_markdown + def fn2() -> str | None: + return os.environ.get("POLARS_FMT_TABLE_FORMATTING") + + assert fn2() == "MARKDOWN" + assert os.environ.get("POLARS_FMT_TABLE_FORMATTING") is None + + df = pl.DataFrame({f"c{idx}": [idx] * 10 for idx in range(10)}) + + @cfg_compact + def fn3(df: pl.DataFrame) -> str: + return repr(df) + + # reuse config instance and confirm state does not leak between invocations + for _ in range(3): + assert ( + fn3(df) + == dedent(""" + shape: (10, 10) + ┌──────────┬──────────┬───┬──────────┬──────────┐ + │ c0 (i64) ┆ c1 (i64) ┆ … ┆ c8 (i64) ┆ c9 (i64) │ + ╞══════════╪══════════╪═══╪══════════╪══════════╡ + │ 0 ┆ 1 ┆ … ┆ 8 ┆ 9 │ + │ 0 ┆ 1 ┆ … ┆ 8 ┆ 9 │ + │ … ┆ … ┆ … ┆ … ┆ … │ + │ 0 ┆ 1 ┆ … ┆ 8 ┆ 9 │ + │ 0 ┆ 1 ┆ … ┆ 8 ┆ 9 │ + └──────────┴──────────┴───┴──────────┴──────────┘""").lstrip() + ) + + assert ( + repr(df) + == dedent(""" + shape: (10, 10) + ┌─────┬─────┬─────┬─────┬───┬─────┬─────┬─────┬─────┐ + │ c0 ┆ c1 ┆ c2 ┆ c3 ┆ … ┆ c6 ┆ c7 ┆ c8 ┆ c9 │ + │ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ i64 ┆ i64 ┆ ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ + ╞═════╪═════╪═════╪═════╪═══╪═════╪═════╪═════╪═════╡ + │ 0 ┆ 1 ┆ 2 ┆ 3 ┆ … ┆ 6 ┆ 7 ┆ 8 ┆ 9 │ + │ 0 ┆ 1 ┆ 2 ┆ 3 ┆ … ┆ 6 ┆ 7 ┆ 8 ┆ 9 │ + │ 0 ┆ 1 ┆ 2 ┆ 3 ┆ … ┆ 6 ┆ 7 ┆ 8 ┆ 9 │ + │ 0 ┆ 1 ┆ 2 ┆ 3 ┆ … ┆ 6 ┆ 7 ┆ 8 ┆ 9 │ + │ 0 ┆ 1 ┆ 2 ┆ 3 ┆ … ┆ 6 ┆ 7 ┆ 8 ┆ 9 │ + │ 0 ┆ 1 ┆ 2 ┆ 3 ┆ … ┆ 6 ┆ 7 ┆ 8 ┆ 9 │ + │ 0 ┆ 1 ┆ 2 ┆ 3 ┆ … ┆ 6 ┆ 7 ┆ 8 ┆ 9 │ + │ 0 ┆ 1 ┆ 2 ┆ 3 ┆ … ┆ 6 ┆ 7 ┆ 8 ┆ 9 │ + │ 0 ┆ 1 ┆ 2 ┆ 3 ┆ … ┆ 6 ┆ 7 ┆ 8 ┆ 9 │ + │ 0 ┆ 1 ┆ 2 ┆ 3 ┆ … ┆ 6 ┆ 7 ┆ 8 ┆ 9 │ + └─────┴─────┴─────┴─────┴───┴─────┴─────┴─────┴─────┘""").lstrip() + ) + + +def test_config_scope() -> None: + pl.Config.set_verbose(False) + pl.Config.set_tbl_cols(8) + + initial_state = pl.Config.state() + + with pl.Config() as cfg: + ( + cfg.set_tbl_formatting(rounded_corners=True) + .set_verbose(True) + .set_tbl_hide_dtype_separator(True) + .set_ascii_tables() + ) + new_state_entries = set( + { + "POLARS_FMT_MAX_COLS": "8", + "POLARS_FMT_TABLE_FORMATTING": "ASCII_FULL_CONDENSED", + "POLARS_FMT_TABLE_HIDE_COLUMN_SEPARATOR": "1", + "POLARS_FMT_TABLE_ROUNDED_CORNERS": "1", + "POLARS_VERBOSE": "1", + }.items() + ) + assert set(initial_state.items()) != new_state_entries + assert new_state_entries.issubset(set(cfg.state().items())) + + # expect scope-exit to restore original state + assert pl.Config.state() == initial_state + + +def test_config_raise_error_if_not_exist() -> None: + with pytest.raises(AttributeError), pl.Config(i_do_not_exist=True): # type: ignore[call-arg] + pass + + +def test_config_state_env_only() -> None: + pl.Config.set_verbose(False) + pl.Config.set_fmt_float("full") + + state_all = pl.Config.state(env_only=False) + state_env_only = pl.Config.state(env_only=True) + assert len(state_env_only) < len(state_all) + assert "set_fmt_float" in state_all + assert "set_fmt_float" not in state_env_only + + +def test_set_streaming_chunk_size() -> None: + with pl.Config() as cfg: + cfg.set_streaming_chunk_size(8) + assert os.environ.get("POLARS_STREAMING_CHUNK_SIZE") == "8" + + with pytest.raises(ValueError), pl.Config() as cfg: + cfg.set_streaming_chunk_size(0) + + +def test_set_fmt_str_lengths_invalid_length() -> None: + with pl.Config() as cfg: + with pytest.raises(ValueError): + cfg.set_fmt_str_lengths(0) + with pytest.raises(ValueError): + cfg.set_fmt_str_lengths(-2) + + +def test_truncated_rows_cols_values_ascii() -> None: + df = pl.DataFrame({f"c{n}": list(range(-n, 100 - n)) for n in range(10)}) + + pl.Config.set_tbl_formatting("UTF8_BORDERS_ONLY", rounded_corners=True) + assert ( + str(df) == "shape: (100, 10)\n" + "╭───────────────────────────────────────────────────╮\n" + "│ c0 c1 c2 c3 … c6 c7 c8 c9 │\n" + "│ --- --- --- --- --- --- --- --- │\n" + "│ i64 i64 i64 i64 i64 i64 i64 i64 │\n" + "╞═══════════════════════════════════════════════════╡\n" + "│ 0 -1 -2 -3 … -6 -7 -8 -9 │\n" + "│ 1 0 -1 -2 … -5 -6 -7 -8 │\n" + "│ 2 1 0 -1 … -4 -5 -6 -7 │\n" + "│ 3 2 1 0 … -3 -4 -5 -6 │\n" + "│ 4 3 2 1 … -2 -3 -4 -5 │\n" + "│ … … … … … … … … … │\n" + "│ 95 94 93 92 … 89 88 87 86 │\n" + "│ 96 95 94 93 … 90 89 88 87 │\n" + "│ 97 96 95 94 … 91 90 89 88 │\n" + "│ 98 97 96 95 … 92 91 90 89 │\n" + "│ 99 98 97 96 … 93 92 91 90 │\n" + "╰───────────────────────────────────────────────────╯" + ) + with pl.Config(tbl_formatting="ASCII_FULL_CONDENSED"): + assert ( + str(df) == "shape: (100, 10)\n" + "+-----+-----+-----+-----+-----+-----+-----+-----+-----+\n" + "| c0 | c1 | c2 | c3 | ... | c6 | c7 | c8 | c9 |\n" + "| --- | --- | --- | --- | | --- | --- | --- | --- |\n" + "| i64 | i64 | i64 | i64 | | i64 | i64 | i64 | i64 |\n" + "+=====================================================+\n" + "| 0 | -1 | -2 | -3 | ... | -6 | -7 | -8 | -9 |\n" + "| 1 | 0 | -1 | -2 | ... | -5 | -6 | -7 | -8 |\n" + "| 2 | 1 | 0 | -1 | ... | -4 | -5 | -6 | -7 |\n" + "| 3 | 2 | 1 | 0 | ... | -3 | -4 | -5 | -6 |\n" + "| 4 | 3 | 2 | 1 | ... | -2 | -3 | -4 | -5 |\n" + "| ... | ... | ... | ... | ... | ... | ... | ... | ... |\n" + "| 95 | 94 | 93 | 92 | ... | 89 | 88 | 87 | 86 |\n" + "| 96 | 95 | 94 | 93 | ... | 90 | 89 | 88 | 87 |\n" + "| 97 | 96 | 95 | 94 | ... | 91 | 90 | 89 | 88 |\n" + "| 98 | 97 | 96 | 95 | ... | 92 | 91 | 90 | 89 |\n" + "| 99 | 98 | 97 | 96 | ... | 93 | 92 | 91 | 90 |\n" + "+-----+-----+-----+-----+-----+-----+-----+-----+-----+" + ) + + with pl.Config(tbl_formatting="MARKDOWN"): + df = pl.DataFrame({"b": [b"0tigohij1prisdfj1gs2io3fbjg0pfihodjgsnfbbmfgnd8j"]}) + assert ( + str(df) + == dedent(""" + shape: (1, 1) + | b | + | --- | + | binary | + |---------------------------------| + | b"0tigohij1prisdfj1gs2io3fbjg0… |""").lstrip() + ) + + with pl.Config(tbl_formatting="ASCII_MARKDOWN"): + df = pl.DataFrame({"b": [b"0tigohij1prisdfj1gs2io3fbjg0pfihodjgsnfbbmfgnd8j"]}) + assert ( + str(df) + == dedent(""" + shape: (1, 1) + | b | + | --- | + | binary | + |-----------------------------------| + | b"0tigohij1prisdfj1gs2io3fbjg0... |""").lstrip() + ) + + +def test_warn_unstable(recwarn: pytest.WarningsRecorder) -> None: + issue_unstable_warning() + assert len(recwarn) == 0 + + pl.Config().warn_unstable(True) + + issue_unstable_warning() + assert len(recwarn) == 1 + + pl.Config().warn_unstable(False) + + issue_unstable_warning() + assert len(recwarn) == 1 + + +@pytest.mark.parametrize( + ("environment_variable", "config_setting", "value", "expected"), + [ + ("POLARS_ENGINE_AFFINITY", "set_engine_affinity", "gpu", "gpu"), + ("POLARS_AUTO_STRUCTIFY", "set_auto_structify", True, "1"), + ("POLARS_FMT_MAX_COLS", "set_tbl_cols", 12, "12"), + ("POLARS_FMT_MAX_ROWS", "set_tbl_rows", 3, "3"), + ("POLARS_FMT_STR_LEN", "set_fmt_str_lengths", 42, "42"), + ("POLARS_FMT_TABLE_CELL_ALIGNMENT", "set_tbl_cell_alignment", "RIGHT", "RIGHT"), + ( + "POLARS_FMT_TABLE_CELL_NUMERIC_ALIGNMENT", + "set_tbl_cell_numeric_alignment", + "RIGHT", + "RIGHT", + ), + ("POLARS_FMT_TABLE_HIDE_COLUMN_NAMES", "set_tbl_hide_column_names", True, "1"), + ( + "POLARS_FMT_TABLE_DATAFRAME_SHAPE_BELOW", + "set_tbl_dataframe_shape_below", + True, + "1", + ), + ( + "POLARS_FMT_TABLE_FORMATTING", + "set_ascii_tables", + True, + "ASCII_FULL_CONDENSED", + ), + ( + "POLARS_FMT_TABLE_FORMATTING", + "set_tbl_formatting", + "ASCII_MARKDOWN", + "ASCII_MARKDOWN", + ), + ( + "POLARS_FMT_TABLE_HIDE_COLUMN_DATA_TYPES", + "set_tbl_hide_column_data_types", + True, + "1", + ), + ( + "POLARS_FMT_TABLE_HIDE_COLUMN_SEPARATOR", + "set_tbl_hide_dtype_separator", + True, + "1", + ), + ( + "POLARS_FMT_TABLE_HIDE_DATAFRAME_SHAPE_INFORMATION", + "set_tbl_hide_dataframe_shape", + True, + "1", + ), + ( + "POLARS_FMT_TABLE_INLINE_COLUMN_DATA_TYPE", + "set_tbl_column_data_type_inline", + True, + "1", + ), + ("POLARS_STREAMING_CHUNK_SIZE", "set_streaming_chunk_size", 100, "100"), + ("POLARS_TABLE_WIDTH", "set_tbl_width_chars", 80, "80"), + ("POLARS_VERBOSE", "set_verbose", True, "1"), + ("POLARS_WARN_UNSTABLE", "warn_unstable", True, "1"), + ], +) +def test_unset_config_env_vars( + environment_variable: str, config_setting: str, value: Any, expected: str +) -> None: + assert environment_variable in _POLARS_CFG_ENV_VARS + + with pl.Config(**{config_setting: value}): + assert os.environ[environment_variable] == expected + + with pl.Config(**{config_setting: None}): # type: ignore[arg-type] + assert environment_variable not in os.environ diff --git a/py-polars/tests/unit/test_conftest.py b/py-polars/tests/unit/test_conftest.py new file mode 100644 index 000000000000..0a1712dcf578 --- /dev/null +++ b/py-polars/tests/unit/test_conftest.py @@ -0,0 +1,38 @@ +"""Tests for the testing infrastructure.""" + +import pytest + + +@pytest.mark.xfail +def test_memory_usage() -> None: + pytest.fail(reason="Disabled for now") + # """The ``memory_usage`` fixture gives somewhat accurate results.""" + # memory_usage = memory_usage_without_pyarrow + # assert memory_usage.get_current() < 100_000 + # assert memory_usage.get_peak() < 100_000 + # + # # Memory from Python is tracked: + # b = b"X" * 1_300_000 + # assert 1_300_000 <= memory_usage.get_current() <= 2_000_000 + # assert 1_300_000 <= memory_usage.get_peak() <= 2_000_000 + # del b + # assert memory_usage.get_current() <= 500_000 + # assert 1_300_000 <= memory_usage.get_peak() <= 2_000_000 + # memory_usage.reset_tracking() + # assert memory_usage.get_current() < 100_000 + # assert memory_usage.get_peak() < 100_000 + # + # # Memory from Polars is tracked: + # df = pl.DataFrame({"x": pl.arange(0, 1_000_000, eager=True, dtype=pl.Int64)}) + # del df + # peak_bytes = memory_usage.get_peak() + # assert 8_000_000 <= peak_bytes < 8_500_000 + # + # memory_usage.reset_tracking() + # assert memory_usage.get_peak() < 1_000_000 + # + # # Memory from NumPy is tracked: + # arr = np.ones((1_400_000,), dtype=np.uint8) + # del arr + # peak = memory_usage.get_peak() + # assert 1_400_000 < peak < 1_500_000 diff --git a/py-polars/tests/unit/test_convert.py b/py-polars/tests/unit/test_convert.py new file mode 100644 index 000000000000..123c0a0cdebe --- /dev/null +++ b/py-polars/tests/unit/test_convert.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import pytest + +import polars as pl +from polars.exceptions import ComputeError, NoDataError + + +def test_from_records_schema_inference() -> None: + data = [[1, 2.1, 3], [4, 5, 6.4]] + + with pytest.raises(TypeError, match="unexpected value"): + pl.from_records(data) + + result = pl.from_records(data, strict=False) + assert result.to_dict(as_series=False) == { + "column_0": [1.0, 2.1, 3.0], + "column_1": [4.0, 5.0, 6.4], + } + + +def test_from_dicts_schema_inference() -> None: + data = [{"a": 1, "b": 2}, {"a": 3.1, "b": 4.5}] + result = pl.from_dicts(data) # type: ignore[arg-type] + assert result.to_dict(as_series=False) == { + "a": [1.0, 3.1], + "b": [2.0, 4.5], + } + + +def test_from_dicts_nested_nulls() -> None: + result = pl.from_dicts([{"a": [None, None]}, {"a": [1, 2]}]) + assert result.to_dict(as_series=False) == {"a": [[None, None], [1, 2]]} + + +def test_from_dicts_empty() -> None: + with pytest.raises(NoDataError, match="no data, cannot infer schema"): + pl.from_dicts([]) + + +def test_from_dicts_all_cols_6716() -> None: + dicts = [{"a": None} for _ in range(20)] + [{"a": "crash"}] + + with pytest.raises( + ComputeError, match="make sure that all rows have the same schema" + ): + pl.from_dicts(dicts, infer_schema_length=20) + assert pl.from_dicts(dicts, infer_schema_length=None).dtypes == [pl.String] + + +def test_dict_float_string_roundtrip_18882() -> None: + assert pl.from_dicts([{"A": "0.1"}]).to_dicts() == [{"A": "0.1"}] diff --git a/py-polars/tests/unit/test_cpu_check.py b/py-polars/tests/unit/test_cpu_check.py new file mode 100644 index 000000000000..efc72fe4bc0e --- /dev/null +++ b/py-polars/tests/unit/test_cpu_check.py @@ -0,0 +1,71 @@ +from unittest.mock import Mock + +import pytest + +from polars import _cpu_check +from polars._cpu_check import check_cpu_flags + + +@pytest.fixture +def _feature_flags(monkeypatch: pytest.MonkeyPatch) -> None: + """Use the default set of feature flags.""" + feature_flags = "+sse3,+ssse3" + monkeypatch.setattr(_cpu_check, "_POLARS_FEATURE_FLAGS", feature_flags) + + +@pytest.mark.usefixtures("_feature_flags") +def test_check_cpu_flags( + monkeypatch: pytest.MonkeyPatch, recwarn: pytest.WarningsRecorder +) -> None: + cpu_flags = {"sse3": True, "ssse3": True} + mock_read_cpu_flags = Mock(return_value=cpu_flags) + monkeypatch.setattr(_cpu_check, "_read_cpu_flags", mock_read_cpu_flags) + + check_cpu_flags() + + assert len(recwarn) == 0 + + +@pytest.mark.usefixtures("_feature_flags") +def test_check_cpu_flags_missing_features(monkeypatch: pytest.MonkeyPatch) -> None: + cpu_flags = {"sse3": True, "ssse3": False} + mock_read_cpu_flags = Mock(return_value=cpu_flags) + monkeypatch.setattr(_cpu_check, "_read_cpu_flags", mock_read_cpu_flags) + + with pytest.warns(RuntimeWarning, match="Missing required CPU features") as w: + check_cpu_flags() + + assert "ssse3" in str(w[0].message) + + +def test_check_cpu_flags_unknown_flag( + monkeypatch: pytest.MonkeyPatch, +) -> None: + real_cpu_flags = {"sse3": True, "ssse3": False} + mock_read_cpu_flags = Mock(return_value=real_cpu_flags) + monkeypatch.setattr(_cpu_check, "_read_cpu_flags", mock_read_cpu_flags) + unknown_feature_flags = "+sse3,+ssse3,+HelloWorld!" + monkeypatch.setattr(_cpu_check, "_POLARS_FEATURE_FLAGS", unknown_feature_flags) + with pytest.raises(RuntimeError, match="unknown feature flag: 'HelloWorld!'"): + check_cpu_flags() + + +def test_check_cpu_flags_skipped_no_flags(monkeypatch: pytest.MonkeyPatch) -> None: + mock_read_cpu_flags = Mock() + monkeypatch.setattr(_cpu_check, "_read_cpu_flags", mock_read_cpu_flags) + + check_cpu_flags() + + assert mock_read_cpu_flags.call_count == 0 + + +@pytest.mark.usefixtures("_feature_flags") +def test_check_cpu_flags_skipped_env_var(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("POLARS_SKIP_CPU_CHECK", "1") + + mock_read_cpu_flags = Mock() + monkeypatch.setattr(_cpu_check, "_read_cpu_flags", mock_read_cpu_flags) + + check_cpu_flags() + + assert mock_read_cpu_flags.call_count == 0 diff --git a/py-polars/tests/unit/test_cse.py b/py-polars/tests/unit/test_cse.py new file mode 100644 index 000000000000..c7c856f3c442 --- /dev/null +++ b/py-polars/tests/unit/test_cse.py @@ -0,0 +1,860 @@ +from __future__ import annotations + +import re +from datetime import date, datetime, timedelta +from tempfile import NamedTemporaryFile +from typing import Any + +import numpy as np +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + + +def num_cse_occurrences(explanation: str) -> int: + """The number of unique CSE columns in an explain string.""" + return len(set(re.findall('__POLARS_CSER_0x[^"]+"', explanation))) + + +def test_cse_rename_cross_join_5405() -> None: + # https://github.com/pola-rs/polars/issues/5405 + + right = pl.DataFrame({"A": [1, 2], "B": [3, 4], "D": [5, 6]}).lazy() + left = pl.DataFrame({"C": [3, 4]}).lazy().join(right.select("A"), how="cross") + + result = left.join(right.rename({"B": "C"}), on=["A", "C"], how="left").collect( + comm_subplan_elim=True + ) + + expected = pl.DataFrame( + { + "C": [3, 3, 4, 4], + "A": [1, 2, 1, 2], + "D": [5, None, None, 6], + } + ) + assert_frame_equal(result, expected, check_row_order=False) + + +def test_union_duplicates() -> None: + n_dfs = 10 + df_lazy = pl.DataFrame({}).lazy() + lazy_dfs = [df_lazy for _ in range(n_dfs)] + + result = len( + re.findall( + r".*CACHE\[id: .*, cache_hits: 9].*", + pl.concat(lazy_dfs).explain(), + flags=re.MULTILINE, + ) + ) + assert result + + +def test_cse_with_struct_expr_11116() -> None: + # https://github.com/pola-rs/polars/issues/11116 + + df = pl.DataFrame([{"s": {"a": 1, "b": 4}, "c": 3}]).lazy() + + result = df.with_columns( + pl.col("s").struct.field("a").alias("s_a"), + pl.col("s").struct.field("b").alias("s_b"), + ( + (pl.col("s").struct.field("a") <= pl.col("c")) + & (pl.col("s").struct.field("b") > pl.col("c")) + ).alias("c_between_a_and_b"), + ).collect(comm_subexpr_elim=True) + + expected = pl.DataFrame( + { + "s": [{"a": 1, "b": 4}], + "c": [3], + "s_a": [1], + "s_b": [4], + "c_between_a_and_b": [True], + } + ) + assert_frame_equal(result, expected) + + +def test_cse_schema_6081() -> None: + # https://github.com/pola-rs/polars/issues/6081 + + df = pl.DataFrame( + data=[ + [date(2022, 12, 12), 1, 1], + [date(2022, 12, 12), 1, 2], + [date(2022, 12, 13), 5, 2], + ], + schema=["date", "id", "value"], + orient="row", + ).lazy() + + min_value_by_group = df.group_by(["date", "id"]).agg( + pl.col("value").min().alias("min_value") + ) + + result = df.join(min_value_by_group, on=["date", "id"], how="left").collect( + comm_subplan_elim=True, projection_pushdown=True + ) + expected = pl.DataFrame( + { + "date": [date(2022, 12, 12), date(2022, 12, 12), date(2022, 12, 13)], + "id": [1, 1, 5], + "value": [1, 2, 2], + "min_value": [1, 1, 2], + } + ) + assert_frame_equal(result, expected, check_row_order=False) + + +def test_cse_9630() -> None: + lf1 = pl.LazyFrame({"key": [1], "x": [1]}) + lf2 = pl.LazyFrame({"key": [1], "y": [2]}) + + joined_lf2 = lf1.join(lf2, on="key") + + all_subsections = ( + pl.concat( + [ + lf1.select("key", pl.col("x").alias("value")), + joined_lf2.select("key", pl.col("y").alias("value")), + ] + ) + .group_by("key") + .agg(pl.col("value")) + ) + + intersected_df1 = all_subsections.join(lf1, on="key") + intersected_df2 = all_subsections.join(lf2, on="key") + + result = intersected_df1.join(intersected_df2, on=["key"], how="left").collect( + comm_subplan_elim=True + ) + + expected = pl.DataFrame( + { + "key": [1], + "value": [[1, 2]], + "x": [1], + "value_right": [[1, 2]], + "y": [2], + } + ) + assert_frame_equal(result, expected) + + +@pytest.mark.write_disk +def test_schema_row_index_cse() -> None: + with NamedTemporaryFile() as csv_a: + csv_a.write(b"A,B\nGr1,A\nGr1,B") + csv_a.seek(0) + + df_a = pl.scan_csv(csv_a.name).with_row_index("Idx") + + result = ( + df_a.join(df_a, on="B") + .group_by("A", maintain_order=True) + .all() + .collect(comm_subexpr_elim=True) + ) + + expected = pl.DataFrame( + { + "A": ["Gr1"], + "Idx": [[0, 1]], + "B": [["A", "B"]], + "Idx_right": [[0, 1]], + "A_right": [["Gr1", "Gr1"]], + }, + schema_overrides={"Idx": pl.List(pl.UInt32), "Idx_right": pl.List(pl.UInt32)}, + ) + assert_frame_equal(result, expected) + + +@pytest.mark.debug +def test_cse_expr_selection_context() -> None: + q = pl.LazyFrame( + { + "a": [1, 2, 3, 4], + "b": [1, 2, 3, 4], + "c": [1, 2, 3, 4], + } + ) + + derived = (pl.col("a") * pl.col("b")).sum() + derived2 = derived * derived + + exprs = [ + derived.alias("d1"), + (derived * pl.col("c").sum() - 1).alias("foo"), + derived2.alias("d2"), + (derived2 * 10).alias("d3"), + ] + + result = q.select(exprs).collect(comm_subexpr_elim=True) + assert num_cse_occurrences(q.select(exprs).explain(comm_subexpr_elim=True)) == 2 + expected = pl.DataFrame( + { + "d1": [30], + "foo": [299], + "d2": [900], + "d3": [9000], + } + ) + assert_frame_equal(result, expected) + + result = q.with_columns(exprs).collect(comm_subexpr_elim=True) + assert ( + num_cse_occurrences(q.with_columns(exprs).explain(comm_subexpr_elim=True)) == 2 + ) + expected = pl.DataFrame( + { + "a": [1, 2, 3, 4], + "b": [1, 2, 3, 4], + "c": [1, 2, 3, 4], + "d1": [30, 30, 30, 30], + "foo": [299, 299, 299, 299], + "d2": [900, 900, 900, 900], + "d3": [9000, 9000, 9000, 9000], + } + ) + assert_frame_equal(result, expected) + + +def test_windows_cse_excluded() -> None: + lf = pl.LazyFrame( + data=[ + ("a", "aaa", 1), + ("a", "bbb", 3), + ("a", "ccc", 1), + ("c", "xxx", 2), + ("c", "yyy", 3), + ("c", "zzz", 4), + ("b", "qqq", 0), + ], + schema=["a", "b", "c"], + orient="row", + ) + + result = lf.select( + c_diff=pl.col("c").diff(1), + c_diff_by_a=pl.col("c").diff(1).over("a"), + ).collect(comm_subexpr_elim=True) + + expected = pl.DataFrame( + { + "c_diff": [None, 2, -2, 1, 1, 1, -4], + "c_diff_by_a": [None, 2, -2, None, 1, 1, None], + } + ) + assert_frame_equal(result, expected) + + +def test_cse_group_by_10215() -> None: + lf = pl.LazyFrame({"a": [1], "b": [1]}) + + result = lf.group_by("b").agg( + (pl.col("a").sum() * pl.col("a").sum()).alias("x"), + (pl.col("b").sum() * pl.col("b").sum()).alias("y"), + (pl.col("a").sum() * pl.col("a").sum()).alias("x2"), + ((pl.col("a") + 2).sum() * pl.col("a").sum()).alias("x3"), + ((pl.col("a") + 2).sum() * pl.col("b").sum()).alias("x4"), + ((pl.col("a") + 2).sum() * pl.col("b").sum()), + ) + + assert "__POLARS_CSER" in result.explain(comm_subexpr_elim=True) + expected = pl.DataFrame( + { + "b": [1], + "x": [1], + "y": [1], + "x2": [1], + "x3": [3], + "x4": [3], + "a": [3], + } + ) + assert_frame_equal(result.collect(comm_subexpr_elim=True), expected) + + +def test_cse_mixed_window_functions() -> None: + # checks if the window caches are cleared + # there are windows in the cse's and the default expressions + lf = pl.LazyFrame({"a": [1], "b": [1], "c": [1]}) + + result = lf.select( + pl.col("a"), + pl.col("b"), + pl.col("c"), + pl.col("b").rank().alias("rank"), + pl.col("b").rank().alias("d_rank"), + pl.col("b").first().over([pl.col("a")]).alias("b_first"), + pl.col("b").last().over([pl.col("a")]).alias("b_last"), + pl.col("b").shift().alias("b_lag_1"), + pl.col("b").shift().alias("b_lead_1"), + pl.col("c").cum_sum().alias("c_cumsum"), + pl.col("c").cum_sum().over([pl.col("a")]).alias("c_cumsum_by_a"), + pl.col("c").diff().alias("c_diff"), + pl.col("c").diff().over([pl.col("a")]).alias("c_diff_by_a"), + ).collect(comm_subexpr_elim=True) + + expected = pl.DataFrame( + { + "a": [1], + "b": [1], + "c": [1], + "rank": [1.0], + "d_rank": [1.0], + "b_first": [1], + "b_last": [1], + "b_lag_1": [None], + "b_lead_1": [None], + "c_cumsum": [1], + "c_cumsum_by_a": [1], + "c_diff": [None], + "c_diff_by_a": [None], + }, + ).with_columns(pl.col(pl.Null).cast(pl.Int64)) + assert_frame_equal(result, expected) + + +def test_cse_10401() -> None: + df = pl.LazyFrame({"clicks": [1.0, float("nan"), None]}) + + q = df.with_columns(pl.all().fill_null(0).fill_nan(0)) + + assert r"""col("clicks").fill_null([0.0]).alias("__POLARS_CSER""" in q.explain() + + expected = pl.DataFrame({"clicks": [1.0, 0.0, 0.0]}) + assert_frame_equal(q.collect(comm_subexpr_elim=True), expected) + + +def test_cse_10441() -> None: + lf = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 2, 1]}) + + result = lf.select( + pl.col("a").sum() + pl.col("a").sum() + pl.col("b").sum() + ).collect(comm_subexpr_elim=True) + + expected = pl.DataFrame({"a": [18]}) + assert_frame_equal(result, expected) + + +def test_cse_10452() -> None: + lf = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 2, 1]}) + q = lf.select( + pl.col("b").sum() + pl.col("a").sum().over(pl.col("b")) + pl.col("b").sum() + ) + + assert "__POLARS_CSE" in q.explain(comm_subexpr_elim=True) + + expected = pl.DataFrame({"b": [13, 14, 15]}) + assert_frame_equal(q.collect(comm_subexpr_elim=True), expected) + + +def test_cse_group_by_ternary_10490() -> None: + lf = pl.LazyFrame( + { + "a": [1, 1, 2, 2], + "b": [1, 2, 3, 4], + "c": [2, 3, 4, 5], + } + ) + + result = ( + lf.group_by("a") + .agg( + [ + pl.when(pl.col(col).is_null().all()).then(None).otherwise(1).alias(col) + for col in ["b", "c"] + ] + + [ + (pl.col("a").sum() * pl.col("a").sum()).alias("x"), + (pl.col("b").sum() * pl.col("b").sum()).alias("y"), + (pl.col("a").sum() * pl.col("a").sum()).alias("x2"), + ((pl.col("a") + 2).sum() * pl.col("a").sum()).alias("x3"), + ((pl.col("a") + 2).sum() * pl.col("b").sum()).alias("x4"), + ] + ) + .collect(comm_subexpr_elim=True) + .sort("a") + ) + + expected = pl.DataFrame( + { + "a": [1, 2], + "b": [1, 1], + "c": [1, 1], + "x": [4, 16], + "y": [9, 49], + "x2": [4, 16], + "x3": [12, 32], + "x4": [18, 56], + }, + schema_overrides={"b": pl.Int32, "c": pl.Int32}, + ) + assert_frame_equal(result, expected) + + +def test_cse_quantile_10815() -> None: + np.random.seed(1) + a = np.random.random(10) + b = np.random.random(10) + df = pl.DataFrame({"a": a, "b": b}) + cols = ["a", "b"] + q = df.lazy().select( + *( + pl.col(c).quantile(0.75, interpolation="midpoint").name.suffix("_3") + for c in cols + ), + *( + pl.col(c).quantile(0.25, interpolation="midpoint").name.suffix("_1") + for c in cols + ), + ) + assert "__POLARS_CSE" not in q.explain() + assert q.collect().to_dict(as_series=False) == { + "a_3": [0.40689473946662197], + "b_3": [0.6145786693120769], + "a_1": [0.16650805109739197], + "b_1": [0.2012768694081981], + } + + +def test_cse_nan_10824() -> None: + v = pl.col("a") / pl.col("b") + magic = pl.when(v > 0).then(pl.lit(float("nan"))).otherwise(v) + assert ( + str( + ( + pl.DataFrame( + { + "a": [1.0], + "b": [1.0], + } + ) + .lazy() + .select(magic) + .collect(comm_subexpr_elim=True) + ).to_dict(as_series=False) + ) + == "{'literal': [nan]}" + ) + + +def test_cse_10901() -> None: + df = pl.DataFrame(data=range(6), schema={"a": pl.Int64}) + a = pl.col("a").rolling_sum(window_size=2) + b = pl.col("a").rolling_sum(window_size=3) + exprs = { + "ax1": a, + "ax2": a * 2, + "bx1": b, + "bx2": b * 2, + } + + expected = pl.DataFrame( + { + "a": [0, 1, 2, 3, 4, 5], + "ax1": [None, 1, 3, 5, 7, 9], + "ax2": [None, 2, 6, 10, 14, 18], + "bx1": [None, None, 3, 6, 9, 12], + "bx2": [None, None, 6, 12, 18, 24], + } + ) + + assert_frame_equal(df.lazy().with_columns(**exprs).collect(), expected) + + +def test_cse_count_in_group_by() -> None: + q = ( + pl.LazyFrame({"a": [1, 1, 2], "b": [1, 2, 3], "c": [40, 51, 12]}) + .group_by("a") + .agg(pl.all().slice(0, pl.len() - 1)) + ) + + assert "POLARS_CSER" not in q.explain() + assert q.collect().sort("a").to_dict(as_series=False) == { + "a": [1, 2], + "b": [[1], []], + "c": [[40], []], + } + + +def test_cse_slice_11594() -> None: + df = pl.LazyFrame({"a": [1, 2, 1, 2, 1, 2]}) + + q = df.select( + pl.col("a").slice(offset=1, length=pl.len() - 1).alias("1"), + pl.col("a").slice(offset=1, length=pl.len() - 1).alias("2"), + ) + + assert "__POLARS_CSE" in q.explain(comm_subexpr_elim=True) + + assert q.collect(comm_subexpr_elim=True).to_dict(as_series=False) == { + "1": [2, 1, 2, 1, 2], + "2": [2, 1, 2, 1, 2], + } + + q = df.select( + pl.col("a").slice(offset=1, length=pl.len() - 1).alias("1"), + pl.col("a").slice(offset=0, length=pl.len() - 1).alias("2"), + ) + + assert "__POLARS_CSE" in q.explain(comm_subexpr_elim=True) + + assert q.collect(comm_subexpr_elim=True).to_dict(as_series=False) == { + "1": [2, 1, 2, 1, 2], + "2": [1, 2, 1, 2, 1], + } + + +def test_cse_is_in_11489() -> None: + df = pl.DataFrame( + {"cond": [1, 2, 3, 2, 1], "x": [1.0, 0.20, 3.0, 4.0, 0.50]} + ).lazy() + any_cond = ( + pl.when(pl.col("cond").is_in([2, 3])) + .then(True) + .when(pl.col("cond").is_in([1])) + .then(False) + .otherwise(None) + .alias("any_cond") + ) + val = ( + pl.when(any_cond) + .then(1.0) + .when(~any_cond) + .then(0.0) + .otherwise(None) + .alias("val") + ) + assert df.select("cond", any_cond, val).collect().to_dict(as_series=False) == { + "cond": [1, 2, 3, 2, 1], + "any_cond": [False, True, True, True, False], + "val": [0.0, 1.0, 1.0, 1.0, 0.0], + } + + +def test_cse_11958() -> None: + df = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) + vector_losses = [] + for lag in range(1, 5): + difference = pl.col("a") - pl.col("a").shift(lag) + component_loss = pl.when(difference >= 0).then(difference * 10) + vector_losses.append(component_loss.alias(f"diff{lag}")) + + q = df.select(vector_losses) + assert "__POLARS_CSE" in q.explain(comm_subexpr_elim=True) + assert q.collect(comm_subexpr_elim=True).to_dict(as_series=False) == { + "diff1": [None, 10, 10, 10, 10], + "diff2": [None, None, 20, 20, 20], + "diff3": [None, None, None, 30, 30], + "diff4": [None, None, None, None, 40], + } + + +def test_cse_14047() -> None: + ldf = pl.LazyFrame( + { + "timestamp": pl.datetime_range( + datetime(2024, 1, 12), + datetime(2024, 1, 12, 0, 0, 0, 150_000), + "10ms", + eager=True, + closed="left", + ), + "price": list(range(15)), + } + ) + + def count_diff( + price: pl.Expr, upper_bound: float = 0.1, lower_bound: float = 0.001 + ) -> pl.Expr: + span_end_to_curr = ( + price.count() + .cast(int) + .rolling("timestamp", period=timedelta(seconds=lower_bound)) + ) + span_start_to_curr = ( + price.count() + .cast(int) + .rolling("timestamp", period=timedelta(seconds=upper_bound)) + ) + return (span_start_to_curr - span_end_to_curr).alias( + f"count_diff_{upper_bound}_{lower_bound}" + ) + + def s_per_count(count_diff: pl.Expr, span: tuple[float, float]) -> pl.Expr: + return (span[1] * 1000 - span[0] * 1000) / count_diff + + spans = [(0.001, 0.1), (1, 10)] + count_diff_exprs = [count_diff(pl.col("price"), span[0], span[1]) for span in spans] + s_per_count_exprs = [ + s_per_count(count_diff, span).alias(f"zz_{span}") + for count_diff, span in zip(count_diff_exprs, spans) + ] + + exprs = count_diff_exprs + s_per_count_exprs + ldf = ldf.with_columns(*exprs) + assert_frame_equal( + ldf.collect(comm_subexpr_elim=True), ldf.collect(comm_subexpr_elim=False) + ) + + +def test_cse_15536() -> None: + source = pl.DataFrame({"a": range(10)}) + + data = source.lazy().filter(pl.col("a") >= 5) + + assert pl.concat( + [ + data.filter(pl.lit(True) & (pl.col("a") == 6) | (pl.col("a") == 9)), + data.filter(pl.lit(True) & (pl.col("a") == 7) | (pl.col("a") == 8)), + ] + ).collect()["a"].to_list() == [6, 9, 7, 8] + + +def test_cse_15548() -> None: + ldf = pl.LazyFrame({"a": [1, 2, 3]}) + ldf2 = ldf.filter(pl.col("a") == 1).cache() + ldf3 = pl.concat([ldf, ldf2]) + + assert len(ldf3.collect(comm_subplan_elim=False)) == 4 + assert len(ldf3.collect(comm_subplan_elim=True)) == 4 + + +@pytest.mark.debug +def test_cse_and_schema_update_projection_pd() -> None: + df = pl.LazyFrame({"a": [1, 2], "b": [99, 99]}) + + q = ( + df.lazy() + .with_row_index() + .select( + pl.when(pl.col("b") < 10) + .then(0.1 * pl.col("b")) + .when(pl.col("b") < 100) + .then(0.2 * pl.col("b")) + ) + ) + assert q.collect(comm_subplan_elim=False).to_dict(as_series=False) == { + "literal": [19.8, 19.8] + } + assert num_cse_occurrences(q.explain(comm_subexpr_elim=True)) == 1 + + +@pytest.mark.debug +@pytest.mark.may_fail_auto_streaming +def test_cse_predicate_self_join(capfd: Any, monkeypatch: Any) -> None: + monkeypatch.setenv("POLARS_VERBOSE", "1") + y = pl.LazyFrame({"a": [1], "b": [2], "y": [3]}) + + xf = y.filter(pl.col("y") == 2).select(["a", "b"]) + y_xf = y.join(xf, on=["a", "b"], how="left") + + y_xf_c = y_xf.select("a", "b") + assert y_xf_c.collect().to_dict(as_series=False) == {"a": [1], "b": [2]} + captured = capfd.readouterr().err + assert "CACHE HIT" in captured + + +def test_cse_manual_cache_15688() -> None: + df = pl.LazyFrame( + {"a": [1, 2, 3, 1, 2, 3], "b": [1, 1, 1, 1, 1, 1], "id": [1, 1, 1, 2, 2, 2]} + ) + + df1 = df.filter(id=1).join(df.filter(id=2), on=["a", "b"], how="semi") + df2 = df.filter(id=1).join(df1, on=["a", "b"], how="semi") + df2 = df2.cache() + res = df2.group_by("b").agg(pl.all().sum()) + + print(res.cache().with_columns(foo=1).explain(comm_subplan_elim=True)) + assert res.cache().with_columns(foo=1).collect().to_dict(as_series=False) == { + "b": [1], + "a": [6], + "id": [3], + "foo": [1], + } + + +def test_cse_drop_nulls_15795() -> None: + A = pl.LazyFrame({"X": 1}) + B = pl.LazyFrame({"X": 1, "Y": 0}).filter(pl.col("Y").is_not_null()) + C = A.join(B, on="X").select("X") + D = B.select("X") + assert C.join(D, on="X").collect().shape == (1, 1) + + +def test_cse_no_projection_15980() -> None: + df = pl.LazyFrame({"x": "a", "y": 1}) + df = pl.concat(df.with_columns(pl.col("y").add(n)) for n in range(2)) + + assert df.filter(pl.col("x").eq("a")).select("x").collect().to_dict( + as_series=False + ) == {"x": ["a", "a"]} + + +@pytest.mark.debug +def test_cse_series_collision_16138() -> None: + holdings = pl.DataFrame( + { + "fund_currency": ["CLP", "CLP"], + "asset_currency": ["EUR", "USA"], + } + ) + + usd = ["USD"] + eur = ["EUR"] + clp = ["CLP"] + + currency_factor_query_dict = [ + pl.col("asset_currency").is_in(eur) & pl.col("fund_currency").is_in(clp), + pl.col("asset_currency").is_in(eur) & pl.col("fund_currency").is_in(usd), + pl.col("asset_currency").is_in(clp) & pl.col("fund_currency").is_in(clp), + pl.col("asset_currency").is_in(usd) & pl.col("fund_currency").is_in(usd), + ] + + factor_holdings = holdings.lazy().with_columns( + pl.coalesce(currency_factor_query_dict).alias("currency_factor"), + ) + + assert factor_holdings.collect(comm_subexpr_elim=True).to_dict(as_series=False) == { + "fund_currency": ["CLP", "CLP"], + "asset_currency": ["EUR", "USA"], + "currency_factor": [True, False], + } + assert num_cse_occurrences(factor_holdings.explain(comm_subexpr_elim=True)) == 3 + + +def test_nested_cache_no_panic_16553() -> None: + assert pl.LazyFrame().select(a=[[[1]]]).collect(comm_subexpr_elim=True).to_dict( + as_series=False + ) == {"a": [[[[1]]]]} + + +def test_hash_empty_series_16577() -> None: + s = pl.Series(values=None) + out = pl.LazyFrame().select(s).collect() + assert out.equals(s.to_frame()) + + +def test_cse_non_scalar_length_mismatch_17732() -> None: + df = pl.LazyFrame({"a": pl.Series(range(30), dtype=pl.Int32)}) + got = ( + df.lazy() + .with_columns( + pl.col("a").head(5).min().alias("b"), + pl.col("a").head(5).max().alias("c"), + ) + .collect(comm_subexpr_elim=True) + ) + expect = pl.DataFrame( + { + "a": pl.Series(range(30), dtype=pl.Int32), + "b": pl.Series([0] * 30, dtype=pl.Int32), + "c": pl.Series([4] * 30, dtype=pl.Int32), + } + ) + + assert_frame_equal(expect, got) + + +def test_cse_chunks_18124() -> None: + df = pl.DataFrame( + { + "ts_diff": [timedelta(seconds=60)] * 2, + "ts_diff_after": [timedelta(seconds=120)] * 2, + } + ) + df = pl.concat([df, df], rechunk=False) + assert ( + df.lazy() + .with_columns( + ts_diff_sign=pl.col("ts_diff") > pl.duration(seconds=0), + ts_diff_after_sign=pl.col("ts_diff_after") > pl.duration(seconds=0), + ) + .filter(pl.col("ts_diff") > 1) + ).collect().shape == (4, 4) + + +@pytest.mark.may_fail_auto_streaming +def test_eager_cse_during_struct_expansion_18411() -> None: + df = pl.DataFrame({"foo": [0, 0, 0, 1, 1]}) + vc = pl.col("foo").value_counts() + classes = vc.struct[0] + counts = vc.struct[1] + # Check if output is stable + assert ( + df.select(pl.col("foo").replace(classes, counts)) + == df.select(pl.col("foo").replace(classes, counts)) + )["foo"].all() + + +def test_cse_as_struct_19253() -> None: + df = pl.LazyFrame({"x": [1, 2], "y": [4, 5]}) + + assert ( + df.with_columns( + q1=pl.struct(pl.col.x - pl.col.y.mean()), + q2=pl.struct(pl.col.x - pl.col.y.mean().over("y")), + ).collect() + ).to_dict(as_series=False) == { + "x": [1, 2], + "y": [4, 5], + "q1": [{"x": -3.5}, {"x": -2.5}], + "q2": [{"x": -3.0}, {"x": -3.0}], + } + + +@pytest.mark.may_fail_auto_streaming +def test_cse_as_struct_value_counts_20927() -> None: + assert pl.DataFrame({"x": [i for i in range(1, 6) for _ in range(i)]}).select( + pl.struct("x").value_counts().struct.unnest() + ).sort("count").to_dict(as_series=False) == { + "x": [{"x": 1}, {"x": 2}, {"x": 3}, {"x": 4}, {"x": 5}], + "count": [1, 2, 3, 4, 5], + } + + +def test_cse_union_19227() -> None: + lf = pl.LazyFrame({"A": [1], "B": [2]}) + lf_1 = lf.select(C="A", B="B") + lf_2 = lf.select(C="A", A="B") + + direct = lf_2.join(lf, on=["A"]).select("C", "A", "B") + + indirect = lf_1.join(direct, on=["C", "B"]).select("C", "A", "B") + + out = pl.concat([direct, indirect]) + assert out.collect().schema == pl.Schema( + [("C", pl.Int64), ("A", pl.Int64), ("B", pl.Int64)] + ) + + +def test_cse_21115() -> None: + lf = pl.LazyFrame({"x": 1, "y": 5}) + + assert lf.with_columns( + pl.all().exp() + pl.min_horizontal(pl.all().exp()) + ).collect().to_dict(as_series=False) == { + "x": [5.43656365691809], + "y": [151.13144093103566], + } + + +def test_cse_cache_leakage_22339() -> None: + lf1 = pl.LazyFrame({"x": [True] * 2}) + lf2 = pl.LazyFrame({"x": [True] * 3}) + + a = lf1 + b = lf1.filter(pl.col("x").not_().over(1)) + c = lf2.filter(pl.col("x").not_().over(1)) + + ab = a.join(b, on="x") + bc = b.join(c, on="x") + ac = a.join(c, on="x") + + assert pl.concat([ab, bc, ac]).collect().to_dict(as_series=False) == {"x": []} diff --git a/py-polars/tests/unit/test_cwc.py b/py-polars/tests/unit/test_cwc.py new file mode 100644 index 000000000000..1bd12058272e --- /dev/null +++ b/py-polars/tests/unit/test_cwc.py @@ -0,0 +1,250 @@ +# Tests for the optimization pass cluster WITH_COLUMNS + +import polars as pl + + +def test_basic_cwc() -> None: + df = ( + pl.LazyFrame({"a": [1, 2]}) + .with_columns(pl.col("a").alias("b") * 2) + .with_columns(pl.col("a").alias("c") * 3) + .with_columns(pl.col("a").alias("d") * 4) + ) + + assert ( + """[[(col("a")) * (2)].alias("b"), [(col("a")) * (3)].alias("c"), [(col("a")) * (4)].alias("d")]""" + in df.explain() + ) + + +def test_disable_cwc() -> None: + df = ( + pl.LazyFrame({"a": [1, 2]}) + .with_columns(pl.col("a").alias("b") * 2) + .with_columns(pl.col("a").alias("c") * 3) + .with_columns(pl.col("a").alias("d") * 4) + ) + + explain = df.explain(cluster_with_columns=False) + + assert """[[(col("a")) * (2)].alias("b")]""" in explain + assert """[[(col("a")) * (3)].alias("c")]""" in explain + assert """[[(col("a")) * (4)].alias("d")]""" in explain + + +def test_refuse_with_deps() -> None: + df = ( + pl.LazyFrame({"a": [1, 2]}) + .with_columns(pl.col("a").alias("b") * 2) + .with_columns(pl.col("b").alias("c") * 3) + .with_columns(pl.col("c").alias("d") * 4) + ) + + explain = df.explain() + + assert """[[(col("a")) * (2)].alias("b")]""" in explain + assert """[[(col("b")) * (3)].alias("c")]""" in explain + assert """[[(col("c")) * (4)].alias("d")]""" in explain + + +def test_partial_deps() -> None: + df = ( + pl.LazyFrame({"a": [1, 2]}) + .with_columns(pl.col("a").alias("b") * 2) + .with_columns( + pl.col("a").alias("c") * 3, + pl.col("b").alias("d") * 4, + pl.col("a").alias("e") * 5, + ) + .with_columns(pl.col("b").alias("f") * 6) + ) + + explain = df.explain() + + assert ( + """[[(col("b")) * (4)].alias("d"), [(col("b")) * (6)].alias("f")]""" in explain + ) + assert ( + """[[(col("a")) * (2)].alias("b"), [(col("a")) * (3)].alias("c"), [(col("a")) * (5)].alias("e")]""" + in explain + ) + + +def test_swap_remove() -> None: + df = ( + pl.LazyFrame({"a": [1, 2]}) + .with_columns(pl.col("a").alias("b") * 2) + .with_columns( + pl.col("b").alias("f") * 6, + pl.col("a").alias("c") * 3, + pl.col("b").alias("d") * 4, + pl.col("b").alias("e") * 5, + ) + ) + + explain = df.explain() + assert df.collect().equals( + pl.DataFrame( + { + "a": [1, 2], + "b": [2, 4], + "f": [12, 24], + "c": [3, 6], + "d": [8, 16], + "e": [10, 20], + } + ) + ) + + assert ( + """[[(col("b")) * (6)].alias("f"), [(col("b")) * (4)].alias("d"), [(col("b")) * (5)].alias("e")]""" + in explain + ) + assert ( + """[[(col("a")) * (2)].alias("b"), [(col("a")) * (3)].alias("c")]""" in explain + ) + assert """simple π""" in explain + + +def test_try_remove_simple_project() -> None: + df = ( + pl.LazyFrame({"a": [1, 2]}) + .with_columns(pl.col("a").alias("b") * 2) + .with_columns(pl.col("a").alias("d") * 4, pl.col("b").alias("c") * 3) + ) + + explain = df.explain() + + assert ( + """[[(col("a")) * (2)].alias("b"), [(col("a")) * (4)].alias("d")]""" in explain + ) + assert """[[(col("b")) * (3)].alias("c")]""" in explain + assert """simple π""" not in explain + + df = ( + pl.LazyFrame({"a": [1, 2]}) + .with_columns(pl.col("a").alias("b") * 2) + .with_columns(pl.col("b").alias("c") * 3, pl.col("a").alias("d") * 4) + ) + + explain = df.explain() + + assert ( + """[[(col("a")) * (2)].alias("b"), [(col("a")) * (4)].alias("d")]""" in explain + ) + assert """[[(col("b")) * (3)].alias("c")]""" in explain + assert """simple π""" in explain + + +def test_cwc_with_internal_aliases() -> None: + df = ( + pl.LazyFrame({"a": [1, 2], "b": [3, 4]}) + .with_columns(pl.any_horizontal((pl.col("a") == 2).alias("b")).alias("c")) + .with_columns(pl.col("b").alias("d") * 3) + ) + + explain = df.explain() + + assert ( + """[[(col("a")) == (2)].alias("c"), [(col("b")) * (3)].alias("d")]""" in explain + ) + + +def test_read_of_pushed_column_16436() -> None: + df = pl.DataFrame( + { + "x": [1.12, 2.21, 4.2, 3.21], + "y": [2.11, 3.32, 2.1, 6.12], + } + ) + + df = ( + df.lazy() + .with_columns((pl.col("y") / pl.col("x")).alias("z")) + .with_columns( + pl.when(pl.col("z").is_infinite()).then(0).otherwise(pl.col("z")).alias("z") + ) + .fill_nan(0) + .collect() + ) + + +def test_multiple_simple_projections_16435() -> None: + df = pl.DataFrame({"a": [1]}).lazy() + + df = ( + df.with_columns(b=pl.col("a")) + .with_columns(c=pl.col("b")) + .with_columns(l2a=pl.lit(2)) + .with_columns(l2b=pl.col("l2a")) + .with_columns(m=pl.lit(3)) + ) + + df.collect() + + +def test_reverse_order() -> None: + df = pl.LazyFrame({"a": [1], "b": [2]}) + + df = ( + df.with_columns(a=pl.col("a"), b=pl.col("b"), c=pl.col("a") * pl.col("b")) + .with_columns(x=pl.col("a"), y=pl.col("b")) + .with_columns(b=pl.col("a"), a=pl.col("b")) + ) + + df.collect() + + +def test_realias_of_unread_column_16530() -> None: + df = ( + pl.LazyFrame({"x": [True]}) + .with_columns(x=pl.lit(False)) + .with_columns(y=~pl.col("x")) + .with_columns(y=pl.lit(False)) + ) + + explain = df.explain() + + assert explain.count("WITH_COLUMNS") == 1 + assert df.collect().equals(pl.DataFrame({"x": [False], "y": [False]})) + + +def test_realias_with_dependencies() -> None: + df = ( + pl.LazyFrame({"x": [True]}) + .with_columns(x=pl.lit(False)) + .with_columns(y=~pl.col("x")) + .with_columns(y=pl.lit(False), z=pl.col("y") | True) + ) + + explain = df.explain() + + assert explain.count("WITH_COLUMNS") == 3 + assert df.collect().equals(pl.DataFrame({"x": [False], "y": [False], "z": [True]})) + + +def test_refuse_pushdown_with_aliases() -> None: + df = ( + pl.LazyFrame({"x": [True]}) + .with_columns(x=pl.lit(False)) + .with_columns(y=pl.lit(True)) + .with_columns(y=pl.lit(False), z=pl.col("y") | True) + ) + + explain = df.explain() + + assert explain.count("WITH_COLUMNS") == 2 + assert df.collect().equals(pl.DataFrame({"x": [False], "y": [False], "z": [True]})) + + +def test_neighbour_live_expr() -> None: + df = ( + pl.LazyFrame({"x": [True]}) + .with_columns(y=pl.lit(False)) + .with_columns(x=pl.lit(False), z=pl.col("x") | False) + ) + + explain = df.explain() + + assert explain.count("WITH_COLUMNS") == 1 + assert df.collect().equals(pl.DataFrame({"x": [False], "y": [False], "z": [True]})) diff --git a/py-polars/tests/unit/test_datatypes.py b/py-polars/tests/unit/test_datatypes.py new file mode 100644 index 000000000000..6364c6864025 --- /dev/null +++ b/py-polars/tests/unit/test_datatypes.py @@ -0,0 +1,238 @@ +from __future__ import annotations + +import pickle +from datetime import datetime, time, timedelta +from typing import TYPE_CHECKING + +import pytest + +import polars as pl +from polars import datatypes +from polars.datatypes import ( + DTYPE_TEMPORAL_UNITS, + Field, + Int64, + List, + Struct, + parse_into_dtype, +) +from polars.datatypes.group import DataTypeGroup +from tests.unit.conftest import DATETIME_DTYPES, NUMERIC_DTYPES + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + from polars.datatypes.classes import DataTypeClass + +SIMPLE_DTYPES: list[DataTypeClass] = [ + *[dt.base_type() for dt in NUMERIC_DTYPES], + pl.Boolean, + pl.String, + pl.Binary, + pl.Time, + pl.Date, + pl.Object, + pl.Null, + pl.Unknown, +] + + +@pytest.mark.parametrize("dtype", SIMPLE_DTYPES) +def test_simple_dtype_init_takes_no_args(dtype: DataTypeClass) -> None: + with pytest.raises(TypeError): + dtype(10) + + +def test_simple_dtype_init_returns_instance() -> None: + dtype = pl.Int8() + assert isinstance(dtype, pl.Int8) + + +def test_complex_dtype_init_returns_instance() -> None: + dtype = pl.Datetime() + assert isinstance(dtype, pl.Datetime) + assert dtype.time_unit == "us" + + +def test_dtype_time_units() -> None: + # check (in)equality behaviour of temporal types that take units + for time_unit in DTYPE_TEMPORAL_UNITS: + assert pl.Datetime == pl.Datetime(time_unit) + assert pl.Duration == pl.Duration(time_unit) + + assert pl.Datetime(time_unit) == pl.Datetime + assert pl.Duration(time_unit) == pl.Duration + + assert pl.Datetime("ms") != pl.Datetime("ns") + assert pl.Duration("ns") != pl.Duration("us") + + # check timeunit from pytype + assert parse_into_dtype(datetime) == pl.Datetime("us") + assert parse_into_dtype(timedelta) == pl.Duration + + with pytest.raises(ValueError, match="invalid `time_unit`"): + pl.Datetime("?") # type: ignore[arg-type] + + with pytest.raises(ValueError, match="invalid `time_unit`"): + pl.Duration("?") # type: ignore[arg-type] + + +def test_dtype_base_type() -> None: + assert pl.Date.base_type() is pl.Date + assert pl.List(pl.Int32).base_type() is pl.List + assert ( + pl.Struct([pl.Field("a", pl.Int64), pl.Field("b", pl.Boolean)]).base_type() + is pl.Struct + ) + for dtype in DATETIME_DTYPES: + assert dtype.base_type() is pl.Datetime + + +def test_dtype_groups() -> None: + grp = DataTypeGroup([pl.Datetime], match_base_type=False) + assert pl.Datetime("ms", "Asia/Tokyo") not in grp + + grp = DataTypeGroup([pl.Datetime]) + assert pl.Datetime("ms", "Asia/Tokyo") in grp + + +def test_dtypes_picklable() -> None: + parametric_type = pl.Datetime("ns") + singleton_type = pl.Float64 + assert pickle.loads(pickle.dumps(parametric_type)) == parametric_type + assert pickle.loads(pickle.dumps(singleton_type)) == singleton_type + + +def test_dtypes_hashable() -> None: + # ensure that all the types can be hashed, and that their hashes + # are sufficient to ensure distinct entries in a dictionary/set + + all_dtypes = [ + getattr(datatypes, d) + for d in dir(datatypes) + if isinstance(getattr(datatypes, d), datatypes.DataType) + ] + assert len(set(all_dtypes + all_dtypes)) == len(all_dtypes) + assert len({pl.Datetime("ms"), pl.Datetime("us"), pl.Datetime("ns")}) == 3 + assert len({pl.List, pl.List(pl.Int16), pl.List(pl.Int32), pl.List(pl.Int64)}) == 4 + + +@pytest.mark.parametrize( + ("dtype", "representation"), + [ + (pl.Boolean, "Boolean"), + (pl.Datetime, "Datetime"), + ( + pl.Datetime(time_zone="Europe/Amsterdam"), + "Datetime(time_unit='us', time_zone='Europe/Amsterdam')", + ), + (pl.List(pl.Int8), "List(Int8)"), + (pl.List(pl.Duration(time_unit="ns")), "List(Duration(time_unit='ns'))"), + (pl.Struct, "Struct"), + ( + pl.Struct({"name": pl.String, "ids": pl.List(pl.UInt32)}), + "Struct({'name': String, 'ids': List(UInt32)})", + ), + ], +) +def test_repr(dtype: PolarsDataType, representation: str) -> None: + assert repr(dtype) == representation + + +@pytest.mark.may_fail_auto_streaming +def test_conversion_dtype() -> None: + df = ( + pl.DataFrame( + { + "id_column": [1, 2, 3, 4], + "some_column": ["a", "b", "c", "d"], + "some_partition_column": [ + "partition_1", + "partition_2", + "partition_1", + "partition_2", + ], + } + ) + .select( + pl.struct( + pl.col("id_column"), pl.col("some_column").cast(pl.Categorical) + ).alias("struct"), + pl.col("some_partition_column"), + ) + .group_by("some_partition_column", maintain_order=True) + .agg("struct") + ) + + result: pl.DataFrame = pl.from_arrow(df.to_arrow()) # type: ignore[assignment] + # the assertion is not the real test + # this tests if dtype has bubbled up correctly in conversion + # if not we would UB + expected = { + "some_partition_column": ["partition_1", "partition_2"], + "struct": [ + [ + {"id_column": 1, "some_column": "a"}, + {"id_column": 3, "some_column": "c"}, + ], + [ + {"id_column": 2, "some_column": "b"}, + {"id_column": 4, "some_column": "d"}, + ], + ], + } + assert result.to_dict(as_series=False) == expected + + +def test_struct_field_iter() -> None: + s = Struct( + [Field("a", List(List(Int64))), Field("b", List(Int64)), Field("c", Int64)] + ) + assert list(s) == [ + ("a", List(List(Int64))), + ("b", List(Int64)), + ("c", Int64), + ] + assert list(reversed(s)) == [ + ("c", Int64), + ("b", List(Int64)), + ("a", List(List(Int64))), + ] + + +def test_raise_invalid_namespace() -> None: + with pytest.raises(pl.exceptions.InvalidOperationError): + pl.select(pl.lit(1.5).str.replace("1", "2")) + + +@pytest.mark.parametrize( + ("dtype", "lower", "upper"), + [ + (pl.Int8, -128, 127), + (pl.UInt8, 0, 255), + (pl.Int16, -32768, 32767), + (pl.UInt16, 0, 65535), + (pl.Int32, -2147483648, 2147483647), + (pl.UInt32, 0, 4294967295), + (pl.Int64, -9223372036854775808, 9223372036854775807), + (pl.UInt64, 0, 18446744073709551615), + ( + pl.Int128, + -170141183460469231731687303715884105728, + 170141183460469231731687303715884105727, + ), + (pl.Float32, float("-inf"), float("inf")), + (pl.Float64, float("-inf"), float("inf")), + (pl.Time, time(0, 0), time(23, 59, 59, 999999)), + ], +) +def test_max_min( + dtype: datatypes.IntegerType + | datatypes.Float32 + | datatypes.Float64 + | datatypes.Time, + upper: int | float | time, + lower: int | float | time, +) -> None: + df = pl.select(min=dtype.min(), max=dtype.max()) + assert df.to_series(0).item() == lower + assert df.to_series(1).item() == upper diff --git a/py-polars/tests/unit/test_empty.py b/py-polars/tests/unit/test_empty.py new file mode 100644 index 000000000000..6cedcaa59b54 --- /dev/null +++ b/py-polars/tests/unit/test_empty.py @@ -0,0 +1,167 @@ +import pytest + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal + + +def test_empty_str_concat_lit() -> None: + df = pl.DataFrame({"a": [], "b": []}, schema=[("a", pl.String), ("b", pl.String)]) + assert df.with_columns(pl.lit("asd") + pl.col("a")).schema == { + "a": pl.String, + "b": pl.String, + "literal": pl.String, + } + + +def test_empty_cross_join() -> None: + a = pl.LazyFrame(schema={"a": pl.Int32}) + b = pl.LazyFrame(schema={"b": pl.Int32}) + + assert (a.join(b, how="cross").collect()).schema == {"a": pl.Int32, "b": pl.Int32} + + +def test_empty_string_replace() -> None: + s = pl.Series("", [], dtype=pl.String) + assert_series_equal(s.str.replace("a", "b", literal=True), s) + assert_series_equal(s.str.replace("a", "b"), s) + assert_series_equal(s.str.replace("ab", "b", literal=True), s) + assert_series_equal(s.str.replace("ab", "b"), s) + + +def test_empty_window_function() -> None: + expr = (pl.col("VAL") / pl.col("VAL").sum()).over("KEY") + + df = pl.DataFrame(schema={"KEY": pl.String, "VAL": pl.Float64}) + df.select(expr) # ComputeError + + lf = pl.DataFrame(schema={"KEY": pl.String, "VAL": pl.Float64}).lazy() + expected = pl.DataFrame(schema={"VAL": pl.Float64}) + assert_frame_equal(lf.select(expr).collect(), expected) + + +def test_empty_count_window() -> None: + df = pl.DataFrame( + {"ID": [], "DESC": [], "dataset": []}, + schema={"ID": pl.String, "DESC": pl.String, "dataset": pl.String}, + ) + + out = df.select(pl.col("ID").count().over(["ID", "DESC"])) + assert out.schema == {"ID": pl.UInt32} + assert out.height == 0 + + +def test_empty_sort_by_args() -> None: + df = pl.DataFrame({"x": [2, 1, 3]}) + assert_frame_equal(df, df.select(pl.col.x.sort_by([]))) + assert_frame_equal(df, df.sort([])) + + +def test_empty_9137() -> None: + out = ( + pl.DataFrame( + {"id": [], "value": []}, + schema={"id": pl.Float32, "value": pl.Float32}, + ) + .group_by("id") + .agg(pl.col("value").pow(2).mean()) + ) + assert out.shape == (0, 2) + assert out.dtypes == [pl.Float32, pl.Float32] + + +@pytest.mark.parametrize("dtype", [pl.String, pl.Binary, pl.UInt32]) +@pytest.mark.parametrize( + "set_operation", + ["set_intersection", "set_union", "set_difference", "set_symmetric_difference"], +) +def test_empty_df_set_operations(set_operation: str, dtype: pl.DataType) -> None: + expr = getattr(pl.col("list1").list, set_operation)(pl.col("list2")) + df = pl.DataFrame([], {"list1": pl.List(dtype), "list2": pl.List(dtype)}) + assert df.select(expr).is_empty() + + +def test_empty_set_intersection() -> None: + full = pl.Series("full", [[1, 2, 3]], pl.List(pl.UInt32)) + empty = pl.Series("empty", [[]], pl.List(pl.UInt32)) + + assert_series_equal(empty.rename("full"), full.list.set_intersection(empty)) + assert_series_equal(empty, empty.list.set_intersection(full)) + + +def test_empty_set_difference() -> None: + full = pl.Series("full", [[1, 2, 3]], pl.List(pl.UInt32)) + empty = pl.Series("empty", [[]], pl.List(pl.UInt32)) + + assert_series_equal(full, full.list.set_difference(empty)) + assert_series_equal(empty, empty.list.set_difference(full)) + + +def test_empty_set_union() -> None: + full = pl.Series("full", [[1, 2, 3]], pl.List(pl.UInt32)) + empty = pl.Series("empty", [[]], pl.List(pl.UInt32)) + + assert_series_equal(full, full.list.set_union(empty)) + assert_series_equal(full.rename("empty"), empty.list.set_union(full)) + + +def test_empty_set_symmetric_difference() -> None: + full = pl.Series("full", [[1, 2, 3]], pl.List(pl.UInt32)) + empty = pl.Series("empty", [[]], pl.List(pl.UInt32)) + + assert_series_equal(full, full.list.set_symmetric_difference(empty)) + assert_series_equal(full.rename("empty"), empty.list.set_symmetric_difference(full)) + + +@pytest.mark.parametrize("name", ["sort", "unique", "head", "tail", "shift", "reverse"]) +def test_empty_list_namespace_output_9585(name: str) -> None: + dtype = pl.List(pl.String) + df = pl.DataFrame([[None]], schema={"A": dtype}) + + expr = getattr(pl.col("A").list, name)() + result = df.select(expr) + + assert result.dtypes == df.dtypes + + +def test_empty_is_in() -> None: + assert_series_equal( + pl.Series("a", [1, 2, 3]).is_in([]), pl.Series("a", [False] * 3) + ) + + +@pytest.mark.parametrize("method", ["drop_nulls", "unique"]) +def test_empty_to_empty(method: str) -> None: + assert getattr(pl.DataFrame(), method)().shape == (0, 0) + + +def test_empty_shift_over_16676() -> None: + df = pl.DataFrame({"a": [], "b": []}) + assert df.with_columns(pl.col("a").shift(fill_value=0).over("b")).shape == (0, 2) + + +def test_empty_list_cat_16405() -> None: + df = pl.DataFrame(schema={"cat": pl.List(pl.Categorical)}) + df.select(pl.col("cat") == pl.col("cat")) + + +def test_empty_list_concat_16924() -> None: + df = pl.DataFrame(schema={"a": pl.Int16, "b": pl.List(pl.String)}) + df.with_columns(pl.col("b").list.concat([pl.col("a").cast(pl.String)])) + + +def test_empty_input_expansion() -> None: + df = pl.DataFrame({"A": [1], "B": [2]}) + + with pytest.raises(pl.exceptions.InvalidOperationError): + ( + df.select("A", "B").with_columns( + pl.col("B").sort_by(pl.struct(pl.exclude("A", "B"))) + ) + ) + + +def test_empty_list_15523() -> None: + s = pl.Series("", [["a"], []], dtype=pl.List) + assert s.dtype == pl.List(pl.String) + s = pl.Series("", [[], ["a"]], dtype=pl.List) + assert s.dtype == pl.List(pl.String) diff --git a/py-polars/tests/unit/test_errors.py b/py-polars/tests/unit/test_errors.py new file mode 100644 index 000000000000..7ba4113b92bc --- /dev/null +++ b/py-polars/tests/unit/test_errors.py @@ -0,0 +1,743 @@ +from __future__ import annotations + +import io +from datetime import date, datetime, time, tzinfo +from decimal import Decimal +from typing import TYPE_CHECKING, Any + +import numpy as np +import pandas as pd +import pytest + +import polars as pl +from polars.datatypes.convert import dtype_to_py_type +from polars.exceptions import ( + ColumnNotFoundError, + ComputeError, + InvalidOperationError, + OutOfBoundsError, + SchemaError, + SchemaFieldNotFoundError, + ShapeError, + StructFieldNotFoundError, +) +from tests.unit.conftest import TEMPORAL_DTYPES + +if TYPE_CHECKING: + from polars._typing import ConcatMethod + + +def test_error_on_empty_group_by() -> None: + with pytest.raises( + ComputeError, match="at least one key is required in a group_by operation" + ): + pl.DataFrame({"x": [0, 0, 1, 1]}).group_by([]).agg(pl.len()) + + +def test_error_on_reducing_map() -> None: + df = pl.DataFrame( + {"id": [0, 0, 0, 1, 1, 1], "t": [2, 4, 5, 10, 11, 14], "y": [0, 1, 1, 2, 3, 4]} + ) + with pytest.raises( + InvalidOperationError, + match=( + r"output length of `map` \(1\) must be equal to " + r"the input length \(6\); consider using `apply` instead" + ), + ): + df.group_by("id").agg(pl.map_batches(["t", "y"], np.mean)) + + df = pl.DataFrame({"x": [1, 2, 3, 4], "group": [1, 2, 1, 2]}) + with pytest.raises( + InvalidOperationError, + match=( + r"output length of `map` \(1\) must be equal to " + r"the input length \(4\); consider using `apply` instead" + ), + ): + df.select( + pl.col("x") + .map_batches( + lambda x: x.cut(breaks=[1, 2, 3], include_breaks=True).struct.unnest(), + is_elementwise=True, + ) + .over("group") + ) + + +def test_error_on_invalid_by_in_asof_join() -> None: + df1 = pl.DataFrame( + { + "a": ["a", "b", "a"], + "b": [1, 2, 3], + "c": ["a", "b", "a"], + } + ).set_sorted("b") + + df2 = df1.with_columns(pl.col("a").cast(pl.Categorical)) + with pytest.raises(ComputeError): + df1.join_asof(df2, on="b", by=["a", "c"]) + + +@pytest.mark.parametrize("dtype", TEMPORAL_DTYPES) +def test_error_on_invalid_series_init(dtype: pl.DataType) -> None: + py_type = dtype_to_py_type(dtype) + with pytest.raises( + TypeError, + match=f"'float' object cannot be interpreted as a {py_type.__name__!r}", + ): + pl.Series([1.5, 2.0, 3.75], dtype=dtype) + + +def test_error_on_invalid_series_init2() -> None: + with pytest.raises(TypeError, match="unexpected value"): + pl.Series([1.5, 2.0, 3.75], dtype=pl.Int32) + + +def test_error_on_invalid_struct_field() -> None: + with pytest.raises(StructFieldNotFoundError): + pl.struct( + [pl.Series("a", [1, 2]), pl.Series("b", ["a", "b"])], eager=True + ).struct.field("z") + + +def test_not_found_error() -> None: + csv = "a,b,c\n2,1,1" + df = pl.read_csv(io.StringIO(csv)) + with pytest.raises(ColumnNotFoundError): + df.select("d") + + +def test_string_numeric_comp_err() -> None: + with pytest.raises(ComputeError, match="cannot compare string with numeric type"): + pl.DataFrame({"a": [1.1, 21, 31, 21, 51, 61, 71, 81]}).select(pl.col("a") < "9") + + +def test_panic_error() -> None: + with pytest.raises( + InvalidOperationError, + match="unit: 'k' not supported", + ): + pl.datetime_range( + start=datetime(2021, 12, 16), + end=datetime(2021, 12, 16, 3), + interval="99k", + eager=True, + ) + + +def test_join_lazy_on_df() -> None: + df_left = pl.DataFrame( + { + "Id": [1, 2, 3, 4], + "Names": ["A", "B", "C", "D"], + } + ) + df_right = pl.DataFrame({"Id": [1, 3], "Tags": ["xxx", "yyy"]}) + + with pytest.raises( + TypeError, + match="expected `other` .* to be a LazyFrame.* not 'DataFrame'", + ): + df_left.lazy().join(df_right, on="Id") # type: ignore[arg-type] + + with pytest.raises( + TypeError, + match="expected `other` .* to be a LazyFrame.* not 'DataFrame'", + ): + df_left.lazy().join_asof(df_right, on="Id") # type: ignore[arg-type] + + with pytest.raises( + TypeError, + match="expected `other` .* to be a LazyFrame.* not 'pandas.core.frame.DataFrame'", + ): + df_left.lazy().join_asof(df_right.to_pandas(), on="Id") # type: ignore[arg-type] + + +def test_projection_update_schema_missing_column() -> None: + with pytest.raises( + ColumnNotFoundError, + match='unable to find column "colC"', + ): + ( + pl.DataFrame({"colA": ["a", "b", "c"], "colB": [1, 2, 3]}) + .lazy() + .filter(~pl.col("colC").is_null()) + .group_by(["colA"]) + .agg([pl.col("colB").sum().alias("result")]) + .collect() + ) + + +def test_not_found_on_rename() -> None: + df = pl.DataFrame({"exists": [1, 2, 3]}) + + err_type = (SchemaFieldNotFoundError, ColumnNotFoundError) + with pytest.raises(err_type): + df.rename({"does_not_exist": "exists"}) + + with pytest.raises(err_type): + df.select(pl.col("does_not_exist").alias("new_name")) + + +def test_getitem_errs() -> None: + df = pl.DataFrame({"a": [1, 2, 3]}) + + with pytest.raises( + TypeError, + match=r"cannot select columns using key of type 'set': {'some'}", + ): + df[{"some"}] # type: ignore[call-overload] + + with pytest.raises( + TypeError, + match=r"cannot select elements using key of type 'set': {'strange'}", + ): + df["a"][{"strange"}] # type: ignore[call-overload] + + with pytest.raises( + TypeError, + match=r"cannot use `__setitem__` on DataFrame with key {'some'} of type 'set' and value 'foo' of type 'str'", + ): + df[{"some"}] = "foo" # type: ignore[index] + + +def test_err_bubbling_up_to_lit() -> None: + df = pl.DataFrame({"date": [date(2020, 1, 1)], "value": [42]}) + + with pytest.raises(TypeError): + df.filter(pl.col("date") == pl.Date("2020-01-01")) # type: ignore[call-arg,operator] + + +def test_error_on_double_agg() -> None: + for e in [ + "mean", + "max", + "min", + "sum", + "std", + "var", + "n_unique", + "last", + "first", + "median", + "skew", # this one is comes from Apply + ]: + with pytest.raises(ComputeError, match="the column is already aggregated"): + ( + pl.DataFrame( + { + "a": [1, 1, 1, 2, 2], + "b": [1, 2, 3, 4, 5], + } + ) + .group_by("a") + .agg([getattr(pl.col("b").min(), e)()]) + ) + + +def test_filter_not_of_type_bool() -> None: + df = pl.DataFrame({"json_val": ['{"a":"hello"}', None, '{"a":"world"}']}) + with pytest.raises( + InvalidOperationError, match="filter predicate must be of type `Boolean`, got" + ): + df.filter(pl.col("json_val").str.json_path_match("$.a")) + + +def test_is_nan_on_non_boolean() -> None: + with pytest.raises(InvalidOperationError): + pl.Series(["1", "2", "3"]).fill_nan("2") # type: ignore[arg-type] + + +def test_window_expression_different_group_length() -> None: + try: + pl.DataFrame({"groups": ["a", "a", "b", "a", "b"]}).select( + pl.col("groups").map_elements(lambda _: pl.Series([1, 2])).over("groups") + ) + except ComputeError as exc: + msg = str(exc) + assert ( + "the length of the window expression did not match that of the group" in msg + ) + assert "group:" in msg + assert "group length:" in msg + assert "output: 'shape:" in msg + + +def test_invalid_concat_type_err() -> None: + df = pl.DataFrame( + { + "foo": [1, 2], + "bar": [6, 7], + "ham": ["a", "b"], + } + ) + with pytest.raises( + ValueError, + match="DataFrame `how` must be one of {'vertical', '.+', 'align_right'}, got 'sausage'", + ): + pl.concat([df, df], how="sausage") # type: ignore[arg-type] + + +@pytest.mark.parametrize("how", ["horizontal", "diagonal"]) +def test_series_concat_err(how: ConcatMethod) -> None: + s = pl.Series([1, 2, 3]) + with pytest.raises( + ValueError, + match="Series only supports 'vertical' concat strategy", + ): + pl.concat([s, s], how=how) + + +def test_invalid_sort_by() -> None: + df = pl.DataFrame( + { + "a": ["bill", "bob", "jen", "allie", "george"], + "b": ["M", "M", "F", "F", "M"], + "c": [32, 40, 20, 19, 39], + } + ) + + # `select a where b order by c desc` + with pytest.raises(ShapeError): + df.select(pl.col("a").filter(pl.col("b") == "M").sort_by("c", descending=True)) + + +def test_epoch_time_type() -> None: + with pytest.raises( + InvalidOperationError, + match="`timestamp` operation not supported for dtype `time`", + ): + pl.Series([time(0, 0, 1)]).dt.epoch("s") + + +def test_duplicate_columns_arg_csv() -> None: + f = io.BytesIO() + pl.DataFrame({"x": [1, 2, 3], "y": ["a", "b", "c"]}).write_csv(f) + f.seek(0) + with pytest.raises( + ValueError, match=r"`columns` arg should only have unique values" + ): + pl.read_csv(f, columns=["x", "x", "y"]) + + +def test_datetime_time_add_err() -> None: + with pytest.raises(SchemaError, match="failed to determine supertype"): + pl.Series([datetime(1970, 1, 1, 0, 0, 1)]) + pl.Series([time(0, 0, 2)]) + + +def test_invalid_dtype() -> None: + with pytest.raises( + TypeError, + match=r"cannot parse input of type 'str' into Polars data type \(given: 'mayonnaise'\)", + ): + pl.Series([1, 2], dtype="mayonnaise") # type: ignore[arg-type] + + with pytest.raises( + TypeError, + match="cannot parse input into Polars data type", + ): + pl.Series([None], dtype=tzinfo) # type: ignore[arg-type] + + +def test_arr_eval_named_cols() -> None: + df = pl.DataFrame({"A": ["a", "b"], "B": [["a", "b"], ["c", "d"]]}) + + with pytest.raises( + ComputeError, + ): + df.select(pl.col("B").list.eval(pl.element().append(pl.col("A")))) + + +def test_alias_in_join_keys() -> None: + df = pl.DataFrame({"A": ["a", "b"], "B": [["a", "b"], ["c", "d"]]}) + with pytest.raises( + InvalidOperationError, + match=r"'alias' is not allowed in a join key, use 'with_columns' first", + ): + df.join(df, on=pl.col("A").alias("foo")) + + +def test_sort_by_different_lengths() -> None: + df = pl.DataFrame( + { + "group": ["a"] * 3 + ["b"] * 3, + "col1": [1, 2, 3, 300, 200, 100], + "col2": [1, 2, 3, 300, 1, 1], + } + ) + with pytest.raises( + ComputeError, + match=r"the expression in `sort_by` argument must result in the same length", + ): + df.group_by("group").agg( + [ + pl.col("col1").sort_by(pl.col("col2").unique()), + ] + ) + + with pytest.raises( + ComputeError, + match=r"the expression in `sort_by` argument must result in the same length", + ): + df.group_by("group").agg( + [ + pl.col("col1").sort_by(pl.col("col2").arg_unique()), + ] + ) + + +def test_err_filter_no_expansion() -> None: + # df contains floats + df = pl.DataFrame( + { + "a": [0.1, 0.2], + } + ) + + with pytest.raises( + ComputeError, match=r"The predicate expanded to zero expressions" + ): + # we filter by ints + df.filter(pl.col(pl.Int16).min() < 0.1) + + +@pytest.mark.parametrize( + ("e"), + [ + pl.col("date") > "2021-11-10", + pl.col("date") < "2021-11-10", + ], +) +def test_date_string_comparison(e: pl.Expr) -> None: + df = pl.DataFrame( + { + "date": [ + "2022-11-01", + "2022-11-02", + "2022-11-05", + ], + } + ).with_columns(pl.col("date").str.strptime(pl.Date, "%Y-%m-%d")) + + with pytest.raises( + InvalidOperationError, + match=r"cannot compare 'date/datetime/time' to a string value", + ): + df.select(e) + + +def test_err_on_multiple_column_expansion() -> None: + # this would be a great feature :) + with pytest.raises( + ComputeError, match=r"expanding more than one `col` is not allowed" + ): + pl.DataFrame( + { + "a": [1], + "b": [2], + "c": [3], + "d": [4], + } + ).select([pl.col(["a", "b"]) + pl.col(["c", "d"])]) + + +def test_compare_different_len() -> None: + df = pl.DataFrame( + { + "idx": list(range(5)), + } + ) + + s = pl.Series([2, 5, 8]) + with pytest.raises(ShapeError): + df.filter(pl.col("idx") == s) + + +def test_take_negative_index_is_oob() -> None: + df = pl.DataFrame({"value": [1, 2, 3]}) + with pytest.raises(OutOfBoundsError): + df["value"].gather(-4) + + +def test_string_numeric_arithmetic_err() -> None: + df = pl.DataFrame({"s": ["x"]}) + with pytest.raises( + InvalidOperationError, match=r"arithmetic on string and numeric not allowed" + ): + df.select(pl.col("s") + 1) + + +def test_ambiguous_filter_err() -> None: + df = pl.DataFrame({"a": [None, "2", "3"], "b": [None, None, "z"]}) + with pytest.raises( + ComputeError, + match=r"The predicate passed to 'LazyFrame.filter' expanded to multiple expressions", + ): + df.filter(pl.col(["a", "b"]).is_null()) + + +def test_with_column_duplicates() -> None: + df = pl.DataFrame({"a": [0, None, 2, 3, None], "b": [None, 1, 2, 3, None]}) + with pytest.raises( + ComputeError, + match=r"the name 'same' passed to `LazyFrame.with_columns` is duplicate.*", + ): + assert df.with_columns([pl.all().alias("same")]).columns == ["a", "b", "same"] + + +def test_skip_nulls_err() -> None: + df = pl.DataFrame({"foo": [None, None]}) + + with pytest.raises( + ComputeError, + match=r"The output type of the 'map_elements' function cannot be determined", + ): + df.with_columns(pl.col("foo").map_elements(lambda x: x, skip_nulls=True)) + + +@pytest.mark.parametrize( + ("test_df", "type", "expected_message"), + [ + pytest.param( + pl.DataFrame({"A": [1, 2, 3], "B": ["1", "2", "help"]}), + pl.UInt32, + "conversion .* failed", + id="Unsigned integer", + ) + ], +) +def test_cast_err_column_value_highlighting( + test_df: pl.DataFrame, type: pl.DataType, expected_message: str +) -> None: + with pytest.raises(InvalidOperationError, match=expected_message): + test_df.with_columns(pl.all().cast(type)) + + +def test_lit_agg_err() -> None: + with pytest.raises(ComputeError, match=r"cannot aggregate a literal"): + pl.DataFrame({"y": [1]}).with_columns(pl.lit(1).sum().over("y")) + + +def test_invalid_group_by_arg() -> None: + df = pl.DataFrame({"a": [1]}) + with pytest.raises( + TypeError, match="specifying aggregations as a dictionary is not supported" + ): + df.group_by(1).agg({"a": "sum"}) + + +def test_overflow_msg() -> None: + with pytest.raises( + ComputeError, + match=r"could not append value: 2147483648 of type: i64 to the builder", + ): + pl.DataFrame([[2**31]], [("a", pl.Int32)], orient="row") + + +def test_sort_by_err_9259() -> None: + df = pl.DataFrame( + {"a": [1, 1, 1], "b": [3, 2, 1], "c": [1, 1, 2]}, + schema={"a": pl.Float32, "b": pl.Float32, "c": pl.Float32}, + ) + with pytest.raises(ComputeError): + df.lazy().group_by("c").agg( + [pl.col("a").sort_by(pl.col("b").filter(pl.col("b") > 100)).sum()] + ).collect() + + +def test_empty_inputs_error() -> None: + df = pl.DataFrame({"col1": [1]}) + with pytest.raises( + pl.exceptions.InvalidOperationError, match="expected at least 1 input" + ): + df.select(pl.sum_horizontal(pl.exclude("col1"))) + + +@pytest.mark.parametrize( + ("colname", "values", "expected"), + [ + ("a", [2], [False, True, False]), + ("a", [True, False], None), + ("a", ["2", "3", "4"], None), + ("b", [Decimal("3.14")], None), + ("c", [-2, -1, 0, 1, 2], None), + ( + "d", + pl.datetime_range( + datetime.now(), + datetime.now(), + interval="2345ns", + time_unit="ns", + eager=True, + ), + None, + ), + ("d", [time(10, 30)], None), + ("e", [datetime(1999, 12, 31, 10, 30)], None), + ("f", ["xx", "zz"], None), + ], +) +def test_invalid_is_in_dtypes( + colname: str, values: list[Any], expected: list[Any] | None +) -> None: + df = pl.DataFrame( + { + "a": [1, 2, 3], + "b": [-2.5, 0.0, 2.5], + "c": [True, None, False], + "d": [datetime(2001, 10, 30), None, datetime(2009, 7, 5)], + "e": [date(2029, 12, 31), date(1999, 12, 31), None], + "f": [b"xx", b"yy", b"zz"], + } + ) + if expected is None: + with pytest.raises( + InvalidOperationError, + match="'is_in' cannot check for .*? values in .*? data", + ): + df.select(pl.col(colname).is_in(values)) + else: + assert df.select(pl.col(colname).is_in(values))[colname].to_list() == expected + + +def test_sort_by_error() -> None: + df = pl.DataFrame( + { + "id": [1, 1, 1, 2, 2, 3, 3, 3], + "number": [1, 3, 2, 1, 2, 2, 1, 3], + "type": ["A", "B", "A", "B", "B", "A", "B", "C"], + "cost": [10, 25, 20, 25, 30, 30, 50, 100], + } + ) + + with pytest.raises( + ComputeError, + match="expressions in 'sort_by' produced a different number of groups", + ): + df.group_by("id", maintain_order=True).agg( + pl.col("cost").filter(pl.col("type") == "A").sort_by("number") + ) + + +def test_non_existent_expr_inputs_in_lazy() -> None: + with pytest.raises(ColumnNotFoundError): + pl.LazyFrame().filter(pl.col("x") == 1).explain() # tests: 12074 + + lf = pl.LazyFrame({"foo": [1, 1, -2, 3]}) + + with pytest.raises(ColumnNotFoundError): + ( + lf.select(pl.col("foo").cum_sum().alias("bar")) + .filter(pl.col("bar") == pl.col("foo")) + .explain() + ) + + +def test_error_list_to_array() -> None: + with pytest.raises(ComputeError, match="not all elements have the specified width"): + pl.DataFrame( + data={"a": [[1, 2], [3, 4, 5]]}, schema={"a": pl.List(pl.Int8)} + ).with_columns(array=pl.col("a").list.to_array(2)) + + +def test_raise_not_found_in_simplify_14974() -> None: + df = pl.DataFrame() + with pytest.raises(ColumnNotFoundError): + df.select(1 / (1 + pl.col("a"))) + + +def test_invalid_product_type() -> None: + with pytest.raises( + InvalidOperationError, + match="`product` operation not supported for dtype", + ): + pl.Series([[1, 2, 3]]).product() + + +def test_fill_null_invalid_supertype() -> None: + df = pl.DataFrame({"date": [date(2022, 1, 1), None]}) + with pytest.raises(InvalidOperationError, match="got invalid or ambiguous"): + df.select(pl.col("date").fill_null(1.0)) + + +def test_raise_array_of_cats() -> None: + with pytest.raises(InvalidOperationError, match="is not yet supported"): + pl.Series([["a", "b"], ["a", "c"]], dtype=pl.Array(pl.Categorical, 2)) + + +def test_raise_invalid_arithmetic() -> None: + df = pl.Series("a", [object()]).to_frame() + + with pytest.raises(InvalidOperationError): + df.select(pl.col("a") - pl.col("a")) + + +def test_raise_on_sorted_multi_args() -> None: + with pytest.raises(TypeError): + pl.DataFrame({"a": [1], "b": [1]}).set_sorted( + ["a", "b"] # type: ignore[arg-type] + ) + + +def test_err_invalid_comparison() -> None: + with pytest.raises( + SchemaError, + match="could not evaluate comparison between series 'a' of dtype: date and series 'b' of dtype: bool", + ): + _ = pl.Series("a", [date(2020, 1, 1)]) == pl.Series("b", [True]) + + with pytest.raises( + InvalidOperationError, + match="could not apply comparison on series of dtype 'object; operand names: 'a', 'b'", + ): + _ = pl.Series("a", [object()]) == pl.Series("b", [object]) + + +def test_no_panic_pandas_nat() -> None: + # we don't want to support pd.nat, but don't want to panic. + with pytest.raises(Exception): # noqa: B017 + pl.DataFrame({"x": [pd.NaT]}) + + +def test_list_to_struct_invalid_type() -> None: + with pytest.raises(pl.exceptions.InvalidOperationError): + pl.DataFrame({"a": 1}).to_series().list.to_struct() + + +def test_raise_invalid_agg() -> None: + with pytest.raises(pl.exceptions.ColumnNotFoundError): + ( + pl.LazyFrame({"foo": [1]}) + .with_row_index() + .group_by("index") + .agg(pl.col("foo").filter(pl.col("i_do_not_exist"))) + ).collect() + + +def test_err_mean_horizontal_lists() -> None: + df = pl.DataFrame( + { + "experiment_id": [1, 2], + "sensor1": [[1, 2, 3], [7, 8, 9]], + "sensor2": [[4, 5, 6], [10, 11, 12]], + } + ) + with pytest.raises(pl.exceptions.InvalidOperationError): + df.with_columns(pl.mean_horizontal("sensor1", "sensor2").alias("avg_sensor")) + + +def test_raise_column_not_found_in_join_arg() -> None: + a = pl.DataFrame({"x": [1, 2, 3]}) + b = pl.DataFrame({"y": [1, 2, 3]}) + with pytest.raises(pl.exceptions.ColumnNotFoundError): + a.join(b, on="y") + + +def test_raise_on_different_results_20104() -> None: + df = pl.DataFrame({"x": [1, 2]}) + + with pytest.raises(pl.exceptions.SchemaError): + df.rolling("x", period="3i").agg( + result=pl.col("x").gather_every(2, offset=1).map_batches(pl.Series.min) + ) diff --git a/py-polars/tests/unit/test_exceptions.py b/py-polars/tests/unit/test_exceptions.py new file mode 100644 index 000000000000..d94a81d92aa7 --- /dev/null +++ b/py-polars/tests/unit/test_exceptions.py @@ -0,0 +1,34 @@ +import pytest + +from polars.exceptions import ( + CategoricalRemappingWarning, + ComputeError, + CustomUFuncWarning, + MapWithoutReturnDtypeWarning, + OutOfBoundsError, + PerformanceWarning, + PolarsError, + PolarsInefficientMapWarning, + PolarsWarning, +) + + +def test_polars_error_base_class() -> None: + msg = "msg" + assert isinstance(ComputeError(msg), PolarsError) + with pytest.raises(PolarsError, match=msg): + raise OutOfBoundsError(msg) + + +def test_polars_warning_base_class() -> None: + msg = "msg" + assert isinstance(MapWithoutReturnDtypeWarning(msg), PolarsWarning) + with pytest.raises(PolarsWarning, match=msg): + raise CustomUFuncWarning(msg) + + +def test_performance_warning_base_class() -> None: + msg = "msg" + assert isinstance(PolarsInefficientMapWarning(msg), PerformanceWarning) + with pytest.raises(PerformanceWarning, match=msg): + raise CategoricalRemappingWarning(msg) diff --git a/py-polars/tests/unit/test_expansion.py b/py-polars/tests/unit/test_expansion.py new file mode 100644 index 000000000000..b01fc625f5cc --- /dev/null +++ b/py-polars/tests/unit/test_expansion.py @@ -0,0 +1,199 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal +from tests.unit.conftest import NUMERIC_DTYPES + + +def test_regex_exclude() -> None: + df = pl.DataFrame({f"col_{i}": [i] for i in range(5)}) + + assert df.select(pl.col("^col_.*$").exclude("col_0")).columns == [ + "col_1", + "col_2", + "col_3", + "col_4", + ] + + +def test_regex_in_filter() -> None: + df = pl.DataFrame( + { + "nrs": [1, 2, 3, None, 5], + "names": ["foo", "ham", "spam", "egg", None], + "flt": [1.0, None, 3.0, 1.0, None], + } + ) + + res = df.filter( + pl.fold( + acc=False, function=lambda acc, s: acc | s, exprs=(pl.col("^nrs|flt*$") < 3) + ) + ).row(0) + expected = (1, "foo", 1.0) + assert res == expected + + +def test_regex_selection() -> None: + lf = pl.LazyFrame( + { + "foo": [1], + "fooey": [1], + "foobar": [1], + "bar": [1], + } + ) + result = lf.select([pl.col("^foo.*$")]) + assert result.collect_schema().names() == ["foo", "fooey", "foobar"] + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + (pl.exclude("a"), ["b", "c"]), + (pl.all().exclude(pl.Boolean), ["a", "b"]), + (pl.all().exclude([pl.Boolean]), ["a", "b"]), + (pl.all().exclude(NUMERIC_DTYPES), ["c"]), + ], +) +def test_exclude_selection(expr: pl.Expr, expected: list[str]) -> None: + lf = pl.LazyFrame({"a": [1], "b": [1], "c": [True]}) + + assert lf.select(expr).collect_schema().names() == expected + + +def test_struct_name_resolving_15430() -> None: + q = pl.LazyFrame([{"a": {"b": "c"}}]) + a = ( + q.with_columns(pl.col("a").struct.field("b")) + .drop("a") + .collect(projection_pushdown=True) + ) + + b = ( + q.with_columns(pl.col("a").struct[0]) + .drop("a") + .collect(projection_pushdown=True) + ) + + assert a["b"].item() == "c" + assert b["b"].item() == "c" + assert a.columns == ["b"] + assert b.columns == ["b"] + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + (pl.all().name.prefix("agg_"), ["A", "agg_B", "agg_C"]), + (pl.col("B", "C").name.prefix("agg_"), ["A", "agg_B", "agg_C"]), + (pl.col("A", "C").name.prefix("agg_"), ["A", "agg_A", "agg_C"]), + ], +) +def test_exclude_keys_in_aggregation_16170(expr: pl.Expr, expected: list[str]) -> None: + df = pl.DataFrame({"A": [4, 4, 3], "B": [1, 2, 3], "C": [5, 6, 7]}) + + # wildcard excludes aggregation column + result = df.lazy().group_by("A").agg(expr) + assert result.collect_schema().names() == expected + + +@pytest.mark.parametrize( + "field", + [ + ["aaa", "ccc"], + [["aaa", "ccc"]], + [["aaa"], "ccc"], + [["^aa.+|cc.+$"]], + ], +) +def test_struct_field_expand(field: Any) -> None: + df = pl.DataFrame( + { + "aaa": [1, 2], + "bbb": ["ab", "cd"], + "ccc": [True, None], + "ddd": [[1, 2], [3]], + } + ) + struct_df = df.select(pl.struct(["aaa", "bbb", "ccc", "ddd"]).alias("struct_col")) + res_df = struct_df.select(pl.col("struct_col").struct.field(*field)) + assert_frame_equal(res_df, df.select("aaa", "ccc")) + + +def test_struct_field_expand_star() -> None: + df = pl.DataFrame( + { + "aaa": [1, 2], + "bbb": ["ab", "cd"], + "ccc": [True, None], + "ddd": [[1, 2], [3]], + } + ) + struct_df = df.select(pl.struct(["aaa", "bbb", "ccc", "ddd"]).alias("struct_col")) + assert_frame_equal(struct_df.select(pl.col("struct_col").struct.field("*")), df) + + +def test_struct_unnest() -> None: + """Same as test_struct_field_expand_star but using the unnest alias.""" + df = pl.DataFrame( + { + "aaa": [1, 2], + "bbb": ["ab", "cd"], + "ccc": [True, None], + "ddd": [[1, 2], [3]], + } + ) + struct_df = df.select(pl.struct(["aaa", "bbb", "ccc", "ddd"]).alias("struct_col")) + assert_frame_equal(struct_df.select(pl.col("struct_col").struct.unnest()), df) + + +def test_struct_field_expand_rewrite() -> None: + df = pl.DataFrame({"A": [1], "B": [2]}) + assert df.select( + pl.struct(["A", "B"]).struct.field("*").name.prefix("foo_") + ).to_dict(as_series=False) == {"foo_A": [1], "foo_B": [2]} + + +def test_struct_field_expansion_16410() -> None: + q = pl.LazyFrame({"coords": [{"x": 4, "y": 4}]}) + + assert q.with_columns( + pl.col("coords").struct.with_fields(pl.field("x").sqrt()).struct.field("*") + ).collect().to_dict(as_series=False) == { + "coords": [{"x": 4, "y": 4}], + "x": [2.0], + "y": [4], + } + + +def test_field_and_column_expansion() -> None: + df = pl.DataFrame({"a": [{"x": 1, "y": 2}], "b": [{"i": 3, "j": 4}]}) + + assert df.select(pl.col("a", "b").struct.field("*")).to_dict(as_series=False) == { + "x": [1], + "y": [2], + "i": [3], + "j": [4], + } + + +def test_struct_field_exclude_and_wildcard_expansion() -> None: + df = pl.DataFrame({"a": [{"x": 1, "y": 2}], "b": [{"i": 3, "j": 4}]}) + + assert df.select(pl.exclude("foo").struct.field("*")).to_dict(as_series=False) == { + "x": [1], + "y": [2], + "i": [3], + "j": [4], + } + assert df.select(pl.all().struct.field("*")).to_dict(as_series=False) == { + "x": [1], + "y": [2], + "i": [3], + "j": [4], + } diff --git a/py-polars/tests/unit/test_expr_multi_cols.py b/py-polars/tests/unit/test_expr_multi_cols.py new file mode 100644 index 000000000000..9cd259122940 --- /dev/null +++ b/py-polars/tests/unit/test_expr_multi_cols.py @@ -0,0 +1,114 @@ +import polars as pl +from polars.testing import assert_frame_equal + + +def test_exclude_name_from_dtypes() -> None: + df = pl.DataFrame({"a": ["a"], "b": ["b"]}) + + assert_frame_equal( + df.with_columns(pl.col(pl.String).exclude("a").name.suffix("_foo")), + pl.DataFrame({"a": ["a"], "b": ["b"], "b_foo": ["b"]}), + ) + + +def test_fold_regex_expand() -> None: + df = pl.DataFrame( + { + "x": [0, 1, 2], + "y_1": [1.1, 2.2, 3.3], + "y_2": [1.0, 2.5, 3.5], + } + ) + assert df.with_columns( + pl.fold( + acc=pl.lit(0), function=lambda acc, x: acc + x, exprs=pl.col("^y_.*$") + ).alias("y_sum"), + ).to_dict(as_series=False) == { + "x": [0, 1, 2], + "y_1": [1.1, 2.2, 3.3], + "y_2": [1.0, 2.5, 3.5], + "y_sum": [2.1, 4.7, 6.8], + } + + +def test_arg_sort_argument_expansion() -> None: + df = pl.DataFrame( + { + "col1": [1, 2, 3], + "col2": [4, 5, 6], + "sort_order": [9, 8, 7], + } + ) + assert df.select( + pl.col("col1").sort_by(pl.col("sort_order").arg_sort()).name.suffix("_suffix") + ).to_dict(as_series=False) == {"col1_suffix": [3, 2, 1]} + assert df.select( + pl.col("^col.*$").sort_by(pl.col("sort_order")).arg_sort() + ).to_dict(as_series=False) == {"col1": [2, 1, 0], "col2": [2, 1, 0]} + assert df.select( + pl.all().exclude("sort_order").sort_by(pl.col("sort_order")).arg_sort() + ).to_dict(as_series=False) == {"col1": [2, 1, 0], "col2": [2, 1, 0]} + + +def test_multiple_columns_length_9137() -> None: + df = pl.DataFrame( + { + "a": [1, 1], + "b": ["c", "d"], + } + ) + + # list is larger than groups + cmp_list = ["a", "b", "c"] + + assert df.group_by("a").agg(pl.col("b").is_in(cmp_list)).to_dict( + as_series=False + ) == { + "a": [1], + "b": [[True, False]], + } + + +def test_regex_in_cols() -> None: + df = pl.DataFrame( + { + "col1": [1, 2, 3], + "col2": [4, 5, 6], + "val1": ["a", "b", "c"], + "val2": ["A", "B", "C"], + } + ) + + assert df.select(pl.col("^col.*$").name.prefix("matched_")).to_dict( + as_series=False + ) == { + "matched_col1": [1, 2, 3], + "matched_col2": [4, 5, 6], + } + + assert df.with_columns( + pl.col("^col.*$", "^val.*$").name.prefix("matched_") + ).to_dict(as_series=False) == { + "col1": [1, 2, 3], + "col2": [4, 5, 6], + "val1": ["a", "b", "c"], + "val2": ["A", "B", "C"], + "matched_col1": [1, 2, 3], + "matched_col2": [4, 5, 6], + "matched_val1": ["a", "b", "c"], + "matched_val2": ["A", "B", "C"], + } + assert df.select(pl.col("^col.*$", "val1").name.prefix("matched_")).to_dict( + as_series=False + ) == { + "matched_col1": [1, 2, 3], + "matched_col2": [4, 5, 6], + "matched_val1": ["a", "b", "c"], + } + + assert df.select(pl.col("^col.*$", "val1").exclude("col2")).to_dict( + as_series=False + ) == { + "col1": [1, 2, 3], + "val1": ["a", "b", "c"], + } diff --git a/py-polars/tests/unit/test_format.py b/py-polars/tests/unit/test_format.py new file mode 100644 index 000000000000..98593f9fd948 --- /dev/null +++ b/py-polars/tests/unit/test_format.py @@ -0,0 +1,513 @@ +from __future__ import annotations + +import string +from decimal import Decimal as D +from typing import TYPE_CHECKING, Any + +import pytest + +import polars as pl +from polars.exceptions import InvalidOperationError + +if TYPE_CHECKING: + from collections.abc import Iterator + + from polars._typing import PolarsDataType + + +@pytest.fixture(autouse=True) +def _environ() -> Iterator[None]: + """Fixture to ensure we run with default Config settings during tests.""" + with pl.Config(restore_defaults=True): + yield + + +@pytest.mark.parametrize( + ("expected", "values"), + [ + pytest.param( + """shape: (1,) +Series: 'foo' [str] +[ + "Somelongstringt… +] +""", + ["Somelongstringto eeat wit me oundaf"], + id="Long string", + ), + pytest.param( + """shape: (1,) +Series: 'foo' [str] +[ + "😀😁😂😃😄😅😆😇😈😉😊😋😌😎😏… +] +""", + ["😀😁😂😃😄😅😆😇😈😉😊😋😌😎😏😐😑😒😓"], + id="Emojis", + ), + pytest.param( + """shape: (1,) +Series: 'foo' [str] +[ + "yzäöüäöüäöüäö" +] +""", + ["yzäöüäöüäöüäö"], + id="Characters with accents", + ), + pytest.param( + """shape: (100,) +Series: 'foo' [i64] +[ + 0 + 1 + 2 + 3 + 4 + … + 95 + 96 + 97 + 98 + 99 +] +""", + [*range(100)], + id="Long series", + ), + ], +) +def test_fmt_series( + capfd: pytest.CaptureFixture[str], expected: str, values: list[Any] +) -> None: + s = pl.Series(name="foo", values=values) + with pl.Config(fmt_str_lengths=15): + print(s) + out, _err = capfd.readouterr() + assert out == expected + + +def test_fmt_series_string_truncate_default(capfd: pytest.CaptureFixture[str]) -> None: + values = [ + string.ascii_lowercase + "123", + string.ascii_lowercase + "1234", + string.ascii_lowercase + "12345", + ] + s = pl.Series(name="foo", values=values) + print(s) + out, _ = capfd.readouterr() + expected = """shape: (3,) +Series: 'foo' [str] +[ + "abcdefghijklmnopqrstuvwxyz123" + "abcdefghijklmnopqrstuvwxyz1234" + "abcdefghijklmnopqrstuvwxyz1234… +] +""" + assert out == expected + + +@pytest.mark.parametrize( + "dtype", [pl.String, pl.Categorical, pl.Enum(["abc", "abcd", "abcde"])] +) +def test_fmt_series_string_truncate_cat( + dtype: PolarsDataType, capfd: pytest.CaptureFixture[str] +) -> None: + s = pl.Series(name="foo", values=["abc", "abcd", "abcde"], dtype=dtype) + with pl.Config(fmt_str_lengths=4): + print(s) + out, _ = capfd.readouterr() + result = [s.strip() for s in out.split("\n")[3:6]] + expected = ['"abc"', '"abcd"', '"abcd…'] + print(result) + assert result == expected + + +@pytest.mark.parametrize( + ("values", "dtype", "expected"), + [ + ( + [-127, -1, 0, 1, 127], + pl.Int8, + """shape: (5,) +Series: 'foo' [i8] +[ + -127 + -1 + 0 + 1 + 127 +]""", + ), + ( + [-32768, -1, 0, 1, 32767], + pl.Int16, + """shape: (5,) +Series: 'foo' [i16] +[ + -32,768 + -1 + 0 + 1 + 32,767 +]""", + ), + ( + [-2147483648, -1, 0, 1, 2147483647], + pl.Int32, + """shape: (5,) +Series: 'foo' [i32] +[ + -2,147,483,648 + -1 + 0 + 1 + 2,147,483,647 +]""", + ), + ( + [-9223372036854775808, -1, 0, 1, 9223372036854775807], + pl.Int64, + """shape: (5,) +Series: 'foo' [i64] +[ + -9,223,372,036,854,775,808 + -1 + 0 + 1 + 9,223,372,036,854,775,807 +]""", + ), + ], +) +def test_fmt_signed_int_thousands_sep( + values: list[int], dtype: PolarsDataType, expected: str +) -> None: + s = pl.Series(name="foo", values=values, dtype=dtype) + with pl.Config(thousands_separator=True): + assert str(s) == expected + + +@pytest.mark.parametrize( + ("values", "dtype", "expected"), + [ + ( + [0, 1, 127], + pl.UInt8, + """shape: (3,) +Series: 'foo' [u8] +[ + 0 + 1 + 127 +]""", + ), + ( + [0, 1, 32767], + pl.UInt16, + """shape: (3,) +Series: 'foo' [u16] +[ + 0 + 1 + 32,767 +]""", + ), + ( + [0, 1, 2147483647], + pl.UInt32, + """shape: (3,) +Series: 'foo' [u32] +[ + 0 + 1 + 2,147,483,647 +]""", + ), + ( + [0, 1, 9223372036854775807], + pl.UInt64, + """shape: (3,) +Series: 'foo' [u64] +[ + 0 + 1 + 9,223,372,036,854,775,807 +]""", + ), + ], +) +def test_fmt_unsigned_int_thousands_sep( + values: list[int], dtype: PolarsDataType, expected: str +) -> None: + s = pl.Series(name="foo", values=values, dtype=dtype) + with pl.Config(thousands_separator=True): + assert str(s) == expected + + +def test_fmt_float(capfd: pytest.CaptureFixture[str]) -> None: + s = pl.Series(name="foo", values=[7.966e-05, 7.9e-05, 8.4666e-05, 8.00007966]) + print(s) + out, _err = capfd.readouterr() + expected = """shape: (4,) +Series: 'foo' [f64] +[ + 0.00008 + 0.000079 + 0.000085 + 8.00008 +] +""" + assert out == expected + + +def test_duration_smallest_units() -> None: + s = pl.Series(range(6), dtype=pl.Duration("us")) + assert ( + str(s) + == "shape: (6,)\nSeries: '' [duration[μs]]\n[\n\t0µs\n\t1µs\n\t2µs\n\t3µs\n\t4µs\n\t5µs\n]" + ) + s = pl.Series(range(6), dtype=pl.Duration("ms")) + assert ( + str(s) + == "shape: (6,)\nSeries: '' [duration[ms]]\n[\n\t0ms\n\t1ms\n\t2ms\n\t3ms\n\t4ms\n\t5ms\n]" + ) + s = pl.Series(range(6), dtype=pl.Duration("ns")) + assert ( + str(s) + == "shape: (6,)\nSeries: '' [duration[ns]]\n[\n\t0ns\n\t1ns\n\t2ns\n\t3ns\n\t4ns\n\t5ns\n]" + ) + + +def test_fmt_float_full() -> None: + fmt_float_full = "shape: (1,)\nSeries: '' [f64]\n[\n\t1.230498095872587\n]" + s = pl.Series([1.2304980958725870923]) + + with pl.Config() as cfg: + cfg.set_fmt_float("full") + assert str(s) == fmt_float_full + + assert str(s) != fmt_float_full + + +def test_fmt_list_12188() -> None: + # set max_items to 1 < 4(size of failed list) to touch the testing branch. + with ( + pl.Config(fmt_table_cell_list_len=1), + pytest.raises(InvalidOperationError, match="from `i64` to `u8` failed"), + ): + pl.DataFrame( + { + "x": pl.int_range(250, 260, 1, eager=True), + } + ).with_columns(u8=pl.col("x").cast(pl.UInt8)) + + +def test_date_list_fmt() -> None: + df = pl.DataFrame( + { + "mydate": ["2020-01-01", "2020-01-02", "2020-01-05", "2020-01-05"], + "index": [1, 2, 5, 5], + } + ) + + df = df.with_columns(pl.col("mydate").str.strptime(pl.Date, "%Y-%m-%d")) + assert ( + str(df.group_by("index", maintain_order=True).agg(pl.col("mydate"))["mydate"]) + == """shape: (3,) +Series: 'mydate' [list[date]] +[ + [2020-01-01] + [2020-01-02] + [2020-01-05, 2020-01-05] +]""" + ) + + +def test_fmt_series_cat_list() -> None: + s = pl.Series( + [ + ["a", "b"], + ["b", "a"], + ["b"], + ], + ).cast(pl.List(pl.Categorical)) + + assert ( + str(s) + == """shape: (3,) +Series: '' [list[cat]] +[ + ["a", "b"] + ["b", "a"] + ["b"] +]""" + ) + + +def test_format_numeric_locale_options() -> None: + df = pl.DataFrame( + { + "a": ["xx", "yy"], + "b": [100000.987654321, -234567.89], + "c": [-11111111, 44444444444], + "d": [D("12345.6789"), D("-9999999.99")], + }, + strict=False, + ) + + # note: numeric digit grouping looks much better + # when right-aligned with fixed float precision + with pl.Config( + tbl_cell_numeric_alignment="RIGHT", + thousands_separator=",", + float_precision=3, + ): + print(df) + assert ( + str(df) + == """shape: (2, 4) +┌─────┬──────────────┬────────────────┬─────────────────┐ +│ a ┆ b ┆ c ┆ d │ +│ --- ┆ --- ┆ --- ┆ --- │ +│ str ┆ f64 ┆ i64 ┆ decimal[*,4] │ +╞═════╪══════════════╪════════════════╪═════════════════╡ +│ xx ┆ 100,000.988 ┆ -11,111,111 ┆ 12,345.6789 │ +│ yy ┆ -234,567.890 ┆ 44,444,444,444 ┆ -9,999,999.9900 │ +└─────┴──────────────┴────────────────┴─────────────────┘""" + ) + + # switch digit/decimal separators + with pl.Config( + decimal_separator=",", + thousands_separator=".", + ): + assert ( + str(df) + == """shape: (2, 4) +┌─────┬────────────────┬────────────────┬─────────────────┐ +│ a ┆ b ┆ c ┆ d │ +│ --- ┆ --- ┆ --- ┆ --- │ +│ str ┆ f64 ┆ i64 ┆ decimal[*,4] │ +╞═════╪════════════════╪════════════════╪═════════════════╡ +│ xx ┆ 100.000,987654 ┆ -11.111.111 ┆ 12.345,6789 │ +│ yy ┆ -234.567,89 ┆ 44.444.444.444 ┆ -9.999.999,9900 │ +└─────┴────────────────┴────────────────┴─────────────────┘""" + ) + + # default (no digit grouping, standard digit/decimal separators) + assert ( + str(df) + == """shape: (2, 4) +┌─────┬───────────────┬─────────────┬───────────────┐ +│ a ┆ b ┆ c ┆ d │ +│ --- ┆ --- ┆ --- ┆ --- │ +│ str ┆ f64 ┆ i64 ┆ decimal[*,4] │ +╞═════╪═══════════════╪═════════════╪═══════════════╡ +│ xx ┆ 100000.987654 ┆ -11111111 ┆ 12345.6789 │ +│ yy ┆ -234567.89 ┆ 44444444444 ┆ -9999999.9900 │ +└─────┴───────────────┴─────────────┴───────────────┘""" + ) + + +def test_fmt_decimal_max_scale() -> None: + values = [D("0.14282911023321884847623576259639164703")] + dtype = pl.Decimal(precision=38, scale=38) + s = pl.Series(values, dtype=dtype) + result = str(s) + expected = """shape: (1,) +Series: '' [decimal[38,38]] +[ + 0.14282911023321884847623576259639164703 +]""" + assert result == expected + + +@pytest.mark.parametrize( + ("lf", "expected"), + [ + ( + ( + pl.LazyFrame({"a": [1]}) + .with_columns(b=pl.col("a")) + .with_columns(c=pl.col("b"), d=pl.col("a")) + ), + 'simple π 4/4 ["a", "b", "c", "d"]', + ), + ( + ( + pl.LazyFrame({"a_very_very_long_string": [1], "a": [1]}) + .with_columns(b=pl.col("a")) + .with_columns(c=pl.col("b"), d=pl.col("a")) + ), + 'simple π 5/5 ["a_very_very_long_string", "a", ... 3 other columns]', + ), + ( + ( + pl.LazyFrame({"an_even_longer_very_very_long_string": [1], "a": [1]}) + .with_columns(b=pl.col("a")) + .with_columns(c=pl.col("b"), d=pl.col("a")) + ), + 'simple π 5/5 ["an_even_longer_very_very_long_string", ... 4 other columns]', + ), + ( + ( + pl.LazyFrame({"a": [1]}) + .with_columns(b=pl.col("a")) + .with_columns(c=pl.col("b"), a_very_long_string_at_the_end=pl.col("a")) + ), + 'simple π 4/4 ["a", "b", "c", ... 1 other column]', + ), + ( + ( + pl.LazyFrame({"a": [1]}) + .with_columns(b=pl.col("a")) + .with_columns( + a_very_long_string_in_the_middle=pl.col("b"), d=pl.col("a") + ) + ), + 'simple π 4/4 ["a", "b", ... 2 other columns]', + ), + ], +) +def test_simple_project_format(lf: pl.LazyFrame, expected: str) -> None: + result = lf.explain() + assert expected in result + + +@pytest.mark.parametrize( + ("df", "expected"), + [ + pytest.param( + pl.DataFrame({"A": range(4)}), + """shape: (4, 1) ++-----+ +| A | ++=====+ +| 0 | +| 1 | +| ... | +| 3 | ++-----+""", + id="Ellipsis correctly aligned", + ), + pytest.param( + pl.DataFrame({"A": range(2)}), + """shape: (2, 1) ++---+ +| A | ++===+ +| 0 | +| 1 | ++---+""", + id="No ellipsis needed", + ), + ], +) +def test_format_ascii_table_truncation(df: pl.DataFrame, expected: str) -> None: + with pl.Config(tbl_rows=3, tbl_hide_column_data_types=True, ascii_tables=True): + assert str(df) == expected + + +def test_format_21393() -> None: + assert pl.select(pl.format("{}", pl.lit(1, pl.Int128))).item() == "1" diff --git a/py-polars/tests/unit/test_init.py b/py-polars/tests/unit/test_init.py new file mode 100644 index 000000000000..716ade074e12 --- /dev/null +++ b/py-polars/tests/unit/test_init.py @@ -0,0 +1,41 @@ +import pytest + +import polars as pl +from polars.exceptions import ComputeError + + +def test_init_nonexistent_attribute() -> None: + with pytest.raises( + AttributeError, match="module 'polars' has no attribute 'stroopwafel'" + ): + pl.stroopwafel + + +def test_init_exceptions_deprecated() -> None: + with pytest.deprecated_call( + match="Accessing `ComputeError` from the top-level `polars` module is deprecated." + ): + exc = pl.ComputeError + + msg = "nope" + with pytest.raises(ComputeError, match=msg): + raise exc(msg) + + +def test_dtype_groups_deprecated() -> None: + with pytest.deprecated_call(match="`INTEGER_DTYPES` is deprecated."): + dtypes = pl.INTEGER_DTYPES + + assert pl.Int8 in dtypes + + +def test_type_aliases_deprecated() -> None: + with pytest.deprecated_call( + match="The `polars.type_aliases` module is deprecated." + ): + from polars.type_aliases import PolarsDataType + assert str(PolarsDataType).startswith("typing.Union") + + +def test_import_all() -> None: + exec("from polars import *") diff --git a/py-polars/tests/unit/test_plugins.py b/py-polars/tests/unit/test_plugins.py new file mode 100644 index 000000000000..3e90b6367944 --- /dev/null +++ b/py-polars/tests/unit/test_plugins.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +import sys +from pathlib import Path +from typing import Any + +import pytest + +import polars as pl +from polars.exceptions import ComputeError +from polars.plugins import ( + _is_dynamic_lib, + _resolve_plugin_path, + _serialize_kwargs, + register_plugin_function, +) + + +@pytest.mark.write_disk +def test_register_plugin_function_invalid_plugin_path(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + plugin_path = tmp_path / "lib.so" + plugin_path.touch() + + expr = register_plugin_function( + plugin_path=plugin_path, function_name="hello", args=5 + ) + + with pytest.raises(ComputeError, match="error loading dynamic library"): + pl.select(expr) + + +@pytest.mark.parametrize( + ("input", "expected"), + [ + (None, b""), + ({}, b""), + ( + {"hi": 0}, + b"\x80\x05\x95\x0b\x00\x00\x00\x00\x00\x00\x00}\x94\x8c\x02hi\x94K\x00s.", + ), + ], +) +def test_serialize_kwargs(input: dict[str, Any] | None, expected: bytes) -> None: + assert _serialize_kwargs(input) == expected + + +@pytest.mark.write_disk +@pytest.mark.parametrize("use_abs_path", [True, False]) +def test_resolve_plugin_path( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + use_abs_path: bool, +) -> None: + tmp_path.mkdir(exist_ok=True) + + mock_venv = tmp_path / ".venv" + mock_venv.mkdir(exist_ok=True) + mock_venv_lib = mock_venv / "lib" + mock_venv_lib.mkdir(exist_ok=True) + (mock_venv_lib / "lib1.so").touch() + (mock_venv_lib / "__init__.py").touch() + + with pytest.MonkeyPatch.context() as mp: + mp.setattr(sys, "prefix", str(mock_venv)) + expected_full_path = mock_venv_lib / "lib1.so" + expected_relative_path = expected_full_path.relative_to(mock_venv) + + if use_abs_path: + result = _resolve_plugin_path(mock_venv_lib, use_abs_path=use_abs_path) + assert result == expected_full_path + else: + result = _resolve_plugin_path(mock_venv_lib, use_abs_path=use_abs_path) + assert result == expected_relative_path + + +@pytest.mark.write_disk +def test_resolve_plugin_path_raises(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + (tmp_path / "__init__.py").touch() + + with pytest.raises(FileNotFoundError, match="no dynamic library found"): + _resolve_plugin_path(tmp_path) + + +@pytest.mark.write_disk +@pytest.mark.parametrize( + ("path", "expected"), + [ + (Path("lib.so"), True), + (Path("lib.pyd"), True), + (Path("lib.dll"), True), + (Path("lib.py"), False), + ], +) +def test_is_dynamic_lib(path: Path, expected: bool, tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + full_path = tmp_path / path + full_path.touch() + assert _is_dynamic_lib(full_path) is expected + + +@pytest.mark.write_disk +def test_is_dynamic_lib_dir(tmp_path: Path) -> None: + path = Path("lib.so") + full_path = tmp_path / path + + full_path.mkdir(exist_ok=True) + (full_path / "hello.txt").touch() + + assert _is_dynamic_lib(full_path) is False diff --git a/py-polars/tests/unit/test_polars_import.py b/py-polars/tests/unit/test_polars_import.py new file mode 100644 index 000000000000..c4ae66690c0f --- /dev/null +++ b/py-polars/tests/unit/test_polars_import.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import compileall +import subprocess +import sys +from pathlib import Path + +import pytest + +import polars as pl +from polars import selectors as cs + +# set a maximum cutoff at 0.3 secs; note that we are typically much faster +# than this (more like ~0.07 secs, depending on hardware), but we allow a +# margin of error to account for frequent noise from slow/contended CI. +MAX_ALLOWED_IMPORT_TIME = 300_000 # << microseconds + + +def _import_time_from_frame(tm: pl.DataFrame) -> int: + return int( + tm.filter(pl.col("import").str.strip_chars() == "polars") + .select("cumulative_time") + .item() + ) + + +def _import_timings() -> bytes: + # assemble suitable command to get polars module import timing; + # run in a separate process to ensure clean timing results. + cmd = f'{sys.executable} -S -X importtime -c "import polars"' + output = subprocess.run(cmd, shell=True, capture_output=True).stderr + if b"Traceback" in output: + msg = f"measuring import timings failed\n\nCommand output:\n{output.decode()}" + raise RuntimeError(msg) + return output.replace(b"import time:", b"").strip() + + +def _import_timings_as_frame(n_tries: int) -> tuple[pl.DataFrame, int]: + import_timings = [] + for _ in range(n_tries): + df_import = ( + pl.read_csv( + source=_import_timings(), + separator="|", + has_header=True, + new_columns=["own_time", "cumulative_time", "import"], + ) + .with_columns(cs.ends_with("_time").str.strip_chars().cast(pl.UInt32)) + .select("import", "own_time", "cumulative_time") + .reverse() + ) + polars_import_time = _import_time_from_frame(df_import) + if polars_import_time < MAX_ALLOWED_IMPORT_TIME: + return df_import, polars_import_time + + import_timings.append(df_import) + + # note: if a qualifying import time was already achieved, we won't get here. + # if we do, let's see all the failed timings to help see what's going on: + import_times = [_import_time_from_frame(df) for df in import_timings] + msg = "\n".join(f"({idx}) {tm:,}μs" for idx, tm in enumerate(import_times)) + min_max = f"Min => {min(import_times):,}μs, Max => {max(import_times):,}μs)" + print(f"\nImport times achieved over {n_tries} tries:\n{min_max}\n\n{msg}") + + sorted_timing_frames = sorted(import_timings, key=_import_time_from_frame) + return sorted_timing_frames[0], min(import_times) + + +@pytest.mark.skipif(sys.platform == "win32", reason="Unreliable on Windows") +@pytest.mark.debug +@pytest.mark.slow +def test_polars_import() -> None: + # up-front compile '.py' -> '.pyc' before timing + polars_path = Path(pl.__file__).parent + compileall.compile_dir(polars_path, quiet=1) + + # note: reduce noise by allowing up to 'n' tries (but return immediately if/when + # a qualifying time is achieved, so we don't waste time running unnecessary tests) + df_import, polars_import_time = _import_timings_as_frame(n_tries=10) + + with pl.Config( + # get a complete view of what's going on in case of failure + tbl_rows=250, + fmt_str_lengths=100, + tbl_hide_dataframe_shape=True, + ): + # ensure that we have not broken lazy-loading (numpy, pandas, pyarrow, etc). + lazy_modules = [ + dep for dep in pl.dependencies.__all__ if not dep.startswith("_") + ] + for mod in lazy_modules: + not_imported = not df_import["import"].str.starts_with(mod).any() + if_err = f"lazy-loading regression: found {mod!r} at import time" + assert not_imported, f"{if_err}\n{df_import}" + + # ensure that we do not have an import speed regression. + if polars_import_time > MAX_ALLOWED_IMPORT_TIME: + import_time_ms = polars_import_time // 1_000 + msg = f"Possible import speed regression; took {import_time_ms}ms\n{df_import}" + raise AssertionError(msg) diff --git a/py-polars/tests/unit/test_predicates.py b/py-polars/tests/unit/test_predicates.py new file mode 100644 index 000000000000..8105894325ad --- /dev/null +++ b/py-polars/tests/unit/test_predicates.py @@ -0,0 +1,751 @@ +import re +from datetime import date, datetime, timedelta +from typing import Any + +import numpy as np +import pytest + +import polars as pl +from polars.exceptions import ComputeError +from polars.testing import assert_frame_equal +from polars.testing.asserts.series import assert_series_equal + + +def test_predicate_4906() -> None: + one_day = timedelta(days=1) + + ldf = pl.DataFrame( + { + "dt": [ + date(2022, 9, 1), + date(2022, 9, 10), + date(2022, 9, 20), + ] + } + ).lazy() + + assert ldf.filter( + pl.min_horizontal((pl.col("dt") + one_day), date(2022, 9, 30)) + > date(2022, 9, 10) + ).collect().to_dict(as_series=False) == { + "dt": [date(2022, 9, 10), date(2022, 9, 20)] + } + + +def test_predicate_null_block_asof_join() -> None: + left = ( + pl.DataFrame( + { + "id": [1, 2, 3, 4], + "timestamp": [ + datetime(2022, 1, 1, 10, 0), + datetime(2022, 1, 1, 10, 1), + datetime(2022, 1, 1, 10, 2), + datetime(2022, 1, 1, 10, 3), + ], + } + ) + .lazy() + .set_sorted("timestamp") + ) + + right = ( + pl.DataFrame( + { + "id": [1, 2, 3] * 2, + "timestamp": [ + datetime(2022, 1, 1, 9, 59, 50), + datetime(2022, 1, 1, 10, 0, 50), + datetime(2022, 1, 1, 10, 1, 50), + datetime(2022, 1, 1, 8, 0, 0), + datetime(2022, 1, 1, 8, 0, 0), + datetime(2022, 1, 1, 8, 0, 0), + ], + "value": ["a", "b", "c"] * 2, + } + ) + .lazy() + .set_sorted("timestamp") + ) + + assert left.join_asof(right, by="id", on="timestamp").filter( + pl.col("value").is_not_null() + ).collect().to_dict(as_series=False) == { + "id": [1, 2, 3], + "timestamp": [ + datetime(2022, 1, 1, 10, 0), + datetime(2022, 1, 1, 10, 1), + datetime(2022, 1, 1, 10, 2), + ], + "value": ["a", "b", "c"], + } + + +def test_predicate_strptime_6558() -> None: + assert ( + pl.DataFrame({"date": ["2022-01-03", "2020-01-04", "2021-02-03", "2019-01-04"]}) + .lazy() + .select(pl.col("date").str.strptime(pl.Date, format="%F")) + .filter((pl.col("date").dt.year() == 2022) & (pl.col("date").dt.month() == 1)) + .collect() + ).to_dict(as_series=False) == {"date": [date(2022, 1, 3)]} + + +def test_predicate_arr_first_6573() -> None: + df = pl.DataFrame( + { + "a": [1, 2, 3, 4, 5, 6], + "b": [6, 5, 4, 3, 2, 1], + } + ) + + assert ( + df.lazy() + .with_columns(pl.col("a").implode()) + .with_columns(pl.col("a").list.first()) + .filter(pl.col("a") == pl.col("b")) + .collect() + ).to_dict(as_series=False) == {"a": [1], "b": [1]} + + +def test_fast_path_comparisons() -> None: + s = pl.Series(np.sort(np.random.randint(0, 50, 100))) + + assert_series_equal(s > 25, s.set_sorted() > 25) + assert_series_equal(s >= 25, s.set_sorted() >= 25) + assert_series_equal(s < 25, s.set_sorted() < 25) + assert_series_equal(s <= 25, s.set_sorted() <= 25) + + +def test_predicate_pushdown_block_8661() -> None: + df = pl.DataFrame( + { + "g": [1, 1, 1, 1, 2, 2, 2, 2], + "t": [1, 2, 3, 4, 4, 3, 2, 1], + "x": [10, 20, 30, 40, 10, 20, 30, 40], + } + ) + assert df.lazy().sort(["g", "t"]).filter( + (pl.col("x").shift() > 20).over("g") + ).collect().to_dict(as_series=False) == { + "g": [1, 2, 2], + "t": [4, 2, 3], + "x": [40, 30, 20], + } + + +def test_predicate_pushdown_cumsum_9566() -> None: + df = pl.DataFrame({"A": range(10), "B": ["b"] * 5 + ["a"] * 5}) + + q = df.lazy().sort(["B", "A"]).filter(pl.col("A").is_in([8, 2]).cum_sum() == 1) + + assert q.collect()["A"].to_list() == [8, 9, 0, 1] + + +def test_predicate_pushdown_join_fill_null_10058() -> None: + ids = pl.LazyFrame({"id": [0, 1, 2]}) + filters = pl.LazyFrame({"id": [0, 1], "filter": [True, False]}) + + assert sorted( + ids.join(filters, how="left", on="id") + .filter(pl.col("filter").fill_null(True)) + .collect() + .to_dict(as_series=False)["id"] + ) == [0, 2] + + +def test_is_in_join_blocked() -> None: + lf1 = pl.LazyFrame( + {"Groups": ["A", "B", "C", "D", "E", "F"], "values0": [1, 2, 3, 4, 5, 6]} + ) + lf2 = pl.LazyFrame( + {"values22": [1, 2, None, 4, 5, 6], "values20": [1, 2, 3, 4, 5, 6]} + ) + lf_all = lf2.join( + lf1, + left_on="values20", + right_on="values0", + how="left", + maintain_order="right_left", + ) + + for result in ( + lf_all.filter(~pl.col("Groups").is_in(["A", "B", "F"])), + lf_all.remove(pl.col("Groups").is_in(["A", "B", "F"])), + ): + expected = pl.LazyFrame( + { + "values22": [None, 4, 5], + "values20": [3, 4, 5], + "Groups": ["C", "D", "E"], + } + ) + assert_frame_equal(result, expected) + + +def test_predicate_pushdown_group_by_keys() -> None: + df = pl.LazyFrame( + {"str": ["A", "B", "A", "B", "C"], "group": [1, 1, 2, 1, 2]} + ).lazy() + assert ( + "SELECTION: None" + not in df.group_by("group") + .agg([pl.len().alias("str_list")]) + .filter(pl.col("group") == 1) + .explain() + ) + + +def test_no_predicate_push_down_with_cast_and_alias_11883() -> None: + df = pl.DataFrame({"a": [1, 2, 3]}) + out = ( + df.lazy() + .select(pl.col("a").cast(pl.Int64).alias("b")) + .filter(pl.col("b") == 1) + .filter((pl.col("b") >= 1) & (pl.col("b") < 1)) + ) + assert ( + re.search(r"FILTER.*FROM\n\s*DF", out.explain(predicate_pushdown=True)) is None + ) + + +@pytest.mark.parametrize( + "predicate", + [ + 0, + "x", + [2, 3], + {"x": 1}, + pl.Series([1, 2, 3]), + None, + ], +) +def test_invalid_filter_predicates(predicate: Any) -> None: + df = pl.DataFrame({"colx": ["aa", "bb", "cc", "dd"]}) + with pytest.raises(TypeError, match="invalid predicate"): + df.filter(predicate) + + +def test_fast_path_boolean_filter_predicates() -> None: + df = pl.DataFrame({"colx": ["aa", "bb", "cc", "dd"]}) + df_empty = df.clear() + + assert_frame_equal(df.filter(False), df_empty) + assert_frame_equal(df.filter(True), df) + + assert_frame_equal(df.remove(True), df_empty) + assert_frame_equal(df.remove(False), df) + + +def test_predicate_pushdown_boundary_12102() -> None: + df = pl.DataFrame({"x": [1, 2, 4], "y": [1, 2, 4]}) + + lf = ( + df.lazy() + .filter(pl.col("y") > 1) + .filter(pl.col("x") == pl.min("x")) + .filter(pl.col("y") > 2) + ) + + result = lf.collect() + result_no_ppd = lf.collect(predicate_pushdown=False) + assert_frame_equal(result, result_no_ppd) + + +def test_take_can_block_predicate_pushdown() -> None: + df = pl.DataFrame({"x": [1, 2, 4], "y": [False, True, True]}) + lf = ( + df.lazy() + .filter(pl.col("y")) + .filter(pl.col("x") == pl.col("x").gather(0)) + .filter(pl.col("y")) + ) + result = lf.collect(predicate_pushdown=True) + assert result.to_dict(as_series=False) == {"x": [2], "y": [True]} + + +def test_literal_series_expr_predicate_pushdown() -> None: + # No pushdown should occur in this case, because otherwise the filter will + # attempt to filter 3 rows with a boolean mask of 2 rows. + lf = pl.LazyFrame({"x": [0, 1, 2]}) + + for res in ( + lf.filter(pl.col("x") > 0).filter(pl.Series([True, True])), + lf.remove(pl.col("x") <= 0).remove(pl.Series([False, False])), + ): + assert res.collect().to_series().to_list() == [1, 2] + + # Pushdown should occur here; series is being used as part of an `is_in`. + for res in ( + lf.filter(pl.col("x") > 0).filter(pl.col("x").is_in([0, 1])), + lf.remove(pl.col("x") <= 0).remove(~pl.col("x").is_in([0, 1])), + ): + assert re.search(r"FILTER .*\nFROM\n\s*DF", res.explain(), re.DOTALL) + assert res.collect().to_series().to_list() == [1] + + +def test_multi_alias_pushdown() -> None: + lf = pl.LazyFrame({"a": [1], "b": [1]}) + + actual = lf.with_columns(m="a", n="b").filter((pl.col("m") + pl.col("n")) < 2) + plan = actual.explain() + + print(plan) + assert plan.count("FILTER") == 1 + assert re.search(r"FILTER.*FROM\n\s*DF", plan, re.DOTALL) is not None + + with pytest.warns(UserWarning, match="Comparisons with None always result in null"): + # confirm we aren't using `eq_missing` in the query plan (denoted as " ==v ") + assert " ==v " not in lf.select(pl.col("a").filter(a=None)).explain() + + +def test_predicate_pushdown_with_window_projections_12637() -> None: + lf = pl.LazyFrame( + { + "key": [1], + "key_2": [1], + "key_3": [1], + "value": [1], + "value_2": [1], + "value_3": [1], + } + ) + + actual = lf.with_columns( + (pl.col("value") * 2).over("key").alias("value_2"), + (pl.col("value") * 2).over("key").alias("value_3"), + ).filter(pl.col("key") == 5) + + plan = actual.explain() + + assert ( + re.search( + r'FILTER \[\(col\("key"\)\) == \(5\)\]\s*FROM\n\s*DF', plan, re.DOTALL + ) + is not None + ) + assert plan.count("FILTER") == 1 + + actual = ( + lf.with_columns( + (pl.col("value") * 2).over("key", "key_2").alias("value_2"), + (pl.col("value") * 2).over("key", "key_2").alias("value_3"), + ) + .filter(pl.col("key") == 5) + .filter(pl.col("key_2") == 5) + ) + + plan = actual.explain() + assert plan.count("FILTER") == 1 + assert re.search(r"FILTER.*FROM\n\s*DF", plan, re.DOTALL) is not None + actual = ( + lf.with_columns( + (pl.col("value") * 2).over("key", "key_2").alias("value_2"), + (pl.col("value") * 2).over("key", "key_3").alias("value_3"), + ) + .filter(pl.col("key") == 5) + .filter(pl.col("key_2") == 5) + ) + + plan = actual.explain() + assert plan.count("FILTER") == 2 + assert ( + re.search( + r'FILTER \[\(col\("key"\)\) == \(5\)\]\s*FROM\n\s*DF', plan, re.DOTALL + ) + is not None + ) + + actual = ( + lf.with_columns( + (pl.col("value") * 2).over("key", pl.col("key_2") + 1).alias("value_2"), + (pl.col("value") * 2).over("key", "key_2").alias("value_3"), + ) + .filter(pl.col("key") == 5) + .filter(pl.col("key_2") == 5) + ) + plan = actual.explain() + assert plan.count("FILTER") == 2 + assert ( + re.search( + r'FILTER \[\(col\("key"\)\) == \(5\)\]\s*FROM\n\s*DF', plan, re.DOTALL + ) + is not None + ) + + # Should block when .over() contains groups-sensitive expr + actual = ( + lf.with_columns( + (pl.col("value") * 2).over("key", pl.sum("key_2")).alias("value_2"), + (pl.col("value") * 2).over("key", "key_2").alias("value_3"), + ) + .filter(pl.col("key") == 5) + .filter(pl.col("key_2") == 5) + ) + + plan = actual.explain() + assert plan.count("FILTER") == 1 + assert "FILTER" in plan + assert re.search(r"FILTER.*FROM\n\s*DF", plan, re.DOTALL) is None + # Ensure the implementation doesn't accidentally push a window expression + # that only refers to the common window keys. + actual = lf.with_columns( + (pl.col("value") * 2).over("key").alias("value_2"), + ).filter(pl.len().over("key") == 1) + + plan = actual.explain() + assert re.search(r"FILTER.*FROM\n\s*DF", plan, re.DOTALL) is None + assert plan.count("FILTER") == 1 + + # Test window in filter + actual = lf.filter(pl.len().over("key") == 1).filter(pl.col("key") == 1) + plan = actual.explain() + assert plan.count("FILTER") == 2 + assert ( + re.search( + r'FILTER \[\(len\(\).over\(\[col\("key"\)\]\)\) == \(1\)\]\s*FROM\n\s*FILTER', + plan, + ) + is not None + ) + assert ( + re.search( + r'FILTER \[\(col\("key"\)\) == \(1\)\]\s*FROM\n\s*DF', plan, re.DOTALL + ) + is not None + ) + + +def test_predicate_reduction() -> None: + # ensure we get clean reduction without casts + lf = pl.LazyFrame({"a": [1], "b": [2]}) + for filter_frame in (lf.filter, lf.remove): + assert ( + "cast" + not in filter_frame( + pl.col("a") > 1, + pl.col("b") > 1, + ).explain() + ) + + +def test_all_any_cleanup_at_single_predicate_case() -> None: + plan = pl.LazyFrame({"a": [1], "b": [2]}).select(["a"]).drop_nulls().explain() + assert "horizontal" not in plan + assert "all" not in plan + + +def test_hconcat_predicate() -> None: + # Predicates shouldn't be pushed down past an hconcat as we can't filter + # across the different inputs + lf1 = pl.LazyFrame( + { + "a1": [0, 1, 2, 3, 4], + "a2": [5, 6, 7, 8, 9], + } + ) + lf2 = pl.LazyFrame( + { + "b1": [0, 1, 2, 3, 4], + "b2": [5, 6, 7, 8, 9], + } + ) + + query = pl.concat( + [ + lf1.filter(pl.col("a1") < 4), + lf2.filter(pl.col("b1") > 0), + ], + how="horizontal", + ).filter(pl.col("b2") < 9) + + expected = pl.DataFrame( + { + "a1": [0, 1, 2], + "a2": [5, 6, 7], + "b1": [1, 2, 3], + "b2": [6, 7, 8], + } + ) + result = query.collect(predicate_pushdown=True) + assert_frame_equal(result, expected) + + +def test_predicate_pd_join_13300() -> None: + # https://github.com/pola-rs/polars/issues/13300 + + lf = pl.LazyFrame({"col3": range(10, 14), "new_col": range(11, 15)}) + lf_other = pl.LazyFrame({"col4": [0, 11, 2, 13]}) + + lf = lf.join(lf_other, left_on="new_col", right_on="col4", how="left") + for res in ( + lf.filter(pl.col("new_col") < 12), + lf.remove(pl.col("new_col") >= 12), + ): + assert res.collect().to_dict(as_series=False) == {"col3": [10], "new_col": [11]} + + +def test_filter_eq_missing_13861() -> None: + lf = pl.LazyFrame({"a": [1, None, 3], "b": ["xx", "yy", None]}) + lf_empty = lf.clear() + + with pytest.warns(UserWarning, match="Comparisons with None always result in null"): + assert_frame_equal(lf.collect().filter(a=None), lf_empty.collect()) + + with pytest.warns(UserWarning, match="Comparisons with None always result in null"): + assert_frame_equal(lf.collect().remove(a=None), lf.collect()) + + with pytest.warns(UserWarning, match="Comparisons with None always result in null"): + lff = lf.filter(a=None) + assert lff.collect().rows() == [] + assert " ==v " not in lff.explain() # check no `eq_missing` op + + with pytest.warns(UserWarning, match="Comparisons with None always result in null"): + assert_frame_equal(lf.collect().filter(a=None), lf_empty.collect()) + + with pytest.warns(UserWarning, match="Comparisons with None always result in null"): + assert_frame_equal(lf.collect().remove(a=None), lf.collect()) + + for filter_expr in ( + pl.col("a").eq_missing(None), + pl.col("a").is_null(), + ): + assert lf.collect().filter(filter_expr).rows() == [(None, "yy")] + + +@pytest.mark.parametrize("how", ["left", "inner"]) +def test_predicate_pushdown_block_join(how: Any) -> None: + q = ( + pl.LazyFrame({"a": [1]}) + .join( + pl.LazyFrame({"a": [2], "b": [1]}), + left_on=["a"], + right_on=["b"], + how=how, + ) + .filter(pl.col("a") == 1) + ) + assert_frame_equal(q.collect(no_optimization=True), q.collect()) + + +def test_predicate_push_down_with_alias_15442() -> None: + df = pl.DataFrame({"a": [1]}) + output = ( + df.lazy() + .filter(pl.col("a").alias("x").drop_nulls() > 0) + .collect(predicate_pushdown=True) + ) + assert output.to_dict(as_series=False) == {"a": [1]} + + +def test_predicate_slice_pushdown_list_gather_17492() -> None: + lf = pl.LazyFrame({"val": [[1], [1, 1]], "len": [1, 2]}) + + assert_frame_equal( + lf.filter(pl.col("len") == 2).filter(pl.col("val").list.get(1) == 1), + lf.slice(1, 1), + ) + + # null_on_oob=True can pass + + plan = ( + lf.filter(pl.col("len") == 2) + .filter(pl.col("val").list.get(1, null_on_oob=True) == 1) + .explain() + ) + + assert re.search(r"FILTER.*FROM\n\s*DF", plan, re.DOTALL) is not None + + # Also check slice pushdown + q = lf.with_columns(pl.col("val").list.get(1).alias("b")).slice(1, 1) + + with pytest.raises(ComputeError, match="get index is out of bounds"): + q.collect() + + +def test_predicate_pushdown_struct_unnest_19632() -> None: + lf = pl.LazyFrame({"a": [{"a": 1, "b": 2}]}).unnest("a") + + q = lf.filter(pl.col("a") == 1) + plan = q.explain() + + assert "FILTER" in plan + assert plan.index("FILTER") < plan.index("UNNEST") + + assert_frame_equal( + q.collect(), + pl.DataFrame({"a": 1, "b": 2}), + ) + + # With `pl.struct()` + lf = pl.LazyFrame({"a": 1, "b": 2}).select(pl.struct(pl.all())).unnest("a") + + q = lf.filter(pl.col("a") == 1) + plan = q.explain() + + assert "FILTER" in plan + assert plan.index("FILTER") < plan.index("UNNEST") + + assert_frame_equal( + q.collect(), + pl.DataFrame({"a": 1, "b": 2}), + ) + + # With `value_counts()` + lf = pl.LazyFrame({"a": [1]}).select(pl.col("a").value_counts()).unnest("a") + + q = lf.filter(pl.col("a") == 1) + plan = q.explain() + + assert plan.index("FILTER") < plan.index("UNNEST") + + assert_frame_equal( + q.collect(), + pl.DataFrame({"a": 1, "count": 1}, schema={"a": pl.Int64, "count": pl.UInt32}), + ) + + +@pytest.mark.parametrize( + "predicate", + [ + pl.col("v") == 7, + pl.col("v") != 99, + pl.col("v") > 0, + pl.col("v") < 999, + pl.col("v").is_in([7]), + pl.col("v").cast(pl.Boolean), + pl.col("b"), + ], +) +@pytest.mark.parametrize("alias", [True, False]) +@pytest.mark.parametrize("join_type", ["left", "right"]) +def test_predicate_pushdown_join_19772( + predicate: pl.Expr, join_type: str, alias: bool +) -> None: + left = pl.LazyFrame({"k": [1, 2]}) + right = pl.LazyFrame({"k": [1], "v": [7], "b": True}) + + if join_type == "right": + [left, right] = [right, left] + + if alias: + predicate = predicate.alias(":V") + + q = left.join(right, on="k", how=join_type).filter(predicate) # type: ignore[arg-type] + + plan = q.explain() + assert plan.startswith("FILTER") + + expect = pl.DataFrame({"k": 1, "v": 7, "b": True}) + + if join_type == "right": + expect = expect.select("v", "b", "k") + + assert_frame_equal(q.collect(no_optimization=True), expect) + assert_frame_equal(q.collect(), expect) + + +def test_predicate_pushdown_scalar_20489() -> None: + df = pl.DataFrame({"a": [1]}) + mask = pl.Series([False]) + + assert_frame_equal( + df.lazy().with_columns(b=pl.Series([2])).filter(mask).collect(), + pl.DataFrame(schema={"a": pl.Int64, "b": pl.Int64}), + ) + + +def test_predicates_not_split_when_pushdown_disabled_20475() -> None: + # This is important for the eager `DataFrame.filter()`, as that runs without + # predicate pushdown enabled. Splitting the predicates in that case can + # severely degrade performance. + q = pl.LazyFrame({"a": 1, "b": 1, "c": 1}).filter( + pl.col("a") > 0, pl.col("b") > 0, pl.col("c") > 0 + ) + assert q.explain(predicate_pushdown=False).count("FILTER") == 1 + + +def test_predicate_filtering_against_nulls() -> None: + df = pl.DataFrame({"num": [1, 2, None, 4]}) + + for res in ( + df.filter(pl.col("num") > 2), + df.filter(pl.col("num").is_in([3, 4, 5])), + ): + assert res["num"].to_list() == [4] + + for res in ( + df.remove(pl.col("num") <= 2), + df.remove(pl.col("num").is_in([1, 2, 3])), + ): + assert res["num"].to_list() == [None, 4] + + for res in ( + df.filter(pl.col("num").ne_missing(None)), + df.remove(pl.col("num").eq_missing(None)), + ): + assert res["num"].to_list() == [1, 2, 4] + + +@pytest.mark.parametrize( + ("query", "expected"), + [ + ( + ( + pl.LazyFrame({"a": [1], "b": [2], "c": [3]}) + .rename({"a": "A", "b": "a"}) + .select("A", "c") + .filter(pl.col("A") == 1) + ), + pl.DataFrame({"A": 1, "c": 3}), + ), + ( + ( + pl.LazyFrame({"a": [1], "b": [2], "c": [3]}) + .rename({"b": "a", "a": "A"}) + .select("A", "c") + .filter(pl.col("A") == 1) + ), + pl.DataFrame({"A": 1, "c": 3}), + ), + ( + ( + pl.LazyFrame({"a": [1], "b": [2], "c": [3]}) + .rename({"a": "b", "b": "a"}) + .select("a", "b", "c") + .filter(pl.col("b") == 1) + ), + pl.DataFrame({"a": 2, "b": 1, "c": 3}), + ), + ( + ( + pl.LazyFrame({"a": [1], "b": [2], "c": [3]}) + .rename({"a": "b", "b": "a"}) + .select("b", "c") + .filter(pl.col("b") == 1) + ), + pl.DataFrame({"b": 1, "c": 3}), + ), + ( + ( + pl.LazyFrame({"a": [1], "b": [2], "c": [3]}) + .rename({"b": "a", "a": "b"}) + .select("a", "b", "c") + .filter(pl.col("b") == 1) + ), + pl.DataFrame({"a": 2, "b": 1, "c": 3}), + ), + ], +) +def test_predicate_pushdown_lazy_rename_22373( + query: pl.LazyFrame, + expected: pl.DataFrame, +) -> None: + assert_frame_equal( + query.collect(), + expected, + ) + + # Ensure filter is pushed past rename + plan = query.explain() + assert plan.index("FILTER") > plan.index("RENAME") diff --git a/py-polars/tests/unit/test_projections.py b/py-polars/tests/unit/test_projections.py new file mode 100644 index 000000000000..d626bb66ddcc --- /dev/null +++ b/py-polars/tests/unit/test_projections.py @@ -0,0 +1,670 @@ +from typing import Literal + +import numpy as np +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + + +def test_projection_on_semi_join_4789() -> None: + lfa = pl.DataFrame({"a": [1], "p": [1]}).lazy() + + lfb = pl.DataFrame({"seq": [1], "p": [1]}).lazy() + + ab = lfa.join(lfb, on="p", how="semi").inspect() + + intermediate_agg = (ab.group_by("a").agg([pl.col("a").alias("seq")])).select( + ["a", "seq"] + ) + + q = ab.join(intermediate_agg, on="a") + + assert q.collect().to_dict(as_series=False) == {"a": [1], "p": [1], "seq": [[1]]} + + +def test_unpivot_projection_pd_block_4997() -> None: + assert ( + pl.DataFrame({"col1": ["a"], "col2": ["b"]}) + .with_row_index() + .lazy() + .unpivot(index="index") + .group_by("index") + .agg(pl.col("variable").alias("result")) + .collect() + ).to_dict(as_series=False) == {"index": [0], "result": [["col1", "col2"]]} + + +def test_double_projection_pushdown() -> None: + assert ( + "2/3 COLUMNS" + in ( + pl.DataFrame({"c0": [], "c1": [], "c2": []}) + .lazy() + .select(["c0", "c1", "c2"]) + .select(["c0", "c1"]) + ).explain() + ) + + +def test_group_by_projection_pushdown() -> None: + assert ( + "2/3 COLUMNS" + in ( + pl.DataFrame({"c0": [], "c1": [], "c2": []}) + .lazy() + .group_by("c0") + .agg( + [ + pl.col("c1").sum().alias("sum(c1)"), + pl.col("c2").mean().alias("mean(c2)"), + ] + ) + .select(["sum(c1)"]) + ).explain() + ) + + +def test_unnest_projection_pushdown() -> None: + lf = pl.DataFrame({"x|y|z": [1, 2], "a|b|c": [2, 3]}).lazy() + + mlf = ( + lf.unpivot() + .with_columns(pl.col("variable").str.split_exact("|", 2)) + .unnest("variable") + ) + mlf = mlf.select( + pl.col("field_1").cast(pl.Categorical).alias("row"), + pl.col("field_2").cast(pl.Categorical).alias("col"), + pl.col("value"), + ) + + out = ( + mlf.sort( + [pl.col.row.cast(pl.String), pl.col.col.cast(pl.String)], + maintain_order=True, + ) + .collect() + .to_dict(as_series=False) + ) + assert out == { + "row": ["b", "b", "y", "y"], + "col": ["c", "c", "z", "z"], + "value": [2, 3, 1, 2], + } + + +def test_hconcat_projection_pushdown() -> None: + lf1 = pl.LazyFrame({"a": [0, 1, 2], "b": [3, 4, 5]}) + lf2 = pl.LazyFrame({"c": [6, 7, 8], "d": [9, 10, 11]}) + query = pl.concat([lf1, lf2], how="horizontal").select(["a", "d"]) + + explanation = query.explain() + assert explanation.count("1/2 COLUMNS") == 2 + + out = query.collect() + expected = pl.DataFrame({"a": [0, 1, 2], "d": [9, 10, 11]}) + assert_frame_equal(out, expected) + + +def test_hconcat_projection_pushdown_length_maintained() -> None: + # We can't eliminate the second input completely as this affects + # the length of the result, even though no columns are used. + lf1 = pl.LazyFrame({"a": [0, 1], "b": [2, 3]}) + lf2 = pl.LazyFrame({"c": [4, 5, 6, 7], "d": [8, 9, 10, 11]}) + query = pl.concat([lf1, lf2], how="horizontal").select(["a"]) + + explanation = query.explain() + assert "1/2 COLUMNS" in explanation + + out = query.collect() + expected = pl.DataFrame({"a": [0, 1, None, None]}) + assert_frame_equal(out, expected) + + +@pytest.mark.may_fail_auto_streaming +def test_unnest_columns_available() -> None: + df = pl.DataFrame( + { + "title": ["Avatar", "spectre", "King Kong"], + "content_rating": ["PG-13"] * 3, + "genres": [ + "Action|Adventure|Fantasy|Sci-Fi", + "Action|Adventure|Thriller", + "Action|Adventure|Drama|Romance", + ], + } + ).lazy() + + q = df.with_columns( + pl.col("genres") + .str.split("|") + .list.to_struct( + n_field_strategy="max_width", fields=lambda i: f"genre{i + 1}", _eager=True + ) + ).unnest("genres") + + out = q.collect() + assert out.to_dict(as_series=False) == { + "title": ["Avatar", "spectre", "King Kong"], + "content_rating": ["PG-13", "PG-13", "PG-13"], + "genre1": ["Action", "Action", "Action"], + "genre2": ["Adventure", "Adventure", "Adventure"], + "genre3": ["Fantasy", "Thriller", "Drama"], + "genre4": ["Sci-Fi", None, "Romance"], + } + + +def test_double_projection_union() -> None: + lf1 = pl.DataFrame( + { + "a": [1, 2, 3, 4], + "b": [2, 3, 4, 5], + "c": [1, 1, 2, 2], + "d": [1, 2, 2, 1], + } + ).lazy() + + lf2 = pl.DataFrame( + { + "a": [5, 6, 7, 8], + "b": [6, 7, 8, 9], + "c": [1, 2, 1, 3], + } + ).lazy() + + # in this query the group_by projects only 2 columns, that's one + # less than the upstream projection so the union will fail if + # the select node does not prune one column + q = lf1.select(["a", "b", "c"]) + + q = pl.concat([q, lf2]) + + q = q.group_by("c", maintain_order=True).agg([pl.col("a")]) + assert q.collect().to_dict(as_series=False) == { + "c": [1, 2, 3], + "a": [[1, 2, 5, 7], [3, 4, 6], [8]], + } + + +def test_asof_join_projection_() -> None: + lf1 = ( + pl.DataFrame( + { + "m": np.linspace(0, 5, 7), + "a": np.linspace(0, 5, 7), + "b": np.linspace(0, 5, 7), + "c": pl.Series(np.linspace(0, 5, 7)).cast(str), + "d": np.linspace(0, 5, 7), + } + ) + .lazy() + .set_sorted("b") + ) + lf2 = ( + pl.DataFrame( + { + "group": [0, 2, 3, 0, 1, 2, 3], + "val": [0.0, 2.5, 2.6, 2.7, 3.4, 4.0, 5.0], + "c": ["x", "x", "x", "y", "y", "y", "y"], + } + ) + .with_columns(pl.col("val").alias("b")) + .lazy() + .set_sorted("b") + ) + + joined = lf1.join_asof( + lf2, + on="b", + by=["c"], + strategy="backward", + ) + + expressions = [ + "m", + "a", + "b", + "c", + "d", + pl.lit(0, dtype=pl.Int64).alias("group"), + pl.lit(0.1).alias("val"), + ] + dirty_lf1 = lf1.select(expressions) + + concatted = pl.concat([joined, dirty_lf1]) + assert concatted.select(["b", "a"]).collect().to_dict(as_series=False) == { + "b": [ + 0.0, + 0.8333333333333334, + 1.6666666666666667, + 2.5, + 3.3333333333333335, + 4.166666666666667, + 5.0, + 0.0, + 0.8333333333333334, + 1.6666666666666667, + 2.5, + 3.3333333333333335, + 4.166666666666667, + 5.0, + ], + "a": [ + 0.0, + 0.8333333333333334, + 1.6666666666666667, + 2.5, + 3.3333333333333335, + 4.166666666666667, + 5.0, + 0.0, + 0.8333333333333334, + 1.6666666666666667, + 2.5, + 3.3333333333333335, + 4.166666666666667, + 5.0, + ], + } + + +def test_merge_sorted_projection_pd() -> None: + lf = pl.LazyFrame( + { + "foo": [1, 2, 3, 4], + "bar": ["patrick", "lukas", "onion", "afk"], + } + ).sort("foo") + + lf2 = pl.LazyFrame({"foo": [5, 6], "bar": ["nice", "false"]}).sort("foo") + + assert ( + lf.merge_sorted(lf2, key="foo").reverse().select(["bar"]) + ).collect().to_dict(as_series=False) == { + "bar": ["false", "nice", "afk", "onion", "lukas", "patrick"] + } + + +def test_distinct_projection_pd_7578() -> None: + lf = pl.LazyFrame( + { + "foo": ["0", "1", "2", "1", "2"], + "bar": ["a", "a", "a", "b", "b"], + } + ) + + result = lf.unique().group_by("bar").agg(pl.len()) + expected = pl.LazyFrame( + { + "bar": ["a", "b"], + "len": [3, 2], + }, + schema_overrides={"len": pl.UInt32}, + ) + assert_frame_equal(result, expected, check_row_order=False) + + +def test_join_suffix_collision_9562() -> None: + df = pl.DataFrame( + { + "foo": [1, 2, 3], + "bar": [6.0, 7.0, 8.0], + "ham": ["a", "b", "c"], + } + ) + other_df = pl.DataFrame( + { + "apple": ["x", "y", "z"], + "ham": ["a", "b", "d"], + } + ) + df.join(other_df, on="ham") + assert df.lazy().join( + other_df.lazy(), + how="inner", + left_on="ham", + right_on="ham", + suffix="m", + maintain_order="right", + ).select("ham").collect().to_dict(as_series=False) == {"ham": ["a", "b"]} + + +def test_projection_join_names_9955() -> None: + batting = pl.LazyFrame( + { + "playerID": ["abercda01"], + "yearID": [1871], + "lgID": ["NA"], + } + ) + + awards_players = pl.LazyFrame( + { + "playerID": ["bondto01"], + "yearID": [1877], + "lgID": ["NL"], + } + ) + + right = awards_players.filter(pl.col("lgID") == "NL").select("playerID") + + q = batting.join( + right, + left_on=[pl.col("playerID")], + right_on=[pl.col("playerID")], + how="inner", + ) + + q = q.select(batting.collect_schema()) + + assert q.collect().schema == { + "playerID": pl.String, + "yearID": pl.Int64, + "lgID": pl.String, + } + + +def test_projection_rename_10595() -> None: + lf = pl.LazyFrame(schema={"a": pl.Float32, "b": pl.Float32}) + + result = lf.select("a", "b").rename({"b": "a", "a": "b"}).select("a") + assert result.collect().schema == {"a": pl.Float32} + + result = ( + lf.select("a", "b") + .rename({"c": "d", "b": "a", "d": "c", "a": "b"}, strict=False) + .select("a") + ) + assert result.collect().schema == {"a": pl.Float32} + + +def test_projection_count_11841() -> None: + pl.LazyFrame({"x": 1}).select(records=pl.len()).select( + pl.lit(1).alias("x"), pl.all() + ).collect() + + +def test_schema_full_outer_join_projection_pd_13287() -> None: + lf = pl.LazyFrame({"a": [1, 1], "b": [2, 3]}) + lf2 = pl.LazyFrame({"a": [1, 1], "c": [2, 3]}) + + assert lf.join( + lf2, + how="full", + left_on="a", + right_on="c", + maintain_order="right_left", + ).with_columns( + pl.col("a").fill_null(pl.col("c")), + ).select("a").collect().to_dict(as_series=False) == {"a": [2, 3, 1, 1]} + + +def test_projection_pushdown_full_outer_join_duplicates() -> None: + df1 = pl.DataFrame({"a": [1, 2, 3], "b": [10, 20, 30]}).lazy() + df2 = pl.DataFrame({"a": [1, 2, 3], "b": [10, 20, 30]}).lazy() + assert ( + df1.join(df2, on="a", how="full", maintain_order="right") + .with_columns(c=0) + .select("a", "c") + .collect() + ).to_dict(as_series=False) == {"a": [1, 2, 3], "c": [0, 0, 0]} + + +def test_rolling_key_projected_13617() -> None: + df = pl.DataFrame({"idx": [1, 2], "value": ["a", "b"]}).set_sorted("idx") + ldf = df.lazy().select(pl.col("value").rolling("idx", period="1i")) + plan = ldf.explain(projection_pushdown=True) + assert r"2/2 COLUMNS" in plan + out = ldf.collect(projection_pushdown=True) + assert out.to_dict(as_series=False) == {"value": [["a"], ["b"]]} + + +def test_projection_drop_with_series_lit_14382() -> None: + df = pl.DataFrame({"b": [1, 6, 8, 7]}) + df2 = pl.DataFrame({"a": [1, 2, 4, 4], "b": [True, True, True, False]}) + + q = ( + df2.lazy() + .select( + *["a", "b"], pl.lit("b").alias("b_name"), df.get_column("b").alias("b_old") + ) + .filter(pl.col("b").not_()) + .drop("b") + ) + assert q.collect().to_dict(as_series=False) == { + "a": [4], + "b_name": ["b"], + "b_old": [7], + } + + +def test_cached_schema_15651() -> None: + q = pl.LazyFrame({"col1": [1], "col2": [2], "col3": [3]}) + q = q.with_row_index() + q = q.filter(~pl.col("col1").is_null()) + # create a subplan diverging from q + _ = q.select(pl.len()).collect(projection_pushdown=True) + + # ensure that q's "cached" columns are still correct + assert q.collect_schema().names() == q.collect().columns + + +def test_double_projection_pushdown_15895() -> None: + df = ( + pl.LazyFrame({"A": [0], "B": [1]}) + .select(C="A", A="B") + .group_by(1) + .all() + .collect(projection_pushdown=True) + ) + assert df.to_dict(as_series=False) == { + "literal": [1], + "C": [[0]], + "A": [[1]], + } + + +@pytest.mark.parametrize("join_type", ["inner", "left", "full"]) +def test_non_coalesce_join_projection_pushdown_16515( + join_type: Literal["inner", "left", "full"], +) -> None: + left = pl.LazyFrame({"x": 1}) + right = pl.LazyFrame({"y": 1}) + + assert ( + left.join(right, how=join_type, left_on="x", right_on="y", coalesce=False) + .select("y") + .collect() + .item() + == 1 + ) + + +@pytest.mark.parametrize("join_type", ["inner", "left", "full"]) +def test_non_coalesce_multi_key_join_projection_pushdown_16554( + join_type: Literal["inner", "left", "full"], +) -> None: + lf1 = pl.LazyFrame( + { + "a": [1, 2, 3, 4, 5], + "b": [1, 2, 3, 4, 5], + } + ) + lf2 = pl.LazyFrame( + { + "a": [0, 2, 3, 4, 5], + "b": [1, 2, 3, 5, 6], + "c": [7, 5, 3, 5, 7], + } + ) + + expect = ( + lf1.with_columns(a2="a") + .join( + other=lf2, + how=join_type, + left_on=["a", "a2"], + right_on=["b", "c"], + coalesce=False, + ) + .select("a", "b", "c") + .collect() + ) + + out = ( + lf1.join( + other=lf2, + how=join_type, + left_on=["a", "a"], + right_on=["b", "c"], + coalesce=False, + ) + .select("a", "b", "c") + .collect() + ) + + assert_frame_equal(out, expect, check_row_order=False) + + +@pytest.mark.parametrize("how", ["semi", "anti"]) +def test_projection_pushdown_semi_anti_no_selection( + how: Literal["semi", "anti"], +) -> None: + q_a = pl.LazyFrame({"a": [1, 2, 3]}) + + q_b = pl.LazyFrame({"b": [1, 2, 3], "c": [1, 2, 3]}) + + assert "1/2 COLUMNS" in ( + q_a.join(q_b, left_on="a", right_on="b", how=how).explain() + ) + + +def test_projection_empty_frame_len_16904() -> None: + df = pl.LazyFrame({}) + + q = df.select(pl.len()) + + assert "0/0 COLUMNS" in q.explain() + + expect = pl.DataFrame({"len": [0]}, schema_overrides={"len": pl.UInt32()}) + assert_frame_equal(q.collect(), expect) + + +def test_projection_literal_no_alias_17739() -> None: + df = pl.LazyFrame({}) + assert df.select(pl.lit(False)).select("literal").collect().to_dict( + as_series=False + ) == {"literal": [False]} + + +def test_projections_collapse_17781() -> None: + frame1 = pl.LazyFrame( + { + "index": [0], + "data1": [0], + "data2": [0], + } + ) + frame2 = pl.LazyFrame( + { + "index": [0], + "label1": [True], + "label2": [False], + "label3": [False], + }, + schema=[ + ("index", pl.Int64), + ("label1", pl.Boolean), + ("label2", pl.Boolean), + ("label3", pl.Boolean), + ], + ) + cols = ["index", "data1", "label1", "label2"] + + lf = None + for lfj in [frame1, frame2]: + use_columns = [c for c in cols if c in lfj.collect_schema().names()] + lfj = lfj.select(use_columns) + lfj = lfj.select(use_columns) + if lf is None: + lf = lfj + else: + lf = lf.join(lfj, on="index", how="left") + assert "SELECT " not in lf.explain() # type: ignore[union-attr] + + +def test_with_columns_projection_pushdown() -> None: + # # Summary + # `process_hstack` in projection PD incorrectly took a fast-path meant for + # LP nodes that don't add new columns to the schema, which stops projection + # PD if it sees that the schema lengths on the upper node matches. + # + # To trigger this, we drop the same number of columns before and after + # the with_columns, and in the with_columns we also add the same number of + # columns. + lf = ( + pl.scan_csv( + b"""\ +a,b,c,d,e +1,1,1,1,1 +""", + include_file_paths="path", + ) + .drop("a", "b") + .with_columns(pl.lit(1).alias(x) for x in ["x", "y"]) + .drop("c", "d") + ) + + plan = lf.explain().strip() + + assert plan.startswith("WITH_COLUMNS:") + # [dyn int: 1.alias("x"), dyn int: 1.alias("y")] + # Csv SCAN [20 in-mem bytes] + assert plan.endswith("1/6 COLUMNS") + + +def test_projection_pushdown_height_20221() -> None: + q = pl.LazyFrame({"a": range(5)}).select("a", b=pl.col("a").first()).select("b") + assert_frame_equal(q.collect(), pl.DataFrame({"b": [0, 0, 0, 0, 0]})) + + +def test_select_len_20337() -> None: + strs = [str(i) for i in range(3)] + q = pl.LazyFrame({"a": strs, "b": strs, "c": strs, "d": range(3)}) + + q = q.group_by(pl.col("c")).agg( + (pl.col("d") * j).alias(f"mult {j}") for j in [1, 2] + ) + + q = q.with_row_index("foo") + assert q.select(pl.len()).collect().item() == 3 + + +def test_filter_count_projection_20902() -> None: + lineitem_ldf = pl.LazyFrame( + { + "l_partkey": [1], + "l_quantity": [1], + "l_extendedprice": [1], + } + ) + assert ( + "1/3 COLUMNS" + in lineitem_ldf.filter(pl.col("l_partkey").is_between(10, 20)) + .select(pl.len()) + .explain() + ) + + +def test_projection_count_21154() -> None: + lf = pl.LazyFrame( + { + "a": [1, 2, 3], + "b": [4, 5, 6], + } + ) + + assert lf.unique("a").select(pl.len()).collect().to_dict(as_series=False) == { + "len": [3] + } diff --git a/py-polars/tests/unit/test_queries.py b/py-polars/tests/unit/test_queries.py new file mode 100644 index 000000000000..d5620f2d2ce8 --- /dev/null +++ b/py-polars/tests/unit/test_queries.py @@ -0,0 +1,390 @@ +from __future__ import annotations + +from datetime import date, datetime, time, timedelta + +import numpy as np +import pandas as pd +import pytest + +import polars as pl +from polars.testing import assert_frame_equal +from tests.unit.conftest import NUMERIC_DTYPES + + +def test_sort_by_bools() -> None: + # tests dispatch + df = pl.DataFrame( + { + "foo": [1, 2, 3], + "bar": [6.0, 7.0, 8.0], + "ham": ["a", "b", "c"], + } + ) + out = df.with_columns((pl.col("foo") % 2 == 1).alias("foo_odd")).sort( + by=["foo_odd", "foo"] + ) + assert out.rows() == [ + (2, 7.0, "b", False), + (1, 6.0, "a", True), + (3, 8.0, "c", True), + ] + assert out.shape == (3, 4) + + +def test_repeat_expansion_in_group_by() -> None: + out = ( + pl.DataFrame({"g": [1, 2, 2, 3, 3, 3]}) + .group_by("g", maintain_order=True) + .agg(pl.repeat(1, pl.len()).cum_sum()) + .to_dict(as_series=False) + ) + assert out == {"g": [1, 2, 3], "repeat": [[1], [1, 2], [1, 2, 3]]} + + +def test_agg_after_head() -> None: + a = [1, 1, 1, 2, 2, 3, 3, 3, 3] + + df = pl.DataFrame({"a": a, "b": pl.arange(1, len(a) + 1, eager=True)}) + + expected = pl.DataFrame({"a": [1, 2, 3], "b": [6, 9, 21]}) + + for maintain_order in [True, False]: + out = df.group_by("a", maintain_order=maintain_order).agg( + [pl.col("b").head(3).sum()] + ) + + if not maintain_order: + out = out.sort("a") + + assert_frame_equal(out, expected) + + +def test_overflow_uint16_agg_mean() -> None: + assert ( + pl.DataFrame( + { + "col1": ["A" for _ in range(1025)], + "col3": [64 for _ in range(1025)], + } + ) + .with_columns(pl.col("col3").cast(pl.UInt16)) + .group_by(["col1"]) + .agg(pl.col("col3").mean()) + .to_dict(as_series=False) + ) == {"col1": ["A"], "col3": [64.0]} + + +def test_binary_on_list_agg_3345() -> None: + df = pl.DataFrame( + { + "group": ["A", "A", "A", "B", "B", "B", "B"], + "id": [1, 2, 1, 4, 5, 4, 6], + } + ) + + assert ( + df.group_by(["group"], maintain_order=True) + .agg( + [ + ( + (pl.col("id").unique_counts() / pl.col("id").len()).log() + * -1 + * (pl.col("id").unique_counts() / pl.col("id").len()) + ).sum() + ] + ) + .to_dict(as_series=False) + ) == {"group": ["A", "B"], "id": [0.6365141682948128, 1.0397207708399179]} + + +def test_maintain_order_after_sampling() -> None: + # internally samples cardinality + # check if the maintain_order kwarg is dispatched + df = pl.DataFrame( + { + "type": ["A", "B", "C", "D", "A", "B", "C", "D"], + "value": [1, 3, 2, 3, 4, 5, 3, 4], + } + ) + + result = df.group_by("type", maintain_order=True).agg(pl.col("value").sum()) + expected = {"type": ["A", "B", "C", "D"], "value": [5, 8, 5, 7]} + assert result.to_dict(as_series=False) == expected + + +@pytest.mark.may_fail_auto_streaming +def test_sorted_group_by_optimization() -> None: + df = pl.DataFrame({"a": np.random.randint(0, 5, 20)}) + + # the sorted optimization should not randomize the + # groups, so this is tests that we hit the sorted optimization + for descending in [True, False]: + sorted_implicit = ( + df.with_columns(pl.col("a").sort(descending=descending)) + .group_by("a") + .agg(pl.len()) + ) + sorted_explicit = ( + df.group_by("a").agg(pl.len()).sort("a", descending=descending) + ) + assert_frame_equal(sorted_explicit, sorted_implicit) + + +def test_median_on_shifted_col_3522() -> None: + df = pl.DataFrame( + { + "foo": [ + datetime(2022, 5, 5, 12, 31, 34), + datetime(2022, 5, 5, 12, 47, 1), + datetime(2022, 5, 6, 8, 59, 11), + ] + } + ) + diffs = df.select(pl.col("foo").diff().dt.total_seconds()) + assert diffs.select(pl.col("foo").median()).to_series()[0] == 36828.5 + + +def test_group_by_agg_equals_zero_3535() -> None: + # setup test frame + df = pl.DataFrame( + data=[ + # note: the 'bb'-keyed values should clearly sum to 0 + ("aa", 10, None), + ("bb", -10, 0.5), + ("bb", 10, -0.5), + ("cc", -99, 10.5), + ("cc", None, 0.0), + ], + schema=[ + ("key", pl.String), + ("val1", pl.Int16), + ("val2", pl.Float32), + ], + orient="row", + ) + # group by the key, aggregating the two numeric cols + assert df.group_by(pl.col("key"), maintain_order=True).agg( + [pl.col("val1").sum(), pl.col("val2").sum()] + ).to_dict(as_series=False) == { + "key": ["aa", "bb", "cc"], + "val1": [10, 0, -99], + "val2": [0.0, 0.0, 10.5], + } + + +def test_group_by_followed_by_limit() -> None: + lf = pl.LazyFrame( + { + "key": ["xx", "yy", "zz", "xx", "zz", "zz"], + "val1": [15, 25, 10, 20, 20, 20], + "val2": [-33, 20, 44, -2, 16, 71], + } + ) + grp = lf.group_by("key", maintain_order=True).agg(pl.col("val1", "val2").sum()) + assert sorted(grp.collect().rows()) == [ + ("xx", 35, -35), + ("yy", 25, 20), + ("zz", 50, 131), + ] + assert sorted(grp.head(2).collect().rows()) == [ + ("xx", 35, -35), + ("yy", 25, 20), + ] + assert sorted(grp.head(10).collect().rows()) == [ + ("xx", 35, -35), + ("yy", 25, 20), + ("zz", 50, 131), + ] + + +def test_dtype_concat_3735() -> None: + for dt in NUMERIC_DTYPES: + d1 = pl.DataFrame([pl.Series("val", [1, 2], dtype=dt)]) + + d2 = pl.DataFrame([pl.Series("val", [3, 4], dtype=dt)]) + df = pl.concat([d1, d2]) + + assert df.shape == (4, 1) + assert df.columns == ["val"] + assert df.to_series().to_list() == [1, 2, 3, 4] + + +def test_opaque_filter_on_lists_3784() -> None: + df = pl.DataFrame( + {"str": ["A", "B", "A", "B", "C"], "group": [1, 1, 2, 1, 2]} + ).lazy() + df = df.with_columns(pl.col("str").cast(pl.Categorical)) + + df_groups = df.group_by("group").agg([pl.col("str").alias("str_list")]) + + pre = "A" + succ = "B" + + assert ( + df_groups.filter( + pl.col("str_list").map_elements( + lambda variant: pre in variant + and succ in variant + and variant.to_list().index(pre) < variant.to_list().index(succ) + ) + ) + ).collect().to_dict(as_series=False) == { + "group": [1], + "str_list": [["A", "B", "B"]], + } + + +def test_ternary_none_struct() -> None: + ignore_nulls = False + + def map_expr(name: str) -> pl.Expr: + return ( + pl.when(ignore_nulls or pl.col(name).null_count() == 0) + .then( + pl.struct( + [ + pl.sum(name).alias("sum"), + (pl.len() - pl.col(name).null_count()).alias("count"), + ] + ), + ) + .otherwise(None) + ).alias("out") + + assert ( + pl.DataFrame({"groups": [1, 2, 3, 4], "values": [None, None, 1, 2]}) + .group_by("groups", maintain_order=True) + .agg([map_expr("values")]) + ).to_dict(as_series=False) == { + "groups": [1, 2, 3, 4], + "out": [ + None, + None, + {"sum": 1, "count": 1}, + {"sum": 2, "count": 1}, + ], + } + + +def test_edge_cast_string_duplicates_4259() -> None: + # carefully constructed data. + # note that row 2, 3 concatenated are the same string ('5461214484') + df = pl.DataFrame( + { + "a": [99, 54612, 546121], + "b": [1, 14484, 4484], + } + ).with_columns(pl.all().cast(pl.String)) + + mask = df.select(["a", "b"]).is_duplicated() + df_filtered = df.filter(pl.lit(mask)) + + assert df_filtered.shape == (0, 2) + assert df_filtered.rows() == [] + + +def test_query_4438() -> None: + df = pl.DataFrame({"x": [1, 2, 3, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 1, 1]}) + + q = ( + df.lazy() + .with_columns(pl.col("x").rolling_max(window_size=3).alias("rolling_max")) + .fill_null(strategy="backward") + .with_columns( + pl.col("rolling_max").rolling_max(window_size=3).alias("rolling_max_2") + ) + ) + assert q.collect()["rolling_max_2"].to_list() == [ + None, + None, + 3, + 10, + 10, + 10, + 10, + 10, + 9, + 8, + 7, + 6, + 5, + 4, + 3, + ] + + +def test_query_4538() -> None: + df = pl.DataFrame( + [ + pl.Series("value", ["aaa", "bbb"]), + ] + ) + assert df.select([pl.col("value").str.to_uppercase().is_in(["AAA"])])[ + "value" + ].to_list() == [True, False] + + +def test_none_comparison_4773() -> None: + df = pl.DataFrame( + { + "x": [0, 1, None, 2], + "y": [1, 2, None, 3], + } + ).filter(pl.col("x") != pl.col("y")) + assert df.shape == (3, 2) + assert df.rows() == [(0, 1), (1, 2), (2, 3)] + + +def test_datetime_supertype_5236() -> None: + df = pd.DataFrame( + { + "StartDateTime": [pd.Timestamp.now(tz="UTC"), pd.Timestamp.now(tz="UTC")], + "EndDateTime": [pd.Timestamp.now(tz="UTC"), pd.Timestamp.now(tz="UTC")], + } + ) + out = pl.from_pandas(df).filter( + pl.col("StartDateTime") + < (pl.col("EndDateTime").dt.truncate("1d").max() - timedelta(days=1)) + ) + assert out.shape == (0, 2) + assert out.dtypes == [pl.Datetime("ns", "UTC")] * 2 + + +def test_shift_drop_nulls_10875() -> None: + assert pl.LazyFrame({"a": [1, 2, 3]}).shift(1).drop_nulls().collect()[ + "a" + ].to_list() == [1, 2] + + +def test_temporal_downcasts() -> None: + s = pl.Series([-1, 0, 1]).cast(pl.Datetime("us")) + + assert s.to_list() == [ + datetime(1969, 12, 31, 23, 59, 59, 999999), + datetime(1970, 1, 1), + datetime(1970, 1, 1, 0, 0, 0, 1), + ] + + # downcast (from us to ms, or from datetime to date) should NOT change the date + for s_dt in (s.dt.date(), s.cast(pl.Date)): + assert s_dt.to_list() == [ + date(1969, 12, 31), + date(1970, 1, 1), + date(1970, 1, 1), + ] + assert s.cast(pl.Datetime("ms")).to_list() == [ + datetime(1969, 12, 31, 23, 59, 59, 999000), + datetime(1970, 1, 1), + datetime(1970, 1, 1), + ] + + +def test_temporal_time_casts() -> None: + s = pl.Series([-1, 0, 1]).cast(pl.Datetime("us")) + + for s_dt in (s.dt.time(), s.cast(pl.Time)): + assert s_dt.to_list() == [ + time(23, 59, 59, 999999), + time(0, 0, 0, 0), + time(0, 0, 0, 1), + ] diff --git a/py-polars/tests/unit/test_row_encoding.py b/py-polars/tests/unit/test_row_encoding.py new file mode 100644 index 000000000000..228d7ec2d859 --- /dev/null +++ b/py-polars/tests/unit/test_row_encoding.py @@ -0,0 +1,358 @@ +from __future__ import annotations + +from decimal import Decimal as D +from typing import TYPE_CHECKING + +import pytest +from hypothesis import given + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal +from polars.testing.parametric import dataframes, series +from polars.testing.parametric.strategies.dtype import dtypes +from tests.unit.conftest import FLOAT_DTYPES, INTEGER_DTYPES + +if TYPE_CHECKING: + from polars._typing import PolarsDataType + +FIELD_COMBS = [ + (descending, nulls_last, False) + for descending in [False, True] + for nulls_last in [False, True] +] + [(False, False, True)] + + +def roundtrip_re( + df: pl.DataFrame, fields: list[tuple[bool, bool, bool]] | None = None +) -> None: + if fields is None: + fields = [(False, False, False)] * df.width + + row_encoded = df._row_encode(fields) + if any(f[2] for f in fields): + return + + dtypes = [(c, df.get_column(c).dtype) for c in df.columns] + result = row_encoded._row_decode(dtypes, fields) + + assert_frame_equal(df, result) + + +def roundtrip_series_re( + values: pl.series.series.ArrayLike, + dtype: PolarsDataType, + field: tuple[bool, bool, bool], +) -> None: + roundtrip_re(pl.Series("series", values, dtype).to_frame(), [field]) + + +@given( + df=dataframes( + excluded_dtypes=[ + pl.Categorical, + pl.Decimal, # Bug: see https://github.com/pola-rs/polars/issues/20308 + ] + ) +) +@pytest.mark.parametrize("field", FIELD_COMBS) +def test_row_encoding_parametric( + df: pl.DataFrame, field: tuple[bool, bool, bool] +) -> None: + roundtrip_re(df, [field] * df.width) + + +@pytest.mark.parametrize("field", FIELD_COMBS) +def test_nulls(field: tuple[bool, bool, bool]) -> None: + roundtrip_series_re([], pl.Null, field) + roundtrip_series_re([None], pl.Null, field) + roundtrip_series_re([None] * 2, pl.Null, field) + roundtrip_series_re([None] * 13, pl.Null, field) + roundtrip_series_re([None] * 42, pl.Null, field) + + +@pytest.mark.parametrize("field", FIELD_COMBS) +def test_bool(field: tuple[bool, bool, bool]) -> None: + roundtrip_series_re([], pl.Boolean, field) + roundtrip_series_re([False], pl.Boolean, field) + roundtrip_series_re([True], pl.Boolean, field) + roundtrip_series_re([False, True], pl.Boolean, field) + roundtrip_series_re([True, False], pl.Boolean, field) + + +@pytest.mark.parametrize("dtype", INTEGER_DTYPES) +@pytest.mark.parametrize("field", FIELD_COMBS) +def test_int(dtype: pl.DataType, field: tuple[bool, bool, bool]) -> None: + min = pl.select(x=dtype.min()).item() # type: ignore[attr-defined] + max = pl.select(x=dtype.max()).item() # type: ignore[attr-defined] + + roundtrip_series_re([], dtype, field) + roundtrip_series_re([0], dtype, field) + roundtrip_series_re([min], dtype, field) + roundtrip_series_re([max], dtype, field) + + roundtrip_series_re([1, 2, 3], dtype, field) + roundtrip_series_re([0, 1, 2, 3], dtype, field) + roundtrip_series_re([min, 0, max], dtype, field) + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.parametrize("field", FIELD_COMBS) +def test_float(dtype: pl.DataType, field: tuple[bool, bool, bool]) -> None: + inf = float("inf") + inf_b = float("-inf") + + roundtrip_series_re([], dtype, field) + roundtrip_series_re([0.0], dtype, field) + roundtrip_series_re([inf], dtype, field) + roundtrip_series_re([-inf_b], dtype, field) + + roundtrip_series_re([1.0, 2.0, 3.0], dtype, field) + roundtrip_series_re([0.0, 1.0, 2.0, 3.0], dtype, field) + roundtrip_series_re([inf, 0, -inf_b], dtype, field) + + +@pytest.mark.parametrize("field", FIELD_COMBS) +def test_str(field: tuple[bool, bool, bool]) -> None: + dtype = pl.String + roundtrip_series_re([], dtype, field) + roundtrip_series_re([""], dtype, field) + + roundtrip_series_re(["a", "b", "c"], dtype, field) + roundtrip_series_re(["", "a", "b", "c"], dtype, field) + + roundtrip_series_re( + ["different", "length", "strings"], + dtype, + field, + ) + roundtrip_series_re( + ["different", "", "length", "", "strings"], + dtype, + field, + ) + + +@pytest.mark.parametrize("field", FIELD_COMBS) +def test_struct(field: tuple[bool, bool, bool]) -> None: + dtype = pl.Struct({}) + roundtrip_series_re([], dtype, field) + roundtrip_series_re([None], dtype, field) + roundtrip_series_re([{}], dtype, field) + roundtrip_series_re([{}, {}, {}], dtype, field) + roundtrip_series_re([{}, None, {}], dtype, field) + + dtype = pl.Struct({"x": pl.Int32}) + roundtrip_series_re([{"x": 1}], dtype, field) + roundtrip_series_re([None], dtype, field) + roundtrip_series_re([{"x": 1}] * 3, dtype, field) + + dtype = pl.Struct({"x": pl.Int32, "y": pl.Int32}) + roundtrip_series_re( + [{"x": 1}, {"y": 2}], + dtype, + field, + ) + roundtrip_series_re([None], dtype, field) + + +@pytest.mark.parametrize("field", FIELD_COMBS) +def test_list(field: tuple[bool, bool, bool]) -> None: + dtype = pl.List(pl.Int32) + roundtrip_series_re([], dtype, field) + roundtrip_series_re([[]], dtype, field) + roundtrip_series_re([[1], [2]], dtype, field) + roundtrip_series_re([[1, 2], [3]], dtype, field) + roundtrip_series_re([[1, 2], [], [3]], dtype, field) + roundtrip_series_re([None, [1, 2], None, [], [3]], dtype, field) + + dtype = pl.List(pl.String) + roundtrip_series_re([], dtype, field) + roundtrip_series_re([[]], dtype, field) + roundtrip_series_re([[""], [""]], dtype, field) + roundtrip_series_re([["abc"], ["xyzw"]], dtype, field) + roundtrip_series_re([["x", "yx"], ["abc"]], dtype, field) + roundtrip_series_re([["wow", "this is"], [], ["cool"]], dtype, field) + roundtrip_series_re( + [None, ["very", "very"], None, [], ["cool"]], + dtype, + field, + ) + + +@pytest.mark.parametrize("field", FIELD_COMBS) +def test_array(field: tuple[bool, bool, bool]) -> None: + dtype = pl.Array(pl.Int32, 0) + roundtrip_series_re([], dtype, field) + roundtrip_series_re([[]], dtype, field) + roundtrip_series_re([None, [], None], dtype, field) + roundtrip_series_re([None], dtype, field) + + dtype = pl.Array(pl.Int32, 2) + roundtrip_series_re([], dtype, field) + roundtrip_series_re([[5, 6]], dtype, field) + roundtrip_series_re([[1, 2], [2, 3]], dtype, field) + roundtrip_series_re([[1, 2], [3, 7]], dtype, field) + roundtrip_series_re([[1, 2], [13, 11], [3, 7]], dtype, field) + roundtrip_series_re( + [None, [1, 2], None, [13, 11], [5, 7]], + dtype, + field, + ) + + dtype = pl.Array(pl.String, 2) + roundtrip_series_re([], dtype, field) + roundtrip_series_re([["a", "b"]], dtype, field) + roundtrip_series_re([["", ""], ["", "a"]], dtype, field) + roundtrip_series_re([["abc", "def"], ["ghi", "xyzw"]], dtype, field) + roundtrip_series_re([["x", "yx"], ["abc", "xxx"]], dtype, field) + roundtrip_series_re( + [["wow", "this is"], ["soo", "so"], ["veryyy", "cool"]], + dtype, + field, + ) + roundtrip_series_re( + [None, ["very", "very"], None, [None, None], ["verryy", "cool"]], + dtype, + field, + ) + + +@pytest.mark.parametrize("field", FIELD_COMBS) +@pytest.mark.parametrize("precision", range(1, 38)) +def test_decimal(field: tuple[bool, bool, bool], precision: int) -> None: + dtype = pl.Decimal(precision=precision, scale=0) + roundtrip_series_re([], dtype, field) + roundtrip_series_re([None], dtype, field) + roundtrip_series_re([D("1")], dtype, field) + roundtrip_series_re([D("-1")], dtype, field) + roundtrip_series_re([D("9" * precision)], dtype, field) + roundtrip_series_re([D("-" + "9" * precision)], dtype, field) + roundtrip_series_re([None, D("-1"), None], dtype, field) + roundtrip_series_re([D("-1"), D("0"), D("1")], dtype, field) + + +@pytest.mark.parametrize("field", FIELD_COMBS) +def test_enum(field: tuple[bool, bool, bool]) -> None: + dtype = pl.Enum([]) + + roundtrip_series_re([], dtype, field) + roundtrip_series_re([None], dtype, field) + roundtrip_series_re([None, None], dtype, field) + + dtype = pl.Enum(["a", "x", "b"]) + + roundtrip_series_re([], dtype, field) + roundtrip_series_re([None], dtype, field) + roundtrip_series_re(["a"], dtype, field) + roundtrip_series_re(["x"], dtype, field) + roundtrip_series_re(["b"], dtype, field) + roundtrip_series_re(["b", "x", "a"], dtype, field) + roundtrip_series_re([None, "b", None], dtype, field) + roundtrip_series_re([None, "a", None], dtype, field) + + +@pytest.mark.parametrize("size", [127, 128, 255, 256, 2**15, 2**15 + 1]) +@pytest.mark.parametrize("field", FIELD_COMBS) +@pytest.mark.slow +def test_large_enum(size: int, field: tuple[bool, bool, bool]) -> None: + dtype = pl.Enum([str(i) for i in range(size)]) + roundtrip_series_re([None, "1"], dtype, field) + roundtrip_series_re(["1", None], dtype, field) + + roundtrip_series_re( + [str(i) for i in range(3, size, int(7 * size / (2**8)))], dtype, field + ) + + +@pytest.mark.parametrize("field", FIELD_COMBS) +def test_list_arr(field: tuple[bool, bool, bool]) -> None: + dtype = pl.List(pl.Array(pl.String, 2)) + roundtrip_series_re([], dtype, field) + roundtrip_series_re([None], dtype, field) + roundtrip_series_re([[None]], dtype, field) + roundtrip_series_re([[[None, None]]], dtype, field) + roundtrip_series_re([[["a", "b"]]], dtype, field) + roundtrip_series_re([[["a", "b"], ["xyz", "wowie"]]], dtype, field) + roundtrip_series_re([[["a", "b"]], None, [None, None]], dtype, field) + + +@pytest.mark.parametrize("field", FIELD_COMBS) +def test_list_struct_arr(field: tuple[bool, bool, bool]) -> None: + dtype = pl.List( + pl.Struct({"x": pl.Array(pl.String, 2), "y": pl.Array(pl.Int64, 3)}) + ) + roundtrip_series_re([], dtype, field) + roundtrip_series_re([None], dtype, field) + roundtrip_series_re([[None]], dtype, field) + roundtrip_series_re([[{"x": None, "y": None}]], dtype, field) + roundtrip_series_re([[{"x": ["a", None], "y": [1, None, 3]}]], dtype, field) + roundtrip_series_re([[{"x": ["a", "xyz"], "y": [1, 7, 3]}]], dtype, field) + roundtrip_series_re([[{"x": ["a", "xyz"], "y": [1, 7, 3]}], []], dtype, field) + + +@pytest.mark.parametrize("field", FIELD_COMBS) +def test_list_nulls(field: tuple[bool, bool, bool]) -> None: + dtype = pl.List(pl.Null) + roundtrip_series_re([], dtype, field) + roundtrip_series_re([[]], dtype, field) + roundtrip_series_re([None], dtype, field) + roundtrip_series_re([[None]], dtype, field) + roundtrip_series_re([[None, None, None]], dtype, field) + roundtrip_series_re([[None], [None, None], [None, None, None]], dtype, field) + + +@pytest.mark.parametrize("field", FIELD_COMBS) +def test_masked_out_list_20151(field: tuple[bool, bool, bool]) -> None: + dtype = pl.List(pl.Int64()) + + values = [[1, 2], None, [4, 5], [None, 3]] + + array_series = pl.Series(values, dtype=pl.Array(pl.Int64(), 2)) + list_from_array_series = array_series.cast(dtype) + + roundtrip_series_re(list_from_array_series, dtype, field) + + +def test_int_after_null() -> None: + roundtrip_re( + pl.DataFrame( + [ + pl.Series("a", [None], pl.Null), + pl.Series("b", [None], pl.Int8), + ] + ), + [(False, True, False), (False, True, False)], + ) + + +@pytest.mark.parametrize("field", FIELD_COMBS) +@given(s=series(allow_null=False, allow_chunks=False, excluded_dtypes=[pl.Categorical])) +def test_optional_eq_non_optional_20320( + field: tuple[bool, bool, bool], s: pl.Series +) -> None: + with_null = s.extend(pl.Series([None], dtype=s.dtype)) + + re_without_null = s.to_frame()._row_encode([field]) + re_with_null = with_null.to_frame()._row_encode([field]) + + re_without_null = re_without_null.cast(pl.Binary) + re_with_null = re_with_null.cast(pl.Binary) + + assert_series_equal(re_with_null.head(s.len()), re_without_null) + + +@pytest.mark.parametrize("field", FIELD_COMBS) +@given(dtype=dtypes(excluded_dtypes=[pl.Categorical])) +def test_null( + field: tuple[bool, bool, bool], + dtype: pl.DataType, +) -> None: + s = pl.Series("a", [None], dtype) + + assert_series_equal( + s.to_frame() + ._row_encode([field]) + ._row_decode([("a", dtype)], [field]) + .to_series(), + s, + ) diff --git a/py-polars/tests/unit/test_row_encoding_sort.py b/py-polars/tests/unit/test_row_encoding_sort.py new file mode 100644 index 000000000000..6184f2972a77 --- /dev/null +++ b/py-polars/tests/unit/test_row_encoding_sort.py @@ -0,0 +1,399 @@ +# mypy: disable-error-code="valid-type" + +from __future__ import annotations + +import datetime +import decimal +import functools +from typing import Any, Literal, Optional, Union, cast + +import pytest +from hypothesis import example, given + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal +from polars.testing.parametric import column, dataframes, series + +Element = Optional[ + Union[ + bool, + int, + float, + str, + decimal.Decimal, + datetime.date, + datetime.datetime, + datetime.time, + datetime.timedelta, + list[Any], + dict[Any, Any], + ] +] +OrderSign = Literal[-1, 0, 1] + + +def elem_order_sign( + lhs: Element, rhs: Element, *, descending: bool, nulls_last: bool +) -> OrderSign: + if isinstance(lhs, pl.Series) and isinstance(rhs, pl.Series): + assert lhs.dtype == rhs.dtype + + if isinstance(lhs.dtype, pl.Enum) or lhs.dtype == pl.Categorical( + ordering="physical" + ): + lhs = cast(Element, lhs.to_physical()) + rhs = cast(Element, rhs.to_physical()) + assert isinstance(lhs, pl.Series) + assert isinstance(rhs, pl.Series) + + if lhs.dtype == pl.Categorical(ordering="lexical"): + lhs = cast(Element, lhs.cast(pl.String)) + rhs = cast(Element, rhs.cast(pl.String)) + assert isinstance(lhs, pl.Series) + assert isinstance(rhs, pl.Series) + + if lhs.is_null().equals(rhs.is_null()) and lhs.equals(rhs): + return 0 + + lhs = lhs.to_list() + rhs = rhs.to_list() + + if lhs is None and rhs is None: + return 0 + elif lhs is None: + return 1 if nulls_last else -1 + elif rhs is None: + return -1 if nulls_last else 1 + elif lhs == rhs: + return 0 + elif isinstance(lhs, bool) and isinstance(rhs, bool): + return -1 if (lhs < rhs) ^ descending else 1 + elif isinstance(lhs, datetime.date) and isinstance(rhs, datetime.date): + return -1 if (lhs < rhs) ^ descending else 1 + elif isinstance(lhs, datetime.datetime) and isinstance(rhs, datetime.datetime): + return -1 if (lhs < rhs) ^ descending else 1 + elif isinstance(lhs, datetime.time) and isinstance(rhs, datetime.time): + return -1 if (lhs < rhs) ^ descending else 1 + elif isinstance(lhs, datetime.timedelta) and isinstance(rhs, datetime.timedelta): + return -1 if (lhs < rhs) ^ descending else 1 + elif isinstance(lhs, decimal.Decimal) and isinstance(rhs, decimal.Decimal): + return -1 if (lhs < rhs) ^ descending else 1 + elif isinstance(lhs, int) and isinstance(rhs, int): + return -1 if (lhs < rhs) ^ descending else 1 + elif isinstance(lhs, float) and isinstance(rhs, float): + return -1 if (lhs < rhs) ^ descending else 1 + elif isinstance(lhs, bytes) and isinstance(rhs, bytes): + lhs_b: bytes = lhs + rhs_b: bytes = rhs + + for lh, rh in zip(lhs_b, rhs_b): + o = elem_order_sign(lh, rh, descending=descending, nulls_last=nulls_last) + if o != 0: + return o + + if len(lhs_b) == len(rhs_b): + return 0 + else: + return -1 if (len(lhs_b) < len(rhs_b)) ^ descending else 1 + elif isinstance(lhs, str) and isinstance(rhs, str): + return -1 if (lhs < rhs) ^ descending else 1 + elif isinstance(lhs, list) and isinstance(rhs, list): + for lh, rh in zip(lhs, rhs): + o = elem_order_sign(lh, rh, descending=descending, nulls_last=nulls_last) + if o != 0: + return o + + if len(lhs) == len(rhs): + return 0 + else: + return -1 if (len(lhs) < len(rhs)) ^ descending else 1 + elif isinstance(lhs, dict) and isinstance(rhs, dict): + assert len(lhs) == len(rhs) + + for lh, rh in zip(lhs.values(), rhs.values()): + o = elem_order_sign(lh, rh, descending=descending, nulls_last=nulls_last) + if o != 0: + return o + + return 0 + else: + pytest.fail("type mismatch") + + +def tuple_order( + lhs: tuple[Element, ...], + rhs: tuple[Element, ...], + *, + descending: list[bool], + nulls_last: list[bool], +) -> OrderSign: + assert len(lhs) == len(rhs) + + for lh, rh, dsc, nl in zip(lhs, rhs, descending, nulls_last): + o = elem_order_sign(lh, rh, descending=dsc, nulls_last=nl) + if o != 0: + return o + + return 0 + + +@given( + s=series( + excluded_dtypes=[ + pl.Float32, # We cannot really deal with totalOrder + pl.Float64, # We cannot really deal with totalOrder + pl.Decimal, # Bug: see https://github.com/pola-rs/polars/issues/20308 + pl.Categorical, + ], + max_size=5, + ) +) +@example(s=pl.Series("col0", [None, [None]], pl.List(pl.Int64))) +def test_series_sort_parametric(s: pl.Series) -> None: + for descending in [False, True]: + for nulls_last in [False, True]: + fields = [(descending, nulls_last, False)] + + def cmp( + lhs: Element, + rhs: Element, + descending: bool = descending, + nulls_last: bool = nulls_last, + ) -> OrderSign: + return elem_order_sign( + lhs, rhs, descending=descending, nulls_last=nulls_last + ) + + rows = list(s) + rows.sort(key=functools.cmp_to_key(cmp)) # type: ignore[arg-type, unused-ignore] + + re = s.to_frame()._row_encode(fields) + re_sorted = re.sort() + re_decoded = re_sorted._row_decode([("s", s.dtype)], fields) + + assert_series_equal( + pl.Series("s", rows, s.dtype), re_decoded.get_column("s") + ) + + +@given( + df=dataframes( + excluded_dtypes=[ + pl.Float32, # We cannot really deal with totalOrder + pl.Float64, # We cannot really deal with totalOrder + pl.Decimal, # Bug: see https://github.com/pola-rs/polars/issues/20308 + pl.List, # I am not sure what is broken here. + pl.Array, # I am not sure what is broken here. + pl.Enum, + pl.Categorical, + ], + max_cols=3, + max_size=5, + ) +) +def test_df_sort_parametric(df: pl.DataFrame) -> None: + for i in range(4**df.width): + descending = [((i // (4**j)) % 4) in [2, 3] for j in range(df.width)] + nulls_last = [((i // (4**j)) % 4) in [1, 3] for j in range(df.width)] + + fields = [ + (descending, nulls_last, False) + for (descending, nulls_last) in zip(descending, nulls_last) + ] + + def cmp( + lhs: tuple[Element, ...], + rhs: tuple[Element, ...], + descending: list[bool] = descending, + nulls_last: list[bool] = nulls_last, + ) -> OrderSign: + return tuple_order(lhs, rhs, descending=descending, nulls_last=nulls_last) + + rows = df.rows() + rows.sort(key=functools.cmp_to_key(cmp)) # type: ignore[arg-type, unused-ignore] + + re = df._row_encode(fields) + re_sorted = re.sort() + re_decoded = re_sorted._row_decode(df.schema.items(), fields) + + assert_frame_equal(pl.DataFrame(rows, df.schema, orient="row"), re_decoded) + + +def assert_order_series( + lhs: pl.series.series.ArrayLike, + rhs: pl.series.series.ArrayLike, + dtype: pl._typing.PolarsDataType, +) -> None: + lhs_df = pl.Series("lhs", lhs, dtype).to_frame() + rhs_df = pl.Series("rhs", rhs, dtype).to_frame() + + for descending in [False, True]: + for nulls_last in [False, True]: + field = (descending, nulls_last, False) + l_re = lhs_df._row_encode([field]).cast(pl.Binary) + r_re = rhs_df._row_encode([field]).cast(pl.Binary) + + order = [ + elem_order_sign( + lh[0], rh[0], descending=descending, nulls_last=nulls_last + ) + for (lh, rh) in zip(lhs_df.rows(), rhs_df.rows()) + ] + + assert_series_equal( + l_re < r_re, pl.Series([o == -1 for o in order]), check_names=False + ) + assert_series_equal( + l_re == r_re, pl.Series([o == 0 for o in order]), check_names=False + ) + assert_series_equal( + l_re > r_re, pl.Series([o == 1 for o in order]), check_names=False + ) + + +def parametric_order_base(df: pl.DataFrame) -> None: + lhs = df.get_columns()[0] + rhs = df.get_columns()[1] + + field = (False, False, False) + lhs_re = lhs.to_frame()._row_encode([field]).cast(pl.Binary) + rhs_re = rhs.to_frame()._row_encode([field]).cast(pl.Binary) + + assert_series_equal(lhs < rhs, lhs_re < rhs_re, check_names=False) + assert_series_equal(lhs == rhs, lhs_re == rhs_re, check_names=False) + assert_series_equal(lhs > rhs, lhs_re > rhs_re, check_names=False) + + field = (True, False, False) + lhs_re = lhs.to_frame()._row_encode([field]).cast(pl.Binary) + rhs_re = rhs.to_frame()._row_encode([field]).cast(pl.Binary) + + assert_series_equal(lhs > rhs, lhs_re < rhs_re, check_names=False) + assert_series_equal(lhs == rhs, lhs_re == rhs_re, check_names=False) + assert_series_equal(lhs < rhs, lhs_re > rhs_re, check_names=False) + + +@given( + df=dataframes([column(dtype=pl.Int32), column(dtype=pl.Int32)], allow_null=False) +) +def test_parametric_int_order(df: pl.DataFrame) -> None: + parametric_order_base(df) + + +@given( + df=dataframes([column(dtype=pl.UInt32), column(dtype=pl.UInt32)], allow_null=False) +) +def test_parametric_uint_order(df: pl.DataFrame) -> None: + parametric_order_base(df) + + +@given( + df=dataframes([column(dtype=pl.String), column(dtype=pl.String)], allow_null=False) +) +def test_parametric_string_order(df: pl.DataFrame) -> None: + parametric_order_base(df) + + +@given( + df=dataframes([column(dtype=pl.Binary), column(dtype=pl.Binary)], allow_null=False) +) +def test_parametric_binary_order(df: pl.DataFrame) -> None: + parametric_order_base(df) + + +def test_order_bool() -> None: + dtype = pl.Boolean + assert_order_series([None, False, True], [True, False, None], dtype) + assert_order_series( + [None, False, True], + [True, False, None], + dtype, + ) + + assert_order_series( + [False, False, True, True], + [True, False, True, False], + dtype, + ) + assert_order_series( + [False, False, True, True], + [True, False, True, False], + dtype, + ) + + +def test_order_int() -> None: + dtype = pl.Int32 + assert_order_series([1, 2, 3], [3, 2, 1], dtype) + assert_order_series([-1, 0, 1], [1, 0, -1], dtype) + assert_order_series([None], [None], dtype) + assert_order_series([None], [1], dtype) + + +def test_order_uint() -> None: + dtype = pl.UInt32 + assert_order_series([1, 2, 3], [3, 2, 1], dtype) + assert_order_series([None], [None], dtype) + assert_order_series([None], [1], dtype) + + +def test_order_str() -> None: + dtype = pl.String + assert_order_series(["a", "b", "c"], ["c", "b", "a"], dtype) + assert_order_series(["a", "aa", "aaa"], ["aaa", "aa", "a"], dtype) + assert_order_series(["", "a", "aa"], ["aa", "a", ""], dtype) + assert_order_series([None], [None], dtype) + assert_order_series([None], ["a"], dtype) + + +def test_order_bin() -> None: + dtype = pl.Binary + assert_order_series([b"a", b"b", b"c"], [b"c", b"b", b"a"], dtype) + assert_order_series([b"a", b"aa", b"aaa"], [b"aaa", b"aa", b"a"], dtype) + assert_order_series([b"", b"a", b"aa"], [b"aa", b"a", b""], dtype) + assert_order_series([None], [None], dtype) + assert_order_series([None], [b"a"], dtype) + assert_order_series([None], [b"a"], dtype) + + +def test_order_list() -> None: + dtype = pl.List(pl.Int32) + assert_order_series([[1, 2, 3]], [[3, 2, 1]], dtype) + assert_order_series([[-1, 0, 1]], [[1, 0, -1]], dtype) + assert_order_series([None], [None], dtype) + assert_order_series([None], [[1, 2, 3]], dtype) + assert_order_series([[None, 2, 3]], [[None, 2, 1]], dtype) + + assert_order_series([[]], [[None]], dtype) + assert_order_series([[]], [[1]], dtype) + assert_order_series([[1]], [[1, 2]], dtype) + + +def test_order_array() -> None: + dtype = pl.Array(pl.Int32, 3) + assert_order_series([[1, 2, 3]], [[3, 2, 1]], dtype) + assert_order_series([[-1, 0, 1]], [[1, 0, -1]], dtype) + assert_order_series([None], [None], dtype) + assert_order_series([None], [[1, 2, 3]], dtype) + assert_order_series([[None, 2, 3]], [[None, 2, 1]], dtype) + + +def test_order_masked_array() -> None: + dtype = pl.Array(pl.Int32, 3) + lhs = pl.Series("l", [1, 2, 3], pl.Int32).replace(1, None).reshape((1, 3)) + rhs = pl.Series("r", [3, 2, 1], pl.Int32).replace(3, None).reshape((1, 3)) + assert_order_series(lhs, rhs, dtype) + + +def test_order_masked_struct() -> None: + dtype = pl.Array(pl.Int32, 3) + lhs = pl.Series("l", [1, 2, 3], pl.Int32).replace(1, None).reshape((1, 3)) + rhs = pl.Series("r", [3, 2, 1], pl.Int32).replace(3, None).reshape((1, 3)) + assert_order_series(lhs.to_frame().to_struct(), rhs.to_frame().to_struct(), dtype) + + +def test_order_enum() -> None: + dtype = pl.Enum(["a", "b", "c"]) + + assert_order_series(["a", "b", "c"], ["c", "b", "a"], dtype) + assert_order_series([None], [None], dtype) + assert_order_series([None], ["a"], dtype) diff --git a/py-polars/tests/unit/test_rows.py b/py-polars/tests/unit/test_rows.py new file mode 100644 index 000000000000..d9c68a2e3fe7 --- /dev/null +++ b/py-polars/tests/unit/test_rows.py @@ -0,0 +1,259 @@ +from datetime import date + +import pytest + +import polars as pl +from polars.exceptions import NoRowsReturnedError, TooManyRowsReturnedError +from tests.unit.conftest import INTEGER_DTYPES + + +def test_row_tuple() -> None: + df = pl.DataFrame({"a": ["foo", "bar", "2"], "b": [1, 2, 3], "c": [1.0, 2.0, 3.0]}) + + # return row by index + assert df.row(0) == ("foo", 1, 1.0) + assert df.row(1) == ("bar", 2, 2.0) + assert df.row(-1) == ("2", 3, 3.0) + + # return named row by index + row = df.row(0, named=True) + assert row == {"a": "foo", "b": 1, "c": 1.0} + + # return row by predicate + assert df.row(by_predicate=pl.col("a") == "bar") == ("bar", 2, 2.0) + assert df.row(by_predicate=pl.col("b").is_in([2, 4, 6])) == ("bar", 2, 2.0) + + # return named row by predicate + row = df.row(by_predicate=pl.col("a") == "bar", named=True) + assert row == {"a": "bar", "b": 2, "c": 2.0} + + # expected error conditions + with pytest.raises(TooManyRowsReturnedError): + df.row(by_predicate=pl.col("b").is_in([1, 3, 5])) + + with pytest.raises(NoRowsReturnedError): + df.row(by_predicate=pl.col("a") == "???") + + # cannot set both 'index' and 'by_predicate' + with pytest.raises(ValueError): + df.row(0, by_predicate=pl.col("a") == "bar") + + # must call 'by_predicate' by keyword + with pytest.raises(TypeError): + df.row(None, pl.col("a") == "bar") # type: ignore[call-overload] + + # cannot pass predicate into 'index' + with pytest.raises(TypeError): + df.row(pl.col("a") == "bar") # type: ignore[call-overload] + + # at least one of 'index' and 'by_predicate' must be set + with pytest.raises(ValueError): + df.row() + + +def test_rows() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [1, 2]}) + + # Regular rows + assert df.rows() == [(1, 1), (2, 2)] + assert df.reverse().rows() == [(2, 2), (1, 1)] + + # Named rows + rows = df.rows(named=True) + assert rows == [{"a": 1, "b": 1}, {"a": 2, "b": 2}] + + # Rows with nullarray cols + df = df.with_columns(c=pl.lit(None)) + assert df.schema == {"a": pl.Int64, "b": pl.Int64, "c": pl.Null} + assert df.rows() == [(1, 1, None), (2, 2, None)] + assert df.rows(named=True) == [ + {"a": 1, "b": 1, "c": None}, + {"a": 2, "b": 2, "c": None}, + ] + + +def test_rows_by_key() -> None: + df = pl.DataFrame( + { + "w": ["a", "b", "b", "a"], + "x": ["q", "q", "q", "k"], + "y": [1.0, 2.5, 3.0, 4.5], + "z": [9, 8, 7, 6], + } + ) + + # tuple (unnamed) rows + assert df.rows_by_key("w") == { + "a": [("q", 1.0, 9), ("k", 4.5, 6)], + "b": [("q", 2.5, 8), ("q", 3.0, 7)], + } + assert df.rows_by_key("w", unique=True) == { + "a": ("k", 4.5, 6), + "b": ("q", 3.0, 7), + } + assert df.rows_by_key("w", include_key=True) == { + "a": [("a", "q", 1.0, 9), ("a", "k", 4.5, 6)], + "b": [("b", "q", 2.5, 8), ("b", "q", 3.0, 7)], + } + assert df.rows_by_key("w", include_key=True) == { + key[0]: grp.rows() for key, grp in df.group_by(["w"]) + } + assert df.rows_by_key("w", include_key=True, unique=True) == { + "a": ("a", "k", 4.5, 6), + "b": ("b", "q", 3.0, 7), + } + assert df.rows_by_key(["x", "w"]) == { + ("q", "a"): [(1.0, 9)], + ("q", "b"): [(2.5, 8), (3.0, 7)], + ("k", "a"): [(4.5, 6)], + } + assert df.rows_by_key(["w", "x"], include_key=True) == { + ("a", "q"): [("a", "q", 1.0, 9)], + ("a", "k"): [("a", "k", 4.5, 6)], + ("b", "q"): [("b", "q", 2.5, 8), ("b", "q", 3.0, 7)], + } + assert df.rows_by_key(["w", "x"], include_key=True, unique=True) == { + ("a", "q"): ("a", "q", 1.0, 9), + ("b", "q"): ("b", "q", 3.0, 7), + ("a", "k"): ("a", "k", 4.5, 6), + } + + # dict (named) rows + assert df.rows_by_key("w", named=True) == { + "a": [{"x": "q", "y": 1.0, "z": 9}, {"x": "k", "y": 4.5, "z": 6}], + "b": [{"x": "q", "y": 2.5, "z": 8}, {"x": "q", "y": 3.0, "z": 7}], + } + assert df.rows_by_key("w", named=True, unique=True) == { + "a": {"x": "k", "y": 4.5, "z": 6}, + "b": {"x": "q", "y": 3.0, "z": 7}, + } + assert df.rows_by_key("w", named=True, include_key=True) == { + "a": [ + {"w": "a", "x": "q", "y": 1.0, "z": 9}, + {"w": "a", "x": "k", "y": 4.5, "z": 6}, + ], + "b": [ + {"w": "b", "x": "q", "y": 2.5, "z": 8}, + {"w": "b", "x": "q", "y": 3.0, "z": 7}, + ], + } + assert df.rows_by_key("w", named=True, include_key=True) == { + key[0]: grp.rows(named=True) for key, grp in df.group_by(["w"]) + } + assert df.rows_by_key("w", named=True, include_key=True, unique=True) == { + "a": {"w": "a", "x": "k", "y": 4.5, "z": 6}, + "b": {"w": "b", "x": "q", "y": 3.0, "z": 7}, + } + assert df.rows_by_key(["x", "w"], named=True) == { + ("q", "a"): [{"y": 1.0, "z": 9}], + ("q", "b"): [{"y": 2.5, "z": 8}, {"y": 3.0, "z": 7}], + ("k", "a"): [{"y": 4.5, "z": 6}], + } + assert df.rows_by_key(["w", "x"], named=True, include_key=True) == { + ("a", "q"): [{"w": "a", "x": "q", "y": 1.0, "z": 9}], + ("a", "k"): [{"w": "a", "x": "k", "y": 4.5, "z": 6}], + ("b", "q"): [ + {"w": "b", "x": "q", "y": 2.5, "z": 8}, + {"w": "b", "x": "q", "y": 3.0, "z": 7}, + ], + } + assert df.rows_by_key(["w", "x"], named=True, include_key=True, unique=True) == { + ("a", "q"): {"w": "a", "x": "q", "y": 1.0, "z": 9}, + ("b", "q"): {"w": "b", "x": "q", "y": 3.0, "z": 7}, + ("a", "k"): {"w": "a", "x": "k", "y": 4.5, "z": 6}, + } + + +def test_iter_rows() -> None: + df = pl.DataFrame( + { + "a": [1, 2, 3], + "b": [True, False, None], + } + ).with_columns(pl.Series(["a:b", "c:d", "e:f"]).str.split_exact(":", 1).alias("c")) + + # expected struct values + c1 = {"field_0": "a", "field_1": "b"} + c2 = {"field_0": "c", "field_1": "d"} + c3 = {"field_0": "e", "field_1": "f"} + + # Default iter_rows behaviour + it = df.iter_rows() + assert next(it) == (1, True, c1) + assert next(it) == (2, False, c2) + assert next(it) == (3, None, c3) + with pytest.raises(StopIteration): + next(it) + + # Apply explicit row-buffer size + for sz in (0, 1, 2, 3, 4): + it = df.iter_rows(buffer_size=sz) + assert next(it) == (1, True, c1) + assert next(it) == (2, False, c2) + assert next(it) == (3, None, c3) + with pytest.raises(StopIteration): + next(it) + + # Return named rows + it_named = df.iter_rows(named=True, buffer_size=sz) + row = next(it_named) + assert row == {"a": 1, "b": True, "c": c1} + row = next(it_named) + assert row == {"a": 2, "b": False, "c": c2} + row = next(it_named) + assert row == {"a": 3, "b": None, "c": c3} + + with pytest.raises(StopIteration): + next(it_named) + + # test over chunked frame + df = pl.concat( + [ + pl.DataFrame({"id": [0, 1], "values": ["a", "b"]}), + pl.DataFrame({"id": [2, 3], "values": ["c", "d"]}), + ], + rechunk=False, + ) + assert df.n_chunks() == 2 + assert df.to_dicts() == [ + {"id": 0, "values": "a"}, + {"id": 1, "values": "b"}, + {"id": 2, "values": "c"}, + {"id": 3, "values": "d"}, + ] + + +@pytest.mark.parametrize("primitive", INTEGER_DTYPES) +def test_row_constructor_schema(primitive: pl.DataType) -> None: + result = pl.DataFrame(data=[[1], [2], [3]], schema={"d": primitive}, orient="row") + + assert result.dtypes == [primitive] + assert result.to_dict(as_series=False) == {"d": [1, 2, 3]} + + +def test_row_constructor_uint64() -> None: + # validate init with a valid UInt64 that exceeds Int64 upper bound + df = pl.DataFrame( + data=[[0], [int(2**63) + 1]], schema={"x": pl.UInt64}, orient="row" + ) + assert df.rows() == [(0,), (9223372036854775809,)] + + +def test_physical_row_encoding() -> None: + dt_str = [ + { + "ts": date(2023, 7, 1), + "files": "AGG_202307.xlsx", + "period_bins": [date(2023, 7, 1), date(2024, 1, 1)], + }, + ] + + df = pl.from_dicts(dt_str) + df_groups = df.group_by("period_bins") + assert df_groups.all().to_dicts() == [ + { + "period_bins": [date(2023, 7, 1), date(2024, 1, 1)], + "ts": [date(2023, 7, 1)], + "files": ["AGG_202307.xlsx"], + } + ] diff --git a/py-polars/tests/unit/test_scalar.py b/py-polars/tests/unit/test_scalar.py new file mode 100644 index 000000000000..b00d2a66e043 --- /dev/null +++ b/py-polars/tests/unit/test_scalar.py @@ -0,0 +1,104 @@ +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + + +@pytest.mark.may_fail_auto_streaming +def test_invalid_broadcast() -> None: + df = pl.DataFrame( + { + "a": [100, 103], + "group": [0, 1], + } + ) + with pytest.raises(pl.exceptions.ShapeError): + df.select(pl.col("group").filter(pl.col("group") == 0), "a") + + +@pytest.mark.parametrize( + "dtype", + [ + pl.Null, + pl.Int32, + pl.String, + pl.Enum(["foo"]), + pl.Binary, + pl.List(pl.Int32), + pl.Struct({"a": pl.Int32}), + pl.Array(pl.Int32, 1), + pl.List(pl.List(pl.Int32)), + ], +) +def test_null_literals(dtype: pl.DataType) -> None: + assert ( + pl.DataFrame([pl.Series("a", [1, 2], pl.Int64)]) + .with_columns(pl.lit(None).cast(dtype).alias("b")) + .collect_schema() + .dtypes() + ) == [pl.Int64, dtype] + + +def test_scalar_19957() -> None: + value = 1 + values = [value] * 5 + foo = pl.DataFrame({"foo": values}) + foo_with_bar_from_literal = foo.with_columns(pl.lit(value).alias("bar")) + assert foo_with_bar_from_literal.gather_every(2).to_dict(as_series=False) == { + "foo": [1, 1, 1], + "bar": [1, 1, 1], + } + + +def test_scalar_len_20046() -> None: + df = pl.DataFrame({"a": [1, 2, 3]}) + + assert ( + df.lazy() + .select( + pl.col("a"), + pl.lit(1), + ) + .select(pl.len()) + .collect() + .item() + == 3 + ) + + q = pl.LazyFrame({"a": range(3)}).select( + pl.first("a"), + pl.col("a").alias("b"), + ) + + assert q.select(pl.len()).collect().item() == 3 + + +def test_scalar_identification_function_expr_in_binary() -> None: + x = pl.Series("x", [1, 2, 3]) + assert_frame_equal( + pl.select(x).with_columns(o=pl.col("x").null_count() > 0), + pl.select(x, o=False), + ) + + +def test_scalar_rechunk_20627() -> None: + df = pl.concat(2 * [pl.Series([1])]).filter(pl.Series([False, True])).to_frame() + assert df.rechunk().to_series().n_chunks() == 1 + + +def test_split_scalar_21581() -> None: + df = pl.DataFrame({"a": [1.0, 2.0, 3.0]}) + df = df.with_columns( + [ + pl.col("a").shift(-1).alias("next_a"), + pl.lit(True).alias("lit"), + ] + ) + + assert df.filter(df["next_a"] != 99.0).with_columns( + [pl.lit(False).alias("lit")] + ).to_dict(as_series=False) == { + "a": [1.0, 2.0], + "next_a": [2.0, 3.0], + "lit": [False, False], + } diff --git a/py-polars/tests/unit/test_schema.py b/py-polars/tests/unit/test_schema.py new file mode 100644 index 000000000000..bdfc4bd21195 --- /dev/null +++ b/py-polars/tests/unit/test_schema.py @@ -0,0 +1,371 @@ +import pickle +from datetime import datetime +from typing import Any + +import pytest + +import polars as pl +from polars.testing.asserts.frame import assert_frame_equal + + +def test_schema() -> None: + s = pl.Schema({"foo": pl.Int8(), "bar": pl.String()}) + + assert s["foo"] == pl.Int8() + assert s["bar"] == pl.String() + assert s.len() == 2 + assert s.names() == ["foo", "bar"] + assert s.dtypes() == [pl.Int8(), pl.String()] + + with pytest.raises( + TypeError, + match="dtypes must be fully-specified, got: List", + ): + pl.Schema({"foo": pl.String, "bar": pl.List}) + + +@pytest.mark.parametrize( + "schema", + [ + pl.Schema(), + pl.Schema({"foo": pl.Int8()}), + pl.Schema({"foo": pl.Datetime("us"), "bar": pl.String()}), + pl.Schema( + { + "foo": pl.UInt32(), + "bar": pl.Categorical("physical"), + "baz": pl.Struct({"x": pl.Int64(), "y": pl.Float64()}), + } + ), + ], +) +def test_schema_empty_frame(schema: pl.Schema) -> None: + assert_frame_equal( + schema.to_frame(), + pl.DataFrame(schema=schema), + ) + + +def test_schema_equality() -> None: + s1 = pl.Schema({"foo": pl.Int8(), "bar": pl.Float64()}) + s2 = pl.Schema({"foo": pl.Int8(), "bar": pl.String()}) + s3 = pl.Schema({"bar": pl.Float64(), "foo": pl.Int8()}) + + assert s1 == s1 + assert s2 == s2 + assert s3 == s3 + assert s1 != s2 + assert s1 != s3 + assert s2 != s3 + + s4 = pl.Schema({"foo": pl.Datetime("us"), "bar": pl.Duration("ns")}) + s5 = pl.Schema({"foo": pl.Datetime("ns"), "bar": pl.Duration("us")}) + s6 = {"foo": pl.Datetime, "bar": pl.Duration} + + assert s4 != s5 + assert s4 != s6 + + +def test_schema_parse_python_dtypes() -> None: + cardinal_directions = pl.Enum(["north", "south", "east", "west"]) + + s = pl.Schema({"foo": pl.List(pl.Int32), "bar": int, "baz": cardinal_directions}) # type: ignore[arg-type] + s["ham"] = datetime + + assert s["foo"] == pl.List(pl.Int32) + assert s["bar"] == pl.Int64 + assert s["baz"] == cardinal_directions + assert s["ham"] == pl.Datetime("us") + + assert s.len() == 4 + assert s.names() == ["foo", "bar", "baz", "ham"] + assert s.dtypes() == [pl.List, pl.Int64, cardinal_directions, pl.Datetime("us")] + + assert list(s.to_python().values()) == [list, int, str, datetime] + assert [tp.to_python() for tp in s.dtypes()] == [list, int, str, datetime] + + +def test_schema_picklable() -> None: + s = pl.Schema( + { + "foo": pl.Int8(), + "bar": pl.String(), + "ham": pl.Struct({"x": pl.List(pl.Date)}), + } + ) + pickled = pickle.dumps(s) + s2 = pickle.loads(pickled) + assert s == s2 + + +def test_schema_python() -> None: + input = { + "foo": pl.Int8(), + "bar": pl.String(), + "baz": pl.Categorical("lexical"), + "ham": pl.Object(), + "spam": pl.Struct({"time": pl.List(pl.Duration), "dist": pl.Float64}), + } + expected = { + "foo": int, + "bar": str, + "baz": str, + "ham": object, + "spam": dict, + } + for schema in (input, input.items(), list(input.items())): + s = pl.Schema(schema) + assert expected == s.to_python() + + +def test_schema_in_map_elements_returns_scalar() -> None: + schema = pl.Schema([("portfolio", pl.String()), ("irr", pl.Float64())]) + + ldf = pl.LazyFrame( + { + "portfolio": ["A", "A", "B", "B"], + "amounts": [100.0, -110.0] * 2, + } + ) + q = ldf.group_by("portfolio").agg( + pl.col("amounts") + .map_elements( + lambda x: float(x.sum()), return_dtype=pl.Float64, returns_scalar=True + ) + .alias("irr") + ) + assert q.collect_schema() == schema + assert q.collect().schema == schema + + +def test_ir_cache_unique_18198() -> None: + lf = pl.LazyFrame({"a": [1]}) + lf.collect_schema() + assert pl.concat([lf, lf]).collect().to_dict(as_series=False) == {"a": [1, 1]} + + +def test_schema_functions_in_agg_with_literal_arg_19011() -> None: + q = ( + pl.LazyFrame({"a": [1, 2, 3, None, 5]}) + .rolling(index_column=pl.int_range(pl.len()).alias("idx"), period="3i") + .agg(pl.col("a").fill_null(0).alias("a_1"), pl.col("a").pow(2.0).alias("a_2")) + ) + assert q.collect_schema() == pl.Schema( + [("idx", pl.Int64), ("a_1", pl.List(pl.Int64)), ("a_2", pl.List(pl.Float64))] + ) + + +def test_lazy_explode_in_agg_schema_19562() -> None: + def new_df_check_schema( + value: dict[str, Any], schema: dict[str, Any] + ) -> pl.DataFrame: + df = pl.DataFrame(value) + assert df.schema == schema + return df + + lf = pl.LazyFrame({"a": [1], "b": [[1]]}) + + q = lf.group_by("a").agg(pl.col("b")) + schema = {"a": pl.Int64, "b": pl.List(pl.List(pl.Int64))} + + assert q.collect_schema() == schema + assert_frame_equal( + q.collect(), new_df_check_schema({"a": [1], "b": [[[1]]]}, schema) + ) + + q = lf.group_by("a").agg(pl.col("b").explode()) + schema = {"a": pl.Int64, "b": pl.List(pl.Int64)} + + assert q.collect_schema() == schema + assert_frame_equal(q.collect(), new_df_check_schema({"a": [1], "b": [[1]]}, schema)) + + q = lf.group_by("a").agg(pl.col("b").explode().explode()) + schema = {"a": pl.Int64, "b": pl.List(pl.Int64)} + + assert q.collect_schema() == schema + assert_frame_equal(q.collect(), new_df_check_schema({"a": [1], "b": [[1]]}, schema)) + + # 2x nested + lf = pl.LazyFrame({"a": [1], "b": [[[1]]]}) + + q = lf.group_by("a").agg(pl.col("b")) + schema = { + "a": pl.Int64, + "b": pl.List(pl.List(pl.List(pl.Int64))), + } + + assert q.collect_schema() == schema + assert_frame_equal( + q.collect(), new_df_check_schema({"a": [1], "b": [[[[1]]]]}, schema) + ) + + q = lf.group_by("a").agg(pl.col("b").explode()) + schema = {"a": pl.Int64, "b": pl.List(pl.List(pl.Int64))} + + assert q.collect_schema() == schema + assert_frame_equal( + q.collect(), new_df_check_schema({"a": [1], "b": [[[1]]]}, schema) + ) + + q = lf.group_by("a").agg(pl.col("b").explode().explode()) + schema = {"a": pl.Int64, "b": pl.List(pl.Int64)} + + assert q.collect_schema() == schema + assert_frame_equal(q.collect(), new_df_check_schema({"a": [1], "b": [[1]]}, schema)) + + +def test_lazy_nested_function_expr_agg_schema() -> None: + q = ( + pl.LazyFrame({"k": [1, 1, 2]}) + .group_by(pl.first(), maintain_order=True) + .agg(o=pl.int_range(pl.len()).reverse() < 1) + ) + + assert q.collect_schema() == {"k": pl.Int64, "o": pl.List(pl.Boolean)} + assert_frame_equal( + q.collect(), pl.DataFrame({"k": [1, 2], "o": [[False, True], [True]]}) + ) + + +def test_lazy_agg_scalar_return_schema() -> None: + q = pl.LazyFrame({"k": [1]}).group_by("k").agg(pl.col("k").null_count().alias("o")) + + schema = {"k": pl.Int64, "o": pl.UInt32} + assert q.collect_schema() == schema + assert_frame_equal(q.collect(), pl.DataFrame({"k": 1, "o": 0}, schema=schema)) + + +def test_lazy_agg_nested_expr_schema() -> None: + q = ( + pl.LazyFrame({"k": [1]}) + .group_by("k") + .agg( + ( + ( + (pl.col("k").reverse().shuffle() + 1) + + pl.col("k").shuffle().reverse() + ) + .shuffle() + .reverse() + .sum() + * 0 + ).alias("o") + ) + ) + + schema = {"k": pl.Int64, "o": pl.Int64} + assert q.collect_schema() == schema + assert_frame_equal(q.collect(), pl.DataFrame({"k": 1, "o": 0}, schema=schema)) + + +def test_lazy_agg_lit_explode() -> None: + q = ( + pl.LazyFrame({"k": [1]}) + .group_by("k") + .agg(pl.lit(1, dtype=pl.Int64).explode().alias("o")) + ) + + schema = {"k": pl.Int64, "o": pl.List(pl.Int64)} + assert q.collect_schema() == schema + assert_frame_equal(q.collect(), pl.DataFrame({"k": 1, "o": [[1]]}, schema=schema)) # type: ignore[arg-type] + + +@pytest.mark.parametrize( + "expr_op", [ + "approx_n_unique", "arg_max", "arg_min", "bitwise_and", "bitwise_or", + "bitwise_xor", "count", "entropy", "first", "has_nulls", "implode", "kurtosis", + "last", "len", "lower_bound", "max", "mean", "median", "min", "n_unique", "nan_max", + "nan_min", "null_count", "product", "sample", "skew", "std", "sum", "upper_bound", + "var" + ] +) # fmt: skip +@pytest.mark.parametrize("lhs", [pl.col("b"), pl.lit(1, dtype=pl.Int64).alias("b")]) +def test_lazy_agg_to_scalar_schema_19752(lhs: pl.Expr, expr_op: str) -> None: + op = getattr(pl.Expr, expr_op) + + lf = pl.LazyFrame({"a": 1, "b": 1}) + + q = lf.group_by("a").agg(lhs.reverse().pipe(op)) + assert q.collect_schema() == q.collect().collect_schema() + + q = lf.group_by("a").agg(lhs.shuffle().reverse().pipe(op)) + + assert q.collect_schema() == q.collect().collect_schema() + + +def test_lazy_agg_schema_after_elementwise_19984() -> None: + lf = pl.LazyFrame({"a": 1, "b": 1}) + + q = lf.group_by("a").agg(pl.col("b").first().fill_null(0)) + assert q.collect_schema() == q.collect().collect_schema() + + q = lf.group_by("a").agg(pl.col("b").first().fill_null(0).fill_null(0)) + assert q.collect_schema() == q.collect().collect_schema() + + q = lf.group_by("a").agg(pl.col("b").first() + 1) + assert q.collect_schema() == q.collect().collect_schema() + + q = lf.group_by("a").agg(1 + pl.col("b").first()) + assert q.collect_schema() == q.collect().collect_schema() + + +@pytest.mark.parametrize( + "expr", [pl.col("b"), pl.col("b").sum(), pl.col("b").reverse()] +) +@pytest.mark.parametrize("mapping_strategy", ["explode", "join", "group_to_rows"]) +def test_lazy_window_schema(expr: pl.Expr, mapping_strategy: str) -> None: + q = pl.LazyFrame({"a": 1, "b": 1}).select( + expr.over("a", mapping_strategy=mapping_strategy) # type: ignore[arg-type] + ) + + assert q.collect_schema() == q.collect().collect_schema() + + +def test_lazy_explode_schema() -> None: + lf = pl.LazyFrame({"k": [1], "x": pl.Series([[1]], dtype=pl.Array(pl.Int64, 1))}) + + q = lf.select(pl.col("x").explode()) + assert q.collect_schema() == {"x": pl.Int64} + + q = lf.select(pl.col("x").arr.explode()) + assert q.collect_schema() == {"x": pl.Int64} + + lf = pl.LazyFrame({"k": [1], "x": pl.Series([[1]], dtype=pl.List(pl.Int64))}) + + q = lf.select(pl.col("x").explode()) + assert q.collect_schema() == {"x": pl.Int64} + + q = lf.select(pl.col("x").list.explode()) + assert q.collect_schema() == {"x": pl.Int64} + + # `LazyFrame.explode()` goes through a different codepath than `Expr.expode` + lf = pl.LazyFrame().with_columns( + pl.Series([[1]], dtype=pl.List(pl.Int64)).alias("list"), + pl.Series([[1]], dtype=pl.Array(pl.Int64, 1)).alias("array"), + ) + + q = lf.explode("*") + assert q.collect_schema() == {"list": pl.Int64, "array": pl.Int64} + + q = lf.explode("list") + assert q.collect_schema() == {"list": pl.Int64, "array": pl.Array(pl.Int64, 1)} + + +def test_raise_subnodes_18787() -> None: + df = pl.DataFrame({"a": [1], "b": [2]}) + + with pytest.raises(pl.exceptions.ColumnNotFoundError): + ( + df.select(pl.struct(pl.all())).select( + pl.first().struct.field("a", "b").filter(pl.col("foo") == 1) + ) + ) + + +def test_scalar_agg_schema_20044() -> None: + assert ( + pl.DataFrame(None, schema={"a": pl.Int64, "b": pl.String, "c": pl.String}) + .with_columns(d=pl.col("a").max()) + .group_by("c") + .agg(pl.col("d").mean()) + ).schema == pl.Schema([("c", pl.String), ("d", pl.Float64)]) diff --git a/py-polars/tests/unit/test_selectors.py b/py-polars/tests/unit/test_selectors.py new file mode 100644 index 000000000000..e18d5c51e4cc --- /dev/null +++ b/py-polars/tests/unit/test_selectors.py @@ -0,0 +1,898 @@ +from collections import OrderedDict +from datetime import datetime, timedelta +from decimal import Decimal as PyDecimal +from typing import Any +from zoneinfo import ZoneInfo + +import pytest + +import polars as pl +import polars.selectors as cs +from polars._typing import SelectorType +from polars._utils.various import qualified_type_name +from polars.exceptions import ColumnNotFoundError, InvalidOperationError +from polars.selectors import expand_selector, is_selector +from polars.testing import assert_frame_equal +from tests.unit.conftest import INTEGER_DTYPES, TEMPORAL_DTYPES + + +def assert_repr_equals(item: Any, expected: str) -> None: + """Assert that the repr of an item matches the expected string.""" + if not isinstance(expected, str): + msg = f"`expected` must be a string; found {qualified_type_name(expected)!r}" + raise TypeError(msg) + assert repr(item) == expected + + +@pytest.fixture +def df() -> pl.DataFrame: + # set up an empty dataframe with plenty of columns of various dtypes + df = pl.DataFrame( + schema={ + "abc": pl.UInt16, + "bbb": pl.UInt32, + "cde": pl.Float64, + "def": pl.Float32, + "eee": pl.Boolean, + "fgg": pl.Boolean, + "ghi": pl.Time, + "JJK": pl.Date, + "Lmn": pl.Duration, + "opp": pl.Datetime("ms"), + "qqR": pl.String, + }, + ) + return df + + +def test_selector_all(df: pl.DataFrame) -> None: + assert df.schema == df.select(cs.all()).schema + assert df.select(~cs.all()).schema == {} + assert df.schema == df.select(~(~cs.all())).schema + assert df.select(cs.all() & pl.col("abc")).schema == {"abc": pl.UInt16} + + +def test_selector_alpha() -> None: + df = pl.DataFrame( + schema=["Hello 123", "こんにちは (^_^)", "مرحبا", "你好!", "World"], + ) + # alphabetical-only (across all languages) + assert expand_selector(df, cs.alpha()) == ("مرحبا", "World") + assert expand_selector(df, cs.alpha(ascii_only=True)) == ("World",) + assert expand_selector(df, ~cs.alpha()) == ( + "Hello 123", + "こんにちは (^_^)", + "你好!", + ) + assert expand_selector(df, ~cs.alpha(ignore_spaces=True)) == ( + "Hello 123", + "こんにちは (^_^)", + "你好!", + ) + + # alphanumeric-only (across all languages) + assert expand_selector(df, cs.alphanumeric(True)) == ("World",) + assert expand_selector(df, ~cs.alphanumeric()) == ( + "Hello 123", + "こんにちは (^_^)", + "你好!", + ) + assert expand_selector(df, ~cs.alphanumeric(True, ignore_spaces=True)) == ( + "こんにちは (^_^)", + "مرحبا", + "你好!", + ) + assert expand_selector(df, cs.alphanumeric(ignore_spaces=True)) == ( + "Hello 123", + "مرحبا", + "World", + ) + assert expand_selector(df, ~cs.alphanumeric(ignore_spaces=True)) == ( + "こんにちは (^_^)", + "你好!", + ) + + +def test_selector_by_dtype(df: pl.DataFrame) -> None: + assert df.select(cs.boolean() | cs.by_dtype(pl.UInt16)).schema == OrderedDict( + { + "abc": pl.UInt16, + "eee": pl.Boolean, + "fgg": pl.Boolean, + } + ) + assert df.select( + ~cs.by_dtype(*INTEGER_DTYPES, *TEMPORAL_DTYPES) + ).schema == pl.Schema( + { + "cde": pl.Float64(), + "def": pl.Float32(), + "eee": pl.Boolean(), + "fgg": pl.Boolean(), + "qqR": pl.String(), + } + ) + assert df.select( + cs.by_dtype(pl.Datetime("ns"), pl.Float32, pl.UInt32, pl.Date) + ).schema == pl.Schema( + { + "bbb": pl.UInt32, + "def": pl.Float32, + "JJK": pl.Date, + } + ) + + # select using python types + assert df.select(cs.by_dtype(int, float)).schema == pl.Schema( + { + "abc": pl.UInt16, + "bbb": pl.UInt32, + "cde": pl.Float64, + "def": pl.Float32, + } + ) + assert df.select(cs.by_dtype(bool, datetime, timedelta)).schema == pl.Schema( + { + "eee": pl.Boolean(), + "fgg": pl.Boolean(), + "Lmn": pl.Duration("us"), + "opp": pl.Datetime("ms"), + } + ) + + # cover timezones and decimal + dfx = pl.DataFrame( + {"idx": [], "dt1": [], "dt2": []}, + schema_overrides={ + "idx": pl.Decimal(24), + "dt1": pl.Datetime("ms"), + "dt2": pl.Datetime(time_zone="Asia/Tokyo"), + }, + ) + assert dfx.select(cs.by_dtype(PyDecimal)).schema == pl.Schema( + {"idx": pl.Decimal(24)}, + ) + assert dfx.select(cs.by_dtype(pl.Datetime(time_zone="*"))).schema == pl.Schema( + {"dt2": pl.Datetime(time_zone="Asia/Tokyo")} + ) + assert dfx.select(cs.by_dtype(pl.Datetime("ms", None))).schema == pl.Schema( + {"dt1": pl.Datetime("ms")}, + ) + for dt in (datetime, pl.Datetime): + assert dfx.select(cs.by_dtype(dt)).schema == pl.Schema( + {"dt1": pl.Datetime("ms"), "dt2": pl.Datetime(time_zone="Asia/Tokyo")}, + ) + + # empty selection selects nothing + assert df.select(cs.by_dtype()).schema == {} + assert df.select(cs.by_dtype([])).schema == {} + + # expected errors + with pytest.raises(TypeError): + df.select(cs.by_dtype(999)) # type: ignore[arg-type] + + +def test_selector_by_index(df: pl.DataFrame) -> None: + # one or more +ve indexes + assert df.select(cs.by_index(0)).columns == ["abc"] + assert df.select(pl.nth(0, 1, 2)).columns == ["abc", "bbb", "cde"] + assert df.select(cs.by_index(0, 1, 2)).columns == ["abc", "bbb", "cde"] + + # one or more -ve indexes + assert df.select(cs.by_index(-1)).columns == ["qqR"] + assert df.select(cs.by_index(-3, -2, -1)).columns == ["Lmn", "opp", "qqR"] + + # range objects + assert df.select(cs.by_index(range(3))).columns == ["abc", "bbb", "cde"] + assert df.select(cs.by_index(0, range(-3, 0))).columns == [ + "abc", + "Lmn", + "opp", + "qqR", + ] + + # exclude by index + assert df.select(~cs.by_index(range(0, df.width, 2))).columns == [ + "bbb", + "def", + "fgg", + "JJK", + "opp", + ] + + # expected errors + with pytest.raises(ColumnNotFoundError): + df.select(cs.by_index(999)) + + for invalid in ("one", ["two", "three"]): + with pytest.raises(TypeError): + df.select(cs.by_index(invalid)) # type: ignore[arg-type] + + +def test_selector_by_name(df: pl.DataFrame) -> None: + for selector in ( + cs.by_name("abc", "cde"), + cs.by_name("abc") | pl.col("cde"), + ): + assert df.select(selector).columns == ["abc", "cde"] + + assert df.select(~cs.by_name("abc", "cde", "ghi", "Lmn", "opp", "eee")).columns == [ + "bbb", + "def", + "fgg", + "JJK", + "qqR", + ] + assert df.select(cs.by_name()).columns == [] + assert df.select(cs.by_name([])).columns == [] + + selected_cols = df.select( + cs.by_name("???", "fgg", "!!!", require_all=False) + ).columns + assert selected_cols == ["fgg"] + + for missing_column in ("missing", "???"): + assert df.select(cs.by_name(missing_column, require_all=False)).columns == [] + + # check "by_name & col" + for selector_expr, expected in ( + (cs.by_name("abc", "cde") & pl.col("ghi"), []), + (cs.by_name("abc", "cde") & pl.col("cde"), ["cde"]), + (pl.col("cde") & cs.by_name("cde", "abc"), ["cde"]), + ): + assert df.select(selector_expr).columns == expected + + # check "by_name & by_name" + assert df.select( + cs.by_name("abc", "cde", "def", "eee") & cs.by_name("cde", "eee", "fgg") + ).columns == ["cde", "eee"] + + # expected errors + with pytest.raises(ColumnNotFoundError, match="xxx"): + df.select(cs.by_name("xxx", "fgg", "!!!")) + + with pytest.raises(ColumnNotFoundError): + df.select(cs.by_name("stroopwafel")) + + with pytest.raises(TypeError): + df.select(cs.by_name(999)) # type: ignore[arg-type] + + +def test_selector_contains(df: pl.DataFrame) -> None: + assert df.select(cs.contains("b")).columns == ["abc", "bbb"] + assert df.select(cs.contains(("e", "g"))).columns == [ # type: ignore[arg-type] + "cde", + "def", + "eee", + "fgg", + "ghi", + ] + assert df.select(~cs.contains("b", "e", "g")).columns == [ + "JJK", + "Lmn", + "opp", + "qqR", + ] + assert df.select(cs.contains("ee", "x")).columns == ["eee"] + + # expected errors + with pytest.raises(TypeError): + df.select(cs.contains(999)) # type: ignore[arg-type] + + +def test_selector_datetime(df: pl.DataFrame) -> None: + assert df.select(cs.datetime()).schema == {"opp": pl.Datetime("ms")} + assert df.select(cs.datetime("ns")).schema == {} + + all_columns = set(df.columns) + assert set(df.select(~cs.datetime()).columns) == all_columns - {"opp"} + + df = pl.DataFrame( + schema={ + "d1": pl.Datetime("ns", "Asia/Tokyo"), + "d2": pl.Datetime("ns", "UTC"), + "d3": pl.Datetime("us", "UTC"), + "d4": pl.Datetime("us"), + "d5": pl.Datetime("ms"), + }, + ) + assert df.select(cs.datetime()).columns == ["d1", "d2", "d3", "d4", "d5"] + assert df.select(~cs.datetime()).schema == {} + + assert df.select(cs.datetime(["ms", "ns"])).columns == ["d1", "d2", "d5"] + assert df.select(cs.datetime(["ms", "ns"], time_zone="*")).columns == ["d1", "d2"] + + assert df.select(~cs.datetime(["ms", "ns"])).columns == ["d3", "d4"] + assert df.select(~cs.datetime(["ms", "ns"], time_zone="*")).columns == [ + "d3", + "d4", + "d5", + ] + assert df.select( + cs.datetime(time_zone=["UTC", "Asia/Tokyo", "Europe/London"]) + ).columns == ["d1", "d2", "d3"] + + assert df.select(cs.datetime(time_zone="*")).columns == ["d1", "d2", "d3"] + assert df.select(cs.datetime("ns", time_zone="*")).columns == ["d1", "d2"] + assert df.select(cs.datetime(time_zone="UTC")).columns == ["d2", "d3"] + assert df.select(cs.datetime("us", time_zone="UTC")).columns == ["d3"] + assert df.select(cs.datetime(time_zone="Asia/Tokyo")).columns == ["d1"] + assert df.select(cs.datetime("us", time_zone="Asia/Tokyo")).columns == [] + assert df.select(cs.datetime(time_zone=None)).columns == ["d4", "d5"] + assert df.select(cs.datetime("ns", time_zone=None)).columns == [] + + assert df.select(~cs.datetime(time_zone="*")).columns == ["d4", "d5"] + assert df.select(~cs.datetime("ns", time_zone="*")).columns == ["d3", "d4", "d5"] + assert df.select(~cs.datetime(time_zone="UTC")).columns == ["d1", "d4", "d5"] + assert df.select(~cs.datetime("us", time_zone="UTC")).columns == [ + "d1", + "d2", + "d4", + "d5", + ] + assert df.select(~cs.datetime(time_zone="Asia/Tokyo")).columns == [ + "d2", + "d3", + "d4", + "d5", + ] + assert df.select(~cs.datetime("us", time_zone="Asia/Tokyo")).columns == [ + "d1", + "d2", + "d3", + "d4", + "d5", + ] + assert df.select(~cs.datetime(time_zone=None)).columns == ["d1", "d2", "d3"] + assert df.select(~cs.datetime("ns", time_zone=None)).columns == [ + "d1", + "d2", + "d3", + "d4", + "d5", + ] + assert df.select(cs.datetime("ns")).columns == ["d1", "d2"] + assert df.select(cs.datetime("us")).columns == ["d3", "d4"] + assert df.select(cs.datetime("ms")).columns == ["d5"] + + # bonus check; significantly more verbose, but equivalent to a selector - + assert ( + df.select( + pl.all().exclude( + pl.Datetime("ms", time_zone="*"), pl.Datetime("ns", time_zone="*") + ) + ).columns + == df.select(~cs.datetime(["ms", "ns"], time_zone="*")).columns + ) + + # expected errors + with pytest.raises(TypeError): + df.select(cs.datetime(999)) # type: ignore[arg-type] + + +def test_select_decimal(df: pl.DataFrame) -> None: + assert df.select(cs.decimal()).columns == [] + df = pl.DataFrame( + schema={ + "zz0": pl.Float64, + "zz1": pl.Decimal, + "zz2": pl.Decimal(10, 10), + } + ) + print(df.select(cs.numeric()).columns) + assert df.select(cs.numeric()).columns == ["zz0", "zz1", "zz2"] + assert df.select(cs.decimal()).columns == ["zz1", "zz2"] + assert df.select(~cs.decimal()).columns == ["zz0"] + + +def test_selector_digit() -> None: + df = pl.DataFrame(schema=["Portfolio", "Year", "2000", "2010", "2020", "✌️"]) + assert expand_selector(df, cs.digit()) == ("2000", "2010", "2020") + assert expand_selector(df, ~cs.digit()) == ("Portfolio", "Year", "✌️") + + df = pl.DataFrame({"১৯৯৯": [1999], "২০৭৭": [2077], "3000": [3000]}) + assert expand_selector(df, cs.digit()) == tuple(df.columns) + assert expand_selector(df, cs.digit(ascii_only=True)) == ("3000",) + assert expand_selector(df, (cs.digit() - cs.digit(True))) == ("১৯৯৯", "২০৭৭") + + +def test_selector_drop(df: pl.DataFrame) -> None: + dfd = df.drop(cs.numeric(), cs.temporal()) + assert dfd.columns == ["eee", "fgg", "qqR"] + + df = pl.DataFrame([["x"], [1]], schema={"foo": pl.String, "foo_right": pl.Int8}) + assert df.drop(cs.ends_with("_right")).schema == {"foo": pl.String()} + + +def test_selector_duration(df: pl.DataFrame) -> None: + assert df.select(cs.duration("ms")).columns == [] + assert df.select(cs.duration(["ms", "ns"])).columns == [] + assert expand_selector(df, cs.duration()) == ("Lmn",) + + df = pl.DataFrame( + schema={ + "d1": pl.Duration("ns"), + "d2": pl.Duration("us"), + "d3": pl.Duration("ms"), + }, + ) + assert expand_selector(df, cs.duration()) == ("d1", "d2", "d3") + assert expand_selector(df, cs.duration("us")) == ("d2",) + assert expand_selector(df, cs.duration(["ms", "ns"])) == ("d1", "d3") + + +def test_selector_ends_with(df: pl.DataFrame) -> None: + assert df.select(cs.ends_with("e")).columns == ["cde", "eee"] + assert df.select(cs.ends_with("ee")).columns == ["eee"] + assert df.select(cs.ends_with("e", "g", "i", "n", "p")).columns == [ + "cde", + "eee", + "fgg", + "ghi", + "Lmn", + "opp", + ] + assert df.select(~cs.ends_with("b", "e", "g", "i", "n", "p")).columns == [ + "abc", + "def", + "JJK", + "qqR", + ] + + # expected errors + with pytest.raises(TypeError): + df.select(cs.ends_with(999)) # type: ignore[arg-type] + + +def test_selector_expand() -> None: + schema = { + "id": pl.Int64, + "desc": pl.String, + "count": pl.UInt32, + "value": pl.Float64, + } + + expanded = cs.expand_selector(schema, cs.numeric() - cs.unsigned_integer()) + assert expanded == ("id", "value") + + with pytest.raises(TypeError, match="expected a selector"): + cs.expand_selector(schema, pl.exclude("id", "count")) + + with pytest.raises(TypeError, match="expected a selector"): + cs.expand_selector(schema, pl.col("value") // 100) + + expanded = cs.expand_selector(schema, pl.exclude("id", "count"), strict=False) + assert expanded == ("desc", "value") + + expanded = cs.expand_selector(schema, cs.numeric().exclude("id"), strict=False) + assert expanded == ("count", "value") + + +def test_selector_first_last(df: pl.DataFrame) -> None: + assert df.select(cs.first()).columns == ["abc"] + assert df.select(cs.last()).columns == ["qqR"] + + all_columns = set(df.columns) + assert set(df.select(~cs.first()).columns) == (all_columns - {"abc"}) + assert set(df.select(~cs.last()).columns) == (all_columns - {"qqR"}) + + +def test_selector_float(df: pl.DataFrame) -> None: + assert df.select(cs.float()).schema == { + "cde": pl.Float64, + "def": pl.Float32, + } + all_columns = set(df.columns) + assert set(df.select(~cs.float()).columns) == (all_columns - {"cde", "def"}) + + +def test_selector_integer(df: pl.DataFrame) -> None: + assert df.select(cs.integer()).schema == { + "abc": pl.UInt16, + "bbb": pl.UInt32, + } + all_columns = set(df.columns) + assert set(df.select(~cs.integer()).columns) == (all_columns - {"abc", "bbb"}) + + +def test_selector_matches(df: pl.DataFrame) -> None: + assert df.select(cs.matches(r"^(?i)[E-N]{3}$")).columns == [ + "eee", + "fgg", + "ghi", + "JJK", + "Lmn", + ] + assert df.select(~cs.matches(r"^(?i)[E-N]{3}$")).columns == [ + "abc", + "bbb", + "cde", + "def", + "opp", + "qqR", + ] + + +def test_selector_miscellaneous(df: pl.DataFrame) -> None: + assert df.select(cs.string()).columns == ["qqR"] + assert df.select(cs.categorical()).columns == [] + + test_schema = { + "abc": pl.String, + "mno": pl.Binary, + "tuv": pl.Object, + "xyz": pl.Categorical, + } + assert expand_selector(test_schema, cs.binary()) == ("mno",) + assert expand_selector(test_schema, ~cs.binary()) == ("abc", "tuv", "xyz") + assert expand_selector(test_schema, cs.object()) == ("tuv",) + assert expand_selector(test_schema, ~cs.object()) == ("abc", "mno", "xyz") + assert expand_selector(test_schema, cs.categorical()) == ("xyz",) + assert expand_selector(test_schema, ~cs.categorical()) == ("abc", "mno", "tuv") + + +def test_selector_numeric(df: pl.DataFrame) -> None: + assert df.select(cs.numeric()).schema == { + "abc": pl.UInt16, + "bbb": pl.UInt32, + "cde": pl.Float64, + "def": pl.Float32, + } + assert df.select(cs.numeric().exclude(pl.UInt16)).schema == { + "bbb": pl.UInt32, + "cde": pl.Float64, + "def": pl.Float32, + } + all_columns = set(df.columns) + assert set(df.select(~cs.numeric()).columns) == ( + all_columns - {"abc", "bbb", "cde", "def"} + ) + + +def test_selector_startswith(df: pl.DataFrame) -> None: + assert df.select(cs.starts_with("a")).columns == ["abc"] + assert df.select(cs.starts_with("ee")).columns == ["eee"] + assert df.select(cs.starts_with("d", "e", "f", "g", "h", "i", "j")).columns == [ + "def", + "eee", + "fgg", + "ghi", + ] + assert df.select(~cs.starts_with("d", "e", "f", "g", "h", "i", "j")).columns == [ + "abc", + "bbb", + "cde", + "JJK", + "Lmn", + "opp", + "qqR", + ] + # expected errors + with pytest.raises(TypeError): + df.select(cs.starts_with(999)) # type: ignore[arg-type] + + +def test_selector_temporal(df: pl.DataFrame) -> None: + assert df.select(cs.temporal()).schema == { + "ghi": pl.Time, + "JJK": pl.Date, + "Lmn": pl.Duration("us"), + "opp": pl.Datetime("ms"), + } + all_columns = set(df.columns) + assert set(df.select(~cs.temporal()).columns) == ( + all_columns - {"ghi", "JJK", "Lmn", "opp"} + ) + assert df.select(cs.time()).schema == {"ghi": pl.Time} + assert df.select(cs.date() | cs.time()).schema == {"ghi": pl.Time, "JJK": pl.Date} + + +def test_selector_temporal_13665() -> None: + df = pl.DataFrame( + data={"utc": [datetime(1950, 7, 5), datetime(2099, 12, 31)]}, + schema={"utc": pl.Datetime(time_zone="UTC")}, + ).with_columns( + idx=pl.int_range(0, 2), + utc=pl.col("utc").dt.replace_time_zone(None), + tokyo=pl.col("utc").dt.convert_time_zone("Asia/Tokyo"), + hawaii=pl.col("utc").dt.convert_time_zone("US/Hawaii"), + ) + for selector in (cs.datetime(), cs.datetime("us"), cs.temporal()): + assert df.select(selector).to_dict(as_series=False) == { + "utc": [ + datetime(1950, 7, 5, 0, 0), + datetime(2099, 12, 31, 0, 0), + ], + "tokyo": [ + datetime(1950, 7, 5, 10, 0, tzinfo=ZoneInfo(key="Asia/Tokyo")), + datetime(2099, 12, 31, 9, 0, tzinfo=ZoneInfo(key="Asia/Tokyo")), + ], + "hawaii": [ + datetime(1950, 7, 4, 14, 0, tzinfo=ZoneInfo(key="US/Hawaii")), + datetime(2099, 12, 30, 14, 0, tzinfo=ZoneInfo(key="US/Hawaii")), + ], + } + + +def test_selector_expansion() -> None: + df = pl.DataFrame({name: [] for name in "abcde"}) + + s1 = pl.all().meta._as_selector() + s2 = pl.col(["a", "b"]) + s = s1.meta._selector_sub(s2) + assert df.select(s).columns == ["c", "d", "e"] + + s1 = pl.col("^a|b$").meta._as_selector() + s = s1.meta._selector_add(pl.col(["d", "e"])) + assert df.select(s).columns == ["a", "b", "d", "e"] + + s = s.meta._selector_sub(pl.col("d")) + assert df.select(s).columns == ["a", "b", "e"] + + # add a duplicate, this tests if they are pruned + s = s.meta._selector_add(pl.col("a")) + assert df.select(s).columns == ["a", "b", "e"] + + s1 = pl.col(["a", "b", "c"]) + s2 = pl.col(["b", "c", "d"]) + + s = s1.meta._as_selector() + s = s.meta._selector_and(s2) + assert df.select(s).columns == ["b", "c"] + + +def test_selector_repr() -> None: + assert_repr_equals(cs.all() - cs.first(), "(cs.all() - cs.first())") + assert_repr_equals(~cs.starts_with("a", "b"), "~cs.starts_with('a', 'b')") + assert_repr_equals(cs.float() | cs.by_name("x"), "(cs.float() | cs.by_name('x'))") + assert_repr_equals( + cs.integer() & cs.matches("z"), + "(cs.integer() & cs.matches(pattern='z'))", + ) + assert_repr_equals( + cs.by_name("baz", "moose", "foo", "bear"), + "cs.by_name('baz', 'moose', 'foo', 'bear')", + ) + assert_repr_equals( + cs.by_name("baz", "moose", "foo", "bear", require_all=False), + "cs.by_name('baz', 'moose', 'foo', 'bear', require_all=False)", + ) + assert_repr_equals( + cs.temporal() | cs.by_dtype(pl.String) & cs.string(include_categorical=False), + "(cs.temporal() | (cs.by_dtype(dtypes=[String]) & cs.string(include_categorical=False)))", + ) + + +def test_selector_sets(df: pl.DataFrame) -> None: + # or + assert df.select( + cs.temporal() | cs.string() | cs.starts_with("e") + ).schema == OrderedDict( + { + "eee": pl.Boolean, + "ghi": pl.Time, + "JJK": pl.Date, + "Lmn": pl.Duration("us"), + "opp": pl.Datetime("ms"), + "qqR": pl.String, + } + ) + + # and + assert df.select(cs.temporal() & cs.matches("opp|JJK")).schema == OrderedDict( + { + "JJK": pl.Date, + "opp": pl.Datetime("ms"), + } + ) + + # SET A - SET B + assert df.select(cs.temporal() - cs.matches("opp|JJK")).schema == OrderedDict( + { + "ghi": pl.Time, + "Lmn": pl.Duration("us"), + } + ) + + # equivalent (though more verbose) to the above, using `exclude` + assert df.select( + cs.exclude(~cs.temporal() | cs.matches("opp|JJK")) + ).schema == OrderedDict( + { + "ghi": pl.Time, + "Lmn": pl.Duration("us"), + } + ) + + frame = pl.DataFrame({"colx": [0, 1, 2], "coly": [3, 4, 5], "colz": [6, 7, 8]}) + sub_expr = cs.matches("[yz]$") - pl.col("colx") # << shouldn't behave as set + assert frame.select(sub_expr).rows() == [(3, 6), (3, 6), (3, 6)] + + with pytest.raises(TypeError, match=r"unsupported .* \('Expr' - 'Selector'\)"): + df.select(pl.col("colx") - cs.matches("[yz]$")) + + with pytest.raises(TypeError, match=r"unsupported .* \('Expr' \+ 'Selector'\)"): + df.select(pl.col("colx") + cs.numeric()) + + with pytest.raises(TypeError, match=r"unsupported .* \('Selector' \+ 'Selector'\)"): + df.select(cs.string() + cs.numeric()) + + # complement + assert df.select(~cs.by_dtype([pl.Duration, pl.Time])).schema == { + "abc": pl.UInt16, + "bbb": pl.UInt32, + "cde": pl.Float64, + "def": pl.Float32, + "eee": pl.Boolean, + "fgg": pl.Boolean, + "JJK": pl.Date, + "opp": pl.Datetime("ms"), + "qqR": pl.String, + } + + # exclusive or + for selected in ( + df.select((cs.matches("e|g")) ^ cs.numeric()), + df.select((cs.contains("b", "g")) ^ pl.col("eee")), + ): + assert selected.schema == OrderedDict( + { + "abc": pl.UInt16, + "bbb": pl.UInt32, + "eee": pl.Boolean, + "fgg": pl.Boolean, + "ghi": pl.Time, + } + ) + + +def test_selector_dispatch_default_operator() -> None: + df = pl.DataFrame({"a": [1, 1], "b": [2, 2], "abc": [3, 3]}) + out = df.select((cs.numeric() & ~cs.by_name("abc")) + 1) + expected = pl.DataFrame( + { + "a": [2, 2], + "b": [3, 3], + } + ) + assert_frame_equal(out, expected) + + +def test_selector_expr_dispatch() -> None: + df = pl.DataFrame( + data={ + "colx": [float("inf"), -1, float("nan"), 25], + "coly": [1, float("-inf"), 10, float("nan")], + }, + schema={"colx": pl.Float64, "coly": pl.Float32}, + ) + expected = pl.DataFrame( + data={ + "colx": [0.0, -1.0, 0.0, 25.0], + "coly": [1.0, 0.0, 10.0, 0.0], + }, + schema={"colx": pl.Float64, "coly": pl.Float32}, + ) + + # basic selector-broadcast expression + assert_frame_equal( + expected, + df.with_columns( + pl.when(cs.float().is_finite()).then(cs.float()).otherwise(0.0).name.keep() + ), + ) + + # inverted selector-broadcast expression + assert_frame_equal( + expected, + df.with_columns( + pl.when(~cs.float().is_finite()).then(0.0).otherwise(cs.float()).name.keep() + ), + ) + + +def test_regex_expansion_group_by_9947() -> None: + df = pl.DataFrame({"g": [3], "abc": [1], "abcd": [3]}) + assert df.group_by("g").agg(pl.col("^ab.*$")).columns == ["g", "abc", "abcd"] + + +def test_regex_expansion_exclude_10002() -> None: + df = pl.DataFrame({"col_1": [1, 2, 3], "col_2": [2, 4, 3]}) + expected = pl.DataFrame({"col_1": [10, 20, 30], "col_2": [0.2, 0.4, 0.3]}) + + assert_frame_equal( + df.select( + pl.col("^col_.*$").exclude("col_2").mul(10), + pl.col("^col_.*$").exclude("col_1") / 10, + ), + expected, + ) + + +def test_is_selector() -> None: + # only actual/compound selectors should pass this check + assert is_selector(cs.numeric()) + assert is_selector(cs.by_dtype(pl.UInt32) | pl.col("xyz")) + + # expressions (and literals, etc) should fail + assert not is_selector(pl.col("xyz")) + assert not is_selector(cs.numeric().name.suffix(":num")) + assert not is_selector(cs.date() + pl.col("time")) + assert not is_selector(None) + assert not is_selector("x") + + schema = {"x": pl.Int64, "y": pl.Float64} + with pytest.raises(TypeError): + expand_selector(schema, 999) + + with pytest.raises(TypeError): + expand_selector(schema, "colname") + + +def test_selector_or() -> None: + df = pl.DataFrame( + { + "int": [1, 2, 3], + "float": [1.0, 2.0, 3.0], + "str": ["x", "y", "z"], + } + ).with_row_index("idx") + + result = df.select(cs.by_name("idx") | ~cs.numeric()) + + expected = pl.DataFrame( + {"idx": [0, 1, 2], "str": ["x", "y", "z"]}, + schema_overrides={"idx": pl.UInt32}, + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "selector", + [ + (cs.string() | cs.numeric()), + (cs.numeric() | cs.string()), + ~(~cs.numeric() & ~cs.string()), + ~(~cs.string() & ~cs.numeric()), + (cs.signed_integer() ^ cs.contains("b", "e", "q")) - cs.starts_with("e"), + ], +) +def test_selector_result_order(df: pl.DataFrame, selector: SelectorType) -> None: + # ensure that selector results always match schema column-order + assert df.select(selector).schema == OrderedDict( + { + "abc": pl.UInt16, + "bbb": pl.UInt32, + "cde": pl.Float64, + "def": pl.Float32, + "qqR": pl.String, + } + ) + + +def test_selector_list_of_lists_18499() -> None: + lf = pl.DataFrame( + { + "foo": [1, 2, 3, 1], + "bar": ["a", "a", "a", "a"], + "ham": ["b", "b", "b", "b"], + } + ) + + with pytest.raises(InvalidOperationError, match="invalid selector expression"): + lf.unique(subset=[["bar", "ham"]]) # type: ignore[list-item] + + +def test_selector_python_dtypes() -> None: + df = pl.DataFrame( + { + "int": [1, 2, 3], + "float": [1.0, 2.0, 3.0], + "bool": [True, False, True], + "str": ["x", "y", "z"], + } + ) + assert df.select(cs.by_dtype(int)).columns == ["int"] + assert df.select(cs.by_dtype(float)).columns == ["float"] + assert df.select(cs.by_dtype(bool)).columns == ["bool"] + assert df.select(cs.by_dtype(str)).columns == ["str"] diff --git a/py-polars/tests/unit/test_serde.py b/py-polars/tests/unit/test_serde.py new file mode 100644 index 000000000000..d15df29f51c9 --- /dev/null +++ b/py-polars/tests/unit/test_serde.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +import io +import pickle +from datetime import datetime, timedelta + +import pytest + +import polars as pl +from polars import StringCache +from polars.exceptions import SchemaError +from polars.testing import assert_frame_equal, assert_series_equal + + +def test_pickling_simple_expression() -> None: + e = pl.col("foo").sum() + buf = pickle.dumps(e) + assert str(pickle.loads(buf)) == str(e) + + +def test_pickling_as_struct_11100() -> None: + e = pl.struct("a") + buf = pickle.dumps(e) + assert str(pickle.loads(buf)) == str(e) + + +def test_serde_time_unit() -> None: + values = [datetime(2022, 1, 1) + timedelta(days=1) for _ in range(3)] + s = pl.Series(values).cast(pl.Datetime("ns")) + result = pickle.loads(pickle.dumps(s)) + assert result.dtype == pl.Datetime("ns") + + +def test_serde_duration() -> None: + df = ( + pl.DataFrame( + { + "a": [ + datetime(2021, 2, 1, 9, 20), + datetime(2021, 2, 2, 9, 20), + ], + "b": [4, 5], + } + ) + .with_columns([pl.col("a").cast(pl.Datetime("ns")).alias("a")]) + .select(pl.all()) + ) + df = df.with_columns([pl.col("a").diff(n=1).alias("a_td")]) + serde_df = pickle.loads(pickle.dumps(df)) + assert serde_df["a_td"].dtype == pl.Duration("ns") + assert_series_equal( + serde_df["a_td"], + pl.Series("a_td", [None, timedelta(days=1)], dtype=pl.Duration("ns")), + ) + + +def test_serde_expression_5461() -> None: + e = pl.col("a").sqrt() / pl.col("b").alias("c") + assert pickle.loads(pickle.dumps(e)).meta == e.meta + + +def test_serde_binary() -> None: + data = pl.Series( + "binary_data", + [ + b"\xba\x9b\xca\xd3y\xcb\xc9#", + b"9\x04\xab\xe2\x11\xf3\x85", + b"\xb8\xcb\xc9^\\\xa9-\x94\xe0H\x9d ", + b"S\xbc:\xcb\xf0\xf5r\xfe\x18\xfeH", + b",\xf5)y\x00\xe5\xf7", + b"\xfd\xf6\xf1\xc2X\x0cn\xb9#", + b"\x06\xef\xa6\xa2\xb7", + b"@\xff\x95\xda\xff\xd2\x18", + ], + ) + assert_series_equal( + data, + pickle.loads(pickle.dumps(data)), + ) + + +def test_pickle_lazyframe() -> None: + q = pl.LazyFrame({"a": [1, 4, 3]}).sort("a") + + s = pickle.dumps(q) + assert_frame_equal(pickle.loads(s).collect(), pl.DataFrame({"a": [1, 3, 4]})) + + +def test_deser_empty_list() -> None: + s = pickle.loads(pickle.dumps(pl.Series([[[42.0]], []]))) + assert s.dtype == pl.List(pl.List(pl.Float64)) + assert s.to_list() == [[[42.0]], []] + + +def times2(x: pl.Series) -> pl.Series: + return x * 2 + + +def test_pickle_udf_expression() -> None: + df = pl.DataFrame({"a": [1, 2, 3]}) + + e = pl.col("a").map_batches(times2) + b = pickle.dumps(e) + e = pickle.loads(b) + + result = df.select(e) + expected = pl.DataFrame({"a": [2, 4, 6]}) + assert_frame_equal(result, expected) + + e = pl.col("a").map_batches(times2, return_dtype=pl.String) + b = pickle.dumps(e) + e = pickle.loads(b) + + # tests that 'GetOutput' is also deserialized + with pytest.raises( + SchemaError, + match=r"expected output type 'String', got 'Int64'; set `return_dtype` to the proper datatype", + ): + df.select(e) + + +def test_pickle_small_integers() -> None: + df = pl.DataFrame( + [ + pl.Series([1, 2], dtype=pl.Int16), + pl.Series([3, 2], dtype=pl.Int8), + pl.Series([32, 2], dtype=pl.UInt8), + pl.Series([3, 3], dtype=pl.UInt16), + ] + ) + b = pickle.dumps(df) + assert_frame_equal(pickle.loads(b), df) + + +def df_times2(df: pl.DataFrame) -> pl.DataFrame: + return df.select(pl.all() * 2) + + +def test_pickle_lazyframe_udf() -> None: + df = pl.DataFrame({"a": [1, 2, 3]}) + + q = df.lazy().map_batches(df_times2) + b = pickle.dumps(q) + + q = pickle.loads(b) + assert q.collect()["a"].to_list() == [2, 4, 6] + + +def test_pickle_lazyframe_nested_function_udf() -> None: + df = pl.DataFrame({"a": [1, 2, 3]}) + + # NOTE: This is only possible when we're using cloudpickle. + def inner_df_times2(df: pl.DataFrame) -> pl.DataFrame: + return df.select(pl.all() * 2) + + q = df.lazy().map_batches(inner_df_times2) + b = pickle.dumps(q) + + q = pickle.loads(b) + assert q.collect()["a"].to_list() == [2, 4, 6] + + +@StringCache() +def test_serde_categorical_series_10586() -> None: + s = pl.Series(["a", "b", "b", "a", "c"], dtype=pl.Categorical) + loaded_s = pickle.loads(pickle.dumps(s)) + assert_series_equal(loaded_s, s) + + +def test_serde_keep_dtype_empty_list() -> None: + s = pl.Series([{"a": None}], dtype=pl.Struct([pl.Field("a", pl.List(pl.String))])) + assert s.dtype == pickle.loads(pickle.dumps(s)).dtype + + +def test_serde_array_dtype() -> None: + s = pl.Series( + [[1, 2, 3], [None, None, None], [1, None, 3]], + dtype=pl.Array(pl.Int32(), 3), + ) + assert_series_equal(pickle.loads(pickle.dumps(s)), s) + + nested_s = pl.Series( + [[[1, 2, 3], [4, None, 5]], None, [[None, None, 2]]], + dtype=pl.List(pl.Array(pl.Int32(), 3)), + ) + assert_series_equal(pickle.loads(pickle.dumps(nested_s)), nested_s) + + +def test_serde_data_type_class() -> None: + dtype = pl.Datetime + serialized = pickle.dumps(dtype) + deserialized = pickle.loads(serialized) + assert deserialized == dtype + assert isinstance(deserialized, type) + + +def test_serde_data_type_instantiated() -> None: + dtype = pl.Int8() + serialized = pickle.dumps(dtype) + deserialized = pickle.loads(serialized) + assert deserialized == dtype + assert isinstance(deserialized, pl.DataType) + + +def test_serde_data_type_instantiated_with_attributes() -> None: + dtype = pl.Enum(["a", "b"]) + serialized = pickle.dumps(dtype) + deserialized = pickle.loads(serialized) + assert deserialized == dtype + assert isinstance(deserialized, pl.DataType) + + +def test_serde_udf() -> None: + lf = pl.LazyFrame({"a": [[1, 2], [3, 4, 5]], "b": [3, 4]}).select( + pl.col("a").map_elements(lambda x: sum(x), return_dtype=pl.Int32) + ) + result = pl.LazyFrame.deserialize(io.BytesIO(lf.serialize())) + + assert_frame_equal(lf, result) + + +def test_serde_empty_df_lazy_frame() -> None: + lf = pl.LazyFrame() + f = io.BytesIO() + f.write(lf.serialize()) + f.seek(0) + assert pl.LazyFrame.deserialize(f).collect().shape == (0, 0) + + +def test_pickle_class_objects_21021() -> None: + assert isinstance(pickle.loads(pickle.dumps(pl.col))("A"), pl.Expr) + assert isinstance(pickle.loads(pickle.dumps(pl.DataFrame))(), pl.DataFrame) + assert isinstance(pickle.loads(pickle.dumps(pl.LazyFrame))(), pl.LazyFrame) diff --git a/py-polars/tests/unit/test_show_graph.py b/py-polars/tests/unit/test_show_graph.py new file mode 100644 index 000000000000..f46d135e0792 --- /dev/null +++ b/py-polars/tests/unit/test_show_graph.py @@ -0,0 +1,15 @@ +import polars as pl + + +def test_show_graph() -> None: + # only test raw output, otherwise we need graphviz and matplotlib + ldf = pl.LazyFrame( + { + "a": ["a", "b", "a", "b", "b", "c"], + "b": [1, 2, 3, 4, 5, 6], + "c": [6, 5, 4, 3, 2, 1], + } + ) + query = ldf.group_by("a", maintain_order=True).agg(pl.all().sum()).sort("a") + out = query.show_graph(raw_output=True) + assert isinstance(out, str) diff --git a/py-polars/tests/unit/test_simplify.py b/py-polars/tests/unit/test_simplify.py new file mode 100644 index 000000000000..e7bc7ec819e3 --- /dev/null +++ b/py-polars/tests/unit/test_simplify.py @@ -0,0 +1,10 @@ +import polars as pl + + +def test_flatten_alias() -> None: + assert ( + """len().alias("bar")""" + in pl.LazyFrame({"a": [1, 2]}) + .select(pl.len().alias("foo").alias("bar")) + .explain() + ) diff --git a/py-polars/tests/unit/test_single.py b/py-polars/tests/unit/test_single.py new file mode 100644 index 000000000000..bc509abe449a --- /dev/null +++ b/py-polars/tests/unit/test_single.py @@ -0,0 +1,26 @@ +import polars as pl +from polars.testing import assert_frame_equal + + +def test_single_row_literal_ambiguity_8481() -> None: + df = pl.DataFrame( + { + "store_id": [1], + "cost_price": [2.0], + } + ) + + inverse_cost_price = 1.0 / pl.col("cost_price") + result = df.with_columns( + (inverse_cost_price / inverse_cost_price.sum()).over("store_id").alias("result") + ) + # exceptions.ComputeError: cannot aggregate a literal + + expected = pl.DataFrame( + { + "store_id": [1], + "cost_price": [2.0], + "result": [1.0], + } + ) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/test_string_cache.py b/py-polars/tests/unit/test_string_cache.py new file mode 100644 index 000000000000..3fa20730b57e --- /dev/null +++ b/py-polars/tests/unit/test_string_cache.py @@ -0,0 +1,185 @@ +from collections.abc import Iterator + +import pytest + +import polars as pl +from polars.exceptions import CategoricalRemappingWarning +from polars.testing import assert_frame_equal + + +@pytest.fixture(autouse=True) +def _disable_string_cache() -> Iterator[None]: + """Fixture to make sure the string cache is disabled before and after each test.""" + pl.disable_string_cache() + yield + pl.disable_string_cache() + + +def sc(set: bool) -> None: + """Short syntax for asserting whether the global string cache is being used.""" + assert pl.using_string_cache() is set + + +def test_string_cache_enable_disable() -> None: + sc(False) + pl.enable_string_cache() + sc(True) + pl.disable_string_cache() + sc(False) + + +def test_string_cache_enable_disable_repeated() -> None: + sc(False) + pl.enable_string_cache() + sc(True) + pl.enable_string_cache() + sc(True) + pl.disable_string_cache() + sc(False) + pl.disable_string_cache() + sc(False) + + +def test_string_cache_context_manager() -> None: + sc(False) + with pl.StringCache(): + sc(True) + sc(False) + + +def test_string_cache_context_manager_nested() -> None: + sc(False) + with pl.StringCache(): + sc(True) + with pl.StringCache(): + sc(True) + sc(True) + sc(False) + + +def test_string_cache_context_manager_mixed_with_enable_disable() -> None: + sc(False) + with pl.StringCache(): + sc(True) + pl.enable_string_cache() + sc(True) + sc(True) + + with pl.StringCache(): + sc(True) + sc(True) + + with pl.StringCache(): + sc(True) + with pl.StringCache(): + sc(True) + pl.disable_string_cache() + sc(True) + sc(True) + sc(False) + + with pl.StringCache(): + sc(True) + pl.disable_string_cache() + sc(True) + sc(False) + + +def test_string_cache_decorator() -> None: + @pl.StringCache() + def my_function() -> None: + sc(True) + + sc(False) + my_function() + sc(False) + + +def test_string_cache_decorator_mixed_with_enable() -> None: + @pl.StringCache() + def my_function() -> None: + sc(True) + pl.enable_string_cache() + sc(True) + + sc(False) + my_function() + sc(True) + + +@pytest.mark.may_fail_auto_streaming +def test_string_cache_join() -> None: + df1 = pl.DataFrame({"a": ["foo", "bar", "ham"], "b": [1, 2, 3]}) + df2 = pl.DataFrame({"a": ["eggs", "spam", "foo"], "c": [2, 2, 3]}) + + # ensure cache is off when casting to categorical; the join will fail + pl.disable_string_cache() + assert pl.using_string_cache() is False + + with pytest.warns( + CategoricalRemappingWarning, + match="Local categoricals have different encodings", + ): + df1a = df1.with_columns(pl.col("a").cast(pl.Categorical)) + df2a = df2.with_columns(pl.col("a").cast(pl.Categorical)) + out = df1a.join(df2a, on="a", how="inner") + + expected = pl.DataFrame( + {"a": ["foo"], "b": [1], "c": [3]}, schema_overrides={"a": pl.Categorical} + ) + + # Can not do equality checks on local categoricals with different categories + assert_frame_equal(out, expected, categorical_as_str=True) + + # now turn on the cache + pl.enable_string_cache() + assert pl.using_string_cache() is True + + df1b = df1.with_columns(pl.col("a").cast(pl.Categorical)) + df2b = df2.with_columns(pl.col("a").cast(pl.Categorical)) + out = df1b.join(df2b, on="a", how="inner") + + expected = pl.DataFrame( + {"a": ["foo"], "b": [1], "c": [3]}, schema_overrides={"a": pl.Categorical} + ) + assert_frame_equal(out, expected) + + +def test_string_cache_eager_lazy() -> None: + # tests if the global string cache is really global and not interfered by the lazy + # execution. first the global settings was thread-local and this breaks with the + # parallel execution of lazy + with pl.StringCache(): + df1 = pl.DataFrame( + {"region_ids": ["reg1", "reg2", "reg3", "reg4", "reg5"]} + ).select([pl.col("region_ids").cast(pl.Categorical)]) + + df2 = pl.DataFrame( + {"seq_name": ["reg4", "reg2", "reg1"], "score": [3.0, 1.0, 2.0]} + ).select([pl.col("seq_name").cast(pl.Categorical), pl.col("score")]) + + expected = pl.DataFrame( + { + "region_ids": ["reg1", "reg2", "reg3", "reg4", "reg5"], + "score": [2.0, 1.0, None, 3.0, None], + } + ).with_columns(pl.col("region_ids").cast(pl.Categorical)) + + result = df1.join(df2, left_on="region_ids", right_on="seq_name", how="left") + assert_frame_equal(result, expected, check_row_order=False) + + # also check row-wise categorical insert. + # (column-wise is preferred, but this shouldn't fail) + for params in ( + {"schema": [("region_ids", pl.Categorical)]}, + { + "schema": ["region_ids"], + "schema_overrides": {"region_ids": pl.Categorical}, + }, + ): + df3 = pl.DataFrame( + data=[["reg1"], ["reg2"], ["reg3"], ["reg4"], ["reg5"]], + orient="row", + **params, + ) + assert_frame_equal(df1, df3) diff --git a/py-polars/tests/unit/testing/__init__.py b/py-polars/tests/unit/testing/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/testing/parametric/strategies/test_core.py b/py-polars/tests/unit/testing/parametric/strategies/test_core.py new file mode 100644 index 000000000000..8edf018e6f21 --- /dev/null +++ b/py-polars/tests/unit/testing/parametric/strategies/test_core.py @@ -0,0 +1,309 @@ +from datetime import datetime +from typing import Any + +import hypothesis.strategies as st +import pytest +from hypothesis import given, settings +from hypothesis.errors import InvalidArgument + +import polars as pl +from polars.testing.parametric import ( + column, + dataframes, + dtypes, + lists, + series, +) + +TEMPORAL_DTYPES = {pl.Date, pl.Time, pl.Datetime, pl.Duration} + + +@given(s=series()) +@settings(max_examples=5) +def test_series_defaults(s: pl.Series) -> None: + assert isinstance(s, pl.Series) + assert s.name == "" + + +@given(s=series(name="hello")) +@settings(max_examples=5) +def test_series_name(s: pl.Series) -> None: + assert s.name == "hello" + + +@given(st.data()) +def test_series_dtype(data: st.DataObject) -> None: + dtype = data.draw(dtypes()) + s = data.draw(series(dtype=dtype)) + assert s.dtype == dtype + + +@given(s=series(dtype=pl.Enum, allow_null=False)) +@settings(max_examples=5) +def test_series_dtype_enum(s: pl.Series) -> None: + assert isinstance(s.dtype, pl.Enum) + assert all(v in s.dtype.categories for v in s) + + +@given(s=series(dtype=pl.Boolean, min_size=5, max_size=5)) +@settings(max_examples=5) +def test_series_size(s: pl.Series) -> None: + assert s.len() == 5 + + +@given(s=series(min_size=3, max_size=8)) +@settings(max_examples=5) +def test_series_size_range(s: pl.Series) -> None: + assert 3 <= s.len() <= 8 + + +@given(s=series(allow_null=False)) +def test_series_allow_null_false(s: pl.Series) -> None: + assert not s.has_nulls() + assert s.dtype != pl.Null + + +@given(s=series(allowed_dtypes=[pl.Null], allow_null=False)) +def test_series_allow_null_allowed_dtypes(s: pl.Series) -> None: + assert s.dtype == pl.Null + + +@given(s=series(allowed_dtypes=[pl.List(pl.Int8)], allow_null=False)) +def test_series_allow_null_nested(s: pl.Series) -> None: + for v in s: + assert not v.has_nulls() + + +@given(df=dataframes()) +@settings(max_examples=5) +def test_dataframes_defaults(df: pl.DataFrame) -> None: + assert isinstance(df, pl.DataFrame) + assert df.columns == [f"col{i}" for i in range(df.width)] + + +@given(lf=dataframes(lazy=True)) +@settings(max_examples=5) +def test_dataframes_lazy(lf: pl.LazyFrame) -> None: + assert isinstance(lf, pl.LazyFrame) + + +@given(df=dataframes(cols=3, min_size=5, max_size=5)) +@settings(max_examples=5) +def test_dataframes_size(df: pl.DataFrame) -> None: + assert df.height == 5 + assert df.width == 3 + + +@given(df=dataframes(min_cols=2, max_cols=5, min_size=3, max_size=8)) +@settings(max_examples=5) +def test_dataframes_size_range(df: pl.DataFrame) -> None: + assert 3 <= df.height <= 8 + assert 2 <= df.width <= 5 + + +@given(df=dataframes(cols=1, allow_null=True)) +@settings(max_examples=5) +def test_dataframes_allow_null_global(df: pl.DataFrame) -> None: + null_count = sum(df.null_count().row(0)) + assert 0 <= null_count <= df.height * df.width + + +@given(df=dataframes(cols=2, allow_null={"col0": True})) +@settings(max_examples=5) +def test_dataframes_allow_null_column(df: pl.DataFrame) -> None: + null_count = sum(df.null_count().row(0)) + assert 0 <= null_count <= df.height * df.width + + +@given( + df=dataframes( + cols=1, + allow_null=False, + include_cols=[column(name="colx", allow_null=True)], + ) +) +def test_dataframes_allow_null_override(df: pl.DataFrame) -> None: + assert not df.get_column("col0").has_nulls() + assert 0 <= df.get_column("colx").null_count() <= df.height + + +@given( + lf=dataframes( + # generate lazyframes with at least one row + lazy=True, + min_size=1, + allow_null=False, + # test mix & match of bulk-assigned cols with custom cols + cols=[column(n, dtype=pl.UInt8, unique=True) for n in ["a", "b"]], + include_cols=[ + column("c", dtype=pl.Boolean), + column("d", strategy=st.sampled_from(["x", "y", "z"])), + ], + ) +) +def test_dataframes_columns(lf: pl.LazyFrame) -> None: + assert lf.collect_schema() == { + "a": pl.UInt8, + "b": pl.UInt8, + "c": pl.Boolean, + "d": pl.String, + } + assert lf.collect_schema().names() == ["a", "b", "c", "d"] + df = lf.collect() + + # confirm uint cols bounds + uint8_max = (2**8) - 1 + assert df["a"].min() >= 0 # type: ignore[operator] + assert df["b"].min() >= 0 # type: ignore[operator] + assert df["a"].max() <= uint8_max # type: ignore[operator] + assert df["b"].max() <= uint8_max # type: ignore[operator] + + # confirm uint cols uniqueness + assert df["a"].is_unique().all() + assert df["b"].is_unique().all() + + # boolean col + assert all(isinstance(v, bool) for v in df["c"].to_list()) + + # string col, entries selected from custom values + xyz = {"x", "y", "z"} + assert all(v in xyz for v in df["d"].to_list()) + + +@pytest.mark.hypothesis +def test_column_invalid_probability() -> None: + with pytest.deprecated_call(), pytest.raises(InvalidArgument): + column("col", null_probability=2.0) + + +@pytest.mark.hypothesis +def test_column_null_probability_deprecated() -> None: + with pytest.deprecated_call(): + col = column("col", allow_null=False, null_probability=0.5) + assert col.null_probability == 0.5 + assert col.allow_null is True # null_probability takes precedence + + +@given(st.data()) +def test_allow_infinities_deprecated(data: st.DataObject) -> None: + with pytest.deprecated_call(): + strategy = series(dtype=pl.Float64, allow_infinities=False) + s = data.draw(strategy) + + assert all(v not in (float("inf"), float("-inf")) for v in s) + + +@given( + df=dataframes( + cols=[ + column("colx", dtype=pl.Array(pl.UInt8, shape=3)), + column("coly", dtype=pl.List(pl.Datetime("ms"))), + column( + name="colz", + dtype=pl.List(pl.List(pl.String)), + strategy=lists( + inner_dtype=pl.List(pl.String), + select_from=["aa", "bb", "cc"], + min_size=1, + ), + ), + ], + allow_null=False, + ), +) +def test_dataframes_nested_strategies(df: pl.DataFrame) -> None: + assert df.schema == { + "colx": pl.Array(pl.UInt8, shape=3), + "coly": pl.List(pl.Datetime("ms")), + "colz": pl.List(pl.List(pl.String)), + } + uint8_max = (2**8) - 1 + + for colx, coly, colz in df.iter_rows(): + assert len(colx) == 3 + assert all(i <= uint8_max for i in colx) + assert all(isinstance(d, datetime) for d in coly) + for inner_list in colz: + assert all(s in ("aa", "bb", "cc") for s in inner_list) + + +@given( + df=dataframes(allowed_dtypes=TEMPORAL_DTYPES, max_size=1, max_cols=5), + lf=dataframes(excluded_dtypes=TEMPORAL_DTYPES, max_size=1, max_cols=5, lazy=True), + s1=series(allowed_dtypes=TEMPORAL_DTYPES, max_size=1), + s2=series(excluded_dtypes=TEMPORAL_DTYPES, max_size=1), +) +@settings(max_examples=50) +def test_strategy_dtypes( + df: pl.DataFrame, + lf: pl.LazyFrame, + s1: pl.Series, + s2: pl.Series, +) -> None: + # dataframe, lazyframe + assert all(tp.is_temporal() for tp in df.dtypes) + assert all(not tp.is_temporal() for tp in lf.collect_schema().dtypes()) + + # series + assert s1.dtype.is_temporal() + assert not s2.dtype.is_temporal() + + +@given(s=series(allow_chunks=False)) +@settings(max_examples=10) +def test_series_allow_chunks(s: pl.Series) -> None: + assert s.n_chunks() == 1 + + +@given(df=dataframes(allow_chunks=False)) +@settings(max_examples=10) +def test_dataframes_allow_chunks(df: pl.DataFrame) -> None: + assert df.n_chunks("first") == 1 + assert df.n_chunks("all") == [1] * df.width + + +@given( + df=dataframes( + allowed_dtypes=[pl.Float32, pl.Float64], + max_cols=4, + allow_null=False, + allow_infinity=False, + ), + s=series(dtype=pl.Float64, allow_null=False, allow_infinity=False), +) +def test_infinities( + df: pl.DataFrame, + s: pl.Series, +) -> None: + from math import isfinite, isnan + + def finite_float(value: Any) -> bool: + return isfinite(value) or isnan(value) + + assert all(finite_float(val) for val in s.to_list()) + for col in df.columns: + assert all(finite_float(val) for val in df[col].to_list()) + + +@given( + df=dataframes( + cols=10, + max_size=1, + allowed_dtypes=[pl.Int8, pl.UInt16, pl.List(pl.Int32)], + ) +) +@settings(max_examples=3) +def test_dataframes_allowed_dtypes_integer_cols(df: pl.DataFrame) -> None: + # ensure dtype constraint works in conjunction with 'n' cols + assert all( + tp in (pl.Int8, pl.UInt16, pl.List(pl.Int32)) for tp in df.schema.values() + ) + + +@given(st.data()) +@settings(max_examples=1) +def test_series_chunked_deprecated(data: st.DataObject) -> None: + with pytest.deprecated_call(): + data.draw(series(chunked=True)) + with pytest.deprecated_call(): + data.draw(dataframes(chunked=True)) diff --git a/py-polars/tests/unit/testing/parametric/strategies/test_data.py b/py-polars/tests/unit/testing/parametric/strategies/test_data.py new file mode 100644 index 000000000000..d3c5282ac959 --- /dev/null +++ b/py-polars/tests/unit/testing/parametric/strategies/test_data.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from hypothesis import given + +import polars as pl +from polars.testing.parametric.strategies.data import categories, data + + +@given(cat=categories(3)) +def test_categories(cat: str) -> None: + assert cat in ("c0", "c1", "c2") + + +@given(cat=data(pl.Categorical, n_categories=3)) +def test_data_kwargs(cat: str) -> None: + assert cat in ("c0", "c1", "c2") + + +@given(categories=data(pl.List(pl.Categorical), n_categories=3)) +def test_data_nested_kwargs(categories: list[str]) -> None: + assert all(c in ("c0", "c1", "c2") for c in categories) + + +@given(cat=data(pl.Enum)) +def test_data_enum(cat: str) -> None: + assert cat in ("c0", "c1", "c2") + + +@given(cat=data(pl.Enum(["hello", "world"]))) +def test_data_enum_instantiated(cat: str) -> None: + assert cat in ("hello", "world") + + +@given(struct=data(pl.Struct({"a": pl.Int8, "b": pl.String}))) +def test_data_struct(struct: dict[str, int | str]) -> None: + assert isinstance(struct["a"], int) + assert isinstance(struct["b"], str) diff --git a/py-polars/tests/unit/testing/parametric/strategies/test_dtype.py b/py-polars/tests/unit/testing/parametric/strategies/test_dtype.py new file mode 100644 index 000000000000..81055a6f27db --- /dev/null +++ b/py-polars/tests/unit/testing/parametric/strategies/test_dtype.py @@ -0,0 +1,59 @@ +import hypothesis.strategies as st +from hypothesis import given + +import polars as pl +from polars.testing.parametric.strategies.dtype import dtypes + + +@given(dtype=dtypes()) +def test_dtypes(dtype: pl.DataType) -> None: + assert isinstance(dtype, pl.DataType) + + +@given(dtype=dtypes(nesting_level=0)) +def test_dtypes_nesting_level(dtype: pl.DataType) -> None: + assert not dtype.is_nested() + + +@given(dtype=dtypes(allowed_dtypes=[pl.Datetime], allow_time_zones=False)) +def test_dtypes_allow_time_zones(dtype: pl.DataType) -> None: + assert getattr(dtype, "time_zone", None) is None + + +@given(st.data()) +def test_dtypes_allowed(data: st.DataObject) -> None: + allowed_dtype = data.draw(dtypes()) + result = data.draw(dtypes(allowed_dtypes=[allowed_dtype])) + assert result == allowed_dtype + + +@given(st.data()) +def test_dtypes_excluded(data: st.DataObject) -> None: + excluded_dtype = data.draw(dtypes()) + result = data.draw(dtypes(excluded_dtypes=[excluded_dtype])) + assert result != excluded_dtype + + +@given(dtype=dtypes(allowed_dtypes=[pl.Duration], excluded_dtypes=[pl.Duration("ms")])) +def test_dtypes_allowed_excluded_instance(dtype: pl.DataType) -> None: + assert isinstance(dtype, pl.Duration) + assert dtype.time_unit != "ms" + + +@given( + dtype=dtypes( + allowed_dtypes=[pl.Duration("ns"), pl.Date], excluded_dtypes=[pl.Duration] + ) +) +def test_dtypes_allowed_excluded_priority(dtype: pl.DataType) -> None: + assert dtype == pl.Date + + +@given(dtype=dtypes(allowed_dtypes=[pl.Int8(), pl.Duration("ms")])) +def test_dtypes_allowed_instantiated(dtype: pl.DataType) -> None: + assert dtype in (pl.Int8(), pl.Duration("ms")) + + +@given(dtype=dtypes(allowed_dtypes=[pl.List(pl.List), pl.Int64])) +def test_dtypes_allowed_uninstantiated_nested(dtype: pl.DataType) -> None: + assert dtype in (pl.List, pl.Int64) diff --git a/py-polars/tests/unit/testing/parametric/strategies/test_legacy.py b/py-polars/tests/unit/testing/parametric/strategies/test_legacy.py new file mode 100644 index 000000000000..d9cd5bc5bbb7 --- /dev/null +++ b/py-polars/tests/unit/testing/parametric/strategies/test_legacy.py @@ -0,0 +1,20 @@ +import pytest +from hypothesis.errors import NonInteractiveExampleWarning + +from polars.testing.parametric import columns, create_list_strategy +from polars.testing.parametric.strategies.core import _COL_LIMIT + + +@pytest.mark.hypothesis +def test_columns_deprecated() -> None: + with pytest.deprecated_call(), pytest.warns(NonInteractiveExampleWarning): + result = columns() + assert 0 <= len(result) <= _COL_LIMIT + + +@pytest.mark.hypothesis +def test_create_list_strategy_deprecated() -> None: + with pytest.deprecated_call(), pytest.warns(NonInteractiveExampleWarning): + result = create_list_strategy(size=5) + with pytest.warns(NonInteractiveExampleWarning): + assert len(result.example()) == 5 diff --git a/py-polars/tests/unit/testing/parametric/strategies/test_utils.py b/py-polars/tests/unit/testing/parametric/strategies/test_utils.py new file mode 100644 index 000000000000..f192bc410e21 --- /dev/null +++ b/py-polars/tests/unit/testing/parametric/strategies/test_utils.py @@ -0,0 +1,22 @@ +from typing import Any + +import pytest + +from polars.testing.parametric.strategies._utils import flexhash + + +@pytest.mark.parametrize( + ("left", "right"), + [ + (1, 2), + (1.0, 2.0), + ("x", "y"), + ([1, 2], [3, 4]), + ({"a": 1, "b": 2}, {"a": 1, "b": 3}), + ({"a": 1, "b": [1.0]}, {"a": 1, "b": [1.5]}), + ], +) +def test_flexhash_flat(left: Any, right: Any) -> None: + assert flexhash(left) != flexhash(right) + assert flexhash(left) == flexhash(left) + assert flexhash(right) == flexhash(right) diff --git a/py-polars/tests/unit/testing/test_assert_frame_equal.py b/py-polars/tests/unit/testing/test_assert_frame_equal.py new file mode 100644 index 000000000000..f13898bb26c5 --- /dev/null +++ b/py-polars/tests/unit/testing/test_assert_frame_equal.py @@ -0,0 +1,464 @@ +from __future__ import annotations + +import math +from typing import Any + +import pytest +from hypothesis import given + +import polars as pl +from polars.testing import assert_frame_equal, assert_frame_not_equal +from polars.testing.parametric import dataframes + +nan = float("nan") +pytest_plugins = ["pytester"] + + +@given(df=dataframes()) +def test_equal(df: pl.DataFrame) -> None: + assert_frame_equal(df, df.clone(), check_exact=True) + + +@pytest.mark.parametrize( + ("df1", "df2", "kwargs"), + [ + pytest.param( + pl.DataFrame({"a": [0.2, 0.3]}), + pl.DataFrame({"a": [0.2, 0.3]}), + {"atol": 1e-15}, + id="equal_floats_low_atol", + ), + pytest.param( + pl.DataFrame({"a": [0.2, 0.3]}), + pl.DataFrame({"a": [0.2, 0.3000000000000001]}), + {"atol": 1e-15}, + id="approx_equal_float_low_atol", + ), + pytest.param( + pl.DataFrame({"a": [0.2, 0.3]}), + pl.DataFrame({"a": [0.2, 0.31]}), + {"atol": 0.1}, + id="approx_equal_float_high_atol", + ), + pytest.param( + pl.DataFrame({"a": [0.2, 1.3]}), + pl.DataFrame({"a": [0.2, 0.9]}), + {"atol": 1}, + id="approx_equal_float_integer_atol", + ), + pytest.param( + pl.DataFrame({"a": [0.0, 1.0, 2.0]}, schema={"a": pl.Float64}), + pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Int64}), + {"check_dtypes": False}, + id="equal_int_float_integer_no_check_dtype", + ), + pytest.param( + pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Float64}), + pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Float32}), + {"check_dtypes": False}, + id="equal_int_float_integer_no_check_dtype", + ), + pytest.param( + pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Int64}), + pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Int64}), + {}, + id="equal_int", + ), + pytest.param( + pl.DataFrame({"a": ["a", "b", "c"]}, schema={"a": pl.String}), + pl.DataFrame({"a": ["a", "b", "c"]}, schema={"a": pl.String}), + {}, + id="equal_str", + ), + pytest.param( + pl.DataFrame({"a": [[0.2, 0.3]]}), + pl.DataFrame({"a": [[0.2, 0.300001]]}), + {"atol": 1e-5}, + id="list_of_float_low_atol", + ), + pytest.param( + pl.DataFrame({"a": [[0.2, 0.3]]}), + pl.DataFrame({"a": [[0.2, 0.31]]}), + {"atol": 0.1}, + id="list_of_float_high_atol", + ), + pytest.param( + pl.DataFrame({"a": [[0.2, 1.3]]}), + pl.DataFrame({"a": [[0.2, 0.9]]}), + {"atol": 1}, + id="list_of_float_integer_atol", + ), + pytest.param( + pl.DataFrame({"a": [[0.2, 0.3]]}), + pl.DataFrame({"a": [[0.2, 0.300000001]]}), + {"rtol": 1e-5}, + id="list_of_float_low_rtol", + ), + pytest.param( + pl.DataFrame({"a": [[0.2, 0.3]]}), + pl.DataFrame({"a": [[0.2, 0.301]]}), + {"rtol": 0.1}, + id="list_of_float_high_rtol", + ), + pytest.param( + pl.DataFrame({"a": [[0.2, 1.3]]}), + pl.DataFrame({"a": [[0.2, 0.9]]}), + {"rtol": 1}, + id="list_of_float_integer_rtol", + ), + pytest.param( + pl.DataFrame({"a": [[None, 1.3]]}), + pl.DataFrame({"a": [[None, 0.9]]}), + {"rtol": 1}, + id="list_of_none_and_float_integer_rtol", + ), + pytest.param( + pl.DataFrame({"a": [[[0.2, 3.0]]]}), + pl.DataFrame({"a": [[[0.2, 3.00000001]]]}), + {"atol": 0.1}, + id="nested_list_of_float_atol_high", + ), + ], +) +def test_assert_frame_equal_passes_assertion( + df1: pl.DataFrame, + df2: pl.DataFrame, + kwargs: dict[str, Any], +) -> None: + assert_frame_equal(df1, df2, **kwargs) + with pytest.raises(AssertionError): + assert_frame_not_equal(df1, df2, **kwargs) + + +@pytest.mark.parametrize( + ("df1", "df2", "kwargs"), + [ + pytest.param( + pl.DataFrame({"a": [[0.2, 0.3]]}), + pl.DataFrame({"a": [[0.2, 0.3, 0.4]]}), + {}, + id="list_of_float_different_lengths", + ), + pytest.param( + pl.DataFrame({"a": [[0.2, 0.3]]}), + pl.DataFrame({"a": [[0.2, 0.3000000000000001]]}), + {"check_exact": True}, + id="list_of_float_check_exact", + ), + pytest.param( + pl.DataFrame({"a": [[0.2, 0.3]]}), + pl.DataFrame({"a": [[0.2, 0.300001]]}), + {"atol": 1e-15, "rtol": 0}, + id="list_of_float_too_low_atol", + ), + pytest.param( + pl.DataFrame({"a": [[0.2, 0.3]]}), + pl.DataFrame({"a": [[0.2, 0.30000001]]}), + {"atol": -1, "rtol": 0}, + id="list_of_float_negative_atol", + ), + pytest.param( + pl.DataFrame({"a": [[2.0, 3.0]]}), + pl.DataFrame({"a": [[2, 3]]}), + {"check_exact": False, "check_dtypes": True}, + id="list_of_float_list_of_int_check_dtype_true", + ), + pytest.param( + pl.DataFrame({"a": [[[0.2, math.nan, 3.0]]]}), + pl.DataFrame({"a": [[[0.2, math.nan, 3.11]]]}), + {"atol": 0.1, "rtol": 0}, + id="nested_list_of_float_and_nan_atol_high", + ), + pytest.param( + pl.DataFrame({"a": [[[[0.2, 3.0]]]]}), + pl.DataFrame({"a": [[[[0.2, 3.11]]]]}), + {"atol": 0.1, "rtol": 0}, + id="double_nested_list_of_float_atol_high", + ), + pytest.param( + pl.DataFrame({"a": [[[[[0.2, 3.0]]]]]}), + pl.DataFrame({"a": [[[[[0.2, 3.11]]]]]}), + {"atol": 0.1, "rtol": 0}, + id="triple_nested_list_of_float_atol_high", + ), + ], +) +def test_assert_frame_equal_raises_assertion_error( + df1: pl.DataFrame, + df2: pl.DataFrame, + kwargs: dict[str, Any], +) -> None: + with pytest.raises(AssertionError): + assert_frame_equal(df1, df2, **kwargs) + assert_frame_not_equal(df1, df2, **kwargs) + + +def test_compare_frame_equal_nans() -> None: + df1 = pl.DataFrame( + data={"x": [1.0, nan], "y": [nan, 2.0]}, + schema=[("x", pl.Float32), ("y", pl.Float64)], + ) + assert_frame_equal(df1, df1, check_exact=True) + + df2 = pl.DataFrame( + data={"x": [1.0, nan], "y": [None, 2.0]}, + schema=[("x", pl.Float32), ("y", pl.Float64)], + ) + assert_frame_not_equal(df1, df2) + with pytest.raises(AssertionError, match="value mismatch for column 'y'"): + assert_frame_equal(df1, df2, check_exact=True) + + +def test_compare_frame_equal_nested_nans() -> None: + # list dtype + df1 = pl.DataFrame( + data={"x": [[1.0, nan]], "y": [[nan, 2.0]]}, + schema=[("x", pl.List(pl.Float32)), ("y", pl.List(pl.Float64))], + ) + assert_frame_equal(df1, df1, check_exact=True) + + df2 = pl.DataFrame( + data={"x": [[1.0, nan]], "y": [[None, 2.0]]}, + schema=[("x", pl.List(pl.Float32)), ("y", pl.List(pl.Float64))], + ) + assert_frame_not_equal(df1, df2) + with pytest.raises(AssertionError, match="value mismatch for column 'y'"): + assert_frame_equal(df1, df2, check_exact=True) + + # struct dtype + df3 = pl.from_dicts( + [ + { + "id": 1, + "struct": [ + {"x": "text", "y": [0.0, nan]}, + {"x": "text", "y": [0.0, nan]}, + ], + }, + { + "id": 2, + "struct": [ + {"x": "text", "y": [1]}, + {"x": "text", "y": [1]}, + ], + }, + ] + ) + df4 = pl.from_dicts( + [ + { + "id": 1, + "struct": [ + {"x": "text", "y": [0.0, nan], "z": ["$"]}, + {"x": "text", "y": [0.0, nan], "z": ["$"]}, + ], + }, + { + "id": 2, + "struct": [ + {"x": "text", "y": [nan, 1.0], "z": ["!"]}, + {"x": "text", "y": [nan, 1.0], "z": ["?"]}, + ], + }, + ] + ) + + assert_frame_equal(df3, df3) + assert_frame_equal(df4, df4) + + assert_frame_not_equal(df3, df4) + for check_dtype in (True, False): + with pytest.raises(AssertionError, match="mismatch|different"): + assert_frame_equal(df3, df4, check_dtypes=check_dtype) + + +def test_assert_frame_equal_pass() -> None: + df1 = pl.DataFrame({"a": [1, 2]}) + df2 = pl.DataFrame({"a": [1, 2]}) + assert_frame_equal(df1, df2) + + +@pytest.mark.parametrize( + "assert_function", + [assert_frame_equal, assert_frame_not_equal], +) +def test_assert_frame_equal_types(assert_function: Any) -> None: + df1 = pl.DataFrame({"a": [1, 2]}) + srs1 = pl.Series(values=[1, 2], name="a") + with pytest.raises( + AssertionError, match=r"inputs are different \(unexpected input types\)" + ): + assert_function(df1, srs1) + + +def test_assert_frame_equal_length_mismatch() -> None: + df1 = pl.DataFrame({"a": [1, 2]}) + df2 = pl.DataFrame({"a": [1, 2, 3]}) + with pytest.raises( + AssertionError, + match=r"DataFrames are different \(number of rows does not match\)", + ): + assert_frame_equal(df1, df2) + assert_frame_not_equal(df1, df2) + + +def test_assert_frame_equal_column_mismatch() -> None: + df1 = pl.DataFrame({"a": [1, 2]}) + df2 = pl.DataFrame({"b": [1, 2]}) + with pytest.raises( + AssertionError, match="columns \\['a'\\] in left DataFrame, but not in right" + ): + assert_frame_equal(df1, df2) + assert_frame_not_equal(df1, df2) + + +def test_assert_frame_equal_column_mismatch2() -> None: + df1 = pl.LazyFrame({"a": [1, 2]}) + df2 = pl.LazyFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}) + with pytest.raises( + AssertionError, + match="columns \\['b', 'c'\\] in right LazyFrame, but not in left", + ): + assert_frame_equal(df1, df2) + assert_frame_not_equal(df1, df2) + + +def test_assert_frame_equal_column_mismatch_order() -> None: + df1 = pl.DataFrame({"b": [3, 4], "a": [1, 2]}) + df2 = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + with pytest.raises(AssertionError, match="columns are not in the same order"): + assert_frame_equal(df1, df2) + + assert_frame_equal(df1, df2, check_column_order=False) + assert_frame_not_equal(df1, df2) + + +def test_assert_frame_equal_check_row_order() -> None: + df1 = pl.DataFrame({"a": [1, 2], "b": [4, 3]}) + df2 = pl.DataFrame({"a": [2, 1], "b": [3, 4]}) + + with pytest.raises(AssertionError, match="value mismatch for column 'a'"): + assert_frame_equal(df1, df2) + + assert_frame_equal(df1, df2, check_row_order=False) + assert_frame_not_equal(df1, df2) + + +def test_assert_frame_equal_check_row_col_order() -> None: + df1 = pl.DataFrame({"a": [1, 2], "b": [4, 3]}) + df2 = pl.DataFrame({"b": [3, 4], "a": [2, 1]}) + + with pytest.raises(AssertionError, match="columns are not in the same order"): + assert_frame_equal(df1, df2, check_row_order=False) + + assert_frame_equal(df1, df2, check_row_order=False, check_column_order=False) + assert_frame_not_equal(df1, df2) + + +@pytest.mark.parametrize( + "assert_function", + [assert_frame_equal, assert_frame_not_equal], +) +def test_assert_frame_equal_check_row_order_unsortable(assert_function: Any) -> None: + df1 = pl.DataFrame({"a": [object(), object()], "b": [3, 4]}) + df2 = pl.DataFrame({"a": [object(), object()], "b": [4, 3]}) + with pytest.raises( + TypeError, match="cannot set `check_row_order=False`.*unsortable columns" + ): + assert_function(df1, df2, check_row_order=False) + + +def test_assert_frame_equal_dtypes_mismatch() -> None: + data = {"a": [1, 2], "b": [3, 4]} + df1 = pl.DataFrame(data, schema={"a": pl.Int8, "b": pl.Int16}) + df2 = pl.DataFrame(data, schema={"b": pl.Int16, "a": pl.Int16}) + + with pytest.raises(AssertionError, match="dtypes do not match"): + assert_frame_equal(df1, df2, check_column_order=False) + + assert_frame_not_equal(df1, df2, check_column_order=False) + assert_frame_not_equal(df1, df2) + + +def test_assert_frame_not_equal() -> None: + df = pl.DataFrame({"a": [1, 2]}) + with pytest.raises(AssertionError, match="DataFrames are equal"): + assert_frame_not_equal(df, df) + lf = df.lazy() + with pytest.raises(AssertionError, match="LazyFrames are equal"): + assert_frame_not_equal(lf, lf) + + +def test_assert_frame_equal_check_dtype_deprecated() -> None: + df1 = pl.DataFrame({"a": [1, 2]}) + df2 = pl.DataFrame({"a": [1.0, 2.0]}) + df3 = pl.DataFrame({"a": [2, 1]}) + + with pytest.deprecated_call(): + assert_frame_equal(df1, df2, check_dtype=False) # type: ignore[call-arg] + + with pytest.deprecated_call(): + assert_frame_not_equal(df1, df3, check_dtype=False) # type: ignore[call-arg] + + +def test_tracebackhide(testdir: pytest.Testdir) -> None: + testdir.makefile( + ".py", + test_path="""\ +import polars as pl +from polars.testing import assert_frame_equal, assert_frame_not_equal + +def test_frame_equal_fail(): + df1 = pl.DataFrame({"a": [1, 2]}) + df2 = pl.DataFrame({"a": [1, 3]}) + assert_frame_equal(df1, df2) + +def test_frame_not_equal_fail(): + df1 = pl.DataFrame({"a": [1, 2]}) + df2 = pl.DataFrame({"a": [1, 2]}) + assert_frame_not_equal(df1, df2) + +def test_frame_data_type_fail(): + df1 = pl.DataFrame({"a": [1, 2]}) + df2 = {"a": [1, 2]} + assert_frame_equal(df1, df2) + +def test_frame_schema_fail(): + df1 = pl.DataFrame({"a": [1, 2]}, {"a": pl.Int64}) + df2 = pl.DataFrame({"a": [1, 2]}, {"a": pl.Int32}) + assert_frame_equal(df1, df2) +""", + ) + result = testdir.runpytest() + result.assert_outcomes(passed=0, failed=4) + stdout = "\n".join(result.outlines) + + assert "polars/py-polars/polars/testing" not in stdout + + # The above should catch any polars testing functions that appear in the + # stack trace. But we keep the following checks (for specific function + # names) just to double-check. + + assert "def assert_frame_equal" not in stdout + assert "def assert_frame_not_equal" not in stdout + assert "def _assert_correct_input_type" not in stdout + assert "def _assert_frame_schema_equal" not in stdout + + assert "def assert_series_equal" not in stdout + assert "def assert_series_not_equal" not in stdout + assert "def _assert_series_values_equal" not in stdout + assert "def _assert_series_nested_values_equal" not in stdout + assert "def _assert_series_null_values_match" not in stdout + assert "def _assert_series_nan_values_match" not in stdout + assert "def _assert_series_values_within_tolerance" not in stdout + + # Make sure the tests are failing for the expected reason (e.g. not because + # an import is missing or something like that): + + assert ( + "AssertionError: DataFrames are different (value mismatch for column 'a')" + in stdout + ) + assert "AssertionError: DataFrames are equal" in stdout + assert "AssertionError: inputs are different (unexpected input types)" in stdout + assert "AssertionError: DataFrames are different (dtypes do not match)" in stdout diff --git a/py-polars/tests/unit/testing/test_assert_series_equal.py b/py-polars/tests/unit/testing/test_assert_series_equal.py new file mode 100644 index 000000000000..d576762fc288 --- /dev/null +++ b/py-polars/tests/unit/testing/test_assert_series_equal.py @@ -0,0 +1,865 @@ +from __future__ import annotations + +import math +from datetime import datetime, time, timedelta +from decimal import Decimal as D +from typing import Any + +import hypothesis.strategies as st +import pytest +from hypothesis import given + +import polars as pl +from polars.testing import assert_series_equal, assert_series_not_equal +from polars.testing.parametric import dtypes, series + +nan = float("nan") +pytest_plugins = ["pytester"] + + +@given(s=series()) +def test_assert_series_equal_parametric(s: pl.Series) -> None: + assert_series_equal(s, s) + + +@given(data=st.data()) +def test_assert_series_equal_parametric_array(data: st.DataObject) -> None: + inner = data.draw(dtypes(excluded_dtypes=[pl.Categorical])) + shape = data.draw(st.integers(min_value=1, max_value=3)) + dtype = pl.Array(inner, shape=shape) + s = data.draw(series(dtype=dtype)) + + assert_series_equal(s, s) + + +def test_compare_series_value_mismatch() -> None: + srs1 = pl.Series([1, 2, 3]) + srs2 = pl.Series([2, 3, 4]) + assert_series_not_equal(srs1, srs2) + + with pytest.raises( + AssertionError, + match=r"Series are different \(exact value mismatch\)", + ): + assert_series_equal(srs1, srs2) + + +def test_compare_series_empty_equal() -> None: + srs1 = pl.Series([]) + srs2 = pl.Series(()) + assert_series_equal(srs1, srs2) + + with pytest.raises( + AssertionError, + match=r"Series are equal \(but are expected not to be\)", + ): + assert_series_not_equal(srs1, srs2) + + +def test_assert_series_equal_check_order() -> None: + srs1 = pl.Series([1, 2, 3, None]) + srs2 = pl.Series([2, None, 3, 1]) + assert_series_equal(srs1, srs2, check_order=False) + + with pytest.raises( + AssertionError, + match=r"Series are equal \(but are expected not to be\)", + ): + assert_series_not_equal(srs1, srs2, check_order=False) + + +def test_assert_series_equal_check_order_unsortable_type() -> None: + s = pl.Series([object(), object()]) + with pytest.raises( + TypeError, + match="cannot set `check_order=False` on Series with unsortable data type", + ): + assert_series_equal(s, s, check_order=False) + + +def test_compare_series_nans_assert_equal() -> None: + srs1 = pl.Series([1.0, 2.0, nan, 4.0, None, 6.0]) + srs2 = pl.Series([1.0, nan, 3.0, 4.0, None, 6.0]) + srs3 = pl.Series([1.0, 2.0, 3.0, 4.0, None, 6.0]) + + for srs in (srs1, srs2, srs3): + assert_series_equal(srs, srs) + assert_series_equal(srs, srs, check_exact=True) + + for check_exact in (False, True): + if check_exact: + check_msg = "exact value mismatch" + else: + check_msg = "Series are different.*value mismatch.*" + + with pytest.raises(AssertionError, match=check_msg): + assert_series_equal(srs1, srs2, check_exact=check_exact) + with pytest.raises(AssertionError, match=check_msg): + assert_series_equal(srs1, srs3, check_exact=check_exact) + + srs4 = pl.Series([1.0, 2.0, 3.0, 4.0, None, 6.0]) + srs5 = pl.Series([1.0, 2.0, 3.0, 4.0, nan, 6.0]) + srs6 = pl.Series([1, 2, 3, 4, None, 6]) + + assert_series_equal(srs4, srs6, check_dtypes=False) + with pytest.raises(AssertionError): + assert_series_equal(srs5, srs6, check_dtypes=False) + assert_series_not_equal(srs5, srs6, check_dtypes=True) + + # nested + for float_type in (pl.Float32, pl.Float64): + srs = pl.Series([[0.0, nan]], dtype=pl.List(float_type)) + assert srs.dtype == pl.List(float_type) + assert_series_equal(srs, srs) + + +def test_compare_series_nulls() -> None: + srs1 = pl.Series([1, 2, None]) + srs2 = pl.Series([1, 2, None]) + assert_series_equal(srs1, srs2) + + srs1 = pl.Series([1, 2, 3]) + srs2 = pl.Series([1, None, None]) + assert_series_not_equal(srs1, srs2) + + with pytest.raises(AssertionError, match="value mismatch"): + assert_series_equal(srs1, srs2) + + +def test_compare_series_value_mismatch_string() -> None: + srs1 = pl.Series(["hello", "no"]) + srs2 = pl.Series(["hello", "yes"]) + + assert_series_not_equal(srs1, srs2) + with pytest.raises( + AssertionError, + match=r"Series are different \(exact value mismatch\)", + ): + assert_series_equal(srs1, srs2) + + +def test_compare_series_dtype_mismatch() -> None: + srs1 = pl.Series([1, 2, 3]) + srs2 = pl.Series([1.0, 2.0, 3.0]) + assert_series_not_equal(srs1, srs2) + + with pytest.raises( + AssertionError, + match=r"Series are different \(dtype mismatch\)", + ): + assert_series_equal(srs1, srs2) + + +@pytest.mark.parametrize( + "assert_function", [assert_series_equal, assert_series_not_equal] +) +def test_compare_series_input_type_mismatch(assert_function: Any) -> None: + srs1 = pl.Series([1, 2, 3]) + srs2 = pl.DataFrame({"col1": [2, 3, 4]}) + + with pytest.raises( + AssertionError, + match=r"inputs are different \(unexpected input types\)", + ): + assert_function(srs1, srs2) + + +def test_compare_series_name_mismatch() -> None: + srs1 = pl.Series(values=[1, 2, 3], name="srs1") + srs2 = pl.Series(values=[1, 2, 3], name="srs2") + with pytest.raises( + AssertionError, + match=r"Series are different \(name mismatch\)", + ): + assert_series_equal(srs1, srs2) + + +def test_compare_series_length_mismatch() -> None: + srs1 = pl.Series(values=[1, 2, 3, 4], name="srs1") + srs2 = pl.Series(values=[1, 2, 3], name="srs2") + + assert_series_not_equal(srs1, srs2) + with pytest.raises( + AssertionError, + match=r"Series are different \(length mismatch\)", + ): + assert_series_equal(srs1, srs2) + + +def test_compare_series_value_exact_mismatch() -> None: + srs1 = pl.Series([1.0, 2.0, 3.0]) + srs2 = pl.Series([1.0, 2.0 + 1e-7, 3.0]) + with pytest.raises( + AssertionError, + match=r"Series are different \(exact value mismatch\)", + ): + assert_series_equal(srs1, srs2, check_exact=True) + + +def test_assert_series_equal_int_overflow() -> None: + # internally may call 'abs' if not check_exact, which can overflow on signed int + s0 = pl.Series([-128], dtype=pl.Int8) + s1 = pl.Series([0, -128], dtype=pl.Int8) + s2 = pl.Series([1, -128], dtype=pl.Int8) + + for check_exact in (True, False): + assert_series_equal(s0, s0, check_exact=check_exact) + with pytest.raises(AssertionError): + assert_series_equal(s1, s2, check_exact=check_exact) + + +@pytest.mark.parametrize( + ("data1", "data2"), + [ + ([datetime(2022, 10, 2, 12)], [datetime(2022, 10, 2, 13)]), + ([time(10, 0, 0)], [time(10, 0, 10)]), + ([timedelta(10, 0, 0)], [timedelta(10, 0, 10)]), + ], +) +def test_assert_series_equal_temporal(data1: Any, data2: Any) -> None: + s1 = pl.Series(data1) + s2 = pl.Series(data2) + assert_series_not_equal(s1, s2) + + +@pytest.mark.parametrize( + ("s1", "s2", "kwargs"), + [ + pytest.param( + pl.Series([0.2, 0.3]), + pl.Series([0.2, 0.3]), + {"atol": 1e-15}, + id="equal_floats_low_atol", + ), + pytest.param( + pl.Series([0.2, 0.3]), + pl.Series([0.2, 0.3000000000000001]), + {"atol": 1e-15}, + id="approx_equal_float_low_atol", + ), + pytest.param( + pl.Series([0.2, 0.3]), + pl.Series([0.2, 0.31]), + {"atol": 0.1}, + id="approx_equal_float_high_atol", + ), + pytest.param( + pl.Series([0.2, 1.3]), + pl.Series([0.2, 0.9]), + {"atol": 1}, + id="approx_equal_float_integer_atol", + ), + pytest.param( + pl.Series([1.0, 2.0, nan]), + pl.Series([1.005, 2.005, nan]), + {"atol": 1e-2, "rtol": 0.0}, + id="approx_equal_float_nan_atol", + ), + pytest.param( + pl.Series([1.0, 2.0, None]), + pl.Series([1.005, 2.005, None]), + {"atol": 1e-2}, + id="approx_equal_float_none_atol", + ), + pytest.param( + pl.Series([1.0, 2.0, nan]), + pl.Series([1.005, 2.015, nan]), + {"atol": 0.0, "rtol": 1e-2}, + id="approx_equal_float_nan_rtol", + ), + pytest.param( + pl.Series([1.0, 2.0, None]), + pl.Series([1.005, 2.015, None]), + {"rtol": 1e-2}, + id="approx_equal_float_none_rtol", + ), + pytest.param( + pl.Series([0.0, 1.0, 2.0], dtype=pl.Float64), + pl.Series([0, 1, 2], dtype=pl.Int64), + {"check_dtypes": False}, + id="equal_int_float_integer_no_check_dtype", + ), + pytest.param( + pl.Series([0, 1, 2], dtype=pl.Float64), + pl.Series([0, 1, 2], dtype=pl.Float32), + {"check_dtypes": False}, + id="equal_int_float_integer_no_check_dtype", + ), + pytest.param( + pl.Series([0, 1, 2], dtype=pl.Int64), + pl.Series([0, 1, 2], dtype=pl.Int64), + {}, + id="equal_int", + ), + pytest.param( + pl.Series(["a", "b", "c"], dtype=pl.String), + pl.Series(["a", "b", "c"], dtype=pl.String), + {}, + id="equal_str", + ), + pytest.param( + pl.Series([[0.2, 0.3]]), + pl.Series([[0.2, 0.31]]), + {"atol": 0.1}, + id="list_of_float_high_atol", + ), + pytest.param( + pl.Series([[0.2, 1.3]]), + pl.Series([[0.2, 0.9]]), + {"atol": 1}, + id="list_of_float_integer_atol", + ), + pytest.param( + pl.Series([[0.2, 0.3]]), + pl.Series([[0.2, 0.300000001]]), + {"rtol": 1e-15}, + id="list_of_float_low_rtol", + ), + pytest.param( + pl.Series([[0.2, 0.3]]), + pl.Series([[0.2, 0.301]]), + {"rtol": 0.1}, + id="list_of_float_high_rtol", + ), + pytest.param( + pl.Series([[0.2, 1.3]]), + pl.Series([[0.2, 0.9]]), + {"rtol": 1}, + id="list_of_float_integer_rtol", + ), + pytest.param( + pl.Series([[None, 1.3]]), + pl.Series([[None, 0.9]]), + {"rtol": 1}, + id="list_of_none_and_float_integer_rtol", + ), + pytest.param( + pl.Series([[None, 1]], dtype=pl.List(pl.Int64)), + pl.Series([[None, 1]], dtype=pl.List(pl.Int64)), + {"rtol": 1}, + id="list_of_none_and_int_integer_rtol", + ), + pytest.param( + pl.Series([[math.nan, 1.3]]), + pl.Series([[math.nan, 0.9]]), + {"rtol": 1}, + id="list_of_none_and_float_integer_rtol", + ), + pytest.param( + pl.Series([[2.0, 3.0]]), + pl.Series([[2, 3]]), + {"check_exact": False, "check_dtypes": False}, + id="list_of_float_list_of_int_check_dtype_false", + ), + pytest.param( + pl.Series([[[0.2, 3.0]]]), + pl.Series([[[0.2, 3.00000001]]]), + {"atol": 0.1}, + id="nested_list_of_float_atol_high", + ), + pytest.param( + pl.Series([[[0.2, math.nan, 3.0]]]), + pl.Series([[[0.2, math.nan, 3.00000001]]]), + {"atol": 0.1}, + id="nested_list_of_float_and_nan_atol_high", + ), + pytest.param( + pl.Series([[[[0.2, 3.0]]]]), + pl.Series([[[[0.2, 3.00000001]]]]), + {"atol": 0.1}, + id="double_nested_list_of_float_atol_high", + ), + pytest.param( + pl.Series([[[[0.2, math.nan, 3.0]]]]), + pl.Series([[[[0.2, math.nan, 3.00000001]]]]), + {"atol": 0.1}, + id="double_nested_list_of_float__and_nan_atol_high", + ), + pytest.param( + pl.Series([[[[[0.2, 3.0]]]]]), + pl.Series([[[[[0.2, 3.00000001]]]]]), + {"atol": 0.1}, + id="triple_nested_list_of_float_atol_high", + ), + pytest.param( + pl.Series([[[[[0.2, math.nan, 3.0]]]]]), + pl.Series([[[[[0.2, math.nan, 3.00000001]]]]]), + {"atol": 0.1}, + id="triple_nested_list_of_float_and_nan_atol_high", + ), + pytest.param( + pl.struct(a=0, b=1, eager=True), + pl.struct(a=0, b=1, eager=True), + {}, + id="struct_equal", + ), + pytest.param( + pl.struct(a=0, b=1.1, eager=True), + pl.struct(a=0, b=1.01, eager=True), + {"atol": 0.1, "rtol": 0}, + id="struct_approx_equal", + ), + pytest.param( + pl.struct(a=0, b=[0.0, 1.1], eager=True), + pl.struct(a=0, b=[0.0, 1.11], eager=True), + {"atol": 0.1}, + id="struct_with_list_approx_equal", + ), + pytest.param( + pl.struct(a=0, b=[0.0, math.nan], eager=True), + pl.struct(a=0, b=[0.0, math.nan], eager=True), + {"atol": 0.1}, + id="struct_with_list_with_nan_compare_equal_true", + ), + ], +) +def test_assert_series_equal_passes_assertion( + s1: pl.Series, + s2: pl.Series, + kwargs: Any, +) -> None: + assert_series_equal(s1, s2, **kwargs) + with pytest.raises(AssertionError): + assert_series_not_equal(s1, s2, **kwargs) + + +@pytest.mark.parametrize( + ("s1", "s2", "kwargs"), + [ + pytest.param( + pl.Series([0.2, 0.3]), + pl.Series([0.2, 0.39]), + {"atol": 0.09, "rtol": 0}, + id="approx_equal_float_high_atol_zero_rtol", + ), + pytest.param( + pl.Series([0.2, 1.3]), + pl.Series([0.2, 2.31]), + {"atol": 1, "rtol": 0}, + id="approx_equal_float_integer_atol_zero_rtol", + ), + pytest.param( + pl.Series([0, 1, 2], dtype=pl.Float64), + pl.Series([0, 1, 2], dtype=pl.Int64), + {"check_dtypes": True}, + id="equal_int_float_integer_check_dtype", + ), + pytest.param( + pl.Series([0, 1, 2], dtype=pl.Float64), + pl.Series([0, 1, 2], dtype=pl.Float32), + {"check_dtypes": True}, + id="equal_int_float_integer_check_dtype", + ), + pytest.param( + pl.Series([1.0, 2.0, nan]), + pl.Series([1.005, 2.005, 3.005]), + {"atol": 1e-2, "rtol": 0.0}, + id="approx_equal_float_left_nan_atol", + ), + pytest.param( + pl.Series([1.0, 2.0, 3.0]), + pl.Series([1.005, 2.005, nan]), + {"atol": 1e-2, "rtol": 0.0}, + id="approx_equal_float_right_nan_atol", + ), + pytest.param( + pl.Series([1.0, 2.0, nan]), + pl.Series([1.005, 2.015, 3.025]), + {"atol": 0.0, "rtol": 1e-2}, + id="approx_equal_float_left_nan_rtol", + ), + pytest.param( + pl.Series([1.0, 2.0, 3.0]), + pl.Series([1.005, 2.015, nan]), + {"atol": 0.0, "rtol": 1e-2}, + id="approx_equal_float_right_nan_rtol", + ), + pytest.param( + pl.Series([[0.2, 0.3]]), + pl.Series([[0.2, 0.3, 0.4]]), + {}, + id="list_of_float_different_lengths", + ), + pytest.param( + pl.Series([[0.2, 0.3]]), + pl.Series([[0.2, 0.3000000000000001]]), + {"check_exact": True}, + id="list_of_float_check_exact", + ), + pytest.param( + pl.Series([[0.2, 0.3]]), + pl.Series([[0.2, 0.300001]]), + {"atol": 1e-15, "rtol": 0}, + id="list_of_float_too_low_atol", + ), + pytest.param( + pl.Series([[0.2, 0.3]]), + pl.Series([[0.2, 0.30000001]]), + {"atol": -1, "rtol": 0}, + id="list_of_float_negative_atol", + ), + pytest.param( + pl.Series([[2.0, 3.0]]), + pl.Series([[2, 3]]), + {"check_exact": False, "check_dtypes": True}, + id="list_of_float_list_of_int_check_dtype_true", + ), + pytest.param( + pl.struct(a=0, b=1.1, eager=True), + pl.struct(a=0, b=1, eager=True), + {"atol": 0.1, "rtol": 0, "check_dtypes": True}, + id="struct_approx_equal_different_type", + ), + pytest.param( + pl.struct(a=0, b=1.09, eager=True), + pl.struct(a=0, b=1, eager=True), + {"atol": 0.1, "rtol": 0, "check_dtypes": False}, + id="struct_approx_equal_different_type", + ), + ], +) +def test_assert_series_equal_raises_assertion_error( + s1: pl.Series, + s2: pl.Series, + kwargs: Any, +) -> None: + with pytest.raises(AssertionError): + assert_series_equal(s1, s2, **kwargs) + assert_series_not_equal(s1, s2, **kwargs) + + +def test_assert_series_equal_categorical_vs_str() -> None: + s1 = pl.Series(["a", "b", "a"], dtype=pl.Categorical) + s2 = pl.Series(["a", "b", "a"], dtype=pl.String) + + with pytest.raises(AssertionError, match="dtype mismatch"): + assert_series_equal(s1, s2, categorical_as_str=True) + + assert_series_equal(s1, s2, check_dtypes=False, categorical_as_str=True) + assert_series_equal(s2, s1, check_dtypes=False, categorical_as_str=True) + + +def test_assert_series_equal_incompatible_data_types() -> None: + s1 = pl.Series(["a", "b", "a"], dtype=pl.Categorical) + s2 = pl.Series([0, 1, 0], dtype=pl.Int8) + + with pytest.raises(AssertionError, match="incompatible data types"): + assert_series_equal(s1, s2, check_dtypes=False) + + +def test_assert_series_equal_full_series() -> None: + s1 = pl.Series([1, 2, 3]) + s2 = pl.Series([1, 2, 4]) + msg = ( + r"Series are different \(exact value mismatch\)\n" + r"\[left\]: \[1, 2, 3\]\n" + r"\[right\]: \[1, 2, 4\]" + ) + with pytest.raises(AssertionError, match=msg): + assert_series_equal(s1, s2) + + +def test_assert_series_not_equal() -> None: + s = pl.Series("a", [1, 2]) + with pytest.raises( + AssertionError, + match=r"Series are equal \(but are expected not to be\)", + ): + assert_series_not_equal(s, s) + + +def test_assert_series_equal_nested_list_float() -> None: + # First entry has only integers + s1 = pl.Series([[1.0, 2.0], [3.0, 4.0]], dtype=pl.List(pl.Float64)) + s2 = pl.Series([[1.0, 2.0], [3.0, 4.9]], dtype=pl.List(pl.Float64)) + + with pytest.raises( + AssertionError, + match=r"Series are different \(nested value mismatch\)", + ): + assert_series_equal(s1, s2) + + +def test_assert_series_equal_nested_struct_float() -> None: + s1 = pl.Series( + [{"a": 1.0, "b": 2.0}, {"a": 3.0, "b": 4.0}], + dtype=pl.Struct({"a": pl.Float64, "b": pl.Float64}), + ) + s2 = pl.Series( + [{"a": 1.0, "b": 2.0}, {"a": 3.0, "b": 4.9}], + dtype=pl.Struct({"a": pl.Float64, "b": pl.Float64}), + ) + + with pytest.raises( + AssertionError, + match=r"Series are different \(nested value mismatch\)", + ): + assert_series_equal(s1, s2) + + +def test_assert_series_equal_full_null_incompatible_dtypes_raises() -> None: + s1 = pl.Series([None, None], dtype=pl.Categorical) + s2 = pl.Series([None, None], dtype=pl.Int16) + + # You could argue this should pass, but it's rare enough not to warrant the + # additional check + with pytest.raises( + AssertionError, + match="incompatible data types", + ): + assert_series_equal(s1, s2, check_dtypes=False) + + +def test_assert_series_equal_full_null_nested_list() -> None: + s = pl.Series([None, None], dtype=pl.List(pl.Float64)) + assert_series_equal(s, s) + + +def test_assert_series_equal_nested_list_nan() -> None: + s = pl.Series([[1.0, 2.0], [3.0, nan]], dtype=pl.List(pl.Float64)) + assert_series_equal(s, s) + + +def test_assert_series_equal_nested_list_none() -> None: + s1 = pl.Series([[1.0, 2.0], None], dtype=pl.List(pl.Float64)) + s2 = pl.Series([[1.0, 2.0], None], dtype=pl.List(pl.Float64)) + + assert_series_equal(s1, s2) + + +def test_assert_series_equal_uint_overflow() -> None: + s1 = pl.Series([1, 2, 3], dtype=pl.UInt8) + s2 = pl.Series([2, 3, 4], dtype=pl.UInt8) + + with pytest.raises( + AssertionError, + match=r"Series are different \(exact value mismatch\)", + ): + assert_series_equal(s1, s2, atol=0) + + with pytest.raises( + AssertionError, + match=r"Series are different \(exact value mismatch\)", + ): + assert_series_equal(s1, s2, atol=1) + + left = pl.Series( + values=[2810428175213635359], + dtype=pl.UInt64, + ) + right = pl.Series( + values=[15807433754238349345], + dtype=pl.UInt64, + ) + with pytest.raises(AssertionError): + assert_series_equal(left, right) + + +def test_assert_series_equal_uint_always_checked_exactly() -> None: + s1 = pl.Series([1, 3], dtype=pl.UInt8) + s2 = pl.Series([2, 4], dtype=pl.Int64) + + with pytest.raises( + AssertionError, + match=r"Series are different \(exact value mismatch\)", + ): + assert_series_equal(s1, s2, atol=1, check_dtypes=False) + + +def test_assert_series_equal_nested_int_always_checked_exactly() -> None: + s1 = pl.Series([[1, 2], [3, 4]]) + s2 = pl.Series([[1, 2], [3, 5]]) + + with pytest.raises( + AssertionError, + match=r"Series are different \(exact value mismatch\)", + ): + assert_series_equal(s1, s2, atol=1) + with pytest.raises( + AssertionError, + match=r"Series are different \(exact value mismatch\)", + ): + assert_series_equal(s1, s2, check_exact=True) + + +@pytest.mark.parametrize("check_exact", [True, False]) +def test_assert_series_equal_array_equal(check_exact: bool) -> None: + s1 = pl.Series([[1.0, 2.0], [3.0, 4.0]], dtype=pl.Array(pl.Float64, 2)) + s2 = pl.Series([[1.0, 2.0], [3.0, 4.2]], dtype=pl.Array(pl.Float64, 2)) + + with pytest.raises( + AssertionError, match=r"Series are different \(nested value mismatch\)" + ): + assert_series_equal(s1, s2, check_exact=check_exact) + + +def test_series_equal_nested_lengths_mismatch() -> None: + s1 = pl.Series([[1.0, 2.0], [3.0, 4.0]], dtype=pl.List(pl.Float64)) + s2 = pl.Series([[1.0, 2.0, 3.0], [4.0]], dtype=pl.List(pl.Float64)) + + with pytest.raises(AssertionError, match="nested value mismatch"): + assert_series_equal(s1, s2) + + +@pytest.mark.parametrize("check_exact", [True, False]) +def test_series_equal_decimals(check_exact: bool) -> None: + s1 = pl.Series([D("1.00000"), D("2.00000")], dtype=pl.Decimal) + s2 = pl.Series([D("1.00000"), D("2.00001")], dtype=pl.Decimal) + + assert_series_equal(s1, s1, check_exact=check_exact) + assert_series_equal(s2, s2, check_exact=check_exact) + + with pytest.raises(AssertionError, match="exact value mismatch"): + assert_series_equal(s1, s2, check_exact=check_exact) + + +def test_assert_series_equal_w_large_integers_12328() -> None: + left = pl.Series([1577840521123000]) + right = pl.Series([1577840521123543]) + with pytest.raises(AssertionError): + assert_series_equal(left, right) + + +def test_assert_series_equal_check_dtype_deprecated() -> None: + s1 = pl.Series("a", [1, 2]) + s2 = pl.Series("a", [1.0, 2.0]) + s3 = pl.Series("a", [2, 1]) + + with pytest.deprecated_call(): + assert_series_equal(s1, s2, check_dtype=False) # type: ignore[call-arg] + + with pytest.deprecated_call(): + assert_series_not_equal(s1, s3, check_dtype=False) # type: ignore[call-arg] + + +def test_assert_series_equal_nested_categorical_as_str_global() -> None: + # https://github.com/pola-rs/polars/issues/16196 + + # Global + with pl.StringCache(): + s1 = pl.Series(["c0"], dtype=pl.Categorical) + s2 = pl.Series(["c1"], dtype=pl.Categorical) + s_global = pl.DataFrame([s1, s2]).to_struct("col0") + + # Local + s1 = pl.Series(["c0"], dtype=pl.Categorical) + s2 = pl.Series(["c1"], dtype=pl.Categorical) + s_local = pl.DataFrame([s1, s2]).to_struct("col0") + + assert_series_equal(s_global, s_local, categorical_as_str=True) + assert_series_not_equal(s_global, s_local, categorical_as_str=False) + + +@pytest.mark.parametrize( + "s", + [ + pl.Series([["a", "b"], ["a"]], dtype=pl.List(pl.Categorical)), + pl.Series([{"a": "x"}, {"a": "y"}], dtype=pl.Struct({"a": pl.Categorical})), + ], +) +def test_assert_series_equal_nested_categorical_as_str(s: pl.Series) -> None: + assert_series_equal(s, s, categorical_as_str=True) + + +def test_tracebackhide(testdir: pytest.Testdir) -> None: + testdir.makefile( + ".py", + test_path="""\ +import polars as pl +from polars.testing import assert_series_equal, assert_series_not_equal + +nan = float("nan") + +def test_series_equal_fail(): + s1 = pl.Series([1, 2]) + s2 = pl.Series([1, 3]) + assert_series_equal(s1, s2) + +def test_series_not_equal_fail(): + s1 = pl.Series([1, 2]) + s2 = pl.Series([1, 2]) + assert_series_not_equal(s1, s2) + +def test_series_nested_fail(): + s1 = pl.Series([[1, 2], [3, 4]]) + s2 = pl.Series([[1, 2], [3, 5]]) + assert_series_equal(s1, s2) + +def test_series_null_fail(): + s1 = pl.Series([1, 2]) + s2 = pl.Series([1, None]) + assert_series_equal(s1, s2) + +def test_series_nan_fail(): + s1 = pl.Series([1.0, 2.0]) + s2 = pl.Series([1.0, nan]) + assert_series_equal(s1, s2) + +def test_series_float_tolerance_fail(): + s1 = pl.Series([1.0, 2.0]) + s2 = pl.Series([1.0, 2.1]) + assert_series_equal(s1, s2) + +def test_series_schema_fail(): + s1 = pl.Series([1, 2], dtype=pl.Int64) + s2 = pl.Series([1, 2], dtype=pl.Int32) + assert_series_equal(s1, s2) + +def test_series_data_type_fail(): + s1 = pl.Series([1, 2]) + s2 = [1, 2] + assert_series_equal(s1, s2) +""", + ) + result = testdir.runpytest() + result.assert_outcomes(passed=0, failed=8) + stdout = "\n".join(result.outlines) + + assert "polars/py-polars/polars/testing" not in stdout + + # The above should catch any polars testing functions that appear in the + # stack trace. But we keep the following checks (for specific function + # names) just to double-check. + + assert "def assert_series_equal" not in stdout + assert "def assert_series_not_equal" not in stdout + assert "def _assert_series_values_equal" not in stdout + assert "def _assert_series_nested_values_equal" not in stdout + assert "def _assert_series_null_values_match" not in stdout + assert "def _assert_series_nan_values_match" not in stdout + assert "def _assert_series_values_within_tolerance" not in stdout + + # Make sure the tests are failing for the expected reason (e.g. not because + # an import is missing or something like that): + + assert "AssertionError: Series are different (exact value mismatch)" in stdout + assert "AssertionError: Series are equal" in stdout + assert "AssertionError: Series are different (nan value mismatch)" in stdout + assert "AssertionError: Series are different (dtype mismatch)" in stdout + assert "AssertionError: inputs are different (unexpected input types)" in stdout + + +def test_assert_series_equal_inf() -> None: + s1 = pl.Series([1.0, float("inf")]) + s2 = pl.Series([1.0, float("inf")]) + assert_series_equal(s1, s2) + + s1 = pl.Series([1.0, float("-inf")]) + s2 = pl.Series([1.0, float("-inf")]) + assert_series_equal(s1, s2) + + s1 = pl.Series([1.0, float("inf")]) + s2 = pl.Series([float("inf"), 1.0]) + assert_series_not_equal(s1, s2) + + s1 = pl.Series([1.0, float("inf")]) + s2 = pl.Series([1.0, float("-inf")]) + assert_series_not_equal(s1, s2) + + s1 = pl.Series([1.0, float("inf")]) + s2 = pl.Series([1.0, 2.0]) + assert_series_not_equal(s1, s2) + + s1 = pl.Series([1.0, float("inf")]) + s2 = pl.Series([1.0, float("nan")]) + assert_series_not_equal(s1, s2) diff --git a/py-polars/tests/unit/utils/__init__.py b/py-polars/tests/unit/utils/__init__.py new file mode 100644 index 000000000000..f27b760953b4 --- /dev/null +++ b/py-polars/tests/unit/utils/__init__.py @@ -0,0 +1 @@ +"""Test module for utility functions.""" diff --git a/py-polars/tests/unit/utils/parse/__init__.py b/py-polars/tests/unit/utils/parse/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/utils/parse/test_expr.py b/py-polars/tests/unit/utils/parse/test_expr.py new file mode 100644 index 000000000000..0bd1fd1621a3 --- /dev/null +++ b/py-polars/tests/unit/utils/parse/test_expr.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +from datetime import date +from typing import Any + +import pytest + +import polars as pl +from polars._utils.parse.expr import parse_into_expression +from polars._utils.wrap import wrap_expr +from polars.testing import assert_frame_equal + + +def assert_expr_equal(result: pl.Expr, expected: pl.Expr) -> None: + """ + Evaluate the given expressions in a simple context to assert equality. + + WARNING: This is not a fully featured function - it's just to evaluate the tests in + this module. Do not use it elsewhere. + """ + df = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + assert_frame_equal(df.select(result), df.select(expected)) + + +@pytest.mark.parametrize( + "input", [5, 2.0, pl.Series([1, 2, 3]), date(2022, 1, 1), b"hi"] +) +def test_parse_into_expression_lit(input: Any) -> None: + result = wrap_expr(parse_into_expression(input)) + expected = pl.lit(input) + assert_expr_equal(result, expected) + + +def test_parse_into_expression_col() -> None: + result = wrap_expr(parse_into_expression("a")) + expected = pl.col("a") + assert_expr_equal(result, expected) + + +@pytest.mark.parametrize("input", [pl.lit(4), pl.col("a")]) +def test_parse_into_expression_expr(input: pl.Expr) -> None: + result = wrap_expr(parse_into_expression(input)) + expected = input + assert_expr_equal(result, expected) + + +@pytest.mark.parametrize( + "input", [pl.when(True).then(1), pl.when(True).then(1).when(False).then(0)] +) +def test_parse_into_expression_whenthen(input: Any) -> None: + result = wrap_expr(parse_into_expression(input)) + expected = input.otherwise(None) + assert_expr_equal(result, expected) + + +@pytest.mark.parametrize("input", [[1, 2, 3], (1, 2)]) +def test_parse_into_expression_list(input: Any) -> None: + result = wrap_expr(parse_into_expression(input)) + expected = pl.lit(pl.Series("literal", [input])) + assert_expr_equal(result, expected) + + +def test_parse_into_expression_str_as_lit() -> None: + result = wrap_expr(parse_into_expression("a", str_as_lit=True)) + expected = pl.lit("a") + assert_expr_equal(result, expected) + + +def test_parse_into_expression_structify() -> None: + result = wrap_expr(parse_into_expression(pl.col("a", "b"), structify=True)) + expected = pl.struct("a", "b") + assert_expr_equal(result, expected) + + +def test_parse_into_expression_structify_multiple_outputs() -> None: + # note: this only works because assert_expr_equal evaluates on a dataframe with + # columns "a" and "b" + result = wrap_expr(parse_into_expression(pl.col("*"), structify=True)) + expected = pl.struct("a", "b") + assert_expr_equal(result, expected) diff --git a/py-polars/tests/unit/utils/pycapsule_utils.py b/py-polars/tests/unit/utils/pycapsule_utils.py new file mode 100644 index 000000000000..4bb4b17261e2 --- /dev/null +++ b/py-polars/tests/unit/utils/pycapsule_utils.py @@ -0,0 +1,43 @@ +from typing import Any + + +class PyCapsuleStreamHolder: + """ + Hold the Arrow C Stream pycapsule. + + A class that exposes the Arrow C Stream interface via Arrow PyCapsules. This + ensures that the consumer is seeing _only_ the `__arrow_c_stream__` dunder, and that + nothing else (e.g. the dataframe or array interface) is actually being used. + """ + + arrow_obj: Any + + def __init__(self, arrow_obj: object) -> None: + self.arrow_obj = arrow_obj + + def __arrow_c_stream__(self, requested_schema: object = None) -> object: + return self.arrow_obj.__arrow_c_stream__(requested_schema) + + def __iter__(self) -> None: + return + + def __next__(self) -> None: + return + + +class PyCapsuleArrayHolder: + """ + Hold the Arrow C Array pycapsule. + + A class that exposes _only_ the Arrow C Array interface via Arrow PyCapsules. This + ensures that the consumer is seeing _only_ the `__arrow_c_array__` dunder, and that + nothing else (e.g. the dataframe or array interface) is actually being used. + """ + + arrow_obj: Any + + def __init__(self, arrow_obj: object) -> None: + self.arrow_obj = arrow_obj + + def __arrow_c_array__(self, requested_schema: object = None) -> object: + return self.arrow_obj.__arrow_c_array__(requested_schema) diff --git a/py-polars/tests/unit/utils/test_deprecation.py b/py-polars/tests/unit/utils/test_deprecation.py new file mode 100644 index 000000000000..c9264164eb0a --- /dev/null +++ b/py-polars/tests/unit/utils/test_deprecation.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import inspect +from typing import Any + +import pytest + +from polars._utils.deprecation import ( + deprecate_function, + deprecate_nonkeyword_arguments, + deprecate_parameter_as_multi_positional, + deprecate_renamed_function, + deprecate_renamed_parameter, + issue_deprecation_warning, +) + + +def test_issue_deprecation_warning() -> None: + with pytest.deprecated_call(): + issue_deprecation_warning("deprecated", version="0.1.2") + + +def test_deprecate_function() -> None: + @deprecate_function("This is deprecated.", version="1.0.0") + def hello() -> None: ... + + with pytest.deprecated_call(): + hello() + + +def test_deprecate_renamed_function() -> None: + @deprecate_renamed_function("new_hello", version="1.0.0") + def hello() -> None: ... + + with pytest.deprecated_call(match="new_hello"): + hello() + + +def test_deprecate_renamed_parameter(recwarn: Any) -> None: + @deprecate_renamed_parameter("foo", "oof", version="1.0.0") + @deprecate_renamed_parameter("bar", "rab", version="2.0.0") + def hello(oof: str, rab: str, ham: str) -> None: ... + + hello(foo="x", bar="y", ham="z") # type: ignore[call-arg] + + assert len(recwarn) == 2 + assert "oof" in str(recwarn[0].message) + assert "rab" in str(recwarn[1].message) + + +class Foo: # noqa: D101 + @deprecate_nonkeyword_arguments(allowed_args=["self", "baz"], version="0.1.2") + def bar( + self, baz: str, ham: str | None = None, foobar: str | None = None + ) -> None: ... + + +def test_deprecate_nonkeyword_arguments_method_signature() -> None: + # Note the added star indicating keyword-only arguments after 'baz' + expected = "(self, baz: 'str', *, ham: 'str | None' = None, foobar: 'str | None' = None) -> 'None'" + assert str(inspect.signature(Foo.bar)) == expected + + +def test_deprecate_nonkeyword_arguments_method_warning() -> None: + msg = ( + r"All arguments of Foo\.bar except for \'baz\' will be keyword-only in the next breaking release." + r" Use keyword arguments to silence this warning." + ) + with pytest.deprecated_call(match=msg): + Foo().bar("qux", "quox") + + +def test_deprecate_parameter_as_multi_positional(recwarn: Any) -> None: + @deprecate_parameter_as_multi_positional("foo", version="1.0.0") + def hello(*foo: str) -> tuple[str, ...]: + return foo + + with pytest.deprecated_call(): + result = hello(foo="x") + assert result == hello("x") + + with pytest.deprecated_call(): + result = hello(foo=["x", "y"]) # type: ignore[arg-type] + assert result == hello("x", "y") + + +def test_deprecate_parameter_as_multi_positional_existing_arg(recwarn: Any) -> None: + @deprecate_parameter_as_multi_positional("foo", version="1.0.0") + def hello(bar: int, *foo: str) -> tuple[int, tuple[str, ...]]: + return bar, foo + + with pytest.deprecated_call(): + result = hello(5, foo="x") + assert result == hello(5, "x") + + with pytest.deprecated_call(): + result = hello(5, foo=["x", "y"]) # type: ignore[arg-type] + assert result == hello(5, "x", "y") diff --git a/py-polars/tests/unit/utils/test_unstable.py b/py-polars/tests/unit/utils/test_unstable.py new file mode 100644 index 000000000000..7dcd6983c538 --- /dev/null +++ b/py-polars/tests/unit/utils/test_unstable.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import pytest + +from polars._utils.unstable import issue_unstable_warning, unstable +from polars.exceptions import UnstableWarning + + +def test_issue_unstable_warning(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("POLARS_WARN_UNSTABLE", "1") + + msg = "`func` is considered unstable." + expected = ( + msg + + " It may be changed at any point without it being considered a breaking change." + ) + with pytest.warns(UnstableWarning, match=expected): + issue_unstable_warning(msg) + + +def test_issue_unstable_warning_default(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("POLARS_WARN_UNSTABLE", "1") + + msg = "This functionality is considered unstable." + with pytest.warns(UnstableWarning, match=msg): + issue_unstable_warning() + + +def test_issue_unstable_warning_setting_disabled( + recwarn: pytest.WarningsRecorder, +) -> None: + issue_unstable_warning() + assert len(recwarn) == 0 + + +def test_unstable_decorator(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("POLARS_WARN_UNSTABLE", "1") + + @unstable() + def hello() -> None: ... + + msg = "`hello` is considered unstable." + with pytest.warns(UnstableWarning, match=msg): + hello() + + +def test_unstable_decorator_setting_disabled(recwarn: pytest.WarningsRecorder) -> None: + @unstable() + def hello() -> None: ... + + hello() + assert len(recwarn) == 0 diff --git a/py-polars/tests/unit/utils/test_utils.py b/py-polars/tests/unit/utils/test_utils.py new file mode 100644 index 000000000000..d7577fda9f35 --- /dev/null +++ b/py-polars/tests/unit/utils/test_utils.py @@ -0,0 +1,291 @@ +from __future__ import annotations + +from datetime import date, datetime, time, timedelta +from typing import TYPE_CHECKING, Any +from zoneinfo import ZoneInfo + +import numpy as np +import pytest + +import polars as pl +from polars._utils.convert import ( + date_to_int, + datetime_to_int, + parse_as_duration_string, + time_to_int, + timedelta_to_int, +) +from polars._utils.various import ( + _in_notebook, + is_bool_sequence, + is_int_sequence, + is_sequence, + is_str_sequence, + parse_percentiles, + parse_version, +) + +if TYPE_CHECKING: + from collections.abc import Sequence + + from polars._typing import TimeUnit + + +@pytest.mark.parametrize( + ("td", "expected"), + [ + (timedelta(), ""), + (timedelta(days=1), "1d"), + (timedelta(days=-1), "-1d"), + (timedelta(seconds=1), "1s"), + (timedelta(seconds=-1), "-1s"), + (timedelta(microseconds=1), "1us"), + (timedelta(microseconds=-1), "-1us"), + (timedelta(days=1, seconds=1), "1d1s"), + (timedelta(minutes=-1, seconds=1), "-59s"), + (timedelta(days=-1, seconds=-1), "-1d1s"), + (timedelta(days=1, microseconds=1), "1d1us"), + (timedelta(days=-1, microseconds=-1), "-1d1us"), + (None, None), + ("1d2s", "1d2s"), + ], +) +def test_parse_as_duration_string( + td: timedelta | str | None, expected: str | None +) -> None: + assert parse_as_duration_string(td) == expected + + +@pytest.mark.parametrize( + ("d", "expected"), + [ + (date(1999, 9, 9), 10_843), + (date(1969, 12, 31), -1), + (date.min, -719_162), + (date.max, 2_932_896), + ], +) +def test_date_to_int(d: date, expected: int) -> None: + assert date_to_int(d) == expected + + +@pytest.mark.parametrize( + ("t", "expected"), + [ + (time(0, 0, 1), 1_000_000_000), + (time(20, 52, 10), 75_130_000_000_000), + (time(20, 52, 10, 200), 75_130_000_200_000), + (time.min, 0), + (time.max, 86_399_999_999_000), + (time(12, 0, tzinfo=None), 43_200_000_000_000), + (time(12, 0, tzinfo=ZoneInfo("UTC")), 43_200_000_000_000), + (time(12, 0, tzinfo=ZoneInfo("Asia/Shanghai")), 43_200_000_000_000), + (time(12, 0, tzinfo=ZoneInfo("US/Central")), 43_200_000_000_000), + ], +) +def test_time_to_int(t: time, expected: int) -> None: + assert time_to_int(t) == expected + + +@pytest.mark.parametrize( + "tzinfo", [None, ZoneInfo("UTC"), ZoneInfo("Asia/Shanghai"), ZoneInfo("US/Central")] +) +def test_time_to_int_with_time_zone(tzinfo: Any) -> None: + t = time(12, 0, tzinfo=tzinfo) + assert time_to_int(t) == 43_200_000_000_000 + + +@pytest.mark.parametrize( + ("dt", "time_unit", "expected"), + [ + (datetime(2121, 1, 1), "ns", 4_765_132_800_000_000_000), + (datetime(2121, 1, 1), "us", 4_765_132_800_000_000), + (datetime(2121, 1, 1), "ms", 4_765_132_800_000), + (datetime(1969, 12, 31, 23, 59, 59, 999999), "us", -1), + (datetime(1969, 12, 30, 23, 59, 59, 999999), "us", -86_400_000_001), + (datetime.min, "ns", -62_135_596_800_000_000_000), + (datetime.max, "ns", 253_402_300_799_999_999_000), + (datetime.min, "ms", -62_135_596_800_000), + (datetime.max, "ms", 253_402_300_799_999), + ], +) +def test_datetime_to_int(dt: datetime, time_unit: TimeUnit, expected: int) -> None: + assert datetime_to_int(dt, time_unit) == expected + + +@pytest.mark.parametrize( + ("dt", "expected"), + [ + ( + datetime(2000, 1, 1, 12, 0, tzinfo=None), + 946_728_000_000_000, + ), + ( + datetime(2000, 1, 1, 12, 0, tzinfo=ZoneInfo("UTC")), + 946_728_000_000_000, + ), + ( + datetime(2000, 1, 1, 12, 0, tzinfo=ZoneInfo("Asia/Shanghai")), + 946_699_200_000_000, + ), + ( + datetime(2000, 1, 1, 12, 0, tzinfo=ZoneInfo("US/Central")), + 946_749_600_000_000, + ), + ], +) +def test_datetime_to_int_with_time_zone(dt: datetime, expected: int) -> None: + assert datetime_to_int(dt, "us") == expected + + +@pytest.mark.parametrize( + ("td", "time_unit", "expected"), + [ + (timedelta(days=1), "ns", 86_400_000_000_000), + (timedelta(days=1), "us", 86_400_000_000), + (timedelta(days=1), "ms", 86_400_000), + (timedelta.min, "ns", -86_399_999_913_600_000_000_000), + (timedelta.max, "ns", 86_399_999_999_999_999_999_000), + (timedelta.min, "ms", -86_399_999_913_600_000), + (timedelta.max, "ms", 86_399_999_999_999_999), + ], +) +def test_timedelta_to_int(td: timedelta, time_unit: TimeUnit, expected: int) -> None: + assert timedelta_to_int(td, time_unit) == expected + + +def test_estimated_size() -> None: + s = pl.Series("n", list(range(100))) + df = s.to_frame() + + for sz in (s.estimated_size(), s.estimated_size("b"), s.estimated_size("bytes")): + assert sz == df.estimated_size() + + assert s.estimated_size("kb") == (df.estimated_size("b") / 1024) + assert s.estimated_size("mb") == (df.estimated_size("kb") / 1024) + assert s.estimated_size("gb") == (df.estimated_size("mb") / 1024) + assert s.estimated_size("tb") == (df.estimated_size("gb") / 1024) + + with pytest.raises(ValueError): + s.estimated_size("milkshake") # type: ignore[arg-type] + + +@pytest.mark.parametrize( + ("v1", "v2"), + [ + ("0.16.8", "0.16.7"), + ("23.0.0", (3, 1000)), + ((23, 0, 0), "3.1000"), + (("0", "0", "2beta"), "0.0.1"), + (("2", "5", "0", "1"), (2, 5, 0)), + ], +) +def test_parse_version(v1: Any, v2: Any) -> None: + assert parse_version(v1) > parse_version(v2) + assert parse_version(v2) < parse_version(v1) + + +@pytest.mark.slow +def test_in_notebook() -> None: + # private function, but easier to test this separately and mock it in the callers + assert not _in_notebook() + + +@pytest.mark.parametrize( + ("percentiles", "expected", "inject_median"), + [ + (None, [0.5], True), + (0.2, [0.2, 0.5], True), + (0.5, [0.5], True), + ((0.25, 0.75), [0.25, 0.5, 0.75], True), + # Undocumented effect - percentiles get sorted. + # Can be changed, this serves as documentation of current behaviour. + ((0.6, 0.3), [0.3, 0.5, 0.6], True), + (None, [], False), + (0.2, [0.2], False), + (0.5, [0.5], False), + ((0.25, 0.75), [0.25, 0.75], False), + ((0.6, 0.3), [0.3, 0.6], False), + ], +) +def test_parse_percentiles( + percentiles: Sequence[float] | float | None, + expected: Sequence[float], + inject_median: bool, +) -> None: + assert parse_percentiles(percentiles, inject_median=inject_median) == expected + + +@pytest.mark.parametrize(("percentiles"), [(1.1), ([-0.1])]) +def test_parse_percentiles_errors(percentiles: Sequence[float] | float | None) -> None: + with pytest.raises(ValueError): + parse_percentiles(percentiles) + + +@pytest.mark.parametrize( + ("sequence", "include_series", "expected"), + [ + (pl.Series(["xx", "yy"]), True, False), + (pl.Series([True, False]), False, False), + (pl.Series([True, False]), True, True), + (np.array([False, True]), False, True), + (np.array([False, True]), True, True), + ([True, False], False, True), + (["xx", "yy"], False, False), + (True, False, False), + ], +) +def test_is_bool_sequence_check( + sequence: Any, + include_series: bool, + expected: bool, +) -> None: + assert is_bool_sequence(sequence, include_series=include_series) == expected + if expected: + assert is_sequence(sequence, include_series=include_series) + + +@pytest.mark.parametrize( + ("sequence", "include_series", "expected"), + [ + (pl.Series(["xx", "yy"]), True, False), + (pl.Series([123, 345]), False, False), + (pl.Series([123, 345]), True, True), + (np.array([123, 345]), False, True), + (np.array([123, 345]), True, True), + (["xx", "yy"], False, False), + ([123, 456], False, True), + (123, False, False), + ], +) +def test_is_int_sequence_check( + sequence: Any, + include_series: bool, + expected: bool, +) -> None: + assert is_int_sequence(sequence, include_series=include_series) == expected + if expected: + assert is_sequence(sequence, include_series=include_series) + + +@pytest.mark.parametrize( + ("sequence", "include_series", "expected"), + [ + (pl.Series(["xx", "yy"]), False, False), + (pl.Series(["xx", "yy"]), True, True), + (pl.Series([123, 345]), True, False), + (np.array(["xx", "yy"]), False, True), + (np.array(["xx", "yy"]), True, True), + (["xx", "yy"], False, True), + ([123, 456], False, False), + ("xx", False, False), + ], +) +def test_is_str_sequence_check( + sequence: Any, + include_series: bool, + expected: bool, +) -> None: + assert is_str_sequence(sequence, include_series=include_series) == expected + if expected: + assert is_sequence(sequence, include_series=include_series) diff --git a/py-polars/tests/unit/utils/test_various.py b/py-polars/tests/unit/utils/test_various.py new file mode 100644 index 000000000000..a0bab09c11c5 --- /dev/null +++ b/py-polars/tests/unit/utils/test_various.py @@ -0,0 +1,10 @@ +import pytest + +from polars._utils.various import issue_warning +from polars.exceptions import PerformanceWarning + + +def test_issue_warning() -> None: + msg = "hello" + with pytest.warns(PerformanceWarning, match=msg): + issue_warning(msg, PerformanceWarning) diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 000000000000..93879cb36a57 --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,2 @@ +[toolchain] +channel = "nightly-2025-04-19" diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 000000000000..f5c108d49ac9 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,4 @@ +group_imports = "StdExternalCrate" +imports_granularity = "Module" +match_block_trailing_comma = true +use_field_init_shorthand = true

This API page has moved here.

{ + primitive_type, + null_count: options.null_count.then_some(array.null_count() as i64), + distinct_count: None, + max_value, + min_value, + } +} diff --git a/crates/polars-parquet/src/arrow/write/primitive/mod.rs b/crates/polars-parquet/src/arrow/write/primitive/mod.rs new file mode 100644 index 000000000000..96318ab0a89b --- /dev/null +++ b/crates/polars-parquet/src/arrow/write/primitive/mod.rs @@ -0,0 +1,6 @@ +mod basic; +mod nested; + +pub use basic::{array_to_page_integer, array_to_page_plain}; +pub(crate) use basic::{build_statistics, encode_plain}; +pub use nested::array_to_page as nested_array_to_page; diff --git a/crates/polars-parquet/src/arrow/write/primitive/nested.rs b/crates/polars-parquet/src/arrow/write/primitive/nested.rs new file mode 100644 index 000000000000..de911689f822 --- /dev/null +++ b/crates/polars-parquet/src/arrow/write/primitive/nested.rs @@ -0,0 +1,54 @@ +use arrow::array::{Array, PrimitiveArray}; +use arrow::types::NativeType as ArrowNativeType; +use polars_error::PolarsResult; + +use super::super::{WriteOptions, nested, utils}; +use super::basic::{build_statistics, encode_plain}; +use crate::arrow::read::schema::is_nullable; +use crate::arrow::write::Nested; +use crate::parquet::encoding::Encoding; +use crate::parquet::page::DataPage; +use crate::parquet::schema::types::PrimitiveType; +use crate::parquet::types::NativeType; +use crate::write::EncodeNullability; + +pub fn array_to_page( + array: &PrimitiveArray, + options: WriteOptions, + type_: PrimitiveType, + nested: &[Nested], +) -> PolarsResult +where + T: ArrowNativeType, + R: NativeType, + T: num_traits::AsPrimitive, +{ + let is_optional = is_nullable(&type_.field_info); + let encode_options = EncodeNullability::new(is_optional); + + let mut buffer = vec![]; + + let (repetition_levels_byte_length, definition_levels_byte_length) = + nested::write_rep_and_def(options.version, nested, &mut buffer)?; + + let buffer = encode_plain(array, encode_options, buffer); + + let statistics = if options.has_statistics() { + Some(build_statistics(array, type_.clone(), &options.statistics).serialize()) + } else { + None + }; + + utils::build_plain_page( + buffer, + nested::num_values(nested), + nested[0].len(), + array.null_count(), + repetition_levels_byte_length, + definition_levels_byte_length, + statistics, + type_, + options, + Encoding::Plain, + ) +} diff --git a/crates/polars-parquet/src/arrow/write/row_group.rs b/crates/polars-parquet/src/arrow/write/row_group.rs new file mode 100644 index 000000000000..9033339e0f69 --- /dev/null +++ b/crates/polars-parquet/src/arrow/write/row_group.rs @@ -0,0 +1,130 @@ +use arrow::array::Array; +use arrow::datatypes::ArrowSchema; +use arrow::record_batch::RecordBatchT; +use polars_error::{PolarsError, PolarsResult, polars_bail, to_compute_err}; + +use super::{ + DynIter, DynStreamingIterator, Encoding, RowGroupIterColumns, SchemaDescriptor, WriteOptions, + array_to_columns, to_parquet_schema, +}; +use crate::parquet::FallibleStreamingIterator; +use crate::parquet::error::ParquetError; +use crate::parquet::schema::types::ParquetType; +use crate::parquet::write::Compressor; + +/// Maps a [`RecordBatchT`] and parquet-specific options to an [`RowGroupIterColumns`] used to +/// write to parquet +/// # Panics +/// Iff +/// * `encodings.len() != fields.len()` or +/// * `encodings.len() != chunk.arrays().len()` +pub fn row_group_iter + 'static + Send + Sync>( + chunk: RecordBatchT, + encodings: Vec>, + fields: Vec, + options: WriteOptions, +) -> RowGroupIterColumns<'static, PolarsError> { + assert_eq!(encodings.len(), fields.len()); + assert_eq!(encodings.len(), chunk.arrays().len()); + DynIter::new( + chunk + .into_arrays() + .into_iter() + .zip(fields) + .zip(encodings) + .flat_map(move |((array, type_), encoding)| { + let encoded_columns = array_to_columns(array, type_, options, &encoding).unwrap(); + encoded_columns + .into_iter() + .map(|encoded_pages| { + let pages = encoded_pages; + + let pages = DynIter::new( + pages + .into_iter() + .map(|x| x.map_err(|e| ParquetError::oos(e.to_string()))), + ); + + let compressed_pages = Compressor::new(pages, options.compression, vec![]) + .map_err(to_compute_err); + Ok(DynStreamingIterator::new(compressed_pages)) + }) + .collect::>() + }), + ) +} + +/// An iterator adapter that converts an iterator over [`RecordBatchT`] into an iterator +/// of row groups. +/// Use it to create an iterator consumable by the parquet's API. +pub struct RowGroupIterator< + A: AsRef + 'static, + I: Iterator>>, +> { + iter: I, + options: WriteOptions, + parquet_schema: SchemaDescriptor, + encodings: Vec>, +} + +impl + 'static, I: Iterator>>> + RowGroupIterator +{ + /// Creates a new [`RowGroupIterator`] from an iterator over [`RecordBatchT`]. + /// + /// # Errors + /// Iff + /// * the Arrow schema can't be converted to a valid Parquet schema. + /// * the length of the encodings is different from the number of fields in schema + pub fn try_new( + iter: I, + schema: &ArrowSchema, + options: WriteOptions, + encodings: Vec>, + ) -> PolarsResult { + if encodings.len() != schema.len() { + polars_bail!(InvalidOperation: + "The number of encodings must equal the number of fields".to_string(), + ) + } + let parquet_schema = to_parquet_schema(schema)?; + + Ok(Self { + iter, + options, + parquet_schema, + encodings, + }) + } + + /// Returns the [`SchemaDescriptor`] of the [`RowGroupIterator`]. + pub fn parquet_schema(&self) -> &SchemaDescriptor { + &self.parquet_schema + } +} + +impl + 'static + Send + Sync, I: Iterator>>> + Iterator for RowGroupIterator +{ + type Item = PolarsResult>; + + fn next(&mut self) -> Option { + let options = self.options; + + self.iter.next().map(|maybe_chunk| { + let chunk = maybe_chunk?; + if self.encodings.len() != chunk.arrays().len() { + polars_bail!(InvalidOperation: + "The number of arrays in the chunk must equal the number of fields in the schema" + ) + }; + let encodings = self.encodings.clone(); + Ok(row_group_iter( + chunk, + encodings, + self.parquet_schema.fields().to_vec(), + options, + )) + }) + } +} diff --git a/crates/polars-parquet/src/arrow/write/schema.rs b/crates/polars-parquet/src/arrow/write/schema.rs new file mode 100644 index 000000000000..105539d09745 --- /dev/null +++ b/crates/polars-parquet/src/arrow/write/schema.rs @@ -0,0 +1,433 @@ +use arrow::datatypes::{ArrowDataType, ArrowSchema, ExtensionType, Field, TimeUnit}; +use arrow::io::ipc::write::{default_ipc_fields, schema_to_bytes}; +use base64::Engine as _; +use base64::engine::general_purpose; +use polars_error::{PolarsResult, polars_bail}; +use polars_utils::pl_str::PlSmallStr; + +use super::super::ARROW_SCHEMA_META_KEY; +use crate::arrow::write::decimal_length_from_precision; +use crate::parquet::metadata::KeyValue; +use crate::parquet::schema::Repetition; +use crate::parquet::schema::types::{ + GroupConvertedType, GroupLogicalType, IntegerType, ParquetType, PhysicalType, + PrimitiveConvertedType, PrimitiveLogicalType, TimeUnit as ParquetTimeUnit, +}; + +fn convert_field(field: Field) -> Field { + Field { + name: field.name, + dtype: convert_dtype(field.dtype), + is_nullable: field.is_nullable, + metadata: field.metadata, + } +} + +fn convert_dtype(dtype: ArrowDataType) -> ArrowDataType { + use ArrowDataType as D; + match dtype { + D::LargeList(field) => D::LargeList(Box::new(convert_field(*field))), + D::Struct(mut fields) => { + for field in &mut fields { + *field = convert_field(std::mem::take(field)) + } + D::Struct(fields) + }, + D::BinaryView => D::LargeBinary, + D::Utf8View => D::LargeUtf8, + D::Dictionary(it, dtype, sorted) => { + let dtype = convert_dtype(*dtype); + D::Dictionary(it, Box::new(dtype), sorted) + }, + D::Extension(ext) => { + let dtype = convert_dtype(ext.inner); + D::Extension(Box::new(ExtensionType { + inner: dtype, + ..*ext + })) + }, + dt => dt, + } +} + +pub fn schema_to_metadata_key(schema: &ArrowSchema) -> KeyValue { + // Convert schema until more arrow readers are aware of binview + let serialized_schema = if schema.iter_values().any(|field| field.dtype.is_view()) { + let schema = schema + .iter_values() + .map(|field| convert_field(field.clone())) + .map(|x| (x.name.clone(), x)) + .collect(); + schema_to_bytes(&schema, &default_ipc_fields(schema.iter_values()), None) + } else { + schema_to_bytes(schema, &default_ipc_fields(schema.iter_values()), None) + }; + + // manually prepending the length to the schema as arrow uses the legacy IPC format + // TODO: change after addressing ARROW-9777 + let schema_len = serialized_schema.len(); + let mut len_prefix_schema = Vec::with_capacity(schema_len + 8); + len_prefix_schema.extend_from_slice(&[255u8, 255, 255, 255]); + len_prefix_schema.extend_from_slice(&(schema_len as u32).to_le_bytes()); + len_prefix_schema.extend_from_slice(&serialized_schema); + + let encoded = general_purpose::STANDARD.encode(&len_prefix_schema); + + KeyValue { + key: ARROW_SCHEMA_META_KEY.to_string(), + value: Some(encoded), + } +} + +/// Creates a [`ParquetType`] from a [`Field`]. +pub fn to_parquet_type(field: &Field) -> PolarsResult { + let name = field.name.clone(); + let repetition = if field.is_nullable { + Repetition::Optional + } else { + Repetition::Required + }; + // create type from field + match field.dtype().to_logical_type() { + ArrowDataType::Null => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + None, + Some(PrimitiveLogicalType::Unknown), + None, + )?), + ArrowDataType::Boolean => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Boolean, + repetition, + None, + None, + None, + )?), + ArrowDataType::Int32 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + None, + None, + None, + )?), + // ArrowDataType::Duration(_) has no parquet representation => do not apply any logical type + ArrowDataType::Int64 | ArrowDataType::Duration(_) => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int64, + repetition, + None, + None, + None, + )?), + // no natural representation in parquet; leave it as is. + // arrow consumers MAY use the arrow schema in the metadata to parse them. + ArrowDataType::Date64 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int64, + repetition, + None, + None, + None, + )?), + ArrowDataType::Float32 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Float, + repetition, + None, + None, + None, + )?), + ArrowDataType::Float64 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Double, + repetition, + None, + None, + None, + )?), + ArrowDataType::Binary | ArrowDataType::LargeBinary | ArrowDataType::BinaryView => { + Ok(ParquetType::try_from_primitive( + name, + PhysicalType::ByteArray, + repetition, + None, + None, + None, + )?) + }, + ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 | ArrowDataType::Utf8View => { + Ok(ParquetType::try_from_primitive( + name, + PhysicalType::ByteArray, + repetition, + Some(PrimitiveConvertedType::Utf8), + Some(PrimitiveLogicalType::String), + None, + )?) + }, + ArrowDataType::Date32 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + Some(PrimitiveConvertedType::Date), + Some(PrimitiveLogicalType::Date), + None, + )?), + ArrowDataType::Int8 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + Some(PrimitiveConvertedType::Int8), + Some(PrimitiveLogicalType::Integer(IntegerType::Int8)), + None, + )?), + ArrowDataType::Int16 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + Some(PrimitiveConvertedType::Int16), + Some(PrimitiveLogicalType::Integer(IntegerType::Int16)), + None, + )?), + ArrowDataType::UInt8 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + Some(PrimitiveConvertedType::Uint8), + Some(PrimitiveLogicalType::Integer(IntegerType::UInt8)), + None, + )?), + ArrowDataType::UInt16 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + Some(PrimitiveConvertedType::Uint16), + Some(PrimitiveLogicalType::Integer(IntegerType::UInt16)), + None, + )?), + ArrowDataType::UInt32 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + Some(PrimitiveConvertedType::Uint32), + Some(PrimitiveLogicalType::Integer(IntegerType::UInt32)), + None, + )?), + ArrowDataType::UInt64 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int64, + repetition, + Some(PrimitiveConvertedType::Uint64), + Some(PrimitiveLogicalType::Integer(IntegerType::UInt64)), + None, + )?), + // no natural representation in parquet; leave it as is. + // arrow consumers MAY use the arrow schema in the metadata to parse them. + ArrowDataType::Timestamp(TimeUnit::Second, _) => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int64, + repetition, + None, + None, + None, + )?), + ArrowDataType::Timestamp(time_unit, zone) => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int64, + repetition, + None, + Some(PrimitiveLogicalType::Timestamp { + is_adjusted_to_utc: matches!(zone, Some(z) if !z.as_str().is_empty()), + unit: match time_unit { + TimeUnit::Second => unreachable!(), + TimeUnit::Millisecond => ParquetTimeUnit::Milliseconds, + TimeUnit::Microsecond => ParquetTimeUnit::Microseconds, + TimeUnit::Nanosecond => ParquetTimeUnit::Nanoseconds, + }, + }), + None, + )?), + // no natural representation in parquet; leave it as is. + // arrow consumers MAY use the arrow schema in the metadata to parse them. + ArrowDataType::Time32(TimeUnit::Second) => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + None, + None, + None, + )?), + ArrowDataType::Time32(TimeUnit::Millisecond) => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + Some(PrimitiveConvertedType::TimeMillis), + Some(PrimitiveLogicalType::Time { + is_adjusted_to_utc: false, + unit: ParquetTimeUnit::Milliseconds, + }), + None, + )?), + ArrowDataType::Time64(time_unit) => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int64, + repetition, + match time_unit { + TimeUnit::Microsecond => Some(PrimitiveConvertedType::TimeMicros), + TimeUnit::Nanosecond => None, + _ => unreachable!(), + }, + Some(PrimitiveLogicalType::Time { + is_adjusted_to_utc: false, + unit: match time_unit { + TimeUnit::Microsecond => ParquetTimeUnit::Microseconds, + TimeUnit::Nanosecond => ParquetTimeUnit::Nanoseconds, + _ => unreachable!(), + }, + }), + None, + )?), + ArrowDataType::Struct(fields) => { + if fields.is_empty() { + polars_bail!(InvalidOperation: + "Unable to write struct type with no child field to Parquet. Consider adding a dummy child field.".to_string(), + ) + } + // recursively convert children to types/nodes + let fields = fields + .iter() + .map(to_parquet_type) + .collect::>>()?; + Ok(ParquetType::from_group( + name, repetition, None, None, fields, None, + )) + }, + ArrowDataType::Dictionary(_, value, _) => { + let dict_field = Field::new(name.clone(), value.as_ref().clone(), field.is_nullable); + to_parquet_type(&dict_field) + }, + ArrowDataType::FixedSizeBinary(size) => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::FixedLenByteArray(*size), + repetition, + None, + None, + None, + )?), + ArrowDataType::Decimal(precision, scale) => { + let precision = *precision; + let scale = *scale; + let logical_type = Some(PrimitiveLogicalType::Decimal(precision, scale)); + + let physical_type = if precision <= 9 { + PhysicalType::Int32 + } else if precision <= 18 { + PhysicalType::Int64 + } else { + let len = decimal_length_from_precision(precision); + PhysicalType::FixedLenByteArray(len) + }; + Ok(ParquetType::try_from_primitive( + name, + physical_type, + repetition, + Some(PrimitiveConvertedType::Decimal(precision, scale)), + logical_type, + None, + )?) + }, + ArrowDataType::Decimal256(precision, scale) => { + let precision = *precision; + let scale = *scale; + let logical_type = Some(PrimitiveLogicalType::Decimal(precision, scale)); + + if precision <= 9 { + Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + Some(PrimitiveConvertedType::Decimal(precision, scale)), + logical_type, + None, + )?) + } else if precision <= 18 { + Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int64, + repetition, + Some(PrimitiveConvertedType::Decimal(precision, scale)), + logical_type, + None, + )?) + } else if precision <= 38 { + let len = decimal_length_from_precision(precision); + Ok(ParquetType::try_from_primitive( + name, + PhysicalType::FixedLenByteArray(len), + repetition, + Some(PrimitiveConvertedType::Decimal(precision, scale)), + logical_type, + None, + )?) + } else { + Ok(ParquetType::try_from_primitive( + name, + PhysicalType::FixedLenByteArray(32), + repetition, + None, + None, + None, + )?) + } + }, + ArrowDataType::Interval(_) => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::FixedLenByteArray(12), + repetition, + Some(PrimitiveConvertedType::Interval), + None, + None, + )?), + ArrowDataType::List(f) + | ArrowDataType::FixedSizeList(f, _) + | ArrowDataType::LargeList(f) => { + let mut f = f.clone(); + f.name = PlSmallStr::from_static("element"); + + Ok(ParquetType::from_group( + name, + repetition, + Some(GroupConvertedType::List), + Some(GroupLogicalType::List), + vec![ParquetType::from_group( + PlSmallStr::from_static("list"), + Repetition::Repeated, + None, + None, + vec![to_parquet_type(&f)?], + None, + )], + None, + )) + }, + ArrowDataType::Map(f, _) => Ok(ParquetType::from_group( + name, + repetition, + Some(GroupConvertedType::Map), + Some(GroupLogicalType::Map), + vec![ParquetType::from_group( + PlSmallStr::from_static("map"), + Repetition::Repeated, + None, + None, + vec![to_parquet_type(f)?], + None, + )], + None, + )), + other => polars_bail!(nyi = "Writing the data type {other:?} is not yet implemented"), + } +} diff --git a/crates/polars-parquet/src/arrow/write/utils.rs b/crates/polars-parquet/src/arrow/write/utils.rs new file mode 100644 index 000000000000..e574bb8275fa --- /dev/null +++ b/crates/polars-parquet/src/arrow/write/utils.rs @@ -0,0 +1,151 @@ +use arrow::bitmap::Bitmap; +use arrow::datatypes::ArrowDataType; +use polars_error::*; + +use super::{Version, WriteOptions}; +use crate::parquet::CowBuffer; +use crate::parquet::compression::CompressionOptions; +use crate::parquet::encoding::Encoding; +use crate::parquet::encoding::hybrid_rle::{self, encode}; +use crate::parquet::metadata::Descriptor; +use crate::parquet::page::{DataPage, DataPageHeader, DataPageHeaderV1, DataPageHeaderV2}; +use crate::parquet::schema::types::PrimitiveType; +use crate::parquet::statistics::ParquetStatistics; + +/// writes the def levels to a `Vec` and returns it. +pub fn write_def_levels( + writer: &mut Vec, + is_optional: bool, + validity: Option<&Bitmap>, + len: usize, + version: Version, +) -> PolarsResult<()> { + if is_optional { + match version { + Version::V1 => { + writer.extend(&[0, 0, 0, 0]); + let start = writer.len(); + + match validity { + None => >::run_length_encode( + writer, len, true, 1, + )?, + Some(validity) => encode::(writer, validity.iter(), 1)?, + } + + // write the first 4 bytes as length + let length = ((writer.len() - start) as i32).to_le_bytes(); + (0..4).for_each(|i| writer[start - 4 + i] = length[i]); + }, + Version::V2 => match validity { + None => { + >::run_length_encode(writer, len, true, 1)? + }, + Some(validity) => encode::(writer, validity.iter(), 1)?, + }, + } + + Ok(()) + } else { + // is required => no def levels + Ok(()) + } +} + +#[allow(clippy::too_many_arguments)] +pub fn build_plain_page( + buffer: Vec, + num_values: usize, + num_rows: usize, + null_count: usize, + repetition_levels_byte_length: usize, + definition_levels_byte_length: usize, + statistics: Option, + type_: PrimitiveType, + options: WriteOptions, + encoding: Encoding, +) -> PolarsResult { + let header = match options.version { + Version::V1 => DataPageHeader::V1(DataPageHeaderV1 { + num_values: num_values as i32, + encoding: encoding.into(), + definition_level_encoding: Encoding::Rle.into(), + repetition_level_encoding: Encoding::Rle.into(), + statistics, + }), + Version::V2 => DataPageHeader::V2(DataPageHeaderV2 { + num_values: num_values as i32, + encoding: encoding.into(), + num_nulls: null_count as i32, + num_rows: num_rows as i32, + definition_levels_byte_length: definition_levels_byte_length as i32, + repetition_levels_byte_length: repetition_levels_byte_length as i32, + is_compressed: Some(options.compression != CompressionOptions::Uncompressed), + statistics, + }), + }; + Ok(DataPage::new( + header, + CowBuffer::Owned(buffer), + Descriptor { + primitive_type: type_, + max_def_level: 0, + max_rep_level: 0, + }, + num_rows, + )) +} + +/// Auxiliary iterator adapter to declare the size hint of an iterator. +pub(super) struct ExactSizedIter> { + iter: I, + remaining: usize, +} + +impl + Clone> Clone for ExactSizedIter { + fn clone(&self) -> Self { + Self { + iter: self.iter.clone(), + remaining: self.remaining, + } + } +} + +impl> ExactSizedIter { + pub fn new(iter: I, length: usize) -> Self { + Self { + iter, + remaining: length, + } + } +} + +impl> Iterator for ExactSizedIter { + type Item = T; + + #[inline] + fn next(&mut self) -> Option { + self.iter.next().inspect(|_| self.remaining -= 1) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.remaining, Some(self.remaining)) + } +} + +impl> std::iter::ExactSizeIterator for ExactSizedIter {} + +/// Returns the number of bits needed to bitpack `max` +#[inline] +pub fn get_bit_width(max: u64) -> u32 { + 64 - max.leading_zeros() +} + +pub(super) fn invalid_encoding(encoding: Encoding, dtype: &ArrowDataType) -> PolarsError { + polars_err!(InvalidOperation: + "Datatype {:?} cannot be encoded by {:?} encoding", + dtype, + encoding + ) +} diff --git a/crates/polars-parquet/src/lib.rs b/crates/polars-parquet/src/lib.rs new file mode 100644 index 000000000000..c429e83ad328 --- /dev/null +++ b/crates/polars-parquet/src/lib.rs @@ -0,0 +1,5 @@ +#![cfg_attr(feature = "simd", feature(portable_simd))] +#![allow(clippy::len_without_is_empty)] +pub mod arrow; +pub use crate::arrow::{read, write}; +pub mod parquet; diff --git a/crates/polars-parquet/src/parquet/bloom_filter/hash.rs b/crates/polars-parquet/src/parquet/bloom_filter/hash.rs new file mode 100644 index 000000000000..c535faa44d76 --- /dev/null +++ b/crates/polars-parquet/src/parquet/bloom_filter/hash.rs @@ -0,0 +1,17 @@ +use xxhash_rust::xxh64::xxh64; + +use crate::parquet::types::NativeType; + +const SEED: u64 = 0; + +/// (xxh64) hash of a [`NativeType`]. +#[inline] +pub fn hash_native(value: T) -> u64 { + xxh64(value.to_le_bytes().as_ref(), SEED) +} + +/// (xxh64) hash of a sequence of bytes (e.g. ByteArray). +#[inline] +pub fn hash_byte>(value: A) -> u64 { + xxh64(value.as_ref(), SEED) +} diff --git a/crates/polars-parquet/src/parquet/bloom_filter/mod.rs b/crates/polars-parquet/src/parquet/bloom_filter/mod.rs new file mode 100644 index 000000000000..218715d7ac5f --- /dev/null +++ b/crates/polars-parquet/src/parquet/bloom_filter/mod.rs @@ -0,0 +1,71 @@ +//! API to read and use bloom filters +mod hash; +mod read; +mod split_block; + +pub use hash::{hash_byte, hash_native}; +pub use read::read; +pub use split_block::{insert, is_in_set}; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn basics() { + let mut bitset = vec![0; 32]; + + // insert + for a in 0..10i64 { + let hash = hash_native(a); + insert(&mut bitset, hash); + } + + // bloom filter produced by parquet-mr/spark for a column of i64 (0..=10) + /* + import pyspark.sql // 3.2.1 + spark = pyspark.sql.SparkSession.builder.getOrCreate() + spark.conf.set("parquet.bloom.filter.enabled", True) + spark.conf.set("parquet.bloom.filter.expected.ndv", 10) + spark.conf.set("parquet.bloom.filter.max.bytes", 32) + + data = [(i % 10,) for i in range(100)] + df = spark.createDataFrame(data, ["id"]).repartition(1) + + df.write.parquet("bla.parquet", mode = "overwrite") + */ + let expected: &[u8] = &[ + 24, 130, 24, 8, 134, 8, 68, 6, 2, 101, 128, 10, 64, 2, 38, 78, 114, 1, 64, 38, 1, 192, + 194, 152, 64, 70, 0, 36, 56, 121, 64, 0, + ]; + assert_eq!(bitset, expected); + + // check + for a in 0..11i64 { + let hash = hash_native(a); + + let valid = is_in_set(&bitset, hash); + + assert_eq!(a < 10, valid); + } + } + + #[test] + fn binary() { + let mut bitset = vec![0; 32]; + + // insert + for a in 0..10i64 { + let value = format!("a{}", a); + let hash = hash_byte(value); + insert(&mut bitset, hash); + } + + // bloom filter produced by parquet-mr/spark for a column of i64 f"a{i}" for i in 0..10 + let expected: &[u8] = &[ + 200, 1, 80, 20, 64, 68, 8, 109, 6, 37, 4, 67, 144, 80, 96, 32, 8, 132, 43, 33, 0, 5, + 99, 65, 2, 0, 224, 44, 64, 78, 96, 4, + ]; + assert_eq!(bitset, expected); + } +} diff --git a/crates/polars-parquet/src/parquet/bloom_filter/read.rs b/crates/polars-parquet/src/parquet/bloom_filter/read.rs new file mode 100644 index 000000000000..fe4ee718cb7e --- /dev/null +++ b/crates/polars-parquet/src/parquet/bloom_filter/read.rs @@ -0,0 +1,51 @@ +use std::io::{Read, Seek, SeekFrom}; + +use polars_parquet_format::thrift::protocol::TCompactInputProtocol; +use polars_parquet_format::{ + BloomFilterAlgorithm, BloomFilterCompression, BloomFilterHeader, SplitBlockAlgorithm, + Uncompressed, +}; + +use crate::parquet::error::ParquetResult; +use crate::parquet::metadata::ColumnChunkMetadata; + +/// Reads the bloom filter associated to [`ColumnChunkMetadata`] into `bitset`. +/// Results in an empty `bitset` if there is no associated bloom filter or the algorithm is not supported. +/// # Error +/// Errors if the column contains no metadata or the filter can't be read or deserialized. +pub fn read( + column_metadata: &ColumnChunkMetadata, + mut reader: &mut R, + bitset: &mut Vec, +) -> ParquetResult<()> { + let offset = column_metadata.metadata().bloom_filter_offset; + + let offset = if let Some(offset) = offset { + offset as u64 + } else { + bitset.clear(); + return Ok(()); + }; + reader.seek(SeekFrom::Start(offset))?; + + // deserialize header + let mut prot = TCompactInputProtocol::new(&mut reader, usize::MAX); // max is ok since `BloomFilterHeader` never allocates + let header = BloomFilterHeader::read_from_in_protocol(&mut prot)?; + + if header.algorithm != BloomFilterAlgorithm::BLOCK(SplitBlockAlgorithm {}) { + bitset.clear(); + return Ok(()); + } + if header.compression != BloomFilterCompression::UNCOMPRESSED(Uncompressed {}) { + bitset.clear(); + return Ok(()); + } + + let length: usize = header.num_bytes.try_into()?; + + bitset.clear(); + bitset.try_reserve(length)?; + reader.by_ref().take(length as u64).read_to_end(bitset)?; + + Ok(()) +} diff --git a/crates/polars-parquet/src/parquet/bloom_filter/split_block.rs b/crates/polars-parquet/src/parquet/bloom_filter/split_block.rs new file mode 100644 index 000000000000..e8672648cda4 --- /dev/null +++ b/crates/polars-parquet/src/parquet/bloom_filter/split_block.rs @@ -0,0 +1,80 @@ +/// magic numbers taken from https://github.com/apache/parquet-format/blob/master/BloomFilter.md +const SALT: [u32; 8] = [ + 1203114875, 1150766481, 2284105051, 2729912477, 1884591559, 770785867, 2667333959, 1550580529, +]; + +fn hash_to_block_index(hash: u64, len: usize) -> usize { + let number_of_blocks = len as u64 / 32; + let low_hash = hash >> 32; + let block_index = ((low_hash * number_of_blocks) >> 32) as u32; + block_index as usize +} + +fn new_mask(x: u32) -> [u32; 8] { + let mut a = [0u32; 8]; + for i in 0..8 { + let mask = x.wrapping_mul(SALT[i]); + let mask = mask >> 27; + let mask = 0x1 << mask; + a[i] = mask; + } + a +} + +/// loads a block from the bitset to the stack +#[inline] +fn load_block(bitset: &[u8]) -> [u32; 8] { + let mut a = [0u32; 8]; + let bitset = bitset.chunks_exact(4).take(8); + for (a, chunk) in a.iter_mut().zip(bitset) { + *a = u32::from_le_bytes(chunk.try_into().unwrap()) + } + a +} + +/// assigns a block from the stack to `bitset` +#[inline] +fn unload_block(block: [u32; 8], bitset: &mut [u8]) { + let bitset = bitset.chunks_exact_mut(4).take(8); + for (a, chunk) in block.iter().zip(bitset) { + let a = a.to_le_bytes(); + chunk[0] = a[0]; + chunk[1] = a[1]; + chunk[2] = a[2]; + chunk[3] = a[3]; + } +} + +/// Returns whether the `hash` is in the set +pub fn is_in_set(bitset: &[u8], hash: u64) -> bool { + let block_index = hash_to_block_index(hash, bitset.len()); + let key = hash as u32; + + let mask = new_mask(key); + let slice = &bitset[block_index * 32..(block_index + 1) * 32]; + let block_mask = load_block(slice); + + for i in 0..8 { + if mask[i] & block_mask[i] == 0 { + return false; + } + } + true +} + +/// Inserts a new hash to the set +pub fn insert(bitset: &mut [u8], hash: u64) { + let block_index = hash_to_block_index(hash, bitset.len()); + let key = hash as u32; + + let mask = new_mask(key); + let slice = &bitset[block_index * 32..(block_index + 1) * 32]; + let mut block_mask = load_block(slice); + + for i in 0..8 { + block_mask[i] |= mask[i]; + + let mut_slice = &mut bitset[block_index * 32..(block_index + 1) * 32]; + unload_block(block_mask, mut_slice) + } +} diff --git a/crates/polars-parquet/src/parquet/compression.rs b/crates/polars-parquet/src/parquet/compression.rs new file mode 100644 index 000000000000..c029f93435c2 --- /dev/null +++ b/crates/polars-parquet/src/parquet/compression.rs @@ -0,0 +1,397 @@ +//! Functionality to compress and decompress data according to the parquet specification +pub use super::parquet_bridge::{ + BrotliLevel, Compression, CompressionOptions, GzipLevel, ZstdLevel, +}; +use crate::parquet::error::{ParquetError, ParquetResult}; + +#[cfg(any(feature = "snappy", feature = "lz4"))] +fn inner_compress< + G: Fn(usize) -> ParquetResult, + F: Fn(&[u8], &mut [u8]) -> ParquetResult, +>( + input: &[u8], + output: &mut Vec, + get_length: G, + compress: F, +) -> ParquetResult<()> { + let original_length = output.len(); + let max_required_length = get_length(input.len())?; + + output.resize(original_length + max_required_length, 0); + let compressed_size = compress(input, &mut output[original_length..])?; + + output.truncate(original_length + compressed_size); + Ok(()) +} + +/// Compresses data stored in slice `input_buf` and writes the compressed result +/// to `output_buf`. +/// +/// Note that you'll need to call `clear()` before reusing the same `output_buf` +/// across different `compress` calls. +#[allow(unused_variables)] +pub fn compress( + compression: CompressionOptions, + input_buf: &[u8], + #[allow(clippy::ptr_arg)] output_buf: &mut Vec, +) -> ParquetResult<()> { + match compression { + #[cfg(feature = "brotli")] + CompressionOptions::Brotli(level) => { + use std::io::Write; + const BROTLI_DEFAULT_BUFFER_SIZE: usize = 4096; + const BROTLI_DEFAULT_LG_WINDOW_SIZE: u32 = 22; // recommended between 20-22 + + let q = level.unwrap_or_default(); + let mut encoder = brotli::CompressorWriter::new( + output_buf, + BROTLI_DEFAULT_BUFFER_SIZE, + q.compression_level(), + BROTLI_DEFAULT_LG_WINDOW_SIZE, + ); + encoder.write_all(input_buf)?; + encoder.flush().map_err(|e| e.into()) + }, + #[cfg(not(feature = "brotli"))] + CompressionOptions::Brotli(_) => Err(ParquetError::FeatureNotActive( + crate::parquet::error::Feature::Brotli, + "compress to brotli".to_string(), + )), + #[cfg(feature = "gzip")] + CompressionOptions::Gzip(level) => { + use std::io::Write; + let level = level.unwrap_or_default(); + let mut encoder = flate2::write::GzEncoder::new(output_buf, level.into()); + encoder.write_all(input_buf)?; + encoder.try_finish().map_err(|e| e.into()) + }, + #[cfg(not(feature = "gzip"))] + CompressionOptions::Gzip(_) => Err(ParquetError::FeatureNotActive( + crate::parquet::error::Feature::Gzip, + "compress to gzip".to_string(), + )), + #[cfg(feature = "snappy")] + CompressionOptions::Snappy => inner_compress( + input_buf, + output_buf, + |len| Ok(snap::raw::max_compress_len(len)), + |input, output| Ok(snap::raw::Encoder::new().compress(input, output)?), + ), + #[cfg(not(feature = "snappy"))] + CompressionOptions::Snappy => Err(ParquetError::FeatureNotActive( + crate::parquet::error::Feature::Snappy, + "compress to snappy".to_string(), + )), + #[cfg(feature = "lz4")] + CompressionOptions::Lz4Raw => inner_compress( + input_buf, + output_buf, + |len| Ok(lz4::block::compress_bound(len)?), + |input, output| { + let compressed_size = lz4::block::compress_to_buffer(input, None, false, output)?; + Ok(compressed_size) + }, + ), + #[cfg(all(not(feature = "lz4"), not(feature = "lz4_flex")))] + CompressionOptions::Lz4Raw => Err(ParquetError::FeatureNotActive( + crate::parquet::error::Feature::Lz4, + "compress to lz4".to_string(), + )), + #[cfg(feature = "zstd")] + CompressionOptions::Zstd(level) => { + let level = level.map(|v| v.compression_level()).unwrap_or_default(); + // Make sure the buffer is large enough; the interface assumption is + // that decompressed data is appended to the output buffer. + let old_len = output_buf.len(); + output_buf.resize( + old_len + zstd::zstd_safe::compress_bound(input_buf.len()), + 0, + ); + match zstd::bulk::compress_to_buffer(input_buf, &mut output_buf[old_len..], level) { + Ok(written_size) => { + output_buf.truncate(old_len + written_size); + Ok(()) + }, + Err(e) => Err(e.into()), + } + }, + #[cfg(not(feature = "zstd"))] + CompressionOptions::Zstd(_) => Err(ParquetError::FeatureNotActive( + crate::parquet::error::Feature::Zstd, + "compress to zstd".to_string(), + )), + CompressionOptions::Uncompressed => Err(ParquetError::InvalidParameter( + "Compressing uncompressed".to_string(), + )), + _ => Err(ParquetError::FeatureNotSupported(format!( + "Compression {:?} is not supported", + compression, + ))), + } +} + +/// Decompresses data stored in slice `input_buf` and writes output to `output_buf`. +/// Returns the total number of bytes written. +#[allow(unused_variables)] +pub fn decompress( + compression: Compression, + input_buf: &[u8], + output_buf: &mut [u8], +) -> ParquetResult<()> { + match compression { + #[cfg(feature = "brotli")] + Compression::Brotli => { + use std::io::Read; + const BROTLI_DEFAULT_BUFFER_SIZE: usize = 4096; + brotli::Decompressor::new(input_buf, BROTLI_DEFAULT_BUFFER_SIZE) + .read_exact(output_buf) + .map_err(|e| e.into()) + }, + #[cfg(not(feature = "brotli"))] + Compression::Brotli => Err(ParquetError::FeatureNotActive( + crate::parquet::error::Feature::Brotli, + "decompress with brotli".to_string(), + )), + #[cfg(feature = "gzip")] + Compression::Gzip => { + use std::io::Read; + let mut decoder = flate2::read::GzDecoder::new(input_buf); + decoder.read_exact(output_buf).map_err(|e| e.into()) + }, + #[cfg(not(feature = "gzip"))] + Compression::Gzip => Err(ParquetError::FeatureNotActive( + crate::parquet::error::Feature::Gzip, + "decompress with gzip".to_string(), + )), + #[cfg(feature = "snappy")] + Compression::Snappy => { + use snap::raw::{Decoder, decompress_len}; + + let len = decompress_len(input_buf)?; + if len > output_buf.len() { + return Err(ParquetError::oos("snappy header out of spec")); + } + Decoder::new() + .decompress(input_buf, output_buf) + .map_err(|e| e.into()) + .map(|_| ()) + }, + #[cfg(not(feature = "snappy"))] + Compression::Snappy => Err(ParquetError::FeatureNotActive( + crate::parquet::error::Feature::Snappy, + "decompress with snappy".to_string(), + )), + #[cfg(all(feature = "lz4_flex", not(feature = "lz4")))] + Compression::Lz4Raw => lz4_flex::block::decompress_into(input_buf, output_buf) + .map(|_| {}) + .map_err(|e| e.into()), + #[cfg(feature = "lz4")] + Compression::Lz4Raw => { + lz4::block::decompress_to_buffer(input_buf, Some(output_buf.len() as i32), output_buf) + .map(|_| {}) + .map_err(|e| e.into()) + }, + #[cfg(all(not(feature = "lz4"), not(feature = "lz4_flex")))] + Compression::Lz4Raw => Err(ParquetError::FeatureNotActive( + crate::parquet::error::Feature::Lz4, + "decompress with lz4".to_string(), + )), + + #[cfg(any(feature = "lz4_flex", feature = "lz4"))] + Compression::Lz4 => try_decompress_hadoop(input_buf, output_buf).or_else(|_| { + lz4_decompress_to_buffer(input_buf, Some(output_buf.len() as i32), output_buf) + .map(|_| {}) + }), + + #[cfg(all(not(feature = "lz4_flex"), not(feature = "lz4")))] + Compression::Lz4 => Err(ParquetError::FeatureNotActive( + crate::parquet::error::Feature::Lz4, + "decompress with legacy lz4".to_string(), + )), + + #[cfg(feature = "zstd")] + Compression::Zstd => { + use std::io::Read; + let mut decoder = zstd::Decoder::with_buffer(input_buf)?; + decoder.read_exact(output_buf).map_err(|e| e.into()) + }, + #[cfg(not(feature = "zstd"))] + Compression::Zstd => Err(ParquetError::FeatureNotActive( + crate::parquet::error::Feature::Zstd, + "decompress with zstd".to_string(), + )), + Compression::Uncompressed => Err(ParquetError::InvalidParameter( + "Compressing uncompressed".to_string(), + )), + _ => Err(ParquetError::FeatureNotSupported(format!( + "Compression {:?} is not supported", + compression, + ))), + } +} + +/// Try to decompress the buffer as if it was compressed with the Hadoop Lz4Codec. +/// Translated from the apache arrow c++ function [TryDecompressHadoop](https://github.com/apache/arrow/blob/bf18e6e4b5bb6180706b1ba0d597a65a4ce5ca48/cpp/src/arrow/util/compression_lz4.cc#L474). +/// Returns error if decompression failed. +#[cfg(any(feature = "lz4", feature = "lz4_flex"))] +fn try_decompress_hadoop(input_buf: &[u8], output_buf: &mut [u8]) -> ParquetResult<()> { + // Parquet files written with the Hadoop Lz4Codec use their own framing. + // The input buffer can contain an arbitrary number of "frames", each + // with the following structure: + // - bytes 0..3: big-endian uint32_t representing the frame decompressed size + // - bytes 4..7: big-endian uint32_t representing the frame compressed size + // - bytes 8...: frame compressed data + // + // The Hadoop Lz4Codec source code can be found here: + // https://github.com/apache/hadoop/blob/trunk/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-nativetask/src/main/native/src/codec/Lz4Codec.cc + + const SIZE_U32: usize = size_of::(); + const PREFIX_LEN: usize = SIZE_U32 * 2; + let mut input_len = input_buf.len(); + let mut input = input_buf; + let mut output_len = output_buf.len(); + let mut output: &mut [u8] = output_buf; + while input_len >= PREFIX_LEN { + let mut bytes = [0; SIZE_U32]; + bytes.copy_from_slice(&input[0..4]); + let expected_decompressed_size = u32::from_be_bytes(bytes); + let mut bytes = [0; SIZE_U32]; + bytes.copy_from_slice(&input[4..8]); + let expected_compressed_size = u32::from_be_bytes(bytes); + input = &input[PREFIX_LEN..]; + input_len -= PREFIX_LEN; + + if input_len < expected_compressed_size as usize { + return Err(ParquetError::oos("Not enough bytes for Hadoop frame")); + } + + if output_len < expected_decompressed_size as usize { + return Err(ParquetError::oos( + "Not enough bytes to hold advertised output", + )); + } + let decompressed_size = lz4_decompress_to_buffer( + &input[..expected_compressed_size as usize], + Some(output_len as i32), + output, + )?; + if decompressed_size != expected_decompressed_size as usize { + return Err(ParquetError::oos("unexpected decompressed size")); + } + input_len -= expected_compressed_size as usize; + output_len -= expected_decompressed_size as usize; + if input_len > expected_compressed_size as usize { + input = &input[expected_compressed_size as usize..]; + output = &mut output[expected_decompressed_size as usize..]; + } else { + break; + } + } + if input_len == 0 { + Ok(()) + } else { + Err(ParquetError::oos("Not all input are consumed")) + } +} + +#[cfg(feature = "lz4")] +#[inline] +fn lz4_decompress_to_buffer( + src: &[u8], + uncompressed_size: Option, + buffer: &mut [u8], +) -> ParquetResult { + let size = lz4::block::decompress_to_buffer(src, uncompressed_size, buffer)?; + Ok(size) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_roundtrip(c: CompressionOptions, data: &[u8]) { + let offset = 2048; + + // Compress to a buffer that already has data is possible + let mut compressed = vec![2; offset]; + compress(c, data, &mut compressed).expect("Error when compressing"); + + // data is compressed... + assert!(compressed.len() - offset < data.len()); + + let mut decompressed = vec![0; data.len()]; + decompress(c.into(), &compressed[offset..], &mut decompressed) + .expect("Error when decompressing"); + assert_eq!(data, decompressed.as_slice()); + } + + fn test_codec(c: CompressionOptions) { + let sizes = vec![1000, 10000, 100000]; + for size in sizes { + let data = (0..size).map(|x| (x % 255) as u8).collect::>(); + test_roundtrip(c, &data); + } + } + + #[test] + fn test_codec_snappy() { + test_codec(CompressionOptions::Snappy); + } + + #[test] + fn test_codec_gzip_default() { + test_codec(CompressionOptions::Gzip(None)); + } + + #[test] + fn test_codec_gzip_low_compression() { + test_codec(CompressionOptions::Gzip(Some( + GzipLevel::try_new(1).unwrap(), + ))); + } + + #[test] + fn test_codec_brotli_default() { + test_codec(CompressionOptions::Brotli(None)); + } + + #[test] + fn test_codec_brotli_low_compression() { + test_codec(CompressionOptions::Brotli(Some( + BrotliLevel::try_new(1).unwrap(), + ))); + } + + #[test] + fn test_codec_brotli_high_compression() { + test_codec(CompressionOptions::Brotli(Some( + BrotliLevel::try_new(11).unwrap(), + ))); + } + + #[test] + fn test_codec_lz4_raw() { + test_codec(CompressionOptions::Lz4Raw); + } + + #[test] + fn test_codec_zstd_default() { + test_codec(CompressionOptions::Zstd(None)); + } + + #[cfg(feature = "zstd")] + #[test] + fn test_codec_zstd_low_compression() { + test_codec(CompressionOptions::Zstd(Some( + ZstdLevel::try_new(1).unwrap(), + ))); + } + + #[cfg(feature = "zstd")] + #[test] + fn test_codec_zstd_high_compression() { + test_codec(CompressionOptions::Zstd(Some( + ZstdLevel::try_new(21).unwrap(), + ))); + } +} diff --git a/crates/polars-parquet/src/parquet/encoding/bitpacked/decode.rs b/crates/polars-parquet/src/parquet/encoding/bitpacked/decode.rs new file mode 100644 index 000000000000..16052359886d --- /dev/null +++ b/crates/polars-parquet/src/parquet/encoding/bitpacked/decode.rs @@ -0,0 +1,415 @@ +use polars_utils::chunks::Chunks; + +use super::{Packed, Unpackable, Unpacked}; +use crate::parquet::error::{ParquetError, ParquetResult}; + +/// An [`Iterator`] of [`Unpackable`] unpacked from a bitpacked slice of bytes. +/// # Implementation +/// This iterator unpacks bytes in chunks and does not allocate. +#[derive(Debug, Clone)] +pub struct Decoder<'a, T: Unpackable> { + packed: Chunks<'a, u8>, + num_bits: usize, + /// number of items + pub(crate) length: usize, + _pd: std::marker::PhantomData, +} + +impl Default for Decoder<'_, T> { + fn default() -> Self { + Self { + packed: Chunks::new(&[], 1), + num_bits: 0, + length: 0, + _pd: std::marker::PhantomData, + } + } +} + +#[inline] +fn decode_pack(packed: &[u8], num_bits: usize, unpacked: &mut T::Unpacked) { + if packed.len() < T::Unpacked::LENGTH * num_bits / 8 { + let mut buf = T::Packed::zero(); + buf.as_mut()[..packed.len()].copy_from_slice(packed); + T::unpack(buf.as_ref(), num_bits, unpacked) + } else { + T::unpack(packed, num_bits, unpacked) + } +} + +impl<'a, T: Unpackable> Decoder<'a, T> { + /// Returns a [`Decoder`] with `T` encoded in `packed` with `num_bits`. + pub fn new(packed: &'a [u8], num_bits: usize, length: usize) -> Self { + Self::try_new(packed, num_bits, length).unwrap() + } + + /// Returns a [`Decoder`] with `T` encoded in `packed` with `num_bits`. + /// + /// `num_bits` is allowed to be `0`. + pub fn new_allow_zero(packed: &'a [u8], num_bits: usize, length: usize) -> Self { + Self::try_new_allow_zero(packed, num_bits, length).unwrap() + } + + /// Returns a [`Decoder`] with `T` encoded in `packed` with `num_bits`. + /// + /// `num_bits` is allowed to be `0`. + pub fn try_new_allow_zero( + packed: &'a [u8], + num_bits: usize, + length: usize, + ) -> ParquetResult { + let block_size = size_of::() * num_bits; + + if packed.len() * 8 < length * num_bits { + return Err(ParquetError::oos(format!( + "Unpacking {length} items with a number of bits {num_bits} requires at least {} bytes.", + length * num_bits / 8 + ))); + } + + debug_assert!(num_bits != 0 || packed.is_empty()); + let block_size = block_size.max(1); + let packed = Chunks::new(packed, block_size); + + Ok(Self { + length, + packed, + num_bits, + _pd: Default::default(), + }) + } + + /// Returns a [`Decoder`] with `T` encoded in `packed` with `num_bits`. + pub fn try_new(packed: &'a [u8], num_bits: usize, length: usize) -> ParquetResult { + let block_size = size_of::() * num_bits; + + if num_bits == 0 { + return Err(ParquetError::oos("Bitpacking requires num_bits > 0")); + } + + if packed.len() * 8 < length * num_bits { + return Err(ParquetError::oos(format!( + "Unpacking {length} items with a number of bits {num_bits} requires at least {} bytes.", + length * num_bits / 8 + ))); + } + + let packed = Chunks::new(packed, block_size); + + Ok(Self { + length, + packed, + num_bits, + _pd: Default::default(), + }) + } + + pub fn num_bits(&self) -> usize { + self.num_bits + } + + pub fn as_slice(&self) -> &[u8] { + self.packed.as_slice() + } + + pub fn lower_element(self) -> ParquetResult> { + let packed = self.packed.as_slice(); + Decoder::try_new(packed, self.num_bits, self.length) + } +} + +/// A iterator over the exact chunks in a [`Decoder`]. +/// +/// The remainder can be accessed using `remainder` or `next_inexact`. +#[derive(Debug)] +pub struct ChunkedDecoder<'a, 'b, T: Unpackable> { + pub(crate) decoder: &'b mut Decoder<'a, T>, +} + +impl Iterator for ChunkedDecoder<'_, '_, T> { + type Item = T::Unpacked; + + #[inline] + fn next(&mut self) -> Option { + if self.decoder.len() < T::Unpacked::LENGTH { + return None; + } + + let mut unpacked = T::Unpacked::zero(); + self.next_into(&mut unpacked)?; + Some(unpacked) + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.decoder.len() / T::Unpacked::LENGTH; + (len, Some(len)) + } +} + +impl ExactSizeIterator for ChunkedDecoder<'_, '_, T> {} + +impl ChunkedDecoder<'_, '_, T> { + /// Get and consume the remainder chunk if it exists. + /// + /// This should only be called after all the chunks full are consumed. + pub fn remainder(&mut self) -> Option<(T::Unpacked, usize)> { + if self.decoder.len() == 0 { + return None; + } + + debug_assert!(self.decoder.len() < T::Unpacked::LENGTH); + let remainder_len = self.decoder.len() % T::Unpacked::LENGTH; + + let mut unpacked = T::Unpacked::zero(); + let packed = self.decoder.packed.next()?; + decode_pack::(packed, self.decoder.num_bits, &mut unpacked); + self.decoder.length -= remainder_len; + Some((unpacked, remainder_len)) + } + + /// Get the next (possibly partial) chunk and its filled length + pub fn next_inexact(&mut self) -> Option<(T::Unpacked, usize)> { + if self.decoder.len() >= T::Unpacked::LENGTH { + Some((self.next().unwrap(), T::Unpacked::LENGTH)) + } else { + self.remainder() + } + } + + /// Consume the next chunk into `unpacked`. + pub fn next_into(&mut self, unpacked: &mut T::Unpacked) -> Option { + if self.decoder.len() == 0 { + return None; + } + + let unpacked_len = self.decoder.len().min(T::Unpacked::LENGTH); + let packed = self.decoder.packed.next()?; + decode_pack::(packed, self.decoder.num_bits, unpacked); + self.decoder.length -= unpacked_len; + + Some(unpacked_len) + } +} + +impl<'a, T: Unpackable> Decoder<'a, T> { + pub fn chunked<'b>(&'b mut self) -> ChunkedDecoder<'a, 'b, T> { + ChunkedDecoder { decoder: self } + } + + pub fn len(&self) -> usize { + self.length + } + + pub fn skip_chunks(&mut self, n: usize) { + debug_assert!(n * T::Unpacked::LENGTH <= self.length); + + for _ in (&mut self.packed).take(n) {} + self.length -= n * T::Unpacked::LENGTH; + } + + pub fn take(&mut self) -> Self { + let block_size = self.packed.chunk_size(); + let packed = std::mem::replace(&mut self.packed, Chunks::new(&[], block_size)); + let length = self.length; + self.length = 0; + + Self { + packed, + num_bits: self.num_bits, + length, + _pd: Default::default(), + } + } + + #[inline] + pub fn collect_into(mut self, vec: &mut Vec) { + // @NOTE: + // When microbenchmarking changing this from a element-wise iterator to a collect into + // improves the speed by around 4x. + // + // The unsafe code here allows us to not have to do a double memcopy. This saves us 20% in + // our microbenchmark. + // + // GB: I did some profiling on this function using the Yellow NYC Taxi dataset. There, the + // average self.length is ~52.8 and the average num_packs is ~2.2. Let this guide your + // decisions surrounding the optimization of this function. + + // @NOTE: + // Since T::Unpacked::LENGTH is always a power of two and known at compile time. Division, + // modulo and multiplication are just trivial operators. + let num_packs = (self.length / T::Unpacked::LENGTH) + + usize::from(self.length % T::Unpacked::LENGTH != 0); + + // We reserve enough space here for self.length rounded up to the next multiple of + // T::Unpacked::LENGTH so that we can safely just write into that memory. Otherwise, we + // would have to make a special path where we memcopy twice which is less than ideal. + vec.reserve(num_packs * T::Unpacked::LENGTH); + + // IMPORTANT: This pointer calculation has to appear after the reserve since that reserve + // might move the buffer. + let mut unpacked_ptr = vec.as_mut_ptr().wrapping_add(vec.len()); + + for _ in 0..num_packs { + // This unwrap should never fail since the packed length is checked on initialized of + // the `Decoder`. + let packed = self.packed.next().unwrap(); + + // SAFETY: + // Since we did a `vec::reserve` before with the total length, we know that the memory + // necessary for a `T::Unpacked` is available. + // + // - The elements in this buffer are properly aligned, so elements in a slice will also + // be properly aligned. + // - It is deferencable because it is (i) not null, (ii) in one allocated object, (iii) + // not pointing to deallocated memory, (iv) we do not rely on atomicity and (v) we do + // not read or write beyond the lifetime of `vec`. + // - All data is initialized before reading it. This is not perfect but should not lead + // to any UB. + // - We don't alias the same data from anywhere else at the same time, because we have + // the mutable reference to `vec`. + let unpacked_ref = unsafe { (unpacked_ptr as *mut T::Unpacked).as_mut() }.unwrap(); + + decode_pack::(packed, self.num_bits, unpacked_ref); + + unpacked_ptr = unpacked_ptr.wrapping_add(T::Unpacked::LENGTH); + } + + // SAFETY: + // We have written these elements before so we know that these are available now. + // + // - The capacity is larger since we reserved enough spaced with the opening + // `vec::reserve`. + // - All elements are initialized by the `decode_pack` into the `unpacked_ref`. + unsafe { vec.set_len(vec.len() + self.length) } + } +} + +#[cfg(test)] +mod tests { + use super::super::tests::case1; + use super::*; + + impl Decoder<'_, T> { + pub fn collect(self) -> Vec { + let mut vec = Vec::new(); + self.collect_into(&mut vec); + vec + } + } + + #[test] + fn test_decode_rle() { + // Test data: 0-7 with bit width 3 + // 0: 000 + // 1: 001 + // 2: 010 + // 3: 011 + // 4: 100 + // 5: 101 + // 6: 110 + // 7: 111 + let num_bits = 3; + let length = 8; + // encoded: 0b10001000u8, 0b11000110, 0b11111010 + let data = vec![0b10001000u8, 0b11000110, 0b11111010]; + + let decoded = Decoder::::try_new(&data, num_bits, length) + .unwrap() + .collect(); + assert_eq!(decoded, vec![0, 1, 2, 3, 4, 5, 6, 7]); + } + + #[test] + fn decode_large() { + let (num_bits, expected, data) = case1(); + + let decoded = Decoder::::try_new(&data, num_bits, expected.len()) + .unwrap() + .collect(); + assert_eq!(decoded, expected); + } + + #[test] + fn test_decode_bool() { + let num_bits = 1; + let length = 8; + let data = vec![0b10101010]; + + let decoded = Decoder::::try_new(&data, num_bits, length) + .unwrap() + .collect(); + assert_eq!(decoded, vec![0, 1, 0, 1, 0, 1, 0, 1]); + } + + #[test] + fn test_decode_u64() { + let num_bits = 1; + let length = 8; + let data = vec![0b10101010]; + + let decoded = Decoder::::try_new(&data, num_bits, length) + .unwrap() + .collect(); + assert_eq!(decoded, vec![0, 1, 0, 1, 0, 1, 0, 1]); + } + + #[test] + fn even_case() { + // [0, 1, 2, 3, 4, 5, 6, 0]x99 + let data = &[0b10001000u8, 0b11000110, 0b00011010]; + let num_bits = 3; + let copies = 99; // 8 * 99 % 32 != 0 + let expected = std::iter::repeat_n(&[0u32, 1, 2, 3, 4, 5, 6, 0], copies) + .flatten() + .copied() + .collect::>(); + let data = std::iter::repeat_n(data, copies) + .flatten() + .copied() + .collect::>(); + let length = expected.len(); + + let decoded = Decoder::::try_new(&data, num_bits, length) + .unwrap() + .collect(); + assert_eq!(decoded, expected); + } + + #[test] + fn odd_case() { + // [0, 1, 2, 3, 4, 5, 6, 0]x4 + [2] + let data = &[0b10001000u8, 0b11000110, 0b00011010]; + let num_bits = 3; + let copies = 4; + let expected = std::iter::repeat_n(&[0u32, 1, 2, 3, 4, 5, 6, 0], copies) + .flatten() + .copied() + .chain(std::iter::once(2)) + .collect::>(); + let data = std::iter::repeat_n(data, copies) + .flatten() + .copied() + .chain(std::iter::once(0b00000010u8)) + .collect::>(); + let length = expected.len(); + + let decoded = Decoder::::try_new(&data, num_bits, length) + .unwrap() + .collect(); + assert_eq!(decoded, expected); + } + + #[test] + fn test_errors() { + // zero length + assert!(Decoder::::try_new(&[], 1, 0).is_ok()); + // no bytes + assert!(Decoder::::try_new(&[], 1, 1).is_err()); + // too few bytes + assert!(Decoder::::try_new(&[1], 1, 8).is_ok()); + assert!(Decoder::::try_new(&[1, 1], 2, 8).is_ok()); + assert!(Decoder::::try_new(&[1], 1, 9).is_err()); + // zero num_bits + assert!(Decoder::::try_new(&[1], 0, 1).is_err()); + } +} diff --git a/crates/polars-parquet/src/parquet/encoding/bitpacked/encode.rs b/crates/polars-parquet/src/parquet/encoding/bitpacked/encode.rs new file mode 100644 index 000000000000..e88e4f0b105e --- /dev/null +++ b/crates/polars-parquet/src/parquet/encoding/bitpacked/encode.rs @@ -0,0 +1,52 @@ +use super::{Unpackable, Unpacked}; + +/// Encodes (packs) a slice of [`Unpackable`] into bitpacked bytes `packed`, using `num_bits` per value. +/// +/// This function assumes that the maximum value in `unpacked` fits in `num_bits` bits +/// and saturates higher values. +/// +/// Only the first `ceil8(unpacked.len() * num_bits)` of `packed` are populated. +pub fn encode(unpacked: &[T], num_bits: usize, packed: &mut [u8]) { + let chunks = unpacked.chunks_exact(T::Unpacked::LENGTH); + + let remainder = chunks.remainder(); + + let packed_size = (T::Unpacked::LENGTH * num_bits).div_ceil(8); + if !remainder.is_empty() { + let packed_chunks = packed.chunks_mut(packed_size); + let mut last_chunk = T::Unpacked::zero(); + for i in 0..remainder.len() { + last_chunk[i] = remainder[i] + } + + chunks + .chain(std::iter::once(last_chunk.as_ref())) + .zip(packed_chunks) + .for_each(|(unpacked, packed)| { + T::pack(&unpacked.try_into().unwrap(), num_bits, packed); + }); + } else { + let packed_chunks = packed.chunks_exact_mut(packed_size); + chunks.zip(packed_chunks).for_each(|(unpacked, packed)| { + T::pack(&unpacked.try_into().unwrap(), num_bits, packed); + }); + } +} + +/// Encodes (packs) a potentially incomplete pack of [`Unpackable`] into bitpacked +/// bytes `packed`, using `num_bits` per value. +/// +/// This function assumes that the maximum value in `unpacked` fits in `num_bits` bits +/// and saturates higher values. +/// +/// Only the first `ceil8(unpacked.len() * num_bits)` of `packed` are populated. +#[inline] +pub fn encode_pack(unpacked: &[T], num_bits: usize, packed: &mut [u8]) { + if unpacked.len() < T::Unpacked::LENGTH { + let mut complete_unpacked = T::Unpacked::zero(); + complete_unpacked.as_mut()[..unpacked.len()].copy_from_slice(unpacked); + T::pack(&complete_unpacked, num_bits, packed) + } else { + T::pack(&unpacked.try_into().unwrap(), num_bits, packed) + } +} diff --git a/crates/polars-parquet/src/parquet/encoding/bitpacked/mod.rs b/crates/polars-parquet/src/parquet/encoding/bitpacked/mod.rs new file mode 100644 index 000000000000..10a1f61af963 --- /dev/null +++ b/crates/polars-parquet/src/parquet/encoding/bitpacked/mod.rs @@ -0,0 +1,276 @@ +macro_rules! seq_macro { + ($i:ident in 1..15 $block:block) => { + seq_macro!($i in [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + ] $block) + }; + ($i:ident in 0..16 $block:block) => { + seq_macro!($i in [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + ] $block) + }; + ($i:ident in 0..=16 $block:block) => { + seq_macro!($i in [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, + ] $block) + }; + ($i:ident in 1..31 $block:block) => { + seq_macro!($i in [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + ] $block) + }; + ($i:ident in 0..32 $block:block) => { + seq_macro!($i in [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + ] $block) + }; + ($i:ident in 0..=32 $block:block) => { + seq_macro!($i in [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, + ] $block) + }; + ($i:ident in 1..63 $block:block) => { + seq_macro!($i in [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, + ] $block) + }; + ($i:ident in 0..64 $block:block) => { + seq_macro!($i in [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, + ] $block) + }; + ($i:ident in 0..=64 $block:block) => { + seq_macro!($i in [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, + 64, + ] $block) + }; + ($i:ident in [$($value:literal),+ $(,)?] $block:block) => { + $({ + #[allow(non_upper_case_globals)] + const $i: usize = $value; + { $block } + })+ + }; +} + +mod decode; +mod encode; +mod pack; +mod unpack; + +pub use decode::{ChunkedDecoder, Decoder}; +pub use encode::{encode, encode_pack}; + +/// A byte slice (e.g. `[u8; 8]`) denoting types that represent complete packs. +pub trait Packed: + Copy + + Sized + + AsRef<[u8]> + + AsMut<[u8]> + + std::ops::IndexMut + + for<'a> TryFrom<&'a [u8]> +{ + const LENGTH: usize; + fn zero() -> Self; +} + +impl Packed for [u8; 8] { + const LENGTH: usize = 8; + #[inline] + fn zero() -> Self { + [0; 8] + } +} + +impl Packed for [u8; 16 * 2] { + const LENGTH: usize = 16 * 2; + #[inline] + fn zero() -> Self { + [0; 16 * 2] + } +} + +impl Packed for [u8; 32 * 4] { + const LENGTH: usize = 32 * 4; + #[inline] + fn zero() -> Self { + [0; 32 * 4] + } +} + +impl Packed for [u8; 64 * 8] { + const LENGTH: usize = 64 * 8; + #[inline] + fn zero() -> Self { + [0; 64 * 8] + } +} + +/// A byte slice of [`Unpackable`] denoting complete unpacked arrays. +pub trait Unpacked: + Copy + + Sized + + AsRef<[T]> + + AsMut<[T]> + + std::ops::Index + + std::ops::IndexMut + + for<'a> TryFrom<&'a [T], Error = std::array::TryFromSliceError> +{ + const LENGTH: usize; + fn zero() -> Self; +} + +impl Unpacked for [u8; 8] { + const LENGTH: usize = 8; + #[inline] + fn zero() -> Self { + [0; 8] + } +} + +impl Unpacked for [u16; 16] { + const LENGTH: usize = 16; + #[inline] + fn zero() -> Self { + [0; 16] + } +} + +impl Unpacked for [u32; 32] { + const LENGTH: usize = 32; + #[inline] + fn zero() -> Self { + [0; 32] + } +} + +impl Unpacked for [u64; 64] { + const LENGTH: usize = 64; + #[inline] + fn zero() -> Self { + [0; 64] + } +} + +/// A type representing a type that can be bitpacked and unpacked by this crate. +pub trait Unpackable: Copy + Sized + Default { + type Packed: Packed; + type Unpacked: Unpacked; + + fn unpack(packed: &[u8], num_bits: usize, unpacked: &mut Self::Unpacked); + fn pack(unpacked: &Self::Unpacked, num_bits: usize, packed: &mut [u8]); +} + +impl Unpackable for u16 { + type Packed = [u8; 16 * 2]; + type Unpacked = [u16; 16]; + + #[inline] + fn unpack(packed: &[u8], num_bits: usize, unpacked: &mut Self::Unpacked) { + unpack::unpack16(packed, unpacked, num_bits) + } + + #[inline] + fn pack(packed: &Self::Unpacked, num_bits: usize, unpacked: &mut [u8]) { + pack::pack16(packed, unpacked, num_bits) + } +} + +impl Unpackable for u32 { + type Packed = [u8; 32 * 4]; + type Unpacked = [u32; 32]; + + #[inline] + fn unpack(packed: &[u8], num_bits: usize, unpacked: &mut Self::Unpacked) { + unpack::unpack32(packed, unpacked, num_bits) + } + + #[inline] + fn pack(packed: &Self::Unpacked, num_bits: usize, unpacked: &mut [u8]) { + pack::pack32(packed, unpacked, num_bits) + } +} + +impl Unpackable for u64 { + type Packed = [u8; 64 * 8]; + type Unpacked = [u64; 64]; + + #[inline] + fn unpack(packed: &[u8], num_bits: usize, unpacked: &mut Self::Unpacked) { + unpack::unpack64(packed, unpacked, num_bits) + } + + #[inline] + fn pack(packed: &Self::Unpacked, num_bits: usize, unpacked: &mut [u8]) { + pack::pack64(packed, unpacked, num_bits) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + pub fn case1() -> (usize, Vec, Vec) { + let num_bits = 3; + let compressed = vec![ + 0b10001000u8, + 0b11000110, + 0b11111010, + 0b10001000u8, + 0b11000110, + 0b11111010, + 0b10001000u8, + 0b11000110, + 0b11111010, + 0b10001000u8, + 0b11000110, + 0b11111010, + 0b10001000u8, + 0b11000110, + 0b11111010, + ]; + let decompressed = vec![ + 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, + 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, + ]; + (num_bits, decompressed, compressed) + } + + #[test] + fn encode_large() { + let (num_bits, unpacked, expected) = case1(); + let mut packed = vec![0u8; 4 * 32]; + + encode(&unpacked, num_bits, &mut packed); + assert_eq!(&packed[..15], expected); + } + + #[test] + fn test_encode() { + let num_bits = 3; + let unpacked = vec![0, 1, 2, 3, 4, 5, 6, 7]; + + let mut packed = vec![0u8; 4 * 32]; + + encode::(&unpacked, num_bits, &mut packed); + + let expected = vec![0b10001000u8, 0b11000110, 0b11111010]; + + assert_eq!(&packed[..3], expected); + } +} diff --git a/crates/polars-parquet/src/parquet/encoding/bitpacked/pack.rs b/crates/polars-parquet/src/parquet/encoding/bitpacked/pack.rs new file mode 100644 index 000000000000..d60aa3be1de7 --- /dev/null +++ b/crates/polars-parquet/src/parquet/encoding/bitpacked/pack.rs @@ -0,0 +1,144 @@ +#![allow(unsafe_op_in_unsafe_fn)] +/// Macro that generates a packing function taking the number of bits as a const generic +macro_rules! pack_impl { + ($t:ty, $bytes:literal, $bits:tt, $bits_minus_one:tt) => { + // Adapted from https://github.com/quickwit-oss/bitpacking + pub unsafe fn pack(input: &[$t; $bits], output: &mut [u8]) { + if NUM_BITS == 0 { + for out in output { + *out = 0; + } + return; + } + assert!(NUM_BITS <= $bits); + assert!(output.len() >= NUM_BITS * $bytes); + + let input_ptr = input.as_ptr(); + let mut output_ptr = output.as_mut_ptr() as *mut $t; + let mut out_register: $t = read_unaligned(input_ptr); + + if $bits == NUM_BITS { + write_unaligned(output_ptr, out_register); + output_ptr = output_ptr.offset(1); + } + + // Using microbenchmark (79d1fff), unrolling this loop is over 10x + // faster than not (>20x faster than old algorithm) + seq_macro!(i in 1..$bits_minus_one { + let bits_filled: usize = i * NUM_BITS; + let inner_cursor: usize = bits_filled % $bits; + let remaining: usize = $bits - inner_cursor; + + let offset_ptr = input_ptr.add(i); + let in_register: $t = read_unaligned(offset_ptr); + + out_register = + if inner_cursor > 0 { + out_register | (in_register << inner_cursor) + } else { + in_register + }; + + if remaining <= NUM_BITS { + write_unaligned(output_ptr, out_register); + output_ptr = output_ptr.offset(1); + if 0 < remaining && remaining < NUM_BITS { + out_register = in_register >> remaining + } + } + }); + + let in_register: $t = read_unaligned(input_ptr.add($bits - 1)); + out_register = if $bits - NUM_BITS > 0 { + out_register | (in_register << ($bits - NUM_BITS)) + } else { + in_register + }; + write_unaligned(output_ptr, out_register) + } + }; +} + +/// Macro that generates pack functions that accept num_bits as a parameter +macro_rules! pack { + ($name:ident, $t:ty, $bytes:literal, $bits:tt, $bits_minus_one:tt) => { + mod $name { + use std::ptr::{read_unaligned, write_unaligned}; + pack_impl!($t, $bytes, $bits, $bits_minus_one); + } + + /// Pack unpacked `input` into `output` with a bit width of `num_bits` + pub fn $name(input: &[$t; $bits], output: &mut [u8], num_bits: usize) { + // This will get optimised into a jump table + seq_macro!(i in 0..=$bits { + if i == num_bits { + unsafe { + return $name::pack::(input, output); + } + } + }); + unreachable!("invalid num_bits {}", num_bits); + } + }; +} + +pack!(pack16, u16, 2, 16, 15); +pack!(pack32, u32, 4, 32, 31); +pack!(pack64, u64, 8, 64, 63); + +#[cfg(test)] +mod tests { + use rand::distributions::{Distribution, Uniform}; + + use super::super::unpack::*; + use super::*; + + #[test] + fn test_u32() { + let input = [ + 0u32, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0u32, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, + ]; + for num_bits in 4..32 { + let mut output = [0u8; 32 * 4]; + pack32(&input, &mut output, num_bits); + let mut other = [0u32; 32]; + unpack32(&output, &mut other, num_bits); + assert_eq!(other, input); + } + } + + #[test] + fn test_u32_random() { + let mut rng = rand::thread_rng(); + let mut random_array = [0u32; 32]; + let between = Uniform::from(0..131_072); + for num_bits in 17..=32 { + for i in &mut random_array { + *i = between.sample(&mut rng); + } + let mut output = [0u8; 32 * 4]; + pack32(&random_array, &mut output, num_bits); + let mut other = [0u32; 32]; + unpack32(&output, &mut other, num_bits); + assert_eq!(other, random_array); + } + } + + #[test] + fn test_u64_random() { + let mut rng = rand::thread_rng(); + let mut random_array = [0u64; 64]; + let between = Uniform::from(0..131_072); + for num_bits in 17..=64 { + for i in &mut random_array { + *i = between.sample(&mut rng); + } + let mut output = [0u8; 64 * 8]; + pack64(&random_array, &mut output, num_bits); + let mut other = [0u64; 64]; + unpack64(&output, &mut other, num_bits); + assert_eq!(other, random_array); + } + } +} diff --git a/crates/polars-parquet/src/parquet/encoding/bitpacked/unpack.rs b/crates/polars-parquet/src/parquet/encoding/bitpacked/unpack.rs new file mode 100644 index 000000000000..c52c17d21681 --- /dev/null +++ b/crates/polars-parquet/src/parquet/encoding/bitpacked/unpack.rs @@ -0,0 +1,147 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// +// Copied from https://github.com/apache/arrow-rs/blob/6859efa690d4c9530cf8a24053bc6ed81025a164/parquet/src/util/bit_pack.rs + +// This implements bit unpacking. For example, for `u8` and `num_bits=3`. +// 0b001_101_110 -> 0b0000_0001, 0b0000_0101, 0b0000_0110 +// +// This file is a bit insane. It unrolls all the possible num_bits vs. combinations. These are very +// highly used functions in Parquet and therefore this that been extensively unrolled and +// optimized. Attempts have been done to introduce SIMD here, but those attempts have not paid off +// in comparison to auto-vectorization. +// +// Generally, there are two code-size vs. runtime tradeoffs taken here in favor of +// runtime. +// +// 1. Each individual function unrolled to a point where all constants are known to +// the compiler. In microbenchmarks, this increases the performance by around 4.5 +// to 5 times. +// 2. All functions are compiled separately and dispatch is done using a +// jumptable. In microbenchmarks, this increases the performance by around 2 to 2.5 +// times. + +/// Macro that generates an unpack function taking the number of bits as a const generic +macro_rules! unpack_impl { + ($t:ty, $bytes:literal, $bits:tt) => { + pub fn unpack(input: &[u8], output: &mut [$t; $bits]) { + if NUM_BITS == 0 { + for out in output { + *out = 0; + } + return; + } + + assert!(NUM_BITS <= $bytes * 8); + + let mask = match NUM_BITS { + $bits => <$t>::MAX, + _ => ((1 << NUM_BITS) - 1), + }; + + assert!(input.len() >= NUM_BITS * $bytes); + + let r = |output_idx: usize| { + <$t>::from_le_bytes( + input[output_idx * $bytes..output_idx * $bytes + $bytes] + .try_into() + .unwrap(), + ) + }; + + // @NOTE + // I was surprised too, but this macro vs. a for loop saves around 4.5 - 5x on + // performance in a microbenchmark. Although the code it generates is completely + // insane. There should be something we can do here to make this less code, sane code + // and faster code. + seq_macro!(i in 0..$bits { + let start_bit = i * NUM_BITS; + let end_bit = start_bit + NUM_BITS; + + let start_bit_offset = start_bit % $bits; + let end_bit_offset = end_bit % $bits; + let start_byte = start_bit / $bits; + let end_byte = end_bit / $bits; + if start_byte != end_byte && end_bit_offset != 0 { + let val = r(start_byte); + let a = val >> start_bit_offset; + let val = r(end_byte); + let b = val << (NUM_BITS - end_bit_offset); + + output[i] = a | (b & mask); + } else { + let val = r(start_byte); + output[i] = (val >> start_bit_offset) & mask; + } + }); + } + }; +} + +/// Macro that generates unpack functions that accept num_bits as a parameter +macro_rules! unpack { + ($name:ident, $t:ty, $bytes:literal, $bits:tt) => { + mod $name { + unpack_impl!($t, $bytes, $bits); + } + + /// Unpack packed `input` into `output` with a bit width of `num_bits` + pub fn $name(input: &[u8], output: &mut [$t; $bits], num_bits: usize) { + // This will get optimised into a jump table + // + // @NOTE + // This jumptable appoach saves around 2 - 2.5x on performance over no jumptable and no + // generics. + seq_macro!(i in 0..=$bits { + if i == num_bits { + return $name::unpack::(input, output); + } + }); + unreachable!("invalid num_bits {}", num_bits); + } + }; +} + +unpack!(unpack16, u16, 2, 16); +unpack!(unpack32, u32, 4, 32); +unpack!(unpack64, u64, 8, 64); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_basic() { + let input = [0xFF; 4096]; + + for i in 0..=32 { + let mut output = [0; 32]; + unpack32(&input, &mut output, i); + for (idx, out) in output.iter().enumerate() { + assert_eq!(out.trailing_ones() as usize, i, "out[{}] = {}", idx, out); + } + } + + for i in 0..=64 { + let mut output = [0; 64]; + unpack64(&input, &mut output, i); + for (idx, out) in output.iter().enumerate() { + assert_eq!(out.trailing_ones() as usize, i, "out[{}] = {}", idx, out); + } + } + } +} diff --git a/crates/polars-parquet/src/parquet/encoding/byte_stream_split/decoder.rs b/crates/polars-parquet/src/parquet/encoding/byte_stream_split/decoder.rs new file mode 100644 index 000000000000..1b383e9522f1 --- /dev/null +++ b/crates/polars-parquet/src/parquet/encoding/byte_stream_split/decoder.rs @@ -0,0 +1,118 @@ +use crate::parquet::error::ParquetError; + +const MAX_ELEMENT_SIZE: usize = 8; + +/// Decodes using the [Byte Stream Split](https://github.com/apache/parquet-format/blob/master/Encodings.md#byte-stream-split-byte_stream_split--9) encoding. +/// # Implementation +/// A fixed size buffer is stored inline to support reading types of up to 8 bytes in size. +#[derive(Debug)] +pub struct Decoder<'a> { + values: &'a [u8], + buffer: [u8; MAX_ELEMENT_SIZE], + num_elements: usize, + position: usize, + element_size: usize, +} + +impl<'a> Decoder<'a> { + pub fn try_new(values: &'a [u8], element_size: usize) -> Result { + if element_size > MAX_ELEMENT_SIZE { + // Since Parquet format version 2.11 it's valid to use byte stream split for fixed-length byte array data, + // which could be larger than 8 bytes, but Polars doesn't yet support reading byte stream split encoded FLBA data. + return Err(ParquetError::oos(format!( + "Byte stream split decoding only supports up to {} byte element sizes", + MAX_ELEMENT_SIZE + ))); + } + + let values_size = values.len(); + if values_size % element_size != 0 { + return Err(ParquetError::oos(format!( + "Values array length ({}) is not a multiple of the element size ({})", + values_size, element_size + ))); + } + let num_elements = values.len() / element_size; + + Ok(Self { + values, + buffer: [0; MAX_ELEMENT_SIZE], + num_elements, + position: 0, + element_size, + }) + } + + pub fn move_next(&mut self) -> bool { + if self.position >= self.num_elements { + return false; + } + + debug_assert!(self.element_size <= MAX_ELEMENT_SIZE); + debug_assert!(self.values.len() >= self.num_elements * self.element_size); + for n in 0..self.element_size { + unsafe { + // SAFETY: + // We have the invariants that element_size <= MAX_ELEMENT_SIZE, + // buffer.len() == MAX_ELEMENT_SIZE, + // position < num_elements and + // values.len() >= num_elements * element_size. + *self.buffer.get_unchecked_mut(n) = *self + .values + .get_unchecked((self.num_elements * n) + self.position) + } + } + + self.position += 1; + true + } + + /// The number of remaining values + pub fn len(&self) -> usize { + self.num_elements - self.position + } + + pub fn current_value(&self) -> &[u8] { + &self.buffer[0..self.element_size] + } + + pub fn iter_converted<'b, T, F>(&'b mut self, converter: F) -> DecoderIterator<'a, 'b, T, F> + where + F: Copy + Fn(&[u8]) -> T, + { + DecoderIterator { + decoder: self, + converter, + } + } +} + +#[derive(Debug)] +pub struct DecoderIterator<'a, 'b, T, F> +where + F: Copy + Fn(&[u8]) -> T, +{ + decoder: &'b mut Decoder<'a>, + converter: F, +} + +impl Iterator for DecoderIterator<'_, '_, T, F> +where + F: Copy + Fn(&[u8]) -> T, +{ + type Item = T; + + #[inline] + fn next(&mut self) -> Option { + if self.decoder.move_next() { + Some((self.converter)(self.decoder.current_value())) + } else { + None + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.decoder.len(), Some(self.decoder.len())) + } +} diff --git a/crates/polars-parquet/src/parquet/encoding/byte_stream_split/mod.rs b/crates/polars-parquet/src/parquet/encoding/byte_stream_split/mod.rs new file mode 100644 index 000000000000..555954ff8069 --- /dev/null +++ b/crates/polars-parquet/src/parquet/encoding/byte_stream_split/mod.rs @@ -0,0 +1,77 @@ +mod decoder; + +pub use decoder::Decoder; + +#[cfg(test)] +mod tests { + use super::*; + use crate::parquet::error::ParquetError; + use crate::parquet::types::NativeType; + + #[test] + fn round_trip_f32() -> Result<(), ParquetError> { + let data = vec![1.0e-2_f32, 2.5_f32, 3.0e2_f32]; + let mut buffer = vec![]; + encode(&data, &mut buffer); + + let mut decoder = Decoder::try_new(&buffer, size_of::())?; + let values = decoder + .iter_converted(|bytes| f32::from_le_bytes(bytes.try_into().unwrap())) + .collect::>(); + + assert_eq!(data, values); + + Ok(()) + } + + #[test] + fn round_trip_f64() -> Result<(), ParquetError> { + let data = vec![1.0e-2_f64, 2.5_f64, 3.0e2_f64]; + let mut buffer = vec![]; + encode(&data, &mut buffer); + + let mut decoder = Decoder::try_new(&buffer, size_of::())?; + let values = decoder + .iter_converted(|bytes| f64::from_le_bytes(bytes.try_into().unwrap())) + .collect::>(); + + assert_eq!(data, values); + + Ok(()) + } + + #[test] + fn fails_for_invalid_values_size() -> Result<(), ParquetError> { + let buffer = vec![0; 12]; + + let result = Decoder::try_new(&buffer, 8); + assert!(result.is_err()); + + Ok(()) + } + + #[test] + fn fails_for_invalid_element_size() -> Result<(), ParquetError> { + let buffer = vec![0; 16]; + + let result = Decoder::try_new(&buffer, 16); + assert!(result.is_err()); + + Ok(()) + } + + fn encode(data: &[T], buffer: &mut Vec) { + let element_size = size_of::(); + let num_elements = data.len(); + let total_length = size_of_val(data); + buffer.resize(total_length, 0); + + for (i, v) in data.iter().enumerate() { + let value_bytes = v.to_le_bytes(); + let value_bytes_ref = value_bytes.as_ref(); + for n in 0..element_size { + buffer[(num_elements * n) + i] = value_bytes_ref[n]; + } + } + } +} diff --git a/crates/polars-parquet/src/parquet/encoding/delta_bitpacked/decoder.rs b/crates/polars-parquet/src/parquet/encoding/delta_bitpacked/decoder.rs new file mode 100644 index 000000000000..a89d824bdb6a --- /dev/null +++ b/crates/polars-parquet/src/parquet/encoding/delta_bitpacked/decoder.rs @@ -0,0 +1,893 @@ +//! This module implements the `DELTA_BINARY_PACKED` encoding. +//! +//! For performance reasons this is done without iterators. Instead, we have `gather_n` functions +//! and a `DeltaGatherer` trait. These allow efficient decoding and mapping of the decoded values. +//! +//! Full information on the delta encoding can be found on the Apache Parquet Format repository. +//! +//! +//! +//! Delta encoding compresses sequential integer values by encoding the first value and the +//! differences between consequentive values. This variant encodes the data into `Block`s and +//! `MiniBlock`s. +//! +//! - A `Block` contains a minimum delta, bitwidths and one or more miniblocks. +//! - A `MiniBlock` contains many deltas that are encoded in [`bitpacked`] encoding. +//! +//! The decoder keeps track of the last value and calculates a new value with the following +//! function. +//! +//! ```text +//! NextValue(Delta) = { +//! Value = Decoder.LastValue + Delta + Block.MinDelta +//! Decoder.LastValue = Value +//! return Value +//! } +//! ``` +//! +//! Note that all these additions need to be wrapping. + +use super::super::{bitpacked, uleb128, zigzag_leb128}; +use super::lin_natural_sum; +use crate::parquet::encoding::bitpacked::{Unpackable, Unpacked}; +use crate::parquet::error::{ParquetError, ParquetResult}; + +const MAX_BITWIDTH: u8 = 64; + +/// Decoder of parquets' `DELTA_BINARY_PACKED`. +#[derive(Debug)] +pub struct Decoder<'a> { + num_miniblocks_per_block: usize, + values_per_block: usize, + + values_remaining: usize, + + last_value: i64, + + values: &'a [u8], + + block: Block<'a>, +} + +#[derive(Debug)] +struct Block<'a> { + min_delta: i64, + + /// Bytes that give the `num_bits` for the [`bitpacked::Decoder`]. + /// + /// Invariant: `bitwidth[i] <= MAX_BITWIDTH` for all `i` + bitwidths: &'a [u8], + values_remaining: usize, + miniblock: MiniBlock<'a>, +} + +#[derive(Debug)] +struct MiniBlock<'a> { + decoder: bitpacked::Decoder<'a, u64>, + buffered: ::Unpacked, + unpacked_start: usize, + unpacked_end: usize, +} + +pub(crate) struct SumGatherer(pub(crate) usize); + +pub trait DeltaGatherer { + type Target: std::fmt::Debug; + + fn target_len(&self, target: &Self::Target) -> usize; + fn target_reserve(&self, target: &mut Self::Target, n: usize); + + /// Gather one element with value `v` into `target`. + fn gather_one(&mut self, target: &mut Self::Target, v: i64) -> ParquetResult<()>; + + /// Gather `num_repeats` elements into `target`. + /// + /// The first value is `v` and the `n`-th value is `v + (n-1)*delta`. + fn gather_constant( + &mut self, + target: &mut Self::Target, + v: i64, + delta: i64, + num_repeats: usize, + ) -> ParquetResult<()> { + for i in 0..num_repeats { + self.gather_one(target, v + (i as i64) * delta)?; + } + Ok(()) + } + /// Gather a `slice` of elements into `target`. + fn gather_slice(&mut self, target: &mut Self::Target, slice: &[i64]) -> ParquetResult<()> { + for &v in slice { + self.gather_one(target, v)?; + } + Ok(()) + } + /// Gather a `chunk` of elements into `target`. + fn gather_chunk(&mut self, target: &mut Self::Target, chunk: &[i64; 64]) -> ParquetResult<()> { + self.gather_slice(target, chunk) + } +} + +impl DeltaGatherer for SumGatherer { + type Target = usize; + + fn target_len(&self, _target: &Self::Target) -> usize { + self.0 + } + fn target_reserve(&self, _target: &mut Self::Target, _n: usize) {} + + fn gather_one(&mut self, target: &mut Self::Target, v: i64) -> ParquetResult<()> { + if v < 0 { + return Err(ParquetError::oos(format!( + "Invalid delta encoding length {v}" + ))); + } + + *target += v as usize; + self.0 += 1; + Ok(()) + } + fn gather_constant( + &mut self, + target: &mut Self::Target, + v: i64, + delta: i64, + num_repeats: usize, + ) -> ParquetResult<()> { + if v < 0 || (delta < 0 && num_repeats > 0 && (num_repeats - 1) as i64 * delta + v < 0) { + return Err(ParquetError::oos("Invalid delta encoding length")); + } + + *target += lin_natural_sum(v, delta, num_repeats) as usize; + + Ok(()) + } + fn gather_slice(&mut self, target: &mut Self::Target, slice: &[i64]) -> ParquetResult<()> { + let min = slice.iter().copied().min().unwrap_or_default(); + if min < 0 { + return Err(ParquetError::oos(format!( + "Invalid delta encoding length {min}" + ))); + } + + *target += slice.iter().copied().map(|v| v as usize).sum::(); + self.0 += slice.len(); + Ok(()) + } + fn gather_chunk(&mut self, target: &mut Self::Target, chunk: &[i64; 64]) -> ParquetResult<()> { + let min = chunk.iter().copied().min().unwrap_or_default(); + if min < 0 { + return Err(ParquetError::oos(format!( + "Invalid delta encoding length {min}" + ))); + } + *target += chunk.iter().copied().map(|v| v as usize).sum::(); + self.0 += chunk.len(); + Ok(()) + } +} + +/// Gather the rest of the [`bitpacked::Decoder`] into `target` +fn gather_bitpacked( + target: &mut G::Target, + min_delta: i64, + last_value: &mut i64, + mut decoder: bitpacked::Decoder, + gatherer: &mut G, +) -> ParquetResult<()> { + let mut chunked = decoder.chunked(); + for mut chunk in chunked.by_ref() { + for value in &mut chunk { + *last_value = last_value + .wrapping_add(*value as i64) + .wrapping_add(min_delta); + *value = *last_value as u64; + } + + let chunk = bytemuck::cast_ref(&chunk); + gatherer.gather_chunk(target, chunk)?; + } + + if let Some((mut chunk, length)) = chunked.remainder() { + let slice = &mut chunk[..length]; + + for value in slice.iter_mut() { + *last_value = last_value + .wrapping_add(*value as i64) + .wrapping_add(min_delta); + *value = *last_value as u64; + } + + let slice = bytemuck::cast_slice(slice); + gatherer.gather_slice(target, slice)?; + } + + Ok(()) +} + +/// Gather an entire [`MiniBlock`] into `target` +fn gather_miniblock( + target: &mut G::Target, + min_delta: i64, + bitwidth: u8, + values: &[u8], + values_per_miniblock: usize, + last_value: &mut i64, + gatherer: &mut G, +) -> ParquetResult<()> { + let bitwidth = bitwidth as usize; + + if bitwidth == 0 { + let v = last_value.wrapping_add(min_delta); + gatherer.gather_constant(target, v, min_delta, values_per_miniblock)?; + *last_value = last_value.wrapping_add(min_delta * values_per_miniblock as i64); + return Ok(()); + } + + debug_assert!(bitwidth <= 64); + debug_assert_eq!((bitwidth * values_per_miniblock).div_ceil(8), values.len()); + + let start_length = gatherer.target_len(target); + gather_bitpacked( + target, + min_delta, + last_value, + bitpacked::Decoder::new(values, bitwidth, values_per_miniblock), + gatherer, + )?; + let target_length = gatherer.target_len(target); + + debug_assert_eq!(target_length - start_length, values_per_miniblock); + + Ok(()) +} + +/// Gather an entire [`Block`] into `target` +fn gather_block<'a, G: DeltaGatherer>( + target: &mut G::Target, + num_miniblocks: usize, + values_per_miniblock: usize, + mut values: &'a [u8], + last_value: &mut i64, + gatherer: &mut G, +) -> ParquetResult<&'a [u8]> { + let (min_delta, consumed) = zigzag_leb128::decode(values); + values = &values[consumed..]; + let bitwidths; + (bitwidths, values) = values + .split_at_checked(num_miniblocks) + .ok_or_else(|| ParquetError::oos("Not enough bitwidths available in delta encoding"))?; + + gatherer.target_reserve(target, num_miniblocks * values_per_miniblock); + for &bitwidth in bitwidths { + let miniblock; + (miniblock, values) = values + .split_at_checked((bitwidth as usize * values_per_miniblock).div_ceil(8)) + .ok_or_else(|| ParquetError::oos("Not enough bytes for miniblock in delta encoding"))?; + gather_miniblock( + target, + min_delta, + bitwidth, + miniblock, + values_per_miniblock, + last_value, + gatherer, + )?; + } + + Ok(values) +} + +impl<'a> Decoder<'a> { + pub fn try_new(mut values: &'a [u8]) -> ParquetResult<(Self, &'a [u8])> { + let header_err = || ParquetError::oos("Insufficient bytes for Delta encoding header"); + + // header: + // + + let (values_per_block, consumed) = uleb128::decode(values); + let values_per_block = values_per_block as usize; + values = values.get(consumed..).ok_or_else(header_err)?; + + assert_eq!(values_per_block % 128, 0); + + let (num_miniblocks_per_block, consumed) = uleb128::decode(values); + let num_miniblocks_per_block = num_miniblocks_per_block as usize; + values = values.get(consumed..).ok_or_else(header_err)?; + + let (total_count, consumed) = uleb128::decode(values); + let total_count = total_count as usize; + values = values.get(consumed..).ok_or_else(header_err)?; + + let (first_value, consumed) = zigzag_leb128::decode(values); + values = values.get(consumed..).ok_or_else(header_err)?; + + assert_eq!(values_per_block % num_miniblocks_per_block, 0); + assert_eq!((values_per_block / num_miniblocks_per_block) % 32, 0); + + let values_per_miniblock = values_per_block / num_miniblocks_per_block; + assert_eq!(values_per_miniblock % 8, 0); + + // We skip over all the values to determine where the slice stops. + // + // This also has the added benefit of error checking in advance, thus we can unwrap in + // other places. + + let mut rem = values; + if total_count > 1 { + let mut num_values_left = total_count - 1; + while num_values_left > 0 { + // If the number of values is does not need all the miniblocks anymore, we need to + // ignore the later miniblocks and regard them as having bitwidth = 0. + // + // Quoted from the specification: + // + // > If, in the last block, less than miniblocks + // > are needed to store the values, the bytes storing the bit widths of the + // > unneeded miniblocks are still present, their value should be zero, but readers + // > must accept arbitrary values as well. There are no additional padding bytes for + // > the miniblock bodies though, as if their bit widths were 0 (regardless of the + // > actual byte values). The reader knows when to stop reading by keeping track of + // > the number of values read. + let num_remaining_mini_blocks = usize::min( + num_miniblocks_per_block, + num_values_left.div_ceil(values_per_miniblock), + ); + + // block: + // + + let (_, consumed) = zigzag_leb128::decode(rem); + rem = rem.get(consumed..).ok_or_else(|| { + ParquetError::oos("No min-delta value in delta encoding miniblock") + })?; + + if rem.len() < num_miniblocks_per_block { + return Err(ParquetError::oos( + "Not enough bitwidths available in delta encoding", + )); + } + if let Some(err_bitwidth) = rem + .get(..num_remaining_mini_blocks) + .expect("num_remaining_mini_blocks <= num_miniblocks_per_block") + .iter() + .copied() + .find(|&bitwidth| bitwidth > MAX_BITWIDTH) + { + return Err(ParquetError::oos(format!( + "Delta encoding miniblock with bitwidth {err_bitwidth} higher than maximum {MAX_BITWIDTH} bits", + ))); + } + + let num_bitpacking_bytes = rem[..num_remaining_mini_blocks] + .iter() + .copied() + .map(|bitwidth| (bitwidth as usize * values_per_miniblock).div_ceil(8)) + .sum::(); + + rem = rem + .get(num_miniblocks_per_block + num_bitpacking_bytes..) + .ok_or_else(|| { + ParquetError::oos( + "Not enough bytes for all bitpacked values in delta encoding", + ) + })?; + + num_values_left = num_values_left.saturating_sub(values_per_block); + } + } + + let values = &values[..values.len() - rem.len()]; + + let decoder = Self { + num_miniblocks_per_block, + values_per_block, + values_remaining: total_count.saturating_sub(1), + last_value: first_value, + values, + + block: Block { + // @NOTE: + // We add one delta=0 into the buffered block which allows us to + // prepend like the `first_value` is just any normal value. + // + // This is a bit of a hack, but makes the rest of the logic + // **A LOT** simpler. + values_remaining: usize::from(total_count > 0), + min_delta: 0, + bitwidths: &[], + miniblock: MiniBlock { + decoder: bitpacked::Decoder::try_new_allow_zero(&[], 0, 1)?, + buffered: ::Unpacked::zero(), + unpacked_start: 0, + unpacked_end: 0, + }, + }, + }; + + Ok((decoder, rem)) + } + + /// Consume a new [`Block`] from `self.values`. + fn consume_block(&mut self) { + // @NOTE: All the panics here should be prevented in the `Decoder::try_new`. + + debug_assert!(!self.values.is_empty()); + + let values_per_miniblock = self.values_per_miniblock(); + + let length = usize::min(self.values_remaining, self.values_per_block); + let actual_num_miniblocks = usize::min( + self.num_miniblocks_per_block, + length.div_ceil(values_per_miniblock), + ); + + debug_assert!(actual_num_miniblocks > 0); + + // + + let (min_delta, consumed) = zigzag_leb128::decode(self.values); + + self.values = &self.values[consumed..]; + let (bitwidths, remainder) = self.values.split_at(self.num_miniblocks_per_block); + + let first_bitwidth = bitwidths[0]; + let bitwidths = &bitwidths[1..actual_num_miniblocks]; + debug_assert!(first_bitwidth <= MAX_BITWIDTH); + let first_bitwidth = first_bitwidth as usize; + + let values_in_first_miniblock = usize::min(length, values_per_miniblock); + let num_allocated_bytes = (first_bitwidth * values_per_miniblock).div_ceil(8); + let num_actual_bytes = (first_bitwidth * values_in_first_miniblock).div_ceil(8); + let (bytes, remainder) = remainder.split_at(num_allocated_bytes); + let bytes = &bytes[..num_actual_bytes]; + + let decoder = + bitpacked::Decoder::new_allow_zero(bytes, first_bitwidth, values_in_first_miniblock); + + self.block = Block { + min_delta, + bitwidths, + values_remaining: length, + miniblock: MiniBlock { + decoder, + // We can leave this as it should not be read before it is updated + buffered: self.block.miniblock.buffered, + unpacked_start: 0, + unpacked_end: 0, + }, + }; + + self.values_remaining -= length; + self.values = remainder; + } + + /// Gather `n` elements from the current [`MiniBlock`] to `target` + fn gather_miniblock_n_into( + &mut self, + target: &mut G::Target, + mut n: usize, + gatherer: &mut G, + ) -> ParquetResult<()> { + debug_assert!(n > 0); + debug_assert!(self.miniblock_len() >= n); + + // If the `num_bits == 0`, the delta is constant and equal to `min_delta`. The + // `bitpacked::Decoder` basically only keeps track of the length. + if self.block.miniblock.decoder.num_bits() == 0 { + let num_repeats = usize::min(self.miniblock_len(), n); + let v = self.last_value.wrapping_add(self.block.min_delta); + gatherer.gather_constant(target, v, self.block.min_delta, num_repeats)?; + self.last_value = self + .last_value + .wrapping_add(self.block.min_delta * num_repeats as i64); + self.block.miniblock.decoder.length -= num_repeats; + return Ok(()); + } + + if self.block.miniblock.unpacked_start < self.block.miniblock.unpacked_end { + let length = usize::min( + n, + self.block.miniblock.unpacked_end - self.block.miniblock.unpacked_start, + ); + self.block.miniblock.buffered + [self.block.miniblock.unpacked_start..self.block.miniblock.unpacked_start + length] + .iter_mut() + .for_each(|v| { + self.last_value = self + .last_value + .wrapping_add(*v as i64) + .wrapping_add(self.block.min_delta); + *v = self.last_value as u64; + }); + gatherer.gather_slice( + target, + bytemuck::cast_slice( + &self.block.miniblock.buffered[self.block.miniblock.unpacked_start + ..self.block.miniblock.unpacked_start + length], + ), + )?; + n -= length; + self.block.miniblock.unpacked_start += length; + } + + if n == 0 { + return Ok(()); + } + + const ITEMS_PER_PACK: usize = <::Unpacked as Unpacked>::LENGTH; + for _ in 0..n / ITEMS_PER_PACK { + let mut chunk = self.block.miniblock.decoder.chunked().next().unwrap(); + chunk.iter_mut().for_each(|v| { + self.last_value = self + .last_value + .wrapping_add(*v as i64) + .wrapping_add(self.block.min_delta); + *v = self.last_value as u64; + }); + gatherer.gather_chunk(target, bytemuck::cast_ref(&chunk))?; + n -= ITEMS_PER_PACK; + } + + if n == 0 { + return Ok(()); + } + + let Some((chunk, len)) = self.block.miniblock.decoder.chunked().next_inexact() else { + debug_assert_eq!(n, 0); + self.block.miniblock.buffered = ::Unpacked::zero(); + self.block.miniblock.unpacked_start = 0; + self.block.miniblock.unpacked_end = 0; + return Ok(()); + }; + + self.block.miniblock.buffered = chunk; + self.block.miniblock.unpacked_start = 0; + self.block.miniblock.unpacked_end = len; + + if n > 0 { + let length = usize::min(n, self.block.miniblock.unpacked_end); + self.block.miniblock.buffered[..length] + .iter_mut() + .for_each(|v| { + self.last_value = self + .last_value + .wrapping_add(*v as i64) + .wrapping_add(self.block.min_delta); + *v = self.last_value as u64; + }); + gatherer.gather_slice( + target, + bytemuck::cast_slice(&self.block.miniblock.buffered[..length]), + )?; + self.block.miniblock.unpacked_start = length; + } + + Ok(()) + } + + /// Gather `n` elements from the current [`Block`] to `target` + fn gather_block_n_into( + &mut self, + target: &mut G::Target, + n: usize, + gatherer: &mut G, + ) -> ParquetResult<()> { + let values_per_miniblock = self.values_per_miniblock(); + + debug_assert!(n <= self.values_per_block); + debug_assert!(self.values_per_block >= values_per_miniblock); + debug_assert_eq!(self.values_per_block % values_per_miniblock, 0); + + let mut n = usize::min(self.block.values_remaining, n); + + if n == 0 { + return Ok(()); + } + + let miniblock_len = self.miniblock_len(); + if n < miniblock_len { + self.gather_miniblock_n_into(target, n, gatherer)?; + debug_assert_eq!(self.miniblock_len(), miniblock_len - n); + self.block.values_remaining -= n; + return Ok(()); + } + + if miniblock_len > 0 { + self.gather_miniblock_n_into(target, miniblock_len, gatherer)?; + n -= miniblock_len; + self.block.values_remaining -= miniblock_len; + } + + while n >= values_per_miniblock { + let bitwidth = self.block.bitwidths[0]; + self.block.bitwidths = &self.block.bitwidths[1..]; + + let miniblock; + (miniblock, self.values) = self + .values + .split_at((bitwidth as usize * values_per_miniblock).div_ceil(8)); + gather_miniblock( + target, + self.block.min_delta, + bitwidth, + miniblock, + values_per_miniblock, + &mut self.last_value, + gatherer, + )?; + n -= values_per_miniblock; + self.block.values_remaining -= values_per_miniblock; + } + + if n == 0 { + return Ok(()); + } + + if !self.block.bitwidths.is_empty() { + let bitwidth = self.block.bitwidths[0]; + self.block.bitwidths = &self.block.bitwidths[1..]; + + if bitwidth > MAX_BITWIDTH { + return Err(ParquetError::oos(format!( + "Delta encoding bitwidth '{bitwidth}' is larger than maximum {MAX_BITWIDTH})" + ))); + } + + let length = usize::min(values_per_miniblock, self.block.values_remaining); + + let num_allocated_bytes = (bitwidth as usize * values_per_miniblock).div_ceil(8); + let num_actual_bytes = (bitwidth as usize * length).div_ceil(8); + + let miniblock; + (miniblock, self.values) = + self.values + .split_at_checked(num_allocated_bytes) + .ok_or(ParquetError::oos( + "Not enough space for delta encoded miniblock", + ))?; + + let miniblock = &miniblock[..num_actual_bytes]; + + let decoder = + bitpacked::Decoder::try_new_allow_zero(miniblock, bitwidth as usize, length)?; + self.block.miniblock = MiniBlock { + decoder, + buffered: self.block.miniblock.buffered, + unpacked_start: 0, + unpacked_end: 0, + }; + + if n > 0 { + self.gather_miniblock_n_into(target, n, gatherer)?; + self.block.values_remaining -= n; + } + } + + Ok(()) + } + + /// Gather `n` elements to `target` + pub fn gather_n_into( + &mut self, + target: &mut G::Target, + mut n: usize, + gatherer: &mut G, + ) -> ParquetResult<()> { + n = usize::min(n, self.len()); + + if n == 0 { + return Ok(()); + } + + let values_per_miniblock = self.values_per_block / self.num_miniblocks_per_block; + + let start_num_values_remaining = self.block.values_remaining; + if n <= self.block.values_remaining { + self.gather_block_n_into(target, n, gatherer)?; + debug_assert_eq!(self.block.values_remaining, start_num_values_remaining - n); + return Ok(()); + } + + n -= self.block.values_remaining; + self.gather_block_n_into(target, self.block.values_remaining, gatherer)?; + debug_assert_eq!(self.block.values_remaining, 0); + + while usize::min(n, self.values_remaining) >= self.values_per_block { + self.values = gather_block( + target, + self.num_miniblocks_per_block, + values_per_miniblock, + self.values, + &mut self.last_value, + gatherer, + )?; + n -= self.values_per_block; + self.values_remaining -= self.values_per_block; + } + + if n == 0 { + return Ok(()); + } + + self.consume_block(); + self.gather_block_n_into(target, n, gatherer)?; + + Ok(()) + } + + pub(crate) fn collect_n>( + &mut self, + e: &mut E, + n: usize, + ) -> ParquetResult<()> { + struct ExtendGatherer<'a, E: std::fmt::Debug + Extend>( + std::marker::PhantomData<&'a E>, + ); + + impl<'a, E: std::fmt::Debug + Extend> DeltaGatherer for ExtendGatherer<'a, E> { + type Target = (usize, &'a mut E); + + fn target_len(&self, target: &Self::Target) -> usize { + target.0 + } + + fn target_reserve(&self, _target: &mut Self::Target, _n: usize) {} + + fn gather_one(&mut self, target: &mut Self::Target, v: i64) -> ParquetResult<()> { + target.1.extend(Some(v)); + target.0 += 1; + Ok(()) + } + } + + let mut gatherer = ExtendGatherer(std::marker::PhantomData); + let mut target = (0, e); + + self.gather_n_into(&mut target, n, &mut gatherer) + } + + pub(crate) fn collect + Default>( + mut self, + ) -> ParquetResult { + let mut e = E::default(); + self.collect_n(&mut e, self.len())?; + Ok(e) + } + + pub fn len(&self) -> usize { + self.values_remaining + self.block.values_remaining + } + + fn values_per_miniblock(&self) -> usize { + debug_assert_eq!(self.values_per_block % self.num_miniblocks_per_block, 0); + self.values_per_block / self.num_miniblocks_per_block + } + + fn miniblock_len(&self) -> usize { + self.block.miniblock.unpacked_end - self.block.miniblock.unpacked_start + + self.block.miniblock.decoder.len() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn single_value() { + // Generated by parquet-rs + // + // header: [128, 1, 4, 1, 2] + // block size: 128, 1 + // mini-blocks: 4 + // elements: 1 + // first_value: 2 <=z> 1 + let data = &[128, 1, 4, 1, 2]; + + let (decoder, rem) = Decoder::try_new(data).unwrap(); + let r = decoder.collect::>().unwrap(); + + assert_eq!(&r[..], &[1]); + assert_eq!(data.len() - rem.len(), 5); + } + + #[test] + fn test_from_spec() { + let expected = (1..=5).collect::>(); + // VALIDATED FROM SPARK==3.1.1 + // header: [128, 1, 4, 5, 2] + // block size: 128, 1 + // mini-blocks: 4 + // elements: 5 + // first_value: 2 <=z> 1 + // block1: [2, 0, 0, 0, 0] + // min_delta: 2 <=z> 1 + // bit_width: 0 + let data = &[128, 1, 4, 5, 2, 2, 0, 0, 0, 0]; + + let (decoder, rem) = Decoder::try_new(data).unwrap(); + let r = decoder.collect::>().unwrap(); + + assert_eq!(expected, r); + + assert_eq!(data.len() - rem.len(), 10); + } + + #[test] + fn case2() { + let expected = vec![1, 2, 3, 4, 5, 1]; + // VALIDATED FROM SPARK==3.1.1 + // header: [128, 1, 4, 6, 2] + // block size: 128, 1 <=u> 128 + // mini-blocks: 4 <=u> 4 + // elements: 6 <=u> 6 + // first_value: 2 <=z> 1 + // block1: [7, 3, 0, 0, 0] + // min_delta: 7 <=z> -4 + // bit_widths: [3, 0, 0, 0] + // values: [ + // 0b01101101 + // 0b00001011 + // ... + // ] <=b> [3, 3, 3, 3, 0] + let data = &[ + 128, 1, 4, 6, 2, 7, 3, 0, 0, 0, 0b01101101, 0b00001011, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + // these should not be consumed + 1, 2, 3, + ]; + + let (decoder, rem) = Decoder::try_new(data).unwrap(); + let r = decoder.collect::>().unwrap(); + + assert_eq!(expected, r); + assert_eq!(rem, &[1, 2, 3]); + } + + #[test] + fn multiple_miniblocks() { + #[rustfmt::skip] + let data = &[ + // Header: [128, 1, 4, 65, 100] + 128, 1, // block size <=u> 128 + 4, // number of mini-blocks <=u> 4 + 65, // number of elements <=u> 65 + 100, // first_value <=z> 50 + + // Block 1 header: [7, 3, 4, 0, 0] + 7, // min_delta <=z> -4 + 3, 4, 255, 0, // bit_widths (255 should not be used as only two miniblocks are needed) + + // 32 3-bit values of 0 for mini-block 1 (12 bytes) + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + + // 32 4-bit values of 8 for mini-block 2 (16 bytes) + 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, + 0x88, 0x88, + + // these should not be consumed + 1, 2, 3, + ]; + + #[rustfmt::skip] + let expected = [ + // First value + 50, + + // Mini-block 1: 32 deltas of -4 + 46, 42, 38, 34, 30, 26, 22, 18, 14, 10, 6, 2, -2, -6, -10, -14, -18, -22, -26, -30, -34, + -38, -42, -46, -50, -54, -58, -62, -66, -70, -74, -78, + + // Mini-block 2: 32 deltas of 4 + -74, -70, -66, -62, -58, -54, -50, -46, -42, -38, -34, -30, -26, -22, -18, -14, -10, -6, + -2, 2, 6, 10, 14, 18, 22, 26, 30, 34, 38, 42, 46, 50, + ]; + + let (decoder, rem) = Decoder::try_new(data).unwrap(); + let r = decoder.collect::>().unwrap(); + + assert_eq!(&expected[..], &r[..]); + assert_eq!(data.len() - rem.len(), data.len() - 3); + assert_eq!(rem.len(), 3); + } +} diff --git a/crates/polars-parquet/src/parquet/encoding/delta_bitpacked/encoder.rs b/crates/polars-parquet/src/parquet/encoding/delta_bitpacked/encoder.rs new file mode 100644 index 000000000000..fdcf50cd52e3 --- /dev/null +++ b/crates/polars-parquet/src/parquet/encoding/delta_bitpacked/encoder.rs @@ -0,0 +1,146 @@ +use super::super::{bitpacked, uleb128, zigzag_leb128}; +use crate::parquet::encoding::ceil8; + +/// Encodes an iterator of `i64` according to parquet's `DELTA_BINARY_PACKED`. +/// # Implementation +/// * This function does not allocate on the heap. +/// * The number of mini-blocks is always 1. This may change in the future. +pub fn encode>( + mut iterator: I, + buffer: &mut Vec, + num_miniblocks_per_block: usize, +) { + const BLOCK_SIZE: usize = 256; + assert!([1, 2, 4].contains(&num_miniblocks_per_block)); + let values_per_miniblock = BLOCK_SIZE / num_miniblocks_per_block; + + let mut container = [0u8; 10]; + let encoded_len = uleb128::encode(BLOCK_SIZE as u64, &mut container); + buffer.extend_from_slice(&container[..encoded_len]); + + let encoded_len = uleb128::encode(num_miniblocks_per_block as u64, &mut container); + buffer.extend_from_slice(&container[..encoded_len]); + + let length = iterator.len(); + let encoded_len = uleb128::encode(length as u64, &mut container); + buffer.extend_from_slice(&container[..encoded_len]); + + let mut values = [0i64; BLOCK_SIZE]; + let mut deltas = [0u64; BLOCK_SIZE]; + let mut num_bits = [0u8; 4]; + + let first_value = iterator.next().unwrap_or_default(); + let (container, encoded_len) = zigzag_leb128::encode(first_value); + buffer.extend_from_slice(&container[..encoded_len]); + + let mut prev = first_value; + let mut length = iterator.len(); + while length != 0 { + let mut min_delta = i64::MAX; + let mut max_delta = i64::MIN; + for (i, integer) in iterator.by_ref().enumerate().take(BLOCK_SIZE) { + if i % values_per_miniblock == 0 { + min_delta = i64::MAX; + max_delta = i64::MIN + } + + let delta = integer.wrapping_sub(prev); + min_delta = min_delta.min(delta); + max_delta = max_delta.max(delta); + + let miniblock_idx = i / values_per_miniblock; + num_bits[miniblock_idx] = (64 - max_delta.abs_diff(min_delta).leading_zeros()) as u8; + values[i] = delta; + prev = integer; + } + let consumed = std::cmp::min(length - iterator.len(), BLOCK_SIZE); + length = iterator.len(); + let values = &values[..consumed]; + + values.iter().zip(deltas.iter_mut()).for_each(|(v, delta)| { + *delta = v.wrapping_sub(min_delta) as u64; + }); + + // + let (container, encoded_len) = zigzag_leb128::encode(min_delta); + buffer.extend_from_slice(&container[..encoded_len]); + + // one miniblock => 1 byte + let mut values_remaining = consumed; + buffer.extend_from_slice(&num_bits[..num_miniblocks_per_block]); + for i in 0..num_miniblocks_per_block { + if values_remaining == 0 { + break; + } + + values_remaining = values_remaining.saturating_sub(values_per_miniblock); + write_miniblock( + buffer, + num_bits[i], + &deltas[i * values_per_miniblock..(i + 1) * values_per_miniblock], + ); + } + } +} + +fn write_miniblock(buffer: &mut Vec, num_bits: u8, deltas: &[u64]) { + let num_bits = num_bits as usize; + if num_bits > 0 { + let start = buffer.len(); + + // bitpack encode all (deltas.len = 128 which is a multiple of 32) + let bytes_needed = start + ceil8(deltas.len() * num_bits); + buffer.resize(bytes_needed, 0); + bitpacked::encode(deltas, num_bits, &mut buffer[start..]); + + let bytes_needed = start + ceil8(deltas.len() * num_bits); + buffer.truncate(bytes_needed); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn constant_delta() { + // header: [128, 2, 1, 5, 2]: + // block size: 256 <=u> 128, 2 + // mini-blocks: 1 <=u> 1 + // elements: 5 <=u> 5 + // first_value: 2 <=z> 1 + // block1: [2, 0, 0, 0, 0] + // min_delta: 1 <=z> 2 + // bitwidth: 0 + let data = 1..=5; + let expected = vec![128u8, 2, 1, 5, 2, 2, 0]; + + let mut buffer = vec![]; + encode(data.collect::>().into_iter(), &mut buffer, 1); + assert_eq!(expected, buffer); + } + + #[test] + fn negative_min_delta() { + // max - min = 1 - -4 = 5 + let data = vec![1, 2, 3, 4, 5, 1]; + // header: [128, 2, 4, 6, 2] + // block size: 256 <=u> 128, 2 + // mini-blocks: 1 <=u> 1 + // elements: 6 <=u> 5 + // first_value: 2 <=z> 1 + // block1: [7, 3, 253, 255] + // min_delta: -4 <=z> 7 + // bitwidth: 3 + // values: [5, 5, 5, 5, 0] <=b> [ + // 0b01101101 + // 0b00001011 + // ] + let mut expected = vec![128u8, 2, 1, 6, 2, 7, 3, 0b01101101, 0b00001011]; + expected.extend(std::iter::repeat_n(0, 256 * 3 / 8 - 2)); // 128 values, 3 bits, 2 already used + + let mut buffer = vec![]; + encode(data.into_iter(), &mut buffer, 1); + assert_eq!(expected, buffer); + } +} diff --git a/crates/polars-parquet/src/parquet/encoding/delta_bitpacked/mod.rs b/crates/polars-parquet/src/parquet/encoding/delta_bitpacked/mod.rs new file mode 100644 index 000000000000..040909a336bb --- /dev/null +++ b/crates/polars-parquet/src/parquet/encoding/delta_bitpacked/mod.rs @@ -0,0 +1,151 @@ +mod decoder; +mod encoder; + +pub(crate) use decoder::{Decoder, SumGatherer}; +pub(crate) use encoder::encode; + +/// The sum of `start, start + delta, start + 2 * delta, ... len times`. +pub(crate) fn lin_natural_sum(start: i64, delta: i64, len: usize) -> i64 { + debug_assert!(len < i64::MAX as usize); + + let base = start * len as i64; + let sum = if len == 0 { + 0 + } else { + let is_odd = len & 1; + // SUM_i=0^n f * i = f * (n(n+1)/2) + let sum = (len >> (is_odd ^ 1)) * (len.wrapping_sub(1) >> is_odd); + delta * sum as i64 + }; + + base + sum +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::parquet::error::{ParquetError, ParquetResult}; + + #[test] + fn linear_natural_sum() { + assert_eq!(lin_natural_sum(0, 0, 0), 0); + assert_eq!(lin_natural_sum(10, 4, 0), 0); + assert_eq!(lin_natural_sum(0, 1, 1), 0); + assert_eq!(lin_natural_sum(0, 1, 3), 3); + assert_eq!(lin_natural_sum(0, 1, 4), 6); + assert_eq!(lin_natural_sum(0, 2, 3), 6); + assert_eq!(lin_natural_sum(2, 2, 3), 12); + } + + #[test] + fn basic() -> Result<(), ParquetError> { + let data = vec![1, 3, 1, 2, 3]; + + let mut buffer = vec![]; + encode(data.clone().into_iter(), &mut buffer, 1); + let (iter, _) = Decoder::try_new(&buffer)?; + + let result = iter.collect::>()?; + assert_eq!(result, data); + Ok(()) + } + + #[test] + fn negative_value() -> Result<(), ParquetError> { + let data = vec![1, 3, -1, 2, 3]; + + let mut buffer = vec![]; + encode(data.clone().into_iter(), &mut buffer, 1); + let (iter, _) = Decoder::try_new(&buffer)?; + + let result = iter.collect::>()?; + assert_eq!(result, data); + Ok(()) + } + + #[test] + fn some() -> Result<(), ParquetError> { + let data = vec![ + -2147483648, + -1777158217, + -984917788, + -1533539476, + -731221386, + -1322398478, + 906736096, + ]; + + let mut buffer = vec![]; + encode(data.clone().into_iter(), &mut buffer, 1); + let (iter, _) = Decoder::try_new(&buffer)?; + + let result = iter.collect::>()?; + assert_eq!(result, data); + Ok(()) + } + + #[test] + fn more_than_one_block() -> Result<(), ParquetError> { + let mut data = vec![1, 3, -1, 2, 3, 10, 1]; + for x in 0..128 { + data.push(x - 10) + } + + let mut buffer = vec![]; + encode(data.clone().into_iter(), &mut buffer, 1); + let (iter, _) = Decoder::try_new(&buffer)?; + + let result = iter.collect::>()?; + assert_eq!(result, data); + Ok(()) + } + + #[test] + fn test_another() -> Result<(), ParquetError> { + let data = vec![2, 3, 1, 2, 1]; + + let mut buffer = vec![]; + encode(data.clone().into_iter(), &mut buffer, 1); + let (iter, _) = Decoder::try_new(&buffer)?; + + let result = iter.collect::>()?; + assert_eq!(result, data); + + Ok(()) + } + + #[test] + fn overflow_constant() -> ParquetResult<()> { + let data = vec![i64::MIN, i64::MAX, i64::MIN, i64::MAX]; + + let mut buffer = vec![]; + encode(data.clone().into_iter(), &mut buffer, 1); + let (iter, _) = Decoder::try_new(&buffer)?; + + let result = iter.collect::>()?; + assert_eq!(result, data); + + Ok(()) + } + + #[test] + fn overflow_vary() -> ParquetResult<()> { + let data = vec![ + 0, + i64::MAX, + i64::MAX - 1, + i64::MIN + 1, + i64::MAX, + i64::MIN + 2, + ]; + + let mut buffer = vec![]; + encode(data.clone().into_iter(), &mut buffer, 1); + let (iter, _) = Decoder::try_new(&buffer)?; + + let result = iter.collect::>()?; + assert_eq!(result, data); + + Ok(()) + } +} diff --git a/crates/polars-parquet/src/parquet/encoding/delta_byte_array/decoder.rs b/crates/polars-parquet/src/parquet/encoding/delta_byte_array/decoder.rs new file mode 100644 index 000000000000..bcdf9403b7be --- /dev/null +++ b/crates/polars-parquet/src/parquet/encoding/delta_byte_array/decoder.rs @@ -0,0 +1,136 @@ +use super::super::delta_bitpacked; +use crate::parquet::encoding::delta_bitpacked::SumGatherer; +use crate::parquet::error::ParquetResult; + +/// Decodes according to [Delta strings](https://github.com/apache/parquet-format/blob/master/Encodings.md#delta-strings-delta_byte_array--7), +/// prefixes, lengths and values +/// # Implementation +/// This struct does not allocate on the heap. +#[derive(Debug)] +pub struct Decoder<'a> { + pub(crate) prefix_lengths: delta_bitpacked::Decoder<'a>, + pub(crate) suffix_lengths: delta_bitpacked::Decoder<'a>, + pub(crate) values: &'a [u8], + + pub(crate) offset: usize, + pub(crate) last: Vec, +} + +impl<'a> Decoder<'a> { + pub fn try_new(values: &'a [u8]) -> ParquetResult { + let (prefix_lengths, values) = delta_bitpacked::Decoder::try_new(values)?; + let (suffix_lengths, values) = delta_bitpacked::Decoder::try_new(values)?; + + Ok(Self { + prefix_lengths, + suffix_lengths, + values, + + offset: 0, + last: Vec::with_capacity(32), + }) + } + + pub fn values(&self) -> &'a [u8] { + self.values + } + + pub fn len(&self) -> usize { + debug_assert_eq!(self.prefix_lengths.len(), self.suffix_lengths.len()); + self.prefix_lengths.len() + } + + pub fn skip_in_place(&mut self, n: usize) -> ParquetResult<()> { + let mut prefix_sum = 0usize; + self.prefix_lengths + .gather_n_into(&mut prefix_sum, n, &mut SumGatherer(0))?; + let mut suffix_sum = 0usize; + self.suffix_lengths + .gather_n_into(&mut suffix_sum, n, &mut SumGatherer(0))?; + self.offset += prefix_sum + suffix_sum; + Ok(()) + } +} + +impl Iterator for Decoder<'_> { + type Item = ParquetResult>; + + fn next(&mut self) -> Option { + if self.len() == 0 { + return None; + } + + let mut prefix_length = vec![]; + let mut suffix_length = vec![]; + if let Err(e) = self.prefix_lengths.collect_n(&mut prefix_length, 1) { + return Some(Err(e)); + } + if let Err(e) = self.suffix_lengths.collect_n(&mut suffix_length, 1) { + return Some(Err(e)); + } + let prefix_length = prefix_length[0]; + let suffix_length = suffix_length[0]; + + let prefix_length = prefix_length as usize; + let suffix_length = suffix_length as usize; + + let mut value = Vec::with_capacity(prefix_length + suffix_length); + + value.extend_from_slice(&self.last[..prefix_length]); + value.extend_from_slice(&self.values[self.offset..self.offset + suffix_length]); + + self.last.clear(); + self.last.extend_from_slice(&value); + + self.offset += suffix_length; + + Some(Ok(value)) + } + + fn size_hint(&self) -> (usize, Option) { + (self.prefix_lengths.len(), Some(self.prefix_lengths.len())) + } +} + +impl ExactSizeIterator for Decoder<'_> {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_bla() -> ParquetResult<()> { + // VALIDATED from spark==3.1.1 + let data = &[ + 128, 1, 4, 2, 0, 0, 0, 0, 0, 0, 128, 1, 4, 2, 10, 0, 0, 0, 0, 0, 72, 101, 108, 108, + 111, 87, 111, 114, 108, 100, + // extra bytes are not from spark, but they should be ignored by the decoder + // because they are beyond the sum of all lengths. + 1, 2, 3, + ]; + + let decoder = Decoder::try_new(data)?; + let values = decoder.collect::, _>>()?; + assert_eq!(values, vec![b"Hello".to_vec(), b"World".to_vec()]); + + Ok(()) + } + + #[test] + fn test_with_prefix() -> ParquetResult<()> { + // VALIDATED from spark==3.1.1 + let data = &[ + 128, 1, 4, 2, 0, 6, 0, 0, 0, 0, 128, 1, 4, 2, 10, 4, 0, 0, 0, 0, 72, 101, 108, 108, + 111, 105, 99, 111, 112, 116, 101, 114, + // extra bytes are not from spark, but they should be ignored by the decoder + // because they are beyond the sum of all lengths. + 1, 2, 3, + ]; + + let decoder = Decoder::try_new(data)?; + let prefixes = decoder.collect::, _>>()?; + assert_eq!(prefixes, vec![b"Hello".to_vec(), b"Helicopter".to_vec()]); + + Ok(()) + } +} diff --git a/crates/polars-parquet/src/parquet/encoding/delta_byte_array/encoder.rs b/crates/polars-parquet/src/parquet/encoding/delta_byte_array/encoder.rs new file mode 100644 index 000000000000..3a36e90b9966 --- /dev/null +++ b/crates/polars-parquet/src/parquet/encoding/delta_byte_array/encoder.rs @@ -0,0 +1,35 @@ +use super::super::delta_bitpacked; +use crate::parquet::encoding::delta_length_byte_array; + +/// Encodes an iterator of according to DELTA_BYTE_ARRAY +pub fn encode<'a, I: ExactSizeIterator + Clone>( + iterator: I, + buffer: &mut Vec, +) { + let mut previous = b"".as_ref(); + + let mut sum_lengths = 0; + let prefixes = iterator + .clone() + .map(|item| { + let prefix_length = item + .iter() + .zip(previous.iter()) + .enumerate() + // find first difference + .find_map(|(length, (lhs, rhs))| (lhs != rhs).then_some(length)) + .unwrap_or(previous.len()); + previous = item; + + sum_lengths += item.len() - prefix_length; + prefix_length as i64 + }) + .collect::>(); + delta_bitpacked::encode(prefixes.iter().copied(), buffer, 1); + + let remaining = iterator + .zip(prefixes) + .map(|(item, prefix)| &item[prefix as usize..]); + + delta_length_byte_array::encode(remaining, buffer); +} diff --git a/crates/polars-parquet/src/parquet/encoding/delta_byte_array/mod.rs b/crates/polars-parquet/src/parquet/encoding/delta_byte_array/mod.rs new file mode 100644 index 000000000000..2bb51511d67e --- /dev/null +++ b/crates/polars-parquet/src/parquet/encoding/delta_byte_array/mod.rs @@ -0,0 +1,27 @@ +mod decoder; +mod encoder; + +pub use decoder::Decoder; +pub use encoder::encode; + +#[cfg(test)] +mod tests { + use super::*; + use crate::parquet::error::ParquetError; + + #[test] + fn basic() -> Result<(), ParquetError> { + let data = vec![b"Hello".as_ref(), b"Helicopter"]; + let mut buffer = vec![]; + encode(data.clone().into_iter(), &mut buffer); + + let mut decoder = Decoder::try_new(&buffer)?; + let prefixes = decoder.by_ref().collect::, _>>()?; + assert_eq!(prefixes, vec![b"Hello".to_vec(), b"Helicopter".to_vec()]); + + // move to the values + let values = decoder.values(); + assert_eq!(values, b"Helloicopter"); + Ok(()) + } +} diff --git a/crates/polars-parquet/src/parquet/encoding/delta_length_byte_array/decoder.rs b/crates/polars-parquet/src/parquet/encoding/delta_length_byte_array/decoder.rs new file mode 100644 index 000000000000..97e2e9c49b8a --- /dev/null +++ b/crates/polars-parquet/src/parquet/encoding/delta_length_byte_array/decoder.rs @@ -0,0 +1,50 @@ +use super::super::delta_bitpacked; +use crate::parquet::error::ParquetResult; + +/// Decodes [Delta-length byte array](https://github.com/apache/parquet-format/blob/master/Encodings.md#delta-length-byte-array-delta_length_byte_array--6) +/// lengths and values. +/// # Implementation +/// This struct does not allocate on the heap. +#[derive(Debug)] +pub(crate) struct Decoder<'a> { + pub(crate) lengths: delta_bitpacked::Decoder<'a>, + pub(crate) values: &'a [u8], + #[cfg(test)] + pub(crate) offset: usize, +} + +impl<'a> Decoder<'a> { + pub fn try_new(values: &'a [u8]) -> ParquetResult { + let (lengths, values) = delta_bitpacked::Decoder::try_new(values)?; + Ok(Self { + lengths, + values, + #[cfg(test)] + offset: 0, + }) + } + + pub fn len(&self) -> usize { + self.lengths.len() + } +} + +#[cfg(test)] +impl<'a> Iterator for Decoder<'a> { + type Item = ParquetResult<&'a [u8]>; + + fn next(&mut self) -> Option { + if self.lengths.len() == 0 { + return None; + } + + let mut length = vec![]; + if let Err(e) = self.lengths.collect_n(&mut length, 1) { + return Some(Err(e)); + } + let length = length[0] as usize; + let value = &self.values[self.offset..self.offset + length]; + self.offset += length; + Some(Ok(value)) + } +} diff --git a/crates/polars-parquet/src/parquet/encoding/delta_length_byte_array/encoder.rs b/crates/polars-parquet/src/parquet/encoding/delta_length_byte_array/encoder.rs new file mode 100644 index 000000000000..4e57c699504f --- /dev/null +++ b/crates/polars-parquet/src/parquet/encoding/delta_length_byte_array/encoder.rs @@ -0,0 +1,23 @@ +use crate::parquet::encoding::delta_bitpacked; + +/// Encodes a cloneable iterator of `&[u8]` into `buffer`. This does not allocated on the heap. +/// # Implementation +/// This encoding is equivalent to call [`delta_bitpacked::encode`] on the lengths of the items +/// of the iterator followed by extending the buffer from each item of the iterator. +pub fn encode, I: ExactSizeIterator + Clone>( + iterator: I, + buffer: &mut Vec, +) { + let mut total_length = 0; + delta_bitpacked::encode( + iterator.clone().map(|x| { + let len = x.as_ref().len(); + total_length += len; + len as i64 + }), + buffer, + 1, + ); + buffer.reserve(total_length); + iterator.for_each(|x| buffer.extend(x.as_ref())) +} diff --git a/crates/polars-parquet/src/parquet/encoding/delta_length_byte_array/mod.rs b/crates/polars-parquet/src/parquet/encoding/delta_length_byte_array/mod.rs new file mode 100644 index 000000000000..050ac766f545 --- /dev/null +++ b/crates/polars-parquet/src/parquet/encoding/delta_length_byte_array/mod.rs @@ -0,0 +1,60 @@ +mod decoder; +mod encoder; + +pub(crate) use decoder::Decoder; +pub(crate) use encoder::encode; + +#[cfg(test)] +mod tests { + use super::*; + use crate::parquet::error::ParquetError; + + #[test] + fn basic() -> Result<(), ParquetError> { + let data = vec!["aa", "bbb", "a", "aa", "b"]; + + let mut buffer = vec![]; + encode(data.into_iter().map(|x| x.as_bytes()), &mut buffer); + + let mut iter = Decoder::try_new(&buffer)?; + + let result = iter.by_ref().collect::, _>>()?; + assert_eq!( + result, + vec![ + b"aa".as_ref(), + b"bbb".as_ref(), + b"a".as_ref(), + b"aa".as_ref(), + b"b".as_ref() + ] + ); + + let result = iter.values; + assert_eq!(result, b"aabbbaaab".as_ref()); + Ok(()) + } + + #[test] + fn many_numbers() -> Result<(), ParquetError> { + let mut data = vec![]; + for i in 0..136 { + data.push(format!("a{}", i)) + } + + let expected = data + .iter() + .map(|v| v.as_bytes().to_vec()) + .collect::>(); + + let mut buffer = vec![]; + encode(data.into_iter(), &mut buffer); + + let mut iter = Decoder::try_new(&buffer)?; + + let result = iter.by_ref().collect::, _>>()?; + assert_eq!(result, expected); + + Ok(()) + } +} diff --git a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/bitmap.rs b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/bitmap.rs new file mode 100644 index 000000000000..0d67dc935857 --- /dev/null +++ b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/bitmap.rs @@ -0,0 +1,102 @@ +use std::io::Write; + +const BIT_MASK: [u8; 8] = [1, 2, 4, 8, 16, 32, 64, 128]; + +/// Sets bit at position `i` in `byte` +#[inline] +pub fn set(byte: u8, i: usize) -> u8 { + byte | BIT_MASK[i] +} + +/// An [`Iterator`] of bool that decodes a bitmap. +/// This is a specialization of [`super::super::bitpacked::Decoder`] for `num_bits == 1`. +#[derive(Debug)] +pub struct BitmapIter<'a> { + iter: std::slice::Iter<'a, u8>, + current_byte: &'a u8, + remaining: usize, + mask: u8, +} + +impl<'a> BitmapIter<'a> { + /// Returns a new [`BitmapIter`]. + /// # Panics + /// This function panics iff `offset / 8 > slice.len()` + #[inline] + pub fn new(slice: &'a [u8], offset: usize, len: usize) -> Self { + let bytes = &slice[offset / 8..]; + + let mut iter = bytes.iter(); + + let current_byte = iter.next().unwrap_or(&0); + + Self { + iter, + mask: 1u8.rotate_left(offset as u32), + remaining: len, + current_byte, + } + } +} + +impl Iterator for BitmapIter<'_> { + type Item = bool; + + #[inline] + fn next(&mut self) -> Option { + // easily predictable in branching + if self.remaining == 0 { + return None; + } else { + self.remaining -= 1; + } + let value = self.current_byte & self.mask != 0; + self.mask = self.mask.rotate_left(1); + if self.mask == 1 { + // reached a new byte => try to fetch it from the iterator + if let Some(v) = self.iter.next() { + self.current_byte = v + } + } + Some(value) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.remaining, Some(self.remaining)) + } +} + +/// Writes an iterator of bools into writer, with LSB first. +pub fn encode_bool>( + writer: &mut W, + mut iterator: I, +) -> std::io::Result<()> { + // the length of the iterator. + let length = iterator.size_hint().1.unwrap(); + + let chunks = length / 8; + let reminder = length % 8; + + (0..chunks).try_for_each(|_| { + let mut byte = 0u8; + (0..8).for_each(|i| { + if iterator.next().unwrap() { + byte = set(byte, i) + } + }); + writer.write_all(&[byte]) + })?; + + if reminder != 0 { + let mut last = 0u8; + iterator.enumerate().for_each(|(i, value)| { + if value { + last = set(last, i) + } + }); + writer.write_all(&[last]) + } else { + Ok(()) + } +} diff --git a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/encoder.rs b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/encoder.rs new file mode 100644 index 000000000000..7a15d1fac520 --- /dev/null +++ b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/encoder.rs @@ -0,0 +1,314 @@ +use std::io::Write; + +use super::bitpacked_encode; +use crate::parquet::encoding::{bitpacked, ceil8, uleb128}; + +// Arbitrary value that balances memory usage and storage overhead +const MAX_VALUES_PER_LITERAL_RUN: usize = (1 << 10) * 8; + +pub trait Encoder { + fn bitpacked_encode>( + writer: &mut W, + iterator: I, + num_bits: usize, + ) -> std::io::Result<()>; + + fn run_length_encode( + writer: &mut W, + run_length: usize, + value: T, + bit_width: u32, + ) -> std::io::Result<()>; +} + +const U32_BLOCK_LEN: usize = 32; + +impl Encoder for u32 { + fn bitpacked_encode>( + writer: &mut W, + mut iterator: I, + num_bits: usize, + ) -> std::io::Result<()> { + // the length of the iterator. + let length = iterator.size_hint().1.unwrap(); + + let mut header = ceil8(length) as u64; + header <<= 1; + header |= 1; // it is bitpacked => first bit is set + let mut container = [0; 10]; + let used = uleb128::encode(header, &mut container); + writer.write_all(&container[..used])?; + + let chunks = length / U32_BLOCK_LEN; + let remainder = length - chunks * U32_BLOCK_LEN; + let mut buffer = [0u32; U32_BLOCK_LEN]; + + // simplified from ceil8(U32_BLOCK_LEN * num_bits) since U32_BLOCK_LEN = 32 + let compressed_chunk_size = 4 * num_bits; + + for _ in 0..chunks { + iterator + .by_ref() + .take(U32_BLOCK_LEN) + .zip(buffer.iter_mut()) + .for_each(|(item, buf)| *buf = item); + + let mut packed = [0u8; 4 * U32_BLOCK_LEN]; + bitpacked::encode_pack::(&buffer, num_bits, packed.as_mut()); + writer.write_all(&packed[..compressed_chunk_size])?; + } + + if remainder != 0 { + // Must be careful here to ensure we write a multiple of `num_bits` + // (the bit width) to align with the spec. Some readers also rely on + // this - see https://github.com/pola-rs/polars/pull/13883. + + // this is ceil8(remainder * num_bits), but we ensure the output is a + // multiple of num_bits by rewriting it as ceil8(remainder) * num_bits + let compressed_remainder_size = ceil8(remainder) * num_bits; + iterator + .by_ref() + .take(remainder) + .zip(buffer.iter_mut()) + .for_each(|(item, buf)| *buf = item); + + let mut packed = [0u8; 4 * U32_BLOCK_LEN]; + // No need to zero rest of buffer because remainder is either: + // * Multiple of 8: We pad non-terminal literal runs to have a + // multiple of 8 values. Once compressed, the data will end on + // clean byte boundaries and packed[..compressed_remainder_size] + // will include only the remainder values and nothing extra. + // * Final run: Extra values from buffer will be included in + // packed[..compressed_remainder_size] but ignored when decoding + // because they extend beyond known column length + bitpacked::encode_pack(&buffer, num_bits, packed.as_mut()); + writer.write_all(&packed[..compressed_remainder_size])?; + }; + Ok(()) + } + + fn run_length_encode( + writer: &mut W, + run_length: usize, + value: u32, + bit_width: u32, + ) -> std::io::Result<()> { + // write the length + indicator + let mut header = run_length as u64; + header <<= 1; + let mut container = [0; 10]; + let used = uleb128::encode(header, &mut container); + writer.write_all(&container[..used])?; + + let num_bytes = ceil8(bit_width as usize); + let bytes = value.to_le_bytes(); + writer.write_all(&bytes[..num_bytes])?; + Ok(()) + } +} + +impl Encoder for bool { + fn bitpacked_encode>( + writer: &mut W, + iterator: I, + _num_bits: usize, + ) -> std::io::Result<()> { + // the length of the iterator. + let length = iterator.size_hint().1.unwrap(); + + let mut header = ceil8(length) as u64; + header <<= 1; + header |= 1; // it is bitpacked => first bit is set + let mut container = [0; 10]; + let used = uleb128::encode(header, &mut container); + writer.write_all(&container[..used])?; + bitpacked_encode(writer, iterator)?; + Ok(()) + } + + fn run_length_encode( + writer: &mut W, + run_length: usize, + value: bool, + _bit_width: u32, + ) -> std::io::Result<()> { + // write the length + indicator + let mut header = run_length as u64; + header <<= 1; + let mut container = [0; 10]; + let used = uleb128::encode(header, &mut container); + writer.write_all(&container[..used])?; + writer.write_all(&(value as u8).to_le_bytes())?; + Ok(()) + } +} + +#[allow(clippy::comparison_chain)] +pub fn encode, W: Write, I: Iterator>( + writer: &mut W, + iterator: I, + num_bits: u32, +) -> std::io::Result<()> { + let mut consecutive_repeats: usize = 0; + let mut previous_val = T::default(); + let mut buffered_bits = [previous_val; MAX_VALUES_PER_LITERAL_RUN]; + let mut buffer_idx = 0; + let mut literal_run_idx = 0; + for val in iterator { + if val == previous_val { + consecutive_repeats += 1; + if consecutive_repeats >= 8 { + // Run is long enough to RLE, no need to buffer values + if consecutive_repeats > 8 { + continue; + } else { + // When we encounter a run long enough to potentially RLE, + // we must first ensure that the buffered literal run has + // a multiple of 8 values for bit-packing. If not, we pad + // up by taking some of the consecutive repeats + let literal_padding = (8 - (literal_run_idx % 8)) % 8; + consecutive_repeats -= literal_padding; + literal_run_idx += literal_padding; + } + } + // Too short to RLE, continue to buffer values + } else if consecutive_repeats > 8 { + // Value changed so start a new run but the current run is long + // enough to RLE. First, bit-pack any buffered literal run. Then, + // RLE current run and reset consecutive repeat counter and buffer. + if literal_run_idx > 0 { + debug_assert!(literal_run_idx % 8 == 0); + T::bitpacked_encode( + writer, + buffered_bits.iter().take(literal_run_idx).copied(), + num_bits as usize, + )?; + literal_run_idx = 0; + } + T::run_length_encode(writer, consecutive_repeats, previous_val, num_bits)?; + consecutive_repeats = 1; + buffer_idx = 0; + } else { + // Value changed so start a new run but the current run is not long + // enough to RLE. Consolidate all consecutive repeats into buffered + // literal run. + literal_run_idx = buffer_idx; + consecutive_repeats = 1; + } + // If buffer is full, bit-pack as literal run and reset + if buffer_idx == MAX_VALUES_PER_LITERAL_RUN { + T::bitpacked_encode(writer, buffered_bits.iter().copied(), num_bits as usize)?; + // If buffer fills up in the middle of a run, all but the last + // repeat is consolidated into the literal run. + debug_assert!( + (consecutive_repeats < 8) + && (buffer_idx - literal_run_idx == consecutive_repeats - 1) + ); + consecutive_repeats = 1; + buffer_idx = 0; + literal_run_idx = 0; + } + buffered_bits[buffer_idx] = val; + previous_val = val; + buffer_idx += 1; + } + // Final run not long enough to RLE, extend literal run. + if consecutive_repeats <= 8 { + literal_run_idx = buffer_idx; + } + // Bit-pack final buffered literal run, if any + if literal_run_idx > 0 { + T::bitpacked_encode( + writer, + buffered_bits.iter().take(literal_run_idx).copied(), + num_bits as usize, + )?; + } + // RLE final consecutive run if long enough + if consecutive_repeats > 8 { + T::run_length_encode(writer, consecutive_repeats, previous_val, num_bits)?; + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::super::bitmap::BitmapIter; + use super::*; + + #[test] + fn bool_basics_1() -> std::io::Result<()> { + let iter = BitmapIter::new(&[0b10011101u8, 0b10011101], 0, 14); + + let mut vec = vec![]; + + encode::(&mut vec, iter, 1)?; + + assert_eq!(vec, vec![((2 << 1) | 1), 0b10011101u8, 0b00011101]); + + Ok(()) + } + + #[test] + fn bool_from_iter() -> std::io::Result<()> { + let mut vec = vec![]; + + encode::( + &mut vec, + vec![true, true, true, true, true, true, true, true].into_iter(), + 1, + )?; + + assert_eq!(vec, vec![((1 << 1) | 1), 0b11111111]); + Ok(()) + } + + #[test] + fn test_encode_u32() -> std::io::Result<()> { + let mut vec = vec![]; + + encode::(&mut vec, vec![0, 1, 2, 1, 2, 1, 1, 0, 3].into_iter(), 2)?; + + assert_eq!( + vec, + vec![ + ((2 << 1) | 1), + 0b01_10_01_00, + 0b00_01_01_10, + 0b_00_00_00_11, + 0b0 + ] + ); + Ok(()) + } + + #[test] + fn test_encode_u32_large() -> std::io::Result<()> { + let mut vec = vec![]; + + let values = (0..128).map(|x| x % 4); + + encode::(&mut vec, values, 2)?; + + let length = 128; + let expected = 0b11_10_01_00u8; + + let mut expected = vec![expected; length / 4]; + expected.insert(0, (((length / 8) as u8) << 1) | 1); + + assert_eq!(vec, expected); + Ok(()) + } + + #[test] + fn test_u32_other() -> std::io::Result<()> { + let values = vec![3, 3, 0, 3, 2, 3, 3, 3, 3, 1, 3, 3, 3, 0, 3].into_iter(); + + let mut vec = vec![]; + encode::(&mut vec, values, 2)?; + + let expected = vec![5, 207, 254, 247, 51]; + assert_eq!(expected, vec); + Ok(()) + } +} diff --git a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs new file mode 100644 index 000000000000..66d30c5300c4 --- /dev/null +++ b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs @@ -0,0 +1,350 @@ +// See https://github.com/apache/parquet-format/blob/master/Encodings.md#run-length-encoding--bit-packing-hybrid-rle--3 +mod bitmap; +mod encoder; + +pub use bitmap::{BitmapIter, encode_bool as bitpacked_encode}; +pub use encoder::{Encoder, encode}; + +use super::{bitpacked, uleb128}; +use crate::parquet::error::{ParquetError, ParquetResult}; + +/// A [`Iterator`] for Hybrid Run-Length Encoding +/// +/// The hybrid here means that each second is prepended by a bit that differentiates between two +/// modes. +/// +/// 1. Run-Length Encoding in the shape of `[Number of Values, Value]` +/// 2. Bitpacking in the shape of `[Value 1 in n bits, Value 2 in n bits, ...]` +/// +/// Note, that this can iterate, but the set of `collect_*` and `translate_and_collect_*` methods +/// should be highly preferred as they are way more efficient and have better error handling. +#[derive(Debug, Clone)] +pub struct HybridRleDecoder<'a> { + data: &'a [u8], + num_bits: usize, + num_values: usize, +} + +pub struct HybridRleChunkIter<'a> { + decoder: HybridRleDecoder<'a>, +} + +#[derive(Debug)] +pub enum HybridRleChunk<'a> { + Rle(u32, usize), + Bitpacked(bitpacked::Decoder<'a, u32>), +} + +impl<'a> Iterator for HybridRleChunkIter<'a> { + type Item = ParquetResult>; + + fn next(&mut self) -> Option { + self.decoder.next_chunk().transpose() + } +} + +impl HybridRleChunk<'_> { + #[inline] + pub fn len(&self) -> usize { + match self { + HybridRleChunk::Rle(_, size) => *size, + HybridRleChunk::Bitpacked(decoder) => decoder.len(), + } + } +} + +impl<'a> HybridRleDecoder<'a> { + /// Returns a new [`HybridRleDecoder`] + pub fn new(data: &'a [u8], num_bits: u32, num_values: usize) -> Self { + Self { + data, + num_bits: num_bits as usize, + num_values, + } + } + + pub fn len(&self) -> usize { + self.num_values + } + + pub fn num_bits(&self) -> usize { + self.num_bits + } + + pub fn into_chunk_iter(self) -> HybridRleChunkIter<'a> { + HybridRleChunkIter { decoder: self } + } + + pub fn next_chunk(&mut self) -> ParquetResult>> { + if self.len() == 0 { + return Ok(None); + } + + if self.num_bits == 0 { + let num_values = self.num_values; + self.num_values = 0; + return Ok(Some(HybridRleChunk::Rle(0, num_values))); + } + + if self.data.is_empty() { + return Ok(None); + } + + let (indicator, consumed) = uleb128::decode(self.data); + self.data = unsafe { self.data.get_unchecked(consumed..) }; + + Ok(Some(if indicator & 1 == 1 { + // is bitpacking + let bytes = (indicator as usize >> 1) * self.num_bits; + let bytes = std::cmp::min(bytes, self.data.len()); + let Some((packed, remaining)) = self.data.split_at_checked(bytes) else { + return Err(ParquetError::oos("Not enough bytes for bitpacked data")); + }; + self.data = remaining; + + let length = std::cmp::min(packed.len() * 8 / self.num_bits, self.num_values); + let decoder = bitpacked::Decoder::::try_new(packed, self.num_bits, length)?; + + self.num_values -= length; + + HybridRleChunk::Bitpacked(decoder) + } else { + // is rle + let run_length = indicator as usize >> 1; + // repeated-value := value that is repeated, using a fixed-width of round-up-to-next-byte(bit-width) + let rle_bytes = self.num_bits.div_ceil(8); + let Some((pack, remaining)) = self.data.split_at_checked(rle_bytes) else { + return Err(ParquetError::oos("Not enough bytes for RLE encoded data")); + }; + self.data = remaining; + + let mut bytes = [0u8; std::mem::size_of::()]; + pack.iter().zip(bytes.iter_mut()).for_each(|(src, dst)| { + *dst = *src; + }); + let value = u32::from_le_bytes(bytes); + + let length = std::cmp::min(run_length, self.num_values); + + self.num_values -= length; + + HybridRleChunk::Rle(value, length) + })) + } + + pub fn next_chunk_length(&mut self) -> ParquetResult> { + if self.len() == 0 { + return Ok(None); + } + + if self.num_bits == 0 { + let num_values = self.num_values; + self.num_values = 0; + return Ok(Some(num_values)); + } + + if self.data.is_empty() { + return Ok(None); + } + + let (indicator, consumed) = uleb128::decode(self.data); + self.data = unsafe { self.data.get_unchecked(consumed..) }; + + Ok(Some(if indicator & 1 == 1 { + // is bitpacking + let bytes = (indicator as usize >> 1) * self.num_bits; + let bytes = std::cmp::min(bytes, self.data.len()); + let Some((packed, remaining)) = self.data.split_at_checked(bytes) else { + return Err(ParquetError::oos("Not enough bytes for bitpacked data")); + }; + self.data = remaining; + + let length = std::cmp::min(packed.len() * 8 / self.num_bits, self.num_values); + self.num_values -= length; + + length + } else { + // is rle + let run_length = indicator as usize >> 1; + // repeated-value := value that is repeated, using a fixed-width of round-up-to-next-byte(bit-width) + let rle_bytes = self.num_bits.div_ceil(8); + let Some(remaining) = self.data.get(rle_bytes..) else { + return Err(ParquetError::oos("Not enough bytes for RLE encoded data")); + }; + self.data = remaining; + + let length = std::cmp::min(run_length, self.num_values); + self.num_values -= length; + + length + })) + } + + pub fn limit_to(&mut self, length: usize) { + self.num_values = self.num_values.min(length); + } + + pub fn collect(self) -> ParquetResult> { + let mut target = Vec::with_capacity(self.len()); + + for chunk in self.into_chunk_iter() { + match chunk? { + HybridRleChunk::Rle(value, size) => { + target.resize(target.len() + size, value); + }, + HybridRleChunk::Bitpacked(decoder) => { + decoder.collect_into(&mut target); + }, + } + } + + Ok(target) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn roundtrip() -> ParquetResult<()> { + let mut buffer = vec![]; + let num_bits = 10u32; + + let data = (0..1000).collect::>(); + + encode::(&mut buffer, data.iter().cloned(), num_bits).unwrap(); + + let decoder = HybridRleDecoder::new(&buffer, num_bits, data.len()); + + let result = decoder.collect()?; + + assert_eq!(result, data); + Ok(()) + } + + #[test] + fn pyarrow_integration() -> ParquetResult<()> { + // data encoded from pyarrow representing (0..1000) + let data = vec![ + 127, 0, 4, 32, 192, 0, 4, 20, 96, 192, 1, 8, 36, 160, 192, 2, 12, 52, 224, 192, 3, 16, + 68, 32, 193, 4, 20, 84, 96, 193, 5, 24, 100, 160, 193, 6, 28, 116, 224, 193, 7, 32, + 132, 32, 194, 8, 36, 148, 96, 194, 9, 40, 164, 160, 194, 10, 44, 180, 224, 194, 11, 48, + 196, 32, 195, 12, 52, 212, 96, 195, 13, 56, 228, 160, 195, 14, 60, 244, 224, 195, 15, + 64, 4, 33, 196, 16, 68, 20, 97, 196, 17, 72, 36, 161, 196, 18, 76, 52, 225, 196, 19, + 80, 68, 33, 197, 20, 84, 84, 97, 197, 21, 88, 100, 161, 197, 22, 92, 116, 225, 197, 23, + 96, 132, 33, 198, 24, 100, 148, 97, 198, 25, 104, 164, 161, 198, 26, 108, 180, 225, + 198, 27, 112, 196, 33, 199, 28, 116, 212, 97, 199, 29, 120, 228, 161, 199, 30, 124, + 244, 225, 199, 31, 128, 4, 34, 200, 32, 132, 20, 98, 200, 33, 136, 36, 162, 200, 34, + 140, 52, 226, 200, 35, 144, 68, 34, 201, 36, 148, 84, 98, 201, 37, 152, 100, 162, 201, + 38, 156, 116, 226, 201, 39, 160, 132, 34, 202, 40, 164, 148, 98, 202, 41, 168, 164, + 162, 202, 42, 172, 180, 226, 202, 43, 176, 196, 34, 203, 44, 180, 212, 98, 203, 45, + 184, 228, 162, 203, 46, 188, 244, 226, 203, 47, 192, 4, 35, 204, 48, 196, 20, 99, 204, + 49, 200, 36, 163, 204, 50, 204, 52, 227, 204, 51, 208, 68, 35, 205, 52, 212, 84, 99, + 205, 53, 216, 100, 163, 205, 54, 220, 116, 227, 205, 55, 224, 132, 35, 206, 56, 228, + 148, 99, 206, 57, 232, 164, 163, 206, 58, 236, 180, 227, 206, 59, 240, 196, 35, 207, + 60, 244, 212, 99, 207, 61, 248, 228, 163, 207, 62, 252, 244, 227, 207, 63, 0, 5, 36, + 208, 64, 4, 21, 100, 208, 65, 8, 37, 164, 208, 66, 12, 53, 228, 208, 67, 16, 69, 36, + 209, 68, 20, 85, 100, 209, 69, 24, 101, 164, 209, 70, 28, 117, 228, 209, 71, 32, 133, + 36, 210, 72, 36, 149, 100, 210, 73, 40, 165, 164, 210, 74, 44, 181, 228, 210, 75, 48, + 197, 36, 211, 76, 52, 213, 100, 211, 77, 56, 229, 164, 211, 78, 60, 245, 228, 211, 79, + 64, 5, 37, 212, 80, 68, 21, 101, 212, 81, 72, 37, 165, 212, 82, 76, 53, 229, 212, 83, + 80, 69, 37, 213, 84, 84, 85, 101, 213, 85, 88, 101, 165, 213, 86, 92, 117, 229, 213, + 87, 96, 133, 37, 214, 88, 100, 149, 101, 214, 89, 104, 165, 165, 214, 90, 108, 181, + 229, 214, 91, 112, 197, 37, 215, 92, 116, 213, 101, 215, 93, 120, 229, 165, 215, 94, + 124, 245, 229, 215, 95, 128, 5, 38, 216, 96, 132, 21, 102, 216, 97, 136, 37, 166, 216, + 98, 140, 53, 230, 216, 99, 144, 69, 38, 217, 100, 148, 85, 102, 217, 101, 152, 101, + 166, 217, 102, 156, 117, 230, 217, 103, 160, 133, 38, 218, 104, 164, 149, 102, 218, + 105, 168, 165, 166, 218, 106, 172, 181, 230, 218, 107, 176, 197, 38, 219, 108, 180, + 213, 102, 219, 109, 184, 229, 166, 219, 110, 188, 245, 230, 219, 111, 192, 5, 39, 220, + 112, 196, 21, 103, 220, 113, 200, 37, 167, 220, 114, 204, 53, 231, 220, 115, 208, 69, + 39, 221, 116, 212, 85, 103, 221, 117, 216, 101, 167, 221, 118, 220, 117, 231, 221, 119, + 224, 133, 39, 222, 120, 228, 149, 103, 222, 121, 232, 165, 167, 222, 122, 236, 181, + 231, 222, 123, 240, 197, 39, 223, 124, 244, 213, 103, 223, 125, 125, 248, 229, 167, + 223, 126, 252, 245, 231, 223, 127, 0, 6, 40, 224, 128, 4, 22, 104, 224, 129, 8, 38, + 168, 224, 130, 12, 54, 232, 224, 131, 16, 70, 40, 225, 132, 20, 86, 104, 225, 133, 24, + 102, 168, 225, 134, 28, 118, 232, 225, 135, 32, 134, 40, 226, 136, 36, 150, 104, 226, + 137, 40, 166, 168, 226, 138, 44, 182, 232, 226, 139, 48, 198, 40, 227, 140, 52, 214, + 104, 227, 141, 56, 230, 168, 227, 142, 60, 246, 232, 227, 143, 64, 6, 41, 228, 144, 68, + 22, 105, 228, 145, 72, 38, 169, 228, 146, 76, 54, 233, 228, 147, 80, 70, 41, 229, 148, + 84, 86, 105, 229, 149, 88, 102, 169, 229, 150, 92, 118, 233, 229, 151, 96, 134, 41, + 230, 152, 100, 150, 105, 230, 153, 104, 166, 169, 230, 154, 108, 182, 233, 230, 155, + 112, 198, 41, 231, 156, 116, 214, 105, 231, 157, 120, 230, 169, 231, 158, 124, 246, + 233, 231, 159, 128, 6, 42, 232, 160, 132, 22, 106, 232, 161, 136, 38, 170, 232, 162, + 140, 54, 234, 232, 163, 144, 70, 42, 233, 164, 148, 86, 106, 233, 165, 152, 102, 170, + 233, 166, 156, 118, 234, 233, 167, 160, 134, 42, 234, 168, 164, 150, 106, 234, 169, + 168, 166, 170, 234, 170, 172, 182, 234, 234, 171, 176, 198, 42, 235, 172, 180, 214, + 106, 235, 173, 184, 230, 170, 235, 174, 188, 246, 234, 235, 175, 192, 6, 43, 236, 176, + 196, 22, 107, 236, 177, 200, 38, 171, 236, 178, 204, 54, 235, 236, 179, 208, 70, 43, + 237, 180, 212, 86, 107, 237, 181, 216, 102, 171, 237, 182, 220, 118, 235, 237, 183, + 224, 134, 43, 238, 184, 228, 150, 107, 238, 185, 232, 166, 171, 238, 186, 236, 182, + 235, 238, 187, 240, 198, 43, 239, 188, 244, 214, 107, 239, 189, 248, 230, 171, 239, + 190, 252, 246, 235, 239, 191, 0, 7, 44, 240, 192, 4, 23, 108, 240, 193, 8, 39, 172, + 240, 194, 12, 55, 236, 240, 195, 16, 71, 44, 241, 196, 20, 87, 108, 241, 197, 24, 103, + 172, 241, 198, 28, 119, 236, 241, 199, 32, 135, 44, 242, 200, 36, 151, 108, 242, 201, + 40, 167, 172, 242, 202, 44, 183, 236, 242, 203, 48, 199, 44, 243, 204, 52, 215, 108, + 243, 205, 56, 231, 172, 243, 206, 60, 247, 236, 243, 207, 64, 7, 45, 244, 208, 68, 23, + 109, 244, 209, 72, 39, 173, 244, 210, 76, 55, 237, 244, 211, 80, 71, 45, 245, 212, 84, + 87, 109, 245, 213, 88, 103, 173, 245, 214, 92, 119, 237, 245, 215, 96, 135, 45, 246, + 216, 100, 151, 109, 246, 217, 104, 167, 173, 246, 218, 108, 183, 237, 246, 219, 112, + 199, 45, 247, 220, 116, 215, 109, 247, 221, 120, 231, 173, 247, 222, 124, 247, 237, + 247, 223, 128, 7, 46, 248, 224, 132, 23, 110, 248, 225, 136, 39, 174, 248, 226, 140, + 55, 238, 248, 227, 144, 71, 46, 249, 228, 148, 87, 110, 249, 229, 152, 103, 174, 249, + 230, 156, 119, 238, 249, 231, 160, 135, 46, 250, 232, 164, 151, 110, 250, 233, 168, + 167, 174, 250, 234, 172, 183, 238, 250, 235, 176, 199, 46, 251, 236, 180, 215, 110, + 251, 237, 184, 231, 174, 251, 238, 188, 247, 238, 251, 239, 192, 7, 47, 252, 240, 196, + 23, 111, 252, 241, 200, 39, 175, 252, 242, 204, 55, 239, 252, 243, 208, 71, 47, 253, + 244, 212, 87, 111, 253, 245, 216, 103, 175, 253, 246, 220, 119, 239, 253, 247, 224, + 135, 47, 254, 248, 228, 151, 111, 254, 249, + ]; + let num_bits = 10; + + let decoder = HybridRleDecoder::new(&data, num_bits, 1000); + let result = decoder.collect()?; + + assert_eq!(result, (0..1000).collect::>()); + Ok(()) + } + + #[test] + fn small() -> ParquetResult<()> { + let data = vec![3, 2]; + + let num_bits = 3; + + let decoder = HybridRleDecoder::new(&data, num_bits, 1); + + let result = decoder.collect()?; + + assert_eq!(result, &[2]); + Ok(()) + } + + #[test] + fn zero_bit_width() -> ParquetResult<()> { + let data = vec![3]; + + let num_bits = 0; + + let decoder = HybridRleDecoder::new(&data, num_bits, 2); + + let result = decoder.collect()?; + + assert_eq!(result, &[0, 0]); + Ok(()) + } + + #[test] + fn empty_values() -> ParquetResult<()> { + let data = []; + + let num_bits = 0; + + let decoder = HybridRleDecoder::new(&data, num_bits, 100); + + let result = decoder.collect()?; + + assert_eq!(result, vec![0; 100]); + Ok(()) + } +} diff --git a/crates/polars-parquet/src/parquet/encoding/mod.rs b/crates/polars-parquet/src/parquet/encoding/mod.rs new file mode 100644 index 000000000000..4e31a44f8c9e --- /dev/null +++ b/crates/polars-parquet/src/parquet/encoding/mod.rs @@ -0,0 +1,27 @@ +pub mod bitpacked; +pub mod byte_stream_split; +pub mod delta_bitpacked; +pub mod delta_byte_array; +pub mod delta_length_byte_array; +pub mod hybrid_rle; +pub mod plain_byte_array; +pub mod uleb128; +pub mod zigzag_leb128; + +pub use crate::parquet::parquet_bridge::Encoding; + +/// # Panics +/// This function panics iff `values.len() < 4`. +#[inline] +pub fn get_length(values: &[u8]) -> Option { + assert!(values.len() >= 4); + values + .get(0..4) + .map(|x| u32::from_le_bytes(x.try_into().unwrap()) as usize) +} + +/// Returns the ceil of value / 8 +#[inline] +pub fn ceil8(value: usize) -> usize { + value / 8 + ((value % 8 != 0) as usize) +} diff --git a/crates/polars-parquet/src/parquet/encoding/plain_byte_array.rs b/crates/polars-parquet/src/parquet/encoding/plain_byte_array.rs new file mode 100644 index 000000000000..36c317c3647d --- /dev/null +++ b/crates/polars-parquet/src/parquet/encoding/plain_byte_array.rs @@ -0,0 +1,46 @@ +/// Decodes according to [Plain strings](https://github.com/apache/parquet-format/blob/master/Encodings.md#plain-plain--0), +/// prefixes, lengths and values +/// # Implementation +/// This struct does not allocate on the heap. +use crate::parquet::error::ParquetError; + +#[derive(Debug)] +pub struct BinaryIter<'a> { + values: &'a [u8], + length: Option, +} + +impl<'a> BinaryIter<'a> { + pub fn new(values: &'a [u8], length: Option) -> Self { + Self { values, length } + } +} + +impl<'a> Iterator for BinaryIter<'a> { + type Item = Result<&'a [u8], ParquetError>; + + #[inline] + fn next(&mut self) -> Option { + if self.values.len() < 4 { + return None; + } + if let Some(x) = self.length.as_mut() { + *x = x.saturating_sub(1) + } + let length = u32::from_le_bytes(self.values[0..4].try_into().unwrap()) as usize; + self.values = &self.values[4..]; + if length > self.values.len() { + return Some(Err(ParquetError::oos( + "A string in plain encoding declares a length that is out of range", + ))); + } + let (result, remaining) = self.values.split_at(length); + self.values = remaining; + Some(Ok(result)) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.length.unwrap_or_default(), self.length) + } +} diff --git a/crates/polars-parquet/src/parquet/encoding/uleb128.rs b/crates/polars-parquet/src/parquet/encoding/uleb128.rs new file mode 100644 index 000000000000..ec461075b8a7 --- /dev/null +++ b/crates/polars-parquet/src/parquet/encoding/uleb128.rs @@ -0,0 +1,132 @@ +// Reads an uleb128 encoded integer with at most 56 bits (8 bytes with 7 bits worth of payload each). +/// Returns the integer and the number of bytes that made up this integer. +/// +/// If the returned length is bigger than 8 this means the integer required more than 8 bytes and the remaining bytes need to be read sequentially and combined with the return value. +/// +/// # Safety +/// `data` needs to contain at least 8 bytes. +#[target_feature(enable = "bmi2")] +#[cfg(target_feature = "bmi2")] +pub unsafe fn decode_uleb_bmi2(data: &[u8]) -> (u64, usize) { + const CONT_MARKER: u64 = 0x80808080_80808080; + debug_assert!(data.len() >= 8); + + unsafe { + let word = data.as_ptr().cast::().read_unaligned(); + // mask indicating continuation bytes + let mask = std::arch::x86_64::_pext_u64(word, CONT_MARKER); + let len = (!mask).trailing_zeros() + 1; + // which payload bits to extract + let ext = std::arch::x86_64::_bzhi_u64(!CONT_MARKER, 8 * len); + let payload = std::arch::x86_64::_pext_u64(word, ext); + + (payload, len as _) + } +} + +pub fn decode(values: &[u8]) -> (u64, usize) { + #[cfg(target_feature = "bmi2")] + { + if polars_utils::cpuid::has_fast_bmi2() && values.len() >= 8 { + let (result, consumed) = unsafe { decode_uleb_bmi2(values) }; + if consumed <= 8 { + return (result, consumed); + } + } + } + + let mut result = 0; + let mut shift = 0; + + let mut consumed = 0; + for byte in values { + consumed += 1; + + #[cfg(debug_assertions)] + debug_assert!(!(shift == 63 && *byte > 1)); + + result |= u64::from(byte & 0b01111111) << shift; + + if byte & 0b10000000 == 0 { + break; + } + + shift += 7; + } + (result, consumed) +} + +/// Encodes `value` in ULEB128 into `container`. The exact number of bytes written +/// depends on `value`, and cannot be determined upfront. The maximum number of bytes +/// required are 10. +/// # Panic +/// This function may panic if `container.len() < 10` and `value` requires more bytes. +pub fn encode(mut value: u64, container: &mut [u8]) -> usize { + assert!(container.len() >= 10); + let mut consumed = 0; + let mut iter = container.iter_mut(); + loop { + let mut byte = (value as u8) & !128; + value >>= 7; + if value != 0 { + byte |= 128; + } + *iter.next().unwrap() = byte; + consumed += 1; + if value == 0 { + break; + } + } + consumed +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn decode_1() { + let data = vec![0xe5, 0x8e, 0x26, 0xDE, 0xAD, 0xBE, 0xEF]; + let (value, len) = decode(&data); + assert_eq!(value, 624_485); + assert_eq!(len, 3); + } + + #[test] + fn decode_2() { + let data = vec![0b00010000, 0b00000001, 0b00000011, 0b00000011]; + let (value, len) = decode(&data); + assert_eq!(value, 16); + assert_eq!(len, 1); + } + + #[test] + fn round_trip() { + let original = 123124234u64; + let mut container = [0u8; 10]; + let encoded_len = encode(original, &mut container); + let (value, len) = decode(&container); + assert_eq!(value, original); + assert_eq!(len, encoded_len); + } + + #[test] + fn min_value() { + let original = u64::MIN; + let mut container = [0u8; 10]; + let encoded_len = encode(original, &mut container); + let (value, len) = decode(&container); + assert_eq!(value, original); + assert_eq!(len, encoded_len); + } + + #[test] + fn max_value() { + let original = u64::MAX; + let mut container = [0u8; 10]; + let encoded_len = encode(original, &mut container); + let (value, len) = decode(&container); + assert_eq!(value, original); + assert_eq!(len, encoded_len); + } +} diff --git a/crates/polars-parquet/src/parquet/encoding/zigzag_leb128.rs b/crates/polars-parquet/src/parquet/encoding/zigzag_leb128.rs new file mode 100644 index 000000000000..63ab565cf8bd --- /dev/null +++ b/crates/polars-parquet/src/parquet/encoding/zigzag_leb128.rs @@ -0,0 +1,68 @@ +use super::uleb128; + +pub fn decode(values: &[u8]) -> (i64, usize) { + let (u, consumed) = uleb128::decode(values); + ((u >> 1) as i64 ^ -((u & 1) as i64), consumed) +} + +pub fn encode(value: i64) -> ([u8; 10], usize) { + let value = ((value << 1) ^ (value >> (64 - 1))) as u64; + let mut a = [0u8; 10]; + let produced = uleb128::encode(value, &mut a); + (a, produced) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_decode() { + // see e.g. https://stackoverflow.com/a/2211086/931303 + let cases = vec![ + (0u8, 0i64), + (1, -1), + (2, 1), + (3, -2), + (4, 2), + (5, -3), + (6, 3), + (7, -4), + (8, 4), + (9, -5), + ]; + for (data, expected) in cases { + let (result, _) = decode(&[data]); + assert_eq!(result, expected) + } + } + + #[test] + fn test_encode() { + let cases = vec![ + (0u8, 0i64), + (1, -1), + (2, 1), + (3, -2), + (4, 2), + (5, -3), + (6, 3), + (7, -4), + (8, 4), + (9, -5), + ]; + for (expected, data) in cases { + let (result, size) = encode(data); + assert_eq!(size, 1); + assert_eq!(result[0], expected) + } + } + + #[test] + fn test_roundtrip() { + let value = -1001212312; + let (data, size) = encode(value); + let (result, _) = decode(&data[..size]); + assert_eq!(value, result); + } +} diff --git a/crates/polars-parquet/src/parquet/error.rs b/crates/polars-parquet/src/parquet/error.rs new file mode 100644 index 000000000000..3d00cf3c647f --- /dev/null +++ b/crates/polars-parquet/src/parquet/error.rs @@ -0,0 +1,140 @@ +//! Contains [`Error`] + +/// List of features whose non-activation may cause a runtime error. +/// Used to indicate which lack of feature caused [`Error::FeatureNotActive`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[non_exhaustive] +pub enum Feature { + /// Snappy compression and decompression + Snappy, + /// Brotli compression and decompression + Brotli, + /// Gzip compression and decompression + Gzip, + /// Lz4 raw compression and decompression + Lz4, + /// Zstd compression and decompression + Zstd, +} + +/// Errors generated by this crate +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum ParquetError { + /// When the parquet file is known to be out of spec. + OutOfSpec(String), + /// Error presented when trying to use a code branch that requires activating a feature. + FeatureNotActive(Feature, String), + /// Error presented when trying to use a feature from parquet that is not yet supported + FeatureNotSupported(String), + /// When encoding, the user passed an invalid parameter + InvalidParameter(String), + /// When decoding or decompressing, the page would allocate more memory than allowed + WouldOverAllocate, +} + +impl ParquetError { + /// Create an OutOfSpec error from any Into + pub(crate) fn oos>(message: I) -> Self { + Self::OutOfSpec(message.into()) + } + + /// Create an FeatureNotSupported error from any Into + pub(crate) fn not_supported>(message: I) -> Self { + Self::FeatureNotSupported(message.into()) + } +} + +impl std::error::Error for ParquetError {} + +impl std::fmt::Display for ParquetError { + fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + ParquetError::OutOfSpec(message) => { + write!(fmt, "File out of specification: {}", message) + }, + ParquetError::FeatureNotActive(feature, reason) => { + write!( + fmt, + "The feature \"{:?}\" needs to be active to {}", + feature, reason + ) + }, + ParquetError::FeatureNotSupported(reason) => { + write!(fmt, "Not yet supported: {}", reason) + }, + ParquetError::InvalidParameter(message) => { + write!(fmt, "Invalid parameter: {}", message) + }, + ParquetError::WouldOverAllocate => { + write!(fmt, "Operation would exceed memory use threshold") + }, + } + } +} + +#[cfg(feature = "snappy")] +impl From for ParquetError { + fn from(e: snap::Error) -> ParquetError { + ParquetError::OutOfSpec(format!("underlying snap error: {}", e)) + } +} + +#[cfg(feature = "lz4_flex")] +impl From for ParquetError { + fn from(e: lz4_flex::block::DecompressError) -> ParquetError { + ParquetError::OutOfSpec(format!("underlying lz4_flex error: {}", e)) + } +} + +#[cfg(feature = "lz4_flex")] +impl From for ParquetError { + fn from(e: lz4_flex::block::CompressError) -> ParquetError { + ParquetError::OutOfSpec(format!("underlying lz4_flex error: {}", e)) + } +} + +impl From for ParquetError { + fn from(e: polars_parquet_format::thrift::Error) -> ParquetError { + ParquetError::OutOfSpec(format!("Invalid thrift: {}", e)) + } +} + +impl From for ParquetError { + fn from(e: std::io::Error) -> ParquetError { + ParquetError::OutOfSpec(format!("underlying IO error: {}", e)) + } +} + +impl From for ParquetError { + fn from(e: std::collections::TryReserveError) -> ParquetError { + ParquetError::OutOfSpec(format!("OOM: {}", e)) + } +} + +impl From for ParquetError { + fn from(e: std::num::TryFromIntError) -> ParquetError { + ParquetError::OutOfSpec(format!("Number must be zero or positive: {}", e)) + } +} + +impl From for ParquetError { + fn from(e: std::array::TryFromSliceError) -> ParquetError { + ParquetError::OutOfSpec(format!("Can't deserialize to parquet native type: {}", e)) + } +} + +/// A specialized `Result` for Parquet errors. +pub type ParquetResult = std::result::Result; + +impl From for polars_error::PolarsError { + fn from(e: ParquetError) -> polars_error::PolarsError { + polars_error::PolarsError::ComputeError(format!("parquet: {}", e).into()) + } +} + +impl From for ParquetError { + fn from(e: polars_error::PolarsError) -> ParquetError { + ParquetError::OutOfSpec(format!("OOM: {}", e)) + } +} diff --git a/crates/polars-parquet/src/parquet/metadata/column_chunk_metadata.rs b/crates/polars-parquet/src/parquet/metadata/column_chunk_metadata.rs new file mode 100644 index 000000000000..9fc241d5c51c --- /dev/null +++ b/crates/polars-parquet/src/parquet/metadata/column_chunk_metadata.rs @@ -0,0 +1,219 @@ +use polars_parquet_format::{ColumnChunk, ColumnMetaData, Encoding}; + +use super::column_descriptor::ColumnDescriptor; +use crate::parquet::compression::Compression; +use crate::parquet::error::{ParquetError, ParquetResult}; +use crate::parquet::schema::types::PhysicalType; +use crate::parquet::statistics::Statistics; + +#[cfg(feature = "serde_types")] +mod serde_types { + pub use std::io::Cursor; + + pub use polars_parquet_format::thrift::protocol::{ + TCompactInputProtocol, TCompactOutputProtocol, + }; + pub use serde::de::Error as DeserializeError; + pub use serde::ser::Error as SerializeError; + pub use serde::{Deserialize, Deserializer, Serialize, Serializer}; +} +#[cfg(feature = "serde_types")] +use serde_types::*; + +/// Metadata for a column chunk. +/// +/// This contains the `ColumnDescriptor` associated with the chunk so that deserializers have +/// access to the descriptor (e.g. physical, converted, logical). +/// +/// This struct is intentionally not `Clone`, as it is a huge struct. +#[derive(Debug)] +#[cfg_attr(feature = "serde_types", derive(Deserialize, Serialize))] +pub struct ColumnChunkMetadata { + #[cfg_attr( + feature = "serde_types", + serde(serialize_with = "serialize_column_chunk") + )] + #[cfg_attr( + feature = "serde_types", + serde(deserialize_with = "deserialize_column_chunk") + )] + column_chunk: ColumnChunk, + column_descr: ColumnDescriptor, +} + +#[cfg(feature = "serde_types")] +fn serialize_column_chunk( + column_chunk: &ColumnChunk, + serializer: S, +) -> std::result::Result +where + S: Serializer, +{ + let mut buf = vec![]; + let cursor = Cursor::new(&mut buf[..]); + let mut protocol = TCompactOutputProtocol::new(cursor); + column_chunk + .write_to_out_protocol(&mut protocol) + .map_err(S::Error::custom)?; + serializer.serialize_bytes(&buf) +} + +#[cfg(feature = "serde_types")] +fn deserialize_column_chunk<'de, D>(deserializer: D) -> std::result::Result +where + D: Deserializer<'de>, +{ + use polars_utils::pl_serialize::deserialize_map_bytes; + + deserialize_map_bytes(deserializer, |b| { + let mut b = b.as_ref(); + let mut protocol = TCompactInputProtocol::new(&mut b, usize::MAX); + ColumnChunk::read_from_in_protocol(&mut protocol).map_err(D::Error::custom) + })? +} + +// Represents common operations for a column chunk. +impl ColumnChunkMetadata { + /// Returns a new [`ColumnChunkMetadata`] + pub fn new(column_chunk: ColumnChunk, column_descr: ColumnDescriptor) -> Self { + Self { + column_chunk, + column_descr, + } + } + + /// File where the column chunk is stored. + /// + /// If not set, assumed to belong to the same file as the metadata. + /// This path is relative to the current file. + pub fn file_path(&self) -> &Option { + &self.column_chunk.file_path + } + + /// Byte offset in `file_path()`. + pub fn file_offset(&self) -> i64 { + self.column_chunk.file_offset + } + + /// Returns this column's [`ColumnChunk`] + pub fn column_chunk(&self) -> &ColumnChunk { + &self.column_chunk + } + + /// The column's [`ColumnMetaData`] + pub fn metadata(&self) -> &ColumnMetaData { + self.column_chunk.meta_data.as_ref().unwrap() + } + + /// The [`ColumnDescriptor`] for this column. This descriptor contains the physical and logical type + /// of the pages. + pub fn descriptor(&self) -> &ColumnDescriptor { + &self.column_descr + } + + /// The [`PhysicalType`] of this column. + pub fn physical_type(&self) -> PhysicalType { + self.column_descr.descriptor.primitive_type.physical_type + } + + /// Decodes the raw statistics into [`Statistics`]. + pub fn statistics(&self) -> Option> { + self.metadata().statistics.as_ref().map(|x| { + Statistics::deserialize(x, self.column_descr.descriptor.primitive_type.clone()) + }) + } + + /// Total number of values in this column chunk. Note that this is not necessarily the number + /// of rows. E.g. the (nested) array `[[1, 2], [3]]` has 2 rows and 3 values. + pub fn num_values(&self) -> i64 { + self.metadata().num_values + } + + /// [`Compression`] for this column. + pub fn compression(&self) -> Compression { + self.metadata().codec.try_into().unwrap() + } + + /// Returns the total compressed data size of this column chunk. + pub fn compressed_size(&self) -> i64 { + self.metadata().total_compressed_size + } + + /// Returns the total uncompressed data size of this column chunk. + pub fn uncompressed_size(&self) -> i64 { + self.metadata().total_uncompressed_size + } + + /// Returns the offset for the column data. + pub fn data_page_offset(&self) -> i64 { + self.metadata().data_page_offset + } + + /// Returns `true` if this column chunk contains a index page, `false` otherwise. + pub fn has_index_page(&self) -> bool { + self.metadata().index_page_offset.is_some() + } + + /// Returns the offset for the index page. + pub fn index_page_offset(&self) -> Option { + self.metadata().index_page_offset + } + + /// Returns the offset for the dictionary page, if any. + pub fn dictionary_page_offset(&self) -> Option { + self.metadata().dictionary_page_offset + } + + /// Returns the encoding for this column + pub fn column_encoding(&self) -> &Vec { + &self.metadata().encodings + } + + /// Returns the offset and length in bytes of the column chunk within the file + pub fn byte_range(&self) -> core::ops::Range { + // this has been validated in [`try_from_thrift`] + column_metadata_byte_range(self.metadata()) + } + + /// Method to convert from Thrift. + pub(crate) fn try_from_thrift( + column_descr: ColumnDescriptor, + column_chunk: ColumnChunk, + ) -> ParquetResult { + // validate metadata + if let Some(meta) = &column_chunk.meta_data { + let _: u64 = meta.total_compressed_size.try_into()?; + + if let Some(offset) = meta.dictionary_page_offset { + let _: u64 = offset.try_into()?; + } + let _: u64 = meta.data_page_offset.try_into()?; + + let _: Compression = meta.codec.try_into()?; + } else { + return Err(ParquetError::oos("Column chunk requires metadata")); + } + + Ok(Self { + column_chunk, + column_descr, + }) + } + + /// Method to convert to Thrift. + pub fn into_thrift(self) -> ColumnChunk { + self.column_chunk + } +} + +pub(super) fn column_metadata_byte_range( + column_metadata: &ColumnMetaData, +) -> core::ops::Range { + let offset = if let Some(dict_page_offset) = column_metadata.dictionary_page_offset { + dict_page_offset as u64 + } else { + column_metadata.data_page_offset as u64 + }; + let len = column_metadata.total_compressed_size as u64; + offset..offset.checked_add(len).unwrap() +} diff --git a/crates/polars-parquet/src/parquet/metadata/column_descriptor.rs b/crates/polars-parquet/src/parquet/metadata/column_descriptor.rs new file mode 100644 index 000000000000..991a37280d31 --- /dev/null +++ b/crates/polars-parquet/src/parquet/metadata/column_descriptor.rs @@ -0,0 +1,87 @@ +use std::ops::Deref; +use std::sync::Arc; + +use polars_utils::pl_str::PlSmallStr; +#[cfg(feature = "serde_types")] +use serde::{Deserialize, Serialize}; + +use crate::parquet::schema::types::{ParquetType, PrimitiveType}; + +/// A descriptor of a parquet column. It contains the necessary information to deserialize +/// a parquet column. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde_types", derive(Deserialize, Serialize))] +pub struct Descriptor { + /// The [`PrimitiveType`] of this column + pub primitive_type: PrimitiveType, + + /// The maximum definition level + pub max_def_level: i16, + + /// The maximum repetition level + pub max_rep_level: i16, +} + +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde_types", derive(Deserialize, Serialize))] +pub enum BaseType { + Owned(ParquetType), + Arc(Arc), +} + +impl BaseType { + pub fn into_arc(self) -> Self { + match self { + BaseType::Owned(t) => Self::Arc(Arc::new(t)), + BaseType::Arc(t) => Self::Arc(t), + } + } +} + +impl PartialEq for BaseType { + fn eq(&self, other: &Self) -> bool { + self.deref() == other.deref() + } +} + +impl Deref for BaseType { + type Target = ParquetType; + + fn deref(&self) -> &Self::Target { + match self { + BaseType::Owned(i) => i, + BaseType::Arc(i) => i.as_ref(), + } + } +} + +/// A descriptor for leaf-level primitive columns. +/// This encapsulates information such as definition and repetition levels and is used to +/// re-assemble nested data. +#[derive(Debug, PartialEq, Clone)] +#[cfg_attr(feature = "serde_types", derive(Deserialize, Serialize))] +pub struct ColumnDescriptor { + /// The descriptor this columns' leaf. + pub descriptor: Descriptor, + + /// The path of this column. For instance, "a.b.c.d". + pub path_in_schema: Vec, + + /// The [`ParquetType`] this descriptor is a leaf of + pub base_type: BaseType, +} + +impl ColumnDescriptor { + /// Creates new descriptor for leaf-level column. + pub fn new( + descriptor: Descriptor, + path_in_schema: Vec, + base_type: BaseType, + ) -> Self { + Self { + descriptor, + path_in_schema, + base_type, + } + } +} diff --git a/crates/polars-parquet/src/parquet/metadata/column_order.rs b/crates/polars-parquet/src/parquet/metadata/column_order.rs new file mode 100644 index 000000000000..4d66f615bfa0 --- /dev/null +++ b/crates/polars-parquet/src/parquet/metadata/column_order.rs @@ -0,0 +1,30 @@ +#[cfg(feature = "serde_types")] +use serde::{Deserialize, Serialize}; + +use super::sort::SortOrder; + +/// Column order that specifies what method was used to aggregate min/max values for +/// statistics. +/// +/// If column order is undefined, then it is the legacy behaviour and all values should +/// be compared as signed values/bytes. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "serde_types", derive(Deserialize, Serialize))] +pub enum ColumnOrder { + /// Column uses the order defined by its logical or physical type + /// (if there is no logical type), parquet-format 2.4.0+. + TypeDefinedOrder(SortOrder), + /// Undefined column order, means legacy behaviour before parquet-format 2.4.0. + /// Sort order is always SIGNED. + Undefined, +} + +impl ColumnOrder { + /// Returns sort order associated with this column order. + pub fn sort_order(&self) -> SortOrder { + match *self { + ColumnOrder::TypeDefinedOrder(order) => order, + ColumnOrder::Undefined => SortOrder::Signed, + } + } +} diff --git a/crates/polars-parquet/src/parquet/metadata/file_metadata.rs b/crates/polars-parquet/src/parquet/metadata/file_metadata.rs new file mode 100644 index 000000000000..4d2dfddd7de2 --- /dev/null +++ b/crates/polars-parquet/src/parquet/metadata/file_metadata.rs @@ -0,0 +1,121 @@ +use polars_parquet_format::ColumnOrder as TColumnOrder; + +use super::RowGroupMetadata; +use super::column_order::ColumnOrder; +use super::schema_descriptor::SchemaDescriptor; +use crate::parquet::error::ParquetError; +use crate::parquet::metadata::get_sort_order; +pub use crate::parquet::thrift_format::KeyValue; + +/// Metadata for a Parquet file. +// This is almost equal to [`polars_parquet_format::FileMetaData`] but contains the descriptors, +// which are crucial to deserialize pages. +#[derive(Debug, Clone)] +pub struct FileMetadata { + /// version of this file. + pub version: i32, + /// number of rows in the file. + pub num_rows: usize, + /// Max row group height, useful for sharing column materializations. + pub max_row_group_height: usize, + /// String message for application that wrote this file. + /// + /// This should have the following format: + /// ` version (build )`. + /// + /// ```shell + /// parquet-mr version 1.8.0 (build 0fda28af84b9746396014ad6a415b90592a98b3b) + /// ``` + pub created_by: Option, + /// The row groups of this file + pub row_groups: Vec, + /// key_value_metadata of this file. + pub key_value_metadata: Option>, + /// schema descriptor. + pub schema_descr: SchemaDescriptor, + /// Column (sort) order used for `min` and `max` values of each column in this file. + /// + /// Each column order corresponds to one column, determined by its position in the + /// list, matching the position of the column in the schema. + /// + /// When `None` is returned, there are no column orders available, and each column + /// should be assumed to have undefined (legacy) column order. + pub column_orders: Option>, +} + +impl FileMetadata { + /// Returns the [`SchemaDescriptor`] that describes schema of this file. + pub fn schema(&self) -> &SchemaDescriptor { + &self.schema_descr + } + + /// returns the metadata + pub fn key_value_metadata(&self) -> &Option> { + &self.key_value_metadata + } + + /// Returns column order for `i`th column in this file. + /// If column orders are not available, returns undefined (legacy) column order. + pub fn column_order(&self, i: usize) -> ColumnOrder { + self.column_orders + .as_ref() + .map(|data| data[i]) + .unwrap_or(ColumnOrder::Undefined) + } + + /// Deserializes [`crate::parquet::thrift_format::FileMetadata`] into this struct + pub fn try_from_thrift( + metadata: polars_parquet_format::FileMetaData, + ) -> Result { + let schema_descr = SchemaDescriptor::try_from_thrift(&metadata.schema)?; + + let mut max_row_group_height = 0; + + let row_groups = metadata + .row_groups + .into_iter() + .map(|rg| { + let md = RowGroupMetadata::try_from_thrift(&schema_descr, rg)?; + max_row_group_height = max_row_group_height.max(md.num_rows()); + Ok(md) + }) + .collect::>()?; + + let column_orders = metadata + .column_orders + .map(|orders| parse_column_orders(&orders, &schema_descr)); + + Ok(FileMetadata { + version: metadata.version, + num_rows: metadata.num_rows.try_into()?, + max_row_group_height, + created_by: metadata.created_by, + row_groups, + key_value_metadata: metadata.key_value_metadata, + schema_descr, + column_orders, + }) + } +} + +/// Parses [`ColumnOrder`] from Thrift definition. +fn parse_column_orders( + orders: &[TColumnOrder], + schema_descr: &SchemaDescriptor, +) -> Vec { + schema_descr + .columns() + .iter() + .zip(orders.iter()) + .map(|(column, order)| match order { + TColumnOrder::TYPEORDER(_) => { + let sort_order = get_sort_order( + &column.descriptor.primitive_type.logical_type, + &column.descriptor.primitive_type.converted_type, + &column.descriptor.primitive_type.physical_type, + ); + ColumnOrder::TypeDefinedOrder(sort_order) + }, + }) + .collect() +} diff --git a/crates/polars-parquet/src/parquet/metadata/mod.rs b/crates/polars-parquet/src/parquet/metadata/mod.rs new file mode 100644 index 000000000000..b7a80739e719 --- /dev/null +++ b/crates/polars-parquet/src/parquet/metadata/mod.rs @@ -0,0 +1,17 @@ +mod column_chunk_metadata; +mod column_descriptor; +mod column_order; +mod file_metadata; +mod row_metadata; +mod schema_descriptor; +mod sort; + +pub use column_chunk_metadata::ColumnChunkMetadata; +pub use column_descriptor::{ColumnDescriptor, Descriptor}; +pub use column_order::ColumnOrder; +pub use file_metadata::{FileMetadata, KeyValue}; +pub use row_metadata::RowGroupMetadata; +pub use schema_descriptor::SchemaDescriptor; +pub use sort::*; + +pub use crate::parquet::thrift_format::FileMetaData as ThriftFileMetadata; diff --git a/crates/polars-parquet/src/parquet/metadata/row_metadata.rs b/crates/polars-parquet/src/parquet/metadata/row_metadata.rs new file mode 100644 index 000000000000..7cf4c5622701 --- /dev/null +++ b/crates/polars-parquet/src/parquet/metadata/row_metadata.rs @@ -0,0 +1,161 @@ +use std::sync::Arc; + +use hashbrown::hash_map::RawEntryMut; +use polars_parquet_format::{RowGroup, SortingColumn}; +use polars_utils::aliases::{InitHashMaps, PlHashMap}; +use polars_utils::idx_vec::UnitVec; +use polars_utils::pl_str::PlSmallStr; +use polars_utils::unitvec; + +use super::column_chunk_metadata::{ColumnChunkMetadata, column_metadata_byte_range}; +use super::schema_descriptor::SchemaDescriptor; +use crate::parquet::error::{ParquetError, ParquetResult}; + +type ColumnLookup = PlHashMap>; + +trait InitColumnLookup { + fn add_column(&mut self, index: usize, column: &ColumnChunkMetadata); +} + +impl InitColumnLookup for ColumnLookup { + #[inline(always)] + fn add_column(&mut self, index: usize, column: &ColumnChunkMetadata) { + let root_name = &column.descriptor().path_in_schema[0]; + + match self.raw_entry_mut().from_key(root_name) { + RawEntryMut::Vacant(slot) => { + slot.insert(root_name.clone(), unitvec![index]); + }, + RawEntryMut::Occupied(mut slot) => { + slot.get_mut().push(index); + }, + }; + } +} + +/// Metadata for a row group. +#[derive(Debug, Clone, Default)] +pub struct RowGroupMetadata { + // Moving of `ColumnChunkMetadata` is very expensive they are rather big. So, we arc the vec + // instead of having an arc slice. This way we don't to move the vec values into an arc when + // collecting. + columns: Arc>, + column_lookup: PlHashMap>, + num_rows: usize, + total_byte_size: usize, + full_byte_range: core::ops::Range, + sorting_columns: Option>, +} + +impl RowGroupMetadata { + #[inline(always)] + pub fn n_columns(&self) -> usize { + self.columns.len() + } + + /// Fetch all columns under this root name if it exists. + pub fn columns_under_root_iter( + &self, + root_name: &str, + ) -> Option + DoubleEndedIterator> { + self.column_lookup + .get(root_name) + .map(|x| x.iter().map(|&x| &self.columns[x])) + } + + /// Fetch all columns under this root name if it exists. + pub fn columns_idxs_under_root_iter<'a>(&'a self, root_name: &str) -> Option<&'a [usize]> { + self.column_lookup.get(root_name).map(|x| x.as_slice()) + } + + pub fn parquet_columns(&self) -> &[ColumnChunkMetadata] { + self.columns.as_ref().as_slice() + } + + /// Number of rows in this row group. + pub fn num_rows(&self) -> usize { + self.num_rows + } + + /// Total byte size of all uncompressed column data in this row group. + pub fn total_byte_size(&self) -> usize { + self.total_byte_size + } + + /// Total size of all compressed column data in this row group. + pub fn compressed_size(&self) -> usize { + self.columns + .iter() + .map(|c| c.compressed_size() as usize) + .sum::() + } + + pub fn full_byte_range(&self) -> core::ops::Range { + self.full_byte_range.clone() + } + + pub fn byte_ranges_iter(&self) -> impl '_ + ExactSizeIterator> { + self.columns.iter().map(|x| x.byte_range()) + } + + pub fn sorting_columns(&self) -> Option<&[SortingColumn]> { + self.sorting_columns.as_deref() + } + + /// Method to convert from Thrift. + pub(crate) fn try_from_thrift( + schema_descr: &SchemaDescriptor, + rg: RowGroup, + ) -> ParquetResult { + if schema_descr.columns().len() != rg.columns.len() { + return Err(ParquetError::oos(format!( + "The number of columns in the row group ({}) must be equal to the number of columns in the schema ({})", + rg.columns.len(), + schema_descr.columns().len() + ))); + } + let total_byte_size = rg.total_byte_size.try_into()?; + let num_rows = rg.num_rows.try_into()?; + + let mut column_lookup = ColumnLookup::with_capacity(rg.columns.len()); + let mut full_byte_range = if let Some(first_column_chunk) = rg.columns.first() { + let Some(metadata) = &first_column_chunk.meta_data else { + return Err(ParquetError::oos("Column chunk requires metadata")); + }; + column_metadata_byte_range(metadata) + } else { + 0..0 + }; + + let sorting_columns = rg.sorting_columns.clone(); + + let columns = rg + .columns + .into_iter() + .zip(schema_descr.columns()) + .enumerate() + .map(|(i, (column_chunk, descriptor))| { + let column = + ColumnChunkMetadata::try_from_thrift(descriptor.clone(), column_chunk)?; + + column_lookup.add_column(i, &column); + + let byte_range = column.byte_range(); + full_byte_range = full_byte_range.start.min(byte_range.start) + ..full_byte_range.end.max(byte_range.end); + + Ok(column) + }) + .collect::>>()?; + let columns = Arc::new(columns); + + Ok(RowGroupMetadata { + columns, + column_lookup, + num_rows, + total_byte_size, + full_byte_range, + sorting_columns, + }) + } +} diff --git a/crates/polars-parquet/src/parquet/metadata/schema_descriptor.rs b/crates/polars-parquet/src/parquet/metadata/schema_descriptor.rs new file mode 100644 index 000000000000..33dc5023f5ab --- /dev/null +++ b/crates/polars-parquet/src/parquet/metadata/schema_descriptor.rs @@ -0,0 +1,148 @@ +use polars_parquet_format::SchemaElement; +use polars_utils::pl_str::PlSmallStr; +#[cfg(feature = "serde_types")] +use serde::{Deserialize, Serialize}; + +use super::column_descriptor::{BaseType, ColumnDescriptor, Descriptor}; +use crate::parquet::error::{ParquetError, ParquetResult}; +use crate::parquet::schema::Repetition; +use crate::parquet::schema::io_message::from_message; +use crate::parquet::schema::types::{FieldInfo, ParquetType}; + +/// A schema descriptor. This encapsulates the top-level schemas for all the columns, +/// as well as all descriptors for all the primitive columns. +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde_types", derive(Deserialize, Serialize))] +pub struct SchemaDescriptor { + name: PlSmallStr, + // The top-level schema (the "message" type). + fields: Vec, + + // All the descriptors for primitive columns in this schema, constructed from + // `schema` in DFS order. + leaves: Vec, +} + +impl SchemaDescriptor { + /// Creates new schema descriptor from Parquet schema. + pub fn new(name: PlSmallStr, fields: Vec) -> Self { + let mut leaves = vec![]; + for f in &fields { + let mut path = vec![]; + build_tree(f, BaseType::Owned(f.clone()), 0, 0, &mut leaves, &mut path); + } + + Self { + name, + fields, + leaves, + } + } + + /// The [`ColumnDescriptor`] (leaves) of this schema. + /// + /// Note that, for nested fields, this may contain more entries than the number of fields + /// in the file - e.g. a struct field may have two columns. + pub fn columns(&self) -> &[ColumnDescriptor] { + &self.leaves + } + + /// The schemas' name. + pub fn name(&self) -> &str { + &self.name + } + + /// The schemas' fields. + pub fn fields(&self) -> &[ParquetType] { + &self.fields + } + + /// The schemas' leaves. + pub fn leaves(&self) -> &[ColumnDescriptor] { + &self.leaves + } + + pub(crate) fn into_thrift(self) -> Vec { + ParquetType::GroupType { + field_info: FieldInfo { + name: self.name, + repetition: Repetition::Optional, + id: None, + }, + logical_type: None, + converted_type: None, + fields: self.fields, + } + .to_thrift() + } + + fn try_from_type(type_: ParquetType) -> ParquetResult { + match type_ { + ParquetType::GroupType { + field_info, fields, .. + } => Ok(Self::new(field_info.name, fields)), + _ => Err(ParquetError::oos("The parquet schema MUST be a group type")), + } + } + + pub(crate) fn try_from_thrift(elements: &[SchemaElement]) -> ParquetResult { + let schema = ParquetType::try_from_thrift(elements)?; + Self::try_from_type(schema) + } + + /// Creates a schema from + pub fn try_from_message(message: &str) -> ParquetResult { + let schema = from_message(message)?; + Self::try_from_type(schema) + } +} + +fn build_tree<'a>( + tp: &'a ParquetType, + base_tp: BaseType, + mut max_rep_level: i16, + mut max_def_level: i16, + leaves: &mut Vec, + path_so_far: &mut Vec<&'a str>, +) { + path_so_far.push(tp.name()); + match tp.get_field_info().repetition { + Repetition::Optional => { + max_def_level += 1; + }, + Repetition::Repeated => { + max_def_level += 1; + max_rep_level += 1; + }, + _ => {}, + } + + match tp { + ParquetType::PrimitiveType(p) => { + let path_in_schema = path_so_far.iter().copied().map(Into::into).collect(); + leaves.push(ColumnDescriptor::new( + Descriptor { + primitive_type: p.clone(), + max_def_level, + max_rep_level, + }, + path_in_schema, + base_tp, + )); + }, + ParquetType::GroupType { fields, .. } => { + let base_tp = base_tp.into_arc(); + for f in fields { + build_tree( + f, + base_tp.clone(), + max_rep_level, + max_def_level, + leaves, + path_so_far, + ); + path_so_far.pop(); + } + }, + } +} diff --git a/crates/polars-parquet/src/parquet/metadata/sort.rs b/crates/polars-parquet/src/parquet/metadata/sort.rs new file mode 100644 index 000000000000..d75d77134103 --- /dev/null +++ b/crates/polars-parquet/src/parquet/metadata/sort.rs @@ -0,0 +1,95 @@ +#[cfg(feature = "serde_types")] +use serde::{Deserialize, Serialize}; + +use crate::parquet::schema::types::{ + IntegerType, PhysicalType, PrimitiveConvertedType, PrimitiveLogicalType, +}; + +/// Sort order for page and column statistics. +/// +/// Types are associated with sort orders and column stats are aggregated using a sort +/// order, and a sort order should be considered when comparing values with statistics +/// min/max. +/// +/// See reference in +/// +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "serde_types", derive(Deserialize, Serialize))] +pub enum SortOrder { + /// Signed (either value or legacy byte-wise) comparison. + Signed, + /// Unsigned (depending on physical type either value or byte-wise) comparison. + Unsigned, + /// Comparison is undefined. + Undefined, +} + +/// Returns sort order for a physical/logical type. +pub fn get_sort_order( + logical_type: &Option, + converted_type: &Option, + physical_type: &PhysicalType, +) -> SortOrder { + if let Some(logical_type) = logical_type { + return get_logical_sort_order(logical_type); + }; + if let Some(converted_type) = converted_type { + return get_converted_sort_order(converted_type); + }; + get_physical_sort_order(physical_type) +} + +fn get_logical_sort_order(logical_type: &PrimitiveLogicalType) -> SortOrder { + // TODO: Should this take converted and logical type, for compatibility? + use PrimitiveLogicalType::*; + match logical_type { + String | Enum | Json | Bson => SortOrder::Unsigned, + Integer(t) => match t { + IntegerType::Int8 | IntegerType::Int16 | IntegerType::Int32 | IntegerType::Int64 => { + SortOrder::Signed + }, + _ => SortOrder::Unsigned, + }, + Decimal(_, _) => SortOrder::Signed, + Date => SortOrder::Signed, + Time { .. } => SortOrder::Signed, + Timestamp { .. } => SortOrder::Signed, + Unknown => SortOrder::Undefined, + Uuid => SortOrder::Unsigned, + Float16 => SortOrder::Unsigned, + } +} + +fn get_converted_sort_order(converted_type: &PrimitiveConvertedType) -> SortOrder { + use PrimitiveConvertedType::*; + match converted_type { + // Unsigned byte-wise comparison. + Utf8 | Json | Bson | Enum => SortOrder::Unsigned, + Int8 | Int16 | Int32 | Int64 => SortOrder::Signed, + Uint8 | Uint16 | Uint32 | Uint64 => SortOrder::Unsigned, + // Signed comparison of the represented value. + Decimal(_, _) => SortOrder::Signed, + Date => SortOrder::Signed, + TimeMillis | TimeMicros | TimestampMillis | TimestampMicros => SortOrder::Signed, + Interval => SortOrder::Undefined, + } +} + +fn get_physical_sort_order(physical_type: &PhysicalType) -> SortOrder { + use PhysicalType::*; + match physical_type { + // Order: false, true + Boolean => SortOrder::Unsigned, + Int32 | Int64 => SortOrder::Signed, + Int96 => SortOrder::Undefined, + // Notes to remember when comparing float/double values: + // If the min is a NaN, it should be ignored. + // If the max is a NaN, it should be ignored. + // If the min is +0, the row group may contain -0 values as well. + // If the max is -0, the row group may contain +0 values as well. + // When looking for NaN values, min and max should be ignored. + Float | Double => SortOrder::Signed, + // Unsigned byte-wise comparison + ByteArray | FixedLenByteArray(_) => SortOrder::Unsigned, + } +} diff --git a/crates/polars-parquet/src/parquet/mod.rs b/crates/polars-parquet/src/parquet/mod.rs new file mode 100644 index 000000000000..03b6ab42d353 --- /dev/null +++ b/crates/polars-parquet/src/parquet/mod.rs @@ -0,0 +1,65 @@ +#[macro_use] +pub mod error; +#[cfg(feature = "bloom_filter")] +pub mod bloom_filter; +pub mod compression; +pub mod encoding; +pub mod metadata; +pub mod page; +mod parquet_bridge; +pub mod read; +pub mod schema; +pub mod statistics; +pub mod types; +pub mod write; + +use std::ops::Deref; + +use polars_parquet_format as thrift_format; +use polars_utils::mmap::MemSlice; +pub use streaming_decompression::{FallibleStreamingIterator, fallible_streaming_iterator}; + +pub const HEADER_SIZE: u64 = PARQUET_MAGIC.len() as u64; +pub const FOOTER_SIZE: u64 = 8; +pub const PARQUET_MAGIC: [u8; 4] = [b'P', b'A', b'R', b'1']; + +/// The number of bytes read at the end of the parquet file on first read +const DEFAULT_FOOTER_READ_SIZE: u64 = 64 * 1024; + +/// A copy-on-write buffer over bytes +#[derive(Debug, Clone)] +pub enum CowBuffer { + Borrowed(MemSlice), + Owned(Vec), +} + +impl Deref for CowBuffer { + type Target = [u8]; + + #[inline(always)] + fn deref(&self) -> &Self::Target { + match self { + CowBuffer::Borrowed(v) => v.deref(), + CowBuffer::Owned(v) => v.deref(), + } + } +} + +impl CowBuffer { + pub fn to_mut(&mut self) -> &mut Vec { + match self { + CowBuffer::Borrowed(v) => { + *self = Self::Owned(v.clone().to_vec()); + self.to_mut() + }, + CowBuffer::Owned(v) => v, + } + } + + pub fn into_vec(self) -> Vec { + match self { + CowBuffer::Borrowed(v) => v.to_vec(), + CowBuffer::Owned(v) => v, + } + } +} diff --git a/crates/polars-parquet/src/parquet/page/mod.rs b/crates/polars-parquet/src/parquet/page/mod.rs new file mode 100644 index 000000000000..363999c1175c --- /dev/null +++ b/crates/polars-parquet/src/parquet/page/mod.rs @@ -0,0 +1,453 @@ +use super::CowBuffer; +use crate::parquet::compression::Compression; +use crate::parquet::encoding::{Encoding, get_length}; +use crate::parquet::error::{ParquetError, ParquetResult}; +use crate::parquet::metadata::Descriptor; +pub use crate::parquet::parquet_bridge::{DataPageHeaderExt, PageType}; +use crate::parquet::statistics::Statistics; +pub use crate::parquet::thrift_format::{ + DataPageHeader as DataPageHeaderV1, DataPageHeaderV2, Encoding as FormatEncoding, + PageHeader as ParquetPageHeader, +}; + +pub enum PageResult { + Single(Page), + Two { dict: DictPage, data: DataPage }, +} + +/// A [`CompressedDataPage`] is compressed, encoded representation of a Parquet data page. +/// It holds actual data and thus cloning it is expensive. +#[derive(Debug)] +pub struct CompressedDataPage { + pub(crate) header: DataPageHeader, + pub(crate) buffer: CowBuffer, + pub(crate) compression: Compression, + uncompressed_page_size: usize, + pub(crate) descriptor: Descriptor, + pub num_rows: Option, +} + +impl CompressedDataPage { + /// Returns a new [`CompressedDataPage`]. + pub fn new( + header: DataPageHeader, + buffer: CowBuffer, + compression: Compression, + uncompressed_page_size: usize, + descriptor: Descriptor, + num_rows: usize, + ) -> Self { + Self { + header, + buffer, + compression, + uncompressed_page_size, + descriptor, + num_rows: Some(num_rows), + } + } + + /// Returns a new [`CompressedDataPage`]. + pub(crate) fn new_read( + header: DataPageHeader, + buffer: CowBuffer, + compression: Compression, + uncompressed_page_size: usize, + descriptor: Descriptor, + ) -> Self { + Self { + header, + buffer, + compression, + uncompressed_page_size, + descriptor, + num_rows: None, + } + } + + pub fn header(&self) -> &DataPageHeader { + &self.header + } + + pub fn uncompressed_size(&self) -> usize { + self.uncompressed_page_size + } + + pub fn compressed_size(&self) -> usize { + self.buffer.len() + } + + /// The compression of the data in this page. + /// Note that what is compressed in a page depends on its version: + /// in V1, the whole data (`[repetition levels][definition levels][values]`) is compressed; in V2 only the values are compressed. + pub fn compression(&self) -> Compression { + self.compression + } + + pub fn num_values(&self) -> usize { + self.header.num_values() + } + + pub fn num_rows(&self) -> Option { + self.num_rows + } + + /// Decodes the raw statistics into a statistics + pub fn statistics(&self) -> Option> { + match &self.header { + DataPageHeader::V1(d) => d + .statistics + .as_ref() + .map(|x| Statistics::deserialize(x, self.descriptor.primitive_type.clone())), + DataPageHeader::V2(d) => d + .statistics + .as_ref() + .map(|x| Statistics::deserialize(x, self.descriptor.primitive_type.clone())), + } + } + + pub fn slice_mut(&mut self) -> &mut CowBuffer { + &mut self.buffer + } +} + +#[derive(Debug, Clone)] +pub enum DataPageHeader { + V1(DataPageHeaderV1), + V2(DataPageHeaderV2), +} + +impl DataPageHeader { + pub fn num_values(&self) -> usize { + match &self { + DataPageHeader::V1(d) => d.num_values as usize, + DataPageHeader::V2(d) => d.num_values as usize, + } + } + + pub fn null_count(&self) -> Option { + match &self { + DataPageHeader::V1(_) => None, + DataPageHeader::V2(d) => Some(d.num_nulls as usize), + } + } + + pub fn encoding(&self) -> FormatEncoding { + match self { + DataPageHeader::V1(d) => d.encoding, + DataPageHeader::V2(d) => d.encoding, + } + } + + pub fn is_dictionary_encoded(&self) -> bool { + matches!(self.encoding(), FormatEncoding::RLE_DICTIONARY) + } +} + +/// A [`DataPage`] is an uncompressed, encoded representation of a Parquet data page. It holds actual data +/// and thus cloning it is expensive. +#[derive(Debug, Clone)] +pub struct DataPage { + pub(super) header: DataPageHeader, + pub(super) buffer: CowBuffer, + pub descriptor: Descriptor, + pub num_rows: Option, +} + +impl DataPage { + pub fn new( + header: DataPageHeader, + buffer: CowBuffer, + descriptor: Descriptor, + num_rows: usize, + ) -> Self { + Self { + header, + buffer, + descriptor, + num_rows: Some(num_rows), + } + } + + pub(crate) fn new_read( + header: DataPageHeader, + buffer: CowBuffer, + descriptor: Descriptor, + ) -> Self { + Self { + header, + buffer, + descriptor, + num_rows: None, + } + } + + pub fn header(&self) -> &DataPageHeader { + &self.header + } + + pub fn buffer(&self) -> &[u8] { + &self.buffer + } + + /// Returns a mutable reference to the internal buffer. + /// Useful to recover the buffer after the page has been decoded. + pub fn buffer_mut(&mut self) -> &mut Vec { + self.buffer.to_mut() + } + + pub fn num_values(&self) -> usize { + self.header.num_values() + } + + pub fn null_count(&self) -> Option { + self.header.null_count() + } + + pub fn num_rows(&self) -> Option { + self.num_rows + } + + pub fn encoding(&self) -> Encoding { + match &self.header { + DataPageHeader::V1(d) => d.encoding(), + DataPageHeader::V2(d) => d.encoding(), + } + } + + pub fn definition_level_encoding(&self) -> Encoding { + match &self.header { + DataPageHeader::V1(d) => d.definition_level_encoding(), + DataPageHeader::V2(_) => Encoding::Rle, + } + } + + pub fn repetition_level_encoding(&self) -> Encoding { + match &self.header { + DataPageHeader::V1(d) => d.repetition_level_encoding(), + DataPageHeader::V2(_) => Encoding::Rle, + } + } + + /// Decodes the raw statistics into a statistics + pub fn statistics(&self) -> Option> { + match &self.header { + DataPageHeader::V1(d) => d + .statistics + .as_ref() + .map(|x| Statistics::deserialize(x, self.descriptor.primitive_type.clone())), + DataPageHeader::V2(d) => d + .statistics + .as_ref() + .map(|x| Statistics::deserialize(x, self.descriptor.primitive_type.clone())), + } + } +} + +/// A [`Page`] is an uncompressed, encoded representation of a Parquet page. It may hold actual data +/// and thus cloning it may be expensive. +#[derive(Debug)] +#[allow(clippy::large_enum_variant)] +pub enum Page { + /// A [`DataPage`] + Data(DataPage), + /// A [`DictPage`] + Dict(DictPage), +} + +impl Page { + pub(crate) fn buffer_mut(&mut self) -> &mut Vec { + match self { + Self::Data(page) => page.buffer.to_mut(), + Self::Dict(page) => page.buffer.to_mut(), + } + } + + pub(crate) fn unwrap_data(self) -> DataPage { + match self { + Self::Data(page) => page, + _ => panic!(), + } + } +} + +/// A [`CompressedPage`] is a compressed, encoded representation of a Parquet page. It holds actual data +/// and thus cloning it is expensive. +#[derive(Debug)] +#[allow(clippy::large_enum_variant)] +pub enum CompressedPage { + Data(CompressedDataPage), + Dict(CompressedDictPage), +} + +impl CompressedPage { + pub(crate) fn buffer_mut(&mut self) -> &mut Vec { + match self { + CompressedPage::Data(page) => page.buffer.to_mut(), + CompressedPage::Dict(page) => page.buffer.to_mut(), + } + } + + pub(crate) fn compression(&self) -> Compression { + match self { + CompressedPage::Data(page) => page.compression(), + CompressedPage::Dict(page) => page.compression(), + } + } + + pub(crate) fn num_values(&self) -> usize { + match self { + CompressedPage::Data(page) => page.num_values(), + CompressedPage::Dict(_) => 0, + } + } + + pub(crate) fn num_rows(&self) -> Option { + match self { + CompressedPage::Data(page) => page.num_rows(), + CompressedPage::Dict(_) => Some(0), + } + } +} + +/// An uncompressed, encoded dictionary page. +#[derive(Debug, Clone)] +pub struct DictPage { + pub buffer: CowBuffer, + pub num_values: usize, + pub is_sorted: bool, +} + +impl DictPage { + pub fn new(buffer: CowBuffer, num_values: usize, is_sorted: bool) -> Self { + Self { + buffer, + num_values, + is_sorted, + } + } +} + +/// A compressed, encoded dictionary page. +#[derive(Debug)] +pub struct CompressedDictPage { + pub(crate) buffer: CowBuffer, + compression: Compression, + pub(crate) num_values: usize, + pub(crate) uncompressed_page_size: usize, + pub is_sorted: bool, +} + +impl CompressedDictPage { + pub fn new( + buffer: CowBuffer, + compression: Compression, + uncompressed_page_size: usize, + num_values: usize, + is_sorted: bool, + ) -> Self { + Self { + buffer, + compression, + uncompressed_page_size, + num_values, + is_sorted, + } + } + + /// The compression of the data in this page. + pub fn compression(&self) -> Compression { + self.compression + } +} + +pub struct EncodedSplitBuffer<'a> { + /// Encoded Repetition Levels + pub rep: &'a [u8], + /// Encoded Definition Levels + pub def: &'a [u8], + /// Encoded Values + pub values: &'a [u8], +} + +/// Splits the page buffer into 3 slices corresponding to (encoded rep levels, encoded def levels, encoded values) for v1 pages. +#[inline] +pub fn split_buffer_v1( + buffer: &[u8], + has_rep: bool, + has_def: bool, +) -> ParquetResult { + let (rep, buffer) = if has_rep { + let level_buffer_length = get_length(buffer).ok_or_else(|| { + ParquetError::oos( + "The number of bytes declared in v1 rep levels is higher than the page size", + ) + })?; + + if buffer.len() < level_buffer_length + 4 { + return Err(ParquetError::oos( + "The number of bytes declared in v1 rep levels is higher than the page size", + )); + } + + buffer[4..].split_at(level_buffer_length) + } else { + (&[] as &[u8], buffer) + }; + + let (def, buffer) = if has_def { + let level_buffer_length = get_length(buffer).ok_or_else(|| { + ParquetError::oos( + "The number of bytes declared in v1 def levels is higher than the page size", + ) + })?; + + if buffer.len() < level_buffer_length + 4 { + return Err(ParquetError::oos( + "The number of bytes declared in v1 def levels is higher than the page size", + )); + } + + buffer[4..].split_at(level_buffer_length) + } else { + (&[] as &[u8], buffer) + }; + + Ok(EncodedSplitBuffer { + rep, + def, + values: buffer, + }) +} + +/// Splits the page buffer into 3 slices corresponding to (encoded rep levels, encoded def levels, encoded values) for v2 pages. +pub fn split_buffer_v2( + buffer: &[u8], + rep_level_buffer_length: usize, + def_level_buffer_length: usize, +) -> ParquetResult { + let (rep, buffer) = buffer.split_at(rep_level_buffer_length); + let (def, values) = buffer.split_at(def_level_buffer_length); + + Ok(EncodedSplitBuffer { rep, def, values }) +} + +/// Splits the page buffer into 3 slices corresponding to (encoded rep levels, encoded def levels, encoded values). +pub fn split_buffer(page: &DataPage) -> ParquetResult { + match page.header() { + DataPageHeader::V1(_) => split_buffer_v1( + page.buffer(), + page.descriptor.max_rep_level > 0, + page.descriptor.max_def_level > 0, + ), + DataPageHeader::V2(header) => { + let def_level_buffer_length: usize = header.definition_levels_byte_length.try_into()?; + let rep_level_buffer_length: usize = header.repetition_levels_byte_length.try_into()?; + split_buffer_v2( + page.buffer(), + rep_level_buffer_length, + def_level_buffer_length, + ) + }, + } +} diff --git a/crates/polars-parquet/src/parquet/parquet_bridge.rs b/crates/polars-parquet/src/parquet/parquet_bridge.rs new file mode 100644 index 000000000000..523ea0e3e12e --- /dev/null +++ b/crates/polars-parquet/src/parquet/parquet_bridge.rs @@ -0,0 +1,706 @@ +// Bridges structs from thrift-generated code to rust enums. + +#[cfg(feature = "serde_types")] +use serde::{Deserialize, Serialize}; + +use super::thrift_format::{ + BoundaryOrder as ParquetBoundaryOrder, CompressionCodec, DataPageHeader, DataPageHeaderV2, + DecimalType, Encoding as ParquetEncoding, FieldRepetitionType, IntType, + LogicalType as ParquetLogicalType, PageType as ParquetPageType, TimeType, + TimeUnit as ParquetTimeUnit, TimestampType, +}; +use crate::parquet::error::ParquetError; + +/// The repetition of a parquet field +#[derive(Debug, Eq, PartialEq, Hash, Clone, Copy)] +#[cfg_attr(feature = "serde_types", derive(Deserialize, Serialize))] +pub enum Repetition { + /// When the field has no null values + Required, + /// When the field may have null values + Optional, + /// When the field may be repeated (list field) + Repeated, +} + +impl TryFrom for Repetition { + type Error = ParquetError; + + fn try_from(repetition: FieldRepetitionType) -> Result { + Ok(match repetition { + FieldRepetitionType::REQUIRED => Repetition::Required, + FieldRepetitionType::OPTIONAL => Repetition::Optional, + FieldRepetitionType::REPEATED => Repetition::Repeated, + _ => return Err(ParquetError::oos("Thrift out of range")), + }) + } +} + +impl From for FieldRepetitionType { + fn from(repetition: Repetition) -> Self { + match repetition { + Repetition::Required => FieldRepetitionType::REQUIRED, + Repetition::Optional => FieldRepetitionType::OPTIONAL, + Repetition::Repeated => FieldRepetitionType::REPEATED, + } + } +} + +#[derive(Debug, Eq, PartialEq, Hash, Clone, Copy)] +#[cfg_attr(feature = "serde_types", derive(Deserialize, Serialize))] +pub enum Compression { + Uncompressed, + Snappy, + Gzip, + Lzo, + Brotli, + Lz4, + Zstd, + Lz4Raw, +} + +impl TryFrom for Compression { + type Error = ParquetError; + + fn try_from(codec: CompressionCodec) -> Result { + Ok(match codec { + CompressionCodec::UNCOMPRESSED => Compression::Uncompressed, + CompressionCodec::SNAPPY => Compression::Snappy, + CompressionCodec::GZIP => Compression::Gzip, + CompressionCodec::LZO => Compression::Lzo, + CompressionCodec::BROTLI => Compression::Brotli, + CompressionCodec::LZ4 => Compression::Lz4, + CompressionCodec::ZSTD => Compression::Zstd, + CompressionCodec::LZ4_RAW => Compression::Lz4Raw, + _ => return Err(ParquetError::oos("Thrift out of range")), + }) + } +} + +impl From for CompressionCodec { + fn from(codec: Compression) -> Self { + match codec { + Compression::Uncompressed => CompressionCodec::UNCOMPRESSED, + Compression::Snappy => CompressionCodec::SNAPPY, + Compression::Gzip => CompressionCodec::GZIP, + Compression::Lzo => CompressionCodec::LZO, + Compression::Brotli => CompressionCodec::BROTLI, + Compression::Lz4 => CompressionCodec::LZ4, + Compression::Zstd => CompressionCodec::ZSTD, + Compression::Lz4Raw => CompressionCodec::LZ4_RAW, + } + } +} + +/// Defines the compression settings for writing a parquet file. +/// +/// If None is provided as a compression setting, then the default compression level is used. +#[derive(Debug, Eq, PartialEq, Hash, Clone, Copy)] +pub enum CompressionOptions { + Uncompressed, + Snappy, + Gzip(Option), + Lzo, + Brotli(Option), + Lz4, + Zstd(Option), + Lz4Raw, +} + +impl From for Compression { + fn from(value: CompressionOptions) -> Self { + match value { + CompressionOptions::Uncompressed => Compression::Uncompressed, + CompressionOptions::Snappy => Compression::Snappy, + CompressionOptions::Gzip(_) => Compression::Gzip, + CompressionOptions::Lzo => Compression::Lzo, + CompressionOptions::Brotli(_) => Compression::Brotli, + CompressionOptions::Lz4 => Compression::Lz4, + CompressionOptions::Zstd(_) => Compression::Zstd, + CompressionOptions::Lz4Raw => Compression::Lz4Raw, + } + } +} + +impl From for CompressionCodec { + fn from(codec: CompressionOptions) -> Self { + match codec { + CompressionOptions::Uncompressed => CompressionCodec::UNCOMPRESSED, + CompressionOptions::Snappy => CompressionCodec::SNAPPY, + CompressionOptions::Gzip(_) => CompressionCodec::GZIP, + CompressionOptions::Lzo => CompressionCodec::LZO, + CompressionOptions::Brotli(_) => CompressionCodec::BROTLI, + CompressionOptions::Lz4 => CompressionCodec::LZ4, + CompressionOptions::Zstd(_) => CompressionCodec::ZSTD, + CompressionOptions::Lz4Raw => CompressionCodec::LZ4_RAW, + } + } +} + +/// Defines valid compression levels. +pub(crate) trait CompressionLevel { + const MINIMUM_LEVEL: T; + const MAXIMUM_LEVEL: T; + + /// Tests if the provided compression level is valid. + fn is_valid_level(level: T) -> Result<(), ParquetError> { + let compression_range = Self::MINIMUM_LEVEL..=Self::MAXIMUM_LEVEL; + if compression_range.contains(&level) { + Ok(()) + } else { + Err(ParquetError::InvalidParameter(format!( + "valid compression range {}..={} exceeded.", + compression_range.start(), + compression_range.end() + ))) + } + } +} + +/// Represents a valid brotli compression level. +#[derive(Debug, Eq, PartialEq, Hash, Clone, Copy)] +pub struct BrotliLevel(u32); + +impl Default for BrotliLevel { + fn default() -> Self { + Self(1) + } +} + +impl CompressionLevel for BrotliLevel { + const MINIMUM_LEVEL: u32 = 0; + const MAXIMUM_LEVEL: u32 = 11; +} + +impl BrotliLevel { + /// Attempts to create a brotli compression level. + /// + /// Compression levels must be valid. + pub fn try_new(level: u32) -> Result { + Self::is_valid_level(level).map(|_| Self(level)) + } + + /// Returns the compression level. + pub fn compression_level(&self) -> u32 { + self.0 + } +} + +/// Represents a valid gzip compression level. +#[derive(Debug, Eq, PartialEq, Hash, Clone, Copy)] +pub struct GzipLevel(u8); + +impl Default for GzipLevel { + fn default() -> Self { + // The default as of miniz_oxide 0.5.1 is 6 for compression level + // (miniz_oxide::deflate::CompressionLevel::DefaultLevel) + Self(6) + } +} + +impl CompressionLevel for GzipLevel { + const MINIMUM_LEVEL: u8 = 0; + const MAXIMUM_LEVEL: u8 = 10; +} + +impl GzipLevel { + /// Attempts to create a gzip compression level. + /// + /// Compression levels must be valid (i.e. be acceptable for [`flate2::Compression`]). + pub fn try_new(level: u8) -> Result { + Self::is_valid_level(level).map(|_| Self(level)) + } + + /// Returns the compression level. + pub fn compression_level(&self) -> u8 { + self.0 + } +} + +#[cfg(feature = "gzip")] +impl From for flate2::Compression { + fn from(level: GzipLevel) -> Self { + Self::new(level.compression_level() as u32) + } +} + +/// Represents a valid zstd compression level. +#[derive(Debug, Eq, PartialEq, Hash, Clone, Copy)] +pub struct ZstdLevel(i32); + +impl CompressionLevel for ZstdLevel { + // zstd binds to C, and hence zstd::compression_level_range() is not const as this calls the + // underlying C library. + const MINIMUM_LEVEL: i32 = 1; + const MAXIMUM_LEVEL: i32 = 22; +} + +impl ZstdLevel { + /// Attempts to create a zstd compression level from a given compression level. + /// + /// Compression levels must be valid (i.e. be acceptable for [`zstd::compression_level_range`]). + pub fn try_new(level: i32) -> Result { + Self::is_valid_level(level).map(|_| Self(level)) + } + + /// Returns the compression level. + pub fn compression_level(&self) -> i32 { + self.0 + } +} + +#[cfg(feature = "zstd")] +impl Default for ZstdLevel { + fn default() -> Self { + Self(zstd::DEFAULT_COMPRESSION_LEVEL) + } +} + +#[derive(Debug, Eq, PartialEq, Hash, Clone, Copy)] +pub enum PageType { + DataPage, + DataPageV2, + DictionaryPage, +} + +impl TryFrom for PageType { + type Error = ParquetError; + + fn try_from(type_: ParquetPageType) -> Result { + Ok(match type_ { + ParquetPageType::DATA_PAGE => PageType::DataPage, + ParquetPageType::DATA_PAGE_V2 => PageType::DataPageV2, + ParquetPageType::DICTIONARY_PAGE => PageType::DictionaryPage, + _ => return Err(ParquetError::oos("Thrift out of range")), + }) + } +} + +impl From for ParquetPageType { + fn from(type_: PageType) -> Self { + match type_ { + PageType::DataPage => ParquetPageType::DATA_PAGE, + PageType::DataPageV2 => ParquetPageType::DATA_PAGE_V2, + PageType::DictionaryPage => ParquetPageType::DICTIONARY_PAGE, + } + } +} + +#[derive(Debug, Eq, PartialEq, Hash, Clone, Copy)] +pub enum Encoding { + /// Default encoding. + /// BOOLEAN - 1 bit per value. 0 is false; 1 is true. + /// INT32 - 4 bytes per value. Stored as little-endian. + /// INT64 - 8 bytes per value. Stored as little-endian. + /// FLOAT - 4 bytes per value. IEEE. Stored as little-endian. + /// DOUBLE - 8 bytes per value. IEEE. Stored as little-endian. + /// BYTE_ARRAY - 4 byte length stored as little endian, followed by bytes. + /// FIXED_LEN_BYTE_ARRAY - Just the bytes. + Plain, + /// Deprecated: Dictionary encoding. The values in the dictionary are encoded in the + /// plain type. + /// in a data page use RLE_DICTIONARY instead. + /// in a Dictionary page use PLAIN instead + PlainDictionary, + /// Group packed run length encoding. Usable for definition/repetition levels + /// encoding and Booleans (on one bit: 0 is false; 1 is true.) + Rle, + /// Bit packed encoding. This can only be used if the data has a known max + /// width. Usable for definition/repetition levels encoding. + BitPacked, + /// Delta encoding for integers. This can be used for int columns and works best + /// on sorted data + DeltaBinaryPacked, + /// Encoding for byte arrays to separate the length values and the data. The lengths + /// are encoded using DELTA_BINARY_PACKED + DeltaLengthByteArray, + /// Incremental-encoded byte array. Prefix lengths are encoded using DELTA_BINARY_PACKED. + /// Suffixes are stored as delta length byte arrays. + DeltaByteArray, + /// Dictionary encoding: the ids are encoded using the RLE encoding + RleDictionary, + /// Encoding for floating-point data. + /// K byte-streams are created where K is the size in bytes of the data type. + /// The individual bytes of an FP value are scattered to the corresponding stream and + /// the streams are concatenated. + /// This itself does not reduce the size of the data but can lead to better compression + /// afterwards. + ByteStreamSplit, +} + +impl TryFrom for Encoding { + type Error = ParquetError; + + fn try_from(encoding: ParquetEncoding) -> Result { + Ok(match encoding { + ParquetEncoding::PLAIN => Encoding::Plain, + ParquetEncoding::PLAIN_DICTIONARY => Encoding::PlainDictionary, + ParquetEncoding::RLE => Encoding::Rle, + ParquetEncoding::BIT_PACKED => Encoding::BitPacked, + ParquetEncoding::DELTA_BINARY_PACKED => Encoding::DeltaBinaryPacked, + ParquetEncoding::DELTA_LENGTH_BYTE_ARRAY => Encoding::DeltaLengthByteArray, + ParquetEncoding::DELTA_BYTE_ARRAY => Encoding::DeltaByteArray, + ParquetEncoding::RLE_DICTIONARY => Encoding::RleDictionary, + ParquetEncoding::BYTE_STREAM_SPLIT => Encoding::ByteStreamSplit, + _ => return Err(ParquetError::oos("Thrift out of range")), + }) + } +} + +impl From for ParquetEncoding { + fn from(encoding: Encoding) -> Self { + match encoding { + Encoding::Plain => ParquetEncoding::PLAIN, + Encoding::PlainDictionary => ParquetEncoding::PLAIN_DICTIONARY, + Encoding::Rle => ParquetEncoding::RLE, + Encoding::BitPacked => ParquetEncoding::BIT_PACKED, + Encoding::DeltaBinaryPacked => ParquetEncoding::DELTA_BINARY_PACKED, + Encoding::DeltaLengthByteArray => ParquetEncoding::DELTA_LENGTH_BYTE_ARRAY, + Encoding::DeltaByteArray => ParquetEncoding::DELTA_BYTE_ARRAY, + Encoding::RleDictionary => ParquetEncoding::RLE_DICTIONARY, + Encoding::ByteStreamSplit => ParquetEncoding::BYTE_STREAM_SPLIT, + } + } +} + +/// Enum to annotate whether lists of min/max elements inside ColumnIndex +/// are ordered and if so, in which direction. +#[derive(Debug, Eq, PartialEq, Hash, Clone, Copy)] +pub enum BoundaryOrder { + Unordered, + Ascending, + Descending, +} + +impl Default for BoundaryOrder { + fn default() -> Self { + Self::Unordered + } +} + +impl TryFrom for BoundaryOrder { + type Error = ParquetError; + + fn try_from(encoding: ParquetBoundaryOrder) -> Result { + Ok(match encoding { + ParquetBoundaryOrder::UNORDERED => BoundaryOrder::Unordered, + ParquetBoundaryOrder::ASCENDING => BoundaryOrder::Ascending, + ParquetBoundaryOrder::DESCENDING => BoundaryOrder::Descending, + _ => return Err(ParquetError::oos("BoundaryOrder Thrift value out of range")), + }) + } +} + +impl From for ParquetBoundaryOrder { + fn from(encoding: BoundaryOrder) -> Self { + match encoding { + BoundaryOrder::Unordered => ParquetBoundaryOrder::UNORDERED, + BoundaryOrder::Ascending => ParquetBoundaryOrder::ASCENDING, + BoundaryOrder::Descending => ParquetBoundaryOrder::DESCENDING, + } + } +} + +pub trait DataPageHeaderExt { + fn encoding(&self) -> Encoding; + fn repetition_level_encoding(&self) -> Encoding; + fn definition_level_encoding(&self) -> Encoding; +} + +impl DataPageHeaderExt for DataPageHeader { + fn encoding(&self) -> Encoding { + self.encoding.try_into().unwrap() + } + + fn repetition_level_encoding(&self) -> Encoding { + self.repetition_level_encoding.try_into().unwrap() + } + + fn definition_level_encoding(&self) -> Encoding { + self.definition_level_encoding.try_into().unwrap() + } +} + +impl DataPageHeaderExt for DataPageHeaderV2 { + fn encoding(&self) -> Encoding { + self.encoding.try_into().unwrap() + } + + fn repetition_level_encoding(&self) -> Encoding { + Encoding::Rle + } + + fn definition_level_encoding(&self) -> Encoding { + Encoding::Rle + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde_types", derive(Deserialize, Serialize))] +pub enum TimeUnit { + Milliseconds, + Microseconds, + Nanoseconds, +} + +impl From for TimeUnit { + fn from(encoding: ParquetTimeUnit) -> Self { + match encoding { + ParquetTimeUnit::MILLIS(_) => TimeUnit::Milliseconds, + ParquetTimeUnit::MICROS(_) => TimeUnit::Microseconds, + ParquetTimeUnit::NANOS(_) => TimeUnit::Nanoseconds, + } + } +} + +impl From for ParquetTimeUnit { + fn from(unit: TimeUnit) -> Self { + match unit { + TimeUnit::Milliseconds => ParquetTimeUnit::MILLIS(Default::default()), + TimeUnit::Microseconds => ParquetTimeUnit::MICROS(Default::default()), + TimeUnit::Nanoseconds => ParquetTimeUnit::NANOS(Default::default()), + } + } +} + +/// Enum of all valid logical integer types +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde_types", derive(Deserialize, Serialize))] +pub enum IntegerType { + Int8, + Int16, + Int32, + Int64, + UInt8, + UInt16, + UInt32, + UInt64, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde_types", derive(Deserialize, Serialize))] +pub enum PrimitiveLogicalType { + String, + Enum, + Decimal(usize, usize), + Date, + Time { + unit: TimeUnit, + is_adjusted_to_utc: bool, + }, + Timestamp { + unit: TimeUnit, + is_adjusted_to_utc: bool, + }, + Integer(IntegerType), + Unknown, + Json, + Bson, + Uuid, + Float16, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde_types", derive(Deserialize, Serialize))] +pub enum GroupLogicalType { + Map, + List, +} + +impl From for ParquetLogicalType { + fn from(type_: GroupLogicalType) -> Self { + match type_ { + GroupLogicalType::Map => ParquetLogicalType::MAP(Default::default()), + GroupLogicalType::List => ParquetLogicalType::LIST(Default::default()), + } + } +} + +impl From<(i32, bool)> for IntegerType { + fn from((bit_width, is_signed): (i32, bool)) -> Self { + match (bit_width, is_signed) { + (8, true) => IntegerType::Int8, + (16, true) => IntegerType::Int16, + (32, true) => IntegerType::Int32, + (64, true) => IntegerType::Int64, + (8, false) => IntegerType::UInt8, + (16, false) => IntegerType::UInt16, + (32, false) => IntegerType::UInt32, + (64, false) => IntegerType::UInt64, + // The above are the only possible annotations for parquet's int32. Anything else + // is a deviation to the parquet specification and we ignore + _ => IntegerType::Int32, + } + } +} + +impl From for (usize, bool) { + fn from(type_: IntegerType) -> (usize, bool) { + match type_ { + IntegerType::Int8 => (8, true), + IntegerType::Int16 => (16, true), + IntegerType::Int32 => (32, true), + IntegerType::Int64 => (64, true), + IntegerType::UInt8 => (8, false), + IntegerType::UInt16 => (16, false), + IntegerType::UInt32 => (32, false), + IntegerType::UInt64 => (64, false), + } + } +} + +impl TryFrom for PrimitiveLogicalType { + type Error = ParquetError; + + fn try_from(type_: ParquetLogicalType) -> Result { + Ok(match type_ { + ParquetLogicalType::STRING(_) => PrimitiveLogicalType::String, + ParquetLogicalType::ENUM(_) => PrimitiveLogicalType::Enum, + ParquetLogicalType::DECIMAL(decimal) => PrimitiveLogicalType::Decimal( + decimal.precision.try_into()?, + decimal.scale.try_into()?, + ), + ParquetLogicalType::DATE(_) => PrimitiveLogicalType::Date, + ParquetLogicalType::TIME(time) => PrimitiveLogicalType::Time { + unit: time.unit.into(), + is_adjusted_to_utc: time.is_adjusted_to_u_t_c, + }, + ParquetLogicalType::TIMESTAMP(time) => PrimitiveLogicalType::Timestamp { + unit: time.unit.into(), + is_adjusted_to_utc: time.is_adjusted_to_u_t_c, + }, + ParquetLogicalType::INTEGER(int) => { + PrimitiveLogicalType::Integer((int.bit_width as i32, int.is_signed).into()) + }, + ParquetLogicalType::UNKNOWN(_) => PrimitiveLogicalType::Unknown, + ParquetLogicalType::JSON(_) => PrimitiveLogicalType::Json, + ParquetLogicalType::BSON(_) => PrimitiveLogicalType::Bson, + ParquetLogicalType::UUID(_) => PrimitiveLogicalType::Uuid, + ParquetLogicalType::FLOAT16(_) => PrimitiveLogicalType::Float16, + _ => return Err(ParquetError::oos("LogicalType value out of range")), + }) + } +} + +impl TryFrom for GroupLogicalType { + type Error = ParquetError; + + fn try_from(type_: ParquetLogicalType) -> Result { + Ok(match type_ { + ParquetLogicalType::LIST(_) => GroupLogicalType::List, + ParquetLogicalType::MAP(_) => GroupLogicalType::Map, + _ => return Err(ParquetError::oos("LogicalType value out of range")), + }) + } +} + +impl From for ParquetLogicalType { + fn from(type_: PrimitiveLogicalType) -> Self { + match type_ { + PrimitiveLogicalType::String => ParquetLogicalType::STRING(Default::default()), + PrimitiveLogicalType::Enum => ParquetLogicalType::ENUM(Default::default()), + PrimitiveLogicalType::Decimal(precision, scale) => { + ParquetLogicalType::DECIMAL(DecimalType { + precision: precision as i32, + scale: scale as i32, + }) + }, + PrimitiveLogicalType::Date => ParquetLogicalType::DATE(Default::default()), + PrimitiveLogicalType::Time { + unit, + is_adjusted_to_utc, + } => ParquetLogicalType::TIME(TimeType { + unit: unit.into(), + is_adjusted_to_u_t_c: is_adjusted_to_utc, + }), + PrimitiveLogicalType::Timestamp { + unit, + is_adjusted_to_utc, + } => ParquetLogicalType::TIMESTAMP(TimestampType { + unit: unit.into(), + is_adjusted_to_u_t_c: is_adjusted_to_utc, + }), + PrimitiveLogicalType::Integer(integer) => { + let (bit_width, is_signed) = integer.into(); + ParquetLogicalType::INTEGER(IntType { + bit_width: bit_width as i8, + is_signed, + }) + }, + PrimitiveLogicalType::Unknown => ParquetLogicalType::UNKNOWN(Default::default()), + PrimitiveLogicalType::Json => ParquetLogicalType::JSON(Default::default()), + PrimitiveLogicalType::Bson => ParquetLogicalType::BSON(Default::default()), + PrimitiveLogicalType::Uuid => ParquetLogicalType::UUID(Default::default()), + PrimitiveLogicalType::Float16 => ParquetLogicalType::FLOAT16(Default::default()), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn round_trip_primitive() -> Result<(), ParquetError> { + use PrimitiveLogicalType::*; + let a = vec![ + String, + Enum, + Decimal(3, 1), + Date, + Time { + unit: TimeUnit::Milliseconds, + is_adjusted_to_utc: true, + }, + Timestamp { + unit: TimeUnit::Milliseconds, + is_adjusted_to_utc: true, + }, + Integer(IntegerType::Int16), + Unknown, + Json, + Bson, + Uuid, + ]; + for a in a { + let c: ParquetLogicalType = a.into(); + let e: PrimitiveLogicalType = c.try_into()?; + assert_eq!(e, a); + } + Ok(()) + } + + #[test] + fn round_trip_encoding() -> Result<(), ParquetError> { + use Encoding::*; + let a = vec![ + Plain, + PlainDictionary, + Rle, + BitPacked, + DeltaBinaryPacked, + DeltaLengthByteArray, + DeltaByteArray, + RleDictionary, + ByteStreamSplit, + ]; + for a in a { + let c: ParquetEncoding = a.into(); + let e: Encoding = c.try_into()?; + assert_eq!(e, a); + } + Ok(()) + } + + #[test] + fn round_compression() -> Result<(), ParquetError> { + use Compression::*; + let a = vec![Uncompressed, Snappy, Gzip, Lzo, Brotli, Lz4, Zstd, Lz4Raw]; + for a in a { + let c: CompressionCodec = a.into(); + let e: Compression = c.try_into()?; + assert_eq!(e, a); + } + Ok(()) + } +} diff --git a/crates/polars-parquet/src/parquet/read/column/mod.rs b/crates/polars-parquet/src/parquet/read/column/mod.rs new file mode 100644 index 000000000000..406fc1d850c0 --- /dev/null +++ b/crates/polars-parquet/src/parquet/read/column/mod.rs @@ -0,0 +1,148 @@ +use std::vec::IntoIter; + +use polars_utils::idx_vec::UnitVec; + +use super::{MemReader, PageReader, get_page_iterator}; +use crate::parquet::error::{ParquetError, ParquetResult}; +use crate::parquet::metadata::{ColumnChunkMetadata, RowGroupMetadata}; +use crate::parquet::page::CompressedPage; +use crate::parquet::schema::types::ParquetType; + +/// Returns a [`ColumnIterator`] of column chunks corresponding to `field`. +/// +/// Contrarily to [`get_page_iterator`] that returns a single iterator of pages, this iterator +/// iterates over columns, one by one, and returns a [`PageReader`] per column. +/// For primitive fields (e.g. `i64`), [`ColumnIterator`] yields exactly one column. +/// For complex fields, it yields multiple columns. +/// `max_page_size` is the maximum number of bytes allowed. +pub fn get_column_iterator<'a>( + reader: MemReader, + row_group: &'a RowGroupMetadata, + field_name: &str, + max_page_size: usize, +) -> ColumnIterator<'a> { + let columns = row_group + .columns_under_root_iter(field_name) + .unwrap() + .rev() + .collect::>(); + ColumnIterator::new(reader, columns, max_page_size) +} + +/// State of [`MutStreamingIterator`]. +#[derive(Debug)] +pub enum State { + /// Iterator still has elements + Some(T), + /// Iterator finished + Finished(Vec), +} + +/// A special kind of fallible streaming iterator where `advance` consumes the iterator. +pub trait MutStreamingIterator: Sized { + type Item; + type Error; + + fn advance(self) -> std::result::Result, Self::Error>; + fn get(&mut self) -> Option<&mut Self::Item>; +} + +/// A [`MutStreamingIterator`] that reads column chunks one by one, +/// returning a [`PageReader`] per column. +pub struct ColumnIterator<'a> { + reader: MemReader, + columns: UnitVec<&'a ColumnChunkMetadata>, + max_page_size: usize, +} + +impl<'a> ColumnIterator<'a> { + /// Returns a new [`ColumnIterator`] + /// `max_page_size` is the maximum allowed page size + pub fn new( + reader: MemReader, + columns: UnitVec<&'a ColumnChunkMetadata>, + max_page_size: usize, + ) -> Self { + Self { + reader, + columns, + max_page_size, + } + } +} + +impl<'a> Iterator for ColumnIterator<'a> { + type Item = ParquetResult<(PageReader, &'a ColumnChunkMetadata)>; + + fn next(&mut self) -> Option { + if self.columns.is_empty() { + return None; + }; + let column = self.columns.pop().unwrap(); + + let iter = + match get_page_iterator(column, self.reader.clone(), Vec::new(), self.max_page_size) { + Err(e) => return Some(Err(e)), + Ok(v) => v, + }; + Some(Ok((iter, column))) + } +} + +/// A [`MutStreamingIterator`] of pre-read column chunks +#[derive(Debug)] +pub struct ReadColumnIterator { + field: ParquetType, + chunks: Vec<( + Vec>, + ColumnChunkMetadata, + )>, + current: Option<( + IntoIter>, + ColumnChunkMetadata, + )>, +} + +impl ReadColumnIterator { + /// Returns a new [`ReadColumnIterator`] + pub fn new( + field: ParquetType, + chunks: Vec<( + Vec>, + ColumnChunkMetadata, + )>, + ) -> Self { + Self { + field, + chunks, + current: None, + } + } +} + +impl MutStreamingIterator for ReadColumnIterator { + type Item = ( + IntoIter>, + ColumnChunkMetadata, + ); + type Error = ParquetError; + + fn advance(mut self) -> Result, ParquetError> { + if self.chunks.is_empty() { + return Ok(State::Finished(vec![])); + } + self.current = self + .chunks + .pop() + .map(|(pages, meta)| (pages.into_iter(), meta)); + Ok(State::Some(Self { + field: self.field, + chunks: self.chunks, + current: self.current, + })) + } + + fn get(&mut self) -> Option<&mut Self::Item> { + self.current.as_mut() + } +} diff --git a/crates/polars-parquet/src/parquet/read/compression.rs b/crates/polars-parquet/src/parquet/read/compression.rs new file mode 100644 index 000000000000..b864b600075c --- /dev/null +++ b/crates/polars-parquet/src/parquet/read/compression.rs @@ -0,0 +1,294 @@ +use polars_parquet_format::DataPageHeaderV2; + +use super::PageReader; +use crate::parquet::CowBuffer; +use crate::parquet::compression::{self, Compression}; +use crate::parquet::error::{ParquetError, ParquetResult}; +use crate::parquet::page::{ + CompressedDataPage, CompressedPage, DataPage, DataPageHeader, DictPage, Page, +}; + +fn decompress_v1( + compressed: &[u8], + compression: Compression, + buffer: &mut [u8], +) -> ParquetResult<()> { + compression::decompress(compression, compressed, buffer) +} + +fn decompress_v2( + compressed: &[u8], + page_header: &DataPageHeaderV2, + compression: Compression, + buffer: &mut [u8], +) -> ParquetResult<()> { + // When processing data page v2, depending on enabled compression for the + // page, we should account for uncompressed data ('offset') of + // repetition and definition levels. + // + // We always use 0 offset for other pages other than v2, `true` flag means + // that compression will be applied if decompressor is defined + let offset = (page_header.definition_levels_byte_length + + page_header.repetition_levels_byte_length) as usize; + // When is_compressed flag is missing the page is considered compressed + let can_decompress = page_header.is_compressed.unwrap_or(true); + + if can_decompress { + if offset > buffer.len() || offset > compressed.len() { + return Err(ParquetError::oos( + "V2 Page Header reported incorrect offset to compressed data", + )); + } + + (buffer[..offset]).copy_from_slice(&compressed[..offset]); + + // https://github.com/pola-rs/polars/issues/22170 + if compressed.len() > offset { + compression::decompress(compression, &compressed[offset..], &mut buffer[offset..])?; + } + } else { + if buffer.len() != compressed.len() { + return Err(ParquetError::oos( + "V2 Page Header reported incorrect decompressed size", + )); + } + buffer.copy_from_slice(compressed); + } + Ok(()) +} + +/// Decompresses the page, using `buffer` for decompression. +/// If `page.buffer.len() == 0`, there was no decompression and the buffer was moved. +/// Else, decompression took place. +pub fn decompress(compressed_page: CompressedPage, buffer: &mut Vec) -> ParquetResult { + Ok(match (compressed_page.compression(), compressed_page) { + (Compression::Uncompressed, CompressedPage::Data(page)) => Page::Data(DataPage::new_read( + page.header, + page.buffer, + page.descriptor, + )), + (_, CompressedPage::Data(page)) => { + // prepare the compression buffer + let read_size = page.uncompressed_size(); + + if read_size > buffer.capacity() { + // dealloc and ignore region, replacing it by a new region. + // This won't reallocate - it frees and calls `alloc_zeroed` + *buffer = vec![0; read_size]; + } else if read_size > buffer.len() { + // fill what we need with zeros so that we can use them in `Read`. + // This won't reallocate + buffer.resize(read_size, 0); + } else { + buffer.truncate(read_size); + } + + match page.header() { + DataPageHeader::V1(_) => decompress_v1(&page.buffer, page.compression, buffer)?, + DataPageHeader::V2(header) => { + decompress_v2(&page.buffer, header, page.compression, buffer)? + }, + } + let buffer = CowBuffer::Owned(std::mem::take(buffer)); + + Page::Data(DataPage::new_read(page.header, buffer, page.descriptor)) + }, + (Compression::Uncompressed, CompressedPage::Dict(page)) => Page::Dict(DictPage { + buffer: page.buffer, + num_values: page.num_values, + is_sorted: page.is_sorted, + }), + (_, CompressedPage::Dict(page)) => { + // prepare the compression buffer + let read_size = page.uncompressed_page_size; + + if read_size > buffer.capacity() { + // dealloc and ignore region, replacing it by a new region. + // This won't reallocate - it frees and calls `alloc_zeroed` + *buffer = vec![0; read_size]; + } else if read_size > buffer.len() { + // fill what we need with zeros so that we can use them in `Read`. + // This won't reallocate + buffer.resize(read_size, 0); + } else { + buffer.truncate(read_size); + } + decompress_v1(&page.buffer, page.compression(), buffer)?; + let buffer = CowBuffer::Owned(std::mem::take(buffer)); + + Page::Dict(DictPage { + buffer, + num_values: page.num_values, + is_sorted: page.is_sorted, + }) + }, + }) +} + +type _Decompressor = streaming_decompression::Decompressor< + CompressedPage, + Page, + fn(CompressedPage, &mut Vec) -> ParquetResult, + ParquetError, + I, +>; + +impl streaming_decompression::Compressed for CompressedPage { + #[inline] + fn is_compressed(&self) -> bool { + self.compression() != Compression::Uncompressed + } +} + +impl streaming_decompression::Decompressed for Page { + #[inline] + fn buffer_mut(&mut self) -> &mut Vec { + self.buffer_mut() + } +} + +/// A [`FallibleStreamingIterator`] that decompresses [`CompressedPage`] into [`DataPage`]. +/// # Implementation +/// This decompressor uses an internal [`Vec`] to perform decompressions which +/// is reused across pages, so that a single allocation is required. +/// If the pages are not compressed, the internal buffer is not used. +pub struct BasicDecompressor { + reader: PageReader, + buffer: Vec, +} + +impl BasicDecompressor { + /// Create a new [`BasicDecompressor`] + pub fn new(reader: PageReader, buffer: Vec) -> Self { + Self { reader, buffer } + } + + /// The total number of values is given from the `ColumnChunk` metadata. + /// + /// - Nested column: equal to the number of non-null values at the lowest nesting level. + /// - Unnested column: equal to the number of non-null rows. + pub fn total_num_values(&self) -> usize { + self.reader.total_num_values() + } + + /// Returns its internal buffer, consuming itself. + pub fn into_inner(self) -> Vec { + self.buffer + } + + pub fn read_dict_page(&mut self) -> ParquetResult> { + match self.reader.read_dict()? { + None => Ok(None), + Some(p) => { + let num_values = p.num_values; + let page = + decompress(CompressedPage::Dict(p), &mut Vec::with_capacity(num_values))?; + + match page { + Page::Dict(d) => Ok(Some(d)), + Page::Data(_) => unreachable!(), + } + }, + } + } + + pub fn reuse_page_buffer(&mut self, page: DataPage) { + let buffer = match page.buffer { + CowBuffer::Borrowed(_) => return, + CowBuffer::Owned(vec) => vec, + }; + + if self.buffer.capacity() > buffer.capacity() { + return; + }; + + self.buffer = buffer; + } +} + +pub struct DataPageItem { + page: CompressedDataPage, +} + +impl DataPageItem { + pub fn num_values(&self) -> usize { + self.page.num_values() + } + + pub fn page(&self) -> &CompressedDataPage { + &self.page + } + + pub fn decompress(self, decompressor: &mut BasicDecompressor) -> ParquetResult { + let p = decompress(CompressedPage::Data(self.page), &mut decompressor.buffer)?; + let Page::Data(p) = p else { + panic!("Decompressing a data page should result in a data page"); + }; + + Ok(p) + } +} + +impl Iterator for BasicDecompressor { + type Item = ParquetResult; + + fn next(&mut self) -> Option { + let page = match self.reader.next() { + None => return None, + Some(Err(e)) => return Some(Err(e)), + Some(Ok(p)) => p, + }; + + let CompressedPage::Data(page) = page else { + return Some(Err(ParquetError::oos( + "Found dictionary page beyond the first page of a column chunk", + ))); + }; + + Some(Ok(DataPageItem { page })) + } + + fn size_hint(&self) -> (usize, Option) { + self.reader.size_hint() + } +} + +#[cfg(test)] +mod tests { + use polars_parquet_format::Encoding; + + use super::*; + + #[test] + fn test_decompress_v2_empty_datapage() { + let compressions = [ + Compression::Snappy, + Compression::Gzip, + Compression::Lzo, + Compression::Brotli, + Compression::Lz4, + Compression::Zstd, + Compression::Lz4Raw, + ]; + + // this datapage has an empty compressed section after the first two bytes (uncompressed definition levels) + let compressed: &mut Vec = &mut vec![0x03, 0x00]; + let page_header = DataPageHeaderV2::new(1, 1, 1, Encoding::PLAIN, 2, 0, true, None); + let buffer: &mut Vec = &mut vec![0, 2]; + + compressions.iter().for_each(|compression| { + test_decompress_v2_datapage(compressed, &page_header, *compression, buffer, compressed) + }); + } + + fn test_decompress_v2_datapage( + compressed: &[u8], + page_header: &DataPageHeaderV2, + compression: Compression, + buffer: &mut [u8], + expected: &[u8], + ) { + decompress_v2(compressed, page_header, compression, buffer).unwrap(); + assert_eq!(buffer, expected); + } +} diff --git a/crates/polars-parquet/src/parquet/read/levels.rs b/crates/polars-parquet/src/parquet/read/levels.rs new file mode 100644 index 000000000000..69d12cff9194 --- /dev/null +++ b/crates/polars-parquet/src/parquet/read/levels.rs @@ -0,0 +1,27 @@ +/// Returns the number of bits needed to store the given maximum definition or repetition level. +#[inline] +pub fn get_bit_width(max_level: i16) -> u32 { + 16 - max_level.leading_zeros() +} + +#[cfg(test)] +mod tests { + use super::get_bit_width; + + #[test] + fn test_get_bit_width() { + assert_eq!(0, get_bit_width(0)); + assert_eq!(1, get_bit_width(1)); + assert_eq!(2, get_bit_width(2)); + assert_eq!(2, get_bit_width(3)); + assert_eq!(3, get_bit_width(4)); + assert_eq!(3, get_bit_width(5)); + assert_eq!(3, get_bit_width(6)); + assert_eq!(3, get_bit_width(7)); + assert_eq!(4, get_bit_width(8)); + assert_eq!(4, get_bit_width(15)); + + assert_eq!(8, get_bit_width(255)); + assert_eq!(9, get_bit_width(256)); + } +} diff --git a/crates/polars-parquet/src/parquet/read/metadata.rs b/crates/polars-parquet/src/parquet/read/metadata.rs new file mode 100644 index 000000000000..627fdd8c92ba --- /dev/null +++ b/crates/polars-parquet/src/parquet/read/metadata.rs @@ -0,0 +1,100 @@ +use std::cmp::min; +use std::io::{Read, Seek, SeekFrom}; + +use polars_parquet_format::FileMetaData as TFileMetadata; +use polars_parquet_format::thrift::protocol::TCompactInputProtocol; + +use super::super::metadata::FileMetadata; +use super::super::{DEFAULT_FOOTER_READ_SIZE, FOOTER_SIZE, HEADER_SIZE, PARQUET_MAGIC}; +use crate::parquet::error::{ParquetError, ParquetResult}; + +pub(super) fn metadata_len(buffer: &[u8], len: usize) -> i32 { + i32::from_le_bytes(buffer[len - 8..len - 4].try_into().unwrap()) +} + +// see (unstable) Seek::stream_len +fn stream_len(seek: &mut impl Seek) -> std::result::Result { + let old_pos = seek.stream_position()?; + let len = seek.seek(SeekFrom::End(0))?; + + // Avoid seeking a third time when we were already at the end of the + // stream. The branch is usually way cheaper than a seek operation. + if old_pos != len { + seek.seek(SeekFrom::Start(old_pos))?; + } + + Ok(len) +} + +/// Reads a [`FileMetadata`] from the reader, located at the end of the file. +pub fn read_metadata(reader: &mut R) -> ParquetResult { + // check file is large enough to hold footer + let file_size = stream_len(reader)?; + read_metadata_with_size(reader, file_size) +} + +/// Reads a [`FileMetadata`] from the reader, located at the end of the file, with known file size. +pub fn read_metadata_with_size( + reader: &mut R, + file_size: u64, +) -> ParquetResult { + if file_size < HEADER_SIZE + FOOTER_SIZE { + return Err(ParquetError::oos( + "A parquet file must contain a header and footer with at least 12 bytes", + )); + } + + // read and cache up to DEFAULT_FOOTER_READ_SIZE bytes from the end and process the footer + let default_end_len = min(DEFAULT_FOOTER_READ_SIZE, file_size) as usize; + reader.seek(SeekFrom::End(-(default_end_len as i64)))?; + + let mut buffer = Vec::with_capacity(default_end_len); + reader + .by_ref() + .take(default_end_len as u64) + .read_to_end(&mut buffer)?; + + // check this is indeed a parquet file + if buffer[default_end_len - 4..] != PARQUET_MAGIC { + return Err(ParquetError::oos("The file must end with PAR1")); + } + + let metadata_len = metadata_len(&buffer, default_end_len); + + let metadata_len: u64 = metadata_len.try_into()?; + + let footer_len = FOOTER_SIZE + metadata_len; + if footer_len > file_size { + return Err(ParquetError::oos( + "The footer size must be smaller or equal to the file's size", + )); + } + + let reader: &[u8] = if (footer_len as usize) < buffer.len() { + // the whole metadata is in the bytes we already read + let remaining = buffer.len() - footer_len as usize; + &buffer[remaining..] + } else { + // the end of file read by default is not long enough, read again including the metadata. + reader.seek(SeekFrom::End(-(footer_len as i64)))?; + + buffer.clear(); + buffer.try_reserve(footer_len as usize)?; + reader.take(footer_len).read_to_end(&mut buffer)?; + + &buffer + }; + + // a highly nested but sparse struct could result in many allocations + let max_size = reader.len() * 2 + 1024; + + deserialize_metadata(reader, max_size) +} + +/// Parse loaded metadata bytes +pub fn deserialize_metadata(reader: R, max_size: usize) -> ParquetResult { + let mut prot = TCompactInputProtocol::new(reader, max_size); + let metadata = TFileMetadata::read_from_in_protocol(&mut prot)?; + + FileMetadata::try_from_thrift(metadata) +} diff --git a/crates/polars-parquet/src/parquet/read/mod.rs b/crates/polars-parquet/src/parquet/read/mod.rs new file mode 100644 index 000000000000..bddd59ee8e8e --- /dev/null +++ b/crates/polars-parquet/src/parquet/read/mod.rs @@ -0,0 +1,39 @@ +mod column; +mod compression; +pub mod levels; +mod metadata; +mod page; +#[cfg(feature = "async")] +mod stream; + +use std::io::{Seek, SeekFrom}; + +pub use column::*; +pub use compression::{BasicDecompressor, decompress}; +pub use metadata::{deserialize_metadata, read_metadata, read_metadata_with_size}; +pub use page::{PageIterator, PageMetaData, PageReader}; +#[cfg(feature = "async")] +pub use page::{get_page_stream, get_page_stream_from_column_start}; +use polars_utils::mmap::MemReader; +#[cfg(feature = "async")] +pub use stream::read_metadata as read_metadata_async; + +use crate::parquet::error::ParquetResult; +use crate::parquet::metadata::ColumnChunkMetadata; + +/// Returns a new [`PageReader`] by seeking `reader` to the beginning of `column_chunk`. +pub fn get_page_iterator( + column_chunk: &ColumnChunkMetadata, + mut reader: MemReader, + scratch: Vec, + max_page_size: usize, +) -> ParquetResult { + let col_start = column_chunk.byte_range().start; + reader.seek(SeekFrom::Start(col_start))?; + Ok(PageReader::new( + reader, + column_chunk, + scratch, + max_page_size, + )) +} diff --git a/crates/polars-parquet/src/parquet/read/page/mod.rs b/crates/polars-parquet/src/parquet/read/page/mod.rs new file mode 100644 index 000000000000..14801839a693 --- /dev/null +++ b/crates/polars-parquet/src/parquet/read/page/mod.rs @@ -0,0 +1,16 @@ +mod reader; +#[cfg(feature = "async")] +mod stream; + +pub use reader::{PageMetaData, PageReader}; + +use crate::parquet::error::ParquetError; +use crate::parquet::page::CompressedPage; + +pub trait PageIterator: Iterator> { + fn swap_buffer(&mut self, buffer: &mut Vec); +} + +#[cfg(feature = "async")] +#[cfg_attr(docsrs, doc(cfg(feature = "async")))] +pub use stream::{get_page_stream, get_page_stream_from_column_start}; diff --git a/crates/polars-parquet/src/parquet/read/page/reader.rs b/crates/polars-parquet/src/parquet/read/page/reader.rs new file mode 100644 index 000000000000..6bf04c173663 --- /dev/null +++ b/crates/polars-parquet/src/parquet/read/page/reader.rs @@ -0,0 +1,352 @@ +use std::io::Seek; +use std::sync::OnceLock; + +use polars_parquet_format::thrift::protocol::TCompactInputProtocol; +use polars_utils::mmap::{MemReader, MemSlice}; + +use super::PageIterator; +use crate::parquet::CowBuffer; +use crate::parquet::compression::Compression; +use crate::parquet::error::{ParquetError, ParquetResult}; +use crate::parquet::metadata::{ColumnChunkMetadata, Descriptor}; +use crate::parquet::page::{ + CompressedDataPage, CompressedDictPage, CompressedPage, DataPageHeader, PageType, + ParquetPageHeader, +}; +use crate::write::Encoding; + +/// This meta is a small part of [`ColumnChunkMetadata`]. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PageMetaData { + /// The start offset of this column chunk in file. + pub column_start: u64, + /// The number of values in this column chunk. + pub num_values: i64, + /// Compression type + pub compression: Compression, + /// The descriptor of this parquet column + pub descriptor: Descriptor, +} + +impl PageMetaData { + /// Returns a new [`PageMetaData`]. + pub fn new( + column_start: u64, + num_values: i64, + compression: Compression, + descriptor: Descriptor, + ) -> Self { + Self { + column_start, + num_values, + compression, + descriptor, + } + } +} + +impl From<&ColumnChunkMetadata> for PageMetaData { + fn from(column: &ColumnChunkMetadata) -> Self { + Self { + column_start: column.byte_range().start, + num_values: column.num_values(), + compression: column.compression(), + descriptor: column.descriptor().descriptor.clone(), + } + } +} + +/// A fallible [`Iterator`] of [`CompressedDataPage`]. This iterator reads pages back +/// to back until all pages have been consumed. +/// +/// The pages from this iterator always have [`None`] [`crate::parquet::page::CompressedDataPage::selected_rows()`] since +/// filter pushdown is not supported without a +/// pre-computed [page index](https://github.com/apache/parquet-format/blob/master/PageIndex.md). +pub struct PageReader { + // The source + reader: MemReader, + + compression: Compression, + + // The number of values we have seen so far. + seen_num_values: i64, + + // The number of total values in this column chunk. + total_num_values: i64, + + descriptor: Descriptor, + + // The currently allocated buffer. + pub(crate) scratch: Vec, + + // Maximum page size (compressed or uncompressed) to limit allocations + max_page_size: usize, +} + +impl PageReader { + /// Returns a new [`PageReader`]. + /// + /// It assumes that the reader has been `sought` (`seek`) to the beginning of `column`. + /// The parameter `max_header_size` + pub fn new( + reader: MemReader, + column: &ColumnChunkMetadata, + scratch: Vec, + max_page_size: usize, + ) -> Self { + Self::new_with_page_meta(reader, column.into(), scratch, max_page_size) + } + + /// Create a new [`PageReader`] with [`PageMetaData`]. + /// + /// It assumes that the reader has been `sought` (`seek`) to the beginning of `column`. + pub fn new_with_page_meta( + reader: MemReader, + reader_meta: PageMetaData, + scratch: Vec, + max_page_size: usize, + ) -> Self { + Self { + reader, + total_num_values: reader_meta.num_values, + compression: reader_meta.compression, + seen_num_values: 0, + descriptor: reader_meta.descriptor, + scratch, + max_page_size, + } + } + + /// Returns the reader and this Readers' interval buffer + pub fn into_inner(self) -> (MemReader, Vec) { + (self.reader, self.scratch) + } + + pub fn total_num_values(&self) -> usize { + debug_assert!(self.total_num_values >= 0); + self.total_num_values as usize + } + + pub fn read_dict(&mut self) -> ParquetResult> { + // If there are no pages, we cannot check if the first page is a dictionary page. Just + // return the fact there is no dictionary page. + if self.reader.remaining_len() == 0 { + return Ok(None); + } + + // a dictionary page exists iff the first data page is not at the start of + // the column + let seek_offset = self.reader.position(); + let page_header = read_page_header(&mut self.reader, self.max_page_size)?; + let page_type = page_header.type_.try_into()?; + + if !matches!(page_type, PageType::DictionaryPage) { + self.reader + .seek(std::io::SeekFrom::Start(seek_offset as u64))?; + return Ok(None); + } + + let read_size: usize = page_header.compressed_page_size.try_into()?; + + if read_size > self.max_page_size { + return Err(ParquetError::WouldOverAllocate); + } + + let buffer = self.reader.read_slice(read_size); + + if buffer.len() != read_size { + return Err(ParquetError::oos( + "The page header reported the wrong page size", + )); + } + + finish_page(page_header, buffer, self.compression, &self.descriptor).map(|p| { + if let CompressedPage::Dict(d) = p { + Some(d) + } else { + unreachable!() + } + }) + } +} + +impl PageIterator for PageReader { + fn swap_buffer(&mut self, scratch: &mut Vec) { + std::mem::swap(&mut self.scratch, scratch) + } +} + +impl Iterator for PageReader { + type Item = ParquetResult; + + fn next(&mut self) -> Option { + let mut buffer = std::mem::take(&mut self.scratch); + let maybe_maybe_page = next_page(self).transpose(); + if maybe_maybe_page.is_none() { + // no page => we take back the buffer + self.scratch = std::mem::take(&mut buffer); + } + maybe_maybe_page + } +} + +/// Reads Page header from Thrift. +pub(super) fn read_page_header( + reader: &mut MemReader, + max_size: usize, +) -> ParquetResult { + let mut prot = TCompactInputProtocol::new(reader, max_size); + let page_header = ParquetPageHeader::read_from_in_protocol(&mut prot)?; + Ok(page_header) +} + +/// This function is lightweight and executes a minimal amount of work so that it is IO bounded. +// Any un-necessary CPU-intensive tasks SHOULD be executed on individual pages. +fn next_page(reader: &mut PageReader) -> ParquetResult> { + if reader.seen_num_values >= reader.total_num_values { + return Ok(None); + }; + build_page(reader) +} + +pub(super) fn build_page(reader: &mut PageReader) -> ParquetResult> { + let page_header = read_page_header(&mut reader.reader, reader.max_page_size)?; + + reader.seen_num_values += get_page_num_values(&page_header)? as i64; + + let read_size: usize = page_header.compressed_page_size.try_into()?; + + if read_size > reader.max_page_size { + return Err(ParquetError::WouldOverAllocate); + } + + let buffer = reader.reader.read_slice(read_size); + + if buffer.len() != read_size { + return Err(ParquetError::oos( + "The page header reported the wrong page size", + )); + } + + finish_page(page_header, buffer, reader.compression, &reader.descriptor).map(Some) +} + +pub(super) fn finish_page( + page_header: ParquetPageHeader, + data: MemSlice, + compression: Compression, + descriptor: &Descriptor, +) -> ParquetResult { + let type_ = page_header.type_.try_into()?; + let uncompressed_page_size = page_header.uncompressed_page_size.try_into()?; + + static DO_VERBOSE: OnceLock = OnceLock::new(); + let do_verbose = *DO_VERBOSE.get_or_init(|| std::env::var("PARQUET_DO_VERBOSE").is_ok()); + + match type_ { + PageType::DictionaryPage => { + let dict_header = page_header.dictionary_page_header.as_ref().ok_or_else(|| { + ParquetError::oos( + "The page header type is a dictionary page but the dictionary header is empty", + ) + })?; + + if do_verbose { + eprintln!( + "Parquet DictPage ( num_values: {}, datatype: {:?} )", + dict_header.num_values, descriptor.primitive_type + ); + } + + let is_sorted = dict_header.is_sorted.unwrap_or(false); + + // move the buffer to `dict_page` + let page = CompressedDictPage::new( + CowBuffer::Borrowed(data), + compression, + uncompressed_page_size, + dict_header.num_values.try_into()?, + is_sorted, + ); + + Ok(CompressedPage::Dict(page)) + }, + PageType::DataPage => { + let header = page_header.data_page_header.ok_or_else(|| { + ParquetError::oos( + "The page header type is a v1 data page but the v1 data header is empty", + ) + })?; + + if do_verbose { + eprintln!( + "Parquet DataPageV1 ( num_values: {}, datatype: {:?}, encoding: {:?} )", + header.num_values, + descriptor.primitive_type, + Encoding::try_from(header.encoding).ok() + ); + } + + Ok(CompressedPage::Data(CompressedDataPage::new_read( + DataPageHeader::V1(header), + CowBuffer::Borrowed(data), + compression, + uncompressed_page_size, + descriptor.clone(), + ))) + }, + PageType::DataPageV2 => { + let header = page_header.data_page_header_v2.ok_or_else(|| { + ParquetError::oos( + "The page header type is a v2 data page but the v2 data header is empty", + ) + })?; + + if do_verbose { + println!( + "Parquet DataPageV2 ( num_values: {}, datatype: {:?}, encoding: {:?} )", + header.num_values, + descriptor.primitive_type, + Encoding::try_from(header.encoding).ok() + ); + } + + Ok(CompressedPage::Data(CompressedDataPage::new_read( + DataPageHeader::V2(header), + CowBuffer::Borrowed(data), + compression, + uncompressed_page_size, + descriptor.clone(), + ))) + }, + } +} + +pub(super) fn get_page_num_values(header: &ParquetPageHeader) -> ParquetResult { + let type_ = header.type_.try_into()?; + Ok(match type_ { + PageType::DataPage => { + header + .data_page_header + .as_ref() + .ok_or_else(|| { + ParquetError::oos( + "The page header type is a v1 data page but the v1 header is empty", + ) + })? + .num_values + }, + PageType::DataPageV2 => { + header + .data_page_header_v2 + .as_ref() + .ok_or_else(|| { + ParquetError::oos( + "The page header type is a v1 data page but the v1 header is empty", + ) + })? + .num_values + }, + _ => 0, + }) +} diff --git a/crates/polars-parquet/src/parquet/read/page/stream.rs b/crates/polars-parquet/src/parquet/read/page/stream.rs new file mode 100644 index 000000000000..c504845f5cc4 --- /dev/null +++ b/crates/polars-parquet/src/parquet/read/page/stream.rs @@ -0,0 +1,144 @@ +use std::io::SeekFrom; + +use async_stream::try_stream; +use futures::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt, Stream}; +use polars_parquet_format::thrift::protocol::TCompactInputStreamProtocol; +use polars_utils::mmap::MemSlice; + +use super::reader::{PageMetaData, finish_page}; +use crate::parquet::compression::Compression; +use crate::parquet::error::{ParquetError, ParquetResult}; +use crate::parquet::metadata::{ColumnChunkMetadata, Descriptor}; +use crate::parquet::page::{CompressedPage, DataPageHeader, ParquetPageHeader}; +use crate::parquet::parquet_bridge::{Encoding, PageType}; + +/// Returns a stream of compressed data pages +pub async fn get_page_stream<'a, RR: AsyncRead + Unpin + Send + AsyncSeek>( + column_metadata: &'a ColumnChunkMetadata, + reader: &'a mut RR, + scratch: Vec, + max_page_size: usize, +) -> ParquetResult> + 'a> { + get_page_stream_with_page_meta(column_metadata.into(), reader, scratch, max_page_size).await +} + +/// Returns a stream of compressed data pages from a reader that begins at the start of the column +pub async fn get_page_stream_from_column_start<'a, R: AsyncRead + Unpin + Send>( + column_metadata: &'a ColumnChunkMetadata, + reader: &'a mut R, + scratch: Vec, + max_header_size: usize, +) -> ParquetResult> + 'a> { + let page_metadata: PageMetaData = column_metadata.into(); + Ok(_get_page_stream( + reader, + page_metadata.num_values, + page_metadata.compression, + page_metadata.descriptor, + scratch, + max_header_size, + )) +} + +/// Returns a stream of compressed data pages with [`PageMetaData`] +pub async fn get_page_stream_with_page_meta( + page_metadata: PageMetaData, + reader: &mut RR, + scratch: Vec, + max_page_size: usize, +) -> ParquetResult> + '_> { + let column_start = page_metadata.column_start; + reader.seek(SeekFrom::Start(column_start)).await?; + Ok(_get_page_stream( + reader, + page_metadata.num_values, + page_metadata.compression, + page_metadata.descriptor, + scratch, + max_page_size, + )) +} + +fn _get_page_stream( + reader: &mut R, + total_num_values: i64, + compression: Compression, + descriptor: Descriptor, + mut scratch: Vec, + max_page_size: usize, +) -> impl Stream> + '_ { + let mut seen_values = 0i64; + try_stream! { + while seen_values < total_num_values { + // the header + let page_header = read_page_header(reader, max_page_size).await?; + + let data_header = get_page_header(&page_header)?; + seen_values += data_header.as_ref().map(|x| x.num_values() as i64).unwrap_or_default(); + + let read_size: usize = page_header.compressed_page_size.try_into()?; + + if read_size > max_page_size { + Err(ParquetError::WouldOverAllocate)? + } + + // followed by the buffer + scratch.clear(); + scratch.try_reserve(read_size)?; + let bytes_read = reader + .take(read_size as u64) + .read_to_end(&mut scratch).await?; + + if bytes_read != read_size { + Err(ParquetError::oos( + "The page header reported the wrong page size", + ))? + } + + yield finish_page( + page_header, + MemSlice::from_vec(std::mem::take(&mut scratch)), + compression, + &descriptor, + )?; + } + } +} + +/// Reads Page header from Thrift. +async fn read_page_header( + reader: &mut R, + max_page_size: usize, +) -> ParquetResult { + let mut prot = TCompactInputStreamProtocol::new(reader, max_page_size); + let page_header = ParquetPageHeader::stream_from_in_protocol(&mut prot).await?; + Ok(page_header) +} + +pub(super) fn get_page_header(header: &ParquetPageHeader) -> ParquetResult> { + let type_ = header.type_.try_into()?; + Ok(match type_ { + PageType::DataPage => { + let header = header.data_page_header.clone().ok_or_else(|| { + ParquetError::oos( + "The page header type is a v1 data page but the v1 header is empty", + ) + })?; + let _: Encoding = header.encoding.try_into()?; + let _: Encoding = header.repetition_level_encoding.try_into()?; + let _: Encoding = header.definition_level_encoding.try_into()?; + + Some(DataPageHeader::V1(header)) + }, + PageType::DataPageV2 => { + let header = header.data_page_header_v2.clone().ok_or_else(|| { + ParquetError::oos( + "The page header type is a v1 data page but the v1 header is empty", + ) + })?; + let _: Encoding = header.encoding.try_into()?; + Some(DataPageHeader::V2(header)) + }, + _ => None, + }) +} diff --git a/crates/polars-parquet/src/parquet/read/stream.rs b/crates/polars-parquet/src/parquet/read/stream.rs new file mode 100644 index 000000000000..6925de5f0fd5 --- /dev/null +++ b/crates/polars-parquet/src/parquet/read/stream.rs @@ -0,0 +1,88 @@ +use std::io::SeekFrom; + +use futures::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt}; + +use super::super::metadata::FileMetadata; +use super::super::{DEFAULT_FOOTER_READ_SIZE, FOOTER_SIZE, PARQUET_MAGIC}; +use super::metadata::{deserialize_metadata, metadata_len}; +use crate::parquet::HEADER_SIZE; +use crate::parquet::error::{ParquetError, ParquetResult}; + +async fn stream_len( + seek: &mut (impl AsyncSeek + std::marker::Unpin), +) -> std::result::Result { + let old_pos = seek.seek(SeekFrom::Current(0)).await?; + let len = seek.seek(SeekFrom::End(0)).await?; + + // Avoid seeking a third time when we were already at the end of the + // stream. The branch is usually way cheaper than a seek operation. + if old_pos != len { + seek.seek(SeekFrom::Start(old_pos)).await?; + } + + Ok(len) +} + +/// Asynchronously reads the files' metadata +pub async fn read_metadata( + reader: &mut R, +) -> ParquetResult { + let file_size = stream_len(reader).await?; + + if file_size < HEADER_SIZE + FOOTER_SIZE { + return Err(ParquetError::oos( + "A parquet file must contain a header and footer with at least 12 bytes", + )); + } + + // read and cache up to DEFAULT_FOOTER_READ_SIZE bytes from the end and process the footer + let default_end_len = std::cmp::min(DEFAULT_FOOTER_READ_SIZE, file_size) as usize; + reader + .seek(SeekFrom::End(-(default_end_len as i64))) + .await?; + + let mut buffer = vec![]; + buffer.try_reserve(default_end_len)?; + reader + .take(default_end_len as u64) + .read_to_end(&mut buffer) + .await?; + + // check this is indeed a parquet file + if buffer[default_end_len - 4..] != PARQUET_MAGIC { + return Err(ParquetError::oos("Invalid Parquet file. Corrupt footer")); + } + + let metadata_len = metadata_len(&buffer, default_end_len); + let metadata_len: u64 = metadata_len.try_into()?; + + let footer_len = FOOTER_SIZE + metadata_len; + if footer_len > file_size { + return Err(ParquetError::oos( + "The footer size must be smaller or equal to the file's size", + )); + } + + let reader = if (footer_len as usize) < buffer.len() { + // the whole metadata is in the bytes we already read + let remaining = buffer.len() - footer_len as usize; + &buffer[remaining..] + } else { + // the end of file read by default is not long enough, read again including the metadata. + reader.seek(SeekFrom::End(-(footer_len as i64))).await?; + + buffer.clear(); + buffer.try_reserve(footer_len as usize)?; + reader + .take(footer_len as u64) + .read_to_end(&mut buffer) + .await?; + + &buffer + }; + + // a highly nested but sparse struct could result in many allocations + let max_size = reader.len() * 2 + 1024; + + deserialize_metadata(reader, max_size) +} diff --git a/crates/polars-parquet/src/parquet/schema/io_message/from_message.rs b/crates/polars-parquet/src/parquet/schema/io_message/from_message.rs new file mode 100644 index 000000000000..71b490ab623a --- /dev/null +++ b/crates/polars-parquet/src/parquet/schema/io_message/from_message.rs @@ -0,0 +1,1178 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Parquet schema parser. +//! Provides methods to parse and validate string message type into Parquet +//! [`ParquetType`](crate::parquet::schema::types::ParquetType). +//! +//! # Example +//! +//! ```rust +//! use polars_parquet::parquet::schema::io_message::from_message; +//! +//! let message_type = " +//! message spark_schema { +//! OPTIONAL BYTE_ARRAY a (UTF8); +//! REQUIRED INT32 b; +//! REQUIRED DOUBLE c; +//! REQUIRED BOOLEAN d; +//! OPTIONAL group e (LIST) { +//! REPEATED group list { +//! REQUIRED INT32 element; +//! } +//! } +//! } +//! "; +//! +//! let schema = from_message(message_type).expect("Expected valid schema"); +//! println!("{:?}", schema); +//! ``` + +use polars_parquet_format::Type; +use polars_utils::pl_str::PlSmallStr; +use types::PrimitiveLogicalType; + +use super::super::types::{ParquetType, TimeUnit}; +use super::super::*; +use crate::parquet::error::{ParquetError, ParquetResult}; +use crate::parquet::schema::types::{GroupConvertedType, PrimitiveConvertedType}; + +fn is_logical_type(s: &str) -> bool { + matches!( + s, + "INTEGER" + | "MAP" + | "LIST" + | "ENUM" + | "DECIMAL" + | "DATE" + | "TIME" + | "TIMESTAMP" + | "STRING" + | "JSON" + | "BSON" + | "UUID" + | "UNKNOWN" + | "INTERVAL" + ) +} + +fn is_converted_type(s: &str) -> bool { + matches!( + s, + "UTF8" + | "ENUM" + | "DECIMAL" + | "DATE" + | "TIME_MILLIS" + | "TIME_MICROS" + | "TIMESTAMP_MILLIS" + | "TIMESTAMP_MICROS" + | "UINT_8" + | "UINT_16" + | "UINT_32" + | "UINT_64" + | "INT_8" + | "INT_16" + | "INT_32" + | "INT_64" + | "JSON" + | "BSON" + | "INTERVAL" + ) +} + +fn converted_group_from_str(s: &str) -> ParquetResult { + Ok(match s { + "MAP" => GroupConvertedType::Map, + "MAP_KEY_VALUE" => GroupConvertedType::MapKeyValue, + "LIST" => GroupConvertedType::List, + other => { + return Err(ParquetError::oos(format!( + "Invalid converted type {}", + other + ))); + }, + }) +} + +fn converted_primitive_from_str(s: &str) -> Option { + use PrimitiveConvertedType::*; + Some(match s { + "UTF8" => Utf8, + "ENUM" => Enum, + "DECIMAL" => Decimal(0, 0), + "DATE" => Date, + "TIME_MILLIS" => TimeMillis, + "TIME_MICROS" => TimeMicros, + "TIMESTAMP_MILLIS" => TimestampMillis, + "TIMESTAMP_MICROS" => TimestampMicros, + "UINT_8" => Uint8, + "UINT_16" => Uint16, + "UINT_32" => Uint32, + "UINT_64" => Uint64, + "INT_8" => Int8, + "INT_16" => Int16, + "INT_32" => Int32, + "INT_64" => Int64, + "JSON" => Json, + "BSON" => Bson, + "INTERVAL" => Interval, + _ => return None, + }) +} + +fn repetition_from_str(s: &str) -> ParquetResult { + Ok(match s { + "REQUIRED" => Repetition::Required, + "OPTIONAL" => Repetition::Optional, + "REPEATED" => Repetition::Repeated, + other => return Err(ParquetError::oos(format!("Invalid repetition {}", other))), + }) +} + +fn type_from_str(s: &str) -> ParquetResult { + match s { + "BOOLEAN" => Ok(Type::BOOLEAN), + "INT32" => Ok(Type::INT32), + "INT64" => Ok(Type::INT64), + "INT96" => Ok(Type::INT96), + "FLOAT" => Ok(Type::FLOAT), + "DOUBLE" => Ok(Type::DOUBLE), + "BYTE_ARRAY" | "BINARY" => Ok(Type::BYTE_ARRAY), + "FIXED_LEN_BYTE_ARRAY" => Ok(Type::FIXED_LEN_BYTE_ARRAY), + other => Err(ParquetError::oos(format!("Invalid type {}", other))), + } +} + +/// Parses message type as string into a Parquet [`ParquetType`](crate::parquet::schema::types::ParquetType). +/// +/// This could, for example, be used to extract individual columns. +/// +/// Returns Parquet general error when parsing or validation fails. +pub fn from_message(message_type: &str) -> ParquetResult { + let mut parser = Parser { + tokenizer: &mut Tokenizer::from_str(message_type), + }; + parser.parse_message_type() +} + +/// Tokenizer to split message type string into tokens that are separated using characters +/// defined in `is_schema_delim` method. Tokenizer also preserves delimiters as tokens. +/// Tokenizer provides Iterator interface to process tokens; it also allows to step back +/// to reprocess previous tokens. +struct Tokenizer<'a> { + // List of all tokens for a string + tokens: Vec<&'a str>, + // Current index of vector + index: usize, +} + +impl<'a> Tokenizer<'a> { + // Create tokenizer from message type string + pub fn from_str(string: &'a str) -> Self { + let vec = string + .split_whitespace() + .flat_map(Self::split_token) + .collect(); + Tokenizer { + tokens: vec, + index: 0, + } + } + + // List of all special characters in schema + fn is_schema_delim(c: char) -> bool { + c == ';' || c == '{' || c == '}' || c == '(' || c == ')' || c == '=' || c == ',' + } + + /// Splits string into tokens; input string can already be token or can contain + /// delimiters, e.g. required" -> Vec("required") and + /// "(UTF8);" -> Vec("(", "UTF8", ")", ";") + fn split_token(string: &str) -> Vec<&str> { + let mut buffer: Vec<&str> = Vec::new(); + let mut tail = string; + while let Some(index) = tail.find(Self::is_schema_delim) { + let (h, t) = tail.split_at(index); + if !h.is_empty() { + buffer.push(h); + } + buffer.push(&t[0..1]); + tail = &t[1..]; + } + if !tail.is_empty() { + buffer.push(tail); + } + buffer + } + + // Move pointer to a previous element + fn backtrack(&mut self) { + self.index -= 1; + } +} + +impl<'a> Iterator for Tokenizer<'a> { + type Item = &'a str; + + fn next(&mut self) -> Option<&'a str> { + if self.index < self.tokens.len() { + self.index += 1; + Some(self.tokens[self.index - 1]) + } else { + None + } + } +} + +/// Internal Schema parser. +/// Traverses message type using tokenizer and parses each group/primitive type +/// recursively. +struct Parser<'a> { + tokenizer: &'a mut Tokenizer<'a>, +} + +// Utility function to assert token on validity. +fn assert_token(token: Option<&str>, expected: &str) -> ParquetResult<()> { + match token { + Some(value) if value == expected => Ok(()), + Some(other) => Err(ParquetError::oos(format!( + "Expected '{}', found token '{}'", + expected, other + ))), + None => Err(ParquetError::oos(format!( + "Expected '{}', but no token found (None)", + expected + ))), + } +} + +// Utility function to parse i32 or return general error. +fn parse_i32(value: Option<&str>, not_found_msg: &str, parse_fail_msg: &str) -> ParquetResult { + value + .ok_or_else(|| ParquetError::oos(not_found_msg)) + .and_then(|v| { + v.parse::() + .map_err(|_| ParquetError::oos(parse_fail_msg)) + }) +} + +// Utility function to parse boolean or return general error. +#[inline] +fn parse_bool( + value: Option<&str>, + not_found_msg: &str, + parse_fail_msg: &str, +) -> ParquetResult { + value + .ok_or_else(|| ParquetError::oos(not_found_msg)) + .and_then(|v| { + v.to_lowercase() + .parse::() + .map_err(|_| ParquetError::oos(parse_fail_msg)) + }) +} + +// Utility function to parse TimeUnit or return general error. +fn parse_timeunit( + value: Option<&str>, + not_found_msg: &str, + parse_fail_msg: &str, +) -> ParquetResult { + value + .ok_or_else(|| ParquetError::oos(not_found_msg)) + .and_then(|v| match v.to_uppercase().as_str() { + "MILLIS" => Ok(TimeUnit::Milliseconds), + "MICROS" => Ok(TimeUnit::Microseconds), + "NANOS" => Ok(TimeUnit::Nanoseconds), + _ => Err(ParquetError::oos(parse_fail_msg)), + }) +} + +impl Parser<'_> { + // Entry function to parse message type, uses internal tokenizer. + fn parse_message_type(&mut self) -> ParquetResult { + // Check that message type starts with "message". + match self.tokenizer.next() { + Some("message") => { + let name = self + .tokenizer + .next() + .ok_or_else(|| ParquetError::oos("Expected name, found None"))?; + let fields = self.parse_child_types()?; + Ok(ParquetType::new_root(PlSmallStr::from_str(name), fields)) + }, + _ => Err(ParquetError::oos( + "Message type does not start with 'message'", + )), + } + } + + // Parses child types for a current group type. + // This is only invoked on root and group types. + fn parse_child_types(&mut self) -> ParquetResult> { + assert_token(self.tokenizer.next(), "{")?; + let mut vec = Vec::new(); + while let Some(value) = self.tokenizer.next() { + if value == "}" { + break; + } else { + self.tokenizer.backtrack(); + vec.push(self.add_type()?); + } + } + Ok(vec) + } + + fn add_type(&mut self) -> ParquetResult { + // Parse repetition + let repetition = self + .tokenizer + .next() + .ok_or_else(|| ParquetError::oos("Expected repetition, found None")) + .and_then(|v| repetition_from_str(&v.to_uppercase()))?; + + match self.tokenizer.next() { + Some(group) if group.to_uppercase() == "GROUP" => self.add_group_type(repetition), + Some(type_string) => { + let physical_type = type_from_str(&type_string.to_uppercase())?; + self.add_primitive_type(repetition, physical_type) + }, + None => Err(ParquetError::oos( + "Invalid type, could not extract next token", + )), + } + } + + fn add_group_type(&mut self, repetition: Repetition) -> ParquetResult { + // Parse name of the group type + let name = self + .tokenizer + .next() + .ok_or_else(|| ParquetError::oos("Expected name, found None"))?; + + // Parse converted type if exists + let converted_type = if let Some("(") = self.tokenizer.next() { + let converted_type = self + .tokenizer + .next() + .ok_or_else(|| ParquetError::oos("Expected converted type, found None")) + .and_then(|v| converted_group_from_str(&v.to_uppercase()))?; + assert_token(self.tokenizer.next(), ")")?; + Some(converted_type) + } else { + self.tokenizer.backtrack(); + None + }; + + // Parse optional id + let id = if let Some("=") = self.tokenizer.next() { + self.tokenizer.next().and_then(|v| v.parse::().ok()) + } else { + self.tokenizer.backtrack(); + None + }; + + let fields = self.parse_child_types()?; + + Ok(ParquetType::from_converted( + PlSmallStr::from_str(name), + fields, + repetition, + converted_type, + id, + )) + } + + fn add_primitive_type( + &mut self, + repetition: Repetition, + physical_type: Type, + ) -> ParquetResult { + // Read type length if the type is FIXED_LEN_BYTE_ARRAY. + let length = if physical_type == Type::FIXED_LEN_BYTE_ARRAY { + assert_token(self.tokenizer.next(), "(")?; + let length = parse_i32( + self.tokenizer.next(), + "Expected length for FIXED_LEN_BYTE_ARRAY, found None", + "Failed to parse length for FIXED_LEN_BYTE_ARRAY", + )?; + assert_token(self.tokenizer.next(), ")")?; + Some(length) + } else { + None + }; + + // Parse name of the primitive type + let name = self + .tokenizer + .next() + .ok_or_else(|| ParquetError::oos("Expected name, found None"))?; + + // Parse logical types + let (converted_type, logical_type) = if let Some("(") = self.tokenizer.next() { + let (is_logical_type, converted_type, token) = self + .tokenizer + .next() + .ok_or_else(|| ParquetError::oos("Expected converted or logical type, found None")) + .and_then(|v| { + let string = v.to_uppercase(); + Ok(if is_logical_type(&string) { + (true, None, string) + } else if is_converted_type(&string) { + (false, converted_primitive_from_str(&string), string) + } else { + return Err(ParquetError::oos(format!( + "Expected converted or logical type, found {}", + string + ))); + }) + })?; + + let logical_type = if is_logical_type { + Some(self.parse_logical_type(&token)?) + } else { + None + }; + + // converted type decimal + let converted_type = match converted_type { + Some(PrimitiveConvertedType::Decimal(_, _)) => { + Some(self.parse_converted_decimal()?) + }, + other => other, + }; + + assert_token(self.tokenizer.next(), ")")?; + (converted_type, logical_type) + } else { + self.tokenizer.backtrack(); + (None, None) + }; + + // Parse optional id + let id = if let Some("=") = self.tokenizer.next() { + self.tokenizer.next().and_then(|v| v.parse::().ok()) + } else { + self.tokenizer.backtrack(); + None + }; + assert_token(self.tokenizer.next(), ";")?; + + ParquetType::try_from_primitive( + PlSmallStr::from_str(name), + (physical_type, length).try_into()?, + repetition, + converted_type, + logical_type, + id, + ) + } + + fn parse_converted_decimal(&mut self) -> ParquetResult { + assert_token(self.tokenizer.next(), "(")?; + // Parse precision + let precision = parse_i32( + self.tokenizer.next(), + "Expected precision, found None", + "Failed to parse precision for DECIMAL type", + )?; + + // Parse scale + let scale = if let Some(",") = self.tokenizer.next() { + parse_i32( + self.tokenizer.next(), + "Expected scale, found None", + "Failed to parse scale for DECIMAL type", + )? + } else { + // Scale is not provided, set it to 0. + self.tokenizer.backtrack(); + 0 + }; + + assert_token(self.tokenizer.next(), ")")?; + Ok(PrimitiveConvertedType::Decimal( + precision.try_into()?, + scale.try_into()?, + )) + } + + fn parse_logical_type(&mut self, tpe: &str) -> ParquetResult { + Ok(match tpe { + "ENUM" => PrimitiveLogicalType::Enum, + "DATE" => PrimitiveLogicalType::Date, + "DECIMAL" => { + let (precision, scale) = if let Some("(") = self.tokenizer.next() { + let precision = parse_i32( + self.tokenizer.next(), + "Expected precision, found None", + "Failed to parse precision for DECIMAL type", + )?; + let scale = if let Some(",") = self.tokenizer.next() { + parse_i32( + self.tokenizer.next(), + "Expected scale, found None", + "Failed to parse scale for DECIMAL type", + )? + } else { + self.tokenizer.backtrack(); + 0 + }; + assert_token(self.tokenizer.next(), ")")?; + (precision, scale) + } else { + self.tokenizer.backtrack(); + (0, 0) + }; + PrimitiveLogicalType::Decimal(precision.try_into()?, scale.try_into()?) + }, + "TIME" => { + let (unit, is_adjusted_to_utc) = if let Some("(") = self.tokenizer.next() { + let unit = parse_timeunit( + self.tokenizer.next(), + "Invalid timeunit found", + "Failed to parse timeunit for TIME type", + )?; + let is_adjusted_to_utc = if let Some(",") = self.tokenizer.next() { + parse_bool( + self.tokenizer.next(), + "Invalid boolean found", + "Failed to parse timezone info for TIME type", + )? + } else { + self.tokenizer.backtrack(); + false + }; + assert_token(self.tokenizer.next(), ")")?; + (unit, is_adjusted_to_utc) + } else { + self.tokenizer.backtrack(); + (TimeUnit::Milliseconds, false) + }; + PrimitiveLogicalType::Time { + is_adjusted_to_utc, + unit, + } + }, + "TIMESTAMP" => { + let (unit, is_adjusted_to_utc) = if let Some("(") = self.tokenizer.next() { + let unit = parse_timeunit( + self.tokenizer.next(), + "Invalid timeunit found", + "Failed to parse timeunit for TIMESTAMP type", + )?; + let is_adjusted_to_utc = if let Some(",") = self.tokenizer.next() { + parse_bool( + self.tokenizer.next(), + "Invalid boolean found", + "Failed to parse timezone info for TIMESTAMP type", + )? + } else { + // Invalid token for unit + self.tokenizer.backtrack(); + false + }; + assert_token(self.tokenizer.next(), ")")?; + (unit, is_adjusted_to_utc) + } else { + self.tokenizer.backtrack(); + (TimeUnit::Milliseconds, false) + }; + PrimitiveLogicalType::Timestamp { + is_adjusted_to_utc, + unit, + } + }, + "INTEGER" => { + let (bit_width, is_signed) = if let Some("(") = self.tokenizer.next() { + let bit_width = parse_i32( + self.tokenizer.next(), + "Invalid bit_width found", + "Failed to parse bit_width for INTEGER type", + )?; + let is_signed = if let Some(",") = self.tokenizer.next() { + parse_bool( + self.tokenizer.next(), + "Invalid boolean found", + "Failed to parse is_signed for INTEGER type", + )? + } else { + // Invalid token for unit + self.tokenizer.backtrack(); + return Err(ParquetError::oos("INTEGER requires sign")); + }; + assert_token(self.tokenizer.next(), ")")?; + (bit_width, is_signed) + } else { + // Invalid token for unit + self.tokenizer.backtrack(); + return Err(ParquetError::oos("INTEGER requires width and sign")); + }; + PrimitiveLogicalType::Integer((bit_width, is_signed).into()) + }, + "STRING" => PrimitiveLogicalType::String, + "JSON" => PrimitiveLogicalType::Json, + "BSON" => PrimitiveLogicalType::Bson, + "UUID" => PrimitiveLogicalType::Uuid, + "UNKNOWN" => PrimitiveLogicalType::Unknown, + "INTERVAL" => return Err(ParquetError::oos("Interval logical type not yet supported")), + _ => unreachable!(), + }) + } +} + +#[cfg(test)] +mod tests { + use types::IntegerType; + + use super::*; + use crate::parquet::schema::types::PhysicalType; + + #[test] + fn test_tokenize_empty_string() { + assert_eq!(Tokenizer::from_str("").next(), None); + } + + #[test] + fn test_tokenize_delimiters() { + let mut iter = Tokenizer::from_str(",;{}()="); + assert_eq!(iter.next(), Some(",")); + assert_eq!(iter.next(), Some(";")); + assert_eq!(iter.next(), Some("{")); + assert_eq!(iter.next(), Some("}")); + assert_eq!(iter.next(), Some("(")); + assert_eq!(iter.next(), Some(")")); + assert_eq!(iter.next(), Some("=")); + assert_eq!(iter.next(), None); + } + + #[test] + fn test_tokenize_delimiters_with_whitespaces() { + let mut iter = Tokenizer::from_str(" , ; { } ( ) = "); + assert_eq!(iter.next(), Some(",")); + assert_eq!(iter.next(), Some(";")); + assert_eq!(iter.next(), Some("{")); + assert_eq!(iter.next(), Some("}")); + assert_eq!(iter.next(), Some("(")); + assert_eq!(iter.next(), Some(")")); + assert_eq!(iter.next(), Some("=")); + assert_eq!(iter.next(), None); + } + + #[test] + fn test_tokenize_words() { + let mut iter = Tokenizer::from_str("abc def ghi jkl mno"); + assert_eq!(iter.next(), Some("abc")); + assert_eq!(iter.next(), Some("def")); + assert_eq!(iter.next(), Some("ghi")); + assert_eq!(iter.next(), Some("jkl")); + assert_eq!(iter.next(), Some("mno")); + assert_eq!(iter.next(), None); + } + + #[test] + fn test_tokenize_backtrack() { + let mut iter = Tokenizer::from_str("abc;"); + assert_eq!(iter.next(), Some("abc")); + assert_eq!(iter.next(), Some(";")); + iter.backtrack(); + assert_eq!(iter.next(), Some(";")); + assert_eq!(iter.next(), None); + } + + #[test] + fn test_tokenize_message_type() { + let schema = " + message schema { + required int32 a; + optional binary c (UTF8); + required group d { + required int32 a; + optional binary c (UTF8); + } + required group e (LIST) { + repeated group list { + required int32 element; + } + } + } + "; + let iter = Tokenizer::from_str(schema); + let mut res = Vec::new(); + for token in iter { + res.push(token); + } + assert_eq!( + res, + vec![ + "message", "schema", "{", "required", "int32", "a", ";", "optional", "binary", "c", + "(", "UTF8", ")", ";", "required", "group", "d", "{", "required", "int32", "a", + ";", "optional", "binary", "c", "(", "UTF8", ")", ";", "}", "required", "group", + "e", "(", "LIST", ")", "{", "repeated", "group", "list", "{", "required", "int32", + "element", ";", "}", "}", "}" + ] + ); + } + + #[test] + fn test_assert_token() { + assert!(assert_token(Some("a"), "a").is_ok()); + assert!(assert_token(Some("a"), "b").is_err()); + assert!(assert_token(None, "b").is_err()); + } + + #[test] + fn test_parse_message_type_invalid() { + let mut iter = Tokenizer::from_str("test"); + let result = Parser { + tokenizer: &mut iter, + } + .parse_message_type(); + assert!(result.is_err()); + assert_eq!( + result.unwrap_err().to_string(), + "File out of specification: Message type does not start with 'message'" + ); + } + + #[test] + fn test_parse_message_type_no_name() { + let mut iter = Tokenizer::from_str("message"); + let result = Parser { + tokenizer: &mut iter, + } + .parse_message_type(); + assert!(result.is_err()); + assert_eq!( + result.unwrap_err().to_string(), + "File out of specification: Expected name, found None" + ); + } + + #[test] + fn test_parse_message_type_fixed_byte_array() { + let schema = " + message schema { + REQUIRED FIXED_LEN_BYTE_ARRAY col; + } + "; + let mut iter = Tokenizer::from_str(schema); + let result = Parser { + tokenizer: &mut iter, + } + .parse_message_type(); + assert!(result.is_err()); + + let schema = " + message schema { + REQUIRED FIXED_LEN_BYTE_ARRAY(16) col; + } + "; + let mut iter = Tokenizer::from_str(schema); + let result = Parser { + tokenizer: &mut iter, + } + .parse_message_type(); + assert!(result.is_ok()); + } + + #[test] + fn test_parse_message_type_decimal() { + // It is okay for decimal to omit precision and scale with right syntax. + // Here we test wrong syntax of decimal type + + // Invalid decimal syntax + let schema = " + message root { + optional int32 f1 (DECIMAL(); + } + "; + let mut iter = Tokenizer::from_str(schema); + let result = Parser { + tokenizer: &mut iter, + } + .parse_message_type(); + assert!(result.is_err()); + + // Invalid decimal, need precision and scale + let schema = " + message root { + optional int32 f1 (DECIMAL()); + } + "; + let mut iter = Tokenizer::from_str(schema); + let result = Parser { + tokenizer: &mut iter, + } + .parse_message_type(); + assert!(result.is_err()); + + // Invalid decimal because of `,` - has precision, needs scale + let schema = " + message root { + optional int32 f1 (DECIMAL(8,)); + } + "; + let mut iter = Tokenizer::from_str(schema); + let result = Parser { + tokenizer: &mut iter, + } + .parse_message_type(); + assert!(result.is_err()); + } + + #[test] + fn test_parse_decimal_wrong() { + // Invalid decimal because, we always require either precision or scale to be + // specified as part of converted type + let schema = " + message root { + optional int32 f3 (DECIMAL); + } + "; + let mut iter = Tokenizer::from_str(schema); + let result = Parser { + tokenizer: &mut iter, + } + .parse_message_type(); + assert!(result.is_err()); + + // Valid decimal (precision, scale) + let schema = " + message root { + optional int32 f1 (DECIMAL(8, 3)); + optional int32 f2 (DECIMAL(8)); + } + "; + let mut iter = Tokenizer::from_str(schema); + let result = Parser { + tokenizer: &mut iter, + } + .parse_message_type(); + assert!(result.is_ok()); + } + + #[test] + fn test_parse_message_type_compare_1() -> ParquetResult<()> { + let schema = " + message root { + optional fixed_len_byte_array(5) f1 (DECIMAL(9, 3)); + optional fixed_len_byte_array (16) f2 (DECIMAL (38, 18)); + } + "; + let mut iter = Tokenizer::from_str(schema); + let message = Parser { + tokenizer: &mut iter, + } + .parse_message_type() + .unwrap(); + + let fields = vec![ + ParquetType::try_from_primitive( + PlSmallStr::from_static("f1"), + PhysicalType::FixedLenByteArray(5), + Repetition::Optional, + None, + Some(PrimitiveLogicalType::Decimal(9, 3)), + None, + )?, + ParquetType::try_from_primitive( + PlSmallStr::from_static("f2"), + PhysicalType::FixedLenByteArray(16), + Repetition::Optional, + None, + Some(PrimitiveLogicalType::Decimal(38, 18)), + None, + )?, + ]; + + let expected = ParquetType::new_root(PlSmallStr::from_static("root"), fields); + + assert_eq!(message, expected); + Ok(()) + } + + #[test] + fn test_parse_message_type_compare_2() -> ParquetResult<()> { + let schema = " + message root { + required group a0 { + optional group a1 (LIST) { + repeated binary a2 (UTF8); + } + + optional group b1 (LIST) { + repeated group b2 { + optional int32 b3; + optional double b4; + } + } + } + } + "; + let mut iter = Tokenizer::from_str(schema); + let message = Parser { + tokenizer: &mut iter, + } + .parse_message_type() + .unwrap(); + + let a2 = ParquetType::try_from_primitive( + "a2".into(), + PhysicalType::ByteArray, + Repetition::Repeated, + Some(PrimitiveConvertedType::Utf8), + None, + None, + )?; + let a1 = ParquetType::from_converted( + "a1".into(), + vec![a2], + Repetition::Optional, + Some(GroupConvertedType::List), + None, + ); + let b2 = ParquetType::from_converted( + "b2".into(), + vec![ + ParquetType::from_physical("b3".into(), PhysicalType::Int32), + ParquetType::from_physical("b4".into(), PhysicalType::Double), + ], + Repetition::Repeated, + None, + None, + ); + let b1 = ParquetType::from_converted( + "b1".into(), + vec![b2], + Repetition::Optional, + Some(GroupConvertedType::List), + None, + ); + let a0 = ParquetType::from_converted( + "a0".into(), + vec![a1, b1], + Repetition::Required, + None, + None, + ); + + let expected = ParquetType::new_root("root".into(), vec![a0]); + + assert_eq!(message, expected); + Ok(()) + } + + #[test] + fn test_parse_message_type_compare_3() -> ParquetResult<()> { + let schema = " + message root { + required int32 _1 (INT_8); + required int32 _2 (INT_16); + required float _3; + required double _4; + optional int32 _5 (DATE); + optional binary _6 (UTF8); + } + "; + let mut iter = Tokenizer::from_str(schema); + let message = Parser { + tokenizer: &mut iter, + } + .parse_message_type() + .unwrap(); + + let f1 = ParquetType::try_from_primitive( + "_1".into(), + PhysicalType::Int32, + Repetition::Required, + Some(PrimitiveConvertedType::Int8), + None, + None, + )?; + let f2 = ParquetType::try_from_primitive( + "_2".into(), + PhysicalType::Int32, + Repetition::Required, + Some(PrimitiveConvertedType::Int16), + None, + None, + )?; + let f3 = ParquetType::try_from_primitive( + "_3".into(), + PhysicalType::Float, + Repetition::Required, + None, + None, + None, + )?; + let f4 = ParquetType::try_from_primitive( + "_4".into(), + PhysicalType::Double, + Repetition::Required, + None, + None, + None, + )?; + let f5 = ParquetType::try_from_primitive( + "_5".into(), + PhysicalType::Int32, + Repetition::Optional, + None, + Some(PrimitiveLogicalType::Date), + None, + )?; + let f6 = ParquetType::try_from_primitive( + "_6".into(), + PhysicalType::ByteArray, + Repetition::Optional, + Some(PrimitiveConvertedType::Utf8), + None, + None, + )?; + + let fields = vec![f1, f2, f3, f4, f5, f6]; + + let expected = ParquetType::new_root("root".into(), fields); + assert_eq!(message, expected); + Ok(()) + } + + #[test] + fn test_parse_message_type_compare_4() -> ParquetResult<()> { + let schema = " + message root { + required int32 _1 (INTEGER(8,true)); + required int32 _2 (INTEGER(16,false)); + required float _3; + required double _4; + optional int32 _5 (DATE); + optional int32 _6 (TIME(MILLIS,false)); + optional int64 _7 (TIME(MICROS,true)); + optional int64 _8 (TIMESTAMP(MILLIS,true)); + optional int64 _9 (TIMESTAMP(NANOS,false)); + optional binary _10 (STRING); + } + "; + let mut iter = Tokenizer::from_str(schema); + let message = Parser { + tokenizer: &mut iter, + } + .parse_message_type()?; + + let f1 = ParquetType::try_from_primitive( + "_1".into(), + PhysicalType::Int32, + Repetition::Required, + None, + Some(PrimitiveLogicalType::Integer(IntegerType::Int8)), + None, + )?; + let f2 = ParquetType::try_from_primitive( + "_2".into(), + PhysicalType::Int32, + Repetition::Required, + None, + Some(PrimitiveLogicalType::Integer(IntegerType::UInt16)), + None, + )?; + let f3 = ParquetType::try_from_primitive( + "_3".into(), + PhysicalType::Float, + Repetition::Required, + None, + None, + None, + )?; + let f4 = ParquetType::try_from_primitive( + "_4".into(), + PhysicalType::Double, + Repetition::Required, + None, + None, + None, + )?; + let f5 = ParquetType::try_from_primitive( + "_5".into(), + PhysicalType::Int32, + Repetition::Optional, + None, + Some(PrimitiveLogicalType::Date), + None, + )?; + let f6 = ParquetType::try_from_primitive( + "_6".into(), + PhysicalType::Int32, + Repetition::Optional, + None, + Some(PrimitiveLogicalType::Time { + is_adjusted_to_utc: false, + unit: TimeUnit::Milliseconds, + }), + None, + )?; + let f7 = ParquetType::try_from_primitive( + "_7".into(), + PhysicalType::Int64, + Repetition::Optional, + None, + Some(PrimitiveLogicalType::Time { + is_adjusted_to_utc: true, + unit: TimeUnit::Microseconds, + }), + None, + )?; + let f8 = ParquetType::try_from_primitive( + "_8".into(), + PhysicalType::Int64, + Repetition::Optional, + None, + Some(PrimitiveLogicalType::Timestamp { + is_adjusted_to_utc: true, + unit: TimeUnit::Milliseconds, + }), + None, + )?; + let f9 = ParquetType::try_from_primitive( + "_9".into(), + PhysicalType::Int64, + Repetition::Optional, + None, + Some(PrimitiveLogicalType::Timestamp { + is_adjusted_to_utc: false, + unit: TimeUnit::Nanoseconds, + }), + None, + )?; + + let f10 = ParquetType::try_from_primitive( + "_10".into(), + PhysicalType::ByteArray, + Repetition::Optional, + None, + Some(PrimitiveLogicalType::String), + None, + )?; + + let fields = vec![f1, f2, f3, f4, f5, f6, f7, f8, f9, f10]; + + let expected = ParquetType::new_root("root".into(), fields); + assert_eq!(message, expected); + Ok(()) + } +} diff --git a/crates/polars-parquet/src/parquet/schema/io_message/mod.rs b/crates/polars-parquet/src/parquet/schema/io_message/mod.rs new file mode 100644 index 000000000000..1e296a7f3724 --- /dev/null +++ b/crates/polars-parquet/src/parquet/schema/io_message/mod.rs @@ -0,0 +1,3 @@ +mod from_message; + +pub use from_message::from_message; diff --git a/crates/polars-parquet/src/parquet/schema/io_thrift/from_thrift.rs b/crates/polars-parquet/src/parquet/schema/io_thrift/from_thrift.rs new file mode 100644 index 000000000000..7296e04f17dd --- /dev/null +++ b/crates/polars-parquet/src/parquet/schema/io_thrift/from_thrift.rs @@ -0,0 +1,147 @@ +use polars_parquet_format::SchemaElement; +use polars_utils::pl_str::PlSmallStr; + +use super::super::types::ParquetType; +use crate::parquet::error::{ParquetError, ParquetResult}; +use crate::parquet::schema::types::FieldInfo; + +impl ParquetType { + /// Method to convert from Thrift. + pub fn try_from_thrift(elements: &[SchemaElement]) -> ParquetResult { + let mut index = 0; + let mut schema_nodes = Vec::new(); + while index < elements.len() { + let t = from_thrift_helper(elements, index)?; + index = t.0; + schema_nodes.push(t.1); + } + if schema_nodes.len() != 1 { + return Err(ParquetError::oos(format!( + "Expected exactly one root node, but found {}", + schema_nodes.len() + ))); + } + + Ok(schema_nodes.remove(0)) + } +} + +/// Constructs a new Type from the `elements`, starting at index `index`. +/// The first result is the starting index for the next Type after this one. If it is +/// equal to `elements.len()`, then this Type is the last one. +/// The second result is the result Type. +fn from_thrift_helper( + elements: &[SchemaElement], + index: usize, +) -> ParquetResult<(usize, ParquetType)> { + // Whether or not the current node is root (message type). + // There is only one message type node in the schema tree. + let is_root_node = index == 0; + + let element = elements.get(index).ok_or_else(|| { + ParquetError::oos(format!("index {} on SchemaElement is not valid", index)) + })?; + let name = PlSmallStr::from_str(element.name.as_str()); + let converted_type = element.converted_type; + + let id = element.field_id; + match element.num_children { + // empty root + None | Some(0) if is_root_node => { + let fields = vec![]; + let tp = ParquetType::new_root(name, fields); + Ok((index + 1, tp)) + }, + + // From parquet-format: + // The children count is used to construct the nested relationship. + // This field is not set when the element is a primitive type + // Sometimes parquet-cpp sets num_children field to 0 for primitive types, so we + // have to handle this case too. + None | Some(0) => { + // primitive type + let repetition = element + .repetition_type + .ok_or_else(|| { + ParquetError::oos("Repetition level must be defined for a primitive type") + })? + .try_into()?; + let physical_type = element.type_.ok_or_else(|| { + ParquetError::oos("Physical type must be defined for a primitive type") + })?; + + let converted_type = converted_type + .map(|converted_type| { + let maybe_decimal = match (element.precision, element.scale) { + (Some(precision), Some(scale)) => Some((precision, scale)), + (None, None) => None, + _ => { + return Err(ParquetError::oos( + "When precision or scale are defined, both must be defined", + )); + }, + }; + (converted_type, maybe_decimal).try_into() + }) + .transpose()?; + + let logical_type = element + .logical_type + .clone() + .map(|x| x.try_into()) + .transpose()?; + + let tp = ParquetType::try_from_primitive( + name, + (physical_type, element.type_length).try_into()?, + repetition, + converted_type, + logical_type, + id, + )?; + + Ok((index + 1, tp)) + }, + Some(n) => { + let mut fields = vec![]; + let mut next_index = index + 1; + for _ in 0..n { + let child_result = from_thrift_helper(elements, next_index)?; + next_index = child_result.0; + fields.push(child_result.1); + } + + let tp = if is_root_node { + ParquetType::new_root(name, fields) + } else { + let repetition = if let Some(repetition) = element.repetition_type { + repetition.try_into()? + } else { + return Err(ParquetError::oos( + "The repetition level of a non-root must be non-null", + )); + }; + + let converted_type = converted_type.map(|x| x.try_into()).transpose()?; + + let logical_type = element + .logical_type + .clone() + .map(|x| x.try_into()) + .transpose()?; + + ParquetType::GroupType { + field_info: FieldInfo { + name, + repetition, + id, + }, + fields, + converted_type, + logical_type, + } + }; + Ok((next_index, tp)) + }, + } +} diff --git a/crates/polars-parquet/src/parquet/schema/io_thrift/mod.rs b/crates/polars-parquet/src/parquet/schema/io_thrift/mod.rs new file mode 100644 index 000000000000..193a49cd8879 --- /dev/null +++ b/crates/polars-parquet/src/parquet/schema/io_thrift/mod.rs @@ -0,0 +1,83 @@ +mod from_thrift; + +mod to_thrift; + +#[cfg(test)] +mod tests { + use crate::parquet::error::ParquetResult; + use crate::parquet::schema::io_message::from_message; + use crate::parquet::schema::types::ParquetType; + + fn test_round_trip(message: &str) -> ParquetResult<()> { + let expected_schema = from_message(message)?; + let thrift_schema = expected_schema.to_thrift(); + let thrift_schema = thrift_schema.into_iter().collect::>(); + let result_schema = ParquetType::try_from_thrift(&thrift_schema)?; + assert_eq!(result_schema, expected_schema); + Ok(()) + } + + #[test] + fn test_schema_type_thrift_conversion() { + let message_type = " + message conversions { + REQUIRED INT64 id; + OPTIONAL group int_array_Array (LIST) { + REPEATED group list { + OPTIONAL group element (LIST) { + REPEATED group list { + OPTIONAL INT32 element; + } + } + } + } + OPTIONAL group int_map (MAP) { + REPEATED group map (MAP_KEY_VALUE) { + REQUIRED BYTE_ARRAY key (UTF8); + OPTIONAL INT32 value; + } + } + OPTIONAL group int_Map_Array (LIST) { + REPEATED group list { + OPTIONAL group g (MAP) { + REPEATED group map (MAP_KEY_VALUE) { + REQUIRED BYTE_ARRAY key (UTF8); + OPTIONAL group value { + OPTIONAL group H { + OPTIONAL group i (LIST) { + REPEATED group list { + OPTIONAL DOUBLE element; + } + } + } + } + } + } + } + } + OPTIONAL group nested_struct { + OPTIONAL INT32 A; + OPTIONAL group b (LIST) { + REPEATED group list { + REQUIRED FIXED_LEN_BYTE_ARRAY (16) element; + } + } + } + } + "; + test_round_trip(message_type).unwrap(); + } + + #[test] + fn test_schema_type_thrift_conversion_decimal() { + let message_type = " + message decimals { + OPTIONAL INT32 field0; + OPTIONAL INT64 field1 (DECIMAL (18, 2)); + OPTIONAL FIXED_LEN_BYTE_ARRAY (16) field2 (DECIMAL (38, 18)); + OPTIONAL BYTE_ARRAY field3 (DECIMAL (9)); + } + "; + test_round_trip(message_type).unwrap(); + } +} diff --git a/crates/polars-parquet/src/parquet/schema/io_thrift/to_thrift.rs b/crates/polars-parquet/src/parquet/schema/io_thrift/to_thrift.rs new file mode 100644 index 000000000000..db372b733593 --- /dev/null +++ b/crates/polars-parquet/src/parquet/schema/io_thrift/to_thrift.rs @@ -0,0 +1,82 @@ +use polars_parquet_format::{ConvertedType, SchemaElement}; + +use super::super::types::ParquetType; +use crate::parquet::schema::types::PrimitiveType; + +impl ParquetType { + /// Method to convert to Thrift. + pub(crate) fn to_thrift(&self) -> Vec { + let mut elements: Vec = Vec::new(); + to_thrift_helper(self, &mut elements, true); + elements + } +} + +/// Constructs list of `SchemaElement` from the schema using depth-first traversal. +/// Here we assume that schema is always valid and starts with group type. +fn to_thrift_helper(schema: &ParquetType, elements: &mut Vec, is_root: bool) { + match schema { + ParquetType::PrimitiveType(PrimitiveType { + field_info, + logical_type, + converted_type, + physical_type, + }) => { + let (type_, type_length) = (*physical_type).into(); + let (converted_type, maybe_decimal) = converted_type + .map(|x| x.into()) + .map(|x: (ConvertedType, Option<(i32, i32)>)| (Some(x.0), x.1)) + .unwrap_or((None, None)); + + let element = SchemaElement { + type_: Some(type_), + type_length, + repetition_type: Some(field_info.repetition.into()), + name: field_info.name.to_string(), + num_children: None, + converted_type, + precision: maybe_decimal.map(|x| x.0), + scale: maybe_decimal.map(|x| x.1), + field_id: field_info.id, + logical_type: logical_type.map(|x| x.into()), + }; + + elements.push(element); + }, + ParquetType::GroupType { + field_info, + fields, + logical_type, + converted_type, + } => { + let converted_type = converted_type.map(|x| x.into()); + + let repetition_type = if is_root { + // https://github.com/apache/parquet-format/blob/7f06e838cbd1b7dbd722ff2580b9c2525e37fc46/src/main/thrift/parquet.thrift#L363 + None + } else { + Some(field_info.repetition) + }; + + let element = SchemaElement { + type_: None, + type_length: None, + repetition_type: repetition_type.map(|x| x.into()), + name: field_info.name.to_string(), + num_children: Some(fields.len() as i32), + converted_type, + scale: None, + precision: None, + field_id: field_info.id, + logical_type: logical_type.map(|x| x.into()), + }; + + elements.push(element); + + // Add child elements for a group + for field in fields { + to_thrift_helper(field, elements, false); + } + }, + } +} diff --git a/crates/polars-parquet/src/parquet/schema/mod.rs b/crates/polars-parquet/src/parquet/schema/mod.rs new file mode 100644 index 000000000000..af1918afa7f9 --- /dev/null +++ b/crates/polars-parquet/src/parquet/schema/mod.rs @@ -0,0 +1,7 @@ +pub use super::thrift_format::SchemaElement; +pub use crate::parquet::parquet_bridge::Repetition; + +pub mod io_message; +pub mod io_thrift; + +pub mod types; diff --git a/crates/polars-parquet/src/parquet/schema/types/basic_type.rs b/crates/polars-parquet/src/parquet/schema/types/basic_type.rs new file mode 100644 index 000000000000..e882f83516f5 --- /dev/null +++ b/crates/polars-parquet/src/parquet/schema/types/basic_type.rs @@ -0,0 +1,17 @@ +use polars_utils::pl_str::PlSmallStr; +#[cfg(feature = "serde_types")] +use serde::{Deserialize, Serialize}; + +use super::super::Repetition; + +/// Common type information. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde_types", derive(Deserialize, Serialize))] +pub struct FieldInfo { + /// The field name + pub name: PlSmallStr, + /// The repetition + pub repetition: Repetition, + /// the optional id, to select fields by id + pub id: Option, +} diff --git a/crates/polars-parquet/src/parquet/schema/types/converted_type.rs b/crates/polars-parquet/src/parquet/schema/types/converted_type.rs new file mode 100644 index 000000000000..99f0e1cac029 --- /dev/null +++ b/crates/polars-parquet/src/parquet/schema/types/converted_type.rs @@ -0,0 +1,231 @@ +use polars_parquet_format::ConvertedType; +#[cfg(feature = "serde_types")] +use serde::{Deserialize, Serialize}; + +use crate::parquet::error::ParquetError; + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde_types", derive(Deserialize, Serialize))] +pub enum PrimitiveConvertedType { + Utf8, + /// an enum is converted into a binary field + Enum, + /// A decimal value. + /// + /// This may be used to annotate binary or fixed primitive types. The underlying byte array + /// stores the unscaled value encoded as two's complement using big-endian byte order (the most + /// significant byte is the zeroth element). The value of the decimal is the value * + /// 10^{-scale}. + /// + /// This must be accompanied by a (maximum) precision and a scale in the SchemaElement. The + /// precision specifies the number of digits in the decimal and the scale stores the location + /// of the decimal point. For example 1.23 would have precision 3 (3 total digits) and scale 2 + /// (the decimal point is 2 digits over). + // (precision, scale) + Decimal(usize, usize), + /// A Date + /// + /// Stored as days since Unix epoch, encoded as the INT32 physical type. + /// + Date, + /// A time + /// + /// The total number of milliseconds since midnight. The value is stored as an INT32 physical + /// type. + TimeMillis, + /// A time. + /// + /// The total number of microseconds since midnight. The value is stored as an INT64 physical + /// type. + TimeMicros, + /// A date/time combination + /// + /// Date and time recorded as milliseconds since the Unix epoch. Recorded as a physical type + /// of INT64. + TimestampMillis, + /// A date/time combination + /// + /// Date and time recorded as microseconds since the Unix epoch. The value is stored as an + /// INT64 physical type. + TimestampMicros, + /// An unsigned integer value. + /// + /// The number describes the maximum number of meaningful data bits in the stored value. 8, 16 + /// and 32 bit values are stored using the INT32 physical type. 64 bit values are stored using + /// the INT64 physical type. + Uint8, + Uint16, + Uint32, + Uint64, + /// A signed integer value. + /// + /// The number describes the maximum number of meainful data bits in the stored value. 8, 16 + /// and 32 bit values are stored using the INT32 physical type. 64 bit values are stored using + /// the INT64 physical type. + /// + Int8, + Int16, + Int32, + Int64, + /// An embedded JSON document + /// + /// A JSON document embedded within a single UTF8 column. + Json, + /// An embedded BSON document + /// + /// A BSON document embedded within a single BINARY column. + Bson, + /// An interval of time + /// + /// This type annotates data stored as a FIXED_LEN_BYTE_ARRAY of length 12 This data is + /// composed of three separate little endian unsigned integers. Each stores a component of a + /// duration of time. The first integer identifies the number of months associated with the + /// duration, the second identifies the number of days associated with the duration and the + /// third identifies the number of milliseconds associated with the provided duration. This + /// duration of time is independent of any particular timezone or date. + Interval, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde_types", derive(Deserialize, Serialize))] +pub enum GroupConvertedType { + /// a map is converted as an optional field containing a repeated key/value pair + Map, + /// a key/value pair is converted into a group of two fields + MapKeyValue, + /// a list is converted into an optional field containing a repeated field for its values + List, +} + +impl TryFrom<(ConvertedType, Option<(i32, i32)>)> for PrimitiveConvertedType { + type Error = ParquetError; + + fn try_from( + (ty, maybe_decimal): (ConvertedType, Option<(i32, i32)>), + ) -> Result { + use PrimitiveConvertedType::*; + Ok(match ty { + ConvertedType::UTF8 => Utf8, + ConvertedType::ENUM => Enum, + ConvertedType::DECIMAL => { + if let Some((precision, scale)) = maybe_decimal { + Decimal(precision.try_into()?, scale.try_into()?) + } else { + return Err(ParquetError::oos("Decimal requires a precision and scale")); + } + }, + ConvertedType::DATE => Date, + ConvertedType::TIME_MILLIS => TimeMillis, + ConvertedType::TIME_MICROS => TimeMicros, + ConvertedType::TIMESTAMP_MILLIS => TimestampMillis, + ConvertedType::TIMESTAMP_MICROS => TimestampMicros, + ConvertedType::UINT_8 => Uint8, + ConvertedType::UINT_16 => Uint16, + ConvertedType::UINT_32 => Uint32, + ConvertedType::UINT_64 => Uint64, + ConvertedType::INT_8 => Int8, + ConvertedType::INT_16 => Int16, + ConvertedType::INT_32 => Int32, + ConvertedType::INT_64 => Int64, + ConvertedType::JSON => Json, + ConvertedType::BSON => Bson, + ConvertedType::INTERVAL => Interval, + _ => { + return Err(ParquetError::oos(format!( + "Converted type \"{:?}\" cannot be applied to a primitive type", + ty + ))); + }, + }) + } +} + +impl TryFrom for GroupConvertedType { + type Error = ParquetError; + + fn try_from(type_: ConvertedType) -> Result { + Ok(match type_ { + ConvertedType::LIST => GroupConvertedType::List, + ConvertedType::MAP => GroupConvertedType::Map, + ConvertedType::MAP_KEY_VALUE => GroupConvertedType::MapKeyValue, + _ => return Err(ParquetError::oos("LogicalType value out of range")), + }) + } +} + +impl From for ConvertedType { + fn from(type_: GroupConvertedType) -> Self { + match type_ { + GroupConvertedType::Map => ConvertedType::MAP, + GroupConvertedType::List => ConvertedType::LIST, + GroupConvertedType::MapKeyValue => ConvertedType::MAP_KEY_VALUE, + } + } +} + +impl From for (ConvertedType, Option<(i32, i32)>) { + fn from(ty: PrimitiveConvertedType) -> Self { + use PrimitiveConvertedType::*; + match ty { + Utf8 => (ConvertedType::UTF8, None), + Enum => (ConvertedType::ENUM, None), + Decimal(precision, scale) => ( + ConvertedType::DECIMAL, + Some((precision as i32, scale as i32)), + ), + Date => (ConvertedType::DATE, None), + TimeMillis => (ConvertedType::TIME_MILLIS, None), + TimeMicros => (ConvertedType::TIME_MICROS, None), + TimestampMillis => (ConvertedType::TIMESTAMP_MILLIS, None), + TimestampMicros => (ConvertedType::TIMESTAMP_MICROS, None), + Uint8 => (ConvertedType::UINT_8, None), + Uint16 => (ConvertedType::UINT_16, None), + Uint32 => (ConvertedType::UINT_32, None), + Uint64 => (ConvertedType::UINT_64, None), + Int8 => (ConvertedType::INT_8, None), + Int16 => (ConvertedType::INT_16, None), + Int32 => (ConvertedType::INT_32, None), + Int64 => (ConvertedType::INT_64, None), + Json => (ConvertedType::JSON, None), + Bson => (ConvertedType::BSON, None), + Interval => (ConvertedType::INTERVAL, None), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn round_trip() -> Result<(), ParquetError> { + use PrimitiveConvertedType::*; + let a = vec![ + Utf8, + Enum, + Decimal(3, 1), + Date, + TimeMillis, + TimeMicros, + TimestampMillis, + TimestampMicros, + Uint8, + Uint16, + Uint32, + Uint64, + Int8, + Int16, + Int32, + Int64, + Json, + Bson, + Interval, + ]; + for a in a { + let (c, d): (ConvertedType, Option<(i32, i32)>) = a.into(); + let e: PrimitiveConvertedType = (c, d).try_into()?; + assert_eq!(e, a); + } + Ok(()) + } +} diff --git a/crates/polars-parquet/src/parquet/schema/types/mod.rs b/crates/polars-parquet/src/parquet/schema/types/mod.rs new file mode 100644 index 000000000000..0516d75069bb --- /dev/null +++ b/crates/polars-parquet/src/parquet/schema/types/mod.rs @@ -0,0 +1,17 @@ +mod spec; + +mod physical_type; +pub use physical_type::*; + +mod basic_type; +pub use basic_type::*; + +mod converted_type; +pub use converted_type::*; + +mod parquet_type; +pub use parquet_type::*; + +pub use crate::parquet::parquet_bridge::{ + GroupLogicalType, IntegerType, PrimitiveLogicalType, TimeUnit, +}; diff --git a/crates/polars-parquet/src/parquet/schema/types/parquet_type.rs b/crates/polars-parquet/src/parquet/schema/types/parquet_type.rs new file mode 100644 index 000000000000..42d01d2f2210 --- /dev/null +++ b/crates/polars-parquet/src/parquet/schema/types/parquet_type.rs @@ -0,0 +1,213 @@ +// see https://github.com/apache/parquet-format/blob/master/LogicalTypes.md +use polars_utils::aliases::*; +use polars_utils::pl_str::PlSmallStr; +#[cfg(feature = "serde_types")] +use serde::{Deserialize, Serialize}; + +use super::super::Repetition; +use super::{ + FieldInfo, GroupConvertedType, GroupLogicalType, PhysicalType, PrimitiveConvertedType, + PrimitiveLogicalType, spec, +}; +use crate::parquet::error::ParquetResult; + +/// The complete description of a parquet column +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde_types", derive(Deserialize, Serialize))] +pub struct PrimitiveType { + /// The fields' generic information + pub field_info: FieldInfo, + /// The optional logical type + pub logical_type: Option, + /// The optional converted type + pub converted_type: Option, + /// The physical type + pub physical_type: PhysicalType, +} + +impl PrimitiveType { + /// Helper method to create an optional field with no logical or converted types. + pub fn from_physical(name: PlSmallStr, physical_type: PhysicalType) -> Self { + let field_info = FieldInfo { + name, + repetition: Repetition::Optional, + id: None, + }; + Self { + field_info, + converted_type: None, + logical_type: None, + physical_type, + } + } +} + +/// Representation of a Parquet type describing primitive and nested fields, +/// including the top-level schema of the parquet file. +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde_types", derive(Deserialize, Serialize))] +pub enum ParquetType { + PrimitiveType(PrimitiveType), + GroupType { + field_info: FieldInfo, + logical_type: Option, + converted_type: Option, + fields: Vec, + }, +} + +/// Accessors +impl ParquetType { + /// Returns [`FieldInfo`] information about the type. + pub fn get_field_info(&self) -> &FieldInfo { + match self { + Self::PrimitiveType(primitive) => &primitive.field_info, + Self::GroupType { field_info, .. } => field_info, + } + } + + /// Returns this type's field name. + pub fn name(&self) -> &str { + &self.get_field_info().name + } + + /// Checks if `sub_type` schema is part of current schema. + /// This method can be used to check if projected columns are part of the root schema. + pub fn check_contains(&self, sub_type: &ParquetType) -> bool { + let basic_match = self.get_field_info() == sub_type.get_field_info(); + + match (self, sub_type) { + ( + Self::PrimitiveType(PrimitiveType { physical_type, .. }), + Self::PrimitiveType(PrimitiveType { + physical_type: other_physical_type, + .. + }), + ) => basic_match && physical_type == other_physical_type, + ( + Self::GroupType { fields, .. }, + Self::GroupType { + fields: other_fields, + .. + }, + ) => { + // build hashmap of name -> Type + let mut field_map = PlHashMap::new(); + for field in fields { + field_map.insert(field.name(), field); + } + + for field in other_fields { + if !field_map + .get(field.name()) + .map(|tpe| tpe.check_contains(field)) + .unwrap_or(false) + { + return false; + } + } + true + }, + _ => false, + } + } +} + +/// Constructors +impl ParquetType { + pub(crate) fn new_root(name: PlSmallStr, fields: Vec) -> Self { + let field_info = FieldInfo { + name, + repetition: Repetition::Optional, + id: None, + }; + ParquetType::GroupType { + field_info, + fields, + logical_type: None, + converted_type: None, + } + } + + pub fn from_converted( + name: PlSmallStr, + fields: Vec, + repetition: Repetition, + converted_type: Option, + id: Option, + ) -> Self { + let field_info = FieldInfo { + name, + repetition, + id, + }; + + ParquetType::GroupType { + field_info, + fields, + converted_type, + logical_type: None, + } + } + + /// # Error + /// Errors iff the combination of physical, logical and converted type is not valid. + pub fn try_from_primitive( + name: PlSmallStr, + physical_type: PhysicalType, + repetition: Repetition, + converted_type: Option, + logical_type: Option, + id: Option, + ) -> ParquetResult { + // LogicalType has replaced the ConvertedType and there are certain LogicalType's that do + // not have a good counterpart in ConvertedType (e.g. Timestamp::Nanos). Therefore, we only + // check the ConvertedType if no LogicalType is given. This would signify a lot older + // Parquet file which could not have these new unsupported ConvertedTypes. + match logical_type { + None => spec::check_converted_invariants(&physical_type, &converted_type)?, + Some(logical_type) => spec::check_logical_invariants(&physical_type, logical_type)?, + } + + let field_info = FieldInfo { + name, + repetition, + id, + }; + + Ok(ParquetType::PrimitiveType(PrimitiveType { + field_info, + converted_type, + logical_type, + physical_type, + })) + } + + /// Helper method to create a [`ParquetType::PrimitiveType`] optional field + /// with no logical or converted types. + pub fn from_physical(name: PlSmallStr, physical_type: PhysicalType) -> Self { + ParquetType::PrimitiveType(PrimitiveType::from_physical(name, physical_type)) + } + + pub fn from_group( + name: PlSmallStr, + repetition: Repetition, + converted_type: Option, + logical_type: Option, + fields: Vec, + id: Option, + ) -> Self { + let field_info = FieldInfo { + name, + repetition, + id, + }; + + ParquetType::GroupType { + field_info, + logical_type, + converted_type, + fields, + } + } +} diff --git a/crates/polars-parquet/src/parquet/schema/types/physical_type.rs b/crates/polars-parquet/src/parquet/schema/types/physical_type.rs new file mode 100644 index 000000000000..ed7242adac71 --- /dev/null +++ b/crates/polars-parquet/src/parquet/schema/types/physical_type.rs @@ -0,0 +1,59 @@ +use polars_parquet_format::Type; +#[cfg(feature = "serde_types")] +use serde::{Deserialize, Serialize}; + +use crate::parquet::error::ParquetError; + +/// The set of all physical types representable in Parquet +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde_types", derive(Deserialize, Serialize))] +pub enum PhysicalType { + Boolean, + Int32, + Int64, + Int96, + Float, + Double, + ByteArray, + FixedLenByteArray(usize), +} + +impl TryFrom<(Type, Option)> for PhysicalType { + type Error = ParquetError; + + fn try_from((type_, length): (Type, Option)) -> Result { + Ok(match type_ { + Type::BOOLEAN => PhysicalType::Boolean, + Type::INT32 => PhysicalType::Int32, + Type::INT64 => PhysicalType::Int64, + Type::INT96 => PhysicalType::Int96, + Type::FLOAT => PhysicalType::Float, + Type::DOUBLE => PhysicalType::Double, + Type::BYTE_ARRAY => PhysicalType::ByteArray, + Type::FIXED_LEN_BYTE_ARRAY => { + let length = length.ok_or_else(|| { + ParquetError::oos("Length must be defined for FixedLenByteArray") + })?; + PhysicalType::FixedLenByteArray(length.try_into()?) + }, + _ => return Err(ParquetError::oos("Unknown type")), + }) + } +} + +impl From for (Type, Option) { + fn from(physical_type: PhysicalType) -> Self { + match physical_type { + PhysicalType::Boolean => (Type::BOOLEAN, None), + PhysicalType::Int32 => (Type::INT32, None), + PhysicalType::Int64 => (Type::INT64, None), + PhysicalType::Int96 => (Type::INT96, None), + PhysicalType::Float => (Type::FLOAT, None), + PhysicalType::Double => (Type::DOUBLE, None), + PhysicalType::ByteArray => (Type::BYTE_ARRAY, None), + PhysicalType::FixedLenByteArray(length) => { + (Type::FIXED_LEN_BYTE_ARRAY, Some(length as i32)) + }, + } + } +} diff --git a/crates/polars-parquet/src/parquet/schema/types/spec.rs b/crates/polars-parquet/src/parquet/schema/types/spec.rs new file mode 100644 index 000000000000..a7403f513273 --- /dev/null +++ b/crates/polars-parquet/src/parquet/schema/types/spec.rs @@ -0,0 +1,177 @@ +// see https://github.com/apache/parquet-format/blob/master/LogicalTypes.md +use super::{IntegerType, PhysicalType, PrimitiveConvertedType, PrimitiveLogicalType, TimeUnit}; +use crate::parquet::error::{ParquetError, ParquetResult}; + +fn check_decimal_invariants( + physical_type: &PhysicalType, + precision: usize, + scale: usize, +) -> ParquetResult<()> { + if precision < 1 { + return Err(ParquetError::oos(format!( + "DECIMAL precision must be larger than 0; It is {}", + precision, + ))); + } + if scale > precision { + return Err(ParquetError::oos(format!( + "Invalid DECIMAL: scale ({}) cannot be greater than precision \ + ({})", + scale, precision + ))); + } + + match physical_type { + PhysicalType::Int32 => { + if !(1..=9).contains(&precision) { + return Err(ParquetError::oos(format!( + "Cannot represent INT32 as DECIMAL with precision {}", + precision + ))); + } + }, + PhysicalType::Int64 => { + if !(1..=18).contains(&precision) { + return Err(ParquetError::oos(format!( + "Cannot represent INT64 as DECIMAL with precision {}", + precision + ))); + } + }, + PhysicalType::FixedLenByteArray(length) => { + let oos_error = + || ParquetError::oos(format!("Byte Array length {} out of spec", length)); + let max_precision = (2f64.powi( + (*length as i32) + .checked_mul(8) + .ok_or_else(oos_error)? + .checked_sub(1) + .ok_or_else(oos_error)?, + ) - 1f64) + .log10() + .floor() as usize; + + if precision > max_precision { + return Err(ParquetError::oos(format!( + "Cannot represent FIXED_LEN_BYTE_ARRAY as DECIMAL with length {} and \ + precision {}. The max precision can only be {}", + length, precision, max_precision + ))); + } + }, + PhysicalType::ByteArray => {}, + _ => { + return Err(ParquetError::oos( + "DECIMAL can only annotate INT32, INT64, BYTE_ARRAY and FIXED_LEN_BYTE_ARRAY", + )); + }, + }; + Ok(()) +} + +pub fn check_converted_invariants( + physical_type: &PhysicalType, + converted_type: &Option, +) -> ParquetResult<()> { + if converted_type.is_none() { + return Ok(()); + }; + let converted_type = converted_type.as_ref().unwrap(); + + use PrimitiveConvertedType::*; + match converted_type { + Utf8 | Bson | Json => { + if physical_type != &PhysicalType::ByteArray { + return Err(ParquetError::oos(format!( + "{:?} can only annotate BYTE_ARRAY fields", + converted_type + ))); + } + }, + Decimal(precision, scale) => { + check_decimal_invariants(physical_type, *precision, *scale)?; + }, + Date | TimeMillis | Uint8 | Uint16 | Uint32 | Int8 | Int16 | Int32 => { + if physical_type != &PhysicalType::Int32 { + return Err(ParquetError::oos(format!( + "{:?} can only annotate INT32", + converted_type + ))); + } + }, + TimeMicros | TimestampMillis | TimestampMicros | Uint64 | Int64 => { + if physical_type != &PhysicalType::Int64 { + return Err(ParquetError::oos(format!( + "{:?} can only annotate INT64", + converted_type + ))); + } + }, + Interval => { + if physical_type != &PhysicalType::FixedLenByteArray(12) { + return Err(ParquetError::oos( + "INTERVAL can only annotate FIXED_LEN_BYTE_ARRAY(12)", + )); + } + }, + Enum => { + if physical_type != &PhysicalType::ByteArray { + return Err(ParquetError::oos( + "ENUM can only annotate BYTE_ARRAY fields", + )); + } + }, + }; + Ok(()) +} + +pub fn check_logical_invariants( + physical_type: &PhysicalType, + logical_type: PrimitiveLogicalType, +) -> ParquetResult<()> { + // Check that logical type and physical type are compatible + use PrimitiveLogicalType::*; + match (logical_type, physical_type) { + (Enum, PhysicalType::ByteArray) => {}, + (Decimal(precision, scale), _) => { + check_decimal_invariants(physical_type, precision, scale)?; + }, + (Date, PhysicalType::Int32) => {}, + ( + Time { + unit: TimeUnit::Milliseconds, + .. + }, + PhysicalType::Int32, + ) => {}, + (Time { unit, .. }, PhysicalType::Int64) => { + if unit == TimeUnit::Milliseconds { + return Err(ParquetError::oos( + "Cannot use millisecond unit on INT64 type", + )); + } + }, + (Timestamp { .. }, PhysicalType::Int64) => {}, + (Integer(IntegerType::Int8), PhysicalType::Int32) => {}, + (Integer(IntegerType::Int16), PhysicalType::Int32) => {}, + (Integer(IntegerType::Int32), PhysicalType::Int32) => {}, + (Integer(IntegerType::UInt8), PhysicalType::Int32) => {}, + (Integer(IntegerType::UInt16), PhysicalType::Int32) => {}, + (Integer(IntegerType::UInt32), PhysicalType::Int32) => {}, + (Integer(IntegerType::UInt64), PhysicalType::Int64) => {}, + (Integer(IntegerType::Int64), PhysicalType::Int64) => {}, + // Null type + (Unknown, PhysicalType::Int32) => {}, + (String | Json | Bson, PhysicalType::ByteArray) => {}, + // https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#uuid + (Uuid, PhysicalType::FixedLenByteArray(16)) => {}, + (Float16, PhysicalType::FixedLenByteArray(2)) => {}, + (a, b) => { + return Err(ParquetError::oos(format!( + "Cannot annotate {:?} from {:?} fields", + a, b + ))); + }, + }; + Ok(()) +} diff --git a/crates/polars-parquet/src/parquet/statistics/binary.rs b/crates/polars-parquet/src/parquet/statistics/binary.rs new file mode 100644 index 000000000000..7f1dabf21ec8 --- /dev/null +++ b/crates/polars-parquet/src/parquet/statistics/binary.rs @@ -0,0 +1,41 @@ +use polars_parquet_format::Statistics as ParquetStatistics; + +use crate::parquet::error::ParquetResult; +use crate::parquet::schema::types::PrimitiveType; + +#[derive(Debug, Clone, PartialEq)] +pub struct BinaryStatistics { + pub primitive_type: PrimitiveType, + pub null_count: Option, + pub distinct_count: Option, + pub max_value: Option>, + pub min_value: Option>, +} + +impl BinaryStatistics { + pub fn deserialize( + v: &ParquetStatistics, + primitive_type: PrimitiveType, + ) -> ParquetResult { + Ok(BinaryStatistics { + primitive_type, + null_count: v.null_count, + distinct_count: v.distinct_count, + max_value: v.max_value.clone(), + min_value: v.min_value.clone(), + }) + } + + pub fn serialize(&self) -> ParquetStatistics { + ParquetStatistics { + null_count: self.null_count, + distinct_count: self.distinct_count, + max_value: self.max_value.clone(), + min_value: self.min_value.clone(), + max: None, + min: None, + is_max_value_exact: None, + is_min_value_exact: None, + } + } +} diff --git a/crates/polars-parquet/src/parquet/statistics/boolean.rs b/crates/polars-parquet/src/parquet/statistics/boolean.rs new file mode 100644 index 000000000000..55e478d5b957 --- /dev/null +++ b/crates/polars-parquet/src/parquet/statistics/boolean.rs @@ -0,0 +1,58 @@ +use polars_parquet_format::Statistics as ParquetStatistics; + +use crate::parquet::error::{ParquetError, ParquetResult}; + +#[derive(Debug, Clone, PartialEq)] +pub struct BooleanStatistics { + pub null_count: Option, + pub distinct_count: Option, + pub max_value: Option, + pub min_value: Option, +} + +impl BooleanStatistics { + pub fn deserialize(v: &ParquetStatistics) -> ParquetResult { + if let Some(ref v) = v.max_value { + if v.len() != size_of::() { + return Err(ParquetError::oos( + "The max_value of statistics MUST be plain encoded", + )); + } + }; + if let Some(ref v) = v.min_value { + if v.len() != size_of::() { + return Err(ParquetError::oos( + "The min_value of statistics MUST be plain encoded", + )); + } + }; + + Ok(Self { + null_count: v.null_count, + distinct_count: v.distinct_count, + max_value: v + .max_value + .as_ref() + .and_then(|x| x.first()) + .map(|x| *x != 0), + min_value: v + .min_value + .as_ref() + .and_then(|x| x.first()) + .map(|x| *x != 0), + }) + } + + pub fn serialize(&self) -> ParquetStatistics { + ParquetStatistics { + null_count: self.null_count, + distinct_count: self.distinct_count, + max_value: self.max_value.map(|x| vec![x as u8]), + min_value: self.min_value.map(|x| vec![x as u8]), + max: None, + min: None, + is_max_value_exact: None, + is_min_value_exact: None, + } + } +} diff --git a/crates/polars-parquet/src/parquet/statistics/fixed_len_binary.rs b/crates/polars-parquet/src/parquet/statistics/fixed_len_binary.rs new file mode 100644 index 000000000000..87642246907d --- /dev/null +++ b/crates/polars-parquet/src/parquet/statistics/fixed_len_binary.rs @@ -0,0 +1,63 @@ +use polars_parquet_format::Statistics as ParquetStatistics; + +use crate::parquet::error::{ParquetError, ParquetResult}; +use crate::parquet::schema::types::PrimitiveType; + +#[derive(Debug, Clone, PartialEq)] +pub struct FixedLenStatistics { + pub primitive_type: PrimitiveType, + pub null_count: Option, + pub distinct_count: Option, + pub max_value: Option>, + pub min_value: Option>, +} + +impl FixedLenStatistics { + pub fn deserialize( + v: &ParquetStatistics, + size: usize, + primitive_type: PrimitiveType, + ) -> ParquetResult { + if let Some(ref v) = v.max_value { + if v.len() != size { + return Err(ParquetError::oos( + "The max_value of statistics MUST be plain encoded", + )); + } + }; + if let Some(ref v) = v.min_value { + if v.len() != size { + return Err(ParquetError::oos( + "The min_value of statistics MUST be plain encoded", + )); + } + }; + + Ok(Self { + primitive_type, + null_count: v.null_count, + distinct_count: v.distinct_count, + max_value: v.max_value.clone().map(|mut x| { + x.truncate(size); + x + }), + min_value: v.min_value.clone().map(|mut x| { + x.truncate(size); + x + }), + }) + } + + pub fn serialize(&self) -> ParquetStatistics { + ParquetStatistics { + null_count: self.null_count, + distinct_count: self.distinct_count, + max_value: self.max_value.clone(), + min_value: self.min_value.clone(), + max: None, + min: None, + is_max_value_exact: None, + is_min_value_exact: None, + } + } +} diff --git a/crates/polars-parquet/src/parquet/statistics/mod.rs b/crates/polars-parquet/src/parquet/statistics/mod.rs new file mode 100644 index 000000000000..1f2b4b85a82f --- /dev/null +++ b/crates/polars-parquet/src/parquet/statistics/mod.rs @@ -0,0 +1,213 @@ +mod binary; +mod boolean; +mod fixed_len_binary; +mod primitive; + +pub use binary::BinaryStatistics; +pub use boolean::BooleanStatistics; +pub use fixed_len_binary::FixedLenStatistics; +pub use primitive::PrimitiveStatistics; + +use crate::parquet::error::ParquetResult; +use crate::parquet::schema::types::{PhysicalType, PrimitiveType}; +pub use crate::parquet::thrift_format::Statistics as ParquetStatistics; + +#[derive(Debug, PartialEq)] +pub enum Statistics { + Binary(BinaryStatistics), + Boolean(BooleanStatistics), + FixedLen(FixedLenStatistics), + Int32(PrimitiveStatistics), + Int64(PrimitiveStatistics), + Int96(PrimitiveStatistics<[u32; 3]>), + Float(PrimitiveStatistics), + Double(PrimitiveStatistics), +} + +impl Statistics { + #[inline] + pub const fn physical_type(&self) -> &PhysicalType { + use Statistics as S; + + match self { + S::Binary(_) => &PhysicalType::ByteArray, + S::Boolean(_) => &PhysicalType::Boolean, + S::FixedLen(s) => &s.primitive_type.physical_type, + S::Int32(_) => &PhysicalType::Int32, + S::Int64(_) => &PhysicalType::Int64, + S::Int96(_) => &PhysicalType::Int96, + S::Float(_) => &PhysicalType::Float, + S::Double(_) => &PhysicalType::Double, + } + } + + pub fn clear_min(&mut self) { + use Statistics as S; + match self { + S::Binary(s) => _ = s.min_value.take(), + S::Boolean(s) => _ = s.min_value.take(), + S::FixedLen(s) => _ = s.min_value.take(), + S::Int32(s) => _ = s.min_value.take(), + S::Int64(s) => _ = s.min_value.take(), + S::Int96(s) => _ = s.min_value.take(), + S::Float(s) => _ = s.min_value.take(), + S::Double(s) => _ = s.min_value.take(), + }; + } + + pub fn clear_max(&mut self) { + use Statistics as S; + match self { + S::Binary(s) => _ = s.max_value.take(), + S::Boolean(s) => _ = s.max_value.take(), + S::FixedLen(s) => _ = s.max_value.take(), + S::Int32(s) => _ = s.max_value.take(), + S::Int64(s) => _ = s.max_value.take(), + S::Int96(s) => _ = s.max_value.take(), + S::Float(s) => _ = s.max_value.take(), + S::Double(s) => _ = s.max_value.take(), + }; + } + + /// Deserializes a raw parquet statistics into [`Statistics`]. + /// # Error + /// This function errors if it is not possible to read the statistics to the + /// corresponding `physical_type`. + #[inline] + pub fn deserialize( + statistics: &ParquetStatistics, + primitive_type: PrimitiveType, + ) -> ParquetResult { + use {PhysicalType as T, PrimitiveStatistics as PrimStat}; + let mut stats: Self = match primitive_type.physical_type { + T::ByteArray => BinaryStatistics::deserialize(statistics, primitive_type)?.into(), + T::Boolean => BooleanStatistics::deserialize(statistics)?.into(), + T::Int32 => PrimStat::::deserialize(statistics, primitive_type)?.into(), + T::Int64 => PrimStat::::deserialize(statistics, primitive_type)?.into(), + T::Int96 => PrimStat::<[u32; 3]>::deserialize(statistics, primitive_type)?.into(), + T::Float => PrimStat::::deserialize(statistics, primitive_type)?.into(), + T::Double => PrimStat::::deserialize(statistics, primitive_type)?.into(), + T::FixedLenByteArray(size) => { + FixedLenStatistics::deserialize(statistics, size, primitive_type)?.into() + }, + }; + + if statistics.is_min_value_exact.is_some_and(|v| !v) { + stats.clear_min(); + } + if statistics.is_max_value_exact.is_some_and(|v| !v) { + stats.clear_max(); + } + + // Parquet Format: + // > - If the min is a NaN, it should be ignored. + // > - If the max is a NaN, it should be ignored. + match &mut stats { + Statistics::Float(stats) => { + stats.min_value.take_if(|v| v.is_nan()); + stats.max_value.take_if(|v| v.is_nan()); + }, + Statistics::Double(stats) => { + stats.min_value.take_if(|v| v.is_nan()); + stats.max_value.take_if(|v| v.is_nan()); + }, + _ => {}, + } + + Ok(stats) + } +} + +macro_rules! statistics_from_as { + ($($variant:ident($struct:ty) => ($as_ident:ident, $into_ident:ident, $expect_ident:ident, $owned_expect_ident:ident),)+) => { + $( + impl From<$struct> for Statistics { + #[inline] + fn from(stats: $struct) -> Self { + Self::$variant(stats) + } + } + )+ + + impl Statistics { + #[inline] + pub const fn null_count(&self) -> Option { + match self { + $(Self::$variant(s) => s.null_count,)+ + } + } + + /// Serializes [`Statistics`] into a raw parquet statistics. + #[inline] + pub fn serialize(&self) -> ParquetStatistics { + match self { + $(Self::$variant(s) => s.serialize(),)+ + } + } + + const fn variant_str(&self) -> &'static str { + match self { + $(Self::$variant(_) => stringify!($struct),)+ + } + } + + $( + #[doc = concat!("Try to take [`Statistics`] as [`", stringify!($struct), "`]")] + #[inline] + pub fn $as_ident(&self) -> Option<&$struct> { + match self { + Self::$variant(s) => Some(s), + _ => None, + } + } + + #[doc = concat!("Try to take [`Statistics`] as [`", stringify!($struct), "`]")] + #[inline] + pub fn $into_ident(self) -> Option<$struct> { + match self { + Self::$variant(s) => Some(s), + _ => None, + } + } + + #[doc = concat!("Interpret [`Statistics`] to be [`", stringify!($struct), "`]")] + /// + /// Panics if it is not the correct variant. + #[track_caller] + #[inline] + pub fn $expect_ident(&self) -> &$struct { + let Self::$variant(s) = self else { + panic!("Expected Statistics to be {}, found {} instead", stringify!($struct), self.variant_str()); + }; + + s + } + + #[doc = concat!("Interpret [`Statistics`] to be [`", stringify!($struct), "`]")] + /// + /// Panics if it is not the correct variant. + #[track_caller] + #[inline] + pub fn $owned_expect_ident(self) -> $struct { + let Self::$variant(s) = self else { + panic!("Expected Statistics to be {}, found {} instead", stringify!($struct), self.variant_str()); + }; + + s + } + )+ + + } + }; +} + +statistics_from_as! { + Binary (BinaryStatistics ) => (as_binary, into_binary, expect_as_binary, expect_binary ), + Boolean (BooleanStatistics ) => (as_boolean, into_boolean, expect_as_boolean, expect_boolean ), + FixedLen (FixedLenStatistics ) => (as_fixedlen, into_fixedlen, expect_as_fixedlen, expect_fixedlen), + Int32 (PrimitiveStatistics ) => (as_int32, into_int32, expect_as_int32, expect_int32 ), + Int64 (PrimitiveStatistics ) => (as_int64, into_int64, expect_as_int64, expect_int64 ), + Int96 (PrimitiveStatistics<[u32; 3]>) => (as_int96, into_int96, expect_as_int96, expect_int96 ), + Float (PrimitiveStatistics ) => (as_float, into_float, expect_as_float, expect_float ), + Double (PrimitiveStatistics ) => (as_double, into_double, expect_as_double, expect_double ), +} diff --git a/crates/polars-parquet/src/parquet/statistics/primitive.rs b/crates/polars-parquet/src/parquet/statistics/primitive.rs new file mode 100644 index 000000000000..7bd8d227faa2 --- /dev/null +++ b/crates/polars-parquet/src/parquet/statistics/primitive.rs @@ -0,0 +1,59 @@ +use polars_parquet_format::Statistics as ParquetStatistics; + +use crate::parquet::error::{ParquetError, ParquetResult}; +use crate::parquet::schema::types::PrimitiveType; +use crate::parquet::types; + +#[derive(Debug, Clone, PartialEq)] +pub struct PrimitiveStatistics { + pub primitive_type: PrimitiveType, + pub null_count: Option, + pub distinct_count: Option, + pub min_value: Option, + pub max_value: Option, +} + +impl PrimitiveStatistics { + pub fn deserialize( + v: &ParquetStatistics, + primitive_type: PrimitiveType, + ) -> ParquetResult { + if v.max_value + .as_ref() + .is_some_and(|v| v.len() != size_of::()) + { + return Err(ParquetError::oos( + "The max_value of statistics MUST be plain encoded", + )); + }; + if v.min_value + .as_ref() + .is_some_and(|v| v.len() != size_of::()) + { + return Err(ParquetError::oos( + "The min_value of statistics MUST be plain encoded", + )); + }; + + Ok(Self { + primitive_type, + null_count: v.null_count, + distinct_count: v.distinct_count, + max_value: v.max_value.as_ref().map(|x| types::decode(x)), + min_value: v.min_value.as_ref().map(|x| types::decode(x)), + }) + } + + pub fn serialize(&self) -> ParquetStatistics { + ParquetStatistics { + null_count: self.null_count, + distinct_count: self.distinct_count, + max_value: self.max_value.map(|x| x.to_le_bytes().as_ref().to_vec()), + min_value: self.min_value.map(|x| x.to_le_bytes().as_ref().to_vec()), + max: None, + min: None, + is_max_value_exact: None, + is_min_value_exact: None, + } + } +} diff --git a/crates/polars-parquet/src/parquet/types.rs b/crates/polars-parquet/src/parquet/types.rs new file mode 100644 index 000000000000..f0b179d0fea6 --- /dev/null +++ b/crates/polars-parquet/src/parquet/types.rs @@ -0,0 +1,164 @@ +use arrow::types::{ + AlignedBytes, AlignedBytesCast, Bytes4Alignment4, Bytes8Alignment8, Bytes12Alignment4, +}; + +use crate::parquet::schema::types::PhysicalType; + +/// A physical native representation of a Parquet fixed-sized type. +pub trait NativeType: + std::fmt::Debug + Send + Sync + 'static + Copy + Clone + AlignedBytesCast +{ + type Bytes: AsRef<[u8]> + + bytemuck::Pod + + IntoIterator + + for<'a> TryFrom<&'a [u8], Error = std::array::TryFromSliceError> + + std::fmt::Debug + + Clone + + Copy; + type AlignedBytes: AlignedBytes + From + Into; + + fn to_le_bytes(&self) -> Self::Bytes; + + fn from_le_bytes(bytes: Self::Bytes) -> Self; + + fn ord(&self, other: &Self) -> std::cmp::Ordering; + + const TYPE: PhysicalType; +} + +macro_rules! native { + ($type:ty, $unaligned:ty, $physical_type:expr) => { + impl NativeType for $type { + type Bytes = [u8; size_of::()]; + type AlignedBytes = $unaligned; + + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + Self::to_le_bytes(*self) + } + + #[inline] + fn from_le_bytes(bytes: Self::Bytes) -> Self { + Self::from_le_bytes(bytes) + } + + #[inline] + fn ord(&self, other: &Self) -> std::cmp::Ordering { + self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Equal) + } + + const TYPE: PhysicalType = $physical_type; + } + }; +} + +native!(i32, Bytes4Alignment4, PhysicalType::Int32); +native!(i64, Bytes8Alignment8, PhysicalType::Int64); +native!(f32, Bytes4Alignment4, PhysicalType::Float); +native!(f64, Bytes8Alignment8, PhysicalType::Double); + +impl NativeType for [u32; 3] { + const TYPE: PhysicalType = PhysicalType::Int96; + + type Bytes = [u8; size_of::()]; + type AlignedBytes = Bytes12Alignment4; + + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + let mut bytes = [0; 12]; + let first = self[0].to_le_bytes(); + bytes[0] = first[0]; + bytes[1] = first[1]; + bytes[2] = first[2]; + bytes[3] = first[3]; + let second = self[1].to_le_bytes(); + bytes[4] = second[0]; + bytes[5] = second[1]; + bytes[6] = second[2]; + bytes[7] = second[3]; + let third = self[2].to_le_bytes(); + bytes[8] = third[0]; + bytes[9] = third[1]; + bytes[10] = third[2]; + bytes[11] = third[3]; + bytes + } + + #[inline] + fn from_le_bytes(bytes: Self::Bytes) -> Self { + let mut first = [0; 4]; + first[0] = bytes[0]; + first[1] = bytes[1]; + first[2] = bytes[2]; + first[3] = bytes[3]; + let mut second = [0; 4]; + second[0] = bytes[4]; + second[1] = bytes[5]; + second[2] = bytes[6]; + second[3] = bytes[7]; + let mut third = [0; 4]; + third[0] = bytes[8]; + third[1] = bytes[9]; + third[2] = bytes[10]; + third[3] = bytes[11]; + [ + u32::from_le_bytes(first), + u32::from_le_bytes(second), + u32::from_le_bytes(third), + ] + } + + #[inline] + fn ord(&self, other: &Self) -> std::cmp::Ordering { + int96_to_i64_ns(*self).ord(&int96_to_i64_ns(*other)) + } +} + +#[inline] +pub fn int96_to_i64_ns(value: [u32; 3]) -> i64 { + const JULIAN_DAY_OF_EPOCH: i64 = 2_440_588; + const SECONDS_PER_DAY: i64 = 86_400; + const NANOS_PER_SECOND: i64 = 1_000_000_000; + + let day = value[2] as i64; + let nanoseconds = ((value[1] as i64) << 32) + value[0] as i64; + let seconds = (day - JULIAN_DAY_OF_EPOCH) * SECONDS_PER_DAY; + + seconds * NANOS_PER_SECOND + nanoseconds +} + +/// Returns the ordering of two binary values. +pub fn ord_binary<'a>(a: &'a [u8], b: &'a [u8]) -> std::cmp::Ordering { + use std::cmp::Ordering::*; + match (a.is_empty(), b.is_empty()) { + (true, true) => return Equal, + (true, false) => return Less, + (false, true) => return Greater, + (false, false) => {}, + } + + for (v1, v2) in a.iter().zip(b.iter()) { + match v1.cmp(v2) { + Equal => continue, + other => return other, + } + } + Equal +} + +#[inline] +pub fn decode(chunk: &[u8]) -> T { + assert!(chunk.len() >= size_of::<::Bytes>()); + unsafe { decode_unchecked(chunk) } +} + +/// Convert a Little-Endian byte-slice into the `T` +/// +/// # Safety +/// +/// This is safe if the length is properly checked. +#[inline] +pub unsafe fn decode_unchecked(chunk: &[u8]) -> T { + let chunk: ::Bytes = unsafe { chunk.try_into().unwrap_unchecked() }; + T::from_le_bytes(chunk) +} diff --git a/crates/polars-parquet/src/parquet/write/column_chunk.rs b/crates/polars-parquet/src/parquet/write/column_chunk.rs new file mode 100644 index 000000000000..c44ce6bfb404 --- /dev/null +++ b/crates/polars-parquet/src/parquet/write/column_chunk.rs @@ -0,0 +1,213 @@ +use std::io::Write; + +#[cfg(feature = "async")] +use futures::AsyncWrite; +use polars_parquet_format::thrift::protocol::TCompactOutputProtocol; +#[cfg(feature = "async")] +use polars_parquet_format::thrift::protocol::TCompactOutputStreamProtocol; +use polars_parquet_format::{ColumnChunk, ColumnMetaData, Type}; +use polars_utils::aliases::PlHashSet; + +use super::DynStreamingIterator; +#[cfg(feature = "async")] +use super::page::write_page_async; +use super::page::{PageWriteSpec, write_page}; +use super::statistics::reduce; +use crate::parquet::FallibleStreamingIterator; +use crate::parquet::compression::Compression; +use crate::parquet::encoding::Encoding; +use crate::parquet::error::{ParquetError, ParquetResult}; +use crate::parquet::metadata::ColumnDescriptor; +use crate::parquet::page::{CompressedPage, PageType}; + +pub fn write_column_chunk( + writer: &mut W, + mut offset: u64, + descriptor: &ColumnDescriptor, + mut compressed_pages: DynStreamingIterator<'_, CompressedPage, E>, +) -> ParquetResult<(ColumnChunk, Vec, u64)> +where + W: Write, + ParquetError: From, + E: std::error::Error, +{ + // write every page + + let initial = offset; + + let mut specs = vec![]; + while let Some(compressed_page) = compressed_pages.next()? { + let spec = write_page(writer, offset, compressed_page)?; + offset += spec.bytes_written; + specs.push(spec); + } + let mut bytes_written = offset - initial; + + let column_chunk = build_column_chunk(&specs, descriptor)?; + + // write metadata + let mut protocol = TCompactOutputProtocol::new(writer); + bytes_written += column_chunk + .meta_data + .as_ref() + .unwrap() + .write_to_out_protocol(&mut protocol)? as u64; + + Ok((column_chunk, specs, bytes_written)) +} + +#[cfg(feature = "async")] +#[cfg_attr(docsrs, doc(cfg(feature = "async")))] +pub async fn write_column_chunk_async( + writer: &mut W, + mut offset: u64, + descriptor: &ColumnDescriptor, + mut compressed_pages: DynStreamingIterator<'_, CompressedPage, E>, +) -> ParquetResult<(ColumnChunk, Vec, u64)> +where + W: AsyncWrite + Unpin + Send, + ParquetError: From, + E: std::error::Error, +{ + let initial = offset; + // write every page + let mut specs = vec![]; + while let Some(compressed_page) = compressed_pages.next()? { + let spec = write_page_async(writer, offset, compressed_page).await?; + offset += spec.bytes_written; + specs.push(spec); + } + let mut bytes_written = offset - initial; + + let column_chunk = build_column_chunk(&specs, descriptor)?; + + // write metadata + let mut protocol = TCompactOutputStreamProtocol::new(writer); + bytes_written += column_chunk + .meta_data + .as_ref() + .unwrap() + .write_to_out_stream_protocol(&mut protocol) + .await? as u64; + + Ok((column_chunk, specs, bytes_written)) +} + +fn build_column_chunk( + specs: &[PageWriteSpec], + descriptor: &ColumnDescriptor, +) -> ParquetResult { + // compute stats to build header at the end of the chunk + + let compression = specs + .iter() + .map(|spec| spec.compression) + .collect::>(); + if compression.len() > 1 { + return Err(crate::parquet::error::ParquetError::oos( + "All pages within a column chunk must be compressed with the same codec", + )); + } + let compression = compression + .into_iter() + .next() + .unwrap_or(Compression::Uncompressed); + + // SPEC: the total compressed size is the total compressed size of each page + the header size + let total_compressed_size = specs + .iter() + .map(|x| x.header_size as i64 + x.header.compressed_page_size as i64) + .sum(); + // SPEC: the total compressed size is the total compressed size of each page + the header size + let total_uncompressed_size = specs + .iter() + .map(|x| x.header_size as i64 + x.header.uncompressed_page_size as i64) + .sum(); + let data_page_offset = specs.first().map(|spec| spec.offset).unwrap_or(0) as i64; + let num_values = specs + .iter() + .map(|spec| { + let type_ = spec.header.type_.try_into().unwrap(); + match type_ { + PageType::DataPage => { + spec.header.data_page_header.as_ref().unwrap().num_values as i64 + }, + PageType::DataPageV2 => { + spec.header.data_page_header_v2.as_ref().unwrap().num_values as i64 + }, + _ => 0, // only data pages contribute + } + }) + .sum(); + let mut encodings = specs + .iter() + .flat_map(|spec| { + let type_ = spec.header.type_.try_into().unwrap(); + match type_ { + PageType::DataPage => vec![ + spec.header.data_page_header.as_ref().unwrap().encoding, + Encoding::Rle.into(), + ], + PageType::DataPageV2 => { + vec![ + spec.header.data_page_header_v2.as_ref().unwrap().encoding, + Encoding::Rle.into(), + ] + }, + PageType::DictionaryPage => vec![ + spec.header + .dictionary_page_header + .as_ref() + .unwrap() + .encoding, + ], + } + }) + .collect::>() // unique + .into_iter() // to vec + .collect::>(); + + // Sort the encodings to have deterministic metadata + encodings.sort(); + + let statistics = specs.iter().map(|x| &x.statistics).collect::>(); + let statistics = reduce(&statistics)?; + let statistics = statistics.map(|x| x.serialize()); + + let (type_, _): (Type, Option) = descriptor.descriptor.primitive_type.physical_type.into(); + + let metadata = ColumnMetaData { + type_, + encodings, + path_in_schema: descriptor + .path_in_schema + .iter() + .map(|x| x.to_string()) + .collect::>(), + codec: compression.into(), + num_values, + total_uncompressed_size, + total_compressed_size, + key_value_metadata: None, + data_page_offset, + index_page_offset: None, + dictionary_page_offset: None, + statistics, + encoding_stats: None, + bloom_filter_offset: None, + bloom_filter_length: None, + size_statistics: None, + }; + + Ok(ColumnChunk { + file_path: None, // same file for now. + file_offset: data_page_offset + total_compressed_size, + meta_data: Some(metadata), + offset_index_offset: None, + offset_index_length: None, + column_index_offset: None, + column_index_length: None, + crypto_metadata: None, + encrypted_column_metadata: None, + }) +} diff --git a/crates/polars-parquet/src/parquet/write/compression.rs b/crates/polars-parquet/src/parquet/write/compression.rs new file mode 100644 index 000000000000..be23cd752556 --- /dev/null +++ b/crates/polars-parquet/src/parquet/write/compression.rs @@ -0,0 +1,186 @@ +use crate::parquet::compression::CompressionOptions; +use crate::parquet::error::{ParquetError, ParquetResult}; +use crate::parquet::page::{ + CompressedDataPage, CompressedDictPage, CompressedPage, DataPage, DataPageHeader, DictPage, + Page, +}; +use crate::parquet::{CowBuffer, FallibleStreamingIterator, compression}; + +/// Compresses a [`DataPage`] into a [`CompressedDataPage`]. +fn compress_data( + page: DataPage, + mut compressed_buffer: Vec, + compression: CompressionOptions, +) -> ParquetResult { + let DataPage { + mut buffer, + header, + descriptor, + num_rows, + } = page; + let uncompressed_page_size = buffer.len(); + let num_rows = num_rows.expect("We should have num_rows when we are writing"); + if compression != CompressionOptions::Uncompressed { + match &header { + DataPageHeader::V1(_) => { + compression::compress(compression, &buffer, &mut compressed_buffer)?; + }, + DataPageHeader::V2(header) => { + let levels_byte_length = (header.repetition_levels_byte_length + + header.definition_levels_byte_length) + as usize; + compressed_buffer.extend_from_slice(&buffer[..levels_byte_length]); + compression::compress( + compression, + &buffer[levels_byte_length..], + &mut compressed_buffer, + )?; + }, + }; + } else { + std::mem::swap(buffer.to_mut(), &mut compressed_buffer); + } + + Ok(CompressedDataPage::new( + header, + CowBuffer::Owned(compressed_buffer), + compression.into(), + uncompressed_page_size, + descriptor, + num_rows, + )) +} + +fn compress_dict( + page: DictPage, + mut compressed_buffer: Vec, + compression: CompressionOptions, +) -> ParquetResult { + let DictPage { + buffer, + num_values, + is_sorted, + } = page; + + let uncompressed_page_size = buffer.len(); + let compressed_buffer = if compression != CompressionOptions::Uncompressed { + compression::compress(compression, &buffer, &mut compressed_buffer)?; + CowBuffer::Owned(compressed_buffer) + } else { + buffer + }; + + Ok(CompressedDictPage::new( + compressed_buffer, + compression.into(), + uncompressed_page_size, + num_values, + is_sorted, + )) +} + +/// Compresses an [`EncodedPage`] into a [`CompressedPage`] using `compressed_buffer` as the +/// intermediary buffer. +/// +/// `compressed_buffer` is taken by value because it becomes owned by [`CompressedPage`] +/// +/// # Errors +/// Errors if the compressor fails +pub fn compress( + page: Page, + compressed_buffer: Vec, + compression: CompressionOptions, +) -> ParquetResult { + match page { + Page::Data(page) => { + compress_data(page, compressed_buffer, compression).map(CompressedPage::Data) + }, + Page::Dict(page) => { + compress_dict(page, compressed_buffer, compression).map(CompressedPage::Dict) + }, + } +} + +/// A [`FallibleStreamingIterator`] that consumes [`Page`] and yields [`CompressedPage`] +/// holding a reusable buffer ([`Vec`]) for compression. +pub struct Compressor>> { + iter: I, + compression: CompressionOptions, + buffer: Vec, + current: Option, +} + +impl>> Compressor { + /// Creates a new [`Compressor`] + pub fn new(iter: I, compression: CompressionOptions, buffer: Vec) -> Self { + Self { + iter, + compression, + buffer, + current: None, + } + } + + /// Creates a new [`Compressor`] (same as `new`) + pub fn new_from_vec(iter: I, compression: CompressionOptions, buffer: Vec) -> Self { + Self::new(iter, compression, buffer) + } + + /// Deconstructs itself into its iterator and scratch buffer. + pub fn into_inner(mut self) -> (I, Vec) { + let mut buffer = if let Some(page) = self.current.as_mut() { + std::mem::take(page.buffer_mut()) + } else { + std::mem::take(&mut self.buffer) + }; + buffer.clear(); + (self.iter, buffer) + } +} + +impl>> FallibleStreamingIterator for Compressor { + type Item = CompressedPage; + type Error = ParquetError; + + fn advance(&mut self) -> std::result::Result<(), Self::Error> { + let mut compressed_buffer = if let Some(page) = self.current.as_mut() { + std::mem::take(page.buffer_mut()) + } else { + std::mem::take(&mut self.buffer) + }; + compressed_buffer.clear(); + + let next = self + .iter + .next() + .map(|x| x.and_then(|page| compress(page, compressed_buffer, self.compression))) + .transpose()?; + self.current = next; + Ok(()) + } + + fn get(&self) -> Option<&Self::Item> { + self.current.as_ref() + } +} + +impl>> Iterator for Compressor { + type Item = ParquetResult; + + fn next(&mut self) -> Option { + let mut compressed_buffer = if let Some(page) = self.current.as_mut() { + std::mem::take(page.buffer_mut()) + } else { + std::mem::take(&mut self.buffer) + }; + compressed_buffer.clear(); + + let page = self.iter.next()?; + let page = match page { + Ok(page) => page, + Err(err) => return Some(Err(err)), + }; + + Some(compress(page, compressed_buffer, self.compression)) + } +} diff --git a/crates/polars-parquet/src/parquet/write/dyn_iter.rs b/crates/polars-parquet/src/parquet/write/dyn_iter.rs new file mode 100644 index 000000000000..a232c06375e8 --- /dev/null +++ b/crates/polars-parquet/src/parquet/write/dyn_iter.rs @@ -0,0 +1,65 @@ +use crate::parquet::FallibleStreamingIterator; + +/// [`DynIter`] is an implementation of a single-threaded, dynamically-typed iterator. +/// +/// This implementation is object safe. +pub struct DynIter<'a, V> { + iter: Box + 'a + Send + Sync>, +} + +impl Iterator for DynIter<'_, V> { + type Item = V; + fn next(&mut self) -> Option { + self.iter.next() + } + + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } +} + +impl<'a, V> DynIter<'a, V> { + /// Returns a new [`DynIter`], boxing the incoming iterator + pub fn new(iter: I) -> Self + where + I: Iterator + 'a + Send + Sync, + { + Self { + iter: Box::new(iter), + } + } +} + +/// Dynamically-typed [`FallibleStreamingIterator`]. +pub struct DynStreamingIterator<'a, V, E> { + iter: Box + 'a + Send + Sync>, +} + +impl FallibleStreamingIterator for DynStreamingIterator<'_, V, E> { + type Item = V; + type Error = E; + + fn advance(&mut self) -> Result<(), Self::Error> { + self.iter.advance() + } + + fn get(&self) -> Option<&Self::Item> { + self.iter.get() + } + + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } +} + +impl<'a, V, E> DynStreamingIterator<'a, V, E> { + /// Returns a new [`DynStreamingIterator`], boxing the incoming iterator + pub fn new(iter: I) -> Self + where + I: FallibleStreamingIterator + 'a + Send + Sync, + { + Self { + iter: Box::new(iter), + } + } +} diff --git a/crates/polars-parquet/src/parquet/write/file.rs b/crates/polars-parquet/src/parquet/write/file.rs new file mode 100644 index 000000000000..c869f097fdb5 --- /dev/null +++ b/crates/polars-parquet/src/parquet/write/file.rs @@ -0,0 +1,259 @@ +use std::io::Write; + +use polars_parquet_format::RowGroup; +use polars_parquet_format::thrift::protocol::TCompactOutputProtocol; + +use super::indexes::{write_column_index, write_offset_index}; +use super::page::PageWriteSpec; +use super::row_group::write_row_group; +use super::{RowGroupIterColumns, WriteOptions}; +use crate::parquet::error::{ParquetError, ParquetResult}; +pub use crate::parquet::metadata::KeyValue; +use crate::parquet::metadata::{SchemaDescriptor, ThriftFileMetadata}; +use crate::parquet::write::State; +use crate::parquet::{FOOTER_SIZE, PARQUET_MAGIC}; + +pub(super) fn start_file(writer: &mut W) -> ParquetResult { + writer.write_all(&PARQUET_MAGIC)?; + Ok(PARQUET_MAGIC.len() as u64) +} + +pub(super) fn end_file( + mut writer: &mut W, + metadata: &ThriftFileMetadata, +) -> ParquetResult { + // Write metadata + let mut protocol = TCompactOutputProtocol::new(&mut writer); + let metadata_len = metadata.write_to_out_protocol(&mut protocol)? as i32; + + // Write footer + let metadata_bytes = metadata_len.to_le_bytes(); + let mut footer_buffer = [0u8; FOOTER_SIZE as usize]; + (0..4).for_each(|i| { + footer_buffer[i] = metadata_bytes[i]; + }); + + (&mut footer_buffer[4..]).write_all(&PARQUET_MAGIC)?; + writer.write_all(&footer_buffer)?; + writer.flush()?; + Ok(metadata_len as u64 + FOOTER_SIZE) +} + +fn create_column_orders(schema_desc: &SchemaDescriptor) -> Vec { + // We only include ColumnOrder for leaf nodes. + // Currently only supported ColumnOrder is TypeDefinedOrder so we set this + // for all leaf nodes. + // Even if the column has an undefined sort order, such as INTERVAL, this + // is still technically the defined TYPEORDER so it should still be set. + (0..schema_desc.columns().len()) + .map(|_| { + polars_parquet_format::ColumnOrder::TYPEORDER( + polars_parquet_format::TypeDefinedOrder {}, + ) + }) + .collect() +} + +/// An interface to write a parquet file. +/// Use `start` to write the header, `write` to write a row group, +/// and `end` to write the footer. +pub struct FileWriter { + writer: W, + schema: SchemaDescriptor, + options: WriteOptions, + created_by: Option, + + offset: u64, + row_groups: Vec, + page_specs: Vec>>, + /// Used to store the current state for writing the file + state: State, + // when the file is written, metadata becomes available + metadata: Option, +} + +/// Writes a parquet file containing only the header and footer +/// +/// This is used to write the metadata as a separate Parquet file, usually when data +/// is partitioned across multiple files. +/// +/// Note: Recall that when combining row groups from [`ThriftFileMetadata`], the `file_path` on each +/// of their column chunks must be updated with their path relative to where they are written to. +pub fn write_metadata_sidecar( + writer: &mut W, + metadata: &ThriftFileMetadata, +) -> ParquetResult { + let mut len = start_file(writer)?; + len += end_file(writer, metadata)?; + Ok(len) +} + +// Accessors +impl FileWriter { + /// The options assigned to the file + pub fn options(&self) -> &WriteOptions { + &self.options + } + + /// The [`SchemaDescriptor`] assigned to this file + pub fn schema(&self) -> &SchemaDescriptor { + &self.schema + } + + /// Returns the [`ThriftFileMetadata`]. This is Some iff the [`Self::end`] has been called. + /// + /// This is used to write the metadata as a separate Parquet file, usually when data + /// is partitioned across multiple files + pub fn metadata(&self) -> Option<&ThriftFileMetadata> { + self.metadata.as_ref() + } +} + +impl FileWriter { + /// Returns a new [`FileWriter`]. + pub fn new( + writer: W, + schema: SchemaDescriptor, + options: WriteOptions, + created_by: Option, + ) -> Self { + Self { + writer, + schema, + options, + created_by, + offset: 0, + row_groups: vec![], + page_specs: vec![], + state: State::Initialised, + metadata: None, + } + } + + /// Writes the header of the file. + /// + /// This is automatically called by [`Self::write`] if not called following [`Self::new`]. + /// + /// # Errors + /// Returns an error if data has been written to the file. + fn start(&mut self) -> ParquetResult<()> { + if self.offset == 0 { + self.offset = start_file(&mut self.writer)?; + self.state = State::Started; + Ok(()) + } else { + Err(ParquetError::InvalidParameter( + "Start cannot be called twice".to_string(), + )) + } + } + + /// Writes a row group to the file. + /// + /// This call is IO-bounded + pub fn write(&mut self, row_group: RowGroupIterColumns<'_, E>) -> ParquetResult<()> + where + ParquetError: From, + E: std::error::Error, + { + if self.offset == 0 { + self.start()?; + } + let ordinal = self.row_groups.len(); + let (group, specs, size) = write_row_group( + &mut self.writer, + self.offset, + self.schema.columns(), + row_group, + ordinal, + )?; + self.offset += size; + self.row_groups.push(group); + self.page_specs.push(specs); + Ok(()) + } + + /// Writes the footer of the parquet file. Returns the total size of the file and the + /// underlying writer. + pub fn end(&mut self, key_value_metadata: Option>) -> ParquetResult { + if self.offset == 0 { + self.start()?; + } + + if self.state != State::Started { + return Err(ParquetError::InvalidParameter( + "End cannot be called twice".to_string(), + )); + } + // compute file stats + let num_rows = self.row_groups.iter().map(|group| group.num_rows).sum(); + + if self.options.write_statistics { + // write column indexes (require page statistics) + self.row_groups + .iter_mut() + .zip(self.page_specs.iter()) + .try_for_each(|(group, pages)| { + group.columns.iter_mut().zip(pages.iter()).try_for_each( + |(column, pages)| { + let offset = self.offset; + column.column_index_offset = Some(offset as i64); + self.offset += write_column_index(&mut self.writer, pages)?; + let length = self.offset - offset; + column.column_index_length = Some(length as i32); + ParquetResult::Ok(()) + }, + )?; + ParquetResult::Ok(()) + })?; + }; + + // write offset index + self.row_groups + .iter_mut() + .zip(self.page_specs.iter()) + .try_for_each(|(group, pages)| { + group + .columns + .iter_mut() + .zip(pages.iter()) + .try_for_each(|(column, pages)| { + let offset = self.offset; + column.offset_index_offset = Some(offset as i64); + self.offset += write_offset_index(&mut self.writer, pages)?; + column.offset_index_length = Some((self.offset - offset) as i32); + ParquetResult::Ok(()) + })?; + ParquetResult::Ok(()) + })?; + + let metadata = ThriftFileMetadata::new( + self.options.version.into(), + self.schema.clone().into_thrift(), + num_rows, + self.row_groups.clone(), + key_value_metadata, + self.created_by.clone(), + Some(create_column_orders(&self.schema)), + None, + None, + ); + + let len = end_file(&mut self.writer, &metadata)?; + self.state = State::Finished; + self.metadata = Some(metadata); + Ok(self.offset + len) + } + + /// Returns the underlying writer. + pub fn into_inner(self) -> W { + self.writer + } + + /// Returns the underlying writer and [`ThriftFileMetadata`] + /// # Panics + /// This function panics if [`Self::end`] has not yet been called + pub fn into_inner_and_metadata(self) -> (W, ThriftFileMetadata) { + (self.writer, self.metadata.expect("File to have ended")) + } +} diff --git a/crates/polars-parquet/src/parquet/write/indexes/mod.rs b/crates/polars-parquet/src/parquet/write/indexes/mod.rs new file mode 100644 index 000000000000..9f413a15d26a --- /dev/null +++ b/crates/polars-parquet/src/parquet/write/indexes/mod.rs @@ -0,0 +1,4 @@ +mod serialize; +mod write; + +pub use write::*; diff --git a/crates/polars-parquet/src/parquet/write/indexes/serialize.rs b/crates/polars-parquet/src/parquet/write/indexes/serialize.rs new file mode 100644 index 000000000000..b5f902e28732 --- /dev/null +++ b/crates/polars-parquet/src/parquet/write/indexes/serialize.rs @@ -0,0 +1,77 @@ +use polars_parquet_format::{BoundaryOrder, ColumnIndex, OffsetIndex, PageLocation}; + +use crate::parquet::error::{ParquetError, ParquetResult}; +use crate::parquet::write::page::{PageWriteSpec, is_data_page}; + +pub fn serialize_column_index(pages: &[PageWriteSpec]) -> ParquetResult { + let mut null_pages = Vec::with_capacity(pages.len()); + let mut min_values = Vec::with_capacity(pages.len()); + let mut max_values = Vec::with_capacity(pages.len()); + let mut null_counts = Vec::with_capacity(pages.len()); + + pages + .iter() + .filter(|x| is_data_page(x)) + .try_for_each(|spec| { + if let Some(stats) = &spec.statistics { + let stats = stats.serialize(); + + let null_count = stats + .null_count + .ok_or_else(|| ParquetError::oos("null count of a page is required"))?; + null_counts.push(null_count); + + if let Some(min_value) = stats.min_value { + min_values.push(min_value); + max_values.push( + stats + .max_value + .ok_or_else(|| ParquetError::oos("max value of a page is required"))?, + ); + null_pages.push(false) + } else { + min_values.push(vec![0]); + max_values.push(vec![0]); + null_pages.push(true) + } + + ParquetResult::Ok(()) + } else { + Err(ParquetError::oos( + "options were set to write statistics but some pages miss them", + )) + } + })?; + Ok(ColumnIndex { + null_pages, + min_values, + max_values, + boundary_order: BoundaryOrder::UNORDERED, + null_counts: Some(null_counts), + repetition_level_histograms: None, + definition_level_histograms: None, + }) +} + +pub fn serialize_offset_index(pages: &[PageWriteSpec]) -> ParquetResult { + let mut first_row_index = 0; + let page_locations = pages + .iter() + .filter(|x| is_data_page(x)) + .map(|spec| { + let location = PageLocation { + offset: spec.offset.try_into()?, + compressed_page_size: spec.bytes_written.try_into()?, + first_row_index, + }; + let num_rows = spec.num_rows; + first_row_index += num_rows as i64; + Ok(location) + }) + .collect::>>()?; + + Ok(OffsetIndex { + page_locations, + unencoded_byte_array_data_bytes: None, + }) +} diff --git a/crates/polars-parquet/src/parquet/write/indexes/write.rs b/crates/polars-parquet/src/parquet/write/indexes/write.rs new file mode 100644 index 000000000000..73325654e518 --- /dev/null +++ b/crates/polars-parquet/src/parquet/write/indexes/write.rs @@ -0,0 +1,45 @@ +use std::io::Write; + +#[cfg(feature = "async")] +use futures::AsyncWrite; +use polars_parquet_format::thrift::protocol::TCompactOutputProtocol; +#[cfg(feature = "async")] +use polars_parquet_format::thrift::protocol::TCompactOutputStreamProtocol; + +use super::serialize::{serialize_column_index, serialize_offset_index}; +use crate::parquet::error::ParquetResult; +use crate::parquet::write::page::PageWriteSpec; + +pub fn write_column_index(writer: &mut W, pages: &[PageWriteSpec]) -> ParquetResult { + let index = serialize_column_index(pages)?; + let mut protocol = TCompactOutputProtocol::new(writer); + Ok(index.write_to_out_protocol(&mut protocol)? as u64) +} + +#[cfg(feature = "async")] +#[cfg_attr(docsrs, doc(cfg(feature = "async")))] +pub async fn write_column_index_async( + writer: &mut W, + pages: &[PageWriteSpec], +) -> ParquetResult { + let index = serialize_column_index(pages)?; + let mut protocol = TCompactOutputStreamProtocol::new(writer); + Ok(index.write_to_out_stream_protocol(&mut protocol).await? as u64) +} + +pub fn write_offset_index(writer: &mut W, pages: &[PageWriteSpec]) -> ParquetResult { + let index = serialize_offset_index(pages)?; + let mut protocol = TCompactOutputProtocol::new(&mut *writer); + Ok(index.write_to_out_protocol(&mut protocol)? as u64) +} + +#[cfg(feature = "async")] +#[cfg_attr(docsrs, doc(cfg(feature = "async")))] +pub async fn write_offset_index_async( + writer: &mut W, + pages: &[PageWriteSpec], +) -> ParquetResult { + let index = serialize_offset_index(pages)?; + let mut protocol = TCompactOutputStreamProtocol::new(&mut *writer); + Ok(index.write_to_out_stream_protocol(&mut protocol).await? as u64) +} diff --git a/crates/polars-parquet/src/parquet/write/mod.rs b/crates/polars-parquet/src/parquet/write/mod.rs new file mode 100644 index 000000000000..7c124987d0b5 --- /dev/null +++ b/crates/polars-parquet/src/parquet/write/mod.rs @@ -0,0 +1,59 @@ +mod column_chunk; +mod compression; +mod file; +mod indexes; +pub(crate) mod page; +mod row_group; +mod statistics; + +#[cfg(feature = "async")] +mod stream; +#[cfg(feature = "async")] +#[cfg_attr(docsrs, doc(cfg(feature = "async")))] +pub use stream::FileStreamer; + +mod dyn_iter; +pub use compression::{Compressor, compress}; +pub use dyn_iter::{DynIter, DynStreamingIterator}; +pub use file::{FileWriter, write_metadata_sidecar}; +pub use row_group::ColumnOffsetsMetadata; + +use crate::parquet::page::CompressedPage; + +pub type RowGroupIterColumns<'a, E> = + DynIter<'a, Result, E>>; + +pub type RowGroupIter<'a, E> = DynIter<'a, RowGroupIterColumns<'a, E>>; + +/// Write options of different interfaces on this crate +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub struct WriteOptions { + /// Whether to write statistics, including indexes + pub write_statistics: bool, + /// Which Parquet version to use + pub version: Version, +} + +/// The parquet version to use +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum Version { + V1, + V2, +} + +/// Used to recall the state of the parquet writer - whether sync or async. +#[derive(PartialEq)] +enum State { + Initialised, + Started, + Finished, +} + +impl From for i32 { + fn from(version: Version) -> Self { + match version { + Version::V1 => 1, + Version::V2 => 2, + } + } +} diff --git a/crates/polars-parquet/src/parquet/write/page.rs b/crates/polars-parquet/src/parquet/write/page.rs new file mode 100644 index 000000000000..8fb65c3daf12 --- /dev/null +++ b/crates/polars-parquet/src/parquet/write/page.rs @@ -0,0 +1,251 @@ +use std::io::Write; + +#[cfg(feature = "async")] +use futures::{AsyncWrite, AsyncWriteExt}; +use polars_parquet_format::thrift::protocol::TCompactOutputProtocol; +#[cfg(feature = "async")] +use polars_parquet_format::thrift::protocol::TCompactOutputStreamProtocol; +use polars_parquet_format::{DictionaryPageHeader, Encoding, PageType}; + +use crate::parquet::compression::Compression; +use crate::parquet::error::{ParquetError, ParquetResult}; +use crate::parquet::page::{ + CompressedDataPage, CompressedDictPage, CompressedPage, DataPageHeader, ParquetPageHeader, +}; +use crate::parquet::statistics::Statistics; + +pub(crate) fn is_data_page(page: &PageWriteSpec) -> bool { + page.header.type_ == PageType::DATA_PAGE || page.header.type_ == PageType::DATA_PAGE_V2 +} + +fn maybe_bytes(uncompressed: usize, compressed: usize) -> ParquetResult<(i32, i32)> { + let uncompressed_page_size: i32 = uncompressed.try_into().map_err(|_| { + ParquetError::oos(format!( + "A page can only contain i32::MAX uncompressed bytes. This one contains {}", + uncompressed + )) + })?; + + let compressed_page_size: i32 = compressed.try_into().map_err(|_| { + ParquetError::oos(format!( + "A page can only contain i32::MAX compressed bytes. This one contains {}", + compressed + )) + })?; + + Ok((uncompressed_page_size, compressed_page_size)) +} + +/// Contains page write metrics. +pub struct PageWriteSpec { + pub header: ParquetPageHeader, + #[allow(dead_code)] + pub num_values: usize, + /// The number of actual rows. For non-nested values, this is equal to the number of values. + pub num_rows: usize, + pub header_size: u64, + pub offset: u64, + pub bytes_written: u64, + pub compression: Compression, + pub statistics: Option, +} + +pub fn write_page( + writer: &mut W, + offset: u64, + compressed_page: &CompressedPage, +) -> ParquetResult { + let num_values = compressed_page.num_values(); + let num_rows = compressed_page + .num_rows() + .expect("We should have num_rows when we are writing"); + + let header = match &compressed_page { + CompressedPage::Data(compressed_page) => assemble_data_page_header(compressed_page), + CompressedPage::Dict(compressed_page) => assemble_dict_page_header(compressed_page), + }?; + + let header_size = write_page_header(writer, &header)?; + let mut bytes_written = header_size; + + bytes_written += match &compressed_page { + CompressedPage::Data(compressed_page) => { + writer.write_all(&compressed_page.buffer)?; + compressed_page.buffer.len() as u64 + }, + CompressedPage::Dict(compressed_page) => { + writer.write_all(&compressed_page.buffer)?; + compressed_page.buffer.len() as u64 + }, + }; + + let statistics = match &compressed_page { + CompressedPage::Data(compressed_page) => compressed_page.statistics().transpose()?, + CompressedPage::Dict(_) => None, + }; + + Ok(PageWriteSpec { + header, + header_size, + offset, + bytes_written, + compression: compressed_page.compression(), + statistics, + num_values, + num_rows, + }) +} + +#[cfg(feature = "async")] +#[cfg_attr(docsrs, doc(cfg(feature = "async")))] +pub async fn write_page_async( + writer: &mut W, + offset: u64, + compressed_page: &CompressedPage, +) -> ParquetResult { + let num_values = compressed_page.num_values(); + let num_rows = compressed_page + .num_rows() + .expect("We should have the num_rows when we are writing"); + + let header = match &compressed_page { + CompressedPage::Data(compressed_page) => assemble_data_page_header(compressed_page), + CompressedPage::Dict(compressed_page) => assemble_dict_page_header(compressed_page), + }?; + + let header_size = write_page_header_async(writer, &header).await?; + let mut bytes_written = header_size as u64; + + bytes_written += match &compressed_page { + CompressedPage::Data(compressed_page) => { + writer.write_all(&compressed_page.buffer).await?; + compressed_page.buffer.len() as u64 + }, + CompressedPage::Dict(compressed_page) => { + writer.write_all(&compressed_page.buffer).await?; + compressed_page.buffer.len() as u64 + }, + }; + + let statistics = match &compressed_page { + CompressedPage::Data(compressed_page) => compressed_page.statistics().transpose()?, + CompressedPage::Dict(_) => None, + }; + + Ok(PageWriteSpec { + header, + header_size, + offset, + bytes_written, + compression: compressed_page.compression(), + statistics, + num_rows, + num_values, + }) +} + +fn assemble_data_page_header(page: &CompressedDataPage) -> ParquetResult { + let (uncompressed_page_size, compressed_page_size) = + maybe_bytes(page.uncompressed_size(), page.compressed_size())?; + + let mut page_header = ParquetPageHeader { + type_: match page.header() { + DataPageHeader::V1(_) => PageType::DATA_PAGE, + DataPageHeader::V2(_) => PageType::DATA_PAGE_V2, + }, + uncompressed_page_size, + compressed_page_size, + crc: None, + data_page_header: None, + index_page_header: None, + dictionary_page_header: None, + data_page_header_v2: None, + }; + + match page.header() { + DataPageHeader::V1(header) => { + page_header.data_page_header = Some(header.clone()); + }, + DataPageHeader::V2(header) => { + page_header.data_page_header_v2 = Some(header.clone()); + }, + } + Ok(page_header) +} + +fn assemble_dict_page_header(page: &CompressedDictPage) -> ParquetResult { + let (uncompressed_page_size, compressed_page_size) = + maybe_bytes(page.uncompressed_page_size, page.buffer.len())?; + + let num_values: i32 = page.num_values.try_into().map_err(|_| { + ParquetError::oos(format!( + "A dictionary page can only contain i32::MAX items. This one contains {}", + page.num_values + )) + })?; + + Ok(ParquetPageHeader { + type_: PageType::DICTIONARY_PAGE, + uncompressed_page_size, + compressed_page_size, + crc: None, + data_page_header: None, + index_page_header: None, + dictionary_page_header: Some(DictionaryPageHeader { + num_values, + encoding: Encoding::PLAIN, + is_sorted: None, + }), + data_page_header_v2: None, + }) +} + +/// writes the page header into `writer`, returning the number of bytes used in the process. +fn write_page_header( + mut writer: &mut W, + header: &ParquetPageHeader, +) -> ParquetResult { + let mut protocol = TCompactOutputProtocol::new(&mut writer); + Ok(header.write_to_out_protocol(&mut protocol)? as u64) +} + +#[cfg(feature = "async")] +#[cfg_attr(docsrs, doc(cfg(feature = "async")))] +/// writes the page header into `writer`, returning the number of bytes used in the process. +async fn write_page_header_async( + mut writer: &mut W, + header: &ParquetPageHeader, +) -> ParquetResult { + let mut protocol = TCompactOutputStreamProtocol::new(&mut writer); + Ok(header.write_to_out_stream_protocol(&mut protocol).await? as u64) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::parquet::CowBuffer; + + #[test] + fn dict_too_large() { + let page = CompressedDictPage::new( + CowBuffer::Owned(vec![]), + Compression::Uncompressed, + i32::MAX as usize + 1, + 100, + false, + ); + assert!(assemble_dict_page_header(&page).is_err()); + } + + #[test] + fn dict_too_many_values() { + let page = CompressedDictPage::new( + CowBuffer::Owned(vec![]), + Compression::Uncompressed, + 0, + i32::MAX as usize + 1, + false, + ); + assert!(assemble_dict_page_header(&page).is_err()); + } +} diff --git a/crates/polars-parquet/src/parquet/write/row_group.rs b/crates/polars-parquet/src/parquet/write/row_group.rs new file mode 100644 index 000000000000..d5774bbdd588 --- /dev/null +++ b/crates/polars-parquet/src/parquet/write/row_group.rs @@ -0,0 +1,198 @@ +use std::io::Write; + +#[cfg(feature = "async")] +use futures::AsyncWrite; +use polars_parquet_format::{ColumnChunk, RowGroup}; + +use super::column_chunk::write_column_chunk; +#[cfg(feature = "async")] +use super::column_chunk::write_column_chunk_async; +use super::page::{PageWriteSpec, is_data_page}; +use super::{DynIter, DynStreamingIterator}; +use crate::parquet::error::{ParquetError, ParquetResult}; +use crate::parquet::metadata::{ColumnChunkMetadata, ColumnDescriptor}; +use crate::parquet::page::CompressedPage; + +pub struct ColumnOffsetsMetadata { + pub dictionary_page_offset: Option, + pub data_page_offset: Option, +} + +impl ColumnOffsetsMetadata { + pub fn from_column_chunk(column_chunk: &ColumnChunk) -> ColumnOffsetsMetadata { + ColumnOffsetsMetadata { + dictionary_page_offset: column_chunk + .meta_data + .as_ref() + .map(|meta| meta.dictionary_page_offset) + .unwrap_or(None), + data_page_offset: column_chunk + .meta_data + .as_ref() + .map(|meta| meta.data_page_offset), + } + } + + pub fn from_column_chunk_metadata( + column_chunk_metadata: &ColumnChunkMetadata, + ) -> ColumnOffsetsMetadata { + ColumnOffsetsMetadata { + dictionary_page_offset: column_chunk_metadata.dictionary_page_offset(), + data_page_offset: Some(column_chunk_metadata.data_page_offset()), + } + } + + pub fn calc_row_group_file_offset(&self) -> Option { + self.dictionary_page_offset + .filter(|x| *x > 0_i64) + .or(self.data_page_offset) + } +} + +fn compute_num_rows(columns: &[(ColumnChunk, Vec)]) -> ParquetResult { + columns + .first() + .map(|(_, specs)| { + let mut num_rows = 0; + specs + .iter() + .filter(|x| is_data_page(x)) + .try_for_each(|spec| { + num_rows += spec.num_rows as i64; + ParquetResult::Ok(()) + })?; + ParquetResult::Ok(num_rows) + }) + .unwrap_or(Ok(0)) +} + +pub fn write_row_group< + 'a, + W, + E, // external error any of the iterators may emit +>( + writer: &mut W, + mut offset: u64, + descriptors: &[ColumnDescriptor], + columns: DynIter<'a, std::result::Result, E>>, + ordinal: usize, +) -> ParquetResult<(RowGroup, Vec>, u64)> +where + W: Write, + ParquetError: From, + E: std::error::Error, +{ + let column_iter = descriptors.iter().zip(columns); + + let initial = offset; + let columns = column_iter + .map(|(descriptor, page_iter)| { + let (column, page_specs, size) = + write_column_chunk(writer, offset, descriptor, page_iter?)?; + offset += size; + Ok((column, page_specs)) + }) + .collect::>>()?; + let bytes_written = offset - initial; + + let num_rows = compute_num_rows(&columns)?; + + // compute row group stats + let file_offset = columns + .first() + .map(|(column_chunk, _)| { + ColumnOffsetsMetadata::from_column_chunk(column_chunk).calc_row_group_file_offset() + }) + .unwrap_or(None); + + let total_byte_size = columns + .iter() + .map(|(c, _)| c.meta_data.as_ref().unwrap().total_uncompressed_size) + .sum(); + let total_compressed_size = columns + .iter() + .map(|(c, _)| c.meta_data.as_ref().unwrap().total_compressed_size) + .sum(); + + let (columns, specs) = columns.into_iter().unzip(); + + Ok(( + RowGroup { + columns, + total_byte_size, + num_rows, + sorting_columns: None, + file_offset, + total_compressed_size: Some(total_compressed_size), + ordinal: ordinal.try_into().ok(), + }, + specs, + bytes_written, + )) +} + +#[cfg(feature = "async")] +#[cfg_attr(docsrs, doc(cfg(feature = "async")))] +pub async fn write_row_group_async< + 'a, + W, + E, // external error any of the iterators may emit +>( + writer: &mut W, + mut offset: u64, + descriptors: &[ColumnDescriptor], + columns: DynIter<'a, std::result::Result, E>>, + ordinal: usize, +) -> ParquetResult<(RowGroup, Vec>, u64)> +where + W: AsyncWrite + Unpin + Send, + ParquetError: From, + E: std::error::Error, +{ + let column_iter = descriptors.iter().zip(columns); + + let initial = offset; + let mut columns = vec![]; + for (descriptor, page_iter) in column_iter { + let (column, page_specs, size) = + write_column_chunk_async(writer, offset, descriptor, page_iter?).await?; + offset += size; + columns.push((column, page_specs)); + } + let bytes_written = offset - initial; + + let num_rows = compute_num_rows(&columns)?; + + // compute row group stats + let file_offset = columns + .first() + .map(|(column_chunk, _)| { + ColumnOffsetsMetadata::from_column_chunk(column_chunk).calc_row_group_file_offset() + }) + .unwrap_or(None); + + let total_byte_size = columns + .iter() + .map(|(c, _)| c.meta_data.as_ref().unwrap().total_uncompressed_size) + .sum(); + let total_compressed_size = columns + .iter() + .map(|(c, _)| c.meta_data.as_ref().unwrap().total_compressed_size) + .sum(); + + let (columns, specs) = columns.into_iter().unzip(); + + Ok(( + RowGroup { + columns, + total_byte_size, + num_rows: num_rows as i64, + sorting_columns: None, + file_offset, + total_compressed_size: Some(total_compressed_size), + ordinal: ordinal.try_into().ok(), + }, + specs, + bytes_written, + )) +} diff --git a/crates/polars-parquet/src/parquet/write/statistics.rs b/crates/polars-parquet/src/parquet/write/statistics.rs new file mode 100644 index 000000000000..064a16eb931b --- /dev/null +++ b/crates/polars-parquet/src/parquet/write/statistics.rs @@ -0,0 +1,295 @@ +use crate::parquet::error::{ParquetError, ParquetResult}; +use crate::parquet::schema::types::PhysicalType; +use crate::parquet::statistics::*; +use crate::parquet::types::NativeType; + +#[inline] +fn reduce_single T>(lhs: Option, rhs: Option, op: F) -> Option { + match (lhs, rhs) { + (None, None) => None, + (Some(x), None) => Some(x), + (None, Some(x)) => Some(x), + (Some(x), Some(y)) => Some(op(x, y)), + } +} + +#[inline] +fn reduce_vec8(lhs: Option>, rhs: &Option>, max: bool) -> Option> { + match (lhs, rhs) { + (None, None) => None, + (Some(x), None) => Some(x), + (None, Some(x)) => Some(x.clone()), + (Some(x), Some(y)) => Some(ord_binary(x, y.clone(), max)), + } +} + +pub fn reduce(stats: &[&Option]) -> ParquetResult> { + if stats.is_empty() { + return Ok(None); + } + let stats = stats + .iter() + .filter_map(|x| x.as_ref()) + .collect::>(); + if stats.is_empty() { + return Ok(None); + }; + + let same_type = stats + .iter() + .skip(1) + .all(|x| x.physical_type() == stats[0].physical_type()); + if !same_type { + return Err(ParquetError::oos( + "The statistics do not have the same dtype", + )); + }; + + use PhysicalType as T; + let stats = match stats[0].physical_type() { + T::Boolean => reduce_boolean(stats.iter().map(|x| x.expect_as_boolean())).into(), + T::Int32 => reduce_primitive::(stats.iter().map(|x| x.expect_as_int32())).into(), + T::Int64 => reduce_primitive(stats.iter().map(|x| x.expect_as_int64())).into(), + T::Float => reduce_primitive(stats.iter().map(|x| x.expect_as_float())).into(), + T::Double => reduce_primitive(stats.iter().map(|x| x.expect_as_double())).into(), + T::ByteArray => reduce_binary(stats.iter().map(|x| x.expect_as_binary())).into(), + T::FixedLenByteArray(_) => { + reduce_fix_len_binary(stats.iter().map(|x| x.expect_as_fixedlen())).into() + }, + _ => todo!(), + }; + + Ok(Some(stats)) +} + +fn reduce_binary<'a, I: Iterator>(mut stats: I) -> BinaryStatistics { + let initial = stats.next().unwrap().clone(); + stats.fold(initial, |mut acc, new| { + acc.min_value = reduce_vec8(acc.min_value, &new.min_value, false); + acc.max_value = reduce_vec8(acc.max_value, &new.max_value, true); + acc.null_count = reduce_single(acc.null_count, new.null_count, |x, y| x + y); + acc.distinct_count = None; + acc + }) +} + +fn reduce_fix_len_binary<'a, I: Iterator>( + mut stats: I, +) -> FixedLenStatistics { + let initial = stats.next().unwrap().clone(); + stats.fold(initial, |mut acc, new| { + acc.min_value = reduce_vec8(acc.min_value, &new.min_value, false); + acc.max_value = reduce_vec8(acc.max_value, &new.max_value, true); + acc.null_count = reduce_single(acc.null_count, new.null_count, |x, y| x + y); + acc.distinct_count = None; + acc + }) +} + +fn ord_binary(a: Vec, b: Vec, max: bool) -> Vec { + for (v1, v2) in a.iter().zip(b.iter()) { + match v1.cmp(v2) { + std::cmp::Ordering::Greater => { + if max { + return a; + } else { + return b; + } + }, + std::cmp::Ordering::Less => { + if max { + return b; + } else { + return a; + } + }, + _ => {}, + } + } + a +} + +fn reduce_boolean<'a, I: Iterator>( + mut stats: I, +) -> BooleanStatistics { + let initial = stats.next().unwrap().clone(); + stats.fold(initial, |mut acc, new| { + acc.min_value = reduce_single( + acc.min_value, + new.min_value, + |x, y| if x & !(y) { y } else { x }, + ); + acc.max_value = reduce_single( + acc.max_value, + new.max_value, + |x, y| if x & !(y) { x } else { y }, + ); + acc.null_count = reduce_single(acc.null_count, new.null_count, |x, y| x + y); + acc.distinct_count = None; + acc + }) +} + +fn reduce_primitive< + 'a, + T: NativeType + std::cmp::PartialOrd, + I: Iterator>, +>( + mut stats: I, +) -> PrimitiveStatistics { + let initial = stats.next().unwrap().clone(); + stats.fold(initial, |mut acc, new| { + acc.min_value = reduce_single( + acc.min_value, + new.min_value, + |x, y| if x > y { y } else { x }, + ); + acc.max_value = reduce_single( + acc.max_value, + new.max_value, + |x, y| if x > y { x } else { y }, + ); + acc.null_count = reduce_single(acc.null_count, new.null_count, |x, y| x + y); + acc.distinct_count = None; + acc + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::parquet::schema::types::PrimitiveType; + + #[test] + fn binary() -> ParquetResult<()> { + let iter = vec![ + BinaryStatistics { + primitive_type: PrimitiveType::from_physical("bla".into(), PhysicalType::ByteArray), + null_count: Some(0), + distinct_count: None, + min_value: Some(vec![1, 2]), + max_value: Some(vec![3, 4]), + }, + BinaryStatistics { + primitive_type: PrimitiveType::from_physical("bla".into(), PhysicalType::ByteArray), + null_count: Some(0), + distinct_count: None, + min_value: Some(vec![4, 5]), + max_value: None, + }, + ]; + let a = reduce_binary(iter.iter()); + + assert_eq!( + a, + BinaryStatistics { + primitive_type: PrimitiveType::from_physical("bla".into(), PhysicalType::ByteArray,), + null_count: Some(0), + distinct_count: None, + min_value: Some(vec![1, 2]), + max_value: Some(vec![3, 4]), + }, + ); + + Ok(()) + } + + #[test] + fn fixed_len_binary() -> ParquetResult<()> { + let iter = vec![ + FixedLenStatistics { + primitive_type: PrimitiveType::from_physical( + "bla".into(), + PhysicalType::FixedLenByteArray(2), + ), + null_count: Some(0), + distinct_count: None, + min_value: Some(vec![1, 2]), + max_value: Some(vec![3, 4]), + }, + FixedLenStatistics { + primitive_type: PrimitiveType::from_physical( + "bla".into(), + PhysicalType::FixedLenByteArray(2), + ), + null_count: Some(0), + distinct_count: None, + min_value: Some(vec![4, 5]), + max_value: None, + }, + ]; + let a = reduce_fix_len_binary(iter.iter()); + + assert_eq!( + a, + FixedLenStatistics { + primitive_type: PrimitiveType::from_physical( + "bla".into(), + PhysicalType::FixedLenByteArray(2), + ), + null_count: Some(0), + distinct_count: None, + min_value: Some(vec![1, 2]), + max_value: Some(vec![3, 4]), + }, + ); + + Ok(()) + } + + #[test] + fn boolean() -> ParquetResult<()> { + let iter = [ + BooleanStatistics { + null_count: Some(0), + distinct_count: None, + min_value: Some(false), + max_value: Some(false), + }, + BooleanStatistics { + null_count: Some(0), + distinct_count: None, + min_value: Some(true), + max_value: Some(true), + }, + ]; + let a = reduce_boolean(iter.iter()); + + assert_eq!( + a, + BooleanStatistics { + null_count: Some(0), + distinct_count: None, + min_value: Some(false), + max_value: Some(true), + }, + ); + + Ok(()) + } + + #[test] + fn primitive() -> ParquetResult<()> { + let iter = [PrimitiveStatistics { + null_count: Some(2), + distinct_count: None, + min_value: Some(30), + max_value: Some(70), + primitive_type: PrimitiveType::from_physical("bla".into(), PhysicalType::Int32), + }]; + let a = reduce_primitive(iter.iter()); + + assert_eq!( + a, + PrimitiveStatistics { + null_count: Some(2), + distinct_count: None, + min_value: Some(30), + max_value: Some(70), + primitive_type: PrimitiveType::from_physical("bla".into(), PhysicalType::Int32,), + }, + ); + + Ok(()) + } +} diff --git a/crates/polars-parquet/src/parquet/write/stream.rs b/crates/polars-parquet/src/parquet/write/stream.rs new file mode 100644 index 000000000000..0eef141a1b67 --- /dev/null +++ b/crates/polars-parquet/src/parquet/write/stream.rs @@ -0,0 +1,192 @@ +use std::io::Write; + +use futures::{AsyncWrite, AsyncWriteExt}; +use polars_parquet_format::RowGroup; +use polars_parquet_format::thrift::protocol::TCompactOutputStreamProtocol; + +use super::row_group::write_row_group_async; +use super::{RowGroupIterColumns, WriteOptions}; +use crate::parquet::error::{ParquetError, ParquetResult}; +use crate::parquet::metadata::{KeyValue, SchemaDescriptor}; +use crate::parquet::write::State; +use crate::parquet::write::indexes::{write_column_index_async, write_offset_index_async}; +use crate::parquet::write::page::PageWriteSpec; +use crate::parquet::{FOOTER_SIZE, PARQUET_MAGIC}; + +async fn start_file(writer: &mut W) -> ParquetResult { + writer.write_all(&PARQUET_MAGIC).await?; + Ok(PARQUET_MAGIC.len() as u64) +} + +async fn end_file( + mut writer: &mut W, + metadata: polars_parquet_format::FileMetaData, +) -> ParquetResult { + // Write file metadata + let mut protocol = TCompactOutputStreamProtocol::new(&mut writer); + let metadata_len = metadata.write_to_out_stream_protocol(&mut protocol).await? as i32; + + // Write footer + let metadata_bytes = metadata_len.to_le_bytes(); + let mut footer_buffer = [0u8; FOOTER_SIZE as usize]; + (0..4).for_each(|i| { + footer_buffer[i] = metadata_bytes[i]; + }); + + (&mut footer_buffer[4..]).write_all(&PARQUET_MAGIC)?; + writer.write_all(&footer_buffer).await?; + writer.flush().await?; + Ok(metadata_len as u64 + FOOTER_SIZE) +} + +/// An interface to write a parquet file asynchronously. +/// Use `start` to write the header, `write` to write a row group, +/// and `end` to write the footer. +pub struct FileStreamer { + writer: W, + schema: SchemaDescriptor, + options: WriteOptions, + created_by: Option, + + offset: u64, + row_groups: Vec, + page_specs: Vec>>, + /// Used to store the current state for writing the file + state: State, +} + +// Accessors +impl FileStreamer { + /// The options assigned to the file + pub fn options(&self) -> &WriteOptions { + &self.options + } + + /// The [`SchemaDescriptor`] assigned to this file + pub fn schema(&self) -> &SchemaDescriptor { + &self.schema + } +} + +impl FileStreamer { + /// Returns a new [`FileStreamer`]. + pub fn new( + writer: W, + schema: SchemaDescriptor, + options: WriteOptions, + created_by: Option, + ) -> Self { + Self { + writer, + schema, + options, + created_by, + offset: 0, + row_groups: vec![], + page_specs: vec![], + state: State::Initialised, + } + } + + /// Writes the header of the file. + /// + /// This is automatically called by [`Self::write`] if not called following [`Self::new`]. + /// + /// # Errors + /// Returns an error if data has been written to the file. + async fn start(&mut self) -> ParquetResult<()> { + if self.offset == 0 { + self.offset = start_file(&mut self.writer).await? as u64; + self.state = State::Started; + Ok(()) + } else { + Err(ParquetError::InvalidParameter( + "Start cannot be called twice".to_string(), + )) + } + } + + /// Writes a row group to the file. + pub async fn write(&mut self, row_group: RowGroupIterColumns<'_, E>) -> ParquetResult<()> + where + ParquetError: From, + E: std::error::Error, + { + if self.offset == 0 { + self.start().await?; + } + + let ordinal = self.row_groups.len(); + let (group, specs, size) = write_row_group_async( + &mut self.writer, + self.offset, + self.schema.columns(), + row_group, + ordinal, + ) + .await?; + self.offset += size; + self.row_groups.push(group); + self.page_specs.push(specs); + Ok(()) + } + + /// Writes the footer of the parquet file. Returns the total size of the file and the + /// underlying writer. + pub async fn end(&mut self, key_value_metadata: Option>) -> ParquetResult { + if self.offset == 0 { + self.start().await?; + } + + if self.state != State::Started { + return Err(ParquetError::InvalidParameter( + "End cannot be called twice".to_string(), + )); + } + // compute file stats + let num_rows = self.row_groups.iter().map(|group| group.num_rows).sum(); + + if self.options.write_statistics { + // write column indexes (require page statistics) + for (group, pages) in self.row_groups.iter_mut().zip(self.page_specs.iter()) { + for (column, pages) in group.columns.iter_mut().zip(pages.iter()) { + let offset = self.offset; + column.column_index_offset = Some(offset as i64); + self.offset += write_column_index_async(&mut self.writer, pages).await?; + let length = self.offset - offset; + column.column_index_length = Some(length as i32); + } + } + }; + + // write offset index + for (group, pages) in self.row_groups.iter_mut().zip(self.page_specs.iter()) { + for (column, pages) in group.columns.iter_mut().zip(pages.iter()) { + let offset = self.offset; + column.offset_index_offset = Some(offset as i64); + self.offset += write_offset_index_async(&mut self.writer, pages).await?; + column.offset_index_length = Some((self.offset - offset) as i32); + } + } + + let metadata = polars_parquet_format::FileMetaData::new( + self.options.version.into(), + self.schema.clone().into_thrift(), + num_rows, + self.row_groups.clone(), + key_value_metadata, + self.created_by.clone(), + None, + None, + None, + ); + + let len = end_file(&mut self.writer, metadata).await?; + Ok(self.offset + len) + } + + /// Returns the underlying writer. + pub fn into_inner(self) -> W { + self.writer + } +} diff --git a/crates/polars-pipe/Cargo.toml b/crates/polars-pipe/Cargo.toml new file mode 100644 index 000000000000..2d5bbf4ba0e1 --- /dev/null +++ b/crates/polars-pipe/Cargo.toml @@ -0,0 +1,52 @@ +[package] +name = "polars-pipe" +version = { workspace = true } +authors = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +license = { workspace = true } +repository = { workspace = true } +description = "Lazy query engine for the Polars DataFrame library" + +[dependencies] +arrow = { workspace = true } +futures = { workspace = true, optional = true } +polars-compute = { workspace = true } +polars-core = { workspace = true, features = ["lazy", "zip_with", "random", "rows"] } +polars-expr = { workspace = true } +polars-io = { workspace = true, features = ["ipc"] } +polars-ops = { workspace = true, features = ["search_sorted", "chunked_ids"] } +polars-plan = { workspace = true } +polars-row = { workspace = true } +polars-utils = { workspace = true, features = ["sysinfo"] } +tokio = { workspace = true, optional = true } +uuid = { workspace = true } + +crossbeam-channel = { workspace = true } +crossbeam-queue = { workspace = true } +enum_dispatch = { version = "0.3" } +hashbrown = { workspace = true } +num-traits = { workspace = true } +rayon = { workspace = true } + +[build-dependencies] +version_check = { workspace = true } + +[features] +csv = ["polars-plan/csv", "polars-io/csv"] +cloud = ["async", "polars-io/cloud", "polars-plan/cloud", "tokio", "futures"] +parquet = ["polars-plan/parquet", "polars-io/parquet", "polars-io/async", "futures"] +ipc = ["polars-plan/ipc", "polars-io/ipc"] +json = ["polars-plan/json", "polars-io/json"] +async = ["polars-plan/async", "polars-io/async", "futures"] +nightly = ["polars-core/nightly", "polars-utils/nightly", "hashbrown/nightly"] +cross_join = ["polars-ops/cross_join"] +dtype-u8 = ["polars-core/dtype-u8"] +dtype-u16 = ["polars-core/dtype-u16"] +dtype-i8 = ["polars-core/dtype-i8"] +dtype-i16 = ["polars-core/dtype-i16"] +dtype-i128 = ["polars-core/dtype-i128"] +dtype-decimal = ["dtype-i128"] +dtype-array = ["polars-core/dtype-array"] +dtype-categorical = ["polars-core/dtype-categorical"] +trigger_ooc = [] diff --git a/crates/polars-pipe/LICENSE b/crates/polars-pipe/LICENSE new file mode 120000 index 000000000000..30cff7403da0 --- /dev/null +++ b/crates/polars-pipe/LICENSE @@ -0,0 +1 @@ +../../LICENSE \ No newline at end of file diff --git a/crates/polars-pipe/README.md b/crates/polars-pipe/README.md new file mode 100644 index 000000000000..ee3ee757f80d --- /dev/null +++ b/crates/polars-pipe/README.md @@ -0,0 +1,7 @@ +# polars-pipe + +`polars-pipe` is an **internal sub-crate** of the [Polars](https://crates.io/crates/polars) library, +introducing OOC (out of core) algorithms to polars physical plans. + +**Important Note**: This crate is **not intended for external usage**. Please refer to the main +[Polars crate](https://crates.io/crates/polars) for intended usage. diff --git a/crates/polars-pipe/build.rs b/crates/polars-pipe/build.rs new file mode 100644 index 000000000000..3e4ab64620ac --- /dev/null +++ b/crates/polars-pipe/build.rs @@ -0,0 +1,7 @@ +fn main() { + println!("cargo:rerun-if-changed=build.rs"); + let channel = version_check::Channel::read().unwrap(); + if channel.is_nightly() { + println!("cargo:rustc-cfg=feature=\"nightly\""); + } +} diff --git a/crates/polars-pipe/src/executors/mod.rs b/crates/polars-pipe/src/executors/mod.rs new file mode 100644 index 000000000000..981fd0279d6f --- /dev/null +++ b/crates/polars-pipe/src/executors/mod.rs @@ -0,0 +1,6 @@ +pub(crate) mod operators; +pub(crate) mod sinks; +pub(crate) mod sources; + +#[cfg(feature = "csv")] +use crate::operators::*; diff --git a/crates/polars-pipe/src/executors/operators/filter.rs b/crates/polars-pipe/src/executors/operators/filter.rs new file mode 100644 index 000000000000..5823fb3c860d --- /dev/null +++ b/crates/polars-pipe/src/executors/operators/filter.rs @@ -0,0 +1,38 @@ +use std::sync::Arc; + +use polars_core::error::PolarsResult; +use polars_core::prelude::polars_err; + +use crate::expressions::PhysicalPipedExpr; +use crate::operators::{DataChunk, Operator, OperatorResult, PExecutionContext}; + +#[derive(Clone)] +pub(crate) struct FilterOperator { + pub(crate) predicate: Arc, +} + +impl Operator for FilterOperator { + fn execute( + &mut self, + context: &PExecutionContext, + chunk: &DataChunk, + ) -> PolarsResult { + let s = self.predicate.evaluate(chunk, &context.execution_state)?; + let mask = s.bool().map_err(|_| { + polars_err!( + ComputeError: "filter predicate must be of type `Boolean`, got `{}`", s.dtype() + ) + })?; + // the filter is sequential as they are already executed on different threads + // we don't want to increase contention and data copies + let df = chunk.data._filter_seq(mask)?; + + Ok(OperatorResult::Finished(chunk.with_data(df))) + } + fn split(&self, _thread_no: usize) -> Box { + Box::new(self.clone()) + } + fn fmt(&self) -> &str { + "filter" + } +} diff --git a/crates/polars-pipe/src/executors/operators/function.rs b/crates/polars-pipe/src/executors/operators/function.rs new file mode 100644 index 000000000000..181faefe8693 --- /dev/null +++ b/crates/polars-pipe/src/executors/operators/function.rs @@ -0,0 +1,121 @@ +use std::collections::VecDeque; + +use polars_core::POOL; +use polars_core::error::PolarsResult; +use polars_core::utils::_split_offsets; +use polars_plan::prelude::*; + +use crate::operators::{DataChunk, Operator, OperatorResult, PExecutionContext}; +use crate::pipeline::determine_chunk_size; + +#[derive(Clone)] +pub struct FunctionOperator { + n_threads: usize, + chunk_size: usize, + offsets: VecDeque<(usize, usize)>, + function: FunctionIR, +} + +impl FunctionOperator { + pub(crate) fn new(function: FunctionIR) -> Self { + FunctionOperator { + n_threads: POOL.current_num_threads(), + function, + chunk_size: 128, + offsets: VecDeque::new(), + } + } + + fn execute_no_expanding(&mut self, chunk: &DataChunk) -> PolarsResult { + Ok(OperatorResult::Finished( + chunk.with_data(self.function.evaluate(chunk.data.clone())?), + )) + } + + // Combine every two `(offset, len)` pairs so that we double the chunk size + fn combine_offsets(&mut self) { + self.offsets = self + .offsets + .make_contiguous() + .chunks(2) + .map(|chunk| { + if chunk.len() == 2 { + let offset = chunk[0].0; + let len = chunk[0].1 + chunk[1].1; + (offset, len) + } else { + chunk[0] + } + }) + .collect() + } +} + +impl Operator for FunctionOperator { + fn execute( + &mut self, + context: &PExecutionContext, + chunk: &DataChunk, + ) -> PolarsResult { + if self.function.expands_rows() { + let input_height = chunk.data.height(); + // ideal chunk size we want to have + // we cannot rely on input chunk size as that can increase due to multiple explode calls + // for instance. + let chunk_size_ambition = determine_chunk_size(chunk.data.width(), self.n_threads)?; + + if self.offsets.is_empty() { + let n = input_height / self.chunk_size; + if n > 1 { + self.offsets = _split_offsets(input_height, n).into(); + } else { + return self.execute_no_expanding(chunk); + } + } + if let Some((offset, len)) = self.offsets.pop_front() { + let df = chunk.data.slice(offset as i64, len); + let output = self.function.evaluate(df)?; + if output.height() * 2 < chunk.data.height() + && output.height() * 2 < chunk_size_ambition + { + self.chunk_size *= 2; + // ensure that next slice is larger + self.combine_offsets(); + } + // allow some increase in chunk size so that we don't toggle the chunk size + // every iteration + else if output.height() * 4 > chunk.data.height() + || output.height() > chunk_size_ambition * 2 + { + let new_chunk_size = self.chunk_size / 2; + + if context.verbose && new_chunk_size < 5 { + eprintln!( + "chunk size in 'function operation' shrank to {new_chunk_size} and has been set to 5 as lower limit" + ) + } + // ensure it is never 0 + self.chunk_size = std::cmp::max(new_chunk_size, 5); + }; + let output = chunk.with_data(output); + if self.offsets.is_empty() { + Ok(OperatorResult::Finished(output)) + } else { + Ok(OperatorResult::HaveMoreOutPut(output)) + } + } else { + self.execute_no_expanding(chunk) + } + } else { + self.execute_no_expanding(chunk) + } + } + + fn split(&self, _thread_no: usize) -> Box { + Box::new(self.clone()) + } + + fn fmt(&self) -> &str { + "function" + } +} diff --git a/crates/polars-pipe/src/executors/operators/mod.rs b/crates/polars-pipe/src/executors/operators/mod.rs new file mode 100644 index 000000000000..a078a4fd5f44 --- /dev/null +++ b/crates/polars-pipe/src/executors/operators/mod.rs @@ -0,0 +1,13 @@ +mod filter; +mod function; +mod pass; +mod placeholder; +mod projection; +mod reproject; + +pub(crate) use filter::*; +pub(crate) use function::*; +pub(crate) use pass::Pass; +pub(crate) use placeholder::PlaceHolder; +pub(crate) use projection::*; +pub(crate) use reproject::*; diff --git a/crates/polars-pipe/src/executors/operators/pass.rs b/crates/polars-pipe/src/executors/operators/pass.rs new file mode 100644 index 000000000000..6b2189dbc2b8 --- /dev/null +++ b/crates/polars-pipe/src/executors/operators/pass.rs @@ -0,0 +1,32 @@ +use polars_core::error::PolarsResult; + +use crate::operators::{DataChunk, Operator, OperatorResult, PExecutionContext}; + +/// Simply pass through the chunks +pub struct Pass { + name: &'static str, +} + +impl Pass { + pub(crate) fn new(name: &'static str) -> Self { + Self { name } + } +} + +impl Operator for Pass { + fn execute( + &mut self, + _context: &PExecutionContext, + chunk: &DataChunk, + ) -> PolarsResult { + Ok(OperatorResult::Finished(chunk.clone())) + } + + fn split(&self, _thread_no: usize) -> Box { + Box::new(Self { name: self.name }) + } + + fn fmt(&self) -> &str { + self.name + } +} diff --git a/crates/polars-pipe/src/executors/operators/placeholder.rs b/crates/polars-pipe/src/executors/operators/placeholder.rs new file mode 100644 index 000000000000..8a12be07d12b --- /dev/null +++ b/crates/polars-pipe/src/executors/operators/placeholder.rs @@ -0,0 +1,93 @@ +use std::sync::{Arc, Mutex}; + +use polars_core::error::PolarsResult; + +use crate::operators::{DataChunk, Operator, OperatorResult, PExecutionContext}; + +#[derive(Clone)] +struct CallBack { + inner: Arc>>>, +} + +impl CallBack { + fn new() -> Self { + Self { + inner: Default::default(), + } + } + + fn replace(&self, op: Box) { + let mut lock = self.inner.try_lock().expect("no-contention"); + *lock = Some(op); + } +} + +impl Operator for CallBack { + fn execute( + &mut self, + context: &PExecutionContext, + chunk: &DataChunk, + ) -> PolarsResult { + let mut lock = self.inner.try_lock().expect("no-contention"); + lock.as_mut().unwrap().execute(context, chunk) + } + + fn flush(&mut self) -> PolarsResult { + let mut lock = self.inner.try_lock().expect("no-contention"); + lock.as_mut().unwrap().flush() + } + + fn must_flush(&self) -> bool { + let lock = self.inner.try_lock().expect("no-contention"); + lock.as_ref().unwrap().must_flush() + } + + fn split(&self, _thread_no: usize) -> Box { + panic!("should not be called") + } + + fn fmt(&self) -> &str { + "callback" + } +} + +#[derive(Clone, Default)] +pub struct PlaceHolder { + inner: Arc>>, +} + +impl PlaceHolder { + pub fn new() -> Self { + Self { + inner: Arc::new(Default::default()), + } + } + + pub fn replace(&self, op: Box) { + let inner = self.inner.lock().unwrap(); + for (thread_no, cb) in inner.iter() { + cb.replace(op.split(*thread_no)) + } + } +} + +impl Operator for PlaceHolder { + fn execute( + &mut self, + _context: &PExecutionContext, + _chunk: &DataChunk, + ) -> PolarsResult { + panic!("placeholder should be replaced") + } + + fn split(&self, thread_no: usize) -> Box { + let cb = CallBack::new(); + let mut inner = self.inner.lock().unwrap(); + inner.push((thread_no, cb.clone())); + Box::new(cb) + } + + fn fmt(&self) -> &str { + "placeholder" + } +} diff --git a/crates/polars-pipe/src/executors/operators/projection.rs b/crates/polars-pipe/src/executors/operators/projection.rs new file mode 100644 index 000000000000..32e108bd2034 --- /dev/null +++ b/crates/polars-pipe/src/executors/operators/projection.rs @@ -0,0 +1,158 @@ +use std::sync::Arc; + +use polars_core::error::PolarsResult; +use polars_core::frame::DataFrame; +use polars_core::frame::column::{Column, IntoColumn}; +use polars_core::schema::SchemaRef; +use polars_plan::prelude::ProjectionOptions; +use polars_utils::pl_str::PlSmallStr; + +use crate::expressions::PhysicalPipedExpr; +use crate::operators::{DataChunk, Operator, OperatorResult, PExecutionContext}; + +#[derive(Clone)] +pub(crate) struct SimpleProjectionOperator { + columns: Arc<[PlSmallStr]>, + input_schema: SchemaRef, +} + +impl SimpleProjectionOperator { + pub(crate) fn new(columns: Arc<[PlSmallStr]>, input_schema: SchemaRef) -> Self { + Self { + columns, + input_schema, + } + } +} + +impl Operator for SimpleProjectionOperator { + fn execute( + &mut self, + _context: &PExecutionContext, + chunk: &DataChunk, + ) -> PolarsResult { + let check_duplicates = false; + let chunk = chunk.with_data(chunk.data._select_with_schema_impl( + self.columns.as_ref(), + &self.input_schema, + check_duplicates, + )?); + Ok(OperatorResult::Finished(chunk)) + } + fn split(&self, _thread_no: usize) -> Box { + Box::new(self.clone()) + } + fn fmt(&self) -> &str { + "fast_projection" + } +} + +#[derive(Clone)] +pub(crate) struct ProjectionOperator { + pub(crate) exprs: Vec>, + pub(crate) options: ProjectionOptions, +} + +impl Operator for ProjectionOperator { + fn execute( + &mut self, + context: &PExecutionContext, + chunk: &DataChunk, + ) -> PolarsResult { + let mut has_literals = false; + let mut has_empty = false; + let mut projected = self + .exprs + .iter() + .map(|e| { + #[allow(unused_mut)] + let mut s = e.evaluate(chunk, &context.execution_state)?; + + has_literals |= s.len() == 1; + has_empty |= s.is_empty(); + + Ok(s.into_column()) + }) + .collect::>>()?; + + if has_empty { + for s in &mut projected { + *s = s.clear(); + } + } else if has_literals && self.options.should_broadcast { + let height = projected.iter().map(|s| s.len()).max().unwrap(); + for s in &mut projected { + let len = s.len(); + if len == 1 && len != height { + *s = s.new_from_index(0, height) + } + } + } + + let chunk = + chunk.with_data(unsafe { DataFrame::new_no_checks_height_from_first(projected) }); + Ok(OperatorResult::Finished(chunk)) + } + fn split(&self, _thread_no: usize) -> Box { + Box::new(self.clone()) + } + fn fmt(&self) -> &str { + "projection" + } +} + +#[derive(Clone)] +pub(crate) struct HstackOperator { + pub(crate) exprs: Vec>, + pub(crate) input_schema: SchemaRef, + pub(crate) options: ProjectionOptions, +} + +impl Operator for HstackOperator { + fn execute( + &mut self, + context: &PExecutionContext, + chunk: &DataChunk, + ) -> PolarsResult { + // add temporary cse column to the chunk + let width = chunk.data.width(); + let projected = self + .exprs + .iter() + .map(|e| { + e.evaluate(chunk, &context.execution_state) + .map(Column::from) + }) + .collect::>>()?; + + let columns = chunk.data.get_columns()[..width].to_vec(); + let mut df = unsafe { DataFrame::new_no_checks_height_from_first(columns) }; + + let schema = &*self.input_schema; + if self.options.should_broadcast { + df._add_columns(projected, schema)?; + } else { + debug_assert!( + projected + .iter() + .all(|column| column.name().starts_with("__POLARS_CSER_0x")), + "non-broadcasting hstack should only be used for CSE columns" + ); + // Safety: this case only appears as a result of CSE + // optimization, and the usage there produces new, unique + // column names. It is immediately followed by a + // projection which pulls out the possibly mismatching + // column lengths. + unsafe { df.get_columns_mut().extend(projected) } + } + + let chunk = chunk.with_data(df); + Ok(OperatorResult::Finished(chunk)) + } + fn split(&self, _thread_no: usize) -> Box { + Box::new(self.clone()) + } + fn fmt(&self) -> &str { + "hstack" + } +} diff --git a/crates/polars-pipe/src/executors/operators/reproject.rs b/crates/polars-pipe/src/executors/operators/reproject.rs new file mode 100644 index 000000000000..b3355b30cc87 --- /dev/null +++ b/crates/polars-pipe/src/executors/operators/reproject.rs @@ -0,0 +1,34 @@ +use polars_core::error::PolarsResult; +use polars_core::frame::DataFrame; +use polars_core::schema::Schema; + +use crate::operators::DataChunk; + +pub(crate) fn reproject_chunk( + chunk: &mut DataChunk, + positions: &mut Vec, + schema: &Schema, +) -> PolarsResult<()> { + let out = if positions.is_empty() { + // use the chunk schema to cache + // the positions for subsequent calls + let chunk_schema = chunk.data.schema(); + + let out = chunk + .data + .select_with_schema_unchecked(schema.iter_names_cloned(), chunk_schema)?; + + *positions = out + .get_columns() + .iter() + .map(|s| chunk_schema.get_full(s.name()).unwrap().0) + .collect(); + out + } else { + let columns = chunk.data.get_columns(); + let cols = positions.iter().map(|i| columns[*i].clone()).collect(); + unsafe { DataFrame::new_no_checks(chunk.data.height(), cols) } + }; + *chunk = chunk.with_data(out); + Ok(()) +} diff --git a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/convert.rs b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/convert.rs new file mode 100644 index 000000000000..4e004efbc1ad --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/convert.rs @@ -0,0 +1,293 @@ +use std::sync::Arc; + +use polars_core::datatypes::Field; +use polars_core::error::PolarsResult; +use polars_core::frame::DataFrame; +use polars_core::prelude::{DataType, IDX_DTYPE, SchemaRef, Series}; +use polars_core::schema::Schema; +use polars_expr::state::ExecutionState; +use polars_io::predicates::PhysicalIoExpr; +use polars_plan::dsl::Expr; +use polars_plan::plans::expr_ir::ExprIR; +use polars_plan::plans::{ArenaExprIter, Context}; +use polars_plan::prelude::{AExpr, IRAggExpr}; +use polars_utils::IdxSize; +use polars_utils::arena::{Arena, Node}; +use polars_utils::pl_str::PlSmallStr; + +use crate::executors::sinks::group_by::aggregates::count::CountAgg; +use crate::executors::sinks::group_by::aggregates::first::FirstAgg; +use crate::executors::sinks::group_by::aggregates::last::LastAgg; +use crate::executors::sinks::group_by::aggregates::mean::MeanAgg; +use crate::executors::sinks::group_by::aggregates::min_max::{new_max, new_min}; +use crate::executors::sinks::group_by::aggregates::null::NullAgg; +use crate::executors::sinks::group_by::aggregates::{AggregateFunction, SumAgg}; +use crate::expressions::PhysicalPipedExpr; +use crate::operators::DataChunk; + +struct Len {} + +impl PhysicalIoExpr for Len { + fn evaluate_io(&self, _df: &DataFrame) -> PolarsResult { + unimplemented!() + } +} +impl PhysicalPipedExpr for Len { + fn evaluate(&self, chunk: &DataChunk, _lazy_state: &ExecutionState) -> PolarsResult { + // the length must match the chunks as the operators expect that + // so we fill a null series. + Ok(Series::new_null(PlSmallStr::EMPTY, chunk.data.height())) + } + + fn field(&self, _input_schema: &Schema) -> PolarsResult { + todo!() + } + + fn expression(&self) -> Expr { + Expr::Len + } +} + +pub fn can_convert_to_hash_agg( + mut node: Node, + expr_arena: &Arena, + input_schema: &Schema, +) -> bool { + let mut can_run_partitioned = true; + if expr_arena + .iter(node) + .map(|(_, ae)| { + match ae { + AExpr::Agg(_) + | AExpr::Len + | AExpr::Cast { .. } + | AExpr::Literal(_) + | AExpr::Column(_) + | AExpr::BinaryExpr { .. } + | AExpr::Ternary { .. } + | AExpr::Alias(_, _) => {}, + _ => { + can_run_partitioned = false; + }, + } + ae + }) + .filter(|ae| matches!(ae, AExpr::Agg(_) | AExpr::Len)) + .count() + == 1 + && can_run_partitioned + { + // last expression must be agg or agg.alias + if let AExpr::Alias(input, _) = expr_arena.get(node) { + node = *input + } + match expr_arena.get(node) { + AExpr::Len => true, + ae @ AExpr::Agg(agg_fn) => { + matches!( + agg_fn, + IRAggExpr::Sum(_) + | IRAggExpr::First(_) + | IRAggExpr::Last(_) + | IRAggExpr::Mean(_) + | IRAggExpr::Count(_, false) + ) || (matches!( + agg_fn, + IRAggExpr::Max { + propagate_nans: false, + .. + } | IRAggExpr::Min { + propagate_nans: false, + .. + } + ) && { + if let Ok(field) = ae.to_field(input_schema, Context::Default, expr_arena) { + match field.dtype { + DataType::Date => { + matches!(agg_fn, IRAggExpr::Mean(_) | IRAggExpr::Median(_)) + }, + _ => field.dtype.to_physical().is_primitive_numeric(), + } + } else { + false + } + }) + }, + _ => false, + } + } else { + false + } +} + +/// # Returns: +/// - input_dtype: dtype that goes into the agg expression +/// - physical expr: physical expression that produces the input of the aggregation +/// - aggregation function: the aggregation function +pub(crate) fn convert_to_hash_agg( + node: Node, + expr_arena: &Arena, + schema: &SchemaRef, + to_physical: &F, +) -> (DataType, Arc, AggregateFunction) +where + F: Fn(&ExprIR, &Arena, &SchemaRef) -> PolarsResult>, +{ + match expr_arena.get(node) { + AExpr::Alias(input, _) => convert_to_hash_agg(*input, expr_arena, schema, to_physical), + AExpr::Len => ( + IDX_DTYPE, + Arc::new(Len {}), + AggregateFunction::Len(CountAgg::new()), + ), + AExpr::Agg(agg) => match agg { + IRAggExpr::Min { input, .. } => { + let phys_expr = + to_physical(&ExprIR::from_node(*input, expr_arena), expr_arena, schema) + .unwrap(); + let logical_dtype = phys_expr.field(schema).unwrap().dtype; + + let agg_fn = match logical_dtype.to_physical() { + DataType::Int8 => AggregateFunction::MinMaxI8(new_min()), + DataType::Int16 => AggregateFunction::MinMaxI16(new_min()), + DataType::Int32 => AggregateFunction::MinMaxI32(new_min()), + DataType::Int64 => AggregateFunction::MinMaxI64(new_min()), + DataType::UInt8 => AggregateFunction::MinMaxU8(new_min()), + DataType::UInt16 => AggregateFunction::MinMaxU16(new_min()), + DataType::UInt32 => AggregateFunction::MinMaxU32(new_min()), + DataType::UInt64 => AggregateFunction::MinMaxU64(new_min()), + DataType::Float32 => AggregateFunction::MinMaxF32(new_min()), + DataType::Float64 => AggregateFunction::MinMaxF64(new_min()), + dt => panic!("{dt} unexpected"), + }; + (logical_dtype, phys_expr, agg_fn) + }, + IRAggExpr::Max { input, .. } => { + let phys_expr = + to_physical(&ExprIR::from_node(*input, expr_arena), expr_arena, schema) + .unwrap(); + let logical_dtype = phys_expr.field(schema).unwrap().dtype; + + let agg_fn = match logical_dtype.to_physical() { + DataType::Int8 => AggregateFunction::MinMaxI8(new_max()), + DataType::Int16 => AggregateFunction::MinMaxI16(new_max()), + DataType::Int32 => AggregateFunction::MinMaxI32(new_max()), + DataType::Int64 => AggregateFunction::MinMaxI64(new_max()), + DataType::UInt8 => AggregateFunction::MinMaxU8(new_max()), + DataType::UInt16 => AggregateFunction::MinMaxU16(new_max()), + DataType::UInt32 => AggregateFunction::MinMaxU32(new_max()), + DataType::UInt64 => AggregateFunction::MinMaxU64(new_max()), + DataType::Float32 => AggregateFunction::MinMaxF32(new_max()), + DataType::Float64 => AggregateFunction::MinMaxF64(new_max()), + dt => panic!("{dt} unexpected"), + }; + (logical_dtype, phys_expr, agg_fn) + }, + IRAggExpr::Sum(input) => { + let phys_expr = + to_physical(&ExprIR::from_node(*input, expr_arena), expr_arena, schema) + .unwrap(); + let logical_dtype = phys_expr.field(schema).unwrap().dtype; + + #[cfg(feature = "dtype-categorical")] + if matches!( + logical_dtype, + DataType::Categorical(_, _) | DataType::Enum(_, _) + ) { + return ( + logical_dtype.clone(), + phys_expr, + AggregateFunction::Null(NullAgg::new(logical_dtype)), + ); + } + + let agg_fn = match logical_dtype.to_physical() { + // Boolean is aggregated as the IDX type. + DataType::Boolean => { + if size_of::() == 4 { + AggregateFunction::SumU32(SumAgg::::new()) + } else { + AggregateFunction::SumU64(SumAgg::::new()) + } + }, + // these are aggregated as i64 to prevent overflow + DataType::Int8 => AggregateFunction::SumI64(SumAgg::::new()), + DataType::Int16 => AggregateFunction::SumI64(SumAgg::::new()), + DataType::UInt8 => AggregateFunction::SumI64(SumAgg::::new()), + DataType::UInt16 => AggregateFunction::SumI64(SumAgg::::new()), + // these stay true to there types + DataType::UInt32 => AggregateFunction::SumU32(SumAgg::::new()), + DataType::UInt64 => AggregateFunction::SumU64(SumAgg::::new()), + DataType::Int32 => AggregateFunction::SumI32(SumAgg::::new()), + DataType::Int64 => AggregateFunction::SumI64(SumAgg::::new()), + DataType::Float32 => AggregateFunction::SumF32(SumAgg::::new()), + DataType::Float64 => AggregateFunction::SumF64(SumAgg::::new()), + dt => AggregateFunction::Null(NullAgg::new(dt)), + }; + (logical_dtype, phys_expr, agg_fn) + }, + IRAggExpr::Mean(input) => { + let phys_expr = + to_physical(&ExprIR::from_node(*input, expr_arena), expr_arena, schema) + .unwrap(); + + let logical_dtype = phys_expr.field(schema).unwrap().dtype; + #[cfg(feature = "dtype-categorical")] + if matches!( + logical_dtype, + DataType::Categorical(_, _) | DataType::Enum(_, _) | DataType::Date + ) { + return ( + logical_dtype.clone(), + phys_expr, + AggregateFunction::Null(NullAgg::new(logical_dtype)), + ); + } + let agg_fn = match logical_dtype.to_physical() { + dt if dt.is_integer() | dt.is_bool() => { + AggregateFunction::MeanF64(MeanAgg::::new()) + }, + DataType::Float32 => AggregateFunction::MeanF32(MeanAgg::::new()), + DataType::Float64 => AggregateFunction::MeanF64(MeanAgg::::new()), + dt => AggregateFunction::Null(NullAgg::new(dt)), + }; + (logical_dtype, phys_expr, agg_fn) + }, + IRAggExpr::First(input) => { + let phys_expr = + to_physical(&ExprIR::from_node(*input, expr_arena), expr_arena, schema) + .unwrap(); + let logical_dtype = phys_expr.field(schema).unwrap().dtype; + ( + logical_dtype.clone(), + phys_expr, + AggregateFunction::First(FirstAgg::new(logical_dtype.to_physical())), + ) + }, + IRAggExpr::Last(input) => { + let phys_expr = + to_physical(&ExprIR::from_node(*input, expr_arena), expr_arena, schema) + .unwrap(); + let logical_dtype = phys_expr.field(schema).unwrap().dtype; + ( + logical_dtype.clone(), + phys_expr, + AggregateFunction::Last(LastAgg::new(logical_dtype.to_physical())), + ) + }, + IRAggExpr::Count(input, _) => { + let phys_expr = + to_physical(&ExprIR::from_node(*input, expr_arena), expr_arena, schema) + .unwrap(); + let logical_dtype = phys_expr.field(schema).unwrap().dtype; + ( + logical_dtype, + phys_expr, + AggregateFunction::Count(CountAgg::new()), + ) + }, + agg => panic!("{agg:?} not yet implemented."), + }, + _ => todo!(), + } +} diff --git a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/count.rs b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/count.rs new file mode 100644 index 000000000000..92b3a3f56a0f --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/count.rs @@ -0,0 +1,57 @@ +use std::any::Any; + +use polars_core::datatypes::{AnyValue, DataType}; +use polars_core::prelude::{IDX_DTYPE, Series}; + +use super::*; +use crate::operators::IdxSize; + +pub(crate) struct CountAgg { + count: IdxSize, +} + +impl CountAgg { + pub(crate) fn new() -> Self { + CountAgg { count: 0 } + } +} + +impl AggregateFn for CountAgg { + fn has_physical_agg(&self) -> bool { + false + } + + fn pre_agg(&mut self, _chunk_idx: IdxSize, item: &mut dyn ExactSizeIterator) { + let item = unsafe { item.next().unwrap_unchecked() }; + if INCLUDE_NULL { + self.count += 1; + } else { + self.count += !matches!(item, AnyValue::Null) as IdxSize; + } + } + fn pre_agg_ordered( + &mut self, + _chunk_idx: IdxSize, + _offset: IdxSize, + length: IdxSize, + _values: &Series, + ) { + self.count += length + } + + fn dtype(&self) -> DataType { + IDX_DTYPE + } + + fn combine(&mut self, other: &dyn Any) { + let other = unsafe { other.downcast_ref::().unwrap_unchecked() }; + self.count += other.count; + } + + fn finalize(&mut self) -> AnyValue<'static> { + AnyValue::from(self.count) + } + fn as_any(&self) -> &dyn Any { + self + } +} diff --git a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/first.rs b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/first.rs new file mode 100644 index 000000000000..1ec3711568ca --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/first.rs @@ -0,0 +1,64 @@ +use std::any::Any; + +use polars_core::datatypes::DataType; +use polars_core::prelude::{AnyValue, Series}; + +use crate::executors::sinks::group_by::aggregates::AggregateFn; +use crate::operators::IdxSize; + +pub(crate) struct FirstAgg { + chunk_idx: IdxSize, + first: Option>, + pub(crate) dtype: DataType, +} + +impl FirstAgg { + pub(crate) fn new(dtype: DataType) -> Self { + Self { + chunk_idx: IdxSize::MAX, + first: None, + dtype, + } + } +} + +impl AggregateFn for FirstAgg { + fn pre_agg(&mut self, chunk_idx: IdxSize, item: &mut dyn ExactSizeIterator) { + let item = unsafe { item.next().unwrap_unchecked() }; + if self.first.is_none() { + self.chunk_idx = chunk_idx; + self.first = Some(item.into_static()) + } + } + fn pre_agg_ordered( + &mut self, + chunk_idx: IdxSize, + offset: IdxSize, + _length: IdxSize, + values: &Series, + ) { + if self.first.is_none() { + self.chunk_idx = chunk_idx; + self.first = Some(unsafe { values.get_unchecked(offset as usize) }.into_static()) + } + } + + fn dtype(&self) -> DataType { + self.dtype.clone() + } + + fn combine(&mut self, other: &dyn Any) { + let other = unsafe { other.downcast_ref::().unwrap_unchecked() }; + if other.first.is_some() && other.chunk_idx < self.chunk_idx { + self.first.clone_from(&other.first); + self.chunk_idx = other.chunk_idx; + }; + } + + fn finalize(&mut self) -> AnyValue<'static> { + std::mem::take(&mut self.first).unwrap_or(AnyValue::Null) + } + fn as_any(&self) -> &dyn Any { + self + } +} diff --git a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/interface.rs b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/interface.rs new file mode 100644 index 000000000000..a05469811401 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/interface.rs @@ -0,0 +1,101 @@ +use std::any::Any; + +use enum_dispatch::enum_dispatch; +use num_traits::NumCast; +use polars_core::datatypes::DataType; +use polars_core::prelude::{AnyValue, Series}; + +use crate::executors::sinks::group_by::aggregates::SumAgg; +use crate::executors::sinks::group_by::aggregates::count::CountAgg; +use crate::executors::sinks::group_by::aggregates::first::FirstAgg; +use crate::executors::sinks::group_by::aggregates::last::LastAgg; +use crate::executors::sinks::group_by::aggregates::mean::MeanAgg; +use crate::executors::sinks::group_by::aggregates::min_max::MinMaxAgg; +use crate::executors::sinks::group_by::aggregates::null::NullAgg; +use crate::operators::IdxSize; + +#[enum_dispatch(AggregateFunction)] +pub(crate) trait AggregateFn: Send + Sync { + fn has_physical_agg(&self) -> bool { + false + } + fn pre_agg(&mut self, _chunk_idx: IdxSize, item: &mut dyn ExactSizeIterator); + fn pre_agg_ordered( + &mut self, + _chunk_idx: IdxSize, + offset: IdxSize, + length: IdxSize, + values: &Series, + ); + fn pre_agg_primitive(&mut self, _chunk_idx: IdxSize, _item: Option) { + unimplemented!() + } + + fn dtype(&self) -> DataType; + + fn combine(&mut self, other: &dyn Any); + + fn finalize(&mut self) -> AnyValue<'static>; + + fn as_any(&self) -> &dyn Any; +} + +// We dispatch via an enum +// as that saves an indirection +#[enum_dispatch] +pub(crate) enum AggregateFunction { + First(FirstAgg), + Last(LastAgg), + Count(CountAgg), + Len(CountAgg), + SumF32(SumAgg), + SumF64(SumAgg), + SumU32(SumAgg), + SumU64(SumAgg), + SumI32(SumAgg), + SumI64(SumAgg), + MeanF32(MeanAgg), + MeanF64(MeanAgg), + Null(NullAgg), + MinMaxF32(MinMaxAgg f32>), + MinMaxF64(MinMaxAgg f64>), + MinMaxU8(MinMaxAgg u8>), + MinMaxU16(MinMaxAgg u16>), + MinMaxU32(MinMaxAgg u32>), + MinMaxU64(MinMaxAgg u64>), + MinMaxI8(MinMaxAgg i8>), + MinMaxI16(MinMaxAgg i16>), + MinMaxI32(MinMaxAgg i32>), + MinMaxI64(MinMaxAgg i64>), +} + +impl AggregateFunction { + pub(crate) fn split(&self) -> Self { + use AggregateFunction::*; + match self { + First(agg) => First(FirstAgg::new(agg.dtype.clone())), + Last(agg) => Last(LastAgg::new(agg.dtype.clone())), + SumF32(_) => SumF32(SumAgg::new()), + SumF64(_) => SumF64(SumAgg::new()), + SumU32(_) => SumU32(SumAgg::new()), + SumU64(_) => SumU64(SumAgg::new()), + SumI32(_) => SumI32(SumAgg::new()), + SumI64(_) => SumI64(SumAgg::new()), + MeanF32(_) => MeanF32(MeanAgg::new()), + MeanF64(_) => MeanF64(MeanAgg::new()), + Count(_) => Count(CountAgg::new()), + Len(_) => Len(CountAgg::new()), + Null(a) => Null(a.clone()), + MinMaxF32(inner) => MinMaxF32(inner.split()), + MinMaxF64(inner) => MinMaxF64(inner.split()), + MinMaxU8(inner) => MinMaxU8(inner.split()), + MinMaxU16(inner) => MinMaxU16(inner.split()), + MinMaxU32(inner) => MinMaxU32(inner.split()), + MinMaxU64(inner) => MinMaxU64(inner.split()), + MinMaxI8(inner) => MinMaxI8(inner.split()), + MinMaxI16(inner) => MinMaxI16(inner.split()), + MinMaxI32(inner) => MinMaxI32(inner.split()), + MinMaxI64(inner) => MinMaxI64(inner.split()), + } + } +} diff --git a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/last.rs b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/last.rs new file mode 100644 index 000000000000..0febeec611e3 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/last.rs @@ -0,0 +1,61 @@ +use std::any::Any; + +use polars_core::datatypes::DataType; +use polars_core::prelude::{AnyValue, Series}; + +use crate::executors::sinks::group_by::aggregates::AggregateFn; +use crate::operators::IdxSize; + +pub(crate) struct LastAgg { + chunk_idx: IdxSize, + last: Option>, + pub(crate) dtype: DataType, +} + +impl LastAgg { + pub(crate) fn new(dtype: DataType) -> Self { + Self { + chunk_idx: 0, + last: None, + dtype, + } + } +} + +impl AggregateFn for LastAgg { + fn pre_agg(&mut self, chunk_idx: IdxSize, item: &mut dyn ExactSizeIterator) { + let item = unsafe { item.next().unwrap_unchecked() }; + self.chunk_idx = chunk_idx; + self.last = Some(item.into_static()); + } + fn pre_agg_ordered( + &mut self, + chunk_idx: IdxSize, + offset: IdxSize, + length: IdxSize, + values: &Series, + ) { + self.chunk_idx = chunk_idx; + self.last = + Some(unsafe { values.get_unchecked((offset + length - 1) as usize) }.into_static()) + } + + fn dtype(&self) -> DataType { + self.dtype.clone() + } + + fn combine(&mut self, other: &dyn Any) { + let other = unsafe { other.downcast_ref::().unwrap_unchecked() }; + if other.last.is_some() && other.chunk_idx >= self.chunk_idx { + self.last.clone_from(&other.last); + self.chunk_idx = other.chunk_idx; + }; + } + + fn finalize(&mut self) -> AnyValue<'static> { + std::mem::take(&mut self.last).unwrap_or(AnyValue::Null) + } + fn as_any(&self) -> &dyn Any { + self + } +} diff --git a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/mean.rs b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/mean.rs new file mode 100644 index 000000000000..6c93618d4c16 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/mean.rs @@ -0,0 +1,134 @@ +use std::any::Any; +use std::ops::Add; + +use arrow::array::{Array, PrimitiveArray}; +use arrow::datatypes::PrimitiveType; +use num_traits::NumCast; +use polars_compute::sum::{WrappingSum, wrapping_sum_arr}; +use polars_core::prelude::*; + +use super::*; + +pub struct MeanAgg { + sum: Option, + count: IdxSize, +} + +impl MeanAgg { + pub(crate) fn new() -> Self { + MeanAgg { + sum: None, + count: 0, + } + } +} + +impl AggregateFn for MeanAgg +where + K::PolarsType: PolarsNumericType, + K: NumericNative + Add + WrappingSum, +{ + fn has_physical_agg(&self) -> bool { + true + } + + fn pre_agg_primitive(&mut self, _chunk_idx: IdxSize, item: Option) { + match (item.map(|v| K::from(v).unwrap()), self.sum) { + (Some(val), Some(sum)) => { + self.sum = Some(sum + val); + self.count += 1; + }, + (Some(val), None) => { + self.sum = Some(val); + self.count += 1; + }, + _ => {}, + } + } + + fn pre_agg(&mut self, _chunk_idx: IdxSize, item: &mut dyn ExactSizeIterator) { + let item = unsafe { item.next().unwrap_unchecked() }; + match (item.extract::(), self.sum) { + (Some(val), Some(sum)) => { + self.sum = Some(sum + val); + self.count += 1; + }, + (Some(val), None) => { + self.sum = Some(val); + self.count += 1; + }, + _ => {}, + } + } + + fn pre_agg_ordered( + &mut self, + _chunk_idx: IdxSize, + offset: IdxSize, + length: IdxSize, + values: &Series, + ) { + // we must cast because mean might be a different dtype + let arr = unsafe { + let arr = values.chunks().get_unchecked(0); + arr.sliced_unchecked(offset as usize, length as usize) + }; + let dtype = K::PolarsType::get_dtype().to_arrow(CompatLevel::newest()); + let arr = polars_compute::cast::cast_unchecked(arr.as_ref(), &dtype).unwrap(); + let arr = unsafe { + arr.as_any() + .downcast_ref::>() + .unwrap_unchecked() + }; + match (wrapping_sum_arr(arr), self.sum) { + (val, Some(sum)) => { + self.sum = Some(sum + val); + self.count += (arr.len() - arr.null_count()) as IdxSize; + }, + (val, None) => { + self.sum = Some(val); + self.count += (arr.len() - arr.null_count()) as IdxSize; + }, + } + } + + fn dtype(&self) -> DataType { + DataType::from_arrow_dtype(&ArrowDataType::from(K::PRIMITIVE)) + } + + fn combine(&mut self, other: &dyn Any) { + let other = unsafe { other.downcast_ref::().unwrap_unchecked() }; + match (self.sum, other.sum) { + (Some(lhs), Some(rhs)) => { + self.sum = Some(lhs + rhs); + self.count += other.count; + }, + (None, Some(rhs)) => { + self.sum = Some(rhs); + self.count = other.count; + }, + _ => {}, + }; + } + + fn finalize(&mut self) -> AnyValue<'static> { + if let Some(val) = self.sum { + unsafe { + match K::PRIMITIVE { + PrimitiveType::Float32 => { + AnyValue::Float32(val.to_f32().unwrap_unchecked() / self.count as f32) + }, + PrimitiveType::Float64 => { + AnyValue::Float64(val.to_f64().unwrap_unchecked() / self.count as f64) + }, + _ => todo!(), + } + } + } else { + AnyValue::Null + } + } + fn as_any(&self) -> &dyn Any { + self + } +} diff --git a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/min_max.rs b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/min_max.rs new file mode 100644 index 000000000000..ef97d3ea0017 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/min_max.rs @@ -0,0 +1,106 @@ +use std::any::Any; + +use arrow::array::PrimitiveArray; +use num_traits::NumCast; +use polars_compute::min_max::MinMaxKernel; +use polars_core::prelude::*; +use polars_utils::min_max::MinMax; + +use super::*; + +pub(super) fn new_min() -> MinMaxAgg K> { + MinMaxAgg::new(MinMax::min_ignore_nan, true) +} + +pub(super) fn new_max() -> MinMaxAgg K> { + MinMaxAgg::new(MinMax::max_ignore_nan, false) +} + +pub struct MinMaxAgg { + agg: Option, + agg_fn: F, + is_min: bool, +} + +impl K + Copy> MinMaxAgg { + pub(crate) fn new(f: F, is_min: bool) -> Self { + MinMaxAgg { + agg: None, + agg_fn: f, + is_min, + } + } + + pub(crate) fn split(&self) -> Self { + MinMaxAgg { + agg: None, + agg_fn: self.agg_fn, + is_min: self.is_min, + } + } +} + +impl K + Send + Sync + 'static> AggregateFn for MinMaxAgg +where + K: NumericNative, + PrimitiveArray: for<'a> MinMaxKernel = K>, +{ + fn has_physical_agg(&self) -> bool { + true + } + + fn pre_agg(&mut self, chunk_idx: IdxSize, item: &mut dyn ExactSizeIterator) { + let item = unsafe { item.next().unwrap_unchecked() }; + self.pre_agg_primitive(chunk_idx, item.extract::()) + } + + fn pre_agg_primitive(&mut self, _chunk_idx: IdxSize, item: Option) { + match (item.map(|v| K::from(v).unwrap()), self.agg) { + (Some(val), Some(current_agg)) => { + self.agg = Some((self.agg_fn)(current_agg, val)); + }, + (Some(val), None) => self.agg = Some(val), + (None, _) => {}, + } + } + + fn pre_agg_ordered( + &mut self, + _chunk_idx: IdxSize, + offset: IdxSize, + length: IdxSize, + values: &Series, + ) { + let ca: &ChunkedArray = values.as_ref().as_ref(); + let arr = ca.downcast_iter().next().unwrap(); + let arr = unsafe { arr.slice_typed_unchecked(offset as usize, length as usize) }; + // convince the compiler that K::POLARSTYPE::Native == K + let arr = unsafe { std::mem::transmute::, PrimitiveArray>(arr) }; + let agg = if self.is_min { + arr.min_ignore_nan_kernel() + } else { + arr.max_ignore_nan_kernel() + }; + self.pre_agg_primitive(0, agg) + } + + fn dtype(&self) -> DataType { + DataType::from_arrow_dtype(&ArrowDataType::from(K::PRIMITIVE)) + } + + fn combine(&mut self, other: &dyn Any) { + let other = unsafe { other.downcast_ref::().unwrap_unchecked() }; + self.pre_agg_primitive(0, other.agg) + } + + fn finalize(&mut self) -> AnyValue<'static> { + if let Some(val) = self.agg { + val.into() + } else { + AnyValue::Null + } + } + fn as_any(&self) -> &dyn Any { + self + } +} diff --git a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/mod.rs b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/mod.rs new file mode 100644 index 000000000000..939a6edc451a --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/mod.rs @@ -0,0 +1,13 @@ +mod convert; +mod count; +mod first; +mod interface; +mod last; +mod mean; +mod min_max; +mod null; +mod sum; + +pub use convert::*; +pub(crate) use interface::{AggregateFn, AggregateFunction}; +pub(crate) use sum::SumAgg; diff --git a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/null.rs b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/null.rs new file mode 100644 index 000000000000..768bcde96947 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/null.rs @@ -0,0 +1,45 @@ +use std::any::Any; + +use polars_core::prelude::*; + +use crate::executors::sinks::group_by::aggregates::AggregateFn; + +#[derive(Clone)] +pub struct NullAgg(DataType); + +impl NullAgg { + pub(crate) fn new(dt: DataType) -> Self { + Self(dt) + } +} + +impl AggregateFn for NullAgg { + fn pre_agg(&mut self, _chunk_idx: IdxSize, _item: &mut dyn ExactSizeIterator) { + // no-op + } + fn pre_agg_ordered( + &mut self, + _chunk_idx: IdxSize, + _offset: IdxSize, + _length: IdxSize, + _values: &Series, + ) { + // no-op + } + + fn dtype(&self) -> DataType { + self.0.clone() + } + + fn combine(&mut self, _other: &dyn Any) { + // no-op + } + + fn finalize(&mut self) -> AnyValue<'static> { + AnyValue::Null + } + + fn as_any(&self) -> &dyn Any { + self + } +} diff --git a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/sum.rs b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/sum.rs new file mode 100644 index 000000000000..3d6836905f21 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/sum.rs @@ -0,0 +1,96 @@ +use std::any::Any; +use std::ops::Add; + +use arrow::array::PrimitiveArray; +use num_traits::NumCast; +use polars_compute::sum::{WrappingSum, wrapping_sum_arr}; +use polars_core::prelude::*; + +use super::*; + +pub struct SumAgg { + sum: Option, +} + +impl SumAgg { + pub(crate) fn new() -> Self { + SumAgg { sum: None } + } +} + +impl AggregateFn for SumAgg +where + K::PolarsType: PolarsNumericType, + K: NumericNative + Add + WrappingSum, +{ + fn has_physical_agg(&self) -> bool { + true + } + + fn pre_agg(&mut self, _chunk_idx: IdxSize, item: &mut dyn ExactSizeIterator) { + let item = unsafe { item.next().unwrap_unchecked() }; + self.pre_agg_primitive(0, item.extract::()) + } + fn pre_agg_primitive(&mut self, _chunk_idx: IdxSize, item: Option) { + match (item.map(|v| K::from(v).unwrap()), self.sum) { + (Some(val), Some(sum)) => self.sum = Some(sum + val), + (Some(val), None) => self.sum = Some(val), + (None, _) => {}, + } + } + + fn pre_agg_ordered( + &mut self, + _chunk_idx: IdxSize, + offset: IdxSize, + length: IdxSize, + values: &Series, + ) { + // we must cast because sum output type might be different than input type. + let arr = unsafe { + let arr = values.chunks().get_unchecked(0); + arr.sliced_unchecked(offset as usize, length as usize) + }; + let dtype = K::PolarsType::get_dtype().to_arrow(CompatLevel::newest()); + let arr = polars_compute::cast::cast_unchecked(arr.as_ref(), &dtype).unwrap(); + let arr = unsafe { + arr.as_any() + .downcast_ref::>() + .unwrap_unchecked() + }; + match (wrapping_sum_arr(arr), self.sum) { + (val, Some(sum)) => { + self.sum = Some(sum + val); + }, + (val, None) => { + self.sum = Some(val); + }, + } + } + + fn dtype(&self) -> DataType { + DataType::from_arrow_dtype(&ArrowDataType::from(K::PRIMITIVE)) + } + + fn combine(&mut self, other: &dyn Any) { + let other = unsafe { other.downcast_ref::().unwrap_unchecked() }; + let sum = match (self.sum, other.sum) { + (Some(lhs), Some(rhs)) => Some(lhs + rhs), + (Some(lhs), None) => Some(lhs), + (None, Some(rhs)) => Some(rhs), + (None, None) => None, + }; + self.sum = sum; + } + + fn finalize(&mut self) -> AnyValue<'static> { + if let Some(val) = self.sum { + val.into() + } else { + K::zero().into() + } + } + fn as_any(&self) -> &dyn Any { + self + } +} diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/eval.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/eval.rs new file mode 100644 index 000000000000..cf89dc02a245 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/eval.rs @@ -0,0 +1,119 @@ +use std::cell::UnsafeCell; + +use polars_row::{RowEncodingOptions, RowsEncoded}; +use polars_utils::aliases::PlSeedableRandomStateQuality; + +use self::row_encode::get_row_encoding_context; +use super::*; +use crate::executors::sinks::group_by::utils::prepare_key; +use crate::executors::sinks::utils::hash_rows; +use crate::expressions::PhysicalPipedExpr; + +pub(super) struct Eval { + // the keys that will be aggregated on + key_columns_expr: Arc>>, + // the columns that will be aggregated + aggregation_columns_expr: Arc>>, + hb: PlSeedableRandomStateQuality, + // amortize allocations + aggregation_series: UnsafeCell>, + keys_columns: UnsafeCell>, + hashes: Vec, + key_fields: Vec, + // amortizes the encoding buffers + rows_encoded: RowsEncoded, +} + +impl Eval { + pub(super) fn new( + key_columns: Arc>>, + aggregation_columns: Arc>>, + ) -> Self { + let hb = PlSeedableRandomStateQuality::default(); + Self { + key_columns_expr: key_columns, + aggregation_columns_expr: aggregation_columns, + hb, + aggregation_series: Default::default(), + keys_columns: Default::default(), + hashes: Default::default(), + key_fields: Default::default(), + rows_encoded: Default::default(), + } + } + pub(super) fn split(&self) -> Self { + Self { + key_columns_expr: self.key_columns_expr.clone(), + aggregation_columns_expr: self.aggregation_columns_expr.clone(), + hb: self.hb, + aggregation_series: Default::default(), + keys_columns: Default::default(), + hashes: Default::default(), + key_fields: vec![Default::default(); self.key_columns_expr.len()], + rows_encoded: Default::default(), + } + } + + pub(super) unsafe fn clear(&mut self) { + let keys_series = &mut *self.keys_columns.get(); + let aggregation_series = &mut *self.aggregation_series.get(); + keys_series.clear(); + aggregation_series.clear(); + self.hashes.clear(); + } + + pub(super) unsafe fn evaluate_keys_aggs_and_hashes( + &mut self, + context: &PExecutionContext, + chunk: &DataChunk, + ) -> PolarsResult<()> { + let keys_columns = &mut *self.keys_columns.get(); + let aggregation_series = &mut *self.aggregation_series.get(); + + for phys_e in self.aggregation_columns_expr.iter() { + let s = phys_e.evaluate(chunk, &context.execution_state)?; + let s = s.to_physical_repr(); + aggregation_series.push(s.into_owned()); + } + let mut dicts = Vec::with_capacity(self.key_columns_expr.len()); + for phys_e in self.key_columns_expr.iter() { + let s = phys_e.evaluate(chunk, &context.execution_state)?; + dicts.push(get_row_encoding_context(s.dtype(), false)); + let s = s.to_physical_repr().into_owned(); + let s = prepare_key(&s, chunk); + keys_columns.push(s.to_arrow(0, CompatLevel::newest())); + } + + polars_row::convert_columns_amortized( + keys_columns[0].len(), // @NOTE: does not work for ZFS + keys_columns, + self.key_fields + .iter() + .copied() + .zip(dicts.iter().map(|v| v.as_ref())), + &mut self.rows_encoded, + ); + // drop the series, all data is in the rows encoding now + keys_columns.clear(); + + // write the hashes to self.hashes buffer + let keys_array = self.rows_encoded.borrow_array(); + hash_rows(&keys_array, &mut self.hashes, &self.hb); + Ok(()) + } + + /// # Safety + /// Caller must ensure `self.rows_encoded` stays alive as the lifetime + /// is bound to the returned array. + pub(super) unsafe fn get_keys_iter(&self) -> BinaryArray { + self.rows_encoded.borrow_array() + } + pub(super) unsafe fn get_aggs_iters(&self) -> Vec { + let aggregation_series = &*self.aggregation_series.get(); + aggregation_series.iter().map(|s| s.phys_iter()).collect() + } + + pub(super) fn hashes(&self) -> &[u64] { + &self.hashes + } +} diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/global.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/global.rs new file mode 100644 index 000000000000..55701f80ac2d --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/global.rs @@ -0,0 +1,226 @@ +use std::collections::LinkedList; +use std::sync::atomic::{AtomicU16, Ordering}; + +use polars_core::POOL; +use polars_core::utils::accumulate_dataframes_vertical_unchecked; +use rayon::prelude::*; + +use super::*; +use crate::pipeline::{FORCE_OOC, PARTITION_SIZE}; + +struct SpillPartitions { + // outer vec: partitions (factor of 2) + partitions: PartitionVec>>, +} + +impl SpillPartitions { + fn new() -> Self { + let mut partitions = Vec::with_capacity(PARTITION_SIZE); + partitions.resize_with(PARTITION_SIZE, Default::default); + + Self { partitions } + } + + #[inline] + fn insert(&self, partition: usize, to_spill: SpillPayload) -> usize { + let partition = &self.partitions[partition]; + let mut partition = partition.lock().unwrap(); + partition.push_back(to_spill); + partition.len() + } + + fn drain_partition( + &self, + partition: usize, + min_size: usize, + ) -> Option> { + let partition = &self.partitions[partition]; + let mut partition = partition.lock().unwrap(); + if partition.len() > min_size { + Some(std::mem::take(&mut partition)) + } else { + None + } + } + + fn spill_schema(&self) -> Option { + for part in &self.partitions { + let bucket = part.lock().unwrap(); + if let Some(payload) = bucket.front() { + return Some(payload.get_schema()); + } + } + None + } +} + +pub(super) struct GlobalTable { + inner_maps: PartitionVec>>, + spill_partitions: SpillPartitions, + early_merge_counter: Arc, + // IO is expensive so we only spill if we have `N` payloads to dump. + spill_partition_ob_size: usize, +} + +impl GlobalTable { + pub(super) fn new( + agg_constructors: Arc<[AggregateFunction]>, + key_dtypes: &[DataType], + output_schema: SchemaRef, + ) -> Self { + let spill_partitions = SpillPartitions::new(); + + let spill_partition_ob_size = if std::env::var(FORCE_OOC).is_ok() { + 1 + } else { + 64 + }; + + let mut inner_maps = Vec::with_capacity(PARTITION_SIZE); + inner_maps.resize_with(PARTITION_SIZE, || { + Mutex::new(AggHashTable::new( + agg_constructors.clone(), + key_dtypes, + output_schema.clone(), + None, + )) + }); + + Self { + inner_maps, + spill_partitions, + early_merge_counter: Default::default(), + spill_partition_ob_size, + } + } + + #[inline] + pub(super) fn spill(&self, partition: usize, payload: SpillPayload) { + self.spill_partitions.insert(partition, payload); + } + + pub(super) fn early_merge(&self) { + // round robin a partition to merge early + let partition = + self.early_merge_counter.fetch_add(1, Ordering::Relaxed) as usize % PARTITION_SIZE; + self.process_partition(partition) + } + + pub(super) fn get_ooc_dump_schema(&self) -> Option { + self.spill_partitions.spill_schema() + } + + pub(super) fn get_ooc_dump(&self) -> Option<(usize, DataFrame)> { + // round robin a partition to dump + let partition = + self.early_merge_counter.fetch_add(1, Ordering::Relaxed) as usize % PARTITION_SIZE; + + // IO is expensive so we only spill if we have `N` payloads to dump. + let bucket = self + .spill_partitions + .drain_partition(partition, self.spill_partition_ob_size)?; + Some(( + partition, + accumulate_dataframes_vertical_unchecked(bucket.into_iter().map(|pl| pl.into_df())), + )) + } + + fn process_partition_impl( + &self, + hash_map: &mut AggHashTable, + hashes: &[u64], + chunk_indexes: &[IdxSize], + keys: &BinaryArray, + agg_cols: &[Column], + ) { + debug_assert_eq!(hashes.len(), chunk_indexes.len()); + debug_assert_eq!(hashes.len(), keys.len()); + + // let mut keys_iters = keys.iter().map(|s| s.phys_iter()).collect::>(); + let mut agg_cols_iters = agg_cols.iter().map(|s| s.phys_iter()).collect::>(); + + // amortize loop counter + for (i, row) in keys.values_iter().enumerate() { + unsafe { + let hash = *hashes.get_unchecked(i); + let chunk_index = *chunk_indexes.get_unchecked(i); + + // SAFETY: keys_iters and cols_iters are not depleted + let overflow = hash_map.insert(hash, row, &mut agg_cols_iters, chunk_index); + // should never overflow + debug_assert!(!overflow); + } + } + } + + pub(super) fn process_partition_from_dumped(&self, partition: usize, spilled: &DataFrame) { + let mut hash_map = self.inner_maps[partition].lock().unwrap(); + let (hashes, chunk_indexes, keys, aggs) = SpillPayload::spilled_to_columns(spilled); + self.process_partition_impl(&mut hash_map, hashes, chunk_indexes, keys, aggs); + } + + fn process_partition(&self, partition: usize) { + if let Some(bucket) = self.spill_partitions.drain_partition(partition, 0) { + let mut hash_map = self.inner_maps[partition].lock().unwrap(); + + for payload in bucket { + let hashes = payload.hashes(); + let keys = payload.keys(); + let chunk_indexes = payload.chunk_index(); + let agg_cols = payload.cols(); + + // @scalar-opt + let agg_cols = agg_cols + .iter() + .map(|v| v.clone().into_column()) + .collect::>(); + + self.process_partition_impl(&mut hash_map, hashes, chunk_indexes, keys, &agg_cols); + } + } + } + + pub(super) fn merge_local_map(&self, finalized_local_map: &AggHashTable) { + // TODO! maybe parallelize? + // needs unsafe, first benchmark. + for (partition_i, pt_map) in self.inner_maps.iter().enumerate() { + let mut pt_map = pt_map.lock().unwrap(); + pt_map.combine_on_partition(partition_i, finalized_local_map) + } + } + + pub(super) fn finalize_partition( + &self, + partition: usize, + slice: &mut Option<(i64, usize)>, + ) -> DataFrame { + // ensure all spilled partitions are processed + self.process_partition(partition); + let mut hash_map = self.inner_maps[partition].lock().unwrap(); + hash_map.finalize(slice) + } + + // only should be called if all state is in-memory + pub(super) fn finalize(&self, slice: &mut Option<(i64, usize)>) -> Vec { + if slice.is_none() { + POOL.install(|| { + (0..PARTITION_SIZE) + .into_par_iter() + .map(|part_i| { + self.process_partition(part_i); + let mut hash_map = self.inner_maps[part_i].lock().unwrap(); + hash_map.finalize(&mut None) + }) + .collect() + }) + } else { + (0..PARTITION_SIZE) + .map(|part_i| { + self.process_partition(part_i); + let mut hash_map = self.inner_maps[part_i].lock().unwrap(); + hash_map.finalize(slice) + }) + .collect() + } + } +} diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/hash_table.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/hash_table.rs new file mode 100644 index 000000000000..2dad78673f87 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/hash_table.rs @@ -0,0 +1,288 @@ +use arrow::legacy::trusted_len::TrustedLenPush; +use polars_utils::hashing::hash_to_partition; + +use self::row_encode::get_row_encoding_context; +use super::*; +use crate::pipeline::PARTITION_SIZE; + +pub(super) struct AggHashTable { + inner_map: PlIdHashMap, + // row data of the keys + keys: Vec, + // the aggregation that are in process + // the index the hashtable points to the start of the aggregations of that key/group + running_aggregations: Vec, + // n aggregation function constructors + // The are used to create new running aggregators + agg_constructors: Arc<[AggregateFunction]>, + output_schema: SchemaRef, + pub num_keys: usize, + spill_size: usize, +} + +impl AggHashTable { + pub(super) fn new( + agg_constructors: Arc<[AggregateFunction]>, + key_dtypes: &[DataType], + output_schema: SchemaRef, + spill_size: Option, + ) -> Self { + assert_eq!(FIXED, spill_size.is_some()); + Self { + inner_map: Default::default(), + keys: Default::default(), + running_aggregations: Default::default(), + agg_constructors, + num_keys: key_dtypes.len(), + spill_size: spill_size.unwrap_or(usize::MAX), + output_schema, + } + } + + pub(super) fn split(&self) -> Self { + Self { + inner_map: Default::default(), + keys: Default::default(), + running_aggregations: Default::default(), + agg_constructors: self.agg_constructors.iter().map(|c| c.split()).collect(), + num_keys: self.num_keys, + spill_size: self.spill_size, + output_schema: self.output_schema.clone(), + } + } + + unsafe fn get_keys_row(&self, key: &Key) -> &[u8] { + let start = key.offset as usize; + let end = start + key.len as usize; + self.keys.get_unchecked(start..end) + } + + pub(super) fn is_empty(&self) -> bool { + self.inner_map.is_empty() + } + + fn get_entry(&mut self, hash: u64, row: &[u8]) -> RawEntryMut { + let keys = self.keys.as_ptr(); + + self.inner_map + .raw_entry_mut() + .from_hash(hash, |hash_map_key| { + // first check the hash as that has no indirection + hash_map_key.hash == hash && { + let offset = hash_map_key.offset as usize; + let len = hash_map_key.len as usize; + + unsafe { std::slice::from_raw_parts(keys.add(offset), len) == row } + } + }) + } + + fn insert_key<'a>(&'a mut self, hash: u64, row: &[u8]) -> Option { + let entry = self.get_entry(hash, row); + + match entry { + RawEntryMut::Occupied(entry) => Some(*entry.get()), + RawEntryMut::Vacant(entry) => { + // bchk shenanigans: + // it does not allow us to hold a `raw entry` and in the meantime + // have &self access to get the length of keys + // so we work with pointers instead + let borrow = &entry; + let borrow = borrow as *const RawVacantEntryMut<_, _, _> as usize; + // ensure the bck forgets this guy + #[allow(clippy::forget_non_drop)] + std::mem::forget(entry); + + // OVERFLOW logic + if FIXED && self.inner_map.len() > self.spill_size { + unsafe { + // take a hold of the entry again and ensure it gets dropped + let borrow = + borrow as *const RawVacantEntryMut<'a, Key, u32, IdBuildHasher>; + let _entry = std::ptr::read(borrow); + } + return None; + } + + let aggregation_idx = self.running_aggregations.len() as u32; + let key_offset = self.keys.len() as u32; + let key_len = row.len() as u32; + let key = Key::new(hash, key_offset, key_len); + + unsafe { + // take a hold of the entry again and ensure it gets dropped + let borrow = borrow as *const RawVacantEntryMut<'a, Key, u32, IdBuildHasher>; + let entry = std::ptr::read(borrow); + entry.insert(key, aggregation_idx); + } + + for agg in self.agg_constructors.as_ref() { + self.running_aggregations.push(agg.split()) + } + + self.keys.extend_from_slice(row); + Some(aggregation_idx) + }, + } + } + + /// # Safety + /// Caller must ensure that `keys` and `agg_iters` are not depleted. + /// # Returns &keys + pub(super) unsafe fn insert( + &mut self, + hash: u64, + key: &[u8], + agg_iters: &mut [SeriesPhysIter], + chunk_index: IdxSize, + ) -> bool { + let agg_idx = match self.insert_key(hash, key) { + // overflow + None => return true, + Some(agg_idx) => agg_idx, + }; + + // apply the aggregation + for (i, agg_iter) in agg_iters.iter_mut().enumerate() { + let i = agg_idx as usize + i; + let agg_fn = unsafe { self.running_aggregations.get_unchecked_mut(i) }; + + agg_fn.pre_agg(chunk_index, agg_iter.as_mut()) + } + // no overflow + false + } + + pub(super) fn combine(&mut self, other: &Self) { + self.combine_impl(other, |_hash| true) + } + + pub(super) fn combine_on_partition( + &mut self, + partition: usize, + other: &AggHashTable, + ) { + self.combine_impl(other, |hash| { + partition == hash_to_partition(hash, PARTITION_SIZE) + }) + } + + pub(super) fn combine_impl( + &mut self, + other: &AggHashTable, + on_condition: C, + ) + // takes a hash and if true, this keys will be combined + where + C: Fn(u64) -> bool, + { + let spill_size = self.spill_size; + self.spill_size = usize::MAX; + for (key_other, agg_idx_other) in other.inner_map.iter() { + // SAFETY: idx is from the hashmap, so is in bounds + let row = unsafe { other.get_keys_row(key_other) }; + + if on_condition(key_other.hash) { + // SAFETY: will not overflow as we set it to usize::MAX; + let agg_idx_self = + unsafe { self.insert_key(key_other.hash, row).unwrap_unchecked() }; + let start = *agg_idx_other as usize; + let end = start + self.agg_constructors.len(); + let aggs_other = unsafe { other.running_aggregations.get_unchecked(start..end) }; + let start = agg_idx_self as usize; + let end = start + self.agg_constructors.len(); + let aggs_self = unsafe { self.running_aggregations.get_unchecked_mut(start..end) }; + for i in 0..aggs_self.len() { + unsafe { + let agg_self = aggs_self.get_unchecked_mut(i); + let other = aggs_other.get_unchecked(i); + // TODO!: try transmutes + agg_self.combine(other.as_any()) + } + } + } + } + self.spill_size = spill_size; + } + + pub(super) fn finalize(&mut self, slice: &mut Option<(i64, usize)>) -> DataFrame { + let local_len = self.inner_map.len(); + let (skip_len, take_len) = if let Some((offset, slice_len)) = slice { + if *offset as usize >= local_len { + *offset -= local_len as i64; + return DataFrame::empty_with_schema(&self.output_schema); + } else { + let out = (*offset as usize, *slice_len); + *offset = 0; + *slice_len = slice_len.saturating_sub(local_len); + out + } + } else { + (0, local_len) + }; + let inner_map = std::mem::take(&mut self.inner_map); + let mut running_aggregations = std::mem::take(&mut self.running_aggregations); + + let mut agg_builders = self + .agg_constructors + .iter() + .map(|ac| AnyValueBufferTrusted::new(&ac.dtype(), take_len)) + .collect::>(); + let num_aggs = self.agg_constructors.len(); + let mut key_rows = Vec::with_capacity(take_len); + + inner_map + .into_iter() + .skip(skip_len) + .take(take_len) + .for_each(|(k, agg_offset)| { + unsafe { + key_rows.push_unchecked(self.get_keys_row(&k)); + } + + let start = agg_offset as usize; + let end = start + num_aggs; + for (i, buffer) in (start..end).zip(agg_builders.iter_mut()) { + unsafe { + let running_agg = running_aggregations.get_unchecked_mut(i); + let av = running_agg.finalize(); + // SAFETY: finalize creates owned AnyValues + buffer.add_unchecked_owned_physical(&av); + } + } + }); + + let key_dtypes = self + .output_schema + .iter_values() + .take(self.num_keys) + .map(|dtype| dtype.to_physical().to_arrow(CompatLevel::newest())) + .collect::>(); + let dicts = self + .output_schema + .iter_values() + .take(self.num_keys) + .map(|dt| get_row_encoding_context(dt, false)) + .collect::>(); + let fields = vec![Default::default(); self.num_keys]; + let key_columns = + unsafe { polars_row::decode::decode_rows(&mut key_rows, &fields, &dicts, &key_dtypes) }; + + let mut cols = Vec::with_capacity(self.num_keys + self.agg_constructors.len()); + cols.extend(key_columns.into_iter().map(|arr| { + Series::try_from((PlSmallStr::EMPTY, arr)) + .unwrap() + .into_column() + })); + cols.extend( + agg_builders + .into_iter() + .map(|buf| buf.into_series().into_column()), + ); + physical_agg_to_logical(&mut cols, &self.output_schema); + unsafe { DataFrame::new_no_checks_height_from_first(cols) } + } +} + +unsafe impl Send for AggHashTable {} +unsafe impl Sync for AggHashTable {} diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/mod.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/mod.rs new file mode 100644 index 000000000000..0d2a17ccb211 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/mod.rs @@ -0,0 +1,128 @@ +mod eval; +mod global; +mod hash_table; +mod ooc_state; +mod sink; +mod source; +mod thread_local; + +use std::any::Any; +use std::hash::{Hash, Hasher}; +use std::sync::Mutex; + +use arrow::array::BinaryArray; +use eval::Eval; +use hash_table::AggHashTable; +use hashbrown::hash_map::{RawEntryMut, RawVacantEntryMut}; +use polars_core::IdBuildHasher; +use polars_core::frame::row::AnyValueBufferTrusted; +use polars_core::series::SeriesPhysIter; +pub(crate) use sink::GenericGroupby2; +use thread_local::ThreadLocalTable; + +use super::*; +use crate::executors::sinks::group_by::aggregates::{AggregateFn, AggregateFunction}; +use crate::executors::sinks::io::IOThread; +use crate::operators::{DataChunk, FinalizedSink, PExecutionContext, Sink, SinkResult}; + +type PartitionVec = Vec; +type IOThreadRef = Arc>>; + +#[derive(Clone)] +struct SpillPayload { + hashes: Vec, + chunk_idx: Vec, + keys: BinaryArray, + aggs: Vec, +} + +static HASH_COL: &str = "__POLARS_h"; +static INDEX_COL: &str = "__POLARS_idx"; +static KEYS_COL: &str = "__POLARS_keys"; + +impl SpillPayload { + fn hashes(&self) -> &[u64] { + &self.hashes + } + + fn keys(&self) -> &BinaryArray { + &self.keys + } + + fn cols(&self) -> &[Series] { + &self.aggs + } + + fn chunk_index(&self) -> &[IdxSize] { + &self.chunk_idx + } + + fn get_schema(&self) -> Schema { + let mut schema = Schema::with_capacity(self.aggs.len() + 2); + schema.with_column(HASH_COL.into(), DataType::UInt64); + schema.with_column(INDEX_COL.into(), IDX_DTYPE); + schema.with_column(KEYS_COL.into(), DataType::BinaryOffset); + for s in &self.aggs { + schema.with_column(s.name().clone(), s.dtype().clone()); + } + schema + } + + fn into_df(self) -> DataFrame { + debug_assert_eq!(self.hashes.len(), self.chunk_idx.len()); + debug_assert_eq!(self.hashes.len(), self.keys.len()); + + let height = self.hashes.len(); + + let hashes = + UInt64Chunked::from_vec(PlSmallStr::from_static(HASH_COL), self.hashes).into_column(); + let chunk_idx = + IdxCa::from_vec(PlSmallStr::from_static(INDEX_COL), self.chunk_idx).into_column(); + let keys = BinaryOffsetChunked::with_chunk(PlSmallStr::from_static(KEYS_COL), self.keys) + .into_column(); + + let mut cols = Vec::with_capacity(self.aggs.len() + 3); + cols.push(hashes); + cols.push(chunk_idx); + cols.push(keys); + // @scalar-opt + cols.extend(self.aggs.into_iter().map(Column::from)); + unsafe { DataFrame::new_no_checks(height, cols) } + } + + fn spilled_to_columns( + spilled: &DataFrame, + ) -> (&[u64], &[IdxSize], &BinaryArray, &[Column]) { + let cols = spilled.get_columns(); + let hashes = cols[0].u64().unwrap(); + let hashes = hashes.cont_slice().unwrap(); + let chunk_indexes = cols[1].idx().unwrap(); + let chunk_indexes = chunk_indexes.cont_slice().unwrap(); + let keys = cols[2].binary_offset().unwrap(); + let keys = keys.downcast_iter().next().unwrap(); + let aggs = &cols[3..]; + (hashes, chunk_indexes, keys, aggs) + } +} + +// This is the hash and the Index offset in the linear buffer +#[derive(Copy, Clone)] +pub(super) struct Key { + pub(super) hash: u64, + pub(super) offset: u32, + pub(super) len: u32, +} + +impl Key { + #[inline] + pub(super) fn new(hash: u64, offset: u32, len: u32) -> Self { + Self { hash, offset, len } + } +} + +impl Hash for Key { + #[inline] + fn hash(&self, state: &mut H) { + state.write_u64(self.hash) + } +} diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/ooc_state.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/ooc_state.rs new file mode 100644 index 000000000000..1f7d33998af3 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/ooc_state.rs @@ -0,0 +1,99 @@ +use polars_core::config::verbose; + +use super::*; +use crate::executors::sinks::memory::MemTracker; +use crate::pipeline::{FORCE_OOC, morsels_per_sink}; + +#[derive(Clone)] +pub(super) struct OocState { + // OOC + // Stores available memory in the system at the start of this sink. + // and stores the memory used by this sink. + mem_track: MemTracker, + // sort in-memory or out-of-core + pub(super) ooc: bool, + // when ooc, we write to disk using an IO thread + pub(super) io_thread: IOThreadRef, + count: u16, + to_disk_threshold: f64, +} + +impl Default for OocState { + fn default() -> Self { + let to_disk_threshold = if std::env::var(FORCE_OOC).is_ok() { + 1.0 + } else { + TO_DISK_THRESHOLD + }; + + Self { + mem_track: MemTracker::new(morsels_per_sink()), + ooc: false, + io_thread: Default::default(), + count: 0, + to_disk_threshold, + } + } +} + +// If this is reached we early merge the overflow buckets +// to free up memory +const EARLY_MERGE_THRESHOLD: f64 = 0.5; +// If this is reached we spill to disk and +// aggregate in a second run +const TO_DISK_THRESHOLD: f64 = 0.3; + +pub(super) enum SpillAction { + EarlyMerge, + Dump, + None, +} + +impl OocState { + fn init_ooc(&mut self, spill_schema: Schema) -> PolarsResult<()> { + if verbose() { + eprintln!("OOC group_by started"); + } + self.ooc = true; + + // start IO thread + let mut iot = self.io_thread.lock().unwrap(); + if iot.is_none() { + *iot = Some(IOThread::try_new(Arc::new(spill_schema), "group_by").unwrap()); + } + Ok(()) + } + + pub(super) fn check_memory_usage( + &mut self, + spill_schema: &dyn Fn() -> Option, + ) -> PolarsResult { + if self.ooc { + return Ok(SpillAction::Dump); + } + let free_frac = self.mem_track.free_memory_fraction_since_start(); + self.count += 1; + + if free_frac < self.to_disk_threshold { + if let Some(schema) = spill_schema() { + self.init_ooc(schema)?; + Ok(SpillAction::Dump) + } else { + Ok(SpillAction::None) + } + } else if free_frac < EARLY_MERGE_THRESHOLD + // clean up some spills + || (self.count % 512) == 0 + { + Ok(SpillAction::EarlyMerge) + } else { + Ok(SpillAction::None) + } + } + + pub(super) fn dump(&self, partition_no: usize, df: DataFrame) { + let iot = self.io_thread.lock().unwrap(); + let iot = iot.as_ref().unwrap(); + iot.dump_partition(partition_no as IdxSize, df) + } +} diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/sink.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/sink.rs new file mode 100644 index 000000000000..50a68cd34d27 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/sink.rs @@ -0,0 +1,184 @@ +use std::cell::UnsafeCell; + +use polars_core::utils::accumulate_dataframes_vertical_unchecked; + +use super::*; +use crate::executors::sinks::group_by::generic::global::GlobalTable; +use crate::executors::sinks::group_by::generic::ooc_state::{OocState, SpillAction}; +use crate::executors::sinks::group_by::generic::source::GroupBySource; +use crate::executors::sources::DataFrameSource; +use crate::expressions::PhysicalPipedExpr; + +pub(crate) struct GenericGroupby2 { + thread_local_table: UnsafeCell, + global_table: Arc, + eval: Eval, + slice: Option<(i64, usize)>, + ooc_state: OocState, +} + +impl GenericGroupby2 { + pub(crate) fn new( + key_columns: Arc>>, + aggregation_columns: Arc>>, + agg_constructors: Arc<[AggregateFunction]>, + output_schema: SchemaRef, + agg_input_dtypes: Vec, + slice: Option<(i64, usize)>, + ) -> Self { + let key_dtypes: Arc<[DataType]> = Arc::from( + output_schema + .iter_values() + .take(key_columns.len()) + .cloned() + .collect::>(), + ); + + let agg_dtypes: Arc<[DataType]> = Arc::from(agg_input_dtypes); + + let global_map = GlobalTable::new( + agg_constructors.clone(), + key_dtypes.as_ref(), + output_schema.clone(), + ); + + Self { + thread_local_table: UnsafeCell::new(ThreadLocalTable::new( + agg_constructors, + key_dtypes, + agg_dtypes, + output_schema, + )), + global_table: Arc::new(global_map), + eval: Eval::new(key_columns, aggregation_columns), + slice, + ooc_state: Default::default(), + } + } +} + +impl Sink for GenericGroupby2 { + fn sink(&mut self, context: &PExecutionContext, chunk: DataChunk) -> PolarsResult { + if chunk.is_empty() { + return Ok(SinkResult::CanHaveMoreInput); + } + // load data and hashes + unsafe { + // SAFETY: we don't hold mutable refs + self.eval.evaluate_keys_aggs_and_hashes(context, &chunk)?; + } + // SAFETY: eval is alive for the duration of keys + let keys = unsafe { self.eval.get_keys_iter() }; + // SAFETY: we don't hold mutable refs + let mut aggs = unsafe { self.eval.get_aggs_iters() }; + + let chunk_idx = chunk.chunk_index; + unsafe { + // SAFETY: the mutable borrows are not aliasing + let table = &mut *self.thread_local_table.get(); + + for (hash, row) in self.eval.hashes().iter().zip(keys.values_iter()) { + if let Some((partition, spill_payload)) = + table.insert(*hash, row, &mut aggs, chunk_idx) + { + self.global_table.spill(partition, spill_payload) + } + } + } + + // clear memory + unsafe { + drop(aggs); + // SAFETY: we don't hold mutable refs, we just dropped them + self.eval.clear() + }; + + // indicates if we should early merge a partition + // other scenario could be that we must spill to disk + match self + .ooc_state + .check_memory_usage(&|| self.global_table.get_ooc_dump_schema())? + { + SpillAction::None => {}, + SpillAction::EarlyMerge => self.global_table.early_merge(), + SpillAction::Dump => { + if let Some((partition_no, spill)) = self.global_table.get_ooc_dump() { + self.ooc_state.dump(partition_no, spill) + } else { + // do nothing + } + }, + } + Ok(SinkResult::CanHaveMoreInput) + } + + fn combine(&mut self, other: &mut dyn Sink) { + let other = other.as_any().downcast_mut::().unwrap(); + unsafe { + let map = &mut *self.thread_local_table.get(); + let other_map = &mut *other.thread_local_table.get(); + map.combine(other_map); + } + } + + fn split(&self, _thread_no: usize) -> Box { + // SAFETY: no mutable refs at this point + let map = unsafe { (*self.thread_local_table.get()).split() }; + Box::new(Self { + eval: self.eval.split(), + thread_local_table: UnsafeCell::new(map), + global_table: self.global_table.clone(), + slice: self.slice, + ooc_state: self.ooc_state.clone(), + }) + } + + fn finalize(&mut self, context: &PExecutionContext) -> PolarsResult { + let map = unsafe { &mut *self.thread_local_table.get() }; + + // only succeeds if it hasn't spilled to global + if let Some(out) = map.finalize(&mut self.slice) { + if context.verbose { + eprintln!("finish streaming aggregation with local in-memory table") + } + Ok(FinalizedSink::Finished(out)) + } else { + // ensure the global map gets all overflow buckets + for (partition, payload) in map.get_all_spilled() { + self.global_table.spill(partition, payload); + } + // ensure the global map update the partitioned hash tables with keys from local map + self.global_table.merge_local_map(map.get_inner_map_mut()); + + // all data is in memory + // finalize + if !self.ooc_state.ooc { + if context.verbose { + eprintln!("finish streaming aggregation with global in-memory table") + } + + let out = self.global_table.finalize(&mut self.slice); + let src = DataFrameSource::from_df(accumulate_dataframes_vertical_unchecked(out)); + Ok(FinalizedSink::Source(Box::new(src))) + } + // create an ooc source + else { + Ok(FinalizedSink::Source(Box::new(GroupBySource::new( + &self.ooc_state.io_thread, + self.slice, + self.global_table.clone(), + )?))) + } + } + } + + fn as_any(&mut self) -> &mut dyn Any { + self + } + + fn fmt(&self) -> &str { + "generic-group_by" + } +} + +unsafe impl Sync for GenericGroupby2 {} diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/source.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/source.rs new file mode 100644 index 000000000000..a0882f45764a --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/source.rs @@ -0,0 +1,94 @@ +use polars_core::utils::flatten::flatten_df_iter; +use polars_io::SerReader; +use polars_io::ipc::IpcReader; + +use super::*; +use crate::executors::sinks::group_by::generic::global::GlobalTable; +use crate::executors::sinks::io::block_thread_until_io_thread_done; +use crate::operators::{Source, SourceResult}; +use crate::pipeline::PARTITION_SIZE; + +pub(super) struct GroupBySource { + // holding this keeps the lockfile in place + _io_thread: IOThread, + global_table: Arc, + slice: Option<(i64, usize)>, + partition_processed: usize, +} + +impl GroupBySource { + pub(super) fn new( + io_thread: &IOThreadRef, + slice: Option<(i64, usize)>, + global_table: Arc, + ) -> PolarsResult { + let mut io_thread = io_thread.lock().unwrap(); + let io_thread = io_thread.take().unwrap(); + + if let Some(slice) = slice { + polars_ensure!(slice.0 >= 0, ComputeError: "negative slice not supported with out-of-core group_by") + } + + block_thread_until_io_thread_done(&io_thread); + Ok(Self { + _io_thread: io_thread, + slice, + global_table, + partition_processed: 0, + }) + } +} + +impl Source for GroupBySource { + fn get_batches(&mut self, context: &PExecutionContext) -> PolarsResult { + if self.slice == Some((0, 0)) { + return Ok(SourceResult::Finished); + } + + let partition = self.partition_processed; + self.partition_processed += 1; + + if partition >= PARTITION_SIZE { + return Ok(SourceResult::Finished); + } + let mut partition_dir = self._io_thread.dir.clone(); + partition_dir.push(format!("{partition}")); + + if context.verbose { + eprintln!("process partition {partition} during {}", self.fmt()) + } + + // merge the dumped tables + // if no tables are spilled we simply skip + // this and finalize the in memory state + if partition_dir.exists() { + for file in std::fs::read_dir(partition_dir).expect("should be there") { + let spilled = file.unwrap().path(); + let file = polars_utils::open_file(&spilled)?; + let reader = IpcReader::new(file); + let spilled = reader.finish().unwrap(); + if spilled.first_col_n_chunks() > 1 { + for spilled in flatten_df_iter(&spilled) { + self.global_table + .process_partition_from_dumped(partition, &spilled) + } + } else { + self.global_table + .process_partition_from_dumped(partition, &spilled) + } + } + } + + let df = self + .global_table + .finalize_partition(partition, &mut self.slice); + + let chunk_idx = self.partition_processed as IdxSize; + Ok(SourceResult::GotMoreData(vec![DataChunk::new( + chunk_idx, df, + )])) + } + fn fmt(&self) -> &str { + "generic-group_by-source" + } +} diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/thread_local.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/thread_local.rs new file mode 100644 index 000000000000..ad03c2775fca --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/thread_local.rs @@ -0,0 +1,299 @@ +use std::sync::LazyLock; + +use arrow::array::MutableBinaryArray; +use polars_utils::hashing::hash_to_partition; + +use super::*; +use crate::pipeline::PARTITION_SIZE; + +const OB_SIZE: usize = 2048; + +static SPILL_SIZE: LazyLock = LazyLock::new(|| { + std::env::var("POLARS_STREAMING_GROUPBY_SPILL_SIZE") + .map(|v| v.parse::().unwrap()) + .unwrap_or(10_000) +}); + +#[derive(Clone)] +struct SpillPartitions { + // outer vec: partitions (factor of 2) + // inner vec: number of keys + number of aggregated columns + keys_partitioned: PartitionVec>, + aggs_partitioned: PartitionVec>>, + hash_partitioned: PartitionVec>, + chunk_index_partitioned: PartitionVec>, + spilled: bool, + // this only fills during the reduce phase IFF + // there are spilled tuples + finished_payloads: PartitionVec>, + keys_dtypes: Arc<[DataType]>, + agg_dtypes: Arc<[DataType]>, + output_schema: SchemaRef, +} + +impl SpillPartitions { + fn new(keys: Arc<[DataType]>, aggs: Arc<[DataType]>, output_schema: SchemaRef) -> Self { + let hash_partitioned = vec![]; + let chunk_index_partitioned = vec![]; + + // construct via split so that pre-allocation succeeds + Self { + keys_partitioned: vec![], + aggs_partitioned: vec![], + hash_partitioned, + chunk_index_partitioned, + spilled: false, + finished_payloads: vec![], + keys_dtypes: keys, + agg_dtypes: aggs, + output_schema, + } + .split() + } + + fn split(&self) -> Self { + let n_columns = self.agg_dtypes.as_ref().len(); + + let aggs_partitioned = (0..PARTITION_SIZE) + .map(|_| { + let mut buf = Vec::with_capacity(n_columns); + for dtype in self.agg_dtypes.as_ref() { + let builder = AnyValueBufferTrusted::new(&dtype.to_physical(), OB_SIZE); + buf.push(builder); + } + buf + }) + .collect(); + + let keys_partitioned = (0..PARTITION_SIZE) + .map(|_| MutableBinaryArray::with_capacity(OB_SIZE)) + .collect(); + + let hash_partitioned = (0..PARTITION_SIZE) + .map(|_| Vec::with_capacity(OB_SIZE)) + .collect::>(); + let chunk_index_partitioned = (0..PARTITION_SIZE) + .map(|_| Vec::with_capacity(OB_SIZE)) + .collect::>(); + + Self { + keys_partitioned, + aggs_partitioned, + hash_partitioned, + chunk_index_partitioned, + spilled: false, + finished_payloads: vec![], + keys_dtypes: self.keys_dtypes.clone(), + agg_dtypes: self.agg_dtypes.clone(), + output_schema: self.output_schema.clone(), + } + } +} + +impl SpillPartitions { + /// Returns (partition, overflowing hashes, chunk_indexes, keys and aggs) + fn insert( + &mut self, + hash: u64, + chunk_idx: IdxSize, + row: &[u8], + agg_iters: &mut [SeriesPhysIter], + ) -> Option<(usize, SpillPayload)> { + let partition = hash_to_partition(hash, self.aggs_partitioned.len()); + self.spilled = true; + unsafe { + let agg_values = self.aggs_partitioned.get_unchecked_mut(partition); + let hashes = self.hash_partitioned.get_unchecked_mut(partition); + let chunk_indexes = self.chunk_index_partitioned.get_unchecked_mut(partition); + let key_builder = self.keys_partitioned.get_unchecked_mut(partition); + + hashes.push(hash); + chunk_indexes.push(chunk_idx); + + // amortize the loop counter + key_builder.push(Some(row)); + for (i, agg) in agg_iters.iter_mut().enumerate() { + let av = agg.next().unwrap_unchecked(); + let buf = agg_values.get_unchecked_mut(i); + buf.add_unchecked_borrowed_physical(&av); + } + + if hashes.len() >= OB_SIZE { + let mut new_hashes = Vec::with_capacity(OB_SIZE); + let mut new_chunk_indexes = Vec::with_capacity(OB_SIZE); + let mut new_keys_builder = MutableBinaryArray::with_capacity(OB_SIZE); + std::mem::swap(&mut new_hashes, hashes); + std::mem::swap(&mut new_chunk_indexes, chunk_indexes); + std::mem::swap(&mut new_keys_builder, key_builder); + + Some(( + partition, + SpillPayload { + hashes: new_hashes, + chunk_idx: new_chunk_indexes, + keys: new_keys_builder.into(), + aggs: agg_values + .iter_mut() + .zip(self.output_schema.iter_names()) + .map(|(b, name)| { + let mut s = b.reset(OB_SIZE, false).unwrap(); + s.rename(name.clone()); + s + }) + .collect(), + }, + )) + } else { + None + } + } + } + + fn finish(&mut self) { + if self.spilled { + let all_spilled = self.get_all_spilled().collect::>(); + for (partition_i, payload) in all_spilled { + let buf = if let Some(buf) = self.finished_payloads.get_mut(partition_i) { + buf + } else { + self.finished_payloads.push(vec![]); + self.finished_payloads.last_mut().unwrap() + }; + buf.push(payload) + } + } + } + + fn combine(&mut self, other: &mut Self) { + match (self.spilled, other.spilled) { + (false, true) => std::mem::swap(self, other), + (true, false) => {}, + (false, false) => {}, + (true, true) => { + self.finish(); + other.finish(); + let other_payloads = std::mem::take(&mut other.finished_payloads); + + for (part_self, part_other) in self.finished_payloads.iter_mut().zip(other_payloads) + { + part_self.extend(part_other) + } + }, + } + } + + fn get_all_spilled(&mut self) -> impl Iterator + '_ { + // todo! allocate + let mut flattened = vec![]; + let finished_payloads = std::mem::take(&mut self.finished_payloads); + for (part, payloads) in finished_payloads.into_iter().enumerate() { + for payload in payloads { + flattened.push((part, payload)) + } + } + + (0..PARTITION_SIZE) + .map(|partition| unsafe { + let spilled_aggs = self.aggs_partitioned.get_unchecked_mut(partition); + let hashes = self.hash_partitioned.get_unchecked_mut(partition); + let chunk_indexes = self.chunk_index_partitioned.get_unchecked_mut(partition); + let keys_builder = + std::mem::take(self.keys_partitioned.get_unchecked_mut(partition)); + let hashes = std::mem::take(hashes); + let chunk_idx = std::mem::take(chunk_indexes); + + ( + partition, + SpillPayload { + hashes, + chunk_idx, + keys: keys_builder.into(), + aggs: spilled_aggs + .iter_mut() + .map(|b| b.reset(0, false).unwrap()) + .collect(), + }, + ) + }) + .chain(flattened) + } +} + +pub(super) struct ThreadLocalTable { + inner_map: AggHashTable, + spill_partitions: SpillPartitions, +} + +impl ThreadLocalTable { + pub(super) fn new( + agg_constructors: Arc<[AggregateFunction]>, + key_dtypes: Arc<[DataType]>, + agg_dtypes: Arc<[DataType]>, + output_schema: SchemaRef, + ) -> Self { + let spill_partitions = + SpillPartitions::new(key_dtypes.clone(), agg_dtypes, output_schema.clone()); + + Self { + inner_map: AggHashTable::new( + agg_constructors, + key_dtypes.as_ref(), + output_schema, + Some(*SPILL_SIZE), + ), + spill_partitions, + } + } + + pub(super) fn split(&self) -> Self { + // should be called before any chunk is processed + debug_assert!(self.inner_map.is_empty()); + + Self { + inner_map: self.inner_map.split(), + spill_partitions: self.spill_partitions.clone(), + } + } + + pub(super) fn get_inner_map_mut(&mut self) -> &mut AggHashTable { + &mut self.inner_map + } + + /// # Safety + /// Caller must ensure that `keys` and `agg_iters` are not depleted. + #[inline] + pub(super) unsafe fn insert( + &mut self, + hash: u64, + keys_row: &[u8], + agg_iters: &mut [SeriesPhysIter], + chunk_index: IdxSize, + ) -> Option<(usize, SpillPayload)> { + if self + .inner_map + .insert(hash, keys_row, agg_iters, chunk_index) + { + self.spill_partitions + .insert(hash, chunk_index, keys_row, agg_iters) + } else { + None + } + } + + pub(super) fn combine(&mut self, other: &mut Self) { + self.inner_map.combine(&other.inner_map); + self.spill_partitions.combine(&mut other.spill_partitions); + } + + pub(super) fn finalize(&mut self, slice: &mut Option<(i64, usize)>) -> Option { + if !self.spill_partitions.spilled { + Some(self.inner_map.finalize(slice)) + } else { + None + } + } + + pub(super) fn get_all_spilled(&mut self) -> impl Iterator + '_ { + self.spill_partitions.get_all_spilled() + } +} diff --git a/crates/polars-pipe/src/executors/sinks/group_by/mod.rs b/crates/polars-pipe/src/executors/sinks/group_by/mod.rs new file mode 100644 index 000000000000..b8478dd9eb7e --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/group_by/mod.rs @@ -0,0 +1,62 @@ +pub(crate) mod aggregates; +mod generic; +mod ooc; +mod ooc_state; +mod primitive; +mod string; +mod utils; + +pub(crate) use generic::GenericGroupby2; +use polars_core::prelude::*; +#[cfg(feature = "dtype-categorical")] +use polars_core::using_string_cache; +pub(crate) use primitive::*; +pub(crate) use string::*; + +pub(super) fn physical_agg_to_logical(cols: &mut [Column], output_schema: &Schema) { + for (s, (name, dtype)) in cols.iter_mut().zip(output_schema.iter()) { + if s.name() != name { + s.rename(name.clone()); + } + match dtype { + #[cfg(feature = "dtype-categorical")] + dt @ (DataType::Categorical(rev_map, ordering) | DataType::Enum(rev_map, ordering)) => { + if let Some(rev_map) = rev_map { + let cats = s.u32().unwrap().clone(); + // SAFETY: + // the rev-map comes from these categoricals + unsafe { + *s = CategoricalChunked::from_cats_and_rev_map_unchecked( + cats, + rev_map.clone(), + matches!(dt, DataType::Enum(_, _)), + *ordering, + ) + .into_column() + } + } else { + let cats = s.u32().unwrap().clone(); + if using_string_cache() { + // SAFETY, we go from logical to primitive back to logical so the categoricals should still match the global map. + *s = unsafe { + CategoricalChunked::from_global_indices_unchecked(cats, *ordering) + .into_column() + }; + } else { + // we set the global string cache once we start a streaming pipeline + unreachable!() + } + } + }, + _ => { + let dtype_left = s.dtype(); + if dtype_left != dtype + && !matches!(dtype, DataType::Boolean) + && !(dtype.is_float() && dtype_left.is_float()) + { + *s = s.cast(dtype).unwrap() + } + }, + } + } +} diff --git a/crates/polars-pipe/src/executors/sinks/group_by/ooc.rs b/crates/polars-pipe/src/executors/sinks/group_by/ooc.rs new file mode 100644 index 000000000000..56166f1ebc95 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/group_by/ooc.rs @@ -0,0 +1,134 @@ +use polars_core::config::verbose; +use polars_core::prelude::*; +use polars_core::utils::split_df; + +use crate::executors::sinks::io::IOThread; +use crate::executors::sources::IpcSourceOneShot; +use crate::operators::{DataChunk, FinalizedSink, PExecutionContext, Sink, Source, SourceResult}; +use crate::pipeline::{PipeLine, morsels_per_sink}; + +pub(super) struct GroupBySource { + // Holding this keeps the lockfile in place + io_thread: IOThread, + already_finished: Option, + partitions: std::fs::ReadDir, + group_by_sink: Box, + chunk_idx: IdxSize, + morsels_per_sink: usize, + slice: Option<(usize, usize)>, +} + +impl GroupBySource { + pub(super) fn new( + io_thread: IOThread, + already_finished: DataFrame, + group_by_sink: Box, + slice: Option<(i64, usize)>, + ) -> PolarsResult { + let partitions = std::fs::read_dir(&io_thread.dir)?; + + if let Some(slice) = slice { + if slice.0 < 0 { + polars_bail!(ComputeError: "negative slice not supported with out-of-core group_by") + } + } + + Ok(Self { + io_thread, + already_finished: Some(already_finished), + partitions, + group_by_sink, + chunk_idx: 0, + morsels_per_sink: morsels_per_sink(), + slice: slice.map(|slice| (slice.0 as usize, slice.1)), + }) + } +} + +impl Source for GroupBySource { + fn get_batches(&mut self, context: &PExecutionContext) -> PolarsResult { + if self.slice == Some((0, 0)) { + return Ok(SourceResult::Finished); + } + + if let Some(df) = self.already_finished.take() { + let chunk_idx = self.chunk_idx; + self.chunk_idx += 1; + return Ok(SourceResult::GotMoreData(vec![DataChunk::new( + chunk_idx, df, + )])); + } + + match self.partitions.next() { + None => Ok(SourceResult::Finished), + Some(dir) => { + let partition_dir = dir?; + if partition_dir.path().ends_with(".lock") { + return self.get_batches(context); + } + + // read the files in the partition into sources + // ensure we read in the right order + let mut files = std::fs::read_dir(partition_dir.path())? + .map(|e| e.map(|e| e.path())) + .collect::, _>>()?; + files.sort_unstable(); + + let sources = files + .iter() + .map(|path| { + Ok(Box::new(IpcSourceOneShot::new(path.as_path())?) as Box) + }) + .collect::>>()?; + + // create a pipeline with a the files as sources and the group_by as sink + let mut pipe = + PipeLine::new_simple(sources, vec![], self.group_by_sink.split(0), verbose()); + + let out = match pipe.run_pipeline(context, &mut vec![])?.unwrap() { + FinalizedSink::Finished(mut df) => { + if let Some(slice) = &mut self.slice { + let height = df.height(); + if slice.0 >= height { + slice.0 -= height; + return self.get_batches(context); + } else { + df = df.slice(slice.0 as i64, slice.1); + slice.0 = 0; + slice.1 = slice.1.saturating_sub(height); + } + } + + let dfs = split_df(&mut df, self.morsels_per_sink, false); + let chunks = dfs + .into_iter() + .map(|data| { + let chunk = DataChunk { + chunk_index: self.chunk_idx, + data, + }; + self.chunk_idx += 1; + + chunk + }) + .collect::>(); + + Ok(SourceResult::GotMoreData(chunks)) + }, + // recursively out of core path + FinalizedSink::Source(mut src) => src.get_batches(context), + _ => unreachable!(), + }; + for path in files { + self.io_thread.clean(path) + } + + out + }, + } + } + + fn fmt(&self) -> &str { + "ooc-group_by-source" + } +} diff --git a/crates/polars-pipe/src/executors/sinks/group_by/ooc_state.rs b/crates/polars-pipe/src/executors/sinks/group_by/ooc_state.rs new file mode 100644 index 000000000000..f2c664087daf --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/group_by/ooc_state.rs @@ -0,0 +1,64 @@ +use std::sync::Mutex; + +use polars_core::config::verbose; +use polars_core::prelude::*; + +use crate::executors::sinks::io::IOThread; +use crate::executors::sinks::memory::MemTracker; +use crate::pipeline::morsels_per_sink; + +// THIS CODE DOESN'T MAKE SENSE +// It is a remnant of OOC, but will be rewritten to use the generic OOC Table. + +pub(super) struct OocState { + // OOC + // Stores available memory in the system at the start of this sink. + // and stores the memory used by this sink. + _mem_track: MemTracker, + // sort in-memory or out-of-core + pub(super) ooc: bool, + // when ooc, we write to disk using an IO thread + pub(super) io_thread: Arc>>, +} + +impl OocState { + pub(super) fn new(io_thread: Option>>>, ooc: bool) -> Self { + Self { + _mem_track: MemTracker::new(morsels_per_sink()), + ooc, + io_thread: io_thread.unwrap_or_default(), + } + } + + pub(super) fn init_ooc(&mut self, input_schema: SchemaRef) -> PolarsResult<()> { + if verbose() { + eprintln!("OOC group_by started"); + } + self.ooc = true; + + // start IO thread + let mut iot = self.io_thread.lock().unwrap(); + if iot.is_none() { + *iot = Some(IOThread::try_new(input_schema, "group_by")?) + } + Ok(()) + } + + pub(super) fn reset_ooc_filter_rows(&mut self, _len: usize) { + // no-op + } + + pub(super) fn check_memory_usage(&mut self, _schema: &SchemaRef) -> PolarsResult<()> { + // ooc is broken we will rewrite to generic table + // n-op + Ok(()) + } + + #[inline] + pub(super) unsafe fn set_row_as_ooc(&mut self, _idx: usize) {} + + pub(super) fn dump(&self, _data: DataFrame, _hashes: &mut [u64]) { + // ooc is broken we will rewrite to generic table + todo!() + } +} diff --git a/crates/polars-pipe/src/executors/sinks/group_by/primitive/mod.rs b/crates/polars-pipe/src/executors/sinks/group_by/primitive/mod.rs new file mode 100644 index 000000000000..9e8ae2939c15 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/group_by/primitive/mod.rs @@ -0,0 +1,571 @@ +use std::any::Any; +use std::fmt::Debug; +use std::hash::{BuildHasher, Hash, Hasher}; +use std::sync::Mutex; + +use arrow::legacy::is_valid::IsValid; +use arrow::legacy::kernels::sort_partition::partition_to_groups_amortized; +use hashbrown::hash_map::RawEntryMut; +use num_traits::NumCast; +use polars_core::POOL; +use polars_core::frame::row::AnyValueBuffer; +use polars_core::prelude::*; +use polars_core::series::IsSorted; +use polars_core::utils::_set_partition_size; +use polars_utils::aliases::PlSeedableRandomStateQuality; +use polars_utils::hashing::{DirtyHash, hash_to_partition}; +use rayon::prelude::*; + +use super::aggregates::AggregateFn; +use crate::executors::sinks::HASHMAP_INIT_SIZE; +use crate::executors::sinks::group_by::aggregates::AggregateFunction; +use crate::executors::sinks::group_by::ooc_state::OocState; +use crate::executors::sinks::group_by::physical_agg_to_logical; +use crate::executors::sinks::group_by::string::{apply_aggregate, write_agg_idx}; +use crate::executors::sinks::group_by::utils::{compute_slices, finalize_group_by, prepare_key}; +use crate::executors::sinks::io::IOThread; +use crate::executors::sinks::utils::load_vec; +use crate::expressions::PhysicalPipedExpr; +use crate::operators::{DataChunk, FinalizedSink, PExecutionContext, Sink, SinkResult}; + +// hash + value +#[derive(Eq, Copy, Clone)] +struct Key { + hash: u64, + value: T, +} + +impl Hash for Key { + fn hash(&self, state: &mut H) { + state.write_u64(self.hash) + } +} + +impl PartialEq for Key { + fn eq(&self, other: &Self) -> bool { + self.value == other.value + } +} + +pub struct PrimitiveGroupbySink { + thread_no: usize, + // idx is the offset in the array with aggregators + pre_agg_partitions: Vec>, IdxSize>>, + // the aggregations are all tightly packed + // the aggregation function of a group can be found + // by: + // * offset = (idx) + // * end = (offset + n_aggs) + aggregators: Vec, + key: Arc, + // the columns that will be aggregated + aggregation_columns: Arc>>, + hb: PlSeedableRandomStateQuality, + // Initializing Aggregation functions. If we aggregate by 2 columns + // this vec will have two functions. We will use these functions + // to populate the buffer where the hashmap points to + agg_fns: Vec, + input_schema: SchemaRef, + output_schema: SchemaRef, + // amortize allocations + aggregation_series: Vec, + hashes: Vec, + slice: Option<(i64, usize)>, + // for sorted fast paths + sort_partitions: Vec<[IdxSize; 2]>, + + ooc_state: OocState, +} + +impl PrimitiveGroupbySink +where + ChunkedArray: IntoSeries, + K::Native: Hash + DirtyHash, +{ + pub(crate) fn new( + key: Arc, + aggregation_columns: Arc>>, + agg_fns: Vec, + input_schema: SchemaRef, + output_schema: SchemaRef, + slice: Option<(i64, usize)>, + ) -> Self { + // this ooc is broken fix later + Self::new_inner( + key, + aggregation_columns, + agg_fns, + input_schema, + output_schema, + slice, + None, + false, + ) + } + + #[allow(clippy::too_many_arguments)] + pub(crate) fn new_inner( + key: Arc, + aggregation_columns: Arc>>, + agg_fns: Vec, + input_schema: SchemaRef, + output_schema: SchemaRef, + slice: Option<(i64, usize)>, + io_thread: Option>>>, + ooc: bool, + ) -> Self { + let hb = PlSeedableRandomStateQuality::default(); + let partitions = _set_partition_size(); + + let pre_agg = load_vec(partitions, || PlIdHashMap::with_capacity(HASHMAP_INIT_SIZE)); + let aggregators = + Vec::with_capacity(HASHMAP_INIT_SIZE * aggregation_columns.len() * partitions); + + let mut out = Self { + thread_no: 0, + pre_agg_partitions: pre_agg, + aggregators, + key, + aggregation_columns, + hb, + agg_fns, + input_schema, + output_schema, + aggregation_series: vec![], + hashes: vec![], + slice, + sort_partitions: vec![], + ooc_state: OocState::new(io_thread, ooc), + }; + if ooc { + out.ooc_state.init_ooc(out.input_schema.clone()).unwrap(); + } + out + } + + #[inline] + fn number_of_aggs(&self) -> usize { + self.aggregation_columns.len() + } + + fn pre_finalize(&mut self) -> PolarsResult> { + // we create a pointer to the aggregation functions buffer + // we will deref *mut on every partition thread + // this will be safe, as the partitions guarantee that access don't alias. + let aggregators = self.aggregators.as_ptr() as usize; + let aggregators_len = self.aggregators.len(); + let slices = compute_slices(&self.pre_agg_partitions, self.slice); + + POOL.install(|| { + let dfs = + self.pre_agg_partitions + .par_iter() + .zip(slices.par_iter()) + .filter_map(|(agg_map, slice)| { + let (offset, slice_len) = (*slice)?; + if agg_map.is_empty() { + return None; + } + // SAFETY: + // we will not alias. + let ptr = aggregators as *mut AggregateFunction; + let agg_fns = + unsafe { std::slice::from_raw_parts_mut(ptr, aggregators_len) }; + let mut key_builder = PrimitiveChunkedBuilder::::new( + self.output_schema.get_at_index(0).unwrap().0.clone(), + agg_map.len(), + ); + let dtypes = agg_fns + .iter() + .take(self.number_of_aggs()) + .map(|func| func.dtype()) + .collect::>(); + + let mut buffers = dtypes + .iter() + .map(|dtype| AnyValueBuffer::new(dtype, slice_len)) + .collect::>(); + + agg_map.into_iter().skip(offset).take(slice_len).for_each( + |(k, &offset)| { + key_builder.append_option(k.value); + + for (i, buffer) in (offset as usize + ..offset as usize + self.aggregation_columns.len()) + .zip(buffers.iter_mut()) + { + unsafe { + let agg_fn = agg_fns.get_unchecked_mut(i); + let av = agg_fn.finalize(); + buffer.add(av); + } + } + }, + ); + + let mut cols = Vec::with_capacity(1 + self.number_of_aggs()); + cols.push(key_builder.finish().into_series().into_column()); + cols.extend( + buffers + .into_iter() + .map(|buf| buf.into_series().into_column()), + ); + physical_agg_to_logical(&mut cols, &self.output_schema); + Some(unsafe { DataFrame::new_no_checks_height_from_first(cols) }) + }) + .collect::>(); + Ok(dfs) + }) + } + + fn sink_sorted(&mut self, ca: &ChunkedArray, chunk: DataChunk) -> PolarsResult { + if chunk.is_empty() { + return Ok(SinkResult::CanHaveMoreInput); + } + let arr = ca.downcast_iter().next().unwrap(); + let values = arr.values().as_slice(); + partition_to_groups_amortized(values, 0, false, 0, &mut self.sort_partitions); + + let pre_agg_len = self.pre_agg_partitions.len(); + let null: Option = None; + let null_hash = self.hb.hash_one(null); + + for group in &self.sort_partitions { + let [offset, length] = group; + let (opt_v, h) = if unsafe { arr.is_valid_unchecked(*offset as usize) } { + let first_g_value = unsafe { *values.get_unchecked(*offset as usize) }; + // Ensure that this hash is equal to the default non-sorted sink. + let h = self.hb.hash_one(first_g_value); + // let h = integer_hash(first_g_value); + (Some(first_g_value), h) + } else { + (null, null_hash) + }; + + let agg_idx = insert_and_get( + h, + opt_v, + pre_agg_len, + &mut self.pre_agg_partitions, + &mut self.aggregators, + &self.agg_fns, + ); + + for (i, aggregation_s) in + (0..self.number_of_aggs() as IdxSize).zip(&self.aggregation_series) + { + let agg_fn = unsafe { self.aggregators.get_unchecked_mut((agg_idx + i) as usize) }; + agg_fn.pre_agg_ordered(chunk.chunk_index, *offset, *length, aggregation_s) + } + } + self.aggregation_series.clear(); + self.ooc_state.check_memory_usage(&self.input_schema)?; + Ok(SinkResult::CanHaveMoreInput) + } + + // we don't yet hash here as the sorted fast path doesn't need hashes + fn prepare_key_and_aggregation_series( + &mut self, + context: &PExecutionContext, + chunk: &DataChunk, + ) -> PolarsResult { + let s = self.key.evaluate(chunk, &context.execution_state)?; + let s = s.to_physical_repr(); + let s = prepare_key(&s, chunk); + + // TODO: Amortize allocation. + for phys_e in self.aggregation_columns.iter() { + let s = phys_e.evaluate(chunk, &context.execution_state)?; + let s = s.to_physical_repr(); + self.aggregation_series.push(s.rechunk()); + } + Ok(s) + } + + fn sink_ooc( + &mut self, + context: &PExecutionContext, + chunk: DataChunk, + ) -> PolarsResult { + let s = self.prepare_key_and_aggregation_series(context, &chunk)?; + // cow -> &series -> &dyn series_trait -> &chunkedarray + let ca: &ChunkedArray = s.as_ref().as_ref(); + + // ensure the hashes are set + s.vec_hash(self.hb, &mut self.hashes).unwrap(); + + let arr = ca.downcast_iter().next().unwrap(); + let pre_agg_len = self.pre_agg_partitions.len(); + + // set all bits to false + self.ooc_state.reset_ooc_filter_rows(ca.len()); + + // this reuses the hashes buffer as [u64] as idx buffer as [idxsize] + // write the hashes to self.hashes buffer + // s.vec_hash(self.hb, &mut self.hashes).unwrap(); + // now we have written hashes, we take the pointer to this buffer + // we will write the aggregation_function indexes in the same buffer + // this is unsafe and we must check that we only write the hashes that + // already read/taken. So we write on the slots we just read + let agg_idx_ptr = self.hashes.as_ptr() as *mut u64 as *mut IdxSize; + + // different from standard sink + // we only set aggregation idx when the entry in the hashmap already + // exists. This way we don't grow the hashmap + // rows that are not processed are sinked to disk and loaded in a second pass + let mut processed = 0; + for (iteration_idx, (opt_v, &h)) in arr.iter().zip(self.hashes.iter()).enumerate() { + let opt_v = opt_v.copied(); + if let Some(agg_idx) = + try_insert_and_get(h, opt_v, pre_agg_len, &mut self.pre_agg_partitions) + { + // # Safety + // we write to the hashes buffer we iterate over at the moment. + // this is sound because the writes are trailing from iteration + unsafe { write_agg_idx(agg_idx_ptr, processed, agg_idx) }; + processed += 1; + } else { + // set this row to true: e.g. processed ooc + // SAFETY: we correctly set the length with `reset_ooc_filter_rows` + unsafe { + self.ooc_state.set_row_as_ooc(iteration_idx); + } + } + } + + let agg_idxs = unsafe { std::slice::from_raw_parts(agg_idx_ptr, processed) }; + apply_aggregation( + agg_idxs, + &chunk, + self.number_of_aggs(), + &self.aggregation_series, + &self.agg_fns, + &mut self.aggregators, + ); + + self.ooc_state.dump(chunk.data, &mut self.hashes); + + Ok(SinkResult::CanHaveMoreInput) + } +} + +impl Sink for PrimitiveGroupbySink +where + K::Native: Hash + Eq + Debug + Hash + DirtyHash, + ChunkedArray: IntoSeries, +{ + fn sink(&mut self, context: &PExecutionContext, chunk: DataChunk) -> PolarsResult { + if self.ooc_state.ooc { + return self.sink_ooc(context, chunk); + } + let s = self.prepare_key_and_aggregation_series(context, &chunk)?; + // cow -> &series -> &dyn series_trait -> &chunkedarray + let ca: &ChunkedArray = s.as_ref().as_ref(); + + // sorted fast path + if matches!(ca.is_sorted_flag(), IsSorted::Ascending) { + return self.sink_sorted(ca, chunk); + } + + s.vec_hash(self.hb, &mut self.hashes).unwrap(); + + // this reuses the hashes buffer as [u64] as idx buffer as [idxsize] + // write the hashes to self.hashes buffer + // s.vec_hash(self.hb, &mut self.hashes).unwrap(); + // now we have written hashes, we take the pointer to this buffer + // we will write the aggregation_function indexes in the same buffer + // this is unsafe and we must check that we only write the hashes that + // already read/taken. So we write on the slots we just read + let agg_idx_ptr = self.hashes.as_ptr() as *mut u64 as *mut IdxSize; + + let arr = ca.downcast_iter().next().unwrap(); + let pre_agg_len = self.pre_agg_partitions.len(); + for (iteration_idx, (opt_v, &h)) in arr.iter().zip(self.hashes.iter()).enumerate() { + let opt_v = opt_v.copied(); + let agg_idx = insert_and_get( + h, + opt_v, + pre_agg_len, + &mut self.pre_agg_partitions, + &mut self.aggregators, + &self.agg_fns, + ); + // # Safety + // we write to the hashes buffer we iterate over at the moment. + // this is sound because the writes are trailing from iteration + unsafe { write_agg_idx(agg_idx_ptr, iteration_idx, agg_idx) }; + } + + // note that this slice looks into the self.hashes buffer + let agg_idxs = unsafe { std::slice::from_raw_parts(agg_idx_ptr, ca.len()) }; + apply_aggregation( + agg_idxs, + &chunk, + self.number_of_aggs(), + &self.aggregation_series, + &self.agg_fns, + &mut self.aggregators, + ); + + self.aggregation_series.clear(); + self.ooc_state.check_memory_usage(&self.input_schema)?; + Ok(SinkResult::CanHaveMoreInput) + } + + fn combine(&mut self, other: &mut dyn Sink) { + // Don't parallelize this as `combine` is called in parallel. + let other = other.as_any().downcast_ref::().unwrap(); + + self.pre_agg_partitions + .iter_mut() + .zip(other.pre_agg_partitions.iter()) + .for_each(|(map_self, map_other)| { + for (key, &agg_idx_other) in map_other.iter() { + let entry = map_self.raw_entry_mut().from_key(key); + + let agg_idx_self = match entry { + RawEntryMut::Vacant(entry) => { + let offset = NumCast::from(self.aggregators.len()).unwrap(); + entry.insert(*key, offset); + // initialize the aggregators + for agg_fn in &self.agg_fns { + self.aggregators.push(agg_fn.split()) + } + offset + }, + RawEntryMut::Occupied(entry) => *entry.get(), + }; + // combine the aggregation functions + for i in 0..self.aggregation_columns.len() { + unsafe { + let agg_fn_other = + other.aggregators.get_unchecked(agg_idx_other as usize + i); + let agg_fn_self = self + .aggregators + .get_unchecked_mut(agg_idx_self as usize + i); + agg_fn_self.combine(agg_fn_other.as_any()) + } + } + } + }); + } + + fn finalize(&mut self, _context: &PExecutionContext) -> PolarsResult { + let dfs = self.pre_finalize()?; + let payload = if self.ooc_state.ooc { + let mut guard = self.ooc_state.io_thread.lock().unwrap(); + // Type hint fixes rust-analyzer thinking .take() is an iterator method. + let iot: &mut Option<_> = &mut *guard; + // Make sure that we reset the shared states. The OOC group_by will + // call split as well and it should not send continue spilling to disk. + let iot = iot.take().unwrap(); + self.ooc_state.ooc = false; + + Some((iot, self.split(0))) + } else { + None + }; + finalize_group_by(dfs, &self.output_schema, self.slice, payload) + } + + fn split(&self, thread_no: usize) -> Box { + let mut new = Self::new_inner( + self.key.clone(), + self.aggregation_columns.clone(), + self.agg_fns.iter().map(|func| func.split()).collect(), + self.input_schema.clone(), + self.output_schema.clone(), + self.slice, + Some(self.ooc_state.io_thread.clone()), + self.ooc_state.ooc, + ); + new.hb = self.hb; + new.thread_no = thread_no; + Box::new(new) + } + + fn as_any(&mut self) -> &mut dyn Any { + self + } + fn fmt(&self) -> &str { + "primitive_group_by" + } +} + +fn insert_and_get( + h: u64, + opt_v: Option, + pre_agg_len: usize, + pre_agg_partitions: &mut [PlIdHashMap>, IdxSize>], + current_aggregators: &mut Vec, + agg_fns: &Vec, +) -> IdxSize +where + T: NumericNative + Hash, +{ + let part = hash_to_partition(h, pre_agg_len); + let current_partition = unsafe { pre_agg_partitions.get_unchecked_mut(part) }; + + let entry = current_partition + .raw_entry_mut() + .from_hash(h, |k| k.value == opt_v); + match entry { + RawEntryMut::Vacant(entry) => { + let offset = unsafe { NumCast::from(current_aggregators.len()).unwrap_unchecked() }; + let key = Key { + hash: h, + value: opt_v, + }; + entry.insert(key, offset); + // initialize the aggregators + for agg_fn in agg_fns { + current_aggregators.push(agg_fn.split()) + } + offset + }, + RawEntryMut::Occupied(entry) => *entry.get(), + } +} + +fn try_insert_and_get( + h: u64, + opt_v: Option, + pre_agg_len: usize, + pre_agg_partitions: &mut [PlIdHashMap>, IdxSize>], +) -> Option +where + T: NumericNative + Hash, +{ + let part = hash_to_partition(h, pre_agg_len); + let current_partition = unsafe { pre_agg_partitions.get_unchecked_mut(part) }; + + let entry = current_partition + .raw_entry_mut() + .from_hash(h, |k| k.value == opt_v); + match entry { + RawEntryMut::Vacant(_) => None, + RawEntryMut::Occupied(entry) => Some(*entry.get()), + } +} + +pub(super) fn apply_aggregation( + agg_idxs: &[IdxSize], + chunk: &DataChunk, + num_aggs: usize, + aggregation_series: &[Series], + agg_fns: &[AggregateFunction], + aggregators: &mut [AggregateFunction], +) { + let chunk_idx = chunk.chunk_index; + for (agg_i, aggregation_s) in (0..num_aggs).zip(aggregation_series) { + let has_physical_agg = agg_fns[agg_i].has_physical_agg(); + apply_aggregate( + agg_i, + chunk_idx, + agg_idxs, + aggregation_s, + has_physical_agg, + aggregators, + ); + } +} diff --git a/crates/polars-pipe/src/executors/sinks/group_by/string.rs b/crates/polars-pipe/src/executors/sinks/group_by/string.rs new file mode 100644 index 000000000000..c6fa2615481e --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/group_by/string.rs @@ -0,0 +1,586 @@ +use std::any::Any; +use std::hash::{Hash, Hasher}; +use std::sync::Mutex; + +use hashbrown::hash_map::RawEntryMut; +use num_traits::NumCast; +use polars_core::frame::row::AnyValueBuffer; +use polars_core::prelude::*; +use polars_core::utils::_set_partition_size; +use polars_core::{IdBuildHasher, POOL}; +use polars_utils::aliases::PlSeedableRandomStateQuality; +use polars_utils::hashing::hash_to_partition; +use rayon::prelude::*; + +use super::aggregates::AggregateFn; +use crate::executors::sinks::HASHMAP_INIT_SIZE; +use crate::executors::sinks::group_by::aggregates::AggregateFunction; +use crate::executors::sinks::group_by::ooc_state::OocState; +use crate::executors::sinks::group_by::physical_agg_to_logical; +use crate::executors::sinks::group_by::primitive::apply_aggregation; +use crate::executors::sinks::group_by::utils::{compute_slices, finalize_group_by, prepare_key}; +use crate::executors::sinks::io::IOThread; +use crate::executors::sinks::utils::load_vec; +use crate::expressions::PhysicalPipedExpr; +use crate::operators::{DataChunk, FinalizedSink, PExecutionContext, Sink, SinkResult}; + +// This is the hash and the Index offset in the linear buffer +#[derive(Copy, Clone)] +struct Key { + pub(super) hash: u64, + pub(super) idx: IdxSize, +} + +impl Key { + #[inline] + pub(super) fn new(hash: u64, idx: IdxSize) -> Self { + Self { hash, idx } + } +} + +impl Hash for Key { + #[inline] + fn hash(&self, state: &mut H) { + state.write_u64(self.hash) + } +} + +// we store a hashmap per partition (partitioned by hash) +// the hashmap contains indexes as keys and as values +// those indexes point into the keys buffer and the values buffer +// the keys buffer are buffers of AnyValue per partition +// and the values are buffer of Aggregation functions per partition +pub struct StringGroupbySink { + thread_no: usize, + // idx is the offset in the array with keys + // idx is the offset in the array with aggregators + pre_agg_partitions: Vec>, + // the aggregations/keys are all tightly packed + // the aggregation function of a group can be found + // by: + // * offset = (idx) + // * end = (offset + 1) + keys: Vec>, + aggregators: Vec, + // the key that will be aggregated on + key_column: Arc, + // the columns that will be aggregated + aggregation_columns: Arc>>, + hb: PlSeedableRandomStateQuality, + // Initializing Aggregation functions. If we aggregate by 2 columns + // this vec will have two functions. We will use these functions + // to populate the buffer where the hashmap points to + agg_fns: Vec, + input_schema: SchemaRef, + output_schema: SchemaRef, + // amortize allocations + aggregation_series: Vec, + hashes: Vec, + slice: Option<(i64, usize)>, + + ooc_state: OocState, +} + +impl StringGroupbySink { + pub(crate) fn new( + key_column: Arc, + aggregation_columns: Arc>>, + agg_fns: Vec, + input_schema: SchemaRef, + output_schema: SchemaRef, + slice: Option<(i64, usize)>, + ) -> Self { + Self::new_inner( + key_column, + aggregation_columns, + agg_fns, + input_schema, + output_schema, + slice, + None, + false, + ) + } + + #[allow(clippy::too_many_arguments)] + fn new_inner( + key_column: Arc, + aggregation_columns: Arc>>, + agg_fns: Vec, + input_schema: SchemaRef, + output_schema: SchemaRef, + slice: Option<(i64, usize)>, + io_thread: Option>>>, + ooc: bool, + ) -> Self { + let hb = Default::default(); + let partitions = _set_partition_size(); + + let pre_agg = load_vec(partitions, || PlIdHashMap::with_capacity(HASHMAP_INIT_SIZE)); + let keys = Vec::with_capacity(HASHMAP_INIT_SIZE * partitions); + let aggregators = + Vec::with_capacity(HASHMAP_INIT_SIZE * aggregation_columns.len() * partitions); + + let mut out = Self { + thread_no: 0, + pre_agg_partitions: pre_agg, + keys, + aggregators, + key_column, + aggregation_columns, + hb, + agg_fns, + input_schema, + output_schema, + aggregation_series: vec![], + hashes: vec![], + slice, + ooc_state: OocState::new(io_thread, ooc), + }; + if ooc { + out.ooc_state.init_ooc(out.input_schema.clone()).unwrap(); + } + out + } + + #[inline] + fn number_of_aggs(&self) -> usize { + self.aggregation_columns.len() + } + + fn pre_finalize(&mut self) -> PolarsResult> { + // we create a pointer to the aggregation functions buffer + // we will deref *mut on every partition thread + // this will be safe, as the partitions guarantee that access don't alias. + let aggregators = self.aggregators.as_ptr() as usize; + let aggregators_len = self.aggregators.len(); + + let slices = compute_slices(&self.pre_agg_partitions, self.slice); + + POOL.install(|| { + let dfs = + self.pre_agg_partitions + .par_iter() + .zip(slices.par_iter()) + .filter_map(|(agg_map, slice)| { + let ptr = aggregators as *mut AggregateFunction; + // SAFETY: + // we will not alias. + let aggregators = + unsafe { std::slice::from_raw_parts_mut(ptr, aggregators_len) }; + + let (offset, slice_len) = (*slice)?; + if agg_map.is_empty() { + return None; + } + let dtypes = aggregators + .iter() + .take(self.number_of_aggs()) + .map(|func| func.dtype()) + .collect::>(); + + let mut buffers = dtypes + .iter() + .map(|dtype| AnyValueBuffer::new(dtype, slice_len)) + .collect::>(); + + let cap = std::cmp::min(slice_len, agg_map.len()); + let mut key_builder = StringChunkedBuilder::new(PlSmallStr::EMPTY, cap); + agg_map.into_iter().skip(offset).take(slice_len).for_each( + |(k, &offset)| { + let key_offset = k.idx as usize; + let key = unsafe { self.keys.get_unchecked(key_offset).as_deref() }; + key_builder.append_option(key); + + for (i, buffer) in (offset as usize + ..offset as usize + self.aggregation_columns.len()) + .zip(buffers.iter_mut()) + { + unsafe { + let agg_fn = aggregators.get_unchecked_mut(i); + let av = agg_fn.finalize(); + buffer.add(av); + } + } + }, + ); + + let mut cols = Vec::with_capacity(1 + self.number_of_aggs()); + cols.push(key_builder.finish().into_series().into_column()); + cols.extend( + buffers + .into_iter() + .map(|buf| buf.into_series().into_column()), + ); + physical_agg_to_logical(&mut cols, &self.output_schema); + Some(unsafe { DataFrame::new_no_checks_height_from_first(cols) }) + }) + .collect::>(); + + Ok(dfs) + }) + } + fn prepare_key_and_aggregation_series( + &mut self, + context: &PExecutionContext, + chunk: &DataChunk, + ) -> PolarsResult { + let s = self.key_column.evaluate(chunk, &context.execution_state)?; + let s = s.to_physical_repr(); + let s = prepare_key(&s, chunk); + + // TODO: Amortize allocation. + for phys_e in self.aggregation_columns.iter() { + let s = phys_e.evaluate(chunk, &context.execution_state)?; + let s = s.to_physical_repr(); + self.aggregation_series.push(s.rechunk()); + } + s.vec_hash(self.hb, &mut self.hashes).unwrap(); + Ok(s) + } + #[inline] + fn get_partitions(&mut self, h: u64) -> &mut PlIdHashMap { + let partition = hash_to_partition(h, self.pre_agg_partitions.len()); + unsafe { self.pre_agg_partitions.get_unchecked_mut(partition) } + } + + fn sink_ooc( + &mut self, + context: &PExecutionContext, + chunk: DataChunk, + ) -> PolarsResult { + let s = self.prepare_key_and_aggregation_series(context, &chunk)?; + + // take containers to please bchk + // we put them back once done + let mut hashes = std::mem::take(&mut self.hashes); + let keys = std::mem::take(&mut self.keys); + let agg_fns = std::mem::take(&mut self.agg_fns); + let mut aggregators = std::mem::take(&mut self.aggregators); + + // write the hashes to self.hashes buffer + // s.vec_hash(self.hb, &mut self.hashes).unwrap(); + // now we have written hashes, we take the pointer to this buffer + // we will write the aggregation_function indexes in the same buffer + // this is unsafe and we must check that we only write the hashes that + // already read/taken. So we write on the slots we just read + let agg_idx_ptr = hashes.as_ptr() as *mut u64 as *mut IdxSize; + // array of the keys + let keys_arr = s.str().unwrap().downcast_iter().next().unwrap().clone(); + + // set all bits to false + self.ooc_state.reset_ooc_filter_rows(chunk.data.height()); + + let mut processed = 0; + for (iteration_idx, (key_val, &h)) in keys_arr.iter().zip(&hashes).enumerate() { + let current_partition = self.get_partitions(h); + let entry = get_entry(key_val, h, current_partition, &keys); + + match entry { + RawEntryMut::Vacant(_) => { + // set this row to true: e.g. processed ooc + // SAFETY: we correctly set the length with `reset_ooc_filter_rows` + unsafe { + self.ooc_state.set_row_as_ooc(iteration_idx); + } + }, + RawEntryMut::Occupied(entry) => { + let agg_idx = *entry.get(); + // # Safety + // we write to the hashes buffer we iterate over at the moment. + // this is sound because we writes are trailing from iteration + unsafe { write_agg_idx(agg_idx_ptr, processed, agg_idx) }; + processed += 1; + }, + }; + } + + // note that this slice looks into the self.hashes buffer + let agg_idxs = unsafe { std::slice::from_raw_parts(agg_idx_ptr, processed) }; + + apply_aggregation( + agg_idxs, + &chunk, + self.number_of_aggs(), + &self.aggregation_series, + &agg_fns, + &mut aggregators, + ); + self.ooc_state.dump(chunk.data, &mut hashes); + + self.aggregation_series.clear(); + self.hashes = hashes; + self.keys = keys; + self.agg_fns = agg_fns; + self.aggregators = aggregators; + self.hashes.clear(); + self.ooc_state.check_memory_usage(&self.input_schema)?; + Ok(SinkResult::CanHaveMoreInput) + } +} + +impl Sink for StringGroupbySink { + fn sink(&mut self, context: &PExecutionContext, chunk: DataChunk) -> PolarsResult { + if chunk.is_empty() { + return Ok(SinkResult::CanHaveMoreInput); + } + if self.ooc_state.ooc { + return self.sink_ooc(context, chunk); + } + let s = self.prepare_key_and_aggregation_series(context, &chunk)?; + + // take containers to please bchk + // we put them back once done + let hashes = std::mem::take(&mut self.hashes); + let mut keys = std::mem::take(&mut self.keys); + let agg_fns = std::mem::take(&mut self.agg_fns); + let mut aggregators = std::mem::take(&mut self.aggregators); + + // write the hashes to self.hashes buffer + // s.vec_hash(self.hb, &mut self.hashes).unwrap(); + // now we have written hashes, we take the pointer to this buffer + // we will write the aggregation_function indexes in the same buffer + // this is unsafe and we must check that we only write the hashes that + // already read/taken. So we write on the slots we just read + let agg_idx_ptr = hashes.as_ptr() as *mut u64 as *mut IdxSize; + // array of the keys + let keys_arr = s.str().unwrap().downcast_iter().next().unwrap().clone(); + + for (iteration_idx, (key_val, &h)) in keys_arr.iter().zip(&hashes).enumerate() { + let current_partition = self.get_partitions(h); + let entry = get_entry(key_val, h, current_partition, &keys); + + let agg_idx = match entry { + RawEntryMut::Vacant(entry) => { + let value_offset = + unsafe { NumCast::from(aggregators.len()).unwrap_unchecked() }; + let keys_offset = + unsafe { Key::new(h, NumCast::from(keys.len()).unwrap_unchecked()) }; + entry.insert(keys_offset, value_offset); + + keys.push(key_val.map(|s| s.into())); + + // initialize the aggregators + for agg_fn in &agg_fns { + aggregators.push(agg_fn.split()) + } + value_offset + }, + RawEntryMut::Occupied(entry) => *entry.get(), + }; + // # Safety + // we write to the hashes buffer we iterate over at the moment. + // this is sound because we writes are trailing from iteration + unsafe { write_agg_idx(agg_idx_ptr, iteration_idx, agg_idx) }; + } + + // note that this slice looks into the self.hashes buffer + let agg_idxs = unsafe { std::slice::from_raw_parts(agg_idx_ptr, keys_arr.len()) }; + + apply_aggregation( + agg_idxs, + &chunk, + self.number_of_aggs(), + &self.aggregation_series, + &agg_fns, + &mut aggregators, + ); + self.aggregation_series.clear(); + self.hashes = hashes; + self.keys = keys; + self.agg_fns = agg_fns; + self.aggregators = aggregators; + self.hashes.clear(); + self.ooc_state.check_memory_usage(&self.input_schema)?; + Ok(SinkResult::CanHaveMoreInput) + } + + fn combine(&mut self, other: &mut dyn Sink) { + // don't parallelize this as this is already done in parallel. + + let other = other.as_any().downcast_ref::().unwrap(); + let n_partitions = self.pre_agg_partitions.len(); + debug_assert_eq!(n_partitions, other.pre_agg_partitions.len()); + + self.pre_agg_partitions + .iter_mut() + .zip(other.pre_agg_partitions.iter()) + .for_each(|(map_self, map_other)| { + for (k_other, &agg_idx_other) in map_other.iter() { + // the hash value + let h = k_other.hash; + + // the offset in the keys of other + let idx_other = k_other.idx as usize; + // slice to the keys of other + let key_other = unsafe { other.keys.get_unchecked(idx_other) }; + + let entry = map_self.raw_entry_mut().from_hash(h, |k_self| { + h == k_self.hash && { + // the offset in the keys of self + let idx_self = k_self.idx as usize; + // slice to the keys of self + // SAFETY: + // in bounds + let key_self = unsafe { self.keys.get_unchecked(idx_self) }; + // compare the keys + key_self == key_other + } + }); + + let agg_idx_self = match entry { + // the keys of other are not in this table, so we must update this table + RawEntryMut::Vacant(entry) => { + // get the current offset in the values buffer + let values_offset = + unsafe { NumCast::from(self.aggregators.len()).unwrap_unchecked() }; + // get the key, comprised of the hash and the current offset in the keys buffer + let key = unsafe { + Key::new(h, NumCast::from(self.keys.len()).unwrap_unchecked()) + }; + + // extend the keys buffer with the new key from other + self.keys.push(key_other.clone()); + + // insert the keys and values_offset + entry.insert(key, values_offset); + // initialize the new aggregators + for agg_fn in &self.agg_fns { + self.aggregators.push(agg_fn.split()) + } + values_offset + }, + RawEntryMut::Occupied(entry) => *entry.get(), + }; + + // combine the aggregation functions + for i in 0..self.aggregation_columns.len() { + unsafe { + let agg_fn_other = + other.aggregators.get_unchecked(agg_idx_other as usize + i); + let agg_fn_self = self + .aggregators + .get_unchecked_mut(agg_idx_self as usize + i); + agg_fn_self.combine(agg_fn_other.as_any()) + } + } + } + }); + } + + fn split(&self, thread_no: usize) -> Box { + let mut new = Self::new_inner( + self.key_column.clone(), + self.aggregation_columns.clone(), + self.agg_fns.iter().map(|func| func.split()).collect(), + self.input_schema.clone(), + self.output_schema.clone(), + self.slice, + Some(self.ooc_state.io_thread.clone()), + self.ooc_state.ooc, + ); + new.hb = self.hb; + new.thread_no = thread_no; + Box::new(new) + } + + fn finalize(&mut self, _context: &PExecutionContext) -> PolarsResult { + let dfs = self.pre_finalize()?; + let payload = if self.ooc_state.ooc { + let mut iot = self.ooc_state.io_thread.lock().unwrap(); + // make sure that we reset the shared states + // the OOC group_by will call split as well and it should + // not send continue spilling to disk + let iot = iot.take().unwrap(); + self.ooc_state.ooc = false; + + Some((iot, self.split(0))) + } else { + None + }; + finalize_group_by(dfs, &self.output_schema, self.slice, payload) + } + + fn as_any(&mut self) -> &mut dyn Any { + self + } + fn fmt(&self) -> &str { + "string_group_by" + } +} + +// write agg_idx to the hashes buffer. +pub(super) unsafe fn write_agg_idx(h: *mut IdxSize, i: usize, agg_idx: IdxSize) { + h.add(i).write(agg_idx) +} + +pub(super) fn apply_aggregate( + agg_i: usize, + chunk_idx: IdxSize, + agg_idxs: &[IdxSize], + aggregation_s: &Series, + has_physical_agg: bool, + aggregators: &mut [AggregateFunction], +) { + macro_rules! apply_agg { + ($self:expr, $macro:ident $(, $opt_args:expr)*) => {{ + match $self.dtype() { + #[cfg(feature = "dtype-u8")] + DataType::UInt8 => $macro!($self.u8().unwrap(), pre_agg_primitive $(, $opt_args)*), + #[cfg(feature = "dtype-u16")] + DataType::UInt16 => $macro!($self.u16().unwrap(), pre_agg_primitive $(, $opt_args)*), + DataType::UInt32 => $macro!($self.u32().unwrap(), pre_agg_primitive $(, $opt_args)*), + DataType::UInt64 => $macro!($self.u64().unwrap(), pre_agg_primitive $(, $opt_args)*), + #[cfg(feature = "dtype-i8")] + DataType::Int8 => $macro!($self.i8().unwrap(), pre_agg_primitive $(, $opt_args)*), + #[cfg(feature = "dtype-i16")] + DataType::Int16 => $macro!($self.i16().unwrap(), pre_agg_primitive $(, $opt_args)*), + DataType::Int32 => $macro!($self.i32().unwrap(), pre_agg_primitive $(, $opt_args)*), + DataType::Int64 => $macro!($self.i64().unwrap(), pre_agg_primitive $(, $opt_args)*), + DataType::Float32 => $macro!($self.f32().unwrap(), pre_agg_primitive $(, $opt_args)*), + DataType::Float64 => $macro!($self.f64().unwrap(), pre_agg_primitive $(, $opt_args)*), + dt => panic!("not implemented for {:?}", dt), + } + }}; + } + + if has_physical_agg && aggregation_s.dtype().is_primitive_numeric() { + macro_rules! dispatch { + ($ca:expr, $name:ident) => {{ + let arr = $ca.downcast_iter().next().unwrap(); + + for (&agg_idx, av) in agg_idxs.iter().zip(arr.into_iter()) { + let i = agg_idx as usize + agg_i; + let agg_fn = unsafe { aggregators.get_unchecked_mut(i) }; + + agg_fn.$name(chunk_idx, av.copied()) + } + }}; + } + + apply_agg!(aggregation_s, dispatch); + } else { + let mut iter = aggregation_s.phys_iter(); + for &agg_idx in agg_idxs.iter() { + let i = agg_idx as usize + agg_i; + let agg_fn = unsafe { aggregators.get_unchecked_mut(i) }; + agg_fn.pre_agg(chunk_idx, &mut iter) + } + } +} + +#[inline] +fn get_entry<'a>( + key_val: Option<&str>, + h: u64, + current_partition: &'a mut PlIdHashMap, + keys: &[Option], +) -> RawEntryMut<'a, Key, IdxSize, IdBuildHasher> { + current_partition.raw_entry_mut().from_hash(h, |key| { + // first compare the hash before we incur the cache miss + key.hash == h && { + let idx = key.idx as usize; + unsafe { keys.get_unchecked(idx).as_deref() == key_val } + } + }) +} diff --git a/crates/polars-pipe/src/executors/sinks/group_by/utils.rs b/crates/polars-pipe/src/executors/sinks/group_by/utils.rs new file mode 100644 index 000000000000..82f1ab0b13f1 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/group_by/utils.rs @@ -0,0 +1,86 @@ +use hashbrown::HashMap; +use polars_core::prelude::*; +use polars_core::utils::{accumulate_dataframes_vertical_unchecked, slice_offsets}; + +use crate::executors::sinks::group_by::ooc::GroupBySource; +use crate::executors::sinks::io::{IOThread, block_thread_until_io_thread_done}; +use crate::operators::{DataChunk, FinalizedSink, Sink}; + +pub(super) fn default_slices( + pre_agg_partitions: &[HashMap], +) -> Vec> { + pre_agg_partitions + .iter() + .map(|agg_map| Some((0, agg_map.len()))) + .collect() +} + +pub(super) fn compute_slices( + pre_agg_partitions: &[HashMap], + slice: Option<(i64, usize)>, +) -> Vec> { + if let Some((offset, slice_len)) = slice { + let total_len = pre_agg_partitions + .iter() + .map(|agg_map| agg_map.len()) + .sum::(); + + if total_len <= slice_len { + return default_slices(pre_agg_partitions); + } + + let (mut offset, mut len) = slice_offsets(offset, slice_len, total_len); + + pre_agg_partitions + .iter() + .map(|agg_map| { + if offset > agg_map.len() { + offset -= agg_map.len(); + None + } else { + let slice = Some((offset, std::cmp::min(len, agg_map.len()))); + len = len.saturating_sub(agg_map.len() - offset); + offset = 0; + slice + } + }) + .collect::>() + } else { + default_slices(pre_agg_partitions) + } +} + +pub(super) fn finalize_group_by( + dfs: Vec, + output_schema: &Schema, + slice: Option<(i64, usize)>, + ooc_payload: Option<(IOThread, Box)>, +) -> PolarsResult { + let df = if dfs.is_empty() { + DataFrame::empty_with_schema(output_schema) + } else { + let df = accumulate_dataframes_vertical_unchecked(dfs); + // re init to check duplicates + DataFrame::new(df.take_columns())? + }; + + match ooc_payload { + None => Ok(FinalizedSink::Finished(df)), + Some((iot, sink)) => { + // we wait until all chunks are spilled + block_thread_until_io_thread_done(&iot); + + Ok(FinalizedSink::Source(Box::new(GroupBySource::new( + iot, df, sink, slice, + )?))) + }, + } +} + +pub(super) fn prepare_key(s: &Series, chunk: &DataChunk) -> Series { + if s.len() == 1 && chunk.data.height() > 1 { + s.new_from_index(0, chunk.data.height()) + } else { + s.rechunk() + } +} diff --git a/crates/polars-pipe/src/executors/sinks/io.rs b/crates/polars-pipe/src/executors/sinks/io.rs new file mode 100644 index 000000000000..52c27ff75f9c --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/io.rs @@ -0,0 +1,313 @@ +use std::fs; +use std::fs::File; +use std::path::{Path, PathBuf}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::time::{Duration, SystemTime}; + +use crossbeam_channel::{Receiver, Sender, bounded, unbounded}; +use polars_core::error::ErrString; +use polars_core::prelude::*; +use polars_core::utils::arrow::temporal_conversions::SECONDS_IN_DAY; +use polars_io::prelude::*; + +use crate::executors::sinks::get_base_temp_dir; +use crate::pipeline::morsels_per_sink; + +pub(in crate::executors::sinks) type DfIter = + Box + Sync + Send>; +// The Option are the partitions it should be written to, if any +type Payload = (Option, DfIter); + +/// A helper that can be used to spill to disk +pub(crate) struct IOThread { + payload_tx: Sender, + cleanup_tx: Sender, + _lockfile: Arc, + pub(in crate::executors::sinks) dir: PathBuf, + pub(in crate::executors::sinks) sent: Arc, + pub(in crate::executors::sinks) total: Arc, + pub(in crate::executors::sinks) thread_local_count: Arc, + schema: SchemaRef, +} + +fn get_lockfile_path(dir: &Path) -> PathBuf { + let mut lockfile_path = dir.to_path_buf(); + lockfile_path.push(".lock"); + lockfile_path +} + +fn get_spill_dir(operation_name: &'static str) -> PolarsResult { + let id = uuid::Uuid::new_v4(); + + let mut dir = std::path::PathBuf::from(get_base_temp_dir()); + dir.push(format!("polars/{operation_name}/{id}")); + + if !dir.exists() { + fs::create_dir_all(&dir).map_err(|err| { + PolarsError::ComputeError(ErrString::from(format!( + "Failed to create spill directory: {}", + err + ))) + })?; + } else if !dir.is_dir() { + return Err(PolarsError::ComputeError( + "Specified spill path is not a directory".into(), + )); + } + + Ok(dir) +} + +fn clean_after_delay(time: Option, secs: u64, path: &Path) { + if let Some(time) = time { + let modified_since = SystemTime::now().duration_since(time).unwrap().as_secs(); + if modified_since > secs { + // This can be fallible if another thread removes this. + // That is fine. + let _ = std::fs::remove_dir_all(path); + } + } else { + polars_warn!("could not modified time on this platform") + } +} + +/// Starts a new thread that will clean up operations of directories that don't +/// have a lockfile (opened with 'w' permissions). +fn gc_thread(operation_name: &'static str, rx: Receiver) { + let _ = std::thread::spawn(move || { + // First clean all existing + let mut dir = std::path::PathBuf::from(get_base_temp_dir()); + dir.push(format!("polars/{operation_name}")); + + // if the directory does not exist, there is nothing to clean + let rd = match std::fs::read_dir(&dir) { + Ok(rd) => rd, + _ => panic!("cannot find {:?}", dir), + }; + + for entry in rd { + let path = entry.unwrap().path(); + if path.is_dir() { + let lockfile_path = get_lockfile_path(&path); + + if let Ok(lockfile) = File::open(lockfile_path) { + // lockfile can be read + if let Ok(md) = lockfile.metadata() { + let time = md.modified().ok(); + // The lockfile can still exist if a process was canceled + // so we also check the modified date + // we don't expect queries that run a month. + clean_after_delay(time, SECONDS_IN_DAY as u64 * 30, &path); + } + } else { + // If path already removed, we simply continue. + if let Ok(md) = path.metadata() { + let time = md.modified().ok(); + // Wait 15 seconds to ensure we don't remove before lockfile is created + // in a `collect_all` contention case + clean_after_delay(time, 15, &path); + } + } + } + } + + // Clean on receive + while let Ok(path) = rx.recv() { + if path.is_file() { + let res = std::fs::remove_file(path); + debug_assert!(res.is_ok()); + } else { + let res = std::fs::remove_dir_all(path); + debug_assert!(res.is_ok()); + } + } + }); +} + +impl IOThread { + pub(in crate::executors::sinks) fn try_new( + // Schema of the file that will be dumped to disk + schema: SchemaRef, + // Will be used as subdirectory name in `~/.base_dir/polars/` + operation_name: &'static str, + ) -> PolarsResult { + let dir = get_spill_dir(operation_name)?; + + // make sure we create lockfile before we GC + let lockfile_path = get_lockfile_path(&dir); + let lockfile = Arc::new(LockFile::new(lockfile_path)?); + + let (cleanup_tx, rx) = unbounded::(); + // start a thread that will clean up old dumps. + // TODO: if we will have more ooc in the future we will have a dedicated GC thread + gc_thread(operation_name, rx); + + // we need some pushback otherwise we still could go OOM. + let (tx, rx) = bounded::(morsels_per_sink() * 2); + + let sent: Arc = Default::default(); + let total: Arc = Default::default(); + let thread_local_count: Arc = Default::default(); + + let dir2 = dir.clone(); + let total2 = total.clone(); + let lockfile2 = lockfile.clone(); + let schema2 = schema.clone(); + std::thread::spawn(move || { + let schema = schema2; + // this moves the lockfile in the thread + // we keep one in the thread and one in the `IoThread` struct + let _keep_hold_on_lockfile = lockfile2; + + let mut count = 0usize; + + // We accept 2 cases. E.g. + // 1. (None, DfIter): + // This will dump to `dir/count.ipc` + // 2. (Some(partitions), DfIter) + // This will dump to `dir/partition/count.ipc` + while let Ok((partitions, iter)) = rx.recv() { + if let Some(partitions) = partitions { + for (part, mut df) in partitions.into_no_null_iter().zip(iter) { + df.shrink_to_fit(); + df.align_chunks_par(); + let mut path = dir2.clone(); + path.push(format!("{part}")); + + let _ = std::fs::create_dir(&path); + path.push(format!("{count}.ipc")); + + let file = File::create(path).unwrap(); + let writer = IpcWriter::new(file).with_compat_level(CompatLevel::newest()); + let mut writer = writer.batched(&schema).unwrap(); + writer.write_batch(&df).unwrap(); + writer.finish().unwrap(); + count += 1; + } + } else { + let mut path = dir2.clone(); + path.push(format!("{count}_0_pass.ipc")); + + let file = File::create(path).unwrap(); + let writer = IpcWriter::new(file).with_compat_level(CompatLevel::newest()); + let mut writer = writer.batched(&schema).unwrap(); + + for mut df in iter { + df.shrink_to_fit(); + df.align_chunks_par(); + writer.write_batch(&df).unwrap(); + } + writer.finish().unwrap(); + + count += 1; + } + total2.store(count, Ordering::Relaxed); + } + }); + + Ok(Self { + payload_tx: tx, + cleanup_tx, + dir, + sent, + total, + _lockfile: lockfile, + thread_local_count, + schema, + }) + } + + pub(in crate::executors::sinks) fn dump_chunk(&self, mut df: DataFrame) { + // if IO thread is blocked + // we write locally on this thread + if self.payload_tx.is_full() { + df.shrink_to_fit(); + let mut path = self.dir.clone(); + let count = self.thread_local_count.fetch_add(1, Ordering::Relaxed); + // thread local name we start with an underscore to ensure we don't get + // duplicates + path.push(format!("_{count}_full.ipc")); + + let file = File::create(path).unwrap(); + let mut writer = IpcWriter::new(file).with_compat_level(CompatLevel::newest()); + writer.finish(&mut df).unwrap(); + } else { + let iter = Box::new(std::iter::once(df)); + self.dump_iter(None, iter) + } + } + + pub(in crate::executors::sinks) fn clean(&self, path: PathBuf) { + self.cleanup_tx.send(path).unwrap() + } + + pub(in crate::executors::sinks) fn dump_partition(&self, partition_no: IdxSize, df: DataFrame) { + let partition = Some(IdxCa::from_vec(PlSmallStr::EMPTY, vec![partition_no])); + let iter = Box::new(std::iter::once(df)); + self.dump_iter(partition, iter) + } + + pub(in crate::executors::sinks) fn dump_partition_local( + &self, + partition_no: IdxSize, + mut df: DataFrame, + ) { + df.align_chunks(); + let count = self.thread_local_count.fetch_add(1, Ordering::Relaxed); + let mut path = self.dir.clone(); + path.push(format!("{partition_no}")); + + let _ = std::fs::create_dir(&path); + // Thread local name we start with an underscore to ensure we don't get + // duplicates + path.push(format!("_{count}.ipc")); + let file = File::create(path).unwrap(); + let writer = IpcWriter::new(file).with_compat_level(CompatLevel::newest()); + let mut writer = writer.batched(&self.schema).unwrap(); + writer.write_batch(&df).unwrap(); + writer.finish().unwrap(); + } + + pub(in crate::executors::sinks) fn dump_iter(&self, partition: Option, iter: DfIter) { + let add = iter.size_hint().1.unwrap(); + self.payload_tx.send((partition, iter)).unwrap(); + self.sent.fetch_add(add, Ordering::Relaxed); + } +} + +impl Drop for IOThread { + fn drop(&mut self) { + // we drop the lockfile explicitly as the thread GC will leak. + std::fs::remove_file(&self._lockfile.path).unwrap(); + } +} + +pub(in crate::executors::sinks) fn block_thread_until_io_thread_done(io_thread: &IOThread) { + // get number sent + let sent = io_thread.sent.load(Ordering::Relaxed); + // get number processed + while io_thread.total.load(Ordering::Relaxed) != sent { + std::thread::park_timeout(Duration::from_millis(6)) + } +} + +struct LockFile { + path: PathBuf, +} + +impl LockFile { + fn new(path: PathBuf) -> PolarsResult { + match File::create(&path) { + Ok(_) => Ok(Self { path }), + Err(e) => { + polars_bail!(ComputeError: "could not create lockfile: {e}") + }, + } + } +} + +impl Drop for LockFile { + fn drop(&mut self) { + let _ = std::fs::remove_file(&self.path); + } +} diff --git a/crates/polars-pipe/src/executors/sinks/joins/cross.rs b/crates/polars-pipe/src/executors/sinks/joins/cross.rs new file mode 100644 index 000000000000..865c8eb9e651 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/joins/cross.rs @@ -0,0 +1,206 @@ +use std::any::Any; +use std::iter::StepBy; +use std::ops::Range; +use std::sync::Arc; +use std::vec; + +use polars_core::error::PolarsResult; +use polars_core::frame::DataFrame; +use polars_ops::prelude::CrossJoin as CrossJoinTrait; +use polars_utils::arena::Node; +use polars_utils::pl_str::PlSmallStr; + +use crate::executors::operators::PlaceHolder; +use crate::operators::{ + DataChunk, FinalizedSink, Operator, OperatorResult, PExecutionContext, Sink, SinkResult, + chunks_to_df_unchecked, +}; + +#[derive(Default)] +pub struct CrossJoin { + chunks: Vec, + suffix: PlSmallStr, + swapped: bool, + node: Node, + placeholder: PlaceHolder, +} + +impl CrossJoin { + pub(crate) fn new( + suffix: PlSmallStr, + swapped: bool, + node: Node, + placeholder: PlaceHolder, + ) -> Self { + CrossJoin { + chunks: vec![], + suffix, + swapped, + node, + placeholder, + } + } +} + +impl Sink for CrossJoin { + fn node(&self) -> Node { + self.node + } + fn is_join_build(&self) -> bool { + true + } + + fn sink(&mut self, _context: &PExecutionContext, chunk: DataChunk) -> PolarsResult { + self.chunks.push(chunk); + Ok(SinkResult::CanHaveMoreInput) + } + + fn combine(&mut self, other: &mut dyn Sink) { + let other = other.as_any().downcast_mut::().unwrap(); + let other_chunks = std::mem::take(&mut other.chunks); + self.chunks.extend(other_chunks); + } + + fn split(&self, _thread_no: usize) -> Box { + Box::new(Self { + suffix: self.suffix.clone(), + swapped: self.swapped, + placeholder: self.placeholder.clone(), + ..Default::default() + }) + } + + fn finalize(&mut self, _context: &PExecutionContext) -> PolarsResult { + let op = Box::new(CrossJoinProbe { + df: Arc::new(chunks_to_df_unchecked(std::mem::take(&mut self.chunks))), + suffix: self.suffix.clone(), + in_process_left: None, + in_process_right: None, + in_process_left_df: Default::default(), + output_names: None, + swapped: self.swapped, + }); + self.placeholder.replace(op); + + Ok(FinalizedSink::Operator) + } + + fn as_any(&mut self) -> &mut dyn Any { + self + } + + fn fmt(&self) -> &str { + "cross_join_sink" + } +} + +#[derive(Clone)] +pub struct CrossJoinProbe { + df: Arc, + suffix: PlSmallStr, + in_process_left: Option>>, + in_process_right: Option>>, + in_process_left_df: DataFrame, + output_names: Option>, + swapped: bool, +} + +impl Operator for CrossJoinProbe { + fn execute( + &mut self, + _context: &PExecutionContext, + chunk: &DataChunk, + ) -> PolarsResult { + // Expected output is size**2, so this needs to be a small number. + // However, if one of the DataFrames is much smaller than 250, we want + // to take rather more from the other DataFrame so we don't end up with + // overly small chunks. + let mut size = 250; + if chunk.data.height() > 0 { + size *= (250 / chunk.data.height()).max(1); + } + if self.df.height() > 0 { + size *= (250 / self.df.height()).max(1); + } + + if self.in_process_left.is_none() { + let mut iter_left = (0..self.df.height()).step_by(size); + let offset = iter_left.next().unwrap_or(0); + self.in_process_left_df = self.df.slice(offset as i64, size); + self.in_process_left = Some(iter_left); + } + if self.in_process_right.is_none() { + self.in_process_right = Some((0..chunk.data.height()).step_by(size)); + } + // output size is large we process in chunks + let iter_left = self.in_process_left.as_mut().unwrap(); + let iter_right = self.in_process_right.as_mut().unwrap(); + + match iter_right.next() { + None => { + self.in_process_right = None; + + // if right is depleted take the next left chunk + match iter_left.next() { + None => { + self.in_process_left = None; + Ok(OperatorResult::NeedsNewData) + }, + Some(offset) => { + self.in_process_left_df = self.df.slice(offset as i64, size); + self.in_process_right = Some((0..chunk.data.height()).step_by(size)); + let iter_right = self.in_process_right.as_mut().unwrap(); + let offset = iter_right.next().unwrap_or(0); + let right_df = chunk.data.slice(offset as i64, size); + + let (a, b) = if self.swapped { + (&right_df, &self.in_process_left_df) + } else { + (&self.in_process_left_df, &right_df) + }; + + let mut df = a.cross_join(b, Some(self.suffix.clone()), None)?; + // Cross joins can produce multiple chunks. + // No parallelize in operators + df.as_single_chunk(); + Ok(OperatorResult::HaveMoreOutPut(chunk.with_data(df))) + }, + } + }, + // deplete the right chunks over the current left chunk + Some(offset) => { + // this will be the branch of the first call + + let right_df = chunk.data.slice(offset as i64, size); + + let (a, b) = if self.swapped { + (&right_df, &self.in_process_left_df) + } else { + (&self.in_process_left_df, &right_df) + }; + + // we use the first join to determine the output names + // this we can amortize the name allocations. + let mut df = match &self.output_names { + None => { + let df = a.cross_join(b, Some(self.suffix.clone()), None)?; + self.output_names = Some(df.get_column_names_owned()); + df + }, + Some(names) => a._cross_join_with_names(b, names)?, + }; + // Cross joins can produce multiple chunks. + df.as_single_chunk(); + + Ok(OperatorResult::HaveMoreOutPut(chunk.with_data(df))) + }, + } + } + fn split(&self, _thread_no: usize) -> Box { + Box::new(self.clone()) + } + + fn fmt(&self) -> &str { + "cross_join_probe" + } +} diff --git a/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs b/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs new file mode 100644 index 000000000000..789738582482 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs @@ -0,0 +1,376 @@ +use std::any::Any; + +use arrow::array::BinaryArray; +use hashbrown::hash_map::RawEntryMut; +use polars_core::prelude::*; +use polars_core::utils::{_set_partition_size, accumulate_dataframes_vertical_unchecked}; +use polars_ops::prelude::JoinArgs; +use polars_utils::arena::Node; +use polars_utils::pl_str::PlSmallStr; +use polars_utils::unitvec; + +use self::row_encode::get_row_encoding_context; +use super::*; +use crate::executors::operators::PlaceHolder; +use crate::executors::sinks::HASHMAP_INIT_SIZE; +use crate::executors::sinks::joins::generic_probe_inner_left::GenericJoinProbe; +use crate::executors::sinks::joins::generic_probe_outer::GenericFullOuterJoinProbe; +use crate::executors::sinks::utils::{hash_rows, load_vec}; +use crate::expressions::PhysicalPipedExpr; +use crate::operators::{DataChunk, FinalizedSink, PExecutionContext, Sink, SinkResult}; + +pub(super) type ChunkIdx = IdxSize; +pub(super) type DfIdx = IdxSize; + +pub struct GenericBuild { + chunks: Vec, + // the join columns are all tightly packed + // the values of a join column(s) can be found + // by: + // first get the offset of the chunks and multiply that with the number of join + // columns + // * chunk_offset = (idx * n_join_keys) + // * end = (offset + n_join_keys) + materialized_join_cols: Vec>, + suffix: PlSmallStr, + hb: PlSeedableRandomStateQuality, + join_args: JoinArgs, + // partitioned tables that will be used for probing + // stores the key and the chunk_idx, df_idx of the left table + hash_tables: PartitionedMap, + + // the columns that will be joined on + join_columns_left: Arc>>, + join_columns_right: Arc>>, + + // amortize allocations + join_columns: Vec, + hashes: Vec, + // the join order is swapped to ensure we hash the smaller table + swapped: bool, + nulls_equal: bool, + node: Node, + key_names_left: Arc<[PlSmallStr]>, + key_names_right: Arc<[PlSmallStr]>, + placeholder: PlaceHolder, +} + +impl GenericBuild { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + suffix: PlSmallStr, + join_args: JoinArgs, + swapped: bool, + join_columns_left: Arc>>, + join_columns_right: Arc>>, + nulls_equal: bool, + node: Node, + key_names_left: Arc<[PlSmallStr]>, + key_names_right: Arc<[PlSmallStr]>, + placeholder: PlaceHolder, + ) -> Self { + let hb = PlSeedableRandomStateQuality::default(); + let partitions = _set_partition_size(); + let hash_tables = PartitionedHashMap::new(load_vec(partitions, || { + PlIdHashMap::with_capacity(HASHMAP_INIT_SIZE) + })); + GenericBuild { + chunks: vec![], + join_args, + suffix, + hb, + swapped, + join_columns_left, + join_columns_right, + join_columns: vec![], + materialized_join_cols: vec![], + hash_tables, + hashes: vec![], + nulls_equal, + node, + key_names_left, + key_names_right, + placeholder, + } + } +} + +#[inline] +pub(super) fn compare_fn( + key: &Key, + h: u64, + join_columns_all_chunks: &[BinaryArray], + current_row: &[u8], +) -> bool { + let key_hash = key.hash; + + // we check the hash first + // as that has no indirection + key_hash == h && { + // we get the appropriate values from the join columns and compare it with the current row + let (chunk_idx, df_idx) = key.idx.extract(); + let chunk_idx = chunk_idx as usize; + let df_idx = df_idx as usize; + + // get the right columns from the linearly packed buffer + let other_row = unsafe { + join_columns_all_chunks + .get_unchecked(chunk_idx) + .value_unchecked(df_idx) + }; + current_row == other_row + } +} + +impl GenericBuild { + fn is_empty(&self) -> bool { + match self.chunks.len() { + 0 => true, + 1 => self.chunks[0].is_empty(), + _ => false, + } + } + + fn set_join_series( + &mut self, + context: &PExecutionContext, + chunk: &DataChunk, + ) -> PolarsResult<&BinaryArray> { + debug_assert!(self.join_columns.is_empty()); + let mut ctxts = Vec::with_capacity(self.join_columns_left.len()); + for phys_e in self.join_columns_left.iter() { + let s = phys_e.evaluate(chunk, &context.execution_state)?; + let arr = s.to_physical_repr().rechunk().array_ref(0).clone(); + self.join_columns.push(arr); + ctxts.push(get_row_encoding_context(s.dtype(), false)); + } + let rows_encoded = polars_row::convert_columns_no_order( + self.join_columns[0].len(), // @NOTE: does not work for ZFS + &self.join_columns, + &ctxts, + ) + .into_array(); + self.materialized_join_cols.push(rows_encoded); + Ok(self.materialized_join_cols.last().unwrap()) + } + unsafe fn get_row(&self, chunk_idx: ChunkIdx, df_idx: DfIdx) -> &[u8] { + self.materialized_join_cols + .get_unchecked(chunk_idx as usize) + .value_unchecked(df_idx as usize) + } +} + +impl Sink for GenericBuild { + fn node(&self) -> Node { + self.node + } + fn is_join_build(&self) -> bool { + true + } + + fn sink(&mut self, context: &PExecutionContext, chunk: DataChunk) -> PolarsResult { + // we do some juggling here so that we don't + // end up with empty chunks + // But we always want one empty chunk if all is empty as we need + // to finish the join + if self.chunks.len() == 1 && self.chunks[0].is_empty() { + self.chunks.pop().unwrap(); + } + if chunk.is_empty() { + if self.chunks.is_empty() { + self.chunks.push(chunk) + } + return Ok(SinkResult::CanHaveMoreInput); + } + let mut hashes = std::mem::take(&mut self.hashes); + let rows = self.set_join_series(context, &chunk)?.clone(); + hash_rows(&rows, &mut hashes, &self.hb); + self.hashes = hashes; + + let current_chunk_offset = self.chunks.len() as ChunkIdx; + + // row offset in the chunk belonging to the hash + let mut current_df_idx = 0 as IdxSize; + for (row, h) in rows.values_iter().zip(&self.hashes) { + let entry = self.hash_tables.raw_entry_mut(*h).from_hash(*h, |key| { + compare_fn(key, *h, &self.materialized_join_cols, row) + }); + + let payload = ChunkId::store(current_chunk_offset, current_df_idx); + match entry { + RawEntryMut::Vacant(entry) => { + let key = Key::new(*h, current_chunk_offset, current_df_idx); + entry.insert(key, (unitvec![payload], Default::default())); + }, + RawEntryMut::Occupied(mut entry) => { + entry.get_mut().0.push(payload); + }, + }; + + current_df_idx += 1; + } + + // clear memory + self.hashes.clear(); + self.join_columns.clear(); + + self.chunks.push(chunk); + Ok(SinkResult::CanHaveMoreInput) + } + + fn combine(&mut self, other: &mut dyn Sink) { + if self.is_empty() { + let other = other.as_any().downcast_mut::().unwrap(); + if !other.is_empty() { + std::mem::swap(self, other); + } + return; + } + let other = other.as_any().downcast_ref::().unwrap(); + if other.is_empty() { + return; + } + + let chunks_offset = self.chunks.len() as IdxSize; + self.chunks.extend_from_slice(&other.chunks); + self.materialized_join_cols + .extend_from_slice(&other.materialized_join_cols); + + // we combine the other hashtable with ours, but we must offset the chunk_idx + // values by the number of chunks we already got. + self.hash_tables + .inner_mut() + .iter_mut() + .zip(other.hash_tables.inner()) + .for_each(|(ht, other_ht)| { + for (k, val) in other_ht.iter() { + let val = &val.0; + let (chunk_idx, df_idx) = k.idx.extract(); + // Use the indexes to materialize the row. + let other_row = unsafe { other.get_row(chunk_idx, df_idx) }; + + let h = k.hash; + let entry = ht.raw_entry_mut().from_hash(h, |key| { + compare_fn(key, h, &self.materialized_join_cols, other_row) + }); + + match entry { + RawEntryMut::Vacant(entry) => { + let chunk_id = unsafe { val.get_unchecked(0) }; + let (chunk_idx, df_idx) = chunk_id.extract(); + let new_chunk_idx = chunk_idx + chunks_offset; + let key = Key::new(h, new_chunk_idx, df_idx); + let mut payload = unitvec![ChunkId::store(new_chunk_idx, df_idx)]; + if val.len() > 1 { + let iter = val[1..].iter().map(|chunk_id| { + let (chunk_idx, val_idx) = chunk_id.extract(); + ChunkId::store(chunk_idx + chunks_offset, val_idx) + }); + payload.extend(iter); + } + entry.insert(key, (payload, Default::default())); + }, + RawEntryMut::Occupied(mut entry) => { + let iter = val.iter().map(|chunk_id| { + let (chunk_idx, val_idx) = chunk_id.extract(); + ChunkId::store(chunk_idx + chunks_offset, val_idx) + }); + entry.get_mut().0.extend(iter); + }, + } + } + }) + } + + fn split(&self, _thread_no: usize) -> Box { + let mut new = Self::new( + self.suffix.clone(), + self.join_args.clone(), + self.swapped, + self.join_columns_left.clone(), + self.join_columns_right.clone(), + self.nulls_equal, + self.node, + self.key_names_left.clone(), + self.key_names_right.clone(), + self.placeholder.clone(), + ); + new.hb = self.hb; + Box::new(new) + } + + fn finalize(&mut self, context: &PExecutionContext) -> PolarsResult { + let chunks_len = self.chunks.len(); + let left_df = accumulate_dataframes_vertical_unchecked( + std::mem::take(&mut self.chunks) + .into_iter() + .map(|chunk| chunk.data), + ); + if left_df.height() > 0 { + assert_eq!(left_df.first_col_n_chunks(), chunks_len); + } + // Reallocate to Arc<[]> to get rid of double indirection as this is accessed on every + // hashtable cmp. + let materialized_join_cols = Arc::from(std::mem::take(&mut self.materialized_join_cols)); + let suffix = self.suffix.clone(); + let hb = self.hb; + let hash_tables = Arc::new(PartitionedHashMap::new(std::mem::take( + self.hash_tables.inner_mut(), + ))); + let join_columns_left = self.join_columns_left.clone(); + let join_columns_right = self.join_columns_right.clone(); + + // take the buffers, this saves one allocation + let mut hashes = std::mem::take(&mut self.hashes); + hashes.clear(); + + match self.join_args.how { + JoinType::Inner | JoinType::Left => { + let probe_operator = GenericJoinProbe::new( + left_df, + materialized_join_cols, + suffix, + hb, + hash_tables, + join_columns_left, + join_columns_right, + self.swapped, + hashes, + context, + self.join_args.clone(), + self.nulls_equal, + ); + self.placeholder.replace(Box::new(probe_operator)); + Ok(FinalizedSink::Operator) + }, + JoinType::Full => { + let coalesce = self.join_args.coalesce.coalesce(&JoinType::Full); + let probe_operator = GenericFullOuterJoinProbe::new( + left_df, + materialized_join_cols, + suffix, + hb, + hash_tables, + join_columns_left, + self.swapped, + hashes, + self.nulls_equal, + coalesce, + self.key_names_left.clone(), + self.key_names_right.clone(), + ); + self.placeholder.replace(Box::new(probe_operator)); + Ok(FinalizedSink::Operator) + }, + + _ => unimplemented!(), + } + } + + fn as_any(&mut self) -> &mut dyn Any { + self + } + fn fmt(&self) -> &str { + "generic_join_build" + } +} diff --git a/crates/polars-pipe/src/executors/sinks/joins/generic_probe_inner_left.rs b/crates/polars-pipe/src/executors/sinks/joins/generic_probe_inner_left.rs new file mode 100644 index 000000000000..ea51913940be --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/joins/generic_probe_inner_left.rs @@ -0,0 +1,326 @@ +use std::borrow::Cow; + +use arrow::array::{Array, BinaryArray}; +use polars_core::prelude::*; +use polars_core::series::IsSorted; +use polars_ops::frame::join::_finish_join; +use polars_ops::prelude::{JoinArgs, JoinType, TakeChunked}; +use polars_utils::nulls::IsNull; +use polars_utils::pl_str::PlSmallStr; + +use crate::executors::sinks::joins::generic_build::*; +use crate::executors::sinks::joins::row_values::RowValues; +use crate::executors::sinks::joins::{ExtraPayload, PartitionedMap, ToRow}; +use crate::executors::sinks::utils::hash_rows; +use crate::expressions::PhysicalPipedExpr; +use crate::operators::{DataChunk, Operator, OperatorResult, PExecutionContext}; + +#[derive(Clone)] +pub struct GenericJoinProbe { + /// All chunks are stacked into a single dataframe + /// the dataframe is not rechunked. + df_a: Arc, + /// The join columns are all tightly packed + /// the values of a join column(s) can be found + /// by: + /// first get the offset of the chunks and multiply that with the number of join + /// columns + /// * chunk_offset = (idx * n_join_keys) + /// * end = (offset + n_join_keys) + materialized_join_cols: Arc<[BinaryArray]>, + suffix: PlSmallStr, + hb: PlSeedableRandomStateQuality, + /// partitioned tables that will be used for probing + /// stores the key and the chunk_idx, df_idx of the left table + hash_tables: Arc>, + + /// Amortize allocations + /// In inner join these are the left table. + /// In left join there are the right table. + join_tuples_a: Vec, + /// in inner join these are the right table + /// in left join there are the left table + join_tuples_b: Vec, + hashes: Vec, + /// the join order is swapped to ensure we hash the smaller table + swapped_or_left: bool, + /// cached output names + output_names: Option>, + args: JoinArgs, + nulls_equal: bool, + row_values: RowValues, +} + +impl GenericJoinProbe { + #[allow(clippy::too_many_arguments)] + pub(super) fn new( + mut df_a: DataFrame, + materialized_join_cols: Arc<[BinaryArray]>, + suffix: PlSmallStr, + hb: PlSeedableRandomStateQuality, + hash_tables: Arc>, + join_columns_left: Arc>>, + join_columns_right: Arc>>, + swapped_or_left: bool, + // Re-use the hashes allocation of the build side. + amortized_hashes: Vec, + context: &PExecutionContext, + args: JoinArgs, + nulls_equal: bool, + ) -> Self { + if swapped_or_left && args.should_coalesce() { + let tmp = DataChunk { + data: df_a.slice(0, 1), + chunk_index: 0, + }; + + // remove duplicate_names caused by joining + // on the same column + let names = join_columns_left + .iter() + .flat_map(|phys_e| { + phys_e + .evaluate(&tmp, &context.execution_state) + .ok() + .map(|s| s.name().clone()) + }) + .collect::>(); + df_a = df_a.drop_many_amortized(&names) + } + + GenericJoinProbe { + df_a: Arc::new(df_a), + materialized_join_cols, + suffix, + hb, + hash_tables, + join_tuples_a: vec![], + join_tuples_b: vec![], + hashes: amortized_hashes, + swapped_or_left, + output_names: None, + args, + nulls_equal, + row_values: RowValues::new(join_columns_right, !swapped_or_left), + } + } + + fn finish_join( + &mut self, + mut left_df: DataFrame, + right_df: DataFrame, + ) -> PolarsResult { + Ok(match &self.output_names { + None => { + let out = _finish_join(left_df, right_df, Some(self.suffix.clone()))?; + self.output_names = Some(out.get_column_names_owned()); + out + }, + Some(names) => unsafe { + // SAFETY: + // if we have duplicate names, we overwrite + // them in the next snippet + left_df + .get_columns_mut() + .extend_from_slice(right_df.get_columns()); + left_df + .get_columns_mut() + .iter_mut() + .zip(names) + .for_each(|(s, name)| { + s.rename(name.clone()); + }); + left_df.clear_schema(); + left_df + }, + }) + } + + fn match_left<'b, I, T>(&mut self, iter: I) + where + I: Iterator + 'b, + T: IsNull + // Temporary trait to get concrete &[u8] + // Input is either &[u8] or Option<&[u8]> + + ToRow, + { + for (i, (h, row)) in iter { + let df_idx_left = i as IdxSize; + + let entry = if row.is_null() { + None + } else { + let row = row.get_row(); + self.hash_tables + .raw_entry(*h) + .from_hash(*h, |key| { + compare_fn(key, *h, &self.materialized_join_cols, row) + }) + .map(|key_val| key_val.1) + }; + + match entry { + Some(indexes_right) => { + let indexes_right = &indexes_right.0; + self.join_tuples_a.extend_from_slice(indexes_right); + self.join_tuples_b + .extend(std::iter::repeat_n(df_idx_left, indexes_right.len())); + }, + None => { + self.join_tuples_b.push(df_idx_left); + self.join_tuples_a.push(ChunkId::null()); + }, + } + } + } + + fn execute_left( + &mut self, + context: &PExecutionContext, + chunk: &DataChunk, + ) -> PolarsResult { + // A left join holds the right table as build table + // and streams the left table through. This allows us to maintain + // the left table order + + self.join_tuples_a.clear(); + self.join_tuples_b.clear(); + let mut hashes = std::mem::take(&mut self.hashes); + let rows = self + .row_values + .get_values(context, chunk, self.nulls_equal)?; + hash_rows(&rows, &mut hashes, &self.hb); + + if self.nulls_equal || rows.null_count() == 0 { + let iter = hashes.iter().zip(rows.values_iter()).enumerate(); + self.match_left(iter); + } else { + let iter = hashes.iter().zip(rows.iter()).enumerate(); + self.match_left(iter); + } + self.hashes = hashes; + let right_df = self.df_a.as_ref(); + + // join tuples of left joins are always sorted + // this will ensure sorted flags maintain + let left_df = unsafe { + chunk + .data + ._take_unchecked_slice_sorted(&self.join_tuples_b, false, IsSorted::Ascending) + }; + let right_df = unsafe { right_df.take_opt_chunked_unchecked(&self.join_tuples_a, false) }; + + let out = self.finish_join(left_df, right_df)?; + + // Clear memory. + self.row_values.clear(); + self.hashes.clear(); + + Ok(OperatorResult::Finished(chunk.with_data(out))) + } + + fn match_inner<'b, I>(&mut self, iter: I) + where + I: Iterator + 'b, + { + for (i, (h, row)) in iter { + let df_idx_right = i as IdxSize; + + let entry = self + .hash_tables + .raw_entry(*h) + .from_hash(*h, |key| { + compare_fn(key, *h, &self.materialized_join_cols, row) + }) + .map(|key_val| key_val.1); + + if let Some(indexes_left) = entry { + let indexes_left = &indexes_left.0; + self.join_tuples_a.extend_from_slice(indexes_left); + self.join_tuples_b + .extend(std::iter::repeat_n(df_idx_right, indexes_left.len())); + } + } + } + + fn execute_inner( + &mut self, + context: &PExecutionContext, + chunk: &DataChunk, + ) -> PolarsResult { + self.join_tuples_a.clear(); + self.join_tuples_b.clear(); + let mut hashes = std::mem::take(&mut self.hashes); + let rows = self + .row_values + .get_values(context, chunk, self.nulls_equal)?; + hash_rows(&rows, &mut hashes, &self.hb); + + if self.nulls_equal || rows.null_count() == 0 { + let iter = hashes.iter().zip(rows.values_iter()).enumerate(); + self.match_inner(iter); + } else { + let iter = hashes + .iter() + .zip(rows.iter()) + .enumerate() + .filter_map(|(i, (h, row))| row.map(|row| (i, (h, row)))); + self.match_inner(iter); + } + self.hashes = hashes; + + let left_df = unsafe { + self.df_a + .take_chunked_unchecked(&self.join_tuples_a, IsSorted::Not, false) + }; + let right_df = unsafe { + let mut df = Cow::Borrowed(&chunk.data); + if let Some(ids) = &self.row_values.join_column_idx { + let mut tmp = df.into_owned(); + let cols = tmp.get_columns_mut(); + // we go from higher idx to lower so that lower indices remain untouched + // by our mutation + for idx in ids.iter().rev() { + let _ = cols.remove(*idx); + } + df = Cow::Owned(tmp); + } + df._take_unchecked_slice(&self.join_tuples_b, false) + }; + + let (a, b) = if self.swapped_or_left { + (right_df, left_df) + } else { + (left_df, right_df) + }; + let out = self.finish_join(a, b)?; + + // Clear memory. + self.row_values.clear(); + self.hashes.clear(); + + Ok(OperatorResult::Finished(chunk.with_data(out))) + } +} + +impl Operator for GenericJoinProbe { + fn execute( + &mut self, + context: &PExecutionContext, + chunk: &DataChunk, + ) -> PolarsResult { + match self.args.how { + JoinType::Inner => self.execute_inner(context, chunk), + JoinType::Left => self.execute_left(context, chunk), + _ => unreachable!(), + } + } + + fn split(&self, _thread_no: usize) -> Box { + let new = self.clone(); + Box::new(new) + } + fn fmt(&self) -> &str { + "generic_join_probe" + } +} diff --git a/crates/polars-pipe/src/executors/sinks/joins/generic_probe_outer.rs b/crates/polars-pipe/src/executors/sinks/joins/generic_probe_outer.rs new file mode 100644 index 000000000000..6e47c22e0f0b --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/joins/generic_probe_outer.rs @@ -0,0 +1,310 @@ +use std::sync::atomic::Ordering; + +use arrow::array::{Array, BinaryArray, MutablePrimitiveArray}; +use polars_core::prelude::*; +use polars_core::series::IsSorted; +use polars_ops::frame::join::_finish_join; +use polars_ops::prelude::{_coalesce_full_join, TakeChunked}; +use polars_utils::pl_str::PlSmallStr; + +use crate::executors::sinks::ExtraPayload; +use crate::executors::sinks::joins::PartitionedMap; +use crate::executors::sinks::joins::generic_build::*; +use crate::executors::sinks::joins::row_values::RowValues; +use crate::executors::sinks::utils::hash_rows; +use crate::expressions::PhysicalPipedExpr; +use crate::operators::{DataChunk, Operator, OperatorResult, PExecutionContext}; + +#[derive(Clone)] +pub struct GenericFullOuterJoinProbe { + /// all chunks are stacked into a single dataframe + /// the dataframe is not rechunked. + df_a: Arc, + // Dummy needed for the flush phase. + df_b_flush_dummy: Option, + /// The join columns are all tightly packed + /// the values of a join column(s) can be found + /// by: + /// first get the offset of the chunks and multiply that with the number of join + /// columns + /// * chunk_offset = (idx * n_join_keys) + /// * end = (offset + n_join_keys) + materialized_join_cols: Arc<[BinaryArray]>, + suffix: PlSmallStr, + hb: PlSeedableRandomStateQuality, + /// partitioned tables that will be used for probing. + /// stores the key and the chunk_idx, df_idx of the left table. + hash_tables: Arc>, + + // amortize allocations + // in inner join these are the left table + // in left join there are the right table + join_tuples_a: Vec, + // in inner join these are the right table + // in left join there are the left table + join_tuples_b: MutablePrimitiveArray, + hashes: Vec, + // the join order is swapped to ensure we hash the smaller table + swapped: bool, + // cached output names + output_names: Option>, + nulls_equal: bool, + coalesce: bool, + thread_no: usize, + row_values: RowValues, + key_names_left: Arc<[PlSmallStr]>, + key_names_right: Arc<[PlSmallStr]>, +} + +impl GenericFullOuterJoinProbe { + #[allow(clippy::too_many_arguments)] + pub(super) fn new( + df_a: DataFrame, + materialized_join_cols: Arc<[BinaryArray]>, + suffix: PlSmallStr, + hb: PlSeedableRandomStateQuality, + hash_tables: Arc>, + join_columns_right: Arc>>, + swapped: bool, + // Re-use the hashes allocation of the build side. + amortized_hashes: Vec, + nulls_equal: bool, + coalesce: bool, + key_names_left: Arc<[PlSmallStr]>, + key_names_right: Arc<[PlSmallStr]>, + ) -> Self { + GenericFullOuterJoinProbe { + df_a: Arc::new(df_a), + df_b_flush_dummy: None, + materialized_join_cols, + suffix, + hb, + hash_tables, + join_tuples_a: vec![], + join_tuples_b: MutablePrimitiveArray::new(), + hashes: amortized_hashes, + swapped, + output_names: None, + nulls_equal, + coalesce, + thread_no: 0, + row_values: RowValues::new(join_columns_right, false), + key_names_left, + key_names_right, + } + } + + fn finish_join(&mut self, left_df: DataFrame, right_df: DataFrame) -> PolarsResult { + fn inner( + left_df: DataFrame, + right_df: DataFrame, + suffix: PlSmallStr, + swapped: bool, + output_names: &mut Option>, + ) -> PolarsResult { + let (mut left_df, right_df) = if swapped { + (right_df, left_df) + } else { + (left_df, right_df) + }; + Ok(match output_names { + None => { + let out = _finish_join(left_df, right_df, Some(suffix))?; + *output_names = Some(out.get_column_names_owned()); + out + }, + Some(names) => unsafe { + // SAFETY: + // if we have duplicate names, we overwrite + // them in the next snippet + left_df + .get_columns_mut() + .extend_from_slice(right_df.get_columns()); + + // @TODO: Is this actually the case? + // SAFETY: output_names should be unique. + left_df + .get_columns_mut() + .iter_mut() + .zip(names) + .for_each(|(s, name)| { + s.rename(name.clone()); + }); + left_df.clear_schema(); + left_df + }, + }) + } + + if self.coalesce { + let out = inner( + left_df.clone(), + right_df, + self.suffix.clone(), + self.swapped, + &mut self.output_names, + )?; + let l = self.key_names_left.iter().cloned().collect::>(); + let r = self.key_names_right.iter().cloned().collect::>(); + Ok(_coalesce_full_join( + out, + l.as_slice(), + r.as_slice(), + Some(self.suffix.clone()), + &left_df, + )) + } else { + inner( + left_df.clone(), + right_df, + self.suffix.clone(), + self.swapped, + &mut self.output_names, + ) + } + } + + fn match_outer<'b, I>(&mut self, iter: I) + where + I: Iterator + 'b, + { + for (i, (h, row)) in iter { + let df_idx_right = i as IdxSize; + + let entry = self + .hash_tables + .raw_entry(*h) + .from_hash(*h, |key| { + compare_fn(key, *h, &self.materialized_join_cols, row) + }) + .map(|key_val| key_val.1); + + if let Some((indexes_left, tracker)) = entry { + // compiles to normal store: https://rust.godbolt.org/z/331hMo339 + tracker.get_tracker().store(true, Ordering::Relaxed); + + self.join_tuples_a.extend_from_slice(indexes_left); + self.join_tuples_b + .extend_constant(indexes_left.len(), Some(df_idx_right)); + } else { + self.join_tuples_a.push(ChunkId::null()); + self.join_tuples_b.push_value(df_idx_right); + } + } + } + + fn execute_outer( + &mut self, + context: &PExecutionContext, + chunk: &DataChunk, + ) -> PolarsResult { + self.join_tuples_a.clear(); + self.join_tuples_b.clear(); + + if self.df_b_flush_dummy.is_none() { + self.df_b_flush_dummy = Some(chunk.data.clear()) + } + + let mut hashes = std::mem::take(&mut self.hashes); + let rows = self + .row_values + .get_values(context, chunk, self.nulls_equal)?; + hash_rows(&rows, &mut hashes, &self.hb); + + if self.nulls_equal || rows.null_count() == 0 { + let iter = hashes.iter().zip(rows.values_iter()).enumerate(); + self.match_outer(iter); + } else { + let iter = hashes + .iter() + .zip(rows.iter()) + .enumerate() + .filter_map(|(i, (h, row))| row.map(|row| (i, (h, row)))); + self.match_outer(iter); + } + self.hashes = hashes; + + let left_df = unsafe { + self.df_a + .take_opt_chunked_unchecked(&self.join_tuples_a, false) + }; + let right_df = unsafe { + self.join_tuples_b.with_freeze(|idx| { + let idx = IdxCa::from(idx.clone()); + let out = chunk.data.take_unchecked_impl(&idx, false); + // Drop so that the freeze context can go back to mutable array. + drop(idx); + out + }) + }; + let out = self.finish_join(left_df, right_df)?; + Ok(OperatorResult::Finished(chunk.with_data(out))) + } + + fn execute_flush(&mut self) -> PolarsResult { + let ht = self.hash_tables.inner(); + let n = ht.len(); + self.join_tuples_a.clear(); + + ht.iter().enumerate().for_each(|(i, ht)| { + if i % n == self.thread_no { + ht.iter().for_each(|(_k, (idx_left, tracker))| { + let found_match = tracker.get_tracker().load(Ordering::Relaxed); + + if !found_match { + self.join_tuples_a.extend_from_slice(idx_left); + } + }) + } + }); + + let left_df = unsafe { + self.df_a + .take_chunked_unchecked(&self.join_tuples_a, IsSorted::Not, false) + }; + + let size = left_df.height(); + let right_df = self.df_b_flush_dummy.as_ref().unwrap(); + + let right_df = unsafe { + DataFrame::new_no_checks( + size, + right_df + .get_columns() + .iter() + .map(|s| Column::full_null(s.name().clone(), size, s.dtype())) + .collect(), + ) + }; + + let out = self.finish_join(left_df, right_df)?; + Ok(OperatorResult::Finished(DataChunk::new(0, out))) + } +} + +impl Operator for GenericFullOuterJoinProbe { + fn execute( + &mut self, + context: &PExecutionContext, + chunk: &DataChunk, + ) -> PolarsResult { + self.execute_outer(context, chunk) + } + + fn flush(&mut self) -> PolarsResult { + self.execute_flush() + } + + fn must_flush(&self) -> bool { + self.df_b_flush_dummy.is_some() + } + + fn split(&self, thread_no: usize) -> Box { + let mut new = self.clone(); + new.thread_no = thread_no; + Box::new(new) + } + fn fmt(&self) -> &str { + "generic_full_join_probe" + } +} diff --git a/crates/polars-pipe/src/executors/sinks/joins/mod.rs b/crates/polars-pipe/src/executors/sinks/joins/mod.rs new file mode 100644 index 000000000000..53ba6e896b71 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/joins/mod.rs @@ -0,0 +1,101 @@ +#[cfg(feature = "cross_join")] +mod cross; +mod generic_build; +mod generic_probe_inner_left; +mod generic_probe_outer; +mod row_values; + +use std::hash::{BuildHasherDefault, Hash, Hasher}; +use std::sync::atomic::AtomicBool; + +#[cfg(feature = "cross_join")] +pub(crate) use cross::*; +pub(crate) use generic_build::GenericBuild; +use polars_core::hashing::IdHasher; +use polars_core::prelude::IdxSize; +use polars_ops::prelude::JoinType; +use polars_utils::idx_vec::UnitVec; +use polars_utils::index::ChunkId; +use polars_utils::partitioned::PartitionedHashMap; + +trait ToRow { + fn get_row(&self) -> &[u8]; +} + +impl ToRow for &[u8] { + #[inline(always)] + fn get_row(&self) -> &[u8] { + self + } +} + +impl ToRow for Option<&[u8]> { + #[inline(always)] + fn get_row(&self) -> &[u8] { + self.unwrap() + } +} + +// This is the hash and the Index offset in the chunks and the index offset in the dataframe +#[derive(Copy, Clone, Debug)] +#[repr(C)] +pub(super) struct Key { + pub(super) hash: u64, + /// We use the MSB as tracker for outer join matches + /// So the 25th bit of the chunk_idx will be used for that. + idx: ChunkId, +} + +impl Key { + #[inline] + fn new(hash: u64, chunk_idx: IdxSize, df_idx: IdxSize) -> Self { + let idx = ChunkId::store(chunk_idx, df_idx); + Key { hash, idx } + } +} + +impl Hash for Key { + #[inline] + fn hash(&self, state: &mut H) { + state.write_u64(self.hash) + } +} + +pub(crate) trait ExtraPayload: Clone + Sync + Send + Default + 'static { + /// Tracker used in the outer join. + fn get_tracker(&self) -> &AtomicBool { + panic!() + } +} +impl ExtraPayload for () {} + +#[repr(transparent)] +pub(crate) struct Tracker { + inner: AtomicBool, +} + +impl Default for Tracker { + #[inline] + fn default() -> Self { + Self { + inner: AtomicBool::new(false), + } + } +} + +// Needed for the trait resolving. We should never hit this. +impl Clone for Tracker { + fn clone(&self) -> Self { + panic!() + } +} + +impl ExtraPayload for Tracker { + #[inline(always)] + fn get_tracker(&self) -> &AtomicBool { + &self.inner + } +} + +type PartitionedMap = + PartitionedHashMap, V), BuildHasherDefault>; diff --git a/crates/polars-pipe/src/executors/sinks/joins/row_values.rs b/crates/polars-pipe/src/executors/sinks/joins/row_values.rs new file mode 100644 index 000000000000..1b2b852a8f61 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/joins/row_values.rs @@ -0,0 +1,100 @@ +use std::sync::Arc; + +use arrow::array::{ArrayRef, BinaryArray, StaticArray}; +use arrow::compute::utils::combine_validities_and_many; +use polars_core::error::PolarsResult; +use polars_core::prelude::row_encode::get_row_encoding_context; +use polars_row::RowsEncoded; + +use crate::expressions::PhysicalPipedExpr; +use crate::operators::{DataChunk, PExecutionContext}; + +#[derive(Clone)] +pub(super) struct RowValues { + current_rows: RowsEncoded, + join_column_eval: Arc>>, + join_columns_material: Vec, + // Location of join columns. + // These column locations need to be dropped from the rhs + pub join_column_idx: Option>, + det_join_idx: bool, +} + +impl RowValues { + pub(super) fn new( + join_column_eval: Arc>>, + det_join_idx: bool, + ) -> Self { + Self { + current_rows: Default::default(), + join_column_eval, + join_column_idx: None, + join_columns_material: vec![], + det_join_idx, + } + } + + pub(super) fn clear(&mut self) { + self.join_columns_material.clear(); + } + + pub(super) fn get_values( + &mut self, + context: &PExecutionContext, + chunk: &DataChunk, + nulls_equal: bool, + ) -> PolarsResult> { + // Memory should already be cleared on previous iteration. + debug_assert!(self.join_columns_material.is_empty()); + let determine_idx = self.det_join_idx && self.join_column_idx.is_none(); + let mut names = vec![]; + + let mut ctxts = Vec::with_capacity(self.join_column_eval.len()); + for phys_e in self.join_column_eval.iter() { + let s = phys_e.evaluate(chunk, &context.execution_state)?; + let mut s = s.to_physical_repr().rechunk(); + if chunk.data.is_empty() { + s = s.clear() + }; + if determine_idx { + names.push(s.name().to_string()); + } + self.join_columns_material.push(s.array_ref(0).clone()); + ctxts.push(get_row_encoding_context(s.dtype(), false)); + } + + // We determine the indices of the columns that have to be removed + // if swapped the join column is already removed from the `build_df` as that will + // be the rhs one. + if determine_idx { + let mut idx = names + .iter() + .filter_map(|name| chunk.data.get_column_index(name)) + .collect::>(); + // Ensure that it is sorted so that we can later remove columns in + // a predictable order + idx.sort_unstable(); + self.join_column_idx = Some(idx); + } + polars_row::convert_columns_amortized_no_order( + self.join_columns_material[0].len(), // @NOTE: does not work for ZFS + &self.join_columns_material, + &ctxts, + &mut self.current_rows, + ); + + // SAFETY: we keep rows-encode alive + let array = unsafe { self.current_rows.borrow_array() }; + Ok(if nulls_equal { + array + } else { + let validities = self + .join_columns_material + .iter() + .map(|arr| arr.validity()) + .collect::>(); + let validity = combine_validities_and_many(&validities); + array.with_validity_typed(validity) + }) + } +} diff --git a/crates/polars-pipe/src/executors/sinks/memory.rs b/crates/polars-pipe/src/executors/sinks/memory.rs new file mode 100644 index 000000000000..27d10792fb65 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/memory.rs @@ -0,0 +1,78 @@ +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use polars_utils::sys::MEMINFO; + +use crate::pipeline::FORCE_OOC; + +const TO_MB: usize = 2 << 19; + +#[derive(Clone)] +pub(super) struct MemTracker { + // available memory at the start of this node + available_mem: Arc, + used_by_node: Arc, + fetch_count: Arc, + thread_count: usize, + available_at_start: usize, + refresh_interval: usize, +} + +impl MemTracker { + pub(super) fn new(thread_count: usize) -> Self { + let refresh_interval = if std::env::var(FORCE_OOC).is_ok() { + 1 + } else { + 64 + }; + + let mut out = Self { + available_mem: Default::default(), + used_by_node: Default::default(), + fetch_count: Arc::new(AtomicUsize::new(1)), + thread_count, + available_at_start: 0, + refresh_interval, + }; + let available = MEMINFO.free() as usize; + out.available_mem.store(available, Ordering::Relaxed); + out.available_at_start = available; + out + } + + /// This shouldn't be called often as this is expensive. + pub fn refresh_memory(&self) { + self.available_mem + .store(MEMINFO.free() as usize, Ordering::Relaxed); + } + + /// Get available memory of the system measured on latest refresh. + pub(super) fn get_available(&self) -> usize { + // once in every n passes we fetch mem usage. + let fetch_count = self.fetch_count.fetch_add(1, Ordering::Relaxed); + + if fetch_count % (self.refresh_interval * self.thread_count) == 0 { + self.refresh_memory() + } + self.available_mem.load(Ordering::Relaxed) + } + + pub(super) fn get_available_latest(&self) -> usize { + self.refresh_memory(); + self.fetch_count.store(0, Ordering::Relaxed); + self.available_mem.load(Ordering::Relaxed) + } + + pub(super) fn free_memory_fraction_since_start(&self) -> f64 { + // We divide first to reduce the precision loss in floats. + // We also add 1.0 to available_at_start to prevent division by zero. + let available_at_start = (self.available_at_start / TO_MB) as f64 + 1.0; + let available = (self.get_available() / TO_MB) as f64; + available / available_at_start + } + + /// Increment the used memory and return the previous value. + pub(super) fn fetch_add(&self, add: usize) -> usize { + self.used_by_node.fetch_add(add, Ordering::Relaxed) + } +} diff --git a/crates/polars-pipe/src/executors/sinks/mod.rs b/crates/polars-pipe/src/executors/sinks/mod.rs new file mode 100644 index 000000000000..86d46eb3082c --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/mod.rs @@ -0,0 +1,43 @@ +pub(crate) mod group_by; +mod io; +mod joins; +mod memory; +mod ordered; +mod output; +mod reproject; +mod slice; +mod sort; +mod utils; + +use std::sync::OnceLock; + +pub(crate) use joins::*; +pub(crate) use ordered::*; +#[cfg(any( + feature = "parquet", + feature = "ipc", + feature = "csv", + feature = "json" +))] +pub(crate) use output::*; +pub(crate) use reproject::*; +pub(crate) use slice::*; +pub(crate) use sort::*; + +// We must strike a balance between cache coherence and resizing costs. +// Overallocation seems a lot more expensive than resizing so we start reasonable small. +const HASHMAP_INIT_SIZE: usize = 64; + +pub(crate) static POLARS_TEMP_DIR: OnceLock = OnceLock::new(); + +pub(crate) fn get_base_temp_dir() -> &'static str { + POLARS_TEMP_DIR.get_or_init(|| { + let tmp = std::env::var("POLARS_TEMP_DIR") + .unwrap_or_else(|_| std::env::temp_dir().to_string_lossy().into_owned()); + + if polars_core::config::verbose() { + eprintln!("Temporary directory path in use: {}", &tmp); + } + tmp + }) +} diff --git a/crates/polars-pipe/src/executors/sinks/ordered.rs b/crates/polars-pipe/src/executors/sinks/ordered.rs new file mode 100644 index 000000000000..09bc10e9046e --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/ordered.rs @@ -0,0 +1,66 @@ +use std::any::Any; + +use polars_core::error::PolarsResult; +use polars_core::frame::DataFrame; +use polars_core::schema::SchemaRef; + +use crate::operators::{ + DataChunk, FinalizedSink, PExecutionContext, Sink, SinkResult, chunks_to_df_unchecked, +}; + +// Ensure the data is return in the order it was streamed +#[derive(Clone)] +pub struct OrderedSink { + chunks: Vec, + schema: SchemaRef, +} + +impl OrderedSink { + pub fn new(schema: SchemaRef) -> Self { + OrderedSink { + chunks: vec![], + schema, + } + } + + fn sort(&mut self) { + self.chunks.sort_unstable_by_key(|chunk| chunk.chunk_index); + } +} + +impl Sink for OrderedSink { + fn sink(&mut self, _context: &PExecutionContext, chunk: DataChunk) -> PolarsResult { + // don't add empty dataframes + if chunk.data.height() > 0 || self.chunks.is_empty() { + self.chunks.push(chunk); + } + Ok(SinkResult::CanHaveMoreInput) + } + + fn combine(&mut self, other: &mut dyn Sink) { + let other = other.as_any().downcast_ref::().unwrap(); + self.chunks.extend_from_slice(&other.chunks); + self.sort(); + } + + fn split(&self, _thread_no: usize) -> Box { + Box::new(self.clone()) + } + fn finalize(&mut self, _context: &PExecutionContext) -> PolarsResult { + if self.chunks.is_empty() { + return Ok(FinalizedSink::Finished(DataFrame::empty_with_schema( + &self.schema, + ))); + } + self.sort(); + + let chunks = std::mem::take(&mut self.chunks); + Ok(FinalizedSink::Finished(chunks_to_df_unchecked(chunks))) + } + fn as_any(&mut self) -> &mut dyn Any { + self + } + fn fmt(&self) -> &str { + "ordered_sink" + } +} diff --git a/crates/polars-pipe/src/executors/sinks/output/csv.rs b/crates/polars-pipe/src/executors/sinks/output/csv.rs new file mode 100644 index 000000000000..55fa347cf9d2 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/output/csv.rs @@ -0,0 +1,67 @@ +use std::path::Path; + +use crossbeam_channel::bounded; +use polars_core::prelude::*; +use polars_io::SerWriter; +use polars_io::cloud::CloudOptions; +use polars_io::csv::write::{CsvWriter, CsvWriterOptions}; +use polars_io::utils::file::try_get_writeable; + +use crate::executors::sinks::output::file_sink::{FilesSink, SinkWriter, init_writer_thread}; +use crate::pipeline::morsels_per_sink; + +pub struct CsvSink {} +impl CsvSink { + #[allow(clippy::new_ret_no_self)] + pub fn new( + path: &Path, + options: CsvWriterOptions, + schema: &Schema, + cloud_options: Option<&CloudOptions>, + ) -> PolarsResult { + let writer = CsvWriter::new(try_get_writeable(path.to_str().unwrap(), cloud_options)?) + .include_bom(options.include_bom) + .include_header(options.include_header) + .with_separator(options.serialize_options.separator) + .with_line_terminator(options.serialize_options.line_terminator) + .with_quote_char(options.serialize_options.quote_char) + .with_batch_size(options.batch_size) + .with_datetime_format(options.serialize_options.datetime_format) + .with_date_format(options.serialize_options.date_format) + .with_time_format(options.serialize_options.time_format) + .with_float_scientific(options.serialize_options.float_scientific) + .with_float_precision(options.serialize_options.float_precision) + .with_null_value(options.serialize_options.null) + .with_quote_style(options.serialize_options.quote_style) + .n_threads(1) + .batched(schema)?; + + let writer = Box::new(writer) as Box; + + let morsels_per_sink = morsels_per_sink(); + let backpressure = morsels_per_sink * 2; + let (sender, receiver) = bounded(backpressure); + + let io_thread_handle = Arc::new(Some(init_writer_thread( + receiver, + writer, + true, + morsels_per_sink, + ))); + + Ok(FilesSink { + sender, + io_thread_handle, + }) + } +} + +impl SinkWriter for polars_io::csv::write::BatchedWriter { + fn _write_batch(&mut self, df: &DataFrame) -> PolarsResult<()> { + self.write_batch(df) + } + + fn _finish(&mut self) -> PolarsResult<()> { + self.finish() + } +} diff --git a/crates/polars-pipe/src/executors/sinks/output/file_sink.rs b/crates/polars-pipe/src/executors/sinks/output/file_sink.rs new file mode 100644 index 000000000000..02b8d2a26851 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/output/file_sink.rs @@ -0,0 +1,121 @@ +use std::any::Any; +use std::thread::JoinHandle; + +use crossbeam_channel::{Receiver, Sender}; +use polars_core::prelude::*; + +use crate::operators::{ + DataChunk, FinalizedSink, PExecutionContext, Sink, SinkResult, StreamingVstacker, +}; + +pub(super) trait SinkWriter { + fn _write_batch(&mut self, df: &DataFrame) -> PolarsResult<()>; + + fn _finish(&mut self) -> PolarsResult<()>; +} + +pub(super) fn init_writer_thread( + receiver: Receiver>, + mut writer: Box, + maintain_order: bool, + // this is used to determine when a batch of chunks should be written to disk + // all chunks per push should be collected to determine in which order they should + // be written + morsels_per_sink: usize, +) -> JoinHandle> { + std::thread::spawn(move || -> PolarsResult<()> { + // keep chunks around until all chunks per sink are written + // then we write them all at once. + let mut chunks = Vec::with_capacity(morsels_per_sink); + let mut vstacker = StreamingVstacker::default(); + + while let Ok(chunk) = receiver.recv() { + // `last_write` indicates if all chunks are processed, e.g. this is the last write. + // this is when `write_chunks` is called with `None`. + let last_write = if let Some(chunk) = chunk { + chunks.push(chunk); + false + } else { + true + }; + + if chunks.len() == morsels_per_sink || last_write { + if maintain_order { + chunks.sort_by_key(|chunk| chunk.chunk_index); + } + + for chunk in chunks.drain(0..) { + for mut df in vstacker.add(chunk.data) { + // The dataframe may only be a single, large chunk, in + // which case we don't want to bother with copying it... + if df.first_col_n_chunks() > 1 { + df.as_single_chunk(); + } + writer._write_batch(&df)?; + } + } + // all chunks are written remove them + chunks.clear(); + + if last_write { + if let Some(mut df) = vstacker.finish() { + if df.first_col_n_chunks() > 1 { + df.as_single_chunk(); + } + writer._write_batch(&df)?; + } + writer._finish()?; + return Ok(()); + } + } + } + Ok(()) + }) +} + +// Ensure the data is return in the order it was streamed +#[derive(Clone)] +pub struct FilesSink { + pub(crate) sender: Sender>, + pub(crate) io_thread_handle: Arc>>>, +} + +impl Sink for FilesSink { + fn sink(&mut self, _context: &PExecutionContext, chunk: DataChunk) -> PolarsResult { + // don't add empty dataframes + if chunk.data.height() > 0 { + self.sender.send(Some(chunk)).unwrap(); + }; + Ok(SinkResult::CanHaveMoreInput) + } + + fn combine(&mut self, _other: &mut dyn Sink) { + // already synchronized + } + + fn split(&self, _thread_no: usize) -> Box { + Box::new(self.clone()) + } + fn finalize(&mut self, _context: &PExecutionContext) -> PolarsResult { + // `None` indicates that we can flush all remaining chunks. + self.sender.send(None).unwrap(); + + // wait until all files written + // some unwrap/mut kung-fu to get a hold of `self` + Arc::get_mut(&mut self.io_thread_handle) + .unwrap() + .take() + .unwrap() + .join() + .unwrap()?; + + // return a dummy dataframe; + Ok(FinalizedSink::Finished(Default::default())) + } + fn as_any(&mut self) -> &mut dyn Any { + self + } + fn fmt(&self) -> &str { + "parquet_sink" + } +} diff --git a/crates/polars-pipe/src/executors/sinks/output/ipc.rs b/crates/polars-pipe/src/executors/sinks/output/ipc.rs new file mode 100644 index 000000000000..c12780eaecd2 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/output/ipc.rs @@ -0,0 +1,55 @@ +use std::path::Path; + +use cloud::CloudOptions; +use crossbeam_channel::bounded; +use file::try_get_writeable; +use polars_core::prelude::*; +use polars_io::ipc::IpcWriterOptions; +use polars_io::prelude::*; + +use crate::executors::sinks::output::file_sink::{FilesSink, SinkWriter, init_writer_thread}; +use crate::pipeline::morsels_per_sink; + +pub struct IpcSink {} +impl IpcSink { + #[allow(clippy::new_ret_no_self)] + pub fn new( + path: &Path, + options: IpcWriterOptions, + schema: &Schema, + cloud_options: Option<&CloudOptions>, + ) -> PolarsResult { + let writer = IpcWriter::new(try_get_writeable(path.to_str().unwrap(), cloud_options)?) + .with_compression(options.compression) + .batched(schema)?; + + let writer = Box::new(writer) as Box; + + let morsels_per_sink = morsels_per_sink(); + let backpressure = morsels_per_sink * 2; + let (sender, receiver) = bounded(backpressure); + + let io_thread_handle = Arc::new(Some(init_writer_thread( + receiver, + writer, + true, + morsels_per_sink, + ))); + + Ok(FilesSink { + sender, + io_thread_handle, + }) + } +} + +impl SinkWriter for polars_io::ipc::BatchedWriter { + fn _write_batch(&mut self, df: &DataFrame) -> PolarsResult<()> { + self.write_batch(df) + } + + fn _finish(&mut self) -> PolarsResult<()> { + self.finish()?; + Ok(()) + } +} diff --git a/crates/polars-pipe/src/executors/sinks/output/json.rs b/crates/polars-pipe/src/executors/sinks/output/json.rs new file mode 100644 index 000000000000..56a4b77a9305 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/output/json.rs @@ -0,0 +1,50 @@ +use std::path::Path; + +use crossbeam_channel::bounded; +use polars_core::prelude::*; +use polars_io::cloud::CloudOptions; +use polars_io::json::{BatchedWriter, JsonWriterOptions}; +use polars_io::utils::file::try_get_writeable; + +use crate::executors::sinks::output::file_sink::{FilesSink, SinkWriter, init_writer_thread}; +use crate::pipeline::morsels_per_sink; + +impl SinkWriter for BatchedWriter { + fn _write_batch(&mut self, df: &DataFrame) -> PolarsResult<()> { + self.write_batch(df) + } + + fn _finish(&mut self) -> PolarsResult<()> { + Ok(()) + } +} + +pub struct JsonSink {} +impl JsonSink { + #[allow(clippy::new_ret_no_self)] + pub fn new( + path: &Path, + _options: JsonWriterOptions, + _schema: &Schema, + cloud_options: Option<&CloudOptions>, + ) -> PolarsResult { + let writer = BatchedWriter::new(try_get_writeable(path.to_str().unwrap(), cloud_options)?); + let writer = Box::new(writer) as Box; + + let morsels_per_sink = morsels_per_sink(); + let backpressure = morsels_per_sink * 2; + let (sender, receiver) = bounded(backpressure); + + let io_thread_handle = Arc::new(Some(init_writer_thread( + receiver, + writer, + true, + morsels_per_sink, + ))); + + Ok(FilesSink { + sender, + io_thread_handle, + }) + } +} diff --git a/crates/polars-pipe/src/executors/sinks/output/mod.rs b/crates/polars-pipe/src/executors/sinks/output/mod.rs new file mode 100644 index 000000000000..602e525fa853 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/output/mod.rs @@ -0,0 +1,24 @@ +#[cfg(feature = "csv")] +mod csv; +#[cfg(any( + feature = "parquet", + feature = "ipc", + feature = "csv", + feature = "json" +))] +mod file_sink; +#[cfg(feature = "ipc")] +mod ipc; +#[cfg(feature = "json")] +mod json; +#[cfg(feature = "parquet")] +mod parquet; + +#[cfg(feature = "csv")] +pub use csv::*; +#[cfg(feature = "ipc")] +pub use ipc::*; +#[cfg(feature = "json")] +pub use json::*; +#[cfg(feature = "parquet")] +pub use parquet::*; diff --git a/crates/polars-pipe/src/executors/sinks/output/parquet.rs b/crates/polars-pipe/src/executors/sinks/output/parquet.rs new file mode 100644 index 000000000000..2dbfd88910e1 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/output/parquet.rs @@ -0,0 +1,161 @@ +use std::any::Any; +use std::path::Path; +use std::thread::JoinHandle; + +use crossbeam_channel::{Receiver, Sender, bounded}; +use polars_core::prelude::*; +use polars_io::cloud::CloudOptions; +use polars_io::parquet::write::{ + BatchedWriter, ParquetWriteOptions, ParquetWriter, RowGroupIterColumns, +}; +use polars_io::utils::file::try_get_writeable; +use polars_utils::file::WriteClose; + +use crate::executors::sinks::output::file_sink::SinkWriter; +use crate::operators::{DataChunk, FinalizedSink, PExecutionContext, Sink, SinkResult}; +use crate::pipeline::morsels_per_sink; + +type RowGroups = Vec>; + +pub(super) fn init_row_group_writer_thread( + receiver: Receiver>, + writer: Arc>, + // this is used to determine when a batch of chunks should be written to disk + // all chunks per push should be collected to determine in which order they should + // be written + morsels_per_sink: usize, +) -> JoinHandle<()> +where + W: std::io::Write + Send + 'static, +{ + std::thread::spawn(move || { + // keep chunks around until all chunks per sink are written + // then we write them all at once. + let mut batched = Vec::with_capacity(morsels_per_sink); + while let Ok(rgs) = receiver.recv() { + // `last_write` indicates if all chunks are processed, e.g. this is the last write. + // this is when `write_chunks` is called with `None`. + let last_write = if let Some(rgs) = rgs { + batched.push(rgs); + false + } else { + true + }; + + if batched.len() == morsels_per_sink || last_write { + batched.sort_by_key(|chunk| chunk.0); + + for (_, rg) in batched.drain(0..) { + writer.write_row_groups(rg).unwrap() + } + } + if last_write { + writer.finish().unwrap(); + return; + } + } + }) +} + +#[derive(Clone)] +pub struct ParquetSink { + writer: Arc>>, + io_thread_handle: Arc>>, + sender: Sender>, +} +impl ParquetSink { + #[allow(clippy::new_ret_no_self)] + pub fn new( + path: &Path, + options: ParquetWriteOptions, + schema: &Schema, + cloud_options: Option<&CloudOptions>, + ) -> PolarsResult { + let writer = ParquetWriter::new(try_get_writeable(path.to_str().unwrap(), cloud_options)?) + .with_compression(options.compression) + .with_data_page_size(options.data_page_size) + .with_statistics(options.statistics) + .with_row_group_size(options.row_group_size) + // This is important! Otherwise we will deadlock + // See: #7074 + .set_parallel(false) + .batched(schema)?; + + let writer = Arc::new(writer); + let morsels_per_sink = morsels_per_sink(); + + let backpressure = morsels_per_sink * 4; + let (sender, receiver) = bounded(backpressure); + + let io_thread_handle = Arc::new(Some(init_row_group_writer_thread( + receiver, + writer.clone(), + morsels_per_sink, + ))); + + Ok(Self { + writer, + io_thread_handle, + sender, + }) + } +} + +impl Sink for ParquetSink { + fn sink(&mut self, _context: &PExecutionContext, chunk: DataChunk) -> PolarsResult { + // Encode and compress row-groups on every thread. + let row_groups = self + .writer + .encode_and_compress(&chunk.data) + .collect::>>()?; + // Only then send the compressed pages to the writer. + self.sender + .send(Some((chunk.chunk_index, row_groups))) + .unwrap(); + Ok(SinkResult::CanHaveMoreInput) + } + + fn combine(&mut self, _other: &mut dyn Sink) { + // Nothing to do + } + + fn split(&self, _thread_no: usize) -> Box { + Box::new(self.clone()) + } + + fn finalize(&mut self, _context: &PExecutionContext) -> PolarsResult { + // `None` indicates that we can flush all remaining chunks. + self.sender.send(None).unwrap(); + + // wait until all files written + // some unwrap/mut kung-fu to get a hold of `self` + Arc::get_mut(&mut self.io_thread_handle) + .unwrap() + .take() + .unwrap() + .join() + .unwrap(); + + // return a dummy dataframe; + Ok(FinalizedSink::Finished(Default::default())) + } + + fn as_any(&mut self) -> &mut dyn Any { + self + } + + fn fmt(&self) -> &str { + "parquet_sink" + } +} + +impl SinkWriter for polars_io::parquet::write::BatchedWriter { + fn _write_batch(&mut self, df: &DataFrame) -> PolarsResult<()> { + self.write_batch(df) + } + + fn _finish(&mut self) -> PolarsResult<()> { + self.finish()?; + Ok(()) + } +} diff --git a/crates/polars-pipe/src/executors/sinks/reproject.rs b/crates/polars-pipe/src/executors/sinks/reproject.rs new file mode 100644 index 000000000000..bd9553b75f97 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/reproject.rs @@ -0,0 +1,59 @@ +use std::any::Any; + +use polars_core::schema::SchemaRef; + +use crate::executors::sources::ReProjectSource; +use crate::operators::{ + DataChunk, FinalizedSink, PExecutionContext, PolarsResult, Sink, SinkResult, +}; + +/// A sink that will ensure we keep the schema order +pub(crate) struct ReProjectSink { + schema: SchemaRef, + sink: Box, +} + +impl ReProjectSink { + pub(crate) fn new(schema: SchemaRef, sink: Box) -> Self { + Self { schema, sink } + } +} + +impl Sink for ReProjectSink { + fn sink(&mut self, context: &PExecutionContext, chunk: DataChunk) -> PolarsResult { + self.sink.sink(context, chunk) + } + + fn combine(&mut self, other: &mut dyn Sink) { + let other = other.as_any().downcast_mut::().unwrap(); + self.sink.combine(other.sink.as_mut()) + } + + fn split(&self, thread_no: usize) -> Box { + let sink = self.sink.split(thread_no); + Box::new(Self { + schema: self.schema.clone(), + sink, + }) + } + + fn finalize(&mut self, context: &PExecutionContext) -> PolarsResult { + Ok(match self.sink.finalize(context)? { + FinalizedSink::Finished(df) => { + FinalizedSink::Finished(df.select(self.schema.iter_names_cloned())?) + }, + FinalizedSink::Source(source) => { + FinalizedSink::Source(Box::new(ReProjectSource::new(self.schema.clone(), source))) + }, + _ => unimplemented!(), + }) + } + + fn as_any(&mut self) -> &mut dyn Any { + self + } + + fn fmt(&self) -> &str { + "re-project-sink" + } +} diff --git a/crates/polars-pipe/src/executors/sinks/slice.rs b/crates/polars-pipe/src/executors/sinks/slice.rs new file mode 100644 index 000000000000..e44f4a26a3ec --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/slice.rs @@ -0,0 +1,102 @@ +use std::any::Any; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex}; + +use polars_core::error::PolarsResult; +use polars_core::frame::DataFrame; +use polars_core::schema::SchemaRef; + +use crate::operators::{ + DataChunk, FinalizedSink, PExecutionContext, Sink, SinkResult, chunks_to_df_unchecked, +}; + +#[derive(Clone)] +// Ensure the data is return in the order it was streamed +pub struct SliceSink { + offset: Arc, + current_len: Arc, + len: usize, + chunks: Arc>>, + schema: SchemaRef, +} + +impl SliceSink { + pub fn new(offset: u64, len: usize, schema: SchemaRef) -> SliceSink { + let offset = Arc::new(AtomicUsize::new(offset as usize)); + SliceSink { + offset, + current_len: Arc::default(), + len, + chunks: Default::default(), + schema, + } + } + + fn sort(&mut self) { + let mut chunks = self.chunks.lock().unwrap(); + chunks.sort_unstable_by_key(|chunk| chunk.chunk_index); + } +} + +impl Sink for SliceSink { + fn sink(&mut self, _context: &PExecutionContext, chunk: DataChunk) -> PolarsResult { + // there is contention here. + + let height = chunk.data.height(); + let mut chunks = self.chunks.lock().unwrap(); + // don't add empty dataframes + if height > 0 || chunks.is_empty() { + // TODO! deal with offset + // this is a bit harder as the chunks come in randomly + + // we are under a mutex lock here + // so ordering doesn't seem too important + let current_offset = self.offset.load(Ordering::Acquire); + let current_len = self.current_len.fetch_add(height, Ordering::Acquire); + + // always push as they come in random order + + chunks.push(chunk); + + if current_len > (self.len + current_offset) { + Ok(SinkResult::Finished) + } else { + Ok(SinkResult::CanHaveMoreInput) + } + } else { + Ok(SinkResult::CanHaveMoreInput) + } + } + + fn combine(&mut self, _other: &mut dyn Sink) { + // no-op + } + + fn split(&self, _thread_no: usize) -> Box { + Box::new(self.clone()) + } + + fn finalize(&mut self, _context: &PExecutionContext) -> PolarsResult { + self.sort(); + let chunks = std::mem::take(&mut self.chunks); + let mut chunks = chunks.lock().unwrap(); + let chunks: Vec = std::mem::take(chunks.as_mut()); + if chunks.is_empty() { + return Ok(FinalizedSink::Finished(DataFrame::empty_with_schema( + &self.schema, + ))); + } + + let df = chunks_to_df_unchecked(chunks); + let offset = self.offset.load(Ordering::Acquire) as i64; + Ok(FinalizedSink::Finished(df.slice(offset, self.len))) + } + + fn as_any(&mut self) -> &mut dyn Any { + self + } + + fn fmt(&self) -> &str { + "slice_sink" + } +} diff --git a/crates/polars-pipe/src/executors/sinks/sort/mod.rs b/crates/polars-pipe/src/executors/sinks/sort/mod.rs new file mode 100644 index 000000000000..6b5f952aaffa --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/sort/mod.rs @@ -0,0 +1,7 @@ +mod ooc; +mod sink; +mod sink_multiple; +mod source; + +pub(crate) use sink::SortSink; +pub(crate) use sink_multiple::SortSinkMultiple; diff --git a/crates/polars-pipe/src/executors/sinks/sort/ooc.rs b/crates/polars-pipe/src/executors/sinks/sort/ooc.rs new file mode 100644 index 000000000000..9c6f33cb1469 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/sort/ooc.rs @@ -0,0 +1,263 @@ +use std::path::Path; +use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; +use std::time::Instant; + +use crossbeam_queue::SegQueue; +use polars_core::POOL; +use polars_core::prelude::*; +use polars_core::series::IsSorted; +use polars_core::utils::{ + accumulate_dataframes_vertical_unchecked, accumulate_dataframes_vertical_unchecked_optional, +}; +use polars_io::SerReader; +use polars_io::ipc::IpcReader; +use polars_ops::prelude::*; +use rayon::prelude::*; + +use crate::executors::sinks::io::{DfIter, IOThread}; +use crate::executors::sinks::memory::MemTracker; +use crate::executors::sinks::sort::source::SortSource; +use crate::operators::FinalizedSink; + +pub(super) fn read_df(path: &Path) -> PolarsResult { + let file = polars_utils::open_file(path)?; + IpcReader::new(file).set_rechunk(false).finish() +} + +// Utility to buffer partitioned dataframes +// this ensures we don't write really small dataframes +// and amortize IO cost +#[derive(Default)] +struct PartitionSpillBuf { + // keep track of the length + // that's cheaper than iterating the linked list + len: AtomicU32, + size: AtomicU64, + chunks: SegQueue, +} + +impl PartitionSpillBuf { + fn push(&self, df: DataFrame, spill_limit: u64) -> Option { + debug_assert!(df.height() > 0); + let size = self + .size + .fetch_add(df.estimated_size() as u64, Ordering::Relaxed); + let len = self.len.fetch_add(1, Ordering::Relaxed); + self.chunks.push(df); + if size > spill_limit { + // Reset all statistics. + self.len.store(0, Ordering::Relaxed); + self.size.store(0, Ordering::Relaxed); + // other threads can be pushing while we drain + // so we pop no more than the current size. + let pop_max = len; + let iter = (0..pop_max).flat_map(|_| self.chunks.pop()); + // Due to race conditions, the chunks can already be popped, so we use optional. + accumulate_dataframes_vertical_unchecked_optional(iter) + } else { + None + } + } + + fn finish(&self) -> Option { + if !self.chunks.is_empty() { + let len = self.len.load(Ordering::Relaxed) + 1; + let mut out = Vec::with_capacity(len as usize); + while let Some(df) = self.chunks.pop() { + out.push(df) + } + Some(accumulate_dataframes_vertical_unchecked(out)) + } else { + None + } + } +} + +pub(crate) struct PartitionSpiller { + partitions: Vec, + // Spill limit in bytes. + spill_limit: u64, +} + +impl PartitionSpiller { + fn new(n_parts: usize, spill_limit: u64) -> Self { + let mut partitions = vec![]; + partitions.resize_with(n_parts + 1, PartitionSpillBuf::default); + Self { + partitions, + spill_limit, + } + } + + fn push(&self, partition: usize, df: DataFrame) -> Option { + self.partitions[partition].push(df, self.spill_limit) + } + + pub(crate) fn get(&self, partition: usize) -> Option { + self.partitions[partition].finish() + } + + pub(crate) fn len(&self) -> usize { + self.partitions.len() + } + + #[cfg(debug_assertions)] + // Used in testing only. + fn spill_all(&self, io_thread: &IOThread) { + let min_len = std::cmp::max(self.partitions.len() / POOL.current_num_threads(), 2); + POOL.install(|| { + self.partitions + .par_iter() + .with_min_len(min_len) + .enumerate() + .for_each(|(part, part_buf)| { + if let Some(df) = part_buf.finish() { + io_thread.dump_partition_local(part as IdxSize, df) + } + }) + }); + eprintln!("PARTITIONED FORCE SPILLED") + } +} + +#[allow(clippy::too_many_arguments)] +pub(super) fn sort_ooc( + io_thread: IOThread, + // these partitions are the samples + // these are not yet assigned to a buckets + samples: Series, + idx: usize, + descending: bool, + nulls_last: bool, + slice: Option<(i64, usize)>, + verbose: bool, + memtrack: MemTracker, + ooc_start: Instant, +) -> PolarsResult { + let now = Instant::now(); + let multithreaded_partition = std::env::var("POLARS_OOC_SORT_PAR_PARTITION").is_ok(); + let spill_size = std::env::var("POLARS_OOC_SORT_SPILL_SIZE") + .map(|v| v.parse::().expect("integer")) + .unwrap_or(1 << 26); + let samples = samples.to_physical_repr().into_owned(); + let spill_size = std::cmp::min( + memtrack.get_available_latest() / (samples.len() * 3 + 1), + spill_size, + ); + + // we collect as I am not sure that if we write to the same directory the + // iterator will read those also. + // We don't want to merge files we just written to disk + let dir = &io_thread.dir; + let files = std::fs::read_dir(dir)?.collect::>>()?; + + if verbose { + eprintln!("spill size: {} mb", spill_size / 1024 / 1024); + eprintln!("processing {} files", files.len()); + } + + let partitions_spiller = PartitionSpiller::new(samples.len(), spill_size as u64); + + POOL.install(|| { + files.par_iter().try_for_each(|entry| { + let path = entry.path(); + // don't read the lock file + if path.ends_with(".lock") { + return PolarsResult::Ok(()); + } + let df = read_df(&path)?; + + let sort_col = &df.get_columns()[idx]; + let assigned_parts = + det_partitions(sort_col.as_materialized_series(), &samples, descending); + + // partition the dataframe into proper buckets + let (iter, unique_assigned_parts) = + partition_df(df, &assigned_parts, multithreaded_partition)?; + for (part, df) in unique_assigned_parts.into_no_null_iter().zip(iter) { + if let Some(df) = partitions_spiller.push(part as usize, df) { + io_thread.dump_partition_local(part, df) + } + } + io_thread.clean(path); + PolarsResult::Ok(()) + }) + })?; + if verbose { + eprintln!("partitioning sort took: {:?}", now.elapsed()); + } + + // Branch for testing so we hit different parts in the Source phase. + #[cfg(debug_assertions)] + { + if std::env::var("POLARS_SPILL_SORT_PARTITIONS").is_ok() { + partitions_spiller.spill_all(&io_thread) + } + } + + let files = std::fs::read_dir(dir)? + .flat_map(|entry| { + entry + .map(|entry| { + let path = entry.path(); + if path.is_dir() { + let dirname = path.file_name().unwrap(); + let partition = dirname.to_string_lossy().parse::().unwrap(); + Some((partition, path)) + } else { + None + } + }) + .transpose() + }) + .collect::>>()?; + + let source = SortSource::new( + files, + idx, + descending, + nulls_last, + slice, + verbose, + io_thread, + memtrack, + ooc_start, + partitions_spiller, + ); + Ok(FinalizedSink::Source(Box::new(source))) +} + +fn det_partitions(s: &Series, partitions: &Series, descending: bool) -> IdxCa { + let s = s.to_physical_repr(); + + search_sorted(partitions, &s, SearchSortedSide::Any, descending).unwrap() +} + +fn partition_df( + df: DataFrame, + partitions: &IdxCa, + multithreaded: bool, +) -> PolarsResult<(DfIter, IdxCa)> { + let groups = partitions.group_tuples(multithreaded, false)?; + let partitions = unsafe { partitions.clone().into_series().agg_first(&groups) }; + let partitions = partitions.idx().unwrap().clone(); + + let out = match groups { + GroupsType::Idx(idx) => { + let iter = idx.into_iter().map(move |(_, group)| { + // groups are in bounds and sorted + unsafe { + df._take_unchecked_slice_sorted(&group, multithreaded, IsSorted::Ascending) + } + }); + Box::new(iter) as DfIter + }, + GroupsType::Slice { groups, .. } => { + let iter = groups + .into_iter() + .map(move |[first, len]| df.slice(first as i64, len as usize)); + Box::new(iter) as DfIter + }, + }; + Ok((out, partitions)) +} diff --git a/crates/polars-pipe/src/executors/sinks/sort/sink.rs b/crates/polars-pipe/src/executors/sinks/sort/sink.rs new file mode 100644 index 000000000000..77680975ab4b --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/sort/sink.rs @@ -0,0 +1,253 @@ +use std::any::Any; +use std::sync::{Arc, RwLock}; +use std::time::Instant; + +use polars_core::chunked_array::ops::SortMultipleOptions; +use polars_core::config::verbose; +use polars_core::error::PolarsResult; +use polars_core::frame::DataFrame; +use polars_core::prelude::{AnyValue, SchemaRef, Series, SortOptions}; +use polars_core::utils::accumulate_dataframes_vertical_unchecked; +use polars_utils::pl_str::PlSmallStr; + +use crate::executors::sinks::io::{IOThread, block_thread_until_io_thread_done}; +use crate::executors::sinks::memory::MemTracker; +use crate::executors::sinks::sort::ooc::sort_ooc; +use crate::operators::{DataChunk, FinalizedSink, PExecutionContext, Sink, SinkResult}; +use crate::pipeline::{FORCE_OOC, morsels_per_sink}; + +pub struct SortSink { + schema: SchemaRef, + chunks: Vec, + // Stores available memory in the system at the start of this sink. + // and stores the memory used by this sink. + mem_track: MemTracker, + // sort in-memory or out-of-core + ooc: bool, + // when ooc, we write to disk using an IO thread + // RwLock as we want to have multiple readers at once. + io_thread: Arc>>, + // location in the dataframe of the columns to sort by + sort_idx: usize, + slice: Option<(i64, usize)>, + sort_options: SortMultipleOptions, + // Statistics + // sampled values so we can find the distribution. + dist_sample: Vec>, + // total rows accumulated in current chunk + current_chunk_rows: usize, + // total bytes of tables in current chunks + current_chunks_size: usize, + // Start time of OOC phase. + ooc_start: Option, +} + +impl SortSink { + pub(crate) fn new( + sort_idx: usize, + slice: Option<(i64, usize)>, + sort_options: SortMultipleOptions, + schema: SchemaRef, + ) -> Self { + // for testing purposes + let ooc = std::env::var(FORCE_OOC).is_ok(); + let n_morsels_per_sink = morsels_per_sink(); + + let mut out = Self { + schema, + chunks: Default::default(), + mem_track: MemTracker::new(n_morsels_per_sink), + ooc, + io_thread: Default::default(), + sort_idx, + slice, + sort_options, + dist_sample: vec![], + current_chunk_rows: 0, + current_chunks_size: 0, + ooc_start: None, + }; + if ooc { + if verbose() { + eprintln!("OOC sort forced"); + } + out.init_ooc().unwrap(); + } + out + } + + fn init_ooc(&mut self) -> PolarsResult<()> { + if verbose() { + eprintln!("OOC sort started"); + } + self.ooc_start = Some(Instant::now()); + self.ooc = true; + + // start IO thread + let mut iot = self.io_thread.write().unwrap(); + if iot.is_none() { + *iot = Some(IOThread::try_new(self.schema.clone(), "sort")?) + } + Ok(()) + } + + fn store_chunk(&mut self, chunk: DataChunk) -> PolarsResult<()> { + let chunk_bytes = chunk.data.estimated_size(); + if !self.ooc { + let used = self.mem_track.fetch_add(chunk_bytes); + let free = self.mem_track.get_available(); + + // we need some free memory to be able to sort + // so we keep 3x the sort data size before we go out of core + if used * 3 > free { + self.init_ooc()?; + self.dump(true)?; + } + }; + // don't add empty dataframes + if chunk.data.height() > 0 || self.chunks.is_empty() { + self.current_chunks_size += chunk_bytes; + self.current_chunk_rows += chunk.data.height(); + self.chunks.push(chunk.data); + } + Ok(()) + } + + fn dump(&mut self, force: bool) -> PolarsResult<()> { + let larger_than_32_mb = self.current_chunks_size > (1 << 25); + if (force || larger_than_32_mb) && !self.chunks.is_empty() { + // into a single chunk because multiple file IO's is expensive + // and may lead to many smaller files in ooc-sort later, which is exponentially + // expensive + let df = accumulate_dataframes_vertical_unchecked(self.chunks.drain(..)); + if df.height() > 0 { + // SAFETY: we just asserted height > 0 + let sample = unsafe { + let s = &df.get_columns()[self.sort_idx]; + s.to_physical_repr().get_unchecked(0).into_static() + }; + self.dist_sample.push(sample); + + let iot = self.io_thread.read().unwrap(); + let iot = iot.as_ref().unwrap(); + + iot.dump_chunk(df); + + // reset sizes + self.current_chunk_rows = 0; + self.current_chunks_size = 0; + } + } + Ok(()) + } +} + +impl Sink for SortSink { + fn sink(&mut self, _context: &PExecutionContext, chunk: DataChunk) -> PolarsResult { + self.store_chunk(chunk)?; + + if self.ooc { + self.dump(false)?; + } + Ok(SinkResult::CanHaveMoreInput) + } + + fn combine(&mut self, other: &mut dyn Sink) { + let other = other.as_any().downcast_mut::().unwrap(); + if let Some(ooc_start) = other.ooc_start { + self.ooc_start = Some(ooc_start); + } + self.chunks.extend(std::mem::take(&mut other.chunks)); + self.ooc |= other.ooc; + self.dist_sample + .extend(std::mem::take(&mut other.dist_sample)); + + if self.ooc { + self.dump(false).unwrap() + } + } + + fn split(&self, _thread_no: usize) -> Box { + Box::new(Self { + schema: self.schema.clone(), + chunks: Default::default(), + mem_track: self.mem_track.clone(), + ooc: self.ooc, + io_thread: self.io_thread.clone(), + sort_idx: self.sort_idx, + slice: self.slice, + sort_options: self.sort_options.clone(), + dist_sample: vec![], + current_chunk_rows: 0, + current_chunks_size: 0, + ooc_start: self.ooc_start, + }) + } + + fn finalize(&mut self, context: &PExecutionContext) -> PolarsResult { + if self.ooc { + // spill everything + self.dump(true).unwrap(); + let mut lock = self.io_thread.write().unwrap(); + let io_thread = lock.take().unwrap(); + + let dist = Series::from_any_values(PlSmallStr::EMPTY, &self.dist_sample, true).unwrap(); + let dist = dist.sort_with(SortOptions::from(&self.sort_options))?; + + let instant = self.ooc_start.unwrap(); + if context.verbose { + eprintln!("finished sinking into OOC sort in {:?}", instant.elapsed()); + } + block_thread_until_io_thread_done(&io_thread); + if context.verbose { + eprintln!("full file dump of OOC sort took {:?}", instant.elapsed()); + } + + sort_ooc( + io_thread, + dist, + self.sort_idx, + self.sort_options.descending[0], + self.sort_options.nulls_last[0], + self.slice, + context.verbose, + self.mem_track.clone(), + instant, + ) + } else { + let chunks = std::mem::take(&mut self.chunks); + let df = accumulate_dataframes_vertical_unchecked(chunks); + let df = sort_accumulated( + df, + self.sort_idx, + self.slice, + SortOptions::from(&self.sort_options), + )?; + Ok(FinalizedSink::Finished(df)) + } + } + + fn as_any(&mut self) -> &mut dyn Any { + self + } + + fn fmt(&self) -> &str { + "sort" + } +} + +pub(super) fn sort_accumulated( + mut df: DataFrame, + sort_idx: usize, + slice: Option<(i64, usize)>, + sort_options: SortOptions, +) -> PolarsResult { + // This is needed because we can have empty blocks and we require chunks to have single chunks. + df.as_single_chunk_par(); + let sort_column = df.get_columns()[sort_idx].clone(); + df.sort_impl( + vec![sort_column], + SortMultipleOptions::from(&sort_options), + slice, + ) +} diff --git a/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs b/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs new file mode 100644 index 000000000000..24cd778a89f6 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs @@ -0,0 +1,361 @@ +use std::any::Any; + +use arrow::array::BinaryArray; +use polars_core::prelude::sort::_broadcast_bools; +use polars_core::prelude::*; +use polars_core::series::IsSorted; +use polars_row::decode::decode_rows_from_binary; +use polars_row::{RowEncodingContext, RowEncodingOptions}; + +use self::row_encode::get_row_encoding_context; +use super::*; +use crate::operators::{ + DataChunk, FinalizedSink, PExecutionContext, Sink, SinkResult, Source, SourceResult, +}; +const POLARS_SORT_COLUMN: &str = "__POLARS_SORT_COLUMN"; + +fn get_sort_fields( + sort_idx: &[usize], + sort_options: &SortMultipleOptions, +) -> Vec { + let mut descending = sort_options.descending.clone(); + let mut nulls_last = sort_options.nulls_last.clone(); + + _broadcast_bools(sort_idx.len(), &mut descending); + _broadcast_bools(sort_idx.len(), &mut nulls_last); + + descending + .into_iter() + .zip(nulls_last) + .map(|(descending, nulls_last)| RowEncodingOptions::new_sorted(descending, nulls_last)) + .collect() +} + +fn sort_by_idx(values: &[V], idx: &[usize]) -> Vec { + assert_eq!(values.len(), idx.len()); + + let mut tmp = values + .iter() + .cloned() + .zip(idx.iter().copied()) + .collect::>(); + tmp.sort_unstable_by_key(|k| k.1); + tmp.into_iter().map(|k| k.0).collect() +} + +#[allow(clippy::too_many_arguments)] +fn finalize_dataframe( + df: &mut DataFrame, + sort_idx: &[usize], + sort_options: &SortMultipleOptions, + sort_dtypes: Option<&[ArrowDataType]>, + rows: &mut Vec<&'static [u8]>, + sort_opts: &[RowEncodingOptions], + sort_dicts: &[Option], + schema: &Schema, +) { + // pop the encoded sort column + // SAFETY: We only pop a value + let encoded = unsafe { df.get_columns_mut() }.pop().unwrap(); + + // we decode the row-encoded binary column + // this will be decoded into multiple columns + // these are the columns we sorted by + // those need to be inserted at the `sort_idx` position + // in the `DataFrame`. + let sort_dtypes = sort_dtypes.expect("should be set if 'can_decode'"); + + let encoded = encoded.binary_offset().unwrap(); + assert_eq!(encoded.chunks().len(), 1); + let arr = encoded.downcast_iter().next().unwrap(); + + // SAFETY: + // temporary extend lifetime + // this is safe as the lifetime in rows stays bound to this scope + let arrays = unsafe { + let arr = std::mem::transmute::<&'_ BinaryArray, &'static BinaryArray>(arr); + decode_rows_from_binary(arr, sort_opts, sort_dicts, sort_dtypes, rows) + }; + rows.clear(); + + let arrays = sort_by_idx(&arrays, sort_idx); + let mut sort_idx = sort_idx.to_vec(); + sort_idx.sort_unstable(); + + for (&sort_idx, arr) in sort_idx.iter().zip(arrays) { + let (name, logical_dtype) = schema.get_at_index(sort_idx).unwrap(); + assert_eq!( + logical_dtype.to_physical(), + DataType::from_arrow_dtype(arr.dtype()) + ); + let col = unsafe { + Series::from_chunks_and_dtype_unchecked(name.clone(), vec![arr], logical_dtype) + } + .into_column(); + + // SAFETY: col has the same length as the df height because it was popped from df. + unsafe { df.get_columns_mut() }.insert(sort_idx, col); + df.clear_schema(); + } + + // SAFETY: We just change the sorted flag. + let first_sort_col = &mut unsafe { df.get_columns_mut() }[sort_idx[0]]; + let flag = if sort_options.descending[0] { + IsSorted::Descending + } else { + IsSorted::Ascending + }; + first_sort_col.set_sorted_flag(flag) +} + +/// This struct will dispatch all sorting to `SortSink` +/// But before it does that it will encode the column that we +/// must sort by and set that row-encoded column as last +/// column in the data chunks. +/// +/// Once the sorting is finished it adapts the result so that +/// the encoded column is removed +pub struct SortSinkMultiple { + output_schema: SchemaRef, + sort_idx: Arc<[usize]>, + sort_sink: Box, + slice: Option<(i64, usize)>, + sort_options: SortMultipleOptions, + // Needed for encoding + sort_opts: Arc<[RowEncodingOptions]>, + sort_dicts: Arc<[Option]>, + sort_dtypes: Option>, + // amortize allocs + sort_column: Vec, +} + +impl SortSinkMultiple { + pub(crate) fn new( + slice: Option<(i64, usize)>, + sort_options: SortMultipleOptions, + output_schema: SchemaRef, + sort_idx: Vec, + ) -> PolarsResult { + let mut schema = (*output_schema).clone(); + + let sort_ctxts = sort_idx + .iter() + .map(|i| { + let (_, dtype) = schema.get_at_index(*i).unwrap(); + get_row_encoding_context(dtype, true) + }) + .collect::>(); + + polars_ensure!(sort_idx.iter().collect::>().len() == sort_idx.len(), ComputeError: "only supports sorting by unique columns"); + + let mut dtypes = vec![DataType::Null; sort_idx.len()]; + + // we remove columns by index, but then the indices aren't correct anymore + // so we do it in the proper order and keep track of the indices removed + let mut sorted_sort_idx = sort_idx.iter().copied().enumerate().collect::>(); + // Sort by `sort_idx`. + sorted_sort_idx.sort_unstable_by_key(|k| k.1); + // remove the sort indices as we will encode them into the sort binary + for (iterator_i, (original_idx, sort_i)) in sorted_sort_idx.iter().enumerate() { + dtypes[*original_idx] = schema.shift_remove_index(*sort_i - iterator_i).unwrap().1; + } + let sort_dtypes = Some(dtypes.into()); + schema.with_column(POLARS_SORT_COLUMN.into(), DataType::BinaryOffset); + let sort_fields = get_sort_fields(&sort_idx, &sort_options); + + // don't set descending and nulls last as this + // will be solved by the row encoding + let sort_sink = Box::new(SortSink::new( + // we will set the last column as sort column + schema.len() - 1, + slice, + sort_options + .clone() + .with_order_descending(false) + .with_nulls_last(false) + .with_maintain_order(false), + Arc::new(schema), + )); + + Ok(SortSinkMultiple { + sort_sink, + slice, + sort_options, + sort_idx: Arc::from(sort_idx), + sort_opts: Arc::from(sort_fields), + sort_dicts: Arc::from(sort_ctxts), + sort_dtypes, + sort_column: vec![], + output_schema, + }) + } + + fn encode(&mut self, chunk: &mut DataChunk) -> PolarsResult<()> { + let df = &mut chunk.data; + + self.sort_column.clear(); + + for i in self.sort_idx.iter() { + let s = &df.get_columns()[*i].as_materialized_series(); + let arr = s.to_physical_repr().rechunk().chunks()[0].to_boxed(); + self.sort_column.push(arr); + } + + // we remove columns by index, but then the aren't correct anymore + // so we do it in the proper order and keep track of the indices removed + let mut sorted_sort_idx = self.sort_idx.to_vec(); + sorted_sort_idx.sort_unstable(); + + // SAFETY: We do not adjust the names or lengths or columns. + let cols = unsafe { df.get_columns_mut() }; + + sorted_sort_idx + .into_iter() + .enumerate() + .for_each(|(i, sort_idx)| { + // shifts all columns right from removed one to the left so + // therefore we subtract `i` as the shifted count + let _ = cols.remove(sort_idx - i); + }); + + df.clear_schema(); + let name = PlSmallStr::from_static(POLARS_SORT_COLUMN); + let column = if chunk.data.height() == 0 && chunk.data.width() > 0 { + Column::new_empty(name, &DataType::BinaryOffset) + } else { + let rows_encoded = polars_row::convert_columns( + self.sort_column[0].len(), // @NOTE: does not work for ZFS + &self.sort_column, + &self.sort_opts, + &self.sort_dicts, + ); + let series = unsafe { + Series::from_chunks_and_dtype_unchecked( + name, + vec![Box::new(rows_encoded.into_array())], + &DataType::BinaryOffset, + ) + }; + debug_assert_eq!(series.chunks().len(), 1); + series.into() + }; + + // SAFETY: length is correct + unsafe { chunk.data.with_column_unchecked(column) }; + Ok(()) + } +} + +impl Sink for SortSinkMultiple { + fn sink( + &mut self, + context: &PExecutionContext, + mut chunk: DataChunk, + ) -> PolarsResult { + self.encode(&mut chunk)?; + self.sort_sink.sink(context, chunk) + } + + fn combine(&mut self, other: &mut dyn Sink) { + let other = other.as_any().downcast_mut::().unwrap(); + self.sort_sink.combine(&mut *other.sort_sink) + } + + fn split(&self, thread_no: usize) -> Box { + let sort_sink = self.sort_sink.split(thread_no); + Box::new(Self { + sort_idx: self.sort_idx.clone(), + sort_sink, + sort_opts: self.sort_opts.clone(), + slice: self.slice, + sort_options: self.sort_options.clone(), + sort_dicts: self.sort_dicts.clone(), + sort_column: vec![], + sort_dtypes: self.sort_dtypes.clone(), + output_schema: self.output_schema.clone(), + }) + } + + fn finalize(&mut self, context: &PExecutionContext) -> PolarsResult { + let out = self.sort_sink.finalize(context)?; + + let sort_dtypes = self.sort_dtypes.take().map(|arr| { + arr.iter() + .map(|dt| dt.to_physical().to_arrow(CompatLevel::newest())) + .collect::>() + }); + + // we must adapt the finalized sink result so that the sort encoded column is dropped + match out { + FinalizedSink::Finished(mut df) => { + finalize_dataframe( + &mut df, + self.sort_idx.as_ref(), + &self.sort_options, + sort_dtypes.as_deref(), + &mut vec![], + self.sort_opts.as_ref(), + self.sort_dicts.as_ref(), + &self.output_schema, + ); + Ok(FinalizedSink::Finished(df)) + }, + FinalizedSink::Source(source) => Ok(FinalizedSink::Source(Box::new(DropEncoded { + source, + sort_idx: self.sort_idx.clone(), + sort_options: self.sort_options.clone(), + sort_dtypes, + rows: vec![], + sort_opts: self.sort_opts.clone(), + sort_dicts: self.sort_dicts.clone(), + output_schema: self.output_schema.clone(), + }))), + // SortSink should not produce this branch + FinalizedSink::Operator => unreachable!(), + } + } + + fn as_any(&mut self) -> &mut dyn Any { + self + } + + fn fmt(&self) -> &str { + "sort_multiple" + } +} + +struct DropEncoded { + source: Box, + sort_idx: Arc<[usize]>, + sort_options: SortMultipleOptions, + sort_dtypes: Option>, + rows: Vec<&'static [u8]>, + sort_opts: Arc<[RowEncodingOptions]>, + sort_dicts: Arc<[Option]>, + output_schema: SchemaRef, +} + +impl Source for DropEncoded { + fn get_batches(&mut self, context: &PExecutionContext) -> PolarsResult { + let mut result = self.source.get_batches(context); + if let Ok(SourceResult::GotMoreData(data)) = &mut result { + for chunk in data { + finalize_dataframe( + &mut chunk.data, + self.sort_idx.as_ref(), + &self.sort_options, + self.sort_dtypes.as_deref(), + &mut self.rows, + self.sort_opts.as_ref(), + self.sort_dicts.as_ref(), + &self.output_schema, + ) + } + }; + result + } + + fn fmt(&self) -> &str { + "sort_multiple_source" + } +} diff --git a/crates/polars-pipe/src/executors/sinks/sort/source.rs b/crates/polars-pipe/src/executors/sinks/sort/source.rs new file mode 100644 index 000000000000..b630990b976a --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/sort/source.rs @@ -0,0 +1,243 @@ +use std::iter::Peekable; +use std::path::PathBuf; +use std::time::Instant; + +use polars_core::POOL; +use polars_core::prelude::*; +use polars_core::utils::{accumulate_dataframes_vertical_unchecked, split_df}; +use rayon::prelude::*; + +use crate::executors::sinks::io::IOThread; +use crate::executors::sinks::memory::MemTracker; +use crate::executors::sinks::sort::ooc::{PartitionSpiller, read_df}; +use crate::executors::sinks::sort::sink::sort_accumulated; +use crate::executors::sources::get_source_index; +use crate::operators::{DataChunk, PExecutionContext, Source, SourceResult}; + +pub struct SortSource { + files: Peekable>, + n_threads: usize, + sort_idx: usize, + descending: bool, + nulls_last: bool, + chunk_offset: IdxSize, + slice: Option<(i64, usize)>, + finished: bool, + io_thread: IOThread, + memtrack: MemTracker, + // Start of the Source phase + source_start: Instant, + // Start of the OOC sort operation. + ooc_start: Instant, + partition_spiller: PartitionSpiller, + current_part: usize, +} + +impl SortSource { + #[allow(clippy::too_many_arguments)] + pub(super) fn new( + mut files: Vec<(u32, PathBuf)>, + sort_idx: usize, + descending: bool, + nulls_last: bool, + slice: Option<(i64, usize)>, + verbose: bool, + io_thread: IOThread, + memtrack: MemTracker, + ooc_start: Instant, + partition_spiller: PartitionSpiller, + ) -> Self { + if verbose { + eprintln!("started sort source phase"); + } + + files.sort_unstable_by_key(|entry| entry.0); + + let n_threads = POOL.current_num_threads(); + let files = files.into_iter().peekable(); + + Self { + files, + n_threads, + sort_idx, + descending, + nulls_last, + chunk_offset: get_source_index(1) as IdxSize, + slice, + finished: false, + io_thread, + memtrack, + source_start: Instant::now(), + ooc_start, + partition_spiller, + current_part: 0, + } + } + fn finish_batch(&mut self, dfs: Vec) -> Vec { + // TODO: make utility functions to save these allocations + let chunk_offset = self.chunk_offset; + self.chunk_offset += dfs.len() as IdxSize; + dfs.into_iter() + .enumerate() + .map(|(i, df)| DataChunk { + chunk_index: chunk_offset + i as IdxSize, + data: df, + }) + .collect() + } + + fn finish_from_df(&mut self, df: DataFrame) -> PolarsResult { + // Sort a single partition + // We always need to sort again! + let current_slice = self.slice; + + let mut df = match &mut self.slice { + None => sort_accumulated( + df, + self.sort_idx, + None, + SortOptions { + descending: self.descending, + nulls_last: self.nulls_last, + multithreaded: true, + maintain_order: false, + limit: None, + }, + ), + Some((offset, len)) => { + let df_len = df.height(); + debug_assert!(*offset >= 0); + let out = if *offset as usize >= df_len { + *offset -= df_len as i64; + Ok(df.slice(0, 0)) + } else { + let out = sort_accumulated( + df, + self.sort_idx, + current_slice, + SortOptions { + descending: self.descending, + nulls_last: self.nulls_last, + multithreaded: true, + maintain_order: false, + limit: None, + }, + ); + *len = len.saturating_sub(df_len); + *offset = 0; + out + }; + if *len == 0 { + self.finished = true; + } + out + }, + }?; + + // convert to chunks + let dfs = split_df(&mut df, self.n_threads, true); + Ok(SourceResult::GotMoreData(self.finish_batch(dfs))) + } + fn print_verbose(&self, verbose: bool) { + if verbose { + eprintln!("sort source phase took: {:?}", self.source_start.elapsed()); + eprintln!("full ooc sort took: {:?}", self.ooc_start.elapsed()); + } + } + + fn get_from_memory( + &mut self, + read: &mut Vec, + read_size: &mut usize, + part: usize, + keep_track: bool, + ) { + while self.current_part <= part { + if let Some(df) = self.partition_spiller.get(self.current_part - 1) { + if keep_track { + *read_size += df.estimated_size(); + } + read.push(df); + } + self.current_part += 1; + } + } +} + +impl Source for SortSource { + fn get_batches(&mut self, context: &PExecutionContext) -> PolarsResult { + // early return + if self.finished || self.current_part >= self.partition_spiller.len() { + self.print_verbose(context.verbose); + return Ok(SourceResult::Finished); + } + self.current_part += 1; + let mut read_size = 0; + let mut read = vec![]; + + match self.files.next() { + None => { + // Ensure we fetch all from memory. + self.get_from_memory( + &mut read, + &mut read_size, + self.partition_spiller.len(), + false, + ); + if read.is_empty() { + self.print_verbose(context.verbose); + Ok(SourceResult::Finished) + } else { + self.finished = true; + let df = accumulate_dataframes_vertical_unchecked(read); + self.finish_from_df(df) + } + }, + Some((mut partition, mut path)) => { + self.get_from_memory(&mut read, &mut read_size, partition as usize, true); + let limit = self.memtrack.get_available() / 3; + + loop { + if let Some(in_mem) = self.partition_spiller.get(partition as usize) { + read_size += in_mem.estimated_size(); + read.push(in_mem) + } + + let files = std::fs::read_dir(&path)?.collect::>>()?; + + // read the files in a single partition in parallel + let dfs = POOL.install(|| { + files + .par_iter() + .map(|entry| { + let df = read_df(&entry.path())?; + Ok(df) + }) + .collect::>>() + })?; + + let df = accumulate_dataframes_vertical_unchecked(dfs); + read_size += df.estimated_size(); + read.push(df); + if read_size > limit { + break; + } + + let Some((next_part, next_path)) = self.files.next() else { + break; + }; + path = next_path; + partition = next_part; + } + let df = accumulate_dataframes_vertical_unchecked(read); + let out = self.finish_from_df(df); + self.io_thread.clean(path); + out + }, + } + } + + fn fmt(&self) -> &str { + "sort_source" + } +} diff --git a/crates/polars-pipe/src/executors/sinks/utils.rs b/crates/polars-pipe/src/executors/sinks/utils.rs new file mode 100644 index 000000000000..bcb5b6ed6067 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/utils.rs @@ -0,0 +1,20 @@ +use arrow::array::BinaryArray; +use polars_core::hashing::_hash_binary_array; +use polars_utils::aliases::PlSeedableRandomStateQuality; + +pub(super) fn hash_rows( + columns: &BinaryArray, + buf: &mut Vec, + hb: &PlSeedableRandomStateQuality, +) { + debug_assert!(buf.is_empty()); + _hash_binary_array(columns, *hb, buf); +} + +pub(super) fn load_vec T>(partitions: usize, item: F) -> Vec { + let mut buf = Vec::with_capacity(partitions); + for _ in 0..partitions { + buf.push(item()); + } + buf +} diff --git a/crates/polars-pipe/src/executors/sources/csv.rs b/crates/polars-pipe/src/executors/sources/csv.rs new file mode 100644 index 000000000000..09ab974f15f3 --- /dev/null +++ b/crates/polars-pipe/src/executors/sources/csv.rs @@ -0,0 +1,238 @@ +use std::fs::File; + +use polars_core::error::feature_gated; +use polars_core::{POOL, config}; +use polars_io::csv::read::{BatchedCsvReader, CsvReadOptions, CsvReader}; +use polars_io::path_utils::is_cloud_url; +use polars_plan::dsl::{ScanSources, UnifiedScanArgs}; +use polars_plan::global::_set_n_rows_for_scan; +use polars_utils::itertools::Itertools; +use polars_utils::slice_enum::Slice; + +use super::*; +use crate::pipeline::determine_chunk_size; + +pub(crate) struct CsvSource { + #[allow(dead_code)] + // this exist because we need to keep ownership + schema: SchemaRef, + // Safety: `reader` outlives `batched_reader` + // (so we have to order the `batched_reader` first in the struct fields) + batched_reader: Option>, + reader: Option>, + n_threads: usize, + sources: ScanSources, + options: Option, + unified_scan_args: Box, + verbose: bool, + // state for multi-file reads + current_path_idx: usize, + n_rows_read: usize, + first_schema: SchemaRef, + include_file_path: Option, +} + +impl CsvSource { + // Delay initializing the reader + // otherwise all files would be opened during construction of the pipeline + // leading to Too many Open files error + fn init_next_reader(&mut self) -> PolarsResult<()> { + let paths = self + .sources + .as_paths() + .ok_or_else(|| polars_err!(nyi = "Streaming scanning of in-memory buffers"))?; + let unified_scan_args = self.unified_scan_args.as_ref(); + + let n_rows = unified_scan_args.pre_slice.clone().map(|slice| { + assert!(matches!(slice, Slice::Positive { len: 0, .. })); + slice.len() + }); + + if self.current_path_idx == paths.len() + || (n_rows.is_some() && n_rows.unwrap() <= self.n_rows_read) + { + return Ok(()); + } + let path = &paths[self.current_path_idx]; + + let force_async = config::force_async(); + let run_async = force_async || is_cloud_url(path); + + if self.current_path_idx == 0 && force_async && self.verbose { + eprintln!("ASYNC READING FORCED"); + } + + self.current_path_idx += 1; + + let options = self.options.clone().unwrap(); + let mut with_columns = unified_scan_args.projection.clone(); + let mut projected_len = 0; + with_columns + .as_ref() + .inspect(|columns| projected_len = columns.len()); + + if projected_len == 0 { + with_columns = None; + } + + let n_cols = if projected_len > 0 { + projected_len + } else { + self.schema.len() + }; + let n_rows = _set_n_rows_for_scan( + unified_scan_args + .pre_slice + .clone() + .map(|slice| { + assert!(matches!(slice, Slice::Positive { len: 0, .. })); + slice.len() + }) + .map(|n| n.saturating_sub(self.n_rows_read)), + ); + let row_index = unified_scan_args.row_index.clone().map(|mut ri| { + ri.offset += self.n_rows_read as IdxSize; + ri + }); + // inversely scale the chunk size by the number of threads so that we reduce memory pressure + // in streaming + let chunk_size = determine_chunk_size(n_cols, POOL.current_num_threads())?; + + if self.verbose { + eprintln!("STREAMING CHUNK SIZE: {chunk_size} rows") + } + + let options = options + .with_schema(Some(self.schema.clone())) + .with_n_rows(n_rows) + .with_columns(with_columns) + .with_rechunk(false) + .with_row_index(row_index); + + let reader: CsvReader = if run_async { + feature_gated!("cloud", { + options.into_reader_with_file_handle( + polars_io::file_cache::FILE_CACHE + .get_entry(path.to_str().unwrap()) + // Safety: This was initialized by schema inference. + .unwrap() + .try_open_assume_latest()?, + ) + }) + } else { + options + .with_path(Some(path)) + .try_into_reader_with_file_path(None)? + }; + + if let Some(col) = &unified_scan_args.include_file_paths { + self.include_file_path = + Some(StringChunked::full(col.clone(), path.to_str().unwrap(), 1)); + }; + + self.reader = Some(reader); + let reader = self.reader.as_mut().unwrap(); + + // Safety: `reader` outlives `batched_reader` + let reader: &'static mut CsvReader = unsafe { std::mem::transmute(reader) }; + let batched_reader = reader.batched_borrowed()?; + self.batched_reader = Some(batched_reader); + Ok(()) + } + + pub(crate) fn new( + sources: ScanSources, + schema: SchemaRef, + options: CsvReadOptions, + unified_scan_args: Box, + verbose: bool, + ) -> PolarsResult { + Ok(CsvSource { + schema, + reader: None, + batched_reader: None, + n_threads: POOL.current_num_threads(), + sources, + options: Some(options), + unified_scan_args, + verbose, + current_path_idx: 0, + n_rows_read: 0, + first_schema: Default::default(), + include_file_path: None, + }) + } +} + +impl Source for CsvSource { + fn get_batches(&mut self, _context: &PExecutionContext) -> PolarsResult { + loop { + let first_read_from_file = self.reader.is_none(); + + if first_read_from_file { + self.init_next_reader()?; + } + + if self.reader.is_none() { + // No more readers + return Ok(SourceResult::Finished); + } + + let Some(batches) = self + .batched_reader + .as_mut() + .unwrap() + .next_batches(self.n_threads)? + else { + self.reader = None; + continue; + }; + + if first_read_from_file { + if self.first_schema.is_empty() { + self.first_schema = batches[0].schema().clone(); + } + ensure_matching_schema(&self.first_schema, batches[0].schema())?; + } + + let index = get_source_index(0); + let mut n_rows_read = 0; + let mut max_height = 0; + let mut out = batches + .into_iter() + .enumerate_u32() + .map(|(i, data)| { + max_height = max_height.max(data.height()); + n_rows_read += data.height(); + DataChunk { + chunk_index: (index + i) as IdxSize, + data, + } + }) + .collect::>(); + + if let Some(ca) = &mut self.include_file_path { + if ca.len() < max_height { + *ca = ca.new_from_index(0, max_height); + }; + + for data_chunk in &mut out { + let n = data_chunk.data.height(); + // SAFETY: Columns are only replaced with columns + // 1. of the same name, and + // 2. of the same length. + unsafe { data_chunk.data.get_columns_mut() }.push(ca.slice(0, n).into_column()); + data_chunk.data.clear_schema(); + } + } + + self.n_rows_read = self.n_rows_read.saturating_add(n_rows_read); + get_source_index(out.len() as u32); + + return Ok(SourceResult::GotMoreData(out)); + } + } + fn fmt(&self) -> &str { + "csv" + } +} diff --git a/crates/polars-pipe/src/executors/sources/frame.rs b/crates/polars-pipe/src/executors/sources/frame.rs new file mode 100644 index 000000000000..14110e36b241 --- /dev/null +++ b/crates/polars-pipe/src/executors/sources/frame.rs @@ -0,0 +1,48 @@ +use std::iter::Enumerate; +use std::vec::IntoIter; + +use polars_core::POOL; +use polars_core::error::PolarsResult; +use polars_core::frame::DataFrame; +use polars_core::utils::split_df; +use polars_utils::IdxSize; + +use crate::executors::sources::get_source_index; +use crate::operators::{DataChunk, PExecutionContext, Source, SourceResult}; + +pub struct DataFrameSource { + dfs: Enumerate>, + n_threads: usize, +} + +impl DataFrameSource { + pub(crate) fn from_df(mut df: DataFrame) -> Self { + let n_threads = POOL.current_num_threads(); + let dfs = split_df(&mut df, n_threads, false); + let dfs = dfs.into_iter().enumerate(); + Self { dfs, n_threads } + } +} + +impl Source for DataFrameSource { + fn get_batches(&mut self, _context: &PExecutionContext) -> PolarsResult { + let idx_offset = get_source_index(0); + let chunks = (&mut self.dfs) + .map(|(chunk_index, data)| DataChunk { + chunk_index: (chunk_index as u32 + idx_offset) as IdxSize, + data, + }) + .take(self.n_threads) + .collect::>(); + get_source_index(chunks.len() as u32); + + if chunks.is_empty() { + Ok(SourceResult::Finished) + } else { + Ok(SourceResult::GotMoreData(chunks)) + } + } + fn fmt(&self) -> &str { + "df" + } +} diff --git a/crates/polars-pipe/src/executors/sources/ipc_one_shot.rs b/crates/polars-pipe/src/executors/sources/ipc_one_shot.rs new file mode 100644 index 000000000000..6f20bec84fa8 --- /dev/null +++ b/crates/polars-pipe/src/executors/sources/ipc_one_shot.rs @@ -0,0 +1,37 @@ +use std::fs::File; +use std::path::Path; + +use polars_core::prelude::*; +use polars_io::SerReader; +use polars_io::ipc::IpcReader; + +use crate::operators::{DataChunk, PExecutionContext, Source, SourceResult}; + +/// Reads the whole file in one pass +pub struct IpcSourceOneShot { + reader: Option>, +} + +impl IpcSourceOneShot { + #[allow(unused_variables)] + pub(crate) fn new(path: &Path) -> PolarsResult { + let file = polars_utils::open_file(path)?; + let reader = Some(IpcReader::new(file)); + + Ok(IpcSourceOneShot { reader }) + } +} + +impl Source for IpcSourceOneShot { + fn get_batches(&mut self, _context: &PExecutionContext) -> PolarsResult { + if self.reader.is_none() { + Ok(SourceResult::Finished) + } else { + let df = self.reader.take().unwrap().finish()?; + Ok(SourceResult::GotMoreData(vec![DataChunk::new(0, df)])) + } + } + fn fmt(&self) -> &str { + "ipc-one-shot" + } +} diff --git a/crates/polars-pipe/src/executors/sources/mod.rs b/crates/polars-pipe/src/executors/sources/mod.rs new file mode 100644 index 000000000000..11f3e2400461 --- /dev/null +++ b/crates/polars-pipe/src/executors/sources/mod.rs @@ -0,0 +1,24 @@ +#[cfg(feature = "csv")] +mod csv; +mod frame; +mod ipc_one_shot; +mod reproject; +mod union; + +use std::sync::atomic::{AtomicU32, Ordering}; + +#[cfg(feature = "csv")] +pub(crate) use csv::CsvSource; +pub(crate) use frame::*; +pub(crate) use ipc_one_shot::*; +pub(crate) use reproject::*; +pub(crate) use union::*; + +#[cfg(feature = "csv")] +use super::*; + +static CHUNK_INDEX: AtomicU32 = AtomicU32::new(0); + +pub(super) fn get_source_index(add: u32) -> u32 { + CHUNK_INDEX.fetch_add(add, Ordering::Relaxed) +} diff --git a/crates/polars-pipe/src/executors/sources/reproject.rs b/crates/polars-pipe/src/executors/sources/reproject.rs new file mode 100644 index 000000000000..250d6fc5adf4 --- /dev/null +++ b/crates/polars-pipe/src/executors/sources/reproject.rs @@ -0,0 +1,39 @@ +use polars_core::prelude::SchemaRef; + +use crate::executors::operators::reproject_chunk; +use crate::operators::{PExecutionContext, PolarsResult, Source, SourceResult}; + +/// A source that will ensure we keep the schema order +pub(crate) struct ReProjectSource { + schema: SchemaRef, + source: Box, + positions: Vec, +} + +impl ReProjectSource { + pub(crate) fn new(schema: SchemaRef, source: Box) -> Self { + ReProjectSource { + schema, + source, + positions: vec![], + } + } +} + +impl Source for ReProjectSource { + fn get_batches(&mut self, context: &PExecutionContext) -> PolarsResult { + Ok(match self.source.get_batches(context)? { + SourceResult::Finished => SourceResult::Finished, + SourceResult::GotMoreData(mut chunks) => { + for chunk in &mut chunks { + reproject_chunk(chunk, &mut self.positions, self.schema.as_ref())?; + } + SourceResult::GotMoreData(chunks) + }, + }) + } + + fn fmt(&self) -> &str { + "re-project-source" + } +} diff --git a/crates/polars-pipe/src/executors/sources/union.rs b/crates/polars-pipe/src/executors/sources/union.rs new file mode 100644 index 000000000000..771c21f27c04 --- /dev/null +++ b/crates/polars-pipe/src/executors/sources/union.rs @@ -0,0 +1,35 @@ +use polars_core::error::PolarsResult; + +use crate::operators::{PExecutionContext, Source, SourceResult}; + +pub struct UnionSource { + sources: Vec>, + source_index: usize, +} + +impl UnionSource { + pub(crate) fn new(sources: Vec>) -> Self { + Self { + sources, + source_index: 0, + } + } +} + +impl Source for UnionSource { + fn get_batches(&mut self, context: &PExecutionContext) -> PolarsResult { + // early return if we have data + // if no data we deplete the loop and are finished + while self.source_index < self.sources.len() { + let src = &mut self.sources[self.source_index]; + match src.get_batches(context)? { + SourceResult::Finished => self.source_index += 1, + SourceResult::GotMoreData(chunks) => return Ok(SourceResult::GotMoreData(chunks)), + } + } + Ok(SourceResult::Finished) + } + fn fmt(&self) -> &str { + "union" + } +} diff --git a/crates/polars-pipe/src/expressions.rs b/crates/polars-pipe/src/expressions.rs new file mode 100644 index 000000000000..f4efe498cc2d --- /dev/null +++ b/crates/polars-pipe/src/expressions.rs @@ -0,0 +1,16 @@ +use polars_core::prelude::*; +use polars_expr::state::ExecutionState; +use polars_io::predicates::PhysicalIoExpr; +use polars_plan::dsl::Expr; + +use crate::operators::DataChunk; + +pub trait PhysicalPipedExpr: PhysicalIoExpr + Send + Sync { + /// Take a [`DataFrame`] and produces a boolean [`Series`] that serves + /// as a predicate mask + fn evaluate(&self, chunk: &DataChunk, lazy_state: &ExecutionState) -> PolarsResult; + + fn field(&self, input_schema: &Schema) -> PolarsResult; + + fn expression(&self) -> Expr; +} diff --git a/crates/polars-pipe/src/lib.rs b/crates/polars-pipe/src/lib.rs new file mode 100644 index 000000000000..3566706fd4f6 --- /dev/null +++ b/crates/polars-pipe/src/lib.rs @@ -0,0 +1,5 @@ +#![allow(unsafe_op_in_unsafe_fn)] +mod executors; +pub mod expressions; +pub mod operators; +pub mod pipeline; diff --git a/crates/polars-pipe/src/operators/chunks.rs b/crates/polars-pipe/src/operators/chunks.rs new file mode 100644 index 000000000000..d237510b4636 --- /dev/null +++ b/crates/polars-pipe/src/operators/chunks.rs @@ -0,0 +1,185 @@ +use polars_core::utils::accumulate_dataframes_vertical_unchecked; + +use super::*; + +#[derive(Clone, Debug)] +pub struct DataChunk { + pub chunk_index: IdxSize, + pub data: DataFrame, +} + +impl DataChunk { + pub(crate) fn new(chunk_index: IdxSize, data: DataFrame) -> Self { + // Check the invariant that all columns have a single chunk. + #[cfg(debug_assertions)] + { + for c in data.get_columns() { + assert_eq!(c.as_materialized_series().chunks().len(), 1); + } + } + Self { chunk_index, data } + } + pub(crate) fn with_data(&self, data: DataFrame) -> Self { + Self::new(self.chunk_index, data) + } + pub(crate) fn is_empty(&self) -> bool { + self.data.is_empty() + } +} + +pub(crate) fn chunks_to_df_unchecked(chunks: Vec) -> DataFrame { + accumulate_dataframes_vertical_unchecked(chunks.into_iter().map(|c| c.data)) +} + +/// Combine a series of `DataFrame`s, and if they're small enough, combine them +/// into larger `DataFrame`s using `vstack`. This allows the caller to turn them +/// into contiguous memory allocations so that we don't suffer from overhead of +/// many small writes. The assumption is that added `DataFrame`s are already in +/// the correct order, and can therefore be combined. +/// +/// The benefit of having a series of `DataFrame` that are e.g. 4MB each that +/// are then made contiguous is that you're not using a lot of memory (an extra +/// 4MB), but you're still doing better than if you had a series of 2KB +/// `DataFrame`s. +/// +/// Changing the `DataFrame` into contiguous chunks is the caller's +/// responsibility. +#[cfg(any(feature = "parquet", feature = "ipc", feature = "csv"))] +#[derive(Clone)] +pub(crate) struct StreamingVstacker { + current_dataframe: Option, + /// How big should resulting chunks be, if possible? + output_chunk_size: usize, +} + +#[cfg(any(feature = "parquet", feature = "ipc", feature = "csv"))] +impl StreamingVstacker { + /// Create a new instance. + pub fn new(output_chunk_size: usize) -> Self { + Self { + current_dataframe: None, + output_chunk_size, + } + } + + /// Add another `DataFrame`, return (potentially combined) `DataFrame`s that + /// result, if any. + pub fn add(&mut self, next_frame: DataFrame) -> impl Iterator { + let mut result: [Option; 2] = [None, None]; + + // If the next chunk is too large, we probably don't want make copies of + // it if a caller does as_single_chunk(), so we flush in advance. + if self.current_dataframe.is_some() + && next_frame.estimated_size() > self.output_chunk_size / 4 + { + result[0] = self.flush(); + } + + if let Some(ref mut current_frame) = self.current_dataframe { + current_frame + .vstack_mut(&next_frame) + .expect("These are chunks from the same dataframe"); + } else { + self.current_dataframe = Some(next_frame); + }; + + if self.current_dataframe.as_ref().unwrap().estimated_size() > self.output_chunk_size { + result[1] = self.flush(); + } + result.into_iter().flatten() + } + + /// Clear and return any cached `DataFrame` data. + #[must_use] + fn flush(&mut self) -> Option { + std::mem::take(&mut self.current_dataframe) + } + + /// Finish and return any remaining cached `DataFrame` data. The only way + /// that `SemicontiguousVstacker` should be cleaned up. + #[must_use] + pub fn finish(mut self) -> Option { + self.flush() + } +} + +#[cfg(any(feature = "parquet", feature = "ipc", feature = "csv"))] +impl Default for StreamingVstacker { + /// 4 MB was chosen based on some empirical experiments that showed it to + /// be decently faster than lower or higher values, and it's small enough + /// it won't impact memory usage significantly. + fn default() -> Self { + StreamingVstacker::new(4 * 1024 * 1024) + } +} + +#[cfg(test)] +#[cfg(any(feature = "parquet", feature = "ipc", feature = "csv"))] +mod test { + use super::*; + + /// DataFrames get merged into chunks that are bigger than the specified + /// size when possible. + #[test] + fn semicontiguous_vstacker_merges() { + let test = semicontiguous_vstacker_merges_impl; + test(vec![10]); + test(vec![10, 10, 10, 10, 10, 10, 10]); + test(vec![10, 40, 10, 10, 10, 10]); + test(vec![40, 10, 10, 40, 10, 10, 40]); + test(vec![50, 50, 50]); + } + + /// Eventually would be nice to drive this with proptest. + fn semicontiguous_vstacker_merges_impl(df_lengths: Vec) { + // Convert the lengths into a series of DataFrames: + let mut vstacker = StreamingVstacker::new(4096); + let dfs: Vec = df_lengths + .iter() + .enumerate() + .map(|(i, length)| { + let series = Column::new("val".into(), vec![i as u64; *length]); + DataFrame::new(vec![series]).unwrap() + }) + .collect(); + + // Combine the DataFrames using a SemicontiguousVstacker: + let mut results = vec![]; + for (i, df) in dfs.iter().enumerate() { + for mut result_df in vstacker.add(df.clone()) { + result_df.as_single_chunk(); + results.push((i, result_df)); + } + } + if let Some(mut result_df) = vstacker.finish() { + result_df.as_single_chunk(); + results.push((df_lengths.len() - 1, result_df)); + } + + // Make sure the lengths are as sufficiently large, and the chunks + // were merged, the whole point of the exercise: + for (original_idx, result_df) in &results { + if result_df.height() < 40 { + // This means either this was the last df, or the next one + // was big enough we decided not to aggregate. + if *original_idx < results.len() - 1 { + assert!(dfs[original_idx + 1].height() > 10); + } + } + // Make sure all result DataFrames only have a single chunk. + assert_eq!( + result_df.get_columns()[0] + .as_materialized_series() + .chunk_lengths() + .len(), + 1 + ); + } + + // Make sure the data was preserved: + assert_eq!( + accumulate_dataframes_vertical_unchecked(dfs.into_iter()), + accumulate_dataframes_vertical_unchecked(results.into_iter().map(|(_, df)| df)), + ); + } +} diff --git a/crates/polars-pipe/src/operators/context.rs b/crates/polars-pipe/src/operators/context.rs new file mode 100644 index 000000000000..a7e52820bcd4 --- /dev/null +++ b/crates/polars-pipe/src/operators/context.rs @@ -0,0 +1,16 @@ +use polars_expr::state::ExecutionState; + +pub struct PExecutionContext { + // injected upstream in polars-lazy + pub(crate) execution_state: ExecutionState, + pub(crate) verbose: bool, +} + +impl PExecutionContext { + pub(crate) fn new(state: ExecutionState, verbose: bool) -> Self { + PExecutionContext { + execution_state: state, + verbose, + } + } +} diff --git a/crates/polars-pipe/src/operators/mod.rs b/crates/polars-pipe/src/operators/mod.rs new file mode 100644 index 000000000000..af47a3b87017 --- /dev/null +++ b/crates/polars-pipe/src/operators/mod.rs @@ -0,0 +1,12 @@ +pub mod chunks; +mod context; +mod operator; +mod sink; +mod source; + +pub(crate) use chunks::*; +pub use context::*; +pub(crate) use operator::*; +pub(crate) use polars_core::prelude::*; +pub use sink::*; +pub(crate) use source::*; diff --git a/crates/polars-pipe/src/operators/operator.rs b/crates/polars-pipe/src/operators/operator.rs new file mode 100644 index 000000000000..9082728a9fdb --- /dev/null +++ b/crates/polars-pipe/src/operators/operator.rs @@ -0,0 +1,31 @@ +use super::*; + +pub enum OperatorResult { + /// needs to be called again with new chunk. + /// Or in case of `flush` needs to be called again. + NeedsNewData, + /// needs to be called again with same chunk. + HaveMoreOutPut(DataChunk), + /// this operator is finished + Finished(DataChunk), +} + +pub trait Operator: Send + Sync { + fn execute( + &mut self, + context: &PExecutionContext, + chunk: &DataChunk, + ) -> PolarsResult; + + fn flush(&mut self) -> PolarsResult { + unimplemented!() + } + + fn must_flush(&self) -> bool { + false + } + + fn split(&self, thread_no: usize) -> Box; + + fn fmt(&self) -> &str; +} diff --git a/crates/polars-pipe/src/operators/sink.rs b/crates/polars-pipe/src/operators/sink.rs new file mode 100644 index 000000000000..a1933c346b5f --- /dev/null +++ b/crates/polars-pipe/src/operators/sink.rs @@ -0,0 +1,52 @@ +use std::any::Any; +use std::fmt::{Debug, Formatter}; + +use polars_utils::arena::Node; + +use super::*; + +#[derive(Debug)] +pub enum SinkResult { + Finished, + CanHaveMoreInput, +} + +pub enum FinalizedSink { + Finished(DataFrame), + Operator, + Source(Box), +} + +impl Debug for FinalizedSink { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let s = match self { + FinalizedSink::Finished(_) => "finished", + FinalizedSink::Operator => "operator", + FinalizedSink::Source(_) => "source", + }; + write!(f, "{s}") + } +} + +pub trait Sink: Send + Sync { + fn sink(&mut self, context: &PExecutionContext, chunk: DataChunk) -> PolarsResult; + + fn combine(&mut self, other: &mut dyn Sink); + + fn split(&self, thread_no: usize) -> Box; + + fn finalize(&mut self, context: &PExecutionContext) -> PolarsResult; + + fn as_any(&mut self) -> &mut dyn Any; + + fn fmt(&self) -> &str; + + fn is_join_build(&self) -> bool { + false + } + + // Only implemented for Join sinks + fn node(&self) -> Node { + unimplemented!() + } +} diff --git a/crates/polars-pipe/src/operators/source.rs b/crates/polars-pipe/src/operators/source.rs new file mode 100644 index 000000000000..80006347b65b --- /dev/null +++ b/crates/polars-pipe/src/operators/source.rs @@ -0,0 +1,12 @@ +use super::*; + +pub enum SourceResult { + Finished, + GotMoreData(Vec), +} + +pub trait Source: Send + Sync { + fn get_batches(&mut self, context: &PExecutionContext) -> PolarsResult; + + fn fmt(&self) -> &str; +} diff --git a/crates/polars-pipe/src/pipeline/config.rs b/crates/polars-pipe/src/pipeline/config.rs new file mode 100644 index 000000000000..8b137891791f --- /dev/null +++ b/crates/polars-pipe/src/pipeline/config.rs @@ -0,0 +1 @@ + diff --git a/crates/polars-pipe/src/pipeline/convert.rs b/crates/polars-pipe/src/pipeline/convert.rs new file mode 100644 index 000000000000..e97577ea6c4d --- /dev/null +++ b/crates/polars-pipe/src/pipeline/convert.rs @@ -0,0 +1,693 @@ +use std::cell::RefCell; +use std::rc::Rc; + +use hashbrown::hash_map::Entry; +use polars_core::prelude::*; +use polars_core::with_match_physical_integer_polars_type; +use polars_ops::prelude::JoinType; +use polars_plan::prelude::expr_ir::{ExprIR, OutputName}; +use polars_plan::prelude::*; + +use crate::executors::operators::{HstackOperator, PlaceHolder}; +use crate::executors::sinks::group_by::GenericGroupby2; +use crate::executors::sinks::group_by::aggregates::convert_to_hash_agg; +use crate::executors::sinks::*; +use crate::executors::{operators, sources}; +use crate::expressions::PhysicalPipedExpr; +use crate::operators::{Operator, Sink as SinkTrait, Source}; +use crate::pipeline::dispatcher::ThreadedSink; +use crate::pipeline::{PhysOperator, PipeLine}; + +pub type CallBacks = PlHashMap; + +fn exprs_to_physical( + exprs: &[ExprIR], + expr_arena: &Arena, + to_physical: &F, + schema: &SchemaRef, +) -> PolarsResult>> +where + F: Fn(&ExprIR, &Arena, &SchemaRef) -> PolarsResult>, +{ + exprs + .iter() + .map(|e| to_physical(e, expr_arena, schema)) + .collect() +} + +#[allow(unused_variables)] +fn get_source( + source: IR, + operator_objects: &mut Vec>, + expr_arena: &mut Arena, + to_physical: &F, + push_predicate: bool, + verbose: bool, +) -> PolarsResult> +where + F: Fn(&ExprIR, &Arena, &SchemaRef) -> PolarsResult>, +{ + use IR::*; + match source { + DataFrameScan { + df, output_schema, .. + } => { + let mut df = (*df).clone(); + let schema = output_schema.clone().unwrap_or_else(|| df.schema().clone()); + if push_predicate { + // projection is free + if let Some(schema) = output_schema { + let columns = schema.iter_names_cloned().collect::>(); + df = df._select_impl_unchecked(&columns)?; + } + } + Ok(Box::new(sources::DataFrameSource::from_df(df)) as Box) + }, + Scan { + sources, + file_info, + hive_parts, + unified_scan_args, + predicate, + output_schema, + scan_type, + } => { + let paths = sources.into_paths(); + let schema = output_schema.as_ref().unwrap_or(&file_info.schema); + + // Add predicate to operators. + // Except for parquet, as that format can use statistics to prune file/row-groups. + #[cfg(feature = "parquet")] + let is_parquet = matches!(&*scan_type, FileScan::Parquet { .. }); + #[cfg(not(feature = "parquet"))] + let is_parquet = false; + + if let (false, true, Some(predicate)) = (is_parquet, push_predicate, predicate.clone()) + { + #[cfg(feature = "parquet")] + debug_assert!(!matches!(&*scan_type, FileScan::Parquet { .. })); + let predicate = to_physical(&predicate, expr_arena, schema)?; + let op = operators::FilterOperator { predicate }; + let op = Box::new(op) as Box; + operator_objects.push(op) + } + match *scan_type { + #[cfg(feature = "csv")] + FileScan::Csv { options, .. } => { + let src = sources::CsvSource::new( + sources, + file_info.reader_schema.clone().unwrap().unwrap_right(), + options, + unified_scan_args, + verbose, + )?; + Ok(Box::new(src) as Box) + }, + #[cfg(feature = "parquet")] + FileScan::Parquet { + options: parquet_options, + metadata, + } => panic!("Parquet no longer supported for old streaming engine"), + _ => todo!(), + } + }, + _ => unreachable!(), + } +} + +pub fn get_sink( + node: Node, + lp_arena: &Arena, + expr_arena: &mut Arena, + to_physical: &F, + callbacks: &mut CallBacks, +) -> PolarsResult> +where + F: Fn(&ExprIR, &Arena, &SchemaRef) -> PolarsResult>, +{ + use IR::*; + let out = match lp_arena.get(node) { + Sink { input, payload } => { + let input_schema = lp_arena.get(*input).schema(lp_arena); + match payload { + SinkTypeIR::Memory => { + Box::new(OrderedSink::new(input_schema.into_owned())) as Box + }, + #[allow(unused_variables)] + SinkTypeIR::File(FileSinkType { + target, + file_type, + sink_options: _, + cloud_options, + }) => { + let SinkTarget::Path(path) = target else { + polars_bail!(InvalidOperation: "in-memory sinks are not supported for the old streaming engine"); + }; + let path = path.as_ref().as_path(); + match &file_type { + #[cfg(feature = "parquet")] + FileType::Parquet(options) => Box::new(ParquetSink::new( + path, + *options, + input_schema.as_ref(), + cloud_options.as_ref(), + )?) + as Box, + #[cfg(feature = "ipc")] + FileType::Ipc(options) => Box::new(IpcSink::new( + path, + *options, + input_schema.as_ref(), + cloud_options.as_ref(), + )?) as Box, + #[cfg(feature = "csv")] + FileType::Csv(options) => Box::new(CsvSink::new( + path, + options.clone(), + input_schema.as_ref(), + cloud_options.as_ref(), + )?) as Box, + #[cfg(feature = "json")] + FileType::Json(options) => Box::new(JsonSink::new( + path, + *options, + input_schema.as_ref(), + cloud_options.as_ref(), + )?) + as Box, + #[allow(unreachable_patterns)] + _ => unreachable!(), + } + }, + SinkTypeIR::Partition { .. } => { + polars_bail!(InvalidOperation: "partitioning sink not supported in old streaming engine") + }, + } + }, + Join { + input_left, + input_right, + options, + left_on, + right_on, + .. + } => { + // slice pushdown optimization should not set this one in a streaming query. + assert!(options.args.slice.is_none()); + let swapped = swap_join_order(options); + let placeholder = callbacks.get(&node).unwrap().clone(); + + match &options.args.how { + #[cfg(feature = "cross_join")] + JoinType::Cross => Box::new(CrossJoin::new( + options.args.suffix().clone(), + swapped, + node, + placeholder, + )) as Box, + jt => { + let input_schema_left = lp_arena.get(*input_left).schema(lp_arena); + let join_columns_left = Arc::new(exprs_to_physical( + left_on, + expr_arena, + to_physical, + input_schema_left.as_ref(), + )?); + let input_schema_right = lp_arena.get(*input_right).schema(lp_arena); + let join_columns_right = Arc::new(exprs_to_physical( + right_on, + expr_arena, + to_physical, + input_schema_right.as_ref(), + )?); + + let swap_eval = || { + if swapped { + (join_columns_right.clone(), join_columns_left.clone()) + } else { + (join_columns_left.clone(), join_columns_right.clone()) + } + }; + + match jt { + JoinType::Inner | JoinType::Left => { + let (join_columns_left, join_columns_right) = swap_eval(); + + Box::new(GenericBuild::<()>::new( + options.args.suffix().clone(), + options.args.clone(), + swapped, + join_columns_left, + join_columns_right, + options.args.nulls_equal, + node, + // We don't need the key names for these joins. + vec![].into(), + vec![].into(), + placeholder, + )) as Box + }, + JoinType::Full => { + // First get the names before we (potentially) swap. + let key_names_left = join_columns_left + .iter() + .map(|e| e.field(&input_schema_left).unwrap().name) + .collect(); + let key_names_right = join_columns_left + .iter() + .map(|e| e.field(&input_schema_left).unwrap().name) + .collect(); + // Swap. + let (join_columns_left, join_columns_right) = swap_eval(); + + Box::new(GenericBuild::::new( + options.args.suffix().clone(), + options.args.clone(), + swapped, + join_columns_left, + join_columns_right, + options.args.nulls_equal, + node, + key_names_left, + key_names_right, + placeholder, + )) as Box + }, + _ => unimplemented!(), + } + }, + } + }, + Slice { input, offset, len } => { + let input_schema = lp_arena.get(*input).schema(lp_arena); + let slice = SliceSink::new(*offset as u64, *len as usize, input_schema.into_owned()); + Box::new(slice) as Box + }, + Sort { + input, + by_column, + slice, + sort_options, + } => { + let input_schema = lp_arena.get(*input).schema(lp_arena).into_owned(); + + if by_column.len() == 1 { + let by_column = aexpr_to_leaf_names_iter(by_column[0].node(), expr_arena) + .next() + .unwrap(); + let index = input_schema.try_index_of(by_column.as_ref())?; + + let sort_sink = SortSink::new(index, *slice, sort_options.clone(), input_schema); + Box::new(sort_sink) as Box + } else { + let sort_idx = by_column + .iter() + .map(|e| { + let name = aexpr_to_leaf_names_iter(e.node(), expr_arena) + .next() + .unwrap(); + input_schema.try_index_of(name.as_ref()) + }) + .collect::>>()?; + + let sort_sink = + SortSinkMultiple::new(*slice, sort_options.clone(), input_schema, sort_idx)?; + Box::new(sort_sink) as Box + } + }, + Distinct { input, options } => { + // We create a Groupby.agg_first()/agg_last (depending on the keep strategy + let input_schema = lp_arena.get(*input).schema(lp_arena).into_owned(); + + let (keys, aggs, output_schema) = match &options.subset { + None => { + let keys = input_schema + .iter_names() + .map(|name| { + let name: PlSmallStr = name.clone(); + let node = expr_arena.add(AExpr::Column(name.clone())); + ExprIR::new(node, OutputName::Alias(name)) + }) + .collect::>(); + let aggs = vec![]; + (keys, aggs, input_schema.clone()) + }, + Some(keys) => { + let mut group_by_out_schema = Schema::with_capacity(input_schema.len()); + let key_names = PlHashSet::from_iter(keys.iter().map(|s| s.as_ref())); + let keys = keys + .iter() + .map(|key| { + let (_, name, dtype) = input_schema.get_full(key.as_ref()).unwrap(); + group_by_out_schema.with_column(name.clone(), dtype.clone()); + let node = expr_arena.add(AExpr::Column(key.clone())); + ExprIR::new(node, OutputName::Alias(key.clone())) + }) + .collect::>(); + + let aggs = input_schema + .iter_names() + .flat_map(|name| { + if key_names.contains(name.as_str()) { + None + } else { + let (_, name, dtype) = + input_schema.get_full(name.as_str()).unwrap(); + group_by_out_schema.with_column(name.clone(), dtype.clone()); + + let name: PlSmallStr = name.clone(); + let col = expr_arena.add(AExpr::Column(name.clone())); + let node = match options.keep_strategy { + UniqueKeepStrategy::First | UniqueKeepStrategy::Any => { + expr_arena.add(AExpr::Agg(IRAggExpr::First(col))) + }, + UniqueKeepStrategy::Last => { + expr_arena.add(AExpr::Agg(IRAggExpr::Last(col))) + }, + UniqueKeepStrategy::None => { + unreachable!() + }, + }; + Some(ExprIR::new(node, OutputName::Alias(name))) + } + }) + .collect(); + (keys, aggs, group_by_out_schema.into()) + }, + }; + + let key_columns = Arc::new(exprs_to_physical( + &keys, + expr_arena, + to_physical, + &input_schema, + )?); + + let mut aggregation_columns = Vec::with_capacity(aggs.len()); + let mut agg_fns = Vec::with_capacity(aggs.len()); + let mut input_agg_dtypes = Vec::with_capacity(aggs.len()); + + for e in &aggs { + let (input_dtype, index, agg_fn) = + convert_to_hash_agg(e.node(), expr_arena, &input_schema, &to_physical); + aggregation_columns.push(index); + agg_fns.push(agg_fn); + input_agg_dtypes.push(input_dtype); + } + let aggregation_columns = Arc::new(aggregation_columns); + + let group_by_sink = Box::new(GenericGroupby2::new( + key_columns, + aggregation_columns, + Arc::from(agg_fns), + output_schema, + input_agg_dtypes, + options.slice, + )); + + Box::new(ReProjectSink::new(input_schema, group_by_sink)) + }, + GroupBy { + input, + keys, + aggs, + schema: output_schema, + options, + .. + } => { + let input_schema = lp_arena.get(*input).schema(lp_arena).as_ref().clone(); + let key_columns = Arc::new(exprs_to_physical( + keys, + expr_arena, + to_physical, + &input_schema, + )?); + + let mut aggregation_columns = Vec::with_capacity(aggs.len()); + let mut agg_fns = Vec::with_capacity(aggs.len()); + let mut input_agg_dtypes = Vec::with_capacity(aggs.len()); + + for e in aggs { + let (input_dtype, index, agg_fn) = + convert_to_hash_agg(e.node(), expr_arena, &input_schema, &to_physical); + aggregation_columns.push(index); + agg_fns.push(agg_fn); + input_agg_dtypes.push(input_dtype); + } + let aggregation_columns = Arc::new(aggregation_columns); + + if std::env::var("POLARS_STREAMING_GB2").as_deref() == Ok("1") { + Box::new(GenericGroupby2::new( + key_columns, + aggregation_columns, + Arc::from(agg_fns), + output_schema.clone(), + input_agg_dtypes, + options.slice, + )) + } else { + match ( + output_schema.get_at_index(0).unwrap().1.to_physical(), + keys.len(), + ) { + (dt, 1) if dt.is_integer() => { + with_match_physical_integer_polars_type!(dt, |$T| { + Box::new(group_by::PrimitiveGroupbySink::<$T>::new( + key_columns[0].clone(), + aggregation_columns, + agg_fns, + input_schema, + output_schema.clone(), + options.slice, + )) as Box + }) + }, + (DataType::String, 1) => Box::new(group_by::StringGroupbySink::new( + key_columns[0].clone(), + aggregation_columns, + agg_fns, + input_schema, + output_schema.clone(), + options.slice, + )) as Box, + _ => Box::new(GenericGroupby2::new( + key_columns, + aggregation_columns, + Arc::from(agg_fns), + output_schema.clone(), + input_agg_dtypes, + options.slice, + )), + } + } + }, + lp => { + panic!("{lp:?} not implemented") + }, + }; + Ok(out) +} + +pub fn get_dummy_operator() -> PlaceHolder { + operators::PlaceHolder::new() +} + +fn get_hstack( + exprs: &[ExprIR], + expr_arena: &Arena, + to_physical: &F, + input_schema: SchemaRef, + options: ProjectionOptions, +) -> PolarsResult +where + F: Fn(&ExprIR, &Arena, &SchemaRef) -> PolarsResult>, +{ + Ok(HstackOperator { + exprs: exprs_to_physical(exprs, expr_arena, &to_physical, &input_schema)?, + input_schema, + options, + }) +} + +pub fn get_operator( + node: Node, + lp_arena: &Arena, + expr_arena: &Arena, + to_physical: &F, +) -> PolarsResult> +where + F: Fn(&ExprIR, &Arena, &SchemaRef) -> PolarsResult>, +{ + use IR::*; + let op = match lp_arena.get(node) { + SimpleProjection { input, columns, .. } => { + let input_schema = lp_arena.get(*input).schema(lp_arena); + let columns = columns.iter_names_cloned().collect(); + let op = operators::SimpleProjectionOperator::new(columns, input_schema.into_owned()); + Box::new(op) as Box + }, + Select { + expr, + input, + options, + .. + } => { + let input_schema = lp_arena.get(*input).schema(lp_arena); + let op = operators::ProjectionOperator { + exprs: exprs_to_physical(expr, expr_arena, &to_physical, input_schema.as_ref())?, + options: *options, + }; + Box::new(op) as Box + }, + HStack { + exprs, + input, + options, + .. + } => { + let input_schema = lp_arena.get(*input).schema(lp_arena); + + let op = get_hstack( + exprs, + expr_arena, + to_physical, + (*input_schema).clone(), + *options, + )?; + + Box::new(op) as Box + }, + Filter { predicate, input } => { + let input_schema = lp_arena.get(*input).schema(lp_arena); + let predicate = to_physical(predicate, expr_arena, input_schema.as_ref())?; + let op = operators::FilterOperator { predicate }; + Box::new(op) as Box + }, + MapFunction { function, .. } => { + let op = operators::FunctionOperator::new(function.clone()); + Box::new(op) as Box + }, + Union { .. } => { + let op = operators::Pass::new("union"); + Box::new(op) as Box + }, + + lp => { + panic!("operator {lp:?} not (yet) supported") + }, + }; + Ok(op) +} + +#[allow(clippy::too_many_arguments)] +pub fn create_pipeline( + sources: &[Node], + operators: Vec>, + sink_nodes: Vec<(usize, Node, Rc>)>, + lp_arena: &Arena, + expr_arena: &mut Arena, + to_physical: F, + verbose: bool, + // Shared sinks are stored in a cache, so that they share state. + // If the shared sink is already in cache, that one is used. + sink_cache: &mut PlHashMap>, + callbacks: &mut CallBacks, +) -> PolarsResult +where + F: Fn(&ExprIR, &Arena, &SchemaRef) -> PolarsResult>, +{ + use IR::*; + + let mut source_objects = Vec::with_capacity(sources.len()); + let mut operator_objects = Vec::with_capacity(operators.len() + 1); + + for node in sources { + let src = match lp_arena.get(*node) { + lp @ DataFrameScan { .. } => get_source( + lp.clone(), + &mut operator_objects, + expr_arena, + &to_physical, + true, + verbose, + )?, + lp @ Scan { .. } => get_source( + lp.clone(), + &mut operator_objects, + expr_arena, + &to_physical, + true, + verbose, + )?, + Union { inputs, .. } => { + let sources = inputs + .iter() + .enumerate() + .map(|(i, node)| { + let lp = lp_arena.get(*node); + // only push predicate of first source + get_source( + lp.clone(), + &mut operator_objects, + expr_arena, + &to_physical, + i == 0, + verbose && i == 0, + ) + }) + .collect::>>()?; + Box::new(sources::UnionSource::new(sources)) as Box + }, + lp => { + panic!("source {lp:?} not (yet) supported") + }, + }; + source_objects.push(src) + } + + // this offset is because the source might have inserted operators + let operator_offset = operator_objects.len(); + operator_objects.extend(operators); + + let sinks = sink_nodes + .into_iter() + .map(|(offset, node, shared_count)| { + // ensure that shared sinks are really shared + // to achieve this we store/fetch them in a cache + let sink = if *shared_count.borrow() == 1 { + get_sink(node, lp_arena, expr_arena, &to_physical, callbacks)? + } else { + match sink_cache.entry(node.0) { + Entry::Vacant(entry) => { + let sink = get_sink(node, lp_arena, expr_arena, &to_physical, callbacks)?; + entry.insert(sink.split(0)); + sink + }, + Entry::Occupied(entry) => entry.get().split(0), + } + }; + Ok(ThreadedSink::new( + sink, + shared_count, + offset + operator_offset, + )) + }) + .collect::>>()?; + + Ok(PipeLine::new( + source_objects, + unsafe { + std::mem::transmute::>, Vec>(operator_objects) + }, + sinks, + verbose, + )) +} + +pub fn swap_join_order(options: &JoinOptions) -> bool { + matches!(options.args.how, JoinType::Left) + || match (options.rows_left, options.rows_right) { + ((Some(left), _), (Some(right), _)) => left > right, + ((_, left), (_, right)) => left > right, + } +} diff --git a/crates/polars-pipe/src/pipeline/dispatcher/drive_operator.rs b/crates/polars-pipe/src/pipeline/dispatcher/drive_operator.rs new file mode 100644 index 000000000000..4676c0d0041f --- /dev/null +++ b/crates/polars-pipe/src/pipeline/dispatcher/drive_operator.rs @@ -0,0 +1,260 @@ +use super::*; +use crate::pipeline::*; + +/// Take data chunks from the sources and pushes them into the operators + sink. Every operator +/// works thread local. +/// The caller passes an `operator_start`/`operator_end` to indicate which part of the pipeline +/// branch should be executed. +#[allow(clippy::too_many_arguments)] +pub(super) fn par_process_chunks( + chunks: Vec, + sink: ThreadedSinkMut, + ec: &PExecutionContext, + operators: &mut [ThreadedOperator], + operator_start: usize, + operator_end: usize, + src: &mut Box, + must_flush: &AtomicBool, +) -> PolarsResult<(Option, SourceResult)> { + debug_assert!(chunks.len() <= sink.len()); + let sink_results = Arc::new(Mutex::new(None)); + let mut next_batches: Option> = None; + let next_batches_ptr = &mut next_batches as *mut Option>; + let next_batches_ptr = unsafe { SyncPtr::new(next_batches_ptr) }; + + // 1. We will iterate the chunks/sinks/operators + // where every iteration belongs to a single thread + // 2. Then we will truncate the pipeline by `start`/`end` + // so that the pipeline represents pipeline that belongs to this sink + // 3. Then we push the data + // # Threading + // Within a rayon scope + // we spawn the jobs. They don't have to finish in any specific order, + // this makes it more lightweight than `par_iter` + + // borrow as ref and move into the closure + POOL.scope(|s| { + for ((chunk, sink), operator_pipe) in chunks + .into_iter() + .zip(sink.iter_mut()) + .zip(operators.iter_mut()) + { + let sink_results = sink_results.clone(); + // Truncate the operators that should run into the current sink. + let operator_pipe = &mut operator_pipe[operator_start..operator_end]; + + s.spawn(move |_| { + let out = if operator_pipe.is_empty() { + sink.sink(ec, chunk) + } else { + push_operators_single_thread(chunk, ec, operator_pipe, sink, must_flush) + }; + + match out { + Ok(SinkResult::Finished) | Err(_) => { + let mut lock = sink_results.lock().unwrap(); + *lock = Some(out) + }, + _ => {}, + } + }) + } + // already get batches on the thread pool + // if one job is finished earlier we can already start that work + s.spawn(|_| { + let out = src.get_batches(ec); + unsafe { + let ptr = next_batches_ptr.get(); + *ptr = Some(out); + } + }) + }); + + let next_batches = next_batches.unwrap()?; + let mut lock = sink_results.lock().unwrap(); + lock.take() + .transpose() + .map(|sink_result| (sink_result, next_batches)) +} + +/// This thread local logic that pushed a data chunk into the operators + sink +/// It can be that a single operator needs to be called multiple times, this is for instance the +/// case with joins that produce many tuples, that's why we keep a stack of `in_process` +/// operators. +pub(super) fn push_operators_single_thread( + chunk: DataChunk, + ec: &PExecutionContext, + operators: ThreadedOperatorMut, + sink: &mut Box, + must_flush: &AtomicBool, +) -> PolarsResult { + debug_assert!(!operators.is_empty()); + + // Stack based operator execution. + let mut in_process = vec![]; + let operator_offset = 0usize; + in_process.push((operator_offset, chunk)); + + while let Some((op_i, chunk)) = in_process.pop() { + match operators.get_mut(op_i) { + None => { + if let SinkResult::Finished = sink.sink(ec, chunk)? { + return Ok(SinkResult::Finished); + } + }, + Some(op) => { + let op = op.get_mut(); + match op.execute(ec, &chunk)? { + OperatorResult::Finished(chunk) => { + let flag = op.must_flush(); + let _ = must_flush.compare_exchange( + false, + flag, + Ordering::Relaxed, + Ordering::Relaxed, + ); + in_process.push((op_i + 1, chunk)) + }, + OperatorResult::HaveMoreOutPut(output_chunk) => { + // Push the next operator call with the same chunk on the stack + in_process.push((op_i, chunk)); + + // But first push the output in the next operator + // If a join can produce many rows, we want the filter to + // be executed in between, or sink into a slice so that we get + // sink::finished before we grow the stack with ever more coming chunks + in_process.push((op_i + 1, output_chunk)); + }, + OperatorResult::NeedsNewData => { + // done, take another chunk from the stack + }, + } + }, + } + } + + Ok(SinkResult::CanHaveMoreInput) +} + +/// Similar to `par_process_chunks`. +/// The caller passes an `operator_start`/`operator_end` to indicate which part of the pipeline +/// branch should be executed. +pub(super) fn par_flush( + sink: ThreadedSinkMut, + ec: &PExecutionContext, + operators: &mut [ThreadedOperator], + operator_start: usize, + operator_end: usize, +) { + // 1. We will iterate the chunks/sinks/operators + // where every iteration belongs to a single thread + // 2. Then we will truncate the pipeline by `start`/`end` + // so that the pipeline represents pipeline that belongs to this sink + // 3. Then we push the data + // # Threading + // Within a rayon scope + // we spawn the jobs. They don't have to finish in any specific order, + // this makes it more lightweight than `par_iter` + + // borrow as ref and move into the closure + POOL.scope(|s| { + for (sink, operator_pipe) in sink.iter_mut().zip(operators.iter_mut()) { + // Truncate the operators that should run into the current sink. + let operator_pipe = &mut operator_pipe[operator_start..operator_end]; + + s.spawn(move |_| { + flush_operators(ec, operator_pipe, sink).unwrap(); + }) + } + }); +} + +pub(super) fn flush_operators( + ec: &PExecutionContext, + operators: &mut [PhysOperator], + sink: &mut Box, +) -> PolarsResult { + let needs_flush = operators + .iter_mut() + .enumerate() + .filter_map(|(i, op)| { + if op.get_mut().must_flush() { + Some(i) + } else { + None + } + }) + .collect::>(); + + // Stack based flushing + operator execution. + if !needs_flush.is_empty() { + let mut in_process = vec![]; + + for op_i in needs_flush.into_iter() { + // Push all operators that need flushing on the stack. + // The `None` indicates that we have no `chunk` input, so we `flush`. + // `Some(chunk)` is the pushing branch + in_process.push((op_i, None)); + + // Next we immediately pop and determine the order of execution below. + // This is to ensure that all operators below upper operators are completely + // flushed when the `flush` is called in higher operators. As operators can `flush` + // multiple times. + while let Some((op_i, chunk)) = in_process.pop() { + match chunk { + // The branch for flushing. + None => { + let op = operators.get_mut(op_i).unwrap().get_mut(); + match op.flush()? { + OperatorResult::Finished(chunk) => { + // Push the chunk in the next operator. + in_process.push((op_i + 1, Some(chunk))) + }, + OperatorResult::HaveMoreOutPut(chunk) => { + // Ensure it is flushed again + in_process.push((op_i, None)); + // Push the chunk in the next operator. + in_process.push((op_i + 1, Some(chunk))) + }, + _ => unreachable!(), + } + }, + // The branch for pushing data in the operators. + // This is the same as the default stack executor, except now it pushes + // `Some(chunk)` instead of `chunk`. + Some(chunk) => { + match operators.get_mut(op_i) { + None => { + if let SinkResult::Finished = sink.sink(ec, chunk)? { + return Ok(SinkResult::Finished); + } + }, + Some(op) => { + let op = op.get_mut(); + match op.execute(ec, &chunk)? { + OperatorResult::Finished(chunk) => { + in_process.push((op_i + 1, Some(chunk))) + }, + OperatorResult::HaveMoreOutPut(output_chunk) => { + // Push the next operator call with the same chunk on the stack + in_process.push((op_i, Some(chunk))); + + // But first push the output in the next operator + // If a join can produce many rows, we want the filter to + // be executed in between, or sink into a slice so that we get + // sink::finished before we grow the stack with ever more coming chunks + in_process.push((op_i + 1, Some(output_chunk))); + }, + OperatorResult::NeedsNewData => { + // Done, take another chunk from the stack + }, + } + }, + } + }, + } + } + } + } + Ok(SinkResult::Finished) +} diff --git a/crates/polars-pipe/src/pipeline/dispatcher/mod.rs b/crates/polars-pipe/src/pipeline/dispatcher/mod.rs new file mode 100644 index 000000000000..cdb7f8e7f048 --- /dev/null +++ b/crates/polars-pipe/src/pipeline/dispatcher/mod.rs @@ -0,0 +1,383 @@ +use std::cell::RefCell; +use std::fmt::{Debug, Formatter}; +use std::rc::Rc; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex}; + +use polars_core::POOL; +use polars_core::error::PolarsResult; +use polars_core::utils::accumulate_dataframes_vertical_unchecked; +use polars_expr::state::ExecutionState; +use polars_utils::sync::SyncPtr; +use rayon::prelude::*; + +use crate::executors::sources::DataFrameSource; +use crate::operators::{ + DataChunk, FinalizedSink, OperatorResult, PExecutionContext, Sink, SinkResult, Source, + SourceResult, +}; +use crate::pipeline::dispatcher::drive_operator::{par_flush, par_process_chunks}; +mod drive_operator; +use super::*; + +pub(super) struct ThreadedSink { + /// A sink split per thread. + pub sinks: Vec>, + /// when that hits 0, the sink will finalize + pub shared_count: Rc>, + initial_shared_count: u32, + /// - offset in the operators vec + /// at that point the sink should be called. + /// the pipeline will first call the operators on that point and then + /// push the result in the sink. + pub operator_end: usize, +} + +impl ThreadedSink { + pub fn new(sink: Box, shared_count: Rc>, operator_end: usize) -> Self { + let n_threads = morsels_per_sink(); + let sinks = (0..n_threads).map(|i| sink.split(i)).collect(); + let initial_shared_count = *shared_count.borrow(); + ThreadedSink { + sinks, + initial_shared_count, + shared_count, + operator_end, + } + } + + // Only the first node of a shared sink should recurse. The others should return. + fn allow_recursion(&self) -> bool { + self.initial_shared_count == *self.shared_count.borrow() + } +} + +// A pipeline consists of: +// +// - 1. One or more sources. +// Sources get pulled and their data is pushed into operators. +// - 2. Zero or more operators. +// The operators simply pass through data, modifying it as they need. +// Operators can work on batches and don't need all data in scope to +// succeed. +// Think for example on multiply a few columns, or applying a predicate. +// Operators can shrink the batches: filter +// Grow the batches: explode/ unpivot +// Keep them the same size: element-wise operations +// The probe side of join operations is also an operator. +// +// +// - 3. One or more sinks +// A sink needs all data in scope to finalize a pipeline branch. +// Think of sorts, preparing a build phase of a join, group_by + aggregations. +// +// This struct will have the SOS (source, operators, sinks) of its own pipeline branch, but also +// the SOS of other branches. The SOS are stored data oriented and the sinks have an offset that +// indicates the last operator node before that specific sink. We only store the `end offset` and +// keep track of the starting operator during execution. +// +// Pipelines branches are shared with other pipeline branches at the join/union nodes. +// # JOIN +// Consider this tree: +// out +// / +// /\ +// 1 2 +// +// And let's consider that branch 2 runs first. It will run until the join node where it will sink +// into a build table. Once that is done it will replace the build-phase placeholder operator in +// branch 1. Branch one can then run completely until out. +pub struct PipeLine { + // All the sources of this pipeline + sources: Vec>, + // All the operators of this pipeline. Some may be placeholders that will be replaced during + // execution + operators: Vec, + // - offset in the operators vec + // at that point the sink should be called. + // the pipeline will first call the operators on that point and then + // push the result in the sink. + // - shared_count + // when that hits 0, the sink will finalize + // - node of the sink + sinks: Vec, + // Log runtime info to stderr + verbose: bool, +} + +impl PipeLine { + #[allow(clippy::type_complexity)] + pub(super) fn new( + sources: Vec>, + operators: Vec, + sinks: Vec, + verbose: bool, + ) -> PipeLine { + // we don't use the power of two partition size here + // we only do that in the sinks itself. + let n_threads = morsels_per_sink(); + + // We split so that every thread gets an operator + // every index maps to a chain of operators than can be pushed as a pipeline for one thread + let operators = (0..n_threads) + .map(|i| { + operators + .iter() + .map(|op| op.get_ref().split(i).into()) + .collect() + }) + .collect(); + + PipeLine { + sources, + operators, + sinks, + verbose, + } + } + + /// Create a pipeline only consisting of a single branch that always finishes with a sink + pub(crate) fn new_simple( + sources: Vec>, + operators: Vec, + sink: Box, + verbose: bool, + ) -> Self { + let operators_len = operators.len(); + Self::new( + sources, + operators, + vec![ThreadedSink::new( + sink, + Rc::new(RefCell::new(1)), + operators_len, + )], + verbose, + ) + } + + /// Replace the current sources with a [`DataFrameSource`]. + fn set_df_as_sources(&mut self, df: DataFrame) { + let src = Box::new(DataFrameSource::from_df(df)) as Box; + self.set_sources(src) + } + + /// Replace the current sources. + fn set_sources(&mut self, src: Box) { + self.sources.clear(); + self.sources.push(src); + } + + fn run_pipeline_no_finalize( + &mut self, + ec: &PExecutionContext, + pipelines: &mut Vec, + ) -> PolarsResult<(u32, Box)> { + let mut out = None; + let mut operator_start = 0; + let last_i = self.sinks.len() - 1; + + // For unions we typically first want to push all pipelines + // into the union sink before we call `finalize` + // however if the sink is finished early, (for instance a `head`) + // we don't want to run the rest of the pipelines and we finalize early + let mut sink_finished = false; + + for (i, mut sink) in std::mem::take(&mut self.sinks).into_iter().enumerate() { + for src in &mut std::mem::take(&mut self.sources) { + let mut next_batches = src.get_batches(ec)?; + + let must_flush: AtomicBool = AtomicBool::new(false); + while let SourceResult::GotMoreData(chunks) = next_batches { + // Every batches iteration we check if we must continue. + ec.execution_state.should_stop()?; + + let (sink_result, next_batches2) = par_process_chunks( + chunks, + &mut sink.sinks, + ec, + &mut self.operators, + operator_start, + sink.operator_end, + src, + &must_flush, + )?; + next_batches = next_batches2; + + if let Some(SinkResult::Finished) = sink_result { + sink_finished = true; + break; + } + } + if !sink_finished && must_flush.load(Ordering::Relaxed) { + par_flush( + &mut sink.sinks, + ec, + &mut self.operators, + operator_start, + sink.operator_end, + ); + } + } + + // Before we reduce we also check if we should continue. + ec.execution_state.should_stop()?; + let allow_recursion = sink.allow_recursion(); + + // The sinks have taken all chunks thread locally, now we reduce them into a single + // result sink. + let mut reduced_sink = POOL + .install(|| { + sink.sinks.into_par_iter().reduce_with(|mut a, mut b| { + a.combine(&mut *b); + a + }) + }) + .unwrap(); + operator_start = sink.operator_end; + + let mut shared_sink_count = { + let mut shared_sink_count = sink.shared_count.borrow_mut(); + *shared_sink_count -= 1; + *shared_sink_count + }; + + // Prevent very deep recursion. Only the outer callee can pop and run. + if allow_recursion { + while shared_sink_count > 0 && !sink_finished { + let mut pipeline = pipelines.pop().unwrap(); + let (count, mut sink) = pipeline.run_pipeline_no_finalize(ec, pipelines)?; + // This branch is hit when we have a Union of joins. + // The build side must be converted into an operator and replaced in the next pipeline. + + // Check either: + // 1. There can be a union source that sinks into a single join: + // scan_parquet(*) -> join B + // 2. There can be a union of joins + // C - JOIN A, B + // concat (A, B, C) + // + // So to ensure that we don't finalize we check + // - They are not both join builds + // - If they are both join builds, check they are note the same build, otherwise + // we must call the `combine` branch. + if sink.is_join_build() + && (!reduced_sink.is_join_build() || (sink.node() != reduced_sink.node())) + { + let FinalizedSink::Operator = sink.finalize(ec)? else { + unreachable!() + }; + } else { + reduced_sink.combine(sink.as_mut()); + shared_sink_count = count; + } + } + } + + if i != last_i { + let sink_result = reduced_sink.finalize(ec)?; + match sink_result { + // turn this sink an a new source + FinalizedSink::Finished(df) => self.set_df_as_sources(df), + FinalizedSink::Source(src) => self.set_sources(src), + // should not happen + FinalizedSink::Operator => { + unreachable!() + }, + } + } else { + out = Some((shared_sink_count, reduced_sink)) + } + } + Ok(out.unwrap()) + } + + /// Run a single pipeline branch. + /// This pulls data from the sources and pushes it into the operators which run on a different + /// thread and finalize in a sink. + /// + /// The sink can be finished, but can also become a new source and then rinse and repeat. + pub fn run_pipeline( + &mut self, + ec: &PExecutionContext, + pipelines: &mut Vec, + ) -> PolarsResult> { + let (sink_shared_count, mut reduced_sink) = self.run_pipeline_no_finalize(ec, pipelines)?; + assert_eq!(sink_shared_count, 0); + + let finalized_reduced_sink = reduced_sink.finalize(ec)?; + Ok(Some(finalized_reduced_sink)) + } +} + +/// Executes all branches and replaces operators and sinks during execution to ensure +/// we materialize. +pub fn execute_pipeline( + state: ExecutionState, + mut pipelines: Vec, +) -> PolarsResult { + let mut pipeline = pipelines.pop().unwrap(); + let ec = PExecutionContext::new(state, pipeline.verbose); + + let mut sink_out = pipeline.run_pipeline(&ec, &mut pipelines)?; + loop { + match &mut sink_out { + None => { + let mut pipeline = pipelines.pop().unwrap(); + sink_out = pipeline.run_pipeline(&ec, &mut pipelines)?; + }, + Some(FinalizedSink::Finished(df)) => return Ok(std::mem::take(df)), + Some(FinalizedSink::Source(src)) => return consume_source(&mut **src, &ec), + + // + // 1/\ + // 2/\ + // 3\ + // the left hand side of the join has finished and now is an operator + // we replace the dummy node in the right hand side pipeline with this + // operator and then we run the pipeline rinse and repeat + // until the final right hand side pipeline ran + Some(FinalizedSink::Operator) => { + // we unwrap, because the latest pipeline should not return an Operator + let mut pipeline = pipelines.pop().unwrap(); + + sink_out = pipeline.run_pipeline(&ec, &mut pipelines)?; + }, + } + } +} + +impl Debug for PipeLine { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let mut fmt = String::new(); + let mut start = 0usize; + fmt.push_str(self.sources[0].fmt()); + for sink in &self.sinks { + fmt.push_str(" -> "); + // take operators of a single thread + let ops = &self.operators[0]; + // slice the pipeline + let ops = &ops[start..sink.operator_end]; + for op in ops { + fmt.push_str(op.get_ref().fmt()); + fmt.push_str(" -> ") + } + start = sink.operator_end; + fmt.push_str(sink.sinks[0].fmt()) + } + write!(f, "{fmt}") + } +} + +/// Take a source and materialize it into a [`DataFrame`]. +fn consume_source(src: &mut dyn Source, context: &PExecutionContext) -> PolarsResult { + let mut frames = Vec::with_capacity(32); + + while let SourceResult::GotMoreData(batch) = src.get_batches(context)? { + frames.extend(batch.into_iter().map(|chunk| chunk.data)) + } + Ok(accumulate_dataframes_vertical_unchecked(frames)) +} + +unsafe impl Send for PipeLine {} +unsafe impl Sync for PipeLine {} diff --git a/crates/polars-pipe/src/pipeline/mod.rs b/crates/polars-pipe/src/pipeline/mod.rs new file mode 100644 index 000000000000..9799d8313857 --- /dev/null +++ b/crates/polars-pipe/src/pipeline/mod.rs @@ -0,0 +1,68 @@ +mod config; +mod convert; +mod dispatcher; + +pub use convert::{ + CallBacks, create_pipeline, get_dummy_operator, get_operator, get_sink, swap_join_order, +}; +pub use dispatcher::{PipeLine, execute_pipeline}; +use polars_core::POOL; +use polars_core::prelude::*; +use polars_utils::cell::SyncUnsafeCell; + +pub use crate::executors::sinks::group_by::aggregates::can_convert_to_hash_agg; +use crate::operators::{Operator, Sink}; + +pub(crate) fn morsels_per_sink() -> usize { + POOL.current_num_threads() +} + +// Number of OOC partitions. +// proxy for RAM size multiplier +pub(crate) const PARTITION_SIZE: usize = 64; + +// env vars +pub(crate) static FORCE_OOC: &str = "POLARS_FORCE_OOC"; + +/// ideal chunk size we strive to have +/// scale the chunk size depending on the number of +/// columns. With 10 columns we use a chunk size of 40_000 +pub(crate) fn determine_chunk_size(n_cols: usize, n_threads: usize) -> PolarsResult { + if let Ok(val) = std::env::var("POLARS_STREAMING_CHUNK_SIZE") { + val.parse().map_err( + |_| polars_err!(ComputeError: "could not parse 'POLARS_STREAMING_CHUNK_SIZE' env var"), + ) + } else { + let thread_factor = std::cmp::max(12 / n_threads, 1); + Ok(std::cmp::max(50_000 / n_cols.max(1) * thread_factor, 1000)) + } +} + +type PhysSink = Box; +/// A physical operator/sink per thread. +type ThreadedOperator = Vec; +type ThreadedOperatorMut<'a> = &'a mut [PhysOperator]; +type ThreadedSinkMut<'a> = &'a mut [PhysSink]; + +#[repr(transparent)] +pub(crate) struct PhysOperator { + inner: SyncUnsafeCell>, +} + +impl From> for PhysOperator { + fn from(value: Box) -> Self { + Self { + inner: SyncUnsafeCell::new(value), + } + } +} + +impl PhysOperator { + pub(crate) fn get_mut(&mut self) -> &mut dyn Operator { + &mut **self.inner.get_mut() + } + + pub(crate) fn get_ref(&self) -> &dyn Operator { + unsafe { &**self.inner.get() } + } +} diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml new file mode 100644 index 000000000000..1817a245fbbb --- /dev/null +++ b/crates/polars-plan/Cargo.toml @@ -0,0 +1,297 @@ +[package] +name = "polars-plan" +version = { workspace = true } +authors = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +license = { workspace = true } +repository = { workspace = true } +description = "Lazy query engine for the Polars DataFrame library" + +[lib] +doctest = false + +[dependencies] +libloading = { version = "0.8.0", optional = true } +polars-compute = { workspace = true } +polars-core = { workspace = true, features = ["lazy", "zip_with", "random"] } +polars-ffi = { workspace = true, optional = true } +polars-io = { workspace = true, features = ["lazy", "csv"] } +polars-json = { workspace = true, optional = true } +polars-ops = { workspace = true, features = [] } +polars-parquet = { workspace = true, optional = true } +polars-time = { workspace = true, optional = true } +polars-utils = { workspace = true } + +arrow = { workspace = true } +bitflags = { workspace = true } +bytemuck = { workspace = true } +bytes = { workspace = true, features = ["serde"] } +chrono = { workspace = true, optional = true } +chrono-tz = { workspace = true, optional = true } +either = { workspace = true } +futures = { workspace = true, optional = true } +hashbrown = { workspace = true } +memmap = { workspace = true } +num-traits = { workspace = true } +percent-encoding = { workspace = true } +pyo3 = { workspace = true, optional = true } +rayon = { workspace = true } +recursive = { workspace = true } +regex = { workspace = true, optional = true } +serde = { workspace = true, features = ["rc"], optional = true } +serde_json = { workspace = true, optional = true } +strum_macros = { workspace = true } + +[build-dependencies] +version_check = { workspace = true } + +[features] +# debugging utility +debugging = [] +python = ["dep:pyo3", "polars-utils/python", "polars-ffi"] +serde = [ + "ir_serde", + "dep:serde", + "polars-core/serde-lazy", + "polars-time/serde", + "polars-io/serde", + "polars-ops/serde", + "polars-utils/serde", + "polars-compute/serde", + "either/serde", +] +streaming = [] +parquet = ["polars-io/parquet", "polars-parquet"] +async = ["polars-io/async", "futures"] +cloud = ["async", "polars-io/cloud"] +ipc = ["polars-io/ipc"] +json = ["polars-io/json", "polars-json"] +csv = ["polars-io/csv"] +temporal = [ + "chrono", + "polars-core/temporal", + "polars-core/dtype-date", + "polars-core/dtype-datetime", + "polars-core/dtype-time", + "polars-core/dtype-i8", + "polars-core/dtype-i16", +] +# debugging purposes +fmt = ["polars-core/fmt"] +strings = ["polars-core/strings", "polars-ops/strings"] +future = [] +dtype-u8 = ["polars-core/dtype-u8"] +dtype-u16 = ["polars-core/dtype-u16"] +dtype-i8 = ["polars-core/dtype-i8"] +dtype-i128 = ["polars-core/dtype-i128"] +dtype-i16 = ["polars-core/dtype-i16"] +dtype-decimal = ["polars-core/dtype-decimal", "dtype-i128"] +dtype-date = ["polars-time/dtype-date", "temporal"] +dtype-datetime = ["polars-time/dtype-datetime", "temporal"] +dtype-duration = ["polars-core/dtype-duration", "polars-time/dtype-duration", "temporal", "polars-ops/dtype-duration"] +dtype-time = ["polars-time/dtype-time", "temporal"] +dtype-array = ["polars-core/dtype-array", "polars-ops/dtype-array"] +dtype-categorical = ["polars-core/dtype-categorical"] +dtype-struct = ["polars-core/dtype-struct"] +object = ["polars-core/object"] +list_gather = ["polars-ops/list_gather"] +list_count = ["polars-ops/list_count"] +array_count = ["polars-ops/array_count", "dtype-array"] +trigonometry = [] +sign = [] +timezones = ["chrono-tz", "polars-time/timezones", "polars-core/timezones", "regex"] +binary_encoding = ["polars-ops/binary_encoding"] +string_encoding = ["polars-ops/string_encoding"] +true_div = [] +nightly = ["polars-utils/nightly", "polars-ops/nightly"] +extract_jsonpath = ["polars-ops/extract_jsonpath"] + +# operations +bitwise = ["polars-core/bitwise", "polars-ops/bitwise"] +approx_unique = ["polars-ops/approx_unique", "polars-core/approx_unique"] +is_in = ["polars-ops/is_in"] +repeat_by = ["polars-ops/repeat_by"] +round_series = ["polars-ops/round_series"] +is_first_distinct = ["polars-core/is_first_distinct", "polars-ops/is_first_distinct"] +is_last_distinct = ["polars-core/is_last_distinct", "polars-ops/is_last_distinct"] +is_unique = ["polars-ops/is_unique"] +is_between = ["polars-ops/is_between"] +cross_join = ["polars-ops/cross_join"] +asof_join = ["polars-time", "polars-ops/asof_join"] +iejoin = ["polars-ops/iejoin"] +concat_str = [] +business = ["polars-ops/business"] +range = [] +mode = ["polars-ops/mode"] +cum_agg = ["polars-ops/cum_agg"] +interpolate = ["polars-ops/interpolate"] +interpolate_by = ["polars-ops/interpolate_by"] +rolling_window = [ + "polars-core/rolling_window", + "polars-time/rolling_window", + "polars-ops/rolling_window", +] +rolling_window_by = [ + "polars-core/rolling_window_by", + "polars-time/rolling_window_by", + "polars-ops/rolling_window_by", +] +rank = ["polars-ops/rank"] +diff = ["polars-ops/diff"] +pct_change = ["polars-ops/pct_change"] +moment = ["polars-ops/moment"] +abs = ["polars-ops/abs"] +random = ["polars-core/random"] +dynamic_group_by = ["polars-core/dynamic_group_by"] +ewma = ["polars-ops/ewma"] +ewma_by = ["polars-ops/ewma_by"] +dot_diagram = [] +unique_counts = ["polars-ops/unique_counts"] +log = ["polars-ops/log"] +chunked_ids = [] +list_to_struct = ["polars-ops/list_to_struct"] +array_to_struct = ["polars-ops/array_to_struct"] +row_hash = ["polars-core/row_hash", "polars-ops/hash"] +reinterpret = ["polars-core/reinterpret", "polars-ops/reinterpret"] +string_pad = ["polars-ops/string_pad"] +string_normalize = ["polars-ops/string_normalize"] +string_reverse = ["polars-ops/string_reverse"] +string_to_integer = ["polars-ops/string_to_integer"] +arg_where = [] +index_of = ["polars-ops/index_of"] +search_sorted = ["polars-ops/search_sorted"] +merge_sorted = ["polars-ops/merge_sorted"] +meta = [] +pivot = ["polars-core/rows", "polars-ops/pivot"] +top_k = ["polars-ops/top_k"] +semi_anti_join = ["polars-ops/semi_anti_join"] +cse = [] +propagate_nans = ["polars-ops/propagate_nans"] +coalesce = [] +fused = ["polars-ops/fused"] +array_any_all = ["polars-ops/array_any_all", "dtype-array"] +list_sets = ["polars-ops/list_sets"] +list_any_all = ["polars-ops/list_any_all"] +list_drop_nulls = ["polars-ops/list_drop_nulls"] +list_sample = ["polars-ops/list_sample"] +cutqcut = ["polars-ops/cutqcut"] +rle = ["polars-ops/rle"] +extract_groups = ["regex", "dtype-struct", "polars-ops/extract_groups"] +ffi_plugin = ["libloading", "polars-ffi"] +hive_partitions = [] +peaks = ["polars-ops/peaks"] +cov = ["polars-ops/cov"] +hist = ["polars-ops/hist"] +replace = ["polars-ops/replace"] +find_many = ["polars-ops/find_many"] +month_start = ["polars-time/month_start"] +month_end = ["polars-time/month_end"] +offset_by = ["polars-time/offset_by"] + +bigidx = ["polars-core/bigidx", "polars-utils/bigidx"] +polars_cloud = ["serde"] +ir_serde = ["serde", "polars-utils/ir_serde"] + +[package.metadata.docs.rs] +features = [ + "bitwise", + "temporal", + "serde", + "rolling_window", + "rolling_window_by", + "timezones", + "dtype-date", + "extract_groups", + "dtype-datetime", + "asof_join", + "dtype-duration", + "is_first_distinct", + "pivot", + "dtype-array", + "is_last_distinct", + "dtype-time", + "array_any_all", + "month_start", + "month_end", + "offset_by", + "parquet", + "strings", + "row_hash", + "json", + "python", + "cloud", + "string_to_integer", + "list_any_all", + "pct_change", + "list_gather", + "dtype-i16", + "round_series", + "cutqcut", + "async", + "ewma", + "ewma_by", + "random", + "chunked_ids", + "repeat_by", + "is_in", + "log", + "string_reverse", + "list_sets", + "propagate_nans", + "mode", + "rank", + "hist", + "object", + "approx_unique", + "dtype-categorical", + "merge_sorted", + "bigidx", + "cov", + "list_sample", + "dtype-i8", + "fused", + "binary_encoding", + "list_drop_nulls", + "fmt", + "list_to_struct", + "string_pad", + "diff", + "rle", + "is_unique", + "find_many", + "string_encoding", + "ipc", + "index_of", + "search_sorted", + "unique_counts", + "dtype-u8", + "dtype-struct", + "peaks", + "abs", + "interpolate", + "interpolate_by", + "list_count", + "cum_agg", + "top_k", + "moment", + "semi_anti_join", + "replace", + "dtype-u16", + "regex", + "dtype-decimal", + "arg_where", + "business", + "range", + "meta", + "hive_partitions", + "concat_str", + "coalesce", + "dot_diagram", + "trigonometry", + "streaming", + "true_div", + "sign", +] +# defines the configuration attribute `docsrs` +rustdoc-args = ["--cfg", "docsrs"] diff --git a/crates/polars-plan/LICENSE b/crates/polars-plan/LICENSE new file mode 120000 index 000000000000..30cff7403da0 --- /dev/null +++ b/crates/polars-plan/LICENSE @@ -0,0 +1 @@ +../../LICENSE \ No newline at end of file diff --git a/crates/polars-plan/README.md b/crates/polars-plan/README.md new file mode 100644 index 000000000000..3a9589b18463 --- /dev/null +++ b/crates/polars-plan/README.md @@ -0,0 +1,7 @@ +# polars-plan + +`polars-plan` is an **internal sub-crate** of the [Polars](https://crates.io/crates/polars) library, +that provides source code responsible for Polars logical planning. + +**Important Note**: This crate is **not intended for external usage**. Please refer to the main +[Polars crate](https://crates.io/crates/polars) for intended usage. diff --git a/crates/polars-plan/build.rs b/crates/polars-plan/build.rs new file mode 100644 index 000000000000..3e4ab64620ac --- /dev/null +++ b/crates/polars-plan/build.rs @@ -0,0 +1,7 @@ +fn main() { + println!("cargo:rerun-if-changed=build.rs"); + let channel = version_check::Channel::read().unwrap(); + if channel.is_nightly() { + println!("cargo:rustc-cfg=feature=\"nightly\""); + } +} diff --git a/crates/polars-plan/src/client/check.rs b/crates/polars-plan/src/client/check.rs new file mode 100644 index 000000000000..619e8557cf72 --- /dev/null +++ b/crates/polars-plan/src/client/check.rs @@ -0,0 +1,128 @@ +use polars_core::error::{PolarsResult, polars_err}; +use polars_io::path_utils::is_cloud_url; + +use crate::constants::POLARS_PLACEHOLDER; +use crate::dsl::{DslPlan, FileScan, ScanSources}; + +/// Assert that the given [`DslPlan`] is eligible to be executed on Polars Cloud. +pub(super) fn assert_cloud_eligible(dsl: &DslPlan) -> PolarsResult<()> { + if std::env::var("POLARS_SKIP_CLIENT_CHECK").as_deref() == Ok("1") { + return Ok(()); + } + for plan_node in dsl.into_iter() { + match plan_node { + #[cfg(feature = "python")] + DslPlan::PythonScan { .. } => (), + DslPlan::GroupBy { apply, .. } if apply.is_some() => { + return ineligible_error("contains map groups"); + }, + DslPlan::Scan { + sources, scan_type, .. + } => { + match sources { + ScanSources::Paths(paths) => { + if paths + .iter() + .any(|p| !is_cloud_url(p) && p.to_str() != Some(POLARS_PLACEHOLDER)) + { + return ineligible_error("contains scan of local file system"); + } + }, + ScanSources::Files(_) => { + return ineligible_error("contains scan of opened files"); + }, + ScanSources::Buffers(_) => { + return ineligible_error("contains scan of in-memory buffer"); + }, + } + + if matches!(&**scan_type, FileScan::Anonymous { .. }) { + return ineligible_error("contains anonymous scan"); + } + }, + DslPlan::Sink { payload, .. } => { + if !payload.is_cloud_destination() { + return ineligible_error("contains sink to non-cloud location"); + } + }, + _ => (), + } + } + Ok(()) +} + +fn ineligible_error(message: &str) -> PolarsResult<()> { + Err(polars_err!( + InvalidOperation: + "logical plan ineligible for execution on Polars Cloud: {message}" + )) +} + +impl DslPlan { + fn inputs<'a>(&'a self, scratch: &mut Vec<&'a DslPlan>) { + use DslPlan::*; + match self { + Select { input, .. } + | GroupBy { input, .. } + | Filter { input, .. } + | Distinct { input, .. } + | Sort { input, .. } + | Slice { input, .. } + | HStack { input, .. } + | MapFunction { input, .. } + | Sink { input, .. } + | Cache { input, .. } => scratch.push(input), + Union { inputs, .. } | HConcat { inputs, .. } | SinkMultiple { inputs } => { + scratch.extend(inputs) + }, + Join { + input_left, + input_right, + .. + } => { + scratch.push(input_left); + scratch.push(input_right); + }, + ExtContext { input, contexts } => { + scratch.push(input); + scratch.extend(contexts); + }, + IR { dsl, .. } => scratch.push(dsl), + Scan { .. } | DataFrameScan { .. } => (), + #[cfg(feature = "python")] + PythonScan { .. } => (), + #[cfg(feature = "merge_sorted")] + MergeSorted { + input_left, + input_right, + .. + } => { + scratch.push(input_left); + scratch.push(input_right); + }, + } + } +} + +pub struct DslPlanIter<'a> { + stack: Vec<&'a DslPlan>, +} + +impl<'a> Iterator for DslPlanIter<'a> { + type Item = &'a DslPlan; + + fn next(&mut self) -> Option { + self.stack + .pop() + .inspect(|next| next.inputs(&mut self.stack)) + } +} + +impl<'a> IntoIterator for &'a DslPlan { + type Item = &'a DslPlan; + type IntoIter = DslPlanIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + DslPlanIter { stack: vec![self] } + } +} diff --git a/crates/polars-plan/src/client/mod.rs b/crates/polars-plan/src/client/mod.rs new file mode 100644 index 000000000000..9c5bd0387d0d --- /dev/null +++ b/crates/polars-plan/src/client/mod.rs @@ -0,0 +1,17 @@ +mod check; + +use polars_core::error::PolarsResult; + +use crate::dsl::DslPlan; + +/// Prepare the given [`DslPlan`] for execution on Polars Cloud. +pub fn prepare_cloud_plan(dsl: DslPlan) -> PolarsResult> { + // Check the plan for cloud eligibility. + check::assert_cloud_eligible(&dsl)?; + + // Serialize the plan. + let mut writer = Vec::new(); + dsl.serialize_versioned(&mut writer)?; + + Ok(writer) +} diff --git a/crates/polars-plan/src/constants.rs b/crates/polars-plan/src/constants.rs new file mode 100644 index 000000000000..5c02dbb79e15 --- /dev/null +++ b/crates/polars-plan/src/constants.rs @@ -0,0 +1,24 @@ +use std::sync::OnceLock; + +use polars_utils::pl_str::PlSmallStr; + +pub static MAP_LIST_NAME: &str = "map_list"; +pub static CSE_REPLACED: &str = "__POLARS_CSER_"; +pub static POLARS_TMP_PREFIX: &str = "_POLARS_"; +pub static POLARS_PLACEHOLDER: &str = "_POLARS_<>"; +pub const LEN: &str = "len"; +const LITERAL_NAME: &str = "literal"; +pub const UNLIMITED_CACHE: u32 = u32::MAX; + +// Cache the often used LITERAL and LEN constants +static LITERAL_NAME_INIT: OnceLock = OnceLock::new(); +static LEN_INIT: OnceLock = OnceLock::new(); + +pub fn get_literal_name() -> &'static PlSmallStr { + LITERAL_NAME_INIT.get_or_init(|| PlSmallStr::from_static(LITERAL_NAME)) +} +pub(crate) fn get_len_name() -> PlSmallStr { + LEN_INIT + .get_or_init(|| PlSmallStr::from_static(LEN)) + .clone() +} diff --git a/crates/polars-plan/src/dsl/arithmetic.rs b/crates/polars-plan/src/dsl/arithmetic.rs new file mode 100644 index 000000000000..c2b68a068961 --- /dev/null +++ b/crates/polars-plan/src/dsl/arithmetic.rs @@ -0,0 +1,176 @@ +use std::ops::{Add, Div, Mul, Neg, Rem, Sub}; + +use super::*; + +// Arithmetic ops +impl Add for Expr { + type Output = Expr; + + fn add(self, rhs: Self) -> Self::Output { + binary_expr(self, Operator::Plus, rhs) + } +} + +impl Sub for Expr { + type Output = Expr; + + fn sub(self, rhs: Self) -> Self::Output { + binary_expr(self, Operator::Minus, rhs) + } +} + +impl Div for Expr { + type Output = Expr; + + fn div(self, rhs: Self) -> Self::Output { + binary_expr(self, Operator::Divide, rhs) + } +} + +impl Mul for Expr { + type Output = Expr; + + fn mul(self, rhs: Self) -> Self::Output { + binary_expr(self, Operator::Multiply, rhs) + } +} + +impl Rem for Expr { + type Output = Expr; + + fn rem(self, rhs: Self) -> Self::Output { + binary_expr(self, Operator::Modulus, rhs) + } +} + +impl Neg for Expr { + type Output = Expr; + + fn neg(self) -> Self::Output { + self.map_unary(FunctionExpr::Negate) + } +} + +impl Expr { + /// Floor divide `self` by `rhs`. + pub fn floor_div(self, rhs: Self) -> Self { + binary_expr(self, Operator::FloorDivide, rhs) + } + + /// Raise expression to the power `exponent` + pub fn pow>(self, exponent: E) -> Self { + self.map_binary(PowFunction::Generic, exponent.into()) + } + + /// Compute the square root of the given expression + pub fn sqrt(self) -> Self { + self.map_unary(PowFunction::Sqrt) + } + + /// Compute the cube root of the given expression + pub fn cbrt(self) -> Self { + self.map_unary(PowFunction::Cbrt) + } + + /// Compute the cosine of the given expression + #[cfg(feature = "trigonometry")] + pub fn cos(self) -> Self { + self.map_unary(TrigonometricFunction::Cos) + } + + /// Compute the cotangent of the given expression + #[cfg(feature = "trigonometry")] + pub fn cot(self) -> Self { + self.map_unary(TrigonometricFunction::Cot) + } + + /// Compute the sine of the given expression + #[cfg(feature = "trigonometry")] + pub fn sin(self) -> Self { + self.map_unary(TrigonometricFunction::Sin) + } + + /// Compute the tangent of the given expression + #[cfg(feature = "trigonometry")] + pub fn tan(self) -> Self { + self.map_unary(TrigonometricFunction::Tan) + } + + /// Compute the inverse cosine of the given expression + #[cfg(feature = "trigonometry")] + pub fn arccos(self) -> Self { + self.map_unary(TrigonometricFunction::ArcCos) + } + + /// Compute the inverse sine of the given expression + #[cfg(feature = "trigonometry")] + pub fn arcsin(self) -> Self { + self.map_unary(TrigonometricFunction::ArcSin) + } + + /// Compute the inverse tangent of the given expression + #[cfg(feature = "trigonometry")] + pub fn arctan(self) -> Self { + self.map_unary(TrigonometricFunction::ArcTan) + } + + /// Compute the inverse tangent of the given expression, with the angle expressed as the argument of a complex number + #[cfg(feature = "trigonometry")] + pub fn arctan2(self, x: Self) -> Self { + self.map_binary(FunctionExpr::Atan2, x) + } + + /// Compute the hyperbolic cosine of the given expression + #[cfg(feature = "trigonometry")] + pub fn cosh(self) -> Self { + self.map_unary(TrigonometricFunction::Cosh) + } + + /// Compute the hyperbolic sine of the given expression + #[cfg(feature = "trigonometry")] + pub fn sinh(self) -> Self { + self.map_unary(TrigonometricFunction::Sinh) + } + + /// Compute the hyperbolic tangent of the given expression + #[cfg(feature = "trigonometry")] + pub fn tanh(self) -> Self { + self.map_unary(TrigonometricFunction::Tanh) + } + + /// Compute the inverse hyperbolic cosine of the given expression + #[cfg(feature = "trigonometry")] + pub fn arccosh(self) -> Self { + self.map_unary(TrigonometricFunction::ArcCosh) + } + + /// Compute the inverse hyperbolic sine of the given expression + #[cfg(feature = "trigonometry")] + pub fn arcsinh(self) -> Self { + self.map_unary(TrigonometricFunction::ArcSinh) + } + + /// Compute the inverse hyperbolic tangent of the given expression + #[cfg(feature = "trigonometry")] + pub fn arctanh(self) -> Self { + self.map_unary(TrigonometricFunction::ArcTanh) + } + + /// Convert from radians to degrees + #[cfg(feature = "trigonometry")] + pub fn degrees(self) -> Self { + self.map_unary(TrigonometricFunction::Degrees) + } + + /// Convert from degrees to radians + #[cfg(feature = "trigonometry")] + pub fn radians(self) -> Self { + self.map_unary(TrigonometricFunction::Radians) + } + + /// Compute the sign of the given expression + #[cfg(feature = "sign")] + pub fn sign(self) -> Self { + self.map_unary(FunctionExpr::Sign) + } +} diff --git a/crates/polars-plan/src/dsl/arity.rs b/crates/polars-plan/src/dsl/arity.rs new file mode 100644 index 000000000000..9883936f6c10 --- /dev/null +++ b/crates/polars-plan/src/dsl/arity.rs @@ -0,0 +1,155 @@ +use super::*; + +/// Utility struct for the `when-then-otherwise` expression. +/// +/// Represents the state of the expression after [when] is called. +/// +/// In this state, `then` must be called to continue to finish the expression. +#[derive(Clone)] +pub struct When { + condition: Expr, +} + +/// Utility struct for the `when-then-otherwise` expression. +/// +/// Represents the state of the expression after `when(...).then(...)` is called. +#[derive(Clone)] +pub struct Then { + condition: Expr, + statement: Expr, +} + +/// Utility struct for the `when-then-otherwise` expression. +/// +/// Represents the state of the expression after an additional `when` is called. +/// +/// In this state, `then` must be called to continue to finish the expression. +#[derive(Clone)] +pub struct ChainedWhen { + conditions: Vec, + statements: Vec, +} + +/// Utility struct for the `when-then-otherwise` expression. +/// +/// Represents the state of the expression after an additional `then` is called. +#[derive(Clone)] +pub struct ChainedThen { + conditions: Vec, + statements: Vec, +} + +impl When { + /// Add a condition to the `when-then-otherwise` expression. + pub fn then>(self, expr: E) -> Then { + Then { + condition: self.condition, + statement: expr.into(), + } + } +} + +impl Then { + /// Attach a statement to the corresponding condition. + pub fn when>(self, condition: E) -> ChainedWhen { + ChainedWhen { + conditions: vec![self.condition, condition.into()], + statements: vec![self.statement], + } + } + + /// Define a default for the `when-then-otherwise` expression. + pub fn otherwise>(self, statement: E) -> Expr { + ternary_expr(self.condition, self.statement, statement.into()) + } +} + +impl ChainedWhen { + pub fn then>(mut self, statement: E) -> ChainedThen { + self.statements.push(statement.into()); + ChainedThen { + conditions: self.conditions, + statements: self.statements, + } + } +} + +impl ChainedThen { + /// Add another condition to the `when-then-otherwise` expression. + pub fn when>(mut self, condition: E) -> ChainedWhen { + self.conditions.push(condition.into()); + + ChainedWhen { + conditions: self.conditions, + statements: self.statements, + } + } + + /// Define a default for the `when-then-otherwise` expression. + pub fn otherwise>(self, expr: E) -> Expr { + // we iterate the preds/ exprs last in first out + // and nest them. + // + // // this expr: + // when((col('x') == 'a')).then(1) + // .when(col('x') == 'b').then(2) + // .when(col('x') == 'c').then(3) + // .otherwise(4) + // + // needs to become: + // when((col('x') == 'a')).then(1) - + // .otherwise( | + // when(col('x') == 'b').then(2) - | + // .otherwise( | | + // pl.when(col('x') == 'c').then(3) | | + // .otherwise(4) | inner | outer + // ) | | + // ) _| _| + // + // by iterating LIFO we first create + // `inner` and then assign that to `otherwise`, + // which will be used in the next layer `outer` + // + + let conditions_iter = self.conditions.into_iter().rev(); + let mut statements_iter = self.statements.into_iter().rev(); + + let mut otherwise = expr.into(); + + for e in conditions_iter { + otherwise = ternary_expr( + e, + statements_iter + .next() + .expect("expr expected, did you call when().then().otherwise?"), + otherwise, + ); + } + + otherwise + } +} + +/// Start a `when-then-otherwise` expression. +pub fn when>(condition: E) -> When { + When { + condition: condition.into(), + } +} + +pub fn ternary_expr(predicate: Expr, truthy: Expr, falsy: Expr) -> Expr { + Expr::Ternary { + predicate: Arc::new(predicate), + truthy: Arc::new(truthy), + falsy: Arc::new(falsy), + } +} + +/// Compute `op(l, r)` (or equivalently `l op r`). `l` and `r` must have types compatible with the Operator. +pub fn binary_expr(l: Expr, op: Operator, r: Expr) -> Expr { + Expr::BinaryExpr { + left: Arc::new(l), + op, + right: Arc::new(r), + } +} diff --git a/crates/polars-plan/src/dsl/array.rs b/crates/polars-plan/src/dsl/array.rs new file mode 100644 index 000000000000..3c72c616dcd8 --- /dev/null +++ b/crates/polars-plan/src/dsl/array.rs @@ -0,0 +1,184 @@ +use polars_core::prelude::*; +#[cfg(feature = "array_to_struct")] +use polars_ops::chunked_array::array::{ + ArrToStructNameGenerator, ToStruct, arr_default_struct_name_gen, +}; + +use crate::dsl::function_expr::ArrayFunction; +use crate::prelude::*; + +/// Specialized expressions for [`Series`] of [`DataType::Array`]. +pub struct ArrayNameSpace(pub Expr); + +impl ArrayNameSpace { + pub fn len(self) -> Expr { + self.0 + .map_unary(FunctionExpr::ArrayExpr(ArrayFunction::Length)) + } + /// Compute the maximum of the items in every subarray. + pub fn max(self) -> Expr { + self.0 + .map_unary(FunctionExpr::ArrayExpr(ArrayFunction::Max)) + } + + /// Compute the minimum of the items in every subarray. + pub fn min(self) -> Expr { + self.0 + .map_unary(FunctionExpr::ArrayExpr(ArrayFunction::Min)) + } + + /// Compute the sum of the items in every subarray. + pub fn sum(self) -> Expr { + self.0 + .map_unary(FunctionExpr::ArrayExpr(ArrayFunction::Sum)) + } + + /// Compute the std of the items in every subarray. + pub fn std(self, ddof: u8) -> Expr { + self.0 + .map_unary(FunctionExpr::ArrayExpr(ArrayFunction::Std(ddof))) + } + + /// Compute the var of the items in every subarray. + pub fn var(self, ddof: u8) -> Expr { + self.0 + .map_unary(FunctionExpr::ArrayExpr(ArrayFunction::Var(ddof))) + } + + /// Compute the median of the items in every subarray. + pub fn median(self) -> Expr { + self.0 + .map_unary(FunctionExpr::ArrayExpr(ArrayFunction::Median)) + } + + /// Keep only the unique values in every sub-array. + pub fn unique(self) -> Expr { + self.0 + .map_unary(FunctionExpr::ArrayExpr(ArrayFunction::Unique(false))) + } + + /// Keep only the unique values in every sub-array. + pub fn unique_stable(self) -> Expr { + self.0 + .map_unary(FunctionExpr::ArrayExpr(ArrayFunction::Unique(true))) + } + + pub fn n_unique(self) -> Expr { + self.0 + .map_unary(FunctionExpr::ArrayExpr(ArrayFunction::NUnique)) + } + + /// Cast the Array column to List column with the same inner data type. + pub fn to_list(self) -> Expr { + self.0 + .map_unary(FunctionExpr::ArrayExpr(ArrayFunction::ToList)) + } + + #[cfg(feature = "array_any_all")] + /// Evaluate whether all boolean values are true for every subarray. + pub fn all(self) -> Expr { + self.0 + .map_unary(FunctionExpr::ArrayExpr(ArrayFunction::All)) + } + + #[cfg(feature = "array_any_all")] + /// Evaluate whether any boolean value is true for every subarray + pub fn any(self) -> Expr { + self.0 + .map_unary(FunctionExpr::ArrayExpr(ArrayFunction::Any)) + } + + pub fn sort(self, options: SortOptions) -> Expr { + self.0 + .map_unary(FunctionExpr::ArrayExpr(ArrayFunction::Sort(options))) + } + + pub fn reverse(self) -> Expr { + self.0 + .map_unary(FunctionExpr::ArrayExpr(ArrayFunction::Reverse)) + } + + pub fn arg_min(self) -> Expr { + self.0 + .map_unary(FunctionExpr::ArrayExpr(ArrayFunction::ArgMin)) + } + + pub fn arg_max(self) -> Expr { + self.0 + .map_unary(FunctionExpr::ArrayExpr(ArrayFunction::ArgMax)) + } + + /// Get items in every sub-array by index. + pub fn get(self, index: Expr, null_on_oob: bool) -> Expr { + self.0.map_binary( + FunctionExpr::ArrayExpr(ArrayFunction::Get(null_on_oob)), + index, + ) + } + + /// Join all string items in a sub-array and place a separator between them. + /// # Error + /// Raise if inner type of array is not `DataType::String`. + pub fn join(self, separator: Expr, ignore_nulls: bool) -> Expr { + self.0.map_binary( + FunctionExpr::ArrayExpr(ArrayFunction::Join(ignore_nulls)), + separator, + ) + } + + #[cfg(feature = "is_in")] + /// Check if the sub-array contains specific element + pub fn contains>(self, other: E) -> Expr { + self.0.map_binary( + FunctionExpr::ArrayExpr(ArrayFunction::Contains), + other.into(), + ) + } + + #[cfg(feature = "array_count")] + /// Count how often the value produced by ``element`` occurs. + pub fn count_matches>(self, element: E) -> Expr { + self.0.map_binary( + FunctionExpr::ArrayExpr(ArrayFunction::CountMatches), + element.into(), + ) + } + + #[cfg(feature = "array_to_struct")] + pub fn to_struct(self, name_generator: Option) -> PolarsResult { + Ok(self + .0 + .map( + move |s| { + s.array()? + .to_struct(name_generator.clone()) + .map(|s| Some(s.into_column())) + }, + GetOutput::map_dtype(move |dt: &DataType| { + let DataType::Array(inner, width) = dt else { + polars_bail!(InvalidOperation: "expected Array type, got: {}", dt) + }; + + let fields = (0..*width) + .map(|i| { + let name = arr_default_struct_name_gen(i); + Field::new(name, inner.as_ref().clone()) + }) + .collect(); + Ok(DataType::Struct(fields)) + }), + ) + .with_fmt("arr.to_struct")) + } + + /// Shift every sub-array. + pub fn shift(self, n: Expr) -> Expr { + self.0 + .map_binary(FunctionExpr::ArrayExpr(ArrayFunction::Shift), n) + } + /// Returns a column with a separate row for every array element. + pub fn explode(self) -> Expr { + self.0 + .map_unary(FunctionExpr::ArrayExpr(ArrayFunction::Explode)) + } +} diff --git a/crates/polars-plan/src/dsl/binary.rs b/crates/polars-plan/src/dsl/binary.rs new file mode 100644 index 000000000000..7edd2f0cb21a --- /dev/null +++ b/crates/polars-plan/src/dsl/binary.rs @@ -0,0 +1,64 @@ +use super::*; +/// Specialized expressions for [`Series`] of [`DataType::String`]. +pub struct BinaryNameSpace(pub(crate) Expr); + +impl BinaryNameSpace { + /// Check if a binary value contains a literal binary. + pub fn contains_literal(self, pat: Expr) -> Expr { + self.0 + .map_binary(FunctionExpr::BinaryExpr(BinaryFunction::Contains), pat) + } + + /// Check if a binary value ends with the given sequence. + pub fn ends_with(self, sub: Expr) -> Expr { + self.0 + .map_binary(FunctionExpr::BinaryExpr(BinaryFunction::EndsWith), sub) + } + + /// Check if a binary value starts with the given sequence. + pub fn starts_with(self, sub: Expr) -> Expr { + self.0 + .map_binary(FunctionExpr::BinaryExpr(BinaryFunction::StartsWith), sub) + } + + /// Return the size (number of bytes) in each element. + pub fn size_bytes(self) -> Expr { + self.0 + .map_unary(FunctionExpr::BinaryExpr(BinaryFunction::Size)) + } + + #[cfg(feature = "binary_encoding")] + pub fn hex_decode(self, strict: bool) -> Expr { + self.0 + .map_unary(FunctionExpr::BinaryExpr(BinaryFunction::HexDecode(strict))) + } + + #[cfg(feature = "binary_encoding")] + pub fn hex_encode(self) -> Expr { + self.0 + .map_unary(FunctionExpr::BinaryExpr(BinaryFunction::HexEncode)) + } + + #[cfg(feature = "binary_encoding")] + pub fn base64_decode(self, strict: bool) -> Expr { + self.0 + .map_unary(FunctionExpr::BinaryExpr(BinaryFunction::Base64Decode( + strict, + ))) + } + + #[cfg(feature = "binary_encoding")] + pub fn base64_encode(self) -> Expr { + self.0 + .map_unary(FunctionExpr::BinaryExpr(BinaryFunction::Base64Encode)) + } + + #[cfg(feature = "binary_encoding")] + pub fn from_buffer(self, to_type: DataType, is_little_endian: bool) -> Expr { + self.0 + .map_unary(FunctionExpr::BinaryExpr(BinaryFunction::FromBuffer( + to_type, + is_little_endian, + ))) + } +} diff --git a/crates/polars-plan/src/dsl/bitwise.rs b/crates/polars-plan/src/dsl/bitwise.rs new file mode 100644 index 000000000000..07a3fb7052f4 --- /dev/null +++ b/crates/polars-plan/src/dsl/bitwise.rs @@ -0,0 +1,48 @@ +use super::{BitwiseFunction, Expr, FunctionExpr}; + +impl Expr { + /// Evaluate the number of set bits. + pub fn bitwise_count_ones(self) -> Self { + self.map_unary(FunctionExpr::Bitwise(BitwiseFunction::CountOnes)) + } + + /// Evaluate the number of unset bits. + pub fn bitwise_count_zeros(self) -> Self { + self.map_unary(FunctionExpr::Bitwise(BitwiseFunction::CountZeros)) + } + + /// Evaluate the number most-significant set bits before seeing an unset bit. + pub fn bitwise_leading_ones(self) -> Self { + self.map_unary(FunctionExpr::Bitwise(BitwiseFunction::LeadingOnes)) + } + + /// Evaluate the number most-significant unset bits before seeing an set bit. + pub fn bitwise_leading_zeros(self) -> Self { + self.map_unary(FunctionExpr::Bitwise(BitwiseFunction::LeadingZeros)) + } + + /// Evaluate the number least-significant set bits before seeing an unset bit. + pub fn bitwise_trailing_ones(self) -> Self { + self.map_unary(FunctionExpr::Bitwise(BitwiseFunction::TrailingOnes)) + } + + /// Evaluate the number least-significant unset bits before seeing an set bit. + pub fn bitwise_trailing_zeros(self) -> Self { + self.map_unary(FunctionExpr::Bitwise(BitwiseFunction::TrailingZeros)) + } + + /// Perform an aggregation of bitwise ANDs + pub fn bitwise_and(self) -> Self { + self.map_unary(FunctionExpr::Bitwise(BitwiseFunction::And)) + } + + /// Perform an aggregation of bitwise ORs + pub fn bitwise_or(self) -> Self { + self.map_unary(FunctionExpr::Bitwise(BitwiseFunction::Or)) + } + + /// Perform an aggregation of bitwise XORs + pub fn bitwise_xor(self) -> Self { + self.map_unary(FunctionExpr::Bitwise(BitwiseFunction::Xor)) + } +} diff --git a/crates/polars-plan/src/dsl/builder_dsl.rs b/crates/polars-plan/src/dsl/builder_dsl.rs new file mode 100644 index 000000000000..d053f58926c9 --- /dev/null +++ b/crates/polars-plan/src/dsl/builder_dsl.rs @@ -0,0 +1,378 @@ +use std::sync::Arc; + +use polars_core::prelude::*; +#[cfg(feature = "csv")] +use polars_io::csv::read::CsvReadOptions; +#[cfg(feature = "ipc")] +use polars_io::ipc::IpcScanOptions; +#[cfg(feature = "parquet")] +use polars_io::parquet::read::ParquetOptions; + +#[cfg(feature = "python")] +use crate::dsl::python_dsl::PythonFunction; +use crate::prelude::*; + +pub struct DslBuilder(pub DslPlan); + +impl From for DslBuilder { + fn from(lp: DslPlan) -> Self { + DslBuilder(lp) + } +} + +impl DslBuilder { + pub fn anonymous_scan( + function: Arc, + options: AnonymousScanOptions, + unified_scan_args: UnifiedScanArgs, + ) -> PolarsResult { + let schema = unified_scan_args.schema.clone().ok_or_else(|| { + polars_err!( + ComputeError: + "anonymous scan requires schema to be specified in unified_scan_args" + ) + })?; + + Ok(DslPlan::Scan { + sources: ScanSources::Buffers(Arc::default()), + file_info: Some(FileInfo { + schema: schema.clone(), + reader_schema: Some(either::Either::Right(schema)), + ..Default::default() + }), + unified_scan_args: Box::new(unified_scan_args), + scan_type: Box::new(FileScan::Anonymous { + function, + options: Arc::new(options), + }), + cached_ir: Default::default(), + } + .into()) + } + + #[cfg(feature = "parquet")] + #[allow(clippy::too_many_arguments)] + pub fn scan_parquet( + sources: ScanSources, + options: ParquetOptions, + unified_scan_args: UnifiedScanArgs, + ) -> PolarsResult { + Ok(DslPlan::Scan { + sources, + file_info: None, + unified_scan_args: Box::new(unified_scan_args), + scan_type: Box::new(FileScan::Parquet { + options, + metadata: None, + }), + cached_ir: Default::default(), + } + .into()) + } + + #[cfg(feature = "ipc")] + #[allow(clippy::too_many_arguments)] + pub fn scan_ipc( + sources: ScanSources, + options: IpcScanOptions, + unified_scan_args: UnifiedScanArgs, + ) -> PolarsResult { + Ok(DslPlan::Scan { + sources, + file_info: None, + unified_scan_args: Box::new(unified_scan_args), + scan_type: Box::new(FileScan::Ipc { + options, + metadata: None, + }), + cached_ir: Default::default(), + } + .into()) + } + + #[allow(clippy::too_many_arguments)] + #[cfg(feature = "csv")] + pub fn scan_csv( + sources: ScanSources, + options: CsvReadOptions, + unified_scan_args: UnifiedScanArgs, + ) -> PolarsResult { + Ok(DslPlan::Scan { + sources, + file_info: None, + unified_scan_args: Box::new(unified_scan_args), + scan_type: Box::new(FileScan::Csv { options }), + cached_ir: Default::default(), + } + .into()) + } + + pub fn cache(self) -> Self { + let input = Arc::new(self.0); + let id = input.as_ref() as *const DslPlan as usize; + DslPlan::Cache { input, id }.into() + } + + pub fn drop(self, to_drop: Vec, strict: bool) -> Self { + self.map_private(DslFunction::Drop(DropFunction { to_drop, strict })) + } + + pub fn project(self, exprs: Vec, options: ProjectionOptions) -> Self { + DslPlan::Select { + expr: exprs, + input: Arc::new(self.0), + options, + } + .into() + } + + pub fn fill_null(self, fill_value: Expr) -> Self { + self.project( + vec![all().fill_null(fill_value)], + ProjectionOptions { + duplicate_check: false, + ..Default::default() + }, + ) + } + + pub fn drop_nans(self, subset: Option>) -> Self { + if let Some(subset) = subset { + self.filter( + all_horizontal( + subset + .into_iter() + .map(|v| v.is_not_nan()) + .collect::>(), + ) + .unwrap(), + ) + } else { + self.filter( + // TODO: when Decimal supports NaN, include here + all_horizontal([dtype_cols([DataType::Float32, DataType::Float64]).is_not_nan()]) + .unwrap(), + ) + } + } + + pub fn drop_nulls(self, subset: Option>) -> Self { + if let Some(subset) = subset { + self.filter( + all_horizontal( + subset + .into_iter() + .map(|v| v.is_not_null()) + .collect::>(), + ) + .unwrap(), + ) + } else { + self.filter(all_horizontal([all().is_not_null()]).unwrap()) + } + } + + pub fn fill_nan(self, fill_value: Expr) -> Self { + self.map_private(DslFunction::FillNan(fill_value)) + } + + pub fn with_columns(self, exprs: Vec, options: ProjectionOptions) -> Self { + if exprs.is_empty() { + return self; + } + + DslPlan::HStack { + input: Arc::new(self.0), + exprs, + options, + } + .into() + } + + pub fn with_context(self, contexts: Vec) -> Self { + DslPlan::ExtContext { + input: Arc::new(self.0), + contexts, + } + .into() + } + + /// Apply a filter + pub fn filter(self, predicate: Expr) -> Self { + DslPlan::Filter { + predicate, + input: Arc::new(self.0), + } + .into() + } + + pub fn group_by>( + self, + keys: Vec, + aggs: E, + apply: Option<(Arc, SchemaRef)>, + maintain_order: bool, + #[cfg(feature = "dynamic_group_by")] dynamic_options: Option, + #[cfg(feature = "dynamic_group_by")] rolling_options: Option, + ) -> Self { + let aggs = aggs.as_ref().to_vec(); + let options = GroupbyOptions { + #[cfg(feature = "dynamic_group_by")] + dynamic: dynamic_options, + #[cfg(feature = "dynamic_group_by")] + rolling: rolling_options, + slice: None, + }; + + DslPlan::GroupBy { + input: Arc::new(self.0), + keys, + aggs, + apply, + maintain_order, + options: Arc::new(options), + } + .into() + } + + pub fn build(self) -> DslPlan { + self.0 + } + + pub fn from_existing_df(df: DataFrame) -> Self { + let schema = df.schema().clone(); + DslPlan::DataFrameScan { + df: Arc::new(df), + schema, + } + .into() + } + + pub fn sort(self, by_column: Vec, sort_options: SortMultipleOptions) -> Self { + DslPlan::Sort { + input: Arc::new(self.0), + by_column, + slice: None, + sort_options, + } + .into() + } + + pub fn explode(self, columns: Vec, allow_empty: bool) -> Self { + DslPlan::MapFunction { + input: Arc::new(self.0), + function: DslFunction::Explode { + columns, + allow_empty, + }, + } + .into() + } + + #[cfg(feature = "pivot")] + pub fn unpivot(self, args: UnpivotArgsDSL) -> Self { + DslPlan::MapFunction { + input: Arc::new(self.0), + function: DslFunction::Unpivot { args }, + } + .into() + } + + pub fn row_index(self, name: PlSmallStr, offset: Option) -> Self { + DslPlan::MapFunction { + input: Arc::new(self.0), + function: DslFunction::RowIndex { name, offset }, + } + .into() + } + + pub fn distinct(self, options: DistinctOptionsDSL) -> Self { + DslPlan::Distinct { + input: Arc::new(self.0), + options, + } + .into() + } + + pub fn slice(self, offset: i64, len: IdxSize) -> Self { + DslPlan::Slice { + input: Arc::new(self.0), + offset, + len, + } + .into() + } + + pub fn join( + self, + other: DslPlan, + left_on: Vec, + right_on: Vec, + options: Arc, + ) -> Self { + DslPlan::Join { + input_left: Arc::new(self.0), + input_right: Arc::new(other), + left_on, + right_on, + predicates: Default::default(), + options, + } + .into() + } + pub fn map_private(self, function: DslFunction) -> Self { + DslPlan::MapFunction { + input: Arc::new(self.0), + function, + } + .into() + } + + #[cfg(feature = "python")] + pub fn map_python( + self, + function: PythonFunction, + optimizations: AllowedOptimizations, + schema: Option, + validate_output: bool, + ) -> Self { + DslPlan::MapFunction { + input: Arc::new(self.0), + function: DslFunction::OpaquePython(OpaquePythonUdf { + function, + schema, + predicate_pd: optimizations.contains(OptFlags::PREDICATE_PUSHDOWN), + projection_pd: optimizations.contains(OptFlags::PROJECTION_PUSHDOWN), + streamable: optimizations.contains(OptFlags::STREAMING), + validate_output, + }), + } + .into() + } + + pub fn map( + self, + function: F, + optimizations: AllowedOptimizations, + schema: Option>, + name: PlSmallStr, + ) -> Self + where + F: DataFrameUdf + 'static, + { + let function = Arc::new(function); + + DslPlan::MapFunction { + input: Arc::new(self.0), + function: DslFunction::FunctionIR(FunctionIR::Opaque { + function, + schema, + predicate_pd: optimizations.contains(OptFlags::PREDICATE_PUSHDOWN), + projection_pd: optimizations.contains(OptFlags::PROJECTION_PUSHDOWN), + streamable: optimizations.contains(OptFlags::STREAMING), + fmt_str: name, + }), + } + .into() + } +} diff --git a/crates/polars-plan/src/dsl/cat.rs b/crates/polars-plan/src/dsl/cat.rs new file mode 100644 index 000000000000..6cc65a3080dd --- /dev/null +++ b/crates/polars-plan/src/dsl/cat.rs @@ -0,0 +1,35 @@ +use super::*; + +/// Specialized expressions for Categorical dtypes. +pub struct CategoricalNameSpace(pub(crate) Expr); + +impl CategoricalNameSpace { + pub fn get_categories(self) -> Expr { + self.0.map_unary(CategoricalFunction::GetCategories) + } + + #[cfg(feature = "strings")] + pub fn len_bytes(self) -> Expr { + self.0.map_unary(CategoricalFunction::LenBytes) + } + + #[cfg(feature = "strings")] + pub fn len_chars(self) -> Expr { + self.0.map_unary(CategoricalFunction::LenChars) + } + + #[cfg(feature = "strings")] + pub fn starts_with(self, prefix: String) -> Expr { + self.0.map_unary(CategoricalFunction::StartsWith(prefix)) + } + + #[cfg(feature = "strings")] + pub fn ends_with(self, suffix: String) -> Expr { + self.0.map_unary(CategoricalFunction::EndsWith(suffix)) + } + + #[cfg(feature = "strings")] + pub fn slice(self, offset: i64, length: Option) -> Expr { + self.0.map_unary(CategoricalFunction::Slice(offset, length)) + } +} diff --git a/crates/polars-plan/src/dsl/dt.rs b/crates/polars-plan/src/dsl/dt.rs new file mode 100644 index 000000000000..cf533e040057 --- /dev/null +++ b/crates/polars-plan/src/dsl/dt.rs @@ -0,0 +1,356 @@ +use super::*; + +/// Specialized expressions for [`Series`] with dates/datetimes. +pub struct DateLikeNameSpace(pub(crate) Expr); + +impl DateLikeNameSpace { + /// Add a given number of business days. + #[cfg(feature = "business")] + pub fn add_business_days( + self, + n: Expr, + week_mask: [bool; 7], + holidays: Vec, + roll: Roll, + ) -> Expr { + self.0.map_binary( + FunctionExpr::Business(BusinessFunction::AddBusinessDay { + week_mask, + holidays, + roll, + }), + n, + ) + } + + /// Convert from Date/Time/Datetime into String with the given format. + /// See [chrono strftime/strptime](https://docs.rs/chrono/0.4.19/chrono/format/strftime/index.html). + pub fn to_string(self, format: &str) -> Expr { + let format = format.to_string(); + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::ToString( + format, + ))) + } + + /// Convert from Date/Time/Datetime into String with the given format. + /// See [chrono strftime/strptime](https://docs.rs/chrono/0.4.19/chrono/format/strftime/index.html). + /// + /// Alias for `to_string`. + pub fn strftime(self, format: &str) -> Expr { + self.to_string(format) + } + + /// Change the underlying [`TimeUnit`]. And update the data accordingly. + pub fn cast_time_unit(self, tu: TimeUnit) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::CastTimeUnit( + tu, + ))) + } + + /// Change the underlying [`TimeUnit`] of the [`Series`]. This does not modify the data. + pub fn with_time_unit(self, tu: TimeUnit) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::WithTimeUnit( + tu, + ))) + } + + /// Change the underlying [`TimeZone`] of the [`Series`]. This does not modify the data. + #[cfg(feature = "timezones")] + pub fn convert_time_zone(self, time_zone: TimeZone) -> Expr { + self.0.map_unary(FunctionExpr::TemporalExpr( + TemporalFunction::ConvertTimeZone(time_zone), + )) + } + + /// Get the millennium of a Date/Datetime + pub fn millennium(self) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::Millennium)) + } + + /// Get the century of a Date/Datetime + pub fn century(self) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::Century)) + } + + /// Get the year of a Date/Datetime + pub fn year(self) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::Year)) + } + + /// Determine whether days are business days. + #[cfg(feature = "business")] + pub fn is_business_day(self, week_mask: [bool; 7], holidays: Vec) -> Expr { + self.0 + .map_unary(FunctionExpr::Business(BusinessFunction::IsBusinessDay { + week_mask, + holidays, + })) + } + + // Compute whether the year of a Date/Datetime is a leap year. + pub fn is_leap_year(self) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::IsLeapYear)) + } + + /// Get the iso-year of a Date/Datetime. + /// This may not correspond with a calendar year. + pub fn iso_year(self) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::IsoYear)) + } + + /// Get the month of a Date/Datetime. + pub fn month(self) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::Month)) + } + + /// Extract quarter from underlying NaiveDateTime representation. + /// Quarters range from 1 to 4. + pub fn quarter(self) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::Quarter)) + } + + /// Extract the week from the underlying Date representation. + /// Can be performed on Date and Datetime + /// + /// Returns the ISO week number starting from 1. + /// The return value ranges from 1 to 53. (The last week of year differs by years.) + pub fn week(self) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::Week)) + } + + /// Extract the ISO week day from the underlying Date representation. + /// Can be performed on Date and Datetime. + /// + /// Returns the weekday number where monday = 1 and sunday = 7 + pub fn weekday(self) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::WeekDay)) + } + + /// Get the month of a Date/Datetime. + pub fn day(self) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::Day)) + } + + /// Get the ordinal_day of a Date/Datetime. + pub fn ordinal_day(self) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::OrdinalDay)) + } + + /// Get the (local) time of a Date/Datetime/Time. + pub fn time(self) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::Time)) + } + + /// Get the (local) date of a Date/Datetime. + pub fn date(self) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::Date)) + } + + /// Get the (local) datetime of a Datetime. + pub fn datetime(self) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::Datetime)) + } + + /// Get the hour of a Datetime/Time64. + pub fn hour(self) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::Hour)) + } + + /// Get the minute of a Datetime/Time64. + pub fn minute(self) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::Minute)) + } + + /// Get the second of a Datetime/Time64. + pub fn second(self) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::Second)) + } + + /// Get the millisecond of a Time64 (scaled from nanosecs). + pub fn millisecond(self) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::Millisecond)) + } + + /// Get the microsecond of a Time64 (scaled from nanosecs). + pub fn microsecond(self) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::Microsecond)) + } + + /// Get the nanosecond part of a Time64. + pub fn nanosecond(self) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::Nanosecond)) + } + + /// Return the timestamp (UNIX epoch) of a Datetime/Date. + pub fn timestamp(self, tu: TimeUnit) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::TimeStamp(tu))) + } + + /// Truncate the Datetime/Date range into buckets. + pub fn truncate(self, every: Expr) -> Expr { + self.0.map_binary( + FunctionExpr::TemporalExpr(TemporalFunction::Truncate), + every, + ) + } + + /// Roll backward to the first day of the month. + #[cfg(feature = "month_start")] + pub fn month_start(self) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::MonthStart)) + } + + /// Roll forward to the last day of the month. + #[cfg(feature = "month_end")] + pub fn month_end(self) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::MonthEnd)) + } + + /// Get the base offset from UTC. + #[cfg(feature = "timezones")] + pub fn base_utc_offset(self) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::BaseUtcOffset)) + } + + /// Get the additional offset from UTC currently in effect (usually due to daylight saving time). + #[cfg(feature = "timezones")] + pub fn dst_offset(self) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::DSTOffset)) + } + + /// Round the Datetime/Date range into buckets. + pub fn round(self, every: Expr) -> Expr { + self.0 + .map_binary(FunctionExpr::TemporalExpr(TemporalFunction::Round), every) + } + + /// Offset this `Date/Datetime` by a given offset [`Duration`]. + /// This will take leap years/ months into account. + #[cfg(feature = "offset_by")] + pub fn offset_by(self, by: Expr) -> Expr { + self.0 + .map_binary(FunctionExpr::TemporalExpr(TemporalFunction::OffsetBy), by) + } + + #[cfg(feature = "timezones")] + pub fn replace_time_zone( + self, + time_zone: Option, + ambiguous: Expr, + non_existent: NonExistent, + ) -> Expr { + self.0.map_binary( + FunctionExpr::TemporalExpr(TemporalFunction::ReplaceTimeZone(time_zone, non_existent)), + ambiguous, + ) + } + + /// Combine an existing Date/Datetime with a Time, creating a new Datetime value. + pub fn combine(self, time: Expr, tu: TimeUnit) -> Expr { + self.0.map_binary( + FunctionExpr::TemporalExpr(TemporalFunction::Combine(tu)), + time, + ) + } + + /// Express a Duration in terms of its total number of integer days. + pub fn total_days(self) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::TotalDays)) + } + + /// Express a Duration in terms of its total number of integer hours. + pub fn total_hours(self) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::TotalHours)) + } + + /// Express a Duration in terms of its total number of integer minutes. + pub fn total_minutes(self) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::TotalMinutes)) + } + + /// Express a Duration in terms of its total number of integer seconds. + pub fn total_seconds(self) -> Expr { + self.0 + .map_unary(FunctionExpr::TemporalExpr(TemporalFunction::TotalSeconds)) + } + + /// Express a Duration in terms of its total number of milliseconds. + pub fn total_milliseconds(self) -> Expr { + self.0.map_unary(FunctionExpr::TemporalExpr( + TemporalFunction::TotalMilliseconds, + )) + } + + /// Express a Duration in terms of its total number of microseconds. + pub fn total_microseconds(self) -> Expr { + self.0.map_unary(FunctionExpr::TemporalExpr( + TemporalFunction::TotalMicroseconds, + )) + } + + /// Express a Duration in terms of its total number of nanoseconds. + pub fn total_nanoseconds(self) -> Expr { + self.0.map_unary(FunctionExpr::TemporalExpr( + TemporalFunction::TotalNanoseconds, + )) + } + + /// Replace the time units of a value + #[allow(clippy::too_many_arguments)] + pub fn replace( + self, + year: Expr, + month: Expr, + day: Expr, + hour: Expr, + minute: Expr, + second: Expr, + microsecond: Expr, + ambiguous: Expr, + ) -> Expr { + self.0.map_n_ary( + FunctionExpr::TemporalExpr(TemporalFunction::Replace), + [ + year, + month, + day, + hour, + minute, + second, + microsecond, + ambiguous, + ], + ) + } +} diff --git a/crates/polars-plan/src/dsl/expr.rs b/crates/polars-plan/src/dsl/expr.rs new file mode 100644 index 000000000000..748879e21d55 --- /dev/null +++ b/crates/polars-plan/src/dsl/expr.rs @@ -0,0 +1,548 @@ +use std::fmt::{Debug, Display, Formatter}; +use std::hash::{Hash, Hasher}; + +use bytes::Bytes; +use polars_compute::rolling::QuantileMethod; +use polars_core::chunked_array::cast::CastOptions; +use polars_core::error::feature_gated; +use polars_core::prelude::*; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +pub use super::expr_dyn_fn::*; +use crate::prelude::*; + +#[derive(PartialEq, Clone, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum AggExpr { + Min { + input: Arc, + propagate_nans: bool, + }, + Max { + input: Arc, + propagate_nans: bool, + }, + Median(Arc), + NUnique(Arc), + First(Arc), + Last(Arc), + Mean(Arc), + Implode(Arc), + // include_nulls + Count(Arc, bool), + Quantile { + expr: Arc, + quantile: Arc, + method: QuantileMethod, + }, + Sum(Arc), + AggGroups(Arc), + Std(Arc, u8), + Var(Arc, u8), +} + +impl AsRef for AggExpr { + fn as_ref(&self) -> &Expr { + use AggExpr::*; + match self { + Min { input, .. } => input, + Max { input, .. } => input, + Median(e) => e, + NUnique(e) => e, + First(e) => e, + Last(e) => e, + Mean(e) => e, + Implode(e) => e, + Count(e, _) => e, + Quantile { expr, .. } => expr, + Sum(e) => e, + AggGroups(e) => e, + Std(e, _) => e, + Var(e, _) => e, + } + } +} + +/// Expressions that can be used in various contexts. +/// +/// Queries consist of multiple expressions. +/// When using the polars lazy API, don't construct an `Expr` directly; instead, create one using +/// the functions in the `polars_lazy::dsl` module. See that module's docs for more info. +#[derive(Clone, PartialEq)] +#[must_use] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum Expr { + Alias(Arc, PlSmallStr), + Column(PlSmallStr), + Columns(Arc<[PlSmallStr]>), + DtypeColumn(Vec), + IndexColumn(Arc<[i64]>), + Literal(LiteralValue), + BinaryExpr { + left: Arc, + op: Operator, + right: Arc, + }, + Cast { + expr: Arc, + dtype: DataType, + options: CastOptions, + }, + Sort { + expr: Arc, + options: SortOptions, + }, + Gather { + expr: Arc, + idx: Arc, + returns_scalar: bool, + }, + SortBy { + expr: Arc, + by: Vec, + sort_options: SortMultipleOptions, + }, + Agg(AggExpr), + /// A ternary operation + /// if true then "foo" else "bar" + Ternary { + predicate: Arc, + truthy: Arc, + falsy: Arc, + }, + Function { + /// function arguments + input: Vec, + /// function to apply + function: FunctionExpr, + options: FunctionOptions, + }, + Explode(Arc), + Filter { + input: Arc, + by: Arc, + }, + /// Polars flavored window functions. + Window { + /// Also has the input. i.e. avg("foo") + function: Arc, + partition_by: Vec, + order_by: Option<(Arc, SortOptions)>, + options: WindowType, + }, + Wildcard, + Slice { + input: Arc, + /// length is not yet known so we accept negative offsets + offset: Arc, + length: Arc, + }, + /// Can be used in a select statement to exclude a column from selection + /// TODO: See if we can replace `Vec` with `Arc` + Exclude(Arc, Vec), + /// Set root name as Alias + KeepName(Arc), + Len, + /// Take the nth column in the `DataFrame` + Nth(i64), + RenameAlias { + function: SpecialEq>, + expr: Arc, + }, + #[cfg(feature = "dtype-struct")] + Field(Arc<[PlSmallStr]>), + AnonymousFunction { + /// function arguments + input: Vec, + /// function to apply + function: OpaqueColumnUdf, + /// output dtype of the function + output_type: GetOutput, + options: FunctionOptions, + }, + SubPlan(SpecialEq>, Vec), + /// Expressions in this node should only be expanding + /// e.g. + /// `Expr::Columns` + /// `Expr::Dtypes` + /// `Expr::Wildcard` + /// `Expr::Exclude` + Selector(super::selector::Selector), +} + +pub type OpaqueColumnUdf = LazySerde>>; +pub(crate) fn new_column_udf(func: F) -> OpaqueColumnUdf { + LazySerde::Deserialized(SpecialEq::new(Arc::new(func))) +} + +#[derive(Clone)] +pub enum LazySerde { + Deserialized(T), + Bytes(Bytes), +} + +impl PartialEq for LazySerde { + fn eq(&self, other: &Self) -> bool { + use LazySerde as L; + match (self, other) { + (L::Deserialized(a), L::Deserialized(b)) => a == b, + (L::Bytes(a), L::Bytes(b)) => { + std::ptr::eq(a.as_ptr(), b.as_ptr()) && a.len() == b.len() + }, + _ => false, + } + } +} + +impl Debug for LazySerde { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Bytes(_) => write!(f, "lazy-serde"), + Self::Deserialized(_) => write!(f, "lazy-serde"), + } + } +} + +impl OpaqueColumnUdf { + pub fn materialize(self) -> PolarsResult>> { + match self { + Self::Deserialized(t) => Ok(t), + Self::Bytes(_b) => { + feature_gated!("serde";"python", { + crate::dsl::python_dsl::PythonUdfExpression::try_deserialize(_b.as_ref()).map(SpecialEq::new) + }) + }, + } + } +} + +#[allow(clippy::derived_hash_with_manual_eq)] +impl Hash for Expr { + fn hash(&self, state: &mut H) { + let d = std::mem::discriminant(self); + d.hash(state); + match self { + Expr::Column(name) => name.hash(state), + Expr::Columns(names) => names.hash(state), + Expr::DtypeColumn(dtypes) => dtypes.hash(state), + Expr::IndexColumn(indices) => indices.hash(state), + Expr::Literal(lv) => std::mem::discriminant(lv).hash(state), + Expr::Selector(s) => s.hash(state), + Expr::Nth(v) => v.hash(state), + Expr::Filter { input, by } => { + input.hash(state); + by.hash(state); + }, + Expr::BinaryExpr { left, op, right } => { + left.hash(state); + right.hash(state); + std::mem::discriminant(op).hash(state) + }, + Expr::Cast { + expr, + dtype, + options: strict, + } => { + expr.hash(state); + dtype.hash(state); + strict.hash(state) + }, + Expr::Sort { expr, options } => { + expr.hash(state); + options.hash(state); + }, + Expr::Alias(input, name) => { + input.hash(state); + name.hash(state) + }, + Expr::KeepName(input) => input.hash(state), + Expr::Ternary { + predicate, + truthy, + falsy, + } => { + predicate.hash(state); + truthy.hash(state); + falsy.hash(state); + }, + Expr::Function { + input, + function, + options, + } => { + input.hash(state); + std::mem::discriminant(function).hash(state); + options.hash(state); + }, + Expr::Gather { + expr, + idx, + returns_scalar, + } => { + expr.hash(state); + idx.hash(state); + returns_scalar.hash(state); + }, + // already hashed by discriminant + Expr::Wildcard | Expr::Len => {}, + Expr::SortBy { + expr, + by, + sort_options, + } => { + expr.hash(state); + by.hash(state); + sort_options.hash(state); + }, + Expr::Agg(input) => input.hash(state), + Expr::Explode(input) => input.hash(state), + Expr::Window { + function, + partition_by, + order_by, + options, + } => { + function.hash(state); + partition_by.hash(state); + order_by.hash(state); + options.hash(state); + }, + Expr::Slice { + input, + offset, + length, + } => { + input.hash(state); + offset.hash(state); + length.hash(state); + }, + Expr::Exclude(input, excl) => { + input.hash(state); + excl.hash(state); + }, + Expr::RenameAlias { function: _, expr } => expr.hash(state), + Expr::AnonymousFunction { + input, + function: _, + output_type: _, + options, + } => { + input.hash(state); + options.hash(state); + }, + Expr::SubPlan(_, names) => names.hash(state), + #[cfg(feature = "dtype-struct")] + Expr::Field(names) => names.hash(state), + } + } +} + +impl Eq for Expr {} + +impl Default for Expr { + fn default() -> Self { + Expr::Literal(LiteralValue::Scalar(Scalar::default())) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum Excluded { + Name(PlSmallStr), + Dtype(DataType), +} + +impl Expr { + /// Get Field result of the expression. The schema is the input data. + pub fn to_field(&self, schema: &Schema, ctxt: Context) -> PolarsResult { + // this is not called much and the expression depth is typically shallow + let mut arena = Arena::with_capacity(5); + self.to_field_amortized(schema, ctxt, &mut arena) + } + pub(crate) fn to_field_amortized( + &self, + schema: &Schema, + ctxt: Context, + expr_arena: &mut Arena, + ) -> PolarsResult { + let root = to_aexpr(self.clone(), expr_arena)?; + expr_arena + .get(root) + .to_field_and_validate(schema, ctxt, expr_arena) + } + + /// Extract a constant usize from an expression. + pub fn extract_usize(&self) -> PolarsResult { + match self { + Expr::Literal(n) => n.extract_usize(), + Expr::Cast { expr, dtype, .. } => { + // lit(x, dtype=...) are Cast expressions. We verify the inner expression is literal. + if dtype.is_integer() { + expr.extract_usize() + } else { + polars_bail!(InvalidOperation: "expression must be constant literal to extract integer") + } + }, + _ => { + polars_bail!(InvalidOperation: "expression must be constant literal to extract integer") + }, + } + } + + #[inline] + pub fn map_unary(self, function: impl Into) -> Self { + Expr::n_ary(function, vec![self]) + } + #[inline] + pub fn map_binary(self, function: impl Into, rhs: Self) -> Self { + Expr::n_ary(function, vec![self, rhs]) + } + + #[inline] + pub fn map_ternary(self, function: impl Into, arg1: Expr, arg2: Expr) -> Expr { + Expr::n_ary(function, vec![self, arg1, arg2]) + } + + #[inline] + pub fn try_map_n_ary( + self, + function: impl Into, + exprs: impl IntoIterator>, + ) -> PolarsResult { + let exprs = exprs.into_iter(); + let mut input = Vec::with_capacity(exprs.size_hint().0 + 1); + input.push(self); + for e in exprs { + input.push(e?); + } + Ok(Expr::n_ary(function, input)) + } + + #[inline] + pub fn map_n_ary( + self, + function: impl Into, + exprs: impl IntoIterator, + ) -> Expr { + let exprs = exprs.into_iter(); + let mut input = Vec::with_capacity(exprs.size_hint().0 + 1); + input.push(self); + input.extend(exprs); + Expr::n_ary(function, input) + } + + #[inline] + pub fn n_ary(function: impl Into, input: Vec) -> Expr { + let function = function.into(); + let options = function.function_options(); + Expr::Function { + input, + function, + options, + } + } +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum Operator { + Eq, + EqValidity, + NotEq, + NotEqValidity, + Lt, + LtEq, + Gt, + GtEq, + Plus, + Minus, + Multiply, + Divide, + TrueDivide, + FloorDivide, + Modulus, + And, + Or, + Xor, + LogicalAnd, + LogicalOr, +} + +impl Display for Operator { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use Operator::*; + let tkn = match self { + Eq => "==", + EqValidity => "==v", + NotEq => "!=", + NotEqValidity => "!=v", + Lt => "<", + LtEq => "<=", + Gt => ">", + GtEq => ">=", + Plus => "+", + Minus => "-", + Multiply => "*", + Divide => "//", + TrueDivide => "/", + FloorDivide => "floor_div", + Modulus => "%", + And | LogicalAnd => "&", + Or | LogicalOr => "|", + Xor => "^", + }; + write!(f, "{tkn}") + } +} + +impl Operator { + pub fn is_comparison(&self) -> bool { + matches!( + self, + Self::Eq + | Self::NotEq + | Self::Lt + | Self::LtEq + | Self::Gt + | Self::GtEq + | Self::EqValidity + | Self::NotEqValidity + ) + } + + pub fn is_bitwise(&self) -> bool { + matches!(self, Self::And | Self::Or | Self::Xor) + } + + pub fn is_comparison_or_bitwise(&self) -> bool { + self.is_comparison() || self.is_bitwise() + } + + pub fn swap_operands(self) -> Self { + match self { + Operator::Eq => Operator::Eq, + Operator::Gt => Operator::Lt, + Operator::GtEq => Operator::LtEq, + Operator::LtEq => Operator::GtEq, + Operator::Or => Operator::Or, + Operator::LogicalAnd => Operator::LogicalAnd, + Operator::LogicalOr => Operator::LogicalOr, + Operator::Xor => Operator::Xor, + Operator::NotEq => Operator::NotEq, + Operator::EqValidity => Operator::EqValidity, + Operator::NotEqValidity => Operator::NotEqValidity, + Operator::Divide => Operator::Multiply, + Operator::Multiply => Operator::Divide, + Operator::And => Operator::And, + Operator::Plus => Operator::Minus, + Operator::Minus => Operator::Plus, + Operator::Lt => Operator::Gt, + _ => unimplemented!(), + } + } + + pub fn is_arithmetic(&self) -> bool { + !(self.is_comparison_or_bitwise()) + } +} diff --git a/crates/polars-plan/src/dsl/expr_dyn_fn.rs b/crates/polars-plan/src/dsl/expr_dyn_fn.rs new file mode 100644 index 000000000000..c0cde176f2be --- /dev/null +++ b/crates/polars-plan/src/dsl/expr_dyn_fn.rs @@ -0,0 +1,459 @@ +use std::fmt::Formatter; +use std::ops::Deref; +use std::sync::Arc; + +#[cfg(feature = "python")] +use polars_utils::pl_serialize::deserialize_map_bytes; +#[cfg(feature = "serde")] +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +use super::*; + +/// A wrapper trait for any closure `Fn(Vec) -> PolarsResult` +pub trait ColumnsUdf: Send + Sync { + fn as_any(&self) -> &dyn std::any::Any { + unimplemented!("as_any not implemented for this 'opaque' function") + } + + fn call_udf(&self, s: &mut [Column]) -> PolarsResult>; + + fn try_serialize(&self, _buf: &mut Vec) -> PolarsResult<()> { + polars_bail!(ComputeError: "serialization not supported for this 'opaque' function") + } +} + +#[cfg(feature = "serde")] +impl Serialize for SpecialEq> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: Serializer, + { + use serde::ser::Error; + let mut buf = vec![]; + self.0 + .try_serialize(&mut buf) + .map_err(|e| S::Error::custom(format!("{e}")))?; + serializer.serialize_bytes(&buf) + } +} + +#[cfg(feature = "serde")] +impl Serialize for LazySerde { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + match self { + Self::Deserialized(t) => t.serialize(serializer), + Self::Bytes(b) => b.serialize(serializer), + } + } +} + +#[cfg(feature = "serde")] +impl<'a, T: Deserialize<'a> + Clone> Deserialize<'a> for LazySerde { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'a>, + { + let buf = bytes::Bytes::deserialize(deserializer)?; + Ok(Self::Bytes(buf)) + } +} + +#[cfg(feature = "serde")] +// impl Deserialize for crate::dsl::expr::LazySerde { +impl<'a> Deserialize<'a> for SpecialEq> { + fn deserialize(deserializer: D) -> std::result::Result + where + D: Deserializer<'a>, + { + use serde::de::Error; + #[cfg(feature = "python")] + { + deserialize_map_bytes(deserializer, |buf| { + if buf.starts_with(crate::dsl::python_dsl::PYTHON_SERDE_MAGIC_BYTE_MARK) { + let udf = crate::dsl::python_dsl::PythonUdfExpression::try_deserialize(&buf) + .map_err(|e| D::Error::custom(format!("{e}")))?; + Ok(SpecialEq::new(udf)) + } else { + Err(D::Error::custom( + "deserialization not supported for this 'opaque' function", + )) + } + })? + } + #[cfg(not(feature = "python"))] + { + _ = deserializer; + + Err(D::Error::custom( + "deserialization not supported for this 'opaque' function", + )) + } + } +} + +impl ColumnsUdf for F +where + F: Fn(&mut [Column]) -> PolarsResult> + Send + Sync, +{ + fn call_udf(&self, s: &mut [Column]) -> PolarsResult> { + self(s) + } +} + +impl Debug for dyn ColumnsUdf { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "ColumnUdf") + } +} + +/// A wrapper trait for any binary closure `Fn(Column, Column) -> PolarsResult` +pub trait ColumnBinaryUdf: Send + Sync { + fn call_udf(&self, a: Column, b: Column) -> PolarsResult; +} + +impl ColumnBinaryUdf for F +where + F: Fn(Column, Column) -> PolarsResult + Send + Sync, +{ + fn call_udf(&self, a: Column, b: Column) -> PolarsResult { + self(a, b) + } +} + +impl Debug for dyn ColumnBinaryUdf { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "ColumnBinaryUdf") + } +} + +impl Default for SpecialEq> { + fn default() -> Self { + panic!("implementation error"); + } +} + +impl Default for SpecialEq> { + fn default() -> Self { + let output_field = move |_: &Schema, _: Context, _: &Field, _: &Field| None; + SpecialEq::new(Arc::new(output_field)) + } +} + +pub trait RenameAliasFn: Send + Sync { + fn call(&self, name: &PlSmallStr) -> PolarsResult; + + fn try_serialize(&self, _buf: &mut Vec) -> PolarsResult<()> { + polars_bail!(ComputeError: "serialization not supported for this renaming function") + } +} + +impl RenameAliasFn for F +where + F: Fn(&PlSmallStr) -> PolarsResult + Send + Sync, +{ + fn call(&self, name: &PlSmallStr) -> PolarsResult { + self(name) + } +} + +impl Debug for dyn RenameAliasFn { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "RenameAliasFn") + } +} + +#[derive(Clone)] +/// Wrapper type that has special equality properties +/// depending on the inner type specialization +pub struct SpecialEq(T); + +#[cfg(feature = "serde")] +impl Serialize for SpecialEq { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: Serializer, + { + self.0.serialize(serializer) + } +} + +#[cfg(feature = "serde")] +impl<'a> Deserialize<'a> for SpecialEq { + fn deserialize(deserializer: D) -> std::result::Result + where + D: Deserializer<'a>, + { + let t = Series::deserialize(deserializer)?; + Ok(SpecialEq(t)) + } +} + +#[cfg(feature = "serde")] +impl Serialize for SpecialEq> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: Serializer, + { + self.0.serialize(serializer) + } +} + +#[cfg(feature = "serde")] +impl<'a, T: Deserialize<'a>> Deserialize<'a> for SpecialEq> { + fn deserialize(deserializer: D) -> std::result::Result + where + D: Deserializer<'a>, + { + let t = T::deserialize(deserializer)?; + Ok(SpecialEq(Arc::new(t))) + } +} + +impl SpecialEq { + pub fn new(val: T) -> Self { + SpecialEq(val) + } +} + +impl PartialEq for SpecialEq> { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.0, &other.0) + } +} + +impl Eq for SpecialEq> {} + +impl PartialEq for SpecialEq { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Debug for SpecialEq { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "no_eq") + } +} + +impl Deref for SpecialEq { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +pub trait BinaryUdfOutputField: Send + Sync { + fn get_field( + &self, + input_schema: &Schema, + cntxt: Context, + field_a: &Field, + field_b: &Field, + ) -> Option; +} + +impl BinaryUdfOutputField for F +where + F: Fn(&Schema, Context, &Field, &Field) -> Option + Send + Sync, +{ + fn get_field( + &self, + input_schema: &Schema, + cntxt: Context, + field_a: &Field, + field_b: &Field, + ) -> Option { + self(input_schema, cntxt, field_a, field_b) + } +} + +pub trait FunctionOutputField: Send + Sync { + fn get_field( + &self, + input_schema: &Schema, + cntxt: Context, + fields: &[Field], + ) -> PolarsResult; + + fn try_serialize(&self, _buf: &mut Vec) -> PolarsResult<()> { + polars_bail!(ComputeError: "serialization not supported for this output field") + } +} + +pub type GetOutput = SpecialEq>; + +impl Default for GetOutput { + fn default() -> Self { + SpecialEq::new(Arc::new( + |_input_schema: &Schema, _cntxt: Context, fields: &[Field]| Ok(fields[0].clone()), + )) + } +} + +impl GetOutput { + pub fn same_type() -> Self { + Default::default() + } + + pub fn first() -> Self { + SpecialEq::new(Arc::new( + |_input_schema: &Schema, _cntxt: Context, fields: &[Field]| Ok(fields[0].clone()), + )) + } + + pub fn from_type(dt: DataType) -> Self { + SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| { + Ok(Field::new(flds[0].name().clone(), dt.clone())) + })) + } + + pub fn map_field PolarsResult + Send + Sync>(f: F) -> Self { + SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| { + f(&flds[0]) + })) + } + + pub fn map_fields PolarsResult + Send + Sync>( + f: F, + ) -> Self { + SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| { + f(flds) + })) + } + + pub fn map_dtype PolarsResult + Send + Sync>( + f: F, + ) -> Self { + SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| { + let mut fld = flds[0].clone(); + let new_type = f(fld.dtype())?; + fld.coerce(new_type); + Ok(fld) + })) + } + + pub fn float_type() -> Self { + Self::map_dtype(|dt| { + Ok(match dt { + DataType::Float32 => DataType::Float32, + _ => DataType::Float64, + }) + }) + } + + pub fn super_type() -> Self { + Self::map_dtypes(|dtypes| { + let mut st = dtypes[0].clone(); + for dt in &dtypes[1..] { + st = try_get_supertype(&st, dt)?; + } + Ok(st) + }) + } + + pub fn map_dtypes(f: F) -> Self + where + F: 'static + Fn(&[&DataType]) -> PolarsResult + Send + Sync, + { + SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| { + let mut fld = flds[0].clone(); + let dtypes = flds.iter().map(|fld| fld.dtype()).collect::>(); + let new_type = f(&dtypes)?; + fld.coerce(new_type); + Ok(fld) + })) + } +} + +impl FunctionOutputField for F +where + F: Fn(&Schema, Context, &[Field]) -> PolarsResult + Send + Sync, +{ + fn get_field( + &self, + input_schema: &Schema, + cntxt: Context, + fields: &[Field], + ) -> PolarsResult { + self(input_schema, cntxt, fields) + } +} + +#[cfg(feature = "serde")] +impl Serialize for GetOutput { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: Serializer, + { + use serde::ser::Error; + let mut buf = vec![]; + self.0 + .try_serialize(&mut buf) + .map_err(|e| S::Error::custom(format!("{e}")))?; + serializer.serialize_bytes(&buf) + } +} + +#[cfg(feature = "serde")] +impl<'a> Deserialize<'a> for GetOutput { + fn deserialize(deserializer: D) -> std::result::Result + where + D: Deserializer<'a>, + { + use serde::de::Error; + #[cfg(feature = "python")] + { + deserialize_map_bytes(deserializer, |buf| { + if buf.starts_with(self::python_dsl::PYTHON_SERDE_MAGIC_BYTE_MARK) { + let get_output = self::python_dsl::PythonGetOutput::try_deserialize(&buf) + .map_err(|e| D::Error::custom(format!("{e}")))?; + Ok(SpecialEq::new(get_output)) + } else { + Err(D::Error::custom( + "deserialization not supported for this output field", + )) + } + })? + } + #[cfg(not(feature = "python"))] + { + _ = deserializer; + + Err(D::Error::custom( + "deserialization not supported for this output field", + )) + } + } +} + +#[cfg(feature = "serde")] +impl Serialize for SpecialEq> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: Serializer, + { + use serde::ser::Error; + let mut buf = vec![]; + self.0 + .try_serialize(&mut buf) + .map_err(|e| S::Error::custom(format!("{e}")))?; + serializer.serialize_bytes(&buf) + } +} + +#[cfg(feature = "serde")] +impl<'a> Deserialize<'a> for SpecialEq> { + fn deserialize(_deserializer: D) -> std::result::Result + where + D: Deserializer<'a>, + { + use serde::de::Error; + Err(D::Error::custom( + "deserialization not supported for this renaming function", + )) + } +} diff --git a/crates/polars-plan/src/dsl/file_scan.rs b/crates/polars-plan/src/dsl/file_scan.rs new file mode 100644 index 000000000000..5a290d83060a --- /dev/null +++ b/crates/polars-plan/src/dsl/file_scan.rs @@ -0,0 +1,240 @@ +use std::hash::Hash; + +use polars_io::cloud::CloudOptions; +#[cfg(feature = "csv")] +use polars_io::csv::read::CsvReadOptions; +#[cfg(feature = "ipc")] +use polars_io::ipc::IpcScanOptions; +#[cfg(feature = "parquet")] +use polars_io::parquet::metadata::FileMetadataRef; +#[cfg(feature = "parquet")] +use polars_io::parquet::read::ParquetOptions; +use polars_io::{HiveOptions, RowIndex}; +use polars_utils::slice_enum::Slice; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; + +use super::*; + +bitflags::bitflags! { + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + pub struct ScanFlags : u32 { + const SPECIALIZED_PREDICATE_FILTER = 0x01; + } +} + +#[derive(Clone, Debug, IntoStaticStr)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +// TODO: Arc<> some of the options and the cloud options. +pub enum FileScan { + #[cfg(feature = "csv")] + Csv { options: CsvReadOptions }, + + #[cfg(feature = "json")] + NDJson { options: NDJsonReadOptions }, + + #[cfg(feature = "parquet")] + Parquet { + options: ParquetOptions, + #[cfg_attr(feature = "serde", serde(skip))] + metadata: Option, + }, + + #[cfg(feature = "ipc")] + Ipc { + options: IpcScanOptions, + #[cfg_attr(feature = "serde", serde(skip))] + metadata: Option>, + }, + + #[cfg_attr(feature = "serde", serde(skip))] + Anonymous { + options: Arc, + function: Arc, + }, +} + +impl FileScan { + pub fn flags(&self) -> ScanFlags { + match self { + #[cfg(feature = "csv")] + Self::Csv { .. } => ScanFlags::empty(), + #[cfg(feature = "ipc")] + Self::Ipc { .. } => ScanFlags::empty(), + #[cfg(feature = "parquet")] + Self::Parquet { .. } => ScanFlags::SPECIALIZED_PREDICATE_FILTER, + #[cfg(feature = "json")] + Self::NDJson { .. } => ScanFlags::empty(), + #[allow(unreachable_patterns)] + _ => ScanFlags::empty(), + } + } + + pub(crate) fn sort_projection(&self, _has_row_index: bool) -> bool { + match self { + #[cfg(feature = "csv")] + Self::Csv { .. } => true, + #[cfg(feature = "ipc")] + Self::Ipc { .. } => _has_row_index, + #[cfg(feature = "parquet")] + Self::Parquet { .. } => false, + #[allow(unreachable_patterns)] + _ => false, + } + } + + pub fn streamable(&self) -> bool { + match self { + #[cfg(feature = "csv")] + Self::Csv { .. } => true, + #[cfg(feature = "ipc")] + Self::Ipc { .. } => false, + #[cfg(feature = "parquet")] + Self::Parquet { .. } => true, + #[cfg(feature = "json")] + Self::NDJson { .. } => false, + #[allow(unreachable_patterns)] + _ => false, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Default, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub enum MissingColumnsPolicy { + #[default] + Raise, + /// Inserts full-NULL columns for the missing ones. + Insert, +} + +#[derive(Debug, Clone, PartialEq, Eq, Default, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub enum CastColumnsPolicy { + /// Raise an error if the datatypes do not match + #[default] + ErrorOnMismatch, +} + +#[derive(Debug, Clone, PartialEq, Eq, Default, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub enum ExtraColumnsPolicy { + /// Error if there are extra columns outside the target schema. + #[default] + Raise, + Ignore, +} + +/// Scan arguments shared across different scan types. +#[derive(Debug, Clone, PartialEq, Eq, Default, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct UnifiedScanArgs { + /// User-provided schema of the file. Will be inferred during IR conversion + /// if None. + pub schema: Option, + pub cloud_options: Option, + pub hive_options: HiveOptions, + + pub rechunk: bool, + pub cache: bool, + pub glob: bool, + + pub projection: Option>, + pub row_index: Option, + /// Slice applied before predicates + pub pre_slice: Option, + + pub cast_columns_policy: CastColumnsPolicy, + pub missing_columns_policy: MissingColumnsPolicy, + pub include_file_paths: Option, +} + +/// Manual impls of Eq/Hash, as some fields are `Arc` where T does not have Eq/Hash. For these +/// fields we compare the pointer addresses instead. +mod _file_scan_eq_hash { + use std::hash::{Hash, Hasher}; + use std::sync::Arc; + + use super::FileScan; + + impl PartialEq for FileScan { + fn eq(&self, other: &Self) -> bool { + FileScanEqHashWrap::from(self) == FileScanEqHashWrap::from(other) + } + } + + impl Eq for FileScan {} + + impl Hash for FileScan { + fn hash(&self, state: &mut H) { + FileScanEqHashWrap::from(self).hash(state) + } + } + + /// # Hash / Eq safety + /// * All usizes originate from `Arc<>`s, and the lifetime of this enum is bound to that of the + /// input ref. + #[derive(PartialEq, Hash)] + pub enum FileScanEqHashWrap<'a> { + #[cfg(feature = "csv")] + Csv { + options: &'a polars_io::csv::read::CsvReadOptions, + }, + + #[cfg(feature = "json")] + NDJson { + options: &'a crate::prelude::NDJsonReadOptions, + }, + + #[cfg(feature = "parquet")] + Parquet { + options: &'a polars_io::prelude::ParquetOptions, + metadata: Option, + }, + + #[cfg(feature = "ipc")] + Ipc { + options: &'a polars_io::prelude::IpcScanOptions, + metadata: Option, + }, + + Anonymous { + options: &'a crate::dsl::AnonymousScanOptions, + function: usize, + }, + } + + impl<'a> From<&'a FileScan> for FileScanEqHashWrap<'a> { + fn from(value: &'a FileScan) -> Self { + match value { + #[cfg(feature = "csv")] + FileScan::Csv { options } => FileScanEqHashWrap::Csv { options }, + + #[cfg(feature = "json")] + FileScan::NDJson { options } => FileScanEqHashWrap::NDJson { options }, + + #[cfg(feature = "parquet")] + FileScan::Parquet { options, metadata } => FileScanEqHashWrap::Parquet { + options, + metadata: metadata.as_ref().map(arc_as_ptr), + }, + + #[cfg(feature = "ipc")] + FileScan::Ipc { options, metadata } => FileScanEqHashWrap::Ipc { + options, + metadata: metadata.as_ref().map(arc_as_ptr), + }, + + FileScan::Anonymous { options, function } => FileScanEqHashWrap::Anonymous { + options, + function: arc_as_ptr(function), + }, + } + } + } + + fn arc_as_ptr(arc: &Arc) -> usize { + Arc::as_ptr(arc) as *const () as usize + } +} diff --git a/crates/polars-plan/src/dsl/format.rs b/crates/polars-plan/src/dsl/format.rs new file mode 100644 index 000000000000..6d7bd9cbf701 --- /dev/null +++ b/crates/polars-plan/src/dsl/format.rs @@ -0,0 +1,170 @@ +use std::fmt; + +use crate::prelude::*; + +impl fmt::Display for Expr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self, f) + } +} + +impl fmt::Debug for Expr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use Expr::*; + match self { + Window { + function, + partition_by, + order_by, + options, + } => match options { + #[cfg(feature = "dynamic_group_by")] + WindowType::Rolling(options) => { + write!( + f, + "{:?}.rolling(by='{}', offset={}, period={})", + function, options.index_column, options.offset, options.period + ) + }, + _ => { + if let Some((order_by, _)) = order_by { + write!( + f, + "{function:?}.over(partition_by: {partition_by:?}, order_by: {order_by:?})" + ) + } else { + write!(f, "{function:?}.over({partition_by:?})") + } + }, + }, + Nth(i) => write!(f, "nth({i})"), + Len => write!(f, "len()"), + Explode(expr) => write!(f, "{expr:?}.explode()"), + Alias(expr, name) => write!(f, "{expr:?}.alias(\"{name}\")"), + Column(name) => write!(f, "col(\"{name}\")"), + Literal(v) => write!(f, "{v:?}"), + BinaryExpr { left, op, right } => write!(f, "[({left:?}) {op:?} ({right:?})]"), + Sort { expr, options } => { + if options.descending { + write!(f, "{expr:?}.sort(desc)") + } else { + write!(f, "{expr:?}.sort(asc)") + } + }, + SortBy { + expr, + by, + sort_options, + } => { + write!( + f, + "{expr:?}.sort_by(by={by:?}, sort_option={sort_options:?})", + ) + }, + Filter { input, by } => { + write!(f, "{input:?}.filter({by:?})") + }, + Gather { + expr, + idx, + returns_scalar, + } => { + if *returns_scalar { + write!(f, "{expr:?}.get({idx:?})") + } else { + write!(f, "{expr:?}.gather({idx:?})") + } + }, + SubPlan(lf, _) => { + write!(f, ".subplan({lf:?})") + }, + Agg(agg) => { + use AggExpr::*; + match agg { + Min { + input, + propagate_nans, + } => { + if *propagate_nans { + write!(f, "{input:?}.nan_min()") + } else { + write!(f, "{input:?}.min()") + } + }, + Max { + input, + propagate_nans, + } => { + if *propagate_nans { + write!(f, "{input:?}.nan_max()") + } else { + write!(f, "{input:?}.max()") + } + }, + Median(expr) => write!(f, "{expr:?}.median()"), + Mean(expr) => write!(f, "{expr:?}.mean()"), + First(expr) => write!(f, "{expr:?}.first()"), + Last(expr) => write!(f, "{expr:?}.last()"), + Implode(expr) => write!(f, "{expr:?}.list()"), + NUnique(expr) => write!(f, "{expr:?}.n_unique()"), + Sum(expr) => write!(f, "{expr:?}.sum()"), + AggGroups(expr) => write!(f, "{expr:?}.groups()"), + Count(expr, _) => write!(f, "{expr:?}.count()"), + Var(expr, _) => write!(f, "{expr:?}.var()"), + Std(expr, _) => write!(f, "{expr:?}.std()"), + Quantile { expr, .. } => write!(f, "{expr:?}.quantile()"), + } + }, + Cast { + expr, + dtype, + options, + } => { + if options.strict() { + write!(f, "{expr:?}.strict_cast({dtype:?})") + } else { + write!(f, "{expr:?}.cast({dtype:?})") + } + }, + Ternary { + predicate, + truthy, + falsy, + } => write!( + f, + ".when({predicate:?}).then({truthy:?}).otherwise({falsy:?})", + ), + Function { + input, function, .. + } => { + if input.len() >= 2 { + write!(f, "{:?}.{function}({:?})", input[0], &input[1..]) + } else { + write!(f, "{:?}.{function}()", input[0]) + } + }, + AnonymousFunction { input, options, .. } => { + if input.len() >= 2 { + write!(f, "{:?}.{}({:?})", input[0], options.fmt_str, &input[1..]) + } else { + write!(f, "{:?}.{}()", input[0], options.fmt_str) + } + }, + Slice { + input, + offset, + length, + } => write!(f, "{input:?}.slice(offset={offset:?}, length={length:?})",), + Wildcard => write!(f, "*"), + Exclude(column, names) => write!(f, "{column:?}.exclude({names:?})"), + KeepName(e) => write!(f, "{e:?}.name.keep()"), + RenameAlias { expr, .. } => write!(f, ".rename_alias({expr:?})"), + Columns(names) => write!(f, "cols({names:?})"), + DtypeColumn(dt) => write!(f, "dtype_columns({dt:?})"), + IndexColumn(idxs) => write!(f, "index_columns({idxs:?})"), + Selector(_) => write!(f, "selector"), + #[cfg(feature = "dtype-struct")] + Field(names) => write!(f, ".field({names:?})"), + } + } +} diff --git a/crates/polars-plan/src/dsl/from.rs b/crates/polars-plan/src/dsl/from.rs new file mode 100644 index 000000000000..dcc53f51e1f9 --- /dev/null +++ b/crates/polars-plan/src/dsl/from.rs @@ -0,0 +1,39 @@ +use super::*; + +impl From for Expr { + fn from(agg: AggExpr) -> Self { + Expr::Agg(agg) + } +} + +impl From<&str> for Expr { + fn from(s: &str) -> Self { + col(PlSmallStr::from_str(s)) + } +} + +macro_rules! from_literals { + ($type:ty) => { + impl From<$type> for Expr { + fn from(val: $type) -> Self { + lit(val) + } + } + }; +} + +from_literals!(f32); +from_literals!(f64); +#[cfg(feature = "dtype-i8")] +from_literals!(i8); +#[cfg(feature = "dtype-i16")] +from_literals!(i16); +from_literals!(i32); +from_literals!(i64); +#[cfg(feature = "dtype-u8")] +from_literals!(u8); +#[cfg(feature = "dtype-u16")] +from_literals!(u16); +from_literals!(u32); +from_literals!(u64); +from_literals!(bool); diff --git a/crates/polars-plan/src/dsl/function_expr/abs.rs b/crates/polars-plan/src/dsl/function_expr/abs.rs new file mode 100644 index 000000000000..5464f06daada --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/abs.rs @@ -0,0 +1,5 @@ +use super::*; + +pub(super) fn abs(s: &Column) -> PolarsResult { + polars_ops::prelude::abs(s.as_materialized_series()).map(Column::from) +} diff --git a/crates/polars-plan/src/dsl/function_expr/arg_where.rs b/crates/polars-plan/src/dsl/function_expr/arg_where.rs new file mode 100644 index 000000000000..ab0afba55960 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/arg_where.rs @@ -0,0 +1,42 @@ +use polars_core::utils::arrow::bitmap::utils::SlicesIterator; + +use super::*; + +pub(super) fn arg_where(s: &mut [Column]) -> PolarsResult> { + let predicate = s[0].bool()?; + + if predicate.is_empty() { + Ok(Some(Column::full_null( + predicate.name().clone(), + 0, + &IDX_DTYPE, + ))) + } else { + let capacity = predicate.sum().unwrap(); + let mut out = Vec::with_capacity(capacity as usize); + let mut total_offset = 0; + + predicate.downcast_iter().for_each(|arr| { + let values = match arr.validity() { + Some(validity) if validity.unset_bits() > 0 => validity & arr.values(), + _ => arr.values().clone(), + }; + + for (offset, len) in SlicesIterator::new(&values) { + // law of small numbers optimization + if len == 1 { + out.push((total_offset + offset) as IdxSize) + } else { + let offset = (offset + total_offset) as IdxSize; + let len = len as IdxSize; + let iter = offset..offset + len; + out.extend(iter) + } + } + + total_offset += arr.len(); + }); + let ca = IdxCa::with_chunk(predicate.name().clone(), IdxArr::from_vec(out)); + Ok(Some(ca.into_column())) + } +} diff --git a/crates/polars-plan/src/dsl/function_expr/array.rs b/crates/polars-plan/src/dsl/function_expr/array.rs new file mode 100644 index 000000000000..1ede2c40c57a --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/array.rs @@ -0,0 +1,375 @@ +use polars_ops::chunked_array::array::*; + +use super::*; +use crate::{map, map_as_slice}; + +#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum ArrayFunction { + Length, + Min, + Max, + Sum, + ToList, + Unique(bool), + NUnique, + Std(u8), + Var(u8), + Median, + #[cfg(feature = "array_any_all")] + Any, + #[cfg(feature = "array_any_all")] + All, + Sort(SortOptions), + Reverse, + ArgMin, + ArgMax, + Get(bool), + Join(bool), + #[cfg(feature = "is_in")] + Contains, + #[cfg(feature = "array_count")] + CountMatches, + Shift, + Explode, + Concat, +} + +impl ArrayFunction { + pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult { + use ArrayFunction::*; + match self { + Concat => Ok(Field::new( + mapper + .args() + .first() + .map_or(PlSmallStr::EMPTY, |x| x.name.clone()), + concat_arr_output_dtype( + &mut mapper.args().iter().map(|x| (x.name.as_str(), &x.dtype)), + )?, + )), + Length => mapper.with_dtype(IDX_DTYPE), + Min | Max => mapper.map_to_list_and_array_inner_dtype(), + Sum => mapper.nested_sum_type(), + ToList => mapper.try_map_dtype(map_array_dtype_to_list_dtype), + Unique(_) => mapper.try_map_dtype(map_array_dtype_to_list_dtype), + NUnique => mapper.with_dtype(IDX_DTYPE), + Std(_) => mapper.map_to_float_dtype(), + Var(_) => mapper.map_to_float_dtype(), + Median => mapper.map_to_float_dtype(), + #[cfg(feature = "array_any_all")] + Any | All => mapper.with_dtype(DataType::Boolean), + Sort(_) => mapper.with_same_dtype(), + Reverse => mapper.with_same_dtype(), + ArgMin | ArgMax => mapper.with_dtype(IDX_DTYPE), + Get(_) => mapper.map_to_list_and_array_inner_dtype(), + Join(_) => mapper.with_dtype(DataType::String), + #[cfg(feature = "is_in")] + Contains => mapper.with_dtype(DataType::Boolean), + #[cfg(feature = "array_count")] + CountMatches => mapper.with_dtype(IDX_DTYPE), + Shift => mapper.with_same_dtype(), + Explode => mapper.try_map_to_array_inner_dtype(), + } + } + + pub fn function_options(&self) -> FunctionOptions { + use ArrayFunction as A; + match self { + #[cfg(feature = "array_any_all")] + A::Any | A::All => FunctionOptions::elementwise(), + #[cfg(feature = "is_in")] + A::Contains => FunctionOptions::elementwise(), + #[cfg(feature = "array_count")] + A::CountMatches => FunctionOptions::elementwise(), + A::Length + | A::Min + | A::Max + | A::Sum + | A::ToList + | A::Unique(_) + | A::NUnique + | A::Std(_) + | A::Var(_) + | A::Median + | A::Sort(_) + | A::Reverse + | A::ArgMin + | A::ArgMax + | A::Concat + | A::Get(_) + | A::Join(_) + | A::Shift => FunctionOptions::elementwise(), + A::Explode => FunctionOptions::row_separable(), + } + } +} + +fn map_array_dtype_to_list_dtype(datatype: &DataType) -> PolarsResult { + if let DataType::Array(inner, _) = datatype { + Ok(DataType::List(inner.clone())) + } else { + polars_bail!(ComputeError: "expected array dtype") + } +} + +impl Display for ArrayFunction { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use ArrayFunction::*; + let name = match self { + Concat => "concat", + Length => "length", + Min => "min", + Max => "max", + Sum => "sum", + ToList => "to_list", + Unique(_) => "unique", + NUnique => "n_unique", + Std(_) => "std", + Var(_) => "var", + Median => "median", + #[cfg(feature = "array_any_all")] + Any => "any", + #[cfg(feature = "array_any_all")] + All => "all", + Sort(_) => "sort", + Reverse => "reverse", + ArgMin => "arg_min", + ArgMax => "arg_max", + Get(_) => "get", + Join(_) => "join", + #[cfg(feature = "is_in")] + Contains => "contains", + #[cfg(feature = "array_count")] + CountMatches => "count_matches", + Shift => "shift", + Explode => "explode", + }; + write!(f, "arr.{name}") + } +} + +impl From for SpecialEq> { + fn from(func: ArrayFunction) -> Self { + use ArrayFunction::*; + match func { + Concat => map_as_slice!(concat_arr), + Length => map!(length), + Min => map!(min), + Max => map!(max), + Sum => map!(sum), + ToList => map!(to_list), + Unique(stable) => map!(unique, stable), + NUnique => map!(n_unique), + Std(ddof) => map!(std, ddof), + Var(ddof) => map!(var, ddof), + Median => map!(median), + #[cfg(feature = "array_any_all")] + Any => map!(any), + #[cfg(feature = "array_any_all")] + All => map!(all), + Sort(options) => map!(sort, options), + Reverse => map!(reverse), + ArgMin => map!(arg_min), + ArgMax => map!(arg_max), + Get(null_on_oob) => map_as_slice!(get, null_on_oob), + Join(ignore_nulls) => map_as_slice!(join, ignore_nulls), + #[cfg(feature = "is_in")] + Contains => map_as_slice!(contains), + #[cfg(feature = "array_count")] + CountMatches => map_as_slice!(count_matches), + Shift => map_as_slice!(shift), + Explode => map_as_slice!(explode), + } + } +} + +pub(super) fn length(s: &Column) -> PolarsResult { + let array = s.array()?; + let width = array.width(); + let width = IdxSize::try_from(width) + .map_err(|_| polars_err!(bigidx, ctx = "array length", size = width))?; + + let mut c = Column::new_scalar(array.name().clone(), width.into(), array.len()); + if let Some(validity) = array.rechunk_validity() { + let mut series = c.into_materialized_series().clone(); + + // SAFETY: We keep datatypes intact and call compute_len afterwards. + let chunks = unsafe { series.chunks_mut() }; + assert_eq!(chunks.len(), 1); + + chunks[0] = chunks[0].with_validity(Some(validity)); + + series.compute_len(); + c = series.into_column(); + } + + Ok(c) +} + +pub(super) fn max(s: &Column) -> PolarsResult { + Ok(s.array()?.array_max().into()) +} + +pub(super) fn min(s: &Column) -> PolarsResult { + Ok(s.array()?.array_min().into()) +} + +pub(super) fn sum(s: &Column) -> PolarsResult { + s.array()?.array_sum().map(Column::from) +} + +pub(super) fn std(s: &Column, ddof: u8) -> PolarsResult { + s.array()?.array_std(ddof).map(Column::from) +} + +pub(super) fn var(s: &Column, ddof: u8) -> PolarsResult { + s.array()?.array_var(ddof).map(Column::from) +} +pub(super) fn median(s: &Column) -> PolarsResult { + s.array()?.array_median().map(Column::from) +} + +pub(super) fn unique(s: &Column, stable: bool) -> PolarsResult { + let ca = s.array()?; + let out = if stable { + ca.array_unique_stable() + } else { + ca.array_unique() + }; + out.map(|ca| ca.into_column()) +} + +pub(super) fn n_unique(s: &Column) -> PolarsResult { + Ok(s.array()?.array_n_unique()?.into_column()) +} + +pub(super) fn to_list(s: &Column) -> PolarsResult { + let list_dtype = map_array_dtype_to_list_dtype(s.dtype())?; + s.cast(&list_dtype) +} + +#[cfg(feature = "array_any_all")] +pub(super) fn any(s: &Column) -> PolarsResult { + s.array()?.array_any().map(Column::from) +} + +#[cfg(feature = "array_any_all")] +pub(super) fn all(s: &Column) -> PolarsResult { + s.array()?.array_all().map(Column::from) +} + +pub(super) fn sort(s: &Column, options: SortOptions) -> PolarsResult { + Ok(s.array()?.array_sort(options)?.into_column()) +} + +pub(super) fn reverse(s: &Column) -> PolarsResult { + Ok(s.array()?.array_reverse().into_column()) +} + +pub(super) fn arg_min(s: &Column) -> PolarsResult { + Ok(s.array()?.array_arg_min().into_column()) +} + +pub(super) fn arg_max(s: &Column) -> PolarsResult { + Ok(s.array()?.array_arg_max().into_column()) +} + +pub(super) fn get(s: &[Column], null_on_oob: bool) -> PolarsResult { + let ca = s[0].array()?; + let index = s[1].cast(&DataType::Int64)?; + let index = index.i64().unwrap(); + ca.array_get(index, null_on_oob).map(Column::from) +} + +pub(super) fn join(s: &[Column], ignore_nulls: bool) -> PolarsResult { + let ca = s[0].array()?; + let separator = s[1].str()?; + ca.array_join(separator, ignore_nulls).map(Column::from) +} + +#[cfg(feature = "is_in")] +pub(super) fn contains(s: &[Column]) -> PolarsResult { + let array = &s[0]; + let item = &s[1]; + polars_ensure!(matches!(array.dtype(), DataType::Array(_, _)), + SchemaMismatch: "invalid series dtype: expected `Array`, got `{}`", array.dtype(), + ); + Ok(is_in( + item.as_materialized_series(), + array.as_materialized_series(), + true, + )? + .with_name(array.name().clone()) + .into_column()) +} + +#[cfg(feature = "array_count")] +pub(super) fn count_matches(args: &[Column]) -> PolarsResult { + let s = &args[0]; + let element = &args[1]; + polars_ensure!( + element.len() == 1, + ComputeError: "argument expression in `arr.count_matches` must produce exactly one element, got {}", + element.len() + ); + let ca = s.array()?; + ca.array_count_matches(element.get(0).unwrap()) + .map(Column::from) +} + +pub(super) fn shift(s: &[Column]) -> PolarsResult { + let ca = s[0].array()?; + let n = &s[1]; + + ca.array_shift(n.as_materialized_series()).map(Column::from) +} + +fn explode(c: &[Column]) -> PolarsResult { + c[0].explode() +} + +fn concat_arr(args: &[Column]) -> PolarsResult { + let dtype = concat_arr_output_dtype(&mut args.iter().map(|c| (c.name().as_str(), c.dtype())))?; + + polars_ops::series::concat_arr::concat_arr(args, &dtype) +} + +/// Determine the output dtype of a `concat_arr` operation. Also performs validation to ensure input +/// dtypes are compatible. +fn concat_arr_output_dtype( + inputs: &mut dyn ExactSizeIterator, +) -> PolarsResult { + #[allow(clippy::len_zero)] + if inputs.len() == 0 { + // should not be reachable - we did not set ALLOW_EMPTY_INPUTS + panic!(); + } + + let mut inputs = inputs.map(|(name, dtype)| { + let (inner_dtype, width) = match dtype { + DataType::Array(inner, width) => (inner.as_ref(), *width), + dt => (dt, 1), + }; + (name, dtype, inner_dtype, width) + }); + let (first_name, first_dtype, first_inner_dtype, mut out_width) = inputs.next().unwrap(); + + for (col_name, dtype, inner_dtype, width) in inputs { + out_width += width; + + if inner_dtype != first_inner_dtype { + polars_bail!( + SchemaMismatch: + "concat_arr dtype mismatch: expected {} or array[{}] dtype to match dtype of first \ + input column (name: {}, dtype: {}), got {} instead for column {}", + first_inner_dtype, first_inner_dtype, first_name, first_dtype, dtype, col_name, + ) + } + } + + Ok(DataType::Array( + Box::new(first_inner_dtype.clone()), + out_width, + )) +} diff --git a/crates/polars-plan/src/dsl/function_expr/binary.rs b/crates/polars-plan/src/dsl/function_expr/binary.rs new file mode 100644 index 000000000000..af17e58b6239 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/binary.rs @@ -0,0 +1,183 @@ +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use super::*; +use crate::{map, map_as_slice}; + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Clone, PartialEq, Debug, Eq, Hash)] +pub enum BinaryFunction { + Contains, + StartsWith, + EndsWith, + #[cfg(feature = "binary_encoding")] + HexDecode(bool), + #[cfg(feature = "binary_encoding")] + HexEncode, + #[cfg(feature = "binary_encoding")] + Base64Decode(bool), + #[cfg(feature = "binary_encoding")] + Base64Encode, + Size, + #[cfg(feature = "binary_encoding")] + FromBuffer(DataType, bool), +} + +impl BinaryFunction { + pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult { + use BinaryFunction::*; + match self { + Contains => mapper.with_dtype(DataType::Boolean), + EndsWith | StartsWith => mapper.with_dtype(DataType::Boolean), + #[cfg(feature = "binary_encoding")] + HexDecode(_) | Base64Decode(_) => mapper.with_same_dtype(), + #[cfg(feature = "binary_encoding")] + HexEncode | Base64Encode => mapper.with_dtype(DataType::String), + Size => mapper.with_dtype(DataType::UInt32), + #[cfg(feature = "binary_encoding")] + FromBuffer(dtype, _) => mapper.with_dtype(dtype.clone()), + } + } + + pub fn function_options(&self) -> FunctionOptions { + use BinaryFunction as B; + match self { + B::Contains | B::StartsWith | B::EndsWith => { + FunctionOptions::elementwise().with_supertyping(Default::default()) + }, + B::Size => FunctionOptions::elementwise(), + #[cfg(feature = "binary_encoding")] + B::HexDecode(_) + | B::HexEncode + | B::Base64Decode(_) + | B::Base64Encode + | B::FromBuffer(_, _) => FunctionOptions::elementwise(), + } + } +} + +impl Display for BinaryFunction { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use BinaryFunction::*; + let s = match self { + Contains => "contains", + StartsWith => "starts_with", + EndsWith => "ends_with", + #[cfg(feature = "binary_encoding")] + HexDecode(_) => "hex_decode", + #[cfg(feature = "binary_encoding")] + HexEncode => "hex_encode", + #[cfg(feature = "binary_encoding")] + Base64Decode(_) => "base64_decode", + #[cfg(feature = "binary_encoding")] + Base64Encode => "base64_encode", + Size => "size_bytes", + #[cfg(feature = "binary_encoding")] + FromBuffer(_, _) => "from_buffer", + }; + write!(f, "bin.{s}") + } +} + +impl From for SpecialEq> { + fn from(func: BinaryFunction) -> Self { + use BinaryFunction::*; + match func { + Contains => { + map_as_slice!(contains) + }, + EndsWith => { + map_as_slice!(ends_with) + }, + StartsWith => { + map_as_slice!(starts_with) + }, + #[cfg(feature = "binary_encoding")] + HexDecode(strict) => map!(hex_decode, strict), + #[cfg(feature = "binary_encoding")] + HexEncode => map!(hex_encode), + #[cfg(feature = "binary_encoding")] + Base64Decode(strict) => map!(base64_decode, strict), + #[cfg(feature = "binary_encoding")] + Base64Encode => map!(base64_encode), + Size => map!(size_bytes), + #[cfg(feature = "binary_encoding")] + FromBuffer(dtype, is_little_endian) => map!(from_buffer, &dtype, is_little_endian), + } + } +} + +pub(super) fn contains(s: &[Column]) -> PolarsResult { + let ca = s[0].binary()?; + let lit = s[1].binary()?; + Ok(ca + .contains_chunked(lit)? + .with_name(ca.name().clone()) + .into_column()) +} + +pub(super) fn ends_with(s: &[Column]) -> PolarsResult { + let ca = s[0].binary()?; + let suffix = s[1].binary()?; + + Ok(ca + .ends_with_chunked(suffix)? + .with_name(ca.name().clone()) + .into_column()) +} + +pub(super) fn starts_with(s: &[Column]) -> PolarsResult { + let ca = s[0].binary()?; + let prefix = s[1].binary()?; + + Ok(ca + .starts_with_chunked(prefix)? + .with_name(ca.name().clone()) + .into_column()) +} + +pub(super) fn size_bytes(s: &Column) -> PolarsResult { + let ca = s.binary()?; + Ok(ca.size_bytes().into_column()) +} + +#[cfg(feature = "binary_encoding")] +pub(super) fn hex_decode(s: &Column, strict: bool) -> PolarsResult { + let ca = s.binary()?; + ca.hex_decode(strict).map(|ok| ok.into_column()) +} + +#[cfg(feature = "binary_encoding")] +pub(super) fn hex_encode(s: &Column) -> PolarsResult { + let ca = s.binary()?; + Ok(ca.hex_encode().into()) +} + +#[cfg(feature = "binary_encoding")] +pub(super) fn base64_decode(s: &Column, strict: bool) -> PolarsResult { + let ca = s.binary()?; + ca.base64_decode(strict).map(|ok| ok.into_column()) +} + +#[cfg(feature = "binary_encoding")] +pub(super) fn base64_encode(s: &Column) -> PolarsResult { + let ca = s.binary()?; + Ok(ca.base64_encode().into()) +} + +#[cfg(feature = "binary_encoding")] +pub(super) fn from_buffer( + s: &Column, + dtype: &DataType, + is_little_endian: bool, +) -> PolarsResult { + let ca = s.binary()?; + ca.from_buffer(dtype, is_little_endian) + .map(|val| val.into()) +} + +impl From for FunctionExpr { + fn from(b: BinaryFunction) -> Self { + FunctionExpr::BinaryExpr(b) + } +} diff --git a/crates/polars-plan/src/dsl/function_expr/bitwise.rs b/crates/polars-plan/src/dsl/function_expr/bitwise.rs new file mode 100644 index 000000000000..2a8bb537bd25 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/bitwise.rs @@ -0,0 +1,146 @@ +use std::fmt; +use std::sync::Arc; + +use polars_core::prelude::*; +use strum_macros::IntoStaticStr; + +use super::{ColumnsUdf, SpecialEq}; +use crate::dsl::{FieldsMapper, FunctionOptions}; +use crate::map; + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash, IntoStaticStr)] +#[strum(serialize_all = "snake_case")] +pub enum BitwiseFunction { + CountOnes, + CountZeros, + + LeadingOnes, + LeadingZeros, + + TrailingOnes, + TrailingZeros, + + // Bitwise Aggregations + And, + Or, + Xor, +} + +impl fmt::Display for BitwiseFunction { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + use BitwiseFunction as B; + + let s = match self { + B::CountOnes => "count_ones", + B::CountZeros => "count_zeros", + B::LeadingOnes => "leading_ones", + B::LeadingZeros => "leading_zeros", + B::TrailingOnes => "trailing_ones", + B::TrailingZeros => "trailing_zeros", + + B::And => "and", + B::Or => "or", + B::Xor => "xor", + }; + + f.write_str(s) + } +} + +impl From for SpecialEq> { + fn from(func: BitwiseFunction) -> Self { + use BitwiseFunction as B; + + match func { + B::CountOnes => map!(count_ones), + B::CountZeros => map!(count_zeros), + B::LeadingOnes => map!(leading_ones), + B::LeadingZeros => map!(leading_zeros), + B::TrailingOnes => map!(trailing_ones), + B::TrailingZeros => map!(trailing_zeros), + + B::And => map!(reduce_and), + B::Or => map!(reduce_or), + B::Xor => map!(reduce_xor), + } + } +} + +impl BitwiseFunction { + pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult { + mapper.try_map_dtype(|dtype| { + let is_valid = match dtype { + DataType::Boolean => true, + dt if dt.is_integer() => true, + dt if dt.is_float() => true, + _ => false, + }; + + if !is_valid { + polars_bail!(InvalidOperation: "dtype {} not supported in '{}' operation", dtype, self); + } + + match self { + Self::CountOnes | + Self::CountZeros | + Self::LeadingOnes | + Self::LeadingZeros | + Self::TrailingOnes | + Self::TrailingZeros => Ok(DataType::UInt32), + Self::And | + Self::Or | + Self::Xor => Ok(dtype.clone()), + } + }) + } + + pub fn function_options(&self) -> FunctionOptions { + use BitwiseFunction as B; + match self { + B::CountOnes + | B::CountZeros + | B::LeadingOnes + | B::LeadingZeros + | B::TrailingOnes + | B::TrailingZeros => FunctionOptions::elementwise(), + B::And | B::Or | B::Xor => FunctionOptions::aggregation(), + } + } +} + +fn count_ones(c: &Column) -> PolarsResult { + c.try_apply_unary_elementwise(polars_ops::series::count_ones) +} + +fn count_zeros(c: &Column) -> PolarsResult { + c.try_apply_unary_elementwise(polars_ops::series::count_zeros) +} + +fn leading_ones(c: &Column) -> PolarsResult { + c.try_apply_unary_elementwise(polars_ops::series::leading_ones) +} + +fn leading_zeros(c: &Column) -> PolarsResult { + c.try_apply_unary_elementwise(polars_ops::series::leading_zeros) +} + +fn trailing_ones(c: &Column) -> PolarsResult { + c.try_apply_unary_elementwise(polars_ops::series::trailing_ones) +} + +fn trailing_zeros(c: &Column) -> PolarsResult { + c.try_apply_unary_elementwise(polars_ops::series::trailing_zeros) +} + +fn reduce_and(c: &Column) -> PolarsResult { + c.and_reduce().map(|v| v.into_column(c.name().clone())) +} + +fn reduce_or(c: &Column) -> PolarsResult { + c.or_reduce().map(|v| v.into_column(c.name().clone())) +} + +fn reduce_xor(c: &Column) -> PolarsResult { + c.xor_reduce().map(|v| v.into_column(c.name().clone())) +} diff --git a/crates/polars-plan/src/dsl/function_expr/boolean.rs b/crates/polars-plan/src/dsl/function_expr/boolean.rs new file mode 100644 index 000000000000..b19126bfe8d9 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/boolean.rs @@ -0,0 +1,297 @@ +use std::ops::{BitAnd, BitOr}; + +use polars_core::POOL; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; + +use super::*; +#[cfg(feature = "is_in")] +use crate::wrap; +use crate::{map, map_as_slice}; + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Clone, PartialEq, Debug, Eq, Hash)] +pub enum BooleanFunction { + Any { + ignore_nulls: bool, + }, + All { + ignore_nulls: bool, + }, + IsNull, + IsNotNull, + IsFinite, + IsInfinite, + IsNan, + IsNotNan, + #[cfg(feature = "is_first_distinct")] + IsFirstDistinct, + #[cfg(feature = "is_last_distinct")] + IsLastDistinct, + #[cfg(feature = "is_unique")] + IsUnique, + #[cfg(feature = "is_unique")] + IsDuplicated, + #[cfg(feature = "is_between")] + IsBetween { + closed: ClosedInterval, + }, + #[cfg(feature = "is_in")] + IsIn { + nulls_equal: bool, + }, + AllHorizontal, + AnyHorizontal, + // Also bitwise negate + Not, +} + +impl BooleanFunction { + pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult { + match self { + BooleanFunction::Not => { + mapper.try_map_dtype(|dtype| { + match dtype { + DataType::Boolean => Ok(DataType::Boolean), + dt if dt.is_integer() => Ok(dt.clone()), + dt => polars_bail!(InvalidOperation: "dtype {:?} not supported in 'not' operation", dt) + } + }) + + }, + _ => mapper.with_dtype(DataType::Boolean), + } + } + + pub fn function_options(&self) -> FunctionOptions { + use BooleanFunction as B; + match self { + B::Any { .. } | B::All { .. } => FunctionOptions::aggregation(), + B::IsNull | B::IsNotNull | B::IsFinite | B::IsInfinite | B::IsNan | B::IsNotNan => { + FunctionOptions::elementwise() + }, + #[cfg(feature = "is_first_distinct")] + B::IsFirstDistinct => FunctionOptions::length_preserving(), + #[cfg(feature = "is_last_distinct")] + B::IsLastDistinct => FunctionOptions::length_preserving(), + #[cfg(feature = "is_unique")] + B::IsUnique => FunctionOptions::length_preserving(), + #[cfg(feature = "is_unique")] + B::IsDuplicated => FunctionOptions::length_preserving(), + #[cfg(feature = "is_between")] + B::IsBetween { .. } => FunctionOptions::elementwise().with_supertyping( + (SuperTypeFlags::default() & !SuperTypeFlags::ALLOW_PRIMITIVE_TO_STRING).into(), + ), + #[cfg(feature = "is_in")] + B::IsIn { .. } => FunctionOptions::elementwise().with_supertyping(Default::default()), + B::AllHorizontal | B::AnyHorizontal => FunctionOptions::elementwise() + .with_input_wildcard_expansion(true) + .with_allow_empty_inputs(true), + B::Not => FunctionOptions::elementwise(), + } + } +} + +impl Display for BooleanFunction { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use BooleanFunction::*; + let s = match self { + All { .. } => "all", + Any { .. } => "any", + IsNull => "is_null", + IsNotNull => "is_not_null", + IsFinite => "is_finite", + IsInfinite => "is_infinite", + IsNan => "is_nan", + IsNotNan => "is_not_nan", + #[cfg(feature = "is_first_distinct")] + IsFirstDistinct => "is_first_distinct", + #[cfg(feature = "is_last_distinct")] + IsLastDistinct => "is_last_distinct", + #[cfg(feature = "is_unique")] + IsUnique => "is_unique", + #[cfg(feature = "is_unique")] + IsDuplicated => "is_duplicated", + #[cfg(feature = "is_between")] + IsBetween { .. } => "is_between", + #[cfg(feature = "is_in")] + IsIn { .. } => "is_in", + AnyHorizontal => "any_horizontal", + AllHorizontal => "all_horizontal", + Not => "not", + }; + write!(f, "{s}") + } +} + +impl From for SpecialEq> { + fn from(func: BooleanFunction) -> Self { + use BooleanFunction::*; + match func { + Any { ignore_nulls } => map!(any, ignore_nulls), + All { ignore_nulls } => map!(all, ignore_nulls), + IsNull => map!(is_null), + IsNotNull => map!(is_not_null), + IsFinite => map!(is_finite), + IsInfinite => map!(is_infinite), + IsNan => map!(is_nan), + IsNotNan => map!(is_not_nan), + #[cfg(feature = "is_first_distinct")] + IsFirstDistinct => map!(is_first_distinct), + #[cfg(feature = "is_last_distinct")] + IsLastDistinct => map!(is_last_distinct), + #[cfg(feature = "is_unique")] + IsUnique => map!(is_unique), + #[cfg(feature = "is_unique")] + IsDuplicated => map!(is_duplicated), + #[cfg(feature = "is_between")] + IsBetween { closed } => map_as_slice!(is_between, closed), + #[cfg(feature = "is_in")] + IsIn { nulls_equal } => wrap!(is_in, nulls_equal), + Not => map!(not), + AllHorizontal => map_as_slice!(all_horizontal), + AnyHorizontal => map_as_slice!(any_horizontal), + } + } +} + +impl From for FunctionExpr { + fn from(func: BooleanFunction) -> Self { + FunctionExpr::Boolean(func) + } +} + +fn any(s: &Column, ignore_nulls: bool) -> PolarsResult { + let ca = s.bool()?; + if ignore_nulls { + Ok(Column::new(s.name().clone(), [ca.any()])) + } else { + Ok(Column::new(s.name().clone(), [ca.any_kleene()])) + } +} + +fn all(s: &Column, ignore_nulls: bool) -> PolarsResult { + let ca = s.bool()?; + if ignore_nulls { + Ok(Column::new(s.name().clone(), [ca.all()])) + } else { + Ok(Column::new(s.name().clone(), [ca.all_kleene()])) + } +} + +fn is_null(s: &Column) -> PolarsResult { + Ok(s.is_null().into_column()) +} + +fn is_not_null(s: &Column) -> PolarsResult { + Ok(s.is_not_null().into_column()) +} + +fn is_finite(s: &Column) -> PolarsResult { + s.is_finite().map(|ca| ca.into_column()) +} + +fn is_infinite(s: &Column) -> PolarsResult { + s.is_infinite().map(|ca| ca.into_column()) +} + +pub(super) fn is_nan(s: &Column) -> PolarsResult { + s.is_nan().map(|ca| ca.into_column()) +} + +pub(super) fn is_not_nan(s: &Column) -> PolarsResult { + s.is_not_nan().map(|ca| ca.into_column()) +} + +#[cfg(feature = "is_first_distinct")] +fn is_first_distinct(s: &Column) -> PolarsResult { + polars_ops::prelude::is_first_distinct(s.as_materialized_series()).map(|ca| ca.into_column()) +} + +#[cfg(feature = "is_last_distinct")] +fn is_last_distinct(s: &Column) -> PolarsResult { + polars_ops::prelude::is_last_distinct(s.as_materialized_series()).map(|ca| ca.into_column()) +} + +#[cfg(feature = "is_unique")] +fn is_unique(s: &Column) -> PolarsResult { + polars_ops::prelude::is_unique(s.as_materialized_series()).map(|ca| ca.into_column()) +} + +#[cfg(feature = "is_unique")] +fn is_duplicated(s: &Column) -> PolarsResult { + polars_ops::prelude::is_duplicated(s.as_materialized_series()).map(|ca| ca.into_column()) +} + +#[cfg(feature = "is_between")] +fn is_between(s: &[Column], closed: ClosedInterval) -> PolarsResult { + let ser = &s[0]; + let lower = &s[1]; + let upper = &s[2]; + polars_ops::prelude::is_between( + ser.as_materialized_series(), + lower.as_materialized_series(), + upper.as_materialized_series(), + closed, + ) + .map(|ca| ca.into_column()) +} + +#[cfg(feature = "is_in")] +fn is_in(s: &mut [Column], nulls_equal: bool) -> PolarsResult> { + let left = &s[0]; + let other = &s[1]; + polars_ops::prelude::is_in( + left.as_materialized_series(), + other.as_materialized_series(), + nulls_equal, + ) + .map(|ca| Some(ca.into_column())) +} + +fn not(s: &Column) -> PolarsResult { + polars_ops::series::negate_bitwise(s.as_materialized_series()).map(Column::from) +} + +// We shouldn't hit these often only on very wide dataframes where we don't reduce to & expressions. +fn any_horizontal(s: &[Column]) -> PolarsResult { + let out = POOL + .install(|| { + s.par_iter() + .try_fold( + || BooleanChunked::new(PlSmallStr::EMPTY, &[false]), + |acc, b| { + let b = b.cast(&DataType::Boolean)?; + let b = b.bool()?; + PolarsResult::Ok((&acc).bitor(b)) + }, + ) + .try_reduce( + || BooleanChunked::new(PlSmallStr::EMPTY, [false]), + |a, b| Ok(a.bitor(b)), + ) + })? + .with_name(s[0].name().clone()); + Ok(out.into_column()) +} + +// We shouldn't hit these often only on very wide dataframes where we don't reduce to & expressions. +fn all_horizontal(s: &[Column]) -> PolarsResult { + let out = POOL + .install(|| { + s.par_iter() + .try_fold( + || BooleanChunked::new(PlSmallStr::EMPTY, &[true]), + |acc, b| { + let b = b.cast(&DataType::Boolean)?; + let b = b.bool()?; + PolarsResult::Ok((&acc).bitand(b)) + }, + ) + .try_reduce( + || BooleanChunked::new(PlSmallStr::EMPTY, [true]), + |a, b| Ok(a.bitand(b)), + ) + })? + .with_name(s[0].name().clone()); + Ok(out.into_column()) +} diff --git a/crates/polars-plan/src/dsl/function_expr/bounds.rs b/crates/polars-plan/src/dsl/function_expr/bounds.rs new file mode 100644 index 000000000000..ae0f36a0956e --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/bounds.rs @@ -0,0 +1,13 @@ +use super::*; + +pub(super) fn upper_bound(s: &Column) -> PolarsResult { + let name = s.name().clone(); + let scalar = s.dtype().to_physical().max()?; + Ok(Column::new_scalar(name, scalar, 1)) +} + +pub(super) fn lower_bound(s: &Column) -> PolarsResult { + let name = s.name().clone(); + let scalar = s.dtype().to_physical().min()?; + Ok(Column::new_scalar(name, scalar, 1)) +} diff --git a/crates/polars-plan/src/dsl/function_expr/business.rs b/crates/polars-plan/src/dsl/function_expr/business.rs new file mode 100644 index 000000000000..3028b1090034 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/business.rs @@ -0,0 +1,127 @@ +use std::fmt::{Display, Formatter}; + +use polars_core::prelude::*; +use polars_ops::prelude::Roll; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use super::FunctionOptions; +use crate::dsl::{FieldsMapper, SpecialEq}; +use crate::map_as_slice; +use crate::prelude::ColumnsUdf; + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Clone, PartialEq, Debug, Eq, Hash)] +pub enum BusinessFunction { + BusinessDayCount { + week_mask: [bool; 7], + holidays: Vec, + }, + AddBusinessDay { + week_mask: [bool; 7], + holidays: Vec, + roll: Roll, + }, + IsBusinessDay { + week_mask: [bool; 7], + holidays: Vec, + }, +} + +impl BusinessFunction { + pub fn get_field(&self, mapper: FieldsMapper) -> PolarsResult { + match self { + Self::BusinessDayCount { .. } => mapper.with_dtype(DataType::Int32), + Self::AddBusinessDay { .. } => mapper.with_same_dtype(), + Self::IsBusinessDay { .. } => mapper.with_dtype(DataType::Boolean), + } + } + pub fn function_options(&self) -> FunctionOptions { + use BusinessFunction as B; + match self { + B::BusinessDayCount { .. } => FunctionOptions::elementwise().with_allow_rename(true), + B::AddBusinessDay { .. } | B::IsBusinessDay { .. } => FunctionOptions::elementwise(), + } + } +} + +impl Display for BusinessFunction { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use BusinessFunction::*; + let s = match self { + BusinessDayCount { .. } => "business_day_count", + AddBusinessDay { .. } => "add_business_days", + IsBusinessDay { .. } => "is_business_day", + }; + write!(f, "{s}") + } +} +impl From for SpecialEq> { + fn from(func: BusinessFunction) -> Self { + use BusinessFunction::*; + match func { + BusinessDayCount { + week_mask, + holidays, + } => { + map_as_slice!(business_day_count, week_mask, &holidays) + }, + AddBusinessDay { + week_mask, + holidays, + roll, + } => { + map_as_slice!(add_business_days, week_mask, &holidays, roll) + }, + IsBusinessDay { + week_mask, + holidays, + } => { + map_as_slice!(is_business_day, week_mask, &holidays) + }, + } + } +} + +pub(super) fn business_day_count( + s: &[Column], + week_mask: [bool; 7], + holidays: &[i32], +) -> PolarsResult { + let start = &s[0]; + let end = &s[1]; + polars_ops::prelude::business_day_count( + start.as_materialized_series(), + end.as_materialized_series(), + week_mask, + holidays, + ) + .map(Column::from) +} +pub(super) fn add_business_days( + s: &[Column], + week_mask: [bool; 7], + holidays: &[i32], + roll: Roll, +) -> PolarsResult { + let start = &s[0]; + let n = &s[1]; + polars_ops::prelude::add_business_days( + start.as_materialized_series(), + n.as_materialized_series(), + week_mask, + holidays, + roll, + ) + .map(Column::from) +} + +pub(super) fn is_business_day( + s: &[Column], + week_mask: [bool; 7], + holidays: &[i32], +) -> PolarsResult { + let dates = &s[0]; + polars_ops::prelude::is_business_day(dates.as_materialized_series(), week_mask, holidays) + .map(Column::from) +} diff --git a/crates/polars-plan/src/dsl/function_expr/cat.rs b/crates/polars-plan/src/dsl/function_expr/cat.rs new file mode 100644 index 000000000000..075b0e70fc45 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/cat.rs @@ -0,0 +1,189 @@ +use super::*; +use crate::map; + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Clone, PartialEq, Debug, Eq, Hash)] +pub enum CategoricalFunction { + GetCategories, + #[cfg(feature = "strings")] + LenBytes, + #[cfg(feature = "strings")] + LenChars, + #[cfg(feature = "strings")] + StartsWith(String), + #[cfg(feature = "strings")] + EndsWith(String), + #[cfg(feature = "strings")] + Slice(i64, Option), +} + +impl CategoricalFunction { + pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult { + use CategoricalFunction::*; + match self { + GetCategories => mapper.with_dtype(DataType::String), + #[cfg(feature = "strings")] + LenBytes => mapper.with_dtype(DataType::UInt32), + #[cfg(feature = "strings")] + LenChars => mapper.with_dtype(DataType::UInt32), + #[cfg(feature = "strings")] + StartsWith(_) => mapper.with_dtype(DataType::Boolean), + #[cfg(feature = "strings")] + EndsWith(_) => mapper.with_dtype(DataType::Boolean), + #[cfg(feature = "strings")] + Slice(_, _) => mapper.with_dtype(DataType::String), + } + } + + pub fn function_options(&self) -> FunctionOptions { + use CategoricalFunction as C; + match self { + C::GetCategories => FunctionOptions::groupwise(), + #[cfg(feature = "strings")] + C::LenBytes | C::LenChars | C::StartsWith(_) | C::EndsWith(_) | C::Slice(_, _) => { + FunctionOptions::elementwise() + }, + } + } +} + +impl Display for CategoricalFunction { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use CategoricalFunction::*; + let s = match self { + GetCategories => "get_categories", + #[cfg(feature = "strings")] + LenBytes => "len_bytes", + #[cfg(feature = "strings")] + LenChars => "len_chars", + #[cfg(feature = "strings")] + StartsWith(_) => "starts_with", + #[cfg(feature = "strings")] + EndsWith(_) => "ends_with", + #[cfg(feature = "strings")] + Slice(_, _) => "slice", + }; + write!(f, "cat.{s}") + } +} + +impl From for SpecialEq> { + fn from(func: CategoricalFunction) -> Self { + use CategoricalFunction::*; + match func { + GetCategories => map!(get_categories), + #[cfg(feature = "strings")] + LenBytes => map!(len_bytes), + #[cfg(feature = "strings")] + LenChars => map!(len_chars), + #[cfg(feature = "strings")] + StartsWith(prefix) => map!(starts_with, prefix.as_str()), + #[cfg(feature = "strings")] + EndsWith(suffix) => map!(ends_with, suffix.as_str()), + #[cfg(feature = "strings")] + Slice(offset, length) => map!(slice, offset, length), + } + } +} + +impl From for FunctionExpr { + fn from(func: CategoricalFunction) -> Self { + FunctionExpr::Categorical(func) + } +} + +fn get_categories(s: &Column) -> PolarsResult { + // categorical check + let ca = s.categorical()?; + let rev_map = ca.get_rev_map(); + let arr = rev_map.get_categories().clone().boxed(); + Series::try_from((ca.name().clone(), arr)).map(Column::from) +} + +// Determine mapping between categories and underlying physical. For local, this is just 0..n. +// For global, this is the global indexes. +fn _get_cat_phys_map(ca: &CategoricalChunked) -> (StringChunked, Series) { + let (categories, phys) = match &**ca.get_rev_map() { + RevMapping::Local(c, _) => (c, ca.physical().cast(&IDX_DTYPE).unwrap()), + RevMapping::Global(physical_map, c, _) => { + // Map physical to its local representation for use with take() later. + let phys = ca + .physical() + .apply(|opt_v| opt_v.map(|v| *physical_map.get(&v).unwrap())); + let out = phys.cast(&IDX_DTYPE).unwrap(); + (c, out) + }, + }; + let categories = StringChunked::with_chunk(ca.name().clone(), categories.clone()); + (categories, phys) +} + +/// Fast path: apply a string function to the categories of a categorical column and broadcast the +/// result back to the array. +// fn apply_to_cats(ca: &CategoricalChunked, mut op: F) -> PolarsResult +fn apply_to_cats(c: &Column, mut op: F) -> PolarsResult +where + F: FnMut(&StringChunked) -> ChunkedArray, + ChunkedArray: IntoSeries, + T: PolarsDataType, +{ + let ca = c.categorical()?; + let (categories, phys) = _get_cat_phys_map(ca); + let result = op(&categories); + // SAFETY: physical idx array is valid. + let out = unsafe { result.take_unchecked(phys.idx().unwrap()) }; + Ok(out.into_column()) +} + +/// Fast path: apply a binary function to the categories of a categorical column and broadcast the +/// result back to the array. +fn apply_to_cats_binary(c: &Column, mut op: F) -> PolarsResult +where + F: FnMut(&BinaryChunked) -> ChunkedArray, + ChunkedArray: IntoSeries, + T: PolarsDataType, +{ + let ca = c.categorical()?; + let (categories, phys) = _get_cat_phys_map(ca); + let result = op(&categories.as_binary()); + // SAFETY: physical idx array is valid. + let out = unsafe { result.take_unchecked(phys.idx().unwrap()) }; + Ok(out.into_column()) +} + +#[cfg(feature = "strings")] +fn len_bytes(c: &Column) -> PolarsResult { + apply_to_cats(c, |s| s.str_len_bytes()) +} + +#[cfg(feature = "strings")] +fn len_chars(c: &Column) -> PolarsResult { + apply_to_cats(c, |s| s.str_len_chars()) +} + +#[cfg(feature = "strings")] +fn starts_with(c: &Column, prefix: &str) -> PolarsResult { + apply_to_cats_binary(c, |s| s.starts_with(prefix.as_bytes())) +} + +#[cfg(feature = "strings")] +fn ends_with(c: &Column, suffix: &str) -> PolarsResult { + apply_to_cats_binary(c, |s| s.ends_with(suffix.as_bytes())) +} + +#[cfg(feature = "strings")] +fn slice(c: &Column, offset: i64, length: Option) -> PolarsResult { + let length = length.unwrap_or(usize::MAX) as u64; + let ca = c.categorical()?; + let (categories, phys) = _get_cat_phys_map(ca); + + let result = unsafe { + categories.apply_views(|view, val| { + let (start, end) = substring_ternary_offsets_value(val, offset, length); + update_view(view, start, end, val) + }) + }; + // SAFETY: physical idx array is valid. + let out = unsafe { result.take_unchecked(phys.idx().unwrap()) }; + Ok(out.into_column()) +} diff --git a/crates/polars-plan/src/dsl/function_expr/clip.rs b/crates/polars-plan/src/dsl/function_expr/clip.rs new file mode 100644 index 000000000000..9a721d65d198 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/clip.rs @@ -0,0 +1,21 @@ +use super::*; + +pub(super) fn clip(s: &[Column], has_min: bool, has_max: bool) -> PolarsResult { + match (has_min, has_max) { + (true, true) => polars_ops::series::clip( + s[0].as_materialized_series(), + s[1].as_materialized_series(), + s[2].as_materialized_series(), + ), + (true, false) => polars_ops::series::clip_min( + s[0].as_materialized_series(), + s[1].as_materialized_series(), + ), + (false, true) => polars_ops::series::clip_max( + s[0].as_materialized_series(), + s[1].as_materialized_series(), + ), + _ => unreachable!(), + } + .map(Column::from) +} diff --git a/crates/polars-plan/src/dsl/function_expr/coerce.rs b/crates/polars-plan/src/dsl/function_expr/coerce.rs new file mode 100644 index 000000000000..2ade8737d077 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/coerce.rs @@ -0,0 +1,22 @@ +use polars_core::prelude::*; + +pub fn as_struct(cols: &[Column]) -> PolarsResult { + let Some(fst) = cols.first() else { + polars_bail!(nyi = "turning no columns as_struct"); + }; + + let mut min_length = usize::MAX; + let mut max_length = usize::MIN; + + for col in cols { + let len = col.len(); + + min_length = min_length.min(len); + max_length = max_length.max(len); + } + + // @NOTE: Any additional errors should be handled by the StructChunked::from_columns + let length = if min_length == 0 { 0 } else { max_length }; + + Ok(StructChunked::from_columns(fst.name().clone(), length, cols)?.into_column()) +} diff --git a/crates/polars-plan/src/dsl/function_expr/concat.rs b/crates/polars-plan/src/dsl/function_expr/concat.rs new file mode 100644 index 000000000000..a021545f2ad0 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/concat.rs @@ -0,0 +1,13 @@ +use super::*; + +pub(super) fn concat_expr(s: &[Column], rechunk: bool) -> PolarsResult { + let mut first = s[0].clone(); + + for s in &s[1..] { + first.append(s)?; + } + if rechunk { + first = first.rechunk() + } + Ok(first) +} diff --git a/crates/polars-plan/src/dsl/function_expr/correlation.rs b/crates/polars-plan/src/dsl/function_expr/correlation.rs new file mode 100644 index 000000000000..a653c6a25f81 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/correlation.rs @@ -0,0 +1,143 @@ +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use super::*; + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Copy, Clone, PartialEq, Debug, Hash)] +pub enum CorrelationMethod { + Pearson, + #[cfg(all(feature = "rank", feature = "propagate_nans"))] + SpearmanRank(bool), + Covariance(u8), +} + +impl Display for CorrelationMethod { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use CorrelationMethod::*; + let s = match self { + Pearson => "pearson", + #[cfg(all(feature = "rank", feature = "propagate_nans"))] + SpearmanRank(_) => "spearman_rank", + Covariance(_) => return write!(f, "covariance"), + }; + write!(f, "{}_correlation", s) + } +} + +pub(super) fn corr(s: &[Column], method: CorrelationMethod) -> PolarsResult { + polars_ensure!( + s[0].len() == s[1].len() || s[0].len() == 1 || s[1].len() == 1, + length_mismatch = "corr", + s[0].len(), + s[1].len() + ); + + match method { + CorrelationMethod::Pearson => pearson_corr(s), + #[cfg(all(feature = "rank", feature = "propagate_nans"))] + CorrelationMethod::SpearmanRank(propagate_nans) => spearman_rank_corr(s, propagate_nans), + CorrelationMethod::Covariance(ddof) => covariance(s, ddof), + } +} + +fn covariance(s: &[Column], ddof: u8) -> PolarsResult { + let a = &s[0]; + let b = &s[1]; + let name = PlSmallStr::from_static("cov"); + + use polars_ops::chunked_array::cov::cov; + let ret = match a.dtype() { + DataType::Float32 => { + let ret = cov(a.f32().unwrap(), b.f32().unwrap(), ddof).map(|v| v as f32); + return Ok(Column::new(name, &[ret])); + }, + DataType::Float64 => cov(a.f64().unwrap(), b.f64().unwrap(), ddof), + DataType::Int32 => cov(a.i32().unwrap(), b.i32().unwrap(), ddof), + DataType::Int64 => cov(a.i64().unwrap(), b.i64().unwrap(), ddof), + DataType::UInt32 => cov(a.u32().unwrap(), b.u32().unwrap(), ddof), + DataType::UInt64 => cov(a.u64().unwrap(), b.u64().unwrap(), ddof), + _ => { + let a = a.cast(&DataType::Float64)?; + let b = b.cast(&DataType::Float64)?; + cov(a.f64().unwrap(), b.f64().unwrap(), ddof) + }, + }; + Ok(Column::new(name, &[ret])) +} + +fn pearson_corr(s: &[Column]) -> PolarsResult { + let a = &s[0]; + let b = &s[1]; + let name = PlSmallStr::from_static("pearson_corr"); + + use polars_ops::chunked_array::cov::pearson_corr; + let ret = match a.dtype() { + DataType::Float32 => { + let ret = pearson_corr(a.f32().unwrap(), b.f32().unwrap()).map(|v| v as f32); + return Ok(Column::new(name.clone(), &[ret])); + }, + DataType::Float64 => pearson_corr(a.f64().unwrap(), b.f64().unwrap()), + DataType::Int32 => pearson_corr(a.i32().unwrap(), b.i32().unwrap()), + DataType::Int64 => pearson_corr(a.i64().unwrap(), b.i64().unwrap()), + DataType::UInt32 => pearson_corr(a.u32().unwrap(), b.u32().unwrap()), + _ => { + let a = a.cast(&DataType::Float64)?; + let b = b.cast(&DataType::Float64)?; + pearson_corr(a.f64().unwrap(), b.f64().unwrap()) + }, + }; + Ok(Column::new(name, &[ret])) +} + +#[cfg(all(feature = "rank", feature = "propagate_nans"))] +fn spearman_rank_corr(s: &[Column], propagate_nans: bool) -> PolarsResult { + use polars_core::utils::coalesce_nulls_columns; + use polars_ops::chunked_array::nan_propagating_aggregate::nan_max_s; + let a = &s[0]; + let b = &s[1]; + + let (a, b) = coalesce_nulls_columns(a, b); + + let name = PlSmallStr::from_static("spearman_rank_correlation"); + if propagate_nans && a.dtype().is_float() { + for s in [&a, &b] { + if nan_max_s(s.as_materialized_series(), PlSmallStr::EMPTY) + .get(0) + .unwrap() + .extract::() + .unwrap() + .is_nan() + { + return Ok(Column::new(name, &[f64::NAN])); + } + } + } + + // drop nulls so that they are excluded + let a = a.drop_nulls(); + let b = b.drop_nulls(); + + let a_rank = a + .as_materialized_series() + .rank( + RankOptions { + method: RankMethod::Average, + ..Default::default() + }, + None, + ) + .into(); + let b_rank = b + .as_materialized_series() + .rank( + RankOptions { + method: RankMethod::Average, + ..Default::default() + }, + None, + ) + .into(); + + pearson_corr(&[a_rank, b_rank]) +} diff --git a/crates/polars-plan/src/dsl/function_expr/cum.rs b/crates/polars-plan/src/dsl/function_expr/cum.rs new file mode 100644 index 000000000000..2866b44bcd11 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/cum.rs @@ -0,0 +1,66 @@ +use super::*; + +pub(super) fn cum_count(s: &Column, reverse: bool) -> PolarsResult { + // @scalar-opt + polars_ops::prelude::cum_count(s.as_materialized_series(), reverse).map(Column::from) +} + +pub(super) fn cum_sum(s: &Column, reverse: bool) -> PolarsResult { + // @scalar-opt + polars_ops::prelude::cum_sum(s.as_materialized_series(), reverse).map(Column::from) +} + +pub(super) fn cum_prod(s: &Column, reverse: bool) -> PolarsResult { + // @scalar-opt + polars_ops::prelude::cum_prod(s.as_materialized_series(), reverse).map(Column::from) +} + +pub(super) fn cum_min(s: &Column, reverse: bool) -> PolarsResult { + // @scalar-opt + polars_ops::prelude::cum_min(s.as_materialized_series(), reverse).map(Column::from) +} + +pub(super) fn cum_max(s: &Column, reverse: bool) -> PolarsResult { + // @scalar-opt + polars_ops::prelude::cum_max(s.as_materialized_series(), reverse).map(Column::from) +} + +pub(super) mod dtypes { + use DataType::*; + use polars_core::utils::materialize_dyn_int; + + use super::*; + + pub fn cum_sum(dt: &DataType) -> DataType { + if dt.is_logical() { + dt.clone() + } else { + match dt { + Boolean => UInt32, + Int32 => Int32, + Int128 => Int128, + UInt32 => UInt32, + UInt64 => UInt64, + Float32 => Float32, + Float64 => Float64, + Unknown(kind) => match kind { + UnknownKind::Int(v) => cum_sum(&materialize_dyn_int(*v).dtype()), + UnknownKind::Float => Float64, + _ => dt.clone(), + }, + _ => Int64, + } + } + } + + pub fn cum_prod(dt: &DataType) -> DataType { + match dt { + Boolean => Int64, + UInt64 => UInt64, + Int128 => Int128, + Float32 => Float32, + Float64 => Float64, + _ => Int64, + } + } +} diff --git a/crates/polars-plan/src/dsl/function_expr/cut.rs b/crates/polars-plan/src/dsl/function_expr/cut.rs new file mode 100644 index 000000000000..faafc7aa3f76 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/cut.rs @@ -0,0 +1,37 @@ +use polars_core::prelude::*; + +pub(crate) fn cut( + s: &Column, + breaks: Vec, + labels: Option>, + left_closed: bool, + include_breaks: bool, +) -> PolarsResult { + polars_ops::prelude::cut( + s.as_materialized_series(), + breaks, + labels, + left_closed, + include_breaks, + ) + .map(Column::from) +} + +pub(crate) fn qcut( + s: &Column, + probs: Vec, + labels: Option>, + left_closed: bool, + allow_duplicates: bool, + include_breaks: bool, +) -> PolarsResult { + polars_ops::prelude::qcut( + s.as_materialized_series(), + probs, + labels, + left_closed, + allow_duplicates, + include_breaks, + ) + .map(Column::from) +} diff --git a/crates/polars-plan/src/dsl/function_expr/datetime.rs b/crates/polars-plan/src/dsl/function_expr/datetime.rs new file mode 100644 index 000000000000..f7f61e4461ba --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/datetime.rs @@ -0,0 +1,659 @@ +#[cfg(feature = "timezones")] +use chrono_tz::Tz; +#[cfg(feature = "timezones")] +use polars_core::chunked_array::temporal::validate_time_zone; +#[cfg(feature = "timezones")] +use polars_time::base_utc_offset as base_utc_offset_fn; +#[cfg(feature = "timezones")] +use polars_time::dst_offset as dst_offset_fn; +#[cfg(feature = "offset_by")] +use polars_time::impl_offset_by; +#[cfg(any(feature = "dtype-date", feature = "dtype-datetime"))] +use polars_time::replace::{replace_date, replace_datetime}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use super::*; + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Clone, PartialEq, Debug, Eq, Hash)] +pub enum TemporalFunction { + Millennium, + Century, + Year, + IsLeapYear, + IsoYear, + Quarter, + Month, + Week, + WeekDay, + Day, + OrdinalDay, + Time, + Date, + Datetime, + Duration(TimeUnit), + Hour, + Minute, + Second, + Millisecond, + Microsecond, + Nanosecond, + TotalDays, + TotalHours, + TotalMinutes, + TotalSeconds, + TotalMilliseconds, + TotalMicroseconds, + TotalNanoseconds, + ToString(String), + CastTimeUnit(TimeUnit), + WithTimeUnit(TimeUnit), + #[cfg(feature = "timezones")] + ConvertTimeZone(TimeZone), + TimeStamp(TimeUnit), + Truncate, + #[cfg(feature = "offset_by")] + OffsetBy, + #[cfg(feature = "month_start")] + MonthStart, + #[cfg(feature = "month_end")] + MonthEnd, + #[cfg(feature = "timezones")] + BaseUtcOffset, + #[cfg(feature = "timezones")] + DSTOffset, + Round, + Replace, + #[cfg(feature = "timezones")] + ReplaceTimeZone(Option, NonExistent), + Combine(TimeUnit), + DatetimeFunction { + time_unit: TimeUnit, + time_zone: Option, + }, +} + +impl TemporalFunction { + pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult { + use TemporalFunction::*; + match self { + Millennium | Century => mapper.with_dtype(DataType::Int8), + Year | IsoYear => mapper.with_dtype(DataType::Int32), + OrdinalDay => mapper.with_dtype(DataType::Int16), + Month | Quarter | Week | WeekDay | Day | Hour | Minute | Second => { + mapper.with_dtype(DataType::Int8) + }, + Millisecond | Microsecond | Nanosecond => mapper.with_dtype(DataType::Int32), + TotalDays | TotalHours | TotalMinutes | TotalSeconds | TotalMilliseconds + | TotalMicroseconds | TotalNanoseconds => mapper.with_dtype(DataType::Int64), + ToString(_) => mapper.with_dtype(DataType::String), + WithTimeUnit(_) => mapper.with_same_dtype(), + CastTimeUnit(tu) => mapper.try_map_dtype(|dt| match dt { + DataType::Duration(_) => Ok(DataType::Duration(*tu)), + DataType::Datetime(_, tz) => Ok(DataType::Datetime(*tu, tz.clone())), + dtype => polars_bail!(ComputeError: "expected duration or datetime, got {}", dtype), + }), + #[cfg(feature = "timezones")] + ConvertTimeZone(tz) => mapper.try_map_dtype(|dt| match dt { + DataType::Datetime(tu, _) => Ok(DataType::Datetime(*tu, Some(tz.clone()))), + dtype => polars_bail!(ComputeError: "expected Datetime, got {}", dtype), + }), + TimeStamp(_) => mapper.with_dtype(DataType::Int64), + IsLeapYear => mapper.with_dtype(DataType::Boolean), + Time => mapper.with_dtype(DataType::Time), + Duration(tu) => mapper.with_dtype(DataType::Duration(*tu)), + Date => mapper.with_dtype(DataType::Date), + Datetime => mapper.try_map_dtype(|dt| match dt { + DataType::Datetime(tu, _) => Ok(DataType::Datetime(*tu, None)), + dtype => polars_bail!(ComputeError: "expected Datetime, got {}", dtype), + }), + Truncate => mapper.with_same_dtype(), + #[cfg(feature = "offset_by")] + OffsetBy => mapper.with_same_dtype(), + #[cfg(feature = "month_start")] + MonthStart => mapper.with_same_dtype(), + #[cfg(feature = "month_end")] + MonthEnd => mapper.with_same_dtype(), + #[cfg(feature = "timezones")] + BaseUtcOffset => mapper.with_dtype(DataType::Duration(TimeUnit::Milliseconds)), + #[cfg(feature = "timezones")] + DSTOffset => mapper.with_dtype(DataType::Duration(TimeUnit::Milliseconds)), + Round => mapper.with_same_dtype(), + Replace => mapper.with_same_dtype(), + #[cfg(feature = "timezones")] + ReplaceTimeZone(tz, _non_existent) => mapper.map_datetime_dtype_timezone(tz.as_ref()), + DatetimeFunction { + time_unit, + time_zone, + } => Ok(Field::new( + PlSmallStr::from_static("datetime"), + DataType::Datetime(*time_unit, time_zone.clone()), + )), + Combine(tu) => mapper.try_map_dtype(|dt| match dt { + DataType::Datetime(_, tz) => Ok(DataType::Datetime(*tu, tz.clone())), + DataType::Date => Ok(DataType::Datetime(*tu, None)), + dtype => { + polars_bail!(ComputeError: "expected Date or Datetime, got {}", dtype) + }, + }), + } + } + + pub fn function_options(&self) -> FunctionOptions { + use TemporalFunction as T; + match self { + T::Millennium + | T::Century + | T::Year + | T::IsLeapYear + | T::IsoYear + | T::Quarter + | T::Month + | T::Week + | T::WeekDay + | T::Day + | T::OrdinalDay + | T::Time + | T::Date + | T::Datetime + | T::Hour + | T::Minute + | T::Second + | T::Millisecond + | T::Microsecond + | T::Nanosecond + | T::TotalDays + | T::TotalHours + | T::TotalMinutes + | T::TotalSeconds + | T::TotalMilliseconds + | T::TotalMicroseconds + | T::TotalNanoseconds + | T::ToString(_) + | T::TimeStamp(_) + | T::CastTimeUnit(_) + | T::WithTimeUnit(_) => FunctionOptions::elementwise(), + #[cfg(feature = "timezones")] + T::ConvertTimeZone(_) => FunctionOptions::elementwise(), + #[cfg(feature = "month_start")] + T::MonthStart => FunctionOptions::elementwise(), + #[cfg(feature = "month_end")] + T::MonthEnd => FunctionOptions::elementwise(), + #[cfg(feature = "timezones")] + T::BaseUtcOffset | T::DSTOffset => FunctionOptions::elementwise(), + T::Truncate => FunctionOptions::elementwise(), + #[cfg(feature = "offset_by")] + T::OffsetBy => FunctionOptions::elementwise(), + T::Round => FunctionOptions::elementwise(), + T::Replace => FunctionOptions::elementwise(), + T::Duration(_) => FunctionOptions::elementwise(), + #[cfg(feature = "timezones")] + T::ReplaceTimeZone(_, _) => FunctionOptions::elementwise(), + T::Combine(_) => FunctionOptions::elementwise(), + T::DatetimeFunction { .. } => FunctionOptions::elementwise().with_allow_rename(true), + } + } +} + +impl Display for TemporalFunction { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use TemporalFunction::*; + let s = match self { + Millennium => "millennium", + Century => "century", + Year => "year", + IsLeapYear => "is_leap_year", + IsoYear => "iso_year", + Quarter => "quarter", + Month => "month", + Week => "week", + WeekDay => "weekday", + Day => "day", + OrdinalDay => "ordinal_day", + Time => "time", + Date => "date", + Datetime => "datetime", + Duration(_) => "duration", + Hour => "hour", + Minute => "minute", + Second => "second", + Millisecond => "millisecond", + Microsecond => "microsecond", + Nanosecond => "nanosecond", + TotalDays => "total_days", + TotalHours => "total_hours", + TotalMinutes => "total_minutes", + TotalSeconds => "total_seconds", + TotalMilliseconds => "total_milliseconds", + TotalMicroseconds => "total_microseconds", + TotalNanoseconds => "total_nanoseconds", + ToString(_) => "to_string", + #[cfg(feature = "timezones")] + ConvertTimeZone(_) => "convert_time_zone", + CastTimeUnit(_) => "cast_time_unit", + WithTimeUnit(_) => "with_time_unit", + TimeStamp(tu) => return write!(f, "dt.timestamp({tu})"), + Truncate => "truncate", + #[cfg(feature = "offset_by")] + OffsetBy => "offset_by", + #[cfg(feature = "month_start")] + MonthStart => "month_start", + #[cfg(feature = "month_end")] + MonthEnd => "month_end", + #[cfg(feature = "timezones")] + BaseUtcOffset => "base_utc_offset", + #[cfg(feature = "timezones")] + DSTOffset => "dst_offset", + Round => "round", + Replace => "replace", + #[cfg(feature = "timezones")] + ReplaceTimeZone(_, _) => "replace_time_zone", + DatetimeFunction { .. } => return write!(f, "dt.datetime"), + Combine(_) => "combine", + }; + write!(f, "dt.{s}") + } +} + +pub(super) fn millennium(s: &Column) -> PolarsResult { + s.as_materialized_series() + .millennium() + .map(|ca| ca.into_column()) +} +pub(super) fn century(s: &Column) -> PolarsResult { + s.as_materialized_series() + .century() + .map(|ca| ca.into_column()) +} +pub(super) fn year(s: &Column) -> PolarsResult { + s.as_materialized_series().year().map(|ca| ca.into_column()) +} +pub(super) fn is_leap_year(s: &Column) -> PolarsResult { + s.as_materialized_series() + .is_leap_year() + .map(|ca| ca.into_column()) +} +pub(super) fn iso_year(s: &Column) -> PolarsResult { + s.as_materialized_series() + .iso_year() + .map(|ca| ca.into_column()) +} +pub(super) fn month(s: &Column) -> PolarsResult { + s.as_materialized_series() + .month() + .map(|ca| ca.into_column()) +} +pub(super) fn quarter(s: &Column) -> PolarsResult { + s.as_materialized_series() + .quarter() + .map(|ca| ca.into_column()) +} +pub(super) fn week(s: &Column) -> PolarsResult { + s.as_materialized_series().week().map(|ca| ca.into_column()) +} +pub(super) fn weekday(s: &Column) -> PolarsResult { + s.as_materialized_series() + .weekday() + .map(|ca| ca.into_column()) +} +pub(super) fn day(s: &Column) -> PolarsResult { + s.as_materialized_series().day().map(|ca| ca.into_column()) +} +pub(super) fn ordinal_day(s: &Column) -> PolarsResult { + s.as_materialized_series() + .ordinal_day() + .map(|ca| ca.into_column()) +} +pub(super) fn time(s: &Column) -> PolarsResult { + match s.dtype() { + #[cfg(feature = "timezones")] + DataType::Datetime(_, Some(_)) => polars_ops::prelude::replace_time_zone( + s.datetime().unwrap(), + None, + &StringChunked::from_iter(std::iter::once("raise")), + NonExistent::Raise, + )? + .cast(&DataType::Time) + .map(Column::from), + DataType::Datetime(_, _) => s + .datetime() + .unwrap() + .cast(&DataType::Time) + .map(Column::from), + DataType::Time => Ok(s.clone()), + dtype => polars_bail!(ComputeError: "expected Datetime or Time, got {}", dtype), + } +} +pub(super) fn date(s: &Column) -> PolarsResult { + match s.dtype() { + #[cfg(feature = "timezones")] + DataType::Datetime(_, Some(_)) => { + let mut out = { + polars_ops::chunked_array::replace_time_zone( + s.datetime().unwrap(), + None, + &StringChunked::from_iter(std::iter::once("raise")), + NonExistent::Raise, + )? + .cast(&DataType::Date)? + }; + // `replace_time_zone` may unset sorted flag. But, we're only taking the date + // part of the result, so we can safely preserve the sorted flag here. We may + // need to make an exception if a time zone introduces a change which involves + // "going back in time" and repeating a day, but we're not aware of that ever + // having happened. + out.set_sorted_flag(s.is_sorted_flag()); + Ok(out.into()) + }, + DataType::Datetime(_, _) => s + .datetime() + .unwrap() + .cast(&DataType::Date) + .map(Column::from), + DataType::Date => Ok(s.clone()), + dtype => polars_bail!(ComputeError: "expected Datetime or Date, got {}", dtype), + } +} +pub(super) fn datetime(s: &Column) -> PolarsResult { + match s.dtype() { + #[cfg(feature = "timezones")] + DataType::Datetime(tu, Some(_)) => polars_ops::chunked_array::replace_time_zone( + s.datetime().unwrap(), + None, + &StringChunked::from_iter(std::iter::once("raise")), + NonExistent::Raise, + )? + .cast(&DataType::Datetime(*tu, None)) + .map(|x| x.into()), + DataType::Datetime(tu, _) => s + .datetime() + .unwrap() + .cast(&DataType::Datetime(*tu, None)) + .map(Column::from), + dtype => polars_bail!(ComputeError: "expected Datetime, got {}", dtype), + } +} +pub(super) fn hour(s: &Column) -> PolarsResult { + s.as_materialized_series().hour().map(|ca| ca.into_column()) +} +pub(super) fn minute(s: &Column) -> PolarsResult { + s.as_materialized_series() + .minute() + .map(|ca| ca.into_column()) +} +pub(super) fn second(s: &Column) -> PolarsResult { + s.as_materialized_series() + .second() + .map(|ca| ca.into_column()) +} +pub(super) fn millisecond(s: &Column) -> PolarsResult { + s.as_materialized_series() + .nanosecond() + .map(|ca| (ca.wrapping_trunc_div_scalar(1_000_000)).into_column()) +} +pub(super) fn microsecond(s: &Column) -> PolarsResult { + s.as_materialized_series() + .nanosecond() + .map(|ca| (ca.wrapping_trunc_div_scalar(1_000)).into_column()) +} +pub(super) fn nanosecond(s: &Column) -> PolarsResult { + s.as_materialized_series() + .nanosecond() + .map(|ca| ca.into_column()) +} +#[cfg(feature = "dtype-duration")] +pub(super) fn total_days(s: &Column) -> PolarsResult { + s.as_materialized_series() + .duration() + .map(|ca| ca.days().into_column()) +} +#[cfg(feature = "dtype-duration")] +pub(super) fn total_hours(s: &Column) -> PolarsResult { + s.as_materialized_series() + .duration() + .map(|ca| ca.hours().into_column()) +} +#[cfg(feature = "dtype-duration")] +pub(super) fn total_minutes(s: &Column) -> PolarsResult { + s.as_materialized_series() + .duration() + .map(|ca| ca.minutes().into_column()) +} +#[cfg(feature = "dtype-duration")] +pub(super) fn total_seconds(s: &Column) -> PolarsResult { + s.as_materialized_series() + .duration() + .map(|ca| ca.seconds().into_column()) +} +#[cfg(feature = "dtype-duration")] +pub(super) fn total_milliseconds(s: &Column) -> PolarsResult { + s.as_materialized_series() + .duration() + .map(|ca| ca.milliseconds().into_column()) +} +#[cfg(feature = "dtype-duration")] +pub(super) fn total_microseconds(s: &Column) -> PolarsResult { + s.as_materialized_series() + .duration() + .map(|ca| ca.microseconds().into_column()) +} +#[cfg(feature = "dtype-duration")] +pub(super) fn total_nanoseconds(s: &Column) -> PolarsResult { + s.as_materialized_series() + .duration() + .map(|ca| ca.nanoseconds().into_column()) +} +pub(super) fn timestamp(s: &Column, tu: TimeUnit) -> PolarsResult { + s.as_materialized_series() + .timestamp(tu) + .map(|ca| ca.into_column()) +} +pub(super) fn to_string(s: &Column, format: &str) -> PolarsResult { + TemporalMethods::to_string(s.as_materialized_series(), format).map(Column::from) +} + +#[cfg(feature = "timezones")] +pub(super) fn convert_time_zone(s: &Column, time_zone: &TimeZone) -> PolarsResult { + match s.dtype() { + DataType::Datetime(_, _) => { + let mut ca = s.datetime()?.clone(); + validate_time_zone(time_zone)?; + ca.set_time_zone(time_zone.clone())?; + Ok(ca.into_column()) + }, + dtype => polars_bail!(ComputeError: "expected Datetime, got {}", dtype), + } +} +pub(super) fn with_time_unit(s: &Column, tu: TimeUnit) -> PolarsResult { + match s.dtype() { + DataType::Datetime(_, _) => { + let mut ca = s.datetime()?.clone(); + ca.set_time_unit(tu); + Ok(ca.into_column()) + }, + #[cfg(feature = "dtype-duration")] + DataType::Duration(_) => { + let mut ca = s.as_materialized_series().duration()?.clone(); + ca.set_time_unit(tu); + Ok(ca.into_column()) + }, + dt => polars_bail!(ComputeError: "dtype `{}` has no time unit", dt), + } +} +pub(super) fn cast_time_unit(s: &Column, tu: TimeUnit) -> PolarsResult { + match s.dtype() { + DataType::Datetime(_, _) => { + let ca = s.datetime()?; + Ok(ca.cast_time_unit(tu).into_column()) + }, + #[cfg(feature = "dtype-duration")] + DataType::Duration(_) => { + let ca = s.as_materialized_series().duration()?; + Ok(ca.cast_time_unit(tu).into_column()) + }, + dt => polars_bail!(ComputeError: "dtype `{}` has no time unit", dt), + } +} + +pub(super) fn truncate(s: &[Column]) -> PolarsResult { + let time_series = &s[0]; + let every = s[1].str()?; + + let mut out = match time_series.dtype() { + DataType::Datetime(_, tz) => match tz { + #[cfg(feature = "timezones")] + Some(tz) => time_series + .datetime()? + .truncate(tz.parse::().ok().as_ref(), every)? + .into_column(), + _ => time_series.datetime()?.truncate(None, every)?.into_column(), + }, + DataType::Date => time_series.date()?.truncate(None, every)?.into_column(), + dt => polars_bail!(opq = round, got = dt, expected = "date/datetime"), + }; + out.set_sorted_flag(time_series.is_sorted_flag()); + Ok(out) +} + +#[cfg(feature = "offset_by")] +pub(super) fn offset_by(s: &[Column]) -> PolarsResult { + impl_offset_by(s[0].as_materialized_series(), s[1].as_materialized_series()).map(Column::from) +} + +#[cfg(feature = "month_start")] +pub(super) fn month_start(s: &Column) -> PolarsResult { + Ok(match s.dtype() { + DataType::Datetime(_, tz) => match tz { + #[cfg(feature = "timezones")] + Some(tz) => s + .datetime() + .unwrap() + .month_start(tz.parse::().ok().as_ref())? + .into_column(), + _ => s.datetime().unwrap().month_start(None)?.into_column(), + }, + DataType::Date => s.date().unwrap().month_start(None)?.into_column(), + dt => polars_bail!(opq = month_start, got = dt, expected = "date/datetime"), + }) +} + +#[cfg(feature = "month_end")] +pub(super) fn month_end(s: &Column) -> PolarsResult { + Ok(match s.dtype() { + DataType::Datetime(_, tz) => match tz { + #[cfg(feature = "timezones")] + Some(tz) => s + .datetime() + .unwrap() + .month_end(tz.parse::().ok().as_ref())? + .into_column(), + _ => s.datetime().unwrap().month_end(None)?.into_column(), + }, + DataType::Date => s.date().unwrap().month_end(None)?.into_column(), + dt => polars_bail!(opq = month_end, got = dt, expected = "date/datetime"), + }) +} + +#[cfg(feature = "timezones")] +pub(super) fn base_utc_offset(s: &Column) -> PolarsResult { + match s.dtype() { + DataType::Datetime(time_unit, Some(tz)) => { + let tz = tz + .parse::() + .expect("Time zone has already been validated"); + Ok(base_utc_offset_fn(s.datetime().unwrap(), time_unit, &tz).into_column()) + }, + dt => polars_bail!( + opq = base_utc_offset, + got = dt, + expected = "time-zone-aware datetime" + ), + } +} +#[cfg(feature = "timezones")] +pub(super) fn dst_offset(s: &Column) -> PolarsResult { + match s.dtype() { + DataType::Datetime(time_unit, Some(tz)) => { + let tz = tz + .parse::() + .expect("Time zone has already been validated"); + Ok(dst_offset_fn(s.datetime().unwrap(), time_unit, &tz).into_column()) + }, + dt => polars_bail!( + opq = dst_offset, + got = dt, + expected = "time-zone-aware datetime" + ), + } +} + +pub(super) fn round(s: &[Column]) -> PolarsResult { + let time_series = &s[0]; + let every = s[1].str()?; + + Ok(match time_series.dtype() { + DataType::Datetime(_, tz) => match tz { + #[cfg(feature = "timezones")] + Some(tz) => time_series + .datetime() + .unwrap() + .round(every, tz.parse::().ok().as_ref())? + .into_column(), + _ => time_series + .datetime() + .unwrap() + .round(every, None)? + .into_column(), + }, + DataType::Date => time_series + .date() + .unwrap() + .round(every, None)? + .into_column(), + dt => polars_bail!(opq = round, got = dt, expected = "date/datetime"), + }) +} + +pub(super) fn replace(s: &[Column]) -> PolarsResult { + let time_series = &s[0]; + let s_year = &s[1].strict_cast(&DataType::Int32)?; + let s_month = &s[2].strict_cast(&DataType::Int8)?; + let s_day = &s[3].strict_cast(&DataType::Int8)?; + let year = s_year.i32()?; + let month = s_month.i8()?; + let day = s_day.i8()?; + + match time_series.dtype() { + DataType::Datetime(_, _) => { + let s_hour = &s[4].strict_cast(&DataType::Int8)?; + let s_minute = &s[5].strict_cast(&DataType::Int8)?; + let s_second = &s[6].strict_cast(&DataType::Int8)?; + let s_microsecond = &s[7].strict_cast(&DataType::Int32)?; + let hour = s_hour.i8()?; + let minute = s_minute.i8()?; + let second = s_second.i8()?; + let nanosecond = &(s_microsecond.i32()? * 1_000); + let s_ambiguous = &s[8].strict_cast(&DataType::String)?; + let ambiguous = s_ambiguous.str()?; + + let out = replace_datetime( + time_series.datetime().unwrap(), + year, + month, + day, + hour, + minute, + second, + nanosecond, + ambiguous, + ); + out.map(|s| s.into_column()) + }, + DataType::Date => { + let out = replace_date(time_series.date().unwrap(), year, month, day); + out.map(|s| s.into_column()) + }, + dt => polars_bail!(opq = round, got = dt, expected = "date/datetime"), + } +} diff --git a/crates/polars-plan/src/dsl/function_expr/dispatch.rs b/crates/polars-plan/src/dsl/function_expr/dispatch.rs new file mode 100644 index 000000000000..6d83a993c532 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/dispatch.rs @@ -0,0 +1,251 @@ +use polars_ops::series::NullStrategy; + +use super::*; + +pub(super) fn reverse(s: &Column) -> PolarsResult { + Ok(s.reverse()) +} + +#[cfg(feature = "approx_unique")] +pub(super) fn approx_n_unique(s: &Column) -> PolarsResult { + s.approx_n_unique() + .map(|v| Column::new_scalar(s.name().clone(), Scalar::new(IDX_DTYPE, v.into()), 1)) +} + +#[cfg(feature = "diff")] +pub(super) fn diff(s: &[Column], null_behavior: NullBehavior) -> PolarsResult { + let s1 = s[0].as_materialized_series(); + let n = &s[1]; + + polars_ensure!( + n.len() == 1, + ComputeError: "n must be a single value." + ); + let n = n.strict_cast(&DataType::Int64)?; + match n.i64()?.get(0) { + Some(n) => polars_ops::prelude::diff(s1, n, null_behavior).map(Column::from), + None => polars_bail!(ComputeError: "'n' can not be None for diff"), + } +} + +#[cfg(feature = "pct_change")] +pub(super) fn pct_change(s: &[Column]) -> PolarsResult { + polars_ops::prelude::pct_change(s[0].as_materialized_series(), s[1].as_materialized_series()) + .map(Column::from) +} + +#[cfg(feature = "interpolate")] +pub(super) fn interpolate(s: &Column, method: InterpolationMethod) -> PolarsResult { + Ok(polars_ops::prelude::interpolate(s.as_materialized_series(), method).into()) +} + +#[cfg(feature = "interpolate_by")] +pub(super) fn interpolate_by(s: &[Column]) -> PolarsResult { + let by = &s[1]; + let by_is_sorted = by.as_materialized_series().is_sorted(Default::default())?; + polars_ops::prelude::interpolate_by(&s[0], by, by_is_sorted) +} + +pub(super) fn to_physical(s: &Column) -> PolarsResult { + Ok(s.to_physical_repr()) +} + +pub(super) fn set_sorted_flag(s: &Column, sorted: IsSorted) -> PolarsResult { + let mut s = s.clone(); + s.set_sorted_flag(sorted); + Ok(s) +} + +#[cfg(feature = "timezones")] +pub(super) fn replace_time_zone( + s: &[Column], + time_zone: Option<&str>, + non_existent: NonExistent, +) -> PolarsResult { + let s1 = &s[0]; + let ca = s1.datetime().unwrap(); + let s2 = &s[1].str()?; + Ok(polars_ops::prelude::replace_time_zone(ca, time_zone, s2, non_existent)?.into_column()) +} + +#[cfg(feature = "dtype-struct")] +pub(super) fn value_counts( + s: &Column, + sort: bool, + parallel: bool, + name: PlSmallStr, + normalize: bool, +) -> PolarsResult { + s.as_materialized_series() + .value_counts(sort, parallel, name, normalize) + .map(|df| df.into_struct(s.name().clone()).into_column()) +} + +#[cfg(feature = "unique_counts")] +pub(super) fn unique_counts(s: &Column) -> PolarsResult { + polars_ops::prelude::unique_counts(s.as_materialized_series()).map(Column::from) +} + +#[cfg(feature = "dtype-array")] +pub(super) fn reshape(c: &Column, dimensions: &[ReshapeDimension]) -> PolarsResult { + c.reshape_array(dimensions) +} + +#[cfg(feature = "repeat_by")] +pub(super) fn repeat_by(s: &[Column]) -> PolarsResult { + let by = &s[1]; + let s = &s[0]; + let by = by.cast(&IDX_DTYPE)?; + polars_ops::chunked_array::repeat_by(s.as_materialized_series(), by.idx()?) + .map(|ok| ok.into_column()) +} + +pub(super) fn max_horizontal(s: &mut [Column]) -> PolarsResult> { + polars_ops::prelude::max_horizontal(s) +} + +pub(super) fn min_horizontal(s: &mut [Column]) -> PolarsResult> { + polars_ops::prelude::min_horizontal(s) +} + +pub(super) fn sum_horizontal(s: &mut [Column], ignore_nulls: bool) -> PolarsResult> { + let null_strategy = if ignore_nulls { + NullStrategy::Ignore + } else { + NullStrategy::Propagate + }; + polars_ops::prelude::sum_horizontal(s, null_strategy) +} + +pub(super) fn mean_horizontal( + s: &mut [Column], + ignore_nulls: bool, +) -> PolarsResult> { + let null_strategy = if ignore_nulls { + NullStrategy::Ignore + } else { + NullStrategy::Propagate + }; + polars_ops::prelude::mean_horizontal(s, null_strategy) +} + +pub(super) fn drop_nulls(s: &Column) -> PolarsResult { + Ok(s.drop_nulls()) +} + +#[cfg(feature = "mode")] +pub(super) fn mode(s: &Column) -> PolarsResult { + mode::mode(s.as_materialized_series()).map(Column::from) +} + +#[cfg(feature = "moment")] +pub(super) fn skew(s: &Column, bias: bool) -> PolarsResult { + // @scalar-opt + s.as_materialized_series() + .skew(bias) + .map(|opt_v| Column::new(s.name().clone(), &[opt_v])) +} + +#[cfg(feature = "moment")] +pub(super) fn kurtosis(s: &Column, fisher: bool, bias: bool) -> PolarsResult { + // @scalar-opt + s.as_materialized_series() + .kurtosis(fisher, bias) + .map(|opt_v| Column::new(s.name().clone(), &[opt_v])) +} + +pub(super) fn arg_unique(s: &Column) -> PolarsResult { + // @scalar-opt + s.as_materialized_series() + .arg_unique() + .map(|ok| ok.into_column()) +} + +#[cfg(feature = "rank")] +pub(super) fn rank(s: &Column, options: RankOptions, seed: Option) -> PolarsResult { + Ok(s.as_materialized_series().rank(options, seed).into_column()) +} + +#[cfg(feature = "hist")] +pub(super) fn hist( + s: &[Column], + bin_count: Option, + include_category: bool, + include_breakpoint: bool, +) -> PolarsResult { + let bins = if s.len() == 2 { Some(&s[1]) } else { None }; + let s = s[0].as_materialized_series(); + hist_series( + s, + bin_count, + bins.map(|b| b.as_materialized_series().clone()), + include_category, + include_breakpoint, + ) + .map(Column::from) +} + +#[cfg(feature = "replace")] +pub(super) fn replace(s: &[Column]) -> PolarsResult { + polars_ops::series::replace( + s[0].as_materialized_series(), + s[1].as_materialized_series(), + s[2].as_materialized_series(), + ) + .map(Column::from) +} + +#[cfg(feature = "replace")] +pub(super) fn replace_strict(s: &[Column], return_dtype: Option) -> PolarsResult { + match s.get(3) { + Some(default) => polars_ops::series::replace_or_default( + s[0].as_materialized_series(), + s[1].as_materialized_series(), + s[2].as_materialized_series(), + default.as_materialized_series(), + return_dtype, + ), + None => polars_ops::series::replace_strict( + s[0].as_materialized_series(), + s[1].as_materialized_series(), + s[2].as_materialized_series(), + return_dtype, + ), + } + .map(Column::from) +} + +pub(super) fn fill_null_with_strategy( + s: &Column, + strategy: FillNullStrategy, +) -> PolarsResult { + s.fill_null(strategy) +} + +pub(super) fn gather_every(s: &Column, n: usize, offset: usize) -> PolarsResult { + s.gather_every(n, offset) +} + +#[cfg(feature = "reinterpret")] +pub(super) fn reinterpret(s: &Column, signed: bool) -> PolarsResult { + polars_ops::series::reinterpret(s.as_materialized_series(), signed).map(Column::from) +} + +pub(super) fn negate(s: &Column) -> PolarsResult { + polars_ops::series::negate(s.as_materialized_series()).map(Column::from) +} + +pub(super) fn extend_constant(s: &[Column]) -> PolarsResult { + let value = &s[1]; + let n = &s[2]; + polars_ensure!(value.len() == 1 && n.len() == 1, ComputeError: "value and n should have unit length."); + let n = n.strict_cast(&DataType::UInt64)?; + let v = value.get(0)?; + let s = &s[0]; + match n.u64()?.get(0) { + Some(n) => s.extend_constant(v, n as usize), + None => { + polars_bail!(ComputeError: "n can not be None for extend_constant.") + }, + } +} diff --git a/crates/polars-plan/src/dsl/function_expr/ewm.rs b/crates/polars-plan/src/dsl/function_expr/ewm.rs new file mode 100644 index 000000000000..6f7a20045503 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/ewm.rs @@ -0,0 +1,13 @@ +use super::*; + +pub(super) fn ewm_mean(s: &Column, options: EWMOptions) -> PolarsResult { + polars_ops::prelude::ewm_mean(s.as_materialized_series(), options).map(Column::from) +} + +pub(super) fn ewm_std(s: &Column, options: EWMOptions) -> PolarsResult { + polars_ops::prelude::ewm_std(s.as_materialized_series(), options).map(Column::from) +} + +pub(super) fn ewm_var(s: &Column, options: EWMOptions) -> PolarsResult { + polars_ops::prelude::ewm_var(s.as_materialized_series(), options).map(Column::from) +} diff --git a/crates/polars-plan/src/dsl/function_expr/ewm_by.rs b/crates/polars-plan/src/dsl/function_expr/ewm_by.rs new file mode 100644 index 000000000000..adfc66a01524 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/ewm_by.rs @@ -0,0 +1,26 @@ +use polars_ops::series::SeriesMethods; + +use super::*; + +pub(super) fn ewm_mean_by(s: &[Column], half_life: Duration) -> PolarsResult { + let time_zone = match s[1].dtype() { + DataType::Datetime(_, Some(time_zone)) => Some(time_zone.as_str()), + _ => None, + }; + polars_ensure!(!half_life.negative(), InvalidOperation: "half_life cannot be negative"); + ensure_is_constant_duration(half_life, time_zone, "half_life")?; + // `half_life` is a constant duration so we can safely use `duration_ns()`. + let half_life = half_life.duration_ns(); + let values = &s[0]; + let times = &s[1]; + let times_is_sorted = times + .as_materialized_series() + .is_sorted(Default::default())?; + polars_ops::prelude::ewm_mean_by( + values.as_materialized_series(), + times.as_materialized_series(), + half_life, + times_is_sorted, + ) + .map(Column::from) +} diff --git a/crates/polars-plan/src/dsl/function_expr/fill_null.rs b/crates/polars-plan/src/dsl/function_expr/fill_null.rs new file mode 100644 index 000000000000..eaa65c2f753e --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/fill_null.rs @@ -0,0 +1,64 @@ +use super::*; + +pub(super) fn fill_null(s: &[Column]) -> PolarsResult { + match (s[0].len(), s[1].len()) { + (a, b) if a == b || b == 1 => { + let series = s[0].clone(); + + // Nothing to fill, so return early + // this is done after casting as the output type must be correct + if series.null_count() == 0 { + return Ok(series); + } + + let fill_value = s[1].clone(); + + // default branch + fn default(series: Column, fill_value: Column) -> PolarsResult { + let mask = series.is_not_null(); + series.zip_with_same_type(&mask, &fill_value) + } + + match series.dtype() { + #[cfg(feature = "dtype-categorical")] + // for Categoricals we first need to check if the category already exist + DataType::Categorical(Some(rev_map), _) => { + if rev_map.is_local() && fill_value.len() == 1 && fill_value.null_count() == 0 { + let fill_av = fill_value.get(0).unwrap(); + let fill_str = fill_av.get_str().unwrap(); + + if let Some(idx) = rev_map.find(fill_str) { + let cats = series.to_physical_repr(); + let mask = cats.is_not_null(); + let out = cats + .zip_with_same_type(&mask, &Column::new(PlSmallStr::EMPTY, &[idx])) + .unwrap(); + unsafe { return out.from_physical_unchecked(series.dtype()) } + } + } + let fill_value = if fill_value.dtype().is_string() { + fill_value + .cast(&DataType::Categorical(None, Default::default())) + .unwrap() + } else { + fill_value + }; + default(series, fill_value) + }, + _ => default(series, fill_value), + } + }, + (1, other_len) => { + if s[0].has_nulls() { + Ok(s[1].clone()) + } else { + Ok(s[0].new_from_index(0, other_len)) + } + }, + (self_len, other_len) => polars_bail!(length_mismatch = "fill_null", self_len, other_len), + } +} + +pub(super) fn coalesce(s: &mut [Column]) -> PolarsResult { + coalesce_columns(s) +} diff --git a/crates/polars-plan/src/dsl/function_expr/fused.rs b/crates/polars-plan/src/dsl/function_expr/fused.rs new file mode 100644 index 000000000000..088078105216 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/fused.rs @@ -0,0 +1,34 @@ +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use super::*; + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Copy, Clone, PartialEq, Debug, Hash)] +pub enum FusedOperator { + MultiplyAdd, + SubMultiply, + MultiplySub, +} + +impl Display for FusedOperator { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let s = match self { + FusedOperator::MultiplyAdd => "fma", + FusedOperator::SubMultiply => "fsm", + FusedOperator::MultiplySub => "fms", + }; + write!(f, "{s}") + } +} + +pub(super) fn fused(input: &[Column], op: FusedOperator) -> PolarsResult { + let s0 = &input[0]; + let s1 = &input[1]; + let s2 = &input[2]; + match op { + FusedOperator::MultiplyAdd => Ok(fma_columns(s0, s1, s2)), + FusedOperator::SubMultiply => Ok(fsm_columns(s0, s1, s2)), + FusedOperator::MultiplySub => Ok(fms_columns(s0, s1, s2)), + } +} diff --git a/crates/polars-plan/src/dsl/function_expr/index_of.rs b/crates/polars-plan/src/dsl/function_expr/index_of.rs new file mode 100644 index 000000000000..d396d7065091 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/index_of.rs @@ -0,0 +1,61 @@ +use polars_ops::series::index_of as index_of_op; + +use super::*; + +/// Given two columns, find the index of a value (the second column) within the +/// first column. Will use binary search if possible, as an optimization. +pub(super) fn index_of(s: &mut [Column]) -> PolarsResult { + let series = if let Column::Scalar(ref sc) = s[0] { + // We only care about the first value: + &sc.as_single_value_series() + } else { + s[0].as_materialized_series() + }; + + let needle_s = &s[1]; + polars_ensure!( + needle_s.len() == 1, + InvalidOperation: "needle of `index_of` can only contain a single value, found {} values", + needle_s.len() + ); + let needle = Scalar::new( + needle_s.dtype().clone(), + needle_s.get(0).unwrap().into_static(), + ); + + let is_sorted_flag = series.is_sorted_flag(); + let result = match is_sorted_flag { + // If the Series is sorted, we can use an optimized binary search to + // find the value. + IsSorted::Ascending | IsSorted::Descending + if !needle.is_null() && + // search_sorted() doesn't support decimals at the moment. + !series.dtype().is_decimal() => + { + search_sorted( + series, + needle_s.as_materialized_series(), + SearchSortedSide::Left, + IsSorted::Descending == is_sorted_flag, + )? + .get(0) + .and_then(|idx| { + // search_sorted() gives an index even if it's not an exact + // match! So we want to make sure it actually found the value. + if series.get(idx as usize).ok()? == needle.as_any_value() { + Some(idx as usize) + } else { + None + } + }) + }, + _ => index_of_op(series, needle)?, + }; + + let av = match result { + None => AnyValue::Null, + Some(idx) => AnyValue::from(idx as IdxSize), + }; + let scalar = Scalar::new(IDX_DTYPE, av); + Ok(Column::new_scalar(series.name().clone(), scalar, 1)) +} diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs new file mode 100644 index 000000000000..c773c127fa05 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -0,0 +1,785 @@ +use arrow::legacy::utils::CustomIterTools; +use polars_ops::chunked_array::list::*; + +use super::*; +use crate::{map, map_as_slice, wrap}; + +#[derive(Clone, Eq, PartialEq, Hash, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum ListFunction { + Concat, + #[cfg(feature = "is_in")] + Contains, + #[cfg(feature = "list_drop_nulls")] + DropNulls, + #[cfg(feature = "list_sample")] + Sample { + is_fraction: bool, + with_replacement: bool, + shuffle: bool, + seed: Option, + }, + Slice, + Shift, + Get(bool), + #[cfg(feature = "list_gather")] + Gather(bool), + #[cfg(feature = "list_gather")] + GatherEvery, + #[cfg(feature = "list_count")] + CountMatches, + Sum, + Length, + Max, + Min, + Mean, + Median, + Std(u8), + Var(u8), + ArgMin, + ArgMax, + #[cfg(feature = "diff")] + Diff { + n: i64, + null_behavior: NullBehavior, + }, + Sort(SortOptions), + Reverse, + Unique(bool), + NUnique, + #[cfg(feature = "list_sets")] + SetOperation(SetOperation), + #[cfg(feature = "list_any_all")] + Any, + #[cfg(feature = "list_any_all")] + All, + Join(bool), + #[cfg(feature = "dtype-array")] + ToArray(usize), + #[cfg(feature = "list_to_struct")] + ToStruct(ListToStructArgs), +} + +impl ListFunction { + pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult { + use ListFunction::*; + match self { + Concat => mapper.map_to_list_supertype(), + #[cfg(feature = "is_in")] + Contains => mapper.with_dtype(DataType::Boolean), + #[cfg(feature = "list_drop_nulls")] + DropNulls => mapper.with_same_dtype(), + #[cfg(feature = "list_sample")] + Sample { .. } => mapper.with_same_dtype(), + Slice => mapper.with_same_dtype(), + Shift => mapper.with_same_dtype(), + Get(_) => mapper.map_to_list_and_array_inner_dtype(), + #[cfg(feature = "list_gather")] + Gather(_) => mapper.with_same_dtype(), + #[cfg(feature = "list_gather")] + GatherEvery => mapper.with_same_dtype(), + #[cfg(feature = "list_count")] + CountMatches => mapper.with_dtype(IDX_DTYPE), + Sum => mapper.nested_sum_type(), + Min => mapper.map_to_list_and_array_inner_dtype(), + Max => mapper.map_to_list_and_array_inner_dtype(), + Mean => mapper.nested_mean_median_type(), + Median => mapper.nested_mean_median_type(), + Std(_) => mapper.map_to_float_dtype(), // Need to also have this sometimes marked as float32 or duration.. + Var(_) => mapper.map_to_float_dtype(), + ArgMin => mapper.with_dtype(IDX_DTYPE), + ArgMax => mapper.with_dtype(IDX_DTYPE), + #[cfg(feature = "diff")] + Diff { .. } => mapper.map_dtype(|dt| { + let inner_dt = match dt.inner_dtype().unwrap() { + #[cfg(feature = "dtype-datetime")] + DataType::Datetime(tu, _) => DataType::Duration(*tu), + #[cfg(feature = "dtype-date")] + DataType::Date => DataType::Duration(TimeUnit::Milliseconds), + #[cfg(feature = "dtype-time")] + DataType::Time => DataType::Duration(TimeUnit::Nanoseconds), + DataType::UInt64 | DataType::UInt32 => DataType::Int64, + DataType::UInt16 => DataType::Int32, + DataType::UInt8 => DataType::Int16, + inner_dt => inner_dt.clone(), + }; + DataType::List(Box::new(inner_dt)) + }), + Sort(_) => mapper.with_same_dtype(), + Reverse => mapper.with_same_dtype(), + Unique(_) => mapper.with_same_dtype(), + Length => mapper.with_dtype(IDX_DTYPE), + #[cfg(feature = "list_sets")] + SetOperation(_) => mapper.with_same_dtype(), + #[cfg(feature = "list_any_all")] + Any => mapper.with_dtype(DataType::Boolean), + #[cfg(feature = "list_any_all")] + All => mapper.with_dtype(DataType::Boolean), + Join(_) => mapper.with_dtype(DataType::String), + #[cfg(feature = "dtype-array")] + ToArray(width) => mapper.try_map_dtype(|dt| map_list_dtype_to_array_dtype(dt, *width)), + NUnique => mapper.with_dtype(IDX_DTYPE), + #[cfg(feature = "list_to_struct")] + ToStruct(args) => mapper.try_map_dtype(|x| args.get_output_dtype(x)), + } + } + + pub fn function_options(&self) -> FunctionOptions { + use ListFunction as L; + match self { + L::Concat => FunctionOptions::elementwise(), + #[cfg(feature = "is_in")] + L::Contains => FunctionOptions::elementwise(), + #[cfg(feature = "list_sample")] + L::Sample { .. } => FunctionOptions::elementwise(), + #[cfg(feature = "list_gather")] + L::Gather(_) => FunctionOptions::groupwise(), + #[cfg(feature = "list_gather")] + L::GatherEvery => FunctionOptions::elementwise(), + #[cfg(feature = "list_sets")] + L::SetOperation(_) => FunctionOptions { + collect_groups: ApplyOptions::ElementWise, + cast_options: Some(CastingRules::Supertype(SuperTypeOptions { + flags: SuperTypeFlags::default() | SuperTypeFlags::ALLOW_IMPLODE_LIST, + })), + + flags: FunctionFlags::default() & !FunctionFlags::RETURNS_SCALAR, + ..Default::default() + }, + #[cfg(feature = "diff")] + L::Diff { .. } => FunctionOptions::elementwise(), + #[cfg(feature = "list_drop_nulls")] + L::DropNulls => FunctionOptions::elementwise(), + #[cfg(feature = "list_count")] + L::CountMatches => FunctionOptions::elementwise(), + L::Sum + | L::Slice + | L::Shift + | L::Get(_) + | L::Length + | L::Max + | L::Min + | L::Mean + | L::Median + | L::Std(_) + | L::Var(_) + | L::ArgMin + | L::ArgMax + | L::Sort(_) + | L::Reverse + | L::Unique(_) + | L::Join(_) + | L::NUnique => FunctionOptions::elementwise(), + #[cfg(feature = "list_any_all")] + L::Any | L::All => FunctionOptions::elementwise(), + #[cfg(feature = "dtype-array")] + L::ToArray(_) => FunctionOptions::elementwise(), + #[cfg(feature = "list_to_struct")] + L::ToStruct(ListToStructArgs::FixedWidth(_)) => FunctionOptions::elementwise(), + #[cfg(feature = "list_to_struct")] + L::ToStruct(ListToStructArgs::InferWidth { .. }) => FunctionOptions::groupwise(), + } + } +} + +#[cfg(feature = "dtype-array")] +fn map_list_dtype_to_array_dtype(datatype: &DataType, width: usize) -> PolarsResult { + if let DataType::List(inner) = datatype { + Ok(DataType::Array(inner.clone(), width)) + } else { + polars_bail!(ComputeError: "expected List dtype") + } +} + +impl Display for ListFunction { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use ListFunction::*; + + let name = match self { + Concat => "concat", + #[cfg(feature = "is_in")] + Contains => "contains", + #[cfg(feature = "list_drop_nulls")] + DropNulls => "drop_nulls", + #[cfg(feature = "list_sample")] + Sample { is_fraction, .. } => { + if *is_fraction { + "sample_fraction" + } else { + "sample_n" + } + }, + Slice => "slice", + Shift => "shift", + Get(_) => "get", + #[cfg(feature = "list_gather")] + Gather(_) => "gather", + #[cfg(feature = "list_gather")] + GatherEvery => "gather_every", + #[cfg(feature = "list_count")] + CountMatches => "count_matches", + Sum => "sum", + Min => "min", + Max => "max", + Mean => "mean", + Median => "median", + Std(_) => "std", + Var(_) => "var", + ArgMin => "arg_min", + ArgMax => "arg_max", + #[cfg(feature = "diff")] + Diff { .. } => "diff", + Length => "length", + Sort(_) => "sort", + Reverse => "reverse", + Unique(is_stable) => { + if *is_stable { + "unique_stable" + } else { + "unique" + } + }, + NUnique => "n_unique", + #[cfg(feature = "list_sets")] + SetOperation(s) => return write!(f, "list.{s}"), + #[cfg(feature = "list_any_all")] + Any => "any", + #[cfg(feature = "list_any_all")] + All => "all", + Join(_) => "join", + #[cfg(feature = "dtype-array")] + ToArray(_) => "to_array", + #[cfg(feature = "list_to_struct")] + ToStruct(_) => "to_struct", + }; + write!(f, "list.{name}") + } +} + +impl From for SpecialEq> { + fn from(func: ListFunction) -> Self { + use ListFunction::*; + match func { + Concat => wrap!(concat), + #[cfg(feature = "is_in")] + Contains => wrap!(contains), + #[cfg(feature = "list_drop_nulls")] + DropNulls => map!(drop_nulls), + #[cfg(feature = "list_sample")] + Sample { + is_fraction, + with_replacement, + shuffle, + seed, + } => { + if is_fraction { + map_as_slice!(sample_fraction, with_replacement, shuffle, seed) + } else { + map_as_slice!(sample_n, with_replacement, shuffle, seed) + } + }, + Slice => wrap!(slice), + Shift => map_as_slice!(shift), + Get(null_on_oob) => wrap!(get, null_on_oob), + #[cfg(feature = "list_gather")] + Gather(null_on_oob) => map_as_slice!(gather, null_on_oob), + #[cfg(feature = "list_gather")] + GatherEvery => map_as_slice!(gather_every), + #[cfg(feature = "list_count")] + CountMatches => map_as_slice!(count_matches), + Sum => map!(sum), + Length => map!(length), + Max => map!(max), + Min => map!(min), + Mean => map!(mean), + Median => map!(median), + Std(ddof) => map!(std, ddof), + Var(ddof) => map!(var, ddof), + ArgMin => map!(arg_min), + ArgMax => map!(arg_max), + #[cfg(feature = "diff")] + Diff { n, null_behavior } => map!(diff, n, null_behavior), + Sort(options) => map!(sort, options), + Reverse => map!(reverse), + Unique(is_stable) => map!(unique, is_stable), + #[cfg(feature = "list_sets")] + SetOperation(s) => map_as_slice!(set_operation, s), + #[cfg(feature = "list_any_all")] + Any => map!(lst_any), + #[cfg(feature = "list_any_all")] + All => map!(lst_all), + Join(ignore_nulls) => map_as_slice!(join, ignore_nulls), + #[cfg(feature = "dtype-array")] + ToArray(width) => map!(to_array, width), + NUnique => map!(n_unique), + #[cfg(feature = "list_to_struct")] + ToStruct(args) => map!(to_struct, &args), + } + } +} + +#[cfg(feature = "is_in")] +pub(super) fn contains(args: &mut [Column]) -> PolarsResult> { + let list = &args[0]; + let item = &args[1]; + polars_ensure!(matches!(list.dtype(), DataType::List(_)), + SchemaMismatch: "invalid series dtype: expected `List`, got `{}`", list.dtype(), + ); + polars_ops::prelude::is_in( + item.as_materialized_series(), + list.as_materialized_series(), + true, + ) + .map(|mut ca| { + ca.rename(list.name().clone()); + Some(ca.into_column()) + }) +} + +#[cfg(feature = "list_drop_nulls")] +pub(super) fn drop_nulls(s: &Column) -> PolarsResult { + let list = s.list()?; + Ok(list.lst_drop_nulls().into_column()) +} + +#[cfg(feature = "list_sample")] +pub(super) fn sample_n( + s: &[Column], + with_replacement: bool, + shuffle: bool, + seed: Option, +) -> PolarsResult { + let list = s[0].list()?; + let n = &s[1]; + list.lst_sample_n(n.as_materialized_series(), with_replacement, shuffle, seed) + .map(|ok| ok.into_column()) +} + +#[cfg(feature = "list_sample")] +pub(super) fn sample_fraction( + s: &[Column], + with_replacement: bool, + shuffle: bool, + seed: Option, +) -> PolarsResult { + let list = s[0].list()?; + let fraction = &s[1]; + list.lst_sample_fraction( + fraction.as_materialized_series(), + with_replacement, + shuffle, + seed, + ) + .map(|ok| ok.into_column()) +} + +fn check_slice_arg_shape(slice_len: usize, ca_len: usize, name: &str) -> PolarsResult<()> { + polars_ensure!( + slice_len == ca_len, + ComputeError: + "shape of the slice '{}' argument: {} does not match that of the list column: {}", + name, slice_len, ca_len + ); + Ok(()) +} + +pub(super) fn shift(s: &[Column]) -> PolarsResult { + let list = s[0].list()?; + let periods = &s[1]; + + list.lst_shift(periods).map(|ok| ok.into_column()) +} + +pub(super) fn slice(args: &mut [Column]) -> PolarsResult> { + let s = &args[0]; + let list_ca = s.list()?; + let offset_s = &args[1]; + let length_s = &args[2]; + + let mut out: ListChunked = match (offset_s.len(), length_s.len()) { + (1, 1) => { + let offset = offset_s.get(0).unwrap().try_extract::()?; + let slice_len = length_s + .get(0) + .unwrap() + .extract::() + .unwrap_or(usize::MAX); + return Ok(Some(list_ca.lst_slice(offset, slice_len).into_column())); + }, + (1, length_slice_len) => { + check_slice_arg_shape(length_slice_len, list_ca.len(), "length")?; + let offset = offset_s.get(0).unwrap().try_extract::()?; + // cast to i64 as it is more likely that it is that dtype + // instead of usize/u64 (we never need that max length) + let length_ca = length_s.cast(&DataType::Int64)?; + let length_ca = length_ca.i64().unwrap(); + + list_ca + .amortized_iter() + .zip(length_ca) + .map(|(opt_s, opt_length)| match (opt_s, opt_length) { + (Some(s), Some(length)) => Some(s.as_ref().slice(offset, length as usize)), + _ => None, + }) + .collect_trusted() + }, + (offset_len, 1) => { + check_slice_arg_shape(offset_len, list_ca.len(), "offset")?; + let length_slice = length_s + .get(0) + .unwrap() + .extract::() + .unwrap_or(usize::MAX); + let offset_ca = offset_s.cast(&DataType::Int64)?; + let offset_ca = offset_ca.i64().unwrap(); + list_ca + .amortized_iter() + .zip(offset_ca) + .map(|(opt_s, opt_offset)| match (opt_s, opt_offset) { + (Some(s), Some(offset)) => Some(s.as_ref().slice(offset, length_slice)), + _ => None, + }) + .collect_trusted() + }, + _ => { + check_slice_arg_shape(offset_s.len(), list_ca.len(), "offset")?; + check_slice_arg_shape(length_s.len(), list_ca.len(), "length")?; + let offset_ca = offset_s.cast(&DataType::Int64)?; + let offset_ca = offset_ca.i64()?; + // cast to i64 as it is more likely that it is that dtype + // instead of usize/u64 (we never need that max length) + let length_ca = length_s.cast(&DataType::Int64)?; + let length_ca = length_ca.i64().unwrap(); + + list_ca + .amortized_iter() + .zip(offset_ca) + .zip(length_ca) + .map( + |((opt_s, opt_offset), opt_length)| match (opt_s, opt_offset, opt_length) { + (Some(s), Some(offset), Some(length)) => { + Some(s.as_ref().slice(offset, length as usize)) + }, + _ => None, + }, + ) + .collect_trusted() + }, + }; + out.rename(s.name().clone()); + Ok(Some(out.into_column())) +} + +pub(super) fn concat(s: &mut [Column]) -> PolarsResult> { + let mut first = std::mem::take(&mut s[0]); + let other = &s[1..]; + + // TODO! don't auto cast here, but implode beforehand. + let mut first_ca = match first.try_list() { + Some(ca) => ca, + None => { + first = first + .reshape_list(&[ReshapeDimension::Infer, ReshapeDimension::new_dimension(1)]) + .unwrap(); + first.list().unwrap() + }, + } + .clone(); + + if first_ca.len() == 1 && !other.is_empty() { + let max_len = other.iter().map(|s| s.len()).max().unwrap(); + if max_len > 1 { + first_ca = first_ca.new_from_index(0, max_len) + } + } + + first_ca.lst_concat(other).map(|ca| Some(ca.into_column())) +} + +pub(super) fn get(s: &mut [Column], null_on_oob: bool) -> PolarsResult> { + let ca = s[0].list()?; + let index = s[1].cast(&DataType::Int64)?; + let index = index.i64().unwrap(); + + match index.len() { + 1 => { + let index = index.get(0); + if let Some(index) = index { + ca.lst_get(index, null_on_oob).map(Column::from).map(Some) + } else { + Ok(Some(Column::full_null( + ca.name().clone(), + ca.len(), + ca.inner_dtype(), + ))) + } + }, + len if len == ca.len() => { + let tmp = ca.rechunk(); + let arr = tmp.downcast_as_array(); + let offsets = arr.offsets().as_slice(); + let take_by = if ca.null_count() == 0 { + index + .iter() + .enumerate() + .map(|(i, opt_idx)| match opt_idx { + Some(idx) => { + let (start, end) = unsafe { + (*offsets.get_unchecked(i), *offsets.get_unchecked(i + 1)) + }; + let offset = if idx >= 0 { start + idx } else { end + idx }; + if offset >= end || offset < start || start == end { + if null_on_oob { + Ok(None) + } else { + polars_bail!(ComputeError: "get index is out of bounds"); + } + } else { + Ok(Some(offset as IdxSize)) + } + }, + None => Ok(None), + }) + .collect::>()? + } else { + index + .iter() + .zip(arr.validity().unwrap()) + .enumerate() + .map(|(i, (opt_idx, valid))| match (valid, opt_idx) { + (true, Some(idx)) => { + let (start, end) = unsafe { + (*offsets.get_unchecked(i), *offsets.get_unchecked(i + 1)) + }; + let offset = if idx >= 0 { start + idx } else { end + idx }; + if offset >= end || offset < start || start == end { + if null_on_oob { + Ok(None) + } else { + polars_bail!(ComputeError: "get index is out of bounds"); + } + } else { + Ok(Some(offset as IdxSize)) + } + }, + _ => Ok(None), + }) + .collect::>()? + }; + let s = Series::try_from((ca.name().clone(), arr.values().clone())).unwrap(); + unsafe { s.take_unchecked(&take_by) } + .cast(ca.inner_dtype()) + .map(Column::from) + .map(Some) + }, + _ if ca.len() == 1 => { + if ca.null_count() > 0 { + return Ok(Some(Column::full_null( + ca.name().clone(), + index.len(), + ca.inner_dtype(), + ))); + } + let tmp = ca.rechunk(); + let arr = tmp.downcast_as_array(); + let offsets = arr.offsets().as_slice(); + let start = offsets[0]; + let end = offsets[1]; + let out_of_bounds = |offset| offset >= end || offset < start || start == end; + let take_by: IdxCa = index + .iter() + .map(|opt_idx| match opt_idx { + Some(idx) => { + let offset = if idx >= 0 { start + idx } else { end + idx }; + if out_of_bounds(offset) { + if null_on_oob { + Ok(None) + } else { + polars_bail!(ComputeError: "get index is out of bounds"); + } + } else { + let Ok(offset) = IdxSize::try_from(offset) else { + polars_bail!(ComputeError: "get index is out of bounds"); + }; + Ok(Some(offset)) + } + }, + None => Ok(None), + }) + .collect::>()?; + + let s = Series::try_from((ca.name().clone(), arr.values().clone())).unwrap(); + unsafe { s.take_unchecked(&take_by) } + .cast(ca.inner_dtype()) + .map(Column::from) + .map(Some) + }, + len => polars_bail!( + ComputeError: + "`list.get` expression got an index array of length {} while the list has {} elements", + len, ca.len() + ), + } +} + +#[cfg(feature = "list_gather")] +pub(super) fn gather(args: &[Column], null_on_oob: bool) -> PolarsResult { + let ca = &args[0]; + let idx = &args[1]; + let ca = ca.list()?; + + if idx.len() == 1 && idx.dtype().is_primitive_numeric() && null_on_oob { + // fast path + let idx = idx.get(0)?.try_extract::()?; + let out = ca.lst_get(idx, null_on_oob).map(Column::from)?; + // make sure we return a list + out.reshape_list(&[ReshapeDimension::Infer, ReshapeDimension::new_dimension(1)]) + } else { + ca.lst_gather(idx.as_materialized_series(), null_on_oob) + .map(Column::from) + } +} + +#[cfg(feature = "list_gather")] +pub(super) fn gather_every(args: &[Column]) -> PolarsResult { + let ca = &args[0]; + let n = &args[1].strict_cast(&IDX_DTYPE)?; + let offset = &args[2].strict_cast(&IDX_DTYPE)?; + + ca.list()? + .lst_gather_every(n.idx()?, offset.idx()?) + .map(Column::from) +} + +#[cfg(feature = "list_count")] +pub(super) fn count_matches(args: &[Column]) -> PolarsResult { + let s = &args[0]; + let element = &args[1]; + polars_ensure!( + element.len() == 1, + ComputeError: "argument expression in `list.count_matches` must produce exactly one element, got {}", + element.len() + ); + let ca = s.list()?; + list_count_matches(ca, element.get(0).unwrap()).map(Column::from) +} + +pub(super) fn sum(s: &Column) -> PolarsResult { + s.list()?.lst_sum().map(Column::from) +} + +pub(super) fn length(s: &Column) -> PolarsResult { + Ok(s.list()?.lst_lengths().into_column()) +} + +pub(super) fn max(s: &Column) -> PolarsResult { + s.list()?.lst_max().map(Column::from) +} + +pub(super) fn min(s: &Column) -> PolarsResult { + s.list()?.lst_min().map(Column::from) +} + +pub(super) fn mean(s: &Column) -> PolarsResult { + Ok(s.list()?.lst_mean().into()) +} + +pub(super) fn median(s: &Column) -> PolarsResult { + Ok(s.list()?.lst_median().into()) +} + +pub(super) fn std(s: &Column, ddof: u8) -> PolarsResult { + Ok(s.list()?.lst_std(ddof).into()) +} + +pub(super) fn var(s: &Column, ddof: u8) -> PolarsResult { + Ok(s.list()?.lst_var(ddof).into()) +} + +pub(super) fn arg_min(s: &Column) -> PolarsResult { + Ok(s.list()?.lst_arg_min().into_column()) +} + +pub(super) fn arg_max(s: &Column) -> PolarsResult { + Ok(s.list()?.lst_arg_max().into_column()) +} + +#[cfg(feature = "diff")] +pub(super) fn diff(s: &Column, n: i64, null_behavior: NullBehavior) -> PolarsResult { + Ok(s.list()?.lst_diff(n, null_behavior)?.into_column()) +} + +pub(super) fn sort(s: &Column, options: SortOptions) -> PolarsResult { + Ok(s.list()?.lst_sort(options)?.into_column()) +} + +pub(super) fn reverse(s: &Column) -> PolarsResult { + Ok(s.list()?.lst_reverse().into_column()) +} + +pub(super) fn unique(s: &Column, is_stable: bool) -> PolarsResult { + if is_stable { + Ok(s.list()?.lst_unique_stable()?.into_column()) + } else { + Ok(s.list()?.lst_unique()?.into_column()) + } +} + +#[cfg(feature = "list_sets")] +pub(super) fn set_operation(s: &[Column], set_type: SetOperation) -> PolarsResult { + let s0 = &s[0]; + let s1 = &s[1]; + + if s0.is_empty() || s1.is_empty() { + return match set_type { + SetOperation::Intersection => { + if s0.is_empty() { + Ok(s0.clone()) + } else { + Ok(s1.clone().with_name(s0.name().clone())) + } + }, + SetOperation::Difference => Ok(s0.clone()), + SetOperation::Union | SetOperation::SymmetricDifference => { + if s0.is_empty() { + Ok(s1.clone().with_name(s0.name().clone())) + } else { + Ok(s0.clone()) + } + }, + }; + } + + list_set_operation(s0.list()?, s1.list()?, set_type).map(|ca| ca.into_column()) +} + +#[cfg(feature = "list_any_all")] +pub(super) fn lst_any(s: &Column) -> PolarsResult { + s.list()?.lst_any().map(Column::from) +} + +#[cfg(feature = "list_any_all")] +pub(super) fn lst_all(s: &Column) -> PolarsResult { + s.list()?.lst_all().map(Column::from) +} + +pub(super) fn join(s: &[Column], ignore_nulls: bool) -> PolarsResult { + let ca = s[0].list()?; + let separator = s[1].str()?; + Ok(ca.lst_join(separator, ignore_nulls)?.into_column()) +} + +#[cfg(feature = "dtype-array")] +pub(super) fn to_array(s: &Column, width: usize) -> PolarsResult { + let array_dtype = map_list_dtype_to_array_dtype(s.dtype(), width)?; + s.cast(&array_dtype) +} + +#[cfg(feature = "list_to_struct")] +pub(super) fn to_struct(s: &Column, args: &ListToStructArgs) -> PolarsResult { + Ok(s.list()?.to_struct(args)?.into_series().into()) +} + +pub(super) fn n_unique(s: &Column) -> PolarsResult { + Ok(s.list()?.lst_n_unique()?.into_column()) +} diff --git a/crates/polars-plan/src/dsl/function_expr/log.rs b/crates/polars-plan/src/dsl/function_expr/log.rs new file mode 100644 index 000000000000..23b6b1e970b1 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/log.rs @@ -0,0 +1,23 @@ +use super::*; + +pub(super) fn entropy(s: &Column, base: f64, normalize: bool) -> PolarsResult { + let out = s.as_materialized_series().entropy(base, normalize)?; + if matches!(s.dtype(), DataType::Float32) { + let out = out as f32; + Ok(Column::new(s.name().clone(), [out])) + } else { + Ok(Column::new(s.name().clone(), [out])) + } +} + +pub(super) fn log(s: &Column, base: f64) -> PolarsResult { + Ok(s.as_materialized_series().log(base).into()) +} + +pub(super) fn log1p(s: &Column) -> PolarsResult { + Ok(s.as_materialized_series().log1p().into()) +} + +pub(super) fn exp(s: &Column) -> PolarsResult { + Ok(s.as_materialized_series().exp().into()) +} diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs new file mode 100644 index 000000000000..d50446aca13a --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -0,0 +1,1397 @@ +#[cfg(feature = "abs")] +mod abs; +#[cfg(feature = "arg_where")] +mod arg_where; +#[cfg(feature = "dtype-array")] +mod array; +mod binary; +#[cfg(feature = "bitwise")] +mod bitwise; +mod boolean; +mod bounds; +#[cfg(feature = "business")] +mod business; +#[cfg(feature = "dtype-categorical")] +pub mod cat; +#[cfg(feature = "round_series")] +mod clip; +#[cfg(feature = "dtype-struct")] +mod coerce; +mod concat; +#[cfg(feature = "cov")] +mod correlation; +#[cfg(feature = "cum_agg")] +mod cum; +#[cfg(feature = "cutqcut")] +mod cut; +#[cfg(feature = "temporal")] +mod datetime; +mod dispatch; +#[cfg(feature = "ewma")] +mod ewm; +#[cfg(feature = "ewma_by")] +mod ewm_by; +mod fill_null; +#[cfg(feature = "fused")] +mod fused; +#[cfg(feature = "index_of")] +mod index_of; +mod list; +#[cfg(feature = "log")] +mod log; +mod nan; +#[cfg(feature = "peaks")] +mod peaks; +#[cfg(feature = "ffi_plugin")] +mod plugin; +pub mod pow; +#[cfg(feature = "random")] +mod random; +#[cfg(feature = "range")] +mod range; +mod repeat; +#[cfg(feature = "rolling_window")] +pub mod rolling; +#[cfg(feature = "rolling_window_by")] +pub mod rolling_by; +#[cfg(feature = "round_series")] +mod round; +#[cfg(feature = "row_hash")] +mod row_hash; +pub(super) mod schema; +#[cfg(feature = "search_sorted")] +mod search_sorted; +mod shift_and_fill; +mod shrink_type; +#[cfg(feature = "sign")] +mod sign; +#[cfg(feature = "strings")] +mod strings; +#[cfg(feature = "dtype-struct")] +mod struct_; +#[cfg(feature = "temporal")] +mod temporal; +#[cfg(feature = "trigonometry")] +pub mod trigonometry; +mod unique; + +use std::fmt::{Display, Formatter}; +use std::hash::{Hash, Hasher}; + +#[cfg(feature = "dtype-array")] +pub use array::ArrayFunction; +#[cfg(feature = "cov")] +pub use correlation::CorrelationMethod; +#[cfg(feature = "fused")] +pub use fused::FusedOperator; +pub use list::ListFunction; +pub use polars_core::datatypes::ReshapeDimension; +use polars_core::prelude::*; +#[cfg(feature = "random")] +pub use random::RandomMethod; +use schema::FieldsMapper; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +pub use self::binary::BinaryFunction; +#[cfg(feature = "bitwise")] +pub use self::bitwise::BitwiseFunction; +pub use self::boolean::BooleanFunction; +#[cfg(feature = "business")] +pub use self::business::BusinessFunction; +#[cfg(feature = "dtype-categorical")] +pub use self::cat::CategoricalFunction; +#[cfg(feature = "temporal")] +pub use self::datetime::TemporalFunction; +pub use self::pow::PowFunction; +#[cfg(feature = "range")] +pub(super) use self::range::RangeFunction; +#[cfg(feature = "rolling_window")] +pub(super) use self::rolling::RollingFunction; +#[cfg(feature = "rolling_window_by")] +pub(super) use self::rolling_by::RollingFunctionBy; +#[cfg(feature = "strings")] +pub use self::strings::StringFunction; +#[cfg(feature = "dtype-struct")] +pub use self::struct_::StructFunction; +#[cfg(feature = "trigonometry")] +pub use self::trigonometry::TrigonometricFunction; +use super::*; + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Clone, PartialEq, Debug)] +pub enum FunctionExpr { + // Namespaces + #[cfg(feature = "dtype-array")] + ArrayExpr(ArrayFunction), + BinaryExpr(BinaryFunction), + #[cfg(feature = "dtype-categorical")] + Categorical(CategoricalFunction), + ListExpr(ListFunction), + #[cfg(feature = "strings")] + StringExpr(StringFunction), + #[cfg(feature = "dtype-struct")] + StructExpr(StructFunction), + #[cfg(feature = "temporal")] + TemporalExpr(TemporalFunction), + #[cfg(feature = "bitwise")] + Bitwise(BitwiseFunction), + + // Other expressions + Boolean(BooleanFunction), + #[cfg(feature = "business")] + Business(BusinessFunction), + #[cfg(feature = "abs")] + Abs, + Negate, + #[cfg(feature = "hist")] + Hist { + bin_count: Option, + include_category: bool, + include_breakpoint: bool, + }, + NullCount, + Pow(PowFunction), + #[cfg(feature = "row_hash")] + Hash(u64, u64, u64, u64), + #[cfg(feature = "arg_where")] + ArgWhere, + #[cfg(feature = "index_of")] + IndexOf, + #[cfg(feature = "search_sorted")] + SearchSorted(SearchSortedSide), + #[cfg(feature = "range")] + Range(RangeFunction), + #[cfg(feature = "trigonometry")] + Trigonometry(TrigonometricFunction), + #[cfg(feature = "trigonometry")] + Atan2, + #[cfg(feature = "sign")] + Sign, + FillNull, + FillNullWithStrategy(FillNullStrategy), + #[cfg(feature = "rolling_window")] + RollingExpr(RollingFunction), + #[cfg(feature = "rolling_window_by")] + RollingExprBy(RollingFunctionBy), + ShiftAndFill, + Shift, + DropNans, + DropNulls, + #[cfg(feature = "mode")] + Mode, + #[cfg(feature = "moment")] + Skew(bool), + #[cfg(feature = "moment")] + Kurtosis(bool, bool), + #[cfg(feature = "dtype-array")] + Reshape(Vec), + #[cfg(feature = "repeat_by")] + RepeatBy, + ArgUnique, + #[cfg(feature = "rank")] + Rank { + options: RankOptions, + seed: Option, + }, + Repeat, + #[cfg(feature = "round_series")] + Clip { + has_min: bool, + has_max: bool, + }, + #[cfg(feature = "dtype-struct")] + AsStruct, + #[cfg(feature = "top_k")] + TopK { + descending: bool, + }, + #[cfg(feature = "top_k")] + TopKBy { + descending: Vec, + }, + #[cfg(feature = "cum_agg")] + CumCount { + reverse: bool, + }, + #[cfg(feature = "cum_agg")] + CumSum { + reverse: bool, + }, + #[cfg(feature = "cum_agg")] + CumProd { + reverse: bool, + }, + #[cfg(feature = "cum_agg")] + CumMin { + reverse: bool, + }, + #[cfg(feature = "cum_agg")] + CumMax { + reverse: bool, + }, + Reverse, + #[cfg(feature = "dtype-struct")] + ValueCounts { + sort: bool, + parallel: bool, + name: PlSmallStr, + normalize: bool, + }, + #[cfg(feature = "unique_counts")] + UniqueCounts, + #[cfg(feature = "approx_unique")] + ApproxNUnique, + Coalesce, + ShrinkType, + #[cfg(feature = "diff")] + Diff(NullBehavior), + #[cfg(feature = "pct_change")] + PctChange, + #[cfg(feature = "interpolate")] + Interpolate(InterpolationMethod), + #[cfg(feature = "interpolate_by")] + InterpolateBy, + #[cfg(feature = "log")] + Entropy { + base: f64, + normalize: bool, + }, + #[cfg(feature = "log")] + Log { + base: f64, + }, + #[cfg(feature = "log")] + Log1p, + #[cfg(feature = "log")] + Exp, + Unique(bool), + #[cfg(feature = "round_series")] + Round { + decimals: u32, + }, + #[cfg(feature = "round_series")] + RoundSF { + digits: i32, + }, + #[cfg(feature = "round_series")] + Floor, + #[cfg(feature = "round_series")] + Ceil, + UpperBound, + LowerBound, + #[cfg(feature = "fused")] + Fused(fused::FusedOperator), + ConcatExpr(bool), + #[cfg(feature = "cov")] + Correlation { + method: correlation::CorrelationMethod, + }, + #[cfg(feature = "peaks")] + PeakMin, + #[cfg(feature = "peaks")] + PeakMax, + #[cfg(feature = "cutqcut")] + Cut { + breaks: Vec, + labels: Option>, + left_closed: bool, + include_breaks: bool, + }, + #[cfg(feature = "cutqcut")] + QCut { + probs: Vec, + labels: Option>, + left_closed: bool, + allow_duplicates: bool, + include_breaks: bool, + }, + #[cfg(feature = "rle")] + RLE, + #[cfg(feature = "rle")] + RLEID, + ToPhysical, + #[cfg(feature = "random")] + Random { + method: random::RandomMethod, + seed: Option, + }, + SetSortedFlag(IsSorted), + #[cfg(feature = "ffi_plugin")] + /// Creating this node is unsafe + /// This will lead to calls over FFI. + FfiPlugin { + flags: FunctionOptions, + /// Shared library. + lib: PlSmallStr, + /// Identifier in the shared lib. + symbol: PlSmallStr, + /// Pickle serialized keyword arguments. + kwargs: Arc<[u8]>, + }, + MaxHorizontal, + MinHorizontal, + SumHorizontal { + ignore_nulls: bool, + }, + MeanHorizontal { + ignore_nulls: bool, + }, + #[cfg(feature = "ewma")] + EwmMean { + options: EWMOptions, + }, + #[cfg(feature = "ewma_by")] + EwmMeanBy { + half_life: Duration, + }, + #[cfg(feature = "ewma")] + EwmStd { + options: EWMOptions, + }, + #[cfg(feature = "ewma")] + EwmVar { + options: EWMOptions, + }, + #[cfg(feature = "replace")] + Replace, + #[cfg(feature = "replace")] + ReplaceStrict { + return_dtype: Option, + }, + GatherEvery { + n: usize, + offset: usize, + }, + #[cfg(feature = "reinterpret")] + Reinterpret(bool), + ExtendConstant, +} + +impl Hash for FunctionExpr { + fn hash(&self, state: &mut H) { + std::mem::discriminant(self).hash(state); + use FunctionExpr::*; + match self { + // Namespaces + #[cfg(feature = "dtype-array")] + ArrayExpr(f) => f.hash(state), + BinaryExpr(f) => f.hash(state), + #[cfg(feature = "dtype-categorical")] + Categorical(f) => f.hash(state), + ListExpr(f) => f.hash(state), + #[cfg(feature = "strings")] + StringExpr(f) => f.hash(state), + #[cfg(feature = "dtype-struct")] + StructExpr(f) => f.hash(state), + #[cfg(feature = "temporal")] + TemporalExpr(f) => f.hash(state), + #[cfg(feature = "bitwise")] + Bitwise(f) => f.hash(state), + + // Other expressions + Boolean(f) => f.hash(state), + #[cfg(feature = "business")] + Business(f) => f.hash(state), + Pow(f) => f.hash(state), + #[cfg(feature = "index_of")] + IndexOf => {}, + #[cfg(feature = "search_sorted")] + SearchSorted(f) => f.hash(state), + #[cfg(feature = "random")] + Random { method, .. } => method.hash(state), + #[cfg(feature = "cov")] + Correlation { method, .. } => method.hash(state), + #[cfg(feature = "range")] + Range(f) => f.hash(state), + #[cfg(feature = "trigonometry")] + Trigonometry(f) => f.hash(state), + #[cfg(feature = "fused")] + Fused(f) => f.hash(state), + #[cfg(feature = "diff")] + Diff(null_behavior) => null_behavior.hash(state), + #[cfg(feature = "interpolate")] + Interpolate(f) => f.hash(state), + #[cfg(feature = "interpolate_by")] + InterpolateBy => {}, + #[cfg(feature = "ffi_plugin")] + FfiPlugin { + flags: _, + lib, + symbol, + kwargs, + } => { + kwargs.hash(state); + lib.hash(state); + symbol.hash(state); + }, + MaxHorizontal + | MinHorizontal + | SumHorizontal { .. } + | MeanHorizontal { .. } + | DropNans + | DropNulls + | Reverse + | ArgUnique + | Shift + | ShiftAndFill => {}, + #[cfg(feature = "mode")] + Mode => {}, + #[cfg(feature = "abs")] + Abs => {}, + Negate => {}, + NullCount => {}, + #[cfg(feature = "arg_where")] + ArgWhere => {}, + #[cfg(feature = "trigonometry")] + Atan2 => {}, + #[cfg(feature = "dtype-struct")] + AsStruct => {}, + #[cfg(feature = "sign")] + Sign => {}, + #[cfg(feature = "row_hash")] + Hash(a, b, c, d) => (a, b, c, d).hash(state), + FillNull => {}, + #[cfg(feature = "rolling_window")] + RollingExpr(f) => { + f.hash(state); + }, + #[cfg(feature = "rolling_window_by")] + RollingExprBy(f) => { + f.hash(state); + }, + #[cfg(feature = "moment")] + Skew(a) => a.hash(state), + #[cfg(feature = "moment")] + Kurtosis(a, b) => { + a.hash(state); + b.hash(state); + }, + Repeat => {}, + #[cfg(feature = "rank")] + Rank { options, seed } => { + options.hash(state); + seed.hash(state); + }, + #[cfg(feature = "round_series")] + Clip { has_min, has_max } => { + has_min.hash(state); + has_max.hash(state); + }, + #[cfg(feature = "top_k")] + TopK { descending } => descending.hash(state), + #[cfg(feature = "cum_agg")] + CumCount { reverse } => reverse.hash(state), + #[cfg(feature = "cum_agg")] + CumSum { reverse } => reverse.hash(state), + #[cfg(feature = "cum_agg")] + CumProd { reverse } => reverse.hash(state), + #[cfg(feature = "cum_agg")] + CumMin { reverse } => reverse.hash(state), + #[cfg(feature = "cum_agg")] + CumMax { reverse } => reverse.hash(state), + #[cfg(feature = "dtype-struct")] + ValueCounts { + sort, + parallel, + name, + normalize, + } => { + sort.hash(state); + parallel.hash(state); + name.hash(state); + normalize.hash(state); + }, + #[cfg(feature = "unique_counts")] + UniqueCounts => {}, + #[cfg(feature = "approx_unique")] + ApproxNUnique => {}, + Coalesce => {}, + ShrinkType => {}, + #[cfg(feature = "pct_change")] + PctChange => {}, + #[cfg(feature = "log")] + Entropy { base, normalize } => { + base.to_bits().hash(state); + normalize.hash(state); + }, + #[cfg(feature = "log")] + Log { base } => base.to_bits().hash(state), + #[cfg(feature = "log")] + Log1p => {}, + #[cfg(feature = "log")] + Exp => {}, + Unique(a) => a.hash(state), + #[cfg(feature = "round_series")] + Round { decimals } => decimals.hash(state), + #[cfg(feature = "round_series")] + FunctionExpr::RoundSF { digits } => digits.hash(state), + #[cfg(feature = "round_series")] + FunctionExpr::Floor => {}, + #[cfg(feature = "round_series")] + Ceil => {}, + UpperBound => {}, + LowerBound => {}, + ConcatExpr(a) => a.hash(state), + #[cfg(feature = "peaks")] + PeakMin => {}, + #[cfg(feature = "peaks")] + PeakMax => {}, + #[cfg(feature = "cutqcut")] + Cut { + breaks, + labels, + left_closed, + include_breaks, + } => { + let slice = bytemuck::cast_slice::<_, u64>(breaks); + slice.hash(state); + labels.hash(state); + left_closed.hash(state); + include_breaks.hash(state); + }, + #[cfg(feature = "dtype-array")] + Reshape(dims) => dims.hash(state), + #[cfg(feature = "repeat_by")] + RepeatBy => {}, + #[cfg(feature = "cutqcut")] + QCut { + probs, + labels, + left_closed, + allow_duplicates, + include_breaks, + } => { + let slice = bytemuck::cast_slice::<_, u64>(probs); + slice.hash(state); + labels.hash(state); + left_closed.hash(state); + allow_duplicates.hash(state); + include_breaks.hash(state); + }, + #[cfg(feature = "rle")] + RLE => {}, + #[cfg(feature = "rle")] + RLEID => {}, + ToPhysical => {}, + SetSortedFlag(is_sorted) => is_sorted.hash(state), + #[cfg(feature = "ewma")] + EwmMean { options } => options.hash(state), + #[cfg(feature = "ewma_by")] + EwmMeanBy { half_life } => (half_life).hash(state), + #[cfg(feature = "ewma")] + EwmStd { options } => options.hash(state), + #[cfg(feature = "ewma")] + EwmVar { options } => options.hash(state), + #[cfg(feature = "hist")] + Hist { + bin_count, + include_category, + include_breakpoint, + } => { + bin_count.hash(state); + include_category.hash(state); + include_breakpoint.hash(state); + }, + #[cfg(feature = "replace")] + Replace => {}, + #[cfg(feature = "replace")] + ReplaceStrict { return_dtype } => return_dtype.hash(state), + FillNullWithStrategy(strategy) => strategy.hash(state), + GatherEvery { n, offset } => (n, offset).hash(state), + #[cfg(feature = "reinterpret")] + Reinterpret(signed) => signed.hash(state), + ExtendConstant => {}, + #[cfg(feature = "top_k")] + TopKBy { descending } => descending.hash(state), + } + } +} + +impl Display for FunctionExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use FunctionExpr::*; + let s = match self { + // Namespaces + #[cfg(feature = "dtype-array")] + ArrayExpr(func) => return write!(f, "{func}"), + BinaryExpr(func) => return write!(f, "{func}"), + #[cfg(feature = "dtype-categorical")] + Categorical(func) => return write!(f, "{func}"), + ListExpr(func) => return write!(f, "{func}"), + #[cfg(feature = "strings")] + StringExpr(func) => return write!(f, "{func}"), + #[cfg(feature = "dtype-struct")] + StructExpr(func) => return write!(f, "{func}"), + #[cfg(feature = "temporal")] + TemporalExpr(func) => return write!(f, "{func}"), + #[cfg(feature = "bitwise")] + Bitwise(func) => return write!(f, "bitwise_{func}"), + + // Other expressions + Boolean(func) => return write!(f, "{func}"), + #[cfg(feature = "business")] + Business(func) => return write!(f, "{func}"), + #[cfg(feature = "abs")] + Abs => "abs", + Negate => "negate", + NullCount => "null_count", + Pow(func) => return write!(f, "{func}"), + #[cfg(feature = "row_hash")] + Hash(_, _, _, _) => "hash", + #[cfg(feature = "arg_where")] + ArgWhere => "arg_where", + #[cfg(feature = "index_of")] + IndexOf => "index_of", + #[cfg(feature = "search_sorted")] + SearchSorted(_) => "search_sorted", + #[cfg(feature = "range")] + Range(func) => return write!(f, "{func}"), + #[cfg(feature = "trigonometry")] + Trigonometry(func) => return write!(f, "{func}"), + #[cfg(feature = "trigonometry")] + Atan2 => return write!(f, "arctan2"), + #[cfg(feature = "sign")] + Sign => "sign", + FillNull => "fill_null", + #[cfg(feature = "rolling_window")] + RollingExpr(func, ..) => return write!(f, "{func}"), + #[cfg(feature = "rolling_window_by")] + RollingExprBy(func, ..) => return write!(f, "{func}"), + ShiftAndFill => "shift_and_fill", + DropNans => "drop_nans", + DropNulls => "drop_nulls", + #[cfg(feature = "mode")] + Mode => "mode", + #[cfg(feature = "moment")] + Skew(_) => "skew", + #[cfg(feature = "moment")] + Kurtosis(..) => "kurtosis", + ArgUnique => "arg_unique", + Repeat => "repeat", + #[cfg(feature = "rank")] + Rank { .. } => "rank", + #[cfg(feature = "round_series")] + Clip { has_min, has_max } => match (has_min, has_max) { + (true, true) => "clip", + (false, true) => "clip_max", + (true, false) => "clip_min", + _ => unreachable!(), + }, + #[cfg(feature = "dtype-struct")] + AsStruct => "as_struct", + #[cfg(feature = "top_k")] + TopK { descending } => { + if *descending { + "bottom_k" + } else { + "top_k" + } + }, + #[cfg(feature = "top_k")] + TopKBy { .. } => "top_k_by", + Shift => "shift", + #[cfg(feature = "cum_agg")] + CumCount { .. } => "cum_count", + #[cfg(feature = "cum_agg")] + CumSum { .. } => "cum_sum", + #[cfg(feature = "cum_agg")] + CumProd { .. } => "cum_prod", + #[cfg(feature = "cum_agg")] + CumMin { .. } => "cum_min", + #[cfg(feature = "cum_agg")] + CumMax { .. } => "cum_max", + #[cfg(feature = "dtype-struct")] + ValueCounts { .. } => "value_counts", + #[cfg(feature = "unique_counts")] + UniqueCounts => "unique_counts", + Reverse => "reverse", + #[cfg(feature = "approx_unique")] + ApproxNUnique => "approx_n_unique", + Coalesce => "coalesce", + ShrinkType => "shrink_dtype", + #[cfg(feature = "diff")] + Diff(_) => "diff", + #[cfg(feature = "pct_change")] + PctChange => "pct_change", + #[cfg(feature = "interpolate")] + Interpolate(_) => "interpolate", + #[cfg(feature = "interpolate_by")] + InterpolateBy => "interpolate_by", + #[cfg(feature = "log")] + Entropy { .. } => "entropy", + #[cfg(feature = "log")] + Log { .. } => "log", + #[cfg(feature = "log")] + Log1p => "log1p", + #[cfg(feature = "log")] + Exp => "exp", + Unique(stable) => { + if *stable { + "unique_stable" + } else { + "unique" + } + }, + #[cfg(feature = "round_series")] + Round { .. } => "round", + #[cfg(feature = "round_series")] + RoundSF { .. } => "round_sig_figs", + #[cfg(feature = "round_series")] + Floor => "floor", + #[cfg(feature = "round_series")] + Ceil => "ceil", + UpperBound => "upper_bound", + LowerBound => "lower_bound", + #[cfg(feature = "fused")] + Fused(fused) => return Display::fmt(fused, f), + ConcatExpr(_) => "concat_expr", + #[cfg(feature = "cov")] + Correlation { method, .. } => return Display::fmt(method, f), + #[cfg(feature = "peaks")] + PeakMin => "peak_min", + #[cfg(feature = "peaks")] + PeakMax => "peak_max", + #[cfg(feature = "cutqcut")] + Cut { .. } => "cut", + #[cfg(feature = "cutqcut")] + QCut { .. } => "qcut", + #[cfg(feature = "dtype-array")] + Reshape(_) => "reshape", + #[cfg(feature = "repeat_by")] + RepeatBy => "repeat_by", + #[cfg(feature = "rle")] + RLE => "rle", + #[cfg(feature = "rle")] + RLEID => "rle_id", + ToPhysical => "to_physical", + #[cfg(feature = "random")] + Random { method, .. } => method.into(), + SetSortedFlag(_) => "set_sorted", + #[cfg(feature = "ffi_plugin")] + FfiPlugin { lib, symbol, .. } => return write!(f, "{lib}:{symbol}"), + MaxHorizontal => "max_horizontal", + MinHorizontal => "min_horizontal", + SumHorizontal { .. } => "sum_horizontal", + MeanHorizontal { .. } => "mean_horizontal", + #[cfg(feature = "ewma")] + EwmMean { .. } => "ewm_mean", + #[cfg(feature = "ewma_by")] + EwmMeanBy { .. } => "ewm_mean_by", + #[cfg(feature = "ewma")] + EwmStd { .. } => "ewm_std", + #[cfg(feature = "ewma")] + EwmVar { .. } => "ewm_var", + #[cfg(feature = "hist")] + Hist { .. } => "hist", + #[cfg(feature = "replace")] + Replace => "replace", + #[cfg(feature = "replace")] + ReplaceStrict { .. } => "replace_strict", + FillNullWithStrategy(_) => "fill_null_with_strategy", + GatherEvery { .. } => "gather_every", + #[cfg(feature = "reinterpret")] + Reinterpret(_) => "reinterpret", + ExtendConstant => "extend_constant", + }; + write!(f, "{s}") + } +} + +#[macro_export] +macro_rules! wrap { + ($e:expr) => { + SpecialEq::new(Arc::new($e)) + }; + + ($e:expr, $($args:expr),*) => {{ + let f = move |s: &mut [Column]| { + $e(s, $($args),*) + }; + + SpecialEq::new(Arc::new(f)) + }}; +} + +/// `Fn(&[Column], args)` +/// * all expression arguments are in the slice. +/// * the first element is the root expression. +#[macro_export] +macro_rules! map_as_slice { + ($func:path) => {{ + let f = move |s: &mut [Column]| { + $func(s).map(Some) + }; + + SpecialEq::new(Arc::new(f)) + }}; + + ($func:path, $($args:expr),*) => {{ + let f = move |s: &mut [Column]| { + $func(s, $($args),*).map(Some) + }; + + SpecialEq::new(Arc::new(f)) + }}; +} + +/// * `FnOnce(Series)` +/// * `FnOnce(Series, args)` +#[macro_export] +macro_rules! map_owned { + ($func:path) => {{ + let f = move |c: &mut [Column]| { + let c = std::mem::take(&mut c[0]); + $func(c).map(Some) + }; + + SpecialEq::new(Arc::new(f)) + }}; + + ($func:path, $($args:expr),*) => {{ + let f = move |c: &mut [Column]| { + let c = std::mem::take(&mut c[0]); + $func(c, $($args),*).map(Some) + }; + + SpecialEq::new(Arc::new(f)) + }}; +} + +/// `Fn(&Series, args)` +#[macro_export] +macro_rules! map { + ($func:path) => {{ + let f = move |c: &mut [Column]| { + let c = &c[0]; + $func(c).map(Some) + }; + + SpecialEq::new(Arc::new(f)) + }}; + + ($func:path, $($args:expr),*) => {{ + let f = move |c: &mut [Column]| { + let c = &c[0]; + $func(c, $($args),*).map(Some) + }; + + SpecialEq::new(Arc::new(f)) + }}; +} + +impl From for SpecialEq> { + fn from(func: FunctionExpr) -> Self { + use FunctionExpr::*; + match func { + // Namespaces + #[cfg(feature = "dtype-array")] + ArrayExpr(func) => func.into(), + BinaryExpr(func) => func.into(), + #[cfg(feature = "dtype-categorical")] + Categorical(func) => func.into(), + ListExpr(func) => func.into(), + #[cfg(feature = "strings")] + StringExpr(func) => func.into(), + #[cfg(feature = "dtype-struct")] + StructExpr(func) => func.into(), + #[cfg(feature = "temporal")] + TemporalExpr(func) => func.into(), + #[cfg(feature = "bitwise")] + Bitwise(func) => func.into(), + + // Other expressions + Boolean(func) => func.into(), + #[cfg(feature = "business")] + Business(func) => func.into(), + #[cfg(feature = "abs")] + Abs => map!(abs::abs), + Negate => map!(dispatch::negate), + NullCount => { + let f = |s: &mut [Column]| { + let s = &s[0]; + Ok(Some(Column::new( + s.name().clone(), + [s.null_count() as IdxSize], + ))) + }; + wrap!(f) + }, + Pow(func) => match func { + PowFunction::Generic => wrap!(pow::pow), + PowFunction::Sqrt => map!(pow::sqrt), + PowFunction::Cbrt => map!(pow::cbrt), + }, + #[cfg(feature = "row_hash")] + Hash(k0, k1, k2, k3) => { + map!(row_hash::row_hash, k0, k1, k2, k3) + }, + #[cfg(feature = "arg_where")] + ArgWhere => { + wrap!(arg_where::arg_where) + }, + #[cfg(feature = "index_of")] + IndexOf => { + map_as_slice!(index_of::index_of) + }, + #[cfg(feature = "search_sorted")] + SearchSorted(side) => { + map_as_slice!(search_sorted::search_sorted_impl, side) + }, + #[cfg(feature = "range")] + Range(func) => func.into(), + + #[cfg(feature = "trigonometry")] + Trigonometry(trig_function) => { + map!(trigonometry::apply_trigonometric_function, trig_function) + }, + #[cfg(feature = "trigonometry")] + Atan2 => { + wrap!(trigonometry::apply_arctan2) + }, + + #[cfg(feature = "sign")] + Sign => { + map!(sign::sign) + }, + FillNull => { + map_as_slice!(fill_null::fill_null) + }, + #[cfg(feature = "rolling_window")] + RollingExpr(f) => { + use RollingFunction::*; + match f { + Min(options) => map!(rolling::rolling_min, options.clone()), + Max(options) => map!(rolling::rolling_max, options.clone()), + Mean(options) => map!(rolling::rolling_mean, options.clone()), + Sum(options) => map!(rolling::rolling_sum, options.clone()), + Quantile(options) => map!(rolling::rolling_quantile, options.clone()), + Var(options) => map!(rolling::rolling_var, options.clone()), + Std(options) => map!(rolling::rolling_std, options.clone()), + #[cfg(feature = "moment")] + Skew(options) => map!(rolling::rolling_skew, options.clone()), + #[cfg(feature = "moment")] + Kurtosis(options) => map!(rolling::rolling_kurtosis, options.clone()), + #[cfg(feature = "cov")] + CorrCov { + rolling_options, + corr_cov_options, + is_corr, + } => { + map_as_slice!( + rolling::rolling_corr_cov, + rolling_options.clone(), + corr_cov_options, + is_corr + ) + }, + } + }, + #[cfg(feature = "rolling_window_by")] + RollingExprBy(f) => { + use RollingFunctionBy::*; + match f { + MinBy(options) => map_as_slice!(rolling_by::rolling_min_by, options.clone()), + MaxBy(options) => map_as_slice!(rolling_by::rolling_max_by, options.clone()), + MeanBy(options) => map_as_slice!(rolling_by::rolling_mean_by, options.clone()), + SumBy(options) => map_as_slice!(rolling_by::rolling_sum_by, options.clone()), + QuantileBy(options) => { + map_as_slice!(rolling_by::rolling_quantile_by, options.clone()) + }, + VarBy(options) => map_as_slice!(rolling_by::rolling_var_by, options.clone()), + StdBy(options) => map_as_slice!(rolling_by::rolling_std_by, options.clone()), + } + }, + #[cfg(feature = "hist")] + Hist { + bin_count, + include_category, + include_breakpoint, + } => { + map_as_slice!( + dispatch::hist, + bin_count, + include_category, + include_breakpoint + ) + }, + ShiftAndFill => { + map_as_slice!(shift_and_fill::shift_and_fill) + }, + DropNans => map_owned!(nan::drop_nans), + DropNulls => map!(dispatch::drop_nulls), + #[cfg(feature = "round_series")] + Clip { has_min, has_max } => { + map_as_slice!(clip::clip, has_min, has_max) + }, + #[cfg(feature = "mode")] + Mode => map!(dispatch::mode), + #[cfg(feature = "moment")] + Skew(bias) => map!(dispatch::skew, bias), + #[cfg(feature = "moment")] + Kurtosis(fisher, bias) => map!(dispatch::kurtosis, fisher, bias), + ArgUnique => map!(dispatch::arg_unique), + Repeat => map_as_slice!(repeat::repeat), + #[cfg(feature = "rank")] + Rank { options, seed } => map!(dispatch::rank, options, seed), + #[cfg(feature = "dtype-struct")] + AsStruct => { + map_as_slice!(coerce::as_struct) + }, + #[cfg(feature = "top_k")] + TopK { descending } => { + map_as_slice!(top_k, descending) + }, + #[cfg(feature = "top_k")] + TopKBy { descending } => map_as_slice!(top_k_by, descending.clone()), + Shift => map_as_slice!(shift_and_fill::shift), + #[cfg(feature = "cum_agg")] + CumCount { reverse } => map!(cum::cum_count, reverse), + #[cfg(feature = "cum_agg")] + CumSum { reverse } => map!(cum::cum_sum, reverse), + #[cfg(feature = "cum_agg")] + CumProd { reverse } => map!(cum::cum_prod, reverse), + #[cfg(feature = "cum_agg")] + CumMin { reverse } => map!(cum::cum_min, reverse), + #[cfg(feature = "cum_agg")] + CumMax { reverse } => map!(cum::cum_max, reverse), + #[cfg(feature = "dtype-struct")] + ValueCounts { + sort, + parallel, + name, + normalize, + } => map!( + dispatch::value_counts, + sort, + parallel, + name.clone(), + normalize + ), + #[cfg(feature = "unique_counts")] + UniqueCounts => map!(dispatch::unique_counts), + Reverse => map!(dispatch::reverse), + #[cfg(feature = "approx_unique")] + ApproxNUnique => map!(dispatch::approx_n_unique), + Coalesce => map_as_slice!(fill_null::coalesce), + ShrinkType => map_owned!(shrink_type::shrink), + #[cfg(feature = "diff")] + Diff(null_behavior) => map_as_slice!(dispatch::diff, null_behavior), + #[cfg(feature = "pct_change")] + PctChange => map_as_slice!(dispatch::pct_change), + #[cfg(feature = "interpolate")] + Interpolate(method) => { + map!(dispatch::interpolate, method) + }, + #[cfg(feature = "interpolate_by")] + InterpolateBy => { + map_as_slice!(dispatch::interpolate_by) + }, + #[cfg(feature = "log")] + Entropy { base, normalize } => map!(log::entropy, base, normalize), + #[cfg(feature = "log")] + Log { base } => map!(log::log, base), + #[cfg(feature = "log")] + Log1p => map!(log::log1p), + #[cfg(feature = "log")] + Exp => map!(log::exp), + Unique(stable) => map!(unique::unique, stable), + #[cfg(feature = "round_series")] + Round { decimals } => map!(round::round, decimals), + #[cfg(feature = "round_series")] + RoundSF { digits } => map!(round::round_sig_figs, digits), + #[cfg(feature = "round_series")] + Floor => map!(round::floor), + #[cfg(feature = "round_series")] + Ceil => map!(round::ceil), + UpperBound => map!(bounds::upper_bound), + LowerBound => map!(bounds::lower_bound), + #[cfg(feature = "fused")] + Fused(op) => map_as_slice!(fused::fused, op), + ConcatExpr(rechunk) => map_as_slice!(concat::concat_expr, rechunk), + #[cfg(feature = "cov")] + Correlation { method } => map_as_slice!(correlation::corr, method), + #[cfg(feature = "peaks")] + PeakMin => map!(peaks::peak_min), + #[cfg(feature = "peaks")] + PeakMax => map!(peaks::peak_max), + #[cfg(feature = "repeat_by")] + RepeatBy => map_as_slice!(dispatch::repeat_by), + #[cfg(feature = "dtype-array")] + Reshape(dims) => map!(dispatch::reshape, &dims), + #[cfg(feature = "cutqcut")] + Cut { + breaks, + labels, + left_closed, + include_breaks, + } => map!( + cut::cut, + breaks.clone(), + labels.clone(), + left_closed, + include_breaks + ), + #[cfg(feature = "cutqcut")] + QCut { + probs, + labels, + left_closed, + allow_duplicates, + include_breaks, + } => map!( + cut::qcut, + probs.clone(), + labels.clone(), + left_closed, + allow_duplicates, + include_breaks + ), + #[cfg(feature = "rle")] + RLE => map!(rle), + #[cfg(feature = "rle")] + RLEID => map!(rle_id), + ToPhysical => map!(dispatch::to_physical), + #[cfg(feature = "random")] + Random { method, seed } => { + use RandomMethod::*; + match method { + Shuffle => map!(random::shuffle, seed), + Sample { + is_fraction, + with_replacement, + shuffle, + } => { + if is_fraction { + map_as_slice!(random::sample_frac, with_replacement, shuffle, seed) + } else { + map_as_slice!(random::sample_n, with_replacement, shuffle, seed) + } + }, + } + }, + SetSortedFlag(sorted) => map!(dispatch::set_sorted_flag, sorted), + #[cfg(feature = "ffi_plugin")] + FfiPlugin { + flags: _, + lib, + symbol, + kwargs, + } => unsafe { + map_as_slice!( + plugin::call_plugin, + lib.as_ref(), + symbol.as_ref(), + kwargs.as_ref() + ) + }, + MaxHorizontal => wrap!(dispatch::max_horizontal), + MinHorizontal => wrap!(dispatch::min_horizontal), + SumHorizontal { ignore_nulls } => wrap!(dispatch::sum_horizontal, ignore_nulls), + MeanHorizontal { ignore_nulls } => wrap!(dispatch::mean_horizontal, ignore_nulls), + #[cfg(feature = "ewma")] + EwmMean { options } => map!(ewm::ewm_mean, options), + #[cfg(feature = "ewma_by")] + EwmMeanBy { half_life } => map_as_slice!(ewm_by::ewm_mean_by, half_life), + #[cfg(feature = "ewma")] + EwmStd { options } => map!(ewm::ewm_std, options), + #[cfg(feature = "ewma")] + EwmVar { options } => map!(ewm::ewm_var, options), + #[cfg(feature = "replace")] + Replace => { + map_as_slice!(dispatch::replace) + }, + #[cfg(feature = "replace")] + ReplaceStrict { return_dtype } => { + map_as_slice!(dispatch::replace_strict, return_dtype.clone()) + }, + + FillNullWithStrategy(strategy) => map!(dispatch::fill_null_with_strategy, strategy), + GatherEvery { n, offset } => map!(dispatch::gather_every, n, offset), + #[cfg(feature = "reinterpret")] + Reinterpret(signed) => map!(dispatch::reinterpret, signed), + ExtendConstant => map_as_slice!(dispatch::extend_constant), + } + } +} + +impl FunctionExpr { + pub fn function_options(&self) -> FunctionOptions { + use FunctionExpr as F; + match self { + #[cfg(feature = "dtype-array")] + F::ArrayExpr(e) => e.function_options(), + F::BinaryExpr(e) => e.function_options(), + #[cfg(feature = "dtype-categorical")] + F::Categorical(e) => e.function_options(), + F::ListExpr(e) => e.function_options(), + #[cfg(feature = "strings")] + F::StringExpr(e) => e.function_options(), + #[cfg(feature = "dtype-struct")] + F::StructExpr(e) => e.function_options(), + #[cfg(feature = "temporal")] + F::TemporalExpr(e) => e.function_options(), + #[cfg(feature = "bitwise")] + F::Bitwise(e) => e.function_options(), + F::Boolean(e) => e.function_options(), + #[cfg(feature = "business")] + F::Business(e) => e.function_options(), + F::Pow(e) => e.function_options(), + #[cfg(feature = "range")] + F::Range(e) => e.function_options(), + #[cfg(feature = "abs")] + F::Abs => FunctionOptions::elementwise(), + F::Negate => FunctionOptions::elementwise(), + #[cfg(feature = "hist")] + F::Hist { .. } => FunctionOptions::groupwise(), + F::NullCount => FunctionOptions::aggregation(), + #[cfg(feature = "row_hash")] + F::Hash(_, _, _, _) => FunctionOptions::elementwise(), + #[cfg(feature = "arg_where")] + F::ArgWhere => FunctionOptions::groupwise(), + #[cfg(feature = "index_of")] + F::IndexOf => { + FunctionOptions::aggregation().with_casting_rules(CastingRules::FirstArgLossless) + }, + #[cfg(feature = "search_sorted")] + F::SearchSorted(_) => FunctionOptions::groupwise().with_supertyping( + (SuperTypeFlags::default() & !SuperTypeFlags::ALLOW_PRIMITIVE_TO_STRING).into(), + ), + #[cfg(feature = "trigonometry")] + F::Trigonometry(_) => FunctionOptions::elementwise(), + #[cfg(feature = "trigonometry")] + F::Atan2 => FunctionOptions::elementwise(), + #[cfg(feature = "sign")] + F::Sign => FunctionOptions::elementwise(), + F::FillNull => FunctionOptions::elementwise().with_supertyping(Default::default()), + F::FillNullWithStrategy(strategy) if strategy.is_elementwise() => { + FunctionOptions::elementwise() + }, + F::FillNullWithStrategy(_) => FunctionOptions::groupwise(), + #[cfg(feature = "rolling_window")] + F::RollingExpr(_) => FunctionOptions::length_preserving(), + #[cfg(feature = "rolling_window_by")] + F::RollingExprBy(_) => FunctionOptions::length_preserving(), + F::ShiftAndFill => FunctionOptions::length_preserving(), + F::Shift => FunctionOptions::length_preserving(), + F::DropNans => FunctionOptions::row_separable(), + F::DropNulls => FunctionOptions::row_separable().with_allow_empty_inputs(true), + #[cfg(feature = "mode")] + F::Mode => FunctionOptions::groupwise(), + #[cfg(feature = "moment")] + F::Skew(_) => FunctionOptions::aggregation(), + #[cfg(feature = "moment")] + F::Kurtosis(_, _) => FunctionOptions::aggregation(), + #[cfg(feature = "dtype-array")] + F::Reshape(_) => FunctionOptions::groupwise(), + #[cfg(feature = "repeat_by")] + F::RepeatBy => FunctionOptions::elementwise(), + F::ArgUnique => FunctionOptions::groupwise(), + #[cfg(feature = "rank")] + F::Rank { .. } => FunctionOptions::groupwise(), + F::Repeat => FunctionOptions::groupwise() + .with_allow_rename(true) + .with_changes_length(true), + #[cfg(feature = "round_series")] + F::Clip { .. } => FunctionOptions::elementwise(), + #[cfg(feature = "dtype-struct")] + F::AsStruct => FunctionOptions::elementwise() + .with_pass_name_to_apply(true) + .with_input_wildcard_expansion(true), + #[cfg(feature = "top_k")] + F::TopK { .. } => FunctionOptions::groupwise(), + #[cfg(feature = "top_k")] + F::TopKBy { .. } => FunctionOptions::groupwise(), + #[cfg(feature = "cum_agg")] + F::CumCount { .. } + | F::CumSum { .. } + | F::CumProd { .. } + | F::CumMin { .. } + | F::CumMax { .. } => FunctionOptions::length_preserving(), + F::Reverse => FunctionOptions::length_preserving(), + #[cfg(feature = "dtype-struct")] + F::ValueCounts { .. } => FunctionOptions::groupwise().with_pass_name_to_apply(true), + #[cfg(feature = "unique_counts")] + F::UniqueCounts => FunctionOptions::groupwise(), + #[cfg(feature = "approx_unique")] + F::ApproxNUnique => FunctionOptions::aggregation(), + F::Coalesce => FunctionOptions::elementwise() + .with_input_wildcard_expansion(true) + .with_supertyping(Default::default()), + F::ShrinkType => FunctionOptions::length_preserving(), + #[cfg(feature = "diff")] + F::Diff(NullBehavior::Drop) => FunctionOptions::groupwise(), + #[cfg(feature = "diff")] + F::Diff(NullBehavior::Ignore) => FunctionOptions::length_preserving(), + #[cfg(feature = "pct_change")] + F::PctChange => FunctionOptions::length_preserving(), + #[cfg(feature = "interpolate")] + F::Interpolate(_) => FunctionOptions::length_preserving(), + #[cfg(feature = "interpolate_by")] + F::InterpolateBy => FunctionOptions::length_preserving(), + #[cfg(feature = "log")] + F::Log { .. } | F::Log1p | F::Exp => FunctionOptions::elementwise(), + #[cfg(feature = "log")] + F::Entropy { .. } => FunctionOptions::aggregation(), + F::Unique(_) => FunctionOptions::groupwise(), + #[cfg(feature = "round_series")] + F::Round { .. } | F::RoundSF { .. } | F::Floor | F::Ceil => { + FunctionOptions::elementwise() + }, + F::UpperBound | F::LowerBound => FunctionOptions::aggregation(), + #[cfg(feature = "fused")] + F::Fused(_) => FunctionOptions::elementwise(), + F::ConcatExpr(_) => FunctionOptions::groupwise() + .with_input_wildcard_expansion(true) + .with_supertyping(Default::default()), + #[cfg(feature = "cov")] + F::Correlation { .. } => { + FunctionOptions::aggregation().with_supertyping(Default::default()) + }, + #[cfg(feature = "peaks")] + F::PeakMin | F::PeakMax => FunctionOptions::length_preserving(), + #[cfg(feature = "cutqcut")] + F::Cut { .. } | F::QCut { .. } => { + FunctionOptions::length_preserving().with_pass_name_to_apply(true) + }, + #[cfg(feature = "rle")] + F::RLE => FunctionOptions::groupwise(), + #[cfg(feature = "rle")] + F::RLEID => FunctionOptions::length_preserving(), + F::ToPhysical => FunctionOptions::elementwise(), + #[cfg(feature = "random")] + F::Random { + method: RandomMethod::Sample { .. }, + .. + } => FunctionOptions::groupwise(), + #[cfg(feature = "random")] + F::Random { + method: RandomMethod::Shuffle, + .. + } => FunctionOptions::length_preserving(), + F::SetSortedFlag(_) => FunctionOptions::elementwise(), + #[cfg(feature = "ffi_plugin")] + F::FfiPlugin { flags, .. } => *flags, + F::MaxHorizontal | F::MinHorizontal => FunctionOptions::elementwise() + .with_input_wildcard_expansion(true) + .with_allow_rename(true), + F::MeanHorizontal { .. } | F::SumHorizontal { .. } => { + FunctionOptions::elementwise().with_input_wildcard_expansion(true) + }, + #[cfg(feature = "ewma")] + F::EwmMean { .. } | F::EwmStd { .. } | F::EwmVar { .. } => { + FunctionOptions::length_preserving() + }, + #[cfg(feature = "ewma_by")] + F::EwmMeanBy { .. } => FunctionOptions::length_preserving(), + #[cfg(feature = "replace")] + F::Replace => FunctionOptions::elementwise(), + #[cfg(feature = "replace")] + F::ReplaceStrict { .. } => FunctionOptions::elementwise(), + F::GatherEvery { .. } => FunctionOptions::groupwise(), + #[cfg(feature = "reinterpret")] + F::Reinterpret(_) => FunctionOptions::elementwise(), + F::ExtendConstant => FunctionOptions::groupwise(), + } + } +} diff --git a/crates/polars-plan/src/dsl/function_expr/nan.rs b/crates/polars-plan/src/dsl/function_expr/nan.rs new file mode 100644 index 000000000000..035556336a65 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/nan.rs @@ -0,0 +1,17 @@ +use super::*; + +pub(super) fn drop_nans(s: Column) -> PolarsResult { + match s.dtype() { + DataType::Float32 => { + let ca = s.f32()?; + let mask = ca.is_not_nan() | ca.is_null(); + ca.filter(&mask).map(|ca| ca.into_column()) + }, + DataType::Float64 => { + let ca = s.f64()?; + let mask = ca.is_not_nan() | ca.is_null(); + ca.filter(&mask).map(|ca| ca.into_column()) + }, + _ => Ok(s), + } +} diff --git a/crates/polars-plan/src/dsl/function_expr/peaks.rs b/crates/polars-plan/src/dsl/function_expr/peaks.rs new file mode 100644 index 000000000000..702a9dc3c86d --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/peaks.rs @@ -0,0 +1,38 @@ +use polars_core::with_match_physical_numeric_polars_type; +use polars_ops::chunked_array::peaks::{peak_max as pmax, peak_min as pmin}; + +use super::*; + +pub(super) fn peak_min(s: &Column) -> PolarsResult { + let s = s.to_physical_repr(); + let s = s.as_materialized_series(); + let s = match s.dtype() { + DataType::Boolean => polars_bail!(opq = peak_min, DataType::Boolean), + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(_, _) => pmin(s.decimal()?).into_column(), + dt => { + with_match_physical_numeric_polars_type!(dt, |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + pmin(ca).into_column() + }) + }, + }; + Ok(s) +} + +pub(super) fn peak_max(s: &Column) -> PolarsResult { + let s = s.to_physical_repr(); + let s = s.as_materialized_series(); + let s = match s.dtype() { + DataType::Boolean => polars_bail!(opq = peak_max, DataType::Boolean), + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(_, _) => pmax(s.decimal()?).into_column(), + dt => { + with_match_physical_numeric_polars_type!(dt, |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + pmax(ca).into_column() + }) + }, + }; + Ok(s) +} diff --git a/crates/polars-plan/src/dsl/function_expr/plugin.rs b/crates/polars-plan/src/dsl/function_expr/plugin.rs new file mode 100644 index 000000000000..ad6840ebe75d --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/plugin.rs @@ -0,0 +1,221 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use std::ffi::CStr; +use std::sync::{LazyLock, RwLock}; + +use arrow::ffi::{ArrowSchema, import_field_from_c}; +use libloading::Library; +use pyo3::Python; +use pyo3::types::PyAnyMethods; + +use super::*; + +type PluginAndVersion = (Library, u16, u16); +static LOADED: LazyLock>> = + LazyLock::new(Default::default); + +fn get_lib(lib: &str) -> PolarsResult<&'static PluginAndVersion> { + let lib_map = LOADED.read().unwrap(); + if let Some(library) = lib_map.get(lib) { + // lifetime is static as we never remove libraries. + Ok(unsafe { std::mem::transmute::<&PluginAndVersion, &'static PluginAndVersion>(library) }) + } else { + drop(lib_map); + + let load_path = if !std::path::Path::new(lib).is_absolute() { + // Get python virtual environment path + let prefix = Python::with_gil(|py| { + let sys = py.import("sys").unwrap(); + let prefix = sys.getattr("prefix").unwrap(); + prefix.to_string() + }); + let full_path = std::path::Path::new(&prefix).join(lib); + full_path.to_string_lossy().into_owned() + } else { + lib.to_string() + }; + + let library = unsafe { + Library::new(&load_path).map_err(|e| { + PolarsError::ComputeError(format!("error loading dynamic library: {e}").into()) + })? + }; + let version_function: libloading::Symbol u32> = unsafe { + library + .get("_polars_plugin_get_version".as_bytes()) + .unwrap() + }; + + let version = unsafe { version_function() }; + let major = (version >> 16) as u16; + let minor = version as u16; + + let mut lib_map = LOADED.write().unwrap(); + lib_map.insert(lib.to_string(), (library, major, minor)); + drop(lib_map); + + get_lib(lib) + } +} + +unsafe fn retrieve_error_msg(lib: &Library) -> &CStr { + let symbol: libloading::Symbol *mut std::os::raw::c_char> = + lib.get(b"_polars_plugin_get_last_error_message\0").unwrap(); + let msg_ptr = symbol(); + CStr::from_ptr(msg_ptr) +} + +pub(super) unsafe fn call_plugin( + s: &[Column], + lib: &str, + symbol: &str, + kwargs: &[u8], +) -> PolarsResult { + let plugin = get_lib(lib)?; + let lib = &plugin.0; + let major = plugin.1; + + if major == 0 { + use polars_ffi::version_0::*; + // *const SeriesExport: pointer to Box + // * usize: length of that pointer + // *const u8: pointer to &[u8] + // usize: length of the u8 slice + // *mut SeriesExport: pointer where return value should be written. + // *const CallerContext + let symbol: libloading::Symbol< + unsafe extern "C" fn( + *const SeriesExport, + usize, + *const u8, + usize, + *mut SeriesExport, + *const CallerContext, + ), + > = lib + .get(format!("_polars_plugin_{}", symbol).as_bytes()) + .unwrap(); + + // @scalar-correctness? + let input = s.iter().map(export_column).collect::>(); + let input_len = s.len(); + let slice_ptr = input.as_ptr(); + + let kwargs_ptr = kwargs.as_ptr(); + let kwargs_len = kwargs.len(); + + let mut return_value = SeriesExport::empty(); + let return_value_ptr = &mut return_value as *mut SeriesExport; + let context = CallerContext::default(); + let context_ptr = &context as *const CallerContext; + symbol( + slice_ptr, + input_len, + kwargs_ptr, + kwargs_len, + return_value_ptr, + context_ptr, + ); + + // The inputs get dropped when the ffi side calls the drop callback. + for e in input { + std::mem::forget(e); + } + + if !return_value.is_null() { + import_series(return_value).map(Column::from) + } else { + let msg = retrieve_error_msg(lib); + let msg = msg.to_string_lossy(); + check_panic(msg.as_ref())?; + polars_bail!(ComputeError: "the plugin failed with message: {}", msg) + } + } else { + polars_bail!(ComputeError: "this polars engine doesn't support plugin version: {}", major) + } +} + +pub(super) unsafe fn plugin_field( + fields: &[Field], + lib: &str, + symbol: &str, + kwargs: &[u8], +) -> PolarsResult { + let plugin = get_lib(lib)?; + let lib = &plugin.0; + let major = plugin.1; + let minor = plugin.2; + + // we deallocate the fields buffer + let ffi_fields = fields + .iter() + .map(|field| arrow::ffi::export_field_to_c(&field.to_arrow(CompatLevel::newest()))) + .collect::>() + .into_boxed_slice(); + let n_args = ffi_fields.len(); + let slice_ptr = ffi_fields.as_ptr(); + + let mut return_value = ArrowSchema::empty(); + let return_value_ptr = &mut return_value as *mut ArrowSchema; + + if major == 0 { + match minor { + 0 => { + let views = fields.iter().any(|field| field.dtype.contains_views()); + polars_ensure!(!views, ComputeError: "cannot call plugin\n\nThis Polars' version has a different 'binary/string' layout. Please compile with latest 'pyo3-polars'"); + + // *const ArrowSchema: pointer to heap Box + // usize: length of the boxed slice + // *mut ArrowSchema: pointer where the return value can be written + let symbol: libloading::Symbol< + unsafe extern "C" fn(*const ArrowSchema, usize, *mut ArrowSchema), + > = lib + .get((format!("_polars_plugin_field_{}", symbol)).as_bytes()) + .unwrap(); + symbol(slice_ptr, n_args, return_value_ptr); + }, + 1 => { + // *const ArrowSchema: pointer to heap Box + // usize: length of the boxed slice + // *mut ArrowSchema: pointer where the return value can be written + // *const u8: pointer to &[u8] (kwargs) + // usize: length of the u8 slice + let symbol: libloading::Symbol< + unsafe extern "C" fn( + *const ArrowSchema, + usize, + *mut ArrowSchema, + *const u8, + usize, + ), + > = lib + .get((format!("_polars_plugin_field_{}", symbol)).as_bytes()) + .unwrap(); + + let kwargs_ptr = kwargs.as_ptr(); + let kwargs_len = kwargs.len(); + + symbol(slice_ptr, n_args, return_value_ptr, kwargs_ptr, kwargs_len); + }, + _ => { + polars_bail!(ComputeError: "this Polars engine doesn't support plugin version: {}-{}", major, minor) + }, + } + if !return_value.is_null() { + let arrow_field = import_field_from_c(&return_value)?; + let out = Field::from(&arrow_field); + Ok(out) + } else { + let msg = retrieve_error_msg(lib); + let msg = msg.to_string_lossy(); + check_panic(msg.as_ref())?; + polars_bail!(ComputeError: "the plugin failed with message: {}", msg) + } + } else { + polars_bail!(ComputeError: "this Polars engine doesn't support plugin version: {}", major) + } +} + +fn check_panic(msg: &str) -> PolarsResult<()> { + polars_ensure!(msg != "PANIC", ComputeError: "the plugin panicked\n\nThe message is suppressed. Set POLARS_VERBOSE=1 to send the panic message to stderr."); + Ok(()) +} diff --git a/crates/polars-plan/src/dsl/function_expr/pow.rs b/crates/polars-plan/src/dsl/function_expr/pow.rs new file mode 100644 index 000000000000..cecb661c5cbf --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/pow.rs @@ -0,0 +1,275 @@ +use num_traits::pow::Pow; +use num_traits::{Float, One, ToPrimitive, Zero}; +use polars_core::prelude::arity::{broadcast_binary_elementwise, unary_elementwise_values}; +use polars_core::with_match_physical_integer_type; + +use super::*; + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash)] +pub enum PowFunction { + Generic, + Sqrt, + Cbrt, +} + +impl PowFunction { + pub fn function_options(&self) -> FunctionOptions { + use PowFunction as P; + match self { + P::Generic | P::Sqrt | P::Cbrt => FunctionOptions::elementwise(), + } + } +} + +impl Display for PowFunction { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use self::*; + match self { + PowFunction::Generic => write!(f, "pow"), + PowFunction::Sqrt => write!(f, "sqrt"), + PowFunction::Cbrt => write!(f, "cbrt"), + } + } +} + +impl From for FunctionExpr { + fn from(value: PowFunction) -> Self { + Self::Pow(value) + } +} + +fn pow_on_chunked_arrays( + base: &ChunkedArray, + exponent: &ChunkedArray, +) -> ChunkedArray +where + T: PolarsNumericType, + F: PolarsNumericType, + T::Native: Pow + ToPrimitive, +{ + if exponent.len() == 1 { + if let Some(e) = exponent.get(0) { + if e == F::Native::zero() { + return unary_elementwise_values(base, |_| T::Native::one()); + } + if e == F::Native::one() { + return base.clone(); + } + if e == F::Native::one() + F::Native::one() { + return base * base; + } + } + } + + broadcast_binary_elementwise(base, exponent, |b, e| Some(Pow::pow(b?, e?))) +} + +fn pow_on_floats( + base: &ChunkedArray, + exponent: &ChunkedArray, +) -> PolarsResult> +where + T: PolarsFloatType, + T::Native: Pow + ToPrimitive + Float, + ChunkedArray: IntoColumn, +{ + let dtype = T::get_dtype(); + + if exponent.len() == 1 { + let Some(exponent_value) = exponent.get(0) else { + return Ok(Some(Column::full_null( + base.name().clone(), + base.len(), + &dtype, + ))); + }; + let s = match exponent_value.to_f64().unwrap() { + 1.0 => base.clone().into_column(), + // specialized sqrt will ensure (-inf)^0.5 = NaN + // and will likely be faster as well. + 0.5 => base.apply_values(|v| v.sqrt()).into_column(), + a if a.fract() == 0.0 && a < 10.0 && a > 1.0 => { + let mut out = base.clone(); + + for _ in 1..exponent_value.to_u8().unwrap() { + out = out * base.clone() + } + out.into_column() + }, + _ => base + .apply_values(|v| Pow::pow(v, exponent_value)) + .into_column(), + }; + Ok(Some(s)) + } else { + Ok(Some(pow_on_chunked_arrays(base, exponent).into_column())) + } +} + +fn pow_to_uint_dtype( + base: &ChunkedArray, + exponent: &ChunkedArray, +) -> PolarsResult> +where + T: PolarsIntegerType, + F: PolarsIntegerType, + T::Native: Pow + ToPrimitive, + ChunkedArray: IntoColumn, +{ + let dtype = T::get_dtype(); + + if exponent.len() == 1 { + let Some(exponent_value) = exponent.get(0) else { + return Ok(Some(Column::full_null( + base.name().clone(), + base.len(), + &dtype, + ))); + }; + let s = match exponent_value.to_u64().unwrap() { + 1 => base.clone().into_column(), + 2..=10 => { + let mut out = base.clone(); + + for _ in 1..exponent_value.to_u8().unwrap() { + out = out * base.clone() + } + out.into_column() + }, + _ => base + .apply_values(|v| Pow::pow(v, exponent_value)) + .into_column(), + }; + Ok(Some(s)) + } else { + Ok(Some(pow_on_chunked_arrays(base, exponent).into_column())) + } +} + +fn pow_on_series(base: &Column, exponent: &Column) -> PolarsResult> { + use DataType::*; + + let base_dtype = base.dtype(); + polars_ensure!( + base_dtype.is_primitive_numeric(), + InvalidOperation: "`pow` operation not supported for dtype `{}` as base", base_dtype + ); + let exponent_dtype = exponent.dtype(); + polars_ensure!( + exponent_dtype.is_primitive_numeric(), + InvalidOperation: "`pow` operation not supported for dtype `{}` as exponent", exponent_dtype + ); + + // if false, dtype is float + if base_dtype.is_integer() { + with_match_physical_integer_type!(base_dtype, |$native_type| { + if exponent_dtype.is_float() { + match exponent_dtype { + Float32 => { + let ca = base.cast(&DataType::Float32)?; + pow_on_floats(ca.f32().unwrap(), exponent.f32().unwrap()) + }, + Float64 => { + let ca = base.cast(&DataType::Float64)?; + pow_on_floats(ca.f64().unwrap(), exponent.f64().unwrap()) + }, + _ => unreachable!(), + } + } else { + let ca = base.$native_type().unwrap(); + let exponent = exponent.strict_cast(&DataType::UInt32).map_err(|err| polars_err!( + InvalidOperation: + "{}\n\nHint: if you were trying to raise an integer to a negative integer power, please cast your base or exponent to float first.", + err + ))?; + pow_to_uint_dtype(ca, exponent.u32().unwrap()) + } + }) + } else { + match base_dtype { + Float32 => { + let ca = base.f32().unwrap(); + let exponent = exponent.strict_cast(&DataType::Float32)?; + pow_on_floats(ca, exponent.f32().unwrap()) + }, + Float64 => { + let ca = base.f64().unwrap(); + let exponent = exponent.strict_cast(&DataType::Float64)?; + pow_on_floats(ca, exponent.f64().unwrap()) + }, + _ => unreachable!(), + } + } +} + +pub(super) fn pow(s: &mut [Column]) -> PolarsResult> { + let base = &s[0]; + let exponent = &s[1]; + + let base_len = base.len(); + let exp_len = exponent.len(); + match (base_len, exp_len) { + (1, _) | (_, 1) => pow_on_series(base, exponent), + (len_a, len_b) if len_a == len_b => pow_on_series(base, exponent), + _ => polars_bail!( + ComputeError: + "exponent shape: {} in `pow` expression does not match that of the base: {}", + exp_len, base_len, + ), + } +} + +pub(super) fn sqrt(base: &Column) -> PolarsResult { + use DataType::*; + match base.dtype() { + Float32 => { + let ca = base.f32().unwrap(); + sqrt_on_floats(ca) + }, + Float64 => { + let ca = base.f64().unwrap(); + sqrt_on_floats(ca) + }, + _ => { + let base = base.cast(&DataType::Float64)?; + sqrt(&base) + }, + } +} + +fn sqrt_on_floats(base: &ChunkedArray) -> PolarsResult +where + T: PolarsFloatType, + T::Native: Pow + ToPrimitive + Float, + ChunkedArray: IntoColumn, +{ + Ok(base.apply_values(|v| v.sqrt()).into_column()) +} + +pub(super) fn cbrt(base: &Column) -> PolarsResult { + use DataType::*; + match base.dtype() { + Float32 => { + let ca = base.f32().unwrap(); + cbrt_on_floats(ca) + }, + Float64 => { + let ca = base.f64().unwrap(); + cbrt_on_floats(ca) + }, + _ => { + let base = base.cast(&DataType::Float64)?; + cbrt(&base) + }, + } +} + +fn cbrt_on_floats(base: &ChunkedArray) -> PolarsResult +where + T: PolarsFloatType, + T::Native: Pow + ToPrimitive + Float, + ChunkedArray: IntoColumn, +{ + Ok(base.apply_values(|v| v.cbrt()).into_column()) +} diff --git a/crates/polars-plan/src/dsl/function_expr/random.rs b/crates/polars-plan/src/dsl/function_expr/random.rs new file mode 100644 index 000000000000..91b0decc1c71 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/random.rs @@ -0,0 +1,74 @@ +use polars_core::prelude::DataType::Float64; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; + +use super::*; + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Copy, Clone, PartialEq, Debug, IntoStaticStr)] +#[strum(serialize_all = "snake_case")] +pub enum RandomMethod { + Shuffle, + Sample { + is_fraction: bool, + with_replacement: bool, + shuffle: bool, + }, +} + +impl Hash for RandomMethod { + fn hash(&self, state: &mut H) { + std::mem::discriminant(self).hash(state) + } +} + +pub(super) fn shuffle(s: &Column, seed: Option) -> PolarsResult { + Ok(s.shuffle(seed)) +} + +pub(super) fn sample_frac( + s: &[Column], + with_replacement: bool, + shuffle: bool, + seed: Option, +) -> PolarsResult { + let src = &s[0]; + let frac_s = &s[1]; + + polars_ensure!( + frac_s.len() == 1, + ComputeError: "Sample fraction must be a single value." + ); + + let frac_s = frac_s.cast(&Float64)?; + let frac = frac_s.f64()?; + + match frac.get(0) { + Some(frac) => src.sample_frac(frac, with_replacement, shuffle, seed), + None => Ok(Column::new_empty(src.name().clone(), src.dtype())), + } +} + +pub(super) fn sample_n( + s: &[Column], + with_replacement: bool, + shuffle: bool, + seed: Option, +) -> PolarsResult { + let src = &s[0]; + let n_s = &s[1]; + + polars_ensure!( + n_s.len() == 1, + ComputeError: "Sample size must be a single value." + ); + + let n_s = n_s.cast(&IDX_DTYPE)?; + let n = n_s.idx()?; + + match n.get(0) { + Some(n) => src.sample_n(n as usize, with_replacement, shuffle, seed), + None => Ok(Column::new_empty(src.name().clone(), src.dtype())), + } +} diff --git a/crates/polars-plan/src/dsl/function_expr/range/date_range.rs b/crates/polars-plan/src/dsl/function_expr/range/date_range.rs new file mode 100644 index 000000000000..d081f2d819e9 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/range/date_range.rs @@ -0,0 +1,97 @@ +use polars_core::prelude::*; +use polars_core::utils::arrow::temporal_conversions::MILLISECONDS_IN_DAY; +use polars_time::{ClosedWindow, Duration, datetime_range_impl}; + +use super::utils::{ + ensure_range_bounds_contain_exactly_one_value, temporal_ranges_impl_broadcast, + temporal_series_to_i64_scalar, +}; + +const CAPACITY_FACTOR: usize = 5; + +pub(super) fn date_range( + s: &[Column], + interval: Duration, + closed: ClosedWindow, +) -> PolarsResult { + let start = &s[0]; + let end = &s[1]; + + ensure_range_bounds_contain_exactly_one_value(start, end)?; + let start = start.strict_cast(&DataType::Date)?; + let end = end.strict_cast(&DataType::Date)?; + polars_ensure!( + interval.is_full_days(), + ComputeError: "`interval` input for `date_range` must consist of full days, got: {interval}" + ); + + let name = start.name().clone(); + let start = temporal_series_to_i64_scalar(&start) + .ok_or_else(|| polars_err!(ComputeError: "start is an out-of-range time."))? + * MILLISECONDS_IN_DAY; + let end = temporal_series_to_i64_scalar(&end) + .ok_or_else(|| polars_err!(ComputeError: "end is an out-of-range time."))? + * MILLISECONDS_IN_DAY; + + let out = datetime_range_impl( + name, + start, + end, + interval, + closed, + TimeUnit::Milliseconds, + None, + )?; + + let to_type = DataType::Date; + out.cast(&to_type).map(Column::from) +} + +pub(super) fn date_ranges( + s: &[Column], + interval: Duration, + closed: ClosedWindow, +) -> PolarsResult { + let start = &s[0]; + let end = &s[1]; + + polars_ensure!( + interval.is_full_days(), + ComputeError: "`interval` input for `date_ranges` must consist of full days, got: {interval}" + ); + + let start = start.strict_cast(&DataType::Date)?.cast(&DataType::Int64)?; + let end = end.strict_cast(&DataType::Date)?.cast(&DataType::Int64)?; + + let start = start.i64().unwrap() * MILLISECONDS_IN_DAY; + let end = end.i64().unwrap() * MILLISECONDS_IN_DAY; + + let mut builder = ListPrimitiveChunkedBuilder::::new( + start.name().clone(), + start.len(), + start.len() * CAPACITY_FACTOR, + DataType::Int32, + ); + + let range_impl = |start, end, builder: &mut ListPrimitiveChunkedBuilder| { + let rng = datetime_range_impl( + PlSmallStr::EMPTY, + start, + end, + interval, + closed, + TimeUnit::Milliseconds, + None, + )?; + let rng = rng.cast(&DataType::Date).unwrap(); + let rng = rng.to_physical_repr(); + let rng = rng.i32().unwrap(); + builder.append_slice(rng.cont_slice().unwrap()); + Ok(()) + }; + + let out = temporal_ranges_impl_broadcast(&start, &end, range_impl, &mut builder)?; + + let to_type = DataType::List(Box::new(DataType::Date)); + out.cast(&to_type) +} diff --git a/crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs b/crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs new file mode 100644 index 000000000000..9453a045657e --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs @@ -0,0 +1,248 @@ +#[cfg(feature = "timezones")] +use polars_core::chunked_array::temporal::parse_time_zone; +use polars_core::prelude::*; +use polars_time::{ClosedWindow, Duration, datetime_range_impl}; + +use super::utils::{ + ensure_range_bounds_contain_exactly_one_value, temporal_ranges_impl_broadcast, + temporal_series_to_i64_scalar, +}; +use crate::dsl::function_expr::FieldsMapper; + +const CAPACITY_FACTOR: usize = 5; + +pub(super) fn datetime_range( + s: &[Column], + interval: Duration, + closed: ClosedWindow, + time_unit: Option, + time_zone: Option, +) -> PolarsResult { + let mut start = s[0].clone(); + let mut end = s[1].clone(); + + ensure_range_bounds_contain_exactly_one_value(&start, &end)?; + + // Note: `start` and `end` have already been cast to their supertype, + // so only `start`'s dtype needs to be matched against. + #[allow(unused_mut)] // `dtype` is mutated within a "feature = timezones" block. + let mut dtype = match (start.dtype(), time_unit) { + (DataType::Date, time_unit) => { + if let Some(tu) = time_unit { + DataType::Datetime(tu, None) + } else if interval.nanoseconds() % 1_000 != 0 { + DataType::Datetime(TimeUnit::Nanoseconds, None) + } else { + DataType::Datetime(TimeUnit::Microseconds, None) + } + }, + // overwrite nothing, keep as-is + (DataType::Datetime(_, _), None) => start.dtype().clone(), + // overwrite time unit, keep timezone + (DataType::Datetime(_, tz), Some(tu)) => DataType::Datetime(tu, tz.clone()), + (dt, _) => polars_bail!(InvalidOperation: "expected a temporal datatype, got {}", dt), + }; + + // overwrite time zone, if specified + match (&dtype, &time_zone) { + #[cfg(feature = "timezones")] + (DataType::Datetime(tu, _), Some(tz)) => { + dtype = DataType::Datetime(*tu, Some(tz.clone())); + }, + _ => {}, + }; + + if start.dtype() == &DataType::Date { + start = start.cast(&DataType::Datetime(TimeUnit::Milliseconds, None))?; + end = end.cast(&DataType::Datetime(TimeUnit::Milliseconds, None))?; + } + + // If `start` and `end` are naive, but a time zone was specified, + // then first localize them + let (start, end) = match (start.dtype(), time_zone) { + #[cfg(feature = "timezones")] + (DataType::Datetime(_, None), Some(tz)) => ( + polars_ops::prelude::replace_time_zone( + start.datetime().unwrap(), + Some(&tz), + &StringChunked::from_iter(std::iter::once("raise")), + NonExistent::Raise, + )? + .cast(&dtype)? + .into_column(), + polars_ops::prelude::replace_time_zone( + end.datetime().unwrap(), + Some(&tz), + &StringChunked::from_iter(std::iter::once("raise")), + NonExistent::Raise, + )? + .cast(&dtype)? + .into_column(), + ), + _ => (start.cast(&dtype)?, end.cast(&dtype)?), + }; + + let name = start.name(); + let start = temporal_series_to_i64_scalar(&start) + .ok_or_else(|| polars_err!(ComputeError: "start is an out-of-range time."))?; + let end = temporal_series_to_i64_scalar(&end) + .ok_or_else(|| polars_err!(ComputeError: "end is an out-of-range time."))?; + + let result = match dtype { + DataType::Datetime(tu, ref tz) => { + let tz = match tz { + #[cfg(feature = "timezones")] + Some(tz) => Some(parse_time_zone(tz)?), + _ => None, + }; + datetime_range_impl(name.clone(), start, end, interval, closed, tu, tz.as_ref())? + }, + _ => unimplemented!(), + }; + Ok(result.cast(&dtype).unwrap().into_column()) +} + +pub(super) fn datetime_ranges( + s: &[Column], + interval: Duration, + closed: ClosedWindow, + time_unit: Option, + time_zone: Option, +) -> PolarsResult { + let mut start = s[0].clone(); + let mut end = s[1].clone(); + + // Note: `start` and `end` have already been cast to their supertype, + // so only `start`'s dtype needs to be matched against. + #[allow(unused_mut)] // `dtype` is mutated within a "feature = timezones" block. + let mut dtype = match (start.dtype(), time_unit) { + (DataType::Date, time_unit) => { + if let Some(tu) = time_unit { + DataType::Datetime(tu, None) + } else if interval.nanoseconds() % 1_000 != 0 { + DataType::Datetime(TimeUnit::Nanoseconds, None) + } else { + DataType::Datetime(TimeUnit::Microseconds, None) + } + }, + // overwrite nothing, keep as-is + (DataType::Datetime(_, _), None) => start.dtype().clone(), + // overwrite time unit, keep timezone + (DataType::Datetime(_, tz), Some(tu)) => DataType::Datetime(tu, tz.clone()), + _ => unreachable!(), + }; + + // overwrite time zone, if specified + match (&dtype, &time_zone) { + #[cfg(feature = "timezones")] + (DataType::Datetime(tu, _), Some(tz)) => { + dtype = DataType::Datetime(*tu, Some(tz.clone())); + }, + _ => {}, + }; + + if start.dtype() == &DataType::Date { + start = start.cast(&DataType::Datetime(TimeUnit::Milliseconds, None))?; + end = end.cast(&DataType::Datetime(TimeUnit::Milliseconds, None))?; + } + + // If `start` and `end` are naive, but a time zone was specified, + // then first localize them + let (start, end) = match (start.dtype(), time_zone) { + #[cfg(feature = "timezones")] + (DataType::Datetime(_, None), Some(tz)) => ( + polars_ops::prelude::replace_time_zone( + start.datetime().unwrap(), + Some(&tz), + &StringChunked::from_iter(std::iter::once("raise")), + NonExistent::Raise, + )? + .cast(&dtype)? + .into_column() + .to_physical_repr() + .cast(&DataType::Int64)?, + polars_ops::prelude::replace_time_zone( + end.datetime().unwrap(), + Some(&tz), + &StringChunked::from_iter(std::iter::once("raise")), + NonExistent::Raise, + )? + .cast(&dtype)? + .into_column() + .to_physical_repr() + .cast(&DataType::Int64)?, + ), + _ => ( + start + .cast(&dtype)? + .to_physical_repr() + .cast(&DataType::Int64)?, + end.cast(&dtype)? + .to_physical_repr() + .cast(&DataType::Int64)?, + ), + }; + + let start = start.i64().unwrap(); + let end = end.i64().unwrap(); + + let out = match dtype { + DataType::Datetime(tu, ref tz) => { + let mut builder = ListPrimitiveChunkedBuilder::::new( + start.name().clone(), + start.len(), + start.len() * CAPACITY_FACTOR, + DataType::Int64, + ); + + let tz = match tz { + #[cfg(feature = "timezones")] + Some(tz) => Some(parse_time_zone(tz)?), + _ => None, + }; + let range_impl = |start, end, builder: &mut ListPrimitiveChunkedBuilder| { + let rng = datetime_range_impl( + PlSmallStr::EMPTY, + start, + end, + interval, + closed, + tu, + tz.as_ref(), + )?; + builder.append_slice(rng.cont_slice().unwrap()); + Ok(()) + }; + + temporal_ranges_impl_broadcast(start, end, range_impl, &mut builder)? + }, + _ => unimplemented!(), + }; + + let to_type = DataType::List(Box::new(dtype)); + out.cast(&to_type) +} + +impl FieldsMapper<'_> { + pub(super) fn map_to_datetime_range_dtype( + &self, + time_unit: Option<&TimeUnit>, + time_zone: Option<&PlSmallStr>, + ) -> PolarsResult { + let data_dtype = self.map_to_supertype()?.dtype; + + let (data_tu, data_tz) = if let DataType::Datetime(tu, tz) = data_dtype { + (tu, tz) + } else { + (TimeUnit::Microseconds, None) + }; + + let tu = match time_unit { + Some(tu) => *tu, + None => data_tu, + }; + let tz = time_zone.cloned().or(data_tz); + + Ok(DataType::Datetime(tu, tz)) + } +} diff --git a/crates/polars-plan/src/dsl/function_expr/range/int_range.rs b/crates/polars-plan/src/dsl/function_expr/range/int_range.rs new file mode 100644 index 000000000000..12bd9082c7a6 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/range/int_range.rs @@ -0,0 +1,82 @@ +use polars_core::prelude::*; +use polars_core::with_match_physical_integer_polars_type; +use polars_ops::series::new_int_range; + +use super::utils::{ensure_range_bounds_contain_exactly_one_value, numeric_ranges_impl_broadcast}; + +const CAPACITY_FACTOR: usize = 5; + +pub(super) fn int_range(s: &[Column], step: i64, dtype: DataType) -> PolarsResult { + let mut start = &s[0]; + let mut end = &s[1]; + let name = start.name(); + + ensure_range_bounds_contain_exactly_one_value(start, end)?; + polars_ensure!(dtype.is_integer(), ComputeError: "non-integer `dtype` passed to `int_range`: {:?}", dtype); + + let (start_storage, end_storage); + if *start.dtype() != dtype { + start_storage = start.strict_cast(&dtype)?; + start = &start_storage; + } + if *end.dtype() != dtype { + end_storage = end.strict_cast(&dtype)?; + end = &end_storage; + } + + with_match_physical_integer_polars_type!(dtype, |$T| { + let start_v = get_first_series_value::<$T>(start)?; + let end_v = get_first_series_value::<$T>(end)?; + new_int_range::<$T>(start_v, end_v, step, name.clone()).map(Column::from) + }) +} + +fn get_first_series_value(s: &Column) -> PolarsResult +where + T: PolarsIntegerType, +{ + let ca: &ChunkedArray = s.as_materialized_series().as_any().downcast_ref().unwrap(); + let value_opt = ca.get(0); + let value = + value_opt.ok_or_else(|| polars_err!(ComputeError: "invalid null input for `int_range`"))?; + Ok(value) +} + +pub(super) fn int_ranges(s: &[Column]) -> PolarsResult { + let start = &s[0]; + let end = &s[1]; + let step = &s[2]; + + let start = start.cast(&DataType::Int64)?; + let end = end.cast(&DataType::Int64)?; + let step = step.cast(&DataType::Int64)?; + + let start = start.i64()?; + let end = end.i64()?; + let step = step.i64()?; + + let len = std::cmp::max(start.len(), end.len()); + let mut builder = ListPrimitiveChunkedBuilder::::new( + // The name should follow our left hand rule. + start.name().clone(), + len, + len * CAPACITY_FACTOR, + DataType::Int64, + ); + + let range_impl = + |start, end, step: i64, builder: &mut ListPrimitiveChunkedBuilder| { + match step { + 1 => builder.append_values_iter_trusted_len(start..end), + 2.. => builder.append_values_iter_trusted_len((start..end).step_by(step as usize)), + _ => builder.append_values_iter_trusted_len( + (end..start) + .step_by(step.unsigned_abs() as usize) + .map(|x| start - (x - end)), + ), + }; + Ok(()) + }; + + numeric_ranges_impl_broadcast(start, end, step, range_impl, &mut builder) +} diff --git a/crates/polars-plan/src/dsl/function_expr/range/linear_space.rs b/crates/polars-plan/src/dsl/function_expr/range/linear_space.rs new file mode 100644 index 000000000000..8f2590bf72ae --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/range/linear_space.rs @@ -0,0 +1,366 @@ +use arrow::temporal_conversions::MILLISECONDS_IN_DAY; +use polars_core::prelude::*; +use polars_ops::series::{ClosedInterval, new_linear_space_f32, new_linear_space_f64}; + +use super::utils::{build_nulls, ensure_range_bounds_contain_exactly_one_value}; + +const CAPACITY_FACTOR: usize = 5; + +pub(super) fn linear_space(s: &[Column], closed: ClosedInterval) -> PolarsResult { + let start = &s[0]; + let end = &s[1]; + let num_samples = &s[2]; + let name = start.name(); + + ensure_range_bounds_contain_exactly_one_value(start, end)?; + polars_ensure!( + num_samples.len() == 1, + ComputeError: "`num_samples` must contain exactly one value, got {} values", num_samples.len() + ); + + let start = start.get(0).unwrap(); + let end = end.get(0).unwrap(); + let num_samples = num_samples.get(0).unwrap(); + let num_samples = num_samples + .extract::() + .ok_or(PolarsError::ComputeError( + format!( + "'num_samples' must be non-negative integer, got {}", + num_samples + ) + .into(), + ))?; + + match (start.dtype(), end.dtype()) { + (DataType::Float32, DataType::Float32) => new_linear_space_f32( + start.extract::().unwrap(), + end.extract::().unwrap(), + num_samples, + closed, + name.clone(), + ) + .map(|s| s.into_column()), + (mut dt, dt2) if dt.is_temporal() && dt == dt2 => { + let mut start = start.extract::().unwrap(); + let mut end = end.extract::().unwrap(); + + // A linear space of a Date produces a sequence of Datetimes, so we must upcast. + if dt == DataType::Date { + start *= MILLISECONDS_IN_DAY; + end *= MILLISECONDS_IN_DAY; + dt = DataType::Datetime(TimeUnit::Milliseconds, None); + } + new_linear_space_f64(start as f64, end as f64, num_samples, closed, name.clone()) + .map(|s| s.cast(&dt).unwrap().into_column()) + }, + (dt1, dt2) if !dt1.is_primitive_numeric() || !dt2.is_primitive_numeric() => { + Err(PolarsError::ComputeError( + format!( + "'start' and 'end' have incompatible dtypes, got {:?} and {:?}", + dt1, dt2 + ) + .into(), + )) + }, + (_, _) => new_linear_space_f64( + start.extract::().unwrap(), + end.extract::().unwrap(), + num_samples, + closed, + name.clone(), + ) + .map(|s| s.into_column()), + } +} + +pub(super) fn linear_spaces( + s: &[Column], + closed: ClosedInterval, + array_width: Option, +) -> PolarsResult { + let start = &s[0]; + let end = &s[1]; + + let (num_samples, capacity_factor) = match array_width { + Some(ns) => { + // An array width is provided instead of a column of `num_sample`s. + let scalar = Scalar::new(DataType::UInt64, AnyValue::UInt64(ns as u64)); + (&Column::new_scalar(PlSmallStr::EMPTY, scalar, 1), ns) + }, + None => (&s[2], CAPACITY_FACTOR), + }; + let name = start.name().clone(); + + let num_samples = num_samples.strict_cast(&DataType::UInt64)?; + let num_samples = num_samples.u64()?; + let len = start.len().max(end.len()).max(num_samples.len()); + + match (start.dtype(), end.dtype()) { + (DataType::Float32, DataType::Float32) => { + let mut builder = ListPrimitiveChunkedBuilder::::new( + name, + len, + len * capacity_factor, + DataType::Float32, + ); + + let linspace_impl = + |start, + end, + num_samples, + builder: &mut ListPrimitiveChunkedBuilder| { + let ls = + new_linear_space_f32(start, end, num_samples, closed, PlSmallStr::EMPTY)?; + builder.append_slice(ls.cont_slice().unwrap()); + Ok(()) + }; + + let start = start.f32()?; + let end = end.f32()?; + let out = + linear_spaces_impl_broadcast(start, end, num_samples, linspace_impl, &mut builder)?; + + let to_type = array_width.map_or_else( + || DataType::List(Box::new(DataType::Float32)), + |width| DataType::Array(Box::new(DataType::Float32), width), + ); + out.cast(&to_type) + }, + (mut dt, dt2) if dt.is_temporal() && dt == dt2 => { + let mut start = start.to_physical_repr(); + let mut end = end.to_physical_repr(); + + // A linear space of a Date produces a sequence of Datetimes, so we must upcast. + if dt == &DataType::Date { + start = start.cast(&DataType::Int64)? * MILLISECONDS_IN_DAY; + end = end.cast(&DataType::Int64)? * MILLISECONDS_IN_DAY; + dt = &DataType::Datetime(TimeUnit::Milliseconds, None); + } + + let start = start.cast(&DataType::Float64)?; + let start = start.f64()?; + let end = end.cast(&DataType::Float64)?; + let end = end.f64()?; + + let mut builder = ListPrimitiveChunkedBuilder::::new( + name, + len, + len * capacity_factor, + DataType::Float64, + ); + + let linspace_impl = + |start, + end, + num_samples, + builder: &mut ListPrimitiveChunkedBuilder| { + let ls = + new_linear_space_f64(start, end, num_samples, closed, PlSmallStr::EMPTY)?; + builder.append_slice(ls.cont_slice().unwrap()); + Ok(()) + }; + let out = + linear_spaces_impl_broadcast(start, end, num_samples, linspace_impl, &mut builder)?; + + let to_type = array_width.map_or_else( + || DataType::List(Box::new(dt.clone())), + |width| DataType::Array(Box::new(dt.clone()), width), + ); + out.cast(&to_type) + }, + (dt1, dt2) if !dt1.is_primitive_numeric() || !dt2.is_primitive_numeric() => { + Err(PolarsError::ComputeError( + format!( + "'start' and 'end' have incompatible dtypes, got {:?} and {:?}", + dt1, dt2 + ) + .into(), + )) + }, + (_, _) => { + let start = start.strict_cast(&DataType::Float64)?; + let end = end.strict_cast(&DataType::Float64)?; + let start = start.f64()?; + let end = end.f64()?; + + let mut builder = ListPrimitiveChunkedBuilder::::new( + name, + len, + len * capacity_factor, + DataType::Float64, + ); + + let linspace_impl = + |start, + end, + num_samples, + builder: &mut ListPrimitiveChunkedBuilder| { + let ls = + new_linear_space_f64(start, end, num_samples, closed, PlSmallStr::EMPTY)?; + builder.append_slice(ls.cont_slice().unwrap()); + Ok(()) + }; + let out = + linear_spaces_impl_broadcast(start, end, num_samples, linspace_impl, &mut builder)?; + + let to_type = array_width.map_or_else( + || DataType::List(Box::new(DataType::Float64)), + |width| DataType::Array(Box::new(DataType::Float64), width), + ); + out.cast(&to_type) + }, + } +} + +/// Create a ranges column from the given start/end columns and a range function. +pub(super) fn linear_spaces_impl_broadcast( + start: &ChunkedArray, + end: &ChunkedArray, + num_samples: &UInt64Chunked, + linear_space_impl: F, + builder: &mut ListPrimitiveChunkedBuilder, +) -> PolarsResult +where + T: PolarsFloatType, + F: Fn(T::Native, T::Native, u64, &mut ListPrimitiveChunkedBuilder) -> PolarsResult<()>, + ListPrimitiveChunkedBuilder: ListBuilderTrait, +{ + match (start.len(), end.len(), num_samples.len()) { + (len_start, len_end, len_samples) if len_start == len_end && len_start == len_samples => { + // (n, n, n) + build_linear_spaces::<_, _, _, T, F>( + start.iter(), + end.iter(), + num_samples.iter(), + linear_space_impl, + builder, + )?; + }, + // (1, n, n) + (1, len_end, len_samples) if len_end == len_samples => { + let start_value = start.get(0); + if start_value.is_some() { + build_linear_spaces::<_, _, _, T, F>( + std::iter::repeat(start_value), + end.iter(), + num_samples.iter(), + linear_space_impl, + builder, + )? + } else { + build_nulls(builder, len_end) + } + }, + // (n, 1, n) + (len_start, 1, len_samples) if len_start == len_samples => { + let end_value = end.get(0); + if end_value.is_some() { + build_linear_spaces::<_, _, _, T, F>( + start.iter(), + std::iter::repeat(end_value), + num_samples.iter(), + linear_space_impl, + builder, + )? + } else { + build_nulls(builder, len_start) + } + }, + // (n, n, 1) + (len_start, len_end, 1) if len_start == len_end => { + let num_samples_value = num_samples.get(0); + if num_samples_value.is_some() { + build_linear_spaces::<_, _, _, T, F>( + start.iter(), + end.iter(), + std::iter::repeat(num_samples_value), + linear_space_impl, + builder, + )? + } else { + build_nulls(builder, len_start) + } + }, + // (n, 1, 1) + (len_start, 1, 1) => { + let end_value = end.get(0); + let num_samples_value = num_samples.get(0); + match (end_value, num_samples_value) { + (Some(_), Some(_)) => build_linear_spaces::<_, _, _, T, F>( + start.iter(), + std::iter::repeat(end_value), + std::iter::repeat(num_samples_value), + linear_space_impl, + builder, + )?, + _ => build_nulls(builder, len_start), + } + }, + // (1, n, 1) + (1, len_end, 1) => { + let start_value = start.get(0); + let num_samples_value = num_samples.get(0); + match (start_value, num_samples_value) { + (Some(_), Some(_)) => build_linear_spaces::<_, _, _, T, F>( + std::iter::repeat(start_value), + end.iter(), + std::iter::repeat(num_samples_value), + linear_space_impl, + builder, + )?, + _ => build_nulls(builder, len_end), + } + }, + // (1, 1, n) + (1, 1, len_num_samples) => { + let start_value = start.get(0); + let end_value = end.get(0); + match (start_value, end_value) { + (Some(_), Some(_)) => build_linear_spaces::<_, _, _, T, F>( + std::iter::repeat(start_value), + std::iter::repeat(end_value), + num_samples.iter(), + linear_space_impl, + builder, + )?, + _ => build_nulls(builder, len_num_samples), + } + }, + (len_start, len_end, len_num_samples) => { + polars_bail!( + ComputeError: + "lengths of `start` ({}), `end` ({}), and `num_samples` ({}) do not match", + len_start, len_end, len_num_samples + ) + }, + }; + let out = builder.finish().into_column(); + Ok(out) +} + +/// Iterate over a start and end column and create a range for each entry. +fn build_linear_spaces( + start: I, + end: J, + num_samples: K, + linear_space_impl: F, + builder: &mut ListPrimitiveChunkedBuilder, +) -> PolarsResult<()> +where + I: Iterator>, + J: Iterator>, + K: Iterator>, + T: PolarsFloatType, + F: Fn(T::Native, T::Native, u64, &mut ListPrimitiveChunkedBuilder) -> PolarsResult<()>, + ListPrimitiveChunkedBuilder: ListBuilderTrait, +{ + for ((start, end), num_samples) in start.zip(end).zip(num_samples) { + match (start, end, num_samples) { + (Some(start), Some(end), Some(num_samples)) => { + linear_space_impl(start, end, num_samples, builder)? + }, + _ => builder.append_null(), + } + } + Ok(()) +} diff --git a/crates/polars-plan/src/dsl/function_expr/range/mod.rs b/crates/polars-plan/src/dsl/function_expr/range/mod.rs new file mode 100644 index 000000000000..d9e0696608f3 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/range/mod.rs @@ -0,0 +1,279 @@ +#[cfg(feature = "dtype-date")] +mod date_range; +#[cfg(feature = "dtype-datetime")] +mod datetime_range; +mod int_range; +mod linear_space; +#[cfg(feature = "dtype-time")] +mod time_range; +mod utils; + +use std::fmt::{Display, Formatter}; + +use polars_core::prelude::*; +use polars_ops::series::ClosedInterval; +#[cfg(feature = "temporal")] +use polars_time::{ClosedWindow, Duration}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use super::{FunctionExpr, FunctionOptions}; +use crate::dsl::SpecialEq; +use crate::dsl::function_expr::FieldsMapper; +use crate::map_as_slice; +use crate::prelude::ColumnsUdf; + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Clone, PartialEq, Debug, Eq, Hash)] +pub enum RangeFunction { + IntRange { + step: i64, + dtype: DataType, + }, + IntRanges, + LinearSpace { + closed: ClosedInterval, + }, + LinearSpaces { + closed: ClosedInterval, + array_width: Option, + }, + #[cfg(feature = "dtype-date")] + DateRange { + interval: Duration, + closed: ClosedWindow, + }, + #[cfg(feature = "dtype-date")] + DateRanges { + interval: Duration, + closed: ClosedWindow, + }, + #[cfg(feature = "dtype-datetime")] + DatetimeRange { + interval: Duration, + closed: ClosedWindow, + time_unit: Option, + time_zone: Option, + }, + #[cfg(feature = "dtype-datetime")] + DatetimeRanges { + interval: Duration, + closed: ClosedWindow, + time_unit: Option, + time_zone: Option, + }, + #[cfg(feature = "dtype-time")] + TimeRange { + interval: Duration, + closed: ClosedWindow, + }, + #[cfg(feature = "dtype-time")] + TimeRanges { + interval: Duration, + closed: ClosedWindow, + }, +} + +fn map_linspace_dtype(mapper: &FieldsMapper) -> PolarsResult { + let fields = mapper.args(); + let start_dtype = fields[0].dtype(); + let end_dtype = fields[1].dtype(); + Ok(match (start_dtype, end_dtype) { + (&DataType::Float32, &DataType::Float32) => DataType::Float32, + // A linear space of a Date produces a sequence of Datetimes + (dt1, dt2) if dt1.is_temporal() && dt1 == dt2 => { + if dt1 == &DataType::Date { + DataType::Datetime(TimeUnit::Milliseconds, None) + } else { + dt1.clone() + } + }, + (dt1, dt2) if !dt1.is_primitive_numeric() || !dt2.is_primitive_numeric() => { + polars_bail!(ComputeError: + "'start' and 'end' have incompatible dtypes, got {:?} and {:?}", + dt1, dt2 + ) + }, + _ => DataType::Float64, + }) +} + +impl RangeFunction { + pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult { + use RangeFunction::*; + match self { + IntRange { dtype, .. } => mapper.with_dtype(dtype.clone()), + IntRanges => mapper.with_dtype(DataType::List(Box::new(DataType::Int64))), + LinearSpace { .. } => mapper.with_dtype(map_linspace_dtype(&mapper)?), + LinearSpaces { + closed: _, + array_width, + } => { + let inner = Box::new(map_linspace_dtype(&mapper)?); + let dt = match array_width { + Some(width) => DataType::Array(inner, *width), + None => DataType::List(inner), + }; + mapper.with_dtype(dt) + }, + #[cfg(feature = "dtype-date")] + DateRange { .. } => mapper.with_dtype(DataType::Date), + #[cfg(feature = "dtype-date")] + DateRanges { .. } => mapper.with_dtype(DataType::List(Box::new(DataType::Date))), + #[cfg(feature = "dtype-datetime")] + DatetimeRange { + interval: _, + closed: _, + time_unit, + time_zone, + } => { + // output dtype may change based on `interval`, `time_unit`, and `time_zone` + let dtype = + mapper.map_to_datetime_range_dtype(time_unit.as_ref(), time_zone.as_ref())?; + mapper.with_dtype(dtype) + }, + #[cfg(feature = "dtype-datetime")] + DatetimeRanges { + interval: _, + closed: _, + time_unit, + time_zone, + } => { + // output dtype may change based on `interval`, `time_unit`, and `time_zone` + let inner_dtype = + mapper.map_to_datetime_range_dtype(time_unit.as_ref(), time_zone.as_ref())?; + mapper.with_dtype(DataType::List(Box::new(inner_dtype))) + }, + #[cfg(feature = "dtype-time")] + TimeRange { .. } => mapper.with_dtype(DataType::Time), + #[cfg(feature = "dtype-time")] + TimeRanges { .. } => mapper.with_dtype(DataType::List(Box::new(DataType::Time))), + } + } + + pub fn function_options(&self) -> FunctionOptions { + use RangeFunction as R; + match self { + R::IntRange { .. } => FunctionOptions::row_separable().with_allow_rename(true), + R::LinearSpace { .. } => FunctionOptions::row_separable().with_allow_rename(true), + #[cfg(feature = "dtype-date")] + R::DateRange { .. } => FunctionOptions::row_separable().with_allow_rename(true), + #[cfg(feature = "dtype-datetime")] + R::DatetimeRange { .. } => FunctionOptions::row_separable() + .with_allow_rename(true) + .with_supertyping(Default::default()), + #[cfg(feature = "dtype-time")] + R::TimeRange { .. } => FunctionOptions::row_separable().with_allow_rename(true), + R::IntRanges => FunctionOptions::elementwise().with_allow_rename(true), + R::LinearSpaces { .. } => FunctionOptions::elementwise().with_allow_rename(true), + #[cfg(feature = "dtype-date")] + R::DateRanges { .. } => FunctionOptions::elementwise().with_allow_rename(true), + #[cfg(feature = "dtype-datetime")] + R::DatetimeRanges { .. } => FunctionOptions::elementwise() + .with_allow_rename(true) + .with_supertyping(Default::default()), + #[cfg(feature = "dtype-time")] + R::TimeRanges { .. } => FunctionOptions::elementwise().with_allow_rename(true), + } + } +} + +impl Display for RangeFunction { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use RangeFunction::*; + let s = match self { + IntRange { .. } => "int_range", + IntRanges => "int_ranges", + LinearSpace { .. } => "linear_space", + LinearSpaces { .. } => "linear_spaces", + #[cfg(feature = "dtype-date")] + DateRange { .. } => "date_range", + #[cfg(feature = "temporal")] + DateRanges { .. } => "date_ranges", + #[cfg(feature = "dtype-datetime")] + DatetimeRange { .. } => "datetime_range", + #[cfg(feature = "dtype-datetime")] + DatetimeRanges { .. } => "datetime_ranges", + #[cfg(feature = "dtype-time")] + TimeRange { .. } => "time_range", + #[cfg(feature = "dtype-time")] + TimeRanges { .. } => "time_ranges", + }; + write!(f, "{s}") + } +} + +impl From for SpecialEq> { + fn from(func: RangeFunction) -> Self { + use RangeFunction::*; + match func { + IntRange { step, dtype } => { + map_as_slice!(int_range::int_range, step, dtype.clone()) + }, + IntRanges => { + map_as_slice!(int_range::int_ranges) + }, + LinearSpace { closed } => { + map_as_slice!(linear_space::linear_space, closed) + }, + LinearSpaces { + closed, + array_width, + } => { + map_as_slice!(linear_space::linear_spaces, closed, array_width) + }, + #[cfg(feature = "dtype-date")] + DateRange { interval, closed } => { + map_as_slice!(date_range::date_range, interval, closed) + }, + #[cfg(feature = "dtype-date")] + DateRanges { interval, closed } => { + map_as_slice!(date_range::date_ranges, interval, closed) + }, + #[cfg(feature = "dtype-datetime")] + DatetimeRange { + interval, + closed, + time_unit, + time_zone, + } => { + map_as_slice!( + datetime_range::datetime_range, + interval, + closed, + time_unit, + time_zone.clone() + ) + }, + #[cfg(feature = "dtype-datetime")] + DatetimeRanges { + interval, + closed, + time_unit, + time_zone, + } => { + map_as_slice!( + datetime_range::datetime_ranges, + interval, + closed, + time_unit, + time_zone.clone() + ) + }, + #[cfg(feature = "dtype-time")] + TimeRange { interval, closed } => { + map_as_slice!(time_range::time_range, interval, closed) + }, + #[cfg(feature = "dtype-time")] + TimeRanges { interval, closed } => { + map_as_slice!(time_range::time_ranges, interval, closed) + }, + } + } +} + +impl From for FunctionExpr { + fn from(value: RangeFunction) -> Self { + Self::Range(value) + } +} diff --git a/crates/polars-plan/src/dsl/function_expr/range/time_range.rs b/crates/polars-plan/src/dsl/function_expr/range/time_range.rs new file mode 100644 index 000000000000..640be245c40a --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/range/time_range.rs @@ -0,0 +1,66 @@ +use polars_core::prelude::*; +use polars_time::{ClosedWindow, Duration, time_range_impl}; + +use super::utils::{ + ensure_range_bounds_contain_exactly_one_value, temporal_ranges_impl_broadcast, + temporal_series_to_i64_scalar, +}; + +const CAPACITY_FACTOR: usize = 5; + +pub(super) fn time_range( + s: &[Column], + interval: Duration, + closed: ClosedWindow, +) -> PolarsResult { + let start = &s[0]; + let end = &s[1]; + let name = start.name(); + + ensure_range_bounds_contain_exactly_one_value(start, end)?; + + let dtype = DataType::Time; + let start = temporal_series_to_i64_scalar(&start.cast(&dtype)?) + .ok_or_else(|| polars_err!(ComputeError: "start is an out-of-range time."))?; + let end = temporal_series_to_i64_scalar(&end.cast(&dtype)?) + .ok_or_else(|| polars_err!(ComputeError: "end is an out-of-range time."))?; + + let out = time_range_impl(name.clone(), start, end, interval, closed)?; + Ok(out.cast(&dtype).unwrap().into_column()) +} + +pub(super) fn time_ranges( + s: &[Column], + interval: Duration, + closed: ClosedWindow, +) -> PolarsResult { + let start = &s[0]; + let end = &s[1]; + + let start = start.cast(&DataType::Time)?; + let end = end.cast(&DataType::Time)?; + + let start_phys = start.to_physical_repr(); + let end_phys = end.to_physical_repr(); + let start = start_phys.i64().unwrap(); + let end = end_phys.i64().unwrap(); + + let len = std::cmp::max(start.len(), end.len()); + let mut builder = ListPrimitiveChunkedBuilder::::new( + start.name().clone(), + len, + len * CAPACITY_FACTOR, + DataType::Int64, + ); + + let range_impl = |start, end, builder: &mut ListPrimitiveChunkedBuilder| { + let rng = time_range_impl(PlSmallStr::EMPTY, start, end, interval, closed)?; + builder.append_slice(rng.cont_slice().unwrap()); + Ok(()) + }; + + let out = temporal_ranges_impl_broadcast(start, end, range_impl, &mut builder)?; + + let to_type = DataType::List(Box::new(DataType::Time)); + out.cast(&to_type) +} diff --git a/crates/polars-plan/src/dsl/function_expr/range/utils.rs b/crates/polars-plan/src/dsl/function_expr/range/utils.rs new file mode 100644 index 000000000000..a2b7b2accd54 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/range/utils.rs @@ -0,0 +1,258 @@ +use polars_core::prelude::{ + ChunkedArray, Column, Int64Chunked, IntoColumn, ListBuilderTrait, ListPrimitiveChunkedBuilder, + PolarsIntegerType, PolarsNumericType, PolarsResult, polars_bail, polars_ensure, +}; + +pub(super) fn temporal_series_to_i64_scalar(s: &Column) -> Option { + s.to_physical_repr().get(0).unwrap().extract::() +} +pub(super) fn ensure_range_bounds_contain_exactly_one_value( + start: &Column, + end: &Column, +) -> PolarsResult<()> { + polars_ensure!( + start.len() == 1, + ComputeError: "`start` must contain exactly one value, got {} values", start.len() + ); + polars_ensure!( + end.len() == 1, + ComputeError: "`end` must contain exactly one value, got {} values", end.len() + ); + Ok(()) +} + +/// Create a numeric ranges column from the given start/end/step columns and a range function. +pub(super) fn numeric_ranges_impl_broadcast( + start: &ChunkedArray, + end: &ChunkedArray, + step: &Int64Chunked, + range_impl: F, + builder: &mut ListPrimitiveChunkedBuilder, +) -> PolarsResult +where + T: PolarsIntegerType, + U: PolarsIntegerType, + F: Fn(T::Native, T::Native, i64, &mut ListPrimitiveChunkedBuilder) -> PolarsResult<()>, + ListPrimitiveChunkedBuilder: ListBuilderTrait, +{ + match (start.len(), end.len(), step.len()) { + (len_start, len_end, len_step) if len_start == len_end && len_start == len_step => { + build_numeric_ranges::<_, _, _, T, U, F>( + start.downcast_iter().flatten(), + end.downcast_iter().flatten(), + step.downcast_iter().flatten(), + range_impl, + builder, + )?; + }, + (1, len_end, 1) => { + let start_scalar = start.get(0); + let step_scalar = step.get(0); + match (start_scalar, step_scalar) { + (Some(start), Some(step)) => build_numeric_ranges::<_, _, _, T, U, F>( + std::iter::repeat(Some(&start)), + end.downcast_iter().flatten(), + std::iter::repeat(Some(&step)), + range_impl, + builder, + )?, + _ => build_nulls(builder, len_end), + } + }, + (len_start, 1, 1) => { + let end_scalar = end.get(0); + let step_scalar = step.get(0); + match (end_scalar, step_scalar) { + (Some(end), Some(step)) => build_numeric_ranges::<_, _, _, T, U, F>( + start.downcast_iter().flatten(), + std::iter::repeat(Some(&end)), + std::iter::repeat(Some(&step)), + range_impl, + builder, + )?, + _ => build_nulls(builder, len_start), + } + }, + (1, 1, len_step) => { + let start_scalar = start.get(0); + let end_scalar = end.get(0); + match (start_scalar, end_scalar) { + (Some(start), Some(end)) => build_numeric_ranges::<_, _, _, T, U, F>( + std::iter::repeat(Some(&start)), + std::iter::repeat(Some(&end)), + step.downcast_iter().flatten(), + range_impl, + builder, + )?, + _ => build_nulls(builder, len_step), + } + }, + (len_start, len_end, 1) if len_start == len_end => { + let step_scalar = step.get(0); + match step_scalar { + Some(step) => build_numeric_ranges::<_, _, _, T, U, F>( + start.downcast_iter().flatten(), + end.downcast_iter().flatten(), + std::iter::repeat(Some(&step)), + range_impl, + builder, + )?, + None => build_nulls(builder, len_start), + } + }, + (len_start, 1, len_step) if len_start == len_step => { + let end_scalar = end.get(0); + match end_scalar { + Some(end) => build_numeric_ranges::<_, _, _, T, U, F>( + start.downcast_iter().flatten(), + std::iter::repeat(Some(&end)), + step.downcast_iter().flatten(), + range_impl, + builder, + )?, + None => build_nulls(builder, len_start), + } + }, + (1, len_end, len_step) if len_end == len_step => { + let start_scalar = start.get(0); + match start_scalar { + Some(start) => build_numeric_ranges::<_, _, _, T, U, F>( + std::iter::repeat(Some(&start)), + end.downcast_iter().flatten(), + step.downcast_iter().flatten(), + range_impl, + builder, + )?, + None => build_nulls(builder, len_end), + } + }, + (len_start, len_end, len_step) => { + polars_bail!( + ComputeError: + "lengths of `start` ({}), `end` ({}) and `step` ({}) do not match", + len_start, len_end, len_step + ) + }, + }; + let out = builder.finish().into_column(); + Ok(out) +} + +/// Create a ranges column from the given start/end columns and a range function. +pub(super) fn temporal_ranges_impl_broadcast( + start: &ChunkedArray, + end: &ChunkedArray, + range_impl: F, + builder: &mut ListPrimitiveChunkedBuilder, +) -> PolarsResult +where + T: PolarsIntegerType, + U: PolarsIntegerType, + F: Fn(T::Native, T::Native, &mut ListPrimitiveChunkedBuilder) -> PolarsResult<()>, + ListPrimitiveChunkedBuilder: ListBuilderTrait, +{ + match (start.len(), end.len()) { + (len_start, len_end) if len_start == len_end => { + build_temporal_ranges::<_, _, T, U, F>( + start.downcast_iter().flatten(), + end.downcast_iter().flatten(), + range_impl, + builder, + )?; + }, + (1, len_end) => { + let start_scalar = start.get(0); + match start_scalar { + Some(start) => build_temporal_ranges::<_, _, T, U, F>( + std::iter::repeat(Some(&start)), + end.downcast_iter().flatten(), + range_impl, + builder, + )?, + None => build_nulls(builder, len_end), + } + }, + (len_start, 1) => { + let end_scalar = end.get(0); + match end_scalar { + Some(end) => build_temporal_ranges::<_, _, T, U, F>( + start.downcast_iter().flatten(), + std::iter::repeat(Some(&end)), + range_impl, + builder, + )?, + None => build_nulls(builder, len_start), + } + }, + (len_start, len_end) => { + polars_bail!( + ComputeError: + "lengths of `start` ({}) and `end` ({}) do not match", + len_start, len_end + ) + }, + }; + let out = builder.finish().into_column(); + Ok(out) +} + +/// Iterate over a start and end column and create a range with the step for each entry. +fn build_numeric_ranges<'a, I, J, K, T, U, F>( + start: I, + end: J, + step: K, + range_impl: F, + builder: &mut ListPrimitiveChunkedBuilder, +) -> PolarsResult<()> +where + I: Iterator>, + J: Iterator>, + K: Iterator>, + T: PolarsIntegerType, + U: PolarsIntegerType, + F: Fn(T::Native, T::Native, i64, &mut ListPrimitiveChunkedBuilder) -> PolarsResult<()>, + ListPrimitiveChunkedBuilder: ListBuilderTrait, +{ + for ((start, end), step) in start.zip(end).zip(step) { + match (start, end, step) { + (Some(start), Some(end), Some(step)) => range_impl(*start, *end, *step, builder)?, + _ => builder.append_null(), + } + } + Ok(()) +} + +/// Iterate over a start and end column and create a range for each entry. +fn build_temporal_ranges<'a, I, J, T, U, F>( + start: I, + end: J, + range_impl: F, + builder: &mut ListPrimitiveChunkedBuilder, +) -> PolarsResult<()> +where + I: Iterator>, + J: Iterator>, + T: PolarsIntegerType, + U: PolarsIntegerType, + F: Fn(T::Native, T::Native, &mut ListPrimitiveChunkedBuilder) -> PolarsResult<()>, + ListPrimitiveChunkedBuilder: ListBuilderTrait, +{ + for (start, end) in start.zip(end) { + match (start, end) { + (Some(start), Some(end)) => range_impl(*start, *end, builder)?, + _ => builder.append_null(), + } + } + Ok(()) +} + +/// Add `n` nulls to the builder. +pub fn build_nulls(builder: &mut ListPrimitiveChunkedBuilder, n: usize) +where + U: PolarsNumericType, + ListPrimitiveChunkedBuilder: ListBuilderTrait, +{ + for _ in 0..n { + builder.append_null() + } +} diff --git a/crates/polars-plan/src/dsl/function_expr/repeat.rs b/crates/polars-plan/src/dsl/function_expr/repeat.rs new file mode 100644 index 000000000000..229f88fc050a --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/repeat.rs @@ -0,0 +1,18 @@ +use polars_core::prelude::{Column, PolarsResult, polars_ensure, polars_err}; + +pub fn repeat(args: &[Column]) -> PolarsResult { + let c = &args[0]; + let n = &args[1]; + + polars_ensure!( + n.dtype().is_integer(), + SchemaMismatch: "expected expression of dtype 'integer', got '{}'", n.dtype() + ); + + let first_value = n.get(0)?; + let n = first_value.extract::().ok_or_else( + || polars_err!(ComputeError: "could not parse value '{}' as a size.", first_value), + )?; + + Ok(c.new_from_index(0, n)) +} diff --git a/crates/polars-plan/src/dsl/function_expr/rolling.rs b/crates/polars-plan/src/dsl/function_expr/rolling.rs new file mode 100644 index 000000000000..d7c7f3371c9f --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/rolling.rs @@ -0,0 +1,228 @@ +#[cfg(feature = "cov")] +use std::ops::BitAnd; + +use polars_core::utils::Container; +use polars_time::chunkedarray::*; + +use super::*; +#[cfg(feature = "cov")] +use crate::dsl::pow::pow; + +#[derive(Clone, PartialEq, Debug, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum RollingFunction { + Min(RollingOptionsFixedWindow), + Max(RollingOptionsFixedWindow), + Mean(RollingOptionsFixedWindow), + Sum(RollingOptionsFixedWindow), + Quantile(RollingOptionsFixedWindow), + Var(RollingOptionsFixedWindow), + Std(RollingOptionsFixedWindow), + #[cfg(feature = "moment")] + Skew(RollingOptionsFixedWindow), + #[cfg(feature = "moment")] + Kurtosis(RollingOptionsFixedWindow), + #[cfg(feature = "cov")] + CorrCov { + rolling_options: RollingOptionsFixedWindow, + corr_cov_options: RollingCovOptions, + // Whether is Corr or Cov + is_corr: bool, + }, +} + +impl Display for RollingFunction { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use RollingFunction::*; + + let name = match self { + Min(_) => "min", + Max(_) => "max", + Mean(_) => "mean", + Sum(_) => "rsum", + Quantile(_) => "quantile", + Var(_) => "var", + Std(_) => "std", + #[cfg(feature = "moment")] + Skew(..) => "skew", + #[cfg(feature = "moment")] + Kurtosis(..) => "kurtosis", + #[cfg(feature = "cov")] + CorrCov { is_corr, .. } => { + if *is_corr { + "corr" + } else { + "cov" + } + }, + }; + + write!(f, "rolling_{name}") + } +} + +pub(super) fn rolling_min(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult { + // @scalar-opt + s.as_materialized_series() + .rolling_min(options) + .map(Column::from) +} + +pub(super) fn rolling_max(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult { + // @scalar-opt + s.as_materialized_series() + .rolling_max(options) + .map(Column::from) +} + +pub(super) fn rolling_mean(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult { + // @scalar-opt + s.as_materialized_series() + .rolling_mean(options) + .map(Column::from) +} + +pub(super) fn rolling_sum(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult { + // @scalar-opt + s.as_materialized_series() + .rolling_sum(options) + .map(Column::from) +} + +pub(super) fn rolling_quantile( + s: &Column, + options: RollingOptionsFixedWindow, +) -> PolarsResult { + // @scalar-opt + s.as_materialized_series() + .rolling_quantile(options) + .map(Column::from) +} + +pub(super) fn rolling_var(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult { + // @scalar-opt + s.as_materialized_series() + .rolling_var(options) + .map(Column::from) +} + +pub(super) fn rolling_std(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult { + // @scalar-opt + s.as_materialized_series() + .rolling_std(options) + .map(Column::from) +} + +#[cfg(feature = "moment")] +pub(super) fn rolling_skew(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult { + // @scalar-opt + let s = s.as_materialized_series(); + polars_ops::series::rolling_skew(s, options).map(Column::from) +} + +#[cfg(feature = "moment")] +pub(super) fn rolling_kurtosis( + s: &Column, + options: RollingOptionsFixedWindow, +) -> PolarsResult { + // @scalar-opt + let s = s.as_materialized_series(); + polars_ops::series::rolling_kurtosis(s, options).map(Column::from) +} + +#[cfg(feature = "cov")] +fn det_count_x_y(window_size: usize, len: usize, dtype: &DataType) -> Series { + match dtype { + DataType::Float64 => { + let values = (0..len) + .map(|v| std::cmp::min(window_size, v + 1) as f64) + .collect::>(); + Series::new(PlSmallStr::EMPTY, values) + }, + DataType::Float32 => { + let values = (0..len) + .map(|v| std::cmp::min(window_size, v + 1) as f32) + .collect::>(); + Series::new(PlSmallStr::EMPTY, values) + }, + _ => unreachable!(), + } +} + +#[cfg(feature = "cov")] +pub(super) fn rolling_corr_cov( + s: &[Column], + rolling_options: RollingOptionsFixedWindow, + cov_options: RollingCovOptions, + is_corr: bool, +) -> PolarsResult { + let mut x = s[0].as_materialized_series().rechunk(); + let mut y = s[1].as_materialized_series().rechunk(); + + if !x.dtype().is_float() { + x = x.cast(&DataType::Float64)?; + } + if !y.dtype().is_float() { + y = y.cast(&DataType::Float64)?; + } + let dtype = x.dtype().clone(); + + let mean_x_y = (&x * &y)?.rolling_mean(rolling_options.clone())?; + let rolling_options_count = RollingOptionsFixedWindow { + window_size: rolling_options.window_size, + min_periods: 0, + ..Default::default() + }; + + let count_x_y = if (x.null_count() + y.null_count()) > 0 { + // mask out nulls on both sides before compute mean/var + let valids = x.is_not_null().bitand(y.is_not_null()); + let valids_arr = valids.downcast_as_array(); + let valids_bitmap = valids_arr.values(); + + unsafe { + let xarr = &mut x.chunks_mut()[0]; + *xarr = xarr.with_validity(Some(valids_bitmap.clone())); + let yarr = &mut y.chunks_mut()[0]; + *yarr = yarr.with_validity(Some(valids_bitmap.clone())); + x.compute_len(); + y.compute_len(); + } + valids + .cast(&dtype) + .unwrap() + .rolling_sum(rolling_options_count)? + } else { + det_count_x_y(rolling_options.window_size, x.len(), &dtype) + }; + + let mean_x = x.rolling_mean(rolling_options.clone())?; + let mean_y = y.rolling_mean(rolling_options.clone())?; + let ddof = Series::new( + PlSmallStr::EMPTY, + &[AnyValue::from(cov_options.ddof).cast(&dtype)], + ); + + let numerator = ((mean_x_y - (mean_x * mean_y).unwrap()).unwrap() + * (count_x_y.clone() / (count_x_y - ddof).unwrap()).unwrap()) + .unwrap(); + + if is_corr { + let var_x = x.rolling_var(rolling_options.clone())?; + let var_y = y.rolling_var(rolling_options.clone())?; + + let base = (var_x * var_y).unwrap(); + let sc = Scalar::new( + base.dtype().clone(), + AnyValue::Float64(0.5).cast(&dtype).into_static(), + ); + let denominator = pow(&mut [base.into_column(), sc.into_column("".into())]) + .unwrap() + .unwrap() + .take_materialized_series(); + + Ok((numerator / denominator)?.into_column()) + } else { + Ok(numerator.into_column()) + } +} diff --git a/crates/polars-plan/src/dsl/function_expr/rolling_by.rs b/crates/polars-plan/src/dsl/function_expr/rolling_by.rs new file mode 100644 index 000000000000..3077c83355f2 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/rolling_by.rs @@ -0,0 +1,109 @@ +use polars_time::chunkedarray::*; + +use super::*; + +#[derive(Clone, PartialEq, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum RollingFunctionBy { + MinBy(RollingOptionsDynamicWindow), + MaxBy(RollingOptionsDynamicWindow), + MeanBy(RollingOptionsDynamicWindow), + SumBy(RollingOptionsDynamicWindow), + QuantileBy(RollingOptionsDynamicWindow), + VarBy(RollingOptionsDynamicWindow), + StdBy(RollingOptionsDynamicWindow), +} + +impl Display for RollingFunctionBy { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use RollingFunctionBy::*; + + let name = match self { + MinBy(_) => "rolling_min_by", + MaxBy(_) => "rolling_max_by", + MeanBy(_) => "rolling_mean_by", + SumBy(_) => "rolling_sum_by", + QuantileBy(_) => "rolling_quantile_by", + VarBy(_) => "rolling_var_by", + StdBy(_) => "rolling_std_by", + }; + + write!(f, "{name}") + } +} + +impl Hash for RollingFunctionBy { + fn hash(&self, state: &mut H) { + std::mem::discriminant(self).hash(state); + } +} + +pub(super) fn rolling_min_by( + s: &[Column], + options: RollingOptionsDynamicWindow, +) -> PolarsResult { + // @scalar-opt + s[0].as_materialized_series() + .rolling_min_by(s[1].as_materialized_series(), options) + .map(Column::from) +} + +pub(super) fn rolling_max_by( + s: &[Column], + options: RollingOptionsDynamicWindow, +) -> PolarsResult { + // @scalar-opt + s[0].as_materialized_series() + .rolling_max_by(s[1].as_materialized_series(), options) + .map(Column::from) +} + +pub(super) fn rolling_mean_by( + s: &[Column], + options: RollingOptionsDynamicWindow, +) -> PolarsResult { + // @scalar-opt + s[0].as_materialized_series() + .rolling_mean_by(s[1].as_materialized_series(), options) + .map(Column::from) +} + +pub(super) fn rolling_sum_by( + s: &[Column], + options: RollingOptionsDynamicWindow, +) -> PolarsResult { + // @scalar-opt + s[0].as_materialized_series() + .rolling_sum_by(s[1].as_materialized_series(), options) + .map(Column::from) +} + +pub(super) fn rolling_quantile_by( + s: &[Column], + options: RollingOptionsDynamicWindow, +) -> PolarsResult { + // @scalar-opt + s[0].as_materialized_series() + .rolling_quantile_by(s[1].as_materialized_series(), options) + .map(Column::from) +} + +pub(super) fn rolling_var_by( + s: &[Column], + options: RollingOptionsDynamicWindow, +) -> PolarsResult { + // @scalar-opt + s[0].as_materialized_series() + .rolling_var_by(s[1].as_materialized_series(), options) + .map(Column::from) +} + +pub(super) fn rolling_std_by( + s: &[Column], + options: RollingOptionsDynamicWindow, +) -> PolarsResult { + // @scalar-opt + s[0].as_materialized_series() + .rolling_std_by(s[1].as_materialized_series(), options) + .map(Column::from) +} diff --git a/crates/polars-plan/src/dsl/function_expr/round.rs b/crates/polars-plan/src/dsl/function_expr/round.rs new file mode 100644 index 000000000000..e2fe75b240d5 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/round.rs @@ -0,0 +1,17 @@ +use super::*; + +pub(super) fn round(c: &Column, decimals: u32) -> PolarsResult { + c.try_apply_unary_elementwise(|s| s.round(decimals)) +} + +pub(super) fn round_sig_figs(c: &Column, digits: i32) -> PolarsResult { + c.try_apply_unary_elementwise(|s| s.round_sig_figs(digits)) +} + +pub(super) fn floor(c: &Column) -> PolarsResult { + c.try_apply_unary_elementwise(Series::floor) +} + +pub(super) fn ceil(c: &Column) -> PolarsResult { + c.try_apply_unary_elementwise(Series::ceil) +} diff --git a/crates/polars-plan/src/dsl/function_expr/row_hash.rs b/crates/polars-plan/src/dsl/function_expr/row_hash.rs new file mode 100644 index 000000000000..e3e29196714b --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/row_hash.rs @@ -0,0 +1,17 @@ +use std::hash::BuildHasher; + +use polars_utils::aliases::{ + PlFixedStateQuality, PlSeedableRandomStateQuality, SeedableFromU64SeedExt, +}; + +use super::*; + +pub(super) fn row_hash(c: &Column, k0: u64, k1: u64, k2: u64, k3: u64) -> PolarsResult { + // TODO: don't expose all these seeds. + let seed = PlFixedStateQuality::default().hash_one((k0, k1, k2, k3)); + + // @scalar-opt + Ok(c.as_materialized_series() + .hash(PlSeedableRandomStateQuality::seed_from_u64(seed)) + .into_column()) +} diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs new file mode 100644 index 000000000000..1286dc321d65 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -0,0 +1,662 @@ +use polars_core::utils::materialize_dyn_int; + +use super::*; + +impl FunctionExpr { + pub(crate) fn get_field( + &self, + _input_schema: &Schema, + _cntxt: Context, + fields: &[Field], + ) -> PolarsResult { + use FunctionExpr::*; + + let mapper = FieldsMapper { fields }; + match self { + // Namespaces + #[cfg(feature = "dtype-array")] + ArrayExpr(func) => func.get_field(mapper), + BinaryExpr(s) => s.get_field(mapper), + #[cfg(feature = "dtype-categorical")] + Categorical(func) => func.get_field(mapper), + ListExpr(func) => func.get_field(mapper), + #[cfg(feature = "strings")] + StringExpr(s) => s.get_field(mapper), + #[cfg(feature = "dtype-struct")] + StructExpr(s) => s.get_field(mapper), + #[cfg(feature = "temporal")] + TemporalExpr(fun) => fun.get_field(mapper), + #[cfg(feature = "bitwise")] + Bitwise(fun) => fun.get_field(mapper), + + // Other expressions + Boolean(func) => func.get_field(mapper), + #[cfg(feature = "business")] + Business(func) => func.get_field(mapper), + #[cfg(feature = "abs")] + Abs => mapper.with_same_dtype(), + Negate => mapper.with_same_dtype(), + NullCount => mapper.with_dtype(IDX_DTYPE), + Pow(pow_function) => match pow_function { + PowFunction::Generic => mapper.pow_dtype(), + _ => mapper.map_to_float_dtype(), + }, + Coalesce => mapper.map_to_supertype(), + #[cfg(feature = "row_hash")] + Hash(..) => mapper.with_dtype(DataType::UInt64), + #[cfg(feature = "arg_where")] + ArgWhere => mapper.with_dtype(IDX_DTYPE), + #[cfg(feature = "index_of")] + IndexOf => mapper.with_dtype(IDX_DTYPE), + #[cfg(feature = "search_sorted")] + SearchSorted(_) => mapper.with_dtype(IDX_DTYPE), + #[cfg(feature = "range")] + Range(func) => func.get_field(mapper), + #[cfg(feature = "trigonometry")] + Trigonometry(_) => mapper.map_to_float_dtype(), + #[cfg(feature = "trigonometry")] + Atan2 => mapper.map_to_float_dtype(), + #[cfg(feature = "sign")] + Sign => mapper.with_dtype(DataType::Int64), + FillNull => mapper.map_to_supertype(), + #[cfg(feature = "rolling_window")] + RollingExpr(rolling_func, ..) => { + use RollingFunction::*; + match rolling_func { + Min(_) | Max(_) => mapper.with_same_dtype(), + Mean(_) | Quantile(_) | Var(_) | Std(_) => mapper.map_to_float_dtype(), + Sum(_) => mapper.sum_dtype(), + #[cfg(feature = "cov")] + CorrCov {..} => mapper.map_to_float_dtype(), + #[cfg(feature = "moment")] + Skew(..) | Kurtosis(..) => mapper.map_to_float_dtype(), + } + }, + #[cfg(feature = "rolling_window_by")] + RollingExprBy(rolling_func, ..) => { + use RollingFunctionBy::*; + match rolling_func { + MinBy(_) | MaxBy(_) => mapper.with_same_dtype(), + MeanBy(_) | QuantileBy(_) | VarBy(_) | StdBy(_) => mapper.map_to_float_dtype(), + SumBy(_) => mapper.sum_dtype(), + } + }, + ShiftAndFill => mapper.with_same_dtype(), + DropNans => mapper.with_same_dtype(), + DropNulls => mapper.with_same_dtype(), + #[cfg(feature = "round_series")] + Clip { .. } => mapper.with_same_dtype(), + #[cfg(feature = "mode")] + Mode => mapper.with_same_dtype(), + #[cfg(feature = "moment")] + Skew(_) => mapper.with_dtype(DataType::Float64), + #[cfg(feature = "moment")] + Kurtosis(..) => mapper.with_dtype(DataType::Float64), + ArgUnique => mapper.with_dtype(IDX_DTYPE), + Repeat => mapper.with_same_dtype(), + #[cfg(feature = "rank")] + Rank { options, .. } => mapper.with_dtype(match options.method { + RankMethod::Average => DataType::Float64, + _ => IDX_DTYPE, + }), + #[cfg(feature = "dtype-struct")] + AsStruct => Ok(Field::new( + fields[0].name().clone(), + DataType::Struct(fields.to_vec()), + )), + #[cfg(feature = "top_k")] + TopK { .. } => mapper.with_same_dtype(), + #[cfg(feature = "top_k")] + TopKBy { .. } => mapper.with_same_dtype(), + #[cfg(feature = "dtype-struct")] + ValueCounts { + sort: _, + parallel: _, + name, + normalize, + } => mapper.map_dtype(|dt| { + let count_dt = if *normalize { + DataType::Float64 + } else { + IDX_DTYPE + }; + DataType::Struct(vec![ + Field::new(fields[0].name().clone(), dt.clone()), + Field::new(name.clone(), count_dt), + ]) + }), + #[cfg(feature = "unique_counts")] + UniqueCounts => mapper.with_dtype(IDX_DTYPE), + Shift | Reverse => mapper.with_same_dtype(), + #[cfg(feature = "cum_agg")] + CumCount { .. } => mapper.with_dtype(IDX_DTYPE), + #[cfg(feature = "cum_agg")] + CumSum { .. } => mapper.map_dtype(cum::dtypes::cum_sum), + #[cfg(feature = "cum_agg")] + CumProd { .. } => mapper.map_dtype(cum::dtypes::cum_prod), + #[cfg(feature = "cum_agg")] + CumMin { .. } => mapper.with_same_dtype(), + #[cfg(feature = "cum_agg")] + CumMax { .. } => mapper.with_same_dtype(), + #[cfg(feature = "approx_unique")] + ApproxNUnique => mapper.with_dtype(IDX_DTYPE), + #[cfg(feature = "hist")] + Hist { + include_category, + include_breakpoint, + .. + } => { + if *include_breakpoint || *include_category { + let mut fields = Vec::with_capacity(3); + if *include_breakpoint { + fields.push(Field::new( + PlSmallStr::from_static("breakpoint"), + DataType::Float64, + )); + } + if *include_category { + fields.push(Field::new( + PlSmallStr::from_static("category"), + DataType::Categorical(None, Default::default()), + )); + } + fields.push(Field::new(PlSmallStr::from_static("count"), IDX_DTYPE)); + mapper.with_dtype(DataType::Struct(fields)) + } else { + mapper.with_dtype(IDX_DTYPE) + } + }, + #[cfg(feature = "diff")] + Diff(_) => mapper.map_dtype(|dt| match dt { + #[cfg(feature = "dtype-datetime")] + DataType::Datetime(tu, _) => DataType::Duration(*tu), + #[cfg(feature = "dtype-date")] + DataType::Date => DataType::Duration(TimeUnit::Milliseconds), + #[cfg(feature = "dtype-time")] + DataType::Time => DataType::Duration(TimeUnit::Nanoseconds), + DataType::UInt64 | DataType::UInt32 => DataType::Int64, + DataType::UInt16 => DataType::Int32, + DataType::UInt8 => DataType::Int16, + dt => dt.clone(), + }), + #[cfg(feature = "pct_change")] + PctChange => mapper.map_dtype(|dt| match dt { + DataType::Float64 | DataType::Float32 => dt.clone(), + _ => DataType::Float64, + }), + #[cfg(feature = "interpolate")] + Interpolate(method) => match method { + InterpolationMethod::Linear => mapper.map_numeric_to_float_dtype(), + InterpolationMethod::Nearest => mapper.with_same_dtype(), + }, + #[cfg(feature = "interpolate_by")] + InterpolateBy => mapper.map_numeric_to_float_dtype(), + ShrinkType => { + // we return the smallest type this can return + // this might not be correct once the actual data + // comes in, but if we set the smallest datatype + // we have the least chance that the smaller dtypes + // get cast to larger types in type-coercion + // this will lead to an incorrect schema in polars + // but we because only the numeric types deviate in + // bit size this will likely not lead to issues + mapper.map_dtype(|dt| { + if dt.is_primitive_numeric() { + if dt.is_float() { + DataType::Float32 + } else if dt.is_unsigned_integer() { + DataType::Int8 + } else { + DataType::UInt8 + } + } else { + dt.clone() + } + }) + }, + #[cfg(feature = "log")] + Entropy { .. } | Log { .. } | Log1p | Exp => mapper.map_to_float_dtype(), + Unique(_) => mapper.with_same_dtype(), + #[cfg(feature = "round_series")] + Round { .. } | RoundSF { .. } | Floor | Ceil => mapper.with_same_dtype(), + UpperBound | LowerBound => mapper.with_same_dtype(), + #[cfg(feature = "fused")] + Fused(_) => mapper.map_to_supertype(), + ConcatExpr(_) => mapper.map_to_supertype(), + #[cfg(feature = "cov")] + Correlation { .. } => mapper.map_to_float_dtype(), + #[cfg(feature = "peaks")] + PeakMin => mapper.with_same_dtype(), + #[cfg(feature = "peaks")] + PeakMax => mapper.with_same_dtype(), + #[cfg(feature = "cutqcut")] + Cut { + include_breaks: false, + .. + } => mapper.with_dtype(DataType::Categorical(None, Default::default())), + #[cfg(feature = "cutqcut")] + Cut { + include_breaks: true, + .. + } => { + let struct_dt = DataType::Struct(vec![ + Field::new(PlSmallStr::from_static("breakpoint"), DataType::Float64), + Field::new( + PlSmallStr::from_static("category"), + DataType::Categorical(None, Default::default()), + ), + ]); + mapper.with_dtype(struct_dt) + }, + #[cfg(feature = "repeat_by")] + RepeatBy => mapper.map_dtype(|dt| DataType::List(dt.clone().into())), + #[cfg(feature = "dtype-array")] + Reshape(dims) => mapper.try_map_dtype(|dt: &DataType| { + let dtype = dt.inner_dtype().unwrap_or(dt).clone(); + + if dims.len() == 1 { + return Ok(dtype); + } + + let num_infers = dims.iter().filter(|d| matches!(d, ReshapeDimension::Infer)).count(); + + polars_ensure!(num_infers <= 1, InvalidOperation: "can only specify one inferred dimension"); + + let mut inferred_size = 0; + if num_infers == 1 { + let mut total_size = 1u64; + let mut current = dt; + while let DataType::Array(dt, width) = current { + if *width == 0 { + total_size = 0; + break; + } + + current = dt.as_ref(); + total_size *= *width as u64; + } + + let current_size = dims.iter().map(|d| d.get_or_infer(1)).product::(); + inferred_size = total_size / current_size; + } + + let mut prev_dtype = dtype.leaf_dtype().clone(); + + // We pop the outer dimension as that is the height of the series. + for dim in &dims[1..] { + prev_dtype = DataType::Array(Box::new(prev_dtype), dim.get_or_infer(inferred_size) as usize); + } + Ok(prev_dtype) + }), + #[cfg(feature = "cutqcut")] + QCut { + include_breaks: false, + .. + } => mapper.with_dtype(DataType::Categorical(None, Default::default())), + #[cfg(feature = "cutqcut")] + QCut { + include_breaks: true, + .. + } => { + let struct_dt = DataType::Struct(vec![ + Field::new(PlSmallStr::from_static("breakpoint"), DataType::Float64), + Field::new( + PlSmallStr::from_static("category"), + DataType::Categorical(None, Default::default()), + ), + ]); + mapper.with_dtype(struct_dt) + }, + #[cfg(feature = "rle")] + RLE => mapper.map_dtype(|dt| { + DataType::Struct(vec![ + Field::new(PlSmallStr::from_static("len"), IDX_DTYPE), + Field::new(PlSmallStr::from_static("value"), dt.clone()), + ]) + }), + #[cfg(feature = "rle")] + RLEID => mapper.with_dtype(IDX_DTYPE), + ToPhysical => mapper.to_physical_type(), + #[cfg(feature = "random")] + Random { .. } => mapper.with_same_dtype(), + SetSortedFlag(_) => mapper.with_same_dtype(), + #[cfg(feature = "ffi_plugin")] + FfiPlugin { + flags: _, + lib, + symbol, + kwargs, + } => unsafe { plugin::plugin_field(fields, lib, symbol.as_ref(), kwargs) }, + MaxHorizontal => mapper.map_to_supertype(), + MinHorizontal => mapper.map_to_supertype(), + SumHorizontal { .. } => { + mapper.map_to_supertype().map(|mut f| { + if f.dtype == DataType::Boolean { + f.dtype = IDX_DTYPE; + } + f + }) + }, + MeanHorizontal { .. } => { + mapper.map_to_supertype().map(|mut f| { + match f.dtype { + dt @ DataType::Float32 => { f.dtype = dt; }, + _ => { f.dtype = DataType::Float64; }, + }; + f + }) + } + #[cfg(feature = "ewma")] + EwmMean { .. } => mapper.map_to_float_dtype(), + #[cfg(feature = "ewma_by")] + EwmMeanBy { .. } => mapper.map_to_float_dtype(), + #[cfg(feature = "ewma")] + EwmStd { .. } => mapper.map_to_float_dtype(), + #[cfg(feature = "ewma")] + EwmVar { .. } => mapper.map_to_float_dtype(), + #[cfg(feature = "replace")] + Replace => mapper.with_same_dtype(), + #[cfg(feature = "replace")] + ReplaceStrict { return_dtype } => mapper.replace_dtype(return_dtype.clone()), + FillNullWithStrategy(_) => mapper.with_same_dtype(), + GatherEvery { .. } => mapper.with_same_dtype(), + #[cfg(feature = "reinterpret")] + Reinterpret(signed) => { + let dt = if *signed { + DataType::Int64 + } else { + DataType::UInt64 + }; + mapper.with_dtype(dt) + }, + ExtendConstant => mapper.with_same_dtype(), + } + } + + pub(crate) fn output_name(&self) -> Option { + match self { + #[cfg(feature = "dtype-struct")] + FunctionExpr::StructExpr(StructFunction::FieldByName(name)) => { + Some(OutputName::Field(name.clone())) + }, + _ => None, + } + } +} + +pub struct FieldsMapper<'a> { + fields: &'a [Field], +} + +impl<'a> FieldsMapper<'a> { + pub fn new(fields: &'a [Field]) -> Self { + Self { fields } + } + + pub fn args(&self) -> &[Field] { + self.fields + } + + /// Field with the same dtype. + pub fn with_same_dtype(&self) -> PolarsResult { + self.map_dtype(|dtype| dtype.clone()) + } + + /// Set a dtype. + pub fn with_dtype(&self, dtype: DataType) -> PolarsResult { + Ok(Field::new(self.fields[0].name().clone(), dtype)) + } + + /// Map a single dtype. + pub fn map_dtype(&self, func: impl FnOnce(&DataType) -> DataType) -> PolarsResult { + let dtype = func(self.fields[0].dtype()); + Ok(Field::new(self.fields[0].name().clone(), dtype)) + } + + pub fn get_fields_lens(&self) -> usize { + self.fields.len() + } + + /// Map a single field with a potentially failing mapper function. + pub fn try_map_field( + &self, + func: impl FnOnce(&Field) -> PolarsResult, + ) -> PolarsResult { + func(&self.fields[0]) + } + + /// Map to a float supertype. + pub fn map_to_float_dtype(&self) -> PolarsResult { + self.map_dtype(|dtype| match dtype { + DataType::Float32 => DataType::Float32, + _ => DataType::Float64, + }) + } + + /// Map to a float supertype if numeric, else preserve + pub fn map_numeric_to_float_dtype(&self) -> PolarsResult { + self.map_dtype(|dtype| { + if dtype.is_primitive_numeric() { + match dtype { + DataType::Float32 => DataType::Float32, + _ => DataType::Float64, + } + } else { + dtype.clone() + } + }) + } + + /// Map to a physical type. + pub fn to_physical_type(&self) -> PolarsResult { + self.map_dtype(|dtype| dtype.to_physical()) + } + + /// Map a single dtype with a potentially failing mapper function. + pub fn try_map_dtype( + &self, + func: impl FnOnce(&DataType) -> PolarsResult, + ) -> PolarsResult { + let dtype = func(self.fields[0].dtype())?; + Ok(Field::new(self.fields[0].name().clone(), dtype)) + } + + /// Map all dtypes with a potentially failing mapper function. + pub fn try_map_dtypes( + &self, + func: impl FnOnce(&[&DataType]) -> PolarsResult, + ) -> PolarsResult { + let mut fld = self.fields[0].clone(); + let dtypes = self + .fields + .iter() + .map(|fld| fld.dtype()) + .collect::>(); + let new_type = func(&dtypes)?; + fld.coerce(new_type); + Ok(fld) + } + + /// Map the dtype to the "supertype" of all fields. + pub fn map_to_supertype(&self) -> PolarsResult { + let st = args_to_supertype(self.fields)?; + let mut first = self.fields[0].clone(); + first.coerce(st); + Ok(first) + } + + /// Map the dtype to the dtype of the list/array elements. + pub fn map_to_list_and_array_inner_dtype(&self) -> PolarsResult { + let mut first = self.fields[0].clone(); + let dt = first + .dtype() + .inner_dtype() + .cloned() + .unwrap_or_else(|| DataType::Unknown(Default::default())); + first.coerce(dt); + Ok(first) + } + + #[cfg(feature = "dtype-array")] + /// Map the dtype to the dtype of the array elements, with typo validation. + pub fn try_map_to_array_inner_dtype(&self) -> PolarsResult { + let dt = self.fields[0].dtype(); + match dt { + DataType::Array(_, _) => self.map_to_list_and_array_inner_dtype(), + _ => polars_bail!(InvalidOperation: "expected Array type, got: {}", dt), + } + } + + /// Map the dtypes to the "supertype" of a list of lists. + pub fn map_to_list_supertype(&self) -> PolarsResult { + self.try_map_dtypes(|dts| { + let mut super_type_inner = None; + + for dt in dts { + match dt { + DataType::List(inner) => match super_type_inner { + None => super_type_inner = Some(*inner.clone()), + Some(st_inner) => { + super_type_inner = Some(try_get_supertype(&st_inner, inner)?) + }, + }, + dt => match super_type_inner { + None => super_type_inner = Some((*dt).clone()), + Some(st_inner) => { + super_type_inner = Some(try_get_supertype(&st_inner, dt)?) + }, + }, + } + } + Ok(DataType::List(Box::new(super_type_inner.unwrap()))) + }) + } + + /// Set the timezone of a datetime dtype. + #[cfg(feature = "timezones")] + pub fn map_datetime_dtype_timezone(&self, tz: Option<&TimeZone>) -> PolarsResult { + self.try_map_dtype(|dt| { + if let DataType::Datetime(tu, _) = dt { + Ok(DataType::Datetime(*tu, tz.cloned())) + } else { + polars_bail!(op = "replace-time-zone", got = dt, expected = "Datetime"); + } + }) + } + + pub fn sum_dtype(&self) -> PolarsResult { + use DataType::*; + self.map_dtype(|dtype| match dtype { + Int8 | UInt8 | Int16 | UInt16 => Int64, + dt => dt.clone(), + }) + } + + pub fn nested_sum_type(&self) -> PolarsResult { + let mut first = self.fields[0].clone(); + use DataType::*; + let dt = first + .dtype() + .inner_dtype() + .cloned() + .unwrap_or_else(|| Unknown(Default::default())); + + match dt { + Boolean => first.coerce(IDX_DTYPE), + UInt8 | Int8 | Int16 | UInt16 => first.coerce(Int64), + _ => first.coerce(dt), + } + Ok(first) + } + + pub fn nested_mean_median_type(&self) -> PolarsResult { + let mut first = self.fields[0].clone(); + use DataType::*; + let dt = first + .dtype() + .inner_dtype() + .cloned() + .unwrap_or_else(|| Unknown(Default::default())); + + let new_dt = match dt { + #[cfg(feature = "dtype-datetime")] + Date => Datetime(TimeUnit::Milliseconds, None), + dt if dt.is_temporal() => dt, + Float32 => Float32, + _ => Float64, + }; + first.coerce(new_dt); + Ok(first) + } + + pub(super) fn pow_dtype(&self) -> PolarsResult { + let base_dtype = self.fields[0].dtype(); + let exponent_dtype = self.fields[1].dtype(); + if base_dtype.is_integer() { + if exponent_dtype.is_float() { + Ok(Field::new( + self.fields[0].name().clone(), + exponent_dtype.clone(), + )) + } else { + Ok(Field::new( + self.fields[0].name().clone(), + base_dtype.clone(), + )) + } + } else { + Ok(Field::new( + self.fields[0].name().clone(), + base_dtype.clone(), + )) + } + } + + #[cfg(feature = "extract_jsonpath")] + pub fn with_opt_dtype(&self, dtype: Option) -> PolarsResult { + let dtype = dtype.unwrap_or_else(|| DataType::Unknown(Default::default())); + self.with_dtype(dtype) + } + + #[cfg(feature = "replace")] + pub fn replace_dtype(&self, return_dtype: Option) -> PolarsResult { + let dtype = match return_dtype { + Some(dtype) => dtype, + None => { + let new = &self.fields[2]; + let default = self.fields.get(3); + match default { + Some(default) => try_get_supertype(default.dtype(), new.dtype())?, + None => new.dtype().clone(), + } + }, + }; + self.with_dtype(dtype) + } +} + +pub(crate) fn args_to_supertype>(dtypes: &[D]) -> PolarsResult { + let mut st = dtypes[0].as_ref().clone(); + for dt in &dtypes[1..] { + st = try_get_supertype(&st, dt.as_ref())? + } + + match (dtypes[0].as_ref(), &st) { + #[cfg(feature = "dtype-categorical")] + (DataType::Categorical(_, ord), DataType::String) => st = DataType::Categorical(None, *ord), + _ => { + if let DataType::Unknown(kind) = st { + match kind { + UnknownKind::Float => st = DataType::Float64, + UnknownKind::Int(v) => { + st = materialize_dyn_int(v).dtype(); + }, + UnknownKind::Str => st = DataType::String, + _ => {}, + } + } + }, + } + + Ok(st) +} diff --git a/crates/polars-plan/src/dsl/function_expr/search_sorted.rs b/crates/polars-plan/src/dsl/function_expr/search_sorted.rs new file mode 100644 index 000000000000..38f6ef81d5c3 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/search_sorted.rs @@ -0,0 +1,14 @@ +use super::*; + +pub(super) fn search_sorted_impl(s: &mut [Column], side: SearchSortedSide) -> PolarsResult { + let sorted_array = &s[0]; + let search_value = &s[1]; + + search_sorted( + sorted_array.as_materialized_series(), + search_value.as_materialized_series(), + side, + false, + ) + .map(|ca| ca.into_column()) +} diff --git a/crates/polars-plan/src/dsl/function_expr/shift_and_fill.rs b/crates/polars-plan/src/dsl/function_expr/shift_and_fill.rs new file mode 100644 index 000000000000..811628dd07e4 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/shift_and_fill.rs @@ -0,0 +1,123 @@ +use polars_core::downcast_as_macro_arg_physical; + +use super::*; + +fn shift_and_fill_numeric(ca: &ChunkedArray, n: i64, fill_value: AnyValue) -> ChunkedArray +where + T: PolarsNumericType, + ChunkedArray: ChunkShiftFill>, +{ + let fill_value = fill_value.extract::(); + ca.shift_and_fill(n, fill_value) +} + +#[cfg(any( + feature = "object", + feature = "dtype-struct", + feature = "dtype-categorical" +))] +fn shift_and_fill_with_mask(s: &Column, n: i64, fill_value: &Column) -> PolarsResult { + use arrow::array::BooleanArray; + use arrow::bitmap::BitmapBuilder; + + let mask: BooleanChunked = if n > 0 { + let len = s.len(); + let mut bits = BitmapBuilder::with_capacity(s.len()); + bits.extend_constant(n as usize, false); + bits.extend_constant(len.saturating_sub(n as usize), true); + let mask = BooleanArray::from_data_default(bits.freeze(), None); + mask.into() + } else { + let length = s.len() as i64; + // n is negative, so subtraction. + let tipping_point = std::cmp::max(length + n, 0); + let mut bits = BitmapBuilder::with_capacity(s.len()); + bits.extend_constant(tipping_point as usize, true); + bits.extend_constant(-n as usize, false); + let mask = BooleanArray::from_data_default(bits.freeze(), None); + mask.into() + }; + s.shift(n).zip_with_same_type(&mask, fill_value) +} + +pub(super) fn shift_and_fill(args: &[Column]) -> PolarsResult { + let s = &args[0]; + let n_s = &args[1].cast(&DataType::Int64)?; + let n = n_s.i64()?; + + if let Some(n) = n.get(0) { + let logical = s.dtype(); + let physical = s.to_physical_repr(); + let fill_value_s = &args[2]; + let fill_value = fill_value_s.get(0).unwrap(); + + use DataType::*; + match logical { + Boolean => { + let ca = s.bool()?; + let fill_value = match fill_value { + AnyValue::Boolean(v) => Some(v), + AnyValue::Null => None, + v => polars_bail!(ComputeError: "fill value '{}' is not supported", v), + }; + ca.shift_and_fill(n, fill_value).into_column().cast(logical) + }, + String => { + let ca = s.str()?; + let fill_value = match fill_value { + AnyValue::String(v) => Some(v), + AnyValue::StringOwned(ref v) => Some(v.as_str()), + AnyValue::Null => None, + v => polars_bail!(ComputeError: "fill value '{}' is not supported", v), + }; + ca.shift_and_fill(n, fill_value).into_column().cast(logical) + }, + List(_) => { + let ca = s.list()?; + let fill_value = match fill_value { + AnyValue::List(v) => Some(v), + AnyValue::Null => None, + v => polars_bail!(ComputeError: "fill value '{}' is not supported", v), + }; + unsafe { + ca.shift_and_fill(n, fill_value.as_ref()) + .into_column() + .from_physical_unchecked(logical) + } + }, + #[cfg(feature = "object")] + Object(_) => shift_and_fill_with_mask(s, n, fill_value_s), + #[cfg(feature = "dtype-struct")] + Struct(_) => shift_and_fill_with_mask(s, n, fill_value_s), + #[cfg(feature = "dtype-categorical")] + Categorical(_, _) | Enum(_, _) => shift_and_fill_with_mask(s, n, fill_value_s), + dt if dt.is_primitive_numeric() || dt.is_logical() => { + macro_rules! dispatch { + ($ca:expr, $n:expr, $fill_value:expr) => {{ shift_and_fill_numeric($ca, $n, $fill_value).into_column() }}; + } + let out = downcast_as_macro_arg_physical!(physical, dispatch, n, fill_value); + unsafe { out.from_physical_unchecked(logical) } + }, + dt => polars_bail!(opq = shift_and_fill, dt), + } + } else { + Ok(Column::full_null(s.name().clone(), s.len(), s.dtype())) + } +} + +pub fn shift(args: &[Column]) -> PolarsResult { + let s = &args[0]; + let n_s = &args[1]; + polars_ensure!( + n_s.len() == 1, + ComputeError: "n must be a single value." + ); + + let n_s = n_s.cast(&DataType::Int64)?; + let n = n_s.i64()?; + + match n.get(0) { + Some(n) => Ok(s.shift(n)), + None => Ok(Column::full_null(s.name().clone(), s.len(), s.dtype())), + } +} diff --git a/crates/polars-plan/src/dsl/function_expr/shrink_type.rs b/crates/polars-plan/src/dsl/function_expr/shrink_type.rs new file mode 100644 index 000000000000..0cc917c0c264 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/shrink_type.rs @@ -0,0 +1,38 @@ +use super::*; + +pub(super) fn shrink(c: Column) -> PolarsResult { + if !c.dtype().is_primitive_numeric() { + return Ok(c); + } + + if c.dtype().is_float() { + return c.cast(&DataType::Float32); + } + + if c.dtype().is_unsigned_integer() { + let max = c.max_reduce()?.value().extract::().unwrap_or(0_u64); + + if cfg!(feature = "dtype-u8") && max <= u8::MAX as u64 { + c.cast(&DataType::UInt8) + } else if cfg!(feature = "dtype-u16") && max <= u16::MAX as u64 { + c.cast(&DataType::UInt16) + } else if max <= u32::MAX as u64 { + c.cast(&DataType::UInt32) + } else { + Ok(c) + } + } else { + let min = c.min_reduce()?.value().extract::().unwrap_or(0_i64); + let max = c.max_reduce()?.value().extract::().unwrap_or(0_i64); + + if cfg!(feature = "dtype-i8") && min >= i8::MIN as i64 && max <= i8::MAX as i64 { + c.cast(&DataType::Int8) + } else if cfg!(feature = "dtype-i16") && min >= i16::MIN as i64 && max <= i16::MAX as i64 { + c.cast(&DataType::Int16) + } else if min >= i32::MIN as i64 && max <= i32::MAX as i64 { + c.cast(&DataType::Int32) + } else { + Ok(c) + } + } +} diff --git a/crates/polars-plan/src/dsl/function_expr/sign.rs b/crates/polars-plan/src/dsl/function_expr/sign.rs new file mode 100644 index 000000000000..3b8b886b0787 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/sign.rs @@ -0,0 +1,34 @@ +use num_traits::{One, Zero}; +use polars_core::with_match_physical_numeric_polars_type; + +use super::*; + +pub(super) fn sign(s: &Column) -> PolarsResult { + let s = s.as_materialized_series(); + let dt = s.dtype(); + polars_ensure!(dt.is_primitive_numeric(), opq = sign, dt); + with_match_physical_numeric_polars_type!(dt, |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref(); + Ok(sign_impl(ca)) + }) +} + +fn sign_impl(ca: &ChunkedArray) -> Column +where + T: PolarsNumericType, + ChunkedArray: IntoColumn, +{ + ca.apply_values(|x| { + if x < T::Native::zero() { + T::Native::zero() - T::Native::one() + } else if x > T::Native::zero() { + T::Native::one() + } else { + // Returning x here ensures we return NaN for NaN input, and + // maintain the sign for signed zeroes (although we don't really + // care about the latter). + x + } + }) + .into_column() +} diff --git a/crates/polars-plan/src/dsl/function_expr/strings.rs b/crates/polars-plan/src/dsl/function_expr/strings.rs new file mode 100644 index 000000000000..a5c3a20be557 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/strings.rs @@ -0,0 +1,1202 @@ +use std::borrow::Cow; + +use arrow::legacy::utils::CustomIterTools; +#[cfg(feature = "timezones")] +use polars_core::chunked_array::temporal::validate_time_zone; +use polars_core::utils::handle_casting_failures; +#[cfg(feature = "dtype-struct")] +use polars_utils::format_pl_smallstr; +#[cfg(feature = "regex")] +use regex::{NoExpand, escape}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use super::*; +use crate::{map, map_as_slice}; + +#[cfg(all(feature = "regex", feature = "timezones"))] +polars_utils::regex_cache::cached_regex! { + static TZ_AWARE_RE = r"(%z)|(%:z)|(%::z)|(%:::z)|(%#z)|(^%\+$)"; +} + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Clone, PartialEq, Debug, Eq, Hash)] +pub enum StringFunction { + #[cfg(feature = "concat_str")] + ConcatHorizontal { + delimiter: PlSmallStr, + ignore_nulls: bool, + }, + #[cfg(feature = "concat_str")] + ConcatVertical { + delimiter: PlSmallStr, + ignore_nulls: bool, + }, + #[cfg(feature = "regex")] + Contains { + literal: bool, + strict: bool, + }, + CountMatches(bool), + EndsWith, + Extract(usize), + ExtractAll, + #[cfg(feature = "extract_groups")] + ExtractGroups { + dtype: DataType, + pat: PlSmallStr, + }, + #[cfg(feature = "regex")] + Find { + literal: bool, + strict: bool, + }, + #[cfg(feature = "string_to_integer")] + ToInteger(bool), + LenBytes, + LenChars, + Lowercase, + #[cfg(feature = "extract_jsonpath")] + JsonDecode { + dtype: Option, + infer_schema_len: Option, + }, + #[cfg(feature = "extract_jsonpath")] + JsonPathMatch, + #[cfg(feature = "regex")] + Replace { + // negative is replace all + // how many matches to replace + n: i64, + literal: bool, + }, + #[cfg(feature = "string_normalize")] + Normalize { + form: UnicodeForm, + }, + #[cfg(feature = "string_reverse")] + Reverse, + #[cfg(feature = "string_pad")] + PadStart { + length: usize, + fill_char: char, + }, + #[cfg(feature = "string_pad")] + PadEnd { + length: usize, + fill_char: char, + }, + Slice, + Head, + Tail, + #[cfg(feature = "string_encoding")] + HexEncode, + #[cfg(feature = "binary_encoding")] + HexDecode(bool), + #[cfg(feature = "string_encoding")] + Base64Encode, + #[cfg(feature = "binary_encoding")] + Base64Decode(bool), + StartsWith, + StripChars, + StripCharsStart, + StripCharsEnd, + StripPrefix, + StripSuffix, + #[cfg(feature = "dtype-struct")] + SplitExact { + n: usize, + inclusive: bool, + }, + #[cfg(feature = "dtype-struct")] + SplitN(usize), + #[cfg(feature = "temporal")] + Strptime(DataType, StrptimeOptions), + Split(bool), + #[cfg(feature = "dtype-decimal")] + ToDecimal(usize), + #[cfg(feature = "nightly")] + Titlecase, + Uppercase, + #[cfg(feature = "string_pad")] + ZFill, + #[cfg(feature = "find_many")] + ContainsAny { + ascii_case_insensitive: bool, + }, + #[cfg(feature = "find_many")] + ReplaceMany { + ascii_case_insensitive: bool, + }, + #[cfg(feature = "find_many")] + ExtractMany { + ascii_case_insensitive: bool, + overlapping: bool, + }, + #[cfg(feature = "find_many")] + FindMany { + ascii_case_insensitive: bool, + overlapping: bool, + }, + #[cfg(feature = "regex")] + EscapeRegex, +} + +impl StringFunction { + pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult { + use StringFunction::*; + match self { + #[cfg(feature = "concat_str")] + ConcatVertical { .. } | ConcatHorizontal { .. } => mapper.with_dtype(DataType::String), + #[cfg(feature = "regex")] + Contains { .. } => mapper.with_dtype(DataType::Boolean), + CountMatches(_) => mapper.with_dtype(DataType::UInt32), + EndsWith | StartsWith => mapper.with_dtype(DataType::Boolean), + Extract(_) => mapper.with_same_dtype(), + ExtractAll => mapper.with_dtype(DataType::List(Box::new(DataType::String))), + #[cfg(feature = "extract_groups")] + ExtractGroups { dtype, .. } => mapper.with_dtype(dtype.clone()), + #[cfg(feature = "string_to_integer")] + ToInteger { .. } => mapper.with_dtype(DataType::Int64), + #[cfg(feature = "regex")] + Find { .. } => mapper.with_dtype(DataType::UInt32), + #[cfg(feature = "extract_jsonpath")] + JsonDecode { dtype, .. } => mapper.with_opt_dtype(dtype.clone()), + #[cfg(feature = "extract_jsonpath")] + JsonPathMatch => mapper.with_dtype(DataType::String), + LenBytes => mapper.with_dtype(DataType::UInt32), + LenChars => mapper.with_dtype(DataType::UInt32), + #[cfg(feature = "regex")] + Replace { .. } => mapper.with_same_dtype(), + #[cfg(feature = "string_normalize")] + Normalize { .. } => mapper.with_same_dtype(), + #[cfg(feature = "string_reverse")] + Reverse => mapper.with_same_dtype(), + #[cfg(feature = "temporal")] + Strptime(dtype, _) => mapper.with_dtype(dtype.clone()), + Split(_) => mapper.with_dtype(DataType::List(Box::new(DataType::String))), + #[cfg(feature = "nightly")] + Titlecase => mapper.with_same_dtype(), + #[cfg(feature = "dtype-decimal")] + ToDecimal(_) => mapper.with_dtype(DataType::Decimal(None, None)), + #[cfg(feature = "string_encoding")] + HexEncode => mapper.with_same_dtype(), + #[cfg(feature = "binary_encoding")] + HexDecode(_) => mapper.with_dtype(DataType::Binary), + #[cfg(feature = "string_encoding")] + Base64Encode => mapper.with_same_dtype(), + #[cfg(feature = "binary_encoding")] + Base64Decode(_) => mapper.with_dtype(DataType::Binary), + Uppercase | Lowercase | StripChars | StripCharsStart | StripCharsEnd | StripPrefix + | StripSuffix | Slice | Head | Tail => mapper.with_same_dtype(), + #[cfg(feature = "string_pad")] + PadStart { .. } | PadEnd { .. } | ZFill => mapper.with_same_dtype(), + #[cfg(feature = "dtype-struct")] + SplitExact { n, .. } => mapper.with_dtype(DataType::Struct( + (0..n + 1) + .map(|i| Field::new(format_pl_smallstr!("field_{i}"), DataType::String)) + .collect(), + )), + #[cfg(feature = "dtype-struct")] + SplitN(n) => mapper.with_dtype(DataType::Struct( + (0..*n) + .map(|i| Field::new(format_pl_smallstr!("field_{i}"), DataType::String)) + .collect(), + )), + #[cfg(feature = "find_many")] + ContainsAny { .. } => mapper.with_dtype(DataType::Boolean), + #[cfg(feature = "find_many")] + ReplaceMany { .. } => mapper.with_same_dtype(), + #[cfg(feature = "find_many")] + ExtractMany { .. } => mapper.with_dtype(DataType::List(Box::new(DataType::String))), + #[cfg(feature = "find_many")] + FindMany { .. } => mapper.with_dtype(DataType::List(Box::new(DataType::UInt32))), + #[cfg(feature = "regex")] + EscapeRegex => mapper.with_same_dtype(), + } + } + + pub fn function_options(&self) -> FunctionOptions { + use StringFunction as S; + match self { + #[cfg(feature = "concat_str")] + S::ConcatHorizontal { .. } => FunctionOptions::elementwise(), + #[cfg(feature = "concat_str")] + S::ConcatVertical { .. } => FunctionOptions::aggregation(), + #[cfg(feature = "regex")] + S::Contains { .. } => { + FunctionOptions::elementwise().with_supertyping(Default::default()) + }, + S::CountMatches(_) => FunctionOptions::elementwise(), + S::EndsWith | S::StartsWith | S::Extract(_) => { + FunctionOptions::elementwise().with_supertyping(Default::default()) + }, + S::ExtractAll => FunctionOptions::elementwise(), + #[cfg(feature = "extract_groups")] + S::ExtractGroups { .. } => FunctionOptions::elementwise(), + #[cfg(feature = "string_to_integer")] + S::ToInteger { .. } => FunctionOptions::elementwise(), + #[cfg(feature = "regex")] + S::Find { .. } => FunctionOptions::elementwise().with_supertyping(Default::default()), + #[cfg(feature = "extract_jsonpath")] + S::JsonDecode { dtype: Some(_), .. } => FunctionOptions::elementwise(), + // because dtype should be inferred only once and be consistent over chunks/morsels. + #[cfg(feature = "extract_jsonpath")] + S::JsonDecode { dtype: None, .. } => FunctionOptions::elementwise_with_infer(), + #[cfg(feature = "extract_jsonpath")] + S::JsonPathMatch => FunctionOptions::elementwise(), + S::LenBytes | S::LenChars => FunctionOptions::elementwise(), + #[cfg(feature = "regex")] + S::Replace { .. } => { + FunctionOptions::elementwise().with_supertyping(Default::default()) + }, + #[cfg(feature = "string_normalize")] + S::Normalize { .. } => FunctionOptions::elementwise(), + #[cfg(feature = "string_reverse")] + S::Reverse => FunctionOptions::elementwise(), + #[cfg(feature = "temporal")] + S::Strptime(_, options) if options.format.is_some() => FunctionOptions::elementwise(), + S::Strptime(_, _) => FunctionOptions::elementwise_with_infer(), + S::Split(_) => FunctionOptions::elementwise(), + #[cfg(feature = "nightly")] + S::Titlecase => FunctionOptions::elementwise(), + #[cfg(feature = "dtype-decimal")] + S::ToDecimal(_) => FunctionOptions::elementwise_with_infer(), + #[cfg(feature = "string_encoding")] + S::HexEncode | S::Base64Encode => FunctionOptions::elementwise(), + #[cfg(feature = "binary_encoding")] + S::HexDecode(_) | S::Base64Decode(_) => FunctionOptions::elementwise(), + S::Uppercase | S::Lowercase => FunctionOptions::elementwise(), + S::StripChars + | S::StripCharsStart + | S::StripCharsEnd + | S::StripPrefix + | S::StripSuffix + | S::Head + | S::Tail => FunctionOptions::elementwise(), + S::Slice => FunctionOptions::elementwise(), + #[cfg(feature = "string_pad")] + S::PadStart { .. } | S::PadEnd { .. } | S::ZFill => FunctionOptions::elementwise(), + #[cfg(feature = "dtype-struct")] + S::SplitExact { .. } => FunctionOptions::elementwise(), + #[cfg(feature = "dtype-struct")] + S::SplitN(_) => FunctionOptions::elementwise(), + #[cfg(feature = "find_many")] + S::ContainsAny { .. } => FunctionOptions::groupwise(), + #[cfg(feature = "find_many")] + S::ReplaceMany { .. } => FunctionOptions::groupwise(), + #[cfg(feature = "find_many")] + S::ExtractMany { .. } => FunctionOptions::groupwise(), + #[cfg(feature = "find_many")] + S::FindMany { .. } => FunctionOptions::groupwise(), + #[cfg(feature = "regex")] + S::EscapeRegex => FunctionOptions::elementwise(), + } + } +} + +impl Display for StringFunction { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use StringFunction::*; + let s = match self { + #[cfg(feature = "regex")] + Contains { .. } => "contains", + CountMatches(_) => "count_matches", + EndsWith => "ends_with", + Extract(_) => "extract", + #[cfg(feature = "concat_str")] + ConcatHorizontal { .. } => "concat_horizontal", + #[cfg(feature = "concat_str")] + ConcatVertical { .. } => "concat_vertical", + ExtractAll => "extract_all", + #[cfg(feature = "extract_groups")] + ExtractGroups { .. } => "extract_groups", + #[cfg(feature = "string_to_integer")] + ToInteger { .. } => "to_integer", + #[cfg(feature = "regex")] + Find { .. } => "find", + Head => "head", + Tail => "tail", + #[cfg(feature = "extract_jsonpath")] + JsonDecode { .. } => "json_decode", + #[cfg(feature = "extract_jsonpath")] + JsonPathMatch => "json_path_match", + LenBytes => "len_bytes", + Lowercase => "lowercase", + LenChars => "len_chars", + #[cfg(feature = "string_pad")] + PadEnd { .. } => "pad_end", + #[cfg(feature = "string_pad")] + PadStart { .. } => "pad_start", + #[cfg(feature = "regex")] + Replace { .. } => "replace", + #[cfg(feature = "string_normalize")] + Normalize { .. } => "normalize", + #[cfg(feature = "string_reverse")] + Reverse => "reverse", + #[cfg(feature = "string_encoding")] + HexEncode => "hex_encode", + #[cfg(feature = "binary_encoding")] + HexDecode(_) => "hex_decode", + #[cfg(feature = "string_encoding")] + Base64Encode => "base64_encode", + #[cfg(feature = "binary_encoding")] + Base64Decode(_) => "base64_decode", + Slice => "slice", + StartsWith => "starts_with", + StripChars => "strip_chars", + StripCharsStart => "strip_chars_start", + StripCharsEnd => "strip_chars_end", + StripPrefix => "strip_prefix", + StripSuffix => "strip_suffix", + #[cfg(feature = "dtype-struct")] + SplitExact { inclusive, .. } => { + if *inclusive { + "split_exact_inclusive" + } else { + "split_exact" + } + }, + #[cfg(feature = "dtype-struct")] + SplitN(_) => "splitn", + #[cfg(feature = "temporal")] + Strptime(_, _) => "strptime", + Split(inclusive) => { + if *inclusive { + "split_inclusive" + } else { + "split" + } + }, + #[cfg(feature = "nightly")] + Titlecase => "titlecase", + #[cfg(feature = "dtype-decimal")] + ToDecimal(_) => "to_decimal", + Uppercase => "uppercase", + #[cfg(feature = "string_pad")] + ZFill => "zfill", + #[cfg(feature = "find_many")] + ContainsAny { .. } => "contains_any", + #[cfg(feature = "find_many")] + ReplaceMany { .. } => "replace_many", + #[cfg(feature = "find_many")] + ExtractMany { .. } => "extract_many", + #[cfg(feature = "find_many")] + FindMany { .. } => "extract_many", + #[cfg(feature = "regex")] + EscapeRegex => "escape_regex", + }; + write!(f, "str.{s}") + } +} + +impl From for SpecialEq> { + fn from(func: StringFunction) -> Self { + use StringFunction::*; + match func { + #[cfg(feature = "regex")] + Contains { literal, strict } => map_as_slice!(strings::contains, literal, strict), + CountMatches(literal) => { + map_as_slice!(strings::count_matches, literal) + }, + EndsWith => map_as_slice!(strings::ends_with), + StartsWith => map_as_slice!(strings::starts_with), + Extract(group_index) => map_as_slice!(strings::extract, group_index), + ExtractAll => { + map_as_slice!(strings::extract_all) + }, + #[cfg(feature = "extract_groups")] + ExtractGroups { pat, dtype } => { + map!(strings::extract_groups, &pat, &dtype) + }, + #[cfg(feature = "regex")] + Find { literal, strict } => map_as_slice!(strings::find, literal, strict), + LenBytes => map!(strings::len_bytes), + LenChars => map!(strings::len_chars), + #[cfg(feature = "string_pad")] + PadEnd { length, fill_char } => { + map!(strings::pad_end, length, fill_char) + }, + #[cfg(feature = "string_pad")] + PadStart { length, fill_char } => { + map!(strings::pad_start, length, fill_char) + }, + #[cfg(feature = "string_pad")] + ZFill => { + map_as_slice!(strings::zfill) + }, + #[cfg(feature = "temporal")] + Strptime(dtype, options) => { + map_as_slice!(strings::strptime, dtype.clone(), &options) + }, + Split(inclusive) => { + map_as_slice!(strings::split, inclusive) + }, + #[cfg(feature = "dtype-struct")] + SplitExact { n, inclusive } => map_as_slice!(strings::split_exact, n, inclusive), + #[cfg(feature = "dtype-struct")] + SplitN(n) => map_as_slice!(strings::splitn, n), + #[cfg(feature = "concat_str")] + ConcatVertical { + delimiter, + ignore_nulls, + } => map!(strings::join, &delimiter, ignore_nulls), + #[cfg(feature = "concat_str")] + ConcatHorizontal { + delimiter, + ignore_nulls, + } => map_as_slice!(strings::concat_hor, &delimiter, ignore_nulls), + #[cfg(feature = "regex")] + Replace { n, literal } => map_as_slice!(strings::replace, literal, n), + #[cfg(feature = "string_normalize")] + Normalize { form } => map!(strings::normalize, form.clone()), + #[cfg(feature = "string_reverse")] + Reverse => map!(strings::reverse), + Uppercase => map!(uppercase), + Lowercase => map!(lowercase), + #[cfg(feature = "nightly")] + Titlecase => map!(strings::titlecase), + StripChars => map_as_slice!(strings::strip_chars), + StripCharsStart => map_as_slice!(strings::strip_chars_start), + StripCharsEnd => map_as_slice!(strings::strip_chars_end), + StripPrefix => map_as_slice!(strings::strip_prefix), + StripSuffix => map_as_slice!(strings::strip_suffix), + #[cfg(feature = "string_to_integer")] + ToInteger(strict) => map_as_slice!(strings::to_integer, strict), + Slice => map_as_slice!(strings::str_slice), + Head => map_as_slice!(strings::str_head), + Tail => map_as_slice!(strings::str_tail), + #[cfg(feature = "string_encoding")] + HexEncode => map!(strings::hex_encode), + #[cfg(feature = "binary_encoding")] + HexDecode(strict) => map!(strings::hex_decode, strict), + #[cfg(feature = "string_encoding")] + Base64Encode => map!(strings::base64_encode), + #[cfg(feature = "binary_encoding")] + Base64Decode(strict) => map!(strings::base64_decode, strict), + #[cfg(feature = "dtype-decimal")] + ToDecimal(infer_len) => map!(strings::to_decimal, infer_len), + #[cfg(feature = "extract_jsonpath")] + JsonDecode { + dtype, + infer_schema_len, + } => map!(strings::json_decode, dtype.clone(), infer_schema_len), + #[cfg(feature = "extract_jsonpath")] + JsonPathMatch => map_as_slice!(strings::json_path_match), + #[cfg(feature = "find_many")] + ContainsAny { + ascii_case_insensitive, + } => { + map_as_slice!(contains_any, ascii_case_insensitive) + }, + #[cfg(feature = "find_many")] + ReplaceMany { + ascii_case_insensitive, + } => { + map_as_slice!(replace_many, ascii_case_insensitive) + }, + #[cfg(feature = "find_many")] + ExtractMany { + ascii_case_insensitive, + overlapping, + } => { + map_as_slice!(extract_many, ascii_case_insensitive, overlapping) + }, + #[cfg(feature = "find_many")] + FindMany { + ascii_case_insensitive, + overlapping, + } => { + map_as_slice!(find_many, ascii_case_insensitive, overlapping) + }, + #[cfg(feature = "regex")] + EscapeRegex => map!(escape_regex), + } + } +} + +#[cfg(feature = "find_many")] +fn contains_any(s: &[Column], ascii_case_insensitive: bool) -> PolarsResult { + let ca = s[0].str()?; + let patterns = s[1].str()?; + polars_ops::chunked_array::strings::contains_any(ca, patterns, ascii_case_insensitive) + .map(|out| out.into_column()) +} + +#[cfg(feature = "find_many")] +fn replace_many(s: &[Column], ascii_case_insensitive: bool) -> PolarsResult { + let ca = s[0].str()?; + let patterns = s[1].str()?; + let replace_with = s[2].str()?; + polars_ops::chunked_array::strings::replace_all( + ca, + patterns, + replace_with, + ascii_case_insensitive, + ) + .map(|out| out.into_column()) +} + +#[cfg(feature = "find_many")] +fn extract_many( + s: &[Column], + ascii_case_insensitive: bool, + overlapping: bool, +) -> PolarsResult { + let ca = s[0].str()?; + let patterns = &s[1]; + + polars_ops::chunked_array::strings::extract_many( + ca, + patterns.as_materialized_series(), + ascii_case_insensitive, + overlapping, + ) + .map(|out| out.into_column()) +} + +#[cfg(feature = "find_many")] +fn find_many( + s: &[Column], + ascii_case_insensitive: bool, + overlapping: bool, +) -> PolarsResult { + let ca = s[0].str()?; + let patterns = &s[1]; + + polars_ops::chunked_array::strings::find_many( + ca, + patterns.as_materialized_series(), + ascii_case_insensitive, + overlapping, + ) + .map(|out| out.into_column()) +} + +fn uppercase(s: &Column) -> PolarsResult { + let ca = s.str()?; + Ok(ca.to_uppercase().into_column()) +} + +fn lowercase(s: &Column) -> PolarsResult { + let ca = s.str()?; + Ok(ca.to_lowercase().into_column()) +} + +#[cfg(feature = "nightly")] +pub(super) fn titlecase(s: &Column) -> PolarsResult { + let ca = s.str()?; + Ok(ca.to_titlecase().into_column()) +} + +pub(super) fn len_chars(s: &Column) -> PolarsResult { + let ca = s.str()?; + Ok(ca.str_len_chars().into_column()) +} + +pub(super) fn len_bytes(s: &Column) -> PolarsResult { + let ca = s.str()?; + Ok(ca.str_len_bytes().into_column()) +} + +#[cfg(feature = "regex")] +pub(super) fn contains(s: &[Column], literal: bool, strict: bool) -> PolarsResult { + _check_same_length(s, "contains")?; + let ca = s[0].str()?; + let pat = s[1].str()?; + ca.contains_chunked(pat, literal, strict) + .map(|ok| ok.into_column()) +} + +#[cfg(feature = "regex")] +pub(super) fn find(s: &[Column], literal: bool, strict: bool) -> PolarsResult { + _check_same_length(s, "find")?; + let ca = s[0].str()?; + let pat = s[1].str()?; + ca.find_chunked(pat, literal, strict) + .map(|ok| ok.into_column()) +} + +pub(super) fn ends_with(s: &[Column]) -> PolarsResult { + _check_same_length(s, "ends_with")?; + let ca = s[0].str()?.as_binary(); + let suffix = s[1].str()?.as_binary(); + + Ok(ca.ends_with_chunked(&suffix)?.into_column()) +} + +pub(super) fn starts_with(s: &[Column]) -> PolarsResult { + _check_same_length(s, "starts_with")?; + let ca = s[0].str()?.as_binary(); + let prefix = s[1].str()?.as_binary(); + Ok(ca.starts_with_chunked(&prefix)?.into_column()) +} + +/// Extract a regex pattern from the a string value. +pub(super) fn extract(s: &[Column], group_index: usize) -> PolarsResult { + let ca = s[0].str()?; + let pat = s[1].str()?; + ca.extract(pat, group_index).map(|ca| ca.into_column()) +} + +#[cfg(feature = "extract_groups")] +/// Extract all capture groups from a regex pattern as a struct +pub(super) fn extract_groups(s: &Column, pat: &str, dtype: &DataType) -> PolarsResult { + let ca = s.str()?; + ca.extract_groups(pat, dtype).map(Column::from) +} + +#[cfg(feature = "string_pad")] +pub(super) fn pad_start(s: &Column, length: usize, fill_char: char) -> PolarsResult { + let ca = s.str()?; + Ok(ca.pad_start(length, fill_char).into_column()) +} + +#[cfg(feature = "string_pad")] +pub(super) fn pad_end(s: &Column, length: usize, fill_char: char) -> PolarsResult { + let ca = s.str()?; + Ok(ca.pad_end(length, fill_char).into_column()) +} + +#[cfg(feature = "string_pad")] +pub(super) fn zfill(s: &[Column]) -> PolarsResult { + _check_same_length(s, "zfill")?; + let ca = s[0].str()?; + let length_s = s[1].strict_cast(&DataType::UInt64)?; + let length = length_s.u64()?; + Ok(ca.zfill(length).into_column()) +} + +pub(super) fn strip_chars(s: &[Column]) -> PolarsResult { + _check_same_length(s, "strip_chars")?; + let ca = s[0].str()?; + let pat_s = &s[1]; + ca.strip_chars(pat_s).map(|ok| ok.into_column()) +} + +pub(super) fn strip_chars_start(s: &[Column]) -> PolarsResult { + _check_same_length(s, "strip_chars_start")?; + let ca = s[0].str()?; + let pat_s = &s[1]; + ca.strip_chars_start(pat_s).map(|ok| ok.into_column()) +} + +pub(super) fn strip_chars_end(s: &[Column]) -> PolarsResult { + _check_same_length(s, "strip_chars_end")?; + let ca = s[0].str()?; + let pat_s = &s[1]; + ca.strip_chars_end(pat_s).map(|ok| ok.into_column()) +} + +pub(super) fn strip_prefix(s: &[Column]) -> PolarsResult { + _check_same_length(s, "strip_prefix")?; + let ca = s[0].str()?; + let prefix = s[1].str()?; + Ok(ca.strip_prefix(prefix).into_column()) +} + +pub(super) fn strip_suffix(s: &[Column]) -> PolarsResult { + _check_same_length(s, "strip_suffix")?; + let ca = s[0].str()?; + let suffix = s[1].str()?; + Ok(ca.strip_suffix(suffix).into_column()) +} + +pub(super) fn extract_all(args: &[Column]) -> PolarsResult { + let s = &args[0]; + let pat = &args[1]; + + let ca = s.str()?; + let pat = pat.str()?; + + if pat.len() == 1 { + if let Some(pat) = pat.get(0) { + ca.extract_all(pat).map(|ca| ca.into_column()) + } else { + Ok(Column::full_null( + ca.name().clone(), + ca.len(), + &DataType::List(Box::new(DataType::String)), + )) + } + } else { + ca.extract_all_many(pat).map(|ca| ca.into_column()) + } +} + +pub(super) fn count_matches(args: &[Column], literal: bool) -> PolarsResult { + let s = &args[0]; + let pat = &args[1]; + + let ca = s.str()?; + let pat = pat.str()?; + if pat.len() == 1 { + if let Some(pat) = pat.get(0) { + ca.count_matches(pat, literal).map(|ca| ca.into_column()) + } else { + Ok(Column::full_null( + ca.name().clone(), + ca.len(), + &DataType::UInt32, + )) + } + } else { + ca.count_matches_many(pat, literal) + .map(|ca| ca.into_column()) + } +} + +#[cfg(feature = "temporal")] +pub(super) fn strptime( + s: &[Column], + dtype: DataType, + options: &StrptimeOptions, +) -> PolarsResult { + match dtype { + #[cfg(feature = "dtype-date")] + DataType::Date => to_date(&s[0], options), + #[cfg(feature = "dtype-datetime")] + DataType::Datetime(time_unit, time_zone) => { + to_datetime(s, &time_unit, time_zone.as_ref(), options) + }, + #[cfg(feature = "dtype-time")] + DataType::Time => to_time(&s[0], options), + dt => polars_bail!(ComputeError: "not implemented for dtype {}", dt), + } +} + +#[cfg(feature = "dtype-struct")] +pub(super) fn split_exact(s: &[Column], n: usize, inclusive: bool) -> PolarsResult { + let ca = s[0].str()?; + let by = s[1].str()?; + + if inclusive { + ca.split_exact_inclusive(by, n).map(|ca| ca.into_column()) + } else { + ca.split_exact(by, n).map(|ca| ca.into_column()) + } +} + +#[cfg(feature = "dtype-struct")] +pub(super) fn splitn(s: &[Column], n: usize) -> PolarsResult { + let ca = s[0].str()?; + let by = s[1].str()?; + + ca.splitn(by, n).map(|ca| ca.into_column()) +} + +pub(super) fn split(s: &[Column], inclusive: bool) -> PolarsResult { + let ca = s[0].str()?; + let by = s[1].str()?; + + if inclusive { + Ok(ca.split_inclusive(by)?.into_column()) + } else { + Ok(ca.split(by)?.into_column()) + } +} + +#[cfg(feature = "dtype-date")] +fn to_date(s: &Column, options: &StrptimeOptions) -> PolarsResult { + let ca = s.str()?; + let out = { + if options.exact { + ca.as_date(options.format.as_deref(), options.cache)? + .into_column() + } else { + ca.as_date_not_exact(options.format.as_deref())? + .into_column() + } + }; + + if options.strict && ca.null_count() != out.null_count() { + handle_casting_failures(s.as_materialized_series(), out.as_materialized_series())?; + } + Ok(out.into_column()) +} + +#[cfg(feature = "dtype-datetime")] +fn to_datetime( + s: &[Column], + time_unit: &TimeUnit, + time_zone: Option<&TimeZone>, + options: &StrptimeOptions, +) -> PolarsResult { + let datetime_strings = &s[0].str()?; + let ambiguous = &s[1].str()?; + + polars_ensure!( + datetime_strings.len() == ambiguous.len() + || datetime_strings.len() == 1 + || ambiguous.len() == 1, + length_mismatch = "str.strptime", + datetime_strings.len(), + ambiguous.len() + ); + + let tz_aware = match &options.format { + #[cfg(all(feature = "regex", feature = "timezones"))] + Some(format) => TZ_AWARE_RE.is_match(format), + _ => false, + }; + #[cfg(feature = "timezones")] + if let Some(time_zone) = time_zone { + validate_time_zone(time_zone)?; + } + let out = if options.exact { + datetime_strings + .as_datetime( + options.format.as_deref(), + *time_unit, + options.cache, + tz_aware, + time_zone, + ambiguous, + )? + .into_column() + } else { + datetime_strings + .as_datetime_not_exact( + options.format.as_deref(), + *time_unit, + tz_aware, + time_zone, + ambiguous, + )? + .into_column() + }; + + if options.strict && datetime_strings.null_count() != out.null_count() { + handle_casting_failures(s[0].as_materialized_series(), out.as_materialized_series())?; + } + Ok(out.into_column()) +} + +#[cfg(feature = "dtype-time")] +fn to_time(s: &Column, options: &StrptimeOptions) -> PolarsResult { + polars_ensure!( + options.exact, ComputeError: "non-exact not implemented for Time data type" + ); + + let ca = s.str()?; + let out = ca + .as_time(options.format.as_deref(), options.cache)? + .into_column(); + + if options.strict && ca.null_count() != out.null_count() { + handle_casting_failures(s.as_materialized_series(), out.as_materialized_series())?; + } + Ok(out.into_column()) +} + +#[cfg(feature = "concat_str")] +pub(super) fn join(s: &Column, delimiter: &str, ignore_nulls: bool) -> PolarsResult { + let str_s = s.cast(&DataType::String)?; + let joined = polars_ops::chunked_array::str_join(str_s.str()?, delimiter, ignore_nulls); + Ok(joined.into_column()) +} + +#[cfg(feature = "concat_str")] +pub(super) fn concat_hor( + series: &[Column], + delimiter: &str, + ignore_nulls: bool, +) -> PolarsResult { + let str_series: Vec<_> = series + .iter() + .map(|s| s.cast(&DataType::String)) + .collect::>()?; + let cas: Vec<_> = str_series.iter().map(|s| s.str().unwrap()).collect(); + Ok(polars_ops::chunked_array::hor_str_concat(&cas, delimiter, ignore_nulls)?.into_column()) +} + +impl From for FunctionExpr { + fn from(str: StringFunction) -> Self { + FunctionExpr::StringExpr(str) + } +} + +#[cfg(feature = "regex")] +fn get_pat(pat: &StringChunked) -> PolarsResult<&str> { + pat.get(0).ok_or_else( + || polars_err!(ComputeError: "pattern cannot be 'null' in 'replace' expression"), + ) +} + +// used only if feature="regex" +#[allow(dead_code)] +fn iter_and_replace<'a, F>(ca: &'a StringChunked, val: &'a StringChunked, f: F) -> StringChunked +where + F: Fn(&'a str, &'a str) -> Cow<'a, str>, +{ + let mut out: StringChunked = ca + .into_iter() + .zip(val) + .map(|(opt_src, opt_val)| match (opt_src, opt_val) { + (Some(src), Some(val)) => Some(f(src, val)), + _ => None, + }) + .collect_trusted(); + + out.rename(ca.name().clone()); + out +} + +#[cfg(feature = "regex")] +fn is_literal_pat(pat: &str) -> bool { + pat.chars().all(|c| !c.is_ascii_punctuation()) +} + +#[cfg(feature = "regex")] +fn replace_n<'a>( + ca: &'a StringChunked, + pat: &'a StringChunked, + val: &'a StringChunked, + literal: bool, + n: usize, +) -> PolarsResult { + match (pat.len(), val.len()) { + (1, 1) => { + let pat = get_pat(pat)?; + let val = val.get(0).ok_or_else( + || polars_err!(ComputeError: "value cannot be 'null' in 'replace' expression"), + )?; + let literal = literal || is_literal_pat(pat); + + match literal { + true => ca.replace_literal(pat, val, n), + false => { + if n > 1 { + polars_bail!(ComputeError: "regex replacement with 'n > 1' not yet supported") + } + ca.replace(pat, val) + }, + } + }, + (1, len_val) => { + if n > 1 { + polars_bail!(ComputeError: "multivalue replacement with 'n > 1' not yet supported") + } + let mut pat = get_pat(pat)?.to_string(); + polars_ensure!( + len_val == ca.len(), + ComputeError: + "replacement value length ({}) does not match string column length ({})", + len_val, ca.len(), + ); + let lit = is_literal_pat(&pat); + let literal_pat = literal || lit; + + if literal_pat { + pat = escape(&pat) + } + + let reg = polars_utils::regex_cache::compile_regex(&pat)?; + + let f = |s: &'a str, val: &'a str| { + if lit && (s.len() <= 32) { + Cow::Owned(s.replacen(&pat, val, 1)) + } else { + // According to the docs for replace + // when literal = True then capture groups are ignored. + if literal { + reg.replace(s, NoExpand(val)) + } else { + reg.replace(s, val) + } + } + }; + Ok(iter_and_replace(ca, val, f)) + }, + _ => polars_bail!( + ComputeError: "dynamic pattern length in 'str.replace' expressions is not supported yet" + ), + } +} + +#[cfg(feature = "regex")] +fn replace_all<'a>( + ca: &'a StringChunked, + pat: &'a StringChunked, + val: &'a StringChunked, + literal: bool, +) -> PolarsResult { + match (pat.len(), val.len()) { + (1, 1) => { + let pat = get_pat(pat)?; + let val = val.get(0).ok_or_else( + || polars_err!(ComputeError: "value cannot be 'null' in 'replace' expression"), + )?; + let literal = literal || is_literal_pat(pat); + + match literal { + true => ca.replace_literal_all(pat, val), + false => ca.replace_all(pat, val), + } + }, + (1, len_val) => { + let mut pat = get_pat(pat)?.to_string(); + polars_ensure!( + len_val == ca.len(), + ComputeError: + "replacement value length ({}) does not match string column length ({})", + len_val, ca.len(), + ); + + let literal_pat = literal || is_literal_pat(&pat); + + if literal_pat { + pat = escape(&pat) + } + + let reg = polars_utils::regex_cache::compile_regex(&pat)?; + + let f = |s: &'a str, val: &'a str| { + // According to the docs for replace_all + // when literal = True then capture groups are ignored. + if literal { + reg.replace_all(s, NoExpand(val)) + } else { + reg.replace_all(s, val) + } + }; + + Ok(iter_and_replace(ca, val, f)) + }, + _ => polars_bail!( + ComputeError: "dynamic pattern length in 'str.replace' expressions is not supported yet" + ), + } +} + +#[cfg(feature = "regex")] +pub(super) fn replace(s: &[Column], literal: bool, n: i64) -> PolarsResult { + let column = &s[0]; + let pat = &s[1]; + let val = &s[2]; + let all = n < 0; + + let column = column.str()?; + let pat = pat.str()?; + let val = val.str()?; + + if all { + replace_all(column, pat, val, literal) + } else { + replace_n(column, pat, val, literal, n as usize) + } + .map(|ca| ca.into_column()) +} + +#[cfg(feature = "string_normalize")] +pub(super) fn normalize(s: &Column, form: UnicodeForm) -> PolarsResult { + let ca = s.str()?; + Ok(ca.str_normalize(form).into_column()) +} + +#[cfg(feature = "string_reverse")] +pub(super) fn reverse(s: &Column) -> PolarsResult { + let ca = s.str()?; + Ok(ca.str_reverse().into_column()) +} + +#[cfg(feature = "string_to_integer")] +pub(super) fn to_integer(s: &[Column], strict: bool) -> PolarsResult { + let ca = s[0].str()?; + let base = s[1].strict_cast(&DataType::UInt32)?; + ca.to_integer(base.u32()?, strict) + .map(|ok| ok.into_column()) +} + +fn _ensure_lengths(s: &[Column]) -> bool { + // Calculate the post-broadcast length and ensure everything is consistent. + let len = s + .iter() + .map(|series| series.len()) + .filter(|l| *l != 1) + .max() + .unwrap_or(1); + s.iter() + .all(|series| series.len() == 1 || series.len() == len) +} + +fn _check_same_length(s: &[Column], fn_name: &str) -> Result<(), PolarsError> { + polars_ensure!( + _ensure_lengths(s), + ShapeMismatch: "all series in `str.{}()` should have equal or unit length", + fn_name + ); + Ok(()) +} + +pub(super) fn str_slice(s: &[Column]) -> PolarsResult { + _check_same_length(s, "slice")?; + let ca = s[0].str()?; + let offset = &s[1]; + let length = &s[2]; + Ok(ca.str_slice(offset, length)?.into_column()) +} + +pub(super) fn str_head(s: &[Column]) -> PolarsResult { + _check_same_length(s, "head")?; + let ca = s[0].str()?; + let n = &s[1]; + Ok(ca.str_head(n)?.into_column()) +} + +pub(super) fn str_tail(s: &[Column]) -> PolarsResult { + _check_same_length(s, "tail")?; + let ca = s[0].str()?; + let n = &s[1]; + Ok(ca.str_tail(n)?.into_column()) +} + +#[cfg(feature = "string_encoding")] +pub(super) fn hex_encode(s: &Column) -> PolarsResult { + Ok(s.str()?.hex_encode().into_column()) +} + +#[cfg(feature = "binary_encoding")] +pub(super) fn hex_decode(s: &Column, strict: bool) -> PolarsResult { + s.str()?.hex_decode(strict).map(|ca| ca.into_column()) +} + +#[cfg(feature = "string_encoding")] +pub(super) fn base64_encode(s: &Column) -> PolarsResult { + Ok(s.str()?.base64_encode().into_column()) +} + +#[cfg(feature = "binary_encoding")] +pub(super) fn base64_decode(s: &Column, strict: bool) -> PolarsResult { + s.str()?.base64_decode(strict).map(|ca| ca.into_column()) +} + +#[cfg(feature = "dtype-decimal")] +pub(super) fn to_decimal(s: &Column, infer_len: usize) -> PolarsResult { + let ca = s.str()?; + ca.to_decimal(infer_len).map(Column::from) +} + +#[cfg(feature = "extract_jsonpath")] +pub(super) fn json_decode( + s: &Column, + dtype: Option, + infer_schema_len: Option, +) -> PolarsResult { + let ca = s.str()?; + ca.json_decode(dtype, infer_schema_len).map(Column::from) +} + +#[cfg(feature = "extract_jsonpath")] +pub(super) fn json_path_match(s: &[Column]) -> PolarsResult { + _check_same_length(s, "json_path_match")?; + let ca = s[0].str()?; + let pat = s[1].str()?; + Ok(ca.json_path_match(pat)?.into_column()) +} + +#[cfg(feature = "regex")] +pub(super) fn escape_regex(s: &Column) -> PolarsResult { + let ca = s.str()?; + Ok(ca.str_escape_regex().into_column()) +} diff --git a/crates/polars-plan/src/dsl/function_expr/struct_.rs b/crates/polars-plan/src/dsl/function_expr/struct_.rs new file mode 100644 index 000000000000..a0644a989185 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/struct_.rs @@ -0,0 +1,269 @@ +use polars_core::utils::slice_offsets; +use polars_utils::format_pl_smallstr; + +use super::*; +use crate::{map, map_as_slice}; + +#[derive(Clone, Eq, PartialEq, Hash, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum StructFunction { + FieldByIndex(i64), + FieldByName(PlSmallStr), + RenameFields(Arc<[PlSmallStr]>), + PrefixFields(PlSmallStr), + SuffixFields(PlSmallStr), + #[cfg(feature = "json")] + JsonEncode, + WithFields, + MultipleFields(Arc<[PlSmallStr]>), +} + +impl StructFunction { + pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult { + use StructFunction::*; + + match self { + FieldByIndex(index) => mapper.try_map_field(|field| { + let (index, _) = slice_offsets(*index, 0, mapper.get_fields_lens()); + if let DataType::Struct(ref fields) = field.dtype { + fields.get(index).cloned().ok_or_else( + || polars_err!(ComputeError: "index out of bounds in `struct.field`"), + ) + } else { + polars_bail!( + ComputeError: "expected struct dtype, got: `{}`", &field.dtype + ) + } + }), + FieldByName(name) => mapper.try_map_field(|field| { + if let DataType::Struct(ref fields) = field.dtype { + let fld = fields + .iter() + .find(|fld| fld.name() == name) + .ok_or_else(|| polars_err!(StructFieldNotFound: "{}", name))?; + Ok(fld.clone()) + } else { + polars_bail!(StructFieldNotFound: "{}", name); + } + }), + RenameFields(names) => mapper.map_dtype(|dt| match dt { + DataType::Struct(fields) => { + let fields = fields + .iter() + .zip(names.as_ref()) + .map(|(fld, name)| Field::new(name.clone(), fld.dtype().clone())) + .collect(); + DataType::Struct(fields) + }, + // The types will be incorrect, but its better than nothing + // we can get an incorrect type with python lambdas, because we only know return type when running + // the query + dt => DataType::Struct( + names + .iter() + .map(|name| Field::new(name.clone(), dt.clone())) + .collect(), + ), + }), + PrefixFields(prefix) => mapper.try_map_dtype(|dt| match dt { + DataType::Struct(fields) => { + let fields = fields + .iter() + .map(|fld| { + let name = fld.name(); + Field::new(format_pl_smallstr!("{prefix}{name}"), fld.dtype().clone()) + }) + .collect(); + Ok(DataType::Struct(fields)) + }, + _ => polars_bail!(op = "prefix_fields", got = dt, expected = "Struct"), + }), + SuffixFields(suffix) => mapper.try_map_dtype(|dt| match dt { + DataType::Struct(fields) => { + let fields = fields + .iter() + .map(|fld| { + let name = fld.name(); + Field::new(format_pl_smallstr!("{name}{suffix}"), fld.dtype().clone()) + }) + .collect(); + Ok(DataType::Struct(fields)) + }, + _ => polars_bail!(op = "suffix_fields", got = dt, expected = "Struct"), + }), + #[cfg(feature = "json")] + JsonEncode => mapper.with_dtype(DataType::String), + WithFields => { + let args = mapper.args(); + let struct_ = &args[0]; + + if let DataType::Struct(fields) = struct_.dtype() { + let mut name_2_dtype = PlIndexMap::with_capacity(fields.len() * 2); + + for field in fields { + name_2_dtype.insert(field.name(), field.dtype()); + } + for arg in &args[1..] { + name_2_dtype.insert(arg.name(), arg.dtype()); + } + let dtype = DataType::Struct( + name_2_dtype + .iter() + .map(|(&name, &dtype)| Field::new(name.clone(), dtype.clone())) + .collect(), + ); + let mut out = struct_.clone(); + out.coerce(dtype); + Ok(out) + } else { + let dt = struct_.dtype(); + polars_bail!(op = "with_fields", got = dt, expected = "Struct") + } + }, + MultipleFields(_) => panic!("should be expanded"), + } + } + + pub fn function_options(&self) -> FunctionOptions { + use StructFunction as S; + match self { + S::FieldByIndex(_) | S::FieldByName(_) => { + FunctionOptions::elementwise().with_allow_rename(true) + }, + S::RenameFields(_) | S::PrefixFields(_) | S::SuffixFields(_) => { + FunctionOptions::elementwise() + }, + #[cfg(feature = "json")] + S::JsonEncode => FunctionOptions::elementwise(), + S::WithFields => FunctionOptions::elementwise() + .with_pass_name_to_apply(true) + .with_input_wildcard_expansion(true), + S::MultipleFields(_) => FunctionOptions::elementwise().with_allow_rename(true), + } + } +} + +impl Display for StructFunction { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use StructFunction::*; + match self { + FieldByIndex(index) => write!(f, "struct.field_by_index({index})"), + FieldByName(name) => write!(f, "struct.field_by_name({name})"), + RenameFields(names) => write!(f, "struct.rename_fields({:?})", names), + PrefixFields(_) => write!(f, "name.prefix_fields"), + SuffixFields(_) => write!(f, "name.suffixFields"), + #[cfg(feature = "json")] + JsonEncode => write!(f, "struct.to_json"), + WithFields => write!(f, "with_fields"), + MultipleFields(_) => write!(f, "multiple_fields"), + } + } +} + +impl From for SpecialEq> { + fn from(func: StructFunction) -> Self { + use StructFunction::*; + match func { + FieldByIndex(_) => panic!("should be replaced"), + FieldByName(name) => map!(get_by_name, &name), + RenameFields(names) => map!(rename_fields, names.clone()), + PrefixFields(prefix) => map!(prefix_fields, prefix.as_str()), + SuffixFields(suffix) => map!(suffix_fields, suffix.as_str()), + #[cfg(feature = "json")] + JsonEncode => map!(to_json), + WithFields => map_as_slice!(with_fields), + MultipleFields(_) => unimplemented!(), + } + } +} + +pub(super) fn get_by_name(s: &Column, name: &str) -> PolarsResult { + let ca = s.struct_()?; + ca.field_by_name(name).map(Column::from) +} + +pub(super) fn rename_fields(s: &Column, names: Arc<[PlSmallStr]>) -> PolarsResult { + let ca = s.struct_()?; + let fields = ca + .fields_as_series() + .iter() + .zip(names.as_ref()) + .map(|(s, name)| { + let mut s = s.clone(); + s.rename(name.clone()); + s + }) + .collect::>(); + let mut out = StructChunked::from_series(ca.name().clone(), ca.len(), fields.iter())?; + out.zip_outer_validity(ca); + Ok(out.into_column()) +} + +pub(super) fn prefix_fields(s: &Column, prefix: &str) -> PolarsResult { + let ca = s.struct_()?; + let fields = ca + .fields_as_series() + .iter() + .map(|s| { + let mut s = s.clone(); + let name = s.name(); + s.rename(format_pl_smallstr!("{prefix}{name}")); + s + }) + .collect::>(); + let mut out = StructChunked::from_series(ca.name().clone(), ca.len(), fields.iter())?; + out.zip_outer_validity(ca); + Ok(out.into_column()) +} + +pub(super) fn suffix_fields(s: &Column, suffix: &str) -> PolarsResult { + let ca = s.struct_()?; + let fields = ca + .fields_as_series() + .iter() + .map(|s| { + let mut s = s.clone(); + let name = s.name(); + s.rename(format_pl_smallstr!("{name}{suffix}")); + s + }) + .collect::>(); + let mut out = StructChunked::from_series(ca.name().clone(), ca.len(), fields.iter())?; + out.zip_outer_validity(ca); + Ok(out.into_column()) +} + +#[cfg(feature = "json")] +pub(super) fn to_json(s: &Column) -> PolarsResult { + let ca = s.struct_()?; + let dtype = ca.dtype().to_arrow(CompatLevel::newest()); + + let iter = ca.chunks().iter().map(|arr| { + let arr = polars_compute::cast::cast_unchecked(arr.as_ref(), &dtype).unwrap(); + polars_json::json::write::serialize_to_utf8(arr.as_ref()) + }); + + Ok(StringChunked::from_chunk_iter(ca.name().clone(), iter).into_column()) +} + +pub(super) fn with_fields(args: &[Column]) -> PolarsResult { + let s = &args[0]; + + let ca = s.struct_()?; + let current = ca.fields_as_series(); + + let mut fields = PlIndexMap::with_capacity(current.len() + s.len() - 1); + + for field in current.iter() { + fields.insert(field.name(), field); + } + + for field in &args[1..] { + fields.insert(field.name(), field.as_materialized_series()); + } + + let new_fields = fields.into_values().cloned().collect::>(); + let mut out = StructChunked::from_series(ca.name().clone(), ca.len(), new_fields.iter())?; + out.zip_outer_validity(ca); + Ok(out.into_column()) +} diff --git a/crates/polars-plan/src/dsl/function_expr/temporal.rs b/crates/polars-plan/src/dsl/function_expr/temporal.rs new file mode 100644 index 000000000000..c3017ad516e6 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/temporal.rs @@ -0,0 +1,202 @@ +#[cfg(feature = "timezones")] +use polars_core::chunked_array::temporal::validate_time_zone; + +use super::*; +use crate::{map, map_as_slice}; + +impl From for SpecialEq> { + fn from(func: TemporalFunction) -> Self { + use TemporalFunction::*; + match func { + Millennium => map!(datetime::millennium), + Century => map!(datetime::century), + Year => map!(datetime::year), + IsLeapYear => map!(datetime::is_leap_year), + IsoYear => map!(datetime::iso_year), + Month => map!(datetime::month), + Quarter => map!(datetime::quarter), + Week => map!(datetime::week), + WeekDay => map!(datetime::weekday), + Duration(tu) => map_as_slice!(impl_duration, tu), + Day => map!(datetime::day), + OrdinalDay => map!(datetime::ordinal_day), + Time => map!(datetime::time), + Date => map!(datetime::date), + Datetime => map!(datetime::datetime), + Hour => map!(datetime::hour), + Minute => map!(datetime::minute), + Second => map!(datetime::second), + Millisecond => map!(datetime::millisecond), + Microsecond => map!(datetime::microsecond), + Nanosecond => map!(datetime::nanosecond), + TotalDays => map!(datetime::total_days), + TotalHours => map!(datetime::total_hours), + TotalMinutes => map!(datetime::total_minutes), + TotalSeconds => map!(datetime::total_seconds), + TotalMilliseconds => map!(datetime::total_milliseconds), + TotalMicroseconds => map!(datetime::total_microseconds), + TotalNanoseconds => map!(datetime::total_nanoseconds), + ToString(format) => map!(datetime::to_string, &format), + TimeStamp(tu) => map!(datetime::timestamp, tu), + #[cfg(feature = "timezones")] + ConvertTimeZone(tz) => map!(datetime::convert_time_zone, &tz), + WithTimeUnit(tu) => map!(datetime::with_time_unit, tu), + CastTimeUnit(tu) => map!(datetime::cast_time_unit, tu), + Truncate => { + map_as_slice!(datetime::truncate) + }, + #[cfg(feature = "offset_by")] + OffsetBy => { + map_as_slice!(datetime::offset_by) + }, + #[cfg(feature = "month_start")] + MonthStart => map!(datetime::month_start), + #[cfg(feature = "month_end")] + MonthEnd => map!(datetime::month_end), + #[cfg(feature = "timezones")] + BaseUtcOffset => map!(datetime::base_utc_offset), + #[cfg(feature = "timezones")] + DSTOffset => map!(datetime::dst_offset), + Round => map_as_slice!(datetime::round), + Replace => map_as_slice!(datetime::replace), + #[cfg(feature = "timezones")] + ReplaceTimeZone(tz, non_existent) => { + map_as_slice!(dispatch::replace_time_zone, tz.as_deref(), non_existent) + }, + Combine(tu) => map_as_slice!(temporal::combine, tu), + DatetimeFunction { + time_unit, + time_zone, + } => { + map_as_slice!(temporal::datetime, &time_unit, time_zone.as_deref()) + }, + } + } +} + +#[cfg(feature = "dtype-datetime")] +pub(super) fn datetime( + s: &[Column], + time_unit: &TimeUnit, + time_zone: Option<&str>, +) -> PolarsResult { + let col_name = PlSmallStr::from_static("datetime"); + + if s.iter().any(|s| s.is_empty()) { + return Ok(Column::new_empty( + col_name, + &DataType::Datetime( + time_unit.to_owned(), + match time_zone { + #[cfg(feature = "timezones")] + Some(time_zone) => { + validate_time_zone(time_zone)?; + Some(PlSmallStr::from_str(time_zone)) + }, + _ => { + assert!( + time_zone.is_none(), + "cannot make use of the `time_zone` argument without the 'timezones' feature enabled." + ); + None + }, + }, + ), + )); + } + + let year = &s[0]; + let month = &s[1]; + let day = &s[2]; + let hour = &s[3]; + let minute = &s[4]; + let second = &s[5]; + let microsecond = &s[6]; + let ambiguous = &s[7]; + + let max_len = s.iter().map(|s| s.len()).max().unwrap(); + + let mut year = year.cast(&DataType::Int32)?; + if year.len() < max_len { + year = year.new_from_index(0, max_len) + } + let year = year.i32()?; + + let mut month = month.cast(&DataType::Int8)?; + if month.len() < max_len { + month = month.new_from_index(0, max_len); + } + let month = month.i8()?; + + let mut day = day.cast(&DataType::Int8)?; + if day.len() < max_len { + day = day.new_from_index(0, max_len); + } + let day = day.i8()?; + + let mut hour = hour.cast(&DataType::Int8)?; + if hour.len() < max_len { + hour = hour.new_from_index(0, max_len); + } + let hour = hour.i8()?; + + let mut minute = minute.cast(&DataType::Int8)?; + if minute.len() < max_len { + minute = minute.new_from_index(0, max_len); + } + let minute = minute.i8()?; + + let mut second = second.cast(&DataType::Int8)?; + if second.len() < max_len { + second = second.new_from_index(0, max_len); + } + let second = second.i8()?; + + let mut nanosecond = microsecond.cast(&DataType::Int32)? * 1_000; + if nanosecond.len() < max_len { + nanosecond = nanosecond.new_from_index(0, max_len); + } + let nanosecond = nanosecond.i32()?; + + let mut _ambiguous = ambiguous.cast(&DataType::String)?; + if _ambiguous.len() < max_len { + _ambiguous = _ambiguous.new_from_index(0, max_len); + } + let ambiguous = _ambiguous.str()?; + + let ca = DatetimeChunked::new_from_parts( + year, month, day, hour, minute, second, nanosecond, ambiguous, time_unit, time_zone, + col_name, + ); + ca.map(|s| s.into_column()) +} + +pub(super) fn combine(s: &[Column], tu: TimeUnit) -> PolarsResult { + let date = &s[0]; + let time = &s[1]; + + let tz = match date.dtype() { + DataType::Date => None, + DataType::Datetime(_, tz) => tz.as_ref(), + _dtype => { + polars_bail!(ComputeError: format!("expected Date or Datetime, got {}", _dtype)) + }, + }; + + let date = date.cast(&DataType::Date)?; + let datetime = date.cast(&DataType::Datetime(tu, None)).unwrap(); + + let duration = time.cast(&DataType::Duration(tu))?; + let result_naive = datetime + duration; + match tz { + #[cfg(feature = "timezones")] + Some(tz) => Ok(polars_ops::prelude::replace_time_zone( + result_naive?.datetime().unwrap(), + Some(tz), + &StringChunked::from_iter(std::iter::once("raise")), + NonExistent::Raise, + )? + .into_column()), + _ => result_naive, + } +} diff --git a/crates/polars-plan/src/dsl/function_expr/trigonometry.rs b/crates/polars-plan/src/dsl/function_expr/trigonometry.rs new file mode 100644 index 000000000000..269263a073d5 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/trigonometry.rs @@ -0,0 +1,291 @@ +use num_traits::Float; +use polars_core::chunked_array::ops::arity::broadcast_binary_elementwise; + +use super::*; + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash)] +pub enum TrigonometricFunction { + Cos, + Cot, + Sin, + Tan, + ArcCos, + ArcSin, + ArcTan, + Cosh, + Sinh, + Tanh, + ArcCosh, + ArcSinh, + ArcTanh, + Degrees, + Radians, +} + +impl Display for TrigonometricFunction { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use self::*; + match self { + TrigonometricFunction::Cos => write!(f, "cos"), + TrigonometricFunction::Cot => write!(f, "cot"), + TrigonometricFunction::Sin => write!(f, "sin"), + TrigonometricFunction::Tan => write!(f, "tan"), + TrigonometricFunction::ArcCos => write!(f, "arccos"), + TrigonometricFunction::ArcSin => write!(f, "arcsin"), + TrigonometricFunction::ArcTan => write!(f, "arctan"), + TrigonometricFunction::Cosh => write!(f, "cosh"), + TrigonometricFunction::Sinh => write!(f, "sinh"), + TrigonometricFunction::Tanh => write!(f, "tanh"), + TrigonometricFunction::ArcCosh => write!(f, "arccosh"), + TrigonometricFunction::ArcSinh => write!(f, "arcsinh"), + TrigonometricFunction::ArcTanh => write!(f, "arctanh"), + TrigonometricFunction::Degrees => write!(f, "degrees"), + TrigonometricFunction::Radians => write!(f, "radians"), + } + } +} + +impl From for FunctionExpr { + fn from(value: TrigonometricFunction) -> Self { + Self::Trigonometry(value) + } +} + +pub(super) fn apply_trigonometric_function( + s: &Column, + trig_function: TrigonometricFunction, +) -> PolarsResult { + use DataType::*; + match s.dtype() { + Float32 => { + let ca = s.f32().unwrap(); + apply_trigonometric_function_to_float(ca, trig_function) + }, + Float64 => { + let ca = s.f64().unwrap(); + apply_trigonometric_function_to_float(ca, trig_function) + }, + dt if dt.is_primitive_numeric() => { + let s = s.cast(&Float64)?; + apply_trigonometric_function(&s, trig_function) + }, + dt => polars_bail!(op = "trigonometry", dt), + } +} + +pub(super) fn apply_arctan2(s: &mut [Column]) -> PolarsResult> { + let y = &s[0]; + let x = &s[1]; + + let y_len = y.len(); + let x_len = x.len(); + + match (y_len, x_len) { + (1, _) | (_, 1) => arctan2_on_columns(y, x), + (len_a, len_b) if len_a == len_b => arctan2_on_columns(y, x), + _ => polars_bail!( + ComputeError: + "y shape: {} in `arctan2` expression does not match that of x: {}", + y_len, x_len, + ), + } +} + +fn arctan2_on_columns(y: &Column, x: &Column) -> PolarsResult> { + use DataType::*; + match y.dtype() { + Float32 => { + let y_ca: &ChunkedArray = y.f32().unwrap(); + arctan2_on_floats(y_ca, x) + }, + Float64 => { + let y_ca: &ChunkedArray = y.f64().unwrap(); + arctan2_on_floats(y_ca, x) + }, + _ => { + let y = y.cast(&DataType::Float64)?; + arctan2_on_columns(&y, x) + }, + } +} + +fn arctan2_on_floats(y: &ChunkedArray, x: &Column) -> PolarsResult> +where + T: PolarsFloatType, + T::Native: Float, + ChunkedArray: IntoColumn, +{ + let dtype = T::get_dtype(); + let x = x.cast(&dtype)?; + let x = y + .unpack_series_matching_type(x.as_materialized_series()) + .unwrap(); + + Ok(Some( + broadcast_binary_elementwise(y, x, |yv, xv| Some(yv?.atan2(xv?))).into_column(), + )) +} + +fn apply_trigonometric_function_to_float( + ca: &ChunkedArray, + trig_function: TrigonometricFunction, +) -> PolarsResult +where + T: PolarsFloatType, + T::Native: Float, + ChunkedArray: IntoColumn, +{ + match trig_function { + TrigonometricFunction::Cos => cos(ca), + TrigonometricFunction::Cot => cot(ca), + TrigonometricFunction::Sin => sin(ca), + TrigonometricFunction::Tan => tan(ca), + TrigonometricFunction::ArcCos => arccos(ca), + TrigonometricFunction::ArcSin => arcsin(ca), + TrigonometricFunction::ArcTan => arctan(ca), + TrigonometricFunction::Cosh => cosh(ca), + TrigonometricFunction::Sinh => sinh(ca), + TrigonometricFunction::Tanh => tanh(ca), + TrigonometricFunction::ArcCosh => arccosh(ca), + TrigonometricFunction::ArcSinh => arcsinh(ca), + TrigonometricFunction::ArcTanh => arctanh(ca), + TrigonometricFunction::Degrees => degrees(ca), + TrigonometricFunction::Radians => radians(ca), + } +} + +fn cos(ca: &ChunkedArray) -> PolarsResult +where + T: PolarsFloatType, + T::Native: Float, + ChunkedArray: IntoColumn, +{ + Ok(ca.apply_values(|v| v.cos()).into_column()) +} + +fn cot(ca: &ChunkedArray) -> PolarsResult +where + T: PolarsFloatType, + T::Native: Float, + ChunkedArray: IntoColumn, +{ + Ok(ca.apply_values(|v| v.tan().powi(-1)).into_column()) +} + +fn sin(ca: &ChunkedArray) -> PolarsResult +where + T: PolarsFloatType, + T::Native: Float, + ChunkedArray: IntoColumn, +{ + Ok(ca.apply_values(|v| v.sin()).into_column()) +} + +fn tan(ca: &ChunkedArray) -> PolarsResult +where + T: PolarsFloatType, + T::Native: Float, + ChunkedArray: IntoColumn, +{ + Ok(ca.apply_values(|v| v.tan()).into_column()) +} + +fn arccos(ca: &ChunkedArray) -> PolarsResult +where + T: PolarsFloatType, + T::Native: Float, + ChunkedArray: IntoColumn, +{ + Ok(ca.apply_values(|v| v.acos()).into_column()) +} + +fn arcsin(ca: &ChunkedArray) -> PolarsResult +where + T: PolarsFloatType, + T::Native: Float, + ChunkedArray: IntoColumn, +{ + Ok(ca.apply_values(|v| v.asin()).into_column()) +} + +fn arctan(ca: &ChunkedArray) -> PolarsResult +where + T: PolarsFloatType, + T::Native: Float, + ChunkedArray: IntoColumn, +{ + Ok(ca.apply_values(|v| v.atan()).into_column()) +} + +fn cosh(ca: &ChunkedArray) -> PolarsResult +where + T: PolarsFloatType, + T::Native: Float, + ChunkedArray: IntoColumn, +{ + Ok(ca.apply_values(|v| v.cosh()).into_column()) +} + +fn sinh(ca: &ChunkedArray) -> PolarsResult +where + T: PolarsFloatType, + T::Native: Float, + ChunkedArray: IntoColumn, +{ + Ok(ca.apply_values(|v| v.sinh()).into_column()) +} + +fn tanh(ca: &ChunkedArray) -> PolarsResult +where + T: PolarsFloatType, + T::Native: Float, + ChunkedArray: IntoColumn, +{ + Ok(ca.apply_values(|v| v.tanh()).into_column()) +} + +fn arccosh(ca: &ChunkedArray) -> PolarsResult +where + T: PolarsFloatType, + T::Native: Float, + ChunkedArray: IntoColumn, +{ + Ok(ca.apply_values(|v| v.acosh()).into_column()) +} + +fn arcsinh(ca: &ChunkedArray) -> PolarsResult +where + T: PolarsFloatType, + T::Native: Float, + ChunkedArray: IntoColumn, +{ + Ok(ca.apply_values(|v| v.asinh()).into_column()) +} + +fn arctanh(ca: &ChunkedArray) -> PolarsResult +where + T: PolarsFloatType, + T::Native: Float, + ChunkedArray: IntoColumn, +{ + Ok(ca.apply_values(|v| v.atanh()).into_column()) +} + +fn degrees(ca: &ChunkedArray) -> PolarsResult +where + T: PolarsFloatType, + T::Native: Float, + ChunkedArray: IntoColumn, +{ + Ok(ca.apply_values(|v| v.to_degrees()).into_column()) +} + +fn radians(ca: &ChunkedArray) -> PolarsResult +where + T: PolarsFloatType, + T::Native: Float, + ChunkedArray: IntoColumn, +{ + Ok(ca.apply_values(|v| v.to_radians()).into_column()) +} diff --git a/crates/polars-plan/src/dsl/function_expr/unique.rs b/crates/polars-plan/src/dsl/function_expr/unique.rs new file mode 100644 index 000000000000..c9f22a841f37 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/unique.rs @@ -0,0 +1,9 @@ +use super::*; + +pub(super) fn unique(s: &Column, stable: bool) -> PolarsResult { + if stable { + s.unique_stable() + } else { + s.unique() + } +} diff --git a/crates/polars-plan/src/dsl/functions/arity.rs b/crates/polars-plan/src/dsl/functions/arity.rs new file mode 100644 index 000000000000..e3fb7b884885 --- /dev/null +++ b/crates/polars-plan/src/dsl/functions/arity.rs @@ -0,0 +1,34 @@ +use super::*; + +macro_rules! prepare_binary_function { + ($f:ident) => { + move |c: &mut [Column]| { + let s0 = std::mem::take(&mut c[0]); + let s1 = std::mem::take(&mut c[1]); + + $f(s0, s1) + } + }; +} + +/// Apply a closure on the two columns that are evaluated from [`Expr`] a and [`Expr`] b. +/// +/// The closure takes two arguments, each a [`Series`]. `output_type` must be the output dtype of the resulting [`Series`]. +pub fn map_binary(a: Expr, b: Expr, f: F, output_type: GetOutput) -> Expr +where + F: 'static + Fn(Column, Column) -> PolarsResult> + Send + Sync, +{ + let function = prepare_binary_function!(f); + a.map_many(function, &[b], output_type) +} + +/// Like [`map_binary`], but used in a group_by-aggregation context. +/// +/// See [`Expr::apply`] for the difference between [`map`](Expr::map) and [`apply`](Expr::apply). +pub fn apply_binary(a: Expr, b: Expr, f: F, output_type: GetOutput) -> Expr +where + F: 'static + Fn(Column, Column) -> PolarsResult> + Send + Sync, +{ + let function = prepare_binary_function!(f); + a.apply_many(function, &[b], output_type) +} diff --git a/crates/polars-plan/src/dsl/functions/business.rs b/crates/polars-plan/src/dsl/functions/business.rs new file mode 100644 index 000000000000..f3c85863dde9 --- /dev/null +++ b/crates/polars-plan/src/dsl/functions/business.rs @@ -0,0 +1,24 @@ +use super::*; + +#[cfg(feature = "dtype-date")] +pub fn business_day_count( + start: Expr, + end: Expr, + week_mask: [bool; 7], + holidays: Vec, +) -> Expr { + let input = vec![start, end]; + + Expr::Function { + input, + function: FunctionExpr::Business(BusinessFunction::BusinessDayCount { + week_mask, + holidays, + }), + options: FunctionOptions { + collect_groups: ApplyOptions::ElementWise, + flags: FunctionFlags::default() | FunctionFlags::ALLOW_RENAME, + ..Default::default() + }, + } +} diff --git a/crates/polars-plan/src/dsl/functions/coerce.rs b/crates/polars-plan/src/dsl/functions/coerce.rs new file mode 100644 index 000000000000..80adc959f2e1 --- /dev/null +++ b/crates/polars-plan/src/dsl/functions/coerce.rs @@ -0,0 +1,12 @@ +use super::*; + +/// Take several expressions and collect them into a [`StructChunked`]. +/// # Panics +/// panics if `exprs` is empty. +pub fn as_struct(exprs: Vec) -> Expr { + assert!( + !exprs.is_empty(), + "expected at least 1 field in 'as_struct'" + ); + Expr::n_ary(FunctionExpr::AsStruct, exprs) +} diff --git a/crates/polars-plan/src/dsl/functions/concat.rs b/crates/polars-plan/src/dsl/functions/concat.rs new file mode 100644 index 000000000000..0cefccde7bcf --- /dev/null +++ b/crates/polars-plan/src/dsl/functions/concat.rs @@ -0,0 +1,96 @@ +use super::*; + +#[cfg(all(feature = "concat_str", feature = "strings"))] +/// Horizontally concat string columns in linear time +pub fn concat_str>(s: E, separator: &str, ignore_nulls: bool) -> Expr { + let input = s.as_ref().to_vec(); + let separator = separator.into(); + + Expr::Function { + input, + function: StringFunction::ConcatHorizontal { + delimiter: separator, + ignore_nulls, + } + .into(), + options: FunctionOptions { + collect_groups: ApplyOptions::ElementWise, + flags: FunctionFlags::default() + | FunctionFlags::INPUT_WILDCARD_EXPANSION & !FunctionFlags::RETURNS_SCALAR, + ..Default::default() + }, + } +} + +#[cfg(all(feature = "concat_str", feature = "strings"))] +/// Format the results of an array of expressions using a format string +pub fn format_str>(format: &str, args: E) -> PolarsResult { + let mut args: std::collections::VecDeque = args.as_ref().to_vec().into(); + + // Parse the format string, and separate substrings between placeholders + let segments: Vec<&str> = format.split("{}").collect(); + + polars_ensure!( + segments.len() - 1 == args.len(), + ShapeMismatch: "number of placeholders should equal the number of arguments" + ); + + let mut exprs: Vec = Vec::new(); + + for (i, s) in segments.iter().enumerate() { + if i > 0 { + if let Some(arg) = args.pop_front() { + exprs.push(arg); + } + } + + if !s.is_empty() { + exprs.push(lit(s.to_string())) + } + } + + Ok(concat_str(exprs, "", false)) +} + +/// Concat lists entries. +pub fn concat_list, IE: Into + Clone>(s: E) -> PolarsResult { + let s: Vec<_> = s.as_ref().iter().map(|e| e.clone().into()).collect(); + + polars_ensure!(!s.is_empty(), ComputeError: "`concat_list` needs one or more expressions"); + + Ok(Expr::Function { + input: s, + function: FunctionExpr::ListExpr(ListFunction::Concat), + options: FunctionOptions { + collect_groups: ApplyOptions::ElementWise, + flags: FunctionFlags::default() | FunctionFlags::INPUT_WILDCARD_EXPANSION, + ..Default::default() + }, + }) +} + +/// Horizontally concatenate columns into a single array-type column. +pub fn concat_arr(input: Vec) -> PolarsResult { + feature_gated!("dtype-array", { + polars_ensure!(!input.is_empty(), ComputeError: "`concat_arr` needs one or more expressions"); + + Ok(Expr::Function { + input, + function: FunctionExpr::ArrayExpr(ArrayFunction::Concat), + options: FunctionOptions { + collect_groups: ApplyOptions::ElementWise, + flags: FunctionFlags::default() | FunctionFlags::INPUT_WILDCARD_EXPANSION, + ..Default::default() + }, + }) + }) +} + +pub fn concat_expr, IE: Into + Clone>( + s: E, + rechunk: bool, +) -> PolarsResult { + let s: Vec<_> = s.as_ref().iter().map(|e| e.clone().into()).collect(); + polars_ensure!(!s.is_empty(), ComputeError: "`concat_expr` needs one or more expressions"); + Ok(Expr::n_ary(FunctionExpr::ConcatExpr(rechunk), s)) +} diff --git a/crates/polars-plan/src/dsl/functions/correlation.rs b/crates/polars-plan/src/dsl/functions/correlation.rs new file mode 100644 index 000000000000..3ddd76b73772 --- /dev/null +++ b/crates/polars-plan/src/dsl/functions/correlation.rs @@ -0,0 +1,62 @@ +use super::*; + +/// Compute the covariance between two columns. +pub fn cov(a: Expr, b: Expr, ddof: u8) -> Expr { + let function = FunctionExpr::Correlation { + method: CorrelationMethod::Covariance(ddof), + }; + a.map_binary(function, b) +} + +/// Compute the pearson correlation between two columns. +pub fn pearson_corr(a: Expr, b: Expr) -> Expr { + let function = FunctionExpr::Correlation { + method: CorrelationMethod::Pearson, + }; + a.map_binary(function, b) +} + +/// Compute the spearman rank correlation between two columns. +/// Missing data will be excluded from the computation. +/// # Arguments +/// * propagate_nans +/// If `true` any `NaN` encountered will lead to `NaN` in the output. +/// If to `false` then `NaN` are regarded as larger than any finite number +/// and thus lead to the highest rank. +#[cfg(all(feature = "rank", feature = "propagate_nans"))] +pub fn spearman_rank_corr(a: Expr, b: Expr, propagate_nans: bool) -> Expr { + let function = FunctionExpr::Correlation { + method: CorrelationMethod::SpearmanRank(propagate_nans), + }; + a.map_binary(function, b) +} + +#[cfg(all(feature = "rolling_window", feature = "cov"))] +fn dispatch_corr_cov(x: Expr, y: Expr, options: RollingCovOptions, is_corr: bool) -> Expr { + // see: https://github.com/pandas-dev/pandas/blob/v1.5.1/pandas/core/window/rolling.py#L1780-L1804 + let rolling_options = RollingOptionsFixedWindow { + window_size: options.window_size as usize, + min_periods: options.min_periods as usize, + ..Default::default() + }; + + Expr::Function { + input: vec![x, y], + function: FunctionExpr::RollingExpr(RollingFunction::CorrCov { + rolling_options, + corr_cov_options: options, + is_corr, + }), + options: Default::default(), + } +} + +#[cfg(all(feature = "rolling_window", feature = "cov"))] +pub fn rolling_corr(x: Expr, y: Expr, options: RollingCovOptions) -> Expr { + dispatch_corr_cov(x, y, options, true) +} + +#[cfg(all(feature = "rolling_window", feature = "cov"))] +pub fn rolling_cov(x: Expr, y: Expr, options: RollingCovOptions) -> Expr { + dispatch_corr_cov(x, y, options, false) +} diff --git a/crates/polars-plan/src/dsl/functions/horizontal.rs b/crates/polars-plan/src/dsl/functions/horizontal.rs new file mode 100644 index 000000000000..27f5abb8aff4 --- /dev/null +++ b/crates/polars-plan/src/dsl/functions/horizontal.rs @@ -0,0 +1,268 @@ +use super::*; + +#[cfg(feature = "dtype-struct")] +fn cum_fold_dtype() -> GetOutput { + GetOutput::map_fields(|fields| { + let mut st = fields[0].dtype.clone(); + for fld in &fields[1..] { + st = get_supertype(&st, &fld.dtype).unwrap(); + } + Ok(Field::new( + fields[0].name.clone(), + DataType::Struct( + fields + .iter() + .map(|fld| Field::new(fld.name().clone(), st.clone())) + .collect(), + ), + )) + }) +} + +/// Accumulate over multiple columns horizontally / row wise. +pub fn fold_exprs(acc: Expr, f: F, exprs: E) -> Expr +where + F: 'static + Fn(Column, Column) -> PolarsResult> + Send + Sync, + E: AsRef<[Expr]>, +{ + let mut exprs_v = Vec::with_capacity(exprs.as_ref().len() + 1); + exprs_v.push(acc); + exprs_v.extend(exprs.as_ref().iter().cloned()); + let exprs = exprs_v; + + let function = new_column_udf(move |columns: &mut [Column]| { + let mut acc = columns.first().unwrap().clone(); + for c in &columns[1..] { + if let Some(a) = f(acc.clone(), c.clone())? { + acc = a + } + } + Ok(Some(acc)) + }); + + Expr::AnonymousFunction { + input: exprs, + function, + // Take the type of the accumulator. + output_type: GetOutput::first(), + options: FunctionOptions { + collect_groups: ApplyOptions::GroupWise, + flags: FunctionFlags::default() + | FunctionFlags::INPUT_WILDCARD_EXPANSION + | FunctionFlags::RETURNS_SCALAR, + fmt_str: "fold", + ..Default::default() + }, + } +} + +/// Analogous to [`Iterator::reduce`](std::iter::Iterator::reduce). +/// +/// An accumulator is initialized to the series given by the first expression in `exprs`, and then each subsequent value +/// of the accumulator is computed from `f(acc, next_expr_series)`. If `exprs` is empty, an error is returned when +/// `collect` is called. +pub fn reduce_exprs(f: F, exprs: E) -> Expr +where + F: 'static + Fn(Column, Column) -> PolarsResult> + Send + Sync, + E: AsRef<[Expr]>, +{ + let exprs = exprs.as_ref().to_vec(); + + let function = new_column_udf(move |columns: &mut [Column]| { + let mut c_iter = columns.iter(); + + match c_iter.next() { + Some(acc) => { + let mut acc = acc.clone(); + + for c in c_iter { + if let Some(a) = f(acc.clone(), c.clone())? { + acc = a + } + } + Ok(Some(acc)) + }, + None => Err(polars_err!(ComputeError: "`reduce` did not have any expressions to fold")), + } + }); + + Expr::AnonymousFunction { + input: exprs, + function, + output_type: GetOutput::super_type(), + options: FunctionOptions { + collect_groups: ApplyOptions::GroupWise, + flags: FunctionFlags::default() + | FunctionFlags::INPUT_WILDCARD_EXPANSION + | FunctionFlags::RETURNS_SCALAR, + fmt_str: "reduce", + ..Default::default() + }, + } +} + +/// Accumulate over multiple columns horizontally / row wise. +#[cfg(feature = "dtype-struct")] +pub fn cum_reduce_exprs(f: F, exprs: E) -> Expr +where + F: 'static + Fn(Column, Column) -> PolarsResult> + Send + Sync, + E: AsRef<[Expr]>, +{ + let exprs = exprs.as_ref().to_vec(); + + let function = new_column_udf(move |columns: &mut [Column]| { + let mut c_iter = columns.iter(); + + match c_iter.next() { + Some(acc) => { + let mut acc = acc.clone(); + let mut result = vec![acc.clone()]; + + for c in c_iter { + let name = c.name().clone(); + if let Some(a) = f(acc.clone(), c.clone())? { + acc = a; + } + acc.rename(name); + result.push(acc.clone()); + } + + StructChunked::from_columns(acc.name().clone(), result[0].len(), &result) + .map(|ca| Some(ca.into_column())) + }, + None => Err(polars_err!(ComputeError: "`reduce` did not have any expressions to fold")), + } + }); + + Expr::AnonymousFunction { + input: exprs, + function, + output_type: cum_fold_dtype(), + options: FunctionOptions { + collect_groups: ApplyOptions::GroupWise, + flags: FunctionFlags::default() + | FunctionFlags::INPUT_WILDCARD_EXPANSION + | FunctionFlags::RETURNS_SCALAR, + fmt_str: "cum_reduce", + ..Default::default() + }, + } +} + +/// Accumulate over multiple columns horizontally / row wise. +#[cfg(feature = "dtype-struct")] +pub fn cum_fold_exprs(acc: Expr, f: F, exprs: E, include_init: bool) -> Expr +where + F: 'static + Fn(Column, Column) -> PolarsResult> + Send + Sync, + E: AsRef<[Expr]>, +{ + let mut exprs = exprs.as_ref().to_vec(); + exprs.push(acc); + + let function = new_column_udf(move |columns: &mut [Column]| { + let mut columns = columns.to_vec(); + let mut acc = columns.pop().unwrap(); + + let mut result = vec![]; + if include_init { + result.push(acc.clone()) + } + + for c in columns { + let name = c.name().clone(); + if let Some(a) = f(acc.clone(), c)? { + acc = a; + acc.rename(name); + result.push(acc.clone()); + } + } + + StructChunked::from_columns(acc.name().clone(), result[0].len(), &result) + .map(|ca| Some(ca.into_column())) + }); + + Expr::AnonymousFunction { + input: exprs, + function, + output_type: cum_fold_dtype(), + options: FunctionOptions { + collect_groups: ApplyOptions::GroupWise, + flags: FunctionFlags::default() + | FunctionFlags::INPUT_WILDCARD_EXPANSION + | FunctionFlags::RETURNS_SCALAR, + fmt_str: "cum_fold", + ..Default::default() + }, + } +} + +/// Create a new column with the bitwise-and of the elements in each row. +/// +/// The name of the resulting column will be "all"; use [`alias`](Expr::alias) to choose a different name. +pub fn all_horizontal>(exprs: E) -> PolarsResult { + let exprs = exprs.as_ref().to_vec(); + polars_ensure!(!exprs.is_empty(), ComputeError: "cannot return empty fold because the number of output rows is unknown"); + // This will be reduced to `expr & expr` during conversion to IR. + Ok(Expr::n_ary( + FunctionExpr::Boolean(BooleanFunction::AllHorizontal), + exprs, + )) +} + +/// Create a new column with the bitwise-or of the elements in each row. +/// +/// The name of the resulting column will be "any"; use [`alias`](Expr::alias) to choose a different name. +pub fn any_horizontal>(exprs: E) -> PolarsResult { + let exprs = exprs.as_ref().to_vec(); + polars_ensure!(!exprs.is_empty(), ComputeError: "cannot return empty fold because the number of output rows is unknown"); + // This will be reduced to `expr | expr` during conversion to IR. + Ok(Expr::n_ary( + FunctionExpr::Boolean(BooleanFunction::AnyHorizontal), + exprs, + )) +} + +/// Create a new column with the maximum value per row. +/// +/// The name of the resulting column will be `"max"`; use [`alias`](Expr::alias) to choose a different name. +pub fn max_horizontal>(exprs: E) -> PolarsResult { + let exprs = exprs.as_ref().to_vec(); + polars_ensure!(!exprs.is_empty(), ComputeError: "cannot return empty fold because the number of output rows is unknown"); + Ok(Expr::n_ary(FunctionExpr::MaxHorizontal, exprs)) +} + +/// Create a new column with the minimum value per row. +/// +/// The name of the resulting column will be `"min"`; use [`alias`](Expr::alias) to choose a different name. +pub fn min_horizontal>(exprs: E) -> PolarsResult { + let exprs = exprs.as_ref().to_vec(); + polars_ensure!(!exprs.is_empty(), ComputeError: "cannot return empty fold because the number of output rows is unknown"); + Ok(Expr::n_ary(FunctionExpr::MinHorizontal, exprs)) +} + +/// Sum all values horizontally across columns. +pub fn sum_horizontal>(exprs: E, ignore_nulls: bool) -> PolarsResult { + let exprs = exprs.as_ref().to_vec(); + polars_ensure!(!exprs.is_empty(), ComputeError: "cannot return empty fold because the number of output rows is unknown"); + Ok(Expr::n_ary( + FunctionExpr::SumHorizontal { ignore_nulls }, + exprs, + )) +} + +/// Compute the mean of all values horizontally across columns. +pub fn mean_horizontal>(exprs: E, ignore_nulls: bool) -> PolarsResult { + let exprs = exprs.as_ref().to_vec(); + polars_ensure!(!exprs.is_empty(), ComputeError: "cannot return empty fold because the number of output rows is unknown"); + Ok(Expr::n_ary( + FunctionExpr::MeanHorizontal { ignore_nulls }, + exprs, + )) +} + +/// Folds the expressions from left to right keeping the first non-null values. +/// +/// It is an error to provide an empty `exprs`. +pub fn coalesce(exprs: &[Expr]) -> Expr { + Expr::n_ary(FunctionExpr::Coalesce, exprs.to_vec()) +} diff --git a/crates/polars-plan/src/dsl/functions/index.rs b/crates/polars-plan/src/dsl/functions/index.rs new file mode 100644 index 000000000000..ff9ca2599fdb --- /dev/null +++ b/crates/polars-plan/src/dsl/functions/index.rs @@ -0,0 +1,21 @@ +use super::*; + +/// Find the indexes that would sort these series in order of appearance. +/// +/// That means that the first `Series` will be used to determine the ordering +/// until duplicates are found. Once duplicates are found, the next `Series` will +/// be used and so on. +#[cfg(feature = "range")] +pub fn arg_sort_by>(by: E, sort_options: SortMultipleOptions) -> Expr { + let e = &by.as_ref()[0]; + let name = expr_output_name(e).unwrap(); + int_range(lit(0 as IdxSize), len().cast(IDX_DTYPE), 1, IDX_DTYPE) + .sort_by(by, sort_options) + .alias(name) +} + +#[cfg(feature = "arg_where")] +/// Get the indices where `condition` evaluates `true`. +pub fn arg_where>(condition: E) -> Expr { + condition.into().map_unary(FunctionExpr::ArgWhere) +} diff --git a/crates/polars-plan/src/dsl/functions/mod.rs b/crates/polars-plan/src/dsl/functions/mod.rs new file mode 100644 index 000000000000..9219704bfc2f --- /dev/null +++ b/crates/polars-plan/src/dsl/functions/mod.rs @@ -0,0 +1,53 @@ +//! # Functions +//! +//! Functions on expressions that might be useful. +mod arity; +#[cfg(feature = "business")] +mod business; +#[cfg(feature = "dtype-struct")] +mod coerce; +mod concat; +#[cfg(feature = "cov")] +mod correlation; +pub(crate) mod horizontal; +#[cfg(any(feature = "range", feature = "arg_where"))] +mod index; +#[cfg(feature = "range")] +mod range; +mod repeat; +mod selectors; +mod syntactic_sugar; +#[cfg(feature = "temporal")] +mod temporal; + +pub use arity::*; +#[cfg(all(feature = "business", feature = "dtype-date"))] +pub use business::*; +#[cfg(feature = "dtype-struct")] +pub use coerce::*; +pub use concat::*; +#[cfg(feature = "cov")] +pub use correlation::*; +pub use horizontal::*; +#[cfg(any(feature = "range", feature = "arg_where"))] +pub use index::*; +#[cfg(feature = "dtype-struct")] +use polars_core::utils::get_supertype; +#[cfg(all(feature = "range", feature = "temporal"))] +pub use range::date_range; // This shouldn't be necessary, but clippy complains about dead code +#[cfg(all(feature = "range", feature = "dtype-time"))] +pub use range::time_range; // This shouldn't be necessary, but clippy complains about dead code +#[cfg(feature = "range")] +pub use range::*; +pub use repeat::*; +pub use selectors::*; +pub use syntactic_sugar::*; +#[cfg(feature = "temporal")] +pub use temporal::*; + +#[cfg(feature = "arg_where")] +use crate::dsl::function_expr::FunctionExpr; +use crate::dsl::function_expr::ListFunction; +#[cfg(all(feature = "concat_str", feature = "strings"))] +use crate::dsl::function_expr::StringFunction; +use crate::dsl::*; diff --git a/crates/polars-plan/src/dsl/functions/range.rs b/crates/polars-plan/src/dsl/functions/range.rs new file mode 100644 index 000000000000..d1567da4a6bc --- /dev/null +++ b/crates/polars-plan/src/dsl/functions/range.rs @@ -0,0 +1,137 @@ +use polars_ops::series::ClosedInterval; +#[cfg(feature = "temporal")] +use polars_time::ClosedWindow; + +use super::*; + +/// Generate a range of integers. +/// +/// Alias for `int_range`. +pub fn arange(start: Expr, end: Expr, step: i64, dtype: DataType) -> Expr { + int_range(start, end, step, dtype) +} + +/// Generate a range of integers. +pub fn int_range(start: Expr, end: Expr, step: i64, dtype: DataType) -> Expr { + Expr::n_ary(RangeFunction::IntRange { step, dtype }, vec![start, end]) +} + +/// Generate a range of integers for each row of the input columns. +pub fn int_ranges(start: Expr, end: Expr, step: Expr) -> Expr { + Expr::n_ary(RangeFunction::IntRanges, vec![start, end, step]) +} + +/// Create a date range from a `start` and `stop` expression. +#[cfg(feature = "temporal")] +pub fn date_range(start: Expr, end: Expr, interval: Duration, closed: ClosedWindow) -> Expr { + Expr::n_ary( + RangeFunction::DateRange { interval, closed }, + vec![start, end], + ) +} + +/// Create a column of date ranges from a `start` and `stop` expression. +#[cfg(feature = "temporal")] +pub fn date_ranges(start: Expr, end: Expr, interval: Duration, closed: ClosedWindow) -> Expr { + Expr::n_ary( + RangeFunction::DateRanges { interval, closed }, + vec![start, end], + ) +} + +/// Create a datetime range from a `start` and `stop` expression. +#[cfg(feature = "dtype-datetime")] +pub fn datetime_range( + start: Expr, + end: Expr, + interval: Duration, + closed: ClosedWindow, + time_unit: Option, + time_zone: Option, +) -> Expr { + Expr::n_ary( + RangeFunction::DatetimeRange { + interval, + closed, + time_unit, + time_zone, + }, + vec![start, end], + ) +} + +/// Create a column of datetime ranges from a `start` and `stop` expression. +#[cfg(feature = "dtype-datetime")] +pub fn datetime_ranges( + start: Expr, + end: Expr, + interval: Duration, + closed: ClosedWindow, + time_unit: Option, + time_zone: Option, +) -> Expr { + Expr::n_ary( + RangeFunction::DatetimeRanges { + interval, + closed, + time_unit, + time_zone, + }, + vec![start, end], + ) +} + +/// Generate a time range. +#[cfg(feature = "dtype-time")] +pub fn time_range(start: Expr, end: Expr, interval: Duration, closed: ClosedWindow) -> Expr { + Expr::n_ary( + RangeFunction::TimeRange { interval, closed }, + vec![start, end], + ) +} + +/// Create a column of time ranges from a `start` and `stop` expression. +#[cfg(feature = "dtype-time")] +pub fn time_ranges(start: Expr, end: Expr, interval: Duration, closed: ClosedWindow) -> Expr { + Expr::n_ary( + RangeFunction::TimeRanges { interval, closed }, + vec![start, end], + ) +} + +/// Generate a series of equally-spaced points. +pub fn linear_space(start: Expr, end: Expr, num_samples: Expr, closed: ClosedInterval) -> Expr { + Expr::n_ary( + RangeFunction::LinearSpace { closed }, + vec![start, end, num_samples], + ) +} + +/// Create a column of linearly-spaced sequences from 'start', 'end', and 'num_samples' expressions. +pub fn linear_spaces( + start: Expr, + end: Expr, + num_samples: Expr, + closed: ClosedInterval, + as_array: bool, +) -> PolarsResult { + let mut input = Vec::::with_capacity(3); + input.push(start); + input.push(end); + let array_width = if as_array { + Some(num_samples.extract_usize().map_err(|_| { + polars_err!(InvalidOperation: "'as_array' is only valid when 'num_samples' is a constant integer") + })?) + } else { + input.push(num_samples); + None + }; + + Ok(Expr::n_ary( + RangeFunction::LinearSpaces { + closed, + array_width, + }, + input, + )) +} diff --git a/crates/polars-plan/src/dsl/functions/repeat.rs b/crates/polars-plan/src/dsl/functions/repeat.rs new file mode 100644 index 000000000000..6646b5fadfee --- /dev/null +++ b/crates/polars-plan/src/dsl/functions/repeat.rs @@ -0,0 +1,13 @@ +use super::*; + +/// Create a column of length `n` containing `n` copies of the literal `value`. +/// +/// Generally you won't need this function, as `lit(value)` already represents a column containing +/// only `value` whose length is automatically set to the correct number of rows. +pub fn repeat>(value: E, n: Expr) -> Expr { + let expr = Expr::n_ary(FunctionExpr::Repeat, vec![value.into(), n]); + + // @NOTE: This alias should probably not be here for consistency, but it is here for backwards + // compatibility until 2.0. + expr.alias(PlSmallStr::from_static("repeat")) +} diff --git a/crates/polars-plan/src/dsl/functions/selectors.rs b/crates/polars-plan/src/dsl/functions/selectors.rs new file mode 100644 index 000000000000..28d52c10f835 --- /dev/null +++ b/crates/polars-plan/src/dsl/functions/selectors.rs @@ -0,0 +1,68 @@ +use super::*; + +/// Create a Column Expression based on a column name. +/// +/// # Arguments +/// +/// * `name` - A string slice that holds the name of the column. If a column with this name does not exist when the +/// LazyFrame is collected, an error is returned. +/// +/// # Examples +/// +/// ```ignore +/// // select a column name +/// col("foo") +/// ``` +/// +/// ```ignore +/// // select all columns by using a wildcard +/// col("*") +/// ``` +/// +/// ```ignore +/// // select specific columns by writing a regular expression that starts with `^` and ends with `$` +/// // only if regex features is activated +/// col("^foo.*$") +/// ``` +pub fn col(name: S) -> Expr +where + S: Into, +{ + let name = name.into(); + match name.as_str() { + "*" => Expr::Wildcard, + _ => Expr::Column(name), + } +} + +/// Selects all columns. Shorthand for `col("*")`. +pub fn all() -> Expr { + Expr::Wildcard +} + +/// Select multiple columns by name. +pub fn cols(names: I) -> Expr +where + I: IntoIterator, + S: Into, +{ + let names = names.into_iter().map(|x| x.into()).collect(); + Expr::Columns(names) +} + +/// Select multiple columns by dtype. +pub fn dtype_col(dtype: &DataType) -> Expr { + Expr::DtypeColumn(vec![dtype.clone()]) +} + +/// Select multiple columns by dtype. +pub fn dtype_cols>(dtype: DT) -> Expr { + let dtypes = dtype.as_ref().to_vec(); + Expr::DtypeColumn(dtypes) +} + +/// Select multiple columns by index. +pub fn index_cols>(indices: N) -> Expr { + let indices = indices.as_ref().to_vec(); + Expr::IndexColumn(Arc::from(indices)) +} diff --git a/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs b/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs new file mode 100644 index 000000000000..4d0e4c105014 --- /dev/null +++ b/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs @@ -0,0 +1,66 @@ +use polars_core::chunked_array::cast::CastOptions; + +use super::*; + +/// Sum all the values in the column named `name`. Shorthand for `col(name).sum()`. +pub fn sum(name: &str) -> Expr { + col(name).sum() +} + +/// Find the minimum of all the values in the column named `name`. Shorthand for `col(name).min()`. +pub fn min(name: &str) -> Expr { + col(name).min() +} + +/// Find the maximum of all the values in the column named `name`. Shorthand for `col(name).max()`. +pub fn max(name: &str) -> Expr { + col(name).max() +} + +/// Find the mean of all the values in the column named `name`. Shorthand for `col(name).mean()`. +pub fn mean(name: &str) -> Expr { + col(name).mean() +} + +/// Find the mean of all the values in the column named `name`. Alias for [`mean`]. +pub fn avg(name: &str) -> Expr { + col(name).mean() +} + +/// Find the median of all the values in the column named `name`. Shorthand for `col(name).median()`. +pub fn median(name: &str) -> Expr { + col(name).median() +} + +/// Find a specific quantile of all the values in the column named `name`. +pub fn quantile(name: &str, quantile: Expr, method: QuantileMethod) -> Expr { + col(name).quantile(quantile, method) +} + +/// Negates a boolean column. +pub fn not(expr: Expr) -> Expr { + expr.not() +} + +/// A column which is `true` wherever `expr` is null, `false` elsewhere. +pub fn is_null(expr: Expr) -> Expr { + expr.is_null() +} + +/// A column which is `false` wherever `expr` is null, `true` elsewhere. +pub fn is_not_null(expr: Expr) -> Expr { + expr.is_not_null() +} + +/// Casts the column given by `Expr` to a different type. +/// +/// Follows the rules of Rust casting, with the exception that integers and floats can be cast to `DataType::Date` and +/// `DataType::DateTime(_, _)`. A column consisting entirely of `Null` can be cast to any type, regardless of the +/// nominal type of the column. +pub fn cast(expr: Expr, dtype: DataType) -> Expr { + Expr::Cast { + expr: Arc::new(expr), + dtype, + options: CastOptions::NonStrict, + } +} diff --git a/crates/polars-plan/src/dsl/functions/temporal.rs b/crates/polars-plan/src/dsl/functions/temporal.rs new file mode 100644 index 000000000000..2ebec16b5b77 --- /dev/null +++ b/crates/polars-plan/src/dsl/functions/temporal.rs @@ -0,0 +1,438 @@ +use chrono::{Datelike, Timelike}; + +use super::*; + +macro_rules! impl_unit_setter { + ($fn_name:ident($field:ident)) => { + #[doc = concat!("Set the ", stringify!($field))] + pub fn $fn_name(mut self, n: Expr) -> Self { + self.$field = n.into(); + self + } + }; +} + +/// Arguments used by `datetime` in order to produce an [`Expr`] of Datetime +/// +/// Construct a [`DatetimeArgs`] with `DatetimeArgs::new(y, m, d)`. This will set the other time units to `lit(0)`. You +/// can then set the other fields with the `with_*` methods, or use `with_hms` to set `hour`, `minute`, and `second` all +/// at once. +/// +/// # Examples +/// ``` +/// use polars_plan::prelude::*; +/// // construct a DatetimeArgs set to July 20, 1969 at 20:17 +/// let args = DatetimeArgs::new(lit(1969), lit(7), lit(20)).with_hms(lit(20), lit(17), lit(0)); +/// // or +/// let args = DatetimeArgs::new(lit(1969), lit(7), lit(20)).with_hour(lit(20)).with_minute(lit(17)); +/// +/// // construct a DatetimeArgs using existing columns +/// let args = DatetimeArgs::new(lit(2023), col("month"), col("day")); +/// ``` +#[derive(Debug, Clone)] +pub struct DatetimeArgs { + pub year: Expr, + pub month: Expr, + pub day: Expr, + pub hour: Expr, + pub minute: Expr, + pub second: Expr, + pub microsecond: Expr, + pub time_unit: TimeUnit, + pub time_zone: Option, + pub ambiguous: Expr, +} + +impl Default for DatetimeArgs { + fn default() -> Self { + Self { + year: lit(1970), + month: lit(1), + day: lit(1), + hour: lit(0), + minute: lit(0), + second: lit(0), + microsecond: lit(0), + time_unit: TimeUnit::Microseconds, + time_zone: None, + ambiguous: lit(String::from("raise")), + } + } +} + +impl DatetimeArgs { + /// Construct a new `DatetimeArgs` set to `year`, `month`, `day` + /// + /// Other fields default to `lit(0)`. Use the `with_*` methods to set them. + pub fn new(year: Expr, month: Expr, day: Expr) -> Self { + Self { + year, + month, + day, + ..Default::default() + } + } + + /// Set `hour`, `minute`, and `second` + /// + /// Equivalent to + /// ```ignore + /// self.with_hour(hour) + /// .with_minute(minute) + /// .with_second(second) + /// ``` + pub fn with_hms(self, hour: Expr, minute: Expr, second: Expr) -> Self { + Self { + hour, + minute, + second, + ..self + } + } + + impl_unit_setter!(with_year(year)); + impl_unit_setter!(with_month(month)); + impl_unit_setter!(with_day(day)); + impl_unit_setter!(with_hour(hour)); + impl_unit_setter!(with_minute(minute)); + impl_unit_setter!(with_second(second)); + impl_unit_setter!(with_microsecond(microsecond)); + + pub fn with_time_unit(self, time_unit: TimeUnit) -> Self { + Self { time_unit, ..self } + } + #[cfg(feature = "timezones")] + pub fn with_time_zone(self, time_zone: Option) -> Self { + Self { time_zone, ..self } + } + #[cfg(feature = "timezones")] + pub fn with_ambiguous(self, ambiguous: Expr) -> Self { + Self { ambiguous, ..self } + } + + fn all_literal(&self) -> bool { + use Expr::*; + [ + &self.year, + &self.month, + &self.day, + &self.hour, + &self.minute, + &self.second, + &self.microsecond, + ] + .iter() + .all(|e| matches!(e, Literal(_))) + } + + fn as_literal(&self) -> Option { + if self.time_zone.is_some() || !self.all_literal() { + return None; + }; + let Expr::Literal(lv) = &self.year else { + unreachable!() + }; + let year = lv.to_any_value()?.extract()?; + let Expr::Literal(lv) = &self.month else { + unreachable!() + }; + let month = lv.to_any_value()?.extract()?; + let Expr::Literal(lv) = &self.day else { + unreachable!() + }; + let day = lv.to_any_value()?.extract()?; + let Expr::Literal(lv) = &self.hour else { + unreachable!() + }; + let hour = lv.to_any_value()?.extract()?; + let Expr::Literal(lv) = &self.minute else { + unreachable!() + }; + let minute = lv.to_any_value()?.extract()?; + let Expr::Literal(lv) = &self.second else { + unreachable!() + }; + let second = lv.to_any_value()?.extract()?; + let Expr::Literal(lv) = &self.microsecond else { + unreachable!() + }; + let ms: u32 = lv.to_any_value()?.extract()?; + + let dt = chrono::NaiveDateTime::default() + .with_year(year)? + .with_month(month)? + .with_day(day)? + .with_hour(hour)? + .with_minute(minute)? + .with_second(second)? + .with_nanosecond(ms * 1000)?; + + let ts = match self.time_unit { + TimeUnit::Milliseconds => dt.and_utc().timestamp_millis(), + TimeUnit::Microseconds => dt.and_utc().timestamp_micros(), + TimeUnit::Nanoseconds => dt.and_utc().timestamp_nanos_opt()?, + }; + + Some( + Expr::Literal(LiteralValue::Scalar(Scalar::new( + DataType::Datetime(self.time_unit, None), + AnyValue::Datetime(ts, self.time_unit, None), + ))) + .alias(PlSmallStr::from_static("datetime")), + ) + } +} + +/// Construct a column of `Datetime` from the provided [`DatetimeArgs`]. +pub fn datetime(args: DatetimeArgs) -> Expr { + if let Some(e) = args.as_literal() { + return e; + } + + let year = args.year; + let month = args.month; + let day = args.day; + let hour = args.hour; + let minute = args.minute; + let second = args.second; + let microsecond = args.microsecond; + let time_unit = args.time_unit; + let time_zone = args.time_zone; + let ambiguous = args.ambiguous; + + let input = vec![ + year, + month, + day, + hour, + minute, + second, + microsecond, + ambiguous, + ]; + + Expr::Alias( + Arc::new(Expr::Function { + input, + function: FunctionExpr::TemporalExpr(TemporalFunction::DatetimeFunction { + time_unit, + time_zone, + }), + options: FunctionOptions { + collect_groups: ApplyOptions::ElementWise, + flags: FunctionFlags::default() | FunctionFlags::ALLOW_RENAME, + fmt_str: "datetime", + ..Default::default() + }, + }), + // TODO: follow left-hand rule in Polars 2.0. + PlSmallStr::from_static("datetime"), + ) +} + +/// Arguments used by `duration` in order to produce an [`Expr`] of [`Duration`] +/// +/// To construct a [`DurationArgs`], use struct literal syntax with `..Default::default()` to leave unspecified fields at +/// their default value of `lit(0)`, as demonstrated below. +/// +/// ``` +/// # use polars_plan::prelude::*; +/// let args = DurationArgs { +/// days: lit(5), +/// hours: col("num_hours"), +/// minutes: col("num_minutes"), +/// ..Default::default() // other fields are lit(0) +/// }; +/// ``` +/// If you prefer builder syntax, `with_*` methods are also available. +/// ``` +/// # use polars_plan::prelude::*; +/// let args = DurationArgs::new().with_weeks(lit(42)).with_hours(lit(84)); +/// ``` +#[derive(Debug, Clone)] +pub struct DurationArgs { + pub weeks: Expr, + pub days: Expr, + pub hours: Expr, + pub minutes: Expr, + pub seconds: Expr, + pub milliseconds: Expr, + pub microseconds: Expr, + pub nanoseconds: Expr, + pub time_unit: TimeUnit, +} + +impl Default for DurationArgs { + fn default() -> Self { + Self { + weeks: lit(0), + days: lit(0), + hours: lit(0), + minutes: lit(0), + seconds: lit(0), + milliseconds: lit(0), + microseconds: lit(0), + nanoseconds: lit(0), + time_unit: TimeUnit::Microseconds, + } + } +} + +impl DurationArgs { + /// Create a new [`DurationArgs`] with all fields set to `lit(0)`. Use the `with_*` methods to set the fields. + pub fn new() -> Self { + Self::default() + } + + /// Set `hours`, `minutes`, and `seconds` + /// + /// Equivalent to: + /// + /// ```ignore + /// self.with_hours(hours) + /// .with_minutes(minutes) + /// .with_seconds(seconds) + /// ``` + pub fn with_hms(self, hours: Expr, minutes: Expr, seconds: Expr) -> Self { + Self { + hours, + minutes, + seconds, + ..self + } + } + + /// Set `milliseconds`, `microseconds`, and `nanoseconds` + /// + /// Equivalent to + /// ```ignore + /// self.with_milliseconds(milliseconds) + /// .with_microseconds(microseconds) + /// .with_nanoseconds(nanoseconds) + /// ``` + pub fn with_fractional_seconds( + self, + milliseconds: Expr, + microseconds: Expr, + nanoseconds: Expr, + ) -> Self { + Self { + milliseconds, + microseconds, + nanoseconds, + ..self + } + } + + impl_unit_setter!(with_weeks(weeks)); + impl_unit_setter!(with_days(days)); + impl_unit_setter!(with_hours(hours)); + impl_unit_setter!(with_minutes(minutes)); + impl_unit_setter!(with_seconds(seconds)); + impl_unit_setter!(with_milliseconds(milliseconds)); + impl_unit_setter!(with_microseconds(microseconds)); + impl_unit_setter!(with_nanoseconds(nanoseconds)); + + fn all_literal(&self) -> bool { + use Expr::*; + [ + &self.weeks, + &self.days, + &self.hours, + &self.seconds, + &self.minutes, + &self.milliseconds, + &self.microseconds, + &self.nanoseconds, + ] + .iter() + .all(|e| matches!(e, Literal(_))) + } + + fn as_literal(&self) -> Option { + if !self.all_literal() { + return None; + }; + let Expr::Literal(lv) = &self.weeks else { + unreachable!() + }; + let weeks = lv.to_any_value()?.extract()?; + let Expr::Literal(lv) = &self.days else { + unreachable!() + }; + let days = lv.to_any_value()?.extract()?; + let Expr::Literal(lv) = &self.hours else { + unreachable!() + }; + let hours = lv.to_any_value()?.extract()?; + let Expr::Literal(lv) = &self.seconds else { + unreachable!() + }; + let seconds = lv.to_any_value()?.extract()?; + let Expr::Literal(lv) = &self.minutes else { + unreachable!() + }; + let minutes = lv.to_any_value()?.extract()?; + let Expr::Literal(lv) = &self.milliseconds else { + unreachable!() + }; + let milliseconds = lv.to_any_value()?.extract()?; + let Expr::Literal(lv) = &self.microseconds else { + unreachable!() + }; + let microseconds = lv.to_any_value()?.extract()?; + let Expr::Literal(lv) = &self.nanoseconds else { + unreachable!() + }; + let nanoseconds = lv.to_any_value()?.extract()?; + + type D = chrono::Duration; + let delta = D::weeks(weeks) + + D::days(days) + + D::hours(hours) + + D::seconds(seconds) + + D::minutes(minutes) + + D::milliseconds(milliseconds) + + D::microseconds(microseconds) + + D::nanoseconds(nanoseconds); + + let d = match self.time_unit { + TimeUnit::Milliseconds => delta.num_milliseconds(), + TimeUnit::Microseconds => delta.num_microseconds()?, + TimeUnit::Nanoseconds => delta.num_nanoseconds()?, + }; + + Some( + Expr::Literal(LiteralValue::Scalar(Scalar::new( + DataType::Duration(self.time_unit), + AnyValue::Duration(d, self.time_unit), + ))) + .alias(PlSmallStr::from_static("duration")), + ) + } +} + +/// Construct a column of [`Duration`] from the provided [`DurationArgs`] +pub fn duration(args: DurationArgs) -> Expr { + if let Some(e) = args.as_literal() { + return e; + } + Expr::Function { + input: vec![ + args.weeks, + args.days, + args.hours, + args.minutes, + args.seconds, + args.milliseconds, + args.microseconds, + args.nanoseconds, + ], + function: FunctionExpr::TemporalExpr(TemporalFunction::Duration(args.time_unit)), + options: FunctionOptions { + collect_groups: ApplyOptions::ElementWise, + flags: FunctionFlags::default(), + ..Default::default() + }, + } +} diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs new file mode 100644 index 000000000000..4cc28b2a4956 --- /dev/null +++ b/crates/polars-plan/src/dsl/list.rs @@ -0,0 +1,298 @@ +use polars_core::prelude::*; +#[cfg(feature = "diff")] +use polars_core::series::ops::NullBehavior; + +use crate::prelude::function_expr::ListFunction; +use crate::prelude::*; + +/// Specialized expressions for [`Series`] of [`DataType::List`]. +pub struct ListNameSpace(pub Expr); + +impl ListNameSpace { + #[cfg(feature = "list_any_all")] + pub fn any(self) -> Expr { + self.0.map_unary(FunctionExpr::ListExpr(ListFunction::Any)) + } + + #[cfg(feature = "list_any_all")] + pub fn all(self) -> Expr { + self.0.map_unary(FunctionExpr::ListExpr(ListFunction::All)) + } + + #[cfg(feature = "list_drop_nulls")] + pub fn drop_nulls(self) -> Expr { + self.0 + .map_unary(FunctionExpr::ListExpr(ListFunction::DropNulls)) + } + + #[cfg(feature = "list_sample")] + pub fn sample_n( + self, + n: Expr, + with_replacement: bool, + shuffle: bool, + seed: Option, + ) -> Expr { + self.0.map_binary( + FunctionExpr::ListExpr(ListFunction::Sample { + is_fraction: false, + with_replacement, + shuffle, + seed, + }), + n, + ) + } + + #[cfg(feature = "list_sample")] + pub fn sample_fraction( + self, + fraction: Expr, + with_replacement: bool, + shuffle: bool, + seed: Option, + ) -> Expr { + self.0.map_binary( + FunctionExpr::ListExpr(ListFunction::Sample { + is_fraction: true, + with_replacement, + shuffle, + seed, + }), + fraction, + ) + } + + /// Return the number of elements in each list. + /// + /// Null values are treated like regular elements in this context. + pub fn len(self) -> Expr { + self.0 + .map_unary(FunctionExpr::ListExpr(ListFunction::Length)) + } + + /// Compute the maximum of the items in every sublist. + pub fn max(self) -> Expr { + self.0.map_unary(FunctionExpr::ListExpr(ListFunction::Max)) + } + + /// Compute the minimum of the items in every sublist. + pub fn min(self) -> Expr { + self.0.map_unary(FunctionExpr::ListExpr(ListFunction::Min)) + } + + /// Compute the sum the items in every sublist. + pub fn sum(self) -> Expr { + self.0.map_unary(FunctionExpr::ListExpr(ListFunction::Sum)) + } + + /// Compute the mean of every sublist and return a `Series` of dtype `Float64` + pub fn mean(self) -> Expr { + self.0.map_unary(FunctionExpr::ListExpr(ListFunction::Mean)) + } + + pub fn median(self) -> Expr { + self.0 + .map_unary(FunctionExpr::ListExpr(ListFunction::Median)) + } + + pub fn std(self, ddof: u8) -> Expr { + self.0 + .map_unary(FunctionExpr::ListExpr(ListFunction::Std(ddof))) + } + + pub fn var(self, ddof: u8) -> Expr { + self.0 + .map_unary(FunctionExpr::ListExpr(ListFunction::Var(ddof))) + } + + /// Sort every sublist. + pub fn sort(self, options: SortOptions) -> Expr { + self.0 + .map_unary(FunctionExpr::ListExpr(ListFunction::Sort(options))) + } + + /// Reverse every sublist + pub fn reverse(self) -> Expr { + self.0 + .map_unary(FunctionExpr::ListExpr(ListFunction::Reverse)) + } + + /// Keep only the unique values in every sublist. + pub fn unique(self) -> Expr { + self.0 + .map_unary(FunctionExpr::ListExpr(ListFunction::Unique(false))) + } + + /// Keep only the unique values in every sublist. + pub fn unique_stable(self) -> Expr { + self.0 + .map_unary(FunctionExpr::ListExpr(ListFunction::Unique(true))) + } + + pub fn n_unique(self) -> Expr { + self.0 + .map_unary(FunctionExpr::ListExpr(ListFunction::NUnique)) + } + + /// Get items in every sublist by index. + pub fn get(self, index: Expr, null_on_oob: bool) -> Expr { + self.0.map_binary( + FunctionExpr::ListExpr(ListFunction::Get(null_on_oob)), + index, + ) + } + + /// Get items in every sublist by multiple indexes. + /// + /// # Arguments + /// - `null_on_oob`: Return a null when an index is out of bounds. + /// This behavior is more expensive than defaulting to returning an `Error`. + #[cfg(feature = "list_gather")] + pub fn gather(self, index: Expr, null_on_oob: bool) -> Expr { + self.0.map_binary( + FunctionExpr::ListExpr(ListFunction::Gather(null_on_oob)), + index, + ) + } + + #[cfg(feature = "list_gather")] + pub fn gather_every(self, n: Expr, offset: Expr) -> Expr { + self.0 + .map_ternary(FunctionExpr::ListExpr(ListFunction::GatherEvery), n, offset) + } + + /// Get first item of every sublist. + pub fn first(self) -> Expr { + self.get(lit(0i64), true) + } + + /// Get last item of every sublist. + pub fn last(self) -> Expr { + self.get(lit(-1i64), true) + } + + /// Join all string items in a sublist and place a separator between them. + /// # Error + /// This errors if inner type of list `!= DataType::String`. + pub fn join(self, separator: Expr, ignore_nulls: bool) -> Expr { + self.0.map_binary( + FunctionExpr::ListExpr(ListFunction::Join(ignore_nulls)), + separator, + ) + } + + /// Return the index of the minimal value of every sublist + pub fn arg_min(self) -> Expr { + self.0 + .map_unary(FunctionExpr::ListExpr(ListFunction::ArgMin)) + } + + /// Return the index of the maximum value of every sublist + pub fn arg_max(self) -> Expr { + self.0 + .map_unary(FunctionExpr::ListExpr(ListFunction::ArgMax)) + } + + /// Diff every sublist. + #[cfg(feature = "diff")] + pub fn diff(self, n: i64, null_behavior: NullBehavior) -> Expr { + self.0.map_unary(FunctionExpr::ListExpr(ListFunction::Diff { + n, + null_behavior, + })) + } + + /// Shift every sublist. + pub fn shift(self, periods: Expr) -> Expr { + self.0 + .map_binary(FunctionExpr::ListExpr(ListFunction::Shift), periods) + } + + /// Slice every sublist. + pub fn slice(self, offset: Expr, length: Expr) -> Expr { + self.0 + .map_ternary(FunctionExpr::ListExpr(ListFunction::Slice), offset, length) + } + + /// Get the head of every sublist + pub fn head(self, n: Expr) -> Expr { + self.slice(lit(0), n) + } + + /// Get the tail of every sublist + pub fn tail(self, n: Expr) -> Expr { + self.slice(lit(0i64) - n.clone().cast(DataType::Int64), n) + } + + #[cfg(feature = "dtype-array")] + /// Convert a List column into an Array column with the same inner data type. + pub fn to_array(self, width: usize) -> Expr { + self.0 + .map_unary(FunctionExpr::ListExpr(ListFunction::ToArray(width))) + } + + #[cfg(feature = "list_to_struct")] + #[allow(clippy::wrong_self_convention)] + /// Convert this `List` to a `Series` of type `Struct`. The width will be determined according to + /// `ListToStructWidthStrategy` and the names of the fields determined by the given `name_generator`. + /// + /// # Schema + /// + /// A polars `LazyFrame` needs to know the schema at all time. The caller therefore must provide + /// an `upper_bound` of struct fields that will be set. + /// If this is incorrectly downstream operation may fail. For instance an `all().sum()` expression + /// will look in the current schema to determine which columns to select. + pub fn to_struct(self, args: ListToStructArgs) -> Expr { + self.0 + .map_unary(FunctionExpr::ListExpr(ListFunction::ToStruct(args))) + } + + #[cfg(feature = "is_in")] + /// Check if the list array contain an element + pub fn contains>(self, other: E) -> Expr { + self.0 + .map_binary(FunctionExpr::ListExpr(ListFunction::Contains), other.into()) + } + + #[cfg(feature = "list_count")] + /// Count how often the value produced by ``element`` occurs. + pub fn count_matches>(self, element: E) -> Expr { + self.0.map_binary( + FunctionExpr::ListExpr(ListFunction::CountMatches), + element.into(), + ) + } + + #[cfg(feature = "list_sets")] + fn set_operation(self, other: Expr, set_operation: SetOperation) -> Expr { + self.0.map_binary( + FunctionExpr::ListExpr(ListFunction::SetOperation(set_operation)), + other, + ) + } + + /// Return the SET UNION between both list arrays. + #[cfg(feature = "list_sets")] + pub fn union>(self, other: E) -> Expr { + self.set_operation(other.into(), SetOperation::Union) + } + + /// Return the SET DIFFERENCE between both list arrays. + #[cfg(feature = "list_sets")] + pub fn set_difference>(self, other: E) -> Expr { + self.set_operation(other.into(), SetOperation::Difference) + } + + /// Return the SET INTERSECTION between both list arrays. + #[cfg(feature = "list_sets")] + pub fn set_intersection>(self, other: E) -> Expr { + self.set_operation(other.into(), SetOperation::Intersection) + } + + /// Return the SET SYMMETRIC DIFFERENCE between both list arrays. + #[cfg(feature = "list_sets")] + pub fn set_symmetric_difference>(self, other: E) -> Expr { + self.set_operation(other.into(), SetOperation::SymmetricDifference) + } +} diff --git a/crates/polars-plan/src/dsl/meta.rs b/crates/polars-plan/src/dsl/meta.rs new file mode 100644 index 000000000000..45a5c4ab92b9 --- /dev/null +++ b/crates/polars-plan/src/dsl/meta.rs @@ -0,0 +1,187 @@ +use std::fmt::Display; +use std::ops::BitAnd; + +use super::*; +use crate::plans::conversion::is_regex_projection; +use crate::plans::ir::tree_format::TreeFmtVisitor; +use crate::plans::visitor::{AexprNode, TreeWalker}; +use crate::prelude::tree_format::TreeFmtVisitorDisplay; + +/// Specialized expressions for Categorical dtypes. +pub struct MetaNameSpace(pub(crate) Expr); + +impl MetaNameSpace { + /// Pop latest expression and return the input(s) of the popped expression. + pub fn pop(self) -> PolarsResult> { + let mut arena = Arena::with_capacity(8); + let node = to_aexpr(self.0, &mut arena)?; + let ae = arena.get(node); + let mut inputs = Vec::with_capacity(2); + ae.inputs_rev(&mut inputs); + Ok(inputs + .iter() + .map(|node| node_to_expr(*node, &arena)) + .collect()) + } + + /// Get the root column names. + pub fn root_names(&self) -> Vec { + expr_to_leaf_column_names(&self.0) + } + + /// A projection that only takes a column or a column + alias. + pub fn is_simple_projection(&self) -> bool { + let mut arena = Arena::with_capacity(8); + to_aexpr(self.0.clone(), &mut arena) + .map(|node| aexpr_is_simple_projection(node, &arena)) + .unwrap_or(false) + } + + /// Get the output name of this expression. + pub fn output_name(&self) -> PolarsResult { + expr_output_name(&self.0) + } + + /// Undo any renaming operation like `alias`, `keep_name`. + pub fn undo_aliases(self) -> Expr { + self.0.map_expr(|e| match e { + Expr::Alias(input, _) + | Expr::KeepName(input) + | Expr::RenameAlias { expr: input, .. } => Arc::unwrap_or_clone(input), + e => e, + }) + } + + /// Indicate if this expression expands to multiple expressions. + pub fn has_multiple_outputs(&self) -> bool { + self.0.into_iter().any(|e| match e { + Expr::Selector(_) | Expr::Wildcard | Expr::Columns(_) | Expr::DtypeColumn(_) => true, + Expr::IndexColumn(idxs) => idxs.len() > 1, + Expr::Column(name) => is_regex_projection(name), + _ => false, + }) + } + + /// Indicate if this expression is a basic (non-regex) column. + pub fn is_column(&self) -> bool { + match &self.0 { + Expr::Column(name) => !is_regex_projection(name), + _ => false, + } + } + + /// Indicate if this expression only selects columns; the presence of any + /// transform operations will cause the check to return `false`, though + /// aliasing of the selected columns is optionally allowed. + pub fn is_column_selection(&self, allow_aliasing: bool) -> bool { + self.0.into_iter().all(|e| match e { + Expr::Column(_) + | Expr::Columns(_) + | Expr::DtypeColumn(_) + | Expr::Exclude(_, _) + | Expr::Nth(_) + | Expr::IndexColumn(_) + | Expr::Selector(_) + | Expr::Wildcard => true, + Expr::Alias(_, _) | Expr::KeepName(_) | Expr::RenameAlias { .. } => allow_aliasing, + _ => false, + }) + } + + /// Indicate if this expression represents a literal value (optionally aliased). + pub fn is_literal(&self, allow_aliasing: bool) -> bool { + self.0.into_iter().all(|e| match e { + Expr::Literal(_) => true, + Expr::Alias(_, _) => allow_aliasing, + Expr::Cast { + expr, + dtype: DataType::Datetime(_, _), + options: CastOptions::Strict, + } if matches!(&**expr, Expr::Literal(LiteralValue::Scalar(sc)) if matches!(sc.as_any_value(), AnyValue::Datetime(..))) => true, + _ => false, + }) + } + + /// Indicate if this expression expands to multiple expressions with regex expansion. + pub fn is_regex_projection(&self) -> bool { + self.0.into_iter().any(|e| match e { + Expr::Column(name) => is_regex_projection(name), + _ => false, + }) + } + + pub fn _selector_add(self, other: Expr) -> PolarsResult { + if let Expr::Selector(mut s) = self.0 { + if let Expr::Selector(s_other) = other { + s = s + s_other; + } else { + s = s + Selector::Root(Box::new(other)) + } + Ok(Expr::Selector(s)) + } else { + polars_bail!(ComputeError: "expected selector, got {:?}", self.0) + } + } + + pub fn _selector_and(self, other: Expr) -> PolarsResult { + if let Expr::Selector(mut s) = self.0 { + if let Expr::Selector(s_other) = other { + s = s.bitand(s_other); + } else { + s = s.bitand(Selector::Root(Box::new(other))) + } + Ok(Expr::Selector(s)) + } else { + polars_bail!(ComputeError: "expected selector, got {:?}", self.0) + } + } + + pub fn _selector_sub(self, other: Expr) -> PolarsResult { + if let Expr::Selector(mut s) = self.0 { + if let Expr::Selector(s_other) = other { + s = s - s_other; + } else { + s = s - Selector::Root(Box::new(other)) + } + Ok(Expr::Selector(s)) + } else { + polars_bail!(ComputeError: "expected selector, got {:?}", self.0) + } + } + + pub fn _selector_xor(self, other: Expr) -> PolarsResult { + if let Expr::Selector(mut s) = self.0 { + if let Expr::Selector(s_other) = other { + s = s ^ s_other; + } else { + s = s ^ Selector::Root(Box::new(other)) + } + Ok(Expr::Selector(s)) + } else { + polars_bail!(ComputeError: "expected selector, got {:?}", self.0) + } + } + + pub fn _into_selector(self) -> Expr { + if let Expr::Selector(_) = self.0 { + self.0 + } else { + Expr::Selector(Selector::new(self.0)) + } + } + + /// Get a hold to an implementor of the `Display` trait that will format as + /// the expression as a tree + pub fn into_tree_formatter(self, display_as_dot: bool) -> PolarsResult { + let mut arena = Default::default(); + let node = to_aexpr(self.0, &mut arena)?; + let mut visitor = TreeFmtVisitor::default(); + if display_as_dot { + visitor.display = TreeFmtVisitorDisplay::DisplayDot; + } + + AexprNode::new(node).visit(&mut visitor, &arena)?; + + Ok(visitor) + } +} diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs new file mode 100644 index 000000000000..a47d9ff84f24 --- /dev/null +++ b/crates/polars-plan/src/dsl/mod.rs @@ -0,0 +1,1895 @@ +#![allow(ambiguous_glob_reexports)] +//! Domain specific language for the Lazy API. +#[cfg(feature = "dtype-categorical")] +pub mod cat; + +#[cfg(feature = "dtype-categorical")] +pub use cat::*; +#[cfg(feature = "rolling_window_by")] +pub(crate) use polars_time::prelude::*; + +mod arithmetic; +mod arity; +#[cfg(feature = "dtype-array")] +mod array; +pub mod binary; +#[cfg(feature = "bitwise")] +mod bitwise; +mod builder_dsl; +pub use builder_dsl::*; +#[cfg(feature = "temporal")] +pub mod dt; +mod expr; +mod expr_dyn_fn; +mod format; +mod from; +pub mod function_expr; +pub mod functions; +mod list; +#[cfg(feature = "meta")] +mod meta; +mod name; +mod options; +#[cfg(feature = "python")] +pub mod python_dsl; +#[cfg(feature = "random")] +mod random; +mod scan_sources; +mod selector; +mod statistics; +#[cfg(feature = "strings")] +pub mod string; +#[cfg(feature = "dtype-struct")] +mod struct_; +pub mod udf; + +use std::fmt::Debug; +use std::sync::Arc; + +mod plan; +pub use arity::*; +#[cfg(feature = "dtype-array")] +pub use array::*; +pub use expr::*; +pub use function_expr::schema::FieldsMapper; +pub use function_expr::*; +pub use functions::*; +pub use list::*; +#[cfg(feature = "meta")] +pub use meta::*; +pub use name::*; +pub use options::*; +pub use plan::*; +use polars_compute::rolling::QuantileMethod; +use polars_core::chunked_array::cast::CastOptions; +use polars_core::error::feature_gated; +use polars_core::prelude::*; +use polars_core::series::IsSorted; +#[cfg(feature = "diff")] +use polars_core::series::ops::NullBehavior; +#[cfg(any(feature = "search_sorted", feature = "is_between"))] +use polars_core::utils::SuperTypeFlags; +use polars_core::utils::{SuperTypeOptions, try_get_supertype}; +pub use selector::Selector; +#[cfg(feature = "dtype-struct")] +pub use struct_::*; +pub use udf::UserDefinedFunction; +mod file_scan; +pub use file_scan::*; +pub use scan_sources::{ScanSource, ScanSourceIter, ScanSourceRef, ScanSources}; + +use crate::constants::MAP_LIST_NAME; +pub use crate::plans::lit; +use crate::prelude::*; + +impl Expr { + /// Modify the Options passed to the `Function` node. + pub(crate) fn with_function_options(self, func: F) -> Expr + where + F: Fn(FunctionOptions) -> FunctionOptions, + { + match self { + Self::AnonymousFunction { + input, + function, + output_type, + mut options, + } => { + options = func(options); + Self::AnonymousFunction { + input, + function, + output_type, + options, + } + }, + Self::Function { + input, + function, + mut options, + } => { + options = func(options); + Self::Function { + input, + function, + options, + } + }, + _ => { + panic!("implementation error") + }, + } + } + + /// Overwrite the function name used for formatting. + /// (this is not intended to be used). + #[doc(hidden)] + pub fn with_fmt(self, name: &'static str) -> Expr { + self.with_function_options(|mut options| { + options.fmt_str = name; + options + }) + } + + /// Compare `Expr` with other `Expr` on equality. + pub fn eq>(self, other: E) -> Expr { + binary_expr(self, Operator::Eq, other.into()) + } + + /// Compare `Expr` with other `Expr` on equality where `None == None`. + pub fn eq_missing>(self, other: E) -> Expr { + binary_expr(self, Operator::EqValidity, other.into()) + } + + /// Compare `Expr` with other `Expr` on non-equality. + pub fn neq>(self, other: E) -> Expr { + binary_expr(self, Operator::NotEq, other.into()) + } + + /// Compare `Expr` with other `Expr` on non-equality where `None == None`. + pub fn neq_missing>(self, other: E) -> Expr { + binary_expr(self, Operator::NotEqValidity, other.into()) + } + + /// Check if `Expr` < `Expr`. + pub fn lt>(self, other: E) -> Expr { + binary_expr(self, Operator::Lt, other.into()) + } + + /// Check if `Expr` > `Expr`. + pub fn gt>(self, other: E) -> Expr { + binary_expr(self, Operator::Gt, other.into()) + } + + /// Check if `Expr` >= `Expr`. + pub fn gt_eq>(self, other: E) -> Expr { + binary_expr(self, Operator::GtEq, other.into()) + } + + /// Check if `Expr` <= `Expr`. + pub fn lt_eq>(self, other: E) -> Expr { + binary_expr(self, Operator::LtEq, other.into()) + } + + /// Negate `Expr`. + #[allow(clippy::should_implement_trait)] + pub fn not(self) -> Expr { + self.map_unary(BooleanFunction::Not) + } + + /// Rename Column. + pub fn alias(self, name: S) -> Expr + where + S: Into, + { + Expr::Alias(Arc::new(self), name.into()) + } + + /// Run is_null operation on `Expr`. + #[allow(clippy::wrong_self_convention)] + pub fn is_null(self) -> Self { + self.map_unary(BooleanFunction::IsNull) + } + + /// Run is_not_null operation on `Expr`. + #[allow(clippy::wrong_self_convention)] + pub fn is_not_null(self) -> Self { + self.map_unary(BooleanFunction::IsNotNull) + } + + /// Drop null values. + pub fn drop_nulls(self) -> Self { + self.map_unary(FunctionExpr::DropNulls) + } + + /// Drop NaN values. + pub fn drop_nans(self) -> Self { + self.map_unary(FunctionExpr::DropNans) + } + + /// Get the number of unique values in the groups. + pub fn n_unique(self) -> Self { + AggExpr::NUnique(Arc::new(self)).into() + } + + /// Get the first value in the group. + pub fn first(self) -> Self { + AggExpr::First(Arc::new(self)).into() + } + + /// Get the last value in the group. + pub fn last(self) -> Self { + AggExpr::Last(Arc::new(self)).into() + } + + /// GroupBy the group to a Series. + pub fn implode(self) -> Self { + AggExpr::Implode(Arc::new(self)).into() + } + + /// Compute the quantile per group. + pub fn quantile(self, quantile: Expr, method: QuantileMethod) -> Self { + AggExpr::Quantile { + expr: Arc::new(self), + quantile: Arc::new(quantile), + method, + } + .into() + } + + /// Get the group indexes of the group by operation. + pub fn agg_groups(self) -> Self { + AggExpr::AggGroups(Arc::new(self)).into() + } + + /// Alias for `explode`. + pub fn flatten(self) -> Self { + self.explode() + } + + /// Explode the String/List column. + pub fn explode(self) -> Self { + Expr::Explode(Arc::new(self)) + } + + /// Slice the Series. + /// `offset` may be negative. + pub fn slice, F: Into>(self, offset: E, length: F) -> Self { + Expr::Slice { + input: Arc::new(self), + offset: Arc::new(offset.into()), + length: Arc::new(length.into()), + } + } + + /// Append expressions. This is done by adding the chunks of `other` to this [`Series`]. + pub fn append>(self, other: E, upcast: bool) -> Self { + let output_type = if upcast { + GetOutput::super_type() + } else { + GetOutput::same_type() + }; + + apply_binary( + self, + other.into(), + move |mut a, mut b| { + if upcast { + let dtype = try_get_supertype(a.dtype(), b.dtype())?; + a = a.cast(&dtype)?; + b = b.cast(&dtype)?; + } + a.append(&b)?; + Ok(Some(a)) + }, + output_type, + ) + } + + /// Get the first `n` elements of the Expr result. + pub fn head(self, length: Option) -> Self { + self.slice(lit(0), lit(length.unwrap_or(10) as u64)) + } + + /// Get the last `n` elements of the Expr result. + pub fn tail(self, length: Option) -> Self { + let len = length.unwrap_or(10); + self.slice(lit(-(len as i64)), lit(len as u64)) + } + + /// Get unique values of this expression. + pub fn unique(self) -> Self { + self.map_unary(FunctionExpr::Unique(false)) + } + + /// Get unique values of this expression, while maintaining order. + /// This requires more work than [`Expr::unique`]. + pub fn unique_stable(self) -> Self { + self.map_unary(FunctionExpr::Unique(true)) + } + + /// Get the first index of unique values of this expression. + pub fn arg_unique(self) -> Self { + self.map_unary(FunctionExpr::ArgUnique) + } + + /// Get the index value that has the minimum value. + pub fn arg_min(self) -> Self { + let options = FunctionOptions { + collect_groups: ApplyOptions::GroupWise, + flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR, + fmt_str: "arg_min", + ..Default::default() + }; + + self.function_with_options( + move |c: Column| { + Ok(Some(Column::new( + c.name().clone(), + &[c.as_materialized_series().arg_min().map(|idx| idx as u32)], + ))) + }, + GetOutput::from_type(IDX_DTYPE), + options, + ) + } + + /// Get the index value that has the maximum value. + pub fn arg_max(self) -> Self { + let options = FunctionOptions { + collect_groups: ApplyOptions::GroupWise, + flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR, + fmt_str: "arg_max", + ..Default::default() + }; + + self.function_with_options( + move |c: Column| { + Ok(Some(Column::new( + c.name().clone(), + &[c.as_materialized_series() + .arg_max() + .map(|idx| idx as IdxSize)], + ))) + }, + GetOutput::from_type(IDX_DTYPE), + options, + ) + } + + /// Get the index values that would sort this expression. + pub fn arg_sort(self, sort_options: SortOptions) -> Self { + let options = FunctionOptions { + collect_groups: ApplyOptions::GroupWise, + fmt_str: "arg_sort", + ..Default::default() + }; + + self.function_with_options( + move |c: Column| { + Ok(Some( + c.as_materialized_series() + .arg_sort(sort_options) + .into_column(), + )) + }, + GetOutput::from_type(IDX_DTYPE), + options, + ) + } + + #[cfg(feature = "index_of")] + /// Find the index of a value. + pub fn index_of>(self, element: E) -> Expr { + self.map_binary(FunctionExpr::IndexOf, element.into()) + } + + #[cfg(feature = "search_sorted")] + /// Find indices where elements should be inserted to maintain order. + pub fn search_sorted>(self, element: E, side: SearchSortedSide) -> Expr { + self.map_binary(FunctionExpr::SearchSorted(side), element.into()) + } + + /// Cast expression to another data type. + /// Throws an error if conversion had overflows. + /// Returns an Error if cast is invalid on rows after predicates are pushed down. + pub fn strict_cast(self, dtype: DataType) -> Self { + Expr::Cast { + expr: Arc::new(self), + dtype, + options: CastOptions::Strict, + } + } + + /// Cast expression to another data type. + pub fn cast(self, dtype: DataType) -> Self { + Expr::Cast { + expr: Arc::new(self), + dtype, + options: CastOptions::NonStrict, + } + } + + /// Cast expression to another data type. + pub fn cast_with_options(self, dtype: DataType, cast_options: CastOptions) -> Self { + Expr::Cast { + expr: Arc::new(self), + dtype, + options: cast_options, + } + } + + /// Take the values by idx. + pub fn gather>(self, idx: E) -> Self { + Expr::Gather { + expr: Arc::new(self), + idx: Arc::new(idx.into()), + returns_scalar: false, + } + } + + /// Take the values by a single index. + pub fn get>(self, idx: E) -> Self { + Expr::Gather { + expr: Arc::new(self), + idx: Arc::new(idx.into()), + returns_scalar: true, + } + } + + /// Sort with given options. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// # use polars_lazy::prelude::*; + /// # fn main() -> PolarsResult<()> { + /// let lf = df! { + /// "a" => [Some(5), Some(4), Some(3), Some(2), None] + /// }? + /// .lazy(); + /// + /// let sorted = lf + /// .select( + /// vec![col("a").sort(SortOptions::default())], + /// ) + /// .collect()?; + /// + /// assert_eq!( + /// sorted, + /// df! { + /// "a" => [None, Some(2), Some(3), Some(4), Some(5)] + /// }? + /// ); + /// # Ok(()) + /// # } + /// ``` + /// See [`SortOptions`] for more options. + pub fn sort(self, options: SortOptions) -> Self { + Expr::Sort { + expr: Arc::new(self), + options, + } + } + + /// Returns the `k` largest elements. + /// + /// This has time complexity `O(n + k log(n))`. + #[cfg(feature = "top_k")] + pub fn top_k(self, k: Expr) -> Self { + self.map_binary(FunctionExpr::TopK { descending: false }, k) + } + + /// Returns the `k` largest rows by given column. + /// + /// For single column, use [`Expr::top_k`]. + #[cfg(feature = "top_k")] + pub fn top_k_by, E: AsRef<[IE]>, IE: Into + Clone>( + self, + k: K, + by: E, + descending: Vec, + ) -> Self { + self.map_n_ary( + FunctionExpr::TopKBy { descending }, + [k.into()] + .into_iter() + .chain(by.as_ref().iter().map(|e| -> Expr { e.clone().into() })), + ) + } + + /// Returns the `k` smallest elements. + /// + /// This has time complexity `O(n + k log(n))`. + #[cfg(feature = "top_k")] + pub fn bottom_k(self, k: Expr) -> Self { + self.map_binary(FunctionExpr::TopK { descending: true }, k) + } + + /// Returns the `k` smallest rows by given column. + /// + /// For single column, use [`Expr::bottom_k`]. + // #[cfg(feature = "top_k")] + #[cfg(feature = "top_k")] + pub fn bottom_k_by, E: AsRef<[IE]>, IE: Into + Clone>( + self, + k: K, + by: E, + descending: Vec, + ) -> Self { + let descending = descending.into_iter().map(|x| !x).collect(); + self.map_n_ary( + FunctionExpr::TopKBy { descending }, + [k.into()] + .into_iter() + .chain(by.as_ref().iter().map(|e| -> Expr { e.clone().into() })), + ) + } + + /// Reverse column + pub fn reverse(self) -> Self { + self.map_unary(FunctionExpr::Reverse) + } + + /// Apply a function/closure once the logical plan get executed. + /// + /// This function is very similar to [`Expr::apply`], but differs in how it handles aggregations. + /// + /// * `map` should be used for operations that are independent of groups, e.g. `multiply * 2`, or `raise to the power` + /// * `apply` should be used for operations that work on a group of data. e.g. `sum`, `count`, etc. + /// + /// It is the responsibility of the caller that the schema is correct by giving + /// the correct output_type. If None given the output type of the input expr is used. + pub fn map(self, function: F, output_type: GetOutput) -> Self + where + F: Fn(Column) -> PolarsResult> + 'static + Send + Sync, + { + let f = move |c: &mut [Column]| function(std::mem::take(&mut c[0])); + + Expr::AnonymousFunction { + input: vec![self], + function: new_column_udf(f), + output_type, + options: FunctionOptions { + collect_groups: ApplyOptions::ElementWise, + fmt_str: "map", + flags: FunctionFlags::default() | FunctionFlags::OPTIONAL_RE_ENTRANT, + ..Default::default() + }, + } + } + + /// Apply a function/closure once the logical plan get executed with many arguments. + /// + /// See the [`Expr::map`] function for the differences between [`map`](Expr::map) and [`apply`](Expr::apply). + pub fn map_many(self, function: F, arguments: &[Expr], output_type: GetOutput) -> Self + where + F: Fn(&mut [Column]) -> PolarsResult> + 'static + Send + Sync, + { + let mut input = vec![self]; + input.extend_from_slice(arguments); + + Expr::AnonymousFunction { + input, + function: new_column_udf(function), + output_type, + options: FunctionOptions { + collect_groups: ApplyOptions::ElementWise, + fmt_str: "", + ..Default::default() + }, + } + } + + /// Apply a function/closure once the logical plan get executed. + /// + /// This function is very similar to [apply](Expr::apply), but differs in how it handles aggregations. + /// + /// * `map` should be used for operations that are independent of groups, e.g. `multiply * 2`, or `raise to the power` + /// * `apply` should be used for operations that work on a group of data. e.g. `sum`, `count`, etc. + /// * `map_list` should be used when the function expects a list aggregated series. + pub fn map_list(self, function: F, output_type: GetOutput) -> Self + where + F: Fn(Column) -> PolarsResult> + 'static + Send + Sync, + { + let f = move |c: &mut [Column]| function(std::mem::take(&mut c[0])); + + Expr::AnonymousFunction { + input: vec![self], + function: new_column_udf(f), + output_type, + options: FunctionOptions { + collect_groups: ApplyOptions::ApplyList, + fmt_str: MAP_LIST_NAME, + ..Default::default() + }, + } + } + + /// A function that cannot be expressed with `map` or `apply` and requires extra settings. + pub fn function_with_options( + self, + function: F, + output_type: GetOutput, + options: FunctionOptions, + ) -> Self + where + F: Fn(Column) -> PolarsResult> + 'static + Send + Sync, + { + let f = move |c: &mut [Column]| function(std::mem::take(&mut c[0])); + + Expr::AnonymousFunction { + input: vec![self], + function: new_column_udf(f), + output_type, + options, + } + } + + /// Apply a function/closure over the groups. This should only be used in a group_by aggregation. + /// + /// It is the responsibility of the caller that the schema is correct by giving + /// the correct output_type. If None given the output type of the input expr is used. + /// + /// This difference with [map](Self::map) is that `apply` will create a separate `Series` per group. + /// + /// * `map` should be used for operations that are independent of groups, e.g. `multiply * 2`, or `raise to the power` + /// * `apply` should be used for operations that work on a group of data. e.g. `sum`, `count`, etc. + pub fn apply(self, function: F, output_type: GetOutput) -> Self + where + F: Fn(Column) -> PolarsResult> + 'static + Send + Sync, + { + let f = move |c: &mut [Column]| function(std::mem::take(&mut c[0])); + + Expr::AnonymousFunction { + input: vec![self], + function: new_column_udf(f), + output_type, + options: FunctionOptions { + collect_groups: ApplyOptions::GroupWise, + fmt_str: "", + ..Default::default() + }, + } + } + + /// Apply a function/closure over the groups with many arguments. This should only be used in a group_by aggregation. + /// + /// See the [`Expr::apply`] function for the differences between [`map`](Expr::map) and [`apply`](Expr::apply). + pub fn apply_many(self, function: F, arguments: &[Expr], output_type: GetOutput) -> Self + where + F: Fn(&mut [Column]) -> PolarsResult> + 'static + Send + Sync, + { + let mut input = vec![self]; + input.extend_from_slice(arguments); + + Expr::AnonymousFunction { + input, + function: new_column_udf(function), + output_type, + options: FunctionOptions { + collect_groups: ApplyOptions::GroupWise, + fmt_str: "", + ..Default::default() + }, + } + } + + /// Get mask of finite values if dtype is Float. + #[allow(clippy::wrong_self_convention)] + pub fn is_finite(self) -> Self { + self.map_unary(BooleanFunction::IsFinite) + } + + /// Get mask of infinite values if dtype is Float. + #[allow(clippy::wrong_self_convention)] + pub fn is_infinite(self) -> Self { + self.map_unary(BooleanFunction::IsInfinite) + } + + /// Get mask of NaN values if dtype is Float. + pub fn is_nan(self) -> Self { + self.map_unary(BooleanFunction::IsNan) + } + + /// Get inverse mask of NaN values if dtype is Float. + pub fn is_not_nan(self) -> Self { + self.map_unary(BooleanFunction::IsNotNan) + } + + /// Shift the values in the array by some period. See [the eager implementation](polars_core::series::SeriesTrait::shift). + pub fn shift(self, n: Expr) -> Self { + self.map_binary(FunctionExpr::Shift, n) + } + + /// Shift the values in the array by some period and fill the resulting empty values. + pub fn shift_and_fill, IE: Into>(self, n: E, fill_value: IE) -> Self { + self.map_ternary(FunctionExpr::ShiftAndFill, n.into(), fill_value.into()) + } + + /// Cumulatively count values from 0 to len. + #[cfg(feature = "cum_agg")] + pub fn cum_count(self, reverse: bool) -> Self { + self.map_unary(FunctionExpr::CumCount { reverse }) + } + + /// Get an array with the cumulative sum computed at every element. + #[cfg(feature = "cum_agg")] + pub fn cum_sum(self, reverse: bool) -> Self { + self.map_unary(FunctionExpr::CumSum { reverse }) + } + + /// Get an array with the cumulative product computed at every element. + #[cfg(feature = "cum_agg")] + pub fn cum_prod(self, reverse: bool) -> Self { + self.map_unary(FunctionExpr::CumProd { reverse }) + } + + /// Get an array with the cumulative min computed at every element. + #[cfg(feature = "cum_agg")] + pub fn cum_min(self, reverse: bool) -> Self { + self.map_unary(FunctionExpr::CumMin { reverse }) + } + + /// Get an array with the cumulative max computed at every element. + #[cfg(feature = "cum_agg")] + pub fn cum_max(self, reverse: bool) -> Self { + self.map_unary(FunctionExpr::CumMax { reverse }) + } + + /// Get the product aggregation of an expression. + pub fn product(self) -> Self { + let options = FunctionOptions { + collect_groups: ApplyOptions::GroupWise, + flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR, + fmt_str: "product", + ..Default::default() + }; + + self.function_with_options( + move |c: Column| { + Some( + c.product() + .map(|sc| sc.into_series(c.name().clone()).into_column()), + ) + .transpose() + }, + GetOutput::map_dtype(|dt| { + use DataType as T; + Ok(match dt { + T::Float32 => T::Float32, + T::Float64 => T::Float64, + T::UInt64 => T::UInt64, + #[cfg(feature = "dtype-i128")] + T::Int128 => T::Int128, + _ => T::Int64, + }) + }), + options, + ) + } + + /// Round underlying floating point array to given decimal numbers. + #[cfg(feature = "round_series")] + pub fn round(self, decimals: u32) -> Self { + self.map_unary(FunctionExpr::Round { decimals }) + } + + /// Round to a number of significant figures. + #[cfg(feature = "round_series")] + pub fn round_sig_figs(self, digits: i32) -> Self { + self.map_unary(FunctionExpr::RoundSF { digits }) + } + + /// Floor underlying floating point array to the lowest integers smaller or equal to the float value. + #[cfg(feature = "round_series")] + pub fn floor(self) -> Self { + self.map_unary(FunctionExpr::Floor) + } + + /// Constant Pi + #[cfg(feature = "round_series")] + pub fn pi() -> Self { + lit(std::f64::consts::PI) + } + + /// Ceil underlying floating point array to the highest integers smaller or equal to the float value. + #[cfg(feature = "round_series")] + pub fn ceil(self) -> Self { + self.map_unary(FunctionExpr::Ceil) + } + + /// Clip underlying values to a set boundary. + #[cfg(feature = "round_series")] + pub fn clip(self, min: Expr, max: Expr) -> Self { + self.map_ternary( + FunctionExpr::Clip { + has_min: true, + has_max: true, + }, + min, + max, + ) + } + + /// Clip underlying values to a set boundary. + #[cfg(feature = "round_series")] + pub fn clip_max(self, max: Expr) -> Self { + self.map_binary( + FunctionExpr::Clip { + has_min: false, + has_max: true, + }, + max, + ) + } + + /// Clip underlying values to a set boundary. + #[cfg(feature = "round_series")] + pub fn clip_min(self, min: Expr) -> Self { + self.map_binary( + FunctionExpr::Clip { + has_min: true, + has_max: false, + }, + min, + ) + } + + /// Convert all values to their absolute/positive value. + #[cfg(feature = "abs")] + pub fn abs(self) -> Self { + self.map_unary(FunctionExpr::Abs) + } + + /// Apply window function over a subgroup. + /// This is similar to a group_by + aggregation + self join. + /// Or similar to [window functions in Postgres](https://www.postgresql.org/docs/9.1/tutorial-window.html). + /// + /// # Example + /// + /// ``` rust + /// #[macro_use] extern crate polars_core; + /// use polars_core::prelude::*; + /// use polars_lazy::prelude::*; + /// + /// fn example() -> PolarsResult<()> { + /// let df = df! { + /// "groups" => &[1, 1, 2, 2, 1, 2, 3, 3, 1], + /// "values" => &[1, 2, 3, 4, 5, 6, 7, 8, 8] + /// }?; + /// + /// let out = df + /// .lazy() + /// .select(&[ + /// col("groups"), + /// sum("values").over([col("groups")]), + /// ]) + /// .collect()?; + /// println!("{}", &out); + /// Ok(()) + /// } + /// + /// ``` + /// + /// Outputs: + /// + /// ``` text + /// ╭────────┬────────╮ + /// │ groups ┆ values │ + /// │ --- ┆ --- │ + /// │ i32 ┆ i32 │ + /// ╞════════╪════════╡ + /// │ 1 ┆ 16 │ + /// │ 1 ┆ 16 │ + /// │ 2 ┆ 13 │ + /// │ 2 ┆ 13 │ + /// │ … ┆ … │ + /// │ 1 ┆ 16 │ + /// │ 2 ┆ 13 │ + /// │ 3 ┆ 15 │ + /// │ 3 ┆ 15 │ + /// │ 1 ┆ 16 │ + /// ╰────────┴────────╯ + /// ``` + pub fn over, IE: Into + Clone>(self, partition_by: E) -> Self { + self.over_with_options(partition_by, None, Default::default()) + } + + pub fn over_with_options, IE: Into + Clone>( + self, + partition_by: E, + order_by: Option<(E, SortOptions)>, + options: WindowMapping, + ) -> Self { + let partition_by = partition_by + .as_ref() + .iter() + .map(|e| e.clone().into()) + .collect(); + + let order_by = order_by.map(|(e, options)| { + let e = e.as_ref(); + let e = if e.len() == 1 { + Arc::new(e[0].clone().into()) + } else { + feature_gated!["dtype-struct", { + let e = e.iter().map(|e| e.clone().into()).collect::>(); + Arc::new(as_struct(e)) + }] + }; + (e, options) + }); + + Expr::Window { + function: Arc::new(self), + partition_by, + order_by, + options: options.into(), + } + } + + #[cfg(feature = "dynamic_group_by")] + pub fn rolling(self, options: RollingGroupOptions) -> Self { + // We add the index column as `partition expr` so that the optimizer will + // not ignore it. + let index_col = col(options.index_column.clone()); + Expr::Window { + function: Arc::new(self), + partition_by: vec![index_col], + order_by: None, + options: WindowType::Rolling(options), + } + } + + fn fill_null_impl(self, fill_value: Expr) -> Self { + self.map_binary(FunctionExpr::FillNull, fill_value) + } + + /// Replace the null values by a value. + pub fn fill_null>(self, fill_value: E) -> Self { + self.fill_null_impl(fill_value.into()) + } + + pub fn fill_null_with_strategy(self, strategy: FillNullStrategy) -> Self { + self.map_unary(FunctionExpr::FillNullWithStrategy(strategy)) + } + + /// Replace the floating point `NaN` values by a value. + pub fn fill_nan>(self, fill_value: E) -> Self { + // we take the not branch so that self is truthy value of `when -> then -> otherwise` + // and that ensure we keep the name of `self` + + when(self.clone().is_not_nan().or(self.clone().is_null())) + .then(self) + .otherwise(fill_value.into()) + } + /// Count the values of the Series + /// or + /// Get counts of the group by operation. + pub fn count(self) -> Self { + AggExpr::Count(Arc::new(self), false).into() + } + + pub fn len(self) -> Self { + AggExpr::Count(Arc::new(self), true).into() + } + + /// Get a mask of duplicated values. + #[allow(clippy::wrong_self_convention)] + #[cfg(feature = "is_unique")] + pub fn is_duplicated(self) -> Self { + self.map_unary(BooleanFunction::IsDuplicated) + } + + #[allow(clippy::wrong_self_convention)] + #[cfg(feature = "is_between")] + pub fn is_between>(self, lower: E, upper: E, closed: ClosedInterval) -> Self { + self.map_ternary( + BooleanFunction::IsBetween { closed }, + lower.into(), + upper.into(), + ) + } + + /// Get a mask of unique values. + #[allow(clippy::wrong_self_convention)] + #[cfg(feature = "is_unique")] + pub fn is_unique(self) -> Self { + self.map_unary(BooleanFunction::IsUnique) + } + + /// Get the approximate count of unique values. + #[cfg(feature = "approx_unique")] + pub fn approx_n_unique(self) -> Self { + self.map_unary(FunctionExpr::ApproxNUnique) + } + + /// Bitwise "and" operation. + pub fn and>(self, expr: E) -> Self { + binary_expr(self, Operator::And, expr.into()) + } + + /// Bitwise "xor" operation. + pub fn xor>(self, expr: E) -> Self { + binary_expr(self, Operator::Xor, expr.into()) + } + + /// Bitwise "or" operation. + pub fn or>(self, expr: E) -> Self { + binary_expr(self, Operator::Or, expr.into()) + } + + /// Logical "or" operation. + pub fn logical_or>(self, expr: E) -> Self { + binary_expr(self, Operator::LogicalOr, expr.into()) + } + + /// Logical "and" operation. + pub fn logical_and>(self, expr: E) -> Self { + binary_expr(self, Operator::LogicalAnd, expr.into()) + } + + /// Filter a single column. + /// + /// Should be used in aggregation context. If you want to filter on a + /// DataFrame level, use `LazyFrame::filter`. + pub fn filter>(self, predicate: E) -> Self { + if has_expr(&self, |e| matches!(e, Expr::Wildcard)) { + panic!("filter '*' not allowed, use LazyFrame::filter") + }; + Expr::Filter { + input: Arc::new(self), + by: Arc::new(predicate.into()), + } + } + + /// Check if the values of the left expression are in the lists of the right expr. + #[allow(clippy::wrong_self_convention)] + #[cfg(feature = "is_in")] + pub fn is_in>(self, other: E, nulls_equal: bool) -> Self { + let other = other.into(); + let function = BooleanFunction::IsIn { nulls_equal }; + let options = function.function_options(); + let function = function.into(); + Expr::Function { + input: vec![self, other], + function, + options, + } + } + + /// Sort this column by the ordering of another column evaluated from given expr. + /// Can also be used in a group_by context to sort the groups. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// # use polars_lazy::prelude::*; + /// # fn main() -> PolarsResult<()> { + /// let lf = df! { + /// "a" => [1, 2, 3, 4, 5], + /// "b" => [5, 4, 3, 2, 1] + /// }?.lazy(); + /// + /// let sorted = lf + /// .select( + /// vec![col("a").sort_by(col("b"), SortOptions::default())], + /// ) + /// .collect()?; + /// + /// assert_eq!( + /// sorted, + /// df! { "a" => [5, 4, 3, 2, 1] }? + /// ); + /// # Ok(()) + /// # } + pub fn sort_by, IE: Into + Clone>( + self, + by: E, + sort_options: SortMultipleOptions, + ) -> Expr { + let by = by.as_ref().iter().map(|e| e.clone().into()).collect(); + Expr::SortBy { + expr: Arc::new(self), + by, + sort_options, + } + } + + #[cfg(feature = "repeat_by")] + /// Repeat the column `n` times, where `n` is determined by the values in `by`. + /// This yields an `Expr` of dtype `List`. + pub fn repeat_by>(self, by: E) -> Expr { + self.map_binary(FunctionExpr::RepeatBy, by.into()) + } + + #[cfg(feature = "is_first_distinct")] + #[allow(clippy::wrong_self_convention)] + /// Get a mask of the first unique value. + pub fn is_first_distinct(self) -> Expr { + self.map_unary(BooleanFunction::IsFirstDistinct) + } + + #[cfg(feature = "is_last_distinct")] + #[allow(clippy::wrong_self_convention)] + /// Get a mask of the last unique value. + pub fn is_last_distinct(self) -> Expr { + self.map_unary(BooleanFunction::IsLastDistinct) + } + + fn dot_impl(self, other: Expr) -> Expr { + (self * other).sum() + } + + /// Compute the dot/inner product between two expressions. + pub fn dot>(self, other: E) -> Expr { + self.dot_impl(other.into()) + } + + #[cfg(feature = "mode")] + /// Compute the mode(s) of this column. This is the most occurring value. + pub fn mode(self) -> Expr { + self.map_unary(FunctionExpr::Mode) + } + + /// Exclude a column from a wildcard/regex selection. + /// + /// You may also use regexes in the exclude as long as they start with `^` and end with `$`. + pub fn exclude(self, columns: impl IntoVec) -> Expr { + let v = columns.into_vec().into_iter().map(Excluded::Name).collect(); + Expr::Exclude(Arc::new(self), v) + } + + pub fn exclude_dtype>(self, dtypes: D) -> Expr { + let v = dtypes + .as_ref() + .iter() + .map(|dt| Excluded::Dtype(dt.clone())) + .collect(); + Expr::Exclude(Arc::new(self), v) + } + + #[cfg(feature = "interpolate")] + /// Interpolate intermediate values. + /// Nulls at the beginning and end of the series remain null. + pub fn interpolate(self, method: InterpolationMethod) -> Expr { + self.map_unary(FunctionExpr::Interpolate(method)) + } + + #[cfg(feature = "rolling_window_by")] + #[allow(clippy::type_complexity)] + fn finish_rolling_by( + self, + by: Expr, + options: RollingOptionsDynamicWindow, + rolling_function_by: fn(RollingOptionsDynamicWindow) -> RollingFunctionBy, + ) -> Expr { + self.map_binary( + FunctionExpr::RollingExprBy(rolling_function_by(options)), + by, + ) + } + + #[cfg(feature = "interpolate_by")] + /// Interpolate intermediate values. + /// Nulls at the beginning and end of the series remain null. + /// The `by` column provides the x-coordinates for interpolation and must not contain nulls. + pub fn interpolate_by(self, by: Expr) -> Expr { + self.map_binary(FunctionExpr::InterpolateBy, by) + } + + #[cfg(feature = "rolling_window")] + #[allow(clippy::type_complexity)] + fn finish_rolling( + self, + options: RollingOptionsFixedWindow, + rolling_function: fn(RollingOptionsFixedWindow) -> RollingFunction, + ) -> Expr { + self.map_unary(FunctionExpr::RollingExpr(rolling_function(options))) + } + + /// Apply a rolling minimum based on another column. + #[cfg(feature = "rolling_window_by")] + pub fn rolling_min_by(self, by: Expr, options: RollingOptionsDynamicWindow) -> Expr { + self.finish_rolling_by(by, options, RollingFunctionBy::MinBy) + } + + /// Apply a rolling maximum based on another column. + #[cfg(feature = "rolling_window_by")] + pub fn rolling_max_by(self, by: Expr, options: RollingOptionsDynamicWindow) -> Expr { + self.finish_rolling_by(by, options, RollingFunctionBy::MaxBy) + } + + /// Apply a rolling mean based on another column. + #[cfg(feature = "rolling_window_by")] + pub fn rolling_mean_by(self, by: Expr, options: RollingOptionsDynamicWindow) -> Expr { + self.finish_rolling_by(by, options, RollingFunctionBy::MeanBy) + } + + /// Apply a rolling sum based on another column. + #[cfg(feature = "rolling_window_by")] + pub fn rolling_sum_by(self, by: Expr, options: RollingOptionsDynamicWindow) -> Expr { + self.finish_rolling_by(by, options, RollingFunctionBy::SumBy) + } + + /// Apply a rolling quantile based on another column. + #[cfg(feature = "rolling_window_by")] + pub fn rolling_quantile_by( + self, + by: Expr, + method: QuantileMethod, + quantile: f64, + mut options: RollingOptionsDynamicWindow, + ) -> Expr { + use polars_compute::rolling::{RollingFnParams, RollingQuantileParams}; + options.fn_params = Some(RollingFnParams::Quantile(RollingQuantileParams { + prob: quantile, + method, + })); + + self.finish_rolling_by(by, options, RollingFunctionBy::QuantileBy) + } + + /// Apply a rolling variance based on another column. + #[cfg(feature = "rolling_window_by")] + pub fn rolling_var_by(self, by: Expr, options: RollingOptionsDynamicWindow) -> Expr { + self.finish_rolling_by(by, options, RollingFunctionBy::VarBy) + } + + /// Apply a rolling std-dev based on another column. + #[cfg(feature = "rolling_window_by")] + pub fn rolling_std_by(self, by: Expr, options: RollingOptionsDynamicWindow) -> Expr { + self.finish_rolling_by(by, options, RollingFunctionBy::StdBy) + } + + /// Apply a rolling median based on another column. + #[cfg(feature = "rolling_window_by")] + pub fn rolling_median_by(self, by: Expr, options: RollingOptionsDynamicWindow) -> Expr { + self.rolling_quantile_by(by, QuantileMethod::Linear, 0.5, options) + } + + /// Apply a rolling minimum. + /// + /// See: [`RollingAgg::rolling_min`] + #[cfg(feature = "rolling_window")] + pub fn rolling_min(self, options: RollingOptionsFixedWindow) -> Expr { + self.finish_rolling(options, RollingFunction::Min) + } + + /// Apply a rolling maximum. + /// + /// See: [`RollingAgg::rolling_max`] + #[cfg(feature = "rolling_window")] + pub fn rolling_max(self, options: RollingOptionsFixedWindow) -> Expr { + self.finish_rolling(options, RollingFunction::Max) + } + + /// Apply a rolling mean. + /// + /// See: [`RollingAgg::rolling_mean`] + #[cfg(feature = "rolling_window")] + pub fn rolling_mean(self, options: RollingOptionsFixedWindow) -> Expr { + self.finish_rolling(options, RollingFunction::Mean) + } + + /// Apply a rolling sum. + /// + /// See: [`RollingAgg::rolling_sum`] + #[cfg(feature = "rolling_window")] + pub fn rolling_sum(self, options: RollingOptionsFixedWindow) -> Expr { + self.finish_rolling(options, RollingFunction::Sum) + } + + /// Apply a rolling median. + /// + /// See: [`RollingAgg::rolling_median`] + #[cfg(feature = "rolling_window")] + pub fn rolling_median(self, options: RollingOptionsFixedWindow) -> Expr { + self.rolling_quantile(QuantileMethod::Linear, 0.5, options) + } + + /// Apply a rolling quantile. + /// + /// See: [`RollingAgg::rolling_quantile`] + #[cfg(feature = "rolling_window")] + pub fn rolling_quantile( + self, + method: QuantileMethod, + quantile: f64, + mut options: RollingOptionsFixedWindow, + ) -> Expr { + use polars_compute::rolling::{RollingFnParams, RollingQuantileParams}; + + options.fn_params = Some(RollingFnParams::Quantile(RollingQuantileParams { + prob: quantile, + method, + })); + + self.finish_rolling(options, RollingFunction::Quantile) + } + + /// Apply a rolling variance. + #[cfg(feature = "rolling_window")] + pub fn rolling_var(self, options: RollingOptionsFixedWindow) -> Expr { + self.finish_rolling(options, RollingFunction::Var) + } + + /// Apply a rolling std-dev. + #[cfg(feature = "rolling_window")] + pub fn rolling_std(self, options: RollingOptionsFixedWindow) -> Expr { + self.finish_rolling(options, RollingFunction::Std) + } + + /// Apply a rolling skew. + #[cfg(feature = "rolling_window")] + #[cfg(feature = "moment")] + pub fn rolling_skew(self, options: RollingOptionsFixedWindow) -> Expr { + self.finish_rolling(options, RollingFunction::Skew) + } + + /// Apply a rolling skew. + #[cfg(feature = "rolling_window")] + #[cfg(feature = "moment")] + pub fn rolling_kurtosis(self, options: RollingOptionsFixedWindow) -> Expr { + self.finish_rolling(options, RollingFunction::Kurtosis) + } + + #[cfg(feature = "rolling_window")] + /// Apply a custom function over a rolling/ moving window of the array. + /// This has quite some dynamic dispatch, so prefer rolling_min, max, mean, sum over this. + pub fn rolling_map( + self, + f: Arc Series + Send + Sync>, + output_type: GetOutput, + options: RollingOptionsFixedWindow, + ) -> Expr { + self.apply( + move |c: Column| { + c.as_materialized_series() + .rolling_map(f.as_ref(), options.clone()) + .map(Column::from) + .map(Some) + }, + output_type, + ) + .with_fmt("rolling_map") + } + + #[cfg(feature = "rolling_window")] + /// Apply a custom function over a rolling/ moving window of the array. + /// Prefer this over rolling_apply in case of floating point numbers as this is faster. + /// This has quite some dynamic dispatch, so prefer rolling_min, max, mean, sum over this. + pub fn rolling_map_float(self, window_size: usize, f: F) -> Expr + where + F: 'static + FnMut(&mut Float64Chunked) -> Option + Send + Sync + Copy, + { + self.apply( + move |c: Column| { + let out = match c.dtype() { + DataType::Float64 => c + .f64() + .unwrap() + .rolling_map_float(window_size, f) + .map(|ca| ca.into_column()), + _ => c + .cast(&DataType::Float64)? + .f64() + .unwrap() + .rolling_map_float(window_size, f) + .map(|ca| ca.into_column()), + }?; + if let DataType::Float32 = c.dtype() { + out.cast(&DataType::Float32).map(Some) + } else { + Ok(Some(out)) + } + }, + GetOutput::map_field(|field| { + Ok(match field.dtype() { + DataType::Float64 => field.clone(), + DataType::Float32 => Field::new(field.name().clone(), DataType::Float32), + _ => Field::new(field.name().clone(), DataType::Float64), + }) + }), + ) + .with_fmt("rolling_map_float") + } + + #[cfg(feature = "peaks")] + pub fn peak_min(self) -> Expr { + self.map_unary(FunctionExpr::PeakMin) + } + + #[cfg(feature = "peaks")] + pub fn peak_max(self) -> Expr { + self.map_unary(FunctionExpr::PeakMax) + } + + #[cfg(feature = "rank")] + /// Assign ranks to data, dealing with ties appropriately. + pub fn rank(self, options: RankOptions, seed: Option) -> Expr { + self.map_unary(FunctionExpr::Rank { options, seed }) + } + + #[cfg(feature = "replace")] + /// Replace the given values with other values. + pub fn replace>(self, old: E, new: E) -> Expr { + let old = old.into(); + let new = new.into(); + let literal_args = is_column_independent(&old) && is_column_independent(&new); + let function = FunctionExpr::Replace; + let mut options = function.function_options(); + if !literal_args { + // If we search and replace by constants, we can run on batches. + // TODO: this optimization should be done during conversion to IR. + options.collect_groups = ApplyOptions::GroupWise; + } + Expr::Function { + input: vec![self, old, new], + function, + options, + } + } + + #[cfg(feature = "replace")] + /// Replace the given values with other values. + pub fn replace_strict>( + self, + old: E, + new: E, + default: Option, + return_dtype: Option, + ) -> Expr { + let old = old.into(); + let new = new.into(); + + // If we replace by constants, we can run on batches. + // TODO: this optimization should be done during conversion to IR. + let literal_args = is_column_independent(&old) && is_column_independent(&new); + + let mut args = vec![self, old, new]; + args.extend(default.map(Into::into)); + let function = FunctionExpr::ReplaceStrict { return_dtype }; + let mut options = function.function_options(); + if !literal_args { + // If we search and replace by constants, we can run on batches. + // TODO: this optimization should be done during conversion to IR. + options.collect_groups = ApplyOptions::GroupWise; + } + Expr::Function { + input: args, + function, + options, + } + } + + #[cfg(feature = "cutqcut")] + /// Bin continuous values into discrete categories. + pub fn cut( + self, + breaks: Vec, + labels: Option>, + left_closed: bool, + include_breaks: bool, + ) -> Expr { + self.map_unary(FunctionExpr::Cut { + breaks, + labels: labels.map(|x| x.into_vec()), + left_closed, + include_breaks, + }) + } + + #[cfg(feature = "cutqcut")] + /// Bin continuous values into discrete categories based on their quantiles. + pub fn qcut( + self, + probs: Vec, + labels: Option>, + left_closed: bool, + allow_duplicates: bool, + include_breaks: bool, + ) -> Expr { + self.map_unary(FunctionExpr::QCut { + probs, + labels: labels.map(|x| x.into_vec()), + left_closed, + allow_duplicates, + include_breaks, + }) + } + + #[cfg(feature = "cutqcut")] + /// Bin continuous values into discrete categories using uniform quantile probabilities. + pub fn qcut_uniform( + self, + n_bins: usize, + labels: Option>, + left_closed: bool, + allow_duplicates: bool, + include_breaks: bool, + ) -> Expr { + let probs = (1..n_bins).map(|b| b as f64 / n_bins as f64).collect(); + self.map_unary(FunctionExpr::QCut { + probs, + labels: labels.map(|x| x.into_vec()), + left_closed, + allow_duplicates, + include_breaks, + }) + } + + #[cfg(feature = "rle")] + /// Get the lengths of runs of identical values. + pub fn rle(self) -> Expr { + self.map_unary(FunctionExpr::RLE) + } + + #[cfg(feature = "rle")] + /// Similar to `rle`, but maps values to run IDs. + pub fn rle_id(self) -> Expr { + self.map_unary(FunctionExpr::RLEID) + } + + #[cfg(feature = "diff")] + /// Calculate the n-th discrete difference between values. + pub fn diff(self, n: Expr, null_behavior: NullBehavior) -> Expr { + self.map_binary(FunctionExpr::Diff(null_behavior), n) + } + + #[cfg(feature = "pct_change")] + /// Computes percentage change between values. + pub fn pct_change(self, n: Expr) -> Expr { + self.map_binary(FunctionExpr::PctChange, n) + } + + #[cfg(feature = "moment")] + /// Compute the sample skewness of a data set. + /// + /// For normally distributed data, the skewness should be about zero. For + /// uni-modal continuous distributions, a skewness value greater than zero means + /// that there is more weight in the right tail of the distribution. The + /// function `skewtest` can be used to determine if the skewness value + /// is close enough to zero, statistically speaking. + /// + /// see: [scipy](https://github.com/scipy/scipy/blob/47bb6febaa10658c72962b9615d5d5aa2513fa3a/scipy/stats/stats.py#L1024) + pub fn skew(self, bias: bool) -> Expr { + self.map_unary(FunctionExpr::Skew(bias)) + } + + #[cfg(feature = "moment")] + /// Compute the kurtosis (Fisher or Pearson). + /// + /// Kurtosis is the fourth central moment divided by the square of the + /// variance. If Fisher's definition is used, then 3.0 is subtracted from + /// the result to give 0.0 for a normal distribution. + /// If bias is False then the kurtosis is calculated using k statistics to + /// eliminate bias coming from biased moment estimators. + pub fn kurtosis(self, fisher: bool, bias: bool) -> Expr { + self.map_unary(FunctionExpr::Kurtosis(fisher, bias)) + } + + /// Get maximal value that could be hold by this dtype. + pub fn upper_bound(self) -> Expr { + self.map_unary(FunctionExpr::UpperBound) + } + + /// Get minimal value that could be hold by this dtype. + pub fn lower_bound(self) -> Expr { + self.map_unary(FunctionExpr::LowerBound) + } + + #[cfg(feature = "dtype-array")] + pub fn reshape(self, dimensions: &[i64]) -> Self { + let dimensions = dimensions + .iter() + .map(|&v| ReshapeDimension::new(v)) + .collect(); + self.map_unary(FunctionExpr::Reshape(dimensions)) + } + + #[cfg(feature = "ewma")] + /// Calculate the exponentially-weighted moving average. + pub fn ewm_mean(self, options: EWMOptions) -> Self { + self.map_unary(FunctionExpr::EwmMean { options }) + } + + #[cfg(feature = "ewma_by")] + /// Calculate the exponentially-weighted moving average by a time column. + pub fn ewm_mean_by(self, times: Expr, half_life: Duration) -> Self { + self.map_binary(FunctionExpr::EwmMeanBy { half_life }, times) + } + + #[cfg(feature = "ewma")] + /// Calculate the exponentially-weighted moving standard deviation. + pub fn ewm_std(self, options: EWMOptions) -> Self { + self.map_unary(FunctionExpr::EwmStd { options }) + } + + #[cfg(feature = "ewma")] + /// Calculate the exponentially-weighted moving variance. + pub fn ewm_var(self, options: EWMOptions) -> Self { + self.map_unary(FunctionExpr::EwmVar { options }) + } + + /// Returns whether any of the values in the column are `true`. + /// + /// If `ignore_nulls` is `False`, [Kleene logic] is used to deal with nulls: + /// if the column contains any null values and no `true` values, the output + /// is null. + /// + /// [Kleene logic]: https://en.wikipedia.org/wiki/Three-valued_logic + pub fn any(self, ignore_nulls: bool) -> Self { + self.map_unary(BooleanFunction::Any { ignore_nulls }) + } + + /// Returns whether all values in the column are `true`. + /// + /// If `ignore_nulls` is `False`, [Kleene logic] is used to deal with nulls: + /// if the column contains any null values and no `false` values, the output + /// is null. + /// + /// [Kleene logic]: https://en.wikipedia.org/wiki/Three-valued_logic + pub fn all(self, ignore_nulls: bool) -> Self { + self.map_unary(BooleanFunction::All { ignore_nulls }) + } + + /// Shrink numeric columns to the minimal required datatype + /// needed to fit the extrema of this [`Series`]. + /// This can be used to reduce memory pressure. + pub fn shrink_dtype(self) -> Self { + self.map_unary(FunctionExpr::ShrinkType) + } + + #[cfg(feature = "dtype-struct")] + /// Count all unique values and create a struct mapping value to count. + /// (Note that it is better to turn parallel off in the aggregation context). + /// The name of the struct field with the counts is given by the parameter `name`. + pub fn value_counts(self, sort: bool, parallel: bool, name: &str, normalize: bool) -> Self { + self.map_unary(FunctionExpr::ValueCounts { + sort, + parallel, + name: name.into(), + normalize, + }) + } + + #[cfg(feature = "unique_counts")] + /// Returns a count of the unique values in the order of appearance. + /// This method differs from [`Expr::value_counts`] in that it does not return the + /// values, only the counts and might be faster. + pub fn unique_counts(self) -> Self { + self.map_unary(FunctionExpr::UniqueCounts) + } + + #[cfg(feature = "log")] + /// Compute the logarithm to a given base. + pub fn log(self, base: f64) -> Self { + self.map_unary(FunctionExpr::Log { base }) + } + + #[cfg(feature = "log")] + /// Compute the natural logarithm of all elements plus one in the input array. + pub fn log1p(self) -> Self { + self.map_unary(FunctionExpr::Log1p) + } + + #[cfg(feature = "log")] + /// Calculate the exponential of all elements in the input array. + pub fn exp(self) -> Self { + self.map_unary(FunctionExpr::Exp) + } + + #[cfg(feature = "log")] + /// Compute the entropy as `-sum(pk * log(pk)`. + /// where `pk` are discrete probabilities. + pub fn entropy(self, base: f64, normalize: bool) -> Self { + self.map_unary(FunctionExpr::Entropy { base, normalize }) + } + /// Get the null count of the column/group. + pub fn null_count(self) -> Expr { + self.map_unary(FunctionExpr::NullCount) + } + + /// Set this `Series` as `sorted` so that downstream code can use + /// fast paths for sorted arrays. + /// # Warning + /// This can lead to incorrect results if this `Series` is not sorted!! + /// Use with care! + pub fn set_sorted_flag(self, sorted: IsSorted) -> Expr { + // This is `map`. If a column is sorted. Chunks of that column are also sorted. + self.map_unary(FunctionExpr::SetSortedFlag(sorted)) + } + + #[cfg(feature = "row_hash")] + /// Compute the hash of every element. + pub fn hash(self, k0: u64, k1: u64, k2: u64, k3: u64) -> Expr { + self.map_unary(FunctionExpr::Hash(k0, k1, k2, k3)) + } + + pub fn to_physical(self) -> Expr { + self.map_unary(FunctionExpr::ToPhysical) + } + + pub fn gather_every(self, n: usize, offset: usize) -> Expr { + self.map_unary(FunctionExpr::GatherEvery { n, offset }) + } + + #[cfg(feature = "reinterpret")] + pub fn reinterpret(self, signed: bool) -> Expr { + self.map_unary(FunctionExpr::Reinterpret(signed)) + } + + pub fn extend_constant(self, value: Expr, n: Expr) -> Expr { + self.map_ternary(FunctionExpr::ExtendConstant, value, n) + } + + #[cfg(feature = "strings")] + /// Get the [`string::StringNameSpace`] + pub fn str(self) -> string::StringNameSpace { + string::StringNameSpace(self) + } + + /// Get the [`binary::BinaryNameSpace`] + pub fn binary(self) -> binary::BinaryNameSpace { + binary::BinaryNameSpace(self) + } + + #[cfg(feature = "temporal")] + /// Get the [`dt::DateLikeNameSpace`] + pub fn dt(self) -> dt::DateLikeNameSpace { + dt::DateLikeNameSpace(self) + } + + /// Get the [`list::ListNameSpace`] + pub fn list(self) -> list::ListNameSpace { + list::ListNameSpace(self) + } + + /// Get the [`name::ExprNameNameSpace`] + pub fn name(self) -> name::ExprNameNameSpace { + name::ExprNameNameSpace(self) + } + + /// Get the [`array::ArrayNameSpace`]. + #[cfg(feature = "dtype-array")] + pub fn arr(self) -> array::ArrayNameSpace { + array::ArrayNameSpace(self) + } + + /// Get the [`CategoricalNameSpace`]. + #[cfg(feature = "dtype-categorical")] + pub fn cat(self) -> cat::CategoricalNameSpace { + cat::CategoricalNameSpace(self) + } + + /// Get the [`struct_::StructNameSpace`]. + #[cfg(feature = "dtype-struct")] + pub fn struct_(self) -> struct_::StructNameSpace { + struct_::StructNameSpace(self) + } + + /// Get the [`meta::MetaNameSpace`] + #[cfg(feature = "meta")] + pub fn meta(self) -> meta::MetaNameSpace { + meta::MetaNameSpace(self) + } +} + +/// Apply a function/closure over multiple columns once the logical plan get executed. +/// +/// This function is very similar to [`apply_multiple`], but differs in how it handles aggregations. +/// +/// * [`map_multiple`] should be used for operations that are independent of groups, e.g. `multiply * 2`, or `raise to the power` +/// * [`apply_multiple`] should be used for operations that work on a group of data. e.g. `sum`, `count`, etc. +/// +/// It is the responsibility of the caller that the schema is correct by giving +/// the correct output_type. If None given the output type of the input expr is used. +pub fn map_multiple(function: F, expr: E, output_type: GetOutput) -> Expr +where + F: Fn(&mut [Column]) -> PolarsResult> + 'static + Send + Sync, + E: AsRef<[Expr]>, +{ + let input = expr.as_ref().to_vec(); + + Expr::AnonymousFunction { + input, + function: new_column_udf(function), + output_type, + options: FunctionOptions { + collect_groups: ApplyOptions::ElementWise, + fmt_str: "", + ..Default::default() + }, + } +} + +/// Apply a function/closure over multiple columns once the logical plan get executed. +/// +/// This function is very similar to [`apply_multiple`], but differs in how it handles aggregations. +/// +/// * [`map_multiple`] should be used for operations that are independent of groups, e.g. `multiply * 2`, or `raise to the power` +/// * [`apply_multiple`] should be used for operations that work on a group of data. e.g. `sum`, `count`, etc. +/// * [`map_list_multiple`] should be used when the function expects a list aggregated series. +pub fn map_list_multiple(function: F, expr: E, output_type: GetOutput) -> Expr +where + F: Fn(&mut [Column]) -> PolarsResult> + 'static + Send + Sync, + E: AsRef<[Expr]>, +{ + let input = expr.as_ref().to_vec(); + + Expr::AnonymousFunction { + input, + function: new_column_udf(function), + output_type, + options: FunctionOptions { + collect_groups: ApplyOptions::ApplyList, + fmt_str: "", + flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR, + ..Default::default() + }, + } +} + +/// Apply a function/closure over the groups of multiple columns. This should only be used in a group_by aggregation. +/// +/// It is the responsibility of the caller that the schema is correct by giving +/// the correct output_type. If None given the output type of the input expr is used. +/// +/// This difference with [`map_multiple`] is that [`apply_multiple`] will create a separate [`Series`] per group. +/// +/// * [`map_multiple`] should be used for operations that are independent of groups, e.g. `multiply * 2`, or `raise to the power` +/// * [`apply_multiple`] should be used for operations that work on a group of data. e.g. `sum`, `count`, etc. +pub fn apply_multiple( + function: F, + expr: E, + output_type: GetOutput, + returns_scalar: bool, +) -> Expr +where + F: Fn(&mut [Column]) -> PolarsResult> + 'static + Send + Sync, + E: AsRef<[Expr]>, +{ + let input = expr.as_ref().to_vec(); + let mut flags = FunctionFlags::default(); + if returns_scalar { + flags |= FunctionFlags::RETURNS_SCALAR; + } + + Expr::AnonymousFunction { + input, + function: new_column_udf(function), + output_type, + options: FunctionOptions { + collect_groups: ApplyOptions::GroupWise, + // don't set this to true + // this is for the caller to decide + fmt_str: "", + flags, + ..Default::default() + }, + } +} + +/// Return the number of rows in the context. +pub fn len() -> Expr { + Expr::Len +} + +/// First column in a DataFrame. +pub fn first() -> Expr { + Expr::Nth(0) +} + +/// Last column in a DataFrame. +pub fn last() -> Expr { + Expr::Nth(-1) +} + +/// Nth column in a DataFrame. +pub fn nth(n: i64) -> Expr { + Expr::Nth(n) +} diff --git a/crates/polars-plan/src/dsl/name.rs b/crates/polars-plan/src/dsl/name.rs new file mode 100644 index 000000000000..e1115c8b3a0b --- /dev/null +++ b/crates/polars-plan/src/dsl/name.rs @@ -0,0 +1,114 @@ +use polars_utils::format_pl_smallstr; +#[cfg(feature = "dtype-struct")] +use polars_utils::pl_str::PlSmallStr; + +use super::*; + +/// Specialized expressions for modifying the name of existing expressions. +pub struct ExprNameNameSpace(pub(crate) Expr); + +impl ExprNameNameSpace { + /// Keep the original root name + /// + /// ```rust,no_run + /// # use polars_core::prelude::*; + /// # use polars_plan::prelude::*; + /// fn example(df: LazyFrame) -> LazyFrame { + /// df.select([ + /// // even thought the alias yields a different column name, + /// // `keep` will make sure that the original column name is used + /// col("*").alias("foo").name().keep() + /// ]) + /// } + /// ``` + pub fn keep(self) -> Expr { + Expr::KeepName(Arc::new(self.0)) + } + + /// Define an alias by mapping a function over the original root column name. + pub fn map(self, function: F) -> Expr + where + F: Fn(&PlSmallStr) -> PolarsResult + 'static + Send + Sync, + { + let function = SpecialEq::new(Arc::new(function) as Arc); + Expr::RenameAlias { + expr: Arc::new(self.0), + function, + } + } + + /// Add a prefix to the root column name. + pub fn prefix(self, prefix: &str) -> Expr { + let prefix = prefix.to_string(); + self.map(move |name| Ok(format_pl_smallstr!("{prefix}{name}"))) + } + + /// Add a suffix to the root column name. + pub fn suffix(self, suffix: &str) -> Expr { + let suffix = suffix.to_string(); + self.map(move |name| Ok(format_pl_smallstr!("{name}{suffix}"))) + } + + /// Update the root column name to use lowercase characters. + #[allow(clippy::wrong_self_convention)] + pub fn to_lowercase(self) -> Expr { + self.map(move |name| Ok(PlSmallStr::from_string(name.to_lowercase()))) + } + + /// Update the root column name to use uppercase characters. + #[allow(clippy::wrong_self_convention)] + pub fn to_uppercase(self) -> Expr { + self.map(move |name| Ok(PlSmallStr::from_string(name.to_uppercase()))) + } + + #[cfg(feature = "dtype-struct")] + pub fn map_fields(self, function: FieldsNameMapper) -> Expr { + let f = function.clone(); + self.0.map( + move |s| { + let s = s.struct_()?; + let fields = s + .fields_as_series() + .iter() + .map(|fd| { + let mut fd = fd.clone(); + fd.rename(function(fd.name())); + fd + }) + .collect::>(); + let mut out = StructChunked::from_series(s.name().clone(), s.len(), fields.iter())?; + out.zip_outer_validity(s); + Ok(Some(out.into_column())) + }, + GetOutput::map_dtype(move |dt| match dt { + DataType::Struct(fds) => { + let fields = fds + .iter() + .map(|fd| Field::new(f(fd.name()), fd.dtype().clone())) + .collect(); + Ok(DataType::Struct(fields)) + }, + _ => panic!("Only struct dtype is supported for `map_fields`."), + }), + ) + } + + #[cfg(feature = "dtype-struct")] + pub fn prefix_fields(self, prefix: &str) -> Expr { + self.0 + .map_unary(FunctionExpr::StructExpr(StructFunction::PrefixFields( + PlSmallStr::from_str(prefix), + ))) + } + + #[cfg(feature = "dtype-struct")] + pub fn suffix_fields(self, suffix: &str) -> Expr { + self.0 + .map_unary(FunctionExpr::StructExpr(StructFunction::SuffixFields( + PlSmallStr::from_str(suffix), + ))) + } +} + +#[cfg(feature = "dtype-struct")] +pub type FieldsNameMapper = Arc PlSmallStr + Send + Sync>; diff --git a/crates/polars-plan/src/dsl/options/mod.rs b/crates/polars-plan/src/dsl/options/mod.rs new file mode 100644 index 000000000000..c0482192d564 --- /dev/null +++ b/crates/polars-plan/src/dsl/options/mod.rs @@ -0,0 +1,399 @@ +use std::hash::Hash; +#[cfg(feature = "json")] +use std::num::NonZeroUsize; +use std::str::FromStr; +use std::sync::Arc; + +mod sink; + +use polars_core::error::PolarsResult; +use polars_core::prelude::*; +#[cfg(feature = "csv")] +use polars_io::csv::write::CsvWriterOptions; +#[cfg(feature = "ipc")] +use polars_io::ipc::IpcWriterOptions; +#[cfg(feature = "json")] +use polars_io::json::JsonWriterOptions; +#[cfg(feature = "parquet")] +use polars_io::parquet::write::ParquetWriteOptions; +#[cfg(feature = "iejoin")] +use polars_ops::frame::IEJoinOptions; +use polars_ops::frame::{CrossJoinFilter, CrossJoinOptions, JoinTypeOptions}; +use polars_ops::prelude::{JoinArgs, JoinType}; +#[cfg(feature = "dynamic_group_by")] +use polars_time::DynamicGroupOptions; +#[cfg(feature = "dynamic_group_by")] +use polars_time::RollingGroupOptions; +use polars_utils::IdxSize; +use polars_utils::pl_str::PlSmallStr; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; +pub use sink::*; +use strum_macros::IntoStaticStr; + +use super::ExprIR; +use crate::dsl::Selector; + +#[derive(Copy, Clone, PartialEq, Debug, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct RollingCovOptions { + pub window_size: IdxSize, + pub min_periods: IdxSize, + pub ddof: u8, +} + +#[derive(Clone, PartialEq, Debug, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct StrptimeOptions { + /// Formatting string + pub format: Option, + /// If set then polars will return an error if any date parsing fails + pub strict: bool, + /// If polars may parse matches that not contain the whole string + /// e.g. "foo-2021-01-01-bar" could match "2021-01-01" + pub exact: bool, + /// use a cache of unique, converted dates to apply the datetime conversion. + pub cache: bool, +} + +impl Default for StrptimeOptions { + fn default() -> Self { + StrptimeOptions { + format: None, + strict: true, + exact: true, + cache: true, + } + } +} + +#[derive(Clone, PartialEq, Eq, IntoStaticStr, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] +pub enum JoinTypeOptionsIR { + #[cfg(feature = "iejoin")] + IEJoin(IEJoinOptions), + #[cfg_attr(all(feature = "serde", not(feature = "ir_serde")), serde(skip))] + // Fused cross join and filter (only in in-memory engine) + Cross { predicate: ExprIR }, +} + +impl Hash for JoinTypeOptionsIR { + fn hash(&self, state: &mut H) { + use JoinTypeOptionsIR::*; + match self { + #[cfg(feature = "iejoin")] + IEJoin(opt) => opt.hash(state), + Cross { predicate } => predicate.node().hash(state), + } + } +} + +impl JoinTypeOptionsIR { + pub fn compile PolarsResult>>( + self, + plan: C, + ) -> PolarsResult { + use JoinTypeOptionsIR::*; + match self { + Cross { predicate } => { + let predicate = plan(&predicate)?; + + Ok(JoinTypeOptions::Cross(CrossJoinOptions { predicate })) + }, + #[cfg(feature = "iejoin")] + IEJoin(opt) => Ok(JoinTypeOptions::IEJoin(opt)), + } + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct JoinOptions { + pub allow_parallel: bool, + pub force_parallel: bool, + pub args: JoinArgs, + pub options: Option, + /// Proxy of the number of rows in both sides of the joins + /// Holds `(Option, estimated_size)` + pub rows_left: (Option, usize), + pub rows_right: (Option, usize), +} + +impl Default for JoinOptions { + fn default() -> Self { + JoinOptions { + allow_parallel: true, + force_parallel: false, + // Todo!: make default + args: JoinArgs::new(JoinType::Left), + options: Default::default(), + rows_left: (None, usize::MAX), + rows_right: (None, usize::MAX), + } + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum WindowType { + /// Explode the aggregated list and just do a hstack instead of a join + /// this requires the groups to be sorted to make any sense + Over(WindowMapping), + #[cfg(feature = "dynamic_group_by")] + Rolling(RollingGroupOptions), +} + +impl From for WindowType { + fn from(value: WindowMapping) -> Self { + Self::Over(value) + } +} + +impl Default for WindowType { + fn default() -> Self { + Self::Over(WindowMapping::default()) + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash, IntoStaticStr)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] +pub enum WindowMapping { + /// Map the group values to the position + #[default] + GroupsToRows, + /// Explode the aggregated list and just do a hstack instead of a join + /// this requires the groups to be sorted to make any sense + Explode, + /// Join the groups as 'List' to the row positions. + /// warning: this can be memory intensive + Join, +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum NestedType { + #[cfg(feature = "dtype-array")] + Array, + // List, +} + +#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct UnpivotArgsDSL { + pub on: Vec, + pub index: Vec, + pub variable_name: Option, + pub value_name: Option, +} + +#[derive(Clone, Debug, Copy, Eq, PartialEq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum Engine { + Auto, + OldStreaming, + Streaming, + InMemory, + Gpu, +} + +impl FromStr for Engine { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + // "cpu" for backwards compatibility + "auto" => Ok(Engine::Auto), + "cpu" | "in-memory" => Ok(Engine::InMemory), + "streaming" => Ok(Engine::Streaming), + "old-streaming" => Ok(Engine::OldStreaming), + "gpu" => Ok(Engine::Gpu), + v => Err(format!( + "`engine` must be one of {{'auto', 'in-memory', 'streaming', 'old-streaming', 'gpu'}}, got {v}", + )), + } + } +} + +impl Engine { + pub fn into_static_str(self) -> &'static str { + match self { + Self::Auto => "auto", + Self::OldStreaming => "old-streaming", + Self::Streaming => "streaming", + Self::InMemory => "in-memory", + Self::Gpu => "gpu", + } + } +} + +#[derive(Clone, Debug, Copy, Default, Eq, PartialEq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct UnionOptions { + pub slice: Option<(i64, usize)>, + // known row_output, estimated row output + pub rows: (Option, usize), + pub parallel: bool, + pub from_partitioned_ds: bool, + pub flattened_by_opt: bool, + pub rechunk: bool, + pub maintain_order: bool, +} + +#[derive(Clone, Debug, Copy, Default, Eq, PartialEq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct HConcatOptions { + pub parallel: bool, +} + +#[derive(Clone, Debug, PartialEq, Eq, Default, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct GroupbyOptions { + #[cfg(feature = "dynamic_group_by")] + pub dynamic: Option, + #[cfg(feature = "dynamic_group_by")] + pub rolling: Option, + /// Take only a slice of the result + pub slice: Option<(i64, usize)>, +} + +impl GroupbyOptions { + pub(crate) fn is_rolling(&self) -> bool { + #[cfg(feature = "dynamic_group_by")] + { + self.rolling.is_some() + } + #[cfg(not(feature = "dynamic_group_by"))] + { + false + } + } + + pub(crate) fn is_dynamic(&self) -> bool { + #[cfg(feature = "dynamic_group_by")] + { + self.dynamic.is_some() + } + #[cfg(not(feature = "dynamic_group_by"))] + { + false + } + } +} + +#[derive(Clone, Debug, Eq, PartialEq, Default, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct DistinctOptionsDSL { + /// Subset of columns that will be taken into account. + pub subset: Option>, + /// This will maintain the order of the input. + /// Note that this is more expensive. + /// `maintain_order` is not supported in the streaming + /// engine. + pub maintain_order: bool, + /// Which rows to keep. + pub keep_strategy: UniqueKeepStrategy, +} + +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub struct LogicalPlanUdfOptions { + /// allow predicate pushdown optimizations + pub predicate_pd: bool, + /// allow projection pushdown optimizations + pub projection_pd: bool, + // used for formatting + pub fmt_str: &'static str, +} + +#[derive(Clone, PartialEq, Eq, Debug, Default, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct AnonymousScanOptions { + pub skip_rows: Option, + pub fmt_str: &'static str, +} + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum FileType { + #[cfg(feature = "parquet")] + Parquet(ParquetWriteOptions), + #[cfg(feature = "ipc")] + Ipc(IpcWriterOptions), + #[cfg(feature = "csv")] + Csv(CsvWriterOptions), + #[cfg(feature = "json")] + Json(JsonWriterOptions), +} + +impl FileType { + pub fn extension(&self) -> &'static str { + match self { + #[cfg(feature = "parquet")] + Self::Parquet(_) => "parquet", + #[cfg(feature = "ipc")] + Self::Ipc(_) => "ipc", + #[cfg(feature = "csv")] + Self::Csv(_) => "csv", + #[cfg(feature = "json")] + Self::Json(_) => "jsonl", + + #[allow(unreachable_patterns)] + _ => unreachable!("enable file type features"), + } + } +} + +// +// Arguments given to `concat`. Differs from `UnionOptions` as the latter is IR state. +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct UnionArgs { + pub parallel: bool, + pub rechunk: bool, + pub to_supertypes: bool, + pub diagonal: bool, + // If it is a union from a scan over multiple files. + pub from_partitioned_ds: bool, + pub maintain_order: bool, +} + +impl Default for UnionArgs { + fn default() -> Self { + Self { + parallel: true, + rechunk: false, + to_supertypes: false, + diagonal: false, + from_partitioned_ds: false, + maintain_order: true, + } + } +} + +impl From for UnionOptions { + fn from(args: UnionArgs) -> Self { + UnionOptions { + slice: None, + parallel: args.parallel, + rows: (None, 0), + from_partitioned_ds: args.from_partitioned_ds, + flattened_by_opt: false, + rechunk: args.rechunk, + maintain_order: args.maintain_order, + } + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg(feature = "json")] +pub struct NDJsonReadOptions { + pub n_threads: Option, + pub infer_schema_length: Option, + pub chunk_size: NonZeroUsize, + pub low_memory: bool, + pub ignore_errors: bool, + pub schema: Option, + pub schema_overwrite: Option, +} diff --git a/crates/polars-plan/src/dsl/options/sink.rs b/crates/polars-plan/src/dsl/options/sink.rs new file mode 100644 index 000000000000..a05625359f27 --- /dev/null +++ b/crates/polars-plan/src/dsl/options/sink.rs @@ -0,0 +1,413 @@ +use std::fmt; +use std::hash::{Hash, Hasher}; +use std::path::PathBuf; +use std::sync::Arc; + +use polars_core::error::{PolarsResult, to_compute_err}; +use polars_core::prelude::DataType; +use polars_core::scalar::Scalar; +use polars_io::cloud::CloudOptions; +use polars_io::utils::file::{DynWriteable, Writeable}; +use polars_io::utils::sync_on_close::SyncOnCloseType; +use polars_utils::IdxSize; +use polars_utils::arena::Arena; +use polars_utils::pl_str::PlSmallStr; + +use super::{ExprIR, FileType}; +use crate::dsl::{AExpr, Expr, SpecialEq}; + +/// Options that apply to all sinks. +#[derive(Clone, PartialEq, Eq, Debug, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct SinkOptions { + /// Call sync when closing the file. + pub sync_on_close: SyncOnCloseType, + + /// The output file needs to maintain order of the data that comes in. + pub maintain_order: bool, + + /// Recursively create all the directories in the path. + pub mkdir: bool, +} + +impl Default for SinkOptions { + fn default() -> Self { + Self { + sync_on_close: Default::default(), + maintain_order: true, + mkdir: false, + } + } +} + +type DynSinkTarget = SpecialEq>>>>; + +#[derive(Clone, PartialEq, Eq)] +pub enum SinkTarget { + Path(Arc), + Dyn(DynSinkTarget), +} + +impl SinkTarget { + pub fn open_into_writeable( + &self, + sink_options: &SinkOptions, + cloud_options: Option<&CloudOptions>, + ) -> PolarsResult { + match self { + SinkTarget::Path(path) => { + if sink_options.mkdir { + polars_io::utils::mkdir::mkdir_recursive(path.as_path())?; + } + + let path = path.as_ref().display().to_string(); + polars_io::utils::file::Writeable::try_new(&path, cloud_options) + }, + SinkTarget::Dyn(memory_writer) => Ok(Writeable::Dyn( + memory_writer.lock().unwrap().take().unwrap(), + )), + } + } + + #[cfg(feature = "cloud")] + pub async fn open_into_writeable_async( + &self, + sink_options: &SinkOptions, + cloud_options: Option<&CloudOptions>, + ) -> PolarsResult { + match self { + SinkTarget::Path(path) => { + if sink_options.mkdir { + polars_io::utils::mkdir::tokio_mkdir_recursive(path.as_path()).await?; + } + + let path = path.as_ref().display().to_string(); + polars_io::utils::file::Writeable::try_new(&path, cloud_options) + }, + SinkTarget::Dyn(memory_writer) => Ok(Writeable::Dyn( + memory_writer.lock().unwrap().take().unwrap(), + )), + } + } +} + +impl fmt::Debug for SinkTarget { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("SinkTarget::")?; + match self { + Self::Path(p) => write!(f, "Path({p:?})"), + Self::Dyn(_) => f.write_str("Dyn"), + } + } +} + +impl std::hash::Hash for SinkTarget { + fn hash(&self, state: &mut H) { + std::mem::discriminant(self).hash(state); + match self { + Self::Path(p) => p.hash(state), + Self::Dyn(p) => Arc::as_ptr(p).hash(state), + } + } +} + +#[cfg(feature = "serde")] +impl serde::Serialize for SinkTarget { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + Self::Path(p) => p.serialize(serializer), + Self::Dyn(_) => Err(serde::ser::Error::custom( + "cannot serialize in-memory sink target", + )), + } + } +} + +#[cfg(feature = "serde")] +impl<'de> serde::Deserialize<'de> for SinkTarget { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + Ok(Self::Path(Arc::new(PathBuf::deserialize(deserializer)?))) + } +} + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct FileSinkType { + pub target: SinkTarget, + pub file_type: FileType, + pub sink_options: SinkOptions, + pub cloud_options: Option, +} + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Clone, Debug, PartialEq)] +pub enum SinkTypeIR { + Memory, + File(FileSinkType), + Partition(PartitionSinkTypeIR), +} + +#[cfg_attr(feature = "python", pyo3::pyclass)] +#[derive(Clone)] +pub struct PartitionTargetContextKey { + pub name: PlSmallStr, + pub raw_value: Scalar, +} + +#[cfg_attr(feature = "python", pyo3::pyclass)] +pub struct PartitionTargetContext { + pub file_idx: usize, + pub part_idx: usize, + pub in_part_idx: usize, + pub keys: Vec, + pub file_path: PathBuf, + pub full_path: PathBuf, +} + +#[cfg(feature = "python")] +#[pyo3::pymethods] +impl PartitionTargetContext { + #[getter] + pub fn file_idx(&self) -> usize { + self.file_idx + } + #[getter] + pub fn part_idx(&self) -> usize { + self.part_idx + } + #[getter] + pub fn in_part_idx(&self) -> usize { + self.in_part_idx + } + #[getter] + pub fn keys(&self) -> Vec { + self.keys.clone() + } + #[getter] + pub fn file_path(&self) -> &std::path::Path { + self.file_path.as_path() + } + #[getter] + pub fn full_path(&self) -> &std::path::Path { + self.full_path.as_path() + } +} +#[cfg(feature = "python")] +#[pyo3::pymethods] +impl PartitionTargetContextKey { + #[getter] + pub fn name(&self) -> &str { + self.name.as_str() + } + #[getter] + pub fn str_value(&self) -> pyo3::PyResult { + let value = self + .raw_value + .clone() + .into_series(PlSmallStr::EMPTY) + .strict_cast(&DataType::String) + .map_err(|err| pyo3::exceptions::PyRuntimeError::new_err(err.to_string()))?; + let value = value.str().unwrap(); + let value = value.get(0).unwrap_or("null").as_bytes(); + let value = percent_encoding::percent_encode(value, polars_io::utils::URL_ENCODE_CHAR_SET); + Ok(value.to_string()) + } + #[getter] + pub fn raw_value(&self) -> pyo3::PyObject { + let converter = polars_core::chunked_array::object::registry::get_pyobject_converter(); + *(converter.as_ref())(self.raw_value.as_any_value()) + .downcast::() + .unwrap() + } +} + +#[derive(Clone, Debug, PartialEq)] +pub enum PartitionTargetCallback { + Rust(SpecialEq PolarsResult + Send + Sync>>), + #[cfg(feature = "python")] + Python(polars_utils::python_function::PythonFunction), +} + +impl PartitionTargetCallback { + pub fn call(&self, ctx: PartitionTargetContext) -> PolarsResult { + match self { + Self::Rust(f) => f(ctx), + #[cfg(feature = "python")] + Self::Python(f) => pyo3::Python::with_gil(|py| { + let sink_target = f.call1(py, (ctx,)).map_err(to_compute_err)?; + let converter = + polars_utils::python_convert_registry::get_python_convert_registry(); + let sink_target = + (converter.from_py.sink_target)(sink_target).map_err(to_compute_err)?; + let sink_target = sink_target.downcast_ref::().unwrap().clone(); + PolarsResult::Ok(sink_target) + }), + } + } +} + +#[cfg(feature = "serde")] +impl<'de> serde::Deserialize<'de> for PartitionTargetCallback { + fn deserialize(_deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[cfg(feature = "python")] + { + Ok(Self::Python( + polars_utils::python_function::PythonFunction::deserialize(_deserializer)?, + )) + } + #[cfg(not(feature = "python"))] + { + use serde::de::Error; + Err(D::Error::custom( + "cannot deserialize PartitionOutputCallback", + )) + } + } +} + +#[cfg(feature = "serde")] +impl serde::Serialize for PartitionTargetCallback { + fn serialize(&self, _serializer: S) -> Result + where + S: serde::Serializer, + { + use serde::ser::Error; + + #[cfg(feature = "python")] + if let Self::Python(v) = self { + return v.serialize(_serializer); + } + + Err(S::Error::custom(format!("cannot serialize {:?}", self))) + } +} + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Clone, Debug, PartialEq)] +pub struct PartitionSinkType { + pub base_path: Arc, + pub file_path_cb: Option, + pub file_type: FileType, + pub sink_options: SinkOptions, + pub variant: PartitionVariant, + pub cloud_options: Option, +} + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Clone, Debug, PartialEq)] +pub struct PartitionSinkTypeIR { + pub base_path: Arc, + pub file_path_cb: Option, + pub file_type: FileType, + pub sink_options: SinkOptions, + pub variant: PartitionVariantIR, + pub cloud_options: Option, +} + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Clone, Debug, PartialEq)] +pub enum SinkType { + Memory, + File(FileSinkType), + Partition(PartitionSinkType), +} + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum PartitionVariant { + MaxSize(IdxSize), + Parted { + key_exprs: Vec, + include_key: bool, + }, + ByKey { + key_exprs: Vec, + include_key: bool, + }, +} + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum PartitionVariantIR { + MaxSize(IdxSize), + Parted { + key_exprs: Vec, + include_key: bool, + }, + ByKey { + key_exprs: Vec, + include_key: bool, + }, +} + +impl SinkTypeIR { + #[cfg(feature = "cse")] + pub(crate) fn traverse_and_hash(&self, expr_arena: &Arena, state: &mut H) { + std::mem::discriminant(self).hash(state); + match self { + Self::Memory => {}, + Self::File(f) => f.hash(state), + Self::Partition(f) => { + f.file_type.hash(state); + f.sink_options.hash(state); + f.variant.traverse_and_hash(expr_arena, state); + f.cloud_options.hash(state); + }, + } + } +} + +impl PartitionVariantIR { + #[cfg(feature = "cse")] + pub(crate) fn traverse_and_hash(&self, expr_arena: &Arena, state: &mut H) { + std::mem::discriminant(self).hash(state); + match self { + Self::MaxSize(size) => size.hash(state), + Self::Parted { + key_exprs, + include_key, + } + | Self::ByKey { + key_exprs, + include_key, + } => { + include_key.hash(state); + for key_expr in key_exprs.as_slice() { + key_expr.traverse_and_hash(expr_arena, state); + } + }, + } + } +} + +impl SinkType { + pub(crate) fn is_cloud_destination(&self) -> bool { + match self { + Self::Memory => false, + Self::File(f) => { + let SinkTarget::Path(p) = &f.target else { + return false; + }; + + polars_io::is_cloud_url(p.as_path()) + }, + Self::Partition(f) => polars_io::is_cloud_url(f.base_path.as_path()), + } + } +} + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Clone, Debug)] +pub struct FileSinkOptions { + pub path: Arc, + pub file_type: FileType, +} diff --git a/crates/polars-plan/src/dsl/plan.rs b/crates/polars-plan/src/dsl/plan.rs new file mode 100644 index 000000000000..9e5d50ce1b2f --- /dev/null +++ b/crates/polars-plan/src/dsl/plan.rs @@ -0,0 +1,260 @@ +use std::fmt; +use std::io::{Read, Write}; +use std::sync::{Arc, Mutex}; + +use polars_utils::arena::Node; +#[cfg(feature = "serde")] +use polars_utils::pl_serialize; +use recursive::recursive; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use super::*; +// (Major, Minor) +// Add a field -> increment minor +// Remove or modify a field -> increment major and reset minor +pub static DSL_VERSION: (u16, u16) = (1, 0); +static DSL_MAGIC_BYTES: &[u8] = b"DSL_VERSION"; + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum DslPlan { + #[cfg(feature = "python")] + PythonScan { + options: crate::dsl::python_dsl::PythonOptionsDsl, + }, + /// Filter on a boolean mask + Filter { + input: Arc, + predicate: Expr, + }, + /// Cache the input at this point in the LP + Cache { + input: Arc, + id: usize, + }, + Scan { + sources: ScanSources, + /// Materialized at IR except for AnonymousScan. + file_info: Option, + unified_scan_args: Box, + scan_type: Box, + /// Local use cases often repeatedly collect the same `LazyFrame` (e.g. in interactive notebook use-cases), + /// so we cache the IR conversion here, as the path expansion can be quite slow (especially for cloud paths). + /// We don't have the arena, as this is always a source node. + #[cfg_attr(feature = "serde", serde(skip))] + cached_ir: Arc>>, + }, + // we keep track of the projection and selection as it is cheaper to first project and then filter + /// In memory DataFrame + DataFrameScan { + df: Arc, + schema: SchemaRef, + }, + /// Polars' `select` operation, this can mean projection, but also full data access. + Select { + expr: Vec, + input: Arc, + options: ProjectionOptions, + }, + /// Groupby aggregation + GroupBy { + input: Arc, + keys: Vec, + aggs: Vec, + maintain_order: bool, + options: Arc, + #[cfg_attr(feature = "serde", serde(skip))] + apply: Option<(Arc, SchemaRef)>, + }, + /// Join operation + Join { + input_left: Arc, + input_right: Arc, + // Invariant: left_on and right_on are equal length. + left_on: Vec, + right_on: Vec, + // Invariant: Either left_on/right_on or predicates is set (non-empty). + predicates: Vec, + options: Arc, + }, + /// Adding columns to the table without a Join + HStack { + input: Arc, + exprs: Vec, + options: ProjectionOptions, + }, + /// Remove duplicates from the table + Distinct { + input: Arc, + options: DistinctOptionsDSL, + }, + /// Sort the table + Sort { + input: Arc, + by_column: Vec, + slice: Option<(i64, usize)>, + sort_options: SortMultipleOptions, + }, + /// Slice the table + Slice { + input: Arc, + offset: i64, + len: IdxSize, + }, + /// A (User Defined) Function + MapFunction { + input: Arc, + function: DslFunction, + }, + /// Vertical concatenation + Union { + inputs: Vec, + args: UnionArgs, + }, + /// Horizontal concatenation of multiple plans + HConcat { + inputs: Vec, + options: HConcatOptions, + }, + /// This allows expressions to access other tables + ExtContext { + input: Arc, + contexts: Vec, + }, + Sink { + input: Arc, + payload: SinkType, + }, + SinkMultiple { + inputs: Vec, + }, + #[cfg(feature = "merge_sorted")] + MergeSorted { + input_left: Arc, + input_right: Arc, + key: PlSmallStr, + }, + IR { + // Keep the original Dsl around as we need that for serialization. + dsl: Arc, + version: u32, + #[cfg_attr(feature = "serde", serde(skip))] + node: Option, + }, +} + +impl Clone for DslPlan { + // Autogenerated by rust-analyzer, don't care about it looking nice, it just + // calls clone on every member of every enum variant. + #[rustfmt::skip] + #[allow(clippy::clone_on_copy)] + #[recursive] + fn clone(&self) -> Self { + match self { + #[cfg(feature = "python")] + Self::PythonScan { options } => Self::PythonScan { options: options.clone() }, + Self::Filter { input, predicate } => Self::Filter { input: input.clone(), predicate: predicate.clone() }, + Self::Cache { input, id } => Self::Cache { input: input.clone(), id: id.clone() }, + Self::Scan { sources, file_info, unified_scan_args, scan_type, cached_ir } => Self::Scan { sources: sources.clone(), file_info: file_info.clone(), unified_scan_args: unified_scan_args.clone(), scan_type: scan_type.clone(), cached_ir: cached_ir.clone() }, + Self::DataFrameScan { df, schema, } => Self::DataFrameScan { df: df.clone(), schema: schema.clone(), }, + Self::Select { expr, input, options } => Self::Select { expr: expr.clone(), input: input.clone(), options: options.clone() }, + Self::GroupBy { input, keys, aggs, apply, maintain_order, options } => Self::GroupBy { input: input.clone(), keys: keys.clone(), aggs: aggs.clone(), apply: apply.clone(), maintain_order: maintain_order.clone(), options: options.clone() }, + Self::Join { input_left, input_right, left_on, right_on, predicates, options } => Self::Join { input_left: input_left.clone(), input_right: input_right.clone(), left_on: left_on.clone(), right_on: right_on.clone(), options: options.clone(), predicates: predicates.clone() }, + Self::HStack { input, exprs, options } => Self::HStack { input: input.clone(), exprs: exprs.clone(), options: options.clone() }, + Self::Distinct { input, options } => Self::Distinct { input: input.clone(), options: options.clone() }, + Self::Sort {input,by_column, slice, sort_options } => Self::Sort { input: input.clone(), by_column: by_column.clone(), slice: slice.clone(), sort_options: sort_options.clone() }, + Self::Slice { input, offset, len } => Self::Slice { input: input.clone(), offset: offset.clone(), len: len.clone() }, + Self::MapFunction { input, function } => Self::MapFunction { input: input.clone(), function: function.clone() }, + Self::Union { inputs, args} => Self::Union { inputs: inputs.clone(), args: args.clone() }, + Self::HConcat { inputs, options } => Self::HConcat { inputs: inputs.clone(), options: options.clone() }, + Self::ExtContext { input, contexts, } => Self::ExtContext { input: input.clone(), contexts: contexts.clone() }, + Self::Sink { input, payload } => Self::Sink { input: input.clone(), payload: payload.clone() }, + Self::SinkMultiple { inputs } => Self::SinkMultiple { inputs: inputs.clone() }, + #[cfg(feature = "merge_sorted")] + Self::MergeSorted { input_left, input_right, key } => Self::MergeSorted { input_left: input_left.clone(), input_right: input_right.clone(), key: key.clone() }, + Self::IR {node, dsl, version} => Self::IR {node: *node, dsl: dsl.clone(), version: *version}, + } + } +} + +impl Default for DslPlan { + fn default() -> Self { + let df = DataFrame::empty(); + let schema = df.schema().clone(); + DslPlan::DataFrameScan { + df: Arc::new(df), + schema, + } + } +} + +impl DslPlan { + pub fn describe(&self) -> PolarsResult { + Ok(self.clone().to_alp()?.describe()) + } + + pub fn describe_tree_format(&self) -> PolarsResult { + Ok(self.clone().to_alp()?.describe_tree_format()) + } + + pub fn display(&self) -> PolarsResult { + struct DslPlanDisplay(IRPlan); + impl fmt::Display for DslPlanDisplay { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.0.as_ref().display(), f) + } + } + Ok(DslPlanDisplay(self.clone().to_alp()?)) + } + + pub fn to_alp(self) -> PolarsResult { + let mut lp_arena = Arena::with_capacity(16); + let mut expr_arena = Arena::with_capacity(16); + + let node = to_alp( + self, + &mut expr_arena, + &mut lp_arena, + &mut OptFlags::default(), + )?; + let plan = IRPlan::new(node, lp_arena, expr_arena); + + Ok(plan) + } + + #[cfg(feature = "serde")] + pub fn serialize_versioned(&self, mut writer: W) -> PolarsResult<()> { + let le_major = DSL_VERSION.0.to_le_bytes(); + let le_minor = DSL_VERSION.1.to_le_bytes(); + writer.write_all(DSL_MAGIC_BYTES)?; + writer.write_all(&le_major)?; + writer.write_all(&le_minor)?; + pl_serialize::SerializeOptions::default().serialize_into_writer::<_, _, true>(writer, self) + } + + #[cfg(feature = "serde")] + pub fn deserialize_versioned(mut reader: R) -> PolarsResult { + const MAGIC_LEN: usize = DSL_MAGIC_BYTES.len(); + let mut version_magic = [0u8; MAGIC_LEN + 4]; + reader.read_exact(&mut version_magic)?; + + if &version_magic[..MAGIC_LEN] != DSL_MAGIC_BYTES { + polars_bail!(ComputeError: "dsl magic bytes not found") + } + + // The DSL serialization is forward compatible if fields don't change, + // so we don't check equality here, we just use this version + // to inform users when the deserialization fails. + let major = u16::from_be_bytes(version_magic[MAGIC_LEN..MAGIC_LEN + 2].try_into().unwrap()); + let minor = u16::from_be_bytes( + version_magic[MAGIC_LEN + 2..MAGIC_LEN + 4] + .try_into() + .unwrap(), + ); + + pl_serialize::SerializeOptions::default() + .deserialize_from_reader::<_, _, true>(reader).map_err(|e| { + polars_err!(ComputeError: "deserialization failed\n\ngiven DSL_VERSION: {:?} is not compatible with this Polars version which uses DSL_VERSION: {:?}\nerror: {}", (major, minor), DSL_VERSION, e) + }) + } +} diff --git a/crates/polars-plan/src/dsl/python_dsl/mod.rs b/crates/polars-plan/src/dsl/python_dsl/mod.rs new file mode 100644 index 000000000000..a6d84eb1a7bd --- /dev/null +++ b/crates/polars-plan/src/dsl/python_dsl/mod.rs @@ -0,0 +1,5 @@ +mod python_udf; +mod source; + +pub use python_udf::*; +pub use source::*; diff --git a/crates/polars-plan/src/dsl/python_dsl/python_udf.rs b/crates/polars-plan/src/dsl/python_dsl/python_udf.rs new file mode 100644 index 000000000000..8d0ceae3921a --- /dev/null +++ b/crates/polars-plan/src/dsl/python_dsl/python_udf.rs @@ -0,0 +1,267 @@ +use std::io::Cursor; +use std::sync::Arc; + +use polars_core::datatypes::{DataType, Field}; +use polars_core::error::*; +use polars_core::frame::DataFrame; +use polars_core::frame::column::Column; +use polars_core::schema::Schema; +use pyo3::prelude::*; +use pyo3::pybacked::PyBackedBytes; +use pyo3::types::PyBytes; + +use crate::constants::MAP_LIST_NAME; +use crate::prelude::*; + +// Will be overwritten on Python Polars start up. +pub static mut CALL_COLUMNS_UDF_PYTHON: Option< + fn(s: Column, lambda: &PyObject) -> PolarsResult, +> = None; +pub static mut CALL_DF_UDF_PYTHON: Option< + fn(s: DataFrame, lambda: &PyObject) -> PolarsResult, +> = None; + +pub use polars_utils::python_function::{ + PYTHON_SERDE_MAGIC_BYTE_MARK, PYTHON3_VERSION, PythonFunction, +}; + +pub struct PythonUdfExpression { + python_function: PyObject, + output_type: Option, + is_elementwise: bool, + returns_scalar: bool, +} + +impl PythonUdfExpression { + pub fn new( + lambda: PyObject, + output_type: Option, + is_elementwise: bool, + returns_scalar: bool, + ) -> Self { + Self { + python_function: lambda, + output_type, + is_elementwise, + returns_scalar, + } + } + + #[cfg(feature = "serde")] + pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult> { + // Handle byte mark + + use polars_utils::pl_serialize; + debug_assert!(buf.starts_with(PYTHON_SERDE_MAGIC_BYTE_MARK)); + let buf = &buf[PYTHON_SERDE_MAGIC_BYTE_MARK.len()..]; + + // Handle pickle metadata + let use_cloudpickle = buf[0]; + if use_cloudpickle != 0 { + let ser_py_version = &buf[1..3]; + let cur_py_version = *PYTHON3_VERSION; + polars_ensure!( + ser_py_version == cur_py_version, + InvalidOperation: + "current Python version {:?} does not match the Python version used to serialize the UDF {:?}", + (3, cur_py_version[0], cur_py_version[1]), + (3, ser_py_version[0], ser_py_version[1] ) + ); + } + let buf = &buf[3..]; + + // Load UDF metadata + let mut reader = Cursor::new(buf); + let (output_type, is_elementwise, returns_scalar): (Option, bool, bool) = + pl_serialize::deserialize_from_reader::<_, _, true>(&mut reader)?; + + let remainder = &buf[reader.position() as usize..]; + + // Load UDF + Python::with_gil(|py| { + let pickle = PyModule::import(py, "pickle") + .expect("unable to import 'pickle'") + .getattr("loads") + .unwrap(); + let arg = (PyBytes::new(py, remainder),); + let python_function = pickle.call1(arg).map_err(from_pyerr)?; + Ok(Arc::new(Self::new( + python_function.into(), + output_type, + is_elementwise, + returns_scalar, + )) as Arc) + }) + } +} + +fn from_pyerr(e: PyErr) -> PolarsError { + PolarsError::ComputeError(format!("error raised in python: {e}").into()) +} + +impl DataFrameUdf for polars_utils::python_function::PythonFunction { + fn call_udf(&self, df: DataFrame) -> PolarsResult { + let func = unsafe { CALL_DF_UDF_PYTHON.unwrap() }; + func(df, &self.0) + } +} + +impl ColumnsUdf for PythonUdfExpression { + fn call_udf(&self, s: &mut [Column]) -> PolarsResult> { + let func = unsafe { CALL_COLUMNS_UDF_PYTHON.unwrap() }; + + let output_type = self + .output_type + .clone() + .unwrap_or_else(|| DataType::Unknown(Default::default())); + let mut out = func(s[0].clone(), &self.python_function)?; + if !matches!(output_type, DataType::Unknown(_)) { + let must_cast = out.dtype().matches_schema_type(&output_type).map_err(|_| { + polars_err!( + SchemaMismatch: "expected output type '{:?}', got '{:?}'; set `return_dtype` to the proper datatype", + output_type, out.dtype(), + ) + })?; + if must_cast { + out = out.cast(&output_type)?; + } + } + + Ok(Some(out)) + } + + #[cfg(feature = "serde")] + fn try_serialize(&self, buf: &mut Vec) -> PolarsResult<()> { + // Write byte marks + + use polars_utils::pl_serialize; + buf.extend_from_slice(PYTHON_SERDE_MAGIC_BYTE_MARK); + + Python::with_gil(|py| { + // Try pickle to serialize the UDF, otherwise fall back to cloudpickle. + let pickle = PyModule::import(py, "pickle") + .expect("unable to import 'pickle'") + .getattr("dumps") + .unwrap(); + let pickle_result = pickle.call1((self.python_function.clone_ref(py),)); + let (dumped, use_cloudpickle) = match pickle_result { + Ok(dumped) => (dumped, false), + Err(_) => { + let cloudpickle = PyModule::import(py, "cloudpickle") + .map_err(from_pyerr)? + .getattr("dumps") + .unwrap(); + let dumped = cloudpickle + .call1((self.python_function.clone_ref(py),)) + .map_err(from_pyerr)?; + (dumped, true) + }, + }; + + // Write pickle metadata + buf.push(use_cloudpickle as u8); + buf.extend_from_slice(&*PYTHON3_VERSION); + + // Write UDF metadata + pl_serialize::serialize_into_writer::<_, _, true>( + &mut *buf, + &( + self.output_type.clone(), + self.is_elementwise, + self.returns_scalar, + ), + )?; + + // Write UDF + let dumped = dumped.extract::().unwrap(); + buf.extend_from_slice(&dumped); + Ok(()) + }) + } +} + +/// Serializable version of [`GetOutput`] for Python UDFs. +pub struct PythonGetOutput { + return_dtype: Option, +} + +impl PythonGetOutput { + pub fn new(return_dtype: Option) -> Self { + Self { return_dtype } + } + + #[cfg(feature = "serde")] + pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult> { + // Skip header. + + use polars_utils::pl_serialize; + debug_assert!(buf.starts_with(PYTHON_SERDE_MAGIC_BYTE_MARK)); + let buf = &buf[PYTHON_SERDE_MAGIC_BYTE_MARK.len()..]; + + let mut reader = Cursor::new(buf); + let return_dtype: Option = + pl_serialize::deserialize_from_reader::<_, _, true>(&mut reader)?; + + Ok(Arc::new(Self::new(return_dtype)) as Arc) + } +} + +impl FunctionOutputField for PythonGetOutput { + fn get_field( + &self, + _input_schema: &Schema, + _cntxt: Context, + fields: &[Field], + ) -> PolarsResult { + // Take the name of first field, just like [`GetOutput::map_field`]. + let name = fields[0].name(); + let return_dtype = match self.return_dtype { + Some(ref dtype) => dtype.clone(), + None => DataType::Unknown(Default::default()), + }; + Ok(Field::new(name.clone(), return_dtype)) + } + + #[cfg(feature = "serde")] + fn try_serialize(&self, buf: &mut Vec) -> PolarsResult<()> { + use polars_utils::pl_serialize; + + buf.extend_from_slice(PYTHON_SERDE_MAGIC_BYTE_MARK); + pl_serialize::serialize_into_writer::<_, _, true>(&mut *buf, &self.return_dtype) + } +} + +impl Expr { + pub fn map_python(self, func: PythonUdfExpression, agg_list: bool) -> Expr { + let (collect_groups, name) = if agg_list { + (ApplyOptions::ApplyList, MAP_LIST_NAME) + } else if func.is_elementwise { + (ApplyOptions::ElementWise, "python_udf") + } else { + (ApplyOptions::GroupWise, "python_udf") + }; + + let returns_scalar = func.returns_scalar; + let return_dtype = func.output_type.clone(); + + let output_field = PythonGetOutput::new(return_dtype); + let output_type = SpecialEq::new(Arc::new(output_field) as Arc); + + let mut flags = FunctionFlags::default() | FunctionFlags::OPTIONAL_RE_ENTRANT; + if returns_scalar { + flags |= FunctionFlags::RETURNS_SCALAR; + } + + Expr::AnonymousFunction { + input: vec![self], + function: new_column_udf(func), + output_type, + options: FunctionOptions { + collect_groups, + fmt_str: name, + flags, + ..Default::default() + }, + } + } +} diff --git a/crates/polars-plan/src/dsl/python_dsl/source.rs b/crates/polars-plan/src/dsl/python_dsl/source.rs new file mode 100644 index 000000000000..4a38efbcdb16 --- /dev/null +++ b/crates/polars-plan/src/dsl/python_dsl/source.rs @@ -0,0 +1,47 @@ +use std::sync::Arc; + +use either::Either; +use polars_core::error::{PolarsResult, polars_err}; +use polars_core::schema::SchemaRef; +use polars_utils::python_function::PythonFunction; +use pyo3::prelude::*; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use crate::dsl::SpecialEq; + +#[derive(Clone, PartialEq, Eq, Debug, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct PythonOptionsDsl { + /// A function that returns a Python Generator. + /// The generator should produce Polars DataFrame's. + pub scan_fn: Option, + /// Either the schema fn or schema is set. + pub schema_fn: Option>>>, + pub python_source: PythonScanSource, + pub validate_schema: bool, +} + +impl PythonOptionsDsl { + pub(crate) fn get_schema(&self) -> PolarsResult { + match self.schema_fn.as_ref().expect("should be set").as_ref() { + Either::Left(func) => Python::with_gil(|py| { + let schema = func + .0 + .call0(py) + .map_err(|e| polars_err!(ComputeError: "schema callable failed: {}", e))?; + crate::plans::python::python_schema_to_rust(py, schema.into_bound(py)) + }), + Either::Right(schema) => Ok(schema.clone()), + } + } +} + +#[derive(Clone, PartialEq, Eq, Debug, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum PythonScanSource { + Pyarrow, + Cuda, + #[default] + IOPlugin, +} diff --git a/crates/polars-plan/src/dsl/random.rs b/crates/polars-plan/src/dsl/random.rs new file mode 100644 index 000000000000..2cba43e23021 --- /dev/null +++ b/crates/polars-plan/src/dsl/random.rs @@ -0,0 +1,50 @@ +use super::*; + +impl Expr { + pub fn shuffle(self, seed: Option) -> Self { + self.map_unary(FunctionExpr::Random { + method: RandomMethod::Shuffle, + seed, + }) + } + + pub fn sample_n( + self, + n: Expr, + with_replacement: bool, + shuffle: bool, + seed: Option, + ) -> Self { + self.map_binary( + FunctionExpr::Random { + method: RandomMethod::Sample { + is_fraction: false, + with_replacement, + shuffle, + }, + seed, + }, + n, + ) + } + + pub fn sample_frac( + self, + frac: Expr, + with_replacement: bool, + shuffle: bool, + seed: Option, + ) -> Self { + self.map_binary( + FunctionExpr::Random { + method: RandomMethod::Sample { + is_fraction: true, + with_replacement, + shuffle, + }, + seed, + }, + frac, + ) + } +} diff --git a/crates/polars-plan/src/dsl/scan_sources.rs b/crates/polars-plan/src/dsl/scan_sources.rs new file mode 100644 index 000000000000..26b9d6f25118 --- /dev/null +++ b/crates/polars-plan/src/dsl/scan_sources.rs @@ -0,0 +1,459 @@ +use std::fmt::{Debug, Formatter}; +use std::fs::File; +use std::path::{Path, PathBuf}; +use std::sync::Arc; + +use polars_core::error::{PolarsResult, feature_gated}; +use polars_io::cloud::CloudOptions; +#[cfg(feature = "cloud")] +use polars_io::file_cache::FileCacheEntry; +#[cfg(feature = "cloud")] +use polars_io::utils::byte_source::{DynByteSource, DynByteSourceBuilder}; +use polars_io::{expand_paths, expand_paths_hive, expanded_from_single_directory}; +use polars_utils::mmap::MemSlice; +use polars_utils::pl_str::PlSmallStr; + +use super::UnifiedScanArgs; + +/// Set of sources to scan from +/// +/// This can either be a list of paths to files, opened files or in-memory buffers. Mixing of +/// buffers is not currently possible. +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Clone)] +pub enum ScanSources { + Paths(Arc<[PathBuf]>), + + #[cfg_attr(feature = "serde", serde(skip))] + Files(Arc<[File]>), + #[cfg_attr(feature = "serde", serde(skip))] + Buffers(Arc<[MemSlice]>), +} + +impl Debug for ScanSources { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Paths(p) => write!(f, "paths: {:?}", p.as_ref()), + Self::Files(p) => write!(f, "files: {} files", p.len()), + Self::Buffers(b) => write!(f, "buffers: {} in-memory-buffers", b.len()), + } + } +} + +/// A reference to a single item in [`ScanSources`] +#[derive(Debug, Clone, Copy)] +pub enum ScanSourceRef<'a> { + Path(&'a Path), + File(&'a File), + Buffer(&'a MemSlice), +} + +/// A single source to scan from +#[derive(Debug, Clone)] +pub enum ScanSource { + Path(Arc), + File(Arc), + Buffer(MemSlice), +} + +impl ScanSource { + pub fn from_sources(sources: ScanSources) -> Result { + if sources.len() == 1 { + match sources { + ScanSources::Paths(ps) => Ok(Self::Path(ps.as_ref()[0].clone().into())), + ScanSources::Files(fs) => { + assert_eq!(fs.len(), 1); + let ptr: *const File = Arc::into_raw(fs) as *const File; + // SAFETY: A [T] with length 1 can be interpreted as T + let f: Arc = unsafe { Arc::from_raw(ptr) }; + + Ok(Self::File(f)) + }, + ScanSources::Buffers(bs) => Ok(Self::Buffer(bs.as_ref()[0].clone())), + } + } else { + Err(sources) + } + } + + pub fn into_sources(self) -> ScanSources { + match self { + ScanSource::Path(p) => ScanSources::Paths([p.to_path_buf()].into()), + ScanSource::File(f) => { + let ptr: *const [File] = std::ptr::slice_from_raw_parts(Arc::into_raw(f), 1); + // SAFETY: A T can be interpreted as [T] with length 1. + let fs: Arc<[File]> = unsafe { Arc::from_raw(ptr) }; + ScanSources::Files(fs) + }, + ScanSource::Buffer(m) => ScanSources::Buffers([m].into()), + } + } + + pub fn as_scan_source_ref(&self) -> ScanSourceRef { + match self { + ScanSource::Path(path) => ScanSourceRef::Path(path.as_ref()), + ScanSource::File(file) => ScanSourceRef::File(file.as_ref()), + ScanSource::Buffer(mem_slice) => ScanSourceRef::Buffer(mem_slice), + } + } + + pub fn run_async(&self) -> bool { + self.as_scan_source_ref().run_async() + } + + pub fn is_cloud_url(&self) -> bool { + if let ScanSource::Path(path) = self { + polars_io::is_cloud_url(path.as_ref()) + } else { + false + } + } +} + +/// An iterator for [`ScanSources`] +pub struct ScanSourceIter<'a> { + sources: &'a ScanSources, + offset: usize, +} + +impl Default for ScanSources { + fn default() -> Self { + // We need to use `Paths` here to avoid erroring when doing hive-partitioned scans of empty + // file lists. + Self::Paths(Arc::default()) + } +} + +impl std::hash::Hash for ScanSources { + fn hash(&self, state: &mut H) { + std::mem::discriminant(self).hash(state); + + // @NOTE: This is a bit crazy + // + // We don't really want to hash the file descriptors or the whole buffers so for now we + // just settle with the fact that the memory behind Arc's does not really move. Therefore, + // we can just hash the pointer. + match self { + Self::Paths(paths) => paths.hash(state), + Self::Files(files) => files.as_ptr().hash(state), + Self::Buffers(buffers) => buffers.as_ptr().hash(state), + } + } +} + +impl PartialEq for ScanSources { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (ScanSources::Paths(l), ScanSources::Paths(r)) => l == r, + (ScanSources::Files(l), ScanSources::Files(r)) => std::ptr::eq(l.as_ptr(), r.as_ptr()), + (ScanSources::Buffers(l), ScanSources::Buffers(r)) => { + std::ptr::eq(l.as_ptr(), r.as_ptr()) + }, + _ => false, + } + } +} + +impl Eq for ScanSources {} + +impl ScanSources { + pub fn expand_paths( + &self, + scan_args: &UnifiedScanArgs, + #[allow(unused_variables)] cloud_options: Option<&CloudOptions>, + ) -> PolarsResult { + match self { + Self::Paths(paths) => Ok(Self::Paths(expand_paths( + paths, + scan_args.glob, + cloud_options, + )?)), + v => Ok(v.clone()), + } + } + + /// This will update `scan_args.hive_options.enabled` to `true` if the existing value is `None` + /// and the paths are expanded from a single directory. Otherwise the existing value is maintained. + #[cfg(any(feature = "ipc", feature = "parquet"))] + pub fn expand_paths_with_hive_update( + &self, + scan_args: &mut UnifiedScanArgs, + #[allow(unused_variables)] cloud_options: Option<&CloudOptions>, + ) -> PolarsResult { + match self { + Self::Paths(paths) => { + let (expanded_paths, hive_start_idx) = expand_paths_hive( + paths, + scan_args.glob, + cloud_options, + scan_args.hive_options.enabled.unwrap_or(false), + )?; + + if scan_args.hive_options.enabled.is_none() + && expanded_from_single_directory(paths, expanded_paths.as_ref()) + { + scan_args.hive_options.enabled = Some(true); + } + scan_args.hive_options.hive_start_idx = hive_start_idx; + + Ok(Self::Paths(expanded_paths)) + }, + v => Ok(v.clone()), + } + } + + pub fn iter(&self) -> ScanSourceIter { + ScanSourceIter { + sources: self, + offset: 0, + } + } + + /// Are the sources all paths? + pub fn is_paths(&self) -> bool { + matches!(self, Self::Paths(_)) + } + + /// Try cast the scan sources to [`ScanSources::Paths`] + pub fn as_paths(&self) -> Option<&[PathBuf]> { + match self { + Self::Paths(paths) => Some(paths.as_ref()), + Self::Files(_) | Self::Buffers(_) => None, + } + } + + /// Try cast the scan sources to [`ScanSources::Paths`] with a clone + pub fn into_paths(&self) -> Option> { + match self { + Self::Paths(paths) => Some(paths.clone()), + Self::Files(_) | Self::Buffers(_) => None, + } + } + + /// Try get the first path in the scan sources + pub fn first_path(&self) -> Option<&Path> { + match self { + Self::Paths(paths) => paths.first().map(|p| p.as_path()), + Self::Files(_) | Self::Buffers(_) => None, + } + } + + /// Is the first path a cloud URL? + pub fn is_cloud_url(&self) -> bool { + self.first_path().is_some_and(polars_io::is_cloud_url) + } + + pub fn len(&self) -> usize { + match self { + Self::Paths(s) => s.len(), + Self::Files(s) => s.len(), + Self::Buffers(s) => s.len(), + } + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn first(&self) -> Option { + self.get(0) + } + + /// Turn the [`ScanSources`] into some kind of identifier + pub fn id(&self) -> PlSmallStr { + if self.is_empty() { + return PlSmallStr::from_static("EMPTY"); + } + + match self { + Self::Paths(paths) => { + PlSmallStr::from_str(paths.first().unwrap().to_string_lossy().as_ref()) + }, + Self::Files(_) => PlSmallStr::from_static("OPEN_FILES"), + Self::Buffers(_) => PlSmallStr::from_static("IN_MEMORY"), + } + } + + /// Get the scan source at specific address + pub fn get(&self, idx: usize) -> Option { + match self { + Self::Paths(paths) => paths.get(idx).map(|p| ScanSourceRef::Path(p)), + Self::Files(files) => files.get(idx).map(ScanSourceRef::File), + Self::Buffers(buffers) => buffers.get(idx).map(ScanSourceRef::Buffer), + } + } + + /// Get the scan source at specific address + /// + /// # Panics + /// + /// If the `idx` is out of range. + #[track_caller] + pub fn at(&self, idx: usize) -> ScanSourceRef { + self.get(idx).unwrap() + } +} + +impl ScanSourceRef<'_> { + /// Get the name for `include_paths` + pub fn to_include_path_name(&self) -> &str { + match self { + Self::Path(path) => path.to_str().unwrap(), + Self::File(_) => "open-file", + Self::Buffer(_) => "in-mem", + } + } + + // @TODO: I would like to remove this function eventually. + pub fn into_owned(&self) -> PolarsResult { + Ok(match self { + ScanSourceRef::Path(path) => ScanSource::Path((*path).into()), + ScanSourceRef::File(file) => { + if let Ok(file) = file.try_clone() { + ScanSource::File(Arc::new(file)) + } else { + ScanSource::Buffer(self.to_memslice()?) + } + }, + ScanSourceRef::Buffer(buffer) => ScanSource::Buffer((*buffer).clone()), + }) + } + + /// Turn the scan source into a memory slice + pub fn to_memslice(&self) -> PolarsResult { + self.to_memslice_possibly_async(false, None, 0) + } + + #[allow(clippy::wrong_self_convention)] + #[cfg(feature = "cloud")] + fn to_memslice_async) -> PolarsResult>( + &self, + assume: F, + run_async: bool, + ) -> PolarsResult { + match self { + ScanSourceRef::Path(path) => { + let path_str = path.to_str(); + let file = if run_async && path_str.is_some() { + feature_gated!("cloud", { + // This isn't filled if we modified the DSL (e.g. in cloud) + let entry = polars_io::file_cache::FILE_CACHE.get_entry(path_str.unwrap()); + + if let Some(entry) = entry { + assume(entry)? + } else { + polars_utils::open_file(path)? + } + }) + } else { + polars_utils::open_file(path)? + }; + + MemSlice::from_file(&file) + }, + ScanSourceRef::File(file) => MemSlice::from_file(file), + ScanSourceRef::Buffer(buff) => Ok((*buff).clone()), + } + } + + #[cfg(feature = "cloud")] + pub fn to_memslice_async_assume_latest(&self, run_async: bool) -> PolarsResult { + self.to_memslice_async(|entry| entry.try_open_assume_latest(), run_async) + } + + #[cfg(feature = "cloud")] + pub fn to_memslice_async_check_latest(&self, run_async: bool) -> PolarsResult { + self.to_memslice_async(|entry| entry.try_open_check_latest(), run_async) + } + + #[cfg(not(feature = "cloud"))] + fn to_memslice_async(&self, run_async: bool) -> PolarsResult { + match self { + ScanSourceRef::Path(path) => { + let file = polars_utils::open_file(path)?; + MemSlice::from_file(&file) + }, + ScanSourceRef::File(file) => MemSlice::from_file(file), + ScanSourceRef::Buffer(buff) => Ok((*buff).clone()), + } + } + + #[cfg(not(feature = "cloud"))] + pub fn to_memslice_async_assume_latest(&self, run_async: bool) -> PolarsResult { + self.to_memslice_async(run_async) + } + + #[cfg(not(feature = "cloud"))] + pub fn to_memslice_async_check_latest(&self, run_async: bool) -> PolarsResult { + self.to_memslice_async(run_async) + } + + pub fn to_memslice_possibly_async( + &self, + run_async: bool, + #[cfg(feature = "cloud")] cache_entries: Option< + &Vec>, + >, + #[cfg(not(feature = "cloud"))] cache_entries: Option<&()>, + index: usize, + ) -> PolarsResult { + match self { + Self::Path(path) => { + let file = if run_async { + feature_gated!("cloud", { + cache_entries.unwrap()[index].try_open_check_latest()? + }) + } else { + polars_utils::open_file(path)? + }; + + MemSlice::from_file(&file) + }, + Self::File(file) => MemSlice::from_file(file), + Self::Buffer(buff) => Ok((*buff).clone()), + } + } + + #[cfg(feature = "cloud")] + pub async fn to_dyn_byte_source( + &self, + builder: &DynByteSourceBuilder, + cloud_options: Option<&CloudOptions>, + ) -> PolarsResult { + match self { + Self::Path(path) => { + builder + .try_build_from_path(path.to_str().unwrap(), cloud_options) + .await + }, + Self::File(file) => Ok(DynByteSource::from(MemSlice::from_file(file)?)), + Self::Buffer(buff) => Ok(DynByteSource::from((*buff).clone())), + } + } + + pub(crate) fn run_async(&self) -> bool { + matches!(self, Self::Path(p) if polars_io::is_cloud_url(p) || polars_core::config::force_async()) + } +} + +impl<'a> Iterator for ScanSourceIter<'a> { + type Item = ScanSourceRef<'a>; + + fn next(&mut self) -> Option { + let item = match self.sources { + ScanSources::Paths(paths) => ScanSourceRef::Path(paths.get(self.offset)?), + ScanSources::Files(files) => ScanSourceRef::File(files.get(self.offset)?), + ScanSources::Buffers(buffers) => ScanSourceRef::Buffer(buffers.get(self.offset)?), + }; + + self.offset += 1; + Some(item) + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.sources.len() - self.offset; + (len, Some(len)) + } +} + +impl ExactSizeIterator for ScanSourceIter<'_> {} diff --git a/crates/polars-plan/src/dsl/selector.rs b/crates/polars-plan/src/dsl/selector.rs new file mode 100644 index 000000000000..7877edb152df --- /dev/null +++ b/crates/polars-plan/src/dsl/selector.rs @@ -0,0 +1,81 @@ +use std::ops::{Add, BitAnd, BitXor, Sub}; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use super::*; + +#[derive(Clone, PartialEq, Hash, Debug, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum Selector { + Add(Box, Box), + Sub(Box, Box), + ExclusiveOr(Box, Box), + Intersect(Box, Box), + Root(Box), +} + +impl Selector { + pub fn new(e: Expr) -> Self { + Self::Root(Box::new(e)) + } +} + +impl Add for Selector { + type Output = Selector; + + fn add(self, rhs: Self) -> Self::Output { + Selector::Add(Box::new(self), Box::new(rhs)) + } +} + +impl BitAnd for Selector { + type Output = Selector; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn bitand(self, rhs: Self) -> Self::Output { + Selector::Intersect(Box::new(self), Box::new(rhs)) + } +} + +impl BitXor for Selector { + type Output = Selector; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn bitxor(self, rhs: Self) -> Self::Output { + Selector::ExclusiveOr(Box::new(self), Box::new(rhs)) + } +} + +impl Sub for Selector { + type Output = Selector; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn sub(self, rhs: Self) -> Self::Output { + Selector::Sub(Box::new(self), Box::new(rhs)) + } +} + +impl From<&str> for Selector { + fn from(value: &str) -> Self { + Selector::new(col(PlSmallStr::from_str(value))) + } +} + +impl From for Selector { + fn from(value: String) -> Self { + Selector::new(col(PlSmallStr::from_string(value))) + } +} + +impl From for Selector { + fn from(value: PlSmallStr) -> Self { + Selector::new(Expr::Column(value)) + } +} + +impl From for Selector { + fn from(value: Expr) -> Self { + Selector::new(value) + } +} diff --git a/crates/polars-plan/src/dsl/statistics.rs b/crates/polars-plan/src/dsl/statistics.rs new file mode 100644 index 000000000000..7c03dcf1e1cf --- /dev/null +++ b/crates/polars-plan/src/dsl/statistics.rs @@ -0,0 +1,86 @@ +use super::*; + +impl Expr { + /// Standard deviation of the values of the Series. + pub fn std(self, ddof: u8) -> Self { + AggExpr::Std(Arc::new(self), ddof).into() + } + + /// Variance of the values of the Series. + pub fn var(self, ddof: u8) -> Self { + AggExpr::Var(Arc::new(self), ddof).into() + } + + /// Reduce groups to minimal value. + pub fn min(self) -> Self { + AggExpr::Min { + input: Arc::new(self), + propagate_nans: false, + } + .into() + } + + /// Reduce groups to maximum value. + pub fn max(self) -> Self { + AggExpr::Max { + input: Arc::new(self), + propagate_nans: false, + } + .into() + } + + /// Reduce groups to minimal value. + pub fn nan_min(self) -> Self { + AggExpr::Min { + input: Arc::new(self), + propagate_nans: true, + } + .into() + } + + /// Reduce groups to maximum value. + pub fn nan_max(self) -> Self { + AggExpr::Max { + input: Arc::new(self), + propagate_nans: true, + } + .into() + } + + /// Reduce groups to the mean value. + pub fn mean(self) -> Self { + AggExpr::Mean(Arc::new(self)).into() + } + + /// Reduce groups to the median value. + pub fn median(self) -> Self { + AggExpr::Median(Arc::new(self)).into() + } + + /// Reduce groups to the sum of all the values. + pub fn sum(self) -> Self { + AggExpr::Sum(Arc::new(self)).into() + } + + /// Compute the histogram of a dataset. + #[cfg(feature = "hist")] + pub fn hist( + self, + bins: Option, + bin_count: Option, + include_category: bool, + include_breakpoint: bool, + ) -> Self { + let mut input = vec![self]; + input.extend(bins); + + Expr::n_ary( + FunctionExpr::Hist { + bin_count, + include_category, + include_breakpoint, + }, + input, + ) + } +} diff --git a/crates/polars-plan/src/dsl/string.rs b/crates/polars-plan/src/dsl/string.rs new file mode 100644 index 000000000000..4cecacbf5042 --- /dev/null +++ b/crates/polars-plan/src/dsl/string.rs @@ -0,0 +1,498 @@ +use super::*; +/// Specialized expressions for [`Series`] of [`DataType::String`]. +pub struct StringNameSpace(pub(crate) Expr); + +impl StringNameSpace { + /// Check if a string value contains a literal substring. + #[cfg(feature = "regex")] + pub fn contains_literal(self, pat: Expr) -> Expr { + self.0.map_binary( + StringFunction::Contains { + literal: true, + strict: false, + }, + pat, + ) + } + + /// Check if this column of strings contains a Regex. If `strict` is `true`, then it is an error if any `pat` is + /// an invalid regex, whereas if `strict` is `false`, an invalid regex will simply evaluate to `false`. + #[cfg(feature = "regex")] + pub fn contains(self, pat: Expr, strict: bool) -> Expr { + self.0.map_binary( + StringFunction::Contains { + literal: false, + strict, + }, + pat, + ) + } + + /// Uses aho-corasick to find many patterns. + /// + /// # Arguments + /// - `patterns`: an expression that evaluates to an String column + /// - `ascii_case_insensitive`: Enable ASCII-aware case insensitive matching. + /// When this option is enabled, searching will be performed without respect to case for + /// ASCII letters (a-z and A-Z) only. + #[cfg(feature = "find_many")] + pub fn contains_any(self, patterns: Expr, ascii_case_insensitive: bool) -> Expr { + self.0.map_binary( + StringFunction::ContainsAny { + ascii_case_insensitive, + }, + patterns, + ) + } + + /// Uses aho-corasick to replace many patterns. + /// # Arguments + /// - `patterns`: an expression that evaluates to a String column + /// - `replace_with`: an expression that evaluates to a String column + /// - `ascii_case_insensitive`: Enable ASCII-aware case-insensitive matching. + /// When this option is enabled, searching will be performed without respect to case for + /// ASCII letters (a-z and A-Z) only. + #[cfg(feature = "find_many")] + pub fn replace_many( + self, + patterns: Expr, + replace_with: Expr, + ascii_case_insensitive: bool, + ) -> Expr { + self.0.map_ternary( + StringFunction::ReplaceMany { + ascii_case_insensitive, + }, + patterns, + replace_with, + ) + } + + /// Uses aho-corasick to replace many patterns. + /// # Arguments + /// - `patterns`: an expression that evaluates to a String column + /// - `ascii_case_insensitive`: Enable ASCII-aware case-insensitive matching. + /// When this option is enabled, searching will be performed without respect to case for + /// ASCII letters (a-z and A-Z) only. + /// - `overlapping`: Whether matches may overlap. + #[cfg(feature = "find_many")] + pub fn extract_many( + self, + patterns: Expr, + ascii_case_insensitive: bool, + overlapping: bool, + ) -> Expr { + self.0.map_binary( + StringFunction::ExtractMany { + ascii_case_insensitive, + overlapping, + }, + patterns, + ) + } + + /// Uses aho-corasick to find many patterns. + /// # Arguments + /// - `patterns`: an expression that evaluates to a String column + /// - `ascii_case_insensitive`: Enable ASCII-aware case-insensitive matching. + /// When this option is enabled, searching will be performed without respect to case for + /// ASCII letters (a-z and A-Z) only. + /// - `overlapping`: Whether matches may overlap. + #[cfg(feature = "find_many")] + pub fn find_many( + self, + patterns: Expr, + ascii_case_insensitive: bool, + overlapping: bool, + ) -> Expr { + self.0.map_binary( + StringFunction::FindMany { + ascii_case_insensitive, + overlapping, + }, + patterns, + ) + } + + /// Check if a string value ends with the `sub` string. + pub fn ends_with(self, sub: Expr) -> Expr { + self.0.map_binary(StringFunction::EndsWith, sub) + } + + /// Check if a string value starts with the `sub` string. + pub fn starts_with(self, sub: Expr) -> Expr { + self.0.map_binary(StringFunction::StartsWith, sub) + } + + #[cfg(feature = "string_encoding")] + pub fn hex_encode(self) -> Expr { + self.0.map_unary(StringFunction::HexEncode) + } + + #[cfg(feature = "binary_encoding")] + pub fn hex_decode(self, strict: bool) -> Expr { + self.0.map_unary(StringFunction::HexDecode(strict)) + } + + #[cfg(feature = "string_encoding")] + pub fn base64_encode(self) -> Expr { + self.0.map_unary(StringFunction::Base64Encode) + } + + #[cfg(feature = "binary_encoding")] + pub fn base64_decode(self, strict: bool) -> Expr { + self.0.map_unary(StringFunction::Base64Decode(strict)) + } + + /// Extract a regex pattern from the a string value. If `group_index` is out of bounds, null is returned. + pub fn extract(self, pat: Expr, group_index: usize) -> Expr { + self.0.map_binary(StringFunction::Extract(group_index), pat) + } + + #[cfg(feature = "extract_groups")] + // Extract all captures groups from a regex pattern as a struct + pub fn extract_groups(self, pat: &str) -> PolarsResult { + // regex will be compiled twice, because it doesn't support serde + // and we need to compile it here to determine the output datatype + + use polars_utils::format_pl_smallstr; + let reg = polars_utils::regex_cache::compile_regex(pat)?; + let names = reg + .capture_names() + .enumerate() + .skip(1) + .map(|(idx, opt_name)| { + opt_name + .map(PlSmallStr::from_str) + .unwrap_or_else(|| format_pl_smallstr!("{idx}")) + }) + .collect::>(); + + let dtype = DataType::Struct( + names + .iter() + .map(|name| Field::new(name.clone(), DataType::String)) + .collect(), + ); + + Ok(self.0.map_unary(StringFunction::ExtractGroups { + dtype, + pat: pat.into(), + })) + } + + /// Pad the start of the string until it reaches the given length. + /// + /// Padding is done using the specified `fill_char`. + /// Strings with length equal to or greater than the given length are + /// returned as-is. + #[cfg(feature = "string_pad")] + pub fn pad_start(self, length: usize, fill_char: char) -> Expr { + self.0 + .map_unary(StringFunction::PadStart { length, fill_char }) + } + + /// Pad the end of the string until it reaches the given length. + /// + /// Padding is done using the specified `fill_char`. + /// Strings with length equal to or greater than the given length are + /// returned as-is. + #[cfg(feature = "string_pad")] + pub fn pad_end(self, length: usize, fill_char: char) -> Expr { + self.0 + .map_unary(StringFunction::PadEnd { length, fill_char }) + } + + /// Pad the start of the string with zeros until it reaches the given length. + /// + /// A sign prefix (`-`) is handled by inserting the padding after the sign + /// character rather than before. + /// Strings with length equal to or greater than the given length are + /// returned as-is. + #[cfg(feature = "string_pad")] + pub fn zfill(self, length: Expr) -> Expr { + self.0.map_binary(StringFunction::ZFill, length) + } + + /// Find the index of a literal substring within another string value. + #[cfg(feature = "regex")] + pub fn find_literal(self, pat: Expr) -> Expr { + self.0.map_binary( + StringFunction::Find { + literal: true, + strict: false, + }, + pat, + ) + } + + /// Find the index of a substring defined by a regular expressions within another string value. + #[cfg(feature = "regex")] + pub fn find(self, pat: Expr, strict: bool) -> Expr { + self.0.map_binary( + StringFunction::Find { + literal: false, + strict, + }, + pat, + ) + } + + /// Extract each successive non-overlapping match in an individual string as an array + pub fn extract_all(self, pat: Expr) -> Expr { + self.0.map_binary(StringFunction::ExtractAll, pat) + } + + /// Count all successive non-overlapping regex matches. + pub fn count_matches(self, pat: Expr, literal: bool) -> Expr { + self.0 + .map_binary(StringFunction::CountMatches(literal), pat) + } + + /// Convert a String column into a Date/Datetime/Time column. + #[cfg(feature = "temporal")] + pub fn strptime(self, dtype: DataType, options: StrptimeOptions, ambiguous: Expr) -> Expr { + let is_column_independent = is_column_independent(&self.0); + // Only elementwise if the format is explicitly set, or we're constant. + self.0 + .map_binary(StringFunction::Strptime(dtype, options), ambiguous) + .with_function_options(|mut options| { + // @HACK. This needs to be done because literals still block predicate pushdown, + // but this should be an exception in the predicate pushdown. + if is_column_independent { + options.collect_groups = ApplyOptions::ElementWise; + } + options + }) + } + + /// Convert a String column into a Date column. + #[cfg(feature = "dtype-date")] + pub fn to_date(self, options: StrptimeOptions) -> Expr { + self.strptime(DataType::Date, options, lit("raise")) + } + + /// Convert a String column into a Datetime column. + #[cfg(feature = "dtype-datetime")] + pub fn to_datetime( + self, + time_unit: Option, + time_zone: Option, + options: StrptimeOptions, + ambiguous: Expr, + ) -> Expr { + // If time_unit is None, try to infer it from the format or set a default + let time_unit = match (&options.format, time_unit) { + (_, Some(time_unit)) => time_unit, + (Some(format), None) => { + if format.contains("%.9f") || format.contains("%9f") { + TimeUnit::Nanoseconds + } else if format.contains("%.3f") || format.contains("%3f") { + TimeUnit::Milliseconds + } else { + TimeUnit::Microseconds + } + }, + (None, None) => TimeUnit::Microseconds, + }; + + self.strptime(DataType::Datetime(time_unit, time_zone), options, ambiguous) + } + + /// Convert a String column into a Time column. + #[cfg(feature = "dtype-time")] + pub fn to_time(self, options: StrptimeOptions) -> Expr { + self.strptime(DataType::Time, options, lit("raise")) + } + + /// Convert a String column into a Decimal column. + #[cfg(feature = "dtype-decimal")] + pub fn to_decimal(self, infer_length: usize) -> Expr { + self.0.map_unary(StringFunction::ToDecimal(infer_length)) + } + + /// Concat the values into a string array. + /// # Arguments + /// + /// * `delimiter` - A string that will act as delimiter between values. + #[cfg(feature = "concat_str")] + pub fn join(self, delimiter: &str, ignore_nulls: bool) -> Expr { + self.0.map_unary(StringFunction::ConcatVertical { + delimiter: delimiter.into(), + ignore_nulls, + }) + } + + /// Split the string by a substring. The resulting dtype is `List`. + pub fn split(self, by: Expr) -> Expr { + self.0.map_binary(StringFunction::Split(false), by) + } + + /// Split the string by a substring and keep the substring. The resulting dtype is `List`. + pub fn split_inclusive(self, by: Expr) -> Expr { + self.0.map_binary(StringFunction::Split(true), by) + } + + #[cfg(feature = "dtype-struct")] + /// Split exactly `n` times by a given substring. The resulting dtype is [`DataType::Struct`]. + pub fn split_exact(self, by: Expr, n: usize) -> Expr { + self.0.map_binary( + StringFunction::SplitExact { + n, + inclusive: false, + }, + by, + ) + } + + #[cfg(feature = "dtype-struct")] + /// Split exactly `n` times by a given substring and keep the substring. + /// The resulting dtype is [`DataType::Struct`]. + pub fn split_exact_inclusive(self, by: Expr, n: usize) -> Expr { + self.0 + .map_binary(StringFunction::SplitExact { n, inclusive: true }, by) + } + + #[cfg(feature = "dtype-struct")] + /// Split by a given substring, returning exactly `n` items. If there are more possible splits, + /// keeps the remainder of the string intact. The resulting dtype is [`DataType::Struct`]. + pub fn splitn(self, by: Expr, n: usize) -> Expr { + self.0.map_binary(StringFunction::SplitN(n), by) + } + + #[cfg(feature = "regex")] + /// Replace values that match a regex `pat` with a `value`. + pub fn replace(self, pat: Expr, value: Expr, literal: bool) -> Expr { + self.0 + .map_ternary(StringFunction::Replace { n: 1, literal }, pat, value) + } + + #[cfg(feature = "regex")] + /// Replace values that match a regex `pat` with a `value`. + pub fn replace_n(self, pat: Expr, value: Expr, literal: bool, n: i64) -> Expr { + self.0 + .map_ternary(StringFunction::Replace { n, literal }, pat, value) + } + + #[cfg(feature = "regex")] + /// Replace all values that match a regex `pat` with a `value`. + pub fn replace_all(self, pat: Expr, value: Expr, literal: bool) -> Expr { + self.0 + .map_ternary(StringFunction::Replace { n: -1, literal }, pat, value) + } + + #[cfg(feature = "string_normalize")] + /// Normalize each string + pub fn normalize(self, form: UnicodeForm) -> Expr { + self.0.map_unary(StringFunction::Normalize { form }) + } + + #[cfg(feature = "string_reverse")] + /// Reverse each string + pub fn reverse(self) -> Expr { + self.0.map_unary(StringFunction::Reverse) + } + + /// Remove leading and trailing characters, or whitespace if matches is None. + pub fn strip_chars(self, matches: Expr) -> Expr { + self.0.map_binary(StringFunction::StripChars, matches) + } + + /// Remove leading characters, or whitespace if matches is None. + pub fn strip_chars_start(self, matches: Expr) -> Expr { + self.0.map_binary(StringFunction::StripCharsStart, matches) + } + + /// Remove trailing characters, or whitespace if matches is None. + pub fn strip_chars_end(self, matches: Expr) -> Expr { + self.0.map_binary(StringFunction::StripCharsEnd, matches) + } + + /// Remove prefix. + pub fn strip_prefix(self, prefix: Expr) -> Expr { + self.0.map_binary(StringFunction::StripPrefix, prefix) + } + + /// Remove suffix. + pub fn strip_suffix(self, suffix: Expr) -> Expr { + self.0.map_binary(StringFunction::StripSuffix, suffix) + } + + /// Convert all characters to lowercase. + pub fn to_lowercase(self) -> Expr { + self.0.map_unary(StringFunction::Lowercase) + } + + /// Convert all characters to uppercase. + pub fn to_uppercase(self) -> Expr { + self.0.map_unary(StringFunction::Uppercase) + } + + /// Convert all characters to titlecase. + #[cfg(feature = "nightly")] + pub fn to_titlecase(self) -> Expr { + self.0.map_unary(StringFunction::Titlecase) + } + + #[cfg(feature = "string_to_integer")] + /// Parse string in base radix into decimal. + pub fn to_integer(self, base: Expr, strict: bool) -> Expr { + self.0.map_binary(StringFunction::ToInteger(strict), base) + } + + /// Return the length of each string as the number of bytes. + /// + /// When working with non-ASCII text, the length in bytes is not the same + /// as the length in characters. You may want to use + /// [`len_chars`] instead. Note that `len_bytes` is much more + /// performant (_O(1)_) than [`len_chars`] (_O(n)_). + /// + /// [`len_chars`]: StringNameSpace::len_chars + pub fn len_bytes(self) -> Expr { + self.0.map_unary(StringFunction::LenBytes) + } + + /// Return the length of each string as the number of characters. + /// + /// When working with ASCII text, use [`len_bytes`] instead to achieve + /// equivalent output with much better performance: + /// [`len_bytes`] runs in _O(1)_, while `len_chars` runs in _O(n)_. + /// + /// [`len_bytes`]: StringNameSpace::len_bytes + pub fn len_chars(self) -> Expr { + self.0.map_unary(StringFunction::LenChars) + } + + /// Slice the string values. + pub fn slice(self, offset: Expr, length: Expr) -> Expr { + self.0.map_ternary(StringFunction::Slice, offset, length) + } + + /// Take the first `n` characters of the string values. + pub fn head(self, n: Expr) -> Expr { + self.0.map_binary(StringFunction::Head, n) + } + + /// Take the last `n` characters of the string values. + pub fn tail(self, n: Expr) -> Expr { + self.0.map_binary(StringFunction::Tail, n) + } + + #[cfg(feature = "extract_jsonpath")] + pub fn json_decode(self, dtype: Option, infer_schema_len: Option) -> Expr { + self.0.map_unary(StringFunction::JsonDecode { + dtype, + infer_schema_len, + }) + } + + #[cfg(feature = "extract_jsonpath")] + pub fn json_path_match(self, pat: Expr) -> Expr { + self.0.map_binary(StringFunction::JsonPathMatch, pat) + } + + #[cfg(feature = "regex")] + pub fn escape_regex(self) -> Expr { + self.0.map_unary(StringFunction::EscapeRegex) + } +} diff --git a/crates/polars-plan/src/dsl/struct_.rs b/crates/polars-plan/src/dsl/struct_.rs new file mode 100644 index 000000000000..4dc9829cb278 --- /dev/null +++ b/crates/polars-plan/src/dsl/struct_.rs @@ -0,0 +1,90 @@ +use super::*; +use crate::plans::conversion::is_regex_projection; + +/// Specialized expressions for Struct dtypes. +pub struct StructNameSpace(pub(crate) Expr); + +impl StructNameSpace { + pub fn field_by_index(self, index: i64) -> Expr { + self.0 + .map_unary(FunctionExpr::StructExpr(StructFunction::FieldByIndex( + index, + ))) + } + + /// Retrieve one or multiple of the fields of this [`StructChunked`] as a new Series. + /// This expression also expands the `"*"` wildcard column. + pub fn field_by_names(self, names: I) -> Expr + where + I: IntoIterator, + S: Into, + { + self.field_by_names_impl(names.into_iter().map(|x| x.into()).collect()) + } + + fn field_by_names_impl(self, names: Arc<[PlSmallStr]>) -> Expr { + self.0 + .map_unary(FunctionExpr::StructExpr(StructFunction::MultipleFields( + names, + ))) + } + + /// Retrieve one of the fields of this [`StructChunked`] as a new Series. + /// This expression also supports wildcard "*" and regex expansion. + pub fn field_by_name(self, name: &str) -> Expr { + if name == "*" || is_regex_projection(name) { + return self.field_by_names([name]); + } + self.0 + .map_unary(FunctionExpr::StructExpr(StructFunction::FieldByName( + name.into(), + ))) + } + + /// Rename the fields of the [`StructChunked`]. + pub fn rename_fields(self, names: I) -> Expr + where + I: IntoIterator, + S: Into, + { + self._rename_fields_impl(names.into_iter().map(|x| x.into()).collect()) + } + + pub fn _rename_fields_impl(self, names: Arc<[PlSmallStr]>) -> Expr { + self.0 + .map_unary(FunctionExpr::StructExpr(StructFunction::RenameFields( + names, + ))) + } + + #[cfg(feature = "json")] + pub fn json_encode(self) -> Expr { + self.0 + .map_unary(FunctionExpr::StructExpr(StructFunction::JsonEncode)) + } + + pub fn with_fields(self, fields: Vec) -> PolarsResult { + fn materialize_field(this: &Expr, field: Expr) -> PolarsResult { + field.try_map_expr(|e| match e { + Expr::Field(names) => { + let this = this.clone().struct_(); + Ok(if names.len() == 1 { + this.field_by_name(names[0].as_ref()) + } else { + this.field_by_names_impl(names) + }) + }, + Expr::Exclude(_, _) => { + polars_bail!(InvalidOperation: "'exclude' not allowed in 'field'") + }, + _ => Ok(e), + }) + } + + let s = self.0.clone(); + self.0.try_map_n_ary( + FunctionExpr::StructExpr(StructFunction::WithFields), + fields.into_iter().map(|e| materialize_field(&s, e)), + ) + } +} diff --git a/crates/polars-plan/src/dsl/udf.rs b/crates/polars-plan/src/dsl/udf.rs new file mode 100644 index 000000000000..e973fabbd961 --- /dev/null +++ b/crates/polars-plan/src/dsl/udf.rs @@ -0,0 +1,49 @@ +use polars_utils::pl_str::PlSmallStr; + +use super::{ColumnsUdf, Expr, GetOutput, OpaqueColumnUdf}; +use crate::prelude::{FunctionOptions, new_column_udf}; + +/// Represents a user-defined function +#[derive(Clone)] +pub struct UserDefinedFunction { + /// name + pub name: PlSmallStr, + /// The function output type. + pub return_type: GetOutput, + /// The function implementation. + pub fun: OpaqueColumnUdf, + /// Options for the function. + pub options: FunctionOptions, +} + +impl std::fmt::Debug for UserDefinedFunction { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("UserDefinedFunction") + .field("name", &self.name) + .field("fun", &"") + .field("options", &self.options) + .finish() + } +} + +impl UserDefinedFunction { + /// Create a new UserDefinedFunction + pub fn new(name: PlSmallStr, return_type: GetOutput, fun: impl ColumnsUdf + 'static) -> Self { + Self { + name, + return_type, + fun: new_column_udf(fun), + options: FunctionOptions::default(), + } + } + + /// creates a logical expression with a call of the UDF + pub fn call(self, args: Vec) -> Expr { + Expr::AnonymousFunction { + input: args, + function: self.fun, + output_type: self.return_type.clone(), + options: self.options, + } + } +} diff --git a/crates/polars-plan/src/frame/mod.rs b/crates/polars-plan/src/frame/mod.rs new file mode 100644 index 000000000000..df505965eb7d --- /dev/null +++ b/crates/polars-plan/src/frame/mod.rs @@ -0,0 +1,3 @@ +mod opt_state; + +pub use opt_state::*; diff --git a/crates/polars-plan/src/frame/opt_state.rs b/crates/polars-plan/src/frame/opt_state.rs new file mode 100644 index 000000000000..7472b623d8fc --- /dev/null +++ b/crates/polars-plan/src/frame/opt_state.rs @@ -0,0 +1,92 @@ +use bitflags::bitflags; + +bitflags! { +#[derive(Copy, Clone, Debug)] + /// Allowed optimizations. + pub struct OptFlags: u32 { + /// Only read columns that are used later in the query. + const PROJECTION_PUSHDOWN = 1; + /// Apply predicates/filters as early as possible. + const PREDICATE_PUSHDOWN = 1 << 2; + /// Cluster sequential `with_columns` calls to independent calls. + const CLUSTER_WITH_COLUMNS = 1 << 3; + /// Run many type coercion optimization rules until fixed point. + const TYPE_COERCION = 1 << 4; + /// Run many expression optimization rules until fixed point. + const SIMPLIFY_EXPR = 1 << 5; + /// Do type checking of the IR. + const TYPE_CHECK = 1 << 6; + /// Pushdown slices/limits. + const SLICE_PUSHDOWN = 1 << 7; + /// Run common-subplan-elimination. This elides duplicate plans and caches their + /// outputs. + const COMM_SUBPLAN_ELIM = 1 << 8; + /// Run common-subexpression-elimination. This elides duplicate expressions and caches their + /// outputs. + const COMM_SUBEXPR_ELIM = 1 << 9; + /// Run nodes that are capably of doing so on the streaming engine. + const STREAMING = 1 << 10; + const NEW_STREAMING = 1 << 11; + /// Run every node eagerly. This turns off multi-node optimizations. + const EAGER = 1 << 12; + /// Try to estimate the number of rows so that joins can determine which side to keep in memory. + const ROW_ESTIMATE = 1 << 13; + /// Replace simple projections with a faster inlined projection that skips the expression engine. + const FAST_PROJECTION = 1 << 14; + /// Collapse slower joins with filters into faster joins. + const COLLAPSE_JOINS = 1 << 15; + /// Check if operations are order dependent and unset maintaining_order if + /// the order would not be observed. + const CHECK_ORDER_OBSERVE = 1 << 16; + } +} + +impl OptFlags { + pub fn schema_only() -> Self { + Self::TYPE_COERCION | Self::TYPE_CHECK + } + + pub fn eager(&self) -> bool { + self.contains(OptFlags::EAGER) + } + + pub fn cluster_with_columns(&self) -> bool { + self.contains(OptFlags::CLUSTER_WITH_COLUMNS) + } + + pub fn collapse_joins(&self) -> bool { + self.contains(OptFlags::COLLAPSE_JOINS) + } + + pub fn predicate_pushdown(&self) -> bool { + self.contains(OptFlags::PREDICATE_PUSHDOWN) + } + + pub fn projection_pushdown(&self) -> bool { + self.contains(OptFlags::PROJECTION_PUSHDOWN) + } + pub fn simplify_expr(&self) -> bool { + self.contains(OptFlags::SIMPLIFY_EXPR) + } + pub fn slice_pushdown(&self) -> bool { + self.contains(OptFlags::SLICE_PUSHDOWN) + } + pub fn streaming(&self) -> bool { + self.contains(OptFlags::STREAMING) + } + pub fn new_streaming(&self) -> bool { + self.contains(OptFlags::NEW_STREAMING) + } + pub fn fast_projection(&self) -> bool { + self.contains(OptFlags::FAST_PROJECTION) + } +} + +impl Default for OptFlags { + fn default() -> Self { + Self::from_bits_truncate(u32::MAX) & !Self::NEW_STREAMING & !Self::STREAMING & !Self::EAGER + } +} + +/// AllowedOptimizations +pub type AllowedOptimizations = OptFlags; diff --git a/crates/polars-plan/src/global.rs b/crates/polars-plan/src/global.rs new file mode 100644 index 000000000000..5ae72b7dd2f9 --- /dev/null +++ b/crates/polars-plan/src/global.rs @@ -0,0 +1,13 @@ +use std::cell::Cell; + +// Will be set/ unset in the fetch operation to communicate overwriting the number of rows to scan. +thread_local! {pub static FETCH_ROWS: Cell> = const { Cell::new(None) }} + +pub fn _set_n_rows_for_scan(n_rows: Option) -> Option { + let fetch_rows = FETCH_ROWS.with(|fetch_rows| fetch_rows.get()); + fetch_rows.or(n_rows) +} + +pub fn _is_fetch_query() -> bool { + FETCH_ROWS.with(|fetch_rows| fetch_rows.get().is_some()) +} diff --git a/crates/polars-plan/src/lib.rs b/crates/polars-plan/src/lib.rs new file mode 100644 index 000000000000..3c495d30b9be --- /dev/null +++ b/crates/polars-plan/src/lib.rs @@ -0,0 +1,15 @@ +#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(feature = "nightly", allow(clippy::needless_pass_by_ref_mut))] // remove once stable +#![cfg_attr(feature = "nightly", allow(clippy::blocks_in_conditions))] // Remove once stable. + +extern crate core; + +#[cfg(feature = "polars_cloud")] +pub mod client; +pub mod constants; +pub mod dsl; +pub mod frame; +pub mod global; +pub mod plans; +pub mod prelude; +pub mod utils; diff --git a/crates/polars-plan/src/plans/aexpr/evaluate.rs b/crates/polars-plan/src/plans/aexpr/evaluate.rs new file mode 100644 index 000000000000..3b003fff8fcc --- /dev/null +++ b/crates/polars-plan/src/plans/aexpr/evaluate.rs @@ -0,0 +1,37 @@ +use std::borrow::Cow; + +use polars_core::schema::Schema; +use polars_utils::arena::{Arena, Node}; +use polars_utils::pl_str::PlSmallStr; + +use super::{AExpr, LiteralValue, aexpr_to_leaf_names_iter}; + +pub fn constant_evaluate<'a>( + e: Node, + expr_arena: &'a Arena, + _schema: &Schema, + _depth: usize, +) -> Option>> { + match expr_arena.get(e) { + AExpr::Literal(lv) => Some(Some(Cow::Borrowed(lv))), + _ => { + if aexpr_to_leaf_names_iter(e, expr_arena).next().is_none() { + Some(None) + } else { + None + } + }, + } +} + +pub fn into_column<'a>( + e: Node, + expr_arena: &'a Arena, + _schema: &Schema, + _depth: usize, +) -> Option<&'a PlSmallStr> { + match expr_arena.get(e) { + AExpr::Column(c) => Some(c), + _ => None, + } +} diff --git a/crates/polars-plan/src/plans/aexpr/hash.rs b/crates/polars-plan/src/plans/aexpr/hash.rs new file mode 100644 index 000000000000..6f8c97dc2a36 --- /dev/null +++ b/crates/polars-plan/src/plans/aexpr/hash.rs @@ -0,0 +1,46 @@ +use std::hash::{Hash, Hasher}; + +use polars_utils::arena::{Arena, Node}; + +use crate::plans::ArenaExprIter; +use crate::prelude::AExpr; + +impl Hash for AExpr { + // This hashes the variant, not the whole expression + fn hash(&self, state: &mut H) { + std::mem::discriminant(self).hash(state); + + match self { + AExpr::Column(name) => name.hash(state), + AExpr::Alias(_, name) => name.hash(state), + AExpr::Literal(lv) => lv.hash(state), + AExpr::Function { + options, function, .. + } => { + options.hash(state); + function.hash(state) + }, + AExpr::AnonymousFunction { options, .. } => { + options.hash(state); + }, + AExpr::Agg(agg) => agg.hash(state), + AExpr::SortBy { sort_options, .. } => sort_options.hash(state), + AExpr::Cast { + options: strict, .. + } => strict.hash(state), + AExpr::Window { options, .. } => options.hash(state), + AExpr::BinaryExpr { op, .. } => op.hash(state), + _ => {}, + } + } +} + +pub(crate) fn traverse_and_hash_aexpr( + node: Node, + expr_arena: &Arena, + state: &mut H, +) { + for (_, ae) in expr_arena.iter(node) { + ae.hash(state); + } +} diff --git a/crates/polars-plan/src/plans/aexpr/minterm_iter.rs b/crates/polars-plan/src/plans/aexpr/minterm_iter.rs new file mode 100644 index 000000000000..4a9e8d601e1d --- /dev/null +++ b/crates/polars-plan/src/plans/aexpr/minterm_iter.rs @@ -0,0 +1,55 @@ +use polars_utils::arena::{Arena, Node}; + +use super::{AExpr, Operator}; + +/// An iterator over all the minterms in a boolean expression boolean. +/// +/// In other words, all the terms that can `AND` together to form this expression. +/// +/// # Example +/// +/// ``` +/// a & (b | c) & (b & (c | (a & c))) +/// ``` +/// +/// Gives terms: +/// +/// ``` +/// a +/// b | c +/// b +/// c | (a & c) +/// ``` +pub struct MintermIter<'a> { + stack: Vec, + expr_arena: &'a Arena, +} + +impl Iterator for MintermIter<'_> { + type Item = Node; + + fn next(&mut self) -> Option { + let mut top = self.stack.pop()?; + + while let AExpr::BinaryExpr { + left, + op: Operator::And, + right, + } = self.expr_arena.get(top) + { + self.stack.push(*right); + top = *left; + } + + Some(top) + } +} + +impl<'a> MintermIter<'a> { + pub fn new(root: Node, expr_arena: &'a Arena) -> Self { + Self { + stack: vec![root], + expr_arena, + } + } +} diff --git a/crates/polars-plan/src/plans/aexpr/mod.rs b/crates/polars-plan/src/plans/aexpr/mod.rs new file mode 100644 index 000000000000..8d644b57f74b --- /dev/null +++ b/crates/polars-plan/src/plans/aexpr/mod.rs @@ -0,0 +1,230 @@ +mod evaluate; +#[cfg(feature = "cse")] +mod hash; +mod minterm_iter; +pub mod predicates; +mod scalar; +mod schema; +mod traverse; + +use std::hash::{Hash, Hasher}; + +#[cfg(feature = "cse")] +pub(super) use hash::traverse_and_hash_aexpr; +pub use minterm_iter::MintermIter; +use polars_compute::rolling::QuantileMethod; +use polars_core::chunked_array::cast::CastOptions; +use polars_core::prelude::*; +use polars_core::utils::{get_time_units, try_get_supertype}; +use polars_utils::arena::{Arena, Node}; +pub use scalar::is_scalar_ae; +#[cfg(feature = "ir_serde")] +use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; +pub use traverse::*; +mod properties; +pub use properties::*; + +use crate::constants::LEN; +use crate::plans::Context; +use crate::prelude::*; + +#[derive(Clone, Debug, IntoStaticStr)] +#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))] +pub enum IRAggExpr { + Min { + input: Node, + propagate_nans: bool, + }, + Max { + input: Node, + propagate_nans: bool, + }, + Median(Node), + NUnique(Node), + First(Node), + Last(Node), + Mean(Node), + Implode(Node), + Quantile { + expr: Node, + quantile: Node, + method: QuantileMethod, + }, + Sum(Node), + // include_nulls + Count(Node, bool), + Std(Node, u8), + Var(Node, u8), + AggGroups(Node), +} + +impl Hash for IRAggExpr { + fn hash(&self, state: &mut H) { + std::mem::discriminant(self).hash(state); + match self { + Self::Min { propagate_nans, .. } | Self::Max { propagate_nans, .. } => { + propagate_nans.hash(state) + }, + Self::Quantile { + method: interpol, .. + } => interpol.hash(state), + Self::Std(_, v) | Self::Var(_, v) => v.hash(state), + _ => {}, + } + } +} + +#[cfg(feature = "cse")] +impl IRAggExpr { + pub(super) fn equal_nodes(&self, other: &IRAggExpr) -> bool { + use IRAggExpr::*; + match (self, other) { + ( + Min { + propagate_nans: l, .. + }, + Min { + propagate_nans: r, .. + }, + ) => l == r, + ( + Max { + propagate_nans: l, .. + }, + Max { + propagate_nans: r, .. + }, + ) => l == r, + (Quantile { method: l, .. }, Quantile { method: r, .. }) => l == r, + (Std(_, l), Std(_, r)) => l == r, + (Var(_, l), Var(_, r)) => l == r, + _ => std::mem::discriminant(self) == std::mem::discriminant(other), + } + } +} + +impl From for GroupByMethod { + fn from(value: IRAggExpr) -> Self { + use IRAggExpr::*; + match value { + Min { propagate_nans, .. } => { + if propagate_nans { + GroupByMethod::NanMin + } else { + GroupByMethod::Min + } + }, + Max { propagate_nans, .. } => { + if propagate_nans { + GroupByMethod::NanMax + } else { + GroupByMethod::Max + } + }, + Median(_) => GroupByMethod::Median, + NUnique(_) => GroupByMethod::NUnique, + First(_) => GroupByMethod::First, + Last(_) => GroupByMethod::Last, + Mean(_) => GroupByMethod::Mean, + Implode(_) => GroupByMethod::Implode, + Sum(_) => GroupByMethod::Sum, + Count(_, include_nulls) => GroupByMethod::Count { include_nulls }, + Std(_, ddof) => GroupByMethod::Std(ddof), + Var(_, ddof) => GroupByMethod::Var(ddof), + AggGroups(_) => GroupByMethod::Groups, + Quantile { .. } => unreachable!(), + } + } +} + +/// IR expression node that is allocated in an [`Arena`][polars_utils::arena::Arena]. +#[derive(Clone, Debug, Default)] +#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))] +pub enum AExpr { + Explode(Node), + Alias(Node, PlSmallStr), + Column(PlSmallStr), + Literal(LiteralValue), + BinaryExpr { + left: Node, + op: Operator, + right: Node, + }, + Cast { + expr: Node, + dtype: DataType, + options: CastOptions, + }, + Sort { + expr: Node, + options: SortOptions, + }, + Gather { + expr: Node, + idx: Node, + returns_scalar: bool, + }, + SortBy { + expr: Node, + by: Vec, + sort_options: SortMultipleOptions, + }, + Filter { + input: Node, + by: Node, + }, + Agg(IRAggExpr), + Ternary { + predicate: Node, + truthy: Node, + falsy: Node, + }, + AnonymousFunction { + input: Vec, + function: OpaqueColumnUdf, + output_type: GetOutput, + options: FunctionOptions, + }, + Function { + /// Function arguments + /// Some functions rely on aliases, + /// for instance assignment of struct fields. + /// Therefor we need [`ExprIr`]. + input: Vec, + /// function to apply + function: FunctionExpr, + options: FunctionOptions, + }, + Window { + function: Node, + partition_by: Vec, + order_by: Option<(Node, SortOptions)>, + options: WindowType, + }, + Slice { + input: Node, + offset: Node, + length: Node, + }, + #[default] + Len, +} + +impl AExpr { + #[cfg(feature = "cse")] + pub(crate) fn col(name: PlSmallStr) -> Self { + AExpr::Column(name) + } + + /// This should be a 1 on 1 copy of the get_type method of Expr until Expr is completely phased out. + pub fn get_type( + &self, + schema: &Schema, + ctxt: Context, + arena: &Arena, + ) -> PolarsResult { + self.to_field(schema, ctxt, arena) + .map(|f| f.dtype().clone()) + } +} diff --git a/crates/polars-plan/src/plans/aexpr/predicates/column_expr.rs b/crates/polars-plan/src/plans/aexpr/predicates/column_expr.rs new file mode 100644 index 000000000000..9c0a9cf684f9 --- /dev/null +++ b/crates/polars-plan/src/plans/aexpr/predicates/column_expr.rs @@ -0,0 +1,119 @@ +//! This module creates predicates splits predicates into partial per-column predicates. + +use polars_core::datatypes::DataType; +use polars_core::scalar::Scalar; +use polars_core::schema::Schema; +use polars_io::predicates::SpecializedColumnPredicateExpr; +use polars_utils::aliases::PlHashMap; +use polars_utils::arena::{Arena, Node}; +use polars_utils::pl_str::PlSmallStr; + +use super::get_binary_expr_col_and_lv; +use crate::dsl::Operator; +use crate::plans::{AExpr, MintermIter, aexpr_to_leaf_names_iter}; + +pub struct ColumnPredicates { + pub predicates: PlHashMap)>, + + /// Are all column predicates AND-ed together the original predicate. + pub is_sumwise_complete: bool, +} + +pub fn aexpr_to_column_predicates( + root: Node, + expr_arena: &mut Arena, + schema: &Schema, +) -> ColumnPredicates { + let mut predicates = + PlHashMap::)>::default(); + let mut is_sumwise_complete = true; + + let minterms = MintermIter::new(root, expr_arena).collect::>(); + + let mut leaf_names = Vec::with_capacity(2); + for minterm in minterms { + leaf_names.clear(); + leaf_names.extend(aexpr_to_leaf_names_iter(minterm, expr_arena)); + + if leaf_names.len() != 1 { + is_sumwise_complete = false; + continue; + } + + let column = leaf_names.pop().unwrap(); + let Some(dtype) = schema.get(&column) else { + is_sumwise_complete = false; + continue; + }; + + // We really don't want to deal with these types. + use DataType as D; + match dtype { + #[cfg(feature = "dtype-categorical")] + D::Enum(_, _) | D::Categorical(_, _) => { + is_sumwise_complete = false; + continue; + }, + #[cfg(feature = "dtype-decimal")] + D::Decimal(_, _) => { + is_sumwise_complete = false; + continue; + }, + _ if dtype.is_nested() => { + is_sumwise_complete = false; + continue; + }, + _ => {}, + } + + let dtype = dtype.clone(); + let entry = predicates.entry(column); + + entry + .and_modify(|n| { + let left = n.0; + n.0 = expr_arena.add(AExpr::BinaryExpr { + left, + op: Operator::LogicalAnd, + right: minterm, + }); + n.1 = None; + }) + .or_insert_with(|| { + ( + minterm, + Some(()).and_then(|_| { + if std::env::var("POLARS_SPECIALIZED_COLUMN_PRED").as_deref() != Ok("1") { + return None; + } + + let aexpr = expr_arena.get(minterm); + + let AExpr::BinaryExpr { left, op, right } = aexpr else { + return None; + }; + let ((_, _), (lv, _)) = + get_binary_expr_col_and_lv(*left, *right, expr_arena, schema)?; + let lv = lv?; + let av = lv.to_any_value()?; + if av.dtype() != dtype { + return None; + } + let scalar = Scalar::new(dtype, av.into_static()); + use Operator as O; + match op { + O::Eq | O::EqValidity => { + Some(SpecializedColumnPredicateExpr::Eq(scalar)) + }, + _ => None, + } + }), + ) + }); + } + + ColumnPredicates { + predicates, + is_sumwise_complete, + } +} diff --git a/crates/polars-plan/src/plans/aexpr/predicates/mod.rs b/crates/polars-plan/src/plans/aexpr/predicates/mod.rs new file mode 100644 index 000000000000..8de7c240729e --- /dev/null +++ b/crates/polars-plan/src/plans/aexpr/predicates/mod.rs @@ -0,0 +1,35 @@ +mod column_expr; +mod skip_batches; + +use std::borrow::Cow; + +pub use column_expr::*; +use polars_core::schema::Schema; +use polars_utils::arena::{Arena, Node}; +use polars_utils::pl_str::PlSmallStr; +pub use skip_batches::*; + +use super::evaluate::{constant_evaluate, into_column}; +use super::{AExpr, LiteralValue}; + +#[allow(clippy::type_complexity)] +fn get_binary_expr_col_and_lv<'a>( + left: Node, + right: Node, + expr_arena: &'a Arena, + schema: &Schema, +) -> Option<( + (&'a PlSmallStr, Node), + (Option>, Node), +)> { + match ( + into_column(left, expr_arena, schema, 0), + into_column(right, expr_arena, schema, 0), + constant_evaluate(left, expr_arena, schema, 0), + constant_evaluate(right, expr_arena, schema, 0), + ) { + (Some(col), _, _, Some(lv)) => Some(((col, left), (lv, right))), + (_, Some(col), Some(lv), _) => Some(((col, right), (lv, left))), + _ => None, + } +} diff --git a/crates/polars-plan/src/plans/aexpr/predicates/skip_batches.rs b/crates/polars-plan/src/plans/aexpr/predicates/skip_batches.rs new file mode 100644 index 000000000000..64c77b65951d --- /dev/null +++ b/crates/polars-plan/src/plans/aexpr/predicates/skip_batches.rs @@ -0,0 +1,593 @@ +//! This module creates predicates that can skip record batches of rows based on statistics about +//! that record batch. + +use polars_core::prelude::{AnyValue, DataType, Scalar}; +use polars_core::schema::Schema; +use polars_utils::aliases::PlIndexMap; +use polars_utils::arena::{Arena, Node}; +use polars_utils::format_pl_smallstr; +use polars_utils::pl_str::PlSmallStr; + +use super::super::evaluate::{constant_evaluate, into_column}; +use super::super::{AExpr, BooleanFunction, Operator, OutputName}; +use crate::dsl::FunctionExpr; +use crate::plans::predicates::get_binary_expr_col_and_lv; +use crate::plans::{ExprIR, LiteralValue, aexpr_to_leaf_names_iter, is_scalar_ae, rename_columns}; +use crate::prelude::FunctionOptions; + +/// Return a new boolean expression determines whether a batch can be skipped based on min, max and +/// null count statistics. +/// +/// This is conversative and may return `None` or `false` when an expression is not yet supported. +/// +/// To evaluate, the expression it is given all the original column appended with `_min` and +/// `_max`. The `min` or `max` cannot be null and when they are null it is assumed they are not +/// known. +pub fn aexpr_to_skip_batch_predicate( + e: Node, + expr_arena: &mut Arena, + schema: &Schema, +) -> Option { + aexpr_to_skip_batch_predicate_rec(e, expr_arena, schema, 0) +} + +fn does_dtype_have_sufficient_order(dtype: &DataType) -> bool { + // Rules surrounding floats are really complicated. I should get around to that. + !dtype.is_nested() && !dtype.is_float() && !dtype.is_null() && !dtype.is_categorical() +} + +#[recursive::recursive] +fn aexpr_to_skip_batch_predicate_rec( + e: Node, + expr_arena: &mut Arena, + schema: &Schema, + depth: usize, +) -> Option { + use Operator as O; + + macro_rules! rec { + ($node:expr) => {{ aexpr_to_skip_batch_predicate_rec($node, expr_arena, schema, depth + 1) }}; + } + macro_rules! and { + ($l:expr, $($r:expr),+ $(,)?) => {{ + let node = $l; + $( + let node = expr_arena.add(AExpr::BinaryExpr { + left: node, + op: O::LogicalAnd, + right: $r, + }); + )+ + node + }} + } + macro_rules! or { + ($l:expr, $($r:expr),+ $(,)?) => {{ + let node = $l; + $( + let node = expr_arena.add(AExpr::BinaryExpr { + left: node, + op: O::LogicalOr, + right: $r, + }); + )+ + node + }} + } + macro_rules! binexpr { + (op: $op:expr, $l:expr, $r:expr) => {{ + expr_arena.add(AExpr::BinaryExpr { + left: $l, + op: $op, + right: $r, + }) + }}; + ($op:ident, $l:expr, $r:expr) => {{ binexpr!(op: O::$op, $l, $r) }}; + } + macro_rules! lt { + ($l:expr, $r:expr) => {{ binexpr!(Lt, $l, $r) }}; + } + macro_rules! gt { + ($l:expr, $r:expr) => {{ binexpr!(Gt, $l, $r) }}; + } + macro_rules! eq { + ($l:expr, $r:expr) => {{ binexpr!(Eq, $l, $r) }}; + } + macro_rules! null_count { + ($i:expr) => {{ + expr_arena.add(AExpr::Function { + input: vec![ExprIR::new($i, OutputName::Alias(PlSmallStr::EMPTY))], + function: FunctionExpr::NullCount, + options: FunctionOptions::default(), + }) + }}; + } + macro_rules! has_no_nulls { + ($i:expr) => {{ + let expr = null_count!($i); + let idx_zero = lv!(0); + eq!(expr, idx_zero) + }}; + } + macro_rules! has_nulls { + ($i:expr) => {{ + let expr = null_count!($i); + let idx_zero = lv!(0); + gt!(expr, idx_zero) + }}; + } + macro_rules! is_stat_defined { + ($i:expr, $dtype:expr) => {{ + let mut expr = expr_arena.add(AExpr::Function { + input: vec![ExprIR::new($i, OutputName::Alias(PlSmallStr::EMPTY))], + function: FunctionExpr::Boolean(BooleanFunction::IsNotNull), + options: FunctionOptions::default(), + }); + + if $dtype.is_float() { + let is_not_nan = expr_arena.add(AExpr::Function { + input: vec![ExprIR::new($i, OutputName::Alias(PlSmallStr::EMPTY))], + function: FunctionExpr::Boolean(BooleanFunction::IsNotNan), + options: FunctionOptions::default(), + }); + expr = and!(is_not_nan, expr); + } + + expr + }}; + } + macro_rules! lv_cases { + ( + $lv:expr, $lv_node:expr, + null: $null_case:expr, + not_null: $non_null_case:expr $(,)? + ) => {{ + if let Some(lv) = $lv { + if lv.is_null() { + $null_case + } else { + $non_null_case + } + } else { + let lv_is_null = has_nulls!($lv_node); + let lv_not_null = has_no_nulls!($lv_node); + + let null_case = $null_case; + let null_case = and!(lv_is_null, null_case); + let non_null_case = $non_null_case; + let non_null_case = and!(lv_not_null, non_null_case); + + or!(null_case, non_null_case) + } + }}; + } + macro_rules! col { + (len) => {{ col!(PlSmallStr::from_static("len")) }}; + ($name:expr) => {{ expr_arena.add(AExpr::Column($name)) }}; + (min: $name:expr) => {{ col!(format_pl_smallstr!("{}_min", $name)) }}; + (max: $name:expr) => {{ col!(format_pl_smallstr!("{}_max", $name)) }}; + (null_count: $name:expr) => {{ col!(format_pl_smallstr!("{}_nc", $name)) }}; + } + macro_rules! lv { + ($lv:expr) => {{ expr_arena.add(AExpr::Literal(Scalar::from($lv).into())) }}; + (idx: $lv:expr) => {{ expr_arena.add(AExpr::Literal(LiteralValue::new_idxsize($lv))) }}; + } + + let specialized = (|| { + if let Some(Some(lv)) = constant_evaluate(e, expr_arena, schema, 0) { + if let Some(av) = lv.to_any_value() { + return match av { + AnyValue::Null => Some(lv!(true)), + AnyValue::Boolean(b) => Some(lv!(!b)), + _ => None, + }; + } + } + + match expr_arena.get(e) { + AExpr::Explode(_) => None, + AExpr::Alias(_, _) => None, + AExpr::Column(_) => None, + AExpr::Literal(_) => None, + AExpr::BinaryExpr { left, op, right } => { + let left = *left; + let right = *right; + + match op { + O::Eq | O::EqValidity => { + let ((col, _), (lv, lv_node)) = + get_binary_expr_col_and_lv(left, right, expr_arena, schema)?; + let dtype = schema.get(col)?; + + if !does_dtype_have_sufficient_order(dtype) { + return None; + } + + let op = *op; + let col = col.clone(); + + // col(A) == B -> { + // null_count(A) == 0 , if B.is_null(), + // null_count(A) == LEN || min(A) > B || max(A) < B, if B.is_not_null(), + // } + + Some(lv_cases!( + lv, lv_node, + null: { + if matches!(op, O::Eq) { + lv!(false) + } else { + let col_nc = col!(null_count: col); + let idx_zero = lv!(idx: 0); + eq!(col_nc, idx_zero) + } + }, + not_null: { + let col_min = col!(min: col); + let col_max = col!(max: col); + + let min_is_defined = is_stat_defined!(col_min, dtype); + let max_is_defined = is_stat_defined!(col_max, dtype); + + let min_gt = gt!(col_min, lv_node); + let min_gt = and!(min_is_defined, min_gt); + + let max_lt = lt!(col_max, lv_node); + let max_lt = and!(max_is_defined, max_lt); + + let col_nc = col!(null_count: col); + let len = col!(len); + let all_nulls = eq!(col_nc, len); + + or!(all_nulls, min_gt, max_lt) + } + )) + }, + O::NotEq | O::NotEqValidity => { + let ((col, _), (lv, lv_node)) = + get_binary_expr_col_and_lv(left, right, expr_arena, schema)?; + let dtype = schema.get(col)?; + + if !does_dtype_have_sufficient_order(dtype) { + return None; + } + + let op = *op; + let col = col.clone(); + + // col(A) != B -> { + // null_count(A) == LEN , if B.is_null(), + // null_count(A) == 0 && min(A) == B && max(A) == B, if B.is_not_null(), + // } + + Some(lv_cases!( + lv, lv_node, + null: { + if matches!(op, O::NotEq) { + lv!(false) + } else { + let col_nc = col!(null_count: col); + let len = col!(len); + eq!(col_nc, len) + } + }, + not_null: { + let col_min = col!(min: col); + let col_max = col!(max: col); + let min_eq = eq!(col_min, lv_node); + let max_eq = eq!(col_max, lv_node); + + let col_nc = col!(null_count: col); + let idx_zero = lv!(idx: 0); + let no_nulls = eq!(col_nc, idx_zero); + + and!(no_nulls, min_eq, max_eq) + } + )) + }, + O::Lt | O::Gt | O::LtEq | O::GtEq => { + let ((col, col_node), (lv, lv_node)) = + get_binary_expr_col_and_lv(left, right, expr_arena, schema)?; + let dtype = schema.get(col)?; + + if !does_dtype_have_sufficient_order(dtype) { + return None; + } + + let col_is_left = col_node == left; + + let op = *op; + let col = col.clone(); + let lv_may_be_null = lv.is_none_or(|lv| lv.is_null()); + + // If B is null, this is always true. + // + // col(A) < B ~ B > col(A) -> + // null_count(A) == LEN || min(A) >= B + // + // col(A) > B ~ B < col(A) -> + // null_count(A) == LEN || max(A) <= B + // + // col(A) <= B ~ B >= col(A) -> + // null_count(A) == LEN || min(A) > B + // + // col(A) >= B ~ B <= col(A) -> + // null_count(A) == LEN || max(A) < B + + let stat = match (op, col_is_left) { + (O::Lt | O::LtEq, true) | (O::Gt | O::GtEq, false) => col!(min: col), + (O::Lt | O::LtEq, false) | (O::Gt | O::GtEq, true) => col!(max: col), + _ => unreachable!(), + }; + let cmp_op = match (op, col_is_left) { + (O::Lt, true) | (O::Gt, false) => O::GtEq, + (O::Lt, false) | (O::Gt, true) => O::LtEq, + + (O::LtEq, true) | (O::GtEq, false) => O::Gt, + (O::LtEq, false) | (O::GtEq, true) => O::Lt, + + _ => unreachable!(), + }; + + let stat_is_defined = is_stat_defined!(stat, dtype); + let cmp_op = binexpr!(op: cmp_op, stat, lv_node); + let mut expr = and!(stat_is_defined, cmp_op); + + if lv_may_be_null { + let has_nulls = has_nulls!(lv_node); + expr = or!(has_nulls, expr); + } + Some(expr) + }, + + O::And | O::LogicalAnd => match (rec!(left), rec!(right)) { + (Some(left), Some(right)) => Some(or!(left, right)), + (Some(n), None) | (None, Some(n)) => Some(n), + (None, None) => None, + }, + O::Or | O::LogicalOr => { + let left = rec!(left)?; + let right = rec!(right)?; + Some(and!(left, right)) + }, + + O::Plus + | O::Minus + | O::Multiply + | O::Divide + | O::TrueDivide + | O::FloorDivide + | O::Modulus + | O::Xor => None, + } + }, + AExpr::Cast { .. } => None, + AExpr::Sort { .. } => None, + AExpr::Gather { .. } => None, + AExpr::SortBy { .. } => None, + AExpr::Filter { .. } => None, + AExpr::Agg(..) => None, + AExpr::Ternary { .. } => None, + AExpr::AnonymousFunction { .. } => None, + AExpr::Function { + input, function, .. + } => match function { + FunctionExpr::Boolean(f) => match f { + #[cfg(feature = "is_in")] + BooleanFunction::IsIn { nulls_equal } => { + if !is_scalar_ae(input[1].node(), expr_arena) { + return None; + } + + let nulls_equal = *nulls_equal; + let lv_node = input[1].node(); + match ( + into_column(input[0].node(), expr_arena, schema, 0), + constant_evaluate(lv_node, expr_arena, schema, 0), + ) { + (Some(col), Some(_)) => { + let dtype = schema.get(col)?; + if !does_dtype_have_sufficient_order(dtype) { + return None; + } + + // col(A).is_in([B1, ..., Bn]) -> + // ([B1, ..., Bn].has_no_nulls() || null_count(A) == 0) && + // ( + // min(A) > max[B1, ..., Bn] || + // max(A) < min[B1, ..., Bn] + // ) + let col = col.clone(); + + let lv_node_exploded = expr_arena.add(AExpr::Explode(lv_node)); + let lv_min = + expr_arena.add(AExpr::Agg(crate::plans::IRAggExpr::Min { + input: lv_node_exploded, + propagate_nans: true, + })); + let lv_max = + expr_arena.add(AExpr::Agg(crate::plans::IRAggExpr::Max { + input: lv_node_exploded, + propagate_nans: true, + })); + + let col_min = col!(min: col); + let col_max = col!(max: col); + + let min_is_defined = is_stat_defined!(col_min, dtype); + let max_is_defined = is_stat_defined!(col_max, dtype); + + let min_gt = gt!(col_min, lv_max); + let min_gt = and!(min_is_defined, min_gt); + + let max_lt = lt!(col_max, lv_min); + let max_lt = and!(max_is_defined, max_lt); + + let expr = or!(min_gt, max_lt); + + let col_nc = col!(null_count: col); + let idx_zero = lv!(idx: 0); + let col_has_no_nulls = eq!(col_nc, idx_zero); + + let lv_has_not_nulls = has_no_nulls!(lv_node_exploded); + let null_case = or!(lv_has_not_nulls, col_has_no_nulls); + + let min_max_is_in = and!(null_case, expr); + + let col_nc = col!(null_count: col); + + let min_is_max = binexpr!(Eq, col_min, col_max); // Eq so that (None == None) == None + let idx_zero = lv!(idx: 0); + let has_no_nulls = eq!(col_nc, idx_zero); + + // The above case does always cover the fallback path. Since there + // is code that relies on the `min==max` always filtering normally, + // we add it here. + let exact_not_in = expr_arena.add(AExpr::Function { + input: vec![ + ExprIR::new(col_min, OutputName::Alias(PlSmallStr::EMPTY)), + ExprIR::new(lv_node, OutputName::Alias(PlSmallStr::EMPTY)), + ], + function: FunctionExpr::Boolean(BooleanFunction::IsIn { + nulls_equal, + }), + options: BooleanFunction::IsIn { nulls_equal } + .function_options(), + }); + let exact_not_in = expr_arena.add(AExpr::Function { + input: vec![ExprIR::new( + exact_not_in, + OutputName::Alias(PlSmallStr::EMPTY), + )], + function: FunctionExpr::Boolean(BooleanFunction::Not), + options: BooleanFunction::Not.function_options(), + }); + let exact_not_in = and!(min_is_max, has_no_nulls, exact_not_in); + + Some(or!(exact_not_in, min_max_is_in)) + }, + _ => None, + } + }, + BooleanFunction::IsNull => { + let col = into_column(input[0].node(), expr_arena, schema, 0)?; + + // col(A).is_null() -> null_count(A) == 0 + let col_nc = col!(null_count: col); + let idx_zero = lv!(idx: 0); + Some(eq!(col_nc, idx_zero)) + }, + BooleanFunction::IsNotNull => { + let col = into_column(input[0].node(), expr_arena, schema, 0)?; + + // col(A).is_not_null() -> null_count(A) == LEN + let col_nc = col!(null_count: col); + let len = col!(len); + Some(eq!(col_nc, len)) + }, + #[cfg(feature = "is_between")] + BooleanFunction::IsBetween { closed } => { + let col = into_column(input[0].node(), expr_arena, schema, 0)?; + let dtype = schema.get(col)?; + + if !does_dtype_have_sufficient_order(dtype) { + return None; + } + + // col(A).is_between(X, Y) -> + // null_count(A) == LEN || + // min(A) >(=) Y || + // max(A) <(=) X + + let left_node = input[1].node(); + let right_node = input[2].node(); + + _ = constant_evaluate(left_node, expr_arena, schema, 0)?; + _ = constant_evaluate(right_node, expr_arena, schema, 0)?; + + let col = col.clone(); + let closed = *closed; + + let lhs_no_nulls = has_no_nulls!(left_node); + let rhs_no_nulls = has_no_nulls!(right_node); + + let col_min = col!(min: col); + let col_max = col!(max: col); + + use polars_ops::series::ClosedInterval; + let (left, right) = match closed { + ClosedInterval::Both => (O::Lt, O::Gt), + ClosedInterval::Left => (O::Lt, O::GtEq), + ClosedInterval::Right => (O::LtEq, O::Gt), + ClosedInterval::None => (O::LtEq, O::GtEq), + }; + + let left = binexpr!(op: left, col_max, left_node); + let right = binexpr!(op: right, col_min, right_node); + + let min_is_defined = is_stat_defined!(col_min, dtype); + let max_is_defined = is_stat_defined!(col_max, dtype); + + let left = and!(max_is_defined, left); + let right = and!(min_is_defined, right); + + let interval = or!(left, right); + Some(and!(lhs_no_nulls, rhs_no_nulls, interval)) + }, + _ => None, + }, + _ => None, + }, + AExpr::Window { .. } => None, + AExpr::Slice { .. } => None, + AExpr::Len => None, + } + })(); + + if let Some(specialized) = specialized { + return Some(specialized); + } + + // If we don't have a specialized implementation we can check if the whole block is constant + // and fill that value in. This is especially useful when filtering hive partitions which are + // filtered using this expression and which set their min == max. + // + // Essentially, what this does is + // E -> all(col(A_min) == col(A_max) & col(A_nc) == 0 for A in LIVE(E)) & ~(E) + + let live_columns = PlIndexMap::from_iter(aexpr_to_leaf_names_iter(e, expr_arena).map(|col| { + let min_name = format_pl_smallstr!("{col}_min"); + (col, min_name) + })); + + // We cannot do proper equalities for these. + if live_columns + .iter() + .any(|(c, _)| schema.get(c).is_none_or(|dt| dt.is_categorical())) + { + return None; + } + + // Rename all uses of column names with the min value. + let expr = rename_columns(e, expr_arena, &live_columns); + let mut expr = expr_arena.add(AExpr::Function { + input: vec![ExprIR::new(expr, OutputName::Alias(PlSmallStr::EMPTY))], + function: FunctionExpr::Boolean(BooleanFunction::Not), + options: FunctionOptions { + collect_groups: crate::plans::ApplyOptions::ElementWise, + ..Default::default() + }, + }); + for col in live_columns.keys() { + let col_min = col!(min: col); + let col_max = col!(max: col); + let col_nc = col!(null_count: col); + + let min_is_max = binexpr!(Eq, col_min, col_max); // Eq so that (None == None) == None + let idx_zero = lv!(idx: 0); + let has_no_nulls = eq!(col_nc, idx_zero); + + expr = and!(min_is_max, has_no_nulls, expr); + } + Some(expr) +} diff --git a/crates/polars-plan/src/plans/aexpr/properties.rs b/crates/polars-plan/src/plans/aexpr/properties.rs new file mode 100644 index 000000000000..02c31de59c53 --- /dev/null +++ b/crates/polars-plan/src/plans/aexpr/properties.rs @@ -0,0 +1,289 @@ +use polars_utils::idx_vec::UnitVec; +use polars_utils::unitvec; + +use super::*; + +impl AExpr { + pub(crate) fn is_leaf(&self) -> bool { + matches!(self, AExpr::Column(_) | AExpr::Literal(_) | AExpr::Len) + } + + pub(crate) fn is_col(&self) -> bool { + matches!(self, AExpr::Column(_)) + } + + /// Checks whether this expression is elementwise. This only checks the top level expression. + pub(crate) fn is_elementwise_top_level(&self) -> bool { + use AExpr::*; + + match self { + AnonymousFunction { options, .. } => options.is_elementwise(), + + Function { options, .. } => options.is_elementwise(), + + Literal(v) => v.is_scalar(), + + Alias(_, _) | BinaryExpr { .. } | Column(_) | Ternary { .. } | Cast { .. } => true, + + Agg { .. } + | Explode(_) + | Filter { .. } + | Gather { .. } + | Len + | Slice { .. } + | Sort { .. } + | SortBy { .. } + | Window { .. } => false, + } + } +} + +/// Checks if the top-level expression node is elementwise. If this is the case, then `stack` will +/// be extended further with any nested expression nodes. +pub fn is_elementwise(stack: &mut UnitVec, ae: &AExpr, expr_arena: &Arena) -> bool { + use AExpr::*; + + if !ae.is_elementwise_top_level() { + return false; + } + + match ae { + // Literals that aren't being projected are allowed to be non-scalar, so we don't add them + // for inspection. (e.g. `is_in()`). + #[cfg(feature = "is_in")] + Function { + function: FunctionExpr::Boolean(BooleanFunction::IsIn { .. }), + input, + .. + } => (|| { + if let Some(rhs) = input.get(1) { + assert_eq!(input.len(), 2); // A.is_in(B) + let rhs = rhs.node(); + + if matches!(expr_arena.get(rhs), AExpr::Literal { .. }) { + stack.extend([input[0].node()]); + return; + } + }; + + ae.inputs_rev(stack); + })(), + _ => ae.inputs_rev(stack), + } + + true +} + +pub fn all_elementwise<'a, N>(nodes: &'a [N], expr_arena: &Arena) -> bool +where + Node: From<&'a N>, +{ + nodes + .iter() + .all(|n| is_elementwise_rec(n.into(), expr_arena)) +} + +/// Recursive variant of `is_elementwise` +pub fn is_elementwise_rec(node: Node, expr_arena: &Arena) -> bool { + let mut stack = unitvec![]; + let mut ae = expr_arena.get(node); + + loop { + if !is_elementwise(&mut stack, ae, expr_arena) { + return false; + } + + let Some(node) = stack.pop() else { + break; + }; + + ae = expr_arena.get(node); + } + + true +} + +/// Recursive variant of `is_elementwise` that also forbids casting to categoricals. This function +/// is used to determine if an expression evaluation can be vertically parallelized. +pub fn is_elementwise_rec_no_cat_cast<'a>(mut ae: &'a AExpr, expr_arena: &'a Arena) -> bool { + let mut stack = unitvec![]; + + loop { + if !is_elementwise(&mut stack, ae, expr_arena) { + return false; + } + + #[cfg(feature = "dtype-categorical")] + { + if let AExpr::Cast { + dtype: DataType::Categorical(..), + .. + } = ae + { + return false; + } + } + + let Some(node) = stack.pop() else { + break; + }; + + ae = expr_arena.get(node); + } + + true +} + +/// Check whether filters can be pushed past this expression. +/// +/// A query, `with_columns(C).filter(P)` can be re-ordered as `filter(P).with_columns(C)`, iff +/// both P and C permit filter pushdown. +/// +/// If filter pushdown is permitted, `stack` is extended with any input expression nodes that this +/// expression may have. +/// +/// Note that this function is not recursive - the caller should repeatedly +/// call this function with the `stack` to perform a recursive check. +pub(crate) fn permits_filter_pushdown( + stack: &mut UnitVec, + ae: &AExpr, + expr_arena: &Arena, +) -> bool { + // This is a subset of an `is_elementwise` check that also blocks exprs that raise errors + // depending on the data. The idea is that, although the success value of these functions + // are elementwise, their error behavior is non-elementwise. Their error behavior is essentially + // performing an aggregation `ANY(evaluation_result_was_error)`, and if this is the case then + // the query result should be an error. + match ae { + // Rows that go OOB on get/gather may be filtered out in earlier operations, + // so we don't push these down. + AExpr::Function { + function: FunctionExpr::ListExpr(ListFunction::Get(false)), + .. + } => false, + #[cfg(feature = "list_gather")] + AExpr::Function { + function: FunctionExpr::ListExpr(ListFunction::Gather(false)), + .. + } => false, + #[cfg(feature = "dtype-array")] + AExpr::Function { + function: FunctionExpr::ArrayExpr(ArrayFunction::Get(false)), + .. + } => false, + // TODO: There are a lot more functions that should be caught here. + ae => is_elementwise(stack, ae, expr_arena), + } +} + +pub fn permits_filter_pushdown_rec<'a>(mut ae: &'a AExpr, expr_arena: &'a Arena) -> bool { + let mut stack = unitvec![]; + + loop { + if !permits_filter_pushdown(&mut stack, ae, expr_arena) { + return false; + } + + let Some(node) = stack.pop() else { + break; + }; + + ae = expr_arena.get(node); + } + + true +} + +pub fn can_pre_agg_exprs( + exprs: &[ExprIR], + expr_arena: &Arena, + _input_schema: &Schema, +) -> bool { + exprs + .iter() + .all(|e| can_pre_agg(e.node(), expr_arena, _input_schema)) +} + +/// Checks whether an expression can be pre-aggregated in a group-by. Note that this also must be +/// implemented physically, so this isn't a complete list. +pub fn can_pre_agg(agg: Node, expr_arena: &Arena, _input_schema: &Schema) -> bool { + let aexpr = expr_arena.get(agg); + + match aexpr { + AExpr::Len => true, + AExpr::Column(_) | AExpr::Literal(_) => false, + // We only allow expressions that end with an aggregation. + AExpr::Agg(_) => { + let has_aggregation = + |node: Node| has_aexpr(node, expr_arena, |ae| matches!(ae, AExpr::Agg(_))); + + // check if the aggregation type is partitionable + // only simple aggregation like col().sum + // that can be divided in to the aggregation of their partitions are allowed + let can_partition = (expr_arena).iter(agg).all(|(_, ae)| { + use AExpr::*; + match ae { + // struct is needed to keep both states + #[cfg(feature = "dtype-struct")] + Agg(IRAggExpr::Mean(_)) => { + // only numeric means for now. + // logical types seem to break because of casts to float. + matches!( + expr_arena + .get(agg) + .get_type(_input_schema, Context::Default, expr_arena) + .map(|dt| { dt.is_primitive_numeric() }), + Ok(true) + ) + }, + // only allowed expressions + Agg(agg_e) => { + matches!( + agg_e, + IRAggExpr::Min { .. } + | IRAggExpr::Max { .. } + | IRAggExpr::Sum(_) + | IRAggExpr::Last(_) + | IRAggExpr::First(_) + | IRAggExpr::Count(_, true) + ) + }, + Function { input, options, .. } => { + matches!(options.collect_groups, ApplyOptions::ElementWise) + && input.len() == 1 + && !has_aggregation(input[0].node()) + }, + BinaryExpr { left, right, .. } => { + !has_aggregation(*left) && !has_aggregation(*right) + }, + Ternary { + truthy, + falsy, + predicate, + .. + } => { + !has_aggregation(*truthy) + && !has_aggregation(*falsy) + && !has_aggregation(*predicate) + }, + Literal(lv) => lv.is_scalar(), + Column(_) | Len | Cast { .. } => true, + _ => false, + } + }); + + #[cfg(feature = "object")] + { + for name in aexpr_to_leaf_names(agg, expr_arena) { + let dtype = _input_schema.get(&name).unwrap(); + + if let DataType::Object(_) = dtype { + return false; + } + } + } + can_partition + }, + _ => false, + } +} diff --git a/crates/polars-plan/src/plans/aexpr/scalar.rs b/crates/polars-plan/src/plans/aexpr/scalar.rs new file mode 100644 index 000000000000..4ad0199b66d4 --- /dev/null +++ b/crates/polars-plan/src/plans/aexpr/scalar.rs @@ -0,0 +1,31 @@ +use recursive::recursive; + +use super::*; + +#[recursive] +pub fn is_scalar_ae(node: Node, expr_arena: &Arena) -> bool { + match expr_arena.get(node) { + AExpr::Literal(lv) => lv.is_scalar(), + AExpr::Function { options, input, .. } + | AExpr::AnonymousFunction { options, input, .. } => { + if options.flags.contains(FunctionFlags::RETURNS_SCALAR) { + true + } else if options.is_elementwise() + || !options.flags.contains(FunctionFlags::CHANGES_LENGTH) + { + input.iter().all(|e| e.is_scalar(expr_arena)) + } else { + false + } + }, + AExpr::BinaryExpr { left, right, .. } => { + is_scalar_ae(*left, expr_arena) && is_scalar_ae(*right, expr_arena) + }, + AExpr::Ternary { truthy, falsy, .. } => { + is_scalar_ae(*truthy, expr_arena) && is_scalar_ae(*falsy, expr_arena) + }, + AExpr::Agg(_) | AExpr::Len => true, + AExpr::Cast { expr, .. } | AExpr::Alias(expr, _) => is_scalar_ae(*expr, expr_arena), + _ => false, + } +} diff --git a/crates/polars-plan/src/plans/aexpr/schema.rs b/crates/polars-plan/src/plans/aexpr/schema.rs new file mode 100644 index 000000000000..1b478816b08f --- /dev/null +++ b/crates/polars-plan/src/plans/aexpr/schema.rs @@ -0,0 +1,751 @@ +#[cfg(feature = "dtype-decimal")] +use polars_core::chunked_array::arithmetic::{ + _get_decimal_scale_add_sub, _get_decimal_scale_div, _get_decimal_scale_mul, +}; +use recursive::recursive; + +use super::*; + +fn float_type(field: &mut Field) { + let should_coerce = match &field.dtype { + DataType::Float32 => false, + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(..) => true, + DataType::Boolean => true, + dt => dt.is_primitive_numeric(), + }; + if should_coerce { + field.coerce(DataType::Float64); + } +} + +fn validate_expr(node: Node, arena: &Arena, schema: &Schema) -> PolarsResult<()> { + let mut ctx = ToFieldContext { + schema, + ctx: Context::Default, + arena, + validate: true, + }; + arena + .get(node) + .to_field_impl(&mut ctx, &mut false) + .map(|_| ()) +} + +struct ToFieldContext<'a> { + schema: &'a Schema, + ctx: Context, + arena: &'a Arena, + // Traverse all expressions to validate they are in the schema. + validate: bool, +} + +impl AExpr { + pub fn to_dtype( + &self, + schema: &Schema, + ctx: Context, + arena: &Arena, + ) -> PolarsResult { + self.to_field(schema, ctx, arena).map(|f| f.dtype) + } + + /// Get Field result of the expression. The schema is the input data. + pub fn to_field( + &self, + schema: &Schema, + ctx: Context, + arena: &Arena, + ) -> PolarsResult { + // Indicates whether we should auto-implode the result. This is initialized to true if we are + // in an aggregation context, so functions that return scalars should explicitly set this + // to false in `to_field_impl`. + let mut agg_list = matches!(ctx, Context::Aggregation); + let mut ctx = ToFieldContext { + schema, + ctx, + arena, + validate: true, + }; + let mut field = self.to_field_impl(&mut ctx, &mut agg_list)?; + + if agg_list { + field.coerce(field.dtype().clone().implode()); + } + + Ok(field) + } + + /// Get Field result of the expression. The schema is the input data. + pub fn to_field_and_validate( + &self, + schema: &Schema, + ctx: Context, + arena: &Arena, + ) -> PolarsResult { + // Indicates whether we should auto-implode the result. This is initialized to true if we are + // in an aggregation context, so functions that return scalars should explicitly set this + // to false in `to_field_impl`. + let mut agg_list = matches!(ctx, Context::Aggregation); + + let mut ctx = ToFieldContext { + schema, + ctx, + arena, + validate: true, + }; + let mut field = self.to_field_impl(&mut ctx, &mut agg_list)?; + + if agg_list { + field.coerce(field.dtype().clone().implode()); + } + + Ok(field) + } + + /// Get Field result of the expression. The schema is the input data. + /// + /// This is taken as `&mut bool` as for some expressions this is determined by the upper node + /// (e.g. `alias`, `cast`). + #[recursive] + pub fn to_field_impl( + &self, + ctx: &mut ToFieldContext, + agg_list: &mut bool, + ) -> PolarsResult { + use AExpr::*; + use DataType::*; + match self { + Len => { + *agg_list = false; + Ok(Field::new(PlSmallStr::from_static(LEN), IDX_DTYPE)) + }, + Window { + function, + options, + partition_by, + order_by, + } => { + if let WindowType::Over(WindowMapping::Join) = options { + // expr.over(..), defaults to agg-list unless explicitly unset + // by the `to_field_impl` of the `expr` + *agg_list = true; + } + + if ctx.validate { + for node in partition_by { + validate_expr(*node, ctx.arena, ctx.schema)?; + } + if let Some((node, _)) = order_by { + validate_expr(*node, ctx.arena, ctx.schema)?; + } + } + + let e = ctx.arena.get(*function); + e.to_field_impl(ctx, agg_list) + }, + Explode(expr) => { + // `Explode` is a "flatten" operation, which is not the same as returning a scalar. + // Namely, it should be auto-imploded in the aggregation context, so we don't update + // the `agg_list` state here. + let field = ctx.arena.get(*expr).to_field_impl(ctx, &mut false)?; + + let field = match field.dtype() { + List(inner) => Field::new(field.name().clone(), *inner.clone()), + #[cfg(feature = "dtype-array")] + Array(inner, ..) => Field::new(field.name().clone(), *inner.clone()), + _ => field, + }; + + Ok(field) + }, + Alias(expr, name) => Ok(Field::new( + name.clone(), + ctx.arena.get(*expr).to_field_impl(ctx, agg_list)?.dtype, + )), + Column(name) => ctx + .schema + .get_field(name) + .ok_or_else(|| PolarsError::ColumnNotFound(name.to_string().into())), + Literal(sv) => { + *agg_list = false; + Ok(match sv { + LiteralValue::Series(s) => s.field().into_owned(), + _ => Field::new(sv.output_name().clone(), sv.get_datatype()), + }) + }, + BinaryExpr { left, right, op } => { + use DataType::*; + + let field = match op { + Operator::Lt + | Operator::Gt + | Operator::Eq + | Operator::NotEq + | Operator::LogicalAnd + | Operator::LtEq + | Operator::GtEq + | Operator::NotEqValidity + | Operator::EqValidity + | Operator::LogicalOr => { + let out_field; + let out_name = { + out_field = ctx.arena.get(*left).to_field_impl(ctx, agg_list)?; + out_field.name() + }; + Field::new(out_name.clone(), Boolean) + }, + Operator::TrueDivide => get_truediv_field(*left, *right, ctx, agg_list)?, + _ => get_arithmetic_field(*left, *right, *op, ctx, agg_list)?, + }; + + Ok(field) + }, + Sort { expr, .. } => ctx.arena.get(*expr).to_field_impl(ctx, agg_list), + Gather { + expr, + idx, + returns_scalar, + .. + } => { + if *returns_scalar { + *agg_list = false; + } + if ctx.validate { + validate_expr(*idx, ctx.arena, ctx.schema)? + } + ctx.arena.get(*expr).to_field_impl(ctx, &mut false) + }, + SortBy { expr, .. } => ctx.arena.get(*expr).to_field_impl(ctx, agg_list), + Filter { input, by } => { + if ctx.validate { + validate_expr(*by, ctx.arena, ctx.schema)? + } + ctx.arena.get(*input).to_field_impl(ctx, agg_list) + }, + Agg(agg) => { + use IRAggExpr::*; + match agg { + Max { input: expr, .. } + | Min { input: expr, .. } + | First(expr) + | Last(expr) => { + *agg_list = false; + ctx.arena.get(*expr).to_field_impl(ctx, &mut false) + }, + Sum(expr) => { + *agg_list = false; + let mut field = ctx.arena.get(*expr).to_field_impl(ctx, &mut false)?; + let dt = match field.dtype() { + Boolean => Some(IDX_DTYPE), + UInt8 | Int8 | Int16 | UInt16 => Some(Int64), + _ => None, + }; + if let Some(dt) = dt { + field.coerce(dt); + } + Ok(field) + }, + Median(expr) => { + *agg_list = false; + let mut field = ctx.arena.get(*expr).to_field_impl(ctx, &mut false)?; + match field.dtype { + Date => field.coerce(Datetime(TimeUnit::Milliseconds, None)), + _ => float_type(&mut field), + } + Ok(field) + }, + Mean(expr) => { + *agg_list = false; + let mut field = ctx.arena.get(*expr).to_field_impl(ctx, &mut false)?; + match field.dtype { + Date => field.coerce(Datetime(TimeUnit::Milliseconds, None)), + _ => float_type(&mut field), + } + Ok(field) + }, + Implode(expr) => { + *agg_list = false; + let mut field = ctx.arena.get(*expr).to_field_impl(ctx, &mut false)?; + field.coerce(DataType::List(field.dtype().clone().into())); + Ok(field) + }, + Std(expr, _) => { + *agg_list = false; + let mut field = ctx.arena.get(*expr).to_field_impl(ctx, &mut false)?; + float_type(&mut field); + Ok(field) + }, + Var(expr, _) => { + *agg_list = false; + let mut field = ctx.arena.get(*expr).to_field_impl(ctx, &mut false)?; + float_type(&mut field); + Ok(field) + }, + NUnique(expr) => { + *agg_list = false; + let mut field = ctx.arena.get(*expr).to_field_impl(ctx, &mut false)?; + field.coerce(IDX_DTYPE); + Ok(field) + }, + Count(expr, _) => { + *agg_list = false; + let mut field = ctx.arena.get(*expr).to_field_impl(ctx, &mut false)?; + field.coerce(IDX_DTYPE); + Ok(field) + }, + AggGroups(expr) => { + *agg_list = true; + let mut field = ctx.arena.get(*expr).to_field_impl(ctx, &mut false)?; + field.coerce(List(IDX_DTYPE.into())); + Ok(field) + }, + Quantile { expr, .. } => { + *agg_list = false; + let mut field = ctx.arena.get(*expr).to_field_impl(ctx, &mut false)?; + float_type(&mut field); + Ok(field) + }, + } + }, + Cast { expr, dtype, .. } => { + let field = ctx.arena.get(*expr).to_field_impl(ctx, agg_list)?; + Ok(Field::new(field.name().clone(), dtype.clone())) + }, + Ternary { truthy, falsy, .. } => { + let mut agg_list_truthy = *agg_list; + let mut agg_list_falsy = *agg_list; + + // During aggregation: + // left: col(foo): list nesting: 1 + // right; col(foo).first(): T nesting: 0 + // col(foo) + col(foo).first() will have nesting 1 as we still maintain the groups list. + let mut truthy = ctx + .arena + .get(*truthy) + .to_field_impl(ctx, &mut agg_list_truthy)?; + let falsy = ctx + .arena + .get(*falsy) + .to_field_impl(ctx, &mut agg_list_falsy)?; + + let st = if let DataType::Null = *truthy.dtype() { + falsy.dtype().clone() + } else { + try_get_supertype(truthy.dtype(), falsy.dtype())? + }; + + *agg_list = agg_list_truthy | agg_list_falsy; + + truthy.coerce(st); + Ok(truthy) + }, + AnonymousFunction { + output_type, + input, + options, + .. + } => { + let fields = func_args_to_fields(input, ctx, agg_list)?; + polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", options.fmt_str); + let out = output_type.get_field(ctx.schema, ctx.ctx, &fields)?; + + if options.flags.contains(FunctionFlags::RETURNS_SCALAR) { + *agg_list = false; + } else if !options.is_elementwise() && matches!(ctx.ctx, Context::Aggregation) { + *agg_list = true; + } + + Ok(out) + }, + Function { + function, + input, + options, + } => { + let fields = func_args_to_fields(input, ctx, agg_list)?; + polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", function); + let out = function.get_field(ctx.schema, ctx.ctx, &fields)?; + + if options.flags.contains(FunctionFlags::RETURNS_SCALAR) { + *agg_list = false; + } else if !options.is_elementwise() && matches!(ctx.ctx, Context::Aggregation) { + *agg_list = true; + } + + Ok(out) + }, + Slice { + input, + offset, + length, + } => { + if ctx.validate { + validate_expr(*offset, ctx.arena, ctx.schema)?; + validate_expr(*length, ctx.arena, ctx.schema)?; + } + + ctx.arena.get(*input).to_field_impl(ctx, agg_list) + }, + } + } +} + +fn func_args_to_fields( + input: &[ExprIR], + ctx: &mut ToFieldContext, + agg_list: &mut bool, +) -> PolarsResult> { + input + .iter() + .enumerate() + // Default context because `col()` would return a list in aggregation context + .map(|(i, e)| { + let tmp = &mut false; + + ctx.arena + .get(e.node()) + .to_field_impl( + ctx, + if i == 0 { + // Only mutate first agg_list as that is the dtype of the function. + agg_list + } else { + tmp + }, + ) + .map(|mut field| { + field.name = e.output_name().clone(); + field + }) + }) + .collect() +} + +#[allow(clippy::too_many_arguments)] +fn get_arithmetic_field( + left: Node, + right: Node, + op: Operator, + ctx: &mut ToFieldContext, + agg_list: &mut bool, +) -> PolarsResult { + use DataType::*; + let left_ae = ctx.arena.get(left); + let right_ae = ctx.arena.get(right); + + // don't traverse tree until strictly needed. Can have terrible performance. + // # 3210 + + // take the left field as a whole. + // don't take dtype and name separate as that splits the tree every node + // leading to quadratic behavior. # 4736 + // + // further right_type is only determined when needed. + let mut left_field = left_ae.to_field_impl(ctx, agg_list)?; + + let super_type = match op { + Operator::Minus => { + let right_type = right_ae.to_field_impl(ctx, agg_list)?.dtype; + match (&left_field.dtype, &right_type) { + #[cfg(feature = "dtype-struct")] + (Struct(_), Struct(_)) => { + return Ok(left_field); + }, + (Duration(_), Datetime(_, _)) + | (Datetime(_, _), Duration(_)) + | (Duration(_), Date) + | (Date, Duration(_)) + | (Duration(_), Time) + | (Time, Duration(_)) => try_get_supertype(left_field.dtype(), &right_type)?, + (Datetime(tu, _), Date) | (Date, Datetime(tu, _)) => Duration(*tu), + // T - T != T if T is a datetime / date + (Datetime(tul, _), Datetime(tur, _)) => Duration(get_time_units(tul, tur)), + (_, Datetime(_, _)) | (Datetime(_, _), _) => { + polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) + }, + (Date, Date) => Duration(TimeUnit::Milliseconds), + (_, Date) | (Date, _) => { + polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) + }, + (Duration(tul), Duration(tur)) => Duration(get_time_units(tul, tur)), + (_, Duration(_)) | (Duration(_), _) => { + polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) + }, + (Time, Time) => Duration(TimeUnit::Nanoseconds), + (_, Time) | (Time, _) => { + polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) + }, + (l @ List(a), r @ List(b)) + if ![a, b] + .into_iter() + .all(|x| x.is_supported_list_arithmetic_input()) => + { + polars_bail!( + InvalidOperation: + "cannot {} two list columns with non-numeric inner types: (left: {}, right: {})", + "sub", l, r, + ) + }, + (list_dtype @ List(_), other_dtype) | (other_dtype, list_dtype @ List(_)) => { + // FIXME: This should not use `try_get_supertype()`! It should instead recursively use the enclosing match block. + // Otherwise we will silently permit addition operations between logical types (see above). + // This currently doesn't cause any problems because the list arithmetic implementation checks and raises errors + // if the leaf types aren't numeric, but it means we don't raise an error until execution and the DSL schema + // may be incorrect. + list_dtype.cast_leaf(try_get_supertype( + list_dtype.leaf_dtype(), + other_dtype.leaf_dtype(), + )?) + }, + #[cfg(feature = "dtype-array")] + (list_dtype @ Array(..), other_dtype) | (other_dtype, list_dtype @ Array(..)) => { + list_dtype.cast_leaf(try_get_supertype( + list_dtype.leaf_dtype(), + other_dtype.leaf_dtype(), + )?) + }, + #[cfg(feature = "dtype-decimal")] + (Decimal(_, Some(scale_left)), Decimal(_, Some(scale_right))) => { + let scale = _get_decimal_scale_add_sub(*scale_left, *scale_right); + Decimal(None, Some(scale)) + }, + (left, right) => try_get_supertype(left, right)?, + } + }, + Operator::Plus => { + let right_type = right_ae.to_field_impl(ctx, agg_list)?.dtype; + match (&left_field.dtype, &right_type) { + (Duration(_), Datetime(_, _)) + | (Datetime(_, _), Duration(_)) + | (Duration(_), Date) + | (Date, Duration(_)) + | (Duration(_), Time) + | (Time, Duration(_)) => try_get_supertype(left_field.dtype(), &right_type)?, + (_, Datetime(_, _)) + | (Datetime(_, _), _) + | (_, Date) + | (Date, _) + | (Time, _) + | (_, Time) => { + polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) + }, + (Duration(tul), Duration(tur)) => Duration(get_time_units(tul, tur)), + (_, Duration(_)) | (Duration(_), _) => { + polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) + }, + (Boolean, Boolean) => IDX_DTYPE, + (l @ List(a), r @ List(b)) + if ![a, b] + .into_iter() + .all(|x| x.is_supported_list_arithmetic_input()) => + { + polars_bail!( + InvalidOperation: + "cannot {} two list columns with non-numeric inner types: (left: {}, right: {})", + "add", l, r, + ) + }, + (list_dtype @ List(_), other_dtype) | (other_dtype, list_dtype @ List(_)) => { + list_dtype.cast_leaf(try_get_supertype( + list_dtype.leaf_dtype(), + other_dtype.leaf_dtype(), + )?) + }, + #[cfg(feature = "dtype-array")] + (list_dtype @ Array(..), other_dtype) | (other_dtype, list_dtype @ Array(..)) => { + list_dtype.cast_leaf(try_get_supertype( + list_dtype.leaf_dtype(), + other_dtype.leaf_dtype(), + )?) + }, + #[cfg(feature = "dtype-decimal")] + (Decimal(_, Some(scale_left)), Decimal(_, Some(scale_right))) => { + let scale = _get_decimal_scale_add_sub(*scale_left, *scale_right); + Decimal(None, Some(scale)) + }, + (left, right) => try_get_supertype(left, right)?, + } + }, + _ => { + let right_type = right_ae.to_field_impl(ctx, agg_list)?.dtype; + + match (&left_field.dtype, &right_type) { + #[cfg(feature = "dtype-struct")] + (Struct(_), Struct(_)) => { + return Ok(left_field); + }, + (Datetime(_, _), _) + | (_, Datetime(_, _)) + | (Time, _) + | (_, Time) + | (Date, _) + | (_, Date) => { + polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) + }, + (Duration(_), Duration(_)) => { + // True divide handled somewhere else + polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) + }, + (l, Duration(_)) if l.is_primitive_numeric() => match op { + Operator::Multiply => { + left_field.coerce(right_type); + return Ok(left_field); + }, + _ => { + polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) + }, + }, + (Duration(_), r) if r.is_primitive_numeric() => match op { + Operator::Multiply => { + return Ok(left_field); + }, + _ => { + polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) + }, + }, + #[cfg(feature = "dtype-decimal")] + (Decimal(_, Some(scale_left)), Decimal(_, Some(scale_right))) => { + let scale = match op { + Operator::Multiply => _get_decimal_scale_mul(*scale_left, *scale_right), + Operator::Divide | Operator::TrueDivide => { + _get_decimal_scale_div(*scale_left) + }, + _ => { + debug_assert!(false); + *scale_left + }, + }; + let dtype = Decimal(None, Some(scale)); + left_field.coerce(dtype); + return Ok(left_field); + }, + + (l @ List(a), r @ List(b)) + if ![a, b] + .into_iter() + .all(|x| x.is_supported_list_arithmetic_input()) => + { + polars_bail!( + InvalidOperation: + "cannot {} two list columns with non-numeric inner types: (left: {}, right: {})", + op, l, r, + ) + }, + // List<->primitive operations can be done directly after casting the to the primitive + // supertype for the primitive values on both sides. + (list_dtype @ List(_), other_dtype) | (other_dtype, list_dtype @ List(_)) => { + let dtype = list_dtype.cast_leaf(try_get_supertype( + list_dtype.leaf_dtype(), + other_dtype.leaf_dtype(), + )?); + left_field.coerce(dtype); + return Ok(left_field); + }, + #[cfg(feature = "dtype-array")] + (list_dtype @ Array(..), other_dtype) | (other_dtype, list_dtype @ Array(..)) => { + let dtype = list_dtype.cast_leaf(try_get_supertype( + list_dtype.leaf_dtype(), + other_dtype.leaf_dtype(), + )?); + left_field.coerce(dtype); + return Ok(left_field); + }, + _ => { + // Avoid needlessly type casting numeric columns during arithmetic + // with literals. + if (left_field.dtype.is_integer() && right_type.is_integer()) + || (left_field.dtype.is_float() && right_type.is_float()) + { + match (left_ae, right_ae) { + (AExpr::Literal(_), AExpr::Literal(_)) => {}, + (AExpr::Literal(_), _) => { + // literal will be coerced to match right type + left_field.coerce(right_type); + return Ok(left_field); + }, + (_, AExpr::Literal(_)) => { + // literal will be coerced to match right type + return Ok(left_field); + }, + _ => {}, + } + } + }, + } + + try_get_supertype(&left_field.dtype, &right_type)? + }, + }; + + left_field.coerce(super_type); + Ok(left_field) +} + +fn get_truediv_field( + left: Node, + right: Node, + ctx: &mut ToFieldContext, + agg_list: &mut bool, +) -> PolarsResult { + let mut left_field = ctx.arena.get(left).to_field_impl(ctx, agg_list)?; + let right_field = ctx.arena.get(right).to_field_impl(ctx, agg_list)?; + let out_type = get_truediv_dtype(left_field.dtype(), right_field.dtype())?; + left_field.coerce(out_type); + Ok(left_field) +} + +fn get_truediv_dtype(left_dtype: &DataType, right_dtype: &DataType) -> PolarsResult { + use DataType::*; + + // TODO: Re-investigate this. A lot of "_" is being used on the RHS match because this code + // originally (mostly) only looked at the LHS dtype. + let out_type = match (left_dtype, right_dtype) { + (l @ List(a), r @ List(b)) + if ![a, b] + .into_iter() + .all(|x| x.is_supported_list_arithmetic_input()) => + { + polars_bail!( + InvalidOperation: + "cannot {} two list columns with non-numeric inner types: (left: {}, right: {})", + "div", l, r, + ) + }, + (list_dtype @ List(_), other_dtype) | (other_dtype, list_dtype @ List(_)) => { + let dtype = get_truediv_dtype(list_dtype.leaf_dtype(), other_dtype.leaf_dtype())?; + list_dtype.cast_leaf(dtype) + }, + #[cfg(feature = "dtype-array")] + (list_dtype @ Array(..), other_dtype) | (other_dtype, list_dtype @ Array(..)) => { + let dtype = get_truediv_dtype(list_dtype.leaf_dtype(), other_dtype.leaf_dtype())?; + list_dtype.cast_leaf(dtype) + }, + (Float32, _) => Float32, + #[cfg(feature = "dtype-decimal")] + (Decimal(_, Some(scale_left)), Decimal(_, _)) => { + let scale = _get_decimal_scale_div(*scale_left); + Decimal(None, Some(scale)) + }, + (dt, _) if dt.is_primitive_numeric() => Float64, + #[cfg(feature = "dtype-duration")] + (Duration(_), Duration(_)) => Float64, + #[cfg(feature = "dtype-duration")] + (Duration(_), dt) if dt.is_primitive_numeric() => left_dtype.clone(), + #[cfg(feature = "dtype-duration")] + (Duration(_), dt) => { + polars_bail!(InvalidOperation: "true division of {} with {} is not allowed", left_dtype, dt) + }, + #[cfg(feature = "dtype-datetime")] + (Datetime(_, _), _) => { + polars_bail!(InvalidOperation: "division of 'Datetime' datatype is not allowed") + }, + #[cfg(feature = "dtype-time")] + (Time, _) => polars_bail!(InvalidOperation: "division of 'Time' datatype is not allowed"), + #[cfg(feature = "dtype-date")] + (Date, _) => polars_bail!(InvalidOperation: "division of 'Date' datatype is not allowed"), + // we don't know what to do here, best return the dtype + (dt, _) => dt.clone(), + }; + Ok(out_type) +} diff --git a/crates/polars-plan/src/plans/aexpr/traverse.rs b/crates/polars-plan/src/plans/aexpr/traverse.rs new file mode 100644 index 000000000000..1b1b34c535f3 --- /dev/null +++ b/crates/polars-plan/src/plans/aexpr/traverse.rs @@ -0,0 +1,214 @@ +use super::*; + +impl AExpr { + /// Push the inputs of this node to the given container, in reverse order. + /// This ensures the primary node responsible for the name is pushed last. + pub fn inputs_rev(&self, container: &mut E) + where + E: Extend, + { + use AExpr::*; + + match self { + Column(_) | Literal(_) | Len => {}, + Alias(e, _) => container.extend([*e]), + BinaryExpr { left, op: _, right } => { + container.extend([*right, *left]); + }, + Cast { expr, .. } => container.extend([*expr]), + Sort { expr, .. } => container.extend([*expr]), + Gather { expr, idx, .. } => { + container.extend([*idx, *expr]); + }, + SortBy { expr, by, .. } => { + container.extend(by.iter().cloned().rev()); + container.extend([*expr]); + }, + Filter { input, by } => { + container.extend([*by, *input]); + }, + Agg(agg_e) => match agg_e.get_input() { + NodeInputs::Single(node) => container.extend([node]), + NodeInputs::Many(nodes) => container.extend(nodes.into_iter().rev()), + NodeInputs::Leaf => {}, + }, + Ternary { + truthy, + falsy, + predicate, + } => { + container.extend([*predicate, *falsy, *truthy]); + }, + AnonymousFunction { input, .. } | Function { input, .. } => { + container.extend(input.iter().rev().map(|e| e.node())) + }, + Explode(e) => container.extend([*e]), + Window { + function, + partition_by, + order_by, + options: _, + } => { + if let Some((n, _)) = order_by { + container.extend([*n]); + } + container.extend(partition_by.iter().rev().cloned()); + container.extend([*function]); + }, + Slice { + input, + offset, + length, + } => { + container.extend([*length, *offset, *input]); + }, + } + } + + pub fn replace_inputs(mut self, inputs: &[Node]) -> Self { + use AExpr::*; + let input = match &mut self { + Column(_) | Literal(_) | Len => return self, + Alias(input, _) => input, + Cast { expr, .. } => expr, + Explode(input) => input, + BinaryExpr { left, right, .. } => { + *left = inputs[0]; + *right = inputs[1]; + return self; + }, + Gather { expr, idx, .. } => { + *expr = inputs[0]; + *idx = inputs[1]; + return self; + }, + Sort { expr, .. } => expr, + SortBy { expr, by, .. } => { + *expr = inputs[0]; + by.clear(); + by.extend_from_slice(&inputs[1..]); + return self; + }, + Filter { input, by, .. } => { + *input = inputs[0]; + *by = inputs[1]; + return self; + }, + Agg(a) => { + match a { + IRAggExpr::Quantile { expr, quantile, .. } => { + *expr = inputs[0]; + *quantile = inputs[1]; + }, + _ => { + a.set_input(inputs[0]); + }, + } + return self; + }, + Ternary { + truthy, + falsy, + predicate, + } => { + *truthy = inputs[0]; + *falsy = inputs[1]; + *predicate = inputs[2]; + return self; + }, + AnonymousFunction { input, .. } | Function { input, .. } => { + assert_eq!(input.len(), inputs.len()); + for (e, node) in input.iter_mut().zip(inputs.iter()) { + e.set_node(*node); + } + return self; + }, + Slice { + input, + offset, + length, + } => { + *input = inputs[0]; + *offset = inputs[1]; + *length = inputs[2]; + return self; + }, + Window { + function, + partition_by, + order_by, + .. + } => { + let offset = order_by.is_some() as usize; + *function = inputs[0]; + partition_by.clear(); + partition_by.extend_from_slice(&inputs[1..inputs.len() - offset]); + if let Some((_, options)) = order_by { + *order_by = Some((*inputs.last().unwrap(), *options)); + } + return self; + }, + }; + *input = inputs[0]; + self + } +} + +impl IRAggExpr { + pub fn get_input(&self) -> NodeInputs { + use IRAggExpr::*; + use NodeInputs::*; + match self { + Min { input, .. } => Single(*input), + Max { input, .. } => Single(*input), + Median(input) => Single(*input), + NUnique(input) => Single(*input), + First(input) => Single(*input), + Last(input) => Single(*input), + Mean(input) => Single(*input), + Implode(input) => Single(*input), + Quantile { expr, quantile, .. } => Many(vec![*expr, *quantile]), + Sum(input) => Single(*input), + Count(input, _) => Single(*input), + Std(input, _) => Single(*input), + Var(input, _) => Single(*input), + AggGroups(input) => Single(*input), + } + } + pub fn set_input(&mut self, input: Node) { + use IRAggExpr::*; + let node = match self { + Min { input, .. } => input, + Max { input, .. } => input, + Median(input) => input, + NUnique(input) => input, + First(input) => input, + Last(input) => input, + Mean(input) => input, + Implode(input) => input, + Quantile { expr, .. } => expr, + Sum(input) => input, + Count(input, _) => input, + Std(input, _) => input, + Var(input, _) => input, + AggGroups(input) => input, + }; + *node = input; + } +} + +pub enum NodeInputs { + Leaf, + Single(Node), + Many(Vec), +} + +impl NodeInputs { + pub fn first(&self) -> Node { + match self { + NodeInputs::Single(node) => *node, + NodeInputs::Many(nodes) => nodes[0], + NodeInputs::Leaf => panic!(), + } + } +} diff --git a/crates/polars-plan/src/plans/anonymous_scan.rs b/crates/polars-plan/src/plans/anonymous_scan.rs new file mode 100644 index 000000000000..3be098f8bf8e --- /dev/null +++ b/crates/polars-plan/src/plans/anonymous_scan.rs @@ -0,0 +1,56 @@ +use std::any::Any; +use std::fmt::{Debug, Formatter}; + +use polars_core::prelude::*; + +use crate::dsl::Expr; + +pub struct AnonymousScanArgs { + pub n_rows: Option, + pub with_columns: Option>, + pub schema: SchemaRef, + pub output_schema: Option, + pub predicate: Option, +} + +pub trait AnonymousScan: Send + Sync { + fn as_any(&self) -> &dyn Any; + /// Creates a DataFrame from the supplied function & scan options. + fn scan(&self, scan_opts: AnonymousScanArgs) -> PolarsResult; + + /// Produce the next batch Polars can consume. Implement this method to get proper + /// streaming support. + fn next_batch(&self, scan_opts: AnonymousScanArgs) -> PolarsResult> { + self.scan(scan_opts).map(Some) + } + + /// function to supply the schema. + /// Allows for an optional infer schema argument for data sources with dynamic schemas + fn schema(&self, _infer_schema_length: Option) -> PolarsResult { + polars_bail!(ComputeError: "must supply either a schema or a schema function"); + } + /// Specify if the scan provider should allow predicate pushdowns. + /// + /// Defaults to `false` + fn allows_predicate_pushdown(&self) -> bool { + false + } + /// Specify if the scan provider should allow projection pushdowns. + /// + /// Defaults to `false` + fn allows_projection_pushdown(&self) -> bool { + false + } + /// Specify if the scan provider should allow slice pushdowns. + /// + /// Defaults to `false` + fn allows_slice_pushdown(&self) -> bool { + false + } +} + +impl Debug for dyn AnonymousScan { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "anonymous_scan") + } +} diff --git a/crates/polars-plan/src/plans/apply.rs b/crates/polars-plan/src/plans/apply.rs new file mode 100644 index 000000000000..468331ef8332 --- /dev/null +++ b/crates/polars-plan/src/plans/apply.rs @@ -0,0 +1,59 @@ +use std::fmt::{Debug, Formatter}; + +use polars_core::prelude::*; + +pub trait DataFrameUdf: Send + Sync { + fn call_udf(&self, df: DataFrame) -> PolarsResult; +} + +impl DataFrameUdf for F +where + F: Fn(DataFrame) -> PolarsResult + Send + Sync, +{ + fn call_udf(&self, df: DataFrame) -> PolarsResult { + self(df) + } +} + +pub trait DataFrameUdfMut: Send + Sync { + fn call_udf(&mut self, df: DataFrame) -> PolarsResult; +} + +impl DataFrameUdfMut for F +where + F: FnMut(DataFrame) -> PolarsResult + Send + Sync, +{ + fn call_udf(&mut self, df: DataFrame) -> PolarsResult { + self(df) + } +} + +impl Debug for dyn DataFrameUdf { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "dyn DataFrameUdf") + } +} +impl Debug for dyn DataFrameUdfMut { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "dyn DataFrameUdfMut") + } +} + +pub trait UdfSchema: Send + Sync { + fn get_schema(&self, input_schema: &Schema) -> PolarsResult; +} + +impl UdfSchema for F +where + F: Fn(&Schema) -> PolarsResult + Send + Sync, +{ + fn get_schema(&self, input_schema: &Schema) -> PolarsResult { + self(input_schema) + } +} + +impl Debug for dyn UdfSchema { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "dyn UdfSchema") + } +} diff --git a/crates/polars-plan/src/plans/builder_ir.rs b/crates/polars-plan/src/plans/builder_ir.rs new file mode 100644 index 000000000000..bf84f989109e --- /dev/null +++ b/crates/polars-plan/src/plans/builder_ir.rs @@ -0,0 +1,330 @@ +use std::borrow::Cow; + +use super::*; + +pub struct IRBuilder<'a> { + root: Node, + expr_arena: &'a mut Arena, + lp_arena: &'a mut Arena, +} + +impl<'a> IRBuilder<'a> { + pub fn new(root: Node, expr_arena: &'a mut Arena, lp_arena: &'a mut Arena) -> Self { + IRBuilder { + root, + expr_arena, + lp_arena, + } + } + + pub fn from_lp(lp: IR, expr_arena: &'a mut Arena, lp_arena: &'a mut Arena) -> Self { + let root = lp_arena.add(lp); + IRBuilder { + root, + expr_arena, + lp_arena, + } + } + + pub fn add_alp(self, lp: IR) -> Self { + let node = self.lp_arena.add(lp); + IRBuilder::new(node, self.expr_arena, self.lp_arena) + } + + /// Adds IR and runs optimizations on its expressions (simplify, coerce, type-check). + pub fn add_alp_optimize_exprs(self, f: F) -> PolarsResult + where + F: FnOnce(Node) -> IR, + { + let lp = f(self.root); + let ir_name = lp.name(); + + let b = self.add_alp(lp); + + // Run the optimizer + let mut conversion_optimizer = ConversionOptimizer::new(true, true, true); + conversion_optimizer.fill_scratch(&b.lp_arena.get(b.root).get_exprs(), b.expr_arena); + conversion_optimizer + .optimize_exprs(b.expr_arena, b.lp_arena, b.root) + .map_err(|e| e.context(format!("optimizing '{ir_name}' failed").into()))?; + + Ok(b) + } + + /// An escape hatch to add an `Expr`. Working with IR is preferred. + pub fn add_expr(&mut self, expr: Expr) -> PolarsResult { + to_expr_ir(expr, self.expr_arena) + } + + pub fn project(self, exprs: Vec, options: ProjectionOptions) -> Self { + // if len == 0, no projection has to be done. This is a select all operation. + if exprs.is_empty() { + self + } else { + let input_schema = self.schema(); + let schema = + expr_irs_to_schema(&exprs, &input_schema, Context::Default, self.expr_arena); + + let lp = IR::Select { + expr: exprs, + input: self.root, + schema: Arc::new(schema), + options, + }; + let node = self.lp_arena.add(lp); + IRBuilder::new(node, self.expr_arena, self.lp_arena) + } + } + + pub fn project_simple_nodes(self, nodes: I) -> PolarsResult + where + I: IntoIterator, + N: Into, + I::IntoIter: ExactSizeIterator, + { + let names = nodes + .into_iter() + .map(|node| match self.expr_arena.get(node.into()) { + AExpr::Column(name) => name, + _ => unreachable!(), + }); + // This is a duplication of `project_simple` because we already borrow self.expr_arena :/ + if names.size_hint().0 == 0 { + Ok(self) + } else { + let input_schema = self.schema(); + let mut count = 0; + let schema = names + .map(|name| { + let dtype = input_schema.try_get(name)?; + count += 1; + Ok(Field::new(name.clone(), dtype.clone())) + }) + .collect::>()?; + + polars_ensure!(count == schema.len(), Duplicate: "found duplicate columns"); + + let lp = IR::SimpleProjection { + input: self.root, + columns: Arc::new(schema), + }; + let node = self.lp_arena.add(lp); + Ok(IRBuilder::new(node, self.expr_arena, self.lp_arena)) + } + } + + pub fn project_simple(self, names: I) -> PolarsResult + where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + S: Into, + { + let names = names.into_iter(); + // if len == 0, no projection has to be done. This is a select all operation. + if names.size_hint().0 == 0 { + Ok(self) + } else { + let input_schema = self.schema(); + let mut count = 0; + let schema = names + .map(|name| { + let name: PlSmallStr = name.into(); + let dtype = input_schema.try_get(name.as_str())?; + count += 1; + Ok(Field::new(name, dtype.clone())) + }) + .collect::>()?; + + polars_ensure!(count == schema.len(), Duplicate: "found duplicate columns"); + + let lp = IR::SimpleProjection { + input: self.root, + columns: Arc::new(schema), + }; + let node = self.lp_arena.add(lp); + Ok(IRBuilder::new(node, self.expr_arena, self.lp_arena)) + } + } + + pub fn node(self) -> Node { + self.root + } + + pub fn build(self) -> IR { + if self.root.0 == self.lp_arena.len() { + self.lp_arena.pop().unwrap() + } else { + self.lp_arena.take(self.root) + } + } + + pub fn schema(&'a self) -> Cow<'a, SchemaRef> { + self.lp_arena.get(self.root).schema(self.lp_arena) + } + + pub fn with_columns(self, exprs: Vec, options: ProjectionOptions) -> Self { + let schema = self.schema(); + let mut new_schema = (**schema).clone(); + + let hstack_schema = expr_irs_to_schema(&exprs, &schema, Context::Default, self.expr_arena); + new_schema.merge(hstack_schema); + + let lp = IR::HStack { + input: self.root, + exprs, + schema: Arc::new(new_schema), + options, + }; + self.add_alp(lp) + } + + pub fn with_columns_simple>(self, exprs: I, options: ProjectionOptions) -> Self + where + I: IntoIterator, + { + let schema = self.schema(); + let mut new_schema = (**schema).clone(); + + let iter = exprs.into_iter(); + let mut expr_irs = Vec::with_capacity(iter.size_hint().0); + for node in iter { + let node = node.into(); + let field = self + .expr_arena + .get(node) + .to_field(&schema, Context::Default, self.expr_arena) + .unwrap(); + + expr_irs.push( + ExprIR::new(node, OutputName::ColumnLhs(field.name.clone())) + .with_dtype(field.dtype.clone()), + ); + new_schema.with_column(field.name().clone(), field.dtype().clone()); + } + + let lp = IR::HStack { + input: self.root, + exprs: expr_irs, + schema: Arc::new(new_schema), + options, + }; + self.add_alp(lp) + } + + // call this if the schema needs to be updated + pub fn explode(self, columns: Arc<[PlSmallStr]>) -> Self { + let lp = IR::MapFunction { + input: self.root, + function: FunctionIR::Explode { + columns, + schema: Default::default(), + }, + }; + self.add_alp(lp) + } + + pub fn group_by( + self, + keys: Vec, + aggs: Vec, + apply: Option>, + maintain_order: bool, + options: Arc, + ) -> Self { + let current_schema = self.schema(); + let mut schema = + expr_irs_to_schema(&keys, ¤t_schema, Context::Default, self.expr_arena); + + #[cfg(feature = "dynamic_group_by")] + { + if let Some(options) = options.rolling.as_ref() { + let name = &options.index_column; + let dtype = current_schema.get(name).unwrap(); + schema.with_column(name.clone(), dtype.clone()); + } else if let Some(options) = options.dynamic.as_ref() { + let name = &options.index_column; + let dtype = current_schema.get(name).unwrap(); + if options.include_boundaries { + schema.with_column("_lower_boundary".into(), dtype.clone()); + schema.with_column("_upper_boundary".into(), dtype.clone()); + } + schema.with_column(name.clone(), dtype.clone()); + } + } + + let agg_schema = expr_irs_to_schema( + &aggs, + ¤t_schema, + Context::Aggregation, + self.expr_arena, + ); + schema.merge(agg_schema); + + let lp = IR::GroupBy { + input: self.root, + keys, + aggs, + schema: Arc::new(schema), + apply, + maintain_order, + options, + }; + self.add_alp(lp) + } + + pub fn join( + self, + other: Node, + left_on: Vec, + right_on: Vec, + options: Arc, + ) -> Self { + let schema_left = self.schema(); + let schema_right = self.lp_arena.get(other).schema(self.lp_arena); + + let schema = det_join_schema( + &schema_left, + &schema_right, + &left_on, + &right_on, + &options, + self.expr_arena, + ) + .unwrap(); + + let lp = IR::Join { + input_left: self.root, + input_right: other, + schema, + left_on, + right_on, + options, + }; + + self.add_alp(lp) + } + + #[cfg(feature = "pivot")] + pub fn unpivot(self, args: Arc) -> Self { + let lp = IR::MapFunction { + input: self.root, + function: FunctionIR::Unpivot { + args, + schema: Default::default(), + }, + }; + self.add_alp(lp) + } + + pub fn row_index(self, name: PlSmallStr, offset: Option) -> Self { + let lp = IR::MapFunction { + input: self.root, + function: FunctionIR::RowIndex { + name, + offset, + schema: Default::default(), + }, + }; + self.add_alp(lp) + } +} diff --git a/crates/polars-plan/src/plans/conversion/convert_utils.rs b/crates/polars-plan/src/plans/conversion/convert_utils.rs new file mode 100644 index 000000000000..51e6940483ba --- /dev/null +++ b/crates/polars-plan/src/plans/conversion/convert_utils.rs @@ -0,0 +1,118 @@ +use super::*; + +pub(super) fn convert_st_union( + inputs: &mut [Node], + lp_arena: &mut Arena, + expr_arena: &mut Arena, +) -> PolarsResult<()> { + let mut schema = (**lp_arena.get(inputs[0]).schema(lp_arena)).clone(); + + let mut changed = false; + for input in inputs[1..].iter() { + let schema_other = lp_arena.get(*input).schema(lp_arena); + changed |= schema.to_supertype(schema_other.as_ref())?; + } + + if changed { + for input in inputs { + let mut exprs = vec![]; + let input_schema = lp_arena.get(*input).schema(lp_arena); + + let to_cast = input_schema.iter().zip(schema.iter_values()).flat_map( + |((left_name, left_type), st)| { + if left_type != st { + Some(col(left_name.clone()).cast(st.clone())) + } else { + None + } + }, + ); + exprs.extend(to_cast); + + if !exprs.is_empty() { + let expr = to_expr_irs(exprs, expr_arena)?; + let lp = IRBuilder::new(*input, expr_arena, lp_arena) + .with_columns(expr, Default::default()) + .build(); + + let node = lp_arena.add(lp); + *input = node + } + } + } + Ok(()) +} + +fn nodes_to_schemas(inputs: &[Node], lp_arena: &mut Arena) -> Vec { + inputs + .iter() + .map(|n| lp_arena.get(*n).schema(lp_arena).into_owned()) + .collect() +} + +pub(super) fn convert_diagonal_concat( + mut inputs: Vec, + lp_arena: &mut Arena, + expr_arena: &mut Arena, +) -> PolarsResult> { + let schemas = nodes_to_schemas(&inputs, lp_arena); + + let upper_bound_width = schemas.iter().map(|sch| sch.len()).sum(); + + let mut total_schema = Schema::with_capacity(upper_bound_width); + + for sch in schemas.iter() { + sch.iter().for_each(|(name, dtype)| { + if !total_schema.contains(name) { + total_schema.with_column(name.as_str().into(), dtype.clone()); + } + }); + } + if total_schema.is_empty() { + return Ok(inputs); + } + + let mut has_empty = false; + + for (node, lf_schema) in inputs.iter_mut().zip(schemas.iter()) { + // Discard, this works physically + if lf_schema.is_empty() { + has_empty = true; + } + let mut columns_to_add = vec![]; + + for (name, dtype) in total_schema.iter() { + // If a name from Total Schema is not present - append + if lf_schema.get_field(name).is_none() { + columns_to_add.push(NULL.lit().cast(dtype.clone()).alias(name.clone())) + } + } + let expr = to_expr_irs(columns_to_add, expr_arena)?; + *node = IRBuilder::new(*node, expr_arena, lp_arena) + // Add the missing columns + .with_columns(expr, Default::default()) + // Now, reorder to match schema. + .project_simple(total_schema.iter_names().map(|v| v.as_str())) + .unwrap() + .node(); + } + + if has_empty { + Ok(inputs + .into_iter() + .zip(schemas) + .filter_map(|(input, schema)| if schema.is_empty() { None } else { Some(input) }) + .collect()) + } else { + Ok(inputs) + } +} + +pub(super) fn h_concat_schema( + inputs: &[Node], + lp_arena: &mut Arena, +) -> PolarsResult { + let schemas = nodes_to_schemas(inputs, lp_arena); + let combined_schema = merge_schemas(&schemas)?; + Ok(Arc::new(combined_schema)) +} diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs new file mode 100644 index 000000000000..afb1ebd73e98 --- /dev/null +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs @@ -0,0 +1,1263 @@ +use arrow::datatypes::ArrowSchemaRef; +use either::Either; +use expr_expansion::{is_regex_projection, rewrite_projections}; +use hive::hive_partitions_from_paths; + +use super::stack_opt::ConversionOptimizer; +use super::*; +use crate::plans::conversion::expr_expansion::expand_selectors; + +fn expand_expressions( + input: Node, + exprs: Vec, + lp_arena: &Arena, + expr_arena: &mut Arena, + opt_flags: &mut OptFlags, +) -> PolarsResult> { + let schema = lp_arena.get(input).schema(lp_arena); + let exprs = rewrite_projections(exprs, &schema, &[], opt_flags)?; + to_expr_irs(exprs, expr_arena) +} + +fn empty_df() -> IR { + IR::DataFrameScan { + df: Arc::new(Default::default()), + schema: Arc::new(Default::default()), + output_schema: None, + } +} + +fn validate_expression( + node: Node, + expr_arena: &Arena, + input_schema: &Schema, + operation_name: &str, +) -> PolarsResult<()> { + let iter = aexpr_to_leaf_names_iter(node, expr_arena); + validate_columns_in_input(iter, input_schema, operation_name) +} + +fn validate_expressions, I: IntoIterator>( + nodes: I, + expr_arena: &Arena, + input_schema: &Schema, + operation_name: &str, +) -> PolarsResult<()> { + let nodes = nodes.into_iter(); + + for node in nodes { + validate_expression(node.into(), expr_arena, input_schema, operation_name)? + } + Ok(()) +} + +macro_rules! failed_here { + ($($t:tt)*) => { + format!("'{}'", stringify!($($t)*)).into() + } +} +pub(super) use failed_here; + +pub fn to_alp( + lp: DslPlan, + expr_arena: &mut Arena, + lp_arena: &mut Arena, + // Only `SIMPLIFY_EXPR`, `TYPE_COERCION`, `TYPE_CHECK` are respected. + opt_flags: &mut OptFlags, +) -> PolarsResult { + let conversion_optimizer = ConversionOptimizer::new( + opt_flags.contains(OptFlags::SIMPLIFY_EXPR), + opt_flags.contains(OptFlags::TYPE_COERCION), + opt_flags.contains(OptFlags::TYPE_CHECK), + ); + + let mut ctxt = DslConversionContext { + expr_arena, + lp_arena, + conversion_optimizer, + opt_flags, + }; + + match to_alp_impl(lp, &mut ctxt) { + Ok(out) => Ok(out), + Err(err) => { + if let Some(ir_until_then) = lp_arena.last_node() { + let node_name = if let PolarsError::Context { msg, .. } = &err { + msg + } else { + "THIS_NODE" + }; + let plan = IRPlan::new( + ir_until_then, + std::mem::take(lp_arena), + std::mem::take(expr_arena), + ); + let location = format!("{}", plan.display()); + Err(err.wrap_msg(|msg| { + format!("{msg}\n\nResolved plan until failure:\n\n\t---> FAILED HERE RESOLVING {node_name} <---\n{location}") + })) + } else { + Err(err) + } + }, + } +} + +pub(super) struct DslConversionContext<'a> { + pub(super) expr_arena: &'a mut Arena, + pub(super) lp_arena: &'a mut Arena, + pub(super) conversion_optimizer: ConversionOptimizer, + pub(super) opt_flags: &'a mut OptFlags, +} + +pub(super) fn run_conversion( + lp: IR, + ctxt: &mut DslConversionContext, + name: &str, +) -> PolarsResult { + let lp_node = ctxt.lp_arena.add(lp); + ctxt.conversion_optimizer + .optimize_exprs(ctxt.expr_arena, ctxt.lp_arena, lp_node) + .map_err(|e| e.context(format!("'{name}' failed").into()))?; + + Ok(lp_node) +} + +/// converts LogicalPlan to IR +/// it adds expressions & lps to the respective arenas as it traverses the plan +/// finally it returns the top node of the logical plan +#[recursive] +pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult { + let owned = Arc::unwrap_or_clone; + + let v = match lp { + DslPlan::Scan { + sources, + file_info, + unified_scan_args: mut unified_scan_args_box, + scan_type, + cached_ir, + } => { + // Note that the first metadata can still end up being `None` later if the files were + // filtered from predicate pushdown. + let mut cached_ir = cached_ir.lock().unwrap(); + + if cached_ir.is_none() { + let cloud_options = unified_scan_args_box.cloud_options.clone(); + let cloud_options = cloud_options.as_ref(); + + let unified_scan_args = unified_scan_args_box.as_mut(); + let mut scan_type = scan_type.clone(); + + if let Some(hive_schema) = unified_scan_args.hive_options.schema.as_deref() { + match unified_scan_args.hive_options.enabled { + // Enable hive_partitioning if it is unspecified but a non-empty hive_schema given + None if !hive_schema.is_empty() => { + unified_scan_args.hive_options.enabled = Some(true) + }, + // hive_partitioning was explicitly disabled + Some(false) => polars_bail!( + ComputeError: + "a hive schema was given but hive_partitioning was disabled" + ), + Some(true) | None => {}, + } + } + + let sources = + match &*scan_type { + #[cfg(feature = "parquet")] + FileScan::Parquet { .. } => sources + .expand_paths_with_hive_update(unified_scan_args, cloud_options)?, + #[cfg(feature = "ipc")] + FileScan::Ipc { .. } => sources + .expand_paths_with_hive_update(unified_scan_args, cloud_options)?, + #[cfg(feature = "csv")] + FileScan::Csv { .. } => { + sources.expand_paths(unified_scan_args, cloud_options)? + }, + #[cfg(feature = "json")] + FileScan::NDJson { .. } => { + sources.expand_paths(unified_scan_args, cloud_options)? + }, + FileScan::Anonymous { .. } => sources, + }; + + let mut file_info = match &mut *scan_type { + #[cfg(feature = "parquet")] + FileScan::Parquet { options, metadata } => { + if let Some(schema) = &options.schema { + // We were passed a schema, we don't have to call `parquet_file_info`, + // but this does mean we don't have `row_estimation` and `first_metadata`. + FileInfo { + schema: schema.clone(), + reader_schema: Some(either::Either::Left(Arc::new( + schema.to_arrow(CompatLevel::newest()), + ))), + row_estimation: (None, 0), + } + } else { + let (file_info, md) = scans::parquet_file_info( + &sources, + unified_scan_args.row_index.as_ref(), + cloud_options, + ) + .map_err(|e| e.context(failed_here!(parquet scan)))?; + + *metadata = md; + file_info + } + }, + #[cfg(feature = "ipc")] + FileScan::Ipc { metadata, .. } => { + let (file_info, md) = scans::ipc_file_info( + &sources, + unified_scan_args.row_index.as_ref(), + cloud_options, + ) + .map_err(|e| e.context(failed_here!(ipc scan)))?; + *metadata = Some(Arc::new(md)); + file_info + }, + #[cfg(feature = "csv")] + FileScan::Csv { options } => { + // TODO: This is a hack. We conditionally set `allow_missing_columns` to + // mimic existing behavior, but this should be taken from a user provided + // parameter instead. + if options.schema.is_some() && options.has_header { + unified_scan_args.missing_columns_policy = MissingColumnsPolicy::Insert; + } + + scans::csv_file_info( + &sources, + unified_scan_args.row_index.as_ref(), + options, + cloud_options, + ) + .map_err(|e| e.context(failed_here!(csv scan)))? + }, + #[cfg(feature = "json")] + FileScan::NDJson { options } => scans::ndjson_file_info( + &sources, + unified_scan_args.row_index.as_ref(), + options, + cloud_options, + ) + .map_err(|e| e.context(failed_here!(ndjson scan)))?, + FileScan::Anonymous { .. } => { + file_info.expect("FileInfo should be set for AnonymousScan") + }, + }; + + if unified_scan_args.hive_options.enabled.is_none() { + // We expect this to be `Some(_)` after this point. If it hasn't been auto-enabled + // we explicitly set it to disabled. + unified_scan_args.hive_options.enabled = Some(false); + } + + let hive_parts = if unified_scan_args.hive_options.enabled.unwrap() + && file_info.reader_schema.is_some() + { + let paths = sources.as_paths().ok_or_else(|| { + polars_err!(nyi = "Hive-partitioning of in-memory buffers") + })?; + + #[allow(unused_assignments)] + let mut owned = None; + + hive_partitions_from_paths( + paths, + unified_scan_args.hive_options.hive_start_idx, + unified_scan_args.hive_options.schema.clone(), + match file_info.reader_schema.as_ref().unwrap() { + Either::Left(v) => { + owned = Some(Schema::from_arrow_schema(v.as_ref())); + owned.as_ref().unwrap() + }, + Either::Right(v) => v.as_ref(), + }, + unified_scan_args.hive_options.try_parse_dates, + )? + } else { + None + }; + + if let Some(ref hive_parts) = hive_parts { + let hive_schema = hive_parts.schema(); + file_info.update_schema_with_hive_schema(hive_schema.clone()); + } else if let Some(hive_schema) = unified_scan_args.hive_options.schema.clone() { + // We hit here if we are passed the `hive_schema` to `scan_parquet` but end up with an empty file + // list during path expansion. In this case we still want to return an empty DataFrame with this + // schema. + file_info.update_schema_with_hive_schema(hive_schema); + } + + unified_scan_args.include_file_paths = unified_scan_args + .include_file_paths + .as_ref() + .filter(|_| match &*scan_type { + #[cfg(feature = "parquet")] + FileScan::Parquet { .. } => true, + #[cfg(feature = "ipc")] + FileScan::Ipc { .. } => true, + #[cfg(feature = "csv")] + FileScan::Csv { .. } => true, + #[cfg(feature = "json")] + FileScan::NDJson { .. } => true, + FileScan::Anonymous { .. } => false, + }) + .cloned(); + + if let Some(ref file_path_col) = unified_scan_args.include_file_paths { + let schema = Arc::make_mut(&mut file_info.schema); + + if schema.contains(file_path_col) { + polars_bail!( + Duplicate: r#"column name for file paths "{}" conflicts with column name from file"#, + file_path_col + ); + } + + schema.insert_at_index( + schema.len(), + file_path_col.clone(), + DataType::String, + )?; + } + + unified_scan_args.projection = if file_info.reader_schema.is_some() { + maybe_init_projection_excluding_hive( + file_info.reader_schema.as_ref().unwrap(), + hive_parts.as_ref().map(|h| h.schema()), + ) + } else { + None + }; + + if let Some(row_index) = &unified_scan_args.row_index { + let schema = Arc::make_mut(&mut file_info.schema); + *schema = schema + .new_inserting_at_index(0, row_index.name.clone(), IDX_DTYPE) + .unwrap(); + } + + let ir = if sources.is_empty() && !matches!(&*scan_type, FileScan::Anonymous { .. }) + { + IR::DataFrameScan { + df: Arc::new(DataFrame::empty_with_schema(&file_info.schema)), + schema: file_info.schema, + output_schema: None, + } + } else { + let unified_scan_args = unified_scan_args_box; + + IR::Scan { + sources, + file_info, + hive_parts, + predicate: None, + scan_type, + output_schema: None, + unified_scan_args, + } + }; + + cached_ir.replace(ir); + } + + cached_ir.clone().unwrap() + }, + #[cfg(feature = "python")] + DslPlan::PythonScan { mut options } => { + let scan_fn = options.scan_fn.take(); + let schema = options.get_schema()?; + IR::PythonScan { + options: PythonOptions { + scan_fn, + schema, + python_source: options.python_source, + validate_schema: options.validate_schema, + output_schema: Default::default(), + with_columns: Default::default(), + n_rows: Default::default(), + predicate: Default::default(), + }, + } + }, + DslPlan::Union { inputs, args } => { + let mut inputs = inputs + .into_iter() + .map(|lp| to_alp_impl(lp, ctxt)) + .collect::>>() + .map_err(|e| e.context(failed_here!(vertical concat)))?; + + if args.diagonal { + inputs = + convert_utils::convert_diagonal_concat(inputs, ctxt.lp_arena, ctxt.expr_arena)?; + } + + if args.to_supertypes { + convert_utils::convert_st_union(&mut inputs, ctxt.lp_arena, ctxt.expr_arena) + .map_err(|e| e.context(failed_here!(vertical concat)))?; + } + + let first = *inputs.first().ok_or_else( + || polars_err!(InvalidOperation: "expected at least one input in 'union'/'concat'"), + )?; + let schema = ctxt.lp_arena.get(first).schema(ctxt.lp_arena); + for n in &inputs[1..] { + let schema_i = ctxt.lp_arena.get(*n).schema(ctxt.lp_arena); + // The first argument + schema_i.matches_schema(schema.as_ref()).map_err(|_| polars_err!(InvalidOperation: "'union'/'concat' inputs should all have the same schema,\ + got\n{:?} and \n{:?}", schema, schema_i) + )?; + } + + let options = args.into(); + IR::Union { inputs, options } + }, + DslPlan::HConcat { inputs, options } => { + let inputs = inputs + .into_iter() + .map(|lp| to_alp_impl(lp, ctxt)) + .collect::>>() + .map_err(|e| e.context(failed_here!(horizontal concat)))?; + + let schema = convert_utils::h_concat_schema(&inputs, ctxt.lp_arena)?; + + IR::HConcat { + inputs, + schema, + options, + } + }, + DslPlan::Filter { input, predicate } => { + let mut input = + to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(filter)))?; + let predicate = expand_filter(predicate, input, ctxt.lp_arena, ctxt.opt_flags) + .map_err(|e| e.context(failed_here!(filter)))?; + + let predicate_ae = to_expr_ir(predicate.clone(), ctxt.expr_arena)?; + + // TODO: We could do better here by using `pushdown_eligibility()` + return if ctxt.opt_flags.predicate_pushdown() + && permits_filter_pushdown_rec( + ctxt.expr_arena.get(predicate_ae.node()), + ctxt.expr_arena, + ) { + // Split expression that are ANDed into multiple Filter nodes as the optimizer can then + // push them down independently. Especially if they refer columns from different tables + // this will be more performant. + // So: + // filter[foo == bar & ham == spam] + // filter [foo == bar] + // filter [ham == spam] + let mut predicates = vec![]; + + let mut stack = vec![predicate_ae.node()]; + while let Some(n) = stack.pop() { + if let AExpr::BinaryExpr { + left, + op: Operator::And | Operator::LogicalAnd, + right, + } = ctxt.expr_arena.get(n) + { + stack.push(*left); + stack.push(*right); + } else { + predicates.push(n) + } + } + + for predicate in predicates { + let predicate = ExprIR::from_node(predicate, ctxt.expr_arena); + ctxt.conversion_optimizer + .push_scratch(predicate.node(), ctxt.expr_arena); + let lp = IR::Filter { input, predicate }; + input = run_conversion(lp, ctxt, "filter")?; + } + + Ok(input) + } else { + ctxt.conversion_optimizer + .push_scratch(predicate_ae.node(), ctxt.expr_arena); + let lp = IR::Filter { + input, + predicate: predicate_ae, + }; + run_conversion(lp, ctxt, "filter") + }; + }, + DslPlan::Slice { input, offset, len } => { + let input = + to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(slice)))?; + IR::Slice { input, offset, len } + }, + DslPlan::DataFrameScan { df, schema } => IR::DataFrameScan { + df, + schema, + output_schema: None, + }, + DslPlan::Select { + expr, + input, + options, + } => { + let input = + to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(select)))?; + let schema = ctxt.lp_arena.get(input).schema(ctxt.lp_arena); + let (exprs, schema) = prepare_projection(expr, &schema, ctxt.opt_flags) + .map_err(|e| e.context(failed_here!(select)))?; + + if exprs.is_empty() { + ctxt.lp_arena.replace(input, empty_df()); + } + + let schema = Arc::new(schema); + let eirs = to_expr_irs(exprs, ctxt.expr_arena)?; + ctxt.conversion_optimizer + .fill_scratch(&eirs, ctxt.expr_arena); + + let lp = IR::Select { + expr: eirs, + input, + schema, + options, + }; + + return run_conversion(lp, ctxt, "select").map_err(|e| e.context(failed_here!(select))); + }, + DslPlan::Sort { + input, + by_column, + slice, + mut sort_options, + } => { + // note: if given an Expr::Columns, count the individual cols + let n_by_exprs = if by_column.len() == 1 { + match &by_column[0] { + Expr::Columns(cols) => cols.len(), + _ => 1, + } + } else { + by_column.len() + }; + let n_desc = sort_options.descending.len(); + polars_ensure!( + n_desc == n_by_exprs || n_desc == 1, + ComputeError: "the length of `descending` ({}) does not match the length of `by` ({})", n_desc, by_column.len() + ); + let n_nulls_last = sort_options.nulls_last.len(); + polars_ensure!( + n_nulls_last == n_by_exprs || n_nulls_last == 1, + ComputeError: "the length of `nulls_last` ({}) does not match the length of `by` ({})", n_nulls_last, by_column.len() + ); + + let input = + to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(sort)))?; + + let mut expanded_cols = Vec::new(); + let mut nulls_last = Vec::new(); + let mut descending = Vec::new(); + + // note: nulls_last/descending need to be matched to expanded multi-output expressions. + // when one of nulls_last/descending has not been updated from the default (single + // value true/false), 'cycle' ensures that "by_column" iter is not truncated. + for (c, (&n, &d)) in by_column.into_iter().zip( + sort_options + .nulls_last + .iter() + .cycle() + .zip(sort_options.descending.iter().cycle()), + ) { + let exprs = expand_expressions( + input, + vec![c], + ctxt.lp_arena, + ctxt.expr_arena, + ctxt.opt_flags, + ) + .map_err(|e| e.context(failed_here!(sort)))?; + + nulls_last.extend(std::iter::repeat_n(n, exprs.len())); + descending.extend(std::iter::repeat_n(d, exprs.len())); + expanded_cols.extend(exprs); + } + sort_options.nulls_last = nulls_last; + sort_options.descending = descending; + + ctxt.conversion_optimizer + .fill_scratch(&expanded_cols, ctxt.expr_arena); + let mut by_column = expanded_cols; + + // Remove null columns in multi-columns sort + if by_column.len() > 1 { + let input_schema = ctxt.lp_arena.get(input).schema(ctxt.lp_arena); + + let mut null_columns = vec![]; + + for (i, c) in by_column.iter().enumerate() { + if let DataType::Null = + c.dtype(&input_schema, Context::Default, ctxt.expr_arena)? + { + null_columns.push(i); + } + } + // All null columns, only take one. + if null_columns.len() == by_column.len() { + by_column.truncate(1); + sort_options.nulls_last.truncate(1); + sort_options.descending.truncate(1); + } + // Remove the null columns + else if !null_columns.is_empty() { + for i in null_columns.into_iter().rev() { + by_column.remove(i); + sort_options.nulls_last.remove(i); + sort_options.descending.remove(i); + } + } + }; + + let lp = IR::Sort { + input, + by_column, + slice, + sort_options, + }; + + return run_conversion(lp, ctxt, "sort").map_err(|e| e.context(failed_here!(sort))); + }, + DslPlan::Cache { input, id } => { + let input = + to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(cache)))?; + IR::Cache { + input, + id, + cache_hits: crate::constants::UNLIMITED_CACHE, + } + }, + DslPlan::GroupBy { + input, + keys, + aggs, + apply, + maintain_order, + options, + } => { + let input = + to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(group_by)))?; + + // Rolling + group-by sorts the whole table, so remove unneeded columns + if ctxt.opt_flags.eager() && options.is_rolling() && !keys.is_empty() { + ctxt.opt_flags.insert(OptFlags::PROJECTION_PUSHDOWN) + } + + let (keys, aggs, schema) = resolve_group_by( + input, + keys, + aggs, + &options, + ctxt.lp_arena, + ctxt.expr_arena, + ctxt.opt_flags, + ) + .map_err(|e| e.context(failed_here!(group_by)))?; + + let (apply, schema) = if let Some((apply, schema)) = apply { + (Some(apply), schema) + } else { + (None, schema) + }; + + ctxt.conversion_optimizer + .fill_scratch(&keys, ctxt.expr_arena); + ctxt.conversion_optimizer + .fill_scratch(&aggs, ctxt.expr_arena); + + let lp = IR::GroupBy { + input, + keys, + aggs, + schema, + apply, + maintain_order, + options, + }; + + return run_conversion(lp, ctxt, "group_by") + .map_err(|e| e.context(failed_here!(group_by))); + }, + DslPlan::Join { + input_left, + input_right, + left_on, + right_on, + predicates, + options, + } => { + return join::resolve_join( + Either::Left(input_left), + Either::Left(input_right), + left_on, + right_on, + predicates, + options, + ctxt, + ) + .map_err(|e| e.context(failed_here!(join))) + .map(|t| t.0); + }, + DslPlan::HStack { + input, + exprs, + options, + } => { + let input = to_alp_impl(owned(input), ctxt) + .map_err(|e| e.context(failed_here!(with_columns)))?; + let (exprs, schema) = + resolve_with_columns(exprs, input, ctxt.lp_arena, ctxt.expr_arena, ctxt.opt_flags) + .map_err(|e| e.context(failed_here!(with_columns)))?; + + ctxt.conversion_optimizer + .fill_scratch(&exprs, ctxt.expr_arena); + let lp = IR::HStack { + input, + exprs, + schema, + options, + }; + return run_conversion(lp, ctxt, "with_columns"); + }, + DslPlan::Distinct { input, options } => { + let input = + to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(unique)))?; + let input_schema = ctxt.lp_arena.get(input).schema(ctxt.lp_arena); + + let subset = options + .subset + .map(|s| { + let cols = expand_selectors(s, input_schema.as_ref(), &[])?; + + // Checking if subset columns exist in the dataframe + for col in cols.iter() { + let _ = input_schema + .try_get(col) + .map_err(|_| polars_err!(col_not_found = col))?; + } + + Ok::<_, PolarsError>(cols) + }) + .transpose()?; + + let options = DistinctOptionsIR { + subset, + maintain_order: options.maintain_order, + keep_strategy: options.keep_strategy, + slice: None, + }; + + IR::Distinct { input, options } + }, + DslPlan::MapFunction { input, function } => { + let input = to_alp_impl(owned(input), ctxt) + .map_err(|e| e.context(failed_here!(format!("{}", function).to_lowercase())))?; + let input_schema = ctxt.lp_arena.get(input).schema(ctxt.lp_arena); + + match function { + DslFunction::Explode { + columns, + allow_empty, + } => { + let columns = expand_selectors(columns, &input_schema, &[])?; + validate_columns_in_input(columns.as_ref(), &input_schema, "explode")?; + polars_ensure!(!columns.is_empty() || allow_empty, InvalidOperation: "no columns provided in explode"); + if columns.is_empty() { + return Ok(input); + } + let function = FunctionIR::Explode { + columns, + schema: Default::default(), + }; + let ir = IR::MapFunction { input, function }; + return Ok(ctxt.lp_arena.add(ir)); + }, + DslFunction::FillNan(fill_value) => { + let exprs = input_schema + .iter() + .filter_map(|(name, dtype)| match dtype { + DataType::Float32 | DataType::Float64 => Some( + col(name.clone()) + .fill_nan(fill_value.clone()) + .alias(name.clone()), + ), + _ => None, + }) + .collect::>(); + + let (exprs, schema) = resolve_with_columns( + exprs, + input, + ctxt.lp_arena, + ctxt.expr_arena, + ctxt.opt_flags, + ) + .map_err(|e| e.context(failed_here!(fill_nan)))?; + + ctxt.conversion_optimizer + .fill_scratch(&exprs, ctxt.expr_arena); + + let lp = IR::HStack { + input, + exprs, + schema, + options: ProjectionOptions { + duplicate_check: false, + ..Default::default() + }, + }; + return run_conversion(lp, ctxt, "fill_nan"); + }, + DslFunction::Drop(DropFunction { to_drop, strict }) => { + let to_drop = expand_selectors(to_drop, &input_schema, &[])?; + let to_drop = to_drop.iter().map(|s| s.as_ref()).collect::>(); + + if strict { + for col_name in to_drop.iter() { + polars_ensure!( + input_schema.contains(col_name), + col_not_found = col_name + ); + } + } + + let mut output_schema = + Schema::with_capacity(input_schema.len().saturating_sub(to_drop.len())); + + for (col_name, dtype) in input_schema.iter() { + if !to_drop.contains(col_name.as_str()) { + output_schema.with_column(col_name.clone(), dtype.clone()); + } + } + + if output_schema.is_empty() { + ctxt.lp_arena.replace(input, empty_df()); + } + + IR::SimpleProjection { + input, + columns: Arc::new(output_schema), + } + }, + DslFunction::Stats(sf) => { + let exprs = match sf { + StatsFunction::Var { ddof } => stats_helper( + |dt| dt.is_primitive_numeric() || dt.is_bool(), + |name| col(name.clone()).var(ddof), + &input_schema, + ), + StatsFunction::Std { ddof } => stats_helper( + |dt| dt.is_primitive_numeric() || dt.is_bool(), + |name| col(name.clone()).std(ddof), + &input_schema, + ), + StatsFunction::Quantile { quantile, method } => stats_helper( + |dt| dt.is_primitive_numeric(), + |name| col(name.clone()).quantile(quantile.clone(), method), + &input_schema, + ), + StatsFunction::Mean => stats_helper( + |dt| { + dt.is_primitive_numeric() + || dt.is_temporal() + || dt == &DataType::Boolean + }, + |name| col(name.clone()).mean(), + &input_schema, + ), + StatsFunction::Sum => stats_helper( + |dt| { + dt.is_primitive_numeric() + || dt.is_decimal() + || matches!(dt, DataType::Boolean | DataType::Duration(_)) + }, + |name| col(name.clone()).sum(), + &input_schema, + ), + StatsFunction::Min => stats_helper( + |dt| dt.is_ord(), + |name| col(name.clone()).min(), + &input_schema, + ), + StatsFunction::Max => stats_helper( + |dt| dt.is_ord(), + |name| col(name.clone()).max(), + &input_schema, + ), + StatsFunction::Median => stats_helper( + |dt| { + dt.is_primitive_numeric() + || dt.is_temporal() + || dt == &DataType::Boolean + }, + |name| col(name.clone()).median(), + &input_schema, + ), + }; + let schema = Arc::new(expressions_to_schema( + &exprs, + &input_schema, + Context::Default, + )?); + let eirs = to_expr_irs(exprs, ctxt.expr_arena)?; + + ctxt.conversion_optimizer + .fill_scratch(&eirs, ctxt.expr_arena); + + let lp = IR::Select { + input, + expr: eirs, + schema, + options: ProjectionOptions { + duplicate_check: false, + ..Default::default() + }, + }; + return run_conversion(lp, ctxt, "stats"); + }, + _ => { + let function = function.into_function_ir(&input_schema)?; + IR::MapFunction { input, function } + }, + } + }, + DslPlan::ExtContext { input, contexts } => { + let input = to_alp_impl(owned(input), ctxt) + .map_err(|e| e.context(failed_here!(with_context)))?; + let contexts = contexts + .into_iter() + .map(|lp| to_alp_impl(lp, ctxt)) + .collect::>>() + .map_err(|e| e.context(failed_here!(with_context)))?; + + let mut schema = (**ctxt.lp_arena.get(input).schema(ctxt.lp_arena)).clone(); + for input in &contexts { + let other_schema = ctxt.lp_arena.get(*input).schema(ctxt.lp_arena); + for fld in other_schema.iter_fields() { + if schema.get(fld.name()).is_none() { + schema.with_column(fld.name, fld.dtype); + } + } + } + + IR::ExtContext { + input, + contexts, + schema: Arc::new(schema), + } + }, + DslPlan::Sink { input, payload } => { + let input = + to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(sink)))?; + let payload = match payload { + SinkType::Memory => SinkTypeIR::Memory, + SinkType::File(f) => SinkTypeIR::File(f), + SinkType::Partition(f) => SinkTypeIR::Partition(PartitionSinkTypeIR { + base_path: f.base_path, + file_path_cb: f.file_path_cb, + file_type: f.file_type, + sink_options: f.sink_options, + variant: match f.variant { + PartitionVariant::MaxSize(max_size) => { + PartitionVariantIR::MaxSize(max_size) + }, + PartitionVariant::Parted { + key_exprs, + include_key, + } => { + let eirs = to_expr_irs(key_exprs, ctxt.expr_arena)?; + ctxt.conversion_optimizer + .fill_scratch(&eirs, ctxt.expr_arena); + + PartitionVariantIR::Parted { + key_exprs: eirs, + include_key, + } + }, + PartitionVariant::ByKey { + key_exprs, + include_key, + } => { + let eirs = to_expr_irs(key_exprs, ctxt.expr_arena)?; + ctxt.conversion_optimizer + .fill_scratch(&eirs, ctxt.expr_arena); + + PartitionVariantIR::ByKey { + key_exprs: eirs, + include_key, + } + }, + }, + cloud_options: f.cloud_options, + }), + }; + + let lp = IR::Sink { input, payload }; + return run_conversion(lp, ctxt, "sink"); + }, + DslPlan::SinkMultiple { inputs } => { + let inputs = inputs + .into_iter() + .map(|lp| to_alp_impl(lp, ctxt)) + .collect::>>() + .map_err(|e| e.context(failed_here!(vertical concat)))?; + IR::SinkMultiple { inputs } + }, + #[cfg(feature = "merge_sorted")] + DslPlan::MergeSorted { + input_left, + input_right, + key, + } => { + let input_left = to_alp_impl(owned(input_left), ctxt) + .map_err(|e| e.context(failed_here!(merge_sorted)))?; + let input_right = to_alp_impl(owned(input_right), ctxt) + .map_err(|e| e.context(failed_here!(merge_sorted)))?; + + IR::MergeSorted { + input_left, + input_right, + key, + } + }, + DslPlan::IR { node, dsl, version } => { + return if node.is_some() + && version == ctxt.lp_arena.version() + && ctxt.conversion_optimizer.used_arenas.insert(version) + { + Ok(node.unwrap()) + } else { + to_alp_impl(owned(dsl), ctxt) + }; + }, + }; + Ok(ctxt.lp_arena.add(v)) +} + +fn expand_filter( + predicate: Expr, + input: Node, + lp_arena: &Arena, + opt_flags: &mut OptFlags, +) -> PolarsResult { + let schema = lp_arena.get(input).schema(lp_arena); + let predicate = if has_expr(&predicate, |e| match e { + Expr::Column(name) => is_regex_projection(name), + Expr::Wildcard + | Expr::Selector(_) + | Expr::RenameAlias { .. } + | Expr::Columns(_) + | Expr::DtypeColumn(_) + | Expr::IndexColumn(_) + | Expr::Nth(_) => true, + #[cfg(feature = "dtype-struct")] + Expr::Function { + function: FunctionExpr::StructExpr(StructFunction::FieldByIndex(_)), + .. + } => true, + _ => false, + }) { + let mut rewritten = rewrite_projections(vec![predicate], &schema, &[], opt_flags)?; + match rewritten.len() { + 1 => { + // all good + rewritten.pop().unwrap() + }, + 0 => { + let msg = "The predicate expanded to zero expressions. \ + This may for example be caused by a regex not matching column names or \ + a column dtype match not hitting any dtypes in the DataFrame"; + polars_bail!(ComputeError: msg); + }, + _ => { + let mut expanded = String::new(); + for e in rewritten.iter().take(5) { + expanded.push_str(&format!("\t{e:?},\n")) + } + // pop latest comma + expanded.pop(); + if rewritten.len() > 5 { + expanded.push_str("\t...\n") + } + + let msg = if cfg!(feature = "python") { + format!( + "The predicate passed to 'LazyFrame.filter' expanded to multiple expressions: \n\n{expanded}\n\ + This is ambiguous. Try to combine the predicates with the 'all' or `any' expression." + ) + } else { + format!( + "The predicate passed to 'LazyFrame.filter' expanded to multiple expressions: \n\n{expanded}\n\ + This is ambiguous. Try to combine the predicates with the 'all_horizontal' or `any_horizontal' expression." + ) + }; + polars_bail!(ComputeError: msg) + }, + } + } else { + predicate + }; + expr_to_leaf_column_names_iter(&predicate) + .try_for_each(|c| schema.try_index_of(&c).and(Ok(())))?; + + Ok(predicate) +} + +fn resolve_with_columns( + exprs: Vec, + input: Node, + lp_arena: &Arena, + expr_arena: &mut Arena, + opt_flags: &mut OptFlags, +) -> PolarsResult<(Vec, SchemaRef)> { + let schema = lp_arena.get(input).schema(lp_arena); + let mut new_schema = (**schema).clone(); + let (exprs, _) = prepare_projection(exprs, &schema, opt_flags)?; + let mut output_names = PlHashSet::with_capacity(exprs.len()); + + let mut arena = Arena::with_capacity(8); + for e in &exprs { + let field = e + .to_field_amortized(&schema, Context::Default, &mut arena) + .unwrap(); + + if !output_names.insert(field.name().clone()) { + let msg = format!( + "the name '{}' passed to `LazyFrame.with_columns` is duplicate\n\n\ + It's possible that multiple expressions are returning the same default column name. \ + If this is the case, try renaming the columns with `.alias(\"new_name\")` to avoid \ + duplicate column names.", + field.name() + ); + polars_bail!(ComputeError: msg) + } + new_schema.with_column(field.name, field.dtype.materialize_unknown(true)?); + arena.clear(); + } + + let eirs = to_expr_irs(exprs, expr_arena)?; + Ok((eirs, Arc::new(new_schema))) +} + +fn resolve_group_by( + input: Node, + keys: Vec, + aggs: Vec, + _options: &GroupbyOptions, + lp_arena: &Arena, + expr_arena: &mut Arena, + opt_flags: &mut OptFlags, +) -> PolarsResult<(Vec, Vec, SchemaRef)> { + let current_schema = lp_arena.get(input).schema(lp_arena); + let current_schema = current_schema.as_ref(); + let mut keys = rewrite_projections(keys, current_schema, &[], opt_flags)?; + + // Initialize schema from keys + let mut schema = expressions_to_schema(&keys, current_schema, Context::Default)?; + + #[allow(unused_mut)] + let mut pop_keys = false; + // Add dynamic groupby index column(s) + // Also add index columns to keys for expression expansion. + #[cfg(feature = "dynamic_group_by")] + { + if let Some(options) = _options.rolling.as_ref() { + let name = options.index_column.clone(); + let dtype = current_schema.try_get(name.as_str())?; + keys.push(col(name.clone())); + pop_keys = true; + schema.with_column(name.clone(), dtype.clone()); + } else if let Some(options) = _options.dynamic.as_ref() { + let name = options.index_column.clone(); + keys.push(col(name.clone())); + pop_keys = true; + let dtype = current_schema.try_get(name.as_str())?; + if options.include_boundaries { + schema.with_column("_lower_boundary".into(), dtype.clone()); + schema.with_column("_upper_boundary".into(), dtype.clone()); + } + schema.with_column(name.clone(), dtype.clone()); + } + } + let keys_index_len = schema.len(); + + let aggs = rewrite_projections(aggs, current_schema, &keys, opt_flags)?; + if pop_keys { + let _ = keys.pop(); + } + + // Add aggregation column(s) + let aggs_schema = expressions_to_schema(&aggs, current_schema, Context::Aggregation)?; + schema.merge(aggs_schema); + + // Make sure aggregation columns do not contain keys or index columns + if schema.len() < (keys_index_len + aggs.len()) { + let mut names = PlHashSet::with_capacity(schema.len()); + for expr in aggs.iter().chain(keys.iter()) { + let name = expr_output_name(expr)?; + polars_ensure!(names.insert(name.clone()), duplicate = name) + } + } + let keys = to_expr_irs(keys, expr_arena)?; + let aggs = to_expr_irs(aggs, expr_arena)?; + validate_expressions(&keys, expr_arena, current_schema, "group by")?; + validate_expressions(&aggs, expr_arena, current_schema, "group by")?; + + Ok((keys, aggs, Arc::new(schema))) +} +fn stats_helper(condition: F, expr: E, schema: &Schema) -> Vec +where + F: Fn(&DataType) -> bool, + E: Fn(&PlSmallStr) -> Expr, +{ + schema + .iter() + .map(|(name, dt)| { + if condition(dt) { + expr(name) + } else { + lit(NULL).cast(dt.clone()).alias(name.clone()) + } + }) + .collect() +} + +pub(crate) fn maybe_init_projection_excluding_hive( + reader_schema: &Either, + hive_parts: Option<&SchemaRef>, +) -> Option> { + // Update `with_columns` with a projection so that hive columns aren't loaded from the + // file + let hive_schema = hive_parts?; + + match &reader_schema { + Either::Left(reader_schema) => hive_schema + .iter_names() + .any(|x| reader_schema.contains(x)) + .then(|| { + reader_schema + .iter_names_cloned() + .filter(|x| !hive_schema.contains(x)) + .collect::>() + }), + Either::Right(reader_schema) => hive_schema + .iter_names() + .any(|x| reader_schema.contains(x)) + .then(|| { + reader_schema + .iter_names_cloned() + .filter(|x| !hive_schema.contains(x)) + .collect::>() + }), + } +} diff --git a/crates/polars-plan/src/plans/conversion/expr_expansion.rs b/crates/polars-plan/src/plans/conversion/expr_expansion.rs new file mode 100644 index 000000000000..b74e219cdbf7 --- /dev/null +++ b/crates/polars-plan/src/plans/conversion/expr_expansion.rs @@ -0,0 +1,947 @@ +//! this contains code used for rewriting projections, expanding wildcards, regex selection etc. + +use super::*; + +pub(crate) fn prepare_projection( + exprs: Vec, + schema: &Schema, + opt_flags: &mut OptFlags, +) -> PolarsResult<(Vec, Schema)> { + let exprs = rewrite_projections(exprs, schema, &[], opt_flags)?; + let schema = expressions_to_schema(&exprs, schema, Context::Default)?; + Ok((exprs, schema)) +} + +/// This replaces the wildcard Expr with a Column Expr. It also removes the Exclude Expr from the +/// expression chain. +pub(super) fn replace_wildcard_with_column(expr: Expr, column_name: &PlSmallStr) -> Expr { + expr.map_expr(|e| match e { + Expr::Wildcard => Expr::Column(column_name.clone()), + Expr::Exclude(input, _) => Arc::unwrap_or_clone(input), + e => e, + }) +} + +#[cfg(feature = "regex")] +fn remove_exclude(expr: Expr) -> Expr { + expr.map_expr(|e| match e { + Expr::Exclude(input, _) => Arc::unwrap_or_clone(input), + e => e, + }) +} + +fn rewrite_special_aliases(expr: Expr) -> PolarsResult { + // the blocks are added by cargo fmt + if has_expr(&expr, |e| { + matches!(e, Expr::KeepName(_) | Expr::RenameAlias { .. }) + }) { + match expr { + Expr::KeepName(expr) => { + let roots = expr_to_leaf_column_names(&expr); + let name = roots + .first() + .expect("expected root column to keep expression name"); + Ok(Expr::Alias(expr, name.clone())) + }, + Expr::RenameAlias { expr, function } => { + let name = get_single_leaf(&expr)?; + let name = function.call(&name)?; + Ok(Expr::Alias(expr, name)) + }, + _ => { + polars_bail!(InvalidOperation: "`keep`, `suffix`, `prefix` should be last expression") + }, + } + } else { + Ok(expr) + } +} + +/// Take an expression with a root: col("*") and copies that expression for all columns in the schema, +/// with the exclusion of the `names` in the exclude expression. +/// The resulting expressions are written to result. +fn replace_wildcard( + expr: &Expr, + result: &mut Vec, + exclude: &PlHashSet, + schema: &Schema, +) -> PolarsResult<()> { + for name in schema.iter_names() { + if !exclude.contains(name.as_str()) { + let new_expr = replace_wildcard_with_column(expr.clone(), name); + let new_expr = rewrite_special_aliases(new_expr)?; + result.push(new_expr) + } + } + Ok(()) +} + +fn replace_nth(expr: Expr, schema: &Schema) -> Expr { + expr.map_expr(|e| { + if let Expr::Nth(i) = e { + match i.negative_to_usize(schema.len()) { + None => { + let name = match i { + 0 => "first", + -1 => "last", + _ => "nth", + }; + Expr::Column(PlSmallStr::from_static(name)) + }, + Some(idx) => { + let (name, _dtype) = schema.get_at_index(idx).unwrap(); + Expr::Column(name.clone()) + }, + } + } else { + e + } + }) +} + +#[cfg(feature = "regex")] +/// This function takes an expression containing a regex in `col("..")` and expands the columns +/// that are selected by that regex in `result`. +fn expand_regex( + expr: &Expr, + result: &mut Vec, + schema: &Schema, + pattern: &str, + exclude: &PlHashSet, +) -> PolarsResult<()> { + let re = polars_utils::regex_cache::compile_regex(pattern) + .map_err(|e| polars_err!(ComputeError: "invalid regex {}", e))?; + for name in schema.iter_names() { + if re.is_match(name) && !exclude.contains(name.as_str()) { + let mut new_expr = remove_exclude(expr.clone()); + + new_expr = new_expr.map_expr(|e| match e { + Expr::Column(pat) if pat.as_str() == pattern => Expr::Column(name.clone()), + e => e, + }); + + let new_expr = rewrite_special_aliases(new_expr)?; + result.push(new_expr); + } + } + Ok(()) +} + +pub(crate) fn is_regex_projection(name: &str) -> bool { + name.starts_with('^') && name.ends_with('$') +} + +#[cfg(feature = "regex")] +/// This function searches for a regex expression in `col("..")` and expands the columns +/// that are selected by that regex in `result`. The regex should start with `^` and end with `$`. +fn replace_regex( + expr: &Expr, + result: &mut Vec, + schema: &Schema, + exclude: &PlHashSet, +) -> PolarsResult<()> { + let roots = expr_to_leaf_column_names(expr); + let mut regex = None; + for name in &roots { + if is_regex_projection(name) { + match regex { + None => { + regex = Some(name); + expand_regex(expr, result, schema, name, exclude)?; + }, + Some(r) => { + polars_ensure!( + r == name, + ComputeError: + "an expression is not allowed to have different regexes" + ) + }, + } + } + } + if regex.is_none() { + let expr = rewrite_special_aliases(expr.clone())?; + result.push(expr) + } + Ok(()) +} + +/// replace `columns(["A", "B"])..` with `col("A")..`, `col("B")..` +#[allow(unused_variables)] +fn expand_columns( + expr: &Expr, + result: &mut Vec, + names: &[PlSmallStr], + schema: &Schema, + exclude: &PlHashSet, +) -> PolarsResult<()> { + if !expr.into_iter().all(|e| match e { + // check for invalid expansions such as `col([a, b]) + col([c, d])` + Expr::Columns(members) => members.as_ref() == names, + _ => true, + }) { + polars_bail!(ComputeError: "expanding more than one `col` is not allowed"); + } + for name in names { + if !exclude.contains(name) { + let new_expr = expr.clone().map_expr(|e| match e { + Expr::Columns(_) => Expr::Column((*name).clone()), + Expr::Exclude(input, _) => Arc::unwrap_or_clone(input), + e => e, + }); + + #[cfg(feature = "regex")] + replace_regex(&new_expr, result, schema, exclude)?; + + #[cfg(not(feature = "regex"))] + result.push(rewrite_special_aliases(new_expr)?); + } + } + Ok(()) +} + +#[cfg(feature = "dtype-struct")] +fn struct_index_to_field(expr: Expr, schema: &Schema) -> PolarsResult { + expr.try_map_expr(|e| match e { + Expr::Function { + input, + function: FunctionExpr::StructExpr(sf), + options, + } => { + if let StructFunction::FieldByIndex(index) = sf { + let dtype = input[0].to_field(schema, Context::Default)?.dtype; + let DataType::Struct(fields) = dtype else { + polars_bail!(InvalidOperation: "expected 'struct' dtype, got {:?}", dtype) + }; + let index = index.try_negative_to_usize(fields.len())?; + let name = fields[index].name.clone(); + Ok(Expr::Function { + input, + function: FunctionExpr::StructExpr(StructFunction::FieldByName(name)), + options, + }) + } else { + Ok(Expr::Function { + input, + function: FunctionExpr::StructExpr(sf), + options, + }) + } + }, + e => Ok(e), + }) +} + +/// This replaces the dtype or index expanded Expr with a Column Expr. +/// ()It also removes the Exclude Expr from the expression chain). +fn replace_dtype_or_index_with_column( + expr: Expr, + column_name: &PlSmallStr, + replace_dtype: bool, +) -> Expr { + expr.map_expr(|e| match e { + Expr::DtypeColumn(_) if replace_dtype => Expr::Column(column_name.clone()), + Expr::IndexColumn(_) if !replace_dtype => Expr::Column(column_name.clone()), + Expr::Exclude(input, _) => Arc::unwrap_or_clone(input), + e => e, + }) +} + +fn dtypes_match(d1: &DataType, d2: &DataType) -> bool { + match (d1, d2) { + // note: allow Datetime "*" wildcard for timezones... + (DataType::Datetime(tu_l, tz_l), DataType::Datetime(tu_r, tz_r)) => { + tu_l == tu_r + && (tz_l == tz_r + || tz_r.is_some() && (tz_l.as_deref().unwrap_or("") == "*") + || tz_l.is_some() && (tz_r.as_deref().unwrap_or("") == "*")) + }, + // ...but otherwise require exact match + _ => d1 == d2, + } +} + +/// replace `DtypeColumn` with `col("foo")..col("bar")` +fn expand_dtypes( + expr: &Expr, + result: &mut Vec, + schema: &Schema, + dtypes: &[DataType], + exclude: &PlHashSet, +) -> PolarsResult<()> { + // note: we loop over the schema to guarantee that we return a stable + // field-order, irrespective of which dtypes are filtered against + for field in schema.iter_fields().filter(|f| { + dtypes.iter().any(|dtype| dtypes_match(dtype, &f.dtype)) + && !exclude.contains(f.name().as_str()) + }) { + let name = field.name(); + let new_expr = expr.clone(); + let new_expr = replace_dtype_or_index_with_column(new_expr, name, true); + let new_expr = rewrite_special_aliases(new_expr)?; + result.push(new_expr) + } + Ok(()) +} + +#[cfg(feature = "dtype-struct")] +fn replace_struct_multiple_fields_with_field( + expr: Expr, + column_name: &PlSmallStr, +) -> PolarsResult { + let mut count = 0; + let out = expr.map_expr(|e| match e { + Expr::Function { + function, + input, + options, + } => { + if matches!( + function, + FunctionExpr::StructExpr(StructFunction::MultipleFields(_)) + ) { + count += 1; + Expr::Function { + input, + function: FunctionExpr::StructExpr(StructFunction::FieldByName( + column_name.clone(), + )), + options, + } + } else { + Expr::Function { + input, + function, + options, + } + } + }, + e => e, + }); + polars_ensure!(count == 1, InvalidOperation: "multiple expanding fields in a single struct not yet supported"); + Ok(out) +} + +#[cfg(feature = "dtype-struct")] +fn expand_struct_fields( + struct_expr: &Expr, + full_expr: &Expr, + result: &mut Vec, + schema: &Schema, + names: &[PlSmallStr], + exclude: &PlHashSet, +) -> PolarsResult<()> { + let first_name = names[0].as_ref(); + if names.len() == 1 && first_name == "*" || is_regex_projection(first_name) { + let Expr::Function { input, .. } = struct_expr else { + unreachable!() + }; + let field = input[0].to_field(schema, Context::Default)?; + let dtype = field.dtype(); + let DataType::Struct(fields) = dtype else { + if !dtype.is_known() { + let mut msg = String::from( + "expected 'struct' got an unknown data type + +This means there was an operation of which the output data type could not be determined statically. +Try setting the output data type for that operation.", + ); + for e in input[0].into_iter() { + #[allow(clippy::single_match)] + match e { + #[cfg(feature = "list_to_struct")] + Expr::Function { + input: _, + function, + options: _, + } => { + if matches!( + function, + FunctionExpr::ListExpr(ListFunction::ToStruct(..)) + ) { + msg.push_str( + " + +Hint: set 'upper_bound' for 'list.to_struct'.", + ); + } + }, + _ => {}, + } + } + + polars_bail!(InvalidOperation: msg) + } else { + polars_bail!(InvalidOperation: "expected 'struct' got {}", field.dtype()) + } + }; + + // Wildcard. + let names = if first_name == "*" { + fields + .iter() + .flat_map(|field| { + let name = field.name(); + + if exclude.contains(name.as_str()) { + None + } else { + Some(name.clone()) + } + }) + .collect::>() + } + // Regex + else { + #[cfg(feature = "regex")] + { + let re = polars_utils::regex_cache::compile_regex(first_name) + .map_err(|e| polars_err!(ComputeError: "invalid regex {}", e))?; + + fields + .iter() + .flat_map(|field| { + let name = field.name(); + if exclude.contains(name.as_str()) || !re.is_match(name.as_str()) { + None + } else { + Some(name.clone()) + } + }) + .collect::>() + } + #[cfg(not(feature = "regex"))] + { + panic!("activate 'regex' feature") + } + }; + + return expand_struct_fields( + struct_expr, + full_expr, + result, + schema, + names.as_slice(), + exclude, + ); + } + + for name in names { + polars_ensure!(name.as_str() != "*", InvalidOperation: "cannot combine wildcards and column names"); + + if !exclude.contains(name) { + let mut new_expr = replace_struct_multiple_fields_with_field(full_expr.clone(), name)?; + match new_expr { + Expr::KeepName(expr) => { + new_expr = Expr::Alias(expr, name.clone()); + }, + Expr::RenameAlias { expr, function } => { + let name = function.call(name)?; + new_expr = Expr::Alias(expr, name); + }, + _ => {}, + } + + result.push(new_expr) + } + } + Ok(()) +} + +/// replace `IndexColumn` with `col("foo")..col("bar")` +fn expand_indices( + expr: &Expr, + result: &mut Vec, + schema: &Schema, + indices: &[i64], + exclude: &PlHashSet, +) -> PolarsResult<()> { + let n_fields = schema.len() as i64; + for idx in indices { + let mut idx = *idx; + if idx < 0 { + idx += n_fields; + if idx < 0 { + polars_bail!(ComputeError: "invalid column index {}", idx) + } + } + if let Some((name, _)) = schema.get_at_index(idx as usize) { + if !exclude.contains(name.as_str()) { + let new_expr = expr.clone(); + let new_expr = replace_dtype_or_index_with_column(new_expr, name, false); + let new_expr = rewrite_special_aliases(new_expr)?; + result.push(new_expr); + } + } + } + Ok(()) +} + +// schema is not used if regex not activated +#[allow(unused_variables)] +fn prepare_excluded( + expr: &Expr, + schema: &Schema, + keys: &[Expr], + has_exclude: bool, +) -> PolarsResult> { + let mut exclude = PlHashSet::new(); + + // explicit exclude branch + if has_exclude { + for e in expr { + if let Expr::Exclude(_, to_exclude) = e { + #[cfg(feature = "regex")] + { + // instead of matching the names for regex patterns and + // expanding the matches in the schema we reuse the + // `replace_regex` func; this is a bit slower but DRY. + let mut buf = vec![]; + for to_exclude_single in to_exclude { + match to_exclude_single { + Excluded::Name(name) => { + let e = Expr::Column(name.clone()); + replace_regex(&e, &mut buf, schema, &Default::default())?; + // we cannot loop because of bchck + while let Some(col) = buf.pop() { + if let Expr::Column(name) = col { + exclude.insert(name); + } + } + }, + Excluded::Dtype(dt) => { + for fld in schema.iter_fields() { + if dtypes_match(fld.dtype(), dt) { + exclude.insert(fld.name.clone()); + } + } + }, + } + } + } + + #[cfg(not(feature = "regex"))] + { + for to_exclude_single in to_exclude { + match to_exclude_single { + Excluded::Name(name) => { + exclude.insert(name.clone()); + }, + Excluded::Dtype(dt) => { + for (name, dtype) in schema.iter() { + if matches!(dtype, dt) { + exclude.insert(name.clone()); + } + } + }, + } + } + } + } + } + } + + // exclude group_by keys + for expr in keys.iter() { + if let Ok(name) = expr_output_name(expr) { + exclude.insert(name.clone()); + } + } + Ok(exclude) +} + +// functions can have col(["a", "b"]) or col(String) as inputs +fn expand_function_inputs( + expr: Expr, + schema: &Schema, + opt_flags: &mut OptFlags, +) -> PolarsResult { + expr.try_map_expr(|mut e| match &mut e { + Expr::AnonymousFunction { input, options, .. } | Expr::Function { input, options, .. } + if options + .flags + .contains(FunctionFlags::INPUT_WILDCARD_EXPANSION) => + { + *input = rewrite_projections(core::mem::take(input), schema, &[], opt_flags)?; + if input.is_empty() && !options.flags.contains(FunctionFlags::ALLOW_EMPTY_INPUTS) { + // Needed to visualize the error + *input = vec![Expr::Literal(LiteralValue::Scalar(Scalar::null( + DataType::Null, + )))]; + polars_bail!(InvalidOperation: "expected at least 1 input in {}", e) + } + Ok(e) + }, + _ => Ok(e), + }) +} + +#[derive(Copy, Clone, Debug)] +struct ExpansionFlags { + multiple_columns: bool, + has_nth: bool, + has_wildcard: bool, + has_selector: bool, + has_exclude: bool, + #[cfg(feature = "dtype-struct")] + expands_fields: bool, + #[cfg(feature = "dtype-struct")] + has_struct_field_by_index: bool, +} + +impl ExpansionFlags { + fn expands(&self) -> bool { + #[cfg(feature = "dtype-struct")] + let expands_fields = self.expands_fields; + #[cfg(not(feature = "dtype-struct"))] + let expands_fields = false; + + self.multiple_columns || expands_fields + } +} + +fn find_flags(expr: &Expr) -> PolarsResult { + let mut multiple_columns = false; + let mut has_nth = false; + let mut has_wildcard = false; + let mut has_selector = false; + let mut has_exclude = false; + #[cfg(feature = "dtype-struct")] + let mut has_struct_field_by_index = false; + #[cfg(feature = "dtype-struct")] + let mut expands_fields = false; + + // Do a single pass and collect all flags at once. + // Supertypes/modification that can be done in place are also done in that pass + for expr in expr { + match expr { + Expr::Columns(_) | Expr::DtypeColumn(_) => multiple_columns = true, + Expr::IndexColumn(idx) => multiple_columns = idx.len() > 1, + Expr::Nth(_) => has_nth = true, + Expr::Wildcard => has_wildcard = true, + Expr::Selector(_) => has_selector = true, + #[cfg(feature = "dtype-struct")] + Expr::Function { + function: FunctionExpr::StructExpr(StructFunction::FieldByIndex(_)), + .. + } => { + has_struct_field_by_index = true; + }, + #[cfg(feature = "dtype-struct")] + Expr::Function { + function: FunctionExpr::StructExpr(StructFunction::MultipleFields(_)), + .. + } => { + expands_fields = true; + }, + Expr::Exclude(_, _) => has_exclude = true, + #[cfg(feature = "dtype-struct")] + Expr::Field(_) => { + polars_bail!(InvalidOperation: "field expression not allowed at location/context") + }, + _ => {}, + } + } + Ok(ExpansionFlags { + multiple_columns, + has_nth, + has_wildcard, + has_selector, + has_exclude, + #[cfg(feature = "dtype-struct")] + has_struct_field_by_index, + #[cfg(feature = "dtype-struct")] + expands_fields, + }) +} + +#[cfg(feature = "dtype-struct")] +fn toggle_cse(opt_flags: &mut OptFlags) { + if opt_flags.contains(OptFlags::EAGER) && !opt_flags.contains(OptFlags::NEW_STREAMING) { + #[cfg(debug_assertions)] + { + use polars_core::config::verbose; + if verbose() { + eprintln!("CSE turned on because of struct expansion") + } + } + *opt_flags |= OptFlags::COMM_SUBEXPR_ELIM; + } +} + +/// In case of single col(*) -> do nothing, no selection is the same as select all +/// In other cases replace the wildcard with an expression with all columns +pub(crate) fn rewrite_projections( + exprs: Vec, + schema: &Schema, + keys: &[Expr], + opt_flags: &mut OptFlags, +) -> PolarsResult> { + let mut result = Vec::with_capacity(exprs.len() + schema.len()); + + for mut expr in exprs { + #[cfg(feature = "dtype-struct")] + let result_offset = result.len(); + + // Functions can have col(["a", "b"]) or col(String) as inputs. + expr = expand_function_inputs(expr, schema, opt_flags)?; + + let mut flags = find_flags(&expr)?; + if flags.has_selector { + expr = replace_selector(expr, schema, keys)?; + // the selector is replaced with Expr::Columns + flags.multiple_columns = true; + } + + replace_and_add_to_results(expr, flags, &mut result, schema, keys, opt_flags)?; + + #[cfg(feature = "dtype-struct")] + if flags.has_struct_field_by_index { + toggle_cse(opt_flags); + for e in &mut result[result_offset..] { + *e = struct_index_to_field(std::mem::take(e), schema)?; + } + } + } + Ok(result) +} + +fn replace_and_add_to_results( + mut expr: Expr, + flags: ExpansionFlags, + result: &mut Vec, + schema: &Schema, + keys: &[Expr], + opt_flags: &mut OptFlags, +) -> PolarsResult<()> { + if flags.has_nth { + expr = replace_nth(expr, schema); + } + + // has multiple column names + // the expanded columns are added to the result + if flags.expands() { + if let Some(e) = expr.into_iter().find(|e| match e { + Expr::Columns(_) | Expr::DtypeColumn(_) | Expr::IndexColumn(_) => true, + #[cfg(feature = "dtype-struct")] + Expr::Function { + function: FunctionExpr::StructExpr(StructFunction::MultipleFields(_)), + .. + } => flags.expands_fields, + _ => false, + }) { + match &e { + Expr::Columns(names) => { + // Don't exclude grouping keys if columns are explicitly specified. + let exclude = prepare_excluded(&expr, schema, &[], flags.has_exclude)?; + expand_columns(&expr, result, names, schema, &exclude)?; + }, + Expr::DtypeColumn(dtypes) => { + let exclude = prepare_excluded(&expr, schema, keys, flags.has_exclude)?; + expand_dtypes(&expr, result, schema, dtypes, &exclude)? + }, + Expr::IndexColumn(indices) => { + let exclude = prepare_excluded(&expr, schema, keys, flags.has_exclude)?; + expand_indices(&expr, result, schema, indices, &exclude)? + }, + #[cfg(feature = "dtype-struct")] + Expr::Function { function, .. } => { + let FunctionExpr::StructExpr(StructFunction::MultipleFields(names)) = function + else { + unreachable!() + }; + let exclude = prepare_excluded(&expr, schema, keys, flags.has_exclude)?; + + // has both column and field expansion + // col('a', 'b').struct.field('*') + if flags.multiple_columns | flags.has_wildcard { + // First expand col('a', 'b') into an intermediate result. + let mut intermediate = vec![]; + let mut flags = flags; + flags.expands_fields = false; + replace_and_add_to_results( + expr.clone(), + flags, + &mut intermediate, + schema, + keys, + opt_flags, + )?; + + // Then expand the fields and add to the final result vec. + flags.expands_fields = true; + flags.multiple_columns = false; + flags.has_wildcard = false; + for e in intermediate { + replace_and_add_to_results(e, flags, result, schema, keys, opt_flags)?; + } + } + // has only field expansion + // col('a').struct.field('*') + else { + toggle_cse(opt_flags); + expand_struct_fields(e, &expr, result, schema, names, &exclude)? + } + }, + _ => {}, + } + } + } + // has multiple column names due to wildcards + else if flags.has_wildcard { + // keep track of column excluded from the wildcard + let exclude = prepare_excluded(&expr, schema, keys, flags.has_exclude)?; + // this path prepares the wildcard as input for the Function Expr + replace_wildcard(&expr, result, &exclude, schema)?; + } + // can have multiple column names due to a regex + else { + #[allow(clippy::collapsible_else_if)] + #[cfg(feature = "regex")] + { + // keep track of column excluded from the dtypes + let exclude = prepare_excluded(&expr, schema, keys, flags.has_exclude)?; + replace_regex(&expr, result, schema, &exclude)?; + } + #[cfg(not(feature = "regex"))] + { + let expr = rewrite_special_aliases(expr)?; + result.push(expr) + } + } + Ok(()) +} + +fn replace_selector_inner( + s: Selector, + members: &mut PlIndexSet, + scratch: &mut Vec, + schema: &Schema, + keys: &[Expr], +) -> PolarsResult<()> { + match s { + Selector::Root(expr) => { + let local_flags = find_flags(&expr)?; + replace_and_add_to_results( + *expr, + local_flags, + scratch, + schema, + keys, + &mut Default::default(), + )?; + members.extend(scratch.drain(..)) + }, + Selector::Add(lhs, rhs) => { + let mut tmp_members: PlIndexSet = Default::default(); + replace_selector_inner(*lhs, members, scratch, schema, keys)?; + replace_selector_inner(*rhs, &mut tmp_members, scratch, schema, keys)?; + members.extend(tmp_members) + }, + Selector::ExclusiveOr(lhs, rhs) => { + let mut tmp_members = Default::default(); + replace_selector_inner(*lhs, &mut tmp_members, scratch, schema, keys)?; + replace_selector_inner(*rhs, members, scratch, schema, keys)?; + + *members = tmp_members.symmetric_difference(members).cloned().collect(); + }, + Selector::Intersect(lhs, rhs) => { + let mut tmp_members = Default::default(); + replace_selector_inner(*lhs, &mut tmp_members, scratch, schema, keys)?; + replace_selector_inner(*rhs, members, scratch, schema, keys)?; + + *members = tmp_members.intersection(members).cloned().collect(); + }, + Selector::Sub(lhs, rhs) => { + let mut tmp_members = Default::default(); + replace_selector_inner(*lhs, &mut tmp_members, scratch, schema, keys)?; + replace_selector_inner(*rhs, members, scratch, schema, keys)?; + + *members = tmp_members.difference(members).cloned().collect(); + }, + } + Ok(()) +} + +fn replace_selector(expr: Expr, schema: &Schema, keys: &[Expr]) -> PolarsResult { + // First pass we replace the selectors with Expr::Columns, we expand the `to_add` columns + // and then subtract the `to_subtract` columns. + expr.try_map_expr(|e| match e { + Expr::Selector(mut s) => { + let mut swapped = Selector::Root(Box::new(Expr::Wildcard)); + std::mem::swap(&mut s, &mut swapped); + + let cols = expand_selector(swapped, schema, keys)?; + Ok(Expr::Columns(cols)) + }, + e => Ok(e), + }) +} + +pub(crate) fn expand_selectors( + s: Vec, + schema: &Schema, + keys: &[Expr], +) -> PolarsResult> { + let mut columns = vec![]; + + // Skip the column fast paths. + fn skip(name: &str) -> bool { + is_regex_projection(name) || name == "*" + } + + for s in s { + match s { + Selector::Root(e) => match *e { + Expr::Column(name) if !skip(name.as_ref()) => columns.push(name), + Expr::Columns(names) if names.iter().all(|n| !skip(n.as_ref())) => { + columns.extend_from_slice(names.as_ref()) + }, + Expr::Selector(s) => { + let names = expand_selector(s, schema, keys)?; + columns.extend_from_slice(names.as_ref()); + }, + e => { + let names = expand_selector(Selector::new(e), schema, keys)?; + columns.extend_from_slice(names.as_ref()); + }, + }, + other => { + let names = expand_selector(other, schema, keys)?; + columns.extend_from_slice(names.as_ref()); + }, + } + } + + Ok(Arc::from(columns)) +} + +pub(super) fn expand_selector( + s: Selector, + schema: &Schema, + keys: &[Expr], +) -> PolarsResult> { + let mut members = PlIndexSet::new(); + replace_selector_inner(s, &mut members, &mut vec![], schema, keys)?; + + if members.len() <= 1 { + members + .into_iter() + .map(|e| { + let Expr::Column(name) = e else { + polars_bail!(InvalidOperation: "invalid selector expression: {}", e) + }; + Ok(name) + }) + .collect() + } else { + // Ensure that multiple columns returned from combined/nested selectors remain in schema order + let selected = schema + .iter_fields() + .map(|field| field.name().clone()) + .filter(|field_name| members.contains(&Expr::Column(field_name.clone()))) + .collect(); + + Ok(selected) + } +} diff --git a/crates/polars-plan/src/plans/conversion/expr_to_ir.rs b/crates/polars-plan/src/plans/conversion/expr_to_ir.rs new file mode 100644 index 000000000000..85d6f499a3fc --- /dev/null +++ b/crates/polars-plan/src/plans/conversion/expr_to_ir.rs @@ -0,0 +1,342 @@ +use super::*; +use crate::plans::conversion::functions::convert_functions; + +pub fn to_expr_ir(expr: Expr, arena: &mut Arena) -> PolarsResult { + let mut state = ConversionContext::new(); + let node = to_aexpr_impl(expr, arena, &mut state)?; + Ok(ExprIR::new(node, state.output_name)) +} + +pub(super) fn to_expr_irs(input: Vec, arena: &mut Arena) -> PolarsResult> { + input.into_iter().map(|e| to_expr_ir(e, arena)).collect() +} + +pub fn to_expr_ir_ignore_alias(expr: Expr, arena: &mut Arena) -> PolarsResult { + let mut state = ConversionContext::new(); + state.ignore_alias = true; + let node = to_aexpr_impl_materialized_lit(expr, arena, &mut state)?; + Ok(ExprIR::new(node, state.output_name)) +} + +pub(super) fn to_expr_irs_ignore_alias( + input: Vec, + arena: &mut Arena, +) -> PolarsResult> { + input + .into_iter() + .map(|e| to_expr_ir_ignore_alias(e, arena)) + .collect() +} + +/// converts expression to AExpr and adds it to the arena, which uses an arena (Vec) for allocation +pub fn to_aexpr(expr: Expr, arena: &mut Arena) -> PolarsResult { + to_aexpr_impl_materialized_lit( + expr, + arena, + &mut ConversionContext { + prune_alias: false, + ..Default::default() + }, + ) +} + +#[derive(Default)] +pub(super) struct ConversionContext { + pub(super) output_name: OutputName, + /// Remove alias from the expressions and set as [`OutputName`]. + pub(super) prune_alias: bool, + /// If an `alias` is encountered prune and ignore it. + pub(super) ignore_alias: bool, +} + +impl ConversionContext { + fn new() -> Self { + Self { + prune_alias: true, + ..Default::default() + } + } +} + +fn to_aexprs( + input: Vec, + arena: &mut Arena, + state: &mut ConversionContext, +) -> PolarsResult> { + input + .into_iter() + .map(|e| to_aexpr_impl_materialized_lit(e, arena, state)) + .collect() +} + +pub(super) fn set_function_output_name( + e: &[ExprIR], + state: &mut ConversionContext, + function_fmt: F, +) where + F: FnOnce() -> PlSmallStr, +{ + if state.output_name.is_none() { + if e.is_empty() { + let s = function_fmt(); + state.output_name = OutputName::LiteralLhs(s); + } else { + state.output_name = e[0].output_name_inner().clone(); + } + } +} + +fn to_aexpr_impl_materialized_lit( + expr: Expr, + arena: &mut Arena, + state: &mut ConversionContext, +) -> PolarsResult { + // Already convert `Lit Float and Lit Int` expressions that are not used in a binary / function expression. + // This means they can be materialized immediately + let e = match expr { + Expr::Literal(lv @ LiteralValue::Dyn(_)) => Expr::Literal(lv.materialize()), + Expr::Alias(inner, name) if matches!(&*inner, Expr::Literal(LiteralValue::Dyn(_))) => { + let Expr::Literal(lv) = &*inner else { + unreachable!() + }; + Expr::Alias(Arc::new(Expr::Literal(lv.clone().materialize())), name) + }, + e => e, + }; + to_aexpr_impl(e, arena, state) +} + +/// Converts expression to AExpr and adds it to the arena, which uses an arena (Vec) for allocation. +#[recursive] +pub(super) fn to_aexpr_impl( + expr: Expr, + arena: &mut Arena, + state: &mut ConversionContext, +) -> PolarsResult { + let owned = Arc::unwrap_or_clone; + let v = match expr { + Expr::Explode(expr) => AExpr::Explode(to_aexpr_impl(owned(expr), arena, state)?), + Expr::Alias(e, name) => { + if state.prune_alias { + if state.output_name.is_none() && !state.ignore_alias { + state.output_name = OutputName::Alias(name); + } + let _ = to_aexpr_impl(owned(e), arena, state)?; + arena.pop().unwrap() + } else { + AExpr::Alias(to_aexpr_impl(owned(e), arena, state)?, name) + } + }, + Expr::Literal(lv) => { + if state.output_name.is_none() { + state.output_name = OutputName::LiteralLhs(lv.output_column_name().clone()); + } + AExpr::Literal(lv) + }, + Expr::Column(name) => { + if state.output_name.is_none() { + state.output_name = OutputName::ColumnLhs(name.clone()) + } + AExpr::Column(name) + }, + Expr::BinaryExpr { left, op, right } => { + let l = to_aexpr_impl(owned(left), arena, state)?; + let r = to_aexpr_impl(owned(right), arena, state)?; + AExpr::BinaryExpr { + left: l, + op, + right: r, + } + }, + Expr::Cast { + expr, + dtype, + options, + } => AExpr::Cast { + expr: to_aexpr_impl(owned(expr), arena, state)?, + dtype, + options, + }, + Expr::Gather { + expr, + idx, + returns_scalar, + } => AExpr::Gather { + expr: to_aexpr_impl(owned(expr), arena, state)?, + idx: to_aexpr_impl_materialized_lit(owned(idx), arena, state)?, + returns_scalar, + }, + Expr::Sort { expr, options } => AExpr::Sort { + expr: to_aexpr_impl(owned(expr), arena, state)?, + options, + }, + Expr::SortBy { + expr, + by, + sort_options, + } => AExpr::SortBy { + expr: to_aexpr_impl(owned(expr), arena, state)?, + by: by + .into_iter() + .map(|e| to_aexpr_impl(e, arena, state)) + .collect::>()?, + sort_options, + }, + Expr::Filter { input, by } => AExpr::Filter { + input: to_aexpr_impl(owned(input), arena, state)?, + by: to_aexpr_impl(owned(by), arena, state)?, + }, + Expr::Agg(agg) => { + let a_agg = match agg { + AggExpr::Min { + input, + propagate_nans, + } => IRAggExpr::Min { + input: to_aexpr_impl_materialized_lit(owned(input), arena, state)?, + propagate_nans, + }, + AggExpr::Max { + input, + propagate_nans, + } => IRAggExpr::Max { + input: to_aexpr_impl_materialized_lit(owned(input), arena, state)?, + propagate_nans, + }, + AggExpr::Median(expr) => { + IRAggExpr::Median(to_aexpr_impl_materialized_lit(owned(expr), arena, state)?) + }, + AggExpr::NUnique(expr) => { + IRAggExpr::NUnique(to_aexpr_impl_materialized_lit(owned(expr), arena, state)?) + }, + AggExpr::First(expr) => { + IRAggExpr::First(to_aexpr_impl_materialized_lit(owned(expr), arena, state)?) + }, + AggExpr::Last(expr) => { + IRAggExpr::Last(to_aexpr_impl_materialized_lit(owned(expr), arena, state)?) + }, + AggExpr::Mean(expr) => { + IRAggExpr::Mean(to_aexpr_impl_materialized_lit(owned(expr), arena, state)?) + }, + AggExpr::Implode(expr) => { + IRAggExpr::Implode(to_aexpr_impl_materialized_lit(owned(expr), arena, state)?) + }, + AggExpr::Count(expr, include_nulls) => IRAggExpr::Count( + to_aexpr_impl_materialized_lit(owned(expr), arena, state)?, + include_nulls, + ), + AggExpr::Quantile { + expr, + quantile, + method, + } => IRAggExpr::Quantile { + expr: to_aexpr_impl_materialized_lit(owned(expr), arena, state)?, + quantile: to_aexpr_impl_materialized_lit(owned(quantile), arena, state)?, + method, + }, + AggExpr::Sum(expr) => { + IRAggExpr::Sum(to_aexpr_impl_materialized_lit(owned(expr), arena, state)?) + }, + AggExpr::Std(expr, ddof) => IRAggExpr::Std( + to_aexpr_impl_materialized_lit(owned(expr), arena, state)?, + ddof, + ), + AggExpr::Var(expr, ddof) => IRAggExpr::Var( + to_aexpr_impl_materialized_lit(owned(expr), arena, state)?, + ddof, + ), + AggExpr::AggGroups(expr) => { + IRAggExpr::AggGroups(to_aexpr_impl_materialized_lit(owned(expr), arena, state)?) + }, + }; + AExpr::Agg(a_agg) + }, + Expr::Ternary { + predicate, + truthy, + falsy, + } => { + // Truthy must be resolved first to get the lhs name first set. + let t = to_aexpr_impl(owned(truthy), arena, state)?; + let p = to_aexpr_impl_materialized_lit(owned(predicate), arena, state)?; + let f = to_aexpr_impl(owned(falsy), arena, state)?; + AExpr::Ternary { + predicate: p, + truthy: t, + falsy: f, + } + }, + Expr::AnonymousFunction { + input, + function, + output_type, + options, + } => { + let e = to_expr_irs(input, arena)?; + set_function_output_name(&e, state, || PlSmallStr::from_static(options.fmt_str)); + AExpr::AnonymousFunction { + input: e, + function, + output_type, + options, + } + }, + Expr::Function { + input, + function, + options, + } => return convert_functions(input, function, options, arena, state), + Expr::Window { + function, + partition_by, + order_by, + options, + } => { + // Process function first so name is correct. + let function = to_aexpr_impl(owned(function), arena, state)?; + let order_by = if let Some((e, options)) = order_by { + Some((to_aexpr_impl(owned(e.clone()), arena, state)?, options)) + } else { + None + }; + + AExpr::Window { + function, + partition_by: to_aexprs(partition_by, arena, state)?, + order_by, + options, + } + }, + Expr::Slice { + input, + offset, + length, + } => AExpr::Slice { + input: to_aexpr_impl(owned(input), arena, state)?, + offset: to_aexpr_impl_materialized_lit(owned(offset), arena, state)?, + length: to_aexpr_impl_materialized_lit(owned(length), arena, state)?, + }, + Expr::Len => { + if state.output_name.is_none() { + state.output_name = OutputName::LiteralLhs(get_len_name()) + } + AExpr::Len + }, + #[cfg(feature = "dtype-struct")] + e @ Expr::Field(_) => { + polars_bail!(InvalidOperation: "'Expr: {}' not allowed in this context/location", e) + }, + e @ Expr::IndexColumn(_) + | e @ Expr::Wildcard + | e @ Expr::Nth(_) + | e @ Expr::SubPlan { .. } + | e @ Expr::KeepName(_) + | e @ Expr::Exclude(_, _) + | e @ Expr::RenameAlias { .. } + | e @ Expr::Columns { .. } + | e @ Expr::DtypeColumn { .. } + | e @ Expr::Selector(_) => { + polars_bail!(InvalidOperation: "'Expr: {}' not allowed in this context/location", e) + }, + }; + Ok(arena.add(v)) +} diff --git a/crates/polars-plan/src/plans/conversion/functions.rs b/crates/polars-plan/src/plans/conversion/functions.rs new file mode 100644 index 000000000000..7d658b4237dc --- /dev/null +++ b/crates/polars-plan/src/plans/conversion/functions.rs @@ -0,0 +1,103 @@ +use arrow::legacy::error::PolarsResult; +use polars_utils::arena::{Arena, Node}; +use polars_utils::format_pl_smallstr; + +use super::*; +use crate::dsl::{Expr, FunctionExpr}; +use crate::plans::AExpr; +use crate::prelude::FunctionOptions; + +pub(super) fn convert_functions( + input: Vec, + function: FunctionExpr, + mut options: FunctionOptions, + arena: &mut Arena, + ctx: &mut ConversionContext, +) -> PolarsResult { + use FunctionExpr as F; + + // Return before converting inputs + match function { + // This can be created by col(*).is_null() on empty dataframes. + F::Boolean(BooleanFunction::AllHorizontal) if input.is_empty() => { + return to_aexpr_impl(lit(true), arena, ctx); + }, + F::Boolean(BooleanFunction::AnyHorizontal) if input.is_empty() => { + return to_aexpr_impl(lit(false), arena, ctx); + }, + // Convert to binary expression as the optimizer understands those. + // Don't exceed 128 expressions as we might stackoverflow. + F::Boolean(BooleanFunction::AllHorizontal) => { + if input.len() < 128 { + let single = input.len() == 1; + let mut expr = input.into_iter().reduce(|l, r| l.logical_and(r)).unwrap(); + if single { + expr = expr.cast(DataType::Boolean) + } + return to_aexpr_impl(expr, arena, ctx); + } + }, + F::Boolean(BooleanFunction::AnyHorizontal) => { + if input.len() < 128 { + let single = input.len() == 1; + let mut expr = input.into_iter().reduce(|l, r| l.logical_or(r)).unwrap(); + if single { + expr = expr.cast(DataType::Boolean) + } + return to_aexpr_impl(expr, arena, ctx); + } + }, + _ => {}, + } + + // Converts inputs + let e = to_expr_irs(input, arena)?; + + match function { + #[cfg(feature = "diff")] + F::Diff(_) => { + polars_ensure!(&e[1].is_scalar(arena), ComputeError: "'n' must be scalar value"); + }, + F::Repeat => { + polars_ensure!(&e[0].is_scalar(arena), ComputeError: "'value' must be scalar value"); + polars_ensure!(&e[1].is_scalar(arena), ComputeError: "'n' must be scalar value"); + }, + #[cfg(feature = "replace")] + F::Replace | F::ReplaceStrict { .. } => { + let old = &e[1]; + let new = &e[1]; + + // if old is scalar and new is scalar -> elementwise + if old.is_scalar(arena) && new.is_scalar(arena) { + options.set_elementwise(); + } + }, + F::ShiftAndFill => { + polars_ensure!(&e[1].is_scalar(arena), ComputeError: "'n' must be scalar value"); + polars_ensure!(&e[2].is_scalar(arena), ComputeError: "'fill_value' must be scalar value"); + }, + _ => {}, + } + + // Validate inputs. + if function == FunctionExpr::ShiftAndFill { + polars_ensure!(&e[1].is_scalar(arena), ComputeError: "'n' must be scalar value"); + polars_ensure!(&e[2].is_scalar(arena), ComputeError: "'fill_value' must be scalar value"); + } + + if ctx.output_name.is_none() { + // Handles special case functions like `struct.field`. + if let Some(name) = function.output_name() { + ctx.output_name = name + } else { + set_function_output_name(&e, ctx, || format_pl_smallstr!("{}", &function)); + } + } + + let ae_function = AExpr::Function { + input: e, + function, + options, + }; + Ok(arena.add(ae_function)) +} diff --git a/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs b/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs new file mode 100644 index 000000000000..bee5cd3b8b4d --- /dev/null +++ b/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs @@ -0,0 +1,261 @@ +use super::*; + +/// converts a node from the AExpr arena to Expr +#[recursive] +pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { + let expr = expr_arena.get(node).clone(); + + match expr { + AExpr::Explode(node) => Expr::Explode(Arc::new(node_to_expr(node, expr_arena))), + AExpr::Alias(expr, name) => { + let exp = node_to_expr(expr, expr_arena); + Expr::Alias(Arc::new(exp), name) + }, + AExpr::Column(a) => Expr::Column(a), + AExpr::Literal(s) => Expr::Literal(s), + AExpr::BinaryExpr { left, op, right } => { + let l = node_to_expr(left, expr_arena); + let r = node_to_expr(right, expr_arena); + Expr::BinaryExpr { + left: Arc::new(l), + op, + right: Arc::new(r), + } + }, + AExpr::Cast { + expr, + dtype, + options: strict, + } => { + let exp = node_to_expr(expr, expr_arena); + Expr::Cast { + expr: Arc::new(exp), + dtype, + options: strict, + } + }, + AExpr::Sort { expr, options } => { + let exp = node_to_expr(expr, expr_arena); + Expr::Sort { + expr: Arc::new(exp), + options, + } + }, + AExpr::Gather { + expr, + idx, + returns_scalar, + } => { + let expr = node_to_expr(expr, expr_arena); + let idx = node_to_expr(idx, expr_arena); + Expr::Gather { + expr: Arc::new(expr), + idx: Arc::new(idx), + returns_scalar, + } + }, + AExpr::SortBy { + expr, + by, + sort_options, + } => { + let expr = node_to_expr(expr, expr_arena); + let by = by + .iter() + .map(|node| node_to_expr(*node, expr_arena)) + .collect(); + Expr::SortBy { + expr: Arc::new(expr), + by, + sort_options, + } + }, + AExpr::Filter { input, by } => { + let input = node_to_expr(input, expr_arena); + let by = node_to_expr(by, expr_arena); + Expr::Filter { + input: Arc::new(input), + by: Arc::new(by), + } + }, + AExpr::Agg(agg) => match agg { + IRAggExpr::Min { + input, + propagate_nans, + } => { + let exp = node_to_expr(input, expr_arena); + AggExpr::Min { + input: Arc::new(exp), + propagate_nans, + } + .into() + }, + IRAggExpr::Max { + input, + propagate_nans, + } => { + let exp = node_to_expr(input, expr_arena); + AggExpr::Max { + input: Arc::new(exp), + propagate_nans, + } + .into() + }, + + IRAggExpr::Median(expr) => { + let exp = node_to_expr(expr, expr_arena); + AggExpr::Median(Arc::new(exp)).into() + }, + IRAggExpr::NUnique(expr) => { + let exp = node_to_expr(expr, expr_arena); + AggExpr::NUnique(Arc::new(exp)).into() + }, + IRAggExpr::First(expr) => { + let exp = node_to_expr(expr, expr_arena); + AggExpr::First(Arc::new(exp)).into() + }, + IRAggExpr::Last(expr) => { + let exp = node_to_expr(expr, expr_arena); + AggExpr::Last(Arc::new(exp)).into() + }, + IRAggExpr::Mean(expr) => { + let exp = node_to_expr(expr, expr_arena); + AggExpr::Mean(Arc::new(exp)).into() + }, + IRAggExpr::Implode(expr) => { + let exp = node_to_expr(expr, expr_arena); + AggExpr::Implode(Arc::new(exp)).into() + }, + IRAggExpr::Quantile { + expr, + quantile, + method, + } => { + let expr = node_to_expr(expr, expr_arena); + let quantile = node_to_expr(quantile, expr_arena); + AggExpr::Quantile { + expr: Arc::new(expr), + quantile: Arc::new(quantile), + method, + } + .into() + }, + IRAggExpr::Sum(expr) => { + let exp = node_to_expr(expr, expr_arena); + AggExpr::Sum(Arc::new(exp)).into() + }, + IRAggExpr::Std(expr, ddof) => { + let exp = node_to_expr(expr, expr_arena); + AggExpr::Std(Arc::new(exp), ddof).into() + }, + IRAggExpr::Var(expr, ddof) => { + let exp = node_to_expr(expr, expr_arena); + AggExpr::Var(Arc::new(exp), ddof).into() + }, + IRAggExpr::AggGroups(expr) => { + let exp = node_to_expr(expr, expr_arena); + AggExpr::AggGroups(Arc::new(exp)).into() + }, + IRAggExpr::Count(expr, include_nulls) => { + let expr = node_to_expr(expr, expr_arena); + AggExpr::Count(Arc::new(expr), include_nulls).into() + }, + }, + AExpr::Ternary { + predicate, + truthy, + falsy, + } => { + let p = node_to_expr(predicate, expr_arena); + let t = node_to_expr(truthy, expr_arena); + let f = node_to_expr(falsy, expr_arena); + + Expr::Ternary { + predicate: Arc::new(p), + truthy: Arc::new(t), + falsy: Arc::new(f), + } + }, + AExpr::AnonymousFunction { + input, + function, + output_type, + options, + } => Expr::AnonymousFunction { + input: expr_irs_to_exprs(input, expr_arena), + function, + output_type, + options, + }, + AExpr::Function { + input, + function, + options, + } => Expr::Function { + input: expr_irs_to_exprs(input, expr_arena), + function, + options, + }, + AExpr::Window { + function, + partition_by, + order_by, + options, + } => { + let function = Arc::new(node_to_expr(function, expr_arena)); + let partition_by = nodes_to_exprs(&partition_by, expr_arena); + let order_by = + order_by.map(|(n, options)| (Arc::new(node_to_expr(n, expr_arena)), options)); + Expr::Window { + function, + partition_by, + order_by, + options, + } + }, + AExpr::Slice { + input, + offset, + length, + } => Expr::Slice { + input: Arc::new(node_to_expr(input, expr_arena)), + offset: Arc::new(node_to_expr(offset, expr_arena)), + length: Arc::new(node_to_expr(length, expr_arena)), + }, + AExpr::Len => Expr::Len, + } +} + +fn nodes_to_exprs(nodes: &[Node], expr_arena: &Arena) -> Vec { + nodes.iter().map(|n| node_to_expr(*n, expr_arena)).collect() +} + +pub fn node_to_lp_cloned( + node: Node, + expr_arena: &Arena, + mut lp_arena: &Arena, +) -> DslPlan { + // we borrow again mutably only to make the types happy + // we want to initialize `to_lp` from a mutable and a immutable lp_arena + // by borrowing an immutable mutably, we still are immutable down the line. + let alp = lp_arena.get(node).clone(); + alp.into_lp( + &|node, lp_arena: &mut &Arena| lp_arena.get(node).clone(), + &mut lp_arena, + expr_arena, + ) +} + +/// converts a node from the IR arena to a LogicalPlan +pub fn node_to_lp(node: Node, expr_arena: &Arena, lp_arena: &mut Arena) -> DslPlan { + let alp = lp_arena.get_mut(node); + let alp = std::mem::take(alp); + alp.into_lp( + &|node, lp_arena: &mut Arena| { + let lp = lp_arena.get_mut(node); + std::mem::take(lp) + }, + lp_arena, + expr_arena, + ) +} diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs new file mode 100644 index 000000000000..ca3281145ed3 --- /dev/null +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -0,0 +1,595 @@ +use arrow::legacy::error::PolarsResult; +use either::Either; +use polars_core::chunked_array::cast::CastOptions; +use polars_core::error::feature_gated; +use polars_core::utils::{get_numeric_upcast_supertype_lossless, try_get_supertype}; +use polars_utils::format_pl_smallstr; +use polars_utils::itertools::Itertools; + +use super::*; +use crate::constants::POLARS_TMP_PREFIX; +use crate::dsl::Expr; +#[cfg(feature = "iejoin")] +use crate::plans::AExpr; + +fn check_join_keys(keys: &[Expr]) -> PolarsResult<()> { + for e in keys { + if has_expr(e, |e| matches!(e, Expr::Alias(_, _))) { + polars_bail!( + InvalidOperation: + "'alias' is not allowed in a join key, use 'with_columns' first", + ) + } + } + Ok(()) +} + +/// Returns: left: join_node, right: last_node (often both the same) +pub fn resolve_join( + input_left: Either, Node>, + input_right: Either, Node>, + left_on: Vec, + right_on: Vec, + predicates: Vec, + mut options: Arc, + ctxt: &mut DslConversionContext, +) -> PolarsResult<(Node, Node)> { + if !predicates.is_empty() { + feature_gated!("iejoin", { + debug_assert!(left_on.is_empty() && right_on.is_empty()); + return resolve_join_where( + input_left.unwrap_left(), + input_right.unwrap_left(), + predicates, + options, + ctxt, + ); + }) + } + + let owned = Arc::unwrap_or_clone; + let mut input_left = input_left.map_right(Ok).right_or_else(|input| { + to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(join left))) + })?; + let mut input_right = input_right.map_right(Ok).right_or_else(|input| { + to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(join right))) + })?; + + let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena); + let schema_right = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena); + + if options.args.how.is_cross() { + polars_ensure!(left_on.len() + right_on.len() == 0, InvalidOperation: "a 'cross' join doesn't expect any join keys"); + } else { + polars_ensure!(left_on.len() + right_on.len() > 0, InvalidOperation: "expected join keys/predicates"); + check_join_keys(&left_on)?; + check_join_keys(&right_on)?; + + let mut turn_off_coalesce = false; + for e in left_on.iter().chain(right_on.iter()) { + // Any expression that is not a simple column expression will turn of coalescing. + turn_off_coalesce |= has_expr(e, |e| !matches!(e, Expr::Column(_))); + } + if turn_off_coalesce { + let options = Arc::make_mut(&mut options); + if matches!(options.args.coalesce, JoinCoalesce::CoalesceColumns) { + polars_warn!( + "coalescing join requested but not all join keys are column references, turning off key coalescing" + ); + } + options.args.coalesce = JoinCoalesce::KeepColumns; + } + + options.args.validation.is_valid_join(&options.args.how)?; + + #[cfg(feature = "asof_join")] + if let JoinType::AsOf(opt) = &options.args.how { + match (&opt.left_by, &opt.right_by) { + (None, None) => {}, + (Some(l), Some(r)) => { + polars_ensure!(l.len() == r.len(), InvalidOperation: "expected equal number of columns in 'by_left' and 'by_right' in 'asof_join'"); + validate_columns_in_input(l, &schema_left, "asof_join")?; + validate_columns_in_input(r, &schema_right, "asof_join")?; + }, + _ => { + polars_bail!(InvalidOperation: "expected both 'by_left' and 'by_right' to be set in 'asof_join'") + }, + } + } + + polars_ensure!( + left_on.len() == right_on.len(), + InvalidOperation: + format!( + "the number of columns given as join key (left: {}, right:{}) should be equal", + left_on.len(), + right_on.len() + ) + ); + } + + let mut left_on = to_expr_irs_ignore_alias(left_on, ctxt.expr_arena)?; + let mut right_on = to_expr_irs_ignore_alias(right_on, ctxt.expr_arena)?; + let mut joined_on = PlHashSet::new(); + + #[cfg(feature = "iejoin")] + let check = !matches!(options.args.how, JoinType::IEJoin); + #[cfg(not(feature = "iejoin"))] + let check = true; + if check { + for (l, r) in left_on.iter().zip(right_on.iter()) { + polars_ensure!( + joined_on.insert((l.output_name(), r.output_name())), + InvalidOperation: "joining with repeated key names; already joined on {} and {}", + l.output_name(), + r.output_name() + ) + } + } + drop(joined_on); + + ctxt.conversion_optimizer + .fill_scratch(&left_on, ctxt.expr_arena); + ctxt.conversion_optimizer + .optimize_exprs(ctxt.expr_arena, ctxt.lp_arena, input_left) + .map_err(|e| e.context("'join' failed".into()))?; + ctxt.conversion_optimizer + .fill_scratch(&right_on, ctxt.expr_arena); + ctxt.conversion_optimizer + .optimize_exprs(ctxt.expr_arena, ctxt.lp_arena, input_right) + .map_err(|e| e.context("'join' failed".into()))?; + + // Re-evaluate because of mutable borrows earlier. + let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena); + let schema_right = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena); + + // # Resolve scalars + // + // Scalars need to be expanded. We translate them to temporary columns added with + // `with_columns` and remove them later with `project` + // This way the backends don't have to expand the literals in the join implementation + + let has_scalars = left_on + .iter() + .chain(right_on.iter()) + .any(|e| e.is_scalar(ctxt.expr_arena)); + + let (schema_left, schema_right) = if has_scalars { + let mut as_with_columns_l = vec![]; + let mut as_with_columns_r = vec![]; + for (i, e) in left_on.iter().enumerate() { + if e.is_scalar(ctxt.expr_arena) { + as_with_columns_l.push((i, e.clone())); + } + } + for (i, e) in right_on.iter().enumerate() { + if e.is_scalar(ctxt.expr_arena) { + as_with_columns_r.push((i, e.clone())); + } + } + + let mut count = 0; + let get_tmp_name = |i| format_pl_smallstr!("{POLARS_TMP_PREFIX}{i}"); + + // Early clone because of bck. + let mut schema_right_new = if !as_with_columns_r.is_empty() { + (**schema_right).clone() + } else { + Default::default() + }; + if !as_with_columns_l.is_empty() { + let mut schema_left_new = (**schema_left).clone(); + + let mut exprs = Vec::with_capacity(as_with_columns_l.len()); + for (i, mut e) in as_with_columns_l { + let tmp_name = get_tmp_name(count); + count += 1; + e.set_alias(tmp_name.clone()); + let dtype = e.dtype(&schema_left_new, Context::Default, ctxt.expr_arena)?; + schema_left_new.with_column(tmp_name.clone(), dtype.clone()); + + let col = ctxt.expr_arena.add(AExpr::Column(tmp_name)); + left_on[i] = ExprIR::from_node(col, ctxt.expr_arena); + exprs.push(e); + } + input_left = ctxt.lp_arena.add(IR::HStack { + input: input_left, + exprs, + schema: Arc::new(schema_left_new), + options: ProjectionOptions::default(), + }) + } + if !as_with_columns_r.is_empty() { + let mut exprs = Vec::with_capacity(as_with_columns_r.len()); + for (i, mut e) in as_with_columns_r { + let tmp_name = get_tmp_name(count); + count += 1; + e.set_alias(tmp_name.clone()); + let dtype = e.dtype(&schema_right_new, Context::Default, ctxt.expr_arena)?; + schema_right_new.with_column(tmp_name.clone(), dtype.clone()); + + let col = ctxt.expr_arena.add(AExpr::Column(tmp_name)); + right_on[i] = ExprIR::from_node(col, ctxt.expr_arena); + exprs.push(e); + } + input_right = ctxt.lp_arena.add(IR::HStack { + input: input_right, + exprs, + schema: Arc::new(schema_right_new), + options: ProjectionOptions::default(), + }) + } + + ( + ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena), + ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena), + ) + } else { + (schema_left, schema_right) + }; + + // Not a closure to avoid borrow issues because we mutate expr_arena as well. + macro_rules! get_dtype { + ($expr:expr, $schema:expr) => { + ctxt.expr_arena + .get($expr.node()) + .get_type($schema, Context::Default, ctxt.expr_arena) + }; + } + + // # Cast lossless + // + // If we do a full join and keys are coalesced, the cast keys must be added up front. + let key_cols_coalesced = + options.args.should_coalesce() && matches!(&options.args.how, JoinType::Full); + let mut as_with_columns_l = vec![]; + let mut as_with_columns_r = vec![]; + for (lnode, rnode) in left_on.iter_mut().zip(right_on.iter_mut()) { + let ltype = get_dtype!(lnode, &schema_left)?; + let rtype = get_dtype!(rnode, &schema_right)?; + + if let Some(dtype) = get_numeric_upcast_supertype_lossless(<ype, &rtype) { + // We use overflowing cast to allow better optimization as we are casting to a known + // lossless supertype. + // + // We have unique references to these nodes (they are created by this function), + // so we can mutate in-place without causing side effects somewhere else. + let casted_l = ctxt.expr_arena.add(AExpr::Cast { + expr: lnode.node(), + dtype: dtype.clone(), + options: CastOptions::Overflowing, + }); + let casted_r = ctxt.expr_arena.add(AExpr::Cast { + expr: rnode.node(), + dtype, + options: CastOptions::Overflowing, + }); + + if key_cols_coalesced { + let mut lnode = lnode.clone(); + let mut rnode = rnode.clone(); + + let ae_l = ctxt.expr_arena.get(lnode.node()); + let ae_r = ctxt.expr_arena.get(rnode.node()); + + polars_ensure!( + ae_l.is_col() && ae_r.is_col(), + SchemaMismatch: "can only 'coalesce' full join if join keys are column expressions", + ); + + lnode.set_node(casted_l); + rnode.set_node(casted_r); + + as_with_columns_r.push(rnode); + as_with_columns_l.push(lnode); + } else { + lnode.set_node(casted_l); + rnode.set_node(casted_r); + } + } else { + polars_ensure!( + ltype == rtype, + SchemaMismatch: "datatypes of join keys don't match - `{}`: {} on left does not match `{}`: {} on right", + lnode.output_name(), ltype, rnode.output_name(), rtype + ) + } + } + + // Every expression must be elementwise so that we are + // guaranteed the keys for a join are all the same length. + + polars_ensure!( + all_elementwise(&left_on, ctxt.expr_arena) && all_elementwise(&right_on, ctxt.expr_arena), + InvalidOperation: "all join key expressions must be elementwise." + ); + + // These are Arc, into_owned is free. + let schema_left = schema_left.into_owned(); + let schema_right = schema_right.into_owned(); + + let join_schema = det_join_schema( + &schema_left, + &schema_right, + &left_on, + &right_on, + &options, + ctxt.expr_arena, + ) + .map_err(|e| e.context(failed_here!(join schema resolving)))?; + + if key_cols_coalesced { + input_left = if as_with_columns_l.is_empty() { + input_left + } else { + ctxt.lp_arena.add(IR::HStack { + input: input_left, + exprs: as_with_columns_l, + schema: schema_left, + options: ProjectionOptions::default(), + }) + }; + + input_right = if as_with_columns_r.is_empty() { + input_right + } else { + ctxt.lp_arena.add(IR::HStack { + input: input_right, + exprs: as_with_columns_r, + schema: schema_right, + options: ProjectionOptions::default(), + }) + }; + } + + let ir = IR::Join { + input_left, + input_right, + schema: join_schema.clone(), + left_on, + right_on, + options, + }; + let join_node = ctxt.lp_arena.add(ir); + + if has_scalars { + let names = join_schema + .iter_names() + .filter_map(|n| { + if n.starts_with(POLARS_TMP_PREFIX) { + None + } else { + Some(n.clone()) + } + }) + .collect_vec(); + + let builder = IRBuilder::new(join_node, ctxt.expr_arena, ctxt.lp_arena); + let ir = builder.project_simple(names).map(|b| b.build())?; + let select_node = ctxt.lp_arena.add(ir); + + Ok((select_node, join_node)) + } else { + Ok((join_node, join_node)) + } +} + +#[cfg(feature = "iejoin")] +impl From for Operator { + fn from(value: InequalityOperator) -> Self { + match value { + InequalityOperator::LtEq => Operator::LtEq, + InequalityOperator::Lt => Operator::Lt, + InequalityOperator::GtEq => Operator::GtEq, + InequalityOperator::Gt => Operator::Gt, + } + } +} + +#[cfg(feature = "iejoin")] +/// Returns: left: join_node, right: last_node (often both the same) +fn resolve_join_where( + input_left: Arc, + input_right: Arc, + predicates: Vec, + mut options: Arc, + ctxt: &mut DslConversionContext, +) -> PolarsResult<(Node, Node)> { + // If not eager, respect the flag. + if ctxt.opt_flags.eager() { + ctxt.opt_flags.set(OptFlags::PREDICATE_PUSHDOWN, true); + } + ctxt.opt_flags.set(OptFlags::COLLAPSE_JOINS, true); + check_join_keys(&predicates)?; + let input_left = to_alp_impl(Arc::unwrap_or_clone(input_left), ctxt) + .map_err(|e| e.context(failed_here!(join left)))?; + let input_right = to_alp_impl(Arc::unwrap_or_clone(input_right), ctxt) + .map_err(|e| e.context(failed_here!(join left)))?; + + let schema_left = ctxt + .lp_arena + .get(input_left) + .schema(ctxt.lp_arena) + .into_owned(); + + let opts = Arc::make_mut(&mut options); + opts.args.how = JoinType::Cross; + + let (mut last_node, join_node) = resolve_join( + Either::Right(input_left), + Either::Right(input_right), + vec![], + vec![], + vec![], + options.clone(), + ctxt, + )?; + + let schema_merged = ctxt + .lp_arena + .get(last_node) + .schema(ctxt.lp_arena) + .into_owned(); + + // Perform predicate validation. + let mut upcast_exprs = Vec::<(Node, DataType)>::new(); + for e in predicates { + let arena = &mut ctxt.expr_arena; + let predicate = to_expr_ir_ignore_alias(e, arena)?; + let node = predicate.node(); + + // Ensure the predicate dtype output of the root node is Boolean + let ae = arena.get(node); + let dt_out = ae.to_dtype(&schema_merged, Context::Default, arena)?; + polars_ensure!( + dt_out == DataType::Boolean, + ComputeError: "'join_where' predicates must resolve to boolean" + ); + + ensure_lossless_binary_comparisons( + &node, + &schema_left, + &schema_merged, + arena, + &mut upcast_exprs, + )?; + + ctxt.conversion_optimizer + .push_scratch(predicate.node(), ctxt.expr_arena); + + let ir = IR::Filter { + input: last_node, + predicate, + }; + + last_node = ctxt.lp_arena.add(ir); + } + + ctxt.conversion_optimizer + .optimize_exprs(ctxt.expr_arena, ctxt.lp_arena, last_node) + .map_err(|e| e.context("'join_where' failed".into()))?; + + Ok((last_node, join_node)) +} + +/// Locate nodes that are operands in a binary comparison involving both tables, and ensure that +/// these nodes are losslessly upcast to a safe dtype. +fn ensure_lossless_binary_comparisons( + node: &Node, + schema_left: &Schema, + schema_merged: &Schema, + expr_arena: &mut Arena, + upcast_exprs: &mut Vec<(Node, DataType)>, +) -> PolarsResult<()> { + // let mut upcast_exprs = Vec::<(Node, DataType)>::new(); + // Ensure that all binary comparisons that use both tables are lossless. + build_upcast_node_list(node, schema_left, schema_merged, expr_arena, upcast_exprs)?; + // Replace each node with its casted counterpart + for (expr, dtype) in upcast_exprs.drain(..) { + let old_expr = expr_arena.duplicate(expr); + let new_aexpr = AExpr::Cast { + expr: old_expr, + dtype, + options: CastOptions::Overflowing, + }; + expr_arena.replace(expr, new_aexpr); + } + Ok(()) +} + +/// If we are dealing with a binary comparison involving columns from exclusively the left table +/// on the LHS and the right table on the RHS side, ensure that the cast is lossless. +/// Expressions involving binaries using either table alone we leave up to the user to verify +/// that they are valid, as they could theoretically be pushed outside of the join. +#[recursive] +fn build_upcast_node_list( + node: &Node, + schema_left: &Schema, + schema_merged: &Schema, + expr_arena: &Arena, + to_replace: &mut Vec<(Node, DataType)>, +) -> PolarsResult { + let expr_origin = match expr_arena.get(*node) { + AExpr::Column(name) => { + if schema_left.contains(name) { + ExprOrigin::Left + } else if schema_merged.contains(name) { + ExprOrigin::Right + } else { + polars_bail!(ColumnNotFound: "{}", name); + } + }, + AExpr::Literal(..) => ExprOrigin::None, + AExpr::Cast { expr: node, .. } => { + build_upcast_node_list(node, schema_left, schema_merged, expr_arena, to_replace)? + }, + AExpr::BinaryExpr { + left: left_node, + op, + right: right_node, + } => { + // If left and right node has both, ensure the dtypes are valid. + let left_origin = build_upcast_node_list( + left_node, + schema_left, + schema_merged, + expr_arena, + to_replace, + )?; + let right_origin = build_upcast_node_list( + right_node, + schema_left, + schema_merged, + expr_arena, + to_replace, + )?; + // We only update casts during comparisons if the operands are from different tables. + if op.is_comparison() { + match (left_origin, right_origin) { + (ExprOrigin::Left, ExprOrigin::Right) + | (ExprOrigin::Right, ExprOrigin::Left) => { + // Ensure our dtype casts are lossless + let left = expr_arena.get(*left_node); + let right = expr_arena.get(*right_node); + let dtype_left = + left.to_dtype(schema_merged, Context::Default, expr_arena)?; + let dtype_right = + right.to_dtype(schema_merged, Context::Default, expr_arena)?; + if dtype_left != dtype_right { + // Ensure that we have a lossless cast between the two types. + let dt = if dtype_left.is_primitive_numeric() + || dtype_right.is_primitive_numeric() + { + get_numeric_upcast_supertype_lossless(&dtype_left, &dtype_right) + .ok_or(PolarsError::SchemaMismatch( + format!( + "'join_where' cannot compare {:?} with {:?}", + dtype_left, dtype_right + ) + .into(), + )) + } else { + try_get_supertype(&dtype_left, &dtype_right) + }?; + + // Store the nodes and their replacements if a cast is required. + let replace_left = dt != dtype_left; + let replace_right = dt != dtype_right; + if replace_left && replace_right { + to_replace.push((*left_node, dt.clone())); + to_replace.push((*right_node, dt)); + } else if replace_left { + to_replace.push((*left_node, dt)); + } else if replace_right { + to_replace.push((*right_node, dt)); + } + } + }, + _ => (), + } + } + left_origin | right_origin + }, + _ => ExprOrigin::None, + }; + Ok(expr_origin) +} diff --git a/crates/polars-plan/src/plans/conversion/mod.rs b/crates/polars-plan/src/plans/conversion/mod.rs new file mode 100644 index 000000000000..3725c8a1e031 --- /dev/null +++ b/crates/polars-plan/src/plans/conversion/mod.rs @@ -0,0 +1,357 @@ +mod convert_utils; +mod dsl_to_ir; +mod expr_expansion; +mod expr_to_ir; +mod ir_to_dsl; +#[cfg(any( + feature = "ipc", + feature = "parquet", + feature = "csv", + feature = "json" +))] +mod scans; +mod stack_opt; + +use std::borrow::Cow; +use std::sync::{Arc, Mutex}; + +pub use dsl_to_ir::*; +pub use expr_to_ir::*; +pub use ir_to_dsl::*; +use polars_core::prelude::*; +use polars_utils::idx_vec::UnitVec; +use polars_utils::unitvec; +use polars_utils::vec::ConvertVec; +use recursive::recursive; +#[cfg(any( + feature = "ipc", + feature = "parquet", + feature = "csv", + feature = "json" +))] +pub use scans::*; +mod functions; +mod join; +pub(crate) mod type_check; +pub(crate) mod type_coercion; + +pub(crate) use expr_expansion::{expand_selectors, is_regex_projection, prepare_projection}; +pub(crate) use stack_opt::ConversionOptimizer; + +use crate::constants::get_len_name; +use crate::prelude::*; + +fn expr_irs_to_exprs(expr_irs: Vec, expr_arena: &Arena) -> Vec { + expr_irs.convert_owned(|e| e.to_expr(expr_arena)) +} + +impl IR { + #[recursive] + fn into_lp( + self, + conversion_fn: &F, + lp_arena: &mut LPA, + expr_arena: &Arena, + ) -> DslPlan + where + F: Fn(Node, &mut LPA) -> IR, + { + let lp = self; + let convert_to_lp = |node: Node, lp_arena: &mut LPA| { + conversion_fn(node, lp_arena).into_lp(conversion_fn, lp_arena, expr_arena) + }; + match lp { + ir @ IR::Scan { .. } => { + let IR::Scan { + sources, + file_info, + hive_parts: _, + predicate: _, + scan_type, + output_schema: _, + unified_scan_args, + } = ir.clone() + else { + unreachable!() + }; + + DslPlan::Scan { + sources, + file_info: Some(file_info), + scan_type, + unified_scan_args, + cached_ir: Arc::new(Mutex::new(Some(ir))), + } + }, + #[cfg(feature = "python")] + IR::PythonScan { .. } => DslPlan::PythonScan { + options: Default::default(), + }, + IR::Union { inputs, .. } => { + let inputs = inputs + .into_iter() + .map(|node| convert_to_lp(node, lp_arena)) + .collect(); + DslPlan::Union { + inputs, + args: Default::default(), + } + }, + IR::HConcat { + inputs, + schema: _, + options, + } => { + let inputs = inputs + .into_iter() + .map(|node| convert_to_lp(node, lp_arena)) + .collect(); + DslPlan::HConcat { inputs, options } + }, + IR::Slice { input, offset, len } => { + let lp = convert_to_lp(input, lp_arena); + DslPlan::Slice { + input: Arc::new(lp), + offset, + len, + } + }, + IR::Filter { input, predicate } => { + let lp = convert_to_lp(input, lp_arena); + let predicate = predicate.to_expr(expr_arena); + DslPlan::Filter { + input: Arc::new(lp), + predicate, + } + }, + IR::DataFrameScan { + df, + schema, + output_schema: _, + } => DslPlan::DataFrameScan { df, schema }, + IR::Select { + expr, + input, + schema: _, + options, + } => { + let i = convert_to_lp(input, lp_arena); + let expr = expr_irs_to_exprs(expr, expr_arena); + DslPlan::Select { + expr, + input: Arc::new(i), + options, + } + }, + IR::SimpleProjection { input, columns } => { + let input = convert_to_lp(input, lp_arena); + let expr = columns + .iter_names() + .map(|name| Expr::Column(name.clone())) + .collect::>(); + DslPlan::Select { + expr, + input: Arc::new(input), + options: Default::default(), + } + }, + IR::Sort { + input, + by_column, + slice, + sort_options, + } => { + let input = Arc::new(convert_to_lp(input, lp_arena)); + let by_column = expr_irs_to_exprs(by_column, expr_arena); + DslPlan::Sort { + input, + by_column, + slice, + sort_options, + } + }, + IR::Cache { + input, + id, + cache_hits: _, + } => { + let input = Arc::new(convert_to_lp(input, lp_arena)); + DslPlan::Cache { input, id } + }, + IR::GroupBy { + input, + keys, + aggs, + schema, + apply, + maintain_order, + options: dynamic_options, + } => { + let i = convert_to_lp(input, lp_arena); + let keys = expr_irs_to_exprs(keys, expr_arena); + let aggs = expr_irs_to_exprs(aggs, expr_arena); + + DslPlan::GroupBy { + input: Arc::new(i), + keys, + aggs, + apply: apply.map(|apply| (apply, schema)), + maintain_order, + options: dynamic_options, + } + }, + IR::Join { + input_left, + input_right, + schema: _, + left_on, + right_on, + options, + } => { + let i_l = convert_to_lp(input_left, lp_arena); + let i_r = convert_to_lp(input_right, lp_arena); + + let left_on = expr_irs_to_exprs(left_on, expr_arena); + let right_on = expr_irs_to_exprs(right_on, expr_arena); + + DslPlan::Join { + input_left: Arc::new(i_l), + input_right: Arc::new(i_r), + predicates: Default::default(), + left_on, + right_on, + options, + } + }, + IR::HStack { + input, + exprs, + options, + .. + } => { + let i = convert_to_lp(input, lp_arena); + let exprs = expr_irs_to_exprs(exprs, expr_arena); + + DslPlan::HStack { + input: Arc::new(i), + exprs, + options, + } + }, + IR::Distinct { input, options } => { + let i = convert_to_lp(input, lp_arena); + let options = DistinctOptionsDSL { + subset: options.subset.map(|s| { + s.iter() + .map(|name| Expr::Column(name.clone()).into()) + .collect() + }), + maintain_order: options.maintain_order, + keep_strategy: options.keep_strategy, + }; + DslPlan::Distinct { + input: Arc::new(i), + options, + } + }, + IR::MapFunction { input, function } => { + let input = Arc::new(convert_to_lp(input, lp_arena)); + DslPlan::MapFunction { + input, + function: function.into(), + } + }, + IR::ExtContext { + input, contexts, .. + } => { + let input = Arc::new(convert_to_lp(input, lp_arena)); + let contexts = contexts + .into_iter() + .map(|node| convert_to_lp(node, lp_arena)) + .collect(); + DslPlan::ExtContext { input, contexts } + }, + IR::Sink { input, payload } => { + let input = Arc::new(convert_to_lp(input, lp_arena)); + let payload = match payload { + SinkTypeIR::Memory => SinkType::Memory, + SinkTypeIR::File(f) => SinkType::File(f), + SinkTypeIR::Partition(f) => SinkType::Partition(PartitionSinkType { + base_path: f.base_path, + file_path_cb: f.file_path_cb, + file_type: f.file_type, + sink_options: f.sink_options, + variant: match f.variant { + PartitionVariantIR::MaxSize(max_size) => { + PartitionVariant::MaxSize(max_size) + }, + PartitionVariantIR::Parted { + key_exprs, + include_key, + } => PartitionVariant::Parted { + key_exprs: expr_irs_to_exprs(key_exprs, expr_arena), + include_key, + }, + PartitionVariantIR::ByKey { + key_exprs, + include_key, + } => PartitionVariant::ByKey { + key_exprs: expr_irs_to_exprs(key_exprs, expr_arena), + include_key, + }, + }, + cloud_options: f.cloud_options, + }), + }; + DslPlan::Sink { input, payload } + }, + IR::SinkMultiple { inputs } => { + let inputs = inputs + .into_iter() + .map(|node| convert_to_lp(node, lp_arena)) + .collect(); + DslPlan::SinkMultiple { inputs } + }, + #[cfg(feature = "merge_sorted")] + IR::MergeSorted { + input_left, + input_right, + key, + } => { + let input_left = Arc::new(convert_to_lp(input_left, lp_arena)); + let input_right = Arc::new(convert_to_lp(input_right, lp_arena)); + + DslPlan::MergeSorted { + input_left, + input_right, + key, + } + }, + IR::Invalid => unreachable!(), + } + } +} + +fn get_input(lp_arena: &Arena, lp_node: Node) -> UnitVec { + let plan = lp_arena.get(lp_node); + let mut inputs: UnitVec = unitvec!(); + + // Used to get the schema of the input. + if is_scan(plan) { + inputs.push(lp_node); + } else { + plan.copy_inputs(&mut inputs); + }; + inputs +} + +fn get_schema(lp_arena: &Arena, lp_node: Node) -> Cow<'_, SchemaRef> { + let inputs = get_input(lp_arena, lp_node); + if inputs.is_empty() { + // Files don't have an input, so we must take their schema. + Cow::Borrowed(lp_arena.get(lp_node).scan_schema()) + } else { + let input = inputs[0]; + lp_arena.get(input).schema(lp_arena) + } +} diff --git a/crates/polars-plan/src/plans/conversion/scans.rs b/crates/polars-plan/src/plans/conversion/scans.rs new file mode 100644 index 000000000000..c0f8f21d6ab3 --- /dev/null +++ b/crates/polars-plan/src/plans/conversion/scans.rs @@ -0,0 +1,373 @@ +use either::Either; +use polars_io::RowIndex; +use polars_io::path_utils::is_cloud_url; +#[cfg(feature = "cloud")] +use polars_io::pl_async::get_runtime; +use polars_io::prelude::*; +use polars_io::utils::compression::maybe_decompress_bytes; + +use super::*; + +#[cfg(any(feature = "parquet", feature = "ipc"))] +fn prepare_output_schema(mut schema: Schema, row_index: Option<&RowIndex>) -> SchemaRef { + if let Some(rc) = row_index { + let _ = schema.insert_at_index(0, rc.name.clone(), IDX_DTYPE); + } + Arc::new(schema) +} + +#[cfg(any(feature = "json", feature = "csv"))] +fn prepare_schemas(mut schema: Schema, row_index: Option<&RowIndex>) -> (SchemaRef, SchemaRef) { + if let Some(rc) = row_index { + let reader_schema = schema.clone(); + let _ = schema.insert_at_index(0, rc.name.clone(), IDX_DTYPE); + (Arc::new(reader_schema), Arc::new(schema)) + } else { + let schema = Arc::new(schema); + (schema.clone(), schema) + } +} + +#[cfg(feature = "parquet")] +pub(super) fn parquet_file_info( + sources: &ScanSources, + row_index: Option<&RowIndex>, + #[allow(unused)] cloud_options: Option<&polars_io::cloud::CloudOptions>, +) -> PolarsResult<(FileInfo, Option)> { + use polars_core::error::feature_gated; + + let (reader_schema, num_rows, metadata) = { + if sources.is_cloud_url() { + let first_path = &sources.as_paths().unwrap()[0]; + feature_gated!("cloud", { + let uri = first_path.to_string_lossy(); + get_runtime().block_in_place_on(async { + let mut reader = + ParquetObjectStore::from_uri(&uri, cloud_options, None).await?; + + PolarsResult::Ok(( + reader.schema().await?, + Some(reader.num_rows().await?), + Some(reader.get_metadata().await?.clone()), + )) + })? + }) + } else { + let first_source = sources + .first() + .ok_or_else(|| polars_err!(ComputeError: "expected at least 1 source"))?; + let memslice = first_source.to_memslice()?; + let mut reader = ParquetReader::new(std::io::Cursor::new(memslice)); + ( + reader.schema()?, + Some(reader.num_rows()?), + Some(reader.get_metadata()?.clone()), + ) + } + }; + + let schema = + prepare_output_schema(Schema::from_arrow_schema(reader_schema.as_ref()), row_index); + + let file_info = FileInfo::new( + schema, + Some(Either::Left(reader_schema)), + (num_rows, num_rows.unwrap_or(0)), + ); + + Ok((file_info, metadata)) +} + +// TODO! return metadata arced +#[cfg(feature = "ipc")] +pub(super) fn ipc_file_info( + sources: &ScanSources, + row_index: Option<&RowIndex>, + cloud_options: Option<&polars_io::cloud::CloudOptions>, +) -> PolarsResult<(FileInfo, arrow::io::ipc::read::FileMetadata)> { + use polars_core::error::feature_gated; + + let Some(first) = sources.first() else { + polars_bail!(ComputeError: "expected at least 1 source"); + }; + + let metadata = match first { + ScanSourceRef::Path(path) => { + if is_cloud_url(path) { + feature_gated!("cloud", { + let uri = path.to_string_lossy(); + get_runtime().block_on(async { + polars_io::ipc::IpcReaderAsync::from_uri(&uri, cloud_options) + .await? + .metadata() + .await + })? + }) + } else { + arrow::io::ipc::read::read_file_metadata(&mut std::io::BufReader::new( + polars_utils::open_file(path)?, + ))? + } + }, + ScanSourceRef::File(file) => { + arrow::io::ipc::read::read_file_metadata(&mut std::io::BufReader::new(file))? + }, + ScanSourceRef::Buffer(buff) => { + arrow::io::ipc::read::read_file_metadata(&mut std::io::Cursor::new(buff))? + }, + }; + + let file_info = FileInfo::new( + prepare_output_schema( + Schema::from_arrow_schema(metadata.schema.as_ref()), + row_index, + ), + Some(Either::Left(Arc::clone(&metadata.schema))), + (None, 0), + ); + + Ok((file_info, metadata)) +} + +#[cfg(feature = "csv")] +pub fn isolated_csv_file_info( + source: ScanSourceRef, + row_index: Option<&RowIndex>, + csv_options: &mut CsvReadOptions, + _cloud_options: Option<&polars_io::cloud::CloudOptions>, +) -> PolarsResult { + use std::io::{Read, Seek}; + + use polars_io::csv::read::schema_inference::SchemaInferenceResult; + use polars_io::utils::get_reader_bytes; + + let run_async = source.run_async(); + + let memslice = source.to_memslice_async_assume_latest(run_async)?; + let owned = &mut vec![]; + let mut reader = std::io::Cursor::new(maybe_decompress_bytes(&memslice, owned)?); + if reader.read(&mut [0; 4])? < 2 && csv_options.raise_if_empty { + polars_bail!(NoData: "empty CSV") + } + reader.rewind()?; + + let reader_bytes = get_reader_bytes(&mut reader).expect("could not mmap file"); + + // this needs a way to estimated bytes/rows. + let si_result = + SchemaInferenceResult::try_from_reader_bytes_and_options(&reader_bytes, csv_options)?; + + csv_options.update_with_inference_result(&si_result); + + let mut schema = csv_options + .schema + .clone() + .unwrap_or_else(|| si_result.get_inferred_schema()); + + let reader_schema = if let Some(rc) = row_index { + let reader_schema = schema.clone(); + let mut output_schema = (*reader_schema).clone(); + output_schema.insert_at_index(0, rc.name.clone(), IDX_DTYPE)?; + schema = Arc::new(output_schema); + reader_schema + } else { + schema.clone() + }; + + let estimated_n_rows = si_result.get_estimated_n_rows(); + + Ok(FileInfo::new( + schema, + Some(Either::Right(reader_schema)), + (None, estimated_n_rows), + )) +} + +#[cfg(feature = "csv")] +pub fn csv_file_info( + sources: &ScanSources, + row_index: Option<&RowIndex>, + csv_options: &mut CsvReadOptions, + cloud_options: Option<&polars_io::cloud::CloudOptions>, +) -> PolarsResult { + use std::io::{Read, Seek}; + + use polars_core::error::feature_gated; + use polars_core::{POOL, config}; + use polars_io::csv::read::schema_inference::SchemaInferenceResult; + use polars_io::utils::get_reader_bytes; + use rayon::iter::{IntoParallelIterator, ParallelIterator}; + + polars_ensure!(!sources.is_empty(), ComputeError: "expected at least 1 source"); + + // TODO: + // * See if we can do better than scanning all files if there is a row limit + // * See if we can do this without downloading the entire file + + // prints the error message if paths is empty. + let run_async = sources.is_cloud_url() || (sources.is_paths() && config::force_async()); + + let cache_entries = { + if run_async { + feature_gated!("cloud", { + Some(polars_io::file_cache::init_entries_from_uri_list( + sources + .as_paths() + .unwrap() + .iter() + .map(|path| Arc::from(path.to_str().unwrap())) + .collect::>() + .as_slice(), + cloud_options, + )?) + }) + } else { + None + } + }; + + let infer_schema_func = |i| { + let source = sources.at(i); + let memslice = source.to_memslice_possibly_async(run_async, cache_entries.as_ref(), i)?; + let owned = &mut vec![]; + let mut reader = std::io::Cursor::new(maybe_decompress_bytes(&memslice, owned)?); + if reader.read(&mut [0; 4])? < 2 && csv_options.raise_if_empty { + polars_bail!(NoData: "empty CSV") + } + reader.rewind()?; + + let reader_bytes = get_reader_bytes(&mut reader).expect("could not mmap file"); + + // this needs a way to estimated bytes/rows. + SchemaInferenceResult::try_from_reader_bytes_and_options(&reader_bytes, csv_options) + }; + + let merge_func = |a: PolarsResult, + b: PolarsResult| { + match (a, b) { + (Err(e), _) | (_, Err(e)) => Err(e), + (Ok(a), Ok(b)) => { + let merged_schema = if csv_options.schema.is_some() { + csv_options.schema.clone().unwrap() + } else { + let schema_a = a.get_inferred_schema(); + let schema_b = b.get_inferred_schema(); + + match (schema_a.is_empty(), schema_b.is_empty()) { + (true, _) => schema_b, + (_, true) => schema_a, + _ => { + let mut s = Arc::unwrap_or_clone(schema_a); + s.to_supertype(&schema_b)?; + Arc::new(s) + }, + } + }; + + Ok(a.with_inferred_schema(merged_schema)) + }, + } + }; + + let si_results = POOL.join( + || infer_schema_func(0), + || { + (1..sources.len()) + .into_par_iter() + .map(infer_schema_func) + .reduce(|| Ok(Default::default()), merge_func) + }, + ); + + let si_result = merge_func(si_results.0, si_results.1)?; + + csv_options.update_with_inference_result(&si_result); + + let mut schema = csv_options + .schema + .clone() + .unwrap_or_else(|| si_result.get_inferred_schema()); + + let reader_schema = if let Some(rc) = row_index { + let reader_schema = schema.clone(); + let mut output_schema = (*reader_schema).clone(); + output_schema.insert_at_index(0, rc.name.clone(), IDX_DTYPE)?; + schema = Arc::new(output_schema); + reader_schema + } else { + schema.clone() + }; + + let estimated_n_rows = si_result.get_estimated_n_rows(); + + Ok(FileInfo::new( + schema, + Some(Either::Right(reader_schema)), + (None, estimated_n_rows), + )) +} + +#[cfg(feature = "json")] +pub fn ndjson_file_info( + sources: &ScanSources, + row_index: Option<&RowIndex>, + ndjson_options: &NDJsonReadOptions, + cloud_options: Option<&polars_io::cloud::CloudOptions>, +) -> PolarsResult { + use polars_core::config; + use polars_core::error::feature_gated; + + let Some(first) = sources.first() else { + polars_bail!(ComputeError: "expected at least 1 source"); + }; + + let run_async = sources.is_cloud_url() || (sources.is_paths() && config::force_async()); + + let cache_entries = { + if run_async { + feature_gated!("cloud", { + Some(polars_io::file_cache::init_entries_from_uri_list( + sources + .as_paths() + .unwrap() + .iter() + .map(|path| Arc::from(path.to_str().unwrap())) + .collect::>() + .as_slice(), + cloud_options, + )?) + }) + } else { + None + } + }; + + let owned = &mut vec![]; + + let (mut reader_schema, schema) = if let Some(schema) = ndjson_options.schema.clone() { + if row_index.is_none() { + (schema.clone(), schema.clone()) + } else { + prepare_schemas(Arc::unwrap_or_clone(schema), row_index) + } + } else { + let memslice = first.to_memslice_possibly_async(run_async, cache_entries.as_ref(), 0)?; + let mut reader = std::io::Cursor::new(maybe_decompress_bytes(&memslice, owned)?); + + let schema = + polars_io::ndjson::infer_schema(&mut reader, ndjson_options.infer_schema_length)?; + + prepare_schemas(schema, row_index) + }; + + if let Some(overwriting_schema) = &ndjson_options.schema_overwrite { + let schema = Arc::make_mut(&mut reader_schema); + overwrite_schema(schema, overwriting_schema)?; + } + + Ok(FileInfo::new( + schema, + Some(Either::Right(reader_schema)), + (None, usize::MAX), + )) +} diff --git a/crates/polars-plan/src/plans/conversion/stack_opt.rs b/crates/polars-plan/src/plans/conversion/stack_opt.rs new file mode 100644 index 000000000000..e04bdf856dfe --- /dev/null +++ b/crates/polars-plan/src/plans/conversion/stack_opt.rs @@ -0,0 +1,110 @@ +use std::borrow::Borrow; + +use self::type_check::TypeCheckRule; +use super::*; + +/// Applies expression simplification and type coercion during conversion to IR. +pub struct ConversionOptimizer { + scratch: Vec, + + simplify: Option, + coerce: Option, + check: Option, + // IR's can be cached in the DSL. + // But if they are used multiple times in DSL (e.g. concat/join) + // then it can occur that we take a slot multiple times. + // So we keep track of the arena versions used and allow only + // one unique IR cache to be reused. + pub(super) used_arenas: PlHashSet, +} + +impl ConversionOptimizer { + pub fn new(simplify: bool, type_coercion: bool, type_check: bool) -> Self { + let simplify = if simplify { + Some(SimplifyExprRule {}) + } else { + None + }; + + let coerce = if type_coercion { + Some(TypeCoercionRule {}) + } else { + None + }; + + let check = if type_check { + Some(TypeCheckRule) + } else { + None + }; + + ConversionOptimizer { + scratch: Vec::with_capacity(8), + simplify, + coerce, + check, + used_arenas: Default::default(), + } + } + + pub fn push_scratch(&mut self, expr: Node, expr_arena: &Arena) { + self.scratch.push(expr); + // traverse all subexpressions and add to the stack + let expr = unsafe { expr_arena.get_unchecked(expr) }; + expr.inputs_rev(&mut self.scratch); + } + + pub fn fill_scratch>(&mut self, exprs: &[N], expr_arena: &Arena) { + for e in exprs { + let node = *e.borrow(); + self.push_scratch(node, expr_arena); + } + } + + /// Optimizes the expressions in the scratch space. This should be called after filling the + /// scratch space with the expressions that you want to optimize. + pub fn optimize_exprs( + &mut self, + expr_arena: &mut Arena, + ir_arena: &mut Arena, + current_ir_node: Node, + ) -> PolarsResult<()> { + // Different from the stack-opt in the optimizer phase, this does a single pass until fixed point per expression. + + if let Some(rule) = &mut self.check { + while let Some(x) = rule.optimize_plan(ir_arena, expr_arena, current_ir_node)? { + ir_arena.replace(current_ir_node, x); + } + } + + // process the expressions on the stack and apply optimizations. + while let Some(current_expr_node) = self.scratch.pop() { + let expr = unsafe { expr_arena.get_unchecked(current_expr_node) }; + + if expr.is_leaf() { + continue; + } + + if let Some(rule) = &mut self.simplify { + while let Some(x) = + rule.optimize_expr(expr_arena, current_expr_node, ir_arena, current_ir_node)? + { + expr_arena.replace(current_expr_node, x); + } + } + if let Some(rule) = &mut self.coerce { + while let Some(x) = + rule.optimize_expr(expr_arena, current_expr_node, ir_arena, current_ir_node)? + { + expr_arena.replace(current_expr_node, x); + } + } + + let expr = unsafe { expr_arena.get_unchecked(current_expr_node) }; + // traverse subexpressions and add to the stack + expr.inputs_rev(&mut self.scratch) + } + + Ok(()) + } +} diff --git a/crates/polars-plan/src/plans/conversion/type_check/mod.rs b/crates/polars-plan/src/plans/conversion/type_check/mod.rs new file mode 100644 index 000000000000..cb36d2709e08 --- /dev/null +++ b/crates/polars-plan/src/plans/conversion/type_check/mod.rs @@ -0,0 +1,48 @@ +use polars_core::error::{PolarsResult, polars_ensure}; +use polars_core::prelude::DataType; +use polars_utils::arena::{Arena, Node}; + +use super::{AExpr, IR, OptimizationRule}; +use crate::plans::Context; +use crate::plans::conversion::get_schema; + +pub struct TypeCheckRule; + +impl OptimizationRule for TypeCheckRule { + fn optimize_plan( + &mut self, + ir_arena: &mut Arena, + expr_arena: &mut Arena, + node: Node, + ) -> PolarsResult> { + let ir = ir_arena.get(node); + match ir { + IR::Scan { + predicate: Some(predicate), + .. + } => { + let input_schema = get_schema(ir_arena, node); + let dtype = predicate.dtype(input_schema.as_ref(), Context::Default, expr_arena)?; + + polars_ensure!( + matches!(dtype, DataType::Boolean | DataType::Unknown(_)), + InvalidOperation: "filter predicate must be of type `Boolean`, got `{dtype:?}`" + ); + + Ok(None) + }, + IR::Filter { predicate, .. } => { + let input_schema = get_schema(ir_arena, node); + let dtype = predicate.dtype(input_schema.as_ref(), Context::Default, expr_arena)?; + + polars_ensure!( + matches!(dtype, DataType::Boolean | DataType::Unknown(_)), + InvalidOperation: "filter predicate must be of type `Boolean`, got `{dtype:?}`" + ); + + Ok(None) + }, + _ => Ok(None), + } + } +} diff --git a/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs b/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs new file mode 100644 index 000000000000..133fc135a76e --- /dev/null +++ b/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs @@ -0,0 +1,285 @@ +#[cfg(feature = "dtype-categorical")] +use polars_utils::matches_any_order; + +use super::*; + +macro_rules! unpack { + ($packed:expr) => {{ + match $packed { + Some(payload) => payload, + None => return Ok(None), + } + }}; +} + +#[allow(unused_variables)] +fn compares_cat_to_string(type_left: &DataType, type_right: &DataType, op: Operator) -> bool { + #[cfg(feature = "dtype-categorical")] + { + op.is_comparison_or_bitwise() + && matches_any_order!( + type_left, + type_right, + DataType::String | DataType::Unknown(UnknownKind::Str), + DataType::Categorical(_, _) | DataType::Enum(_, _) + ) + } + #[cfg(not(feature = "dtype-categorical"))] + { + false + } +} + +#[allow(unused_variables)] +fn is_cat_str_binary(type_left: &DataType, type_right: &DataType) -> bool { + #[cfg(feature = "dtype-categorical")] + { + matches_any_order!( + type_left, + type_right, + DataType::String, + DataType::Categorical(_, _) | DataType::Enum(_, _) + ) + } + #[cfg(not(feature = "dtype-categorical"))] + { + false + } +} + +#[cfg(feature = "dtype-struct")] +// Ensure we don't cast to supertype +// otherwise we will fill a struct with null fields +fn process_struct_numeric_arithmetic( + type_left: DataType, + type_right: DataType, + node_left: Node, + node_right: Node, + op: Operator, + expr_arena: &mut Arena, +) -> PolarsResult> { + match (&type_left, &type_right) { + (DataType::Struct(fields), _) => { + if let Some(first) = fields.first() { + let new_node_right = expr_arena.add(AExpr::Cast { + expr: node_right, + dtype: DataType::Struct(vec![first.clone()]), + options: CastOptions::NonStrict, + }); + Ok(Some(AExpr::BinaryExpr { + left: node_left, + op, + right: new_node_right, + })) + } else { + Ok(None) + } + }, + (_, DataType::Struct(fields)) => { + if let Some(first) = fields.first() { + let new_node_left = expr_arena.add(AExpr::Cast { + expr: node_left, + dtype: DataType::Struct(vec![first.clone()]), + options: CastOptions::NonStrict, + }); + + Ok(Some(AExpr::BinaryExpr { + left: new_node_left, + op, + right: node_right, + })) + } else { + Ok(None) + } + }, + _ => unreachable!(), + } +} + +#[cfg(any( + feature = "dtype-date", + feature = "dtype-datetime", + feature = "dtype-time" +))] +fn err_date_str_compare() -> PolarsResult<()> { + if cfg!(feature = "python") { + polars_bail!( + InvalidOperation: + "cannot compare 'date/datetime/time' to a string value \ + (create native python {{ 'date', 'datetime', 'time' }} or compare to a temporal column)" + ); + } else { + polars_bail!( + InvalidOperation: "cannot compare 'date/datetime/time' to a string value" + ); + } +} + +pub(super) fn process_binary( + expr_arena: &mut Arena, + lp_arena: &Arena, + lp_node: Node, + node_left: Node, + op: Operator, + node_right: Node, +) -> PolarsResult> { + let input_schema = get_schema(lp_arena, lp_node); + let (left, type_left): (&AExpr, DataType) = + unpack!(get_aexpr_and_type(expr_arena, node_left, &input_schema)); + let (right, type_right): (&AExpr, DataType) = + unpack!(get_aexpr_and_type(expr_arena, node_right, &input_schema)); + + match (&type_left, &type_right) { + (Unknown(UnknownKind::Any), Unknown(UnknownKind::Any)) => return Ok(None), + ( + Unknown(UnknownKind::Any), + Unknown(UnknownKind::Int(_) | UnknownKind::Float | UnknownKind::Str), + ) => { + let right = unpack!(materialize(right)); + let right = expr_arena.add(right); + + return Ok(Some(AExpr::BinaryExpr { + left: node_left, + op, + right, + })); + }, + ( + Unknown(UnknownKind::Int(_) | UnknownKind::Float | UnknownKind::Str), + Unknown(UnknownKind::Any), + ) => { + let left = unpack!(materialize(left)); + let left = expr_arena.add(left); + + return Ok(Some(AExpr::BinaryExpr { + left, + op, + right: node_right, + })); + }, + _ => { + unpack!(early_escape(&type_left, &type_right)); + }, + } + + use DataType::*; + // don't coerce string with number comparisons. They must error + match (&type_left, &type_right, op) { + #[cfg(not(feature = "dtype-categorical"))] + (DataType::String, dt, op) | (dt, DataType::String, op) + if op.is_comparison_or_bitwise() && dt.is_primitive_numeric() => + { + return Ok(None); + }, + #[cfg(feature = "dtype-categorical")] + (String | Unknown(UnknownKind::Str) | Categorical(_, _), dt, op) + | (dt, Unknown(UnknownKind::Str) | String | Categorical(_, _), op) + if op.is_comparison_or_bitwise() && dt.is_primitive_numeric() => + { + return Ok(None); + }, + #[cfg(feature = "dtype-categorical")] + (Unknown(UnknownKind::Str) | String | Enum(_, _), dt, op) + | (dt, Unknown(UnknownKind::Str) | String | Enum(_, _), op) + if op.is_comparison_or_bitwise() && dt.is_primitive_numeric() => + { + return Ok(None); + }, + #[cfg(feature = "dtype-date")] + (Date, String | Unknown(UnknownKind::Str), op) + | (String | Unknown(UnknownKind::Str), Date, op) + if op.is_comparison_or_bitwise() => + { + err_date_str_compare()? + }, + #[cfg(feature = "dtype-datetime")] + (Datetime(_, _), String | Unknown(UnknownKind::Str), op) + | (String | Unknown(UnknownKind::Str), Datetime(_, _), op) + if op.is_comparison_or_bitwise() => + { + err_date_str_compare()? + }, + #[cfg(feature = "dtype-time")] + (Time | Unknown(UnknownKind::Str), String, op) if op.is_comparison_or_bitwise() => { + err_date_str_compare()? + }, + // structs can be arbitrarily nested, leave the complexity to the caller for now. + #[cfg(feature = "dtype-struct")] + (Struct(_), Struct(_), _op) => return Ok(None), + _ => {}, + } + + if op.is_arithmetic() { + match (&type_left, &type_right) { + (Duration(_), Duration(_)) => return Ok(None), + (Duration(_), r) if r.is_primitive_numeric() => return Ok(None), + (String, a) | (a, String) if a.is_primitive_numeric() => { + polars_bail!(InvalidOperation: "arithmetic on string and numeric not allowed, try an explicit cast first") + }, + (Datetime(_, _), _) + | (_, Datetime(_, _)) + | (Date, _) + | (_, Date) + | (Duration(_), _) + | (_, Duration(_)) + | (Time, _) + | (_, Time) + | (List(_), _) + | (_, List(_)) => return Ok(None), + #[cfg(feature = "dtype-array")] + (Array(..), _) | (_, Array(..)) => return Ok(None), + #[cfg(feature = "dtype-struct")] + (Struct(_), a) | (a, Struct(_)) if a.is_primitive_numeric() => { + return process_struct_numeric_arithmetic( + type_left, type_right, node_left, node_right, op, expr_arena, + ); + }, + _ => {}, + } + } else if compares_cat_to_string(&type_left, &type_right, op) { + return Ok(None); + } + + // Coerce types: + let st = unpack!(get_supertype(&type_left, &type_right)); + let mut st = modify_supertype(st, left, right, &type_left, &type_right); + + if is_cat_str_binary(&type_left, &type_right) { + st = String + } + + // TODO! raise here? + // We should at least never cast to Unknown. + if matches!(st, DataType::Unknown(UnknownKind::Any)) { + return Ok(None); + } + + // Only cast if the type is not already the super type. + // this can prevent an expensive flattening and subsequent aggregation + // in a group_by context. To be able to cast the groups need to be + // flattened + let new_node_left = if type_left != st { + expr_arena.add(AExpr::Cast { + expr: node_left, + dtype: st.clone(), + options: CastOptions::NonStrict, + }) + } else { + node_left + }; + let new_node_right = if type_right != st { + expr_arena.add(AExpr::Cast { + expr: node_right, + dtype: st, + options: CastOptions::NonStrict, + }) + } else { + node_right + }; + + Ok(Some(AExpr::BinaryExpr { + left: new_node_left, + op, + right: new_node_right, + })) +} diff --git a/crates/polars-plan/src/plans/conversion/type_coercion/functions.rs b/crates/polars-plan/src/plans/conversion/type_coercion/functions.rs new file mode 100644 index 000000000000..2d84f7ef4dde --- /dev/null +++ b/crates/polars-plan/src/plans/conversion/type_coercion/functions.rs @@ -0,0 +1,73 @@ +use super::*; + +/// Get the datatypes of function arguments. +/// +/// If all arguments give the same datatype or a datatype cannot be found, `Ok(None)` is returned. +pub(super) fn get_function_dtypes( + input: &[ExprIR], + expr_arena: &Arena, + input_schema: &Schema, + function: &FunctionExpr, +) -> PolarsResult>> { + let mut dtypes = Vec::with_capacity(input.len()); + let mut first = true; + for e in input { + let Some((_, dtype)) = get_aexpr_and_type(expr_arena, e.node(), input_schema) else { + return Ok(None); + }; + + if first { + check_namespace(function, &dtype)?; + first = false; + } + // Ignore Unknown in the inputs. + // We will raise if we cannot find the supertype later. + match dtype { + DataType::Unknown(UnknownKind::Any) => { + return Ok(None); + }, + _ => dtypes.push(dtype), + } + } + + if dtypes.iter().all_equal() { + return Ok(None); + } + Ok(Some(dtypes)) +} + +// `str` namespace belongs to `String` +// `cat` namespace belongs to `Categorical` etc. +fn check_namespace(function: &FunctionExpr, first_dtype: &DataType) -> PolarsResult<()> { + match function { + #[cfg(feature = "strings")] + FunctionExpr::StringExpr(_) => { + polars_ensure!(first_dtype == &DataType::String, InvalidOperation: "expected String type, got: {}", first_dtype) + }, + FunctionExpr::BinaryExpr(_) => { + polars_ensure!(first_dtype == &DataType::Binary, InvalidOperation: "expected Binary type, got: {}", first_dtype) + }, + #[cfg(feature = "temporal")] + FunctionExpr::TemporalExpr(_) => { + polars_ensure!(first_dtype.is_temporal(), InvalidOperation: "expected Date(time)/Duration type, got: {}", first_dtype) + }, + FunctionExpr::ListExpr(_) => { + polars_ensure!(matches!(first_dtype, DataType::List(_)), InvalidOperation: "expected List type, got: {}", first_dtype) + }, + #[cfg(feature = "dtype-array")] + FunctionExpr::ArrayExpr(_) => { + polars_ensure!(matches!(first_dtype, DataType::Array(_, _)), InvalidOperation: "expected Array type, got: {}", first_dtype) + }, + #[cfg(feature = "dtype-struct")] + FunctionExpr::StructExpr(_) => { + polars_ensure!(matches!(first_dtype, DataType::Struct(_)), InvalidOperation: "expected Struct type, got: {}", first_dtype) + }, + #[cfg(feature = "dtype-categorical")] + FunctionExpr::Categorical(_) => { + polars_ensure!(matches!(first_dtype, DataType::Categorical(_, _)), InvalidOperation: "expected Categorical type, got: {}", first_dtype) + }, + _ => {}, + } + + Ok(()) +} diff --git a/crates/polars-plan/src/plans/conversion/type_coercion/is_in.rs b/crates/polars-plan/src/plans/conversion/type_coercion/is_in.rs new file mode 100644 index 000000000000..fd353d1f2a31 --- /dev/null +++ b/crates/polars-plan/src/plans/conversion/type_coercion/is_in.rs @@ -0,0 +1,198 @@ +use super::*; + +pub(super) enum IsInTypeCoercionResult { + SuperType(DataType, DataType), + SelfCast { dtype: DataType, strict: bool }, + OtherCast { dtype: DataType, strict: bool }, + Implode, +} + +#[allow(clippy::too_many_arguments)] +pub(super) fn resolve_is_in( + input: &[ExprIR], + expr_arena: &Arena, + lp_arena: &Arena, + lp_node: Node, + is_contains: bool, + op: &'static str, + flat_idx: usize, + nested_idx: usize, +) -> PolarsResult> { + let input_schema = get_schema(lp_arena, lp_node); + let (_, type_left) = unpack!(get_aexpr_and_type( + expr_arena, + input[flat_idx].node(), + &input_schema + )); + let (_, type_other) = unpack!(get_aexpr_and_type( + expr_arena, + input[nested_idx].node(), + &input_schema + )); + + let left_nl = type_left.nesting_level(); + let right_nl = type_other.nesting_level(); + + // @HACK. This needs to happen until 2.0 because we support `pl.col.a.is_in(pl.col.a)`. + if !is_contains && left_nl == right_nl { + polars_warn!( + Deprecation, + "`is_in` with a collection of the same datatype is ambiguous and deprecated. +Please use `implode` to return to previous behavior. + +See https://github.com/pola-rs/polars/issues/22149 for more information." + ); + return Ok(Some(IsInTypeCoercionResult::Implode)); + } + + if left_nl + 1 != right_nl { + polars_bail!(InvalidOperation: "'{op}' cannot check for {:?} values in {:?} data", &type_other, &type_left); + } + + let type_other_inner = type_other.inner_dtype().unwrap(); + + unpack!(early_escape(&type_left, type_other_inner)); + + let cast_type = match &type_other { + DataType::List(_) => DataType::List(Box::new(type_left.clone())), + #[cfg(feature = "dtype-array")] + DataType::Array(_, width) => DataType::Array(Box::new(type_left.clone()), *width), + _ => unreachable!(), + }; + + let casted_expr = match (&type_left, type_other_inner) { + // types are equal, do nothing + (a, b) if a == b => return Ok(None), + // all-null can represent anything (and/or empty list), so cast to target dtype + (_, DataType::Null) => IsInTypeCoercionResult::OtherCast { + dtype: cast_type, + strict: false, + }, + #[cfg(feature = "dtype-categorical")] + (DataType::Enum(_, _), DataType::String) => IsInTypeCoercionResult::OtherCast { + dtype: cast_type, + strict: true, + }, + #[cfg(feature = "dtype-categorical")] + (DataType::String, DataType::Enum(_, _)) => IsInTypeCoercionResult::SelfCast { + dtype: type_other_inner.clone(), + strict: true, + }, + #[cfg(feature = "dtype-categorical")] + (DataType::String, DataType::Categorical(Some(rm), ordering)) if rm.is_global() => { + IsInTypeCoercionResult::SelfCast { + dtype: DataType::Categorical(None, *ordering), + strict: false, + } + }, + + // @NOTE: Local Categorical coercion has to happen in the kernel, which makes it streaming + // incompatible. + #[cfg(feature = "dtype-categorical")] + (DataType::Categorical(Some(rm), ordering), DataType::String) if rm.is_global() => { + IsInTypeCoercionResult::OtherCast { + dtype: match &type_other { + DataType::List(_) => { + DataType::List(Box::new(DataType::Categorical(None, *ordering))) + }, + #[cfg(feature = "dtype-array")] + DataType::Array(_, width) => { + DataType::Array(Box::new(DataType::Categorical(None, *ordering)), *width) + }, + _ => unreachable!(), + }, + strict: false, + } + }, + + #[cfg(feature = "dtype-categorical")] + (DataType::Categorical(_, _), DataType::String) => return Ok(None), + #[cfg(feature = "dtype-categorical")] + (DataType::String, DataType::Categorical(_, _)) => return Ok(None), + #[cfg(feature = "dtype-decimal")] + (DataType::Decimal(_, _), dt) if dt.is_primitive_numeric() => { + IsInTypeCoercionResult::OtherCast { + dtype: cast_type, + strict: false, + } + }, + #[cfg(feature = "dtype-decimal")] + (DataType::Decimal(_, _), _) | (_, DataType::Decimal(_, _)) => { + polars_bail!(InvalidOperation: "'{op}' cannot check for {:?} values in {:?} data", &type_other, &type_left) + }, + // can't check for more granular time_unit in less-granular time_unit data, + // or we'll cast away valid/necessary precision (eg: nanosecs to millisecs) + (DataType::Datetime(lhs_unit, _), DataType::Datetime(rhs_unit, _)) => { + if lhs_unit <= rhs_unit { + return Ok(None); + } else { + polars_bail!(InvalidOperation: "'{op}' cannot check for {:?} precision values in {:?} Datetime data", &rhs_unit, &lhs_unit) + } + }, + (DataType::Duration(lhs_unit), DataType::Duration(rhs_unit)) => { + if lhs_unit <= rhs_unit { + return Ok(None); + } else { + polars_bail!(InvalidOperation: "'{op}' cannot check for {:?} precision values in {:?} Duration data", &rhs_unit, &lhs_unit) + } + }, + (_, DataType::List(_)) => { + polars_ensure!( + &type_left == type_other_inner, + InvalidOperation: "'{op}' cannot check for {:?} values in {:?} data", + &type_left, &type_other + ); + return Ok(None); + }, + #[cfg(feature = "dtype-array")] + (_, DataType::Array(_, _)) => { + polars_ensure!( + &type_left == type_other_inner, + InvalidOperation: "'{op}' cannot check for {:?} values in {:?} data", + &type_left, &type_other + ); + return Ok(None); + }, + #[cfg(feature = "dtype-struct")] + (DataType::Struct(_), _) | (_, DataType::Struct(_)) => { + polars_ensure!( + &type_left == type_other_inner, + InvalidOperation: "'{op}' cannot check for {:?} values in {:?} data", + &type_left, &type_other + ); + return Ok(None); + }, + + // don't attempt to cast between obviously mismatched types, but + // allow integer/float comparison (will use their supertypes). + (a, b) => { + if (a.is_primitive_numeric() && b.is_primitive_numeric()) || (a == &DataType::Null) { + if a != b { + // @TAG: 2.0 + // @HACK: `is_in` does supertype casting between primitive numerics, which + // honestly makes very little sense. To stay backwards compatible we keep this, + // but please in 2.0 remove this. + + let super_type = + polars_core::utils::try_get_supertype(&type_left, type_other_inner)?; + let other_type = match &type_other { + DataType::List(_) => DataType::List(Box::new(super_type.clone())), + #[cfg(feature = "dtype-array")] + DataType::Array(_, width) => { + DataType::Array(Box::new(super_type.clone()), *width) + }, + _ => unreachable!(), + }; + + return Ok(Some(IsInTypeCoercionResult::SuperType( + super_type, other_type, + ))); + } + + return Ok(None); + } + polars_bail!(InvalidOperation: "'{op}' cannot check for {:?} values in {:?} data", &type_other, &type_left) + }, + }; + Ok(Some(casted_expr)) +} diff --git a/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs b/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs new file mode 100644 index 000000000000..702aef62073c --- /dev/null +++ b/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs @@ -0,0 +1,722 @@ +mod binary; +mod functions; +#[cfg(feature = "is_in")] +mod is_in; + +use binary::process_binary; +use polars_core::chunked_array::cast::CastOptions; +use polars_core::prelude::*; +use polars_core::utils::{get_supertype, get_supertype_with_options, materialize_dyn_int}; +use polars_utils::format_list; +use polars_utils::itertools::Itertools; + +use super::*; + +pub struct TypeCoercionRule {} + +macro_rules! unpack { + ($packed:expr) => { + match $packed { + Some(payload) => payload, + None => return Ok(None), + } + }; +} +pub(super) use unpack; + +/// determine if we use the supertype or not. For instance when we have a column Int64 and we compare with literal UInt32 +/// it would be wasteful to cast the column instead of the literal. +fn modify_supertype( + mut st: DataType, + left: &AExpr, + right: &AExpr, + type_left: &DataType, + type_right: &DataType, +) -> DataType { + // TODO! This must be removed and dealt properly with dynamic str. + use DataType::*; + match (type_left, type_right, left, right) { + // if the we compare a categorical to a literal string we want to cast the literal to categorical + #[cfg(feature = "dtype-categorical")] + (Categorical(_, ordering), String | Unknown(UnknownKind::Str), _, AExpr::Literal(_)) + | (String | Unknown(UnknownKind::Str), Categorical(_, ordering), AExpr::Literal(_), _) => { + st = Categorical(None, *ordering) + }, + #[cfg(feature = "dtype-categorical")] + (dt @ Enum(_, _), String | Unknown(UnknownKind::Str), _, AExpr::Literal(_)) + | (String | Unknown(UnknownKind::Str), dt @ Enum(_, _), AExpr::Literal(_), _) => { + st = dt.clone() + }, + // when then expression literals can have a different list type. + // so we cast the literal to the other hand side. + (List(inner), List(other), _, AExpr::Literal(_)) + | (List(other), List(inner), AExpr::Literal(_), _) + if inner != other => + { + st = match &**inner { + #[cfg(feature = "dtype-categorical")] + Categorical(_, ordering) => List(Box::new(Categorical(None, *ordering))), + _ => List(inner.clone()), + }; + }, + // do nothing + _ => {}, + } + st +} + +fn get_aexpr_and_type<'a>( + expr_arena: &'a Arena, + e: Node, + input_schema: &Schema, +) -> Option<(&'a AExpr, DataType)> { + let ae = expr_arena.get(e); + Some(( + ae, + ae.get_type(input_schema, Context::Default, expr_arena) + .ok()?, + )) +} + +fn materialize(aexpr: &AExpr) -> Option { + match aexpr { + AExpr::Literal(lv) => Some(AExpr::Literal(lv.clone().materialize())), + _ => None, + } +} + +impl OptimizationRule for TypeCoercionRule { + fn optimize_expr( + &mut self, + expr_arena: &mut Arena, + expr_node: Node, + lp_arena: &Arena, + lp_node: Node, + ) -> PolarsResult> { + let expr = expr_arena.get(expr_node); + let out = match *expr { + AExpr::Cast { + expr, + ref dtype, + options, + } => { + let input = expr_arena.get(expr); + + inline_or_prune_cast(input, dtype, options, lp_node, lp_arena, expr_arena)? + }, + AExpr::Ternary { + truthy: truthy_node, + falsy: falsy_node, + predicate, + } => { + let input_schema = get_schema(lp_arena, lp_node); + let (truthy, type_true) = + unpack!(get_aexpr_and_type(expr_arena, truthy_node, &input_schema)); + let (falsy, type_false) = + unpack!(get_aexpr_and_type(expr_arena, falsy_node, &input_schema)); + + if type_true == type_false { + return Ok(None); + } + let st = unpack!(get_supertype(&type_true, &type_false)); + let st = modify_supertype(st, truthy, falsy, &type_true, &type_false); + + // only cast if the type is not already the super type. + // this can prevent an expensive flattening and subsequent aggregation + // in a group_by context. To be able to cast the groups need to be + // flattened + let new_node_truthy = if type_true != st { + expr_arena.add(AExpr::Cast { + expr: truthy_node, + dtype: st.clone(), + options: CastOptions::Strict, + }) + } else { + truthy_node + }; + + let new_node_falsy = if type_false != st { + expr_arena.add(AExpr::Cast { + expr: falsy_node, + dtype: st, + options: CastOptions::Strict, + }) + } else { + falsy_node + }; + + Some(AExpr::Ternary { + truthy: new_node_truthy, + falsy: new_node_falsy, + predicate, + }) + }, + AExpr::BinaryExpr { + left: node_left, + op, + right: node_right, + } => return process_binary(expr_arena, lp_arena, lp_node, node_left, op, node_right), + #[cfg(feature = "is_in")] + AExpr::Function { + ref function, + ref input, + options, + } if { + let mut matches = matches!( + function, + FunctionExpr::Boolean(BooleanFunction::IsIn { .. }) + | FunctionExpr::ListExpr(ListFunction::Contains) + ); + #[cfg(feature = "dtype-array")] + { + matches |= matches!(function, FunctionExpr::ArrayExpr(ArrayFunction::Contains)); + } + matches + } => + { + let (op, flat, nested, is_contains) = match function { + FunctionExpr::Boolean(BooleanFunction::IsIn { .. }) => ("is_in", 0, 1, false), + FunctionExpr::ListExpr(ListFunction::Contains) => ("list.contains", 1, 0, true), + #[cfg(feature = "dtype-array")] + FunctionExpr::ArrayExpr(ArrayFunction::Contains) => { + ("arr.contains", 1, 0, true) + }, + _ => unreachable!(), + }; + + let Some(result) = is_in::resolve_is_in( + input, + expr_arena, + lp_arena, + lp_node, + is_contains, + op, + flat, + nested, + )? + else { + return Ok(None); + }; + + let function = function.clone(); + let mut input = input.to_vec(); + use self::is_in::IsInTypeCoercionResult; + match result { + IsInTypeCoercionResult::SuperType(flat_type, nested_type) => { + let input_schema = get_schema(lp_arena, lp_node); + let (_, type_left) = unpack!(get_aexpr_and_type( + expr_arena, + input[flat].node(), + &input_schema + )); + let (_, type_other) = unpack!(get_aexpr_and_type( + expr_arena, + input[nested].node(), + &input_schema + )); + cast_expr_ir( + &mut input[flat], + &type_left, + &flat_type, + expr_arena, + CastOptions::NonStrict, + )?; + cast_expr_ir( + &mut input[nested], + &type_other, + &nested_type, + expr_arena, + CastOptions::NonStrict, + )?; + }, + IsInTypeCoercionResult::SelfCast { dtype, strict } => { + let input_schema = get_schema(lp_arena, lp_node); + let (_, type_self) = unpack!(get_aexpr_and_type( + expr_arena, + input[flat].node(), + &input_schema + )); + let options = if strict { + CastOptions::Strict + } else { + CastOptions::NonStrict + }; + cast_expr_ir(&mut input[flat], &type_self, &dtype, expr_arena, options)?; + }, + IsInTypeCoercionResult::OtherCast { dtype, strict } => { + let input_schema = get_schema(lp_arena, lp_node); + let (_, type_other) = unpack!(get_aexpr_and_type( + expr_arena, + input[nested].node(), + &input_schema + )); + let options = if strict { + CastOptions::Strict + } else { + CastOptions::NonStrict + }; + cast_expr_ir(&mut input[nested], &type_other, &dtype, expr_arena, options)?; + }, + IsInTypeCoercionResult::Implode => { + assert!(!is_contains); + let other_input = + expr_arena.add(AExpr::Agg(IRAggExpr::Implode(input[1].node()))); + input[1].set_node(other_input); + }, + } + + Some(AExpr::Function { + function, + input, + options, + }) + }, + // shift and fill should only cast left and fill value to super type. + AExpr::Function { + function: FunctionExpr::ShiftAndFill, + ref input, + options, + } => { + let input_schema = get_schema(lp_arena, lp_node); + let left_node = input[0].node(); + let fill_value_node = input[2].node(); + let (left, type_left) = + unpack!(get_aexpr_and_type(expr_arena, left_node, &input_schema)); + let (fill_value, type_fill_value) = unpack!(get_aexpr_and_type( + expr_arena, + fill_value_node, + &input_schema + )); + + unpack!(early_escape(&type_left, &type_fill_value)); + + let super_type = unpack!(get_supertype(&type_left, &type_fill_value)); + let super_type = + modify_supertype(super_type, left, fill_value, &type_left, &type_fill_value); + + let mut input = input.clone(); + let new_node_left = if type_left != super_type { + expr_arena.add(AExpr::Cast { + expr: left_node, + dtype: super_type.clone(), + options: CastOptions::NonStrict, + }) + } else { + left_node + }; + + let new_node_fill_value = if type_fill_value != super_type { + expr_arena.add(AExpr::Cast { + expr: fill_value_node, + dtype: super_type.clone(), + options: CastOptions::NonStrict, + }) + } else { + fill_value_node + }; + + input[0].set_node(new_node_left); + input[2].set_node(new_node_fill_value); + + Some(AExpr::Function { + function: FunctionExpr::ShiftAndFill, + input, + options, + }) + }, + // generic type coercion of any function. + AExpr::Function { + // only for `DataType::Unknown` as it still has to be set. + ref function, + ref input, + mut options, + } if options.cast_options.is_some() => { + let casting_rules = options.cast_options.unwrap(); + let input_schema = get_schema(lp_arena, lp_node); + + let function = function.clone(); + let mut input = input.clone(); + + if let Some(dtypes) = + functions::get_function_dtypes(&input, expr_arena, &input_schema, &function)? + { + let self_e = input[0].clone(); + let (self_ae, type_self) = + unpack!(get_aexpr_and_type(expr_arena, self_e.node(), &input_schema)); + let mut super_type = type_self.clone(); + match casting_rules { + CastingRules::Supertype(super_type_opts) => { + for other in &input[1..] { + let (other, type_other) = unpack!(get_aexpr_and_type( + expr_arena, + other.node(), + &input_schema + )); + + let Some(new_st) = get_supertype_with_options( + &super_type, + &type_other, + super_type_opts, + ) else { + raise_supertype(&function, &input, &input_schema, expr_arena)?; + unreachable!() + }; + if input.len() == 2 { + // modify_supertype is a bit more conservative of casting columns + // to literals + super_type = modify_supertype( + new_st, + self_ae, + other, + &type_self, + &type_other, + ) + } else { + // when dealing with more than 1 argument, we simply find the supertypes + super_type = new_st + } + } + }, + CastingRules::FirstArgLossless => { + if super_type.is_integer() { + for other in &input[1..] { + let other = + other.dtype(&input_schema, Context::Default, expr_arena)?; + if other.is_float() { + polars_bail!(InvalidOperation: "cannot cast lossless between {} and {}", super_type, other) + } + } + } + }, + } + + if matches!(super_type, DataType::Unknown(UnknownKind::Any)) { + raise_supertype(&function, &input, &input_schema, expr_arena)?; + unreachable!() + } + + match super_type { + DataType::Unknown(UnknownKind::Float) => super_type = DataType::Float64, + DataType::Unknown(UnknownKind::Int(v)) => { + super_type = materialize_dyn_int(v).dtype() + }, + _ => {}, + } + + for (e, dtype) in input.iter_mut().zip(dtypes) { + match super_type { + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_, _) if dtype.is_string() => { + // pass + }, + _ => cast_expr_ir( + e, + &dtype, + &super_type, + expr_arena, + CastOptions::NonStrict, + )?, + } + } + } + + // Ensure we don't go through this on next iteration. + options.cast_options = None; + Some(AExpr::Function { + function, + input, + options, + }) + }, + #[cfg(feature = "temporal")] + AExpr::Function { + function: ref function @ FunctionExpr::TemporalExpr(TemporalFunction::Duration(_)), + ref input, + options, + } => { + let input_schema = get_schema(lp_arena, lp_node); + + for (i, expr) in input.iter().enumerate() { + let (_, dtype) = + unpack!(get_aexpr_and_type(expr_arena, expr.node(), &input_schema)); + + if !matches!(dtype, DataType::Int64) { + let function = function.clone(); + let mut input = input.to_vec(); + cast_expr_ir( + &mut input[i], + &dtype, + &DataType::Int64, + expr_arena, + CastOptions::NonStrict, + )?; + for expr in &mut input[i + 1..] { + let (_, dtype) = + unpack!(get_aexpr_and_type(expr_arena, expr.node(), &input_schema)); + cast_expr_ir( + expr, + &dtype, + &DataType::Int64, + expr_arena, + CastOptions::Strict, + )?; + } + + return Ok(Some(AExpr::Function { + function, + input, + options, + })); + } + } + + None + }, + AExpr::Slice { offset, length, .. } => { + let input_schema = get_schema(lp_arena, lp_node); + let (_, offset_dtype) = + unpack!(get_aexpr_and_type(expr_arena, offset, &input_schema)); + polars_ensure!(offset_dtype.is_integer(), InvalidOperation: "offset must be integral for slice expression, not {}", offset_dtype); + let (_, length_dtype) = + unpack!(get_aexpr_and_type(expr_arena, length, &input_schema)); + polars_ensure!(length_dtype.is_integer() || length_dtype.is_null(), InvalidOperation: "length must be integral for slice expression, not {}", length_dtype); + None + }, + _ => None, + }; + Ok(out) + } +} + +fn inline_or_prune_cast( + aexpr: &AExpr, + dtype: &DataType, + options: CastOptions, + lp_node: Node, + lp_arena: &Arena, + expr_arena: &Arena, +) -> PolarsResult> { + if !dtype.is_known() { + return Ok(None); + } + + let out = match aexpr { + // PRUNE + AExpr::BinaryExpr { op, .. } => { + use Operator::*; + + match op { + LogicalOr | LogicalAnd => { + if let Some(schema) = lp_arena.get(lp_node).input_schema(lp_arena) { + let field = aexpr.to_field(&schema, Context::Default, expr_arena)?; + if field.dtype == *dtype { + return Ok(Some(aexpr.clone())); + } + } + + None + }, + Eq | EqValidity | NotEq | NotEqValidity | Lt | LtEq | Gt | GtEq => { + if dtype.is_bool() { + Some(aexpr.clone()) + } else { + None + } + }, + _ => None, + } + }, + // INLINE + AExpr::Literal(lv) => try_inline_literal_cast(lv, dtype, options)?.map(AExpr::Literal), + _ => None, + }; + + Ok(out) +} + +fn try_inline_literal_cast( + lv: &LiteralValue, + dtype: &DataType, + options: CastOptions, +) -> PolarsResult> { + let lv = match lv { + LiteralValue::Series(s) => { + let s = s.cast_with_options(dtype, options)?; + LiteralValue::Series(SpecialEq::new(s)) + }, + LiteralValue::Dyn(dyn_value) => dyn_value.clone().try_materialize_to_dtype(dtype)?.into(), + lv if lv.is_null() => match dtype { + DataType::Unknown(UnknownKind::Float | UnknownKind::Int(_) | UnknownKind::Str) => { + LiteralValue::untyped_null() + }, + _ => return Ok(None), + }, + LiteralValue::Scalar(sc) => sc.clone().cast_with_options(dtype, options)?.into(), + lv => { + let Some(av) = lv.to_any_value() else { + return Ok(None); + }; + if dtype == &av.dtype() { + return Ok(Some(lv.clone())); + } + match (av, dtype) { + // casting null always remains null + (AnyValue::Null, _) => return Ok(None), + // series cast should do this one + #[cfg(feature = "dtype-datetime")] + (AnyValue::Datetime(_, _, _), DataType::Datetime(_, _)) => return Ok(None), + #[cfg(feature = "dtype-duration")] + (AnyValue::Duration(_, _), _) => return Ok(None), + #[cfg(feature = "dtype-categorical")] + (AnyValue::Categorical(_, _, _), _) | (_, DataType::Categorical(_, _)) => { + return Ok(None); + }, + #[cfg(feature = "dtype-categorical")] + (AnyValue::Enum(_, _, _), _) | (_, DataType::Enum(_, _)) => return Ok(None), + #[cfg(feature = "dtype-struct")] + (_, DataType::Struct(_)) => return Ok(None), + (av, _) => { + let out = { + match av.strict_cast(dtype) { + Some(out) => out, + None => return Ok(None), + } + }; + out.into() + }, + } + }, + }; + + Ok(Some(lv)) +} + +fn cast_expr_ir( + e: &mut ExprIR, + from_dtype: &DataType, + to_dtype: &DataType, + expr_arena: &mut Arena, + options: CastOptions, +) -> PolarsResult<()> { + if from_dtype == to_dtype { + return Ok(()); + } + + check_cast(from_dtype, to_dtype)?; + + if let AExpr::Literal(lv) = expr_arena.get(e.node()) { + if let Some(literal) = try_inline_literal_cast(lv, to_dtype, options)? { + e.set_node(expr_arena.add(AExpr::Literal(literal))); + e.set_dtype(to_dtype.clone()); + return Ok(()); + } + } + + e.set_node(expr_arena.add(AExpr::Cast { + expr: e.node(), + dtype: to_dtype.clone(), + options: CastOptions::Strict, + })); + e.set_dtype(to_dtype.clone()); + + Ok(()) +} + +fn check_cast(from: &DataType, to: &DataType) -> PolarsResult<()> { + polars_ensure!( + from.can_cast_to(to) != Some(false), + InvalidOperation: "casting from {from:?} to {to:?} not supported" + ); + Ok(()) +} + +fn early_escape(type_self: &DataType, type_other: &DataType) -> Option<()> { + match (type_self, type_other) { + (lhs, rhs) if lhs == rhs => None, + _ => Some(()), + } +} + +fn raise_supertype( + function: &FunctionExpr, + inputs: &[ExprIR], + input_schema: &Schema, + expr_arena: &Arena, +) -> PolarsResult<()> { + let dtypes = inputs + .iter() + .map(|e| e.dtype(input_schema, Context::Default, expr_arena).cloned()) + .collect::>>()?; + + let st = dtypes + .iter() + .cloned() + .map(Some) + .reduce(|a, b| get_supertype(&a?, &b?)) + .expect("always at least 2 inputs"); + // We could get a supertype with the default options, so the input types are not allowed for this + // specific operation. + if st.is_some() { + polars_bail!(InvalidOperation: "got invalid or ambiguous dtypes: '{}' in expression '{}'\ + \n\nConsider explicitly casting your input types to resolve potential ambiguity.", format_list!(&dtypes), function); + } else { + polars_bail!(InvalidOperation: "could not determine supertype of: {} in expression '{}'\ + \n\nIt might also be the case that the type combination isn't allowed in this specific operation.", format_list!(&dtypes), function); + } +} + +#[cfg(test)] +#[cfg(feature = "dtype-categorical")] +mod test { + use polars_core::prelude::*; + + use super::*; + + #[test] + fn test_categorical_string() { + let mut expr_arena = Arena::new(); + let mut lp_arena = Arena::new(); + let optimizer = StackOptimizer {}; + let rules: &mut [Box] = &mut [Box::new(TypeCoercionRule {})]; + + let df = DataFrame::new(Vec::from([Column::new_empty( + PlSmallStr::from_static("fruits"), + &DataType::Categorical(None, Default::default()), + )])) + .unwrap(); + + let expr_in = vec![col("fruits").eq(lit("somestr"))]; + let lp = DslBuilder::from_existing_df(df.clone()) + .project(expr_in.clone(), Default::default()) + .build(); + + let mut lp_top = + to_alp(lp, &mut expr_arena, &mut lp_arena, &mut OptFlags::default()).unwrap(); + lp_top = optimizer + .optimize_loop(rules, &mut expr_arena, &mut lp_arena, lp_top) + .unwrap(); + let lp = node_to_lp(lp_top, &expr_arena, &mut lp_arena); + + // we test that the fruits column is not cast to string for the comparison + if let DslPlan::Select { expr, .. } = lp { + assert_eq!(expr, expr_in); + }; + + let expr_in = vec![col("fruits") + (lit("somestr"))]; + let lp = DslBuilder::from_existing_df(df) + .project(expr_in, Default::default()) + .build(); + let mut lp_top = + to_alp(lp, &mut expr_arena, &mut lp_arena, &mut OptFlags::default()).unwrap(); + lp_top = optimizer + .optimize_loop(rules, &mut expr_arena, &mut lp_arena, lp_top) + .unwrap(); + let lp = node_to_lp(lp_top, &expr_arena, &mut lp_arena); + + // we test that the fruits column is cast to string for the addition + let expected = vec![col("fruits").cast(DataType::String) + lit("somestr")]; + if let DslPlan::Select { expr, .. } = lp { + assert_eq!(expr, expected); + }; + } +} diff --git a/crates/polars-plan/src/plans/debug.rs b/crates/polars-plan/src/plans/debug.rs new file mode 100644 index 000000000000..1c0a5ce1ad07 --- /dev/null +++ b/crates/polars-plan/src/plans/debug.rs @@ -0,0 +1,13 @@ +#![allow(dead_code)] +use polars_utils::arena::{Arena, Node}; + +use crate::prelude::{AExpr, node_to_expr}; + +pub fn dbg_nodes(nodes: &[Node], arena: &Arena) { + println!("["); + for node in nodes { + let e = node_to_expr(*node, arena); + println!("{e:?}") + } + println!("]"); +} diff --git a/crates/polars-plan/src/plans/expr_ir.rs b/crates/polars-plan/src/plans/expr_ir.rs new file mode 100644 index 000000000000..b2a3543a899c --- /dev/null +++ b/crates/polars-plan/src/plans/expr_ir.rs @@ -0,0 +1,335 @@ +use std::borrow::Borrow; +use std::hash::Hash; +#[cfg(feature = "cse")] +use std::hash::Hasher; +use std::sync::OnceLock; + +use polars_utils::format_pl_smallstr; +#[cfg(feature = "ir_serde")] +use serde::{Deserialize, Serialize}; + +use super::*; +use crate::constants::{get_len_name, get_literal_name}; + +#[derive(Default, Debug, Clone, Hash, PartialEq, Eq)] +#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))] +pub enum OutputName { + /// No not yet set. + #[default] + None, + /// The most left-hand-side literal will be the output name. + LiteralLhs(PlSmallStr), + /// The most left-hand-side column will be the output name. + ColumnLhs(PlSmallStr), + /// Rename the output as `PlSmallStr`. + Alias(PlSmallStr), + #[cfg(feature = "dtype-struct")] + /// A struct field. + Field(PlSmallStr), +} + +impl OutputName { + pub fn get(&self) -> Option<&PlSmallStr> { + match self { + OutputName::Alias(name) => Some(name), + OutputName::ColumnLhs(name) => Some(name), + OutputName::LiteralLhs(name) => Some(name), + #[cfg(feature = "dtype-struct")] + OutputName::Field(name) => Some(name), + OutputName::None => None, + } + } + + pub fn unwrap(&self) -> &PlSmallStr { + self.get().expect("no output name set") + } + + pub(crate) fn is_none(&self) -> bool { + matches!(self, OutputName::None) + } +} + +#[derive(Debug)] +#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))] +pub struct ExprIR { + /// Output name of this expression. + output_name: OutputName, + /// Output dtype of this expression + /// Reduced expression. + /// This expression is pruned from `alias` and already expanded. + node: Node, + #[cfg_attr(feature = "ir_serde", serde(skip))] + output_dtype: OnceLock, +} + +impl Eq for ExprIR {} + +impl PartialEq for ExprIR { + fn eq(&self, other: &Self) -> bool { + self.node == other.node && self.output_name == other.output_name + } +} + +impl Clone for ExprIR { + fn clone(&self) -> Self { + let output_dtype = OnceLock::new(); + if let Some(dt) = self.output_dtype.get() { + output_dtype.set(dt.clone()).unwrap() + } + + ExprIR { + output_name: self.output_name.clone(), + node: self.node, + output_dtype, + } + } +} + +impl Borrow for ExprIR { + fn borrow(&self) -> &Node { + &self.node + } +} + +impl ExprIR { + pub fn new(node: Node, output_name: OutputName) -> Self { + debug_assert!(!output_name.is_none()); + ExprIR { + output_name, + node, + output_dtype: OnceLock::new(), + } + } + + pub fn with_dtype(self, dtype: DataType) -> Self { + let _ = self.output_dtype.set(dtype); + self + } + + pub(crate) fn set_dtype(&mut self, dtype: DataType) { + self.output_dtype = OnceLock::from(dtype); + } + + pub fn from_node(node: Node, arena: &Arena) -> Self { + let mut out = Self { + node, + output_name: OutputName::None, + output_dtype: OnceLock::new(), + }; + out.node = node; + for (_, ae) in arena.iter(node) { + match ae { + AExpr::Column(name) => { + out.output_name = OutputName::ColumnLhs(name.clone()); + break; + }, + AExpr::Literal(lv) => { + if let LiteralValue::Series(s) = lv { + out.output_name = OutputName::LiteralLhs(s.name().clone()); + } else { + out.output_name = OutputName::LiteralLhs(get_literal_name().clone()); + } + break; + }, + AExpr::Function { + input, function, .. + } => { + match function { + #[cfg(feature = "dtype-struct")] + FunctionExpr::StructExpr(StructFunction::FieldByName(name)) => { + out.output_name = OutputName::Field(name.clone()); + }, + _ => { + if input.is_empty() { + out.output_name = + OutputName::LiteralLhs(format_pl_smallstr!("{}", function)); + } else { + out.output_name = input[0].output_name.clone(); + } + }, + } + break; + }, + AExpr::AnonymousFunction { input, options, .. } => { + if input.is_empty() { + out.output_name = + OutputName::LiteralLhs(PlSmallStr::from_static(options.fmt_str)); + } else { + out.output_name = input[0].output_name.clone(); + } + break; + }, + AExpr::Len => { + out.output_name = OutputName::LiteralLhs(get_len_name()); + break; + }, + AExpr::Alias(_, _) => { + // Should be removed during conversion. + #[cfg(debug_assertions)] + { + unreachable!() + } + }, + _ => {}, + } + } + debug_assert!(!out.output_name.is_none()); + out + } + + #[inline] + pub fn node(&self) -> Node { + self.node + } + + /// Create a `ExprIR` structure that implements display + pub fn display<'a>(&'a self, expr_arena: &'a Arena) -> ExprIRDisplay<'a> { + ExprIRDisplay { + node: self.node(), + output_name: self.output_name_inner(), + expr_arena, + } + } + + pub(crate) fn set_node(&mut self, node: Node) { + self.node = node; + self.output_dtype = OnceLock::new(); + } + + pub(crate) fn set_alias(&mut self, name: PlSmallStr) { + self.output_name = OutputName::Alias(name) + } + + pub fn with_alias(&self, name: PlSmallStr) -> Self { + Self { + output_name: OutputName::Alias(name), + node: self.node, + output_dtype: self.output_dtype.clone(), + } + } + + pub(crate) fn set_columnlhs(&mut self, name: PlSmallStr) { + debug_assert!(matches!( + self.output_name, + OutputName::ColumnLhs(_) | OutputName::None + )); + self.output_name = OutputName::ColumnLhs(name) + } + + pub fn output_name_inner(&self) -> &OutputName { + &self.output_name + } + + pub fn output_name(&self) -> &PlSmallStr { + self.output_name.unwrap() + } + + pub fn to_expr(&self, expr_arena: &Arena) -> Expr { + let out = node_to_expr(self.node, expr_arena); + + match &self.output_name { + OutputName::Alias(name) => out.alias(name.clone()), + _ => out, + } + } + + pub fn get_alias(&self) -> Option<&PlSmallStr> { + match &self.output_name { + OutputName::Alias(name) => Some(name), + _ => None, + } + } + + // Utility for debugging. + #[cfg(debug_assertions)] + #[allow(dead_code)] + pub(crate) fn print(&self, expr_arena: &Arena) { + eprintln!("{:?}", self.to_expr(expr_arena)) + } + + pub(crate) fn has_alias(&self) -> bool { + matches!(self.output_name, OutputName::Alias(_)) + } + + #[cfg(feature = "cse")] + pub(crate) fn traverse_and_hash(&self, expr_arena: &Arena, state: &mut H) { + traverse_and_hash_aexpr(self.node, expr_arena, state); + if let Some(alias) = self.get_alias() { + alias.hash(state) + } + } + + pub fn is_scalar(&self, expr_arena: &Arena) -> bool { + is_scalar_ae(self.node, expr_arena) + } + + pub fn dtype( + &self, + schema: &Schema, + ctxt: Context, + expr_arena: &Arena, + ) -> PolarsResult<&DataType> { + match self.output_dtype.get() { + Some(dtype) => Ok(dtype), + None => { + let dtype = expr_arena + .get(self.node) + .to_dtype(schema, ctxt, expr_arena)?; + let _ = self.output_dtype.set(dtype); + Ok(self.output_dtype.get().unwrap()) + }, + } + } + + pub fn field( + &self, + schema: &Schema, + ctxt: Context, + expr_arena: &Arena, + ) -> PolarsResult { + let dtype = self.dtype(schema, ctxt, expr_arena)?; + let name = self.output_name(); + Ok(Field::new(name.clone(), dtype.clone())) + } +} + +impl AsRef for ExprIR { + fn as_ref(&self) -> &ExprIR { + self + } +} + +/// A Node that is restricted to `AExpr::Column` +#[repr(transparent)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)] +pub struct ColumnNode(pub(crate) Node); + +impl From for Node { + fn from(value: ColumnNode) -> Self { + value.0 + } +} +impl From<&ExprIR> for Node { + fn from(value: &ExprIR) -> Self { + value.node() + } +} + +pub(crate) fn name_to_expr_ir(name: PlSmallStr, expr_arena: &mut Arena) -> ExprIR { + let node = expr_arena.add(AExpr::Column(name.clone())); + ExprIR::new(node, OutputName::ColumnLhs(name)) +} + +pub(crate) fn names_to_expr_irs(names: I, expr_arena: &mut Arena) -> Vec +where + I: IntoIterator, + S: Into, +{ + names + .into_iter() + .map(|name| { + let name = name.into(); + name_to_expr_ir(name, expr_arena) + }) + .collect() +} diff --git a/crates/polars-plan/src/plans/functions/count.rs b/crates/polars-plan/src/plans/functions/count.rs new file mode 100644 index 000000000000..2c4bf7c0501a --- /dev/null +++ b/crates/polars-plan/src/plans/functions/count.rs @@ -0,0 +1,248 @@ +#[cfg(feature = "ipc")] +use arrow::io::ipc::read::get_row_count as count_rows_ipc_sync; +#[cfg(any( + feature = "parquet", + feature = "ipc", + feature = "json", + feature = "csv" +))] +use polars_core::error::feature_gated; +#[cfg(any(feature = "json", feature = "parquet"))] +use polars_io::SerReader; +#[cfg(any(feature = "parquet", feature = "json"))] +use polars_io::cloud::CloudOptions; +#[cfg(feature = "parquet")] +use polars_io::parquet::read::ParquetReader; +#[cfg(all(feature = "parquet", feature = "async"))] +use polars_io::pl_async::{get_runtime, with_concurrency_budget}; + +use super::*; + +#[allow(unused_variables)] +pub fn count_rows( + sources: &ScanSources, + scan_type: &FileScan, + cloud_options: Option<&CloudOptions>, + alias: Option, +) -> PolarsResult { + #[cfg(not(any( + feature = "parquet", + feature = "ipc", + feature = "json", + feature = "csv" + )))] + { + unreachable!() + } + + #[cfg(any( + feature = "parquet", + feature = "ipc", + feature = "json", + feature = "csv" + ))] + { + let count: PolarsResult = match scan_type { + #[cfg(feature = "csv")] + FileScan::Csv { options } => count_all_rows_csv(sources, options), + #[cfg(feature = "parquet")] + FileScan::Parquet { .. } => count_rows_parquet(sources, cloud_options), + #[cfg(feature = "ipc")] + FileScan::Ipc { options, metadata } => count_rows_ipc( + sources, + #[cfg(feature = "cloud")] + cloud_options, + metadata.as_deref(), + ), + #[cfg(feature = "json")] + FileScan::NDJson { options } => count_rows_ndjson(sources, cloud_options), + FileScan::Anonymous { .. } => { + unreachable!() + }, + }; + let count = count?; + let count: IdxSize = count.try_into().map_err( + |_| polars_err!(ComputeError: "count of {} exceeded maximum row size", count), + )?; + let column_name = alias.unwrap_or(PlSmallStr::from_static(crate::constants::LEN)); + DataFrame::new(vec![Column::new(column_name, [count])]) + } +} + +#[cfg(feature = "csv")] +fn count_all_rows_csv( + sources: &ScanSources, + options: &polars_io::prelude::CsvReadOptions, +) -> PolarsResult { + let parse_options = options.get_parse_options(); + + sources + .iter() + .map(|source| match source { + ScanSourceRef::Path(path) => polars_io::csv::read::count_rows( + path, + parse_options.separator, + parse_options.quote_char, + parse_options.comment_prefix.as_ref(), + parse_options.eol_char, + options.has_header, + ), + _ => { + let memslice = source.to_memslice()?; + + polars_io::csv::read::count_rows_from_slice_par( + &memslice[..], + parse_options.separator, + parse_options.quote_char, + parse_options.comment_prefix.as_ref(), + parse_options.eol_char, + options.has_header, + ) + }, + }) + .sum() +} + +#[cfg(feature = "parquet")] +pub(super) fn count_rows_parquet( + sources: &ScanSources, + #[allow(unused)] cloud_options: Option<&CloudOptions>, +) -> PolarsResult { + if sources.is_empty() { + return Ok(0); + }; + let is_cloud = sources.is_cloud_url(); + + if is_cloud { + feature_gated!("cloud", { + get_runtime().block_on(count_rows_cloud_parquet( + sources.as_paths().unwrap(), + cloud_options, + )) + }) + } else { + sources + .iter() + .map(|source| { + ParquetReader::new(std::io::Cursor::new(source.to_memslice()?)).num_rows() + }) + .sum::>() + } +} + +#[cfg(all(feature = "parquet", feature = "async"))] +async fn count_rows_cloud_parquet( + paths: &[std::path::PathBuf], + cloud_options: Option<&CloudOptions>, +) -> PolarsResult { + use polars_io::prelude::ParquetObjectStore; + + let collection = paths.iter().map(|path| { + with_concurrency_budget(1, || async { + let mut reader = + ParquetObjectStore::from_uri(&path.to_string_lossy(), cloud_options, None).await?; + reader.num_rows().await + }) + }); + futures::future::try_join_all(collection) + .await + .map(|rows| rows.iter().sum()) +} + +#[cfg(feature = "ipc")] +pub(super) fn count_rows_ipc( + sources: &ScanSources, + #[cfg(feature = "cloud")] cloud_options: Option<&CloudOptions>, + metadata: Option<&arrow::io::ipc::read::FileMetadata>, +) -> PolarsResult { + if sources.is_empty() { + return Ok(0); + }; + let is_cloud = sources.is_cloud_url(); + + if is_cloud { + feature_gated!("cloud", { + get_runtime().block_on(count_rows_cloud_ipc( + sources.as_paths().unwrap(), + cloud_options, + metadata, + )) + }) + } else { + sources + .iter() + .map(|source| { + let memslice = source.to_memslice()?; + count_rows_ipc_sync(&mut std::io::Cursor::new(memslice)).map(|v| v as usize) + }) + .sum::>() + } +} + +#[cfg(all(feature = "ipc", feature = "async"))] +async fn count_rows_cloud_ipc( + paths: &[std::path::PathBuf], + cloud_options: Option<&CloudOptions>, + metadata: Option<&arrow::io::ipc::read::FileMetadata>, +) -> PolarsResult { + use polars_io::ipc::IpcReaderAsync; + + let collection = paths.iter().map(|path| { + with_concurrency_budget(1, || async { + let reader = IpcReaderAsync::from_uri(&path.to_string_lossy(), cloud_options).await?; + reader.count_rows(metadata).await + }) + }); + futures::future::try_join_all(collection) + .await + .map(|rows| rows.iter().map(|v| *v as usize).sum()) +} + +#[cfg(feature = "json")] +pub(super) fn count_rows_ndjson( + sources: &ScanSources, + cloud_options: Option<&CloudOptions>, +) -> PolarsResult { + use polars_core::config; + use polars_io::utils::compression::maybe_decompress_bytes; + + if sources.is_empty() { + return Ok(0); + } + + let is_cloud_url = sources.is_cloud_url(); + let run_async = is_cloud_url || (sources.is_paths() && config::force_async()); + + let cache_entries = { + if run_async { + feature_gated!("cloud", { + Some(polars_io::file_cache::init_entries_from_uri_list( + sources + .as_paths() + .unwrap() + .iter() + .map(|path| Arc::from(path.to_str().unwrap())) + .collect::>() + .as_slice(), + cloud_options, + )?) + }) + } else { + None + } + }; + + sources + .iter() + .map(|source| { + let memslice = + source.to_memslice_possibly_async(run_async, cache_entries.as_ref(), 0)?; + + let owned = &mut vec![]; + let reader = polars_io::ndjson::core::JsonLineReader::new(std::io::Cursor::new( + maybe_decompress_bytes(&memslice[..], owned)?, + )); + reader.count() + }) + .sum() +} diff --git a/crates/polars-plan/src/plans/functions/dsl.rs b/crates/polars-plan/src/plans/functions/dsl.rs new file mode 100644 index 000000000000..7586c438788d --- /dev/null +++ b/crates/polars-plan/src/plans/functions/dsl.rs @@ -0,0 +1,183 @@ +use polars_compute::rolling::QuantileMethod; +use strum_macros::IntoStaticStr; + +use super::*; + +#[cfg(feature = "python")] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Clone)] +pub struct OpaquePythonUdf { + pub function: PythonFunction, + pub schema: Option, + /// allow predicate pushdown optimizations + pub predicate_pd: bool, + /// allow projection pushdown optimizations + pub projection_pd: bool, + pub streamable: bool, + pub validate_output: bool, +} + +// Except for Opaque functions, this only has the DSL name of the function. +#[derive(Clone, IntoStaticStr)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "SCREAMING_SNAKE_CASE")] +pub enum DslFunction { + RowIndex { + name: PlSmallStr, + offset: Option, + }, + // This is both in DSL and IR because we want to be able to serialize it. + #[cfg(feature = "python")] + OpaquePython(OpaquePythonUdf), + Explode { + columns: Vec, + allow_empty: bool, + }, + #[cfg(feature = "pivot")] + Unpivot { + args: UnpivotArgsDSL, + }, + Rename { + existing: Arc<[PlSmallStr]>, + new: Arc<[PlSmallStr]>, + strict: bool, + }, + Unnest(Vec), + Stats(StatsFunction), + /// FillValue + FillNan(Expr), + Drop(DropFunction), + // Function that is already converted to IR. + #[cfg_attr(feature = "serde", serde(skip))] + FunctionIR(FunctionIR), +} + +#[derive(Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct DropFunction { + /// Columns that are going to be dropped + pub(crate) to_drop: Vec, + /// If `true`, performs a check for each item in `to_drop` against the schema. Returns an + /// `ColumnNotFound` error if the column does not exist in the schema. + pub(crate) strict: bool, +} + +#[derive(Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum StatsFunction { + Var { + ddof: u8, + }, + Std { + ddof: u8, + }, + Quantile { + quantile: Expr, + method: QuantileMethod, + }, + Median, + Mean, + Sum, + Min, + Max, +} + +pub(crate) fn validate_columns_in_input, I: IntoIterator>( + columns: I, + input_schema: &Schema, + operation_name: &str, +) -> PolarsResult<()> { + let columns = columns.into_iter(); + for c in columns { + polars_ensure!(input_schema.contains(c.as_ref()), ColumnNotFound: "'{}' on column: '{}' is invalid\n\nSchema at this point: {:?}", operation_name, c.as_ref(), input_schema) + } + Ok(()) +} + +impl DslFunction { + pub(crate) fn into_function_ir(self, input_schema: &Schema) -> PolarsResult { + let function = match self { + #[cfg(feature = "pivot")] + DslFunction::Unpivot { args } => { + let on = expand_selectors(args.on, input_schema, &[])?; + let index = expand_selectors(args.index, input_schema, &[])?; + validate_columns_in_input(on.as_ref(), input_schema, "unpivot")?; + validate_columns_in_input(index.as_ref(), input_schema, "unpivot")?; + + let args = UnpivotArgsIR { + on: on.iter().cloned().collect(), + index: index.iter().cloned().collect(), + variable_name: args.variable_name.clone(), + value_name: args.value_name.clone(), + }; + + FunctionIR::Unpivot { + args: Arc::new(args), + schema: Default::default(), + } + }, + DslFunction::FunctionIR(func) => func, + DslFunction::RowIndex { name, offset } => FunctionIR::RowIndex { + name, + offset, + schema: Default::default(), + }, + DslFunction::Rename { + existing, + new, + strict, + } => { + let swapping = new.iter().any(|name| input_schema.get(name).is_some()); + if strict { + validate_columns_in_input(existing.as_ref(), input_schema, "rename")?; + } + FunctionIR::Rename { + existing, + new, + swapping, + schema: Default::default(), + } + }, + DslFunction::Unnest(selectors) => { + let columns = expand_selectors(selectors, input_schema, &[])?; + validate_columns_in_input(columns.as_ref(), input_schema, "explode")?; + FunctionIR::Unnest { columns } + }, + #[cfg(feature = "python")] + DslFunction::OpaquePython(inner) => FunctionIR::OpaquePython(inner), + DslFunction::Stats(_) + | DslFunction::FillNan(_) + | DslFunction::Drop(_) + | DslFunction::Explode { .. } => { + // We should not reach this. + panic!("impl error") + }, + }; + Ok(function) + } +} + +impl Debug for DslFunction { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{self}") + } +} + +impl Display for DslFunction { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use DslFunction::*; + match self { + FunctionIR(inner) => write!(f, "{inner}"), + v => { + let s: &str = v.into(); + write!(f, "{s}") + }, + } + } +} + +impl From for DslFunction { + fn from(value: FunctionIR) -> Self { + DslFunction::FunctionIR(value) + } +} diff --git a/crates/polars-plan/src/plans/functions/mod.rs b/crates/polars-plan/src/plans/functions/mod.rs new file mode 100644 index 000000000000..630966039a13 --- /dev/null +++ b/crates/polars-plan/src/plans/functions/mod.rs @@ -0,0 +1,350 @@ +mod count; +mod dsl; +#[cfg(feature = "python")] +mod python_udf; +mod rename; +mod schema; + +use std::borrow::Cow; +use std::fmt::{Debug, Display, Formatter}; +use std::hash::{Hash, Hasher}; +use std::sync::{Arc, Mutex}; + +pub use dsl::*; +use polars_core::error::feature_gated; +use polars_core::prelude::*; +use polars_io::cloud::CloudOptions; +use polars_utils::pl_str::PlSmallStr; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; + +#[cfg(feature = "python")] +use crate::dsl::python_dsl::PythonFunction; +use crate::plans::ir::ScanSourcesDisplay; +use crate::prelude::*; + +#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))] +#[derive(Clone, IntoStaticStr)] +#[strum(serialize_all = "SCREAMING_SNAKE_CASE")] +pub enum FunctionIR { + RowIndex { + name: PlSmallStr, + offset: Option, + // Might be cached. + #[cfg_attr(feature = "ir_serde", serde(skip))] + schema: CachedSchema, + }, + #[cfg(feature = "python")] + OpaquePython(OpaquePythonUdf), + + FastCount { + sources: ScanSources, + scan_type: Box, + cloud_options: Option, + alias: Option, + }, + + Unnest { + columns: Arc<[PlSmallStr]>, + }, + Rechunk, + Rename { + existing: Arc<[PlSmallStr]>, + new: Arc<[PlSmallStr]>, + // A column name gets swapped with an existing column + swapping: bool, + #[cfg_attr(feature = "ir_serde", serde(skip))] + schema: CachedSchema, + }, + Explode { + columns: Arc<[PlSmallStr]>, + #[cfg_attr(feature = "ir_serde", serde(skip))] + schema: CachedSchema, + }, + #[cfg(feature = "pivot")] + Unpivot { + args: Arc, + #[cfg_attr(feature = "ir_serde", serde(skip))] + schema: CachedSchema, + }, + #[cfg_attr(feature = "ir_serde", serde(skip))] + Opaque { + function: Arc, + schema: Option>, + /// allow predicate pushdown optimizations + predicate_pd: bool, + /// allow projection pushdown optimizations + projection_pd: bool, + streamable: bool, + // used for formatting + fmt_str: PlSmallStr, + }, + /// Streaming engine pipeline + #[cfg_attr(feature = "ir_serde", serde(skip))] + Pipeline { + function: Arc>, + schema: SchemaRef, + original: Option>, + }, +} + +impl Eq for FunctionIR {} + +impl PartialEq for FunctionIR { + fn eq(&self, other: &Self) -> bool { + use FunctionIR::*; + match (self, other) { + (Rechunk, Rechunk) => true, + ( + FastCount { + sources: srcs_l, .. + }, + FastCount { + sources: srcs_r, .. + }, + ) => srcs_l == srcs_r, + ( + Rename { + existing: existing_l, + new: new_l, + .. + }, + Rename { + existing: existing_r, + new: new_r, + .. + }, + ) => existing_l == existing_r && new_l == new_r, + (Explode { columns: l, .. }, Explode { columns: r, .. }) => l == r, + #[cfg(feature = "pivot")] + (Unpivot { args: l, .. }, Unpivot { args: r, .. }) => l == r, + (RowIndex { name: l, .. }, RowIndex { name: r, .. }) => l == r, + _ => false, + } + } +} + +impl Hash for FunctionIR { + fn hash(&self, state: &mut H) { + std::mem::discriminant(self).hash(state); + match self { + #[cfg(feature = "python")] + FunctionIR::OpaquePython { .. } => {}, + FunctionIR::Opaque { fmt_str, .. } => fmt_str.hash(state), + FunctionIR::FastCount { + sources, + scan_type, + cloud_options, + alias, + } => { + sources.hash(state); + scan_type.hash(state); + cloud_options.hash(state); + alias.hash(state); + }, + FunctionIR::Pipeline { .. } => {}, + FunctionIR::Unnest { columns } => columns.hash(state), + FunctionIR::Rechunk => {}, + FunctionIR::Rename { + existing, + new, + swapping: _, + .. + } => { + existing.hash(state); + new.hash(state); + }, + FunctionIR::Explode { columns, schema: _ } => columns.hash(state), + #[cfg(feature = "pivot")] + FunctionIR::Unpivot { args, schema: _ } => args.hash(state), + FunctionIR::RowIndex { + name, + schema: _, + offset, + } => { + name.hash(state); + offset.hash(state); + }, + } + } +} + +impl FunctionIR { + /// Whether this function can run on batches of data at a time. + pub fn is_streamable(&self) -> bool { + use FunctionIR::*; + match self { + Rechunk | Pipeline { .. } => false, + FastCount { .. } | Unnest { .. } | Rename { .. } | Explode { .. } => true, + #[cfg(feature = "pivot")] + Unpivot { .. } => true, + Opaque { streamable, .. } => *streamable, + #[cfg(feature = "python")] + OpaquePython(OpaquePythonUdf { streamable, .. }) => *streamable, + RowIndex { .. } => false, + } + } + + /// Whether this function will increase the number of rows + pub fn expands_rows(&self) -> bool { + use FunctionIR::*; + match self { + #[cfg(feature = "pivot")] + Unpivot { .. } => true, + Explode { .. } => true, + _ => false, + } + } + + pub(crate) fn allow_predicate_pd(&self) -> bool { + use FunctionIR::*; + match self { + Opaque { predicate_pd, .. } => *predicate_pd, + #[cfg(feature = "python")] + OpaquePython(OpaquePythonUdf { predicate_pd, .. }) => *predicate_pd, + #[cfg(feature = "pivot")] + Unpivot { .. } => true, + Rechunk | Unnest { .. } | Rename { .. } | Explode { .. } => true, + RowIndex { .. } | FastCount { .. } => false, + Pipeline { .. } => unimplemented!(), + } + } + + pub(crate) fn allow_projection_pd(&self) -> bool { + use FunctionIR::*; + match self { + Opaque { projection_pd, .. } => *projection_pd, + #[cfg(feature = "python")] + OpaquePython(OpaquePythonUdf { projection_pd, .. }) => *projection_pd, + Rechunk | FastCount { .. } | Unnest { .. } | Rename { .. } | Explode { .. } => true, + #[cfg(feature = "pivot")] + Unpivot { .. } => true, + RowIndex { .. } => true, + Pipeline { .. } => unimplemented!(), + } + } + + pub(crate) fn additional_projection_pd_columns(&self) -> Cow<[PlSmallStr]> { + use FunctionIR::*; + match self { + Unnest { columns } => Cow::Borrowed(columns.as_ref()), + Explode { columns, .. } => Cow::Borrowed(columns.as_ref()), + _ => Cow::Borrowed(&[]), + } + } + + pub fn evaluate(&self, mut df: DataFrame) -> PolarsResult { + use FunctionIR::*; + match self { + Opaque { function, .. } => function.call_udf(df), + #[cfg(feature = "python")] + OpaquePython(OpaquePythonUdf { + function, + validate_output, + schema, + .. + }) => python_udf::call_python_udf(function, df, *validate_output, schema.clone()), + FastCount { + sources, + scan_type, + cloud_options, + alias, + } => count::count_rows(sources, scan_type, cloud_options.as_ref(), alias.clone()), + Rechunk => { + df.as_single_chunk_par(); + Ok(df) + }, + Unnest { columns: _columns } => { + feature_gated!("dtype-struct", df.unnest(_columns.iter().cloned())) + }, + Pipeline { function, .. } => { + // we use a global string cache here as streaming chunks all have different rev maps + #[cfg(feature = "dtype-categorical")] + { + let _sc = StringCacheHolder::hold(); + function.lock().unwrap().call_udf(df) + } + + #[cfg(not(feature = "dtype-categorical"))] + { + function.lock().unwrap().call_udf(df) + } + }, + Rename { existing, new, .. } => rename::rename_impl(df, existing, new), + Explode { columns, .. } => df.explode(columns.iter().cloned()), + #[cfg(feature = "pivot")] + Unpivot { args, .. } => { + use polars_ops::pivot::UnpivotDF; + let args = (**args).clone(); + df.unpivot2(args) + }, + RowIndex { name, offset, .. } => df.with_row_index(name.clone(), *offset), + } + } + + pub fn to_streaming_lp(&self) -> Option { + let Self::Pipeline { + function: _, + schema: _, + original, + } = self + else { + return None; + }; + + Some(original.as_ref()?.as_ref().as_ref()) + } +} + +impl Debug for FunctionIR { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{self}") + } +} + +impl Display for FunctionIR { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use FunctionIR::*; + match self { + Opaque { fmt_str, .. } => write!(f, "{fmt_str}"), + Unnest { columns } => { + write!(f, "UNNEST by:")?; + let columns = columns.as_ref(); + fmt_column_delimited(f, columns, "[", "]") + }, + Pipeline { original, .. } => { + if let Some(original) = original { + let ir_display = original.as_ref().display(); + + writeln!(f, "--- STREAMING")?; + write!(f, "{ir_display}")?; + let indent = 2; + write!(f, "{:indent$}--- END STREAMING", "") + } else { + write!(f, "STREAMING") + } + }, + FastCount { + sources, + scan_type, + cloud_options: _, + alias, + } => { + let scan_type: &str = (&(**scan_type)).into(); + let default_column_name = PlSmallStr::from_static(crate::constants::LEN); + let alias = alias.as_ref().unwrap_or(&default_column_name); + + write!( + f, + "FAST COUNT ({scan_type}) {} as \"{alias}\"", + ScanSourcesDisplay(sources) + ) + }, + v => { + let s: &str = v.into(); + write!(f, "{s}") + }, + } + } +} diff --git a/crates/polars-plan/src/plans/functions/python_udf.rs b/crates/polars-plan/src/plans/functions/python_udf.rs new file mode 100644 index 000000000000..2dac388cd2f7 --- /dev/null +++ b/crates/polars-plan/src/plans/functions/python_udf.rs @@ -0,0 +1,36 @@ +use super::*; + +pub(super) fn call_python_udf( + function: &PythonFunction, + df: DataFrame, + validate_output: bool, + opt_schema: Option, +) -> PolarsResult { + let expected_schema = if let Some(schema) = opt_schema { + Some(schema) + } + // only materialize if we validate the output + else if validate_output { + Some(df.schema().clone()) + } + // do not materialize the schema, we will ignore it. + else { + None + }; + let out = DataFrameUdf::call_udf(function, df)?; + + if validate_output { + let output_schema = out.schema(); + let expected = expected_schema.unwrap(); + if &expected != output_schema { + return Err(PolarsError::ComputeError( + format!( + "The output schema of 'LazyFrame.map' is incorrect. Expected: {expected:?}\n\ + Got: {output_schema:?}" + ) + .into(), + )); + } + } + Ok(out) +} diff --git a/crates/polars-plan/src/plans/functions/rename.rs b/crates/polars-plan/src/plans/functions/rename.rs new file mode 100644 index 000000000000..49b99a3cc6f6 --- /dev/null +++ b/crates/polars-plan/src/plans/functions/rename.rs @@ -0,0 +1,32 @@ +use super::*; + +pub(super) fn rename_impl( + mut df: DataFrame, + existing: &[PlSmallStr], + new: &[PlSmallStr], +) -> PolarsResult { + let positions = if existing.len() > 1 && df.get_columns().len() > 10 { + let schema = df.schema(); + existing + .iter() + .map(|old| schema.get_full(old).map(|(idx, _, _)| idx)) + .collect::>() + } else { + existing + .iter() + .map(|old| df.get_column_index(old)) + .collect::>() + }; + + for (pos, name) in positions.iter().zip(new.iter()) { + // the column might be removed due to projection pushdown + // so we only update if we can find it. + if let Some(pos) = pos { + // SAFETY: We do not adjust the columns except their names + unsafe { df.get_columns_mut()[*pos].rename(name.clone()) }; + } + } + // recreate dataframe so we check duplicates + let columns = df.take_columns(); + DataFrame::new(columns) +} diff --git a/crates/polars-plan/src/plans/functions/schema.rs b/crates/polars-plan/src/plans/functions/schema.rs new file mode 100644 index 000000000000..6264464d61e2 --- /dev/null +++ b/crates/polars-plan/src/plans/functions/schema.rs @@ -0,0 +1,225 @@ +#[cfg(feature = "pivot")] +use polars_core::utils::try_get_supertype; + +use super::*; +use crate::constants::get_len_name; + +impl FunctionIR { + pub(crate) fn clear_cached_schema(&self) { + use FunctionIR::*; + // We will likely add more branches later + #[allow(clippy::single_match)] + match self { + #[cfg(feature = "pivot")] + Unpivot { schema, .. } => { + let mut guard = schema.lock().unwrap(); + *guard = None; + }, + RowIndex { schema, .. } | Explode { schema, .. } | Rename { schema, .. } => { + let mut guard = schema.lock().unwrap(); + *guard = None; + }, + _ => {}, + } + } + + pub(crate) fn schema<'a>( + &self, + input_schema: &'a SchemaRef, + ) -> PolarsResult> { + use FunctionIR::*; + match self { + Opaque { schema, .. } => match schema { + None => Ok(Cow::Borrowed(input_schema)), + Some(schema_fn) => { + let output_schema = schema_fn.get_schema(input_schema)?; + Ok(Cow::Owned(output_schema)) + }, + }, + #[cfg(feature = "python")] + OpaquePython(OpaquePythonUdf { schema, .. }) => Ok(schema + .as_ref() + .map(|schema| Cow::Owned(schema.clone())) + .unwrap_or_else(|| Cow::Borrowed(input_schema))), + Pipeline { schema, .. } => Ok(Cow::Owned(schema.clone())), + FastCount { alias, .. } => { + let mut schema: Schema = Schema::with_capacity(1); + let name = alias.clone().unwrap_or_else(get_len_name); + schema.insert_at_index(0, name, IDX_DTYPE)?; + Ok(Cow::Owned(Arc::new(schema))) + }, + Rechunk => Ok(Cow::Borrowed(input_schema)), + Unnest { columns: _columns } => { + #[cfg(feature = "dtype-struct")] + { + let mut new_schema = Schema::with_capacity(input_schema.len() * 2); + for (name, dtype) in input_schema.iter() { + if _columns.iter().any(|item| item == name) { + match dtype { + DataType::Struct(flds) => { + for fld in flds { + new_schema + .with_column(fld.name().clone(), fld.dtype().clone()); + } + }, + DataType::Unknown(_) => { + // pass through unknown + }, + _ => { + polars_bail!( + SchemaMismatch: "expected struct dtype, got: `{}`", dtype + ); + }, + } + } else { + new_schema.with_column(name.clone(), dtype.clone()); + } + } + + Ok(Cow::Owned(Arc::new(new_schema))) + } + #[cfg(not(feature = "dtype-struct"))] + { + panic!("activate feature 'dtype-struct'") + } + }, + Rename { + existing, + new, + schema, + .. + } => rename_schema(input_schema, existing, new, schema), + RowIndex { schema, name, .. } => Ok(Cow::Owned(row_index_schema( + schema, + input_schema, + name.clone(), + ))), + Explode { schema, columns } => explode_schema(schema, input_schema, columns), + #[cfg(feature = "pivot")] + Unpivot { schema, args } => unpivot_schema(args, schema, input_schema), + } + } +} + +fn row_index_schema( + cached_schema: &CachedSchema, + input_schema: &SchemaRef, + name: PlSmallStr, +) -> SchemaRef { + let mut guard = cached_schema.lock().unwrap(); + if let Some(schema) = &*guard { + return schema.clone(); + } + let mut schema = (**input_schema).clone(); + schema.insert_at_index(0, name, IDX_DTYPE).unwrap(); + let schema_ref = Arc::new(schema); + *guard = Some(schema_ref.clone()); + schema_ref +} + +fn explode_schema<'a>( + cached_schema: &CachedSchema, + schema: &'a Schema, + columns: &[PlSmallStr], +) -> PolarsResult> { + let mut guard = cached_schema.lock().unwrap(); + if let Some(schema) = &*guard { + return Ok(Cow::Owned(schema.clone())); + } + let mut schema = schema.clone(); + + // columns to string + columns.iter().try_for_each(|name| { + match schema.try_get(name)? { + DataType::List(inner) => { + schema.with_column(name.clone(), inner.as_ref().clone()); + }, + #[cfg(feature = "dtype-array")] + DataType::Array(inner, _) => { + schema.with_column(name.clone(), inner.as_ref().clone()); + }, + _ => {}, + } + + PolarsResult::Ok(()) + })?; + let schema = Arc::new(schema); + *guard = Some(schema.clone()); + Ok(Cow::Owned(schema)) +} + +#[cfg(feature = "pivot")] +fn unpivot_schema<'a>( + args: &UnpivotArgsIR, + cached_schema: &CachedSchema, + input_schema: &'a Schema, +) -> PolarsResult> { + let mut guard = cached_schema.lock().unwrap(); + if let Some(schema) = &*guard { + return Ok(Cow::Owned(schema.clone())); + } + + let mut new_schema = args + .index + .iter() + .map(|id| Ok(Field::new(id.clone(), input_schema.try_get(id)?.clone()))) + .collect::>()?; + let variable_name = args + .variable_name + .as_ref() + .cloned() + .unwrap_or_else(|| "variable".into()); + let value_name = args + .value_name + .as_ref() + .cloned() + .unwrap_or_else(|| "value".into()); + + new_schema.with_column(variable_name, DataType::String); + + // We need to determine the supertype of all value columns. + let mut supertype = DataType::Null; + + // take all columns that are not in `id_vars` as `value_var` + if args.on.is_empty() { + let id_vars = PlHashSet::from_iter(&args.index); + for (name, dtype) in input_schema.iter() { + if !id_vars.contains(name) { + supertype = try_get_supertype(&supertype, dtype)?; + } + } + } else { + for name in &args.on { + let dtype = input_schema.try_get(name)?; + supertype = try_get_supertype(&supertype, dtype)?; + } + } + new_schema.with_column(value_name, supertype); + let schema = Arc::new(new_schema); + *guard = Some(schema.clone()); + Ok(Cow::Owned(schema)) +} + +fn rename_schema<'a>( + input_schema: &'a SchemaRef, + existing: &[PlSmallStr], + new: &[PlSmallStr], + cached_schema: &CachedSchema, +) -> PolarsResult> { + let mut guard = cached_schema.lock().unwrap(); + if let Some(schema) = &*guard { + return Ok(Cow::Owned(schema.clone())); + } + let mut new_schema = input_schema.iter_fields().collect::>(); + + for (old, new) in existing.iter().zip(new.iter()) { + // The column might be removed due to projection pushdown + // so we only update if we can find it. + if let Some((idx, _, _)) = input_schema.get_full(old) { + new_schema[idx].name = new.as_str().into(); + } + } + let schema: SchemaRef = Arc::new(new_schema.into_iter().collect()); + *guard = Some(schema.clone()); + Ok(Cow::Owned(schema)) +} diff --git a/crates/polars-plan/src/plans/hive.rs b/crates/polars-plan/src/plans/hive.rs new file mode 100644 index 000000000000..ddccac14c43b --- /dev/null +++ b/crates/polars-plan/src/plans/hive.rs @@ -0,0 +1,252 @@ +use std::path::{Path, PathBuf}; + +use polars_core::prelude::*; +use polars_io::prelude::schema_inference::{finish_infer_field_schema, infer_field_schema}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct HivePartitionsDf(DataFrame); + +impl HivePartitionsDf { + pub fn get_projection_schema_and_indices( + &self, + names: &PlHashSet, + ) -> (SchemaRef, Vec) { + let mut out_schema = Schema::with_capacity(self.schema().len()); + let mut out_indices = Vec::with_capacity(self.0.get_columns().len()); + + for (i, column) in self.0.get_columns().iter().enumerate() { + let name = column.name(); + if names.contains(name.as_str()) { + out_indices.push(i); + out_schema + .insert_at_index(out_schema.len(), name.clone(), column.dtype().clone()) + .unwrap(); + } + } + + (out_schema.into(), out_indices) + } + + pub fn apply_projection(&mut self, column_indices: &[usize]) { + let schema = self.schema(); + let projected_schema = schema.try_project_indices(column_indices).unwrap(); + self.0 = self.0.select(projected_schema.iter_names_cloned()).unwrap(); + } + + pub fn take_indices(&self, row_indexes: &[IdxSize]) -> Self { + if !row_indexes.is_empty() { + let mut max_idx = 0; + for &i in row_indexes { + max_idx = max_idx.max(i); + } + assert!(max_idx < self.0.height() as IdxSize); + } + // SAFETY: Checked bounds before. + Self(unsafe { self.0.take_slice_unchecked(row_indexes) }) + } + + pub fn df(&self) -> &DataFrame { + &self.0 + } + + pub fn schema(&self) -> &SchemaRef { + self.0.schema() + } +} + +impl From for HivePartitionsDf { + fn from(value: DataFrame) -> Self { + Self(value) + } +} + +/// Note: Returned hive partitions are ordered by their position in the `reader_schema` +/// +/// # Safety +/// `hive_start_idx <= [min path length]` +pub fn hive_partitions_from_paths( + paths: &[PathBuf], + hive_start_idx: usize, + schema: Option, + reader_schema: &Schema, + try_parse_dates: bool, +) -> PolarsResult> { + let Some(path) = paths.first() else { + return Ok(None); + }; + + let sep = separator(path); + let path_string = path.to_str().unwrap(); + + fn parse_hive_string_and_decode(part: &'_ str) -> Option<(&'_ str, std::borrow::Cow<'_, str>)> { + let (k, v) = parse_hive_string(part)?; + let v = percent_encoding::percent_decode(v.as_bytes()) + .decode_utf8() + .ok()?; + + Some((k, v)) + } + + macro_rules! get_hive_parts_iter { + ($e:expr) => {{ + let path_parts = $e[hive_start_idx..].split(sep); + let file_index = path_parts.clone().count() - 1; + + path_parts.enumerate().filter_map(move |(index, part)| { + if index == file_index { + return None; + } + + parse_hive_string_and_decode(part) + }) + }}; + } + + let hive_schema = if let Some(ref schema) = schema { + Arc::new(get_hive_parts_iter!(path_string).map(|(name, _)| { + let Some(dtype) = schema.get(name) else { + polars_bail!( + SchemaFieldNotFound: + "path contains column not present in the given Hive schema: {:?}, path = {:?}", + name, + path + ) + }; + + let dtype = if !try_parse_dates && dtype.is_temporal() { + DataType::String + } else { + dtype.clone() + }; + + Ok(Field::new(PlSmallStr::from_str(name), dtype)) + }).collect::>()?) + } else { + let mut hive_schema = Schema::with_capacity(16); + let mut schema_inference_map: PlHashMap<&str, PlHashSet> = + PlHashMap::with_capacity(16); + + for (name, _) in get_hive_parts_iter!(path_string) { + // If the column is also in the file we can use the dtype stored there. + if let Some(dtype) = reader_schema.get(name) { + let dtype = if !try_parse_dates && dtype.is_temporal() { + DataType::String + } else { + dtype.clone() + }; + + hive_schema.insert_at_index(hive_schema.len(), name.into(), dtype.clone())?; + continue; + } + + hive_schema.insert_at_index(hive_schema.len(), name.into(), DataType::String)?; + schema_inference_map.insert(name, PlHashSet::with_capacity(4)); + } + + if hive_schema.is_empty() && schema_inference_map.is_empty() { + return Ok(None); + } + + if !schema_inference_map.is_empty() { + for path in paths { + for (name, value) in get_hive_parts_iter!(path.to_str().unwrap()) { + let Some(entry) = schema_inference_map.get_mut(name) else { + continue; + }; + + if value.is_empty() || value == "__HIVE_DEFAULT_PARTITION__" { + continue; + } + + entry.insert(infer_field_schema(value.as_ref(), try_parse_dates, false)); + } + } + + for (name, ref possibilities) in schema_inference_map.drain() { + let dtype = finish_infer_field_schema(possibilities); + *hive_schema.try_get_mut(name).unwrap() = dtype; + } + } + Arc::new(hive_schema) + }; + + let mut buffers = polars_io::csv::read::buffer::init_buffers( + &(0..hive_schema.len()).collect::>(), + paths.len(), + hive_schema.as_ref(), + None, + polars_io::prelude::CsvEncoding::Utf8, + false, + )?; + + for path in paths { + let path = path.to_str().unwrap(); + + for (name, value) in get_hive_parts_iter!(path) { + let Some(index) = hive_schema.index_of(name) else { + polars_bail!( + SchemaFieldNotFound: + "path contains column not present in the given Hive schema: {:?}, path = {:?}", + name, + path + ) + }; + + let buf = buffers.get_mut(index).unwrap(); + + if !value.is_empty() && value != "__HIVE_DEFAULT_PARTITION__" { + buf.add(value.as_bytes(), false, false, false)?; + } else { + buf.add_null(false); + } + } + } + + let mut buffers = buffers + .into_iter() + .map(|x| Ok(x.into_series()?.into_column())) + .collect::>>()?; + buffers.sort_by_key(|s| reader_schema.index_of(s.name()).unwrap_or(usize::MAX)); + + Ok(Some(HivePartitionsDf(DataFrame::new_with_height( + paths.len(), + buffers, + )?))) +} + +/// Determine the path separator for identifying Hive partitions. +fn separator(url: &Path) -> &[char] { + if cfg!(target_family = "windows") { + if polars_io::path_utils::is_cloud_url(url) { + &['/'] + } else { + &['/', '\\'] + } + } else { + &['/'] + } +} + +/// Parse a Hive partition string (e.g. "column=1.5") into a name and value part. +/// +/// Returns `None` if the string is not a Hive partition string. +fn parse_hive_string(part: &'_ str) -> Option<(&'_ str, &'_ str)> { + let mut it = part.split('='); + let name = it.next()?; + let value = it.next()?; + + // Having multiple '=' doesn't seem like a valid Hive partition. + if it.next().is_some() { + return None; + } + + // Files are not Hive partitions, so globs are not valid. + if value.contains('*') { + return None; + }; + + Some((name, value)) +} diff --git a/crates/polars-plan/src/plans/ir/dot.rs b/crates/polars-plan/src/plans/ir/dot.rs new file mode 100644 index 000000000000..9c6c572cf0a0 --- /dev/null +++ b/crates/polars-plan/src/plans/ir/dot.rs @@ -0,0 +1,465 @@ +use std::fmt; +use std::path::PathBuf; + +use polars_core::schema::Schema; +use polars_utils::pl_str::PlSmallStr; + +use super::format::ExprIRSliceDisplay; +use crate::constants::UNLIMITED_CACHE; +use crate::prelude::ir::format::ColumnsDisplay; +use crate::prelude::*; + +pub struct IRDotDisplay<'a> { + is_streaming: bool, + lp: IRPlanRef<'a>, +} + +const INDENT: &str = " "; + +#[derive(Clone, Copy)] +enum DotNode { + Plain(usize), + Cache(usize), +} + +impl fmt::Display for DotNode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + DotNode::Plain(n) => write!(f, "p{n}"), + DotNode::Cache(n) => write!(f, "c{n}"), + } + } +} + +#[inline(always)] +fn write_label<'a, 'b>( + f: &'a mut fmt::Formatter<'b>, + id: DotNode, + mut w: impl FnMut(&mut EscapeLabel<'a>) -> fmt::Result, +) -> fmt::Result { + write!(f, "{INDENT}{id}[label=\"")?; + + let mut escaped = EscapeLabel(f); + w(&mut escaped)?; + let EscapeLabel(f) = escaped; + + writeln!(f, "\"]")?; + + Ok(()) +} + +impl<'a> IRDotDisplay<'a> { + pub fn new(lp: IRPlanRef<'a>) -> Self { + if let Some(streaming_lp) = lp.extract_streaming_plan() { + return Self::new_streaming(streaming_lp); + } + + Self { + is_streaming: false, + lp, + } + } + + fn new_streaming(lp: IRPlanRef<'a>) -> Self { + Self { + is_streaming: true, + lp, + } + } + + fn with_root(&self, root: Node) -> Self { + Self { + is_streaming: false, + lp: self.lp.with_root(root), + } + } + + fn display_expr(&self, expr: &'a ExprIR) -> ExprIRDisplay<'a> { + expr.display(self.lp.expr_arena) + } + + fn display_exprs(&self, exprs: &'a [ExprIR]) -> ExprIRSliceDisplay<'a, ExprIR> { + ExprIRSliceDisplay { + exprs, + expr_arena: self.lp.expr_arena, + } + } + + fn _format( + &self, + f: &mut fmt::Formatter<'_>, + parent: Option, + last: &mut usize, + ) -> std::fmt::Result { + use fmt::Write; + + let root = self.lp.root(); + + let mut parent = parent; + if self.is_streaming { + *last += 1; + let streaming_node = DotNode::Plain(*last); + + if let Some(parent) = parent { + writeln!(f, "{INDENT}{parent} -- {streaming_node}")?; + write_label(f, streaming_node, |f| f.write_str("STREAMING"))?; + } + + parent = Some(streaming_node); + } + let parent = parent; + + let id = if let IR::Cache { id, .. } = root { + DotNode::Cache(*id) + } else { + *last += 1; + DotNode::Plain(*last) + }; + + if let Some(parent) = parent { + writeln!(f, "{INDENT}{parent} -- {id}")?; + } + + use IR::*; + match root { + Union { inputs, .. } => { + for input in inputs { + self.with_root(*input)._format(f, Some(id), last)?; + } + + write_label(f, id, |f| f.write_str("UNION"))?; + }, + HConcat { inputs, .. } => { + for input in inputs { + self.with_root(*input)._format(f, Some(id), last)?; + } + + write_label(f, id, |f| f.write_str("HCONCAT"))?; + }, + Cache { + input, cache_hits, .. + } => { + self.with_root(*input)._format(f, Some(id), last)?; + + if *cache_hits == UNLIMITED_CACHE { + write_label(f, id, |f| f.write_str("CACHE"))?; + } else { + write_label(f, id, |f| write!(f, "CACHE: {cache_hits} times"))?; + }; + }, + Filter { predicate, input } => { + self.with_root(*input)._format(f, Some(id), last)?; + + let pred = self.display_expr(predicate); + write_label(f, id, |f| write!(f, "FILTER BY {pred}"))?; + }, + #[cfg(feature = "python")] + PythonScan { options } => { + let predicate = match &options.predicate { + PythonPredicate::Polars(e) => format!("{}", self.display_expr(e)), + PythonPredicate::PyArrow(s) => s.clone(), + PythonPredicate::None => "none".to_string(), + }; + let with_columns = NumColumns(options.with_columns.as_ref().map(|s| s.as_ref())); + let total_columns = options.schema.len(); + + write_label(f, id, |f| { + write!( + f, + "PYTHON SCAN\nπ {with_columns}/{total_columns};\nσ {predicate}" + ) + })? + }, + Select { + expr, + input, + schema, + .. + } => { + self.with_root(*input)._format(f, Some(id), last)?; + write_label(f, id, |f| write!(f, "π {}/{}", expr.len(), schema.len()))?; + }, + Sort { + input, by_column, .. + } => { + let by_column = self.display_exprs(by_column); + self.with_root(*input)._format(f, Some(id), last)?; + write_label(f, id, |f| write!(f, "SORT BY {by_column}"))?; + }, + GroupBy { + input, keys, aggs, .. + } => { + let keys = self.display_exprs(keys); + let aggs = self.display_exprs(aggs); + self.with_root(*input)._format(f, Some(id), last)?; + write_label(f, id, |f| write!(f, "AGG {aggs}\nBY\n{keys}"))?; + }, + HStack { input, exprs, .. } => { + let exprs = self.display_exprs(exprs); + self.with_root(*input)._format(f, Some(id), last)?; + write_label(f, id, |f| write!(f, "WITH COLUMNS {exprs}"))?; + }, + Slice { input, offset, len } => { + self.with_root(*input)._format(f, Some(id), last)?; + write_label(f, id, |f| write!(f, "SLICE offset: {offset}; len: {len}"))?; + }, + Distinct { input, options, .. } => { + self.with_root(*input)._format(f, Some(id), last)?; + write_label(f, id, |f| { + f.write_str("DISTINCT")?; + + if let Some(subset) = &options.subset { + f.write_str(" BY ")?; + + let mut subset = subset.iter(); + + if let Some(fst) = subset.next() { + f.write_str(fst)?; + for name in subset { + write!(f, ", \"{name}\"")?; + } + } else { + f.write_str("None")?; + } + } + + Ok(()) + })?; + }, + DataFrameScan { + schema, + output_schema, + .. + } => { + let num_columns = NumColumnsSchema(output_schema.as_ref().map(|p| p.as_ref())); + let total_columns = schema.len(); + + write_label(f, id, |f| { + write!(f, "TABLE\nπ {num_columns}/{total_columns}") + })?; + }, + Scan { + sources, + file_info, + hive_parts: _, + predicate, + scan_type, + unified_scan_args, + output_schema: _, + } => { + let name: &str = (&**scan_type).into(); + let path = ScanSourcesDisplay(sources); + let with_columns = unified_scan_args + .projection + .as_ref() + .map(|cols| cols.as_ref()); + let with_columns = NumColumns(with_columns); + let total_columns = + file_info.schema.len() - usize::from(unified_scan_args.row_index.is_some()); + + write_label(f, id, |f| { + write!(f, "{name} SCAN {path}\nπ {with_columns}/{total_columns};",)?; + + if let Some(predicate) = predicate.as_ref() { + write!(f, "\nσ {}", self.display_expr(predicate))?; + } + + if let Some(row_index) = unified_scan_args.row_index.as_ref() { + write!(f, "\nrow index: {} (+{})", row_index.name, row_index.offset)?; + } + + Ok(()) + })?; + }, + Join { + input_left, + input_right, + left_on, + right_on, + options, + .. + } => { + self.with_root(*input_left)._format(f, Some(id), last)?; + self.with_root(*input_right)._format(f, Some(id), last)?; + + let left_on = self.display_exprs(left_on); + let right_on = self.display_exprs(right_on); + + write_label(f, id, |f| { + write!( + f, + "JOIN {}\nleft: {left_on};\nright: {right_on}", + options.args.how + ) + })?; + }, + MapFunction { + input, function, .. + } => { + if let Some(streaming_lp) = function.to_streaming_lp() { + Self::new_streaming(streaming_lp)._format(f, Some(id), last)?; + } else { + self.with_root(*input)._format(f, Some(id), last)?; + write_label(f, id, |f| write!(f, "{function}"))?; + } + }, + ExtContext { input, .. } => { + self.with_root(*input)._format(f, Some(id), last)?; + write_label(f, id, |f| f.write_str("EXTERNAL_CONTEXT"))?; + }, + Sink { input, payload, .. } => { + self.with_root(*input)._format(f, Some(id), last)?; + + write_label(f, id, |f| { + f.write_str(match payload { + SinkTypeIR::Memory => "SINK (MEMORY)", + SinkTypeIR::File { .. } => "SINK (FILE)", + SinkTypeIR::Partition { .. } => "SINK (PARTITION)", + }) + })?; + }, + SinkMultiple { inputs } => { + for input in inputs { + self.with_root(*input)._format(f, Some(id), last)?; + } + + write_label(f, id, |f| f.write_str("SINK MULTIPLE"))?; + }, + SimpleProjection { input, columns } => { + let num_columns = columns.as_ref().len(); + let total_columns = self.lp.lp_arena.get(*input).schema(self.lp.lp_arena).len(); + + let columns = ColumnsDisplay(columns.as_ref()); + self.with_root(*input)._format(f, Some(id), last)?; + write_label(f, id, |f| { + write!(f, "simple π {num_columns}/{total_columns}\n[{columns}]") + })?; + }, + #[cfg(feature = "merge_sorted")] + MergeSorted { + input_left, + input_right, + key, + } => { + self.with_root(*input_left)._format(f, Some(id), last)?; + self.with_root(*input_right)._format(f, Some(id), last)?; + + write_label(f, id, |f| write!(f, "MERGE_SORTED ON '{key}'",))?; + }, + Invalid => write_label(f, id, |f| f.write_str("INVALID"))?, + } + + Ok(()) + } +} + +// A few utility structures for formatting +pub struct PathsDisplay<'a>(pub &'a [PathBuf]); +pub struct ScanSourcesDisplay<'a>(pub &'a ScanSources); +struct NumColumns<'a>(Option<&'a [PlSmallStr]>); +struct NumColumnsSchema<'a>(Option<&'a Schema>); + +impl fmt::Display for ScanSourceRef<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ScanSourceRef::Path(path) => path.display().fmt(f), + ScanSourceRef::File(_) => f.write_str("open-file"), + ScanSourceRef::Buffer(buff) => write!(f, "{} in-mem bytes", buff.len()), + } + } +} + +impl fmt::Display for ScanSourcesDisplay<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.0.len() { + 0 => write!(f, "[]"), + 1 => write!(f, "[{}]", self.0.at(0)), + 2 => write!(f, "[{}, {}]", self.0.at(0), self.0.at(1)), + _ => write!( + f, + "[{}, ... {} other sources]", + self.0.at(0), + self.0.len() - 1, + ), + } + } +} + +impl fmt::Display for PathsDisplay<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.0.len() { + 0 => write!(f, "[]"), + 1 => write!(f, "[{}]", self.0[0].display()), + 2 => write!(f, "[{}, {}]", self.0[0].display(), self.0[1].display()), + _ => write!( + f, + "[{}, ... {} other files]", + self.0[0].display(), + self.0.len() - 1, + ), + } + } +} + +impl fmt::Display for NumColumns<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.0 { + None => f.write_str("*"), + Some(columns) => columns.len().fmt(f), + } + } +} + +impl fmt::Display for NumColumnsSchema<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.0 { + None => f.write_str("*"), + Some(columns) => columns.len().fmt(f), + } + } +} + +/// Utility structure to write to a [`fmt::Formatter`] whilst escaping the output as a label name +pub struct EscapeLabel<'a>(pub &'a mut dyn fmt::Write); + +impl fmt::Write for EscapeLabel<'_> { + fn write_str(&mut self, mut s: &str) -> fmt::Result { + loop { + let mut char_indices = s.char_indices(); + + // This escapes quotes and new lines + // @NOTE: I am aware this does not work for \" and such. I am ignoring that fact as we + // are not really using such strings. + let f = char_indices.find_map(|(i, c)| match c { + '"' => Some((i, r#"\""#)), + '\n' => Some((i, r#"\n"#)), + _ => None, + }); + + let Some((at, to_write)) = f else { + break; + }; + + self.0.write_str(&s[..at])?; + self.0.write_str(to_write)?; + s = &s[at + 1..]; + } + + self.0.write_str(s)?; + + Ok(()) + } +} + +impl fmt::Display for IRDotDisplay<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "graph polars_query {{")?; + + let mut last = 0; + self._format(f, None, &mut last)?; + + writeln!(f, "}}")?; + + Ok(()) + } +} diff --git a/crates/polars-plan/src/plans/ir/format.rs b/crates/polars-plan/src/plans/ir/format.rs new file mode 100644 index 000000000000..363271236c47 --- /dev/null +++ b/crates/polars-plan/src/plans/ir/format.rs @@ -0,0 +1,801 @@ +use std::fmt::{self, Display, Formatter}; + +use polars_core::schema::Schema; +use polars_io::RowIndex; +use polars_utils::format_list_truncated; +use polars_utils::slice_enum::Slice; +use recursive::recursive; + +use self::ir::dot::ScanSourcesDisplay; +use crate::prelude::*; + +pub struct IRDisplay<'a> { + is_streaming: bool, + lp: IRPlanRef<'a>, +} + +#[derive(Clone, Copy)] +pub struct ExprIRDisplay<'a> { + pub(crate) node: Node, + pub(crate) output_name: &'a OutputName, + pub(crate) expr_arena: &'a Arena, +} + +impl<'a> ExprIRDisplay<'a> { + pub fn display_node(node: Node, expr_arena: &'a Arena) -> Self { + Self { + node, + output_name: &OutputName::None, + expr_arena, + } + } +} + +/// Utility structure to display several [`ExprIR`]'s in a nice way +pub(crate) struct ExprIRSliceDisplay<'a, T: AsExpr> { + pub(crate) exprs: &'a [T], + pub(crate) expr_arena: &'a Arena, +} + +pub(crate) trait AsExpr { + fn node(&self) -> Node; + fn output_name(&self) -> &OutputName; +} + +impl AsExpr for Node { + fn node(&self) -> Node { + *self + } + fn output_name(&self) -> &OutputName { + &OutputName::None + } +} + +impl AsExpr for ExprIR { + fn node(&self) -> Node { + self.node() + } + fn output_name(&self) -> &OutputName { + self.output_name_inner() + } +} + +#[allow(clippy::too_many_arguments)] +fn write_scan( + f: &mut Formatter, + name: &str, + sources: &ScanSources, + indent: usize, + n_columns: i64, + total_columns: usize, + predicate: &Option>, + pre_slice: Option, + row_index: Option<&RowIndex>, +) -> fmt::Result { + write!( + f, + "{:indent$}{name} SCAN {}", + "", + ScanSourcesDisplay(sources) + )?; + + let total_columns = total_columns - usize::from(row_index.is_some()); + if n_columns > 0 { + write!( + f, + "\n{:indent$}PROJECT {n_columns}/{total_columns} COLUMNS", + "", + )?; + } else { + write!(f, "\n{:indent$}PROJECT */{total_columns} COLUMNS", "")?; + } + if let Some(predicate) = predicate { + write!(f, "\n{:indent$}SELECTION: {predicate}", "")?; + } + if let Some(pre_slice) = pre_slice { + write!(f, "\n{:indent$}SLICE: {pre_slice:?}", "")?; + } + if let Some(row_index) = row_index { + write!(f, "\n{:indent$}ROW_INDEX: {}", "", row_index.name)?; + if row_index.offset != 0 { + write!(f, " (offset: {})", row_index.offset)?; + } + } + Ok(()) +} + +impl<'a> IRDisplay<'a> { + pub fn new(lp: IRPlanRef<'a>) -> Self { + if let Some(streaming_lp) = lp.extract_streaming_plan() { + return Self::new_streaming(streaming_lp); + } + + Self { + is_streaming: false, + lp, + } + } + + fn new_streaming(lp: IRPlanRef<'a>) -> Self { + Self { + is_streaming: true, + lp, + } + } + + fn root(&self) -> &IR { + self.lp.root() + } + + fn with_root(&self, root: Node) -> Self { + Self { + is_streaming: false, + lp: self.lp.with_root(root), + } + } + + fn display_expr(&self, root: &'a ExprIR) -> ExprIRDisplay<'a> { + ExprIRDisplay { + node: root.node(), + output_name: root.output_name_inner(), + expr_arena: self.lp.expr_arena, + } + } + + fn display_expr_slice(&self, exprs: &'a [ExprIR]) -> ExprIRSliceDisplay<'a, ExprIR> { + ExprIRSliceDisplay { + exprs, + expr_arena: self.lp.expr_arena, + } + } + + #[recursive] + fn _format(&self, f: &mut Formatter, indent: usize) -> fmt::Result { + let indent_increment = 2; + let indent = if self.is_streaming { + writeln!(f, "{:indent$}STREAMING:", "")?; + indent + indent_increment + } else { + if indent != 0 { + writeln!(f)?; + } + indent + }; + + let sub_indent = indent + indent_increment; + use IR::*; + + match self.root() { + #[cfg(feature = "python")] + PythonScan { options } => { + let total_columns = options.schema.len(); + let n_columns = options + .with_columns + .as_ref() + .map(|s| s.len() as i64) + .unwrap_or(-1); + + let predicate = match &options.predicate { + PythonPredicate::Polars(e) => Some(self.display_expr(e)), + PythonPredicate::PyArrow(_) => None, + PythonPredicate::None => None, + }; + + write_scan( + f, + "PYTHON", + &ScanSources::default(), + indent, + n_columns, + total_columns, + &predicate, + options + .n_rows + .map(|len| polars_utils::slice_enum::Slice::Positive { offset: 0, len }), + None, + ) + }, + Union { inputs, options } => { + let name = if let Some(slice) = options.slice { + format!("SLICED UNION: {slice:?}") + } else { + "UNION".to_string() + }; + + // 3 levels of indentation + // - 0 => UNION ... END UNION + // - 1 => PLAN 0, PLAN 1, ... PLAN N + // - 2 => actual formatting of plans + let sub_sub_indent = sub_indent + indent_increment; + write!(f, "{:indent$}{name}", "")?; + for (i, plan) in inputs.iter().enumerate() { + write!(f, "\n{:sub_indent$}PLAN {i}:", "")?; + self.with_root(*plan)._format(f, sub_sub_indent)?; + } + write!(f, "\n{:indent$}END {name}", "") + }, + HConcat { inputs, .. } => { + let sub_sub_indent = sub_indent + indent_increment; + write!(f, "{:indent$}HCONCAT", "")?; + for (i, plan) in inputs.iter().enumerate() { + write!(f, "\n{:sub_indent$}PLAN {i}:", "")?; + self.with_root(*plan)._format(f, sub_sub_indent)?; + } + write!(f, "\n{:indent$}END HCONCAT", "") + }, + Cache { + input, + id, + cache_hits, + } => { + write!( + f, + "{:indent$}CACHE[id: {:x}, cache_hits: {}]", + "", *id, *cache_hits + )?; + self.with_root(*input)._format(f, sub_indent) + }, + Scan { + sources, + file_info, + predicate, + scan_type, + unified_scan_args, + hive_parts: _, + output_schema: _, + } => { + let n_columns = unified_scan_args + .projection + .as_ref() + .map(|columns| columns.len() as i64) + .unwrap_or(-1); + + let predicate = predicate.as_ref().map(|p| self.display_expr(p)); + + write_scan( + f, + (&**scan_type).into(), + sources, + indent, + n_columns, + file_info.schema.len(), + &predicate, + unified_scan_args.pre_slice.clone(), + unified_scan_args.row_index.as_ref(), + ) + }, + Filter { predicate, input } => { + let predicate = self.display_expr(predicate); + // this one is writeln because we don't increase indent (which inserts a line) + write!(f, "{:indent$}FILTER {predicate}", "")?; + write!(f, "\n{:indent$}FROM", "")?; + self.with_root(*input)._format(f, sub_indent) + }, + DataFrameScan { + schema, + output_schema, + .. + } => { + let total_columns = schema.len(); + let (n_columns, projected) = if let Some(schema) = output_schema { + ( + format!("{}", schema.len()), + format_list_truncated!(schema.iter_names(), 4, '"'), + ) + } else { + ("*".to_string(), "".to_string()) + }; + write!( + f, + "{:indent$}DF {}; PROJECT{} {}/{} COLUMNS", + "", + format_list_truncated!(schema.iter_names(), 4, '"'), + projected, + n_columns, + total_columns, + ) + }, + Select { expr, input, .. } => { + // @NOTE: Maybe there should be a clear delimiter here? + let exprs = self.display_expr_slice(expr); + write!(f, "{:indent$}SELECT {exprs}", "")?; + write!(f, "\n{:indent$}FROM", "")?; + self.with_root(*input)._format(f, sub_indent) + }, + Sort { + input, by_column, .. + } => { + let by_column = self.display_expr_slice(by_column); + write!(f, "{:indent$}SORT BY {by_column}", "")?; + self.with_root(*input)._format(f, sub_indent) + }, + GroupBy { + input, + keys, + aggs, + apply, + .. + } => { + let keys = self.display_expr_slice(keys); + write!(f, "{:indent$}AGGREGATE", "")?; + if apply.is_some() { + write!(f, "\n{:sub_indent$}MAP_GROUPS BY {keys}", "")?; + write!(f, "\n{:sub_indent$}FROM", "")?; + } else { + let aggs = self.display_expr_slice(aggs); + write!(f, "\n{:sub_indent$}{aggs} BY {keys}", "")?; + write!(f, "\n{:sub_indent$}FROM", "")?; + } + self.with_root(*input)._format(f, sub_indent) + }, + Join { + input_left, + input_right, + left_on, + right_on, + options, + .. + } => { + let left_on = self.display_expr_slice(left_on); + let right_on = self.display_expr_slice(right_on); + + // Fused cross + filter (show as nested loop join) + if let Some(JoinTypeOptionsIR::Cross { predicate }) = &options.options { + let predicate = self.display_expr(predicate); + let name = "NESTED LOOP"; + write!(f, "{:indent$}{name} JOIN ON {predicate}:", "")?; + write!(f, "\n{:indent$}LEFT PLAN:", "")?; + self.with_root(*input_left)._format(f, sub_indent)?; + write!(f, "\n{:indent$}RIGHT PLAN:", "")?; + self.with_root(*input_right)._format(f, sub_indent)?; + write!(f, "\n{:indent$}END {name} JOIN", "") + } else { + let how = &options.args.how; + write!(f, "{:indent$}{how} JOIN:", "")?; + write!(f, "\n{:indent$}LEFT PLAN ON: {left_on}", "")?; + self.with_root(*input_left)._format(f, sub_indent)?; + write!(f, "\n{:indent$}RIGHT PLAN ON: {right_on}", "")?; + self.with_root(*input_right)._format(f, sub_indent)?; + write!(f, "\n{:indent$}END {how} JOIN", "") + } + }, + HStack { input, exprs, .. } => { + // @NOTE: Maybe there should be a clear delimiter here? + let exprs = self.display_expr_slice(exprs); + + write!(f, "{:indent$} WITH_COLUMNS:", "",)?; + write!(f, "\n{:indent$} {exprs} ", "")?; + self.with_root(*input)._format(f, sub_indent) + }, + Distinct { input, options } => { + write!( + f, + "{:indent$}UNIQUE[maintain_order: {:?}, keep_strategy: {:?}] BY {:?}", + "", options.maintain_order, options.keep_strategy, options.subset + )?; + self.with_root(*input)._format(f, sub_indent) + }, + Slice { input, offset, len } => { + write!(f, "{:indent$}SLICE[offset: {offset}, len: {len}]", "")?; + self.with_root(*input)._format(f, sub_indent) + }, + MapFunction { + input, function, .. + } => { + if let Some(streaming_lp) = function.to_streaming_lp() { + IRDisplay::new_streaming(streaming_lp)._format(f, indent) + } else { + write!(f, "{:indent$}{function}", "")?; + self.with_root(*input)._format(f, sub_indent) + } + }, + ExtContext { input, .. } => { + write!(f, "{:indent$}EXTERNAL_CONTEXT", "")?; + self.with_root(*input)._format(f, sub_indent) + }, + Sink { input, payload, .. } => { + let name = match payload { + SinkTypeIR::Memory => "SINK (memory)", + SinkTypeIR::File { .. } => "SINK (file)", + SinkTypeIR::Partition { .. } => "SINK (partition)", + }; + write!(f, "{:indent$}{name}", "")?; + self.with_root(*input)._format(f, sub_indent) + }, + SinkMultiple { inputs } => { + // 3 levels of indentation + // - 0 => SINK_MULTIPLE ... END SINK_MULTIPLE + // - 1 => PLAN 0, PLAN 1, ... PLAN N + // - 2 => actual formatting of plans + let sub_sub_indent = sub_indent + 2; + write!(f, "{:indent$}SINK_MULTIPLE", "")?; + for (i, plan) in inputs.iter().enumerate() { + write!(f, "\n{:sub_indent$}PLAN {i}:", "")?; + self.with_root(*plan)._format(f, sub_sub_indent)?; + } + write!(f, "\n{:indent$}END SINK_MULTIPLE", "") + }, + SimpleProjection { input, columns } => { + let num_columns = columns.as_ref().len(); + let total_columns = self.lp.lp_arena.get(*input).schema(self.lp.lp_arena).len(); + + let columns = ColumnsDisplay(columns.as_ref()); + write!( + f, + "{:indent$}simple π {num_columns}/{total_columns} [{columns}]", + "" + )?; + + self.with_root(*input)._format(f, sub_indent) + }, + #[cfg(feature = "merge_sorted")] + MergeSorted { + input_left, + input_right, + key, + } => { + write!(f, "{:indent$}MERGE SORTED ON '{key}':", "")?; + write!(f, "\n{:indent$}LEFT PLAN:", "")?; + self.with_root(*input_left)._format(f, sub_indent)?; + write!(f, "\n{:indent$}RIGHT PLAN:", "")?; + self.with_root(*input_right)._format(f, sub_indent)?; + write!(f, "\n{:indent$}END MERGE_SORTED", "") + }, + Invalid => write!(f, "{:indent$}INVALID", ""), + } + } +} + +impl<'a> ExprIRDisplay<'a> { + fn with_slice(&self, exprs: &'a [T]) -> ExprIRSliceDisplay<'a, T> { + ExprIRSliceDisplay { + exprs, + expr_arena: self.expr_arena, + } + } + + fn with_root(&self, root: &'a T) -> Self { + Self { + node: root.node(), + output_name: root.output_name(), + expr_arena: self.expr_arena, + } + } +} + +impl Display for IRDisplay<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + self._format(f, 0) + } +} + +impl fmt::Debug for IRDisplay<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Display::fmt(&self, f) + } +} + +impl Display for ExprIRSliceDisplay<'_, T> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + // Display items in slice delimited by a comma + + use std::fmt::Write; + + let mut iter = self.exprs.iter(); + + f.write_char('[')?; + if let Some(fst) = iter.next() { + let fst = ExprIRDisplay { + node: fst.node(), + output_name: fst.output_name(), + expr_arena: self.expr_arena, + }; + write!(f, "{fst}")?; + } + + for expr in iter { + let expr = ExprIRDisplay { + node: expr.node(), + output_name: expr.output_name(), + expr_arena: self.expr_arena, + }; + write!(f, ", {expr}")?; + } + + f.write_char(']')?; + + Ok(()) + } +} + +impl fmt::Debug for ExprIRSliceDisplay<'_, T> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Display::fmt(self, f) + } +} + +impl Display for ExprIRDisplay<'_> { + #[recursive] + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let root = self.expr_arena.get(self.node); + + use AExpr::*; + match root { + Window { + function, + partition_by, + order_by, + options, + } => { + let function = self.with_root(function); + let partition_by = self.with_slice(partition_by); + match options { + #[cfg(feature = "dynamic_group_by")] + WindowType::Rolling(options) => { + write!( + f, + "{function}.rolling(by='{}', offset={}, period={})", + options.index_column, options.offset, options.period + ) + }, + _ => { + if let Some((order_by, _)) = order_by { + let order_by = self.with_root(order_by); + write!( + f, + "{function}.over(partition_by: {partition_by}, order_by: {order_by})" + ) + } else { + write!(f, "{function}.over({partition_by})") + } + }, + } + }, + Len => write!(f, "len()"), + Explode(expr) => { + let expr = self.with_root(expr); + write!(f, "{expr}.explode()") + }, + Alias(expr, name) => { + let expr = self.with_root(expr); + write!(f, "{expr}.alias(\"{name}\")") + }, + Column(name) => write!(f, "col(\"{name}\")"), + Literal(v) => write!(f, "{v:?}"), + BinaryExpr { left, op, right } => { + let left = self.with_root(left); + let right = self.with_root(right); + write!(f, "[({left}) {op:?} ({right})]") + }, + Sort { expr, options } => { + let expr = self.with_root(expr); + if options.descending { + write!(f, "{expr}.sort(desc)") + } else { + write!(f, "{expr}.sort(asc)") + } + }, + SortBy { + expr, + by, + sort_options, + } => { + let expr = self.with_root(expr); + let by = self.with_slice(by); + write!(f, "{expr}.sort_by(by={by}, sort_option={sort_options:?})",) + }, + Filter { input, by } => { + let input = self.with_root(input); + let by = self.with_root(by); + + write!(f, "{input}.filter({by})") + }, + Gather { + expr, + idx, + returns_scalar, + } => { + let expr = self.with_root(expr); + let idx = self.with_root(idx); + expr.fmt(f)?; + + if *returns_scalar { + write!(f, ".get({idx})") + } else { + write!(f, ".gather({idx})") + } + }, + Agg(agg) => { + use IRAggExpr::*; + match agg { + Min { + input, + propagate_nans, + } => { + self.with_root(input).fmt(f)?; + if *propagate_nans { + write!(f, ".nan_min()") + } else { + write!(f, ".min()") + } + }, + Max { + input, + propagate_nans, + } => { + self.with_root(input).fmt(f)?; + if *propagate_nans { + write!(f, ".nan_max()") + } else { + write!(f, ".max()") + } + }, + Median(expr) => write!(f, "{}.median()", self.with_root(expr)), + Mean(expr) => write!(f, "{}.mean()", self.with_root(expr)), + First(expr) => write!(f, "{}.first()", self.with_root(expr)), + Last(expr) => write!(f, "{}.last()", self.with_root(expr)), + Implode(expr) => write!(f, "{}.implode()", self.with_root(expr)), + NUnique(expr) => write!(f, "{}.n_unique()", self.with_root(expr)), + Sum(expr) => write!(f, "{}.sum()", self.with_root(expr)), + AggGroups(expr) => write!(f, "{}.groups()", self.with_root(expr)), + Count(expr, _) => write!(f, "{}.count()", self.with_root(expr)), + Var(expr, _) => write!(f, "{}.var()", self.with_root(expr)), + Std(expr, _) => write!(f, "{}.std()", self.with_root(expr)), + Quantile { expr, .. } => write!(f, "{}.quantile()", self.with_root(expr)), + } + }, + Cast { + expr, + dtype, + options, + } => { + self.with_root(expr).fmt(f)?; + if options.strict() { + write!(f, ".strict_cast({dtype:?})") + } else { + write!(f, ".cast({dtype:?})") + } + }, + Ternary { + predicate, + truthy, + falsy, + } => { + let predicate = self.with_root(predicate); + let truthy = self.with_root(truthy); + let falsy = self.with_root(falsy); + write!(f, "when({predicate}).then({truthy}).otherwise({falsy})",) + }, + Function { + input, function, .. + } => { + let fst = self.with_root(&input[0]); + fst.fmt(f)?; + if input.len() >= 2 { + write!(f, ".{function}({})", self.with_slice(&input[1..])) + } else { + write!(f, ".{function}()") + } + }, + AnonymousFunction { input, options, .. } => { + let fst = self.with_root(&input[0]); + fst.fmt(f)?; + if input.len() >= 2 { + write!(f, ".{}({})", options.fmt_str, self.with_slice(&input[1..])) + } else { + write!(f, ".{}()", options.fmt_str) + } + }, + Slice { + input, + offset, + length, + } => { + let input = self.with_root(input); + let offset = self.with_root(offset); + let length = self.with_root(length); + + write!(f, "{input}.slice(offset={offset}, length={length})") + }, + }?; + + match self.output_name { + OutputName::None => {}, + OutputName::LiteralLhs(_) => {}, + OutputName::ColumnLhs(_) => {}, + #[cfg(feature = "dtype-struct")] + OutputName::Field(_) => {}, + OutputName::Alias(name) => write!(f, r#".alias("{name}")"#)?, + } + + Ok(()) + } +} + +impl fmt::Debug for ExprIRDisplay<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Display::fmt(self, f) + } +} + +pub(crate) struct ColumnsDisplay<'a>(pub(crate) &'a Schema); + +impl fmt::Display for ColumnsDisplay<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let len = self.0.len(); + let mut iter_names = self.0.iter_names().enumerate(); + + const MAX_LEN: usize = 32; + const ADD_PER_ITEM: usize = 4; + + let mut current_len = 0; + + if let Some((_, fst)) = iter_names.next() { + write!(f, "\"{fst}\"")?; + + current_len += fst.len() + ADD_PER_ITEM; + } + + for (i, col) in iter_names { + current_len += col.len() + ADD_PER_ITEM; + + if current_len > MAX_LEN { + write!(f, ", ... {} other ", len - i)?; + if len - i == 1 { + f.write_str("column")?; + } else { + f.write_str("columns")?; + } + + break; + } + + write!(f, ", \"{col}\"")?; + } + + Ok(()) + } +} + +impl fmt::Debug for Operator { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + Display::fmt(self, f) + } +} + +impl fmt::Debug for LiteralValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use LiteralValue::*; + + match self { + Self::Scalar(sc) => write!(f, "{}", sc.value()), + Self::Series(s) => { + let name = s.name(); + if name.is_empty() { + write!(f, "Series") + } else { + write!(f, "Series[{name}]") + } + }, + Range(range) => fmt::Debug::fmt(range, f), + Dyn(d) => fmt::Debug::fmt(d, f), + } + } +} + +impl fmt::Debug for DynLiteralValue { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::Int(v) => write!(f, "dyn int: {v}"), + Self::Float(v) => write!(f, "dyn float: {}", v), + Self::Str(v) => write!(f, "dyn str: {v}"), + Self::List(_) => todo!(), + } + } +} + +impl fmt::Debug for RangeLiteralValue { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "range({}, {})", self.low, self.high) + } +} diff --git a/crates/polars-plan/src/plans/ir/inputs.rs b/crates/polars-plan/src/plans/ir/inputs.rs new file mode 100644 index 000000000000..174658e9ed31 --- /dev/null +++ b/crates/polars-plan/src/plans/ir/inputs.rs @@ -0,0 +1,294 @@ +use super::*; + +impl IR { + /// Takes the expressions of an LP node and the inputs of that node and reconstruct + pub fn with_exprs_and_input(&self, mut exprs: Vec, mut inputs: Vec) -> IR { + use IR::*; + + match self { + #[cfg(feature = "python")] + PythonScan { options } => PythonScan { + options: options.clone(), + }, + Union { options, .. } => Union { + inputs, + options: *options, + }, + HConcat { + schema, options, .. + } => HConcat { + inputs, + schema: schema.clone(), + options: *options, + }, + Slice { offset, len, .. } => Slice { + input: inputs[0], + offset: *offset, + len: *len, + }, + Filter { .. } => Filter { + input: inputs[0], + predicate: exprs.pop().unwrap(), + }, + Select { + schema, options, .. + } => Select { + input: inputs[0], + expr: exprs, + schema: schema.clone(), + options: *options, + }, + GroupBy { + keys, + schema, + apply, + maintain_order, + options: dynamic_options, + .. + } => GroupBy { + input: inputs[0], + keys: exprs[..keys.len()].to_vec(), + aggs: exprs[keys.len()..].to_vec(), + schema: schema.clone(), + apply: apply.clone(), + maintain_order: *maintain_order, + options: dynamic_options.clone(), + }, + Join { + schema, + left_on, + options, + .. + } => Join { + input_left: inputs[0], + input_right: inputs[1], + schema: schema.clone(), + left_on: exprs[..left_on.len()].to_vec(), + right_on: exprs[left_on.len()..].to_vec(), + options: options.clone(), + }, + Sort { + by_column, + slice, + sort_options, + .. + } => Sort { + input: inputs[0], + by_column: by_column.clone(), + slice: *slice, + sort_options: sort_options.clone(), + }, + Cache { id, cache_hits, .. } => Cache { + input: inputs[0], + id: *id, + cache_hits: *cache_hits, + }, + Distinct { options, .. } => Distinct { + input: inputs[0], + options: options.clone(), + }, + HStack { + schema, options, .. + } => HStack { + input: inputs[0], + exprs, + schema: schema.clone(), + options: *options, + }, + Scan { + sources, + file_info, + hive_parts, + output_schema, + predicate, + unified_scan_args, + scan_type, + } => { + let mut new_predicate = None; + if predicate.is_some() { + new_predicate = exprs.pop() + } + + Scan { + sources: sources.clone(), + file_info: file_info.clone(), + hive_parts: hive_parts.clone(), + output_schema: output_schema.clone(), + unified_scan_args: unified_scan_args.clone(), + predicate: new_predicate, + scan_type: scan_type.clone(), + } + }, + DataFrameScan { + df, + schema, + output_schema, + } => DataFrameScan { + df: df.clone(), + schema: schema.clone(), + output_schema: output_schema.clone(), + }, + MapFunction { function, .. } => MapFunction { + input: inputs[0], + function: function.clone(), + }, + ExtContext { schema, .. } => ExtContext { + input: inputs.pop().unwrap(), + contexts: inputs, + schema: schema.clone(), + }, + Sink { payload, .. } => Sink { + input: inputs.pop().unwrap(), + payload: payload.clone(), + }, + SinkMultiple { .. } => SinkMultiple { inputs }, + SimpleProjection { columns, .. } => SimpleProjection { + input: inputs.pop().unwrap(), + columns: columns.clone(), + }, + #[cfg(feature = "merge_sorted")] + MergeSorted { + input_left: _, + input_right: _, + key, + } => MergeSorted { + input_left: inputs[0], + input_right: inputs[1], + key: key.clone(), + }, + Invalid => unreachable!(), + } + } + + /// Copy the exprs in this LP node to an existing container. + pub fn copy_exprs(&self, container: &mut Vec) { + use IR::*; + match self { + Slice { .. } + | Cache { .. } + | Distinct { .. } + | Union { .. } + | MapFunction { .. } + | SinkMultiple { .. } => {}, + Sort { by_column, .. } => container.extend_from_slice(by_column), + Filter { predicate, .. } => container.push(predicate.clone()), + Select { expr, .. } => container.extend_from_slice(expr), + GroupBy { keys, aggs, .. } => { + let iter = keys.iter().cloned().chain(aggs.iter().cloned()); + container.extend(iter) + }, + Join { + left_on, right_on, .. + } => { + let iter = left_on.iter().cloned().chain(right_on.iter().cloned()); + container.extend(iter) + }, + HStack { exprs, .. } => container.extend_from_slice(exprs), + Scan { predicate, .. } => { + if let Some(pred) = predicate { + container.push(pred.clone()) + } + }, + DataFrameScan { .. } => {}, + #[cfg(feature = "python")] + PythonScan { .. } => {}, + Sink { payload, .. } => { + if let SinkTypeIR::Partition(p) = payload { + match &p.variant { + PartitionVariantIR::Parted { key_exprs, .. } + | PartitionVariantIR::ByKey { key_exprs, .. } => { + container.extend_from_slice(key_exprs) + }, + _ => (), + } + } + }, + HConcat { .. } => {}, + ExtContext { .. } | SimpleProjection { .. } => {}, + #[cfg(feature = "merge_sorted")] + MergeSorted { .. } => {}, + Invalid => unreachable!(), + } + } + + /// Get expressions in this node. + pub fn get_exprs(&self) -> Vec { + let mut exprs = Vec::new(); + self.copy_exprs(&mut exprs); + exprs + } + + /// Push inputs of the LP in of this node to an existing container. + /// Most plans have typically one input. A join has two and a scan (CsvScan) + /// or an in-memory DataFrame has none. A Union has multiple. + pub fn copy_inputs(&self, container: &mut T) + where + T: Extend, + { + use IR::*; + let input = match self { + Union { inputs, .. } | HConcat { inputs, .. } | SinkMultiple { inputs } => { + container.extend(inputs.iter().cloned()); + return; + }, + Slice { input, .. } => *input, + Filter { input, .. } => *input, + Select { input, .. } => *input, + SimpleProjection { input, .. } => *input, + Sort { input, .. } => *input, + Cache { input, .. } => *input, + GroupBy { input, .. } => *input, + Join { + input_left, + input_right, + .. + } => { + container.extend([*input_left, *input_right]); + return; + }, + HStack { input, .. } => *input, + Distinct { input, .. } => *input, + MapFunction { input, .. } => *input, + Sink { input, .. } => *input, + ExtContext { + input, contexts, .. + } => { + container.extend(contexts.iter().cloned()); + *input + }, + Scan { .. } => return, + DataFrameScan { .. } => return, + #[cfg(feature = "python")] + PythonScan { .. } => return, + #[cfg(feature = "merge_sorted")] + MergeSorted { + input_left, + input_right, + .. + } => { + container.extend([*input_left, *input_right]); + return; + }, + Invalid => unreachable!(), + }; + container.extend([input]) + } + + pub fn get_inputs(&self) -> UnitVec { + let mut inputs: UnitVec = unitvec!(); + self.copy_inputs(&mut inputs); + inputs + } + + pub fn get_inputs_vec(&self) -> Vec { + let mut inputs = vec![]; + self.copy_inputs(&mut inputs); + inputs + } + + pub(crate) fn get_input(&self) -> Option { + let mut inputs: UnitVec = unitvec!(); + self.copy_inputs(&mut inputs); + inputs.first().copied() + } +} diff --git a/crates/polars-plan/src/plans/ir/mod.rs b/crates/polars-plan/src/plans/ir/mod.rs new file mode 100644 index 000000000000..3adf1b671265 --- /dev/null +++ b/crates/polars-plan/src/plans/ir/mod.rs @@ -0,0 +1,279 @@ +mod dot; +mod format; +mod inputs; +mod schema; +pub(crate) mod tree_format; + +use std::borrow::Cow; +use std::fmt; + +pub use dot::{EscapeLabel, IRDotDisplay, PathsDisplay, ScanSourcesDisplay}; +pub use format::{ExprIRDisplay, IRDisplay}; +use polars_core::prelude::*; +use polars_utils::idx_vec::UnitVec; +use polars_utils::unitvec; +#[cfg(feature = "ir_serde")] +use serde::{Deserialize, Serialize}; +use strum_macros::IntoStaticStr; + +use self::hive::HivePartitionsDf; +use crate::prelude::*; + +#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))] +pub struct IRPlan { + pub lp_top: Node, + pub lp_arena: Arena, + pub expr_arena: Arena, +} + +#[derive(Clone, Copy)] +pub struct IRPlanRef<'a> { + pub lp_top: Node, + pub lp_arena: &'a Arena, + pub expr_arena: &'a Arena, +} + +/// [`IR`] is a representation of [`DslPlan`] with [`Node`]s which are allocated in an [`Arena`] +/// In this IR the logical plan has access to the full dataset. +#[derive(Clone, Debug, Default, IntoStaticStr)] +#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "SCREAMING_SNAKE_CASE")] +pub enum IR { + #[cfg(feature = "python")] + PythonScan { + options: PythonOptions, + }, + Slice { + input: Node, + offset: i64, + len: IdxSize, + }, + Filter { + input: Node, + predicate: ExprIR, + }, + Scan { + sources: ScanSources, + file_info: FileInfo, + hive_parts: Option, + predicate: Option, + /// schema of the projected file + output_schema: Option, + scan_type: Box, + /// generic options that can be used for all file types. + unified_scan_args: Box, + }, + DataFrameScan { + df: Arc, + schema: SchemaRef, + // Schema of the projected file + // If `None`, no projection is applied + output_schema: Option, + }, + // Only selects columns (semantically only has row access). + // This is a more restricted operation than `Select`. + SimpleProjection { + input: Node, + columns: SchemaRef, + }, + // Polars' `select` operation. This may access full materialized data. + Select { + input: Node, + expr: Vec, + schema: SchemaRef, + options: ProjectionOptions, + }, + Sort { + input: Node, + by_column: Vec, + slice: Option<(i64, usize)>, + sort_options: SortMultipleOptions, + }, + Cache { + input: Node, + // Unique ID. + id: usize, + /// How many hits the cache must be saved in memory. + cache_hits: u32, + }, + GroupBy { + input: Node, + keys: Vec, + aggs: Vec, + schema: SchemaRef, + maintain_order: bool, + options: Arc, + #[cfg_attr(feature = "ir_serde", serde(skip))] + apply: Option>, + }, + Join { + input_left: Node, + input_right: Node, + schema: SchemaRef, + left_on: Vec, + right_on: Vec, + options: Arc, + }, + HStack { + input: Node, + exprs: Vec, + schema: SchemaRef, + options: ProjectionOptions, + }, + Distinct { + input: Node, + options: DistinctOptionsIR, + }, + MapFunction { + input: Node, + function: FunctionIR, + }, + Union { + inputs: Vec, + options: UnionOptions, + }, + /// Horizontal concatenation + /// - Invariant: the names will be unique + HConcat { + inputs: Vec, + schema: SchemaRef, + options: HConcatOptions, + }, + ExtContext { + input: Node, + contexts: Vec, + schema: SchemaRef, + }, + Sink { + input: Node, + payload: SinkTypeIR, + }, + /// Node that allows for multiple plans to be executed in parallel with common subplan + /// elimination and everything. + SinkMultiple { + inputs: Vec, + }, + #[cfg(feature = "merge_sorted")] + MergeSorted { + input_left: Node, + input_right: Node, + key: PlSmallStr, + }, + #[default] + Invalid, +} + +impl IRPlan { + pub fn new(top: Node, ir_arena: Arena, expr_arena: Arena) -> Self { + Self { + lp_top: top, + lp_arena: ir_arena, + expr_arena, + } + } + + pub fn root(&self) -> &IR { + self.lp_arena.get(self.lp_top) + } + + pub fn as_ref(&self) -> IRPlanRef { + IRPlanRef { + lp_top: self.lp_top, + lp_arena: &self.lp_arena, + expr_arena: &self.expr_arena, + } + } + + /// Extract the original logical plan if the plan is for the Streaming Engine + pub fn extract_streaming_plan(&self) -> Option { + self.as_ref().extract_streaming_plan() + } + + pub fn describe(&self) -> String { + self.as_ref().describe() + } + + pub fn describe_tree_format(&self) -> String { + self.as_ref().describe_tree_format() + } + + pub fn display(&self) -> format::IRDisplay { + self.as_ref().display() + } + + pub fn display_dot(&self) -> dot::IRDotDisplay { + self.as_ref().display_dot() + } +} + +impl<'a> IRPlanRef<'a> { + pub fn root(self) -> &'a IR { + self.lp_arena.get(self.lp_top) + } + + pub fn with_root(self, root: Node) -> Self { + Self { + lp_top: root, + lp_arena: self.lp_arena, + expr_arena: self.expr_arena, + } + } + + /// Extract the original logical plan if the plan is for the Streaming Engine + pub fn extract_streaming_plan(self) -> Option> { + // @NOTE: the streaming engine replaces the whole tree with a MapFunction { Pipeline, .. } + // and puts the original plan somewhere in there. This is how we extract it. Disgusting, I + // know. + let IR::MapFunction { input: _, function } = self.root() else { + return None; + }; + + let FunctionIR::Pipeline { original, .. } = function else { + return None; + }; + + Some(original.as_ref()?.as_ref().as_ref()) + } + + pub fn display(self) -> format::IRDisplay<'a> { + format::IRDisplay::new(self) + } + + pub fn display_dot(self) -> dot::IRDotDisplay<'a> { + dot::IRDotDisplay::new(self) + } + + pub fn describe(self) -> String { + self.display().to_string() + } + + pub fn describe_tree_format(self) -> String { + let mut visitor = tree_format::TreeFmtVisitor::default(); + tree_format::TreeFmtNode::root_logical_plan(self).traverse(&mut visitor); + format!("{visitor:#?}") + } +} + +impl fmt::Debug for IRPlan { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + ::fmt(&self.display(), f) + } +} + +impl fmt::Debug for IRPlanRef<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + ::fmt(&self.display(), f) + } +} + +#[cfg(test)] +mod test { + use super::*; + + // skipped for now + #[ignore] + #[test] + fn test_alp_size() { + assert!(size_of::() <= 152); + } +} diff --git a/crates/polars-plan/src/plans/ir/schema.rs b/crates/polars-plan/src/plans/ir/schema.rs new file mode 100644 index 000000000000..e2591036b51d --- /dev/null +++ b/crates/polars-plan/src/plans/ir/schema.rs @@ -0,0 +1,178 @@ +use recursive::recursive; + +use super::*; + +impl IR { + /// Get the schema of the logical plan node but don't take projections into account at the scan + /// level. This ensures we can apply the predicate + pub(crate) fn scan_schema(&self) -> &SchemaRef { + use IR::*; + match self { + Scan { file_info, .. } => &file_info.schema, + #[cfg(feature = "python")] + PythonScan { options, .. } => &options.schema, + _ => unreachable!(), + } + } + + pub fn name(&self) -> &'static str { + use IR::*; + match self { + Scan { scan_type, .. } => (&**scan_type).into(), + #[cfg(feature = "python")] + PythonScan { .. } => "python_scan", + Slice { .. } => "slice", + Filter { .. } => "selection", + DataFrameScan { .. } => "df", + Select { .. } => "projection", + Sort { .. } => "sort", + Cache { .. } => "cache", + GroupBy { .. } => "aggregate", + Join { .. } => "join", + HStack { .. } => "hstack", + Distinct { .. } => "distinct", + MapFunction { .. } => "map_function", + Union { .. } => "union", + HConcat { .. } => "hconcat", + ExtContext { .. } => "ext_context", + Sink { payload, .. } => match payload { + SinkTypeIR::Memory => "sink (memory)", + SinkTypeIR::File { .. } => "sink (file)", + SinkTypeIR::Partition { .. } => "sink (partition)", + }, + SinkMultiple { .. } => "sink multiple", + SimpleProjection { .. } => "simple_projection", + #[cfg(feature = "merge_sorted")] + MergeSorted { .. } => "merge_sorted", + Invalid => "invalid", + } + } + + pub fn input_schema<'a>(&'a self, arena: &'a Arena) -> Option> { + use IR::*; + let schema = match self { + #[cfg(feature = "python")] + PythonScan { options } => &options.schema, + DataFrameScan { schema, .. } => schema, + Scan { file_info, .. } => &file_info.schema, + node => { + let input = node.get_input()?; + return Some(arena.get(input).schema(arena)); + }, + }; + Some(Cow::Borrowed(schema)) + } + + /// Get the schema of the logical plan node. + #[recursive] + pub fn schema<'a>(&'a self, arena: &'a Arena) -> Cow<'a, SchemaRef> { + use IR::*; + let schema = match self { + #[cfg(feature = "python")] + PythonScan { options } => options.output_schema.as_ref().unwrap_or(&options.schema), + Union { inputs, .. } => return arena.get(inputs[0]).schema(arena), + HConcat { schema, .. } => schema, + Cache { input, .. } => return arena.get(*input).schema(arena), + Sort { input, .. } => return arena.get(*input).schema(arena), + Scan { + output_schema, + file_info, + .. + } => output_schema.as_ref().unwrap_or(&file_info.schema), + DataFrameScan { + schema, + output_schema, + .. + } => output_schema.as_ref().unwrap_or(schema), + Filter { input, .. } => return arena.get(*input).schema(arena), + Select { schema, .. } => schema, + SimpleProjection { columns, .. } => columns, + GroupBy { schema, .. } => schema, + Join { schema, .. } => schema, + HStack { schema, .. } => schema, + Distinct { input, .. } + | Sink { + input, + payload: SinkTypeIR::Memory, + } => return arena.get(*input).schema(arena), + Sink { .. } | SinkMultiple { .. } => return Cow::Owned(Arc::new(Schema::default())), + Slice { input, .. } => return arena.get(*input).schema(arena), + MapFunction { input, function } => { + let input_schema = arena.get(*input).schema(arena); + + return match input_schema { + Cow::Owned(schema) => { + Cow::Owned(function.schema(&schema).unwrap().into_owned()) + }, + Cow::Borrowed(schema) => function.schema(schema).unwrap(), + }; + }, + ExtContext { schema, .. } => schema, + #[cfg(feature = "merge_sorted")] + MergeSorted { input_left, .. } => return arena.get(*input_left).schema(arena), + Invalid => unreachable!(), + }; + Cow::Borrowed(schema) + } + + /// Get the schema of the logical plan node, using caching. + #[recursive] + pub fn schema_with_cache<'a>( + node: Node, + arena: &'a Arena, + cache: &mut PlHashMap>, + ) -> Arc { + use IR::*; + if let Some(schema) = cache.get(&node) { + return schema.clone(); + } + + let schema = match arena.get(node) { + #[cfg(feature = "python")] + PythonScan { options } => options + .output_schema + .as_ref() + .unwrap_or(&options.schema) + .clone(), + Union { inputs, .. } => IR::schema_with_cache(inputs[0], arena, cache), + HConcat { schema, .. } => schema.clone(), + Cache { input, .. } + | Sort { input, .. } + | Filter { input, .. } + | Distinct { input, .. } + | Sink { + input, + payload: SinkTypeIR::Memory, + } + | Slice { input, .. } => IR::schema_with_cache(*input, arena, cache), + Sink { .. } | SinkMultiple { .. } => Arc::new(Schema::default()), + Scan { + output_schema, + file_info, + .. + } => output_schema.as_ref().unwrap_or(&file_info.schema).clone(), + DataFrameScan { + schema, + output_schema, + .. + } => output_schema.as_ref().unwrap_or(schema).clone(), + Select { schema, .. } + | GroupBy { schema, .. } + | Join { schema, .. } + | HStack { schema, .. } + | ExtContext { schema, .. } + | SimpleProjection { + columns: schema, .. + } => schema.clone(), + MapFunction { input, function } => { + let input_schema = IR::schema_with_cache(*input, arena, cache); + function.schema(&input_schema).unwrap().into_owned() + }, + #[cfg(feature = "merge_sorted")] + MergeSorted { input_left, .. } => IR::schema_with_cache(*input_left, arena, cache), + Invalid => unreachable!(), + }; + cache.insert(node, schema.clone()); + schema + } +} diff --git a/crates/polars-plan/src/plans/ir/tree_format.rs b/crates/polars-plan/src/plans/ir/tree_format.rs new file mode 100644 index 000000000000..493c0c934236 --- /dev/null +++ b/crates/polars-plan/src/plans/ir/tree_format.rs @@ -0,0 +1,953 @@ +use std::fmt; + +use polars_core::error::*; +use polars_utils::{format_list_container_truncated, format_list_truncated}; + +use crate::constants; +use crate::plans::ir::IRPlanRef; +use crate::plans::visitor::{VisitRecursion, Visitor}; +use crate::prelude::ir::format::ColumnsDisplay; +use crate::prelude::visitor::AexprNode; +use crate::prelude::*; + +pub struct TreeFmtNode<'a> { + h: Option, + content: TreeFmtNodeContent<'a>, + + lp: IRPlanRef<'a>, +} + +pub struct TreeFmtAExpr<'a>(&'a AExpr); + +/// Hack UpperExpr trait to get a kind of formatting that doesn't traverse the nodes. +/// So we can format with {foo:E} +impl fmt::Display for TreeFmtAExpr<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match self.0 { + AExpr::Explode(_) => "explode", + AExpr::Alias(_, name) => return write!(f, "alias({})", name), + AExpr::Column(name) => return write!(f, "col({})", name), + AExpr::Literal(lv) => return write!(f, "lit({lv:?})"), + AExpr::BinaryExpr { op, .. } => return write!(f, "binary: {}", op), + AExpr::Cast { dtype, options, .. } => { + return if options.strict() { + write!(f, "strict cast({})", dtype) + } else { + write!(f, "cast({})", dtype) + }; + }, + AExpr::Sort { options, .. } => { + return write!( + f, + "sort: {}{}{}", + options.descending as u8, options.nulls_last as u8, options.multithreaded as u8 + ); + }, + AExpr::Gather { .. } => "gather", + AExpr::SortBy { sort_options, .. } => { + write!(f, "sort_by:")?; + for i in &sort_options.descending { + write!(f, "{}", *i as u8)?; + } + for i in &sort_options.nulls_last { + write!(f, "{}", *i as u8)?; + } + write!(f, "{}", sort_options.multithreaded as u8)?; + return Ok(()); + }, + AExpr::Filter { .. } => "filter", + AExpr::Agg(a) => { + let s: &str = a.into(); + return write!(f, "{}", s.to_lowercase()); + }, + AExpr::Ternary { .. } => "ternary", + AExpr::AnonymousFunction { options, .. } => { + return write!(f, "anonymous_function: {}", options.fmt_str); + }, + AExpr::Function { function, .. } => return write!(f, "function: {function}"), + AExpr::Window { .. } => "window", + AExpr::Slice { .. } => "slice", + AExpr::Len => constants::LEN, + }; + + write!(f, "{s}") + } +} + +pub enum TreeFmtNodeContent<'a> { + Expression(&'a ExprIR), + LogicalPlan(Node), +} + +struct TreeFmtNodeData<'a>(String, Vec>); + +fn with_header(header: &Option, text: &str) -> String { + if let Some(header) = header { + format!("{header}\n{text}") + } else { + text.to_string() + } +} + +#[cfg(feature = "regex")] +fn multiline_expression(expr: &str) -> std::borrow::Cow<'_, str> { + polars_utils::regex_cache::cached_regex! { + static RE = r"([\)\]])(\.[a-z0-9]+\()"; + } + RE.replace_all(expr, "$1\n $2") +} + +impl<'a> TreeFmtNode<'a> { + pub fn root_logical_plan(lp: IRPlanRef<'a>) -> Self { + if let Some(streaming_lp) = lp.extract_streaming_plan() { + return Self::streaming_root_logical_plan(streaming_lp); + } + + Self { + h: None, + content: TreeFmtNodeContent::LogicalPlan(lp.lp_top), + + lp, + } + } + + fn streaming_root_logical_plan(lp: IRPlanRef<'a>) -> Self { + Self { + h: Some("Streaming".to_string()), + content: TreeFmtNodeContent::LogicalPlan(lp.lp_top), + + lp, + } + } + + pub fn lp_node(&self, h: Option, root: Node) -> Self { + Self { + h, + content: TreeFmtNodeContent::LogicalPlan(root), + + lp: self.lp, + } + } + + pub fn expr_node(&self, h: Option, expr: &'a ExprIR) -> Self { + Self { + h, + content: TreeFmtNodeContent::Expression(expr), + + lp: self.lp, + } + } + + pub fn traverse(&self, visitor: &mut TreeFmtVisitor) { + let TreeFmtNodeData(title, child_nodes) = self.node_data(); + + if visitor.levels.len() <= visitor.depth { + visitor.levels.push(vec![]); + } + + let row = visitor.levels.get_mut(visitor.depth).unwrap(); + row.resize(visitor.width + 1, "".to_string()); + + row[visitor.width] = title; + visitor.prev_depth = visitor.depth; + visitor.depth += 1; + + for child in &child_nodes { + child.traverse(visitor); + } + + visitor.depth -= 1; + visitor.width += if visitor.prev_depth == visitor.depth { + 1 + } else { + 0 + }; + } + + fn node_data(&self) -> TreeFmtNodeData<'_> { + use {TreeFmtNodeContent as C, TreeFmtNodeData as ND, with_header as wh}; + + let lp = &self.lp; + let h = &self.h; + + use IR::*; + match self.content { + #[cfg(feature = "regex")] + C::Expression(expr) => ND( + wh( + h, + &multiline_expression(&expr.display(self.lp.expr_arena).to_string()), + ), + vec![], + ), + #[cfg(not(feature = "regex"))] + C::Expression(expr) => ND(wh(h, &expr.display(self.lp.expr_arena).to_string()), vec![]), + C::LogicalPlan(lp_top) => { + match self.lp.with_root(lp_top).root() { + #[cfg(feature = "python")] + PythonScan { .. } => ND(wh(h, &lp.describe()), vec![]), + Scan { .. } => ND(wh(h, &lp.describe()), vec![]), + DataFrameScan { + schema, + output_schema, + .. + } => { + let (n_columns, projected) = if let Some(schema) = output_schema { + ( + format!("{}", schema.len()), + format!( + ": {};", + format_list_truncated!(schema.iter_names(), 4, '"') + ), + ) + } else { + ("*".to_string(), "".to_string()) + }; + ND( + wh( + h, + &format!( + "DF {}\nPROJECT{} {}/{} COLUMNS", + format_list_truncated!(schema.iter_names(), 4, '"'), + projected, + n_columns, + schema.len() + ), + ), + vec![], + ) + }, + + Union { inputs, .. } => ND( + wh( + h, + // THis is commented out, but must be restored when we convert to IR's. + // &(if let Some(slice) = options.slice { + // format!("SLICED UNION: {slice:?}") + // } else { + // "UNION".to_string() + // }), + "UNION", + ), + inputs + .iter() + .enumerate() + .map(|(i, lp_root)| self.lp_node(Some(format!("PLAN {i}:")), *lp_root)) + .collect(), + ), + HConcat { inputs, .. } => ND( + wh(h, "HCONCAT"), + inputs + .iter() + .enumerate() + .map(|(i, lp_root)| self.lp_node(Some(format!("PLAN {i}:")), *lp_root)) + .collect(), + ), + Cache { + input, + id, + cache_hits, + } => ND( + wh( + h, + &format!("CACHE[id: {:x}, cache_hits: {}]", *id, *cache_hits), + ), + vec![self.lp_node(None, *input)], + ), + Filter { input, predicate } => ND( + wh(h, "FILTER"), + vec![ + self.expr_node(Some("predicate:".to_string()), predicate), + self.lp_node(Some("FROM:".to_string()), *input), + ], + ), + Select { expr, input, .. } => ND( + wh(h, "SELECT"), + expr.iter() + .map(|expr| self.expr_node(Some("expression:".to_string()), expr)) + .chain([self.lp_node(Some("FROM:".to_string()), *input)]) + .collect(), + ), + Sort { + input, by_column, .. + } => ND( + wh(h, "SORT BY"), + by_column + .iter() + .map(|expr| self.expr_node(Some("expression:".to_string()), expr)) + .chain([self.lp_node(None, *input)]) + .collect(), + ), + GroupBy { + input, keys, aggs, .. + } => ND( + wh(h, "AGGREGATE"), + aggs.iter() + .map(|expr| self.expr_node(Some("expression:".to_string()), expr)) + .chain(keys.iter().map(|expr| { + self.expr_node(Some("aggregate by:".to_string()), expr) + })) + .chain([self.lp_node(Some("FROM:".to_string()), *input)]) + .collect(), + ), + Join { + input_left, + input_right, + left_on, + right_on, + options, + .. + } => ND( + wh(h, &format!("{} JOIN", options.args.how)), + left_on + .iter() + .map(|expr| self.expr_node(Some("left on:".to_string()), expr)) + .chain([self.lp_node(Some("LEFT PLAN:".to_string()), *input_left)]) + .chain( + right_on.iter().map(|expr| { + self.expr_node(Some("right on:".to_string()), expr) + }), + ) + .chain([self.lp_node(Some("RIGHT PLAN:".to_string()), *input_right)]) + .collect(), + ), + HStack { input, exprs, .. } => ND( + wh(h, "WITH_COLUMNS"), + exprs + .iter() + .map(|expr| self.expr_node(Some("expression:".to_string()), expr)) + .chain([self.lp_node(None, *input)]) + .collect(), + ), + Distinct { input, options } => ND( + wh( + h, + &format!( + "UNIQUE[maintain_order: {:?}, keep_strategy: {:?}] BY {:?}", + options.maintain_order, options.keep_strategy, options.subset + ), + ), + vec![self.lp_node(None, *input)], + ), + Slice { input, offset, len } => ND( + wh(h, &format!("SLICE[offset: {offset}, len: {len}]")), + vec![self.lp_node(None, *input)], + ), + MapFunction { input, function } => { + if let Some(streaming_lp) = function.to_streaming_lp() { + ND( + String::default(), + vec![TreeFmtNode::streaming_root_logical_plan(streaming_lp)], + ) + } else { + ND( + wh(h, &format!("{function}")), + vec![self.lp_node(None, *input)], + ) + } + }, + ExtContext { input, .. } => { + ND(wh(h, "EXTERNAL_CONTEXT"), vec![self.lp_node(None, *input)]) + }, + Sink { input, payload } => ND( + wh( + h, + match payload { + SinkTypeIR::Memory => "SINK (memory)", + SinkTypeIR::File { .. } => "SINK (file)", + SinkTypeIR::Partition { .. } => "SINK (partition)", + }, + ), + vec![self.lp_node(None, *input)], + ), + SinkMultiple { inputs } => ND( + wh(h, "SINK_MULTIPLE"), + inputs + .iter() + .enumerate() + .map(|(i, lp_root)| self.lp_node(Some(format!("PLAN {i}:")), *lp_root)) + .collect(), + ), + SimpleProjection { input, columns } => { + let num_columns = columns.as_ref().len(); + let total_columns = lp.lp_arena.get(*input).schema(lp.lp_arena).len(); + + let columns = ColumnsDisplay(columns.as_ref()); + ND( + wh( + h, + &format!("simple π {num_columns}/{total_columns} [{columns}]"), + ), + vec![self.lp_node(None, *input)], + ) + }, + #[cfg(feature = "merge_sorted")] + MergeSorted { + input_left, + input_right, + key, + } => ND( + wh(h, &format!("MERGE SORTED ON '{key}")), + [self.lp_node(Some("LEFT PLAN:".to_string()), *input_left)] + .into_iter() + .chain([self.lp_node(Some("RIGHT PLAN:".to_string()), *input_right)]) + .collect(), + ), + Invalid => ND(wh(h, "INVALID"), vec![]), + } + }, + } + } +} + +#[derive(Default)] +pub enum TreeFmtVisitorDisplay { + #[default] + DisplayText, + DisplayDot, +} + +#[derive(Default)] +pub(crate) struct TreeFmtVisitor { + levels: Vec>, + prev_depth: usize, + depth: usize, + width: usize, + pub(crate) display: TreeFmtVisitorDisplay, +} + +impl Visitor for TreeFmtVisitor { + type Node = AexprNode; + type Arena = Arena; + + /// Invoked before any children of `node` are visited. + fn pre_visit( + &mut self, + node: &Self::Node, + arena: &Self::Arena, + ) -> PolarsResult { + let repr = TreeFmtAExpr(arena.get(node.node())); + let repr = repr.to_string(); + + if self.levels.len() <= self.depth { + self.levels.push(vec![]) + } + + // the post-visit ensures the width of this node is known + let row = self.levels.get_mut(self.depth).unwrap(); + + // set default values to ensure we format at the right width + row.resize(self.width + 1, "".to_string()); + row[self.width] = repr; + + // before entering a depth-first branch we preserve the depth to control the width increase + // in the post-visit + self.prev_depth = self.depth; + + // we will enter depth first, we enter child so depth increases + self.depth += 1; + + Ok(VisitRecursion::Continue) + } + + fn post_visit( + &mut self, + _node: &Self::Node, + _arena: &Self::Arena, + ) -> PolarsResult { + // we finished this branch so we decrease in depth, back the caller node + self.depth -= 1; + + // because we traverse depth first + // the width is increased once after one or more depth-first branches + // this way we avoid empty columns in the resulting tree representation + self.width += if self.prev_depth == self.depth { 1 } else { 0 }; + + Ok(VisitRecursion::Continue) + } +} + +/// Calculates the number of digits in a `usize` number +/// Useful for the alignment of `usize` values when they are displayed +fn digits(n: usize) -> usize { + if n == 0 { + 1 + } else { + f64::log10(n as f64) as usize + 1 + } +} + +/// Meta-info of a column in a populated `TreeFmtVisitor` required for the pretty-print of a tree +#[derive(Clone, Default, Debug)] +struct TreeViewColumn { + offset: usize, + width: usize, + center: usize, +} + +/// Meta-info of a column in a populated `TreeFmtVisitor` required for the pretty-print of a tree +#[derive(Clone, Default, Debug)] +struct TreeViewRow { + offset: usize, + height: usize, + center: usize, +} + +/// Meta-info of a cell in a populated `TreeFmtVisitor` +#[derive(Clone, Default, Debug)] +struct TreeViewCell<'a> { + text: Vec<&'a str>, + /// A `Vec` of indices of `TreeViewColumn`-s stored elsewhere in another `Vec` + /// For a cell on a row `i` these indices point to the columns that contain child-cells on a + /// row `i + 1` (if the latter exists) + /// NOTE: might warrant a rethink should this code become used broader + children_columns: Vec, +} + +/// The complete intermediate representation of a `TreeFmtVisitor` that can be drawn on a `Canvas` +/// down the line +#[derive(Default, Debug)] +struct TreeView<'a> { + n_rows: usize, + n_rows_width: usize, + matrix: Vec>>, + /// NOTE: `TreeViewCell`'s `children_columns` field contains indices pointing at the elements + /// of this `Vec` + columns: Vec, + rows: Vec, +} + +// NOTE: the code below this line is full of hardcoded integer offsets which may not be a big +// problem as long as it remains the private implementation of the pretty-print +/// The conversion from a reference to `levels` field of a `TreeFmtVisitor` +impl<'a> From<&'a [Vec]> for TreeView<'a> { + #[allow(clippy::needless_range_loop)] + fn from(value: &'a [Vec]) -> Self { + let n_rows = value.len(); + let n_cols = value.iter().map(|row| row.len()).max().unwrap_or(0); + if n_rows == 0 || n_cols == 0 { + return TreeView::default(); + } + // the character-width of the highest index of a row + let n_rows_width = digits(n_rows - 1); + + let mut matrix = vec![vec![TreeViewCell::default(); n_cols]; n_rows]; + for i in 0..n_rows { + for j in 0..n_cols { + if j < value[i].len() && !value[i][j].is_empty() { + matrix[i][j].text = value[i][j].split('\n').collect(); + if i < n_rows - 1 { + if j < value[i + 1].len() && !value[i + 1][j].is_empty() { + matrix[i][j].children_columns.push(j); + } + for k in j + 1..n_cols { + if (k >= value[i].len() || value[i][k].is_empty()) + && k < value[i + 1].len() + { + if !value[i + 1][k].is_empty() { + matrix[i][j].children_columns.push(k); + } + } else { + break; + } + } + } + } + } + } + + let mut y_offset = 3; + let mut rows = vec![TreeViewRow::default(); n_rows]; + for i in 0..n_rows { + let mut height = 0; + for j in 0..n_cols { + height = [matrix[i][j].text.len(), height].into_iter().max().unwrap(); + } + height += 2; + rows[i].offset = y_offset; + rows[i].height = height; + rows[i].center = height / 2; + y_offset += height + 3; + } + + let mut x_offset = n_rows_width + 4; + let mut columns = vec![TreeViewColumn::default(); n_cols]; + // the two nested loops below are those `needless_range_loop`s + // more readable this way to my taste + for j in 0..n_cols { + let mut width = 0; + for i in 0..n_rows { + width = [ + matrix[i][j].text.iter().map(|l| l.len()).max().unwrap_or(0), + width, + ] + .into_iter() + .max() + .unwrap(); + } + width += 6; + columns[j].offset = x_offset; + columns[j].width = width; + columns[j].center = width / 2 + width % 2; + x_offset += width; + } + + Self { + n_rows, + n_rows_width, + matrix, + columns, + rows, + } + } +} + +/// The basic charset that's used for drawing lines and boxes on a `Canvas` +struct Glyphs { + void: char, + vertical_line: char, + horizontal_line: char, + top_left_corner: char, + top_right_corner: char, + bottom_left_corner: char, + bottom_right_corner: char, + tee_down: char, + tee_up: char, +} + +impl Default for Glyphs { + fn default() -> Self { + Self { + void: ' ', + vertical_line: '│', + horizontal_line: '─', + top_left_corner: '╭', + top_right_corner: '╮', + bottom_left_corner: '╰', + bottom_right_corner: '╯', + tee_down: '┬', + tee_up: '┴', + } + } +} + +/// A `Point` on a `Canvas` +#[derive(Clone, Copy)] +struct Point(usize, usize); + +/// The orientation of a line on a `Canvas` +#[derive(Clone, Copy)] +enum Orientation { + Vertical, + Horizontal, +} + +/// `Canvas` +struct Canvas { + width: usize, + height: usize, + canvas: Vec>, + glyphs: Glyphs, +} + +impl Canvas { + fn new(width: usize, height: usize, glyphs: Glyphs) -> Self { + Self { + width, + height, + canvas: vec![vec![glyphs.void; width]; height], + glyphs, + } + } + + /// Draws a single `symbol` on the `Canvas` + /// NOTE: The `Point`s that lay outside of the `Canvas` are quietly ignored + fn draw_symbol(&mut self, point: Point, symbol: char) { + let Point(x, y) = point; + if x < self.width && y < self.height { + self.canvas[y][x] = symbol; + } + } + + /// Draws a line of `length` from an `origin` along the `orientation` + fn draw_line(&mut self, origin: Point, orientation: Orientation, length: usize) { + let Point(x, y) = origin; + if let Orientation::Vertical = orientation { + let mut down = 0; + while down < length { + self.draw_symbol(Point(x, y + down), self.glyphs.vertical_line); + down += 1; + } + } else if let Orientation::Horizontal = orientation { + let mut right = 0; + while right < length { + self.draw_symbol(Point(x + right, y), self.glyphs.horizontal_line); + right += 1; + } + } + } + + /// Draws a box of `width` and `height` with an `origin` being the top left corner + fn draw_box(&mut self, origin: Point, width: usize, height: usize) { + let Point(x, y) = origin; + self.draw_symbol(origin, self.glyphs.top_left_corner); + self.draw_symbol(Point(x + width - 1, y), self.glyphs.top_right_corner); + self.draw_symbol(Point(x, y + height - 1), self.glyphs.bottom_left_corner); + self.draw_symbol( + Point(x + width - 1, y + height - 1), + self.glyphs.bottom_right_corner, + ); + self.draw_line(Point(x + 1, y), Orientation::Horizontal, width - 2); + self.draw_line( + Point(x + 1, y + height - 1), + Orientation::Horizontal, + width - 2, + ); + self.draw_line(Point(x, y + 1), Orientation::Vertical, height - 2); + self.draw_line( + Point(x + width - 1, y + 1), + Orientation::Vertical, + height - 2, + ); + } + + /// Draws a box of height `2 + text.len()` containing a left-aligned text + fn draw_label_centered(&mut self, center: Point, text: &[&str]) { + if !text.is_empty() { + let Point(x, y) = center; + let text_width = text.iter().map(|l| l.len()).max().unwrap(); + let half_width = text_width / 2 + text_width % 2; + let half_height = text.len() / 2; + if x >= half_width + 2 && y > half_height { + self.draw_box( + Point(x - half_width - 2, y - half_height - 1), + text_width + 4, + text.len() + 2, + ); + for (i, line) in text.iter().enumerate() { + for (j, c) in line.chars().enumerate() { + self.draw_symbol(Point(x - half_width + j, y - half_height + i), c); + } + } + } + } + } + + /// Draws branched lines from a `Point` to multiple `Point`s below + /// NOTE: the shape of these connections is very specific for this particular kind of the + /// representation of a tree + fn draw_connections(&mut self, from: Point, to: &[Point], branching_offset: usize) { + let mut start_with_corner = true; + let Point(mut x_from, mut y_from) = from; + for (i, Point(x, y)) in to.iter().enumerate() { + if *x >= x_from && *y >= y_from - 1 { + self.draw_symbol(Point(*x, *y), self.glyphs.tee_up); + if *x == x_from { + // if the first connection goes straight below + self.draw_symbol(Point(x_from, y_from - 1), self.glyphs.tee_down); + self.draw_line(Point(x_from, y_from), Orientation::Vertical, *y - y_from); + x_from += 1; + } else { + if start_with_corner { + // if the first or the second connection steers to the right + self.draw_symbol(Point(x_from, y_from - 1), self.glyphs.tee_down); + self.draw_line( + Point(x_from, y_from), + Orientation::Vertical, + branching_offset, + ); + y_from += branching_offset; + self.draw_symbol(Point(x_from, y_from), self.glyphs.bottom_left_corner); + start_with_corner = false; + x_from += 1; + } + let length = *x - x_from; + self.draw_line(Point(x_from, y_from), Orientation::Horizontal, length); + x_from += length; + if i == to.len() - 1 { + self.draw_symbol(Point(x_from, y_from), self.glyphs.top_right_corner); + } else { + self.draw_symbol(Point(x_from, y_from), self.glyphs.tee_down); + } + self.draw_line( + Point(x_from, y_from + 1), + Orientation::Vertical, + *y - y_from - 1, + ); + x_from += 1; + } + } + } + } +} + +/// The actual drawing happens in the conversion of the intermediate `TreeView` into `Canvas` +impl From> for Canvas { + fn from(value: TreeView<'_>) -> Self { + let width = value.n_rows_width + 3 + value.columns.iter().map(|c| c.width).sum::(); + let height = + 3 + value.rows.iter().map(|r| r.height).sum::() + 3 * (value.n_rows - 1); + let mut canvas = Canvas::new(width, height, Glyphs::default()); + + // Axles + let (x, y) = (value.n_rows_width + 2, 1); + canvas.draw_symbol(Point(x, y), '┌'); + canvas.draw_line(Point(x + 1, y), Orientation::Horizontal, width - x); + canvas.draw_line(Point(x, y + 1), Orientation::Vertical, height - y); + + // Row and column indices + for (i, row) in value.rows.iter().enumerate() { + // the prefix `Vec` of spaces compensates for the row indices that are shorter than the + // highest index, effectively, row indices are right-aligned + for (j, c) in vec![' '; value.n_rows_width - digits(i)] + .into_iter() + .chain(format!("{i}").chars()) + .enumerate() + { + canvas.draw_symbol(Point(j + 1, row.offset + row.center), c); + } + } + for (j, col) in value.columns.iter().enumerate() { + let j_width = digits(j); + let start = col.offset + col.center - (j_width / 2 + j_width % 2); + for (k, c) in format!("{j}").chars().enumerate() { + canvas.draw_symbol(Point(start + k, 0), c); + } + } + + // Non-empty cells (nodes) and their connections (edges) + for (i, row) in value.matrix.iter().enumerate() { + for (j, cell) in row.iter().enumerate() { + if !cell.text.is_empty() { + canvas.draw_label_centered( + Point( + value.columns[j].offset + value.columns[j].center, + value.rows[i].offset + value.rows[i].center, + ), + &cell.text, + ); + } + } + } + + fn even_odd(a: usize, b: usize) -> usize { + if a % 2 == 0 && b % 2 == 1 { 1 } else { 0 } + } + + for (i, row) in value.matrix.iter().enumerate() { + for (j, cell) in row.iter().enumerate() { + if !cell.text.is_empty() && i < value.rows.len() - 1 { + let children_points = cell + .children_columns + .iter() + .map(|k| { + let child_total_padding = + value.rows[i + 1].height - value.matrix[i + 1][*k].text.len() - 2; + let even_cell_in_odd_row = even_odd( + value.matrix[i + 1][*k].text.len(), + value.rows[i + 1].height, + ); + Point( + value.columns[*k].offset + value.columns[*k].center - 1, + value.rows[i + 1].offset + + child_total_padding / 2 + + child_total_padding % 2 + - even_cell_in_odd_row, + ) + }) + .collect::>(); + + let parent_total_padding = + value.rows[i].height - value.matrix[i][j].text.len() - 2; + let even_cell_in_odd_row = + even_odd(value.matrix[i][j].text.len(), value.rows[i].height); + + canvas.draw_connections( + Point( + value.columns[j].offset + value.columns[j].center - 1, + value.rows[i].offset + value.rows[i].height + - parent_total_padding / 2 + - even_cell_in_odd_row, + ), + &children_points, + parent_total_padding / 2 + 1 + even_cell_in_odd_row, + ); + } + } + } + + canvas + } +} + +impl fmt::Display for Canvas { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + for row in &self.canvas { + writeln!(f, "{}", row.iter().collect::().trim_end())?; + } + + Ok(()) + } +} + +fn tree_fmt_text(tree: &TreeFmtVisitor, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + let tree_view: TreeView<'_> = tree.levels.as_slice().into(); + let canvas: Canvas = tree_view.into(); + write!(f, "{canvas}")?; + + Ok(()) +} + +// GraphViz Output +// Create a simple DOT graph String from TreeFmtVisitor +fn tree_fmt_dot(tree: &TreeFmtVisitor, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + // Build a dot graph as a string + let tree_view: TreeView<'_> = tree.levels.as_slice().into(); + let mut relations: Vec = Vec::new(); + + // Non-empty cells (nodes) and their connections (edges) + for (i, row) in tree_view.matrix.iter().enumerate() { + for (j, cell) in row.iter().enumerate() { + if !cell.text.is_empty() { + // Add node + let node_label = &cell.text.join("\n"); + let node_desc = format!("n{i}{j} [label=\"{node_label}\", ordering=\"out\"]"); + relations.push(node_desc); + + // Add child edges + if i < tree_view.rows.len() - 1 { + // Iter in reversed order to undo the reversed child order when iterating expressions + for child_col in cell.children_columns.iter().rev() { + let next_row = i + 1; + let edge = format!("n{i}{j} -- n{next_row}{child_col}"); + relations.push(edge); + } + } + } + } + } + + let graph_str = relations.join("\n "); + let s = format!("graph {{\n {graph_str}\n}}"); + write!(f, "{s}")?; + Ok(()) +} + +fn tree_fmt(tree: &TreeFmtVisitor, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + match tree.display { + TreeFmtVisitorDisplay::DisplayText => tree_fmt_text(tree, f), + TreeFmtVisitorDisplay::DisplayDot => tree_fmt_dot(tree, f), + } +} + +impl fmt::Display for TreeFmtVisitor { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + tree_fmt(self, f) + } +} + +impl fmt::Debug for TreeFmtVisitor { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + tree_fmt(self, f) + } +} diff --git a/crates/polars-plan/src/plans/iterator.rs b/crates/polars-plan/src/plans/iterator.rs new file mode 100644 index 000000000000..257962f815f9 --- /dev/null +++ b/crates/polars-plan/src/plans/iterator.rs @@ -0,0 +1,227 @@ +use std::sync::Arc; + +use polars_core::error::PolarsResult; +use polars_utils::idx_vec::UnitVec; +use polars_utils::unitvec; +use visitor::{RewritingVisitor, TreeWalker}; + +use crate::prelude::*; + +macro_rules! push_expr { + ($current_expr:expr, $c:ident, $push:ident, $push_owned:ident, $iter:ident) => {{ + use Expr::*; + match $current_expr { + Nth(_) | Column(_) | Literal(_) | Wildcard | Columns(_) | DtypeColumn(_) + | IndexColumn(_) | Len => {}, + #[cfg(feature = "dtype-struct")] + Field(_) => {}, + Alias(e, _) => $push($c, e), + BinaryExpr { left, op: _, right } => { + // reverse order so that left is popped first + $push($c, right); + $push($c, left); + }, + Cast { expr, .. } => $push($c, expr), + Sort { expr, .. } => $push($c, expr), + Gather { expr, idx, .. } => { + $push($c, idx); + $push($c, expr); + }, + Filter { input, by } => { + $push($c, by); + // latest, so that it is popped first + $push($c, input); + }, + SortBy { expr, by, .. } => { + for e in by { + $push_owned($c, e) + } + // latest, so that it is popped first + $push($c, expr); + }, + Agg(agg_e) => { + use AggExpr::*; + match agg_e { + Max { input, .. } => $push($c, input), + Min { input, .. } => $push($c, input), + Mean(e) => $push($c, e), + Median(e) => $push($c, e), + NUnique(e) => $push($c, e), + First(e) => $push($c, e), + Last(e) => $push($c, e), + Implode(e) => $push($c, e), + Count(e, _) => $push($c, e), + Quantile { expr, .. } => $push($c, expr), + Sum(e) => $push($c, e), + AggGroups(e) => $push($c, e), + Std(e, _) => $push($c, e), + Var(e, _) => $push($c, e), + } + }, + Ternary { + truthy, + falsy, + predicate, + } => { + $push($c, predicate); + $push($c, falsy); + // latest, so that it is popped first + $push($c, truthy); + }, + // we iterate in reverse order, so that the lhs is popped first and will be found + // as the root columns/ input columns by `_suffix` and `_keep_name` etc. + AnonymousFunction { input, .. } => input.$iter().rev().for_each(|e| $push_owned($c, e)), + Function { input, .. } => input.$iter().rev().for_each(|e| $push_owned($c, e)), + Explode(e) => $push($c, e), + Window { + function, + partition_by, + .. + } => { + for e in partition_by.into_iter().rev() { + $push_owned($c, e) + } + // latest so that it is popped first + $push($c, function); + }, + Slice { + input, + offset, + length, + } => { + $push($c, length); + $push($c, offset); + // latest, so that it is popped first + $push($c, input); + }, + Exclude(e, _) => $push($c, e), + KeepName(e) => $push($c, e), + RenameAlias { expr, .. } => $push($c, expr), + SubPlan { .. } => {}, + // pass + Selector(_) => {}, + } + }}; +} + +pub struct ExprIter<'a> { + stack: UnitVec<&'a Expr>, +} + +impl<'a> Iterator for ExprIter<'a> { + type Item = &'a Expr; + + fn next(&mut self) -> Option { + self.stack + .pop() + .inspect(|current_expr| current_expr.nodes(&mut self.stack)) + } +} + +pub struct ExprMapper { + f: F, +} + +impl PolarsResult> RewritingVisitor for ExprMapper { + type Node = Expr; + type Arena = (); + + fn mutate(&mut self, node: Self::Node, _arena: &mut Self::Arena) -> PolarsResult { + (self.f)(node) + } +} + +impl Expr { + pub fn nodes<'a>(&'a self, container: &mut UnitVec<&'a Expr>) { + let push = |c: &mut UnitVec<&'a Expr>, e: &'a Expr| c.push(e); + push_expr!(self, container, push, push, iter); + } + + pub fn nodes_owned(self, container: &mut UnitVec) { + let push_arc = |c: &mut UnitVec, e: Arc| c.push(Arc::unwrap_or_clone(e)); + let push_owned = |c: &mut UnitVec, e: Expr| c.push(e); + push_expr!(self, container, push_arc, push_owned, into_iter); + } + + pub fn map_expr Self>(self, mut f: F) -> Self { + self.rewrite(&mut ExprMapper { f: |e| Ok(f(e)) }, &mut ()) + .unwrap() + } + + pub fn try_map_expr PolarsResult>(self, f: F) -> PolarsResult { + self.rewrite(&mut ExprMapper { f }, &mut ()) + } +} + +impl<'a> IntoIterator for &'a Expr { + type Item = &'a Expr; + type IntoIter = ExprIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + let stack = unitvec!(self); + ExprIter { stack } + } +} + +pub struct AExprIter<'a> { + stack: UnitVec, + arena: Option<&'a Arena>, +} + +impl<'a> Iterator for AExprIter<'a> { + type Item = (Node, &'a AExpr); + + fn next(&mut self) -> Option { + self.stack.pop().map(|node| { + // take the arena because the bchk doesn't allow a mutable borrow to the field. + let arena = self.arena.unwrap(); + let current_expr = arena.get(node); + current_expr.inputs_rev(&mut self.stack); + + self.arena = Some(arena); + (node, current_expr) + }) + } +} + +pub trait ArenaExprIter<'a> { + fn iter(&self, root: Node) -> AExprIter<'a>; +} + +impl<'a> ArenaExprIter<'a> for &'a Arena { + fn iter(&self, root: Node) -> AExprIter<'a> { + let stack = unitvec![root]; + AExprIter { + stack, + arena: Some(self), + } + } +} + +pub struct AlpIter<'a> { + stack: UnitVec, + arena: &'a Arena, +} + +pub trait ArenaLpIter<'a> { + fn iter(&self, root: Node) -> AlpIter<'a>; +} + +impl<'a> ArenaLpIter<'a> for &'a Arena { + fn iter(&self, root: Node) -> AlpIter<'a> { + let stack = unitvec![root]; + AlpIter { stack, arena: self } + } +} + +impl<'a> Iterator for AlpIter<'a> { + type Item = (Node, &'a IR); + + fn next(&mut self) -> Option { + self.stack.pop().map(|node| { + let lp = self.arena.get(node); + lp.copy_inputs(&mut self.stack); + (node, lp) + }) + } +} diff --git a/crates/polars-plan/src/plans/lit.rs b/crates/polars-plan/src/plans/lit.rs new file mode 100644 index 000000000000..02d4f3b82166 --- /dev/null +++ b/crates/polars-plan/src/plans/lit.rs @@ -0,0 +1,633 @@ +use std::hash::{Hash, Hasher}; + +#[cfg(feature = "temporal")] +use chrono::{Duration as ChronoDuration, NaiveDate, NaiveDateTime}; +use polars_core::chunked_array::cast::CastOptions; +use polars_core::prelude::*; +use polars_core::utils::materialize_dyn_int; +use polars_utils::hashing::hash_to_partition; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use crate::constants::get_literal_name; +use crate::prelude::*; + +#[derive(Clone, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum DynLiteralValue { + Str(PlSmallStr), + Int(i128), + Float(f64), + List(DynListLiteralValue), +} +#[derive(Clone, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum DynListLiteralValue { + Str(Box<[Option]>), + Int(Box<[Option]>), + Float(Box<[Option]>), + List(Box<[Option]>), +} + +impl Hash for DynLiteralValue { + fn hash(&self, state: &mut H) { + std::mem::discriminant(self).hash(state); + match self { + Self::Str(i) => i.hash(state), + Self::Int(i) => i.hash(state), + Self::Float(i) => i.to_ne_bytes().hash(state), + Self::List(i) => i.hash(state), + } + } +} + +impl Hash for DynListLiteralValue { + fn hash(&self, state: &mut H) { + std::mem::discriminant(self).hash(state); + match self { + Self::Str(i) => i.hash(state), + Self::Int(i) => i.hash(state), + Self::Float(i) => i + .iter() + .for_each(|i| i.map(|i| i.to_ne_bytes()).hash(state)), + Self::List(i) => i.hash(state), + } + } +} + +#[derive(Clone, PartialEq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct RangeLiteralValue { + pub low: i128, + pub high: i128, + pub dtype: DataType, +} +#[derive(Clone, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum LiteralValue { + /// A dynamically inferred literal value. This needs to be materialized into a specific type. + Dyn(DynLiteralValue), + Scalar(Scalar), + Series(SpecialEq), + Range(RangeLiteralValue), +} + +pub enum MaterializedLiteralValue { + Scalar(Scalar), + Series(Series), +} + +impl DynListLiteralValue { + pub fn try_materialize_to_dtype(self, dtype: &DataType) -> PolarsResult { + let Some(inner_dtype) = dtype.inner_dtype() else { + polars_bail!(InvalidOperation: "conversion from list literal to `{dtype}` failed."); + }; + + let s = match self { + DynListLiteralValue::Str(vs) => { + StringChunked::from_iter_options(PlSmallStr::from_static("literal"), vs.into_iter()) + .into_series() + }, + DynListLiteralValue::Int(vs) => { + #[cfg(feature = "dtype-i128")] + { + Int128Chunked::from_iter_options( + PlSmallStr::from_static("literal"), + vs.into_iter(), + ) + .into_series() + } + + #[cfg(not(feature = "dtype-i128"))] + { + Int64Chunked::from_iter_options( + PlSmallStr::from_static("literal"), + vs.into_iter().map(|v| v.map(|v| v as i64)), + ) + .into_series() + } + }, + DynListLiteralValue::Float(vs) => Float64Chunked::from_iter_options( + PlSmallStr::from_static("literal"), + vs.into_iter(), + ) + .into_series(), + DynListLiteralValue::List(_) => todo!("nested lists"), + }; + + let s = s.cast_with_options(inner_dtype, CastOptions::Strict)?; + let value = match dtype { + DataType::List(_) => AnyValue::List(s), + #[cfg(feature = "dtype-array")] + DataType::Array(_, size) => AnyValue::Array(s, *size), + _ => unreachable!(), + }; + + Ok(Scalar::new(dtype.clone(), value)) + } +} + +impl DynLiteralValue { + pub fn try_materialize_to_dtype(self, dtype: &DataType) -> PolarsResult { + match self { + DynLiteralValue::Str(s) => { + Ok(Scalar::from(s).cast_with_options(dtype, CastOptions::Strict)?) + }, + DynLiteralValue::Int(i) => { + Ok(Scalar::from(i).cast_with_options(dtype, CastOptions::Strict)?) + }, + DynLiteralValue::Float(f) => { + Ok(Scalar::from(f).cast_with_options(dtype, CastOptions::Strict)?) + }, + DynLiteralValue::List(dyn_list_value) => dyn_list_value.try_materialize_to_dtype(dtype), + } + } +} + +impl RangeLiteralValue { + pub fn try_materialize_to_series(self, dtype: &DataType) -> PolarsResult { + fn handle_range_oob(range: &RangeLiteralValue, to_dtype: &DataType) -> PolarsResult<()> { + polars_bail!( + InvalidOperation: + "conversion from `{}` to `{to_dtype}` failed for range({}, {})", + range.dtype, range.low, range.high, + ) + } + + let s = match dtype { + DataType::Int32 => { + if self.low < i32::MIN as i128 || self.high > i32::MAX as i128 { + handle_range_oob(&self, dtype)?; + } + + new_int_range::( + self.low as i32, + self.high as i32, + 1, + PlSmallStr::from_static("range"), + ) + .unwrap() + }, + DataType::Int64 => { + if self.low < i64::MIN as i128 || self.high > i64::MAX as i128 { + handle_range_oob(&self, dtype)?; + } + + new_int_range::( + self.low as i64, + self.high as i64, + 1, + PlSmallStr::from_static("range"), + ) + .unwrap() + }, + DataType::UInt32 => { + if self.low < u32::MIN as i128 || self.high > u32::MAX as i128 { + handle_range_oob(&self, dtype)?; + } + new_int_range::( + self.low as u32, + self.high as u32, + 1, + PlSmallStr::from_static("range"), + ) + .unwrap() + }, + _ => polars_bail!(InvalidOperation: "unsupported range datatype `{dtype}`"), + }; + + Ok(s) + } +} + +impl LiteralValue { + /// Get the output name as `&str`. + pub(crate) fn output_name(&self) -> &PlSmallStr { + match self { + LiteralValue::Series(s) => s.name(), + _ => get_literal_name(), + } + } + + /// Get the output name as [`PlSmallStr`]. + pub(crate) fn output_column_name(&self) -> &PlSmallStr { + match self { + LiteralValue::Series(s) => s.name(), + _ => get_literal_name(), + } + } + + pub fn try_materialize_to_dtype( + self, + dtype: &DataType, + ) -> PolarsResult { + use LiteralValue as L; + match self { + L::Dyn(dyn_value) => dyn_value + .try_materialize_to_dtype(dtype) + .map(MaterializedLiteralValue::Scalar), + L::Scalar(sc) => Ok(MaterializedLiteralValue::Scalar( + sc.cast_with_options(dtype, CastOptions::Strict)?, + )), + L::Range(range) => { + let Some(inner_dtype) = dtype.inner_dtype() else { + polars_bail!( + InvalidOperation: "cannot turn `{}` range into `{dtype}`", + range.dtype + ); + }; + + let s = range.try_materialize_to_series(inner_dtype)?; + let value = match dtype { + DataType::List(_) => AnyValue::List(s), + #[cfg(feature = "dtype-array")] + DataType::Array(_, size) => AnyValue::Array(s, *size), + _ => unreachable!(), + }; + Ok(MaterializedLiteralValue::Scalar(Scalar::new( + dtype.clone(), + value, + ))) + }, + L::Series(s) => Ok(MaterializedLiteralValue::Series( + s.cast_with_options(dtype, CastOptions::Strict)?, + )), + } + } + + pub fn extract_usize(&self) -> PolarsResult { + macro_rules! cast_usize { + ($v:expr) => { + usize::try_from($v).map_err( + |_| polars_err!(InvalidOperation: "cannot convert value {} to usize", $v) + ) + } + } + match &self { + Self::Dyn(DynLiteralValue::Int(v)) => cast_usize!(*v), + Self::Scalar(sc) => match sc.as_any_value() { + AnyValue::UInt8(v) => Ok(v as usize), + AnyValue::UInt16(v) => Ok(v as usize), + AnyValue::UInt32(v) => cast_usize!(v), + AnyValue::UInt64(v) => cast_usize!(v), + AnyValue::Int8(v) => cast_usize!(v), + AnyValue::Int16(v) => cast_usize!(v), + AnyValue::Int32(v) => cast_usize!(v), + AnyValue::Int64(v) => cast_usize!(v), + AnyValue::Int128(v) => cast_usize!(v), + _ => { + polars_bail!(InvalidOperation: "expression must be constant literal to extract integer") + }, + }, + _ => { + polars_bail!(InvalidOperation: "expression must be constant literal to extract integer") + }, + } + } + + pub fn materialize(self) -> Self { + match self { + LiteralValue::Dyn(_) => { + let av = self.to_any_value().unwrap(); + av.into() + }, + lv => lv, + } + } + + pub fn is_scalar(&self) -> bool { + !matches!(self, LiteralValue::Series(_) | LiteralValue::Range { .. }) + } + + pub fn to_any_value(&self) -> Option { + let av = match self { + Self::Scalar(sc) => sc.value().clone(), + Self::Range(range) => { + let s = range.clone().try_materialize_to_series(&range.dtype).ok()?; + AnyValue::List(s) + }, + Self::Series(_) => return None, + Self::Dyn(d) => match d { + DynLiteralValue::Int(v) => materialize_dyn_int(*v), + DynLiteralValue::Float(v) => AnyValue::Float64(*v), + DynLiteralValue::Str(v) => AnyValue::String(v), + DynLiteralValue::List(_) => todo!(), + }, + }; + Some(av) + } + + /// Getter for the `DataType` of the value + pub fn get_datatype(&self) -> DataType { + match self { + Self::Dyn(d) => match d { + DynLiteralValue::Int(v) => DataType::Unknown(UnknownKind::Int(*v)), + DynLiteralValue::Float(_) => DataType::Unknown(UnknownKind::Float), + DynLiteralValue::Str(_) => DataType::Unknown(UnknownKind::Str), + DynLiteralValue::List(_) => todo!(), + }, + Self::Scalar(sc) => sc.dtype().clone(), + Self::Series(s) => s.dtype().clone(), + Self::Range(s) => s.dtype.clone(), + } + } + + pub fn new_idxsize(value: IdxSize) -> Self { + LiteralValue::Scalar(value.into()) + } + + pub fn extract_str(&self) -> Option<&str> { + match self { + LiteralValue::Dyn(DynLiteralValue::Str(s)) => Some(s.as_str()), + LiteralValue::Scalar(sc) => match sc.value() { + AnyValue::String(s) => Some(s), + AnyValue::StringOwned(s) => Some(s), + _ => None, + }, + _ => None, + } + } + + pub fn extract_binary(&self) -> Option<&[u8]> { + match self { + LiteralValue::Scalar(sc) => match sc.value() { + AnyValue::Binary(s) => Some(s), + AnyValue::BinaryOwned(s) => Some(s), + _ => None, + }, + _ => None, + } + } + + pub fn is_null(&self) -> bool { + match self { + Self::Scalar(sc) => sc.is_null(), + Self::Series(s) => s.len() == 1 && s.null_count() == 1, + _ => false, + } + } + + pub fn bool(&self) -> Option { + match self { + LiteralValue::Scalar(s) => match s.as_any_value() { + AnyValue::Boolean(b) => Some(b), + _ => None, + }, + _ => None, + } + } + + pub const fn untyped_null() -> Self { + Self::Scalar(Scalar::null(DataType::Null)) + } +} + +impl From for LiteralValue { + fn from(value: Scalar) -> Self { + Self::Scalar(value) + } +} + +pub trait Literal { + /// [Literal](Expr::Literal) expression. + fn lit(self) -> Expr; +} + +pub trait TypedLiteral: Literal { + /// [Literal](Expr::Literal) expression. + fn typed_lit(self) -> Expr + where + Self: Sized, + { + self.lit() + } +} + +impl TypedLiteral for String {} +impl TypedLiteral for &str {} + +impl Literal for PlSmallStr { + fn lit(self) -> Expr { + Expr::Literal(Scalar::from(self).into()) + } +} + +impl Literal for String { + fn lit(self) -> Expr { + Expr::Literal(Scalar::from(PlSmallStr::from_string(self)).into()) + } +} + +impl Literal for &str { + fn lit(self) -> Expr { + Expr::Literal(Scalar::from(PlSmallStr::from_str(self)).into()) + } +} + +impl Literal for Vec { + fn lit(self) -> Expr { + Expr::Literal(Scalar::from(self).into()) + } +} + +impl Literal for &[u8] { + fn lit(self) -> Expr { + Expr::Literal(Scalar::from(self.to_vec()).into()) + } +} + +impl From> for LiteralValue { + fn from(value: AnyValue<'_>) -> Self { + Self::Scalar(Scalar::new(value.dtype(), value.into_static())) + } +} + +macro_rules! make_literal { + ($TYPE:ty, $SCALAR:ident) => { + impl Literal for $TYPE { + fn lit(self) -> Expr { + Expr::Literal(Scalar::from(self).into()) + } + } + }; +} + +macro_rules! make_literal_typed { + ($TYPE:ty, $SCALAR:ident) => { + impl TypedLiteral for $TYPE { + fn typed_lit(self) -> Expr { + Expr::Literal(Scalar::from(self).into()) + } + } + }; +} + +macro_rules! make_dyn_lit { + ($TYPE:ty, $SCALAR:ident) => { + impl Literal for $TYPE { + fn lit(self) -> Expr { + Expr::Literal(LiteralValue::Dyn(DynLiteralValue::$SCALAR( + self.try_into().unwrap(), + ))) + } + } + }; +} + +make_literal!(bool, Boolean); +make_literal_typed!(f32, Float32); +make_literal_typed!(f64, Float64); +make_literal_typed!(i8, Int8); +make_literal_typed!(i16, Int16); +make_literal_typed!(i32, Int32); +make_literal_typed!(i64, Int64); +make_literal_typed!(i128, Int128); +make_literal_typed!(u8, UInt8); +make_literal_typed!(u16, UInt16); +make_literal_typed!(u32, UInt32); +make_literal_typed!(u64, UInt64); + +make_dyn_lit!(f32, Float); +make_dyn_lit!(f64, Float); +make_dyn_lit!(i8, Int); +make_dyn_lit!(i16, Int); +make_dyn_lit!(i32, Int); +make_dyn_lit!(i64, Int); +make_dyn_lit!(u8, Int); +make_dyn_lit!(u16, Int); +make_dyn_lit!(u32, Int); +make_dyn_lit!(u64, Int); +make_dyn_lit!(i128, Int); + +/// The literal Null +pub struct Null {} +pub const NULL: Null = Null {}; + +impl Literal for Null { + fn lit(self) -> Expr { + Expr::Literal(LiteralValue::Scalar(Scalar::null(DataType::Null))) + } +} + +#[cfg(feature = "dtype-datetime")] +impl Literal for NaiveDateTime { + fn lit(self) -> Expr { + if in_nanoseconds_window(&self) { + Expr::Literal( + Scalar::new_datetime( + self.and_utc().timestamp_nanos_opt().unwrap(), + TimeUnit::Nanoseconds, + None, + ) + .into(), + ) + } else { + Expr::Literal( + Scalar::new_datetime( + self.and_utc().timestamp_micros(), + TimeUnit::Microseconds, + None, + ) + .into(), + ) + } + } +} + +#[cfg(feature = "dtype-duration")] +impl Literal for ChronoDuration { + fn lit(self) -> Expr { + if let Some(value) = self.num_nanoseconds() { + Expr::Literal(Scalar::new_duration(value, TimeUnit::Nanoseconds).into()) + } else { + Expr::Literal( + Scalar::new_duration(self.num_microseconds().unwrap(), TimeUnit::Microseconds) + .into(), + ) + } + } +} + +#[cfg(feature = "dtype-duration")] +impl Literal for Duration { + fn lit(self) -> Expr { + assert!( + self.months() == 0, + "Cannot create literal duration that is not of fixed length; found {}", + self + ); + let ns = self.duration_ns(); + Expr::Literal( + Scalar::new_duration( + if self.negative() { -ns } else { ns }, + TimeUnit::Nanoseconds, + ) + .into(), + ) + } +} + +#[cfg(feature = "dtype-datetime")] +impl Literal for NaiveDate { + fn lit(self) -> Expr { + self.and_hms_opt(0, 0, 0).unwrap().lit() + } +} + +impl Literal for Series { + fn lit(self) -> Expr { + Expr::Literal(LiteralValue::Series(SpecialEq::new(self))) + } +} + +impl Literal for LiteralValue { + fn lit(self) -> Expr { + Expr::Literal(self) + } +} + +impl Literal for Scalar { + fn lit(self) -> Expr { + Expr::Literal(self.into()) + } +} + +/// Create a Literal Expression from `L`. A literal expression behaves like a column that contains a single distinct +/// value. +/// +/// The column is automatically of the "correct" length to make the operations work. Often this is determined by the +/// length of the `LazyFrame` it is being used with. For instance, `lazy_df.with_column(lit(5).alias("five"))` creates a +/// new column named "five" that is the length of the Dataframe (at the time `collect` is called), where every value in +/// the column is `5`. +pub fn lit(t: L) -> Expr { + t.lit() +} + +pub fn typed_lit(t: L) -> Expr { + t.typed_lit() +} + +impl Hash for LiteralValue { + fn hash(&self, state: &mut H) { + std::mem::discriminant(self).hash(state); + match self { + LiteralValue::Series(s) => { + // Free stats + s.dtype().hash(state); + let len = s.len(); + len.hash(state); + s.null_count().hash(state); + const RANDOM: u64 = 0x2c194fa5df32a367; + let mut rng = (len as u64) ^ RANDOM; + for _ in 0..std::cmp::min(5, len) { + let idx = hash_to_partition(rng, len); + s.get(idx).unwrap().hash(state); + rng = rng.rotate_right(17).wrapping_add(RANDOM); + } + }, + LiteralValue::Range(range) => range.hash(state), + LiteralValue::Scalar(sc) => sc.hash(state), + LiteralValue::Dyn(d) => d.hash(state), + } + } +} diff --git a/crates/polars-plan/src/plans/mod.rs b/crates/polars-plan/src/plans/mod.rs new file mode 100644 index 000000000000..1d21bc106dc4 --- /dev/null +++ b/crates/polars-plan/src/plans/mod.rs @@ -0,0 +1,50 @@ +use std::sync::Arc; + +use polars_core::prelude::*; + +use crate::prelude::*; + +pub(crate) mod aexpr; +pub(crate) mod anonymous_scan; +pub(crate) mod ir; + +mod apply; +mod builder_ir; +pub(crate) mod conversion; +#[cfg(feature = "debugging")] +pub(crate) mod debug; +pub mod expr_ir; +mod functions; +pub mod hive; +pub(crate) mod iterator; +mod lit; +pub(crate) mod optimizer; +pub(crate) mod options; +#[cfg(feature = "python")] +pub mod python; +#[cfg(feature = "python")] +pub use python::*; +mod schema; +pub mod visitor; + +pub use aexpr::*; +pub use anonymous_scan::*; +pub use apply::*; +pub use builder_ir::*; +pub use conversion::*; +pub(crate) use expr_ir::*; +pub use functions::*; +pub use ir::*; +pub use iterator::*; +pub use lit::*; +pub use optimizer::*; +pub use schema::*; + +#[derive(Clone, Copy, Debug, Default)] +pub enum Context { + /// Any operation that is done on groups + Aggregation, + /// Any operation that is done while projection/ selection of data + #[default] + Default, +} diff --git a/crates/polars-plan/src/plans/optimizer/cache_states.rs b/crates/polars-plan/src/plans/optimizer/cache_states.rs new file mode 100644 index 000000000000..6825d809f237 --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/cache_states.rs @@ -0,0 +1,395 @@ +use std::collections::BTreeMap; + +use super::*; + +fn get_upper_projections( + parent: Node, + lp_arena: &Arena, + expr_arena: &Arena, + names_scratch: &mut Vec, + found_required_columns: &mut bool, +) -> bool { + let parent = lp_arena.get(parent); + + use IR::*; + // During projection pushdown all accumulated. + match parent { + SimpleProjection { columns, .. } => { + let iter = columns.iter_names_cloned(); + names_scratch.extend(iter); + *found_required_columns = true; + false + }, + Filter { predicate, .. } => { + // Also add predicate, as the projection is above the filter node. + names_scratch.extend(aexpr_to_leaf_names(predicate.node(), expr_arena)); + + true + }, + // Only filter and projection nodes are allowed, any other node we stop. + _ => false, + } +} + +fn get_upper_predicates( + parent: Node, + lp_arena: &Arena, + expr_arena: &mut Arena, + predicate_scratch: &mut Vec, +) -> bool { + let parent = lp_arena.get(parent); + + use IR::*; + match parent { + Filter { predicate, .. } => { + let expr = predicate.to_expr(expr_arena); + predicate_scratch.push(expr); + false + }, + SimpleProjection { .. } => true, + // Only filter and projection nodes are allowed, any other node we stop. + _ => false, + } +} + +type TwoParents = [Option; 2]; + +// 1. This will ensure that all equal caches communicate the amount of columns +// they need to project. +// 2. This will ensure we apply predicate in the subtrees below the caches. +// If the predicate above the cache is the same for all matching caches, that filter will be +// applied as well. +// +// # Example +// Consider this tree, where `SUB-TREE` is duplicate and can be cached. +// +// +// Tree +// | +// | +// |--------------------|-------------------| +// | | +// SUB-TREE SUB-TREE +// +// STEPS: +// - 1. CSE will run and will insert cache nodes +// +// Tree +// | +// | +// |--------------------|-------------------| +// | | +// | CACHE 0 | CACHE 0 +// | | +// SUB-TREE SUB-TREE +// +// - 2. predicate and projection pushdown will run and will insert optional FILTER and PROJECTION above the caches +// +// Tree +// | +// | +// |--------------------|-------------------| +// | FILTER (optional) | FILTER (optional) +// | PROJ (optional) | PROJ (optional) +// | | +// | CACHE 0 | CACHE 0 +// | | +// SUB-TREE SUB-TREE +// +// # Projection optimization +// The union of the projection is determined and the projection will be pushed down. +// +// Tree +// | +// | +// |--------------------|-------------------| +// | FILTER (optional) | FILTER (optional) +// | CACHE 0 | CACHE 0 +// | | +// SUB-TREE SUB-TREE +// UNION PROJ (optional) UNION PROJ (optional) +// +// # Filter optimization +// Depending on the predicates the predicate pushdown optimization will run. +// Possible cases: +// - NO FILTERS: run predicate pd from the cache nodes -> finish +// - Above the filters the caches are the same -> run predicate pd from the filter node -> finish +// - There is a cache without predicates above the cache node -> run predicate form the cache nodes -> finish +// - The predicates above the cache nodes are all different -> remove the cache nodes -> finish +pub(super) fn set_cache_states( + root: Node, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + scratch: &mut Vec, + expr_eval: ExprEval<'_>, + verbose: bool, + new_streaming: bool, +) -> PolarsResult<()> { + let mut stack = Vec::with_capacity(4); + let mut names_scratch = vec![]; + let mut predicates_scratch = vec![]; + + scratch.clear(); + stack.clear(); + + #[derive(Default)] + struct Value { + // All the children of the cache per cache-id. + children: Vec, + parents: Vec, + cache_nodes: Vec, + // Union over projected names. + names_union: PlHashSet, + // Union over predicates. + predicate_union: PlHashMap, + } + let mut cache_schema_and_children = BTreeMap::new(); + + // Stack frame + #[derive(Default, Copy, Clone)] + struct Frame { + current: Node, + cache_id: Option, + parent: TwoParents, + previous_cache: Option, + } + let init = Frame { + current: root, + ..Default::default() + }; + + stack.push(init); + + // # First traversal. + // Collect the union of columns per cache id. + // And find the cache parents. + while let Some(mut frame) = stack.pop() { + let lp = lp_arena.get(frame.current); + lp.copy_inputs(scratch); + + use IR::*; + + if let Cache { input, id, .. } = lp { + if let Some(cache_id) = frame.cache_id { + frame.previous_cache = Some(cache_id) + } + if frame.parent[0].is_some() { + // Projection pushdown has already run and blocked on cache nodes + // the pushed down columns are projected just above this cache + // if there were no pushed down column, we just take the current + // nodes schema + // we never want to naively take parents, as a join or aggregate for instance + // change the schema + + let v = cache_schema_and_children + .entry(*id) + .or_insert_with(Value::default); + v.children.push(*input); + v.parents.push(frame.parent); + v.cache_nodes.push(frame.current); + + let mut found_required_columns = false; + + for parent_node in frame.parent.into_iter().flatten() { + let keep_going = get_upper_projections( + parent_node, + lp_arena, + expr_arena, + &mut names_scratch, + &mut found_required_columns, + ); + if !names_scratch.is_empty() { + v.names_union.extend(names_scratch.drain(..)); + } + // We stop early as we want to find the first projection node above the cache. + if !keep_going { + break; + } + } + + for parent_node in frame.parent.into_iter().flatten() { + let keep_going = get_upper_predicates( + parent_node, + lp_arena, + expr_arena, + &mut predicates_scratch, + ); + if !predicates_scratch.is_empty() { + for pred in predicates_scratch.drain(..) { + let count = v.predicate_union.entry(pred).or_insert(0); + *count += 1; + } + } + // We stop early as we want to find the first predicate node above the cache. + if !keep_going { + break; + } + } + + // There was no explicit projection and we must take + // all columns + if !found_required_columns { + let schema = lp.schema(lp_arena); + v.names_union.extend(schema.iter_names_cloned()); + } + } + frame.cache_id = Some(*id); + }; + + // Shift parents. + frame.parent[1] = frame.parent[0]; + frame.parent[0] = Some(frame.current); + for n in scratch.iter() { + let mut new_frame = frame; + new_frame.current = *n; + stack.push(new_frame); + } + scratch.clear(); + } + + // # Second pass. + // we create a subtree where we project the columns + // just before the cache. Then we do another projection pushdown + // and finally remove that last projection and stitch the subplan + // back to the cache node again + if !cache_schema_and_children.is_empty() { + let mut proj_pd = ProjectionPushDown::new(); + let mut pred_pd = PredicatePushDown::new(expr_eval, new_streaming).block_at_cache(false); + for (_cache_id, v) in cache_schema_and_children { + // # CHECK IF WE NEED TO REMOVE CACHES + // If we encounter multiple predicates we remove the cache nodes completely as we don't + // want to loose predicate pushdown in favor of scan sharing. + if v.predicate_union.len() > 1 { + if verbose { + eprintln!("cache nodes will be removed because predicates don't match") + } + for ((&child, cache), parents) in + v.children.iter().zip(v.cache_nodes).zip(v.parents) + { + // Remove the cache and assign the child the cache location. + lp_arena.swap(child, cache); + + // Restart predicate and projection pushdown from most top parent. + // This to ensure we continue the optimization where it was blocked initially. + // We pick up the blocked filter and projection. + let mut node = cache; + for p_node in parents.into_iter().flatten() { + if matches!( + lp_arena.get(p_node), + IR::Filter { .. } | IR::SimpleProjection { .. } + ) { + node = p_node + } else { + break; + } + } + + let lp = lp_arena.take(node); + let lp = proj_pd.optimize(lp, lp_arena, expr_arena)?; + let lp = pred_pd.optimize(lp, lp_arena, expr_arena)?; + lp_arena.replace(node, lp); + } + return Ok(()); + } + // Below we restart projection and predicates pushdown + // on the first cache node. As it are cache nodes, the others are the same + // and we can reuse the optimized state for all inputs. + // See #21637 + + // # RUN PROJECTION PUSHDOWN + if !v.names_union.is_empty() { + let first_child = *v.children.first().expect("at least on child"); + + let columns = &v.names_union; + let child_lp = lp_arena.take(first_child); + + // Make sure we project in the order of the schema + // if we don't a union may fail as we would project by the + // order we discovered all values. + let child_schema = child_lp.schema(lp_arena); + let child_schema = child_schema.as_ref(); + let projection = child_schema + .iter_names() + .flat_map(|name| columns.get(name.as_str()).cloned()) + .collect::>(); + + let new_child = lp_arena.add(child_lp); + + let lp = IRBuilder::new(new_child, expr_arena, lp_arena) + .project_simple(projection) + .expect("unique names") + .build(); + + let lp = proj_pd.optimize(lp, lp_arena, expr_arena)?; + // Optimization can lead to a double projection. Only take the last. + let lp = if let IR::SimpleProjection { input, columns } = lp { + let input = + if let IR::SimpleProjection { input: input2, .. } = lp_arena.get(input) { + *input2 + } else { + input + }; + IR::SimpleProjection { input, columns } + } else { + lp + }; + lp_arena.replace(first_child, lp.clone()); + + // Set the remaining children to the same node. + for &child in &v.children[1..] { + lp_arena.replace(child, lp.clone()); + } + } else { + // No upper projections to include, run projection pushdown from cache node. + let first_child = *v.children.first().expect("at least on child"); + let child_lp = lp_arena.take(first_child); + let lp = proj_pd.optimize(child_lp, lp_arena, expr_arena)?; + lp_arena.replace(first_child, lp.clone()); + + for &child in &v.children[1..] { + lp_arena.replace(child, lp.clone()); + } + } + + // # RUN PREDICATE PUSHDOWN + // Run this after projection pushdown, otherwise the predicate columns will not be projected. + + // - If all predicates of parent are the same we will restart predicate pushdown from the parent FILTER node. + // - Otherwise we will start predicate pushdown from the cache node. + let allow_parent_predicate_pushdown = v.predicate_union.len() == 1 && { + let (_pred, count) = v.predicate_union.iter().next().unwrap(); + *count == v.children.len() as u32 + }; + + if allow_parent_predicate_pushdown { + let parents = *v.parents.first().unwrap(); + let node = get_filter_node(parents, lp_arena) + .expect("expected filter; this is an optimizer bug"); + let start_lp = lp_arena.take(node); + let lp = pred_pd.optimize(start_lp, lp_arena, expr_arena)?; + lp_arena.replace(node, lp.clone()); + for &parents in &v.parents[1..] { + let node = get_filter_node(parents, lp_arena) + .expect("expected filter; this is an optimizer bug"); + lp_arena.replace(node, lp.clone()); + } + } else { + let child = *v.children.first().unwrap(); + let child_lp = lp_arena.take(child); + let lp = pred_pd.optimize(child_lp, lp_arena, expr_arena)?; + lp_arena.replace(child, lp.clone()); + for &child in &v.children[1..] { + lp_arena.replace(child, lp.clone()); + } + } + } + } + Ok(()) +} + +fn get_filter_node(parents: TwoParents, lp_arena: &Arena) -> Option { + parents + .into_iter() + .flatten() + .find(|&parent| matches!(lp_arena.get(parent), IR::Filter { .. })) +} diff --git a/crates/polars-plan/src/plans/optimizer/cluster_with_columns.rs b/crates/polars-plan/src/plans/optimizer/cluster_with_columns.rs new file mode 100644 index 000000000000..f94dc8c0b969 --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/cluster_with_columns.rs @@ -0,0 +1,300 @@ +use std::sync::Arc; + +use arrow::bitmap::MutableBitmap; +use polars_core::schema::Schema; +use polars_utils::aliases::{InitHashMaps, PlHashMap}; +use polars_utils::arena::{Arena, Node}; +use polars_utils::vec::inplace_zip_filtermap; + +use super::aexpr::AExpr; +use super::ir::IR; +use super::{PlSmallStr, aexpr_to_leaf_names_iter}; + +type ColumnMap = PlHashMap; + +fn column_map_finalize_bitset(bitset: &mut MutableBitmap, column_map: &ColumnMap) { + assert!(bitset.len() <= column_map.len()); + + let size = bitset.len(); + bitset.extend_constant(column_map.len() - size, false); +} + +fn column_map_set(bitset: &mut MutableBitmap, column_map: &mut ColumnMap, column: PlSmallStr) { + let size = column_map.len(); + column_map + .entry(column) + .and_modify(|idx| bitset.set(*idx, true)) + .or_insert_with(|| { + bitset.push(true); + size + }); +} + +pub fn optimize(root: Node, lp_arena: &mut Arena, expr_arena: &Arena) { + let mut ir_stack = Vec::with_capacity(16); + ir_stack.push(root); + + // We define these here to reuse the allocations across the loops + let mut column_map = ColumnMap::with_capacity(8); + let mut input_genset = MutableBitmap::with_capacity(16); + let mut current_expr_livesets: Vec = Vec::with_capacity(16); + let mut current_liveset = MutableBitmap::with_capacity(16); + let mut pushable = MutableBitmap::with_capacity(16); + let mut potential_pushable = Vec::with_capacity(4); + + while let Some(current) = ir_stack.pop() { + let current_ir = lp_arena.get(current); + current_ir.copy_inputs(&mut ir_stack); + let IR::HStack { input, .. } = current_ir else { + continue; + }; + let input = *input; + + let [current_ir, input_ir] = lp_arena.get_many_mut([current, input]); + + let IR::HStack { + input: current_input, + exprs: current_exprs, + schema: current_schema, + options: current_options, + } = current_ir + else { + unreachable!(); + }; + let IR::HStack { + input: input_input, + exprs: input_exprs, + schema: input_schema, + options: input_options, + } = input_ir + else { + continue; + }; + + let column_map = &mut column_map; + + // Reuse the allocations of the previous loop + column_map.clear(); + input_genset.clear(); + current_expr_livesets.clear(); + current_liveset.clear(); + pushable.clear(); + potential_pushable.clear(); + + pushable.reserve(current_exprs.len()); + potential_pushable.reserve(current_exprs.len()); + + // @NOTE + // We can pushdown any column that utilizes no live columns that are generated in the + // input. + + for input_expr in input_exprs.iter() { + column_map_set( + &mut input_genset, + column_map, + input_expr.output_name().clone(), + ); + } + + for expr in current_exprs.iter() { + let mut liveset = MutableBitmap::from_len_zeroed(column_map.len()); + + for live in aexpr_to_leaf_names_iter(expr.node(), expr_arena) { + column_map_set(&mut liveset, column_map, live.clone()); + } + + current_expr_livesets.push(liveset); + } + + // Force that column_map is not further mutated from this point on + let column_map = column_map as &_; + + column_map_finalize_bitset(&mut input_genset, column_map); + + current_liveset.extend_constant(column_map.len(), false); + for expr_liveset in &mut current_expr_livesets { + use std::ops::BitOrAssign; + column_map_finalize_bitset(expr_liveset, column_map); + (&mut current_liveset).bitor_assign(expr_liveset as &_); + } + + // Check for every expression in the current WITH_COLUMNS node whether it can be pushed + // down or pruned. + inplace_zip_filtermap( + current_exprs, + &mut current_expr_livesets, + |mut expr, liveset| { + let does_input_assign_column_that_expr_used = + input_genset.intersects_with(&liveset); + + if does_input_assign_column_that_expr_used { + pushable.push(false); + return Some((expr, liveset)); + } + + let column_name = expr.output_name(); + let is_pushable = if let Some(idx) = column_map.get(column_name) { + let does_input_alias_also_expr = input_genset.get(*idx); + let is_alias_live_in_current = current_liveset.get(*idx); + + if does_input_alias_also_expr && !is_alias_live_in_current { + // @NOTE: Pruning of re-assigned columns + // + // We checked if this expression output is also assigned by the input and + // that this assignment is not used in the current WITH_COLUMNS. + // Consequently, we are free to prune the input's assignment to the output. + // + // We immediately prune here to simplify the later code. + // + // @NOTE: Expressions in a `WITH_COLUMNS` cannot alias to the same column. + // Otherwise, this would be faulty and would panic. + let input_expr = input_exprs + .iter_mut() + .find(|input_expr| column_name == input_expr.output_name()) + .expect("No assigning expression for generated column"); + + // @NOTE + // Since we are reassigning a column and we are pushing to the input, we do + // not need to change the schema of the current or input nodes. + std::mem::swap(&mut expr, input_expr); + return None; + } + + // We cannot have multiple assignments to the same column in one WITH_COLUMNS + // and we need to make sure that we are not changing the column value that + // neighbouring expressions are seeing. + + // @NOTE: In this case it might be possible to push this down if all the + // expressions that use the output are also being pushed down. + if !does_input_alias_also_expr && is_alias_live_in_current { + potential_pushable.push(pushable.len()); + pushable.push(false); + return Some((expr, liveset)); + } + + !does_input_alias_also_expr && !is_alias_live_in_current + } else { + true + }; + + pushable.push(is_pushable); + Some((expr, liveset)) + }, + ); + + debug_assert_eq!(pushable.len(), current_exprs.len()); + + // Here we do a last check for expressions to push down. + // This will pushdown the expressions that "has an output column that is mentioned by + // neighbour columns, but all those neighbours were being pushed down". + for candidate in potential_pushable.iter().copied() { + let column_name = current_exprs[candidate].output_name(); + let column_idx = column_map.get(column_name).unwrap(); + + current_liveset.clear(); + current_liveset.extend_constant(column_map.len(), false); + for (i, expr_liveset) in current_expr_livesets.iter().enumerate() { + if pushable.get(i) || i == candidate { + continue; + } + use std::ops::BitOrAssign; + (&mut current_liveset).bitor_assign(expr_liveset as &_); + } + + if !current_liveset.get(*column_idx) { + pushable.set(candidate, true); + } + } + + let pushable_set_bits = pushable.set_bits(); + + // If all columns are pushable, we can merge the input into the current. This should be + // a relatively common case. + if pushable_set_bits == pushable.len() { + // @NOTE: To keep the schema correct, we reverse the order here. As a + // `WITH_COLUMNS` higher up produces later columns. This also allows us not to + // have to deal with schemas. + input_exprs.extend(std::mem::take(current_exprs)); + std::mem::swap(current_exprs, input_exprs); + + // Here, we perform the trick where we switch the inputs. This makes it possible to + // change the essentially remove the `current` node without knowing the parent of + // `current`. Essentially, we move the input node to the current node. + *current_input = *input_input; + *current_options = current_options.merge_options(input_options); + + // Let us just make this node invalid so we can detect when someone tries to + // mention it later. + lp_arena.take(input); + + // Since we merged the current and input nodes and the input node might have + // optimizations with their input, we loop again on this node. + ir_stack.pop(); + ir_stack.push(current); + continue; + } + + // There is nothing to push down. Move on. + if pushable_set_bits == 0 { + continue; + } + + let input_schema_inner = Arc::make_mut(input_schema); + + // @NOTE: We don't have to insert a SimpleProjection or redo the `current_schema` if + // `pushable` contains only 0..N for some N. We use these two variables to keep track + // of this. + let mut has_seen_unpushable = false; + let mut needs_simple_projection = false; + + input_schema_inner.reserve(pushable_set_bits); + input_exprs.reserve(pushable_set_bits); + *current_exprs = std::mem::take(current_exprs) + .into_iter() + .zip(pushable.iter()) + .filter_map(|(expr, do_pushdown)| { + if do_pushdown { + needs_simple_projection = has_seen_unpushable; + + let column = expr.output_name().as_ref(); + // @NOTE: we cannot just use the index here, as there might be renames that sit + // earlier in the schema + let datatype = current_schema.get(column).unwrap(); + input_schema_inner.with_column(column.into(), datatype.clone()); + input_exprs.push(expr); + + None + } else { + has_seen_unpushable = true; + Some(expr) + } + }) + .collect(); + + let options = current_options.merge_options(input_options); + *current_options = options; + *input_options = options; + + // @NOTE: Here we add a simple projection to make sure that the output still + // has the right schema. + if needs_simple_projection { + // @NOTE: This may seem stupid, but this way we prioritize the input columns and then + // the existing columns which is exactly what we want. + let mut new_current_schema = Schema::with_capacity(current_schema.len()); + new_current_schema.merge_from_ref(input_schema.as_ref()); + new_current_schema.merge_from_ref(current_schema.as_ref()); + + debug_assert_eq!(new_current_schema.len(), current_schema.len()); + + let proj_schema = std::mem::replace(current_schema, Arc::new(new_current_schema)); + + let moved_current = lp_arena.add(IR::Invalid); + let projection = IR::SimpleProjection { + input: moved_current, + columns: proj_schema, + }; + let current = lp_arena.replace(current, projection); + lp_arena.replace(moved_current, current); + } + } +} diff --git a/crates/polars-plan/src/plans/optimizer/collapse_and_project.rs b/crates/polars-plan/src/plans/optimizer/collapse_and_project.rs new file mode 100644 index 000000000000..912c33e9c77b --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/collapse_and_project.rs @@ -0,0 +1,132 @@ +use std::collections::BTreeSet; + +use super::*; + +/// Projection in the physical plan is done by selecting an expression per thread. +/// In case of many projections and columns this can be expensive when the expressions are simple +/// column selections. These can be selected on a single thread. The single thread is faster, because +/// the eager selection algorithm hashes the column names, making the projection complexity linear +/// instead of quadratic. +/// +/// It is important that this optimization is ran after projection pushdown. +/// +/// The schema reported after this optimization is also +pub(super) struct SimpleProjectionAndCollapse { + /// Keep track of nodes that are already processed when they + /// can be expensive. Schema materialization can be for instance. + processed: BTreeSet, + eager: bool, +} + +impl SimpleProjectionAndCollapse { + pub(super) fn new(eager: bool) -> Self { + Self { + processed: Default::default(), + eager, + } + } +} + +impl OptimizationRule for SimpleProjectionAndCollapse { + fn optimize_plan( + &mut self, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + node: Node, + ) -> PolarsResult> { + use IR::*; + let lp = lp_arena.get(node); + + match lp { + Select { input, expr, .. } => { + if !matches!(lp_arena.get(*input), ExtContext { .. }) + && !self.processed.contains(&node) + { + // First check if we can apply the optimization before we allocate. + if !expr.iter().all(|e| { + matches!(expr_arena.get(e.node()), AExpr::Column(_)) && !e.has_alias() + }) { + self.processed.insert(node); + return Ok(None); + } + + let exprs = expr + .iter() + .map(|e| e.output_name().clone()) + .collect::>(); + let Some(alp) = IRBuilder::new(*input, expr_arena, lp_arena) + .project_simple(exprs.iter().cloned()) + .ok() + else { + return Ok(None); + }; + let alp = alp.build(); + + Ok(Some(alp)) + } else { + self.processed.insert(node); + Ok(None) + } + }, + SimpleProjection { columns, input } if !self.eager => { + match lp_arena.get(*input) { + // If there are 2 subsequent fast projections, flatten them and only take the last + SimpleProjection { + input: prev_input, .. + } => Ok(Some(SimpleProjection { + input: *prev_input, + columns: columns.clone(), + })), + // Cleanup projections set in projection pushdown just above caches + // they are not needed. + cache_lp @ Cache { .. } if self.processed.contains(&node) => { + let cache_schema = cache_lp.schema(lp_arena); + if cache_schema.len() == columns.len() + && cache_schema.iter_names().zip(columns.iter_names()).all( + |(left_name, right_name)| left_name.as_str() == right_name.as_str(), + ) + { + Ok(Some(cache_lp.clone())) + } else { + Ok(None) + } + }, + // If a projection does nothing, remove it. + other => { + let input_schema = other.schema(lp_arena); + // This will fail fast if lengths are not equal + if *input_schema.as_ref() == *columns { + Ok(Some(other.clone())) + } else { + self.processed.insert(node); + Ok(None) + } + }, + } + }, + // if there are 2 subsequent caches, flatten them and only take the inner + Cache { + input, + cache_hits: outer_cache_hits, + .. + } if !self.eager => { + if let Cache { + input: prev_input, + id, + cache_hits, + } = lp_arena.get(*input) + { + Ok(Some(Cache { + input: *prev_input, + id: *id, + // ensure the counts are updated + cache_hits: cache_hits.saturating_add(*outer_cache_hits), + })) + } else { + Ok(None) + } + }, + _ => Ok(None), + } + } +} diff --git a/crates/polars-plan/src/plans/optimizer/collapse_joins.rs b/crates/polars-plan/src/plans/optimizer/collapse_joins.rs new file mode 100644 index 000000000000..6b71643076be --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/collapse_joins.rs @@ -0,0 +1,382 @@ +//! Optimization that collapses several a join with several filters into faster join. +//! +//! For example, `join(how='cross').filter(pl.col.l == pl.col.r)` can be collapsed to +//! `join(how='inner', left_on=pl.col.l, right_on=pl.col.r)`. + +use std::sync::Arc; + +use polars_core::schema::*; +#[cfg(feature = "iejoin")] +use polars_ops::frame::{IEJoinOptions, InequalityOperator}; +use polars_ops::frame::{JoinCoalesce, JoinType, MaintainOrderJoin}; +use polars_utils::arena::{Arena, Node}; + +use super::{AExpr, ExprOrigin, IR, JoinOptions, aexpr_to_leaf_names_iter}; +use crate::dsl::{JoinTypeOptionsIR, Operator}; +use crate::plans::optimizer::join_utils::remove_suffix; +use crate::plans::{ExprIR, MintermIter}; + +fn and_expr(left: Node, right: Node, expr_arena: &mut Arena) -> Node { + expr_arena.add(AExpr::BinaryExpr { + left, + op: Operator::And, + right, + }) +} + +pub fn optimize(root: Node, lp_arena: &mut Arena, expr_arena: &mut Arena) { + let mut predicates = Vec::with_capacity(4); + + // Partition to: + // - equality predicates + // - IEjoin supported inequality predicates + // - remaining predicates + #[cfg(feature = "iejoin")] + let mut ie_op = Vec::new(); + let mut remaining_predicates = Vec::new(); + + let mut ir_stack = Vec::with_capacity(16); + ir_stack.push(root); + + while let Some(current) = ir_stack.pop() { + let current_ir = lp_arena.get(current); + current_ir.copy_inputs(&mut ir_stack); + + match current_ir { + IR::Filter { + input: _, + predicate, + } => { + predicates.push((current, predicate.node())); + }, + IR::Join { + input_left, + input_right, + schema, + left_on, + right_on, + options, + } if options.args.how.is_cross() => { + if predicates.is_empty() { + continue; + } + + let suffix = options.args.suffix(); + + debug_assert!(left_on.is_empty()); + debug_assert!(right_on.is_empty()); + + let mut eq_left_on = Vec::new(); + let mut eq_right_on = Vec::new(); + + #[cfg(feature = "iejoin")] + let mut ie_left_on = Vec::new(); + #[cfg(feature = "iejoin")] + let mut ie_right_on = Vec::new(); + + #[cfg(feature = "iejoin")] + { + ie_op.clear(); + } + + remaining_predicates.clear(); + + #[cfg(feature = "iejoin")] + fn to_inequality_operator(op: &Operator) -> Option { + match op { + Operator::Lt => Some(InequalityOperator::Lt), + Operator::LtEq => Some(InequalityOperator::LtEq), + Operator::Gt => Some(InequalityOperator::Gt), + Operator::GtEq => Some(InequalityOperator::GtEq), + _ => None, + } + } + + let left_schema = lp_arena.get(*input_left).schema(lp_arena); + let right_schema = lp_arena.get(*input_right).schema(lp_arena); + + let left_schema = left_schema.as_ref(); + let right_schema = right_schema.as_ref(); + + for (_, predicate_node) in &predicates { + for node in MintermIter::new(*predicate_node, expr_arena) { + let AExpr::BinaryExpr { left, op, right } = expr_arena.get(node) else { + remaining_predicates.push(node); + continue; + }; + + if !op.is_comparison_or_bitwise() { + // @NOTE: This is not a valid predicate, but we should not handle that + // here. + remaining_predicates.push(node); + continue; + } + + let mut left = *left; + let mut op = *op; + let mut right = *right; + + let left_origin = ExprOrigin::get_expr_origin( + left, + expr_arena, + left_schema, + right_schema, + suffix.as_str(), + ) + .unwrap(); + let right_origin = ExprOrigin::get_expr_origin( + right, + expr_arena, + left_schema, + right_schema, + suffix.as_str(), + ) + .unwrap(); + + use ExprOrigin as EO; + + // We can only join if both sides of the binary expression stem from + // different sides of the join. + match (left_origin, right_origin) { + (EO::Both, _) | (_, EO::Both) => { + // If either expression originates from the both sides, we need to + // filter it afterwards. + remaining_predicates.push(node); + continue; + }, + (EO::None, _) | (_, EO::None) => { + // @TODO: This should probably be pushed down + remaining_predicates.push(node); + continue; + }, + (EO::Left, EO::Left) | (EO::Right, EO::Right) => { + // @TODO: This can probably be pushed down in the predicate + // pushdown, but for now just take it as is. + remaining_predicates.push(node); + continue; + }, + (EO::Right, EO::Left) => { + // Swap around the expressions so they match with the left_on and + // right_on. + std::mem::swap(&mut left, &mut right); + op = op.swap_operands(); + }, + (EO::Left, EO::Right) => {}, + } + + if matches!(op, Operator::Eq) { + eq_left_on.push(ExprIR::from_node(left, expr_arena)); + eq_right_on.push(ExprIR::from_node(right, expr_arena)); + } else { + #[cfg(feature = "iejoin")] + if let Some(ie_op_) = to_inequality_operator(&op) { + fn is_numeric( + node: Node, + expr_arena: &Arena, + schema: &Schema, + ) -> bool { + aexpr_to_leaf_names_iter(node, expr_arena).any(|name| { + if let Some(dt) = schema.get(name.as_str()) { + dt.to_physical().is_primitive_numeric() + } else { + false + } + }) + } + + // We fallback to remaining if: + // - we already have an IEjoin or Inner join + // - we already have an Inner join + // - data is not numeric (our iejoin doesn't yet implement that) + if ie_op.len() >= 2 + || !eq_left_on.is_empty() + || !is_numeric(left, expr_arena, left_schema) + { + remaining_predicates.push(node); + } else { + ie_left_on.push(ExprIR::from_node(left, expr_arena)); + ie_right_on.push(ExprIR::from_node(right, expr_arena)); + ie_op.push(ie_op_); + } + } else { + remaining_predicates.push(node); + } + + #[cfg(not(feature = "iejoin"))] + remaining_predicates.push(node); + } + } + } + + let mut can_simplify_join = false; + + if !eq_left_on.is_empty() { + for expr in eq_right_on.iter_mut() { + remove_suffix(expr, expr_arena, right_schema, suffix.as_str()); + } + can_simplify_join = true; + } else { + #[cfg(feature = "iejoin")] + if !ie_op.is_empty() { + for expr in ie_right_on.iter_mut() { + remove_suffix(expr, expr_arena, right_schema, suffix.as_str()); + } + can_simplify_join = true; + } + can_simplify_join |= options.args.how.is_cross(); + } + + if can_simplify_join { + let new_join = insert_fitting_join( + eq_left_on, + eq_right_on, + #[cfg(feature = "iejoin")] + ie_left_on, + #[cfg(feature = "iejoin")] + ie_right_on, + #[cfg(feature = "iejoin")] + &ie_op, + &remaining_predicates, + lp_arena, + expr_arena, + options.as_ref().clone(), + *input_left, + *input_right, + schema.clone(), + ); + + lp_arena.swap(predicates[0].0, new_join); + } + + predicates.clear(); + }, + _ => { + predicates.clear(); + }, + } + } +} + +#[allow(clippy::too_many_arguments)] +fn insert_fitting_join( + eq_left_on: Vec, + eq_right_on: Vec, + #[cfg(feature = "iejoin")] ie_left_on: Vec, + #[cfg(feature = "iejoin")] ie_right_on: Vec, + #[cfg(feature = "iejoin")] ie_op: &[InequalityOperator], + remaining_predicates: &[Node], + lp_arena: &mut Arena, + expr_arena: &mut Arena, + mut options: JoinOptions, + input_left: Node, + input_right: Node, + schema: SchemaRef, +) -> Node { + debug_assert_eq!(eq_left_on.len(), eq_right_on.len()); + #[cfg(feature = "iejoin")] + { + debug_assert_eq!(ie_op.len(), ie_left_on.len()); + debug_assert_eq!(ie_left_on.len(), ie_right_on.len()); + debug_assert!(ie_op.len() <= 2); + } + debug_assert!(matches!(options.args.how, JoinType::Cross)); + + let remaining_predicates = remaining_predicates + .iter() + .copied() + .reduce(|left, right| and_expr(left, right, expr_arena)); + + let (left_on, right_on, remaining_predicates) = match () { + _ if !eq_left_on.is_empty() => { + options.args.how = JoinType::Inner; + // We need to make sure not to delete any columns + options.args.coalesce = JoinCoalesce::KeepColumns; + + #[cfg(feature = "iejoin")] + let remaining_predicates = ie_left_on.into_iter().zip(ie_op).zip(ie_right_on).fold( + remaining_predicates, + |acc, ((left, op), right)| { + let e = expr_arena.add(AExpr::BinaryExpr { + left: left.node(), + op: (*op).into(), + right: right.node(), + }); + Some(acc.map_or(e, |acc| and_expr(acc, e, expr_arena))) + }, + ); + + (eq_left_on, eq_right_on, remaining_predicates) + }, + #[cfg(feature = "iejoin")] + _ if !ie_op.is_empty() => { + // We can only IE join up to 2 operators + + let operator1 = ie_op[0]; + let operator2 = ie_op.get(1).copied(); + + // Do an IEjoin. + options.args.how = JoinType::IEJoin; + options.options = Some(JoinTypeOptionsIR::IEJoin(IEJoinOptions { + operator1, + operator2, + })); + // We need to make sure not to delete any columns + options.args.coalesce = JoinCoalesce::KeepColumns; + + (ie_left_on, ie_right_on, remaining_predicates) + }, + // If anything just fall back to a cross join. + _ => { + options.args.how = JoinType::Cross; + // We need to make sure not to delete any columns + options.args.coalesce = JoinCoalesce::KeepColumns; + + #[cfg(feature = "iejoin")] + let remaining_predicates = ie_left_on.into_iter().zip(ie_op).zip(ie_right_on).fold( + remaining_predicates, + |acc, ((left, op), right)| { + let e = expr_arena.add(AExpr::BinaryExpr { + left: left.node(), + op: (*op).into(), + right: right.node(), + }); + Some(acc.map_or(e, |acc| and_expr(acc, e, expr_arena))) + }, + ); + + let mut remaining_predicates = remaining_predicates; + if let Some(pred) = remaining_predicates + .take_if(|_| matches!(options.args.maintain_order, MaintainOrderJoin::None)) + { + options.options = Some(JoinTypeOptionsIR::Cross { + predicate: ExprIR::from_node(pred, expr_arena), + }) + } + + (Vec::new(), Vec::new(), remaining_predicates) + }, + }; + + // Note: We expect key type upcasting / expression optimizations have already been done during + // DSL->IR conversion. + + let join_ir = IR::Join { + input_left, + input_right, + schema, + left_on, + right_on, + options: Arc::new(options), + }; + + let join_node = lp_arena.add(join_ir); + + if let Some(predicate) = remaining_predicates { + lp_arena.add(IR::Filter { + input: join_node, + predicate: ExprIR::from_node(predicate, &*expr_arena), + }) + } else { + join_node + } +} diff --git a/crates/polars-plan/src/plans/optimizer/collect_members.rs b/crates/polars-plan/src/plans/optimizer/collect_members.rs new file mode 100644 index 000000000000..1c3f4ac2c127 --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/collect_members.rs @@ -0,0 +1,95 @@ +use std::hash::BuildHasher; + +use super::*; + +// Utility to cheaply check if we have duplicate sources. +// This may have false positives. +#[cfg(feature = "cse")] +#[derive(Default)] +struct UniqueScans { + ids: PlHashSet, + count: usize, +} + +#[cfg(feature = "cse")] +impl UniqueScans { + fn insert(&mut self, node: Node, lp_arena: &Arena, expr_arena: &Arena) { + let alp_node = IRNode::new(node); + self.ids.insert( + self.ids + .hasher() + .hash_one(alp_node.hashable_and_cmp(lp_arena, expr_arena)), + ); + self.count += 1; + } +} + +pub(super) struct MemberCollector { + pub(crate) has_joins_or_unions: bool, + pub(crate) has_sink_multiple: bool, + pub(crate) has_cache: bool, + pub(crate) has_ext_context: bool, + pub(crate) has_filter_with_join_input: bool, + pub(crate) has_distinct: bool, + pub(crate) has_sort: bool, + pub(crate) has_group_by: bool, + #[cfg(feature = "cse")] + scans: UniqueScans, +} + +impl MemberCollector { + pub(super) fn new() -> Self { + Self { + has_joins_or_unions: false, + has_sink_multiple: false, + has_cache: false, + has_ext_context: false, + has_filter_with_join_input: false, + has_distinct: false, + has_sort: false, + has_group_by: false, + #[cfg(feature = "cse")] + scans: UniqueScans::default(), + } + } + pub(super) fn collect(&mut self, root: Node, lp_arena: &Arena, _expr_arena: &Arena) { + use IR::*; + for (_node, alp) in lp_arena.iter(root) { + match alp { + SinkMultiple { .. } => self.has_sink_multiple = true, + Join { .. } | Union { .. } => self.has_joins_or_unions = true, + Filter { input, .. } => { + self.has_filter_with_join_input |= matches!(lp_arena.get(*input), Join { options, .. } if options.args.how.is_cross()) + }, + Distinct { .. } => { + self.has_distinct = true; + }, + GroupBy { .. } => { + self.has_group_by = true; + }, + Sort { .. } => { + self.has_sort = true; + }, + Cache { .. } => self.has_cache = true, + ExtContext { .. } => self.has_ext_context = true, + #[cfg(feature = "cse")] + Scan { .. } => { + self.scans.insert(_node, lp_arena, _expr_arena); + }, + HConcat { .. } => { + self.has_joins_or_unions = true; + }, + #[cfg(feature = "cse")] + DataFrameScan { .. } => { + self.scans.insert(_node, lp_arena, _expr_arena); + }, + _ => {}, + } + } + } + + #[cfg(feature = "cse")] + pub(super) fn has_duplicate_scans(&self) -> bool { + self.scans.count != self.scans.ids.len() + } +} diff --git a/crates/polars-plan/src/plans/optimizer/count_star.rs b/crates/polars-plan/src/plans/optimizer/count_star.rs new file mode 100644 index 000000000000..5fb51e798bf6 --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/count_star.rs @@ -0,0 +1,228 @@ +use std::path::PathBuf; + +use polars_io::cloud::CloudOptions; +use polars_utils::mmap::MemSlice; + +use super::*; + +pub(super) struct CountStar; + +impl CountStar { + pub(super) fn new() -> Self { + Self + } +} + +impl OptimizationRule for CountStar { + // Replace select count(*) from datasource with specialized map function. + fn optimize_plan( + &mut self, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + mut node: Node, + ) -> PolarsResult> { + // New-streaming always puts a sink on top. + if let IR::Sink { input, .. } = lp_arena.get(node) { + node = *input; + } + + // Note: This will be a useful flag later for testing parallel CountLines on CSV. + let use_fast_file_count = match std::env::var("POLARS_FAST_FILE_COUNT_DISPATCH").as_deref() + { + Ok("1") => Some(true), + Ok("0") => Some(false), + Ok(v) => panic!( + "POLARS_FAST_FILE_COUNT_DISPATCH must be one of ('0', '1'), got: {}", + v + ), + Err(_) => None, + }; + + Ok(visit_logical_plan_for_scan_paths( + node, + lp_arena, + expr_arena, + false, + use_fast_file_count, + ) + .map(|count_star_expr| { + // MapFunction needs a leaf node, hence we create a dummy placeholder node + let placeholder = IR::DataFrameScan { + df: Arc::new(Default::default()), + schema: Arc::new(Default::default()), + output_schema: None, + }; + let placeholder_node = lp_arena.add(placeholder); + + let alp = IR::MapFunction { + input: placeholder_node, + function: FunctionIR::FastCount { + sources: count_star_expr.sources, + scan_type: count_star_expr.scan_type, + cloud_options: count_star_expr.cloud_options, + alias: count_star_expr.alias, + }, + }; + + lp_arena.replace(count_star_expr.node, alp.clone()); + alp + })) + } +} + +struct CountStarExpr { + // Top node of the projection to replace + node: Node, + // Paths to the input files + sources: ScanSources, + cloud_options: Option, + // File Type + scan_type: Box, + // Column Alias + alias: Option, +} + +// Visit the logical plan and return CountStarExpr with the expr information gathered +// Return None if query is not a simple COUNT(*) FROM SOURCE +fn visit_logical_plan_for_scan_paths( + node: Node, + lp_arena: &Arena, + expr_arena: &Arena, + inside_union: bool, // Inside union's we do not check for COUNT(*) expression + use_fast_file_count: Option, // Overrides if Some +) -> Option { + match lp_arena.get(node) { + IR::Union { inputs, .. } => { + enum MutableSources { + Paths(Vec), + Buffers(Vec), + } + + let mut scan_type: Option> = None; + let mut cloud_options = None; + let mut sources = None; + + for input in inputs { + match visit_logical_plan_for_scan_paths( + *input, + lp_arena, + expr_arena, + true, + use_fast_file_count, + ) { + Some(expr) => { + match (expr.sources, &mut sources) { + ( + ScanSources::Paths(paths), + Some(MutableSources::Paths(mutable_paths)), + ) => mutable_paths.extend_from_slice(&paths[..]), + (ScanSources::Paths(paths), None) => { + sources = Some(MutableSources::Paths(paths.to_vec())) + }, + ( + ScanSources::Buffers(buffers), + Some(MutableSources::Buffers(mutable_buffers)), + ) => mutable_buffers.extend_from_slice(&buffers[..]), + (ScanSources::Buffers(buffers), None) => { + sources = Some(MutableSources::Buffers(buffers.to_vec())) + }, + _ => return None, + } + + // Take the first Some(_) cloud option + // TODO: Should check the cloud types are the same. + cloud_options = cloud_options.or(expr.cloud_options); + + match &scan_type { + None => scan_type = Some(expr.scan_type), + Some(scan_type) => { + // All scans must be of the same type (e.g. csv / parquet) + if std::mem::discriminant(&**scan_type) + != std::mem::discriminant(&*expr.scan_type) + { + return None; + } + }, + }; + }, + None => return None, + } + } + Some(CountStarExpr { + sources: match sources { + Some(MutableSources::Paths(paths)) => ScanSources::Paths(paths.into()), + Some(MutableSources::Buffers(buffers)) => ScanSources::Buffers(buffers.into()), + None => ScanSources::default(), + }, + scan_type: scan_type.unwrap(), + cloud_options, + node, + alias: None, + }) + }, + IR::Scan { + scan_type, + sources, + unified_scan_args, + .. + } => { + // New-streaming is generally on par for all except CSV (see https://github.com/pola-rs/polars/pull/22363). + // In the future we can potentially remove the dedicated count codepaths. + + let use_fast_file_count = use_fast_file_count.unwrap_or(match scan_type.as_ref() { + #[cfg(feature = "csv")] + FileScan::Csv { .. } => true, + _ => false, + }); + + if use_fast_file_count { + Some(CountStarExpr { + sources: sources.clone(), + scan_type: scan_type.clone(), + cloud_options: unified_scan_args.cloud_options.clone(), + node, + alias: None, + }) + } else { + None + } + }, + // A union can insert a simple projection to ensure all projections align. + // We can ignore that if we are inside a count star. + IR::SimpleProjection { input, .. } if inside_union => visit_logical_plan_for_scan_paths( + *input, + lp_arena, + expr_arena, + false, + use_fast_file_count, + ), + IR::Select { input, expr, .. } => { + if expr.len() == 1 { + let (valid, alias) = is_valid_count_expr(&expr[0], expr_arena); + if valid || inside_union { + return visit_logical_plan_for_scan_paths( + *input, + lp_arena, + expr_arena, + false, + use_fast_file_count, + ) + .map(|mut expr| { + expr.alias = alias; + expr.node = node; + expr + }); + } + } + None + }, + _ => None, + } +} + +fn is_valid_count_expr(e: &ExprIR, expr_arena: &Arena) -> (bool, Option) { + match expr_arena.get(e.node()) { + AExpr::Len => (true, e.get_alias().cloned()), + _ => (false, None), + } +} diff --git a/crates/polars-plan/src/plans/optimizer/cse/cse_expr.rs b/crates/polars-plan/src/plans/optimizer/cse/cse_expr.rs new file mode 100644 index 000000000000..ba71a03d42bc --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/cse/cse_expr.rs @@ -0,0 +1,1041 @@ +use std::hash::BuildHasher; + +use hashbrown::hash_map::RawEntryMut; +use polars_core::CHEAP_SERIES_HASH_LIMIT; +use polars_utils::aliases::PlFixedStateQuality; +use polars_utils::format_pl_smallstr; +use polars_utils::hashing::_boost_hash_combine; +use polars_utils::vec::CapacityByFactor; + +use super::*; +use crate::constants::CSE_REPLACED; +use crate::prelude::visitor::AexprNode; + +#[derive(Debug, Clone)] +struct ProjectionExprs { + expr: Vec, + /// offset from the back + /// `expr[expr.len() - common_sub_offset..]` + /// are the common sub expressions + common_sub_offset: usize, +} + +impl ProjectionExprs { + fn default_exprs(&self) -> &[ExprIR] { + &self.expr[..self.expr.len() - self.common_sub_offset] + } + + fn cse_exprs(&self) -> &[ExprIR] { + &self.expr[self.expr.len() - self.common_sub_offset..] + } + + fn new_with_cse(expr: Vec, common_sub_offset: usize) -> Self { + Self { + expr, + common_sub_offset, + } + } +} + +/// Identifier that shows the sub-expression path. +/// Must implement hash and equality and ideally +/// have little collisions +/// We will do a full expression comparison to check if the +/// expressions with equal identifiers are truly equal +#[derive(Clone, Debug)] +pub(super) struct Identifier { + inner: Option, + last_node: Option, + hb: PlFixedStateQuality, +} + +impl Identifier { + fn new() -> Self { + Self { + inner: None, + last_node: None, + hb: PlFixedStateQuality::with_seed(0), + } + } + + fn hash(&self) -> u64 { + self.inner.unwrap_or(0) + } + + fn ae_node(&self) -> AexprNode { + self.last_node.unwrap() + } + + fn is_equal(&self, other: &Self, arena: &Arena) -> bool { + self.inner == other.inner + && self.last_node.map(|v| v.hashable_and_cmp(arena)) + == other.last_node.map(|v| v.hashable_and_cmp(arena)) + } + + fn is_valid(&self) -> bool { + self.inner.is_some() + } + + fn materialize(&self) -> PlSmallStr { + format_pl_smallstr!("{}{:#x}", CSE_REPLACED, self.materialized_hash()) + } + + fn materialized_hash(&self) -> u64 { + self.inner.unwrap_or(0) + } + + fn combine(&mut self, other: &Identifier) { + let inner = match (self.inner, other.inner) { + (Some(l), Some(r)) => _boost_hash_combine(l, r), + (None, Some(r)) => r, + (Some(l), None) => l, + _ => return, + }; + self.inner = Some(inner); + } + + fn add_ae_node(&self, ae: &AexprNode, arena: &Arena) -> Self { + let hashed = self.hb.hash_one(ae.to_aexpr(arena)); + let inner = Some( + self.inner + .map_or(hashed, |l| _boost_hash_combine(l, hashed)), + ); + Self { + inner, + last_node: Some(*ae), + hb: self.hb, + } + } +} + +#[derive(Default)] +struct IdentifierMap { + inner: PlHashMap, +} + +impl IdentifierMap { + fn get(&self, id: &Identifier, arena: &Arena) -> Option<&V> { + self.inner + .raw_entry() + .from_hash(id.hash(), |k| k.is_equal(id, arena)) + .map(|(_k, v)| v) + } + + fn entry<'a, F: FnOnce() -> V>( + &'a mut self, + id: Identifier, + v: F, + arena: &Arena, + ) -> &'a mut V { + let h = id.hash(); + match self + .inner + .raw_entry_mut() + .from_hash(h, |k| k.is_equal(&id, arena)) + { + RawEntryMut::Occupied(entry) => entry.into_mut(), + RawEntryMut::Vacant(entry) => { + let (_, v) = entry.insert_with_hasher(h, id, v(), |id| id.hash()); + v + }, + } + } + fn insert(&mut self, id: Identifier, v: V, arena: &Arena) { + self.entry(id, || v, arena); + } + + fn iter(&self) -> impl Iterator { + self.inner.iter() + } +} + +/// Merges identical expressions into identical IDs. +/// +/// Does no analysis whether this leads to legal substitutions. +#[derive(Default)] +pub struct NaiveExprMerger { + node_to_uniq_id: PlHashMap, + uniq_id_to_node: Vec, + identifier_to_uniq_id: IdentifierMap, + arg_stack: Vec>, +} + +impl NaiveExprMerger { + pub fn add_expr(&mut self, node: Node, arena: &Arena) { + let node = AexprNode::new(node); + node.visit(self, arena).unwrap(); + } + + pub fn get_uniq_id(&self, node: Node) -> Option { + self.node_to_uniq_id.get(&node).copied() + } + + pub fn get_node(&self, uniq_id: u32) -> Option { + self.uniq_id_to_node.get(uniq_id as usize).copied() + } +} + +impl Visitor for NaiveExprMerger { + type Node = AexprNode; + type Arena = Arena; + + fn pre_visit( + &mut self, + _node: &Self::Node, + _arena: &Self::Arena, + ) -> PolarsResult { + self.arg_stack.push(None); + Ok(VisitRecursion::Continue) + } + + fn post_visit( + &mut self, + node: &Self::Node, + arena: &Self::Arena, + ) -> PolarsResult { + let mut identifier = Identifier::new(); + while let Some(Some(arg)) = self.arg_stack.pop() { + identifier.combine(&arg); + } + identifier = identifier.add_ae_node(node, arena); + let uniq_id = *self.identifier_to_uniq_id.entry( + identifier, + || { + let uniq_id = self.uniq_id_to_node.len() as u32; + self.uniq_id_to_node.push(node.node()); + uniq_id + }, + arena, + ); + self.node_to_uniq_id.insert(node.node(), uniq_id); + Ok(VisitRecursion::Continue) + } +} + +/// Identifier maps to Expr Node and count. +type SubExprCount = IdentifierMap<(Node, u32)>; +/// (post_visit_idx, identifier); +type IdentifierArray = Vec<(usize, Identifier)>; + +#[derive(Debug)] +enum VisitRecord { + /// entered a new expression + Entered(usize), + /// Every visited sub-expression pushes their identifier to the stack. + // The `bool` indicates if this expression is valid. + // This can be `AND` accumulated by the lineage of the expression to determine + // of the whole expression can be added. + // For instance a in a group_by we only want to use elementwise operation in cse: + // - `(col("a") * 2).sum(), (col("a") * 2)` -> we want to do `col("a") * 2` on a `with_columns` + // - `col("a").sum() * col("a").sum()` -> we don't want `sum` to run on `with_columns` + // as that doesn't have groups context. If we encounter a `sum` it should be flagged as `false` + // + // This should have the following stack + // id valid + // col(a) true + // sum false + // col(a) true + // sum false + // binary true + // -------------- accumulated + // false + SubExprId(Identifier, bool), +} + +fn skip_pre_visit(ae: &AExpr, is_groupby: bool) -> bool { + match ae { + AExpr::Window { .. } => true, + #[cfg(feature = "dtype-struct")] + AExpr::Ternary { .. } => is_groupby, + _ => false, + } +} + +/// Goes through an expression and generates a identifier +/// +/// The visitor uses a `visit_stack` to track traversal order. +/// +/// # Entering a node +/// When `pre-visit` is called we enter a new (sub)-expression and +/// we add `Entered` to the stack. +/// # Leaving a node +/// On `post-visit` when we leave the node and we pop all `SubExprIds` nodes. +/// Those are considered sub-expression of the leaving node +/// +/// We also record an `id_array` that followed the pre-visit order. This +/// is used to cache the `Identifiers`. +// +// # Example (this is not a docstring as clippy complains about spacing) +// Say we have the expression: `(col("f00").min() * col("bar")).sum()` +// with the following call tree: +// +// sum +// +// | +// +// binary: * +// +// | | +// +// col(bar) min +// +// | +// +// col(f00) +// +// # call order +// function-called stack stack-after(pop until E, push I) # ID +// pre-visit: sum E - +// pre-visit: binary: * EE - +// pre-visit: col(bar) EEE - +// post-visit: col(bar) EEE EEI id: col(bar) +// pre-visit: min EEIE - +// pre-visit: col(f00) EEIEE - +// post-visit: col(f00) EEIEE EEIEI id: col(f00) +// post-visit: min EEIEI EEII id: min!col(f00) +// post-visit: binary: * EEII EI id: binary: *!min!col(f00)!col(bar) +// post-visit: sum EI I id: sum!binary: *!min!col(f00)!col(bar) +struct ExprIdentifierVisitor<'a> { + se_count: &'a mut SubExprCount, + /// Materialized `CSE` materialized (name) hashes can collide. So we validate that all CSE counts + /// match name hash counts. + name_validation: &'a mut PlHashMap, + identifier_array: &'a mut IdentifierArray, + // Index in pre-visit traversal order. + pre_visit_idx: usize, + post_visit_idx: usize, + visit_stack: &'a mut Vec, + /// Offset in the identifier array + /// this allows us to use a single `vec` on multiple expressions + id_array_offset: usize, + // Whether the expression replaced a subexpression. + has_sub_expr: bool, + // During aggregation we only identify element-wise operations + is_group_by: bool, +} + +impl ExprIdentifierVisitor<'_> { + fn new<'a>( + se_count: &'a mut SubExprCount, + identifier_array: &'a mut IdentifierArray, + visit_stack: &'a mut Vec, + is_group_by: bool, + name_validation: &'a mut PlHashMap, + ) -> ExprIdentifierVisitor<'a> { + let id_array_offset = identifier_array.len(); + ExprIdentifierVisitor { + se_count, + name_validation, + identifier_array, + pre_visit_idx: 0, + post_visit_idx: 0, + visit_stack, + id_array_offset, + has_sub_expr: false, + is_group_by, + } + } + + /// pop all visit-records until an `Entered` is found. We accumulate a `SubExprId`s + /// to `id`. Finally we return the expression `idx` and `Identifier`. + /// This works due to the stack. + /// If we traverse another expression in the mean time, it will get popped of the stack first + /// so the returned identifier belongs to a single sub-expression + fn pop_until_entered(&mut self) -> (usize, Identifier, bool) { + let mut id = Identifier::new(); + let mut is_valid_accumulated = true; + + while let Some(item) = self.visit_stack.pop() { + match item { + VisitRecord::Entered(idx) => return (idx, id, is_valid_accumulated), + VisitRecord::SubExprId(s, valid) => { + id.combine(&s); + is_valid_accumulated &= valid + }, + } + } + unreachable!() + } + + /// return `None` -> node is accepted + /// return `Some(_)` node is not accepted and apply the given recursion operation + /// `Some(_, true)` don't accept this node, but can be a member of a cse. + /// `Some(_, false)` don't accept this node, and don't allow as a member of a cse. + fn accept_node_post_visit(&self, ae: &AExpr) -> Accepted { + match ae { + // window expressions should `evaluate_on_groups`, not `evaluate` + // so we shouldn't cache the children as they are evaluated incorrectly + AExpr::Window { .. } => REFUSE_SKIP, + // Don't allow this for now, as we can get `null().cast()` in ternary expressions. + // TODO! Add a typed null + AExpr::Literal(LiteralValue::Scalar(sc)) if sc.is_null() => REFUSE_NO_MEMBER, + AExpr::Literal(s) => { + match s { + LiteralValue::Series(s) => { + let dtype = s.dtype(); + + // Object and nested types are harder to hash and compare. + let allow = !(dtype.is_nested() | dtype.is_object()); + + if s.len() < CHEAP_SERIES_HASH_LIMIT && allow { + REFUSE_ALLOW_MEMBER + } else { + REFUSE_NO_MEMBER + } + }, + _ => REFUSE_ALLOW_MEMBER, + } + }, + AExpr::Column(_) | AExpr::Alias(_, _) => REFUSE_ALLOW_MEMBER, + AExpr::Len => { + if self.is_group_by { + REFUSE_NO_MEMBER + } else { + REFUSE_ALLOW_MEMBER + } + }, + #[cfg(feature = "random")] + AExpr::Function { + function: FunctionExpr::Random { .. }, + .. + } => REFUSE_NO_MEMBER, + #[cfg(feature = "rolling_window")] + AExpr::Function { + function: FunctionExpr::RollingExpr { .. }, + .. + } => REFUSE_NO_MEMBER, + AExpr::AnonymousFunction { .. } => REFUSE_NO_MEMBER, + _ => { + // During aggregation we only store elementwise operation in the state + // other operations we cannot add to the state as they have the output size of the + // groups, not the original dataframe + if self.is_group_by { + if !ae.is_elementwise_top_level() { + return REFUSE_NO_MEMBER; + } + match ae { + AExpr::AnonymousFunction { .. } => REFUSE_NO_MEMBER, + AExpr::Cast { .. } => REFUSE_ALLOW_MEMBER, + _ => ACCEPT, + } + } else { + ACCEPT + } + }, + } + } +} + +impl Visitor for ExprIdentifierVisitor<'_> { + type Node = AexprNode; + type Arena = Arena; + + fn pre_visit( + &mut self, + node: &Self::Node, + arena: &Self::Arena, + ) -> PolarsResult { + if skip_pre_visit(node.to_aexpr(arena), self.is_group_by) { + // Still add to the stack so that a parent becomes invalidated. + self.visit_stack + .push(VisitRecord::SubExprId(Identifier::new(), false)); + return Ok(VisitRecursion::Skip); + } + + self.visit_stack + .push(VisitRecord::Entered(self.pre_visit_idx)); + self.pre_visit_idx += 1; + + // implement default placeholders + self.identifier_array + .push((self.id_array_offset, Identifier::new())); + + Ok(VisitRecursion::Continue) + } + + fn post_visit( + &mut self, + node: &Self::Node, + arena: &Self::Arena, + ) -> PolarsResult { + let ae = node.to_aexpr(arena); + self.post_visit_idx += 1; + + let (pre_visit_idx, sub_expr_id, is_valid_accumulated) = self.pop_until_entered(); + // Create the Id of this node. + let id: Identifier = sub_expr_id.add_ae_node(node, arena); + + if !is_valid_accumulated { + self.identifier_array[pre_visit_idx + self.id_array_offset].0 = self.post_visit_idx; + self.visit_stack.push(VisitRecord::SubExprId(id, false)); + return Ok(VisitRecursion::Continue); + } + + // If we don't store this node + // we only push the visit_stack, so the parents know the trail. + if let Some((recurse, local_is_valid)) = self.accept_node_post_visit(ae) { + self.identifier_array[pre_visit_idx + self.id_array_offset].0 = self.post_visit_idx; + + self.visit_stack + .push(VisitRecord::SubExprId(id, local_is_valid)); + return Ok(recurse); + } + + // Store the created id. + self.identifier_array[pre_visit_idx + self.id_array_offset] = + (self.post_visit_idx, id.clone()); + + // We popped until entered, push this Id on the stack so the trail + // is available for the parent expression. + self.visit_stack + .push(VisitRecord::SubExprId(id.clone(), true)); + + let mat_h = id.materialized_hash(); + let (_, se_count) = self.se_count.entry(id, || (node.node(), 0), arena); + + *se_count += 1; + *self.name_validation.entry(mat_h).or_insert(0) += 1; + self.has_sub_expr |= *se_count > 1; + + Ok(VisitRecursion::Continue) + } +} + +struct CommonSubExprRewriter<'a> { + sub_expr_map: &'a SubExprCount, + identifier_array: &'a IdentifierArray, + /// keep track of the replaced identifiers. + replaced_identifiers: &'a mut IdentifierMap<()>, + + max_post_visit_idx: usize, + /// index in traversal order in which `identifier_array` + /// was written. This is the index in `identifier_array`. + visited_idx: usize, + /// Offset in the identifier array. + /// This allows us to use a single `vec` on multiple expressions + id_array_offset: usize, + /// Indicates if this expression is rewritten. + rewritten: bool, + is_group_by: bool, +} + +impl<'a> CommonSubExprRewriter<'a> { + fn new( + sub_expr_map: &'a SubExprCount, + identifier_array: &'a IdentifierArray, + replaced_identifiers: &'a mut IdentifierMap<()>, + id_array_offset: usize, + is_group_by: bool, + ) -> Self { + Self { + sub_expr_map, + identifier_array, + replaced_identifiers, + max_post_visit_idx: 0, + visited_idx: 0, + id_array_offset, + rewritten: false, + is_group_by, + } + } +} + +// # Example +// Expression tree with [pre-visit,post-visit] indices +// counted from 1 +// [1,8] binary: + +// +// | | +// +// [2,2] sum [4,7] sum +// +// | | +// +// [3,1] col(foo) [5,6] binary: * +// +// | | +// +// [6,3] col(bar) [7,5] sum +// +// | +// +// [8,4] col(foo) +// +// in this tree `col(foo).sum()` should be post-visited/mutated +// so if we are at `[2,2]` +// +// call stack +// pre-visit [1,8] binary -> no_mutate_and_continue -> visits children +// pre-visit [2,2] sum -> mutate_and_stop -> does not visit children +// post-visit [2,2] sum -> skip index to [4,7] (because we didn't visit children) +// pre-visit [4,7] sum -> no_mutate_and_continue -> visits children +// pre-visit [5,6] binary -> no_mutate_and_continue -> visits children +// pre-visit [6,3] col -> stop_recursion -> does not mutate +// pre-visit [7,5] sum -> mutate_and_stop -> does not visit children +// post-visit [7,5] -> skip index to end +impl RewritingVisitor for CommonSubExprRewriter<'_> { + type Node = AexprNode; + type Arena = Arena; + + fn pre_visit( + &mut self, + ae_node: &Self::Node, + arena: &mut Self::Arena, + ) -> PolarsResult { + let ae = ae_node.to_aexpr(arena); + if self.visited_idx + self.id_array_offset >= self.identifier_array.len() + || self.max_post_visit_idx + > self.identifier_array[self.visited_idx + self.id_array_offset].0 + || skip_pre_visit(ae, self.is_group_by) + { + return Ok(RewriteRecursion::Stop); + } + + let id = &self.identifier_array[self.visited_idx + self.id_array_offset].1; + + // Id placeholder not overwritten, so we can skip this sub-expression. + if !id.is_valid() { + self.visited_idx += 1; + let recurse = if ae_node.is_leaf(arena) { + RewriteRecursion::Stop + } else { + // continue visit its children to see + // if there are cse + RewriteRecursion::NoMutateAndContinue + }; + return Ok(recurse); + } + + // Because some expressions don't have hash / equality guarantee (e.g. floats) + // we can get none here. This must be changed later. + let Some((_, count)) = self.sub_expr_map.get(id, arena) else { + self.visited_idx += 1; + return Ok(RewriteRecursion::NoMutateAndContinue); + }; + if *count > 1 { + self.replaced_identifiers.insert(id.clone(), (), arena); + // rewrite this sub-expression, don't visit its children + Ok(RewriteRecursion::MutateAndStop) + } else { + // This is a unique expression + // visit its children to see if they are cse + self.visited_idx += 1; + Ok(RewriteRecursion::NoMutateAndContinue) + } + } + + fn mutate( + &mut self, + mut node: Self::Node, + arena: &mut Self::Arena, + ) -> PolarsResult { + let (post_visit_count, id) = + &self.identifier_array[self.visited_idx + self.id_array_offset]; + self.visited_idx += 1; + + // TODO!: check if we ever hit this branch + if *post_visit_count < self.max_post_visit_idx { + return Ok(node); + } + + self.max_post_visit_idx = *post_visit_count; + // DFS, so every post_visit that is smaller than `post_visit_count` + // is a subexpression of this node and we can skip that + // + // `self.visited_idx` will influence recursion strategy in `pre_visit` + // see call-stack comment above + while self.visited_idx < self.identifier_array.len() - self.id_array_offset + && *post_visit_count > self.identifier_array[self.visited_idx + self.id_array_offset].0 + { + self.visited_idx += 1; + } + // If this is not true, the traversal order in the visitor was different from the rewriter. + debug_assert_eq!( + node.hashable_and_cmp(arena), + id.ae_node().hashable_and_cmp(arena) + ); + + let name = id.materialize(); + node.assign(AExpr::col(name), arena); + self.rewritten = true; + + Ok(node) + } +} + +pub(crate) struct CommonSubExprOptimizer { + // amortize allocations + // these are cleared per lp node + se_count: SubExprCount, + id_array: IdentifierArray, + id_array_offsets: Vec, + replaced_identifiers: IdentifierMap<()>, + // these are cleared per expr node + visit_stack: Vec, + name_validation: PlHashMap, +} + +impl CommonSubExprOptimizer { + pub(crate) fn new() -> Self { + Self { + se_count: Default::default(), + id_array: Default::default(), + visit_stack: Default::default(), + id_array_offsets: Default::default(), + replaced_identifiers: Default::default(), + name_validation: Default::default(), + } + } + + fn visit_expression( + &mut self, + ae_node: AexprNode, + is_group_by: bool, + expr_arena: &mut Arena, + ) -> PolarsResult<(usize, bool)> { + let mut visitor = ExprIdentifierVisitor::new( + &mut self.se_count, + &mut self.id_array, + &mut self.visit_stack, + is_group_by, + &mut self.name_validation, + ); + ae_node.visit(&mut visitor, expr_arena).map(|_| ())?; + Ok((visitor.id_array_offset, visitor.has_sub_expr)) + } + + /// Mutate the expression. + /// Returns a new expression and a `bool` indicating if it was rewritten or not. + fn mutate_expression( + &mut self, + ae_node: AexprNode, + id_array_offset: usize, + is_group_by: bool, + expr_arena: &mut Arena, + ) -> PolarsResult<(AexprNode, bool)> { + let mut rewriter = CommonSubExprRewriter::new( + &self.se_count, + &self.id_array, + &mut self.replaced_identifiers, + id_array_offset, + is_group_by, + ); + ae_node + .rewrite(&mut rewriter, expr_arena) + .map(|out| (out, rewriter.rewritten)) + } + + fn find_cse( + &mut self, + expr: &[ExprIR], + expr_arena: &mut Arena, + id_array_offsets: &mut Vec, + is_group_by: bool, + schema: &Schema, + ) -> PolarsResult> { + let mut has_sub_expr = false; + + // First get all cse's. + for e in expr { + // The visitor can return early thus depleted its stack + // on a previous iteration. + self.visit_stack.clear(); + + // Visit expressions and collect sub-expression counts. + let ae_node = AexprNode::new(e.node()); + let (id_array_offset, this_expr_has_se) = + self.visit_expression(ae_node, is_group_by, expr_arena)?; + id_array_offsets.push(id_array_offset as u32); + has_sub_expr |= this_expr_has_se; + } + + // Ensure that the `materialized hashes` count matches that of the CSE count. + // It can happen that CSE collide and in that case we fallback and skip CSE. + for (id, (_, count)) in self.se_count.iter() { + let mat_h = id.materialized_hash(); + let valid = if let Some(name_count) = self.name_validation.get(&mat_h) { + *name_count == *count + } else { + false + }; + + if !valid { + if verbose() { + eprintln!( + "materialized names collided in common subexpression elimination.\n backtrace and run without CSE" + ) + } + return Ok(None); + } + } + + if has_sub_expr { + let mut new_expr = Vec::with_capacity_by_factor(expr.len(), 1.3); + + // Then rewrite the expressions that have a cse count > 1. + for (e, offset) in expr.iter().zip(id_array_offsets.iter()) { + let ae_node = AexprNode::new(e.node()); + + let (out, rewritten) = + self.mutate_expression(ae_node, *offset as usize, is_group_by, expr_arena)?; + + let out_node = out.node(); + let mut out_e = e.clone(); + let new_node = if !rewritten { + out_e + } else { + out_e.set_node(out_node); + + // Ensure the function ExprIR's have the proper names. + // This is needed for structs to get the proper field + let mut scratch = vec![]; + let mut stack = vec![(e.node(), out_node)]; + while let Some((original, new)) = stack.pop() { + // Don't follow identical nodes. + if original == new { + continue; + } + scratch.clear(); + let aes = expr_arena.get_many_mut([original, new]); + + // Only follow paths that are the same. + if std::mem::discriminant(aes[0]) != std::mem::discriminant(aes[1]) { + continue; + } + + aes[0].inputs_rev(&mut scratch); + let offset = scratch.len(); + aes[1].inputs_rev(&mut scratch); + + // If they have a different number of inputs, we don't follow the nodes. + if scratch.len() != offset * 2 { + continue; + } + + for i in 0..scratch.len() / 2 { + stack.push((scratch[i], scratch[i + offset])); + } + + match expr_arena.get_many_mut([original, new]) { + [ + AExpr::Function { + input: input_original, + .. + }, + AExpr::Function { + input: input_new, .. + }, + ] => { + for (new, original) in input_new.iter_mut().zip(input_original) { + new.set_alias(original.output_name().clone()); + } + }, + [ + AExpr::AnonymousFunction { + input: input_original, + .. + }, + AExpr::AnonymousFunction { + input: input_new, .. + }, + ] => { + for (new, original) in input_new.iter_mut().zip(input_original) { + new.set_alias(original.output_name().clone()); + } + }, + _ => {}, + } + } + + // If we don't end with an alias we add an alias. Because the normal left-hand + // rule we apply for determining the name will not work we now refer to + // intermediate temporary names starting with the `CSE_REPLACED` constant. + if !e.has_alias() { + let name = ae_node.to_field(schema, expr_arena)?.name; + out_e.set_alias(name.clone()); + } + out_e + }; + new_expr.push(new_node) + } + // Add the tmp columns + for id in self.replaced_identifiers.inner.keys() { + let (node, _count) = self.se_count.get(id, expr_arena).unwrap(); + let name = id.materialize(); + let out_e = ExprIR::new(*node, OutputName::Alias(name)); + new_expr.push(out_e) + } + let expr = + ProjectionExprs::new_with_cse(new_expr, self.replaced_identifiers.inner.len()); + Ok(Some(expr)) + } else { + Ok(None) + } + } +} + +impl RewritingVisitor for CommonSubExprOptimizer { + type Node = IRNode; + type Arena = IRNodeArena; + + fn pre_visit( + &mut self, + node: &Self::Node, + arena: &mut Self::Arena, + ) -> PolarsResult { + use IR::*; + Ok(match node.to_alp(&arena.0) { + Select { .. } | HStack { .. } | GroupBy { .. } => RewriteRecursion::MutateAndContinue, + _ => RewriteRecursion::NoMutateAndContinue, + }) + } + + fn mutate(&mut self, node: Self::Node, arena: &mut Self::Arena) -> PolarsResult { + let mut id_array_offsets = std::mem::take(&mut self.id_array_offsets); + + self.se_count.inner.clear(); + self.name_validation.clear(); + self.id_array.clear(); + id_array_offsets.clear(); + self.replaced_identifiers.inner.clear(); + + let arena_idx = node.node(); + let alp = arena.0.get(arena_idx); + + match alp { + IR::Select { + input, + expr, + schema, + options, + } => { + let input_schema = arena.0.get(*input).schema(&arena.0); + if let Some(expr) = self.find_cse( + expr, + &mut arena.1, + &mut id_array_offsets, + false, + input_schema.as_ref().as_ref(), + )? { + let schema = schema.clone(); + let options = *options; + + let lp = IRBuilder::new(*input, &mut arena.1, &mut arena.0) + .with_columns( + expr.cse_exprs().to_vec(), + ProjectionOptions { + run_parallel: options.run_parallel, + duplicate_check: options.duplicate_check, + // These columns might have different + // lengths from the dataframe, but + // they are only temporaries that will + // be removed by the evaluation of the + // default_exprs and the subsequent + // projection. + should_broadcast: false, + }, + ) + .build(); + let input = arena.0.add(lp); + + let lp = IR::Select { + input, + expr: expr.default_exprs().to_vec(), + schema, + options, + }; + arena.0.replace(arena_idx, lp); + } + }, + IR::HStack { + input, + exprs, + schema, + options, + } => { + let input_schema = arena.0.get(*input).schema(&arena.0); + if let Some(exprs) = self.find_cse( + exprs, + &mut arena.1, + &mut id_array_offsets, + false, + input_schema.as_ref().as_ref(), + )? { + let schema = schema.clone(); + let options = *options; + let input = *input; + + let lp = IRBuilder::new(input, &mut arena.1, &mut arena.0) + .with_columns( + exprs.cse_exprs().to_vec(), + // These columns might have different + // lengths from the dataframe, but they + // are only temporaries that will be + // removed by the evaluation of the + // default_exprs and the subsequent + // projection. + ProjectionOptions { + run_parallel: options.run_parallel, + duplicate_check: options.duplicate_check, + should_broadcast: false, + }, + ) + .with_columns(exprs.default_exprs().to_vec(), options) + .build(); + let input = arena.0.add(lp); + + let lp = IR::SimpleProjection { + input, + columns: schema, + }; + arena.0.replace(arena_idx, lp); + } + }, + IR::GroupBy { + input, + keys, + aggs, + options, + maintain_order, + apply, + schema, + } => { + let input_schema = arena.0.get(*input).schema(&arena.0); + if let Some(aggs) = self.find_cse( + aggs, + &mut arena.1, + &mut id_array_offsets, + true, + input_schema.as_ref().as_ref(), + )? { + let keys = keys.clone(); + let options = options.clone(); + let schema = schema.clone(); + let apply = apply.clone(); + let maintain_order = *maintain_order; + let input = *input; + + let lp = IRBuilder::new(input, &mut arena.1, &mut arena.0) + .with_columns(aggs.cse_exprs().to_vec(), Default::default()) + .build(); + let input = arena.0.add(lp); + + let lp = IR::GroupBy { + input, + keys, + aggs: aggs.default_exprs().to_vec(), + options, + schema, + maintain_order, + apply, + }; + arena.0.replace(arena_idx, lp); + } + }, + _ => {}, + } + + self.id_array_offsets = id_array_offsets; + Ok(node) + } +} diff --git a/crates/polars-plan/src/plans/optimizer/cse/cse_lp.rs b/crates/polars-plan/src/plans/optimizer/cse/cse_lp.rs new file mode 100644 index 000000000000..9847e1561a48 --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/cse/cse_lp.rs @@ -0,0 +1,398 @@ +use std::hash::BuildHasher; + +use hashbrown::hash_map::RawEntryMut; + +use super::*; +use crate::prelude::visitor::IRNode; + +mod identifier_impl { + use polars_utils::aliases::PlFixedStateQuality; + use polars_utils::hashing::_boost_hash_combine; + + use super::*; + /// Identifier that shows the sub-expression path. + /// Must implement hash and equality and ideally + /// have little collisions + /// We will do a full expression comparison to check if the + /// expressions with equal identifiers are truly equal + #[derive(Clone)] + pub(super) struct Identifier { + inner: Option, + last_node: Option, + hb: PlFixedStateQuality, + } + + impl Identifier { + pub fn hash(&self) -> u64 { + self.inner.unwrap_or(0) + } + + pub fn is_equal( + &self, + other: &Self, + lp_arena: &Arena, + expr_arena: &Arena, + ) -> bool { + self.inner == other.inner + && match (self.last_node, other.last_node) { + (None, None) => true, + (Some(l), Some(r)) => { + // We ignore caches as they are inserted on the node locations. + // In that case we don't want to cmp the cache (as we just inserted it), + // but the input node of the cache. + l.hashable_and_cmp(lp_arena, expr_arena).ignore_caches() + == r.hashable_and_cmp(lp_arena, expr_arena).ignore_caches() + }, + _ => false, + } + } + pub fn new() -> Self { + Self { + inner: None, + last_node: None, + hb: PlFixedStateQuality::with_seed(0), + } + } + + pub fn is_valid(&self) -> bool { + self.inner.is_some() + } + + pub fn combine(&mut self, other: &Identifier) { + let inner = match (self.inner, other.inner) { + (Some(l), Some(r)) => _boost_hash_combine(l, r), + (None, Some(r)) => r, + (Some(l), None) => l, + _ => return, + }; + self.inner = Some(inner); + } + + pub fn add_alp_node( + &self, + alp: &IRNode, + lp_arena: &Arena, + expr_arena: &Arena, + ) -> Self { + let hashed = self.hb.hash_one(alp.hashable_and_cmp(lp_arena, expr_arena)); + let inner = Some( + self.inner + .map_or(hashed, |l| _boost_hash_combine(l, hashed)), + ); + Self { + inner, + last_node: Some(*alp), + hb: self.hb, + } + } + } +} +use identifier_impl::*; + +#[derive(Default)] +struct IdentifierMap { + inner: PlHashMap, +} + +impl IdentifierMap { + fn get(&self, id: &Identifier, lp_arena: &Arena, expr_arena: &Arena) -> Option<&V> { + self.inner + .raw_entry() + .from_hash(id.hash(), |k| k.is_equal(id, lp_arena, expr_arena)) + .map(|(_k, v)| v) + } + + fn entry V>( + &mut self, + id: Identifier, + v: F, + lp_arena: &Arena, + expr_arena: &Arena, + ) -> &mut V { + let h = id.hash(); + match self + .inner + .raw_entry_mut() + .from_hash(h, |k| k.is_equal(&id, lp_arena, expr_arena)) + { + RawEntryMut::Occupied(entry) => entry.into_mut(), + RawEntryMut::Vacant(entry) => { + let (_, v) = entry.insert_with_hasher(h, id, v(), |id| id.hash()); + v + }, + } + } +} + +/// Identifier maps to Expr Node and count. +type SubPlanCount = IdentifierMap<(Node, u32)>; +/// (post_visit_idx, identifier); +type IdentifierArray = Vec<(usize, Identifier)>; + +/// See Expr based CSE for explanations. +enum VisitRecord { + /// Entered a new plan node + Entered(usize), + SubPlanId(Identifier), +} + +struct LpIdentifierVisitor<'a> { + sp_count: &'a mut SubPlanCount, + identifier_array: &'a mut IdentifierArray, + // Index in pre-visit traversal order. + pre_visit_idx: usize, + post_visit_idx: usize, + visit_stack: Vec, + has_subplan: bool, +} + +impl LpIdentifierVisitor<'_> { + fn new<'a>( + sp_count: &'a mut SubPlanCount, + identifier_array: &'a mut IdentifierArray, + ) -> LpIdentifierVisitor<'a> { + LpIdentifierVisitor { + sp_count, + identifier_array, + pre_visit_idx: 0, + post_visit_idx: 0, + visit_stack: vec![], + has_subplan: false, + } + } + + fn pop_until_entered(&mut self) -> (usize, Identifier) { + let mut id = Identifier::new(); + + while let Some(item) = self.visit_stack.pop() { + match item { + VisitRecord::Entered(idx) => return (idx, id), + VisitRecord::SubPlanId(s) => { + id.combine(&s); + }, + } + } + unreachable!() + } +} + +fn skip_children(lp: &IR) -> bool { + match lp { + // Don't visit all the files in a `scan *` operation. + // Put an arbitrary limit to 20 files now. + IR::Union { + options, inputs, .. + } => options.from_partitioned_ds && inputs.len() > 20, + _ => false, + } +} + +impl Visitor for LpIdentifierVisitor<'_> { + type Node = IRNode; + type Arena = IRNodeArena; + + fn pre_visit( + &mut self, + node: &Self::Node, + arena: &Self::Arena, + ) -> PolarsResult { + self.visit_stack + .push(VisitRecord::Entered(self.pre_visit_idx)); + self.pre_visit_idx += 1; + + self.identifier_array.push((0, Identifier::new())); + + if skip_children(node.to_alp(&arena.0)) { + Ok(VisitRecursion::Skip) + } else { + Ok(VisitRecursion::Continue) + } + } + + fn post_visit( + &mut self, + node: &Self::Node, + arena: &Self::Arena, + ) -> PolarsResult { + self.post_visit_idx += 1; + + let (pre_visit_idx, sub_plan_id) = self.pop_until_entered(); + + // Create the Id of this node. + let id: Identifier = sub_plan_id.add_alp_node(node, &arena.0, &arena.1); + + // Store the created id. + self.identifier_array[pre_visit_idx] = (self.post_visit_idx, id.clone()); + + // We popped until entered, push this Id on the stack so the trail + // is available for the parent plan. + self.visit_stack.push(VisitRecord::SubPlanId(id.clone())); + + let (_, sp_count) = self + .sp_count + .entry(id, || (node.node(), 0), &arena.0, &arena.1); + *sp_count += 1; + self.has_subplan |= *sp_count > 1; + Ok(VisitRecursion::Continue) + } +} + +pub(super) type CacheId2Caches = PlHashMap)>; + +struct CommonSubPlanRewriter<'a> { + sp_count: &'a SubPlanCount, + identifier_array: &'a IdentifierArray, + + max_post_visit_idx: usize, + /// index in traversal order in which `identifier_array` + /// was written. This is the index in `identifier_array`. + visited_idx: usize, + /// Indicates if this expression is rewritten. + rewritten: bool, + cache_id: IdentifierMap, + // Maps cache_id : (cache_count and cache_nodes) + cache_id_to_caches: CacheId2Caches, +} + +impl<'a> CommonSubPlanRewriter<'a> { + fn new(sp_count: &'a SubPlanCount, identifier_array: &'a IdentifierArray) -> Self { + Self { + sp_count, + identifier_array, + max_post_visit_idx: 0, + visited_idx: 0, + rewritten: false, + cache_id: Default::default(), + cache_id_to_caches: Default::default(), + } + } +} + +impl RewritingVisitor for CommonSubPlanRewriter<'_> { + type Node = IRNode; + type Arena = IRNodeArena; + + fn pre_visit( + &mut self, + lp_node: &Self::Node, + arena: &mut Self::Arena, + ) -> PolarsResult { + if self.visited_idx >= self.identifier_array.len() + || self.max_post_visit_idx > self.identifier_array[self.visited_idx].0 + { + return Ok(RewriteRecursion::Stop); + } + + let id = &self.identifier_array[self.visited_idx].1; + + // Id placeholder not overwritten, so we can skip this sub-expression. + if !id.is_valid() { + self.visited_idx += 1; + return Ok(RewriteRecursion::NoMutateAndContinue); + } + + let Some((_, count)) = self.sp_count.get(id, &arena.0, &arena.1) else { + self.visited_idx += 1; + return Ok(RewriteRecursion::NoMutateAndContinue); + }; + + if *count > 1 { + // Rewrite this sub-plan, don't visit its children + Ok(RewriteRecursion::MutateAndStop) + } + // Never mutate if count <= 1. The post-visit will search for the node, and not be able to find it + else { + // Don't traverse the children. + if skip_children(lp_node.to_alp(&arena.0)) { + return Ok(RewriteRecursion::Stop); + } + // This is a unique plan + // visit its children to see if they are cse + self.visited_idx += 1; + Ok(RewriteRecursion::NoMutateAndContinue) + } + } + + fn mutate( + &mut self, + mut node: Self::Node, + arena: &mut Self::Arena, + ) -> PolarsResult { + let (post_visit_count, id) = &self.identifier_array[self.visited_idx]; + self.visited_idx += 1; + + if *post_visit_count < self.max_post_visit_idx { + return Ok(node); + } + self.max_post_visit_idx = *post_visit_count; + while self.visited_idx < self.identifier_array.len() + && *post_visit_count > self.identifier_array[self.visited_idx].0 + { + self.visited_idx += 1; + } + + let cache_id = self.cache_id.inner.len(); + let cache_id = *self + .cache_id + .entry(id.clone(), || cache_id, &arena.0, &arena.1); + let cache_count = self.sp_count.get(id, &arena.0, &arena.1).unwrap().1; + + let cache_node = IR::Cache { + input: node.node(), + id: cache_id, + cache_hits: cache_count - 1, + }; + node.assign(cache_node, &mut arena.0); + let (_count, nodes) = self + .cache_id_to_caches + .entry(cache_id) + .or_insert_with(|| (cache_count, vec![])); + nodes.push(node.node()); + self.rewritten = true; + Ok(node) + } +} + +pub(crate) fn elim_cmn_subplans( + root: Node, + lp_arena: &mut Arena, + expr_arena: &mut Arena, +) -> (Node, bool, CacheId2Caches) { + let mut sp_count = Default::default(); + let mut id_array = Default::default(); + + with_ir_arena(lp_arena, expr_arena, |arena| { + let lp_node = IRNode::new_mutate(root); + let mut visitor = LpIdentifierVisitor::new(&mut sp_count, &mut id_array); + + lp_node.visit(&mut visitor, arena).map(|_| ()).unwrap(); + + let lp_node = IRNode::new_mutate(root); + let mut rewriter = CommonSubPlanRewriter::new(&sp_count, &id_array); + lp_node.rewrite(&mut rewriter, arena).unwrap(); + + (root, rewriter.rewritten, rewriter.cache_id_to_caches) + }) +} + +/// Prune unused caches. +/// In the query below the query will be insert cache 0 with a count of 2 on `lf.select` +/// and cache 1 with a count of 3 on `lf`. But because cache 0 is higher in the chain cache 1 +/// will never be used. So we prune caches that don't fit their count. +/// +/// `conctat([lf.select(), lf.select(), lf])` +/// +pub(crate) fn prune_unused_caches(lp_arena: &mut Arena, cid2c: CacheId2Caches) { + for (count, nodes) in cid2c.values() { + if *count == nodes.len() as u32 { + continue; + } + + for node in nodes { + let IR::Cache { input, .. } = lp_arena.get(*node) else { + unreachable!() + }; + lp_arena.swap(*input, *node) + } + } +} diff --git a/crates/polars-plan/src/plans/optimizer/cse/mod.rs b/crates/polars-plan/src/plans/optimizer/cse/mod.rs new file mode 100644 index 000000000000..f5df3bec7595 --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/cse/mod.rs @@ -0,0 +1,17 @@ +mod cse_expr; +mod cse_lp; + +pub(super) use cse_expr::CommonSubExprOptimizer; +pub use cse_expr::NaiveExprMerger; +pub(super) use cse_lp::{elim_cmn_subplans, prune_unused_caches}; + +use super::*; + +type Accepted = Option<(VisitRecursion, bool)>; +// Don't allow this node in a cse. +const REFUSE_NO_MEMBER: Accepted = Some((VisitRecursion::Continue, false)); +// Don't allow this node, but allow as a member of a cse. +const REFUSE_ALLOW_MEMBER: Accepted = Some((VisitRecursion::Continue, true)); +const REFUSE_SKIP: Accepted = Some((VisitRecursion::Skip, false)); +// Accept this node. +const ACCEPT: Accepted = None; diff --git a/crates/polars-plan/src/plans/optimizer/delay_rechunk.rs b/crates/polars-plan/src/plans/optimizer/delay_rechunk.rs new file mode 100644 index 000000000000..0e0de1f02d4e --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/delay_rechunk.rs @@ -0,0 +1,69 @@ +use std::collections::BTreeSet; + +use super::*; + +#[derive(Default)] +pub(super) struct DelayRechunk { + processed: BTreeSet, +} + +impl DelayRechunk { + pub(super) fn new() -> Self { + Default::default() + } +} + +impl OptimizationRule for DelayRechunk { + fn optimize_plan( + &mut self, + lp_arena: &mut Arena, + _expr_arena: &mut Arena, + node: Node, + ) -> PolarsResult> { + match lp_arena.get(node) { + // An aggregation can be partitioned, its wasteful to rechunk before that partition. + #[allow(unused_mut)] + IR::GroupBy { input, keys, .. } => { + // Multiple keys on multiple chunks is much slower, so rechunk. + if !self.processed.insert(node.0) || keys.len() > 1 { + return Ok(None); + }; + + use IR::*; + let mut input_node = None; + for (node, lp) in (&*lp_arena).iter(*input) { + match lp { + Scan { .. } => { + input_node = Some(node); + break; + }, + Union { .. } => { + input_node = Some(node); + break; + }, + // don't delay rechunk if there is a join first + Join { .. } => break, + _ => {}, + } + } + + if let Some(node) = input_node { + match lp_arena.get_mut(node) { + Scan { + unified_scan_args, .. + } => { + unified_scan_args.rechunk = false; + }, + Union { options, .. } => { + options.rechunk = false; + }, + _ => unreachable!(), + } + }; + + Ok(None) + }, + _ => Ok(None), + } + } +} diff --git a/crates/polars-plan/src/plans/optimizer/flatten_union.rs b/crates/polars-plan/src/plans/optimizer/flatten_union.rs new file mode 100644 index 000000000000..e79c0b972dd4 --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/flatten_union.rs @@ -0,0 +1,52 @@ +use IR::*; +use polars_core::error::PolarsResult; +use polars_utils::arena::{Arena, Node}; + +use super::OptimizationRule; +use crate::prelude::IR; + +pub struct FlattenUnionRule {} + +fn get_union_inputs(node: Node, lp_arena: &Arena) -> Option<&[Node]> { + match lp_arena.get(node) { + IR::Union { inputs, .. } => Some(inputs), + _ => None, + } +} + +impl OptimizationRule for FlattenUnionRule { + fn optimize_plan( + &mut self, + lp_arena: &mut polars_utils::arena::Arena, + _expr_arena: &mut polars_utils::arena::Arena, + node: polars_utils::arena::Node, + ) -> PolarsResult> { + let lp = lp_arena.get(node); + + match lp { + Union { inputs, options } + if inputs.iter().any(|node| match lp_arena.get(*node) { + Union { options, .. } => !options.flattened_by_opt, + _ => false, + }) => + { + let mut new_inputs = Vec::with_capacity(inputs.len() * 2); + let mut options = *options; + + for node in inputs { + match get_union_inputs(*node, lp_arena) { + Some(inp) => new_inputs.extend_from_slice(inp), + None => new_inputs.push(*node), + } + } + options.flattened_by_opt = true; + + Ok(Some(Union { + inputs: new_inputs, + options, + })) + }, + _ => Ok(None), + } + } +} diff --git a/crates/polars-plan/src/plans/optimizer/fused.rs b/crates/polars-plan/src/plans/optimizer/fused.rs new file mode 100644 index 000000000000..d1d97fde9b74 --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/fused.rs @@ -0,0 +1,195 @@ +#[cfg(feature = "python")] +use self::python_dsl::PythonScanSource; +use super::*; + +pub struct FusedArithmetic {} + +fn get_expr(input: &[Node], op: FusedOperator, expr_arena: &Arena) -> AExpr { + let input = input + .iter() + .copied() + .map(|n| ExprIR::from_node(n, expr_arena)) + .collect(); + let mut options = FunctionOptions { + collect_groups: ApplyOptions::ElementWise, + cast_options: Some(CastingRules::cast_to_supertypes()), + ..Default::default() + }; + // order of operations change because of FMA + // so we must toggle this check off + // it is still safe as it is a trusted operation + unsafe { options.no_check_lengths() } + AExpr::Function { + input, + function: FunctionExpr::Fused(op), + options, + } +} + +fn check_eligible( + left: &Node, + right: &Node, + lp_node: Node, + expr_arena: &Arena, + lp_arena: &Arena, +) -> PolarsResult<(Option, Option)> { + let Some(input_node) = lp_arena.get(lp_node).get_input() else { + return Ok((None, None)); + }; + let schema = lp_arena.get(input_node).schema(lp_arena); + let field_left = expr_arena + .get(*left) + .to_field(&schema, Context::Default, expr_arena)?; + let type_right = expr_arena + .get(*right) + .get_type(&schema, Context::Default, expr_arena)?; + let type_left = &field_left.dtype; + // Exclude literals for now as these will not benefit from fused operations downstream #9857 + // This optimization would also interfere with the `col -> lit` type-coercion rules + // And it might also interfere with constant folding which is a more suitable optimizations here + if type_left.is_primitive_numeric() + && type_right.is_primitive_numeric() + && !has_aexpr_literal(*left, expr_arena) + && !has_aexpr_literal(*right, expr_arena) + { + Ok((Some(true), Some(field_left))) + } else { + Ok((Some(false), None)) + } +} + +impl OptimizationRule for FusedArithmetic { + #[allow(clippy::float_cmp)] + fn optimize_expr( + &mut self, + expr_arena: &mut Arena, + expr_node: Node, + lp_arena: &Arena, + lp_node: Node, + ) -> PolarsResult> { + // We don't want to fuse arithmetic that we send to pyarrow. + #[cfg(feature = "python")] + if let IR::PythonScan { options } = lp_arena.get(lp_node) { + if matches!( + options.python_source, + PythonScanSource::Pyarrow | PythonScanSource::IOPlugin + ) { + return Ok(None); + } + }; + let expr = expr_arena.get(expr_node); + + use AExpr::*; + match expr { + BinaryExpr { + left, + op: Operator::Plus, + right, + } => { + // FUSED MULTIPLY ADD + // For fma the plus is always the out as the multiply takes prevalence + match expr_arena.get(*left) { + // Argument order is a + b * c + // so we must swap operands + // + // input + // (a * b) + c + // swapped as + // c + (a * b) + BinaryExpr { + left: a, + op: Operator::Multiply, + right: b, + } => match check_eligible(left, right, lp_node, expr_arena, lp_arena)? { + (None, _) | (Some(false), _) => Ok(None), + (Some(true), Some(output_field)) => { + let input = &[*right, *a, *b]; + let fma = get_expr(input, FusedOperator::MultiplyAdd, expr_arena); + let node = expr_arena.add(fma); + // we reordered the arguments, so we don't obey the left expression output name + // rule anymore, that's why we alias + Ok(Some(Alias(node, output_field.name.clone()))) + }, + _ => unreachable!(), + }, + _ => match expr_arena.get(*right) { + // input + // (a + (b * c) + // kept as input + BinaryExpr { + left: a, + op: Operator::Multiply, + right: b, + } => match check_eligible(left, right, lp_node, expr_arena, lp_arena)? { + (None, _) | (Some(false), _) => Ok(None), + (Some(true), _) => { + let input = &[*left, *a, *b]; + Ok(Some(get_expr( + input, + FusedOperator::MultiplyAdd, + expr_arena, + ))) + }, + }, + _ => Ok(None), + }, + } + }, + + BinaryExpr { + left, + op: Operator::Minus, + right, + } => { + // FUSED SUB MULTIPLY + match expr_arena.get(*right) { + // input + // (a - (b * c) + // kept as input + BinaryExpr { + left: a, + op: Operator::Multiply, + right: b, + } => match check_eligible(left, right, lp_node, expr_arena, lp_arena)? { + (None, _) | (Some(false), _) => Ok(None), + (Some(true), _) => { + let input = &[*left, *a, *b]; + Ok(Some(get_expr( + input, + FusedOperator::SubMultiply, + expr_arena, + ))) + }, + }, + _ => { + // FUSED MULTIPLY SUB + match expr_arena.get(*left) { + // input + // (a * b) - c + // kept as input + BinaryExpr { + left: a, + op: Operator::Multiply, + right: b, + } => { + match check_eligible(left, right, lp_node, expr_arena, lp_arena)? { + (None, _) | (Some(false), _) => Ok(None), + (Some(true), _) => { + let input = &[*a, *b, *right]; + Ok(Some(get_expr( + input, + FusedOperator::MultiplySub, + expr_arena, + ))) + }, + } + }, + _ => Ok(None), + } + }, + } + }, + _ => Ok(None), + } + } +} diff --git a/crates/polars-plan/src/plans/optimizer/join_utils.rs b/crates/polars-plan/src/plans/optimizer/join_utils.rs new file mode 100644 index 000000000000..8e0f1d97c1f0 --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/join_utils.rs @@ -0,0 +1,149 @@ +use polars_core::error::{PolarsResult, polars_bail}; +use polars_core::schema::*; +use polars_utils::arena::{Arena, Node}; +use polars_utils::pl_str::PlSmallStr; + +use super::{AExpr, aexpr_to_leaf_names_iter}; +use crate::plans::visitor::{AexprNode, RewriteRecursion, RewritingVisitor, TreeWalker}; +use crate::plans::{ExprIR, OutputName}; + +/// Join origin of an expression +#[derive(Debug, Clone, PartialEq, Copy)] +#[repr(u8)] +pub(crate) enum ExprOrigin { + // Note: BitOr is implemented on this struct that relies on this exact u8 + // repr layout (i.e. treated as a bitfield). + // + /// Utilizes no columns + None = 0b00, + /// Utilizes columns from the left side of the join + Left = 0b10, + /// Utilizes columns from the right side of the join + Right = 0b01, + /// Utilizes columns from both sides of the join + Both = 0b11, +} + +impl ExprOrigin { + /// Errors with ColumnNotFound if a column cannot be found on either side. + pub(crate) fn get_expr_origin( + root: Node, + expr_arena: &Arena, + left_schema: &Schema, + right_schema: &Schema, + suffix: &str, + ) -> PolarsResult { + aexpr_to_leaf_names_iter(root, expr_arena).try_fold( + ExprOrigin::None, + |acc_origin, column_name| { + Ok(acc_origin + | Self::get_column_origin(&column_name, left_schema, right_schema, suffix)?) + }, + ) + } + + /// Errors with ColumnNotFound if a column cannot be found on either side. + pub(crate) fn get_column_origin( + column_name: &str, + left_schema: &Schema, + right_schema: &Schema, + suffix: &str, + ) -> PolarsResult { + Ok(if left_schema.contains(column_name) { + ExprOrigin::Left + } else if right_schema.contains(column_name) + || column_name + .strip_suffix(suffix) + .is_some_and(|x| right_schema.contains(x)) + { + ExprOrigin::Right + } else { + polars_bail!(ColumnNotFound: "{}", column_name) + }) + } +} + +impl std::ops::BitOr for ExprOrigin { + type Output = ExprOrigin; + + fn bitor(self, rhs: Self) -> Self::Output { + unsafe { std::mem::transmute::(self as u8 | rhs as u8) } + } +} + +impl std::ops::BitOrAssign for ExprOrigin { + fn bitor_assign(&mut self, rhs: Self) { + *self = *self | rhs; + } +} + +pub(super) fn remove_suffix<'a>( + expr: &mut ExprIR, + expr_arena: &mut Arena, + schema_rhs: &'a Schema, + suffix: &'a str, +) { + let schema = schema_rhs; + // Using AexprNode::rewrite() ensures we do not mutate any nodes in-place. The nodes may be + // used in other locations and mutating them will cause really confusing bugs, such as + // https://github.com/pola-rs/polars/issues/20831. + let node = AexprNode::new(expr.node()) + .rewrite(&mut RemoveSuffix { schema, suffix }, expr_arena) + .unwrap() + .node(); + + expr.set_node(node); + + if let OutputName::ColumnLhs(colname) = expr.output_name_inner() { + if colname.ends_with(suffix) && !schema.contains(colname.as_str()) { + let name = PlSmallStr::from(&colname[..colname.len() - suffix.len()]); + expr.set_columnlhs(name); + } + } + + struct RemoveSuffix<'a> { + schema: &'a Schema, + suffix: &'a str, + } + + impl RewritingVisitor for RemoveSuffix<'_> { + type Node = AexprNode; + type Arena = Arena; + + fn pre_visit( + &mut self, + node: &Self::Node, + arena: &mut Self::Arena, + ) -> polars_core::prelude::PolarsResult { + let AExpr::Column(colname) = arena.get(node.node()) else { + return Ok(RewriteRecursion::NoMutateAndContinue); + }; + + if !colname.ends_with(self.suffix) || self.schema.contains(colname.as_str()) { + return Ok(RewriteRecursion::NoMutateAndContinue); + } + + Ok(RewriteRecursion::MutateAndContinue) + } + + fn mutate( + &mut self, + node: Self::Node, + arena: &mut Self::Arena, + ) -> polars_core::prelude::PolarsResult { + let AExpr::Column(colname) = arena.get(node.node()) else { + unreachable!(); + }; + + // Safety: Checked in pre_visit() + Ok(AexprNode::new(arena.add(AExpr::Column(PlSmallStr::from( + &colname[..colname.len() - self.suffix.len()], + ))))) + } + } +} + +pub(super) fn split_suffix<'a>(name: &'a str, suffix: &str) -> &'a str { + let (original, _) = name.split_at(name.len() - suffix.len()); + original +} diff --git a/crates/polars-plan/src/plans/optimizer/mod.rs b/crates/polars-plan/src/plans/optimizer/mod.rs new file mode 100644 index 000000000000..70f311c231ee --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/mod.rs @@ -0,0 +1,292 @@ +use polars_core::prelude::*; + +use crate::prelude::*; + +mod cache_states; +mod delay_rechunk; + +mod cluster_with_columns; +mod collapse_and_project; +mod collapse_joins; +mod collect_members; +mod count_star; +#[cfg(feature = "cse")] +mod cse; +mod flatten_union; +#[cfg(feature = "fused")] +mod fused; +mod join_utils; +pub(crate) use join_utils::ExprOrigin; +mod predicate_pushdown; +mod projection_pushdown; +mod set_order; +mod simplify_expr; +mod slice_pushdown_expr; +mod slice_pushdown_lp; +mod stack_opt; + +use collapse_and_project::SimpleProjectionAndCollapse; +#[cfg(feature = "cse")] +pub use cse::NaiveExprMerger; +use delay_rechunk::DelayRechunk; +use polars_core::config::verbose; +use polars_io::predicates::PhysicalIoExpr; +pub use predicate_pushdown::PredicatePushDown; +pub use projection_pushdown::ProjectionPushDown; +pub use simplify_expr::{SimplifyBooleanRule, SimplifyExprRule}; +use slice_pushdown_lp::SlicePushDown; +pub use stack_opt::{OptimizationRule, StackOptimizer}; + +use self::flatten_union::FlattenUnionRule; +use self::set_order::set_order_flags; +pub use crate::frame::{AllowedOptimizations, OptFlags}; +pub use crate::plans::conversion::type_coercion::TypeCoercionRule; +use crate::plans::optimizer::count_star::CountStar; +#[cfg(feature = "cse")] +use crate::plans::optimizer::cse::CommonSubExprOptimizer; +#[cfg(feature = "cse")] +use crate::plans::optimizer::cse::prune_unused_caches; +use crate::plans::optimizer::predicate_pushdown::ExprEval; +#[cfg(feature = "cse")] +use crate::plans::visitor::*; +use crate::prelude::optimizer::collect_members::MemberCollector; + +pub trait Optimize { + fn optimize(&self, logical_plan: DslPlan) -> PolarsResult; +} + +// arbitrary constant to reduce reallocation. +const HASHMAP_SIZE: usize = 16; + +pub(crate) fn init_hashmap(max_len: Option) -> PlHashMap { + PlHashMap::with_capacity(std::cmp::min(max_len.unwrap_or(HASHMAP_SIZE), HASHMAP_SIZE)) +} + +pub fn optimize( + logical_plan: DslPlan, + mut opt_flags: OptFlags, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + scratch: &mut Vec, + expr_eval: ExprEval<'_>, +) -> PolarsResult { + #[allow(dead_code)] + let verbose = verbose(); + + #[cfg(feature = "python")] + if opt_flags.streaming() { + polars_warn!( + Deprecation, + "\ +The old streaming engine is being deprecated and will soon be replaced by the new streaming \ +engine. Starting Polars version 1.23.0 and until the new streaming engine is released, the old \ +streaming engine may become less usable. For people who rely on the old streaming engine, it is \ +suggested to pin your version to before 1.23.0. + +More information on the new streaming engine: https://github.com/pola-rs/polars/issues/20947" + ) + } + + // Gradually fill the rules passed to the optimizer + let opt = StackOptimizer {}; + let mut rules: Vec> = Vec::with_capacity(8); + + // Unset CSE + // This can be turned on again during ir-conversion. + #[allow(clippy::eq_op)] + #[cfg(feature = "cse")] + if opt_flags.contains(OptFlags::EAGER) { + opt_flags &= !(OptFlags::COMM_SUBEXPR_ELIM | OptFlags::COMM_SUBEXPR_ELIM); + } + let mut lp_top = to_alp(logical_plan, expr_arena, lp_arena, &mut opt_flags)?; + + // Don't run optimizations that don't make sense on a single node. + // This keeps eager execution more snappy. + #[cfg(feature = "cse")] + let comm_subplan_elim = opt_flags.contains(OptFlags::COMM_SUBPLAN_ELIM); + + #[cfg(feature = "cse")] + let comm_subexpr_elim = opt_flags.contains(OptFlags::COMM_SUBEXPR_ELIM); + #[cfg(not(feature = "cse"))] + let comm_subexpr_elim = false; + + // During debug we check if the optimizations have not modified the final schema. + #[cfg(debug_assertions)] + let prev_schema = lp_arena.get(lp_top).schema(lp_arena).into_owned(); + + let mut _opt_members = &mut None; + + macro_rules! get_or_init_members { + () => { + _get_or_init_members(_opt_members, lp_top, lp_arena, expr_arena) + }; + } + + macro_rules! get_members_opt { + () => { + _opt_members.as_mut() + }; + } + + // Run before slice pushdown + if opt_flags.contains(OptFlags::CHECK_ORDER_OBSERVE) { + let members = get_or_init_members!(); + if members.has_group_by | members.has_sort | members.has_distinct { + set_order_flags(lp_top, lp_arena, expr_arena, scratch); + } + } + + if opt_flags.simplify_expr() { + #[cfg(feature = "fused")] + rules.push(Box::new(fused::FusedArithmetic {})); + } + + #[cfg(feature = "cse")] + let _cse_plan_changed = if comm_subplan_elim { + let members = get_or_init_members!(); + if (members.has_sink_multiple || members.has_joins_or_unions) + && members.has_duplicate_scans() + && !members.has_cache + { + if verbose { + eprintln!("found multiple sources; run comm_subplan_elim") + } + + let (lp, changed, cid2c) = cse::elim_cmn_subplans(lp_top, lp_arena, expr_arena); + + prune_unused_caches(lp_arena, cid2c); + + lp_top = lp; + members.has_cache |= changed; + changed + } else { + false + } + } else { + false + }; + #[cfg(not(feature = "cse"))] + let _cse_plan_changed = false; + + // Should be run before predicate pushdown. + if opt_flags.projection_pushdown() { + let mut projection_pushdown_opt = ProjectionPushDown::new(); + let alp = lp_arena.take(lp_top); + let alp = projection_pushdown_opt.optimize(alp, lp_arena, expr_arena)?; + lp_arena.replace(lp_top, alp); + + if projection_pushdown_opt.is_count_star { + let mut count_star_opt = CountStar::new(); + count_star_opt.optimize_plan(lp_arena, expr_arena, lp_top)?; + } + } + + if opt_flags.predicate_pushdown() { + let mut predicate_pushdown_opt = + PredicatePushDown::new(expr_eval, opt_flags.new_streaming()); + let alp = lp_arena.take(lp_top); + let alp = predicate_pushdown_opt.optimize(alp, lp_arena, expr_arena)?; + lp_arena.replace(lp_top, alp); + } + + if opt_flags.cluster_with_columns() { + cluster_with_columns::optimize(lp_top, lp_arena, expr_arena) + } + + // Make sure it is after predicate pushdown + if opt_flags.collapse_joins() && get_or_init_members!().has_filter_with_join_input { + collapse_joins::optimize(lp_top, lp_arena, expr_arena); + } + + // Make sure its before slice pushdown. + if opt_flags.fast_projection() { + rules.push(Box::new(SimpleProjectionAndCollapse::new( + opt_flags.eager(), + ))); + } + + if !opt_flags.eager() { + rules.push(Box::new(DelayRechunk::new())); + } + + if opt_flags.slice_pushdown() { + let mut slice_pushdown_opt = + SlicePushDown::new(opt_flags.streaming(), opt_flags.new_streaming()); + let alp = lp_arena.take(lp_top); + let alp = slice_pushdown_opt.optimize(alp, lp_arena, expr_arena)?; + + lp_arena.replace(lp_top, alp); + + // Expressions use the stack optimizer. + rules.push(Box::new(slice_pushdown_opt)); + } + // This optimization removes branches, so we must do it when type coercion + // is completed. + if opt_flags.simplify_expr() { + rules.push(Box::new(SimplifyBooleanRule {})); + } + + if !opt_flags.eager() { + rules.push(Box::new(FlattenUnionRule {})); + } + + lp_top = opt.optimize_loop(&mut rules, expr_arena, lp_arena, lp_top)?; + + if _cse_plan_changed + && get_members_opt!() + .is_some_and(|members| members.has_joins_or_unions && members.has_cache) + { + // We only want to run this on cse inserted caches + cache_states::set_cache_states( + lp_top, + lp_arena, + expr_arena, + scratch, + expr_eval, + verbose, + opt_flags.new_streaming(), + )?; + } + + // This one should run (nearly) last as this modifies the projections + #[cfg(feature = "cse")] + if comm_subexpr_elim && !get_or_init_members!().has_ext_context { + let mut optimizer = CommonSubExprOptimizer::new(); + let alp_node = IRNode::new_mutate(lp_top); + + lp_top = try_with_ir_arena(lp_arena, expr_arena, |arena| { + let rewritten = alp_node.rewrite(&mut optimizer, arena)?; + Ok(rewritten.node()) + })?; + } + + // During debug we check if the optimizations have not modified the final schema. + #[cfg(debug_assertions)] + { + // only check by names because we may supercast types. + assert_eq!( + prev_schema.iter_names().collect::>(), + lp_arena + .get(lp_top) + .schema(lp_arena) + .iter_names() + .collect::>() + ); + }; + + Ok(lp_top) +} + +fn _get_or_init_members<'a>( + opt_members: &'a mut Option, + lp_top: Node, + lp_arena: &mut Arena, + expr_arena: &mut Arena, +) -> &'a mut MemberCollector { + opt_members.get_or_insert_with(|| { + let mut members = MemberCollector::new(); + members.collect(lp_top, lp_arena, expr_arena); + + members + }) +} diff --git a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/group_by.rs b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/group_by.rs new file mode 100644 index 000000000000..0f9645d168c1 --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/group_by.rs @@ -0,0 +1,82 @@ +use super::*; + +#[allow(clippy::too_many_arguments)] +pub(super) fn process_group_by( + opt: &mut PredicatePushDown, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + input: Node, + keys: Vec, + aggs: Vec, + schema: SchemaRef, + maintain_order: bool, + apply: Option>, + options: Arc, + acc_predicates: PlHashMap, +) -> PolarsResult { + use IR::*; + + #[cfg(feature = "dynamic_group_by")] + let no_push = { options.rolling.is_some() || options.dynamic.is_some() }; + + #[cfg(not(feature = "dynamic_group_by"))] + let no_push = false; + + // Don't pushdown predicates on these cases. + if apply.is_some() || no_push || options.slice.is_some() { + let lp = GroupBy { + input, + keys, + aggs, + schema, + apply, + maintain_order, + options, + }; + return opt.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena); + } + + // If the predicate only resolves to the keys we can push it down. + // When it filters the aggregations, the predicate should be done after aggregation. + let mut local_predicates = Vec::with_capacity(acc_predicates.len()); + let key_schema = aexprs_to_schema( + &keys, + lp_arena.get(input).schema(lp_arena).as_ref(), + Context::Default, + expr_arena, + ); + + let mut new_acc_predicates = PlHashMap::with_capacity(acc_predicates.len()); + + for (pred_name, predicate) in acc_predicates { + // Counts change due to groupby's + // TODO! handle aliases, so that the predicate that is pushed down refers to the column before alias. + let mut push_down = !has_aexpr(predicate.node(), expr_arena, |ae| matches!(ae, AExpr::Len)); + + for name in aexpr_to_leaf_names_iter(predicate.node(), expr_arena) { + push_down &= key_schema.contains(name.as_ref()); + + if !push_down { + break; + } + } + if !push_down { + local_predicates.push(predicate) + } else { + new_acc_predicates.insert(pred_name.clone(), predicate.clone()); + } + } + + opt.pushdown_and_assign(input, new_acc_predicates, lp_arena, expr_arena)?; + + let lp = GroupBy { + input, + keys, + aggs, + schema, + apply, + maintain_order, + options, + }; + Ok(opt.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)) +} diff --git a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs new file mode 100644 index 000000000000..c0ed75a601d3 --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs @@ -0,0 +1,251 @@ +use super::*; +use crate::plans::optimizer::join_utils::remove_suffix; + +// Information concerning individual sides of a join. +#[derive(PartialEq, Eq)] +struct LeftRight(T, T); + +fn should_block_join_specific( + ae: &AExpr, + how: &JoinType, + on_names: &PlHashSet, + expr_arena: &Arena, + schema_left: &Schema, + schema_right: &Schema, +) -> LeftRight { + use AExpr::*; + match ae { + // joins can produce null values + Function { + function: + FunctionExpr::Boolean(BooleanFunction::IsNotNull) + | FunctionExpr::Boolean(BooleanFunction::IsNull) + | FunctionExpr::FillNull, + .. + } => join_produces_null(how), + #[cfg(feature = "is_in")] + Function { + function: FunctionExpr::Boolean(BooleanFunction::IsIn { .. }), + .. + } => join_produces_null(how), + // joins can produce duplicates + #[cfg(feature = "is_unique")] + Function { + function: + FunctionExpr::Boolean(BooleanFunction::IsUnique) + | FunctionExpr::Boolean(BooleanFunction::IsDuplicated), + .. + } => LeftRight(true, true), + #[cfg(feature = "is_first_distinct")] + Function { + function: FunctionExpr::Boolean(BooleanFunction::IsFirstDistinct), + .. + } => LeftRight(true, true), + // any operation that checks for equality or ordering can be wrong because + // the join can produce null values + // TODO! check if we can be less conservative here + BinaryExpr { + op: Operator::Eq | Operator::NotEq, + left, + right, + } => { + let LeftRight(bleft, bright) = join_produces_null(how); + + let l_name = aexpr_output_name(*left, expr_arena).unwrap(); + let r_name = aexpr_output_name(*right, expr_arena).unwrap(); + + let is_in_on = on_names.contains(&l_name) || on_names.contains(&r_name); + + let block_left = + is_in_on && (schema_left.contains(&l_name) || schema_left.contains(&r_name)); + let block_right = + is_in_on && (schema_right.contains(&l_name) || schema_right.contains(&r_name)); + LeftRight(block_left | bleft, block_right | bright) + }, + _ => join_produces_null(how), + } +} + +/// Returns a tuple indicating whether predicates should be blocked for either side based on the +/// join type. +/// +/// * `true` indicates that predicates must not be pushed to that side +fn join_produces_null(how: &JoinType) -> LeftRight { + match how { + JoinType::Left => LeftRight(false, true), + JoinType::Right => LeftRight(true, false), + + JoinType::Full => LeftRight(true, true), + JoinType::Cross => LeftRight(true, true), + #[cfg(feature = "asof_join")] + JoinType::AsOf(_) => LeftRight(true, true), + + JoinType::Inner => LeftRight(false, false), + #[cfg(feature = "semi_anti_join")] + JoinType::Semi | JoinType::Anti => LeftRight(false, false), + #[cfg(feature = "iejoin")] + JoinType::IEJoin => LeftRight(false, false), + } +} + +fn all_pred_cols_in_left_on( + predicate: &ExprIR, + expr_arena: &mut Arena, + left_on: &[ExprIR], +) -> bool { + aexpr_to_leaf_names_iter(predicate.node(), expr_arena) + .all(|pred_column_name| left_on.iter().any(|e| e.output_name() == &pred_column_name)) +} + +#[allow(clippy::too_many_arguments)] +pub(super) fn process_join( + opt: &mut PredicatePushDown, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + input_left: Node, + input_right: Node, + left_on: Vec, + right_on: Vec, + schema: SchemaRef, + options: Arc, + acc_predicates: PlHashMap, +) -> PolarsResult { + use IR::*; + let schema_left = lp_arena.get(input_left).schema(lp_arena); + let schema_right = lp_arena.get(input_right).schema(lp_arena); + + let on_names = left_on + .iter() + .flat_map(|e| aexpr_to_leaf_names_iter(e.node(), expr_arena)) + .chain( + right_on + .iter() + .flat_map(|e| aexpr_to_leaf_names_iter(e.node(), expr_arena)), + ) + .collect::>(); + + let mut pushdown_left = init_hashmap(Some(acc_predicates.len())); + let mut pushdown_right = init_hashmap(Some(acc_predicates.len())); + let mut local_predicates = Vec::with_capacity(acc_predicates.len()); + + for (_, predicate) in acc_predicates { + let column_origins = ExprOrigin::get_expr_origin( + predicate.node(), + expr_arena, + &schema_left, + &schema_right, + options.args.suffix(), + ) + .unwrap_or_else(|e| { + if cfg!(debug_assertions) { + panic!("{:?}", e) + } else { + ExprOrigin::None + } + }); + + // Cross joins produce a cartesian product, so if a predicate combines columns from both tables, we should not push down. + // Inequality joins logically produce a cartesian product, so the same logic applies. + if (options.args.how.is_cross() || options.args.how.is_ie()) + && column_origins == ExprOrigin::Both + { + local_predicates.push(predicate); + continue; + } + + // check if predicate can pass the joins node + let allow_pushdown_left = !has_aexpr(predicate.node(), expr_arena, |ae| { + should_block_join_specific( + ae, + &options.args.how, + &on_names, + expr_arena, + &schema_left, + &schema_right, + ) + .0 + }); + + let allow_pushdown_right = !has_aexpr(predicate.node(), expr_arena, |ae| { + should_block_join_specific( + ae, + &options.args.how, + &on_names, + expr_arena, + &schema_left, + &schema_right, + ) + .1 + }); + + // these indicate to which tables we are going to push down the predicate + let mut filter_left = false; + let mut filter_right = false; + + if allow_pushdown_left && column_origins == ExprOrigin::Left { + filter_left = true; + + insert_and_combine_predicate(&mut pushdown_left, &predicate, expr_arena); + // If we push down to the left and all predicate columns are also + // join columns, we also push down right for inner, left or semi join + if all_pred_cols_in_left_on(&predicate, expr_arena, &left_on) { + filter_right = match &options.args.how { + // TODO! if join_on right has a different name + // we can set this to `true` IFF we rename the predicate + JoinType::Inner | JoinType::Left => { + check_input_node(predicate.node(), &schema_right, expr_arena) + }, + #[cfg(feature = "semi_anti_join")] + JoinType::Semi => check_input_node(predicate.node(), &schema_right, expr_arena), + _ => false, + } + } + // this is `else if` because if the predicate is in the left hand side + // the right hand side should be renamed with the suffix. + // in that case we should not push down as the user wants to filter on `x` + // not on `x_rhs`. + } else if allow_pushdown_right && column_origins == ExprOrigin::Right { + filter_right = true; + + let mut predicate = predicate.clone(); + remove_suffix( + &mut predicate, + expr_arena, + &schema_right, + options.args.suffix(), + ); + + insert_and_combine_predicate(&mut pushdown_right, &predicate, expr_arena); + } + + match (filter_left, filter_right, &options.args.how) { + // if not pushed down on one of the tables we have to do it locally. + (false, false, _) | + // if left join and predicate only available in right table, + // 'we should not filter right, because that would lead to + // invalid results. + // see: #2057 + (false, true, JoinType::Left) + => { + local_predicates.push(predicate); + continue; + }, + // business as usual + _ => {} + } + } + + opt.pushdown_and_assign(input_left, pushdown_left, lp_arena, expr_arena)?; + opt.pushdown_and_assign(input_right, pushdown_right, lp_arena, expr_arena)?; + + let lp = Join { + input_left, + input_right, + left_on, + right_on, + schema, + options, + }; + + Ok(opt.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)) +} diff --git a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/keys.rs b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/keys.rs new file mode 100644 index 000000000000..ea10c35bd699 --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/keys.rs @@ -0,0 +1,27 @@ +//! Keys in the `acc_predicates` hashmap. +use super::*; + +// an invisible ascii token we use as delimiter +const HIDDEN_DELIMITER: &str = "\u{1D17A}"; + +/// Determine the hashmap key by combining all the leaf column names of a predicate +pub(super) fn predicate_to_key(predicate: Node, expr_arena: &Arena) -> PlSmallStr { + let mut iter = aexpr_to_leaf_names_iter(predicate, expr_arena); + if let Some(first) = iter.next() { + if let Some(second) = iter.next() { + let mut new = String::with_capacity(32 * iter.size_hint().0); + new.push_str(&first); + new.push_str(HIDDEN_DELIMITER); + new.push_str(&second); + + for name in iter { + new.push_str(HIDDEN_DELIMITER); + new.push_str(&name); + } + return PlSmallStr::from_string(new); + } + first + } else { + PlSmallStr::from_str(HIDDEN_DELIMITER) + } +} diff --git a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs new file mode 100644 index 000000000000..f005610facae --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs @@ -0,0 +1,701 @@ +mod group_by; +mod join; +mod keys; +mod rename; +mod utils; + +use polars_core::datatypes::PlHashMap; +use polars_core::prelude::*; +use recursive::recursive; +use utils::*; + +#[cfg(feature = "python")] +use self::python_dsl::PythonScanSource; +use super::*; +use crate::dsl::function_expr::FunctionExpr; +use crate::prelude::optimizer::predicate_pushdown::group_by::process_group_by; +use crate::prelude::optimizer::predicate_pushdown::join::process_join; +use crate::prelude::optimizer::predicate_pushdown::rename::process_rename; +use crate::utils::{check_input_node, has_aexpr}; + +pub type ExprEval<'a> = + Option<&'a dyn Fn(&ExprIR, &Arena, &SchemaRef) -> Option>>; + +/// The struct is wrapped in a mod to prevent direct member access of `nodes_scratch` +mod inner { + use polars_core::config::verbose; + use polars_utils::arena::Node; + use polars_utils::idx_vec::UnitVec; + use polars_utils::unitvec; + + use super::ExprEval; + + pub struct PredicatePushDown<'a> { + // TODO: Remove unused + #[expect(unused)] + pub(super) expr_eval: ExprEval<'a>, + #[expect(unused)] + pub(super) verbose: bool, + pub(super) block_at_cache: bool, + nodes_scratch: UnitVec, + #[expect(unused)] + pub(super) new_streaming: bool, + } + + impl<'a> PredicatePushDown<'a> { + pub fn new(expr_eval: ExprEval<'a>, new_streaming: bool) -> Self { + Self { + expr_eval, + verbose: verbose(), + block_at_cache: true, + nodes_scratch: unitvec![], + new_streaming, + } + } + + /// Returns shared scratch space after clearing. + pub(super) fn empty_nodes_scratch_mut(&mut self) -> &mut UnitVec { + self.nodes_scratch.clear(); + &mut self.nodes_scratch + } + } +} + +pub use inner::PredicatePushDown; + +impl PredicatePushDown<'_> { + pub(crate) fn block_at_cache(mut self, toggle: bool) -> Self { + self.block_at_cache = toggle; + self + } + + fn optional_apply_predicate( + &mut self, + lp: IR, + local_predicates: Vec, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + ) -> IR { + if !local_predicates.is_empty() { + let predicate = combine_predicates(local_predicates.into_iter(), expr_arena); + let input = lp_arena.add(lp); + + IR::Filter { input, predicate } + } else { + lp + } + } + + fn pushdown_and_assign( + &mut self, + input: Node, + acc_predicates: PlHashMap, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + ) -> PolarsResult<()> { + let alp = lp_arena.take(input); + let lp = self.push_down(alp, acc_predicates, lp_arena, expr_arena)?; + lp_arena.replace(input, lp); + Ok(()) + } + + /// Filter will be pushed down. + fn pushdown_and_continue( + &mut self, + lp: IR, + mut acc_predicates: PlHashMap, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + has_projections: bool, + ) -> PolarsResult { + let inputs = lp.get_inputs_vec(); + let exprs = lp.get_exprs(); + + if has_projections { + // projections should only have a single input. + if inputs.len() > 1 { + // except for ExtContext + assert!(matches!(lp, IR::ExtContext { .. })); + } + let input = inputs[inputs.len() - 1]; + + let (eligibility, alias_rename_map) = pushdown_eligibility( + &exprs, + &[], + &acc_predicates, + expr_arena, + self.empty_nodes_scratch_mut(), + )?; + + let local_predicates = match eligibility { + PushdownEligibility::Full => vec![], + PushdownEligibility::Partial { to_local } => { + let mut out = Vec::with_capacity(to_local.len()); + for key in to_local { + out.push(acc_predicates.remove(&key).unwrap()); + } + out + }, + PushdownEligibility::NoPushdown => { + return self.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena); + }, + }; + + if !alias_rename_map.is_empty() { + for (_, expr_ir) in acc_predicates.iter_mut() { + map_column_references(expr_ir, expr_arena, &alias_rename_map); + } + } + + let alp = lp_arena.take(input); + let alp = self.push_down(alp, acc_predicates, lp_arena, expr_arena)?; + lp_arena.replace(input, alp); + + let lp = lp.with_exprs_and_input(exprs, inputs); + Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)) + } else { + let mut local_predicates = Vec::with_capacity(acc_predicates.len()); + + // determine new inputs by pushing down predicates + let new_inputs = inputs + .iter() + .map(|&node| { + // first we check if we are able to push down the predicate passed this node + // it could be that this node just added the column where we base the predicate on + let input_schema = lp_arena.get(node).schema(lp_arena); + let mut pushdown_predicates = + optimizer::init_hashmap(Some(acc_predicates.len())); + for (_, predicate) in acc_predicates.iter() { + // we can pushdown the predicate + if check_input_node(predicate.node(), &input_schema, expr_arena) { + insert_and_combine_predicate( + &mut pushdown_predicates, + predicate, + expr_arena, + ) + } + // we cannot pushdown the predicate we do it here + else { + local_predicates.push(predicate.clone()); + } + } + + let alp = lp_arena.take(node); + let alp = self.push_down(alp, pushdown_predicates, lp_arena, expr_arena)?; + lp_arena.replace(node, alp); + Ok(node) + }) + .collect::>>()?; + + let lp = lp.with_exprs_and_input(exprs, new_inputs); + Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)) + } + } + + /// Filter will be done at this node, but we continue optimization + fn no_pushdown_restart_opt( + &mut self, + lp: IR, + acc_predicates: PlHashMap, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + ) -> PolarsResult { + let inputs = lp.get_inputs(); + let exprs = lp.get_exprs(); + + let new_inputs = inputs + .iter() + .map(|&node| { + let alp = lp_arena.take(node); + let alp = self.push_down( + alp, + init_hashmap(Some(acc_predicates.len())), + lp_arena, + expr_arena, + )?; + lp_arena.replace(node, alp); + Ok(node) + }) + .collect::>>()?; + let lp = lp.with_exprs_and_input(exprs, new_inputs); + + // all predicates are done locally + let local_predicates = acc_predicates.into_values().collect::>(); + Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)) + } + + fn no_pushdown( + &mut self, + lp: IR, + acc_predicates: PlHashMap, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + ) -> PolarsResult { + // all predicates are done locally + let local_predicates = acc_predicates.into_values().collect::>(); + Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)) + } + + /// Predicate pushdown optimizer + /// + /// # Arguments + /// + /// * `IR` - Arena based logical plan tree representing the query. + /// * `acc_predicates` - The predicates we accumulate during tree traversal. + /// The hashmap maps from leaf-column name to predicates on that column. + /// If the key is already taken we combine the predicate with a bitand operation. + /// The `Node`s are indexes in the `expr_arena` + /// * `lp_arena` - The local memory arena for the logical plan. + /// * `expr_arena` - The local memory arena for the expressions. + #[recursive] + fn push_down( + &mut self, + lp: IR, + mut acc_predicates: PlHashMap, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + ) -> PolarsResult { + use IR::*; + + match lp { + Filter { + // Note: We assume AND'ed predicates have already been split to separate IR filter + // nodes during DSL conversion so we don't do that here. + ref predicate, + input, + } => { + // Use a tmp_key to avoid inadvertently combining predicates that otherwise would have + // been partially pushed: + // + // (1) .filter(pl.count().over("key") == 1) + // (2) .filter(pl.col("key") == 1) + // + // (2) can be pushed past (1) but they both have the same predicate + // key name in the hashtable. + let tmp_key = temporary_unique_key(&acc_predicates); + acc_predicates.insert(tmp_key.clone(), predicate.clone()); + + let local_predicates = match pushdown_eligibility( + &[], + &[predicate.clone()], + &acc_predicates, + expr_arena, + self.empty_nodes_scratch_mut(), + )? + .0 + { + PushdownEligibility::Full => vec![], + PushdownEligibility::Partial { to_local } => { + let mut out = Vec::with_capacity(to_local.len()); + for key in to_local { + out.push(acc_predicates.remove(&key).unwrap()); + } + out + }, + PushdownEligibility::NoPushdown => { + let out = acc_predicates.drain().map(|t| t.1).collect(); + acc_predicates.clear(); + out + }, + }; + + if let Some(predicate) = acc_predicates.remove(&tmp_key) { + insert_and_combine_predicate(&mut acc_predicates, &predicate, expr_arena); + } + + let alp = lp_arena.take(input); + let new_input = self.push_down(alp, acc_predicates, lp_arena, expr_arena)?; + + // TODO! + // If a predicates result would be influenced by earlier applied + // predicates, we simply don't pushdown this one passed this node + // However, we can do better and let it pass but store the order of the predicates + // so that we can apply them in correct order at the deepest level + Ok( + self.optional_apply_predicate( + new_input, + local_predicates, + lp_arena, + expr_arena, + ), + ) + }, + DataFrameScan { + df, + schema, + output_schema, + } => { + let selection = predicate_at_scan(acc_predicates, None, expr_arena); + let mut lp = DataFrameScan { + df, + schema, + output_schema, + }; + + if let Some(predicate) = selection { + let input = lp_arena.add(lp); + + lp = IR::Filter { input, predicate } + } + + Ok(lp) + }, + Scan { + sources, + file_info, + hive_parts: scan_hive_parts, + ref predicate, + scan_type, + unified_scan_args, + output_schema, + } => { + let mut blocked_names = Vec::with_capacity(2); + + // TODO: Allow predicates on file names, this should be supported by new-streaming. + if let Some(col) = unified_scan_args.include_file_paths.as_deref() { + blocked_names.push(col); + } + + match &*scan_type { + #[cfg(feature = "parquet")] + FileScan::Parquet { .. } => {}, + #[cfg(feature = "ipc")] + FileScan::Ipc { .. } => {}, + _ => { + // Disallow row index pushdown of other scans as they may + // not update the row index properly before applying the + // predicate (e.g. FileScan::Csv doesn't). + if let Some(ref row_index) = unified_scan_args.row_index { + blocked_names.push(row_index.name.as_ref()); + }; + }, + }; + + let local_predicates = if blocked_names.is_empty() { + vec![] + } else { + transfer_to_local_by_name(expr_arena, &mut acc_predicates, |name| { + blocked_names.contains(&name.as_ref()) + }) + }; + let predicate = predicate_at_scan(acc_predicates, predicate.clone(), expr_arena); + + let mut do_optimization = match &*scan_type { + #[cfg(feature = "csv")] + FileScan::Csv { .. } => unified_scan_args.pre_slice.is_none(), + FileScan::Anonymous { function, .. } => function.allows_predicate_pushdown(), + #[cfg(feature = "json")] + FileScan::NDJson { .. } => true, + #[allow(unreachable_patterns)] + _ => true, + }; + do_optimization &= predicate.is_some(); + + let hive_parts = scan_hive_parts; + + let lp = if do_optimization { + Scan { + sources, + file_info, + hive_parts, + predicate, + unified_scan_args, + output_schema, + scan_type, + } + } else { + let lp = Scan { + sources, + file_info, + hive_parts, + predicate: None, + unified_scan_args, + output_schema, + scan_type, + }; + if let Some(predicate) = predicate { + let input = lp_arena.add(lp); + Filter { input, predicate } + } else { + lp + } + }; + + Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)) + }, + Distinct { input, options } => { + let subset = if let Some(ref subset) = options.subset { + subset.as_ref() + } else { + &[] + }; + let mut names_set = PlHashSet::::with_capacity(subset.len()); + for name in subset.iter() { + names_set.insert(name.clone()); + } + + let local_predicates = match options.keep_strategy { + UniqueKeepStrategy::Any => { + let condition = |e: &ExprIR| { + // if not elementwise -> to local + !is_elementwise_rec(e.node(), expr_arena) + }; + transfer_to_local_by_expr_ir(expr_arena, &mut acc_predicates, condition) + }, + UniqueKeepStrategy::First + | UniqueKeepStrategy::Last + | UniqueKeepStrategy::None => { + let condition = |name: &PlSmallStr| { + !subset.is_empty() && !names_set.contains(name.as_str()) + }; + transfer_to_local_by_name(expr_arena, &mut acc_predicates, condition) + }, + }; + + self.pushdown_and_assign(input, acc_predicates, lp_arena, expr_arena)?; + let lp = Distinct { input, options }; + Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)) + }, + Join { + input_left, + input_right, + left_on, + right_on, + schema, + options, + } => process_join( + self, + lp_arena, + expr_arena, + input_left, + input_right, + left_on, + right_on, + schema, + options, + acc_predicates, + ), + MapFunction { ref function, .. } => { + if function.allow_predicate_pd() { + match function { + FunctionIR::Rename { existing, new, .. } => { + process_rename(&mut acc_predicates, expr_arena, existing, new); + + self.pushdown_and_continue( + lp, + acc_predicates, + lp_arena, + expr_arena, + false, + ) + }, + FunctionIR::Explode { columns, .. } => { + let condition = |name: &PlSmallStr| columns.iter().any(|s| s == name); + + // first columns that refer to the exploded columns should be done here + let local_predicates = transfer_to_local_by_name( + expr_arena, + &mut acc_predicates, + condition, + ); + + let lp = self.pushdown_and_continue( + lp, + acc_predicates, + lp_arena, + expr_arena, + false, + )?; + Ok(self.optional_apply_predicate( + lp, + local_predicates, + lp_arena, + expr_arena, + )) + }, + #[cfg(feature = "pivot")] + FunctionIR::Unpivot { args, .. } => { + let variable_name = &args + .variable_name + .clone() + .unwrap_or_else(|| PlSmallStr::from_static("variable")); + let value_name = &args + .value_name + .clone() + .unwrap_or_else(|| PlSmallStr::from_static("value")); + + // predicates that will be done at this level + let condition = |name: &PlSmallStr| { + name == variable_name + || name == value_name + || args.on.iter().any(|s| s == name) + }; + let local_predicates = transfer_to_local_by_name( + expr_arena, + &mut acc_predicates, + condition, + ); + + let lp = self.pushdown_and_continue( + lp, + acc_predicates, + lp_arena, + expr_arena, + false, + )?; + Ok(self.optional_apply_predicate( + lp, + local_predicates, + lp_arena, + expr_arena, + )) + }, + FunctionIR::Unnest { columns } => { + let exclude = columns.iter().cloned().collect::>(); + + let local_predicates = + transfer_to_local_by_name(expr_arena, &mut acc_predicates, |x| { + exclude.contains(x) + }); + + let lp = self.pushdown_and_continue( + lp, + acc_predicates, + lp_arena, + expr_arena, + false, + )?; + Ok(self.optional_apply_predicate( + lp, + local_predicates, + lp_arena, + expr_arena, + )) + }, + _ => self.pushdown_and_continue( + lp, + acc_predicates, + lp_arena, + expr_arena, + false, + ), + } + } else { + self.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena) + } + }, + GroupBy { + input, + keys, + aggs, + schema, + apply, + maintain_order, + options, + } => process_group_by( + self, + lp_arena, + expr_arena, + input, + keys, + aggs, + schema, + maintain_order, + apply, + options, + acc_predicates, + ), + lp @ Union { .. } => { + if cfg!(debug_assertions) { + for v in acc_predicates.values() { + let ae = expr_arena.get(v.node()); + assert!(permits_filter_pushdown( + self.empty_nodes_scratch_mut(), + ae, + expr_arena + )); + } + } + + self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, false) + }, + lp @ Sort { .. } => { + if cfg!(debug_assertions) { + for v in acc_predicates.values() { + let ae = expr_arena.get(v.node()); + assert!(permits_filter_pushdown( + self.empty_nodes_scratch_mut(), + ae, + expr_arena + )); + } + } + + self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, true) + }, + lp @ Sink { .. } | lp @ SinkMultiple { .. } => { + self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, false) + }, + // Pushed down passed these nodes + lp @ HStack { .. } + | lp @ Select { .. } + | lp @ SimpleProjection { .. } + | lp @ ExtContext { .. } => { + self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, true) + }, + // NOT Pushed down passed these nodes + // predicates influence slice sizes + lp @ Slice { .. } => { + self.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena) + }, + lp @ HConcat { .. } => { + self.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena) + }, + // Caches will run predicate push-down in the `cache_states` run. + Cache { .. } => { + if self.block_at_cache { + self.no_pushdown(lp, acc_predicates, lp_arena, expr_arena) + } else { + self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, false) + } + }, + #[cfg(feature = "python")] + PythonScan { mut options } => { + let predicate = predicate_at_scan(acc_predicates, None, expr_arena); + if let Some(predicate) = predicate { + // For IO plugins we only accept streamable expressions as + // we want to apply the predicates to the batches. + if !is_elementwise_rec(predicate.node(), expr_arena) + && matches!(options.python_source, PythonScanSource::IOPlugin) + { + let lp = PythonScan { options }; + return Ok(self.optional_apply_predicate( + lp, + vec![predicate], + lp_arena, + expr_arena, + )); + } + + options.predicate = PythonPredicate::Polars(predicate); + } + Ok(PythonScan { options }) + }, + #[cfg(feature = "merge_sorted")] + lp @ MergeSorted { .. } => { + self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, false) + }, + Invalid => unreachable!(), + } + } + + pub(crate) fn optimize( + &mut self, + logical_plan: IR, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + ) -> PolarsResult { + let acc_predicates = PlHashMap::new(); + self.push_down(logical_plan, acc_predicates, lp_arena, expr_arena) + } +} diff --git a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/rename.rs b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/rename.rs new file mode 100644 index 000000000000..ceaae27c2ee1 --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/rename.rs @@ -0,0 +1,19 @@ +use polars_utils::pl_str::PlSmallStr; + +use super::*; + +pub(super) fn process_rename( + acc_predicates: &mut PlHashMap, + expr_arena: &mut Arena, + existing: &[PlSmallStr], + new: &[PlSmallStr], +) { + let rename_map: PlHashMap = + new.iter().cloned().zip(existing.iter().cloned()).collect(); + + if !rename_map.is_empty() { + for (_, expr_ir) in acc_predicates.iter_mut() { + map_column_references(expr_ir, expr_arena, &rename_map); + } + } +} diff --git a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/utils.rs b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/utils.rs new file mode 100644 index 000000000000..bbcafdf7ac7a --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/utils.rs @@ -0,0 +1,460 @@ +use polars_core::prelude::*; +use polars_utils::idx_vec::UnitVec; + +use super::keys::*; +use crate::plans::visitor::{AexprNode, RewriteRecursion, RewritingVisitor, TreeWalker}; +use crate::prelude::*; +fn combine_by_and(left: Node, right: Node, arena: &mut Arena) -> Node { + arena.add(AExpr::BinaryExpr { + left, + op: Operator::And, + right, + }) +} + +/// Don't overwrite predicates but combine them. +pub(super) fn insert_and_combine_predicate( + acc_predicates: &mut PlHashMap, + predicate: &ExprIR, + arena: &mut Arena, +) { + let name = predicate_to_key(predicate.node(), arena); + + acc_predicates + .entry(name) + .and_modify(|existing_predicate| { + let node = combine_by_and(predicate.node(), existing_predicate.node(), arena); + existing_predicate.set_node(node) + }) + .or_insert_with(|| predicate.clone()); +} + +pub(super) fn temporary_unique_key(acc_predicates: &PlHashMap) -> PlSmallStr { + // TODO: Don't heap allocate during construction. + let mut out_key = '\u{1D17A}'.to_string(); + let mut existing_keys = acc_predicates.keys(); + + while acc_predicates.contains_key(&*out_key) { + out_key.push_str(existing_keys.next().unwrap()); + } + + PlSmallStr::from_string(out_key) +} + +pub(super) fn combine_predicates(iter: I, arena: &mut Arena) -> ExprIR +where + I: Iterator, +{ + let mut single_pred = None; + for e in iter { + single_pred = match single_pred { + None => Some(e.node()), + Some(left) => Some(arena.add(AExpr::BinaryExpr { + left, + op: Operator::And, + right: e.node(), + })), + }; + } + single_pred + .map(|node| ExprIR::from_node(node, arena)) + .expect("an empty iterator was passed") +} + +pub(super) fn predicate_at_scan( + acc_predicates: PlHashMap, + predicate: Option, + expr_arena: &mut Arena, +) -> Option { + if !acc_predicates.is_empty() { + let mut new_predicate = combine_predicates(acc_predicates.into_values(), expr_arena); + if let Some(pred) = predicate { + new_predicate.set_node(combine_by_and( + new_predicate.node(), + pred.node(), + expr_arena, + )); + } + Some(new_predicate) + } else { + None + } +} + +/// Evaluates a condition on the column name inputs of every predicate, where if +/// the condition evaluates to true on any column name the predicate is +/// transferred to local. +pub(super) fn transfer_to_local_by_expr_ir( + expr_arena: &Arena, + acc_predicates: &mut PlHashMap, + mut condition: F, +) -> Vec +where + F: FnMut(&ExprIR) -> bool, +{ + let mut remove_keys = Vec::with_capacity(acc_predicates.len()); + + for predicate in acc_predicates.values() { + if condition(predicate) { + if let Some(name) = aexpr_to_leaf_names_iter(predicate.node(), expr_arena).next() { + remove_keys.push(name); + } + } + } + let mut local_predicates = Vec::with_capacity(remove_keys.len()); + for key in remove_keys { + if let Some(pred) = acc_predicates.remove(&*key) { + local_predicates.push(pred) + } + } + local_predicates +} + +/// Evaluates a condition on the column name inputs of every predicate, where if +/// the condition evaluates to true on any column name the predicate is +/// transferred to local. +pub(super) fn transfer_to_local_by_name( + expr_arena: &Arena, + acc_predicates: &mut PlHashMap, + mut condition: F, +) -> Vec +where + F: FnMut(&PlSmallStr) -> bool, +{ + let mut remove_keys = Vec::with_capacity(acc_predicates.len()); + + for (key, predicate) in &*acc_predicates { + let root_names = aexpr_to_leaf_names_iter(predicate.node(), expr_arena); + for name in root_names { + if condition(&name) { + remove_keys.push(key.clone()); + break; + } + } + } + let mut local_predicates = Vec::with_capacity(remove_keys.len()); + for key in remove_keys { + if let Some(pred) = acc_predicates.remove(&*key) { + local_predicates.push(pred) + } + } + local_predicates +} + +/// * `col(A).alias(B).alias(C) => (C, A)` +/// * `col(A) => (A, A)` +/// * `col(A).sum().alias(B) => None` +fn get_maybe_aliased_projection_to_input_name_map( + e: &ExprIR, + expr_arena: &Arena, +) -> Option<(PlSmallStr, PlSmallStr)> { + let ae = expr_arena.get(e.node()); + match e.get_alias() { + Some(alias) => match ae { + AExpr::Column(c_name) => Some((alias.clone(), c_name.clone())), + _ => None, + }, + _ => match ae { + AExpr::Column(c_name) => Some((c_name.clone(), c_name.clone())), + _ => None, + }, + } +} + +pub enum PushdownEligibility { + Full, + // Partial can happen when there are window exprs. + Partial { to_local: Vec }, + NoPushdown, +} + +#[allow(clippy::type_complexity)] +pub fn pushdown_eligibility( + projection_nodes: &[ExprIR], + new_predicates: &[ExprIR], + acc_predicates: &PlHashMap, + expr_arena: &mut Arena, + scratch: &mut UnitVec, +) -> PolarsResult<(PushdownEligibility, PlHashMap)> { + scratch.clear(); + let ae_nodes_stack = scratch; + + let mut alias_to_col_map = + optimizer::init_hashmap::(Some(projection_nodes.len())); + let mut col_to_alias_map = alias_to_col_map.clone(); + + let mut modified_projection_columns = + PlHashSet::::with_capacity(projection_nodes.len()); + let mut has_window = false; + let mut common_window_inputs = PlHashSet::::new(); + + // Important: Names inserted into any data structure by this function are + // all non-aliased. + // This function returns false if pushdown cannot be performed. + let process_projection_or_predicate = + |ae_nodes_stack: &mut UnitVec, + has_window: &mut bool, + common_window_inputs: &mut PlHashSet| { + debug_assert_eq!(ae_nodes_stack.len(), 1); + + let mut partition_by_names = PlHashSet::::new(); + + while let Some(node) = ae_nodes_stack.pop() { + let ae = expr_arena.get(node); + + match ae { + AExpr::Window { + partition_by, + #[cfg(feature = "dynamic_group_by")] + options, + // The function is not checked for groups-sensitivity because + // it is applied over the windows. + .. + } => { + #[cfg(feature = "dynamic_group_by")] + if matches!(options, WindowType::Rolling(..)) { + return false; + }; + + partition_by_names.clear(); + partition_by_names.reserve(partition_by.len()); + + for node in partition_by.iter() { + // Only accept col() + if let AExpr::Column(name) = expr_arena.get(*node) { + partition_by_names.insert(name.clone()); + } else { + // Nested windows can also qualify for push down. + // e.g.: + // * expr1 = min().over(A) + // * expr2 = sum().over(A, expr1) + // Both exprs window over A, so predicates referring + // to A can still be pushed. + ae_nodes_stack.push(*node); + } + } + + if !*has_window { + for name in partition_by_names.drain() { + common_window_inputs.insert(name); + } + + *has_window = true; + } else { + common_window_inputs.retain(|k| partition_by_names.contains(k)) + } + + // Cannot push into disjoint windows: + // e.g.: + // * sum().over(A) + // * sum().over(B) + if common_window_inputs.is_empty() { + return false; + } + }, + _ => { + if !permits_filter_pushdown(ae_nodes_stack, ae, expr_arena) { + return false; + } + }, + } + } + + true + }; + + for e in projection_nodes.iter() { + if let Some((alias, column_name)) = + get_maybe_aliased_projection_to_input_name_map(e, expr_arena) + { + if alias != column_name { + alias_to_col_map.insert(alias.clone(), column_name.clone()); + col_to_alias_map.insert(column_name, alias); + } + continue; + } + + modified_projection_columns.insert(e.output_name().clone()); + + debug_assert!(ae_nodes_stack.is_empty()); + ae_nodes_stack.push(e.node()); + + if !process_projection_or_predicate( + ae_nodes_stack, + &mut has_window, + &mut common_window_inputs, + ) { + return Ok((PushdownEligibility::NoPushdown, alias_to_col_map)); + } + } + + if has_window && !col_to_alias_map.is_empty() { + // Rename to aliased names. + let mut new = PlHashSet::::with_capacity(2 * common_window_inputs.len()); + + for key in common_window_inputs.into_iter() { + if let Some(aliased) = col_to_alias_map.get(&key) { + new.insert(aliased.clone()); + } + // Ensure predicate does not refer to a different column that + // got aliased to the same name as the window column. E.g.: + // .with_columns(col(A).alias(C), sum=sum().over(C)) + // .filter(col(C) == ..) + if !alias_to_col_map.contains_key(&key) { + new.insert(key); + } + } + + if new.is_empty() { + return Ok((PushdownEligibility::NoPushdown, alias_to_col_map)); + } + + common_window_inputs = new; + } + + for e in new_predicates.iter() { + debug_assert!(ae_nodes_stack.is_empty()); + ae_nodes_stack.push(e.node()); + + if !process_projection_or_predicate( + ae_nodes_stack, + &mut has_window, + &mut common_window_inputs, + ) { + return Ok((PushdownEligibility::NoPushdown, alias_to_col_map)); + } + } + + // Should have returned early. + debug_assert!(!common_window_inputs.is_empty() || !has_window); + + if !has_window && projection_nodes.is_empty() { + return Ok((PushdownEligibility::Full, alias_to_col_map)); + } + + // Note: has_window is constant. + let can_use_column = |col: &str| { + if has_window { + common_window_inputs.contains(col) + } else { + !modified_projection_columns.contains(col) + } + }; + + let to_local = acc_predicates + .iter() + .filter_map(|(key, e)| { + debug_assert!(ae_nodes_stack.is_empty()); + + ae_nodes_stack.push(e.node()); + + let mut can_pushdown = true; + + while let Some(node) = ae_nodes_stack.pop() { + let ae = expr_arena.get(node); + + can_pushdown &= if let AExpr::Column(name) = ae { + can_use_column(name) + } else { + // May still contain window expressions that need to be blocked. + permits_filter_pushdown(ae_nodes_stack, ae, expr_arena) + }; + + if !can_pushdown { + break; + }; + } + + ae_nodes_stack.clear(); + + if !can_pushdown { + Some(key.clone()) + } else { + None + } + }) + .collect::>(); + + match to_local.len() { + 0 => Ok((PushdownEligibility::Full, alias_to_col_map)), + len if len == acc_predicates.len() => { + Ok((PushdownEligibility::NoPushdown, alias_to_col_map)) + }, + _ => Ok((PushdownEligibility::Partial { to_local }, alias_to_col_map)), + } +} + +/// Maps column references within an expression. Used to handle column renaming when pushing +/// predicates. +pub(super) fn map_column_references( + expr: &mut ExprIR, + expr_arena: &mut Arena, + rename_map: &PlHashMap, +) { + if rename_map.is_empty() { + return; + } + + let node = AexprNode::new(expr.node()) + .rewrite( + &mut MapColumnReferences { + rename_map, + column_nodes: PlHashMap::with_capacity(rename_map.len()), + }, + expr_arena, + ) + .unwrap() + .node(); + + expr.set_node(node); + + struct MapColumnReferences<'a> { + rename_map: &'a PlHashMap, + column_nodes: PlHashMap<&'a str, Node>, + } + + impl RewritingVisitor for MapColumnReferences<'_> { + type Node = AexprNode; + type Arena = Arena; + + fn pre_visit( + &mut self, + node: &Self::Node, + arena: &mut Self::Arena, + ) -> polars_core::prelude::PolarsResult { + let AExpr::Column(colname) = arena.get(node.node()) else { + return Ok(RewriteRecursion::NoMutateAndContinue); + }; + + if !self.rename_map.contains_key(colname) { + return Ok(RewriteRecursion::NoMutateAndContinue); + } + + Ok(RewriteRecursion::MutateAndContinue) + } + + fn mutate( + &mut self, + node: Self::Node, + arena: &mut Self::Arena, + ) -> polars_core::prelude::PolarsResult { + let AExpr::Column(colname) = arena.get(node.node()) else { + unreachable!(); + }; + + let new_colname = self.rename_map.get(colname).unwrap(); + + if !self.column_nodes.contains_key(new_colname.as_str()) { + self.column_nodes.insert( + new_colname.as_str(), + arena.add(AExpr::Column(new_colname.clone())), + ); + } + + // Safety: Checked in pre_visit() + Ok(AexprNode::new( + *self.column_nodes.get(new_colname.as_str()).unwrap(), + )) + } + } +} diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/functions/mod.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/functions/mod.rs new file mode 100644 index 000000000000..72a451442569 --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/functions/mod.rs @@ -0,0 +1,119 @@ +#[cfg(feature = "pivot")] +mod unpivot; + +#[cfg(feature = "pivot")] +use unpivot::process_unpivot; + +use super::*; + +#[allow(clippy::too_many_arguments)] +pub(super) fn process_functions( + proj_pd: &mut ProjectionPushDown, + input: Node, + function: FunctionIR, + mut ctx: ProjectionContext, + lp_arena: &mut Arena, + expr_arena: &mut Arena, +) -> PolarsResult { + use FunctionIR::*; + match function { + Rename { + ref existing, + ref new, + swapping, + schema: _, + } => { + let clear = ctx.has_pushed_down(); + process_rename( + &mut ctx.acc_projections, + &mut ctx.projected_names, + expr_arena, + existing, + new, + swapping, + )?; + proj_pd.pushdown_and_assign(input, ctx, lp_arena, expr_arena)?; + + if clear { + function.clear_cached_schema() + } + + let lp = IR::MapFunction { input, function }; + Ok(lp) + }, + Explode { columns, .. } => { + columns + .iter() + .for_each(|name| add_str_to_accumulated(name.clone(), &mut ctx, expr_arena)); + proj_pd.pushdown_and_assign(input, ctx, lp_arena, expr_arena)?; + Ok(IRBuilder::new(input, expr_arena, lp_arena) + .explode(columns.clone()) + .build()) + }, + #[cfg(feature = "pivot")] + Unpivot { ref args, .. } => { + let lp = IR::MapFunction { + input, + function: function.clone(), + }; + + process_unpivot(proj_pd, lp, args, input, ctx, lp_arena, expr_arena) + }, + _ => { + if function.allow_projection_pd() && ctx.has_pushed_down() { + let original_acc_projection_len = ctx.acc_projections.len(); + + // add columns needed for the function. + for name in function.additional_projection_pd_columns().as_ref() { + let node = expr_arena.add(AExpr::Column(name.clone())); + add_expr_to_accumulated( + node, + &mut ctx.acc_projections, + &mut ctx.projected_names, + expr_arena, + ) + } + let expands_schema = matches!(function, FunctionIR::Unnest { .. }); + + let local_projections = proj_pd.pushdown_and_assign_check_schema( + input, + ctx, + lp_arena, + expr_arena, + expands_schema, + )?; + + // Remove the cached schema + function.clear_cached_schema(); + let lp = IR::MapFunction { + input, + function: function.clone(), + }; + + if local_projections.is_empty() { + Ok(lp) + } else { + // if we would project, we would remove pushed down predicates + if local_projections.len() < original_acc_projection_len { + Ok(IRBuilder::from_lp(lp, expr_arena, lp_arena) + .with_columns_simple(local_projections, Default::default()) + .build()) + // all projections are local + } else { + Ok(IRBuilder::from_lp(lp, expr_arena, lp_arena) + .project_simple_nodes(local_projections) + .unwrap() + .build()) + } + } + } else { + let lp = IR::MapFunction { + input, + function: function.clone(), + }; + // restart projection pushdown + proj_pd.no_pushdown_restart_opt(lp, ctx, lp_arena, expr_arena) + } + }, + } +} diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/functions/unpivot.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/functions/unpivot.rs new file mode 100644 index 000000000000..85e37cc6090d --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/functions/unpivot.rs @@ -0,0 +1,53 @@ +use super::*; + +#[allow(clippy::too_many_arguments)] +pub(super) fn process_unpivot( + proj_pd: &mut ProjectionPushDown, + lp: IR, + args: &Arc, + input: Node, + ctx: ProjectionContext, + lp_arena: &mut Arena, + expr_arena: &mut Arena, +) -> PolarsResult { + if args.on.is_empty() { + // restart projection pushdown + proj_pd.no_pushdown_restart_opt(lp, ctx, lp_arena, expr_arena) + } else { + let (acc_projections, mut local_projections, projected_names) = split_acc_projections( + ctx.acc_projections, + lp_arena.get(input).schema(lp_arena).as_ref(), + expr_arena, + false, + ); + + if !local_projections.is_empty() { + local_projections.extend_from_slice(&acc_projections); + } + let mut ctx = ProjectionContext::new(acc_projections, projected_names, ctx.inner); + + // make sure that the requested columns are projected + args.index + .iter() + .for_each(|name| add_str_to_accumulated(name.clone(), &mut ctx, expr_arena)); + args.on + .iter() + .for_each(|name| add_str_to_accumulated(name.clone(), &mut ctx, expr_arena)); + + proj_pd.pushdown_and_assign(input, ctx, lp_arena, expr_arena)?; + + // re-make unpivot node so that the schema is updated + let lp = IRBuilder::new(input, expr_arena, lp_arena) + .unpivot(args.clone()) + .build(); + + if local_projections.is_empty() { + Ok(lp) + } else { + Ok(IRBuilder::from_lp(lp, expr_arena, lp_arena) + .project_simple_nodes(local_projections) + .unwrap() + .build()) + } + } +} diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/generic.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/generic.rs new file mode 100644 index 000000000000..2e3ee21959df --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/generic.rs @@ -0,0 +1,50 @@ +use super::*; + +#[allow(clippy::too_many_arguments)] +pub(super) fn process_generic( + proj_pd: &mut ProjectionPushDown, + lp: IR, + ctx: ProjectionContext, + lp_arena: &mut Arena, + expr_arena: &mut Arena, +) -> PolarsResult { + let inputs = lp.get_inputs(); + let exprs = lp.get_exprs(); + + // let mut first_schema = None; + // let mut names = None; + + let new_inputs = inputs + .iter() + .map(|&node| { + let alp = lp_arena.take(node); + let mut alp = proj_pd.push_down(alp, ctx.clone(), lp_arena, expr_arena)?; + + // double projection can mess up the schema ordering + // here we ensure the ordering is maintained. + // + // Consider this query + // df1 => a, b + // df2 => a, b + // + // df3 = df1.join(df2, on = a, b) + // + // concat([df1, df3]).select(a) + // + // schema after projection pd + // df3 => a, b + // df1 => a + // so we ensure we do the 'a' projection again before we concatenate + if !ctx.acc_projections.is_empty() && inputs.len() > 1 { + alp = IRBuilder::from_lp(alp, expr_arena, lp_arena) + .project_simple_nodes(ctx.acc_projections.iter().map(|e| e.0)) + .unwrap() + .build(); + } + lp_arena.replace(node, alp); + Ok(node) + }) + .collect::>>()?; + + Ok(lp.with_exprs_and_input(exprs, new_inputs)) +} diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/group_by.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/group_by.rs new file mode 100644 index 000000000000..7bf9ac41ef62 --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/group_by.rs @@ -0,0 +1,91 @@ +use super::*; + +#[allow(clippy::too_many_arguments)] +pub(super) fn process_group_by( + proj_pd: &mut ProjectionPushDown, + input: Node, + keys: Vec, + aggs: Vec, + apply: Option>, + schema: SchemaRef, + maintain_order: bool, + options: Arc, + ctx: ProjectionContext, + lp_arena: &mut Arena, + expr_arena: &mut Arena, +) -> PolarsResult { + use IR::*; + + // the custom function may need all columns so we do the projections here. + if let Some(f) = apply { + let lp = GroupBy { + input, + keys, + aggs, + schema, + apply: Some(f), + maintain_order, + options, + }; + let input = lp_arena.add(lp); + + let builder = IRBuilder::new(input, expr_arena, lp_arena); + Ok(proj_pd.finish_node_simple_projection(&ctx.acc_projections, builder)) + } else { + let has_pushed_down = ctx.has_pushed_down(); + + // TODO! remove unnecessary vec alloc. + let (mut acc_projections, _local_projections, mut names) = split_acc_projections( + ctx.acc_projections, + lp_arena.get(input).schema(lp_arena).as_ref(), + expr_arena, + false, + ); + + // add the columns used in the aggregations to the projection only if they are used upstream + let projected_aggs = aggs + .into_iter() + .filter(|agg| { + if has_pushed_down && ctx.inner.projections_seen > 0 { + ctx.projected_names.contains(agg.output_name()) + } else { + true + } + }) + .collect::>(); + + for agg in &projected_aggs { + add_expr_to_accumulated(agg.node(), &mut acc_projections, &mut names, expr_arena); + } + + // make sure the keys are projected + for key in &*keys { + add_expr_to_accumulated(key.node(), &mut acc_projections, &mut names, expr_arena); + } + + // make sure that the dynamic key is projected + #[cfg(feature = "dynamic_group_by")] + if let Some(options) = &options.dynamic { + let node = expr_arena.add(AExpr::Column(options.index_column.clone())); + add_expr_to_accumulated(node, &mut acc_projections, &mut names, expr_arena); + } + // make sure that the rolling key is projected + #[cfg(feature = "dynamic_group_by")] + if let Some(options) = &options.rolling { + let node = expr_arena.add(AExpr::Column(options.index_column.clone())); + add_expr_to_accumulated(node, &mut acc_projections, &mut names, expr_arena); + } + let ctx = ProjectionContext::new(acc_projections, names, ctx.inner); + + proj_pd.pushdown_and_assign(input, ctx, lp_arena, expr_arena)?; + + let builder = IRBuilder::new(input, expr_arena, lp_arena).group_by( + keys, + projected_aggs, + apply, + maintain_order, + options, + ); + Ok(builder.build()) + } +} diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/hconcat.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/hconcat.rs new file mode 100644 index 000000000000..b3e3c5ace987 --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/hconcat.rs @@ -0,0 +1,57 @@ +use super::*; + +#[allow(clippy::too_many_arguments)] +pub(super) fn process_hconcat( + proj_pd: &mut ProjectionPushDown, + inputs: Vec, + schema: SchemaRef, + options: HConcatOptions, + ctx: ProjectionContext, + lp_arena: &mut Arena, + expr_arena: &mut Arena, +) -> PolarsResult { + // When applying projection pushdown to horizontal concatenation, + // we apply pushdown to all of the inputs using the subset of accumulated projections relevant to each input, + // then rebuild the concatenated schema. + + let schema = if ctx.acc_projections.is_empty() { + schema + } else { + let mut remaining_projections: PlHashSet<_> = ctx.acc_projections.into_iter().collect(); + + for input in inputs.iter() { + let mut input_pushdown = Vec::new(); + let input_schema = lp_arena.get(*input).schema(lp_arena); + + for proj in remaining_projections.iter() { + if check_input_column_node(*proj, input_schema.as_ref(), expr_arena) { + input_pushdown.push(*proj); + } + } + + let mut input_names = PlHashSet::new(); + for proj in &input_pushdown { + remaining_projections.remove(proj); + for name in aexpr_to_leaf_names(proj.0, expr_arena) { + input_names.insert(name); + } + } + let ctx = ProjectionContext::new(input_pushdown, input_names, ctx.inner); + proj_pd.pushdown_and_assign(*input, ctx, lp_arena, expr_arena)?; + } + + let mut schemas = Vec::with_capacity(inputs.len()); + for input in inputs.iter() { + let schema = lp_arena.get(*input).schema(lp_arena).into_owned(); + schemas.push(schema); + } + let new_schema = merge_schemas(&schemas)?; + Arc::new(new_schema) + }; + + Ok(IR::HConcat { + inputs, + schema, + options, + }) +} diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/hstack.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/hstack.rs new file mode 100644 index 000000000000..b459628a31b1 --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/hstack.rs @@ -0,0 +1,67 @@ +use super::*; + +#[allow(clippy::too_many_arguments)] +pub(super) fn process_hstack( + proj_pd: &mut ProjectionPushDown, + input: Node, + mut exprs: Vec, + options: ProjectionOptions, + mut ctx: ProjectionContext, + lp_arena: &mut Arena, + expr_arena: &mut Arena, +) -> PolarsResult { + if ctx.has_pushed_down() { + let mut pruned_with_cols = Vec::with_capacity(exprs.len()); + + // Check if output names are used upstream + // if not, we can prune the `with_column` expression + // as it is not used in the output. + for e in exprs { + let is_used_upstream = ctx.projected_names.contains(e.output_name()); + if is_used_upstream { + pruned_with_cols.push(e); + } + } + + if pruned_with_cols.is_empty() { + proj_pd.pushdown_and_assign(input, ctx, lp_arena, expr_arena)?; + return Ok(lp_arena.take(input)); + } + + // Make sure that columns selected with_columns are available + // only if not empty. If empty we already select everything. + for e in &pruned_with_cols { + add_expr_to_accumulated( + e.node(), + &mut ctx.acc_projections, + &mut ctx.projected_names, + expr_arena, + ); + } + + exprs = pruned_with_cols + } + // projections that select columns added by + // this `with_column` operation can be dropped + // For instance in: + // + // q + // .with_column(col("a").alias("b") + // .select(["a", "b"]) + // + // we can drop the "b" projection at this level + let (acc_projections, _, names) = split_acc_projections( + ctx.acc_projections, + &lp_arena.get(input).schema(lp_arena), + expr_arena, + true, // expands_schema + ); + + let ctx = ProjectionContext::new(acc_projections, names, ctx.inner); + proj_pd.pushdown_and_assign(input, ctx, lp_arena, expr_arena)?; + + let lp = IRBuilder::new(input, expr_arena, lp_arena) + .with_columns(exprs, options) + .build(); + Ok(lp) +} diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/joins.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/joins.rs new file mode 100644 index 000000000000..f6882aadb340 --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/joins.rs @@ -0,0 +1,485 @@ +#![allow(clippy::too_many_arguments)] +use std::borrow::Cow; + +use super::*; +use crate::prelude::optimizer::join_utils::split_suffix; + +fn add_keys_to_accumulated_state( + expr: Node, + acc_projections: &mut Vec, + local_projection: &mut Vec, + projected_names: &mut PlHashSet, + expr_arena: &mut Arena, + // Only for left hand side table we add local names. + add_local: bool, +) -> Option { + add_expr_to_accumulated(expr, acc_projections, projected_names, expr_arena); + // The projections may do more than simply project. + // e.g. col("foo").truncate() * col("bar") + // that means we don't want to execute the projection as that is already done by + // the JOIN executor + if add_local { + // return the left most name as output name + let names = aexpr_to_leaf_names_iter(expr, expr_arena).collect::>(); + let output_name = names.first().cloned(); + for name in names { + let node = expr_arena.add(AExpr::Column(name)); + local_projection.push(ColumnNode(node)); + } + output_name + } else { + None + } +} + +#[cfg(feature = "asof_join")] +pub(super) fn process_asof_join( + proj_pd: &mut ProjectionPushDown, + input_left: Node, + input_right: Node, + left_on: Vec, + right_on: Vec, + options: Arc, + ctx: ProjectionContext, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + join_schema: &Schema, +) -> PolarsResult { + // n = 0 if no projections, so we don't allocate unneeded + let n = ctx.acc_projections.len() * 2; + let mut pushdown_left = Vec::with_capacity(n); + let mut pushdown_right = Vec::with_capacity(n); + let mut names_left = PlHashSet::with_capacity(n); + let mut names_right = PlHashSet::with_capacity(n); + let mut local_projection = Vec::with_capacity(n); + + let JoinType::AsOf(asof_options) = &options.args.how else { + unreachable!() + }; + + // if there are no projections we don't have to do anything (all columns are projected) + // otherwise we build local projections to sort out proper column names due to the + // join operation + // + // Joins on columns with different names, for example + // left_on = "a", right_on = "b + // will remove the name "b" (it is "a" now). That columns should therefore not + // be added to a local projection. + if ctx.has_pushed_down() { + let schema_left = lp_arena.get(input_left).schema(lp_arena); + let schema_right = lp_arena.get(input_right).schema(lp_arena); + + // make sure that the asof join 'by' columns are projected + if let (Some(left_by), Some(right_by)) = (&asof_options.left_by, &asof_options.right_by) { + for name in left_by { + let add = ctx.projected_names.contains(name.as_str()); + + let node = expr_arena.add(AExpr::Column(name.clone())); + add_keys_to_accumulated_state( + node, + &mut pushdown_left, + &mut local_projection, + &mut names_left, + expr_arena, + add, + ); + } + for name in right_by { + let node = expr_arena.add(AExpr::Column(name.clone())); + add_keys_to_accumulated_state( + node, + &mut pushdown_right, + &mut local_projection, + &mut names_right, + expr_arena, + false, + ); + } + } + + // The join on keys can lead that columns are already added, we don't want to create + // duplicates so store the names. + let mut local_projected_names = PlHashSet::new(); + + // We need the join columns so we push the projection downwards + for e in &left_on { + let local_name = add_keys_to_accumulated_state( + e.node(), + &mut pushdown_left, + &mut local_projection, + &mut names_left, + expr_arena, + true, + ) + .unwrap(); + local_projected_names.insert(local_name); + } + // this differs from normal joins, as in `asof_joins` + // both columns remain. So `add_local=true` also for the right table + for e in &right_on { + if let Some(local_name) = add_keys_to_accumulated_state( + e.node(), + &mut pushdown_right, + &mut local_projection, + &mut names_right, + expr_arena, + true, + ) { + // insert the name. + // if name was already added we pop the local projection + // otherwise we would project duplicate columns + if !local_projected_names.insert(local_name) { + local_projection.pop(); + } + }; + } + + for proj in ctx.acc_projections { + let add_local = if local_projected_names.is_empty() { + true + } else { + let name = column_node_to_name(proj, expr_arena); + !local_projected_names.contains(name) + }; + + process_projection( + proj_pd, + &schema_left, + &schema_right, + proj, + &mut pushdown_left, + &mut pushdown_right, + &mut names_left, + &mut names_right, + expr_arena, + &mut local_projection, + add_local, + &options, + join_schema, + ); + } + } + + let ctx_left = ProjectionContext::new(pushdown_left, names_left, ctx.inner); + let ctx_right = ProjectionContext::new(pushdown_right, names_right, ctx.inner); + + proj_pd.pushdown_and_assign(input_left, ctx_left, lp_arena, expr_arena)?; + proj_pd.pushdown_and_assign(input_right, ctx_right, lp_arena, expr_arena)?; + + resolve_join_suffixes( + input_left, + input_right, + left_on, + right_on, + options, + lp_arena, + expr_arena, + &local_projection, + ) +} + +#[allow(clippy::too_many_arguments)] +pub(super) fn process_join( + proj_pd: &mut ProjectionPushDown, + input_left: Node, + input_right: Node, + left_on: Vec, + right_on: Vec, + mut options: Arc, + ctx: ProjectionContext, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + join_schema: &Schema, +) -> PolarsResult { + #[cfg(feature = "asof_join")] + if matches!(options.args.how, JoinType::AsOf(_)) { + return process_asof_join( + proj_pd, + input_left, + input_right, + left_on, + right_on, + options, + ctx, + lp_arena, + expr_arena, + join_schema, + ); + } + + // n = 0 if no projections, so we don't allocate unneeded + let n = ctx.acc_projections.len() * 2; + let mut pushdown_left = Vec::with_capacity(n); + let mut pushdown_right = Vec::with_capacity(n); + let mut names_left = PlHashSet::with_capacity(n); + let mut names_right = PlHashSet::with_capacity(n); + let mut local_projection = Vec::with_capacity(n); + + // If there are no projections we don't have to do anything (all columns are projected) + // otherwise we build local projections to sort out proper column names due to the + // join operation + // + // Joins on columns with different names, for example + // left_on = "a", right_on = "b + // will remove the name "b" (it is "a" now). That columns should therefore not + // be added to a local projection. + if ctx.has_pushed_down() { + let schema_left = lp_arena.get(input_left).schema(lp_arena); + let schema_right = lp_arena.get(input_right).schema(lp_arena); + + // The join on keys can lead that columns are already added, we don't want to create + // duplicates so store the names. + let mut local_projected_names = PlHashSet::new(); + + // We need the join columns so we push the projection downwards + for e in &left_on { + if !local_projected_names.insert(e.output_name().clone()) { + // A join can have multiple leaf names, so we must still ensure all leaf names are projected. + if options.args.how.is_ie() { + add_expr_to_accumulated( + e.node(), + &mut pushdown_left, + &mut names_left, + expr_arena, + ); + } + + continue; + } + + let _ = add_keys_to_accumulated_state( + e.node(), + &mut pushdown_left, + &mut local_projection, + &mut names_left, + expr_arena, + true, + ); + } + + // For left and inner joins we can set `coalesce` to `true` if the rhs key columns are not projected. + // This saves a materialization. + if !options.args.should_coalesce() + && matches!(options.args.how, JoinType::Left | JoinType::Inner) + { + let mut allow_opt = true; + let non_coalesced_key_is_used = right_on.iter().any(|e| { + // Inline expressions other than col should not coalesce. + if !matches!(expr_arena.get(e.node()), AExpr::Column(_)) { + allow_opt = false; + return true; + } + let key_name = e.output_name(); + + // If the name is in the lhs table, a suffix is added. + let key_name_after_join = if schema_left.contains(key_name) { + Cow::Owned(_join_suffix_name(key_name, options.args.suffix())) + } else { + Cow::Borrowed(key_name) + }; + + ctx.projected_names.contains(key_name_after_join.as_ref()) + }); + + // If they key is not used, coalesce the columns as that is often cheaper. + if !non_coalesced_key_is_used && allow_opt { + let options = Arc::make_mut(&mut options); + options.args.coalesce = JoinCoalesce::CoalesceColumns; + } + } + + // In both columns remain. So `add_local=true` also for the right table + let add_local = !options.args.should_coalesce(); + for e in &right_on { + // In case of full outer joins we also add the columns. + // But before we do that we must check if the column wasn't already added by the lhs. + let add_local = if add_local { + !local_projected_names.contains(e.output_name()) + } else { + false + }; + + let local_name = add_keys_to_accumulated_state( + e.node(), + &mut pushdown_right, + &mut local_projection, + &mut names_right, + expr_arena, + add_local, + ); + + if let Some(local_name) = local_name { + local_projected_names.insert(local_name); + } + } + + for proj in ctx.acc_projections { + let add_local = if local_projected_names.is_empty() { + true + } else { + let name = column_node_to_name(proj, expr_arena); + !local_projected_names.contains(name) + }; + + process_projection( + proj_pd, + &schema_left, + &schema_right, + proj, + &mut pushdown_left, + &mut pushdown_right, + &mut names_left, + &mut names_right, + expr_arena, + &mut local_projection, + add_local, + &options, + join_schema, + ); + } + } + + let ctx_left = ProjectionContext::new(pushdown_left, names_left, ctx.inner); + let ctx_right = ProjectionContext::new(pushdown_right, names_right, ctx.inner); + + proj_pd.pushdown_and_assign(input_left, ctx_left, lp_arena, expr_arena)?; + proj_pd.pushdown_and_assign(input_right, ctx_right, lp_arena, expr_arena)?; + + resolve_join_suffixes( + input_left, + input_right, + left_on, + right_on, + options, + lp_arena, + expr_arena, + &local_projection, + ) +} + +fn process_projection( + proj_pd: &mut ProjectionPushDown, + schema_left: &Schema, + schema_right: &Schema, + proj: ColumnNode, + pushdown_left: &mut Vec, + pushdown_right: &mut Vec, + names_left: &mut PlHashSet, + names_right: &mut PlHashSet, + expr_arena: &mut Arena, + local_projection: &mut Vec, + add_local: bool, + options: &JoinOptions, + join_schema: &Schema, +) { + // Path for renamed columns due to the join. The column name of the left table + // stays as is, the column of the right will have the "_right" suffix. + // Thus joining two tables with both a foo column leads to ["foo", "foo_right"] + + // try to push down projection in either of two tables + let (pushed_at_least_once, already_projected) = proj_pd.join_push_down( + schema_left, + schema_right, + proj, + pushdown_left, + pushdown_right, + names_left, + names_right, + expr_arena, + ); + + if !(pushed_at_least_once || already_projected) + // did not succeed push down in any tables., + // this might be due to the suffix in the projection name + // this branch tries to pushdown the column without suffix + { + // Column name of the projection without any alias. + let leaf_column_name = column_node_to_name(proj, expr_arena).clone(); + + let suffix = options.args.suffix().as_str(); + // If _right suffix exists we need to push a projection down without this + // suffix. + if leaf_column_name.ends_with(suffix) && join_schema.contains(leaf_column_name.as_ref()) { + // downwards name is the name without the _right i.e. "foo". + let downwards_name = split_suffix(leaf_column_name.as_ref(), suffix); + let downwards_name = PlSmallStr::from_str(downwards_name); + + let downwards_name_column = expr_arena.add(AExpr::Column(downwards_name.clone())); + // project downwards and locally immediately alias to prevent wrong projections + if names_right.insert(downwards_name) { + pushdown_right.push(ColumnNode(downwards_name_column)); + } + local_projection.push(proj); + } + } + // did succeed pushdown at least in any of the two tables + // if not already added locally we ensure we project local as well + else if add_local && pushed_at_least_once { + // always also do the projection locally, because the join columns may not be + // included in the projection. + // for instance: + // + // SELECT [COLUMN temp] + // FROM + // JOIN (["days", "temp"]) WITH (["days", "rain"]) ON (left: days right: days) + // + // should drop the days column after the join. + local_projection.push(proj); + } +} + +// Because we do a projection pushdown +// We may influence the suffixes. +// For instance if a join would have created a schema +// +// "foo", "foo_right" +// +// but we only project the "foo_right" column, the join will not produce +// a "name_right" because we did not project its left name duplicate "foo" +// +// The code below checks if can do the suffixed projections on the schema that +// we have after the join. If we cannot then we modify the projection: +// +// col("foo_right") to col("foo").alias("foo_right") +fn resolve_join_suffixes( + input_left: Node, + input_right: Node, + left_on: Vec, + right_on: Vec, + options: Arc, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + local_projection: &[ColumnNode], +) -> PolarsResult { + let suffix = options.args.suffix().as_str(); + let alp = IRBuilder::new(input_left, expr_arena, lp_arena) + .join(input_right, left_on, right_on, options.clone()) + .build(); + let schema_after_join = alp.schema(lp_arena); + + let mut all_columns = true; + let projections = local_projection + .iter() + .map(|proj| { + let name = column_node_to_name(*proj, expr_arena).clone(); + if name.ends_with(suffix) && schema_after_join.get(&name).is_none() { + let downstream_name = &name.as_str()[..name.len() - suffix.len()]; + let col = AExpr::Column(downstream_name.into()); + let node = expr_arena.add(col); + all_columns = false; + ExprIR::new(node, OutputName::Alias(name.clone())) + } else { + ExprIR::new(proj.0, OutputName::ColumnLhs(name.clone())) + } + }) + .collect::>(); + + let builder = IRBuilder::from_lp(alp, expr_arena, lp_arena); + Ok(if all_columns { + builder + .project_simple(projections.iter().map(|e| e.output_name().clone()))? + .build() + } else { + builder.project(projections, Default::default()).build() + }) +} diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/mod.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/mod.rs new file mode 100644 index 000000000000..ab77067d5ddc --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/mod.rs @@ -0,0 +1,773 @@ +mod functions; +mod generic; +mod group_by; +mod hconcat; +mod hstack; +mod joins; +mod projection; +mod rename; +#[cfg(feature = "semi_anti_join")] +mod semi_anti_join; + +use polars_core::datatypes::PlHashSet; +use polars_core::prelude::*; +use polars_io::RowIndex; +use recursive::recursive; +#[cfg(feature = "semi_anti_join")] +use semi_anti_join::process_semi_anti_join; + +use crate::prelude::optimizer::projection_pushdown::generic::process_generic; +use crate::prelude::optimizer::projection_pushdown::group_by::process_group_by; +use crate::prelude::optimizer::projection_pushdown::hconcat::process_hconcat; +use crate::prelude::optimizer::projection_pushdown::hstack::process_hstack; +use crate::prelude::optimizer::projection_pushdown::joins::process_join; +use crate::prelude::optimizer::projection_pushdown::projection::process_projection; +use crate::prelude::optimizer::projection_pushdown::rename::process_rename; +use crate::prelude::*; +use crate::utils::aexpr_to_leaf_names; + +#[derive(Default, Copy, Clone)] +struct ProjectionCopyState { + projections_seen: usize, + is_count_star: bool, +} + +#[derive(Clone, Default)] +struct ProjectionContext { + acc_projections: Vec, + projected_names: PlHashSet, + inner: ProjectionCopyState, +} + +impl ProjectionContext { + fn new( + acc_projections: Vec, + projected_names: PlHashSet, + inner: ProjectionCopyState, + ) -> Self { + Self { + acc_projections, + projected_names, + inner, + } + } + + /// If this is `true`, other nodes should add the columns + /// they need to the push down state + fn has_pushed_down(&self) -> bool { + // count star also acts like a pushdown as we will select a single column at the source + // when there were no other projections. + !self.acc_projections.is_empty() || self.inner.is_count_star + } + + fn process_count_star_at_scan(&mut self, schema: &Schema, expr_arena: &mut Arena) { + if self.acc_projections.is_empty() { + let (name, _dt) = match schema.len() { + 0 => return, + 1 => schema.get_at_index(0).unwrap(), + _ => { + // skip first as that can be the row index. + // We look for a relative cheap type, such as a numeric or bool + schema + .iter() + .skip(1) + .find(|(_name, dt)| { + let phys = dt; + phys.is_null() + || phys.is_primitive_numeric() + || phys.is_bool() + || phys.is_temporal() + }) + .unwrap_or_else(|| schema.get_at_index(schema.len() - 1).unwrap()) + }, + }; + + let node = expr_arena.add(AExpr::Column(name.clone())); + self.acc_projections.push(ColumnNode(node)); + self.projected_names.insert(name.clone()); + } + } +} + +/// utility function to get names of the columns needed in projection at scan level +fn get_scan_columns( + acc_projections: &[ColumnNode], + expr_arena: &Arena, + row_index: Option<&RowIndex>, + file_path_col: Option<&str>, +) -> Option> { + if !acc_projections.is_empty() { + Some( + acc_projections + .iter() + .filter_map(|node| { + let name = column_node_to_name(*node, expr_arena); + + if let Some(ri) = row_index { + if ri.name == name { + return None; + } + } + + if let Some(file_path_col) = file_path_col { + if file_path_col == name.as_str() { + return None; + } + } + + Some(name.clone()) + }) + .collect::>(), + ) + } else { + None + } +} + +/// split in a projection vec that can be pushed down and a projection vec that should be used +/// in this node +/// +/// # Returns +/// accumulated_projections, local_projections, accumulated_names +/// +/// - `expands_schema`. An unnest adds more columns to a schema, so we cannot use fast path +fn split_acc_projections( + acc_projections: Vec, + down_schema: &Schema, + expr_arena: &Arena, + expands_schema: bool, +) -> (Vec, Vec, PlHashSet) { + // If node above has as many columns as the projection there is nothing to pushdown. + if !expands_schema && down_schema.len() == acc_projections.len() { + let local_projections = acc_projections; + (vec![], local_projections, PlHashSet::new()) + } else { + let (acc_projections, local_projections): (Vec<_>, Vec<_>) = acc_projections + .into_iter() + .partition(|expr| check_input_column_node(*expr, down_schema, expr_arena)); + let mut names = PlHashSet::default(); + for proj in &acc_projections { + let name = column_node_to_name(*proj, expr_arena).clone(); + names.insert(name); + } + (acc_projections, local_projections, names) + } +} + +/// utility function such that we can recurse all binary expressions in the expression tree +fn add_expr_to_accumulated( + expr: Node, + acc_projections: &mut Vec, + projected_names: &mut PlHashSet, + expr_arena: &Arena, +) { + for root_node in aexpr_to_column_nodes_iter(expr, expr_arena) { + let name = column_node_to_name(root_node, expr_arena).clone(); + if projected_names.insert(name) { + acc_projections.push(root_node) + } + } +} + +fn add_str_to_accumulated( + name: PlSmallStr, + ctx: &mut ProjectionContext, + expr_arena: &mut Arena, +) { + // if not pushed down: all columns are already projected. + if ctx.has_pushed_down() && !ctx.projected_names.contains(&name) { + let node = expr_arena.add(AExpr::Column(name)); + add_expr_to_accumulated( + node, + &mut ctx.acc_projections, + &mut ctx.projected_names, + expr_arena, + ); + } +} + +fn update_scan_schema( + acc_projections: &[ColumnNode], + expr_arena: &Arena, + schema: &Schema, + sort_projections: bool, +) -> PolarsResult { + let mut new_schema = Schema::with_capacity(acc_projections.len()); + let mut new_cols = Vec::with_capacity(acc_projections.len()); + for node in acc_projections.iter() { + let name = column_node_to_name(*node, expr_arena); + let item = schema.try_get_full(name)?; + new_cols.push(item); + } + // make sure that the projections are sorted by the schema. + if sort_projections { + new_cols.sort_unstable_by_key(|item| item.0); + } + for item in new_cols { + new_schema.with_column(item.1.clone(), item.2.clone()); + } + Ok(new_schema) +} + +pub struct ProjectionPushDown { + pub is_count_star: bool, +} + +impl ProjectionPushDown { + pub(super) fn new() -> Self { + Self { + is_count_star: false, + } + } + + /// Projection will be done at this node, but we continue optimization + fn no_pushdown_restart_opt( + &mut self, + lp: IR, + ctx: ProjectionContext, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + ) -> PolarsResult { + let inputs = lp.get_inputs(); + let exprs = lp.get_exprs(); + + let new_inputs = inputs + .iter() + .map(|&node| { + let alp = lp_arena.take(node); + let ctx = ProjectionContext::new(Default::default(), Default::default(), ctx.inner); + let alp = self.push_down(alp, ctx, lp_arena, expr_arena)?; + lp_arena.replace(node, alp); + Ok(node) + }) + .collect::>>()?; + let lp = lp.with_exprs_and_input(exprs, new_inputs); + + let builder = IRBuilder::from_lp(lp, expr_arena, lp_arena); + Ok(self.finish_node_simple_projection(&ctx.acc_projections, builder)) + } + + fn finish_node_simple_projection( + &mut self, + local_projections: &[ColumnNode], + builder: IRBuilder, + ) -> IR { + if !local_projections.is_empty() { + builder + .project_simple_nodes(local_projections.iter().map(|node| node.0)) + .unwrap() + .build() + } else { + builder.build() + } + } + + fn finish_node(&mut self, local_projections: Vec, builder: IRBuilder) -> IR { + if !local_projections.is_empty() { + builder + .project(local_projections, Default::default()) + .build() + } else { + builder.build() + } + } + + #[allow(clippy::too_many_arguments)] + fn join_push_down( + &mut self, + schema_left: &Schema, + schema_right: &Schema, + proj: ColumnNode, + pushdown_left: &mut Vec, + pushdown_right: &mut Vec, + names_left: &mut PlHashSet, + names_right: &mut PlHashSet, + expr_arena: &Arena, + ) -> (bool, bool) { + let mut pushed_at_least_one = false; + let mut already_projected = false; + + let name = column_node_to_name(proj, expr_arena); + let is_in_left = names_left.contains(name); + let is_in_right = names_right.contains(name); + already_projected |= is_in_left; + already_projected |= is_in_right; + + if check_input_column_node(proj, schema_left, expr_arena) && !is_in_left { + names_left.insert(name.clone()); + pushdown_left.push(proj); + pushed_at_least_one = true; + } + if check_input_column_node(proj, schema_right, expr_arena) && !is_in_right { + names_right.insert(name.clone()); + pushdown_right.push(proj); + pushed_at_least_one = true; + } + + (pushed_at_least_one, already_projected) + } + + /// This pushes down current node and assigns the result to this node. + fn pushdown_and_assign( + &mut self, + input: Node, + ctx: ProjectionContext, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + ) -> PolarsResult<()> { + let alp = lp_arena.take(input); + let lp = self.push_down(alp, ctx, lp_arena, expr_arena)?; + lp_arena.replace(input, lp); + Ok(()) + } + + /// This pushes down the projection that are validated + /// that they can be done successful at the schema above + /// The result is assigned to this node. + /// + /// The local projections are return and still have to be applied + fn pushdown_and_assign_check_schema( + &mut self, + input: Node, + mut ctx: ProjectionContext, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + // an unnest changes/expands the schema + expands_schema: bool, + ) -> PolarsResult> { + let alp = lp_arena.take(input); + let down_schema = alp.schema(lp_arena); + + let (acc_projections, local_projections, names) = split_acc_projections( + ctx.acc_projections, + &down_schema, + expr_arena, + expands_schema, + ); + + ctx.acc_projections = acc_projections; + ctx.projected_names = names; + + let lp = self.push_down(alp, ctx, lp_arena, expr_arena)?; + lp_arena.replace(input, lp); + Ok(local_projections) + } + + /// Projection pushdown optimizer + /// + /// # Arguments + /// + /// * `IR` - Arena based logical plan tree representing the query. + /// * `acc_projections` - The projections we accumulate during tree traversal. + /// * `names` - We keep track of the names to ensure we don't do duplicate projections. + /// * `projections_seen` - Count the number of projection operations during tree traversal. + /// * `lp_arena` - The local memory arena for the logical plan. + /// * `expr_arena` - The local memory arena for the expressions. + #[recursive] + fn push_down( + &mut self, + logical_plan: IR, + mut ctx: ProjectionContext, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + ) -> PolarsResult { + use IR::*; + + match logical_plan { + Select { expr, input, .. } => { + process_projection(self, input, expr, ctx, lp_arena, expr_arena, false) + }, + SimpleProjection { columns, input, .. } => { + let exprs = names_to_expr_irs(columns.iter_names_cloned(), expr_arena); + process_projection(self, input, exprs, ctx, lp_arena, expr_arena, true) + }, + DataFrameScan { + df, + schema, + mut output_schema, + .. + } => { + // TODO: Just project 0-width morsels. + if self.is_count_star { + ctx.process_count_star_at_scan(&schema, expr_arena); + } + if ctx.has_pushed_down() { + output_schema = Some(Arc::new(update_scan_schema( + &ctx.acc_projections, + expr_arena, + &schema, + false, + )?)); + } + let lp = DataFrameScan { + df, + schema, + output_schema, + }; + Ok(lp) + }, + #[cfg(feature = "python")] + PythonScan { mut options } => { + if self.is_count_star { + ctx.process_count_star_at_scan(&options.schema, expr_arena); + } + + options.with_columns = + get_scan_columns(&ctx.acc_projections, expr_arena, None, None); + + options.output_schema = if options.with_columns.is_none() { + None + } else { + Some(Arc::new(update_scan_schema( + &ctx.acc_projections, + expr_arena, + &options.schema, + true, + )?)) + }; + Ok(PythonScan { options }) + }, + Scan { + sources, + mut file_info, + hive_parts, + scan_type, + predicate, + mut unified_scan_args, + mut output_schema, + } => { + let do_optimization = match &*scan_type { + FileScan::Anonymous { function, .. } => function.allows_projection_pushdown(), + #[cfg(feature = "json")] + FileScan::NDJson { .. } => true, + #[cfg(feature = "ipc")] + FileScan::Ipc { .. } => true, + #[cfg(feature = "csv")] + FileScan::Csv { .. } => true, + #[cfg(feature = "parquet")] + FileScan::Parquet { .. } => true, + }; + + #[expect(clippy::never_loop)] + loop { + if !do_optimization { + break; + } + + if self.is_count_star { + if let FileScan::Anonymous { .. } = &*scan_type { + // Anonymous scan is not controlled by us, we don't know if it can support + // 0-column projections, so we always project one. + use either::Either; + + let projection: Arc<[PlSmallStr]> = match &file_info.reader_schema { + Some(Either::Left(s)) => s.iter_names().next(), + Some(Either::Right(s)) => s.iter_names().next(), + None => None, + } + .into_iter() + .cloned() + .collect(); + + unified_scan_args.projection = Some(projection.clone()); + + if projection.is_empty() { + output_schema = Some(Default::default()); + break; + } + + ctx.acc_projections.push(ColumnNode( + expr_arena.add(AExpr::Column(projection[0].clone())), + )); + + unified_scan_args.projection = Some(projection) + } else { + // All nodes in new-streaming support projecting empty morsels with the correct height + // from the file. + unified_scan_args.projection = Some(Arc::from([])); + output_schema = Some(Default::default()); + break; + }; + } + + unified_scan_args.projection = get_scan_columns( + &ctx.acc_projections, + expr_arena, + unified_scan_args.row_index.as_ref(), + unified_scan_args.include_file_paths.as_deref(), + ); + + output_schema = if unified_scan_args.projection.is_some() { + let mut schema = update_scan_schema( + &ctx.acc_projections, + expr_arena, + &file_info.schema, + scan_type.sort_projection(unified_scan_args.row_index.is_some()), + )?; + + if let Some(ref file_path_col) = unified_scan_args.include_file_paths { + if let Some(i) = schema.index_of(file_path_col) { + let (name, dtype) = schema.shift_remove_index(i).unwrap(); + schema.insert_at_index(schema.len(), name, dtype)?; + } + } + + Some(Arc::new(schema)) + } else { + None + }; + + break; + } + + // File builder has a row index, but projected columns + // do not include it, so cull. + if let Some(RowIndex { ref name, .. }) = unified_scan_args.row_index { + if output_schema + .as_ref() + .is_some_and(|schema| !schema.contains(name)) + { + // Need to remove it from the input schema so + // that projection indices are correct. + let mut file_schema = Arc::unwrap_or_clone(file_info.schema); + file_schema.shift_remove(name); + file_info.schema = Arc::new(file_schema); + unified_scan_args.row_index = None; + } + }; + + if let Some(col_name) = &unified_scan_args.include_file_paths { + if output_schema + .as_ref() + .is_some_and(|schema| !schema.contains(col_name)) + { + // Need to remove it from the input schema so + // that projection indices are correct. + let mut file_schema = Arc::unwrap_or_clone(file_info.schema); + file_schema.shift_remove(col_name); + file_info.schema = Arc::new(file_schema); + unified_scan_args.include_file_paths = None; + } + }; + + let lp = Scan { + sources, + file_info, + hive_parts, + output_schema, + scan_type, + predicate, + unified_scan_args, + }; + + Ok(lp) + }, + Sort { + input, + by_column, + slice, + sort_options, + } => { + if ctx.has_pushed_down() { + // Make sure that the column(s) used for the sort is projected + by_column.iter().for_each(|node| { + add_expr_to_accumulated( + node.node(), + &mut ctx.acc_projections, + &mut ctx.projected_names, + expr_arena, + ); + }); + } + + self.pushdown_and_assign(input, ctx, lp_arena, expr_arena)?; + Ok(Sort { + input, + by_column, + slice, + sort_options, + }) + }, + Distinct { input, options } => { + // make sure that the set of unique columns is projected + if ctx.has_pushed_down() { + if let Some(subset) = options.subset.as_ref() { + subset.iter().for_each(|name| { + add_str_to_accumulated(name.clone(), &mut ctx, expr_arena) + }) + } else { + // distinct needs all columns + let input_schema = lp_arena.get(input).schema(lp_arena); + for name in input_schema.iter_names() { + add_str_to_accumulated(name.clone(), &mut ctx, expr_arena) + } + } + } + + self.pushdown_and_assign(input, ctx, lp_arena, expr_arena)?; + Ok(Distinct { input, options }) + }, + Filter { predicate, input } => { + if ctx.has_pushed_down() { + // make sure that the filter column is projected + add_expr_to_accumulated( + predicate.node(), + &mut ctx.acc_projections, + &mut ctx.projected_names, + expr_arena, + ); + }; + self.pushdown_and_assign(input, ctx, lp_arena, expr_arena)?; + Ok(Filter { predicate, input }) + }, + GroupBy { + input, + keys, + aggs, + apply, + schema, + maintain_order, + options, + } => process_group_by( + self, + input, + keys, + aggs, + apply, + schema, + maintain_order, + options, + ctx, + lp_arena, + expr_arena, + ), + Join { + input_left, + input_right, + left_on, + right_on, + options, + schema, + } => match options.args.how { + #[cfg(feature = "semi_anti_join")] + JoinType::Semi | JoinType::Anti => process_semi_anti_join( + self, + input_left, + input_right, + left_on, + right_on, + options, + ctx, + lp_arena, + expr_arena, + ), + _ => process_join( + self, + input_left, + input_right, + left_on, + right_on, + options, + ctx, + lp_arena, + expr_arena, + &schema, + ), + }, + HStack { + input, + exprs, + options, + .. + } => process_hstack(self, input, exprs, options, ctx, lp_arena, expr_arena), + ExtContext { + input, contexts, .. + } => { + // local projections are ignored. These are just root nodes + // complex expression will still be done later + let _local_projections = + self.pushdown_and_assign_check_schema(input, ctx, lp_arena, expr_arena, false)?; + + let mut new_schema = lp_arena + .get(input) + .schema(lp_arena) + .as_ref() + .as_ref() + .clone(); + + for node in &contexts { + let other_schema = lp_arena.get(*node).schema(lp_arena); + for fld in other_schema.iter_fields() { + if new_schema.get(fld.name()).is_none() { + new_schema.with_column(fld.name, fld.dtype); + } + } + } + + Ok(ExtContext { + input, + contexts, + schema: Arc::new(new_schema), + }) + }, + MapFunction { input, function } => { + functions::process_functions(self, input, function, ctx, lp_arena, expr_arena) + }, + HConcat { + inputs, + schema, + options, + } => process_hconcat(self, inputs, schema, options, ctx, lp_arena, expr_arena), + lp @ Union { .. } => process_generic(self, lp, ctx, lp_arena, expr_arena), + // These nodes only have inputs and exprs, so we can use same logic. + lp @ Slice { .. } | lp @ Sink { .. } | lp @ SinkMultiple { .. } => { + process_generic(self, lp, ctx, lp_arena, expr_arena) + }, + Cache { .. } => { + // projections above this cache will be accumulated and pushed down + // later + // the redundant projection will be cleaned in the fast projection optimization + // phase. + if ctx.acc_projections.is_empty() { + Ok(logical_plan) + } else { + Ok(IRBuilder::from_lp(logical_plan, expr_arena, lp_arena) + .project_simple_nodes(ctx.acc_projections) + .unwrap() + .build()) + } + }, + #[cfg(feature = "merge_sorted")] + MergeSorted { + input_left, + input_right, + key, + } => { + if ctx.has_pushed_down() { + // make sure that the filter column is projected + add_str_to_accumulated(key.clone(), &mut ctx, expr_arena); + }; + + self.pushdown_and_assign(input_left, ctx.clone(), lp_arena, expr_arena)?; + self.pushdown_and_assign(input_right, ctx, lp_arena, expr_arena)?; + + Ok(MergeSorted { + input_left, + input_right, + key, + }) + }, + Invalid => unreachable!(), + } + } + + pub fn optimize( + &mut self, + logical_plan: IR, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + ) -> PolarsResult { + let ctx = ProjectionContext::default(); + self.push_down(logical_plan, ctx, lp_arena, expr_arena) + } +} diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/projection.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/projection.rs new file mode 100644 index 000000000000..3dd40aad3b4c --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/projection.rs @@ -0,0 +1,157 @@ +use super::*; + +#[inline] +pub(super) fn is_count(node: Node, expr_arena: &Arena) -> bool { + matches!(expr_arena.get(node), AExpr::Len) +} + +#[allow(clippy::too_many_arguments)] +pub(super) fn process_projection( + proj_pd: &mut ProjectionPushDown, + input: Node, + mut exprs: Vec, + mut ctx: ProjectionContext, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + // Whether is SimpleProjection. + simple: bool, +) -> PolarsResult { + let mut local_projection = Vec::with_capacity(exprs.len()); + + // Special path for `SELECT count(*) FROM` + // as there would be no projections and we would read + // the whole file while we only want the count + if exprs.len() == 1 && is_count(exprs[0].node(), expr_arena) { + // Clear all accumulated projections since we only project a single column from this level. + ctx.acc_projections.clear(); + ctx.projected_names.clear(); + + let input_lp = lp_arena.get(input); + + // If the input node is not aware of `is_count_star` we must project a single column from + // this level, otherwise the upstream nodes may end up projecting everything. + let input_is_count_star_aware = match input_lp { + IR::DataFrameScan { .. } | IR::Scan { .. } => true, + #[cfg(feature = "python")] + IR::PythonScan { .. } => true, + _ => false, + }; + + if !input_is_count_star_aware { + if let Some(name) = input_lp + .schema(lp_arena) + .get_at_index(0) + .map(|(name, _)| name) + { + ctx.acc_projections + .push(ColumnNode(expr_arena.add(AExpr::Column(name.clone())))); + ctx.projected_names.insert(name.clone()); + } + } + + local_projection.push(exprs.pop().unwrap()); + + if input_is_count_star_aware { + ctx.inner.is_count_star = true; + proj_pd.is_count_star = true; + } + } else { + // `remove_names` tracks projected names that need to be removed as they may be aliased + // names that are created on this level. + let mut remove_names = PlHashSet::new(); + + // If there are non-scalar projections we must project at least one of them to maintain the + // output height. + let mut opt_non_scalar = None; + let mut projection_has_non_scalar = false; + + let projected_exprs: Vec = exprs + .into_iter() + .filter(|e| { + let is_non_scalar = !e.is_scalar(expr_arena); + + if opt_non_scalar.is_none() && is_non_scalar { + opt_non_scalar = Some(e.clone()) + } + + let name = match e.output_name_inner() { + OutputName::LiteralLhs(name) | OutputName::Alias(name) => { + remove_names.insert(name.clone()); + name + }, + #[cfg(feature = "dtype-struct")] + OutputName::Field(name) => { + remove_names.insert(name.clone()); + name + }, + OutputName::ColumnLhs(name) => name, + OutputName::None => { + if cfg!(debug_assertions) { + panic!() + } else { + return false; + } + }, + }; + + let project = ctx.acc_projections.is_empty() || ctx.projected_names.contains(name); + projection_has_non_scalar |= project & is_non_scalar; + project + }) + .collect(); + + // Remove aliased before adding new ones. + if !remove_names.is_empty() { + if !ctx.projected_names.is_empty() { + for name in remove_names.iter() { + ctx.projected_names.remove(name); + } + } + + ctx.acc_projections + .retain(|c| !remove_names.contains(column_node_to_name(*c, expr_arena))); + } + + for e in projected_exprs { + add_expr_to_accumulated( + e.node(), + &mut ctx.acc_projections, + &mut ctx.projected_names, + expr_arena, + ); + + // do local as we still need the effect of the projection + // e.g. a projection is more than selecting a column, it can + // also be a function/ complicated expression + local_projection.push(e); + } + + if !projection_has_non_scalar { + if let Some(non_scalar) = opt_non_scalar { + add_expr_to_accumulated( + non_scalar.node(), + &mut ctx.acc_projections, + &mut ctx.projected_names, + expr_arena, + ); + + local_projection.push(non_scalar); + } + } + } + + ctx.inner.projections_seen += 1; + proj_pd.pushdown_and_assign(input, ctx, lp_arena, expr_arena)?; + + let builder = IRBuilder::new(input, expr_arena, lp_arena); + + let lp = if !local_projection.is_empty() && simple { + builder + .project_simple_nodes(local_projection.into_iter().map(|e| e.node()))? + .build() + } else { + proj_pd.finish_node(local_projection, builder) + }; + + Ok(lp) +} diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/rename.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/rename.rs new file mode 100644 index 000000000000..37142ba90943 --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/rename.rs @@ -0,0 +1,64 @@ +use std::collections::BTreeSet; + +use polars_utils::pl_str::PlSmallStr; + +use super::*; + +fn iter_and_update_nodes( + existing: &str, + new: &str, + acc_projections: &mut [ColumnNode], + expr_arena: &mut Arena, + processed: &mut BTreeSet, +) { + for column_node in acc_projections.iter_mut() { + let node = column_node.0; + if !processed.contains(&node.0) { + // We walk the query backwards, so we rename new to existing + if column_node_to_name(*column_node, expr_arena) == new { + let new_node = expr_arena.add(AExpr::Column(PlSmallStr::from_str(existing))); + *column_node = ColumnNode(new_node); + processed.insert(new_node.0); + } + } + } +} + +#[allow(clippy::too_many_arguments)] +pub(super) fn process_rename( + acc_projections: &mut [ColumnNode], + projected_names: &mut PlHashSet, + expr_arena: &mut Arena, + existing: &[PlSmallStr], + new: &[PlSmallStr], + swapping: bool, +) -> PolarsResult<()> { + if swapping { + let reverse_map: PlHashMap<_, _> = + new.iter().cloned().zip(existing.iter().cloned()).collect(); + let mut new_projected_names = PlHashSet::with_capacity(projected_names.len()); + + for col in acc_projections { + let name = column_node_to_name(*col, expr_arena); + + if let Some(previous) = reverse_map.get(name) { + let new = expr_arena.add(AExpr::Column(previous.clone())); + *col = ColumnNode(new); + let _ = new_projected_names.insert(previous.clone()); + } else { + let _ = new_projected_names.insert(name.clone()); + } + } + *projected_names = new_projected_names; + } else { + let mut processed = BTreeSet::new(); + for (existing, new) in existing.iter().zip(new.iter()) { + if projected_names.remove(new.as_str()) { + let name = existing.clone(); + projected_names.insert(name); + iter_and_update_nodes(existing, new, acc_projections, expr_arena, &mut processed); + } + } + } + Ok(()) +} diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/semi_anti_join.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/semi_anti_join.rs new file mode 100644 index 000000000000..607e0db2597e --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/semi_anti_join.rs @@ -0,0 +1,73 @@ +use super::*; + +#[allow(clippy::too_many_arguments)] +pub(super) fn process_semi_anti_join( + proj_pd: &mut ProjectionPushDown, + input_left: Node, + input_right: Node, + left_on: Vec, + right_on: Vec, + options: Arc, + ctx: ProjectionContext, + lp_arena: &mut Arena, + expr_arena: &mut Arena, +) -> PolarsResult { + // n = 0 if no projections, so we don't allocate unneeded + let n = ctx.acc_projections.len() * 2; + let mut pushdown_left = Vec::with_capacity(n); + let mut pushdown_right = Vec::with_capacity(n); + let mut names_left = PlHashSet::with_capacity(n); + let mut names_right = PlHashSet::with_capacity(n); + + if !ctx.has_pushed_down() { + // Only project the join columns. + for e in &right_on { + add_expr_to_accumulated(e.node(), &mut pushdown_right, &mut names_right, expr_arena); + } + } else { + // We build local projections to sort out proper column names due to the + // join operation. + // Joins on columns with different names, for example + // left_on = "a", right_on = "b + // will remove the name "b" (it is "a" now). That columns should therefore not + // be added to a local projection. + let schema_left = lp_arena.get(input_left).schema(lp_arena); + let schema_right = lp_arena.get(input_right).schema(lp_arena); + + // We need the join columns so we push the projection downwards + for e in &left_on { + add_expr_to_accumulated(e.node(), &mut pushdown_left, &mut names_left, expr_arena); + } + for e in &right_on { + add_expr_to_accumulated(e.node(), &mut pushdown_right, &mut names_right, expr_arena); + } + + for proj in ctx.acc_projections { + let _ = proj_pd.join_push_down( + &schema_left, + &schema_right, + proj, + &mut pushdown_left, + &mut pushdown_right, + &mut names_left, + &mut names_right, + expr_arena, + ); + } + } + + let ctx_left = ProjectionContext::new(pushdown_left, names_left, ctx.inner); + let ctx_right = ProjectionContext::new(pushdown_right, names_right, ctx.inner); + + proj_pd.pushdown_and_assign(input_left, ctx_left, lp_arena, expr_arena)?; + proj_pd.pushdown_and_assign(input_right, ctx_right, lp_arena, expr_arena)?; + + let alp = IRBuilder::new(input_left, expr_arena, lp_arena) + .join(input_right, left_on, right_on, options) + .build(); + + let root = lp_arena.add(alp); + let builder = IRBuilder::new(root, expr_arena, lp_arena); + + Ok(proj_pd.finish_node(vec![], builder)) +} diff --git a/crates/polars-plan/src/plans/optimizer/set_order.rs b/crates/polars-plan/src/plans/optimizer/set_order.rs new file mode 100644 index 000000000000..a2f71d1568c1 --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/set_order.rs @@ -0,0 +1,161 @@ +use polars_utils::unitvec; + +use super::*; + +// Can give false positives. +fn is_order_dependent_top_level(ae: &AExpr, ctx: Context) -> bool { + match ae { + AExpr::Agg(agg) => match agg { + IRAggExpr::Min { .. } => false, + IRAggExpr::Max { .. } => false, + IRAggExpr::Median(_) => false, + IRAggExpr::NUnique(_) => false, + IRAggExpr::First(_) => true, + IRAggExpr::Last(_) => true, + IRAggExpr::Mean(_) => false, + IRAggExpr::Implode(_) => true, + IRAggExpr::Quantile { .. } => false, + IRAggExpr::Sum(_) => false, + IRAggExpr::Count(_, _) => false, + IRAggExpr::Std(_, _) => false, + IRAggExpr::Var(_, _) => false, + IRAggExpr::AggGroups(_) => true, + }, + AExpr::Column(_) => matches!(ctx, Context::Aggregation), + _ => true, + } +} + +// Can give false positives. +fn is_order_dependent<'a>(mut ae: &'a AExpr, expr_arena: &'a Arena, ctx: Context) -> bool { + let mut stack = unitvec![]; + + loop { + if is_order_dependent_top_level(ae, ctx) { + return true; + } + + let Some(node) = stack.pop() else { + break; + }; + + ae = expr_arena.get(node); + } + + false +} + +// Can give false negatives. +pub(crate) fn all_order_independent<'a, N>( + nodes: &'a [N], + expr_arena: &Arena, + ctx: Context, +) -> bool +where + Node: From<&'a N>, +{ + !nodes + .iter() + .any(|n| is_order_dependent(expr_arena.get(n.into()), expr_arena, ctx)) +} + +// Should run before slice pushdown. +pub(super) fn set_order_flags( + root: Node, + ir_arena: &mut Arena, + expr_arena: &Arena, + scratch: &mut Vec, +) { + scratch.clear(); + scratch.push(root); + + let mut maintain_order_above = true; + + while let Some(node) = scratch.pop() { + let ir = ir_arena.get_mut(node); + ir.copy_inputs(scratch); + + match ir { + IR::Sort { + input, + sort_options, + .. + } => { + debug_assert!(sort_options.limit.is_none()); + // This sort can be removed + if !maintain_order_above { + scratch.pop(); + scratch.push(node); + let input = *input; + ir_arena.swap(node, input); + continue; + } + + if !sort_options.maintain_order { + maintain_order_above = false; // `maintain_order=True` is influenced by result of earlier sorts + } + }, + IR::Distinct { options, .. } => { + debug_assert!(options.slice.is_none()); + if !maintain_order_above { + options.maintain_order = false; + continue; + } + if matches!( + options.keep_strategy, + UniqueKeepStrategy::First | UniqueKeepStrategy::Last + ) { + maintain_order_above = true; + } else if !options.maintain_order { + maintain_order_above = false; + } + }, + IR::Union { options, .. } => { + debug_assert!(options.slice.is_none()); + options.maintain_order = maintain_order_above; + }, + IR::GroupBy { + keys, + aggs, + maintain_order, + options, + apply, + .. + } => { + debug_assert!(options.slice.is_none()); + if apply.is_some() + || *maintain_order + || options.is_rolling() + || options.is_dynamic() + { + maintain_order_above = true; + continue; + } + if !maintain_order_above && *maintain_order { + *maintain_order = false; + continue; + } + + if all_elementwise(keys, expr_arena) + && all_order_independent(aggs, expr_arena, Context::Aggregation) + { + maintain_order_above = false; + continue; + } + maintain_order_above = true; + }, + // Conservative now. + IR::HStack { exprs, .. } | IR::Select { expr: exprs, .. } => { + if !maintain_order_above && all_elementwise(exprs, expr_arena) { + continue; + } + maintain_order_above = true; + }, + _ => { + // If we don't know maintain order + // Known: slice + maintain_order_above = true; + }, + } + } +} diff --git a/crates/polars-plan/src/plans/optimizer/simplify_expr/mod.rs b/crates/polars-plan/src/plans/optimizer/simplify_expr/mod.rs new file mode 100644 index 000000000000..cf7b3ccccacf --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/simplify_expr/mod.rs @@ -0,0 +1,789 @@ +mod simplify_functions; + +use polars_utils::floor_divmod::FloorDivMod; +use polars_utils::total_ord::ToTotalOrd; +use simplify_functions::optimize_functions; + +use crate::plans::*; + +fn new_null_count(input: &[ExprIR]) -> AExpr { + AExpr::Function { + input: input.to_vec(), + function: FunctionExpr::NullCount, + options: FunctionOptions { + collect_groups: ApplyOptions::GroupWise, + flags: FunctionFlags::ALLOW_GROUP_AWARE | FunctionFlags::RETURNS_SCALAR, + ..Default::default() + }, + } +} + +macro_rules! eval_binary_same_type { + ($lhs:expr, $rhs:expr, |$l: ident, $r: ident| $ret: expr) => {{ + if let (AExpr::Literal(lit_left), AExpr::Literal(lit_right)) = ($lhs, $rhs) { + match (lit_left, lit_right) { + (LiteralValue::Scalar(l), LiteralValue::Scalar(r)) => { + match (l.as_any_value(), r.as_any_value()) { + (AnyValue::Float32($l), AnyValue::Float32($r)) => { + Some(AExpr::Literal(>::from($ret).into())) + }, + (AnyValue::Float64($l), AnyValue::Float64($r)) => { + Some(AExpr::Literal(>::from($ret).into())) + }, + + (AnyValue::Int8($l), AnyValue::Int8($r)) => { + Some(AExpr::Literal(>::from($ret).into())) + }, + (AnyValue::Int16($l), AnyValue::Int16($r)) => { + Some(AExpr::Literal(>::from($ret).into())) + }, + (AnyValue::Int32($l), AnyValue::Int32($r)) => { + Some(AExpr::Literal(>::from($ret).into())) + }, + (AnyValue::Int64($l), AnyValue::Int64($r)) => { + Some(AExpr::Literal(>::from($ret).into())) + }, + (AnyValue::Int128($l), AnyValue::Int128($r)) => { + Some(AExpr::Literal(>::from($ret).into())) + }, + + (AnyValue::UInt8($l), AnyValue::UInt8($r)) => { + Some(AExpr::Literal(>::from($ret).into())) + }, + (AnyValue::UInt16($l), AnyValue::UInt16($r)) => { + Some(AExpr::Literal(>::from($ret).into())) + }, + (AnyValue::UInt32($l), AnyValue::UInt32($r)) => { + Some(AExpr::Literal(>::from($ret).into())) + }, + (AnyValue::UInt64($l), AnyValue::UInt64($r)) => { + Some(AExpr::Literal(>::from($ret).into())) + }, + + _ => None, + } + .into() + }, + ( + LiteralValue::Dyn(DynLiteralValue::Float($l)), + LiteralValue::Dyn(DynLiteralValue::Float($r)), + ) => { + let $l = *$l; + let $r = *$r; + Some(AExpr::Literal(LiteralValue::Dyn(DynLiteralValue::Float( + $ret, + )))) + }, + ( + LiteralValue::Dyn(DynLiteralValue::Int($l)), + LiteralValue::Dyn(DynLiteralValue::Int($r)), + ) => { + let $l = *$l; + let $r = *$r; + Some(AExpr::Literal(LiteralValue::Dyn(DynLiteralValue::Int( + $ret, + )))) + }, + _ => None, + } + } else { + None + } + }}; +} + +macro_rules! eval_binary_cmp_same_type { + ($lhs:expr, $operand: tt, $rhs:expr) => {{ + if let (AExpr::Literal(lit_left), AExpr::Literal(lit_right)) = ($lhs, $rhs) { + match (lit_left, lit_right) { + (LiteralValue::Scalar(l), LiteralValue::Scalar(r)) => match (l.as_any_value(), r.as_any_value()) { + (AnyValue::Float32(l), AnyValue::Float32(r)) => Some(AExpr::Literal({ let x: bool = l.to_total_ord() $operand r.to_total_ord(); Scalar::from(x) }.into())), + (AnyValue::Float64(l), AnyValue::Float64(r)) => Some(AExpr::Literal({ let x: bool = l.to_total_ord() $operand r.to_total_ord(); Scalar::from(x) }.into())), + + (AnyValue::Boolean(l), AnyValue::Boolean(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())), + + (AnyValue::Int8(l), AnyValue::Int8(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())), + (AnyValue::Int16(l), AnyValue::Int16(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())), + (AnyValue::Int32(l), AnyValue::Int32(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())), + (AnyValue::Int64(l), AnyValue::Int64(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())), + (AnyValue::Int128(l), AnyValue::Int128(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())), + + (AnyValue::UInt8(l), AnyValue::UInt8(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())), + (AnyValue::UInt16(l), AnyValue::UInt16(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())), + (AnyValue::UInt32(l), AnyValue::UInt32(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())), + (AnyValue::UInt64(l), AnyValue::UInt64(r)) => Some(AExpr::Literal({ let x: bool = l $operand r; Scalar::from(x) }.into())), + + _ => None, + }.into(), + (LiteralValue::Dyn(DynLiteralValue::Float(l)), LiteralValue::Dyn(DynLiteralValue::Float(r))) => { + let x: bool = l.to_total_ord() $operand r.to_total_ord(); + Some(AExpr::Literal(Scalar::from(x).into())) + }, + (LiteralValue::Dyn(DynLiteralValue::Int(l)), LiteralValue::Dyn(DynLiteralValue::Int(r))) => { + let x: bool = l $operand r; + Some(AExpr::Literal(Scalar::from(x).into())) + }, + _ => None, + } + } else { + None + } + + }} +} + +pub struct SimplifyBooleanRule {} + +impl OptimizationRule for SimplifyBooleanRule { + fn optimize_expr( + &mut self, + expr_arena: &mut Arena, + expr_node: Node, + lp_arena: &Arena, + lp_node: Node, + ) -> PolarsResult> { + let expr = expr_arena.get(expr_node); + let in_filter = matches!(lp_arena.get(lp_node), IR::Filter { .. }); + + let out = match expr { + // true AND x => x + AExpr::BinaryExpr { + left, + op: Operator::And, + right, + } if matches!( + expr_arena.get(*left), + AExpr::Literal(lv) if lv.bool() == Some(true) + ) && in_filter => + { + // Only in filter as we might change the name from "literal" + // to whatever lhs columns is. + return Ok(Some(expr_arena.get(*right).clone())); + }, + // x AND true => x + AExpr::BinaryExpr { + left, + op: Operator::And, + right, + } if matches!( + expr_arena.get(*right), + AExpr::Literal(lv) if lv.bool() == Some(true) + ) => + { + Some(expr_arena.get(*left).clone()) + }, + + // x AND false -> false + // FIXME: we need an optimizer redesign to allow x & false to be optimized + // in general as we can forget the length of a series otherwise. + AExpr::BinaryExpr { + left, + op: Operator::And, + right, + } if matches!(expr_arena.get(*left), AExpr::Literal(_)) + && matches!( + expr_arena.get(*right), + AExpr::Literal(lv) if lv.bool() == Some(false) + ) => + { + Some(AExpr::Literal(Scalar::from(false).into())) + }, + + // false AND x -> false + // FIXME: we need an optimizer redesign to allow false & x to be optimized + // in general as we can forget the length of a series otherwise. + AExpr::BinaryExpr { + left, + op: Operator::And, + right, + } if matches!( + expr_arena.get(*left), + AExpr::Literal(lv) if lv.bool() == Some(false) + ) && matches!(expr_arena.get(*right), AExpr::Literal(_)) => + { + Some(AExpr::Literal(Scalar::from(false).into())) + }, + + // false or x => x + AExpr::BinaryExpr { + left, + op: Operator::Or, + right, + } if matches!( + expr_arena.get(*left), + AExpr::Literal(lv) if lv.bool() == Some(false) + ) && in_filter => + { + // Only in filter as we might change the name from "literal" + // to whatever lhs columns is. + return Ok(Some(expr_arena.get(*right).clone())); + }, + // x or false => x + AExpr::BinaryExpr { + left, + op: Operator::Or, + right, + .. + } if matches!( + expr_arena.get(*right), + AExpr::Literal(lv) if lv.bool() == Some(false) + ) => + { + Some(expr_arena.get(*left).clone()) + }, + + // true OR x => true + // FIXME: we need an optimizer redesign to allow true | x to be optimized + // in general as we can forget the length of a series otherwise. + AExpr::BinaryExpr { + left, + op: Operator::Or, + right, + } if matches!(expr_arena.get(*left), AExpr::Literal(_)) + && matches!( + expr_arena.get(*right), + AExpr::Literal(lv) if lv.bool() == Some(true) + ) => + { + Some(AExpr::Literal(Scalar::from(true).into())) + }, + + // x OR true => true + // FIXME: we need an optimizer redesign to allow true | x to be optimized + // in general as we can forget the length of a series otherwise. + AExpr::BinaryExpr { + left, + op: Operator::Or, + right, + } if matches!( + expr_arena.get(*left), + AExpr::Literal(lv) if lv.bool() == Some(true) + ) && matches!(expr_arena.get(*right), AExpr::Literal(_)) => + { + Some(AExpr::Literal(Scalar::from(true).into())) + }, + AExpr::Function { + input, + function: FunctionExpr::Negate, + .. + } if input.len() == 1 => { + let input = &input[0]; + let ae = expr_arena.get(input.node()); + eval_negate(ae) + }, + _ => None, + }; + Ok(out) + } +} + +fn eval_negate(ae: &AExpr) -> Option { + use std::ops::Neg; + let out = match ae { + AExpr::Literal(lv) => match lv { + LiteralValue::Scalar(sc) => match sc.as_any_value() { + AnyValue::Int8(v) => Scalar::from(v.checked_neg()?), + AnyValue::Int16(v) => Scalar::from(v.checked_neg()?), + AnyValue::Int32(v) => Scalar::from(v.checked_neg()?), + AnyValue::Int64(v) => Scalar::from(v.checked_neg()?), + AnyValue::Float32(v) => Scalar::from(v.neg()), + AnyValue::Float64(v) => Scalar::from(v.neg()), + _ => return None, + } + .into(), + LiteralValue::Dyn(d) => LiteralValue::Dyn(match d { + DynLiteralValue::Int(v) => DynLiteralValue::Int(v.checked_neg()?), + DynLiteralValue::Float(v) => DynLiteralValue::Float(v.neg()), + _ => return None, + }), + _ => return None, + }, + _ => return None, + }; + Some(AExpr::Literal(out)) +} + +fn eval_bitwise(left: &AExpr, right: &AExpr, operation: F) -> Option +where + F: Fn(bool, bool) -> bool, +{ + if let (AExpr::Literal(lit_left), AExpr::Literal(lit_right)) = (left, right) { + return match (lit_left.bool(), lit_right.bool()) { + (Some(x), Some(y)) => Some(AExpr::Literal(Scalar::from(operation(x, y)).into())), + _ => None, + }; + } + None +} + +#[cfg(all(feature = "strings", feature = "concat_str"))] +fn string_addition_to_linear_concat( + lp_arena: &Arena, + lp_node: Node, + expr_arena: &Arena, + left_node: Node, + right_node: Node, + left_aexpr: &AExpr, + right_aexpr: &AExpr, +) -> Option { + { + let lp = lp_arena.get(lp_node); + let input = lp.get_input()?; + let schema = lp_arena.get(input).schema(lp_arena); + let left_e = ExprIR::from_node(left_node, expr_arena); + let right_e = ExprIR::from_node(right_node, expr_arena); + + let get_type = |ae: &AExpr| ae.get_type(&schema, Context::Default, expr_arena).ok(); + let type_a = get_type(left_aexpr).or_else(|| get_type(right_aexpr))?; + let type_b = get_type(right_aexpr).or_else(|| get_type(right_aexpr))?; + + if type_a != type_b { + return None; + } + + if type_a.is_string() { + match (left_aexpr, right_aexpr) { + // concat + concat + ( + AExpr::Function { + input: input_left, + function: + fun_l @ FunctionExpr::StringExpr(StringFunction::ConcatHorizontal { + delimiter: sep_l, + ignore_nulls: ignore_nulls_l, + }), + options, + }, + AExpr::Function { + input: input_right, + function: + FunctionExpr::StringExpr(StringFunction::ConcatHorizontal { + delimiter: sep_r, + ignore_nulls: ignore_nulls_r, + }), + .. + }, + ) => { + if sep_l.is_empty() && sep_r.is_empty() && ignore_nulls_l == ignore_nulls_r { + let mut input = Vec::with_capacity(input_left.len() + input_right.len()); + input.extend_from_slice(input_left); + input.extend_from_slice(input_right); + Some(AExpr::Function { + input, + function: fun_l.clone(), + options: *options, + }) + } else { + None + } + }, + // concat + str + ( + AExpr::Function { + input, + function: + fun @ FunctionExpr::StringExpr(StringFunction::ConcatHorizontal { + delimiter: sep, + ignore_nulls, + }), + options, + }, + _, + ) => { + if sep.is_empty() && !ignore_nulls { + let mut input = input.clone(); + input.push(right_e); + Some(AExpr::Function { + input, + function: fun.clone(), + options: *options, + }) + } else { + None + } + }, + // str + concat + ( + _, + AExpr::Function { + input: input_right, + function: + fun @ FunctionExpr::StringExpr(StringFunction::ConcatHorizontal { + delimiter: sep, + ignore_nulls, + }), + options, + }, + ) => { + if sep.is_empty() && !ignore_nulls { + let mut input = Vec::with_capacity(1 + input_right.len()); + input.push(left_e); + input.extend_from_slice(input_right); + Some(AExpr::Function { + input, + function: fun.clone(), + options: *options, + }) + } else { + None + } + }, + _ => Some(AExpr::Function { + input: vec![left_e, right_e], + function: StringFunction::ConcatHorizontal { + delimiter: "".into(), + ignore_nulls: false, + } + .into(), + options: FunctionOptions { + collect_groups: ApplyOptions::ElementWise, + flags: FunctionFlags::default() + | FunctionFlags::INPUT_WILDCARD_EXPANSION + & !FunctionFlags::RETURNS_SCALAR, + ..Default::default() + }, + }), + } + } else { + None + } + } +} + +pub struct SimplifyExprRule {} + +impl OptimizationRule for SimplifyExprRule { + #[allow(clippy::float_cmp)] + fn optimize_expr( + &mut self, + expr_arena: &mut Arena, + expr_node: Node, + lp_arena: &Arena, + lp_node: Node, + ) -> PolarsResult> { + let expr = expr_arena.get(expr_node).clone(); + + let out = match &expr { + // drop_nulls().len() -> len() - null_count() + // drop_nulls().count() -> len() - null_count() + AExpr::Agg(IRAggExpr::Count(input, _)) => { + let input_expr = expr_arena.get(*input); + match input_expr { + AExpr::Function { + input, + function: FunctionExpr::DropNulls, + options: _, + } => { + // we should perform optimization only if the original expression is a column + // so in case of disabled CSE, we will not suffer from performance regression + if input.len() == 1 { + let drop_nulls_input_node = input[0].node(); + match expr_arena.get(drop_nulls_input_node) { + AExpr::Column(_) => Some(AExpr::BinaryExpr { + op: Operator::Minus, + right: expr_arena.add(new_null_count(input)), + left: expr_arena.add(AExpr::Agg(IRAggExpr::Count( + drop_nulls_input_node, + true, + ))), + }), + _ => None, + } + } else { + None + } + }, + _ => None, + } + }, + // is_null().sum() -> null_count() + // is_not_null().sum() -> len() - null_count() + AExpr::Agg(IRAggExpr::Sum(input)) => { + let input_expr = expr_arena.get(*input); + match input_expr { + AExpr::Function { + input, + function: FunctionExpr::Boolean(BooleanFunction::IsNull), + options: _, + } => Some(new_null_count(input)), + AExpr::Function { + input, + function: FunctionExpr::Boolean(BooleanFunction::IsNotNull), + options: _, + } => { + // we should perform optimization only if the original expression is a column + // so in case of disabled CSE, we will not suffer from performance regression + if input.len() == 1 { + let is_not_null_input_node = input[0].node(); + match expr_arena.get(is_not_null_input_node) { + AExpr::Column(_) => Some(AExpr::BinaryExpr { + op: Operator::Minus, + right: expr_arena.add(new_null_count(input)), + left: expr_arena.add(AExpr::Agg(IRAggExpr::Count( + is_not_null_input_node, + true, + ))), + }), + _ => None, + } + } else { + None + } + }, + _ => None, + } + }, + // lit(left) + lit(right) => lit(left + right) + // and null propagation + AExpr::BinaryExpr { left, op, right } => { + let left_aexpr = expr_arena.get(*left); + let right_aexpr = expr_arena.get(*right); + + // lit(left) + lit(right) => lit(left + right) + use Operator::*; + #[allow(clippy::manual_map)] + let out = match op { + Plus => { + match eval_binary_same_type!(left_aexpr, right_aexpr, |l, r| l + r) { + Some(new) => Some(new), + None => { + // try to replace addition of string columns with `concat_str` + #[cfg(all(feature = "strings", feature = "concat_str"))] + { + string_addition_to_linear_concat( + lp_arena, + lp_node, + expr_arena, + *left, + *right, + left_aexpr, + right_aexpr, + ) + } + #[cfg(not(all(feature = "strings", feature = "concat_str")))] + { + None + } + }, + } + }, + Minus => eval_binary_same_type!(left_aexpr, right_aexpr, |l, r| l - r), + Multiply => eval_binary_same_type!(left_aexpr, right_aexpr, |l, r| l * r), + Divide => { + if let (AExpr::Literal(lit_left), AExpr::Literal(lit_right)) = + (left_aexpr, right_aexpr) + { + match (lit_left, lit_right) { + (LiteralValue::Scalar(l), LiteralValue::Scalar(r)) => { + match (l.as_any_value(), r.as_any_value()) { + (AnyValue::Float32(x), AnyValue::Float32(y)) => { + Some(AExpr::Literal( + >::from(x / y).into(), + )) + }, + (AnyValue::Float64(x), AnyValue::Float64(y)) => { + Some(AExpr::Literal( + >::from(x / y).into(), + )) + }, + + (AnyValue::Int8(x), AnyValue::Int8(y)) => { + Some(AExpr::Literal( + >::from( + x.wrapping_floor_div_mod(y).0, + ) + .into(), + )) + }, + (AnyValue::Int16(x), AnyValue::Int16(y)) => { + Some(AExpr::Literal( + >::from( + x.wrapping_floor_div_mod(y).0, + ) + .into(), + )) + }, + (AnyValue::Int32(x), AnyValue::Int32(y)) => { + Some(AExpr::Literal( + >::from( + x.wrapping_floor_div_mod(y).0, + ) + .into(), + )) + }, + (AnyValue::Int64(x), AnyValue::Int64(y)) => { + Some(AExpr::Literal( + >::from( + x.wrapping_floor_div_mod(y).0, + ) + .into(), + )) + }, + (AnyValue::Int128(x), AnyValue::Int128(y)) => { + Some(AExpr::Literal( + >::from( + x.wrapping_floor_div_mod(y).0, + ) + .into(), + )) + }, + + (AnyValue::UInt8(x), AnyValue::UInt8(y)) => { + Some(AExpr::Literal( + >::from(x / y).into(), + )) + }, + (AnyValue::UInt16(x), AnyValue::UInt16(y)) => { + Some(AExpr::Literal( + >::from(x / y).into(), + )) + }, + (AnyValue::UInt32(x), AnyValue::UInt32(y)) => { + Some(AExpr::Literal( + >::from(x / y).into(), + )) + }, + (AnyValue::UInt64(x), AnyValue::UInt64(y)) => { + Some(AExpr::Literal( + >::from(x / y).into(), + )) + }, + + _ => None, + } + }, + + ( + LiteralValue::Dyn(DynLiteralValue::Float(x)), + LiteralValue::Dyn(DynLiteralValue::Float(y)), + ) => { + Some(AExpr::Literal(>::from(x / y).into())) + }, + ( + LiteralValue::Dyn(DynLiteralValue::Int(x)), + LiteralValue::Dyn(DynLiteralValue::Int(y)), + ) => Some(AExpr::Literal(LiteralValue::Dyn(DynLiteralValue::Int( + x.wrapping_floor_div_mod(*y).0, + )))), + _ => None, + } + } else { + None + } + }, + TrueDivide => { + if let (AExpr::Literal(lit_left), AExpr::Literal(lit_right)) = + (left_aexpr, right_aexpr) + { + match (lit_left, lit_right) { + (LiteralValue::Scalar(l), LiteralValue::Scalar(r)) => { + match (l.as_any_value(), r.as_any_value()) { + (AnyValue::Float32(x), AnyValue::Float32(y)) => { + Some(AExpr::Literal(Scalar::from(x / y).into())) + }, + (AnyValue::Float64(x), AnyValue::Float64(y)) => { + Some(AExpr::Literal(Scalar::from(x / y).into())) + }, + + (AnyValue::Int8(x), AnyValue::Int8(y)) => { + Some(AExpr::Literal( + Scalar::from(x as f64 / y as f64).into(), + )) + }, + (AnyValue::Int16(x), AnyValue::Int16(y)) => { + Some(AExpr::Literal( + Scalar::from(x as f64 / y as f64).into(), + )) + }, + (AnyValue::Int32(x), AnyValue::Int32(y)) => { + Some(AExpr::Literal( + Scalar::from(x as f64 / y as f64).into(), + )) + }, + (AnyValue::Int64(x), AnyValue::Int64(y)) => { + Some(AExpr::Literal( + Scalar::from(x as f64 / y as f64).into(), + )) + }, + (AnyValue::Int128(x), AnyValue::Int128(y)) => { + Some(AExpr::Literal( + Scalar::from(x as f64 / y as f64).into(), + )) + }, + + (AnyValue::UInt8(x), AnyValue::UInt8(y)) => { + Some(AExpr::Literal( + Scalar::from(x as f64 / y as f64).into(), + )) + }, + (AnyValue::UInt16(x), AnyValue::UInt16(y)) => { + Some(AExpr::Literal( + Scalar::from(x as f64 / y as f64).into(), + )) + }, + (AnyValue::UInt32(x), AnyValue::UInt32(y)) => { + Some(AExpr::Literal( + Scalar::from(x as f64 / y as f64).into(), + )) + }, + (AnyValue::UInt64(x), AnyValue::UInt64(y)) => { + Some(AExpr::Literal( + Scalar::from(x as f64 / y as f64).into(), + )) + }, + + _ => None, + } + }, + + ( + LiteralValue::Dyn(DynLiteralValue::Float(x)), + LiteralValue::Dyn(DynLiteralValue::Float(y)), + ) => Some(AExpr::Literal(Scalar::from(*x / *y).into())), + ( + LiteralValue::Dyn(DynLiteralValue::Int(x)), + LiteralValue::Dyn(DynLiteralValue::Int(y)), + ) => { + Some(AExpr::Literal(Scalar::from(*x as f64 / *y as f64).into())) + }, + _ => None, + } + } else { + None + } + }, + Modulus => eval_binary_same_type!(left_aexpr, right_aexpr, |l, r| l + .wrapping_floor_div_mod(r) + .1), + Lt => eval_binary_cmp_same_type!(left_aexpr, <, right_aexpr), + Gt => eval_binary_cmp_same_type!(left_aexpr, >, right_aexpr), + Eq | EqValidity => eval_binary_cmp_same_type!(left_aexpr, ==, right_aexpr), + NotEq | NotEqValidity => { + eval_binary_cmp_same_type!(left_aexpr, !=, right_aexpr) + }, + GtEq => eval_binary_cmp_same_type!(left_aexpr, >=, right_aexpr), + LtEq => eval_binary_cmp_same_type!(left_aexpr, <=, right_aexpr), + And | LogicalAnd => eval_bitwise(left_aexpr, right_aexpr, |l, r| l & r), + Or | LogicalOr => eval_bitwise(left_aexpr, right_aexpr, |l, r| l | r), + Xor => eval_bitwise(left_aexpr, right_aexpr, |l, r| l ^ r), + FloorDivide => eval_binary_same_type!(left_aexpr, right_aexpr, |l, r| l + .wrapping_floor_div_mod(r) + .0), + }; + if out.is_some() { + return Ok(out); + } + + None + }, + AExpr::Function { + input, + function, + options, + .. + } => return optimize_functions(input, function, options, expr_arena), + _ => None, + }; + Ok(out) + } +} diff --git a/crates/polars-plan/src/plans/optimizer/simplify_expr/simplify_functions.rs b/crates/polars-plan/src/plans/optimizer/simplify_expr/simplify_functions.rs new file mode 100644 index 000000000000..1cf11252729c --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/simplify_expr/simplify_functions.rs @@ -0,0 +1,366 @@ +use super::*; + +pub(super) fn optimize_functions( + input: &[ExprIR], + function: &FunctionExpr, + options: &FunctionOptions, + expr_arena: &mut Arena, +) -> PolarsResult> { + let out = match function { + // is_null().any() -> null_count() > 0 + // is_not_null().any() -> null_count() < len() + // CORRECTNESS: we can ignore 'ignore_nulls' since is_null/is_not_null never produces NULLS + FunctionExpr::Boolean(BooleanFunction::Any { ignore_nulls: _ }) => { + let input_node = expr_arena.get(input[0].node()); + match input_node { + AExpr::Function { + input, + function: FunctionExpr::Boolean(BooleanFunction::IsNull), + options: _, + } => Some(AExpr::BinaryExpr { + left: expr_arena.add(new_null_count(input)), + op: Operator::Gt, + right: expr_arena.add(AExpr::Literal(LiteralValue::new_idxsize(0))), + }), + AExpr::Function { + input, + function: FunctionExpr::Boolean(BooleanFunction::IsNotNull), + options: _, + } => { + // we should perform optimization only if the original expression is a column + // so in case of disabled CSE, we will not suffer from performance regression + if input.len() == 1 { + let is_not_null_input_node = input[0].node(); + match expr_arena.get(is_not_null_input_node) { + AExpr::Column(_) => Some(AExpr::BinaryExpr { + op: Operator::Lt, + left: expr_arena.add(new_null_count(input)), + right: expr_arena.add(AExpr::Agg(IRAggExpr::Count( + is_not_null_input_node, + true, + ))), + }), + _ => None, + } + } else { + None + } + }, + _ => None, + } + }, + // is_null().all() -> null_count() == len() + // is_not_null().all() -> null_count() == 0 + FunctionExpr::Boolean(BooleanFunction::All { ignore_nulls: _ }) => { + let input_node = expr_arena.get(input[0].node()); + match input_node { + AExpr::Function { + input, + function: FunctionExpr::Boolean(BooleanFunction::IsNull), + options: _, + } => { + // we should perform optimization only if the original expression is a column + // so in case of disabled CSE, we will not suffer from performance regression + if input.len() == 1 { + let is_null_input_node = input[0].node(); + match expr_arena.get(is_null_input_node) { + AExpr::Column(_) => Some(AExpr::BinaryExpr { + op: Operator::Eq, + right: expr_arena.add(new_null_count(input)), + left: expr_arena + .add(AExpr::Agg(IRAggExpr::Count(is_null_input_node, true))), + }), + _ => None, + } + } else { + None + } + }, + AExpr::Function { + input, + function: FunctionExpr::Boolean(BooleanFunction::IsNotNull), + options: _, + } => Some(AExpr::BinaryExpr { + left: expr_arena.add(new_null_count(input)), + op: Operator::Eq, + right: expr_arena.add(AExpr::Literal(LiteralValue::new_idxsize(0))), + }), + _ => None, + } + }, + // sort().reverse() -> sort(reverse) + // sort_by().reverse() -> sort_by(reverse) + FunctionExpr::Reverse => { + let input = expr_arena.get(input[0].node()); + match input { + AExpr::Sort { expr, options } => { + let mut options = *options; + options.descending = !options.descending; + Some(AExpr::Sort { + expr: *expr, + options, + }) + }, + AExpr::SortBy { + expr, + by, + sort_options, + } => { + let mut sort_options = sort_options.clone(); + let reversed_descending = sort_options.descending.iter().map(|x| !*x).collect(); + sort_options.descending = reversed_descending; + Some(AExpr::SortBy { + expr: *expr, + by: by.clone(), + sort_options, + }) + }, + // TODO: add support for cum_sum and other operation that allow reversing. + _ => None, + } + }, + // flatten nested concat_str calls + #[cfg(all(feature = "strings", feature = "concat_str"))] + function @ FunctionExpr::StringExpr(StringFunction::ConcatHorizontal { + delimiter: sep, + ignore_nulls, + }) if sep.is_empty() => { + if input + .iter() + .any(|e| is_string_concat(expr_arena.get(e.node()), *ignore_nulls)) + { + let mut new_inputs = Vec::with_capacity(input.len() * 2); + + for e in input { + match get_string_concat_input(e.node(), expr_arena, *ignore_nulls) { + Some(inp) => new_inputs.extend_from_slice(inp), + None => new_inputs.push(e.clone()), + } + } + Some(AExpr::Function { + input: new_inputs, + function: function.clone(), + options: *options, + }) + } else { + None + } + }, + FunctionExpr::Boolean(BooleanFunction::Not) => { + let y = expr_arena.get(input[0].node()); + + match y { + // not(a and b) => not(a) or not(b) + AExpr::BinaryExpr { + left, + op: Operator::And | Operator::LogicalAnd, + right, + } => { + let left = *left; + let right = *right; + Some(AExpr::BinaryExpr { + left: expr_arena.add(AExpr::Function { + input: vec![ExprIR::from_node(left, expr_arena)], + function: FunctionExpr::Boolean(BooleanFunction::Not), + options: *options, + }), + op: Operator::Or, + right: expr_arena.add(AExpr::Function { + input: vec![ExprIR::from_node(right, expr_arena)], + function: FunctionExpr::Boolean(BooleanFunction::Not), + options: *options, + }), + }) + }, + // not(a or b) => not(a) and not(b) + AExpr::BinaryExpr { + left, + op: Operator::Or | Operator::LogicalOr, + right, + } => { + let left = *left; + let right = *right; + Some(AExpr::BinaryExpr { + left: expr_arena.add(AExpr::Function { + input: vec![ExprIR::from_node(left, expr_arena)], + function: FunctionExpr::Boolean(BooleanFunction::Not), + options: *options, + }), + op: Operator::And, + right: expr_arena.add(AExpr::Function { + input: vec![ExprIR::from_node(right, expr_arena)], + function: FunctionExpr::Boolean(BooleanFunction::Not), + options: *options, + }), + }) + }, + // not(not x) => x + AExpr::Function { + input, + function: FunctionExpr::Boolean(BooleanFunction::Not), + .. + } => Some(expr_arena.get(input[0].node()).clone()), + // not(lit x) => !x + AExpr::Literal(lv) if lv.bool().is_some() => { + Some(AExpr::Literal(Scalar::from(!lv.bool().unwrap()).into())) + }, + // not(x.is_null) => x.is_not_null + AExpr::Function { + input, + function: FunctionExpr::Boolean(BooleanFunction::IsNull), + options, + } => Some(AExpr::Function { + input: input.clone(), + function: FunctionExpr::Boolean(BooleanFunction::IsNotNull), + options: *options, + }), + // not(x.is_not_null) => x.is_null + AExpr::Function { + input, + function: FunctionExpr::Boolean(BooleanFunction::IsNotNull), + options, + } => Some(AExpr::Function { + input: input.clone(), + function: FunctionExpr::Boolean(BooleanFunction::IsNull), + options: *options, + }), + // not(a == b) => a != b + AExpr::BinaryExpr { + left, + op: Operator::Eq, + right, + } => Some(AExpr::BinaryExpr { + left: *left, + op: Operator::NotEq, + right: *right, + }), + // not(a != b) => a == b + AExpr::BinaryExpr { + left, + op: Operator::NotEq, + right, + } => Some(AExpr::BinaryExpr { + left: *left, + op: Operator::Eq, + right: *right, + }), + // not(a < b) => a >= b + AExpr::BinaryExpr { + left, + op: Operator::Lt, + right, + } => Some(AExpr::BinaryExpr { + left: *left, + op: Operator::GtEq, + right: *right, + }), + // not(a <= b) => a > b + AExpr::BinaryExpr { + left, + op: Operator::LtEq, + right, + } => Some(AExpr::BinaryExpr { + left: *left, + op: Operator::Gt, + right: *right, + }), + // not(a > b) => a <= b + AExpr::BinaryExpr { + left, + op: Operator::Gt, + right, + } => Some(AExpr::BinaryExpr { + left: *left, + op: Operator::LtEq, + right: *right, + }), + // not(a >= b) => a < b + AExpr::BinaryExpr { + left, + op: Operator::GtEq, + right, + } => Some(AExpr::BinaryExpr { + left: *left, + op: Operator::Lt, + right: *right, + }), + #[cfg(feature = "is_between")] + // not(col('x').is_between(a,b)) => col('x') < a || col('x') > b + AExpr::Function { + input, + function: FunctionExpr::Boolean(BooleanFunction::IsBetween { closed }), + .. + } => { + if !matches!(expr_arena.get(input[0].node()), AExpr::Column(_)) { + None + } else { + let left_cmp_op = match closed { + ClosedInterval::Both | ClosedInterval::Left => Operator::Lt, + ClosedInterval::None | ClosedInterval::Right => Operator::LtEq, + }; + let right_cmp_op = match closed { + ClosedInterval::Both | ClosedInterval::Right => Operator::Gt, + ClosedInterval::None | ClosedInterval::Left => Operator::GtEq, + }; + let left_left = input[0].node(); + let right_left = input[1].node(); + + let left_right = left_left; + let right_right = input[2].node(); + + // input[0] is between input[1] and input[2] + Some(AExpr::BinaryExpr { + // input[0] (<,<=) input[1] + left: expr_arena.add(AExpr::BinaryExpr { + left: left_left, + op: left_cmp_op, + right: right_left, + }), + // OR + op: Operator::Or, + // input[0] (>,>=) input[2] + right: expr_arena.add(AExpr::BinaryExpr { + left: left_right, + op: right_cmp_op, + right: right_right, + }), + }) + } + }, + _ => None, + } + }, + _ => None, + }; + Ok(out) +} + +#[cfg(all(feature = "strings", feature = "concat_str"))] +fn is_string_concat(ae: &AExpr, ignore_nulls: bool) -> bool { + matches!(ae, AExpr::Function { + function:FunctionExpr::StringExpr( + StringFunction::ConcatHorizontal{delimiter: sep, ignore_nulls: func_inore_nulls}, + ), + .. + } if sep.is_empty() && *func_inore_nulls == ignore_nulls) +} + +#[cfg(all(feature = "strings", feature = "concat_str"))] +fn get_string_concat_input( + node: Node, + expr_arena: &Arena, + ignore_nulls: bool, +) -> Option<&[ExprIR]> { + match expr_arena.get(node) { + AExpr::Function { + input, + function: + FunctionExpr::StringExpr(StringFunction::ConcatHorizontal { + delimiter: sep, + ignore_nulls: func_ignore_nulls, + }), + .. + } if sep.is_empty() && *func_ignore_nulls == ignore_nulls => Some(input), + _ => None, + } +} diff --git a/crates/polars-plan/src/plans/optimizer/slice_pushdown_expr.rs b/crates/polars-plan/src/plans/optimizer/slice_pushdown_expr.rs new file mode 100644 index 000000000000..fb3508170bab --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/slice_pushdown_expr.rs @@ -0,0 +1,128 @@ +use super::*; + +fn pushdown(input: Node, offset: Node, length: Node, arena: &mut Arena) -> Node { + arena.add(AExpr::Slice { + input, + offset, + length, + }) +} + +impl OptimizationRule for SlicePushDown { + fn optimize_expr( + &mut self, + expr_arena: &mut Arena, + expr_node: Node, + _lp_arena: &Arena, + _lp_node: Node, + ) -> PolarsResult> { + if let AExpr::Slice { + input, + offset, + length, + } = expr_arena.get(expr_node) + { + let offset = *offset; + let length = *length; + + use AExpr::*; + let out = match expr_arena.get(*input) { + ae @ Alias(..) | ae @ Cast { .. } => { + let ae = ae.clone(); + let scratch = self.empty_nodes_scratch_mut(); + ae.inputs_rev(scratch); + let input = scratch[0]; + let new_input = pushdown(input, offset, length, expr_arena); + Some(ae.replace_inputs(&[new_input])) + }, + Literal(lv) => { + match lv { + LiteralValue::Series(_) => None, + LiteralValue::Range { .. } => None, + // no need to slice a literal value of unit length + lv => Some(Literal(lv.clone())), + } + }, + BinaryExpr { left, right, op } => { + let left = *left; + let right = *right; + let op = *op; + + let left = pushdown(left, offset, length, expr_arena); + let right = pushdown(right, offset, length, expr_arena); + Some(BinaryExpr { left, op, right }) + }, + Ternary { + truthy, + falsy, + predicate, + } => { + let truthy = *truthy; + let falsy = *falsy; + let predicate = *predicate; + + let truthy = pushdown(truthy, offset, length, expr_arena); + let falsy = pushdown(falsy, offset, length, expr_arena); + let predicate = pushdown(predicate, offset, length, expr_arena); + Some(Ternary { + truthy, + falsy, + predicate, + }) + }, + m @ AnonymousFunction { options, .. } + if matches!(options.collect_groups, ApplyOptions::ElementWise) => + { + if let AnonymousFunction { + mut input, + function, + output_type, + options, + } = m.clone() + { + input.iter_mut().for_each(|e| { + let n = pushdown(e.node(), offset, length, expr_arena); + e.set_node(n); + }); + + Some(AnonymousFunction { + input, + function, + output_type, + options, + }) + } else { + unreachable!() + } + }, + m @ Function { options, .. } + if matches!(options.collect_groups, ApplyOptions::ElementWise) => + { + if let Function { + mut input, + function, + options, + } = m.clone() + { + input.iter_mut().for_each(|e| { + let n = pushdown(e.node(), offset, length, expr_arena); + e.set_node(n); + }); + + Some(Function { + input, + function, + options, + }) + } else { + unreachable!() + } + }, + _ => None, + }; + Ok(out) + } else { + Ok(None) + } + } +} diff --git a/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs b/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs new file mode 100644 index 000000000000..124c6dd19d95 --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs @@ -0,0 +1,570 @@ +use polars_core::prelude::*; +use polars_utils::idx_vec::UnitVec; +use polars_utils::slice_enum::Slice; +use recursive::recursive; + +use crate::prelude::*; + +mod inner { + use polars_utils::arena::Node; + use polars_utils::idx_vec::UnitVec; + use polars_utils::unitvec; + + pub struct SlicePushDown { + pub streaming: bool, + #[expect(unused)] + pub new_streaming: bool, + scratch: UnitVec, + } + + impl SlicePushDown { + pub fn new(streaming: bool, new_streaming: bool) -> Self { + Self { + streaming, + new_streaming, + scratch: unitvec![], + } + } + + /// Returns shared scratch space after clearing. + pub fn empty_nodes_scratch_mut(&mut self) -> &mut UnitVec { + self.scratch.clear(); + &mut self.scratch + } + } +} + +pub(super) use inner::SlicePushDown; + +#[derive(Copy, Clone)] +struct State { + offset: i64, + len: IdxSize, +} + +impl State { + fn to_slice_enum(self) -> Slice { + let offset = self.offset; + let len: usize = usize::try_from(self.len).unwrap(); + + (offset, len).into() + } +} + +/// Can push down slice when: +/// * all projections are elementwise +/// * at least 1 projection is based on a column (for height broadcast) +/// * projections not based on any column project as scalars +/// +/// Returns (can_pushdown, can_pushdown_and_any_expr_has_column) +fn can_pushdown_slice_past_projections( + exprs: &[ExprIR], + arena: &Arena, + scratch: &mut UnitVec, +) -> (bool, bool) { + scratch.clear(); + + let mut can_pushdown_and_any_expr_has_column = false; + + for expr_ir in exprs.iter() { + scratch.push(expr_ir.node()); + + // # "has_column" + // `select(c = Literal([1, 2, 3])).slice(0, 0)` must block slice pushdown, + // because `c` projects to a height independent from the input height. We check + // this by observing that `c` does not have any columns in its input nodes. + // + // TODO: Simply checking that a column node is present does not handle e.g.: + // `select(c = Literal([1, 2, 3]).is_in(col(a)))`, for functions like `is_in`, + // `str.contains`, `str.contains_any` etc. - observe a column node is present + // but the output height is not dependent on it. + let mut has_column = false; + let mut literals_all_scalar = true; + + while let Some(node) = scratch.pop() { + let ae = arena.get(node); + + // We re-use the logic from predicate pushdown, as slices can be seen as a form of filtering. + // But we also do some bookkeeping here specific to slice pushdown. + + match ae { + AExpr::Column(_) => has_column = true, + AExpr::Literal(v) => literals_all_scalar &= v.is_scalar(), + _ => {}, + } + + if !permits_filter_pushdown(scratch, ae, arena) { + return (false, false); + } + } + + // If there is no column then all literals must be scalar + if !(has_column || literals_all_scalar) { + return (false, false); + } + + can_pushdown_and_any_expr_has_column |= has_column + } + + (true, can_pushdown_and_any_expr_has_column) +} + +impl SlicePushDown { + // slice will be done at this node if we found any + // we also stop optimization + fn no_pushdown_finish_opt( + &self, + lp: IR, + state: Option, + lp_arena: &mut Arena, + ) -> PolarsResult { + match state { + Some(state) => { + let input = lp_arena.add(lp); + + let lp = IR::Slice { + input, + offset: state.offset, + len: state.len, + }; + Ok(lp) + }, + None => Ok(lp), + } + } + + /// slice will be done at this node, but we continue optimization + fn no_pushdown_restart_opt( + &mut self, + lp: IR, + state: Option, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + ) -> PolarsResult { + let inputs = lp.get_inputs(); + let exprs = lp.get_exprs(); + + let new_inputs = inputs + .iter() + .map(|&node| { + let alp = lp_arena.take(node); + // No state, so we do not push down the slice here. + let state = None; + let alp = self.pushdown(alp, state, lp_arena, expr_arena)?; + lp_arena.replace(node, alp); + Ok(node) + }) + .collect::>>()?; + let lp = lp.with_exprs_and_input(exprs, new_inputs); + + self.no_pushdown_finish_opt(lp, state, lp_arena) + } + + /// slice will be pushed down. + fn pushdown_and_continue( + &mut self, + lp: IR, + state: Option, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + ) -> PolarsResult { + let inputs = lp.get_inputs(); + let exprs = lp.get_exprs(); + + let new_inputs = inputs + .iter() + .map(|&node| { + let alp = lp_arena.take(node); + let alp = self.pushdown(alp, state, lp_arena, expr_arena)?; + lp_arena.replace(node, alp); + Ok(node) + }) + .collect::>>()?; + Ok(lp.with_exprs_and_input(exprs, new_inputs)) + } + + #[recursive] + fn pushdown( + &mut self, + lp: IR, + state: Option, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + ) -> PolarsResult { + use IR::*; + + match (lp, state) { + #[cfg(feature = "python")] + (PythonScan { + mut options, + }, + // TODO! we currently skip slice pushdown if there is a predicate. + // we can modify the readers to only limit after predicates have been applied + Some(state)) if state.offset == 0 && matches!(options.predicate, PythonPredicate::None) => { + options.n_rows = Some(state.len as usize); + let lp = PythonScan { + options, + }; + Ok(lp) + } + #[cfg(feature = "csv")] + (Scan { + sources, + file_info, + hive_parts, + output_schema, + mut unified_scan_args, + predicate, + scan_type, + }, Some(state)) if matches!(&*scan_type, FileScan::Csv { .. }) && predicate.is_none() => { + unified_scan_args.pre_slice = Some(state.to_slice_enum()); + + let lp = Scan { + sources, + file_info, + hive_parts, + output_schema, + scan_type, + unified_scan_args, + predicate, + }; + + Ok(lp) + }, + + #[cfg(feature = "json")] + (Scan { + sources, + file_info, + hive_parts, + output_schema, + mut unified_scan_args, + predicate, + scan_type, + }, Some(state)) if predicate.is_none() && matches!(&*scan_type, FileScan::NDJson {.. }) => { + unified_scan_args.pre_slice = Some(state.to_slice_enum()); + + let lp = Scan { + sources, + file_info, + hive_parts, + output_schema, + scan_type, + unified_scan_args, + predicate, + }; + + Ok(lp) + }, + #[cfg(feature = "parquet")] + (Scan { + sources, + file_info, + hive_parts, + output_schema, + mut unified_scan_args, + predicate, + scan_type, + }, Some(state)) if predicate.is_none() && matches!(&*scan_type, FileScan::Parquet { .. }) => { + unified_scan_args.pre_slice = Some(state.to_slice_enum()); + + let lp = Scan { + sources, + file_info, + hive_parts, + output_schema, + scan_type, + unified_scan_args, + predicate, + }; + + Ok(lp) + }, + + #[cfg(feature = "ipc")] + (Scan { + sources, + file_info, + hive_parts, + output_schema, + mut unified_scan_args, + predicate, + scan_type, + }, Some(state)) if predicate.is_none() && matches!(&*scan_type, FileScan::Ipc{..})=> { + unified_scan_args.pre_slice = Some(state.to_slice_enum()); + + let lp = Scan { + sources, + file_info, + hive_parts, + output_schema, + scan_type, + unified_scan_args, + predicate, + }; + + Ok(lp) + }, + + // TODO! we currently skip slice pushdown if there is a predicate. + (Scan { + sources, + file_info, + hive_parts, + output_schema, + mut unified_scan_args, + predicate, + scan_type + }, Some(state)) if state.offset == 0 && predicate.is_none() => { + unified_scan_args.pre_slice = Some(state.to_slice_enum()); + + let lp = Scan { + sources, + file_info, + hive_parts, + output_schema, + predicate, + unified_scan_args, + scan_type + }; + + Ok(lp) + }, + (DataFrameScan {df, schema, output_schema, }, Some(state)) => { + let df = df.slice(state.offset, state.len as usize); + let lp = DataFrameScan { + df: Arc::new(df), + schema, + output_schema, + }; + Ok(lp) + } + (Union {mut inputs, mut options }, Some(state)) => { + if state.offset == 0 { + for input in &mut inputs { + let input_lp = lp_arena.take(*input); + let input_lp = self.pushdown(input_lp, Some(state), lp_arena, expr_arena)?; + lp_arena.replace(*input, input_lp); + } + } + // The in-memory union node is slice aware. + // We still set this information, but the streaming engine will ignore it. + options.slice = Some((state.offset, state.len as usize)); + let lp = Union {inputs, options}; + + if self.streaming { + // Ensure the slice node remains. + self.no_pushdown_finish_opt(lp, Some(state), lp_arena) + } else { + Ok(lp) + } + }, + (Join { + input_left, + input_right, + schema, + left_on, + right_on, + mut options + }, Some(state)) if !self.streaming && !matches!(options.options, Some(JoinTypeOptionsIR::Cross { .. })) => { + // first restart optimization in both inputs and get the updated LP + let lp_left = lp_arena.take(input_left); + let lp_left = self.pushdown(lp_left, None, lp_arena, expr_arena)?; + let input_left = lp_arena.add(lp_left); + + let lp_right = lp_arena.take(input_right); + let lp_right = self.pushdown(lp_right, None, lp_arena, expr_arena)?; + let input_right = lp_arena.add(lp_right); + + // then assign the slice state to the join operation + + let mut_options = Arc::make_mut(&mut options); + mut_options.args.slice = Some((state.offset, state.len as usize)); + + Ok(Join { + input_left, + input_right, + schema, + left_on, + right_on, + options + }) + } + (GroupBy { input, keys, aggs, schema, apply, maintain_order, mut options }, Some(state)) => { + // first restart optimization in inputs and get the updated LP + let input_lp = lp_arena.take(input); + let input_lp = self.pushdown(input_lp, None, lp_arena, expr_arena)?; + let input= lp_arena.add(input_lp); + + let mut_options= Arc::make_mut(&mut options); + mut_options.slice = Some((state.offset, state.len as usize)); + + Ok(GroupBy { + input, + keys, + aggs, + schema, + apply, + maintain_order, + options + }) + } + (Distinct {input, mut options}, Some(state)) => { + // first restart optimization in inputs and get the updated LP + let input_lp = lp_arena.take(input); + let input_lp = self.pushdown(input_lp, None, lp_arena, expr_arena)?; + let input= lp_arena.add(input_lp); + options.slice = Some((state.offset, state.len as usize)); + Ok(Distinct { + input, + options, + }) + } + (Sort {input, by_column, mut slice, + sort_options}, Some(state)) => { + // first restart optimization in inputs and get the updated LP + let input_lp = lp_arena.take(input); + let input_lp = self.pushdown(input_lp, None, lp_arena, expr_arena)?; + let input= lp_arena.add(input_lp); + + slice = Some((state.offset, state.len as usize)); + Ok(Sort { + input, + by_column, + slice, + sort_options + }) + } + (Slice { + input, + offset, + len + }, Some(previous_state)) => { + let alp = lp_arena.take(input); + let state = Some(if previous_state.offset == offset { + State { + offset, + len: std::cmp::min(len, previous_state.len) + } + } else { + State { + offset, + len + } + }); + let lp = self.pushdown(alp, state, lp_arena, expr_arena)?; + let input = lp_arena.add(lp); + Ok(Slice { + input, + offset: previous_state.offset, + len: previous_state.len + }) + } + (Slice { + input, + offset, + len + }, None) => { + let alp = lp_arena.take(input); + let state = Some(State { + offset, + len + }); + self.pushdown(alp, state, lp_arena, expr_arena) + } + // [Do not pushdown] boundary + // here we do not pushdown. + // we reset the state and then start the optimization again + m @ (Filter { .. }, _) + // other blocking nodes + | m @ (DataFrameScan {..}, _) + | m @ (Sort {..}, _) + | m @ (MapFunction {function: FunctionIR::Explode {..}, ..}, _) + | m @ (Cache {..}, _) + | m @ (Distinct {..}, _) + | m @ (GroupBy{..},_) + // blocking in streaming + | m @ (Join{..},_) + => { + let (lp, state) = m; + self.no_pushdown_restart_opt(lp, state, lp_arena, expr_arena) + }, + #[cfg(feature = "pivot")] + m @ (MapFunction {function: FunctionIR::Unpivot {..}, ..}, _) => { + let (lp, state) = m; + self.no_pushdown_restart_opt(lp, state, lp_arena, expr_arena) + }, + // [Pushdown] + (MapFunction {input, function}, _) if function.allow_predicate_pd() => { + let lp = MapFunction {input, function}; + self.pushdown_and_continue(lp, state, lp_arena, expr_arena) + }, + // [NO Pushdown] + m @ (MapFunction {..}, _) => { + let (lp, state) = m; + self.no_pushdown_restart_opt(lp, state, lp_arena, expr_arena) + } + // [Pushdown] + // these nodes will be pushed down. + // State is None, we can continue + m @ (Select {..}, None) + | m @ (HStack {..}, None) + | m @ (SimpleProjection {..}, _) + => { + let (lp, state) = m; + self.pushdown_and_continue(lp, state, lp_arena, expr_arena) + } + // there is state, inspect the projection to determine how to deal with it + (Select {input, expr, schema, options}, Some(_)) => { + if can_pushdown_slice_past_projections(&expr, expr_arena, self.empty_nodes_scratch_mut()).1 { + let lp = Select {input, expr, schema, options}; + self.pushdown_and_continue(lp, state, lp_arena, expr_arena) + } + // don't push down slice, but restart optimization + else { + let lp = Select {input, expr, schema, options}; + self.no_pushdown_restart_opt(lp, state, lp_arena, expr_arena) + } + } + (HStack {input, exprs, schema, options}, _) => { + let (can_pushdown, can_pushdown_and_any_expr_has_column) = can_pushdown_slice_past_projections(&exprs, expr_arena, self.empty_nodes_scratch_mut()); + + if can_pushdown_and_any_expr_has_column || ( + // If the schema length is greater then an input column is being projected, so + // the exprs in with_columns do not need to have an input column name. + schema.len() > exprs.len() && can_pushdown + ) + { + let lp = HStack {input, exprs, schema, options}; + self.pushdown_and_continue(lp, state, lp_arena, expr_arena) + } + // don't push down slice, but restart optimization + else { + let lp = HStack {input, exprs, schema, options}; + self.no_pushdown_restart_opt(lp, state, lp_arena, expr_arena) + } + } + (HConcat {inputs, schema, options}, _) => { + // Slice can always be pushed down for horizontal concatenation + let lp = HConcat {inputs, schema, options}; + self.pushdown_and_continue(lp, state, lp_arena, expr_arena) + } + (lp @ Sink { .. }, _) | (lp @ SinkMultiple { .. }, _) => { + // Slice can always be pushed down for sinks + self.pushdown_and_continue(lp, state, lp_arena, expr_arena) + } + (catch_all, state) => { + self.no_pushdown_finish_opt(catch_all, state, lp_arena) + } + } + } + + pub fn optimize( + &mut self, + logical_plan: IR, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + ) -> PolarsResult { + self.pushdown(logical_plan, None, lp_arena, expr_arena) + } +} diff --git a/crates/polars-plan/src/plans/optimizer/stack_opt.rs b/crates/polars-plan/src/plans/optimizer/stack_opt.rs new file mode 100644 index 000000000000..8c1658a0ffda --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/stack_opt.rs @@ -0,0 +1,108 @@ +use polars_core::prelude::PolarsResult; + +use crate::plans::aexpr::AExpr; +use crate::plans::ir::IR; +use crate::prelude::{Arena, Node}; + +/// Optimizer that uses a stack and memory arenas in favor of recursion +pub struct StackOptimizer {} + +impl StackOptimizer { + pub fn optimize_loop( + &self, + rules: &mut [Box], + expr_arena: &mut Arena, + lp_arena: &mut Arena, + lp_top: Node, + ) -> PolarsResult { + let mut changed = true; + + // Nodes of expressions and lp node from which the expressions are a member of. + let mut plans = vec![]; + let mut exprs = vec![]; + let mut scratch = vec![]; + + // Run loop until reaching fixed point. + while changed { + // Recurse into sub plans and expressions and apply rules. + changed = false; + plans.push(lp_top); + while let Some(current_node) = plans.pop() { + // Apply rules + for rule in rules.iter_mut() { + // keep iterating over same rule + while let Some(x) = rule.optimize_plan(lp_arena, expr_arena, current_node)? { + lp_arena.replace(current_node, x); + changed = true; + } + } + + let plan = lp_arena.get(current_node); + + // traverse subplans and expressions and add to the stack + plan.copy_exprs(&mut scratch); + plan.copy_inputs(&mut plans); + + if scratch.is_empty() { + continue; + } + + while let Some(expr_ir) = scratch.pop() { + exprs.push(expr_ir.node()); + } + + // process the expressions on the stack and apply optimizations. + while let Some(current_expr_node) = exprs.pop() { + { + let expr = unsafe { expr_arena.get_unchecked(current_expr_node) }; + if expr.is_leaf() { + continue; + } + } + for rule in rules.iter_mut() { + // keep iterating over same rule + while let Some(x) = rule.optimize_expr( + expr_arena, + current_expr_node, + lp_arena, + current_node, + )? { + expr_arena.replace(current_expr_node, x); + changed = true; + } + } + + let expr = unsafe { expr_arena.get_unchecked(current_expr_node) }; + // traverse subexpressions and add to the stack + expr.inputs_rev(&mut exprs) + } + } + } + Ok(lp_top) + } +} + +pub trait OptimizationRule { + /// Optimize (subplan) in LogicalPlan + /// + /// * `lp_arena` - LogicalPlan memory arena + /// * `expr_arena` - Expression memory arena + /// * `node` - node of the current LogicalPlan node + fn optimize_plan( + &mut self, + _lp_arena: &mut Arena, + _expr_arena: &mut Arena, + _node: Node, + ) -> PolarsResult> { + Ok(None) + } + fn optimize_expr( + &mut self, + _expr_arena: &mut Arena, + _expr_node: Node, + _lp_arena: &Arena, + _lp_node: Node, + ) -> PolarsResult> { + Ok(None) + } +} diff --git a/crates/polars-plan/src/plans/options.rs b/crates/polars-plan/src/plans/options.rs new file mode 100644 index 000000000000..200033d5db0f --- /dev/null +++ b/crates/polars-plan/src/plans/options.rs @@ -0,0 +1,284 @@ +use bitflags::bitflags; +use polars_core::prelude::*; +use polars_core::utils::SuperTypeOptions; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use crate::plans::PlSmallStr; + +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))] +pub struct DistinctOptionsIR { + /// Subset of columns that will be taken into account. + pub subset: Option>, + /// This will maintain the order of the input. + /// Note that this is more expensive. + /// `maintain_order` is not supported in the streaming + /// engine. + pub maintain_order: bool, + /// Which rows to keep. + pub keep_strategy: UniqueKeepStrategy, + /// Take only a slice of the result + pub slice: Option<(i64, usize)>, +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum ApplyOptions { + /// Collect groups to a list and apply the function over the groups. + /// This can be important in aggregation context. + /// e.g. [g1, g1, g2] -> [[g1, g1], g2] + GroupWise, + /// collect groups to a list and then apply + /// e.g. [g1, g1, g2] -> list([g1, g1, g2]) + ApplyList, + /// do not collect before apply + /// e.g. [g1, g1, g2] -> [g1, g1, g2] + ElementWise, +} + +// a boolean that can only be set to `false` safely +#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct UnsafeBool(bool); +impl Default for UnsafeBool { + fn default() -> Self { + UnsafeBool(true) + } +} + +bitflags!( + #[repr(transparent)] + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] + #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] + pub struct FunctionFlags: u8 { + // Raise if use in group by + const ALLOW_GROUP_AWARE = 1 << 0; + // For example a `unique` or a `slice` + const CHANGES_LENGTH = 1 << 1; + // The physical expression may rename the output of this function. + // If set to `false` the physical engine will ensure the left input + // expression is the output name. + const ALLOW_RENAME = 1 << 2; + // if set, then the `Series` passed to the function in the group_by operation + // will ensure the name is set. This is an extra heap allocation per group. + const PASS_NAME_TO_APPLY = 1 << 3; + /// There can be two ways of expanding wildcards: + /// + /// Say the schema is 'a', 'b' and there is a function `f`. In this case, `f('*')` can expand + /// to: + /// 1. `f('a', 'b')` + /// 2. `f('a'), f('b')` + /// + /// Setting this to true, will lead to behavior 1. + /// + /// This also accounts for regex expansion. + const INPUT_WILDCARD_EXPANSION = 1 << 4; + /// Automatically explode on unit length if it ran as final aggregation. + /// + /// this is the case for aggregations like sum, min, covariance etc. + /// We need to know this because we cannot see the difference between + /// the following functions based on the output type and number of elements: + /// + /// x: {1, 2, 3} + /// + /// head_1(x) -> {1} + /// sum(x) -> {4} + const RETURNS_SCALAR = 1 << 5; + /// This can happen with UDF's that use Polars within the UDF. + /// This can lead to recursively entering the engine and sometimes deadlocks. + /// This flag must be set to handle that. + const OPTIONAL_RE_ENTRANT = 1 << 6; + /// Whether this function allows no inputs. + const ALLOW_EMPTY_INPUTS = 1 << 7; + } +); + +impl Default for FunctionFlags { + fn default() -> Self { + Self::from_bits_truncate(0) | Self::ALLOW_GROUP_AWARE + } +} + +#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)] +pub enum CastingRules { + /// Whether information may be lost during cast. E.g. a float to int is considered lossy, + /// whereas int to int is considered lossless. + /// Overflowing is not considered in this flag, that's handled in `strict` casting + FirstArgLossless, + Supertype(SuperTypeOptions), +} + +impl CastingRules { + pub fn cast_to_supertypes() -> CastingRules { + Self::Supertype(Default::default()) + } +} + +#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)] +#[cfg_attr(any(feature = "serde"), derive(Serialize, Deserialize))] +pub struct FunctionOptions { + /// Collect groups to a list and apply the function over the groups. + /// This can be important in aggregation context. + pub collect_groups: ApplyOptions, + + // Validate the output of a `map`. + // this should always be true or we could OOB + pub check_lengths: UnsafeBool, + pub flags: FunctionFlags, + + // used for formatting, (only for anonymous functions) + #[cfg_attr(feature = "serde", serde(skip))] + pub fmt_str: &'static str, + /// Options used when deciding how to cast the arguments of the function. + #[cfg_attr(feature = "serde", serde(skip))] + pub cast_options: Option, +} + +impl FunctionOptions { + #[cfg(feature = "fused")] + pub(crate) unsafe fn no_check_lengths(&mut self) { + self.check_lengths = UnsafeBool(false); + } + pub fn check_lengths(&self) -> bool { + self.check_lengths.0 + } + + pub fn set_elementwise(&mut self) { + self.collect_groups = ApplyOptions::ElementWise + } + + pub fn is_elementwise(&self) -> bool { + matches!( + self.collect_groups, + ApplyOptions::ElementWise | ApplyOptions::ApplyList + ) && !self.flags.contains(FunctionFlags::CHANGES_LENGTH) + && !self.flags.contains(FunctionFlags::RETURNS_SCALAR) + } + + pub fn is_length_preserving(&self) -> bool { + !self.flags.contains(FunctionFlags::CHANGES_LENGTH) + } + + pub fn returns_scalar(&self) -> bool { + self.flags.contains(FunctionFlags::RETURNS_SCALAR) + } + + pub fn elementwise() -> FunctionOptions { + FunctionOptions { + collect_groups: ApplyOptions::ElementWise, + ..Default::default() + } + } + + pub fn elementwise_with_infer() -> FunctionOptions { + Self::groupwise() + } + + pub fn row_separable() -> FunctionOptions { + Self::groupwise() + } + + pub fn length_preserving() -> FunctionOptions { + Self::groupwise() + } + + pub fn groupwise() -> FunctionOptions { + FunctionOptions { + collect_groups: ApplyOptions::GroupWise, + ..Default::default() + } + } + + pub fn aggregation() -> FunctionOptions { + let mut options = Self::groupwise(); + options.flags |= FunctionFlags::RETURNS_SCALAR; + options + } + + pub fn with_supertyping(self, supertype_options: SuperTypeOptions) -> FunctionOptions { + self.with_casting_rules(CastingRules::Supertype(supertype_options)) + } + + pub fn with_casting_rules(mut self, casting_rules: CastingRules) -> FunctionOptions { + self.cast_options = Some(casting_rules); + self + } + + pub fn with_allow_rename(mut self, allow_rename: bool) -> FunctionOptions { + self.flags.set(FunctionFlags::ALLOW_RENAME, allow_rename); + self + } + + pub fn with_pass_name_to_apply(mut self, pass_name_to_apply: bool) -> Self { + self.flags + .set(FunctionFlags::PASS_NAME_TO_APPLY, pass_name_to_apply); + self + } + + pub fn with_input_wildcard_expansion( + mut self, + input_wildcard_expansion: bool, + ) -> FunctionOptions { + self.flags.set( + FunctionFlags::INPUT_WILDCARD_EXPANSION, + input_wildcard_expansion, + ); + self + } + + pub fn with_allow_empty_inputs(mut self, allow_empty_inputs: bool) -> FunctionOptions { + self.flags + .set(FunctionFlags::ALLOW_EMPTY_INPUTS, allow_empty_inputs); + self + } + + pub fn with_changes_length(mut self, changes_length: bool) -> FunctionOptions { + self.flags + .set(FunctionFlags::ALLOW_EMPTY_INPUTS, changes_length); + self + } +} + +impl Default for FunctionOptions { + fn default() -> Self { + FunctionOptions { + collect_groups: ApplyOptions::GroupWise, + check_lengths: UnsafeBool(true), + fmt_str: Default::default(), + cast_options: Default::default(), + flags: Default::default(), + } + } +} + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct ProjectionOptions { + pub run_parallel: bool, + pub duplicate_check: bool, + // Should length-1 Series be broadcast to the length of the dataframe. + // Only used by CSE optimizer + pub should_broadcast: bool, +} + +impl Default for ProjectionOptions { + fn default() -> Self { + Self { + run_parallel: true, + duplicate_check: true, + should_broadcast: true, + } + } +} + +impl ProjectionOptions { + /// Conservatively merge the options of two [`ProjectionOptions`] + pub fn merge_options(&self, other: &Self) -> Self { + Self { + run_parallel: self.run_parallel & other.run_parallel, + duplicate_check: self.duplicate_check & other.duplicate_check, + should_broadcast: self.should_broadcast | other.should_broadcast, + } + } +} diff --git a/crates/polars-plan/src/plans/python/mod.rs b/crates/polars-plan/src/plans/python/mod.rs new file mode 100644 index 000000000000..944847b8ba0f --- /dev/null +++ b/crates/polars-plan/src/plans/python/mod.rs @@ -0,0 +1,7 @@ +pub mod predicate; +pub mod pyarrow; +mod source; +mod utils; + +pub use source::*; +pub use utils::*; diff --git a/crates/polars-plan/src/plans/python/predicate.rs b/crates/polars-plan/src/plans/python/predicate.rs new file mode 100644 index 000000000000..c6024ec25f7d --- /dev/null +++ b/crates/polars-plan/src/plans/python/predicate.rs @@ -0,0 +1,53 @@ +use polars_core::prelude::{AnyValue, PolarsResult}; +use polars_utils::pl_serialize; +use recursive::recursive; + +use crate::prelude::*; + +#[recursive] +fn accept_as_io_predicate(e: &Expr) -> bool { + const LIMIT: usize = 1 << 16; + match e { + Expr::Literal(lv) => match lv { + LiteralValue::Scalar(v) => match v.as_any_value() { + AnyValue::Binary(v) => v.len() <= LIMIT, + AnyValue::String(v) => v.len() <= LIMIT, + _ => true, + }, + LiteralValue::Series(s) => s.estimated_size() < LIMIT, + + // Don't accept dynamic types + LiteralValue::Dyn(_) => false, + _ => true, + }, + Expr::Wildcard | Expr::Column(_) => true, + Expr::BinaryExpr { left, right, .. } => { + accept_as_io_predicate(left) && accept_as_io_predicate(right) + }, + Expr::Ternary { + truthy, + falsy, + predicate, + } => { + accept_as_io_predicate(truthy) + && accept_as_io_predicate(falsy) + && accept_as_io_predicate(predicate) + }, + Expr::Cast { expr, .. } => accept_as_io_predicate(expr), + Expr::Alias(_, _) => true, + Expr::AnonymousFunction { input, options, .. } | Expr::Function { input, options, .. } => { + options.is_elementwise() && input.iter().all(accept_as_io_predicate) + }, + _ => false, + } +} + +pub fn serialize(expr: &Expr) -> PolarsResult>> { + if !accept_as_io_predicate(expr) { + return Ok(None); + } + let mut buf = vec![]; + pl_serialize::serialize_into_writer::<_, _, true>(&mut buf, expr)?; + + Ok(Some(buf)) +} diff --git a/crates/polars-plan/src/plans/python/pyarrow.rs b/crates/polars-plan/src/plans/python/pyarrow.rs new file mode 100644 index 000000000000..1f675c6961f3 --- /dev/null +++ b/crates/polars-plan/src/plans/python/pyarrow.rs @@ -0,0 +1,176 @@ +use std::fmt::Write; + +use polars_core::datatypes::AnyValue; +use polars_core::prelude::{TimeUnit, TimeZone}; + +use crate::prelude::*; + +#[derive(Default, Copy, Clone)] +pub struct PyarrowArgs { + // pyarrow doesn't allow `filter([True, False])` + // but does allow `filter(field("a").isin([True, False]))` + allow_literal_series: bool, +} + +fn to_py_datetime(v: i64, tu: &TimeUnit, tz: Option<&TimeZone>) -> String { + // note: `to_py_datetime` and the `Datetime` + // dtype have to be in-scope on the python side + match tz { + None => format!("to_py_datetime({},'{}')", v, tu.to_ascii()), + Some(tz) => format!("to_py_datetime({},'{}',{})", v, tu.to_ascii(), tz), + } +} + +// convert to a pyarrow expression that can be evaluated with pythons eval +pub fn predicate_to_pa( + predicate: Node, + expr_arena: &Arena, + args: PyarrowArgs, +) -> Option { + match expr_arena.get(predicate) { + AExpr::BinaryExpr { left, right, op } => { + if op.is_comparison_or_bitwise() { + let left = predicate_to_pa(*left, expr_arena, args)?; + let right = predicate_to_pa(*right, expr_arena, args)?; + Some(format!("({left} {op} {right})")) + } else { + None + } + }, + AExpr::Column(name) => Some(format!("pa.compute.field('{}')", name)), + AExpr::Literal(LiteralValue::Series(s)) => { + if !args.allow_literal_series || s.is_empty() || s.len() > 100 { + None + } else { + let mut list_repr = String::with_capacity(s.len() * 5); + list_repr.push('['); + for av in s.rechunk().iter() { + if let AnyValue::Boolean(v) = av { + let s = if v { "True" } else { "False" }; + write!(list_repr, "{},", s).unwrap(); + } else if let AnyValue::Datetime(v, tu, tz) = av { + let dtm = to_py_datetime(v, &tu, tz); + write!(list_repr, "{dtm},").unwrap(); + } else if let AnyValue::Date(v) = av { + write!(list_repr, "to_py_date({v}),").unwrap(); + } else { + write!(list_repr, "{av},").unwrap(); + } + } + // pop last comma + list_repr.pop(); + list_repr.push(']'); + Some(list_repr) + } + }, + AExpr::Literal(lv) => { + let av = lv.to_any_value()?; + let dtype = av.dtype(); + match av.as_borrowed() { + AnyValue::String(s) => Some(format!("'{s}'")), + AnyValue::Boolean(val) => { + // python bools are capitalized + if val { + Some("pa.compute.scalar(True)".to_string()) + } else { + Some("pa.compute.scalar(False)".to_string()) + } + }, + #[cfg(feature = "dtype-date")] + AnyValue::Date(v) => { + // the function `to_py_date` and the `Date` + // dtype have to be in scope on the python side + Some(format!("to_py_date({v})")) + }, + #[cfg(feature = "dtype-datetime")] + AnyValue::Datetime(v, tu, tz) => Some(to_py_datetime(v, &tu, tz)), + // Activate once pyarrow supports them + // #[cfg(feature = "dtype-time")] + // AnyValue::Time(v) => { + // // the function `to_py_time` has to be in scope + // // on the python side + // Some(format!("to_py_time(value={v})")) + // } + // #[cfg(feature = "dtype-duration")] + // AnyValue::Duration(v, tu) => { + // // the function `to_py_timedelta` has to be in scope + // // on the python side + // Some(format!( + // "to_py_timedelta(value={}, tu='{}')", + // v, + // tu.to_ascii() + // )) + // } + av => { + if dtype.is_float() { + let val = av.extract::()?; + Some(format!("{val}")) + } else if dtype.is_integer() { + let val = av.extract::()?; + Some(format!("{val}")) + } else { + None + } + }, + } + }, + #[cfg(feature = "is_in")] + AExpr::Function { + function: FunctionExpr::Boolean(BooleanFunction::IsIn { .. }), + input, + .. + } => { + let col = predicate_to_pa(input.first()?.node(), expr_arena, args)?; + let mut args = args; + args.allow_literal_series = true; + let values = predicate_to_pa(input.get(1)?.node(), expr_arena, args)?; + + Some(format!("({col}).isin({values})")) + }, + #[cfg(feature = "is_between")] + AExpr::Function { + function: FunctionExpr::Boolean(BooleanFunction::IsBetween { closed }), + input, + .. + } => { + if !matches!(expr_arena.get(input.first()?.node()), AExpr::Column(_)) { + None + } else { + let col = predicate_to_pa(input.first()?.node(), expr_arena, args)?; + let left_cmp_op = match closed { + ClosedInterval::None | ClosedInterval::Right => Operator::Gt, + ClosedInterval::Both | ClosedInterval::Left => Operator::GtEq, + }; + let right_cmp_op = match closed { + ClosedInterval::None | ClosedInterval::Left => Operator::Lt, + ClosedInterval::Both | ClosedInterval::Right => Operator::LtEq, + }; + + let lower = predicate_to_pa(input.get(1)?.node(), expr_arena, args)?; + let upper = predicate_to_pa(input.get(2)?.node(), expr_arena, args)?; + + Some(format!( + "(({col} {left_cmp_op} {lower}) & ({col} {right_cmp_op} {upper}))" + )) + } + }, + AExpr::Function { + function, input, .. + } => { + let input = input.first().unwrap().node(); + let input = predicate_to_pa(input, expr_arena, args)?; + + match function { + FunctionExpr::Boolean(BooleanFunction::Not) => Some(format!("~({input})")), + FunctionExpr::Boolean(BooleanFunction::IsNull) => { + Some(format!("({input}).is_null()")) + }, + FunctionExpr::Boolean(BooleanFunction::IsNotNull) => { + Some(format!("~({input}).is_null()")) + }, + _ => None, + } + }, + _ => None, + } +} diff --git a/crates/polars-plan/src/plans/python/source.rs b/crates/polars-plan/src/plans/python/source.rs new file mode 100644 index 000000000000..57bb59fd40f2 --- /dev/null +++ b/crates/polars-plan/src/plans/python/source.rs @@ -0,0 +1,42 @@ +use std::sync::Arc; + +use polars_core::schema::SchemaRef; +use polars_utils::python_function::PythonFunction; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use crate::dsl::python_dsl::PythonScanSource; +use crate::plans::{ExprIR, PlSmallStr}; + +#[derive(Clone, PartialEq, Eq, Debug, Default)] +#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))] +pub struct PythonOptions { + /// A function that returns a Python Generator. + /// The generator should produce Polars DataFrame's. + pub scan_fn: Option, + /// Schema of the file. + pub schema: SchemaRef, + /// Schema the reader will produce when the file is read. + pub output_schema: Option, + // Projected column names. + pub with_columns: Option>, + // Which interface is the python function. + pub python_source: PythonScanSource, + /// A `head` call passed to the reader. + pub n_rows: Option, + /// Optional predicate the reader must apply. + pub predicate: PythonPredicate, + /// Validate if the source gives the proper schema. + pub validate_schema: bool, +} + +#[derive(Clone, PartialEq, Eq, Debug, Default)] +#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))] +pub enum PythonPredicate { + // A pyarrow predicate python expression + // can be evaluated with python.eval + PyArrow(String), + Polars(ExprIR), + #[default] + None, +} diff --git a/crates/polars-plan/src/plans/python/utils.rs b/crates/polars-plan/src/plans/python/utils.rs new file mode 100644 index 000000000000..c4812f7d5e45 --- /dev/null +++ b/crates/polars-plan/src/plans/python/utils.rs @@ -0,0 +1,30 @@ +use polars_core::error::{PolarsResult, polars_err}; +use polars_core::frame::DataFrame; +use polars_core::schema::SchemaRef; +use polars_ffi::version_0::SeriesExport; +use pyo3::intern; +use pyo3::prelude::*; + +pub fn python_df_to_rust(py: Python, df: Bound) -> PolarsResult { + let err = |_| polars_err!(ComputeError: "expected a polars.DataFrame; got {}", df); + let pydf = df.getattr(intern!(py, "_df")).map_err(err)?; + + let width = pydf.call_method0(intern!(py, "width")).unwrap(); + let width = width.extract::().unwrap(); + + // Don't resize the Vec<> so that the drop of the SeriesExport will not be caleld. + let mut export: Vec = Vec::with_capacity(width); + let location = export.as_mut_ptr(); + + let _ = pydf + .call_method1(intern!(py, "_export_columns"), (location as usize,)) + .unwrap(); + + unsafe { polars_ffi::version_0::import_df(location, width) } +} + +pub(crate) fn python_schema_to_rust(py: Python, schema: Bound) -> PolarsResult { + let err = |_| polars_err!(ComputeError: "expected a polars.Schema; got {}", schema); + let df = schema.call_method0("to_frame").map_err(err)?; + python_df_to_rust(py, df).map(|df| df.schema().clone()) +} diff --git a/crates/polars-plan/src/plans/schema.rs b/crates/polars-plan/src/plans/schema.rs new file mode 100644 index 000000000000..18839a243ff8 --- /dev/null +++ b/crates/polars-plan/src/plans/schema.rs @@ -0,0 +1,459 @@ +use std::ops::Deref; +use std::sync::Mutex; + +use arrow::datatypes::ArrowSchemaRef; +use either::Either; +use polars_core::prelude::*; +use polars_utils::format_pl_smallstr; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use crate::prelude::*; + +impl DslPlan { + // Warning! This should not be used on the DSL internally. + // All schema resolving should be done during conversion to [`IR`]. + + /// Compute the schema. This requires conversion to [`IR`] and type-resolving. + pub fn compute_schema(&self) -> PolarsResult { + let mut lp_arena = Default::default(); + let mut expr_arena = Default::default(); + let node = to_alp( + self.clone(), + &mut expr_arena, + &mut lp_arena, + &mut OptFlags::schema_only(), + )?; + + Ok(lp_arena.get(node).schema(&lp_arena).into_owned()) + } +} + +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct FileInfo { + /// Schema of the physical file. + /// + /// Notes: + /// - Does not include logical columns like `include_file_path` and row index. + /// - Always includes all hive columns. + pub schema: SchemaRef, + /// Stores the schema used for the reader, as the main schema can contain + /// extra hive columns. + pub reader_schema: Option>, + /// - known size + /// - estimated size (set to usize::max if unknown). + pub row_estimation: (Option, usize), +} + +// Manual default because `row_estimation.1` needs to be `usize::MAX`. +impl Default for FileInfo { + fn default() -> Self { + FileInfo { + schema: Default::default(), + reader_schema: None, + row_estimation: (None, usize::MAX), + } + } +} + +impl FileInfo { + /// Constructs a new [`FileInfo`]. + pub fn new( + schema: SchemaRef, + reader_schema: Option>, + row_estimation: (Option, usize), + ) -> Self { + Self { + schema: schema.clone(), + reader_schema, + row_estimation, + } + } + + /// Merge the [`Schema`] of a [`HivePartitions`] with the schema of this [`FileInfo`]. + pub fn update_schema_with_hive_schema(&mut self, hive_schema: SchemaRef) { + let schema = Arc::make_mut(&mut self.schema); + + for field in hive_schema.iter_fields() { + if let Some(existing) = schema.get_mut(&field.name) { + *existing = field.dtype().clone(); + } else { + schema + .insert_at_index(schema.len(), field.name, field.dtype.clone()) + .unwrap(); + } + } + } +} + +#[cfg(feature = "streaming")] +fn estimate_sizes( + known_size: Option, + estimated_size: usize, + filter_count: usize, +) -> (Option, usize) { + match (known_size, filter_count) { + (Some(known_size), 0) => (Some(known_size), estimated_size), + (None, 0) => (None, estimated_size), + (_, _) => ( + None, + (estimated_size as f32 * 0.9f32.powf(filter_count as f32)) as usize, + ), + } +} + +#[cfg(feature = "streaming")] +pub fn set_estimated_row_counts( + root: Node, + lp_arena: &mut Arena, + expr_arena: &Arena, + mut _filter_count: usize, + scratch: &mut Vec, +) -> (Option, usize, usize) { + use IR::*; + + fn apply_slice(out: &mut (Option, usize, usize), slice: Option<(i64, usize)>) { + if let Some((_, len)) = slice { + out.0 = out.0.map(|known_size| std::cmp::min(len, known_size)); + out.1 = std::cmp::min(len, out.1); + } + } + + match lp_arena.get(root) { + Filter { predicate, input } => { + _filter_count += expr_arena + .iter(predicate.node()) + .filter(|(_, ae)| matches!(ae, AExpr::BinaryExpr { .. })) + .count() + + 1; + set_estimated_row_counts(*input, lp_arena, expr_arena, _filter_count, scratch) + }, + Slice { input, len, .. } => { + let len = *len as usize; + let mut out = + set_estimated_row_counts(*input, lp_arena, expr_arena, _filter_count, scratch); + apply_slice(&mut out, Some((0, len))); + out + }, + Union { .. } => { + if let Union { + inputs, + mut options, + } = lp_arena.take(root) + { + let mut sum_output = (None, 0usize); + for input in &inputs { + let mut out = + set_estimated_row_counts(*input, lp_arena, expr_arena, 0, scratch); + if let Some((_offset, len)) = options.slice { + apply_slice(&mut out, Some((0, len))) + } + // todo! deal with known as well + let out = estimate_sizes(out.0, out.1, out.2); + sum_output.1 = sum_output.1.saturating_add(out.1); + } + options.rows = sum_output; + lp_arena.replace(root, Union { inputs, options }); + (sum_output.0, sum_output.1, 0) + } else { + unreachable!() + } + }, + Join { .. } => { + if let Join { + input_left, + input_right, + mut options, + schema, + left_on, + right_on, + } = lp_arena.take(root) + { + let mut_options = Arc::make_mut(&mut options); + let (known_size, estimated_size, filter_count_left) = + set_estimated_row_counts(input_left, lp_arena, expr_arena, 0, scratch); + mut_options.rows_left = + estimate_sizes(known_size, estimated_size, filter_count_left); + let (known_size, estimated_size, filter_count_right) = + set_estimated_row_counts(input_right, lp_arena, expr_arena, 0, scratch); + mut_options.rows_right = + estimate_sizes(known_size, estimated_size, filter_count_right); + + let mut out = match options.args.how { + JoinType::Left => { + let (known_size, estimated_size) = options.rows_left; + (known_size, estimated_size, filter_count_left) + }, + JoinType::Cross | JoinType::Full => { + let (known_size_left, estimated_size_left) = options.rows_left; + let (known_size_right, estimated_size_right) = options.rows_right; + match (known_size_left, known_size_right) { + (Some(l), Some(r)) => { + (Some(l * r), estimated_size_left, estimated_size_right) + }, + _ => (None, estimated_size_left * estimated_size_right, 0), + } + }, + _ => { + let (known_size_left, estimated_size_left) = options.rows_left; + let (known_size_right, estimated_size_right) = options.rows_right; + if estimated_size_left > estimated_size_right { + (known_size_left, estimated_size_left, 0) + } else { + (known_size_right, estimated_size_right, 0) + } + }, + }; + apply_slice(&mut out, options.args.slice); + lp_arena.replace( + root, + Join { + input_left, + input_right, + options, + schema, + left_on, + right_on, + }, + ); + out + } else { + unreachable!() + } + }, + DataFrameScan { df, .. } => { + let len = df.height(); + (Some(len), len, _filter_count) + }, + Scan { file_info, .. } => { + let (known_size, estimated_size) = file_info.row_estimation; + (known_size, estimated_size, _filter_count) + }, + #[cfg(feature = "python")] + PythonScan { .. } => { + // TODO! get row estimation. + (None, usize::MAX, _filter_count) + }, + lp => { + lp.copy_inputs(scratch); + let mut sum_output = (None, 0, 0); + while let Some(input) = scratch.pop() { + let out = + set_estimated_row_counts(input, lp_arena, expr_arena, _filter_count, scratch); + sum_output.1 += out.1; + sum_output.2 += out.2; + sum_output.0 = match sum_output.0 { + None => out.0, + p => p, + }; + } + sum_output + }, + } +} + +pub(crate) fn det_join_schema( + schema_left: &SchemaRef, + schema_right: &SchemaRef, + left_on: &[ExprIR], + right_on: &[ExprIR], + options: &JoinOptions, + expr_arena: &Arena, +) -> PolarsResult { + match &options.args.how { + // semi and anti joins are just filtering operations + // the schema will never change. + #[cfg(feature = "semi_anti_join")] + JoinType::Semi | JoinType::Anti => Ok(schema_left.clone()), + // Right-join with coalesce enabled will coalesce LHS columns into RHS columns (i.e. LHS columns + // are removed). This is the opposite of what a left join does so it has its own codepath. + // + // E.g. df(cols=[A, B]).right_join(df(cols=[A, B]), on=A, coalesce=True) + // + // will result in + // + // df(cols=[B, A, B_right]) + JoinType::Right if options.args.should_coalesce() => { + // Get join names. + let mut join_on_left: PlHashSet<_> = PlHashSet::with_capacity(left_on.len()); + for e in left_on { + let field = e.field(schema_left, Context::Default, expr_arena)?; + join_on_left.insert(field.name); + } + + let mut join_on_right: PlHashSet<_> = PlHashSet::with_capacity(right_on.len()); + for e in right_on { + let field = e.field(schema_right, Context::Default, expr_arena)?; + join_on_right.insert(field.name); + } + + // For the error message + let mut suffixed = None; + + let new_schema = Schema::with_capacity(schema_left.len() + schema_right.len()) + // Columns from left, excluding those used as join keys + .hstack(schema_left.iter().filter_map(|(name, dtype)| { + if join_on_left.contains(name) { + return None; + } + + Some((name.clone(), dtype.clone())) + }))? + // Columns from right + .hstack(schema_right.iter().map(|(name, dtype)| { + suffixed = None; + + let in_left_schema = schema_left.contains(name.as_str()); + let is_coalesced = join_on_left.contains(name.as_str()); + + if in_left_schema && !is_coalesced { + suffixed = Some(format_pl_smallstr!("{}{}", name, options.args.suffix())); + (suffixed.clone().unwrap(), dtype.clone()) + } else { + (name.clone(), dtype.clone()) + } + })) + .map_err(|e| { + if let Some(column) = suffixed { + join_suffix_duplicate_help_msg(&column) + } else { + e + } + })?; + + Ok(Arc::new(new_schema)) + }, + _how => { + let mut new_schema = Schema::with_capacity(schema_left.len() + schema_right.len()) + .hstack(schema_left.iter_fields())?; + + let is_coalesced = options.args.should_coalesce(); + + let mut _asof_pre_added_rhs_keys: PlHashSet = PlHashSet::new(); + + // Handles coalescing of asof-joins. + // Asof joins are not equi-joins + // so the columns that are joined on, may have different + // values so if the right has a different name, it is added to the schema + #[cfg(feature = "asof_join")] + if matches!(_how, JoinType::AsOf(_)) { + for (left_on, right_on) in left_on.iter().zip(right_on) { + let field_left = left_on.field(schema_left, Context::Default, expr_arena)?; + let field_right = right_on.field(schema_right, Context::Default, expr_arena)?; + + if is_coalesced && field_left.name != field_right.name { + _asof_pre_added_rhs_keys.insert(field_right.name.clone()); + + if schema_left.contains(&field_right.name) { + new_schema.with_column( + _join_suffix_name(&field_right.name, options.args.suffix()), + field_right.dtype, + ); + } else { + new_schema.with_column(field_right.name, field_right.dtype); + } + } + } + } + + let mut join_on_right: PlHashSet<_> = PlHashSet::with_capacity(right_on.len()); + for e in right_on { + let field = e.field(schema_right, Context::Default, expr_arena)?; + join_on_right.insert(field.name); + } + + for (name, dtype) in schema_right.iter() { + #[cfg(feature = "asof_join")] + { + if let JoinType::AsOf(asof_options) = &options.args.how { + // Asof adds keys earlier + if _asof_pre_added_rhs_keys.contains(name) { + continue; + } + + // Asof join by columns are coalesced + if asof_options + .right_by + .as_deref() + .is_some_and(|x| x.contains(name)) + { + // Do not add suffix. The column of the left table will be used + continue; + } + } + } + + if join_on_right.contains(name.as_str()) && is_coalesced { + // Column will be coalesced into an already added LHS column. + continue; + } + + // For the error message. + let mut suffixed = None; + + let (name, dtype) = if schema_left.contains(name) { + suffixed = Some(format_pl_smallstr!("{}{}", name, options.args.suffix())); + (suffixed.clone().unwrap(), dtype.clone()) + } else { + (name.clone(), dtype.clone()) + }; + + new_schema.try_insert(name, dtype).map_err(|e| { + if let Some(column) = suffixed { + join_suffix_duplicate_help_msg(&column) + } else { + e + } + })?; + } + + Ok(Arc::new(new_schema)) + }, + } +} + +fn join_suffix_duplicate_help_msg(column_name: &str) -> PolarsError { + polars_err!( + Duplicate: + "\ +column with name '{}' already exists + +You may want to try: +- renaming the column prior to joining +- using the `suffix` parameter to specify a suffix different to the default one ('_right')", + column_name + ) +} + +// We don't use an `Arc` because caches should live in different query plans. +// For that reason we have a specialized deep clone. +#[derive(Default)] +pub struct CachedSchema(Mutex>); + +impl AsRef>> for CachedSchema { + fn as_ref(&self) -> &Mutex> { + &self.0 + } +} + +impl Deref for CachedSchema { + type Target = Mutex>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Clone for CachedSchema { + fn clone(&self) -> Self { + let inner = self.0.lock().unwrap(); + Self(Mutex::new(inner.clone())) + } +} + +impl CachedSchema { + pub fn get(&self) -> Option { + self.0.lock().unwrap().clone() + } +} diff --git a/crates/polars-plan/src/plans/visitor/expr.rs b/crates/polars-plan/src/plans/visitor/expr.rs new file mode 100644 index 000000000000..2db617bb453d --- /dev/null +++ b/crates/polars-plan/src/plans/visitor/expr.rs @@ -0,0 +1,317 @@ +use std::fmt::Debug; +#[cfg(feature = "cse")] +use std::fmt::Formatter; + +use polars_core::prelude::{Field, Schema}; +use polars_utils::unitvec; + +use super::*; +use crate::prelude::*; + +impl TreeWalker for Expr { + type Arena = (); + + fn apply_children PolarsResult>( + &self, + op: &mut F, + arena: &Self::Arena, + ) -> PolarsResult { + let mut scratch = unitvec![]; + + self.nodes(&mut scratch); + + for &child in scratch.as_slice() { + match op(child, arena)? { + // let the recursion continue + VisitRecursion::Continue | VisitRecursion::Skip => {}, + // early stop + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + } + } + Ok(VisitRecursion::Continue) + } + + fn map_children PolarsResult>( + self, + f: &mut F, + _arena: &mut Self::Arena, + ) -> PolarsResult { + use polars_utils::functions::try_arc_map as am; + let mut f = |expr| f(expr, &mut ()); + use AggExpr::*; + use Expr::*; + #[rustfmt::skip] + let ret = match self { + Alias(l, r) => Alias(am(l, f)?, r), + Column(_) => self, + Columns(_) => self, + DtypeColumn(_) => self, + IndexColumn(_) => self, + Literal(_) => self, + #[cfg(feature = "dtype-struct")] + Field(_) => self, + BinaryExpr { left, op, right } => { + BinaryExpr { left: am(left, &mut f)? , op, right: am(right, f)?} + }, + Cast { expr, dtype, options: strict } => Cast { expr: am(expr, f)?, dtype, options: strict }, + Sort { expr, options } => Sort { expr: am(expr, f)?, options }, + Gather { expr, idx, returns_scalar } => Gather { expr: am(expr, &mut f)?, idx: am(idx, f)?, returns_scalar }, + SortBy { expr, by, sort_options } => SortBy { expr: am(expr, &mut f)?, by: by.into_iter().map(f).collect::>()?, sort_options }, + Agg(agg_expr) => Agg(match agg_expr { + Min { input, propagate_nans } => Min { input: am(input, f)?, propagate_nans }, + Max { input, propagate_nans } => Max { input: am(input, f)?, propagate_nans }, + Median(x) => Median(am(x, f)?), + NUnique(x) => NUnique(am(x, f)?), + First(x) => First(am(x, f)?), + Last(x) => Last(am(x, f)?), + Mean(x) => Mean(am(x, f)?), + Implode(x) => Implode(am(x, f)?), + Count(x, nulls) => Count(am(x, f)?, nulls), + Quantile { expr, quantile, method: interpol } => Quantile { expr: am(expr, &mut f)?, quantile: am(quantile, f)?, method: interpol }, + Sum(x) => Sum(am(x, f)?), + AggGroups(x) => AggGroups(am(x, f)?), + Std(x, ddf) => Std(am(x, f)?, ddf), + Var(x, ddf) => Var(am(x, f)?, ddf), + }), + Ternary { predicate, truthy, falsy } => Ternary { predicate: am(predicate, &mut f)?, truthy: am(truthy, &mut f)?, falsy: am(falsy, f)? }, + Function { input, function, options } => Function { input: input.into_iter().map(f).collect::>()?, function, options }, + Explode(expr) => Explode(am(expr, f)?), + Filter { input, by } => Filter { input: am(input, &mut f)?, by: am(by, f)? }, + Window { function, partition_by, order_by, options } => { + let partition_by = partition_by.into_iter().map(&mut f).collect::>()?; + Window { function: am(function, f)?, partition_by, order_by, options } + }, + Wildcard => Wildcard, + Slice { input, offset, length } => Slice { input: am(input, &mut f)?, offset: am(offset, &mut f)?, length: am(length, f)? }, + Exclude(expr, excluded) => Exclude(am(expr, f)?, excluded), + KeepName(expr) => KeepName(am(expr, f)?), + Len => Len, + Nth(_) => self, + RenameAlias { function, expr } => RenameAlias { function, expr: am(expr, f)? }, + AnonymousFunction { input, function, output_type, options } => { + AnonymousFunction { input: input.into_iter().map(f).collect::>()?, function, output_type, options } + }, + SubPlan(_, _) => self, + Selector(_) => self, + }; + Ok(ret) + } +} + +#[derive(Copy, Clone, Debug)] +pub struct AexprNode { + node: Node, +} + +impl AexprNode { + pub fn new(node: Node) -> Self { + Self { node } + } + + /// Get the `Node`. + pub fn node(&self) -> Node { + self.node + } + + pub fn to_aexpr<'a>(&self, arena: &'a Arena) -> &'a AExpr { + arena.get(self.node) + } + + pub fn to_expr(&self, arena: &Arena) -> Expr { + node_to_expr(self.node, arena) + } + + pub fn to_field(&self, schema: &Schema, arena: &Arena) -> PolarsResult { + let aexpr = arena.get(self.node); + aexpr.to_field(schema, Context::Default, arena) + } + + pub fn assign(&mut self, ae: AExpr, arena: &mut Arena) { + let node = arena.add(ae); + self.node = node; + } + + #[cfg(feature = "cse")] + pub(crate) fn is_leaf(&self, arena: &Arena) -> bool { + matches!(self.to_aexpr(arena), AExpr::Column(_) | AExpr::Literal(_)) + } + + #[cfg(feature = "cse")] + pub(crate) fn hashable_and_cmp<'a>(&self, arena: &'a Arena) -> AExprArena<'a> { + AExprArena { + node: self.node, + arena, + } + } +} + +#[cfg(feature = "cse")] +pub struct AExprArena<'a> { + node: Node, + arena: &'a Arena, +} + +#[cfg(feature = "cse")] +impl Debug for AExprArena<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "AexprArena: {}", self.node.0) + } +} + +impl AExpr { + #[cfg(feature = "cse")] + fn is_equal_node(&self, other: &Self) -> bool { + use AExpr::*; + match (self, other) { + (Alias(_, l), Alias(_, r)) => l == r, + (Column(l), Column(r)) => l == r, + (Literal(l), Literal(r)) => l == r, + (Window { options: l, .. }, Window { options: r, .. }) => l == r, + ( + Cast { + options: strict_l, + dtype: dtl, + .. + }, + Cast { + options: strict_r, + dtype: dtr, + .. + }, + ) => strict_l == strict_r && dtl == dtr, + (Sort { options: l, .. }, Sort { options: r, .. }) => l == r, + (Gather { .. }, Gather { .. }) + | (Filter { .. }, Filter { .. }) + | (Ternary { .. }, Ternary { .. }) + | (Len, Len) + | (Slice { .. }, Slice { .. }) + | (Explode(_), Explode(_)) => true, + ( + SortBy { + sort_options: l_sort_options, + .. + }, + SortBy { + sort_options: r_sort_options, + .. + }, + ) => l_sort_options == r_sort_options, + (Agg(l), Agg(r)) => l.equal_nodes(r), + ( + Function { + input: il, + function: fl, + options: ol, + }, + Function { + input: ir, + function: fr, + options: or, + }, + ) => { + fl == fr && ol == or && { + let mut all_same_name = true; + for (l, r) in il.iter().zip(ir) { + all_same_name &= l.output_name() == r.output_name() + } + + all_same_name + } + }, + (AnonymousFunction { .. }, AnonymousFunction { .. }) => false, + (BinaryExpr { op: l, .. }, BinaryExpr { op: r, .. }) => l == r, + _ => false, + } + } +} + +#[cfg(feature = "cse")] +impl<'a> AExprArena<'a> { + pub fn new(node: Node, arena: &'a Arena) -> Self { + Self { node, arena } + } + pub fn to_aexpr(&self) -> &'a AExpr { + self.arena.get(self.node) + } + + // Check single node on equality + pub fn is_equal_single(&self, other: &Self) -> bool { + let self_ae = self.to_aexpr(); + let other_ae = other.to_aexpr(); + self_ae.is_equal_node(other_ae) + } +} + +#[cfg(feature = "cse")] +impl PartialEq for AExprArena<'_> { + fn eq(&self, other: &Self) -> bool { + let mut scratch1 = unitvec![]; + let mut scratch2 = unitvec![]; + + scratch1.push(self.node); + scratch2.push(other.node); + + loop { + match (scratch1.pop(), scratch2.pop()) { + (Some(l), Some(r)) => { + let l = Self::new(l, self.arena); + let r = Self::new(r, self.arena); + + if !l.is_equal_single(&r) { + return false; + } + + l.to_aexpr().inputs_rev(&mut scratch1); + r.to_aexpr().inputs_rev(&mut scratch2); + }, + (None, None) => return true, + _ => return false, + } + } + } +} + +impl TreeWalker for AexprNode { + type Arena = Arena; + fn apply_children PolarsResult>( + &self, + op: &mut F, + arena: &Self::Arena, + ) -> PolarsResult { + let mut scratch = unitvec![]; + + self.to_aexpr(arena).inputs_rev(&mut scratch); + for node in scratch.as_slice() { + let aenode = AexprNode::new(*node); + match op(&aenode, arena)? { + // let the recursion continue + VisitRecursion::Continue | VisitRecursion::Skip => {}, + // early stop + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + } + } + Ok(VisitRecursion::Continue) + } + + fn map_children PolarsResult>( + mut self, + op: &mut F, + arena: &mut Self::Arena, + ) -> PolarsResult { + let mut scratch = unitvec![]; + + let ae = arena.get(self.node).clone(); + ae.inputs_rev(&mut scratch); + + // rewrite the nodes + for node in scratch.as_mut_slice() { + let aenode = AexprNode::new(*node); + *node = op(aenode, arena)?.node; + } + + scratch.as_mut_slice().reverse(); + let ae = ae.replace_inputs(&scratch); + self.node = arena.add(ae); + Ok(self) + } +} diff --git a/crates/polars-plan/src/plans/visitor/hash.rs b/crates/polars-plan/src/plans/visitor/hash.rs new file mode 100644 index 000000000000..b10d03352844 --- /dev/null +++ b/crates/polars-plan/src/plans/visitor/hash.rs @@ -0,0 +1,515 @@ +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +use polars_utils::arena::Arena; + +use super::*; +use crate::plans::{AExpr, IR}; +use crate::prelude::ExprIR; +use crate::prelude::aexpr::traverse_and_hash_aexpr; + +impl IRNode { + pub(crate) fn hashable_and_cmp<'a>( + &'a self, + lp_arena: &'a Arena, + expr_arena: &'a Arena, + ) -> HashableEqLP<'a> { + HashableEqLP { + node: *self, + lp_arena, + expr_arena, + ignore_cache: false, + } + } +} + +pub(crate) struct HashableEqLP<'a> { + node: IRNode, + lp_arena: &'a Arena, + expr_arena: &'a Arena, + ignore_cache: bool, +} + +impl HashableEqLP<'_> { + /// When encountering a Cache node, ignore it and take the input. + #[cfg(feature = "cse")] + pub(crate) fn ignore_caches(mut self) -> Self { + self.ignore_cache = true; + self + } +} + +fn hash_option_expr(expr: &Option, expr_arena: &Arena, state: &mut H) { + if let Some(e) = expr { + e.traverse_and_hash(expr_arena, state) + } +} + +fn hash_exprs(exprs: &[ExprIR], expr_arena: &Arena, state: &mut H) { + for e in exprs { + e.traverse_and_hash(expr_arena, state); + } +} + +impl Hash for HashableEqLP<'_> { + // This hashes the variant, not the whole plan + fn hash(&self, state: &mut H) { + let alp = self.node.to_alp(self.lp_arena); + std::mem::discriminant(alp).hash(state); + match alp { + #[cfg(feature = "python")] + IR::PythonScan { .. } => {}, + IR::Slice { + offset, + len, + input: _, + } => { + len.hash(state); + offset.hash(state); + }, + IR::Filter { + input: _, + predicate, + } => { + predicate.traverse_and_hash(self.expr_arena, state); + }, + IR::Scan { + sources, + file_info: _, + hive_parts: _, + predicate, + output_schema: _, + scan_type, + unified_scan_args, + } => { + // We don't have to traverse the schema, hive partitions etc. as they are derivative from the paths. + scan_type.hash(state); + sources.hash(state); + hash_option_expr(predicate, self.expr_arena, state); + unified_scan_args.hash(state); + }, + IR::DataFrameScan { + df, + schema: _, + output_schema, + .. + } => { + (Arc::as_ptr(df) as usize).hash(state); + output_schema.hash(state); + }, + IR::SimpleProjection { columns, input: _ } => { + columns.hash(state); + }, + IR::Select { + input: _, + expr, + schema: _, + options, + } => { + hash_exprs(expr, self.expr_arena, state); + options.hash(state); + }, + IR::Sort { + input: _, + by_column, + slice, + sort_options, + } => { + hash_exprs(by_column, self.expr_arena, state); + slice.hash(state); + sort_options.hash(state); + }, + IR::GroupBy { + input: _, + keys, + aggs, + schema: _, + apply, + maintain_order, + options, + } => { + hash_exprs(keys, self.expr_arena, state); + hash_exprs(aggs, self.expr_arena, state); + apply.is_none().hash(state); + maintain_order.hash(state); + options.hash(state); + }, + IR::Join { + input_left: _, + input_right: _, + schema: _, + left_on, + right_on, + options, + } => { + hash_exprs(left_on, self.expr_arena, state); + hash_exprs(right_on, self.expr_arena, state); + options.hash(state); + }, + IR::HStack { + input: _, + exprs, + schema: _, + options, + } => { + hash_exprs(exprs, self.expr_arena, state); + options.hash(state); + }, + IR::Distinct { input: _, options } => { + options.hash(state); + }, + IR::MapFunction { input: _, function } => { + function.hash(state); + }, + IR::Union { inputs: _, options } => options.hash(state), + IR::HConcat { + inputs: _, + schema: _, + options, + } => { + options.hash(state); + }, + IR::ExtContext { + input: _, + contexts, + schema: _, + } => { + for node in contexts { + traverse_and_hash_aexpr(*node, self.expr_arena, state); + } + }, + IR::Sink { input: _, payload } => { + payload.traverse_and_hash(self.expr_arena, state); + }, + IR::SinkMultiple { .. } => {}, + IR::Cache { + input: _, + id, + cache_hits, + } => { + id.hash(state); + cache_hits.hash(state); + }, + #[cfg(feature = "merge_sorted")] + IR::MergeSorted { + input_left: _, + input_right: _, + key, + } => { + key.hash(state); + }, + IR::Invalid => unreachable!(), + } + } +} + +fn expr_irs_eq(l: &[ExprIR], r: &[ExprIR], expr_arena: &Arena) -> bool { + l.len() == r.len() && l.iter().zip(r).all(|(l, r)| expr_ir_eq(l, r, expr_arena)) +} + +fn expr_ir_eq(l: &ExprIR, r: &ExprIR, expr_arena: &Arena) -> bool { + l.get_alias() == r.get_alias() && { + let l = AexprNode::new(l.node()); + let r = AexprNode::new(r.node()); + l.hashable_and_cmp(expr_arena) == r.hashable_and_cmp(expr_arena) + } +} + +fn opt_expr_ir_eq(l: &Option, r: &Option, expr_arena: &Arena) -> bool { + match (l, r) { + (None, None) => true, + (Some(l), Some(r)) => expr_ir_eq(l, r, expr_arena), + _ => false, + } +} + +impl HashableEqLP<'_> { + fn is_equal(&self, other: &Self) -> bool { + let alp_l = self.node.to_alp(self.lp_arena); + let alp_r = other.node.to_alp(self.lp_arena); + if std::mem::discriminant(alp_l) != std::mem::discriminant(alp_r) { + return false; + } + match (alp_l, alp_r) { + ( + IR::Slice { + input: _, + offset: ol, + len: ll, + }, + IR::Slice { + input: _, + offset: or, + len: lr, + }, + ) => ol == or && ll == lr, + ( + IR::Filter { + input: _, + predicate: l, + }, + IR::Filter { + input: _, + predicate: r, + }, + ) => expr_ir_eq(l, r, self.expr_arena), + ( + IR::Scan { + sources: pl, + file_info: _, + hive_parts: _, + predicate: pred_l, + output_schema: _, + scan_type: stl, + unified_scan_args: ol, + }, + IR::Scan { + sources: pr, + file_info: _, + hive_parts: _, + predicate: pred_r, + output_schema: _, + scan_type: str, + unified_scan_args: or, + }, + ) => { + pl == pr + && stl == str + && ol == or + && opt_expr_ir_eq(pred_l, pred_r, self.expr_arena) + }, + ( + IR::DataFrameScan { + df: dfl, + schema: _, + output_schema: s_l, + }, + IR::DataFrameScan { + df: dfr, + schema: _, + output_schema: s_r, + }, + ) => std::ptr::eq(Arc::as_ptr(dfl), Arc::as_ptr(dfr)) && s_l == s_r, + ( + IR::SimpleProjection { + input: _, + columns: cl, + }, + IR::SimpleProjection { + input: _, + columns: cr, + }, + ) => cl == cr, + ( + IR::Select { + input: _, + expr: el, + options: ol, + schema: _, + }, + IR::Select { + input: _, + expr: er, + options: or, + schema: _, + }, + ) => ol == or && expr_irs_eq(el, er, self.expr_arena), + ( + IR::Sort { + input: _, + by_column: cl, + slice: l_slice, + sort_options: l_options, + }, + IR::Sort { + input: _, + by_column: cr, + slice: r_slice, + sort_options: r_options, + }, + ) => { + (l_slice == r_slice && l_options == r_options) + && expr_irs_eq(cl, cr, self.expr_arena) + }, + ( + IR::GroupBy { + input: _, + keys: keys_l, + aggs: aggs_l, + schema: _, + apply: apply_l, + maintain_order: maintain_l, + options: ol, + }, + IR::GroupBy { + input: _, + keys: keys_r, + aggs: aggs_r, + schema: _, + apply: apply_r, + maintain_order: maintain_r, + options: or, + }, + ) => { + apply_l.is_none() + && apply_r.is_none() + && ol == or + && maintain_l == maintain_r + && expr_irs_eq(keys_l, keys_r, self.expr_arena) + && expr_irs_eq(aggs_l, aggs_r, self.expr_arena) + }, + ( + IR::Join { + input_left: _, + input_right: _, + schema: _, + left_on: ll, + right_on: rl, + options: ol, + }, + IR::Join { + input_left: _, + input_right: _, + schema: _, + left_on: lr, + right_on: rr, + options: or, + }, + ) => { + ol == or + && expr_irs_eq(ll, lr, self.expr_arena) + && expr_irs_eq(rl, rr, self.expr_arena) + }, + ( + IR::HStack { + input: _, + exprs: el, + schema: _, + options: ol, + }, + IR::HStack { + input: _, + exprs: er, + schema: _, + options: or, + }, + ) => ol == or && expr_irs_eq(el, er, self.expr_arena), + ( + IR::Distinct { + input: _, + options: ol, + }, + IR::Distinct { + input: _, + options: or, + }, + ) => ol == or, + ( + IR::MapFunction { + input: _, + function: l, + }, + IR::MapFunction { + input: _, + function: r, + }, + ) => l == r, + ( + IR::Union { + inputs: _, + options: l, + }, + IR::Union { + inputs: _, + options: r, + }, + ) => l == r, + ( + IR::HConcat { + inputs: _, + schema: _, + options: l, + }, + IR::HConcat { + inputs: _, + schema: _, + options: r, + }, + ) => l == r, + ( + IR::ExtContext { + input: _, + contexts: l, + schema: _, + }, + IR::ExtContext { + input: _, + contexts: r, + schema: _, + }, + ) => { + l.len() == r.len() + && l.iter().zip(r.iter()).all(|(l, r)| { + let l = AexprNode::new(*l).hashable_and_cmp(self.expr_arena); + let r = AexprNode::new(*r).hashable_and_cmp(self.expr_arena); + l == r + }) + }, + _ => false, + } + } +} + +impl PartialEq for HashableEqLP<'_> { + fn eq(&self, other: &Self) -> bool { + let mut scratch_1 = vec![]; + let mut scratch_2 = vec![]; + + scratch_1.push(self.node.node()); + scratch_2.push(other.node.node()); + + loop { + match (scratch_1.pop(), scratch_2.pop()) { + (Some(l), Some(r)) => { + let l = IRNode::new(l); + let r = IRNode::new(r); + let l_alp = l.to_alp(self.lp_arena); + let r_alp = r.to_alp(self.lp_arena); + + if self.ignore_cache { + match (l_alp, r_alp) { + (IR::Cache { input: l, .. }, IR::Cache { input: r, .. }) => { + scratch_1.push(*l); + scratch_2.push(*r); + continue; + }, + (IR::Cache { input: l, .. }, _) => { + scratch_1.push(*l); + scratch_2.push(r.node()); + continue; + }, + (_, IR::Cache { input: r, .. }) => { + scratch_1.push(l.node()); + scratch_2.push(*r); + continue; + }, + _ => {}, + } + } + + if !l + .hashable_and_cmp(self.lp_arena, self.expr_arena) + .is_equal(&r.hashable_and_cmp(self.lp_arena, self.expr_arena)) + { + return false; + } + + l_alp.copy_inputs(&mut scratch_1); + r_alp.copy_inputs(&mut scratch_2); + }, + (None, None) => return true, + _ => return false, + } + } + } +} diff --git a/crates/polars-plan/src/plans/visitor/lp.rs b/crates/polars-plan/src/plans/visitor/lp.rs new file mode 100644 index 000000000000..87b22a627bd7 --- /dev/null +++ b/crates/polars-plan/src/plans/visitor/lp.rs @@ -0,0 +1,133 @@ +use polars_utils::unitvec; + +use super::*; +use crate::prelude::*; + +#[derive(Copy, Clone, Debug)] +pub struct IRNode { + node: Node, + // Whether it may mutate the arena on rewrite. + // If set the Rewriting Treewalker will mutate the arena. + mutate: bool, +} + +impl IRNode { + pub fn new(node: Node) -> Self { + Self { + node, + mutate: false, + } + } + + pub fn new_mutate(node: Node) -> Self { + Self { node, mutate: true } + } + + pub fn node(&self) -> Node { + self.node + } + + pub fn replace_node(&mut self, node: Node) { + self.node = node; + } + + /// Replace the current `Node` with a new `IR`. + pub fn replace(&mut self, ae: IR, arena: &mut Arena) { + let node = self.node; + arena.replace(node, ae); + } + + pub fn to_alp<'a>(&self, arena: &'a Arena) -> &'a IR { + arena.get(self.node) + } + + pub fn to_alp_mut<'a>(&mut self, arena: &'a mut Arena) -> &'a mut IR { + arena.get_mut(self.node) + } + + pub fn assign(&mut self, ir_node: IR, arena: &mut Arena) { + let node = arena.add(ir_node); + self.node = node; + } +} + +pub type IRNodeArena = (Arena, Arena); + +impl TreeWalker for IRNode { + type Arena = IRNodeArena; + + fn apply_children PolarsResult>( + &self, + op: &mut F, + arena: &Self::Arena, + ) -> PolarsResult { + let mut scratch = unitvec![]; + + self.to_alp(&arena.0).copy_inputs(&mut scratch); + for &node in scratch.as_slice() { + let mut lp_node = IRNode::new(node); + lp_node.mutate = self.mutate; + match op(&lp_node, arena)? { + // let the recursion continue + VisitRecursion::Continue | VisitRecursion::Skip => {}, + // early stop + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + } + } + Ok(VisitRecursion::Continue) + } + + fn map_children PolarsResult>( + self, + op: &mut F, + arena: &mut Self::Arena, + ) -> PolarsResult { + let mut inputs = vec![]; + let mut exprs = vec![]; + + let lp = arena.0.get(self.node); + lp.copy_inputs(&mut inputs); + lp.copy_exprs(&mut exprs); + + // rewrite the nodes + for node in &mut inputs { + let mut lp_node = IRNode::new(*node); + lp_node.mutate = self.mutate; + *node = op(lp_node, arena)?.node; + } + let lp = arena.0.get(self.node); + let lp = lp.with_exprs_and_input(exprs, inputs); + if self.mutate { + arena.0.replace(self.node, lp); + Ok(self) + } else { + let node = arena.0.add(lp); + Ok(IRNode::new_mutate(node)) + } + } +} + +#[cfg(feature = "cse")] +pub(crate) fn with_ir_arena T, T>( + lp_arena: &mut Arena, + expr_arena: &mut Arena, + func: F, +) -> T { + try_with_ir_arena(lp_arena, expr_arena, |a| Ok(func(a))).unwrap() +} + +#[cfg(feature = "cse")] +pub(crate) fn try_with_ir_arena PolarsResult, T>( + lp_arena: &mut Arena, + expr_arena: &mut Arena, + func: F, +) -> PolarsResult { + let owned_lp_arena = std::mem::take(lp_arena); + let owned_expr_arena = std::mem::take(expr_arena); + + let mut arena = (owned_lp_arena, owned_expr_arena); + let out = func(&mut arena); + std::mem::swap(lp_arena, &mut arena.0); + std::mem::swap(expr_arena, &mut arena.1); + out +} diff --git a/crates/polars-plan/src/plans/visitor/mod.rs b/crates/polars-plan/src/plans/visitor/mod.rs new file mode 100644 index 000000000000..1ff6deba8c84 --- /dev/null +++ b/crates/polars-plan/src/plans/visitor/mod.rs @@ -0,0 +1,38 @@ +//! Defines different visitor patterns and for any tree. + +use arrow::legacy::error::PolarsResult; +mod expr; +#[cfg(feature = "cse")] +mod hash; +mod lp; +mod visitors; + +pub use expr::*; +pub use lp::*; +pub use visitors::*; + +/// Controls how the [`TreeWalker`] recursion should proceed for [`TreeWalker::visit`]. +#[derive(Debug)] +pub enum VisitRecursion { + /// Continue the visit to this node tree. + Continue, + /// Keep recursive but skip applying op on the children + Skip, + /// Stop the visit to this node tree. + Stop, +} + +/// Controls how the [`TreeWalker`] recursion should proceed for [`TreeWalker::rewrite`]. +#[derive(Debug)] +pub enum RewriteRecursion { + /// Continue the visit to this node and children. + MutateAndContinue, + /// Don't mutate this node, continue visiting the children + NoMutateAndContinue, + /// Stop and return. + /// This doesn't visit the children + Stop, + /// Call `op` immediately and return + /// This doesn't visit the children + MutateAndStop, +} diff --git a/crates/polars-plan/src/plans/visitor/visitors.rs b/crates/polars-plan/src/plans/visitor/visitors.rs new file mode 100644 index 000000000000..1e23fa5cac8f --- /dev/null +++ b/crates/polars-plan/src/plans/visitor/visitors.rs @@ -0,0 +1,108 @@ +use recursive::recursive; + +use super::*; + +/// An implementor of this trait decides how and in which order its nodes get traversed +/// Implemented for [`crate::dsl::Expr`] and [`AexprNode`]. +pub trait TreeWalker: Sized { + type Arena; + fn apply_children PolarsResult>( + &self, + op: &mut F, + arena: &Self::Arena, + ) -> PolarsResult; + + fn map_children PolarsResult>( + self, + op: &mut F, + arena: &mut Self::Arena, + ) -> PolarsResult; + + /// Walks all nodes in depth-first-order. + #[recursive] + fn visit>( + &self, + visitor: &mut V, + arena: &Self::Arena, + ) -> PolarsResult { + match visitor.pre_visit(self, arena)? { + VisitRecursion::Continue => {}, + // If the recursion should skip, do not apply to its children. And let the recursion continue + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + // If the recursion should stop, do not apply to its children + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + }; + + match self.apply_children(&mut |node, arena| node.visit(visitor, arena), arena)? { + // let the recursion continue + VisitRecursion::Continue | VisitRecursion::Skip => {}, + // If the recursion should stop, no further post visit will be performed + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + } + + visitor.post_visit(self, arena) + } + + #[recursive] + fn rewrite>( + self, + rewriter: &mut R, + arena: &mut Self::Arena, + ) -> PolarsResult { + let mutate_this_node = match rewriter.pre_visit(&self, arena)? { + RewriteRecursion::MutateAndStop => return rewriter.mutate(self, arena), + RewriteRecursion::Stop => return Ok(self), + RewriteRecursion::MutateAndContinue => true, + RewriteRecursion::NoMutateAndContinue => false, + }; + + let after_applied_children = + self.map_children(&mut |node, arena| node.rewrite(rewriter, arena), arena)?; + + if mutate_this_node { + rewriter.mutate(after_applied_children, arena) + } else { + Ok(after_applied_children) + } + } +} + +pub trait Visitor { + type Node; + type Arena; + + /// Invoked before any children of `node` are visited. + fn pre_visit( + &mut self, + _node: &Self::Node, + _arena: &Self::Arena, + ) -> PolarsResult { + Ok(VisitRecursion::Continue) + } + + /// Invoked after all children of `node` are visited. Default + /// implementation does nothing. + fn post_visit( + &mut self, + _node: &Self::Node, + _arena: &Self::Arena, + ) -> PolarsResult { + Ok(VisitRecursion::Continue) + } +} + +pub trait RewritingVisitor { + type Node; + type Arena; + + /// Invoked before any children of `node` are visited. + fn pre_visit( + &mut self, + _node: &Self::Node, + _arena: &mut Self::Arena, + ) -> PolarsResult { + Ok(RewriteRecursion::MutateAndContinue) + } + + fn mutate(&mut self, node: Self::Node, _arena: &mut Self::Arena) -> PolarsResult; +} diff --git a/crates/polars-plan/src/prelude.rs b/crates/polars-plan/src/prelude.rs new file mode 100644 index 000000000000..d90e032cc925 --- /dev/null +++ b/crates/polars-plan/src/prelude.rs @@ -0,0 +1,18 @@ +pub(crate) use polars_ops::prelude::*; +#[cfg(feature = "temporal")] +pub(crate) use polars_time::in_nanoseconds_window; +#[cfg(any( + feature = "temporal", + feature = "dtype-duration", + feature = "dtype-date", + feature = "dtype-time" +))] +pub(crate) use polars_time::prelude::*; +pub use polars_utils::arena::{Arena, Node}; + +pub use crate::dsl::*; +#[cfg(feature = "debugging")] +pub use crate::plans::debug::*; +pub use crate::plans::options::*; +pub use crate::plans::*; +pub use crate::utils::*; diff --git a/crates/polars-plan/src/utils.rs b/crates/polars-plan/src/utils.rs new file mode 100644 index 000000000000..2715d1a7c569 --- /dev/null +++ b/crates/polars-plan/src/utils.rs @@ -0,0 +1,368 @@ +use std::fmt::Formatter; +use std::iter::FlatMap; + +use polars_core::prelude::*; + +use self::visitor::{AexprNode, RewritingVisitor, TreeWalker}; +use crate::constants::get_len_name; +use crate::prelude::*; + +/// Utility to write comma delimited strings +pub fn comma_delimited(mut s: String, items: &[S]) -> String +where + S: AsRef, +{ + s.push('('); + for c in items { + s.push_str(c.as_ref()); + s.push_str(", "); + } + s.pop(); + s.pop(); + s.push(')'); + s +} + +/// Utility to write comma delimited +pub(crate) fn fmt_column_delimited>( + f: &mut Formatter<'_>, + items: &[S], + container_start: &str, + container_end: &str, +) -> std::fmt::Result { + write!(f, "{container_start}")?; + for (i, c) in items.iter().enumerate() { + write!(f, "{}", c.as_ref())?; + if i != (items.len() - 1) { + write!(f, ", ")?; + } + } + write!(f, "{container_end}") +} + +pub(crate) fn is_scan(plan: &IR) -> bool { + matches!(plan, IR::Scan { .. } | IR::DataFrameScan { .. }) +} + +/// A projection that only takes a column or a column + alias. +#[cfg(feature = "meta")] +pub(crate) fn aexpr_is_simple_projection(current_node: Node, arena: &Arena) -> bool { + arena + .iter(current_node) + .all(|(_node, e)| matches!(e, AExpr::Column(_) | AExpr::Alias(_, _))) +} + +pub fn has_aexpr(current_node: Node, arena: &Arena, matches: F) -> bool +where + F: Fn(&AExpr) -> bool, +{ + arena.iter(current_node).any(|(_node, e)| matches(e)) +} + +pub fn has_aexpr_window(current_node: Node, arena: &Arena) -> bool { + has_aexpr(current_node, arena, |e| matches!(e, AExpr::Window { .. })) +} + +pub fn has_aexpr_literal(current_node: Node, arena: &Arena) -> bool { + has_aexpr(current_node, arena, |e| matches!(e, AExpr::Literal(_))) +} + +/// Can check if an expression tree has a matching_expr. This +/// requires a dummy expression to be created that will be used to pattern match against. +pub fn has_expr(current_expr: &Expr, matches: F) -> bool +where + F: Fn(&Expr) -> bool, +{ + current_expr.into_iter().any(matches) +} + +/// Check if expression is independent from any column. +pub(crate) fn is_column_independent(expr: &Expr) -> bool { + !expr.into_iter().any(|e| match e { + Expr::Nth(_) + | Expr::Column(_) + | Expr::Columns(_) + | Expr::DtypeColumn(_) + | Expr::IndexColumn(_) + | Expr::Wildcard + | Expr::Len + | Expr::SubPlan(..) + | Expr::Selector(_) => true, + + #[cfg(feature = "dtype-struct")] + Expr::Field(_) => true, + _ => false, + }) +} + +pub fn has_null(current_expr: &Expr) -> bool { + has_expr( + current_expr, + |e| matches!(e, Expr::Literal(LiteralValue::Scalar(sc)) if sc.is_null()), + ) +} + +pub fn aexpr_output_name(node: Node, arena: &Arena) -> PolarsResult { + for (_, ae) in arena.iter(node) { + match ae { + // don't follow the partition by branch + AExpr::Window { function, .. } => return aexpr_output_name(*function, arena), + AExpr::Column(name) => return Ok(name.clone()), + AExpr::Alias(_, name) => return Ok(name.clone()), + AExpr::Len => return Ok(get_len_name()), + AExpr::Literal(val) => return Ok(val.output_column_name().clone()), + AExpr::Ternary { truthy, .. } => return aexpr_output_name(*truthy, arena), + _ => {}, + } + } + let expr = node_to_expr(node, arena); + polars_bail!( + ComputeError: + "unable to find root column name for expr '{expr:?}' when calling 'output_name'", + ); +} + +/// output name of expr +pub fn expr_output_name(expr: &Expr) -> PolarsResult { + for e in expr { + match e { + // don't follow the partition by branch + Expr::Window { function, .. } => return expr_output_name(function), + Expr::Column(name) => return Ok(name.clone()), + Expr::Alias(_, name) => return Ok(name.clone()), + Expr::KeepName(_) | Expr::Wildcard | Expr::RenameAlias { .. } => polars_bail!( + ComputeError: + "cannot determine output column without a context for this expression" + ), + Expr::Columns(_) | Expr::DtypeColumn(_) | Expr::IndexColumn(_) => polars_bail!( + ComputeError: + "this expression may produce multiple output names" + ), + Expr::Len => return Ok(get_len_name()), + Expr::Literal(val) => return Ok(val.output_column_name().clone()), + _ => {}, + } + } + polars_bail!( + ComputeError: + "unable to find root column name for expr '{expr:?}' when calling 'output_name'", + ); +} + +/// This function should be used to find the name of the start of an expression +/// Normal iteration would just return the first root column it found +pub(crate) fn get_single_leaf(expr: &Expr) -> PolarsResult { + for e in expr { + match e { + Expr::Filter { input, .. } => return get_single_leaf(input), + Expr::Gather { expr, .. } => return get_single_leaf(expr), + Expr::SortBy { expr, .. } => return get_single_leaf(expr), + Expr::Window { function, .. } => return get_single_leaf(function), + Expr::Column(name) => return Ok(name.clone()), + Expr::Len => return Ok(get_len_name()), + _ => {}, + } + } + polars_bail!( + ComputeError: "unable to find a single leaf column in expr {:?}", expr + ); +} + +#[allow(clippy::type_complexity)] +pub fn expr_to_leaf_column_names_iter(expr: &Expr) -> impl Iterator + '_ { + expr_to_leaf_column_exprs_iter(expr).flat_map(|e| expr_to_leaf_column_name(e).ok()) +} + +/// This should gradually replace expr_to_root_column as this will get all names in the tree. +pub fn expr_to_leaf_column_names(expr: &Expr) -> Vec { + expr_to_leaf_column_names_iter(expr).collect() +} + +/// unpack alias(col) to name of the root column name +pub fn expr_to_leaf_column_name(expr: &Expr) -> PolarsResult { + let mut leaves = expr_to_leaf_column_exprs_iter(expr).collect::>(); + polars_ensure!(leaves.len() <= 1, ComputeError: "found more than one root column name"); + match leaves.pop() { + Some(Expr::Column(name)) => Ok(name.clone()), + Some(Expr::Wildcard) => polars_bail!( + ComputeError: "wildcard has no root column name", + ), + Some(_) => unreachable!(), + None => polars_bail!( + ComputeError: "no root column name found", + ), + } +} + +#[allow(clippy::type_complexity)] +pub(crate) fn aexpr_to_column_nodes_iter<'a>( + root: Node, + arena: &'a Arena, +) -> FlatMap, Option, fn((Node, &'a AExpr)) -> Option> { + arena.iter(root).flat_map(|(node, ae)| { + if matches!(ae, AExpr::Column(_)) { + Some(ColumnNode(node)) + } else { + None + } + }) +} + +pub fn column_node_to_name(node: ColumnNode, arena: &Arena) -> &PlSmallStr { + if let AExpr::Column(name) = arena.get(node.0) { + name + } else { + unreachable!() + } +} + +/// Get all leaf column expressions in the expression tree. +pub(crate) fn expr_to_leaf_column_exprs_iter(expr: &Expr) -> impl Iterator { + expr.into_iter().flat_map(|e| match e { + Expr::Column(_) | Expr::Wildcard => Some(e), + _ => None, + }) +} + +/// Take a list of expressions and a schema and determine the output schema. +pub fn expressions_to_schema( + expr: &[Expr], + schema: &Schema, + ctxt: Context, +) -> PolarsResult { + let mut expr_arena = Arena::with_capacity(4 * expr.len()); + expr.iter() + .map(|expr| { + let mut field = expr.to_field_amortized(schema, ctxt, &mut expr_arena)?; + + field.dtype = field.dtype.materialize_unknown(true)?; + Ok(field) + }) + .collect() +} + +pub fn aexpr_to_leaf_names_iter( + node: Node, + arena: &Arena, +) -> impl Iterator + '_ { + aexpr_to_column_nodes_iter(node, arena).map(|node| match arena.get(node.0) { + AExpr::Column(name) => name.clone(), + _ => unreachable!(), + }) +} + +pub fn aexpr_to_leaf_names(node: Node, arena: &Arena) -> Vec { + aexpr_to_leaf_names_iter(node, arena).collect() +} + +pub fn aexpr_to_leaf_name(node: Node, arena: &Arena) -> PlSmallStr { + aexpr_to_leaf_names_iter(node, arena).next().unwrap() +} + +/// check if a selection/projection can be done on the downwards schema +pub(crate) fn check_input_node( + node: Node, + input_schema: &Schema, + expr_arena: &Arena, +) -> bool { + aexpr_to_leaf_names_iter(node, expr_arena).all(|name| input_schema.contains(name.as_ref())) +} + +pub(crate) fn check_input_column_node( + node: ColumnNode, + input_schema: &Schema, + expr_arena: &Arena, +) -> bool { + match expr_arena.get(node.0) { + AExpr::Column(name) => input_schema.contains(name.as_ref()), + // Invariant of `ColumnNode` + _ => unreachable!(), + } +} + +pub(crate) fn aexprs_to_schema, K: Into>( + expr: I, + schema: &Schema, + ctxt: Context, + arena: &Arena, +) -> Schema { + expr.into_iter() + .map(|node| { + arena + .get(node.into()) + .to_field(schema, ctxt, arena) + .unwrap() + }) + .collect() +} + +pub(crate) fn expr_irs_to_schema, K: AsRef>( + expr: I, + schema: &Schema, + ctxt: Context, + arena: &Arena, +) -> Schema { + expr.into_iter() + .map(|e| { + let e = e.as_ref(); + let mut field = e.field(schema, ctxt, arena).expect("should be resolved"); + + // TODO! (can this be removed?) + if let Some(name) = e.get_alias() { + field.name = name.clone() + } + field.dtype = field.dtype.materialize_unknown(true).unwrap(); + field + }) + .collect() +} + +/// Concatenate multiple schemas into one, disallowing duplicate field names +pub fn merge_schemas(schemas: &[SchemaRef]) -> PolarsResult { + let schema_size = schemas.iter().map(|schema| schema.len()).sum(); + let mut merged_schema = Schema::with_capacity(schema_size); + + for schema in schemas { + schema.iter().try_for_each(|(name, dtype)| { + if merged_schema.with_column(name.clone(), dtype.clone()).is_none() { + Ok(()) + } else { + Err(polars_err!(Duplicate: "Column with name '{}' has more than one occurrence", name)) + } + })?; + } + + Ok(merged_schema) +} + +/// Rename all reference to the column in `map` with their corresponding new name. +pub fn rename_columns( + node: Node, + expr_arena: &mut Arena, + map: &PlIndexMap, +) -> Node { + struct RenameColumns<'a>(&'a PlIndexMap); + impl RewritingVisitor for RenameColumns<'_> { + type Node = AexprNode; + type Arena = Arena; + + fn mutate( + &mut self, + node: Self::Node, + arena: &mut Self::Arena, + ) -> PolarsResult { + if let AExpr::Column(name) = arena.get(node.node()) { + if let Some(new_name) = self.0.get(name) { + return Ok(AexprNode::new(arena.add(AExpr::Column(new_name.clone())))); + } + } + + Ok(node) + } + } + + AexprNode::new(node) + .rewrite(&mut RenameColumns(map), expr_arena) + .unwrap() + .node() +} diff --git a/crates/polars-python/Cargo.toml b/crates/polars-python/Cargo.toml new file mode 100644 index 000000000000..f51256572d94 --- /dev/null +++ b/crates/polars-python/Cargo.toml @@ -0,0 +1,280 @@ +[package] +name = "polars-python" +version = { workspace = true } +authors = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +license = { workspace = true } +repository = { workspace = true } +description = "Enable running Polars workloads in Python" + +[dependencies] +polars-compute = { workspace = true } +polars-core = { workspace = true, features = ["python"] } +polars-error = { workspace = true } +polars-expr = { workspace = true } +polars-ffi = { workspace = true } +polars-io = { workspace = true } +polars-lazy = { workspace = true, features = ["python"] } +polars-mem-engine = { workspace = true } +polars-ops = { workspace = true, features = ["bitwise"] } +polars-parquet = { workspace = true, optional = true } +polars-plan = { workspace = true } +polars-row = { workspace = true } +polars-time = { workspace = true } +polars-utils = { workspace = true } + +# TODO! remove this once truly activated. This is required to make sdist building work +# polars-stream = { workspace = true } + +arboard = { workspace = true, optional = true } +arrow = { workspace = true } +bincode = { workspace = true } +bytemuck = { workspace = true } +bytes = { workspace = true } +chrono = { workspace = true } +chrono-tz = { workspace = true } +either = { workspace = true } +flate2 = { workspace = true } +hashbrown = { workspace = true } +itoa = { workspace = true } +libc = { workspace = true } +ndarray = { workspace = true } +num-traits = { workspace = true } +numpy = { workspace = true } +pyo3 = { workspace = true, features = ["abi3-py39", "chrono", "chrono-tz", "multiple-pymethods"] } +rayon = { workspace = true } +recursive = { workspace = true } +serde_json = { workspace = true, optional = true } + +[dependencies.polars] +workspace = true +features = [ + "abs", + "approx_unique", + "array_any_all", + "arg_where", + "bitwise", + "business", + "concat_str", + "cum_agg", + "cumulative_eval", + "dataframe_arithmetic", + "month_start", + "month_end", + "offset_by", + "diagonal_concat", + "diff", + "dot_diagram", + "dot_product", + "dtype-categorical", + "dtype-full", + "dynamic_group_by", + "ewma", + "ewma_by", + "fmt", + "fused", + "interpolate", + "interpolate_by", + "is_first_distinct", + "is_last_distinct", + "is_unique", + "is_between", + "lazy", + "list_eval", + "list_to_struct", + "list_arithmetic", + "array_arithmetic", + "array_to_struct", + "log", + "mode", + "moment", + "ndarray", + "partition_by", + "product", + "random", + "range", + "rank", + "reinterpret", + "replace", + "rolling_window", + "rolling_window_by", + "round_series", + "row_hash", + "rows", + "semi_anti_join", + "serde-lazy", + "string_encoding", + "string_normalize", + "string_reverse", + "string_to_integer", + "string_pad", + "strings", + "temporal", + "to_dummies", + "true_div", + "unique_counts", + "zip_with", + "cov", +] + +[build-dependencies] +version_check = { workspace = true } + +[features] +# Features below are only there to enable building a slim binary during development. +avro = ["polars/avro"] +catalog = ["polars-lazy/catalog"] +parquet = ["polars/parquet", "polars-parquet", "polars-mem-engine/parquet"] +ipc = ["polars/ipc", "polars-mem-engine/ipc"] +ipc_streaming = ["polars/ipc_streaming"] +is_in = ["polars/is_in"] +json = ["polars/serde", "serde_json", "polars/json", "polars-utils/serde", "polars-mem-engine/json"] +trigonometry = ["polars/trigonometry"] +sign = ["polars/sign"] +asof_join = ["polars/asof_join"] +iejoin = ["polars/iejoin"] +cross_join = ["polars/cross_join"] +pct_change = ["polars/pct_change"] +repeat_by = ["polars/repeat_by"] + +streaming = ["polars/streaming"] +meta = ["polars/meta"] +index_of = ["polars/index_of"] +search_sorted = ["polars/search_sorted"] +decompress = ["polars/decompress"] +regex = ["polars/regex"] +csv = ["polars/csv", "polars-mem-engine/csv"] +clipboard = ["arboard"] +extract_jsonpath = ["polars/extract_jsonpath"] +pivot = ["polars/pivot"] +top_k = ["polars/top_k"] +propagate_nans = ["polars/propagate_nans"] +sql = ["polars/sql"] +performant = ["polars/performant"] +timezones = ["polars/timezones"] +cse = ["polars/cse"] +merge_sorted = ["polars/merge_sorted"] +list_gather = ["polars/list_gather"] +list_count = ["polars/list_count"] +array_count = ["polars/array_count", "polars/dtype-array"] +binary_encoding = ["polars/binary_encoding"] +list_sets = ["polars-lazy/list_sets"] +list_any_all = ["polars/list_any_all"] +array_any_all = ["polars/array_any_all", "polars/dtype-array"] +list_drop_nulls = ["polars/list_drop_nulls"] +list_sample = ["polars/list_sample"] +cutqcut = ["polars/cutqcut"] +rle = ["polars/rle"] +extract_groups = ["polars/extract_groups"] +ffi_plugin = ["polars-plan/ffi_plugin"] +cloud = ["polars/cloud", "polars/aws", "polars/gcp", "polars/azure", "polars/http"] +peaks = ["polars/peaks"] +hist = ["polars/hist"] +find_many = ["polars/find_many"] +new_streaming = ["polars-lazy/new_streaming"] +bitwise = ["polars/bitwise"] +approx_unique = ["polars/approx_unique"] +string_normalize = ["polars/string_normalize"] + +dtype-i8 = [] +dtype-i16 = [] +dtype-u8 = [] +dtype-u16 = [] +dtype-i128 = [] +dtype-array = [] +object = ["polars/object"] + +dtypes = [ + "dtype-array", + "dtype-i16", + "dtype-i8", + "dtype-u16", + "dtype-u8", + "dtype-i128", + "object", +] + +operations = [ + "approx_unique", + "array_any_all", + "array_count", + "bitwise", + "is_in", + "repeat_by", + "trigonometry", + "sign", + "performant", + "list_gather", + "list_count", + "list_sets", + "list_any_all", + "list_drop_nulls", + "list_sample", + "cutqcut", + "rle", + "extract_groups", + "pivot", + "extract_jsonpath", + "asof_join", + "cross_join", + "pct_change", + "index_of", + "search_sorted", + "merge_sorted", + "top_k", + "propagate_nans", + "timezones", + "peaks", + "hist", + "find_many", + "string_normalize", +] + +io = [ + "json", + "parquet", + "ipc", + "ipc_streaming", + "avro", + "csv", + "cloud", + "clipboard", +] + +optimizations = [ + "cse", + "polars/fused", + "streaming", +] + +polars_cloud = ["polars/polars_cloud", "polars/ir_serde"] + +# also includes simd +nightly = ["polars/nightly"] + +pymethods = [] + +all = [ + "pymethods", + "optimizations", + "io", + "operations", + "dtypes", + "meta", + "decompress", + "regex", + "sql", + "binary_encoding", + "ffi_plugin", + "polars_cloud", + "new_streaming", +] + +# we cannot conditionally activate simd +# https://github.com/rust-lang/cargo/issues/1197 +# so we have an indirection and compile +# with --no-default-features --features=all for targets without simd +default = [ + "all", +] diff --git a/crates/polars-python/LICENSE b/crates/polars-python/LICENSE new file mode 120000 index 000000000000..30cff7403da0 --- /dev/null +++ b/crates/polars-python/LICENSE @@ -0,0 +1 @@ +../../LICENSE \ No newline at end of file diff --git a/crates/polars-python/README.md b/crates/polars-python/README.md new file mode 100644 index 000000000000..2244b02c543e --- /dev/null +++ b/crates/polars-python/README.md @@ -0,0 +1,7 @@ +# polars-python + +`polars-python` is an **internal sub-crate** of the [Polars](https://crates.io/crates/polars) +library. It enables running Polars workloads in Python. + +**Important Note**: This crate is **not intended for external usage**. Please refer to the main +[Polars crate](https://crates.io/crates/polars) for intended usage. diff --git a/crates/polars-python/build.rs b/crates/polars-python/build.rs new file mode 100644 index 000000000000..3e4ab64620ac --- /dev/null +++ b/crates/polars-python/build.rs @@ -0,0 +1,7 @@ +fn main() { + println!("cargo:rerun-if-changed=build.rs"); + let channel = version_check::Channel::read().unwrap(); + if channel.is_nightly() { + println!("cargo:rustc-cfg=feature=\"nightly\""); + } +} diff --git a/crates/polars-python/src/batched_csv.rs b/crates/polars-python/src/batched_csv.rs new file mode 100644 index 000000000000..16540cd6cc08 --- /dev/null +++ b/crates/polars-python/src/batched_csv.rs @@ -0,0 +1,148 @@ +use std::path::PathBuf; +use std::sync::Mutex; + +use polars::io::RowIndex; +use polars::io::csv::read::OwnedBatchedCsvReader; +use polars::io::mmap::MmapBytesReader; +use polars::prelude::*; +use polars_utils::open_file; +use pyo3::prelude::*; +use pyo3::pybacked::PyBackedStr; + +use crate::error::PyPolarsErr; +use crate::utils::EnterPolarsExt; +use crate::{PyDataFrame, Wrap}; + +#[pyclass] +#[repr(transparent)] +pub struct PyBatchedCsv { + reader: Mutex, +} + +#[pymethods] +#[allow(clippy::wrong_self_convention, clippy::should_implement_trait)] +impl PyBatchedCsv { + #[staticmethod] + #[pyo3(signature = ( + infer_schema_length, chunk_size, has_header, ignore_errors, n_rows, skip_rows, skip_lines, + projection, separator, rechunk, columns, encoding, n_threads, path, schema_overrides, + overwrite_dtype_slice, low_memory, comment_prefix, quote_char, null_values, + missing_utf8_is_empty_string, try_parse_dates, skip_rows_after_header, row_index, + eol_char, raise_if_empty, truncate_ragged_lines, decimal_comma) + )] + fn new( + infer_schema_length: Option, + chunk_size: usize, + has_header: bool, + ignore_errors: bool, + n_rows: Option, + skip_rows: usize, + skip_lines: usize, + projection: Option>, + separator: &str, + rechunk: bool, + columns: Option>, + encoding: Wrap, + n_threads: Option, + path: PathBuf, + schema_overrides: Option)>>, + overwrite_dtype_slice: Option>>, + low_memory: bool, + comment_prefix: Option<&str>, + quote_char: Option<&str>, + null_values: Option>, + missing_utf8_is_empty_string: bool, + try_parse_dates: bool, + skip_rows_after_header: usize, + row_index: Option<(String, IdxSize)>, + eol_char: &str, + raise_if_empty: bool, + truncate_ragged_lines: bool, + decimal_comma: bool, + ) -> PyResult { + let null_values = null_values.map(|w| w.0); + let eol_char = eol_char.as_bytes()[0]; + let row_index = row_index.map(|(name, offset)| RowIndex { + name: name.into(), + offset, + }); + let quote_char = if let Some(s) = quote_char { + if s.is_empty() { + None + } else { + Some(s.as_bytes()[0]) + } + } else { + None + }; + + let schema_overrides = schema_overrides.map(|overwrite_dtype| { + overwrite_dtype + .iter() + .map(|(name, dtype)| { + let dtype = dtype.0.clone(); + Field::new((&**name).into(), dtype) + }) + .collect::() + }); + + let overwrite_dtype_slice = overwrite_dtype_slice.map(|overwrite_dtype| { + overwrite_dtype + .iter() + .map(|dt| dt.0.clone()) + .collect::>() + }); + + let file = open_file(&path).map_err(PyPolarsErr::from)?; + let reader = Box::new(file) as Box; + let reader = CsvReadOptions::default() + .with_infer_schema_length(infer_schema_length) + .with_has_header(has_header) + .with_n_rows(n_rows) + .with_skip_rows(skip_rows) + .with_skip_rows(skip_lines) + .with_ignore_errors(ignore_errors) + .with_projection(projection.map(Arc::new)) + .with_rechunk(rechunk) + .with_chunk_size(chunk_size) + .with_columns(columns.map(|x| x.into_iter().map(PlSmallStr::from_string).collect())) + .with_n_threads(n_threads) + .with_dtype_overwrite(overwrite_dtype_slice.map(Arc::new)) + .with_low_memory(low_memory) + .with_schema_overwrite(schema_overrides.map(Arc::new)) + .with_skip_rows_after_header(skip_rows_after_header) + .with_row_index(row_index) + .with_raise_if_empty(raise_if_empty) + .with_parse_options( + CsvParseOptions::default() + .with_separator(separator.as_bytes()[0]) + .with_encoding(encoding.0) + .with_missing_is_null(!missing_utf8_is_empty_string) + .with_comment_prefix(comment_prefix) + .with_null_values(null_values) + .with_try_parse_dates(try_parse_dates) + .with_quote_char(quote_char) + .with_eol_char(eol_char) + .with_truncate_ragged_lines(truncate_ragged_lines) + .with_decimal_comma(decimal_comma), + ) + .into_reader_with_file_handle(reader); + + let reader = reader.batched(None).map_err(PyPolarsErr::from)?; + + Ok(PyBatchedCsv { + reader: Mutex::new(reader), + }) + } + + fn next_batches(&self, py: Python, n: usize) -> PyResult>> { + let reader = &self.reader; + let batches = py.enter_polars(move || reader.lock().unwrap().next_batches(n))?; + + // SAFETY: same memory layout + let batches = unsafe { + std::mem::transmute::>, Option>>(batches) + }; + Ok(batches) + } +} diff --git a/crates/polars-python/src/catalog/mod.rs b/crates/polars-python/src/catalog/mod.rs new file mode 100644 index 000000000000..1816b7c87322 --- /dev/null +++ b/crates/polars-python/src/catalog/mod.rs @@ -0,0 +1 @@ +pub mod unity; diff --git a/crates/polars-python/src/catalog/unity.rs b/crates/polars-python/src/catalog/unity.rs new file mode 100644 index 000000000000..dce6922750c1 --- /dev/null +++ b/crates/polars-python/src/catalog/unity.rs @@ -0,0 +1,603 @@ +use std::str::FromStr; + +use polars::prelude::{LazyFrame, PlHashMap, PlSmallStr, Schema}; +use polars_io::catalog::unity::client::{CatalogClient, CatalogClientBuilder}; +use polars_io::catalog::unity::models::{ + CatalogInfo, ColumnInfo, DataSourceFormat, NamespaceInfo, TableInfo, TableType, +}; +use polars_io::catalog::unity::schema::parse_type_json_str; +use polars_io::cloud::credential_provider::PlCredentialProvider; +use polars_io::pl_async; +use pyo3::exceptions::PyValueError; +use pyo3::sync::GILOnceCell; +use pyo3::types::{PyAnyMethods, PyDict, PyList, PyNone, PyTuple}; +use pyo3::{Bound, IntoPyObject, Py, PyAny, PyObject, PyResult, Python, pyclass, pymethods}; + +use crate::lazyframe::PyLazyFrame; +use crate::prelude::{Wrap, parse_cloud_options}; +use crate::utils::{EnterPolarsExt, to_py_err}; + +macro_rules! pydict_insert_keys { + ($dict:expr, {$a:expr}) => { + $dict.set_item(stringify!($a), $a)?; + }; + + ($dict:expr, {$a:expr, $($args:expr),+}) => { + pydict_insert_keys!($dict, { $a }); + pydict_insert_keys!($dict, { $($args),+ }); + }; + + ($dict:expr, {$a:expr, $($args:expr),+,}) => { + pydict_insert_keys!($dict, {$a, $($args),+}); + }; +} + +// Result dataclasses. These are initialized from Python by calling [`PyCatalogClient::init_classes`]. + +static CATALOG_INFO_CLS: GILOnceCell> = GILOnceCell::new(); +static NAMESPACE_INFO_CLS: GILOnceCell> = GILOnceCell::new(); +static TABLE_INFO_CLS: GILOnceCell> = GILOnceCell::new(); +static COLUMN_INFO_CLS: GILOnceCell> = GILOnceCell::new(); + +#[pyclass] +pub struct PyCatalogClient(CatalogClient); + +#[pymethods] +impl PyCatalogClient { + #[pyo3(signature = (workspace_url, bearer_token))] + #[staticmethod] + pub fn new(workspace_url: String, bearer_token: Option) -> PyResult { + let builder = CatalogClientBuilder::new().with_workspace_url(workspace_url); + + let builder = if let Some(bearer_token) = bearer_token { + builder.with_bearer_token(bearer_token) + } else { + builder + }; + + builder.build().map(PyCatalogClient).map_err(to_py_err) + } + + pub fn list_catalogs(&self, py: Python) -> PyResult { + let v = py.enter_polars(|| { + pl_async::get_runtime().block_in_place_on(self.client().list_catalogs()) + })?; + + let mut opt_err = None; + + let out = PyList::new( + py, + v.into_iter().map(|x| { + let v = catalog_info_to_pyobject(py, x); + if let Ok(v) = v { + Some(v) + } else { + opt_err.replace(v); + None + } + }), + )?; + + opt_err.transpose()?; + + Ok(out.into()) + } + + #[pyo3(signature = (catalog_name))] + pub fn list_namespaces(&self, py: Python, catalog_name: &str) -> PyResult { + let v = py.enter_polars(|| { + pl_async::get_runtime().block_in_place_on(self.client().list_namespaces(catalog_name)) + })?; + + let mut opt_err = None; + + let out = PyList::new( + py, + v.into_iter().map(|x| { + let v = namespace_info_to_pyobject(py, x); + match v { + Ok(v) => Some(v), + Err(_) => { + opt_err.replace(v); + None + }, + } + }), + )?; + + opt_err.transpose()?; + + Ok(out.into()) + } + + #[pyo3(signature = (catalog_name, namespace))] + pub fn list_tables( + &self, + py: Python, + catalog_name: &str, + namespace: &str, + ) -> PyResult { + let v = py.enter_polars(|| { + pl_async::get_runtime() + .block_in_place_on(self.client().list_tables(catalog_name, namespace)) + })?; + + let mut opt_err = None; + + let out = PyList::new( + py, + v.into_iter().map(|table_info| { + let v = table_info_to_pyobject(py, table_info); + + if let Ok(v) = v { + Some(v) + } else { + opt_err.replace(v); + None + } + }), + )? + .into(); + + opt_err.transpose()?; + + Ok(out) + } + + #[pyo3(signature = (table_name, catalog_name, namespace))] + pub fn get_table_info( + &self, + py: Python, + table_name: &str, + catalog_name: &str, + namespace: &str, + ) -> PyResult { + let table_info = py + .enter_polars(|| { + pl_async::get_runtime().block_in_place_on(self.client().get_table_info( + table_name, + catalog_name, + namespace, + )) + }) + .map_err(to_py_err)?; + + table_info_to_pyobject(py, table_info).map(|x| x.into()) + } + + #[pyo3(signature = (table_id, write))] + pub fn get_table_credentials( + &self, + py: Python, + table_id: &str, + write: bool, + ) -> PyResult { + let table_credentials = py + .enter_polars(|| { + pl_async::get_runtime() + .block_in_place_on(self.client().get_table_credentials(table_id, write)) + }) + .map_err(to_py_err)?; + + let expiry = table_credentials.expiration_time; + + let credentials = PyDict::new(py); + // Keys in here are intended to be injected into `storage_options` from the Python side. + // Note this currently really only exists for `aws_endpoint_url`. + let storage_update_options = PyDict::new(py); + + { + use TableCredentialsVariants::*; + use polars_io::catalog::unity::models::{ + TableCredentialsAws, TableCredentialsAzure, TableCredentialsGcp, + TableCredentialsVariants, + }; + + match table_credentials.into_enum() { + Some(Aws(TableCredentialsAws { + access_key_id, + secret_access_key, + session_token, + access_point, + })) => { + credentials.set_item("aws_access_key_id", access_key_id)?; + credentials.set_item("aws_secret_access_key", secret_access_key)?; + + if let Some(session_token) = session_token { + credentials.set_item("aws_session_token", session_token)?; + } + + if let Some(access_point) = access_point { + storage_update_options.set_item("aws_endpoint_url", access_point)?; + } + }, + Some(Azure(TableCredentialsAzure { sas_token })) => { + credentials.set_item("sas_token", sas_token)?; + }, + Some(Gcp(TableCredentialsGcp { oauth_token })) => { + credentials.set_item("bearer_token", oauth_token)?; + }, + None => {}, + } + } + + let credentials = if credentials.len()? > 0 { + credentials.into_any() + } else { + PyNone::get(py).as_any().clone() + }; + let storage_update_options = storage_update_options.into_any(); + let expiry = expiry.into_pyobject(py)?.into_any(); + + Ok(PyTuple::new(py, [credentials, storage_update_options, expiry])?.into()) + } + + #[pyo3(signature = (catalog_name, namespace, table_name, cloud_options, credential_provider, retries))] + pub fn scan_table( + &self, + py: Python, + catalog_name: &str, + namespace: &str, + table_name: &str, + cloud_options: Option>, + credential_provider: Option, + retries: usize, + ) -> PyResult { + let table_info = py.enter_polars(|| { + pl_async::get_runtime().block_in_place_on(self.client().get_table_info( + catalog_name, + namespace, + table_name, + )) + })?; + + let Some(storage_location) = table_info.storage_location.as_deref() else { + return Err(PyValueError::new_err( + "cannot scan catalog table: no storage_location found", + )); + }; + + let cloud_options = + parse_cloud_options(storage_location, cloud_options.unwrap_or_default())? + .with_max_retries(retries) + .with_credential_provider( + credential_provider.map(PlCredentialProvider::from_python_builder), + ); + + Ok( + LazyFrame::scan_catalog_table(&table_info, Some(cloud_options)) + .map_err(to_py_err)? + .into(), + ) + } + + #[pyo3(signature = (catalog_name, comment, storage_root))] + pub fn create_catalog( + &self, + py: Python, + catalog_name: &str, + comment: Option<&str>, + storage_root: Option<&str>, + ) -> PyResult { + let catalog_info = py + .allow_threads(|| { + pl_async::get_runtime().block_in_place_on(self.client().create_catalog( + catalog_name, + comment, + storage_root, + )) + }) + .map_err(to_py_err)?; + + catalog_info_to_pyobject(py, catalog_info).map(|x| x.into()) + } + + #[pyo3(signature = (catalog_name, force))] + pub fn delete_catalog(&self, py: Python, catalog_name: &str, force: bool) -> PyResult<()> { + py.allow_threads(|| { + pl_async::get_runtime() + .block_in_place_on(self.client().delete_catalog(catalog_name, force)) + }) + .map_err(to_py_err) + } + + #[pyo3(signature = (catalog_name, namespace, comment, storage_root))] + pub fn create_namespace( + &self, + py: Python, + catalog_name: &str, + namespace: &str, + comment: Option<&str>, + storage_root: Option<&str>, + ) -> PyResult { + let namespace_info = py + .allow_threads(|| { + pl_async::get_runtime().block_in_place_on(self.client().create_namespace( + catalog_name, + namespace, + comment, + storage_root, + )) + }) + .map_err(to_py_err)?; + + namespace_info_to_pyobject(py, namespace_info).map(|x| x.into()) + } + + #[pyo3(signature = (catalog_name, namespace, force))] + pub fn delete_namespace( + &self, + py: Python, + catalog_name: &str, + namespace: &str, + force: bool, + ) -> PyResult<()> { + py.allow_threads(|| { + pl_async::get_runtime().block_in_place_on(self.client().delete_namespace( + catalog_name, + namespace, + force, + )) + }) + .map_err(to_py_err) + } + + #[pyo3(signature = ( + catalog_name, namespace, table_name, schema, table_type, data_source_format, comment, + storage_root, properties + ))] + pub fn create_table( + &self, + py: Python, + catalog_name: &str, + namespace: &str, + table_name: &str, + schema: Option>, + table_type: &str, + data_source_format: Option<&str>, + comment: Option<&str>, + storage_root: Option<&str>, + properties: Vec<(String, String)>, + ) -> PyResult { + let table_info = py.allow_threads(|| { + pl_async::get_runtime() + .block_in_place_on( + self.client().create_table( + catalog_name, + namespace, + table_name, + schema.as_ref().map(|x| &x.0), + &TableType::from_str(table_type) + .map_err(|e| PyValueError::new_err(e.to_string()))?, + data_source_format + .map(DataSourceFormat::from_str) + .transpose() + .map_err(|e| PyValueError::new_err(e.to_string()))? + .as_ref(), + comment, + storage_root, + &mut properties.iter().map(|(a, b)| (a.as_str(), b.as_str())), + ), + ) + .map_err(to_py_err) + })?; + + table_info_to_pyobject(py, table_info).map(|x| x.into()) + } + + #[pyo3(signature = (catalog_name, namespace, table_name))] + pub fn delete_table( + &self, + py: Python, + catalog_name: &str, + namespace: &str, + table_name: &str, + ) -> PyResult<()> { + py.allow_threads(|| { + pl_async::get_runtime().block_in_place_on(self.client().delete_table( + catalog_name, + namespace, + table_name, + )) + }) + .map_err(to_py_err) + } + + #[pyo3(signature = (type_json))] + #[staticmethod] + pub fn type_json_to_polars_type(py: Python, type_json: &str) -> PyResult { + Ok(Wrap(parse_type_json_str(type_json).map_err(to_py_err)?) + .into_pyobject(py)? + .unbind()) + } + + #[pyo3(signature = (catalog_info_cls, namespace_info_cls, table_info_cls, column_info_cls))] + #[staticmethod] + pub fn init_classes( + py: Python, + catalog_info_cls: Py, + namespace_info_cls: Py, + table_info_cls: Py, + column_info_cls: Py, + ) { + CATALOG_INFO_CLS.get_or_init(py, || catalog_info_cls); + NAMESPACE_INFO_CLS.get_or_init(py, || namespace_info_cls); + TABLE_INFO_CLS.get_or_init(py, || table_info_cls); + COLUMN_INFO_CLS.get_or_init(py, || column_info_cls); + } +} + +impl PyCatalogClient { + fn client(&self) -> &CatalogClient { + &self.0 + } +} + +fn catalog_info_to_pyobject( + py: Python, + CatalogInfo { + name, + comment, + storage_location, + properties, + options, + created_at, + created_by, + updated_at, + updated_by, + }: CatalogInfo, +) -> PyResult> { + let dict = PyDict::new(py); + + let properties = properties_to_pyobject(py, properties); + let options = properties_to_pyobject(py, options); + + pydict_insert_keys!(dict, { + name, + comment, + storage_location, + properties, + options, + created_at, + created_by, + updated_at, + updated_by + }); + + CATALOG_INFO_CLS + .get(py) + .unwrap() + .bind(py) + .call((), Some(&dict)) +} + +fn namespace_info_to_pyobject( + py: Python, + NamespaceInfo { + name, + comment, + properties, + storage_location, + created_at, + created_by, + updated_at, + updated_by, + }: NamespaceInfo, +) -> PyResult> { + let dict = PyDict::new(py); + + let properties = properties_to_pyobject(py, properties); + + pydict_insert_keys!(dict, { + name, + comment, + properties, + storage_location, + created_at, + created_by, + updated_at, + updated_by + }); + + NAMESPACE_INFO_CLS + .get(py) + .unwrap() + .bind(py) + .call((), Some(&dict)) +} + +fn table_info_to_pyobject(py: Python, table_info: TableInfo) -> PyResult> { + let TableInfo { + name, + table_id, + table_type, + comment, + storage_location, + data_source_format, + columns, + properties, + created_at, + created_by, + updated_at, + updated_by, + } = table_info; + + let column_info_cls = COLUMN_INFO_CLS.get(py).unwrap().bind(py); + + let columns = columns + .map(|columns| { + columns + .into_iter() + .map( + |ColumnInfo { + name, + type_name, + type_text, + type_json, + position, + comment, + partition_index, + }| { + let dict = PyDict::new(py); + + let name = name.as_str(); + let type_name = type_name.as_str(); + let type_text = type_text.as_str(); + + pydict_insert_keys!(dict, { + name, + type_name, + type_text, + type_json, + position, + comment, + partition_index, + }); + + column_info_cls.call((), Some(&dict)) + }, + ) + .collect::>>() + }) + .transpose()?; + + let dict = PyDict::new(py); + + let data_source_format = data_source_format.map(|x| x.to_string()); + let table_type = table_type.to_string(); + let properties = properties_to_pyobject(py, properties); + + pydict_insert_keys!(dict, { + name, + comment, + table_id, + table_type, + storage_location, + data_source_format, + columns, + properties, + created_at, + created_by, + updated_at, + updated_by, + }); + + TABLE_INFO_CLS + .get(py) + .unwrap() + .bind(py) + .call((), Some(&dict)) +} + +fn properties_to_pyobject( + py: Python, + properties: PlHashMap, +) -> Bound<'_, PyDict> { + let dict = PyDict::new(py); + + for (key, value) in properties.into_iter() { + dict.set_item(key.as_str(), value).unwrap(); + } + + dict +} diff --git a/crates/polars-python/src/cloud.rs b/crates/polars-python/src/cloud.rs new file mode 100644 index 000000000000..3f606d3400e8 --- /dev/null +++ b/crates/polars-python/src/cloud.rs @@ -0,0 +1,94 @@ +use polars_core::error::{PolarsResult, polars_err}; +use polars_expr::state::ExecutionState; +use polars_mem_engine::create_physical_plan; +use polars_plan::plans::{AExpr, IR, IRPlan}; +use polars_plan::prelude::{Arena, Node}; +use polars_utils::pl_serialize; +use pyo3::intern; +use pyo3::prelude::{PyAnyMethods, PyModule, Python, *}; +use pyo3::types::{IntoPyDict, PyBytes}; + +use crate::error::PyPolarsErr; +use crate::lazyframe::visit::NodeTraverser; +use crate::utils::EnterPolarsExt; +use crate::{PyDataFrame, PyLazyFrame}; + +#[pyfunction] +pub fn prepare_cloud_plan(lf: PyLazyFrame, py: Python<'_>) -> PyResult> { + let plan = lf.ldf.logical_plan; + let bytes = polars::prelude::prepare_cloud_plan(plan).map_err(PyPolarsErr::from)?; + + Ok(PyBytes::new(py, &bytes)) +} + +/// Take a serialized `IRPlan` and execute it on the GPU engine. +/// +/// This is done as a Python function because the `NodeTraverser` class created for this purpose +/// must exactly match the one expected by the `cudf_polars` package. +#[pyfunction] +pub fn _execute_ir_plan_with_gpu(ir_plan_ser: Vec, py: Python) -> PyResult { + // Deserialize into IRPlan. + let mut ir_plan: IRPlan = + pl_serialize::deserialize_from_reader::<_, _, false>(ir_plan_ser.as_slice()) + .map_err(PyPolarsErr::from)?; + + // Edit for use with GPU engine. + gpu_post_opt( + py, + ir_plan.lp_top, + &mut ir_plan.lp_arena, + &mut ir_plan.expr_arena, + ) + .map_err(PyPolarsErr::from)?; + + // Convert to physical plan. + let mut physical_plan = create_physical_plan( + ir_plan.lp_top, + &mut ir_plan.lp_arena, + &mut ir_plan.expr_arena, + None, + ) + .map_err(PyPolarsErr::from)?; + + // Execute the plan. + let mut state = ExecutionState::new(); + py.enter_polars_df(|| physical_plan.execute(&mut state)) +} + +/// Prepare the IR for execution by the Polars GPU engine. +fn gpu_post_opt( + py: Python, + root: Node, + lp_arena: &mut Arena, + expr_arena: &mut Arena, +) -> PolarsResult<()> { + // Get cuDF Python function. + let cudf = PyModule::import(py, intern!(py, "cudf_polars")).unwrap(); + let lambda = cudf.getattr(intern!(py, "execute_with_cudf")).unwrap(); + + // Define cuDF config. + let polars = PyModule::import(py, intern!(py, "polars")).unwrap(); + let engine = polars.getattr(intern!(py, "GPUEngine")).unwrap(); + let kwargs = [("raise_on_fail", true)].into_py_dict(py).unwrap(); + let engine = engine.call((), Some(&kwargs)).unwrap(); + + // Define node traverser. + let nt = NodeTraverser::new(root, std::mem::take(lp_arena), std::mem::take(expr_arena)); + + // Get a copy of the arenas. + let arenas = nt.get_arenas(); + + // Pass the node visitor which allows the Python callback to replace parts of the query plan. + // Remove "cuda" or specify better once we have multiple post-opt callbacks. + let kwargs = [("config", engine)].into_py_dict(py).unwrap(); + lambda + .call((nt,), Some(&kwargs)) + .map_err(|e| polars_err!(ComputeError: "'cuda' conversion failed: {}", e))?; + + // Unpack the arena's. + // At this point the `nt` is useless. + std::mem::swap(lp_arena, &mut *arenas.0.lock().unwrap()); + std::mem::swap(expr_arena, &mut *arenas.1.lock().unwrap()); + + Ok(()) +} diff --git a/crates/polars-python/src/conversion/any_value.rs b/crates/polars-python/src/conversion/any_value.rs new file mode 100644 index 000000000000..29f616f66294 --- /dev/null +++ b/crates/polars-python/src/conversion/any_value.rs @@ -0,0 +1,565 @@ +use std::borrow::{Borrow, Cow}; +use std::sync::{Arc, Mutex}; + +use chrono::{ + DateTime, Datelike, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, TimeDelta, Timelike, +}; +use chrono_tz::Tz; +use hashbrown::HashMap; +#[cfg(feature = "object")] +use polars::chunked_array::object::PolarsObjectSafe; +#[cfg(feature = "object")] +use polars::datatypes::OwnedObject; +use polars::datatypes::{DataType, Field, TimeUnit}; +use polars::prelude::{AnyValue, PlSmallStr, Series}; +use polars_core::utils::any_values_to_supertype_and_n_dtypes; +use polars_core::utils::arrow::temporal_conversions::date32_to_date; +use polars_utils::aliases::PlFixedStateQuality; +use pyo3::exceptions::{PyOverflowError, PyTypeError, PyValueError}; +use pyo3::prelude::*; +use pyo3::pybacked::PyBackedStr; +use pyo3::types::{ + PyBool, PyBytes, PyDict, PyFloat, PyInt, PyList, PySequence, PyString, PyTuple, PyType, +}; +use pyo3::{IntoPyObjectExt, intern}; + +use super::datetime::{ + datetime_to_py_object, elapsed_offset_to_timedelta, nanos_since_midnight_to_naivetime, +}; +use super::{ObjectValue, Wrap, decimal_to_digits, struct_dict}; +use crate::error::PyPolarsErr; +use crate::py_modules::{pl_series, pl_utils}; +use crate::series::PySeries; + +impl<'py> IntoPyObject<'py> for Wrap> { + type Target = PyAny; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> Result { + any_value_into_py_object(self.0, py) + } +} + +impl<'py> IntoPyObject<'py> for &Wrap> { + type Target = PyAny; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> Result { + self.clone().into_pyobject(py) + } +} + +impl<'py> FromPyObject<'py> for Wrap> { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + py_object_to_any_value(ob, true, true).map(Wrap) + } +} + +pub(crate) fn any_value_into_py_object<'py>( + av: AnyValue, + py: Python<'py>, +) -> PyResult> { + let utils = pl_utils(py).bind(py); + match av { + AnyValue::UInt8(v) => v.into_bound_py_any(py), + AnyValue::UInt16(v) => v.into_bound_py_any(py), + AnyValue::UInt32(v) => v.into_bound_py_any(py), + AnyValue::UInt64(v) => v.into_bound_py_any(py), + AnyValue::Int8(v) => v.into_bound_py_any(py), + AnyValue::Int16(v) => v.into_bound_py_any(py), + AnyValue::Int32(v) => v.into_bound_py_any(py), + AnyValue::Int64(v) => v.into_bound_py_any(py), + AnyValue::Int128(v) => v.into_bound_py_any(py), + AnyValue::Float32(v) => v.into_bound_py_any(py), + AnyValue::Float64(v) => v.into_bound_py_any(py), + AnyValue::Null => py.None().into_bound_py_any(py), + AnyValue::Boolean(v) => v.into_bound_py_any(py), + AnyValue::String(v) => v.into_bound_py_any(py), + AnyValue::StringOwned(v) => v.into_bound_py_any(py), + AnyValue::Categorical(idx, rev, arr) | AnyValue::Enum(idx, rev, arr) => { + let s = if arr.is_null() { + rev.get(idx) + } else { + unsafe { arr.deref_unchecked().value(idx as usize) } + }; + s.into_bound_py_any(py) + }, + AnyValue::CategoricalOwned(idx, rev, arr) | AnyValue::EnumOwned(idx, rev, arr) => { + let s = if arr.is_null() { + rev.get(idx) + } else { + unsafe { arr.deref_unchecked().value(idx as usize) } + }; + s.into_bound_py_any(py) + }, + AnyValue::Date(v) => { + let date = date32_to_date(v); + date.into_bound_py_any(py) + }, + AnyValue::Datetime(v, time_unit, time_zone) => { + datetime_to_py_object(py, v, time_unit, time_zone) + }, + AnyValue::DatetimeOwned(v, time_unit, time_zone) => { + datetime_to_py_object(py, v, time_unit, time_zone.as_ref().map(AsRef::as_ref)) + }, + AnyValue::Duration(v, time_unit) => { + let time_delta = elapsed_offset_to_timedelta(v, time_unit); + time_delta.into_bound_py_any(py) + }, + AnyValue::Time(v) => nanos_since_midnight_to_naivetime(v).into_bound_py_any(py), + AnyValue::Array(v, _) | AnyValue::List(v) => PySeries::new(v).to_list(py), + ref av @ AnyValue::Struct(_, _, flds) => { + Ok(struct_dict(py, av._iter_struct_av(), flds)?.into_any()) + }, + AnyValue::StructOwned(payload) => { + Ok(struct_dict(py, payload.0.into_iter(), &payload.1)?.into_any()) + }, + #[cfg(feature = "object")] + AnyValue::Object(v) => { + let object = v.as_any().downcast_ref::().unwrap(); + Ok(object.inner.clone_ref(py).into_bound(py)) + }, + #[cfg(feature = "object")] + AnyValue::ObjectOwned(v) => { + let object = v.0.as_any().downcast_ref::().unwrap(); + Ok(object.inner.clone_ref(py).into_bound(py)) + }, + AnyValue::Binary(v) => PyBytes::new(py, v).into_bound_py_any(py), + AnyValue::BinaryOwned(v) => PyBytes::new(py, &v).into_bound_py_any(py), + AnyValue::Decimal(v, scale) => { + let convert = utils.getattr(intern!(py, "to_py_decimal"))?; + const N: usize = 3; + let mut buf = [0_u128; N]; + let n_digits = decimal_to_digits(v.abs(), &mut buf); + let buf = unsafe { + std::slice::from_raw_parts( + buf.as_slice().as_ptr() as *const u8, + N * size_of::(), + ) + }; + let digits = PyTuple::new(py, buf.iter().take(n_digits))?; + convert.call1((v.is_negative() as u8, digits, n_digits, -(scale as i32))) + }, + } +} + +/// Holds a Python type object and implements hashing / equality based on the pointer address of the +/// type object. This is used as a hashtable key instead of only the `usize` pointer value, as we +/// need to hold a ref to the Python type object to keep it alive. +#[derive(Debug)] +pub struct TypeObjectKey { + #[allow(unused)] + type_object: Py, + /// We need to store this in a field for `Borrow` + address: usize, +} + +impl TypeObjectKey { + fn new(type_object: Py) -> Self { + let address = type_object.as_ptr() as usize; + Self { + type_object, + address, + } + } +} + +impl PartialEq for TypeObjectKey { + fn eq(&self, other: &Self) -> bool { + self.address == other.address + } +} + +impl Eq for TypeObjectKey {} + +impl std::borrow::Borrow for TypeObjectKey { + fn borrow(&self) -> &usize { + &self.address + } +} + +impl std::hash::Hash for TypeObjectKey { + fn hash(&self, state: &mut H) { + let v: &usize = self.borrow(); + v.hash(state) + } +} + +type InitFn = for<'py> fn(&Bound<'py, PyAny>, bool) -> PyResult>; +pub(crate) static LUT: Mutex> = + Mutex::new(HashMap::with_hasher(PlFixedStateQuality::with_seed(0))); + +/// Convert a Python object to an [`AnyValue`]. +pub(crate) fn py_object_to_any_value<'py>( + ob: &Bound<'py, PyAny>, + strict: bool, + allow_object: bool, +) -> PyResult> { + // Conversion functions. + fn get_null(_ob: &Bound<'_, PyAny>, _strict: bool) -> PyResult> { + Ok(AnyValue::Null) + } + + fn get_bool(ob: &Bound<'_, PyAny>, _strict: bool) -> PyResult> { + let b = ob.extract::()?; + Ok(AnyValue::Boolean(b)) + } + + fn get_int(ob: &Bound<'_, PyAny>, strict: bool) -> PyResult> { + if let Ok(v) = ob.extract::() { + Ok(AnyValue::Int64(v)) + } else if let Ok(v) = ob.extract::() { + Ok(AnyValue::Int128(v)) + } else if !strict { + let f = ob.extract::()?; + Ok(AnyValue::Float64(f)) + } else { + Err(PyOverflowError::new_err(format!( + "int value too large for Polars integer types: {ob}" + ))) + } + } + + fn get_float(ob: &Bound<'_, PyAny>, _strict: bool) -> PyResult> { + Ok(AnyValue::Float64(ob.extract::()?)) + } + + fn get_str(ob: &Bound<'_, PyAny>, _strict: bool) -> PyResult> { + // Ideally we'd be returning an AnyValue::String(&str) instead, as was + // the case in previous versions of this function. However, if compiling + // with abi3 for versions older than Python 3.10, the APIs that purport + // to return &str actually just encode to UTF-8 as a newly allocated + // PyBytes object, and then return reference to that. So what we're + // doing here isn't any different fundamentally, and the APIs to for + // converting to &str are deprecated in PyO3 0.21. + // + // Once Python 3.10 is the minimum supported version, converting to &str + // will be cheaper, and we should do that. Python 3.9 security updates + // end-of-life is Oct 31, 2025. + Ok(AnyValue::StringOwned(ob.extract::()?.into())) + } + + fn get_bytes<'py>(ob: &Bound<'py, PyAny>, _strict: bool) -> PyResult> { + let value = ob.extract::>()?; + Ok(AnyValue::BinaryOwned(value)) + } + + fn get_date(ob: &Bound<'_, PyAny>, _strict: bool) -> PyResult> { + const UNIX_EPOCH: NaiveDate = NaiveDateTime::UNIX_EPOCH.date(); + let date = ob.extract::()?; + let elapsed = date.signed_duration_since(UNIX_EPOCH); + Ok(AnyValue::Date(elapsed.num_days() as i32)) + } + + fn get_datetime(ob: &Bound<'_, PyAny>, _strict: bool) -> PyResult> { + let py = ob.py(); + let tzinfo = ob.getattr(intern!(py, "tzinfo"))?; + + if tzinfo.is_none() { + let datetime = ob.extract::()?; + let delta = datetime - NaiveDateTime::UNIX_EPOCH; + let timestamp = delta.num_microseconds().unwrap(); + return Ok(AnyValue::Datetime(timestamp, TimeUnit::Microseconds, None)); + } + + // Try converting `pytz` timezone to `zoneinfo` timezone + let (ob, tzinfo) = if let Some(tz) = tzinfo + .getattr(intern!(py, "zone")) + .ok() + .and_then(|zone| zone.extract::().ok()?.parse::().ok()) + { + let tzinfo = tz.into_pyobject(py)?; + ( + &ob.call_method(intern!(py, "astimezone"), (&tzinfo,), None)?, + tzinfo, + ) + } else { + (ob, tzinfo) + }; + + let (timestamp, tz) = if tzinfo.hasattr(intern!(py, "key"))? { + let datetime = ob.extract::>()?; + let tz = datetime.timezone().name().into(); + if datetime.year() >= 2100 { + // chrono-tz does not support dates after 2100 + // https://github.com/chronotope/chrono-tz/issues/135 + ( + pl_utils(py) + .bind(py) + .getattr(intern!(py, "datetime_to_int"))? + .call1((ob, intern!(py, "us")))? + .extract::()?, + tz, + ) + } else { + let delta = datetime.to_utc() - DateTime::UNIX_EPOCH; + (delta.num_microseconds().unwrap(), tz) + } + } else { + let datetime = ob.extract::>()?; + let delta = datetime.to_utc() - DateTime::UNIX_EPOCH; + (delta.num_microseconds().unwrap(), "UTC".into()) + }; + + Ok(AnyValue::DatetimeOwned( + timestamp, + TimeUnit::Microseconds, + Some(Arc::new(tz)), + )) + } + + fn get_timedelta(ob: &Bound<'_, PyAny>, _strict: bool) -> PyResult> { + let timedelta = ob.extract::()?; + if let Some(micros) = timedelta.num_microseconds() { + Ok(AnyValue::Duration(micros, TimeUnit::Microseconds)) + } else { + Ok(AnyValue::Duration( + timedelta.num_milliseconds(), + TimeUnit::Milliseconds, + )) + } + } + + fn get_time(ob: &Bound<'_, PyAny>, _strict: bool) -> PyResult> { + let time = ob.extract::()?; + + Ok(AnyValue::Time( + (time.num_seconds_from_midnight() as i64) * 1_000_000_000 + time.nanosecond() as i64, + )) + } + + fn get_decimal(ob: &Bound<'_, PyAny>, _strict: bool) -> PyResult> { + fn abs_decimal_from_digits( + digits: impl IntoIterator, + exp: i32, + ) -> Option<(i128, usize)> { + const MAX_ABS_DEC: i128 = 10_i128.pow(38) - 1; + let mut v = 0_i128; + for (i, d) in digits.into_iter().map(i128::from).enumerate() { + if i < 38 { + v = v * 10 + d; + } else { + v = v.checked_mul(10).and_then(|v| v.checked_add(d))?; + } + } + // We only support non-negative scale (=> non-positive exponent). + let scale = if exp > 0 { + // The decimal may be in a non-canonical representation, try to fix it first. + v = 10_i128 + .checked_pow(exp as u32) + .and_then(|factor| v.checked_mul(factor))?; + 0 + } else { + (-exp) as usize + }; + // TODO: Do we care for checking if it fits in MAX_ABS_DEC? (if we set precision to None anyway?) + (v <= MAX_ABS_DEC).then_some((v, scale)) + } + + // Note: Using Vec is not the most efficient thing here (input is a tuple) + let (sign, digits, exp): (i8, Vec, i32) = ob + .call_method0(intern!(ob.py(), "as_tuple")) + .unwrap() + .extract() + .unwrap(); + let (mut v, scale) = abs_decimal_from_digits(digits, exp).ok_or_else(|| { + PyErr::from(PyPolarsErr::Other( + "Decimal is too large to fit in Decimal128".into(), + )) + })?; + if sign > 0 { + v = -v; // Won't overflow since -i128::MAX > i128::MIN + } + Ok(AnyValue::Decimal(v, scale)) + } + + fn get_list(ob: &Bound<'_, PyAny>, strict: bool) -> PyResult> { + fn get_list_with_constructor( + ob: &Bound<'_, PyAny>, + strict: bool, + ) -> PyResult> { + // Use the dedicated constructor. + // This constructor is able to go via dedicated type constructors + // so it can be much faster. + let py = ob.py(); + let kwargs = PyDict::new(py); + kwargs.set_item("strict", strict)?; + let s = pl_series(py).call(py, (ob,), Some(&kwargs))?; + get_list_from_series(s.bind(py), strict) + } + + if ob.is_empty()? { + Ok(AnyValue::List(Series::new_empty( + PlSmallStr::EMPTY, + &DataType::Null, + ))) + } else if ob.is_instance_of::() | ob.is_instance_of::() { + const INFER_SCHEMA_LENGTH: usize = 25; + + let list = ob.downcast::()?; + + let mut avs = Vec::with_capacity(INFER_SCHEMA_LENGTH); + let mut iter = list.try_iter()?; + let mut items = Vec::with_capacity(INFER_SCHEMA_LENGTH); + for item in (&mut iter).take(INFER_SCHEMA_LENGTH) { + items.push(item?); + let av = py_object_to_any_value(items.last().unwrap(), strict, true)?; + avs.push(av) + } + let (dtype, n_dtypes) = any_values_to_supertype_and_n_dtypes(&avs) + .map_err(|e| PyTypeError::new_err(e.to_string()))?; + + // This path is only taken if there is no question about the data type. + if dtype.is_primitive() && n_dtypes == 1 { + get_list_with_constructor(ob, strict) + } else { + // Push the rest. + let length = list.len()?; + avs.reserve(length); + let mut rest = Vec::with_capacity(length); + for item in iter { + rest.push(item?); + let av = py_object_to_any_value(rest.last().unwrap(), strict, true)?; + avs.push(av) + } + + let s = Series::from_any_values_and_dtype(PlSmallStr::EMPTY, &avs, &dtype, strict) + .map_err(|e| { + PyTypeError::new_err(format!( + "{e}\n\nHint: Try setting `strict=False` to allow passing data with mixed types." + )) + })?; + Ok(AnyValue::List(s)) + } + } else { + // range will take this branch + get_list_with_constructor(ob, strict) + } + } + + fn get_list_from_series(ob: &Bound<'_, PyAny>, _strict: bool) -> PyResult> { + let s = super::get_series(ob)?; + Ok(AnyValue::List(s)) + } + + fn get_struct<'py>(ob: &Bound<'py, PyAny>, strict: bool) -> PyResult> { + let dict = ob.downcast::().unwrap(); + let len = dict.len(); + let mut keys = Vec::with_capacity(len); + let mut vals = Vec::with_capacity(len); + for (k, v) in dict.into_iter() { + let key = k.extract::>()?; + let val = py_object_to_any_value(&v, strict, true)?; + let dtype = val.dtype(); + keys.push(Field::new(key.as_ref().into(), dtype)); + vals.push(val) + } + Ok(AnyValue::StructOwned(Box::new((vals, keys)))) + } + + fn get_object(ob: &Bound<'_, PyAny>, _strict: bool) -> PyResult> { + #[cfg(feature = "object")] + { + // This is slow, but hey don't use objects. + let v = &ObjectValue { + inner: ob.clone().unbind(), + }; + Ok(AnyValue::ObjectOwned(OwnedObject(v.to_boxed()))) + } + #[cfg(not(feature = "object"))] + panic!("activate object") + } + + /// Determine which conversion function to use for the given object. + /// + /// Note: This function is only ran if the object's type is not already in the + /// lookup table. + fn get_conversion_function(ob: &Bound<'_, PyAny>, allow_object: bool) -> PyResult { + let py = ob.py(); + if ob.is_none() { + Ok(get_null) + } + // bool must be checked before int because Python bool is an instance of int. + else if ob.is_instance_of::() { + Ok(get_bool) + } else if ob.is_instance_of::() { + Ok(get_int) + } else if ob.is_instance_of::() { + Ok(get_float) + } else if ob.is_instance_of::() { + Ok(get_str) + } else if ob.is_instance_of::() { + Ok(get_bytes) + } else if ob.is_instance_of::() || ob.is_instance_of::() { + Ok(get_list) + } else if ob.is_instance_of::() { + Ok(get_struct) + } else { + let ob_type = ob.get_type(); + let type_name = ob_type.fully_qualified_name()?.to_string(); + match type_name.as_str() { + // Can't use pyo3::types::PyDateTime with abi3-py37 feature, + // so need this workaround instead of `isinstance(ob, datetime)`. + "datetime.date" => Ok(get_date as InitFn), + "datetime.time" => Ok(get_time as InitFn), + "datetime.datetime" => Ok(get_datetime as InitFn), + "datetime.timedelta" => Ok(get_timedelta as InitFn), + "decimal.Decimal" => Ok(get_decimal as InitFn), + "range" => Ok(get_list as InitFn), + _ => { + // Support NumPy scalars. + if ob.extract::().is_ok() || ob.extract::().is_ok() { + return Ok(get_int as InitFn); + } else if ob.extract::().is_ok() { + return Ok(get_float as InitFn); + } + + // Support custom subclasses of datetime/date. + let ancestors = ob_type.getattr(intern!(py, "__mro__"))?; + let ancestors_str_iter = ancestors + .try_iter()? + .map(|b| b.unwrap().str().unwrap().to_string()); + for c in ancestors_str_iter { + match &*c { + // datetime must be checked before date because + // Python datetime is an instance of date. + "" => { + return Ok(get_datetime as InitFn); + }, + "" => return Ok(get_date as InitFn), + "" => return Ok(get_timedelta as InitFn), + "" => return Ok(get_time as InitFn), + _ => (), + } + } + + if allow_object { + Ok(get_object as InitFn) + } else { + Err(PyValueError::new_err(format!("Cannot convert {ob}"))) + } + }, + } + } + } + + let py_type = ob.get_type(); + let py_type_address = py_type.as_ptr() as usize; + + let conversion_func = { + if let Some(cached_func) = LUT.lock().unwrap().get(&py_type_address) { + *cached_func + } else { + let k = TypeObjectKey::new(py_type.clone().unbind()); + assert_eq!(k.address, py_type_address); + + let func = get_conversion_function(ob, allow_object)?; + LUT.lock().unwrap().insert(k, func); + func + } + }; + + conversion_func(ob, strict) +} diff --git a/crates/polars-python/src/conversion/chunked_array.rs b/crates/polars-python/src/conversion/chunked_array.rs new file mode 100644 index 000000000000..fcd9802112fe --- /dev/null +++ b/crates/polars-python/src/conversion/chunked_array.rs @@ -0,0 +1,156 @@ +use chrono::NaiveTime; +use polars_core::utils::arrow::temporal_conversions::date32_to_date; +use pyo3::prelude::*; +use pyo3::types::{PyBytes, PyList, PyNone, PyTuple}; +use pyo3::{BoundObject, intern}; + +use super::datetime::{ + datetime_to_py_object, elapsed_offset_to_timedelta, nanos_since_midnight_to_naivetime, +}; +use super::{decimal_to_digits, struct_dict}; +use crate::prelude::*; +use crate::py_modules::pl_utils; + +impl<'py> IntoPyObject<'py> for &Wrap<&StringChunked> { + type Target = PyList; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> Result { + let iter = self.0.iter(); + PyList::new(py, iter) + } +} + +impl<'py> IntoPyObject<'py> for &Wrap<&BinaryChunked> { + type Target = PyList; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> Result { + let iter = self + .0 + .iter() + .map(|opt_bytes| opt_bytes.map(|bytes| PyBytes::new(py, bytes))); + PyList::new(py, iter) + } +} + +impl<'py> IntoPyObject<'py> for &Wrap<&StructChunked> { + type Target = PyList; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + fn into_pyobject(self, py: Python<'py>) -> Result { + let s = self.0.clone().into_series(); + // todo! iterate its chunks and flatten. + // make series::iter() accept a chunk index. + let s = s.rechunk(); + let iter = s.iter().map(|av| match av { + AnyValue::Struct(_, _, flds) => struct_dict(py, av._iter_struct_av(), flds) + .unwrap() + .into_any(), + AnyValue::Null => PyNone::get(py).into_bound().into_any(), + _ => unreachable!(), + }); + + PyList::new(py, iter) + } +} + +impl<'py> IntoPyObject<'py> for &Wrap<&DurationChunked> { + type Target = PyList; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> Result { + let time_unit = self.0.time_unit(); + let iter = self + .0 + .iter() + .map(|opt_v| opt_v.map(|v| elapsed_offset_to_timedelta(v, time_unit))); + PyList::new(py, iter) + } +} + +impl<'py> IntoPyObject<'py> for &Wrap<&DatetimeChunked> { + type Target = PyList; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> Result { + let time_zone = self.0.time_zone().as_ref(); + let time_unit = self.0.time_unit(); + let iter = self.0.iter().map(|opt_v| { + opt_v.map(|v| datetime_to_py_object(py, v, time_unit, time_zone).unwrap()) + }); + PyList::new(py, iter) + } +} + +impl<'py> IntoPyObject<'py> for &Wrap<&TimeChunked> { + type Target = PyList; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> Result { + let iter = time_to_pyobject_iter(self.0); + PyList::new(py, iter) + } +} + +pub(crate) fn time_to_pyobject_iter( + ca: &TimeChunked, +) -> impl '_ + ExactSizeIterator> { + ca.0.iter() + .map(move |opt_v| opt_v.map(nanos_since_midnight_to_naivetime)) +} + +impl<'py> IntoPyObject<'py> for &Wrap<&DateChunked> { + type Target = PyList; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> Result { + let iter = self.0.into_iter().map(|opt_v| opt_v.map(date32_to_date)); + PyList::new(py, iter) + } +} + +impl<'py> IntoPyObject<'py> for &Wrap<&DecimalChunked> { + type Target = PyList; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + fn into_pyobject(self, py: Python<'py>) -> Result { + let iter = decimal_to_pyobject_iter(py, self.0)?; + PyList::new(py, iter) + } +} + +pub(crate) fn decimal_to_pyobject_iter<'py, 'a>( + py: Python<'py>, + ca: &'a DecimalChunked, +) -> PyResult>> + use<'py, 'a>> { + let utils = pl_utils(py).bind(py); + let convert = utils.getattr(intern!(py, "to_py_decimal"))?; + let py_scale = (-(ca.scale() as i32)).into_pyobject(py)?; + // if we don't know precision, the only safe bet is to set it to 39 + let py_precision = ca.precision().unwrap_or(39).into_pyobject(py)?; + Ok(ca.iter().map(move |opt_v| { + opt_v.map(|v| { + // TODO! use AnyValue so that we have a single impl. + const N: usize = 3; + let mut buf = [0_u128; N]; + let n_digits = decimal_to_digits(v.abs(), &mut buf); + let buf = unsafe { + std::slice::from_raw_parts( + buf.as_slice().as_ptr() as *const u8, + N * size_of::(), + ) + }; + let digits = PyTuple::new(py, buf.iter().take(n_digits)).unwrap(); + convert + .call1((v.is_negative() as u8, digits, &py_precision, &py_scale)) + .unwrap() + }) + })) +} diff --git a/crates/polars-python/src/conversion/datetime.rs b/crates/polars-python/src/conversion/datetime.rs new file mode 100644 index 000000000000..69c0d0867ca8 --- /dev/null +++ b/crates/polars-python/src/conversion/datetime.rs @@ -0,0 +1,72 @@ +//! Utilities for converting dates, times, datetimes, and so on. + +use std::str::FromStr; + +use chrono::{DateTime, Datelike, FixedOffset, NaiveDateTime, NaiveTime, TimeDelta, TimeZone as _}; +use chrono_tz::Tz; +use polars::datatypes::TimeUnit; +use polars_core::datatypes::TimeZone; +use pyo3::types::PyAnyMethods; +use pyo3::{Bound, IntoPyObjectExt, PyAny, PyResult, Python, intern}; + +use crate::error::PyPolarsErr; +use crate::py_modules::pl_utils; + +pub fn elapsed_offset_to_timedelta(elapsed: i64, time_unit: TimeUnit) -> TimeDelta { + let (in_second, nano_multiplier) = match time_unit { + TimeUnit::Nanoseconds => (1_000_000_000, 1), + TimeUnit::Microseconds => (1_000_000, 1_000), + TimeUnit::Milliseconds => (1_000, 1_000_000), + }; + let mut elapsed_sec = elapsed / in_second; + let mut elapsed_nanos = nano_multiplier * (elapsed % in_second); + if elapsed_nanos < 0 { + // TimeDelta expects nanos to always be positive. + elapsed_sec -= 1; + elapsed_nanos += 1_000_000_000; + } + TimeDelta::new(elapsed_sec, elapsed_nanos as u32).unwrap() +} + +/// Convert time-units-since-epoch to a more structured object. +pub fn timestamp_to_naive_datetime(since_epoch: i64, time_unit: TimeUnit) -> NaiveDateTime { + NaiveDateTime::UNIX_EPOCH + elapsed_offset_to_timedelta(since_epoch, time_unit) +} + +/// Convert nanoseconds-since-midnight to a more structured object. +pub fn nanos_since_midnight_to_naivetime(nanos_since_midnight: i64) -> NaiveTime { + NaiveTime::from_hms_opt(0, 0, 0).unwrap() + + elapsed_offset_to_timedelta(nanos_since_midnight, TimeUnit::Nanoseconds) +} + +pub fn datetime_to_py_object<'py>( + py: Python<'py>, + v: i64, + tu: TimeUnit, + tz: Option<&TimeZone>, +) -> PyResult> { + if let Some(time_zone) = tz { + if let Ok(tz) = Tz::from_str(time_zone) { + let utc_datetime = DateTime::UNIX_EPOCH + elapsed_offset_to_timedelta(v, tu); + if utc_datetime.year() >= 2100 { + // chrono-tz does not support dates after 2100 + // https://github.com/chronotope/chrono-tz/issues/135 + pl_utils(py) + .bind(py) + .getattr(intern!(py, "to_py_datetime"))? + .call1((v, tu.to_ascii(), time_zone.as_str())) + } else { + let datetime = utc_datetime.with_timezone(&tz); + datetime.into_bound_py_any(py) + } + } else if let Ok(tz) = FixedOffset::from_str(time_zone) { + let naive_datetime = timestamp_to_naive_datetime(v, tu); + let datetime = tz.from_utc_datetime(&naive_datetime); + datetime.into_bound_py_any(py) + } else { + Err(PyPolarsErr::Other(format!("Could not parse timezone: {time_zone}")).into()) + } + } else { + timestamp_to_naive_datetime(v, tu).into_bound_py_any(py) + } +} diff --git a/crates/polars-python/src/conversion/mod.rs b/crates/polars-python/src/conversion/mod.rs new file mode 100644 index 000000000000..5d1e20c2cd1c --- /dev/null +++ b/crates/polars-python/src/conversion/mod.rs @@ -0,0 +1,1331 @@ +pub(crate) mod any_value; +pub(crate) mod chunked_array; +mod datetime; + +use std::convert::Infallible; +use std::fmt::{Display, Formatter}; +use std::fs::File; +use std::hash::{Hash, Hasher}; +use std::path::PathBuf; + +#[cfg(feature = "object")] +use polars::chunked_array::object::PolarsObjectSafe; +use polars::frame::row::Row; +#[cfg(feature = "avro")] +use polars::io::avro::AvroCompression; +#[cfg(feature = "cloud")] +use polars::io::cloud::CloudOptions; +use polars::series::ops::NullBehavior; +use polars_core::utils::arrow::array::Array; +use polars_core::utils::arrow::types::NativeType; +use polars_core::utils::materialize_dyn_int; +use polars_lazy::prelude::*; +#[cfg(feature = "parquet")] +use polars_parquet::write::StatisticsOptions; +use polars_plan::dsl::ScanSources; +use polars_utils::mmap::MemSlice; +use polars_utils::pl_str::PlSmallStr; +use polars_utils::total_ord::{TotalEq, TotalHash}; +use pyo3::basic::CompareOp; +use pyo3::exceptions::{PyTypeError, PyValueError}; +use pyo3::intern; +use pyo3::prelude::*; +use pyo3::pybacked::PyBackedStr; +use pyo3::types::{PyDict, PyList, PySequence, PyString}; + +use crate::error::PyPolarsErr; +use crate::file::{PythonScanSourceInput, get_python_scan_source_input}; +#[cfg(feature = "object")] +use crate::object::OBJECT_NAME; +use crate::prelude::*; +use crate::py_modules::{pl_series, polars}; +use crate::series::PySeries; +use crate::{PyDataFrame, PyLazyFrame}; + +/// # Safety +/// Should only be implemented for transparent types +pub(crate) unsafe trait Transparent { + type Target; +} + +unsafe impl Transparent for PySeries { + type Target = Series; +} + +unsafe impl Transparent for Wrap { + type Target = T; +} + +unsafe impl Transparent for Option { + type Target = Option; +} + +pub(crate) fn reinterpret_vec(input: Vec) -> Vec { + assert_eq!(size_of::(), size_of::()); + assert_eq!(align_of::(), align_of::()); + let len = input.len(); + let cap = input.capacity(); + let mut manual_drop_vec = std::mem::ManuallyDrop::new(input); + let vec_ptr: *mut T = manual_drop_vec.as_mut_ptr(); + let ptr: *mut T::Target = vec_ptr as *mut T::Target; + unsafe { Vec::from_raw_parts(ptr, len, cap) } +} + +pub(crate) fn vec_extract_wrapped(buf: Vec>) -> Vec { + reinterpret_vec(buf) +} + +#[repr(transparent)] +pub struct Wrap(pub T); + +impl Clone for Wrap +where + T: Clone, +{ + fn clone(&self) -> Self { + Wrap(self.0.clone()) + } +} +impl From for Wrap { + fn from(t: T) -> Self { + Wrap(t) + } +} + +// extract a Rust DataFrame from a python DataFrame, that is DataFrame> +pub(crate) fn get_df(obj: &Bound<'_, PyAny>) -> PyResult { + let pydf = obj.getattr(intern!(obj.py(), "_df"))?; + Ok(pydf.extract::()?.df) +} + +pub(crate) fn get_lf(obj: &Bound<'_, PyAny>) -> PyResult { + let pydf = obj.getattr(intern!(obj.py(), "_ldf"))?; + Ok(pydf.extract::()?.ldf) +} + +pub(crate) fn get_series(obj: &Bound<'_, PyAny>) -> PyResult { + let s = obj.getattr(intern!(obj.py(), "_s"))?; + Ok(s.extract::()?.series) +} + +pub(crate) fn to_series(py: Python, s: PySeries) -> PyResult> { + let series = pl_series(py).bind(py); + let constructor = series.getattr(intern!(py, "_from_pyseries"))?; + constructor.call1((s,)) +} + +impl<'a> FromPyObject<'a> for Wrap { + fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult { + Ok(Wrap((&*ob.extract::()?).into())) + } +} + +#[cfg(feature = "csv")] +impl<'a> FromPyObject<'a> for Wrap { + fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult { + if let Ok(s) = ob.extract::() { + Ok(Wrap(NullValues::AllColumnsSingle((&*s).into()))) + } else if let Ok(s) = ob.extract::>() { + Ok(Wrap(NullValues::AllColumns( + s.into_iter().map(|x| (&*x).into()).collect(), + ))) + } else if let Ok(s) = ob.extract::>() { + Ok(Wrap(NullValues::Named( + s.into_iter() + .map(|(a, b)| ((&*a).into(), (&*b).into())) + .collect(), + ))) + } else { + Err( + PyPolarsErr::Other("could not extract value from null_values argument".into()) + .into(), + ) + } + } +} + +fn struct_dict<'py, 'a>( + py: Python<'py>, + vals: impl Iterator>, + flds: &[Field], +) -> PyResult> { + let dict = PyDict::new(py); + flds.iter().zip(vals).try_for_each(|(fld, val)| { + dict.set_item(fld.name().as_str(), Wrap(val).into_pyobject(py)?) + })?; + Ok(dict) +} + +// accept u128 array to ensure alignment is correct +fn decimal_to_digits(v: i128, buf: &mut [u128; 3]) -> usize { + const ZEROS: i128 = 0x3030_3030_3030_3030_3030_3030_3030_3030; + // SAFETY: transmute is safe as there are 48 bytes in 3 128bit ints + // and the minimal alignment of u8 fits u16 + let buf = unsafe { std::mem::transmute::<&mut [u128; 3], &mut [u8; 48]>(buf) }; + let mut buffer = itoa::Buffer::new(); + let value = buffer.format(v); + let len = value.len(); + for (dst, src) in buf.iter_mut().zip(value.as_bytes().iter()) { + *dst = *src + } + + let ptr = buf.as_mut_ptr() as *mut i128; + unsafe { + // this is safe because we know that the buffer is exactly 48 bytes long + *ptr -= ZEROS; + *ptr.add(1) -= ZEROS; + *ptr.add(2) -= ZEROS; + } + len +} + +impl<'py> IntoPyObject<'py> for &Wrap { + type Target = PyAny; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> Result { + let pl = polars(py).bind(py); + + match &self.0 { + DataType::Int8 => { + let class = pl.getattr(intern!(py, "Int8"))?; + class.call0() + }, + DataType::Int16 => { + let class = pl.getattr(intern!(py, "Int16"))?; + class.call0() + }, + DataType::Int32 => { + let class = pl.getattr(intern!(py, "Int32"))?; + class.call0() + }, + DataType::Int64 => { + let class = pl.getattr(intern!(py, "Int64"))?; + class.call0() + }, + DataType::UInt8 => { + let class = pl.getattr(intern!(py, "UInt8"))?; + class.call0() + }, + DataType::UInt16 => { + let class = pl.getattr(intern!(py, "UInt16"))?; + class.call0() + }, + DataType::UInt32 => { + let class = pl.getattr(intern!(py, "UInt32"))?; + class.call0() + }, + DataType::UInt64 => { + let class = pl.getattr(intern!(py, "UInt64"))?; + class.call0() + }, + DataType::Int128 => { + let class = pl.getattr(intern!(py, "Int128"))?; + class.call0() + }, + DataType::Float32 => { + let class = pl.getattr(intern!(py, "Float32"))?; + class.call0() + }, + DataType::Float64 | DataType::Unknown(UnknownKind::Float) => { + let class = pl.getattr(intern!(py, "Float64"))?; + class.call0() + }, + DataType::Decimal(precision, scale) => { + let class = pl.getattr(intern!(py, "Decimal"))?; + let args = (*precision, *scale); + class.call1(args) + }, + DataType::Boolean => { + let class = pl.getattr(intern!(py, "Boolean"))?; + class.call0() + }, + DataType::String | DataType::Unknown(UnknownKind::Str) => { + let class = pl.getattr(intern!(py, "String"))?; + class.call0() + }, + DataType::Binary => { + let class = pl.getattr(intern!(py, "Binary"))?; + class.call0() + }, + DataType::Array(inner, size) => { + let class = pl.getattr(intern!(py, "Array"))?; + let inner = Wrap(*inner.clone()); + let args = (&inner, *size); + class.call1(args) + }, + DataType::List(inner) => { + let class = pl.getattr(intern!(py, "List"))?; + let inner = Wrap(*inner.clone()); + class.call1((&inner,)) + }, + DataType::Date => { + let class = pl.getattr(intern!(py, "Date"))?; + class.call0() + }, + DataType::Datetime(tu, tz) => { + let datetime_class = pl.getattr(intern!(py, "Datetime"))?; + datetime_class.call1((tu.to_ascii(), tz.as_deref())) + }, + DataType::Duration(tu) => { + let duration_class = pl.getattr(intern!(py, "Duration"))?; + duration_class.call1((tu.to_ascii(),)) + }, + #[cfg(feature = "object")] + DataType::Object(_) => { + let class = pl.getattr(intern!(py, "Object"))?; + class.call0() + }, + DataType::Categorical(_, ordering) => { + let class = pl.getattr(intern!(py, "Categorical"))?; + class.call1((Wrap(*ordering),)) + }, + DataType::Enum(rev_map, _) => { + // we should always have an initialized rev_map coming from rust + let categories = rev_map.as_ref().unwrap().get_categories(); + let class = pl.getattr(intern!(py, "Enum"))?; + let s = + Series::from_arrow(PlSmallStr::from_static("category"), categories.to_boxed()) + .map_err(PyPolarsErr::from)?; + let series = to_series(py, s.into())?; + class.call1((series,)) + }, + DataType::Time => pl.getattr(intern!(py, "Time")), + DataType::Struct(fields) => { + let field_class = pl.getattr(intern!(py, "Field"))?; + let iter = fields.iter().map(|fld| { + let name = fld.name().as_str(); + let dtype = Wrap(fld.dtype().clone()); + field_class.call1((name, &dtype)).unwrap() + }); + let fields = PyList::new(py, iter)?; + let struct_class = pl.getattr(intern!(py, "Struct"))?; + struct_class.call1((fields,)) + }, + DataType::Null => { + let class = pl.getattr(intern!(py, "Null"))?; + class.call0() + }, + DataType::Unknown(UnknownKind::Int(v)) => { + Wrap(materialize_dyn_int(*v).dtype()).into_pyobject(py) + }, + DataType::Unknown(_) => { + let class = pl.getattr(intern!(py, "Unknown"))?; + class.call0() + }, + DataType::BinaryOffset => { + unimplemented!() + }, + } + } +} + +impl<'py> FromPyObject<'py> for Wrap { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let py = ob.py(); + let name = ob + .getattr(intern!(py, "name"))? + .str()? + .extract::()?; + let dtype = ob + .getattr(intern!(py, "dtype"))? + .extract::>()?; + Ok(Wrap(Field::new((&*name).into(), dtype.0))) + } +} + +impl<'py> FromPyObject<'py> for Wrap { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let py = ob.py(); + let type_name = ob.get_type().qualname()?.to_string(); + + let dtype = match &*type_name { + "DataTypeClass" => { + // just the class, not an object + let name = ob + .getattr(intern!(py, "__name__"))? + .str()? + .extract::()?; + match &*name { + "Int8" => DataType::Int8, + "Int16" => DataType::Int16, + "Int32" => DataType::Int32, + "Int64" => DataType::Int64, + "Int128" => DataType::Int128, + "UInt8" => DataType::UInt8, + "UInt16" => DataType::UInt16, + "UInt32" => DataType::UInt32, + "UInt64" => DataType::UInt64, + "Float32" => DataType::Float32, + "Float64" => DataType::Float64, + "Boolean" => DataType::Boolean, + "String" => DataType::String, + "Binary" => DataType::Binary, + "Categorical" => DataType::Categorical(None, Default::default()), + "Enum" => DataType::Enum(None, Default::default()), + "Date" => DataType::Date, + "Time" => DataType::Time, + "Datetime" => DataType::Datetime(TimeUnit::Microseconds, None), + "Duration" => DataType::Duration(TimeUnit::Microseconds), + "Decimal" => DataType::Decimal(None, None), // "none" scale => "infer" + "List" => DataType::List(Box::new(DataType::Null)), + "Array" => DataType::Array(Box::new(DataType::Null), 0), + "Struct" => DataType::Struct(vec![]), + "Null" => DataType::Null, + #[cfg(feature = "object")] + "Object" => DataType::Object(OBJECT_NAME), + "Unknown" => DataType::Unknown(Default::default()), + dt => { + return Err(PyTypeError::new_err(format!( + "'{dt}' is not a Polars data type", + ))); + }, + } + }, + "Int8" => DataType::Int8, + "Int16" => DataType::Int16, + "Int32" => DataType::Int32, + "Int64" => DataType::Int64, + "Int128" => DataType::Int128, + "UInt8" => DataType::UInt8, + "UInt16" => DataType::UInt16, + "UInt32" => DataType::UInt32, + "UInt64" => DataType::UInt64, + "Float32" => DataType::Float32, + "Float64" => DataType::Float64, + "Boolean" => DataType::Boolean, + "String" => DataType::String, + "Binary" => DataType::Binary, + "Categorical" => { + let ordering = ob.getattr(intern!(py, "ordering")).unwrap(); + let ordering = ordering.extract::>()?.0; + DataType::Categorical(None, ordering) + }, + "Enum" => { + let categories = ob.getattr(intern!(py, "categories")).unwrap(); + let s = get_series(&categories.as_borrowed())?; + let ca = s.str().map_err(PyPolarsErr::from)?; + let categories = ca.downcast_iter().next().unwrap().clone(); + create_enum_dtype(categories) + }, + "Date" => DataType::Date, + "Time" => DataType::Time, + "Datetime" => { + let time_unit = ob.getattr(intern!(py, "time_unit")).unwrap(); + let time_unit = time_unit.extract::>()?.0; + let time_zone = ob.getattr(intern!(py, "time_zone")).unwrap(); + let time_zone = time_zone.extract::>()?; + DataType::Datetime(time_unit, time_zone.as_deref().map(|x| x.into())) + }, + "Duration" => { + let time_unit = ob.getattr(intern!(py, "time_unit")).unwrap(); + let time_unit = time_unit.extract::>()?.0; + DataType::Duration(time_unit) + }, + "Decimal" => { + let precision = ob.getattr(intern!(py, "precision"))?.extract()?; + let scale = ob.getattr(intern!(py, "scale"))?.extract()?; + DataType::Decimal(precision, Some(scale)) + }, + "List" => { + let inner = ob.getattr(intern!(py, "inner")).unwrap(); + let inner = inner.extract::>()?; + DataType::List(Box::new(inner.0)) + }, + "Array" => { + let inner = ob.getattr(intern!(py, "inner")).unwrap(); + let size = ob.getattr(intern!(py, "size")).unwrap(); + let inner = inner.extract::>()?; + let size = size.extract::()?; + DataType::Array(Box::new(inner.0), size) + }, + "Struct" => { + let fields = ob.getattr(intern!(py, "fields"))?; + let fields = fields + .extract::>>()? + .into_iter() + .map(|f| f.0) + .collect::>(); + DataType::Struct(fields) + }, + "Null" => DataType::Null, + #[cfg(feature = "object")] + "Object" => DataType::Object(OBJECT_NAME), + "Unknown" => DataType::Unknown(Default::default()), + dt => { + return Err(PyTypeError::new_err(format!( + "'{dt}' is not a Polars data type", + ))); + }, + }; + Ok(Wrap(dtype)) + } +} + +impl<'py> IntoPyObject<'py> for Wrap { + type Target = PyString; + type Output = Bound<'py, Self::Target>; + type Error = Infallible; + + fn into_pyobject(self, py: Python<'py>) -> Result { + match self.0 { + CategoricalOrdering::Physical => "physical", + CategoricalOrdering::Lexical => "lexical", + } + .into_pyobject(py) + } +} + +impl<'py> IntoPyObject<'py> for Wrap { + type Target = PyString; + type Output = Bound<'py, Self::Target>; + type Error = Infallible; + + fn into_pyobject(self, py: Python<'py>) -> Result { + self.0.to_ascii().into_pyobject(py) + } +} + +#[cfg(feature = "parquet")] +impl<'s> FromPyObject<'s> for Wrap { + fn extract_bound(ob: &Bound<'s, PyAny>) -> PyResult { + let mut statistics = StatisticsOptions::empty(); + + let dict = ob.downcast::()?; + for (key, val) in dict { + let key = key.extract::()?; + let val = val.extract::()?; + + match key.as_ref() { + "min" => statistics.min_value = val, + "max" => statistics.max_value = val, + "distinct_count" => statistics.distinct_count = val, + "null_count" => statistics.null_count = val, + _ => { + return Err(PyTypeError::new_err(format!( + "'{key}' is not a valid statistic option", + ))); + }, + } + } + + Ok(Wrap(statistics)) + } +} + +impl<'s> FromPyObject<'s> for Wrap> { + fn extract_bound(ob: &Bound<'s, PyAny>) -> PyResult { + let vals = ob.extract::>>>()?; + let vals = reinterpret_vec(vals); + Ok(Wrap(Row(vals))) + } +} + +impl<'py> FromPyObject<'py> for Wrap { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let dict = ob.downcast::()?; + + Ok(Wrap( + dict.iter() + .map(|(key, val)| { + let key = key.extract::()?; + let val = val.extract::>()?; + + Ok(Field::new((&*key).into(), val.0)) + }) + .collect::>()?, + )) + } +} + +impl<'py> FromPyObject<'py> for Wrap { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let list = ob.downcast::()?.to_owned(); + + if list.is_empty() { + return Ok(Wrap(ScanSources::default())); + } + + enum MutableSources { + Paths(Vec), + Files(Vec), + Buffers(Vec), + } + + let num_items = list.len(); + let mut iter = list + .into_iter() + .map(|val| get_python_scan_source_input(val.unbind(), false)); + + let Some(first) = iter.next() else { + return Ok(Wrap(ScanSources::default())); + }; + + let mut sources = match first? { + PythonScanSourceInput::Path(path) => { + let mut sources = Vec::with_capacity(num_items); + sources.push(path); + MutableSources::Paths(sources) + }, + PythonScanSourceInput::File(file) => { + let mut sources = Vec::with_capacity(num_items); + sources.push(file.into()); + MutableSources::Files(sources) + }, + PythonScanSourceInput::Buffer(buffer) => { + let mut sources = Vec::with_capacity(num_items); + sources.push(buffer); + MutableSources::Buffers(sources) + }, + }; + + for source in iter { + match (&mut sources, source?) { + (MutableSources::Paths(v), PythonScanSourceInput::Path(p)) => v.push(p), + (MutableSources::Files(v), PythonScanSourceInput::File(f)) => v.push(f.into()), + (MutableSources::Buffers(v), PythonScanSourceInput::Buffer(f)) => v.push(f), + _ => { + return Err(PyTypeError::new_err( + "Cannot combine in-memory bytes, paths and files for scan sources", + )); + }, + } + } + + Ok(Wrap(match sources { + MutableSources::Paths(i) => ScanSources::Paths(i.into()), + MutableSources::Files(i) => ScanSources::Files(i.into()), + MutableSources::Buffers(i) => ScanSources::Buffers(i.into()), + })) + } +} + +impl<'py> IntoPyObject<'py> for Wrap<&Schema> { + type Target = PyDict; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> Result { + let dict = PyDict::new(py); + self.0 + .iter() + .try_for_each(|(k, v)| dict.set_item(k.as_str(), &Wrap(v.clone())))?; + Ok(dict) + } +} + +#[derive(Debug)] +#[repr(transparent)] +pub struct ObjectValue { + pub inner: PyObject, +} + +impl Clone for ObjectValue { + fn clone(&self) -> Self { + Python::with_gil(|py| Self { + inner: self.inner.clone_ref(py), + }) + } +} + +impl Hash for ObjectValue { + fn hash(&self, state: &mut H) { + let h = Python::with_gil(|py| self.inner.bind(py).hash().expect("should be hashable")); + state.write_isize(h) + } +} + +impl Eq for ObjectValue {} + +impl PartialEq for ObjectValue { + fn eq(&self, other: &Self) -> bool { + Python::with_gil(|py| { + match self + .inner + .bind(py) + .rich_compare(other.inner.bind(py), CompareOp::Eq) + { + Ok(result) => result.is_truthy().unwrap(), + Err(_) => false, + } + }) + } +} + +impl TotalEq for ObjectValue { + fn tot_eq(&self, other: &Self) -> bool { + self == other + } +} + +impl TotalHash for ObjectValue { + fn tot_hash(&self, state: &mut H) + where + H: Hasher, + { + self.hash(state); + } +} + +impl Display for ObjectValue { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.inner) + } +} + +#[cfg(feature = "object")] +impl PolarsObject for ObjectValue { + fn type_name() -> &'static str { + "object" + } +} + +impl From for ObjectValue { + fn from(p: PyObject) -> Self { + Self { inner: p } + } +} + +impl<'a> FromPyObject<'a> for ObjectValue { + fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult { + Ok(ObjectValue { + inner: ob.to_owned().unbind(), + }) + } +} + +/// # Safety +/// +/// The caller is responsible for checking that val is Object otherwise UB +#[cfg(feature = "object")] +impl From<&dyn PolarsObjectSafe> for &ObjectValue { + fn from(val: &dyn PolarsObjectSafe) -> Self { + unsafe { &*(val as *const dyn PolarsObjectSafe as *const ObjectValue) } + } +} + +impl<'a, 'py> IntoPyObject<'py> for &'a ObjectValue { + type Target = PyAny; + type Output = Borrowed<'a, 'py, Self::Target>; + type Error = std::convert::Infallible; + + fn into_pyobject(self, py: Python<'py>) -> Result { + Ok(self.inner.bind_borrowed(py)) + } +} + +impl Default for ObjectValue { + fn default() -> Self { + Python::with_gil(|py| ObjectValue { inner: py.None() }) + } +} + +impl<'a, T: NativeType + FromPyObject<'a>> FromPyObject<'a> for Wrap> { + fn extract_bound(obj: &Bound<'a, PyAny>) -> PyResult { + let seq = obj.downcast::()?; + let mut v = Vec::with_capacity(seq.len().unwrap_or(0)); + for item in seq.try_iter()? { + v.push(item?.extract::()?); + } + Ok(Wrap(v)) + } +} + +#[cfg(feature = "asof_join")] +impl<'py> FromPyObject<'py> for Wrap { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let parsed = match &*(ob.extract::()?) { + "backward" => AsofStrategy::Backward, + "forward" => AsofStrategy::Forward, + "nearest" => AsofStrategy::Nearest, + v => { + return Err(PyValueError::new_err(format!( + "asof `strategy` must be one of {{'backward', 'forward', 'nearest'}}, got {v}", + ))); + }, + }; + Ok(Wrap(parsed)) + } +} + +impl<'py> FromPyObject<'py> for Wrap { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let parsed = match &*(ob.extract::()?) { + "linear" => InterpolationMethod::Linear, + "nearest" => InterpolationMethod::Nearest, + v => { + return Err(PyValueError::new_err(format!( + "interpolation `method` must be one of {{'linear', 'nearest'}}, got {v}", + ))); + }, + }; + Ok(Wrap(parsed)) + } +} + +#[cfg(feature = "avro")] +impl<'py> FromPyObject<'py> for Wrap> { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let parsed = match &*ob.extract::()? { + "uncompressed" => None, + "snappy" => Some(AvroCompression::Snappy), + "deflate" => Some(AvroCompression::Deflate), + v => { + return Err(PyValueError::new_err(format!( + "avro `compression` must be one of {{'uncompressed', 'snappy', 'deflate'}}, got {v}", + ))); + }, + }; + Ok(Wrap(parsed)) + } +} + +impl<'py> FromPyObject<'py> for Wrap { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let parsed = match &*ob.extract::()? { + "physical" => CategoricalOrdering::Physical, + "lexical" => CategoricalOrdering::Lexical, + v => { + return Err(PyValueError::new_err(format!( + "categorical `ordering` must be one of {{'physical', 'lexical'}}, got {v}", + ))); + }, + }; + Ok(Wrap(parsed)) + } +} + +impl<'py> FromPyObject<'py> for Wrap { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let parsed = match &*ob.extract::()? { + "window" => StartBy::WindowBound, + "datapoint" => StartBy::DataPoint, + "monday" => StartBy::Monday, + "tuesday" => StartBy::Tuesday, + "wednesday" => StartBy::Wednesday, + "thursday" => StartBy::Thursday, + "friday" => StartBy::Friday, + "saturday" => StartBy::Saturday, + "sunday" => StartBy::Sunday, + v => { + return Err(PyValueError::new_err(format!( + "`start_by` must be one of {{'window', 'datapoint', 'monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday'}}, got {v}", + ))); + }, + }; + Ok(Wrap(parsed)) + } +} + +impl<'py> FromPyObject<'py> for Wrap { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let parsed = match &*ob.extract::()? { + "left" => ClosedWindow::Left, + "right" => ClosedWindow::Right, + "both" => ClosedWindow::Both, + "none" => ClosedWindow::None, + v => { + return Err(PyValueError::new_err(format!( + "`closed` must be one of {{'left', 'right', 'both', 'none'}}, got {v}", + ))); + }, + }; + Ok(Wrap(parsed)) + } +} + +#[cfg(feature = "csv")] +impl<'py> FromPyObject<'py> for Wrap { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let parsed = match &*ob.extract::()? { + "utf8" => CsvEncoding::Utf8, + "utf8-lossy" => CsvEncoding::LossyUtf8, + v => { + return Err(PyValueError::new_err(format!( + "csv `encoding` must be one of {{'utf8', 'utf8-lossy'}}, got {v}", + ))); + }, + }; + Ok(Wrap(parsed)) + } +} + +#[cfg(feature = "ipc")] +impl<'py> FromPyObject<'py> for Wrap> { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let parsed = match &*ob.extract::()? { + "uncompressed" => None, + "lz4" => Some(IpcCompression::LZ4), + "zstd" => Some(IpcCompression::ZSTD), + v => { + return Err(PyValueError::new_err(format!( + "ipc `compression` must be one of {{'uncompressed', 'lz4', 'zstd'}}, got {v}", + ))); + }, + }; + Ok(Wrap(parsed)) + } +} + +impl<'py> FromPyObject<'py> for Wrap { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let parsed = match &*ob.extract::()? { + "inner" => JoinType::Inner, + "left" => JoinType::Left, + "right" => JoinType::Right, + "full" => JoinType::Full, + "semi" => JoinType::Semi, + "anti" => JoinType::Anti, + #[cfg(feature = "cross_join")] + "cross" => JoinType::Cross, + v => { + return Err(PyValueError::new_err(format!( + "`how` must be one of {{'inner', 'left', 'full', 'semi', 'anti', 'cross'}}, got {v}", + ))); + }, + }; + Ok(Wrap(parsed)) + } +} + +impl<'py> FromPyObject<'py> for Wrap